mirror of
https://github.com/bitechdev/ResolveSpec.git
synced 2026-04-10 18:03:57 +00:00
Compare commits
16 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
dfb63c3328 | ||
|
|
e8d0ab28c3 | ||
|
|
4fc25c60ae | ||
|
|
16a960d973 | ||
|
|
2afee9d238 | ||
|
|
1e89124c97 | ||
|
|
ca0545e144 | ||
|
|
850ad2b2ab | ||
|
|
2a2e33da0c | ||
|
|
17808a8121 | ||
|
|
134ff85c59 | ||
|
|
bacddc58a6 | ||
|
|
f1ad83d966 | ||
|
|
79a3912f93 | ||
| 6502b55797 | |||
|
|
a9bf08f58b |
1
.gitignore
vendored
1
.gitignore
vendored
@@ -29,3 +29,4 @@ test.db
|
||||
tests/data/
|
||||
node_modules/
|
||||
resolvespec-js/dist/
|
||||
.codex
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/uptrace/bun"
|
||||
@@ -94,22 +95,57 @@ func debugScanIntoStruct(rows interface{}, dest interface{}) error {
|
||||
// BunAdapter adapts Bun to work with our Database interface
|
||||
// This demonstrates how the abstraction works with different ORMs
|
||||
type BunAdapter struct {
|
||||
db *bun.DB
|
||||
driverName string
|
||||
db *bun.DB
|
||||
dbMu sync.RWMutex
|
||||
dbFactory func() (*bun.DB, error)
|
||||
driverName string
|
||||
metricsEnabled bool
|
||||
}
|
||||
|
||||
// NewBunAdapter creates a new Bun adapter
|
||||
func NewBunAdapter(db *bun.DB) *BunAdapter {
|
||||
adapter := &BunAdapter{db: db}
|
||||
adapter := &BunAdapter{db: db, metricsEnabled: true}
|
||||
// Initialize driver name
|
||||
adapter.driverName = adapter.DriverName()
|
||||
return adapter
|
||||
}
|
||||
|
||||
// WithDBFactory configures a factory used to reopen the database connection if it is closed.
|
||||
func (b *BunAdapter) WithDBFactory(factory func() (*bun.DB, error)) *BunAdapter {
|
||||
b.dbFactory = factory
|
||||
return b
|
||||
}
|
||||
|
||||
// SetMetricsEnabled enables or disables query metrics for this adapter.
|
||||
func (b *BunAdapter) SetMetricsEnabled(enabled bool) *BunAdapter {
|
||||
b.metricsEnabled = enabled
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *BunAdapter) getDB() *bun.DB {
|
||||
b.dbMu.RLock()
|
||||
defer b.dbMu.RUnlock()
|
||||
return b.db
|
||||
}
|
||||
|
||||
func (b *BunAdapter) reconnectDB() error {
|
||||
if b.dbFactory == nil {
|
||||
return fmt.Errorf("no db factory configured for reconnect")
|
||||
}
|
||||
newDB, err := b.dbFactory()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
b.dbMu.Lock()
|
||||
b.db = newDB
|
||||
b.dbMu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
// EnableQueryDebug enables query debugging which logs all SQL queries including preloads
|
||||
// This is useful for debugging preload queries that may be failing
|
||||
func (b *BunAdapter) EnableQueryDebug() {
|
||||
b.db.AddQueryHook(&QueryDebugHook{})
|
||||
b.getDB().AddQueryHook(&QueryDebugHook{})
|
||||
logger.Info("Bun query debug mode enabled - all SQL queries will be logged")
|
||||
}
|
||||
|
||||
@@ -130,22 +166,23 @@ func (b *BunAdapter) DisableQueryDebug() {
|
||||
|
||||
func (b *BunAdapter) NewSelect() common.SelectQuery {
|
||||
return &BunSelectQuery{
|
||||
query: b.db.NewSelect(),
|
||||
db: b.db,
|
||||
driverName: b.driverName,
|
||||
query: b.getDB().NewSelect(),
|
||||
db: b.db,
|
||||
driverName: b.driverName,
|
||||
metricsEnabled: b.metricsEnabled,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *BunAdapter) NewInsert() common.InsertQuery {
|
||||
return &BunInsertQuery{query: b.db.NewInsert()}
|
||||
return &BunInsertQuery{query: b.getDB().NewInsert(), driverName: b.driverName, metricsEnabled: b.metricsEnabled}
|
||||
}
|
||||
|
||||
func (b *BunAdapter) NewUpdate() common.UpdateQuery {
|
||||
return &BunUpdateQuery{query: b.db.NewUpdate()}
|
||||
return &BunUpdateQuery{query: b.getDB().NewUpdate(), driverName: b.driverName, metricsEnabled: b.metricsEnabled}
|
||||
}
|
||||
|
||||
func (b *BunAdapter) NewDelete() common.DeleteQuery {
|
||||
return &BunDeleteQuery{query: b.db.NewDelete()}
|
||||
return &BunDeleteQuery{query: b.getDB().NewDelete(), driverName: b.driverName, metricsEnabled: b.metricsEnabled}
|
||||
}
|
||||
|
||||
func (b *BunAdapter) Exec(ctx context.Context, query string, args ...interface{}) (res common.Result, err error) {
|
||||
@@ -154,7 +191,17 @@ func (b *BunAdapter) Exec(ctx context.Context, query string, args ...interface{}
|
||||
err = logger.HandlePanic("BunAdapter.Exec", r)
|
||||
}
|
||||
}()
|
||||
result, err := b.db.ExecContext(ctx, query, args...)
|
||||
startedAt := time.Now()
|
||||
operation, schema, entity, table := metricTargetFromRawQuery(query, b.driverName)
|
||||
var result sql.Result
|
||||
run := func() error { var e error; result, e = b.getDB().ExecContext(ctx, query, args...); return e }
|
||||
err = run()
|
||||
if isDBClosed(err) {
|
||||
if reconnErr := b.reconnectDB(); reconnErr == nil {
|
||||
err = run()
|
||||
}
|
||||
}
|
||||
recordQueryMetrics(b.metricsEnabled, operation, schema, entity, table, startedAt, err)
|
||||
return &BunResult{result: result}, err
|
||||
}
|
||||
|
||||
@@ -164,16 +211,29 @@ func (b *BunAdapter) Query(ctx context.Context, dest interface{}, query string,
|
||||
err = logger.HandlePanic("BunAdapter.Query", r)
|
||||
}
|
||||
}()
|
||||
return b.db.NewRaw(query, args...).Scan(ctx, dest)
|
||||
startedAt := time.Now()
|
||||
operation, schema, entity, table := metricTargetFromRawQuery(query, b.driverName)
|
||||
err = b.getDB().NewRaw(query, args...).Scan(ctx, dest)
|
||||
if isDBClosed(err) {
|
||||
if reconnErr := b.reconnectDB(); reconnErr == nil {
|
||||
err = b.getDB().NewRaw(query, args...).Scan(ctx, dest)
|
||||
}
|
||||
}
|
||||
recordQueryMetrics(b.metricsEnabled, operation, schema, entity, table, startedAt, err)
|
||||
return err
|
||||
}
|
||||
|
||||
func (b *BunAdapter) BeginTx(ctx context.Context) (common.Database, error) {
|
||||
tx, err := b.db.BeginTx(ctx, &sql.TxOptions{})
|
||||
tx, err := b.getDB().BeginTx(ctx, &sql.TxOptions{})
|
||||
if isDBClosed(err) {
|
||||
if reconnErr := b.reconnectDB(); reconnErr == nil {
|
||||
tx, err = b.getDB().BeginTx(ctx, &sql.TxOptions{})
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// For Bun, we'll return a special wrapper that holds the transaction
|
||||
return &BunTxAdapter{tx: tx, driverName: b.driverName}, nil
|
||||
return &BunTxAdapter{tx: tx, driverName: b.driverName, metricsEnabled: b.metricsEnabled}, nil
|
||||
}
|
||||
|
||||
func (b *BunAdapter) CommitTx(ctx context.Context) error {
|
||||
@@ -194,15 +254,23 @@ func (b *BunAdapter) RunInTransaction(ctx context.Context, fn func(common.Databa
|
||||
err = logger.HandlePanic("BunAdapter.RunInTransaction", r)
|
||||
}
|
||||
}()
|
||||
return b.db.RunInTx(ctx, &sql.TxOptions{}, func(ctx context.Context, tx bun.Tx) error {
|
||||
// Create adapter with transaction
|
||||
adapter := &BunTxAdapter{tx: tx, driverName: b.driverName}
|
||||
return fn(adapter)
|
||||
})
|
||||
run := func() error {
|
||||
return b.getDB().RunInTx(ctx, &sql.TxOptions{}, func(ctx context.Context, tx bun.Tx) error {
|
||||
adapter := &BunTxAdapter{tx: tx, driverName: b.driverName, metricsEnabled: b.metricsEnabled}
|
||||
return fn(adapter)
|
||||
})
|
||||
}
|
||||
err = run()
|
||||
if isDBClosed(err) {
|
||||
if reconnErr := b.reconnectDB(); reconnErr == nil {
|
||||
err = run()
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (b *BunAdapter) GetUnderlyingDB() interface{} {
|
||||
return b.db
|
||||
return b.getDB()
|
||||
}
|
||||
|
||||
func (b *BunAdapter) DriverName() string {
|
||||
@@ -226,25 +294,24 @@ type BunSelectQuery struct {
|
||||
hasModel bool // Track if Model() was called
|
||||
schema string // Separated schema name
|
||||
tableName string // Just the table name, without schema
|
||||
entity string
|
||||
tableAlias string
|
||||
driverName string // Database driver name (postgres, sqlite, mssql)
|
||||
inJoinContext bool // Track if we're in a JOIN relation context
|
||||
joinTableAlias string // Alias to use for JOIN conditions
|
||||
skipAutoDetect bool // Skip auto-detection to prevent circular calls
|
||||
customPreloads map[string][]func(common.SelectQuery) common.SelectQuery // Relations to load with custom implementation
|
||||
metricsEnabled bool
|
||||
}
|
||||
|
||||
func (b *BunSelectQuery) Model(model interface{}) common.SelectQuery {
|
||||
b.query = b.query.Model(model)
|
||||
b.hasModel = true // Mark that we have a model
|
||||
|
||||
// Try to get table name from model if it implements TableNameProvider
|
||||
if provider, ok := model.(common.TableNameProvider); ok {
|
||||
fullTableName := provider.TableName()
|
||||
// Check if the table name contains schema (e.g., "schema.table")
|
||||
// For SQLite, this will convert "schema.table" to "schema_table"
|
||||
b.schema, b.tableName = parseTableName(fullTableName, b.driverName)
|
||||
b.schema, b.tableName = schemaAndTableFromModel(model, b.driverName)
|
||||
if b.tableName == "" {
|
||||
b.schema, b.tableName = parseTableName(b.query.GetTableName(), b.driverName)
|
||||
}
|
||||
b.entity = entityNameFromModel(model, b.tableName)
|
||||
|
||||
if provider, ok := model.(common.TableAliasProvider); ok {
|
||||
b.tableAlias = provider.TableAlias()
|
||||
@@ -258,6 +325,9 @@ func (b *BunSelectQuery) Table(table string) common.SelectQuery {
|
||||
// Check if the table name contains schema (e.g., "schema.table")
|
||||
// For SQLite, this will convert "schema.table" to "schema_table"
|
||||
b.schema, b.tableName = parseTableName(table, b.driverName)
|
||||
if b.entity == "" {
|
||||
b.entity = cleanMetricIdentifier(b.tableName)
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
@@ -563,9 +633,10 @@ func (b *BunSelectQuery) PreloadRelation(relation string, apply ...func(common.S
|
||||
|
||||
// Wrap the incoming *bun.SelectQuery in our adapter
|
||||
wrapper := &BunSelectQuery{
|
||||
query: sq,
|
||||
db: b.db,
|
||||
driverName: b.driverName,
|
||||
query: sq,
|
||||
db: b.db,
|
||||
driverName: b.driverName,
|
||||
metricsEnabled: b.metricsEnabled,
|
||||
}
|
||||
|
||||
// Try to extract table name and alias from the preload model
|
||||
@@ -816,7 +887,7 @@ func (b *BunSelectQuery) loadRelationLevel(ctx context.Context, parentRecords re
|
||||
|
||||
// Apply user's functions (if any)
|
||||
if isLast && len(applyFuncs) > 0 {
|
||||
wrapper := &BunSelectQuery{query: query, db: b.db, driverName: b.driverName}
|
||||
wrapper := &BunSelectQuery{query: query, db: b.db, driverName: b.driverName, metricsEnabled: b.metricsEnabled}
|
||||
for _, fn := range applyFuncs {
|
||||
if fn != nil {
|
||||
wrapper = fn(wrapper).(*BunSelectQuery)
|
||||
@@ -1168,27 +1239,28 @@ func (b *BunSelectQuery) Having(having string, args ...interface{}) common.Selec
|
||||
}
|
||||
|
||||
func (b *BunSelectQuery) Scan(ctx context.Context, dest interface{}) (err error) {
|
||||
startedAt := time.Now()
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = logger.HandlePanic("BunSelectQuery.Scan", r)
|
||||
}
|
||||
recordQueryMetrics(b.metricsEnabled, "SELECT", b.schema, b.entity, b.tableName, startedAt, err)
|
||||
}()
|
||||
if dest == nil {
|
||||
return fmt.Errorf("destination cannot be nil")
|
||||
err = fmt.Errorf("destination cannot be nil")
|
||||
return err
|
||||
}
|
||||
|
||||
err = b.query.Scan(ctx, dest)
|
||||
if err != nil {
|
||||
// Log SQL string for debugging
|
||||
sqlStr := b.query.String()
|
||||
logger.Error("BunSelectQuery.Scan failed. SQL: %s. Error: %v", sqlStr, err)
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
return err
|
||||
}
|
||||
|
||||
func (b *BunSelectQuery) ScanModel(ctx context.Context) (err error) {
|
||||
startedAt := time.Now()
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
// Enhanced panic recovery with model information
|
||||
@@ -1198,7 +1270,6 @@ func (b *BunSelectQuery) ScanModel(ctx context.Context) (err error) {
|
||||
modelValue := model.Value()
|
||||
modelInfo = fmt.Sprintf("Model type: %T", modelValue)
|
||||
|
||||
// Try to get the model's underlying struct type
|
||||
v := reflect.ValueOf(modelValue)
|
||||
if v.Kind() == reflect.Ptr {
|
||||
v = v.Elem()
|
||||
@@ -1218,9 +1289,11 @@ func (b *BunSelectQuery) ScanModel(ctx context.Context) (err error) {
|
||||
logger.Error("Panic in BunSelectQuery.ScanModel: %v. %s. SQL: %s", r, modelInfo, sqlStr)
|
||||
err = logger.HandlePanic("BunSelectQuery.ScanModel", r)
|
||||
}
|
||||
recordQueryMetrics(b.metricsEnabled, "SELECT", b.schema, b.entity, b.tableName, startedAt, err)
|
||||
}()
|
||||
if b.query.GetModel() == nil {
|
||||
return fmt.Errorf("model is nil")
|
||||
err = fmt.Errorf("model is nil")
|
||||
return err
|
||||
}
|
||||
|
||||
// Optional: Enable detailed field-level debugging (set to true to debug)
|
||||
@@ -1236,7 +1309,6 @@ func (b *BunSelectQuery) ScanModel(ctx context.Context) (err error) {
|
||||
|
||||
err = b.query.Scan(ctx)
|
||||
if err != nil {
|
||||
// Log SQL string for debugging
|
||||
sqlStr := b.query.String()
|
||||
logger.Error("BunSelectQuery.ScanModel failed. SQL: %s. Error: %v", sqlStr, err)
|
||||
return err
|
||||
@@ -1245,7 +1317,7 @@ func (b *BunSelectQuery) ScanModel(ctx context.Context) (err error) {
|
||||
// After main query, load custom preloads using separate queries
|
||||
if len(b.customPreloads) > 0 {
|
||||
logger.Info("Loading %d custom preload(s) with separate queries", len(b.customPreloads))
|
||||
if err := b.loadCustomPreloads(ctx); err != nil {
|
||||
if err = b.loadCustomPreloads(ctx); err != nil {
|
||||
logger.Error("Failed to load custom preloads: %v", err)
|
||||
return err
|
||||
}
|
||||
@@ -1255,21 +1327,22 @@ func (b *BunSelectQuery) ScanModel(ctx context.Context) (err error) {
|
||||
}
|
||||
|
||||
func (b *BunSelectQuery) Count(ctx context.Context) (count int, err error) {
|
||||
startedAt := time.Now()
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = logger.HandlePanic("BunSelectQuery.Count", r)
|
||||
count = 0
|
||||
}
|
||||
recordQueryMetrics(b.metricsEnabled, "COUNT", b.schema, b.entity, b.tableName, startedAt, err)
|
||||
}()
|
||||
// If Model() was set, use bun's native Count() which works properly
|
||||
if b.hasModel {
|
||||
count, err := b.query.Count(ctx)
|
||||
count, err = b.query.Count(ctx) // assign to named returns, not shadow vars
|
||||
if err != nil {
|
||||
// Log SQL string for debugging
|
||||
sqlStr := b.query.String()
|
||||
logger.Error("BunSelectQuery.Count failed. SQL: %s. Error: %v", sqlStr, err)
|
||||
}
|
||||
return count, err
|
||||
return
|
||||
}
|
||||
|
||||
// Otherwise, wrap as subquery to avoid "Model(nil)" error
|
||||
@@ -1279,39 +1352,49 @@ func (b *BunSelectQuery) Count(ctx context.Context) (count int, err error) {
|
||||
ColumnExpr("COUNT(*)")
|
||||
err = countQuery.Scan(ctx, &count)
|
||||
if err != nil {
|
||||
// Log SQL string for debugging
|
||||
sqlStr := countQuery.String()
|
||||
logger.Error("BunSelectQuery.Count (subquery) failed. SQL: %s. Error: %v", sqlStr, err)
|
||||
}
|
||||
return count, err
|
||||
return
|
||||
}
|
||||
|
||||
func (b *BunSelectQuery) Exists(ctx context.Context) (exists bool, err error) {
|
||||
startedAt := time.Now()
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = logger.HandlePanic("BunSelectQuery.Exists", r)
|
||||
exists = false
|
||||
}
|
||||
recordQueryMetrics(b.metricsEnabled, "EXISTS", b.schema, b.entity, b.tableName, startedAt, err)
|
||||
}()
|
||||
exists, err = b.query.Exists(ctx)
|
||||
if err != nil {
|
||||
// Log SQL string for debugging
|
||||
sqlStr := b.query.String()
|
||||
logger.Error("BunSelectQuery.Exists failed. SQL: %s. Error: %v", sqlStr, err)
|
||||
}
|
||||
return exists, err
|
||||
return
|
||||
}
|
||||
|
||||
// BunInsertQuery implements InsertQuery for Bun
|
||||
type BunInsertQuery struct {
|
||||
query *bun.InsertQuery
|
||||
values map[string]interface{}
|
||||
hasModel bool
|
||||
query *bun.InsertQuery
|
||||
values map[string]interface{}
|
||||
hasModel bool
|
||||
driverName string
|
||||
schema string
|
||||
tableName string
|
||||
entity string
|
||||
metricsEnabled bool
|
||||
}
|
||||
|
||||
func (b *BunInsertQuery) Model(model interface{}) common.InsertQuery {
|
||||
b.query = b.query.Model(model)
|
||||
b.hasModel = true
|
||||
b.schema, b.tableName = schemaAndTableFromModel(model, b.driverName)
|
||||
if b.tableName == "" {
|
||||
b.schema, b.tableName = parseTableName(b.query.GetTableName(), b.driverName)
|
||||
}
|
||||
b.entity = entityNameFromModel(model, b.tableName)
|
||||
return b
|
||||
}
|
||||
|
||||
@@ -1320,6 +1403,10 @@ func (b *BunInsertQuery) Table(table string) common.InsertQuery {
|
||||
return b
|
||||
}
|
||||
b.query = b.query.Table(table)
|
||||
b.schema, b.tableName = parseTableName(table, b.driverName)
|
||||
if b.entity == "" {
|
||||
b.entity = cleanMetricIdentifier(b.tableName)
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
@@ -1349,6 +1436,7 @@ func (b *BunInsertQuery) Exec(ctx context.Context) (res common.Result, err error
|
||||
err = logger.HandlePanic("BunInsertQuery.Exec", r)
|
||||
}
|
||||
}()
|
||||
startedAt := time.Now()
|
||||
if len(b.values) > 0 {
|
||||
if !b.hasModel {
|
||||
// If no model was set, use the values map as the model
|
||||
@@ -1362,29 +1450,45 @@ func (b *BunInsertQuery) Exec(ctx context.Context) (res common.Result, err error
|
||||
}
|
||||
}
|
||||
result, err := b.query.Exec(ctx)
|
||||
recordQueryMetrics(b.metricsEnabled, "INSERT", b.schema, b.entity, b.tableName, startedAt, err)
|
||||
return &BunResult{result: result}, err
|
||||
}
|
||||
|
||||
// BunUpdateQuery implements UpdateQuery for Bun
|
||||
type BunUpdateQuery struct {
|
||||
query *bun.UpdateQuery
|
||||
model interface{}
|
||||
query *bun.UpdateQuery
|
||||
model interface{}
|
||||
driverName string
|
||||
schema string
|
||||
tableName string
|
||||
entity string
|
||||
metricsEnabled bool
|
||||
}
|
||||
|
||||
func (b *BunUpdateQuery) Model(model interface{}) common.UpdateQuery {
|
||||
b.query = b.query.Model(model)
|
||||
b.model = model
|
||||
b.schema, b.tableName = schemaAndTableFromModel(model, b.driverName)
|
||||
if b.tableName == "" {
|
||||
b.schema, b.tableName = parseTableName(b.query.GetTableName(), b.driverName)
|
||||
}
|
||||
b.entity = entityNameFromModel(model, b.tableName)
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *BunUpdateQuery) Table(table string) common.UpdateQuery {
|
||||
b.query = b.query.Table(table)
|
||||
b.schema, b.tableName = parseTableName(table, b.driverName)
|
||||
if b.entity == "" {
|
||||
b.entity = cleanMetricIdentifier(b.tableName)
|
||||
}
|
||||
if b.model == nil {
|
||||
// Try to get table name from table string if model is not set
|
||||
|
||||
model, err := modelregistry.GetModelByName(table)
|
||||
if err == nil {
|
||||
b.model = model
|
||||
b.entity = entityNameFromModel(model, b.tableName)
|
||||
}
|
||||
}
|
||||
return b
|
||||
@@ -1435,27 +1539,43 @@ func (b *BunUpdateQuery) Exec(ctx context.Context) (res common.Result, err error
|
||||
err = logger.HandlePanic("BunUpdateQuery.Exec", r)
|
||||
}
|
||||
}()
|
||||
startedAt := time.Now()
|
||||
result, err := b.query.Exec(ctx)
|
||||
if err != nil {
|
||||
// Log SQL string for debugging
|
||||
sqlStr := b.query.String()
|
||||
logger.Error("BunUpdateQuery.Exec failed. SQL: %s. Error: %v", sqlStr, err)
|
||||
}
|
||||
recordQueryMetrics(b.metricsEnabled, "UPDATE", b.schema, b.entity, b.tableName, startedAt, err)
|
||||
return &BunResult{result: result}, err
|
||||
}
|
||||
|
||||
// BunDeleteQuery implements DeleteQuery for Bun
|
||||
type BunDeleteQuery struct {
|
||||
query *bun.DeleteQuery
|
||||
query *bun.DeleteQuery
|
||||
driverName string
|
||||
schema string
|
||||
tableName string
|
||||
entity string
|
||||
metricsEnabled bool
|
||||
}
|
||||
|
||||
func (b *BunDeleteQuery) Model(model interface{}) common.DeleteQuery {
|
||||
b.query = b.query.Model(model)
|
||||
b.schema, b.tableName = schemaAndTableFromModel(model, b.driverName)
|
||||
if b.tableName == "" {
|
||||
b.schema, b.tableName = parseTableName(b.query.GetTableName(), b.driverName)
|
||||
}
|
||||
b.entity = entityNameFromModel(model, b.tableName)
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *BunDeleteQuery) Table(table string) common.DeleteQuery {
|
||||
b.query = b.query.Table(table)
|
||||
b.schema, b.tableName = parseTableName(table, b.driverName)
|
||||
if b.entity == "" {
|
||||
b.entity = cleanMetricIdentifier(b.tableName)
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
@@ -1470,12 +1590,14 @@ func (b *BunDeleteQuery) Exec(ctx context.Context) (res common.Result, err error
|
||||
err = logger.HandlePanic("BunDeleteQuery.Exec", r)
|
||||
}
|
||||
}()
|
||||
startedAt := time.Now()
|
||||
result, err := b.query.Exec(ctx)
|
||||
if err != nil {
|
||||
// Log SQL string for debugging
|
||||
sqlStr := b.query.String()
|
||||
logger.Error("BunDeleteQuery.Exec failed. SQL: %s. Error: %v", sqlStr, err)
|
||||
}
|
||||
recordQueryMetrics(b.metricsEnabled, "DELETE", b.schema, b.entity, b.tableName, startedAt, err)
|
||||
return &BunResult{result: result}, err
|
||||
}
|
||||
|
||||
@@ -1501,37 +1623,46 @@ func (b *BunResult) LastInsertId() (int64, error) {
|
||||
|
||||
// BunTxAdapter wraps a Bun transaction to implement the Database interface
|
||||
type BunTxAdapter struct {
|
||||
tx bun.Tx
|
||||
driverName string
|
||||
tx bun.Tx
|
||||
driverName string
|
||||
metricsEnabled bool
|
||||
}
|
||||
|
||||
func (b *BunTxAdapter) NewSelect() common.SelectQuery {
|
||||
return &BunSelectQuery{
|
||||
query: b.tx.NewSelect(),
|
||||
db: b.tx,
|
||||
driverName: b.driverName,
|
||||
query: b.tx.NewSelect(),
|
||||
db: b.tx,
|
||||
driverName: b.driverName,
|
||||
metricsEnabled: b.metricsEnabled,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *BunTxAdapter) NewInsert() common.InsertQuery {
|
||||
return &BunInsertQuery{query: b.tx.NewInsert()}
|
||||
return &BunInsertQuery{query: b.tx.NewInsert(), driverName: b.driverName, metricsEnabled: b.metricsEnabled}
|
||||
}
|
||||
|
||||
func (b *BunTxAdapter) NewUpdate() common.UpdateQuery {
|
||||
return &BunUpdateQuery{query: b.tx.NewUpdate()}
|
||||
return &BunUpdateQuery{query: b.tx.NewUpdate(), driverName: b.driverName, metricsEnabled: b.metricsEnabled}
|
||||
}
|
||||
|
||||
func (b *BunTxAdapter) NewDelete() common.DeleteQuery {
|
||||
return &BunDeleteQuery{query: b.tx.NewDelete()}
|
||||
return &BunDeleteQuery{query: b.tx.NewDelete(), driverName: b.driverName, metricsEnabled: b.metricsEnabled}
|
||||
}
|
||||
|
||||
func (b *BunTxAdapter) Exec(ctx context.Context, query string, args ...interface{}) (common.Result, error) {
|
||||
startedAt := time.Now()
|
||||
operation, schema, entity, table := metricTargetFromRawQuery(query, b.driverName)
|
||||
result, err := b.tx.ExecContext(ctx, query, args...)
|
||||
recordQueryMetrics(b.metricsEnabled, operation, schema, entity, table, startedAt, err)
|
||||
return &BunResult{result: result}, err
|
||||
}
|
||||
|
||||
func (b *BunTxAdapter) Query(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
|
||||
return b.tx.NewRaw(query, args...).Scan(ctx, dest)
|
||||
startedAt := time.Now()
|
||||
operation, schema, entity, table := metricTargetFromRawQuery(query, b.driverName)
|
||||
err := b.tx.NewRaw(query, args...).Scan(ctx, dest)
|
||||
recordQueryMetrics(b.metricsEnabled, operation, schema, entity, table, startedAt, err)
|
||||
return err
|
||||
}
|
||||
|
||||
func (b *BunTxAdapter) BeginTx(ctx context.Context) (common.Database, error) {
|
||||
|
||||
@@ -4,6 +4,8 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
@@ -15,22 +17,93 @@ import (
|
||||
|
||||
// GormAdapter adapts GORM to work with our Database interface
|
||||
type GormAdapter struct {
|
||||
db *gorm.DB
|
||||
driverName string
|
||||
dbMu sync.RWMutex
|
||||
db *gorm.DB
|
||||
dbFactory func() (*gorm.DB, error)
|
||||
driverName string
|
||||
metricsEnabled bool
|
||||
}
|
||||
|
||||
// NewGormAdapter creates a new GORM adapter
|
||||
func NewGormAdapter(db *gorm.DB) *GormAdapter {
|
||||
adapter := &GormAdapter{db: db}
|
||||
adapter := &GormAdapter{db: db, metricsEnabled: true}
|
||||
// Initialize driver name
|
||||
adapter.driverName = adapter.DriverName()
|
||||
return adapter
|
||||
}
|
||||
|
||||
// WithDBFactory configures a factory used to reopen the database connection if it is closed.
|
||||
func (g *GormAdapter) WithDBFactory(factory func() (*gorm.DB, error)) *GormAdapter {
|
||||
g.dbFactory = factory
|
||||
return g
|
||||
}
|
||||
|
||||
// SetMetricsEnabled enables or disables query metrics for this adapter.
|
||||
func (g *GormAdapter) SetMetricsEnabled(enabled bool) *GormAdapter {
|
||||
g.metricsEnabled = enabled
|
||||
return g
|
||||
}
|
||||
|
||||
func (g *GormAdapter) getDB() *gorm.DB {
|
||||
g.dbMu.RLock()
|
||||
defer g.dbMu.RUnlock()
|
||||
return g.db
|
||||
}
|
||||
|
||||
func (g *GormAdapter) reconnectDB(targets ...*gorm.DB) error {
|
||||
if g.dbFactory == nil {
|
||||
return fmt.Errorf("no db factory configured for reconnect")
|
||||
}
|
||||
|
||||
freshDB, err := g.dbFactory()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
g.dbMu.Lock()
|
||||
previous := g.db
|
||||
g.db = freshDB
|
||||
g.driverName = normalizeGormDriverName(freshDB)
|
||||
g.dbMu.Unlock()
|
||||
|
||||
if previous != nil {
|
||||
syncGormConnPool(previous, freshDB)
|
||||
}
|
||||
|
||||
for _, target := range targets {
|
||||
if target != nil && target != previous {
|
||||
syncGormConnPool(target, freshDB)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func syncGormConnPool(target, fresh *gorm.DB) {
|
||||
if target == nil || fresh == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if target.Config != nil && fresh.Config != nil {
|
||||
target.ConnPool = fresh.ConnPool
|
||||
}
|
||||
|
||||
if target.Statement != nil {
|
||||
if fresh.Statement != nil && fresh.Statement.ConnPool != nil {
|
||||
target.Statement.ConnPool = fresh.Statement.ConnPool
|
||||
} else if fresh.Config != nil {
|
||||
target.Statement.ConnPool = fresh.ConnPool
|
||||
}
|
||||
target.Statement.DB = target
|
||||
}
|
||||
}
|
||||
|
||||
// EnableQueryDebug enables query debugging which logs all SQL queries including preloads
|
||||
// This is useful for debugging preload queries that may be failing
|
||||
func (g *GormAdapter) EnableQueryDebug() *GormAdapter {
|
||||
g.dbMu.Lock()
|
||||
g.db = g.db.Debug()
|
||||
g.dbMu.Unlock()
|
||||
logger.Info("GORM query debug mode enabled - all SQL queries will be logged")
|
||||
return g
|
||||
}
|
||||
@@ -44,19 +117,19 @@ func (g *GormAdapter) DisableQueryDebug() *GormAdapter {
|
||||
}
|
||||
|
||||
func (g *GormAdapter) NewSelect() common.SelectQuery {
|
||||
return &GormSelectQuery{db: g.db, driverName: g.driverName}
|
||||
return &GormSelectQuery{db: g.getDB(), driverName: g.driverName, reconnect: g.reconnectDB, metricsEnabled: g.metricsEnabled}
|
||||
}
|
||||
|
||||
func (g *GormAdapter) NewInsert() common.InsertQuery {
|
||||
return &GormInsertQuery{db: g.db}
|
||||
return &GormInsertQuery{db: g.getDB(), reconnect: g.reconnectDB, driverName: g.driverName, metricsEnabled: g.metricsEnabled}
|
||||
}
|
||||
|
||||
func (g *GormAdapter) NewUpdate() common.UpdateQuery {
|
||||
return &GormUpdateQuery{db: g.db}
|
||||
return &GormUpdateQuery{db: g.getDB(), reconnect: g.reconnectDB, driverName: g.driverName, metricsEnabled: g.metricsEnabled}
|
||||
}
|
||||
|
||||
func (g *GormAdapter) NewDelete() common.DeleteQuery {
|
||||
return &GormDeleteQuery{db: g.db}
|
||||
return &GormDeleteQuery{db: g.getDB(), reconnect: g.reconnectDB, driverName: g.driverName, metricsEnabled: g.metricsEnabled}
|
||||
}
|
||||
|
||||
func (g *GormAdapter) Exec(ctx context.Context, query string, args ...interface{}) (res common.Result, err error) {
|
||||
@@ -65,7 +138,18 @@ func (g *GormAdapter) Exec(ctx context.Context, query string, args ...interface{
|
||||
err = logger.HandlePanic("GormAdapter.Exec", r)
|
||||
}
|
||||
}()
|
||||
result := g.db.WithContext(ctx).Exec(query, args...)
|
||||
startedAt := time.Now()
|
||||
operation, schema, entity, table := metricTargetFromRawQuery(query, g.driverName)
|
||||
run := func() *gorm.DB {
|
||||
return g.getDB().WithContext(ctx).Exec(query, args...)
|
||||
}
|
||||
result := run()
|
||||
if isDBClosed(result.Error) {
|
||||
if reconnErr := g.reconnectDB(); reconnErr == nil {
|
||||
result = run()
|
||||
}
|
||||
}
|
||||
recordQueryMetrics(g.metricsEnabled, operation, schema, entity, table, startedAt, result.Error)
|
||||
return &GormResult{result: result}, result.Error
|
||||
}
|
||||
|
||||
@@ -75,15 +159,35 @@ func (g *GormAdapter) Query(ctx context.Context, dest interface{}, query string,
|
||||
err = logger.HandlePanic("GormAdapter.Query", r)
|
||||
}
|
||||
}()
|
||||
return g.db.WithContext(ctx).Raw(query, args...).Find(dest).Error
|
||||
startedAt := time.Now()
|
||||
operation, schema, entity, table := metricTargetFromRawQuery(query, g.driverName)
|
||||
run := func() error {
|
||||
return g.getDB().WithContext(ctx).Raw(query, args...).Find(dest).Error
|
||||
}
|
||||
err = run()
|
||||
if isDBClosed(err) {
|
||||
if reconnErr := g.reconnectDB(); reconnErr == nil {
|
||||
err = run()
|
||||
}
|
||||
}
|
||||
recordQueryMetrics(g.metricsEnabled, operation, schema, entity, table, startedAt, err)
|
||||
return err
|
||||
}
|
||||
|
||||
func (g *GormAdapter) BeginTx(ctx context.Context) (common.Database, error) {
|
||||
tx := g.db.WithContext(ctx).Begin()
|
||||
run := func() *gorm.DB {
|
||||
return g.getDB().WithContext(ctx).Begin()
|
||||
}
|
||||
tx := run()
|
||||
if isDBClosed(tx.Error) {
|
||||
if reconnErr := g.reconnectDB(); reconnErr == nil {
|
||||
tx = run()
|
||||
}
|
||||
}
|
||||
if tx.Error != nil {
|
||||
return nil, tx.Error
|
||||
}
|
||||
return &GormAdapter{db: tx, driverName: g.driverName}, nil
|
||||
return &GormAdapter{db: tx, dbFactory: g.dbFactory, driverName: g.driverName, metricsEnabled: g.metricsEnabled}, nil
|
||||
}
|
||||
|
||||
func (g *GormAdapter) CommitTx(ctx context.Context) error {
|
||||
@@ -100,24 +204,37 @@ func (g *GormAdapter) RunInTransaction(ctx context.Context, fn func(common.Datab
|
||||
err = logger.HandlePanic("GormAdapter.RunInTransaction", r)
|
||||
}
|
||||
}()
|
||||
return g.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
adapter := &GormAdapter{db: tx, driverName: g.driverName}
|
||||
return fn(adapter)
|
||||
})
|
||||
run := func() error {
|
||||
return g.getDB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
adapter := &GormAdapter{db: tx, dbFactory: g.dbFactory, driverName: g.driverName, metricsEnabled: g.metricsEnabled}
|
||||
return fn(adapter)
|
||||
})
|
||||
}
|
||||
err = run()
|
||||
if isDBClosed(err) {
|
||||
if reconnErr := g.reconnectDB(); reconnErr == nil {
|
||||
err = run()
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (g *GormAdapter) GetUnderlyingDB() interface{} {
|
||||
return g.db
|
||||
return g.getDB()
|
||||
}
|
||||
|
||||
func (g *GormAdapter) DriverName() string {
|
||||
if g.db.Dialector == nil {
|
||||
return normalizeGormDriverName(g.getDB())
|
||||
}
|
||||
|
||||
func normalizeGormDriverName(db *gorm.DB) string {
|
||||
if db == nil || db.Dialector == nil {
|
||||
return ""
|
||||
}
|
||||
// Normalize GORM's dialector name to match the project's canonical vocabulary.
|
||||
// GORM returns "sqlserver" for MSSQL; the rest of the project uses "mssql".
|
||||
// GORM returns "sqlite" or "sqlite3" for SQLite; we normalize to "sqlite".
|
||||
switch name := g.db.Name(); name {
|
||||
switch name := db.Name(); name {
|
||||
case "sqlserver":
|
||||
return "mssql"
|
||||
case "sqlite3":
|
||||
@@ -130,24 +247,21 @@ func (g *GormAdapter) DriverName() string {
|
||||
// GormSelectQuery implements SelectQuery for GORM
|
||||
type GormSelectQuery struct {
|
||||
db *gorm.DB
|
||||
reconnect func(...*gorm.DB) error
|
||||
schema string // Separated schema name
|
||||
tableName string // Just the table name, without schema
|
||||
entity string
|
||||
tableAlias string
|
||||
driverName string // Database driver name (postgres, sqlite, mssql)
|
||||
inJoinContext bool // Track if we're in a JOIN relation context
|
||||
joinTableAlias string // Alias to use for JOIN conditions
|
||||
metricsEnabled bool
|
||||
}
|
||||
|
||||
func (g *GormSelectQuery) Model(model interface{}) common.SelectQuery {
|
||||
g.db = g.db.Model(model)
|
||||
|
||||
// Try to get table name from model if it implements TableNameProvider
|
||||
if provider, ok := model.(common.TableNameProvider); ok {
|
||||
fullTableName := provider.TableName()
|
||||
// Check if the table name contains schema (e.g., "schema.table")
|
||||
// For SQLite, this will convert "schema.table" to "schema_table"
|
||||
g.schema, g.tableName = parseTableName(fullTableName, g.driverName)
|
||||
}
|
||||
g.schema, g.tableName = schemaAndTableFromModel(model, g.driverName)
|
||||
g.entity = entityNameFromModel(model, g.tableName)
|
||||
|
||||
if provider, ok := model.(common.TableAliasProvider); ok {
|
||||
g.tableAlias = provider.TableAlias()
|
||||
@@ -161,6 +275,9 @@ func (g *GormSelectQuery) Table(table string) common.SelectQuery {
|
||||
// Check if the table name contains schema (e.g., "schema.table")
|
||||
// For SQLite, this will convert "schema.table" to "schema_table"
|
||||
g.schema, g.tableName = parseTableName(table, g.driverName)
|
||||
if g.entity == "" {
|
||||
g.entity = cleanMetricIdentifier(g.tableName)
|
||||
}
|
||||
|
||||
return g
|
||||
}
|
||||
@@ -346,8 +463,10 @@ func (g *GormSelectQuery) PreloadRelation(relation string, apply ...func(common.
|
||||
}
|
||||
|
||||
wrapper := &GormSelectQuery{
|
||||
db: db,
|
||||
driverName: g.driverName,
|
||||
db: db,
|
||||
reconnect: g.reconnect,
|
||||
driverName: g.driverName,
|
||||
metricsEnabled: g.metricsEnabled,
|
||||
}
|
||||
|
||||
current := common.SelectQuery(wrapper)
|
||||
@@ -385,9 +504,11 @@ func (g *GormSelectQuery) JoinRelation(relation string, apply ...func(common.Sel
|
||||
|
||||
wrapper := &GormSelectQuery{
|
||||
db: db,
|
||||
reconnect: g.reconnect,
|
||||
driverName: g.driverName,
|
||||
inJoinContext: true, // Mark as JOIN context
|
||||
joinTableAlias: strings.ToLower(relation), // Use relation name as alias
|
||||
metricsEnabled: g.metricsEnabled,
|
||||
}
|
||||
current := common.SelectQuery(wrapper)
|
||||
|
||||
@@ -444,7 +565,16 @@ func (g *GormSelectQuery) Scan(ctx context.Context, dest interface{}) (err error
|
||||
err = logger.HandlePanic("GormSelectQuery.Scan", r)
|
||||
}
|
||||
}()
|
||||
err = g.db.WithContext(ctx).Find(dest).Error
|
||||
startedAt := time.Now()
|
||||
run := func() error {
|
||||
return g.db.WithContext(ctx).Find(dest).Error
|
||||
}
|
||||
err = run()
|
||||
if isDBClosed(err) && g.reconnect != nil {
|
||||
if reconnErr := g.reconnect(g.db); reconnErr == nil {
|
||||
err = run()
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
// Log SQL string for debugging
|
||||
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
|
||||
@@ -452,6 +582,7 @@ func (g *GormSelectQuery) Scan(ctx context.Context, dest interface{}) (err error
|
||||
})
|
||||
logger.Error("GormSelectQuery.Scan failed. SQL: %s. Error: %v", sqlStr, err)
|
||||
}
|
||||
recordQueryMetrics(g.metricsEnabled, "SELECT", g.schema, g.entity, g.tableName, startedAt, err)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -464,7 +595,16 @@ func (g *GormSelectQuery) ScanModel(ctx context.Context) (err error) {
|
||||
if g.db.Statement.Model == nil {
|
||||
return fmt.Errorf("ScanModel requires Model() to be set before scanning")
|
||||
}
|
||||
err = g.db.WithContext(ctx).Find(g.db.Statement.Model).Error
|
||||
startedAt := time.Now()
|
||||
run := func() error {
|
||||
return g.db.WithContext(ctx).Find(g.db.Statement.Model).Error
|
||||
}
|
||||
err = run()
|
||||
if isDBClosed(err) && g.reconnect != nil {
|
||||
if reconnErr := g.reconnect(g.db); reconnErr == nil {
|
||||
err = run()
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
// Log SQL string for debugging
|
||||
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
|
||||
@@ -472,6 +612,7 @@ func (g *GormSelectQuery) ScanModel(ctx context.Context) (err error) {
|
||||
})
|
||||
logger.Error("GormSelectQuery.ScanModel failed. SQL: %s. Error: %v", sqlStr, err)
|
||||
}
|
||||
recordQueryMetrics(g.metricsEnabled, "SELECT", g.schema, g.entity, g.tableName, startedAt, err)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -482,8 +623,17 @@ func (g *GormSelectQuery) Count(ctx context.Context) (count int, err error) {
|
||||
count = 0
|
||||
}
|
||||
}()
|
||||
startedAt := time.Now()
|
||||
var count64 int64
|
||||
err = g.db.WithContext(ctx).Count(&count64).Error
|
||||
run := func() error {
|
||||
return g.db.WithContext(ctx).Count(&count64).Error
|
||||
}
|
||||
err = run()
|
||||
if isDBClosed(err) && g.reconnect != nil {
|
||||
if reconnErr := g.reconnect(g.db); reconnErr == nil {
|
||||
err = run()
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
// Log SQL string for debugging
|
||||
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
|
||||
@@ -491,6 +641,7 @@ func (g *GormSelectQuery) Count(ctx context.Context) (count int, err error) {
|
||||
})
|
||||
logger.Error("GormSelectQuery.Count failed. SQL: %s. Error: %v", sqlStr, err)
|
||||
}
|
||||
recordQueryMetrics(g.metricsEnabled, "COUNT", g.schema, g.entity, g.tableName, startedAt, err)
|
||||
return int(count64), err
|
||||
}
|
||||
|
||||
@@ -501,8 +652,17 @@ func (g *GormSelectQuery) Exists(ctx context.Context) (exists bool, err error) {
|
||||
exists = false
|
||||
}
|
||||
}()
|
||||
startedAt := time.Now()
|
||||
var count int64
|
||||
err = g.db.WithContext(ctx).Limit(1).Count(&count).Error
|
||||
run := func() error {
|
||||
return g.db.WithContext(ctx).Limit(1).Count(&count).Error
|
||||
}
|
||||
err = run()
|
||||
if isDBClosed(err) && g.reconnect != nil {
|
||||
if reconnErr := g.reconnect(g.db); reconnErr == nil {
|
||||
err = run()
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
// Log SQL string for debugging
|
||||
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
|
||||
@@ -510,24 +670,37 @@ func (g *GormSelectQuery) Exists(ctx context.Context) (exists bool, err error) {
|
||||
})
|
||||
logger.Error("GormSelectQuery.Exists failed. SQL: %s. Error: %v", sqlStr, err)
|
||||
}
|
||||
recordQueryMetrics(g.metricsEnabled, "EXISTS", g.schema, g.entity, g.tableName, startedAt, err)
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
// GormInsertQuery implements InsertQuery for GORM
|
||||
type GormInsertQuery struct {
|
||||
db *gorm.DB
|
||||
model interface{}
|
||||
values map[string]interface{}
|
||||
db *gorm.DB
|
||||
reconnect func(...*gorm.DB) error
|
||||
model interface{}
|
||||
values map[string]interface{}
|
||||
schema string
|
||||
tableName string
|
||||
entity string
|
||||
driverName string
|
||||
metricsEnabled bool
|
||||
}
|
||||
|
||||
func (g *GormInsertQuery) Model(model interface{}) common.InsertQuery {
|
||||
g.model = model
|
||||
g.db = g.db.Model(model)
|
||||
g.schema, g.tableName = schemaAndTableFromModel(model, g.driverName)
|
||||
g.entity = entityNameFromModel(model, g.tableName)
|
||||
return g
|
||||
}
|
||||
|
||||
func (g *GormInsertQuery) Table(table string) common.InsertQuery {
|
||||
g.db = g.db.Table(table)
|
||||
g.schema, g.tableName = parseTableName(table, g.driverName)
|
||||
if g.entity == "" {
|
||||
g.entity = cleanMetricIdentifier(g.tableName)
|
||||
}
|
||||
return g
|
||||
}
|
||||
|
||||
@@ -555,38 +728,60 @@ func (g *GormInsertQuery) Exec(ctx context.Context) (res common.Result, err erro
|
||||
err = logger.HandlePanic("GormInsertQuery.Exec", r)
|
||||
}
|
||||
}()
|
||||
var result *gorm.DB
|
||||
switch {
|
||||
case g.model != nil:
|
||||
result = g.db.WithContext(ctx).Create(g.model)
|
||||
case g.values != nil:
|
||||
result = g.db.WithContext(ctx).Create(g.values)
|
||||
default:
|
||||
result = g.db.WithContext(ctx).Create(map[string]interface{}{})
|
||||
startedAt := time.Now()
|
||||
run := func() *gorm.DB {
|
||||
switch {
|
||||
case g.model != nil:
|
||||
return g.db.WithContext(ctx).Create(g.model)
|
||||
case g.values != nil:
|
||||
return g.db.WithContext(ctx).Create(g.values)
|
||||
default:
|
||||
return g.db.WithContext(ctx).Create(map[string]interface{}{})
|
||||
}
|
||||
}
|
||||
result := run()
|
||||
if isDBClosed(result.Error) && g.reconnect != nil {
|
||||
if reconnErr := g.reconnect(g.db); reconnErr == nil {
|
||||
result = run()
|
||||
}
|
||||
}
|
||||
recordQueryMetrics(g.metricsEnabled, "INSERT", g.schema, g.entity, g.tableName, startedAt, result.Error)
|
||||
return &GormResult{result: result}, result.Error
|
||||
}
|
||||
|
||||
// GormUpdateQuery implements UpdateQuery for GORM
|
||||
type GormUpdateQuery struct {
|
||||
db *gorm.DB
|
||||
model interface{}
|
||||
updates interface{}
|
||||
db *gorm.DB
|
||||
reconnect func(...*gorm.DB) error
|
||||
model interface{}
|
||||
updates interface{}
|
||||
schema string
|
||||
tableName string
|
||||
entity string
|
||||
driverName string
|
||||
metricsEnabled bool
|
||||
}
|
||||
|
||||
func (g *GormUpdateQuery) Model(model interface{}) common.UpdateQuery {
|
||||
g.model = model
|
||||
g.db = g.db.Model(model)
|
||||
g.schema, g.tableName = schemaAndTableFromModel(model, g.driverName)
|
||||
g.entity = entityNameFromModel(model, g.tableName)
|
||||
return g
|
||||
}
|
||||
|
||||
func (g *GormUpdateQuery) Table(table string) common.UpdateQuery {
|
||||
g.db = g.db.Table(table)
|
||||
g.schema, g.tableName = parseTableName(table, g.driverName)
|
||||
if g.entity == "" {
|
||||
g.entity = cleanMetricIdentifier(g.tableName)
|
||||
}
|
||||
if g.model == nil {
|
||||
// Try to get table name from table string if model is not set
|
||||
model, err := modelregistry.GetModelByName(table)
|
||||
if err == nil {
|
||||
g.model = model
|
||||
g.entity = entityNameFromModel(model, g.tableName)
|
||||
}
|
||||
}
|
||||
return g
|
||||
@@ -647,7 +842,16 @@ func (g *GormUpdateQuery) Exec(ctx context.Context) (res common.Result, err erro
|
||||
err = logger.HandlePanic("GormUpdateQuery.Exec", r)
|
||||
}
|
||||
}()
|
||||
result := g.db.WithContext(ctx).Updates(g.updates)
|
||||
startedAt := time.Now()
|
||||
run := func() *gorm.DB {
|
||||
return g.db.WithContext(ctx).Updates(g.updates)
|
||||
}
|
||||
result := run()
|
||||
if isDBClosed(result.Error) && g.reconnect != nil {
|
||||
if reconnErr := g.reconnect(g.db); reconnErr == nil {
|
||||
result = run()
|
||||
}
|
||||
}
|
||||
if result.Error != nil {
|
||||
// Log SQL string for debugging
|
||||
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
|
||||
@@ -655,23 +859,36 @@ func (g *GormUpdateQuery) Exec(ctx context.Context) (res common.Result, err erro
|
||||
})
|
||||
logger.Error("GormUpdateQuery.Exec failed. SQL: %s. Error: %v", sqlStr, result.Error)
|
||||
}
|
||||
recordQueryMetrics(g.metricsEnabled, "UPDATE", g.schema, g.entity, g.tableName, startedAt, result.Error)
|
||||
return &GormResult{result: result}, result.Error
|
||||
}
|
||||
|
||||
// GormDeleteQuery implements DeleteQuery for GORM
|
||||
type GormDeleteQuery struct {
|
||||
db *gorm.DB
|
||||
model interface{}
|
||||
db *gorm.DB
|
||||
reconnect func(...*gorm.DB) error
|
||||
model interface{}
|
||||
schema string
|
||||
tableName string
|
||||
entity string
|
||||
driverName string
|
||||
metricsEnabled bool
|
||||
}
|
||||
|
||||
func (g *GormDeleteQuery) Model(model interface{}) common.DeleteQuery {
|
||||
g.model = model
|
||||
g.db = g.db.Model(model)
|
||||
g.schema, g.tableName = schemaAndTableFromModel(model, g.driverName)
|
||||
g.entity = entityNameFromModel(model, g.tableName)
|
||||
return g
|
||||
}
|
||||
|
||||
func (g *GormDeleteQuery) Table(table string) common.DeleteQuery {
|
||||
g.db = g.db.Table(table)
|
||||
g.schema, g.tableName = parseTableName(table, g.driverName)
|
||||
if g.entity == "" {
|
||||
g.entity = cleanMetricIdentifier(g.tableName)
|
||||
}
|
||||
return g
|
||||
}
|
||||
|
||||
@@ -686,7 +903,16 @@ func (g *GormDeleteQuery) Exec(ctx context.Context) (res common.Result, err erro
|
||||
err = logger.HandlePanic("GormDeleteQuery.Exec", r)
|
||||
}
|
||||
}()
|
||||
result := g.db.WithContext(ctx).Delete(g.model)
|
||||
startedAt := time.Now()
|
||||
run := func() *gorm.DB {
|
||||
return g.db.WithContext(ctx).Delete(g.model)
|
||||
}
|
||||
result := run()
|
||||
if isDBClosed(result.Error) && g.reconnect != nil {
|
||||
if reconnErr := g.reconnect(g.db); reconnErr == nil {
|
||||
result = run()
|
||||
}
|
||||
}
|
||||
if result.Error != nil {
|
||||
// Log SQL string for debugging
|
||||
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
|
||||
@@ -694,6 +920,7 @@ func (g *GormDeleteQuery) Exec(ctx context.Context) (res common.Result, err erro
|
||||
})
|
||||
logger.Error("GormDeleteQuery.Exec failed. SQL: %s. Error: %v", sqlStr, result.Error)
|
||||
}
|
||||
recordQueryMetrics(g.metricsEnabled, "DELETE", g.schema, g.entity, g.tableName, startedAt, result.Error)
|
||||
return &GormResult{result: result}, result.Error
|
||||
}
|
||||
|
||||
|
||||
@@ -5,7 +5,10 @@ import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
@@ -16,8 +19,11 @@ import (
|
||||
// PgSQLAdapter adapts standard database/sql to work with our Database interface
|
||||
// This provides a lightweight PostgreSQL adapter without ORM overhead
|
||||
type PgSQLAdapter struct {
|
||||
db *sql.DB
|
||||
driverName string
|
||||
db *sql.DB
|
||||
dbMu sync.RWMutex
|
||||
dbFactory func() (*sql.DB, error)
|
||||
driverName string
|
||||
metricsEnabled bool
|
||||
}
|
||||
|
||||
// NewPgSQLAdapter creates a new adapter wrapping a standard sql.DB.
|
||||
@@ -28,7 +34,43 @@ func NewPgSQLAdapter(db *sql.DB, driverName ...string) *PgSQLAdapter {
|
||||
if len(driverName) > 0 && driverName[0] != "" {
|
||||
name = driverName[0]
|
||||
}
|
||||
return &PgSQLAdapter{db: db, driverName: name}
|
||||
return &PgSQLAdapter{db: db, driverName: name, metricsEnabled: true}
|
||||
}
|
||||
|
||||
// WithDBFactory configures a factory used to reopen the database connection if it is closed.
|
||||
func (p *PgSQLAdapter) WithDBFactory(factory func() (*sql.DB, error)) *PgSQLAdapter {
|
||||
p.dbFactory = factory
|
||||
return p
|
||||
}
|
||||
|
||||
// SetMetricsEnabled enables or disables query metrics for this adapter.
|
||||
func (p *PgSQLAdapter) SetMetricsEnabled(enabled bool) *PgSQLAdapter {
|
||||
p.metricsEnabled = enabled
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *PgSQLAdapter) getDB() *sql.DB {
|
||||
p.dbMu.RLock()
|
||||
defer p.dbMu.RUnlock()
|
||||
return p.db
|
||||
}
|
||||
|
||||
func (p *PgSQLAdapter) reconnectDB() error {
|
||||
if p.dbFactory == nil {
|
||||
return fmt.Errorf("no db factory configured for reconnect")
|
||||
}
|
||||
newDB, err := p.dbFactory()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
p.dbMu.Lock()
|
||||
p.db = newDB
|
||||
p.dbMu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func isDBClosed(err error) bool {
|
||||
return err != nil && strings.Contains(err.Error(), "sql: database is closed")
|
||||
}
|
||||
|
||||
// EnableQueryDebug enables query debugging for development
|
||||
@@ -38,37 +80,41 @@ func (p *PgSQLAdapter) EnableQueryDebug() {
|
||||
|
||||
func (p *PgSQLAdapter) NewSelect() common.SelectQuery {
|
||||
return &PgSQLSelectQuery{
|
||||
db: p.db,
|
||||
driverName: p.driverName,
|
||||
columns: []string{"*"},
|
||||
args: make([]interface{}, 0),
|
||||
db: p.getDB(),
|
||||
driverName: p.driverName,
|
||||
columns: []string{"*"},
|
||||
args: make([]interface{}, 0),
|
||||
metricsEnabled: p.metricsEnabled,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *PgSQLAdapter) NewInsert() common.InsertQuery {
|
||||
return &PgSQLInsertQuery{
|
||||
db: p.db,
|
||||
driverName: p.driverName,
|
||||
values: make(map[string]interface{}),
|
||||
db: p.getDB(),
|
||||
driverName: p.driverName,
|
||||
values: make(map[string]interface{}),
|
||||
metricsEnabled: p.metricsEnabled,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *PgSQLAdapter) NewUpdate() common.UpdateQuery {
|
||||
return &PgSQLUpdateQuery{
|
||||
db: p.db,
|
||||
driverName: p.driverName,
|
||||
sets: make(map[string]interface{}),
|
||||
args: make([]interface{}, 0),
|
||||
whereClauses: make([]string, 0),
|
||||
db: p.getDB(),
|
||||
driverName: p.driverName,
|
||||
sets: make(map[string]interface{}),
|
||||
args: make([]interface{}, 0),
|
||||
whereClauses: make([]string, 0),
|
||||
metricsEnabled: p.metricsEnabled,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *PgSQLAdapter) NewDelete() common.DeleteQuery {
|
||||
return &PgSQLDeleteQuery{
|
||||
db: p.db,
|
||||
driverName: p.driverName,
|
||||
args: make([]interface{}, 0),
|
||||
whereClauses: make([]string, 0),
|
||||
db: p.getDB(),
|
||||
driverName: p.driverName,
|
||||
args: make([]interface{}, 0),
|
||||
whereClauses: make([]string, 0),
|
||||
metricsEnabled: p.metricsEnabled,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -78,12 +124,23 @@ func (p *PgSQLAdapter) Exec(ctx context.Context, query string, args ...interface
|
||||
err = logger.HandlePanic("PgSQLAdapter.Exec", r)
|
||||
}
|
||||
}()
|
||||
startedAt := time.Now()
|
||||
operation, schema, entity, table := metricTargetFromRawQuery(query, p.driverName)
|
||||
logger.Debug("PgSQL Exec: %s [args: %v]", query, args)
|
||||
result, err := p.db.ExecContext(ctx, query, args...)
|
||||
var result sql.Result
|
||||
run := func() error { var e error; result, e = p.getDB().ExecContext(ctx, query, args...); return e }
|
||||
err = run()
|
||||
if isDBClosed(err) {
|
||||
if reconnErr := p.reconnectDB(); reconnErr == nil {
|
||||
err = run()
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
logger.Error("PgSQL Exec failed: %v", err)
|
||||
recordQueryMetrics(p.metricsEnabled, operation, schema, entity, table, startedAt, err)
|
||||
return nil, err
|
||||
}
|
||||
recordQueryMetrics(p.metricsEnabled, operation, schema, entity, table, startedAt, nil)
|
||||
return &PgSQLResult{result: result}, nil
|
||||
}
|
||||
|
||||
@@ -93,23 +150,35 @@ func (p *PgSQLAdapter) Query(ctx context.Context, dest interface{}, query string
|
||||
err = logger.HandlePanic("PgSQLAdapter.Query", r)
|
||||
}
|
||||
}()
|
||||
startedAt := time.Now()
|
||||
operation, schema, entity, table := metricTargetFromRawQuery(query, p.driverName)
|
||||
logger.Debug("PgSQL Query: %s [args: %v]", query, args)
|
||||
rows, err := p.db.QueryContext(ctx, query, args...)
|
||||
var rows *sql.Rows
|
||||
run := func() error { var e error; rows, e = p.getDB().QueryContext(ctx, query, args...); return e }
|
||||
err = run()
|
||||
if isDBClosed(err) {
|
||||
if reconnErr := p.reconnectDB(); reconnErr == nil {
|
||||
err = run()
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
logger.Error("PgSQL Query failed: %v", err)
|
||||
recordQueryMetrics(p.metricsEnabled, operation, schema, entity, table, startedAt, err)
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return scanRows(rows, dest)
|
||||
err = scanRows(rows, dest)
|
||||
recordQueryMetrics(p.metricsEnabled, operation, schema, entity, table, startedAt, err)
|
||||
return err
|
||||
}
|
||||
|
||||
func (p *PgSQLAdapter) BeginTx(ctx context.Context) (common.Database, error) {
|
||||
tx, err := p.db.BeginTx(ctx, nil)
|
||||
tx, err := p.getDB().BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &PgSQLTxAdapter{tx: tx, driverName: p.driverName}, nil
|
||||
return &PgSQLTxAdapter{tx: tx, driverName: p.driverName, metricsEnabled: p.metricsEnabled}, nil
|
||||
}
|
||||
|
||||
func (p *PgSQLAdapter) CommitTx(ctx context.Context) error {
|
||||
@@ -127,12 +196,12 @@ func (p *PgSQLAdapter) RunInTransaction(ctx context.Context, fn func(common.Data
|
||||
}
|
||||
}()
|
||||
|
||||
tx, err := p.db.BeginTx(ctx, nil)
|
||||
tx, err := p.getDB().BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
adapter := &PgSQLTxAdapter{tx: tx, driverName: p.driverName}
|
||||
adapter := &PgSQLTxAdapter{tx: tx, driverName: p.driverName, metricsEnabled: p.metricsEnabled}
|
||||
|
||||
defer func() {
|
||||
if p := recover(); p != nil {
|
||||
@@ -175,34 +244,34 @@ type relationMetadata struct {
|
||||
|
||||
// PgSQLSelectQuery implements SelectQuery for PostgreSQL
|
||||
type PgSQLSelectQuery struct {
|
||||
db *sql.DB
|
||||
tx *sql.Tx
|
||||
model interface{}
|
||||
tableName string
|
||||
tableAlias string
|
||||
driverName string // Database driver name (postgres, sqlite, mssql)
|
||||
columns []string
|
||||
columnExprs []string
|
||||
whereClauses []string
|
||||
orClauses []string
|
||||
joins []string
|
||||
orderBy []string
|
||||
groupBy []string
|
||||
havingClauses []string
|
||||
limit int
|
||||
offset int
|
||||
args []interface{}
|
||||
paramCounter int
|
||||
preloads []preloadConfig
|
||||
db *sql.DB
|
||||
tx *sql.Tx
|
||||
model interface{}
|
||||
entity string
|
||||
tableName string
|
||||
schema string
|
||||
tableAlias string
|
||||
driverName string // Database driver name (postgres, sqlite, mssql)
|
||||
columns []string
|
||||
columnExprs []string
|
||||
whereClauses []string
|
||||
orClauses []string
|
||||
joins []string
|
||||
orderBy []string
|
||||
groupBy []string
|
||||
havingClauses []string
|
||||
limit int
|
||||
offset int
|
||||
args []interface{}
|
||||
paramCounter int
|
||||
preloads []preloadConfig
|
||||
metricsEnabled bool
|
||||
}
|
||||
|
||||
func (p *PgSQLSelectQuery) Model(model interface{}) common.SelectQuery {
|
||||
p.model = model
|
||||
if provider, ok := model.(common.TableNameProvider); ok {
|
||||
fullTableName := provider.TableName()
|
||||
// For SQLite, convert "schema.table" to "schema_table"
|
||||
_, p.tableName = parseTableName(fullTableName, p.driverName)
|
||||
}
|
||||
p.schema, p.tableName = schemaAndTableFromModel(model, p.driverName)
|
||||
p.entity = entityNameFromModel(model, p.tableName)
|
||||
if provider, ok := model.(common.TableAliasProvider); ok {
|
||||
p.tableAlias = provider.TableAlias()
|
||||
}
|
||||
@@ -211,7 +280,10 @@ func (p *PgSQLSelectQuery) Model(model interface{}) common.SelectQuery {
|
||||
|
||||
func (p *PgSQLSelectQuery) Table(table string) common.SelectQuery {
|
||||
// For SQLite, convert "schema.table" to "schema_table"
|
||||
_, p.tableName = parseTableName(table, p.driverName)
|
||||
p.schema, p.tableName = parseTableName(table, p.driverName)
|
||||
if p.entity == "" {
|
||||
p.entity = cleanMetricIdentifier(p.tableName)
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
@@ -421,6 +493,7 @@ func (p *PgSQLSelectQuery) Scan(ctx context.Context, dest interface{}) (err erro
|
||||
err = logger.HandlePanic("PgSQLSelectQuery.Scan", r)
|
||||
}
|
||||
}()
|
||||
startedAt := time.Now()
|
||||
|
||||
// Apply preloads that use JOINs
|
||||
p.applyJoinPreloads()
|
||||
@@ -437,17 +510,21 @@ func (p *PgSQLSelectQuery) Scan(ctx context.Context, dest interface{}) (err erro
|
||||
|
||||
if err != nil {
|
||||
logger.Error("PgSQL SELECT failed: %v", err)
|
||||
recordQueryMetrics(p.metricsEnabled, "SELECT", p.schema, p.entity, p.tableName, startedAt, err)
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
err = scanRows(rows, dest)
|
||||
if err != nil {
|
||||
recordQueryMetrics(p.metricsEnabled, "SELECT", p.schema, p.entity, p.tableName, startedAt, err)
|
||||
return err
|
||||
}
|
||||
|
||||
// Apply preloads that use separate queries
|
||||
return p.applySubqueryPreloads(ctx, dest)
|
||||
err = p.applySubqueryPreloads(ctx, dest)
|
||||
recordQueryMetrics(p.metricsEnabled, "SELECT", p.schema, p.entity, p.tableName, startedAt, err)
|
||||
return err
|
||||
}
|
||||
|
||||
func (p *PgSQLSelectQuery) ScanModel(ctx context.Context) error {
|
||||
@@ -457,15 +534,8 @@ func (p *PgSQLSelectQuery) ScanModel(ctx context.Context) error {
|
||||
return p.Scan(ctx, p.model)
|
||||
}
|
||||
|
||||
func (p *PgSQLSelectQuery) Count(ctx context.Context) (count int, err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = logger.HandlePanic("PgSQLSelectQuery.Count", r)
|
||||
count = 0
|
||||
}
|
||||
}()
|
||||
|
||||
// Build a COUNT query
|
||||
// countInternal executes the COUNT query and returns the result without recording metrics.
|
||||
func (p *PgSQLSelectQuery) countInternal(ctx context.Context) (int, error) {
|
||||
var sb strings.Builder
|
||||
sb.WriteString("SELECT COUNT(*) FROM ")
|
||||
sb.WriteString(p.tableName)
|
||||
@@ -499,10 +569,26 @@ func (p *PgSQLSelectQuery) Count(ctx context.Context) (count int, err error) {
|
||||
row = p.db.QueryRowContext(ctx, query, p.args...)
|
||||
}
|
||||
|
||||
err = row.Scan(&count)
|
||||
var count int
|
||||
if err := row.Scan(&count); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func (p *PgSQLSelectQuery) Count(ctx context.Context) (count int, err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = logger.HandlePanic("PgSQLSelectQuery.Count", r)
|
||||
count = 0
|
||||
}
|
||||
}()
|
||||
startedAt := time.Now()
|
||||
count, err = p.countInternal(ctx)
|
||||
if err != nil {
|
||||
logger.Error("PgSQL COUNT failed: %v", err)
|
||||
}
|
||||
recordQueryMetrics(p.metricsEnabled, "COUNT", p.schema, p.entity, p.tableName, startedAt, err)
|
||||
return count, err
|
||||
}
|
||||
|
||||
@@ -513,27 +599,32 @@ func (p *PgSQLSelectQuery) Exists(ctx context.Context) (exists bool, err error)
|
||||
exists = false
|
||||
}
|
||||
}()
|
||||
|
||||
count, err := p.Count(ctx)
|
||||
startedAt := time.Now()
|
||||
count, err := p.countInternal(ctx)
|
||||
if err != nil {
|
||||
logger.Error("PgSQL EXISTS failed: %v", err)
|
||||
}
|
||||
recordQueryMetrics(p.metricsEnabled, "EXISTS", p.schema, p.entity, p.tableName, startedAt, err)
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
// PgSQLInsertQuery implements InsertQuery for PostgreSQL
|
||||
type PgSQLInsertQuery struct {
|
||||
db *sql.DB
|
||||
tx *sql.Tx
|
||||
tableName string
|
||||
driverName string
|
||||
values map[string]interface{}
|
||||
returning []string
|
||||
db *sql.DB
|
||||
tx *sql.Tx
|
||||
schema string
|
||||
tableName string
|
||||
entity string
|
||||
driverName string
|
||||
values map[string]interface{}
|
||||
valueOrder []string
|
||||
returning []string
|
||||
metricsEnabled bool
|
||||
}
|
||||
|
||||
func (p *PgSQLInsertQuery) Model(model interface{}) common.InsertQuery {
|
||||
if provider, ok := model.(common.TableNameProvider); ok {
|
||||
fullTableName := provider.TableName()
|
||||
// For SQLite, convert "schema.table" to "schema_table"
|
||||
_, p.tableName = parseTableName(fullTableName, p.driverName)
|
||||
}
|
||||
p.schema, p.tableName = schemaAndTableFromModel(model, p.driverName)
|
||||
p.entity = entityNameFromModel(model, p.tableName)
|
||||
// Extract values from model using reflection
|
||||
// This is a simplified implementation
|
||||
return p
|
||||
@@ -541,11 +632,17 @@ func (p *PgSQLInsertQuery) Model(model interface{}) common.InsertQuery {
|
||||
|
||||
func (p *PgSQLInsertQuery) Table(table string) common.InsertQuery {
|
||||
// For SQLite, convert "schema.table" to "schema_table"
|
||||
_, p.tableName = parseTableName(table, p.driverName)
|
||||
p.schema, p.tableName = parseTableName(table, p.driverName)
|
||||
if p.entity == "" {
|
||||
p.entity = cleanMetricIdentifier(p.tableName)
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *PgSQLInsertQuery) Value(column string, value interface{}) common.InsertQuery {
|
||||
if _, exists := p.values[column]; !exists {
|
||||
p.valueOrder = append(p.valueOrder, column)
|
||||
}
|
||||
p.values[column] = value
|
||||
return p
|
||||
}
|
||||
@@ -561,25 +658,27 @@ func (p *PgSQLInsertQuery) Returning(columns ...string) common.InsertQuery {
|
||||
}
|
||||
|
||||
func (p *PgSQLInsertQuery) Exec(ctx context.Context) (res common.Result, err error) {
|
||||
startedAt := time.Now()
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = logger.HandlePanic("PgSQLInsertQuery.Exec", r)
|
||||
}
|
||||
recordQueryMetrics(p.metricsEnabled, "INSERT", p.schema, p.entity, p.tableName, startedAt, err)
|
||||
}()
|
||||
|
||||
if len(p.values) == 0 {
|
||||
return nil, fmt.Errorf("no values to insert")
|
||||
err = fmt.Errorf("no values to insert")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
columns := make([]string, 0, len(p.values))
|
||||
placeholders := make([]string, 0, len(p.values))
|
||||
args := make([]interface{}, 0, len(p.values))
|
||||
|
||||
i := 1
|
||||
for col, val := range p.values {
|
||||
for _, col := range p.valueOrder {
|
||||
columns = append(columns, col)
|
||||
placeholders = append(placeholders, fmt.Sprintf("$%d", i))
|
||||
args = append(args, val)
|
||||
args = append(args, p.values[col])
|
||||
i++
|
||||
}
|
||||
|
||||
@@ -611,35 +710,40 @@ func (p *PgSQLInsertQuery) Exec(ctx context.Context) (res common.Result, err err
|
||||
|
||||
// PgSQLUpdateQuery implements UpdateQuery for PostgreSQL
|
||||
type PgSQLUpdateQuery struct {
|
||||
db *sql.DB
|
||||
tx *sql.Tx
|
||||
tableName string
|
||||
driverName string
|
||||
model interface{}
|
||||
sets map[string]interface{}
|
||||
whereClauses []string
|
||||
args []interface{}
|
||||
paramCounter int
|
||||
returning []string
|
||||
db *sql.DB
|
||||
tx *sql.Tx
|
||||
schema string
|
||||
tableName string
|
||||
entity string
|
||||
driverName string
|
||||
model interface{}
|
||||
sets map[string]interface{}
|
||||
setOrder []string
|
||||
whereClauses []string
|
||||
args []interface{}
|
||||
paramCounter int
|
||||
returning []string
|
||||
metricsEnabled bool
|
||||
}
|
||||
|
||||
func (p *PgSQLUpdateQuery) Model(model interface{}) common.UpdateQuery {
|
||||
p.model = model
|
||||
if provider, ok := model.(common.TableNameProvider); ok {
|
||||
fullTableName := provider.TableName()
|
||||
// For SQLite, convert "schema.table" to "schema_table"
|
||||
_, p.tableName = parseTableName(fullTableName, p.driverName)
|
||||
}
|
||||
p.schema, p.tableName = schemaAndTableFromModel(model, p.driverName)
|
||||
p.entity = entityNameFromModel(model, p.tableName)
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *PgSQLUpdateQuery) Table(table string) common.UpdateQuery {
|
||||
// For SQLite, convert "schema.table" to "schema_table"
|
||||
_, p.tableName = parseTableName(table, p.driverName)
|
||||
p.schema, p.tableName = parseTableName(table, p.driverName)
|
||||
if p.entity == "" {
|
||||
p.entity = cleanMetricIdentifier(p.tableName)
|
||||
}
|
||||
if p.model == nil {
|
||||
model, err := modelregistry.GetModelByName(table)
|
||||
if err == nil {
|
||||
p.model = model
|
||||
p.entity = entityNameFromModel(model, p.tableName)
|
||||
}
|
||||
}
|
||||
return p
|
||||
@@ -649,6 +753,9 @@ func (p *PgSQLUpdateQuery) Set(column string, value interface{}) common.UpdateQu
|
||||
if p.model != nil && !reflection.IsColumnWritable(p.model, column) {
|
||||
return p
|
||||
}
|
||||
if _, exists := p.sets[column]; !exists {
|
||||
p.setOrder = append(p.setOrder, column)
|
||||
}
|
||||
p.sets[column] = value
|
||||
return p
|
||||
}
|
||||
@@ -659,13 +766,23 @@ func (p *PgSQLUpdateQuery) SetMap(values map[string]interface{}) common.UpdateQu
|
||||
pkName = reflection.GetPrimaryKeyName(p.model)
|
||||
}
|
||||
|
||||
for column, value := range values {
|
||||
orderedColumns := make([]string, 0, len(values))
|
||||
for column := range values {
|
||||
orderedColumns = append(orderedColumns, column)
|
||||
}
|
||||
sort.Strings(orderedColumns)
|
||||
|
||||
for _, column := range orderedColumns {
|
||||
value := values[column]
|
||||
if pkName != "" && column == pkName {
|
||||
continue
|
||||
}
|
||||
if p.model != nil && !reflection.IsColumnWritable(p.model, column) {
|
||||
continue
|
||||
}
|
||||
if _, exists := p.sets[column]; !exists {
|
||||
p.setOrder = append(p.setOrder, column)
|
||||
}
|
||||
p.sets[column] = value
|
||||
}
|
||||
return p
|
||||
@@ -694,24 +811,26 @@ func (p *PgSQLUpdateQuery) replacePlaceholders(query string, argCount int) strin
|
||||
}
|
||||
|
||||
func (p *PgSQLUpdateQuery) Exec(ctx context.Context) (res common.Result, err error) {
|
||||
startedAt := time.Now()
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = logger.HandlePanic("PgSQLUpdateQuery.Exec", r)
|
||||
}
|
||||
recordQueryMetrics(p.metricsEnabled, "UPDATE", p.schema, p.entity, p.tableName, startedAt, err)
|
||||
}()
|
||||
|
||||
if len(p.sets) == 0 {
|
||||
return nil, fmt.Errorf("no values to update")
|
||||
err = fmt.Errorf("no values to update")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
setClauses := make([]string, 0, len(p.sets))
|
||||
setArgs := make([]interface{}, 0, len(p.sets))
|
||||
|
||||
// SET parameters start at $1
|
||||
i := 1
|
||||
for col, val := range p.sets {
|
||||
for _, col := range p.setOrder {
|
||||
setClauses = append(setClauses, fmt.Sprintf("%s = $%d", col, i))
|
||||
setArgs = append(setArgs, val)
|
||||
setArgs = append(setArgs, p.sets[col])
|
||||
i++
|
||||
}
|
||||
|
||||
@@ -773,27 +892,30 @@ func (p *PgSQLUpdateQuery) Exec(ctx context.Context) (res common.Result, err err
|
||||
|
||||
// PgSQLDeleteQuery implements DeleteQuery for PostgreSQL
|
||||
type PgSQLDeleteQuery struct {
|
||||
db *sql.DB
|
||||
tx *sql.Tx
|
||||
tableName string
|
||||
driverName string
|
||||
whereClauses []string
|
||||
args []interface{}
|
||||
paramCounter int
|
||||
db *sql.DB
|
||||
tx *sql.Tx
|
||||
schema string
|
||||
tableName string
|
||||
entity string
|
||||
driverName string
|
||||
whereClauses []string
|
||||
args []interface{}
|
||||
paramCounter int
|
||||
metricsEnabled bool
|
||||
}
|
||||
|
||||
func (p *PgSQLDeleteQuery) Model(model interface{}) common.DeleteQuery {
|
||||
if provider, ok := model.(common.TableNameProvider); ok {
|
||||
fullTableName := provider.TableName()
|
||||
// For SQLite, convert "schema.table" to "schema_table"
|
||||
_, p.tableName = parseTableName(fullTableName, p.driverName)
|
||||
}
|
||||
p.schema, p.tableName = schemaAndTableFromModel(model, p.driverName)
|
||||
p.entity = entityNameFromModel(model, p.tableName)
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *PgSQLDeleteQuery) Table(table string) common.DeleteQuery {
|
||||
// For SQLite, convert "schema.table" to "schema_table"
|
||||
_, p.tableName = parseTableName(table, p.driverName)
|
||||
p.schema, p.tableName = parseTableName(table, p.driverName)
|
||||
if p.entity == "" {
|
||||
p.entity = cleanMetricIdentifier(p.tableName)
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
@@ -815,10 +937,12 @@ func (p *PgSQLDeleteQuery) replacePlaceholders(query string, argCount int) strin
|
||||
}
|
||||
|
||||
func (p *PgSQLDeleteQuery) Exec(ctx context.Context) (res common.Result, err error) {
|
||||
startedAt := time.Now()
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = logger.HandlePanic("PgSQLDeleteQuery.Exec", r)
|
||||
}
|
||||
recordQueryMetrics(p.metricsEnabled, "DELETE", p.schema, p.entity, p.tableName, startedAt, err)
|
||||
}()
|
||||
|
||||
query := fmt.Sprintf("DELETE FROM %s", p.tableName)
|
||||
@@ -866,66 +990,80 @@ func (p *PgSQLResult) LastInsertId() (int64, error) {
|
||||
|
||||
// PgSQLTxAdapter wraps a PostgreSQL transaction
|
||||
type PgSQLTxAdapter struct {
|
||||
tx *sql.Tx
|
||||
driverName string
|
||||
tx *sql.Tx
|
||||
driverName string
|
||||
metricsEnabled bool
|
||||
}
|
||||
|
||||
func (p *PgSQLTxAdapter) NewSelect() common.SelectQuery {
|
||||
return &PgSQLSelectQuery{
|
||||
tx: p.tx,
|
||||
driverName: p.driverName,
|
||||
columns: []string{"*"},
|
||||
args: make([]interface{}, 0),
|
||||
tx: p.tx,
|
||||
driverName: p.driverName,
|
||||
columns: []string{"*"},
|
||||
args: make([]interface{}, 0),
|
||||
metricsEnabled: p.metricsEnabled,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *PgSQLTxAdapter) NewInsert() common.InsertQuery {
|
||||
return &PgSQLInsertQuery{
|
||||
tx: p.tx,
|
||||
driverName: p.driverName,
|
||||
values: make(map[string]interface{}),
|
||||
tx: p.tx,
|
||||
driverName: p.driverName,
|
||||
values: make(map[string]interface{}),
|
||||
metricsEnabled: p.metricsEnabled,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *PgSQLTxAdapter) NewUpdate() common.UpdateQuery {
|
||||
return &PgSQLUpdateQuery{
|
||||
tx: p.tx,
|
||||
driverName: p.driverName,
|
||||
sets: make(map[string]interface{}),
|
||||
args: make([]interface{}, 0),
|
||||
whereClauses: make([]string, 0),
|
||||
tx: p.tx,
|
||||
driverName: p.driverName,
|
||||
sets: make(map[string]interface{}),
|
||||
args: make([]interface{}, 0),
|
||||
whereClauses: make([]string, 0),
|
||||
metricsEnabled: p.metricsEnabled,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *PgSQLTxAdapter) NewDelete() common.DeleteQuery {
|
||||
return &PgSQLDeleteQuery{
|
||||
tx: p.tx,
|
||||
driverName: p.driverName,
|
||||
args: make([]interface{}, 0),
|
||||
whereClauses: make([]string, 0),
|
||||
tx: p.tx,
|
||||
driverName: p.driverName,
|
||||
args: make([]interface{}, 0),
|
||||
whereClauses: make([]string, 0),
|
||||
metricsEnabled: p.metricsEnabled,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *PgSQLTxAdapter) Exec(ctx context.Context, query string, args ...interface{}) (common.Result, error) {
|
||||
startedAt := time.Now()
|
||||
operation, schema, entity, table := metricTargetFromRawQuery(query, p.driverName)
|
||||
logger.Debug("PgSQL Tx Exec: %s [args: %v]", query, args)
|
||||
result, err := p.tx.ExecContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
logger.Error("PgSQL Tx Exec failed: %v", err)
|
||||
recordQueryMetrics(p.metricsEnabled, operation, schema, entity, table, startedAt, err)
|
||||
return nil, err
|
||||
}
|
||||
recordQueryMetrics(p.metricsEnabled, operation, schema, entity, table, startedAt, nil)
|
||||
return &PgSQLResult{result: result}, nil
|
||||
}
|
||||
|
||||
func (p *PgSQLTxAdapter) Query(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
|
||||
startedAt := time.Now()
|
||||
operation, schema, entity, table := metricTargetFromRawQuery(query, p.driverName)
|
||||
logger.Debug("PgSQL Tx Query: %s [args: %v]", query, args)
|
||||
rows, err := p.tx.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
logger.Error("PgSQL Tx Query failed: %v", err)
|
||||
recordQueryMetrics(p.metricsEnabled, operation, schema, entity, table, startedAt, err)
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return scanRows(rows, dest)
|
||||
err = scanRows(rows, dest)
|
||||
recordQueryMetrics(p.metricsEnabled, operation, schema, entity, table, startedAt, err)
|
||||
return err
|
||||
}
|
||||
|
||||
func (p *PgSQLTxAdapter) BeginTx(ctx context.Context) (common.Database, error) {
|
||||
|
||||
206
pkg/common/adapters/database/query_metrics.go
Normal file
206
pkg/common/adapters/database/query_metrics.go
Normal file
@@ -0,0 +1,206 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/metrics"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/reflection"
|
||||
)
|
||||
|
||||
func recordQueryMetrics(enabled bool, operation, schema, entity, table string, startedAt time.Time, err error) {
|
||||
if !enabled {
|
||||
return
|
||||
}
|
||||
|
||||
metrics.GetProvider().RecordDBQuery(
|
||||
normalizeMetricOperation(operation),
|
||||
normalizeMetricSchema(schema),
|
||||
normalizeMetricEntity(entity, table),
|
||||
normalizeMetricTable(table),
|
||||
time.Since(startedAt),
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
func normalizeMetricOperation(operation string) string {
|
||||
operation = strings.ToUpper(strings.TrimSpace(operation))
|
||||
if operation == "" {
|
||||
return "UNKNOWN"
|
||||
}
|
||||
return operation
|
||||
}
|
||||
|
||||
func normalizeMetricSchema(schema string) string {
|
||||
schema = cleanMetricIdentifier(schema)
|
||||
if schema == "" {
|
||||
return "default"
|
||||
}
|
||||
return schema
|
||||
}
|
||||
|
||||
func normalizeMetricEntity(entity, table string) string {
|
||||
entity = cleanMetricIdentifier(entity)
|
||||
if entity != "" {
|
||||
return entity
|
||||
}
|
||||
|
||||
table = cleanMetricIdentifier(table)
|
||||
if table != "" {
|
||||
return table
|
||||
}
|
||||
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
func normalizeMetricTable(table string) string {
|
||||
table = cleanMetricIdentifier(table)
|
||||
if table == "" {
|
||||
return "unknown"
|
||||
}
|
||||
return table
|
||||
}
|
||||
|
||||
func entityNameFromModel(model interface{}, table string) string {
|
||||
if model == nil {
|
||||
return cleanMetricIdentifier(table)
|
||||
}
|
||||
|
||||
modelType := reflect.TypeOf(model)
|
||||
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
|
||||
if modelType == nil {
|
||||
return cleanMetricIdentifier(table)
|
||||
}
|
||||
|
||||
if modelType.Kind() == reflect.Struct && modelType.Name() != "" {
|
||||
return reflection.ToSnakeCase(modelType.Name())
|
||||
}
|
||||
|
||||
return cleanMetricIdentifier(table)
|
||||
}
|
||||
|
||||
func schemaAndTableFromModel(model interface{}, driverName string) (schema, table string) {
|
||||
provider, ok := tableNameProviderFromModel(model)
|
||||
if !ok {
|
||||
return "", ""
|
||||
}
|
||||
|
||||
return parseTableName(provider.TableName(), driverName)
|
||||
}
|
||||
|
||||
// tableNameProviderType is cached to avoid repeated reflection on every call.
|
||||
var tableNameProviderType = reflect.TypeOf((*common.TableNameProvider)(nil)).Elem()
|
||||
|
||||
func tableNameProviderFromModel(model interface{}) (common.TableNameProvider, bool) {
|
||||
if model == nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
if provider, ok := model.(common.TableNameProvider); ok {
|
||||
return provider, true
|
||||
}
|
||||
|
||||
modelType := reflect.TypeOf(model)
|
||||
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
|
||||
if modelType == nil || modelType.Kind() != reflect.Struct {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// Check whether *T implements TableNameProvider before allocating.
|
||||
ptrType := reflect.PointerTo(modelType)
|
||||
if !ptrType.Implements(tableNameProviderType) && !modelType.Implements(tableNameProviderType) {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
modelValue := reflect.New(modelType)
|
||||
if provider, ok := modelValue.Interface().(common.TableNameProvider); ok {
|
||||
return provider, true
|
||||
}
|
||||
|
||||
if provider, ok := modelValue.Elem().Interface().(common.TableNameProvider); ok {
|
||||
return provider, true
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func metricTargetFromRawQuery(query, driverName string) (operation, schema, entity, table string) {
|
||||
operation = normalizeMetricOperation(firstQueryKeyword(query))
|
||||
tableRef := tableFromRawQuery(query, operation)
|
||||
if tableRef == "" {
|
||||
return operation, "", "unknown", "unknown"
|
||||
}
|
||||
|
||||
schema, table = parseTableName(tableRef, driverName)
|
||||
entity = cleanMetricIdentifier(table)
|
||||
return operation, schema, entity, table
|
||||
}
|
||||
|
||||
func firstQueryKeyword(query string) string {
|
||||
query = strings.TrimSpace(query)
|
||||
if query == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
fields := strings.Fields(query)
|
||||
if len(fields) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
return fields[0]
|
||||
}
|
||||
|
||||
func tableFromRawQuery(query, operation string) string {
|
||||
tokens := tokenizeQuery(query)
|
||||
if len(tokens) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
switch operation {
|
||||
case "SELECT":
|
||||
return tokenAfter(tokens, "FROM")
|
||||
case "INSERT":
|
||||
return tokenAfter(tokens, "INTO")
|
||||
case "UPDATE":
|
||||
return tokenAfter(tokens, "UPDATE")
|
||||
case "DELETE":
|
||||
return tokenAfter(tokens, "FROM")
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func tokenAfter(tokens []string, keyword string) string {
|
||||
for idx, token := range tokens {
|
||||
if strings.EqualFold(token, keyword) && idx+1 < len(tokens) {
|
||||
return cleanMetricIdentifier(tokens[idx+1])
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func tokenizeQuery(query string) []string {
|
||||
replacer := strings.NewReplacer(
|
||||
"\n", " ",
|
||||
"\t", " ",
|
||||
"(", " ",
|
||||
")", " ",
|
||||
",", " ",
|
||||
)
|
||||
return strings.Fields(replacer.Replace(query))
|
||||
}
|
||||
|
||||
func cleanMetricIdentifier(value string) string {
|
||||
value = strings.TrimSpace(value)
|
||||
value = strings.Trim(value, "\"'`[]")
|
||||
value = strings.TrimRight(value, ";")
|
||||
return value
|
||||
}
|
||||
306
pkg/common/adapters/database/query_metrics_test.go
Normal file
306
pkg/common/adapters/database/query_metrics_test.go
Normal file
@@ -0,0 +1,306 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/uptrace/bun"
|
||||
"github.com/uptrace/bun/dialect/sqlitedialect"
|
||||
"github.com/uptrace/bun/driver/sqliteshim"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/metrics"
|
||||
)
|
||||
|
||||
type queryMetricCall struct {
|
||||
operation string
|
||||
schema string
|
||||
entity string
|
||||
table string
|
||||
}
|
||||
|
||||
type capturingMetricsProvider struct {
|
||||
mu sync.Mutex
|
||||
calls []queryMetricCall
|
||||
}
|
||||
|
||||
func (c *capturingMetricsProvider) RecordHTTPRequest(method, path, status string, duration time.Duration) {
|
||||
}
|
||||
func (c *capturingMetricsProvider) IncRequestsInFlight() {}
|
||||
func (c *capturingMetricsProvider) DecRequestsInFlight() {}
|
||||
func (c *capturingMetricsProvider) RecordDBQuery(operation, schema, entity, table string, duration time.Duration, err error) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.calls = append(c.calls, queryMetricCall{
|
||||
operation: operation,
|
||||
schema: schema,
|
||||
entity: entity,
|
||||
table: table,
|
||||
})
|
||||
}
|
||||
func (c *capturingMetricsProvider) RecordCacheHit(provider string) {}
|
||||
func (c *capturingMetricsProvider) RecordCacheMiss(provider string) {}
|
||||
func (c *capturingMetricsProvider) UpdateCacheSize(provider string, size int64) {
|
||||
}
|
||||
func (c *capturingMetricsProvider) RecordEventPublished(source, eventType string) {}
|
||||
func (c *capturingMetricsProvider) RecordEventProcessed(source, eventType, status string, duration time.Duration) {
|
||||
}
|
||||
func (c *capturingMetricsProvider) UpdateEventQueueSize(size int64) {}
|
||||
func (c *capturingMetricsProvider) RecordPanic(methodName string) {}
|
||||
func (c *capturingMetricsProvider) Handler() http.Handler { return http.NewServeMux() }
|
||||
|
||||
func (c *capturingMetricsProvider) snapshot() []queryMetricCall {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
out := make([]queryMetricCall, len(c.calls))
|
||||
copy(out, c.calls)
|
||||
return out
|
||||
}
|
||||
|
||||
type queryMetricsGormUser struct {
|
||||
ID int `gorm:"primaryKey"`
|
||||
Name string
|
||||
}
|
||||
|
||||
func (queryMetricsGormUser) TableName() string {
|
||||
return "metrics_gorm_users"
|
||||
}
|
||||
|
||||
type queryMetricsBunUser struct {
|
||||
bun.BaseModel `bun:"table:metrics_bun_users"`
|
||||
ID int64 `bun:"id,pk,autoincrement"`
|
||||
Name string `bun:"name"`
|
||||
}
|
||||
|
||||
func TestPgSQLAdapterRecordsSchemaEntityTableMetrics(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
provider := &capturingMetricsProvider{}
|
||||
prev := metrics.GetProvider()
|
||||
metrics.SetProvider(provider)
|
||||
defer metrics.SetProvider(prev)
|
||||
|
||||
mock.ExpectExec(`UPDATE users SET name = \$1 WHERE id = \$2`).
|
||||
WithArgs("Alice", 1).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
adapter := NewPgSQLAdapter(db)
|
||||
_, err = adapter.NewUpdate().
|
||||
Table("public.users").
|
||||
Set("name", "Alice").
|
||||
Where("id = ?", 1).
|
||||
Exec(context.Background())
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
|
||||
calls := provider.snapshot()
|
||||
require.Len(t, calls, 1)
|
||||
assert.Equal(t, "UPDATE", calls[0].operation)
|
||||
assert.Equal(t, "public", calls[0].schema)
|
||||
assert.Equal(t, "users", calls[0].entity)
|
||||
assert.Equal(t, "users", calls[0].table)
|
||||
}
|
||||
|
||||
func TestPgSQLAdapterDisableMetricsSuppressesEmission(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
provider := &capturingMetricsProvider{}
|
||||
prev := metrics.GetProvider()
|
||||
metrics.SetProvider(provider)
|
||||
defer metrics.SetProvider(prev)
|
||||
|
||||
mock.ExpectExec(`DELETE FROM users WHERE id = \$1`).
|
||||
WithArgs(1).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
adapter := NewPgSQLAdapter(db).SetMetricsEnabled(false)
|
||||
_, err = adapter.NewDelete().
|
||||
Table("users").
|
||||
Where("id = ?", 1).
|
||||
Exec(context.Background())
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
assert.Empty(t, provider.snapshot())
|
||||
}
|
||||
|
||||
func TestGormAdapterRecordsEntityAndTableMetrics(t *testing.T) {
|
||||
db, err := gorm.Open(sqlite.Open("file::memory:?cache=shared"), &gorm.Config{})
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, db.AutoMigrate(&queryMetricsGormUser{}))
|
||||
require.NoError(t, db.Create(&queryMetricsGormUser{Name: "Alice"}).Error)
|
||||
|
||||
provider := &capturingMetricsProvider{}
|
||||
prev := metrics.GetProvider()
|
||||
metrics.SetProvider(provider)
|
||||
defer metrics.SetProvider(prev)
|
||||
|
||||
adapter := NewGormAdapter(db)
|
||||
var users []queryMetricsGormUser
|
||||
err = adapter.NewSelect().Model(&users).Scan(context.Background(), &users)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, users)
|
||||
|
||||
calls := provider.snapshot()
|
||||
require.Len(t, calls, 1)
|
||||
assert.Equal(t, "SELECT", calls[0].operation)
|
||||
assert.Equal(t, "default", calls[0].schema)
|
||||
assert.Equal(t, "query_metrics_gorm_user", calls[0].entity)
|
||||
assert.Equal(t, "metrics_gorm_users", calls[0].table)
|
||||
}
|
||||
|
||||
func TestPgSQLAdapterRecordsErrorMetric(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
provider := &capturingMetricsProvider{}
|
||||
prev := metrics.GetProvider()
|
||||
metrics.SetProvider(provider)
|
||||
defer metrics.SetProvider(prev)
|
||||
|
||||
mock.ExpectExec(`INSERT INTO users`).
|
||||
WillReturnError(fmt.Errorf("unique constraint violation"))
|
||||
|
||||
adapter := NewPgSQLAdapter(db)
|
||||
_, err = adapter.NewInsert().
|
||||
Table("users").
|
||||
Value("name", "Alice").
|
||||
Exec(context.Background())
|
||||
|
||||
require.Error(t, err)
|
||||
|
||||
calls := provider.snapshot()
|
||||
require.Len(t, calls, 1)
|
||||
assert.Equal(t, "INSERT", calls[0].operation)
|
||||
assert.Equal(t, "users", calls[0].table)
|
||||
}
|
||||
|
||||
func TestPgSQLAdapterRecordsExistsMetric(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
provider := &capturingMetricsProvider{}
|
||||
prev := metrics.GetProvider()
|
||||
metrics.SetProvider(provider)
|
||||
defer metrics.SetProvider(prev)
|
||||
|
||||
mock.ExpectQuery(`SELECT COUNT\(\*\) FROM users`).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(3))
|
||||
|
||||
adapter := NewPgSQLAdapter(db)
|
||||
exists, err := adapter.NewSelect().Table("users").Exists(context.Background())
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
|
||||
calls := provider.snapshot()
|
||||
require.Len(t, calls, 1)
|
||||
assert.Equal(t, "EXISTS", calls[0].operation)
|
||||
assert.Equal(t, "users", calls[0].table)
|
||||
}
|
||||
|
||||
func TestPgSQLAdapterRecordsCountMetric(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
provider := &capturingMetricsProvider{}
|
||||
prev := metrics.GetProvider()
|
||||
metrics.SetProvider(provider)
|
||||
defer metrics.SetProvider(prev)
|
||||
|
||||
mock.ExpectQuery(`SELECT COUNT\(\*\) FROM users`).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(5))
|
||||
|
||||
adapter := NewPgSQLAdapter(db)
|
||||
count, err := adapter.NewSelect().Table("users").Count(context.Background())
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 5, count)
|
||||
|
||||
calls := provider.snapshot()
|
||||
require.Len(t, calls, 1)
|
||||
assert.Equal(t, "COUNT", calls[0].operation)
|
||||
assert.Equal(t, "users", calls[0].table)
|
||||
}
|
||||
|
||||
func TestPgSQLAdapterRawExecRecordsMetric(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
provider := &capturingMetricsProvider{}
|
||||
prev := metrics.GetProvider()
|
||||
metrics.SetProvider(provider)
|
||||
defer metrics.SetProvider(prev)
|
||||
|
||||
mock.ExpectExec(`UPDATE public\.orders SET status = \$1`).
|
||||
WithArgs("shipped").
|
||||
WillReturnResult(sqlmock.NewResult(0, 2))
|
||||
|
||||
adapter := NewPgSQLAdapter(db)
|
||||
_, err = adapter.Exec(context.Background(), `UPDATE public.orders SET status = $1`, "shipped")
|
||||
|
||||
require.NoError(t, err)
|
||||
|
||||
calls := provider.snapshot()
|
||||
require.Len(t, calls, 1)
|
||||
assert.Equal(t, "UPDATE", calls[0].operation)
|
||||
assert.Equal(t, "public", calls[0].schema)
|
||||
assert.Equal(t, "orders", calls[0].table)
|
||||
}
|
||||
|
||||
func TestBunAdapterRecordsEntityAndTableMetrics(t *testing.T) {
|
||||
sqldb, err := sql.Open(sqliteshim.ShimName, "file::memory:?cache=shared")
|
||||
require.NoError(t, err)
|
||||
defer sqldb.Close()
|
||||
|
||||
db := bun.NewDB(sqldb, sqlitedialect.New())
|
||||
defer db.Close()
|
||||
|
||||
_, err = db.NewCreateTable().
|
||||
Model((*queryMetricsBunUser)(nil)).
|
||||
IfNotExists().
|
||||
Exec(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = db.NewInsert().Model(&queryMetricsBunUser{Name: "Alice"}).Exec(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
provider := &capturingMetricsProvider{}
|
||||
prev := metrics.GetProvider()
|
||||
metrics.SetProvider(provider)
|
||||
defer metrics.SetProvider(prev)
|
||||
|
||||
adapter := NewBunAdapter(db)
|
||||
var users []queryMetricsBunUser
|
||||
err = adapter.NewSelect().Model(&users).Scan(context.Background(), &users)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, users)
|
||||
|
||||
calls := provider.snapshot()
|
||||
require.Len(t, calls, 1)
|
||||
assert.Equal(t, "SELECT", calls[0].operation)
|
||||
assert.Equal(t, "default", calls[0].schema)
|
||||
assert.Equal(t, "query_metrics_bun_user", calls[0].entity)
|
||||
assert.Equal(t, "metrics_bun_users", calls[0].table)
|
||||
}
|
||||
@@ -98,8 +98,8 @@ func (p *NestedCUDProcessor) ProcessNestedCUD(
|
||||
}
|
||||
}
|
||||
|
||||
// Filter regularData to only include fields that exist in the model
|
||||
// Use MapToStruct to validate and filter fields
|
||||
// Filter regularData to only include fields that exist in the model,
|
||||
// and translate JSON keys to their actual database column names.
|
||||
regularData = p.filterValidFields(regularData, model)
|
||||
|
||||
// Inject parent IDs for foreign key resolution
|
||||
@@ -191,14 +191,15 @@ func (p *NestedCUDProcessor) extractCRUDRequest(data map[string]interface{}) str
|
||||
return ""
|
||||
}
|
||||
|
||||
// filterValidFields filters input data to only include fields that exist in the model
|
||||
// Uses reflection.MapToStruct to validate fields and extract only those that match the model
|
||||
// filterValidFields filters input data to only include fields that exist in the model,
|
||||
// and translates JSON key names to their actual database column names.
|
||||
// For example, a field tagged `json:"_changed_date" bun:"changed_date"` will be
|
||||
// included in the result as "changed_date", not "_changed_date".
|
||||
func (p *NestedCUDProcessor) filterValidFields(data map[string]interface{}, model interface{}) map[string]interface{} {
|
||||
if len(data) == 0 {
|
||||
return data
|
||||
}
|
||||
|
||||
// Create a new instance of the model to use with MapToStruct
|
||||
modelType := reflect.TypeOf(model)
|
||||
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
||||
modelType = modelType.Elem()
|
||||
@@ -208,25 +209,16 @@ func (p *NestedCUDProcessor) filterValidFields(data map[string]interface{}, mode
|
||||
return data
|
||||
}
|
||||
|
||||
// Create a new instance of the model
|
||||
tempModel := reflect.New(modelType).Interface()
|
||||
// Build a mapping from JSON key -> DB column name for all writable fields.
|
||||
// This both validates which fields belong to the model and translates their names
|
||||
// to the correct column names for use in SQL insert/update queries.
|
||||
jsonToDBCol := reflection.BuildJSONToDBColumnMap(modelType)
|
||||
|
||||
// Use MapToStruct to map the data - this will only map valid fields
|
||||
err := reflection.MapToStruct(data, tempModel)
|
||||
if err != nil {
|
||||
logger.Debug("Error mapping data to model: %v", err)
|
||||
return data
|
||||
}
|
||||
|
||||
// Extract the mapped fields back into a map
|
||||
// This effectively filters out any fields that don't exist in the model
|
||||
filteredData := make(map[string]interface{})
|
||||
tempModelValue := reflect.ValueOf(tempModel).Elem()
|
||||
|
||||
for key, value := range data {
|
||||
// Check if the field was successfully mapped
|
||||
if fieldWasMapped(tempModelValue, modelType, key) {
|
||||
filteredData[key] = value
|
||||
dbColName, exists := jsonToDBCol[key]
|
||||
if exists {
|
||||
filteredData[dbColName] = value
|
||||
} else {
|
||||
logger.Debug("Skipping invalid field '%s' - not found in model %v", key, modelType)
|
||||
}
|
||||
@@ -235,72 +227,8 @@ func (p *NestedCUDProcessor) filterValidFields(data map[string]interface{}, mode
|
||||
return filteredData
|
||||
}
|
||||
|
||||
// fieldWasMapped checks if a field with the given key was mapped to the model
|
||||
func fieldWasMapped(modelValue reflect.Value, modelType reflect.Type, key string) bool {
|
||||
// Look for the field by JSON tag or field name
|
||||
for i := 0; i < modelType.NumField(); i++ {
|
||||
field := modelType.Field(i)
|
||||
|
||||
// Skip unexported fields
|
||||
if !field.IsExported() {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check JSON tag
|
||||
jsonTag := field.Tag.Get("json")
|
||||
if jsonTag != "" && jsonTag != "-" {
|
||||
parts := strings.Split(jsonTag, ",")
|
||||
if len(parts) > 0 && parts[0] == key {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Check bun tag
|
||||
bunTag := field.Tag.Get("bun")
|
||||
if bunTag != "" && bunTag != "-" {
|
||||
if colName := reflection.ExtractColumnFromBunTag(bunTag); colName == key {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Check gorm tag
|
||||
gormTag := field.Tag.Get("gorm")
|
||||
if gormTag != "" && gormTag != "-" {
|
||||
if colName := reflection.ExtractColumnFromGormTag(gormTag); colName == key {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Check lowercase field name
|
||||
if strings.EqualFold(field.Name, key) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Handle embedded structs recursively
|
||||
if field.Anonymous {
|
||||
fieldType := field.Type
|
||||
if fieldType.Kind() == reflect.Ptr {
|
||||
fieldType = fieldType.Elem()
|
||||
}
|
||||
if fieldType.Kind() == reflect.Struct {
|
||||
embeddedValue := modelValue.Field(i)
|
||||
if embeddedValue.Kind() == reflect.Ptr {
|
||||
if embeddedValue.IsNil() {
|
||||
continue
|
||||
}
|
||||
embeddedValue = embeddedValue.Elem()
|
||||
}
|
||||
if fieldWasMapped(embeddedValue, fieldType, key) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// injectForeignKeys injects parent IDs into data for foreign key fields
|
||||
// injectForeignKeys injects parent IDs into data for foreign key fields.
|
||||
// data is expected to be keyed by DB column names (as returned by filterValidFields).
|
||||
func (p *NestedCUDProcessor) injectForeignKeys(data map[string]interface{}, modelType reflect.Type, parentIDs map[string]interface{}) {
|
||||
if len(parentIDs) == 0 {
|
||||
return
|
||||
@@ -319,10 +247,11 @@ func (p *NestedCUDProcessor) injectForeignKeys(data map[string]interface{}, mode
|
||||
if strings.EqualFold(jsonName, parentKey+"_id") ||
|
||||
strings.EqualFold(jsonName, parentKey+"id") ||
|
||||
strings.EqualFold(field.Name, parentKey+"ID") {
|
||||
// Only inject if not already present
|
||||
if _, exists := data[jsonName]; !exists {
|
||||
logger.Debug("Injecting foreign key: %s = %v", jsonName, parentID)
|
||||
data[jsonName] = parentID
|
||||
// Use the DB column name as the key, since data is keyed by DB column names
|
||||
dbColName := reflection.GetColumnName(field)
|
||||
if _, exists := data[dbColName]; !exists {
|
||||
logger.Debug("Injecting foreign key: %s = %v", dbColName, parentID)
|
||||
data[dbColName] = parentID
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -359,6 +359,42 @@ func (c *sqlConnection) Stats() *ConnectionStats {
|
||||
return stats
|
||||
}
|
||||
|
||||
func (c *sqlConnection) reconnectForAdapter() error {
|
||||
timeout := c.config.ConnectTimeout
|
||||
if timeout <= 0 {
|
||||
timeout = 10 * time.Second
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer cancel()
|
||||
|
||||
return c.Reconnect(ctx)
|
||||
}
|
||||
|
||||
func (c *sqlConnection) reopenNativeForAdapter() (*sql.DB, error) {
|
||||
if err := c.reconnectForAdapter(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return c.Native()
|
||||
}
|
||||
|
||||
func (c *sqlConnection) reopenBunForAdapter() (*bun.DB, error) {
|
||||
if err := c.reconnectForAdapter(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return c.Bun()
|
||||
}
|
||||
|
||||
func (c *sqlConnection) reopenGORMForAdapter() (*gorm.DB, error) {
|
||||
if err := c.reconnectForAdapter(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return c.GORM()
|
||||
}
|
||||
|
||||
// getBunAdapter returns or creates the Bun adapter
|
||||
func (c *sqlConnection) getBunAdapter() (common.Database, error) {
|
||||
if c == nil {
|
||||
@@ -391,7 +427,9 @@ func (c *sqlConnection) getBunAdapter() (common.Database, error) {
|
||||
c.bunDB = bun.NewDB(native, dialect)
|
||||
}
|
||||
|
||||
c.bunAdapter = database.NewBunAdapter(c.bunDB)
|
||||
c.bunAdapter = database.NewBunAdapter(c.bunDB).
|
||||
WithDBFactory(c.reopenBunForAdapter).
|
||||
SetMetricsEnabled(c.config.EnableMetrics)
|
||||
return c.bunAdapter, nil
|
||||
}
|
||||
|
||||
@@ -432,7 +470,9 @@ func (c *sqlConnection) getGORMAdapter() (common.Database, error) {
|
||||
c.gormDB = db
|
||||
}
|
||||
|
||||
c.gormAdapter = database.NewGormAdapter(c.gormDB)
|
||||
c.gormAdapter = database.NewGormAdapter(c.gormDB).
|
||||
WithDBFactory(c.reopenGORMForAdapter).
|
||||
SetMetricsEnabled(c.config.EnableMetrics)
|
||||
return c.gormAdapter, nil
|
||||
}
|
||||
|
||||
@@ -473,11 +513,17 @@ func (c *sqlConnection) getNativeAdapter() (common.Database, error) {
|
||||
// Create a native adapter based on database type
|
||||
switch c.dbType {
|
||||
case DatabaseTypePostgreSQL:
|
||||
c.nativeAdapter = database.NewPgSQLAdapter(c.nativeDB, string(c.dbType))
|
||||
c.nativeAdapter = database.NewPgSQLAdapter(c.nativeDB, string(c.dbType)).
|
||||
WithDBFactory(c.reopenNativeForAdapter).
|
||||
SetMetricsEnabled(c.config.EnableMetrics)
|
||||
case DatabaseTypeSQLite:
|
||||
c.nativeAdapter = database.NewPgSQLAdapter(c.nativeDB, string(c.dbType))
|
||||
c.nativeAdapter = database.NewPgSQLAdapter(c.nativeDB, string(c.dbType)).
|
||||
WithDBFactory(c.reopenNativeForAdapter).
|
||||
SetMetricsEnabled(c.config.EnableMetrics)
|
||||
case DatabaseTypeMSSQL:
|
||||
c.nativeAdapter = database.NewPgSQLAdapter(c.nativeDB, string(c.dbType))
|
||||
c.nativeAdapter = database.NewPgSQLAdapter(c.nativeDB, string(c.dbType)).
|
||||
WithDBFactory(c.reopenNativeForAdapter).
|
||||
SetMetricsEnabled(c.config.EnableMetrics)
|
||||
default:
|
||||
return nil, ErrUnsupportedDatabase
|
||||
}
|
||||
|
||||
@@ -4,8 +4,13 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common/adapters/database"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/dbmanager/providers"
|
||||
)
|
||||
|
||||
func TestNewConnectionFromDB(t *testing.T) {
|
||||
@@ -208,3 +213,157 @@ func TestNewConnectionFromDB_PostgreSQL(t *testing.T) {
|
||||
t.Errorf("Expected type DatabaseTypePostgreSQL, got '%s'", conn.Type())
|
||||
}
|
||||
}
|
||||
|
||||
func TestDatabaseNativeAdapterReconnectFactory(t *testing.T) {
|
||||
conn := newSQLConnection("test-native", DatabaseTypeSQLite, ConnectionConfig{
|
||||
Name: "test-native",
|
||||
Type: DatabaseTypeSQLite,
|
||||
FilePath: ":memory:",
|
||||
DefaultORM: string(ORMTypeNative),
|
||||
ConnectTimeout: 2 * time.Second,
|
||||
}, providers.NewSQLiteProvider())
|
||||
|
||||
ctx := context.Background()
|
||||
if err := conn.Connect(ctx); err != nil {
|
||||
t.Fatalf("Failed to connect: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
db, err := conn.Database()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get database adapter: %v", err)
|
||||
}
|
||||
|
||||
adapter, ok := db.(*database.PgSQLAdapter)
|
||||
if !ok {
|
||||
t.Fatalf("Expected PgSQLAdapter, got %T", db)
|
||||
}
|
||||
|
||||
underlyingBefore, ok := adapter.GetUnderlyingDB().(*sql.DB)
|
||||
if !ok {
|
||||
t.Fatalf("Expected underlying *sql.DB, got %T", adapter.GetUnderlyingDB())
|
||||
}
|
||||
|
||||
if err := underlyingBefore.Close(); err != nil {
|
||||
t.Fatalf("Failed to close underlying database: %v", err)
|
||||
}
|
||||
|
||||
if _, err := db.Exec(ctx, "SELECT 1"); err != nil {
|
||||
t.Fatalf("Expected native adapter to reconnect, got error: %v", err)
|
||||
}
|
||||
|
||||
underlyingAfter, ok := adapter.GetUnderlyingDB().(*sql.DB)
|
||||
if !ok {
|
||||
t.Fatalf("Expected reconnected *sql.DB, got %T", adapter.GetUnderlyingDB())
|
||||
}
|
||||
|
||||
if underlyingAfter == underlyingBefore {
|
||||
t.Fatal("Expected adapter to swap to a fresh *sql.DB after reconnect")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDatabaseBunAdapterReconnectFactory(t *testing.T) {
|
||||
conn := newSQLConnection("test-bun", DatabaseTypeSQLite, ConnectionConfig{
|
||||
Name: "test-bun",
|
||||
Type: DatabaseTypeSQLite,
|
||||
FilePath: ":memory:",
|
||||
DefaultORM: string(ORMTypeBun),
|
||||
ConnectTimeout: 2 * time.Second,
|
||||
}, providers.NewSQLiteProvider())
|
||||
|
||||
ctx := context.Background()
|
||||
if err := conn.Connect(ctx); err != nil {
|
||||
t.Fatalf("Failed to connect: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
db, err := conn.Database()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get database adapter: %v", err)
|
||||
}
|
||||
|
||||
adapter, ok := db.(*database.BunAdapter)
|
||||
if !ok {
|
||||
t.Fatalf("Expected BunAdapter, got %T", db)
|
||||
}
|
||||
|
||||
underlyingBefore, ok := adapter.GetUnderlyingDB().(interface{ Close() error })
|
||||
if !ok {
|
||||
t.Fatalf("Expected underlying Bun DB with Close method, got %T", adapter.GetUnderlyingDB())
|
||||
}
|
||||
|
||||
if err := underlyingBefore.Close(); err != nil {
|
||||
t.Fatalf("Failed to close underlying Bun database: %v", err)
|
||||
}
|
||||
|
||||
if _, err := db.Exec(ctx, "SELECT 1"); err != nil {
|
||||
t.Fatalf("Expected Bun adapter to reconnect, got error: %v", err)
|
||||
}
|
||||
|
||||
underlyingAfter := adapter.GetUnderlyingDB()
|
||||
if underlyingAfter == underlyingBefore {
|
||||
t.Fatal("Expected adapter to swap to a fresh Bun DB after reconnect")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDatabaseGormAdapterReconnectFactory(t *testing.T) {
|
||||
conn := newSQLConnection("test-gorm", DatabaseTypeSQLite, ConnectionConfig{
|
||||
Name: "test-gorm",
|
||||
Type: DatabaseTypeSQLite,
|
||||
FilePath: ":memory:",
|
||||
DefaultORM: string(ORMTypeGORM),
|
||||
ConnectTimeout: 2 * time.Second,
|
||||
}, providers.NewSQLiteProvider())
|
||||
|
||||
ctx := context.Background()
|
||||
if err := conn.Connect(ctx); err != nil {
|
||||
t.Fatalf("Failed to connect: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
db, err := conn.Database()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get database adapter: %v", err)
|
||||
}
|
||||
|
||||
adapter, ok := db.(*database.GormAdapter)
|
||||
if !ok {
|
||||
t.Fatalf("Expected GormAdapter, got %T", db)
|
||||
}
|
||||
|
||||
gormBefore, ok := adapter.GetUnderlyingDB().(*gorm.DB)
|
||||
if !ok {
|
||||
t.Fatalf("Expected underlying *gorm.DB, got %T", adapter.GetUnderlyingDB())
|
||||
}
|
||||
|
||||
sqlBefore, err := gormBefore.DB()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get underlying *sql.DB: %v", err)
|
||||
}
|
||||
|
||||
if err := sqlBefore.Close(); err != nil {
|
||||
t.Fatalf("Failed to close underlying database: %v", err)
|
||||
}
|
||||
|
||||
count, err := db.NewSelect().Table("sqlite_master").Count(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected GORM query builder to reconnect, got error: %v", err)
|
||||
}
|
||||
if count < 0 {
|
||||
t.Fatalf("Expected non-negative count, got %d", count)
|
||||
}
|
||||
|
||||
gormAfter, ok := adapter.GetUnderlyingDB().(*gorm.DB)
|
||||
if !ok {
|
||||
t.Fatalf("Expected reconnected *gorm.DB, got %T", adapter.GetUnderlyingDB())
|
||||
}
|
||||
|
||||
sqlAfter, err := gormAfter.DB()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get reconnected *sql.DB: %v", err)
|
||||
}
|
||||
|
||||
if sqlAfter == sqlBefore {
|
||||
t.Fatal("Expected GORM adapter to use a fresh *sql.DB after reconnect")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,7 +2,9 @@ package dbmanager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -366,8 +368,11 @@ func (m *connectionManager) performHealthCheck() {
|
||||
"connection", item.name,
|
||||
"error", err)
|
||||
|
||||
// Attempt reconnection if enabled
|
||||
if m.config.EnableAutoReconnect {
|
||||
// Only reconnect when the client handle itself is closed/disconnected.
|
||||
// For transient database restarts or network blips, *sql.DB can recover
|
||||
// on its own; forcing Close()+Connect() here invalidates any cached ORM
|
||||
// wrappers and callers that still hold the old handle.
|
||||
if m.config.EnableAutoReconnect && shouldReconnectAfterHealthCheck(err) {
|
||||
logger.Info("Attempting reconnection: connection=%s", item.name)
|
||||
if err := item.conn.Reconnect(ctx); err != nil {
|
||||
logger.Error("Reconnection failed",
|
||||
@@ -376,7 +381,21 @@ func (m *connectionManager) performHealthCheck() {
|
||||
} else {
|
||||
logger.Info("Reconnection successful: connection=%s", item.name)
|
||||
}
|
||||
} else if m.config.EnableAutoReconnect {
|
||||
logger.Info("Skipping reconnect for transient health check failure: connection=%s", item.name)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func shouldReconnectAfterHealthCheck(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if errors.Is(err, ErrConnectionClosed) {
|
||||
return true
|
||||
}
|
||||
|
||||
return strings.Contains(err.Error(), "sql: database is closed")
|
||||
}
|
||||
|
||||
@@ -3,12 +3,38 @@ package dbmanager
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/uptrace/bun"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
type healthCheckStubConnection struct {
|
||||
healthErr error
|
||||
reconnectCalls int
|
||||
}
|
||||
|
||||
func (c *healthCheckStubConnection) Name() string { return "stub" }
|
||||
func (c *healthCheckStubConnection) Type() DatabaseType { return DatabaseTypePostgreSQL }
|
||||
func (c *healthCheckStubConnection) Bun() (*bun.DB, error) { return nil, fmt.Errorf("not implemented") }
|
||||
func (c *healthCheckStubConnection) GORM() (*gorm.DB, error) { return nil, fmt.Errorf("not implemented") }
|
||||
func (c *healthCheckStubConnection) Native() (*sql.DB, error) { return nil, fmt.Errorf("not implemented") }
|
||||
func (c *healthCheckStubConnection) DB() (*sql.DB, error) { return nil, fmt.Errorf("not implemented") }
|
||||
func (c *healthCheckStubConnection) Database() (common.Database, error) { return nil, fmt.Errorf("not implemented") }
|
||||
func (c *healthCheckStubConnection) MongoDB() (*mongo.Client, error) { return nil, fmt.Errorf("not implemented") }
|
||||
func (c *healthCheckStubConnection) Connect(ctx context.Context) error { return nil }
|
||||
func (c *healthCheckStubConnection) Close() error { return nil }
|
||||
func (c *healthCheckStubConnection) HealthCheck(ctx context.Context) error { return c.healthErr }
|
||||
func (c *healthCheckStubConnection) Reconnect(ctx context.Context) error { c.reconnectCalls++; return nil }
|
||||
func (c *healthCheckStubConnection) Stats() *ConnectionStats { return &ConnectionStats{} }
|
||||
|
||||
func TestBackgroundHealthChecker(t *testing.T) {
|
||||
// Create a SQLite in-memory database
|
||||
db, err := sql.Open("sqlite3", ":memory:")
|
||||
@@ -224,3 +250,41 @@ func TestManagerStatsAfterClose(t *testing.T) {
|
||||
t.Errorf("Expected 0 total connections after close, got %d", stats.TotalConnections)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPerformHealthCheckSkipsReconnectForTransientFailures(t *testing.T) {
|
||||
conn := &healthCheckStubConnection{
|
||||
healthErr: fmt.Errorf("connection 'primary' health check: dial tcp 127.0.0.1:5432: connect: connection refused"),
|
||||
}
|
||||
|
||||
mgr := &connectionManager{
|
||||
connections: map[string]Connection{"primary": conn},
|
||||
config: ManagerConfig{
|
||||
EnableAutoReconnect: true,
|
||||
},
|
||||
}
|
||||
|
||||
mgr.performHealthCheck()
|
||||
|
||||
if conn.reconnectCalls != 0 {
|
||||
t.Fatalf("expected no reconnect attempts for transient health failure, got %d", conn.reconnectCalls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPerformHealthCheckReconnectsClosedConnections(t *testing.T) {
|
||||
conn := &healthCheckStubConnection{
|
||||
healthErr: NewConnectionError("primary", "health check", fmt.Errorf("sql: database is closed")),
|
||||
}
|
||||
|
||||
mgr := &connectionManager{
|
||||
connections: map[string]Connection{"primary": conn},
|
||||
config: ManagerConfig{
|
||||
EnableAutoReconnect: true,
|
||||
},
|
||||
}
|
||||
|
||||
mgr.performHealthCheck()
|
||||
|
||||
if conn.reconnectCalls != 1 {
|
||||
t.Fatalf("expected reconnect attempt for closed database handle, got %d", conn.reconnectCalls)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,11 +4,17 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
)
|
||||
|
||||
// isDBClosed reports whether err indicates the *sql.DB has been closed.
|
||||
func isDBClosed(err error) bool {
|
||||
return err != nil && strings.Contains(err.Error(), "sql: database is closed")
|
||||
}
|
||||
|
||||
// Common errors
|
||||
var (
|
||||
// ErrNotSQLDatabase is returned when attempting SQL operations on a non-SQL database
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
_ "github.com/glebarez/sqlite" // Pure Go SQLite driver
|
||||
@@ -14,8 +15,10 @@ import (
|
||||
|
||||
// SQLiteProvider implements Provider for SQLite databases
|
||||
type SQLiteProvider struct {
|
||||
db *sql.DB
|
||||
config ConnectionConfig
|
||||
db *sql.DB
|
||||
dbMu sync.RWMutex
|
||||
dbFactory func() (*sql.DB, error)
|
||||
config ConnectionConfig
|
||||
}
|
||||
|
||||
// NewSQLiteProvider creates a new SQLite provider
|
||||
@@ -129,7 +132,13 @@ func (p *SQLiteProvider) HealthCheck(ctx context.Context) error {
|
||||
|
||||
// Execute a simple query to verify the database is accessible
|
||||
var result int
|
||||
err := p.db.QueryRowContext(healthCtx, "SELECT 1").Scan(&result)
|
||||
run := func() error { return p.getDB().QueryRowContext(healthCtx, "SELECT 1").Scan(&result) }
|
||||
err := run()
|
||||
if isDBClosed(err) {
|
||||
if reconnErr := p.reconnectDB(); reconnErr == nil {
|
||||
err = run()
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("health check failed: %w", err)
|
||||
}
|
||||
@@ -141,6 +150,32 @@ func (p *SQLiteProvider) HealthCheck(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// WithDBFactory configures a factory used to reopen the database connection if it is closed.
|
||||
func (p *SQLiteProvider) WithDBFactory(factory func() (*sql.DB, error)) *SQLiteProvider {
|
||||
p.dbFactory = factory
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *SQLiteProvider) getDB() *sql.DB {
|
||||
p.dbMu.RLock()
|
||||
defer p.dbMu.RUnlock()
|
||||
return p.db
|
||||
}
|
||||
|
||||
func (p *SQLiteProvider) reconnectDB() error {
|
||||
if p.dbFactory == nil {
|
||||
return fmt.Errorf("no db factory configured for reconnect")
|
||||
}
|
||||
newDB, err := p.dbFactory()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
p.dbMu.Lock()
|
||||
p.db = newDB
|
||||
p.dbMu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetNative returns the native *sql.DB connection
|
||||
func (p *SQLiteProvider) GetNative() (*sql.DB, error) {
|
||||
if p.db == nil {
|
||||
|
||||
@@ -2,6 +2,7 @@ package metrics
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
@@ -19,7 +20,7 @@ type Provider interface {
|
||||
DecRequestsInFlight()
|
||||
|
||||
// RecordDBQuery records metrics for a database query
|
||||
RecordDBQuery(operation, table string, duration time.Duration, err error)
|
||||
RecordDBQuery(operation, schema, entity, table string, duration time.Duration, err error)
|
||||
|
||||
// RecordCacheHit records a cache hit
|
||||
RecordCacheHit(provider string)
|
||||
@@ -46,21 +47,28 @@ type Provider interface {
|
||||
Handler() http.Handler
|
||||
}
|
||||
|
||||
// globalProvider is the global metrics provider
|
||||
var globalProvider Provider
|
||||
// globalProvider is the global metrics provider, protected by globalProviderMu.
|
||||
var (
|
||||
globalProviderMu sync.RWMutex
|
||||
globalProvider Provider
|
||||
)
|
||||
|
||||
// SetProvider sets the global metrics provider
|
||||
// SetProvider sets the global metrics provider.
|
||||
func SetProvider(p Provider) {
|
||||
globalProviderMu.Lock()
|
||||
globalProvider = p
|
||||
globalProviderMu.Unlock()
|
||||
}
|
||||
|
||||
// GetProvider returns the current metrics provider
|
||||
// GetProvider returns the current metrics provider.
|
||||
func GetProvider() Provider {
|
||||
if globalProvider == nil {
|
||||
// Return no-op provider if none is set
|
||||
globalProviderMu.RLock()
|
||||
p := globalProvider
|
||||
globalProviderMu.RUnlock()
|
||||
if p == nil {
|
||||
return &NoOpProvider{}
|
||||
}
|
||||
return globalProvider
|
||||
return p
|
||||
}
|
||||
|
||||
// NoOpProvider is a no-op implementation of Provider
|
||||
@@ -69,7 +77,7 @@ type NoOpProvider struct{}
|
||||
func (n *NoOpProvider) RecordHTTPRequest(method, path, status string, duration time.Duration) {}
|
||||
func (n *NoOpProvider) IncRequestsInFlight() {}
|
||||
func (n *NoOpProvider) DecRequestsInFlight() {}
|
||||
func (n *NoOpProvider) RecordDBQuery(operation, table string, duration time.Duration, err error) {
|
||||
func (n *NoOpProvider) RecordDBQuery(operation, schema, entity, table string, duration time.Duration, err error) {
|
||||
}
|
||||
func (n *NoOpProvider) RecordCacheHit(provider string) {}
|
||||
func (n *NoOpProvider) RecordCacheMiss(provider string) {}
|
||||
|
||||
@@ -83,14 +83,14 @@ func NewPrometheusProvider(cfg *Config) *PrometheusProvider {
|
||||
Help: "Database query duration in seconds",
|
||||
Buckets: cfg.DBQueryBuckets,
|
||||
},
|
||||
[]string{"operation", "table"},
|
||||
[]string{"operation", "schema", "entity", "table"},
|
||||
),
|
||||
dbQueryTotal: promauto.NewCounterVec(
|
||||
prometheus.CounterOpts{
|
||||
Name: metricName("db_queries_total"),
|
||||
Help: "Total number of database queries",
|
||||
},
|
||||
[]string{"operation", "table", "status"},
|
||||
[]string{"operation", "schema", "entity", "table", "status"},
|
||||
),
|
||||
cacheHits: promauto.NewCounterVec(
|
||||
prometheus.CounterOpts{
|
||||
@@ -204,13 +204,13 @@ func (p *PrometheusProvider) DecRequestsInFlight() {
|
||||
}
|
||||
|
||||
// RecordDBQuery implements Provider interface
|
||||
func (p *PrometheusProvider) RecordDBQuery(operation, table string, duration time.Duration, err error) {
|
||||
func (p *PrometheusProvider) RecordDBQuery(operation, schema, entity, table string, duration time.Duration, err error) {
|
||||
status := "success"
|
||||
if err != nil {
|
||||
status = "error"
|
||||
}
|
||||
p.dbQueryDuration.WithLabelValues(operation, table).Observe(duration.Seconds())
|
||||
p.dbQueryTotal.WithLabelValues(operation, table, status).Inc()
|
||||
p.dbQueryDuration.WithLabelValues(operation, schema, entity, table).Observe(duration.Seconds())
|
||||
p.dbQueryTotal.WithLabelValues(operation, schema, entity, table, status).Inc()
|
||||
}
|
||||
|
||||
// RecordCacheHit implements Provider interface
|
||||
|
||||
@@ -196,6 +196,92 @@ func collectColumnsFromType(typ reflect.Type, columns *[]string) {
|
||||
}
|
||||
}
|
||||
|
||||
// GetColumnName extracts the database column name from a struct field.
|
||||
// Priority: bun tag -> gorm tag -> json tag -> lowercase field name.
|
||||
// This is the exported version for use by other packages.
|
||||
func GetColumnName(field reflect.StructField) string {
|
||||
return getColumnNameFromField(field)
|
||||
}
|
||||
|
||||
// BuildJSONToDBColumnMap returns a map from JSON key names to database column names
|
||||
// for the given model type. Only writable, non-relation fields are included.
|
||||
// This is used to translate incoming request data (keyed by JSON names) into
|
||||
// properly named database columns before insert/update operations.
|
||||
func BuildJSONToDBColumnMap(modelType reflect.Type) map[string]string {
|
||||
result := make(map[string]string)
|
||||
buildJSONToDBMap(modelType, result, false)
|
||||
return result
|
||||
}
|
||||
|
||||
func buildJSONToDBMap(modelType reflect.Type, result map[string]string, scanOnly bool) {
|
||||
for i := 0; i < modelType.NumField(); i++ {
|
||||
field := modelType.Field(i)
|
||||
if !field.IsExported() {
|
||||
continue
|
||||
}
|
||||
|
||||
bunTag := field.Tag.Get("bun")
|
||||
gormTag := field.Tag.Get("gorm")
|
||||
|
||||
// Handle embedded structs
|
||||
if field.Anonymous {
|
||||
ft := field.Type
|
||||
if ft.Kind() == reflect.Ptr {
|
||||
ft = ft.Elem()
|
||||
}
|
||||
isScanOnly := scanOnly
|
||||
if bunTag != "" && isBunFieldScanOnly(bunTag) {
|
||||
isScanOnly = true
|
||||
}
|
||||
if ft.Kind() == reflect.Struct {
|
||||
buildJSONToDBMap(ft, result, isScanOnly)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
if scanOnly {
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip explicitly excluded fields
|
||||
if bunTag == "-" || gormTag == "-" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip scan-only fields
|
||||
if bunTag != "" && isBunFieldScanOnly(bunTag) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip bun relation fields
|
||||
if bunTag != "" && (strings.Contains(bunTag, "rel:") || strings.Contains(bunTag, "join:") || strings.Contains(bunTag, "m2m:")) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip gorm relation fields
|
||||
if gormTag != "" && (strings.Contains(gormTag, "foreignKey:") || strings.Contains(gormTag, "references:") || strings.Contains(gormTag, "many2many:")) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Get JSON key (how the field appears in incoming request data)
|
||||
jsonKey := ""
|
||||
if jsonTag := field.Tag.Get("json"); jsonTag != "" && jsonTag != "-" {
|
||||
parts := strings.Split(jsonTag, ",")
|
||||
if len(parts) > 0 && parts[0] != "" {
|
||||
jsonKey = parts[0]
|
||||
}
|
||||
}
|
||||
if jsonKey == "" {
|
||||
jsonKey = strings.ToLower(field.Name)
|
||||
}
|
||||
|
||||
// Get the actual DB column name (bun > gorm > json > field name)
|
||||
dbColName := getColumnNameFromField(field)
|
||||
|
||||
result[jsonKey] = dbColName
|
||||
}
|
||||
}
|
||||
|
||||
// getColumnNameFromField extracts the column name from a struct field
|
||||
// Priority: bun tag -> gorm tag -> json tag -> lowercase field name
|
||||
func getColumnNameFromField(field reflect.StructField) string {
|
||||
|
||||
@@ -823,12 +823,12 @@ func TestToSnakeCase(t *testing.T) {
|
||||
{
|
||||
name: "UserID",
|
||||
input: "UserID",
|
||||
expected: "user_i_d",
|
||||
expected: "user_id",
|
||||
},
|
||||
{
|
||||
name: "HTTPServer",
|
||||
input: "HTTPServer",
|
||||
expected: "h_t_t_p_server",
|
||||
expected: "http_server",
|
||||
},
|
||||
{
|
||||
name: "lowercase",
|
||||
@@ -838,7 +838,7 @@ func TestToSnakeCase(t *testing.T) {
|
||||
{
|
||||
name: "UPPERCASE",
|
||||
input: "UPPERCASE",
|
||||
expected: "u_p_p_e_r_c_a_s_e",
|
||||
expected: "uppercase",
|
||||
},
|
||||
{
|
||||
name: "Single",
|
||||
|
||||
@@ -142,9 +142,138 @@ e.Any("/mcp", echo.WrapHandler(h)) // Echo
|
||||
|
||||
---
|
||||
|
||||
### Authentication
|
||||
## OAuth2 Authentication
|
||||
|
||||
Add middleware before the MCP routes. The handler itself has no auth layer.
|
||||
`resolvemcp` ships a full **MCP-standard OAuth2 authorization server** (`pkg/security.OAuthServer`) that MCP clients (Claude Desktop, Cursor, etc.) can discover and use automatically.
|
||||
|
||||
It can operate as:
|
||||
- **Its own identity provider** — shows a login form, validates via `DatabaseAuthenticator.Login()`
|
||||
- **An OAuth2 federation layer** — delegates to external providers (Google, GitHub, Microsoft, etc.)
|
||||
- **Both simultaneously**
|
||||
|
||||
### Standard endpoints served
|
||||
|
||||
| Path | Spec | Purpose |
|
||||
|---|---|---|
|
||||
| `GET /.well-known/oauth-authorization-server` | RFC 8414 | MCP client auto-discovery |
|
||||
| `POST /oauth/register` | RFC 7591 | Dynamic client registration |
|
||||
| `GET /oauth/authorize` | OAuth 2.1 + PKCE | Start login (form or provider redirect) |
|
||||
| `POST /oauth/authorize` | — | Login form submission |
|
||||
| `POST /oauth/token` | OAuth 2.1 | Auth code → Bearer token exchange |
|
||||
| `POST /oauth/token` (refresh) | OAuth 2.1 | Refresh token rotation |
|
||||
| `GET /oauth/provider/callback` | Internal | External provider redirect target |
|
||||
|
||||
MCP clients send `Authorization: Bearer <token>` on all subsequent requests.
|
||||
|
||||
---
|
||||
|
||||
### Mode 1 — Direct login (server as identity provider)
|
||||
|
||||
```go
|
||||
import "github.com/bitechdev/ResolveSpec/pkg/security"
|
||||
|
||||
db, _ := sql.Open("postgres", dsn)
|
||||
auth := security.NewDatabaseAuthenticator(db)
|
||||
|
||||
handler := resolvemcp.NewHandlerWithGORM(gormDB, resolvemcp.Config{
|
||||
BaseURL: "https://api.example.com",
|
||||
BasePath: "/mcp",
|
||||
})
|
||||
|
||||
// Enable the OAuth2 server — auth enables the login form
|
||||
handler.EnableOAuthServer(security.OAuthServerConfig{
|
||||
Issuer: "https://api.example.com",
|
||||
}, auth)
|
||||
|
||||
provider, _ := security.NewCompositeSecurityProvider(auth, colSec, rowSec)
|
||||
securityList, _ := security.NewSecurityList(provider)
|
||||
security.RegisterSecurityHooks(handler, securityList)
|
||||
|
||||
http.ListenAndServe(":8080", handler.HTTPHandler(securityList))
|
||||
```
|
||||
|
||||
MCP client flow:
|
||||
1. Discovers server at `/.well-known/oauth-authorization-server`
|
||||
2. Registers itself at `/oauth/register`
|
||||
3. Redirects user to `/oauth/authorize` → login form appears
|
||||
4. On submit, exchanges code at `/oauth/token` → receives `Authorization: Bearer` token
|
||||
5. Uses token on all MCP tool calls
|
||||
|
||||
---
|
||||
|
||||
### Mode 2 — External provider (Google, GitHub, etc.)
|
||||
|
||||
The `RedirectURL` in the provider config must point to `/oauth/provider/callback` on this server.
|
||||
|
||||
```go
|
||||
auth := security.NewDatabaseAuthenticator(db).WithOAuth2(security.OAuth2Config{
|
||||
ClientID: os.Getenv("GOOGLE_CLIENT_ID"),
|
||||
ClientSecret: os.Getenv("GOOGLE_CLIENT_SECRET"),
|
||||
RedirectURL: "https://api.example.com/oauth/provider/callback",
|
||||
Scopes: []string{"openid", "profile", "email"},
|
||||
AuthURL: "https://accounts.google.com/o/oauth2/auth",
|
||||
TokenURL: "https://oauth2.googleapis.com/token",
|
||||
UserInfoURL: "https://www.googleapis.com/oauth2/v2/userinfo",
|
||||
ProviderName: "google",
|
||||
})
|
||||
|
||||
// Pass `auth` so the OAuth server supports persistence, introspection, and revocation.
|
||||
// Google handles the end-user authentication flow via redirect.
|
||||
handler.EnableOAuthServer(security.OAuthServerConfig{
|
||||
Issuer: "https://api.example.com",
|
||||
}, auth)
|
||||
handler.RegisterOAuth2Provider(auth, "google")
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Mode 3 — Both (login form + external providers)
|
||||
|
||||
```go
|
||||
handler.EnableOAuthServer(security.OAuthServerConfig{
|
||||
Issuer: "https://api.example.com",
|
||||
LoginTitle: "My App Login",
|
||||
}, auth) // auth enables the username/password form
|
||||
|
||||
handler.RegisterOAuth2Provider(googleAuth, "google")
|
||||
handler.RegisterOAuth2Provider(githubAuth, "github")
|
||||
```
|
||||
|
||||
When external providers are registered they take priority; the login form is used as fallback when no providers are configured.
|
||||
|
||||
---
|
||||
|
||||
### Using `security.OAuthServer` standalone
|
||||
|
||||
The authorization server lives in `pkg/security` and can be used with any HTTP framework independently of `resolvemcp`:
|
||||
|
||||
```go
|
||||
oauthSrv := security.NewOAuthServer(security.OAuthServerConfig{
|
||||
Issuer: "https://api.example.com",
|
||||
}, auth)
|
||||
oauthSrv.RegisterExternalProvider(googleAuth, "google")
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.Handle("/", oauthSrv.HTTPHandler()) // mounts all OAuth2 routes
|
||||
mux.Handle("/mcp/", myMCPHandler)
|
||||
http.ListenAndServe(":8080", mux)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Cookie-based flow (legacy)
|
||||
|
||||
For simple setups without full MCP OAuth2 compliance, use the legacy helpers that set a session cookie after external provider login:
|
||||
|
||||
```go
|
||||
resolvemcp.SetupMuxOAuth2Routes(r, auth, resolvemcp.OAuth2RouteConfig{
|
||||
ProviderName: "google",
|
||||
LoginPath: "/auth/google/login",
|
||||
CallbackPath: "/auth/google/callback",
|
||||
AfterLoginRedirect: "/",
|
||||
})
|
||||
resolvemcp.SetupMuxRoutesWithAuth(r, handler, securityList)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
|
||||
@@ -16,17 +16,20 @@ import (
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/reflection"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/security"
|
||||
)
|
||||
|
||||
// Handler exposes registered database models as MCP tools and resources.
|
||||
type Handler struct {
|
||||
db common.Database
|
||||
registry common.ModelRegistry
|
||||
hooks *HookRegistry
|
||||
mcpServer *server.MCPServer
|
||||
config Config
|
||||
name string
|
||||
version string
|
||||
db common.Database
|
||||
registry common.ModelRegistry
|
||||
hooks *HookRegistry
|
||||
mcpServer *server.MCPServer
|
||||
config Config
|
||||
name string
|
||||
version string
|
||||
oauth2Regs []oauth2Registration
|
||||
oauthSrv *security.OAuthServer
|
||||
}
|
||||
|
||||
// NewHandler creates a Handler with the given database, model registry, and config.
|
||||
@@ -717,7 +720,7 @@ func (h *Handler) applyFilterGroup(query common.SelectQuery, filters []common.Fi
|
||||
return query.Where("("+strings.Join(conditions, " OR ")+")", args...)
|
||||
}
|
||||
|
||||
func (h *Handler) buildFilterCondition(filter common.FilterOption) (string, []interface{}) {
|
||||
func (h *Handler) buildFilterCondition(filter common.FilterOption) (condition string, args []interface{}) {
|
||||
switch filter.Operator {
|
||||
case "eq", "=":
|
||||
return fmt.Sprintf("%s = ?", filter.Column), []interface{}{filter.Value}
|
||||
@@ -747,7 +750,8 @@ func (h *Handler) buildFilterCondition(filter common.FilterOption) (string, []in
|
||||
}
|
||||
|
||||
func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, preloads []common.PreloadOption) (common.SelectQuery, error) {
|
||||
for _, preload := range preloads {
|
||||
for i := range preloads {
|
||||
preload := &preloads[i]
|
||||
if preload.Relation == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
264
pkg/resolvemcp/oauth2.go
Normal file
264
pkg/resolvemcp/oauth2.go
Normal file
@@ -0,0 +1,264 @@
|
||||
package resolvemcp
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/security"
|
||||
)
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// OAuth2 registration on the Handler
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
// oauth2Registration stores a configured auth provider and its route config.
|
||||
type oauth2Registration struct {
|
||||
auth *security.DatabaseAuthenticator
|
||||
cfg OAuth2RouteConfig
|
||||
}
|
||||
|
||||
// RegisterOAuth2 attaches an OAuth2 provider to the Handler.
|
||||
// The login and callback HTTP routes are served by HTTPHandler / StreamableHTTPMux.
|
||||
// Call this once per provider before serving requests.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// auth := security.NewGoogleAuthenticator(clientID, secret, redirectURL, db)
|
||||
// handler.RegisterOAuth2(auth, resolvemcp.OAuth2RouteConfig{
|
||||
// ProviderName: "google",
|
||||
// LoginPath: "/auth/google/login",
|
||||
// CallbackPath: "/auth/google/callback",
|
||||
// AfterLoginRedirect: "/",
|
||||
// })
|
||||
func (h *Handler) RegisterOAuth2(auth *security.DatabaseAuthenticator, cfg OAuth2RouteConfig) {
|
||||
h.oauth2Regs = append(h.oauth2Regs, oauth2Registration{auth: auth, cfg: cfg})
|
||||
}
|
||||
|
||||
// HTTPHandler returns a single http.Handler that serves:
|
||||
// - MCP OAuth2 authorization server endpoints (when EnableOAuthServer has been called)
|
||||
// - OAuth2 login + callback routes for every registered provider (legacy cookie flow)
|
||||
// - The MCP SSE transport wrapped with required authentication middleware
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// auth := security.NewGoogleAuthenticator(...)
|
||||
// handler.RegisterOAuth2(auth, cfg)
|
||||
// handler.EnableOAuthServer(security.OAuthServerConfig{Issuer: "https://api.example.com"})
|
||||
// security.RegisterSecurityHooks(handler, securityList)
|
||||
// http.ListenAndServe(":8080", handler.HTTPHandler(securityList))
|
||||
func (h *Handler) HTTPHandler(securityList *security.SecurityList) http.Handler {
|
||||
mux := http.NewServeMux()
|
||||
if h.oauthSrv != nil {
|
||||
h.mountOAuthServerRoutes(mux)
|
||||
}
|
||||
h.mountOAuth2Routes(mux)
|
||||
|
||||
mcpHandler := h.AuthedSSEServer(securityList)
|
||||
basePath := h.config.BasePath
|
||||
if basePath == "" {
|
||||
basePath = "/mcp"
|
||||
}
|
||||
mux.Handle(basePath+"/sse", mcpHandler)
|
||||
mux.Handle(basePath+"/message", mcpHandler)
|
||||
mux.Handle(basePath+"/", http.StripPrefix(basePath, mcpHandler))
|
||||
|
||||
return mux
|
||||
}
|
||||
|
||||
// StreamableHTTPMux returns a single http.Handler that serves:
|
||||
// - MCP OAuth2 authorization server endpoints (when EnableOAuthServer has been called)
|
||||
// - OAuth2 login + callback routes for every registered provider (legacy cookie flow)
|
||||
// - The MCP streamable HTTP transport wrapped with required authentication middleware
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// http.ListenAndServe(":8080", handler.StreamableHTTPMux(securityList))
|
||||
func (h *Handler) StreamableHTTPMux(securityList *security.SecurityList) http.Handler {
|
||||
mux := http.NewServeMux()
|
||||
if h.oauthSrv != nil {
|
||||
h.mountOAuthServerRoutes(mux)
|
||||
}
|
||||
h.mountOAuth2Routes(mux)
|
||||
|
||||
mcpHandler := h.AuthedStreamableHTTPServer(securityList)
|
||||
basePath := h.config.BasePath
|
||||
if basePath == "" {
|
||||
basePath = "/mcp"
|
||||
}
|
||||
mux.Handle(basePath+"/", http.StripPrefix(basePath, mcpHandler))
|
||||
mux.Handle(basePath, mcpHandler)
|
||||
|
||||
return mux
|
||||
}
|
||||
|
||||
// mountOAuth2Routes registers all stored OAuth2 login+callback routes onto mux.
|
||||
func (h *Handler) mountOAuth2Routes(mux *http.ServeMux) {
|
||||
for _, reg := range h.oauth2Regs {
|
||||
var cookieOpts []security.SessionCookieOptions
|
||||
if reg.cfg.CookieOptions != nil {
|
||||
cookieOpts = append(cookieOpts, *reg.cfg.CookieOptions)
|
||||
}
|
||||
mux.Handle(reg.cfg.LoginPath, OAuth2LoginHandler(reg.auth, reg.cfg.ProviderName))
|
||||
mux.Handle(reg.cfg.CallbackPath, OAuth2CallbackHandler(reg.auth, reg.cfg.ProviderName, reg.cfg.AfterLoginRedirect, cookieOpts...))
|
||||
}
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Auth-wrapped transports
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
// AuthedSSEServer wraps SSEServer with required authentication middleware from pkg/security.
|
||||
// The middleware reads the session cookie / Authorization header and populates the user
|
||||
// context into the request context, making it available to BeforeHandle security hooks.
|
||||
// Unauthenticated requests receive 401 before reaching any MCP tool.
|
||||
func (h *Handler) AuthedSSEServer(securityList *security.SecurityList) http.Handler {
|
||||
return security.NewAuthMiddleware(securityList)(h.SSEServer())
|
||||
}
|
||||
|
||||
// OptionalAuthSSEServer wraps SSEServer with optional authentication middleware.
|
||||
// Unauthenticated requests continue as guest rather than returning 401.
|
||||
// Use together with RegisterSecurityHooks and per-model CanPublicRead/Write rules
|
||||
// to allow mixed public/private access.
|
||||
func (h *Handler) OptionalAuthSSEServer(securityList *security.SecurityList) http.Handler {
|
||||
return security.NewOptionalAuthMiddleware(securityList)(h.SSEServer())
|
||||
}
|
||||
|
||||
// AuthedStreamableHTTPServer wraps StreamableHTTPServer with required authentication middleware.
|
||||
func (h *Handler) AuthedStreamableHTTPServer(securityList *security.SecurityList) http.Handler {
|
||||
return security.NewAuthMiddleware(securityList)(h.StreamableHTTPServer())
|
||||
}
|
||||
|
||||
// OptionalAuthStreamableHTTPServer wraps StreamableHTTPServer with optional authentication middleware.
|
||||
func (h *Handler) OptionalAuthStreamableHTTPServer(securityList *security.SecurityList) http.Handler {
|
||||
return security.NewOptionalAuthMiddleware(securityList)(h.StreamableHTTPServer())
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// OAuth2 route config and standalone handlers
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
// OAuth2RouteConfig configures the OAuth2 HTTP endpoints for a single provider.
|
||||
type OAuth2RouteConfig struct {
|
||||
// ProviderName is the OAuth2 provider name as registered with WithOAuth2()
|
||||
// (e.g. "google", "github", "microsoft").
|
||||
ProviderName string
|
||||
|
||||
// LoginPath is the HTTP path that redirects the browser to the OAuth2 provider
|
||||
// (e.g. "/auth/google/login").
|
||||
LoginPath string
|
||||
|
||||
// CallbackPath is the HTTP path that the OAuth2 provider redirects back to
|
||||
// (e.g. "/auth/google/callback"). Must match the RedirectURL in OAuth2Config.
|
||||
CallbackPath string
|
||||
|
||||
// AfterLoginRedirect is the URL to redirect the browser to after a successful
|
||||
// login. When empty the LoginResponse JSON is written directly to the response.
|
||||
AfterLoginRedirect string
|
||||
|
||||
// CookieOptions customises the session cookie written on successful login.
|
||||
// Defaults to HttpOnly, Secure, SameSite=Lax when nil.
|
||||
CookieOptions *security.SessionCookieOptions
|
||||
}
|
||||
|
||||
// OAuth2LoginHandler returns an http.HandlerFunc that redirects the browser to
|
||||
// the OAuth2 provider's authorization URL.
|
||||
//
|
||||
// Register it on any router:
|
||||
//
|
||||
// mux.Handle("/auth/google/login", resolvemcp.OAuth2LoginHandler(auth, "google"))
|
||||
func OAuth2LoginHandler(auth *security.DatabaseAuthenticator, providerName string) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
state, err := auth.OAuth2GenerateState()
|
||||
if err != nil {
|
||||
http.Error(w, "failed to generate state", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
authURL, err := auth.OAuth2GetAuthURL(providerName, state)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
http.Redirect(w, r, authURL, http.StatusTemporaryRedirect)
|
||||
}
|
||||
}
|
||||
|
||||
// OAuth2CallbackHandler returns an http.HandlerFunc that handles the OAuth2 provider
|
||||
// callback: exchanges the authorization code for a session token, writes the session
|
||||
// cookie, then either redirects to afterLoginRedirect or writes the LoginResponse as JSON.
|
||||
//
|
||||
// Register it on any router:
|
||||
//
|
||||
// mux.Handle("/auth/google/callback", resolvemcp.OAuth2CallbackHandler(auth, "google", "/dashboard"))
|
||||
func OAuth2CallbackHandler(auth *security.DatabaseAuthenticator, providerName, afterLoginRedirect string, cookieOpts ...security.SessionCookieOptions) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
code := r.URL.Query().Get("code")
|
||||
state := r.URL.Query().Get("state")
|
||||
if code == "" {
|
||||
http.Error(w, "missing code parameter", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
loginResp, err := auth.OAuth2HandleCallback(r.Context(), providerName, code, state)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
security.SetSessionCookie(w, loginResp, cookieOpts...)
|
||||
|
||||
if afterLoginRedirect != "" {
|
||||
http.Redirect(w, r, afterLoginRedirect, http.StatusTemporaryRedirect)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(loginResp) //nolint:errcheck
|
||||
}
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Gorilla Mux convenience helpers
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
// SetupMuxOAuth2Routes registers the OAuth2 login and callback routes on a Gorilla Mux router.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// resolvemcp.SetupMuxOAuth2Routes(r, auth, resolvemcp.OAuth2RouteConfig{
|
||||
// ProviderName: "google", LoginPath: "/auth/google/login",
|
||||
// CallbackPath: "/auth/google/callback", AfterLoginRedirect: "/",
|
||||
// })
|
||||
func SetupMuxOAuth2Routes(muxRouter *mux.Router, auth *security.DatabaseAuthenticator, cfg OAuth2RouteConfig) {
|
||||
var cookieOpts []security.SessionCookieOptions
|
||||
if cfg.CookieOptions != nil {
|
||||
cookieOpts = append(cookieOpts, *cfg.CookieOptions)
|
||||
}
|
||||
|
||||
muxRouter.Handle(cfg.LoginPath,
|
||||
OAuth2LoginHandler(auth, cfg.ProviderName),
|
||||
).Methods(http.MethodGet)
|
||||
|
||||
muxRouter.Handle(cfg.CallbackPath,
|
||||
OAuth2CallbackHandler(auth, cfg.ProviderName, cfg.AfterLoginRedirect, cookieOpts...),
|
||||
).Methods(http.MethodGet)
|
||||
}
|
||||
|
||||
// SetupMuxRoutesWithAuth mounts the MCP SSE endpoints on a Gorilla Mux router
|
||||
// with required authentication middleware applied.
|
||||
func SetupMuxRoutesWithAuth(muxRouter *mux.Router, handler *Handler, securityList *security.SecurityList) {
|
||||
basePath := handler.config.BasePath
|
||||
h := handler.AuthedSSEServer(securityList)
|
||||
|
||||
muxRouter.Handle(basePath+"/sse", h).Methods(http.MethodGet, http.MethodOptions)
|
||||
muxRouter.Handle(basePath+"/message", h).Methods(http.MethodPost, http.MethodOptions)
|
||||
muxRouter.PathPrefix(basePath).Handler(http.StripPrefix(basePath, h))
|
||||
}
|
||||
|
||||
// SetupMuxStreamableHTTPRoutesWithAuth mounts the MCP streamable HTTP endpoint on a
|
||||
// Gorilla Mux router with required authentication middleware applied.
|
||||
func SetupMuxStreamableHTTPRoutesWithAuth(muxRouter *mux.Router, handler *Handler, securityList *security.SecurityList) {
|
||||
basePath := handler.config.BasePath
|
||||
h := handler.AuthedStreamableHTTPServer(securityList)
|
||||
muxRouter.PathPrefix(basePath).Handler(http.StripPrefix(basePath, h))
|
||||
}
|
||||
51
pkg/resolvemcp/oauth2_server.go
Normal file
51
pkg/resolvemcp/oauth2_server.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package resolvemcp
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/security"
|
||||
)
|
||||
|
||||
// EnableOAuthServer activates the MCP-standard OAuth2 authorization server on this Handler.
|
||||
//
|
||||
// Pass a DatabaseAuthenticator to enable direct username/password login — the server acts as
|
||||
// its own identity provider and renders a login form at /oauth/authorize. Pass nil to use
|
||||
// only external providers registered via RegisterOAuth2Provider.
|
||||
//
|
||||
// After calling this, HTTPHandler and StreamableHTTPMux serve the full set of RFC-compliant
|
||||
// endpoints required by MCP clients alongside the MCP transport:
|
||||
//
|
||||
// GET /.well-known/oauth-authorization-server RFC 8414 — auto-discovery
|
||||
// POST /oauth/register RFC 7591 — dynamic client registration
|
||||
// GET /oauth/authorize OAuth 2.1 + PKCE — start login
|
||||
// POST /oauth/authorize Login form submission (password flow)
|
||||
// POST /oauth/token Bearer token exchange + refresh
|
||||
// GET /oauth/provider/callback External provider redirect target
|
||||
func (h *Handler) EnableOAuthServer(cfg security.OAuthServerConfig, auth *security.DatabaseAuthenticator) {
|
||||
h.oauthSrv = security.NewOAuthServer(cfg, auth)
|
||||
// Wire any external providers already registered via RegisterOAuth2
|
||||
for _, reg := range h.oauth2Regs {
|
||||
h.oauthSrv.RegisterExternalProvider(reg.auth, reg.cfg.ProviderName)
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterOAuth2Provider adds an external OAuth2 provider to the MCP OAuth2 authorization server.
|
||||
// EnableOAuthServer must be called before this. The auth must have been configured with
|
||||
// WithOAuth2(providerName, ...) for the given provider name.
|
||||
func (h *Handler) RegisterOAuth2Provider(auth *security.DatabaseAuthenticator, providerName string) {
|
||||
if h.oauthSrv != nil {
|
||||
h.oauthSrv.RegisterExternalProvider(auth, providerName)
|
||||
}
|
||||
}
|
||||
|
||||
// mountOAuthServerRoutes mounts the security.OAuthServer's HTTP handler onto mux.
|
||||
func (h *Handler) mountOAuthServerRoutes(mux *http.ServeMux) {
|
||||
oauthHandler := h.oauthSrv.HTTPHandler()
|
||||
// Delegate all /oauth/ and /.well-known/ paths to the OAuth server
|
||||
mux.Handle("/.well-known/", oauthHandler)
|
||||
mux.Handle("/oauth/", oauthHandler)
|
||||
if h.oauthSrv != nil {
|
||||
// Also mount the external provider callback path if it differs from /oauth/
|
||||
mux.Handle(h.oauthSrv.ProviderCallbackPath(), oauthHandler)
|
||||
}
|
||||
}
|
||||
153
pkg/security/KEYSTORE.md
Normal file
153
pkg/security/KEYSTORE.md
Normal file
@@ -0,0 +1,153 @@
|
||||
# Keystore
|
||||
|
||||
Per-user named auth keys with pluggable storage. Each user can hold multiple keys of different types — JWT secrets, header API keys, OAuth2 client credentials, or generic API keys. Keys are identified by a human-readable name ("CI deploy", "mobile app") and can carry scopes and arbitrary metadata.
|
||||
|
||||
## Key types
|
||||
|
||||
| Constant | Value | Use case |
|
||||
|---|---|---|
|
||||
| `KeyTypeJWTSecret` | `jwt_secret` | Per-user JWT signing secret |
|
||||
| `KeyTypeHeaderAPI` | `header_api` | Static API key sent in a request header |
|
||||
| `KeyTypeOAuth2` | `oauth2` | OAuth2 client credentials |
|
||||
| `KeyTypeGenericAPI` | `api` | General-purpose application key |
|
||||
|
||||
## Storage backends
|
||||
|
||||
### ConfigKeyStore
|
||||
|
||||
In-memory store seeded from a static list. Suitable for a small, fixed set of service-account keys loaded from a config file. Keys created at runtime via `CreateKey` are held in memory and lost on restart.
|
||||
|
||||
```go
|
||||
// Pre-load keys from config (KeyHash = SHA-256 hex of the raw key)
|
||||
store := security.NewConfigKeyStore([]security.UserKey{
|
||||
{
|
||||
UserID: 1,
|
||||
KeyType: security.KeyTypeGenericAPI,
|
||||
KeyHash: "e3b0c44298fc1c149afb...", // sha256(rawKey)
|
||||
Name: "CI deploy",
|
||||
Scopes: []string{"deploy"},
|
||||
IsActive: true,
|
||||
},
|
||||
})
|
||||
```
|
||||
|
||||
### DatabaseKeyStore
|
||||
|
||||
Backed by PostgreSQL stored procedures. Supports optional caching (default 2-minute TTL). Apply `keystore_schema.sql` before use.
|
||||
|
||||
```go
|
||||
db, _ := sql.Open("postgres", dsn)
|
||||
|
||||
store := security.NewDatabaseKeyStore(db)
|
||||
|
||||
// With options
|
||||
store = security.NewDatabaseKeyStore(db, security.DatabaseKeyStoreOptions{
|
||||
CacheTTL: 5 * time.Minute,
|
||||
SQLNames: &security.KeyStoreSQLNames{
|
||||
ValidateKey: "myapp_keystore_validate", // override one procedure name
|
||||
},
|
||||
})
|
||||
```
|
||||
|
||||
## Managing keys
|
||||
|
||||
```go
|
||||
ctx := context.Background()
|
||||
|
||||
// Create — raw key returned once; store it securely
|
||||
resp, err := store.CreateKey(ctx, security.CreateKeyRequest{
|
||||
UserID: 42,
|
||||
KeyType: security.KeyTypeGenericAPI,
|
||||
Name: "mobile app",
|
||||
Scopes: []string{"read", "write"},
|
||||
})
|
||||
fmt.Println(resp.RawKey) // only shown here; hashed internally
|
||||
|
||||
// List
|
||||
keys, err := store.GetUserKeys(ctx, 42, "") // "" = all types
|
||||
keys, err = store.GetUserKeys(ctx, 42, security.KeyTypeGenericAPI)
|
||||
|
||||
// Revoke
|
||||
err = store.DeleteKey(ctx, 42, resp.Key.ID)
|
||||
|
||||
// Validate (used by authenticators internally)
|
||||
key, err := store.ValidateKey(ctx, rawKey, "")
|
||||
```
|
||||
|
||||
## HTTP authentication
|
||||
|
||||
`KeyStoreAuthenticator` wraps any `KeyStore` and implements the `Authenticator` interface. It is drop-in compatible with `DatabaseAuthenticator` and works in `CompositeSecurityProvider`.
|
||||
|
||||
Keys are extracted from the request in this order:
|
||||
|
||||
1. `Authorization: Bearer <key>`
|
||||
2. `Authorization: ApiKey <key>`
|
||||
3. `X-API-Key: <key>`
|
||||
|
||||
```go
|
||||
auth := security.NewKeyStoreAuthenticator(store, "") // "" = accept any key type
|
||||
// Restrict to a specific type:
|
||||
auth = security.NewKeyStoreAuthenticator(store, security.KeyTypeGenericAPI)
|
||||
```
|
||||
|
||||
Plug it into a handler:
|
||||
|
||||
```go
|
||||
handler := resolvespec.NewHandler(db, registry,
|
||||
resolvespec.WithAuthenticator(auth),
|
||||
)
|
||||
```
|
||||
|
||||
`Login` and `Logout` return an error — key lifecycle is managed through `KeyStore` directly.
|
||||
|
||||
On successful validation the request context receives a `UserContext` where:
|
||||
|
||||
- `UserID` — from the key
|
||||
- `Roles` — the key's `Scopes`
|
||||
- `Claims["key_type"]` — key type string
|
||||
- `Claims["key_name"]` — key name
|
||||
|
||||
## Database setup
|
||||
|
||||
Apply `keystore_schema.sql` to your PostgreSQL database. It requires the `users` table from the main `database_schema.sql`.
|
||||
|
||||
```sql
|
||||
\i pkg/security/keystore_schema.sql
|
||||
```
|
||||
|
||||
This creates:
|
||||
|
||||
- `user_keys` table with indexes on `user_id`, `key_hash`, and `key_type`
|
||||
- `resolvespec_keystore_get_user_keys(p_user_id, p_key_type)`
|
||||
- `resolvespec_keystore_create_key(p_request jsonb)`
|
||||
- `resolvespec_keystore_delete_key(p_user_id, p_key_id)`
|
||||
- `resolvespec_keystore_validate_key(p_key_hash, p_key_type)`
|
||||
|
||||
### Custom procedure names
|
||||
|
||||
```go
|
||||
store := security.NewDatabaseKeyStore(db, security.DatabaseKeyStoreOptions{
|
||||
SQLNames: &security.KeyStoreSQLNames{
|
||||
GetUserKeys: "myschema_get_keys",
|
||||
CreateKey: "myschema_create_key",
|
||||
DeleteKey: "myschema_delete_key",
|
||||
ValidateKey: "myschema_validate_key",
|
||||
},
|
||||
})
|
||||
|
||||
// Validate names at startup
|
||||
names := &security.KeyStoreSQLNames{
|
||||
GetUserKeys: "myschema_get_keys",
|
||||
// ...
|
||||
}
|
||||
if err := security.ValidateKeyStoreSQLNames(names); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
```
|
||||
|
||||
## Security notes
|
||||
|
||||
- Raw keys are never stored. Only the SHA-256 hex digest is persisted.
|
||||
- The raw key is generated with `crypto/rand` (32 bytes, base64url-encoded) and returned exactly once in `CreateKeyResponse.RawKey`.
|
||||
- Hash comparisons in `ConfigKeyStore` use `crypto/subtle.ConstantTimeCompare` to prevent timing side-channels.
|
||||
- `DeleteKey` performs a soft delete (`is_active = false`). The `DatabaseKeyStore` invalidates the cache entry immediately, but due to the cache TTL a revoked key may authenticate for up to `CacheTTL` (default 2 minutes) in a distributed environment. Set `CacheTTL: 0` to disable caching if immediate revocation is required.
|
||||
@@ -12,6 +12,7 @@ Type-safe, composable security system for ResolveSpec with support for authentic
|
||||
- ✅ **Testable** - Easy to mock and test
|
||||
- ✅ **Extensible** - Implement custom providers for your needs
|
||||
- ✅ **Stored Procedures** - All database operations use PostgreSQL stored procedures for security and maintainability
|
||||
- ✅ **OAuth2 Authorization Server** - Built-in OAuth 2.1 + PKCE server (RFC 8414, 7591, 7009, 7662) with login form and external provider federation
|
||||
|
||||
## Stored Procedure Architecture
|
||||
|
||||
@@ -38,6 +39,12 @@ Type-safe, composable security system for ResolveSpec with support for authentic
|
||||
| `resolvespec_jwt_logout` | JWT token blacklist | JWTAuthenticator |
|
||||
| `resolvespec_column_security` | Load column rules | DatabaseColumnSecurityProvider |
|
||||
| `resolvespec_row_security` | Load row templates | DatabaseRowSecurityProvider |
|
||||
| `resolvespec_oauth_register_client` | Persist OAuth2 client (RFC 7591) | OAuthServer / DatabaseAuthenticator |
|
||||
| `resolvespec_oauth_get_client` | Retrieve OAuth2 client by ID | OAuthServer / DatabaseAuthenticator |
|
||||
| `resolvespec_oauth_save_code` | Persist authorization code | OAuthServer / DatabaseAuthenticator |
|
||||
| `resolvespec_oauth_exchange_code` | Consume authorization code (single-use) | OAuthServer / DatabaseAuthenticator |
|
||||
| `resolvespec_oauth_introspect` | Token introspection (RFC 7662) | OAuthServer / DatabaseAuthenticator |
|
||||
| `resolvespec_oauth_revoke` | Token revocation (RFC 7009) | OAuthServer / DatabaseAuthenticator |
|
||||
|
||||
See `database_schema.sql` for complete stored procedure definitions and examples.
|
||||
|
||||
@@ -897,6 +904,156 @@ securityList := security.NewSecurityList(provider)
|
||||
restheadspec.RegisterSecurityHooks(handler, securityList) // or funcspec/resolvespec
|
||||
```
|
||||
|
||||
## OAuth2 Authorization Server
|
||||
|
||||
`OAuthServer` is a generic OAuth 2.1 + PKCE authorization server. It is not tied to any spec — `pkg/resolvemcp` uses it, but it can be used standalone with any `http.ServeMux`.
|
||||
|
||||
### Endpoints
|
||||
|
||||
| Method | Path | RFC |
|
||||
|--------|------|-----|
|
||||
| `GET` | `/.well-known/oauth-authorization-server` | RFC 8414 — server metadata |
|
||||
| `POST` | `/oauth/register` | RFC 7591 — dynamic client registration |
|
||||
| `GET` | `/oauth/authorize` | OAuth 2.1 — start authorization / provider selection |
|
||||
| `POST` | `/oauth/authorize` | OAuth 2.1 — login form submission |
|
||||
| `POST` | `/oauth/token` | OAuth 2.1 — code exchange + refresh |
|
||||
| `POST` | `/oauth/revoke` | RFC 7009 — token revocation |
|
||||
| `POST` | `/oauth/introspect` | RFC 7662 — token introspection |
|
||||
| `GET` | `{ProviderCallbackPath}` | External provider redirect target |
|
||||
|
||||
### Config
|
||||
|
||||
```go
|
||||
cfg := security.OAuthServerConfig{
|
||||
Issuer: "https://example.com", // Required — token issuer URL
|
||||
ProviderCallbackPath: "/oauth/provider/callback", // External provider redirect target
|
||||
LoginTitle: "My App Login", // HTML login page title
|
||||
PersistClients: true, // Store clients in DB (multi-instance safe)
|
||||
PersistCodes: true, // Store codes in DB (multi-instance safe)
|
||||
DefaultScopes: []string{"openid", "profile"}, // Returned when no scope requested
|
||||
AccessTokenTTL: time.Hour,
|
||||
AuthCodeTTL: 5 * time.Minute,
|
||||
}
|
||||
```
|
||||
|
||||
| Field | Default | Notes |
|
||||
|-------|---------|-------|
|
||||
| `Issuer` | — | Required; trailing slash is trimmed automatically |
|
||||
| `ProviderCallbackPath` | `/oauth/provider/callback` | |
|
||||
| `LoginTitle` | `"Sign in"` | |
|
||||
| `PersistClients` | `false` | Set `true` for multi-instance |
|
||||
| `PersistCodes` | `false` | Set `true` for multi-instance; does not require `PersistClients` |
|
||||
| `DefaultScopes` | `["openid","profile","email"]` | |
|
||||
| `AccessTokenTTL` | `24h` | Also used as `expires_in` in token responses |
|
||||
| `AuthCodeTTL` | `2m` | |
|
||||
|
||||
### Operating Modes
|
||||
|
||||
**Mode 1 — Direct login (username/password form)**
|
||||
|
||||
Pass a `*DatabaseAuthenticator` to `NewOAuthServer`. The server renders a login form at `GET /oauth/authorize` and issues tokens via the stored session after login.
|
||||
|
||||
```go
|
||||
auth := security.NewDatabaseAuthenticator(db)
|
||||
srv := security.NewOAuthServer(cfg, auth)
|
||||
```
|
||||
|
||||
**Mode 2 — External provider federation**
|
||||
|
||||
Pass a `*DatabaseAuthenticator` for persistence (authorization codes, revoke, introspect) and register external providers. The authorize endpoint redirects to the specified provider (via the `provider` query param) or to the first registered provider by default.
|
||||
|
||||
```go
|
||||
auth := security.NewDatabaseAuthenticator(db)
|
||||
srv := security.NewOAuthServer(cfg, auth)
|
||||
srv.RegisterExternalProvider(googleAuth, "google")
|
||||
srv.RegisterExternalProvider(githubAuth, "github")
|
||||
```
|
||||
|
||||
**Mode 3 — Both**
|
||||
|
||||
Pass auth for the login form and also register external providers. The authorize page shows both a login form and provider buttons.
|
||||
|
||||
```go
|
||||
srv := security.NewOAuthServer(cfg, auth)
|
||||
srv.RegisterExternalProvider(googleAuth, "google")
|
||||
```
|
||||
|
||||
### Standalone Usage
|
||||
|
||||
```go
|
||||
mux := http.NewServeMux()
|
||||
mux.Handle("/.well-known/", srv.HTTPHandler())
|
||||
mux.Handle("/oauth/", srv.HTTPHandler())
|
||||
mux.Handle(cfg.ProviderCallbackPath, srv.HTTPHandler())
|
||||
|
||||
http.ListenAndServe(":8080", mux)
|
||||
```
|
||||
|
||||
### DB Persistence
|
||||
|
||||
When `PersistClients: true` or `PersistCodes: true`, the server calls the corresponding `DatabaseAuthenticator` methods. Both flags default to `false` (in-memory maps). Enable both for multi-instance deployments.
|
||||
|
||||
Requires `oauth_clients` and `oauth_codes` tables + 6 stored procedures from `database_schema.sql`.
|
||||
|
||||
#### New DB Types
|
||||
|
||||
```go
|
||||
type OAuthServerClient struct {
|
||||
ClientID string `json:"client_id"`
|
||||
RedirectURIs []string `json:"redirect_uris"`
|
||||
ClientName string `json:"client_name,omitempty"`
|
||||
GrantTypes []string `json:"grant_types"`
|
||||
AllowedScopes []string `json:"allowed_scopes,omitempty"`
|
||||
}
|
||||
|
||||
type OAuthCode struct {
|
||||
Code string `json:"code"`
|
||||
ClientID string `json:"client_id"`
|
||||
RedirectURI string `json:"redirect_uri"`
|
||||
ClientState string `json:"client_state,omitempty"`
|
||||
CodeChallenge string `json:"code_challenge"`
|
||||
CodeChallengeMethod string `json:"code_challenge_method"`
|
||||
SessionToken string `json:"session_token"`
|
||||
Scopes []string `json:"scopes,omitempty"`
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
}
|
||||
|
||||
type OAuthTokenInfo struct {
|
||||
Active bool `json:"active"`
|
||||
Sub string `json:"sub,omitempty"`
|
||||
Username string `json:"username,omitempty"`
|
||||
Email string `json:"email,omitempty"`
|
||||
Roles []string `json:"roles,omitempty"`
|
||||
Exp int64 `json:"exp,omitempty"`
|
||||
Iat int64 `json:"iat,omitempty"`
|
||||
}
|
||||
```
|
||||
|
||||
#### DatabaseAuthenticator OAuth Methods
|
||||
|
||||
```go
|
||||
auth.OAuthRegisterClient(ctx, client) // RFC 7591 — persist client
|
||||
auth.OAuthGetClient(ctx, clientID) // retrieve client
|
||||
auth.OAuthSaveCode(ctx, code) // persist authorization code
|
||||
auth.OAuthExchangeCode(ctx, code) // consume code (single-use, deletes on read)
|
||||
auth.OAuthIntrospectToken(ctx, token) // RFC 7662 — returns OAuthTokenInfo
|
||||
auth.OAuthRevokeToken(ctx, token) // RFC 7009 — revoke session
|
||||
```
|
||||
|
||||
#### SQLNames Fields
|
||||
|
||||
```go
|
||||
type SQLNames struct {
|
||||
// ... existing fields ...
|
||||
OAuthRegisterClient string // default: "resolvespec_oauth_register_client"
|
||||
OAuthGetClient string // default: "resolvespec_oauth_get_client"
|
||||
OAuthSaveCode string // default: "resolvespec_oauth_save_code"
|
||||
OAuthExchangeCode string // default: "resolvespec_oauth_exchange_code"
|
||||
OAuthIntrospect string // default: "resolvespec_oauth_introspect"
|
||||
OAuthRevoke string // default: "resolvespec_oauth_revoke"
|
||||
}
|
||||
```
|
||||
|
||||
The main changes:
|
||||
1. Security package no longer knows about specific spec types
|
||||
2. Each spec registers its own security hooks
|
||||
|
||||
@@ -1397,3 +1397,180 @@ $$ LANGUAGE plpgsql;
|
||||
|
||||
-- Get credentials by username
|
||||
-- SELECT * FROM resolvespec_passkey_get_credentials_by_username('admin');
|
||||
|
||||
-- ============================================
|
||||
-- OAuth2 Server Tables (OAuthServer persistence)
|
||||
-- ============================================
|
||||
|
||||
-- oauth_clients: persistent RFC 7591 registered clients
|
||||
CREATE TABLE IF NOT EXISTS oauth_clients (
|
||||
id SERIAL PRIMARY KEY,
|
||||
client_id VARCHAR(255) NOT NULL UNIQUE,
|
||||
redirect_uris TEXT[] NOT NULL,
|
||||
client_name VARCHAR(255),
|
||||
grant_types TEXT[] DEFAULT ARRAY['authorization_code'],
|
||||
allowed_scopes TEXT[] DEFAULT ARRAY['openid','profile','email'],
|
||||
is_active BOOLEAN DEFAULT true,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
-- oauth_codes: short-lived authorization codes (for multi-instance deployments)
|
||||
-- Note: client_id is stored without a foreign key so codes can be persisted even
|
||||
-- when OAuth clients are managed in memory rather than persisted in oauth_clients.
|
||||
CREATE TABLE IF NOT EXISTS oauth_codes (
|
||||
id SERIAL PRIMARY KEY,
|
||||
code VARCHAR(255) NOT NULL UNIQUE,
|
||||
client_id VARCHAR(255) NOT NULL,
|
||||
redirect_uri TEXT NOT NULL,
|
||||
client_state TEXT,
|
||||
code_challenge VARCHAR(255) NOT NULL,
|
||||
code_challenge_method VARCHAR(10) DEFAULT 'S256',
|
||||
session_token TEXT NOT NULL,
|
||||
refresh_token TEXT,
|
||||
scopes TEXT[],
|
||||
expires_at TIMESTAMP NOT NULL,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_oauth_codes_code ON oauth_codes(code);
|
||||
CREATE INDEX IF NOT EXISTS idx_oauth_codes_expires ON oauth_codes(expires_at);
|
||||
|
||||
-- ============================================
|
||||
-- OAuth2 Server Stored Procedures
|
||||
-- ============================================
|
||||
|
||||
CREATE OR REPLACE FUNCTION resolvespec_oauth_register_client(p_data jsonb)
|
||||
RETURNS TABLE(p_success bool, p_error text, p_data jsonb)
|
||||
LANGUAGE plpgsql AS $$
|
||||
DECLARE
|
||||
v_client_id text;
|
||||
v_row jsonb;
|
||||
BEGIN
|
||||
v_client_id := p_data->>'client_id';
|
||||
|
||||
INSERT INTO oauth_clients (client_id, redirect_uris, client_name, grant_types, allowed_scopes)
|
||||
VALUES (
|
||||
v_client_id,
|
||||
ARRAY(SELECT jsonb_array_elements_text(p_data->'redirect_uris')),
|
||||
p_data->>'client_name',
|
||||
COALESCE(ARRAY(SELECT jsonb_array_elements_text(p_data->'grant_types')), ARRAY['authorization_code']),
|
||||
COALESCE(ARRAY(SELECT jsonb_array_elements_text(p_data->'allowed_scopes')), ARRAY['openid','profile','email'])
|
||||
)
|
||||
RETURNING to_jsonb(oauth_clients.*) INTO v_row;
|
||||
|
||||
RETURN QUERY SELECT true, null::text, v_row;
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
RETURN QUERY SELECT false, SQLERRM, null::jsonb;
|
||||
END;
|
||||
$$;
|
||||
|
||||
CREATE OR REPLACE FUNCTION resolvespec_oauth_get_client(p_client_id text)
|
||||
RETURNS TABLE(p_success bool, p_error text, p_data jsonb)
|
||||
LANGUAGE plpgsql AS $$
|
||||
DECLARE
|
||||
v_row jsonb;
|
||||
BEGIN
|
||||
SELECT to_jsonb(oauth_clients.*)
|
||||
INTO v_row
|
||||
FROM oauth_clients
|
||||
WHERE client_id = p_client_id AND is_active = true;
|
||||
|
||||
IF v_row IS NULL THEN
|
||||
RETURN QUERY SELECT false, 'client not found'::text, null::jsonb;
|
||||
ELSE
|
||||
RETURN QUERY SELECT true, null::text, v_row;
|
||||
END IF;
|
||||
END;
|
||||
$$;
|
||||
|
||||
CREATE OR REPLACE FUNCTION resolvespec_oauth_save_code(p_data jsonb)
|
||||
RETURNS TABLE(p_success bool, p_error text)
|
||||
LANGUAGE plpgsql AS $$
|
||||
BEGIN
|
||||
INSERT INTO oauth_codes (code, client_id, redirect_uri, client_state, code_challenge, code_challenge_method, session_token, refresh_token, scopes, expires_at)
|
||||
VALUES (
|
||||
p_data->>'code',
|
||||
p_data->>'client_id',
|
||||
p_data->>'redirect_uri',
|
||||
p_data->>'client_state',
|
||||
p_data->>'code_challenge',
|
||||
COALESCE(p_data->>'code_challenge_method', 'S256'),
|
||||
p_data->>'session_token',
|
||||
p_data->>'refresh_token',
|
||||
ARRAY(SELECT jsonb_array_elements_text(p_data->'scopes')),
|
||||
(p_data->>'expires_at')::timestamp
|
||||
);
|
||||
|
||||
RETURN QUERY SELECT true, null::text;
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
RETURN QUERY SELECT false, SQLERRM;
|
||||
END;
|
||||
$$;
|
||||
|
||||
CREATE OR REPLACE FUNCTION resolvespec_oauth_exchange_code(p_code text)
|
||||
RETURNS TABLE(p_success bool, p_error text, p_data jsonb)
|
||||
LANGUAGE plpgsql AS $$
|
||||
DECLARE
|
||||
v_row jsonb;
|
||||
BEGIN
|
||||
DELETE FROM oauth_codes
|
||||
WHERE code = p_code AND expires_at > now()
|
||||
RETURNING jsonb_build_object(
|
||||
'client_id', client_id,
|
||||
'redirect_uri', redirect_uri,
|
||||
'client_state', client_state,
|
||||
'code_challenge', code_challenge,
|
||||
'code_challenge_method', code_challenge_method,
|
||||
'session_token', session_token,
|
||||
'refresh_token', refresh_token,
|
||||
'scopes', to_jsonb(scopes)
|
||||
) INTO v_row;
|
||||
|
||||
IF v_row IS NULL THEN
|
||||
RETURN QUERY SELECT false, 'invalid or expired code'::text, null::jsonb;
|
||||
ELSE
|
||||
RETURN QUERY SELECT true, null::text, v_row;
|
||||
END IF;
|
||||
END;
|
||||
$$;
|
||||
|
||||
CREATE OR REPLACE FUNCTION resolvespec_oauth_introspect(p_token text)
|
||||
RETURNS TABLE(p_success bool, p_error text, p_data jsonb)
|
||||
LANGUAGE plpgsql AS $$
|
||||
DECLARE
|
||||
v_row jsonb;
|
||||
BEGIN
|
||||
SELECT jsonb_build_object(
|
||||
'active', true,
|
||||
'sub', u.id::text,
|
||||
'username', u.username,
|
||||
'email', u.email,
|
||||
'user_level', u.user_level,
|
||||
-- NULLIF converts empty string to NULL; string_to_array(NULL) returns NULL;
|
||||
-- to_jsonb(NULL) returns NULL; COALESCE then returns '[]' for NULL/empty roles.
|
||||
'roles', COALESCE(to_jsonb(string_to_array(NULLIF(u.roles, ''), ',')), '[]'::jsonb),
|
||||
'exp', EXTRACT(EPOCH FROM s.expires_at)::bigint,
|
||||
'iat', EXTRACT(EPOCH FROM s.created_at)::bigint
|
||||
)
|
||||
INTO v_row
|
||||
FROM user_sessions s
|
||||
JOIN users u ON u.id = s.user_id
|
||||
WHERE s.session_token = p_token
|
||||
AND s.expires_at > now()
|
||||
AND u.is_active = true;
|
||||
|
||||
IF v_row IS NULL THEN
|
||||
RETURN QUERY SELECT true, null::text, '{"active":false}'::jsonb;
|
||||
ELSE
|
||||
RETURN QUERY SELECT true, null::text, v_row;
|
||||
END IF;
|
||||
END;
|
||||
$$;
|
||||
|
||||
CREATE OR REPLACE FUNCTION resolvespec_oauth_revoke(p_token text)
|
||||
RETURNS TABLE(p_success bool, p_error text)
|
||||
LANGUAGE plpgsql AS $$
|
||||
BEGIN
|
||||
DELETE FROM user_sessions WHERE session_token = p_token;
|
||||
RETURN QUERY SELECT true, null::text;
|
||||
END;
|
||||
$$;
|
||||
|
||||
81
pkg/security/keystore.go
Normal file
81
pkg/security/keystore.go
Normal file
@@ -0,0 +1,81 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"time"
|
||||
)
|
||||
|
||||
// hashSHA256Hex returns the lowercase hex SHA-256 digest of the given string.
|
||||
// Used by all keystore implementations to hash raw keys before storage or lookup.
|
||||
func hashSHA256Hex(raw string) string {
|
||||
sum := sha256.Sum256([]byte(raw))
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
// KeyType identifies the category of an auth key.
|
||||
type KeyType string
|
||||
|
||||
const (
|
||||
// KeyTypeJWTSecret is a per-user JWT signing secret for token generation.
|
||||
KeyTypeJWTSecret KeyType = "jwt_secret"
|
||||
// KeyTypeHeaderAPI is a static API key sent via a request header.
|
||||
KeyTypeHeaderAPI KeyType = "header_api"
|
||||
// KeyTypeOAuth2 holds OAuth2 client credentials (client_id / client_secret).
|
||||
KeyTypeOAuth2 KeyType = "oauth2"
|
||||
// KeyTypeGenericAPI is a generic application API key.
|
||||
KeyTypeGenericAPI KeyType = "api"
|
||||
)
|
||||
|
||||
// UserKey represents a single named auth key belonging to a user.
|
||||
// KeyHash stores the SHA-256 hex digest of the raw key; the raw key is never persisted.
|
||||
type UserKey struct {
|
||||
ID int64 `json:"id"`
|
||||
UserID int `json:"user_id"`
|
||||
KeyType KeyType `json:"key_type"`
|
||||
KeyHash string `json:"key_hash"` // SHA-256 hex; never the raw key
|
||||
Name string `json:"name"`
|
||||
Scopes []string `json:"scopes,omitempty"`
|
||||
Meta map[string]any `json:"meta,omitempty"`
|
||||
ExpiresAt *time.Time `json:"expires_at,omitempty"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
LastUsedAt *time.Time `json:"last_used_at,omitempty"`
|
||||
IsActive bool `json:"is_active"`
|
||||
}
|
||||
|
||||
// CreateKeyRequest specifies the parameters for a new key.
|
||||
type CreateKeyRequest struct {
|
||||
UserID int
|
||||
KeyType KeyType
|
||||
Name string
|
||||
Scopes []string
|
||||
Meta map[string]any
|
||||
ExpiresAt *time.Time
|
||||
}
|
||||
|
||||
// CreateKeyResponse is returned exactly once when a key is created.
|
||||
// The caller is responsible for persisting RawKey; it is not stored anywhere.
|
||||
type CreateKeyResponse struct {
|
||||
Key UserKey
|
||||
RawKey string // crypto/rand 32 bytes, base64url-encoded
|
||||
}
|
||||
|
||||
// KeyStore manages per-user auth keys with pluggable storage backends.
|
||||
// Implementations: ConfigKeyStore (static list) and DatabaseKeyStore (stored procedures).
|
||||
type KeyStore interface {
|
||||
// CreateKey generates a new key, stores its hash, and returns the raw key once.
|
||||
CreateKey(ctx context.Context, req CreateKeyRequest) (*CreateKeyResponse, error)
|
||||
|
||||
// GetUserKeys returns all active, non-expired keys for a user.
|
||||
// Pass an empty KeyType to return all types.
|
||||
GetUserKeys(ctx context.Context, userID int, keyType KeyType) ([]UserKey, error)
|
||||
|
||||
// DeleteKey soft-deletes a key by ID after verifying ownership.
|
||||
DeleteKey(ctx context.Context, userID int, keyID int64) error
|
||||
|
||||
// ValidateKey checks a raw key, returns the matching UserKey on success.
|
||||
// The implementation hashes the raw key before any lookup.
|
||||
// Pass an empty KeyType to accept any type.
|
||||
ValidateKey(ctx context.Context, rawKey string, keyType KeyType) (*UserKey, error)
|
||||
}
|
||||
97
pkg/security/keystore_authenticator.go
Normal file
97
pkg/security/keystore_authenticator.go
Normal file
@@ -0,0 +1,97 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// KeyStoreAuthenticator implements the Authenticator interface using a KeyStore.
|
||||
// It is suitable for long-lived application credentials (API keys, JWT secrets, etc.)
|
||||
// rather than interactive sessions. Login and Logout are not supported — key lifecycle
|
||||
// is managed directly through the KeyStore.
|
||||
//
|
||||
// Key extraction order:
|
||||
// 1. Authorization: Bearer <key>
|
||||
// 2. Authorization: ApiKey <key>
|
||||
// 3. X-API-Key header
|
||||
type KeyStoreAuthenticator struct {
|
||||
keyStore KeyStore
|
||||
keyType KeyType // empty = accept any type
|
||||
}
|
||||
|
||||
// NewKeyStoreAuthenticator creates a KeyStoreAuthenticator.
|
||||
// Pass an empty keyType to accept keys of any type.
|
||||
func NewKeyStoreAuthenticator(ks KeyStore, keyType KeyType) *KeyStoreAuthenticator {
|
||||
return &KeyStoreAuthenticator{keyStore: ks, keyType: keyType}
|
||||
}
|
||||
|
||||
// Login is not supported for keystore authentication.
|
||||
func (a *KeyStoreAuthenticator) Login(_ context.Context, _ LoginRequest) (*LoginResponse, error) {
|
||||
return nil, fmt.Errorf("keystore authenticator does not support login")
|
||||
}
|
||||
|
||||
// Logout is not supported for keystore authentication.
|
||||
func (a *KeyStoreAuthenticator) Logout(_ context.Context, _ LogoutRequest) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Authenticate extracts an API key from the request and validates it against the KeyStore.
|
||||
// Returns a UserContext built from the matching UserKey on success.
|
||||
func (a *KeyStoreAuthenticator) Authenticate(r *http.Request) (*UserContext, error) {
|
||||
rawKey := extractAPIKey(r)
|
||||
if rawKey == "" {
|
||||
return nil, fmt.Errorf("API key required (Authorization: Bearer/ApiKey <key> or X-API-Key header)")
|
||||
}
|
||||
|
||||
userKey, err := a.keyStore.ValidateKey(r.Context(), rawKey, a.keyType)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid API key: %w", err)
|
||||
}
|
||||
|
||||
return userKeyToUserContext(userKey), nil
|
||||
}
|
||||
|
||||
// extractAPIKey extracts a raw key from the request using the following precedence:
|
||||
// 1. Authorization: Bearer <key>
|
||||
// 2. Authorization: ApiKey <key>
|
||||
// 3. X-API-Key header
|
||||
func extractAPIKey(r *http.Request) string {
|
||||
if auth := r.Header.Get("Authorization"); auth != "" {
|
||||
if after, ok := strings.CutPrefix(auth, "Bearer "); ok {
|
||||
return strings.TrimSpace(after)
|
||||
}
|
||||
if after, ok := strings.CutPrefix(auth, "ApiKey "); ok {
|
||||
return strings.TrimSpace(after)
|
||||
}
|
||||
}
|
||||
return strings.TrimSpace(r.Header.Get("X-API-Key"))
|
||||
}
|
||||
|
||||
// userKeyToUserContext converts a UserKey into a UserContext.
|
||||
// Scopes are mapped to Roles. Key type and name are stored in Claims.
|
||||
func userKeyToUserContext(k *UserKey) *UserContext {
|
||||
claims := map[string]any{
|
||||
"key_type": string(k.KeyType),
|
||||
"key_name": k.Name,
|
||||
}
|
||||
|
||||
meta := k.Meta
|
||||
if meta == nil {
|
||||
meta = map[string]any{}
|
||||
}
|
||||
|
||||
roles := k.Scopes
|
||||
if roles == nil {
|
||||
roles = []string{}
|
||||
}
|
||||
|
||||
return &UserContext{
|
||||
UserID: k.UserID,
|
||||
SessionID: fmt.Sprintf("key:%d", k.ID),
|
||||
Roles: roles,
|
||||
Claims: claims,
|
||||
Meta: meta,
|
||||
}
|
||||
}
|
||||
149
pkg/security/keystore_config.go
Normal file
149
pkg/security/keystore_config.go
Normal file
@@ -0,0 +1,149 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/subtle"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ConfigKeyStore is an in-memory keystore backed by a static slice of UserKey values.
|
||||
// It is designed for config-file driven setups (e.g. service accounts defined in YAML)
|
||||
// with a small, bounded number of keys. For large or dynamic key sets use DatabaseKeyStore.
|
||||
//
|
||||
// Pre-existing entries must have KeyHash set to the SHA-256 hex of the intended raw key.
|
||||
// Keys created at runtime via CreateKey are held in memory only and lost on restart.
|
||||
type ConfigKeyStore struct {
|
||||
mu sync.RWMutex
|
||||
keys []UserKey
|
||||
next int64 // monotonic ID counter for runtime-created keys (atomic)
|
||||
}
|
||||
|
||||
// NewConfigKeyStore creates a ConfigKeyStore seeded with the provided keys.
|
||||
// Pass nil or an empty slice to start with no pre-loaded keys.
|
||||
// Zero-value entries (CreatedAt is zero) are treated as active and assigned the current time.
|
||||
func NewConfigKeyStore(keys []UserKey) *ConfigKeyStore {
|
||||
var maxID int64
|
||||
copied := make([]UserKey, len(keys))
|
||||
copy(copied, keys)
|
||||
for i := range copied {
|
||||
if copied[i].CreatedAt.IsZero() {
|
||||
copied[i].IsActive = true
|
||||
copied[i].CreatedAt = time.Now()
|
||||
}
|
||||
if copied[i].ID > maxID {
|
||||
maxID = copied[i].ID
|
||||
}
|
||||
}
|
||||
return &ConfigKeyStore{keys: copied, next: maxID}
|
||||
}
|
||||
|
||||
// CreateKey generates a new raw key, stores its SHA-256 hash, and returns the raw key once.
|
||||
func (s *ConfigKeyStore) CreateKey(_ context.Context, req CreateKeyRequest) (*CreateKeyResponse, error) {
|
||||
rawBytes := make([]byte, 32)
|
||||
if _, err := rand.Read(rawBytes); err != nil {
|
||||
return nil, fmt.Errorf("failed to generate key material: %w", err)
|
||||
}
|
||||
rawKey := base64.RawURLEncoding.EncodeToString(rawBytes)
|
||||
hash := hashSHA256Hex(rawKey)
|
||||
|
||||
id := atomic.AddInt64(&s.next, 1)
|
||||
key := UserKey{
|
||||
ID: id,
|
||||
UserID: req.UserID,
|
||||
KeyType: req.KeyType,
|
||||
KeyHash: hash,
|
||||
Name: req.Name,
|
||||
Scopes: req.Scopes,
|
||||
Meta: req.Meta,
|
||||
ExpiresAt: req.ExpiresAt,
|
||||
CreatedAt: time.Now(),
|
||||
IsActive: true,
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
s.keys = append(s.keys, key)
|
||||
s.mu.Unlock()
|
||||
|
||||
return &CreateKeyResponse{Key: key, RawKey: rawKey}, nil
|
||||
}
|
||||
|
||||
// GetUserKeys returns all active, non-expired keys for the given user.
|
||||
// Pass an empty KeyType to return all types.
|
||||
func (s *ConfigKeyStore) GetUserKeys(_ context.Context, userID int, keyType KeyType) ([]UserKey, error) {
|
||||
now := time.Now()
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
var result []UserKey
|
||||
for i := range s.keys {
|
||||
k := &s.keys[i]
|
||||
if k.UserID != userID || !k.IsActive {
|
||||
continue
|
||||
}
|
||||
if k.ExpiresAt != nil && k.ExpiresAt.Before(now) {
|
||||
continue
|
||||
}
|
||||
if keyType != "" && k.KeyType != keyType {
|
||||
continue
|
||||
}
|
||||
result = append(result, *k)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// DeleteKey soft-deletes a key by setting IsActive to false after ownership verification.
|
||||
func (s *ConfigKeyStore) DeleteKey(_ context.Context, userID int, keyID int64) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
for i := range s.keys {
|
||||
if s.keys[i].ID == keyID {
|
||||
if s.keys[i].UserID != userID {
|
||||
return fmt.Errorf("key not found or permission denied")
|
||||
}
|
||||
s.keys[i].IsActive = false
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("key not found")
|
||||
}
|
||||
|
||||
// ValidateKey hashes the raw key and finds a matching, active, non-expired entry.
|
||||
// Uses constant-time comparison to prevent timing side-channels.
|
||||
// Pass an empty KeyType to accept any type.
|
||||
func (s *ConfigKeyStore) ValidateKey(_ context.Context, rawKey string, keyType KeyType) (*UserKey, error) {
|
||||
hash := hashSHA256Hex(rawKey)
|
||||
hashBytes, _ := hex.DecodeString(hash)
|
||||
now := time.Now()
|
||||
|
||||
// Write lock: ValidateKey updates LastUsedAt on the matched entry.
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
for i := range s.keys {
|
||||
k := &s.keys[i]
|
||||
if !k.IsActive {
|
||||
continue
|
||||
}
|
||||
if k.ExpiresAt != nil && k.ExpiresAt.Before(now) {
|
||||
continue
|
||||
}
|
||||
if keyType != "" && k.KeyType != keyType {
|
||||
continue
|
||||
}
|
||||
stored, _ := hex.DecodeString(k.KeyHash)
|
||||
if subtle.ConstantTimeCompare(hashBytes, stored) != 1 {
|
||||
continue
|
||||
}
|
||||
k.LastUsedAt = &now
|
||||
result := *k
|
||||
return &result, nil
|
||||
}
|
||||
return nil, fmt.Errorf("invalid or expired key")
|
||||
}
|
||||
256
pkg/security/keystore_database.go
Normal file
256
pkg/security/keystore_database.go
Normal file
@@ -0,0 +1,256 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"database/sql"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/cache"
|
||||
)
|
||||
|
||||
// DatabaseKeyStoreOptions configures DatabaseKeyStore.
|
||||
type DatabaseKeyStoreOptions struct {
|
||||
// Cache is an optional cache instance. If nil, uses the default cache.
|
||||
Cache *cache.Cache
|
||||
// CacheTTL is the duration to cache ValidateKey results.
|
||||
// Default: 2 minutes.
|
||||
CacheTTL time.Duration
|
||||
// SQLNames provides custom procedure names. If nil, uses DefaultKeyStoreSQLNames().
|
||||
SQLNames *KeyStoreSQLNames
|
||||
// DBFactory is called to obtain a fresh *sql.DB when the existing connection is closed.
|
||||
// If nil, reconnection is disabled.
|
||||
DBFactory func() (*sql.DB, error)
|
||||
}
|
||||
|
||||
// DatabaseKeyStore is a KeyStore backed by PostgreSQL stored procedures.
|
||||
// All DB operations go through configurable procedure names; the raw key is
|
||||
// never passed to the database.
|
||||
//
|
||||
// See keystore_schema.sql for the required table and procedure definitions.
|
||||
//
|
||||
// Note: DeleteKey invalidates the cache entry for the deleted key. Due to the
|
||||
// cache TTL, a deleted key may continue to authenticate for up to CacheTTL
|
||||
// (default 2 minutes) if the cache entry cannot be invalidated.
|
||||
type DatabaseKeyStore struct {
|
||||
db *sql.DB
|
||||
dbMu sync.RWMutex
|
||||
dbFactory func() (*sql.DB, error)
|
||||
sqlNames *KeyStoreSQLNames
|
||||
cache *cache.Cache
|
||||
cacheTTL time.Duration
|
||||
}
|
||||
|
||||
// NewDatabaseKeyStore creates a DatabaseKeyStore with optional configuration.
|
||||
func NewDatabaseKeyStore(db *sql.DB, opts ...DatabaseKeyStoreOptions) *DatabaseKeyStore {
|
||||
o := DatabaseKeyStoreOptions{}
|
||||
if len(opts) > 0 {
|
||||
o = opts[0]
|
||||
}
|
||||
if o.CacheTTL == 0 {
|
||||
o.CacheTTL = 2 * time.Minute
|
||||
}
|
||||
c := o.Cache
|
||||
if c == nil {
|
||||
c = cache.GetDefaultCache()
|
||||
}
|
||||
names := MergeKeyStoreSQLNames(DefaultKeyStoreSQLNames(), o.SQLNames)
|
||||
return &DatabaseKeyStore{
|
||||
db: db,
|
||||
dbFactory: o.DBFactory,
|
||||
sqlNames: names,
|
||||
cache: c,
|
||||
cacheTTL: o.CacheTTL,
|
||||
}
|
||||
}
|
||||
|
||||
func (ks *DatabaseKeyStore) getDB() *sql.DB {
|
||||
ks.dbMu.RLock()
|
||||
defer ks.dbMu.RUnlock()
|
||||
return ks.db
|
||||
}
|
||||
|
||||
func (ks *DatabaseKeyStore) reconnectDB() error {
|
||||
if ks.dbFactory == nil {
|
||||
return fmt.Errorf("no db factory configured for reconnect")
|
||||
}
|
||||
newDB, err := ks.dbFactory()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ks.dbMu.Lock()
|
||||
ks.db = newDB
|
||||
ks.dbMu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateKey generates a raw key, stores its SHA-256 hash via the create procedure,
|
||||
// and returns the raw key once.
|
||||
func (ks *DatabaseKeyStore) CreateKey(ctx context.Context, req CreateKeyRequest) (*CreateKeyResponse, error) {
|
||||
rawBytes := make([]byte, 32)
|
||||
if _, err := rand.Read(rawBytes); err != nil {
|
||||
return nil, fmt.Errorf("failed to generate key material: %w", err)
|
||||
}
|
||||
rawKey := base64.RawURLEncoding.EncodeToString(rawBytes)
|
||||
hash := hashSHA256Hex(rawKey)
|
||||
|
||||
type createRequest struct {
|
||||
UserID int `json:"user_id"`
|
||||
KeyType KeyType `json:"key_type"`
|
||||
KeyHash string `json:"key_hash"`
|
||||
Name string `json:"name"`
|
||||
Scopes []string `json:"scopes,omitempty"`
|
||||
Meta map[string]any `json:"meta,omitempty"`
|
||||
ExpiresAt *time.Time `json:"expires_at,omitempty"`
|
||||
}
|
||||
|
||||
reqJSON, err := json.Marshal(createRequest{
|
||||
UserID: req.UserID,
|
||||
KeyType: req.KeyType,
|
||||
KeyHash: hash,
|
||||
Name: req.Name,
|
||||
Scopes: req.Scopes,
|
||||
Meta: req.Meta,
|
||||
ExpiresAt: req.ExpiresAt,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal create key request: %w", err)
|
||||
}
|
||||
|
||||
var success bool
|
||||
var errorMsg sql.NullString
|
||||
var keyJSON sql.NullString
|
||||
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_key::text FROM %s($1::jsonb)`, ks.sqlNames.CreateKey)
|
||||
if err = ks.getDB().QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &keyJSON); err != nil {
|
||||
return nil, fmt.Errorf("create key procedure failed: %w", err)
|
||||
}
|
||||
if !success {
|
||||
return nil, errors.New(nullStringOr(errorMsg, "create key failed"))
|
||||
}
|
||||
|
||||
var key UserKey
|
||||
if err = json.Unmarshal([]byte(keyJSON.String), &key); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse created key: %w", err)
|
||||
}
|
||||
|
||||
return &CreateKeyResponse{Key: key, RawKey: rawKey}, nil
|
||||
}
|
||||
|
||||
// GetUserKeys returns all active, non-expired keys for the given user.
|
||||
// Pass an empty KeyType to return all types.
|
||||
func (ks *DatabaseKeyStore) GetUserKeys(ctx context.Context, userID int, keyType KeyType) ([]UserKey, error) {
|
||||
var success bool
|
||||
var errorMsg sql.NullString
|
||||
var keysJSON sql.NullString
|
||||
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_keys::text FROM %s($1, $2)`, ks.sqlNames.GetUserKeys)
|
||||
if err := ks.getDB().QueryRowContext(ctx, query, userID, string(keyType)).Scan(&success, &errorMsg, &keysJSON); err != nil {
|
||||
return nil, fmt.Errorf("get user keys procedure failed: %w", err)
|
||||
}
|
||||
if !success {
|
||||
return nil, errors.New(nullStringOr(errorMsg, "get user keys failed"))
|
||||
}
|
||||
|
||||
var keys []UserKey
|
||||
if keysJSON.Valid && keysJSON.String != "" && keysJSON.String != "[]" {
|
||||
if err := json.Unmarshal([]byte(keysJSON.String), &keys); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse user keys: %w", err)
|
||||
}
|
||||
}
|
||||
if keys == nil {
|
||||
keys = []UserKey{}
|
||||
}
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
// DeleteKey soft-deletes a key after verifying ownership and invalidates its cache entry.
|
||||
// The delete procedure returns the key_hash so no separate lookup is needed.
|
||||
// Note: cache invalidation is best-effort; a cached entry may persist for up to CacheTTL.
|
||||
func (ks *DatabaseKeyStore) DeleteKey(ctx context.Context, userID int, keyID int64) error {
|
||||
var success bool
|
||||
var errorMsg sql.NullString
|
||||
var keyHash sql.NullString
|
||||
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_key_hash FROM %s($1, $2)`, ks.sqlNames.DeleteKey)
|
||||
if err := ks.getDB().QueryRowContext(ctx, query, userID, keyID).Scan(&success, &errorMsg, &keyHash); err != nil {
|
||||
return fmt.Errorf("delete key procedure failed: %w", err)
|
||||
}
|
||||
if !success {
|
||||
return errors.New(nullStringOr(errorMsg, "delete key failed"))
|
||||
}
|
||||
|
||||
if keyHash.Valid && keyHash.String != "" && ks.cache != nil {
|
||||
_ = ks.cache.Delete(ctx, keystoreCacheKey(keyHash.String))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateKey hashes the raw key and calls the validate procedure.
|
||||
// Results are cached for CacheTTL to reduce DB load on hot paths.
|
||||
func (ks *DatabaseKeyStore) ValidateKey(ctx context.Context, rawKey string, keyType KeyType) (*UserKey, error) {
|
||||
hash := hashSHA256Hex(rawKey)
|
||||
cacheKey := keystoreCacheKey(hash)
|
||||
|
||||
if ks.cache != nil {
|
||||
var cached UserKey
|
||||
if err := ks.cache.Get(ctx, cacheKey, &cached); err == nil {
|
||||
if cached.IsActive {
|
||||
return &cached, nil
|
||||
}
|
||||
return nil, errors.New("invalid or expired key")
|
||||
}
|
||||
}
|
||||
|
||||
var success bool
|
||||
var errorMsg sql.NullString
|
||||
var keyJSON sql.NullString
|
||||
|
||||
runQuery := func() error {
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_key::text FROM %s($1, $2)`, ks.sqlNames.ValidateKey)
|
||||
return ks.getDB().QueryRowContext(ctx, query, hash, string(keyType)).Scan(&success, &errorMsg, &keyJSON)
|
||||
}
|
||||
if err := runQuery(); err != nil {
|
||||
if isDBClosed(err) {
|
||||
if reconnErr := ks.reconnectDB(); reconnErr == nil {
|
||||
err = runQuery()
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("validate key procedure failed: %w", err)
|
||||
}
|
||||
} else {
|
||||
return nil, fmt.Errorf("validate key procedure failed: %w", err)
|
||||
}
|
||||
}
|
||||
if !success {
|
||||
return nil, errors.New(nullStringOr(errorMsg, "invalid or expired key"))
|
||||
}
|
||||
|
||||
var key UserKey
|
||||
if err := json.Unmarshal([]byte(keyJSON.String), &key); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse validated key: %w", err)
|
||||
}
|
||||
|
||||
if ks.cache != nil {
|
||||
_ = ks.cache.Set(ctx, cacheKey, key, ks.cacheTTL)
|
||||
}
|
||||
|
||||
return &key, nil
|
||||
}
|
||||
|
||||
func keystoreCacheKey(hash string) string {
|
||||
return "keystore:validate:" + hash
|
||||
}
|
||||
|
||||
// nullStringOr returns s.String if valid, otherwise the fallback.
|
||||
func nullStringOr(s sql.NullString, fallback string) string {
|
||||
if s.Valid && s.String != "" {
|
||||
return s.String
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
187
pkg/security/keystore_schema.sql
Normal file
187
pkg/security/keystore_schema.sql
Normal file
@@ -0,0 +1,187 @@
|
||||
-- Keystore schema for per-user auth keys
|
||||
-- Apply alongside database_schema.sql (requires the users table)
|
||||
|
||||
CREATE TABLE IF NOT EXISTS user_keys (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||
key_type VARCHAR(50) NOT NULL,
|
||||
key_hash VARCHAR(64) NOT NULL UNIQUE, -- SHA-256 hex digest (64 chars)
|
||||
name VARCHAR(255) NOT NULL DEFAULT '',
|
||||
scopes TEXT, -- JSON array, e.g. '["read","write"]'
|
||||
meta JSONB,
|
||||
expires_at TIMESTAMP,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
last_used_at TIMESTAMP,
|
||||
is_active BOOLEAN DEFAULT true
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_user_keys_user_id ON user_keys(user_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_user_keys_key_hash ON user_keys(key_hash);
|
||||
CREATE INDEX IF NOT EXISTS idx_user_keys_key_type ON user_keys(key_type);
|
||||
|
||||
-- resolvespec_keystore_get_user_keys
|
||||
-- Returns all active, non-expired keys for a user.
|
||||
-- Pass empty p_key_type to return all key types.
|
||||
CREATE OR REPLACE FUNCTION resolvespec_keystore_get_user_keys(
|
||||
p_user_id INTEGER,
|
||||
p_key_type TEXT DEFAULT ''
|
||||
)
|
||||
RETURNS TABLE(p_success BOOLEAN, p_error TEXT, p_keys JSONB)
|
||||
LANGUAGE plpgsql AS $$
|
||||
DECLARE
|
||||
v_keys JSONB;
|
||||
BEGIN
|
||||
SELECT COALESCE(
|
||||
jsonb_agg(
|
||||
jsonb_build_object(
|
||||
'id', k.id,
|
||||
'user_id', k.user_id,
|
||||
'key_type', k.key_type,
|
||||
'name', k.name,
|
||||
'scopes', CASE WHEN k.scopes IS NOT NULL THEN k.scopes::jsonb ELSE '[]'::jsonb END,
|
||||
'meta', COALESCE(k.meta, '{}'::jsonb),
|
||||
'expires_at', k.expires_at,
|
||||
'created_at', k.created_at,
|
||||
'last_used_at', k.last_used_at,
|
||||
'is_active', k.is_active
|
||||
)
|
||||
),
|
||||
'[]'::jsonb
|
||||
)
|
||||
INTO v_keys
|
||||
FROM user_keys k
|
||||
WHERE k.user_id = p_user_id
|
||||
AND k.is_active = true
|
||||
AND (k.expires_at IS NULL OR k.expires_at > NOW())
|
||||
AND (p_key_type = '' OR k.key_type = p_key_type);
|
||||
|
||||
RETURN QUERY SELECT true, NULL::TEXT, v_keys;
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
RETURN QUERY SELECT false, SQLERRM, NULL::JSONB;
|
||||
END;
|
||||
$$;
|
||||
|
||||
-- resolvespec_keystore_create_key
|
||||
-- Inserts a new key row. key_hash is provided by the caller (Go hashes the raw key).
|
||||
-- Returns the created key record (without key_hash).
|
||||
CREATE OR REPLACE FUNCTION resolvespec_keystore_create_key(
|
||||
p_request JSONB
|
||||
)
|
||||
RETURNS TABLE(p_success BOOLEAN, p_error TEXT, p_key JSONB)
|
||||
LANGUAGE plpgsql AS $$
|
||||
DECLARE
|
||||
v_id BIGINT;
|
||||
v_created_at TIMESTAMP;
|
||||
v_key JSONB;
|
||||
BEGIN
|
||||
INSERT INTO user_keys (user_id, key_type, key_hash, name, scopes, meta, expires_at)
|
||||
VALUES (
|
||||
(p_request->>'user_id')::INTEGER,
|
||||
p_request->>'key_type',
|
||||
p_request->>'key_hash',
|
||||
COALESCE(p_request->>'name', ''),
|
||||
p_request->>'scopes',
|
||||
p_request->'meta',
|
||||
CASE WHEN p_request->>'expires_at' IS NOT NULL
|
||||
THEN (p_request->>'expires_at')::TIMESTAMP
|
||||
ELSE NULL
|
||||
END
|
||||
)
|
||||
RETURNING id, created_at INTO v_id, v_created_at;
|
||||
|
||||
v_key := jsonb_build_object(
|
||||
'id', v_id,
|
||||
'user_id', (p_request->>'user_id')::INTEGER,
|
||||
'key_type', p_request->>'key_type',
|
||||
'name', COALESCE(p_request->>'name', ''),
|
||||
'scopes', CASE WHEN p_request->>'scopes' IS NOT NULL
|
||||
THEN (p_request->>'scopes')::jsonb
|
||||
ELSE '[]'::jsonb END,
|
||||
'meta', COALESCE(p_request->'meta', '{}'::jsonb),
|
||||
'expires_at', p_request->>'expires_at',
|
||||
'created_at', v_created_at,
|
||||
'is_active', true
|
||||
);
|
||||
|
||||
RETURN QUERY SELECT true, NULL::TEXT, v_key;
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
RETURN QUERY SELECT false, SQLERRM, NULL::JSONB;
|
||||
END;
|
||||
$$;
|
||||
|
||||
-- resolvespec_keystore_delete_key
|
||||
-- Soft-deletes a key (is_active = false) after verifying ownership.
|
||||
-- Returns p_key_hash so the caller can invalidate cache entries without a separate query.
|
||||
CREATE OR REPLACE FUNCTION resolvespec_keystore_delete_key(
|
||||
p_user_id INTEGER,
|
||||
p_key_id BIGINT
|
||||
)
|
||||
RETURNS TABLE(p_success BOOLEAN, p_error TEXT, p_key_hash TEXT)
|
||||
LANGUAGE plpgsql AS $$
|
||||
DECLARE
|
||||
v_hash TEXT;
|
||||
BEGIN
|
||||
UPDATE user_keys
|
||||
SET is_active = false
|
||||
WHERE id = p_key_id AND user_id = p_user_id AND is_active = true
|
||||
RETURNING key_hash INTO v_hash;
|
||||
|
||||
IF NOT FOUND THEN
|
||||
RETURN QUERY SELECT false, 'key not found or already deleted'::TEXT, NULL::TEXT;
|
||||
RETURN;
|
||||
END IF;
|
||||
|
||||
RETURN QUERY SELECT true, NULL::TEXT, v_hash;
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
RETURN QUERY SELECT false, SQLERRM, NULL::TEXT;
|
||||
END;
|
||||
$$;
|
||||
|
||||
-- resolvespec_keystore_validate_key
|
||||
-- Looks up a key by its SHA-256 hash, checks active status and expiry,
|
||||
-- updates last_used_at, and returns the key record.
|
||||
-- p_key_type can be empty to accept any key type.
|
||||
CREATE OR REPLACE FUNCTION resolvespec_keystore_validate_key(
|
||||
p_key_hash TEXT,
|
||||
p_key_type TEXT DEFAULT ''
|
||||
)
|
||||
RETURNS TABLE(p_success BOOLEAN, p_error TEXT, p_key JSONB)
|
||||
LANGUAGE plpgsql AS $$
|
||||
DECLARE
|
||||
v_key_rec user_keys%ROWTYPE;
|
||||
v_key JSONB;
|
||||
BEGIN
|
||||
SELECT * INTO v_key_rec
|
||||
FROM user_keys
|
||||
WHERE key_hash = p_key_hash
|
||||
AND is_active = true
|
||||
AND (expires_at IS NULL OR expires_at > NOW())
|
||||
AND (p_key_type = '' OR key_type = p_key_type);
|
||||
|
||||
IF NOT FOUND THEN
|
||||
RETURN QUERY SELECT false, 'invalid or expired key'::TEXT, NULL::JSONB;
|
||||
RETURN;
|
||||
END IF;
|
||||
|
||||
UPDATE user_keys SET last_used_at = NOW() WHERE id = v_key_rec.id;
|
||||
|
||||
v_key := jsonb_build_object(
|
||||
'id', v_key_rec.id,
|
||||
'user_id', v_key_rec.user_id,
|
||||
'key_type', v_key_rec.key_type,
|
||||
'name', v_key_rec.name,
|
||||
'scopes', CASE WHEN v_key_rec.scopes IS NOT NULL
|
||||
THEN v_key_rec.scopes::jsonb
|
||||
ELSE '[]'::jsonb END,
|
||||
'meta', COALESCE(v_key_rec.meta, '{}'::jsonb),
|
||||
'expires_at', v_key_rec.expires_at,
|
||||
'created_at', v_key_rec.created_at,
|
||||
'last_used_at', NOW(),
|
||||
'is_active', v_key_rec.is_active
|
||||
);
|
||||
|
||||
RETURN QUERY SELECT true, NULL::TEXT, v_key;
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
RETURN QUERY SELECT false, SQLERRM, NULL::JSONB;
|
||||
END;
|
||||
$$;
|
||||
61
pkg/security/keystore_sql_names.go
Normal file
61
pkg/security/keystore_sql_names.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package security
|
||||
|
||||
import "fmt"
|
||||
|
||||
// KeyStoreSQLNames holds the configurable stored procedure names used by DatabaseKeyStore.
|
||||
// Use DefaultKeyStoreSQLNames() for defaults and MergeKeyStoreSQLNames() for partial overrides.
|
||||
type KeyStoreSQLNames struct {
|
||||
GetUserKeys string // default: "resolvespec_keystore_get_user_keys"
|
||||
CreateKey string // default: "resolvespec_keystore_create_key"
|
||||
DeleteKey string // default: "resolvespec_keystore_delete_key"
|
||||
ValidateKey string // default: "resolvespec_keystore_validate_key"
|
||||
}
|
||||
|
||||
// DefaultKeyStoreSQLNames returns a KeyStoreSQLNames with all default resolvespec_keystore_* values.
|
||||
func DefaultKeyStoreSQLNames() *KeyStoreSQLNames {
|
||||
return &KeyStoreSQLNames{
|
||||
GetUserKeys: "resolvespec_keystore_get_user_keys",
|
||||
CreateKey: "resolvespec_keystore_create_key",
|
||||
DeleteKey: "resolvespec_keystore_delete_key",
|
||||
ValidateKey: "resolvespec_keystore_validate_key",
|
||||
}
|
||||
}
|
||||
|
||||
// MergeKeyStoreSQLNames returns a copy of base with any non-empty fields from override applied.
|
||||
// If override is nil, a copy of base is returned.
|
||||
func MergeKeyStoreSQLNames(base, override *KeyStoreSQLNames) *KeyStoreSQLNames {
|
||||
if override == nil {
|
||||
copied := *base
|
||||
return &copied
|
||||
}
|
||||
merged := *base
|
||||
if override.GetUserKeys != "" {
|
||||
merged.GetUserKeys = override.GetUserKeys
|
||||
}
|
||||
if override.CreateKey != "" {
|
||||
merged.CreateKey = override.CreateKey
|
||||
}
|
||||
if override.DeleteKey != "" {
|
||||
merged.DeleteKey = override.DeleteKey
|
||||
}
|
||||
if override.ValidateKey != "" {
|
||||
merged.ValidateKey = override.ValidateKey
|
||||
}
|
||||
return &merged
|
||||
}
|
||||
|
||||
// ValidateKeyStoreSQLNames checks that all non-empty procedure names are valid SQL identifiers.
|
||||
func ValidateKeyStoreSQLNames(names *KeyStoreSQLNames) error {
|
||||
fields := map[string]string{
|
||||
"GetUserKeys": names.GetUserKeys,
|
||||
"CreateKey": names.CreateKey,
|
||||
"DeleteKey": names.DeleteKey,
|
||||
"ValidateKey": names.ValidateKey,
|
||||
}
|
||||
for field, val := range fields {
|
||||
if val != "" && !validSQLIdentifier.MatchString(val) {
|
||||
return fmt.Errorf("KeyStoreSQLNames.%s contains invalid characters: %q", field, val)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -244,7 +244,7 @@ func (a *DatabaseAuthenticator) oauth2GetOrCreateUser(ctx context.Context, userC
|
||||
var errMsg *string
|
||||
var userID *int
|
||||
|
||||
err = a.db.QueryRowContext(ctx, fmt.Sprintf(`
|
||||
err = a.getDB().QueryRowContext(ctx, fmt.Sprintf(`
|
||||
SELECT p_success, p_error, p_user_id
|
||||
FROM %s($1::jsonb)
|
||||
`, a.sqlNames.OAuthGetOrCreateUser), userJSON).Scan(&success, &errMsg, &userID)
|
||||
@@ -287,7 +287,7 @@ func (a *DatabaseAuthenticator) oauth2CreateSession(ctx context.Context, session
|
||||
var success bool
|
||||
var errMsg *string
|
||||
|
||||
err = a.db.QueryRowContext(ctx, fmt.Sprintf(`
|
||||
err = a.getDB().QueryRowContext(ctx, fmt.Sprintf(`
|
||||
SELECT p_success, p_error
|
||||
FROM %s($1::jsonb)
|
||||
`, a.sqlNames.OAuthCreateSession), sessionJSON).Scan(&success, &errMsg)
|
||||
@@ -385,7 +385,7 @@ func (a *DatabaseAuthenticator) OAuth2RefreshToken(ctx context.Context, refreshT
|
||||
var errMsg *string
|
||||
var sessionData []byte
|
||||
|
||||
err = a.db.QueryRowContext(ctx, fmt.Sprintf(`
|
||||
err = a.getDB().QueryRowContext(ctx, fmt.Sprintf(`
|
||||
SELECT p_success, p_error, p_data::text
|
||||
FROM %s($1)
|
||||
`, a.sqlNames.OAuthGetRefreshToken), refreshToken).Scan(&success, &errMsg, &sessionData)
|
||||
@@ -451,7 +451,7 @@ func (a *DatabaseAuthenticator) OAuth2RefreshToken(ctx context.Context, refreshT
|
||||
var updateSuccess bool
|
||||
var updateErrMsg *string
|
||||
|
||||
err = a.db.QueryRowContext(ctx, fmt.Sprintf(`
|
||||
err = a.getDB().QueryRowContext(ctx, fmt.Sprintf(`
|
||||
SELECT p_success, p_error
|
||||
FROM %s($1::jsonb)
|
||||
`, a.sqlNames.OAuthUpdateRefreshToken), updateJSON).Scan(&updateSuccess, &updateErrMsg)
|
||||
@@ -472,7 +472,7 @@ func (a *DatabaseAuthenticator) OAuth2RefreshToken(ctx context.Context, refreshT
|
||||
var userErrMsg *string
|
||||
var userData []byte
|
||||
|
||||
err = a.db.QueryRowContext(ctx, fmt.Sprintf(`
|
||||
err = a.getDB().QueryRowContext(ctx, fmt.Sprintf(`
|
||||
SELECT p_success, p_error, p_data::text
|
||||
FROM %s($1)
|
||||
`, a.sqlNames.OAuthGetUser), session.UserID).Scan(&userSuccess, &userErrMsg, &userData)
|
||||
|
||||
917
pkg/security/oauth_server.go
Normal file
917
pkg/security/oauth_server.go
Normal file
@@ -0,0 +1,917 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// OAuthServerConfig configures the MCP-standard OAuth2 authorization server.
|
||||
type OAuthServerConfig struct {
|
||||
// Issuer is the public base URL of this server (e.g. "https://api.example.com").
|
||||
// Used in /.well-known/oauth-authorization-server and to build endpoint URLs.
|
||||
Issuer string
|
||||
|
||||
// ProviderCallbackPath is the path on this server that external OAuth2 providers
|
||||
// redirect back to. Defaults to "/oauth/provider/callback".
|
||||
ProviderCallbackPath string
|
||||
|
||||
// LoginTitle is shown on the built-in login form when the server acts as its own
|
||||
// identity provider. Defaults to "Sign in".
|
||||
LoginTitle string
|
||||
|
||||
// PersistClients stores registered clients in the database when a DatabaseAuthenticator is provided.
|
||||
// Clients registered during a session survive server restarts.
|
||||
PersistClients bool
|
||||
|
||||
// PersistCodes stores authorization codes in the database.
|
||||
// Useful for multi-instance deployments. Defaults to in-memory.
|
||||
PersistCodes bool
|
||||
|
||||
// DefaultScopes lists scopes advertised in server metadata. Defaults to ["openid","profile","email"].
|
||||
DefaultScopes []string
|
||||
|
||||
// AccessTokenTTL is the issued token lifetime. Defaults to 24h.
|
||||
AccessTokenTTL time.Duration
|
||||
|
||||
// AuthCodeTTL is the auth code lifetime. Defaults to 2 minutes.
|
||||
AuthCodeTTL time.Duration
|
||||
}
|
||||
|
||||
// oauthClient is a dynamically registered OAuth2 client (RFC 7591).
|
||||
type oauthClient struct {
|
||||
ClientID string `json:"client_id"`
|
||||
RedirectURIs []string `json:"redirect_uris"`
|
||||
ClientName string `json:"client_name,omitempty"`
|
||||
GrantTypes []string `json:"grant_types"`
|
||||
AllowedScopes []string `json:"allowed_scopes,omitempty"`
|
||||
}
|
||||
|
||||
// pendingAuth tracks an in-progress authorization code exchange.
|
||||
type pendingAuth struct {
|
||||
ClientID string
|
||||
RedirectURI string
|
||||
ClientState string
|
||||
CodeChallenge string
|
||||
CodeChallengeMethod string
|
||||
ProviderName string // empty = password login
|
||||
ExpiresAt time.Time
|
||||
SessionToken string // set after authentication completes
|
||||
RefreshToken string // set after authentication completes when refresh tokens are issued
|
||||
Scopes []string // requested scopes
|
||||
}
|
||||
|
||||
// externalProvider pairs a DatabaseAuthenticator with its provider name.
|
||||
type externalProvider struct {
|
||||
auth *DatabaseAuthenticator
|
||||
providerName string
|
||||
}
|
||||
|
||||
// OAuthServer implements the MCP-standard OAuth2 authorization server (OAuth 2.1 + PKCE).
|
||||
//
|
||||
// It can act as both:
|
||||
// - A direct identity provider using DatabaseAuthenticator username/password login
|
||||
// - A federation layer that delegates authentication to external OAuth2 providers
|
||||
// (Google, GitHub, Microsoft, etc.) registered via RegisterExternalProvider
|
||||
//
|
||||
// The server exposes these RFC-compliant endpoints:
|
||||
//
|
||||
// GET /.well-known/oauth-authorization-server RFC 8414 — server metadata discovery
|
||||
// POST /oauth/register RFC 7591 — dynamic client registration
|
||||
// GET /oauth/authorize OAuth 2.1 + PKCE — start authorization
|
||||
// POST /oauth/authorize Direct login form submission
|
||||
// POST /oauth/token Token exchange and refresh
|
||||
// POST /oauth/revoke RFC 7009 — token revocation
|
||||
// POST /oauth/introspect RFC 7662 — token introspection
|
||||
// GET {ProviderCallbackPath} Internal — external provider callback
|
||||
type OAuthServer struct {
|
||||
cfg OAuthServerConfig
|
||||
auth *DatabaseAuthenticator // nil = only external providers
|
||||
providers []externalProvider
|
||||
|
||||
mu sync.RWMutex
|
||||
clients map[string]*oauthClient
|
||||
pending map[string]*pendingAuth // provider_state → pending (external flow)
|
||||
codes map[string]*pendingAuth // auth_code → pending (post-auth)
|
||||
|
||||
done chan struct{} // closed by Close() to stop background goroutines
|
||||
}
|
||||
|
||||
// NewOAuthServer creates a new MCP OAuth2 authorization server.
|
||||
//
|
||||
// Pass a DatabaseAuthenticator to enable direct username/password login (the server
|
||||
// acts as its own identity provider). Pass nil to use only external providers.
|
||||
// External providers are added separately via RegisterExternalProvider.
|
||||
//
|
||||
// Call Close() to stop background goroutines when the server is no longer needed.
|
||||
func NewOAuthServer(cfg OAuthServerConfig, auth *DatabaseAuthenticator) *OAuthServer {
|
||||
if cfg.ProviderCallbackPath == "" {
|
||||
cfg.ProviderCallbackPath = "/oauth/provider/callback"
|
||||
}
|
||||
if cfg.LoginTitle == "" {
|
||||
cfg.LoginTitle = "Sign in"
|
||||
}
|
||||
if len(cfg.DefaultScopes) == 0 {
|
||||
cfg.DefaultScopes = []string{"openid", "profile", "email"}
|
||||
}
|
||||
if cfg.AccessTokenTTL == 0 {
|
||||
cfg.AccessTokenTTL = 24 * time.Hour
|
||||
}
|
||||
if cfg.AuthCodeTTL == 0 {
|
||||
cfg.AuthCodeTTL = 2 * time.Minute
|
||||
}
|
||||
// Normalize issuer: remove trailing slash to ensure consistent endpoint URL construction.
|
||||
cfg.Issuer = strings.TrimSuffix(cfg.Issuer, "/")
|
||||
s := &OAuthServer{
|
||||
cfg: cfg,
|
||||
auth: auth,
|
||||
clients: make(map[string]*oauthClient),
|
||||
pending: make(map[string]*pendingAuth),
|
||||
codes: make(map[string]*pendingAuth),
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
go s.cleanupExpired()
|
||||
return s
|
||||
}
|
||||
|
||||
// Close stops the background goroutines started by NewOAuthServer.
|
||||
// It is safe to call Close multiple times.
|
||||
func (s *OAuthServer) Close() {
|
||||
select {
|
||||
case <-s.done:
|
||||
// already closed
|
||||
default:
|
||||
close(s.done)
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterExternalProvider adds an external OAuth2 provider (Google, GitHub, Microsoft, etc.)
|
||||
// that handles user authentication via redirect. The DatabaseAuthenticator must have been
|
||||
// configured with WithOAuth2(providerName, ...) before calling this.
|
||||
// Multiple providers can be registered; the first is used as the default.
|
||||
// All providers must be registered before the server starts serving requests.
|
||||
func (s *OAuthServer) RegisterExternalProvider(auth *DatabaseAuthenticator, providerName string) {
|
||||
s.mu.Lock()
|
||||
s.providers = append(s.providers, externalProvider{auth: auth, providerName: providerName})
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
// ProviderCallbackPath returns the configured path for external provider callbacks.
|
||||
func (s *OAuthServer) ProviderCallbackPath() string {
|
||||
return s.cfg.ProviderCallbackPath
|
||||
}
|
||||
|
||||
// HTTPHandler returns an http.Handler that serves all RFC-required OAuth2 endpoints.
|
||||
// Mount it at the root of your HTTP server alongside the MCP transport.
|
||||
//
|
||||
// mux := http.NewServeMux()
|
||||
// mux.Handle("/", oauthServer.HTTPHandler())
|
||||
// mux.Handle("/mcp/", mcpTransport)
|
||||
func (s *OAuthServer) HTTPHandler() http.Handler {
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/.well-known/oauth-authorization-server", s.metadataHandler)
|
||||
mux.HandleFunc("/oauth/register", s.registerHandler)
|
||||
mux.HandleFunc("/oauth/authorize", s.authorizeHandler)
|
||||
mux.HandleFunc("/oauth/token", s.tokenHandler)
|
||||
mux.HandleFunc("/oauth/revoke", s.revokeHandler)
|
||||
mux.HandleFunc("/oauth/introspect", s.introspectHandler)
|
||||
mux.HandleFunc(s.cfg.ProviderCallbackPath, s.providerCallbackHandler)
|
||||
return mux
|
||||
}
|
||||
|
||||
// cleanupExpired removes stale pending auths and codes every 5 minutes.
|
||||
func (s *OAuthServer) cleanupExpired() {
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-s.done:
|
||||
return
|
||||
case <-ticker.C:
|
||||
now := time.Now()
|
||||
s.mu.Lock()
|
||||
for k, p := range s.pending {
|
||||
if now.After(p.ExpiresAt) {
|
||||
delete(s.pending, k)
|
||||
}
|
||||
}
|
||||
for k, p := range s.codes {
|
||||
if now.After(p.ExpiresAt) {
|
||||
delete(s.codes, k)
|
||||
}
|
||||
}
|
||||
s.mu.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// RFC 8414 — Server metadata
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
func (s *OAuthServer) metadataHandler(w http.ResponseWriter, r *http.Request) {
|
||||
issuer := s.cfg.Issuer
|
||||
meta := map[string]interface{}{
|
||||
"issuer": issuer,
|
||||
"authorization_endpoint": issuer + "/oauth/authorize",
|
||||
"token_endpoint": issuer + "/oauth/token",
|
||||
"registration_endpoint": issuer + "/oauth/register",
|
||||
"revocation_endpoint": issuer + "/oauth/revoke",
|
||||
"introspection_endpoint": issuer + "/oauth/introspect",
|
||||
"scopes_supported": s.cfg.DefaultScopes,
|
||||
"response_types_supported": []string{"code"},
|
||||
"grant_types_supported": []string{"authorization_code", "refresh_token"},
|
||||
"code_challenge_methods_supported": []string{"S256"},
|
||||
"token_endpoint_auth_methods_supported": []string{"none"},
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(meta) //nolint:errcheck
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// RFC 7591 — Dynamic client registration
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
func (s *OAuthServer) registerHandler(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
var req struct {
|
||||
RedirectURIs []string `json:"redirect_uris"`
|
||||
ClientName string `json:"client_name"`
|
||||
GrantTypes []string `json:"grant_types"`
|
||||
AllowedScopes []string `json:"allowed_scopes"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeOAuthError(w, "invalid_request", "malformed JSON", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if len(req.RedirectURIs) == 0 {
|
||||
writeOAuthError(w, "invalid_request", "redirect_uris required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
grantTypes := req.GrantTypes
|
||||
if len(grantTypes) == 0 {
|
||||
grantTypes = []string{"authorization_code"}
|
||||
}
|
||||
allowedScopes := req.AllowedScopes
|
||||
if len(allowedScopes) == 0 {
|
||||
allowedScopes = s.cfg.DefaultScopes
|
||||
}
|
||||
clientID, err := randomOAuthToken()
|
||||
if err != nil {
|
||||
http.Error(w, "server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
client := &oauthClient{
|
||||
ClientID: clientID,
|
||||
RedirectURIs: req.RedirectURIs,
|
||||
ClientName: req.ClientName,
|
||||
GrantTypes: grantTypes,
|
||||
AllowedScopes: allowedScopes,
|
||||
}
|
||||
|
||||
if s.cfg.PersistClients && s.auth != nil {
|
||||
dbClient := &OAuthServerClient{
|
||||
ClientID: client.ClientID,
|
||||
RedirectURIs: client.RedirectURIs,
|
||||
ClientName: client.ClientName,
|
||||
GrantTypes: client.GrantTypes,
|
||||
AllowedScopes: client.AllowedScopes,
|
||||
}
|
||||
if _, err := s.auth.OAuthRegisterClient(r.Context(), dbClient); err != nil {
|
||||
http.Error(w, "server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
s.clients[clientID] = client
|
||||
s.mu.Unlock()
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
json.NewEncoder(w).Encode(client) //nolint:errcheck
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Authorization endpoint — GET + POST /oauth/authorize
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
func (s *OAuthServer) authorizeHandler(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.Method {
|
||||
case http.MethodGet:
|
||||
s.authorizeGet(w, r)
|
||||
case http.MethodPost:
|
||||
s.authorizePost(w, r)
|
||||
default:
|
||||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||
}
|
||||
}
|
||||
|
||||
// authorizeGet validates the request and either:
|
||||
// - Redirects to an external provider (if providers are registered)
|
||||
// - Renders a login form (if the server is its own identity provider)
|
||||
func (s *OAuthServer) authorizeGet(w http.ResponseWriter, r *http.Request) {
|
||||
q := r.URL.Query()
|
||||
clientID := q.Get("client_id")
|
||||
redirectURI := q.Get("redirect_uri")
|
||||
clientState := q.Get("state")
|
||||
codeChallenge := q.Get("code_challenge")
|
||||
codeChallengeMethod := q.Get("code_challenge_method")
|
||||
providerName := q.Get("provider")
|
||||
scopeStr := q.Get("scope")
|
||||
var scopes []string
|
||||
if scopeStr != "" {
|
||||
scopes = strings.Fields(scopeStr)
|
||||
}
|
||||
|
||||
if q.Get("response_type") != "code" {
|
||||
writeOAuthError(w, "unsupported_response_type", "only 'code' is supported", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if codeChallenge == "" {
|
||||
writeOAuthError(w, "invalid_request", "code_challenge required (PKCE S256)", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if codeChallengeMethod != "" && codeChallengeMethod != "S256" {
|
||||
writeOAuthError(w, "invalid_request", "only S256 code_challenge_method is supported", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
client, ok := s.lookupOrFetchClient(r.Context(), clientID)
|
||||
if !ok {
|
||||
writeOAuthError(w, "invalid_client", "unknown client_id", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if !oauthSliceContains(client.RedirectURIs, redirectURI) {
|
||||
writeOAuthError(w, "invalid_request", "redirect_uri not registered", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// External provider path
|
||||
if len(s.providers) > 0 {
|
||||
s.redirectToExternalProvider(w, r, clientID, redirectURI, clientState, codeChallenge, codeChallengeMethod, providerName, scopes)
|
||||
return
|
||||
}
|
||||
|
||||
// Direct login form path (server is its own identity provider)
|
||||
if s.auth == nil {
|
||||
http.Error(w, "no authentication provider configured", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
s.renderLoginForm(w, r, clientID, redirectURI, clientState, codeChallenge, codeChallengeMethod, scopeStr, "")
|
||||
}
|
||||
|
||||
// authorizePost handles login form submission for the direct login flow.
|
||||
func (s *OAuthServer) authorizePost(w http.ResponseWriter, r *http.Request) {
|
||||
if err := r.ParseForm(); err != nil {
|
||||
http.Error(w, "invalid form", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
clientID := r.FormValue("client_id")
|
||||
redirectURI := r.FormValue("redirect_uri")
|
||||
clientState := r.FormValue("client_state")
|
||||
codeChallenge := r.FormValue("code_challenge")
|
||||
codeChallengeMethod := r.FormValue("code_challenge_method")
|
||||
username := r.FormValue("username")
|
||||
password := r.FormValue("password")
|
||||
scopeStr := r.FormValue("scope")
|
||||
var scopes []string
|
||||
if scopeStr != "" {
|
||||
scopes = strings.Fields(scopeStr)
|
||||
}
|
||||
|
||||
client, ok := s.lookupOrFetchClient(r.Context(), clientID)
|
||||
if !ok || !oauthSliceContains(client.RedirectURIs, redirectURI) {
|
||||
http.Error(w, "invalid client or redirect_uri", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if s.auth == nil {
|
||||
http.Error(w, "no authentication provider configured", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
loginResp, err := s.auth.Login(r.Context(), LoginRequest{
|
||||
Username: username,
|
||||
Password: password,
|
||||
})
|
||||
if err != nil {
|
||||
s.renderLoginForm(w, r, clientID, redirectURI, clientState, codeChallenge, codeChallengeMethod, scopeStr, "Invalid username or password")
|
||||
return
|
||||
}
|
||||
|
||||
s.issueCodeAndRedirect(w, r, loginResp.Token, loginResp.RefreshToken, clientID, redirectURI, clientState, codeChallenge, codeChallengeMethod, "", scopes)
|
||||
}
|
||||
|
||||
// redirectToExternalProvider stores the pending auth and redirects to the configured provider.
|
||||
func (s *OAuthServer) redirectToExternalProvider(w http.ResponseWriter, r *http.Request, clientID, redirectURI, clientState, codeChallenge, codeChallengeMethod, providerName string, scopes []string) {
|
||||
var provider *externalProvider
|
||||
if providerName != "" {
|
||||
for i := range s.providers {
|
||||
if s.providers[i].providerName == providerName {
|
||||
provider = &s.providers[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
if provider == nil {
|
||||
http.Error(w, fmt.Sprintf("provider %q not found", providerName), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
provider = &s.providers[0]
|
||||
}
|
||||
|
||||
providerState, err := randomOAuthToken()
|
||||
if err != nil {
|
||||
http.Error(w, "server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
pending := &pendingAuth{
|
||||
ClientID: clientID,
|
||||
RedirectURI: redirectURI,
|
||||
ClientState: clientState,
|
||||
CodeChallenge: codeChallenge,
|
||||
CodeChallengeMethod: codeChallengeMethod,
|
||||
ProviderName: provider.providerName,
|
||||
ExpiresAt: time.Now().Add(10 * time.Minute),
|
||||
Scopes: scopes,
|
||||
}
|
||||
s.mu.Lock()
|
||||
s.pending[providerState] = pending
|
||||
s.mu.Unlock()
|
||||
|
||||
authURL, err := provider.auth.OAuth2GetAuthURL(provider.providerName, providerState)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
http.Redirect(w, r, authURL, http.StatusFound)
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// External provider callback — GET {ProviderCallbackPath}
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
func (s *OAuthServer) providerCallbackHandler(w http.ResponseWriter, r *http.Request) {
|
||||
code := r.URL.Query().Get("code")
|
||||
providerState := r.URL.Query().Get("state")
|
||||
|
||||
if code == "" {
|
||||
http.Error(w, "missing code", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
pending, ok := s.pending[providerState]
|
||||
if ok {
|
||||
delete(s.pending, providerState)
|
||||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
if !ok || time.Now().After(pending.ExpiresAt) {
|
||||
http.Error(w, "invalid or expired state", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
provider := s.providerByName(pending.ProviderName)
|
||||
if provider == nil {
|
||||
http.Error(w, fmt.Sprintf("provider %q not found", pending.ProviderName), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
loginResp, err := provider.auth.OAuth2HandleCallback(r.Context(), pending.ProviderName, code, providerState)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
s.issueCodeAndRedirect(w, r, loginResp.Token, loginResp.RefreshToken,
|
||||
pending.ClientID, pending.RedirectURI, pending.ClientState,
|
||||
pending.CodeChallenge, pending.CodeChallengeMethod, pending.ProviderName, pending.Scopes)
|
||||
}
|
||||
|
||||
// issueCodeAndRedirect generates a short-lived auth code and redirects to the MCP client.
|
||||
func (s *OAuthServer) issueCodeAndRedirect(w http.ResponseWriter, r *http.Request, sessionToken, refreshToken, clientID, redirectURI, clientState, codeChallenge, codeChallengeMethod, providerName string, scopes []string) {
|
||||
authCode, err := randomOAuthToken()
|
||||
if err != nil {
|
||||
http.Error(w, "server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
pending := &pendingAuth{
|
||||
ClientID: clientID,
|
||||
RedirectURI: redirectURI,
|
||||
ClientState: clientState,
|
||||
CodeChallenge: codeChallenge,
|
||||
CodeChallengeMethod: codeChallengeMethod,
|
||||
ProviderName: providerName,
|
||||
SessionToken: sessionToken,
|
||||
RefreshToken: refreshToken,
|
||||
ExpiresAt: time.Now().Add(s.cfg.AuthCodeTTL),
|
||||
Scopes: scopes,
|
||||
}
|
||||
|
||||
if s.cfg.PersistCodes && s.auth != nil {
|
||||
oauthCode := &OAuthCode{
|
||||
Code: authCode,
|
||||
ClientID: clientID,
|
||||
RedirectURI: redirectURI,
|
||||
ClientState: clientState,
|
||||
CodeChallenge: codeChallenge,
|
||||
CodeChallengeMethod: codeChallengeMethod,
|
||||
SessionToken: sessionToken,
|
||||
RefreshToken: refreshToken,
|
||||
Scopes: scopes,
|
||||
ExpiresAt: pending.ExpiresAt,
|
||||
}
|
||||
if err := s.auth.OAuthSaveCode(r.Context(), oauthCode); err != nil {
|
||||
http.Error(w, "server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
s.mu.Lock()
|
||||
s.codes[authCode] = pending
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
redirectURL, err := url.Parse(redirectURI)
|
||||
if err != nil {
|
||||
http.Error(w, "invalid redirect_uri", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
qp := redirectURL.Query()
|
||||
qp.Set("code", authCode)
|
||||
if clientState != "" {
|
||||
qp.Set("state", clientState)
|
||||
}
|
||||
redirectURL.RawQuery = qp.Encode()
|
||||
http.Redirect(w, r, redirectURL.String(), http.StatusFound)
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Token endpoint — POST /oauth/token
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
func (s *OAuthServer) tokenHandler(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
if err := r.ParseForm(); err != nil {
|
||||
writeOAuthError(w, "invalid_request", "cannot parse form", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
switch r.FormValue("grant_type") {
|
||||
case "authorization_code":
|
||||
s.handleAuthCodeGrant(w, r)
|
||||
case "refresh_token":
|
||||
s.handleRefreshGrant(w, r)
|
||||
default:
|
||||
writeOAuthError(w, "unsupported_grant_type", "", http.StatusBadRequest)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OAuthServer) handleAuthCodeGrant(w http.ResponseWriter, r *http.Request) {
|
||||
code := r.FormValue("code")
|
||||
redirectURI := r.FormValue("redirect_uri")
|
||||
clientID := r.FormValue("client_id")
|
||||
codeVerifier := r.FormValue("code_verifier")
|
||||
|
||||
if code == "" || codeVerifier == "" {
|
||||
writeOAuthError(w, "invalid_request", "code and code_verifier required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
var sessionToken string
|
||||
var refreshToken string
|
||||
var scopes []string
|
||||
|
||||
if s.cfg.PersistCodes && s.auth != nil {
|
||||
oauthCode, err := s.auth.OAuthExchangeCode(r.Context(), code)
|
||||
if err != nil {
|
||||
writeOAuthError(w, "invalid_grant", "code expired or invalid", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if oauthCode.ClientID != clientID {
|
||||
writeOAuthError(w, "invalid_client", "", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if oauthCode.RedirectURI != redirectURI {
|
||||
writeOAuthError(w, "invalid_grant", "redirect_uri mismatch", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if !validatePKCESHA256(oauthCode.CodeChallenge, codeVerifier) {
|
||||
writeOAuthError(w, "invalid_grant", "code_verifier invalid", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
sessionToken = oauthCode.SessionToken
|
||||
refreshToken = oauthCode.RefreshToken
|
||||
scopes = oauthCode.Scopes
|
||||
} else {
|
||||
s.mu.Lock()
|
||||
pending, ok := s.codes[code]
|
||||
if ok {
|
||||
delete(s.codes, code)
|
||||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
if !ok || time.Now().After(pending.ExpiresAt) {
|
||||
writeOAuthError(w, "invalid_grant", "code expired or invalid", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if pending.ClientID != clientID {
|
||||
writeOAuthError(w, "invalid_client", "", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if pending.RedirectURI != redirectURI {
|
||||
writeOAuthError(w, "invalid_grant", "redirect_uri mismatch", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if !validatePKCESHA256(pending.CodeChallenge, codeVerifier) {
|
||||
writeOAuthError(w, "invalid_grant", "code_verifier invalid", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
sessionToken = pending.SessionToken
|
||||
refreshToken = pending.RefreshToken
|
||||
scopes = pending.Scopes
|
||||
}
|
||||
|
||||
s.writeOAuthToken(w, sessionToken, refreshToken, scopes)
|
||||
}
|
||||
|
||||
func (s *OAuthServer) handleRefreshGrant(w http.ResponseWriter, r *http.Request) {
|
||||
refreshToken := r.FormValue("refresh_token")
|
||||
providerName := r.FormValue("provider")
|
||||
if refreshToken == "" {
|
||||
writeOAuthError(w, "invalid_request", "refresh_token required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Try external providers first, then fall back to DatabaseAuthenticator
|
||||
provider := s.providerByName(providerName)
|
||||
if provider != nil {
|
||||
loginResp, err := provider.auth.OAuth2RefreshToken(r.Context(), refreshToken, providerName)
|
||||
if err != nil {
|
||||
writeOAuthError(w, "invalid_grant", err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
s.writeOAuthToken(w, loginResp.Token, loginResp.RefreshToken, nil)
|
||||
return
|
||||
}
|
||||
|
||||
if s.auth != nil {
|
||||
loginResp, err := s.auth.RefreshToken(r.Context(), refreshToken)
|
||||
if err != nil {
|
||||
writeOAuthError(w, "invalid_grant", err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
s.writeOAuthToken(w, loginResp.Token, loginResp.RefreshToken, nil)
|
||||
return
|
||||
}
|
||||
|
||||
writeOAuthError(w, "invalid_grant", "no provider available for refresh", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// RFC 7009 — Token revocation
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
func (s *OAuthServer) revokeHandler(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
if err := r.ParseForm(); err != nil {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
token := r.FormValue("token")
|
||||
if token == "" {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
|
||||
if s.auth != nil {
|
||||
s.auth.OAuthRevokeToken(r.Context(), token) //nolint:errcheck
|
||||
} else {
|
||||
// In external-provider-only mode, attempt revocation via the first provider's auth.
|
||||
s.mu.RLock()
|
||||
var providerAuth *DatabaseAuthenticator
|
||||
if len(s.providers) > 0 {
|
||||
providerAuth = s.providers[0].auth
|
||||
}
|
||||
s.mu.RUnlock()
|
||||
if providerAuth != nil {
|
||||
providerAuth.OAuthRevokeToken(r.Context(), token) //nolint:errcheck
|
||||
}
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// RFC 7662 — Token introspection
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
func (s *OAuthServer) introspectHandler(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
if err := r.ParseForm(); err != nil {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write([]byte(`{"active":false}`)) //nolint:errcheck
|
||||
return
|
||||
}
|
||||
token := r.FormValue("token")
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
if token == "" {
|
||||
w.Write([]byte(`{"active":false}`)) //nolint:errcheck
|
||||
return
|
||||
}
|
||||
|
||||
// Resolve the authenticator to use: prefer the primary auth, then the first provider's auth.
|
||||
authToUse := s.auth
|
||||
if authToUse == nil {
|
||||
s.mu.RLock()
|
||||
if len(s.providers) > 0 {
|
||||
authToUse = s.providers[0].auth
|
||||
}
|
||||
s.mu.RUnlock()
|
||||
}
|
||||
if authToUse == nil {
|
||||
w.Write([]byte(`{"active":false}`)) //nolint:errcheck
|
||||
return
|
||||
}
|
||||
|
||||
info, err := authToUse.OAuthIntrospectToken(r.Context(), token)
|
||||
if err != nil {
|
||||
w.Write([]byte(`{"active":false}`)) //nolint:errcheck
|
||||
return
|
||||
}
|
||||
json.NewEncoder(w).Encode(info) //nolint:errcheck
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Login form (direct identity provider mode)
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
func (s *OAuthServer) renderLoginForm(w http.ResponseWriter, r *http.Request, clientID, redirectURI, clientState, codeChallenge, codeChallengeMethod, scope, errMsg string) {
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
errHTML := ""
|
||||
if errMsg != "" {
|
||||
errHTML = `<p style="color:red">` + errMsg + `</p>`
|
||||
}
|
||||
fmt.Fprintf(w, loginFormHTML,
|
||||
s.cfg.LoginTitle,
|
||||
s.cfg.LoginTitle,
|
||||
errHTML,
|
||||
clientID,
|
||||
htmlEscape(redirectURI),
|
||||
htmlEscape(clientState),
|
||||
htmlEscape(codeChallenge),
|
||||
htmlEscape(codeChallengeMethod),
|
||||
htmlEscape(scope),
|
||||
)
|
||||
}
|
||||
|
||||
const loginFormHTML = `<!DOCTYPE html>
|
||||
<html><head><meta charset="utf-8"><title>%s</title>
|
||||
<style>body{font-family:sans-serif;display:flex;justify-content:center;align-items:center;min-height:100vh;margin:0;background:#f5f5f5}
|
||||
.card{background:#fff;padding:2rem;border-radius:8px;box-shadow:0 2px 8px rgba(0,0,0,.15);width:320px}
|
||||
h2{margin:0 0 1.5rem;font-size:1.25rem}
|
||||
label{display:block;margin-bottom:.25rem;font-size:.875rem;color:#555}
|
||||
input[type=text],input[type=password]{width:100%%;box-sizing:border-box;padding:.5rem;border:1px solid #ccc;border-radius:4px;margin-bottom:1rem;font-size:1rem}
|
||||
button{width:100%%;padding:.6rem;background:#0070f3;color:#fff;border:none;border-radius:4px;font-size:1rem;cursor:pointer}
|
||||
button:hover{background:#005fd4}.err{color:#d32f2f;margin-bottom:1rem;font-size:.875rem}</style>
|
||||
</head><body><div class="card">
|
||||
<h2>%s</h2>%s
|
||||
<form method="POST" action="authorize">
|
||||
<input type="hidden" name="client_id" value="%s">
|
||||
<input type="hidden" name="redirect_uri" value="%s">
|
||||
<input type="hidden" name="client_state" value="%s">
|
||||
<input type="hidden" name="code_challenge" value="%s">
|
||||
<input type="hidden" name="code_challenge_method" value="%s">
|
||||
<input type="hidden" name="scope" value="%s">
|
||||
<label>Username</label><input type="text" name="username" autofocus autocomplete="username">
|
||||
<label>Password</label><input type="password" name="password" autocomplete="current-password">
|
||||
<button type="submit">Sign in</button>
|
||||
</form></div></body></html>`
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Helpers
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
// lookupOrFetchClient checks in-memory first, then DB if PersistClients is enabled.
|
||||
func (s *OAuthServer) lookupOrFetchClient(ctx context.Context, clientID string) (*oauthClient, bool) {
|
||||
s.mu.RLock()
|
||||
c, ok := s.clients[clientID]
|
||||
s.mu.RUnlock()
|
||||
if ok {
|
||||
return c, true
|
||||
}
|
||||
|
||||
if !s.cfg.PersistClients || s.auth == nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
dbClient, err := s.auth.OAuthGetClient(ctx, clientID)
|
||||
if err != nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
c = &oauthClient{
|
||||
ClientID: dbClient.ClientID,
|
||||
RedirectURIs: dbClient.RedirectURIs,
|
||||
ClientName: dbClient.ClientName,
|
||||
GrantTypes: dbClient.GrantTypes,
|
||||
AllowedScopes: dbClient.AllowedScopes,
|
||||
}
|
||||
s.mu.Lock()
|
||||
s.clients[clientID] = c
|
||||
s.mu.Unlock()
|
||||
return c, true
|
||||
}
|
||||
|
||||
func (s *OAuthServer) providerByName(name string) *externalProvider {
|
||||
for i := range s.providers {
|
||||
if s.providers[i].providerName == name {
|
||||
return &s.providers[i]
|
||||
}
|
||||
}
|
||||
// If name is empty and only one provider exists, return it
|
||||
if name == "" && len(s.providers) == 1 {
|
||||
return &s.providers[0]
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func validatePKCESHA256(challenge, verifier string) bool {
|
||||
h := sha256.Sum256([]byte(verifier))
|
||||
return base64.RawURLEncoding.EncodeToString(h[:]) == challenge
|
||||
}
|
||||
|
||||
func randomOAuthToken() (string, error) {
|
||||
b := make([]byte, 32)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.RawURLEncoding.EncodeToString(b), nil
|
||||
}
|
||||
|
||||
func oauthSliceContains(slice []string, s string) bool {
|
||||
for _, v := range slice {
|
||||
if v == s {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *OAuthServer) writeOAuthToken(w http.ResponseWriter, accessToken, refreshToken string, scopes []string) {
|
||||
expiresIn := int64(s.cfg.AccessTokenTTL.Seconds())
|
||||
resp := map[string]interface{}{
|
||||
"access_token": accessToken,
|
||||
"token_type": "Bearer",
|
||||
"expires_in": expiresIn,
|
||||
}
|
||||
if refreshToken != "" {
|
||||
resp["refresh_token"] = refreshToken
|
||||
}
|
||||
if len(scopes) > 0 {
|
||||
resp["scope"] = strings.Join(scopes, " ")
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Cache-Control", "no-store")
|
||||
w.Header().Set("Pragma", "no-cache")
|
||||
json.NewEncoder(w).Encode(resp) //nolint:errcheck
|
||||
}
|
||||
|
||||
func writeOAuthError(w http.ResponseWriter, errCode, description string, status int) {
|
||||
resp := map[string]string{"error": errCode}
|
||||
if description != "" {
|
||||
resp["error_description"] = description
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
json.NewEncoder(w).Encode(resp) //nolint:errcheck
|
||||
}
|
||||
|
||||
func htmlEscape(s string) string {
|
||||
s = strings.ReplaceAll(s, "&", "&")
|
||||
s = strings.ReplaceAll(s, `"`, """)
|
||||
s = strings.ReplaceAll(s, "<", "<")
|
||||
s = strings.ReplaceAll(s, ">", ">")
|
||||
return s
|
||||
}
|
||||
204
pkg/security/oauth_server_db.go
Normal file
204
pkg/security/oauth_server_db.go
Normal file
@@ -0,0 +1,204 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
// OAuthServerClient is a persisted RFC 7591 registered OAuth2 client.
|
||||
type OAuthServerClient struct {
|
||||
ClientID string `json:"client_id"`
|
||||
RedirectURIs []string `json:"redirect_uris"`
|
||||
ClientName string `json:"client_name,omitempty"`
|
||||
GrantTypes []string `json:"grant_types"`
|
||||
AllowedScopes []string `json:"allowed_scopes,omitempty"`
|
||||
}
|
||||
|
||||
// OAuthCode is a short-lived authorization code.
|
||||
type OAuthCode struct {
|
||||
Code string `json:"code"`
|
||||
ClientID string `json:"client_id"`
|
||||
RedirectURI string `json:"redirect_uri"`
|
||||
ClientState string `json:"client_state,omitempty"`
|
||||
CodeChallenge string `json:"code_challenge"`
|
||||
CodeChallengeMethod string `json:"code_challenge_method"`
|
||||
SessionToken string `json:"session_token"`
|
||||
RefreshToken string `json:"refresh_token,omitempty"`
|
||||
Scopes []string `json:"scopes,omitempty"`
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
}
|
||||
|
||||
// OAuthTokenInfo is the RFC 7662 token introspection response.
|
||||
type OAuthTokenInfo struct {
|
||||
Active bool `json:"active"`
|
||||
Sub string `json:"sub,omitempty"`
|
||||
Username string `json:"username,omitempty"`
|
||||
Email string `json:"email,omitempty"`
|
||||
UserLevel int `json:"user_level,omitempty"`
|
||||
Roles []string `json:"roles,omitempty"`
|
||||
Exp int64 `json:"exp,omitempty"`
|
||||
Iat int64 `json:"iat,omitempty"`
|
||||
}
|
||||
|
||||
// OAuthRegisterClient persists an OAuth2 client registration.
|
||||
func (a *DatabaseAuthenticator) OAuthRegisterClient(ctx context.Context, client *OAuthServerClient) (*OAuthServerClient, error) {
|
||||
input, err := json.Marshal(client)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal client: %w", err)
|
||||
}
|
||||
|
||||
var success bool
|
||||
var errMsg *string
|
||||
var data []byte
|
||||
|
||||
err = a.db.QueryRowContext(ctx, fmt.Sprintf(`
|
||||
SELECT p_success, p_error, p_data::text
|
||||
FROM %s($1::jsonb)
|
||||
`, a.sqlNames.OAuthRegisterClient), input).Scan(&success, &errMsg, &data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to register client: %w", err)
|
||||
}
|
||||
if !success {
|
||||
if errMsg != nil {
|
||||
return nil, fmt.Errorf("%s", *errMsg)
|
||||
}
|
||||
return nil, fmt.Errorf("failed to register client")
|
||||
}
|
||||
|
||||
var result OAuthServerClient
|
||||
if err := json.Unmarshal(data, &result); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse registered client: %w", err)
|
||||
}
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// OAuthGetClient retrieves a registered client by ID.
|
||||
func (a *DatabaseAuthenticator) OAuthGetClient(ctx context.Context, clientID string) (*OAuthServerClient, error) {
|
||||
var success bool
|
||||
var errMsg *string
|
||||
var data []byte
|
||||
|
||||
err := a.db.QueryRowContext(ctx, fmt.Sprintf(`
|
||||
SELECT p_success, p_error, p_data::text
|
||||
FROM %s($1)
|
||||
`, a.sqlNames.OAuthGetClient), clientID).Scan(&success, &errMsg, &data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get client: %w", err)
|
||||
}
|
||||
if !success {
|
||||
if errMsg != nil {
|
||||
return nil, fmt.Errorf("%s", *errMsg)
|
||||
}
|
||||
return nil, fmt.Errorf("client not found")
|
||||
}
|
||||
|
||||
var result OAuthServerClient
|
||||
if err := json.Unmarshal(data, &result); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse client: %w", err)
|
||||
}
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// OAuthSaveCode persists an authorization code.
|
||||
func (a *DatabaseAuthenticator) OAuthSaveCode(ctx context.Context, code *OAuthCode) error {
|
||||
input, err := json.Marshal(code)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal code: %w", err)
|
||||
}
|
||||
|
||||
var success bool
|
||||
var errMsg *string
|
||||
|
||||
err = a.db.QueryRowContext(ctx, fmt.Sprintf(`
|
||||
SELECT p_success, p_error
|
||||
FROM %s($1::jsonb)
|
||||
`, a.sqlNames.OAuthSaveCode), input).Scan(&success, &errMsg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to save code: %w", err)
|
||||
}
|
||||
if !success {
|
||||
if errMsg != nil {
|
||||
return fmt.Errorf("%s", *errMsg)
|
||||
}
|
||||
return fmt.Errorf("failed to save code")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// OAuthExchangeCode retrieves and deletes an authorization code (single use).
|
||||
func (a *DatabaseAuthenticator) OAuthExchangeCode(ctx context.Context, code string) (*OAuthCode, error) {
|
||||
var success bool
|
||||
var errMsg *string
|
||||
var data []byte
|
||||
|
||||
err := a.db.QueryRowContext(ctx, fmt.Sprintf(`
|
||||
SELECT p_success, p_error, p_data::text
|
||||
FROM %s($1)
|
||||
`, a.sqlNames.OAuthExchangeCode), code).Scan(&success, &errMsg, &data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to exchange code: %w", err)
|
||||
}
|
||||
if !success {
|
||||
if errMsg != nil {
|
||||
return nil, fmt.Errorf("%s", *errMsg)
|
||||
}
|
||||
return nil, fmt.Errorf("invalid or expired code")
|
||||
}
|
||||
|
||||
var result OAuthCode
|
||||
if err := json.Unmarshal(data, &result); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse code data: %w", err)
|
||||
}
|
||||
result.Code = code
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// OAuthIntrospectToken validates a token and returns its metadata (RFC 7662).
|
||||
func (a *DatabaseAuthenticator) OAuthIntrospectToken(ctx context.Context, token string) (*OAuthTokenInfo, error) {
|
||||
var success bool
|
||||
var errMsg *string
|
||||
var data []byte
|
||||
|
||||
err := a.db.QueryRowContext(ctx, fmt.Sprintf(`
|
||||
SELECT p_success, p_error, p_data::text
|
||||
FROM %s($1)
|
||||
`, a.sqlNames.OAuthIntrospect), token).Scan(&success, &errMsg, &data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to introspect token: %w", err)
|
||||
}
|
||||
if !success {
|
||||
if errMsg != nil {
|
||||
return nil, fmt.Errorf("%s", *errMsg)
|
||||
}
|
||||
return nil, fmt.Errorf("introspection failed")
|
||||
}
|
||||
|
||||
var result OAuthTokenInfo
|
||||
if err := json.Unmarshal(data, &result); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse token info: %w", err)
|
||||
}
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// OAuthRevokeToken revokes a token by deleting the session (RFC 7009).
|
||||
func (a *DatabaseAuthenticator) OAuthRevokeToken(ctx context.Context, token string) error {
|
||||
var success bool
|
||||
var errMsg *string
|
||||
|
||||
err := a.db.QueryRowContext(ctx, fmt.Sprintf(`
|
||||
SELECT p_success, p_error
|
||||
FROM %s($1)
|
||||
`, a.sqlNames.OAuthRevoke), token).Scan(&success, &errMsg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to revoke token: %w", err)
|
||||
}
|
||||
if !success {
|
||||
if errMsg != nil {
|
||||
return fmt.Errorf("%s", *errMsg)
|
||||
}
|
||||
return fmt.Errorf("failed to revoke token")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -7,18 +7,21 @@ import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// DatabasePasskeyProvider implements PasskeyProvider using database storage
|
||||
// Procedure names are configurable via SQLNames (see DefaultSQLNames for defaults)
|
||||
type DatabasePasskeyProvider struct {
|
||||
db *sql.DB
|
||||
rpID string // Relying Party ID (domain)
|
||||
rpName string // Relying Party display name
|
||||
rpOrigin string // Expected origin for WebAuthn
|
||||
timeout int64 // Timeout in milliseconds (default: 60000)
|
||||
sqlNames *SQLNames
|
||||
db *sql.DB
|
||||
dbMu sync.RWMutex
|
||||
dbFactory func() (*sql.DB, error)
|
||||
rpID string // Relying Party ID (domain)
|
||||
rpName string // Relying Party display name
|
||||
rpOrigin string // Expected origin for WebAuthn
|
||||
timeout int64 // Timeout in milliseconds (default: 60000)
|
||||
sqlNames *SQLNames
|
||||
}
|
||||
|
||||
// DatabasePasskeyProviderOptions configures the passkey provider
|
||||
@@ -33,6 +36,9 @@ type DatabasePasskeyProviderOptions struct {
|
||||
Timeout int64
|
||||
// SQLNames provides custom SQL procedure/function names. If nil, uses DefaultSQLNames().
|
||||
SQLNames *SQLNames
|
||||
// DBFactory is called to obtain a fresh *sql.DB when the existing connection is closed.
|
||||
// If nil, reconnection is disabled.
|
||||
DBFactory func() (*sql.DB, error)
|
||||
}
|
||||
|
||||
// NewDatabasePasskeyProvider creates a new database-backed passkey provider
|
||||
@@ -44,15 +50,36 @@ func NewDatabasePasskeyProvider(db *sql.DB, opts DatabasePasskeyProviderOptions)
|
||||
sqlNames := MergeSQLNames(DefaultSQLNames(), opts.SQLNames)
|
||||
|
||||
return &DatabasePasskeyProvider{
|
||||
db: db,
|
||||
rpID: opts.RPID,
|
||||
rpName: opts.RPName,
|
||||
rpOrigin: opts.RPOrigin,
|
||||
timeout: opts.Timeout,
|
||||
sqlNames: sqlNames,
|
||||
db: db,
|
||||
dbFactory: opts.DBFactory,
|
||||
rpID: opts.RPID,
|
||||
rpName: opts.RPName,
|
||||
rpOrigin: opts.RPOrigin,
|
||||
timeout: opts.Timeout,
|
||||
sqlNames: sqlNames,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *DatabasePasskeyProvider) getDB() *sql.DB {
|
||||
p.dbMu.RLock()
|
||||
defer p.dbMu.RUnlock()
|
||||
return p.db
|
||||
}
|
||||
|
||||
func (p *DatabasePasskeyProvider) reconnectDB() error {
|
||||
if p.dbFactory == nil {
|
||||
return fmt.Errorf("no db factory configured for reconnect")
|
||||
}
|
||||
newDB, err := p.dbFactory()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
p.dbMu.Lock()
|
||||
p.db = newDB
|
||||
p.dbMu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
// BeginRegistration creates registration options for a new passkey
|
||||
func (p *DatabasePasskeyProvider) BeginRegistration(ctx context.Context, userID int, username, displayName string) (*PasskeyRegistrationOptions, error) {
|
||||
// Generate challenge
|
||||
@@ -140,7 +167,7 @@ func (p *DatabasePasskeyProvider) CompleteRegistration(ctx context.Context, user
|
||||
var credentialID sql.NullInt64
|
||||
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_credential_id FROM %s($1::jsonb)`, p.sqlNames.PasskeyStoreCredential)
|
||||
err = p.db.QueryRowContext(ctx, query, string(credJSON)).Scan(&success, &errorMsg, &credentialID)
|
||||
err = p.getDB().QueryRowContext(ctx, query, string(credJSON)).Scan(&success, &errorMsg, &credentialID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to store credential: %w", err)
|
||||
}
|
||||
@@ -181,7 +208,7 @@ func (p *DatabasePasskeyProvider) BeginAuthentication(ctx context.Context, usern
|
||||
var credentialsJSON sql.NullString
|
||||
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_user_id, p_credentials::text FROM %s($1)`, p.sqlNames.PasskeyGetCredsByUsername)
|
||||
err := p.db.QueryRowContext(ctx, query, username).Scan(&success, &errorMsg, &userID, &credentialsJSON)
|
||||
err := p.getDB().QueryRowContext(ctx, query, username).Scan(&success, &errorMsg, &userID, &credentialsJSON)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get credentials: %w", err)
|
||||
}
|
||||
@@ -240,8 +267,16 @@ func (p *DatabasePasskeyProvider) CompleteAuthentication(ctx context.Context, re
|
||||
var errorMsg sql.NullString
|
||||
var credentialJSON sql.NullString
|
||||
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_credential::text FROM %s($1)`, p.sqlNames.PasskeyGetCredential)
|
||||
err := p.db.QueryRowContext(ctx, query, response.RawID).Scan(&success, &errorMsg, &credentialJSON)
|
||||
runQuery := func() error {
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_credential::text FROM %s($1)`, p.sqlNames.PasskeyGetCredential)
|
||||
return p.getDB().QueryRowContext(ctx, query, response.RawID).Scan(&success, &errorMsg, &credentialJSON)
|
||||
}
|
||||
err := runQuery()
|
||||
if isDBClosed(err) {
|
||||
if reconnErr := p.reconnectDB(); reconnErr == nil {
|
||||
err = runQuery()
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to get credential: %w", err)
|
||||
}
|
||||
@@ -272,7 +307,7 @@ func (p *DatabasePasskeyProvider) CompleteAuthentication(ctx context.Context, re
|
||||
var cloneWarning sql.NullBool
|
||||
|
||||
updateQuery := fmt.Sprintf(`SELECT p_success, p_error, p_clone_warning FROM %s($1, $2)`, p.sqlNames.PasskeyUpdateCounter)
|
||||
err = p.db.QueryRowContext(ctx, updateQuery, response.RawID, newCounter).Scan(&updateSuccess, &updateError, &cloneWarning)
|
||||
err = p.getDB().QueryRowContext(ctx, updateQuery, response.RawID, newCounter).Scan(&updateSuccess, &updateError, &cloneWarning)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to update counter: %w", err)
|
||||
}
|
||||
@@ -291,7 +326,7 @@ func (p *DatabasePasskeyProvider) GetCredentials(ctx context.Context, userID int
|
||||
var credentialsJSON sql.NullString
|
||||
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_credentials::text FROM %s($1)`, p.sqlNames.PasskeyGetUserCredentials)
|
||||
err := p.db.QueryRowContext(ctx, query, userID).Scan(&success, &errorMsg, &credentialsJSON)
|
||||
err := p.getDB().QueryRowContext(ctx, query, userID).Scan(&success, &errorMsg, &credentialsJSON)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get credentials: %w", err)
|
||||
}
|
||||
@@ -370,7 +405,7 @@ func (p *DatabasePasskeyProvider) DeleteCredential(ctx context.Context, userID i
|
||||
var errorMsg sql.NullString
|
||||
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error FROM %s($1, $2)`, p.sqlNames.PasskeyDeleteCredential)
|
||||
err = p.db.QueryRowContext(ctx, query, userID, credID).Scan(&success, &errorMsg)
|
||||
err = p.getDB().QueryRowContext(ctx, query, userID, credID).Scan(&success, &errorMsg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete credential: %w", err)
|
||||
}
|
||||
@@ -396,7 +431,7 @@ func (p *DatabasePasskeyProvider) UpdateCredentialName(ctx context.Context, user
|
||||
var errorMsg sql.NullString
|
||||
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error FROM %s($1, $2, $3)`, p.sqlNames.PasskeyUpdateName)
|
||||
err = p.db.QueryRowContext(ctx, query, userID, credID, name).Scan(&success, &errorMsg)
|
||||
err = p.getDB().QueryRowContext(ctx, query, userID, credID, name).Scan(&success, &errorMsg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update credential name: %w", err)
|
||||
}
|
||||
|
||||
@@ -63,10 +63,12 @@ func (a *HeaderAuthenticator) Authenticate(r *http.Request) (*UserContext, error
|
||||
// Also supports multiple OAuth2 providers configured with WithOAuth2()
|
||||
// Also supports passkey authentication configured with WithPasskey()
|
||||
type DatabaseAuthenticator struct {
|
||||
db *sql.DB
|
||||
cache *cache.Cache
|
||||
cacheTTL time.Duration
|
||||
sqlNames *SQLNames
|
||||
db *sql.DB
|
||||
dbMu sync.RWMutex
|
||||
dbFactory func() (*sql.DB, error)
|
||||
cache *cache.Cache
|
||||
cacheTTL time.Duration
|
||||
sqlNames *SQLNames
|
||||
|
||||
// OAuth2 providers registry (multiple providers supported)
|
||||
oauth2Providers map[string]*OAuth2Provider
|
||||
@@ -88,6 +90,9 @@ type DatabaseAuthenticatorOptions struct {
|
||||
// SQLNames provides custom SQL procedure/function names. If nil, uses DefaultSQLNames().
|
||||
// Partial overrides are supported: only set the fields you want to change.
|
||||
SQLNames *SQLNames
|
||||
// DBFactory is called to obtain a fresh *sql.DB when the existing connection is closed.
|
||||
// If nil, reconnection is disabled.
|
||||
DBFactory func() (*sql.DB, error)
|
||||
}
|
||||
|
||||
func NewDatabaseAuthenticator(db *sql.DB) *DatabaseAuthenticator {
|
||||
@@ -110,6 +115,7 @@ func NewDatabaseAuthenticatorWithOptions(db *sql.DB, opts DatabaseAuthenticatorO
|
||||
|
||||
return &DatabaseAuthenticator{
|
||||
db: db,
|
||||
dbFactory: opts.DBFactory,
|
||||
cache: cacheInstance,
|
||||
cacheTTL: opts.CacheTTL,
|
||||
sqlNames: sqlNames,
|
||||
@@ -117,6 +123,42 @@ func NewDatabaseAuthenticatorWithOptions(db *sql.DB, opts DatabaseAuthenticatorO
|
||||
}
|
||||
}
|
||||
|
||||
func (a *DatabaseAuthenticator) getDB() *sql.DB {
|
||||
a.dbMu.RLock()
|
||||
defer a.dbMu.RUnlock()
|
||||
return a.db
|
||||
}
|
||||
|
||||
func (a *DatabaseAuthenticator) reconnectDB() error {
|
||||
if a.dbFactory == nil {
|
||||
return fmt.Errorf("no db factory configured for reconnect")
|
||||
}
|
||||
newDB, err := a.dbFactory()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
a.dbMu.Lock()
|
||||
a.db = newDB
|
||||
a.dbMu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *DatabaseAuthenticator) runDBOpWithReconnect(run func(*sql.DB) error) error {
|
||||
db := a.getDB()
|
||||
if db == nil {
|
||||
return fmt.Errorf("database connection is nil")
|
||||
}
|
||||
|
||||
err := run(db)
|
||||
if isDBClosed(err) {
|
||||
if reconnErr := a.reconnectDB(); reconnErr == nil {
|
||||
err = run(a.getDB())
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (a *DatabaseAuthenticator) Login(ctx context.Context, req LoginRequest) (*LoginResponse, error) {
|
||||
// Convert LoginRequest to JSON
|
||||
reqJSON, err := json.Marshal(req)
|
||||
@@ -128,8 +170,10 @@ func (a *DatabaseAuthenticator) Login(ctx context.Context, req LoginRequest) (*L
|
||||
var errorMsg sql.NullString
|
||||
var dataJSON sql.NullString
|
||||
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.Login)
|
||||
err = a.db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
|
||||
err = a.runDBOpWithReconnect(func(db *sql.DB) error {
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.Login)
|
||||
return db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("login query failed: %w", err)
|
||||
}
|
||||
@@ -162,8 +206,10 @@ func (a *DatabaseAuthenticator) Register(ctx context.Context, req RegisterReques
|
||||
var errorMsg sql.NullString
|
||||
var dataJSON sql.NullString
|
||||
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.Register)
|
||||
err = a.db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
|
||||
err = a.runDBOpWithReconnect(func(db *sql.DB) error {
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.Register)
|
||||
return db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("register query failed: %w", err)
|
||||
}
|
||||
@@ -195,8 +241,10 @@ func (a *DatabaseAuthenticator) Logout(ctx context.Context, req LogoutRequest) e
|
||||
var errorMsg sql.NullString
|
||||
var dataJSON sql.NullString
|
||||
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.Logout)
|
||||
err = a.db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
|
||||
err = a.runDBOpWithReconnect(func(db *sql.DB) error {
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.Logout)
|
||||
return db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("logout query failed: %w", err)
|
||||
}
|
||||
@@ -269,8 +317,10 @@ func (a *DatabaseAuthenticator) Authenticate(r *http.Request) (*UserContext, err
|
||||
var errorMsg sql.NullString
|
||||
var userJSON sql.NullString
|
||||
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2)`, a.sqlNames.Session)
|
||||
err := a.db.QueryRowContext(r.Context(), query, token, reference).Scan(&success, &errorMsg, &userJSON)
|
||||
err := a.runDBOpWithReconnect(func(db *sql.DB) error {
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2)`, a.sqlNames.Session)
|
||||
return db.QueryRowContext(r.Context(), query, token, reference).Scan(&success, &errorMsg, &userJSON)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("session query failed: %w", err)
|
||||
}
|
||||
@@ -345,8 +395,10 @@ func (a *DatabaseAuthenticator) updateSessionActivity(ctx context.Context, sessi
|
||||
var errorMsg sql.NullString
|
||||
var updatedUserJSON sql.NullString
|
||||
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2::jsonb)`, a.sqlNames.SessionUpdate)
|
||||
_ = a.db.QueryRowContext(ctx, query, sessionToken, string(userJSON)).Scan(&success, &errorMsg, &updatedUserJSON)
|
||||
_ = a.runDBOpWithReconnect(func(db *sql.DB) error {
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2::jsonb)`, a.sqlNames.SessionUpdate)
|
||||
return db.QueryRowContext(ctx, query, sessionToken, string(userJSON)).Scan(&success, &errorMsg, &updatedUserJSON)
|
||||
})
|
||||
}
|
||||
|
||||
// RefreshToken implements Refreshable interface
|
||||
@@ -356,8 +408,10 @@ func (a *DatabaseAuthenticator) RefreshToken(ctx context.Context, refreshToken s
|
||||
var errorMsg sql.NullString
|
||||
var userJSON sql.NullString
|
||||
// Get current session to pass to refresh
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2)`, a.sqlNames.Session)
|
||||
err := a.db.QueryRowContext(ctx, query, refreshToken, "refresh").Scan(&success, &errorMsg, &userJSON)
|
||||
err := a.runDBOpWithReconnect(func(db *sql.DB) error {
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2)`, a.sqlNames.Session)
|
||||
return db.QueryRowContext(ctx, query, refreshToken, "refresh").Scan(&success, &errorMsg, &userJSON)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("refresh token query failed: %w", err)
|
||||
}
|
||||
@@ -373,8 +427,10 @@ func (a *DatabaseAuthenticator) RefreshToken(ctx context.Context, refreshToken s
|
||||
var newErrorMsg sql.NullString
|
||||
var newUserJSON sql.NullString
|
||||
|
||||
refreshQuery := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2::jsonb)`, a.sqlNames.RefreshToken)
|
||||
err = a.db.QueryRowContext(ctx, refreshQuery, refreshToken, userJSON).Scan(&newSuccess, &newErrorMsg, &newUserJSON)
|
||||
err = a.runDBOpWithReconnect(func(db *sql.DB) error {
|
||||
refreshQuery := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2::jsonb)`, a.sqlNames.RefreshToken)
|
||||
return db.QueryRowContext(ctx, refreshQuery, refreshToken, userJSON).Scan(&newSuccess, &newErrorMsg, &newUserJSON)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("refresh token generation failed: %w", err)
|
||||
}
|
||||
@@ -406,6 +462,8 @@ func (a *DatabaseAuthenticator) RefreshToken(ctx context.Context, refreshToken s
|
||||
type JWTAuthenticator struct {
|
||||
secretKey []byte
|
||||
db *sql.DB
|
||||
dbMu sync.RWMutex
|
||||
dbFactory func() (*sql.DB, error)
|
||||
sqlNames *SQLNames
|
||||
}
|
||||
|
||||
@@ -417,13 +475,47 @@ func NewJWTAuthenticator(secretKey string, db *sql.DB, names ...*SQLNames) *JWTA
|
||||
}
|
||||
}
|
||||
|
||||
// WithDBFactory configures a factory used to reopen the database connection if it is closed.
|
||||
func (a *JWTAuthenticator) WithDBFactory(factory func() (*sql.DB, error)) *JWTAuthenticator {
|
||||
a.dbFactory = factory
|
||||
return a
|
||||
}
|
||||
|
||||
func (a *JWTAuthenticator) getDB() *sql.DB {
|
||||
a.dbMu.RLock()
|
||||
defer a.dbMu.RUnlock()
|
||||
return a.db
|
||||
}
|
||||
|
||||
func (a *JWTAuthenticator) reconnectDB() error {
|
||||
if a.dbFactory == nil {
|
||||
return fmt.Errorf("no db factory configured for reconnect")
|
||||
}
|
||||
newDB, err := a.dbFactory()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
a.dbMu.Lock()
|
||||
a.db = newDB
|
||||
a.dbMu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *JWTAuthenticator) Login(ctx context.Context, req LoginRequest) (*LoginResponse, error) {
|
||||
var success bool
|
||||
var errorMsg sql.NullString
|
||||
var userJSON []byte
|
||||
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_user FROM %s($1, $2)`, a.sqlNames.JWTLogin)
|
||||
err := a.db.QueryRowContext(ctx, query, req.Username, req.Password).Scan(&success, &errorMsg, &userJSON)
|
||||
runLoginQuery := func() error {
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_user FROM %s($1, $2)`, a.sqlNames.JWTLogin)
|
||||
return a.getDB().QueryRowContext(ctx, query, req.Username, req.Password).Scan(&success, &errorMsg, &userJSON)
|
||||
}
|
||||
err := runLoginQuery()
|
||||
if isDBClosed(err) {
|
||||
if reconnErr := a.reconnectDB(); reconnErr == nil {
|
||||
err = runLoginQuery()
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("login query failed: %w", err)
|
||||
}
|
||||
@@ -476,7 +568,7 @@ func (a *JWTAuthenticator) Logout(ctx context.Context, req LogoutRequest) error
|
||||
var errorMsg sql.NullString
|
||||
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error FROM %s($1, $2)`, a.sqlNames.JWTLogout)
|
||||
err := a.db.QueryRowContext(ctx, query, req.Token, req.UserID).Scan(&success, &errorMsg)
|
||||
err := a.getDB().QueryRowContext(ctx, query, req.Token, req.UserID).Scan(&success, &errorMsg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("logout query failed: %w", err)
|
||||
}
|
||||
@@ -513,14 +605,41 @@ func (a *JWTAuthenticator) Authenticate(r *http.Request) (*UserContext, error) {
|
||||
// All database operations go through stored procedures
|
||||
// Procedure names are configurable via SQLNames (see DefaultSQLNames for defaults)
|
||||
type DatabaseColumnSecurityProvider struct {
|
||||
db *sql.DB
|
||||
sqlNames *SQLNames
|
||||
db *sql.DB
|
||||
dbMu sync.RWMutex
|
||||
dbFactory func() (*sql.DB, error)
|
||||
sqlNames *SQLNames
|
||||
}
|
||||
|
||||
func NewDatabaseColumnSecurityProvider(db *sql.DB, names ...*SQLNames) *DatabaseColumnSecurityProvider {
|
||||
return &DatabaseColumnSecurityProvider{db: db, sqlNames: resolveSQLNames(names...)}
|
||||
}
|
||||
|
||||
func (p *DatabaseColumnSecurityProvider) WithDBFactory(factory func() (*sql.DB, error)) *DatabaseColumnSecurityProvider {
|
||||
p.dbFactory = factory
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *DatabaseColumnSecurityProvider) getDB() *sql.DB {
|
||||
p.dbMu.RLock()
|
||||
defer p.dbMu.RUnlock()
|
||||
return p.db
|
||||
}
|
||||
|
||||
func (p *DatabaseColumnSecurityProvider) reconnectDB() error {
|
||||
if p.dbFactory == nil {
|
||||
return fmt.Errorf("no db factory configured for reconnect")
|
||||
}
|
||||
newDB, err := p.dbFactory()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
p.dbMu.Lock()
|
||||
p.db = newDB
|
||||
p.dbMu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *DatabaseColumnSecurityProvider) GetColumnSecurity(ctx context.Context, userID int, schema, table string) ([]ColumnSecurity, error) {
|
||||
var rules []ColumnSecurity
|
||||
|
||||
@@ -528,8 +647,16 @@ func (p *DatabaseColumnSecurityProvider) GetColumnSecurity(ctx context.Context,
|
||||
var errorMsg sql.NullString
|
||||
var rulesJSON []byte
|
||||
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_rules FROM %s($1, $2, $3)`, p.sqlNames.ColumnSecurity)
|
||||
err := p.db.QueryRowContext(ctx, query, userID, schema, table).Scan(&success, &errorMsg, &rulesJSON)
|
||||
runQuery := func() error {
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_rules FROM %s($1, $2, $3)`, p.sqlNames.ColumnSecurity)
|
||||
return p.getDB().QueryRowContext(ctx, query, userID, schema, table).Scan(&success, &errorMsg, &rulesJSON)
|
||||
}
|
||||
err := runQuery()
|
||||
if isDBClosed(err) {
|
||||
if reconnErr := p.reconnectDB(); reconnErr == nil {
|
||||
err = runQuery()
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load column security: %w", err)
|
||||
}
|
||||
@@ -578,21 +705,55 @@ func (p *DatabaseColumnSecurityProvider) GetColumnSecurity(ctx context.Context,
|
||||
// All database operations go through stored procedures
|
||||
// Procedure names are configurable via SQLNames (see DefaultSQLNames for defaults)
|
||||
type DatabaseRowSecurityProvider struct {
|
||||
db *sql.DB
|
||||
sqlNames *SQLNames
|
||||
db *sql.DB
|
||||
dbMu sync.RWMutex
|
||||
dbFactory func() (*sql.DB, error)
|
||||
sqlNames *SQLNames
|
||||
}
|
||||
|
||||
func NewDatabaseRowSecurityProvider(db *sql.DB, names ...*SQLNames) *DatabaseRowSecurityProvider {
|
||||
return &DatabaseRowSecurityProvider{db: db, sqlNames: resolveSQLNames(names...)}
|
||||
}
|
||||
|
||||
func (p *DatabaseRowSecurityProvider) WithDBFactory(factory func() (*sql.DB, error)) *DatabaseRowSecurityProvider {
|
||||
p.dbFactory = factory
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *DatabaseRowSecurityProvider) getDB() *sql.DB {
|
||||
p.dbMu.RLock()
|
||||
defer p.dbMu.RUnlock()
|
||||
return p.db
|
||||
}
|
||||
|
||||
func (p *DatabaseRowSecurityProvider) reconnectDB() error {
|
||||
if p.dbFactory == nil {
|
||||
return fmt.Errorf("no db factory configured for reconnect")
|
||||
}
|
||||
newDB, err := p.dbFactory()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
p.dbMu.Lock()
|
||||
p.db = newDB
|
||||
p.dbMu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *DatabaseRowSecurityProvider) GetRowSecurity(ctx context.Context, userID int, schema, table string) (RowSecurity, error) {
|
||||
var template string
|
||||
var hasBlock bool
|
||||
|
||||
query := fmt.Sprintf(`SELECT p_template, p_block FROM %s($1, $2, $3)`, p.sqlNames.RowSecurity)
|
||||
|
||||
err := p.db.QueryRowContext(ctx, query, schema, table, userID).Scan(&template, &hasBlock)
|
||||
runQuery := func() error {
|
||||
query := fmt.Sprintf(`SELECT p_template, p_block FROM %s($1, $2, $3)`, p.sqlNames.RowSecurity)
|
||||
return p.getDB().QueryRowContext(ctx, query, schema, table, userID).Scan(&template, &hasBlock)
|
||||
}
|
||||
err := runQuery()
|
||||
if isDBClosed(err) {
|
||||
if reconnErr := p.reconnectDB(); reconnErr == nil {
|
||||
err = runQuery()
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return RowSecurity{}, fmt.Errorf("failed to load row security: %w", err)
|
||||
}
|
||||
@@ -662,6 +823,11 @@ func (p *ConfigRowSecurityProvider) GetRowSecurity(ctx context.Context, userID i
|
||||
// Helper functions
|
||||
// ================
|
||||
|
||||
// isDBClosed reports whether err indicates the *sql.DB has been closed.
|
||||
func isDBClosed(err error) bool {
|
||||
return err != nil && strings.Contains(err.Error(), "sql: database is closed")
|
||||
}
|
||||
|
||||
func parseRoles(rolesStr string) []string {
|
||||
if rolesStr == "" {
|
||||
return []string{}
|
||||
@@ -780,8 +946,16 @@ func (a *DatabaseAuthenticator) LoginWithPasskey(ctx context.Context, req Passke
|
||||
var errorMsg sql.NullString
|
||||
var dataJSON sql.NullString
|
||||
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.PasskeyLogin)
|
||||
err = a.db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
|
||||
runPasskeyQuery := func() error {
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.PasskeyLogin)
|
||||
return a.getDB().QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
|
||||
}
|
||||
err = runPasskeyQuery()
|
||||
if isDBClosed(err) {
|
||||
if reconnErr := a.reconnectDB(); reconnErr == nil {
|
||||
err = runPasskeyQuery()
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("passkey login query failed: %w", err)
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package security
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
@@ -790,6 +791,211 @@ func TestDatabaseAuthenticatorRefreshToken(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestDatabaseAuthenticatorReconnectsClosedDBPaths(t *testing.T) {
|
||||
newAuthWithReconnect := func(t *testing.T) (*DatabaseAuthenticator, sqlmock.Sqlmock, sqlmock.Sqlmock, func()) {
|
||||
t.Helper()
|
||||
|
||||
primaryDB, primaryMock, err := sqlmock.New()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create primary mock db: %v", err)
|
||||
}
|
||||
|
||||
reconnectDB, reconnectMock, err := sqlmock.New()
|
||||
if err != nil {
|
||||
primaryDB.Close()
|
||||
t.Fatalf("failed to create reconnect mock db: %v", err)
|
||||
}
|
||||
|
||||
cacheProvider := cache.NewMemoryProvider(&cache.Options{
|
||||
DefaultTTL: 1 * time.Minute,
|
||||
MaxSize: 1000,
|
||||
})
|
||||
|
||||
auth := NewDatabaseAuthenticatorWithOptions(primaryDB, DatabaseAuthenticatorOptions{
|
||||
Cache: cache.NewCache(cacheProvider),
|
||||
DBFactory: func() (*sql.DB, error) {
|
||||
return reconnectDB, nil
|
||||
},
|
||||
})
|
||||
|
||||
cleanup := func() {
|
||||
_ = primaryDB.Close()
|
||||
_ = reconnectDB.Close()
|
||||
}
|
||||
|
||||
return auth, primaryMock, reconnectMock, cleanup
|
||||
}
|
||||
|
||||
t.Run("Authenticate reconnects after closed database", func(t *testing.T) {
|
||||
auth, primaryMock, reconnectMock, cleanup := newAuthWithReconnect(t)
|
||||
defer cleanup()
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req.Header.Set("Authorization", "Bearer reconnect-auth-token")
|
||||
|
||||
primaryMock.ExpectQuery(`SELECT p_success, p_error, p_user::text FROM resolvespec_session`).
|
||||
WithArgs("reconnect-auth-token", "authenticate").
|
||||
WillReturnError(fmt.Errorf("sql: database is closed"))
|
||||
|
||||
reconnectRows := sqlmock.NewRows([]string{"p_success", "p_error", "p_user"}).
|
||||
AddRow(true, nil, `{"user_id":7,"user_name":"reconnect-user","session_id":"reconnect-auth-token"}`)
|
||||
|
||||
reconnectMock.ExpectQuery(`SELECT p_success, p_error, p_user::text FROM resolvespec_session`).
|
||||
WithArgs("reconnect-auth-token", "authenticate").
|
||||
WillReturnRows(reconnectRows)
|
||||
|
||||
userCtx, err := auth.Authenticate(req)
|
||||
if err != nil {
|
||||
t.Fatalf("expected authenticate to reconnect, got %v", err)
|
||||
}
|
||||
if userCtx.UserID != 7 {
|
||||
t.Fatalf("expected user ID 7, got %d", userCtx.UserID)
|
||||
}
|
||||
|
||||
if err := primaryMock.ExpectationsWereMet(); err != nil {
|
||||
t.Fatalf("primary db expectations not met: %v", err)
|
||||
}
|
||||
if err := reconnectMock.ExpectationsWereMet(); err != nil {
|
||||
t.Fatalf("reconnect db expectations not met: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Register reconnects after closed database", func(t *testing.T) {
|
||||
auth, primaryMock, reconnectMock, cleanup := newAuthWithReconnect(t)
|
||||
defer cleanup()
|
||||
|
||||
req := RegisterRequest{
|
||||
Username: "reconnect-register",
|
||||
Password: "password123",
|
||||
Email: "reconnect@example.com",
|
||||
UserLevel: 1,
|
||||
Roles: []string{"user"},
|
||||
}
|
||||
|
||||
primaryMock.ExpectQuery(`SELECT p_success, p_error, p_data::text FROM resolvespec_register`).
|
||||
WithArgs(sqlmock.AnyArg()).
|
||||
WillReturnError(fmt.Errorf("sql: database is closed"))
|
||||
|
||||
reconnectRows := sqlmock.NewRows([]string{"p_success", "p_error", "p_data"}).
|
||||
AddRow(true, nil, `{"token":"reconnected-register-token","user":{"user_id":8,"user_name":"reconnect-register"},"expires_in":86400}`)
|
||||
|
||||
reconnectMock.ExpectQuery(`SELECT p_success, p_error, p_data::text FROM resolvespec_register`).
|
||||
WithArgs(sqlmock.AnyArg()).
|
||||
WillReturnRows(reconnectRows)
|
||||
|
||||
resp, err := auth.Register(context.Background(), req)
|
||||
if err != nil {
|
||||
t.Fatalf("expected register to reconnect, got %v", err)
|
||||
}
|
||||
if resp.Token != "reconnected-register-token" {
|
||||
t.Fatalf("expected refreshed token, got %s", resp.Token)
|
||||
}
|
||||
|
||||
if err := primaryMock.ExpectationsWereMet(); err != nil {
|
||||
t.Fatalf("primary db expectations not met: %v", err)
|
||||
}
|
||||
if err := reconnectMock.ExpectationsWereMet(); err != nil {
|
||||
t.Fatalf("reconnect db expectations not met: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Logout reconnects after closed database", func(t *testing.T) {
|
||||
auth, primaryMock, reconnectMock, cleanup := newAuthWithReconnect(t)
|
||||
defer cleanup()
|
||||
|
||||
req := LogoutRequest{Token: "logout-reconnect-token", UserID: 9}
|
||||
|
||||
primaryMock.ExpectQuery(`SELECT p_success, p_error, p_data::text FROM resolvespec_logout`).
|
||||
WithArgs(sqlmock.AnyArg()).
|
||||
WillReturnError(fmt.Errorf("sql: database is closed"))
|
||||
|
||||
reconnectRows := sqlmock.NewRows([]string{"p_success", "p_error", "p_data"}).
|
||||
AddRow(true, nil, nil)
|
||||
|
||||
reconnectMock.ExpectQuery(`SELECT p_success, p_error, p_data::text FROM resolvespec_logout`).
|
||||
WithArgs(sqlmock.AnyArg()).
|
||||
WillReturnRows(reconnectRows)
|
||||
|
||||
if err := auth.Logout(context.Background(), req); err != nil {
|
||||
t.Fatalf("expected logout to reconnect, got %v", err)
|
||||
}
|
||||
|
||||
if err := primaryMock.ExpectationsWereMet(); err != nil {
|
||||
t.Fatalf("primary db expectations not met: %v", err)
|
||||
}
|
||||
if err := reconnectMock.ExpectationsWereMet(); err != nil {
|
||||
t.Fatalf("reconnect db expectations not met: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("RefreshToken reconnects after closed database", func(t *testing.T) {
|
||||
auth, primaryMock, reconnectMock, cleanup := newAuthWithReconnect(t)
|
||||
defer cleanup()
|
||||
|
||||
refreshToken := "refresh-reconnect-token"
|
||||
|
||||
primaryMock.ExpectQuery(`SELECT p_success, p_error, p_user::text FROM resolvespec_session`).
|
||||
WithArgs(refreshToken, "refresh").
|
||||
WillReturnError(fmt.Errorf("sql: database is closed"))
|
||||
|
||||
sessionRows := sqlmock.NewRows([]string{"p_success", "p_error", "p_user"}).
|
||||
AddRow(true, nil, `{"user_id":10,"user_name":"refresh-user"}`)
|
||||
|
||||
reconnectMock.ExpectQuery(`SELECT p_success, p_error, p_user::text FROM resolvespec_session`).
|
||||
WithArgs(refreshToken, "refresh").
|
||||
WillReturnRows(sessionRows)
|
||||
|
||||
refreshRows := sqlmock.NewRows([]string{"p_success", "p_error", "p_user"}).
|
||||
AddRow(true, nil, `{"user_id":10,"user_name":"refresh-user","session_id":"refreshed-token"}`)
|
||||
|
||||
reconnectMock.ExpectQuery(`SELECT p_success, p_error, p_user::text FROM resolvespec_refresh_token`).
|
||||
WithArgs(refreshToken, sqlmock.AnyArg()).
|
||||
WillReturnRows(refreshRows)
|
||||
|
||||
resp, err := auth.RefreshToken(context.Background(), refreshToken)
|
||||
if err != nil {
|
||||
t.Fatalf("expected refresh token to reconnect, got %v", err)
|
||||
}
|
||||
if resp.Token != "refreshed-token" {
|
||||
t.Fatalf("expected refreshed-token, got %s", resp.Token)
|
||||
}
|
||||
|
||||
if err := primaryMock.ExpectationsWereMet(); err != nil {
|
||||
t.Fatalf("primary db expectations not met: %v", err)
|
||||
}
|
||||
if err := reconnectMock.ExpectationsWereMet(); err != nil {
|
||||
t.Fatalf("reconnect db expectations not met: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("updateSessionActivity reconnects after closed database", func(t *testing.T) {
|
||||
auth, primaryMock, reconnectMock, cleanup := newAuthWithReconnect(t)
|
||||
defer cleanup()
|
||||
|
||||
userCtx := &UserContext{UserID: 11, UserName: "activity-user", SessionID: "activity-token"}
|
||||
|
||||
primaryMock.ExpectQuery(`SELECT p_success, p_error, p_user::text FROM resolvespec_session_update`).
|
||||
WithArgs("activity-token", sqlmock.AnyArg()).
|
||||
WillReturnError(fmt.Errorf("sql: database is closed"))
|
||||
|
||||
reconnectRows := sqlmock.NewRows([]string{"p_success", "p_error", "p_user"}).
|
||||
AddRow(true, nil, `{"user_id":11,"user_name":"activity-user","session_id":"activity-token"}`)
|
||||
|
||||
reconnectMock.ExpectQuery(`SELECT p_success, p_error, p_user::text FROM resolvespec_session_update`).
|
||||
WithArgs("activity-token", sqlmock.AnyArg()).
|
||||
WillReturnRows(reconnectRows)
|
||||
|
||||
auth.updateSessionActivity(context.Background(), "activity-token", userCtx)
|
||||
|
||||
if err := primaryMock.ExpectationsWereMet(); err != nil {
|
||||
t.Fatalf("primary db expectations not met: %v", err)
|
||||
}
|
||||
if err := reconnectMock.ExpectationsWereMet(); err != nil {
|
||||
t.Fatalf("reconnect db expectations not met: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Test JWTAuthenticator
|
||||
func TestJWTAuthenticator(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
|
||||
@@ -54,6 +54,13 @@ type SQLNames struct {
|
||||
OAuthUpdateRefreshToken string // default: "resolvespec_oauth_updaterefreshtoken"
|
||||
OAuthGetUser string // default: "resolvespec_oauth_getuser"
|
||||
|
||||
// OAuth2 server procedures (OAuthServer persistence)
|
||||
OAuthRegisterClient string // default: "resolvespec_oauth_register_client"
|
||||
OAuthGetClient string // default: "resolvespec_oauth_get_client"
|
||||
OAuthSaveCode string // default: "resolvespec_oauth_save_code"
|
||||
OAuthExchangeCode string // default: "resolvespec_oauth_exchange_code"
|
||||
OAuthIntrospect string // default: "resolvespec_oauth_introspect"
|
||||
OAuthRevoke string // default: "resolvespec_oauth_revoke"
|
||||
}
|
||||
|
||||
// DefaultSQLNames returns an SQLNames with all default resolvespec_* values.
|
||||
@@ -93,6 +100,13 @@ func DefaultSQLNames() *SQLNames {
|
||||
OAuthGetRefreshToken: "resolvespec_oauth_getrefreshtoken",
|
||||
OAuthUpdateRefreshToken: "resolvespec_oauth_updaterefreshtoken",
|
||||
OAuthGetUser: "resolvespec_oauth_getuser",
|
||||
|
||||
OAuthRegisterClient: "resolvespec_oauth_register_client",
|
||||
OAuthGetClient: "resolvespec_oauth_get_client",
|
||||
OAuthSaveCode: "resolvespec_oauth_save_code",
|
||||
OAuthExchangeCode: "resolvespec_oauth_exchange_code",
|
||||
OAuthIntrospect: "resolvespec_oauth_introspect",
|
||||
OAuthRevoke: "resolvespec_oauth_revoke",
|
||||
}
|
||||
}
|
||||
|
||||
@@ -191,6 +205,24 @@ func MergeSQLNames(base, override *SQLNames) *SQLNames {
|
||||
if override.OAuthGetUser != "" {
|
||||
merged.OAuthGetUser = override.OAuthGetUser
|
||||
}
|
||||
if override.OAuthRegisterClient != "" {
|
||||
merged.OAuthRegisterClient = override.OAuthRegisterClient
|
||||
}
|
||||
if override.OAuthGetClient != "" {
|
||||
merged.OAuthGetClient = override.OAuthGetClient
|
||||
}
|
||||
if override.OAuthSaveCode != "" {
|
||||
merged.OAuthSaveCode = override.OAuthSaveCode
|
||||
}
|
||||
if override.OAuthExchangeCode != "" {
|
||||
merged.OAuthExchangeCode = override.OAuthExchangeCode
|
||||
}
|
||||
if override.OAuthIntrospect != "" {
|
||||
merged.OAuthIntrospect = override.OAuthIntrospect
|
||||
}
|
||||
if override.OAuthRevoke != "" {
|
||||
merged.OAuthRevoke = override.OAuthRevoke
|
||||
}
|
||||
return &merged
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user