diff --git a/pkg/common/adapters/database/bun.go b/pkg/common/adapters/database/bun.go index d36337b..458db47 100644 --- a/pkg/common/adapters/database/bun.go +++ b/pkg/common/adapters/database/bun.go @@ -6,6 +6,7 @@ import ( "fmt" "reflect" "strings" + "sync" "time" "github.com/uptrace/bun" @@ -95,6 +96,8 @@ func debugScanIntoStruct(rows interface{}, dest interface{}) error { // This demonstrates how the abstraction works with different ORMs type BunAdapter struct { db *bun.DB + dbMu sync.RWMutex + dbFactory func() (*bun.DB, error) driverName string } @@ -106,10 +109,36 @@ func NewBunAdapter(db *bun.DB) *BunAdapter { return adapter } +// WithDBFactory configures a factory used to reopen the database connection if it is closed. +func (b *BunAdapter) WithDBFactory(factory func() (*bun.DB, error)) *BunAdapter { + b.dbFactory = factory + return b +} + +func (b *BunAdapter) getDB() *bun.DB { + b.dbMu.RLock() + defer b.dbMu.RUnlock() + return b.db +} + +func (b *BunAdapter) reconnectDB() error { + if b.dbFactory == nil { + return fmt.Errorf("no db factory configured for reconnect") + } + newDB, err := b.dbFactory() + if err != nil { + return err + } + b.dbMu.Lock() + b.db = newDB + b.dbMu.Unlock() + return nil +} + // EnableQueryDebug enables query debugging which logs all SQL queries including preloads // This is useful for debugging preload queries that may be failing func (b *BunAdapter) EnableQueryDebug() { - b.db.AddQueryHook(&QueryDebugHook{}) + b.getDB().AddQueryHook(&QueryDebugHook{}) logger.Info("Bun query debug mode enabled - all SQL queries will be logged") } @@ -130,22 +159,22 @@ func (b *BunAdapter) DisableQueryDebug() { func (b *BunAdapter) NewSelect() common.SelectQuery { return &BunSelectQuery{ - query: b.db.NewSelect(), + query: b.getDB().NewSelect(), db: b.db, driverName: b.driverName, } } func (b *BunAdapter) NewInsert() common.InsertQuery { - return &BunInsertQuery{query: b.db.NewInsert()} + return &BunInsertQuery{query: b.getDB().NewInsert()} } func (b *BunAdapter) NewUpdate() common.UpdateQuery { - return &BunUpdateQuery{query: b.db.NewUpdate()} + return &BunUpdateQuery{query: b.getDB().NewUpdate()} } func (b *BunAdapter) NewDelete() common.DeleteQuery { - return &BunDeleteQuery{query: b.db.NewDelete()} + return &BunDeleteQuery{query: b.getDB().NewDelete()} } func (b *BunAdapter) Exec(ctx context.Context, query string, args ...interface{}) (res common.Result, err error) { @@ -154,7 +183,14 @@ func (b *BunAdapter) Exec(ctx context.Context, query string, args ...interface{} err = logger.HandlePanic("BunAdapter.Exec", r) } }() - result, err := b.db.ExecContext(ctx, query, args...) + var result sql.Result + run := func() error { var e error; result, e = b.getDB().ExecContext(ctx, query, args...); return e } + err = run() + if isDBClosed(err) { + if reconnErr := b.reconnectDB(); reconnErr == nil { + err = run() + } + } return &BunResult{result: result}, err } @@ -164,11 +200,17 @@ func (b *BunAdapter) Query(ctx context.Context, dest interface{}, query string, err = logger.HandlePanic("BunAdapter.Query", r) } }() - return b.db.NewRaw(query, args...).Scan(ctx, dest) + err = b.getDB().NewRaw(query, args...).Scan(ctx, dest) + if isDBClosed(err) { + if reconnErr := b.reconnectDB(); reconnErr == nil { + err = b.getDB().NewRaw(query, args...).Scan(ctx, dest) + } + } + return err } func (b *BunAdapter) BeginTx(ctx context.Context) (common.Database, error) { - tx, err := b.db.BeginTx(ctx, &sql.TxOptions{}) + tx, err := b.getDB().BeginTx(ctx, &sql.TxOptions{}) if err != nil { return nil, err } @@ -194,7 +236,7 @@ func (b *BunAdapter) RunInTransaction(ctx context.Context, fn func(common.Databa err = logger.HandlePanic("BunAdapter.RunInTransaction", r) } }() - return b.db.RunInTx(ctx, &sql.TxOptions{}, func(ctx context.Context, tx bun.Tx) error { + return b.getDB().RunInTx(ctx, &sql.TxOptions{}, func(ctx context.Context, tx bun.Tx) error { // Create adapter with transaction adapter := &BunTxAdapter{tx: tx, driverName: b.driverName} return fn(adapter) @@ -202,7 +244,7 @@ func (b *BunAdapter) RunInTransaction(ctx context.Context, fn func(common.Databa } func (b *BunAdapter) GetUnderlyingDB() interface{} { - return b.db + return b.getDB() } func (b *BunAdapter) DriverName() string { diff --git a/pkg/common/adapters/database/pgsql.go b/pkg/common/adapters/database/pgsql.go index 87f3631..336e2af 100644 --- a/pkg/common/adapters/database/pgsql.go +++ b/pkg/common/adapters/database/pgsql.go @@ -6,6 +6,7 @@ import ( "fmt" "reflect" "strings" + "sync" "github.com/bitechdev/ResolveSpec/pkg/common" "github.com/bitechdev/ResolveSpec/pkg/logger" @@ -17,6 +18,8 @@ import ( // This provides a lightweight PostgreSQL adapter without ORM overhead type PgSQLAdapter struct { db *sql.DB + dbMu sync.RWMutex + dbFactory func() (*sql.DB, error) driverName string } @@ -31,6 +34,36 @@ func NewPgSQLAdapter(db *sql.DB, driverName ...string) *PgSQLAdapter { return &PgSQLAdapter{db: db, driverName: name} } +// WithDBFactory configures a factory used to reopen the database connection if it is closed. +func (p *PgSQLAdapter) WithDBFactory(factory func() (*sql.DB, error)) *PgSQLAdapter { + p.dbFactory = factory + return p +} + +func (p *PgSQLAdapter) getDB() *sql.DB { + p.dbMu.RLock() + defer p.dbMu.RUnlock() + return p.db +} + +func (p *PgSQLAdapter) reconnectDB() error { + if p.dbFactory == nil { + return fmt.Errorf("no db factory configured for reconnect") + } + newDB, err := p.dbFactory() + if err != nil { + return err + } + p.dbMu.Lock() + p.db = newDB + p.dbMu.Unlock() + return nil +} + +func isDBClosed(err error) bool { + return err != nil && strings.Contains(err.Error(), "sql: database is closed") +} + // EnableQueryDebug enables query debugging for development func (p *PgSQLAdapter) EnableQueryDebug() { logger.Info("PgSQL query debug mode - logging enabled via logger") @@ -38,7 +71,7 @@ func (p *PgSQLAdapter) EnableQueryDebug() { func (p *PgSQLAdapter) NewSelect() common.SelectQuery { return &PgSQLSelectQuery{ - db: p.db, + db: p.getDB(), driverName: p.driverName, columns: []string{"*"}, args: make([]interface{}, 0), @@ -47,7 +80,7 @@ func (p *PgSQLAdapter) NewSelect() common.SelectQuery { func (p *PgSQLAdapter) NewInsert() common.InsertQuery { return &PgSQLInsertQuery{ - db: p.db, + db: p.getDB(), driverName: p.driverName, values: make(map[string]interface{}), } @@ -55,7 +88,7 @@ func (p *PgSQLAdapter) NewInsert() common.InsertQuery { func (p *PgSQLAdapter) NewUpdate() common.UpdateQuery { return &PgSQLUpdateQuery{ - db: p.db, + db: p.getDB(), driverName: p.driverName, sets: make(map[string]interface{}), args: make([]interface{}, 0), @@ -65,7 +98,7 @@ func (p *PgSQLAdapter) NewUpdate() common.UpdateQuery { func (p *PgSQLAdapter) NewDelete() common.DeleteQuery { return &PgSQLDeleteQuery{ - db: p.db, + db: p.getDB(), driverName: p.driverName, args: make([]interface{}, 0), whereClauses: make([]string, 0), @@ -79,7 +112,14 @@ func (p *PgSQLAdapter) Exec(ctx context.Context, query string, args ...interface } }() logger.Debug("PgSQL Exec: %s [args: %v]", query, args) - result, err := p.db.ExecContext(ctx, query, args...) + var result sql.Result + run := func() error { var e error; result, e = p.getDB().ExecContext(ctx, query, args...); return e } + err = run() + if isDBClosed(err) { + if reconnErr := p.reconnectDB(); reconnErr == nil { + err = run() + } + } if err != nil { logger.Error("PgSQL Exec failed: %v", err) return nil, err @@ -94,7 +134,14 @@ func (p *PgSQLAdapter) Query(ctx context.Context, dest interface{}, query string } }() logger.Debug("PgSQL Query: %s [args: %v]", query, args) - rows, err := p.db.QueryContext(ctx, query, args...) + var rows *sql.Rows + run := func() error { var e error; rows, e = p.getDB().QueryContext(ctx, query, args...); return e } + err = run() + if isDBClosed(err) { + if reconnErr := p.reconnectDB(); reconnErr == nil { + err = run() + } + } if err != nil { logger.Error("PgSQL Query failed: %v", err) return err @@ -105,7 +152,7 @@ func (p *PgSQLAdapter) Query(ctx context.Context, dest interface{}, query string } func (p *PgSQLAdapter) BeginTx(ctx context.Context) (common.Database, error) { - tx, err := p.db.BeginTx(ctx, nil) + tx, err := p.getDB().BeginTx(ctx, nil) if err != nil { return nil, err } @@ -127,7 +174,7 @@ func (p *PgSQLAdapter) RunInTransaction(ctx context.Context, fn func(common.Data } }() - tx, err := p.db.BeginTx(ctx, nil) + tx, err := p.getDB().BeginTx(ctx, nil) if err != nil { return err } diff --git a/pkg/common/recursive_crud.go b/pkg/common/recursive_crud.go index caecdd8..572d353 100644 --- a/pkg/common/recursive_crud.go +++ b/pkg/common/recursive_crud.go @@ -98,8 +98,8 @@ func (p *NestedCUDProcessor) ProcessNestedCUD( } } - // Filter regularData to only include fields that exist in the model - // Use MapToStruct to validate and filter fields + // Filter regularData to only include fields that exist in the model, + // and translate JSON keys to their actual database column names. regularData = p.filterValidFields(regularData, model) // Inject parent IDs for foreign key resolution @@ -191,14 +191,15 @@ func (p *NestedCUDProcessor) extractCRUDRequest(data map[string]interface{}) str return "" } -// filterValidFields filters input data to only include fields that exist in the model -// Uses reflection.MapToStruct to validate fields and extract only those that match the model +// filterValidFields filters input data to only include fields that exist in the model, +// and translates JSON key names to their actual database column names. +// For example, a field tagged `json:"_changed_date" bun:"changed_date"` will be +// included in the result as "changed_date", not "_changed_date". func (p *NestedCUDProcessor) filterValidFields(data map[string]interface{}, model interface{}) map[string]interface{} { if len(data) == 0 { return data } - // Create a new instance of the model to use with MapToStruct modelType := reflect.TypeOf(model) for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) { modelType = modelType.Elem() @@ -208,25 +209,16 @@ func (p *NestedCUDProcessor) filterValidFields(data map[string]interface{}, mode return data } - // Create a new instance of the model - tempModel := reflect.New(modelType).Interface() + // Build a mapping from JSON key -> DB column name for all writable fields. + // This both validates which fields belong to the model and translates their names + // to the correct column names for use in SQL insert/update queries. + jsonToDBCol := reflection.BuildJSONToDBColumnMap(modelType) - // Use MapToStruct to map the data - this will only map valid fields - err := reflection.MapToStruct(data, tempModel) - if err != nil { - logger.Debug("Error mapping data to model: %v", err) - return data - } - - // Extract the mapped fields back into a map - // This effectively filters out any fields that don't exist in the model filteredData := make(map[string]interface{}) - tempModelValue := reflect.ValueOf(tempModel).Elem() - for key, value := range data { - // Check if the field was successfully mapped - if fieldWasMapped(tempModelValue, modelType, key) { - filteredData[key] = value + dbColName, exists := jsonToDBCol[key] + if exists { + filteredData[dbColName] = value } else { logger.Debug("Skipping invalid field '%s' - not found in model %v", key, modelType) } @@ -235,72 +227,8 @@ func (p *NestedCUDProcessor) filterValidFields(data map[string]interface{}, mode return filteredData } -// fieldWasMapped checks if a field with the given key was mapped to the model -func fieldWasMapped(modelValue reflect.Value, modelType reflect.Type, key string) bool { - // Look for the field by JSON tag or field name - for i := 0; i < modelType.NumField(); i++ { - field := modelType.Field(i) - - // Skip unexported fields - if !field.IsExported() { - continue - } - - // Check JSON tag - jsonTag := field.Tag.Get("json") - if jsonTag != "" && jsonTag != "-" { - parts := strings.Split(jsonTag, ",") - if len(parts) > 0 && parts[0] == key { - return true - } - } - - // Check bun tag - bunTag := field.Tag.Get("bun") - if bunTag != "" && bunTag != "-" { - if colName := reflection.ExtractColumnFromBunTag(bunTag); colName == key { - return true - } - } - - // Check gorm tag - gormTag := field.Tag.Get("gorm") - if gormTag != "" && gormTag != "-" { - if colName := reflection.ExtractColumnFromGormTag(gormTag); colName == key { - return true - } - } - - // Check lowercase field name - if strings.EqualFold(field.Name, key) { - return true - } - - // Handle embedded structs recursively - if field.Anonymous { - fieldType := field.Type - if fieldType.Kind() == reflect.Ptr { - fieldType = fieldType.Elem() - } - if fieldType.Kind() == reflect.Struct { - embeddedValue := modelValue.Field(i) - if embeddedValue.Kind() == reflect.Ptr { - if embeddedValue.IsNil() { - continue - } - embeddedValue = embeddedValue.Elem() - } - if fieldWasMapped(embeddedValue, fieldType, key) { - return true - } - } - } - } - - return false -} - -// injectForeignKeys injects parent IDs into data for foreign key fields +// injectForeignKeys injects parent IDs into data for foreign key fields. +// data is expected to be keyed by DB column names (as returned by filterValidFields). func (p *NestedCUDProcessor) injectForeignKeys(data map[string]interface{}, modelType reflect.Type, parentIDs map[string]interface{}) { if len(parentIDs) == 0 { return @@ -319,10 +247,11 @@ func (p *NestedCUDProcessor) injectForeignKeys(data map[string]interface{}, mode if strings.EqualFold(jsonName, parentKey+"_id") || strings.EqualFold(jsonName, parentKey+"id") || strings.EqualFold(field.Name, parentKey+"ID") { - // Only inject if not already present - if _, exists := data[jsonName]; !exists { - logger.Debug("Injecting foreign key: %s = %v", jsonName, parentID) - data[jsonName] = parentID + // Use the DB column name as the key, since data is keyed by DB column names + dbColName := reflection.GetColumnName(field) + if _, exists := data[dbColName]; !exists { + logger.Debug("Injecting foreign key: %s = %v", dbColName, parentID) + data[dbColName] = parentID } } } diff --git a/pkg/dbmanager/providers/provider.go b/pkg/dbmanager/providers/provider.go index 65dbc6f..a541f2b 100644 --- a/pkg/dbmanager/providers/provider.go +++ b/pkg/dbmanager/providers/provider.go @@ -4,11 +4,17 @@ import ( "context" "database/sql" "errors" + "strings" "time" "go.mongodb.org/mongo-driver/mongo" ) +// isDBClosed reports whether err indicates the *sql.DB has been closed. +func isDBClosed(err error) bool { + return err != nil && strings.Contains(err.Error(), "sql: database is closed") +} + // Common errors var ( // ErrNotSQLDatabase is returned when attempting SQL operations on a non-SQL database diff --git a/pkg/dbmanager/providers/sqlite.go b/pkg/dbmanager/providers/sqlite.go index 6f70970..4306b8d 100644 --- a/pkg/dbmanager/providers/sqlite.go +++ b/pkg/dbmanager/providers/sqlite.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "fmt" + "sync" "time" _ "github.com/glebarez/sqlite" // Pure Go SQLite driver @@ -14,8 +15,10 @@ import ( // SQLiteProvider implements Provider for SQLite databases type SQLiteProvider struct { - db *sql.DB - config ConnectionConfig + db *sql.DB + dbMu sync.RWMutex + dbFactory func() (*sql.DB, error) + config ConnectionConfig } // NewSQLiteProvider creates a new SQLite provider @@ -129,7 +132,13 @@ func (p *SQLiteProvider) HealthCheck(ctx context.Context) error { // Execute a simple query to verify the database is accessible var result int - err := p.db.QueryRowContext(healthCtx, "SELECT 1").Scan(&result) + run := func() error { return p.getDB().QueryRowContext(healthCtx, "SELECT 1").Scan(&result) } + err := run() + if isDBClosed(err) { + if reconnErr := p.reconnectDB(); reconnErr == nil { + err = run() + } + } if err != nil { return fmt.Errorf("health check failed: %w", err) } @@ -141,6 +150,32 @@ func (p *SQLiteProvider) HealthCheck(ctx context.Context) error { return nil } +// WithDBFactory configures a factory used to reopen the database connection if it is closed. +func (p *SQLiteProvider) WithDBFactory(factory func() (*sql.DB, error)) *SQLiteProvider { + p.dbFactory = factory + return p +} + +func (p *SQLiteProvider) getDB() *sql.DB { + p.dbMu.RLock() + defer p.dbMu.RUnlock() + return p.db +} + +func (p *SQLiteProvider) reconnectDB() error { + if p.dbFactory == nil { + return fmt.Errorf("no db factory configured for reconnect") + } + newDB, err := p.dbFactory() + if err != nil { + return err + } + p.dbMu.Lock() + p.db = newDB + p.dbMu.Unlock() + return nil +} + // GetNative returns the native *sql.DB connection func (p *SQLiteProvider) GetNative() (*sql.DB, error) { if p.db == nil { diff --git a/pkg/reflection/model_utils.go b/pkg/reflection/model_utils.go index 350fd72..672aee0 100644 --- a/pkg/reflection/model_utils.go +++ b/pkg/reflection/model_utils.go @@ -196,6 +196,92 @@ func collectColumnsFromType(typ reflect.Type, columns *[]string) { } } +// GetColumnName extracts the database column name from a struct field. +// Priority: bun tag -> gorm tag -> json tag -> lowercase field name. +// This is the exported version for use by other packages. +func GetColumnName(field reflect.StructField) string { + return getColumnNameFromField(field) +} + +// BuildJSONToDBColumnMap returns a map from JSON key names to database column names +// for the given model type. Only writable, non-relation fields are included. +// This is used to translate incoming request data (keyed by JSON names) into +// properly named database columns before insert/update operations. +func BuildJSONToDBColumnMap(modelType reflect.Type) map[string]string { + result := make(map[string]string) + buildJSONToDBMap(modelType, result, false) + return result +} + +func buildJSONToDBMap(modelType reflect.Type, result map[string]string, scanOnly bool) { + for i := 0; i < modelType.NumField(); i++ { + field := modelType.Field(i) + if !field.IsExported() { + continue + } + + bunTag := field.Tag.Get("bun") + gormTag := field.Tag.Get("gorm") + + // Handle embedded structs + if field.Anonymous { + ft := field.Type + if ft.Kind() == reflect.Ptr { + ft = ft.Elem() + } + isScanOnly := scanOnly + if bunTag != "" && isBunFieldScanOnly(bunTag) { + isScanOnly = true + } + if ft.Kind() == reflect.Struct { + buildJSONToDBMap(ft, result, isScanOnly) + continue + } + } + + if scanOnly { + continue + } + + // Skip explicitly excluded fields + if bunTag == "-" || gormTag == "-" { + continue + } + + // Skip scan-only fields + if bunTag != "" && isBunFieldScanOnly(bunTag) { + continue + } + + // Skip bun relation fields + if bunTag != "" && (strings.Contains(bunTag, "rel:") || strings.Contains(bunTag, "join:") || strings.Contains(bunTag, "m2m:")) { + continue + } + + // Skip gorm relation fields + if gormTag != "" && (strings.Contains(gormTag, "foreignKey:") || strings.Contains(gormTag, "references:") || strings.Contains(gormTag, "many2many:")) { + continue + } + + // Get JSON key (how the field appears in incoming request data) + jsonKey := "" + if jsonTag := field.Tag.Get("json"); jsonTag != "" && jsonTag != "-" { + parts := strings.Split(jsonTag, ",") + if len(parts) > 0 && parts[0] != "" { + jsonKey = parts[0] + } + } + if jsonKey == "" { + jsonKey = strings.ToLower(field.Name) + } + + // Get the actual DB column name (bun > gorm > json > field name) + dbColName := getColumnNameFromField(field) + + result[jsonKey] = dbColName + } +} + // getColumnNameFromField extracts the column name from a struct field // Priority: bun tag -> gorm tag -> json tag -> lowercase field name func getColumnNameFromField(field reflect.StructField) string { diff --git a/pkg/reflection/model_utils_test.go b/pkg/reflection/model_utils_test.go index 41b6529..0041a81 100644 --- a/pkg/reflection/model_utils_test.go +++ b/pkg/reflection/model_utils_test.go @@ -823,12 +823,12 @@ func TestToSnakeCase(t *testing.T) { { name: "UserID", input: "UserID", - expected: "user_i_d", + expected: "user_id", }, { name: "HTTPServer", input: "HTTPServer", - expected: "h_t_t_p_server", + expected: "http_server", }, { name: "lowercase", @@ -838,7 +838,7 @@ func TestToSnakeCase(t *testing.T) { { name: "UPPERCASE", input: "UPPERCASE", - expected: "u_p_p_e_r_c_a_s_e", + expected: "uppercase", }, { name: "Single", diff --git a/pkg/resolvemcp/handler.go b/pkg/resolvemcp/handler.go index 4ed7c9c..de8f578 100644 --- a/pkg/resolvemcp/handler.go +++ b/pkg/resolvemcp/handler.go @@ -717,7 +717,7 @@ func (h *Handler) applyFilterGroup(query common.SelectQuery, filters []common.Fi return query.Where("("+strings.Join(conditions, " OR ")+")", args...) } -func (h *Handler) buildFilterCondition(filter common.FilterOption) (string, []interface{}) { +func (h *Handler) buildFilterCondition(filter common.FilterOption) (condition string, args []interface{}) { switch filter.Operator { case "eq", "=": return fmt.Sprintf("%s = ?", filter.Column), []interface{}{filter.Value} @@ -747,7 +747,8 @@ func (h *Handler) buildFilterCondition(filter common.FilterOption) (string, []in } func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, preloads []common.PreloadOption) (common.SelectQuery, error) { - for _, preload := range preloads { + for i := range preloads { + preload := &preloads[i] if preload.Relation == "" { continue } diff --git a/pkg/security/KEYSTORE.md b/pkg/security/KEYSTORE.md new file mode 100644 index 0000000..dab4d6e --- /dev/null +++ b/pkg/security/KEYSTORE.md @@ -0,0 +1,153 @@ +# Keystore + +Per-user named auth keys with pluggable storage. Each user can hold multiple keys of different types — JWT secrets, header API keys, OAuth2 client credentials, or generic API keys. Keys are identified by a human-readable name ("CI deploy", "mobile app") and can carry scopes and arbitrary metadata. + +## Key types + +| Constant | Value | Use case | +|---|---|---| +| `KeyTypeJWTSecret` | `jwt_secret` | Per-user JWT signing secret | +| `KeyTypeHeaderAPI` | `header_api` | Static API key sent in a request header | +| `KeyTypeOAuth2` | `oauth2` | OAuth2 client credentials | +| `KeyTypeGenericAPI` | `api` | General-purpose application key | + +## Storage backends + +### ConfigKeyStore + +In-memory store seeded from a static list. Suitable for a small, fixed set of service-account keys loaded from a config file. Keys created at runtime via `CreateKey` are held in memory and lost on restart. + +```go +// Pre-load keys from config (KeyHash = SHA-256 hex of the raw key) +store := security.NewConfigKeyStore([]security.UserKey{ + { + UserID: 1, + KeyType: security.KeyTypeGenericAPI, + KeyHash: "e3b0c44298fc1c149afb...", // sha256(rawKey) + Name: "CI deploy", + Scopes: []string{"deploy"}, + IsActive: true, + }, +}) +``` + +### DatabaseKeyStore + +Backed by PostgreSQL stored procedures. Supports optional caching (default 2-minute TTL). Apply `keystore_schema.sql` before use. + +```go +db, _ := sql.Open("postgres", dsn) + +store := security.NewDatabaseKeyStore(db) + +// With options +store = security.NewDatabaseKeyStore(db, security.DatabaseKeyStoreOptions{ + CacheTTL: 5 * time.Minute, + SQLNames: &security.KeyStoreSQLNames{ + ValidateKey: "myapp_keystore_validate", // override one procedure name + }, +}) +``` + +## Managing keys + +```go +ctx := context.Background() + +// Create — raw key returned once; store it securely +resp, err := store.CreateKey(ctx, security.CreateKeyRequest{ + UserID: 42, + KeyType: security.KeyTypeGenericAPI, + Name: "mobile app", + Scopes: []string{"read", "write"}, +}) +fmt.Println(resp.RawKey) // only shown here; hashed internally + +// List +keys, err := store.GetUserKeys(ctx, 42, "") // "" = all types +keys, err = store.GetUserKeys(ctx, 42, security.KeyTypeGenericAPI) + +// Revoke +err = store.DeleteKey(ctx, 42, resp.Key.ID) + +// Validate (used by authenticators internally) +key, err := store.ValidateKey(ctx, rawKey, "") +``` + +## HTTP authentication + +`KeyStoreAuthenticator` wraps any `KeyStore` and implements the `Authenticator` interface. It is drop-in compatible with `DatabaseAuthenticator` and works in `CompositeSecurityProvider`. + +Keys are extracted from the request in this order: + +1. `Authorization: Bearer ` +2. `Authorization: ApiKey ` +3. `X-API-Key: ` + +```go +auth := security.NewKeyStoreAuthenticator(store, "") // "" = accept any key type +// Restrict to a specific type: +auth = security.NewKeyStoreAuthenticator(store, security.KeyTypeGenericAPI) +``` + +Plug it into a handler: + +```go +handler := resolvespec.NewHandler(db, registry, + resolvespec.WithAuthenticator(auth), +) +``` + +`Login` and `Logout` return an error — key lifecycle is managed through `KeyStore` directly. + +On successful validation the request context receives a `UserContext` where: + +- `UserID` — from the key +- `Roles` — the key's `Scopes` +- `Claims["key_type"]` — key type string +- `Claims["key_name"]` — key name + +## Database setup + +Apply `keystore_schema.sql` to your PostgreSQL database. It requires the `users` table from the main `database_schema.sql`. + +```sql +\i pkg/security/keystore_schema.sql +``` + +This creates: + +- `user_keys` table with indexes on `user_id`, `key_hash`, and `key_type` +- `resolvespec_keystore_get_user_keys(p_user_id, p_key_type)` +- `resolvespec_keystore_create_key(p_request jsonb)` +- `resolvespec_keystore_delete_key(p_user_id, p_key_id)` +- `resolvespec_keystore_validate_key(p_key_hash, p_key_type)` + +### Custom procedure names + +```go +store := security.NewDatabaseKeyStore(db, security.DatabaseKeyStoreOptions{ + SQLNames: &security.KeyStoreSQLNames{ + GetUserKeys: "myschema_get_keys", + CreateKey: "myschema_create_key", + DeleteKey: "myschema_delete_key", + ValidateKey: "myschema_validate_key", + }, +}) + +// Validate names at startup +names := &security.KeyStoreSQLNames{ + GetUserKeys: "myschema_get_keys", + // ... +} +if err := security.ValidateKeyStoreSQLNames(names); err != nil { + log.Fatal(err) +} +``` + +## Security notes + +- Raw keys are never stored. Only the SHA-256 hex digest is persisted. +- The raw key is generated with `crypto/rand` (32 bytes, base64url-encoded) and returned exactly once in `CreateKeyResponse.RawKey`. +- Hash comparisons in `ConfigKeyStore` use `crypto/subtle.ConstantTimeCompare` to prevent timing side-channels. +- `DeleteKey` performs a soft delete (`is_active = false`). The `DatabaseKeyStore` invalidates the cache entry immediately, but due to the cache TTL a revoked key may authenticate for up to `CacheTTL` (default 2 minutes) in a distributed environment. Set `CacheTTL: 0` to disable caching if immediate revocation is required. diff --git a/pkg/security/keystore.go b/pkg/security/keystore.go new file mode 100644 index 0000000..1d442a1 --- /dev/null +++ b/pkg/security/keystore.go @@ -0,0 +1,81 @@ +package security + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "time" +) + +// hashSHA256Hex returns the lowercase hex SHA-256 digest of the given string. +// Used by all keystore implementations to hash raw keys before storage or lookup. +func hashSHA256Hex(raw string) string { + sum := sha256.Sum256([]byte(raw)) + return hex.EncodeToString(sum[:]) +} + +// KeyType identifies the category of an auth key. +type KeyType string + +const ( + // KeyTypeJWTSecret is a per-user JWT signing secret for token generation. + KeyTypeJWTSecret KeyType = "jwt_secret" + // KeyTypeHeaderAPI is a static API key sent via a request header. + KeyTypeHeaderAPI KeyType = "header_api" + // KeyTypeOAuth2 holds OAuth2 client credentials (client_id / client_secret). + KeyTypeOAuth2 KeyType = "oauth2" + // KeyTypeGenericAPI is a generic application API key. + KeyTypeGenericAPI KeyType = "api" +) + +// UserKey represents a single named auth key belonging to a user. +// KeyHash stores the SHA-256 hex digest of the raw key; the raw key is never persisted. +type UserKey struct { + ID int64 `json:"id"` + UserID int `json:"user_id"` + KeyType KeyType `json:"key_type"` + KeyHash string `json:"key_hash"` // SHA-256 hex; never the raw key + Name string `json:"name"` + Scopes []string `json:"scopes,omitempty"` + Meta map[string]any `json:"meta,omitempty"` + ExpiresAt *time.Time `json:"expires_at,omitempty"` + CreatedAt time.Time `json:"created_at"` + LastUsedAt *time.Time `json:"last_used_at,omitempty"` + IsActive bool `json:"is_active"` +} + +// CreateKeyRequest specifies the parameters for a new key. +type CreateKeyRequest struct { + UserID int + KeyType KeyType + Name string + Scopes []string + Meta map[string]any + ExpiresAt *time.Time +} + +// CreateKeyResponse is returned exactly once when a key is created. +// The caller is responsible for persisting RawKey; it is not stored anywhere. +type CreateKeyResponse struct { + Key UserKey + RawKey string // crypto/rand 32 bytes, base64url-encoded +} + +// KeyStore manages per-user auth keys with pluggable storage backends. +// Implementations: ConfigKeyStore (static list) and DatabaseKeyStore (stored procedures). +type KeyStore interface { + // CreateKey generates a new key, stores its hash, and returns the raw key once. + CreateKey(ctx context.Context, req CreateKeyRequest) (*CreateKeyResponse, error) + + // GetUserKeys returns all active, non-expired keys for a user. + // Pass an empty KeyType to return all types. + GetUserKeys(ctx context.Context, userID int, keyType KeyType) ([]UserKey, error) + + // DeleteKey soft-deletes a key by ID after verifying ownership. + DeleteKey(ctx context.Context, userID int, keyID int64) error + + // ValidateKey checks a raw key, returns the matching UserKey on success. + // The implementation hashes the raw key before any lookup. + // Pass an empty KeyType to accept any type. + ValidateKey(ctx context.Context, rawKey string, keyType KeyType) (*UserKey, error) +} diff --git a/pkg/security/keystore_authenticator.go b/pkg/security/keystore_authenticator.go new file mode 100644 index 0000000..dd8cae2 --- /dev/null +++ b/pkg/security/keystore_authenticator.go @@ -0,0 +1,97 @@ +package security + +import ( + "context" + "fmt" + "net/http" + "strings" +) + +// KeyStoreAuthenticator implements the Authenticator interface using a KeyStore. +// It is suitable for long-lived application credentials (API keys, JWT secrets, etc.) +// rather than interactive sessions. Login and Logout are not supported — key lifecycle +// is managed directly through the KeyStore. +// +// Key extraction order: +// 1. Authorization: Bearer +// 2. Authorization: ApiKey +// 3. X-API-Key header +type KeyStoreAuthenticator struct { + keyStore KeyStore + keyType KeyType // empty = accept any type +} + +// NewKeyStoreAuthenticator creates a KeyStoreAuthenticator. +// Pass an empty keyType to accept keys of any type. +func NewKeyStoreAuthenticator(ks KeyStore, keyType KeyType) *KeyStoreAuthenticator { + return &KeyStoreAuthenticator{keyStore: ks, keyType: keyType} +} + +// Login is not supported for keystore authentication. +func (a *KeyStoreAuthenticator) Login(_ context.Context, _ LoginRequest) (*LoginResponse, error) { + return nil, fmt.Errorf("keystore authenticator does not support login") +} + +// Logout is not supported for keystore authentication. +func (a *KeyStoreAuthenticator) Logout(_ context.Context, _ LogoutRequest) error { + return nil +} + +// Authenticate extracts an API key from the request and validates it against the KeyStore. +// Returns a UserContext built from the matching UserKey on success. +func (a *KeyStoreAuthenticator) Authenticate(r *http.Request) (*UserContext, error) { + rawKey := extractAPIKey(r) + if rawKey == "" { + return nil, fmt.Errorf("API key required (Authorization: Bearer/ApiKey or X-API-Key header)") + } + + userKey, err := a.keyStore.ValidateKey(r.Context(), rawKey, a.keyType) + if err != nil { + return nil, fmt.Errorf("invalid API key: %w", err) + } + + return userKeyToUserContext(userKey), nil +} + +// extractAPIKey extracts a raw key from the request using the following precedence: +// 1. Authorization: Bearer +// 2. Authorization: ApiKey +// 3. X-API-Key header +func extractAPIKey(r *http.Request) string { + if auth := r.Header.Get("Authorization"); auth != "" { + if after, ok := strings.CutPrefix(auth, "Bearer "); ok { + return strings.TrimSpace(after) + } + if after, ok := strings.CutPrefix(auth, "ApiKey "); ok { + return strings.TrimSpace(after) + } + } + return strings.TrimSpace(r.Header.Get("X-API-Key")) +} + +// userKeyToUserContext converts a UserKey into a UserContext. +// Scopes are mapped to Roles. Key type and name are stored in Claims. +func userKeyToUserContext(k *UserKey) *UserContext { + claims := map[string]any{ + "key_type": string(k.KeyType), + "key_name": k.Name, + } + + meta := k.Meta + if meta == nil { + meta = map[string]any{} + } + + roles := k.Scopes + if roles == nil { + roles = []string{} + } + + return &UserContext{ + UserID: k.UserID, + SessionID: fmt.Sprintf("key:%d", k.ID), + Roles: roles, + Claims: claims, + Meta: meta, + } +} diff --git a/pkg/security/keystore_config.go b/pkg/security/keystore_config.go new file mode 100644 index 0000000..353cf14 --- /dev/null +++ b/pkg/security/keystore_config.go @@ -0,0 +1,149 @@ +package security + +import ( + "context" + "crypto/rand" + "crypto/subtle" + "encoding/base64" + "encoding/hex" + "fmt" + "sync" + "sync/atomic" + "time" +) + +// ConfigKeyStore is an in-memory keystore backed by a static slice of UserKey values. +// It is designed for config-file driven setups (e.g. service accounts defined in YAML) +// with a small, bounded number of keys. For large or dynamic key sets use DatabaseKeyStore. +// +// Pre-existing entries must have KeyHash set to the SHA-256 hex of the intended raw key. +// Keys created at runtime via CreateKey are held in memory only and lost on restart. +type ConfigKeyStore struct { + mu sync.RWMutex + keys []UserKey + next int64 // monotonic ID counter for runtime-created keys (atomic) +} + +// NewConfigKeyStore creates a ConfigKeyStore seeded with the provided keys. +// Pass nil or an empty slice to start with no pre-loaded keys. +// Zero-value entries (CreatedAt is zero) are treated as active and assigned the current time. +func NewConfigKeyStore(keys []UserKey) *ConfigKeyStore { + var maxID int64 + copied := make([]UserKey, len(keys)) + copy(copied, keys) + for i := range copied { + if copied[i].CreatedAt.IsZero() { + copied[i].IsActive = true + copied[i].CreatedAt = time.Now() + } + if copied[i].ID > maxID { + maxID = copied[i].ID + } + } + return &ConfigKeyStore{keys: copied, next: maxID} +} + +// CreateKey generates a new raw key, stores its SHA-256 hash, and returns the raw key once. +func (s *ConfigKeyStore) CreateKey(_ context.Context, req CreateKeyRequest) (*CreateKeyResponse, error) { + rawBytes := make([]byte, 32) + if _, err := rand.Read(rawBytes); err != nil { + return nil, fmt.Errorf("failed to generate key material: %w", err) + } + rawKey := base64.RawURLEncoding.EncodeToString(rawBytes) + hash := hashSHA256Hex(rawKey) + + id := atomic.AddInt64(&s.next, 1) + key := UserKey{ + ID: id, + UserID: req.UserID, + KeyType: req.KeyType, + KeyHash: hash, + Name: req.Name, + Scopes: req.Scopes, + Meta: req.Meta, + ExpiresAt: req.ExpiresAt, + CreatedAt: time.Now(), + IsActive: true, + } + + s.mu.Lock() + s.keys = append(s.keys, key) + s.mu.Unlock() + + return &CreateKeyResponse{Key: key, RawKey: rawKey}, nil +} + +// GetUserKeys returns all active, non-expired keys for the given user. +// Pass an empty KeyType to return all types. +func (s *ConfigKeyStore) GetUserKeys(_ context.Context, userID int, keyType KeyType) ([]UserKey, error) { + now := time.Now() + s.mu.RLock() + defer s.mu.RUnlock() + + var result []UserKey + for i := range s.keys { + k := &s.keys[i] + if k.UserID != userID || !k.IsActive { + continue + } + if k.ExpiresAt != nil && k.ExpiresAt.Before(now) { + continue + } + if keyType != "" && k.KeyType != keyType { + continue + } + result = append(result, *k) + } + return result, nil +} + +// DeleteKey soft-deletes a key by setting IsActive to false after ownership verification. +func (s *ConfigKeyStore) DeleteKey(_ context.Context, userID int, keyID int64) error { + s.mu.Lock() + defer s.mu.Unlock() + + for i := range s.keys { + if s.keys[i].ID == keyID { + if s.keys[i].UserID != userID { + return fmt.Errorf("key not found or permission denied") + } + s.keys[i].IsActive = false + return nil + } + } + return fmt.Errorf("key not found") +} + +// ValidateKey hashes the raw key and finds a matching, active, non-expired entry. +// Uses constant-time comparison to prevent timing side-channels. +// Pass an empty KeyType to accept any type. +func (s *ConfigKeyStore) ValidateKey(_ context.Context, rawKey string, keyType KeyType) (*UserKey, error) { + hash := hashSHA256Hex(rawKey) + hashBytes, _ := hex.DecodeString(hash) + now := time.Now() + + // Write lock: ValidateKey updates LastUsedAt on the matched entry. + s.mu.Lock() + defer s.mu.Unlock() + + for i := range s.keys { + k := &s.keys[i] + if !k.IsActive { + continue + } + if k.ExpiresAt != nil && k.ExpiresAt.Before(now) { + continue + } + if keyType != "" && k.KeyType != keyType { + continue + } + stored, _ := hex.DecodeString(k.KeyHash) + if subtle.ConstantTimeCompare(hashBytes, stored) != 1 { + continue + } + k.LastUsedAt = &now + result := *k + return &result, nil + } + return nil, fmt.Errorf("invalid or expired key") +} diff --git a/pkg/security/keystore_database.go b/pkg/security/keystore_database.go new file mode 100644 index 0000000..75e7eb1 --- /dev/null +++ b/pkg/security/keystore_database.go @@ -0,0 +1,256 @@ +package security + +import ( + "context" + "crypto/rand" + "database/sql" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "sync" + "time" + + "github.com/bitechdev/ResolveSpec/pkg/cache" +) + +// DatabaseKeyStoreOptions configures DatabaseKeyStore. +type DatabaseKeyStoreOptions struct { + // Cache is an optional cache instance. If nil, uses the default cache. + Cache *cache.Cache + // CacheTTL is the duration to cache ValidateKey results. + // Default: 2 minutes. + CacheTTL time.Duration + // SQLNames provides custom procedure names. If nil, uses DefaultKeyStoreSQLNames(). + SQLNames *KeyStoreSQLNames + // DBFactory is called to obtain a fresh *sql.DB when the existing connection is closed. + // If nil, reconnection is disabled. + DBFactory func() (*sql.DB, error) +} + +// DatabaseKeyStore is a KeyStore backed by PostgreSQL stored procedures. +// All DB operations go through configurable procedure names; the raw key is +// never passed to the database. +// +// See keystore_schema.sql for the required table and procedure definitions. +// +// Note: DeleteKey invalidates the cache entry for the deleted key. Due to the +// cache TTL, a deleted key may continue to authenticate for up to CacheTTL +// (default 2 minutes) if the cache entry cannot be invalidated. +type DatabaseKeyStore struct { + db *sql.DB + dbMu sync.RWMutex + dbFactory func() (*sql.DB, error) + sqlNames *KeyStoreSQLNames + cache *cache.Cache + cacheTTL time.Duration +} + +// NewDatabaseKeyStore creates a DatabaseKeyStore with optional configuration. +func NewDatabaseKeyStore(db *sql.DB, opts ...DatabaseKeyStoreOptions) *DatabaseKeyStore { + o := DatabaseKeyStoreOptions{} + if len(opts) > 0 { + o = opts[0] + } + if o.CacheTTL == 0 { + o.CacheTTL = 2 * time.Minute + } + c := o.Cache + if c == nil { + c = cache.GetDefaultCache() + } + names := MergeKeyStoreSQLNames(DefaultKeyStoreSQLNames(), o.SQLNames) + return &DatabaseKeyStore{ + db: db, + dbFactory: o.DBFactory, + sqlNames: names, + cache: c, + cacheTTL: o.CacheTTL, + } +} + +func (ks *DatabaseKeyStore) getDB() *sql.DB { + ks.dbMu.RLock() + defer ks.dbMu.RUnlock() + return ks.db +} + +func (ks *DatabaseKeyStore) reconnectDB() error { + if ks.dbFactory == nil { + return fmt.Errorf("no db factory configured for reconnect") + } + newDB, err := ks.dbFactory() + if err != nil { + return err + } + ks.dbMu.Lock() + ks.db = newDB + ks.dbMu.Unlock() + return nil +} + +// CreateKey generates a raw key, stores its SHA-256 hash via the create procedure, +// and returns the raw key once. +func (ks *DatabaseKeyStore) CreateKey(ctx context.Context, req CreateKeyRequest) (*CreateKeyResponse, error) { + rawBytes := make([]byte, 32) + if _, err := rand.Read(rawBytes); err != nil { + return nil, fmt.Errorf("failed to generate key material: %w", err) + } + rawKey := base64.RawURLEncoding.EncodeToString(rawBytes) + hash := hashSHA256Hex(rawKey) + + type createRequest struct { + UserID int `json:"user_id"` + KeyType KeyType `json:"key_type"` + KeyHash string `json:"key_hash"` + Name string `json:"name"` + Scopes []string `json:"scopes,omitempty"` + Meta map[string]any `json:"meta,omitempty"` + ExpiresAt *time.Time `json:"expires_at,omitempty"` + } + + reqJSON, err := json.Marshal(createRequest{ + UserID: req.UserID, + KeyType: req.KeyType, + KeyHash: hash, + Name: req.Name, + Scopes: req.Scopes, + Meta: req.Meta, + ExpiresAt: req.ExpiresAt, + }) + if err != nil { + return nil, fmt.Errorf("failed to marshal create key request: %w", err) + } + + var success bool + var errorMsg sql.NullString + var keyJSON sql.NullString + + query := fmt.Sprintf(`SELECT p_success, p_error, p_key::text FROM %s($1::jsonb)`, ks.sqlNames.CreateKey) + if err = ks.getDB().QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &keyJSON); err != nil { + return nil, fmt.Errorf("create key procedure failed: %w", err) + } + if !success { + return nil, errors.New(nullStringOr(errorMsg, "create key failed")) + } + + var key UserKey + if err = json.Unmarshal([]byte(keyJSON.String), &key); err != nil { + return nil, fmt.Errorf("failed to parse created key: %w", err) + } + + return &CreateKeyResponse{Key: key, RawKey: rawKey}, nil +} + +// GetUserKeys returns all active, non-expired keys for the given user. +// Pass an empty KeyType to return all types. +func (ks *DatabaseKeyStore) GetUserKeys(ctx context.Context, userID int, keyType KeyType) ([]UserKey, error) { + var success bool + var errorMsg sql.NullString + var keysJSON sql.NullString + + query := fmt.Sprintf(`SELECT p_success, p_error, p_keys::text FROM %s($1, $2)`, ks.sqlNames.GetUserKeys) + if err := ks.getDB().QueryRowContext(ctx, query, userID, string(keyType)).Scan(&success, &errorMsg, &keysJSON); err != nil { + return nil, fmt.Errorf("get user keys procedure failed: %w", err) + } + if !success { + return nil, errors.New(nullStringOr(errorMsg, "get user keys failed")) + } + + var keys []UserKey + if keysJSON.Valid && keysJSON.String != "" && keysJSON.String != "[]" { + if err := json.Unmarshal([]byte(keysJSON.String), &keys); err != nil { + return nil, fmt.Errorf("failed to parse user keys: %w", err) + } + } + if keys == nil { + keys = []UserKey{} + } + return keys, nil +} + +// DeleteKey soft-deletes a key after verifying ownership and invalidates its cache entry. +// The delete procedure returns the key_hash so no separate lookup is needed. +// Note: cache invalidation is best-effort; a cached entry may persist for up to CacheTTL. +func (ks *DatabaseKeyStore) DeleteKey(ctx context.Context, userID int, keyID int64) error { + var success bool + var errorMsg sql.NullString + var keyHash sql.NullString + + query := fmt.Sprintf(`SELECT p_success, p_error, p_key_hash FROM %s($1, $2)`, ks.sqlNames.DeleteKey) + if err := ks.getDB().QueryRowContext(ctx, query, userID, keyID).Scan(&success, &errorMsg, &keyHash); err != nil { + return fmt.Errorf("delete key procedure failed: %w", err) + } + if !success { + return errors.New(nullStringOr(errorMsg, "delete key failed")) + } + + if keyHash.Valid && keyHash.String != "" && ks.cache != nil { + _ = ks.cache.Delete(ctx, keystoreCacheKey(keyHash.String)) + } + return nil +} + +// ValidateKey hashes the raw key and calls the validate procedure. +// Results are cached for CacheTTL to reduce DB load on hot paths. +func (ks *DatabaseKeyStore) ValidateKey(ctx context.Context, rawKey string, keyType KeyType) (*UserKey, error) { + hash := hashSHA256Hex(rawKey) + cacheKey := keystoreCacheKey(hash) + + if ks.cache != nil { + var cached UserKey + if err := ks.cache.Get(ctx, cacheKey, &cached); err == nil { + if cached.IsActive { + return &cached, nil + } + return nil, errors.New("invalid or expired key") + } + } + + var success bool + var errorMsg sql.NullString + var keyJSON sql.NullString + + runQuery := func() error { + query := fmt.Sprintf(`SELECT p_success, p_error, p_key::text FROM %s($1, $2)`, ks.sqlNames.ValidateKey) + return ks.getDB().QueryRowContext(ctx, query, hash, string(keyType)).Scan(&success, &errorMsg, &keyJSON) + } + if err := runQuery(); err != nil { + if isDBClosed(err) { + if reconnErr := ks.reconnectDB(); reconnErr == nil { + err = runQuery() + } + if err != nil { + return nil, fmt.Errorf("validate key procedure failed: %w", err) + } + } else { + return nil, fmt.Errorf("validate key procedure failed: %w", err) + } + } + if !success { + return nil, errors.New(nullStringOr(errorMsg, "invalid or expired key")) + } + + var key UserKey + if err := json.Unmarshal([]byte(keyJSON.String), &key); err != nil { + return nil, fmt.Errorf("failed to parse validated key: %w", err) + } + + if ks.cache != nil { + _ = ks.cache.Set(ctx, cacheKey, key, ks.cacheTTL) + } + + return &key, nil +} + +func keystoreCacheKey(hash string) string { + return "keystore:validate:" + hash +} + +// nullStringOr returns s.String if valid, otherwise the fallback. +func nullStringOr(s sql.NullString, fallback string) string { + if s.Valid && s.String != "" { + return s.String + } + return fallback +} diff --git a/pkg/security/keystore_schema.sql b/pkg/security/keystore_schema.sql new file mode 100644 index 0000000..5e527ef --- /dev/null +++ b/pkg/security/keystore_schema.sql @@ -0,0 +1,187 @@ +-- Keystore schema for per-user auth keys +-- Apply alongside database_schema.sql (requires the users table) + +CREATE TABLE IF NOT EXISTS user_keys ( + id BIGSERIAL PRIMARY KEY, + user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE, + key_type VARCHAR(50) NOT NULL, + key_hash VARCHAR(64) NOT NULL UNIQUE, -- SHA-256 hex digest (64 chars) + name VARCHAR(255) NOT NULL DEFAULT '', + scopes TEXT, -- JSON array, e.g. '["read","write"]' + meta JSONB, + expires_at TIMESTAMP, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + last_used_at TIMESTAMP, + is_active BOOLEAN DEFAULT true +); + +CREATE INDEX IF NOT EXISTS idx_user_keys_user_id ON user_keys(user_id); +CREATE INDEX IF NOT EXISTS idx_user_keys_key_hash ON user_keys(key_hash); +CREATE INDEX IF NOT EXISTS idx_user_keys_key_type ON user_keys(key_type); + +-- resolvespec_keystore_get_user_keys +-- Returns all active, non-expired keys for a user. +-- Pass empty p_key_type to return all key types. +CREATE OR REPLACE FUNCTION resolvespec_keystore_get_user_keys( + p_user_id INTEGER, + p_key_type TEXT DEFAULT '' +) +RETURNS TABLE(p_success BOOLEAN, p_error TEXT, p_keys JSONB) +LANGUAGE plpgsql AS $$ +DECLARE + v_keys JSONB; +BEGIN + SELECT COALESCE( + jsonb_agg( + jsonb_build_object( + 'id', k.id, + 'user_id', k.user_id, + 'key_type', k.key_type, + 'name', k.name, + 'scopes', CASE WHEN k.scopes IS NOT NULL THEN k.scopes::jsonb ELSE '[]'::jsonb END, + 'meta', COALESCE(k.meta, '{}'::jsonb), + 'expires_at', k.expires_at, + 'created_at', k.created_at, + 'last_used_at', k.last_used_at, + 'is_active', k.is_active + ) + ), + '[]'::jsonb + ) + INTO v_keys + FROM user_keys k + WHERE k.user_id = p_user_id + AND k.is_active = true + AND (k.expires_at IS NULL OR k.expires_at > NOW()) + AND (p_key_type = '' OR k.key_type = p_key_type); + + RETURN QUERY SELECT true, NULL::TEXT, v_keys; +EXCEPTION WHEN OTHERS THEN + RETURN QUERY SELECT false, SQLERRM, NULL::JSONB; +END; +$$; + +-- resolvespec_keystore_create_key +-- Inserts a new key row. key_hash is provided by the caller (Go hashes the raw key). +-- Returns the created key record (without key_hash). +CREATE OR REPLACE FUNCTION resolvespec_keystore_create_key( + p_request JSONB +) +RETURNS TABLE(p_success BOOLEAN, p_error TEXT, p_key JSONB) +LANGUAGE plpgsql AS $$ +DECLARE + v_id BIGINT; + v_created_at TIMESTAMP; + v_key JSONB; +BEGIN + INSERT INTO user_keys (user_id, key_type, key_hash, name, scopes, meta, expires_at) + VALUES ( + (p_request->>'user_id')::INTEGER, + p_request->>'key_type', + p_request->>'key_hash', + COALESCE(p_request->>'name', ''), + p_request->>'scopes', + p_request->'meta', + CASE WHEN p_request->>'expires_at' IS NOT NULL + THEN (p_request->>'expires_at')::TIMESTAMP + ELSE NULL + END + ) + RETURNING id, created_at INTO v_id, v_created_at; + + v_key := jsonb_build_object( + 'id', v_id, + 'user_id', (p_request->>'user_id')::INTEGER, + 'key_type', p_request->>'key_type', + 'name', COALESCE(p_request->>'name', ''), + 'scopes', CASE WHEN p_request->>'scopes' IS NOT NULL + THEN (p_request->>'scopes')::jsonb + ELSE '[]'::jsonb END, + 'meta', COALESCE(p_request->'meta', '{}'::jsonb), + 'expires_at', p_request->>'expires_at', + 'created_at', v_created_at, + 'is_active', true + ); + + RETURN QUERY SELECT true, NULL::TEXT, v_key; +EXCEPTION WHEN OTHERS THEN + RETURN QUERY SELECT false, SQLERRM, NULL::JSONB; +END; +$$; + +-- resolvespec_keystore_delete_key +-- Soft-deletes a key (is_active = false) after verifying ownership. +-- Returns p_key_hash so the caller can invalidate cache entries without a separate query. +CREATE OR REPLACE FUNCTION resolvespec_keystore_delete_key( + p_user_id INTEGER, + p_key_id BIGINT +) +RETURNS TABLE(p_success BOOLEAN, p_error TEXT, p_key_hash TEXT) +LANGUAGE plpgsql AS $$ +DECLARE + v_hash TEXT; +BEGIN + UPDATE user_keys + SET is_active = false + WHERE id = p_key_id AND user_id = p_user_id AND is_active = true + RETURNING key_hash INTO v_hash; + + IF NOT FOUND THEN + RETURN QUERY SELECT false, 'key not found or already deleted'::TEXT, NULL::TEXT; + RETURN; + END IF; + + RETURN QUERY SELECT true, NULL::TEXT, v_hash; +EXCEPTION WHEN OTHERS THEN + RETURN QUERY SELECT false, SQLERRM, NULL::TEXT; +END; +$$; + +-- resolvespec_keystore_validate_key +-- Looks up a key by its SHA-256 hash, checks active status and expiry, +-- updates last_used_at, and returns the key record. +-- p_key_type can be empty to accept any key type. +CREATE OR REPLACE FUNCTION resolvespec_keystore_validate_key( + p_key_hash TEXT, + p_key_type TEXT DEFAULT '' +) +RETURNS TABLE(p_success BOOLEAN, p_error TEXT, p_key JSONB) +LANGUAGE plpgsql AS $$ +DECLARE + v_key_rec user_keys%ROWTYPE; + v_key JSONB; +BEGIN + SELECT * INTO v_key_rec + FROM user_keys + WHERE key_hash = p_key_hash + AND is_active = true + AND (expires_at IS NULL OR expires_at > NOW()) + AND (p_key_type = '' OR key_type = p_key_type); + + IF NOT FOUND THEN + RETURN QUERY SELECT false, 'invalid or expired key'::TEXT, NULL::JSONB; + RETURN; + END IF; + + UPDATE user_keys SET last_used_at = NOW() WHERE id = v_key_rec.id; + + v_key := jsonb_build_object( + 'id', v_key_rec.id, + 'user_id', v_key_rec.user_id, + 'key_type', v_key_rec.key_type, + 'name', v_key_rec.name, + 'scopes', CASE WHEN v_key_rec.scopes IS NOT NULL + THEN v_key_rec.scopes::jsonb + ELSE '[]'::jsonb END, + 'meta', COALESCE(v_key_rec.meta, '{}'::jsonb), + 'expires_at', v_key_rec.expires_at, + 'created_at', v_key_rec.created_at, + 'last_used_at', NOW(), + 'is_active', v_key_rec.is_active + ); + + RETURN QUERY SELECT true, NULL::TEXT, v_key; +EXCEPTION WHEN OTHERS THEN + RETURN QUERY SELECT false, SQLERRM, NULL::JSONB; +END; +$$; diff --git a/pkg/security/keystore_sql_names.go b/pkg/security/keystore_sql_names.go new file mode 100644 index 0000000..ab182c6 --- /dev/null +++ b/pkg/security/keystore_sql_names.go @@ -0,0 +1,61 @@ +package security + +import "fmt" + +// KeyStoreSQLNames holds the configurable stored procedure names used by DatabaseKeyStore. +// Use DefaultKeyStoreSQLNames() for defaults and MergeKeyStoreSQLNames() for partial overrides. +type KeyStoreSQLNames struct { + GetUserKeys string // default: "resolvespec_keystore_get_user_keys" + CreateKey string // default: "resolvespec_keystore_create_key" + DeleteKey string // default: "resolvespec_keystore_delete_key" + ValidateKey string // default: "resolvespec_keystore_validate_key" +} + +// DefaultKeyStoreSQLNames returns a KeyStoreSQLNames with all default resolvespec_keystore_* values. +func DefaultKeyStoreSQLNames() *KeyStoreSQLNames { + return &KeyStoreSQLNames{ + GetUserKeys: "resolvespec_keystore_get_user_keys", + CreateKey: "resolvespec_keystore_create_key", + DeleteKey: "resolvespec_keystore_delete_key", + ValidateKey: "resolvespec_keystore_validate_key", + } +} + +// MergeKeyStoreSQLNames returns a copy of base with any non-empty fields from override applied. +// If override is nil, a copy of base is returned. +func MergeKeyStoreSQLNames(base, override *KeyStoreSQLNames) *KeyStoreSQLNames { + if override == nil { + copied := *base + return &copied + } + merged := *base + if override.GetUserKeys != "" { + merged.GetUserKeys = override.GetUserKeys + } + if override.CreateKey != "" { + merged.CreateKey = override.CreateKey + } + if override.DeleteKey != "" { + merged.DeleteKey = override.DeleteKey + } + if override.ValidateKey != "" { + merged.ValidateKey = override.ValidateKey + } + return &merged +} + +// ValidateKeyStoreSQLNames checks that all non-empty procedure names are valid SQL identifiers. +func ValidateKeyStoreSQLNames(names *KeyStoreSQLNames) error { + fields := map[string]string{ + "GetUserKeys": names.GetUserKeys, + "CreateKey": names.CreateKey, + "DeleteKey": names.DeleteKey, + "ValidateKey": names.ValidateKey, + } + for field, val := range fields { + if val != "" && !validSQLIdentifier.MatchString(val) { + return fmt.Errorf("KeyStoreSQLNames.%s contains invalid characters: %q", field, val) + } + } + return nil +} diff --git a/pkg/security/oauth2_methods.go b/pkg/security/oauth2_methods.go index 0b87b3b..79a58de 100644 --- a/pkg/security/oauth2_methods.go +++ b/pkg/security/oauth2_methods.go @@ -244,7 +244,7 @@ func (a *DatabaseAuthenticator) oauth2GetOrCreateUser(ctx context.Context, userC var errMsg *string var userID *int - err = a.db.QueryRowContext(ctx, fmt.Sprintf(` + err = a.getDB().QueryRowContext(ctx, fmt.Sprintf(` SELECT p_success, p_error, p_user_id FROM %s($1::jsonb) `, a.sqlNames.OAuthGetOrCreateUser), userJSON).Scan(&success, &errMsg, &userID) @@ -287,7 +287,7 @@ func (a *DatabaseAuthenticator) oauth2CreateSession(ctx context.Context, session var success bool var errMsg *string - err = a.db.QueryRowContext(ctx, fmt.Sprintf(` + err = a.getDB().QueryRowContext(ctx, fmt.Sprintf(` SELECT p_success, p_error FROM %s($1::jsonb) `, a.sqlNames.OAuthCreateSession), sessionJSON).Scan(&success, &errMsg) @@ -385,7 +385,7 @@ func (a *DatabaseAuthenticator) OAuth2RefreshToken(ctx context.Context, refreshT var errMsg *string var sessionData []byte - err = a.db.QueryRowContext(ctx, fmt.Sprintf(` + err = a.getDB().QueryRowContext(ctx, fmt.Sprintf(` SELECT p_success, p_error, p_data::text FROM %s($1) `, a.sqlNames.OAuthGetRefreshToken), refreshToken).Scan(&success, &errMsg, &sessionData) @@ -451,7 +451,7 @@ func (a *DatabaseAuthenticator) OAuth2RefreshToken(ctx context.Context, refreshT var updateSuccess bool var updateErrMsg *string - err = a.db.QueryRowContext(ctx, fmt.Sprintf(` + err = a.getDB().QueryRowContext(ctx, fmt.Sprintf(` SELECT p_success, p_error FROM %s($1::jsonb) `, a.sqlNames.OAuthUpdateRefreshToken), updateJSON).Scan(&updateSuccess, &updateErrMsg) @@ -472,7 +472,7 @@ func (a *DatabaseAuthenticator) OAuth2RefreshToken(ctx context.Context, refreshT var userErrMsg *string var userData []byte - err = a.db.QueryRowContext(ctx, fmt.Sprintf(` + err = a.getDB().QueryRowContext(ctx, fmt.Sprintf(` SELECT p_success, p_error, p_data::text FROM %s($1) `, a.sqlNames.OAuthGetUser), session.UserID).Scan(&userSuccess, &userErrMsg, &userData) diff --git a/pkg/security/passkey_provider.go b/pkg/security/passkey_provider.go index 7abeab0..fb153c1 100644 --- a/pkg/security/passkey_provider.go +++ b/pkg/security/passkey_provider.go @@ -7,18 +7,21 @@ import ( "encoding/base64" "encoding/json" "fmt" + "sync" "time" ) // DatabasePasskeyProvider implements PasskeyProvider using database storage // Procedure names are configurable via SQLNames (see DefaultSQLNames for defaults) type DatabasePasskeyProvider struct { - db *sql.DB - rpID string // Relying Party ID (domain) - rpName string // Relying Party display name - rpOrigin string // Expected origin for WebAuthn - timeout int64 // Timeout in milliseconds (default: 60000) - sqlNames *SQLNames + db *sql.DB + dbMu sync.RWMutex + dbFactory func() (*sql.DB, error) + rpID string // Relying Party ID (domain) + rpName string // Relying Party display name + rpOrigin string // Expected origin for WebAuthn + timeout int64 // Timeout in milliseconds (default: 60000) + sqlNames *SQLNames } // DatabasePasskeyProviderOptions configures the passkey provider @@ -33,6 +36,9 @@ type DatabasePasskeyProviderOptions struct { Timeout int64 // SQLNames provides custom SQL procedure/function names. If nil, uses DefaultSQLNames(). SQLNames *SQLNames + // DBFactory is called to obtain a fresh *sql.DB when the existing connection is closed. + // If nil, reconnection is disabled. + DBFactory func() (*sql.DB, error) } // NewDatabasePasskeyProvider creates a new database-backed passkey provider @@ -44,15 +50,36 @@ func NewDatabasePasskeyProvider(db *sql.DB, opts DatabasePasskeyProviderOptions) sqlNames := MergeSQLNames(DefaultSQLNames(), opts.SQLNames) return &DatabasePasskeyProvider{ - db: db, - rpID: opts.RPID, - rpName: opts.RPName, - rpOrigin: opts.RPOrigin, - timeout: opts.Timeout, - sqlNames: sqlNames, + db: db, + dbFactory: opts.DBFactory, + rpID: opts.RPID, + rpName: opts.RPName, + rpOrigin: opts.RPOrigin, + timeout: opts.Timeout, + sqlNames: sqlNames, } } +func (p *DatabasePasskeyProvider) getDB() *sql.DB { + p.dbMu.RLock() + defer p.dbMu.RUnlock() + return p.db +} + +func (p *DatabasePasskeyProvider) reconnectDB() error { + if p.dbFactory == nil { + return fmt.Errorf("no db factory configured for reconnect") + } + newDB, err := p.dbFactory() + if err != nil { + return err + } + p.dbMu.Lock() + p.db = newDB + p.dbMu.Unlock() + return nil +} + // BeginRegistration creates registration options for a new passkey func (p *DatabasePasskeyProvider) BeginRegistration(ctx context.Context, userID int, username, displayName string) (*PasskeyRegistrationOptions, error) { // Generate challenge @@ -140,7 +167,7 @@ func (p *DatabasePasskeyProvider) CompleteRegistration(ctx context.Context, user var credentialID sql.NullInt64 query := fmt.Sprintf(`SELECT p_success, p_error, p_credential_id FROM %s($1::jsonb)`, p.sqlNames.PasskeyStoreCredential) - err = p.db.QueryRowContext(ctx, query, string(credJSON)).Scan(&success, &errorMsg, &credentialID) + err = p.getDB().QueryRowContext(ctx, query, string(credJSON)).Scan(&success, &errorMsg, &credentialID) if err != nil { return nil, fmt.Errorf("failed to store credential: %w", err) } @@ -181,7 +208,7 @@ func (p *DatabasePasskeyProvider) BeginAuthentication(ctx context.Context, usern var credentialsJSON sql.NullString query := fmt.Sprintf(`SELECT p_success, p_error, p_user_id, p_credentials::text FROM %s($1)`, p.sqlNames.PasskeyGetCredsByUsername) - err := p.db.QueryRowContext(ctx, query, username).Scan(&success, &errorMsg, &userID, &credentialsJSON) + err := p.getDB().QueryRowContext(ctx, query, username).Scan(&success, &errorMsg, &userID, &credentialsJSON) if err != nil { return nil, fmt.Errorf("failed to get credentials: %w", err) } @@ -240,8 +267,16 @@ func (p *DatabasePasskeyProvider) CompleteAuthentication(ctx context.Context, re var errorMsg sql.NullString var credentialJSON sql.NullString - query := fmt.Sprintf(`SELECT p_success, p_error, p_credential::text FROM %s($1)`, p.sqlNames.PasskeyGetCredential) - err := p.db.QueryRowContext(ctx, query, response.RawID).Scan(&success, &errorMsg, &credentialJSON) + runQuery := func() error { + query := fmt.Sprintf(`SELECT p_success, p_error, p_credential::text FROM %s($1)`, p.sqlNames.PasskeyGetCredential) + return p.getDB().QueryRowContext(ctx, query, response.RawID).Scan(&success, &errorMsg, &credentialJSON) + } + err := runQuery() + if isDBClosed(err) { + if reconnErr := p.reconnectDB(); reconnErr == nil { + err = runQuery() + } + } if err != nil { return 0, fmt.Errorf("failed to get credential: %w", err) } @@ -272,7 +307,7 @@ func (p *DatabasePasskeyProvider) CompleteAuthentication(ctx context.Context, re var cloneWarning sql.NullBool updateQuery := fmt.Sprintf(`SELECT p_success, p_error, p_clone_warning FROM %s($1, $2)`, p.sqlNames.PasskeyUpdateCounter) - err = p.db.QueryRowContext(ctx, updateQuery, response.RawID, newCounter).Scan(&updateSuccess, &updateError, &cloneWarning) + err = p.getDB().QueryRowContext(ctx, updateQuery, response.RawID, newCounter).Scan(&updateSuccess, &updateError, &cloneWarning) if err != nil { return 0, fmt.Errorf("failed to update counter: %w", err) } @@ -291,7 +326,7 @@ func (p *DatabasePasskeyProvider) GetCredentials(ctx context.Context, userID int var credentialsJSON sql.NullString query := fmt.Sprintf(`SELECT p_success, p_error, p_credentials::text FROM %s($1)`, p.sqlNames.PasskeyGetUserCredentials) - err := p.db.QueryRowContext(ctx, query, userID).Scan(&success, &errorMsg, &credentialsJSON) + err := p.getDB().QueryRowContext(ctx, query, userID).Scan(&success, &errorMsg, &credentialsJSON) if err != nil { return nil, fmt.Errorf("failed to get credentials: %w", err) } @@ -370,7 +405,7 @@ func (p *DatabasePasskeyProvider) DeleteCredential(ctx context.Context, userID i var errorMsg sql.NullString query := fmt.Sprintf(`SELECT p_success, p_error FROM %s($1, $2)`, p.sqlNames.PasskeyDeleteCredential) - err = p.db.QueryRowContext(ctx, query, userID, credID).Scan(&success, &errorMsg) + err = p.getDB().QueryRowContext(ctx, query, userID, credID).Scan(&success, &errorMsg) if err != nil { return fmt.Errorf("failed to delete credential: %w", err) } @@ -396,7 +431,7 @@ func (p *DatabasePasskeyProvider) UpdateCredentialName(ctx context.Context, user var errorMsg sql.NullString query := fmt.Sprintf(`SELECT p_success, p_error FROM %s($1, $2, $3)`, p.sqlNames.PasskeyUpdateName) - err = p.db.QueryRowContext(ctx, query, userID, credID, name).Scan(&success, &errorMsg) + err = p.getDB().QueryRowContext(ctx, query, userID, credID, name).Scan(&success, &errorMsg) if err != nil { return fmt.Errorf("failed to update credential name: %w", err) } diff --git a/pkg/security/providers.go b/pkg/security/providers.go index e23bd82..af57172 100644 --- a/pkg/security/providers.go +++ b/pkg/security/providers.go @@ -63,10 +63,12 @@ func (a *HeaderAuthenticator) Authenticate(r *http.Request) (*UserContext, error // Also supports multiple OAuth2 providers configured with WithOAuth2() // Also supports passkey authentication configured with WithPasskey() type DatabaseAuthenticator struct { - db *sql.DB - cache *cache.Cache - cacheTTL time.Duration - sqlNames *SQLNames + db *sql.DB + dbMu sync.RWMutex + dbFactory func() (*sql.DB, error) + cache *cache.Cache + cacheTTL time.Duration + sqlNames *SQLNames // OAuth2 providers registry (multiple providers supported) oauth2Providers map[string]*OAuth2Provider @@ -88,6 +90,9 @@ type DatabaseAuthenticatorOptions struct { // SQLNames provides custom SQL procedure/function names. If nil, uses DefaultSQLNames(). // Partial overrides are supported: only set the fields you want to change. SQLNames *SQLNames + // DBFactory is called to obtain a fresh *sql.DB when the existing connection is closed. + // If nil, reconnection is disabled. + DBFactory func() (*sql.DB, error) } func NewDatabaseAuthenticator(db *sql.DB) *DatabaseAuthenticator { @@ -110,6 +115,7 @@ func NewDatabaseAuthenticatorWithOptions(db *sql.DB, opts DatabaseAuthenticatorO return &DatabaseAuthenticator{ db: db, + dbFactory: opts.DBFactory, cache: cacheInstance, cacheTTL: opts.CacheTTL, sqlNames: sqlNames, @@ -117,6 +123,26 @@ func NewDatabaseAuthenticatorWithOptions(db *sql.DB, opts DatabaseAuthenticatorO } } +func (a *DatabaseAuthenticator) getDB() *sql.DB { + a.dbMu.RLock() + defer a.dbMu.RUnlock() + return a.db +} + +func (a *DatabaseAuthenticator) reconnectDB() error { + if a.dbFactory == nil { + return fmt.Errorf("no db factory configured for reconnect") + } + newDB, err := a.dbFactory() + if err != nil { + return err + } + a.dbMu.Lock() + a.db = newDB + a.dbMu.Unlock() + return nil +} + func (a *DatabaseAuthenticator) Login(ctx context.Context, req LoginRequest) (*LoginResponse, error) { // Convert LoginRequest to JSON reqJSON, err := json.Marshal(req) @@ -128,8 +154,16 @@ func (a *DatabaseAuthenticator) Login(ctx context.Context, req LoginRequest) (*L var errorMsg sql.NullString var dataJSON sql.NullString - query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.Login) - err = a.db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON) + runLoginQuery := func() error { + query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.Login) + return a.getDB().QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON) + } + err = runLoginQuery() + if isDBClosed(err) { + if reconnErr := a.reconnectDB(); reconnErr == nil { + err = runLoginQuery() + } + } if err != nil { return nil, fmt.Errorf("login query failed: %w", err) } @@ -163,7 +197,7 @@ func (a *DatabaseAuthenticator) Register(ctx context.Context, req RegisterReques var dataJSON sql.NullString query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.Register) - err = a.db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON) + err = a.getDB().QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON) if err != nil { return nil, fmt.Errorf("register query failed: %w", err) } @@ -196,7 +230,7 @@ func (a *DatabaseAuthenticator) Logout(ctx context.Context, req LogoutRequest) e var dataJSON sql.NullString query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.Logout) - err = a.db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON) + err = a.getDB().QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON) if err != nil { return fmt.Errorf("logout query failed: %w", err) } @@ -270,7 +304,7 @@ func (a *DatabaseAuthenticator) Authenticate(r *http.Request) (*UserContext, err var userJSON sql.NullString query := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2)`, a.sqlNames.Session) - err := a.db.QueryRowContext(r.Context(), query, token, reference).Scan(&success, &errorMsg, &userJSON) + err := a.getDB().QueryRowContext(r.Context(), query, token, reference).Scan(&success, &errorMsg, &userJSON) if err != nil { return nil, fmt.Errorf("session query failed: %w", err) } @@ -346,7 +380,7 @@ func (a *DatabaseAuthenticator) updateSessionActivity(ctx context.Context, sessi var updatedUserJSON sql.NullString query := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2::jsonb)`, a.sqlNames.SessionUpdate) - _ = a.db.QueryRowContext(ctx, query, sessionToken, string(userJSON)).Scan(&success, &errorMsg, &updatedUserJSON) + _ = a.getDB().QueryRowContext(ctx, query, sessionToken, string(userJSON)).Scan(&success, &errorMsg, &updatedUserJSON) } // RefreshToken implements Refreshable interface @@ -357,7 +391,7 @@ func (a *DatabaseAuthenticator) RefreshToken(ctx context.Context, refreshToken s var userJSON sql.NullString // Get current session to pass to refresh query := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2)`, a.sqlNames.Session) - err := a.db.QueryRowContext(ctx, query, refreshToken, "refresh").Scan(&success, &errorMsg, &userJSON) + err := a.getDB().QueryRowContext(ctx, query, refreshToken, "refresh").Scan(&success, &errorMsg, &userJSON) if err != nil { return nil, fmt.Errorf("refresh token query failed: %w", err) } @@ -374,7 +408,7 @@ func (a *DatabaseAuthenticator) RefreshToken(ctx context.Context, refreshToken s var newUserJSON sql.NullString refreshQuery := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2::jsonb)`, a.sqlNames.RefreshToken) - err = a.db.QueryRowContext(ctx, refreshQuery, refreshToken, userJSON).Scan(&newSuccess, &newErrorMsg, &newUserJSON) + err = a.getDB().QueryRowContext(ctx, refreshQuery, refreshToken, userJSON).Scan(&newSuccess, &newErrorMsg, &newUserJSON) if err != nil { return nil, fmt.Errorf("refresh token generation failed: %w", err) } @@ -406,6 +440,8 @@ func (a *DatabaseAuthenticator) RefreshToken(ctx context.Context, refreshToken s type JWTAuthenticator struct { secretKey []byte db *sql.DB + dbMu sync.RWMutex + dbFactory func() (*sql.DB, error) sqlNames *SQLNames } @@ -417,13 +453,47 @@ func NewJWTAuthenticator(secretKey string, db *sql.DB, names ...*SQLNames) *JWTA } } +// WithDBFactory configures a factory used to reopen the database connection if it is closed. +func (a *JWTAuthenticator) WithDBFactory(factory func() (*sql.DB, error)) *JWTAuthenticator { + a.dbFactory = factory + return a +} + +func (a *JWTAuthenticator) getDB() *sql.DB { + a.dbMu.RLock() + defer a.dbMu.RUnlock() + return a.db +} + +func (a *JWTAuthenticator) reconnectDB() error { + if a.dbFactory == nil { + return fmt.Errorf("no db factory configured for reconnect") + } + newDB, err := a.dbFactory() + if err != nil { + return err + } + a.dbMu.Lock() + a.db = newDB + a.dbMu.Unlock() + return nil +} + func (a *JWTAuthenticator) Login(ctx context.Context, req LoginRequest) (*LoginResponse, error) { var success bool var errorMsg sql.NullString var userJSON []byte - query := fmt.Sprintf(`SELECT p_success, p_error, p_user FROM %s($1, $2)`, a.sqlNames.JWTLogin) - err := a.db.QueryRowContext(ctx, query, req.Username, req.Password).Scan(&success, &errorMsg, &userJSON) + runLoginQuery := func() error { + query := fmt.Sprintf(`SELECT p_success, p_error, p_user FROM %s($1, $2)`, a.sqlNames.JWTLogin) + return a.getDB().QueryRowContext(ctx, query, req.Username, req.Password).Scan(&success, &errorMsg, &userJSON) + } + err := runLoginQuery() + if isDBClosed(err) { + if reconnErr := a.reconnectDB(); reconnErr == nil { + err = runLoginQuery() + } + } if err != nil { return nil, fmt.Errorf("login query failed: %w", err) } @@ -476,7 +546,7 @@ func (a *JWTAuthenticator) Logout(ctx context.Context, req LogoutRequest) error var errorMsg sql.NullString query := fmt.Sprintf(`SELECT p_success, p_error FROM %s($1, $2)`, a.sqlNames.JWTLogout) - err := a.db.QueryRowContext(ctx, query, req.Token, req.UserID).Scan(&success, &errorMsg) + err := a.getDB().QueryRowContext(ctx, query, req.Token, req.UserID).Scan(&success, &errorMsg) if err != nil { return fmt.Errorf("logout query failed: %w", err) } @@ -513,14 +583,41 @@ func (a *JWTAuthenticator) Authenticate(r *http.Request) (*UserContext, error) { // All database operations go through stored procedures // Procedure names are configurable via SQLNames (see DefaultSQLNames for defaults) type DatabaseColumnSecurityProvider struct { - db *sql.DB - sqlNames *SQLNames + db *sql.DB + dbMu sync.RWMutex + dbFactory func() (*sql.DB, error) + sqlNames *SQLNames } func NewDatabaseColumnSecurityProvider(db *sql.DB, names ...*SQLNames) *DatabaseColumnSecurityProvider { return &DatabaseColumnSecurityProvider{db: db, sqlNames: resolveSQLNames(names...)} } +func (p *DatabaseColumnSecurityProvider) WithDBFactory(factory func() (*sql.DB, error)) *DatabaseColumnSecurityProvider { + p.dbFactory = factory + return p +} + +func (p *DatabaseColumnSecurityProvider) getDB() *sql.DB { + p.dbMu.RLock() + defer p.dbMu.RUnlock() + return p.db +} + +func (p *DatabaseColumnSecurityProvider) reconnectDB() error { + if p.dbFactory == nil { + return fmt.Errorf("no db factory configured for reconnect") + } + newDB, err := p.dbFactory() + if err != nil { + return err + } + p.dbMu.Lock() + p.db = newDB + p.dbMu.Unlock() + return nil +} + func (p *DatabaseColumnSecurityProvider) GetColumnSecurity(ctx context.Context, userID int, schema, table string) ([]ColumnSecurity, error) { var rules []ColumnSecurity @@ -528,8 +625,16 @@ func (p *DatabaseColumnSecurityProvider) GetColumnSecurity(ctx context.Context, var errorMsg sql.NullString var rulesJSON []byte - query := fmt.Sprintf(`SELECT p_success, p_error, p_rules FROM %s($1, $2, $3)`, p.sqlNames.ColumnSecurity) - err := p.db.QueryRowContext(ctx, query, userID, schema, table).Scan(&success, &errorMsg, &rulesJSON) + runQuery := func() error { + query := fmt.Sprintf(`SELECT p_success, p_error, p_rules FROM %s($1, $2, $3)`, p.sqlNames.ColumnSecurity) + return p.getDB().QueryRowContext(ctx, query, userID, schema, table).Scan(&success, &errorMsg, &rulesJSON) + } + err := runQuery() + if isDBClosed(err) { + if reconnErr := p.reconnectDB(); reconnErr == nil { + err = runQuery() + } + } if err != nil { return nil, fmt.Errorf("failed to load column security: %w", err) } @@ -578,21 +683,55 @@ func (p *DatabaseColumnSecurityProvider) GetColumnSecurity(ctx context.Context, // All database operations go through stored procedures // Procedure names are configurable via SQLNames (see DefaultSQLNames for defaults) type DatabaseRowSecurityProvider struct { - db *sql.DB - sqlNames *SQLNames + db *sql.DB + dbMu sync.RWMutex + dbFactory func() (*sql.DB, error) + sqlNames *SQLNames } func NewDatabaseRowSecurityProvider(db *sql.DB, names ...*SQLNames) *DatabaseRowSecurityProvider { return &DatabaseRowSecurityProvider{db: db, sqlNames: resolveSQLNames(names...)} } +func (p *DatabaseRowSecurityProvider) WithDBFactory(factory func() (*sql.DB, error)) *DatabaseRowSecurityProvider { + p.dbFactory = factory + return p +} + +func (p *DatabaseRowSecurityProvider) getDB() *sql.DB { + p.dbMu.RLock() + defer p.dbMu.RUnlock() + return p.db +} + +func (p *DatabaseRowSecurityProvider) reconnectDB() error { + if p.dbFactory == nil { + return fmt.Errorf("no db factory configured for reconnect") + } + newDB, err := p.dbFactory() + if err != nil { + return err + } + p.dbMu.Lock() + p.db = newDB + p.dbMu.Unlock() + return nil +} + func (p *DatabaseRowSecurityProvider) GetRowSecurity(ctx context.Context, userID int, schema, table string) (RowSecurity, error) { var template string var hasBlock bool - query := fmt.Sprintf(`SELECT p_template, p_block FROM %s($1, $2, $3)`, p.sqlNames.RowSecurity) - - err := p.db.QueryRowContext(ctx, query, schema, table, userID).Scan(&template, &hasBlock) + runQuery := func() error { + query := fmt.Sprintf(`SELECT p_template, p_block FROM %s($1, $2, $3)`, p.sqlNames.RowSecurity) + return p.getDB().QueryRowContext(ctx, query, schema, table, userID).Scan(&template, &hasBlock) + } + err := runQuery() + if isDBClosed(err) { + if reconnErr := p.reconnectDB(); reconnErr == nil { + err = runQuery() + } + } if err != nil { return RowSecurity{}, fmt.Errorf("failed to load row security: %w", err) } @@ -662,6 +801,11 @@ func (p *ConfigRowSecurityProvider) GetRowSecurity(ctx context.Context, userID i // Helper functions // ================ +// isDBClosed reports whether err indicates the *sql.DB has been closed. +func isDBClosed(err error) bool { + return err != nil && strings.Contains(err.Error(), "sql: database is closed") +} + func parseRoles(rolesStr string) []string { if rolesStr == "" { return []string{} @@ -780,8 +924,16 @@ func (a *DatabaseAuthenticator) LoginWithPasskey(ctx context.Context, req Passke var errorMsg sql.NullString var dataJSON sql.NullString - query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.PasskeyLogin) - err = a.db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON) + runPasskeyQuery := func() error { + query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.PasskeyLogin) + return a.getDB().QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON) + } + err = runPasskeyQuery() + if isDBClosed(err) { + if reconnErr := a.reconnectDB(); reconnErr == nil { + err = runPasskeyQuery() + } + } if err != nil { return nil, fmt.Errorf("passkey login query failed: %w", err) }