package pages import ( "context" "database/sql" "errors" "gosten/csrf" "gosten/form" "gosten/model" "gosten/session" "log" "net/http" "net/url" "github.com/go-chi/chi/v5" "golang.org/x/crypto/bcrypt" ) type userContextKey struct{} func getUser(ctx context.Context) model.User { return ctx.Value(userContextKey{}).(model.User) } const ( sessionDuration = 86400 * 7 // 7 days loginQueryMarker = "next" ) func setUserInContext(ctx context.Context, uid int32) (context.Context, error) { u, err := Q.GetUserById(ctx, uid) if err != nil { return ctx, err } u.Pwd = "" // don't carry pwd around return context.WithValue(ctx, userContextKey{}, u), nil } func RequireAuth(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { s := session.From(r) if !s.IsNew() && s.Authenticated { ctx, err := setUserInContext(r.Context(), s.UserID) if err == nil { // authenticated --> done next.ServeHTTP(w, r.WithContext(ctx)) return } s.Invalidate() s.Save(w, r) } // redirect to login with next-param v := url.Values{} v.Set(loginQueryMarker, r.URL.Path) redirPath := "/login?" + v.Encode() http.Redirect(w, r, redirPath, http.StatusFound) }) } type user struct { Name string `form:"options=required,autofocus"` Password string `form:"type=password;options=required"` RememberMe bool `form:"type=checkbox;value=y;options=checked"` form.FormErrors csrf.CsrfField } func Login() Page { r := chi.NewRouter() r.Get("/", func(w http.ResponseWriter, r *http.Request) { if session.From(r).Authenticated { http.Redirect(w, r, "/", http.StatusFound) } u := user{} u.SetCsrfField(r) showLoginPage(r, w, u) }) r.Post("/", handleLogin) return r } func validatePwd(hash, pwd string) bool { hashB := []byte(hash) pwdB := []byte(pwd) return bcrypt.CompareHashAndPassword(hashB, pwdB) == nil } func checkLogin(ctx context.Context, user user) (bool, int32) { dbUser, err := Q.GetUserByName(ctx, user.Name) if err == nil { if !validatePwd(dbUser.Pwd, user.Password) { return false, 0 } } else if errors.Is(err, sql.ErrNoRows) { return false, 0 } else { log.Panicf("Could not load user '%s': %v", user.Name, err) } return true, dbUser.ID } func handleLogin(w http.ResponseWriter, r *http.Request) { u := user{} form.Parse(r, &u) ok, userId := checkLogin(r.Context(), u) if !ok { u.Errors = []error{form.FieldError{Field: "Password", Issue: "Invalid"}} showLoginPage(r, w, u) return } s := session.From(r) if u.RememberMe { s.MaxAge(sessionDuration) // 1 week } else { s.MaxAge(0) } s.UserID = userId s.Authenticated = true s.Save(w, r) // redirect next := r.URL.Query().Get(loginQueryMarker) if next == "" { next = "/" } http.Redirect(w, r, next, http.StatusFound) } func showLoginPage(r *http.Request, w http.ResponseWriter, u user) { render(login(u))(w, r) }