From 568df8c6d6df0e2c7b52fb63f05634f5c5244f86 Mon Sep 17 00:00:00 2001 From: Hein Date: Tue, 31 Mar 2026 14:25:59 +0200 Subject: [PATCH] feat(security): add configurable SQL procedure names * Introduce SQLNames struct to define stored procedure names. * Update DatabaseAuthenticator, JWTAuthenticator, and other providers to use SQLNames for procedure calls. * Remove hardcoded procedure names for better flexibility and customization. * Implement validation for SQL names to ensure they are valid identifiers. * Add tests for SQLNames functionality and merging behavior. --- pkg/security/QUICK_REFERENCE.md | 7 +- pkg/security/examples.go | 6 - pkg/security/oauth2_methods.go | 34 ++-- pkg/security/passkey_provider.go | 21 ++- pkg/security/providers.go | 133 +++++++-------- pkg/security/sql_names.go | 222 +++++++++++++++++++++++++ pkg/security/sql_names_test.go | 145 ++++++++++++++++ pkg/security/totp_provider_database.go | 28 ++-- 8 files changed, 476 insertions(+), 120 deletions(-) create mode 100644 pkg/security/sql_names.go create mode 100644 pkg/security/sql_names_test.go diff --git a/pkg/security/QUICK_REFERENCE.md b/pkg/security/QUICK_REFERENCE.md index d229b6c..ac9971d 100644 --- a/pkg/security/QUICK_REFERENCE.md +++ b/pkg/security/QUICK_REFERENCE.md @@ -258,11 +258,8 @@ func (a *JWTAuthenticator) Login(ctx context.Context, req security.LoginRequest) } func (a *JWTAuthenticator) Logout(ctx context.Context, req security.LogoutRequest) error { - // Add to blacklist - return a.db.WithContext(ctx).Table("token_blacklist").Create(map[string]any{ - "token": req.Token, - "user_id": req.UserID, - }).Error + // Invalidate session via stored procedure + return nil } func (a *JWTAuthenticator) Authenticate(r *http.Request) (*security.UserContext, error) { diff --git a/pkg/security/examples.go b/pkg/security/examples.go index 24e4485..1d91ea2 100644 --- a/pkg/security/examples.go +++ b/pkg/security/examples.go @@ -135,12 +135,6 @@ func (a *JWTAuthenticatorExample) Login(ctx context.Context, req LoginRequest) ( } func (a *JWTAuthenticatorExample) Logout(ctx context.Context, req LogoutRequest) error { - // For JWT, logout could involve token blacklisting - // Add token to blacklist table - // err := a.db.WithContext(ctx).Table("token_blacklist").Create(map[string]interface{}{ - // "token": req.Token, - // "expires_at": time.Now().Add(24 * time.Hour), - // }).Error return nil } diff --git a/pkg/security/oauth2_methods.go b/pkg/security/oauth2_methods.go index a8d653c..0b87b3b 100644 --- a/pkg/security/oauth2_methods.go +++ b/pkg/security/oauth2_methods.go @@ -244,10 +244,10 @@ func (a *DatabaseAuthenticator) oauth2GetOrCreateUser(ctx context.Context, userC var errMsg *string var userID *int - err = a.db.QueryRowContext(ctx, ` - SELECT p_success, p_error, p_user_id - FROM resolvespec_oauth_getorcreateuser($1::jsonb) - `, userJSON).Scan(&success, &errMsg, &userID) + err = a.db.QueryRowContext(ctx, fmt.Sprintf(` + SELECT p_success, p_error, p_user_id + FROM %s($1::jsonb) + `, a.sqlNames.OAuthGetOrCreateUser), userJSON).Scan(&success, &errMsg, &userID) if err != nil { return 0, fmt.Errorf("failed to get or create user: %w", err) @@ -287,10 +287,10 @@ func (a *DatabaseAuthenticator) oauth2CreateSession(ctx context.Context, session var success bool var errMsg *string - err = a.db.QueryRowContext(ctx, ` - SELECT p_success, p_error - FROM resolvespec_oauth_createsession($1::jsonb) - `, sessionJSON).Scan(&success, &errMsg) + err = a.db.QueryRowContext(ctx, fmt.Sprintf(` + SELECT p_success, p_error + FROM %s($1::jsonb) + `, a.sqlNames.OAuthCreateSession), sessionJSON).Scan(&success, &errMsg) if err != nil { return fmt.Errorf("failed to create session: %w", err) @@ -385,10 +385,10 @@ func (a *DatabaseAuthenticator) OAuth2RefreshToken(ctx context.Context, refreshT var errMsg *string var sessionData []byte - err = a.db.QueryRowContext(ctx, ` + err = a.db.QueryRowContext(ctx, fmt.Sprintf(` SELECT p_success, p_error, p_data::text - FROM resolvespec_oauth_getrefreshtoken($1) - `, refreshToken).Scan(&success, &errMsg, &sessionData) + FROM %s($1) + `, a.sqlNames.OAuthGetRefreshToken), refreshToken).Scan(&success, &errMsg, &sessionData) if err != nil { return nil, fmt.Errorf("failed to get session by refresh token: %w", err) @@ -451,10 +451,10 @@ func (a *DatabaseAuthenticator) OAuth2RefreshToken(ctx context.Context, refreshT var updateSuccess bool var updateErrMsg *string - err = a.db.QueryRowContext(ctx, ` + err = a.db.QueryRowContext(ctx, fmt.Sprintf(` SELECT p_success, p_error - FROM resolvespec_oauth_updaterefreshtoken($1::jsonb) - `, updateJSON).Scan(&updateSuccess, &updateErrMsg) + FROM %s($1::jsonb) + `, a.sqlNames.OAuthUpdateRefreshToken), updateJSON).Scan(&updateSuccess, &updateErrMsg) if err != nil { return nil, fmt.Errorf("failed to update session: %w", err) @@ -472,10 +472,10 @@ func (a *DatabaseAuthenticator) OAuth2RefreshToken(ctx context.Context, refreshT var userErrMsg *string var userData []byte - err = a.db.QueryRowContext(ctx, ` + err = a.db.QueryRowContext(ctx, fmt.Sprintf(` SELECT p_success, p_error, p_data::text - FROM resolvespec_oauth_getuser($1) - `, session.UserID).Scan(&userSuccess, &userErrMsg, &userData) + FROM %s($1) + `, a.sqlNames.OAuthGetUser), session.UserID).Scan(&userSuccess, &userErrMsg, &userData) if err != nil { return nil, fmt.Errorf("failed to get user data: %w", err) diff --git a/pkg/security/passkey_provider.go b/pkg/security/passkey_provider.go index cc2603d..7abeab0 100644 --- a/pkg/security/passkey_provider.go +++ b/pkg/security/passkey_provider.go @@ -11,12 +11,14 @@ import ( ) // DatabasePasskeyProvider implements PasskeyProvider using database storage +// Procedure names are configurable via SQLNames (see DefaultSQLNames for defaults) type DatabasePasskeyProvider struct { db *sql.DB rpID string // Relying Party ID (domain) rpName string // Relying Party display name rpOrigin string // Expected origin for WebAuthn timeout int64 // Timeout in milliseconds (default: 60000) + sqlNames *SQLNames } // DatabasePasskeyProviderOptions configures the passkey provider @@ -29,6 +31,8 @@ type DatabasePasskeyProviderOptions struct { RPOrigin string // Timeout is the timeout for operations in milliseconds (default: 60000) Timeout int64 + // SQLNames provides custom SQL procedure/function names. If nil, uses DefaultSQLNames(). + SQLNames *SQLNames } // NewDatabasePasskeyProvider creates a new database-backed passkey provider @@ -37,12 +41,15 @@ func NewDatabasePasskeyProvider(db *sql.DB, opts DatabasePasskeyProviderOptions) opts.Timeout = 60000 // 60 seconds default } + sqlNames := MergeSQLNames(DefaultSQLNames(), opts.SQLNames) + return &DatabasePasskeyProvider{ db: db, rpID: opts.RPID, rpName: opts.RPName, rpOrigin: opts.RPOrigin, timeout: opts.Timeout, + sqlNames: sqlNames, } } @@ -132,7 +139,7 @@ func (p *DatabasePasskeyProvider) CompleteRegistration(ctx context.Context, user var errorMsg sql.NullString var credentialID sql.NullInt64 - query := `SELECT p_success, p_error, p_credential_id FROM resolvespec_passkey_store_credential($1::jsonb)` + query := fmt.Sprintf(`SELECT p_success, p_error, p_credential_id FROM %s($1::jsonb)`, p.sqlNames.PasskeyStoreCredential) err = p.db.QueryRowContext(ctx, query, string(credJSON)).Scan(&success, &errorMsg, &credentialID) if err != nil { return nil, fmt.Errorf("failed to store credential: %w", err) @@ -173,7 +180,7 @@ func (p *DatabasePasskeyProvider) BeginAuthentication(ctx context.Context, usern var userID sql.NullInt64 var credentialsJSON sql.NullString - query := `SELECT p_success, p_error, p_user_id, p_credentials::text FROM resolvespec_passkey_get_credentials_by_username($1)` + query := fmt.Sprintf(`SELECT p_success, p_error, p_user_id, p_credentials::text FROM %s($1)`, p.sqlNames.PasskeyGetCredsByUsername) err := p.db.QueryRowContext(ctx, query, username).Scan(&success, &errorMsg, &userID, &credentialsJSON) if err != nil { return nil, fmt.Errorf("failed to get credentials: %w", err) @@ -233,7 +240,7 @@ func (p *DatabasePasskeyProvider) CompleteAuthentication(ctx context.Context, re var errorMsg sql.NullString var credentialJSON sql.NullString - query := `SELECT p_success, p_error, p_credential::text FROM resolvespec_passkey_get_credential($1)` + query := fmt.Sprintf(`SELECT p_success, p_error, p_credential::text FROM %s($1)`, p.sqlNames.PasskeyGetCredential) err := p.db.QueryRowContext(ctx, query, response.RawID).Scan(&success, &errorMsg, &credentialJSON) if err != nil { return 0, fmt.Errorf("failed to get credential: %w", err) @@ -264,7 +271,7 @@ func (p *DatabasePasskeyProvider) CompleteAuthentication(ctx context.Context, re var updateError sql.NullString var cloneWarning sql.NullBool - updateQuery := `SELECT p_success, p_error, p_clone_warning FROM resolvespec_passkey_update_counter($1, $2)` + updateQuery := fmt.Sprintf(`SELECT p_success, p_error, p_clone_warning FROM %s($1, $2)`, p.sqlNames.PasskeyUpdateCounter) err = p.db.QueryRowContext(ctx, updateQuery, response.RawID, newCounter).Scan(&updateSuccess, &updateError, &cloneWarning) if err != nil { return 0, fmt.Errorf("failed to update counter: %w", err) @@ -283,7 +290,7 @@ func (p *DatabasePasskeyProvider) GetCredentials(ctx context.Context, userID int var errorMsg sql.NullString var credentialsJSON sql.NullString - query := `SELECT p_success, p_error, p_credentials::text FROM resolvespec_passkey_get_user_credentials($1)` + query := fmt.Sprintf(`SELECT p_success, p_error, p_credentials::text FROM %s($1)`, p.sqlNames.PasskeyGetUserCredentials) err := p.db.QueryRowContext(ctx, query, userID).Scan(&success, &errorMsg, &credentialsJSON) if err != nil { return nil, fmt.Errorf("failed to get credentials: %w", err) @@ -362,7 +369,7 @@ func (p *DatabasePasskeyProvider) DeleteCredential(ctx context.Context, userID i var success bool var errorMsg sql.NullString - query := `SELECT p_success, p_error FROM resolvespec_passkey_delete_credential($1, $2)` + query := fmt.Sprintf(`SELECT p_success, p_error FROM %s($1, $2)`, p.sqlNames.PasskeyDeleteCredential) err = p.db.QueryRowContext(ctx, query, userID, credID).Scan(&success, &errorMsg) if err != nil { return fmt.Errorf("failed to delete credential: %w", err) @@ -388,7 +395,7 @@ func (p *DatabasePasskeyProvider) UpdateCredentialName(ctx context.Context, user var success bool var errorMsg sql.NullString - query := `SELECT p_success, p_error FROM resolvespec_passkey_update_name($1, $2, $3)` + query := fmt.Sprintf(`SELECT p_success, p_error FROM %s($1, $2, $3)`, p.sqlNames.PasskeyUpdateName) err = p.db.QueryRowContext(ctx, query, userID, credID, name).Scan(&success, &errorMsg) if err != nil { return fmt.Errorf("failed to update credential name: %w", err) diff --git a/pkg/security/providers.go b/pkg/security/providers.go index f8ecad3..e23bd82 100644 --- a/pkg/security/providers.go +++ b/pkg/security/providers.go @@ -58,8 +58,7 @@ func (a *HeaderAuthenticator) Authenticate(r *http.Request) (*UserContext, error // DatabaseAuthenticator provides session-based authentication with database storage // All database operations go through stored procedures for security and consistency -// Requires stored procedures: resolvespec_login, resolvespec_logout, resolvespec_session, -// resolvespec_session_update, resolvespec_refresh_token +// Procedure names are configurable via SQLNames (see DefaultSQLNames for defaults) // See database_schema.sql for procedure definitions // Also supports multiple OAuth2 providers configured with WithOAuth2() // Also supports passkey authentication configured with WithPasskey() @@ -67,6 +66,7 @@ type DatabaseAuthenticator struct { db *sql.DB cache *cache.Cache cacheTTL time.Duration + sqlNames *SQLNames // OAuth2 providers registry (multiple providers supported) oauth2Providers map[string]*OAuth2Provider @@ -85,6 +85,9 @@ type DatabaseAuthenticatorOptions struct { Cache *cache.Cache // PasskeyProvider is an optional passkey provider for WebAuthn/FIDO2 authentication PasskeyProvider PasskeyProvider + // SQLNames provides custom SQL procedure/function names. If nil, uses DefaultSQLNames(). + // Partial overrides are supported: only set the fields you want to change. + SQLNames *SQLNames } func NewDatabaseAuthenticator(db *sql.DB) *DatabaseAuthenticator { @@ -103,10 +106,13 @@ func NewDatabaseAuthenticatorWithOptions(db *sql.DB, opts DatabaseAuthenticatorO cacheInstance = cache.GetDefaultCache() } + sqlNames := MergeSQLNames(DefaultSQLNames(), opts.SQLNames) + return &DatabaseAuthenticator{ db: db, cache: cacheInstance, cacheTTL: opts.CacheTTL, + sqlNames: sqlNames, passkeyProvider: opts.PasskeyProvider, } } @@ -118,12 +124,11 @@ func (a *DatabaseAuthenticator) Login(ctx context.Context, req LoginRequest) (*L return nil, fmt.Errorf("failed to marshal login request: %w", err) } - // Call resolvespec_login stored procedure var success bool var errorMsg sql.NullString var dataJSON sql.NullString - query := `SELECT p_success, p_error, p_data::text FROM resolvespec_login($1::jsonb)` + query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.Login) err = a.db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON) if err != nil { return nil, fmt.Errorf("login query failed: %w", err) @@ -153,12 +158,11 @@ func (a *DatabaseAuthenticator) Register(ctx context.Context, req RegisterReques return nil, fmt.Errorf("failed to marshal register request: %w", err) } - // Call resolvespec_register stored procedure var success bool var errorMsg sql.NullString var dataJSON sql.NullString - query := `SELECT p_success, p_error, p_data::text FROM resolvespec_register($1::jsonb)` + query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.Register) err = a.db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON) if err != nil { return nil, fmt.Errorf("register query failed: %w", err) @@ -187,12 +191,11 @@ func (a *DatabaseAuthenticator) Logout(ctx context.Context, req LogoutRequest) e return fmt.Errorf("failed to marshal logout request: %w", err) } - // Call resolvespec_logout stored procedure var success bool var errorMsg sql.NullString var dataJSON sql.NullString - query := `SELECT p_success, p_error, p_data::text FROM resolvespec_logout($1::jsonb)` + query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.Logout) err = a.db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON) if err != nil { return fmt.Errorf("logout query failed: %w", err) @@ -266,7 +269,7 @@ func (a *DatabaseAuthenticator) Authenticate(r *http.Request) (*UserContext, err var errorMsg sql.NullString var userJSON sql.NullString - query := `SELECT p_success, p_error, p_user::text FROM resolvespec_session($1, $2)` + query := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2)`, a.sqlNames.Session) err := a.db.QueryRowContext(r.Context(), query, token, reference).Scan(&success, &errorMsg, &userJSON) if err != nil { return nil, fmt.Errorf("session query failed: %w", err) @@ -338,24 +341,22 @@ func (a *DatabaseAuthenticator) updateSessionActivity(ctx context.Context, sessi return } - // Call resolvespec_session_update stored procedure var success bool var errorMsg sql.NullString var updatedUserJSON sql.NullString - query := `SELECT p_success, p_error, p_user::text FROM resolvespec_session_update($1, $2::jsonb)` + query := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2::jsonb)`, a.sqlNames.SessionUpdate) _ = a.db.QueryRowContext(ctx, query, sessionToken, string(userJSON)).Scan(&success, &errorMsg, &updatedUserJSON) } // RefreshToken implements Refreshable interface func (a *DatabaseAuthenticator) RefreshToken(ctx context.Context, refreshToken string) (*LoginResponse, error) { - // Call api_refresh_token stored procedure // First, we need to get the current user context for the refresh token var success bool var errorMsg sql.NullString var userJSON sql.NullString // Get current session to pass to refresh - query := `SELECT p_success, p_error, p_user::text FROM resolvespec_session($1, $2)` + query := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2)`, a.sqlNames.Session) err := a.db.QueryRowContext(ctx, query, refreshToken, "refresh").Scan(&success, &errorMsg, &userJSON) if err != nil { return nil, fmt.Errorf("refresh token query failed: %w", err) @@ -368,12 +369,11 @@ func (a *DatabaseAuthenticator) RefreshToken(ctx context.Context, refreshToken s return nil, fmt.Errorf("invalid refresh token") } - // Call resolvespec_refresh_token to generate new token var newSuccess bool var newErrorMsg sql.NullString var newUserJSON sql.NullString - refreshQuery := `SELECT p_success, p_error, p_user::text FROM resolvespec_refresh_token($1, $2::jsonb)` + refreshQuery := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2::jsonb)`, a.sqlNames.RefreshToken) err = a.db.QueryRowContext(ctx, refreshQuery, refreshToken, userJSON).Scan(&newSuccess, &newErrorMsg, &newUserJSON) if err != nil { return nil, fmt.Errorf("refresh token generation failed: %w", err) @@ -401,27 +401,28 @@ func (a *DatabaseAuthenticator) RefreshToken(ctx context.Context, refreshToken s // JWTAuthenticator provides JWT token-based authentication // All database operations go through stored procedures -// Requires stored procedures: resolvespec_jwt_login, resolvespec_jwt_logout +// Procedure names are configurable via SQLNames (see DefaultSQLNames for defaults) // NOTE: JWT signing/verification requires github.com/golang-jwt/jwt/v5 to be installed and imported type JWTAuthenticator struct { secretKey []byte db *sql.DB + sqlNames *SQLNames } -func NewJWTAuthenticator(secretKey string, db *sql.DB) *JWTAuthenticator { +func NewJWTAuthenticator(secretKey string, db *sql.DB, names ...*SQLNames) *JWTAuthenticator { return &JWTAuthenticator{ secretKey: []byte(secretKey), db: db, + sqlNames: resolveSQLNames(names...), } } func (a *JWTAuthenticator) Login(ctx context.Context, req LoginRequest) (*LoginResponse, error) { - // Call resolvespec_jwt_login stored procedure var success bool var errorMsg sql.NullString var userJSON []byte - query := `SELECT p_success, p_error, p_user FROM resolvespec_jwt_login($1, $2)` + query := fmt.Sprintf(`SELECT p_success, p_error, p_user FROM %s($1, $2)`, a.sqlNames.JWTLogin) err := a.db.QueryRowContext(ctx, query, req.Username, req.Password).Scan(&success, &errorMsg, &userJSON) if err != nil { return nil, fmt.Errorf("login query failed: %w", err) @@ -471,11 +472,10 @@ func (a *JWTAuthenticator) Login(ctx context.Context, req LoginRequest) (*LoginR } func (a *JWTAuthenticator) Logout(ctx context.Context, req LogoutRequest) error { - // Call resolvespec_jwt_logout stored procedure var success bool var errorMsg sql.NullString - query := `SELECT p_success, p_error FROM resolvespec_jwt_logout($1, $2)` + query := fmt.Sprintf(`SELECT p_success, p_error FROM %s($1, $2)`, a.sqlNames.JWTLogout) err := a.db.QueryRowContext(ctx, query, req.Token, req.UserID).Scan(&success, &errorMsg) if err != nil { return fmt.Errorf("logout query failed: %w", err) @@ -511,24 +511,24 @@ func (a *JWTAuthenticator) Authenticate(r *http.Request) (*UserContext, error) { // DatabaseColumnSecurityProvider loads column security from database // All database operations go through stored procedures -// Requires stored procedure: resolvespec_column_security +// Procedure names are configurable via SQLNames (see DefaultSQLNames for defaults) type DatabaseColumnSecurityProvider struct { - db *sql.DB + db *sql.DB + sqlNames *SQLNames } -func NewDatabaseColumnSecurityProvider(db *sql.DB) *DatabaseColumnSecurityProvider { - return &DatabaseColumnSecurityProvider{db: db} +func NewDatabaseColumnSecurityProvider(db *sql.DB, names ...*SQLNames) *DatabaseColumnSecurityProvider { + return &DatabaseColumnSecurityProvider{db: db, sqlNames: resolveSQLNames(names...)} } func (p *DatabaseColumnSecurityProvider) GetColumnSecurity(ctx context.Context, userID int, schema, table string) ([]ColumnSecurity, error) { var rules []ColumnSecurity - // Call resolvespec_column_security stored procedure var success bool var errorMsg sql.NullString var rulesJSON []byte - query := `SELECT p_success, p_error, p_rules FROM resolvespec_column_security($1, $2, $3)` + query := fmt.Sprintf(`SELECT p_success, p_error, p_rules FROM %s($1, $2, $3)`, p.sqlNames.ColumnSecurity) err := p.db.QueryRowContext(ctx, query, userID, schema, table).Scan(&success, &errorMsg, &rulesJSON) if err != nil { return nil, fmt.Errorf("failed to load column security: %w", err) @@ -576,21 +576,21 @@ func (p *DatabaseColumnSecurityProvider) GetColumnSecurity(ctx context.Context, // DatabaseRowSecurityProvider loads row security from database // All database operations go through stored procedures -// Requires stored procedure: resolvespec_row_security +// Procedure names are configurable via SQLNames (see DefaultSQLNames for defaults) type DatabaseRowSecurityProvider struct { - db *sql.DB + db *sql.DB + sqlNames *SQLNames } -func NewDatabaseRowSecurityProvider(db *sql.DB) *DatabaseRowSecurityProvider { - return &DatabaseRowSecurityProvider{db: db} +func NewDatabaseRowSecurityProvider(db *sql.DB, names ...*SQLNames) *DatabaseRowSecurityProvider { + return &DatabaseRowSecurityProvider{db: db, sqlNames: resolveSQLNames(names...)} } func (p *DatabaseRowSecurityProvider) GetRowSecurity(ctx context.Context, userID int, schema, table string) (RowSecurity, error) { var template string var hasBlock bool - // Call resolvespec_row_security stored procedure - query := `SELECT p_template, p_block FROM resolvespec_row_security($1, $2, $3)` + query := fmt.Sprintf(`SELECT p_template, p_block FROM %s($1, $2, $3)`, p.sqlNames.RowSecurity) err := p.db.QueryRowContext(ctx, query, schema, table, userID).Scan(&template, &hasBlock) if err != nil { @@ -758,56 +758,47 @@ func (a *DatabaseAuthenticator) LoginWithPasskey(ctx context.Context, req Passke return nil, fmt.Errorf("passkey authentication failed: %w", err) } - // Get user data from database - var username, email, roles string - var userLevel int - query := `SELECT username, email, user_level, COALESCE(roles, '') FROM users WHERE id = $1 AND is_active = true` - err = a.db.QueryRowContext(ctx, query, userID).Scan(&username, &email, &userLevel, &roles) - if err != nil { - return nil, fmt.Errorf("failed to get user data: %w", err) + // Build request JSON for passkey login stored procedure + reqData := map[string]any{ + "user_id": userID, } - - // Generate session token - sessionToken := "sess_" + generateRandomString(32) + "_" + fmt.Sprintf("%d", time.Now().Unix()) - expiresAt := time.Now().Add(24 * time.Hour) - - // Extract IP and user agent from claims - ipAddress := "" - userAgent := "" if req.Claims != nil { if ip, ok := req.Claims["ip_address"].(string); ok { - ipAddress = ip + reqData["ip_address"] = ip } if ua, ok := req.Claims["user_agent"].(string); ok { - userAgent = ua + reqData["user_agent"] = ua } } - // Create session - insertQuery := `INSERT INTO user_sessions (session_token, user_id, expires_at, ip_address, user_agent, last_activity_at) - VALUES ($1, $2, $3, $4, $5, now())` - _, err = a.db.ExecContext(ctx, insertQuery, sessionToken, userID, expiresAt, ipAddress, userAgent) + reqJSON, err := json.Marshal(reqData) if err != nil { - return nil, fmt.Errorf("failed to create session: %w", err) + return nil, fmt.Errorf("failed to marshal passkey login request: %w", err) } - // Update last login - updateQuery := `UPDATE users SET last_login_at = now() WHERE id = $1` - _, _ = a.db.ExecContext(ctx, updateQuery, userID) + var success bool + var errorMsg sql.NullString + var dataJSON sql.NullString - // Return login response - return &LoginResponse{ - Token: sessionToken, - User: &UserContext{ - UserID: userID, - UserName: username, - Email: email, - UserLevel: userLevel, - SessionID: sessionToken, - Roles: parseRoles(roles), - }, - ExpiresIn: int64(24 * time.Hour.Seconds()), - }, nil + query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.PasskeyLogin) + err = a.db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON) + if err != nil { + return nil, fmt.Errorf("passkey login query failed: %w", err) + } + + if !success { + if errorMsg.Valid { + return nil, fmt.Errorf("%s", errorMsg.String) + } + return nil, fmt.Errorf("passkey login failed") + } + + var response LoginResponse + if err := json.Unmarshal([]byte(dataJSON.String), &response); err != nil { + return nil, fmt.Errorf("failed to parse passkey login response: %w", err) + } + + return &response, nil } // GetPasskeyCredentials returns all passkey credentials for a user diff --git a/pkg/security/sql_names.go b/pkg/security/sql_names.go new file mode 100644 index 0000000..bbb594c --- /dev/null +++ b/pkg/security/sql_names.go @@ -0,0 +1,222 @@ +package security + +import ( + "fmt" + "reflect" + "regexp" +) + +var validSQLIdentifier = regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_]*$`) + +// SQLNames defines all configurable SQL stored procedure and table names +// used by the security package. Override individual fields to remap +// to custom database objects. Use DefaultSQLNames() for baseline defaults, +// and MergeSQLNames() to apply partial overrides. +type SQLNames struct { + // Auth procedures (DatabaseAuthenticator) + Login string // default: "resolvespec_login" + Register string // default: "resolvespec_register" + Logout string // default: "resolvespec_logout" + Session string // default: "resolvespec_session" + SessionUpdate string // default: "resolvespec_session_update" + RefreshToken string // default: "resolvespec_refresh_token" + + // JWT procedures (JWTAuthenticator) + JWTLogin string // default: "resolvespec_jwt_login" + JWTLogout string // default: "resolvespec_jwt_logout" + + // Security policy procedures + ColumnSecurity string // default: "resolvespec_column_security" + RowSecurity string // default: "resolvespec_row_security" + + // TOTP procedures (DatabaseTwoFactorProvider) + TOTPEnable string // default: "resolvespec_totp_enable" + TOTPDisable string // default: "resolvespec_totp_disable" + TOTPGetStatus string // default: "resolvespec_totp_get_status" + TOTPGetSecret string // default: "resolvespec_totp_get_secret" + TOTPRegenerateBackup string // default: "resolvespec_totp_regenerate_backup_codes" + TOTPValidateBackupCode string // default: "resolvespec_totp_validate_backup_code" + + // Passkey procedures (DatabasePasskeyProvider) + PasskeyStoreCredential string // default: "resolvespec_passkey_store_credential" + PasskeyGetCredsByUsername string // default: "resolvespec_passkey_get_credentials_by_username" + PasskeyGetCredential string // default: "resolvespec_passkey_get_credential" + PasskeyUpdateCounter string // default: "resolvespec_passkey_update_counter" + PasskeyGetUserCredentials string // default: "resolvespec_passkey_get_user_credentials" + PasskeyDeleteCredential string // default: "resolvespec_passkey_delete_credential" + PasskeyUpdateName string // default: "resolvespec_passkey_update_name" + PasskeyLogin string // default: "resolvespec_passkey_login" + + // OAuth2 procedures (DatabaseAuthenticator OAuth2 methods) + OAuthGetOrCreateUser string // default: "resolvespec_oauth_getorcreateuser" + OAuthCreateSession string // default: "resolvespec_oauth_createsession" + OAuthGetRefreshToken string // default: "resolvespec_oauth_getrefreshtoken" + OAuthUpdateRefreshToken string // default: "resolvespec_oauth_updaterefreshtoken" + OAuthGetUser string // default: "resolvespec_oauth_getuser" + +} + +// DefaultSQLNames returns an SQLNames with all default resolvespec_* values. +func DefaultSQLNames() *SQLNames { + return &SQLNames{ + Login: "resolvespec_login", + Register: "resolvespec_register", + Logout: "resolvespec_logout", + Session: "resolvespec_session", + SessionUpdate: "resolvespec_session_update", + RefreshToken: "resolvespec_refresh_token", + + JWTLogin: "resolvespec_jwt_login", + JWTLogout: "resolvespec_jwt_logout", + + ColumnSecurity: "resolvespec_column_security", + RowSecurity: "resolvespec_row_security", + + TOTPEnable: "resolvespec_totp_enable", + TOTPDisable: "resolvespec_totp_disable", + TOTPGetStatus: "resolvespec_totp_get_status", + TOTPGetSecret: "resolvespec_totp_get_secret", + TOTPRegenerateBackup: "resolvespec_totp_regenerate_backup_codes", + TOTPValidateBackupCode: "resolvespec_totp_validate_backup_code", + + PasskeyStoreCredential: "resolvespec_passkey_store_credential", + PasskeyGetCredsByUsername: "resolvespec_passkey_get_credentials_by_username", + PasskeyGetCredential: "resolvespec_passkey_get_credential", + PasskeyUpdateCounter: "resolvespec_passkey_update_counter", + PasskeyGetUserCredentials: "resolvespec_passkey_get_user_credentials", + PasskeyDeleteCredential: "resolvespec_passkey_delete_credential", + PasskeyUpdateName: "resolvespec_passkey_update_name", + PasskeyLogin: "resolvespec_passkey_login", + + OAuthGetOrCreateUser: "resolvespec_oauth_getorcreateuser", + OAuthCreateSession: "resolvespec_oauth_createsession", + OAuthGetRefreshToken: "resolvespec_oauth_getrefreshtoken", + OAuthUpdateRefreshToken: "resolvespec_oauth_updaterefreshtoken", + OAuthGetUser: "resolvespec_oauth_getuser", + } +} + +// MergeSQLNames returns a copy of base with any non-empty fields from override applied. +// If override is nil, a copy of base is returned. +func MergeSQLNames(base, override *SQLNames) *SQLNames { + if override == nil { + copied := *base + return &copied + } + merged := *base + if override.Login != "" { + merged.Login = override.Login + } + if override.Register != "" { + merged.Register = override.Register + } + if override.Logout != "" { + merged.Logout = override.Logout + } + if override.Session != "" { + merged.Session = override.Session + } + if override.SessionUpdate != "" { + merged.SessionUpdate = override.SessionUpdate + } + if override.RefreshToken != "" { + merged.RefreshToken = override.RefreshToken + } + if override.JWTLogin != "" { + merged.JWTLogin = override.JWTLogin + } + if override.JWTLogout != "" { + merged.JWTLogout = override.JWTLogout + } + if override.ColumnSecurity != "" { + merged.ColumnSecurity = override.ColumnSecurity + } + if override.RowSecurity != "" { + merged.RowSecurity = override.RowSecurity + } + if override.TOTPEnable != "" { + merged.TOTPEnable = override.TOTPEnable + } + if override.TOTPDisable != "" { + merged.TOTPDisable = override.TOTPDisable + } + if override.TOTPGetStatus != "" { + merged.TOTPGetStatus = override.TOTPGetStatus + } + if override.TOTPGetSecret != "" { + merged.TOTPGetSecret = override.TOTPGetSecret + } + if override.TOTPRegenerateBackup != "" { + merged.TOTPRegenerateBackup = override.TOTPRegenerateBackup + } + if override.TOTPValidateBackupCode != "" { + merged.TOTPValidateBackupCode = override.TOTPValidateBackupCode + } + if override.PasskeyStoreCredential != "" { + merged.PasskeyStoreCredential = override.PasskeyStoreCredential + } + if override.PasskeyGetCredsByUsername != "" { + merged.PasskeyGetCredsByUsername = override.PasskeyGetCredsByUsername + } + if override.PasskeyGetCredential != "" { + merged.PasskeyGetCredential = override.PasskeyGetCredential + } + if override.PasskeyUpdateCounter != "" { + merged.PasskeyUpdateCounter = override.PasskeyUpdateCounter + } + if override.PasskeyGetUserCredentials != "" { + merged.PasskeyGetUserCredentials = override.PasskeyGetUserCredentials + } + if override.PasskeyDeleteCredential != "" { + merged.PasskeyDeleteCredential = override.PasskeyDeleteCredential + } + if override.PasskeyUpdateName != "" { + merged.PasskeyUpdateName = override.PasskeyUpdateName + } + if override.PasskeyLogin != "" { + merged.PasskeyLogin = override.PasskeyLogin + } + if override.OAuthGetOrCreateUser != "" { + merged.OAuthGetOrCreateUser = override.OAuthGetOrCreateUser + } + if override.OAuthCreateSession != "" { + merged.OAuthCreateSession = override.OAuthCreateSession + } + if override.OAuthGetRefreshToken != "" { + merged.OAuthGetRefreshToken = override.OAuthGetRefreshToken + } + if override.OAuthUpdateRefreshToken != "" { + merged.OAuthUpdateRefreshToken = override.OAuthUpdateRefreshToken + } + if override.OAuthGetUser != "" { + merged.OAuthGetUser = override.OAuthGetUser + } + return &merged +} + +// ValidateSQLNames checks that all non-empty fields in names are valid SQL identifiers. +// Returns an error if any field contains invalid characters. +func ValidateSQLNames(names *SQLNames) error { + v := reflect.ValueOf(names).Elem() + typ := v.Type() + for i := 0; i < v.NumField(); i++ { + field := v.Field(i) + if field.Kind() != reflect.String { + continue + } + val := field.String() + if val != "" && !validSQLIdentifier.MatchString(val) { + return fmt.Errorf("SQLNames.%s contains invalid characters: %q", typ.Field(i).Name, val) + } + } + return nil +} + +// resolveSQLNames merges an optional override with defaults. +// Used by constructors that accept variadic *SQLNames. +func resolveSQLNames(override ...*SQLNames) *SQLNames { + if len(override) > 0 && override[0] != nil { + return MergeSQLNames(DefaultSQLNames(), override[0]) + } + return DefaultSQLNames() +} diff --git a/pkg/security/sql_names_test.go b/pkg/security/sql_names_test.go new file mode 100644 index 0000000..be7a500 --- /dev/null +++ b/pkg/security/sql_names_test.go @@ -0,0 +1,145 @@ +package security + +import ( + "reflect" + "testing" +) + +func TestDefaultSQLNames_AllFieldsNonEmpty(t *testing.T) { + names := DefaultSQLNames() + v := reflect.ValueOf(names).Elem() + typ := v.Type() + + for i := 0; i < v.NumField(); i++ { + field := v.Field(i) + if field.Kind() != reflect.String { + continue + } + if field.String() == "" { + t.Errorf("DefaultSQLNames().%s is empty", typ.Field(i).Name) + } + } +} + +func TestMergeSQLNames_PartialOverride(t *testing.T) { + base := DefaultSQLNames() + override := &SQLNames{ + Login: "custom_login", + TOTPEnable: "custom_totp_enable", + PasskeyLogin: "custom_passkey_login", + } + + merged := MergeSQLNames(base, override) + + if merged.Login != "custom_login" { + t.Errorf("MergeSQLNames().Login = %q, want %q", merged.Login, "custom_login") + } + if merged.TOTPEnable != "custom_totp_enable" { + t.Errorf("MergeSQLNames().TOTPEnable = %q, want %q", merged.TOTPEnable, "custom_totp_enable") + } + if merged.PasskeyLogin != "custom_passkey_login" { + t.Errorf("MergeSQLNames().PasskeyLogin = %q, want %q", merged.PasskeyLogin, "custom_passkey_login") + } + // Non-overridden fields should retain defaults + if merged.Logout != "resolvespec_logout" { + t.Errorf("MergeSQLNames().Logout = %q, want %q", merged.Logout, "resolvespec_logout") + } + if merged.Session != "resolvespec_session" { + t.Errorf("MergeSQLNames().Session = %q, want %q", merged.Session, "resolvespec_session") + } +} + +func TestMergeSQLNames_NilOverride(t *testing.T) { + base := DefaultSQLNames() + merged := MergeSQLNames(base, nil) + + // Should be a copy, not the same pointer + if merged == base { + t.Error("MergeSQLNames with nil override should return a copy, not the same pointer") + } + + // All values should match + v1 := reflect.ValueOf(base).Elem() + v2 := reflect.ValueOf(merged).Elem() + typ := v1.Type() + + for i := 0; i < v1.NumField(); i++ { + f1 := v1.Field(i) + f2 := v2.Field(i) + if f1.Kind() != reflect.String { + continue + } + if f1.String() != f2.String() { + t.Errorf("MergeSQLNames(base, nil).%s = %q, want %q", typ.Field(i).Name, f2.String(), f1.String()) + } + } +} + +func TestMergeSQLNames_DoesNotMutateBase(t *testing.T) { + base := DefaultSQLNames() + originalLogin := base.Login + + override := &SQLNames{Login: "custom_login"} + _ = MergeSQLNames(base, override) + + if base.Login != originalLogin { + t.Errorf("MergeSQLNames mutated base: Login = %q, want %q", base.Login, originalLogin) + } +} + +func TestMergeSQLNames_AllFieldsMerged(t *testing.T) { + base := DefaultSQLNames() + override := &SQLNames{} + v := reflect.ValueOf(override).Elem() + for i := 0; i < v.NumField(); i++ { + if v.Field(i).Kind() == reflect.String { + v.Field(i).SetString("custom_sentinel") + } + } + + merged := MergeSQLNames(base, override) + mv := reflect.ValueOf(merged).Elem() + typ := mv.Type() + for i := 0; i < mv.NumField(); i++ { + if mv.Field(i).Kind() != reflect.String { + continue + } + if mv.Field(i).String() != "custom_sentinel" { + t.Errorf("MergeSQLNames did not merge field %s", typ.Field(i).Name) + } + } +} + +func TestValidateSQLNames_Valid(t *testing.T) { + names := DefaultSQLNames() + if err := ValidateSQLNames(names); err != nil { + t.Errorf("ValidateSQLNames(defaults) error = %v", err) + } +} + +func TestValidateSQLNames_Invalid(t *testing.T) { + names := DefaultSQLNames() + names.Login = "resolvespec_login; DROP TABLE users; --" + + err := ValidateSQLNames(names) + if err == nil { + t.Error("ValidateSQLNames should reject names with invalid characters") + } +} + +func TestResolveSQLNames_NoOverride(t *testing.T) { + names := resolveSQLNames() + if names.Login != "resolvespec_login" { + t.Errorf("resolveSQLNames().Login = %q, want default", names.Login) + } +} + +func TestResolveSQLNames_WithOverride(t *testing.T) { + names := resolveSQLNames(&SQLNames{Login: "custom_login"}) + if names.Login != "custom_login" { + t.Errorf("resolveSQLNames().Login = %q, want %q", names.Login, "custom_login") + } + if names.Logout != "resolvespec_logout" { + t.Errorf("resolveSQLNames().Logout = %q, want default", names.Logout) + } +} diff --git a/pkg/security/totp_provider_database.go b/pkg/security/totp_provider_database.go index f730785..6fe8c5e 100644 --- a/pkg/security/totp_provider_database.go +++ b/pkg/security/totp_provider_database.go @@ -9,23 +9,23 @@ import ( ) // DatabaseTwoFactorProvider implements TwoFactorAuthProvider using PostgreSQL stored procedures -// Requires stored procedures: resolvespec_totp_enable, resolvespec_totp_disable, -// resolvespec_totp_get_status, resolvespec_totp_get_secret, -// resolvespec_totp_regenerate_backup_codes, resolvespec_totp_validate_backup_code +// Procedure names are configurable via SQLNames (see DefaultSQLNames for defaults) // See totp_database_schema.sql for procedure definitions type DatabaseTwoFactorProvider struct { - db *sql.DB - totpGen *TOTPGenerator + db *sql.DB + totpGen *TOTPGenerator + sqlNames *SQLNames } // NewDatabaseTwoFactorProvider creates a new database-backed 2FA provider -func NewDatabaseTwoFactorProvider(db *sql.DB, config *TwoFactorConfig) *DatabaseTwoFactorProvider { +func NewDatabaseTwoFactorProvider(db *sql.DB, config *TwoFactorConfig, names ...*SQLNames) *DatabaseTwoFactorProvider { if config == nil { config = DefaultTwoFactorConfig() } return &DatabaseTwoFactorProvider{ - db: db, - totpGen: NewTOTPGenerator(config), + db: db, + totpGen: NewTOTPGenerator(config), + sqlNames: resolveSQLNames(names...), } } @@ -76,7 +76,7 @@ func (p *DatabaseTwoFactorProvider) Enable2FA(userID int, secret string, backupC var success bool var errorMsg sql.NullString - query := `SELECT p_success, p_error FROM resolvespec_totp_enable($1, $2, $3::jsonb)` + query := fmt.Sprintf(`SELECT p_success, p_error FROM %s($1, $2, $3::jsonb)`, p.sqlNames.TOTPEnable) err = p.db.QueryRow(query, userID, secret, string(codesJSON)).Scan(&success, &errorMsg) if err != nil { return fmt.Errorf("enable 2FA query failed: %w", err) @@ -97,7 +97,7 @@ func (p *DatabaseTwoFactorProvider) Disable2FA(userID int) error { var success bool var errorMsg sql.NullString - query := `SELECT p_success, p_error FROM resolvespec_totp_disable($1)` + query := fmt.Sprintf(`SELECT p_success, p_error FROM %s($1)`, p.sqlNames.TOTPDisable) err := p.db.QueryRow(query, userID).Scan(&success, &errorMsg) if err != nil { return fmt.Errorf("disable 2FA query failed: %w", err) @@ -119,7 +119,7 @@ func (p *DatabaseTwoFactorProvider) Get2FAStatus(userID int) (bool, error) { var errorMsg sql.NullString var enabled bool - query := `SELECT p_success, p_error, p_enabled FROM resolvespec_totp_get_status($1)` + query := fmt.Sprintf(`SELECT p_success, p_error, p_enabled FROM %s($1)`, p.sqlNames.TOTPGetStatus) err := p.db.QueryRow(query, userID).Scan(&success, &errorMsg, &enabled) if err != nil { return false, fmt.Errorf("get 2FA status query failed: %w", err) @@ -141,7 +141,7 @@ func (p *DatabaseTwoFactorProvider) Get2FASecret(userID int) (string, error) { var errorMsg sql.NullString var secret sql.NullString - query := `SELECT p_success, p_error, p_secret FROM resolvespec_totp_get_secret($1)` + query := fmt.Sprintf(`SELECT p_success, p_error, p_secret FROM %s($1)`, p.sqlNames.TOTPGetSecret) err := p.db.QueryRow(query, userID).Scan(&success, &errorMsg, &secret) if err != nil { return "", fmt.Errorf("get 2FA secret query failed: %w", err) @@ -185,7 +185,7 @@ func (p *DatabaseTwoFactorProvider) GenerateBackupCodes(userID int, count int) ( var success bool var errorMsg sql.NullString - query := `SELECT p_success, p_error FROM resolvespec_totp_regenerate_backup_codes($1, $2::jsonb)` + query := fmt.Sprintf(`SELECT p_success, p_error FROM %s($1, $2::jsonb)`, p.sqlNames.TOTPRegenerateBackup) err = p.db.QueryRow(query, userID, string(codesJSON)).Scan(&success, &errorMsg) if err != nil { return nil, fmt.Errorf("regenerate backup codes query failed: %w", err) @@ -212,7 +212,7 @@ func (p *DatabaseTwoFactorProvider) ValidateBackupCode(userID int, code string) var errorMsg sql.NullString var valid bool - query := `SELECT p_success, p_error, p_valid FROM resolvespec_totp_validate_backup_code($1, $2)` + query := fmt.Sprintf(`SELECT p_success, p_error, p_valid FROM %s($1, $2)`, p.sqlNames.TOTPValidateBackupCode) err := p.db.QueryRow(query, userID, codeHash).Scan(&success, &errorMsg, &valid) if err != nil { return false, fmt.Errorf("validate backup code query failed: %w", err)