Compare commits

...

5 Commits

Author SHA1 Message Date
Hein
fb051b5577 fix(spectypes): correct quoting logic in formatPostgresStringArray
Some checks failed
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Failing after -33m2s
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Successful in -28m38s
Build , Vet Test, and Lint / Lint Code (push) Successful in -27m26s
Build , Vet Test, and Lint / Build (push) Successful in -29m3s
Tests / Integration Tests (push) Failing after -33m9s
Tests / Unit Tests (push) Successful in -30m2s
2026-04-30 15:38:21 +02:00
Hein
cc9c4337fd feat(spectypes): add PostgreSQL array types and parsing functions 2026-04-30 15:37:33 +02:00
Hein
0aaeff63a2 fix(db): guard against non-existent relations in preload
Some checks failed
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Successful in -33m2s
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Successful in -32m35s
Build , Vet Test, and Lint / Lint Code (push) Successful in -32m33s
Build , Vet Test, and Lint / Build (push) Successful in -32m50s
Tests / Integration Tests (push) Failing after -33m35s
Tests / Unit Tests (push) Successful in -33m22s
2026-04-20 17:15:33 +02:00
Hein
325769be4e feat(reflection): add support for nested struct mapping
Some checks failed
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Successful in -33m9s
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Successful in -32m37s
Build , Vet Test, and Lint / Build (push) Successful in -32m41s
Build , Vet Test, and Lint / Lint Code (push) Successful in -31m55s
Tests / Unit Tests (push) Successful in -33m29s
Tests / Integration Tests (push) Failing after -33m42s
2026-04-16 13:45:46 +02:00
f79a400772 feat(security): add self-service password reset functionality
Some checks failed
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Successful in -33m14s
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Successful in -32m42s
Build , Vet Test, and Lint / Build (push) Successful in -32m59s
Build , Vet Test, and Lint / Lint Code (push) Successful in -32m26s
Tests / Integration Tests (push) Failing after -33m40s
Tests / Unit Tests (push) Successful in -33m35s
* Implement password reset request and completion procedures
* Update database schema for password reset tokens
* Add new request and response types for password reset
2026-04-15 21:46:33 +02:00
9 changed files with 1231 additions and 0 deletions

View File

@@ -597,6 +597,19 @@ func (b *BunSelectQuery) PreloadRelation(relation string, apply ...func(common.S
if !b.skipAutoDetect {
model := b.query.GetModel()
if model != nil && model.Value() != nil {
// Guard against relations that don't exist on the model. Without this,
// bun panics inside Count/Scan with `model=X does not have relation="Y"`.
// Only validate the root segment so nested paths (e.g. "PRM.CHILD") still
// fall through to bun's native resolution.
rootRelation := relation
if idx := strings.Index(rootRelation, "."); idx >= 0 {
rootRelation = rootRelation[:idx]
}
if reflection.GetRelationType(model.Value(), rootRelation) == reflection.RelationUnknown {
logger.Warn("Skipping preload '%s': relation '%s' is not declared on model %T", relation, rootRelation, model.Value())
return b
}
relType := reflection.GetRelationType(model.Value(), relation)
// Log the detected relationship type

View File

@@ -1244,6 +1244,16 @@ func setFieldValue(field reflect.Value, value interface{}) error {
}
}
// Handle map[string]interface{} → nested struct (e.g. relation fields like AFN, DEF)
if m, ok := value.(map[string]interface{}); ok {
if field.CanAddr() {
if err := MapToStruct(m, field.Addr().Interface()); err != nil {
return err
}
return nil
}
}
// Fallback: Try to find a "Val" field (for SqlNull types) and set it directly
valField := field.FieldByName("Val")
if valField.IsValid() && valField.CanSet() {

View File

@@ -221,6 +221,124 @@ func TestMapToStruct_AllSqlTypes(t *testing.T) {
t.Logf(" - SqlJSONB (Tags): %v", tagsValue)
}
// TestMapToStruct_NestedStructPointer tests that a map[string]interface{} value is
// correctly converted into a pointer-to-struct field (e.g. AFN *ModelCoreActionfunction).
func TestMapToStruct_NestedStructPointer(t *testing.T) {
type Inner struct {
ID spectypes.SqlInt32 `bun:"rid_inner,pk" json:"rid_inner"`
Name spectypes.SqlString `bun:"name" json:"name"`
}
type Outer struct {
ID spectypes.SqlInt32 `bun:"rid_outer,pk" json:"rid_outer"`
Inner *Inner `json:"inner,omitempty" bun:"rel:has-one"`
}
dataMap := map[string]interface{}{
"rid_outer": int64(1),
"inner": map[string]interface{}{
"rid_inner": int64(42),
"name": "hello",
},
}
var result Outer
err := reflection.MapToStruct(dataMap, &result)
if err != nil {
t.Fatalf("MapToStruct() error = %v", err)
}
if !result.ID.Valid || result.ID.Val != 1 {
t.Errorf("ID = %v, want 1", result.ID)
}
if result.Inner == nil {
t.Fatal("Inner is nil, want non-nil")
}
if !result.Inner.ID.Valid || result.Inner.ID.Val != 42 {
t.Errorf("Inner.ID = %v, want 42", result.Inner.ID)
}
if !result.Inner.Name.Valid || result.Inner.Name.Val != "hello" {
t.Errorf("Inner.Name = %v, want 'hello'", result.Inner.Name)
}
}
// TestMapToStruct_NestedStructNilPointer tests that a nil map value leaves the pointer nil.
func TestMapToStruct_NestedStructNilPointer(t *testing.T) {
type Inner struct {
ID spectypes.SqlInt32 `bun:"rid_inner,pk" json:"rid_inner"`
}
type Outer struct {
ID spectypes.SqlInt32 `bun:"rid_outer,pk" json:"rid_outer"`
Inner *Inner `json:"inner,omitempty" bun:"rel:has-one"`
}
dataMap := map[string]interface{}{
"rid_outer": int64(5),
"inner": nil,
}
var result Outer
err := reflection.MapToStruct(dataMap, &result)
if err != nil {
t.Fatalf("MapToStruct() error = %v", err)
}
if result.Inner != nil {
t.Errorf("Inner = %v, want nil", result.Inner)
}
}
// TestMapToStruct_NestedStructWithSpectypes mirrors the real-world case of
// ModelCoreActionoption.AFN being populated from map[string]interface{}.
func TestMapToStruct_NestedStructWithSpectypes(t *testing.T) {
type ActionFunction struct {
Ridactionfunction spectypes.SqlInt32 `bun:"rid_actionfunction,pk" json:"rid_actionfunction"`
Functionname spectypes.SqlString `bun:"functionname" json:"functionname"`
Fntype spectypes.SqlString `bun:"fntype" json:"fntype"`
}
type ActionOption struct {
Ridactionoption spectypes.SqlInt32 `bun:"rid_actionoption,pk" json:"rid_actionoption"`
Ridactionfunction spectypes.SqlInt32 `bun:"rid_actionfunction" json:"rid_actionfunction"`
Description spectypes.SqlString `bun:"description" json:"description"`
AFN *ActionFunction `json:"AFN,omitempty" bun:"rel:has-one"`
}
dataMap := map[string]interface{}{
"rid_actionoption": int64(10),
"rid_actionfunction": int64(99),
"description": "test option",
"AFN": map[string]interface{}{
"rid_actionfunction": int64(99),
"functionname": "MyFunction",
"fntype": "action",
},
}
var result ActionOption
err := reflection.MapToStruct(dataMap, &result)
if err != nil {
t.Fatalf("MapToStruct() error = %v", err)
}
if !result.Ridactionoption.Valid || result.Ridactionoption.Val != 10 {
t.Errorf("Ridactionoption = %v, want 10", result.Ridactionoption)
}
if !result.Description.Valid || result.Description.Val != "test option" {
t.Errorf("Description = %v, want 'test option'", result.Description)
}
if result.AFN == nil {
t.Fatal("AFN is nil, want non-nil")
}
if !result.AFN.Ridactionfunction.Valid || result.AFN.Ridactionfunction.Val != 99 {
t.Errorf("AFN.Ridactionfunction = %v, want 99", result.AFN.Ridactionfunction)
}
if !result.AFN.Functionname.Valid || result.AFN.Functionname.Val != "MyFunction" {
t.Errorf("AFN.Functionname = %v, want 'MyFunction'", result.AFN.Functionname)
}
if !result.AFN.Fntype.Valid || result.AFN.Fntype.Val != "action" {
t.Errorf("AFN.Fntype = %v, want 'action'", result.AFN.Fntype)
}
}
func TestMapToStruct_SqlNull_NilValues(t *testing.T) {
// Test that SqlNull types handle nil values correctly
type TestModel struct {

View File

@@ -13,6 +13,7 @@ Type-safe, composable security system for ResolveSpec with support for authentic
-**Extensible** - Implement custom providers for your needs
-**Stored Procedures** - All database operations use PostgreSQL stored procedures for security and maintainability
-**OAuth2 Authorization Server** - Built-in OAuth 2.1 + PKCE server (RFC 8414, 7591, 7009, 7662) with login form and external provider federation
-**Password Reset** - Self-service password reset with secure token generation and session invalidation
## Stored Procedure Architecture
@@ -45,6 +46,8 @@ Type-safe, composable security system for ResolveSpec with support for authentic
| `resolvespec_oauth_exchange_code` | Consume authorization code (single-use) | OAuthServer / DatabaseAuthenticator |
| `resolvespec_oauth_introspect` | Token introspection (RFC 7662) | OAuthServer / DatabaseAuthenticator |
| `resolvespec_oauth_revoke` | Token revocation (RFC 7009) | OAuthServer / DatabaseAuthenticator |
| `resolvespec_password_reset_request` | Create password reset token | DatabaseAuthenticator |
| `resolvespec_password_reset` | Validate token and set new password | DatabaseAuthenticator |
See `database_schema.sql` for complete stored procedure definitions and examples.
@@ -904,6 +907,66 @@ securityList := security.NewSecurityList(provider)
restheadspec.RegisterSecurityHooks(handler, securityList) // or funcspec/resolvespec
```
## Password Reset
`DatabaseAuthenticator` implements `PasswordResettable` for self-service password reset.
### Flow
1. User submits email or username → `RequestPasswordReset` → server generates a token and returns it for out-of-band delivery (email, SMS, etc.)
2. User submits the raw token + new password → `CompletePasswordReset` → password updated, all sessions invalidated
### DB Requirements
Run the migrations in `database_schema.sql`:
- `user_password_resets` table (`user_id`, `token_hash` SHA-256, `expires_at`, `used`, `used_at`)
- `resolvespec_password_reset_request` stored procedure
- `resolvespec_password_reset` stored procedure
Requires the `pgcrypto` extension (`gen_random_bytes`, `digest`) — already used by `resolvespec_login`.
### Usage
```go
auth := security.NewDatabaseAuthenticator(db)
// Step 1 — initiate reset (call after user submits their email)
resp, err := auth.RequestPasswordReset(ctx, security.PasswordResetRequest{
Email: "user@example.com",
})
// resp.Token is the raw token — deliver it out-of-band
// resp.ExpiresIn is 3600 (1 hour)
// Always returns success regardless of whether the user exists (anti-enumeration)
// Step 2 — complete reset (call after user submits token + new password)
err = auth.CompletePasswordReset(ctx, security.PasswordResetCompleteRequest{
Token: rawToken,
NewPassword: "newSecurePassword",
})
// On success: password updated, all active sessions deleted
```
### Security Notes
- The raw token is never stored; only its SHA-256 hash is persisted
- Requesting a reset invalidates any previous unused tokens for that user
- Tokens expire after 1 hour
- Completing a reset deletes all active sessions, forcing re-login
- `RequestPasswordReset` always returns success even when the email/username is not found, preventing user enumeration
- Hash the new password with bcrypt before storing (pgcrypto `crypt`/`gen_salt`) — see the TODO comment in `resolvespec_password_reset`
### SQLNames
```go
type SQLNames struct {
// ...
PasswordResetRequest string // default: "resolvespec_password_reset_request"
PasswordResetComplete string // default: "resolvespec_password_reset"
}
```
---
## OAuth2 Authorization Server
`OAuthServer` is a generic OAuth 2.1 + PKCE authorization server. It is not tied to any spec — `pkg/resolvemcp` uses it, but it can be used standalone with any `http.ServeMux`.
@@ -1110,6 +1173,14 @@ type Cacheable interface {
}
```
**PasswordResettable** - Self-service password reset:
```go
type PasswordResettable interface {
RequestPasswordReset(ctx context.Context, req PasswordResetRequest) (*PasswordResetResponse, error)
CompletePasswordReset(ctx context.Context, req PasswordResetCompleteRequest) error
}
```
## Benefits Over Callbacks
| Feature | Old (Callbacks) | New (Interfaces) |

View File

@@ -1398,6 +1398,158 @@ $$ LANGUAGE plpgsql;
-- Get credentials by username
-- SELECT * FROM resolvespec_passkey_get_credentials_by_username('admin');
-- ============================================
-- Password Reset Tables
-- ============================================
-- Password reset tokens table
CREATE TABLE IF NOT EXISTS user_password_resets (
id SERIAL PRIMARY KEY,
user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE,
token_hash VARCHAR(64) NOT NULL UNIQUE, -- SHA-256 hex of the raw token
expires_at TIMESTAMP NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
used BOOLEAN DEFAULT false,
used_at TIMESTAMP
);
CREATE INDEX IF NOT EXISTS idx_pw_reset_token_hash ON user_password_resets(token_hash);
CREATE INDEX IF NOT EXISTS idx_pw_reset_user_id ON user_password_resets(user_id);
CREATE INDEX IF NOT EXISTS idx_pw_reset_expires_at ON user_password_resets(expires_at);
-- ============================================
-- Stored Procedures for Password Reset
-- ============================================
-- 1. resolvespec_password_reset_request - Creates a password reset token for a user
-- Input: p_request jsonb {email: string, username: string}
-- Output: p_success (bool), p_error (text), p_data jsonb {token: string, expires_in: int}
-- NOTE: The raw token is returned so the caller can deliver it out-of-band (e.g. email).
-- Only the SHA-256 hash is stored. Invalidates any previous unused tokens for the user.
CREATE OR REPLACE FUNCTION resolvespec_password_reset_request(p_request jsonb)
RETURNS TABLE(p_success boolean, p_error text, p_data jsonb) AS $$
DECLARE
v_user_id INTEGER;
v_email TEXT;
v_username TEXT;
v_raw_token TEXT;
v_token_hash TEXT;
v_expires_at TIMESTAMP;
BEGIN
v_email := p_request->>'email';
v_username := p_request->>'username';
-- Require at least one identifier
IF (v_email IS NULL OR v_email = '') AND (v_username IS NULL OR v_username = '') THEN
RETURN QUERY SELECT false, 'email or username is required'::text, NULL::jsonb;
RETURN;
END IF;
-- Look up active user
IF v_email IS NOT NULL AND v_email <> '' THEN
SELECT id INTO v_user_id FROM users WHERE email = v_email AND is_active = true;
ELSE
SELECT id INTO v_user_id FROM users WHERE username = v_username AND is_active = true;
END IF;
-- Return generic success even when user not found to avoid user enumeration
IF NOT FOUND THEN
RETURN QUERY SELECT true, NULL::text, jsonb_build_object('token', '', 'expires_in', 0);
RETURN;
END IF;
-- Invalidate previous unused tokens for this user
DELETE FROM user_password_resets WHERE user_id = v_user_id AND used = false;
-- Generate a random 32-byte token and store its SHA-256 hash
v_raw_token := encode(gen_random_bytes(32), 'hex');
v_token_hash := encode(digest(v_raw_token, 'sha256'), 'hex');
v_expires_at := now() + interval '1 hour';
INSERT INTO user_password_resets (user_id, token_hash, expires_at)
VALUES (v_user_id, v_token_hash, v_expires_at);
RETURN QUERY SELECT
true,
NULL::text,
jsonb_build_object(
'token', v_raw_token,
'expires_in', 3600
);
EXCEPTION
WHEN OTHERS THEN
RETURN QUERY SELECT false, SQLERRM::text, NULL::jsonb;
END;
$$ LANGUAGE plpgsql;
-- 2. resolvespec_password_reset - Validates the token and updates the user's password
-- Input: p_request jsonb {token: string, new_password: string}
-- Output: p_success (bool), p_error (text)
-- NOTE: Hash the new_password with bcrypt before storing (pgcrypto crypt/gen_salt).
-- The TODO below mirrors the convention used in resolvespec_register.
CREATE OR REPLACE FUNCTION resolvespec_password_reset(p_request jsonb)
RETURNS TABLE(p_success boolean, p_error text) AS $$
DECLARE
v_raw_token TEXT;
v_token_hash TEXT;
v_new_pw TEXT;
v_reset_id INTEGER;
v_user_id INTEGER;
v_expires_at TIMESTAMP;
BEGIN
v_raw_token := p_request->>'token';
v_new_pw := p_request->>'new_password';
IF v_raw_token IS NULL OR v_raw_token = '' THEN
RETURN QUERY SELECT false, 'token is required'::text;
RETURN;
END IF;
IF v_new_pw IS NULL OR v_new_pw = '' THEN
RETURN QUERY SELECT false, 'new_password is required'::text;
RETURN;
END IF;
v_token_hash := encode(digest(v_raw_token, 'sha256'), 'hex');
-- Find valid, unused reset token
SELECT id, user_id, expires_at
INTO v_reset_id, v_user_id, v_expires_at
FROM user_password_resets
WHERE token_hash = v_token_hash AND used = false;
IF NOT FOUND THEN
RETURN QUERY SELECT false, 'invalid or expired token'::text;
RETURN;
END IF;
IF v_expires_at <= now() THEN
RETURN QUERY SELECT false, 'invalid or expired token'::text;
RETURN;
END IF;
-- TODO: Hash new password with pgcrypto before storing
-- Enable pgcrypto: CREATE EXTENSION IF NOT EXISTS pgcrypto;
-- v_new_pw := crypt(v_new_pw, gen_salt('bf'));
-- Update password and invalidate all sessions
UPDATE users SET password = v_new_pw, updated_at = now() WHERE id = v_user_id;
DELETE FROM user_sessions WHERE user_id = v_user_id;
-- Mark token as used
UPDATE user_password_resets SET used = true, used_at = now() WHERE id = v_reset_id;
RETURN QUERY SELECT true, NULL::text;
EXCEPTION
WHEN OTHERS THEN
RETURN QUERY SELECT false, SQLERRM::text;
END;
$$ LANGUAGE plpgsql;
-- Example: Test password reset stored procedures
-- SELECT * FROM resolvespec_password_reset_request('{"email": "user@example.com"}'::jsonb);
-- SELECT * FROM resolvespec_password_reset('{"token": "<raw_token>", "new_password": "newpass123"}'::jsonb);
-- ============================================
-- OAuth2 Server Tables (OAuthServer persistence)
-- ============================================

View File

@@ -57,6 +57,27 @@ type LogoutRequest struct {
UserID int `json:"user_id"`
}
// PasswordResetRequest initiates a password reset for a user
type PasswordResetRequest struct {
Email string `json:"email,omitempty"`
Username string `json:"username,omitempty"`
}
// PasswordResetResponse is returned when a reset is initiated
type PasswordResetResponse struct {
// Token is the reset token to be delivered out-of-band (e.g. email).
// The stored procedure may return it for delivery or leave it empty
// if the delivery is handled entirely in the database.
Token string `json:"token"`
ExpiresIn int64 `json:"expires_in"` // seconds
}
// PasswordResetCompleteRequest completes a password reset using the token
type PasswordResetCompleteRequest struct {
Token string `json:"token"`
NewPassword string `json:"new_password"`
}
// Authenticator handles user authentication operations
type Authenticator interface {
// Login authenticates credentials and returns a token
@@ -114,3 +135,12 @@ type Cacheable interface {
// ClearCache clears cached security rules for a user/entity
ClearCache(ctx context.Context, userID int, schema, table string) error
}
// PasswordResettable allows providers to support self-service password reset
type PasswordResettable interface {
// RequestPasswordReset creates a reset token for the given email/username
RequestPasswordReset(ctx context.Context, req PasswordResetRequest) (*PasswordResetResponse, error)
// CompletePasswordReset validates the token and sets the new password
CompletePasswordReset(ctx context.Context, req PasswordResetCompleteRequest) error
}

View File

@@ -868,6 +868,75 @@ func generateRandomString(length int) string {
// return ""
// }
// Password reset methods
// ======================
// RequestPasswordReset implements PasswordResettable. It calls the stored procedure
// resolvespec_password_reset_request and returns the reset token and expiry.
func (a *DatabaseAuthenticator) RequestPasswordReset(ctx context.Context, req PasswordResetRequest) (*PasswordResetResponse, error) {
reqJSON, err := json.Marshal(req)
if err != nil {
return nil, fmt.Errorf("failed to marshal password reset request: %w", err)
}
var success bool
var errorMsg sql.NullString
var dataJSON sql.NullString
err = a.runDBOpWithReconnect(func(db *sql.DB) error {
query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.PasswordResetRequest)
return db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
})
if err != nil {
return nil, fmt.Errorf("password reset request query failed: %w", err)
}
if !success {
if errorMsg.Valid {
return nil, fmt.Errorf("%s", errorMsg.String)
}
return nil, fmt.Errorf("password reset request failed")
}
var response PasswordResetResponse
if dataJSON.Valid && dataJSON.String != "" {
if err := json.Unmarshal([]byte(dataJSON.String), &response); err != nil {
return nil, fmt.Errorf("failed to parse password reset response: %w", err)
}
}
return &response, nil
}
// CompletePasswordReset implements PasswordResettable. It validates the token and
// updates the user's password via resolvespec_password_reset.
func (a *DatabaseAuthenticator) CompletePasswordReset(ctx context.Context, req PasswordResetCompleteRequest) error {
reqJSON, err := json.Marshal(req)
if err != nil {
return fmt.Errorf("failed to marshal password reset complete request: %w", err)
}
var success bool
var errorMsg sql.NullString
err = a.runDBOpWithReconnect(func(db *sql.DB) error {
query := fmt.Sprintf(`SELECT p_success, p_error FROM %s($1::jsonb)`, a.sqlNames.PasswordResetComplete)
return db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg)
})
if err != nil {
return fmt.Errorf("password reset complete query failed: %w", err)
}
if !success {
if errorMsg.Valid {
return fmt.Errorf("%s", errorMsg.String)
}
return fmt.Errorf("password reset failed")
}
return nil
}
// Passkey authentication methods
// ==============================

View File

@@ -47,6 +47,10 @@ type SQLNames struct {
PasskeyUpdateName string // default: "resolvespec_passkey_update_name"
PasskeyLogin string // default: "resolvespec_passkey_login"
// Password reset procedures (DatabaseAuthenticator)
PasswordResetRequest string // default: "resolvespec_password_reset_request"
PasswordResetComplete string // default: "resolvespec_password_reset"
// OAuth2 procedures (DatabaseAuthenticator OAuth2 methods)
OAuthGetOrCreateUser string // default: "resolvespec_oauth_getorcreateuser"
OAuthCreateSession string // default: "resolvespec_oauth_createsession"
@@ -95,6 +99,9 @@ func DefaultSQLNames() *SQLNames {
PasskeyUpdateName: "resolvespec_passkey_update_name",
PasskeyLogin: "resolvespec_passkey_login",
PasswordResetRequest: "resolvespec_password_reset_request",
PasswordResetComplete: "resolvespec_password_reset",
OAuthGetOrCreateUser: "resolvespec_oauth_getorcreateuser",
OAuthCreateSession: "resolvespec_oauth_createsession",
OAuthGetRefreshToken: "resolvespec_oauth_getrefreshtoken",
@@ -190,6 +197,12 @@ func MergeSQLNames(base, override *SQLNames) *SQLNames {
if override.PasskeyLogin != "" {
merged.PasskeyLogin = override.PasskeyLogin
}
if override.PasswordResetRequest != "" {
merged.PasswordResetRequest = override.PasswordResetRequest
}
if override.PasswordResetComplete != "" {
merged.PasswordResetComplete = override.PasswordResetComplete
}
if override.OAuthGetOrCreateUser != "" {
merged.OAuthGetOrCreateUser = override.OAuthGetOrCreateUser
}

View File

@@ -0,0 +1,755 @@
package spectypes
import (
"database/sql/driver"
"encoding/json"
"fmt"
"strconv"
"strings"
"github.com/google/uuid"
)
// parsePostgresArrayElements parses a PostgreSQL array literal (e.g. `{a,"b,c",d}`)
// into a slice of raw string elements. Each element retains its unquoted/unescaped value.
func parsePostgresArrayElements(s string) ([]string, error) {
s = strings.TrimSpace(s)
if s == "" || strings.EqualFold(s, "null") || strings.EqualFold(s, "NULL") {
return nil, nil
}
if !strings.HasPrefix(s, "{") || !strings.HasSuffix(s, "}") {
return nil, fmt.Errorf("not a valid PostgreSQL array literal: %q", s)
}
inner := s[1 : len(s)-1]
if inner == "" {
return []string{}, nil
}
var result []string
var cur strings.Builder
inQuotes := false
i := 0
for i < len(inner) {
c := inner[i]
switch {
case c == '"' && !inQuotes:
inQuotes = true
case c == '"' && inQuotes:
if i+1 < len(inner) && inner[i+1] == '"' {
cur.WriteByte('"')
i++
} else {
inQuotes = false
}
case c == '\\' && inQuotes:
if i+1 < len(inner) {
cur.WriteByte(inner[i+1])
i++
}
case c == ',' && !inQuotes:
result = append(result, cur.String())
cur.Reset()
default:
cur.WriteByte(c)
}
i++
}
result = append(result, cur.String())
return result, nil
}
// formatPostgresStringArray formats a []string back into a PostgreSQL array literal.
func formatPostgresStringArray(vals []string) string {
if vals == nil {
return "NULL"
}
parts := make([]string, len(vals))
for i, v := range vals {
// Quote if value contains comma, double-quote, backslash, braces, whitespace, or is empty.
needsQuote := v == "" || strings.ContainsAny(v, `,"\\{}`+"\t\n\r ")
if needsQuote {
v = strings.ReplaceAll(v, `\`, `\\`)
v = strings.ReplaceAll(v, `"`, `""`)
parts[i] = `"` + v + `"`
} else {
parts[i] = v
}
}
return "{" + strings.Join(parts, ",") + "}"
}
// ── SqlStringArray ───────────────────────────────────────────────────────────
// SqlStringArray is a nullable PostgreSQL text[] / varchar[] array.
type SqlStringArray struct {
Val []string
Valid bool
}
func (a *SqlStringArray) Scan(value any) error {
if value == nil {
a.Valid = false
a.Val = nil
return nil
}
var s string
switch v := value.(type) {
case string:
s = v
case []byte:
s = string(v)
default:
return fmt.Errorf("SqlStringArray: cannot scan type %T", value)
}
elems, err := parsePostgresArrayElements(s)
if err != nil {
return err
}
a.Val = elems
a.Valid = true
return nil
}
func (a SqlStringArray) Value() (driver.Value, error) {
if !a.Valid {
return nil, nil
}
return formatPostgresStringArray(a.Val), nil
}
func (a SqlStringArray) MarshalJSON() ([]byte, error) {
if !a.Valid {
return []byte("null"), nil
}
return json.Marshal(a.Val)
}
func (a *SqlStringArray) UnmarshalJSON(b []byte) error {
s := strings.TrimSpace(string(b))
if s == "null" {
a.Valid = false
a.Val = nil
return nil
}
var vals []string
if err := json.Unmarshal(b, &vals); err != nil {
return err
}
a.Val = vals
a.Valid = true
return nil
}
func NewSqlStringArray(v []string) SqlStringArray {
return SqlStringArray{Val: v, Valid: true}
}
// ── SqlInt16Array ────────────────────────────────────────────────────────────
type SqlInt16Array struct {
Val []int16
Valid bool
}
func (a *SqlInt16Array) Scan(value any) error {
if value == nil {
a.Valid = false
a.Val = nil
return nil
}
var s string
switch v := value.(type) {
case string:
s = v
case []byte:
s = string(v)
default:
return fmt.Errorf("SqlInt16Array: cannot scan type %T", value)
}
elems, err := parsePostgresArrayElements(s)
if err != nil {
return err
}
a.Val = make([]int16, len(elems))
for i, e := range elems {
n, err := strconv.ParseInt(strings.TrimSpace(e), 10, 16)
if err != nil {
return fmt.Errorf("SqlInt16Array: element %d %q: %w", i, e, err)
}
a.Val[i] = int16(n)
}
a.Valid = true
return nil
}
func (a SqlInt16Array) Value() (driver.Value, error) {
if !a.Valid {
return nil, nil
}
parts := make([]string, len(a.Val))
for i, v := range a.Val {
parts[i] = strconv.FormatInt(int64(v), 10)
}
return "{" + strings.Join(parts, ",") + "}", nil
}
func (a SqlInt16Array) MarshalJSON() ([]byte, error) {
if !a.Valid {
return []byte("null"), nil
}
return json.Marshal(a.Val)
}
func (a *SqlInt16Array) UnmarshalJSON(b []byte) error {
if strings.TrimSpace(string(b)) == "null" {
a.Valid = false
a.Val = nil
return nil
}
var vals []int16
if err := json.Unmarshal(b, &vals); err != nil {
return err
}
a.Val = vals
a.Valid = true
return nil
}
func NewSqlInt16Array(v []int16) SqlInt16Array {
return SqlInt16Array{Val: v, Valid: true}
}
// ── SqlInt32Array ────────────────────────────────────────────────────────────
type SqlInt32Array struct {
Val []int32
Valid bool
}
func (a *SqlInt32Array) Scan(value any) error {
if value == nil {
a.Valid = false
a.Val = nil
return nil
}
var s string
switch v := value.(type) {
case string:
s = v
case []byte:
s = string(v)
default:
return fmt.Errorf("SqlInt32Array: cannot scan type %T", value)
}
elems, err := parsePostgresArrayElements(s)
if err != nil {
return err
}
a.Val = make([]int32, len(elems))
for i, e := range elems {
n, err := strconv.ParseInt(strings.TrimSpace(e), 10, 32)
if err != nil {
return fmt.Errorf("SqlInt32Array: element %d %q: %w", i, e, err)
}
a.Val[i] = int32(n)
}
a.Valid = true
return nil
}
func (a SqlInt32Array) Value() (driver.Value, error) {
if !a.Valid {
return nil, nil
}
parts := make([]string, len(a.Val))
for i, v := range a.Val {
parts[i] = strconv.FormatInt(int64(v), 10)
}
return "{" + strings.Join(parts, ",") + "}", nil
}
func (a SqlInt32Array) MarshalJSON() ([]byte, error) {
if !a.Valid {
return []byte("null"), nil
}
return json.Marshal(a.Val)
}
func (a *SqlInt32Array) UnmarshalJSON(b []byte) error {
if strings.TrimSpace(string(b)) == "null" {
a.Valid = false
a.Val = nil
return nil
}
var vals []int32
if err := json.Unmarshal(b, &vals); err != nil {
return err
}
a.Val = vals
a.Valid = true
return nil
}
func NewSqlInt32Array(v []int32) SqlInt32Array {
return SqlInt32Array{Val: v, Valid: true}
}
// ── SqlInt64Array ────────────────────────────────────────────────────────────
type SqlInt64Array struct {
Val []int64
Valid bool
}
func (a *SqlInt64Array) Scan(value any) error {
if value == nil {
a.Valid = false
a.Val = nil
return nil
}
var s string
switch v := value.(type) {
case string:
s = v
case []byte:
s = string(v)
default:
return fmt.Errorf("SqlInt64Array: cannot scan type %T", value)
}
elems, err := parsePostgresArrayElements(s)
if err != nil {
return err
}
a.Val = make([]int64, len(elems))
for i, e := range elems {
n, err := strconv.ParseInt(strings.TrimSpace(e), 10, 64)
if err != nil {
return fmt.Errorf("SqlInt64Array: element %d %q: %w", i, e, err)
}
a.Val[i] = n
}
a.Valid = true
return nil
}
func (a SqlInt64Array) Value() (driver.Value, error) {
if !a.Valid {
return nil, nil
}
parts := make([]string, len(a.Val))
for i, v := range a.Val {
parts[i] = strconv.FormatInt(v, 10)
}
return "{" + strings.Join(parts, ",") + "}", nil
}
func (a SqlInt64Array) MarshalJSON() ([]byte, error) {
if !a.Valid {
return []byte("null"), nil
}
return json.Marshal(a.Val)
}
func (a *SqlInt64Array) UnmarshalJSON(b []byte) error {
if strings.TrimSpace(string(b)) == "null" {
a.Valid = false
a.Val = nil
return nil
}
var vals []int64
if err := json.Unmarshal(b, &vals); err != nil {
return err
}
a.Val = vals
a.Valid = true
return nil
}
func NewSqlInt64Array(v []int64) SqlInt64Array {
return SqlInt64Array{Val: v, Valid: true}
}
// ── SqlFloat32Array ──────────────────────────────────────────────────────────
type SqlFloat32Array struct {
Val []float32
Valid bool
}
func (a *SqlFloat32Array) Scan(value any) error {
if value == nil {
a.Valid = false
a.Val = nil
return nil
}
var s string
switch v := value.(type) {
case string:
s = v
case []byte:
s = string(v)
default:
return fmt.Errorf("SqlFloat32Array: cannot scan type %T", value)
}
elems, err := parsePostgresArrayElements(s)
if err != nil {
return err
}
a.Val = make([]float32, len(elems))
for i, e := range elems {
f, err := strconv.ParseFloat(strings.TrimSpace(e), 32)
if err != nil {
return fmt.Errorf("SqlFloat32Array: element %d %q: %w", i, e, err)
}
a.Val[i] = float32(f)
}
a.Valid = true
return nil
}
func (a SqlFloat32Array) Value() (driver.Value, error) {
if !a.Valid {
return nil, nil
}
parts := make([]string, len(a.Val))
for i, v := range a.Val {
parts[i] = strconv.FormatFloat(float64(v), 'f', -1, 32)
}
return "{" + strings.Join(parts, ",") + "}", nil
}
func (a SqlFloat32Array) MarshalJSON() ([]byte, error) {
if !a.Valid {
return []byte("null"), nil
}
return json.Marshal(a.Val)
}
func (a *SqlFloat32Array) UnmarshalJSON(b []byte) error {
if strings.TrimSpace(string(b)) == "null" {
a.Valid = false
a.Val = nil
return nil
}
var vals []float32
if err := json.Unmarshal(b, &vals); err != nil {
return err
}
a.Val = vals
a.Valid = true
return nil
}
func NewSqlFloat32Array(v []float32) SqlFloat32Array {
return SqlFloat32Array{Val: v, Valid: true}
}
// ── SqlFloat64Array ──────────────────────────────────────────────────────────
type SqlFloat64Array struct {
Val []float64
Valid bool
}
func (a *SqlFloat64Array) Scan(value any) error {
if value == nil {
a.Valid = false
a.Val = nil
return nil
}
var s string
switch v := value.(type) {
case string:
s = v
case []byte:
s = string(v)
default:
return fmt.Errorf("SqlFloat64Array: cannot scan type %T", value)
}
elems, err := parsePostgresArrayElements(s)
if err != nil {
return err
}
a.Val = make([]float64, len(elems))
for i, e := range elems {
f, err := strconv.ParseFloat(strings.TrimSpace(e), 64)
if err != nil {
return fmt.Errorf("SqlFloat64Array: element %d %q: %w", i, e, err)
}
a.Val[i] = f
}
a.Valid = true
return nil
}
func (a SqlFloat64Array) Value() (driver.Value, error) {
if !a.Valid {
return nil, nil
}
parts := make([]string, len(a.Val))
for i, v := range a.Val {
parts[i] = strconv.FormatFloat(v, 'f', -1, 64)
}
return "{" + strings.Join(parts, ",") + "}", nil
}
func (a SqlFloat64Array) MarshalJSON() ([]byte, error) {
if !a.Valid {
return []byte("null"), nil
}
return json.Marshal(a.Val)
}
func (a *SqlFloat64Array) UnmarshalJSON(b []byte) error {
if strings.TrimSpace(string(b)) == "null" {
a.Valid = false
a.Val = nil
return nil
}
var vals []float64
if err := json.Unmarshal(b, &vals); err != nil {
return err
}
a.Val = vals
a.Valid = true
return nil
}
func NewSqlFloat64Array(v []float64) SqlFloat64Array {
return SqlFloat64Array{Val: v, Valid: true}
}
// ── SqlBoolArray ─────────────────────────────────────────────────────────────
type SqlBoolArray struct {
Val []bool
Valid bool
}
func (a *SqlBoolArray) Scan(value any) error {
if value == nil {
a.Valid = false
a.Val = nil
return nil
}
var s string
switch v := value.(type) {
case string:
s = v
case []byte:
s = string(v)
default:
return fmt.Errorf("SqlBoolArray: cannot scan type %T", value)
}
elems, err := parsePostgresArrayElements(s)
if err != nil {
return err
}
a.Val = make([]bool, len(elems))
for i, e := range elems {
e = strings.ToLower(strings.TrimSpace(e))
a.Val[i] = e == "t" || e == "true" || e == "1" || e == "yes"
}
a.Valid = true
return nil
}
func (a SqlBoolArray) Value() (driver.Value, error) {
if !a.Valid {
return nil, nil
}
parts := make([]string, len(a.Val))
for i, v := range a.Val {
if v {
parts[i] = "t"
} else {
parts[i] = "f"
}
}
return "{" + strings.Join(parts, ",") + "}", nil
}
func (a SqlBoolArray) MarshalJSON() ([]byte, error) {
if !a.Valid {
return []byte("null"), nil
}
return json.Marshal(a.Val)
}
func (a *SqlBoolArray) UnmarshalJSON(b []byte) error {
if strings.TrimSpace(string(b)) == "null" {
a.Valid = false
a.Val = nil
return nil
}
var vals []bool
if err := json.Unmarshal(b, &vals); err != nil {
return err
}
a.Val = vals
a.Valid = true
return nil
}
func NewSqlBoolArray(v []bool) SqlBoolArray {
return SqlBoolArray{Val: v, Valid: true}
}
// ── SqlUUIDArray ─────────────────────────────────────────────────────────────
type SqlUUIDArray struct {
Val []uuid.UUID
Valid bool
}
func (a *SqlUUIDArray) Scan(value any) error {
if value == nil {
a.Valid = false
a.Val = nil
return nil
}
var s string
switch v := value.(type) {
case string:
s = v
case []byte:
s = string(v)
default:
return fmt.Errorf("SqlUUIDArray: cannot scan type %T", value)
}
elems, err := parsePostgresArrayElements(s)
if err != nil {
return err
}
a.Val = make([]uuid.UUID, len(elems))
for i, e := range elems {
u, err := uuid.Parse(strings.TrimSpace(e))
if err != nil {
return fmt.Errorf("SqlUUIDArray: element %d %q: %w", i, e, err)
}
a.Val[i] = u
}
a.Valid = true
return nil
}
func (a SqlUUIDArray) Value() (driver.Value, error) {
if !a.Valid {
return nil, nil
}
parts := make([]string, len(a.Val))
for i, v := range a.Val {
parts[i] = v.String()
}
return "{" + strings.Join(parts, ",") + "}", nil
}
func (a SqlUUIDArray) MarshalJSON() ([]byte, error) {
if !a.Valid {
return []byte("null"), nil
}
return json.Marshal(a.Val)
}
func (a *SqlUUIDArray) UnmarshalJSON(b []byte) error {
if strings.TrimSpace(string(b)) == "null" {
a.Valid = false
a.Val = nil
return nil
}
var vals []uuid.UUID
if err := json.Unmarshal(b, &vals); err != nil {
return err
}
a.Val = vals
a.Valid = true
return nil
}
func NewSqlUUIDArray(v []uuid.UUID) SqlUUIDArray {
return SqlUUIDArray{Val: v, Valid: true}
}
// ── SqlVector ────────────────────────────────────────────────────────────────
// SqlVector is a nullable pgvector `vector` type backed by []float32.
// Wire format: `[1.0,2.0,3.0]` (square brackets, comma-separated floats).
type SqlVector struct {
Val []float32
Valid bool
}
func (v *SqlVector) Scan(value any) error {
if value == nil {
v.Valid = false
v.Val = nil
return nil
}
var s string
switch val := value.(type) {
case string:
s = val
case []byte:
s = string(val)
default:
return fmt.Errorf("SqlVector: cannot scan type %T", value)
}
s = strings.TrimSpace(s)
if !strings.HasPrefix(s, "[") || !strings.HasSuffix(s, "]") {
return fmt.Errorf("SqlVector: not a valid vector literal: %q", s)
}
inner := s[1 : len(s)-1]
if inner == "" {
v.Val = []float32{}
v.Valid = true
return nil
}
parts := strings.Split(inner, ",")
v.Val = make([]float32, len(parts))
for i, p := range parts {
f, err := strconv.ParseFloat(strings.TrimSpace(p), 32)
if err != nil {
return fmt.Errorf("SqlVector: element %d %q: %w", i, p, err)
}
v.Val[i] = float32(f)
}
v.Valid = true
return nil
}
func (v SqlVector) Value() (driver.Value, error) {
if !v.Valid {
return nil, nil
}
parts := make([]string, len(v.Val))
for i, f := range v.Val {
parts[i] = strconv.FormatFloat(float64(f), 'f', -1, 32)
}
return "[" + strings.Join(parts, ",") + "]", nil
}
func (v SqlVector) MarshalJSON() ([]byte, error) {
if !v.Valid {
return []byte("null"), nil
}
return json.Marshal(v.Val)
}
func (v *SqlVector) UnmarshalJSON(b []byte) error {
if strings.TrimSpace(string(b)) == "null" {
v.Valid = false
v.Val = nil
return nil
}
var vals []float32
if err := json.Unmarshal(b, &vals); err != nil {
return err
}
v.Val = vals
v.Valid = true
return nil
}
func NewSqlVector(val []float32) SqlVector {
return SqlVector{Val: val, Valid: true}
}