summaryrefslogtreecommitdiff
path: root/auth.go
diff options
context:
space:
mode:
Diffstat (limited to 'auth.go')
-rw-r--r--auth.go116
1 files changed, 116 insertions, 0 deletions
diff --git a/auth.go b/auth.go
new file mode 100644
index 0000000..601f28a
--- /dev/null
+++ b/auth.go
@@ -0,0 +1,116 @@
+package main
+
+import (
+ "context"
+ "database/sql"
+ "errors"
+ "log"
+ "net/http"
+ "net/url"
+
+ "golang.org/x/crypto/bcrypt"
+)
+
+const (
+ userContextKey = "_user"
+ sessionDuration = 86400 * 7 // 7 days
+ loginQueryMarker = "next"
+)
+
+func RequireAuth(next http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ s := session(r)
+
+ if !s.s.IsNew && s.Authenticated {
+ u, err := Q.GetUserById(r.Context(), s.UserID)
+ if err == nil {
+ // authenticated --> done
+ ctx := context.WithValue(r.Context(), userContextKey, u.ID)
+ next.ServeHTTP(w, r.WithContext(ctx))
+ return
+ }
+
+ s.Invalidate()
+ s.Save(w, r)
+ }
+
+ // redirect to login with next-param
+ v := url.Values{}
+ v.Set(loginQueryMarker, r.URL.Path)
+ redirPath := "/login?" + v.Encode()
+ http.Redirect(w, r, redirPath, http.StatusFound)
+ })
+}
+
+func checkLogin(ctx context.Context, user User) (bool, int64) {
+ dbUser, err := Q.GetUserByName(ctx, user.Name)
+ if err == nil {
+ hash := []byte(dbUser.Pwd)
+ pwd := []byte(user.Password)
+
+ if bcrypt.CompareHashAndPassword(hash, pwd) != nil {
+ return false, 0
+ }
+ } else if errors.Is(err, sql.ErrNoRows) {
+ return false, 0
+ } else {
+ log.Panicf("Could not load user '%s': %v", user.Name, err)
+ }
+
+ return true, dbUser.ID
+}
+
+func handleLogin(w http.ResponseWriter, r *http.Request) {
+ u := User{}
+ parseForm(r, &u)
+
+ ok, userId := checkLogin(r.Context(), u)
+
+ if !ok {
+ u.Errors = []error{fieldError{"Password", "Invalid"}}
+ showLoginPage(w, u)
+ return
+ }
+
+ s := session(r)
+ if u.RememberMe {
+ s.MaxAge(sessionDuration) // 1 week
+ } else {
+ s.MaxAge(0)
+ }
+
+ s.UserID = userId
+ s.Authenticated = true
+ s.Save(w, r)
+
+ // redirect
+ next := r.URL.Query().Get(loginQueryMarker)
+ if next == "" {
+ next = "/"
+ }
+ http.Redirect(w, r, next, http.StatusFound)
+}
+
+func handleLogout() http.HandlerFunc {
+ return func(w http.ResponseWriter, r *http.Request) {
+ s := session(r)
+ s.Invalidate()
+ s.Save(w, r)
+
+ http.Redirect(w, r, "/", http.StatusFound)
+ }
+}
+
+func showLoginPage(w http.ResponseWriter, u User) {
+ showTemplate(w, "login", u)
+}
+
+func loginPage() http.HandlerFunc {
+ return func(w http.ResponseWriter, r *http.Request) {
+ showLoginPage(w, User{})
+ }
+}
+
+func userId(r *http.Request) int64 {
+ return r.Context().Value(userContextKey).(int64)
+}