Compare commits

..

12 Commits

Author SHA1 Message Date
Hein Puth (Warkanum)
1e89124c97 Merge pull request #18 from bitechdev/feature-auth-mcp
Some checks failed
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Successful in -30m1s
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Successful in -29m29s
Build , Vet Test, and Lint / Build (push) Successful in -29m37s
Build , Vet Test, and Lint / Lint Code (push) Successful in -28m58s
Tests / Unit Tests (push) Successful in -30m23s
Tests / Integration Tests (push) Failing after -30m31s
feat(security): implement OAuth2 authorization server with database s…
2026-04-09 16:18:18 +02:00
copilot-swe-agent[bot]
ca0545e144 fix(security): address validation review comments - mutex safety and issuer normalization
Agent-Logs-Url: https://github.com/bitechdev/ResolveSpec/sessions/e886b781-c910-425f-aa6f-06d13c46dcc7

Co-authored-by: warkanum <208308+warkanum@users.noreply.github.com>
2026-04-09 14:07:45 +00:00
copilot-swe-agent[bot]
850ad2b2ab fix(security): address all OAuth2 PR review issues
Agent-Logs-Url: https://github.com/bitechdev/ResolveSpec/sessions/e886b781-c910-425f-aa6f-06d13c46dcc7

Co-authored-by: warkanum <208308+warkanum@users.noreply.github.com>
2026-04-09 14:04:53 +00:00
Hein
2a2e33da0c Merge branch 'main' of https://github.com/bitechdev/ResolveSpec into feature-auth-mcp 2026-04-09 15:52:26 +02:00
Hein Puth (Warkanum)
17808a8121 Merge pull request #19 from bitechdev/feature-keystore
Feature keystore
2026-04-09 15:50:36 +02:00
Hein
134ff85c59 Merge branch 'main' of https://github.com/bitechdev/ResolveSpec into feature-keystore 2026-04-09 15:47:54 +02:00
Hein
bacddc58a6 style(recursive_crud): remove unnecessary blank line 2026-04-09 15:37:13 +02:00
Hein
f1ad83d966 feat(reflection): add JSON to DB column name mapping functions
* Implement BuildJSONToDBColumnMap for translating JSON keys to DB column names
* Enhance GetColumnName to extract column names with priority
* Update filterValidFields to utilize new mapping for improved data handling
* Fix TestToSnakeCase expected values for consistency
2026-04-09 15:36:52 +02:00
Hein
79a3912f93 fix(db): improve database connection handling and reconnection logic
* Added a database factory function to allow reconnection when the database is closed.
* Implemented mutex locks for safe concurrent access to the database connection.
* Updated all database query methods to handle reconnection attempts on closed connections.
* Enhanced error handling for database operations across multiple providers.
2026-04-09 09:19:28 +02:00
6502b55797 feat(security): implement OAuth2 authorization server with database support
- Add OAuthServer for handling OAuth2 flows including authorization, token exchange, and client registration.
- Introduce DatabaseAuthenticator for persisting clients and authorization codes.
- Implement SQL procedures for client registration, code saving, and token introspection.
- Support for external OAuth2 providers and PKCE (Proof Key for Code Exchange).
2026-04-07 22:56:05 +02:00
aa095d6bfd fix(tests): replace panic with log.Fatal for better error handling
Some checks failed
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Successful in -29m52s
Build , Vet Test, and Lint / Build (push) Successful in -29m52s
Tests / Integration Tests (push) Failing after -30m46s
Tests / Unit Tests (push) Successful in -28m51s
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Successful in -30m17s
Build , Vet Test, and Lint / Lint Code (push) Failing after -29m23s
2026-04-07 20:38:22 +02:00
Hein
a9bf08f58b feat(security): implement keystore for user authentication keys
* Add ConfigKeyStore for in-memory key management
* Introduce DatabaseKeyStore for PostgreSQL-backed key storage
* Create KeyStoreAuthenticator for API key validation
* Define SQL procedures for key management in PostgreSQL
* Document keystore functionality and usage in KEYSTORE.md
2026-04-07 17:09:17 +02:00
28 changed files with 3495 additions and 228 deletions

View File

@@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"reflect" "reflect"
"strings" "strings"
"sync"
"time" "time"
"github.com/uptrace/bun" "github.com/uptrace/bun"
@@ -95,6 +96,8 @@ func debugScanIntoStruct(rows interface{}, dest interface{}) error {
// This demonstrates how the abstraction works with different ORMs // This demonstrates how the abstraction works with different ORMs
type BunAdapter struct { type BunAdapter struct {
db *bun.DB db *bun.DB
dbMu sync.RWMutex
dbFactory func() (*bun.DB, error)
driverName string driverName string
} }
@@ -106,10 +109,36 @@ func NewBunAdapter(db *bun.DB) *BunAdapter {
return adapter return adapter
} }
// WithDBFactory configures a factory used to reopen the database connection if it is closed.
func (b *BunAdapter) WithDBFactory(factory func() (*bun.DB, error)) *BunAdapter {
b.dbFactory = factory
return b
}
func (b *BunAdapter) getDB() *bun.DB {
b.dbMu.RLock()
defer b.dbMu.RUnlock()
return b.db
}
func (b *BunAdapter) reconnectDB() error {
if b.dbFactory == nil {
return fmt.Errorf("no db factory configured for reconnect")
}
newDB, err := b.dbFactory()
if err != nil {
return err
}
b.dbMu.Lock()
b.db = newDB
b.dbMu.Unlock()
return nil
}
// EnableQueryDebug enables query debugging which logs all SQL queries including preloads // EnableQueryDebug enables query debugging which logs all SQL queries including preloads
// This is useful for debugging preload queries that may be failing // This is useful for debugging preload queries that may be failing
func (b *BunAdapter) EnableQueryDebug() { func (b *BunAdapter) EnableQueryDebug() {
b.db.AddQueryHook(&QueryDebugHook{}) b.getDB().AddQueryHook(&QueryDebugHook{})
logger.Info("Bun query debug mode enabled - all SQL queries will be logged") logger.Info("Bun query debug mode enabled - all SQL queries will be logged")
} }
@@ -130,22 +159,22 @@ func (b *BunAdapter) DisableQueryDebug() {
func (b *BunAdapter) NewSelect() common.SelectQuery { func (b *BunAdapter) NewSelect() common.SelectQuery {
return &BunSelectQuery{ return &BunSelectQuery{
query: b.db.NewSelect(), query: b.getDB().NewSelect(),
db: b.db, db: b.db,
driverName: b.driverName, driverName: b.driverName,
} }
} }
func (b *BunAdapter) NewInsert() common.InsertQuery { func (b *BunAdapter) NewInsert() common.InsertQuery {
return &BunInsertQuery{query: b.db.NewInsert()} return &BunInsertQuery{query: b.getDB().NewInsert()}
} }
func (b *BunAdapter) NewUpdate() common.UpdateQuery { func (b *BunAdapter) NewUpdate() common.UpdateQuery {
return &BunUpdateQuery{query: b.db.NewUpdate()} return &BunUpdateQuery{query: b.getDB().NewUpdate()}
} }
func (b *BunAdapter) NewDelete() common.DeleteQuery { func (b *BunAdapter) NewDelete() common.DeleteQuery {
return &BunDeleteQuery{query: b.db.NewDelete()} return &BunDeleteQuery{query: b.getDB().NewDelete()}
} }
func (b *BunAdapter) Exec(ctx context.Context, query string, args ...interface{}) (res common.Result, err error) { func (b *BunAdapter) Exec(ctx context.Context, query string, args ...interface{}) (res common.Result, err error) {
@@ -154,7 +183,14 @@ func (b *BunAdapter) Exec(ctx context.Context, query string, args ...interface{}
err = logger.HandlePanic("BunAdapter.Exec", r) err = logger.HandlePanic("BunAdapter.Exec", r)
} }
}() }()
result, err := b.db.ExecContext(ctx, query, args...) var result sql.Result
run := func() error { var e error; result, e = b.getDB().ExecContext(ctx, query, args...); return e }
err = run()
if isDBClosed(err) {
if reconnErr := b.reconnectDB(); reconnErr == nil {
err = run()
}
}
return &BunResult{result: result}, err return &BunResult{result: result}, err
} }
@@ -164,11 +200,17 @@ func (b *BunAdapter) Query(ctx context.Context, dest interface{}, query string,
err = logger.HandlePanic("BunAdapter.Query", r) err = logger.HandlePanic("BunAdapter.Query", r)
} }
}() }()
return b.db.NewRaw(query, args...).Scan(ctx, dest) err = b.getDB().NewRaw(query, args...).Scan(ctx, dest)
if isDBClosed(err) {
if reconnErr := b.reconnectDB(); reconnErr == nil {
err = b.getDB().NewRaw(query, args...).Scan(ctx, dest)
}
}
return err
} }
func (b *BunAdapter) BeginTx(ctx context.Context) (common.Database, error) { func (b *BunAdapter) BeginTx(ctx context.Context) (common.Database, error) {
tx, err := b.db.BeginTx(ctx, &sql.TxOptions{}) tx, err := b.getDB().BeginTx(ctx, &sql.TxOptions{})
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -194,7 +236,7 @@ func (b *BunAdapter) RunInTransaction(ctx context.Context, fn func(common.Databa
err = logger.HandlePanic("BunAdapter.RunInTransaction", r) err = logger.HandlePanic("BunAdapter.RunInTransaction", r)
} }
}() }()
return b.db.RunInTx(ctx, &sql.TxOptions{}, func(ctx context.Context, tx bun.Tx) error { return b.getDB().RunInTx(ctx, &sql.TxOptions{}, func(ctx context.Context, tx bun.Tx) error {
// Create adapter with transaction // Create adapter with transaction
adapter := &BunTxAdapter{tx: tx, driverName: b.driverName} adapter := &BunTxAdapter{tx: tx, driverName: b.driverName}
return fn(adapter) return fn(adapter)
@@ -202,7 +244,7 @@ func (b *BunAdapter) RunInTransaction(ctx context.Context, fn func(common.Databa
} }
func (b *BunAdapter) GetUnderlyingDB() interface{} { func (b *BunAdapter) GetUnderlyingDB() interface{} {
return b.db return b.getDB()
} }
func (b *BunAdapter) DriverName() string { func (b *BunAdapter) DriverName() string {

View File

@@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"reflect" "reflect"
"strings" "strings"
"sync"
"github.com/bitechdev/ResolveSpec/pkg/common" "github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/logger" "github.com/bitechdev/ResolveSpec/pkg/logger"
@@ -17,6 +18,8 @@ import (
// This provides a lightweight PostgreSQL adapter without ORM overhead // This provides a lightweight PostgreSQL adapter without ORM overhead
type PgSQLAdapter struct { type PgSQLAdapter struct {
db *sql.DB db *sql.DB
dbMu sync.RWMutex
dbFactory func() (*sql.DB, error)
driverName string driverName string
} }
@@ -31,6 +34,36 @@ func NewPgSQLAdapter(db *sql.DB, driverName ...string) *PgSQLAdapter {
return &PgSQLAdapter{db: db, driverName: name} return &PgSQLAdapter{db: db, driverName: name}
} }
// WithDBFactory configures a factory used to reopen the database connection if it is closed.
func (p *PgSQLAdapter) WithDBFactory(factory func() (*sql.DB, error)) *PgSQLAdapter {
p.dbFactory = factory
return p
}
func (p *PgSQLAdapter) getDB() *sql.DB {
p.dbMu.RLock()
defer p.dbMu.RUnlock()
return p.db
}
func (p *PgSQLAdapter) reconnectDB() error {
if p.dbFactory == nil {
return fmt.Errorf("no db factory configured for reconnect")
}
newDB, err := p.dbFactory()
if err != nil {
return err
}
p.dbMu.Lock()
p.db = newDB
p.dbMu.Unlock()
return nil
}
func isDBClosed(err error) bool {
return err != nil && strings.Contains(err.Error(), "sql: database is closed")
}
// EnableQueryDebug enables query debugging for development // EnableQueryDebug enables query debugging for development
func (p *PgSQLAdapter) EnableQueryDebug() { func (p *PgSQLAdapter) EnableQueryDebug() {
logger.Info("PgSQL query debug mode - logging enabled via logger") logger.Info("PgSQL query debug mode - logging enabled via logger")
@@ -38,7 +71,7 @@ func (p *PgSQLAdapter) EnableQueryDebug() {
func (p *PgSQLAdapter) NewSelect() common.SelectQuery { func (p *PgSQLAdapter) NewSelect() common.SelectQuery {
return &PgSQLSelectQuery{ return &PgSQLSelectQuery{
db: p.db, db: p.getDB(),
driverName: p.driverName, driverName: p.driverName,
columns: []string{"*"}, columns: []string{"*"},
args: make([]interface{}, 0), args: make([]interface{}, 0),
@@ -47,7 +80,7 @@ func (p *PgSQLAdapter) NewSelect() common.SelectQuery {
func (p *PgSQLAdapter) NewInsert() common.InsertQuery { func (p *PgSQLAdapter) NewInsert() common.InsertQuery {
return &PgSQLInsertQuery{ return &PgSQLInsertQuery{
db: p.db, db: p.getDB(),
driverName: p.driverName, driverName: p.driverName,
values: make(map[string]interface{}), values: make(map[string]interface{}),
} }
@@ -55,7 +88,7 @@ func (p *PgSQLAdapter) NewInsert() common.InsertQuery {
func (p *PgSQLAdapter) NewUpdate() common.UpdateQuery { func (p *PgSQLAdapter) NewUpdate() common.UpdateQuery {
return &PgSQLUpdateQuery{ return &PgSQLUpdateQuery{
db: p.db, db: p.getDB(),
driverName: p.driverName, driverName: p.driverName,
sets: make(map[string]interface{}), sets: make(map[string]interface{}),
args: make([]interface{}, 0), args: make([]interface{}, 0),
@@ -65,7 +98,7 @@ func (p *PgSQLAdapter) NewUpdate() common.UpdateQuery {
func (p *PgSQLAdapter) NewDelete() common.DeleteQuery { func (p *PgSQLAdapter) NewDelete() common.DeleteQuery {
return &PgSQLDeleteQuery{ return &PgSQLDeleteQuery{
db: p.db, db: p.getDB(),
driverName: p.driverName, driverName: p.driverName,
args: make([]interface{}, 0), args: make([]interface{}, 0),
whereClauses: make([]string, 0), whereClauses: make([]string, 0),
@@ -79,7 +112,14 @@ func (p *PgSQLAdapter) Exec(ctx context.Context, query string, args ...interface
} }
}() }()
logger.Debug("PgSQL Exec: %s [args: %v]", query, args) logger.Debug("PgSQL Exec: %s [args: %v]", query, args)
result, err := p.db.ExecContext(ctx, query, args...) var result sql.Result
run := func() error { var e error; result, e = p.getDB().ExecContext(ctx, query, args...); return e }
err = run()
if isDBClosed(err) {
if reconnErr := p.reconnectDB(); reconnErr == nil {
err = run()
}
}
if err != nil { if err != nil {
logger.Error("PgSQL Exec failed: %v", err) logger.Error("PgSQL Exec failed: %v", err)
return nil, err return nil, err
@@ -94,7 +134,14 @@ func (p *PgSQLAdapter) Query(ctx context.Context, dest interface{}, query string
} }
}() }()
logger.Debug("PgSQL Query: %s [args: %v]", query, args) logger.Debug("PgSQL Query: %s [args: %v]", query, args)
rows, err := p.db.QueryContext(ctx, query, args...) var rows *sql.Rows
run := func() error { var e error; rows, e = p.getDB().QueryContext(ctx, query, args...); return e }
err = run()
if isDBClosed(err) {
if reconnErr := p.reconnectDB(); reconnErr == nil {
err = run()
}
}
if err != nil { if err != nil {
logger.Error("PgSQL Query failed: %v", err) logger.Error("PgSQL Query failed: %v", err)
return err return err
@@ -105,7 +152,7 @@ func (p *PgSQLAdapter) Query(ctx context.Context, dest interface{}, query string
} }
func (p *PgSQLAdapter) BeginTx(ctx context.Context) (common.Database, error) { func (p *PgSQLAdapter) BeginTx(ctx context.Context) (common.Database, error) {
tx, err := p.db.BeginTx(ctx, nil) tx, err := p.getDB().BeginTx(ctx, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -127,7 +174,7 @@ func (p *PgSQLAdapter) RunInTransaction(ctx context.Context, fn func(common.Data
} }
}() }()
tx, err := p.db.BeginTx(ctx, nil) tx, err := p.getDB().BeginTx(ctx, nil)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -98,8 +98,8 @@ func (p *NestedCUDProcessor) ProcessNestedCUD(
} }
} }
// Filter regularData to only include fields that exist in the model // Filter regularData to only include fields that exist in the model,
// Use MapToStruct to validate and filter fields // and translate JSON keys to their actual database column names.
regularData = p.filterValidFields(regularData, model) regularData = p.filterValidFields(regularData, model)
// Inject parent IDs for foreign key resolution // Inject parent IDs for foreign key resolution
@@ -191,14 +191,15 @@ func (p *NestedCUDProcessor) extractCRUDRequest(data map[string]interface{}) str
return "" return ""
} }
// filterValidFields filters input data to only include fields that exist in the model // filterValidFields filters input data to only include fields that exist in the model,
// Uses reflection.MapToStruct to validate fields and extract only those that match the model // and translates JSON key names to their actual database column names.
// For example, a field tagged `json:"_changed_date" bun:"changed_date"` will be
// included in the result as "changed_date", not "_changed_date".
func (p *NestedCUDProcessor) filterValidFields(data map[string]interface{}, model interface{}) map[string]interface{} { func (p *NestedCUDProcessor) filterValidFields(data map[string]interface{}, model interface{}) map[string]interface{} {
if len(data) == 0 { if len(data) == 0 {
return data return data
} }
// Create a new instance of the model to use with MapToStruct
modelType := reflect.TypeOf(model) modelType := reflect.TypeOf(model)
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) { for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
modelType = modelType.Elem() modelType = modelType.Elem()
@@ -208,25 +209,16 @@ func (p *NestedCUDProcessor) filterValidFields(data map[string]interface{}, mode
return data return data
} }
// Create a new instance of the model // Build a mapping from JSON key -> DB column name for all writable fields.
tempModel := reflect.New(modelType).Interface() // This both validates which fields belong to the model and translates their names
// to the correct column names for use in SQL insert/update queries.
jsonToDBCol := reflection.BuildJSONToDBColumnMap(modelType)
// Use MapToStruct to map the data - this will only map valid fields
err := reflection.MapToStruct(data, tempModel)
if err != nil {
logger.Debug("Error mapping data to model: %v", err)
return data
}
// Extract the mapped fields back into a map
// This effectively filters out any fields that don't exist in the model
filteredData := make(map[string]interface{}) filteredData := make(map[string]interface{})
tempModelValue := reflect.ValueOf(tempModel).Elem()
for key, value := range data { for key, value := range data {
// Check if the field was successfully mapped dbColName, exists := jsonToDBCol[key]
if fieldWasMapped(tempModelValue, modelType, key) { if exists {
filteredData[key] = value filteredData[dbColName] = value
} else { } else {
logger.Debug("Skipping invalid field '%s' - not found in model %v", key, modelType) logger.Debug("Skipping invalid field '%s' - not found in model %v", key, modelType)
} }
@@ -235,72 +227,8 @@ func (p *NestedCUDProcessor) filterValidFields(data map[string]interface{}, mode
return filteredData return filteredData
} }
// fieldWasMapped checks if a field with the given key was mapped to the model // injectForeignKeys injects parent IDs into data for foreign key fields.
func fieldWasMapped(modelValue reflect.Value, modelType reflect.Type, key string) bool { // data is expected to be keyed by DB column names (as returned by filterValidFields).
// Look for the field by JSON tag or field name
for i := 0; i < modelType.NumField(); i++ {
field := modelType.Field(i)
// Skip unexported fields
if !field.IsExported() {
continue
}
// Check JSON tag
jsonTag := field.Tag.Get("json")
if jsonTag != "" && jsonTag != "-" {
parts := strings.Split(jsonTag, ",")
if len(parts) > 0 && parts[0] == key {
return true
}
}
// Check bun tag
bunTag := field.Tag.Get("bun")
if bunTag != "" && bunTag != "-" {
if colName := reflection.ExtractColumnFromBunTag(bunTag); colName == key {
return true
}
}
// Check gorm tag
gormTag := field.Tag.Get("gorm")
if gormTag != "" && gormTag != "-" {
if colName := reflection.ExtractColumnFromGormTag(gormTag); colName == key {
return true
}
}
// Check lowercase field name
if strings.EqualFold(field.Name, key) {
return true
}
// Handle embedded structs recursively
if field.Anonymous {
fieldType := field.Type
if fieldType.Kind() == reflect.Ptr {
fieldType = fieldType.Elem()
}
if fieldType.Kind() == reflect.Struct {
embeddedValue := modelValue.Field(i)
if embeddedValue.Kind() == reflect.Ptr {
if embeddedValue.IsNil() {
continue
}
embeddedValue = embeddedValue.Elem()
}
if fieldWasMapped(embeddedValue, fieldType, key) {
return true
}
}
}
}
return false
}
// injectForeignKeys injects parent IDs into data for foreign key fields
func (p *NestedCUDProcessor) injectForeignKeys(data map[string]interface{}, modelType reflect.Type, parentIDs map[string]interface{}) { func (p *NestedCUDProcessor) injectForeignKeys(data map[string]interface{}, modelType reflect.Type, parentIDs map[string]interface{}) {
if len(parentIDs) == 0 { if len(parentIDs) == 0 {
return return
@@ -319,10 +247,11 @@ func (p *NestedCUDProcessor) injectForeignKeys(data map[string]interface{}, mode
if strings.EqualFold(jsonName, parentKey+"_id") || if strings.EqualFold(jsonName, parentKey+"_id") ||
strings.EqualFold(jsonName, parentKey+"id") || strings.EqualFold(jsonName, parentKey+"id") ||
strings.EqualFold(field.Name, parentKey+"ID") { strings.EqualFold(field.Name, parentKey+"ID") {
// Only inject if not already present // Use the DB column name as the key, since data is keyed by DB column names
if _, exists := data[jsonName]; !exists { dbColName := reflection.GetColumnName(field)
logger.Debug("Injecting foreign key: %s = %v", jsonName, parentID) if _, exists := data[dbColName]; !exists {
data[jsonName] = parentID logger.Debug("Injecting foreign key: %s = %v", dbColName, parentID)
data[dbColName] = parentID
} }
} }
} }

View File

@@ -3,6 +3,7 @@ package providers_test
import ( import (
"context" "context"
"fmt" "fmt"
"log"
"time" "time"
"github.com/bitechdev/ResolveSpec/pkg/dbmanager" "github.com/bitechdev/ResolveSpec/pkg/dbmanager"
@@ -29,14 +30,14 @@ func ExamplePostgresListener_basic() {
ctx := context.Background() ctx := context.Background()
if err := provider.Connect(ctx, cfg); err != nil { if err := provider.Connect(ctx, cfg); err != nil {
panic(fmt.Sprintf("Failed to connect: %v", err)) log.Fatalf("Failed to connect: %v", err)
} }
defer provider.Close() defer provider.Close()
// Get listener // Get listener
listener, err := provider.GetListener(ctx) listener, err := provider.GetListener(ctx)
if err != nil { if err != nil {
panic(fmt.Sprintf("Failed to get listener: %v", err)) log.Fatalf("Failed to get listener: %v", err)
} }
// Subscribe to a channel with a handler // Subscribe to a channel with a handler
@@ -44,13 +45,13 @@ func ExamplePostgresListener_basic() {
fmt.Printf("Received notification on %s: %s\n", channel, payload) fmt.Printf("Received notification on %s: %s\n", channel, payload)
}) })
if err != nil { if err != nil {
panic(fmt.Sprintf("Failed to listen: %v", err)) log.Fatalf("Failed to listen: %v", err)
} }
// Send a notification // Send a notification
err = listener.Notify(ctx, "user_events", `{"event":"user_created","user_id":123}`) err = listener.Notify(ctx, "user_events", `{"event":"user_created","user_id":123}`)
if err != nil { if err != nil {
panic(fmt.Sprintf("Failed to notify: %v", err)) log.Fatalf("Failed to notify: %v", err)
} }
// Wait for notification to be processed // Wait for notification to be processed
@@ -58,7 +59,7 @@ func ExamplePostgresListener_basic() {
// Unsubscribe from the channel // Unsubscribe from the channel
if err := listener.Unlisten("user_events"); err != nil { if err := listener.Unlisten("user_events"); err != nil {
panic(fmt.Sprintf("Failed to unlisten: %v", err)) log.Fatalf("Failed to unlisten: %v", err)
} }
} }
@@ -80,13 +81,13 @@ func ExamplePostgresListener_multipleChannels() {
ctx := context.Background() ctx := context.Background()
if err := provider.Connect(ctx, cfg); err != nil { if err := provider.Connect(ctx, cfg); err != nil {
panic(fmt.Sprintf("Failed to connect: %v", err)) log.Fatalf("Failed to connect: %v", err)
} }
defer provider.Close() defer provider.Close()
listener, err := provider.GetListener(ctx) listener, err := provider.GetListener(ctx)
if err != nil { if err != nil {
panic(fmt.Sprintf("Failed to get listener: %v", err)) log.Fatalf("Failed to get listener: %v", err)
} }
// Listen to multiple channels // Listen to multiple channels
@@ -97,7 +98,7 @@ func ExamplePostgresListener_multipleChannels() {
fmt.Printf("[%s] %s\n", ch, payload) fmt.Printf("[%s] %s\n", ch, payload)
}) })
if err != nil { if err != nil {
panic(fmt.Sprintf("Failed to listen on %s: %v", channel, err)) log.Fatalf("Failed to listen on %s: %v", channel, err)
} }
} }
@@ -140,14 +141,14 @@ func ExamplePostgresListener_withDBManager() {
provider := providers.NewPostgresProvider() provider := providers.NewPostgresProvider()
if err := provider.Connect(ctx, cfg); err != nil { if err := provider.Connect(ctx, cfg); err != nil {
panic(err) log.Fatal(err)
} }
defer provider.Close() defer provider.Close()
// Get listener // Get listener
listener, err := provider.GetListener(ctx) listener, err := provider.GetListener(ctx)
if err != nil { if err != nil {
panic(err) log.Fatal(err)
} }
// Subscribe to application events // Subscribe to application events
@@ -186,13 +187,13 @@ func ExamplePostgresListener_errorHandling() {
ctx := context.Background() ctx := context.Background()
if err := provider.Connect(ctx, cfg); err != nil { if err := provider.Connect(ctx, cfg); err != nil {
panic(fmt.Sprintf("Failed to connect: %v", err)) log.Fatalf("Failed to connect: %v", err)
} }
defer provider.Close() defer provider.Close()
listener, err := provider.GetListener(ctx) listener, err := provider.GetListener(ctx)
if err != nil { if err != nil {
panic(fmt.Sprintf("Failed to get listener: %v", err)) log.Fatalf("Failed to get listener: %v", err)
} }
// The listener automatically reconnects if the connection is lost // The listener automatically reconnects if the connection is lost

View File

@@ -4,11 +4,17 @@ import (
"context" "context"
"database/sql" "database/sql"
"errors" "errors"
"strings"
"time" "time"
"go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo"
) )
// isDBClosed reports whether err indicates the *sql.DB has been closed.
func isDBClosed(err error) bool {
return err != nil && strings.Contains(err.Error(), "sql: database is closed")
}
// Common errors // Common errors
var ( var (
// ErrNotSQLDatabase is returned when attempting SQL operations on a non-SQL database // ErrNotSQLDatabase is returned when attempting SQL operations on a non-SQL database

View File

@@ -4,6 +4,7 @@ import (
"context" "context"
"database/sql" "database/sql"
"fmt" "fmt"
"sync"
"time" "time"
_ "github.com/glebarez/sqlite" // Pure Go SQLite driver _ "github.com/glebarez/sqlite" // Pure Go SQLite driver
@@ -15,6 +16,8 @@ import (
// SQLiteProvider implements Provider for SQLite databases // SQLiteProvider implements Provider for SQLite databases
type SQLiteProvider struct { type SQLiteProvider struct {
db *sql.DB db *sql.DB
dbMu sync.RWMutex
dbFactory func() (*sql.DB, error)
config ConnectionConfig config ConnectionConfig
} }
@@ -129,7 +132,13 @@ func (p *SQLiteProvider) HealthCheck(ctx context.Context) error {
// Execute a simple query to verify the database is accessible // Execute a simple query to verify the database is accessible
var result int var result int
err := p.db.QueryRowContext(healthCtx, "SELECT 1").Scan(&result) run := func() error { return p.getDB().QueryRowContext(healthCtx, "SELECT 1").Scan(&result) }
err := run()
if isDBClosed(err) {
if reconnErr := p.reconnectDB(); reconnErr == nil {
err = run()
}
}
if err != nil { if err != nil {
return fmt.Errorf("health check failed: %w", err) return fmt.Errorf("health check failed: %w", err)
} }
@@ -141,6 +150,32 @@ func (p *SQLiteProvider) HealthCheck(ctx context.Context) error {
return nil return nil
} }
// WithDBFactory configures a factory used to reopen the database connection if it is closed.
func (p *SQLiteProvider) WithDBFactory(factory func() (*sql.DB, error)) *SQLiteProvider {
p.dbFactory = factory
return p
}
func (p *SQLiteProvider) getDB() *sql.DB {
p.dbMu.RLock()
defer p.dbMu.RUnlock()
return p.db
}
func (p *SQLiteProvider) reconnectDB() error {
if p.dbFactory == nil {
return fmt.Errorf("no db factory configured for reconnect")
}
newDB, err := p.dbFactory()
if err != nil {
return err
}
p.dbMu.Lock()
p.db = newDB
p.dbMu.Unlock()
return nil
}
// GetNative returns the native *sql.DB connection // GetNative returns the native *sql.DB connection
func (p *SQLiteProvider) GetNative() (*sql.DB, error) { func (p *SQLiteProvider) GetNative() (*sql.DB, error) {
if p.db == nil { if p.db == nil {

View File

@@ -196,6 +196,92 @@ func collectColumnsFromType(typ reflect.Type, columns *[]string) {
} }
} }
// GetColumnName extracts the database column name from a struct field.
// Priority: bun tag -> gorm tag -> json tag -> lowercase field name.
// This is the exported version for use by other packages.
func GetColumnName(field reflect.StructField) string {
return getColumnNameFromField(field)
}
// BuildJSONToDBColumnMap returns a map from JSON key names to database column names
// for the given model type. Only writable, non-relation fields are included.
// This is used to translate incoming request data (keyed by JSON names) into
// properly named database columns before insert/update operations.
func BuildJSONToDBColumnMap(modelType reflect.Type) map[string]string {
result := make(map[string]string)
buildJSONToDBMap(modelType, result, false)
return result
}
func buildJSONToDBMap(modelType reflect.Type, result map[string]string, scanOnly bool) {
for i := 0; i < modelType.NumField(); i++ {
field := modelType.Field(i)
if !field.IsExported() {
continue
}
bunTag := field.Tag.Get("bun")
gormTag := field.Tag.Get("gorm")
// Handle embedded structs
if field.Anonymous {
ft := field.Type
if ft.Kind() == reflect.Ptr {
ft = ft.Elem()
}
isScanOnly := scanOnly
if bunTag != "" && isBunFieldScanOnly(bunTag) {
isScanOnly = true
}
if ft.Kind() == reflect.Struct {
buildJSONToDBMap(ft, result, isScanOnly)
continue
}
}
if scanOnly {
continue
}
// Skip explicitly excluded fields
if bunTag == "-" || gormTag == "-" {
continue
}
// Skip scan-only fields
if bunTag != "" && isBunFieldScanOnly(bunTag) {
continue
}
// Skip bun relation fields
if bunTag != "" && (strings.Contains(bunTag, "rel:") || strings.Contains(bunTag, "join:") || strings.Contains(bunTag, "m2m:")) {
continue
}
// Skip gorm relation fields
if gormTag != "" && (strings.Contains(gormTag, "foreignKey:") || strings.Contains(gormTag, "references:") || strings.Contains(gormTag, "many2many:")) {
continue
}
// Get JSON key (how the field appears in incoming request data)
jsonKey := ""
if jsonTag := field.Tag.Get("json"); jsonTag != "" && jsonTag != "-" {
parts := strings.Split(jsonTag, ",")
if len(parts) > 0 && parts[0] != "" {
jsonKey = parts[0]
}
}
if jsonKey == "" {
jsonKey = strings.ToLower(field.Name)
}
// Get the actual DB column name (bun > gorm > json > field name)
dbColName := getColumnNameFromField(field)
result[jsonKey] = dbColName
}
}
// getColumnNameFromField extracts the column name from a struct field // getColumnNameFromField extracts the column name from a struct field
// Priority: bun tag -> gorm tag -> json tag -> lowercase field name // Priority: bun tag -> gorm tag -> json tag -> lowercase field name
func getColumnNameFromField(field reflect.StructField) string { func getColumnNameFromField(field reflect.StructField) string {

View File

@@ -823,12 +823,12 @@ func TestToSnakeCase(t *testing.T) {
{ {
name: "UserID", name: "UserID",
input: "UserID", input: "UserID",
expected: "user_i_d", expected: "user_id",
}, },
{ {
name: "HTTPServer", name: "HTTPServer",
input: "HTTPServer", input: "HTTPServer",
expected: "h_t_t_p_server", expected: "http_server",
}, },
{ {
name: "lowercase", name: "lowercase",
@@ -838,7 +838,7 @@ func TestToSnakeCase(t *testing.T) {
{ {
name: "UPPERCASE", name: "UPPERCASE",
input: "UPPERCASE", input: "UPPERCASE",
expected: "u_p_p_e_r_c_a_s_e", expected: "uppercase",
}, },
{ {
name: "Single", name: "Single",

View File

@@ -142,9 +142,138 @@ e.Any("/mcp", echo.WrapHandler(h)) // Echo
--- ---
### Authentication ## OAuth2 Authentication
Add middleware before the MCP routes. The handler itself has no auth layer. `resolvemcp` ships a full **MCP-standard OAuth2 authorization server** (`pkg/security.OAuthServer`) that MCP clients (Claude Desktop, Cursor, etc.) can discover and use automatically.
It can operate as:
- **Its own identity provider** — shows a login form, validates via `DatabaseAuthenticator.Login()`
- **An OAuth2 federation layer** — delegates to external providers (Google, GitHub, Microsoft, etc.)
- **Both simultaneously**
### Standard endpoints served
| Path | Spec | Purpose |
|---|---|---|
| `GET /.well-known/oauth-authorization-server` | RFC 8414 | MCP client auto-discovery |
| `POST /oauth/register` | RFC 7591 | Dynamic client registration |
| `GET /oauth/authorize` | OAuth 2.1 + PKCE | Start login (form or provider redirect) |
| `POST /oauth/authorize` | — | Login form submission |
| `POST /oauth/token` | OAuth 2.1 | Auth code → Bearer token exchange |
| `POST /oauth/token` (refresh) | OAuth 2.1 | Refresh token rotation |
| `GET /oauth/provider/callback` | Internal | External provider redirect target |
MCP clients send `Authorization: Bearer <token>` on all subsequent requests.
---
### Mode 1 — Direct login (server as identity provider)
```go
import "github.com/bitechdev/ResolveSpec/pkg/security"
db, _ := sql.Open("postgres", dsn)
auth := security.NewDatabaseAuthenticator(db)
handler := resolvemcp.NewHandlerWithGORM(gormDB, resolvemcp.Config{
BaseURL: "https://api.example.com",
BasePath: "/mcp",
})
// Enable the OAuth2 server — auth enables the login form
handler.EnableOAuthServer(security.OAuthServerConfig{
Issuer: "https://api.example.com",
}, auth)
provider, _ := security.NewCompositeSecurityProvider(auth, colSec, rowSec)
securityList, _ := security.NewSecurityList(provider)
security.RegisterSecurityHooks(handler, securityList)
http.ListenAndServe(":8080", handler.HTTPHandler(securityList))
```
MCP client flow:
1. Discovers server at `/.well-known/oauth-authorization-server`
2. Registers itself at `/oauth/register`
3. Redirects user to `/oauth/authorize` → login form appears
4. On submit, exchanges code at `/oauth/token` → receives `Authorization: Bearer` token
5. Uses token on all MCP tool calls
---
### Mode 2 — External provider (Google, GitHub, etc.)
The `RedirectURL` in the provider config must point to `/oauth/provider/callback` on this server.
```go
auth := security.NewDatabaseAuthenticator(db).WithOAuth2(security.OAuth2Config{
ClientID: os.Getenv("GOOGLE_CLIENT_ID"),
ClientSecret: os.Getenv("GOOGLE_CLIENT_SECRET"),
RedirectURL: "https://api.example.com/oauth/provider/callback",
Scopes: []string{"openid", "profile", "email"},
AuthURL: "https://accounts.google.com/o/oauth2/auth",
TokenURL: "https://oauth2.googleapis.com/token",
UserInfoURL: "https://www.googleapis.com/oauth2/v2/userinfo",
ProviderName: "google",
})
// Pass `auth` so the OAuth server supports persistence, introspection, and revocation.
// Google handles the end-user authentication flow via redirect.
handler.EnableOAuthServer(security.OAuthServerConfig{
Issuer: "https://api.example.com",
}, auth)
handler.RegisterOAuth2Provider(auth, "google")
```
---
### Mode 3 — Both (login form + external providers)
```go
handler.EnableOAuthServer(security.OAuthServerConfig{
Issuer: "https://api.example.com",
LoginTitle: "My App Login",
}, auth) // auth enables the username/password form
handler.RegisterOAuth2Provider(googleAuth, "google")
handler.RegisterOAuth2Provider(githubAuth, "github")
```
When external providers are registered they take priority; the login form is used as fallback when no providers are configured.
---
### Using `security.OAuthServer` standalone
The authorization server lives in `pkg/security` and can be used with any HTTP framework independently of `resolvemcp`:
```go
oauthSrv := security.NewOAuthServer(security.OAuthServerConfig{
Issuer: "https://api.example.com",
}, auth)
oauthSrv.RegisterExternalProvider(googleAuth, "google")
mux := http.NewServeMux()
mux.Handle("/", oauthSrv.HTTPHandler()) // mounts all OAuth2 routes
mux.Handle("/mcp/", myMCPHandler)
http.ListenAndServe(":8080", mux)
```
---
### Cookie-based flow (legacy)
For simple setups without full MCP OAuth2 compliance, use the legacy helpers that set a session cookie after external provider login:
```go
resolvemcp.SetupMuxOAuth2Routes(r, auth, resolvemcp.OAuth2RouteConfig{
ProviderName: "google",
LoginPath: "/auth/google/login",
CallbackPath: "/auth/google/callback",
AfterLoginRedirect: "/",
})
resolvemcp.SetupMuxRoutesWithAuth(r, handler, securityList)
```
--- ---

View File

@@ -16,6 +16,7 @@ import (
"github.com/bitechdev/ResolveSpec/pkg/logger" "github.com/bitechdev/ResolveSpec/pkg/logger"
"github.com/bitechdev/ResolveSpec/pkg/modelregistry" "github.com/bitechdev/ResolveSpec/pkg/modelregistry"
"github.com/bitechdev/ResolveSpec/pkg/reflection" "github.com/bitechdev/ResolveSpec/pkg/reflection"
"github.com/bitechdev/ResolveSpec/pkg/security"
) )
// Handler exposes registered database models as MCP tools and resources. // Handler exposes registered database models as MCP tools and resources.
@@ -27,6 +28,8 @@ type Handler struct {
config Config config Config
name string name string
version string version string
oauth2Regs []oauth2Registration
oauthSrv *security.OAuthServer
} }
// NewHandler creates a Handler with the given database, model registry, and config. // NewHandler creates a Handler with the given database, model registry, and config.
@@ -197,8 +200,19 @@ func (h *Handler) getSchemaAndTable(defaultSchema, entity string, model interfac
return defaultSchema, entity return defaultSchema, entity
} }
// recoverPanic catches a panic from the current goroutine and returns it as an error.
// Usage: defer recoverPanic(&returnedErr)
func recoverPanic(err *error) {
if r := recover(); r != nil {
msg := fmt.Sprintf("%v", r)
logger.Error("[resolvemcp] panic recovered: %s", msg)
*err = fmt.Errorf("internal error: %s", msg)
}
}
// executeRead reads records from the database and returns raw data + metadata. // executeRead reads records from the database and returns raw data + metadata.
func (h *Handler) executeRead(ctx context.Context, schema, entity, id string, options common.RequestOptions) (interface{}, *common.Metadata, error) { func (h *Handler) executeRead(ctx context.Context, schema, entity, id string, options common.RequestOptions) (_ interface{}, _ *common.Metadata, retErr error) {
defer recoverPanic(&retErr)
model, err := h.registry.GetModelByEntity(schema, entity) model, err := h.registry.GetModelByEntity(schema, entity)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("model not found: %w", err) return nil, nil, fmt.Errorf("model not found: %w", err)
@@ -254,15 +268,6 @@ func (h *Handler) executeRead(ctx context.Context, schema, entity, id string, op
query = query.ColumnExpr(fmt.Sprintf("(%s) AS %s", cu.Expression, cu.Name)) query = query.ColumnExpr(fmt.Sprintf("(%s) AS %s", cu.Expression, cu.Name))
} }
// Preloads
if len(options.Preload) > 0 {
var err error
query, err = h.applyPreloads(model, query, options.Preload)
if err != nil {
return nil, nil, fmt.Errorf("failed to apply preloads: %w", err)
}
}
// Filters // Filters
query = h.applyFilters(query, options.Filters) query = h.applyFilters(query, options.Filters)
@@ -304,7 +309,7 @@ func (h *Handler) executeRead(ctx context.Context, schema, entity, id string, op
} }
} }
// Count // Count — must happen before preloads are applied; Bun panics when counting with relations.
total, err := query.Count(ctx) total, err := query.Count(ctx)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("error counting records: %w", err) return nil, nil, fmt.Errorf("error counting records: %w", err)
@@ -318,6 +323,15 @@ func (h *Handler) executeRead(ctx context.Context, schema, entity, id string, op
query = query.Offset(*options.Offset) query = query.Offset(*options.Offset)
} }
// Preloads — applied after count to avoid Bun panic when counting with relations.
if len(options.Preload) > 0 {
var preloadErr error
query, preloadErr = h.applyPreloads(model, query, options.Preload)
if preloadErr != nil {
return nil, nil, fmt.Errorf("failed to apply preloads: %w", preloadErr)
}
}
// BeforeRead hook // BeforeRead hook
hookCtx.Query = query hookCtx.Query = query
if err := h.hooks.Execute(BeforeRead, hookCtx); err != nil { if err := h.hooks.Execute(BeforeRead, hookCtx); err != nil {
@@ -378,7 +392,8 @@ func (h *Handler) executeRead(ctx context.Context, schema, entity, id string, op
} }
// executeCreate inserts one or more records. // executeCreate inserts one or more records.
func (h *Handler) executeCreate(ctx context.Context, schema, entity string, data interface{}) (interface{}, error) { func (h *Handler) executeCreate(ctx context.Context, schema, entity string, data interface{}) (_ interface{}, retErr error) {
defer recoverPanic(&retErr)
model, err := h.registry.GetModelByEntity(schema, entity) model, err := h.registry.GetModelByEntity(schema, entity)
if err != nil { if err != nil {
return nil, fmt.Errorf("model not found: %w", err) return nil, fmt.Errorf("model not found: %w", err)
@@ -462,7 +477,8 @@ func (h *Handler) executeCreate(ctx context.Context, schema, entity string, data
} }
// executeUpdate updates a record by ID. // executeUpdate updates a record by ID.
func (h *Handler) executeUpdate(ctx context.Context, schema, entity, id string, data interface{}) (interface{}, error) { func (h *Handler) executeUpdate(ctx context.Context, schema, entity, id string, data interface{}) (_ interface{}, retErr error) {
defer recoverPanic(&retErr)
model, err := h.registry.GetModelByEntity(schema, entity) model, err := h.registry.GetModelByEntity(schema, entity)
if err != nil { if err != nil {
return nil, fmt.Errorf("model not found: %w", err) return nil, fmt.Errorf("model not found: %w", err)
@@ -572,7 +588,8 @@ func (h *Handler) executeUpdate(ctx context.Context, schema, entity, id string,
} }
// executeDelete deletes a record by ID. // executeDelete deletes a record by ID.
func (h *Handler) executeDelete(ctx context.Context, schema, entity, id string) (interface{}, error) { func (h *Handler) executeDelete(ctx context.Context, schema, entity, id string) (_ interface{}, retErr error) {
defer recoverPanic(&retErr)
if id == "" { if id == "" {
return nil, fmt.Errorf("delete requires an ID") return nil, fmt.Errorf("delete requires an ID")
} }
@@ -703,7 +720,7 @@ func (h *Handler) applyFilterGroup(query common.SelectQuery, filters []common.Fi
return query.Where("("+strings.Join(conditions, " OR ")+")", args...) return query.Where("("+strings.Join(conditions, " OR ")+")", args...)
} }
func (h *Handler) buildFilterCondition(filter common.FilterOption) (string, []interface{}) { func (h *Handler) buildFilterCondition(filter common.FilterOption) (condition string, args []interface{}) {
switch filter.Operator { switch filter.Operator {
case "eq", "=": case "eq", "=":
return fmt.Sprintf("%s = ?", filter.Column), []interface{}{filter.Value} return fmt.Sprintf("%s = ?", filter.Column), []interface{}{filter.Value}
@@ -733,7 +750,8 @@ func (h *Handler) buildFilterCondition(filter common.FilterOption) (string, []in
} }
func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, preloads []common.PreloadOption) (common.SelectQuery, error) { func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, preloads []common.PreloadOption) (common.SelectQuery, error) {
for _, preload := range preloads { for i := range preloads {
preload := &preloads[i]
if preload.Relation == "" { if preload.Relation == "" {
continue continue
} }

264
pkg/resolvemcp/oauth2.go Normal file
View File

@@ -0,0 +1,264 @@
package resolvemcp
import (
"encoding/json"
"net/http"
"github.com/gorilla/mux"
"github.com/bitechdev/ResolveSpec/pkg/security"
)
// --------------------------------------------------------------------------
// OAuth2 registration on the Handler
// --------------------------------------------------------------------------
// oauth2Registration stores a configured auth provider and its route config.
type oauth2Registration struct {
auth *security.DatabaseAuthenticator
cfg OAuth2RouteConfig
}
// RegisterOAuth2 attaches an OAuth2 provider to the Handler.
// The login and callback HTTP routes are served by HTTPHandler / StreamableHTTPMux.
// Call this once per provider before serving requests.
//
// Example:
//
// auth := security.NewGoogleAuthenticator(clientID, secret, redirectURL, db)
// handler.RegisterOAuth2(auth, resolvemcp.OAuth2RouteConfig{
// ProviderName: "google",
// LoginPath: "/auth/google/login",
// CallbackPath: "/auth/google/callback",
// AfterLoginRedirect: "/",
// })
func (h *Handler) RegisterOAuth2(auth *security.DatabaseAuthenticator, cfg OAuth2RouteConfig) {
h.oauth2Regs = append(h.oauth2Regs, oauth2Registration{auth: auth, cfg: cfg})
}
// HTTPHandler returns a single http.Handler that serves:
// - MCP OAuth2 authorization server endpoints (when EnableOAuthServer has been called)
// - OAuth2 login + callback routes for every registered provider (legacy cookie flow)
// - The MCP SSE transport wrapped with required authentication middleware
//
// Example:
//
// auth := security.NewGoogleAuthenticator(...)
// handler.RegisterOAuth2(auth, cfg)
// handler.EnableOAuthServer(security.OAuthServerConfig{Issuer: "https://api.example.com"})
// security.RegisterSecurityHooks(handler, securityList)
// http.ListenAndServe(":8080", handler.HTTPHandler(securityList))
func (h *Handler) HTTPHandler(securityList *security.SecurityList) http.Handler {
mux := http.NewServeMux()
if h.oauthSrv != nil {
h.mountOAuthServerRoutes(mux)
}
h.mountOAuth2Routes(mux)
mcpHandler := h.AuthedSSEServer(securityList)
basePath := h.config.BasePath
if basePath == "" {
basePath = "/mcp"
}
mux.Handle(basePath+"/sse", mcpHandler)
mux.Handle(basePath+"/message", mcpHandler)
mux.Handle(basePath+"/", http.StripPrefix(basePath, mcpHandler))
return mux
}
// StreamableHTTPMux returns a single http.Handler that serves:
// - MCP OAuth2 authorization server endpoints (when EnableOAuthServer has been called)
// - OAuth2 login + callback routes for every registered provider (legacy cookie flow)
// - The MCP streamable HTTP transport wrapped with required authentication middleware
//
// Example:
//
// http.ListenAndServe(":8080", handler.StreamableHTTPMux(securityList))
func (h *Handler) StreamableHTTPMux(securityList *security.SecurityList) http.Handler {
mux := http.NewServeMux()
if h.oauthSrv != nil {
h.mountOAuthServerRoutes(mux)
}
h.mountOAuth2Routes(mux)
mcpHandler := h.AuthedStreamableHTTPServer(securityList)
basePath := h.config.BasePath
if basePath == "" {
basePath = "/mcp"
}
mux.Handle(basePath+"/", http.StripPrefix(basePath, mcpHandler))
mux.Handle(basePath, mcpHandler)
return mux
}
// mountOAuth2Routes registers all stored OAuth2 login+callback routes onto mux.
func (h *Handler) mountOAuth2Routes(mux *http.ServeMux) {
for _, reg := range h.oauth2Regs {
var cookieOpts []security.SessionCookieOptions
if reg.cfg.CookieOptions != nil {
cookieOpts = append(cookieOpts, *reg.cfg.CookieOptions)
}
mux.Handle(reg.cfg.LoginPath, OAuth2LoginHandler(reg.auth, reg.cfg.ProviderName))
mux.Handle(reg.cfg.CallbackPath, OAuth2CallbackHandler(reg.auth, reg.cfg.ProviderName, reg.cfg.AfterLoginRedirect, cookieOpts...))
}
}
// --------------------------------------------------------------------------
// Auth-wrapped transports
// --------------------------------------------------------------------------
// AuthedSSEServer wraps SSEServer with required authentication middleware from pkg/security.
// The middleware reads the session cookie / Authorization header and populates the user
// context into the request context, making it available to BeforeHandle security hooks.
// Unauthenticated requests receive 401 before reaching any MCP tool.
func (h *Handler) AuthedSSEServer(securityList *security.SecurityList) http.Handler {
return security.NewAuthMiddleware(securityList)(h.SSEServer())
}
// OptionalAuthSSEServer wraps SSEServer with optional authentication middleware.
// Unauthenticated requests continue as guest rather than returning 401.
// Use together with RegisterSecurityHooks and per-model CanPublicRead/Write rules
// to allow mixed public/private access.
func (h *Handler) OptionalAuthSSEServer(securityList *security.SecurityList) http.Handler {
return security.NewOptionalAuthMiddleware(securityList)(h.SSEServer())
}
// AuthedStreamableHTTPServer wraps StreamableHTTPServer with required authentication middleware.
func (h *Handler) AuthedStreamableHTTPServer(securityList *security.SecurityList) http.Handler {
return security.NewAuthMiddleware(securityList)(h.StreamableHTTPServer())
}
// OptionalAuthStreamableHTTPServer wraps StreamableHTTPServer with optional authentication middleware.
func (h *Handler) OptionalAuthStreamableHTTPServer(securityList *security.SecurityList) http.Handler {
return security.NewOptionalAuthMiddleware(securityList)(h.StreamableHTTPServer())
}
// --------------------------------------------------------------------------
// OAuth2 route config and standalone handlers
// --------------------------------------------------------------------------
// OAuth2RouteConfig configures the OAuth2 HTTP endpoints for a single provider.
type OAuth2RouteConfig struct {
// ProviderName is the OAuth2 provider name as registered with WithOAuth2()
// (e.g. "google", "github", "microsoft").
ProviderName string
// LoginPath is the HTTP path that redirects the browser to the OAuth2 provider
// (e.g. "/auth/google/login").
LoginPath string
// CallbackPath is the HTTP path that the OAuth2 provider redirects back to
// (e.g. "/auth/google/callback"). Must match the RedirectURL in OAuth2Config.
CallbackPath string
// AfterLoginRedirect is the URL to redirect the browser to after a successful
// login. When empty the LoginResponse JSON is written directly to the response.
AfterLoginRedirect string
// CookieOptions customises the session cookie written on successful login.
// Defaults to HttpOnly, Secure, SameSite=Lax when nil.
CookieOptions *security.SessionCookieOptions
}
// OAuth2LoginHandler returns an http.HandlerFunc that redirects the browser to
// the OAuth2 provider's authorization URL.
//
// Register it on any router:
//
// mux.Handle("/auth/google/login", resolvemcp.OAuth2LoginHandler(auth, "google"))
func OAuth2LoginHandler(auth *security.DatabaseAuthenticator, providerName string) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
state, err := auth.OAuth2GenerateState()
if err != nil {
http.Error(w, "failed to generate state", http.StatusInternalServerError)
return
}
authURL, err := auth.OAuth2GetAuthURL(providerName, state)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
http.Redirect(w, r, authURL, http.StatusTemporaryRedirect)
}
}
// OAuth2CallbackHandler returns an http.HandlerFunc that handles the OAuth2 provider
// callback: exchanges the authorization code for a session token, writes the session
// cookie, then either redirects to afterLoginRedirect or writes the LoginResponse as JSON.
//
// Register it on any router:
//
// mux.Handle("/auth/google/callback", resolvemcp.OAuth2CallbackHandler(auth, "google", "/dashboard"))
func OAuth2CallbackHandler(auth *security.DatabaseAuthenticator, providerName, afterLoginRedirect string, cookieOpts ...security.SessionCookieOptions) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
code := r.URL.Query().Get("code")
state := r.URL.Query().Get("state")
if code == "" {
http.Error(w, "missing code parameter", http.StatusBadRequest)
return
}
loginResp, err := auth.OAuth2HandleCallback(r.Context(), providerName, code, state)
if err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
security.SetSessionCookie(w, loginResp, cookieOpts...)
if afterLoginRedirect != "" {
http.Redirect(w, r, afterLoginRedirect, http.StatusTemporaryRedirect)
return
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(loginResp) //nolint:errcheck
}
}
// --------------------------------------------------------------------------
// Gorilla Mux convenience helpers
// --------------------------------------------------------------------------
// SetupMuxOAuth2Routes registers the OAuth2 login and callback routes on a Gorilla Mux router.
//
// Example:
//
// resolvemcp.SetupMuxOAuth2Routes(r, auth, resolvemcp.OAuth2RouteConfig{
// ProviderName: "google", LoginPath: "/auth/google/login",
// CallbackPath: "/auth/google/callback", AfterLoginRedirect: "/",
// })
func SetupMuxOAuth2Routes(muxRouter *mux.Router, auth *security.DatabaseAuthenticator, cfg OAuth2RouteConfig) {
var cookieOpts []security.SessionCookieOptions
if cfg.CookieOptions != nil {
cookieOpts = append(cookieOpts, *cfg.CookieOptions)
}
muxRouter.Handle(cfg.LoginPath,
OAuth2LoginHandler(auth, cfg.ProviderName),
).Methods(http.MethodGet)
muxRouter.Handle(cfg.CallbackPath,
OAuth2CallbackHandler(auth, cfg.ProviderName, cfg.AfterLoginRedirect, cookieOpts...),
).Methods(http.MethodGet)
}
// SetupMuxRoutesWithAuth mounts the MCP SSE endpoints on a Gorilla Mux router
// with required authentication middleware applied.
func SetupMuxRoutesWithAuth(muxRouter *mux.Router, handler *Handler, securityList *security.SecurityList) {
basePath := handler.config.BasePath
h := handler.AuthedSSEServer(securityList)
muxRouter.Handle(basePath+"/sse", h).Methods(http.MethodGet, http.MethodOptions)
muxRouter.Handle(basePath+"/message", h).Methods(http.MethodPost, http.MethodOptions)
muxRouter.PathPrefix(basePath).Handler(http.StripPrefix(basePath, h))
}
// SetupMuxStreamableHTTPRoutesWithAuth mounts the MCP streamable HTTP endpoint on a
// Gorilla Mux router with required authentication middleware applied.
func SetupMuxStreamableHTTPRoutesWithAuth(muxRouter *mux.Router, handler *Handler, securityList *security.SecurityList) {
basePath := handler.config.BasePath
h := handler.AuthedStreamableHTTPServer(securityList)
muxRouter.PathPrefix(basePath).Handler(http.StripPrefix(basePath, h))
}

View File

@@ -0,0 +1,51 @@
package resolvemcp
import (
"net/http"
"github.com/bitechdev/ResolveSpec/pkg/security"
)
// EnableOAuthServer activates the MCP-standard OAuth2 authorization server on this Handler.
//
// Pass a DatabaseAuthenticator to enable direct username/password login — the server acts as
// its own identity provider and renders a login form at /oauth/authorize. Pass nil to use
// only external providers registered via RegisterOAuth2Provider.
//
// After calling this, HTTPHandler and StreamableHTTPMux serve the full set of RFC-compliant
// endpoints required by MCP clients alongside the MCP transport:
//
// GET /.well-known/oauth-authorization-server RFC 8414 — auto-discovery
// POST /oauth/register RFC 7591 — dynamic client registration
// GET /oauth/authorize OAuth 2.1 + PKCE — start login
// POST /oauth/authorize Login form submission (password flow)
// POST /oauth/token Bearer token exchange + refresh
// GET /oauth/provider/callback External provider redirect target
func (h *Handler) EnableOAuthServer(cfg security.OAuthServerConfig, auth *security.DatabaseAuthenticator) {
h.oauthSrv = security.NewOAuthServer(cfg, auth)
// Wire any external providers already registered via RegisterOAuth2
for _, reg := range h.oauth2Regs {
h.oauthSrv.RegisterExternalProvider(reg.auth, reg.cfg.ProviderName)
}
}
// RegisterOAuth2Provider adds an external OAuth2 provider to the MCP OAuth2 authorization server.
// EnableOAuthServer must be called before this. The auth must have been configured with
// WithOAuth2(providerName, ...) for the given provider name.
func (h *Handler) RegisterOAuth2Provider(auth *security.DatabaseAuthenticator, providerName string) {
if h.oauthSrv != nil {
h.oauthSrv.RegisterExternalProvider(auth, providerName)
}
}
// mountOAuthServerRoutes mounts the security.OAuthServer's HTTP handler onto mux.
func (h *Handler) mountOAuthServerRoutes(mux *http.ServeMux) {
oauthHandler := h.oauthSrv.HTTPHandler()
// Delegate all /oauth/ and /.well-known/ paths to the OAuth server
mux.Handle("/.well-known/", oauthHandler)
mux.Handle("/oauth/", oauthHandler)
if h.oauthSrv != nil {
// Also mount the external provider callback path if it differs from /oauth/
mux.Handle(h.oauthSrv.ProviderCallbackPath(), oauthHandler)
}
}

153
pkg/security/KEYSTORE.md Normal file
View File

@@ -0,0 +1,153 @@
# Keystore
Per-user named auth keys with pluggable storage. Each user can hold multiple keys of different types — JWT secrets, header API keys, OAuth2 client credentials, or generic API keys. Keys are identified by a human-readable name ("CI deploy", "mobile app") and can carry scopes and arbitrary metadata.
## Key types
| Constant | Value | Use case |
|---|---|---|
| `KeyTypeJWTSecret` | `jwt_secret` | Per-user JWT signing secret |
| `KeyTypeHeaderAPI` | `header_api` | Static API key sent in a request header |
| `KeyTypeOAuth2` | `oauth2` | OAuth2 client credentials |
| `KeyTypeGenericAPI` | `api` | General-purpose application key |
## Storage backends
### ConfigKeyStore
In-memory store seeded from a static list. Suitable for a small, fixed set of service-account keys loaded from a config file. Keys created at runtime via `CreateKey` are held in memory and lost on restart.
```go
// Pre-load keys from config (KeyHash = SHA-256 hex of the raw key)
store := security.NewConfigKeyStore([]security.UserKey{
{
UserID: 1,
KeyType: security.KeyTypeGenericAPI,
KeyHash: "e3b0c44298fc1c149afb...", // sha256(rawKey)
Name: "CI deploy",
Scopes: []string{"deploy"},
IsActive: true,
},
})
```
### DatabaseKeyStore
Backed by PostgreSQL stored procedures. Supports optional caching (default 2-minute TTL). Apply `keystore_schema.sql` before use.
```go
db, _ := sql.Open("postgres", dsn)
store := security.NewDatabaseKeyStore(db)
// With options
store = security.NewDatabaseKeyStore(db, security.DatabaseKeyStoreOptions{
CacheTTL: 5 * time.Minute,
SQLNames: &security.KeyStoreSQLNames{
ValidateKey: "myapp_keystore_validate", // override one procedure name
},
})
```
## Managing keys
```go
ctx := context.Background()
// Create — raw key returned once; store it securely
resp, err := store.CreateKey(ctx, security.CreateKeyRequest{
UserID: 42,
KeyType: security.KeyTypeGenericAPI,
Name: "mobile app",
Scopes: []string{"read", "write"},
})
fmt.Println(resp.RawKey) // only shown here; hashed internally
// List
keys, err := store.GetUserKeys(ctx, 42, "") // "" = all types
keys, err = store.GetUserKeys(ctx, 42, security.KeyTypeGenericAPI)
// Revoke
err = store.DeleteKey(ctx, 42, resp.Key.ID)
// Validate (used by authenticators internally)
key, err := store.ValidateKey(ctx, rawKey, "")
```
## HTTP authentication
`KeyStoreAuthenticator` wraps any `KeyStore` and implements the `Authenticator` interface. It is drop-in compatible with `DatabaseAuthenticator` and works in `CompositeSecurityProvider`.
Keys are extracted from the request in this order:
1. `Authorization: Bearer <key>`
2. `Authorization: ApiKey <key>`
3. `X-API-Key: <key>`
```go
auth := security.NewKeyStoreAuthenticator(store, "") // "" = accept any key type
// Restrict to a specific type:
auth = security.NewKeyStoreAuthenticator(store, security.KeyTypeGenericAPI)
```
Plug it into a handler:
```go
handler := resolvespec.NewHandler(db, registry,
resolvespec.WithAuthenticator(auth),
)
```
`Login` and `Logout` return an error — key lifecycle is managed through `KeyStore` directly.
On successful validation the request context receives a `UserContext` where:
- `UserID` — from the key
- `Roles` — the key's `Scopes`
- `Claims["key_type"]` — key type string
- `Claims["key_name"]` — key name
## Database setup
Apply `keystore_schema.sql` to your PostgreSQL database. It requires the `users` table from the main `database_schema.sql`.
```sql
\i pkg/security/keystore_schema.sql
```
This creates:
- `user_keys` table with indexes on `user_id`, `key_hash`, and `key_type`
- `resolvespec_keystore_get_user_keys(p_user_id, p_key_type)`
- `resolvespec_keystore_create_key(p_request jsonb)`
- `resolvespec_keystore_delete_key(p_user_id, p_key_id)`
- `resolvespec_keystore_validate_key(p_key_hash, p_key_type)`
### Custom procedure names
```go
store := security.NewDatabaseKeyStore(db, security.DatabaseKeyStoreOptions{
SQLNames: &security.KeyStoreSQLNames{
GetUserKeys: "myschema_get_keys",
CreateKey: "myschema_create_key",
DeleteKey: "myschema_delete_key",
ValidateKey: "myschema_validate_key",
},
})
// Validate names at startup
names := &security.KeyStoreSQLNames{
GetUserKeys: "myschema_get_keys",
// ...
}
if err := security.ValidateKeyStoreSQLNames(names); err != nil {
log.Fatal(err)
}
```
## Security notes
- Raw keys are never stored. Only the SHA-256 hex digest is persisted.
- The raw key is generated with `crypto/rand` (32 bytes, base64url-encoded) and returned exactly once in `CreateKeyResponse.RawKey`.
- Hash comparisons in `ConfigKeyStore` use `crypto/subtle.ConstantTimeCompare` to prevent timing side-channels.
- `DeleteKey` performs a soft delete (`is_active = false`). The `DatabaseKeyStore` invalidates the cache entry immediately, but due to the cache TTL a revoked key may authenticate for up to `CacheTTL` (default 2 minutes) in a distributed environment. Set `CacheTTL: 0` to disable caching if immediate revocation is required.

View File

@@ -12,6 +12,7 @@ Type-safe, composable security system for ResolveSpec with support for authentic
-**Testable** - Easy to mock and test -**Testable** - Easy to mock and test
-**Extensible** - Implement custom providers for your needs -**Extensible** - Implement custom providers for your needs
-**Stored Procedures** - All database operations use PostgreSQL stored procedures for security and maintainability -**Stored Procedures** - All database operations use PostgreSQL stored procedures for security and maintainability
-**OAuth2 Authorization Server** - Built-in OAuth 2.1 + PKCE server (RFC 8414, 7591, 7009, 7662) with login form and external provider federation
## Stored Procedure Architecture ## Stored Procedure Architecture
@@ -38,6 +39,12 @@ Type-safe, composable security system for ResolveSpec with support for authentic
| `resolvespec_jwt_logout` | JWT token blacklist | JWTAuthenticator | | `resolvespec_jwt_logout` | JWT token blacklist | JWTAuthenticator |
| `resolvespec_column_security` | Load column rules | DatabaseColumnSecurityProvider | | `resolvespec_column_security` | Load column rules | DatabaseColumnSecurityProvider |
| `resolvespec_row_security` | Load row templates | DatabaseRowSecurityProvider | | `resolvespec_row_security` | Load row templates | DatabaseRowSecurityProvider |
| `resolvespec_oauth_register_client` | Persist OAuth2 client (RFC 7591) | OAuthServer / DatabaseAuthenticator |
| `resolvespec_oauth_get_client` | Retrieve OAuth2 client by ID | OAuthServer / DatabaseAuthenticator |
| `resolvespec_oauth_save_code` | Persist authorization code | OAuthServer / DatabaseAuthenticator |
| `resolvespec_oauth_exchange_code` | Consume authorization code (single-use) | OAuthServer / DatabaseAuthenticator |
| `resolvespec_oauth_introspect` | Token introspection (RFC 7662) | OAuthServer / DatabaseAuthenticator |
| `resolvespec_oauth_revoke` | Token revocation (RFC 7009) | OAuthServer / DatabaseAuthenticator |
See `database_schema.sql` for complete stored procedure definitions and examples. See `database_schema.sql` for complete stored procedure definitions and examples.
@@ -897,6 +904,156 @@ securityList := security.NewSecurityList(provider)
restheadspec.RegisterSecurityHooks(handler, securityList) // or funcspec/resolvespec restheadspec.RegisterSecurityHooks(handler, securityList) // or funcspec/resolvespec
``` ```
## OAuth2 Authorization Server
`OAuthServer` is a generic OAuth 2.1 + PKCE authorization server. It is not tied to any spec — `pkg/resolvemcp` uses it, but it can be used standalone with any `http.ServeMux`.
### Endpoints
| Method | Path | RFC |
|--------|------|-----|
| `GET` | `/.well-known/oauth-authorization-server` | RFC 8414 — server metadata |
| `POST` | `/oauth/register` | RFC 7591 — dynamic client registration |
| `GET` | `/oauth/authorize` | OAuth 2.1 — start authorization / provider selection |
| `POST` | `/oauth/authorize` | OAuth 2.1 — login form submission |
| `POST` | `/oauth/token` | OAuth 2.1 — code exchange + refresh |
| `POST` | `/oauth/revoke` | RFC 7009 — token revocation |
| `POST` | `/oauth/introspect` | RFC 7662 — token introspection |
| `GET` | `{ProviderCallbackPath}` | External provider redirect target |
### Config
```go
cfg := security.OAuthServerConfig{
Issuer: "https://example.com", // Required — token issuer URL
ProviderCallbackPath: "/oauth/provider/callback", // External provider redirect target
LoginTitle: "My App Login", // HTML login page title
PersistClients: true, // Store clients in DB (multi-instance safe)
PersistCodes: true, // Store codes in DB (multi-instance safe)
DefaultScopes: []string{"openid", "profile"}, // Returned when no scope requested
AccessTokenTTL: time.Hour,
AuthCodeTTL: 5 * time.Minute,
}
```
| Field | Default | Notes |
|-------|---------|-------|
| `Issuer` | — | Required; trailing slash is trimmed automatically |
| `ProviderCallbackPath` | `/oauth/provider/callback` | |
| `LoginTitle` | `"Sign in"` | |
| `PersistClients` | `false` | Set `true` for multi-instance |
| `PersistCodes` | `false` | Set `true` for multi-instance; does not require `PersistClients` |
| `DefaultScopes` | `["openid","profile","email"]` | |
| `AccessTokenTTL` | `24h` | Also used as `expires_in` in token responses |
| `AuthCodeTTL` | `2m` | |
### Operating Modes
**Mode 1 — Direct login (username/password form)**
Pass a `*DatabaseAuthenticator` to `NewOAuthServer`. The server renders a login form at `GET /oauth/authorize` and issues tokens via the stored session after login.
```go
auth := security.NewDatabaseAuthenticator(db)
srv := security.NewOAuthServer(cfg, auth)
```
**Mode 2 — External provider federation**
Pass a `*DatabaseAuthenticator` for persistence (authorization codes, revoke, introspect) and register external providers. The authorize endpoint redirects to the specified provider (via the `provider` query param) or to the first registered provider by default.
```go
auth := security.NewDatabaseAuthenticator(db)
srv := security.NewOAuthServer(cfg, auth)
srv.RegisterExternalProvider(googleAuth, "google")
srv.RegisterExternalProvider(githubAuth, "github")
```
**Mode 3 — Both**
Pass auth for the login form and also register external providers. The authorize page shows both a login form and provider buttons.
```go
srv := security.NewOAuthServer(cfg, auth)
srv.RegisterExternalProvider(googleAuth, "google")
```
### Standalone Usage
```go
mux := http.NewServeMux()
mux.Handle("/.well-known/", srv.HTTPHandler())
mux.Handle("/oauth/", srv.HTTPHandler())
mux.Handle(cfg.ProviderCallbackPath, srv.HTTPHandler())
http.ListenAndServe(":8080", mux)
```
### DB Persistence
When `PersistClients: true` or `PersistCodes: true`, the server calls the corresponding `DatabaseAuthenticator` methods. Both flags default to `false` (in-memory maps). Enable both for multi-instance deployments.
Requires `oauth_clients` and `oauth_codes` tables + 6 stored procedures from `database_schema.sql`.
#### New DB Types
```go
type OAuthServerClient struct {
ClientID string `json:"client_id"`
RedirectURIs []string `json:"redirect_uris"`
ClientName string `json:"client_name,omitempty"`
GrantTypes []string `json:"grant_types"`
AllowedScopes []string `json:"allowed_scopes,omitempty"`
}
type OAuthCode struct {
Code string `json:"code"`
ClientID string `json:"client_id"`
RedirectURI string `json:"redirect_uri"`
ClientState string `json:"client_state,omitempty"`
CodeChallenge string `json:"code_challenge"`
CodeChallengeMethod string `json:"code_challenge_method"`
SessionToken string `json:"session_token"`
Scopes []string `json:"scopes,omitempty"`
ExpiresAt time.Time `json:"expires_at"`
}
type OAuthTokenInfo struct {
Active bool `json:"active"`
Sub string `json:"sub,omitempty"`
Username string `json:"username,omitempty"`
Email string `json:"email,omitempty"`
Roles []string `json:"roles,omitempty"`
Exp int64 `json:"exp,omitempty"`
Iat int64 `json:"iat,omitempty"`
}
```
#### DatabaseAuthenticator OAuth Methods
```go
auth.OAuthRegisterClient(ctx, client) // RFC 7591 — persist client
auth.OAuthGetClient(ctx, clientID) // retrieve client
auth.OAuthSaveCode(ctx, code) // persist authorization code
auth.OAuthExchangeCode(ctx, code) // consume code (single-use, deletes on read)
auth.OAuthIntrospectToken(ctx, token) // RFC 7662 — returns OAuthTokenInfo
auth.OAuthRevokeToken(ctx, token) // RFC 7009 — revoke session
```
#### SQLNames Fields
```go
type SQLNames struct {
// ... existing fields ...
OAuthRegisterClient string // default: "resolvespec_oauth_register_client"
OAuthGetClient string // default: "resolvespec_oauth_get_client"
OAuthSaveCode string // default: "resolvespec_oauth_save_code"
OAuthExchangeCode string // default: "resolvespec_oauth_exchange_code"
OAuthIntrospect string // default: "resolvespec_oauth_introspect"
OAuthRevoke string // default: "resolvespec_oauth_revoke"
}
```
The main changes: The main changes:
1. Security package no longer knows about specific spec types 1. Security package no longer knows about specific spec types
2. Each spec registers its own security hooks 2. Each spec registers its own security hooks

View File

@@ -1397,3 +1397,180 @@ $$ LANGUAGE plpgsql;
-- Get credentials by username -- Get credentials by username
-- SELECT * FROM resolvespec_passkey_get_credentials_by_username('admin'); -- SELECT * FROM resolvespec_passkey_get_credentials_by_username('admin');
-- ============================================
-- OAuth2 Server Tables (OAuthServer persistence)
-- ============================================
-- oauth_clients: persistent RFC 7591 registered clients
CREATE TABLE IF NOT EXISTS oauth_clients (
id SERIAL PRIMARY KEY,
client_id VARCHAR(255) NOT NULL UNIQUE,
redirect_uris TEXT[] NOT NULL,
client_name VARCHAR(255),
grant_types TEXT[] DEFAULT ARRAY['authorization_code'],
allowed_scopes TEXT[] DEFAULT ARRAY['openid','profile','email'],
is_active BOOLEAN DEFAULT true,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);
-- oauth_codes: short-lived authorization codes (for multi-instance deployments)
-- Note: client_id is stored without a foreign key so codes can be persisted even
-- when OAuth clients are managed in memory rather than persisted in oauth_clients.
CREATE TABLE IF NOT EXISTS oauth_codes (
id SERIAL PRIMARY KEY,
code VARCHAR(255) NOT NULL UNIQUE,
client_id VARCHAR(255) NOT NULL,
redirect_uri TEXT NOT NULL,
client_state TEXT,
code_challenge VARCHAR(255) NOT NULL,
code_challenge_method VARCHAR(10) DEFAULT 'S256',
session_token TEXT NOT NULL,
refresh_token TEXT,
scopes TEXT[],
expires_at TIMESTAMP NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);
CREATE INDEX IF NOT EXISTS idx_oauth_codes_code ON oauth_codes(code);
CREATE INDEX IF NOT EXISTS idx_oauth_codes_expires ON oauth_codes(expires_at);
-- ============================================
-- OAuth2 Server Stored Procedures
-- ============================================
CREATE OR REPLACE FUNCTION resolvespec_oauth_register_client(p_data jsonb)
RETURNS TABLE(p_success bool, p_error text, p_data jsonb)
LANGUAGE plpgsql AS $$
DECLARE
v_client_id text;
v_row jsonb;
BEGIN
v_client_id := p_data->>'client_id';
INSERT INTO oauth_clients (client_id, redirect_uris, client_name, grant_types, allowed_scopes)
VALUES (
v_client_id,
ARRAY(SELECT jsonb_array_elements_text(p_data->'redirect_uris')),
p_data->>'client_name',
COALESCE(ARRAY(SELECT jsonb_array_elements_text(p_data->'grant_types')), ARRAY['authorization_code']),
COALESCE(ARRAY(SELECT jsonb_array_elements_text(p_data->'allowed_scopes')), ARRAY['openid','profile','email'])
)
RETURNING to_jsonb(oauth_clients.*) INTO v_row;
RETURN QUERY SELECT true, null::text, v_row;
EXCEPTION WHEN OTHERS THEN
RETURN QUERY SELECT false, SQLERRM, null::jsonb;
END;
$$;
CREATE OR REPLACE FUNCTION resolvespec_oauth_get_client(p_client_id text)
RETURNS TABLE(p_success bool, p_error text, p_data jsonb)
LANGUAGE plpgsql AS $$
DECLARE
v_row jsonb;
BEGIN
SELECT to_jsonb(oauth_clients.*)
INTO v_row
FROM oauth_clients
WHERE client_id = p_client_id AND is_active = true;
IF v_row IS NULL THEN
RETURN QUERY SELECT false, 'client not found'::text, null::jsonb;
ELSE
RETURN QUERY SELECT true, null::text, v_row;
END IF;
END;
$$;
CREATE OR REPLACE FUNCTION resolvespec_oauth_save_code(p_data jsonb)
RETURNS TABLE(p_success bool, p_error text)
LANGUAGE plpgsql AS $$
BEGIN
INSERT INTO oauth_codes (code, client_id, redirect_uri, client_state, code_challenge, code_challenge_method, session_token, refresh_token, scopes, expires_at)
VALUES (
p_data->>'code',
p_data->>'client_id',
p_data->>'redirect_uri',
p_data->>'client_state',
p_data->>'code_challenge',
COALESCE(p_data->>'code_challenge_method', 'S256'),
p_data->>'session_token',
p_data->>'refresh_token',
ARRAY(SELECT jsonb_array_elements_text(p_data->'scopes')),
(p_data->>'expires_at')::timestamp
);
RETURN QUERY SELECT true, null::text;
EXCEPTION WHEN OTHERS THEN
RETURN QUERY SELECT false, SQLERRM;
END;
$$;
CREATE OR REPLACE FUNCTION resolvespec_oauth_exchange_code(p_code text)
RETURNS TABLE(p_success bool, p_error text, p_data jsonb)
LANGUAGE plpgsql AS $$
DECLARE
v_row jsonb;
BEGIN
DELETE FROM oauth_codes
WHERE code = p_code AND expires_at > now()
RETURNING jsonb_build_object(
'client_id', client_id,
'redirect_uri', redirect_uri,
'client_state', client_state,
'code_challenge', code_challenge,
'code_challenge_method', code_challenge_method,
'session_token', session_token,
'refresh_token', refresh_token,
'scopes', to_jsonb(scopes)
) INTO v_row;
IF v_row IS NULL THEN
RETURN QUERY SELECT false, 'invalid or expired code'::text, null::jsonb;
ELSE
RETURN QUERY SELECT true, null::text, v_row;
END IF;
END;
$$;
CREATE OR REPLACE FUNCTION resolvespec_oauth_introspect(p_token text)
RETURNS TABLE(p_success bool, p_error text, p_data jsonb)
LANGUAGE plpgsql AS $$
DECLARE
v_row jsonb;
BEGIN
SELECT jsonb_build_object(
'active', true,
'sub', u.id::text,
'username', u.username,
'email', u.email,
'user_level', u.user_level,
-- NULLIF converts empty string to NULL; string_to_array(NULL) returns NULL;
-- to_jsonb(NULL) returns NULL; COALESCE then returns '[]' for NULL/empty roles.
'roles', COALESCE(to_jsonb(string_to_array(NULLIF(u.roles, ''), ',')), '[]'::jsonb),
'exp', EXTRACT(EPOCH FROM s.expires_at)::bigint,
'iat', EXTRACT(EPOCH FROM s.created_at)::bigint
)
INTO v_row
FROM user_sessions s
JOIN users u ON u.id = s.user_id
WHERE s.session_token = p_token
AND s.expires_at > now()
AND u.is_active = true;
IF v_row IS NULL THEN
RETURN QUERY SELECT true, null::text, '{"active":false}'::jsonb;
ELSE
RETURN QUERY SELECT true, null::text, v_row;
END IF;
END;
$$;
CREATE OR REPLACE FUNCTION resolvespec_oauth_revoke(p_token text)
RETURNS TABLE(p_success bool, p_error text)
LANGUAGE plpgsql AS $$
BEGIN
DELETE FROM user_sessions WHERE session_token = p_token;
RETURN QUERY SELECT true, null::text;
END;
$$;

81
pkg/security/keystore.go Normal file
View File

@@ -0,0 +1,81 @@
package security
import (
"context"
"crypto/sha256"
"encoding/hex"
"time"
)
// hashSHA256Hex returns the lowercase hex SHA-256 digest of the given string.
// Used by all keystore implementations to hash raw keys before storage or lookup.
func hashSHA256Hex(raw string) string {
sum := sha256.Sum256([]byte(raw))
return hex.EncodeToString(sum[:])
}
// KeyType identifies the category of an auth key.
type KeyType string
const (
// KeyTypeJWTSecret is a per-user JWT signing secret for token generation.
KeyTypeJWTSecret KeyType = "jwt_secret"
// KeyTypeHeaderAPI is a static API key sent via a request header.
KeyTypeHeaderAPI KeyType = "header_api"
// KeyTypeOAuth2 holds OAuth2 client credentials (client_id / client_secret).
KeyTypeOAuth2 KeyType = "oauth2"
// KeyTypeGenericAPI is a generic application API key.
KeyTypeGenericAPI KeyType = "api"
)
// UserKey represents a single named auth key belonging to a user.
// KeyHash stores the SHA-256 hex digest of the raw key; the raw key is never persisted.
type UserKey struct {
ID int64 `json:"id"`
UserID int `json:"user_id"`
KeyType KeyType `json:"key_type"`
KeyHash string `json:"key_hash"` // SHA-256 hex; never the raw key
Name string `json:"name"`
Scopes []string `json:"scopes,omitempty"`
Meta map[string]any `json:"meta,omitempty"`
ExpiresAt *time.Time `json:"expires_at,omitempty"`
CreatedAt time.Time `json:"created_at"`
LastUsedAt *time.Time `json:"last_used_at,omitempty"`
IsActive bool `json:"is_active"`
}
// CreateKeyRequest specifies the parameters for a new key.
type CreateKeyRequest struct {
UserID int
KeyType KeyType
Name string
Scopes []string
Meta map[string]any
ExpiresAt *time.Time
}
// CreateKeyResponse is returned exactly once when a key is created.
// The caller is responsible for persisting RawKey; it is not stored anywhere.
type CreateKeyResponse struct {
Key UserKey
RawKey string // crypto/rand 32 bytes, base64url-encoded
}
// KeyStore manages per-user auth keys with pluggable storage backends.
// Implementations: ConfigKeyStore (static list) and DatabaseKeyStore (stored procedures).
type KeyStore interface {
// CreateKey generates a new key, stores its hash, and returns the raw key once.
CreateKey(ctx context.Context, req CreateKeyRequest) (*CreateKeyResponse, error)
// GetUserKeys returns all active, non-expired keys for a user.
// Pass an empty KeyType to return all types.
GetUserKeys(ctx context.Context, userID int, keyType KeyType) ([]UserKey, error)
// DeleteKey soft-deletes a key by ID after verifying ownership.
DeleteKey(ctx context.Context, userID int, keyID int64) error
// ValidateKey checks a raw key, returns the matching UserKey on success.
// The implementation hashes the raw key before any lookup.
// Pass an empty KeyType to accept any type.
ValidateKey(ctx context.Context, rawKey string, keyType KeyType) (*UserKey, error)
}

View File

@@ -0,0 +1,97 @@
package security
import (
"context"
"fmt"
"net/http"
"strings"
)
// KeyStoreAuthenticator implements the Authenticator interface using a KeyStore.
// It is suitable for long-lived application credentials (API keys, JWT secrets, etc.)
// rather than interactive sessions. Login and Logout are not supported — key lifecycle
// is managed directly through the KeyStore.
//
// Key extraction order:
// 1. Authorization: Bearer <key>
// 2. Authorization: ApiKey <key>
// 3. X-API-Key header
type KeyStoreAuthenticator struct {
keyStore KeyStore
keyType KeyType // empty = accept any type
}
// NewKeyStoreAuthenticator creates a KeyStoreAuthenticator.
// Pass an empty keyType to accept keys of any type.
func NewKeyStoreAuthenticator(ks KeyStore, keyType KeyType) *KeyStoreAuthenticator {
return &KeyStoreAuthenticator{keyStore: ks, keyType: keyType}
}
// Login is not supported for keystore authentication.
func (a *KeyStoreAuthenticator) Login(_ context.Context, _ LoginRequest) (*LoginResponse, error) {
return nil, fmt.Errorf("keystore authenticator does not support login")
}
// Logout is not supported for keystore authentication.
func (a *KeyStoreAuthenticator) Logout(_ context.Context, _ LogoutRequest) error {
return nil
}
// Authenticate extracts an API key from the request and validates it against the KeyStore.
// Returns a UserContext built from the matching UserKey on success.
func (a *KeyStoreAuthenticator) Authenticate(r *http.Request) (*UserContext, error) {
rawKey := extractAPIKey(r)
if rawKey == "" {
return nil, fmt.Errorf("API key required (Authorization: Bearer/ApiKey <key> or X-API-Key header)")
}
userKey, err := a.keyStore.ValidateKey(r.Context(), rawKey, a.keyType)
if err != nil {
return nil, fmt.Errorf("invalid API key: %w", err)
}
return userKeyToUserContext(userKey), nil
}
// extractAPIKey extracts a raw key from the request using the following precedence:
// 1. Authorization: Bearer <key>
// 2. Authorization: ApiKey <key>
// 3. X-API-Key header
func extractAPIKey(r *http.Request) string {
if auth := r.Header.Get("Authorization"); auth != "" {
if after, ok := strings.CutPrefix(auth, "Bearer "); ok {
return strings.TrimSpace(after)
}
if after, ok := strings.CutPrefix(auth, "ApiKey "); ok {
return strings.TrimSpace(after)
}
}
return strings.TrimSpace(r.Header.Get("X-API-Key"))
}
// userKeyToUserContext converts a UserKey into a UserContext.
// Scopes are mapped to Roles. Key type and name are stored in Claims.
func userKeyToUserContext(k *UserKey) *UserContext {
claims := map[string]any{
"key_type": string(k.KeyType),
"key_name": k.Name,
}
meta := k.Meta
if meta == nil {
meta = map[string]any{}
}
roles := k.Scopes
if roles == nil {
roles = []string{}
}
return &UserContext{
UserID: k.UserID,
SessionID: fmt.Sprintf("key:%d", k.ID),
Roles: roles,
Claims: claims,
Meta: meta,
}
}

View File

@@ -0,0 +1,149 @@
package security
import (
"context"
"crypto/rand"
"crypto/subtle"
"encoding/base64"
"encoding/hex"
"fmt"
"sync"
"sync/atomic"
"time"
)
// ConfigKeyStore is an in-memory keystore backed by a static slice of UserKey values.
// It is designed for config-file driven setups (e.g. service accounts defined in YAML)
// with a small, bounded number of keys. For large or dynamic key sets use DatabaseKeyStore.
//
// Pre-existing entries must have KeyHash set to the SHA-256 hex of the intended raw key.
// Keys created at runtime via CreateKey are held in memory only and lost on restart.
type ConfigKeyStore struct {
mu sync.RWMutex
keys []UserKey
next int64 // monotonic ID counter for runtime-created keys (atomic)
}
// NewConfigKeyStore creates a ConfigKeyStore seeded with the provided keys.
// Pass nil or an empty slice to start with no pre-loaded keys.
// Zero-value entries (CreatedAt is zero) are treated as active and assigned the current time.
func NewConfigKeyStore(keys []UserKey) *ConfigKeyStore {
var maxID int64
copied := make([]UserKey, len(keys))
copy(copied, keys)
for i := range copied {
if copied[i].CreatedAt.IsZero() {
copied[i].IsActive = true
copied[i].CreatedAt = time.Now()
}
if copied[i].ID > maxID {
maxID = copied[i].ID
}
}
return &ConfigKeyStore{keys: copied, next: maxID}
}
// CreateKey generates a new raw key, stores its SHA-256 hash, and returns the raw key once.
func (s *ConfigKeyStore) CreateKey(_ context.Context, req CreateKeyRequest) (*CreateKeyResponse, error) {
rawBytes := make([]byte, 32)
if _, err := rand.Read(rawBytes); err != nil {
return nil, fmt.Errorf("failed to generate key material: %w", err)
}
rawKey := base64.RawURLEncoding.EncodeToString(rawBytes)
hash := hashSHA256Hex(rawKey)
id := atomic.AddInt64(&s.next, 1)
key := UserKey{
ID: id,
UserID: req.UserID,
KeyType: req.KeyType,
KeyHash: hash,
Name: req.Name,
Scopes: req.Scopes,
Meta: req.Meta,
ExpiresAt: req.ExpiresAt,
CreatedAt: time.Now(),
IsActive: true,
}
s.mu.Lock()
s.keys = append(s.keys, key)
s.mu.Unlock()
return &CreateKeyResponse{Key: key, RawKey: rawKey}, nil
}
// GetUserKeys returns all active, non-expired keys for the given user.
// Pass an empty KeyType to return all types.
func (s *ConfigKeyStore) GetUserKeys(_ context.Context, userID int, keyType KeyType) ([]UserKey, error) {
now := time.Now()
s.mu.RLock()
defer s.mu.RUnlock()
var result []UserKey
for i := range s.keys {
k := &s.keys[i]
if k.UserID != userID || !k.IsActive {
continue
}
if k.ExpiresAt != nil && k.ExpiresAt.Before(now) {
continue
}
if keyType != "" && k.KeyType != keyType {
continue
}
result = append(result, *k)
}
return result, nil
}
// DeleteKey soft-deletes a key by setting IsActive to false after ownership verification.
func (s *ConfigKeyStore) DeleteKey(_ context.Context, userID int, keyID int64) error {
s.mu.Lock()
defer s.mu.Unlock()
for i := range s.keys {
if s.keys[i].ID == keyID {
if s.keys[i].UserID != userID {
return fmt.Errorf("key not found or permission denied")
}
s.keys[i].IsActive = false
return nil
}
}
return fmt.Errorf("key not found")
}
// ValidateKey hashes the raw key and finds a matching, active, non-expired entry.
// Uses constant-time comparison to prevent timing side-channels.
// Pass an empty KeyType to accept any type.
func (s *ConfigKeyStore) ValidateKey(_ context.Context, rawKey string, keyType KeyType) (*UserKey, error) {
hash := hashSHA256Hex(rawKey)
hashBytes, _ := hex.DecodeString(hash)
now := time.Now()
// Write lock: ValidateKey updates LastUsedAt on the matched entry.
s.mu.Lock()
defer s.mu.Unlock()
for i := range s.keys {
k := &s.keys[i]
if !k.IsActive {
continue
}
if k.ExpiresAt != nil && k.ExpiresAt.Before(now) {
continue
}
if keyType != "" && k.KeyType != keyType {
continue
}
stored, _ := hex.DecodeString(k.KeyHash)
if subtle.ConstantTimeCompare(hashBytes, stored) != 1 {
continue
}
k.LastUsedAt = &now
result := *k
return &result, nil
}
return nil, fmt.Errorf("invalid or expired key")
}

View File

@@ -0,0 +1,256 @@
package security
import (
"context"
"crypto/rand"
"database/sql"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"sync"
"time"
"github.com/bitechdev/ResolveSpec/pkg/cache"
)
// DatabaseKeyStoreOptions configures DatabaseKeyStore.
type DatabaseKeyStoreOptions struct {
// Cache is an optional cache instance. If nil, uses the default cache.
Cache *cache.Cache
// CacheTTL is the duration to cache ValidateKey results.
// Default: 2 minutes.
CacheTTL time.Duration
// SQLNames provides custom procedure names. If nil, uses DefaultKeyStoreSQLNames().
SQLNames *KeyStoreSQLNames
// DBFactory is called to obtain a fresh *sql.DB when the existing connection is closed.
// If nil, reconnection is disabled.
DBFactory func() (*sql.DB, error)
}
// DatabaseKeyStore is a KeyStore backed by PostgreSQL stored procedures.
// All DB operations go through configurable procedure names; the raw key is
// never passed to the database.
//
// See keystore_schema.sql for the required table and procedure definitions.
//
// Note: DeleteKey invalidates the cache entry for the deleted key. Due to the
// cache TTL, a deleted key may continue to authenticate for up to CacheTTL
// (default 2 minutes) if the cache entry cannot be invalidated.
type DatabaseKeyStore struct {
db *sql.DB
dbMu sync.RWMutex
dbFactory func() (*sql.DB, error)
sqlNames *KeyStoreSQLNames
cache *cache.Cache
cacheTTL time.Duration
}
// NewDatabaseKeyStore creates a DatabaseKeyStore with optional configuration.
func NewDatabaseKeyStore(db *sql.DB, opts ...DatabaseKeyStoreOptions) *DatabaseKeyStore {
o := DatabaseKeyStoreOptions{}
if len(opts) > 0 {
o = opts[0]
}
if o.CacheTTL == 0 {
o.CacheTTL = 2 * time.Minute
}
c := o.Cache
if c == nil {
c = cache.GetDefaultCache()
}
names := MergeKeyStoreSQLNames(DefaultKeyStoreSQLNames(), o.SQLNames)
return &DatabaseKeyStore{
db: db,
dbFactory: o.DBFactory,
sqlNames: names,
cache: c,
cacheTTL: o.CacheTTL,
}
}
func (ks *DatabaseKeyStore) getDB() *sql.DB {
ks.dbMu.RLock()
defer ks.dbMu.RUnlock()
return ks.db
}
func (ks *DatabaseKeyStore) reconnectDB() error {
if ks.dbFactory == nil {
return fmt.Errorf("no db factory configured for reconnect")
}
newDB, err := ks.dbFactory()
if err != nil {
return err
}
ks.dbMu.Lock()
ks.db = newDB
ks.dbMu.Unlock()
return nil
}
// CreateKey generates a raw key, stores its SHA-256 hash via the create procedure,
// and returns the raw key once.
func (ks *DatabaseKeyStore) CreateKey(ctx context.Context, req CreateKeyRequest) (*CreateKeyResponse, error) {
rawBytes := make([]byte, 32)
if _, err := rand.Read(rawBytes); err != nil {
return nil, fmt.Errorf("failed to generate key material: %w", err)
}
rawKey := base64.RawURLEncoding.EncodeToString(rawBytes)
hash := hashSHA256Hex(rawKey)
type createRequest struct {
UserID int `json:"user_id"`
KeyType KeyType `json:"key_type"`
KeyHash string `json:"key_hash"`
Name string `json:"name"`
Scopes []string `json:"scopes,omitempty"`
Meta map[string]any `json:"meta,omitempty"`
ExpiresAt *time.Time `json:"expires_at,omitempty"`
}
reqJSON, err := json.Marshal(createRequest{
UserID: req.UserID,
KeyType: req.KeyType,
KeyHash: hash,
Name: req.Name,
Scopes: req.Scopes,
Meta: req.Meta,
ExpiresAt: req.ExpiresAt,
})
if err != nil {
return nil, fmt.Errorf("failed to marshal create key request: %w", err)
}
var success bool
var errorMsg sql.NullString
var keyJSON sql.NullString
query := fmt.Sprintf(`SELECT p_success, p_error, p_key::text FROM %s($1::jsonb)`, ks.sqlNames.CreateKey)
if err = ks.getDB().QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &keyJSON); err != nil {
return nil, fmt.Errorf("create key procedure failed: %w", err)
}
if !success {
return nil, errors.New(nullStringOr(errorMsg, "create key failed"))
}
var key UserKey
if err = json.Unmarshal([]byte(keyJSON.String), &key); err != nil {
return nil, fmt.Errorf("failed to parse created key: %w", err)
}
return &CreateKeyResponse{Key: key, RawKey: rawKey}, nil
}
// GetUserKeys returns all active, non-expired keys for the given user.
// Pass an empty KeyType to return all types.
func (ks *DatabaseKeyStore) GetUserKeys(ctx context.Context, userID int, keyType KeyType) ([]UserKey, error) {
var success bool
var errorMsg sql.NullString
var keysJSON sql.NullString
query := fmt.Sprintf(`SELECT p_success, p_error, p_keys::text FROM %s($1, $2)`, ks.sqlNames.GetUserKeys)
if err := ks.getDB().QueryRowContext(ctx, query, userID, string(keyType)).Scan(&success, &errorMsg, &keysJSON); err != nil {
return nil, fmt.Errorf("get user keys procedure failed: %w", err)
}
if !success {
return nil, errors.New(nullStringOr(errorMsg, "get user keys failed"))
}
var keys []UserKey
if keysJSON.Valid && keysJSON.String != "" && keysJSON.String != "[]" {
if err := json.Unmarshal([]byte(keysJSON.String), &keys); err != nil {
return nil, fmt.Errorf("failed to parse user keys: %w", err)
}
}
if keys == nil {
keys = []UserKey{}
}
return keys, nil
}
// DeleteKey soft-deletes a key after verifying ownership and invalidates its cache entry.
// The delete procedure returns the key_hash so no separate lookup is needed.
// Note: cache invalidation is best-effort; a cached entry may persist for up to CacheTTL.
func (ks *DatabaseKeyStore) DeleteKey(ctx context.Context, userID int, keyID int64) error {
var success bool
var errorMsg sql.NullString
var keyHash sql.NullString
query := fmt.Sprintf(`SELECT p_success, p_error, p_key_hash FROM %s($1, $2)`, ks.sqlNames.DeleteKey)
if err := ks.getDB().QueryRowContext(ctx, query, userID, keyID).Scan(&success, &errorMsg, &keyHash); err != nil {
return fmt.Errorf("delete key procedure failed: %w", err)
}
if !success {
return errors.New(nullStringOr(errorMsg, "delete key failed"))
}
if keyHash.Valid && keyHash.String != "" && ks.cache != nil {
_ = ks.cache.Delete(ctx, keystoreCacheKey(keyHash.String))
}
return nil
}
// ValidateKey hashes the raw key and calls the validate procedure.
// Results are cached for CacheTTL to reduce DB load on hot paths.
func (ks *DatabaseKeyStore) ValidateKey(ctx context.Context, rawKey string, keyType KeyType) (*UserKey, error) {
hash := hashSHA256Hex(rawKey)
cacheKey := keystoreCacheKey(hash)
if ks.cache != nil {
var cached UserKey
if err := ks.cache.Get(ctx, cacheKey, &cached); err == nil {
if cached.IsActive {
return &cached, nil
}
return nil, errors.New("invalid or expired key")
}
}
var success bool
var errorMsg sql.NullString
var keyJSON sql.NullString
runQuery := func() error {
query := fmt.Sprintf(`SELECT p_success, p_error, p_key::text FROM %s($1, $2)`, ks.sqlNames.ValidateKey)
return ks.getDB().QueryRowContext(ctx, query, hash, string(keyType)).Scan(&success, &errorMsg, &keyJSON)
}
if err := runQuery(); err != nil {
if isDBClosed(err) {
if reconnErr := ks.reconnectDB(); reconnErr == nil {
err = runQuery()
}
if err != nil {
return nil, fmt.Errorf("validate key procedure failed: %w", err)
}
} else {
return nil, fmt.Errorf("validate key procedure failed: %w", err)
}
}
if !success {
return nil, errors.New(nullStringOr(errorMsg, "invalid or expired key"))
}
var key UserKey
if err := json.Unmarshal([]byte(keyJSON.String), &key); err != nil {
return nil, fmt.Errorf("failed to parse validated key: %w", err)
}
if ks.cache != nil {
_ = ks.cache.Set(ctx, cacheKey, key, ks.cacheTTL)
}
return &key, nil
}
func keystoreCacheKey(hash string) string {
return "keystore:validate:" + hash
}
// nullStringOr returns s.String if valid, otherwise the fallback.
func nullStringOr(s sql.NullString, fallback string) string {
if s.Valid && s.String != "" {
return s.String
}
return fallback
}

View File

@@ -0,0 +1,187 @@
-- Keystore schema for per-user auth keys
-- Apply alongside database_schema.sql (requires the users table)
CREATE TABLE IF NOT EXISTS user_keys (
id BIGSERIAL PRIMARY KEY,
user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE,
key_type VARCHAR(50) NOT NULL,
key_hash VARCHAR(64) NOT NULL UNIQUE, -- SHA-256 hex digest (64 chars)
name VARCHAR(255) NOT NULL DEFAULT '',
scopes TEXT, -- JSON array, e.g. '["read","write"]'
meta JSONB,
expires_at TIMESTAMP,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
last_used_at TIMESTAMP,
is_active BOOLEAN DEFAULT true
);
CREATE INDEX IF NOT EXISTS idx_user_keys_user_id ON user_keys(user_id);
CREATE INDEX IF NOT EXISTS idx_user_keys_key_hash ON user_keys(key_hash);
CREATE INDEX IF NOT EXISTS idx_user_keys_key_type ON user_keys(key_type);
-- resolvespec_keystore_get_user_keys
-- Returns all active, non-expired keys for a user.
-- Pass empty p_key_type to return all key types.
CREATE OR REPLACE FUNCTION resolvespec_keystore_get_user_keys(
p_user_id INTEGER,
p_key_type TEXT DEFAULT ''
)
RETURNS TABLE(p_success BOOLEAN, p_error TEXT, p_keys JSONB)
LANGUAGE plpgsql AS $$
DECLARE
v_keys JSONB;
BEGIN
SELECT COALESCE(
jsonb_agg(
jsonb_build_object(
'id', k.id,
'user_id', k.user_id,
'key_type', k.key_type,
'name', k.name,
'scopes', CASE WHEN k.scopes IS NOT NULL THEN k.scopes::jsonb ELSE '[]'::jsonb END,
'meta', COALESCE(k.meta, '{}'::jsonb),
'expires_at', k.expires_at,
'created_at', k.created_at,
'last_used_at', k.last_used_at,
'is_active', k.is_active
)
),
'[]'::jsonb
)
INTO v_keys
FROM user_keys k
WHERE k.user_id = p_user_id
AND k.is_active = true
AND (k.expires_at IS NULL OR k.expires_at > NOW())
AND (p_key_type = '' OR k.key_type = p_key_type);
RETURN QUERY SELECT true, NULL::TEXT, v_keys;
EXCEPTION WHEN OTHERS THEN
RETURN QUERY SELECT false, SQLERRM, NULL::JSONB;
END;
$$;
-- resolvespec_keystore_create_key
-- Inserts a new key row. key_hash is provided by the caller (Go hashes the raw key).
-- Returns the created key record (without key_hash).
CREATE OR REPLACE FUNCTION resolvespec_keystore_create_key(
p_request JSONB
)
RETURNS TABLE(p_success BOOLEAN, p_error TEXT, p_key JSONB)
LANGUAGE plpgsql AS $$
DECLARE
v_id BIGINT;
v_created_at TIMESTAMP;
v_key JSONB;
BEGIN
INSERT INTO user_keys (user_id, key_type, key_hash, name, scopes, meta, expires_at)
VALUES (
(p_request->>'user_id')::INTEGER,
p_request->>'key_type',
p_request->>'key_hash',
COALESCE(p_request->>'name', ''),
p_request->>'scopes',
p_request->'meta',
CASE WHEN p_request->>'expires_at' IS NOT NULL
THEN (p_request->>'expires_at')::TIMESTAMP
ELSE NULL
END
)
RETURNING id, created_at INTO v_id, v_created_at;
v_key := jsonb_build_object(
'id', v_id,
'user_id', (p_request->>'user_id')::INTEGER,
'key_type', p_request->>'key_type',
'name', COALESCE(p_request->>'name', ''),
'scopes', CASE WHEN p_request->>'scopes' IS NOT NULL
THEN (p_request->>'scopes')::jsonb
ELSE '[]'::jsonb END,
'meta', COALESCE(p_request->'meta', '{}'::jsonb),
'expires_at', p_request->>'expires_at',
'created_at', v_created_at,
'is_active', true
);
RETURN QUERY SELECT true, NULL::TEXT, v_key;
EXCEPTION WHEN OTHERS THEN
RETURN QUERY SELECT false, SQLERRM, NULL::JSONB;
END;
$$;
-- resolvespec_keystore_delete_key
-- Soft-deletes a key (is_active = false) after verifying ownership.
-- Returns p_key_hash so the caller can invalidate cache entries without a separate query.
CREATE OR REPLACE FUNCTION resolvespec_keystore_delete_key(
p_user_id INTEGER,
p_key_id BIGINT
)
RETURNS TABLE(p_success BOOLEAN, p_error TEXT, p_key_hash TEXT)
LANGUAGE plpgsql AS $$
DECLARE
v_hash TEXT;
BEGIN
UPDATE user_keys
SET is_active = false
WHERE id = p_key_id AND user_id = p_user_id AND is_active = true
RETURNING key_hash INTO v_hash;
IF NOT FOUND THEN
RETURN QUERY SELECT false, 'key not found or already deleted'::TEXT, NULL::TEXT;
RETURN;
END IF;
RETURN QUERY SELECT true, NULL::TEXT, v_hash;
EXCEPTION WHEN OTHERS THEN
RETURN QUERY SELECT false, SQLERRM, NULL::TEXT;
END;
$$;
-- resolvespec_keystore_validate_key
-- Looks up a key by its SHA-256 hash, checks active status and expiry,
-- updates last_used_at, and returns the key record.
-- p_key_type can be empty to accept any key type.
CREATE OR REPLACE FUNCTION resolvespec_keystore_validate_key(
p_key_hash TEXT,
p_key_type TEXT DEFAULT ''
)
RETURNS TABLE(p_success BOOLEAN, p_error TEXT, p_key JSONB)
LANGUAGE plpgsql AS $$
DECLARE
v_key_rec user_keys%ROWTYPE;
v_key JSONB;
BEGIN
SELECT * INTO v_key_rec
FROM user_keys
WHERE key_hash = p_key_hash
AND is_active = true
AND (expires_at IS NULL OR expires_at > NOW())
AND (p_key_type = '' OR key_type = p_key_type);
IF NOT FOUND THEN
RETURN QUERY SELECT false, 'invalid or expired key'::TEXT, NULL::JSONB;
RETURN;
END IF;
UPDATE user_keys SET last_used_at = NOW() WHERE id = v_key_rec.id;
v_key := jsonb_build_object(
'id', v_key_rec.id,
'user_id', v_key_rec.user_id,
'key_type', v_key_rec.key_type,
'name', v_key_rec.name,
'scopes', CASE WHEN v_key_rec.scopes IS NOT NULL
THEN v_key_rec.scopes::jsonb
ELSE '[]'::jsonb END,
'meta', COALESCE(v_key_rec.meta, '{}'::jsonb),
'expires_at', v_key_rec.expires_at,
'created_at', v_key_rec.created_at,
'last_used_at', NOW(),
'is_active', v_key_rec.is_active
);
RETURN QUERY SELECT true, NULL::TEXT, v_key;
EXCEPTION WHEN OTHERS THEN
RETURN QUERY SELECT false, SQLERRM, NULL::JSONB;
END;
$$;

View File

@@ -0,0 +1,61 @@
package security
import "fmt"
// KeyStoreSQLNames holds the configurable stored procedure names used by DatabaseKeyStore.
// Use DefaultKeyStoreSQLNames() for defaults and MergeKeyStoreSQLNames() for partial overrides.
type KeyStoreSQLNames struct {
GetUserKeys string // default: "resolvespec_keystore_get_user_keys"
CreateKey string // default: "resolvespec_keystore_create_key"
DeleteKey string // default: "resolvespec_keystore_delete_key"
ValidateKey string // default: "resolvespec_keystore_validate_key"
}
// DefaultKeyStoreSQLNames returns a KeyStoreSQLNames with all default resolvespec_keystore_* values.
func DefaultKeyStoreSQLNames() *KeyStoreSQLNames {
return &KeyStoreSQLNames{
GetUserKeys: "resolvespec_keystore_get_user_keys",
CreateKey: "resolvespec_keystore_create_key",
DeleteKey: "resolvespec_keystore_delete_key",
ValidateKey: "resolvespec_keystore_validate_key",
}
}
// MergeKeyStoreSQLNames returns a copy of base with any non-empty fields from override applied.
// If override is nil, a copy of base is returned.
func MergeKeyStoreSQLNames(base, override *KeyStoreSQLNames) *KeyStoreSQLNames {
if override == nil {
copied := *base
return &copied
}
merged := *base
if override.GetUserKeys != "" {
merged.GetUserKeys = override.GetUserKeys
}
if override.CreateKey != "" {
merged.CreateKey = override.CreateKey
}
if override.DeleteKey != "" {
merged.DeleteKey = override.DeleteKey
}
if override.ValidateKey != "" {
merged.ValidateKey = override.ValidateKey
}
return &merged
}
// ValidateKeyStoreSQLNames checks that all non-empty procedure names are valid SQL identifiers.
func ValidateKeyStoreSQLNames(names *KeyStoreSQLNames) error {
fields := map[string]string{
"GetUserKeys": names.GetUserKeys,
"CreateKey": names.CreateKey,
"DeleteKey": names.DeleteKey,
"ValidateKey": names.ValidateKey,
}
for field, val := range fields {
if val != "" && !validSQLIdentifier.MatchString(val) {
return fmt.Errorf("KeyStoreSQLNames.%s contains invalid characters: %q", field, val)
}
}
return nil
}

View File

@@ -244,7 +244,7 @@ func (a *DatabaseAuthenticator) oauth2GetOrCreateUser(ctx context.Context, userC
var errMsg *string var errMsg *string
var userID *int var userID *int
err = a.db.QueryRowContext(ctx, fmt.Sprintf(` err = a.getDB().QueryRowContext(ctx, fmt.Sprintf(`
SELECT p_success, p_error, p_user_id SELECT p_success, p_error, p_user_id
FROM %s($1::jsonb) FROM %s($1::jsonb)
`, a.sqlNames.OAuthGetOrCreateUser), userJSON).Scan(&success, &errMsg, &userID) `, a.sqlNames.OAuthGetOrCreateUser), userJSON).Scan(&success, &errMsg, &userID)
@@ -287,7 +287,7 @@ func (a *DatabaseAuthenticator) oauth2CreateSession(ctx context.Context, session
var success bool var success bool
var errMsg *string var errMsg *string
err = a.db.QueryRowContext(ctx, fmt.Sprintf(` err = a.getDB().QueryRowContext(ctx, fmt.Sprintf(`
SELECT p_success, p_error SELECT p_success, p_error
FROM %s($1::jsonb) FROM %s($1::jsonb)
`, a.sqlNames.OAuthCreateSession), sessionJSON).Scan(&success, &errMsg) `, a.sqlNames.OAuthCreateSession), sessionJSON).Scan(&success, &errMsg)
@@ -385,7 +385,7 @@ func (a *DatabaseAuthenticator) OAuth2RefreshToken(ctx context.Context, refreshT
var errMsg *string var errMsg *string
var sessionData []byte var sessionData []byte
err = a.db.QueryRowContext(ctx, fmt.Sprintf(` err = a.getDB().QueryRowContext(ctx, fmt.Sprintf(`
SELECT p_success, p_error, p_data::text SELECT p_success, p_error, p_data::text
FROM %s($1) FROM %s($1)
`, a.sqlNames.OAuthGetRefreshToken), refreshToken).Scan(&success, &errMsg, &sessionData) `, a.sqlNames.OAuthGetRefreshToken), refreshToken).Scan(&success, &errMsg, &sessionData)
@@ -451,7 +451,7 @@ func (a *DatabaseAuthenticator) OAuth2RefreshToken(ctx context.Context, refreshT
var updateSuccess bool var updateSuccess bool
var updateErrMsg *string var updateErrMsg *string
err = a.db.QueryRowContext(ctx, fmt.Sprintf(` err = a.getDB().QueryRowContext(ctx, fmt.Sprintf(`
SELECT p_success, p_error SELECT p_success, p_error
FROM %s($1::jsonb) FROM %s($1::jsonb)
`, a.sqlNames.OAuthUpdateRefreshToken), updateJSON).Scan(&updateSuccess, &updateErrMsg) `, a.sqlNames.OAuthUpdateRefreshToken), updateJSON).Scan(&updateSuccess, &updateErrMsg)
@@ -472,7 +472,7 @@ func (a *DatabaseAuthenticator) OAuth2RefreshToken(ctx context.Context, refreshT
var userErrMsg *string var userErrMsg *string
var userData []byte var userData []byte
err = a.db.QueryRowContext(ctx, fmt.Sprintf(` err = a.getDB().QueryRowContext(ctx, fmt.Sprintf(`
SELECT p_success, p_error, p_data::text SELECT p_success, p_error, p_data::text
FROM %s($1) FROM %s($1)
`, a.sqlNames.OAuthGetUser), session.UserID).Scan(&userSuccess, &userErrMsg, &userData) `, a.sqlNames.OAuthGetUser), session.UserID).Scan(&userSuccess, &userErrMsg, &userData)

View File

@@ -0,0 +1,917 @@
package security
import (
"context"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"fmt"
"net/http"
"net/url"
"strings"
"sync"
"time"
)
// OAuthServerConfig configures the MCP-standard OAuth2 authorization server.
type OAuthServerConfig struct {
// Issuer is the public base URL of this server (e.g. "https://api.example.com").
// Used in /.well-known/oauth-authorization-server and to build endpoint URLs.
Issuer string
// ProviderCallbackPath is the path on this server that external OAuth2 providers
// redirect back to. Defaults to "/oauth/provider/callback".
ProviderCallbackPath string
// LoginTitle is shown on the built-in login form when the server acts as its own
// identity provider. Defaults to "Sign in".
LoginTitle string
// PersistClients stores registered clients in the database when a DatabaseAuthenticator is provided.
// Clients registered during a session survive server restarts.
PersistClients bool
// PersistCodes stores authorization codes in the database.
// Useful for multi-instance deployments. Defaults to in-memory.
PersistCodes bool
// DefaultScopes lists scopes advertised in server metadata. Defaults to ["openid","profile","email"].
DefaultScopes []string
// AccessTokenTTL is the issued token lifetime. Defaults to 24h.
AccessTokenTTL time.Duration
// AuthCodeTTL is the auth code lifetime. Defaults to 2 minutes.
AuthCodeTTL time.Duration
}
// oauthClient is a dynamically registered OAuth2 client (RFC 7591).
type oauthClient struct {
ClientID string `json:"client_id"`
RedirectURIs []string `json:"redirect_uris"`
ClientName string `json:"client_name,omitempty"`
GrantTypes []string `json:"grant_types"`
AllowedScopes []string `json:"allowed_scopes,omitempty"`
}
// pendingAuth tracks an in-progress authorization code exchange.
type pendingAuth struct {
ClientID string
RedirectURI string
ClientState string
CodeChallenge string
CodeChallengeMethod string
ProviderName string // empty = password login
ExpiresAt time.Time
SessionToken string // set after authentication completes
RefreshToken string // set after authentication completes when refresh tokens are issued
Scopes []string // requested scopes
}
// externalProvider pairs a DatabaseAuthenticator with its provider name.
type externalProvider struct {
auth *DatabaseAuthenticator
providerName string
}
// OAuthServer implements the MCP-standard OAuth2 authorization server (OAuth 2.1 + PKCE).
//
// It can act as both:
// - A direct identity provider using DatabaseAuthenticator username/password login
// - A federation layer that delegates authentication to external OAuth2 providers
// (Google, GitHub, Microsoft, etc.) registered via RegisterExternalProvider
//
// The server exposes these RFC-compliant endpoints:
//
// GET /.well-known/oauth-authorization-server RFC 8414 — server metadata discovery
// POST /oauth/register RFC 7591 — dynamic client registration
// GET /oauth/authorize OAuth 2.1 + PKCE — start authorization
// POST /oauth/authorize Direct login form submission
// POST /oauth/token Token exchange and refresh
// POST /oauth/revoke RFC 7009 — token revocation
// POST /oauth/introspect RFC 7662 — token introspection
// GET {ProviderCallbackPath} Internal — external provider callback
type OAuthServer struct {
cfg OAuthServerConfig
auth *DatabaseAuthenticator // nil = only external providers
providers []externalProvider
mu sync.RWMutex
clients map[string]*oauthClient
pending map[string]*pendingAuth // provider_state → pending (external flow)
codes map[string]*pendingAuth // auth_code → pending (post-auth)
done chan struct{} // closed by Close() to stop background goroutines
}
// NewOAuthServer creates a new MCP OAuth2 authorization server.
//
// Pass a DatabaseAuthenticator to enable direct username/password login (the server
// acts as its own identity provider). Pass nil to use only external providers.
// External providers are added separately via RegisterExternalProvider.
//
// Call Close() to stop background goroutines when the server is no longer needed.
func NewOAuthServer(cfg OAuthServerConfig, auth *DatabaseAuthenticator) *OAuthServer {
if cfg.ProviderCallbackPath == "" {
cfg.ProviderCallbackPath = "/oauth/provider/callback"
}
if cfg.LoginTitle == "" {
cfg.LoginTitle = "Sign in"
}
if len(cfg.DefaultScopes) == 0 {
cfg.DefaultScopes = []string{"openid", "profile", "email"}
}
if cfg.AccessTokenTTL == 0 {
cfg.AccessTokenTTL = 24 * time.Hour
}
if cfg.AuthCodeTTL == 0 {
cfg.AuthCodeTTL = 2 * time.Minute
}
// Normalize issuer: remove trailing slash to ensure consistent endpoint URL construction.
cfg.Issuer = strings.TrimSuffix(cfg.Issuer, "/")
s := &OAuthServer{
cfg: cfg,
auth: auth,
clients: make(map[string]*oauthClient),
pending: make(map[string]*pendingAuth),
codes: make(map[string]*pendingAuth),
done: make(chan struct{}),
}
go s.cleanupExpired()
return s
}
// Close stops the background goroutines started by NewOAuthServer.
// It is safe to call Close multiple times.
func (s *OAuthServer) Close() {
select {
case <-s.done:
// already closed
default:
close(s.done)
}
}
// RegisterExternalProvider adds an external OAuth2 provider (Google, GitHub, Microsoft, etc.)
// that handles user authentication via redirect. The DatabaseAuthenticator must have been
// configured with WithOAuth2(providerName, ...) before calling this.
// Multiple providers can be registered; the first is used as the default.
// All providers must be registered before the server starts serving requests.
func (s *OAuthServer) RegisterExternalProvider(auth *DatabaseAuthenticator, providerName string) {
s.mu.Lock()
s.providers = append(s.providers, externalProvider{auth: auth, providerName: providerName})
s.mu.Unlock()
}
// ProviderCallbackPath returns the configured path for external provider callbacks.
func (s *OAuthServer) ProviderCallbackPath() string {
return s.cfg.ProviderCallbackPath
}
// HTTPHandler returns an http.Handler that serves all RFC-required OAuth2 endpoints.
// Mount it at the root of your HTTP server alongside the MCP transport.
//
// mux := http.NewServeMux()
// mux.Handle("/", oauthServer.HTTPHandler())
// mux.Handle("/mcp/", mcpTransport)
func (s *OAuthServer) HTTPHandler() http.Handler {
mux := http.NewServeMux()
mux.HandleFunc("/.well-known/oauth-authorization-server", s.metadataHandler)
mux.HandleFunc("/oauth/register", s.registerHandler)
mux.HandleFunc("/oauth/authorize", s.authorizeHandler)
mux.HandleFunc("/oauth/token", s.tokenHandler)
mux.HandleFunc("/oauth/revoke", s.revokeHandler)
mux.HandleFunc("/oauth/introspect", s.introspectHandler)
mux.HandleFunc(s.cfg.ProviderCallbackPath, s.providerCallbackHandler)
return mux
}
// cleanupExpired removes stale pending auths and codes every 5 minutes.
func (s *OAuthServer) cleanupExpired() {
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for {
select {
case <-s.done:
return
case <-ticker.C:
now := time.Now()
s.mu.Lock()
for k, p := range s.pending {
if now.After(p.ExpiresAt) {
delete(s.pending, k)
}
}
for k, p := range s.codes {
if now.After(p.ExpiresAt) {
delete(s.codes, k)
}
}
s.mu.Unlock()
}
}
}
// --------------------------------------------------------------------------
// RFC 8414 — Server metadata
// --------------------------------------------------------------------------
func (s *OAuthServer) metadataHandler(w http.ResponseWriter, r *http.Request) {
issuer := s.cfg.Issuer
meta := map[string]interface{}{
"issuer": issuer,
"authorization_endpoint": issuer + "/oauth/authorize",
"token_endpoint": issuer + "/oauth/token",
"registration_endpoint": issuer + "/oauth/register",
"revocation_endpoint": issuer + "/oauth/revoke",
"introspection_endpoint": issuer + "/oauth/introspect",
"scopes_supported": s.cfg.DefaultScopes,
"response_types_supported": []string{"code"},
"grant_types_supported": []string{"authorization_code", "refresh_token"},
"code_challenge_methods_supported": []string{"S256"},
"token_endpoint_auth_methods_supported": []string{"none"},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(meta) //nolint:errcheck
}
// --------------------------------------------------------------------------
// RFC 7591 — Dynamic client registration
// --------------------------------------------------------------------------
func (s *OAuthServer) registerHandler(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
var req struct {
RedirectURIs []string `json:"redirect_uris"`
ClientName string `json:"client_name"`
GrantTypes []string `json:"grant_types"`
AllowedScopes []string `json:"allowed_scopes"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeOAuthError(w, "invalid_request", "malformed JSON", http.StatusBadRequest)
return
}
if len(req.RedirectURIs) == 0 {
writeOAuthError(w, "invalid_request", "redirect_uris required", http.StatusBadRequest)
return
}
grantTypes := req.GrantTypes
if len(grantTypes) == 0 {
grantTypes = []string{"authorization_code"}
}
allowedScopes := req.AllowedScopes
if len(allowedScopes) == 0 {
allowedScopes = s.cfg.DefaultScopes
}
clientID, err := randomOAuthToken()
if err != nil {
http.Error(w, "server error", http.StatusInternalServerError)
return
}
client := &oauthClient{
ClientID: clientID,
RedirectURIs: req.RedirectURIs,
ClientName: req.ClientName,
GrantTypes: grantTypes,
AllowedScopes: allowedScopes,
}
if s.cfg.PersistClients && s.auth != nil {
dbClient := &OAuthServerClient{
ClientID: client.ClientID,
RedirectURIs: client.RedirectURIs,
ClientName: client.ClientName,
GrantTypes: client.GrantTypes,
AllowedScopes: client.AllowedScopes,
}
if _, err := s.auth.OAuthRegisterClient(r.Context(), dbClient); err != nil {
http.Error(w, "server error", http.StatusInternalServerError)
return
}
}
s.mu.Lock()
s.clients[clientID] = client
s.mu.Unlock()
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusCreated)
json.NewEncoder(w).Encode(client) //nolint:errcheck
}
// --------------------------------------------------------------------------
// Authorization endpoint — GET + POST /oauth/authorize
// --------------------------------------------------------------------------
func (s *OAuthServer) authorizeHandler(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodGet:
s.authorizeGet(w, r)
case http.MethodPost:
s.authorizePost(w, r)
default:
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
}
}
// authorizeGet validates the request and either:
// - Redirects to an external provider (if providers are registered)
// - Renders a login form (if the server is its own identity provider)
func (s *OAuthServer) authorizeGet(w http.ResponseWriter, r *http.Request) {
q := r.URL.Query()
clientID := q.Get("client_id")
redirectURI := q.Get("redirect_uri")
clientState := q.Get("state")
codeChallenge := q.Get("code_challenge")
codeChallengeMethod := q.Get("code_challenge_method")
providerName := q.Get("provider")
scopeStr := q.Get("scope")
var scopes []string
if scopeStr != "" {
scopes = strings.Fields(scopeStr)
}
if q.Get("response_type") != "code" {
writeOAuthError(w, "unsupported_response_type", "only 'code' is supported", http.StatusBadRequest)
return
}
if codeChallenge == "" {
writeOAuthError(w, "invalid_request", "code_challenge required (PKCE S256)", http.StatusBadRequest)
return
}
if codeChallengeMethod != "" && codeChallengeMethod != "S256" {
writeOAuthError(w, "invalid_request", "only S256 code_challenge_method is supported", http.StatusBadRequest)
return
}
client, ok := s.lookupOrFetchClient(r.Context(), clientID)
if !ok {
writeOAuthError(w, "invalid_client", "unknown client_id", http.StatusBadRequest)
return
}
if !oauthSliceContains(client.RedirectURIs, redirectURI) {
writeOAuthError(w, "invalid_request", "redirect_uri not registered", http.StatusBadRequest)
return
}
// External provider path
if len(s.providers) > 0 {
s.redirectToExternalProvider(w, r, clientID, redirectURI, clientState, codeChallenge, codeChallengeMethod, providerName, scopes)
return
}
// Direct login form path (server is its own identity provider)
if s.auth == nil {
http.Error(w, "no authentication provider configured", http.StatusInternalServerError)
return
}
s.renderLoginForm(w, r, clientID, redirectURI, clientState, codeChallenge, codeChallengeMethod, scopeStr, "")
}
// authorizePost handles login form submission for the direct login flow.
func (s *OAuthServer) authorizePost(w http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil {
http.Error(w, "invalid form", http.StatusBadRequest)
return
}
clientID := r.FormValue("client_id")
redirectURI := r.FormValue("redirect_uri")
clientState := r.FormValue("client_state")
codeChallenge := r.FormValue("code_challenge")
codeChallengeMethod := r.FormValue("code_challenge_method")
username := r.FormValue("username")
password := r.FormValue("password")
scopeStr := r.FormValue("scope")
var scopes []string
if scopeStr != "" {
scopes = strings.Fields(scopeStr)
}
client, ok := s.lookupOrFetchClient(r.Context(), clientID)
if !ok || !oauthSliceContains(client.RedirectURIs, redirectURI) {
http.Error(w, "invalid client or redirect_uri", http.StatusBadRequest)
return
}
if s.auth == nil {
http.Error(w, "no authentication provider configured", http.StatusInternalServerError)
return
}
loginResp, err := s.auth.Login(r.Context(), LoginRequest{
Username: username,
Password: password,
})
if err != nil {
s.renderLoginForm(w, r, clientID, redirectURI, clientState, codeChallenge, codeChallengeMethod, scopeStr, "Invalid username or password")
return
}
s.issueCodeAndRedirect(w, r, loginResp.Token, loginResp.RefreshToken, clientID, redirectURI, clientState, codeChallenge, codeChallengeMethod, "", scopes)
}
// redirectToExternalProvider stores the pending auth and redirects to the configured provider.
func (s *OAuthServer) redirectToExternalProvider(w http.ResponseWriter, r *http.Request, clientID, redirectURI, clientState, codeChallenge, codeChallengeMethod, providerName string, scopes []string) {
var provider *externalProvider
if providerName != "" {
for i := range s.providers {
if s.providers[i].providerName == providerName {
provider = &s.providers[i]
break
}
}
if provider == nil {
http.Error(w, fmt.Sprintf("provider %q not found", providerName), http.StatusBadRequest)
return
}
} else {
provider = &s.providers[0]
}
providerState, err := randomOAuthToken()
if err != nil {
http.Error(w, "server error", http.StatusInternalServerError)
return
}
pending := &pendingAuth{
ClientID: clientID,
RedirectURI: redirectURI,
ClientState: clientState,
CodeChallenge: codeChallenge,
CodeChallengeMethod: codeChallengeMethod,
ProviderName: provider.providerName,
ExpiresAt: time.Now().Add(10 * time.Minute),
Scopes: scopes,
}
s.mu.Lock()
s.pending[providerState] = pending
s.mu.Unlock()
authURL, err := provider.auth.OAuth2GetAuthURL(provider.providerName, providerState)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
http.Redirect(w, r, authURL, http.StatusFound)
}
// --------------------------------------------------------------------------
// External provider callback — GET {ProviderCallbackPath}
// --------------------------------------------------------------------------
func (s *OAuthServer) providerCallbackHandler(w http.ResponseWriter, r *http.Request) {
code := r.URL.Query().Get("code")
providerState := r.URL.Query().Get("state")
if code == "" {
http.Error(w, "missing code", http.StatusBadRequest)
return
}
s.mu.Lock()
pending, ok := s.pending[providerState]
if ok {
delete(s.pending, providerState)
}
s.mu.Unlock()
if !ok || time.Now().After(pending.ExpiresAt) {
http.Error(w, "invalid or expired state", http.StatusBadRequest)
return
}
provider := s.providerByName(pending.ProviderName)
if provider == nil {
http.Error(w, fmt.Sprintf("provider %q not found", pending.ProviderName), http.StatusInternalServerError)
return
}
loginResp, err := provider.auth.OAuth2HandleCallback(r.Context(), pending.ProviderName, code, providerState)
if err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
s.issueCodeAndRedirect(w, r, loginResp.Token, loginResp.RefreshToken,
pending.ClientID, pending.RedirectURI, pending.ClientState,
pending.CodeChallenge, pending.CodeChallengeMethod, pending.ProviderName, pending.Scopes)
}
// issueCodeAndRedirect generates a short-lived auth code and redirects to the MCP client.
func (s *OAuthServer) issueCodeAndRedirect(w http.ResponseWriter, r *http.Request, sessionToken, refreshToken, clientID, redirectURI, clientState, codeChallenge, codeChallengeMethod, providerName string, scopes []string) {
authCode, err := randomOAuthToken()
if err != nil {
http.Error(w, "server error", http.StatusInternalServerError)
return
}
pending := &pendingAuth{
ClientID: clientID,
RedirectURI: redirectURI,
ClientState: clientState,
CodeChallenge: codeChallenge,
CodeChallengeMethod: codeChallengeMethod,
ProviderName: providerName,
SessionToken: sessionToken,
RefreshToken: refreshToken,
ExpiresAt: time.Now().Add(s.cfg.AuthCodeTTL),
Scopes: scopes,
}
if s.cfg.PersistCodes && s.auth != nil {
oauthCode := &OAuthCode{
Code: authCode,
ClientID: clientID,
RedirectURI: redirectURI,
ClientState: clientState,
CodeChallenge: codeChallenge,
CodeChallengeMethod: codeChallengeMethod,
SessionToken: sessionToken,
RefreshToken: refreshToken,
Scopes: scopes,
ExpiresAt: pending.ExpiresAt,
}
if err := s.auth.OAuthSaveCode(r.Context(), oauthCode); err != nil {
http.Error(w, "server error", http.StatusInternalServerError)
return
}
} else {
s.mu.Lock()
s.codes[authCode] = pending
s.mu.Unlock()
}
redirectURL, err := url.Parse(redirectURI)
if err != nil {
http.Error(w, "invalid redirect_uri", http.StatusInternalServerError)
return
}
qp := redirectURL.Query()
qp.Set("code", authCode)
if clientState != "" {
qp.Set("state", clientState)
}
redirectURL.RawQuery = qp.Encode()
http.Redirect(w, r, redirectURL.String(), http.StatusFound)
}
// --------------------------------------------------------------------------
// Token endpoint — POST /oauth/token
// --------------------------------------------------------------------------
func (s *OAuthServer) tokenHandler(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
if err := r.ParseForm(); err != nil {
writeOAuthError(w, "invalid_request", "cannot parse form", http.StatusBadRequest)
return
}
switch r.FormValue("grant_type") {
case "authorization_code":
s.handleAuthCodeGrant(w, r)
case "refresh_token":
s.handleRefreshGrant(w, r)
default:
writeOAuthError(w, "unsupported_grant_type", "", http.StatusBadRequest)
}
}
func (s *OAuthServer) handleAuthCodeGrant(w http.ResponseWriter, r *http.Request) {
code := r.FormValue("code")
redirectURI := r.FormValue("redirect_uri")
clientID := r.FormValue("client_id")
codeVerifier := r.FormValue("code_verifier")
if code == "" || codeVerifier == "" {
writeOAuthError(w, "invalid_request", "code and code_verifier required", http.StatusBadRequest)
return
}
var sessionToken string
var refreshToken string
var scopes []string
if s.cfg.PersistCodes && s.auth != nil {
oauthCode, err := s.auth.OAuthExchangeCode(r.Context(), code)
if err != nil {
writeOAuthError(w, "invalid_grant", "code expired or invalid", http.StatusBadRequest)
return
}
if oauthCode.ClientID != clientID {
writeOAuthError(w, "invalid_client", "", http.StatusBadRequest)
return
}
if oauthCode.RedirectURI != redirectURI {
writeOAuthError(w, "invalid_grant", "redirect_uri mismatch", http.StatusBadRequest)
return
}
if !validatePKCESHA256(oauthCode.CodeChallenge, codeVerifier) {
writeOAuthError(w, "invalid_grant", "code_verifier invalid", http.StatusBadRequest)
return
}
sessionToken = oauthCode.SessionToken
refreshToken = oauthCode.RefreshToken
scopes = oauthCode.Scopes
} else {
s.mu.Lock()
pending, ok := s.codes[code]
if ok {
delete(s.codes, code)
}
s.mu.Unlock()
if !ok || time.Now().After(pending.ExpiresAt) {
writeOAuthError(w, "invalid_grant", "code expired or invalid", http.StatusBadRequest)
return
}
if pending.ClientID != clientID {
writeOAuthError(w, "invalid_client", "", http.StatusBadRequest)
return
}
if pending.RedirectURI != redirectURI {
writeOAuthError(w, "invalid_grant", "redirect_uri mismatch", http.StatusBadRequest)
return
}
if !validatePKCESHA256(pending.CodeChallenge, codeVerifier) {
writeOAuthError(w, "invalid_grant", "code_verifier invalid", http.StatusBadRequest)
return
}
sessionToken = pending.SessionToken
refreshToken = pending.RefreshToken
scopes = pending.Scopes
}
s.writeOAuthToken(w, sessionToken, refreshToken, scopes)
}
func (s *OAuthServer) handleRefreshGrant(w http.ResponseWriter, r *http.Request) {
refreshToken := r.FormValue("refresh_token")
providerName := r.FormValue("provider")
if refreshToken == "" {
writeOAuthError(w, "invalid_request", "refresh_token required", http.StatusBadRequest)
return
}
// Try external providers first, then fall back to DatabaseAuthenticator
provider := s.providerByName(providerName)
if provider != nil {
loginResp, err := provider.auth.OAuth2RefreshToken(r.Context(), refreshToken, providerName)
if err != nil {
writeOAuthError(w, "invalid_grant", err.Error(), http.StatusBadRequest)
return
}
s.writeOAuthToken(w, loginResp.Token, loginResp.RefreshToken, nil)
return
}
if s.auth != nil {
loginResp, err := s.auth.RefreshToken(r.Context(), refreshToken)
if err != nil {
writeOAuthError(w, "invalid_grant", err.Error(), http.StatusBadRequest)
return
}
s.writeOAuthToken(w, loginResp.Token, loginResp.RefreshToken, nil)
return
}
writeOAuthError(w, "invalid_grant", "no provider available for refresh", http.StatusBadRequest)
}
// --------------------------------------------------------------------------
// RFC 7009 — Token revocation
// --------------------------------------------------------------------------
func (s *OAuthServer) revokeHandler(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
if err := r.ParseForm(); err != nil {
w.WriteHeader(http.StatusOK)
return
}
token := r.FormValue("token")
if token == "" {
w.WriteHeader(http.StatusOK)
return
}
if s.auth != nil {
s.auth.OAuthRevokeToken(r.Context(), token) //nolint:errcheck
} else {
// In external-provider-only mode, attempt revocation via the first provider's auth.
s.mu.RLock()
var providerAuth *DatabaseAuthenticator
if len(s.providers) > 0 {
providerAuth = s.providers[0].auth
}
s.mu.RUnlock()
if providerAuth != nil {
providerAuth.OAuthRevokeToken(r.Context(), token) //nolint:errcheck
}
}
w.WriteHeader(http.StatusOK)
}
// --------------------------------------------------------------------------
// RFC 7662 — Token introspection
// --------------------------------------------------------------------------
func (s *OAuthServer) introspectHandler(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
if err := r.ParseForm(); err != nil {
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(`{"active":false}`)) //nolint:errcheck
return
}
token := r.FormValue("token")
w.Header().Set("Content-Type", "application/json")
if token == "" {
w.Write([]byte(`{"active":false}`)) //nolint:errcheck
return
}
// Resolve the authenticator to use: prefer the primary auth, then the first provider's auth.
authToUse := s.auth
if authToUse == nil {
s.mu.RLock()
if len(s.providers) > 0 {
authToUse = s.providers[0].auth
}
s.mu.RUnlock()
}
if authToUse == nil {
w.Write([]byte(`{"active":false}`)) //nolint:errcheck
return
}
info, err := authToUse.OAuthIntrospectToken(r.Context(), token)
if err != nil {
w.Write([]byte(`{"active":false}`)) //nolint:errcheck
return
}
json.NewEncoder(w).Encode(info) //nolint:errcheck
}
// --------------------------------------------------------------------------
// Login form (direct identity provider mode)
// --------------------------------------------------------------------------
func (s *OAuthServer) renderLoginForm(w http.ResponseWriter, r *http.Request, clientID, redirectURI, clientState, codeChallenge, codeChallengeMethod, scope, errMsg string) {
w.Header().Set("Content-Type", "text/html; charset=utf-8")
errHTML := ""
if errMsg != "" {
errHTML = `<p style="color:red">` + errMsg + `</p>`
}
fmt.Fprintf(w, loginFormHTML,
s.cfg.LoginTitle,
s.cfg.LoginTitle,
errHTML,
clientID,
htmlEscape(redirectURI),
htmlEscape(clientState),
htmlEscape(codeChallenge),
htmlEscape(codeChallengeMethod),
htmlEscape(scope),
)
}
const loginFormHTML = `<!DOCTYPE html>
<html><head><meta charset="utf-8"><title>%s</title>
<style>body{font-family:sans-serif;display:flex;justify-content:center;align-items:center;min-height:100vh;margin:0;background:#f5f5f5}
.card{background:#fff;padding:2rem;border-radius:8px;box-shadow:0 2px 8px rgba(0,0,0,.15);width:320px}
h2{margin:0 0 1.5rem;font-size:1.25rem}
label{display:block;margin-bottom:.25rem;font-size:.875rem;color:#555}
input[type=text],input[type=password]{width:100%%;box-sizing:border-box;padding:.5rem;border:1px solid #ccc;border-radius:4px;margin-bottom:1rem;font-size:1rem}
button{width:100%%;padding:.6rem;background:#0070f3;color:#fff;border:none;border-radius:4px;font-size:1rem;cursor:pointer}
button:hover{background:#005fd4}.err{color:#d32f2f;margin-bottom:1rem;font-size:.875rem}</style>
</head><body><div class="card">
<h2>%s</h2>%s
<form method="POST" action="authorize">
<input type="hidden" name="client_id" value="%s">
<input type="hidden" name="redirect_uri" value="%s">
<input type="hidden" name="client_state" value="%s">
<input type="hidden" name="code_challenge" value="%s">
<input type="hidden" name="code_challenge_method" value="%s">
<input type="hidden" name="scope" value="%s">
<label>Username</label><input type="text" name="username" autofocus autocomplete="username">
<label>Password</label><input type="password" name="password" autocomplete="current-password">
<button type="submit">Sign in</button>
</form></div></body></html>`
// --------------------------------------------------------------------------
// Helpers
// --------------------------------------------------------------------------
// lookupOrFetchClient checks in-memory first, then DB if PersistClients is enabled.
func (s *OAuthServer) lookupOrFetchClient(ctx context.Context, clientID string) (*oauthClient, bool) {
s.mu.RLock()
c, ok := s.clients[clientID]
s.mu.RUnlock()
if ok {
return c, true
}
if !s.cfg.PersistClients || s.auth == nil {
return nil, false
}
dbClient, err := s.auth.OAuthGetClient(ctx, clientID)
if err != nil {
return nil, false
}
c = &oauthClient{
ClientID: dbClient.ClientID,
RedirectURIs: dbClient.RedirectURIs,
ClientName: dbClient.ClientName,
GrantTypes: dbClient.GrantTypes,
AllowedScopes: dbClient.AllowedScopes,
}
s.mu.Lock()
s.clients[clientID] = c
s.mu.Unlock()
return c, true
}
func (s *OAuthServer) providerByName(name string) *externalProvider {
for i := range s.providers {
if s.providers[i].providerName == name {
return &s.providers[i]
}
}
// If name is empty and only one provider exists, return it
if name == "" && len(s.providers) == 1 {
return &s.providers[0]
}
return nil
}
func validatePKCESHA256(challenge, verifier string) bool {
h := sha256.Sum256([]byte(verifier))
return base64.RawURLEncoding.EncodeToString(h[:]) == challenge
}
func randomOAuthToken() (string, error) {
b := make([]byte, 32)
if _, err := rand.Read(b); err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(b), nil
}
func oauthSliceContains(slice []string, s string) bool {
for _, v := range slice {
if v == s {
return true
}
}
return false
}
func (s *OAuthServer) writeOAuthToken(w http.ResponseWriter, accessToken, refreshToken string, scopes []string) {
expiresIn := int64(s.cfg.AccessTokenTTL.Seconds())
resp := map[string]interface{}{
"access_token": accessToken,
"token_type": "Bearer",
"expires_in": expiresIn,
}
if refreshToken != "" {
resp["refresh_token"] = refreshToken
}
if len(scopes) > 0 {
resp["scope"] = strings.Join(scopes, " ")
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Cache-Control", "no-store")
w.Header().Set("Pragma", "no-cache")
json.NewEncoder(w).Encode(resp) //nolint:errcheck
}
func writeOAuthError(w http.ResponseWriter, errCode, description string, status int) {
resp := map[string]string{"error": errCode}
if description != "" {
resp["error_description"] = description
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
json.NewEncoder(w).Encode(resp) //nolint:errcheck
}
func htmlEscape(s string) string {
s = strings.ReplaceAll(s, "&", "&amp;")
s = strings.ReplaceAll(s, `"`, "&#34;")
s = strings.ReplaceAll(s, "<", "&lt;")
s = strings.ReplaceAll(s, ">", "&gt;")
return s
}

View File

@@ -0,0 +1,204 @@
package security
import (
"context"
"encoding/json"
"fmt"
"time"
)
// OAuthServerClient is a persisted RFC 7591 registered OAuth2 client.
type OAuthServerClient struct {
ClientID string `json:"client_id"`
RedirectURIs []string `json:"redirect_uris"`
ClientName string `json:"client_name,omitempty"`
GrantTypes []string `json:"grant_types"`
AllowedScopes []string `json:"allowed_scopes,omitempty"`
}
// OAuthCode is a short-lived authorization code.
type OAuthCode struct {
Code string `json:"code"`
ClientID string `json:"client_id"`
RedirectURI string `json:"redirect_uri"`
ClientState string `json:"client_state,omitempty"`
CodeChallenge string `json:"code_challenge"`
CodeChallengeMethod string `json:"code_challenge_method"`
SessionToken string `json:"session_token"`
RefreshToken string `json:"refresh_token,omitempty"`
Scopes []string `json:"scopes,omitempty"`
ExpiresAt time.Time `json:"expires_at"`
}
// OAuthTokenInfo is the RFC 7662 token introspection response.
type OAuthTokenInfo struct {
Active bool `json:"active"`
Sub string `json:"sub,omitempty"`
Username string `json:"username,omitempty"`
Email string `json:"email,omitempty"`
UserLevel int `json:"user_level,omitempty"`
Roles []string `json:"roles,omitempty"`
Exp int64 `json:"exp,omitempty"`
Iat int64 `json:"iat,omitempty"`
}
// OAuthRegisterClient persists an OAuth2 client registration.
func (a *DatabaseAuthenticator) OAuthRegisterClient(ctx context.Context, client *OAuthServerClient) (*OAuthServerClient, error) {
input, err := json.Marshal(client)
if err != nil {
return nil, fmt.Errorf("failed to marshal client: %w", err)
}
var success bool
var errMsg *string
var data []byte
err = a.db.QueryRowContext(ctx, fmt.Sprintf(`
SELECT p_success, p_error, p_data::text
FROM %s($1::jsonb)
`, a.sqlNames.OAuthRegisterClient), input).Scan(&success, &errMsg, &data)
if err != nil {
return nil, fmt.Errorf("failed to register client: %w", err)
}
if !success {
if errMsg != nil {
return nil, fmt.Errorf("%s", *errMsg)
}
return nil, fmt.Errorf("failed to register client")
}
var result OAuthServerClient
if err := json.Unmarshal(data, &result); err != nil {
return nil, fmt.Errorf("failed to parse registered client: %w", err)
}
return &result, nil
}
// OAuthGetClient retrieves a registered client by ID.
func (a *DatabaseAuthenticator) OAuthGetClient(ctx context.Context, clientID string) (*OAuthServerClient, error) {
var success bool
var errMsg *string
var data []byte
err := a.db.QueryRowContext(ctx, fmt.Sprintf(`
SELECT p_success, p_error, p_data::text
FROM %s($1)
`, a.sqlNames.OAuthGetClient), clientID).Scan(&success, &errMsg, &data)
if err != nil {
return nil, fmt.Errorf("failed to get client: %w", err)
}
if !success {
if errMsg != nil {
return nil, fmt.Errorf("%s", *errMsg)
}
return nil, fmt.Errorf("client not found")
}
var result OAuthServerClient
if err := json.Unmarshal(data, &result); err != nil {
return nil, fmt.Errorf("failed to parse client: %w", err)
}
return &result, nil
}
// OAuthSaveCode persists an authorization code.
func (a *DatabaseAuthenticator) OAuthSaveCode(ctx context.Context, code *OAuthCode) error {
input, err := json.Marshal(code)
if err != nil {
return fmt.Errorf("failed to marshal code: %w", err)
}
var success bool
var errMsg *string
err = a.db.QueryRowContext(ctx, fmt.Sprintf(`
SELECT p_success, p_error
FROM %s($1::jsonb)
`, a.sqlNames.OAuthSaveCode), input).Scan(&success, &errMsg)
if err != nil {
return fmt.Errorf("failed to save code: %w", err)
}
if !success {
if errMsg != nil {
return fmt.Errorf("%s", *errMsg)
}
return fmt.Errorf("failed to save code")
}
return nil
}
// OAuthExchangeCode retrieves and deletes an authorization code (single use).
func (a *DatabaseAuthenticator) OAuthExchangeCode(ctx context.Context, code string) (*OAuthCode, error) {
var success bool
var errMsg *string
var data []byte
err := a.db.QueryRowContext(ctx, fmt.Sprintf(`
SELECT p_success, p_error, p_data::text
FROM %s($1)
`, a.sqlNames.OAuthExchangeCode), code).Scan(&success, &errMsg, &data)
if err != nil {
return nil, fmt.Errorf("failed to exchange code: %w", err)
}
if !success {
if errMsg != nil {
return nil, fmt.Errorf("%s", *errMsg)
}
return nil, fmt.Errorf("invalid or expired code")
}
var result OAuthCode
if err := json.Unmarshal(data, &result); err != nil {
return nil, fmt.Errorf("failed to parse code data: %w", err)
}
result.Code = code
return &result, nil
}
// OAuthIntrospectToken validates a token and returns its metadata (RFC 7662).
func (a *DatabaseAuthenticator) OAuthIntrospectToken(ctx context.Context, token string) (*OAuthTokenInfo, error) {
var success bool
var errMsg *string
var data []byte
err := a.db.QueryRowContext(ctx, fmt.Sprintf(`
SELECT p_success, p_error, p_data::text
FROM %s($1)
`, a.sqlNames.OAuthIntrospect), token).Scan(&success, &errMsg, &data)
if err != nil {
return nil, fmt.Errorf("failed to introspect token: %w", err)
}
if !success {
if errMsg != nil {
return nil, fmt.Errorf("%s", *errMsg)
}
return nil, fmt.Errorf("introspection failed")
}
var result OAuthTokenInfo
if err := json.Unmarshal(data, &result); err != nil {
return nil, fmt.Errorf("failed to parse token info: %w", err)
}
return &result, nil
}
// OAuthRevokeToken revokes a token by deleting the session (RFC 7009).
func (a *DatabaseAuthenticator) OAuthRevokeToken(ctx context.Context, token string) error {
var success bool
var errMsg *string
err := a.db.QueryRowContext(ctx, fmt.Sprintf(`
SELECT p_success, p_error
FROM %s($1)
`, a.sqlNames.OAuthRevoke), token).Scan(&success, &errMsg)
if err != nil {
return fmt.Errorf("failed to revoke token: %w", err)
}
if !success {
if errMsg != nil {
return fmt.Errorf("%s", *errMsg)
}
return fmt.Errorf("failed to revoke token")
}
return nil
}

View File

@@ -7,6 +7,7 @@ import (
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
"fmt" "fmt"
"sync"
"time" "time"
) )
@@ -14,6 +15,8 @@ import (
// Procedure names are configurable via SQLNames (see DefaultSQLNames for defaults) // Procedure names are configurable via SQLNames (see DefaultSQLNames for defaults)
type DatabasePasskeyProvider struct { type DatabasePasskeyProvider struct {
db *sql.DB db *sql.DB
dbMu sync.RWMutex
dbFactory func() (*sql.DB, error)
rpID string // Relying Party ID (domain) rpID string // Relying Party ID (domain)
rpName string // Relying Party display name rpName string // Relying Party display name
rpOrigin string // Expected origin for WebAuthn rpOrigin string // Expected origin for WebAuthn
@@ -33,6 +36,9 @@ type DatabasePasskeyProviderOptions struct {
Timeout int64 Timeout int64
// SQLNames provides custom SQL procedure/function names. If nil, uses DefaultSQLNames(). // SQLNames provides custom SQL procedure/function names. If nil, uses DefaultSQLNames().
SQLNames *SQLNames SQLNames *SQLNames
// DBFactory is called to obtain a fresh *sql.DB when the existing connection is closed.
// If nil, reconnection is disabled.
DBFactory func() (*sql.DB, error)
} }
// NewDatabasePasskeyProvider creates a new database-backed passkey provider // NewDatabasePasskeyProvider creates a new database-backed passkey provider
@@ -45,6 +51,7 @@ func NewDatabasePasskeyProvider(db *sql.DB, opts DatabasePasskeyProviderOptions)
return &DatabasePasskeyProvider{ return &DatabasePasskeyProvider{
db: db, db: db,
dbFactory: opts.DBFactory,
rpID: opts.RPID, rpID: opts.RPID,
rpName: opts.RPName, rpName: opts.RPName,
rpOrigin: opts.RPOrigin, rpOrigin: opts.RPOrigin,
@@ -53,6 +60,26 @@ func NewDatabasePasskeyProvider(db *sql.DB, opts DatabasePasskeyProviderOptions)
} }
} }
func (p *DatabasePasskeyProvider) getDB() *sql.DB {
p.dbMu.RLock()
defer p.dbMu.RUnlock()
return p.db
}
func (p *DatabasePasskeyProvider) reconnectDB() error {
if p.dbFactory == nil {
return fmt.Errorf("no db factory configured for reconnect")
}
newDB, err := p.dbFactory()
if err != nil {
return err
}
p.dbMu.Lock()
p.db = newDB
p.dbMu.Unlock()
return nil
}
// BeginRegistration creates registration options for a new passkey // BeginRegistration creates registration options for a new passkey
func (p *DatabasePasskeyProvider) BeginRegistration(ctx context.Context, userID int, username, displayName string) (*PasskeyRegistrationOptions, error) { func (p *DatabasePasskeyProvider) BeginRegistration(ctx context.Context, userID int, username, displayName string) (*PasskeyRegistrationOptions, error) {
// Generate challenge // Generate challenge
@@ -140,7 +167,7 @@ func (p *DatabasePasskeyProvider) CompleteRegistration(ctx context.Context, user
var credentialID sql.NullInt64 var credentialID sql.NullInt64
query := fmt.Sprintf(`SELECT p_success, p_error, p_credential_id FROM %s($1::jsonb)`, p.sqlNames.PasskeyStoreCredential) query := fmt.Sprintf(`SELECT p_success, p_error, p_credential_id FROM %s($1::jsonb)`, p.sqlNames.PasskeyStoreCredential)
err = p.db.QueryRowContext(ctx, query, string(credJSON)).Scan(&success, &errorMsg, &credentialID) err = p.getDB().QueryRowContext(ctx, query, string(credJSON)).Scan(&success, &errorMsg, &credentialID)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to store credential: %w", err) return nil, fmt.Errorf("failed to store credential: %w", err)
} }
@@ -181,7 +208,7 @@ func (p *DatabasePasskeyProvider) BeginAuthentication(ctx context.Context, usern
var credentialsJSON sql.NullString var credentialsJSON sql.NullString
query := fmt.Sprintf(`SELECT p_success, p_error, p_user_id, p_credentials::text FROM %s($1)`, p.sqlNames.PasskeyGetCredsByUsername) query := fmt.Sprintf(`SELECT p_success, p_error, p_user_id, p_credentials::text FROM %s($1)`, p.sqlNames.PasskeyGetCredsByUsername)
err := p.db.QueryRowContext(ctx, query, username).Scan(&success, &errorMsg, &userID, &credentialsJSON) err := p.getDB().QueryRowContext(ctx, query, username).Scan(&success, &errorMsg, &userID, &credentialsJSON)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get credentials: %w", err) return nil, fmt.Errorf("failed to get credentials: %w", err)
} }
@@ -240,8 +267,16 @@ func (p *DatabasePasskeyProvider) CompleteAuthentication(ctx context.Context, re
var errorMsg sql.NullString var errorMsg sql.NullString
var credentialJSON sql.NullString var credentialJSON sql.NullString
runQuery := func() error {
query := fmt.Sprintf(`SELECT p_success, p_error, p_credential::text FROM %s($1)`, p.sqlNames.PasskeyGetCredential) query := fmt.Sprintf(`SELECT p_success, p_error, p_credential::text FROM %s($1)`, p.sqlNames.PasskeyGetCredential)
err := p.db.QueryRowContext(ctx, query, response.RawID).Scan(&success, &errorMsg, &credentialJSON) return p.getDB().QueryRowContext(ctx, query, response.RawID).Scan(&success, &errorMsg, &credentialJSON)
}
err := runQuery()
if isDBClosed(err) {
if reconnErr := p.reconnectDB(); reconnErr == nil {
err = runQuery()
}
}
if err != nil { if err != nil {
return 0, fmt.Errorf("failed to get credential: %w", err) return 0, fmt.Errorf("failed to get credential: %w", err)
} }
@@ -272,7 +307,7 @@ func (p *DatabasePasskeyProvider) CompleteAuthentication(ctx context.Context, re
var cloneWarning sql.NullBool var cloneWarning sql.NullBool
updateQuery := fmt.Sprintf(`SELECT p_success, p_error, p_clone_warning FROM %s($1, $2)`, p.sqlNames.PasskeyUpdateCounter) updateQuery := fmt.Sprintf(`SELECT p_success, p_error, p_clone_warning FROM %s($1, $2)`, p.sqlNames.PasskeyUpdateCounter)
err = p.db.QueryRowContext(ctx, updateQuery, response.RawID, newCounter).Scan(&updateSuccess, &updateError, &cloneWarning) err = p.getDB().QueryRowContext(ctx, updateQuery, response.RawID, newCounter).Scan(&updateSuccess, &updateError, &cloneWarning)
if err != nil { if err != nil {
return 0, fmt.Errorf("failed to update counter: %w", err) return 0, fmt.Errorf("failed to update counter: %w", err)
} }
@@ -291,7 +326,7 @@ func (p *DatabasePasskeyProvider) GetCredentials(ctx context.Context, userID int
var credentialsJSON sql.NullString var credentialsJSON sql.NullString
query := fmt.Sprintf(`SELECT p_success, p_error, p_credentials::text FROM %s($1)`, p.sqlNames.PasskeyGetUserCredentials) query := fmt.Sprintf(`SELECT p_success, p_error, p_credentials::text FROM %s($1)`, p.sqlNames.PasskeyGetUserCredentials)
err := p.db.QueryRowContext(ctx, query, userID).Scan(&success, &errorMsg, &credentialsJSON) err := p.getDB().QueryRowContext(ctx, query, userID).Scan(&success, &errorMsg, &credentialsJSON)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get credentials: %w", err) return nil, fmt.Errorf("failed to get credentials: %w", err)
} }
@@ -370,7 +405,7 @@ func (p *DatabasePasskeyProvider) DeleteCredential(ctx context.Context, userID i
var errorMsg sql.NullString var errorMsg sql.NullString
query := fmt.Sprintf(`SELECT p_success, p_error FROM %s($1, $2)`, p.sqlNames.PasskeyDeleteCredential) query := fmt.Sprintf(`SELECT p_success, p_error FROM %s($1, $2)`, p.sqlNames.PasskeyDeleteCredential)
err = p.db.QueryRowContext(ctx, query, userID, credID).Scan(&success, &errorMsg) err = p.getDB().QueryRowContext(ctx, query, userID, credID).Scan(&success, &errorMsg)
if err != nil { if err != nil {
return fmt.Errorf("failed to delete credential: %w", err) return fmt.Errorf("failed to delete credential: %w", err)
} }
@@ -396,7 +431,7 @@ func (p *DatabasePasskeyProvider) UpdateCredentialName(ctx context.Context, user
var errorMsg sql.NullString var errorMsg sql.NullString
query := fmt.Sprintf(`SELECT p_success, p_error FROM %s($1, $2, $3)`, p.sqlNames.PasskeyUpdateName) query := fmt.Sprintf(`SELECT p_success, p_error FROM %s($1, $2, $3)`, p.sqlNames.PasskeyUpdateName)
err = p.db.QueryRowContext(ctx, query, userID, credID, name).Scan(&success, &errorMsg) err = p.getDB().QueryRowContext(ctx, query, userID, credID, name).Scan(&success, &errorMsg)
if err != nil { if err != nil {
return fmt.Errorf("failed to update credential name: %w", err) return fmt.Errorf("failed to update credential name: %w", err)
} }

View File

@@ -64,6 +64,8 @@ func (a *HeaderAuthenticator) Authenticate(r *http.Request) (*UserContext, error
// Also supports passkey authentication configured with WithPasskey() // Also supports passkey authentication configured with WithPasskey()
type DatabaseAuthenticator struct { type DatabaseAuthenticator struct {
db *sql.DB db *sql.DB
dbMu sync.RWMutex
dbFactory func() (*sql.DB, error)
cache *cache.Cache cache *cache.Cache
cacheTTL time.Duration cacheTTL time.Duration
sqlNames *SQLNames sqlNames *SQLNames
@@ -88,6 +90,9 @@ type DatabaseAuthenticatorOptions struct {
// SQLNames provides custom SQL procedure/function names. If nil, uses DefaultSQLNames(). // SQLNames provides custom SQL procedure/function names. If nil, uses DefaultSQLNames().
// Partial overrides are supported: only set the fields you want to change. // Partial overrides are supported: only set the fields you want to change.
SQLNames *SQLNames SQLNames *SQLNames
// DBFactory is called to obtain a fresh *sql.DB when the existing connection is closed.
// If nil, reconnection is disabled.
DBFactory func() (*sql.DB, error)
} }
func NewDatabaseAuthenticator(db *sql.DB) *DatabaseAuthenticator { func NewDatabaseAuthenticator(db *sql.DB) *DatabaseAuthenticator {
@@ -110,6 +115,7 @@ func NewDatabaseAuthenticatorWithOptions(db *sql.DB, opts DatabaseAuthenticatorO
return &DatabaseAuthenticator{ return &DatabaseAuthenticator{
db: db, db: db,
dbFactory: opts.DBFactory,
cache: cacheInstance, cache: cacheInstance,
cacheTTL: opts.CacheTTL, cacheTTL: opts.CacheTTL,
sqlNames: sqlNames, sqlNames: sqlNames,
@@ -117,6 +123,26 @@ func NewDatabaseAuthenticatorWithOptions(db *sql.DB, opts DatabaseAuthenticatorO
} }
} }
func (a *DatabaseAuthenticator) getDB() *sql.DB {
a.dbMu.RLock()
defer a.dbMu.RUnlock()
return a.db
}
func (a *DatabaseAuthenticator) reconnectDB() error {
if a.dbFactory == nil {
return fmt.Errorf("no db factory configured for reconnect")
}
newDB, err := a.dbFactory()
if err != nil {
return err
}
a.dbMu.Lock()
a.db = newDB
a.dbMu.Unlock()
return nil
}
func (a *DatabaseAuthenticator) Login(ctx context.Context, req LoginRequest) (*LoginResponse, error) { func (a *DatabaseAuthenticator) Login(ctx context.Context, req LoginRequest) (*LoginResponse, error) {
// Convert LoginRequest to JSON // Convert LoginRequest to JSON
reqJSON, err := json.Marshal(req) reqJSON, err := json.Marshal(req)
@@ -128,8 +154,16 @@ func (a *DatabaseAuthenticator) Login(ctx context.Context, req LoginRequest) (*L
var errorMsg sql.NullString var errorMsg sql.NullString
var dataJSON sql.NullString var dataJSON sql.NullString
runLoginQuery := func() error {
query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.Login) query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.Login)
err = a.db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON) return a.getDB().QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
}
err = runLoginQuery()
if isDBClosed(err) {
if reconnErr := a.reconnectDB(); reconnErr == nil {
err = runLoginQuery()
}
}
if err != nil { if err != nil {
return nil, fmt.Errorf("login query failed: %w", err) return nil, fmt.Errorf("login query failed: %w", err)
} }
@@ -163,7 +197,7 @@ func (a *DatabaseAuthenticator) Register(ctx context.Context, req RegisterReques
var dataJSON sql.NullString var dataJSON sql.NullString
query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.Register) query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.Register)
err = a.db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON) err = a.getDB().QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
if err != nil { if err != nil {
return nil, fmt.Errorf("register query failed: %w", err) return nil, fmt.Errorf("register query failed: %w", err)
} }
@@ -196,7 +230,7 @@ func (a *DatabaseAuthenticator) Logout(ctx context.Context, req LogoutRequest) e
var dataJSON sql.NullString var dataJSON sql.NullString
query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.Logout) query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.Logout)
err = a.db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON) err = a.getDB().QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
if err != nil { if err != nil {
return fmt.Errorf("logout query failed: %w", err) return fmt.Errorf("logout query failed: %w", err)
} }
@@ -270,7 +304,7 @@ func (a *DatabaseAuthenticator) Authenticate(r *http.Request) (*UserContext, err
var userJSON sql.NullString var userJSON sql.NullString
query := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2)`, a.sqlNames.Session) query := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2)`, a.sqlNames.Session)
err := a.db.QueryRowContext(r.Context(), query, token, reference).Scan(&success, &errorMsg, &userJSON) err := a.getDB().QueryRowContext(r.Context(), query, token, reference).Scan(&success, &errorMsg, &userJSON)
if err != nil { if err != nil {
return nil, fmt.Errorf("session query failed: %w", err) return nil, fmt.Errorf("session query failed: %w", err)
} }
@@ -346,7 +380,7 @@ func (a *DatabaseAuthenticator) updateSessionActivity(ctx context.Context, sessi
var updatedUserJSON sql.NullString var updatedUserJSON sql.NullString
query := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2::jsonb)`, a.sqlNames.SessionUpdate) query := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2::jsonb)`, a.sqlNames.SessionUpdate)
_ = a.db.QueryRowContext(ctx, query, sessionToken, string(userJSON)).Scan(&success, &errorMsg, &updatedUserJSON) _ = a.getDB().QueryRowContext(ctx, query, sessionToken, string(userJSON)).Scan(&success, &errorMsg, &updatedUserJSON)
} }
// RefreshToken implements Refreshable interface // RefreshToken implements Refreshable interface
@@ -357,7 +391,7 @@ func (a *DatabaseAuthenticator) RefreshToken(ctx context.Context, refreshToken s
var userJSON sql.NullString var userJSON sql.NullString
// Get current session to pass to refresh // Get current session to pass to refresh
query := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2)`, a.sqlNames.Session) query := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2)`, a.sqlNames.Session)
err := a.db.QueryRowContext(ctx, query, refreshToken, "refresh").Scan(&success, &errorMsg, &userJSON) err := a.getDB().QueryRowContext(ctx, query, refreshToken, "refresh").Scan(&success, &errorMsg, &userJSON)
if err != nil { if err != nil {
return nil, fmt.Errorf("refresh token query failed: %w", err) return nil, fmt.Errorf("refresh token query failed: %w", err)
} }
@@ -374,7 +408,7 @@ func (a *DatabaseAuthenticator) RefreshToken(ctx context.Context, refreshToken s
var newUserJSON sql.NullString var newUserJSON sql.NullString
refreshQuery := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2::jsonb)`, a.sqlNames.RefreshToken) refreshQuery := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2::jsonb)`, a.sqlNames.RefreshToken)
err = a.db.QueryRowContext(ctx, refreshQuery, refreshToken, userJSON).Scan(&newSuccess, &newErrorMsg, &newUserJSON) err = a.getDB().QueryRowContext(ctx, refreshQuery, refreshToken, userJSON).Scan(&newSuccess, &newErrorMsg, &newUserJSON)
if err != nil { if err != nil {
return nil, fmt.Errorf("refresh token generation failed: %w", err) return nil, fmt.Errorf("refresh token generation failed: %w", err)
} }
@@ -406,6 +440,8 @@ func (a *DatabaseAuthenticator) RefreshToken(ctx context.Context, refreshToken s
type JWTAuthenticator struct { type JWTAuthenticator struct {
secretKey []byte secretKey []byte
db *sql.DB db *sql.DB
dbMu sync.RWMutex
dbFactory func() (*sql.DB, error)
sqlNames *SQLNames sqlNames *SQLNames
} }
@@ -417,13 +453,47 @@ func NewJWTAuthenticator(secretKey string, db *sql.DB, names ...*SQLNames) *JWTA
} }
} }
// WithDBFactory configures a factory used to reopen the database connection if it is closed.
func (a *JWTAuthenticator) WithDBFactory(factory func() (*sql.DB, error)) *JWTAuthenticator {
a.dbFactory = factory
return a
}
func (a *JWTAuthenticator) getDB() *sql.DB {
a.dbMu.RLock()
defer a.dbMu.RUnlock()
return a.db
}
func (a *JWTAuthenticator) reconnectDB() error {
if a.dbFactory == nil {
return fmt.Errorf("no db factory configured for reconnect")
}
newDB, err := a.dbFactory()
if err != nil {
return err
}
a.dbMu.Lock()
a.db = newDB
a.dbMu.Unlock()
return nil
}
func (a *JWTAuthenticator) Login(ctx context.Context, req LoginRequest) (*LoginResponse, error) { func (a *JWTAuthenticator) Login(ctx context.Context, req LoginRequest) (*LoginResponse, error) {
var success bool var success bool
var errorMsg sql.NullString var errorMsg sql.NullString
var userJSON []byte var userJSON []byte
runLoginQuery := func() error {
query := fmt.Sprintf(`SELECT p_success, p_error, p_user FROM %s($1, $2)`, a.sqlNames.JWTLogin) query := fmt.Sprintf(`SELECT p_success, p_error, p_user FROM %s($1, $2)`, a.sqlNames.JWTLogin)
err := a.db.QueryRowContext(ctx, query, req.Username, req.Password).Scan(&success, &errorMsg, &userJSON) return a.getDB().QueryRowContext(ctx, query, req.Username, req.Password).Scan(&success, &errorMsg, &userJSON)
}
err := runLoginQuery()
if isDBClosed(err) {
if reconnErr := a.reconnectDB(); reconnErr == nil {
err = runLoginQuery()
}
}
if err != nil { if err != nil {
return nil, fmt.Errorf("login query failed: %w", err) return nil, fmt.Errorf("login query failed: %w", err)
} }
@@ -476,7 +546,7 @@ func (a *JWTAuthenticator) Logout(ctx context.Context, req LogoutRequest) error
var errorMsg sql.NullString var errorMsg sql.NullString
query := fmt.Sprintf(`SELECT p_success, p_error FROM %s($1, $2)`, a.sqlNames.JWTLogout) query := fmt.Sprintf(`SELECT p_success, p_error FROM %s($1, $2)`, a.sqlNames.JWTLogout)
err := a.db.QueryRowContext(ctx, query, req.Token, req.UserID).Scan(&success, &errorMsg) err := a.getDB().QueryRowContext(ctx, query, req.Token, req.UserID).Scan(&success, &errorMsg)
if err != nil { if err != nil {
return fmt.Errorf("logout query failed: %w", err) return fmt.Errorf("logout query failed: %w", err)
} }
@@ -514,6 +584,8 @@ func (a *JWTAuthenticator) Authenticate(r *http.Request) (*UserContext, error) {
// Procedure names are configurable via SQLNames (see DefaultSQLNames for defaults) // Procedure names are configurable via SQLNames (see DefaultSQLNames for defaults)
type DatabaseColumnSecurityProvider struct { type DatabaseColumnSecurityProvider struct {
db *sql.DB db *sql.DB
dbMu sync.RWMutex
dbFactory func() (*sql.DB, error)
sqlNames *SQLNames sqlNames *SQLNames
} }
@@ -521,6 +593,31 @@ func NewDatabaseColumnSecurityProvider(db *sql.DB, names ...*SQLNames) *Database
return &DatabaseColumnSecurityProvider{db: db, sqlNames: resolveSQLNames(names...)} return &DatabaseColumnSecurityProvider{db: db, sqlNames: resolveSQLNames(names...)}
} }
func (p *DatabaseColumnSecurityProvider) WithDBFactory(factory func() (*sql.DB, error)) *DatabaseColumnSecurityProvider {
p.dbFactory = factory
return p
}
func (p *DatabaseColumnSecurityProvider) getDB() *sql.DB {
p.dbMu.RLock()
defer p.dbMu.RUnlock()
return p.db
}
func (p *DatabaseColumnSecurityProvider) reconnectDB() error {
if p.dbFactory == nil {
return fmt.Errorf("no db factory configured for reconnect")
}
newDB, err := p.dbFactory()
if err != nil {
return err
}
p.dbMu.Lock()
p.db = newDB
p.dbMu.Unlock()
return nil
}
func (p *DatabaseColumnSecurityProvider) GetColumnSecurity(ctx context.Context, userID int, schema, table string) ([]ColumnSecurity, error) { func (p *DatabaseColumnSecurityProvider) GetColumnSecurity(ctx context.Context, userID int, schema, table string) ([]ColumnSecurity, error) {
var rules []ColumnSecurity var rules []ColumnSecurity
@@ -528,8 +625,16 @@ func (p *DatabaseColumnSecurityProvider) GetColumnSecurity(ctx context.Context,
var errorMsg sql.NullString var errorMsg sql.NullString
var rulesJSON []byte var rulesJSON []byte
runQuery := func() error {
query := fmt.Sprintf(`SELECT p_success, p_error, p_rules FROM %s($1, $2, $3)`, p.sqlNames.ColumnSecurity) query := fmt.Sprintf(`SELECT p_success, p_error, p_rules FROM %s($1, $2, $3)`, p.sqlNames.ColumnSecurity)
err := p.db.QueryRowContext(ctx, query, userID, schema, table).Scan(&success, &errorMsg, &rulesJSON) return p.getDB().QueryRowContext(ctx, query, userID, schema, table).Scan(&success, &errorMsg, &rulesJSON)
}
err := runQuery()
if isDBClosed(err) {
if reconnErr := p.reconnectDB(); reconnErr == nil {
err = runQuery()
}
}
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to load column security: %w", err) return nil, fmt.Errorf("failed to load column security: %w", err)
} }
@@ -579,6 +684,8 @@ func (p *DatabaseColumnSecurityProvider) GetColumnSecurity(ctx context.Context,
// Procedure names are configurable via SQLNames (see DefaultSQLNames for defaults) // Procedure names are configurable via SQLNames (see DefaultSQLNames for defaults)
type DatabaseRowSecurityProvider struct { type DatabaseRowSecurityProvider struct {
db *sql.DB db *sql.DB
dbMu sync.RWMutex
dbFactory func() (*sql.DB, error)
sqlNames *SQLNames sqlNames *SQLNames
} }
@@ -586,13 +693,45 @@ func NewDatabaseRowSecurityProvider(db *sql.DB, names ...*SQLNames) *DatabaseRow
return &DatabaseRowSecurityProvider{db: db, sqlNames: resolveSQLNames(names...)} return &DatabaseRowSecurityProvider{db: db, sqlNames: resolveSQLNames(names...)}
} }
func (p *DatabaseRowSecurityProvider) WithDBFactory(factory func() (*sql.DB, error)) *DatabaseRowSecurityProvider {
p.dbFactory = factory
return p
}
func (p *DatabaseRowSecurityProvider) getDB() *sql.DB {
p.dbMu.RLock()
defer p.dbMu.RUnlock()
return p.db
}
func (p *DatabaseRowSecurityProvider) reconnectDB() error {
if p.dbFactory == nil {
return fmt.Errorf("no db factory configured for reconnect")
}
newDB, err := p.dbFactory()
if err != nil {
return err
}
p.dbMu.Lock()
p.db = newDB
p.dbMu.Unlock()
return nil
}
func (p *DatabaseRowSecurityProvider) GetRowSecurity(ctx context.Context, userID int, schema, table string) (RowSecurity, error) { func (p *DatabaseRowSecurityProvider) GetRowSecurity(ctx context.Context, userID int, schema, table string) (RowSecurity, error) {
var template string var template string
var hasBlock bool var hasBlock bool
runQuery := func() error {
query := fmt.Sprintf(`SELECT p_template, p_block FROM %s($1, $2, $3)`, p.sqlNames.RowSecurity) query := fmt.Sprintf(`SELECT p_template, p_block FROM %s($1, $2, $3)`, p.sqlNames.RowSecurity)
return p.getDB().QueryRowContext(ctx, query, schema, table, userID).Scan(&template, &hasBlock)
err := p.db.QueryRowContext(ctx, query, schema, table, userID).Scan(&template, &hasBlock) }
err := runQuery()
if isDBClosed(err) {
if reconnErr := p.reconnectDB(); reconnErr == nil {
err = runQuery()
}
}
if err != nil { if err != nil {
return RowSecurity{}, fmt.Errorf("failed to load row security: %w", err) return RowSecurity{}, fmt.Errorf("failed to load row security: %w", err)
} }
@@ -662,6 +801,11 @@ func (p *ConfigRowSecurityProvider) GetRowSecurity(ctx context.Context, userID i
// Helper functions // Helper functions
// ================ // ================
// isDBClosed reports whether err indicates the *sql.DB has been closed.
func isDBClosed(err error) bool {
return err != nil && strings.Contains(err.Error(), "sql: database is closed")
}
func parseRoles(rolesStr string) []string { func parseRoles(rolesStr string) []string {
if rolesStr == "" { if rolesStr == "" {
return []string{} return []string{}
@@ -780,8 +924,16 @@ func (a *DatabaseAuthenticator) LoginWithPasskey(ctx context.Context, req Passke
var errorMsg sql.NullString var errorMsg sql.NullString
var dataJSON sql.NullString var dataJSON sql.NullString
runPasskeyQuery := func() error {
query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.PasskeyLogin) query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.PasskeyLogin)
err = a.db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON) return a.getDB().QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
}
err = runPasskeyQuery()
if isDBClosed(err) {
if reconnErr := a.reconnectDB(); reconnErr == nil {
err = runPasskeyQuery()
}
}
if err != nil { if err != nil {
return nil, fmt.Errorf("passkey login query failed: %w", err) return nil, fmt.Errorf("passkey login query failed: %w", err)
} }

View File

@@ -54,6 +54,13 @@ type SQLNames struct {
OAuthUpdateRefreshToken string // default: "resolvespec_oauth_updaterefreshtoken" OAuthUpdateRefreshToken string // default: "resolvespec_oauth_updaterefreshtoken"
OAuthGetUser string // default: "resolvespec_oauth_getuser" OAuthGetUser string // default: "resolvespec_oauth_getuser"
// OAuth2 server procedures (OAuthServer persistence)
OAuthRegisterClient string // default: "resolvespec_oauth_register_client"
OAuthGetClient string // default: "resolvespec_oauth_get_client"
OAuthSaveCode string // default: "resolvespec_oauth_save_code"
OAuthExchangeCode string // default: "resolvespec_oauth_exchange_code"
OAuthIntrospect string // default: "resolvespec_oauth_introspect"
OAuthRevoke string // default: "resolvespec_oauth_revoke"
} }
// DefaultSQLNames returns an SQLNames with all default resolvespec_* values. // DefaultSQLNames returns an SQLNames with all default resolvespec_* values.
@@ -93,6 +100,13 @@ func DefaultSQLNames() *SQLNames {
OAuthGetRefreshToken: "resolvespec_oauth_getrefreshtoken", OAuthGetRefreshToken: "resolvespec_oauth_getrefreshtoken",
OAuthUpdateRefreshToken: "resolvespec_oauth_updaterefreshtoken", OAuthUpdateRefreshToken: "resolvespec_oauth_updaterefreshtoken",
OAuthGetUser: "resolvespec_oauth_getuser", OAuthGetUser: "resolvespec_oauth_getuser",
OAuthRegisterClient: "resolvespec_oauth_register_client",
OAuthGetClient: "resolvespec_oauth_get_client",
OAuthSaveCode: "resolvespec_oauth_save_code",
OAuthExchangeCode: "resolvespec_oauth_exchange_code",
OAuthIntrospect: "resolvespec_oauth_introspect",
OAuthRevoke: "resolvespec_oauth_revoke",
} }
} }
@@ -191,6 +205,24 @@ func MergeSQLNames(base, override *SQLNames) *SQLNames {
if override.OAuthGetUser != "" { if override.OAuthGetUser != "" {
merged.OAuthGetUser = override.OAuthGetUser merged.OAuthGetUser = override.OAuthGetUser
} }
if override.OAuthRegisterClient != "" {
merged.OAuthRegisterClient = override.OAuthRegisterClient
}
if override.OAuthGetClient != "" {
merged.OAuthGetClient = override.OAuthGetClient
}
if override.OAuthSaveCode != "" {
merged.OAuthSaveCode = override.OAuthSaveCode
}
if override.OAuthExchangeCode != "" {
merged.OAuthExchangeCode = override.OAuthExchangeCode
}
if override.OAuthIntrospect != "" {
merged.OAuthIntrospect = override.OAuthIntrospect
}
if override.OAuthRevoke != "" {
merged.OAuthRevoke = override.OAuthRevoke
}
return &merged return &merged
} }

View File

@@ -3,6 +3,7 @@ package server_test
import ( import (
"context" "context"
"fmt" "fmt"
"log"
"net/http" "net/http"
"time" "time"
@@ -29,18 +30,18 @@ func ExampleManager_basic() {
GZIP: true, // Enable GZIP compression GZIP: true, // Enable GZIP compression
}) })
if err != nil { if err != nil {
panic(err) log.Fatal(err)
} }
// Start all servers // Start all servers
if err := mgr.StartAll(); err != nil { if err := mgr.StartAll(); err != nil {
panic(err) log.Fatal(err)
} }
// Server is now running... // Server is now running...
// When done, stop gracefully // When done, stop gracefully
if err := mgr.StopAll(); err != nil { if err := mgr.StopAll(); err != nil {
panic(err) log.Fatal(err)
} }
} }
@@ -61,7 +62,7 @@ func ExampleManager_https() {
SSLKey: "/path/to/key.pem", SSLKey: "/path/to/key.pem",
}) })
if err != nil { if err != nil {
panic(err) log.Fatal(err)
} }
// Option 2: Self-signed certificate (for development) // Option 2: Self-signed certificate (for development)
@@ -73,7 +74,7 @@ func ExampleManager_https() {
SelfSignedSSL: true, SelfSignedSSL: true,
}) })
if err != nil { if err != nil {
panic(err) log.Fatal(err)
} }
// Option 3: Let's Encrypt / AutoTLS (for production) // Option 3: Let's Encrypt / AutoTLS (for production)
@@ -88,12 +89,12 @@ func ExampleManager_https() {
AutoTLSCacheDir: "./certs-cache", AutoTLSCacheDir: "./certs-cache",
}) })
if err != nil { if err != nil {
panic(err) log.Fatal(err)
} }
// Start all servers // Start all servers
if err := mgr.StartAll(); err != nil { if err := mgr.StartAll(); err != nil {
panic(err) log.Fatal(err)
} }
// Cleanup // Cleanup
@@ -136,7 +137,7 @@ func ExampleManager_gracefulShutdown() {
IdleTimeout: 120 * time.Second, IdleTimeout: 120 * time.Second,
}) })
if err != nil { if err != nil {
panic(err) log.Fatal(err)
} }
// Start servers and block until shutdown signal (SIGINT/SIGTERM) // Start servers and block until shutdown signal (SIGINT/SIGTERM)
@@ -164,7 +165,7 @@ func ExampleManager_healthChecks() {
Handler: mux, Handler: mux,
}) })
if err != nil { if err != nil {
panic(err) log.Fatal(err)
} }
// Add health and readiness endpoints // Add health and readiness endpoints
@@ -173,7 +174,7 @@ func ExampleManager_healthChecks() {
// Start the server // Start the server
if err := mgr.StartAll(); err != nil { if err := mgr.StartAll(); err != nil {
panic(err) log.Fatal(err)
} }
// Health check returns: // Health check returns:
@@ -204,7 +205,7 @@ func ExampleManager_multipleServers() {
GZIP: true, GZIP: true,
}) })
if err != nil { if err != nil {
panic(err) log.Fatal(err)
} }
// Admin API server (different port) // Admin API server (different port)
@@ -218,7 +219,7 @@ func ExampleManager_multipleServers() {
Handler: adminHandler, Handler: adminHandler,
}) })
if err != nil { if err != nil {
panic(err) log.Fatal(err)
} }
// Metrics server (internal only) // Metrics server (internal only)
@@ -232,18 +233,18 @@ func ExampleManager_multipleServers() {
Handler: metricsHandler, Handler: metricsHandler,
}) })
if err != nil { if err != nil {
panic(err) log.Fatal(err)
} }
// Start all servers at once // Start all servers at once
if err := mgr.StartAll(); err != nil { if err := mgr.StartAll(); err != nil {
panic(err) log.Fatal(err)
} }
// Get specific server instance // Get specific server instance
publicInstance, err := mgr.Get("public-api") publicInstance, err := mgr.Get("public-api")
if err != nil { if err != nil {
panic(err) log.Fatal(err)
} }
fmt.Printf("Public API running on: %s\n", publicInstance.Addr()) fmt.Printf("Public API running on: %s\n", publicInstance.Addr())
@@ -253,7 +254,7 @@ func ExampleManager_multipleServers() {
// Stop all servers gracefully (in parallel) // Stop all servers gracefully (in parallel)
if err := mgr.StopAll(); err != nil { if err := mgr.StopAll(); err != nil {
panic(err) log.Fatal(err)
} }
} }
@@ -273,11 +274,11 @@ func ExampleManager_monitoring() {
Handler: handler, Handler: handler,
}) })
if err != nil { if err != nil {
panic(err) log.Fatal(err)
} }
if err := mgr.StartAll(); err != nil { if err := mgr.StartAll(); err != nil {
panic(err) log.Fatal(err)
} }
// Check server status // Check server status