From 869fb9691f877116d5b15a92de006d0daf4d70e5 Mon Sep 17 00:00:00 2001 From: René 'Necoro' Neumann Date: Thu, 17 Oct 2024 00:27:08 +0200 Subject: Restructure and change to chi as muxing framework --- auth.go | 130 ----------------------------------------------------- csrf.go | 31 ------------- csrf/csrf.go | 31 +++++++++++++ go.mod | 8 ++-- go.sum | 18 +++----- main.go | 87 +++++++++++------------------------ pages/login.go | 123 ++++++++++++++++++++++++++++++++++++++++++++++++++ pages/logout.go | 15 +++++++ pages/page.go | 58 ++++++++++++++++++++++++ pages/pages.go | 26 +++++++++++ session.go | 79 -------------------------------- session/session.go | 87 +++++++++++++++++++++++++++++++++++ 12 files changed, 375 insertions(+), 318 deletions(-) delete mode 100644 auth.go delete mode 100644 csrf.go create mode 100644 csrf/csrf.go create mode 100644 pages/login.go create mode 100644 pages/logout.go create mode 100644 pages/page.go create mode 100644 pages/pages.go delete mode 100644 session.go create mode 100644 session/session.go diff --git a/auth.go b/auth.go deleted file mode 100644 index 7e23cd6..0000000 --- a/auth.go +++ /dev/null @@ -1,130 +0,0 @@ -package main - -import ( - "context" - "database/sql" - "errors" - "log" - "net/http" - "net/url" - - "golang.org/x/crypto/bcrypt" -) - -type userContextKey struct{} - -const ( - sessionDuration = 86400 * 7 // 7 days - loginQueryMarker = "next" -) - -func RequireAuth(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - s := session(r) - - if !s.s.IsNew && s.Authenticated { - u, err := Q.GetUserById(r.Context(), s.UserID) - if err == nil { - // authenticated --> done - ctx := context.WithValue(r.Context(), userContextKey{}, u.ID) - next.ServeHTTP(w, r.WithContext(ctx)) - return - } - - s.Invalidate() - s.Save(w, r) - } - - // redirect to login with next-param - v := url.Values{} - v.Set(loginQueryMarker, r.URL.Path) - redirPath := "/login?" + v.Encode() - http.Redirect(w, r, redirPath, http.StatusFound) - }) -} - -func checkLogin(ctx context.Context, user User) (bool, int32) { - dbUser, err := Q.GetUserByName(ctx, user.Name) - if err == nil { - hash := []byte(dbUser.Pwd) - pwd := []byte(user.Password) - - if bcrypt.CompareHashAndPassword(hash, pwd) != nil { - return false, 0 - } - } else if errors.Is(err, sql.ErrNoRows) { - return false, 0 - } else { - log.Panicf("Could not load user '%s': %v", user.Name, err) - } - - return true, dbUser.ID -} - -func handleLogin(w http.ResponseWriter, r *http.Request) { - u := User{} - parseForm(r, &u) - - ok, userId := checkLogin(r.Context(), u) - - if !ok { - u.Errors = []error{fieldError{"Password", "Invalid"}} - showLoginPage(w, u) - return - } - - s := session(r) - if u.RememberMe { - s.MaxAge(sessionDuration) // 1 week - } else { - s.MaxAge(0) - } - - s.UserID = userId - s.Authenticated = true - s.Save(w, r) - - // redirect - next := r.URL.Query().Get(loginQueryMarker) - if next == "" { - next = "/" - } - http.Redirect(w, r, next, http.StatusFound) -} - -func handleLogout() http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - s := session(r) - s.Invalidate() - s.Save(w, r) - - http.Redirect(w, r, "/", http.StatusFound) - } -} - -type User struct { - Name string `form:"options=required,autofocus"` - Password string `form:"type=password;options=required"` - RememberMe bool `form:"type=checkbox;value=y;options=checked"` - Errors []error `form:"-"` - Csrf -} - -func showLoginPage(w http.ResponseWriter, u User) { - showTemplate(w, "login", u) -} - -func loginPage() http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - if session(r).Authenticated { - http.Redirect(w, r, "/", http.StatusFound) - } - u := User{} - u.SetCsrfField(r) - showLoginPage(w, u) - } -} - -func userId(r *http.Request) int32 { - return r.Context().Value(userContextKey{}).(int32) -} diff --git a/csrf.go b/csrf.go deleted file mode 100644 index 4539825..0000000 --- a/csrf.go +++ /dev/null @@ -1,31 +0,0 @@ -package main - -import ( - "html/template" - "net/http" - - "github.com/gorilla/csrf" - "github.com/gorilla/securecookie" -) - -func csrfHandler(next http.Handler) http.Handler { - return csrf.Protect( - securecookie.GenerateRandomKey(32), - csrf.SameSite(csrf.SameSiteStrictMode), - csrf.FieldName("csrf.csrffield"), // should match the structure in `Csrf` - )(next) -} - -// Csrf handles the CSRF data for a form. -// Include it verbatim and then use `{{.CsrfField}}` in templates. -type Csrf struct { - CsrfField template.HTML `form:"-" schema:"-"` -} - -func (c *Csrf) SetCsrfField(r *http.Request) { - c.CsrfField = csrf.TemplateField(r) -} - -type WithCsrf interface { - SetCsrfField(r *http.Request) -} diff --git a/csrf/csrf.go b/csrf/csrf.go new file mode 100644 index 0000000..18fdb81 --- /dev/null +++ b/csrf/csrf.go @@ -0,0 +1,31 @@ +package csrf + +import ( + "html/template" + "net/http" + + "github.com/gorilla/csrf" + "github.com/gorilla/securecookie" +) + +func Handler() func(http.Handler) http.Handler { + return csrf.Protect( + securecookie.GenerateRandomKey(32), + csrf.SameSite(csrf.SameSiteStrictMode), + csrf.FieldName("csrf.csrffield"), // should match the structure in `Csrf` + ) +} + +// Csrf handles the CSRF data for a form. +// Include it verbatim and then use `{{.CsrfField}}` in templates. +type Csrf struct { + CsrfField template.HTML `form:"-" schema:"-"` +} + +func (c *Csrf) SetCsrfField(r *http.Request) { + c.CsrfField = csrf.TemplateField(r) +} + +type Enabled interface { + SetCsrfField(r *http.Request) +} diff --git a/go.mod b/go.mod index c63ac1e..f1acb5c 100644 --- a/go.mod +++ b/go.mod @@ -5,22 +5,20 @@ go 1.23 toolchain go1.23.1 require ( - github.com/Necoro/form v0.0.0-20240211223301-6fa9f8196e1e + github.com/go-chi/chi/v5 v5.1.0 github.com/gorilla/csrf v1.7.2 - github.com/gorilla/handlers v1.5.2 github.com/gorilla/schema v1.4.1 github.com/gorilla/securecookie v1.1.2 github.com/gorilla/sessions v1.4.0 github.com/jackc/pgx/v5 v5.7.1 github.com/joho/godotenv v1.5.1 - golang.org/x/crypto v0.27.0 + golang.org/x/crypto v0.28.0 ) require ( - github.com/felixge/httpsnoop v1.0.4 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect github.com/jackc/puddle/v2 v2.2.2 // indirect golang.org/x/sync v0.8.0 // indirect - golang.org/x/text v0.18.0 // indirect + golang.org/x/text v0.19.0 // indirect ) diff --git a/go.sum b/go.sum index 4564b09..82440f9 100644 --- a/go.sum +++ b/go.sum @@ -1,18 +1,12 @@ -github.com/Necoro/form v0.0.0-20240211223301-6fa9f8196e1e h1:v3DDTGBMt9pclCdG7jRyNAABmtJw3uky/Xoi/DfbWNs= -github.com/Necoro/form v0.0.0-20240211223301-6fa9f8196e1e/go.mod h1:JxpmgZ5hjL6fyhBoZ4HAUadkp7DNqWlHbFL7l8oic4Y= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= -github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= -github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= -github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/go-chi/chi/v5 v5.1.0 h1:acVI1TYaD+hhedDJ3r54HyA6sExp3HfXq7QWEEY/xMw= +github.com/go-chi/chi/v5 v5.1.0/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8= github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0= github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/gorilla/csrf v1.7.2 h1:oTUjx0vyf2T+wkrx09Trsev1TE+/EbDAeHtSTbtC2eI= github.com/gorilla/csrf v1.7.2/go.mod h1:F1Fj3KG23WYHE6gozCmBAezKookxbIvUJT+121wTuLk= -github.com/gorilla/handlers v1.5.2 h1:cLTUSsNkgcwhgRqvCNmdbRWG0A3N4F+M2nWKdScwyEE= -github.com/gorilla/handlers v1.5.2/go.mod h1:dX+xVpaxdSw+q0Qek8SSsl3dfMk3jNddUkMzo0GtH0w= github.com/gorilla/schema v1.4.1 h1:jUg5hUjCSDZpNGLuXQOgIWGdlgrIdYvgQ0wZtdK1M3E= github.com/gorilla/schema v1.4.1/go.mod h1:Dg5SSm5PV60mhF2NFaTV1xuYYj8tV8NOPRo4FggUMnM= github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kXD8ePA= @@ -36,12 +30,12 @@ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UV github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= -golang.org/x/crypto v0.27.0 h1:GXm2NjJrPaiv/h1tb2UH8QfgC/hOf/+z0p6PT8o1w7A= -golang.org/x/crypto v0.27.0/go.mod h1:1Xngt8kV6Dvbssa53Ziq6Eqn0HqbZi5Z6R0ZpwQzt70= +golang.org/x/crypto v0.28.0 h1:GBDwsMXVQi34v5CCYUm2jkJvu4cbtru2U4TN2PSyQnw= +golang.org/x/crypto v0.28.0/go.mod h1:rmgy+3RHxRZMyY0jjAJShp2zgEdOqj2AO7U0pYmeQ7U= golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= -golang.org/x/text v0.18.0 h1:XvMDiNzPAl0jr17s6W9lcaIhGUfUORdGCNsuLmPG224= -golang.org/x/text v0.18.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= +golang.org/x/text v0.19.0 h1:kTxAhCbGbxhK0IwgSKiMO5awPoDQ0RpfiVYBfK860YM= +golang.org/x/text v0.19.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/main.go b/main.go index 05ee948..34a6719 100644 --- a/main.go +++ b/main.go @@ -6,16 +6,16 @@ import ( "net/http" "os" - "github.com/gorilla/handlers" + "github.com/go-chi/chi/v5" + "github.com/go-chi/chi/v5/middleware" "github.com/jackc/pgx/v5/pgxpool" "github.com/joho/godotenv" - "gosten/model" - "gosten/templ" + "gosten/csrf" + "gosten/pages" + "gosten/session" ) -var Q *model.Queries - func checkEnvEntry(e string) { if os.Getenv(e) == "" { log.Fatalf("Variable '%s' not set", e) @@ -43,66 +43,31 @@ func main() { } defer db.Close() - Q = model.New(db) - - mux := http.NewServeMux() - - // handlers that DO NOT require authentification - mux.Handle("GET /login", loginPage()) - mux.HandleFunc("POST /login", handleLogin) - mux.Handle("GET /logout", handleLogout()) - mux.Handle("/static/", http.StripPrefix("/static", http.FileServer(http.Dir("static")))) - mux.Handle("/favicon.ico", http.NotFoundHandler()) - - handler := sessionHandler(csrfHandler(mux)) - handler = handlers.CombinedLoggingHandler(os.Stderr, handler) - handler = handlers.ProxyHeaders(handler) + pages.Connect(db) - // setup authentification - authMux := http.NewServeMux() - mux.Handle("/", RequireAuth(authMux)) + router := chi.NewRouter() - // handlers that required authentification - authMux.Handle("GET /{$}", indexPage()) - authMux.Handle("GET /recur/{$}", recurPage()) - authMux.Handle("GET /categories/{$}", categoriesPage()) - authMux.Handle("GET /", notfound()) - - log.Fatal(http.ListenAndServe(os.Getenv("GOSTEN_ADDRESS"), handler)) -} - -func showTemplate(w http.ResponseWriter, tpl string, data any) { - if err := templ.Lookup(tpl).Execute(w, data); err != nil { - log.Panicf("Executing '%s' with %+v: %v", tpl, data, err) - } -} + // A good base middleware stack + router.Use(middleware.RequestID) + router.Use(middleware.RealIP) + router.Use(middleware.CleanPath) + router.Use(middleware.Logger) + router.Use(middleware.Recoverer) -func indexPage() http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - uid := userId(r) - u, _ := Q.GetUserById(r.Context(), uid) - showTemplate(w, "index", u.Name) - } -} + // handlers that DO NOT require authentification + router.Handle("/static/*", http.StripPrefix("/static", http.FileServer(http.Dir("static")))) + router.Get("/favicon.ico", http.NotFound) -func recurPage() http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - uid := userId(r) - exps, _ := Q.GetRecurExpenses(r.Context(), uid) - showTemplate(w, "recur", exps) - } -} + appRouter := router.With(csrf.Handler(), session.Handler()) + appRouter.Get("/login", pages.Login()) + appRouter.Post("/login", pages.HandleLogin) + appRouter.Get("/logout", pages.Logout()) -func categoriesPage() http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - uid := userId(r) - cats, _ := Q.GetCategoriesOrdered(r.Context(), uid) - showTemplate(w, "categories", cats) - } -} + authRouter := appRouter.With(pages.RequireAuth) + authRouter.Mount("/", pages.Init()) + authRouter.Mount("/recur", pages.Recur()) + authRouter.Mount("/categories", pages.Categories()) + authRouter.NotFound(pages.NotFound()) -func notfound() http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - showTemplate(w, "404", r.RequestURI) - } + log.Fatal(http.ListenAndServe(os.Getenv("GOSTEN_ADDRESS"), router)) } diff --git a/pages/login.go b/pages/login.go new file mode 100644 index 0000000..fb7859a --- /dev/null +++ b/pages/login.go @@ -0,0 +1,123 @@ +package pages + +import ( + "context" + "database/sql" + "errors" + "gosten/csrf" + "gosten/form" + "gosten/session" + "log" + "net/http" + "net/url" + + "golang.org/x/crypto/bcrypt" +) + +type userContextKey struct{} + +const ( + sessionDuration = 86400 * 7 // 7 days + loginQueryMarker = "next" +) + +func RequireAuth(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + s := session.From(r) + + if !s.IsNew() && s.Authenticated { + u, err := Q.GetUserById(r.Context(), s.UserID) + if err == nil { + // authenticated --> done + ctx := context.WithValue(r.Context(), userContextKey{}, u.ID) + next.ServeHTTP(w, r.WithContext(ctx)) + return + } + + s.Invalidate() + s.Save(w, r) + } + + // redirect to login with next-param + v := url.Values{} + v.Set(loginQueryMarker, r.URL.Path) + redirPath := "/login?" + v.Encode() + http.Redirect(w, r, redirPath, http.StatusFound) + }) +} + +type User struct { + Name string `form:"options=required,autofocus"` + Password string `form:"type=password;options=required"` + RememberMe bool `form:"type=checkbox;value=y;options=checked"` + Errors []error `form:"-"` + csrf.Csrf +} + +func showLoginPage(w http.ResponseWriter, u User) { + showTemplate(w, "login", u) +} + +func Login() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if session.From(r).Authenticated { + http.Redirect(w, r, "/", http.StatusFound) + } + u := User{} + u.SetCsrfField(r) + showLoginPage(w, u) + } +} + +func userId(r *http.Request) int32 { + return r.Context().Value(userContextKey{}).(int32) +} + +func checkLogin(ctx context.Context, user User) (bool, int32) { + dbUser, err := Q.GetUserByName(ctx, user.Name) + if err == nil { + hash := []byte(dbUser.Pwd) + pwd := []byte(user.Password) + + if bcrypt.CompareHashAndPassword(hash, pwd) != nil { + return false, 0 + } + } else if errors.Is(err, sql.ErrNoRows) { + return false, 0 + } else { + log.Panicf("Could not load user '%s': %v", user.Name, err) + } + + return true, dbUser.ID +} + +func HandleLogin(w http.ResponseWriter, r *http.Request) { + u := User{} + form.Parse(r, &u) + + ok, userId := checkLogin(r.Context(), u) + + if !ok { + u.Errors = []error{form.FieldError{Field: "Password", Issue: "Invalid"}} + showLoginPage(w, u) + return + } + + s := session.From(r) + if u.RememberMe { + s.MaxAge(sessionDuration) // 1 week + } else { + s.MaxAge(0) + } + + s.UserID = userId + s.Authenticated = true + s.Save(w, r) + + // redirect + next := r.URL.Query().Get(loginQueryMarker) + if next == "" { + next = "/" + } + http.Redirect(w, r, next, http.StatusFound) +} diff --git a/pages/logout.go b/pages/logout.go new file mode 100644 index 0000000..dad0e1a --- /dev/null +++ b/pages/logout.go @@ -0,0 +1,15 @@ +package pages + +import ( + "gosten/session" + "net/http" +) + +func Logout() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + s := session.From(r) + s.Invalidate() + s.Save(w, r) + http.Redirect(w, r, "/", http.StatusFound) + } +} diff --git a/pages/page.go b/pages/page.go new file mode 100644 index 0000000..25c2331 --- /dev/null +++ b/pages/page.go @@ -0,0 +1,58 @@ +package pages + +import ( + "context" + "gosten/model" + "gosten/templ" + "log" + "net/http" + + "github.com/go-chi/chi/v5" +) + +var Q *model.Queries + +func Connect(tx model.DBTX) { + Q = model.New(tx) +} + +type Page interface { + http.Handler +} + +type dataFunc func(r *http.Request, uid int32) any + +type simplePage struct { + dataFn dataFunc + template string +} + +func (p simplePage) ServeHTTP(w http.ResponseWriter, r *http.Request) { + input := p.dataFn(r, userId(r)) + p.showTemplate(w, input) +} + +func simpleByQuery[T any](tpl string, query func(ctx context.Context, id int32) (T, error)) Page { + dataFn := func(r *http.Request, uid int32) any { + d, _ := query(r.Context(), uid) + return d + } + return simple(tpl, dataFn) +} + +func simple(tpl string, dataFn dataFunc) Page { + p := simplePage{dataFn, tpl} + r := chi.NewRouter() + r.Get("/", p.ServeHTTP) + return r +} + +func showTemplate(w http.ResponseWriter, tpl string, data any) { + if err := templ.Lookup(tpl).Execute(w, data); err != nil { + log.Panicf("Executing '%s' with %+v: %v", tpl, data, err) + } +} + +func (p simplePage) showTemplate(w http.ResponseWriter, data any) { + showTemplate(w, p.template, data) +} diff --git a/pages/pages.go b/pages/pages.go new file mode 100644 index 0000000..e965bdd --- /dev/null +++ b/pages/pages.go @@ -0,0 +1,26 @@ +package pages + +import ( + "net/http" +) + +func Init() Page { + return simple("index", func(r *http.Request, uid int32) any { + u, _ := Q.GetUserById(r.Context(), uid) + return u.Name + }) +} + +func Recur() Page { + return simpleByQuery("recur", Q.GetRecurExpenses) +} + +func Categories() Page { + return simpleByQuery("categories", Q.GetCategoriesOrdered) +} + +func NotFound() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + showTemplate(w, "404", r.RequestURI) + } +} diff --git a/session.go b/session.go deleted file mode 100644 index 680f648..0000000 --- a/session.go +++ /dev/null @@ -1,79 +0,0 @@ -package main - -import ( - "context" - "encoding/gob" - "log" - "net/http" - "os" - - "github.com/gorilla/securecookie" - "github.com/gorilla/sessions" -) - -type sessionContextKey struct{} - -const ( - sessionCookie = "sessionKeks" - dataKey = "data" -) - -func init() { - gob.Register(SessionData{}) -} - -type Session struct { - *SessionData - s *sessions.Session -} - -type SessionData struct { - UserID int32 - Authenticated bool -} - -func (s *Session) Save(w http.ResponseWriter, r *http.Request) { - s.s.Values[dataKey] = *s.SessionData - if err := s.s.Save(r, w); err != nil { - log.Panic("Storing session: ", err) - } -} - -func (s *Session) MaxAge(maxAge int) { - s.s.Options.MaxAge = maxAge -} - -func (s *Session) Invalidate() { - s.MaxAge(-1) - s.Authenticated = false -} - -func session(r *http.Request) Session { - s := r.Context().Value(sessionContextKey{}).(*sessions.Session) - s.Options.HttpOnly = true - - sd, ok := s.Values[dataKey].(SessionData) - if !ok { - sd = SessionData{} - } - return Session{&sd, s} -} - -func sessionHandler(next http.Handler) http.Handler { - var key []byte - - if envKey := os.Getenv("GOSTEN_SECRET"); len(envKey) >= 32 { - key = []byte(envKey) - } else { - key = securecookie.GenerateRandomKey(32) - } - - sessionStore := sessions.NewCookieStore(key) - - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - session, _ := sessionStore.Get(r, sessionCookie) - - ctx := context.WithValue(r.Context(), sessionContextKey{}, session) - next.ServeHTTP(w, r.WithContext(ctx)) - }) -} diff --git a/session/session.go b/session/session.go new file mode 100644 index 0000000..5ffd5cd --- /dev/null +++ b/session/session.go @@ -0,0 +1,87 @@ +package session + +import ( + "context" + "encoding/gob" + "log" + "net/http" + "os" + + "github.com/gorilla/securecookie" + "github.com/gorilla/sessions" +) + +type sessionContextKey struct{} + +const ( + sessionCookie = "sessionKeks" + dataKey = "data" +) + +func init() { + gob.Register(sessionData{}) +} + +type Session struct { + *sessionData + s *sessions.Session +} + +type sessionData struct { + UserID int32 + Authenticated bool +} + +func (s *Session) Save(w http.ResponseWriter, r *http.Request) { + s.s.Values[dataKey] = *s.sessionData + if err := s.s.Save(r, w); err != nil { + log.Panic("Storing session: ", err) + } +} + +func (s *Session) MaxAge(maxAge int) { + s.s.Options.MaxAge = maxAge +} + +func (s *Session) Invalidate() { + s.MaxAge(-1) + s.Authenticated = false +} + +func (s *Session) IsNew() bool { + return s.s.IsNew +} + +// From extracts the `Session` from the `Request`. +func From(r *http.Request) Session { + s := r.Context().Value(sessionContextKey{}).(*sessions.Session) + s.Options.HttpOnly = true + + sd, ok := s.Values[dataKey].(sessionData) + if !ok { + sd = sessionData{} + } + return Session{&sd, s} +} + +func Handler() func(next http.Handler) http.Handler { + var key []byte + + if envKey := os.Getenv("GOSTEN_SECRET"); len(envKey) >= 32 { + key = []byte(envKey) + } else { + key = securecookie.GenerateRandomKey(32) + } + + sessionStore := sessions.NewCookieStore(key) + + return func(next http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + session, _ := sessionStore.Get(r, sessionCookie) + + ctx := context.WithValue(r.Context(), sessionContextKey{}, session) + next.ServeHTTP(w, r.WithContext(ctx)) + } + return http.HandlerFunc(fn) + } +} -- cgit v1.2.3-70-g09d2