You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

332 lines
7.7 KiB

//
// Copyright (c) 2019 Ted Unangst <tedu@tedunangst.com>
//
// Permission to use, copy, modify, and distribute this software for any
// purpose with or without fee is hereby granted, provided that the above
// copyright notice and this permission notice appear in all copies.
//
// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
// ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
// ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
// OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
package main
import (
"context"
"crypto/rand"
"crypto/sha512"
"crypto/subtle"
"database/sql"
"fmt"
"hash"
"io"
"log"
"net/http"
"reflect"
"regexp"
"strings"
"sync"
"time"
"golang.org/x/crypto/bcrypt"
)
type keytype struct{}
var thekey keytype
func LoginChecker(handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
userinfo, ok := checkauthcookie(r)
if ok {
ctx := context.WithValue(r.Context(), thekey, userinfo)
r = r.WithContext(ctx)
}
handler.ServeHTTP(w, r)
})
}
func LoginRequired(handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ok := GetUserInfo(r) != nil
if !ok {
loginredirect(w, r)
return
}
handler.ServeHTTP(w, r)
})
}
func GetUserInfo(r *http.Request) *UserInfo {
userinfo, ok := r.Context().Value(thekey).(*UserInfo)
if !ok {
return nil
}
return userinfo
}
func calculateCSRF(salt, action, auth string) string {
hasher := sha512.New512_256()
zero := []byte{0}
hasher.Write(zero)
hasher.Write([]byte(auth))
hasher.Write(zero)
hasher.Write([]byte(csrfkey))
hasher.Write(zero)
hasher.Write([]byte(salt))
hasher.Write(zero)
hasher.Write([]byte(action))
hasher.Write(zero)
hash := hexsum(hasher)
return salt + hash
}
func GetCSRF(action string, r *http.Request) string {
auth := getauthcookie(r)
if auth == "" {
return ""
}
hasher := sha512.New512_256()
io.CopyN(hasher, rand.Reader, 32)
salt := hexsum(hasher)
return calculateCSRF(salt, action, auth)
}
func CheckCSRF(action string, r *http.Request) bool {
auth := getauthcookie(r)
if auth == "" {
return false
}
csrf := r.FormValue("CSRF")
if len(csrf) != authlen*2 {
return false
}
salt := csrf[0:authlen]
rv := calculateCSRF(salt, action, auth)
ok := subtle.ConstantTimeCompare([]byte(rv), []byte(csrf)) == 1
return ok
}
func CSRFWrap(action string, handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ok := CheckCSRF(action, r)
if !ok {
http.Error(w, "invalid csrf", 403)
return
}
handler.ServeHTTP(w, r)
})
}
func loginredirect(w http.ResponseWriter, r *http.Request) {
http.SetCookie(w, &http.Cookie{
Name: "auth",
Value: "",
MaxAge: -1,
Secure: securecookies,
HttpOnly: true,
})
http.Redirect(w, r, "/login", http.StatusSeeOther)
}
var authregex = regexp.MustCompile("^[[:alnum:]]+$")
var authlen = 32
var stmtUserName, stmtUserAuth, stmtSaveAuth, stmtDeleteAuth *sql.Stmt
var csrfkey string
var securecookies bool
func LoginInit(db *sql.DB) {
var err error
stmtUserName, err = db.Prepare("select userid, hash from users where username = ?")
if err != nil {
log.Fatal(err)
}
var userinfo UserInfo
t := reflect.TypeOf(userinfo)
var fields []string
for i := 0; i < t.NumField(); i++ {
f := t.Field(i)
fields = append(fields, strings.ToLower(f.Name))
}
stmtUserAuth, err = db.Prepare(fmt.Sprintf("select %s from users where userid = (select userid from auth where hash = ?)", strings.Join(fields, ", ")))
if err != nil {
log.Fatal(err)
}
stmtSaveAuth, err = db.Prepare("insert into auth (userid, hash) values (?, ?)")
if err != nil {
log.Fatal(err)
}
stmtDeleteAuth, err = db.Prepare("delete from auth where userid = ?")
if err != nil {
log.Fatal(err)
}
debug := false
getconfig("debug", &debug)
securecookies = !debug
getconfig("csrfkey", &csrfkey)
}
var authinprogress = make(map[string]bool)
var authprogressmtx sync.Mutex
func rateandwait(username string) bool {
authprogressmtx.Lock()
defer authprogressmtx.Unlock()
if authinprogress[username] {
return false
}
authinprogress[username] = true
go func(name string) {
time.Sleep(1 * time.Second / 2)
authprogressmtx.Lock()
authinprogress[name] = false
authprogressmtx.Unlock()
}(username)
return true
}
func getauthcookie(r *http.Request) string {
cookie, err := r.Cookie("auth")
if err != nil {
return ""
}
auth := cookie.Value
if !(len(auth) == authlen && authregex.MatchString(auth)) {
log.Printf("login: bad auth: %s", auth)
return ""
}
return auth
}
func checkauthcookie(r *http.Request) (*UserInfo, bool) {
auth := getauthcookie(r)
if auth == "" {
return nil, false
}
hasher := sha512.New512_256()
hasher.Write([]byte(auth))
authhash := hexsum(hasher)
row := stmtUserAuth.QueryRow(authhash)
var userinfo UserInfo
v := reflect.ValueOf(&userinfo).Elem()
var ptrs []interface{}
for i := 0; i < v.NumField(); i++ {
f := v.Field(i)
ptrs = append(ptrs, f.Addr().Interface())
}
err := row.Scan(ptrs...)
if err != nil {
if err == sql.ErrNoRows {
log.Printf("login: no auth found")
} else {
log.Printf("login: error scanning auth row: %s", err)
}
return nil, false
}
return &userinfo, true
}
func loaduser(username string) (int64, string, bool) {
row := stmtUserName.QueryRow(username)
var userid int64
var hash string
err := row.Scan(&userid, &hash)
if err != nil {
if err == sql.ErrNoRows {
log.Printf("login: no username found")
} else {
log.Printf("login: error loading username: %s", err)
}
return -1, "", false
}
return userid, hash, true
}
var userregex = regexp.MustCompile("^[[:alnum:]]+$")
var userlen = 32
var passlen = 128
func hexsum(h hash.Hash) string {
return fmt.Sprintf("%x", h.Sum(nil))[0:authlen]
}
func dologin(w http.ResponseWriter, r *http.Request) {
username := r.FormValue("username")
password := r.FormValue("password")
if len(username) == 0 || len(username) > userlen ||
!userregex.MatchString(username) || len(password) == 0 ||
len(password) > passlen {
log.Printf("login: invalid password attempt")
loginredirect(w, r)
return
}
userid, hash, ok := loaduser(username)
if !ok {
loginredirect(w, r)
return
}
if !rateandwait(username) {
loginredirect(w, r)
return
}
err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(password))
if err != nil {
log.Printf("login: incorrect password")
loginredirect(w, r)
return
}
hasher := sha512.New512_256()
io.CopyN(hasher, rand.Reader, 32)
hash = hexsum(hasher)
http.SetCookie(w, &http.Cookie{
Name: "auth",
Value: hash,
MaxAge: 3600 * 24 * 30,
Secure: securecookies,
HttpOnly: true,
})
hasher.Reset()
hasher.Write([]byte(hash))
authhash := hexsum(hasher)
_, err = stmtSaveAuth.Exec(userid, authhash)
if err != nil {
log.Printf("error saving auth: %s", err)
}
log.Printf("login: successful login")
http.Redirect(w, r, "/", http.StatusSeeOther)
}
func dologout(w http.ResponseWriter, r *http.Request) {
userinfo, ok := checkauthcookie(r)
if ok && CheckCSRF("logout", r) {
_, err := stmtDeleteAuth.Exec(userinfo.UserID)
if err != nil {
log.Printf("login: error deleting old auth: %s", err)
http.Redirect(w, r, "/", http.StatusSeeOther)
return
}
http.SetCookie(w, &http.Cookie{
Name: "auth",
Value: "",
MaxAge: -1,
Secure: securecookies,
HttpOnly: true,
})
}
http.Redirect(w, r, "/", http.StatusSeeOther)
}