mirror of
https://github.com/bitechdev/ResolveSpec.git
synced 2026-04-09 09:26:24 +00:00
Compare commits
2 Commits
feature-au
...
feature-ke
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
79a3912f93 | ||
|
|
a9bf08f58b |
@@ -6,6 +6,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/uptrace/bun"
|
"github.com/uptrace/bun"
|
||||||
@@ -95,6 +96,8 @@ func debugScanIntoStruct(rows interface{}, dest interface{}) error {
|
|||||||
// This demonstrates how the abstraction works with different ORMs
|
// This demonstrates how the abstraction works with different ORMs
|
||||||
type BunAdapter struct {
|
type BunAdapter struct {
|
||||||
db *bun.DB
|
db *bun.DB
|
||||||
|
dbMu sync.RWMutex
|
||||||
|
dbFactory func() (*bun.DB, error)
|
||||||
driverName string
|
driverName string
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -106,10 +109,36 @@ func NewBunAdapter(db *bun.DB) *BunAdapter {
|
|||||||
return adapter
|
return adapter
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WithDBFactory configures a factory used to reopen the database connection if it is closed.
|
||||||
|
func (b *BunAdapter) WithDBFactory(factory func() (*bun.DB, error)) *BunAdapter {
|
||||||
|
b.dbFactory = factory
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *BunAdapter) getDB() *bun.DB {
|
||||||
|
b.dbMu.RLock()
|
||||||
|
defer b.dbMu.RUnlock()
|
||||||
|
return b.db
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *BunAdapter) reconnectDB() error {
|
||||||
|
if b.dbFactory == nil {
|
||||||
|
return fmt.Errorf("no db factory configured for reconnect")
|
||||||
|
}
|
||||||
|
newDB, err := b.dbFactory()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
b.dbMu.Lock()
|
||||||
|
b.db = newDB
|
||||||
|
b.dbMu.Unlock()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// EnableQueryDebug enables query debugging which logs all SQL queries including preloads
|
// EnableQueryDebug enables query debugging which logs all SQL queries including preloads
|
||||||
// This is useful for debugging preload queries that may be failing
|
// This is useful for debugging preload queries that may be failing
|
||||||
func (b *BunAdapter) EnableQueryDebug() {
|
func (b *BunAdapter) EnableQueryDebug() {
|
||||||
b.db.AddQueryHook(&QueryDebugHook{})
|
b.getDB().AddQueryHook(&QueryDebugHook{})
|
||||||
logger.Info("Bun query debug mode enabled - all SQL queries will be logged")
|
logger.Info("Bun query debug mode enabled - all SQL queries will be logged")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -130,22 +159,22 @@ func (b *BunAdapter) DisableQueryDebug() {
|
|||||||
|
|
||||||
func (b *BunAdapter) NewSelect() common.SelectQuery {
|
func (b *BunAdapter) NewSelect() common.SelectQuery {
|
||||||
return &BunSelectQuery{
|
return &BunSelectQuery{
|
||||||
query: b.db.NewSelect(),
|
query: b.getDB().NewSelect(),
|
||||||
db: b.db,
|
db: b.db,
|
||||||
driverName: b.driverName,
|
driverName: b.driverName,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *BunAdapter) NewInsert() common.InsertQuery {
|
func (b *BunAdapter) NewInsert() common.InsertQuery {
|
||||||
return &BunInsertQuery{query: b.db.NewInsert()}
|
return &BunInsertQuery{query: b.getDB().NewInsert()}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *BunAdapter) NewUpdate() common.UpdateQuery {
|
func (b *BunAdapter) NewUpdate() common.UpdateQuery {
|
||||||
return &BunUpdateQuery{query: b.db.NewUpdate()}
|
return &BunUpdateQuery{query: b.getDB().NewUpdate()}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *BunAdapter) NewDelete() common.DeleteQuery {
|
func (b *BunAdapter) NewDelete() common.DeleteQuery {
|
||||||
return &BunDeleteQuery{query: b.db.NewDelete()}
|
return &BunDeleteQuery{query: b.getDB().NewDelete()}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *BunAdapter) Exec(ctx context.Context, query string, args ...interface{}) (res common.Result, err error) {
|
func (b *BunAdapter) Exec(ctx context.Context, query string, args ...interface{}) (res common.Result, err error) {
|
||||||
@@ -154,7 +183,14 @@ func (b *BunAdapter) Exec(ctx context.Context, query string, args ...interface{}
|
|||||||
err = logger.HandlePanic("BunAdapter.Exec", r)
|
err = logger.HandlePanic("BunAdapter.Exec", r)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
result, err := b.db.ExecContext(ctx, query, args...)
|
var result sql.Result
|
||||||
|
run := func() error { var e error; result, e = b.getDB().ExecContext(ctx, query, args...); return e }
|
||||||
|
err = run()
|
||||||
|
if isDBClosed(err) {
|
||||||
|
if reconnErr := b.reconnectDB(); reconnErr == nil {
|
||||||
|
err = run()
|
||||||
|
}
|
||||||
|
}
|
||||||
return &BunResult{result: result}, err
|
return &BunResult{result: result}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -164,11 +200,17 @@ func (b *BunAdapter) Query(ctx context.Context, dest interface{}, query string,
|
|||||||
err = logger.HandlePanic("BunAdapter.Query", r)
|
err = logger.HandlePanic("BunAdapter.Query", r)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
return b.db.NewRaw(query, args...).Scan(ctx, dest)
|
err = b.getDB().NewRaw(query, args...).Scan(ctx, dest)
|
||||||
|
if isDBClosed(err) {
|
||||||
|
if reconnErr := b.reconnectDB(); reconnErr == nil {
|
||||||
|
err = b.getDB().NewRaw(query, args...).Scan(ctx, dest)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *BunAdapter) BeginTx(ctx context.Context) (common.Database, error) {
|
func (b *BunAdapter) BeginTx(ctx context.Context) (common.Database, error) {
|
||||||
tx, err := b.db.BeginTx(ctx, &sql.TxOptions{})
|
tx, err := b.getDB().BeginTx(ctx, &sql.TxOptions{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -194,7 +236,7 @@ func (b *BunAdapter) RunInTransaction(ctx context.Context, fn func(common.Databa
|
|||||||
err = logger.HandlePanic("BunAdapter.RunInTransaction", r)
|
err = logger.HandlePanic("BunAdapter.RunInTransaction", r)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
return b.db.RunInTx(ctx, &sql.TxOptions{}, func(ctx context.Context, tx bun.Tx) error {
|
return b.getDB().RunInTx(ctx, &sql.TxOptions{}, func(ctx context.Context, tx bun.Tx) error {
|
||||||
// Create adapter with transaction
|
// Create adapter with transaction
|
||||||
adapter := &BunTxAdapter{tx: tx, driverName: b.driverName}
|
adapter := &BunTxAdapter{tx: tx, driverName: b.driverName}
|
||||||
return fn(adapter)
|
return fn(adapter)
|
||||||
@@ -202,7 +244,7 @@ func (b *BunAdapter) RunInTransaction(ctx context.Context, fn func(common.Databa
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (b *BunAdapter) GetUnderlyingDB() interface{} {
|
func (b *BunAdapter) GetUnderlyingDB() interface{} {
|
||||||
return b.db
|
return b.getDB()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *BunAdapter) DriverName() string {
|
func (b *BunAdapter) DriverName() string {
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
@@ -17,6 +18,8 @@ import (
|
|||||||
// This provides a lightweight PostgreSQL adapter without ORM overhead
|
// This provides a lightweight PostgreSQL adapter without ORM overhead
|
||||||
type PgSQLAdapter struct {
|
type PgSQLAdapter struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
|
dbMu sync.RWMutex
|
||||||
|
dbFactory func() (*sql.DB, error)
|
||||||
driverName string
|
driverName string
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -31,6 +34,36 @@ func NewPgSQLAdapter(db *sql.DB, driverName ...string) *PgSQLAdapter {
|
|||||||
return &PgSQLAdapter{db: db, driverName: name}
|
return &PgSQLAdapter{db: db, driverName: name}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WithDBFactory configures a factory used to reopen the database connection if it is closed.
|
||||||
|
func (p *PgSQLAdapter) WithDBFactory(factory func() (*sql.DB, error)) *PgSQLAdapter {
|
||||||
|
p.dbFactory = factory
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *PgSQLAdapter) getDB() *sql.DB {
|
||||||
|
p.dbMu.RLock()
|
||||||
|
defer p.dbMu.RUnlock()
|
||||||
|
return p.db
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *PgSQLAdapter) reconnectDB() error {
|
||||||
|
if p.dbFactory == nil {
|
||||||
|
return fmt.Errorf("no db factory configured for reconnect")
|
||||||
|
}
|
||||||
|
newDB, err := p.dbFactory()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
p.dbMu.Lock()
|
||||||
|
p.db = newDB
|
||||||
|
p.dbMu.Unlock()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func isDBClosed(err error) bool {
|
||||||
|
return err != nil && strings.Contains(err.Error(), "sql: database is closed")
|
||||||
|
}
|
||||||
|
|
||||||
// EnableQueryDebug enables query debugging for development
|
// EnableQueryDebug enables query debugging for development
|
||||||
func (p *PgSQLAdapter) EnableQueryDebug() {
|
func (p *PgSQLAdapter) EnableQueryDebug() {
|
||||||
logger.Info("PgSQL query debug mode - logging enabled via logger")
|
logger.Info("PgSQL query debug mode - logging enabled via logger")
|
||||||
@@ -38,7 +71,7 @@ func (p *PgSQLAdapter) EnableQueryDebug() {
|
|||||||
|
|
||||||
func (p *PgSQLAdapter) NewSelect() common.SelectQuery {
|
func (p *PgSQLAdapter) NewSelect() common.SelectQuery {
|
||||||
return &PgSQLSelectQuery{
|
return &PgSQLSelectQuery{
|
||||||
db: p.db,
|
db: p.getDB(),
|
||||||
driverName: p.driverName,
|
driverName: p.driverName,
|
||||||
columns: []string{"*"},
|
columns: []string{"*"},
|
||||||
args: make([]interface{}, 0),
|
args: make([]interface{}, 0),
|
||||||
@@ -47,7 +80,7 @@ func (p *PgSQLAdapter) NewSelect() common.SelectQuery {
|
|||||||
|
|
||||||
func (p *PgSQLAdapter) NewInsert() common.InsertQuery {
|
func (p *PgSQLAdapter) NewInsert() common.InsertQuery {
|
||||||
return &PgSQLInsertQuery{
|
return &PgSQLInsertQuery{
|
||||||
db: p.db,
|
db: p.getDB(),
|
||||||
driverName: p.driverName,
|
driverName: p.driverName,
|
||||||
values: make(map[string]interface{}),
|
values: make(map[string]interface{}),
|
||||||
}
|
}
|
||||||
@@ -55,7 +88,7 @@ func (p *PgSQLAdapter) NewInsert() common.InsertQuery {
|
|||||||
|
|
||||||
func (p *PgSQLAdapter) NewUpdate() common.UpdateQuery {
|
func (p *PgSQLAdapter) NewUpdate() common.UpdateQuery {
|
||||||
return &PgSQLUpdateQuery{
|
return &PgSQLUpdateQuery{
|
||||||
db: p.db,
|
db: p.getDB(),
|
||||||
driverName: p.driverName,
|
driverName: p.driverName,
|
||||||
sets: make(map[string]interface{}),
|
sets: make(map[string]interface{}),
|
||||||
args: make([]interface{}, 0),
|
args: make([]interface{}, 0),
|
||||||
@@ -65,7 +98,7 @@ func (p *PgSQLAdapter) NewUpdate() common.UpdateQuery {
|
|||||||
|
|
||||||
func (p *PgSQLAdapter) NewDelete() common.DeleteQuery {
|
func (p *PgSQLAdapter) NewDelete() common.DeleteQuery {
|
||||||
return &PgSQLDeleteQuery{
|
return &PgSQLDeleteQuery{
|
||||||
db: p.db,
|
db: p.getDB(),
|
||||||
driverName: p.driverName,
|
driverName: p.driverName,
|
||||||
args: make([]interface{}, 0),
|
args: make([]interface{}, 0),
|
||||||
whereClauses: make([]string, 0),
|
whereClauses: make([]string, 0),
|
||||||
@@ -79,7 +112,14 @@ func (p *PgSQLAdapter) Exec(ctx context.Context, query string, args ...interface
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
logger.Debug("PgSQL Exec: %s [args: %v]", query, args)
|
logger.Debug("PgSQL Exec: %s [args: %v]", query, args)
|
||||||
result, err := p.db.ExecContext(ctx, query, args...)
|
var result sql.Result
|
||||||
|
run := func() error { var e error; result, e = p.getDB().ExecContext(ctx, query, args...); return e }
|
||||||
|
err = run()
|
||||||
|
if isDBClosed(err) {
|
||||||
|
if reconnErr := p.reconnectDB(); reconnErr == nil {
|
||||||
|
err = run()
|
||||||
|
}
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("PgSQL Exec failed: %v", err)
|
logger.Error("PgSQL Exec failed: %v", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -94,7 +134,14 @@ func (p *PgSQLAdapter) Query(ctx context.Context, dest interface{}, query string
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
logger.Debug("PgSQL Query: %s [args: %v]", query, args)
|
logger.Debug("PgSQL Query: %s [args: %v]", query, args)
|
||||||
rows, err := p.db.QueryContext(ctx, query, args...)
|
var rows *sql.Rows
|
||||||
|
run := func() error { var e error; rows, e = p.getDB().QueryContext(ctx, query, args...); return e }
|
||||||
|
err = run()
|
||||||
|
if isDBClosed(err) {
|
||||||
|
if reconnErr := p.reconnectDB(); reconnErr == nil {
|
||||||
|
err = run()
|
||||||
|
}
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("PgSQL Query failed: %v", err)
|
logger.Error("PgSQL Query failed: %v", err)
|
||||||
return err
|
return err
|
||||||
@@ -105,7 +152,7 @@ func (p *PgSQLAdapter) Query(ctx context.Context, dest interface{}, query string
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (p *PgSQLAdapter) BeginTx(ctx context.Context) (common.Database, error) {
|
func (p *PgSQLAdapter) BeginTx(ctx context.Context) (common.Database, error) {
|
||||||
tx, err := p.db.BeginTx(ctx, nil)
|
tx, err := p.getDB().BeginTx(ctx, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -127,7 +174,7 @@ func (p *PgSQLAdapter) RunInTransaction(ctx context.Context, fn func(common.Data
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
tx, err := p.db.BeginTx(ctx, nil)
|
tx, err := p.getDB().BeginTx(ctx, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,11 +4,17 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"errors"
|
"errors"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"go.mongodb.org/mongo-driver/mongo"
|
"go.mongodb.org/mongo-driver/mongo"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// isDBClosed reports whether err indicates the *sql.DB has been closed.
|
||||||
|
func isDBClosed(err error) bool {
|
||||||
|
return err != nil && strings.Contains(err.Error(), "sql: database is closed")
|
||||||
|
}
|
||||||
|
|
||||||
// Common errors
|
// Common errors
|
||||||
var (
|
var (
|
||||||
// ErrNotSQLDatabase is returned when attempting SQL operations on a non-SQL database
|
// ErrNotSQLDatabase is returned when attempting SQL operations on a non-SQL database
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
_ "github.com/glebarez/sqlite" // Pure Go SQLite driver
|
_ "github.com/glebarez/sqlite" // Pure Go SQLite driver
|
||||||
@@ -15,6 +16,8 @@ import (
|
|||||||
// SQLiteProvider implements Provider for SQLite databases
|
// SQLiteProvider implements Provider for SQLite databases
|
||||||
type SQLiteProvider struct {
|
type SQLiteProvider struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
|
dbMu sync.RWMutex
|
||||||
|
dbFactory func() (*sql.DB, error)
|
||||||
config ConnectionConfig
|
config ConnectionConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -129,7 +132,13 @@ func (p *SQLiteProvider) HealthCheck(ctx context.Context) error {
|
|||||||
|
|
||||||
// Execute a simple query to verify the database is accessible
|
// Execute a simple query to verify the database is accessible
|
||||||
var result int
|
var result int
|
||||||
err := p.db.QueryRowContext(healthCtx, "SELECT 1").Scan(&result)
|
run := func() error { return p.getDB().QueryRowContext(healthCtx, "SELECT 1").Scan(&result) }
|
||||||
|
err := run()
|
||||||
|
if isDBClosed(err) {
|
||||||
|
if reconnErr := p.reconnectDB(); reconnErr == nil {
|
||||||
|
err = run()
|
||||||
|
}
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("health check failed: %w", err)
|
return fmt.Errorf("health check failed: %w", err)
|
||||||
}
|
}
|
||||||
@@ -141,6 +150,32 @@ func (p *SQLiteProvider) HealthCheck(ctx context.Context) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WithDBFactory configures a factory used to reopen the database connection if it is closed.
|
||||||
|
func (p *SQLiteProvider) WithDBFactory(factory func() (*sql.DB, error)) *SQLiteProvider {
|
||||||
|
p.dbFactory = factory
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *SQLiteProvider) getDB() *sql.DB {
|
||||||
|
p.dbMu.RLock()
|
||||||
|
defer p.dbMu.RUnlock()
|
||||||
|
return p.db
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *SQLiteProvider) reconnectDB() error {
|
||||||
|
if p.dbFactory == nil {
|
||||||
|
return fmt.Errorf("no db factory configured for reconnect")
|
||||||
|
}
|
||||||
|
newDB, err := p.dbFactory()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
p.dbMu.Lock()
|
||||||
|
p.db = newDB
|
||||||
|
p.dbMu.Unlock()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// GetNative returns the native *sql.DB connection
|
// GetNative returns the native *sql.DB connection
|
||||||
func (p *SQLiteProvider) GetNative() (*sql.DB, error) {
|
func (p *SQLiteProvider) GetNative() (*sql.DB, error) {
|
||||||
if p.db == nil {
|
if p.db == nil {
|
||||||
|
|||||||
@@ -74,7 +74,7 @@ func (h *Handler) newSSEServer(baseURL, basePath string) *server.SSEServer {
|
|||||||
return server.NewSSEServer(
|
return server.NewSSEServer(
|
||||||
h.mcpServer,
|
h.mcpServer,
|
||||||
server.WithBaseURL(baseURL),
|
server.WithBaseURL(baseURL),
|
||||||
server.WithBasePath(basePath),
|
server.WithStaticBasePath(basePath),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -695,7 +695,7 @@ func (h *Handler) applyFilterGroup(query common.SelectQuery, filters []common.Fi
|
|||||||
return query.Where("("+strings.Join(conditions, " OR ")+")", args...)
|
return query.Where("("+strings.Join(conditions, " OR ")+")", args...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) buildFilterCondition(filter common.FilterOption) (string, []interface{}) {
|
func (h *Handler) buildFilterCondition(filter common.FilterOption) (condition string, args []interface{}) {
|
||||||
switch filter.Operator {
|
switch filter.Operator {
|
||||||
case "eq", "=":
|
case "eq", "=":
|
||||||
return fmt.Sprintf("%s = ?", filter.Column), []interface{}{filter.Value}
|
return fmt.Sprintf("%s = ?", filter.Column), []interface{}{filter.Value}
|
||||||
@@ -725,7 +725,8 @@ func (h *Handler) buildFilterCondition(filter common.FilterOption) (string, []in
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, preloads []common.PreloadOption) (common.SelectQuery, error) {
|
func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, preloads []common.PreloadOption) (common.SelectQuery, error) {
|
||||||
for _, preload := range preloads {
|
for i := range preloads {
|
||||||
|
preload := &preloads[i]
|
||||||
if preload.Relation == "" {
|
if preload.Relation == "" {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|||||||
153
pkg/security/KEYSTORE.md
Normal file
153
pkg/security/KEYSTORE.md
Normal file
@@ -0,0 +1,153 @@
|
|||||||
|
# Keystore
|
||||||
|
|
||||||
|
Per-user named auth keys with pluggable storage. Each user can hold multiple keys of different types — JWT secrets, header API keys, OAuth2 client credentials, or generic API keys. Keys are identified by a human-readable name ("CI deploy", "mobile app") and can carry scopes and arbitrary metadata.
|
||||||
|
|
||||||
|
## Key types
|
||||||
|
|
||||||
|
| Constant | Value | Use case |
|
||||||
|
|---|---|---|
|
||||||
|
| `KeyTypeJWTSecret` | `jwt_secret` | Per-user JWT signing secret |
|
||||||
|
| `KeyTypeHeaderAPI` | `header_api` | Static API key sent in a request header |
|
||||||
|
| `KeyTypeOAuth2` | `oauth2` | OAuth2 client credentials |
|
||||||
|
| `KeyTypeGenericAPI` | `api` | General-purpose application key |
|
||||||
|
|
||||||
|
## Storage backends
|
||||||
|
|
||||||
|
### ConfigKeyStore
|
||||||
|
|
||||||
|
In-memory store seeded from a static list. Suitable for a small, fixed set of service-account keys loaded from a config file. Keys created at runtime via `CreateKey` are held in memory and lost on restart.
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Pre-load keys from config (KeyHash = SHA-256 hex of the raw key)
|
||||||
|
store := security.NewConfigKeyStore([]security.UserKey{
|
||||||
|
{
|
||||||
|
UserID: 1,
|
||||||
|
KeyType: security.KeyTypeGenericAPI,
|
||||||
|
KeyHash: "e3b0c44298fc1c149afb...", // sha256(rawKey)
|
||||||
|
Name: "CI deploy",
|
||||||
|
Scopes: []string{"deploy"},
|
||||||
|
IsActive: true,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
### DatabaseKeyStore
|
||||||
|
|
||||||
|
Backed by PostgreSQL stored procedures. Supports optional caching (default 2-minute TTL). Apply `keystore_schema.sql` before use.
|
||||||
|
|
||||||
|
```go
|
||||||
|
db, _ := sql.Open("postgres", dsn)
|
||||||
|
|
||||||
|
store := security.NewDatabaseKeyStore(db)
|
||||||
|
|
||||||
|
// With options
|
||||||
|
store = security.NewDatabaseKeyStore(db, security.DatabaseKeyStoreOptions{
|
||||||
|
CacheTTL: 5 * time.Minute,
|
||||||
|
SQLNames: &security.KeyStoreSQLNames{
|
||||||
|
ValidateKey: "myapp_keystore_validate", // override one procedure name
|
||||||
|
},
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
## Managing keys
|
||||||
|
|
||||||
|
```go
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Create — raw key returned once; store it securely
|
||||||
|
resp, err := store.CreateKey(ctx, security.CreateKeyRequest{
|
||||||
|
UserID: 42,
|
||||||
|
KeyType: security.KeyTypeGenericAPI,
|
||||||
|
Name: "mobile app",
|
||||||
|
Scopes: []string{"read", "write"},
|
||||||
|
})
|
||||||
|
fmt.Println(resp.RawKey) // only shown here; hashed internally
|
||||||
|
|
||||||
|
// List
|
||||||
|
keys, err := store.GetUserKeys(ctx, 42, "") // "" = all types
|
||||||
|
keys, err = store.GetUserKeys(ctx, 42, security.KeyTypeGenericAPI)
|
||||||
|
|
||||||
|
// Revoke
|
||||||
|
err = store.DeleteKey(ctx, 42, resp.Key.ID)
|
||||||
|
|
||||||
|
// Validate (used by authenticators internally)
|
||||||
|
key, err := store.ValidateKey(ctx, rawKey, "")
|
||||||
|
```
|
||||||
|
|
||||||
|
## HTTP authentication
|
||||||
|
|
||||||
|
`KeyStoreAuthenticator` wraps any `KeyStore` and implements the `Authenticator` interface. It is drop-in compatible with `DatabaseAuthenticator` and works in `CompositeSecurityProvider`.
|
||||||
|
|
||||||
|
Keys are extracted from the request in this order:
|
||||||
|
|
||||||
|
1. `Authorization: Bearer <key>`
|
||||||
|
2. `Authorization: ApiKey <key>`
|
||||||
|
3. `X-API-Key: <key>`
|
||||||
|
|
||||||
|
```go
|
||||||
|
auth := security.NewKeyStoreAuthenticator(store, "") // "" = accept any key type
|
||||||
|
// Restrict to a specific type:
|
||||||
|
auth = security.NewKeyStoreAuthenticator(store, security.KeyTypeGenericAPI)
|
||||||
|
```
|
||||||
|
|
||||||
|
Plug it into a handler:
|
||||||
|
|
||||||
|
```go
|
||||||
|
handler := resolvespec.NewHandler(db, registry,
|
||||||
|
resolvespec.WithAuthenticator(auth),
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
`Login` and `Logout` return an error — key lifecycle is managed through `KeyStore` directly.
|
||||||
|
|
||||||
|
On successful validation the request context receives a `UserContext` where:
|
||||||
|
|
||||||
|
- `UserID` — from the key
|
||||||
|
- `Roles` — the key's `Scopes`
|
||||||
|
- `Claims["key_type"]` — key type string
|
||||||
|
- `Claims["key_name"]` — key name
|
||||||
|
|
||||||
|
## Database setup
|
||||||
|
|
||||||
|
Apply `keystore_schema.sql` to your PostgreSQL database. It requires the `users` table from the main `database_schema.sql`.
|
||||||
|
|
||||||
|
```sql
|
||||||
|
\i pkg/security/keystore_schema.sql
|
||||||
|
```
|
||||||
|
|
||||||
|
This creates:
|
||||||
|
|
||||||
|
- `user_keys` table with indexes on `user_id`, `key_hash`, and `key_type`
|
||||||
|
- `resolvespec_keystore_get_user_keys(p_user_id, p_key_type)`
|
||||||
|
- `resolvespec_keystore_create_key(p_request jsonb)`
|
||||||
|
- `resolvespec_keystore_delete_key(p_user_id, p_key_id)`
|
||||||
|
- `resolvespec_keystore_validate_key(p_key_hash, p_key_type)`
|
||||||
|
|
||||||
|
### Custom procedure names
|
||||||
|
|
||||||
|
```go
|
||||||
|
store := security.NewDatabaseKeyStore(db, security.DatabaseKeyStoreOptions{
|
||||||
|
SQLNames: &security.KeyStoreSQLNames{
|
||||||
|
GetUserKeys: "myschema_get_keys",
|
||||||
|
CreateKey: "myschema_create_key",
|
||||||
|
DeleteKey: "myschema_delete_key",
|
||||||
|
ValidateKey: "myschema_validate_key",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
// Validate names at startup
|
||||||
|
names := &security.KeyStoreSQLNames{
|
||||||
|
GetUserKeys: "myschema_get_keys",
|
||||||
|
// ...
|
||||||
|
}
|
||||||
|
if err := security.ValidateKeyStoreSQLNames(names); err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Security notes
|
||||||
|
|
||||||
|
- Raw keys are never stored. Only the SHA-256 hex digest is persisted.
|
||||||
|
- The raw key is generated with `crypto/rand` (32 bytes, base64url-encoded) and returned exactly once in `CreateKeyResponse.RawKey`.
|
||||||
|
- Hash comparisons in `ConfigKeyStore` use `crypto/subtle.ConstantTimeCompare` to prevent timing side-channels.
|
||||||
|
- `DeleteKey` performs a soft delete (`is_active = false`). The `DatabaseKeyStore` invalidates the cache entry immediately, but due to the cache TTL a revoked key may authenticate for up to `CacheTTL` (default 2 minutes) in a distributed environment. Set `CacheTTL: 0` to disable caching if immediate revocation is required.
|
||||||
81
pkg/security/keystore.go
Normal file
81
pkg/security/keystore.go
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
package security
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/hex"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// hashSHA256Hex returns the lowercase hex SHA-256 digest of the given string.
|
||||||
|
// Used by all keystore implementations to hash raw keys before storage or lookup.
|
||||||
|
func hashSHA256Hex(raw string) string {
|
||||||
|
sum := sha256.Sum256([]byte(raw))
|
||||||
|
return hex.EncodeToString(sum[:])
|
||||||
|
}
|
||||||
|
|
||||||
|
// KeyType identifies the category of an auth key.
|
||||||
|
type KeyType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
// KeyTypeJWTSecret is a per-user JWT signing secret for token generation.
|
||||||
|
KeyTypeJWTSecret KeyType = "jwt_secret"
|
||||||
|
// KeyTypeHeaderAPI is a static API key sent via a request header.
|
||||||
|
KeyTypeHeaderAPI KeyType = "header_api"
|
||||||
|
// KeyTypeOAuth2 holds OAuth2 client credentials (client_id / client_secret).
|
||||||
|
KeyTypeOAuth2 KeyType = "oauth2"
|
||||||
|
// KeyTypeGenericAPI is a generic application API key.
|
||||||
|
KeyTypeGenericAPI KeyType = "api"
|
||||||
|
)
|
||||||
|
|
||||||
|
// UserKey represents a single named auth key belonging to a user.
|
||||||
|
// KeyHash stores the SHA-256 hex digest of the raw key; the raw key is never persisted.
|
||||||
|
type UserKey struct {
|
||||||
|
ID int64 `json:"id"`
|
||||||
|
UserID int `json:"user_id"`
|
||||||
|
KeyType KeyType `json:"key_type"`
|
||||||
|
KeyHash string `json:"key_hash"` // SHA-256 hex; never the raw key
|
||||||
|
Name string `json:"name"`
|
||||||
|
Scopes []string `json:"scopes,omitempty"`
|
||||||
|
Meta map[string]any `json:"meta,omitempty"`
|
||||||
|
ExpiresAt *time.Time `json:"expires_at,omitempty"`
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
LastUsedAt *time.Time `json:"last_used_at,omitempty"`
|
||||||
|
IsActive bool `json:"is_active"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateKeyRequest specifies the parameters for a new key.
|
||||||
|
type CreateKeyRequest struct {
|
||||||
|
UserID int
|
||||||
|
KeyType KeyType
|
||||||
|
Name string
|
||||||
|
Scopes []string
|
||||||
|
Meta map[string]any
|
||||||
|
ExpiresAt *time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateKeyResponse is returned exactly once when a key is created.
|
||||||
|
// The caller is responsible for persisting RawKey; it is not stored anywhere.
|
||||||
|
type CreateKeyResponse struct {
|
||||||
|
Key UserKey
|
||||||
|
RawKey string // crypto/rand 32 bytes, base64url-encoded
|
||||||
|
}
|
||||||
|
|
||||||
|
// KeyStore manages per-user auth keys with pluggable storage backends.
|
||||||
|
// Implementations: ConfigKeyStore (static list) and DatabaseKeyStore (stored procedures).
|
||||||
|
type KeyStore interface {
|
||||||
|
// CreateKey generates a new key, stores its hash, and returns the raw key once.
|
||||||
|
CreateKey(ctx context.Context, req CreateKeyRequest) (*CreateKeyResponse, error)
|
||||||
|
|
||||||
|
// GetUserKeys returns all active, non-expired keys for a user.
|
||||||
|
// Pass an empty KeyType to return all types.
|
||||||
|
GetUserKeys(ctx context.Context, userID int, keyType KeyType) ([]UserKey, error)
|
||||||
|
|
||||||
|
// DeleteKey soft-deletes a key by ID after verifying ownership.
|
||||||
|
DeleteKey(ctx context.Context, userID int, keyID int64) error
|
||||||
|
|
||||||
|
// ValidateKey checks a raw key, returns the matching UserKey on success.
|
||||||
|
// The implementation hashes the raw key before any lookup.
|
||||||
|
// Pass an empty KeyType to accept any type.
|
||||||
|
ValidateKey(ctx context.Context, rawKey string, keyType KeyType) (*UserKey, error)
|
||||||
|
}
|
||||||
97
pkg/security/keystore_authenticator.go
Normal file
97
pkg/security/keystore_authenticator.go
Normal file
@@ -0,0 +1,97 @@
|
|||||||
|
package security
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// KeyStoreAuthenticator implements the Authenticator interface using a KeyStore.
|
||||||
|
// It is suitable for long-lived application credentials (API keys, JWT secrets, etc.)
|
||||||
|
// rather than interactive sessions. Login and Logout are not supported — key lifecycle
|
||||||
|
// is managed directly through the KeyStore.
|
||||||
|
//
|
||||||
|
// Key extraction order:
|
||||||
|
// 1. Authorization: Bearer <key>
|
||||||
|
// 2. Authorization: ApiKey <key>
|
||||||
|
// 3. X-API-Key header
|
||||||
|
type KeyStoreAuthenticator struct {
|
||||||
|
keyStore KeyStore
|
||||||
|
keyType KeyType // empty = accept any type
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewKeyStoreAuthenticator creates a KeyStoreAuthenticator.
|
||||||
|
// Pass an empty keyType to accept keys of any type.
|
||||||
|
func NewKeyStoreAuthenticator(ks KeyStore, keyType KeyType) *KeyStoreAuthenticator {
|
||||||
|
return &KeyStoreAuthenticator{keyStore: ks, keyType: keyType}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Login is not supported for keystore authentication.
|
||||||
|
func (a *KeyStoreAuthenticator) Login(_ context.Context, _ LoginRequest) (*LoginResponse, error) {
|
||||||
|
return nil, fmt.Errorf("keystore authenticator does not support login")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Logout is not supported for keystore authentication.
|
||||||
|
func (a *KeyStoreAuthenticator) Logout(_ context.Context, _ LogoutRequest) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Authenticate extracts an API key from the request and validates it against the KeyStore.
|
||||||
|
// Returns a UserContext built from the matching UserKey on success.
|
||||||
|
func (a *KeyStoreAuthenticator) Authenticate(r *http.Request) (*UserContext, error) {
|
||||||
|
rawKey := extractAPIKey(r)
|
||||||
|
if rawKey == "" {
|
||||||
|
return nil, fmt.Errorf("API key required (Authorization: Bearer/ApiKey <key> or X-API-Key header)")
|
||||||
|
}
|
||||||
|
|
||||||
|
userKey, err := a.keyStore.ValidateKey(r.Context(), rawKey, a.keyType)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid API key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return userKeyToUserContext(userKey), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractAPIKey extracts a raw key from the request using the following precedence:
|
||||||
|
// 1. Authorization: Bearer <key>
|
||||||
|
// 2. Authorization: ApiKey <key>
|
||||||
|
// 3. X-API-Key header
|
||||||
|
func extractAPIKey(r *http.Request) string {
|
||||||
|
if auth := r.Header.Get("Authorization"); auth != "" {
|
||||||
|
if after, ok := strings.CutPrefix(auth, "Bearer "); ok {
|
||||||
|
return strings.TrimSpace(after)
|
||||||
|
}
|
||||||
|
if after, ok := strings.CutPrefix(auth, "ApiKey "); ok {
|
||||||
|
return strings.TrimSpace(after)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return strings.TrimSpace(r.Header.Get("X-API-Key"))
|
||||||
|
}
|
||||||
|
|
||||||
|
// userKeyToUserContext converts a UserKey into a UserContext.
|
||||||
|
// Scopes are mapped to Roles. Key type and name are stored in Claims.
|
||||||
|
func userKeyToUserContext(k *UserKey) *UserContext {
|
||||||
|
claims := map[string]any{
|
||||||
|
"key_type": string(k.KeyType),
|
||||||
|
"key_name": k.Name,
|
||||||
|
}
|
||||||
|
|
||||||
|
meta := k.Meta
|
||||||
|
if meta == nil {
|
||||||
|
meta = map[string]any{}
|
||||||
|
}
|
||||||
|
|
||||||
|
roles := k.Scopes
|
||||||
|
if roles == nil {
|
||||||
|
roles = []string{}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &UserContext{
|
||||||
|
UserID: k.UserID,
|
||||||
|
SessionID: fmt.Sprintf("key:%d", k.ID),
|
||||||
|
Roles: roles,
|
||||||
|
Claims: claims,
|
||||||
|
Meta: meta,
|
||||||
|
}
|
||||||
|
}
|
||||||
149
pkg/security/keystore_config.go
Normal file
149
pkg/security/keystore_config.go
Normal file
@@ -0,0 +1,149 @@
|
|||||||
|
package security
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/subtle"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/hex"
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ConfigKeyStore is an in-memory keystore backed by a static slice of UserKey values.
|
||||||
|
// It is designed for config-file driven setups (e.g. service accounts defined in YAML)
|
||||||
|
// with a small, bounded number of keys. For large or dynamic key sets use DatabaseKeyStore.
|
||||||
|
//
|
||||||
|
// Pre-existing entries must have KeyHash set to the SHA-256 hex of the intended raw key.
|
||||||
|
// Keys created at runtime via CreateKey are held in memory only and lost on restart.
|
||||||
|
type ConfigKeyStore struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
keys []UserKey
|
||||||
|
next int64 // monotonic ID counter for runtime-created keys (atomic)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewConfigKeyStore creates a ConfigKeyStore seeded with the provided keys.
|
||||||
|
// Pass nil or an empty slice to start with no pre-loaded keys.
|
||||||
|
// Zero-value entries (CreatedAt is zero) are treated as active and assigned the current time.
|
||||||
|
func NewConfigKeyStore(keys []UserKey) *ConfigKeyStore {
|
||||||
|
var maxID int64
|
||||||
|
copied := make([]UserKey, len(keys))
|
||||||
|
copy(copied, keys)
|
||||||
|
for i := range copied {
|
||||||
|
if copied[i].CreatedAt.IsZero() {
|
||||||
|
copied[i].IsActive = true
|
||||||
|
copied[i].CreatedAt = time.Now()
|
||||||
|
}
|
||||||
|
if copied[i].ID > maxID {
|
||||||
|
maxID = copied[i].ID
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return &ConfigKeyStore{keys: copied, next: maxID}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateKey generates a new raw key, stores its SHA-256 hash, and returns the raw key once.
|
||||||
|
func (s *ConfigKeyStore) CreateKey(_ context.Context, req CreateKeyRequest) (*CreateKeyResponse, error) {
|
||||||
|
rawBytes := make([]byte, 32)
|
||||||
|
if _, err := rand.Read(rawBytes); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to generate key material: %w", err)
|
||||||
|
}
|
||||||
|
rawKey := base64.RawURLEncoding.EncodeToString(rawBytes)
|
||||||
|
hash := hashSHA256Hex(rawKey)
|
||||||
|
|
||||||
|
id := atomic.AddInt64(&s.next, 1)
|
||||||
|
key := UserKey{
|
||||||
|
ID: id,
|
||||||
|
UserID: req.UserID,
|
||||||
|
KeyType: req.KeyType,
|
||||||
|
KeyHash: hash,
|
||||||
|
Name: req.Name,
|
||||||
|
Scopes: req.Scopes,
|
||||||
|
Meta: req.Meta,
|
||||||
|
ExpiresAt: req.ExpiresAt,
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
IsActive: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
s.mu.Lock()
|
||||||
|
s.keys = append(s.keys, key)
|
||||||
|
s.mu.Unlock()
|
||||||
|
|
||||||
|
return &CreateKeyResponse{Key: key, RawKey: rawKey}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUserKeys returns all active, non-expired keys for the given user.
|
||||||
|
// Pass an empty KeyType to return all types.
|
||||||
|
func (s *ConfigKeyStore) GetUserKeys(_ context.Context, userID int, keyType KeyType) ([]UserKey, error) {
|
||||||
|
now := time.Now()
|
||||||
|
s.mu.RLock()
|
||||||
|
defer s.mu.RUnlock()
|
||||||
|
|
||||||
|
var result []UserKey
|
||||||
|
for i := range s.keys {
|
||||||
|
k := &s.keys[i]
|
||||||
|
if k.UserID != userID || !k.IsActive {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if k.ExpiresAt != nil && k.ExpiresAt.Before(now) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if keyType != "" && k.KeyType != keyType {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
result = append(result, *k)
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteKey soft-deletes a key by setting IsActive to false after ownership verification.
|
||||||
|
func (s *ConfigKeyStore) DeleteKey(_ context.Context, userID int, keyID int64) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
for i := range s.keys {
|
||||||
|
if s.keys[i].ID == keyID {
|
||||||
|
if s.keys[i].UserID != userID {
|
||||||
|
return fmt.Errorf("key not found or permission denied")
|
||||||
|
}
|
||||||
|
s.keys[i].IsActive = false
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return fmt.Errorf("key not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateKey hashes the raw key and finds a matching, active, non-expired entry.
|
||||||
|
// Uses constant-time comparison to prevent timing side-channels.
|
||||||
|
// Pass an empty KeyType to accept any type.
|
||||||
|
func (s *ConfigKeyStore) ValidateKey(_ context.Context, rawKey string, keyType KeyType) (*UserKey, error) {
|
||||||
|
hash := hashSHA256Hex(rawKey)
|
||||||
|
hashBytes, _ := hex.DecodeString(hash)
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
// Write lock: ValidateKey updates LastUsedAt on the matched entry.
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
for i := range s.keys {
|
||||||
|
k := &s.keys[i]
|
||||||
|
if !k.IsActive {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if k.ExpiresAt != nil && k.ExpiresAt.Before(now) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if keyType != "" && k.KeyType != keyType {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
stored, _ := hex.DecodeString(k.KeyHash)
|
||||||
|
if subtle.ConstantTimeCompare(hashBytes, stored) != 1 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
k.LastUsedAt = &now
|
||||||
|
result := *k
|
||||||
|
return &result, nil
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("invalid or expired key")
|
||||||
|
}
|
||||||
256
pkg/security/keystore_database.go
Normal file
256
pkg/security/keystore_database.go
Normal file
@@ -0,0 +1,256 @@
|
|||||||
|
package security
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"database/sql"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/cache"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DatabaseKeyStoreOptions configures DatabaseKeyStore.
|
||||||
|
type DatabaseKeyStoreOptions struct {
|
||||||
|
// Cache is an optional cache instance. If nil, uses the default cache.
|
||||||
|
Cache *cache.Cache
|
||||||
|
// CacheTTL is the duration to cache ValidateKey results.
|
||||||
|
// Default: 2 minutes.
|
||||||
|
CacheTTL time.Duration
|
||||||
|
// SQLNames provides custom procedure names. If nil, uses DefaultKeyStoreSQLNames().
|
||||||
|
SQLNames *KeyStoreSQLNames
|
||||||
|
// DBFactory is called to obtain a fresh *sql.DB when the existing connection is closed.
|
||||||
|
// If nil, reconnection is disabled.
|
||||||
|
DBFactory func() (*sql.DB, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DatabaseKeyStore is a KeyStore backed by PostgreSQL stored procedures.
|
||||||
|
// All DB operations go through configurable procedure names; the raw key is
|
||||||
|
// never passed to the database.
|
||||||
|
//
|
||||||
|
// See keystore_schema.sql for the required table and procedure definitions.
|
||||||
|
//
|
||||||
|
// Note: DeleteKey invalidates the cache entry for the deleted key. Due to the
|
||||||
|
// cache TTL, a deleted key may continue to authenticate for up to CacheTTL
|
||||||
|
// (default 2 minutes) if the cache entry cannot be invalidated.
|
||||||
|
type DatabaseKeyStore struct {
|
||||||
|
db *sql.DB
|
||||||
|
dbMu sync.RWMutex
|
||||||
|
dbFactory func() (*sql.DB, error)
|
||||||
|
sqlNames *KeyStoreSQLNames
|
||||||
|
cache *cache.Cache
|
||||||
|
cacheTTL time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDatabaseKeyStore creates a DatabaseKeyStore with optional configuration.
|
||||||
|
func NewDatabaseKeyStore(db *sql.DB, opts ...DatabaseKeyStoreOptions) *DatabaseKeyStore {
|
||||||
|
o := DatabaseKeyStoreOptions{}
|
||||||
|
if len(opts) > 0 {
|
||||||
|
o = opts[0]
|
||||||
|
}
|
||||||
|
if o.CacheTTL == 0 {
|
||||||
|
o.CacheTTL = 2 * time.Minute
|
||||||
|
}
|
||||||
|
c := o.Cache
|
||||||
|
if c == nil {
|
||||||
|
c = cache.GetDefaultCache()
|
||||||
|
}
|
||||||
|
names := MergeKeyStoreSQLNames(DefaultKeyStoreSQLNames(), o.SQLNames)
|
||||||
|
return &DatabaseKeyStore{
|
||||||
|
db: db,
|
||||||
|
dbFactory: o.DBFactory,
|
||||||
|
sqlNames: names,
|
||||||
|
cache: c,
|
||||||
|
cacheTTL: o.CacheTTL,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ks *DatabaseKeyStore) getDB() *sql.DB {
|
||||||
|
ks.dbMu.RLock()
|
||||||
|
defer ks.dbMu.RUnlock()
|
||||||
|
return ks.db
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ks *DatabaseKeyStore) reconnectDB() error {
|
||||||
|
if ks.dbFactory == nil {
|
||||||
|
return fmt.Errorf("no db factory configured for reconnect")
|
||||||
|
}
|
||||||
|
newDB, err := ks.dbFactory()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
ks.dbMu.Lock()
|
||||||
|
ks.db = newDB
|
||||||
|
ks.dbMu.Unlock()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateKey generates a raw key, stores its SHA-256 hash via the create procedure,
|
||||||
|
// and returns the raw key once.
|
||||||
|
func (ks *DatabaseKeyStore) CreateKey(ctx context.Context, req CreateKeyRequest) (*CreateKeyResponse, error) {
|
||||||
|
rawBytes := make([]byte, 32)
|
||||||
|
if _, err := rand.Read(rawBytes); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to generate key material: %w", err)
|
||||||
|
}
|
||||||
|
rawKey := base64.RawURLEncoding.EncodeToString(rawBytes)
|
||||||
|
hash := hashSHA256Hex(rawKey)
|
||||||
|
|
||||||
|
type createRequest struct {
|
||||||
|
UserID int `json:"user_id"`
|
||||||
|
KeyType KeyType `json:"key_type"`
|
||||||
|
KeyHash string `json:"key_hash"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Scopes []string `json:"scopes,omitempty"`
|
||||||
|
Meta map[string]any `json:"meta,omitempty"`
|
||||||
|
ExpiresAt *time.Time `json:"expires_at,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
reqJSON, err := json.Marshal(createRequest{
|
||||||
|
UserID: req.UserID,
|
||||||
|
KeyType: req.KeyType,
|
||||||
|
KeyHash: hash,
|
||||||
|
Name: req.Name,
|
||||||
|
Scopes: req.Scopes,
|
||||||
|
Meta: req.Meta,
|
||||||
|
ExpiresAt: req.ExpiresAt,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to marshal create key request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var success bool
|
||||||
|
var errorMsg sql.NullString
|
||||||
|
var keyJSON sql.NullString
|
||||||
|
|
||||||
|
query := fmt.Sprintf(`SELECT p_success, p_error, p_key::text FROM %s($1::jsonb)`, ks.sqlNames.CreateKey)
|
||||||
|
if err = ks.getDB().QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &keyJSON); err != nil {
|
||||||
|
return nil, fmt.Errorf("create key procedure failed: %w", err)
|
||||||
|
}
|
||||||
|
if !success {
|
||||||
|
return nil, errors.New(nullStringOr(errorMsg, "create key failed"))
|
||||||
|
}
|
||||||
|
|
||||||
|
var key UserKey
|
||||||
|
if err = json.Unmarshal([]byte(keyJSON.String), &key); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse created key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &CreateKeyResponse{Key: key, RawKey: rawKey}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUserKeys returns all active, non-expired keys for the given user.
|
||||||
|
// Pass an empty KeyType to return all types.
|
||||||
|
func (ks *DatabaseKeyStore) GetUserKeys(ctx context.Context, userID int, keyType KeyType) ([]UserKey, error) {
|
||||||
|
var success bool
|
||||||
|
var errorMsg sql.NullString
|
||||||
|
var keysJSON sql.NullString
|
||||||
|
|
||||||
|
query := fmt.Sprintf(`SELECT p_success, p_error, p_keys::text FROM %s($1, $2)`, ks.sqlNames.GetUserKeys)
|
||||||
|
if err := ks.getDB().QueryRowContext(ctx, query, userID, string(keyType)).Scan(&success, &errorMsg, &keysJSON); err != nil {
|
||||||
|
return nil, fmt.Errorf("get user keys procedure failed: %w", err)
|
||||||
|
}
|
||||||
|
if !success {
|
||||||
|
return nil, errors.New(nullStringOr(errorMsg, "get user keys failed"))
|
||||||
|
}
|
||||||
|
|
||||||
|
var keys []UserKey
|
||||||
|
if keysJSON.Valid && keysJSON.String != "" && keysJSON.String != "[]" {
|
||||||
|
if err := json.Unmarshal([]byte(keysJSON.String), &keys); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse user keys: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if keys == nil {
|
||||||
|
keys = []UserKey{}
|
||||||
|
}
|
||||||
|
return keys, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteKey soft-deletes a key after verifying ownership and invalidates its cache entry.
|
||||||
|
// The delete procedure returns the key_hash so no separate lookup is needed.
|
||||||
|
// Note: cache invalidation is best-effort; a cached entry may persist for up to CacheTTL.
|
||||||
|
func (ks *DatabaseKeyStore) DeleteKey(ctx context.Context, userID int, keyID int64) error {
|
||||||
|
var success bool
|
||||||
|
var errorMsg sql.NullString
|
||||||
|
var keyHash sql.NullString
|
||||||
|
|
||||||
|
query := fmt.Sprintf(`SELECT p_success, p_error, p_key_hash FROM %s($1, $2)`, ks.sqlNames.DeleteKey)
|
||||||
|
if err := ks.getDB().QueryRowContext(ctx, query, userID, keyID).Scan(&success, &errorMsg, &keyHash); err != nil {
|
||||||
|
return fmt.Errorf("delete key procedure failed: %w", err)
|
||||||
|
}
|
||||||
|
if !success {
|
||||||
|
return errors.New(nullStringOr(errorMsg, "delete key failed"))
|
||||||
|
}
|
||||||
|
|
||||||
|
if keyHash.Valid && keyHash.String != "" && ks.cache != nil {
|
||||||
|
_ = ks.cache.Delete(ctx, keystoreCacheKey(keyHash.String))
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateKey hashes the raw key and calls the validate procedure.
|
||||||
|
// Results are cached for CacheTTL to reduce DB load on hot paths.
|
||||||
|
func (ks *DatabaseKeyStore) ValidateKey(ctx context.Context, rawKey string, keyType KeyType) (*UserKey, error) {
|
||||||
|
hash := hashSHA256Hex(rawKey)
|
||||||
|
cacheKey := keystoreCacheKey(hash)
|
||||||
|
|
||||||
|
if ks.cache != nil {
|
||||||
|
var cached UserKey
|
||||||
|
if err := ks.cache.Get(ctx, cacheKey, &cached); err == nil {
|
||||||
|
if cached.IsActive {
|
||||||
|
return &cached, nil
|
||||||
|
}
|
||||||
|
return nil, errors.New("invalid or expired key")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var success bool
|
||||||
|
var errorMsg sql.NullString
|
||||||
|
var keyJSON sql.NullString
|
||||||
|
|
||||||
|
runQuery := func() error {
|
||||||
|
query := fmt.Sprintf(`SELECT p_success, p_error, p_key::text FROM %s($1, $2)`, ks.sqlNames.ValidateKey)
|
||||||
|
return ks.getDB().QueryRowContext(ctx, query, hash, string(keyType)).Scan(&success, &errorMsg, &keyJSON)
|
||||||
|
}
|
||||||
|
if err := runQuery(); err != nil {
|
||||||
|
if isDBClosed(err) {
|
||||||
|
if reconnErr := ks.reconnectDB(); reconnErr == nil {
|
||||||
|
err = runQuery()
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("validate key procedure failed: %w", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return nil, fmt.Errorf("validate key procedure failed: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !success {
|
||||||
|
return nil, errors.New(nullStringOr(errorMsg, "invalid or expired key"))
|
||||||
|
}
|
||||||
|
|
||||||
|
var key UserKey
|
||||||
|
if err := json.Unmarshal([]byte(keyJSON.String), &key); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse validated key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if ks.cache != nil {
|
||||||
|
_ = ks.cache.Set(ctx, cacheKey, key, ks.cacheTTL)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &key, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func keystoreCacheKey(hash string) string {
|
||||||
|
return "keystore:validate:" + hash
|
||||||
|
}
|
||||||
|
|
||||||
|
// nullStringOr returns s.String if valid, otherwise the fallback.
|
||||||
|
func nullStringOr(s sql.NullString, fallback string) string {
|
||||||
|
if s.Valid && s.String != "" {
|
||||||
|
return s.String
|
||||||
|
}
|
||||||
|
return fallback
|
||||||
|
}
|
||||||
187
pkg/security/keystore_schema.sql
Normal file
187
pkg/security/keystore_schema.sql
Normal file
@@ -0,0 +1,187 @@
|
|||||||
|
-- Keystore schema for per-user auth keys
|
||||||
|
-- Apply alongside database_schema.sql (requires the users table)
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS user_keys (
|
||||||
|
id BIGSERIAL PRIMARY KEY,
|
||||||
|
user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||||
|
key_type VARCHAR(50) NOT NULL,
|
||||||
|
key_hash VARCHAR(64) NOT NULL UNIQUE, -- SHA-256 hex digest (64 chars)
|
||||||
|
name VARCHAR(255) NOT NULL DEFAULT '',
|
||||||
|
scopes TEXT, -- JSON array, e.g. '["read","write"]'
|
||||||
|
meta JSONB,
|
||||||
|
expires_at TIMESTAMP,
|
||||||
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
last_used_at TIMESTAMP,
|
||||||
|
is_active BOOLEAN DEFAULT true
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_user_keys_user_id ON user_keys(user_id);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_user_keys_key_hash ON user_keys(key_hash);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_user_keys_key_type ON user_keys(key_type);
|
||||||
|
|
||||||
|
-- resolvespec_keystore_get_user_keys
|
||||||
|
-- Returns all active, non-expired keys for a user.
|
||||||
|
-- Pass empty p_key_type to return all key types.
|
||||||
|
CREATE OR REPLACE FUNCTION resolvespec_keystore_get_user_keys(
|
||||||
|
p_user_id INTEGER,
|
||||||
|
p_key_type TEXT DEFAULT ''
|
||||||
|
)
|
||||||
|
RETURNS TABLE(p_success BOOLEAN, p_error TEXT, p_keys JSONB)
|
||||||
|
LANGUAGE plpgsql AS $$
|
||||||
|
DECLARE
|
||||||
|
v_keys JSONB;
|
||||||
|
BEGIN
|
||||||
|
SELECT COALESCE(
|
||||||
|
jsonb_agg(
|
||||||
|
jsonb_build_object(
|
||||||
|
'id', k.id,
|
||||||
|
'user_id', k.user_id,
|
||||||
|
'key_type', k.key_type,
|
||||||
|
'name', k.name,
|
||||||
|
'scopes', CASE WHEN k.scopes IS NOT NULL THEN k.scopes::jsonb ELSE '[]'::jsonb END,
|
||||||
|
'meta', COALESCE(k.meta, '{}'::jsonb),
|
||||||
|
'expires_at', k.expires_at,
|
||||||
|
'created_at', k.created_at,
|
||||||
|
'last_used_at', k.last_used_at,
|
||||||
|
'is_active', k.is_active
|
||||||
|
)
|
||||||
|
),
|
||||||
|
'[]'::jsonb
|
||||||
|
)
|
||||||
|
INTO v_keys
|
||||||
|
FROM user_keys k
|
||||||
|
WHERE k.user_id = p_user_id
|
||||||
|
AND k.is_active = true
|
||||||
|
AND (k.expires_at IS NULL OR k.expires_at > NOW())
|
||||||
|
AND (p_key_type = '' OR k.key_type = p_key_type);
|
||||||
|
|
||||||
|
RETURN QUERY SELECT true, NULL::TEXT, v_keys;
|
||||||
|
EXCEPTION WHEN OTHERS THEN
|
||||||
|
RETURN QUERY SELECT false, SQLERRM, NULL::JSONB;
|
||||||
|
END;
|
||||||
|
$$;
|
||||||
|
|
||||||
|
-- resolvespec_keystore_create_key
|
||||||
|
-- Inserts a new key row. key_hash is provided by the caller (Go hashes the raw key).
|
||||||
|
-- Returns the created key record (without key_hash).
|
||||||
|
CREATE OR REPLACE FUNCTION resolvespec_keystore_create_key(
|
||||||
|
p_request JSONB
|
||||||
|
)
|
||||||
|
RETURNS TABLE(p_success BOOLEAN, p_error TEXT, p_key JSONB)
|
||||||
|
LANGUAGE plpgsql AS $$
|
||||||
|
DECLARE
|
||||||
|
v_id BIGINT;
|
||||||
|
v_created_at TIMESTAMP;
|
||||||
|
v_key JSONB;
|
||||||
|
BEGIN
|
||||||
|
INSERT INTO user_keys (user_id, key_type, key_hash, name, scopes, meta, expires_at)
|
||||||
|
VALUES (
|
||||||
|
(p_request->>'user_id')::INTEGER,
|
||||||
|
p_request->>'key_type',
|
||||||
|
p_request->>'key_hash',
|
||||||
|
COALESCE(p_request->>'name', ''),
|
||||||
|
p_request->>'scopes',
|
||||||
|
p_request->'meta',
|
||||||
|
CASE WHEN p_request->>'expires_at' IS NOT NULL
|
||||||
|
THEN (p_request->>'expires_at')::TIMESTAMP
|
||||||
|
ELSE NULL
|
||||||
|
END
|
||||||
|
)
|
||||||
|
RETURNING id, created_at INTO v_id, v_created_at;
|
||||||
|
|
||||||
|
v_key := jsonb_build_object(
|
||||||
|
'id', v_id,
|
||||||
|
'user_id', (p_request->>'user_id')::INTEGER,
|
||||||
|
'key_type', p_request->>'key_type',
|
||||||
|
'name', COALESCE(p_request->>'name', ''),
|
||||||
|
'scopes', CASE WHEN p_request->>'scopes' IS NOT NULL
|
||||||
|
THEN (p_request->>'scopes')::jsonb
|
||||||
|
ELSE '[]'::jsonb END,
|
||||||
|
'meta', COALESCE(p_request->'meta', '{}'::jsonb),
|
||||||
|
'expires_at', p_request->>'expires_at',
|
||||||
|
'created_at', v_created_at,
|
||||||
|
'is_active', true
|
||||||
|
);
|
||||||
|
|
||||||
|
RETURN QUERY SELECT true, NULL::TEXT, v_key;
|
||||||
|
EXCEPTION WHEN OTHERS THEN
|
||||||
|
RETURN QUERY SELECT false, SQLERRM, NULL::JSONB;
|
||||||
|
END;
|
||||||
|
$$;
|
||||||
|
|
||||||
|
-- resolvespec_keystore_delete_key
|
||||||
|
-- Soft-deletes a key (is_active = false) after verifying ownership.
|
||||||
|
-- Returns p_key_hash so the caller can invalidate cache entries without a separate query.
|
||||||
|
CREATE OR REPLACE FUNCTION resolvespec_keystore_delete_key(
|
||||||
|
p_user_id INTEGER,
|
||||||
|
p_key_id BIGINT
|
||||||
|
)
|
||||||
|
RETURNS TABLE(p_success BOOLEAN, p_error TEXT, p_key_hash TEXT)
|
||||||
|
LANGUAGE plpgsql AS $$
|
||||||
|
DECLARE
|
||||||
|
v_hash TEXT;
|
||||||
|
BEGIN
|
||||||
|
UPDATE user_keys
|
||||||
|
SET is_active = false
|
||||||
|
WHERE id = p_key_id AND user_id = p_user_id AND is_active = true
|
||||||
|
RETURNING key_hash INTO v_hash;
|
||||||
|
|
||||||
|
IF NOT FOUND THEN
|
||||||
|
RETURN QUERY SELECT false, 'key not found or already deleted'::TEXT, NULL::TEXT;
|
||||||
|
RETURN;
|
||||||
|
END IF;
|
||||||
|
|
||||||
|
RETURN QUERY SELECT true, NULL::TEXT, v_hash;
|
||||||
|
EXCEPTION WHEN OTHERS THEN
|
||||||
|
RETURN QUERY SELECT false, SQLERRM, NULL::TEXT;
|
||||||
|
END;
|
||||||
|
$$;
|
||||||
|
|
||||||
|
-- resolvespec_keystore_validate_key
|
||||||
|
-- Looks up a key by its SHA-256 hash, checks active status and expiry,
|
||||||
|
-- updates last_used_at, and returns the key record.
|
||||||
|
-- p_key_type can be empty to accept any key type.
|
||||||
|
CREATE OR REPLACE FUNCTION resolvespec_keystore_validate_key(
|
||||||
|
p_key_hash TEXT,
|
||||||
|
p_key_type TEXT DEFAULT ''
|
||||||
|
)
|
||||||
|
RETURNS TABLE(p_success BOOLEAN, p_error TEXT, p_key JSONB)
|
||||||
|
LANGUAGE plpgsql AS $$
|
||||||
|
DECLARE
|
||||||
|
v_key_rec user_keys%ROWTYPE;
|
||||||
|
v_key JSONB;
|
||||||
|
BEGIN
|
||||||
|
SELECT * INTO v_key_rec
|
||||||
|
FROM user_keys
|
||||||
|
WHERE key_hash = p_key_hash
|
||||||
|
AND is_active = true
|
||||||
|
AND (expires_at IS NULL OR expires_at > NOW())
|
||||||
|
AND (p_key_type = '' OR key_type = p_key_type);
|
||||||
|
|
||||||
|
IF NOT FOUND THEN
|
||||||
|
RETURN QUERY SELECT false, 'invalid or expired key'::TEXT, NULL::JSONB;
|
||||||
|
RETURN;
|
||||||
|
END IF;
|
||||||
|
|
||||||
|
UPDATE user_keys SET last_used_at = NOW() WHERE id = v_key_rec.id;
|
||||||
|
|
||||||
|
v_key := jsonb_build_object(
|
||||||
|
'id', v_key_rec.id,
|
||||||
|
'user_id', v_key_rec.user_id,
|
||||||
|
'key_type', v_key_rec.key_type,
|
||||||
|
'name', v_key_rec.name,
|
||||||
|
'scopes', CASE WHEN v_key_rec.scopes IS NOT NULL
|
||||||
|
THEN v_key_rec.scopes::jsonb
|
||||||
|
ELSE '[]'::jsonb END,
|
||||||
|
'meta', COALESCE(v_key_rec.meta, '{}'::jsonb),
|
||||||
|
'expires_at', v_key_rec.expires_at,
|
||||||
|
'created_at', v_key_rec.created_at,
|
||||||
|
'last_used_at', NOW(),
|
||||||
|
'is_active', v_key_rec.is_active
|
||||||
|
);
|
||||||
|
|
||||||
|
RETURN QUERY SELECT true, NULL::TEXT, v_key;
|
||||||
|
EXCEPTION WHEN OTHERS THEN
|
||||||
|
RETURN QUERY SELECT false, SQLERRM, NULL::JSONB;
|
||||||
|
END;
|
||||||
|
$$;
|
||||||
61
pkg/security/keystore_sql_names.go
Normal file
61
pkg/security/keystore_sql_names.go
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
package security
|
||||||
|
|
||||||
|
import "fmt"
|
||||||
|
|
||||||
|
// KeyStoreSQLNames holds the configurable stored procedure names used by DatabaseKeyStore.
|
||||||
|
// Use DefaultKeyStoreSQLNames() for defaults and MergeKeyStoreSQLNames() for partial overrides.
|
||||||
|
type KeyStoreSQLNames struct {
|
||||||
|
GetUserKeys string // default: "resolvespec_keystore_get_user_keys"
|
||||||
|
CreateKey string // default: "resolvespec_keystore_create_key"
|
||||||
|
DeleteKey string // default: "resolvespec_keystore_delete_key"
|
||||||
|
ValidateKey string // default: "resolvespec_keystore_validate_key"
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultKeyStoreSQLNames returns a KeyStoreSQLNames with all default resolvespec_keystore_* values.
|
||||||
|
func DefaultKeyStoreSQLNames() *KeyStoreSQLNames {
|
||||||
|
return &KeyStoreSQLNames{
|
||||||
|
GetUserKeys: "resolvespec_keystore_get_user_keys",
|
||||||
|
CreateKey: "resolvespec_keystore_create_key",
|
||||||
|
DeleteKey: "resolvespec_keystore_delete_key",
|
||||||
|
ValidateKey: "resolvespec_keystore_validate_key",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MergeKeyStoreSQLNames returns a copy of base with any non-empty fields from override applied.
|
||||||
|
// If override is nil, a copy of base is returned.
|
||||||
|
func MergeKeyStoreSQLNames(base, override *KeyStoreSQLNames) *KeyStoreSQLNames {
|
||||||
|
if override == nil {
|
||||||
|
copied := *base
|
||||||
|
return &copied
|
||||||
|
}
|
||||||
|
merged := *base
|
||||||
|
if override.GetUserKeys != "" {
|
||||||
|
merged.GetUserKeys = override.GetUserKeys
|
||||||
|
}
|
||||||
|
if override.CreateKey != "" {
|
||||||
|
merged.CreateKey = override.CreateKey
|
||||||
|
}
|
||||||
|
if override.DeleteKey != "" {
|
||||||
|
merged.DeleteKey = override.DeleteKey
|
||||||
|
}
|
||||||
|
if override.ValidateKey != "" {
|
||||||
|
merged.ValidateKey = override.ValidateKey
|
||||||
|
}
|
||||||
|
return &merged
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateKeyStoreSQLNames checks that all non-empty procedure names are valid SQL identifiers.
|
||||||
|
func ValidateKeyStoreSQLNames(names *KeyStoreSQLNames) error {
|
||||||
|
fields := map[string]string{
|
||||||
|
"GetUserKeys": names.GetUserKeys,
|
||||||
|
"CreateKey": names.CreateKey,
|
||||||
|
"DeleteKey": names.DeleteKey,
|
||||||
|
"ValidateKey": names.ValidateKey,
|
||||||
|
}
|
||||||
|
for field, val := range fields {
|
||||||
|
if val != "" && !validSQLIdentifier.MatchString(val) {
|
||||||
|
return fmt.Errorf("KeyStoreSQLNames.%s contains invalid characters: %q", field, val)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -244,7 +244,7 @@ func (a *DatabaseAuthenticator) oauth2GetOrCreateUser(ctx context.Context, userC
|
|||||||
var errMsg *string
|
var errMsg *string
|
||||||
var userID *int
|
var userID *int
|
||||||
|
|
||||||
err = a.db.QueryRowContext(ctx, fmt.Sprintf(`
|
err = a.getDB().QueryRowContext(ctx, fmt.Sprintf(`
|
||||||
SELECT p_success, p_error, p_user_id
|
SELECT p_success, p_error, p_user_id
|
||||||
FROM %s($1::jsonb)
|
FROM %s($1::jsonb)
|
||||||
`, a.sqlNames.OAuthGetOrCreateUser), userJSON).Scan(&success, &errMsg, &userID)
|
`, a.sqlNames.OAuthGetOrCreateUser), userJSON).Scan(&success, &errMsg, &userID)
|
||||||
@@ -287,7 +287,7 @@ func (a *DatabaseAuthenticator) oauth2CreateSession(ctx context.Context, session
|
|||||||
var success bool
|
var success bool
|
||||||
var errMsg *string
|
var errMsg *string
|
||||||
|
|
||||||
err = a.db.QueryRowContext(ctx, fmt.Sprintf(`
|
err = a.getDB().QueryRowContext(ctx, fmt.Sprintf(`
|
||||||
SELECT p_success, p_error
|
SELECT p_success, p_error
|
||||||
FROM %s($1::jsonb)
|
FROM %s($1::jsonb)
|
||||||
`, a.sqlNames.OAuthCreateSession), sessionJSON).Scan(&success, &errMsg)
|
`, a.sqlNames.OAuthCreateSession), sessionJSON).Scan(&success, &errMsg)
|
||||||
@@ -385,7 +385,7 @@ func (a *DatabaseAuthenticator) OAuth2RefreshToken(ctx context.Context, refreshT
|
|||||||
var errMsg *string
|
var errMsg *string
|
||||||
var sessionData []byte
|
var sessionData []byte
|
||||||
|
|
||||||
err = a.db.QueryRowContext(ctx, fmt.Sprintf(`
|
err = a.getDB().QueryRowContext(ctx, fmt.Sprintf(`
|
||||||
SELECT p_success, p_error, p_data::text
|
SELECT p_success, p_error, p_data::text
|
||||||
FROM %s($1)
|
FROM %s($1)
|
||||||
`, a.sqlNames.OAuthGetRefreshToken), refreshToken).Scan(&success, &errMsg, &sessionData)
|
`, a.sqlNames.OAuthGetRefreshToken), refreshToken).Scan(&success, &errMsg, &sessionData)
|
||||||
@@ -451,7 +451,7 @@ func (a *DatabaseAuthenticator) OAuth2RefreshToken(ctx context.Context, refreshT
|
|||||||
var updateSuccess bool
|
var updateSuccess bool
|
||||||
var updateErrMsg *string
|
var updateErrMsg *string
|
||||||
|
|
||||||
err = a.db.QueryRowContext(ctx, fmt.Sprintf(`
|
err = a.getDB().QueryRowContext(ctx, fmt.Sprintf(`
|
||||||
SELECT p_success, p_error
|
SELECT p_success, p_error
|
||||||
FROM %s($1::jsonb)
|
FROM %s($1::jsonb)
|
||||||
`, a.sqlNames.OAuthUpdateRefreshToken), updateJSON).Scan(&updateSuccess, &updateErrMsg)
|
`, a.sqlNames.OAuthUpdateRefreshToken), updateJSON).Scan(&updateSuccess, &updateErrMsg)
|
||||||
@@ -472,7 +472,7 @@ func (a *DatabaseAuthenticator) OAuth2RefreshToken(ctx context.Context, refreshT
|
|||||||
var userErrMsg *string
|
var userErrMsg *string
|
||||||
var userData []byte
|
var userData []byte
|
||||||
|
|
||||||
err = a.db.QueryRowContext(ctx, fmt.Sprintf(`
|
err = a.getDB().QueryRowContext(ctx, fmt.Sprintf(`
|
||||||
SELECT p_success, p_error, p_data::text
|
SELECT p_success, p_error, p_data::text
|
||||||
FROM %s($1)
|
FROM %s($1)
|
||||||
`, a.sqlNames.OAuthGetUser), session.UserID).Scan(&userSuccess, &userErrMsg, &userData)
|
`, a.sqlNames.OAuthGetUser), session.UserID).Scan(&userSuccess, &userErrMsg, &userData)
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -14,6 +15,8 @@ import (
|
|||||||
// Procedure names are configurable via SQLNames (see DefaultSQLNames for defaults)
|
// Procedure names are configurable via SQLNames (see DefaultSQLNames for defaults)
|
||||||
type DatabasePasskeyProvider struct {
|
type DatabasePasskeyProvider struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
|
dbMu sync.RWMutex
|
||||||
|
dbFactory func() (*sql.DB, error)
|
||||||
rpID string // Relying Party ID (domain)
|
rpID string // Relying Party ID (domain)
|
||||||
rpName string // Relying Party display name
|
rpName string // Relying Party display name
|
||||||
rpOrigin string // Expected origin for WebAuthn
|
rpOrigin string // Expected origin for WebAuthn
|
||||||
@@ -33,6 +36,9 @@ type DatabasePasskeyProviderOptions struct {
|
|||||||
Timeout int64
|
Timeout int64
|
||||||
// SQLNames provides custom SQL procedure/function names. If nil, uses DefaultSQLNames().
|
// SQLNames provides custom SQL procedure/function names. If nil, uses DefaultSQLNames().
|
||||||
SQLNames *SQLNames
|
SQLNames *SQLNames
|
||||||
|
// DBFactory is called to obtain a fresh *sql.DB when the existing connection is closed.
|
||||||
|
// If nil, reconnection is disabled.
|
||||||
|
DBFactory func() (*sql.DB, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewDatabasePasskeyProvider creates a new database-backed passkey provider
|
// NewDatabasePasskeyProvider creates a new database-backed passkey provider
|
||||||
@@ -45,6 +51,7 @@ func NewDatabasePasskeyProvider(db *sql.DB, opts DatabasePasskeyProviderOptions)
|
|||||||
|
|
||||||
return &DatabasePasskeyProvider{
|
return &DatabasePasskeyProvider{
|
||||||
db: db,
|
db: db,
|
||||||
|
dbFactory: opts.DBFactory,
|
||||||
rpID: opts.RPID,
|
rpID: opts.RPID,
|
||||||
rpName: opts.RPName,
|
rpName: opts.RPName,
|
||||||
rpOrigin: opts.RPOrigin,
|
rpOrigin: opts.RPOrigin,
|
||||||
@@ -53,6 +60,26 @@ func NewDatabasePasskeyProvider(db *sql.DB, opts DatabasePasskeyProviderOptions)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *DatabasePasskeyProvider) getDB() *sql.DB {
|
||||||
|
p.dbMu.RLock()
|
||||||
|
defer p.dbMu.RUnlock()
|
||||||
|
return p.db
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *DatabasePasskeyProvider) reconnectDB() error {
|
||||||
|
if p.dbFactory == nil {
|
||||||
|
return fmt.Errorf("no db factory configured for reconnect")
|
||||||
|
}
|
||||||
|
newDB, err := p.dbFactory()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
p.dbMu.Lock()
|
||||||
|
p.db = newDB
|
||||||
|
p.dbMu.Unlock()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// BeginRegistration creates registration options for a new passkey
|
// BeginRegistration creates registration options for a new passkey
|
||||||
func (p *DatabasePasskeyProvider) BeginRegistration(ctx context.Context, userID int, username, displayName string) (*PasskeyRegistrationOptions, error) {
|
func (p *DatabasePasskeyProvider) BeginRegistration(ctx context.Context, userID int, username, displayName string) (*PasskeyRegistrationOptions, error) {
|
||||||
// Generate challenge
|
// Generate challenge
|
||||||
@@ -140,7 +167,7 @@ func (p *DatabasePasskeyProvider) CompleteRegistration(ctx context.Context, user
|
|||||||
var credentialID sql.NullInt64
|
var credentialID sql.NullInt64
|
||||||
|
|
||||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_credential_id FROM %s($1::jsonb)`, p.sqlNames.PasskeyStoreCredential)
|
query := fmt.Sprintf(`SELECT p_success, p_error, p_credential_id FROM %s($1::jsonb)`, p.sqlNames.PasskeyStoreCredential)
|
||||||
err = p.db.QueryRowContext(ctx, query, string(credJSON)).Scan(&success, &errorMsg, &credentialID)
|
err = p.getDB().QueryRowContext(ctx, query, string(credJSON)).Scan(&success, &errorMsg, &credentialID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to store credential: %w", err)
|
return nil, fmt.Errorf("failed to store credential: %w", err)
|
||||||
}
|
}
|
||||||
@@ -181,7 +208,7 @@ func (p *DatabasePasskeyProvider) BeginAuthentication(ctx context.Context, usern
|
|||||||
var credentialsJSON sql.NullString
|
var credentialsJSON sql.NullString
|
||||||
|
|
||||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_user_id, p_credentials::text FROM %s($1)`, p.sqlNames.PasskeyGetCredsByUsername)
|
query := fmt.Sprintf(`SELECT p_success, p_error, p_user_id, p_credentials::text FROM %s($1)`, p.sqlNames.PasskeyGetCredsByUsername)
|
||||||
err := p.db.QueryRowContext(ctx, query, username).Scan(&success, &errorMsg, &userID, &credentialsJSON)
|
err := p.getDB().QueryRowContext(ctx, query, username).Scan(&success, &errorMsg, &userID, &credentialsJSON)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to get credentials: %w", err)
|
return nil, fmt.Errorf("failed to get credentials: %w", err)
|
||||||
}
|
}
|
||||||
@@ -240,8 +267,16 @@ func (p *DatabasePasskeyProvider) CompleteAuthentication(ctx context.Context, re
|
|||||||
var errorMsg sql.NullString
|
var errorMsg sql.NullString
|
||||||
var credentialJSON sql.NullString
|
var credentialJSON sql.NullString
|
||||||
|
|
||||||
|
runQuery := func() error {
|
||||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_credential::text FROM %s($1)`, p.sqlNames.PasskeyGetCredential)
|
query := fmt.Sprintf(`SELECT p_success, p_error, p_credential::text FROM %s($1)`, p.sqlNames.PasskeyGetCredential)
|
||||||
err := p.db.QueryRowContext(ctx, query, response.RawID).Scan(&success, &errorMsg, &credentialJSON)
|
return p.getDB().QueryRowContext(ctx, query, response.RawID).Scan(&success, &errorMsg, &credentialJSON)
|
||||||
|
}
|
||||||
|
err := runQuery()
|
||||||
|
if isDBClosed(err) {
|
||||||
|
if reconnErr := p.reconnectDB(); reconnErr == nil {
|
||||||
|
err = runQuery()
|
||||||
|
}
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, fmt.Errorf("failed to get credential: %w", err)
|
return 0, fmt.Errorf("failed to get credential: %w", err)
|
||||||
}
|
}
|
||||||
@@ -272,7 +307,7 @@ func (p *DatabasePasskeyProvider) CompleteAuthentication(ctx context.Context, re
|
|||||||
var cloneWarning sql.NullBool
|
var cloneWarning sql.NullBool
|
||||||
|
|
||||||
updateQuery := fmt.Sprintf(`SELECT p_success, p_error, p_clone_warning FROM %s($1, $2)`, p.sqlNames.PasskeyUpdateCounter)
|
updateQuery := fmt.Sprintf(`SELECT p_success, p_error, p_clone_warning FROM %s($1, $2)`, p.sqlNames.PasskeyUpdateCounter)
|
||||||
err = p.db.QueryRowContext(ctx, updateQuery, response.RawID, newCounter).Scan(&updateSuccess, &updateError, &cloneWarning)
|
err = p.getDB().QueryRowContext(ctx, updateQuery, response.RawID, newCounter).Scan(&updateSuccess, &updateError, &cloneWarning)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, fmt.Errorf("failed to update counter: %w", err)
|
return 0, fmt.Errorf("failed to update counter: %w", err)
|
||||||
}
|
}
|
||||||
@@ -291,7 +326,7 @@ func (p *DatabasePasskeyProvider) GetCredentials(ctx context.Context, userID int
|
|||||||
var credentialsJSON sql.NullString
|
var credentialsJSON sql.NullString
|
||||||
|
|
||||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_credentials::text FROM %s($1)`, p.sqlNames.PasskeyGetUserCredentials)
|
query := fmt.Sprintf(`SELECT p_success, p_error, p_credentials::text FROM %s($1)`, p.sqlNames.PasskeyGetUserCredentials)
|
||||||
err := p.db.QueryRowContext(ctx, query, userID).Scan(&success, &errorMsg, &credentialsJSON)
|
err := p.getDB().QueryRowContext(ctx, query, userID).Scan(&success, &errorMsg, &credentialsJSON)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to get credentials: %w", err)
|
return nil, fmt.Errorf("failed to get credentials: %w", err)
|
||||||
}
|
}
|
||||||
@@ -370,7 +405,7 @@ func (p *DatabasePasskeyProvider) DeleteCredential(ctx context.Context, userID i
|
|||||||
var errorMsg sql.NullString
|
var errorMsg sql.NullString
|
||||||
|
|
||||||
query := fmt.Sprintf(`SELECT p_success, p_error FROM %s($1, $2)`, p.sqlNames.PasskeyDeleteCredential)
|
query := fmt.Sprintf(`SELECT p_success, p_error FROM %s($1, $2)`, p.sqlNames.PasskeyDeleteCredential)
|
||||||
err = p.db.QueryRowContext(ctx, query, userID, credID).Scan(&success, &errorMsg)
|
err = p.getDB().QueryRowContext(ctx, query, userID, credID).Scan(&success, &errorMsg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to delete credential: %w", err)
|
return fmt.Errorf("failed to delete credential: %w", err)
|
||||||
}
|
}
|
||||||
@@ -396,7 +431,7 @@ func (p *DatabasePasskeyProvider) UpdateCredentialName(ctx context.Context, user
|
|||||||
var errorMsg sql.NullString
|
var errorMsg sql.NullString
|
||||||
|
|
||||||
query := fmt.Sprintf(`SELECT p_success, p_error FROM %s($1, $2, $3)`, p.sqlNames.PasskeyUpdateName)
|
query := fmt.Sprintf(`SELECT p_success, p_error FROM %s($1, $2, $3)`, p.sqlNames.PasskeyUpdateName)
|
||||||
err = p.db.QueryRowContext(ctx, query, userID, credID, name).Scan(&success, &errorMsg)
|
err = p.getDB().QueryRowContext(ctx, query, userID, credID, name).Scan(&success, &errorMsg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to update credential name: %w", err)
|
return fmt.Errorf("failed to update credential name: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -64,6 +64,8 @@ func (a *HeaderAuthenticator) Authenticate(r *http.Request) (*UserContext, error
|
|||||||
// Also supports passkey authentication configured with WithPasskey()
|
// Also supports passkey authentication configured with WithPasskey()
|
||||||
type DatabaseAuthenticator struct {
|
type DatabaseAuthenticator struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
|
dbMu sync.RWMutex
|
||||||
|
dbFactory func() (*sql.DB, error)
|
||||||
cache *cache.Cache
|
cache *cache.Cache
|
||||||
cacheTTL time.Duration
|
cacheTTL time.Duration
|
||||||
sqlNames *SQLNames
|
sqlNames *SQLNames
|
||||||
@@ -88,6 +90,9 @@ type DatabaseAuthenticatorOptions struct {
|
|||||||
// SQLNames provides custom SQL procedure/function names. If nil, uses DefaultSQLNames().
|
// SQLNames provides custom SQL procedure/function names. If nil, uses DefaultSQLNames().
|
||||||
// Partial overrides are supported: only set the fields you want to change.
|
// Partial overrides are supported: only set the fields you want to change.
|
||||||
SQLNames *SQLNames
|
SQLNames *SQLNames
|
||||||
|
// DBFactory is called to obtain a fresh *sql.DB when the existing connection is closed.
|
||||||
|
// If nil, reconnection is disabled.
|
||||||
|
DBFactory func() (*sql.DB, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewDatabaseAuthenticator(db *sql.DB) *DatabaseAuthenticator {
|
func NewDatabaseAuthenticator(db *sql.DB) *DatabaseAuthenticator {
|
||||||
@@ -110,6 +115,7 @@ func NewDatabaseAuthenticatorWithOptions(db *sql.DB, opts DatabaseAuthenticatorO
|
|||||||
|
|
||||||
return &DatabaseAuthenticator{
|
return &DatabaseAuthenticator{
|
||||||
db: db,
|
db: db,
|
||||||
|
dbFactory: opts.DBFactory,
|
||||||
cache: cacheInstance,
|
cache: cacheInstance,
|
||||||
cacheTTL: opts.CacheTTL,
|
cacheTTL: opts.CacheTTL,
|
||||||
sqlNames: sqlNames,
|
sqlNames: sqlNames,
|
||||||
@@ -117,6 +123,26 @@ func NewDatabaseAuthenticatorWithOptions(db *sql.DB, opts DatabaseAuthenticatorO
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *DatabaseAuthenticator) getDB() *sql.DB {
|
||||||
|
a.dbMu.RLock()
|
||||||
|
defer a.dbMu.RUnlock()
|
||||||
|
return a.db
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *DatabaseAuthenticator) reconnectDB() error {
|
||||||
|
if a.dbFactory == nil {
|
||||||
|
return fmt.Errorf("no db factory configured for reconnect")
|
||||||
|
}
|
||||||
|
newDB, err := a.dbFactory()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
a.dbMu.Lock()
|
||||||
|
a.db = newDB
|
||||||
|
a.dbMu.Unlock()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (a *DatabaseAuthenticator) Login(ctx context.Context, req LoginRequest) (*LoginResponse, error) {
|
func (a *DatabaseAuthenticator) Login(ctx context.Context, req LoginRequest) (*LoginResponse, error) {
|
||||||
// Convert LoginRequest to JSON
|
// Convert LoginRequest to JSON
|
||||||
reqJSON, err := json.Marshal(req)
|
reqJSON, err := json.Marshal(req)
|
||||||
@@ -128,8 +154,16 @@ func (a *DatabaseAuthenticator) Login(ctx context.Context, req LoginRequest) (*L
|
|||||||
var errorMsg sql.NullString
|
var errorMsg sql.NullString
|
||||||
var dataJSON sql.NullString
|
var dataJSON sql.NullString
|
||||||
|
|
||||||
|
runLoginQuery := func() error {
|
||||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.Login)
|
query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.Login)
|
||||||
err = a.db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
|
return a.getDB().QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
|
||||||
|
}
|
||||||
|
err = runLoginQuery()
|
||||||
|
if isDBClosed(err) {
|
||||||
|
if reconnErr := a.reconnectDB(); reconnErr == nil {
|
||||||
|
err = runLoginQuery()
|
||||||
|
}
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("login query failed: %w", err)
|
return nil, fmt.Errorf("login query failed: %w", err)
|
||||||
}
|
}
|
||||||
@@ -163,7 +197,7 @@ func (a *DatabaseAuthenticator) Register(ctx context.Context, req RegisterReques
|
|||||||
var dataJSON sql.NullString
|
var dataJSON sql.NullString
|
||||||
|
|
||||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.Register)
|
query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.Register)
|
||||||
err = a.db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
|
err = a.getDB().QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("register query failed: %w", err)
|
return nil, fmt.Errorf("register query failed: %w", err)
|
||||||
}
|
}
|
||||||
@@ -196,7 +230,7 @@ func (a *DatabaseAuthenticator) Logout(ctx context.Context, req LogoutRequest) e
|
|||||||
var dataJSON sql.NullString
|
var dataJSON sql.NullString
|
||||||
|
|
||||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.Logout)
|
query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.Logout)
|
||||||
err = a.db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
|
err = a.getDB().QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("logout query failed: %w", err)
|
return fmt.Errorf("logout query failed: %w", err)
|
||||||
}
|
}
|
||||||
@@ -270,7 +304,7 @@ func (a *DatabaseAuthenticator) Authenticate(r *http.Request) (*UserContext, err
|
|||||||
var userJSON sql.NullString
|
var userJSON sql.NullString
|
||||||
|
|
||||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2)`, a.sqlNames.Session)
|
query := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2)`, a.sqlNames.Session)
|
||||||
err := a.db.QueryRowContext(r.Context(), query, token, reference).Scan(&success, &errorMsg, &userJSON)
|
err := a.getDB().QueryRowContext(r.Context(), query, token, reference).Scan(&success, &errorMsg, &userJSON)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("session query failed: %w", err)
|
return nil, fmt.Errorf("session query failed: %w", err)
|
||||||
}
|
}
|
||||||
@@ -346,7 +380,7 @@ func (a *DatabaseAuthenticator) updateSessionActivity(ctx context.Context, sessi
|
|||||||
var updatedUserJSON sql.NullString
|
var updatedUserJSON sql.NullString
|
||||||
|
|
||||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2::jsonb)`, a.sqlNames.SessionUpdate)
|
query := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2::jsonb)`, a.sqlNames.SessionUpdate)
|
||||||
_ = a.db.QueryRowContext(ctx, query, sessionToken, string(userJSON)).Scan(&success, &errorMsg, &updatedUserJSON)
|
_ = a.getDB().QueryRowContext(ctx, query, sessionToken, string(userJSON)).Scan(&success, &errorMsg, &updatedUserJSON)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RefreshToken implements Refreshable interface
|
// RefreshToken implements Refreshable interface
|
||||||
@@ -357,7 +391,7 @@ func (a *DatabaseAuthenticator) RefreshToken(ctx context.Context, refreshToken s
|
|||||||
var userJSON sql.NullString
|
var userJSON sql.NullString
|
||||||
// Get current session to pass to refresh
|
// Get current session to pass to refresh
|
||||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2)`, a.sqlNames.Session)
|
query := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2)`, a.sqlNames.Session)
|
||||||
err := a.db.QueryRowContext(ctx, query, refreshToken, "refresh").Scan(&success, &errorMsg, &userJSON)
|
err := a.getDB().QueryRowContext(ctx, query, refreshToken, "refresh").Scan(&success, &errorMsg, &userJSON)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("refresh token query failed: %w", err)
|
return nil, fmt.Errorf("refresh token query failed: %w", err)
|
||||||
}
|
}
|
||||||
@@ -374,7 +408,7 @@ func (a *DatabaseAuthenticator) RefreshToken(ctx context.Context, refreshToken s
|
|||||||
var newUserJSON sql.NullString
|
var newUserJSON sql.NullString
|
||||||
|
|
||||||
refreshQuery := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2::jsonb)`, a.sqlNames.RefreshToken)
|
refreshQuery := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2::jsonb)`, a.sqlNames.RefreshToken)
|
||||||
err = a.db.QueryRowContext(ctx, refreshQuery, refreshToken, userJSON).Scan(&newSuccess, &newErrorMsg, &newUserJSON)
|
err = a.getDB().QueryRowContext(ctx, refreshQuery, refreshToken, userJSON).Scan(&newSuccess, &newErrorMsg, &newUserJSON)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("refresh token generation failed: %w", err)
|
return nil, fmt.Errorf("refresh token generation failed: %w", err)
|
||||||
}
|
}
|
||||||
@@ -406,6 +440,8 @@ func (a *DatabaseAuthenticator) RefreshToken(ctx context.Context, refreshToken s
|
|||||||
type JWTAuthenticator struct {
|
type JWTAuthenticator struct {
|
||||||
secretKey []byte
|
secretKey []byte
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
|
dbMu sync.RWMutex
|
||||||
|
dbFactory func() (*sql.DB, error)
|
||||||
sqlNames *SQLNames
|
sqlNames *SQLNames
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -417,13 +453,47 @@ func NewJWTAuthenticator(secretKey string, db *sql.DB, names ...*SQLNames) *JWTA
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WithDBFactory configures a factory used to reopen the database connection if it is closed.
|
||||||
|
func (a *JWTAuthenticator) WithDBFactory(factory func() (*sql.DB, error)) *JWTAuthenticator {
|
||||||
|
a.dbFactory = factory
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *JWTAuthenticator) getDB() *sql.DB {
|
||||||
|
a.dbMu.RLock()
|
||||||
|
defer a.dbMu.RUnlock()
|
||||||
|
return a.db
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *JWTAuthenticator) reconnectDB() error {
|
||||||
|
if a.dbFactory == nil {
|
||||||
|
return fmt.Errorf("no db factory configured for reconnect")
|
||||||
|
}
|
||||||
|
newDB, err := a.dbFactory()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
a.dbMu.Lock()
|
||||||
|
a.db = newDB
|
||||||
|
a.dbMu.Unlock()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (a *JWTAuthenticator) Login(ctx context.Context, req LoginRequest) (*LoginResponse, error) {
|
func (a *JWTAuthenticator) Login(ctx context.Context, req LoginRequest) (*LoginResponse, error) {
|
||||||
var success bool
|
var success bool
|
||||||
var errorMsg sql.NullString
|
var errorMsg sql.NullString
|
||||||
var userJSON []byte
|
var userJSON []byte
|
||||||
|
|
||||||
|
runLoginQuery := func() error {
|
||||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_user FROM %s($1, $2)`, a.sqlNames.JWTLogin)
|
query := fmt.Sprintf(`SELECT p_success, p_error, p_user FROM %s($1, $2)`, a.sqlNames.JWTLogin)
|
||||||
err := a.db.QueryRowContext(ctx, query, req.Username, req.Password).Scan(&success, &errorMsg, &userJSON)
|
return a.getDB().QueryRowContext(ctx, query, req.Username, req.Password).Scan(&success, &errorMsg, &userJSON)
|
||||||
|
}
|
||||||
|
err := runLoginQuery()
|
||||||
|
if isDBClosed(err) {
|
||||||
|
if reconnErr := a.reconnectDB(); reconnErr == nil {
|
||||||
|
err = runLoginQuery()
|
||||||
|
}
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("login query failed: %w", err)
|
return nil, fmt.Errorf("login query failed: %w", err)
|
||||||
}
|
}
|
||||||
@@ -476,7 +546,7 @@ func (a *JWTAuthenticator) Logout(ctx context.Context, req LogoutRequest) error
|
|||||||
var errorMsg sql.NullString
|
var errorMsg sql.NullString
|
||||||
|
|
||||||
query := fmt.Sprintf(`SELECT p_success, p_error FROM %s($1, $2)`, a.sqlNames.JWTLogout)
|
query := fmt.Sprintf(`SELECT p_success, p_error FROM %s($1, $2)`, a.sqlNames.JWTLogout)
|
||||||
err := a.db.QueryRowContext(ctx, query, req.Token, req.UserID).Scan(&success, &errorMsg)
|
err := a.getDB().QueryRowContext(ctx, query, req.Token, req.UserID).Scan(&success, &errorMsg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("logout query failed: %w", err)
|
return fmt.Errorf("logout query failed: %w", err)
|
||||||
}
|
}
|
||||||
@@ -514,6 +584,8 @@ func (a *JWTAuthenticator) Authenticate(r *http.Request) (*UserContext, error) {
|
|||||||
// Procedure names are configurable via SQLNames (see DefaultSQLNames for defaults)
|
// Procedure names are configurable via SQLNames (see DefaultSQLNames for defaults)
|
||||||
type DatabaseColumnSecurityProvider struct {
|
type DatabaseColumnSecurityProvider struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
|
dbMu sync.RWMutex
|
||||||
|
dbFactory func() (*sql.DB, error)
|
||||||
sqlNames *SQLNames
|
sqlNames *SQLNames
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -521,6 +593,31 @@ func NewDatabaseColumnSecurityProvider(db *sql.DB, names ...*SQLNames) *Database
|
|||||||
return &DatabaseColumnSecurityProvider{db: db, sqlNames: resolveSQLNames(names...)}
|
return &DatabaseColumnSecurityProvider{db: db, sqlNames: resolveSQLNames(names...)}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *DatabaseColumnSecurityProvider) WithDBFactory(factory func() (*sql.DB, error)) *DatabaseColumnSecurityProvider {
|
||||||
|
p.dbFactory = factory
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *DatabaseColumnSecurityProvider) getDB() *sql.DB {
|
||||||
|
p.dbMu.RLock()
|
||||||
|
defer p.dbMu.RUnlock()
|
||||||
|
return p.db
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *DatabaseColumnSecurityProvider) reconnectDB() error {
|
||||||
|
if p.dbFactory == nil {
|
||||||
|
return fmt.Errorf("no db factory configured for reconnect")
|
||||||
|
}
|
||||||
|
newDB, err := p.dbFactory()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
p.dbMu.Lock()
|
||||||
|
p.db = newDB
|
||||||
|
p.dbMu.Unlock()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (p *DatabaseColumnSecurityProvider) GetColumnSecurity(ctx context.Context, userID int, schema, table string) ([]ColumnSecurity, error) {
|
func (p *DatabaseColumnSecurityProvider) GetColumnSecurity(ctx context.Context, userID int, schema, table string) ([]ColumnSecurity, error) {
|
||||||
var rules []ColumnSecurity
|
var rules []ColumnSecurity
|
||||||
|
|
||||||
@@ -528,8 +625,16 @@ func (p *DatabaseColumnSecurityProvider) GetColumnSecurity(ctx context.Context,
|
|||||||
var errorMsg sql.NullString
|
var errorMsg sql.NullString
|
||||||
var rulesJSON []byte
|
var rulesJSON []byte
|
||||||
|
|
||||||
|
runQuery := func() error {
|
||||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_rules FROM %s($1, $2, $3)`, p.sqlNames.ColumnSecurity)
|
query := fmt.Sprintf(`SELECT p_success, p_error, p_rules FROM %s($1, $2, $3)`, p.sqlNames.ColumnSecurity)
|
||||||
err := p.db.QueryRowContext(ctx, query, userID, schema, table).Scan(&success, &errorMsg, &rulesJSON)
|
return p.getDB().QueryRowContext(ctx, query, userID, schema, table).Scan(&success, &errorMsg, &rulesJSON)
|
||||||
|
}
|
||||||
|
err := runQuery()
|
||||||
|
if isDBClosed(err) {
|
||||||
|
if reconnErr := p.reconnectDB(); reconnErr == nil {
|
||||||
|
err = runQuery()
|
||||||
|
}
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to load column security: %w", err)
|
return nil, fmt.Errorf("failed to load column security: %w", err)
|
||||||
}
|
}
|
||||||
@@ -579,6 +684,8 @@ func (p *DatabaseColumnSecurityProvider) GetColumnSecurity(ctx context.Context,
|
|||||||
// Procedure names are configurable via SQLNames (see DefaultSQLNames for defaults)
|
// Procedure names are configurable via SQLNames (see DefaultSQLNames for defaults)
|
||||||
type DatabaseRowSecurityProvider struct {
|
type DatabaseRowSecurityProvider struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
|
dbMu sync.RWMutex
|
||||||
|
dbFactory func() (*sql.DB, error)
|
||||||
sqlNames *SQLNames
|
sqlNames *SQLNames
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -586,13 +693,45 @@ func NewDatabaseRowSecurityProvider(db *sql.DB, names ...*SQLNames) *DatabaseRow
|
|||||||
return &DatabaseRowSecurityProvider{db: db, sqlNames: resolveSQLNames(names...)}
|
return &DatabaseRowSecurityProvider{db: db, sqlNames: resolveSQLNames(names...)}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *DatabaseRowSecurityProvider) WithDBFactory(factory func() (*sql.DB, error)) *DatabaseRowSecurityProvider {
|
||||||
|
p.dbFactory = factory
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *DatabaseRowSecurityProvider) getDB() *sql.DB {
|
||||||
|
p.dbMu.RLock()
|
||||||
|
defer p.dbMu.RUnlock()
|
||||||
|
return p.db
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *DatabaseRowSecurityProvider) reconnectDB() error {
|
||||||
|
if p.dbFactory == nil {
|
||||||
|
return fmt.Errorf("no db factory configured for reconnect")
|
||||||
|
}
|
||||||
|
newDB, err := p.dbFactory()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
p.dbMu.Lock()
|
||||||
|
p.db = newDB
|
||||||
|
p.dbMu.Unlock()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (p *DatabaseRowSecurityProvider) GetRowSecurity(ctx context.Context, userID int, schema, table string) (RowSecurity, error) {
|
func (p *DatabaseRowSecurityProvider) GetRowSecurity(ctx context.Context, userID int, schema, table string) (RowSecurity, error) {
|
||||||
var template string
|
var template string
|
||||||
var hasBlock bool
|
var hasBlock bool
|
||||||
|
|
||||||
|
runQuery := func() error {
|
||||||
query := fmt.Sprintf(`SELECT p_template, p_block FROM %s($1, $2, $3)`, p.sqlNames.RowSecurity)
|
query := fmt.Sprintf(`SELECT p_template, p_block FROM %s($1, $2, $3)`, p.sqlNames.RowSecurity)
|
||||||
|
return p.getDB().QueryRowContext(ctx, query, schema, table, userID).Scan(&template, &hasBlock)
|
||||||
err := p.db.QueryRowContext(ctx, query, schema, table, userID).Scan(&template, &hasBlock)
|
}
|
||||||
|
err := runQuery()
|
||||||
|
if isDBClosed(err) {
|
||||||
|
if reconnErr := p.reconnectDB(); reconnErr == nil {
|
||||||
|
err = runQuery()
|
||||||
|
}
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return RowSecurity{}, fmt.Errorf("failed to load row security: %w", err)
|
return RowSecurity{}, fmt.Errorf("failed to load row security: %w", err)
|
||||||
}
|
}
|
||||||
@@ -662,6 +801,11 @@ func (p *ConfigRowSecurityProvider) GetRowSecurity(ctx context.Context, userID i
|
|||||||
// Helper functions
|
// Helper functions
|
||||||
// ================
|
// ================
|
||||||
|
|
||||||
|
// isDBClosed reports whether err indicates the *sql.DB has been closed.
|
||||||
|
func isDBClosed(err error) bool {
|
||||||
|
return err != nil && strings.Contains(err.Error(), "sql: database is closed")
|
||||||
|
}
|
||||||
|
|
||||||
func parseRoles(rolesStr string) []string {
|
func parseRoles(rolesStr string) []string {
|
||||||
if rolesStr == "" {
|
if rolesStr == "" {
|
||||||
return []string{}
|
return []string{}
|
||||||
@@ -780,8 +924,16 @@ func (a *DatabaseAuthenticator) LoginWithPasskey(ctx context.Context, req Passke
|
|||||||
var errorMsg sql.NullString
|
var errorMsg sql.NullString
|
||||||
var dataJSON sql.NullString
|
var dataJSON sql.NullString
|
||||||
|
|
||||||
|
runPasskeyQuery := func() error {
|
||||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.PasskeyLogin)
|
query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.PasskeyLogin)
|
||||||
err = a.db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
|
return a.getDB().QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
|
||||||
|
}
|
||||||
|
err = runPasskeyQuery()
|
||||||
|
if isDBClosed(err) {
|
||||||
|
if reconnErr := a.reconnectDB(); reconnErr == nil {
|
||||||
|
err = runPasskeyQuery()
|
||||||
|
}
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("passkey login query failed: %w", err)
|
return nil, fmt.Errorf("passkey login query failed: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user