summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--csrf/csrf.go (renamed from csrf.go)8
-rw-r--r--go.mod8
-rw-r--r--go.sum18
-rw-r--r--main.go87
-rw-r--r--pages/login.go (renamed from auth.go)81
-rw-r--r--pages/logout.go15
-rw-r--r--pages/page.go58
-rw-r--r--pages/pages.go26
-rw-r--r--session/session.go (renamed from session.go)36
9 files changed, 197 insertions, 140 deletions
diff --git a/csrf.go b/csrf/csrf.go
index 4539825..18fdb81 100644
--- a/csrf.go
+++ b/csrf/csrf.go
@@ -1,4 +1,4 @@
-package main
+package csrf
import (
"html/template"
@@ -8,12 +8,12 @@ import (
"github.com/gorilla/securecookie"
)
-func csrfHandler(next http.Handler) http.Handler {
+func Handler() func(http.Handler) http.Handler {
return csrf.Protect(
securecookie.GenerateRandomKey(32),
csrf.SameSite(csrf.SameSiteStrictMode),
csrf.FieldName("csrf.csrffield"), // should match the structure in `Csrf`
- )(next)
+ )
}
// Csrf handles the CSRF data for a form.
@@ -26,6 +26,6 @@ func (c *Csrf) SetCsrfField(r *http.Request) {
c.CsrfField = csrf.TemplateField(r)
}
-type WithCsrf interface {
+type Enabled interface {
SetCsrfField(r *http.Request)
}
diff --git a/go.mod b/go.mod
index c63ac1e..f1acb5c 100644
--- a/go.mod
+++ b/go.mod
@@ -5,22 +5,20 @@ go 1.23
toolchain go1.23.1
require (
- github.com/Necoro/form v0.0.0-20240211223301-6fa9f8196e1e
+ github.com/go-chi/chi/v5 v5.1.0
github.com/gorilla/csrf v1.7.2
- github.com/gorilla/handlers v1.5.2
github.com/gorilla/schema v1.4.1
github.com/gorilla/securecookie v1.1.2
github.com/gorilla/sessions v1.4.0
github.com/jackc/pgx/v5 v5.7.1
github.com/joho/godotenv v1.5.1
- golang.org/x/crypto v0.27.0
+ golang.org/x/crypto v0.28.0
)
require (
- github.com/felixge/httpsnoop v1.0.4 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
github.com/jackc/puddle/v2 v2.2.2 // indirect
golang.org/x/sync v0.8.0 // indirect
- golang.org/x/text v0.18.0 // indirect
+ golang.org/x/text v0.19.0 // indirect
)
diff --git a/go.sum b/go.sum
index 4564b09..82440f9 100644
--- a/go.sum
+++ b/go.sum
@@ -1,18 +1,12 @@
-github.com/Necoro/form v0.0.0-20240211223301-6fa9f8196e1e h1:v3DDTGBMt9pclCdG7jRyNAABmtJw3uky/Xoi/DfbWNs=
-github.com/Necoro/form v0.0.0-20240211223301-6fa9f8196e1e/go.mod h1:JxpmgZ5hjL6fyhBoZ4HAUadkp7DNqWlHbFL7l8oic4Y=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
-github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg=
-github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
-github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
-github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
+github.com/go-chi/chi/v5 v5.1.0 h1:acVI1TYaD+hhedDJ3r54HyA6sExp3HfXq7QWEEY/xMw=
+github.com/go-chi/chi/v5 v5.1.0/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8=
github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0=
github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/gorilla/csrf v1.7.2 h1:oTUjx0vyf2T+wkrx09Trsev1TE+/EbDAeHtSTbtC2eI=
github.com/gorilla/csrf v1.7.2/go.mod h1:F1Fj3KG23WYHE6gozCmBAezKookxbIvUJT+121wTuLk=
-github.com/gorilla/handlers v1.5.2 h1:cLTUSsNkgcwhgRqvCNmdbRWG0A3N4F+M2nWKdScwyEE=
-github.com/gorilla/handlers v1.5.2/go.mod h1:dX+xVpaxdSw+q0Qek8SSsl3dfMk3jNddUkMzo0GtH0w=
github.com/gorilla/schema v1.4.1 h1:jUg5hUjCSDZpNGLuXQOgIWGdlgrIdYvgQ0wZtdK1M3E=
github.com/gorilla/schema v1.4.1/go.mod h1:Dg5SSm5PV60mhF2NFaTV1xuYYj8tV8NOPRo4FggUMnM=
github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kXD8ePA=
@@ -36,12 +30,12 @@ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UV
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
-golang.org/x/crypto v0.27.0 h1:GXm2NjJrPaiv/h1tb2UH8QfgC/hOf/+z0p6PT8o1w7A=
-golang.org/x/crypto v0.27.0/go.mod h1:1Xngt8kV6Dvbssa53Ziq6Eqn0HqbZi5Z6R0ZpwQzt70=
+golang.org/x/crypto v0.28.0 h1:GBDwsMXVQi34v5CCYUm2jkJvu4cbtru2U4TN2PSyQnw=
+golang.org/x/crypto v0.28.0/go.mod h1:rmgy+3RHxRZMyY0jjAJShp2zgEdOqj2AO7U0pYmeQ7U=
golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ=
golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
-golang.org/x/text v0.18.0 h1:XvMDiNzPAl0jr17s6W9lcaIhGUfUORdGCNsuLmPG224=
-golang.org/x/text v0.18.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY=
+golang.org/x/text v0.19.0 h1:kTxAhCbGbxhK0IwgSKiMO5awPoDQ0RpfiVYBfK860YM=
+golang.org/x/text v0.19.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
diff --git a/main.go b/main.go
index 05ee948..34a6719 100644
--- a/main.go
+++ b/main.go
@@ -6,16 +6,16 @@ import (
"net/http"
"os"
- "github.com/gorilla/handlers"
+ "github.com/go-chi/chi/v5"
+ "github.com/go-chi/chi/v5/middleware"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/joho/godotenv"
- "gosten/model"
- "gosten/templ"
+ "gosten/csrf"
+ "gosten/pages"
+ "gosten/session"
)
-var Q *model.Queries
-
func checkEnvEntry(e string) {
if os.Getenv(e) == "" {
log.Fatalf("Variable '%s' not set", e)
@@ -43,66 +43,31 @@ func main() {
}
defer db.Close()
- Q = model.New(db)
-
- mux := http.NewServeMux()
-
- // handlers that DO NOT require authentification
- mux.Handle("GET /login", loginPage())
- mux.HandleFunc("POST /login", handleLogin)
- mux.Handle("GET /logout", handleLogout())
- mux.Handle("/static/", http.StripPrefix("/static", http.FileServer(http.Dir("static"))))
- mux.Handle("/favicon.ico", http.NotFoundHandler())
-
- handler := sessionHandler(csrfHandler(mux))
- handler = handlers.CombinedLoggingHandler(os.Stderr, handler)
- handler = handlers.ProxyHeaders(handler)
+ pages.Connect(db)
- // setup authentification
- authMux := http.NewServeMux()
- mux.Handle("/", RequireAuth(authMux))
+ router := chi.NewRouter()
- // handlers that required authentification
- authMux.Handle("GET /{$}", indexPage())
- authMux.Handle("GET /recur/{$}", recurPage())
- authMux.Handle("GET /categories/{$}", categoriesPage())
- authMux.Handle("GET /", notfound())
-
- log.Fatal(http.ListenAndServe(os.Getenv("GOSTEN_ADDRESS"), handler))
-}
-
-func showTemplate(w http.ResponseWriter, tpl string, data any) {
- if err := templ.Lookup(tpl).Execute(w, data); err != nil {
- log.Panicf("Executing '%s' with %+v: %v", tpl, data, err)
- }
-}
+ // A good base middleware stack
+ router.Use(middleware.RequestID)
+ router.Use(middleware.RealIP)
+ router.Use(middleware.CleanPath)
+ router.Use(middleware.Logger)
+ router.Use(middleware.Recoverer)
-func indexPage() http.HandlerFunc {
- return func(w http.ResponseWriter, r *http.Request) {
- uid := userId(r)
- u, _ := Q.GetUserById(r.Context(), uid)
- showTemplate(w, "index", u.Name)
- }
-}
+ // handlers that DO NOT require authentification
+ router.Handle("/static/*", http.StripPrefix("/static", http.FileServer(http.Dir("static"))))
+ router.Get("/favicon.ico", http.NotFound)
-func recurPage() http.HandlerFunc {
- return func(w http.ResponseWriter, r *http.Request) {
- uid := userId(r)
- exps, _ := Q.GetRecurExpenses(r.Context(), uid)
- showTemplate(w, "recur", exps)
- }
-}
+ appRouter := router.With(csrf.Handler(), session.Handler())
+ appRouter.Get("/login", pages.Login())
+ appRouter.Post("/login", pages.HandleLogin)
+ appRouter.Get("/logout", pages.Logout())
-func categoriesPage() http.HandlerFunc {
- return func(w http.ResponseWriter, r *http.Request) {
- uid := userId(r)
- cats, _ := Q.GetCategoriesOrdered(r.Context(), uid)
- showTemplate(w, "categories", cats)
- }
-}
+ authRouter := appRouter.With(pages.RequireAuth)
+ authRouter.Mount("/", pages.Init())
+ authRouter.Mount("/recur", pages.Recur())
+ authRouter.Mount("/categories", pages.Categories())
+ authRouter.NotFound(pages.NotFound())
-func notfound() http.HandlerFunc {
- return func(w http.ResponseWriter, r *http.Request) {
- showTemplate(w, "404", r.RequestURI)
- }
+ log.Fatal(http.ListenAndServe(os.Getenv("GOSTEN_ADDRESS"), router))
}
diff --git a/auth.go b/pages/login.go
index 7e23cd6..fb7859a 100644
--- a/auth.go
+++ b/pages/login.go
@@ -1,9 +1,12 @@
-package main
+package pages
import (
"context"
"database/sql"
"errors"
+ "gosten/csrf"
+ "gosten/form"
+ "gosten/session"
"log"
"net/http"
"net/url"
@@ -20,9 +23,9 @@ const (
func RequireAuth(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- s := session(r)
+ s := session.From(r)
- if !s.s.IsNew && s.Authenticated {
+ if !s.IsNew() && s.Authenticated {
u, err := Q.GetUserById(r.Context(), s.UserID)
if err == nil {
// authenticated --> done
@@ -43,6 +46,33 @@ func RequireAuth(next http.Handler) http.Handler {
})
}
+type User struct {
+ Name string `form:"options=required,autofocus"`
+ Password string `form:"type=password;options=required"`
+ RememberMe bool `form:"type=checkbox;value=y;options=checked"`
+ Errors []error `form:"-"`
+ csrf.Csrf
+}
+
+func showLoginPage(w http.ResponseWriter, u User) {
+ showTemplate(w, "login", u)
+}
+
+func Login() http.HandlerFunc {
+ return func(w http.ResponseWriter, r *http.Request) {
+ if session.From(r).Authenticated {
+ http.Redirect(w, r, "/", http.StatusFound)
+ }
+ u := User{}
+ u.SetCsrfField(r)
+ showLoginPage(w, u)
+ }
+}
+
+func userId(r *http.Request) int32 {
+ return r.Context().Value(userContextKey{}).(int32)
+}
+
func checkLogin(ctx context.Context, user User) (bool, int32) {
dbUser, err := Q.GetUserByName(ctx, user.Name)
if err == nil {
@@ -61,19 +91,19 @@ func checkLogin(ctx context.Context, user User) (bool, int32) {
return true, dbUser.ID
}
-func handleLogin(w http.ResponseWriter, r *http.Request) {
+func HandleLogin(w http.ResponseWriter, r *http.Request) {
u := User{}
- parseForm(r, &u)
+ form.Parse(r, &u)
ok, userId := checkLogin(r.Context(), u)
if !ok {
- u.Errors = []error{fieldError{"Password", "Invalid"}}
+ u.Errors = []error{form.FieldError{Field: "Password", Issue: "Invalid"}}
showLoginPage(w, u)
return
}
- s := session(r)
+ s := session.From(r)
if u.RememberMe {
s.MaxAge(sessionDuration) // 1 week
} else {
@@ -91,40 +121,3 @@ func handleLogin(w http.ResponseWriter, r *http.Request) {
}
http.Redirect(w, r, next, http.StatusFound)
}
-
-func handleLogout() http.HandlerFunc {
- return func(w http.ResponseWriter, r *http.Request) {
- s := session(r)
- s.Invalidate()
- s.Save(w, r)
-
- http.Redirect(w, r, "/", http.StatusFound)
- }
-}
-
-type User struct {
- Name string `form:"options=required,autofocus"`
- Password string `form:"type=password;options=required"`
- RememberMe bool `form:"type=checkbox;value=y;options=checked"`
- Errors []error `form:"-"`
- Csrf
-}
-
-func showLoginPage(w http.ResponseWriter, u User) {
- showTemplate(w, "login", u)
-}
-
-func loginPage() http.HandlerFunc {
- return func(w http.ResponseWriter, r *http.Request) {
- if session(r).Authenticated {
- http.Redirect(w, r, "/", http.StatusFound)
- }
- u := User{}
- u.SetCsrfField(r)
- showLoginPage(w, u)
- }
-}
-
-func userId(r *http.Request) int32 {
- return r.Context().Value(userContextKey{}).(int32)
-}
diff --git a/pages/logout.go b/pages/logout.go
new file mode 100644
index 0000000..dad0e1a
--- /dev/null
+++ b/pages/logout.go
@@ -0,0 +1,15 @@
+package pages
+
+import (
+ "gosten/session"
+ "net/http"
+)
+
+func Logout() http.HandlerFunc {
+ return func(w http.ResponseWriter, r *http.Request) {
+ s := session.From(r)
+ s.Invalidate()
+ s.Save(w, r)
+ http.Redirect(w, r, "/", http.StatusFound)
+ }
+}
diff --git a/pages/page.go b/pages/page.go
new file mode 100644
index 0000000..25c2331
--- /dev/null
+++ b/pages/page.go
@@ -0,0 +1,58 @@
+package pages
+
+import (
+ "context"
+ "gosten/model"
+ "gosten/templ"
+ "log"
+ "net/http"
+
+ "github.com/go-chi/chi/v5"
+)
+
+var Q *model.Queries
+
+func Connect(tx model.DBTX) {
+ Q = model.New(tx)
+}
+
+type Page interface {
+ http.Handler
+}
+
+type dataFunc func(r *http.Request, uid int32) any
+
+type simplePage struct {
+ dataFn dataFunc
+ template string
+}
+
+func (p simplePage) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+ input := p.dataFn(r, userId(r))
+ p.showTemplate(w, input)
+}
+
+func simpleByQuery[T any](tpl string, query func(ctx context.Context, id int32) (T, error)) Page {
+ dataFn := func(r *http.Request, uid int32) any {
+ d, _ := query(r.Context(), uid)
+ return d
+ }
+ return simple(tpl, dataFn)
+}
+
+func simple(tpl string, dataFn dataFunc) Page {
+ p := simplePage{dataFn, tpl}
+ r := chi.NewRouter()
+ r.Get("/", p.ServeHTTP)
+ return r
+}
+
+func showTemplate(w http.ResponseWriter, tpl string, data any) {
+ if err := templ.Lookup(tpl).Execute(w, data); err != nil {
+ log.Panicf("Executing '%s' with %+v: %v", tpl, data, err)
+ }
+}
+
+func (p simplePage) showTemplate(w http.ResponseWriter, data any) {
+ showTemplate(w, p.template, data)
+}
diff --git a/pages/pages.go b/pages/pages.go
new file mode 100644
index 0000000..e965bdd
--- /dev/null
+++ b/pages/pages.go
@@ -0,0 +1,26 @@
+package pages
+
+import (
+ "net/http"
+)
+
+func Init() Page {
+ return simple("index", func(r *http.Request, uid int32) any {
+ u, _ := Q.GetUserById(r.Context(), uid)
+ return u.Name
+ })
+}
+
+func Recur() Page {
+ return simpleByQuery("recur", Q.GetRecurExpenses)
+}
+
+func Categories() Page {
+ return simpleByQuery("categories", Q.GetCategoriesOrdered)
+}
+
+func NotFound() http.HandlerFunc {
+ return func(w http.ResponseWriter, r *http.Request) {
+ showTemplate(w, "404", r.RequestURI)
+ }
+}
diff --git a/session.go b/session/session.go
index 680f648..5ffd5cd 100644
--- a/session.go
+++ b/session/session.go
@@ -1,4 +1,4 @@
-package main
+package session
import (
"context"
@@ -19,21 +19,21 @@ const (
)
func init() {
- gob.Register(SessionData{})
+ gob.Register(sessionData{})
}
type Session struct {
- *SessionData
+ *sessionData
s *sessions.Session
}
-type SessionData struct {
+type sessionData struct {
UserID int32
Authenticated bool
}
func (s *Session) Save(w http.ResponseWriter, r *http.Request) {
- s.s.Values[dataKey] = *s.SessionData
+ s.s.Values[dataKey] = *s.sessionData
if err := s.s.Save(r, w); err != nil {
log.Panic("Storing session: ", err)
}
@@ -48,18 +48,23 @@ func (s *Session) Invalidate() {
s.Authenticated = false
}
-func session(r *http.Request) Session {
+func (s *Session) IsNew() bool {
+ return s.s.IsNew
+}
+
+// From extracts the `Session` from the `Request`.
+func From(r *http.Request) Session {
s := r.Context().Value(sessionContextKey{}).(*sessions.Session)
s.Options.HttpOnly = true
- sd, ok := s.Values[dataKey].(SessionData)
+ sd, ok := s.Values[dataKey].(sessionData)
if !ok {
- sd = SessionData{}
+ sd = sessionData{}
}
return Session{&sd, s}
}
-func sessionHandler(next http.Handler) http.Handler {
+func Handler() func(next http.Handler) http.Handler {
var key []byte
if envKey := os.Getenv("GOSTEN_SECRET"); len(envKey) >= 32 {
@@ -70,10 +75,13 @@ func sessionHandler(next http.Handler) http.Handler {
sessionStore := sessions.NewCookieStore(key)
- return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- session, _ := sessionStore.Get(r, sessionCookie)
+ return func(next http.Handler) http.Handler {
+ fn := func(w http.ResponseWriter, r *http.Request) {
+ session, _ := sessionStore.Get(r, sessionCookie)
- ctx := context.WithValue(r.Context(), sessionContextKey{}, session)
- next.ServeHTTP(w, r.WithContext(ctx))
- })
+ ctx := context.WithValue(r.Context(), sessionContextKey{}, session)
+ next.ServeHTTP(w, r.WithContext(ctx))
+ }
+ return http.HandlerFunc(fn)
+ }
}