mirror of
https://github.com/bitechdev/ResolveSpec.git
synced 2026-04-09 17:36:23 +00:00
Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 4adf94fe37 | |||
|
|
405a04a192 | ||
|
|
c1b16d363a | ||
|
|
568df8c6d6 |
2
go.mod
2
go.mod
@@ -15,6 +15,7 @@ require (
|
|||||||
github.com/gorilla/websocket v1.5.3
|
github.com/gorilla/websocket v1.5.3
|
||||||
github.com/jackc/pgx/v5 v5.8.0
|
github.com/jackc/pgx/v5 v5.8.0
|
||||||
github.com/klauspost/compress v1.18.2
|
github.com/klauspost/compress v1.18.2
|
||||||
|
github.com/mark3labs/mcp-go v0.46.0
|
||||||
github.com/mattn/go-sqlite3 v1.14.33
|
github.com/mattn/go-sqlite3 v1.14.33
|
||||||
github.com/microsoft/go-mssqldb v1.9.5
|
github.com/microsoft/go-mssqldb v1.9.5
|
||||||
github.com/mochi-mqtt/server/v2 v2.7.9
|
github.com/mochi-mqtt/server/v2 v2.7.9
|
||||||
@@ -88,7 +89,6 @@ require (
|
|||||||
github.com/jinzhu/now v1.1.5 // indirect
|
github.com/jinzhu/now v1.1.5 // indirect
|
||||||
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect
|
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect
|
||||||
github.com/magiconair/properties v1.8.10 // indirect
|
github.com/magiconair/properties v1.8.10 // indirect
|
||||||
github.com/mark3labs/mcp-go v0.46.0 // indirect
|
|
||||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||||
github.com/moby/docker-image-spec v1.3.1 // indirect
|
github.com/moby/docker-image-spec v1.3.1 // indirect
|
||||||
github.com/moby/go-archive v0.1.0 // indirect
|
github.com/moby/go-archive v0.1.0 // indirect
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ type Connection interface {
|
|||||||
Bun() (*bun.DB, error)
|
Bun() (*bun.DB, error)
|
||||||
GORM() (*gorm.DB, error)
|
GORM() (*gorm.DB, error)
|
||||||
Native() (*sql.DB, error)
|
Native() (*sql.DB, error)
|
||||||
|
DB() (*sql.DB, error)
|
||||||
|
|
||||||
// Common Database interface (for SQL databases)
|
// Common Database interface (for SQL databases)
|
||||||
Database() (common.Database, error)
|
Database() (common.Database, error)
|
||||||
@@ -224,6 +225,11 @@ func (c *sqlConnection) Native() (*sql.DB, error) {
|
|||||||
return c.nativeDB, nil
|
return c.nativeDB, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DB returns the underlying *sql.DB connection
|
||||||
|
func (c *sqlConnection) DB() (*sql.DB, error) {
|
||||||
|
return c.Native()
|
||||||
|
}
|
||||||
|
|
||||||
// Bun returns a Bun ORM instance wrapping the native connection
|
// Bun returns a Bun ORM instance wrapping the native connection
|
||||||
func (c *sqlConnection) Bun() (*bun.DB, error) {
|
func (c *sqlConnection) Bun() (*bun.DB, error) {
|
||||||
if c == nil {
|
if c == nil {
|
||||||
@@ -645,6 +651,11 @@ func (c *mongoConnection) Native() (*sql.DB, error) {
|
|||||||
return nil, ErrNotSQLDatabase
|
return nil, ErrNotSQLDatabase
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DB returns an error for MongoDB connections
|
||||||
|
func (c *mongoConnection) DB() (*sql.DB, error) {
|
||||||
|
return nil, ErrNotSQLDatabase
|
||||||
|
}
|
||||||
|
|
||||||
// Database returns an error for MongoDB connections
|
// Database returns an error for MongoDB connections
|
||||||
func (c *mongoConnection) Database() (common.Database, error) {
|
func (c *mongoConnection) Database() (common.Database, error) {
|
||||||
return nil, ErrNotSQLDatabase
|
return nil, ErrNotSQLDatabase
|
||||||
|
|||||||
@@ -119,6 +119,83 @@ Add middleware before the MCP routes. The handler itself has no auth layer.
|
|||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
## Security
|
||||||
|
|
||||||
|
`resolvemcp` integrates with the `security` package to provide per-entity access control, row-level security, and column-level security — the same system used by `resolvespec` and `restheadspec`.
|
||||||
|
|
||||||
|
### Wiring security hooks
|
||||||
|
|
||||||
|
```go
|
||||||
|
import "github.com/bitechdev/ResolveSpec/pkg/security"
|
||||||
|
|
||||||
|
securityList := security.NewSecurityList(mySecurityProvider)
|
||||||
|
resolvemcp.RegisterSecurityHooks(handler, securityList)
|
||||||
|
```
|
||||||
|
|
||||||
|
Call `RegisterSecurityHooks` **once**, after creating the handler and before registering models. It installs these controls automatically:
|
||||||
|
|
||||||
|
| Hook | Effect |
|
||||||
|
|---|---|
|
||||||
|
| `BeforeHandle` | Enforces per-entity operation rules (see below) |
|
||||||
|
| `BeforeRead` | Loads RLS/CLS rules, then injects a user-scoped WHERE clause |
|
||||||
|
| `AfterRead` | Masks/hides columns per column-security rules; writes audit log |
|
||||||
|
| `BeforeUpdate` | Blocks update if `CanUpdate` is false |
|
||||||
|
| `BeforeDelete` | Blocks delete if `CanDelete` is false |
|
||||||
|
|
||||||
|
### Per-entity operation rules
|
||||||
|
|
||||||
|
Use `RegisterModelWithRules` instead of `RegisterModel` to set access rules at registration time:
|
||||||
|
|
||||||
|
```go
|
||||||
|
import "github.com/bitechdev/ResolveSpec/pkg/modelregistry"
|
||||||
|
|
||||||
|
// Read-only entity
|
||||||
|
handler.RegisterModelWithRules("public", "audit_logs", &AuditLog{}, modelregistry.ModelRules{
|
||||||
|
CanRead: true,
|
||||||
|
CanCreate: false,
|
||||||
|
CanUpdate: false,
|
||||||
|
CanDelete: false,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Public read, authenticated write
|
||||||
|
handler.RegisterModelWithRules("public", "products", &Product{}, modelregistry.ModelRules{
|
||||||
|
CanPublicRead: true,
|
||||||
|
CanRead: true,
|
||||||
|
CanCreate: true,
|
||||||
|
CanUpdate: true,
|
||||||
|
CanDelete: false,
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
To update rules for an already-registered model:
|
||||||
|
|
||||||
|
```go
|
||||||
|
handler.SetModelRules("public", "users", modelregistry.ModelRules{
|
||||||
|
CanRead: true,
|
||||||
|
CanCreate: true,
|
||||||
|
CanUpdate: true,
|
||||||
|
CanDelete: false,
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
`RegisterModel` (no rules) registers with all-allowed defaults (`CanRead/Create/Update/Delete = true`).
|
||||||
|
|
||||||
|
### ModelRules fields
|
||||||
|
|
||||||
|
| Field | Default | Description |
|
||||||
|
|---|---|---|
|
||||||
|
| `CanPublicRead` | `false` | Allow unauthenticated reads |
|
||||||
|
| `CanPublicCreate` | `false` | Allow unauthenticated creates |
|
||||||
|
| `CanPublicUpdate` | `false` | Allow unauthenticated updates |
|
||||||
|
| `CanPublicDelete` | `false` | Allow unauthenticated deletes |
|
||||||
|
| `CanRead` | `true` | Allow authenticated reads |
|
||||||
|
| `CanCreate` | `true` | Allow authenticated creates |
|
||||||
|
| `CanUpdate` | `true` | Allow authenticated updates |
|
||||||
|
| `CanDelete` | `true` | Allow authenticated deletes |
|
||||||
|
| `SecurityDisabled` | `false` | Skip all security checks for this model |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
## MCP Tools
|
## MCP Tools
|
||||||
|
|
||||||
### Tool Naming
|
### Tool Naming
|
||||||
@@ -204,6 +281,35 @@ Delete a record by primary key. **Irreversible.**
|
|||||||
{ "success": true, "data": { ...deleted record... } }
|
{ "success": true, "data": { ...deleted record... } }
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Annotation Tool — `resolvespec_annotate`
|
||||||
|
|
||||||
|
Store or retrieve freeform annotation records for any tool, model, or entity. Registered automatically on every handler.
|
||||||
|
|
||||||
|
| Argument | Type | Description |
|
||||||
|
|---|---|---|
|
||||||
|
| `tool_name` | string (required) | Key to annotate — an MCP tool name (e.g. `read_public_users`), a model name (e.g. `public.users`), or any other identifier. |
|
||||||
|
| `annotations` | object | Annotation data to persist. Omit to retrieve existing annotations instead. |
|
||||||
|
|
||||||
|
**Set annotations** (calls `resolvespec_set_annotation(tool_name, annotations)`):
|
||||||
|
```json
|
||||||
|
{ "tool_name": "read_public_users", "annotations": { "description": "Returns active users", "owner": "platform-team" } }
|
||||||
|
```
|
||||||
|
**Response:**
|
||||||
|
```json
|
||||||
|
{ "success": true, "tool_name": "read_public_users", "action": "set" }
|
||||||
|
```
|
||||||
|
|
||||||
|
**Get annotations** (calls `resolvespec_get_annotation(tool_name)`):
|
||||||
|
```json
|
||||||
|
{ "tool_name": "read_public_users" }
|
||||||
|
```
|
||||||
|
**Response:**
|
||||||
|
```json
|
||||||
|
{ "success": true, "tool_name": "read_public_users", "action": "get", "annotations": { ... } }
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
### Resource — `{schema}.{entity}`
|
### Resource — `{schema}.{entity}`
|
||||||
|
|
||||||
Each model is also registered as an MCP resource with URI `schema.entity` (or just `entity` when schema is empty). Reading the resource returns up to 100 records as `application/json`.
|
Each model is also registered as an MCP resource with URI `schema.entity` (or just `entity` when schema is empty). Reading the resource returns up to 100 records as `application/json`.
|
||||||
|
|||||||
107
pkg/resolvemcp/annotation.go
Normal file
107
pkg/resolvemcp/annotation.go
Normal file
@@ -0,0 +1,107 @@
|
|||||||
|
package resolvemcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/mark3labs/mcp-go/mcp"
|
||||||
|
)
|
||||||
|
|
||||||
|
const annotationToolName = "resolvespec_annotate"
|
||||||
|
|
||||||
|
// registerAnnotationTool adds the resolvespec_annotate tool to the MCP server.
|
||||||
|
// The tool lets models/entities store and retrieve freeform annotation records
|
||||||
|
// using the resolvespec_set_annotation / resolvespec_get_annotation database procedures.
|
||||||
|
func registerAnnotationTool(h *Handler) {
|
||||||
|
tool := mcp.NewTool(annotationToolName,
|
||||||
|
mcp.WithDescription(
|
||||||
|
"Store or retrieve annotations for any MCP tool, model, or entity.\n\n"+
|
||||||
|
"To set annotations: provide both 'tool_name' and 'annotations'. "+
|
||||||
|
"Calls resolvespec_set_annotation(tool_name, annotations) to persist the data.\n\n"+
|
||||||
|
"To get annotations: provide only 'tool_name'. "+
|
||||||
|
"Calls resolvespec_get_annotation(tool_name) and returns the stored annotations.\n\n"+
|
||||||
|
"'tool_name' may be any identifier: an MCP tool name (e.g. 'read_public_users'), "+
|
||||||
|
"a model/entity name (e.g. 'public.users'), or any other key.",
|
||||||
|
),
|
||||||
|
mcp.WithString("tool_name",
|
||||||
|
mcp.Description("Name of the tool, model, or entity to annotate (e.g. 'read_public_users', 'public.users')."),
|
||||||
|
mcp.Required(),
|
||||||
|
),
|
||||||
|
mcp.WithObject("annotations",
|
||||||
|
mcp.Description("Annotation data to store. Omit to retrieve existing annotations instead of setting them."),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
h.mcpServer.AddTool(tool, func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||||
|
args := req.GetArguments()
|
||||||
|
|
||||||
|
toolName, ok := args["tool_name"].(string)
|
||||||
|
if !ok || toolName == "" {
|
||||||
|
return mcp.NewToolResultError("missing required argument: tool_name"), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
annotations, hasAnnotations := args["annotations"]
|
||||||
|
|
||||||
|
if hasAnnotations && annotations != nil {
|
||||||
|
return executeSetAnnotation(ctx, h, toolName, annotations)
|
||||||
|
}
|
||||||
|
return executeGetAnnotation(ctx, h, toolName)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func executeSetAnnotation(ctx context.Context, h *Handler, toolName string, annotations interface{}) (*mcp.CallToolResult, error) {
|
||||||
|
jsonBytes, err := json.Marshal(annotations)
|
||||||
|
if err != nil {
|
||||||
|
return mcp.NewToolResultError(fmt.Sprintf("failed to marshal annotations: %v", err)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = h.db.Exec(ctx, "SELECT resolvespec_set_annotation($1, $2)", toolName, string(jsonBytes))
|
||||||
|
if err != nil {
|
||||||
|
return mcp.NewToolResultError(fmt.Sprintf("failed to set annotation: %v", err)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return marshalResult(map[string]interface{}{
|
||||||
|
"success": true,
|
||||||
|
"tool_name": toolName,
|
||||||
|
"action": "set",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func executeGetAnnotation(ctx context.Context, h *Handler, toolName string) (*mcp.CallToolResult, error) {
|
||||||
|
var rows []map[string]interface{}
|
||||||
|
err := h.db.Query(ctx, &rows, "SELECT resolvespec_get_annotation($1)", toolName)
|
||||||
|
if err != nil {
|
||||||
|
return mcp.NewToolResultError(fmt.Sprintf("failed to get annotation: %v", err)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var annotations interface{}
|
||||||
|
if len(rows) > 0 {
|
||||||
|
// The procedure returns a single value; extract the first column of the first row.
|
||||||
|
for _, v := range rows[0] {
|
||||||
|
annotations = v
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the value is a []byte or string containing JSON, decode it so it round-trips cleanly.
|
||||||
|
switch v := annotations.(type) {
|
||||||
|
case []byte:
|
||||||
|
var decoded interface{}
|
||||||
|
if json.Unmarshal(v, &decoded) == nil {
|
||||||
|
annotations = decoded
|
||||||
|
}
|
||||||
|
case string:
|
||||||
|
var decoded interface{}
|
||||||
|
if json.Unmarshal([]byte(v), &decoded) == nil {
|
||||||
|
annotations = decoded
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return marshalResult(map[string]interface{}{
|
||||||
|
"success": true,
|
||||||
|
"tool_name": toolName,
|
||||||
|
"action": "get",
|
||||||
|
"annotations": annotations,
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -14,6 +14,7 @@ import (
|
|||||||
|
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/reflection"
|
"github.com/bitechdev/ResolveSpec/pkg/reflection"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -30,7 +31,7 @@ type Handler struct {
|
|||||||
|
|
||||||
// NewHandler creates a Handler with the given database, model registry, and config.
|
// NewHandler creates a Handler with the given database, model registry, and config.
|
||||||
func NewHandler(db common.Database, registry common.ModelRegistry, cfg Config) *Handler {
|
func NewHandler(db common.Database, registry common.ModelRegistry, cfg Config) *Handler {
|
||||||
return &Handler{
|
h := &Handler{
|
||||||
db: db,
|
db: db,
|
||||||
registry: registry,
|
registry: registry,
|
||||||
hooks: NewHookRegistry(),
|
hooks: NewHookRegistry(),
|
||||||
@@ -39,6 +40,8 @@ func NewHandler(db common.Database, registry common.ModelRegistry, cfg Config) *
|
|||||||
name: "resolvemcp",
|
name: "resolvemcp",
|
||||||
version: "1.0.0",
|
version: "1.0.0",
|
||||||
}
|
}
|
||||||
|
registerAnnotationTool(h)
|
||||||
|
return h
|
||||||
}
|
}
|
||||||
|
|
||||||
// Hooks returns the hook registry.
|
// Hooks returns the hook registry.
|
||||||
@@ -123,6 +126,32 @@ func (h *Handler) RegisterModel(schema, entity string, model interface{}) error
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RegisterModelWithRules registers a model and sets per-entity operation rules
|
||||||
|
// (CanRead, CanCreate, CanUpdate, CanDelete, CanPublic*, SecurityDisabled).
|
||||||
|
// Requires RegisterSecurityHooks to have been called for the rules to be enforced.
|
||||||
|
func (h *Handler) RegisterModelWithRules(schema, entity string, model interface{}, rules modelregistry.ModelRules) error {
|
||||||
|
reg, ok := h.registry.(*modelregistry.DefaultModelRegistry)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("resolvemcp: registry does not support model rules (use NewHandlerWithGORM/Bun/DB)")
|
||||||
|
}
|
||||||
|
fullName := buildModelName(schema, entity)
|
||||||
|
if err := reg.RegisterModelWithRules(fullName, model, rules); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
registerModelTools(h, schema, entity, model)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetModelRules updates the operation rules for an already-registered model.
|
||||||
|
// Requires RegisterSecurityHooks to have been called for the rules to be enforced.
|
||||||
|
func (h *Handler) SetModelRules(schema, entity string, rules modelregistry.ModelRules) error {
|
||||||
|
reg, ok := h.registry.(*modelregistry.DefaultModelRegistry)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("resolvemcp: registry does not support model rules (use NewHandlerWithGORM/Bun/DB)")
|
||||||
|
}
|
||||||
|
return reg.SetModelRules(buildModelName(schema, entity), rules)
|
||||||
|
}
|
||||||
|
|
||||||
// buildModelName builds the registry key for a model (same format as resolvespec).
|
// buildModelName builds the registry key for a model (same format as resolvespec).
|
||||||
func buildModelName(schema, entity string) string {
|
func buildModelName(schema, entity string) string {
|
||||||
if schema == "" {
|
if schema == "" {
|
||||||
|
|||||||
115
pkg/resolvemcp/security_hooks.go
Normal file
115
pkg/resolvemcp/security_hooks.go
Normal file
@@ -0,0 +1,115 @@
|
|||||||
|
package resolvemcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/security"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RegisterSecurityHooks wires the security package's access-control layer into the
|
||||||
|
// resolvemcp handler. Call it once after creating the handler, before registering models.
|
||||||
|
//
|
||||||
|
// The following controls are applied:
|
||||||
|
// - Per-entity operation rules (CanRead, CanCreate, CanUpdate, CanDelete, CanPublic*)
|
||||||
|
// stored via RegisterModelWithRules / SetModelRules.
|
||||||
|
// - Row-level security: WHERE clause injected per user from the SecurityList provider.
|
||||||
|
// - Column-level security: sensitive columns masked/hidden in read results.
|
||||||
|
// - Audit logging after each read.
|
||||||
|
func RegisterSecurityHooks(handler *Handler, securityList *security.SecurityList) {
|
||||||
|
// BeforeHandle: enforce model-level operation rules (auth check).
|
||||||
|
handler.Hooks().Register(BeforeHandle, func(hookCtx *HookContext) error {
|
||||||
|
if err := security.CheckModelAuthAllowed(newSecurityContext(hookCtx), hookCtx.Operation); err != nil {
|
||||||
|
hookCtx.Abort = true
|
||||||
|
hookCtx.AbortMessage = err.Error()
|
||||||
|
hookCtx.AbortCode = http.StatusUnauthorized
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
// BeforeRead (1st): load RLS + CLS rules from the provider into SecurityList.
|
||||||
|
handler.Hooks().Register(BeforeRead, func(hookCtx *HookContext) error {
|
||||||
|
return security.LoadSecurityRules(newSecurityContext(hookCtx), securityList)
|
||||||
|
})
|
||||||
|
|
||||||
|
// BeforeRead (2nd): apply row-level security — injects a WHERE clause into the query.
|
||||||
|
// resolvemcp has no separate BeforeScan hook; the query is available in BeforeRead.
|
||||||
|
handler.Hooks().Register(BeforeRead, func(hookCtx *HookContext) error {
|
||||||
|
return security.ApplyRowSecurity(newSecurityContext(hookCtx), securityList)
|
||||||
|
})
|
||||||
|
|
||||||
|
// AfterRead (1st): apply column-level security — mask/hide columns in the result.
|
||||||
|
handler.Hooks().Register(AfterRead, func(hookCtx *HookContext) error {
|
||||||
|
return security.ApplyColumnSecurity(newSecurityContext(hookCtx), securityList)
|
||||||
|
})
|
||||||
|
|
||||||
|
// AfterRead (2nd): audit log.
|
||||||
|
handler.Hooks().Register(AfterRead, func(hookCtx *HookContext) error {
|
||||||
|
return security.LogDataAccess(newSecurityContext(hookCtx))
|
||||||
|
})
|
||||||
|
|
||||||
|
// BeforeUpdate: enforce CanUpdate rule.
|
||||||
|
handler.Hooks().Register(BeforeUpdate, func(hookCtx *HookContext) error {
|
||||||
|
return security.CheckModelUpdateAllowed(newSecurityContext(hookCtx))
|
||||||
|
})
|
||||||
|
|
||||||
|
// BeforeDelete: enforce CanDelete rule.
|
||||||
|
handler.Hooks().Register(BeforeDelete, func(hookCtx *HookContext) error {
|
||||||
|
return security.CheckModelDeleteAllowed(newSecurityContext(hookCtx))
|
||||||
|
})
|
||||||
|
|
||||||
|
logger.Info("Security hooks registered for resolvemcp handler")
|
||||||
|
}
|
||||||
|
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
// securityContext — adapts resolvemcp.HookContext to security.SecurityContext
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
|
||||||
|
type securityContext struct {
|
||||||
|
ctx *HookContext
|
||||||
|
}
|
||||||
|
|
||||||
|
func newSecurityContext(ctx *HookContext) security.SecurityContext {
|
||||||
|
return &securityContext{ctx: ctx}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *securityContext) GetContext() context.Context {
|
||||||
|
return s.ctx.Context
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *securityContext) GetUserID() (int, bool) {
|
||||||
|
return security.GetUserID(s.ctx.Context)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *securityContext) GetSchema() string {
|
||||||
|
return s.ctx.Schema
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *securityContext) GetEntity() string {
|
||||||
|
return s.ctx.Entity
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *securityContext) GetModel() interface{} {
|
||||||
|
return s.ctx.Model
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *securityContext) GetQuery() interface{} {
|
||||||
|
return s.ctx.Query
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *securityContext) SetQuery(query interface{}) {
|
||||||
|
if q, ok := query.(common.SelectQuery); ok {
|
||||||
|
s.ctx.Query = q
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *securityContext) GetResult() interface{} {
|
||||||
|
return s.ctx.Result
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *securityContext) SetResult(result interface{}) {
|
||||||
|
s.ctx.Result = result
|
||||||
|
}
|
||||||
@@ -258,11 +258,8 @@ func (a *JWTAuthenticator) Login(ctx context.Context, req security.LoginRequest)
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *JWTAuthenticator) Logout(ctx context.Context, req security.LogoutRequest) error {
|
func (a *JWTAuthenticator) Logout(ctx context.Context, req security.LogoutRequest) error {
|
||||||
// Add to blacklist
|
// Invalidate session via stored procedure
|
||||||
return a.db.WithContext(ctx).Table("token_blacklist").Create(map[string]any{
|
return nil
|
||||||
"token": req.Token,
|
|
||||||
"user_id": req.UserID,
|
|
||||||
}).Error
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *JWTAuthenticator) Authenticate(r *http.Request) (*security.UserContext, error) {
|
func (a *JWTAuthenticator) Authenticate(r *http.Request) (*security.UserContext, error) {
|
||||||
|
|||||||
@@ -135,12 +135,6 @@ func (a *JWTAuthenticatorExample) Login(ctx context.Context, req LoginRequest) (
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *JWTAuthenticatorExample) Logout(ctx context.Context, req LogoutRequest) error {
|
func (a *JWTAuthenticatorExample) Logout(ctx context.Context, req LogoutRequest) error {
|
||||||
// For JWT, logout could involve token blacklisting
|
|
||||||
// Add token to blacklist table
|
|
||||||
// err := a.db.WithContext(ctx).Table("token_blacklist").Create(map[string]interface{}{
|
|
||||||
// "token": req.Token,
|
|
||||||
// "expires_at": time.Now().Add(24 * time.Hour),
|
|
||||||
// }).Error
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -244,10 +244,10 @@ func (a *DatabaseAuthenticator) oauth2GetOrCreateUser(ctx context.Context, userC
|
|||||||
var errMsg *string
|
var errMsg *string
|
||||||
var userID *int
|
var userID *int
|
||||||
|
|
||||||
err = a.db.QueryRowContext(ctx, `
|
err = a.db.QueryRowContext(ctx, fmt.Sprintf(`
|
||||||
SELECT p_success, p_error, p_user_id
|
SELECT p_success, p_error, p_user_id
|
||||||
FROM resolvespec_oauth_getorcreateuser($1::jsonb)
|
FROM %s($1::jsonb)
|
||||||
`, userJSON).Scan(&success, &errMsg, &userID)
|
`, a.sqlNames.OAuthGetOrCreateUser), userJSON).Scan(&success, &errMsg, &userID)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, fmt.Errorf("failed to get or create user: %w", err)
|
return 0, fmt.Errorf("failed to get or create user: %w", err)
|
||||||
@@ -287,10 +287,10 @@ func (a *DatabaseAuthenticator) oauth2CreateSession(ctx context.Context, session
|
|||||||
var success bool
|
var success bool
|
||||||
var errMsg *string
|
var errMsg *string
|
||||||
|
|
||||||
err = a.db.QueryRowContext(ctx, `
|
err = a.db.QueryRowContext(ctx, fmt.Sprintf(`
|
||||||
SELECT p_success, p_error
|
SELECT p_success, p_error
|
||||||
FROM resolvespec_oauth_createsession($1::jsonb)
|
FROM %s($1::jsonb)
|
||||||
`, sessionJSON).Scan(&success, &errMsg)
|
`, a.sqlNames.OAuthCreateSession), sessionJSON).Scan(&success, &errMsg)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to create session: %w", err)
|
return fmt.Errorf("failed to create session: %w", err)
|
||||||
@@ -385,10 +385,10 @@ func (a *DatabaseAuthenticator) OAuth2RefreshToken(ctx context.Context, refreshT
|
|||||||
var errMsg *string
|
var errMsg *string
|
||||||
var sessionData []byte
|
var sessionData []byte
|
||||||
|
|
||||||
err = a.db.QueryRowContext(ctx, `
|
err = a.db.QueryRowContext(ctx, fmt.Sprintf(`
|
||||||
SELECT p_success, p_error, p_data::text
|
SELECT p_success, p_error, p_data::text
|
||||||
FROM resolvespec_oauth_getrefreshtoken($1)
|
FROM %s($1)
|
||||||
`, refreshToken).Scan(&success, &errMsg, &sessionData)
|
`, a.sqlNames.OAuthGetRefreshToken), refreshToken).Scan(&success, &errMsg, &sessionData)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to get session by refresh token: %w", err)
|
return nil, fmt.Errorf("failed to get session by refresh token: %w", err)
|
||||||
@@ -451,10 +451,10 @@ func (a *DatabaseAuthenticator) OAuth2RefreshToken(ctx context.Context, refreshT
|
|||||||
var updateSuccess bool
|
var updateSuccess bool
|
||||||
var updateErrMsg *string
|
var updateErrMsg *string
|
||||||
|
|
||||||
err = a.db.QueryRowContext(ctx, `
|
err = a.db.QueryRowContext(ctx, fmt.Sprintf(`
|
||||||
SELECT p_success, p_error
|
SELECT p_success, p_error
|
||||||
FROM resolvespec_oauth_updaterefreshtoken($1::jsonb)
|
FROM %s($1::jsonb)
|
||||||
`, updateJSON).Scan(&updateSuccess, &updateErrMsg)
|
`, a.sqlNames.OAuthUpdateRefreshToken), updateJSON).Scan(&updateSuccess, &updateErrMsg)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to update session: %w", err)
|
return nil, fmt.Errorf("failed to update session: %w", err)
|
||||||
@@ -472,10 +472,10 @@ func (a *DatabaseAuthenticator) OAuth2RefreshToken(ctx context.Context, refreshT
|
|||||||
var userErrMsg *string
|
var userErrMsg *string
|
||||||
var userData []byte
|
var userData []byte
|
||||||
|
|
||||||
err = a.db.QueryRowContext(ctx, `
|
err = a.db.QueryRowContext(ctx, fmt.Sprintf(`
|
||||||
SELECT p_success, p_error, p_data::text
|
SELECT p_success, p_error, p_data::text
|
||||||
FROM resolvespec_oauth_getuser($1)
|
FROM %s($1)
|
||||||
`, session.UserID).Scan(&userSuccess, &userErrMsg, &userData)
|
`, a.sqlNames.OAuthGetUser), session.UserID).Scan(&userSuccess, &userErrMsg, &userData)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to get user data: %w", err)
|
return nil, fmt.Errorf("failed to get user data: %w", err)
|
||||||
|
|||||||
@@ -11,12 +11,14 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// DatabasePasskeyProvider implements PasskeyProvider using database storage
|
// DatabasePasskeyProvider implements PasskeyProvider using database storage
|
||||||
|
// Procedure names are configurable via SQLNames (see DefaultSQLNames for defaults)
|
||||||
type DatabasePasskeyProvider struct {
|
type DatabasePasskeyProvider struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
rpID string // Relying Party ID (domain)
|
rpID string // Relying Party ID (domain)
|
||||||
rpName string // Relying Party display name
|
rpName string // Relying Party display name
|
||||||
rpOrigin string // Expected origin for WebAuthn
|
rpOrigin string // Expected origin for WebAuthn
|
||||||
timeout int64 // Timeout in milliseconds (default: 60000)
|
timeout int64 // Timeout in milliseconds (default: 60000)
|
||||||
|
sqlNames *SQLNames
|
||||||
}
|
}
|
||||||
|
|
||||||
// DatabasePasskeyProviderOptions configures the passkey provider
|
// DatabasePasskeyProviderOptions configures the passkey provider
|
||||||
@@ -29,6 +31,8 @@ type DatabasePasskeyProviderOptions struct {
|
|||||||
RPOrigin string
|
RPOrigin string
|
||||||
// Timeout is the timeout for operations in milliseconds (default: 60000)
|
// Timeout is the timeout for operations in milliseconds (default: 60000)
|
||||||
Timeout int64
|
Timeout int64
|
||||||
|
// SQLNames provides custom SQL procedure/function names. If nil, uses DefaultSQLNames().
|
||||||
|
SQLNames *SQLNames
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewDatabasePasskeyProvider creates a new database-backed passkey provider
|
// NewDatabasePasskeyProvider creates a new database-backed passkey provider
|
||||||
@@ -37,12 +41,15 @@ func NewDatabasePasskeyProvider(db *sql.DB, opts DatabasePasskeyProviderOptions)
|
|||||||
opts.Timeout = 60000 // 60 seconds default
|
opts.Timeout = 60000 // 60 seconds default
|
||||||
}
|
}
|
||||||
|
|
||||||
|
sqlNames := MergeSQLNames(DefaultSQLNames(), opts.SQLNames)
|
||||||
|
|
||||||
return &DatabasePasskeyProvider{
|
return &DatabasePasskeyProvider{
|
||||||
db: db,
|
db: db,
|
||||||
rpID: opts.RPID,
|
rpID: opts.RPID,
|
||||||
rpName: opts.RPName,
|
rpName: opts.RPName,
|
||||||
rpOrigin: opts.RPOrigin,
|
rpOrigin: opts.RPOrigin,
|
||||||
timeout: opts.Timeout,
|
timeout: opts.Timeout,
|
||||||
|
sqlNames: sqlNames,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -132,7 +139,7 @@ func (p *DatabasePasskeyProvider) CompleteRegistration(ctx context.Context, user
|
|||||||
var errorMsg sql.NullString
|
var errorMsg sql.NullString
|
||||||
var credentialID sql.NullInt64
|
var credentialID sql.NullInt64
|
||||||
|
|
||||||
query := `SELECT p_success, p_error, p_credential_id FROM resolvespec_passkey_store_credential($1::jsonb)`
|
query := fmt.Sprintf(`SELECT p_success, p_error, p_credential_id FROM %s($1::jsonb)`, p.sqlNames.PasskeyStoreCredential)
|
||||||
err = p.db.QueryRowContext(ctx, query, string(credJSON)).Scan(&success, &errorMsg, &credentialID)
|
err = p.db.QueryRowContext(ctx, query, string(credJSON)).Scan(&success, &errorMsg, &credentialID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to store credential: %w", err)
|
return nil, fmt.Errorf("failed to store credential: %w", err)
|
||||||
@@ -173,7 +180,7 @@ func (p *DatabasePasskeyProvider) BeginAuthentication(ctx context.Context, usern
|
|||||||
var userID sql.NullInt64
|
var userID sql.NullInt64
|
||||||
var credentialsJSON sql.NullString
|
var credentialsJSON sql.NullString
|
||||||
|
|
||||||
query := `SELECT p_success, p_error, p_user_id, p_credentials::text FROM resolvespec_passkey_get_credentials_by_username($1)`
|
query := fmt.Sprintf(`SELECT p_success, p_error, p_user_id, p_credentials::text FROM %s($1)`, p.sqlNames.PasskeyGetCredsByUsername)
|
||||||
err := p.db.QueryRowContext(ctx, query, username).Scan(&success, &errorMsg, &userID, &credentialsJSON)
|
err := p.db.QueryRowContext(ctx, query, username).Scan(&success, &errorMsg, &userID, &credentialsJSON)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to get credentials: %w", err)
|
return nil, fmt.Errorf("failed to get credentials: %w", err)
|
||||||
@@ -233,7 +240,7 @@ func (p *DatabasePasskeyProvider) CompleteAuthentication(ctx context.Context, re
|
|||||||
var errorMsg sql.NullString
|
var errorMsg sql.NullString
|
||||||
var credentialJSON sql.NullString
|
var credentialJSON sql.NullString
|
||||||
|
|
||||||
query := `SELECT p_success, p_error, p_credential::text FROM resolvespec_passkey_get_credential($1)`
|
query := fmt.Sprintf(`SELECT p_success, p_error, p_credential::text FROM %s($1)`, p.sqlNames.PasskeyGetCredential)
|
||||||
err := p.db.QueryRowContext(ctx, query, response.RawID).Scan(&success, &errorMsg, &credentialJSON)
|
err := p.db.QueryRowContext(ctx, query, response.RawID).Scan(&success, &errorMsg, &credentialJSON)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, fmt.Errorf("failed to get credential: %w", err)
|
return 0, fmt.Errorf("failed to get credential: %w", err)
|
||||||
@@ -264,7 +271,7 @@ func (p *DatabasePasskeyProvider) CompleteAuthentication(ctx context.Context, re
|
|||||||
var updateError sql.NullString
|
var updateError sql.NullString
|
||||||
var cloneWarning sql.NullBool
|
var cloneWarning sql.NullBool
|
||||||
|
|
||||||
updateQuery := `SELECT p_success, p_error, p_clone_warning FROM resolvespec_passkey_update_counter($1, $2)`
|
updateQuery := fmt.Sprintf(`SELECT p_success, p_error, p_clone_warning FROM %s($1, $2)`, p.sqlNames.PasskeyUpdateCounter)
|
||||||
err = p.db.QueryRowContext(ctx, updateQuery, response.RawID, newCounter).Scan(&updateSuccess, &updateError, &cloneWarning)
|
err = p.db.QueryRowContext(ctx, updateQuery, response.RawID, newCounter).Scan(&updateSuccess, &updateError, &cloneWarning)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, fmt.Errorf("failed to update counter: %w", err)
|
return 0, fmt.Errorf("failed to update counter: %w", err)
|
||||||
@@ -283,7 +290,7 @@ func (p *DatabasePasskeyProvider) GetCredentials(ctx context.Context, userID int
|
|||||||
var errorMsg sql.NullString
|
var errorMsg sql.NullString
|
||||||
var credentialsJSON sql.NullString
|
var credentialsJSON sql.NullString
|
||||||
|
|
||||||
query := `SELECT p_success, p_error, p_credentials::text FROM resolvespec_passkey_get_user_credentials($1)`
|
query := fmt.Sprintf(`SELECT p_success, p_error, p_credentials::text FROM %s($1)`, p.sqlNames.PasskeyGetUserCredentials)
|
||||||
err := p.db.QueryRowContext(ctx, query, userID).Scan(&success, &errorMsg, &credentialsJSON)
|
err := p.db.QueryRowContext(ctx, query, userID).Scan(&success, &errorMsg, &credentialsJSON)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to get credentials: %w", err)
|
return nil, fmt.Errorf("failed to get credentials: %w", err)
|
||||||
@@ -362,7 +369,7 @@ func (p *DatabasePasskeyProvider) DeleteCredential(ctx context.Context, userID i
|
|||||||
var success bool
|
var success bool
|
||||||
var errorMsg sql.NullString
|
var errorMsg sql.NullString
|
||||||
|
|
||||||
query := `SELECT p_success, p_error FROM resolvespec_passkey_delete_credential($1, $2)`
|
query := fmt.Sprintf(`SELECT p_success, p_error FROM %s($1, $2)`, p.sqlNames.PasskeyDeleteCredential)
|
||||||
err = p.db.QueryRowContext(ctx, query, userID, credID).Scan(&success, &errorMsg)
|
err = p.db.QueryRowContext(ctx, query, userID, credID).Scan(&success, &errorMsg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to delete credential: %w", err)
|
return fmt.Errorf("failed to delete credential: %w", err)
|
||||||
@@ -388,7 +395,7 @@ func (p *DatabasePasskeyProvider) UpdateCredentialName(ctx context.Context, user
|
|||||||
var success bool
|
var success bool
|
||||||
var errorMsg sql.NullString
|
var errorMsg sql.NullString
|
||||||
|
|
||||||
query := `SELECT p_success, p_error FROM resolvespec_passkey_update_name($1, $2, $3)`
|
query := fmt.Sprintf(`SELECT p_success, p_error FROM %s($1, $2, $3)`, p.sqlNames.PasskeyUpdateName)
|
||||||
err = p.db.QueryRowContext(ctx, query, userID, credID, name).Scan(&success, &errorMsg)
|
err = p.db.QueryRowContext(ctx, query, userID, credID, name).Scan(&success, &errorMsg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to update credential name: %w", err)
|
return fmt.Errorf("failed to update credential name: %w", err)
|
||||||
|
|||||||
@@ -58,8 +58,7 @@ func (a *HeaderAuthenticator) Authenticate(r *http.Request) (*UserContext, error
|
|||||||
|
|
||||||
// DatabaseAuthenticator provides session-based authentication with database storage
|
// DatabaseAuthenticator provides session-based authentication with database storage
|
||||||
// All database operations go through stored procedures for security and consistency
|
// All database operations go through stored procedures for security and consistency
|
||||||
// Requires stored procedures: resolvespec_login, resolvespec_logout, resolvespec_session,
|
// Procedure names are configurable via SQLNames (see DefaultSQLNames for defaults)
|
||||||
// resolvespec_session_update, resolvespec_refresh_token
|
|
||||||
// See database_schema.sql for procedure definitions
|
// See database_schema.sql for procedure definitions
|
||||||
// Also supports multiple OAuth2 providers configured with WithOAuth2()
|
// Also supports multiple OAuth2 providers configured with WithOAuth2()
|
||||||
// Also supports passkey authentication configured with WithPasskey()
|
// Also supports passkey authentication configured with WithPasskey()
|
||||||
@@ -67,6 +66,7 @@ type DatabaseAuthenticator struct {
|
|||||||
db *sql.DB
|
db *sql.DB
|
||||||
cache *cache.Cache
|
cache *cache.Cache
|
||||||
cacheTTL time.Duration
|
cacheTTL time.Duration
|
||||||
|
sqlNames *SQLNames
|
||||||
|
|
||||||
// OAuth2 providers registry (multiple providers supported)
|
// OAuth2 providers registry (multiple providers supported)
|
||||||
oauth2Providers map[string]*OAuth2Provider
|
oauth2Providers map[string]*OAuth2Provider
|
||||||
@@ -85,6 +85,9 @@ type DatabaseAuthenticatorOptions struct {
|
|||||||
Cache *cache.Cache
|
Cache *cache.Cache
|
||||||
// PasskeyProvider is an optional passkey provider for WebAuthn/FIDO2 authentication
|
// PasskeyProvider is an optional passkey provider for WebAuthn/FIDO2 authentication
|
||||||
PasskeyProvider PasskeyProvider
|
PasskeyProvider PasskeyProvider
|
||||||
|
// SQLNames provides custom SQL procedure/function names. If nil, uses DefaultSQLNames().
|
||||||
|
// Partial overrides are supported: only set the fields you want to change.
|
||||||
|
SQLNames *SQLNames
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewDatabaseAuthenticator(db *sql.DB) *DatabaseAuthenticator {
|
func NewDatabaseAuthenticator(db *sql.DB) *DatabaseAuthenticator {
|
||||||
@@ -103,10 +106,13 @@ func NewDatabaseAuthenticatorWithOptions(db *sql.DB, opts DatabaseAuthenticatorO
|
|||||||
cacheInstance = cache.GetDefaultCache()
|
cacheInstance = cache.GetDefaultCache()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
sqlNames := MergeSQLNames(DefaultSQLNames(), opts.SQLNames)
|
||||||
|
|
||||||
return &DatabaseAuthenticator{
|
return &DatabaseAuthenticator{
|
||||||
db: db,
|
db: db,
|
||||||
cache: cacheInstance,
|
cache: cacheInstance,
|
||||||
cacheTTL: opts.CacheTTL,
|
cacheTTL: opts.CacheTTL,
|
||||||
|
sqlNames: sqlNames,
|
||||||
passkeyProvider: opts.PasskeyProvider,
|
passkeyProvider: opts.PasskeyProvider,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -118,12 +124,11 @@ func (a *DatabaseAuthenticator) Login(ctx context.Context, req LoginRequest) (*L
|
|||||||
return nil, fmt.Errorf("failed to marshal login request: %w", err)
|
return nil, fmt.Errorf("failed to marshal login request: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Call resolvespec_login stored procedure
|
|
||||||
var success bool
|
var success bool
|
||||||
var errorMsg sql.NullString
|
var errorMsg sql.NullString
|
||||||
var dataJSON sql.NullString
|
var dataJSON sql.NullString
|
||||||
|
|
||||||
query := `SELECT p_success, p_error, p_data::text FROM resolvespec_login($1::jsonb)`
|
query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.Login)
|
||||||
err = a.db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
|
err = a.db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("login query failed: %w", err)
|
return nil, fmt.Errorf("login query failed: %w", err)
|
||||||
@@ -153,12 +158,11 @@ func (a *DatabaseAuthenticator) Register(ctx context.Context, req RegisterReques
|
|||||||
return nil, fmt.Errorf("failed to marshal register request: %w", err)
|
return nil, fmt.Errorf("failed to marshal register request: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Call resolvespec_register stored procedure
|
|
||||||
var success bool
|
var success bool
|
||||||
var errorMsg sql.NullString
|
var errorMsg sql.NullString
|
||||||
var dataJSON sql.NullString
|
var dataJSON sql.NullString
|
||||||
|
|
||||||
query := `SELECT p_success, p_error, p_data::text FROM resolvespec_register($1::jsonb)`
|
query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.Register)
|
||||||
err = a.db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
|
err = a.db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("register query failed: %w", err)
|
return nil, fmt.Errorf("register query failed: %w", err)
|
||||||
@@ -187,12 +191,11 @@ func (a *DatabaseAuthenticator) Logout(ctx context.Context, req LogoutRequest) e
|
|||||||
return fmt.Errorf("failed to marshal logout request: %w", err)
|
return fmt.Errorf("failed to marshal logout request: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Call resolvespec_logout stored procedure
|
|
||||||
var success bool
|
var success bool
|
||||||
var errorMsg sql.NullString
|
var errorMsg sql.NullString
|
||||||
var dataJSON sql.NullString
|
var dataJSON sql.NullString
|
||||||
|
|
||||||
query := `SELECT p_success, p_error, p_data::text FROM resolvespec_logout($1::jsonb)`
|
query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.Logout)
|
||||||
err = a.db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
|
err = a.db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("logout query failed: %w", err)
|
return fmt.Errorf("logout query failed: %w", err)
|
||||||
@@ -266,7 +269,7 @@ func (a *DatabaseAuthenticator) Authenticate(r *http.Request) (*UserContext, err
|
|||||||
var errorMsg sql.NullString
|
var errorMsg sql.NullString
|
||||||
var userJSON sql.NullString
|
var userJSON sql.NullString
|
||||||
|
|
||||||
query := `SELECT p_success, p_error, p_user::text FROM resolvespec_session($1, $2)`
|
query := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2)`, a.sqlNames.Session)
|
||||||
err := a.db.QueryRowContext(r.Context(), query, token, reference).Scan(&success, &errorMsg, &userJSON)
|
err := a.db.QueryRowContext(r.Context(), query, token, reference).Scan(&success, &errorMsg, &userJSON)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("session query failed: %w", err)
|
return nil, fmt.Errorf("session query failed: %w", err)
|
||||||
@@ -338,24 +341,22 @@ func (a *DatabaseAuthenticator) updateSessionActivity(ctx context.Context, sessi
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Call resolvespec_session_update stored procedure
|
|
||||||
var success bool
|
var success bool
|
||||||
var errorMsg sql.NullString
|
var errorMsg sql.NullString
|
||||||
var updatedUserJSON sql.NullString
|
var updatedUserJSON sql.NullString
|
||||||
|
|
||||||
query := `SELECT p_success, p_error, p_user::text FROM resolvespec_session_update($1, $2::jsonb)`
|
query := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2::jsonb)`, a.sqlNames.SessionUpdate)
|
||||||
_ = a.db.QueryRowContext(ctx, query, sessionToken, string(userJSON)).Scan(&success, &errorMsg, &updatedUserJSON)
|
_ = a.db.QueryRowContext(ctx, query, sessionToken, string(userJSON)).Scan(&success, &errorMsg, &updatedUserJSON)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RefreshToken implements Refreshable interface
|
// RefreshToken implements Refreshable interface
|
||||||
func (a *DatabaseAuthenticator) RefreshToken(ctx context.Context, refreshToken string) (*LoginResponse, error) {
|
func (a *DatabaseAuthenticator) RefreshToken(ctx context.Context, refreshToken string) (*LoginResponse, error) {
|
||||||
// Call api_refresh_token stored procedure
|
|
||||||
// First, we need to get the current user context for the refresh token
|
// First, we need to get the current user context for the refresh token
|
||||||
var success bool
|
var success bool
|
||||||
var errorMsg sql.NullString
|
var errorMsg sql.NullString
|
||||||
var userJSON sql.NullString
|
var userJSON sql.NullString
|
||||||
// Get current session to pass to refresh
|
// Get current session to pass to refresh
|
||||||
query := `SELECT p_success, p_error, p_user::text FROM resolvespec_session($1, $2)`
|
query := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2)`, a.sqlNames.Session)
|
||||||
err := a.db.QueryRowContext(ctx, query, refreshToken, "refresh").Scan(&success, &errorMsg, &userJSON)
|
err := a.db.QueryRowContext(ctx, query, refreshToken, "refresh").Scan(&success, &errorMsg, &userJSON)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("refresh token query failed: %w", err)
|
return nil, fmt.Errorf("refresh token query failed: %w", err)
|
||||||
@@ -368,12 +369,11 @@ func (a *DatabaseAuthenticator) RefreshToken(ctx context.Context, refreshToken s
|
|||||||
return nil, fmt.Errorf("invalid refresh token")
|
return nil, fmt.Errorf("invalid refresh token")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Call resolvespec_refresh_token to generate new token
|
|
||||||
var newSuccess bool
|
var newSuccess bool
|
||||||
var newErrorMsg sql.NullString
|
var newErrorMsg sql.NullString
|
||||||
var newUserJSON sql.NullString
|
var newUserJSON sql.NullString
|
||||||
|
|
||||||
refreshQuery := `SELECT p_success, p_error, p_user::text FROM resolvespec_refresh_token($1, $2::jsonb)`
|
refreshQuery := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2::jsonb)`, a.sqlNames.RefreshToken)
|
||||||
err = a.db.QueryRowContext(ctx, refreshQuery, refreshToken, userJSON).Scan(&newSuccess, &newErrorMsg, &newUserJSON)
|
err = a.db.QueryRowContext(ctx, refreshQuery, refreshToken, userJSON).Scan(&newSuccess, &newErrorMsg, &newUserJSON)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("refresh token generation failed: %w", err)
|
return nil, fmt.Errorf("refresh token generation failed: %w", err)
|
||||||
@@ -401,27 +401,28 @@ func (a *DatabaseAuthenticator) RefreshToken(ctx context.Context, refreshToken s
|
|||||||
|
|
||||||
// JWTAuthenticator provides JWT token-based authentication
|
// JWTAuthenticator provides JWT token-based authentication
|
||||||
// All database operations go through stored procedures
|
// All database operations go through stored procedures
|
||||||
// Requires stored procedures: resolvespec_jwt_login, resolvespec_jwt_logout
|
// Procedure names are configurable via SQLNames (see DefaultSQLNames for defaults)
|
||||||
// NOTE: JWT signing/verification requires github.com/golang-jwt/jwt/v5 to be installed and imported
|
// NOTE: JWT signing/verification requires github.com/golang-jwt/jwt/v5 to be installed and imported
|
||||||
type JWTAuthenticator struct {
|
type JWTAuthenticator struct {
|
||||||
secretKey []byte
|
secretKey []byte
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
|
sqlNames *SQLNames
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewJWTAuthenticator(secretKey string, db *sql.DB) *JWTAuthenticator {
|
func NewJWTAuthenticator(secretKey string, db *sql.DB, names ...*SQLNames) *JWTAuthenticator {
|
||||||
return &JWTAuthenticator{
|
return &JWTAuthenticator{
|
||||||
secretKey: []byte(secretKey),
|
secretKey: []byte(secretKey),
|
||||||
db: db,
|
db: db,
|
||||||
|
sqlNames: resolveSQLNames(names...),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *JWTAuthenticator) Login(ctx context.Context, req LoginRequest) (*LoginResponse, error) {
|
func (a *JWTAuthenticator) Login(ctx context.Context, req LoginRequest) (*LoginResponse, error) {
|
||||||
// Call resolvespec_jwt_login stored procedure
|
|
||||||
var success bool
|
var success bool
|
||||||
var errorMsg sql.NullString
|
var errorMsg sql.NullString
|
||||||
var userJSON []byte
|
var userJSON []byte
|
||||||
|
|
||||||
query := `SELECT p_success, p_error, p_user FROM resolvespec_jwt_login($1, $2)`
|
query := fmt.Sprintf(`SELECT p_success, p_error, p_user FROM %s($1, $2)`, a.sqlNames.JWTLogin)
|
||||||
err := a.db.QueryRowContext(ctx, query, req.Username, req.Password).Scan(&success, &errorMsg, &userJSON)
|
err := a.db.QueryRowContext(ctx, query, req.Username, req.Password).Scan(&success, &errorMsg, &userJSON)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("login query failed: %w", err)
|
return nil, fmt.Errorf("login query failed: %w", err)
|
||||||
@@ -471,11 +472,10 @@ func (a *JWTAuthenticator) Login(ctx context.Context, req LoginRequest) (*LoginR
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *JWTAuthenticator) Logout(ctx context.Context, req LogoutRequest) error {
|
func (a *JWTAuthenticator) Logout(ctx context.Context, req LogoutRequest) error {
|
||||||
// Call resolvespec_jwt_logout stored procedure
|
|
||||||
var success bool
|
var success bool
|
||||||
var errorMsg sql.NullString
|
var errorMsg sql.NullString
|
||||||
|
|
||||||
query := `SELECT p_success, p_error FROM resolvespec_jwt_logout($1, $2)`
|
query := fmt.Sprintf(`SELECT p_success, p_error FROM %s($1, $2)`, a.sqlNames.JWTLogout)
|
||||||
err := a.db.QueryRowContext(ctx, query, req.Token, req.UserID).Scan(&success, &errorMsg)
|
err := a.db.QueryRowContext(ctx, query, req.Token, req.UserID).Scan(&success, &errorMsg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("logout query failed: %w", err)
|
return fmt.Errorf("logout query failed: %w", err)
|
||||||
@@ -511,24 +511,24 @@ func (a *JWTAuthenticator) Authenticate(r *http.Request) (*UserContext, error) {
|
|||||||
|
|
||||||
// DatabaseColumnSecurityProvider loads column security from database
|
// DatabaseColumnSecurityProvider loads column security from database
|
||||||
// All database operations go through stored procedures
|
// All database operations go through stored procedures
|
||||||
// Requires stored procedure: resolvespec_column_security
|
// Procedure names are configurable via SQLNames (see DefaultSQLNames for defaults)
|
||||||
type DatabaseColumnSecurityProvider struct {
|
type DatabaseColumnSecurityProvider struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
|
sqlNames *SQLNames
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewDatabaseColumnSecurityProvider(db *sql.DB) *DatabaseColumnSecurityProvider {
|
func NewDatabaseColumnSecurityProvider(db *sql.DB, names ...*SQLNames) *DatabaseColumnSecurityProvider {
|
||||||
return &DatabaseColumnSecurityProvider{db: db}
|
return &DatabaseColumnSecurityProvider{db: db, sqlNames: resolveSQLNames(names...)}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *DatabaseColumnSecurityProvider) GetColumnSecurity(ctx context.Context, userID int, schema, table string) ([]ColumnSecurity, error) {
|
func (p *DatabaseColumnSecurityProvider) GetColumnSecurity(ctx context.Context, userID int, schema, table string) ([]ColumnSecurity, error) {
|
||||||
var rules []ColumnSecurity
|
var rules []ColumnSecurity
|
||||||
|
|
||||||
// Call resolvespec_column_security stored procedure
|
|
||||||
var success bool
|
var success bool
|
||||||
var errorMsg sql.NullString
|
var errorMsg sql.NullString
|
||||||
var rulesJSON []byte
|
var rulesJSON []byte
|
||||||
|
|
||||||
query := `SELECT p_success, p_error, p_rules FROM resolvespec_column_security($1, $2, $3)`
|
query := fmt.Sprintf(`SELECT p_success, p_error, p_rules FROM %s($1, $2, $3)`, p.sqlNames.ColumnSecurity)
|
||||||
err := p.db.QueryRowContext(ctx, query, userID, schema, table).Scan(&success, &errorMsg, &rulesJSON)
|
err := p.db.QueryRowContext(ctx, query, userID, schema, table).Scan(&success, &errorMsg, &rulesJSON)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to load column security: %w", err)
|
return nil, fmt.Errorf("failed to load column security: %w", err)
|
||||||
@@ -576,21 +576,21 @@ func (p *DatabaseColumnSecurityProvider) GetColumnSecurity(ctx context.Context,
|
|||||||
|
|
||||||
// DatabaseRowSecurityProvider loads row security from database
|
// DatabaseRowSecurityProvider loads row security from database
|
||||||
// All database operations go through stored procedures
|
// All database operations go through stored procedures
|
||||||
// Requires stored procedure: resolvespec_row_security
|
// Procedure names are configurable via SQLNames (see DefaultSQLNames for defaults)
|
||||||
type DatabaseRowSecurityProvider struct {
|
type DatabaseRowSecurityProvider struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
|
sqlNames *SQLNames
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewDatabaseRowSecurityProvider(db *sql.DB) *DatabaseRowSecurityProvider {
|
func NewDatabaseRowSecurityProvider(db *sql.DB, names ...*SQLNames) *DatabaseRowSecurityProvider {
|
||||||
return &DatabaseRowSecurityProvider{db: db}
|
return &DatabaseRowSecurityProvider{db: db, sqlNames: resolveSQLNames(names...)}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *DatabaseRowSecurityProvider) GetRowSecurity(ctx context.Context, userID int, schema, table string) (RowSecurity, error) {
|
func (p *DatabaseRowSecurityProvider) GetRowSecurity(ctx context.Context, userID int, schema, table string) (RowSecurity, error) {
|
||||||
var template string
|
var template string
|
||||||
var hasBlock bool
|
var hasBlock bool
|
||||||
|
|
||||||
// Call resolvespec_row_security stored procedure
|
query := fmt.Sprintf(`SELECT p_template, p_block FROM %s($1, $2, $3)`, p.sqlNames.RowSecurity)
|
||||||
query := `SELECT p_template, p_block FROM resolvespec_row_security($1, $2, $3)`
|
|
||||||
|
|
||||||
err := p.db.QueryRowContext(ctx, query, schema, table, userID).Scan(&template, &hasBlock)
|
err := p.db.QueryRowContext(ctx, query, schema, table, userID).Scan(&template, &hasBlock)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -758,56 +758,47 @@ func (a *DatabaseAuthenticator) LoginWithPasskey(ctx context.Context, req Passke
|
|||||||
return nil, fmt.Errorf("passkey authentication failed: %w", err)
|
return nil, fmt.Errorf("passkey authentication failed: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get user data from database
|
// Build request JSON for passkey login stored procedure
|
||||||
var username, email, roles string
|
reqData := map[string]any{
|
||||||
var userLevel int
|
"user_id": userID,
|
||||||
query := `SELECT username, email, user_level, COALESCE(roles, '') FROM users WHERE id = $1 AND is_active = true`
|
|
||||||
err = a.db.QueryRowContext(ctx, query, userID).Scan(&username, &email, &userLevel, &roles)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to get user data: %w", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate session token
|
|
||||||
sessionToken := "sess_" + generateRandomString(32) + "_" + fmt.Sprintf("%d", time.Now().Unix())
|
|
||||||
expiresAt := time.Now().Add(24 * time.Hour)
|
|
||||||
|
|
||||||
// Extract IP and user agent from claims
|
|
||||||
ipAddress := ""
|
|
||||||
userAgent := ""
|
|
||||||
if req.Claims != nil {
|
if req.Claims != nil {
|
||||||
if ip, ok := req.Claims["ip_address"].(string); ok {
|
if ip, ok := req.Claims["ip_address"].(string); ok {
|
||||||
ipAddress = ip
|
reqData["ip_address"] = ip
|
||||||
}
|
}
|
||||||
if ua, ok := req.Claims["user_agent"].(string); ok {
|
if ua, ok := req.Claims["user_agent"].(string); ok {
|
||||||
userAgent = ua
|
reqData["user_agent"] = ua
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create session
|
reqJSON, err := json.Marshal(reqData)
|
||||||
insertQuery := `INSERT INTO user_sessions (session_token, user_id, expires_at, ip_address, user_agent, last_activity_at)
|
|
||||||
VALUES ($1, $2, $3, $4, $5, now())`
|
|
||||||
_, err = a.db.ExecContext(ctx, insertQuery, sessionToken, userID, expiresAt, ipAddress, userAgent)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to create session: %w", err)
|
return nil, fmt.Errorf("failed to marshal passkey login request: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update last login
|
var success bool
|
||||||
updateQuery := `UPDATE users SET last_login_at = now() WHERE id = $1`
|
var errorMsg sql.NullString
|
||||||
_, _ = a.db.ExecContext(ctx, updateQuery, userID)
|
var dataJSON sql.NullString
|
||||||
|
|
||||||
// Return login response
|
query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.PasskeyLogin)
|
||||||
return &LoginResponse{
|
err = a.db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
|
||||||
Token: sessionToken,
|
if err != nil {
|
||||||
User: &UserContext{
|
return nil, fmt.Errorf("passkey login query failed: %w", err)
|
||||||
UserID: userID,
|
}
|
||||||
UserName: username,
|
|
||||||
Email: email,
|
if !success {
|
||||||
UserLevel: userLevel,
|
if errorMsg.Valid {
|
||||||
SessionID: sessionToken,
|
return nil, fmt.Errorf("%s", errorMsg.String)
|
||||||
Roles: parseRoles(roles),
|
}
|
||||||
},
|
return nil, fmt.Errorf("passkey login failed")
|
||||||
ExpiresIn: int64(24 * time.Hour.Seconds()),
|
}
|
||||||
}, nil
|
|
||||||
|
var response LoginResponse
|
||||||
|
if err := json.Unmarshal([]byte(dataJSON.String), &response); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse passkey login response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &response, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetPasskeyCredentials returns all passkey credentials for a user
|
// GetPasskeyCredentials returns all passkey credentials for a user
|
||||||
|
|||||||
222
pkg/security/sql_names.go
Normal file
222
pkg/security/sql_names.go
Normal file
@@ -0,0 +1,222 @@
|
|||||||
|
package security
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"regexp"
|
||||||
|
)
|
||||||
|
|
||||||
|
var validSQLIdentifier = regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_]*$`)
|
||||||
|
|
||||||
|
// SQLNames defines all configurable SQL stored procedure and table names
|
||||||
|
// used by the security package. Override individual fields to remap
|
||||||
|
// to custom database objects. Use DefaultSQLNames() for baseline defaults,
|
||||||
|
// and MergeSQLNames() to apply partial overrides.
|
||||||
|
type SQLNames struct {
|
||||||
|
// Auth procedures (DatabaseAuthenticator)
|
||||||
|
Login string // default: "resolvespec_login"
|
||||||
|
Register string // default: "resolvespec_register"
|
||||||
|
Logout string // default: "resolvespec_logout"
|
||||||
|
Session string // default: "resolvespec_session"
|
||||||
|
SessionUpdate string // default: "resolvespec_session_update"
|
||||||
|
RefreshToken string // default: "resolvespec_refresh_token"
|
||||||
|
|
||||||
|
// JWT procedures (JWTAuthenticator)
|
||||||
|
JWTLogin string // default: "resolvespec_jwt_login"
|
||||||
|
JWTLogout string // default: "resolvespec_jwt_logout"
|
||||||
|
|
||||||
|
// Security policy procedures
|
||||||
|
ColumnSecurity string // default: "resolvespec_column_security"
|
||||||
|
RowSecurity string // default: "resolvespec_row_security"
|
||||||
|
|
||||||
|
// TOTP procedures (DatabaseTwoFactorProvider)
|
||||||
|
TOTPEnable string // default: "resolvespec_totp_enable"
|
||||||
|
TOTPDisable string // default: "resolvespec_totp_disable"
|
||||||
|
TOTPGetStatus string // default: "resolvespec_totp_get_status"
|
||||||
|
TOTPGetSecret string // default: "resolvespec_totp_get_secret"
|
||||||
|
TOTPRegenerateBackup string // default: "resolvespec_totp_regenerate_backup_codes"
|
||||||
|
TOTPValidateBackupCode string // default: "resolvespec_totp_validate_backup_code"
|
||||||
|
|
||||||
|
// Passkey procedures (DatabasePasskeyProvider)
|
||||||
|
PasskeyStoreCredential string // default: "resolvespec_passkey_store_credential"
|
||||||
|
PasskeyGetCredsByUsername string // default: "resolvespec_passkey_get_credentials_by_username"
|
||||||
|
PasskeyGetCredential string // default: "resolvespec_passkey_get_credential"
|
||||||
|
PasskeyUpdateCounter string // default: "resolvespec_passkey_update_counter"
|
||||||
|
PasskeyGetUserCredentials string // default: "resolvespec_passkey_get_user_credentials"
|
||||||
|
PasskeyDeleteCredential string // default: "resolvespec_passkey_delete_credential"
|
||||||
|
PasskeyUpdateName string // default: "resolvespec_passkey_update_name"
|
||||||
|
PasskeyLogin string // default: "resolvespec_passkey_login"
|
||||||
|
|
||||||
|
// OAuth2 procedures (DatabaseAuthenticator OAuth2 methods)
|
||||||
|
OAuthGetOrCreateUser string // default: "resolvespec_oauth_getorcreateuser"
|
||||||
|
OAuthCreateSession string // default: "resolvespec_oauth_createsession"
|
||||||
|
OAuthGetRefreshToken string // default: "resolvespec_oauth_getrefreshtoken"
|
||||||
|
OAuthUpdateRefreshToken string // default: "resolvespec_oauth_updaterefreshtoken"
|
||||||
|
OAuthGetUser string // default: "resolvespec_oauth_getuser"
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultSQLNames returns an SQLNames with all default resolvespec_* values.
|
||||||
|
func DefaultSQLNames() *SQLNames {
|
||||||
|
return &SQLNames{
|
||||||
|
Login: "resolvespec_login",
|
||||||
|
Register: "resolvespec_register",
|
||||||
|
Logout: "resolvespec_logout",
|
||||||
|
Session: "resolvespec_session",
|
||||||
|
SessionUpdate: "resolvespec_session_update",
|
||||||
|
RefreshToken: "resolvespec_refresh_token",
|
||||||
|
|
||||||
|
JWTLogin: "resolvespec_jwt_login",
|
||||||
|
JWTLogout: "resolvespec_jwt_logout",
|
||||||
|
|
||||||
|
ColumnSecurity: "resolvespec_column_security",
|
||||||
|
RowSecurity: "resolvespec_row_security",
|
||||||
|
|
||||||
|
TOTPEnable: "resolvespec_totp_enable",
|
||||||
|
TOTPDisable: "resolvespec_totp_disable",
|
||||||
|
TOTPGetStatus: "resolvespec_totp_get_status",
|
||||||
|
TOTPGetSecret: "resolvespec_totp_get_secret",
|
||||||
|
TOTPRegenerateBackup: "resolvespec_totp_regenerate_backup_codes",
|
||||||
|
TOTPValidateBackupCode: "resolvespec_totp_validate_backup_code",
|
||||||
|
|
||||||
|
PasskeyStoreCredential: "resolvespec_passkey_store_credential",
|
||||||
|
PasskeyGetCredsByUsername: "resolvespec_passkey_get_credentials_by_username",
|
||||||
|
PasskeyGetCredential: "resolvespec_passkey_get_credential",
|
||||||
|
PasskeyUpdateCounter: "resolvespec_passkey_update_counter",
|
||||||
|
PasskeyGetUserCredentials: "resolvespec_passkey_get_user_credentials",
|
||||||
|
PasskeyDeleteCredential: "resolvespec_passkey_delete_credential",
|
||||||
|
PasskeyUpdateName: "resolvespec_passkey_update_name",
|
||||||
|
PasskeyLogin: "resolvespec_passkey_login",
|
||||||
|
|
||||||
|
OAuthGetOrCreateUser: "resolvespec_oauth_getorcreateuser",
|
||||||
|
OAuthCreateSession: "resolvespec_oauth_createsession",
|
||||||
|
OAuthGetRefreshToken: "resolvespec_oauth_getrefreshtoken",
|
||||||
|
OAuthUpdateRefreshToken: "resolvespec_oauth_updaterefreshtoken",
|
||||||
|
OAuthGetUser: "resolvespec_oauth_getuser",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MergeSQLNames returns a copy of base with any non-empty fields from override applied.
|
||||||
|
// If override is nil, a copy of base is returned.
|
||||||
|
func MergeSQLNames(base, override *SQLNames) *SQLNames {
|
||||||
|
if override == nil {
|
||||||
|
copied := *base
|
||||||
|
return &copied
|
||||||
|
}
|
||||||
|
merged := *base
|
||||||
|
if override.Login != "" {
|
||||||
|
merged.Login = override.Login
|
||||||
|
}
|
||||||
|
if override.Register != "" {
|
||||||
|
merged.Register = override.Register
|
||||||
|
}
|
||||||
|
if override.Logout != "" {
|
||||||
|
merged.Logout = override.Logout
|
||||||
|
}
|
||||||
|
if override.Session != "" {
|
||||||
|
merged.Session = override.Session
|
||||||
|
}
|
||||||
|
if override.SessionUpdate != "" {
|
||||||
|
merged.SessionUpdate = override.SessionUpdate
|
||||||
|
}
|
||||||
|
if override.RefreshToken != "" {
|
||||||
|
merged.RefreshToken = override.RefreshToken
|
||||||
|
}
|
||||||
|
if override.JWTLogin != "" {
|
||||||
|
merged.JWTLogin = override.JWTLogin
|
||||||
|
}
|
||||||
|
if override.JWTLogout != "" {
|
||||||
|
merged.JWTLogout = override.JWTLogout
|
||||||
|
}
|
||||||
|
if override.ColumnSecurity != "" {
|
||||||
|
merged.ColumnSecurity = override.ColumnSecurity
|
||||||
|
}
|
||||||
|
if override.RowSecurity != "" {
|
||||||
|
merged.RowSecurity = override.RowSecurity
|
||||||
|
}
|
||||||
|
if override.TOTPEnable != "" {
|
||||||
|
merged.TOTPEnable = override.TOTPEnable
|
||||||
|
}
|
||||||
|
if override.TOTPDisable != "" {
|
||||||
|
merged.TOTPDisable = override.TOTPDisable
|
||||||
|
}
|
||||||
|
if override.TOTPGetStatus != "" {
|
||||||
|
merged.TOTPGetStatus = override.TOTPGetStatus
|
||||||
|
}
|
||||||
|
if override.TOTPGetSecret != "" {
|
||||||
|
merged.TOTPGetSecret = override.TOTPGetSecret
|
||||||
|
}
|
||||||
|
if override.TOTPRegenerateBackup != "" {
|
||||||
|
merged.TOTPRegenerateBackup = override.TOTPRegenerateBackup
|
||||||
|
}
|
||||||
|
if override.TOTPValidateBackupCode != "" {
|
||||||
|
merged.TOTPValidateBackupCode = override.TOTPValidateBackupCode
|
||||||
|
}
|
||||||
|
if override.PasskeyStoreCredential != "" {
|
||||||
|
merged.PasskeyStoreCredential = override.PasskeyStoreCredential
|
||||||
|
}
|
||||||
|
if override.PasskeyGetCredsByUsername != "" {
|
||||||
|
merged.PasskeyGetCredsByUsername = override.PasskeyGetCredsByUsername
|
||||||
|
}
|
||||||
|
if override.PasskeyGetCredential != "" {
|
||||||
|
merged.PasskeyGetCredential = override.PasskeyGetCredential
|
||||||
|
}
|
||||||
|
if override.PasskeyUpdateCounter != "" {
|
||||||
|
merged.PasskeyUpdateCounter = override.PasskeyUpdateCounter
|
||||||
|
}
|
||||||
|
if override.PasskeyGetUserCredentials != "" {
|
||||||
|
merged.PasskeyGetUserCredentials = override.PasskeyGetUserCredentials
|
||||||
|
}
|
||||||
|
if override.PasskeyDeleteCredential != "" {
|
||||||
|
merged.PasskeyDeleteCredential = override.PasskeyDeleteCredential
|
||||||
|
}
|
||||||
|
if override.PasskeyUpdateName != "" {
|
||||||
|
merged.PasskeyUpdateName = override.PasskeyUpdateName
|
||||||
|
}
|
||||||
|
if override.PasskeyLogin != "" {
|
||||||
|
merged.PasskeyLogin = override.PasskeyLogin
|
||||||
|
}
|
||||||
|
if override.OAuthGetOrCreateUser != "" {
|
||||||
|
merged.OAuthGetOrCreateUser = override.OAuthGetOrCreateUser
|
||||||
|
}
|
||||||
|
if override.OAuthCreateSession != "" {
|
||||||
|
merged.OAuthCreateSession = override.OAuthCreateSession
|
||||||
|
}
|
||||||
|
if override.OAuthGetRefreshToken != "" {
|
||||||
|
merged.OAuthGetRefreshToken = override.OAuthGetRefreshToken
|
||||||
|
}
|
||||||
|
if override.OAuthUpdateRefreshToken != "" {
|
||||||
|
merged.OAuthUpdateRefreshToken = override.OAuthUpdateRefreshToken
|
||||||
|
}
|
||||||
|
if override.OAuthGetUser != "" {
|
||||||
|
merged.OAuthGetUser = override.OAuthGetUser
|
||||||
|
}
|
||||||
|
return &merged
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateSQLNames checks that all non-empty fields in names are valid SQL identifiers.
|
||||||
|
// Returns an error if any field contains invalid characters.
|
||||||
|
func ValidateSQLNames(names *SQLNames) error {
|
||||||
|
v := reflect.ValueOf(names).Elem()
|
||||||
|
typ := v.Type()
|
||||||
|
for i := 0; i < v.NumField(); i++ {
|
||||||
|
field := v.Field(i)
|
||||||
|
if field.Kind() != reflect.String {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
val := field.String()
|
||||||
|
if val != "" && !validSQLIdentifier.MatchString(val) {
|
||||||
|
return fmt.Errorf("SQLNames.%s contains invalid characters: %q", typ.Field(i).Name, val)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// resolveSQLNames merges an optional override with defaults.
|
||||||
|
// Used by constructors that accept variadic *SQLNames.
|
||||||
|
func resolveSQLNames(override ...*SQLNames) *SQLNames {
|
||||||
|
if len(override) > 0 && override[0] != nil {
|
||||||
|
return MergeSQLNames(DefaultSQLNames(), override[0])
|
||||||
|
}
|
||||||
|
return DefaultSQLNames()
|
||||||
|
}
|
||||||
145
pkg/security/sql_names_test.go
Normal file
145
pkg/security/sql_names_test.go
Normal file
@@ -0,0 +1,145 @@
|
|||||||
|
package security
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDefaultSQLNames_AllFieldsNonEmpty(t *testing.T) {
|
||||||
|
names := DefaultSQLNames()
|
||||||
|
v := reflect.ValueOf(names).Elem()
|
||||||
|
typ := v.Type()
|
||||||
|
|
||||||
|
for i := 0; i < v.NumField(); i++ {
|
||||||
|
field := v.Field(i)
|
||||||
|
if field.Kind() != reflect.String {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if field.String() == "" {
|
||||||
|
t.Errorf("DefaultSQLNames().%s is empty", typ.Field(i).Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMergeSQLNames_PartialOverride(t *testing.T) {
|
||||||
|
base := DefaultSQLNames()
|
||||||
|
override := &SQLNames{
|
||||||
|
Login: "custom_login",
|
||||||
|
TOTPEnable: "custom_totp_enable",
|
||||||
|
PasskeyLogin: "custom_passkey_login",
|
||||||
|
}
|
||||||
|
|
||||||
|
merged := MergeSQLNames(base, override)
|
||||||
|
|
||||||
|
if merged.Login != "custom_login" {
|
||||||
|
t.Errorf("MergeSQLNames().Login = %q, want %q", merged.Login, "custom_login")
|
||||||
|
}
|
||||||
|
if merged.TOTPEnable != "custom_totp_enable" {
|
||||||
|
t.Errorf("MergeSQLNames().TOTPEnable = %q, want %q", merged.TOTPEnable, "custom_totp_enable")
|
||||||
|
}
|
||||||
|
if merged.PasskeyLogin != "custom_passkey_login" {
|
||||||
|
t.Errorf("MergeSQLNames().PasskeyLogin = %q, want %q", merged.PasskeyLogin, "custom_passkey_login")
|
||||||
|
}
|
||||||
|
// Non-overridden fields should retain defaults
|
||||||
|
if merged.Logout != "resolvespec_logout" {
|
||||||
|
t.Errorf("MergeSQLNames().Logout = %q, want %q", merged.Logout, "resolvespec_logout")
|
||||||
|
}
|
||||||
|
if merged.Session != "resolvespec_session" {
|
||||||
|
t.Errorf("MergeSQLNames().Session = %q, want %q", merged.Session, "resolvespec_session")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMergeSQLNames_NilOverride(t *testing.T) {
|
||||||
|
base := DefaultSQLNames()
|
||||||
|
merged := MergeSQLNames(base, nil)
|
||||||
|
|
||||||
|
// Should be a copy, not the same pointer
|
||||||
|
if merged == base {
|
||||||
|
t.Error("MergeSQLNames with nil override should return a copy, not the same pointer")
|
||||||
|
}
|
||||||
|
|
||||||
|
// All values should match
|
||||||
|
v1 := reflect.ValueOf(base).Elem()
|
||||||
|
v2 := reflect.ValueOf(merged).Elem()
|
||||||
|
typ := v1.Type()
|
||||||
|
|
||||||
|
for i := 0; i < v1.NumField(); i++ {
|
||||||
|
f1 := v1.Field(i)
|
||||||
|
f2 := v2.Field(i)
|
||||||
|
if f1.Kind() != reflect.String {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if f1.String() != f2.String() {
|
||||||
|
t.Errorf("MergeSQLNames(base, nil).%s = %q, want %q", typ.Field(i).Name, f2.String(), f1.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMergeSQLNames_DoesNotMutateBase(t *testing.T) {
|
||||||
|
base := DefaultSQLNames()
|
||||||
|
originalLogin := base.Login
|
||||||
|
|
||||||
|
override := &SQLNames{Login: "custom_login"}
|
||||||
|
_ = MergeSQLNames(base, override)
|
||||||
|
|
||||||
|
if base.Login != originalLogin {
|
||||||
|
t.Errorf("MergeSQLNames mutated base: Login = %q, want %q", base.Login, originalLogin)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMergeSQLNames_AllFieldsMerged(t *testing.T) {
|
||||||
|
base := DefaultSQLNames()
|
||||||
|
override := &SQLNames{}
|
||||||
|
v := reflect.ValueOf(override).Elem()
|
||||||
|
for i := 0; i < v.NumField(); i++ {
|
||||||
|
if v.Field(i).Kind() == reflect.String {
|
||||||
|
v.Field(i).SetString("custom_sentinel")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
merged := MergeSQLNames(base, override)
|
||||||
|
mv := reflect.ValueOf(merged).Elem()
|
||||||
|
typ := mv.Type()
|
||||||
|
for i := 0; i < mv.NumField(); i++ {
|
||||||
|
if mv.Field(i).Kind() != reflect.String {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if mv.Field(i).String() != "custom_sentinel" {
|
||||||
|
t.Errorf("MergeSQLNames did not merge field %s", typ.Field(i).Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateSQLNames_Valid(t *testing.T) {
|
||||||
|
names := DefaultSQLNames()
|
||||||
|
if err := ValidateSQLNames(names); err != nil {
|
||||||
|
t.Errorf("ValidateSQLNames(defaults) error = %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateSQLNames_Invalid(t *testing.T) {
|
||||||
|
names := DefaultSQLNames()
|
||||||
|
names.Login = "resolvespec_login; DROP TABLE users; --"
|
||||||
|
|
||||||
|
err := ValidateSQLNames(names)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("ValidateSQLNames should reject names with invalid characters")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveSQLNames_NoOverride(t *testing.T) {
|
||||||
|
names := resolveSQLNames()
|
||||||
|
if names.Login != "resolvespec_login" {
|
||||||
|
t.Errorf("resolveSQLNames().Login = %q, want default", names.Login)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveSQLNames_WithOverride(t *testing.T) {
|
||||||
|
names := resolveSQLNames(&SQLNames{Login: "custom_login"})
|
||||||
|
if names.Login != "custom_login" {
|
||||||
|
t.Errorf("resolveSQLNames().Login = %q, want %q", names.Login, "custom_login")
|
||||||
|
}
|
||||||
|
if names.Logout != "resolvespec_logout" {
|
||||||
|
t.Errorf("resolveSQLNames().Logout = %q, want default", names.Logout)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -9,23 +9,23 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// DatabaseTwoFactorProvider implements TwoFactorAuthProvider using PostgreSQL stored procedures
|
// DatabaseTwoFactorProvider implements TwoFactorAuthProvider using PostgreSQL stored procedures
|
||||||
// Requires stored procedures: resolvespec_totp_enable, resolvespec_totp_disable,
|
// Procedure names are configurable via SQLNames (see DefaultSQLNames for defaults)
|
||||||
// resolvespec_totp_get_status, resolvespec_totp_get_secret,
|
|
||||||
// resolvespec_totp_regenerate_backup_codes, resolvespec_totp_validate_backup_code
|
|
||||||
// See totp_database_schema.sql for procedure definitions
|
// See totp_database_schema.sql for procedure definitions
|
||||||
type DatabaseTwoFactorProvider struct {
|
type DatabaseTwoFactorProvider struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
totpGen *TOTPGenerator
|
totpGen *TOTPGenerator
|
||||||
|
sqlNames *SQLNames
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewDatabaseTwoFactorProvider creates a new database-backed 2FA provider
|
// NewDatabaseTwoFactorProvider creates a new database-backed 2FA provider
|
||||||
func NewDatabaseTwoFactorProvider(db *sql.DB, config *TwoFactorConfig) *DatabaseTwoFactorProvider {
|
func NewDatabaseTwoFactorProvider(db *sql.DB, config *TwoFactorConfig, names ...*SQLNames) *DatabaseTwoFactorProvider {
|
||||||
if config == nil {
|
if config == nil {
|
||||||
config = DefaultTwoFactorConfig()
|
config = DefaultTwoFactorConfig()
|
||||||
}
|
}
|
||||||
return &DatabaseTwoFactorProvider{
|
return &DatabaseTwoFactorProvider{
|
||||||
db: db,
|
db: db,
|
||||||
totpGen: NewTOTPGenerator(config),
|
totpGen: NewTOTPGenerator(config),
|
||||||
|
sqlNames: resolveSQLNames(names...),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -76,7 +76,7 @@ func (p *DatabaseTwoFactorProvider) Enable2FA(userID int, secret string, backupC
|
|||||||
var success bool
|
var success bool
|
||||||
var errorMsg sql.NullString
|
var errorMsg sql.NullString
|
||||||
|
|
||||||
query := `SELECT p_success, p_error FROM resolvespec_totp_enable($1, $2, $3::jsonb)`
|
query := fmt.Sprintf(`SELECT p_success, p_error FROM %s($1, $2, $3::jsonb)`, p.sqlNames.TOTPEnable)
|
||||||
err = p.db.QueryRow(query, userID, secret, string(codesJSON)).Scan(&success, &errorMsg)
|
err = p.db.QueryRow(query, userID, secret, string(codesJSON)).Scan(&success, &errorMsg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("enable 2FA query failed: %w", err)
|
return fmt.Errorf("enable 2FA query failed: %w", err)
|
||||||
@@ -97,7 +97,7 @@ func (p *DatabaseTwoFactorProvider) Disable2FA(userID int) error {
|
|||||||
var success bool
|
var success bool
|
||||||
var errorMsg sql.NullString
|
var errorMsg sql.NullString
|
||||||
|
|
||||||
query := `SELECT p_success, p_error FROM resolvespec_totp_disable($1)`
|
query := fmt.Sprintf(`SELECT p_success, p_error FROM %s($1)`, p.sqlNames.TOTPDisable)
|
||||||
err := p.db.QueryRow(query, userID).Scan(&success, &errorMsg)
|
err := p.db.QueryRow(query, userID).Scan(&success, &errorMsg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("disable 2FA query failed: %w", err)
|
return fmt.Errorf("disable 2FA query failed: %w", err)
|
||||||
@@ -119,7 +119,7 @@ func (p *DatabaseTwoFactorProvider) Get2FAStatus(userID int) (bool, error) {
|
|||||||
var errorMsg sql.NullString
|
var errorMsg sql.NullString
|
||||||
var enabled bool
|
var enabled bool
|
||||||
|
|
||||||
query := `SELECT p_success, p_error, p_enabled FROM resolvespec_totp_get_status($1)`
|
query := fmt.Sprintf(`SELECT p_success, p_error, p_enabled FROM %s($1)`, p.sqlNames.TOTPGetStatus)
|
||||||
err := p.db.QueryRow(query, userID).Scan(&success, &errorMsg, &enabled)
|
err := p.db.QueryRow(query, userID).Scan(&success, &errorMsg, &enabled)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, fmt.Errorf("get 2FA status query failed: %w", err)
|
return false, fmt.Errorf("get 2FA status query failed: %w", err)
|
||||||
@@ -141,7 +141,7 @@ func (p *DatabaseTwoFactorProvider) Get2FASecret(userID int) (string, error) {
|
|||||||
var errorMsg sql.NullString
|
var errorMsg sql.NullString
|
||||||
var secret sql.NullString
|
var secret sql.NullString
|
||||||
|
|
||||||
query := `SELECT p_success, p_error, p_secret FROM resolvespec_totp_get_secret($1)`
|
query := fmt.Sprintf(`SELECT p_success, p_error, p_secret FROM %s($1)`, p.sqlNames.TOTPGetSecret)
|
||||||
err := p.db.QueryRow(query, userID).Scan(&success, &errorMsg, &secret)
|
err := p.db.QueryRow(query, userID).Scan(&success, &errorMsg, &secret)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("get 2FA secret query failed: %w", err)
|
return "", fmt.Errorf("get 2FA secret query failed: %w", err)
|
||||||
@@ -185,7 +185,7 @@ func (p *DatabaseTwoFactorProvider) GenerateBackupCodes(userID int, count int) (
|
|||||||
var success bool
|
var success bool
|
||||||
var errorMsg sql.NullString
|
var errorMsg sql.NullString
|
||||||
|
|
||||||
query := `SELECT p_success, p_error FROM resolvespec_totp_regenerate_backup_codes($1, $2::jsonb)`
|
query := fmt.Sprintf(`SELECT p_success, p_error FROM %s($1, $2::jsonb)`, p.sqlNames.TOTPRegenerateBackup)
|
||||||
err = p.db.QueryRow(query, userID, string(codesJSON)).Scan(&success, &errorMsg)
|
err = p.db.QueryRow(query, userID, string(codesJSON)).Scan(&success, &errorMsg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("regenerate backup codes query failed: %w", err)
|
return nil, fmt.Errorf("regenerate backup codes query failed: %w", err)
|
||||||
@@ -212,7 +212,7 @@ func (p *DatabaseTwoFactorProvider) ValidateBackupCode(userID int, code string)
|
|||||||
var errorMsg sql.NullString
|
var errorMsg sql.NullString
|
||||||
var valid bool
|
var valid bool
|
||||||
|
|
||||||
query := `SELECT p_success, p_error, p_valid FROM resolvespec_totp_validate_backup_code($1, $2)`
|
query := fmt.Sprintf(`SELECT p_success, p_error, p_valid FROM %s($1, $2)`, p.sqlNames.TOTPValidateBackupCode)
|
||||||
err := p.db.QueryRow(query, userID, codeHash).Scan(&success, &errorMsg, &valid)
|
err := p.db.QueryRow(query, userID, codeHash).Scan(&success, &errorMsg, &valid)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, fmt.Errorf("validate backup code query failed: %w", err)
|
return false, fmt.Errorf("validate backup code query failed: %w", err)
|
||||||
|
|||||||
Reference in New Issue
Block a user