mirror of
https://github.com/bitechdev/ResolveSpec.git
synced 2026-04-09 09:26:24 +00:00
* Added a database factory function to allow reconnection when the database is closed. * Implemented mutex locks for safe concurrent access to the database connection. * Updated all database query methods to handle reconnection attempts on closed connections. * Enhanced error handling for database operations across multiple providers.
257 lines
7.8 KiB
Go
257 lines
7.8 KiB
Go
package security
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"database/sql"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/bitechdev/ResolveSpec/pkg/cache"
|
|
)
|
|
|
|
// DatabaseKeyStoreOptions configures DatabaseKeyStore.
|
|
type DatabaseKeyStoreOptions struct {
|
|
// Cache is an optional cache instance. If nil, uses the default cache.
|
|
Cache *cache.Cache
|
|
// CacheTTL is the duration to cache ValidateKey results.
|
|
// Default: 2 minutes.
|
|
CacheTTL time.Duration
|
|
// SQLNames provides custom procedure names. If nil, uses DefaultKeyStoreSQLNames().
|
|
SQLNames *KeyStoreSQLNames
|
|
// DBFactory is called to obtain a fresh *sql.DB when the existing connection is closed.
|
|
// If nil, reconnection is disabled.
|
|
DBFactory func() (*sql.DB, error)
|
|
}
|
|
|
|
// DatabaseKeyStore is a KeyStore backed by PostgreSQL stored procedures.
|
|
// All DB operations go through configurable procedure names; the raw key is
|
|
// never passed to the database.
|
|
//
|
|
// See keystore_schema.sql for the required table and procedure definitions.
|
|
//
|
|
// Note: DeleteKey invalidates the cache entry for the deleted key. Due to the
|
|
// cache TTL, a deleted key may continue to authenticate for up to CacheTTL
|
|
// (default 2 minutes) if the cache entry cannot be invalidated.
|
|
type DatabaseKeyStore struct {
|
|
db *sql.DB
|
|
dbMu sync.RWMutex
|
|
dbFactory func() (*sql.DB, error)
|
|
sqlNames *KeyStoreSQLNames
|
|
cache *cache.Cache
|
|
cacheTTL time.Duration
|
|
}
|
|
|
|
// NewDatabaseKeyStore creates a DatabaseKeyStore with optional configuration.
|
|
func NewDatabaseKeyStore(db *sql.DB, opts ...DatabaseKeyStoreOptions) *DatabaseKeyStore {
|
|
o := DatabaseKeyStoreOptions{}
|
|
if len(opts) > 0 {
|
|
o = opts[0]
|
|
}
|
|
if o.CacheTTL == 0 {
|
|
o.CacheTTL = 2 * time.Minute
|
|
}
|
|
c := o.Cache
|
|
if c == nil {
|
|
c = cache.GetDefaultCache()
|
|
}
|
|
names := MergeKeyStoreSQLNames(DefaultKeyStoreSQLNames(), o.SQLNames)
|
|
return &DatabaseKeyStore{
|
|
db: db,
|
|
dbFactory: o.DBFactory,
|
|
sqlNames: names,
|
|
cache: c,
|
|
cacheTTL: o.CacheTTL,
|
|
}
|
|
}
|
|
|
|
func (ks *DatabaseKeyStore) getDB() *sql.DB {
|
|
ks.dbMu.RLock()
|
|
defer ks.dbMu.RUnlock()
|
|
return ks.db
|
|
}
|
|
|
|
func (ks *DatabaseKeyStore) reconnectDB() error {
|
|
if ks.dbFactory == nil {
|
|
return fmt.Errorf("no db factory configured for reconnect")
|
|
}
|
|
newDB, err := ks.dbFactory()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
ks.dbMu.Lock()
|
|
ks.db = newDB
|
|
ks.dbMu.Unlock()
|
|
return nil
|
|
}
|
|
|
|
// CreateKey generates a raw key, stores its SHA-256 hash via the create procedure,
|
|
// and returns the raw key once.
|
|
func (ks *DatabaseKeyStore) CreateKey(ctx context.Context, req CreateKeyRequest) (*CreateKeyResponse, error) {
|
|
rawBytes := make([]byte, 32)
|
|
if _, err := rand.Read(rawBytes); err != nil {
|
|
return nil, fmt.Errorf("failed to generate key material: %w", err)
|
|
}
|
|
rawKey := base64.RawURLEncoding.EncodeToString(rawBytes)
|
|
hash := hashSHA256Hex(rawKey)
|
|
|
|
type createRequest struct {
|
|
UserID int `json:"user_id"`
|
|
KeyType KeyType `json:"key_type"`
|
|
KeyHash string `json:"key_hash"`
|
|
Name string `json:"name"`
|
|
Scopes []string `json:"scopes,omitempty"`
|
|
Meta map[string]any `json:"meta,omitempty"`
|
|
ExpiresAt *time.Time `json:"expires_at,omitempty"`
|
|
}
|
|
|
|
reqJSON, err := json.Marshal(createRequest{
|
|
UserID: req.UserID,
|
|
KeyType: req.KeyType,
|
|
KeyHash: hash,
|
|
Name: req.Name,
|
|
Scopes: req.Scopes,
|
|
Meta: req.Meta,
|
|
ExpiresAt: req.ExpiresAt,
|
|
})
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to marshal create key request: %w", err)
|
|
}
|
|
|
|
var success bool
|
|
var errorMsg sql.NullString
|
|
var keyJSON sql.NullString
|
|
|
|
query := fmt.Sprintf(`SELECT p_success, p_error, p_key::text FROM %s($1::jsonb)`, ks.sqlNames.CreateKey)
|
|
if err = ks.getDB().QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &keyJSON); err != nil {
|
|
return nil, fmt.Errorf("create key procedure failed: %w", err)
|
|
}
|
|
if !success {
|
|
return nil, errors.New(nullStringOr(errorMsg, "create key failed"))
|
|
}
|
|
|
|
var key UserKey
|
|
if err = json.Unmarshal([]byte(keyJSON.String), &key); err != nil {
|
|
return nil, fmt.Errorf("failed to parse created key: %w", err)
|
|
}
|
|
|
|
return &CreateKeyResponse{Key: key, RawKey: rawKey}, nil
|
|
}
|
|
|
|
// GetUserKeys returns all active, non-expired keys for the given user.
|
|
// Pass an empty KeyType to return all types.
|
|
func (ks *DatabaseKeyStore) GetUserKeys(ctx context.Context, userID int, keyType KeyType) ([]UserKey, error) {
|
|
var success bool
|
|
var errorMsg sql.NullString
|
|
var keysJSON sql.NullString
|
|
|
|
query := fmt.Sprintf(`SELECT p_success, p_error, p_keys::text FROM %s($1, $2)`, ks.sqlNames.GetUserKeys)
|
|
if err := ks.getDB().QueryRowContext(ctx, query, userID, string(keyType)).Scan(&success, &errorMsg, &keysJSON); err != nil {
|
|
return nil, fmt.Errorf("get user keys procedure failed: %w", err)
|
|
}
|
|
if !success {
|
|
return nil, errors.New(nullStringOr(errorMsg, "get user keys failed"))
|
|
}
|
|
|
|
var keys []UserKey
|
|
if keysJSON.Valid && keysJSON.String != "" && keysJSON.String != "[]" {
|
|
if err := json.Unmarshal([]byte(keysJSON.String), &keys); err != nil {
|
|
return nil, fmt.Errorf("failed to parse user keys: %w", err)
|
|
}
|
|
}
|
|
if keys == nil {
|
|
keys = []UserKey{}
|
|
}
|
|
return keys, nil
|
|
}
|
|
|
|
// DeleteKey soft-deletes a key after verifying ownership and invalidates its cache entry.
|
|
// The delete procedure returns the key_hash so no separate lookup is needed.
|
|
// Note: cache invalidation is best-effort; a cached entry may persist for up to CacheTTL.
|
|
func (ks *DatabaseKeyStore) DeleteKey(ctx context.Context, userID int, keyID int64) error {
|
|
var success bool
|
|
var errorMsg sql.NullString
|
|
var keyHash sql.NullString
|
|
|
|
query := fmt.Sprintf(`SELECT p_success, p_error, p_key_hash FROM %s($1, $2)`, ks.sqlNames.DeleteKey)
|
|
if err := ks.getDB().QueryRowContext(ctx, query, userID, keyID).Scan(&success, &errorMsg, &keyHash); err != nil {
|
|
return fmt.Errorf("delete key procedure failed: %w", err)
|
|
}
|
|
if !success {
|
|
return errors.New(nullStringOr(errorMsg, "delete key failed"))
|
|
}
|
|
|
|
if keyHash.Valid && keyHash.String != "" && ks.cache != nil {
|
|
_ = ks.cache.Delete(ctx, keystoreCacheKey(keyHash.String))
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// ValidateKey hashes the raw key and calls the validate procedure.
|
|
// Results are cached for CacheTTL to reduce DB load on hot paths.
|
|
func (ks *DatabaseKeyStore) ValidateKey(ctx context.Context, rawKey string, keyType KeyType) (*UserKey, error) {
|
|
hash := hashSHA256Hex(rawKey)
|
|
cacheKey := keystoreCacheKey(hash)
|
|
|
|
if ks.cache != nil {
|
|
var cached UserKey
|
|
if err := ks.cache.Get(ctx, cacheKey, &cached); err == nil {
|
|
if cached.IsActive {
|
|
return &cached, nil
|
|
}
|
|
return nil, errors.New("invalid or expired key")
|
|
}
|
|
}
|
|
|
|
var success bool
|
|
var errorMsg sql.NullString
|
|
var keyJSON sql.NullString
|
|
|
|
runQuery := func() error {
|
|
query := fmt.Sprintf(`SELECT p_success, p_error, p_key::text FROM %s($1, $2)`, ks.sqlNames.ValidateKey)
|
|
return ks.getDB().QueryRowContext(ctx, query, hash, string(keyType)).Scan(&success, &errorMsg, &keyJSON)
|
|
}
|
|
if err := runQuery(); err != nil {
|
|
if isDBClosed(err) {
|
|
if reconnErr := ks.reconnectDB(); reconnErr == nil {
|
|
err = runQuery()
|
|
}
|
|
if err != nil {
|
|
return nil, fmt.Errorf("validate key procedure failed: %w", err)
|
|
}
|
|
} else {
|
|
return nil, fmt.Errorf("validate key procedure failed: %w", err)
|
|
}
|
|
}
|
|
if !success {
|
|
return nil, errors.New(nullStringOr(errorMsg, "invalid or expired key"))
|
|
}
|
|
|
|
var key UserKey
|
|
if err := json.Unmarshal([]byte(keyJSON.String), &key); err != nil {
|
|
return nil, fmt.Errorf("failed to parse validated key: %w", err)
|
|
}
|
|
|
|
if ks.cache != nil {
|
|
_ = ks.cache.Set(ctx, cacheKey, key, ks.cacheTTL)
|
|
}
|
|
|
|
return &key, nil
|
|
}
|
|
|
|
func keystoreCacheKey(hash string) string {
|
|
return "keystore:validate:" + hash
|
|
}
|
|
|
|
// nullStringOr returns s.String if valid, otherwise the fallback.
|
|
func nullStringOr(s sql.NullString, fallback string) string {
|
|
if s.Valid && s.String != "" {
|
|
return s.String
|
|
}
|
|
return fallback
|
|
}
|