From 24c2071fcaa8065d450dae78a80a671697f0e873 Mon Sep 17 00:00:00 2001 From: René 'Necoro' Neumann Date: Wed, 14 Feb 2024 00:23:02 +0100 Subject: Restructure: Move auth and session to their own files Make auth handling nicer. --- auth.go | 116 +++++++++++++++++++++++++++++++++++++++++++++ main.go | 140 +++++++------------------------------------------------ session.go | 70 ++++++++++++++++++++++++++++ templ/index.tpl | 2 +- templ/login2.tpl | 4 -- 5 files changed, 204 insertions(+), 128 deletions(-) create mode 100644 auth.go create mode 100644 session.go delete mode 100644 templ/login2.tpl diff --git a/auth.go b/auth.go new file mode 100644 index 0000000..601f28a --- /dev/null +++ b/auth.go @@ -0,0 +1,116 @@ +package main + +import ( + "context" + "database/sql" + "errors" + "log" + "net/http" + "net/url" + + "golang.org/x/crypto/bcrypt" +) + +const ( + userContextKey = "_user" + sessionDuration = 86400 * 7 // 7 days + loginQueryMarker = "next" +) + +func RequireAuth(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + s := session(r) + + if !s.s.IsNew && s.Authenticated { + u, err := Q.GetUserById(r.Context(), s.UserID) + if err == nil { + // authenticated --> done + ctx := context.WithValue(r.Context(), userContextKey, u.ID) + next.ServeHTTP(w, r.WithContext(ctx)) + return + } + + s.Invalidate() + s.Save(w, r) + } + + // redirect to login with next-param + v := url.Values{} + v.Set(loginQueryMarker, r.URL.Path) + redirPath := "/login?" + v.Encode() + http.Redirect(w, r, redirPath, http.StatusFound) + }) +} + +func checkLogin(ctx context.Context, user User) (bool, int64) { + dbUser, err := Q.GetUserByName(ctx, user.Name) + if err == nil { + hash := []byte(dbUser.Pwd) + pwd := []byte(user.Password) + + if bcrypt.CompareHashAndPassword(hash, pwd) != nil { + return false, 0 + } + } else if errors.Is(err, sql.ErrNoRows) { + return false, 0 + } else { + log.Panicf("Could not load user '%s': %v", user.Name, err) + } + + return true, dbUser.ID +} + +func handleLogin(w http.ResponseWriter, r *http.Request) { + u := User{} + parseForm(r, &u) + + ok, userId := checkLogin(r.Context(), u) + + if !ok { + u.Errors = []error{fieldError{"Password", "Invalid"}} + showLoginPage(w, u) + return + } + + s := session(r) + if u.RememberMe { + s.MaxAge(sessionDuration) // 1 week + } else { + s.MaxAge(0) + } + + s.UserID = userId + s.Authenticated = true + s.Save(w, r) + + // redirect + next := r.URL.Query().Get(loginQueryMarker) + if next == "" { + next = "/" + } + http.Redirect(w, r, next, http.StatusFound) +} + +func handleLogout() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + s := session(r) + s.Invalidate() + s.Save(w, r) + + http.Redirect(w, r, "/", http.StatusFound) + } +} + +func showLoginPage(w http.ResponseWriter, u User) { + showTemplate(w, "login", u) +} + +func loginPage() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + showLoginPage(w, User{}) + } +} + +func userId(r *http.Request) int64 { + return r.Context().Value(userContextKey).(int64) +} diff --git a/main.go b/main.go index ef6cef3..66fa34d 100644 --- a/main.go +++ b/main.go @@ -1,10 +1,7 @@ package main import ( - "context" - "database/sql" "encoding/gob" - "errors" "flag" "fmt" "log" @@ -15,9 +12,6 @@ import ( "github.com/gorilla/handlers" "github.com/gorilla/schema" - "github.com/gorilla/securecookie" - "github.com/gorilla/sessions" - "golang.org/x/crypto/bcrypt" "gosten/model" "gosten/templ" @@ -38,9 +32,6 @@ func init() { var Q *model.Queries var s *schema.Decoder -var sessionStore sessions.Store - -const sessionCookie = "sessionKeks" func main() { flag.Parse() @@ -53,19 +44,24 @@ func main() { Q = model.New(db) s = schema.NewDecoder() - sessionStore = sessions.NewCookieStore(securecookie.GenerateRandomKey(32)) mux := http.NewServeMux() - mux.Handle("/{$}", showTemplate("index", nil)) - mux.HandleFunc("GET /login", loginPage) + mux.Handle("GET /login", loginPage()) mux.HandleFunc("POST /login", handleLogin) - mux.HandleFunc("GET /logout", handleLogout) + mux.Handle("GET /logout", handleLogout()) + mux.Handle("/favicon.ico", http.NotFoundHandler()) handler := sessionHandler(mux) handler = handlers.CombinedLoggingHandler(os.Stderr, handler) handler = handlers.ProxyHeaders(handler) + // the real content, needing authentification + authMux := http.NewServeMux() + mux.Handle("/", RequireAuth(authMux)) + + authMux.Handle("GET /{$}", indexPage()) + address := net.JoinHostPort(host, strconv.FormatUint(port, 10)) log.Fatal(http.ListenAndServe(address, handler)) } @@ -90,11 +86,9 @@ func (fe fieldError) FieldError() (field, err string) { return fe.Field, fe.Issue } -func showTemplate(tpl string, data any) http.HandlerFunc { - return func(w http.ResponseWriter, _ *http.Request) { - if err := templ.Lookup(tpl).Execute(w, data); err != nil { - log.Panicf("Executing '%s' with %+v: %v", tpl, data, err) - } +func showTemplate(w http.ResponseWriter, tpl string, data any) { + if err := templ.Lookup(tpl).Execute(w, data); err != nil { + log.Panicf("Executing '%s' with %+v: %v", tpl, data, err) } } @@ -107,110 +101,10 @@ func parseForm[T any](r *http.Request, data *T) { } } -func loginPage(w http.ResponseWriter, r *http.Request) { - s := session(r) - - if !s.s.IsNew && s.Authenticated { - u, err := Q.GetUserById(r.Context(), s.UserID) - if err != nil { - s.Authenticated = false - s.Save(w, r) - } else { - u2 := User{Name: u.Name} - showTemplate("login2", u2).ServeHTTP(w, r) - return - } - } - - showTemplate("login", User{}).ServeHTTP(w, r) -} - -func handleLogin(w http.ResponseWriter, r *http.Request) { - u := User{} - parseForm(r, &u) - - invalid := false - - dbUser, err := Q.GetUserByName(r.Context(), u.Name) - if err == nil { - hash := []byte(dbUser.Pwd) - pwd := []byte(u.Password) - - if bcrypt.CompareHashAndPassword(hash, pwd) != nil { - invalid = true - } - } else if errors.Is(err, sql.ErrNoRows) { - invalid = true - } else { - log.Panicf("Could not load user '%s': %v", u.Name, err) - } - - if invalid { - u.Errors = []error{fieldError{"Password", "Invalid"}} - showTemplate("login", u).ServeHTTP(w, r) - return - } - - s := session(r) - if u.RememberMe { - s.MaxAge(86400 * 7) // 1 week - } else { - s.MaxAge(0) - } - - s.UserID = dbUser.ID - s.Authenticated = true - s.Save(w, r) - - showTemplate("login2", u).ServeHTTP(w, r) -} - -func handleLogout(w http.ResponseWriter, r *http.Request) { - s := session(r) - s.Authenticated = false - s.MaxAge(-1) - s.Save(w, r) -} - -func sessionHandler(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - session, err := sessionStore.Get(r, sessionCookie) - if err != nil { - } - - ctx := context.WithValue(r.Context(), "_session", session) - next.ServeHTTP(w, r.WithContext(ctx)) - }) -} - -type Session struct { - *SessionData - s *sessions.Session -} - -type SessionData struct { - UserID int64 - Authenticated bool -} - -func (s *Session) Save(w http.ResponseWriter, r *http.Request) { - s.s.Values["data"] = *s.SessionData - if err := s.s.Save(r, w); err != nil { - log.Panic("Storing session: ", err) - } -} - -func (s *Session) MaxAge(maxAge int) { - s.s.Options.MaxAge = maxAge -} - -func session(r *http.Request) Session { - s := r.Context().Value("_session").(*sessions.Session) - s.Options.HttpOnly = true - - sd, ok := s.Values["data"].(SessionData) - if !ok { - sd = SessionData{} +func indexPage() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + uid := userId(r) + u, _ := Q.GetUserById(r.Context(), uid) + showTemplate(w, "index", u.Name) } - return Session{&sd, s} } diff --git a/session.go b/session.go new file mode 100644 index 0000000..f495cdd --- /dev/null +++ b/session.go @@ -0,0 +1,70 @@ +package main + +import ( + "context" + "encoding/gob" + "log" + "net/http" + + "github.com/gorilla/securecookie" + "github.com/gorilla/sessions" +) + +const ( + sessionCookie = "sessionKeks" + sessionContextKey = "_session" + dataKey = "data" +) + +var sessionStore sessions.Store + +func init() { + gob.Register(SessionData{}) + sessionStore = sessions.NewCookieStore(securecookie.GenerateRandomKey(32)) +} + +type Session struct { + *SessionData + s *sessions.Session +} + +type SessionData struct { + UserID int64 + Authenticated bool +} + +func (s *Session) Save(w http.ResponseWriter, r *http.Request) { + s.s.Values[dataKey] = *s.SessionData + if err := s.s.Save(r, w); err != nil { + log.Panic("Storing session: ", err) + } +} + +func (s *Session) MaxAge(maxAge int) { + s.s.Options.MaxAge = maxAge +} + +func (s *Session) Invalidate() { + s.MaxAge(-1) + s.Authenticated = false +} + +func session(r *http.Request) Session { + s := r.Context().Value(sessionContextKey).(*sessions.Session) + s.Options.HttpOnly = true + + sd, ok := s.Values[dataKey].(SessionData) + if !ok { + sd = SessionData{} + } + return Session{&sd, s} +} + +func sessionHandler(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + session, _ := sessionStore.Get(r, sessionCookie) + + ctx := context.WithValue(r.Context(), sessionContextKey, session) + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} diff --git a/templ/index.tpl b/templ/index.tpl index 2c52810..013c5ea 100644 --- a/templ/index.tpl +++ b/templ/index.tpl @@ -1,3 +1,3 @@ {{define "body"}} - Das ist die Basis + Logged in with user: {{.}} {{end}} \ No newline at end of file diff --git a/templ/login2.tpl b/templ/login2.tpl deleted file mode 100644 index 89ba6a5..0000000 --- a/templ/login2.tpl +++ /dev/null @@ -1,4 +0,0 @@ -{{define "body"}} - Logged in with user: {{.Name}}
- You have chosen: {{.Password}} -{{end}} \ No newline at end of file -- cgit v1.2.3-70-g09d2