parent
d728004cad
commit
723efee364
@ -1,331 +0,0 @@
|
|||||||
//
|
|
||||||
// Copyright (c) 2019 Ted Unangst <tedu@tedunangst.com>
|
|
||||||
//
|
|
||||||
// Permission to use, copy, modify, and distribute this software for any
|
|
||||||
// purpose with or without fee is hereby granted, provided that the above
|
|
||||||
// copyright notice and this permission notice appear in all copies.
|
|
||||||
//
|
|
||||||
// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
|
|
||||||
// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
|
|
||||||
// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
|
|
||||||
// ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
|
|
||||||
// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
|
|
||||||
// ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
|
|
||||||
// OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
|
||||||
|
|
||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"crypto/rand"
|
|
||||||
"crypto/sha512"
|
|
||||||
"crypto/subtle"
|
|
||||||
"database/sql"
|
|
||||||
"fmt"
|
|
||||||
"hash"
|
|
||||||
"io"
|
|
||||||
"log"
|
|
||||||
"net/http"
|
|
||||||
"reflect"
|
|
||||||
"regexp"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"golang.org/x/crypto/bcrypt"
|
|
||||||
)
|
|
||||||
|
|
||||||
type keytype struct{}
|
|
||||||
|
|
||||||
var thekey keytype
|
|
||||||
|
|
||||||
func LoginChecker(handler http.Handler) http.Handler {
|
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
userinfo, ok := checkauthcookie(r)
|
|
||||||
if ok {
|
|
||||||
ctx := context.WithValue(r.Context(), thekey, userinfo)
|
|
||||||
r = r.WithContext(ctx)
|
|
||||||
}
|
|
||||||
handler.ServeHTTP(w, r)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func LoginRequired(handler http.Handler) http.Handler {
|
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
ok := GetUserInfo(r) != nil
|
|
||||||
if !ok {
|
|
||||||
loginredirect(w, r)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
handler.ServeHTTP(w, r)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetUserInfo(r *http.Request) *UserInfo {
|
|
||||||
userinfo, ok := r.Context().Value(thekey).(*UserInfo)
|
|
||||||
if !ok {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return userinfo
|
|
||||||
}
|
|
||||||
|
|
||||||
func calculateCSRF(salt, action, auth string) string {
|
|
||||||
hasher := sha512.New512_256()
|
|
||||||
zero := []byte{0}
|
|
||||||
hasher.Write(zero)
|
|
||||||
hasher.Write([]byte(auth))
|
|
||||||
hasher.Write(zero)
|
|
||||||
hasher.Write([]byte(csrfkey))
|
|
||||||
hasher.Write(zero)
|
|
||||||
hasher.Write([]byte(salt))
|
|
||||||
hasher.Write(zero)
|
|
||||||
hasher.Write([]byte(action))
|
|
||||||
hasher.Write(zero)
|
|
||||||
hash := hexsum(hasher)
|
|
||||||
|
|
||||||
return salt + hash
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetCSRF(action string, r *http.Request) string {
|
|
||||||
auth := getauthcookie(r)
|
|
||||||
if auth == "" {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
hasher := sha512.New512_256()
|
|
||||||
io.CopyN(hasher, rand.Reader, 32)
|
|
||||||
salt := hexsum(hasher)
|
|
||||||
|
|
||||||
return calculateCSRF(salt, action, auth)
|
|
||||||
}
|
|
||||||
|
|
||||||
func CheckCSRF(action string, r *http.Request) bool {
|
|
||||||
auth := getauthcookie(r)
|
|
||||||
if auth == "" {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
csrf := r.FormValue("CSRF")
|
|
||||||
if len(csrf) != authlen*2 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
salt := csrf[0:authlen]
|
|
||||||
rv := calculateCSRF(salt, action, auth)
|
|
||||||
ok := subtle.ConstantTimeCompare([]byte(rv), []byte(csrf)) == 1
|
|
||||||
return ok
|
|
||||||
}
|
|
||||||
|
|
||||||
func CSRFWrap(action string, handler http.Handler) http.Handler {
|
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
ok := CheckCSRF(action, r)
|
|
||||||
if !ok {
|
|
||||||
http.Error(w, "invalid csrf", 403)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
handler.ServeHTTP(w, r)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func loginredirect(w http.ResponseWriter, r *http.Request) {
|
|
||||||
http.SetCookie(w, &http.Cookie{
|
|
||||||
Name: "auth",
|
|
||||||
Value: "",
|
|
||||||
MaxAge: -1,
|
|
||||||
Secure: securecookies,
|
|
||||||
HttpOnly: true,
|
|
||||||
})
|
|
||||||
http.Redirect(w, r, "/login", http.StatusSeeOther)
|
|
||||||
}
|
|
||||||
|
|
||||||
var authregex = regexp.MustCompile("^[[:alnum:]]+$")
|
|
||||||
var authlen = 32
|
|
||||||
|
|
||||||
var stmtUserName, stmtUserAuth, stmtSaveAuth, stmtDeleteAuth *sql.Stmt
|
|
||||||
var csrfkey string
|
|
||||||
var securecookies bool
|
|
||||||
|
|
||||||
func LoginInit(db *sql.DB) {
|
|
||||||
var err error
|
|
||||||
stmtUserName, err = db.Prepare("select userid, hash from users where username = ?")
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
var userinfo UserInfo
|
|
||||||
t := reflect.TypeOf(userinfo)
|
|
||||||
var fields []string
|
|
||||||
for i := 0; i < t.NumField(); i++ {
|
|
||||||
f := t.Field(i)
|
|
||||||
fields = append(fields, strings.ToLower(f.Name))
|
|
||||||
}
|
|
||||||
stmtUserAuth, err = db.Prepare(fmt.Sprintf("select %s from users where userid = (select userid from auth where hash = ?)", strings.Join(fields, ", ")))
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
stmtSaveAuth, err = db.Prepare("insert into auth (userid, hash) values (?, ?)")
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
stmtDeleteAuth, err = db.Prepare("delete from auth where userid = ?")
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
debug := false
|
|
||||||
getconfig("debug", &debug)
|
|
||||||
securecookies = !debug
|
|
||||||
getconfig("csrfkey", &csrfkey)
|
|
||||||
}
|
|
||||||
|
|
||||||
var authinprogress = make(map[string]bool)
|
|
||||||
var authprogressmtx sync.Mutex
|
|
||||||
|
|
||||||
func rateandwait(username string) bool {
|
|
||||||
authprogressmtx.Lock()
|
|
||||||
defer authprogressmtx.Unlock()
|
|
||||||
if authinprogress[username] {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
authinprogress[username] = true
|
|
||||||
go func(name string) {
|
|
||||||
time.Sleep(1 * time.Second / 2)
|
|
||||||
authprogressmtx.Lock()
|
|
||||||
authinprogress[name] = false
|
|
||||||
authprogressmtx.Unlock()
|
|
||||||
}(username)
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func getauthcookie(r *http.Request) string {
|
|
||||||
cookie, err := r.Cookie("auth")
|
|
||||||
if err != nil {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
auth := cookie.Value
|
|
||||||
if !(len(auth) == authlen && authregex.MatchString(auth)) {
|
|
||||||
log.Printf("login: bad auth: %s", auth)
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
return auth
|
|
||||||
}
|
|
||||||
|
|
||||||
func checkauthcookie(r *http.Request) (*UserInfo, bool) {
|
|
||||||
auth := getauthcookie(r)
|
|
||||||
if auth == "" {
|
|
||||||
return nil, false
|
|
||||||
}
|
|
||||||
hasher := sha512.New512_256()
|
|
||||||
hasher.Write([]byte(auth))
|
|
||||||
authhash := hexsum(hasher)
|
|
||||||
row := stmtUserAuth.QueryRow(authhash)
|
|
||||||
var userinfo UserInfo
|
|
||||||
v := reflect.ValueOf(&userinfo).Elem()
|
|
||||||
var ptrs []interface{}
|
|
||||||
for i := 0; i < v.NumField(); i++ {
|
|
||||||
f := v.Field(i)
|
|
||||||
ptrs = append(ptrs, f.Addr().Interface())
|
|
||||||
}
|
|
||||||
err := row.Scan(ptrs...)
|
|
||||||
if err != nil {
|
|
||||||
if err == sql.ErrNoRows {
|
|
||||||
log.Printf("login: no auth found")
|
|
||||||
} else {
|
|
||||||
log.Printf("login: error scanning auth row: %s", err)
|
|
||||||
}
|
|
||||||
return nil, false
|
|
||||||
}
|
|
||||||
return &userinfo, true
|
|
||||||
}
|
|
||||||
|
|
||||||
func loaduser(username string) (int64, string, bool) {
|
|
||||||
row := stmtUserName.QueryRow(username)
|
|
||||||
var userid int64
|
|
||||||
var hash string
|
|
||||||
err := row.Scan(&userid, &hash)
|
|
||||||
if err != nil {
|
|
||||||
if err == sql.ErrNoRows {
|
|
||||||
log.Printf("login: no username found")
|
|
||||||
} else {
|
|
||||||
log.Printf("login: error loading username: %s", err)
|
|
||||||
}
|
|
||||||
return -1, "", false
|
|
||||||
}
|
|
||||||
return userid, hash, true
|
|
||||||
}
|
|
||||||
|
|
||||||
var userregex = regexp.MustCompile("^[[:alnum:]]+$")
|
|
||||||
var userlen = 32
|
|
||||||
var passlen = 128
|
|
||||||
|
|
||||||
func hexsum(h hash.Hash) string {
|
|
||||||
return fmt.Sprintf("%x", h.Sum(nil))[0:authlen]
|
|
||||||
}
|
|
||||||
|
|
||||||
func dologin(w http.ResponseWriter, r *http.Request) {
|
|
||||||
username := r.FormValue("username")
|
|
||||||
password := r.FormValue("password")
|
|
||||||
|
|
||||||
if len(username) == 0 || len(username) > userlen ||
|
|
||||||
!userregex.MatchString(username) || len(password) == 0 ||
|
|
||||||
len(password) > passlen {
|
|
||||||
log.Printf("login: invalid password attempt")
|
|
||||||
loginredirect(w, r)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
userid, hash, ok := loaduser(username)
|
|
||||||
if !ok {
|
|
||||||
loginredirect(w, r)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if !rateandwait(username) {
|
|
||||||
loginredirect(w, r)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(password))
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("login: incorrect password")
|
|
||||||
loginredirect(w, r)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
hasher := sha512.New512_256()
|
|
||||||
io.CopyN(hasher, rand.Reader, 32)
|
|
||||||
hash = hexsum(hasher)
|
|
||||||
|
|
||||||
http.SetCookie(w, &http.Cookie{
|
|
||||||
Name: "auth",
|
|
||||||
Value: hash,
|
|
||||||
MaxAge: 3600 * 24 * 30,
|
|
||||||
Secure: securecookies,
|
|
||||||
HttpOnly: true,
|
|
||||||
})
|
|
||||||
|
|
||||||
hasher.Reset()
|
|
||||||
hasher.Write([]byte(hash))
|
|
||||||
authhash := hexsum(hasher)
|
|
||||||
|
|
||||||
_, err = stmtSaveAuth.Exec(userid, authhash)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("error saving auth: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Printf("login: successful login")
|
|
||||||
http.Redirect(w, r, "/", http.StatusSeeOther)
|
|
||||||
}
|
|
||||||
|
|
||||||
func dologout(w http.ResponseWriter, r *http.Request) {
|
|
||||||
userinfo, ok := checkauthcookie(r)
|
|
||||||
if ok && CheckCSRF("logout", r) {
|
|
||||||
_, err := stmtDeleteAuth.Exec(userinfo.UserID)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("login: error deleting old auth: %s", err)
|
|
||||||
http.Redirect(w, r, "/", http.StatusSeeOther)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
http.SetCookie(w, &http.Cookie{
|
|
||||||
Name: "auth",
|
|
||||||
Value: "",
|
|
||||||
MaxAge: -1,
|
|
||||||
Secure: securecookies,
|
|
||||||
HttpOnly: true,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
http.Redirect(w, r, "/", http.StatusSeeOther)
|
|
||||||
}
|
|
Loading…
Reference in new issue