summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRené 'Necoro' Neumann <necoro@necoro.eu>2024-02-14 00:23:02 +0100
committerRené 'Necoro' Neumann <necoro@necoro.eu>2024-02-14 00:23:02 +0100
commit24c2071fcaa8065d450dae78a80a671697f0e873 (patch)
tree7c301de897b0b51079090fdc10560fc52f4f97ed
parent4c98ab6a3a1f41ebaa5360a6a4615cd705a94db0 (diff)
downloadgosten-24c2071fcaa8065d450dae78a80a671697f0e873.tar.gz
gosten-24c2071fcaa8065d450dae78a80a671697f0e873.tar.bz2
gosten-24c2071fcaa8065d450dae78a80a671697f0e873.zip
Restructure: Move auth and session to their own files
Make auth handling nicer.
-rw-r--r--auth.go116
-rw-r--r--main.go140
-rw-r--r--session.go70
-rw-r--r--templ/index.tpl2
-rw-r--r--templ/login2.tpl4
5 files changed, 204 insertions, 128 deletions
diff --git a/auth.go b/auth.go
new file mode 100644
index 0000000..601f28a
--- /dev/null
+++ b/auth.go
@@ -0,0 +1,116 @@
+package main
+
+import (
+ "context"
+ "database/sql"
+ "errors"
+ "log"
+ "net/http"
+ "net/url"
+
+ "golang.org/x/crypto/bcrypt"
+)
+
+const (
+ userContextKey = "_user"
+ sessionDuration = 86400 * 7 // 7 days
+ loginQueryMarker = "next"
+)
+
+func RequireAuth(next http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ s := session(r)
+
+ if !s.s.IsNew && s.Authenticated {
+ u, err := Q.GetUserById(r.Context(), s.UserID)
+ if err == nil {
+ // authenticated --> done
+ ctx := context.WithValue(r.Context(), userContextKey, u.ID)
+ next.ServeHTTP(w, r.WithContext(ctx))
+ return
+ }
+
+ s.Invalidate()
+ s.Save(w, r)
+ }
+
+ // redirect to login with next-param
+ v := url.Values{}
+ v.Set(loginQueryMarker, r.URL.Path)
+ redirPath := "/login?" + v.Encode()
+ http.Redirect(w, r, redirPath, http.StatusFound)
+ })
+}
+
+func checkLogin(ctx context.Context, user User) (bool, int64) {
+ dbUser, err := Q.GetUserByName(ctx, user.Name)
+ if err == nil {
+ hash := []byte(dbUser.Pwd)
+ pwd := []byte(user.Password)
+
+ if bcrypt.CompareHashAndPassword(hash, pwd) != nil {
+ return false, 0
+ }
+ } else if errors.Is(err, sql.ErrNoRows) {
+ return false, 0
+ } else {
+ log.Panicf("Could not load user '%s': %v", user.Name, err)
+ }
+
+ return true, dbUser.ID
+}
+
+func handleLogin(w http.ResponseWriter, r *http.Request) {
+ u := User{}
+ parseForm(r, &u)
+
+ ok, userId := checkLogin(r.Context(), u)
+
+ if !ok {
+ u.Errors = []error{fieldError{"Password", "Invalid"}}
+ showLoginPage(w, u)
+ return
+ }
+
+ s := session(r)
+ if u.RememberMe {
+ s.MaxAge(sessionDuration) // 1 week
+ } else {
+ s.MaxAge(0)
+ }
+
+ s.UserID = userId
+ s.Authenticated = true
+ s.Save(w, r)
+
+ // redirect
+ next := r.URL.Query().Get(loginQueryMarker)
+ if next == "" {
+ next = "/"
+ }
+ http.Redirect(w, r, next, http.StatusFound)
+}
+
+func handleLogout() http.HandlerFunc {
+ return func(w http.ResponseWriter, r *http.Request) {
+ s := session(r)
+ s.Invalidate()
+ s.Save(w, r)
+
+ http.Redirect(w, r, "/", http.StatusFound)
+ }
+}
+
+func showLoginPage(w http.ResponseWriter, u User) {
+ showTemplate(w, "login", u)
+}
+
+func loginPage() http.HandlerFunc {
+ return func(w http.ResponseWriter, r *http.Request) {
+ showLoginPage(w, User{})
+ }
+}
+
+func userId(r *http.Request) int64 {
+ return r.Context().Value(userContextKey).(int64)
+}
diff --git a/main.go b/main.go
index ef6cef3..66fa34d 100644
--- a/main.go
+++ b/main.go
@@ -1,10 +1,7 @@
package main
import (
- "context"
- "database/sql"
"encoding/gob"
- "errors"
"flag"
"fmt"
"log"
@@ -15,9 +12,6 @@ import (
"github.com/gorilla/handlers"
"github.com/gorilla/schema"
- "github.com/gorilla/securecookie"
- "github.com/gorilla/sessions"
- "golang.org/x/crypto/bcrypt"
"gosten/model"
"gosten/templ"
@@ -38,9 +32,6 @@ func init() {
var Q *model.Queries
var s *schema.Decoder
-var sessionStore sessions.Store
-
-const sessionCookie = "sessionKeks"
func main() {
flag.Parse()
@@ -53,19 +44,24 @@ func main() {
Q = model.New(db)
s = schema.NewDecoder()
- sessionStore = sessions.NewCookieStore(securecookie.GenerateRandomKey(32))
mux := http.NewServeMux()
- mux.Handle("/{$}", showTemplate("index", nil))
- mux.HandleFunc("GET /login", loginPage)
+ mux.Handle("GET /login", loginPage())
mux.HandleFunc("POST /login", handleLogin)
- mux.HandleFunc("GET /logout", handleLogout)
+ mux.Handle("GET /logout", handleLogout())
+ mux.Handle("/favicon.ico", http.NotFoundHandler())
handler := sessionHandler(mux)
handler = handlers.CombinedLoggingHandler(os.Stderr, handler)
handler = handlers.ProxyHeaders(handler)
+ // the real content, needing authentification
+ authMux := http.NewServeMux()
+ mux.Handle("/", RequireAuth(authMux))
+
+ authMux.Handle("GET /{$}", indexPage())
+
address := net.JoinHostPort(host, strconv.FormatUint(port, 10))
log.Fatal(http.ListenAndServe(address, handler))
}
@@ -90,11 +86,9 @@ func (fe fieldError) FieldError() (field, err string) {
return fe.Field, fe.Issue
}
-func showTemplate(tpl string, data any) http.HandlerFunc {
- return func(w http.ResponseWriter, _ *http.Request) {
- if err := templ.Lookup(tpl).Execute(w, data); err != nil {
- log.Panicf("Executing '%s' with %+v: %v", tpl, data, err)
- }
+func showTemplate(w http.ResponseWriter, tpl string, data any) {
+ if err := templ.Lookup(tpl).Execute(w, data); err != nil {
+ log.Panicf("Executing '%s' with %+v: %v", tpl, data, err)
}
}
@@ -107,110 +101,10 @@ func parseForm[T any](r *http.Request, data *T) {
}
}
-func loginPage(w http.ResponseWriter, r *http.Request) {
- s := session(r)
-
- if !s.s.IsNew && s.Authenticated {
- u, err := Q.GetUserById(r.Context(), s.UserID)
- if err != nil {
- s.Authenticated = false
- s.Save(w, r)
- } else {
- u2 := User{Name: u.Name}
- showTemplate("login2", u2).ServeHTTP(w, r)
- return
- }
- }
-
- showTemplate("login", User{}).ServeHTTP(w, r)
-}
-
-func handleLogin(w http.ResponseWriter, r *http.Request) {
- u := User{}
- parseForm(r, &u)
-
- invalid := false
-
- dbUser, err := Q.GetUserByName(r.Context(), u.Name)
- if err == nil {
- hash := []byte(dbUser.Pwd)
- pwd := []byte(u.Password)
-
- if bcrypt.CompareHashAndPassword(hash, pwd) != nil {
- invalid = true
- }
- } else if errors.Is(err, sql.ErrNoRows) {
- invalid = true
- } else {
- log.Panicf("Could not load user '%s': %v", u.Name, err)
- }
-
- if invalid {
- u.Errors = []error{fieldError{"Password", "Invalid"}}
- showTemplate("login", u).ServeHTTP(w, r)
- return
- }
-
- s := session(r)
- if u.RememberMe {
- s.MaxAge(86400 * 7) // 1 week
- } else {
- s.MaxAge(0)
- }
-
- s.UserID = dbUser.ID
- s.Authenticated = true
- s.Save(w, r)
-
- showTemplate("login2", u).ServeHTTP(w, r)
-}
-
-func handleLogout(w http.ResponseWriter, r *http.Request) {
- s := session(r)
- s.Authenticated = false
- s.MaxAge(-1)
- s.Save(w, r)
-}
-
-func sessionHandler(next http.Handler) http.Handler {
- return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- session, err := sessionStore.Get(r, sessionCookie)
- if err != nil {
- }
-
- ctx := context.WithValue(r.Context(), "_session", session)
- next.ServeHTTP(w, r.WithContext(ctx))
- })
-}
-
-type Session struct {
- *SessionData
- s *sessions.Session
-}
-
-type SessionData struct {
- UserID int64
- Authenticated bool
-}
-
-func (s *Session) Save(w http.ResponseWriter, r *http.Request) {
- s.s.Values["data"] = *s.SessionData
- if err := s.s.Save(r, w); err != nil {
- log.Panic("Storing session: ", err)
- }
-}
-
-func (s *Session) MaxAge(maxAge int) {
- s.s.Options.MaxAge = maxAge
-}
-
-func session(r *http.Request) Session {
- s := r.Context().Value("_session").(*sessions.Session)
- s.Options.HttpOnly = true
-
- sd, ok := s.Values["data"].(SessionData)
- if !ok {
- sd = SessionData{}
+func indexPage() http.HandlerFunc {
+ return func(w http.ResponseWriter, r *http.Request) {
+ uid := userId(r)
+ u, _ := Q.GetUserById(r.Context(), uid)
+ showTemplate(w, "index", u.Name)
}
- return Session{&sd, s}
}
diff --git a/session.go b/session.go
new file mode 100644
index 0000000..f495cdd
--- /dev/null
+++ b/session.go
@@ -0,0 +1,70 @@
+package main
+
+import (
+ "context"
+ "encoding/gob"
+ "log"
+ "net/http"
+
+ "github.com/gorilla/securecookie"
+ "github.com/gorilla/sessions"
+)
+
+const (
+ sessionCookie = "sessionKeks"
+ sessionContextKey = "_session"
+ dataKey = "data"
+)
+
+var sessionStore sessions.Store
+
+func init() {
+ gob.Register(SessionData{})
+ sessionStore = sessions.NewCookieStore(securecookie.GenerateRandomKey(32))
+}
+
+type Session struct {
+ *SessionData
+ s *sessions.Session
+}
+
+type SessionData struct {
+ UserID int64
+ Authenticated bool
+}
+
+func (s *Session) Save(w http.ResponseWriter, r *http.Request) {
+ s.s.Values[dataKey] = *s.SessionData
+ if err := s.s.Save(r, w); err != nil {
+ log.Panic("Storing session: ", err)
+ }
+}
+
+func (s *Session) MaxAge(maxAge int) {
+ s.s.Options.MaxAge = maxAge
+}
+
+func (s *Session) Invalidate() {
+ s.MaxAge(-1)
+ s.Authenticated = false
+}
+
+func session(r *http.Request) Session {
+ s := r.Context().Value(sessionContextKey).(*sessions.Session)
+ s.Options.HttpOnly = true
+
+ sd, ok := s.Values[dataKey].(SessionData)
+ if !ok {
+ sd = SessionData{}
+ }
+ return Session{&sd, s}
+}
+
+func sessionHandler(next http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ session, _ := sessionStore.Get(r, sessionCookie)
+
+ ctx := context.WithValue(r.Context(), sessionContextKey, session)
+ next.ServeHTTP(w, r.WithContext(ctx))
+ })
+}
diff --git a/templ/index.tpl b/templ/index.tpl
index 2c52810..013c5ea 100644
--- a/templ/index.tpl
+++ b/templ/index.tpl
@@ -1,3 +1,3 @@
{{define "body"}}
- Das ist die Basis
+ Logged in with user: {{.}}
{{end}} \ No newline at end of file
diff --git a/templ/login2.tpl b/templ/login2.tpl
deleted file mode 100644
index 89ba6a5..0000000
--- a/templ/login2.tpl
+++ /dev/null
@@ -1,4 +0,0 @@
-{{define "body"}}
- Logged in with user: {{.Name}} <br>
- You have chosen: {{.Password}}
-{{end}} \ No newline at end of file