nidus-sync/auth/auth.go

242 lines
7.8 KiB
Go
Raw Normal View History

package auth
import (
"context"
"errors"
"fmt"
"net/http"
"strconv"
"strings"
"github.com/Gleipnir-Technology/bob/dialect/psql"
"github.com/Gleipnir-Technology/bob/dialect/psql/sm"
"github.com/Gleipnir-Technology/nidus-sync/db"
"github.com/Gleipnir-Technology/nidus-sync/db/enums"
"github.com/Gleipnir-Technology/nidus-sync/db/models"
"github.com/Gleipnir-Technology/nidus-sync/db/sql"
"github.com/Gleipnir-Technology/nidus-sync/debug"
2025-11-06 22:31:51 +00:00
"github.com/aarondl/opt/omit"
2025-11-24 19:49:19 +00:00
"github.com/rs/zerolog/log"
"golang.org/x/crypto/bcrypt"
)
type NoCredentialsError struct{}
2025-11-06 22:31:51 +00:00
func (e NoCredentialsError) Error() string { return "No credentials were present in the request" }
type NoUserError struct{}
func (e NoUserError) Error() string { return "That user does not exist" }
type InvalidCredentials struct{}
2025-11-06 22:31:51 +00:00
func (e InvalidCredentials) Error() string { return "No username with that password exists" }
type InvalidUsername struct{}
2025-11-06 22:31:51 +00:00
func (e InvalidUsername) Error() string { return "That username doesn't exist" }
type AuthenticatedHandler func(http.ResponseWriter, *http.Request, *models.User)
type EnsureAuth struct {
handler AuthenticatedHandler
}
func AddUserSession(r *http.Request, user *models.User) {
id := strconv.Itoa(int(user.ID))
sessionManager.Put(r.Context(), "user_id", id)
sessionManager.Put(r.Context(), "username", user.Username)
}
func GetAuthenticatedUser(r *http.Request) (*models.User, error) {
//user_id := sessionManager.GetInt(r.Context(), "user_id")
user_id_str := sessionManager.GetString(r.Context(), "user_id")
if user_id_str != "" {
user_id, err := strconv.Atoi(user_id_str)
if err != nil {
return nil, fmt.Errorf("Failed to convert user_id to int: %w", err)
}
username := sessionManager.GetString(r.Context(), "username")
if user_id > 0 && username != "" {
return findUser(r.Context(), user_id)
}
}
// If we can't get the user from the session try to get from auth headers
username, password, ok := r.BasicAuth()
if !ok {
return nil, &NoCredentialsError{}
}
user, err := validateUser(r.Context(), username, password)
if err != nil {
return nil, err
}
AddUserSession(r, user)
return user, nil
}
func NewEnsureAuth(handlerToWrap AuthenticatedHandler) *EnsureAuth {
return &EnsureAuth{handlerToWrap}
}
func (ea *EnsureAuth) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// If this is an API request respond with a more machine-readable error state
accept := r.Header.Values("Accept")
offers := []string{"application/json", "text/html"}
content_type := NegotiateContent(accept, offers)
user, err := GetAuthenticatedUser(r)
if err != nil || user == nil {
var msg []byte
// Separate return codes for different authentication failures
if _, ok := err.(*NoCredentialsError); ok {
fmt.Println("No credentials present and no session")
w.Header().Set("WWW-Authenticate-Error", "no-credentials")
msg = []byte("Please provide credentials.\n")
} else if _, ok := err.(*NoUserError); ok {
w.Header().Set("WWW-Authenticate-Error", "invalid-credentials")
msg = []byte("Invalid credentials provided.\n")
} else if _, ok := err.(*InvalidCredentials); ok {
w.Header().Set("WWW-Authenticate-Error", "invalid-credentials")
msg = []byte("Invalid credentials provided.\n")
}
if content_type == "text/html" {
http.Redirect(w, r, "/signin?next="+r.URL.Path, http.StatusSeeOther)
return
}
w.Header().Set("WWW-Authenticate", `Basic realm="Nidus Sync"`)
w.WriteHeader(401)
w.Write(msg)
return
}
ea.handler(w, r, user)
}
func SigninUser(r *http.Request, username string, password string) (*models.User, error) {
user, err := validateUser(r.Context(), username, password)
if err != nil {
return nil, err
}
if user == nil {
return nil, errors.New("No matching user")
}
AddUserSession(r, user)
return user, nil
}
func SignoutUser(r *http.Request, user *models.User) {
sessionManager.Put(r.Context(), "user_id", "")
sessionManager.Put(r.Context(), "username", "")
log.Info().Str("username", user.Username).Int32("user_id", user.ID).Msg("Ended user session")
}
func SignupUser(ctx context.Context, username string, name string, password string) (*models.User, error) {
passwordHash, err := HashPassword(password)
if err != nil {
return nil, fmt.Errorf("Cannot signup user, failed to create hashed password: %w", err)
}
o_setter := models.OrganizationSetter{
Name: omit.From(fmt.Sprintf("%s's organization", username)),
}
o, err := models.Organizations.Insert(&o_setter).One(ctx, db.PGInstance.BobDB)
if err != nil {
return nil, fmt.Errorf("Failed to create organization: %w", err)
}
log.Info().Int32("id", o.ID).Msg("Created organization")
u_setter := models.UserSetter{
2025-11-06 22:31:51 +00:00
DisplayName: omit.From(name),
OrganizationID: omit.From(o.ID),
2025-11-06 22:31:51 +00:00
PasswordHash: omit.From(passwordHash),
PasswordHashType: omit.From(enums.HashtypeBcrypt14),
2026-02-18 07:02:36 +00:00
Role: omit.From(enums.UserroleAccountOwner),
2025-11-06 22:31:51 +00:00
Username: omit.From(username),
}
u, err := models.Users.Insert(&u_setter).One(ctx, db.PGInstance.BobDB)
if err != nil {
return nil, fmt.Errorf("Failed to create user: %w", err)
}
log.Info().Int32("id", u.ID).Str("username", u.Username).Msg("Created user")
return u, nil
}
// Helper function to translate strings into solid error types for operating on
func findUser(ctx context.Context, user_id int) (*models.User, error) {
//user, err := models.FindUser(ctx, db.PGInstance.BobDB, int32(user_id))
user, err := models.Users.Query(
models.Preload.User.Organization(),
sm.Where(models.Users.Columns.ID.EQ(psql.Arg(user_id))),
).One(ctx, db.PGInstance.BobDB)
if err != nil {
if err.Error() == "No such user" || err.Error() == "sql: no rows in result set" {
return nil, &NoUserError{}
} else {
debug.LogErrorTypeInfo(err)
log.Error().Err(err).Msg("Unrecognized error. This should be updated in the findUser code")
return nil, err
}
}
//log.Info().Int32("user_id", user.ID).Int32("org_id", user.OrganizationID).Msg("Found user")
return user, err
}
func HashPassword(password string) (string, error) {
bytes, err := bcrypt.GenerateFromPassword([]byte(password), 14)
return string(bytes), err
}
func redact(s string) string {
if len(s) <= 4 {
return s
}
first_two := s[:2]
last_two := s[len(s)-2:]
middle_length := len(s) - 4
return first_two + strings.Repeat("*", middle_length) + last_two
}
func validatePassword(password, hash string) bool {
err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(password))
return err == nil
}
func validateUser(ctx context.Context, username string, password string) (*models.User, error) {
passwordHash, err := HashPassword(password)
if err != nil {
return nil, fmt.Errorf("Failed to hash password: %w", err)
}
result, err := sql.UserByUsername(username).All(ctx, db.PGInstance.BobDB)
if err != nil {
return nil, fmt.Errorf("Failed to query for user: %w", err)
}
switch len(result) {
2025-11-06 22:31:51 +00:00
case 0:
log.Info().Str("username", username).Str("password", redact(password)).Msg("Invalid username")
return nil, InvalidUsername{}
case 1:
row := result[0]
if !validatePassword(password, row.PasswordHash) {
log.Info().Str("username", username).Str("password", redact(password)).Str("hash", passwordHash).Msg("Invalid password for user")
return nil, InvalidCredentials{}
}
user := models.User{
2025-11-06 22:31:51 +00:00
ID: row.ID,
ArcgisAccessToken: row.ArcgisAccessToken,
ArcgisLicense: row.ArcgisLicense,
ArcgisRefreshToken: row.ArcgisRefreshToken,
ArcgisRefreshTokenExpires: row.ArcgisRefreshTokenExpires,
2025-11-06 22:31:51 +00:00
ArcgisRole: row.ArcgisRole,
DisplayName: row.DisplayName,
Email: row.Email,
OrganizationID: row.OrganizationID,
Username: row.Username,
}
log.Info().Str("username", username).Msg("Validated user")
return &user, nil
default:
return nil, errors.New("More than one matching row, this should be impossible.")
}
}