mirror of
https://github.com/bitechdev/ResolveSpec.git
synced 2026-04-16 12:53:53 +00:00
Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| f79a400772 | |||
|
|
aef1f96c10 | ||
|
|
354ed2a8dc |
@@ -10,6 +10,8 @@ import (
|
||||
"github.com/bitechdev/ResolveSpec/pkg/reflection"
|
||||
)
|
||||
|
||||
const maxMetricFallbackEntityLength = 120
|
||||
|
||||
func recordQueryMetrics(enabled bool, operation, schema, entity, table string, startedAt time.Time, err error) {
|
||||
if !enabled {
|
||||
return
|
||||
@@ -136,7 +138,7 @@ func metricTargetFromRawQuery(query, driverName string) (operation, schema, enti
|
||||
operation = normalizeMetricOperation(firstQueryKeyword(query))
|
||||
tableRef := tableFromRawQuery(query, operation)
|
||||
if tableRef == "" {
|
||||
return operation, "", "unknown", "unknown"
|
||||
return operation, "", fallbackMetricEntityFromQuery(query), "unknown"
|
||||
}
|
||||
|
||||
schema, table = parseTableName(tableRef, driverName)
|
||||
@@ -144,6 +146,133 @@ func metricTargetFromRawQuery(query, driverName string) (operation, schema, enti
|
||||
return operation, schema, entity, table
|
||||
}
|
||||
|
||||
func fallbackMetricEntityFromQuery(query string) string {
|
||||
query = sanitizeMetricQueryShape(query)
|
||||
if query == "" {
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
if len(query) > maxMetricFallbackEntityLength {
|
||||
return query[:maxMetricFallbackEntityLength-3] + "..."
|
||||
}
|
||||
|
||||
return query
|
||||
}
|
||||
|
||||
func sanitizeMetricQueryShape(query string) string {
|
||||
query = strings.TrimSpace(query)
|
||||
if query == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
var out strings.Builder
|
||||
for i := 0; i < len(query); {
|
||||
if query[i] == '\'' {
|
||||
out.WriteByte('?')
|
||||
i++
|
||||
for i < len(query) {
|
||||
if query[i] == '\'' {
|
||||
if i+1 < len(query) && query[i+1] == '\'' {
|
||||
i += 2
|
||||
continue
|
||||
}
|
||||
i++
|
||||
break
|
||||
}
|
||||
i++
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if query[i] == '?' {
|
||||
out.WriteByte('?')
|
||||
i++
|
||||
continue
|
||||
}
|
||||
|
||||
if query[i] == '$' && i+1 < len(query) && isASCIIDigit(query[i+1]) {
|
||||
out.WriteByte('?')
|
||||
i++
|
||||
for i < len(query) && isASCIIDigit(query[i]) {
|
||||
i++
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if query[i] == ':' && (i == 0 || query[i-1] != ':') && i+1 < len(query) && isIdentifierStart(query[i+1]) {
|
||||
out.WriteByte('?')
|
||||
i++
|
||||
for i < len(query) && isIdentifierPart(query[i]) {
|
||||
i++
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if query[i] == '@' && (i == 0 || query[i-1] != '@') && i+1 < len(query) && isIdentifierStart(query[i+1]) {
|
||||
out.WriteByte('?')
|
||||
i++
|
||||
for i < len(query) && isIdentifierPart(query[i]) {
|
||||
i++
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if startsNumericLiteral(query, i) {
|
||||
out.WriteByte('?')
|
||||
i++
|
||||
for i < len(query) && (isASCIIDigit(query[i]) || query[i] == '.') {
|
||||
i++
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
out.WriteByte(query[i])
|
||||
i++
|
||||
}
|
||||
|
||||
return strings.Join(strings.Fields(out.String()), " ")
|
||||
}
|
||||
|
||||
func startsNumericLiteral(query string, idx int) bool {
|
||||
if idx >= len(query) {
|
||||
return false
|
||||
}
|
||||
|
||||
start := idx
|
||||
if query[idx] == '-' {
|
||||
if idx+1 >= len(query) || !isASCIIDigit(query[idx+1]) {
|
||||
return false
|
||||
}
|
||||
start++
|
||||
}
|
||||
|
||||
if !isASCIIDigit(query[start]) {
|
||||
return false
|
||||
}
|
||||
|
||||
if idx > 0 && isIdentifierPart(query[idx-1]) {
|
||||
return false
|
||||
}
|
||||
|
||||
if start+1 < len(query) && query[start] == '0' && (query[start+1] == 'x' || query[start+1] == 'X') {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func isASCIIDigit(ch byte) bool {
|
||||
return ch >= '0' && ch <= '9'
|
||||
}
|
||||
|
||||
func isIdentifierStart(ch byte) bool {
|
||||
return (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || ch == '_'
|
||||
}
|
||||
|
||||
func isIdentifierPart(ch byte) bool {
|
||||
return isIdentifierStart(ch) || isASCIIDigit(ch)
|
||||
}
|
||||
|
||||
func firstQueryKeyword(query string) string {
|
||||
query = strings.TrimSpace(query)
|
||||
if query == "" {
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -268,6 +269,47 @@ func TestPgSQLAdapterRawExecRecordsMetric(t *testing.T) {
|
||||
assert.Equal(t, "orders", calls[0].table)
|
||||
}
|
||||
|
||||
func TestPgSQLAdapterRawExecUsesSQLAsEntityWhenTargetUnknown(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
provider := &capturingMetricsProvider{}
|
||||
prev := metrics.GetProvider()
|
||||
metrics.SetProvider(provider)
|
||||
defer metrics.SetProvider(prev)
|
||||
|
||||
query := `select core.c_setuserid($1)`
|
||||
mock.ExpectExec(`select core\.c_setuserid\(\$1\)`).
|
||||
WithArgs(42).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
adapter := NewPgSQLAdapter(db)
|
||||
_, err = adapter.Exec(context.Background(), query, 42)
|
||||
|
||||
require.NoError(t, err)
|
||||
|
||||
calls := provider.snapshot()
|
||||
require.Len(t, calls, 1)
|
||||
assert.Equal(t, "SELECT", calls[0].operation)
|
||||
assert.Equal(t, "default", calls[0].schema)
|
||||
assert.Equal(t, "select core.c_setuserid(?)", calls[0].entity)
|
||||
assert.Equal(t, "unknown", calls[0].table)
|
||||
}
|
||||
|
||||
func TestFallbackMetricEntityFromQuerySanitizesAndTruncates(t *testing.T) {
|
||||
entity := fallbackMetricEntityFromQuery(" \n SELECT some_function(1, 'abc', $2, ?, :name, @p1, true, null) \t ")
|
||||
assert.Equal(t, "SELECT some_function(?, ?, ?, ?, ?, ?, true, null)", entity)
|
||||
|
||||
entity = fallbackMetricEntityFromQuery("SELECT price::numeric, id FROM logs WHERE code = -42")
|
||||
assert.Equal(t, "SELECT price::numeric, id FROM logs WHERE code = ?", entity)
|
||||
|
||||
longQuery := "SELECT " + strings.Repeat("x", maxMetricFallbackEntityLength)
|
||||
entity = fallbackMetricEntityFromQuery(longQuery)
|
||||
assert.Len(t, entity, maxMetricFallbackEntityLength)
|
||||
assert.True(t, strings.HasSuffix(entity, "..."))
|
||||
}
|
||||
|
||||
func TestBunAdapterRecordsEntityAndTableMetrics(t *testing.T) {
|
||||
sqldb, err := sql.Open(sqliteshim.ShimName, "file::memory:?cache=shared")
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -739,7 +739,7 @@ func (h *Handler) mergeQueryParams(r *http.Request, sqlquery string, variables m
|
||||
colval = strings.ReplaceAll(colval, "\\", "\\\\")
|
||||
colval = strings.ReplaceAll(colval, "'", "''")
|
||||
if colval != "*" {
|
||||
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s ILIKE '%%%s%%'", ValidSQL(parmk, "colname"), colval))
|
||||
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("CAST(%s AS TEXT) ILIKE '%%%s%%'", ValidSQL(parmk, "colname"), colval))
|
||||
}
|
||||
} else if val == "" || val == "0" {
|
||||
// For empty/zero values, treat as literal 0 or empty string with quotes
|
||||
@@ -806,7 +806,7 @@ func (h *Handler) mergeHeaderParams(r *http.Request, sqlquery string, variables
|
||||
colname := strings.ReplaceAll(k, "x-searchfilter-", "")
|
||||
sval := strings.ReplaceAll(val, "'", "")
|
||||
if sval != "" {
|
||||
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s ILIKE '%%%s%%'", ValidSQL(colname, "colname"), ValidSQL(sval, "colvalue")))
|
||||
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("CAST(%s AS TEXT) ILIKE '%%%s%%'", ValidSQL(colname, "colname"), ValidSQL(sval, "colvalue")))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -259,7 +259,7 @@ func (h *Handler) ApplyFilters(sqlQuery string, params *RequestParameters) strin
|
||||
for colName, value := range params.SearchFilters {
|
||||
sval := strings.ReplaceAll(value, "'", "")
|
||||
if sval != "" {
|
||||
condition := fmt.Sprintf("%s ILIKE '%%%s%%'", ValidSQL(colName, "colname"), ValidSQL(sval, "colvalue"))
|
||||
condition := fmt.Sprintf("CAST(%s AS TEXT) ILIKE '%%%s%%'", ValidSQL(colName, "colname"), ValidSQL(sval, "colvalue"))
|
||||
sqlQuery = sqlQryWhere(sqlQuery, condition)
|
||||
logger.Debug("Applied search filter: %s", condition)
|
||||
}
|
||||
@@ -307,11 +307,11 @@ func (h *Handler) buildFilterCondition(colName string, op FilterOperator) string
|
||||
|
||||
switch operator {
|
||||
case "contains", "contain", "like":
|
||||
return fmt.Sprintf("%s ILIKE '%%%s%%'", safCol, ValidSQL(value, "colvalue"))
|
||||
return fmt.Sprintf("CAST(%s AS TEXT) ILIKE '%%%s%%'", safCol, ValidSQL(value, "colvalue"))
|
||||
case "beginswith", "startswith":
|
||||
return fmt.Sprintf("%s ILIKE '%s%%'", safCol, ValidSQL(value, "colvalue"))
|
||||
return fmt.Sprintf("CAST(%s AS TEXT) ILIKE '%s%%'", safCol, ValidSQL(value, "colvalue"))
|
||||
case "endswith":
|
||||
return fmt.Sprintf("%s ILIKE '%%%s'", safCol, ValidSQL(value, "colvalue"))
|
||||
return fmt.Sprintf("CAST(%s AS TEXT) ILIKE '%%%s'", safCol, ValidSQL(value, "colvalue"))
|
||||
case "equals", "eq", "=":
|
||||
if IsNumeric(value) {
|
||||
return fmt.Sprintf("%s = %s", safCol, ValidSQL(value, "colvalue"))
|
||||
|
||||
@@ -274,7 +274,7 @@ func TestBuildFilterCondition(t *testing.T) {
|
||||
Value: "test",
|
||||
Logic: "AND",
|
||||
},
|
||||
expected: "description ILIKE '%test%'",
|
||||
expected: "CAST(description AS TEXT) ILIKE '%test%'",
|
||||
},
|
||||
{
|
||||
name: "Starts with operator",
|
||||
@@ -284,7 +284,7 @@ func TestBuildFilterCondition(t *testing.T) {
|
||||
Value: "john",
|
||||
Logic: "AND",
|
||||
},
|
||||
expected: "name ILIKE 'john%'",
|
||||
expected: "CAST(name AS TEXT) ILIKE 'john%'",
|
||||
},
|
||||
{
|
||||
name: "Ends with operator",
|
||||
@@ -294,7 +294,7 @@ func TestBuildFilterCondition(t *testing.T) {
|
||||
Value: "@example.com",
|
||||
Logic: "AND",
|
||||
},
|
||||
expected: "email ILIKE '%@example.com'",
|
||||
expected: "CAST(email AS TEXT) ILIKE '%@example.com'",
|
||||
},
|
||||
{
|
||||
name: "Between operator",
|
||||
|
||||
@@ -702,7 +702,12 @@ func (h *Handler) readMultiple(hookCtx *HookContext) (data interface{}, metadata
|
||||
if hookCtx.Options != nil {
|
||||
// Apply filters
|
||||
for _, filter := range hookCtx.Options.Filters {
|
||||
query = query.Where(fmt.Sprintf("%s %s ?", filter.Column, h.getOperatorSQL(filter.Operator)), filter.Value)
|
||||
op := strings.ToLower(filter.Operator)
|
||||
if op == "like" || op == "ilike" {
|
||||
query = query.Where(fmt.Sprintf("CAST(%s AS TEXT) %s ?", filter.Column, h.getOperatorSQL(filter.Operator)), filter.Value)
|
||||
} else {
|
||||
query = query.Where(fmt.Sprintf("%s %s ?", filter.Column, h.getOperatorSQL(filter.Operator)), filter.Value)
|
||||
}
|
||||
}
|
||||
|
||||
// Apply sorting
|
||||
@@ -743,7 +748,12 @@ func (h *Handler) readMultiple(hookCtx *HookContext) (data interface{}, metadata
|
||||
countQuery := h.db.NewSelect().Model(hookCtx.ModelPtr).Table(hookCtx.TableName)
|
||||
if hookCtx.Options != nil {
|
||||
for _, filter := range hookCtx.Options.Filters {
|
||||
countQuery = countQuery.Where(fmt.Sprintf("%s %s ?", filter.Column, h.getOperatorSQL(filter.Operator)), filter.Value)
|
||||
op := strings.ToLower(filter.Operator)
|
||||
if op == "like" || op == "ilike" {
|
||||
countQuery = countQuery.Where(fmt.Sprintf("CAST(%s AS TEXT) %s ?", filter.Column, h.getOperatorSQL(filter.Operator)), filter.Value)
|
||||
} else {
|
||||
countQuery = countQuery.Where(fmt.Sprintf("%s %s ?", filter.Column, h.getOperatorSQL(filter.Operator)), filter.Value)
|
||||
}
|
||||
}
|
||||
}
|
||||
count, _ := countQuery.Count(hookCtx.Context)
|
||||
|
||||
@@ -735,9 +735,9 @@ func (h *Handler) buildFilterCondition(filter common.FilterOption) (condition st
|
||||
case "lte", "<=":
|
||||
return fmt.Sprintf("%s <= ?", filter.Column), []interface{}{filter.Value}
|
||||
case "like":
|
||||
return fmt.Sprintf("%s LIKE ?", filter.Column), []interface{}{filter.Value}
|
||||
return fmt.Sprintf("CAST(%s AS TEXT) LIKE ?", filter.Column), []interface{}{filter.Value}
|
||||
case "ilike":
|
||||
return fmt.Sprintf("%s ILIKE ?", filter.Column), []interface{}{filter.Value}
|
||||
return fmt.Sprintf("CAST(%s AS TEXT) ILIKE ?", filter.Column), []interface{}{filter.Value}
|
||||
case "in":
|
||||
condition, args := common.BuildInCondition(filter.Column, filter.Value)
|
||||
return condition, args
|
||||
|
||||
@@ -54,7 +54,7 @@ func TestBuildFilterCondition(t *testing.T) {
|
||||
Operator: "like",
|
||||
Value: "%@example.com",
|
||||
},
|
||||
expectedCondition: "email LIKE ?",
|
||||
expectedCondition: "CAST(email AS TEXT) LIKE ?",
|
||||
expectedArgsCount: 1,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -1545,10 +1545,10 @@ func (h *Handler) buildFilterCondition(filter common.FilterOption) (conditionStr
|
||||
condition = fmt.Sprintf("%s <= ?", filter.Column)
|
||||
args = []interface{}{filter.Value}
|
||||
case "like":
|
||||
condition = fmt.Sprintf("%s LIKE ?", filter.Column)
|
||||
condition = fmt.Sprintf("CAST(%s AS TEXT) LIKE ?", filter.Column)
|
||||
args = []interface{}{filter.Value}
|
||||
case "ilike":
|
||||
condition = fmt.Sprintf("%s ILIKE ?", filter.Column)
|
||||
condition = fmt.Sprintf("CAST(%s AS TEXT) ILIKE ?", filter.Column)
|
||||
args = []interface{}{filter.Value}
|
||||
case "in":
|
||||
condition, args = common.BuildInCondition(filter.Column, filter.Value)
|
||||
@@ -1589,10 +1589,10 @@ func (h *Handler) applyFilter(query common.SelectQuery, filter common.FilterOpti
|
||||
condition = fmt.Sprintf("%s <= ?", filter.Column)
|
||||
args = []interface{}{filter.Value}
|
||||
case "like":
|
||||
condition = fmt.Sprintf("%s LIKE ?", filter.Column)
|
||||
condition = fmt.Sprintf("CAST(%s AS TEXT) LIKE ?", filter.Column)
|
||||
args = []interface{}{filter.Value}
|
||||
case "ilike":
|
||||
condition = fmt.Sprintf("%s ILIKE ?", filter.Column)
|
||||
condition = fmt.Sprintf("CAST(%s AS TEXT) ILIKE ?", filter.Column)
|
||||
args = []interface{}{filter.Value}
|
||||
case "in":
|
||||
condition, args = common.BuildInCondition(filter.Column, filter.Value)
|
||||
|
||||
@@ -2118,11 +2118,12 @@ func (h *Handler) qualifyColumnName(columnName, fullTableName string) string {
|
||||
|
||||
func (h *Handler) applyFilter(query common.SelectQuery, filter common.FilterOption, tableName string, needsCast bool, logicOp string) common.SelectQuery {
|
||||
// Qualify the column name with table name if not already qualified
|
||||
qualifiedColumn := h.qualifyColumnName(filter.Column, tableName)
|
||||
rawQualifiedColumn := h.qualifyColumnName(filter.Column, tableName)
|
||||
qualifiedColumn := rawQualifiedColumn
|
||||
|
||||
// Apply casting to text if needed for non-numeric columns or non-numeric values
|
||||
if needsCast {
|
||||
qualifiedColumn = fmt.Sprintf("CAST(%s AS TEXT)", qualifiedColumn)
|
||||
qualifiedColumn = fmt.Sprintf("CAST(%s AS TEXT)", rawQualifiedColumn)
|
||||
}
|
||||
|
||||
// Helper function to apply the correct Where method based on logic operator
|
||||
@@ -2147,11 +2148,11 @@ func (h *Handler) applyFilter(query common.SelectQuery, filter common.FilterOpti
|
||||
case "lte", "less_than_equals", "le":
|
||||
return applyWhere(fmt.Sprintf("%s <= ?", qualifiedColumn), filter.Value)
|
||||
case "like":
|
||||
return applyWhere(fmt.Sprintf("%s LIKE ?", qualifiedColumn), filter.Value)
|
||||
// Always cast to TEXT for LIKE/ILIKE to support date/time/timestamp columns
|
||||
return applyWhere(fmt.Sprintf("CAST(%s AS TEXT) LIKE ?", rawQualifiedColumn), filter.Value)
|
||||
case "ilike":
|
||||
// Use ILIKE for case-insensitive search (PostgreSQL)
|
||||
// Column is already cast to TEXT if needed
|
||||
return applyWhere(fmt.Sprintf("%s ILIKE ?", qualifiedColumn), filter.Value)
|
||||
// Always cast to TEXT for LIKE/ILIKE to support date/time/timestamp columns
|
||||
return applyWhere(fmt.Sprintf("CAST(%s AS TEXT) ILIKE ?", rawQualifiedColumn), filter.Value)
|
||||
case "in":
|
||||
cond, inArgs := common.BuildInCondition(qualifiedColumn, filter.Value)
|
||||
if cond == "" {
|
||||
@@ -2203,11 +2204,16 @@ func (h *Handler) applyOrFilterGroup(query common.SelectQuery, filters []*common
|
||||
|
||||
for i, filter := range filters {
|
||||
// Qualify the column name with table name if not already qualified
|
||||
qualifiedColumn := h.qualifyColumnName(filter.Column, tableName)
|
||||
rawQualifiedColumn := h.qualifyColumnName(filter.Column, tableName)
|
||||
qualifiedColumn := rawQualifiedColumn
|
||||
|
||||
// Apply casting to text if needed for non-numeric columns or non-numeric values
|
||||
if castInfo[i].NeedsCast {
|
||||
qualifiedColumn = fmt.Sprintf("CAST(%s AS TEXT)", qualifiedColumn)
|
||||
op := strings.ToLower(filter.Operator)
|
||||
if op == "like" || op == "ilike" {
|
||||
// Always cast to TEXT for LIKE/ILIKE to support date/time/timestamp columns
|
||||
qualifiedColumn = fmt.Sprintf("CAST(%s AS TEXT)", rawQualifiedColumn)
|
||||
} else if castInfo[i].NeedsCast {
|
||||
// Apply casting to text if needed for non-numeric columns or non-numeric values
|
||||
qualifiedColumn = fmt.Sprintf("CAST(%s AS TEXT)", rawQualifiedColumn)
|
||||
}
|
||||
|
||||
// Build the condition based on operator
|
||||
|
||||
@@ -13,6 +13,7 @@ Type-safe, composable security system for ResolveSpec with support for authentic
|
||||
- ✅ **Extensible** - Implement custom providers for your needs
|
||||
- ✅ **Stored Procedures** - All database operations use PostgreSQL stored procedures for security and maintainability
|
||||
- ✅ **OAuth2 Authorization Server** - Built-in OAuth 2.1 + PKCE server (RFC 8414, 7591, 7009, 7662) with login form and external provider federation
|
||||
- ✅ **Password Reset** - Self-service password reset with secure token generation and session invalidation
|
||||
|
||||
## Stored Procedure Architecture
|
||||
|
||||
@@ -45,6 +46,8 @@ Type-safe, composable security system for ResolveSpec with support for authentic
|
||||
| `resolvespec_oauth_exchange_code` | Consume authorization code (single-use) | OAuthServer / DatabaseAuthenticator |
|
||||
| `resolvespec_oauth_introspect` | Token introspection (RFC 7662) | OAuthServer / DatabaseAuthenticator |
|
||||
| `resolvespec_oauth_revoke` | Token revocation (RFC 7009) | OAuthServer / DatabaseAuthenticator |
|
||||
| `resolvespec_password_reset_request` | Create password reset token | DatabaseAuthenticator |
|
||||
| `resolvespec_password_reset` | Validate token and set new password | DatabaseAuthenticator |
|
||||
|
||||
See `database_schema.sql` for complete stored procedure definitions and examples.
|
||||
|
||||
@@ -904,6 +907,66 @@ securityList := security.NewSecurityList(provider)
|
||||
restheadspec.RegisterSecurityHooks(handler, securityList) // or funcspec/resolvespec
|
||||
```
|
||||
|
||||
## Password Reset
|
||||
|
||||
`DatabaseAuthenticator` implements `PasswordResettable` for self-service password reset.
|
||||
|
||||
### Flow
|
||||
|
||||
1. User submits email or username → `RequestPasswordReset` → server generates a token and returns it for out-of-band delivery (email, SMS, etc.)
|
||||
2. User submits the raw token + new password → `CompletePasswordReset` → password updated, all sessions invalidated
|
||||
|
||||
### DB Requirements
|
||||
|
||||
Run the migrations in `database_schema.sql`:
|
||||
- `user_password_resets` table (`user_id`, `token_hash` SHA-256, `expires_at`, `used`, `used_at`)
|
||||
- `resolvespec_password_reset_request` stored procedure
|
||||
- `resolvespec_password_reset` stored procedure
|
||||
|
||||
Requires the `pgcrypto` extension (`gen_random_bytes`, `digest`) — already used by `resolvespec_login`.
|
||||
|
||||
### Usage
|
||||
|
||||
```go
|
||||
auth := security.NewDatabaseAuthenticator(db)
|
||||
|
||||
// Step 1 — initiate reset (call after user submits their email)
|
||||
resp, err := auth.RequestPasswordReset(ctx, security.PasswordResetRequest{
|
||||
Email: "user@example.com",
|
||||
})
|
||||
// resp.Token is the raw token — deliver it out-of-band
|
||||
// resp.ExpiresIn is 3600 (1 hour)
|
||||
// Always returns success regardless of whether the user exists (anti-enumeration)
|
||||
|
||||
// Step 2 — complete reset (call after user submits token + new password)
|
||||
err = auth.CompletePasswordReset(ctx, security.PasswordResetCompleteRequest{
|
||||
Token: rawToken,
|
||||
NewPassword: "newSecurePassword",
|
||||
})
|
||||
// On success: password updated, all active sessions deleted
|
||||
```
|
||||
|
||||
### Security Notes
|
||||
|
||||
- The raw token is never stored; only its SHA-256 hash is persisted
|
||||
- Requesting a reset invalidates any previous unused tokens for that user
|
||||
- Tokens expire after 1 hour
|
||||
- Completing a reset deletes all active sessions, forcing re-login
|
||||
- `RequestPasswordReset` always returns success even when the email/username is not found, preventing user enumeration
|
||||
- Hash the new password with bcrypt before storing (pgcrypto `crypt`/`gen_salt`) — see the TODO comment in `resolvespec_password_reset`
|
||||
|
||||
### SQLNames
|
||||
|
||||
```go
|
||||
type SQLNames struct {
|
||||
// ...
|
||||
PasswordResetRequest string // default: "resolvespec_password_reset_request"
|
||||
PasswordResetComplete string // default: "resolvespec_password_reset"
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## OAuth2 Authorization Server
|
||||
|
||||
`OAuthServer` is a generic OAuth 2.1 + PKCE authorization server. It is not tied to any spec — `pkg/resolvemcp` uses it, but it can be used standalone with any `http.ServeMux`.
|
||||
@@ -1110,6 +1173,14 @@ type Cacheable interface {
|
||||
}
|
||||
```
|
||||
|
||||
**PasswordResettable** - Self-service password reset:
|
||||
```go
|
||||
type PasswordResettable interface {
|
||||
RequestPasswordReset(ctx context.Context, req PasswordResetRequest) (*PasswordResetResponse, error)
|
||||
CompletePasswordReset(ctx context.Context, req PasswordResetCompleteRequest) error
|
||||
}
|
||||
```
|
||||
|
||||
## Benefits Over Callbacks
|
||||
|
||||
| Feature | Old (Callbacks) | New (Interfaces) |
|
||||
|
||||
@@ -1398,6 +1398,158 @@ $$ LANGUAGE plpgsql;
|
||||
-- Get credentials by username
|
||||
-- SELECT * FROM resolvespec_passkey_get_credentials_by_username('admin');
|
||||
|
||||
-- ============================================
|
||||
-- Password Reset Tables
|
||||
-- ============================================
|
||||
|
||||
-- Password reset tokens table
|
||||
CREATE TABLE IF NOT EXISTS user_password_resets (
|
||||
id SERIAL PRIMARY KEY,
|
||||
user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||
token_hash VARCHAR(64) NOT NULL UNIQUE, -- SHA-256 hex of the raw token
|
||||
expires_at TIMESTAMP NOT NULL,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
used BOOLEAN DEFAULT false,
|
||||
used_at TIMESTAMP
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_pw_reset_token_hash ON user_password_resets(token_hash);
|
||||
CREATE INDEX IF NOT EXISTS idx_pw_reset_user_id ON user_password_resets(user_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_pw_reset_expires_at ON user_password_resets(expires_at);
|
||||
|
||||
-- ============================================
|
||||
-- Stored Procedures for Password Reset
|
||||
-- ============================================
|
||||
|
||||
-- 1. resolvespec_password_reset_request - Creates a password reset token for a user
|
||||
-- Input: p_request jsonb {email: string, username: string}
|
||||
-- Output: p_success (bool), p_error (text), p_data jsonb {token: string, expires_in: int}
|
||||
-- NOTE: The raw token is returned so the caller can deliver it out-of-band (e.g. email).
|
||||
-- Only the SHA-256 hash is stored. Invalidates any previous unused tokens for the user.
|
||||
CREATE OR REPLACE FUNCTION resolvespec_password_reset_request(p_request jsonb)
|
||||
RETURNS TABLE(p_success boolean, p_error text, p_data jsonb) AS $$
|
||||
DECLARE
|
||||
v_user_id INTEGER;
|
||||
v_email TEXT;
|
||||
v_username TEXT;
|
||||
v_raw_token TEXT;
|
||||
v_token_hash TEXT;
|
||||
v_expires_at TIMESTAMP;
|
||||
BEGIN
|
||||
v_email := p_request->>'email';
|
||||
v_username := p_request->>'username';
|
||||
|
||||
-- Require at least one identifier
|
||||
IF (v_email IS NULL OR v_email = '') AND (v_username IS NULL OR v_username = '') THEN
|
||||
RETURN QUERY SELECT false, 'email or username is required'::text, NULL::jsonb;
|
||||
RETURN;
|
||||
END IF;
|
||||
|
||||
-- Look up active user
|
||||
IF v_email IS NOT NULL AND v_email <> '' THEN
|
||||
SELECT id INTO v_user_id FROM users WHERE email = v_email AND is_active = true;
|
||||
ELSE
|
||||
SELECT id INTO v_user_id FROM users WHERE username = v_username AND is_active = true;
|
||||
END IF;
|
||||
|
||||
-- Return generic success even when user not found to avoid user enumeration
|
||||
IF NOT FOUND THEN
|
||||
RETURN QUERY SELECT true, NULL::text, jsonb_build_object('token', '', 'expires_in', 0);
|
||||
RETURN;
|
||||
END IF;
|
||||
|
||||
-- Invalidate previous unused tokens for this user
|
||||
DELETE FROM user_password_resets WHERE user_id = v_user_id AND used = false;
|
||||
|
||||
-- Generate a random 32-byte token and store its SHA-256 hash
|
||||
v_raw_token := encode(gen_random_bytes(32), 'hex');
|
||||
v_token_hash := encode(digest(v_raw_token, 'sha256'), 'hex');
|
||||
v_expires_at := now() + interval '1 hour';
|
||||
|
||||
INSERT INTO user_password_resets (user_id, token_hash, expires_at)
|
||||
VALUES (v_user_id, v_token_hash, v_expires_at);
|
||||
|
||||
RETURN QUERY SELECT
|
||||
true,
|
||||
NULL::text,
|
||||
jsonb_build_object(
|
||||
'token', v_raw_token,
|
||||
'expires_in', 3600
|
||||
);
|
||||
EXCEPTION
|
||||
WHEN OTHERS THEN
|
||||
RETURN QUERY SELECT false, SQLERRM::text, NULL::jsonb;
|
||||
END;
|
||||
$$ LANGUAGE plpgsql;
|
||||
|
||||
-- 2. resolvespec_password_reset - Validates the token and updates the user's password
|
||||
-- Input: p_request jsonb {token: string, new_password: string}
|
||||
-- Output: p_success (bool), p_error (text)
|
||||
-- NOTE: Hash the new_password with bcrypt before storing (pgcrypto crypt/gen_salt).
|
||||
-- The TODO below mirrors the convention used in resolvespec_register.
|
||||
CREATE OR REPLACE FUNCTION resolvespec_password_reset(p_request jsonb)
|
||||
RETURNS TABLE(p_success boolean, p_error text) AS $$
|
||||
DECLARE
|
||||
v_raw_token TEXT;
|
||||
v_token_hash TEXT;
|
||||
v_new_pw TEXT;
|
||||
v_reset_id INTEGER;
|
||||
v_user_id INTEGER;
|
||||
v_expires_at TIMESTAMP;
|
||||
BEGIN
|
||||
v_raw_token := p_request->>'token';
|
||||
v_new_pw := p_request->>'new_password';
|
||||
|
||||
IF v_raw_token IS NULL OR v_raw_token = '' THEN
|
||||
RETURN QUERY SELECT false, 'token is required'::text;
|
||||
RETURN;
|
||||
END IF;
|
||||
|
||||
IF v_new_pw IS NULL OR v_new_pw = '' THEN
|
||||
RETURN QUERY SELECT false, 'new_password is required'::text;
|
||||
RETURN;
|
||||
END IF;
|
||||
|
||||
v_token_hash := encode(digest(v_raw_token, 'sha256'), 'hex');
|
||||
|
||||
-- Find valid, unused reset token
|
||||
SELECT id, user_id, expires_at
|
||||
INTO v_reset_id, v_user_id, v_expires_at
|
||||
FROM user_password_resets
|
||||
WHERE token_hash = v_token_hash AND used = false;
|
||||
|
||||
IF NOT FOUND THEN
|
||||
RETURN QUERY SELECT false, 'invalid or expired token'::text;
|
||||
RETURN;
|
||||
END IF;
|
||||
|
||||
IF v_expires_at <= now() THEN
|
||||
RETURN QUERY SELECT false, 'invalid or expired token'::text;
|
||||
RETURN;
|
||||
END IF;
|
||||
|
||||
-- TODO: Hash new password with pgcrypto before storing
|
||||
-- Enable pgcrypto: CREATE EXTENSION IF NOT EXISTS pgcrypto;
|
||||
-- v_new_pw := crypt(v_new_pw, gen_salt('bf'));
|
||||
|
||||
-- Update password and invalidate all sessions
|
||||
UPDATE users SET password = v_new_pw, updated_at = now() WHERE id = v_user_id;
|
||||
DELETE FROM user_sessions WHERE user_id = v_user_id;
|
||||
|
||||
-- Mark token as used
|
||||
UPDATE user_password_resets SET used = true, used_at = now() WHERE id = v_reset_id;
|
||||
|
||||
RETURN QUERY SELECT true, NULL::text;
|
||||
EXCEPTION
|
||||
WHEN OTHERS THEN
|
||||
RETURN QUERY SELECT false, SQLERRM::text;
|
||||
END;
|
||||
$$ LANGUAGE plpgsql;
|
||||
|
||||
-- Example: Test password reset stored procedures
|
||||
-- SELECT * FROM resolvespec_password_reset_request('{"email": "user@example.com"}'::jsonb);
|
||||
-- SELECT * FROM resolvespec_password_reset('{"token": "<raw_token>", "new_password": "newpass123"}'::jsonb);
|
||||
|
||||
-- ============================================
|
||||
-- OAuth2 Server Tables (OAuthServer persistence)
|
||||
-- ============================================
|
||||
|
||||
@@ -57,6 +57,27 @@ type LogoutRequest struct {
|
||||
UserID int `json:"user_id"`
|
||||
}
|
||||
|
||||
// PasswordResetRequest initiates a password reset for a user
|
||||
type PasswordResetRequest struct {
|
||||
Email string `json:"email,omitempty"`
|
||||
Username string `json:"username,omitempty"`
|
||||
}
|
||||
|
||||
// PasswordResetResponse is returned when a reset is initiated
|
||||
type PasswordResetResponse struct {
|
||||
// Token is the reset token to be delivered out-of-band (e.g. email).
|
||||
// The stored procedure may return it for delivery or leave it empty
|
||||
// if the delivery is handled entirely in the database.
|
||||
Token string `json:"token"`
|
||||
ExpiresIn int64 `json:"expires_in"` // seconds
|
||||
}
|
||||
|
||||
// PasswordResetCompleteRequest completes a password reset using the token
|
||||
type PasswordResetCompleteRequest struct {
|
||||
Token string `json:"token"`
|
||||
NewPassword string `json:"new_password"`
|
||||
}
|
||||
|
||||
// Authenticator handles user authentication operations
|
||||
type Authenticator interface {
|
||||
// Login authenticates credentials and returns a token
|
||||
@@ -114,3 +135,12 @@ type Cacheable interface {
|
||||
// ClearCache clears cached security rules for a user/entity
|
||||
ClearCache(ctx context.Context, userID int, schema, table string) error
|
||||
}
|
||||
|
||||
// PasswordResettable allows providers to support self-service password reset
|
||||
type PasswordResettable interface {
|
||||
// RequestPasswordReset creates a reset token for the given email/username
|
||||
RequestPasswordReset(ctx context.Context, req PasswordResetRequest) (*PasswordResetResponse, error)
|
||||
|
||||
// CompletePasswordReset validates the token and sets the new password
|
||||
CompletePasswordReset(ctx context.Context, req PasswordResetCompleteRequest) error
|
||||
}
|
||||
|
||||
@@ -868,6 +868,75 @@ func generateRandomString(length int) string {
|
||||
// return ""
|
||||
// }
|
||||
|
||||
// Password reset methods
|
||||
// ======================
|
||||
|
||||
// RequestPasswordReset implements PasswordResettable. It calls the stored procedure
|
||||
// resolvespec_password_reset_request and returns the reset token and expiry.
|
||||
func (a *DatabaseAuthenticator) RequestPasswordReset(ctx context.Context, req PasswordResetRequest) (*PasswordResetResponse, error) {
|
||||
reqJSON, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal password reset request: %w", err)
|
||||
}
|
||||
|
||||
var success bool
|
||||
var errorMsg sql.NullString
|
||||
var dataJSON sql.NullString
|
||||
|
||||
err = a.runDBOpWithReconnect(func(db *sql.DB) error {
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.PasswordResetRequest)
|
||||
return db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("password reset request query failed: %w", err)
|
||||
}
|
||||
|
||||
if !success {
|
||||
if errorMsg.Valid {
|
||||
return nil, fmt.Errorf("%s", errorMsg.String)
|
||||
}
|
||||
return nil, fmt.Errorf("password reset request failed")
|
||||
}
|
||||
|
||||
var response PasswordResetResponse
|
||||
if dataJSON.Valid && dataJSON.String != "" {
|
||||
if err := json.Unmarshal([]byte(dataJSON.String), &response); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse password reset response: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return &response, nil
|
||||
}
|
||||
|
||||
// CompletePasswordReset implements PasswordResettable. It validates the token and
|
||||
// updates the user's password via resolvespec_password_reset.
|
||||
func (a *DatabaseAuthenticator) CompletePasswordReset(ctx context.Context, req PasswordResetCompleteRequest) error {
|
||||
reqJSON, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal password reset complete request: %w", err)
|
||||
}
|
||||
|
||||
var success bool
|
||||
var errorMsg sql.NullString
|
||||
|
||||
err = a.runDBOpWithReconnect(func(db *sql.DB) error {
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error FROM %s($1::jsonb)`, a.sqlNames.PasswordResetComplete)
|
||||
return db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg)
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("password reset complete query failed: %w", err)
|
||||
}
|
||||
|
||||
if !success {
|
||||
if errorMsg.Valid {
|
||||
return fmt.Errorf("%s", errorMsg.String)
|
||||
}
|
||||
return fmt.Errorf("password reset failed")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Passkey authentication methods
|
||||
// ==============================
|
||||
|
||||
|
||||
@@ -47,6 +47,10 @@ type SQLNames struct {
|
||||
PasskeyUpdateName string // default: "resolvespec_passkey_update_name"
|
||||
PasskeyLogin string // default: "resolvespec_passkey_login"
|
||||
|
||||
// Password reset procedures (DatabaseAuthenticator)
|
||||
PasswordResetRequest string // default: "resolvespec_password_reset_request"
|
||||
PasswordResetComplete string // default: "resolvespec_password_reset"
|
||||
|
||||
// OAuth2 procedures (DatabaseAuthenticator OAuth2 methods)
|
||||
OAuthGetOrCreateUser string // default: "resolvespec_oauth_getorcreateuser"
|
||||
OAuthCreateSession string // default: "resolvespec_oauth_createsession"
|
||||
@@ -95,6 +99,9 @@ func DefaultSQLNames() *SQLNames {
|
||||
PasskeyUpdateName: "resolvespec_passkey_update_name",
|
||||
PasskeyLogin: "resolvespec_passkey_login",
|
||||
|
||||
PasswordResetRequest: "resolvespec_password_reset_request",
|
||||
PasswordResetComplete: "resolvespec_password_reset",
|
||||
|
||||
OAuthGetOrCreateUser: "resolvespec_oauth_getorcreateuser",
|
||||
OAuthCreateSession: "resolvespec_oauth_createsession",
|
||||
OAuthGetRefreshToken: "resolvespec_oauth_getrefreshtoken",
|
||||
@@ -190,6 +197,12 @@ func MergeSQLNames(base, override *SQLNames) *SQLNames {
|
||||
if override.PasskeyLogin != "" {
|
||||
merged.PasskeyLogin = override.PasskeyLogin
|
||||
}
|
||||
if override.PasswordResetRequest != "" {
|
||||
merged.PasswordResetRequest = override.PasswordResetRequest
|
||||
}
|
||||
if override.PasswordResetComplete != "" {
|
||||
merged.PasswordResetComplete = override.PasswordResetComplete
|
||||
}
|
||||
if override.OAuthGetOrCreateUser != "" {
|
||||
merged.OAuthGetOrCreateUser = override.OAuthGetOrCreateUser
|
||||
}
|
||||
|
||||
@@ -807,6 +807,11 @@ func (h *Handler) buildFilterCondition(filter common.FilterOption) (conditionStr
|
||||
cond, args := common.BuildInCondition(filter.Column, filter.Value)
|
||||
return cond, args
|
||||
}
|
||||
op := strings.ToLower(filter.Operator)
|
||||
if op == "like" || op == "ilike" {
|
||||
operatorSQL := h.getOperatorSQL(filter.Operator)
|
||||
return fmt.Sprintf("CAST(%s AS TEXT) %s ?", filter.Column, operatorSQL), []interface{}{filter.Value}
|
||||
}
|
||||
operatorSQL := h.getOperatorSQL(filter.Operator)
|
||||
return fmt.Sprintf("%s %s ?", filter.Column, operatorSQL), []interface{}{filter.Value}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user