feat(security): add configurable SQL procedure names
Some checks failed
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Successful in -25m9s
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Successful in -24m29s
Build , Vet Test, and Lint / Build (push) Successful in -30m5s
Build , Vet Test, and Lint / Lint Code (push) Failing after -28m58s
Tests / Integration Tests (push) Failing after -30m26s
Tests / Unit Tests (push) Successful in -28m7s

* Introduce SQLNames struct to define stored procedure names.
* Update DatabaseAuthenticator, JWTAuthenticator, and other providers to use SQLNames for procedure calls.
* Remove hardcoded procedure names for better flexibility and customization.
* Implement validation for SQL names to ensure they are valid identifiers.
* Add tests for SQLNames functionality and merging behavior.
This commit is contained in:
Hein
2026-03-31 14:25:59 +02:00
parent aa362c77da
commit 568df8c6d6
8 changed files with 476 additions and 120 deletions

View File

@@ -258,11 +258,8 @@ func (a *JWTAuthenticator) Login(ctx context.Context, req security.LoginRequest)
} }
func (a *JWTAuthenticator) Logout(ctx context.Context, req security.LogoutRequest) error { func (a *JWTAuthenticator) Logout(ctx context.Context, req security.LogoutRequest) error {
// Add to blacklist // Invalidate session via stored procedure
return a.db.WithContext(ctx).Table("token_blacklist").Create(map[string]any{ return nil
"token": req.Token,
"user_id": req.UserID,
}).Error
} }
func (a *JWTAuthenticator) Authenticate(r *http.Request) (*security.UserContext, error) { func (a *JWTAuthenticator) Authenticate(r *http.Request) (*security.UserContext, error) {

View File

@@ -135,12 +135,6 @@ func (a *JWTAuthenticatorExample) Login(ctx context.Context, req LoginRequest) (
} }
func (a *JWTAuthenticatorExample) Logout(ctx context.Context, req LogoutRequest) error { func (a *JWTAuthenticatorExample) Logout(ctx context.Context, req LogoutRequest) error {
// For JWT, logout could involve token blacklisting
// Add token to blacklist table
// err := a.db.WithContext(ctx).Table("token_blacklist").Create(map[string]interface{}{
// "token": req.Token,
// "expires_at": time.Now().Add(24 * time.Hour),
// }).Error
return nil return nil
} }

View File

@@ -244,10 +244,10 @@ func (a *DatabaseAuthenticator) oauth2GetOrCreateUser(ctx context.Context, userC
var errMsg *string var errMsg *string
var userID *int var userID *int
err = a.db.QueryRowContext(ctx, ` err = a.db.QueryRowContext(ctx, fmt.Sprintf(`
SELECT p_success, p_error, p_user_id SELECT p_success, p_error, p_user_id
FROM resolvespec_oauth_getorcreateuser($1::jsonb) FROM %s($1::jsonb)
`, userJSON).Scan(&success, &errMsg, &userID) `, a.sqlNames.OAuthGetOrCreateUser), userJSON).Scan(&success, &errMsg, &userID)
if err != nil { if err != nil {
return 0, fmt.Errorf("failed to get or create user: %w", err) return 0, fmt.Errorf("failed to get or create user: %w", err)
@@ -287,10 +287,10 @@ func (a *DatabaseAuthenticator) oauth2CreateSession(ctx context.Context, session
var success bool var success bool
var errMsg *string var errMsg *string
err = a.db.QueryRowContext(ctx, ` err = a.db.QueryRowContext(ctx, fmt.Sprintf(`
SELECT p_success, p_error SELECT p_success, p_error
FROM resolvespec_oauth_createsession($1::jsonb) FROM %s($1::jsonb)
`, sessionJSON).Scan(&success, &errMsg) `, a.sqlNames.OAuthCreateSession), sessionJSON).Scan(&success, &errMsg)
if err != nil { if err != nil {
return fmt.Errorf("failed to create session: %w", err) return fmt.Errorf("failed to create session: %w", err)
@@ -385,10 +385,10 @@ func (a *DatabaseAuthenticator) OAuth2RefreshToken(ctx context.Context, refreshT
var errMsg *string var errMsg *string
var sessionData []byte var sessionData []byte
err = a.db.QueryRowContext(ctx, ` err = a.db.QueryRowContext(ctx, fmt.Sprintf(`
SELECT p_success, p_error, p_data::text SELECT p_success, p_error, p_data::text
FROM resolvespec_oauth_getrefreshtoken($1) FROM %s($1)
`, refreshToken).Scan(&success, &errMsg, &sessionData) `, a.sqlNames.OAuthGetRefreshToken), refreshToken).Scan(&success, &errMsg, &sessionData)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get session by refresh token: %w", err) return nil, fmt.Errorf("failed to get session by refresh token: %w", err)
@@ -451,10 +451,10 @@ func (a *DatabaseAuthenticator) OAuth2RefreshToken(ctx context.Context, refreshT
var updateSuccess bool var updateSuccess bool
var updateErrMsg *string var updateErrMsg *string
err = a.db.QueryRowContext(ctx, ` err = a.db.QueryRowContext(ctx, fmt.Sprintf(`
SELECT p_success, p_error SELECT p_success, p_error
FROM resolvespec_oauth_updaterefreshtoken($1::jsonb) FROM %s($1::jsonb)
`, updateJSON).Scan(&updateSuccess, &updateErrMsg) `, a.sqlNames.OAuthUpdateRefreshToken), updateJSON).Scan(&updateSuccess, &updateErrMsg)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to update session: %w", err) return nil, fmt.Errorf("failed to update session: %w", err)
@@ -472,10 +472,10 @@ func (a *DatabaseAuthenticator) OAuth2RefreshToken(ctx context.Context, refreshT
var userErrMsg *string var userErrMsg *string
var userData []byte var userData []byte
err = a.db.QueryRowContext(ctx, ` err = a.db.QueryRowContext(ctx, fmt.Sprintf(`
SELECT p_success, p_error, p_data::text SELECT p_success, p_error, p_data::text
FROM resolvespec_oauth_getuser($1) FROM %s($1)
`, session.UserID).Scan(&userSuccess, &userErrMsg, &userData) `, a.sqlNames.OAuthGetUser), session.UserID).Scan(&userSuccess, &userErrMsg, &userData)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get user data: %w", err) return nil, fmt.Errorf("failed to get user data: %w", err)

View File

@@ -11,12 +11,14 @@ import (
) )
// DatabasePasskeyProvider implements PasskeyProvider using database storage // DatabasePasskeyProvider implements PasskeyProvider using database storage
// Procedure names are configurable via SQLNames (see DefaultSQLNames for defaults)
type DatabasePasskeyProvider struct { type DatabasePasskeyProvider struct {
db *sql.DB db *sql.DB
rpID string // Relying Party ID (domain) rpID string // Relying Party ID (domain)
rpName string // Relying Party display name rpName string // Relying Party display name
rpOrigin string // Expected origin for WebAuthn rpOrigin string // Expected origin for WebAuthn
timeout int64 // Timeout in milliseconds (default: 60000) timeout int64 // Timeout in milliseconds (default: 60000)
sqlNames *SQLNames
} }
// DatabasePasskeyProviderOptions configures the passkey provider // DatabasePasskeyProviderOptions configures the passkey provider
@@ -29,6 +31,8 @@ type DatabasePasskeyProviderOptions struct {
RPOrigin string RPOrigin string
// Timeout is the timeout for operations in milliseconds (default: 60000) // Timeout is the timeout for operations in milliseconds (default: 60000)
Timeout int64 Timeout int64
// SQLNames provides custom SQL procedure/function names. If nil, uses DefaultSQLNames().
SQLNames *SQLNames
} }
// NewDatabasePasskeyProvider creates a new database-backed passkey provider // NewDatabasePasskeyProvider creates a new database-backed passkey provider
@@ -37,12 +41,15 @@ func NewDatabasePasskeyProvider(db *sql.DB, opts DatabasePasskeyProviderOptions)
opts.Timeout = 60000 // 60 seconds default opts.Timeout = 60000 // 60 seconds default
} }
sqlNames := MergeSQLNames(DefaultSQLNames(), opts.SQLNames)
return &DatabasePasskeyProvider{ return &DatabasePasskeyProvider{
db: db, db: db,
rpID: opts.RPID, rpID: opts.RPID,
rpName: opts.RPName, rpName: opts.RPName,
rpOrigin: opts.RPOrigin, rpOrigin: opts.RPOrigin,
timeout: opts.Timeout, timeout: opts.Timeout,
sqlNames: sqlNames,
} }
} }
@@ -132,7 +139,7 @@ func (p *DatabasePasskeyProvider) CompleteRegistration(ctx context.Context, user
var errorMsg sql.NullString var errorMsg sql.NullString
var credentialID sql.NullInt64 var credentialID sql.NullInt64
query := `SELECT p_success, p_error, p_credential_id FROM resolvespec_passkey_store_credential($1::jsonb)` query := fmt.Sprintf(`SELECT p_success, p_error, p_credential_id FROM %s($1::jsonb)`, p.sqlNames.PasskeyStoreCredential)
err = p.db.QueryRowContext(ctx, query, string(credJSON)).Scan(&success, &errorMsg, &credentialID) err = p.db.QueryRowContext(ctx, query, string(credJSON)).Scan(&success, &errorMsg, &credentialID)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to store credential: %w", err) return nil, fmt.Errorf("failed to store credential: %w", err)
@@ -173,7 +180,7 @@ func (p *DatabasePasskeyProvider) BeginAuthentication(ctx context.Context, usern
var userID sql.NullInt64 var userID sql.NullInt64
var credentialsJSON sql.NullString var credentialsJSON sql.NullString
query := `SELECT p_success, p_error, p_user_id, p_credentials::text FROM resolvespec_passkey_get_credentials_by_username($1)` query := fmt.Sprintf(`SELECT p_success, p_error, p_user_id, p_credentials::text FROM %s($1)`, p.sqlNames.PasskeyGetCredsByUsername)
err := p.db.QueryRowContext(ctx, query, username).Scan(&success, &errorMsg, &userID, &credentialsJSON) err := p.db.QueryRowContext(ctx, query, username).Scan(&success, &errorMsg, &userID, &credentialsJSON)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get credentials: %w", err) return nil, fmt.Errorf("failed to get credentials: %w", err)
@@ -233,7 +240,7 @@ func (p *DatabasePasskeyProvider) CompleteAuthentication(ctx context.Context, re
var errorMsg sql.NullString var errorMsg sql.NullString
var credentialJSON sql.NullString var credentialJSON sql.NullString
query := `SELECT p_success, p_error, p_credential::text FROM resolvespec_passkey_get_credential($1)` query := fmt.Sprintf(`SELECT p_success, p_error, p_credential::text FROM %s($1)`, p.sqlNames.PasskeyGetCredential)
err := p.db.QueryRowContext(ctx, query, response.RawID).Scan(&success, &errorMsg, &credentialJSON) err := p.db.QueryRowContext(ctx, query, response.RawID).Scan(&success, &errorMsg, &credentialJSON)
if err != nil { if err != nil {
return 0, fmt.Errorf("failed to get credential: %w", err) return 0, fmt.Errorf("failed to get credential: %w", err)
@@ -264,7 +271,7 @@ func (p *DatabasePasskeyProvider) CompleteAuthentication(ctx context.Context, re
var updateError sql.NullString var updateError sql.NullString
var cloneWarning sql.NullBool var cloneWarning sql.NullBool
updateQuery := `SELECT p_success, p_error, p_clone_warning FROM resolvespec_passkey_update_counter($1, $2)` updateQuery := fmt.Sprintf(`SELECT p_success, p_error, p_clone_warning FROM %s($1, $2)`, p.sqlNames.PasskeyUpdateCounter)
err = p.db.QueryRowContext(ctx, updateQuery, response.RawID, newCounter).Scan(&updateSuccess, &updateError, &cloneWarning) err = p.db.QueryRowContext(ctx, updateQuery, response.RawID, newCounter).Scan(&updateSuccess, &updateError, &cloneWarning)
if err != nil { if err != nil {
return 0, fmt.Errorf("failed to update counter: %w", err) return 0, fmt.Errorf("failed to update counter: %w", err)
@@ -283,7 +290,7 @@ func (p *DatabasePasskeyProvider) GetCredentials(ctx context.Context, userID int
var errorMsg sql.NullString var errorMsg sql.NullString
var credentialsJSON sql.NullString var credentialsJSON sql.NullString
query := `SELECT p_success, p_error, p_credentials::text FROM resolvespec_passkey_get_user_credentials($1)` query := fmt.Sprintf(`SELECT p_success, p_error, p_credentials::text FROM %s($1)`, p.sqlNames.PasskeyGetUserCredentials)
err := p.db.QueryRowContext(ctx, query, userID).Scan(&success, &errorMsg, &credentialsJSON) err := p.db.QueryRowContext(ctx, query, userID).Scan(&success, &errorMsg, &credentialsJSON)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get credentials: %w", err) return nil, fmt.Errorf("failed to get credentials: %w", err)
@@ -362,7 +369,7 @@ func (p *DatabasePasskeyProvider) DeleteCredential(ctx context.Context, userID i
var success bool var success bool
var errorMsg sql.NullString var errorMsg sql.NullString
query := `SELECT p_success, p_error FROM resolvespec_passkey_delete_credential($1, $2)` query := fmt.Sprintf(`SELECT p_success, p_error FROM %s($1, $2)`, p.sqlNames.PasskeyDeleteCredential)
err = p.db.QueryRowContext(ctx, query, userID, credID).Scan(&success, &errorMsg) err = p.db.QueryRowContext(ctx, query, userID, credID).Scan(&success, &errorMsg)
if err != nil { if err != nil {
return fmt.Errorf("failed to delete credential: %w", err) return fmt.Errorf("failed to delete credential: %w", err)
@@ -388,7 +395,7 @@ func (p *DatabasePasskeyProvider) UpdateCredentialName(ctx context.Context, user
var success bool var success bool
var errorMsg sql.NullString var errorMsg sql.NullString
query := `SELECT p_success, p_error FROM resolvespec_passkey_update_name($1, $2, $3)` query := fmt.Sprintf(`SELECT p_success, p_error FROM %s($1, $2, $3)`, p.sqlNames.PasskeyUpdateName)
err = p.db.QueryRowContext(ctx, query, userID, credID, name).Scan(&success, &errorMsg) err = p.db.QueryRowContext(ctx, query, userID, credID, name).Scan(&success, &errorMsg)
if err != nil { if err != nil {
return fmt.Errorf("failed to update credential name: %w", err) return fmt.Errorf("failed to update credential name: %w", err)

View File

@@ -58,8 +58,7 @@ func (a *HeaderAuthenticator) Authenticate(r *http.Request) (*UserContext, error
// DatabaseAuthenticator provides session-based authentication with database storage // DatabaseAuthenticator provides session-based authentication with database storage
// All database operations go through stored procedures for security and consistency // All database operations go through stored procedures for security and consistency
// Requires stored procedures: resolvespec_login, resolvespec_logout, resolvespec_session, // Procedure names are configurable via SQLNames (see DefaultSQLNames for defaults)
// resolvespec_session_update, resolvespec_refresh_token
// See database_schema.sql for procedure definitions // See database_schema.sql for procedure definitions
// Also supports multiple OAuth2 providers configured with WithOAuth2() // Also supports multiple OAuth2 providers configured with WithOAuth2()
// Also supports passkey authentication configured with WithPasskey() // Also supports passkey authentication configured with WithPasskey()
@@ -67,6 +66,7 @@ type DatabaseAuthenticator struct {
db *sql.DB db *sql.DB
cache *cache.Cache cache *cache.Cache
cacheTTL time.Duration cacheTTL time.Duration
sqlNames *SQLNames
// OAuth2 providers registry (multiple providers supported) // OAuth2 providers registry (multiple providers supported)
oauth2Providers map[string]*OAuth2Provider oauth2Providers map[string]*OAuth2Provider
@@ -85,6 +85,9 @@ type DatabaseAuthenticatorOptions struct {
Cache *cache.Cache Cache *cache.Cache
// PasskeyProvider is an optional passkey provider for WebAuthn/FIDO2 authentication // PasskeyProvider is an optional passkey provider for WebAuthn/FIDO2 authentication
PasskeyProvider PasskeyProvider PasskeyProvider PasskeyProvider
// SQLNames provides custom SQL procedure/function names. If nil, uses DefaultSQLNames().
// Partial overrides are supported: only set the fields you want to change.
SQLNames *SQLNames
} }
func NewDatabaseAuthenticator(db *sql.DB) *DatabaseAuthenticator { func NewDatabaseAuthenticator(db *sql.DB) *DatabaseAuthenticator {
@@ -103,10 +106,13 @@ func NewDatabaseAuthenticatorWithOptions(db *sql.DB, opts DatabaseAuthenticatorO
cacheInstance = cache.GetDefaultCache() cacheInstance = cache.GetDefaultCache()
} }
sqlNames := MergeSQLNames(DefaultSQLNames(), opts.SQLNames)
return &DatabaseAuthenticator{ return &DatabaseAuthenticator{
db: db, db: db,
cache: cacheInstance, cache: cacheInstance,
cacheTTL: opts.CacheTTL, cacheTTL: opts.CacheTTL,
sqlNames: sqlNames,
passkeyProvider: opts.PasskeyProvider, passkeyProvider: opts.PasskeyProvider,
} }
} }
@@ -118,12 +124,11 @@ func (a *DatabaseAuthenticator) Login(ctx context.Context, req LoginRequest) (*L
return nil, fmt.Errorf("failed to marshal login request: %w", err) return nil, fmt.Errorf("failed to marshal login request: %w", err)
} }
// Call resolvespec_login stored procedure
var success bool var success bool
var errorMsg sql.NullString var errorMsg sql.NullString
var dataJSON sql.NullString var dataJSON sql.NullString
query := `SELECT p_success, p_error, p_data::text FROM resolvespec_login($1::jsonb)` query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.Login)
err = a.db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON) err = a.db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
if err != nil { if err != nil {
return nil, fmt.Errorf("login query failed: %w", err) return nil, fmt.Errorf("login query failed: %w", err)
@@ -153,12 +158,11 @@ func (a *DatabaseAuthenticator) Register(ctx context.Context, req RegisterReques
return nil, fmt.Errorf("failed to marshal register request: %w", err) return nil, fmt.Errorf("failed to marshal register request: %w", err)
} }
// Call resolvespec_register stored procedure
var success bool var success bool
var errorMsg sql.NullString var errorMsg sql.NullString
var dataJSON sql.NullString var dataJSON sql.NullString
query := `SELECT p_success, p_error, p_data::text FROM resolvespec_register($1::jsonb)` query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.Register)
err = a.db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON) err = a.db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
if err != nil { if err != nil {
return nil, fmt.Errorf("register query failed: %w", err) return nil, fmt.Errorf("register query failed: %w", err)
@@ -187,12 +191,11 @@ func (a *DatabaseAuthenticator) Logout(ctx context.Context, req LogoutRequest) e
return fmt.Errorf("failed to marshal logout request: %w", err) return fmt.Errorf("failed to marshal logout request: %w", err)
} }
// Call resolvespec_logout stored procedure
var success bool var success bool
var errorMsg sql.NullString var errorMsg sql.NullString
var dataJSON sql.NullString var dataJSON sql.NullString
query := `SELECT p_success, p_error, p_data::text FROM resolvespec_logout($1::jsonb)` query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.Logout)
err = a.db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON) err = a.db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
if err != nil { if err != nil {
return fmt.Errorf("logout query failed: %w", err) return fmt.Errorf("logout query failed: %w", err)
@@ -266,7 +269,7 @@ func (a *DatabaseAuthenticator) Authenticate(r *http.Request) (*UserContext, err
var errorMsg sql.NullString var errorMsg sql.NullString
var userJSON sql.NullString var userJSON sql.NullString
query := `SELECT p_success, p_error, p_user::text FROM resolvespec_session($1, $2)` query := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2)`, a.sqlNames.Session)
err := a.db.QueryRowContext(r.Context(), query, token, reference).Scan(&success, &errorMsg, &userJSON) err := a.db.QueryRowContext(r.Context(), query, token, reference).Scan(&success, &errorMsg, &userJSON)
if err != nil { if err != nil {
return nil, fmt.Errorf("session query failed: %w", err) return nil, fmt.Errorf("session query failed: %w", err)
@@ -338,24 +341,22 @@ func (a *DatabaseAuthenticator) updateSessionActivity(ctx context.Context, sessi
return return
} }
// Call resolvespec_session_update stored procedure
var success bool var success bool
var errorMsg sql.NullString var errorMsg sql.NullString
var updatedUserJSON sql.NullString var updatedUserJSON sql.NullString
query := `SELECT p_success, p_error, p_user::text FROM resolvespec_session_update($1, $2::jsonb)` query := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2::jsonb)`, a.sqlNames.SessionUpdate)
_ = a.db.QueryRowContext(ctx, query, sessionToken, string(userJSON)).Scan(&success, &errorMsg, &updatedUserJSON) _ = a.db.QueryRowContext(ctx, query, sessionToken, string(userJSON)).Scan(&success, &errorMsg, &updatedUserJSON)
} }
// RefreshToken implements Refreshable interface // RefreshToken implements Refreshable interface
func (a *DatabaseAuthenticator) RefreshToken(ctx context.Context, refreshToken string) (*LoginResponse, error) { func (a *DatabaseAuthenticator) RefreshToken(ctx context.Context, refreshToken string) (*LoginResponse, error) {
// Call api_refresh_token stored procedure
// First, we need to get the current user context for the refresh token // First, we need to get the current user context for the refresh token
var success bool var success bool
var errorMsg sql.NullString var errorMsg sql.NullString
var userJSON sql.NullString var userJSON sql.NullString
// Get current session to pass to refresh // Get current session to pass to refresh
query := `SELECT p_success, p_error, p_user::text FROM resolvespec_session($1, $2)` query := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2)`, a.sqlNames.Session)
err := a.db.QueryRowContext(ctx, query, refreshToken, "refresh").Scan(&success, &errorMsg, &userJSON) err := a.db.QueryRowContext(ctx, query, refreshToken, "refresh").Scan(&success, &errorMsg, &userJSON)
if err != nil { if err != nil {
return nil, fmt.Errorf("refresh token query failed: %w", err) return nil, fmt.Errorf("refresh token query failed: %w", err)
@@ -368,12 +369,11 @@ func (a *DatabaseAuthenticator) RefreshToken(ctx context.Context, refreshToken s
return nil, fmt.Errorf("invalid refresh token") return nil, fmt.Errorf("invalid refresh token")
} }
// Call resolvespec_refresh_token to generate new token
var newSuccess bool var newSuccess bool
var newErrorMsg sql.NullString var newErrorMsg sql.NullString
var newUserJSON sql.NullString var newUserJSON sql.NullString
refreshQuery := `SELECT p_success, p_error, p_user::text FROM resolvespec_refresh_token($1, $2::jsonb)` refreshQuery := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2::jsonb)`, a.sqlNames.RefreshToken)
err = a.db.QueryRowContext(ctx, refreshQuery, refreshToken, userJSON).Scan(&newSuccess, &newErrorMsg, &newUserJSON) err = a.db.QueryRowContext(ctx, refreshQuery, refreshToken, userJSON).Scan(&newSuccess, &newErrorMsg, &newUserJSON)
if err != nil { if err != nil {
return nil, fmt.Errorf("refresh token generation failed: %w", err) return nil, fmt.Errorf("refresh token generation failed: %w", err)
@@ -401,27 +401,28 @@ func (a *DatabaseAuthenticator) RefreshToken(ctx context.Context, refreshToken s
// JWTAuthenticator provides JWT token-based authentication // JWTAuthenticator provides JWT token-based authentication
// All database operations go through stored procedures // All database operations go through stored procedures
// Requires stored procedures: resolvespec_jwt_login, resolvespec_jwt_logout // Procedure names are configurable via SQLNames (see DefaultSQLNames for defaults)
// NOTE: JWT signing/verification requires github.com/golang-jwt/jwt/v5 to be installed and imported // NOTE: JWT signing/verification requires github.com/golang-jwt/jwt/v5 to be installed and imported
type JWTAuthenticator struct { type JWTAuthenticator struct {
secretKey []byte secretKey []byte
db *sql.DB db *sql.DB
sqlNames *SQLNames
} }
func NewJWTAuthenticator(secretKey string, db *sql.DB) *JWTAuthenticator { func NewJWTAuthenticator(secretKey string, db *sql.DB, names ...*SQLNames) *JWTAuthenticator {
return &JWTAuthenticator{ return &JWTAuthenticator{
secretKey: []byte(secretKey), secretKey: []byte(secretKey),
db: db, db: db,
sqlNames: resolveSQLNames(names...),
} }
} }
func (a *JWTAuthenticator) Login(ctx context.Context, req LoginRequest) (*LoginResponse, error) { func (a *JWTAuthenticator) Login(ctx context.Context, req LoginRequest) (*LoginResponse, error) {
// Call resolvespec_jwt_login stored procedure
var success bool var success bool
var errorMsg sql.NullString var errorMsg sql.NullString
var userJSON []byte var userJSON []byte
query := `SELECT p_success, p_error, p_user FROM resolvespec_jwt_login($1, $2)` query := fmt.Sprintf(`SELECT p_success, p_error, p_user FROM %s($1, $2)`, a.sqlNames.JWTLogin)
err := a.db.QueryRowContext(ctx, query, req.Username, req.Password).Scan(&success, &errorMsg, &userJSON) err := a.db.QueryRowContext(ctx, query, req.Username, req.Password).Scan(&success, &errorMsg, &userJSON)
if err != nil { if err != nil {
return nil, fmt.Errorf("login query failed: %w", err) return nil, fmt.Errorf("login query failed: %w", err)
@@ -471,11 +472,10 @@ func (a *JWTAuthenticator) Login(ctx context.Context, req LoginRequest) (*LoginR
} }
func (a *JWTAuthenticator) Logout(ctx context.Context, req LogoutRequest) error { func (a *JWTAuthenticator) Logout(ctx context.Context, req LogoutRequest) error {
// Call resolvespec_jwt_logout stored procedure
var success bool var success bool
var errorMsg sql.NullString var errorMsg sql.NullString
query := `SELECT p_success, p_error FROM resolvespec_jwt_logout($1, $2)` query := fmt.Sprintf(`SELECT p_success, p_error FROM %s($1, $2)`, a.sqlNames.JWTLogout)
err := a.db.QueryRowContext(ctx, query, req.Token, req.UserID).Scan(&success, &errorMsg) err := a.db.QueryRowContext(ctx, query, req.Token, req.UserID).Scan(&success, &errorMsg)
if err != nil { if err != nil {
return fmt.Errorf("logout query failed: %w", err) return fmt.Errorf("logout query failed: %w", err)
@@ -511,24 +511,24 @@ func (a *JWTAuthenticator) Authenticate(r *http.Request) (*UserContext, error) {
// DatabaseColumnSecurityProvider loads column security from database // DatabaseColumnSecurityProvider loads column security from database
// All database operations go through stored procedures // All database operations go through stored procedures
// Requires stored procedure: resolvespec_column_security // Procedure names are configurable via SQLNames (see DefaultSQLNames for defaults)
type DatabaseColumnSecurityProvider struct { type DatabaseColumnSecurityProvider struct {
db *sql.DB db *sql.DB
sqlNames *SQLNames
} }
func NewDatabaseColumnSecurityProvider(db *sql.DB) *DatabaseColumnSecurityProvider { func NewDatabaseColumnSecurityProvider(db *sql.DB, names ...*SQLNames) *DatabaseColumnSecurityProvider {
return &DatabaseColumnSecurityProvider{db: db} return &DatabaseColumnSecurityProvider{db: db, sqlNames: resolveSQLNames(names...)}
} }
func (p *DatabaseColumnSecurityProvider) GetColumnSecurity(ctx context.Context, userID int, schema, table string) ([]ColumnSecurity, error) { func (p *DatabaseColumnSecurityProvider) GetColumnSecurity(ctx context.Context, userID int, schema, table string) ([]ColumnSecurity, error) {
var rules []ColumnSecurity var rules []ColumnSecurity
// Call resolvespec_column_security stored procedure
var success bool var success bool
var errorMsg sql.NullString var errorMsg sql.NullString
var rulesJSON []byte var rulesJSON []byte
query := `SELECT p_success, p_error, p_rules FROM resolvespec_column_security($1, $2, $3)` query := fmt.Sprintf(`SELECT p_success, p_error, p_rules FROM %s($1, $2, $3)`, p.sqlNames.ColumnSecurity)
err := p.db.QueryRowContext(ctx, query, userID, schema, table).Scan(&success, &errorMsg, &rulesJSON) err := p.db.QueryRowContext(ctx, query, userID, schema, table).Scan(&success, &errorMsg, &rulesJSON)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to load column security: %w", err) return nil, fmt.Errorf("failed to load column security: %w", err)
@@ -576,21 +576,21 @@ func (p *DatabaseColumnSecurityProvider) GetColumnSecurity(ctx context.Context,
// DatabaseRowSecurityProvider loads row security from database // DatabaseRowSecurityProvider loads row security from database
// All database operations go through stored procedures // All database operations go through stored procedures
// Requires stored procedure: resolvespec_row_security // Procedure names are configurable via SQLNames (see DefaultSQLNames for defaults)
type DatabaseRowSecurityProvider struct { type DatabaseRowSecurityProvider struct {
db *sql.DB db *sql.DB
sqlNames *SQLNames
} }
func NewDatabaseRowSecurityProvider(db *sql.DB) *DatabaseRowSecurityProvider { func NewDatabaseRowSecurityProvider(db *sql.DB, names ...*SQLNames) *DatabaseRowSecurityProvider {
return &DatabaseRowSecurityProvider{db: db} return &DatabaseRowSecurityProvider{db: db, sqlNames: resolveSQLNames(names...)}
} }
func (p *DatabaseRowSecurityProvider) GetRowSecurity(ctx context.Context, userID int, schema, table string) (RowSecurity, error) { func (p *DatabaseRowSecurityProvider) GetRowSecurity(ctx context.Context, userID int, schema, table string) (RowSecurity, error) {
var template string var template string
var hasBlock bool var hasBlock bool
// Call resolvespec_row_security stored procedure query := fmt.Sprintf(`SELECT p_template, p_block FROM %s($1, $2, $3)`, p.sqlNames.RowSecurity)
query := `SELECT p_template, p_block FROM resolvespec_row_security($1, $2, $3)`
err := p.db.QueryRowContext(ctx, query, schema, table, userID).Scan(&template, &hasBlock) err := p.db.QueryRowContext(ctx, query, schema, table, userID).Scan(&template, &hasBlock)
if err != nil { if err != nil {
@@ -758,56 +758,47 @@ func (a *DatabaseAuthenticator) LoginWithPasskey(ctx context.Context, req Passke
return nil, fmt.Errorf("passkey authentication failed: %w", err) return nil, fmt.Errorf("passkey authentication failed: %w", err)
} }
// Get user data from database // Build request JSON for passkey login stored procedure
var username, email, roles string reqData := map[string]any{
var userLevel int "user_id": userID,
query := `SELECT username, email, user_level, COALESCE(roles, '') FROM users WHERE id = $1 AND is_active = true`
err = a.db.QueryRowContext(ctx, query, userID).Scan(&username, &email, &userLevel, &roles)
if err != nil {
return nil, fmt.Errorf("failed to get user data: %w", err)
} }
// Generate session token
sessionToken := "sess_" + generateRandomString(32) + "_" + fmt.Sprintf("%d", time.Now().Unix())
expiresAt := time.Now().Add(24 * time.Hour)
// Extract IP and user agent from claims
ipAddress := ""
userAgent := ""
if req.Claims != nil { if req.Claims != nil {
if ip, ok := req.Claims["ip_address"].(string); ok { if ip, ok := req.Claims["ip_address"].(string); ok {
ipAddress = ip reqData["ip_address"] = ip
} }
if ua, ok := req.Claims["user_agent"].(string); ok { if ua, ok := req.Claims["user_agent"].(string); ok {
userAgent = ua reqData["user_agent"] = ua
} }
} }
// Create session reqJSON, err := json.Marshal(reqData)
insertQuery := `INSERT INTO user_sessions (session_token, user_id, expires_at, ip_address, user_agent, last_activity_at)
VALUES ($1, $2, $3, $4, $5, now())`
_, err = a.db.ExecContext(ctx, insertQuery, sessionToken, userID, expiresAt, ipAddress, userAgent)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create session: %w", err) return nil, fmt.Errorf("failed to marshal passkey login request: %w", err)
} }
// Update last login var success bool
updateQuery := `UPDATE users SET last_login_at = now() WHERE id = $1` var errorMsg sql.NullString
_, _ = a.db.ExecContext(ctx, updateQuery, userID) var dataJSON sql.NullString
// Return login response query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.PasskeyLogin)
return &LoginResponse{ err = a.db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
Token: sessionToken, if err != nil {
User: &UserContext{ return nil, fmt.Errorf("passkey login query failed: %w", err)
UserID: userID, }
UserName: username,
Email: email, if !success {
UserLevel: userLevel, if errorMsg.Valid {
SessionID: sessionToken, return nil, fmt.Errorf("%s", errorMsg.String)
Roles: parseRoles(roles), }
}, return nil, fmt.Errorf("passkey login failed")
ExpiresIn: int64(24 * time.Hour.Seconds()), }
}, nil
var response LoginResponse
if err := json.Unmarshal([]byte(dataJSON.String), &response); err != nil {
return nil, fmt.Errorf("failed to parse passkey login response: %w", err)
}
return &response, nil
} }
// GetPasskeyCredentials returns all passkey credentials for a user // GetPasskeyCredentials returns all passkey credentials for a user

222
pkg/security/sql_names.go Normal file
View File

@@ -0,0 +1,222 @@
package security
import (
"fmt"
"reflect"
"regexp"
)
var validSQLIdentifier = regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_]*$`)
// SQLNames defines all configurable SQL stored procedure and table names
// used by the security package. Override individual fields to remap
// to custom database objects. Use DefaultSQLNames() for baseline defaults,
// and MergeSQLNames() to apply partial overrides.
type SQLNames struct {
// Auth procedures (DatabaseAuthenticator)
Login string // default: "resolvespec_login"
Register string // default: "resolvespec_register"
Logout string // default: "resolvespec_logout"
Session string // default: "resolvespec_session"
SessionUpdate string // default: "resolvespec_session_update"
RefreshToken string // default: "resolvespec_refresh_token"
// JWT procedures (JWTAuthenticator)
JWTLogin string // default: "resolvespec_jwt_login"
JWTLogout string // default: "resolvespec_jwt_logout"
// Security policy procedures
ColumnSecurity string // default: "resolvespec_column_security"
RowSecurity string // default: "resolvespec_row_security"
// TOTP procedures (DatabaseTwoFactorProvider)
TOTPEnable string // default: "resolvespec_totp_enable"
TOTPDisable string // default: "resolvespec_totp_disable"
TOTPGetStatus string // default: "resolvespec_totp_get_status"
TOTPGetSecret string // default: "resolvespec_totp_get_secret"
TOTPRegenerateBackup string // default: "resolvespec_totp_regenerate_backup_codes"
TOTPValidateBackupCode string // default: "resolvespec_totp_validate_backup_code"
// Passkey procedures (DatabasePasskeyProvider)
PasskeyStoreCredential string // default: "resolvespec_passkey_store_credential"
PasskeyGetCredsByUsername string // default: "resolvespec_passkey_get_credentials_by_username"
PasskeyGetCredential string // default: "resolvespec_passkey_get_credential"
PasskeyUpdateCounter string // default: "resolvespec_passkey_update_counter"
PasskeyGetUserCredentials string // default: "resolvespec_passkey_get_user_credentials"
PasskeyDeleteCredential string // default: "resolvespec_passkey_delete_credential"
PasskeyUpdateName string // default: "resolvespec_passkey_update_name"
PasskeyLogin string // default: "resolvespec_passkey_login"
// OAuth2 procedures (DatabaseAuthenticator OAuth2 methods)
OAuthGetOrCreateUser string // default: "resolvespec_oauth_getorcreateuser"
OAuthCreateSession string // default: "resolvespec_oauth_createsession"
OAuthGetRefreshToken string // default: "resolvespec_oauth_getrefreshtoken"
OAuthUpdateRefreshToken string // default: "resolvespec_oauth_updaterefreshtoken"
OAuthGetUser string // default: "resolvespec_oauth_getuser"
}
// DefaultSQLNames returns an SQLNames with all default resolvespec_* values.
func DefaultSQLNames() *SQLNames {
return &SQLNames{
Login: "resolvespec_login",
Register: "resolvespec_register",
Logout: "resolvespec_logout",
Session: "resolvespec_session",
SessionUpdate: "resolvespec_session_update",
RefreshToken: "resolvespec_refresh_token",
JWTLogin: "resolvespec_jwt_login",
JWTLogout: "resolvespec_jwt_logout",
ColumnSecurity: "resolvespec_column_security",
RowSecurity: "resolvespec_row_security",
TOTPEnable: "resolvespec_totp_enable",
TOTPDisable: "resolvespec_totp_disable",
TOTPGetStatus: "resolvespec_totp_get_status",
TOTPGetSecret: "resolvespec_totp_get_secret",
TOTPRegenerateBackup: "resolvespec_totp_regenerate_backup_codes",
TOTPValidateBackupCode: "resolvespec_totp_validate_backup_code",
PasskeyStoreCredential: "resolvespec_passkey_store_credential",
PasskeyGetCredsByUsername: "resolvespec_passkey_get_credentials_by_username",
PasskeyGetCredential: "resolvespec_passkey_get_credential",
PasskeyUpdateCounter: "resolvespec_passkey_update_counter",
PasskeyGetUserCredentials: "resolvespec_passkey_get_user_credentials",
PasskeyDeleteCredential: "resolvespec_passkey_delete_credential",
PasskeyUpdateName: "resolvespec_passkey_update_name",
PasskeyLogin: "resolvespec_passkey_login",
OAuthGetOrCreateUser: "resolvespec_oauth_getorcreateuser",
OAuthCreateSession: "resolvespec_oauth_createsession",
OAuthGetRefreshToken: "resolvespec_oauth_getrefreshtoken",
OAuthUpdateRefreshToken: "resolvespec_oauth_updaterefreshtoken",
OAuthGetUser: "resolvespec_oauth_getuser",
}
}
// MergeSQLNames returns a copy of base with any non-empty fields from override applied.
// If override is nil, a copy of base is returned.
func MergeSQLNames(base, override *SQLNames) *SQLNames {
if override == nil {
copied := *base
return &copied
}
merged := *base
if override.Login != "" {
merged.Login = override.Login
}
if override.Register != "" {
merged.Register = override.Register
}
if override.Logout != "" {
merged.Logout = override.Logout
}
if override.Session != "" {
merged.Session = override.Session
}
if override.SessionUpdate != "" {
merged.SessionUpdate = override.SessionUpdate
}
if override.RefreshToken != "" {
merged.RefreshToken = override.RefreshToken
}
if override.JWTLogin != "" {
merged.JWTLogin = override.JWTLogin
}
if override.JWTLogout != "" {
merged.JWTLogout = override.JWTLogout
}
if override.ColumnSecurity != "" {
merged.ColumnSecurity = override.ColumnSecurity
}
if override.RowSecurity != "" {
merged.RowSecurity = override.RowSecurity
}
if override.TOTPEnable != "" {
merged.TOTPEnable = override.TOTPEnable
}
if override.TOTPDisable != "" {
merged.TOTPDisable = override.TOTPDisable
}
if override.TOTPGetStatus != "" {
merged.TOTPGetStatus = override.TOTPGetStatus
}
if override.TOTPGetSecret != "" {
merged.TOTPGetSecret = override.TOTPGetSecret
}
if override.TOTPRegenerateBackup != "" {
merged.TOTPRegenerateBackup = override.TOTPRegenerateBackup
}
if override.TOTPValidateBackupCode != "" {
merged.TOTPValidateBackupCode = override.TOTPValidateBackupCode
}
if override.PasskeyStoreCredential != "" {
merged.PasskeyStoreCredential = override.PasskeyStoreCredential
}
if override.PasskeyGetCredsByUsername != "" {
merged.PasskeyGetCredsByUsername = override.PasskeyGetCredsByUsername
}
if override.PasskeyGetCredential != "" {
merged.PasskeyGetCredential = override.PasskeyGetCredential
}
if override.PasskeyUpdateCounter != "" {
merged.PasskeyUpdateCounter = override.PasskeyUpdateCounter
}
if override.PasskeyGetUserCredentials != "" {
merged.PasskeyGetUserCredentials = override.PasskeyGetUserCredentials
}
if override.PasskeyDeleteCredential != "" {
merged.PasskeyDeleteCredential = override.PasskeyDeleteCredential
}
if override.PasskeyUpdateName != "" {
merged.PasskeyUpdateName = override.PasskeyUpdateName
}
if override.PasskeyLogin != "" {
merged.PasskeyLogin = override.PasskeyLogin
}
if override.OAuthGetOrCreateUser != "" {
merged.OAuthGetOrCreateUser = override.OAuthGetOrCreateUser
}
if override.OAuthCreateSession != "" {
merged.OAuthCreateSession = override.OAuthCreateSession
}
if override.OAuthGetRefreshToken != "" {
merged.OAuthGetRefreshToken = override.OAuthGetRefreshToken
}
if override.OAuthUpdateRefreshToken != "" {
merged.OAuthUpdateRefreshToken = override.OAuthUpdateRefreshToken
}
if override.OAuthGetUser != "" {
merged.OAuthGetUser = override.OAuthGetUser
}
return &merged
}
// ValidateSQLNames checks that all non-empty fields in names are valid SQL identifiers.
// Returns an error if any field contains invalid characters.
func ValidateSQLNames(names *SQLNames) error {
v := reflect.ValueOf(names).Elem()
typ := v.Type()
for i := 0; i < v.NumField(); i++ {
field := v.Field(i)
if field.Kind() != reflect.String {
continue
}
val := field.String()
if val != "" && !validSQLIdentifier.MatchString(val) {
return fmt.Errorf("SQLNames.%s contains invalid characters: %q", typ.Field(i).Name, val)
}
}
return nil
}
// resolveSQLNames merges an optional override with defaults.
// Used by constructors that accept variadic *SQLNames.
func resolveSQLNames(override ...*SQLNames) *SQLNames {
if len(override) > 0 && override[0] != nil {
return MergeSQLNames(DefaultSQLNames(), override[0])
}
return DefaultSQLNames()
}

View File

@@ -0,0 +1,145 @@
package security
import (
"reflect"
"testing"
)
func TestDefaultSQLNames_AllFieldsNonEmpty(t *testing.T) {
names := DefaultSQLNames()
v := reflect.ValueOf(names).Elem()
typ := v.Type()
for i := 0; i < v.NumField(); i++ {
field := v.Field(i)
if field.Kind() != reflect.String {
continue
}
if field.String() == "" {
t.Errorf("DefaultSQLNames().%s is empty", typ.Field(i).Name)
}
}
}
func TestMergeSQLNames_PartialOverride(t *testing.T) {
base := DefaultSQLNames()
override := &SQLNames{
Login: "custom_login",
TOTPEnable: "custom_totp_enable",
PasskeyLogin: "custom_passkey_login",
}
merged := MergeSQLNames(base, override)
if merged.Login != "custom_login" {
t.Errorf("MergeSQLNames().Login = %q, want %q", merged.Login, "custom_login")
}
if merged.TOTPEnable != "custom_totp_enable" {
t.Errorf("MergeSQLNames().TOTPEnable = %q, want %q", merged.TOTPEnable, "custom_totp_enable")
}
if merged.PasskeyLogin != "custom_passkey_login" {
t.Errorf("MergeSQLNames().PasskeyLogin = %q, want %q", merged.PasskeyLogin, "custom_passkey_login")
}
// Non-overridden fields should retain defaults
if merged.Logout != "resolvespec_logout" {
t.Errorf("MergeSQLNames().Logout = %q, want %q", merged.Logout, "resolvespec_logout")
}
if merged.Session != "resolvespec_session" {
t.Errorf("MergeSQLNames().Session = %q, want %q", merged.Session, "resolvespec_session")
}
}
func TestMergeSQLNames_NilOverride(t *testing.T) {
base := DefaultSQLNames()
merged := MergeSQLNames(base, nil)
// Should be a copy, not the same pointer
if merged == base {
t.Error("MergeSQLNames with nil override should return a copy, not the same pointer")
}
// All values should match
v1 := reflect.ValueOf(base).Elem()
v2 := reflect.ValueOf(merged).Elem()
typ := v1.Type()
for i := 0; i < v1.NumField(); i++ {
f1 := v1.Field(i)
f2 := v2.Field(i)
if f1.Kind() != reflect.String {
continue
}
if f1.String() != f2.String() {
t.Errorf("MergeSQLNames(base, nil).%s = %q, want %q", typ.Field(i).Name, f2.String(), f1.String())
}
}
}
func TestMergeSQLNames_DoesNotMutateBase(t *testing.T) {
base := DefaultSQLNames()
originalLogin := base.Login
override := &SQLNames{Login: "custom_login"}
_ = MergeSQLNames(base, override)
if base.Login != originalLogin {
t.Errorf("MergeSQLNames mutated base: Login = %q, want %q", base.Login, originalLogin)
}
}
func TestMergeSQLNames_AllFieldsMerged(t *testing.T) {
base := DefaultSQLNames()
override := &SQLNames{}
v := reflect.ValueOf(override).Elem()
for i := 0; i < v.NumField(); i++ {
if v.Field(i).Kind() == reflect.String {
v.Field(i).SetString("custom_sentinel")
}
}
merged := MergeSQLNames(base, override)
mv := reflect.ValueOf(merged).Elem()
typ := mv.Type()
for i := 0; i < mv.NumField(); i++ {
if mv.Field(i).Kind() != reflect.String {
continue
}
if mv.Field(i).String() != "custom_sentinel" {
t.Errorf("MergeSQLNames did not merge field %s", typ.Field(i).Name)
}
}
}
func TestValidateSQLNames_Valid(t *testing.T) {
names := DefaultSQLNames()
if err := ValidateSQLNames(names); err != nil {
t.Errorf("ValidateSQLNames(defaults) error = %v", err)
}
}
func TestValidateSQLNames_Invalid(t *testing.T) {
names := DefaultSQLNames()
names.Login = "resolvespec_login; DROP TABLE users; --"
err := ValidateSQLNames(names)
if err == nil {
t.Error("ValidateSQLNames should reject names with invalid characters")
}
}
func TestResolveSQLNames_NoOverride(t *testing.T) {
names := resolveSQLNames()
if names.Login != "resolvespec_login" {
t.Errorf("resolveSQLNames().Login = %q, want default", names.Login)
}
}
func TestResolveSQLNames_WithOverride(t *testing.T) {
names := resolveSQLNames(&SQLNames{Login: "custom_login"})
if names.Login != "custom_login" {
t.Errorf("resolveSQLNames().Login = %q, want %q", names.Login, "custom_login")
}
if names.Logout != "resolvespec_logout" {
t.Errorf("resolveSQLNames().Logout = %q, want default", names.Logout)
}
}

View File

@@ -9,23 +9,23 @@ import (
) )
// DatabaseTwoFactorProvider implements TwoFactorAuthProvider using PostgreSQL stored procedures // DatabaseTwoFactorProvider implements TwoFactorAuthProvider using PostgreSQL stored procedures
// Requires stored procedures: resolvespec_totp_enable, resolvespec_totp_disable, // Procedure names are configurable via SQLNames (see DefaultSQLNames for defaults)
// resolvespec_totp_get_status, resolvespec_totp_get_secret,
// resolvespec_totp_regenerate_backup_codes, resolvespec_totp_validate_backup_code
// See totp_database_schema.sql for procedure definitions // See totp_database_schema.sql for procedure definitions
type DatabaseTwoFactorProvider struct { type DatabaseTwoFactorProvider struct {
db *sql.DB db *sql.DB
totpGen *TOTPGenerator totpGen *TOTPGenerator
sqlNames *SQLNames
} }
// NewDatabaseTwoFactorProvider creates a new database-backed 2FA provider // NewDatabaseTwoFactorProvider creates a new database-backed 2FA provider
func NewDatabaseTwoFactorProvider(db *sql.DB, config *TwoFactorConfig) *DatabaseTwoFactorProvider { func NewDatabaseTwoFactorProvider(db *sql.DB, config *TwoFactorConfig, names ...*SQLNames) *DatabaseTwoFactorProvider {
if config == nil { if config == nil {
config = DefaultTwoFactorConfig() config = DefaultTwoFactorConfig()
} }
return &DatabaseTwoFactorProvider{ return &DatabaseTwoFactorProvider{
db: db, db: db,
totpGen: NewTOTPGenerator(config), totpGen: NewTOTPGenerator(config),
sqlNames: resolveSQLNames(names...),
} }
} }
@@ -76,7 +76,7 @@ func (p *DatabaseTwoFactorProvider) Enable2FA(userID int, secret string, backupC
var success bool var success bool
var errorMsg sql.NullString var errorMsg sql.NullString
query := `SELECT p_success, p_error FROM resolvespec_totp_enable($1, $2, $3::jsonb)` query := fmt.Sprintf(`SELECT p_success, p_error FROM %s($1, $2, $3::jsonb)`, p.sqlNames.TOTPEnable)
err = p.db.QueryRow(query, userID, secret, string(codesJSON)).Scan(&success, &errorMsg) err = p.db.QueryRow(query, userID, secret, string(codesJSON)).Scan(&success, &errorMsg)
if err != nil { if err != nil {
return fmt.Errorf("enable 2FA query failed: %w", err) return fmt.Errorf("enable 2FA query failed: %w", err)
@@ -97,7 +97,7 @@ func (p *DatabaseTwoFactorProvider) Disable2FA(userID int) error {
var success bool var success bool
var errorMsg sql.NullString var errorMsg sql.NullString
query := `SELECT p_success, p_error FROM resolvespec_totp_disable($1)` query := fmt.Sprintf(`SELECT p_success, p_error FROM %s($1)`, p.sqlNames.TOTPDisable)
err := p.db.QueryRow(query, userID).Scan(&success, &errorMsg) err := p.db.QueryRow(query, userID).Scan(&success, &errorMsg)
if err != nil { if err != nil {
return fmt.Errorf("disable 2FA query failed: %w", err) return fmt.Errorf("disable 2FA query failed: %w", err)
@@ -119,7 +119,7 @@ func (p *DatabaseTwoFactorProvider) Get2FAStatus(userID int) (bool, error) {
var errorMsg sql.NullString var errorMsg sql.NullString
var enabled bool var enabled bool
query := `SELECT p_success, p_error, p_enabled FROM resolvespec_totp_get_status($1)` query := fmt.Sprintf(`SELECT p_success, p_error, p_enabled FROM %s($1)`, p.sqlNames.TOTPGetStatus)
err := p.db.QueryRow(query, userID).Scan(&success, &errorMsg, &enabled) err := p.db.QueryRow(query, userID).Scan(&success, &errorMsg, &enabled)
if err != nil { if err != nil {
return false, fmt.Errorf("get 2FA status query failed: %w", err) return false, fmt.Errorf("get 2FA status query failed: %w", err)
@@ -141,7 +141,7 @@ func (p *DatabaseTwoFactorProvider) Get2FASecret(userID int) (string, error) {
var errorMsg sql.NullString var errorMsg sql.NullString
var secret sql.NullString var secret sql.NullString
query := `SELECT p_success, p_error, p_secret FROM resolvespec_totp_get_secret($1)` query := fmt.Sprintf(`SELECT p_success, p_error, p_secret FROM %s($1)`, p.sqlNames.TOTPGetSecret)
err := p.db.QueryRow(query, userID).Scan(&success, &errorMsg, &secret) err := p.db.QueryRow(query, userID).Scan(&success, &errorMsg, &secret)
if err != nil { if err != nil {
return "", fmt.Errorf("get 2FA secret query failed: %w", err) return "", fmt.Errorf("get 2FA secret query failed: %w", err)
@@ -185,7 +185,7 @@ func (p *DatabaseTwoFactorProvider) GenerateBackupCodes(userID int, count int) (
var success bool var success bool
var errorMsg sql.NullString var errorMsg sql.NullString
query := `SELECT p_success, p_error FROM resolvespec_totp_regenerate_backup_codes($1, $2::jsonb)` query := fmt.Sprintf(`SELECT p_success, p_error FROM %s($1, $2::jsonb)`, p.sqlNames.TOTPRegenerateBackup)
err = p.db.QueryRow(query, userID, string(codesJSON)).Scan(&success, &errorMsg) err = p.db.QueryRow(query, userID, string(codesJSON)).Scan(&success, &errorMsg)
if err != nil { if err != nil {
return nil, fmt.Errorf("regenerate backup codes query failed: %w", err) return nil, fmt.Errorf("regenerate backup codes query failed: %w", err)
@@ -212,7 +212,7 @@ func (p *DatabaseTwoFactorProvider) ValidateBackupCode(userID int, code string)
var errorMsg sql.NullString var errorMsg sql.NullString
var valid bool var valid bool
query := `SELECT p_success, p_error, p_valid FROM resolvespec_totp_validate_backup_code($1, $2)` query := fmt.Sprintf(`SELECT p_success, p_error, p_valid FROM %s($1, $2)`, p.sqlNames.TOTPValidateBackupCode)
err := p.db.QueryRow(query, userID, codeHash).Scan(&success, &errorMsg, &valid) err := p.db.QueryRow(query, userID, codeHash).Scan(&success, &errorMsg, &valid)
if err != nil { if err != nil {
return false, fmt.Errorf("validate backup code query failed: %w", err) return false, fmt.Errorf("validate backup code query failed: %w", err)