From 70031d8971136583647125d967ace7b9c831ed00 Mon Sep 17 00:00:00 2001 From: René 'Necoro' Neumann Date: Thu, 17 Oct 2024 18:07:11 +0200 Subject: Save User as part of the context --- pages/login.go | 9 +++++---- pages/page.go | 45 ++++++++++++++++++++++++++++++--------------- pages/pages.go | 9 ++------- pages/pages.templ | 8 ++++---- pages/pages_templ.go | 12 ++++++------ 5 files changed, 47 insertions(+), 36 deletions(-) diff --git a/pages/login.go b/pages/login.go index 14b1ce1..84119e9 100644 --- a/pages/login.go +++ b/pages/login.go @@ -6,6 +6,7 @@ import ( "errors" "gosten/csrf" "gosten/form" + "gosten/model" "gosten/session" "log" "net/http" @@ -17,8 +18,8 @@ import ( type userContextKey struct{} -func userId(r *http.Request) int32 { - return r.Context().Value(userContextKey{}).(int32) +func getUser(ctx context.Context) model.User { + return ctx.Value(userContextKey{}).(model.User) } const ( @@ -34,7 +35,7 @@ func RequireAuth(next http.Handler) http.Handler { u, err := Q.GetUserById(r.Context(), s.UserID) if err == nil { // authenticated --> done - ctx := context.WithValue(r.Context(), userContextKey{}, u.ID) + ctx := context.WithValue(r.Context(), userContextKey{}, u) next.ServeHTTP(w, r.WithContext(ctx)) return } @@ -126,5 +127,5 @@ func handleLogin(w http.ResponseWriter, r *http.Request) { } func showLoginPage(r *http.Request, w http.ResponseWriter, u user) { - render(login(u), w, r) + render(login(u))(w, r) } diff --git a/pages/page.go b/pages/page.go index 6ce8cee..c10fb21 100644 --- a/pages/page.go +++ b/pages/page.go @@ -20,38 +20,53 @@ type Page interface { http.Handler } -type dataFunc[T any] func(r *http.Request, uid int32) T +type dataFunc[T any] func(ctx context.Context) T type tplFunc[T any] func(T) templ.Component -func simpleByQuery[T any](tpl tplFunc[T], query func(ctx context.Context, id int32) (T, error)) Page { - dataFn := func(r *http.Request, uid int32) T { - d, _ := query(r.Context(), uid) - return d - } - return simple(tpl, dataFn) +func simple(c templ.Component) Page { + r := chi.NewRouter() + r.Get("/", render(c)) + return r } -func simple[T any](tpl tplFunc[T], dataFn dataFunc[T]) Page { +func simpleWithData[T any](tpl tplFunc[T], dataFn dataFunc[T]) Page { r := chi.NewRouter() r.Get("/", func(w http.ResponseWriter, r *http.Request) { - input := dataFn(r, userId(r)) + input := dataFn(r.Context()) c := tpl(input) - render(c, w, r) + render(c)(w, r) }) return r } +func simpleByQuery[T any](tpl tplFunc[T], query func(context.Context, int32) (T, error)) Page { + dataFn := func(ctx context.Context) T { + d, err := query(ctx, getUser(ctx).ID) + if err != nil { + log.Panic(err.Error()) + } + return d + } + return simpleWithData(tpl, dataFn) +} + type ctxPath struct{} -func render(c templ.Component, w http.ResponseWriter, r *http.Request) { - ctx := context.WithValue(r.Context(), ctxPath{}, r.URL.Path) - if err := c.Render(ctx, w); err != nil { - log.Panic(err.Error()) +func render(c templ.Component) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx := context.WithValue(r.Context(), ctxPath{}, r.URL.Path) + if err := c.Render(ctx, w); err != nil { + log.Panic(err.Error()) + } } } +func getCurrPath(ctx context.Context) string { + return ctx.Value(ctxPath{}).(string) +} + func isCurrPath(ctx context.Context, path string) bool { - currPath := ctx.Value(ctxPath{}).(string) + currPath := getCurrPath(ctx) if path[0] != '/' { path = "/" + path } diff --git a/pages/pages.go b/pages/pages.go index eb7a3f6..0f03cdd 100644 --- a/pages/pages.go +++ b/pages/pages.go @@ -5,10 +5,7 @@ import ( ) func Init() Page { - return simple(index, func(r *http.Request, uid int32) string { - u, _ := Q.GetUserById(r.Context(), uid) - return u.Name - }) + return simple(index()) } func Recur() Page { @@ -20,7 +17,5 @@ func Categories() Page { } func NotFound() http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - render(notfound(r.RequestURI), w, r) - } + return render(notfound()) } diff --git a/pages/pages.templ b/pages/pages.templ index d2bfad2..67e2792 100644 --- a/pages/pages.templ +++ b/pages/pages.templ @@ -2,22 +2,22 @@ package pages import "gosten/model" -templ notfound(uri string) { +templ notfound() { @content() { } } -templ index(user string) { +templ index() { @content() { - Logged in with user: {user} + Logged in with user: {getUser(ctx).Name} } } diff --git a/pages/pages_templ.go b/pages/pages_templ.go index fe1ad49..1f526ca 100644 --- a/pages/pages_templ.go +++ b/pages/pages_templ.go @@ -10,7 +10,7 @@ import templruntime "github.com/a-h/templ/runtime" import "gosten/model" -func notfound(uri string) templ.Component { +func notfound() templ.Component { return templruntime.GeneratedTemplate(func(templ_7745c5c3_Input templruntime.GeneratedComponentInput) (templ_7745c5c3_Err error) { templ_7745c5c3_W, ctx := templ_7745c5c3_Input.Writer, templ_7745c5c3_Input.Context if templ_7745c5c3_CtxErr := ctx.Err(); templ_7745c5c3_CtxErr != nil { @@ -48,9 +48,9 @@ func notfound(uri string) templ.Component { return templ_7745c5c3_Err } var templ_7745c5c3_Var3 string - templ_7745c5c3_Var3, templ_7745c5c3_Err = templ.JoinStringErrs(uri) + templ_7745c5c3_Var3, templ_7745c5c3_Err = templ.JoinStringErrs(getCurrPath(ctx)) if templ_7745c5c3_Err != nil { - return templ.Error{Err: templ_7745c5c3_Err, FileName: `pages/pages.templ`, Line: 12, Col: 47} + return templ.Error{Err: templ_7745c5c3_Err, FileName: `pages/pages.templ`, Line: 12, Col: 60} } _, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(templ_7745c5c3_Var3)) if templ_7745c5c3_Err != nil { @@ -70,7 +70,7 @@ func notfound(uri string) templ.Component { }) } -func index(user string) templ.Component { +func index() templ.Component { return templruntime.GeneratedTemplate(func(templ_7745c5c3_Input templruntime.GeneratedComponentInput) (templ_7745c5c3_Err error) { templ_7745c5c3_W, ctx := templ_7745c5c3_Input.Writer, templ_7745c5c3_Input.Context if templ_7745c5c3_CtxErr := ctx.Err(); templ_7745c5c3_CtxErr != nil { @@ -108,9 +108,9 @@ func index(user string) templ.Component { return templ_7745c5c3_Err } var templ_7745c5c3_Var6 string - templ_7745c5c3_Var6, templ_7745c5c3_Err = templ.JoinStringErrs(user) + templ_7745c5c3_Var6, templ_7745c5c3_Err = templ.JoinStringErrs(getUser(ctx).Name) if templ_7745c5c3_Err != nil { - return templ.Error{Err: templ_7745c5c3_Err, FileName: `pages/pages.templ`, Line: 20, Col: 30} + return templ.Error{Err: templ_7745c5c3_Err, FileName: `pages/pages.templ`, Line: 20, Col: 43} } _, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(templ_7745c5c3_Var6)) if templ_7745c5c3_Err != nil { -- cgit v1.2.3-70-g09d2