Compare commits

...

27 Commits

Author SHA1 Message Date
Hein
ff72e04428 Added meta operation. 2025-12-03 11:59:58 +02:00
Hein
e35f8a4f14 Fix session id that is an integer. 2025-12-03 11:49:19 +02:00
Hein
5ff9a8a24e Fixed blank params on funcspec 2025-12-03 11:42:32 +02:00
Hein
81b87af6e4 Updated doc 2025-12-03 11:30:59 +02:00
Hein
f3ba314640 Refectored the mux routers. 2025-12-03 10:42:26 +02:00
Hein
93df33e274 UnderlyingRequest and UnderlyingResponseWriter
Some checks are pending
Tests / Run Tests (1.23.x) (push) Waiting to run
Tests / Run Tests (1.24.x) (push) Waiting to run
Tests / Lint Code (push) Waiting to run
Tests / Build (push) Waiting to run
2025-12-02 17:40:44 +02:00
Hein
abd045493a mux UnderlyingRequest 2025-12-02 17:34:18 +02:00
Hein
a61556d857 Added FallbackHandler 2025-12-02 17:16:34 +02:00
Hein
eaf1133575 Fixed security rules not loading 2025-12-02 16:55:12 +02:00
Hein
8172c0495d More generic security solution. 2025-12-02 16:35:08 +02:00
Hein
7a3c368121 Pass through to default handler 2025-12-02 16:09:36 +02:00
Hein
9c5c7689e9 More common handler interface 2025-12-02 15:45:24 +02:00
Hein
08050c960d Optional Authentication 2025-12-02 14:14:38 +02:00
Hein
78029fb34f Fixed formatting issues
Some checks failed
Tests / Run Tests (1.23.x) (push) Has been cancelled
Tests / Run Tests (1.24.x) (push) Has been cancelled
Tests / Lint Code (push) Has been cancelled
Tests / Build (push) Has been cancelled
2025-12-01 14:56:30 +02:00
Hein
1643a5e920 Added cache, funcspec and implemented total cache 2025-12-01 14:40:54 +02:00
Hein
6bbe0ec8b0 Added function api prototype
Some checks failed
Tests / Run Tests (1.23.x) (push) Has been cancelled
Tests / Run Tests (1.24.x) (push) Has been cancelled
Tests / Lint Code (push) Has been cancelled
Tests / Build (push) Has been cancelled
2025-11-24 17:00:15 +02:00
Hein
e32ec9e17e Updated the security package 2025-11-24 17:00:05 +02:00
Hein
26c175e65e Added make release to vscode tasks
Some checks are pending
Tests / Run Tests (1.23.x) (push) Waiting to run
Tests / Run Tests (1.24.x) (push) Waiting to run
Tests / Lint Code (push) Waiting to run
Tests / Build (push) Waiting to run
2025-11-24 10:15:23 +02:00
Hein
aa99e8e4bc Added WrapHTTPRequest 2025-11-24 10:13:48 +02:00
Hein
163593901f Huge preload chains causing errors, workaround to do seperate selects.
Some checks failed
Tests / Run Tests (1.23.x) (push) Has been cancelled
Tests / Run Tests (1.24.x) (push) Has been cancelled
Tests / Lint Code (push) Has been cancelled
Tests / Build (push) Has been cancelled
2025-11-21 17:09:11 +02:00
Hein
1261960e97 Ability to handle multiple x-custom- headers
Some checks are pending
Tests / Run Tests (1.23.x) (push) Waiting to run
Tests / Run Tests (1.24.x) (push) Waiting to run
Tests / Lint Code (push) Waiting to run
Tests / Build (push) Waiting to run
2025-11-21 12:15:07 +02:00
Hein
76bbf33db2 Fixed SingleRecordAsObject true when handleRead with no id 2025-11-21 11:49:08 +02:00
Hein
02c9b96b0c Better SanitizeWhereClause 2025-11-21 11:42:01 +02:00
Hein
9a3564f05f SanitizeWhereClause with tablename on handlers. 2025-11-21 11:00:44 +02:00
Hein
a931b8cdd2 Better preloads 2025-11-21 10:41:58 +02:00
Hein
7e76977dcc Lots of refactoring, Fixes to preloads 2025-11-21 10:17:20 +02:00
Hein
7853a3f56a cql_columns parsing and recursive preloading. Also added legacy header support for limt(s,e) ,sort(x,y,-z) 2025-11-21 09:15:40 +02:00
63 changed files with 11954 additions and 2117 deletions

View File

@@ -86,7 +86,6 @@
"emptyFallthrough",
"equalFold",
"flagName",
"ifElseChain",
"indexAlloc",
"initClause",
"methodExprCall",
@@ -106,6 +105,9 @@
"unnecessaryBlock",
"weakCond",
"yodaStyleExpr"
],
"disabled-checks": [
"ifElseChain"
]
},
"revive": {

9
.vscode/tasks.json vendored
View File

@@ -6,7 +6,7 @@
"label": "go: build workspace",
"command": "build",
"options": {
"env": {
"env": {
"CGO_ENABLED": "0"
},
"cwd": "${workspaceFolder}/bin",
@@ -18,7 +18,6 @@
"$go"
],
"group": "build",
},
{
"type": "go",
@@ -81,6 +80,12 @@
"kind": "test",
"isDefault": false
}
},
{
"type": "shell",
"label": "Make Release",
"problemMatcher": [],
"command": "sh ${workspaceFolder}/make_release.sh",
}
]
}

View File

@@ -13,6 +13,8 @@ Both share the same core architecture and provide dynamic data querying, relatio
**🆕 New in v2.1**: RestHeadSpec (HeaderSpec) - Header-based REST API with lifecycle hooks, cursor pagination, and advanced filtering.
**🆕 New in v3.0**: Explicit route registration - Routes are now created per registered model for better flexibility and control. OPTIONS method support with full CORS headers for cross-origin requests.
![slogan](./generated_slogan.webp)
## Table of Contents
@@ -65,6 +67,12 @@ Both share the same core architecture and provide dynamic data querying, relatio
- **🆕 Advanced Filtering**: Field filters, search operators, AND/OR logic, and custom SQL
- **🆕 Base64 Encoding**: Support for base64-encoded header values
### Routing & CORS (v3.0+)
- **🆕 Explicit Route Registration**: Routes created per registered model instead of dynamic lookups
- **🆕 OPTIONS Method Support**: Full OPTIONS method support returning model metadata
- **🆕 CORS Headers**: Comprehensive CORS support with all HeadSpec headers allowed
- **🆕 Better Route Control**: Customize routes per model with more flexibility
## API Structure
### URL Patterns
@@ -123,13 +131,15 @@ import "github.com/gorilla/mux"
// Create handler
handler := restheadspec.NewHandlerWithGORM(db)
// Register models using schema.table format
// IMPORTANT: Register models BEFORE setting up routes
// Routes are created explicitly for each registered model
handler.Registry.RegisterModel("public.users", &User{})
handler.Registry.RegisterModel("public.posts", &Post{})
// Setup routes
// Setup routes (creates explicit routes for each registered model)
// This replaces the old dynamic route lookup approach
router := mux.NewRouter()
restheadspec.SetupMuxRoutes(router, handler)
restheadspec.SetupMuxRoutes(router, handler, nil)
// Start server
http.ListenAndServe(":8080", router)
@@ -172,6 +182,42 @@ restheadspec.SetupMuxRoutes(router, handler)
For complete header documentation, see [pkg/restheadspec/HEADERS.md](pkg/restheadspec/HEADERS.md).
### CORS & OPTIONS Support
ResolveSpec and RestHeadSpec include comprehensive CORS support for cross-origin requests:
**OPTIONS Method**:
```http
OPTIONS /public/users HTTP/1.1
```
Returns metadata with appropriate CORS headers:
```http
Access-Control-Allow-Origin: *
Access-Control-Allow-Methods: GET, POST, OPTIONS
Access-Control-Allow-Headers: Content-Type, Authorization, X-Select-Fields, X-FieldFilter-*, ...
Access-Control-Max-Age: 86400
Access-Control-Allow-Credentials: true
```
**Key Features**:
- OPTIONS returns model metadata (same as GET metadata endpoint)
- All HTTP methods include CORS headers automatically
- OPTIONS requests don't require authentication (CORS preflight)
- Supports all HeadSpec custom headers (`X-Select-Fields`, `X-FieldFilter-*`, etc.)
- 24-hour max age to reduce preflight requests
**Configuration**:
```go
import "github.com/bitechdev/ResolveSpec/pkg/common"
// Get default CORS config
corsConfig := common.DefaultCORSConfig()
// Customize if needed
corsConfig.AllowedOrigins = []string{"https://example.com"}
corsConfig.AllowedMethods = []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"}
```
### Lifecycle Hooks
RestHeadSpec supports lifecycle hooks for all CRUD operations:
@@ -687,15 +733,16 @@ handler := resolvespec.NewHandler(dbAdapter, registry)
```go
import "github.com/gorilla/mux"
// Backward compatible way
router := mux.NewRouter()
resolvespec.SetupRoutes(router, handler)
// Register models first
handler.Registry.RegisterModel("public.users", &User{})
handler.Registry.RegisterModel("public.posts", &Post{})
// Or manually:
router.HandleFunc("/{schema}/{entity}", func(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r)
handler.Handle(w, r, vars)
}).Methods("POST")
// Setup routes - creates explicit routes for each model
router := mux.NewRouter()
resolvespec.SetupMuxRoutes(router, handler, nil)
// Routes created: /public/users, /public/posts, etc.
// Each route includes GET, POST, and OPTIONS methods with CORS support
```
#### Gin (Custom Integration)
@@ -950,7 +997,28 @@ This project is licensed under the MIT License - see the [LICENSE](LICENSE) file
## What's New
### v2.1 (Latest)
### v3.0 (Latest - December 2025)
**Explicit Route Registration (🆕)**:
- **Breaking Change**: Routes are now created explicitly for each registered model
- **Better Control**: Customize routes per model with more flexibility
- **Registration Order**: Models must be registered BEFORE calling SetupMuxRoutes/SetupBunRouterRoutes
- **Benefits**: More flexible routing, easier to add custom routes per model, better performance
**OPTIONS Method & CORS Support (🆕)**:
- **OPTIONS Endpoint**: Full OPTIONS method support for CORS preflight requests
- **Metadata Response**: OPTIONS returns model metadata (same as GET /metadata)
- **CORS Headers**: Comprehensive CORS headers on all responses
- **Header Support**: All HeadSpec custom headers (`X-Select-Fields`, `X-FieldFilter-*`, etc.) allowed
- **No Auth on OPTIONS**: CORS preflight requests don't require authentication
- **Configurable**: Customize CORS settings via `common.CORSConfig`
**Migration Notes**:
- Update your code to register models BEFORE calling SetupMuxRoutes/SetupBunRouterRoutes
- Routes like `/public/users` are now created per registered model instead of using dynamic `/{schema}/{entity}` pattern
- This is a **breaking change** but provides better control and flexibility
### v2.1
**Recursive CRUD Handler (🆕 Nov 11, 2025)**:
- **Nested Object Graphs**: Automatically handle complex object hierarchies with parent-child relationships

View File

@@ -47,8 +47,8 @@ func main() {
handler.RegisterModel("public", modelNames[i], model)
}
// Setup routes using new SetupMuxRoutes function
resolvespec.SetupMuxRoutes(r, handler)
// Setup routes using new SetupMuxRoutes function (without authentication)
resolvespec.SetupMuxRoutes(r, handler, nil)
// Start server
logger.Info("Starting server on :8080")

4
go.mod
View File

@@ -19,7 +19,10 @@ require (
)
require (
github.com/bradfitz/gomemcache v0.0.0-20250403215159-8d39553ac7cf // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/dustin/go-humanize v1.0.1 // indirect
github.com/glebarez/go-sqlite v1.21.2 // indirect
github.com/google/uuid v1.6.0 // indirect
@@ -30,6 +33,7 @@ require (
github.com/ncruces/go-strftime v0.1.9 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/puzpuzpuz/xsync/v3 v3.5.1 // indirect
github.com/redis/go-redis/v9 v9.17.1 // indirect
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.0 // indirect

8
go.sum
View File

@@ -1,6 +1,12 @@
github.com/bradfitz/gomemcache v0.0.0-20250403215159-8d39553ac7cf h1:TqhNAT4zKbTdLa62d2HDBFdvgSbIGB3eJE8HqhgiL9I=
github.com/bradfitz/gomemcache v0.0.0-20250403215159-8d39553ac7cf/go.mod h1:r5xuitiExdLAJ09PR7vBVENGvp4ZuTBeWTGtxuX3K+c=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
github.com/glebarez/go-sqlite v1.21.2 h1:3a6LFC4sKahUunAmynQKLZceZCOzUthkRkEAl9gAXWo=
@@ -31,6 +37,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/puzpuzpuz/xsync/v3 v3.5.1 h1:GJYJZwO6IdxN/IKbneznS6yPkVC+c3zyY/j19c++5Fg=
github.com/puzpuzpuz/xsync/v3 v3.5.1/go.mod h1:VjzYrABPabuM4KyBh1Ftq6u8nhwY5tBPKP9jpmh0nnA=
github.com/redis/go-redis/v9 v9.17.1 h1:7tl732FjYPRT9H9aNfyTwKg9iTETjWjGKEJ2t/5iWTs=
github.com/redis/go-redis/v9 v9.17.1/go.mod h1:u410H11HMLoB+TP67dz8rL9s6QW2j76l0//kSOd3370=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=

340
pkg/cache/README.md vendored Normal file
View File

@@ -0,0 +1,340 @@
# Cache Package
A flexible, provider-based caching library for Go that supports multiple backend storage systems including in-memory, Redis, and Memcache.
## Features
- **Multiple Providers**: Support for in-memory, Redis, and Memcache backends
- **Pluggable Architecture**: Easy to add custom cache providers
- **Type-Safe API**: Automatic JSON serialization/deserialization
- **TTL Support**: Configurable time-to-live for cache entries
- **Context-Aware**: All operations support Go contexts
- **Statistics**: Built-in cache statistics and monitoring
- **Pattern Deletion**: Delete keys by pattern (Redis)
- **Lazy Loading**: GetOrSet pattern for easy cache-aside implementation
## Installation
```bash
go get github.com/bitechdev/ResolveSpec/pkg/cache
```
For Redis support:
```bash
go get github.com/redis/go-redis/v9
```
For Memcache support:
```bash
go get github.com/bradfitz/gomemcache/memcache
```
## Quick Start
### In-Memory Cache
```go
package main
import (
"context"
"time"
"github.com/bitechdev/ResolveSpec/pkg/cache"
)
func main() {
// Initialize with in-memory provider
cache.UseMemory(&cache.Options{
DefaultTTL: 5 * time.Minute,
MaxSize: 10000,
})
defer cache.Close()
ctx := context.Background()
c := cache.GetDefaultCache()
// Store a value
type User struct {
ID int
Name string
}
user := User{ID: 1, Name: "John"}
c.Set(ctx, "user:1", user, 10*time.Minute)
// Retrieve a value
var retrieved User
c.Get(ctx, "user:1", &retrieved)
}
```
### Redis Cache
```go
cache.UseRedis(&cache.RedisConfig{
Host: "localhost",
Port: 6379,
Password: "",
DB: 0,
Options: &cache.Options{
DefaultTTL: 5 * time.Minute,
},
})
defer cache.Close()
```
### Memcache
```go
cache.UseMemcache(&cache.MemcacheConfig{
Servers: []string{"localhost:11211"},
Options: &cache.Options{
DefaultTTL: 5 * time.Minute,
},
})
defer cache.Close()
```
## API Reference
### Core Methods
#### Set
```go
Set(ctx context.Context, key string, value interface{}, ttl time.Duration) error
```
Stores a value in the cache with automatic JSON serialization.
#### Get
```go
Get(ctx context.Context, key string, dest interface{}) error
```
Retrieves and deserializes a value from the cache.
#### SetBytes / GetBytes
```go
SetBytes(ctx context.Context, key string, value []byte, ttl time.Duration) error
GetBytes(ctx context.Context, key string) ([]byte, error)
```
Store and retrieve raw bytes without serialization.
#### Delete
```go
Delete(ctx context.Context, key string) error
```
Removes a key from the cache.
#### DeleteByPattern
```go
DeleteByPattern(ctx context.Context, pattern string) error
```
Removes all keys matching a pattern (Redis only).
#### Clear
```go
Clear(ctx context.Context) error
```
Removes all items from the cache.
#### Exists
```go
Exists(ctx context.Context, key string) bool
```
Checks if a key exists in the cache.
#### GetOrSet
```go
GetOrSet(ctx context.Context, key string, dest interface{}, ttl time.Duration,
loader func() (interface{}, error)) error
```
Retrieves a value from cache, or loads and caches it if not found (lazy loading).
#### Stats
```go
Stats(ctx context.Context) (*CacheStats, error)
```
Returns cache statistics including hits, misses, and key counts.
## Provider Configuration
### In-Memory Options
```go
&cache.Options{
DefaultTTL: 5 * time.Minute, // Default expiration time
MaxSize: 10000, // Maximum number of items
EvictionPolicy: "LRU", // Eviction strategy (future)
}
```
### Redis Configuration
```go
&cache.RedisConfig{
Host: "localhost",
Port: 6379,
Password: "", // Optional authentication
DB: 0, // Database number
PoolSize: 10, // Connection pool size
Options: &cache.Options{
DefaultTTL: 5 * time.Minute,
},
}
```
### Memcache Configuration
```go
&cache.MemcacheConfig{
Servers: []string{"localhost:11211"},
MaxIdleConns: 2,
Timeout: 1 * time.Second,
Options: &cache.Options{
DefaultTTL: 5 * time.Minute,
},
}
```
## Advanced Usage
### Custom Provider
```go
// Create a custom provider instance
memProvider := cache.NewMemoryProvider(&cache.Options{
DefaultTTL: 10 * time.Minute,
MaxSize: 500,
})
// Initialize with custom provider
cache.Initialize(memProvider)
```
### Lazy Loading Pattern
```go
var data ExpensiveData
err := c.GetOrSet(ctx, "expensive:key", &data, 10*time.Minute, func() (interface{}, error) {
// This expensive operation only runs if key is not in cache
return computeExpensiveData(), nil
})
```
### Query API Cache
The package includes specialized functions for caching query results:
```go
// Cache a query result
api := "GetUsers"
query := "SELECT * FROM users WHERE active = true"
tablenames := "users"
total := int64(150)
cache.PutQueryAPICache(ctx, api, query, tablenames, total)
// Retrieve cached query
hash := cache.HashQueryAPICache(api, query)
cachedQuery, err := cache.FetchQueryAPICache(ctx, hash)
```
## Provider Comparison
| Feature | In-Memory | Redis | Memcache |
|---------|-----------|-------|----------|
| Persistence | No | Yes | No |
| Distributed | No | Yes | Yes |
| Pattern Delete | No | Yes | No |
| Statistics | Full | Full | Limited |
| Atomic Operations | Yes | Yes | Yes |
| Max Item Size | Memory | 512MB | 1MB |
## Best Practices
1. **Use contexts**: Always pass context for cancellation and timeout control
2. **Set appropriate TTLs**: Balance between freshness and performance
3. **Handle errors**: Cache misses and errors should be handled gracefully
4. **Monitor statistics**: Use Stats() to monitor cache performance
5. **Clean up**: Always call Close() when shutting down
6. **Pattern consistency**: Use consistent key naming patterns (e.g., "user:id:field")
## Example: Complete Application
```go
package main
import (
"context"
"log"
"time"
"github.com/bitechdev/ResolveSpec/pkg/cache"
)
type UserService struct {
cache *cache.Cache
}
func NewUserService() *UserService {
// Initialize with Redis in production, memory for testing
cache.UseRedis(&cache.RedisConfig{
Host: "localhost",
Port: 6379,
Options: &cache.Options{
DefaultTTL: 10 * time.Minute,
},
})
return &UserService{
cache: cache.GetDefaultCache(),
}
}
func (s *UserService) GetUser(ctx context.Context, userID int) (*User, error) {
var user User
cacheKey := fmt.Sprintf("user:%d", userID)
// Try to get from cache first
err := s.cache.GetOrSet(ctx, cacheKey, &user, 15*time.Minute, func() (interface{}, error) {
// Load from database if not in cache
return s.loadUserFromDB(userID)
})
if err != nil {
return nil, err
}
return &user, nil
}
func (s *UserService) InvalidateUser(ctx context.Context, userID int) error {
cacheKey := fmt.Sprintf("user:%d", userID)
return s.cache.Delete(ctx, cacheKey)
}
func main() {
service := NewUserService()
defer cache.Close()
ctx := context.Background()
user, err := service.GetUser(ctx, 123)
if err != nil {
log.Fatal(err)
}
log.Printf("User: %+v", user)
}
```
## Performance Considerations
- **In-Memory**: Fastest but limited by RAM and not distributed
- **Redis**: Great for distributed systems, persistent, but network overhead
- **Memcache**: Good for distributed caching, simpler than Redis but less features
Choose based on your needs:
- Single instance? Use in-memory
- Need persistence or advanced features? Use Redis
- Simple distributed cache? Use Memcache
## License
See repository license.

76
pkg/cache/cache.go vendored Normal file
View File

@@ -0,0 +1,76 @@
package cache
import (
"context"
"fmt"
"time"
)
var (
defaultCache *Cache
)
// Initialize initializes the cache with a provider.
// If not called, the package will use an in-memory provider by default.
func Initialize(provider Provider) {
defaultCache = NewCache(provider)
}
// UseMemory configures the cache to use in-memory storage.
func UseMemory(opts *Options) error {
provider := NewMemoryProvider(opts)
defaultCache = NewCache(provider)
return nil
}
// UseRedis configures the cache to use Redis storage.
func UseRedis(config *RedisConfig) error {
provider, err := NewRedisProvider(config)
if err != nil {
return fmt.Errorf("failed to initialize Redis provider: %w", err)
}
defaultCache = NewCache(provider)
return nil
}
// UseMemcache configures the cache to use Memcache storage.
func UseMemcache(config *MemcacheConfig) error {
provider, err := NewMemcacheProvider(config)
if err != nil {
return fmt.Errorf("failed to initialize Memcache provider: %w", err)
}
defaultCache = NewCache(provider)
return nil
}
// GetDefaultCache returns the default cache instance.
// Initializes with in-memory provider if not already initialized.
func GetDefaultCache() *Cache {
if defaultCache == nil {
_ = UseMemory(&Options{
DefaultTTL: 5 * time.Minute,
MaxSize: 10000,
})
}
return defaultCache
}
// SetDefaultCache sets a custom cache instance as the default cache.
// This is useful for testing or when you want to use a pre-configured cache instance.
func SetDefaultCache(cache *Cache) {
defaultCache = cache
}
// GetStats returns cache statistics.
func GetStats(ctx context.Context) (*CacheStats, error) {
cache := GetDefaultCache()
return cache.Stats(ctx)
}
// Close closes the cache and releases resources.
func Close() error {
if defaultCache != nil {
return defaultCache.Close()
}
return nil
}

147
pkg/cache/cache_manager.go vendored Normal file
View File

@@ -0,0 +1,147 @@
package cache
import (
"context"
"encoding/json"
"fmt"
"time"
)
// Cache is the main cache manager that wraps a Provider.
type Cache struct {
provider Provider
}
// NewCache creates a new cache manager with the specified provider.
func NewCache(provider Provider) *Cache {
return &Cache{
provider: provider,
}
}
// Get retrieves and deserializes a value from the cache.
func (c *Cache) Get(ctx context.Context, key string, dest interface{}) error {
data, exists := c.provider.Get(ctx, key)
if !exists {
return fmt.Errorf("key not found: %s", key)
}
if err := json.Unmarshal(data, dest); err != nil {
return fmt.Errorf("failed to deserialize: %w", err)
}
return nil
}
// GetBytes retrieves raw bytes from the cache.
func (c *Cache) GetBytes(ctx context.Context, key string) ([]byte, error) {
data, exists := c.provider.Get(ctx, key)
if !exists {
return nil, fmt.Errorf("key not found: %s", key)
}
return data, nil
}
// Set serializes and stores a value in the cache with the specified TTL.
func (c *Cache) Set(ctx context.Context, key string, value interface{}, ttl time.Duration) error {
data, err := json.Marshal(value)
if err != nil {
return fmt.Errorf("failed to serialize: %w", err)
}
return c.provider.Set(ctx, key, data, ttl)
}
// SetBytes stores raw bytes in the cache with the specified TTL.
func (c *Cache) SetBytes(ctx context.Context, key string, value []byte, ttl time.Duration) error {
return c.provider.Set(ctx, key, value, ttl)
}
// Delete removes a key from the cache.
func (c *Cache) Delete(ctx context.Context, key string) error {
return c.provider.Delete(ctx, key)
}
// DeleteByPattern removes all keys matching the pattern.
func (c *Cache) DeleteByPattern(ctx context.Context, pattern string) error {
return c.provider.DeleteByPattern(ctx, pattern)
}
// Clear removes all items from the cache.
func (c *Cache) Clear(ctx context.Context) error {
return c.provider.Clear(ctx)
}
// Exists checks if a key exists in the cache.
func (c *Cache) Exists(ctx context.Context, key string) bool {
return c.provider.Exists(ctx, key)
}
// Stats returns statistics about the cache.
func (c *Cache) Stats(ctx context.Context) (*CacheStats, error) {
return c.provider.Stats(ctx)
}
// Close closes the cache and releases any resources.
func (c *Cache) Close() error {
return c.provider.Close()
}
// GetOrSet retrieves a value from cache, or sets it if it doesn't exist.
// The loader function is called only if the key is not found in cache.
func (c *Cache) GetOrSet(ctx context.Context, key string, dest interface{}, ttl time.Duration, loader func() (interface{}, error)) error {
// Try to get from cache first
err := c.Get(ctx, key, dest)
if err == nil {
return nil
}
// Load the value
value, err := loader()
if err != nil {
return fmt.Errorf("loader failed: %w", err)
}
// Store in cache
if err := c.Set(ctx, key, value, ttl); err != nil {
return fmt.Errorf("failed to cache value: %w", err)
}
// Populate dest with the loaded value
data, err := json.Marshal(value)
if err != nil {
return fmt.Errorf("failed to serialize loaded value: %w", err)
}
if err := json.Unmarshal(data, dest); err != nil {
return fmt.Errorf("failed to deserialize loaded value: %w", err)
}
return nil
}
// Remember is a convenience function that caches the result of a function call.
// It's similar to GetOrSet but returns the value directly.
func (c *Cache) Remember(ctx context.Context, key string, ttl time.Duration, loader func() (interface{}, error)) (interface{}, error) {
// Try to get from cache first as bytes
data, err := c.GetBytes(ctx, key)
if err == nil {
var result interface{}
if err := json.Unmarshal(data, &result); err == nil {
return result, nil
}
}
// Load the value
value, err := loader()
if err != nil {
return nil, fmt.Errorf("loader failed: %w", err)
}
// Store in cache
if err := c.Set(ctx, key, value, ttl); err != nil {
return nil, fmt.Errorf("failed to cache value: %w", err)
}
return value, nil
}

69
pkg/cache/cache_test.go vendored Normal file
View File

@@ -0,0 +1,69 @@
package cache
import (
"context"
"testing"
"time"
)
func TestSetDefaultCache(t *testing.T) {
// Create a custom cache instance
provider := NewMemoryProvider(&Options{
DefaultTTL: 1 * time.Minute,
MaxSize: 50,
})
customCache := NewCache(provider)
// Set it as the default
SetDefaultCache(customCache)
// Verify it's now the default
retrievedCache := GetDefaultCache()
if retrievedCache != customCache {
t.Error("SetDefaultCache did not set the cache correctly")
}
// Test that we can use it
ctx := context.Background()
testKey := "test_key"
testValue := "test_value"
err := retrievedCache.Set(ctx, testKey, testValue, time.Minute)
if err != nil {
t.Fatalf("Failed to set value: %v", err)
}
var result string
err = retrievedCache.Get(ctx, testKey, &result)
if err != nil {
t.Fatalf("Failed to get value: %v", err)
}
if result != testValue {
t.Errorf("Expected %s, got %s", testValue, result)
}
// Clean up - reset to default
SetDefaultCache(nil)
}
func TestGetDefaultCacheInitialization(t *testing.T) {
// Reset to nil first
SetDefaultCache(nil)
// GetDefaultCache should auto-initialize
cache := GetDefaultCache()
if cache == nil {
t.Error("GetDefaultCache should auto-initialize, got nil")
}
// Should be usable
ctx := context.Background()
err := cache.Set(ctx, "test", "value", time.Minute)
if err != nil {
t.Errorf("Failed to use auto-initialized cache: %v", err)
}
// Clean up
SetDefaultCache(nil)
}

266
pkg/cache/example_usage.go vendored Normal file
View File

@@ -0,0 +1,266 @@
package cache
import (
"context"
"fmt"
"log"
"time"
)
// ExampleInMemoryCache demonstrates using the in-memory cache provider.
func ExampleInMemoryCache() {
// Initialize with in-memory provider
err := UseMemory(&Options{
DefaultTTL: 5 * time.Minute,
MaxSize: 1000,
})
if err != nil {
log.Fatal(err)
}
ctx := context.Background()
// Get the cache instance
cache := GetDefaultCache()
// Store a value
type User struct {
ID int
Name string
}
user := User{ID: 1, Name: "John Doe"}
err = cache.Set(ctx, "user:1", user, 10*time.Minute)
if err != nil {
_ = Close()
log.Fatal(err)
}
// Retrieve a value
var retrieved User
err = cache.Get(ctx, "user:1", &retrieved)
if err != nil {
_ = Close()
log.Fatal(err)
}
fmt.Printf("Retrieved user: %+v\n", retrieved)
// Check if key exists
exists := cache.Exists(ctx, "user:1")
fmt.Printf("Key exists: %v\n", exists)
// Delete a key
err = cache.Delete(ctx, "user:1")
if err != nil {
_ = Close()
log.Fatal(err)
}
// Get statistics
stats, err := cache.Stats(ctx)
if err != nil {
_ = Close()
log.Fatal(err)
}
fmt.Printf("Cache stats: %+v\n", stats)
_ = Close()
}
// ExampleRedisCache demonstrates using the Redis cache provider.
func ExampleRedisCache() {
// Initialize with Redis provider
err := UseRedis(&RedisConfig{
Host: "localhost",
Port: 6379,
Password: "", // Set if Redis requires authentication
DB: 0,
Options: &Options{
DefaultTTL: 5 * time.Minute,
},
})
if err != nil {
log.Fatal(err)
}
ctx := context.Background()
// Get the cache instance
cache := GetDefaultCache()
// Store raw bytes
data := []byte("Hello, Redis!")
err = cache.SetBytes(ctx, "greeting", data, 1*time.Hour)
if err != nil {
_ = Close()
log.Fatal(err)
}
// Retrieve raw bytes
retrieved, err := cache.GetBytes(ctx, "greeting")
if err != nil {
_ = Close()
log.Fatal(err)
}
fmt.Printf("Retrieved data: %s\n", string(retrieved))
// Clear all cache
err = cache.Clear(ctx)
if err != nil {
_ = Close()
log.Fatal(err)
}
_ = Close()
}
// ExampleMemcacheCache demonstrates using the Memcache cache provider.
func ExampleMemcacheCache() {
// Initialize with Memcache provider
err := UseMemcache(&MemcacheConfig{
Servers: []string{"localhost:11211"},
Options: &Options{
DefaultTTL: 5 * time.Minute,
},
})
if err != nil {
log.Fatal(err)
}
ctx := context.Background()
// Get the cache instance
cache := GetDefaultCache()
// Store a value
type Product struct {
ID int
Name string
Price float64
}
product := Product{ID: 100, Name: "Widget", Price: 29.99}
err = cache.Set(ctx, "product:100", product, 30*time.Minute)
if err != nil {
_ = Close()
log.Fatal(err)
}
// Retrieve a value
var retrieved Product
err = cache.Get(ctx, "product:100", &retrieved)
if err != nil {
_ = Close()
log.Fatal(err)
}
fmt.Printf("Retrieved product: %+v\n", retrieved)
_ = Close()
}
// ExampleGetOrSet demonstrates the GetOrSet pattern for lazy loading.
func ExampleGetOrSet() {
err := UseMemory(&Options{
DefaultTTL: 5 * time.Minute,
MaxSize: 1000,
})
if err != nil {
log.Fatal(err)
}
ctx := context.Background()
cache := GetDefaultCache()
type ExpensiveData struct {
Result string
}
var data ExpensiveData
err = cache.GetOrSet(ctx, "expensive:computation", &data, 10*time.Minute, func() (interface{}, error) {
// This expensive operation only runs if the key is not in cache
fmt.Println("Computing expensive result...")
time.Sleep(1 * time.Second)
return ExpensiveData{Result: "computed value"}, nil
})
if err != nil {
_ = Close()
log.Fatal(err)
}
fmt.Printf("Data: %+v\n", data)
// Second call will use cached value
err = cache.GetOrSet(ctx, "expensive:computation", &data, 10*time.Minute, func() (interface{}, error) {
fmt.Println("This won't be called!")
return ExpensiveData{Result: "new value"}, nil
})
if err != nil {
_ = Close()
log.Fatal(err)
}
fmt.Printf("Cached data: %+v\n", data)
_ = Close()
}
// ExampleCustomProvider demonstrates using a custom provider.
func ExampleCustomProvider() {
// Create a custom provider
memProvider := NewMemoryProvider(&Options{
DefaultTTL: 10 * time.Minute,
MaxSize: 500,
})
// Initialize with custom provider
Initialize(memProvider)
ctx := context.Background()
cache := GetDefaultCache()
// Use the cache
err := cache.SetBytes(ctx, "key", []byte("value"), 5*time.Minute)
if err != nil {
_ = Close()
log.Fatal(err)
}
// Clean expired items (memory provider specific)
if mp, ok := cache.provider.(*MemoryProvider); ok {
count := mp.CleanExpired(ctx)
fmt.Printf("Cleaned %d expired items\n", count)
}
_ = Close()
}
// ExampleDeleteByPattern demonstrates pattern-based deletion (Redis only).
func ExampleDeleteByPattern() {
err := UseRedis(&RedisConfig{
Host: "localhost",
Port: 6379,
Options: &Options{
DefaultTTL: 5 * time.Minute,
},
})
if err != nil {
log.Fatal(err)
}
ctx := context.Background()
cache := GetDefaultCache()
// Store multiple keys with a pattern
_ = cache.SetBytes(ctx, "user:1:profile", []byte("profile1"), 10*time.Minute)
_ = cache.SetBytes(ctx, "user:2:profile", []byte("profile2"), 10*time.Minute)
_ = cache.SetBytes(ctx, "user:1:settings", []byte("settings1"), 10*time.Minute)
// Delete all keys matching pattern (Redis glob pattern)
err = cache.DeleteByPattern(ctx, "user:*:profile")
if err != nil {
_ = Close()
log.Print(err)
return
}
fmt.Println("Deleted all user profile keys")
_ = Close()
}

57
pkg/cache/provider.go vendored Normal file
View File

@@ -0,0 +1,57 @@
package cache
import (
"context"
"time"
)
// Provider defines the interface that all cache providers must implement.
type Provider interface {
// Get retrieves a value from the cache by key.
// Returns nil, false if key doesn't exist or is expired.
Get(ctx context.Context, key string) ([]byte, bool)
// Set stores a value in the cache with the specified TTL.
// If ttl is 0, the item never expires.
Set(ctx context.Context, key string, value []byte, ttl time.Duration) error
// Delete removes a key from the cache.
Delete(ctx context.Context, key string) error
// DeleteByPattern removes all keys matching the pattern.
// Pattern syntax depends on the provider implementation.
DeleteByPattern(ctx context.Context, pattern string) error
// Clear removes all items from the cache.
Clear(ctx context.Context) error
// Exists checks if a key exists in the cache.
Exists(ctx context.Context, key string) bool
// Close closes the provider and releases any resources.
Close() error
// Stats returns statistics about the cache provider.
Stats(ctx context.Context) (*CacheStats, error)
}
// CacheStats contains cache statistics.
type CacheStats struct {
Hits int64 `json:"hits"`
Misses int64 `json:"misses"`
Keys int64 `json:"keys"`
ProviderType string `json:"provider_type"`
ProviderStats map[string]any `json:"provider_stats,omitempty"`
}
// Options contains configuration options for cache providers.
type Options struct {
// DefaultTTL is the default time-to-live for cache items.
DefaultTTL time.Duration
// MaxSize is the maximum number of items (for in-memory provider).
MaxSize int
// EvictionPolicy determines how items are evicted (LRU, LFU, etc).
EvictionPolicy string
}

144
pkg/cache/provider_memcache.go vendored Normal file
View File

@@ -0,0 +1,144 @@
package cache
import (
"context"
"fmt"
"time"
"github.com/bradfitz/gomemcache/memcache"
)
// MemcacheProvider is a Memcache implementation of the Provider interface.
type MemcacheProvider struct {
client *memcache.Client
options *Options
}
// MemcacheConfig contains Memcache-specific configuration.
type MemcacheConfig struct {
// Servers is a list of memcache server addresses (e.g., "localhost:11211")
Servers []string
// MaxIdleConns is the maximum number of idle connections (default: 2)
MaxIdleConns int
// Timeout for connection operations (default: 1 second)
Timeout time.Duration
// Options contains general cache options
Options *Options
}
// NewMemcacheProvider creates a new Memcache cache provider.
func NewMemcacheProvider(config *MemcacheConfig) (*MemcacheProvider, error) {
if config == nil {
config = &MemcacheConfig{
Servers: []string{"localhost:11211"},
}
}
if len(config.Servers) == 0 {
config.Servers = []string{"localhost:11211"}
}
if config.MaxIdleConns == 0 {
config.MaxIdleConns = 2
}
if config.Timeout == 0 {
config.Timeout = 1 * time.Second
}
if config.Options == nil {
config.Options = &Options{
DefaultTTL: 5 * time.Minute,
}
}
client := memcache.New(config.Servers...)
client.MaxIdleConns = config.MaxIdleConns
client.Timeout = config.Timeout
// Test connection
if err := client.Ping(); err != nil {
return nil, fmt.Errorf("failed to connect to Memcache: %w", err)
}
return &MemcacheProvider{
client: client,
options: config.Options,
}, nil
}
// Get retrieves a value from the cache by key.
func (m *MemcacheProvider) Get(ctx context.Context, key string) ([]byte, bool) {
item, err := m.client.Get(key)
if err == memcache.ErrCacheMiss {
return nil, false
}
if err != nil {
return nil, false
}
return item.Value, true
}
// Set stores a value in the cache with the specified TTL.
func (m *MemcacheProvider) Set(ctx context.Context, key string, value []byte, ttl time.Duration) error {
if ttl == 0 {
ttl = m.options.DefaultTTL
}
item := &memcache.Item{
Key: key,
Value: value,
Expiration: int32(ttl.Seconds()),
}
return m.client.Set(item)
}
// Delete removes a key from the cache.
func (m *MemcacheProvider) Delete(ctx context.Context, key string) error {
err := m.client.Delete(key)
if err == memcache.ErrCacheMiss {
return nil
}
return err
}
// DeleteByPattern removes all keys matching the pattern.
// Note: Memcache does not support pattern-based deletion natively.
// This is a no-op for memcache and returns an error.
func (m *MemcacheProvider) DeleteByPattern(ctx context.Context, pattern string) error {
return fmt.Errorf("pattern-based deletion is not supported by Memcache")
}
// Clear removes all items from the cache.
func (m *MemcacheProvider) Clear(ctx context.Context) error {
return m.client.FlushAll()
}
// Exists checks if a key exists in the cache.
func (m *MemcacheProvider) Exists(ctx context.Context, key string) bool {
_, err := m.client.Get(key)
return err == nil
}
// Close closes the provider and releases any resources.
func (m *MemcacheProvider) Close() error {
// Memcache client doesn't have a close method
return nil
}
// Stats returns statistics about the cache provider.
// Note: Memcache provider returns limited statistics.
func (m *MemcacheProvider) Stats(ctx context.Context) (*CacheStats, error) {
stats := &CacheStats{
ProviderType: "memcache",
ProviderStats: map[string]any{
"note": "Memcache does not provide detailed statistics through the standard client",
},
}
return stats, nil
}

226
pkg/cache/provider_memory.go vendored Normal file
View File

@@ -0,0 +1,226 @@
package cache
import (
"context"
"fmt"
"regexp"
"sync"
"time"
)
// memoryItem represents a cached item in memory.
type memoryItem struct {
Value []byte
Expiration time.Time
LastAccess time.Time
HitCount int64
}
// isExpired checks if the item has expired.
func (m *memoryItem) isExpired() bool {
if m.Expiration.IsZero() {
return false
}
return time.Now().After(m.Expiration)
}
// MemoryProvider is an in-memory implementation of the Provider interface.
type MemoryProvider struct {
mu sync.RWMutex
items map[string]*memoryItem
options *Options
hits int64
misses int64
}
// NewMemoryProvider creates a new in-memory cache provider.
func NewMemoryProvider(opts *Options) *MemoryProvider {
if opts == nil {
opts = &Options{
DefaultTTL: 5 * time.Minute,
MaxSize: 10000,
}
}
return &MemoryProvider{
items: make(map[string]*memoryItem),
options: opts,
}
}
// Get retrieves a value from the cache by key.
func (m *MemoryProvider) Get(ctx context.Context, key string) ([]byte, bool) {
m.mu.Lock()
defer m.mu.Unlock()
item, exists := m.items[key]
if !exists {
m.misses++
return nil, false
}
if item.isExpired() {
delete(m.items, key)
m.misses++
return nil, false
}
item.LastAccess = time.Now()
item.HitCount++
m.hits++
return item.Value, true
}
// Set stores a value in the cache with the specified TTL.
func (m *MemoryProvider) Set(ctx context.Context, key string, value []byte, ttl time.Duration) error {
m.mu.Lock()
defer m.mu.Unlock()
if ttl == 0 {
ttl = m.options.DefaultTTL
}
var expiration time.Time
if ttl > 0 {
expiration = time.Now().Add(ttl)
}
// Check max size and evict if necessary
if m.options.MaxSize > 0 && len(m.items) >= m.options.MaxSize {
if _, exists := m.items[key]; !exists {
m.evictOne()
}
}
m.items[key] = &memoryItem{
Value: value,
Expiration: expiration,
LastAccess: time.Now(),
}
return nil
}
// Delete removes a key from the cache.
func (m *MemoryProvider) Delete(ctx context.Context, key string) error {
m.mu.Lock()
defer m.mu.Unlock()
delete(m.items, key)
return nil
}
// DeleteByPattern removes all keys matching the pattern.
func (m *MemoryProvider) DeleteByPattern(ctx context.Context, pattern string) error {
m.mu.Lock()
defer m.mu.Unlock()
re, err := regexp.Compile(pattern)
if err != nil {
return fmt.Errorf("invalid pattern: %w", err)
}
for key := range m.items {
if re.MatchString(key) {
delete(m.items, key)
}
}
return nil
}
// Clear removes all items from the cache.
func (m *MemoryProvider) Clear(ctx context.Context) error {
m.mu.Lock()
defer m.mu.Unlock()
m.items = make(map[string]*memoryItem)
m.hits = 0
m.misses = 0
return nil
}
// Exists checks if a key exists in the cache.
func (m *MemoryProvider) Exists(ctx context.Context, key string) bool {
m.mu.RLock()
defer m.mu.RUnlock()
item, exists := m.items[key]
if !exists {
return false
}
return !item.isExpired()
}
// Close closes the provider and releases any resources.
func (m *MemoryProvider) Close() error {
m.mu.Lock()
defer m.mu.Unlock()
m.items = nil
return nil
}
// Stats returns statistics about the cache provider.
func (m *MemoryProvider) Stats(ctx context.Context) (*CacheStats, error) {
m.mu.RLock()
defer m.mu.RUnlock()
// Clean expired items first
validKeys := 0
for _, item := range m.items {
if !item.isExpired() {
validKeys++
}
}
return &CacheStats{
Hits: m.hits,
Misses: m.misses,
Keys: int64(validKeys),
ProviderType: "memory",
ProviderStats: map[string]any{
"capacity": m.options.MaxSize,
},
}, nil
}
// evictOne removes one item from the cache using LRU strategy.
func (m *MemoryProvider) evictOne() {
var oldestKey string
var oldestTime time.Time
for key, item := range m.items {
if item.isExpired() {
delete(m.items, key)
return
}
if oldestKey == "" || item.LastAccess.Before(oldestTime) {
oldestKey = key
oldestTime = item.LastAccess
}
}
if oldestKey != "" {
delete(m.items, oldestKey)
}
}
// CleanExpired removes all expired items from the cache.
func (m *MemoryProvider) CleanExpired(ctx context.Context) int {
m.mu.Lock()
defer m.mu.Unlock()
count := 0
for key, item := range m.items {
if item.isExpired() {
delete(m.items, key)
count++
}
}
return count
}

185
pkg/cache/provider_redis.go vendored Normal file
View File

@@ -0,0 +1,185 @@
package cache
import (
"context"
"fmt"
"time"
"github.com/redis/go-redis/v9"
)
// RedisProvider is a Redis implementation of the Provider interface.
type RedisProvider struct {
client *redis.Client
options *Options
}
// RedisConfig contains Redis-specific configuration.
type RedisConfig struct {
// Host is the Redis server host (default: localhost)
Host string
// Port is the Redis server port (default: 6379)
Port int
// Password for Redis authentication (optional)
Password string
// DB is the Redis database number (default: 0)
DB int
// PoolSize is the maximum number of connections (default: 10)
PoolSize int
// Options contains general cache options
Options *Options
}
// NewRedisProvider creates a new Redis cache provider.
func NewRedisProvider(config *RedisConfig) (*RedisProvider, error) {
if config == nil {
config = &RedisConfig{
Host: "localhost",
Port: 6379,
DB: 0,
}
}
if config.Host == "" {
config.Host = "localhost"
}
if config.Port == 0 {
config.Port = 6379
}
if config.PoolSize == 0 {
config.PoolSize = 10
}
if config.Options == nil {
config.Options = &Options{
DefaultTTL: 5 * time.Minute,
}
}
client := redis.NewClient(&redis.Options{
Addr: fmt.Sprintf("%s:%d", config.Host, config.Port),
Password: config.Password,
DB: config.DB,
PoolSize: config.PoolSize,
})
// Test connection
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := client.Ping(ctx).Err(); err != nil {
return nil, fmt.Errorf("failed to connect to Redis: %w", err)
}
return &RedisProvider{
client: client,
options: config.Options,
}, nil
}
// Get retrieves a value from the cache by key.
func (r *RedisProvider) Get(ctx context.Context, key string) ([]byte, bool) {
val, err := r.client.Get(ctx, key).Bytes()
if err == redis.Nil {
return nil, false
}
if err != nil {
return nil, false
}
return val, true
}
// Set stores a value in the cache with the specified TTL.
func (r *RedisProvider) Set(ctx context.Context, key string, value []byte, ttl time.Duration) error {
if ttl == 0 {
ttl = r.options.DefaultTTL
}
return r.client.Set(ctx, key, value, ttl).Err()
}
// Delete removes a key from the cache.
func (r *RedisProvider) Delete(ctx context.Context, key string) error {
return r.client.Del(ctx, key).Err()
}
// DeleteByPattern removes all keys matching the pattern.
func (r *RedisProvider) DeleteByPattern(ctx context.Context, pattern string) error {
iter := r.client.Scan(ctx, 0, pattern, 0).Iterator()
pipe := r.client.Pipeline()
count := 0
for iter.Next(ctx) {
pipe.Del(ctx, iter.Val())
count++
// Execute pipeline in batches of 100
if count%100 == 0 {
if _, err := pipe.Exec(ctx); err != nil {
return err
}
pipe = r.client.Pipeline()
}
}
if err := iter.Err(); err != nil {
return err
}
// Execute remaining commands
if count%100 != 0 {
_, err := pipe.Exec(ctx)
return err
}
return nil
}
// Clear removes all items from the cache.
func (r *RedisProvider) Clear(ctx context.Context) error {
return r.client.FlushDB(ctx).Err()
}
// Exists checks if a key exists in the cache.
func (r *RedisProvider) Exists(ctx context.Context, key string) bool {
result, err := r.client.Exists(ctx, key).Result()
if err != nil {
return false
}
return result > 0
}
// Close closes the provider and releases any resources.
func (r *RedisProvider) Close() error {
return r.client.Close()
}
// Stats returns statistics about the cache provider.
func (r *RedisProvider) Stats(ctx context.Context) (*CacheStats, error) {
info, err := r.client.Info(ctx, "stats", "keyspace").Result()
if err != nil {
return nil, fmt.Errorf("failed to get Redis stats: %w", err)
}
dbSize, err := r.client.DBSize(ctx).Result()
if err != nil {
return nil, fmt.Errorf("failed to get DB size: %w", err)
}
// Parse stats from INFO command
// This is a simplified version - you may want to parse more detailed stats
stats := &CacheStats{
Keys: dbSize,
ProviderType: "redis",
ProviderStats: map[string]any{
"info": info,
},
}
return stats, nil
}

127
pkg/cache/query_cache.go vendored Normal file
View File

@@ -0,0 +1,127 @@
package cache
import (
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"strings"
"github.com/bitechdev/ResolveSpec/pkg/common"
)
// QueryCacheKey represents the components used to build a cache key for query total count
type QueryCacheKey struct {
TableName string `json:"table_name"`
Filters []common.FilterOption `json:"filters"`
Sort []common.SortOption `json:"sort"`
CustomSQLWhere string `json:"custom_sql_where,omitempty"`
CustomSQLOr string `json:"custom_sql_or,omitempty"`
Expand []ExpandOptionKey `json:"expand,omitempty"`
Distinct bool `json:"distinct,omitempty"`
CursorForward string `json:"cursor_forward,omitempty"`
CursorBackward string `json:"cursor_backward,omitempty"`
}
// ExpandOptionKey represents expand options for cache key
type ExpandOptionKey struct {
Relation string `json:"relation"`
Where string `json:"where,omitempty"`
}
// BuildQueryCacheKey builds a cache key from query parameters for total count caching
// This is used to cache the total count of records matching a query
func BuildQueryCacheKey(tableName string, filters []common.FilterOption, sort []common.SortOption, customWhere, customOr string) string {
key := QueryCacheKey{
TableName: tableName,
Filters: filters,
Sort: sort,
CustomSQLWhere: customWhere,
CustomSQLOr: customOr,
}
// Serialize to JSON for consistent hashing
jsonData, err := json.Marshal(key)
if err != nil {
// Fallback to simple string concatenation if JSON fails
return hashString(fmt.Sprintf("%s_%v_%v_%s_%s", tableName, filters, sort, customWhere, customOr))
}
return hashString(string(jsonData))
}
// BuildExtendedQueryCacheKey builds a cache key for extended query options (restheadspec)
// Includes expand, distinct, and cursor pagination options
func BuildExtendedQueryCacheKey(tableName string, filters []common.FilterOption, sort []common.SortOption,
customWhere, customOr string, expandOpts []interface{}, distinct bool, cursorFwd, cursorBwd string) string {
key := QueryCacheKey{
TableName: tableName,
Filters: filters,
Sort: sort,
CustomSQLWhere: customWhere,
CustomSQLOr: customOr,
Distinct: distinct,
CursorForward: cursorFwd,
CursorBackward: cursorBwd,
}
// Convert expand options to cache key format
if len(expandOpts) > 0 {
key.Expand = make([]ExpandOptionKey, 0, len(expandOpts))
for _, exp := range expandOpts {
// Type assert to get the expand option fields we care about for caching
if expMap, ok := exp.(map[string]interface{}); ok {
expKey := ExpandOptionKey{}
if rel, ok := expMap["relation"].(string); ok {
expKey.Relation = rel
}
if where, ok := expMap["where"].(string); ok {
expKey.Where = where
}
key.Expand = append(key.Expand, expKey)
}
}
// Sort expand options for consistent hashing (already sorted by relation name above)
}
// Serialize to JSON for consistent hashing
jsonData, err := json.Marshal(key)
if err != nil {
// Fallback to simple string concatenation if JSON fails
return hashString(fmt.Sprintf("%s_%v_%v_%s_%s_%v_%v_%s_%s",
tableName, filters, sort, customWhere, customOr, expandOpts, distinct, cursorFwd, cursorBwd))
}
return hashString(string(jsonData))
}
// hashString computes SHA256 hash of a string
func hashString(s string) string {
h := sha256.New()
h.Write([]byte(s))
return hex.EncodeToString(h.Sum(nil))
}
// GetQueryTotalCacheKey returns a formatted cache key for storing/retrieving total count
func GetQueryTotalCacheKey(hash string) string {
return fmt.Sprintf("query_total:%s", hash)
}
// CachedTotal represents a cached total count
type CachedTotal struct {
Total int `json:"total"`
}
// InvalidateCacheForTable removes all cached totals for a specific table
// This should be called when data in the table changes (insert/update/delete)
func InvalidateCacheForTable(ctx context.Context, tableName string) error {
cache := GetDefaultCache()
// Build a pattern to match all query totals for this table
// Note: This requires pattern matching support in the provider
pattern := fmt.Sprintf("query_total:*%s*", strings.ToLower(tableName))
return cache.DeleteByPattern(ctx, pattern)
}

151
pkg/cache/query_cache_test.go vendored Normal file
View File

@@ -0,0 +1,151 @@
package cache
import (
"context"
"testing"
"time"
"github.com/bitechdev/ResolveSpec/pkg/common"
)
func TestBuildQueryCacheKey(t *testing.T) {
filters := []common.FilterOption{
{Column: "name", Operator: "eq", Value: "test"},
{Column: "age", Operator: "gt", Value: 25},
}
sorts := []common.SortOption{
{Column: "name", Direction: "asc"},
}
// Generate cache key
key1 := BuildQueryCacheKey("users", filters, sorts, "status = 'active'", "")
// Same parameters should generate same key
key2 := BuildQueryCacheKey("users", filters, sorts, "status = 'active'", "")
if key1 != key2 {
t.Errorf("Expected same cache keys for identical parameters, got %s and %s", key1, key2)
}
// Different parameters should generate different key
key3 := BuildQueryCacheKey("users", filters, sorts, "status = 'inactive'", "")
if key1 == key3 {
t.Errorf("Expected different cache keys for different parameters, got %s and %s", key1, key3)
}
}
func TestBuildExtendedQueryCacheKey(t *testing.T) {
filters := []common.FilterOption{
{Column: "name", Operator: "eq", Value: "test"},
}
sorts := []common.SortOption{
{Column: "name", Direction: "asc"},
}
expandOpts := []interface{}{
map[string]interface{}{
"relation": "posts",
"where": "status = 'published'",
},
}
// Generate cache key
key1 := BuildExtendedQueryCacheKey("users", filters, sorts, "", "", expandOpts, false, "", "")
// Same parameters should generate same key
key2 := BuildExtendedQueryCacheKey("users", filters, sorts, "", "", expandOpts, false, "", "")
if key1 != key2 {
t.Errorf("Expected same cache keys for identical parameters")
}
// Different distinct value should generate different key
key3 := BuildExtendedQueryCacheKey("users", filters, sorts, "", "", expandOpts, true, "", "")
if key1 == key3 {
t.Errorf("Expected different cache keys for different distinct values")
}
}
func TestGetQueryTotalCacheKey(t *testing.T) {
hash := "abc123"
key := GetQueryTotalCacheKey(hash)
expected := "query_total:abc123"
if key != expected {
t.Errorf("Expected %s, got %s", expected, key)
}
}
func TestCachedTotalIntegration(t *testing.T) {
// Initialize cache with memory provider for testing
UseMemory(&Options{
DefaultTTL: 1 * time.Minute,
MaxSize: 100,
})
ctx := context.Background()
// Create test data
filters := []common.FilterOption{
{Column: "status", Operator: "eq", Value: "active"},
}
sorts := []common.SortOption{
{Column: "created_at", Direction: "desc"},
}
// Build cache key
cacheKeyHash := BuildQueryCacheKey("test_table", filters, sorts, "", "")
cacheKey := GetQueryTotalCacheKey(cacheKeyHash)
// Store a total count in cache
totalToCache := CachedTotal{Total: 42}
err := GetDefaultCache().Set(ctx, cacheKey, totalToCache, time.Minute)
if err != nil {
t.Fatalf("Failed to set cache: %v", err)
}
// Retrieve from cache
var cachedTotal CachedTotal
err = GetDefaultCache().Get(ctx, cacheKey, &cachedTotal)
if err != nil {
t.Fatalf("Failed to get from cache: %v", err)
}
if cachedTotal.Total != 42 {
t.Errorf("Expected total 42, got %d", cachedTotal.Total)
}
// Test cache miss
nonExistentKey := GetQueryTotalCacheKey("nonexistent")
var missedTotal CachedTotal
err = GetDefaultCache().Get(ctx, nonExistentKey, &missedTotal)
if err == nil {
t.Errorf("Expected error for cache miss, got nil")
}
}
func TestHashString(t *testing.T) {
input1 := "test string"
input2 := "test string"
input3 := "different string"
hash1 := hashString(input1)
hash2 := hashString(input2)
hash3 := hashString(input3)
// Same input should produce same hash
if hash1 != hash2 {
t.Errorf("Expected same hash for identical inputs")
}
// Different input should produce different hash
if hash1 == hash3 {
t.Errorf("Expected different hash for different inputs")
}
// Hash should be hex encoded SHA256 (64 characters)
if len(hash1) != 64 {
t.Errorf("Expected hash length of 64, got %d", len(hash1))
}
}

View File

@@ -4,6 +4,7 @@ import (
"context"
"database/sql"
"fmt"
"reflect"
"strings"
"github.com/uptrace/bun"
@@ -99,12 +100,20 @@ func (b *BunAdapter) RunInTransaction(ctx context.Context, fn func(common.Databa
// BunSelectQuery implements SelectQuery for Bun
type BunSelectQuery struct {
query *bun.SelectQuery
db bun.IDB // Store DB connection for count queries
hasModel bool // Track if Model() was called
schema string // Separated schema name
tableName string // Just the table name, without schema
tableAlias string
query *bun.SelectQuery
db bun.IDB // Store DB connection for count queries
hasModel bool // Track if Model() was called
schema string // Separated schema name
tableName string // Just the table name, without schema
tableAlias string
deferredPreloads []deferredPreload // Preloads to execute as separate queries
}
// deferredPreload represents a preload that will be executed as a separate query
// to avoid PostgreSQL identifier length limits
type deferredPreload struct {
relation string
apply []func(common.SelectQuery) common.SelectQuery
}
func (b *BunSelectQuery) Model(model interface{}) common.SelectQuery {
@@ -233,11 +242,99 @@ func (b *BunSelectQuery) Preload(relation string, conditions ...interface{}) com
return b
}
// // shortenAliasForPostgres shortens a table/relation alias if it would exceed PostgreSQL's 63-char limit
// // when combined with typical column names
// func shortenAliasForPostgres(relationPath string) (string, bool) {
// // Convert relation path to the alias format Bun uses: dots become double underscores
// // Also convert to lowercase and use snake_case as Bun does
// parts := strings.Split(relationPath, ".")
// alias := strings.ToLower(strings.Join(parts, "__"))
// // PostgreSQL truncates identifiers to 63 chars
// // If the alias + typical column name would exceed this, we need to shorten
// // Reserve at least 30 chars for column names (e.g., "__rid_mastertype_hubtype")
// const maxAliasLength = 30
// if len(alias) > maxAliasLength {
// // Create a shortened alias using a hash of the original
// hash := md5.Sum([]byte(alias))
// hashStr := hex.EncodeToString(hash[:])[:8]
// // Keep first few chars of original for readability + hash
// prefixLen := maxAliasLength - 9 // 9 = 1 underscore + 8 hash chars
// if prefixLen > len(alias) {
// prefixLen = len(alias)
// }
// shortened := alias[:prefixLen] + "_" + hashStr
// logger.Debug("Shortened alias '%s' (%d chars) to '%s' (%d chars) to avoid PostgreSQL 63-char limit",
// alias, len(alias), shortened, len(shortened))
// return shortened, true
// }
// return alias, false
// }
// // estimateColumnAliasLength estimates the length of a column alias in a nested preload
// // Bun creates aliases like: relationChain__columnName
// func estimateColumnAliasLength(relationPath string, columnName string) int {
// relationParts := strings.Split(relationPath, ".")
// aliasChain := strings.ToLower(strings.Join(relationParts, "__"))
// // Bun adds "__" between alias and column name
// return len(aliasChain) + 2 + len(columnName)
// }
func (b *BunSelectQuery) PreloadRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery {
// Check if this relation chain would create problematic long aliases
relationParts := strings.Split(relation, ".")
aliasChain := strings.ToLower(strings.Join(relationParts, "__"))
// PostgreSQL's identifier limit is 63 characters
const postgresIdentifierLimit = 63
const safeAliasLimit = 35 // Leave room for column names
// If the alias chain is too long, defer this preload to be executed as a separate query
if len(aliasChain) > safeAliasLimit {
logger.Info("Preload relation '%s' creates long alias chain '%s' (%d chars). "+
"Using separate query to avoid PostgreSQL %d-char identifier limit.",
relation, aliasChain, len(aliasChain), postgresIdentifierLimit)
// For nested preloads (e.g., "Parent.Child"), split into separate preloads
// This avoids the long concatenated alias
if len(relationParts) > 1 {
// Load first level normally: "Parent"
firstLevel := relationParts[0]
remainingPath := strings.Join(relationParts[1:], ".")
logger.Info("Splitting nested preload: loading '%s' first, then '%s' separately",
firstLevel, remainingPath)
// Apply the first level preload normally
b.query = b.query.Relation(firstLevel)
// Store the remaining nested preload to be executed after the main query
b.deferredPreloads = append(b.deferredPreloads, deferredPreload{
relation: relation,
apply: apply,
})
return b
}
// Single level but still too long - just warn and continue
logger.Warn("Single-level preload '%s' has a very long name (%d chars). "+
"Consider renaming the field to avoid potential issues.",
relation, len(aliasChain))
}
// Normal preload handling
b.query = b.query.Relation(relation, func(sq *bun.SelectQuery) *bun.SelectQuery {
defer func() {
if r := recover(); r != nil {
logger.HandlePanic("BunSelectQuery.PreloadRelation", r)
err := logger.HandlePanic("BunSelectQuery.PreloadRelation", r)
if err != nil {
return
}
}
}()
if len(apply) == 0 {
@@ -306,7 +403,23 @@ func (b *BunSelectQuery) Scan(ctx context.Context, dest interface{}) (err error)
if dest == nil {
return fmt.Errorf("destination cannot be nil")
}
return b.query.Scan(ctx, dest)
// Execute the main query first
err = b.query.Scan(ctx, dest)
if err != nil {
return err
}
// Execute any deferred preloads
if len(b.deferredPreloads) > 0 {
err = b.executeDeferredPreloads(ctx, dest)
if err != nil {
logger.Warn("Failed to execute deferred preloads: %v", err)
// Don't fail the whole query, just log the warning
}
}
return nil
}
func (b *BunSelectQuery) ScanModel(ctx context.Context) (err error) {
@@ -319,7 +432,132 @@ func (b *BunSelectQuery) ScanModel(ctx context.Context) (err error) {
return fmt.Errorf("model is nil")
}
return b.query.Scan(ctx)
// Execute the main query first
err = b.query.Scan(ctx)
if err != nil {
return err
}
// Execute any deferred preloads
if len(b.deferredPreloads) > 0 {
model := b.query.GetModel()
err = b.executeDeferredPreloads(ctx, model.Value())
if err != nil {
logger.Warn("Failed to execute deferred preloads: %v", err)
// Don't fail the whole query, just log the warning
}
}
return nil
}
// executeDeferredPreloads executes preloads that were deferred to avoid PostgreSQL identifier length limits
func (b *BunSelectQuery) executeDeferredPreloads(ctx context.Context, dest interface{}) error {
if len(b.deferredPreloads) == 0 {
return nil
}
for _, dp := range b.deferredPreloads {
err := b.executeSingleDeferredPreload(ctx, dest, dp)
if err != nil {
return fmt.Errorf("failed to execute deferred preload '%s': %w", dp.relation, err)
}
}
return nil
}
// executeSingleDeferredPreload executes a single deferred preload
// For a relation like "Parent.Child", it:
// 1. Finds all loaded Parent records in dest
// 2. Loads Child records for those Parents using a separate query (loading only "Child", not "Parent.Child")
// 3. Bun automatically assigns the Child records to the appropriate Parent.Child field
func (b *BunSelectQuery) executeSingleDeferredPreload(ctx context.Context, dest interface{}, dp deferredPreload) error {
relationParts := strings.Split(dp.relation, ".")
if len(relationParts) < 2 {
return fmt.Errorf("deferred preload must be nested (e.g., 'Parent.Child'), got: %s", dp.relation)
}
// The parent relation that was already loaded
parentRelation := relationParts[0]
// The child relation we need to load
childRelation := strings.Join(relationParts[1:], ".")
logger.Debug("Executing deferred preload: loading '%s' on already-loaded '%s'", childRelation, parentRelation)
// Use reflection to access the parent relation field(s) in the loaded records
// Then load the child relation for those parent records
destValue := reflect.ValueOf(dest)
if destValue.Kind() == reflect.Ptr {
destValue = destValue.Elem()
}
// Handle both slice and single record
if destValue.Kind() == reflect.Slice {
// Iterate through each record in the slice
for i := 0; i < destValue.Len(); i++ {
record := destValue.Index(i)
if err := b.loadChildRelationForRecord(ctx, record, parentRelation, childRelation, dp.apply); err != nil {
logger.Warn("Failed to load child relation '%s' for record %d: %v", childRelation, i, err)
// Continue with other records
}
}
} else {
// Single record
if err := b.loadChildRelationForRecord(ctx, destValue, parentRelation, childRelation, dp.apply); err != nil {
return fmt.Errorf("failed to load child relation '%s': %w", childRelation, err)
}
}
return nil
}
// loadChildRelationForRecord loads a child relation for a single parent record
func (b *BunSelectQuery) loadChildRelationForRecord(ctx context.Context, record reflect.Value, parentRelation, childRelation string, apply []func(common.SelectQuery) common.SelectQuery) error {
// Ensure we're working with the actual struct value, not a pointer
if record.Kind() == reflect.Ptr {
record = record.Elem()
}
// Get the parent relation field
parentField := record.FieldByName(parentRelation)
if !parentField.IsValid() {
// Parent relation field doesn't exist
logger.Debug("Parent relation field '%s' not found in record", parentRelation)
return nil
}
// Check if the parent field is nil (for pointer fields)
if parentField.Kind() == reflect.Ptr && parentField.IsNil() {
// Parent relation not loaded or nil, skip
logger.Debug("Parent relation field '%s' is nil, skipping child preload", parentRelation)
return nil
}
// Get the interface value to pass to Bun
parentValue := parentField.Interface()
// Load the child relation on the parent record
// This uses a shorter alias since we're only loading "Child", not "Parent.Child"
return b.db.NewSelect().
Model(parentValue).
Relation(childRelation, func(sq *bun.SelectQuery) *bun.SelectQuery {
// Apply any custom query modifications
if len(apply) > 0 {
wrapper := &BunSelectQuery{query: sq, db: b.db}
current := common.SelectQuery(wrapper)
for _, fn := range apply {
if fn != nil {
current = fn(current)
}
}
if finalBun, ok := current.(*BunSelectQuery); ok {
return finalBun.query
}
}
return sq
}).
Scan(ctx)
}
func (b *BunSelectQuery) Count(ctx context.Context) (count int, err error) {
@@ -401,7 +639,7 @@ func (b *BunInsertQuery) Exec(ctx context.Context) (res common.Result, err error
err = logger.HandlePanic("BunInsertQuery.Exec", r)
}
}()
if b.values != nil && len(b.values) > 0 {
if len(b.values) > 0 {
if !b.hasModel {
// If no model was set, use the values map as the model
// Bun can insert map[string]interface{} directly

View File

@@ -141,6 +141,12 @@ func (b *BunRouterRequest) AllHeaders() map[string]string {
return headers
}
// UnderlyingRequest returns the underlying *http.Request
// This is useful when you need to pass the request to other handlers
func (b *BunRouterRequest) UnderlyingRequest() *http.Request {
return b.req.Request
}
// StandardBunRouterAdapter creates routes compatible with standard bunrouter handlers
type StandardBunRouterAdapter struct {
*BunRouterAdapter

View File

@@ -137,6 +137,12 @@ func (h *HTTPRequest) AllHeaders() map[string]string {
return headers
}
// UnderlyingRequest returns the underlying *http.Request
// This is useful when you need to pass the request to other handlers
func (h *HTTPRequest) UnderlyingRequest() *http.Request {
return h.req
}
// HTTPResponseWriter adapts our ResponseWriter interface to standard http.ResponseWriter
type HTTPResponseWriter struct {
resp http.ResponseWriter
@@ -166,6 +172,12 @@ func (h *HTTPResponseWriter) WriteJSON(data interface{}) error {
return json.NewEncoder(h.resp).Encode(data)
}
// UnderlyingResponseWriter returns the underlying http.ResponseWriter
// This is useful when you need to pass the response writer to other handlers
func (h *HTTPResponseWriter) UnderlyingResponseWriter() http.ResponseWriter {
return h.resp
}
// StandardMuxAdapter creates routes compatible with standard http.HandlerFunc
type StandardMuxAdapter struct {
*MuxAdapter

119
pkg/common/cors.go Normal file
View File

@@ -0,0 +1,119 @@
package common
import (
"fmt"
"strings"
)
// CORSConfig holds CORS configuration
type CORSConfig struct {
AllowedOrigins []string
AllowedMethods []string
AllowedHeaders []string
MaxAge int
}
// DefaultCORSConfig returns a default CORS configuration suitable for HeadSpec
func DefaultCORSConfig() CORSConfig {
return CORSConfig{
AllowedOrigins: []string{"*"},
AllowedMethods: []string{"GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"},
AllowedHeaders: GetHeadSpecHeaders(),
MaxAge: 86400, // 24 hours
}
}
// GetHeadSpecHeaders returns all headers used by HeadSpec
func GetHeadSpecHeaders() []string {
return []string{
// Standard headers
"Content-Type",
"Authorization",
"Accept",
"Accept-Language",
"Content-Language",
// Field Selection
"X-Select-Fields",
"X-Not-Select-Fields",
"X-Clean-JSON",
// Filtering & Search
"X-FieldFilter-*",
"X-SearchFilter-*",
"X-SearchOp-*",
"X-SearchOr-*",
"X-SearchAnd-*",
"X-SearchCols",
"X-Custom-SQL-W",
"X-Custom-SQL-W-*",
"X-Custom-SQL-Or",
"X-Custom-SQL-Or-*",
// Joins & Relations
"X-Preload",
"X-Preload-*",
"X-Expand",
"X-Expand-*",
"X-Custom-SQL-Join",
"X-Custom-SQL-Join-*",
// Sorting & Pagination
"X-Sort",
"X-Sort-*",
"X-Limit",
"X-Offset",
"X-Cursor-Forward",
"X-Cursor-Backward",
// Advanced Features
"X-AdvSQL-*",
"X-CQL-Sel-*",
"X-Distinct",
"X-SkipCount",
"X-SkipCache",
"X-Fetch-RowNumber",
"X-PKRow",
// Response Format
"X-SimpleAPI",
"X-DetailAPI",
"X-Syncfusion",
"X-Single-Record-As-Object",
// Transaction Control
"X-Transaction-Atomic",
// X-Files - comprehensive JSON configuration
"X-Files",
}
}
// SetCORSHeaders sets CORS headers on a response writer
func SetCORSHeaders(w ResponseWriter, config CORSConfig) {
// Set allowed origins
if len(config.AllowedOrigins) > 0 {
w.SetHeader("Access-Control-Allow-Origin", strings.Join(config.AllowedOrigins, ", "))
}
// Set allowed methods
if len(config.AllowedMethods) > 0 {
w.SetHeader("Access-Control-Allow-Methods", strings.Join(config.AllowedMethods, ", "))
}
// Set allowed headers
if len(config.AllowedHeaders) > 0 {
w.SetHeader("Access-Control-Allow-Headers", strings.Join(config.AllowedHeaders, ", "))
}
// Set max age
if config.MaxAge > 0 {
w.SetHeader("Access-Control-Max-Age", fmt.Sprintf("%d", config.MaxAge))
}
// Allow credentials
w.SetHeader("Access-Control-Allow-Credentials", "true")
// Expose headers that clients can read
w.SetHeader("Access-Control-Expose-Headers", "Content-Range, X-Api-Range-Total, X-Api-Range-Size")
}

View File

@@ -0,0 +1,97 @@
package common
// Example showing how to use the common handler interfaces
// This file demonstrates the handler interface hierarchy and usage patterns
// ProcessWithAnyHandler demonstrates using the base SpecHandler interface
// which works with any handler type (resolvespec, restheadspec, or funcspec)
func ProcessWithAnyHandler(handler SpecHandler) Database {
// All handlers expose GetDatabase() through the SpecHandler interface
return handler.GetDatabase()
}
// ProcessCRUDRequest demonstrates using the CRUDHandler interface
// which works with resolvespec.Handler and restheadspec.Handler
func ProcessCRUDRequest(handler CRUDHandler, w ResponseWriter, r Request, params map[string]string) {
// Both resolvespec and restheadspec handlers implement Handle()
handler.Handle(w, r, params)
}
// ProcessMetadataRequest demonstrates getting metadata from CRUD handlers
func ProcessMetadataRequest(handler CRUDHandler, w ResponseWriter, r Request, params map[string]string) {
// Both resolvespec and restheadspec handlers implement HandleGet()
handler.HandleGet(w, r, params)
}
// Example usage patterns (not executable, just for documentation):
/*
// Example 1: Using with resolvespec.Handler
func ExampleResolveSpec() {
db := // ... get database
registry := // ... get registry
handler := resolvespec.NewHandler(db, registry)
// Can be used as SpecHandler
var specHandler SpecHandler = handler
database := specHandler.GetDatabase()
// Can be used as CRUDHandler
var crudHandler CRUDHandler = handler
crudHandler.Handle(w, r, params)
crudHandler.HandleGet(w, r, params)
}
// Example 2: Using with restheadspec.Handler
func ExampleRestHeadSpec() {
db := // ... get database
registry := // ... get registry
handler := restheadspec.NewHandler(db, registry)
// Can be used as SpecHandler
var specHandler SpecHandler = handler
database := specHandler.GetDatabase()
// Can be used as CRUDHandler
var crudHandler CRUDHandler = handler
crudHandler.Handle(w, r, params)
crudHandler.HandleGet(w, r, params)
}
// Example 3: Using with funcspec.Handler
func ExampleFuncSpec() {
db := // ... get database
handler := funcspec.NewHandler(db)
// Can be used as SpecHandler
var specHandler SpecHandler = handler
database := specHandler.GetDatabase()
// Can be used as QueryHandler
var queryHandler QueryHandler = handler
// funcspec has different methods: SqlQueryList() and SqlQuery()
// which return HTTP handler functions
}
// Example 4: Polymorphic handler processing
func ProcessHandlers(handlers []SpecHandler) {
for _, handler := range handlers {
// All handlers expose the database
db := handler.GetDatabase()
// Type switch for specific handler types
switch h := handler.(type) {
case CRUDHandler:
// This is resolvespec or restheadspec
// Can call Handle() and HandleGet()
_ = h
case QueryHandler:
// This is funcspec
// Can call SqlQueryList() and SqlQuery()
_ = h
}
}
}
*/

View File

@@ -1,6 +1,11 @@
package common
import "context"
import (
"context"
"encoding/json"
"io"
"net/http"
)
// Database interface designed to work with both GORM and Bun
type Database interface {
@@ -117,6 +122,7 @@ type Request interface {
PathParam(key string) string
QueryParam(key string) string
AllQueryParams() map[string]string // Get all query parameters as a map
UnderlyingRequest() *http.Request // Get the underlying *http.Request for forwarding to other handlers
}
// ResponseWriter interface abstracts HTTP response
@@ -125,11 +131,113 @@ type ResponseWriter interface {
WriteHeader(statusCode int)
Write(data []byte) (int, error)
WriteJSON(data interface{}) error
UnderlyingResponseWriter() http.ResponseWriter // Get the underlying http.ResponseWriter for forwarding to other handlers
}
// HTTPHandlerFunc type for HTTP handlers
type HTTPHandlerFunc func(ResponseWriter, Request)
// WrapHTTPRequest wraps standard http.ResponseWriter and *http.Request into common interfaces
func WrapHTTPRequest(w http.ResponseWriter, r *http.Request) (ResponseWriter, Request) {
return &StandardResponseWriter{w: w}, &StandardRequest{r: r}
}
// StandardResponseWriter adapts http.ResponseWriter to ResponseWriter interface
type StandardResponseWriter struct {
w http.ResponseWriter
status int
}
func (s *StandardResponseWriter) SetHeader(key, value string) {
s.w.Header().Set(key, value)
}
func (s *StandardResponseWriter) WriteHeader(statusCode int) {
s.status = statusCode
s.w.WriteHeader(statusCode)
}
func (s *StandardResponseWriter) Write(data []byte) (int, error) {
return s.w.Write(data)
}
func (s *StandardResponseWriter) WriteJSON(data interface{}) error {
s.SetHeader("Content-Type", "application/json")
return json.NewEncoder(s.w).Encode(data)
}
func (s *StandardResponseWriter) UnderlyingResponseWriter() http.ResponseWriter {
return s.w
}
// StandardRequest adapts *http.Request to Request interface
type StandardRequest struct {
r *http.Request
body []byte
}
func (s *StandardRequest) Method() string {
return s.r.Method
}
func (s *StandardRequest) URL() string {
return s.r.URL.String()
}
func (s *StandardRequest) Header(key string) string {
return s.r.Header.Get(key)
}
func (s *StandardRequest) AllHeaders() map[string]string {
headers := make(map[string]string)
for key, values := range s.r.Header {
if len(values) > 0 {
headers[key] = values[0]
}
}
return headers
}
func (s *StandardRequest) Body() ([]byte, error) {
if s.body != nil {
return s.body, nil
}
if s.r.Body == nil {
return nil, nil
}
defer s.r.Body.Close()
body, err := io.ReadAll(s.r.Body)
if err != nil {
return nil, err
}
s.body = body
return body, nil
}
func (s *StandardRequest) PathParam(key string) string {
// Standard http.Request doesn't have path params
// This should be set by the router
return ""
}
func (s *StandardRequest) QueryParam(key string) string {
return s.r.URL.Query().Get(key)
}
func (s *StandardRequest) AllQueryParams() map[string]string {
params := make(map[string]string)
for key, values := range s.r.URL.Query() {
if len(values) > 0 {
params[key] = values[0]
}
}
return params
}
func (s *StandardRequest) UnderlyingRequest() *http.Request {
return s.r
}
// TableNameProvider interface for models that provide table names
type TableNameProvider interface {
TableName() string
@@ -148,3 +256,39 @@ type PrimaryKeyNameProvider interface {
type SchemaProvider interface {
SchemaName() string
}
// SpecHandler interface represents common functionality across all spec handlers
// This is the base interface implemented by:
// - resolvespec.Handler: Handles CRUD operations via request body with explicit operation field
// - restheadspec.Handler: Handles CRUD operations via HTTP methods (GET/POST/PUT/DELETE)
// - funcspec.Handler: Handles custom SQL query execution with dynamic parameters
//
// The interface hierarchy is:
//
// SpecHandler (base)
// ├── CRUDHandler (resolvespec, restheadspec)
// └── QueryHandler (funcspec)
type SpecHandler interface {
// GetDatabase returns the underlying database connection
GetDatabase() Database
}
// CRUDHandler interface for handlers that support CRUD operations
// This is implemented by resolvespec.Handler and restheadspec.Handler
type CRUDHandler interface {
SpecHandler
// Handle processes API requests through router-agnostic interface
Handle(w ResponseWriter, r Request, params map[string]string)
// HandleGet processes GET requests for metadata
HandleGet(w ResponseWriter, r Request, params map[string]string)
}
// QueryHandler interface for handlers that execute SQL queries
// This is implemented by funcspec.Handler
// Note: funcspec uses standard http.ResponseWriter and *http.Request instead of common interfaces
type QueryHandler interface {
SpecHandler
// Methods are defined in funcspec package due to different function signature requirements
}

View File

@@ -5,6 +5,8 @@ import (
"strings"
"github.com/bitechdev/ResolveSpec/pkg/logger"
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
"github.com/bitechdev/ResolveSpec/pkg/reflection"
)
// ValidateAndFixPreloadWhere validates that the WHERE clause for a preload contains
@@ -96,6 +98,173 @@ func IsSQLExpression(cond string) bool {
return false
}
// IsTrivialCondition checks if a condition is trivial and always evaluates to true
// These conditions should be removed from WHERE clauses as they have no filtering effect
func IsTrivialCondition(cond string) bool {
cond = strings.TrimSpace(cond)
lowerCond := strings.ToLower(cond)
// Conditions that always evaluate to true
trivialConditions := []string{
"1=1", "1 = 1", "1= 1", "1 =1",
"true", "true = true", "true=true", "true= true", "true =true",
"0=0", "0 = 0", "0= 0", "0 =0",
}
for _, trivial := range trivialConditions {
if lowerCond == trivial {
return true
}
}
return false
}
// SanitizeWhereClause removes trivial conditions and optionally prefixes table/relation names to columns
// This function should be used everywhere a WHERE statement is sent to ensure clean, efficient SQL
//
// Parameters:
// - where: The WHERE clause string to sanitize
// - tableName: Optional table/relation name to prefix to column references (empty string to skip prefixing)
//
// Returns:
// - The sanitized WHERE clause with trivial conditions removed and columns optionally prefixed
// - An empty string if all conditions were trivial or the input was empty
func SanitizeWhereClause(where string, tableName string) string {
if where == "" {
return ""
}
where = strings.TrimSpace(where)
// Strip outer parentheses and re-trim
where = stripOuterParentheses(where)
// Get valid columns from the model if tableName is provided
var validColumns map[string]bool
if tableName != "" {
validColumns = getValidColumnsForTable(tableName)
}
// Split by AND to handle multiple conditions
conditions := splitByAND(where)
validConditions := make([]string, 0, len(conditions))
for _, cond := range conditions {
cond = strings.TrimSpace(cond)
if cond == "" {
continue
}
// Strip parentheses from the condition before checking
condToCheck := stripOuterParentheses(cond)
// Skip trivial conditions that always evaluate to true
if IsTrivialCondition(condToCheck) {
logger.Debug("Removing trivial condition: '%s'", cond)
continue
}
// If tableName is provided and the condition doesn't already have a table prefix,
// attempt to add it
if tableName != "" && !hasTablePrefix(condToCheck) {
// Check if this is a SQL expression/literal that shouldn't be prefixed
if !IsSQLExpression(strings.ToLower(condToCheck)) {
// Extract the column name and prefix it
columnName := ExtractColumnName(condToCheck)
if columnName != "" {
// Only prefix if this is a valid column in the model
// If we don't have model info (validColumns is nil), prefix anyway for backward compatibility
if validColumns == nil || isValidColumn(columnName, validColumns) {
// Replace in the original condition (without stripped parens)
cond = strings.Replace(cond, columnName, tableName+"."+columnName, 1)
logger.Debug("Prefixed column in condition: '%s'", cond)
} else {
logger.Debug("Skipping prefix for '%s' - not a valid column in model", columnName)
}
}
}
}
validConditions = append(validConditions, cond)
}
if len(validConditions) == 0 {
return ""
}
result := strings.Join(validConditions, " AND ")
if result != where {
logger.Debug("Sanitized WHERE clause: '%s' -> '%s'", where, result)
}
return result
}
// stripOuterParentheses removes matching outer parentheses from a string
// It handles nested parentheses correctly
func stripOuterParentheses(s string) string {
s = strings.TrimSpace(s)
for {
if len(s) < 2 || s[0] != '(' || s[len(s)-1] != ')' {
return s
}
// Check if these parentheses match (i.e., they're the outermost pair)
depth := 0
matched := false
for i := 0; i < len(s); i++ {
switch s[i] {
case '(':
depth++
case ')':
depth--
if depth == 0 && i == len(s)-1 {
matched = true
} else if depth == 0 {
// Found a closing paren before the end, so outer parens don't match
return s
}
}
}
if !matched {
return s
}
// Strip the outer parentheses and continue
s = strings.TrimSpace(s[1 : len(s)-1])
}
}
// splitByAND splits a WHERE clause by AND operators (case-insensitive)
// This is a simple split that doesn't handle nested parentheses or complex expressions
func splitByAND(where string) []string {
// First try uppercase AND
conditions := strings.Split(where, " AND ")
// If we didn't split on uppercase, try lowercase
if len(conditions) == 1 {
conditions = strings.Split(where, " and ")
}
// If we still didn't split, try mixed case
if len(conditions) == 1 {
conditions = strings.Split(where, " And ")
}
return conditions
}
// hasTablePrefix checks if a condition already has a table/relation prefix (contains a dot)
func hasTablePrefix(cond string) bool {
// Look for patterns like "table.column" or "`table`.`column`" or "\"table\".\"column\""
return strings.Contains(cond, ".")
}
// ExtractColumnName extracts the column name from a WHERE condition
// For example: "status = 'active'" returns "status"
func ExtractColumnName(cond string) string {
@@ -134,3 +303,38 @@ func IsSQLKeyword(word string) bool {
}
return false
}
// getValidColumnsForTable retrieves the valid SQL columns for a table from the model registry
// Returns a map of column names for fast lookup, or nil if the model is not found
func getValidColumnsForTable(tableName string) map[string]bool {
// Try to get the model from the registry
model, err := modelregistry.GetModelByName(tableName)
if err != nil {
// Model not found, return nil to indicate we should use fallback behavior
return nil
}
// Get SQL columns from the model
columns := reflection.GetSQLModelColumns(model)
if len(columns) == 0 {
// No columns found, return nil
return nil
}
// Build a map for fast lookup
columnMap := make(map[string]bool, len(columns))
for _, col := range columns {
columnMap[strings.ToLower(col)] = true
}
return columnMap
}
// isValidColumn checks if a column name exists in the valid columns map
// Handles case-insensitive comparison
func isValidColumn(columnName string, validColumns map[string]bool) bool {
if validColumns == nil {
return true // No model info, assume valid
}
return validColumns[strings.ToLower(columnName)]
}

View File

@@ -0,0 +1,224 @@
package common
import (
"testing"
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
)
func TestSanitizeWhereClause(t *testing.T) {
tests := []struct {
name string
where string
tableName string
expected string
}{
{
name: "trivial conditions in parentheses",
where: "(true AND true AND true)",
tableName: "mastertask",
expected: "",
},
{
name: "trivial conditions without parentheses",
where: "true AND true AND true",
tableName: "mastertask",
expected: "",
},
{
name: "single trivial condition",
where: "true",
tableName: "mastertask",
expected: "",
},
{
name: "valid condition with parentheses",
where: "(status = 'active')",
tableName: "users",
expected: "users.status = 'active'",
},
{
name: "mixed trivial and valid conditions",
where: "true AND status = 'active' AND 1=1",
tableName: "users",
expected: "users.status = 'active'",
},
{
name: "condition already with table prefix",
where: "users.status = 'active'",
tableName: "users",
expected: "users.status = 'active'",
},
{
name: "multiple valid conditions",
where: "status = 'active' AND age > 18",
tableName: "users",
expected: "users.status = 'active' AND users.age > 18",
},
{
name: "no table name provided",
where: "status = 'active'",
tableName: "",
expected: "status = 'active'",
},
{
name: "empty where clause",
where: "",
tableName: "users",
expected: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := SanitizeWhereClause(tt.where, tt.tableName)
if result != tt.expected {
t.Errorf("SanitizeWhereClause(%q, %q) = %q; want %q", tt.where, tt.tableName, result, tt.expected)
}
})
}
}
func TestStripOuterParentheses(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{
name: "single level parentheses",
input: "(true)",
expected: "true",
},
{
name: "multiple levels",
input: "((true))",
expected: "true",
},
{
name: "no parentheses",
input: "true",
expected: "true",
},
{
name: "mismatched parentheses",
input: "(true",
expected: "(true",
},
{
name: "complex expression",
input: "(a AND b)",
expected: "a AND b",
},
{
name: "nested but not outer",
input: "(a AND (b OR c)) AND d",
expected: "(a AND (b OR c)) AND d",
},
{
name: "with spaces",
input: " ( true ) ",
expected: "true",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := stripOuterParentheses(tt.input)
if result != tt.expected {
t.Errorf("stripOuterParentheses(%q) = %q; want %q", tt.input, result, tt.expected)
}
})
}
}
func TestIsTrivialCondition(t *testing.T) {
tests := []struct {
name string
input string
expected bool
}{
{"true", "true", true},
{"true with spaces", " true ", true},
{"TRUE uppercase", "TRUE", true},
{"1=1", "1=1", true},
{"1 = 1", "1 = 1", true},
{"true = true", "true = true", true},
{"valid condition", "status = 'active'", false},
{"false", "false", false},
{"column name", "is_active", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := IsTrivialCondition(tt.input)
if result != tt.expected {
t.Errorf("IsTrivialCondition(%q) = %v; want %v", tt.input, result, tt.expected)
}
})
}
}
// Test model for model-aware sanitization tests
type MasterTask struct {
ID int `bun:"id,pk"`
Name string `bun:"name"`
Status string `bun:"status"`
UserID int `bun:"user_id"`
}
func TestSanitizeWhereClauseWithModel(t *testing.T) {
// Register the test model
err := modelregistry.RegisterModel(MasterTask{}, "mastertask")
if err != nil {
// Model might already be registered, ignore error
t.Logf("Model registration returned: %v", err)
}
tests := []struct {
name string
where string
tableName string
expected string
}{
{
name: "valid column gets prefixed",
where: "status = 'active'",
tableName: "mastertask",
expected: "mastertask.status = 'active'",
},
{
name: "multiple valid columns get prefixed",
where: "status = 'active' AND user_id = 123",
tableName: "mastertask",
expected: "mastertask.status = 'active' AND mastertask.user_id = 123",
},
{
name: "invalid column does not get prefixed",
where: "invalid_column = 'value'",
tableName: "mastertask",
expected: "invalid_column = 'value'",
},
{
name: "mix of valid and trivial conditions",
where: "true AND status = 'active' AND 1=1",
tableName: "mastertask",
expected: "mastertask.status = 'active'",
},
{
name: "parentheses with valid column",
where: "(status = 'active')",
tableName: "mastertask",
expected: "mastertask.status = 'active'",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := SanitizeWhereClause(tt.where, tt.tableName)
if result != tt.expected {
t.Errorf("SanitizeWhereClause(%q, %q) = %q; want %q", tt.where, tt.tableName, result, tt.expected)
}
})
}
}

View File

@@ -238,13 +238,13 @@ func (t *SqlTimeStamp) UnmarshalJSON(b []byte) error {
var err error
if b == nil {
t = &SqlTimeStamp{}
return nil
}
s := strings.Trim(strings.Trim(string(b), " "), "\"")
if s == "null" || s == "" || s == "0" ||
s == "0001-01-01T00:00:00" || s == "0001-01-01" {
t = &SqlTimeStamp{}
return nil
}
@@ -293,7 +293,7 @@ func (t *SqlTimeStamp) Scan(value interface{}) error {
// String - Override String format of time
func (t SqlTimeStamp) String() string {
return fmt.Sprintf("%s", time.Time(t).Format("2006-01-02T15:04:05"))
return time.Time(t).Format("2006-01-02T15:04:05")
}
// GetTime - Returns Time
@@ -308,7 +308,7 @@ func (t *SqlTimeStamp) SetTime(pTime time.Time) {
// Format - Formats the time
func (t SqlTimeStamp) Format(layout string) string {
return fmt.Sprintf("%s", time.Time(t).Format(layout))
return time.Time(t).Format(layout)
}
func SqlTimeStampNow() SqlTimeStamp {
@@ -420,7 +420,6 @@ func (t *SqlDate) UnmarshalJSON(b []byte) error {
if s == "null" || s == "" || s == "0" ||
strings.HasPrefix(s, "0001-01-01T00:00:00") ||
s == "0001-01-01" {
t = &SqlDate{}
return nil
}
@@ -434,7 +433,7 @@ func (t *SqlDate) UnmarshalJSON(b []byte) error {
// MarshalJSON - Override JSON format of time
func (t SqlDate) MarshalJSON() ([]byte, error) {
tmstr := time.Time(t).Format("2006-01-02") //time.RFC3339
tmstr := time.Time(t).Format("2006-01-02") // time.RFC3339
if strings.HasPrefix(tmstr, "0001-01-01") {
return []byte("null"), nil
}
@@ -482,7 +481,7 @@ func (t SqlDate) Int64() int64 {
// String - Override String format of time
func (t SqlDate) String() string {
tmstr := time.Time(t).Format("2006-01-02") //time.RFC3339
tmstr := time.Time(t).Format("2006-01-02") // time.RFC3339
if strings.HasPrefix(tmstr, "0001-01-01") || strings.HasPrefix(tmstr, "1800-12-31") {
return "0"
}
@@ -517,8 +516,8 @@ func (t *SqlTime) UnmarshalJSON(b []byte) error {
*t = SqlTime{}
return nil
}
tx := time.Time{}
tx, err = tryParseDT(s)
tx, err := tryParseDT(s)
*t = SqlTime(tx)
return err
@@ -642,9 +641,8 @@ func (n SqlJSONB) AsSlice() ([]any, error) {
func (n *SqlJSONB) UnmarshalJSON(b []byte) error {
s := strings.Trim(strings.Trim(string(b), " "), "\"")
invalid := (s == "null" || s == "" || len(s) < 2) || !(strings.Contains(s, "{") || strings.Contains(s, "["))
invalid := (s == "null" || s == "" || len(s) < 2) || (!strings.Contains(s, "{") && !strings.Contains(s, "["))
if invalid {
s = ""
return nil
}
@@ -661,7 +659,7 @@ func (n SqlJSONB) MarshalJSON() ([]byte, error) {
var obj interface{}
err := json.Unmarshal(n, &obj)
if err != nil {
//fmt.Printf("Invalid JSON %v", err)
// fmt.Printf("Invalid JSON %v", err)
return []byte("null"), nil
}
@@ -725,7 +723,6 @@ func (n *SqlUUID) UnmarshalJSON(b []byte) error {
s := strings.Trim(strings.Trim(string(b), " "), "\"")
invalid := (s == "null" || s == "" || len(s) < 30)
if invalid {
s = ""
return nil
}
*n = SqlUUID(sql.NullString{String: s, Valid: !invalid})

View File

@@ -32,15 +32,22 @@ type Parameter struct {
}
type PreloadOption struct {
Relation string `json:"relation"`
Columns []string `json:"columns"`
OmitColumns []string `json:"omit_columns"`
Sort []SortOption `json:"sort"`
Filters []FilterOption `json:"filters"`
Where string `json:"where"`
Limit *int `json:"limit"`
Offset *int `json:"offset"`
Updatable *bool `json:"updateable"` // if true, the relation can be updated
Relation string `json:"relation"`
Columns []string `json:"columns"`
OmitColumns []string `json:"omit_columns"`
Sort []SortOption `json:"sort"`
Filters []FilterOption `json:"filters"`
Where string `json:"where"`
Limit *int `json:"limit"`
Offset *int `json:"offset"`
Updatable *bool `json:"updateable"` // if true, the relation can be updated
ComputedQL map[string]string `json:"computed_ql"` // Computed columns as SQL expressions
Recursive bool `json:"recursive"` // if true, preload recursively up to 5 levels
// Relationship keys from XFiles - used to build proper foreign key filters
PrimaryKey string `json:"primary_key"` // Primary key of the related table
RelatedKey string `json:"related_key"` // For child tables: column in child that references parent
ForeignKey string `json:"foreign_key"` // For parent tables: column in current table that references parent
}
type FilterOption struct {

View File

@@ -6,6 +6,7 @@ import (
"strings"
"github.com/bitechdev/ResolveSpec/pkg/logger"
"github.com/bitechdev/ResolveSpec/pkg/reflection"
)
// ColumnValidator validates column names against a model's fields
@@ -92,23 +93,6 @@ func (v *ColumnValidator) getColumnName(field reflect.StructField) string {
return strings.ToLower(field.Name)
}
// extractSourceColumn extracts the base column name from PostgreSQL JSON operators
// Examples:
// - "columna->>'val'" returns "columna"
// - "columna->'key'" returns "columna"
// - "columna" returns "columna"
// - "table.columna->>'val'" returns "table.columna"
func extractSourceColumn(colName string) string {
// Check for PostgreSQL JSON operators: -> and ->>
if idx := strings.Index(colName, "->>"); idx != -1 {
return strings.TrimSpace(colName[:idx])
}
if idx := strings.Index(colName, "->"); idx != -1 {
return strings.TrimSpace(colName[:idx])
}
return colName
}
// ValidateColumn validates a single column name
// Returns nil if valid, error if invalid
// Columns prefixed with "cql" (case insensitive) are always valid
@@ -125,7 +109,7 @@ func (v *ColumnValidator) ValidateColumn(column string) error {
}
// Extract source column name (remove JSON operators like ->> or ->)
sourceColumn := extractSourceColumn(column)
sourceColumn := reflection.ExtractSourceColumn(column)
// Check if column exists in model
if _, exists := v.validColumns[strings.ToLower(sourceColumn)]; !exists {

View File

@@ -2,6 +2,8 @@ package common
import (
"testing"
"github.com/bitechdev/ResolveSpec/pkg/reflection"
)
func TestExtractSourceColumn(t *testing.T) {
@@ -49,9 +51,9 @@ func TestExtractSourceColumn(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
result := extractSourceColumn(tc.input)
result := reflection.ExtractSourceColumn(tc.input)
if result != tc.expected {
t.Errorf("extractSourceColumn(%q) = %q; want %q", tc.input, result, tc.expected)
t.Errorf("reflection.ExtractSourceColumn(%q) = %q; want %q", tc.input, result, tc.expected)
}
})
}

View File

@@ -0,0 +1,948 @@
package funcspec
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"regexp"
"runtime/debug"
"strconv"
"strings"
"time"
"github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/logger"
"github.com/bitechdev/ResolveSpec/pkg/restheadspec"
"github.com/bitechdev/ResolveSpec/pkg/security"
)
// Handler handles function-based SQL API requests
type Handler struct {
db common.Database
hooks *HookRegistry
}
// NewHandler creates a new function API handler
func NewHandler(db common.Database) *Handler {
return &Handler{
db: db,
hooks: NewHookRegistry(),
}
}
// GetDatabase returns the underlying database connection
// Implements common.SpecHandler interface
func (h *Handler) GetDatabase() common.Database {
return h.db
}
// Hooks returns the hook registry for this handler
// Use this to register custom hooks for operations
func (h *Handler) Hooks() *HookRegistry {
return h.hooks
}
// HTTPFuncType is a function type for HTTP handlers
type HTTPFuncType func(http.ResponseWriter, *http.Request)
// SqlQueryList creates an HTTP handler that executes a SQL query and returns a list with pagination
func (h *Handler) SqlQueryList(sqlquery string, pNoCount, pBlankparms, pAllowFilter bool) HTTPFuncType {
return func(w http.ResponseWriter, r *http.Request) {
defer func() {
if err := recover(); err != nil {
stack := debug.Stack()
logger.Error("Panic in SqlQueryList: %v\nStack trace:\n%s", err, string(stack))
http.Error(w, fmt.Sprintf("Internal server error: %v", err), http.StatusInternalServerError)
}
}()
ctx, cancel := context.WithTimeout(r.Context(), 900*time.Second)
defer cancel()
var dbobjlist []map[string]interface{}
var total int64
propQry := make(map[string]string)
inputvars := make([]string, 0)
metainfo := make(map[string]interface{})
variables := make(map[string]interface{})
complexAPI := false
// Get user context from security package
userCtx, ok := security.GetUserContext(ctx)
if !ok {
logger.Warn("No user context found in request")
userCtx = &security.UserContext{UserID: 0, UserName: "anonymous"}
}
w.Header().Set("Content-Type", "application/json")
// Initialize hook context
hookCtx := &HookContext{
Context: ctx,
Handler: h,
Request: r,
Writer: w,
SQLQuery: sqlquery,
Variables: variables,
InputVars: inputvars,
MetaInfo: metainfo,
PropQry: propQry,
UserContext: userCtx,
NoCount: pNoCount,
BlankParams: pBlankparms,
AllowFilter: pAllowFilter,
ComplexAPI: complexAPI,
}
// Execute BeforeQueryList hook
if err := h.hooks.Execute(BeforeQueryList, hookCtx); err != nil {
logger.Error("BeforeQueryList hook failed: %v", err)
sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err)
return
}
// Check if hook aborted the operation
if hookCtx.Abort {
if hookCtx.AbortCode == 0 {
hookCtx.AbortCode = http.StatusBadRequest
}
sendError(w, hookCtx.AbortCode, "operation_aborted", hookCtx.AbortMessage, nil)
return
}
// Use potentially modified SQL query and variables from hooks
sqlquery = hookCtx.SQLQuery
variables = hookCtx.Variables
// complexAPI = hookCtx.ComplexAPI
// Extract input variables from SQL query (placeholders like [variable])
sqlquery = h.extractInputVariables(sqlquery, &inputvars)
// Merge URL path parameters
sqlquery = h.mergePathParams(r, sqlquery, variables)
// Parse comprehensive parameters from headers and query string
reqParams := h.ParseParameters(r)
complexAPI = reqParams.ComplexAPI
// Merge query string parameters
sqlquery = h.mergeQueryParams(r, sqlquery, variables, pAllowFilter, propQry)
// Merge header parameters
sqlquery = h.mergeHeaderParams(r, sqlquery, variables, propQry, &complexAPI)
// Apply filters from parsed parameters (if not already applied by pAllowFilter)
if !pAllowFilter {
sqlquery = h.ApplyFilters(sqlquery, reqParams)
}
// Apply field selection
sqlquery = h.ApplyFieldSelection(sqlquery, reqParams)
// Apply DISTINCT if requested
sqlquery = h.ApplyDistinct(sqlquery, reqParams)
// Override pNoCount if skipcount is specified
if reqParams.SkipCount {
pNoCount = true
}
// Build metainfo
metainfo["ipaddress"] = getIPAddress(r)
metainfo["url"] = r.RequestURI
metainfo["user"] = userCtx.UserName
metainfo["rid_user"] = fmt.Sprintf("%d", userCtx.UserID)
metainfo["method"] = r.Method
metainfo["variables"] = variables
// Replace meta variables in SQL
sqlquery = h.replaceMetaVariables(sqlquery, r, userCtx, metainfo, variables)
// Remove unused input variables
if pBlankparms {
for _, kw := range inputvars {
replacement := getReplacementForBlankParam(sqlquery, kw)
sqlquery = strings.ReplaceAll(sqlquery, kw, replacement)
logger.Debug("Replaced unused variable %s with: %s", kw, replacement)
}
}
// Update hook context with latest SQL query and variables
hookCtx.SQLQuery = sqlquery
hookCtx.Variables = variables
hookCtx.InputVars = inputvars
// Execute query within transaction
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
sqlqueryCnt := sqlquery
// Parse sorting and pagination parameters
sortcols, limit, offset := h.parsePaginationParams(r)
// Override with parsed parameters if available
if reqParams.SortColumns != "" {
sortcols = reqParams.SortColumns
}
if reqParams.Limit > 0 {
limit = reqParams.Limit
}
if reqParams.Offset > 0 {
offset = reqParams.Offset
}
hookCtx.SortColumns = sortcols
hookCtx.Limit = limit
hookCtx.Offset = offset
fromPos := strings.Index(strings.ToLower(sqlquery), "from ")
orderbyPos := strings.Index(strings.ToLower(sqlquery), "order by")
if len(sortcols) > 0 && (orderbyPos < 0 || (orderbyPos > 0 && orderbyPos < fromPos)) {
sqlquery = fmt.Sprintf("%s \nORDER BY %s", sqlquery, ValidSQL(sortcols, "select"))
}
if !pNoCount {
if limit > 0 && offset > 0 {
sqlquery = fmt.Sprintf("%s \nLIMIT %d OFFSET %d", sqlquery, limit, offset)
} else if limit > 0 {
sqlquery = fmt.Sprintf("%s \nLIMIT %d", sqlquery, limit)
} else {
sqlquery = fmt.Sprintf("%s \nLIMIT %d", sqlquery, 20000)
}
// Get total count
countQuery := fmt.Sprintf("SELECT COUNT(1) FROM (%s) cnts", sqlqueryCnt)
var countResult struct{ Count int64 }
if err := tx.Query(ctx, &countResult, countQuery); err != nil {
sendError(w, http.StatusBadRequest, "count_failed", "Failed to retrieve record count", err)
return err
}
total = countResult.Count
}
// Execute BeforeSQLExec hook
hookCtx.SQLQuery = sqlquery
if err := h.hooks.Execute(BeforeSQLExec, hookCtx); err != nil {
logger.Error("BeforeSQLExec hook failed: %v", err)
sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err)
return err
}
// Use potentially modified SQL query from hook
sqlquery = hookCtx.SQLQuery
// Execute main query
rows := make([]map[string]interface{}, 0)
if err := tx.Query(ctx, &rows, sqlquery); err != nil {
sendError(w, http.StatusBadRequest, "query_failed", "Failed to retrieve records", err)
return err
}
dbobjlist = rows
if pNoCount {
total = int64(len(dbobjlist))
}
// Execute AfterSQLExec hook
hookCtx.Result = dbobjlist
hookCtx.Total = total
if err := h.hooks.Execute(AfterSQLExec, hookCtx); err != nil {
logger.Error("AfterSQLExec hook failed: %v", err)
sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err)
return err
}
// Use potentially modified result from hook
if modifiedResult, ok := hookCtx.Result.([]map[string]interface{}); ok {
dbobjlist = modifiedResult
}
total = hookCtx.Total
return nil
})
if err != nil {
logger.Error("Transaction failed: %v", err)
return
}
// Execute AfterQueryList hook
hookCtx.Result = dbobjlist
hookCtx.Total = total
hookCtx.Error = err
if err := h.hooks.Execute(AfterQueryList, hookCtx); err != nil {
logger.Error("AfterQueryList hook failed: %v", err)
sendError(w, http.StatusInternalServerError, "hook_error", "Hook execution failed", err)
return
}
// Use potentially modified result from hook
if modifiedResult, ok := hookCtx.Result.([]map[string]interface{}); ok {
dbobjlist = modifiedResult
}
total = hookCtx.Total
// Set response headers
respOffset := 0
if offsetStr := r.URL.Query().Get("offset"); offsetStr != "" {
if o, err := strconv.Atoi(offsetStr); err == nil {
respOffset = o
}
}
w.Header().Set("Content-Range", fmt.Sprintf("items %d-%d/%d", respOffset, respOffset+len(dbobjlist), total))
logger.Info("Serving: Records %d of %d", len(dbobjlist), total)
// Execute BeforeResponse hook
hookCtx.Result = dbobjlist
hookCtx.Total = total
if err := h.hooks.Execute(BeforeResponse, hookCtx); err != nil {
logger.Error("BeforeResponse hook failed: %v", err)
sendError(w, http.StatusInternalServerError, "hook_error", "Hook execution failed", err)
return
}
// Use potentially modified result from hook
if modifiedResult, ok := hookCtx.Result.([]map[string]interface{}); ok {
dbobjlist = modifiedResult
}
if len(dbobjlist) == 0 {
_, _ = w.Write([]byte("[]"))
return
}
// Format response based on response format
switch reqParams.ResponseFormat {
case "syncfusion":
// Syncfusion format: { result: data, count: total }
response := map[string]interface{}{
"result": dbobjlist,
"count": total,
}
data, err := json.Marshal(response)
if err != nil {
sendError(w, http.StatusInternalServerError, "json_error", "Could not marshal response", err)
} else {
if int64(len(dbobjlist)) < total {
w.WriteHeader(http.StatusPartialContent)
}
_, _ = w.Write(data)
}
case "detail":
// Detail format: complex API with metadata
metaobj := map[string]interface{}{
"items": dbobjlist,
"count": fmt.Sprintf("%d", len(dbobjlist)),
"total": fmt.Sprintf("%d", total),
"tablename": r.URL.Path,
"tableprefix": "gsql",
}
data, err := json.Marshal(metaobj)
if err != nil {
sendError(w, http.StatusInternalServerError, "json_error", "Could not marshal response", err)
} else {
if int64(len(dbobjlist)) < total {
w.WriteHeader(http.StatusPartialContent)
}
_, _ = w.Write(data)
}
default:
// Simple format: just return the data array (or complex API if requested)
if complexAPI {
metaobj := map[string]interface{}{
"items": dbobjlist,
"count": fmt.Sprintf("%d", len(dbobjlist)),
"total": fmt.Sprintf("%d", total),
"tablename": r.URL.Path,
"tableprefix": "gsql",
}
data, err := json.Marshal(metaobj)
if err != nil {
sendError(w, http.StatusInternalServerError, "json_error", "Could not marshal response", err)
} else {
if int64(len(dbobjlist)) < total {
w.WriteHeader(http.StatusPartialContent)
}
_, _ = w.Write(data)
}
} else {
data, err := json.Marshal(dbobjlist)
if err != nil {
sendError(w, http.StatusInternalServerError, "json_error", "Could not marshal response", err)
} else {
if int64(len(dbobjlist)) < total {
w.WriteHeader(http.StatusPartialContent)
}
_, _ = w.Write(data)
}
}
}
}
}
// SqlQuery creates an HTTP handler that executes a SQL query and returns a single record
func (h *Handler) SqlQuery(sqlquery string, pBlankparms bool) HTTPFuncType {
return func(w http.ResponseWriter, r *http.Request) {
defer func() {
if err := recover(); err != nil {
stack := debug.Stack()
logger.Error("Panic in SqlQuery: %v\nStack trace:\n%s", err, string(stack))
http.Error(w, fmt.Sprintf("Internal server error: %v", err), http.StatusInternalServerError)
}
}()
ctx, cancel := context.WithTimeout(r.Context(), 600*time.Second)
defer cancel()
propQry := make(map[string]string)
inputvars := make([]string, 0)
metainfo := make(map[string]interface{})
variables := make(map[string]interface{})
dbobj := make(map[string]interface{})
complexAPI := false
// Get user context from security package
userCtx, ok := security.GetUserContext(ctx)
if !ok {
logger.Warn("No user context found in request")
userCtx = &security.UserContext{UserID: 0, UserName: "anonymous"}
}
w.Header().Set("Content-Type", "application/json")
// Initialize hook context
hookCtx := &HookContext{
Context: ctx,
Handler: h,
Request: r,
Writer: w,
SQLQuery: sqlquery,
Variables: variables,
InputVars: inputvars,
MetaInfo: metainfo,
PropQry: propQry,
UserContext: userCtx,
BlankParams: pBlankparms,
ComplexAPI: complexAPI,
}
// Execute BeforeQuery hook
if err := h.hooks.Execute(BeforeQuery, hookCtx); err != nil {
logger.Error("BeforeQuery hook failed: %v", err)
sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err)
return
}
// Check if hook aborted the operation
if hookCtx.Abort {
if hookCtx.AbortCode == 0 {
hookCtx.AbortCode = http.StatusBadRequest
}
sendError(w, hookCtx.AbortCode, "operation_aborted", hookCtx.AbortMessage, nil)
return
}
// Use potentially modified SQL query and variables from hooks
sqlquery = hookCtx.SQLQuery
variables = hookCtx.Variables
// Extract input variables from SQL query
sqlquery = h.extractInputVariables(sqlquery, &inputvars)
// Merge URL path parameters
sqlquery = h.mergePathParams(r, sqlquery, variables)
// Parse comprehensive parameters from headers and query string
reqParams := h.ParseParameters(r)
complexAPI = reqParams.ComplexAPI
// Merge query string parameters
sqlquery = h.mergeQueryParams(r, sqlquery, variables, false, propQry)
// Merge header parameters
sqlquery = h.mergeHeaderParams(r, sqlquery, variables, propQry, &complexAPI)
hookCtx.ComplexAPI = complexAPI
// Apply filters from parsed parameters
sqlquery = h.ApplyFilters(sqlquery, reqParams)
// Apply field selection
sqlquery = h.ApplyFieldSelection(sqlquery, reqParams)
// Apply DISTINCT if requested
sqlquery = h.ApplyDistinct(sqlquery, reqParams)
// Build metainfo
metainfo["ipaddress"] = getIPAddress(r)
metainfo["url"] = r.RequestURI
metainfo["user"] = userCtx.UserName
metainfo["rid_user"] = fmt.Sprintf("%d", userCtx.UserID)
metainfo["method"] = r.Method
metainfo["variables"] = variables
// Replace meta variables in SQL
sqlquery = h.replaceMetaVariables(sqlquery, r, userCtx, metainfo, variables)
// Apply field filters from headers
for k, val := range propQry {
kLower := strings.ToLower(k)
if strings.HasPrefix(kLower, "x-fieldfilter-") {
colname := strings.ReplaceAll(kLower, "x-fieldfilter-", "")
if strings.Contains(strings.ToLower(sqlquery), colname) {
if val == "" || val == "0" {
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("COALESCE(%s, 0) = %s", ValidSQL(colname, "colname"), ValidSQL(val, "colvalue")))
} else {
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s = %s", ValidSQL(colname, "colname"), ValidSQL(val, "colvalue")))
}
}
}
}
// Remove unused input variables
if pBlankparms {
for _, kw := range inputvars {
replacement := getReplacementForBlankParam(sqlquery, kw)
sqlquery = strings.ReplaceAll(sqlquery, kw, replacement)
logger.Debug("Replaced unused variable %s with: %s", kw, replacement)
}
}
// Update hook context with latest SQL query and variables
hookCtx.SQLQuery = sqlquery
hookCtx.Variables = variables
hookCtx.InputVars = inputvars
// Execute query within transaction
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
// Execute BeforeSQLExec hook
if err := h.hooks.Execute(BeforeSQLExec, hookCtx); err != nil {
logger.Error("BeforeSQLExec hook failed: %v", err)
sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err)
return err
}
// Use potentially modified SQL query from hook
sqlquery = hookCtx.SQLQuery
// Execute main query
rows := make([]map[string]interface{}, 0)
if err := tx.Query(ctx, &rows, sqlquery); err != nil {
sendError(w, http.StatusBadRequest, "query_failed", "Failed to retrieve records", err)
return err
}
if len(rows) > 0 {
dbobj = rows[0]
}
// Execute AfterSQLExec hook
hookCtx.Result = dbobj
if err := h.hooks.Execute(AfterSQLExec, hookCtx); err != nil {
logger.Error("AfterSQLExec hook failed: %v", err)
sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err)
return err
}
// Use potentially modified result from hook
if modifiedResult, ok := hookCtx.Result.(map[string]interface{}); ok {
dbobj = modifiedResult
}
return nil
})
if err != nil {
logger.Error("Transaction failed: %v", err)
return
}
// Execute AfterQuery hook
hookCtx.Result = dbobj
hookCtx.Error = err
if err := h.hooks.Execute(AfterQuery, hookCtx); err != nil {
logger.Error("AfterQuery hook failed: %v", err)
sendError(w, http.StatusInternalServerError, "hook_error", "Hook execution failed", err)
return
}
// Use potentially modified result from hook
if modifiedResult, ok := hookCtx.Result.(map[string]interface{}); ok {
dbobj = modifiedResult
}
// Execute BeforeResponse hook
hookCtx.Result = dbobj
if err := h.hooks.Execute(BeforeResponse, hookCtx); err != nil {
logger.Error("BeforeResponse hook failed: %v", err)
sendError(w, http.StatusInternalServerError, "hook_error", "Hook execution failed", err)
return
}
// Use potentially modified result from hook
if modifiedResult, ok := hookCtx.Result.(map[string]interface{}); ok {
dbobj = modifiedResult
}
// Check if response should be root-level data
if val, ok := dbobj["root_as_data"]; ok {
data, err := json.Marshal(val)
if err != nil {
sendError(w, http.StatusInternalServerError, "json_error", "Could not marshal response", err)
} else {
_, _ = w.Write(data)
}
return
}
// Marshal and send response
data, err := json.Marshal(dbobj)
if err != nil {
sendError(w, http.StatusInternalServerError, "json_error", "Could not marshal response", err)
} else {
_, _ = w.Write(data)
}
}
}
// Helper functions
// extractInputVariables extracts placeholders like [variable] from the SQL query
func (h *Handler) extractInputVariables(sqlquery string, inputvars *[]string) string {
testsqlquery := sqlquery
for i := 0; i <= strings.Count(sqlquery, "[")*4; i++ {
iStart := strings.Index(testsqlquery, "[")
if iStart < 0 {
break
}
iEnd := strings.Index(testsqlquery, "]")
if iEnd < 0 {
break
}
*inputvars = append(*inputvars, testsqlquery[iStart:iEnd+1])
testsqlquery = testsqlquery[iEnd+1:]
}
return sqlquery
}
// mergePathParams merges URL path parameters into the SQL query
func (h *Handler) mergePathParams(r *http.Request, sqlquery string, variables map[string]interface{}) string {
// Note: Path parameters would typically come from a router like gorilla/mux
// For now, this is a placeholder for path parameter extraction
return sqlquery
}
// mergeQueryParams merges query string parameters into the SQL query
func (h *Handler) mergeQueryParams(r *http.Request, sqlquery string, variables map[string]interface{}, allowFilter bool, propQry map[string]string) string {
for parmk, parmv := range r.URL.Query() {
if len(parmk) == 0 || len(parmv) == 0 {
continue
}
val := parmv[0]
dec, err := restheadspec.DecodeParam(val)
if err == nil {
val = dec
}
kword := fmt.Sprintf("[%s]", parmk)
variables[parmk] = val
// Replace in SQL if placeholder exists
if strings.Contains(sqlquery, kword) && len(val) > 0 {
if strings.HasPrefix(parmk, "p-") {
sqlquery = strings.ReplaceAll(sqlquery, kword, val)
}
}
// Add to propQry for x- prefixed params
if strings.HasPrefix(parmk, "x-") {
propQry[parmk] = val
}
// Apply filters if allowed
if allowFilter && len(parmk) > 1 && strings.Contains(strings.ToLower(sqlquery), strings.ToLower(parmk)) {
if len(parmv) > 1 {
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s IN (%s)", ValidSQL(parmk, "colname"), strings.Join(parmv, ",")))
} else {
if strings.Contains(val, "match=") {
colval := strings.ReplaceAll(val, "match=", "")
if colval != "*" {
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s ILIKE '%%%s%%'", ValidSQL(parmk, "colname"), ValidSQL(colval, "colvalue")))
}
} else if val == "" || val == "0" {
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("(%[1]s = %[2]s OR %[1]s IS NULL)", ValidSQL(parmk, "colname"), ValidSQL(val, "colvalue")))
} else {
if IsNumeric(val) {
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s = %s", ValidSQL(parmk, "colname"), ValidSQL(val, "colvalue")))
} else {
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s = '%s'", ValidSQL(parmk, "colname"), ValidSQL(val, "colvalue")))
}
}
}
}
}
return sqlquery
}
// mergeHeaderParams merges HTTP header parameters into the SQL query
func (h *Handler) mergeHeaderParams(r *http.Request, sqlquery string, variables map[string]interface{}, propQry map[string]string, complexAPI *bool) string {
for kc, v := range r.Header {
k := strings.ToLower(kc)
if !strings.HasPrefix(k, "x-") || len(v) == 0 {
continue
}
val := v[0]
dec, err := restheadspec.DecodeParam(val)
if err == nil {
val = dec
}
variables[k] = val
propQry[k] = val
kword := fmt.Sprintf("[%s]", k)
if strings.Contains(sqlquery, kword) {
sqlquery = strings.ReplaceAll(sqlquery, kword, val)
}
// Handle special headers
if strings.Contains(k, "x-fieldfilter-") {
colname := strings.ReplaceAll(k, "x-fieldfilter-", "")
if val == "" || val == "0" {
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("COALESCE(%s, 0) = %s", ValidSQL(colname, "colname"), ValidSQL(val, "colvalue")))
} else {
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s = %s", ValidSQL(colname, "colname"), ValidSQL(val, "colvalue")))
}
}
if strings.Contains(k, "x-searchfilter-") {
colname := strings.ReplaceAll(k, "x-searchfilter-", "")
sval := strings.ReplaceAll(val, "'", "")
if sval != "" {
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s ILIKE '%%%s%%'", ValidSQL(colname, "colname"), ValidSQL(sval, "colvalue")))
}
}
if strings.Contains(k, "x-custom-sql-w") {
colval := ValidSQL(val, "select")
if len(colval) > 0 {
sqlquery = sqlQryWhere(sqlquery, colval)
}
}
if strings.Contains(k, "x-simpleapi") {
*complexAPI = !strings.EqualFold(val, "1") && !strings.EqualFold(val, "true")
}
}
return sqlquery
}
// replaceMetaVariables replaces meta variables like [rid_user], [user], etc. in the SQL query
func (h *Handler) replaceMetaVariables(sqlquery string, r *http.Request, userCtx *security.UserContext, metainfo map[string]interface{}, variables map[string]interface{}) string {
if strings.Contains(sqlquery, "[p_meta_default]") {
data, _ := json.Marshal(metainfo)
sqlquery = strings.ReplaceAll(sqlquery, "[p_meta_default]", fmt.Sprintf("'%s'::jsonb", string(data)))
}
if strings.Contains(sqlquery, "[json_variables]") {
data, _ := json.Marshal(variables)
sqlquery = strings.ReplaceAll(sqlquery, "[json_variables]", fmt.Sprintf("'%s'::jsonb", string(data)))
}
if strings.Contains(sqlquery, "[rid_user]") {
sqlquery = strings.ReplaceAll(sqlquery, "[rid_user]", fmt.Sprintf("%d", userCtx.UserID))
}
if strings.Contains(sqlquery, "[user]") {
sqlquery = strings.ReplaceAll(sqlquery, "[user]", fmt.Sprintf("'%s'", userCtx.UserName))
}
if strings.Contains(sqlquery, "[rid_session]") {
sessionID, _ := strconv.ParseInt(userCtx.SessionID, 10, 64)
sqlquery = strings.ReplaceAll(sqlquery, "[rid_session]", fmt.Sprintf("%d", sessionID))
}
if strings.Contains(sqlquery, "[method]") {
sqlquery = strings.ReplaceAll(sqlquery, "[method]", r.Method)
}
if strings.Contains(sqlquery, "[post_body]") {
bodystr := ""
if r.Method == "POST" || r.Method == "PUT" {
if r.Body != nil {
contents, err := io.ReadAll(r.Body)
if err == nil {
bodystr = string(contents)
}
}
}
sqlquery = strings.ReplaceAll(sqlquery, "[post_body]", fmt.Sprintf("'%s'", bodystr))
}
return sqlquery
}
// parsePaginationParams extracts sort, limit, and offset parameters from request
func (h *Handler) parsePaginationParams(r *http.Request) (sortcols string, limit, offset int) {
limit = 20
offset = 0
if sortStr := r.URL.Query().Get("sort"); sortStr != "" {
sortcols = sortStr
}
if limitStr := r.URL.Query().Get("limit"); limitStr != "" {
if l, err := strconv.Atoi(limitStr); err == nil && l > 0 {
limit = l
}
}
if offsetStr := r.URL.Query().Get("offset"); offsetStr != "" {
if o, err := strconv.Atoi(offsetStr); err == nil && o >= 0 {
offset = o
}
}
return
}
// ValidSQL validates and sanitizes SQL input to prevent injection
// mode can be: "colname", "colvalue", "select"
func ValidSQL(input, mode string) string {
// Remove dangerous characters based on mode
switch mode {
case "colname":
// For column names, only allow alphanumeric, underscore, and dot
reg := regexp.MustCompile(`[^a-zA-Z0-9_\.]`)
return reg.ReplaceAllString(input, "")
case "colvalue":
// For column values, escape single quotes
return strings.ReplaceAll(input, "'", "''")
case "select":
// For SELECT clauses, be more permissive but still safe
// Remove semicolons and common SQL injection patterns
dangerous := []string{";", "--", "/*", "*/", "xp_", "sp_", "DROP ", "DELETE ", "TRUNCATE ", "UPDATE ", "INSERT "}
result := input
for _, d := range dangerous {
result = strings.ReplaceAll(result, d, "")
result = strings.ReplaceAll(result, strings.ToLower(d), "")
result = strings.ReplaceAll(result, strings.ToUpper(d), "")
}
return result
default:
return input
}
}
// sqlQryWhere adds a WHERE clause to a SQL query or appends to existing WHERE with AND
func sqlQryWhere(sqlquery, condition string) string {
lowerQuery := strings.ToLower(sqlquery)
wherePos := strings.Index(lowerQuery, " where ")
groupPos := strings.Index(lowerQuery, " group by")
orderPos := strings.Index(lowerQuery, " order by")
limitPos := strings.Index(lowerQuery, " limit ")
// Find the insertion point (before GROUP BY, ORDER BY, or LIMIT)
insertPos := len(sqlquery)
if groupPos > 0 && groupPos < insertPos {
insertPos = groupPos
}
if orderPos > 0 && orderPos < insertPos {
insertPos = orderPos
}
if limitPos > 0 && limitPos < insertPos {
insertPos = limitPos
}
if wherePos > 0 {
// WHERE exists, add AND condition before GROUP BY / ORDER BY / LIMIT
before := sqlquery[:insertPos]
after := sqlquery[insertPos:]
return fmt.Sprintf("%s AND %s %s", before, condition, after)
} else {
// No WHERE exists, add it before GROUP BY / ORDER BY / LIMIT
before := sqlquery[:insertPos]
after := sqlquery[insertPos:]
return fmt.Sprintf("%s WHERE %s %s", before, condition, after)
}
}
// IsNumeric checks if a string contains only numeric characters
func IsNumeric(s string) bool {
_, err := strconv.ParseFloat(s, 64)
return err == nil
}
// getReplacementForBlankParam determines the replacement value for an unused parameter
// based on whether it appears within quotes in the SQL query.
// It checks for PostgreSQL quotes: single quotes (”) and dollar quotes ($...$)
func getReplacementForBlankParam(sqlquery, param string) string {
// Find the parameter in the query
idx := strings.Index(sqlquery, param)
if idx < 0 {
return ""
}
// Check characters immediately before and after the parameter
var charBefore, charAfter byte
if idx > 0 {
charBefore = sqlquery[idx-1]
}
endIdx := idx + len(param)
if endIdx < len(sqlquery) {
charAfter = sqlquery[endIdx]
}
// Check if parameter is surrounded by quotes (single quote or dollar sign for PostgreSQL dollar-quoted strings)
if (charBefore == '\'' || charBefore == '$') && (charAfter == '\'' || charAfter == '$') {
// Parameter is in quotes, return empty string
return ""
}
// Parameter is not in quotes, return NULL
return "NULL"
}
// makeResultReceiver creates a slice of interface{} pointers for scanning SQL rows
// func makeResultReceiver(length int) []interface{} {
// result := make([]interface{}, length)
// for i := 0; i < length; i++ {
// var v interface{}
// result[i] = &v
// }
// return result
// }
// getIPAddress extracts the real IP address from the request
func getIPAddress(r *http.Request) string {
if forwarded := r.Header.Get("X-Forwarded-For"); forwarded != "" {
// X-Forwarded-For can contain multiple IPs, take the first one
ips := strings.Split(forwarded, ",")
return strings.TrimSpace(ips[0])
}
if realIP := r.Header.Get("X-Real-IP"); realIP != "" {
return realIP
}
return r.RemoteAddr
}
// sendError sends a JSON error response
func sendError(w http.ResponseWriter, status int, code, message string, err error) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
errObj := common.APIError{
Code: code,
Message: message,
}
if err != nil {
errObj.Detail = err.Error()
}
data, _ := json.Marshal(map[string]interface{}{
"success": false,
"error": errObj,
})
_, _ = w.Write(data)
}

View File

@@ -0,0 +1,899 @@
package funcspec
import (
"bytes"
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/security"
)
// MockDatabase implements common.Database interface for testing
type MockDatabase struct {
QueryFunc func(ctx context.Context, dest interface{}, query string, args ...interface{}) error
ExecFunc func(ctx context.Context, query string, args ...interface{}) (common.Result, error)
RunInTransactionFunc func(ctx context.Context, fn func(common.Database) error) error
}
func (m *MockDatabase) NewSelect() common.SelectQuery {
return nil
}
func (m *MockDatabase) NewInsert() common.InsertQuery {
return nil
}
func (m *MockDatabase) NewUpdate() common.UpdateQuery {
return nil
}
func (m *MockDatabase) NewDelete() common.DeleteQuery {
return nil
}
func (m *MockDatabase) Exec(ctx context.Context, query string, args ...interface{}) (common.Result, error) {
if m.ExecFunc != nil {
return m.ExecFunc(ctx, query, args...)
}
return &MockResult{rows: 0}, nil
}
func (m *MockDatabase) Query(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
if m.QueryFunc != nil {
return m.QueryFunc(ctx, dest, query, args...)
}
return nil
}
func (m *MockDatabase) BeginTx(ctx context.Context) (common.Database, error) {
return m, nil
}
func (m *MockDatabase) CommitTx(ctx context.Context) error {
return nil
}
func (m *MockDatabase) RollbackTx(ctx context.Context) error {
return nil
}
func (m *MockDatabase) RunInTransaction(ctx context.Context, fn func(common.Database) error) error {
if m.RunInTransactionFunc != nil {
return m.RunInTransactionFunc(ctx, fn)
}
return fn(m)
}
// MockResult implements common.Result interface for testing
type MockResult struct {
rows int64
id int64
}
func (m *MockResult) RowsAffected() int64 {
return m.rows
}
func (m *MockResult) LastInsertId() (int64, error) {
return m.id, nil
}
// Helper function to create a test request with user context
func createTestRequest(method, path string, queryParams map[string]string, headers map[string]string, body []byte) *http.Request {
u, _ := url.Parse(path)
if queryParams != nil {
q := u.Query()
for k, v := range queryParams {
q.Set(k, v)
}
u.RawQuery = q.Encode()
}
var bodyReader *bytes.Reader
if body != nil {
bodyReader = bytes.NewReader(body)
} else {
bodyReader = bytes.NewReader([]byte{})
}
req := httptest.NewRequest(method, u.String(), bodyReader)
if headers != nil {
for k, v := range headers {
req.Header.Set(k, v)
}
}
// Add user context
userCtx := &security.UserContext{
UserID: 1,
UserName: "testuser",
SessionID: "test-session-123",
}
ctx := context.WithValue(req.Context(), security.UserContextKey, userCtx)
req = req.WithContext(ctx)
return req
}
// TestNewHandler tests handler creation
func TestNewHandler(t *testing.T) {
db := &MockDatabase{}
handler := NewHandler(db)
if handler == nil {
t.Fatal("Expected handler to be created, got nil")
}
if handler.db != db {
t.Error("Expected handler to have the provided database")
}
if handler.hooks == nil {
t.Error("Expected handler to have a hook registry")
}
}
// TestHandlerHooks tests the Hooks method
func TestHandlerHooks(t *testing.T) {
handler := NewHandler(&MockDatabase{})
hooks := handler.Hooks()
if hooks == nil {
t.Fatal("Expected hooks registry to be non-nil")
}
// Should return the same instance
hooks2 := handler.Hooks()
if hooks != hooks2 {
t.Error("Expected Hooks() to return the same registry instance")
}
}
// TestExtractInputVariables tests the extractInputVariables function
func TestExtractInputVariables(t *testing.T) {
handler := NewHandler(&MockDatabase{})
tests := []struct {
name string
sqlQuery string
expectedVars []string
}{
{
name: "No variables",
sqlQuery: "SELECT * FROM users",
expectedVars: []string{},
},
{
name: "Single variable",
sqlQuery: "SELECT * FROM users WHERE id = [user_id]",
expectedVars: []string{"[user_id]"},
},
{
name: "Multiple variables",
sqlQuery: "SELECT * FROM users WHERE id = [user_id] AND name = [user_name]",
expectedVars: []string{"[user_id]", "[user_name]"},
},
{
name: "Nested brackets",
sqlQuery: "SELECT * FROM users WHERE data::jsonb @> '[field]'::jsonb AND id = [user_id]",
expectedVars: []string{"[field]", "[user_id]"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
inputvars := make([]string, 0)
result := handler.extractInputVariables(tt.sqlQuery, &inputvars)
if result != tt.sqlQuery {
t.Errorf("Expected SQL query to be unchanged, got %s", result)
}
if len(inputvars) != len(tt.expectedVars) {
t.Errorf("Expected %d variables, got %d: %v", len(tt.expectedVars), len(inputvars), inputvars)
return
}
for i, expected := range tt.expectedVars {
if inputvars[i] != expected {
t.Errorf("Expected variable %d to be %s, got %s", i, expected, inputvars[i])
}
}
})
}
}
// TestValidSQL tests the SQL sanitization function
func TestValidSQL(t *testing.T) {
tests := []struct {
name string
input string
mode string
expected string
}{
{
name: "Column name with valid characters",
input: "user_id",
mode: "colname",
expected: "user_id",
},
{
name: "Column name with dots (table.column)",
input: "users.user_id",
mode: "colname",
expected: "users.user_id",
},
{
name: "Column name with SQL injection attempt",
input: "id'; DROP TABLE users--",
mode: "colname",
expected: "idDROPTABLEusers",
},
{
name: "Column value with single quotes",
input: "O'Brien",
mode: "colvalue",
expected: "O''Brien",
},
{
name: "Select with dangerous keywords",
input: "name, email; DROP TABLE users",
mode: "select",
expected: "name, email TABLE users",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := ValidSQL(tt.input, tt.mode)
if result != tt.expected {
t.Errorf("ValidSQL(%q, %q) = %q, expected %q", tt.input, tt.mode, result, tt.expected)
}
})
}
}
// TestIsNumeric tests the IsNumeric function
func TestIsNumeric(t *testing.T) {
tests := []struct {
input string
expected bool
}{
{"123", true},
{"123.45", true},
{"-123", true},
{"-123.45", true},
{"0", true},
{"abc", false},
{"12.34.56", false},
{"", false},
{"123abc", false},
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
result := IsNumeric(tt.input)
if result != tt.expected {
t.Errorf("IsNumeric(%q) = %v, expected %v", tt.input, result, tt.expected)
}
})
}
}
// TestSqlQryWhere tests the WHERE clause manipulation
func TestSqlQryWhere(t *testing.T) {
tests := []struct {
name string
sqlQuery string
condition string
expected string
}{
{
name: "Add WHERE to query without WHERE",
sqlQuery: "SELECT * FROM users",
condition: "status = 'active'",
expected: "SELECT * FROM users WHERE status = 'active' ",
},
{
name: "Add AND to query with existing WHERE",
sqlQuery: "SELECT * FROM users WHERE id > 0",
condition: "status = 'active'",
expected: "SELECT * FROM users WHERE id > 0 AND status = 'active' ",
},
{
name: "Add WHERE before ORDER BY",
sqlQuery: "SELECT * FROM users ORDER BY name",
condition: "status = 'active'",
expected: "SELECT * FROM users WHERE status = 'active' ORDER BY name",
},
{
name: "Add WHERE before GROUP BY",
sqlQuery: "SELECT COUNT(*) FROM users GROUP BY department",
condition: "status = 'active'",
expected: "SELECT COUNT(*) FROM users WHERE status = 'active' GROUP BY department",
},
{
name: "Add WHERE before LIMIT",
sqlQuery: "SELECT * FROM users LIMIT 10",
condition: "status = 'active'",
expected: "SELECT * FROM users WHERE status = 'active' LIMIT 10",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := sqlQryWhere(tt.sqlQuery, tt.condition)
if result != tt.expected {
t.Errorf("sqlQryWhere() = %q, expected %q", result, tt.expected)
}
})
}
}
// TestGetIPAddress tests IP address extraction
func TestGetIPAddress(t *testing.T) {
tests := []struct {
name string
setupReq func() *http.Request
expected string
}{
{
name: "X-Forwarded-For header",
setupReq: func() *http.Request {
req := httptest.NewRequest("GET", "/test", nil)
req.Header.Set("X-Forwarded-For", "192.168.1.100, 10.0.0.1")
return req
},
expected: "192.168.1.100",
},
{
name: "X-Real-IP header",
setupReq: func() *http.Request {
req := httptest.NewRequest("GET", "/test", nil)
req.Header.Set("X-Real-IP", "192.168.1.200")
return req
},
expected: "192.168.1.200",
},
{
name: "RemoteAddr fallback",
setupReq: func() *http.Request {
req := httptest.NewRequest("GET", "/test", nil)
req.RemoteAddr = "192.168.1.1:12345"
return req
},
expected: "192.168.1.1:12345",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := tt.setupReq()
result := getIPAddress(req)
if result != tt.expected {
t.Errorf("getIPAddress() = %q, expected %q", result, tt.expected)
}
})
}
}
// TestParsePaginationParams tests pagination parameter parsing
func TestParsePaginationParams(t *testing.T) {
handler := NewHandler(&MockDatabase{})
tests := []struct {
name string
queryParams map[string]string
expectedSort string
expectedLimit int
expectedOffset int
}{
{
name: "No parameters - defaults",
queryParams: map[string]string{},
expectedSort: "",
expectedLimit: 20,
expectedOffset: 0,
},
{
name: "All parameters provided",
queryParams: map[string]string{
"sort": "name,-created_at",
"limit": "100",
"offset": "50",
},
expectedSort: "name,-created_at",
expectedLimit: 100,
expectedOffset: 50,
},
{
name: "Invalid limit - use default",
queryParams: map[string]string{
"limit": "invalid",
},
expectedSort: "",
expectedLimit: 20,
expectedOffset: 0,
},
{
name: "Negative offset - use default",
queryParams: map[string]string{
"offset": "-10",
},
expectedSort: "",
expectedLimit: 20,
expectedOffset: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := createTestRequest("GET", "/test", tt.queryParams, nil, nil)
sort, limit, offset := handler.parsePaginationParams(req)
if sort != tt.expectedSort {
t.Errorf("Expected sort=%q, got %q", tt.expectedSort, sort)
}
if limit != tt.expectedLimit {
t.Errorf("Expected limit=%d, got %d", tt.expectedLimit, limit)
}
if offset != tt.expectedOffset {
t.Errorf("Expected offset=%d, got %d", tt.expectedOffset, offset)
}
})
}
}
// TestSqlQuery tests the SqlQuery handler for single record queries
func TestSqlQuery(t *testing.T) {
tests := []struct {
name string
sqlQuery string
blankParams bool
queryParams map[string]string
headers map[string]string
setupDB func() *MockDatabase
expectedStatus int
validateResp func(t *testing.T, body []byte)
}{
{
name: "Basic query - returns single record",
sqlQuery: "SELECT * FROM users WHERE id = 1",
blankParams: false,
setupDB: func() *MockDatabase {
return &MockDatabase{
RunInTransactionFunc: func(ctx context.Context, fn func(common.Database) error) error {
db := &MockDatabase{
QueryFunc: func(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
rows := dest.(*[]map[string]interface{})
*rows = []map[string]interface{}{
{"id": float64(1), "name": "Test User", "email": "test@example.com"},
}
return nil
},
}
return fn(db)
},
}
},
expectedStatus: 200,
validateResp: func(t *testing.T, body []byte) {
var result map[string]interface{}
if err := json.Unmarshal(body, &result); err != nil {
t.Fatalf("Failed to unmarshal response: %v", err)
}
if result["name"] != "Test User" {
t.Errorf("Expected name='Test User', got %v", result["name"])
}
},
},
{
name: "Query with no results",
sqlQuery: "SELECT * FROM users WHERE id = 999",
blankParams: false,
setupDB: func() *MockDatabase {
return &MockDatabase{
RunInTransactionFunc: func(ctx context.Context, fn func(common.Database) error) error {
db := &MockDatabase{
QueryFunc: func(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
// Return empty array
return nil
},
}
return fn(db)
},
}
},
expectedStatus: 200,
validateResp: func(t *testing.T, body []byte) {
var result map[string]interface{}
if err := json.Unmarshal(body, &result); err != nil {
t.Fatalf("Failed to unmarshal response: %v", err)
}
if len(result) != 0 {
t.Errorf("Expected empty result, got %v", result)
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
db := tt.setupDB()
handler := NewHandler(db)
req := createTestRequest("GET", "/test", tt.queryParams, tt.headers, nil)
w := httptest.NewRecorder()
handlerFunc := handler.SqlQuery(tt.sqlQuery, tt.blankParams)
handlerFunc(w, req)
if w.Code != tt.expectedStatus {
t.Errorf("Expected status %d, got %d", tt.expectedStatus, w.Code)
}
if tt.validateResp != nil {
tt.validateResp(t, w.Body.Bytes())
}
})
}
}
// TestSqlQueryList tests the SqlQueryList handler for list queries
func TestSqlQueryList(t *testing.T) {
tests := []struct {
name string
sqlQuery string
noCount bool
blankParams bool
allowFilter bool
queryParams map[string]string
headers map[string]string
setupDB func() *MockDatabase
expectedStatus int
validateResp func(t *testing.T, w *httptest.ResponseRecorder)
}{
{
name: "Basic list query",
sqlQuery: "SELECT * FROM users",
noCount: false,
blankParams: false,
allowFilter: false,
setupDB: func() *MockDatabase {
return &MockDatabase{
RunInTransactionFunc: func(ctx context.Context, fn func(common.Database) error) error {
callCount := 0
db := &MockDatabase{
QueryFunc: func(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
callCount++
if strings.Contains(query, "COUNT") {
// Count query
countResult := dest.(*struct{ Count int64 })
countResult.Count = 2
} else {
// Main query
rows := dest.(*[]map[string]interface{})
*rows = []map[string]interface{}{
{"id": float64(1), "name": "User 1"},
{"id": float64(2), "name": "User 2"},
}
}
return nil
},
}
return fn(db)
},
}
},
expectedStatus: 200,
validateResp: func(t *testing.T, w *httptest.ResponseRecorder) {
var result []map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &result); err != nil {
t.Fatalf("Failed to unmarshal response: %v", err)
}
if len(result) != 2 {
t.Errorf("Expected 2 results, got %d", len(result))
}
// Check Content-Range header
contentRange := w.Header().Get("Content-Range")
if !strings.Contains(contentRange, "2") {
t.Errorf("Expected Content-Range to contain total count, got: %s", contentRange)
}
},
},
{
name: "List query with noCount",
sqlQuery: "SELECT * FROM users",
noCount: true,
blankParams: false,
allowFilter: false,
setupDB: func() *MockDatabase {
return &MockDatabase{
RunInTransactionFunc: func(ctx context.Context, fn func(common.Database) error) error {
db := &MockDatabase{
QueryFunc: func(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
if strings.Contains(query, "COUNT") {
t.Error("Count query should not be executed when noCount is true")
}
rows := dest.(*[]map[string]interface{})
*rows = []map[string]interface{}{
{"id": float64(1), "name": "User 1"},
}
return nil
},
}
return fn(db)
},
}
},
expectedStatus: 200,
validateResp: func(t *testing.T, w *httptest.ResponseRecorder) {
var result []map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &result); err != nil {
t.Fatalf("Failed to unmarshal response: %v", err)
}
if len(result) != 1 {
t.Errorf("Expected 1 result, got %d", len(result))
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
db := tt.setupDB()
handler := NewHandler(db)
req := createTestRequest("GET", "/test", tt.queryParams, tt.headers, nil)
w := httptest.NewRecorder()
handlerFunc := handler.SqlQueryList(tt.sqlQuery, tt.noCount, tt.blankParams, tt.allowFilter)
handlerFunc(w, req)
if w.Code != tt.expectedStatus {
t.Errorf("Expected status %d, got %d. Body: %s", tt.expectedStatus, w.Code, w.Body.String())
}
if tt.validateResp != nil {
tt.validateResp(t, w)
}
})
}
}
// TestMergeQueryParams tests query parameter merging
func TestMergeQueryParams(t *testing.T) {
handler := NewHandler(&MockDatabase{})
tests := []struct {
name string
sqlQuery string
queryParams map[string]string
allowFilter bool
expectedQuery string
checkVars func(t *testing.T, vars map[string]interface{})
}{
{
name: "Replace placeholder with parameter",
sqlQuery: "SELECT * FROM users WHERE id = [user_id]",
queryParams: map[string]string{"p-user_id": "123"},
allowFilter: false,
checkVars: func(t *testing.T, vars map[string]interface{}) {
if vars["p-user_id"] != "123" {
t.Errorf("Expected p-user_id=123, got %v", vars["p-user_id"])
}
},
},
{
name: "Add filter when allowed",
sqlQuery: "SELECT * FROM users",
queryParams: map[string]string{"status": "active"},
allowFilter: true,
checkVars: func(t *testing.T, vars map[string]interface{}) {
if vars["status"] != "active" {
t.Errorf("Expected status=active, got %v", vars["status"])
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := createTestRequest("GET", "/test", tt.queryParams, nil, nil)
variables := make(map[string]interface{})
propQry := make(map[string]string)
result := handler.mergeQueryParams(req, tt.sqlQuery, variables, tt.allowFilter, propQry)
if result == "" {
t.Error("Expected non-empty SQL query result")
}
if tt.checkVars != nil {
tt.checkVars(t, variables)
}
})
}
}
// TestMergeHeaderParams tests header parameter merging
func TestMergeHeaderParams(t *testing.T) {
handler := NewHandler(&MockDatabase{})
tests := []struct {
name string
sqlQuery string
headers map[string]string
expectedQuery string
checkVars func(t *testing.T, vars map[string]interface{})
}{
{
name: "Field filter header",
sqlQuery: "SELECT * FROM users",
headers: map[string]string{"X-FieldFilter-Status": "1"},
checkVars: func(t *testing.T, vars map[string]interface{}) {
if vars["x-fieldfilter-status"] != "1" {
t.Errorf("Expected x-fieldfilter-status=1, got %v", vars["x-fieldfilter-status"])
}
},
},
{
name: "Search filter header",
sqlQuery: "SELECT * FROM users",
headers: map[string]string{"X-SearchFilter-Name": "john"},
checkVars: func(t *testing.T, vars map[string]interface{}) {
if vars["x-searchfilter-name"] != "john" {
t.Errorf("Expected x-searchfilter-name=john, got %v", vars["x-searchfilter-name"])
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := createTestRequest("GET", "/test", nil, tt.headers, nil)
variables := make(map[string]interface{})
propQry := make(map[string]string)
complexAPI := false
result := handler.mergeHeaderParams(req, tt.sqlQuery, variables, propQry, &complexAPI)
if result == "" {
t.Error("Expected non-empty SQL query result")
}
if tt.checkVars != nil {
tt.checkVars(t, variables)
}
})
}
}
// TestReplaceMetaVariables tests meta variable replacement
func TestReplaceMetaVariables(t *testing.T) {
handler := NewHandler(&MockDatabase{})
userCtx := &security.UserContext{
UserID: 123,
UserName: "testuser",
SessionID: "456",
}
metainfo := map[string]interface{}{
"ipaddress": "192.168.1.1",
"url": "/api/test",
}
variables := map[string]interface{}{
"param1": "value1",
}
tests := []struct {
name string
sqlQuery string
expectedCheck func(result string) bool
}{
{
name: "Replace [rid_user]",
sqlQuery: "SELECT * FROM users WHERE created_by = [rid_user]",
expectedCheck: func(result string) bool {
return strings.Contains(result, "123")
},
},
{
name: "Replace [user]",
sqlQuery: "SELECT * FROM audit WHERE username = [user]",
expectedCheck: func(result string) bool {
return strings.Contains(result, "'testuser'")
},
},
{
name: "Replace [rid_session]",
sqlQuery: "SELECT * FROM sessions WHERE session_id = [rid_session]",
expectedCheck: func(result string) bool {
return strings.Contains(result, "456")
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := createTestRequest("GET", "/test", nil, nil, nil)
result := handler.replaceMetaVariables(tt.sqlQuery, req, userCtx, metainfo, variables)
if !tt.expectedCheck(result) {
t.Errorf("Meta variable replacement failed. Query: %s", result)
}
})
}
}
// TestGetReplacementForBlankParam tests the blank parameter replacement logic
func TestGetReplacementForBlankParam(t *testing.T) {
tests := []struct {
name string
sqlQuery string
param string
expected string
}{
{
name: "Parameter in single quotes",
sqlQuery: "SELECT * FROM users WHERE name = '[username]'",
param: "[username]",
expected: "",
},
{
name: "Parameter in dollar quotes",
sqlQuery: "SELECT * FROM users WHERE data = $[jsondata]$",
param: "[jsondata]",
expected: "",
},
{
name: "Parameter not in quotes",
sqlQuery: "SELECT * FROM users WHERE id = [user_id]",
param: "[user_id]",
expected: "NULL",
},
{
name: "Parameter not in quotes with AND",
sqlQuery: "SELECT * FROM users WHERE id = [user_id] AND status = 1",
param: "[user_id]",
expected: "NULL",
},
{
name: "Parameter in mixed quote context - before quote",
sqlQuery: "SELECT * FROM users WHERE id = [user_id] AND name = 'test'",
param: "[user_id]",
expected: "NULL",
},
{
name: "Parameter in mixed quote context - in quotes",
sqlQuery: "SELECT * FROM users WHERE name = '[username]' AND id = 1",
param: "[username]",
expected: "",
},
{
name: "Parameter with dollar quote tag",
sqlQuery: "SELECT * FROM users WHERE body = $tag$[content]$tag$",
param: "[content]",
expected: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := getReplacementForBlankParam(tt.sqlQuery, tt.param)
if result != tt.expected {
t.Errorf("Expected replacement '%s', got '%s' for query: %s", tt.expected, result, tt.sqlQuery)
}
})
}
}

160
pkg/funcspec/hooks.go Normal file
View File

@@ -0,0 +1,160 @@
package funcspec
import (
"context"
"fmt"
"net/http"
"github.com/bitechdev/ResolveSpec/pkg/logger"
"github.com/bitechdev/ResolveSpec/pkg/security"
)
// HookType defines the type of hook to execute
type HookType string
const (
// Query operation hooks (for SqlQuery - single record)
BeforeQuery HookType = "before_query"
AfterQuery HookType = "after_query"
// Query list operation hooks (for SqlQueryList - multiple records)
BeforeQueryList HookType = "before_query_list"
AfterQueryList HookType = "after_query_list"
// SQL execution hooks (just before SQL is executed)
BeforeSQLExec HookType = "before_sql_exec"
AfterSQLExec HookType = "after_sql_exec"
// Response hooks (before response is sent)
BeforeResponse HookType = "before_response"
)
// HookContext contains all the data available to a hook
type HookContext struct {
Context context.Context
Handler *Handler // Reference to the handler for accessing database
Request *http.Request
Writer http.ResponseWriter
// SQL query and variables
SQLQuery string // The SQL query being executed (can be modified by hooks)
Variables map[string]interface{} // Variables extracted from request
InputVars []string // Input variable placeholders found in query
MetaInfo map[string]interface{} // Metadata about the request
PropQry map[string]string // Property query parameters
// User context
UserContext *security.UserContext
// Pagination and filtering (for list queries)
SortColumns string
Limit int
Offset int
// Results
Result interface{} // Query result (single record or list)
Total int64 // Total count (for list queries)
Error error // Error if operation failed
ComplexAPI bool // Whether complex API response format is requested
NoCount bool // Whether count query should be skipped
BlankParams bool // Whether blank parameters should be removed
AllowFilter bool // Whether filtering is allowed
// Allow hooks to abort the operation
Abort bool // If set to true, the operation will be aborted
AbortMessage string // Message to return if aborted
AbortCode int // HTTP status code if aborted
}
// HookFunc is the signature for hook functions
// It receives a HookContext and can modify it or return an error
// If an error is returned, the operation will be aborted
type HookFunc func(*HookContext) error
// HookRegistry manages all registered hooks
type HookRegistry struct {
hooks map[HookType][]HookFunc
}
// NewHookRegistry creates a new hook registry
func NewHookRegistry() *HookRegistry {
return &HookRegistry{
hooks: make(map[HookType][]HookFunc),
}
}
// Register adds a new hook for the specified hook type
func (r *HookRegistry) Register(hookType HookType, hook HookFunc) {
if r.hooks == nil {
r.hooks = make(map[HookType][]HookFunc)
}
r.hooks[hookType] = append(r.hooks[hookType], hook)
logger.Info("Registered funcspec hook for %s (total: %d)", hookType, len(r.hooks[hookType]))
}
// RegisterMultiple registers a hook for multiple hook types
func (r *HookRegistry) RegisterMultiple(hookTypes []HookType, hook HookFunc) {
for _, hookType := range hookTypes {
r.Register(hookType, hook)
}
}
// Execute runs all hooks for the specified type in order
// If any hook returns an error, execution stops and the error is returned
func (r *HookRegistry) Execute(hookType HookType, ctx *HookContext) error {
hooks, exists := r.hooks[hookType]
if !exists || len(hooks) == 0 {
return nil
}
logger.Debug("Executing %d funcspec hook(s) for %s", len(hooks), hookType)
for i, hook := range hooks {
if err := hook(ctx); err != nil {
logger.Error("Funcspec hook %d for %s failed: %v", i+1, hookType, err)
return fmt.Errorf("hook execution failed: %w", err)
}
// Check if hook requested abort
if ctx.Abort {
logger.Warn("Funcspec hook %d for %s requested abort: %s", i+1, hookType, ctx.AbortMessage)
return fmt.Errorf("operation aborted by hook: %s", ctx.AbortMessage)
}
}
return nil
}
// Clear removes all hooks for the specified type
func (r *HookRegistry) Clear(hookType HookType) {
delete(r.hooks, hookType)
logger.Info("Cleared all funcspec hooks for %s", hookType)
}
// ClearAll removes all registered hooks
func (r *HookRegistry) ClearAll() {
r.hooks = make(map[HookType][]HookFunc)
logger.Info("Cleared all funcspec hooks")
}
// Count returns the number of hooks registered for a specific type
func (r *HookRegistry) Count(hookType HookType) int {
if hooks, exists := r.hooks[hookType]; exists {
return len(hooks)
}
return 0
}
// HasHooks returns true if there are any hooks registered for the specified type
func (r *HookRegistry) HasHooks(hookType HookType) bool {
return r.Count(hookType) > 0
}
// GetAllHookTypes returns all hook types that have registered hooks
func (r *HookRegistry) GetAllHookTypes() []HookType {
types := make([]HookType, 0, len(r.hooks))
for hookType := range r.hooks {
types = append(types, hookType)
}
return types
}

View File

@@ -0,0 +1,137 @@
package funcspec
import (
"fmt"
"strings"
"github.com/bitechdev/ResolveSpec/pkg/logger"
)
// Example hook functions demonstrating various use cases
// ExampleLoggingHook logs all SQL queries before execution
func ExampleLoggingHook(ctx *HookContext) error {
logger.Info("Executing SQL query for user %s: %s", ctx.UserContext.UserName, ctx.SQLQuery)
return nil
}
// ExampleSecurityHook validates user permissions before executing queries
func ExampleSecurityHook(ctx *HookContext) error {
// Example: Block queries that try to access sensitive tables
if strings.Contains(strings.ToLower(ctx.SQLQuery), "sensitive_table") {
if ctx.UserContext.UserID != 1 { // Only admin can access
ctx.Abort = true
ctx.AbortCode = 403
ctx.AbortMessage = "Access denied: insufficient permissions"
return fmt.Errorf("access denied to sensitive_table")
}
}
return nil
}
// ExampleQueryModificationHook modifies SQL queries to add user-specific filtering
func ExampleQueryModificationHook(ctx *HookContext) error {
// Example: Automatically add user_id filter for non-admin users
if ctx.UserContext.UserID != 1 { // Not admin
// Add WHERE clause to filter by user_id
if !strings.Contains(strings.ToLower(ctx.SQLQuery), "where") {
ctx.SQLQuery = fmt.Sprintf("%s WHERE user_id = %d", ctx.SQLQuery, ctx.UserContext.UserID)
} else {
ctx.SQLQuery = strings.Replace(
ctx.SQLQuery,
"WHERE",
fmt.Sprintf("WHERE user_id = %d AND", ctx.UserContext.UserID),
1,
)
}
logger.Debug("Modified query for user %d: %s", ctx.UserContext.UserID, ctx.SQLQuery)
}
return nil
}
// ExampleResultFilterHook filters results after query execution
func ExampleResultFilterHook(ctx *HookContext) error {
// Example: Remove sensitive fields from results for non-admin users
if ctx.UserContext.UserID != 1 { // Not admin
switch result := ctx.Result.(type) {
case []map[string]interface{}:
// Filter list results
for i := range result {
delete(result[i], "password")
delete(result[i], "ssn")
delete(result[i], "credit_card")
}
case map[string]interface{}:
// Filter single record
delete(result, "password")
delete(result, "ssn")
delete(result, "credit_card")
}
}
return nil
}
// ExampleAuditHook logs all queries and results for audit purposes
func ExampleAuditHook(ctx *HookContext) error {
// Log to audit table or external system
logger.Info("AUDIT: User %s (%d) executed query from %s",
ctx.UserContext.UserName,
ctx.UserContext.UserID,
ctx.Request.RemoteAddr,
)
// In a real implementation, you might:
// - Insert into an audit log table
// - Send to a logging service
// - Write to a file
return nil
}
// ExampleCacheHook implements simple response caching
func ExampleCacheHook(ctx *HookContext) error {
// This is a simplified example - real caching would use a proper cache store
// Check if we have a cached result for this query
// cacheKey := fmt.Sprintf("%s:%s", ctx.UserContext.UserName, ctx.SQLQuery)
// if cachedResult := checkCache(cacheKey); cachedResult != nil {
// ctx.Result = cachedResult
// ctx.Abort = true // Skip query execution
// ctx.AbortMessage = "Serving from cache"
// }
return nil
}
// ExampleErrorHandlingHook provides custom error handling
func ExampleErrorHandlingHook(ctx *HookContext) error {
if ctx.Error != nil {
// Log error with context
logger.Error("Query failed for user %s: %v\nQuery: %s",
ctx.UserContext.UserName,
ctx.Error,
ctx.SQLQuery,
)
// You could send notifications, update metrics, etc.
}
return nil
}
// Example of registering hooks:
//
// func SetupHooks(handler *Handler) {
// hooks := handler.Hooks()
//
// // Register security hook before query execution
// hooks.Register(BeforeQuery, ExampleSecurityHook)
// hooks.Register(BeforeQueryList, ExampleSecurityHook)
//
// // Register logging hook before SQL execution
// hooks.Register(BeforeSQLExec, ExampleLoggingHook)
//
// // Register result filtering after query
// hooks.Register(AfterQuery, ExampleResultFilterHook)
// hooks.Register(AfterQueryList, ExampleResultFilterHook)
//
// // Register audit hook after execution
// hooks.RegisterMultiple([]HookType{AfterQuery, AfterQueryList}, ExampleAuditHook)
// }

589
pkg/funcspec/hooks_test.go Normal file
View File

@@ -0,0 +1,589 @@
package funcspec
import (
"context"
"fmt"
"net/http/httptest"
"testing"
"github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/security"
)
// TestNewHookRegistry tests hook registry creation
func TestNewHookRegistry(t *testing.T) {
registry := NewHookRegistry()
if registry == nil {
t.Fatal("Expected registry to be created, got nil")
}
if registry.hooks == nil {
t.Error("Expected hooks map to be initialized")
}
}
// TestRegisterHook tests registering a single hook
func TestRegisterHook(t *testing.T) {
registry := NewHookRegistry()
hookCalled := false
testHook := func(ctx *HookContext) error {
hookCalled = true
return nil
}
registry.Register(BeforeQuery, testHook)
if !registry.HasHooks(BeforeQuery) {
t.Error("Expected hook to be registered")
}
if registry.Count(BeforeQuery) != 1 {
t.Errorf("Expected 1 hook, got %d", registry.Count(BeforeQuery))
}
// Execute the hook
ctx := &HookContext{}
err := registry.Execute(BeforeQuery, ctx)
if err != nil {
t.Errorf("Hook execution failed: %v", err)
}
if !hookCalled {
t.Error("Expected hook to be called")
}
}
// TestRegisterMultipleHooks tests registering multiple hooks for same type
func TestRegisterMultipleHooks(t *testing.T) {
registry := NewHookRegistry()
callOrder := []int{}
hook1 := func(ctx *HookContext) error {
callOrder = append(callOrder, 1)
return nil
}
hook2 := func(ctx *HookContext) error {
callOrder = append(callOrder, 2)
return nil
}
hook3 := func(ctx *HookContext) error {
callOrder = append(callOrder, 3)
return nil
}
registry.Register(BeforeQuery, hook1)
registry.Register(BeforeQuery, hook2)
registry.Register(BeforeQuery, hook3)
if registry.Count(BeforeQuery) != 3 {
t.Errorf("Expected 3 hooks, got %d", registry.Count(BeforeQuery))
}
// Execute hooks
ctx := &HookContext{}
err := registry.Execute(BeforeQuery, ctx)
if err != nil {
t.Errorf("Hook execution failed: %v", err)
}
// Verify hooks were called in order
if len(callOrder) != 3 {
t.Errorf("Expected 3 hooks to be called, got %d", len(callOrder))
}
for i, expected := range []int{1, 2, 3} {
if callOrder[i] != expected {
t.Errorf("Expected hook %d at position %d, got %d", expected, i, callOrder[i])
}
}
}
// TestRegisterMultipleHookTypes tests registering a hook for multiple types
func TestRegisterMultipleHookTypes(t *testing.T) {
registry := NewHookRegistry()
callCount := 0
testHook := func(ctx *HookContext) error {
callCount++
return nil
}
hookTypes := []HookType{BeforeQuery, AfterQuery, BeforeSQLExec}
registry.RegisterMultiple(hookTypes, testHook)
// Verify hook is registered for all types
for _, hookType := range hookTypes {
if !registry.HasHooks(hookType) {
t.Errorf("Expected hook to be registered for %s", hookType)
}
if registry.Count(hookType) != 1 {
t.Errorf("Expected 1 hook for %s, got %d", hookType, registry.Count(hookType))
}
}
// Execute each hook type
ctx := &HookContext{}
for _, hookType := range hookTypes {
if err := registry.Execute(hookType, ctx); err != nil {
t.Errorf("Hook execution failed for %s: %v", hookType, err)
}
}
if callCount != 3 {
t.Errorf("Expected hook to be called 3 times, got %d", callCount)
}
}
// TestHookError tests hook error handling
func TestHookError(t *testing.T) {
registry := NewHookRegistry()
expectedError := fmt.Errorf("test error")
errorHook := func(ctx *HookContext) error {
return expectedError
}
registry.Register(BeforeQuery, errorHook)
ctx := &HookContext{}
err := registry.Execute(BeforeQuery, ctx)
if err == nil {
t.Error("Expected error from hook, got nil")
}
if err.Error() != fmt.Sprintf("hook execution failed: %v", expectedError) {
t.Errorf("Expected error message to contain hook error, got: %v", err)
}
}
// TestHookAbort tests hook abort functionality
func TestHookAbort(t *testing.T) {
registry := NewHookRegistry()
abortHook := func(ctx *HookContext) error {
ctx.Abort = true
ctx.AbortMessage = "Operation aborted by hook"
ctx.AbortCode = 403
return nil
}
registry.Register(BeforeQuery, abortHook)
ctx := &HookContext{}
err := registry.Execute(BeforeQuery, ctx)
if err == nil {
t.Error("Expected error when hook aborts, got nil")
}
if !ctx.Abort {
t.Error("Expected Abort to be true")
}
if ctx.AbortMessage != "Operation aborted by hook" {
t.Errorf("Expected abort message, got: %s", ctx.AbortMessage)
}
if ctx.AbortCode != 403 {
t.Errorf("Expected abort code 403, got: %d", ctx.AbortCode)
}
}
// TestHookChainWithError tests that hook chain stops on first error
func TestHookChainWithError(t *testing.T) {
registry := NewHookRegistry()
callOrder := []int{}
hook1 := func(ctx *HookContext) error {
callOrder = append(callOrder, 1)
return nil
}
hook2 := func(ctx *HookContext) error {
callOrder = append(callOrder, 2)
return fmt.Errorf("error in hook 2")
}
hook3 := func(ctx *HookContext) error {
callOrder = append(callOrder, 3)
return nil
}
registry.Register(BeforeQuery, hook1)
registry.Register(BeforeQuery, hook2)
registry.Register(BeforeQuery, hook3)
ctx := &HookContext{}
err := registry.Execute(BeforeQuery, ctx)
if err == nil {
t.Error("Expected error from hook chain")
}
// Only first two hooks should have been called
if len(callOrder) != 2 {
t.Errorf("Expected 2 hooks to be called, got %d", len(callOrder))
}
if callOrder[0] != 1 || callOrder[1] != 2 {
t.Errorf("Expected hooks 1 and 2 to be called, got: %v", callOrder)
}
}
// TestClearHooks tests clearing hooks
func TestClearHooks(t *testing.T) {
registry := NewHookRegistry()
testHook := func(ctx *HookContext) error {
return nil
}
registry.Register(BeforeQuery, testHook)
registry.Register(AfterQuery, testHook)
if !registry.HasHooks(BeforeQuery) {
t.Error("Expected BeforeQuery hook to be registered")
}
registry.Clear(BeforeQuery)
if registry.HasHooks(BeforeQuery) {
t.Error("Expected BeforeQuery hooks to be cleared")
}
if !registry.HasHooks(AfterQuery) {
t.Error("Expected AfterQuery hook to still be registered")
}
}
// TestClearAllHooks tests clearing all hooks
func TestClearAllHooks(t *testing.T) {
registry := NewHookRegistry()
testHook := func(ctx *HookContext) error {
return nil
}
registry.Register(BeforeQuery, testHook)
registry.Register(AfterQuery, testHook)
registry.Register(BeforeSQLExec, testHook)
registry.ClearAll()
if registry.HasHooks(BeforeQuery) || registry.HasHooks(AfterQuery) || registry.HasHooks(BeforeSQLExec) {
t.Error("Expected all hooks to be cleared")
}
}
// TestGetAllHookTypes tests getting all registered hook types
func TestGetAllHookTypes(t *testing.T) {
registry := NewHookRegistry()
testHook := func(ctx *HookContext) error {
return nil
}
registry.Register(BeforeQuery, testHook)
registry.Register(AfterQuery, testHook)
types := registry.GetAllHookTypes()
if len(types) != 2 {
t.Errorf("Expected 2 hook types, got %d", len(types))
}
// Verify the types are present
foundBefore := false
foundAfter := false
for _, hookType := range types {
if hookType == BeforeQuery {
foundBefore = true
}
if hookType == AfterQuery {
foundAfter = true
}
}
if !foundBefore || !foundAfter {
t.Error("Expected both BeforeQuery and AfterQuery hook types")
}
}
// TestHookContextModification tests that hooks can modify the context
func TestHookContextModification(t *testing.T) {
registry := NewHookRegistry()
// Hook that modifies SQL query
modifyHook := func(ctx *HookContext) error {
ctx.SQLQuery = "SELECT * FROM modified_table"
ctx.Variables["new_var"] = "new_value"
return nil
}
registry.Register(BeforeQuery, modifyHook)
ctx := &HookContext{
SQLQuery: "SELECT * FROM original_table",
Variables: make(map[string]interface{}),
}
err := registry.Execute(BeforeQuery, ctx)
if err != nil {
t.Errorf("Hook execution failed: %v", err)
}
if ctx.SQLQuery != "SELECT * FROM modified_table" {
t.Errorf("Expected SQL query to be modified, got: %s", ctx.SQLQuery)
}
if ctx.Variables["new_var"] != "new_value" {
t.Errorf("Expected variable to be added, got: %v", ctx.Variables)
}
}
// TestExampleHooks tests the example hooks
func TestExampleLoggingHook(t *testing.T) {
ctx := &HookContext{
Context: context.Background(),
SQLQuery: "SELECT * FROM test",
UserContext: &security.UserContext{
UserName: "testuser",
},
}
err := ExampleLoggingHook(ctx)
if err != nil {
t.Errorf("ExampleLoggingHook failed: %v", err)
}
}
func TestExampleSecurityHook(t *testing.T) {
tests := []struct {
name string
sqlQuery string
userID int
shouldAbort bool
}{
{
name: "Admin accessing sensitive table",
sqlQuery: "SELECT * FROM sensitive_table",
userID: 1,
shouldAbort: false,
},
{
name: "Non-admin accessing sensitive table",
sqlQuery: "SELECT * FROM sensitive_table",
userID: 2,
shouldAbort: true,
},
{
name: "Non-admin accessing normal table",
sqlQuery: "SELECT * FROM users",
userID: 2,
shouldAbort: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := &HookContext{
Context: context.Background(),
SQLQuery: tt.sqlQuery,
UserContext: &security.UserContext{
UserID: tt.userID,
},
}
_ = ExampleSecurityHook(ctx)
if tt.shouldAbort {
if !ctx.Abort {
t.Error("Expected security hook to abort operation")
}
if ctx.AbortCode != 403 {
t.Errorf("Expected abort code 403, got %d", ctx.AbortCode)
}
} else {
if ctx.Abort {
t.Error("Expected security hook not to abort operation")
}
}
})
}
}
func TestExampleResultFilterHook(t *testing.T) {
tests := []struct {
name string
userID int
result interface{}
validate func(t *testing.T, result interface{})
}{
{
name: "Admin user - no filtering",
userID: 1,
result: map[string]interface{}{
"id": 1,
"name": "Test",
"password": "secret",
},
validate: func(t *testing.T, result interface{}) {
m := result.(map[string]interface{})
if _, exists := m["password"]; !exists {
t.Error("Expected password field to remain for admin")
}
},
},
{
name: "Regular user - sensitive fields removed",
userID: 2,
result: map[string]interface{}{
"id": 1,
"name": "Test",
"password": "secret",
"ssn": "123-45-6789",
},
validate: func(t *testing.T, result interface{}) {
m := result.(map[string]interface{})
if _, exists := m["password"]; exists {
t.Error("Expected password field to be removed")
}
if _, exists := m["ssn"]; exists {
t.Error("Expected ssn field to be removed")
}
if _, exists := m["name"]; !exists {
t.Error("Expected name field to remain")
}
},
},
{
name: "Regular user - list results filtered",
userID: 2,
result: []map[string]interface{}{
{"id": 1, "name": "User 1", "password": "secret1"},
{"id": 2, "name": "User 2", "password": "secret2"},
},
validate: func(t *testing.T, result interface{}) {
list := result.([]map[string]interface{})
for _, m := range list {
if _, exists := m["password"]; exists {
t.Error("Expected password field to be removed from list")
}
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := &HookContext{
Context: context.Background(),
Result: tt.result,
UserContext: &security.UserContext{
UserID: tt.userID,
},
}
err := ExampleResultFilterHook(ctx)
if err != nil {
t.Errorf("Hook failed: %v", err)
}
if tt.validate != nil {
tt.validate(t, ctx.Result)
}
})
}
}
func TestExampleAuditHook(t *testing.T) {
req := httptest.NewRequest("GET", "/api/test", nil)
req.RemoteAddr = "192.168.1.1:12345"
ctx := &HookContext{
Context: context.Background(),
Request: req,
UserContext: &security.UserContext{
UserID: 123,
UserName: "testuser",
},
}
err := ExampleAuditHook(ctx)
if err != nil {
t.Errorf("ExampleAuditHook failed: %v", err)
}
}
func TestExampleErrorHandlingHook(t *testing.T) {
ctx := &HookContext{
Context: context.Background(),
SQLQuery: "SELECT * FROM test",
Error: fmt.Errorf("test error"),
UserContext: &security.UserContext{
UserName: "testuser",
},
}
err := ExampleErrorHandlingHook(ctx)
if err != nil {
t.Errorf("ExampleErrorHandlingHook failed: %v", err)
}
}
// TestHookIntegrationWithHandler tests hooks integrated with the handler
func TestHookIntegrationWithHandler(t *testing.T) {
db := &MockDatabase{
RunInTransactionFunc: func(ctx context.Context, fn func(common.Database) error) error {
queryDB := &MockDatabase{
QueryFunc: func(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
rows := dest.(*[]map[string]interface{})
*rows = []map[string]interface{}{
{"id": float64(1), "name": "Test User"},
}
return nil
},
}
return fn(queryDB)
},
}
handler := NewHandler(db)
// Register a hook that modifies the SQL query
hookCalled := false
handler.Hooks().Register(BeforeSQLExec, func(ctx *HookContext) error {
hookCalled = true
// Verify we can access context data
if ctx.SQLQuery == "" {
t.Error("Expected SQL query to be set")
}
if ctx.UserContext == nil {
t.Error("Expected user context to be set")
}
return nil
})
// Execute a query
req := createTestRequest("GET", "/test", nil, nil, nil)
w := httptest.NewRecorder()
handlerFunc := handler.SqlQuery("SELECT * FROM users WHERE id = 1", false)
handlerFunc(w, req)
if !hookCalled {
t.Error("Expected hook to be called during query execution")
}
if w.Code != 200 {
t.Errorf("Expected status 200, got %d", w.Code)
}
}

411
pkg/funcspec/parameters.go Normal file
View File

@@ -0,0 +1,411 @@
package funcspec
import (
"fmt"
"net/http"
"strconv"
"strings"
"github.com/bitechdev/ResolveSpec/pkg/logger"
"github.com/bitechdev/ResolveSpec/pkg/restheadspec"
)
// RequestParameters holds parsed parameters from headers and query string
type RequestParameters struct {
// Field selection
SelectFields []string
NotSelectFields []string
Distinct bool
// Filtering
FieldFilters map[string]string // column -> value (exact match)
SearchFilters map[string]string // column -> value (ILIKE)
SearchOps map[string]FilterOperator // column -> {operator, value, logic}
CustomSQLWhere string
CustomSQLOr string
// Sorting & Pagination
SortColumns string
Limit int
Offset int
// Advanced features
SkipCount bool
SkipCache bool
// Response format
ResponseFormat string // "simple", "detail", "syncfusion"
ComplexAPI bool // true if NOT simple API
}
// FilterOperator represents a filter with operator
type FilterOperator struct {
Operator string // eq, neq, gt, lt, gte, lte, like, ilike, in, between, etc.
Value string
Logic string // AND or OR
}
// ParseParameters parses all parameters from request headers and query string
func (h *Handler) ParseParameters(r *http.Request) *RequestParameters {
params := &RequestParameters{
FieldFilters: make(map[string]string),
SearchFilters: make(map[string]string),
SearchOps: make(map[string]FilterOperator),
Limit: 20, // Default limit
Offset: 0, // Default offset
ResponseFormat: "simple", // Default format
ComplexAPI: false, // Default to simple API
}
// Merge headers and query parameters
combined := make(map[string]string)
// Add all headers (normalize to lowercase)
for key, values := range r.Header {
if len(values) > 0 {
combined[strings.ToLower(key)] = values[0]
}
}
// Add all query parameters (override headers)
for key, values := range r.URL.Query() {
if len(values) > 0 {
combined[strings.ToLower(key)] = values[0]
}
}
// Parse each parameter
for key, value := range combined {
// Decode value if base64 encoded
decodedValue := h.decodeValue(value)
switch {
// Field Selection
case strings.HasPrefix(key, "x-select-fields"):
params.SelectFields = h.parseCommaSeparated(decodedValue)
case strings.HasPrefix(key, "x-not-select-fields"):
params.NotSelectFields = h.parseCommaSeparated(decodedValue)
case strings.HasPrefix(key, "x-distinct"):
params.Distinct = strings.EqualFold(decodedValue, "true")
// Filtering
case strings.HasPrefix(key, "x-fieldfilter-"):
colName := strings.TrimPrefix(key, "x-fieldfilter-")
params.FieldFilters[colName] = decodedValue
case strings.HasPrefix(key, "x-searchfilter-"):
colName := strings.TrimPrefix(key, "x-searchfilter-")
params.SearchFilters[colName] = decodedValue
case strings.HasPrefix(key, "x-searchop-"):
h.parseSearchOp(params, key, decodedValue, "AND")
case strings.HasPrefix(key, "x-searchor-"):
h.parseSearchOp(params, key, decodedValue, "OR")
case strings.HasPrefix(key, "x-searchand-"):
h.parseSearchOp(params, key, decodedValue, "AND")
case strings.HasPrefix(key, "x-custom-sql-w"):
if params.CustomSQLWhere != "" {
params.CustomSQLWhere = fmt.Sprintf("%s AND (%s)", params.CustomSQLWhere, decodedValue)
} else {
params.CustomSQLWhere = decodedValue
}
case strings.HasPrefix(key, "x-custom-sql-or"):
if params.CustomSQLOr != "" {
params.CustomSQLOr = fmt.Sprintf("%s OR (%s)", params.CustomSQLOr, decodedValue)
} else {
params.CustomSQLOr = decodedValue
}
// Sorting & Pagination
case key == "sort" || strings.HasPrefix(key, "x-sort"):
params.SortColumns = decodedValue
case strings.HasPrefix(key, "sort(") && strings.Contains(key, ")"):
// Handle sort(col1,-col2) syntax
sortValue := key[strings.Index(key, "(")+1 : strings.Index(key, ")")]
params.SortColumns = sortValue
case key == "limit" || strings.HasPrefix(key, "x-limit"):
if limit, err := strconv.Atoi(decodedValue); err == nil && limit > 0 {
params.Limit = limit
}
case strings.HasPrefix(key, "limit(") && strings.Contains(key, ")"):
// Handle limit(offset,limit) or limit(limit) syntax
limitValue := key[strings.Index(key, "(")+1 : strings.Index(key, ")")]
parts := strings.Split(limitValue, ",")
if len(parts) > 1 {
if offset, err := strconv.Atoi(parts[0]); err == nil {
params.Offset = offset
}
if limit, err := strconv.Atoi(parts[1]); err == nil {
params.Limit = limit
}
} else {
if limit, err := strconv.Atoi(parts[0]); err == nil {
params.Limit = limit
}
}
case key == "offset" || strings.HasPrefix(key, "x-offset"):
if offset, err := strconv.Atoi(decodedValue); err == nil && offset >= 0 {
params.Offset = offset
}
// Advanced features
case strings.HasPrefix(key, "x-skipcount"):
params.SkipCount = strings.EqualFold(decodedValue, "true")
case strings.HasPrefix(key, "x-skipcache"):
params.SkipCache = strings.EqualFold(decodedValue, "true")
// Response Format
case strings.HasPrefix(key, "x-simpleapi"):
params.ResponseFormat = "simple"
params.ComplexAPI = decodedValue != "1" && !strings.EqualFold(decodedValue, "true")
case strings.HasPrefix(key, "x-detailapi"):
params.ResponseFormat = "detail"
params.ComplexAPI = true
case strings.HasPrefix(key, "x-syncfusion"):
params.ResponseFormat = "syncfusion"
params.ComplexAPI = true
}
}
return params
}
// parseSearchOp parses x-searchop-{operator}-{column} or x-searchor-{operator}-{column}
func (h *Handler) parseSearchOp(params *RequestParameters, headerKey, value, logic string) {
var prefix string
if logic == "OR" {
prefix = "x-searchor-"
} else {
prefix = "x-searchop-"
if strings.HasPrefix(headerKey, "x-searchand-") {
prefix = "x-searchand-"
}
}
rest := strings.TrimPrefix(headerKey, prefix)
parts := strings.SplitN(rest, "-", 2)
if len(parts) != 2 {
logger.Warn("Invalid search operator header format: %s", headerKey)
return
}
operator := parts[0]
colName := parts[1]
params.SearchOps[colName] = FilterOperator{
Operator: operator,
Value: value,
Logic: logic,
}
logger.Debug("%s search operator: %s %s %s", logic, colName, operator, value)
}
// decodeValue decodes base64 encoded values (ZIP_ or __ prefix)
func (h *Handler) decodeValue(value string) string {
decoded, _ := restheadspec.DecodeParam(value)
return decoded
}
// parseCommaSeparated parses comma-separated values
func (h *Handler) parseCommaSeparated(value string) []string {
if value == "" {
return nil
}
parts := strings.Split(value, ",")
result := make([]string, 0, len(parts))
for _, part := range parts {
part = strings.TrimSpace(part)
if part != "" {
result = append(result, part)
}
}
return result
}
// ApplyFieldSelection applies column selection to SQL query
func (h *Handler) ApplyFieldSelection(sqlQuery string, params *RequestParameters) string {
if len(params.SelectFields) == 0 && len(params.NotSelectFields) == 0 {
return sqlQuery
}
// This is a simplified implementation
// A full implementation would parse the SQL and replace the SELECT clause
// For now, we log a warning that this feature needs manual implementation
if len(params.SelectFields) > 0 {
logger.Debug("Field selection requested: %v (manual SQL adjustment may be needed)", params.SelectFields)
}
if len(params.NotSelectFields) > 0 {
logger.Debug("Field exclusion requested: %v (manual SQL adjustment may be needed)", params.NotSelectFields)
}
return sqlQuery
}
// ApplyFilters applies all filters to the SQL query
func (h *Handler) ApplyFilters(sqlQuery string, params *RequestParameters) string {
// Apply field filters (exact match)
for colName, value := range params.FieldFilters {
condition := ""
if value == "" || value == "0" {
condition = fmt.Sprintf("COALESCE(%s, 0) = %s", ValidSQL(colName, "colname"), ValidSQL(value, "colvalue"))
} else {
condition = fmt.Sprintf("%s = %s", ValidSQL(colName, "colname"), ValidSQL(value, "colvalue"))
}
sqlQuery = sqlQryWhere(sqlQuery, condition)
logger.Debug("Applied field filter: %s", condition)
}
// Apply search filters (ILIKE)
for colName, value := range params.SearchFilters {
sval := strings.ReplaceAll(value, "'", "")
if sval != "" {
condition := fmt.Sprintf("%s ILIKE '%%%s%%'", ValidSQL(colName, "colname"), ValidSQL(sval, "colvalue"))
sqlQuery = sqlQryWhere(sqlQuery, condition)
logger.Debug("Applied search filter: %s", condition)
}
}
// Apply search operators
for colName, filterOp := range params.SearchOps {
condition := h.buildFilterCondition(colName, filterOp)
if condition != "" {
if filterOp.Logic == "OR" {
sqlQuery = sqlQryWhereOr(sqlQuery, condition)
} else {
sqlQuery = sqlQryWhere(sqlQuery, condition)
}
logger.Debug("Applied search operator: %s", condition)
}
}
// Apply custom SQL WHERE
if params.CustomSQLWhere != "" {
colval := ValidSQL(params.CustomSQLWhere, "select")
if colval != "" {
sqlQuery = sqlQryWhere(sqlQuery, colval)
logger.Debug("Applied custom SQL WHERE: %s", colval)
}
}
// Apply custom SQL OR
if params.CustomSQLOr != "" {
colval := ValidSQL(params.CustomSQLOr, "select")
if colval != "" {
sqlQuery = sqlQryWhereOr(sqlQuery, colval)
logger.Debug("Applied custom SQL OR: %s", colval)
}
}
return sqlQuery
}
// buildFilterCondition builds a SQL condition from a FilterOperator
func (h *Handler) buildFilterCondition(colName string, op FilterOperator) string {
safCol := ValidSQL(colName, "colname")
operator := strings.ToLower(op.Operator)
value := op.Value
switch operator {
case "contains", "contain", "like":
return fmt.Sprintf("%s ILIKE '%%%s%%'", safCol, ValidSQL(value, "colvalue"))
case "beginswith", "startswith":
return fmt.Sprintf("%s ILIKE '%s%%'", safCol, ValidSQL(value, "colvalue"))
case "endswith":
return fmt.Sprintf("%s ILIKE '%%%s'", safCol, ValidSQL(value, "colvalue"))
case "equals", "eq", "=":
if IsNumeric(value) {
return fmt.Sprintf("%s = %s", safCol, ValidSQL(value, "colvalue"))
}
return fmt.Sprintf("%s = '%s'", safCol, ValidSQL(value, "colvalue"))
case "notequals", "neq", "ne", "!=", "<>":
if IsNumeric(value) {
return fmt.Sprintf("%s != %s", safCol, ValidSQL(value, "colvalue"))
}
return fmt.Sprintf("%s != '%s'", safCol, ValidSQL(value, "colvalue"))
case "greaterthan", "gt", ">":
return fmt.Sprintf("%s > %s", safCol, ValidSQL(value, "colvalue"))
case "lessthan", "lt", "<":
return fmt.Sprintf("%s < %s", safCol, ValidSQL(value, "colvalue"))
case "greaterthanorequal", "gte", "ge", ">=":
return fmt.Sprintf("%s >= %s", safCol, ValidSQL(value, "colvalue"))
case "lessthanorequal", "lte", "le", "<=":
return fmt.Sprintf("%s <= %s", safCol, ValidSQL(value, "colvalue"))
case "between":
parts := strings.Split(value, ",")
if len(parts) == 2 {
return fmt.Sprintf("%s > %s AND %s < %s", safCol, ValidSQL(parts[0], "colvalue"), safCol, ValidSQL(parts[1], "colvalue"))
}
case "betweeninclusive":
parts := strings.Split(value, ",")
if len(parts) == 2 {
return fmt.Sprintf("%s >= %s AND %s <= %s", safCol, ValidSQL(parts[0], "colvalue"), safCol, ValidSQL(parts[1], "colvalue"))
}
case "in":
values := strings.Split(value, ",")
safeValues := make([]string, len(values))
for i, v := range values {
safeValues[i] = fmt.Sprintf("'%s'", ValidSQL(v, "colvalue"))
}
return fmt.Sprintf("%s IN (%s)", safCol, strings.Join(safeValues, ", "))
case "empty", "isnull", "null":
return fmt.Sprintf("(%s IS NULL OR %s = '')", safCol, safCol)
case "notempty", "isnotnull", "notnull":
return fmt.Sprintf("(%s IS NOT NULL AND %s != '')", safCol, safCol)
default:
logger.Warn("Unknown filter operator: %s, defaulting to equals", operator)
return fmt.Sprintf("%s = '%s'", safCol, ValidSQL(value, "colvalue"))
}
return ""
}
// ApplyDistinct adds DISTINCT to SQL query if requested
func (h *Handler) ApplyDistinct(sqlQuery string, params *RequestParameters) string {
if !params.Distinct {
return sqlQuery
}
// Add DISTINCT after SELECT
selectPos := strings.Index(strings.ToUpper(sqlQuery), "SELECT")
if selectPos >= 0 {
beforeSelect := sqlQuery[:selectPos+6] // "SELECT"
afterSelect := sqlQuery[selectPos+6:]
sqlQuery = beforeSelect + " DISTINCT" + afterSelect
logger.Debug("Applied DISTINCT to query")
}
return sqlQuery
}
// sqlQryWhereOr adds a WHERE clause with OR logic
func sqlQryWhereOr(sqlquery, condition string) string {
lowerQuery := strings.ToLower(sqlquery)
wherePos := strings.Index(lowerQuery, " where ")
groupPos := strings.Index(lowerQuery, " group by")
orderPos := strings.Index(lowerQuery, " order by")
limitPos := strings.Index(lowerQuery, " limit ")
// Find the insertion point
insertPos := len(sqlquery)
if groupPos > 0 && groupPos < insertPos {
insertPos = groupPos
}
if orderPos > 0 && orderPos < insertPos {
insertPos = orderPos
}
if limitPos > 0 && limitPos < insertPos {
insertPos = limitPos
}
if wherePos > 0 {
// WHERE exists, add OR condition
before := sqlquery[:insertPos]
after := sqlquery[insertPos:]
return fmt.Sprintf("%s OR (%s) %s", before, condition, after)
} else {
// No WHERE exists, add it
before := sqlquery[:insertPos]
after := sqlquery[insertPos:]
return fmt.Sprintf("%s WHERE %s %s", before, condition, after)
}
}

View File

@@ -0,0 +1,549 @@
package funcspec
import (
"strings"
"testing"
)
// TestParseParameters tests the comprehensive parameter parsing
func TestParseParameters(t *testing.T) {
handler := NewHandler(&MockDatabase{})
tests := []struct {
name string
queryParams map[string]string
headers map[string]string
validate func(t *testing.T, params *RequestParameters)
}{
{
name: "Parse field selection",
headers: map[string]string{
"X-Select-Fields": "id,name,email",
"X-Not-Select-Fields": "password,ssn",
},
validate: func(t *testing.T, params *RequestParameters) {
if len(params.SelectFields) != 3 {
t.Errorf("Expected 3 select fields, got %d", len(params.SelectFields))
}
if len(params.NotSelectFields) != 2 {
t.Errorf("Expected 2 not-select fields, got %d", len(params.NotSelectFields))
}
},
},
{
name: "Parse distinct flag",
headers: map[string]string{
"X-Distinct": "true",
},
validate: func(t *testing.T, params *RequestParameters) {
if !params.Distinct {
t.Error("Expected Distinct to be true")
}
},
},
{
name: "Parse field filters",
headers: map[string]string{
"X-FieldFilter-Status": "active",
"X-FieldFilter-Type": "admin",
},
validate: func(t *testing.T, params *RequestParameters) {
if len(params.FieldFilters) != 2 {
t.Errorf("Expected 2 field filters, got %d", len(params.FieldFilters))
}
if params.FieldFilters["status"] != "active" {
t.Errorf("Expected status filter=active, got %s", params.FieldFilters["status"])
}
},
},
{
name: "Parse search filters",
headers: map[string]string{
"X-SearchFilter-Name": "john",
"X-SearchFilter-Email": "test",
},
validate: func(t *testing.T, params *RequestParameters) {
if len(params.SearchFilters) != 2 {
t.Errorf("Expected 2 search filters, got %d", len(params.SearchFilters))
}
},
},
{
name: "Parse sort columns",
queryParams: map[string]string{
"sort": "-created_at,name",
},
validate: func(t *testing.T, params *RequestParameters) {
if params.SortColumns != "-created_at,name" {
t.Errorf("Expected sort columns=-created_at,name, got %s", params.SortColumns)
}
},
},
{
name: "Parse limit and offset",
queryParams: map[string]string{
"limit": "100",
"offset": "50",
},
validate: func(t *testing.T, params *RequestParameters) {
if params.Limit != 100 {
t.Errorf("Expected limit=100, got %d", params.Limit)
}
if params.Offset != 50 {
t.Errorf("Expected offset=50, got %d", params.Offset)
}
},
},
{
name: "Parse skip count",
headers: map[string]string{
"X-SkipCount": "true",
},
validate: func(t *testing.T, params *RequestParameters) {
if !params.SkipCount {
t.Error("Expected SkipCount to be true")
}
},
},
{
name: "Parse response format - syncfusion",
headers: map[string]string{
"X-Syncfusion": "true",
},
validate: func(t *testing.T, params *RequestParameters) {
if params.ResponseFormat != "syncfusion" {
t.Errorf("Expected ResponseFormat=syncfusion, got %s", params.ResponseFormat)
}
if !params.ComplexAPI {
t.Error("Expected ComplexAPI to be true for syncfusion format")
}
},
},
{
name: "Parse response format - detail",
headers: map[string]string{
"X-DetailAPI": "true",
},
validate: func(t *testing.T, params *RequestParameters) {
if params.ResponseFormat != "detail" {
t.Errorf("Expected ResponseFormat=detail, got %s", params.ResponseFormat)
}
},
},
{
name: "Parse simple API",
headers: map[string]string{
"X-SimpleAPI": "true",
},
validate: func(t *testing.T, params *RequestParameters) {
if params.ResponseFormat != "simple" {
t.Errorf("Expected ResponseFormat=simple, got %s", params.ResponseFormat)
}
if params.ComplexAPI {
t.Error("Expected ComplexAPI to be false for simple API")
}
},
},
{
name: "Parse custom SQL WHERE",
headers: map[string]string{
"X-Custom-SQL-W": "status = 'active' AND deleted = false",
},
validate: func(t *testing.T, params *RequestParameters) {
if params.CustomSQLWhere == "" {
t.Error("Expected CustomSQLWhere to be set")
}
},
},
{
name: "Parse search operators - AND",
headers: map[string]string{
"X-SearchOp-Eq-Name": "john",
"X-SearchOp-Gt-Age": "18",
},
validate: func(t *testing.T, params *RequestParameters) {
if len(params.SearchOps) != 2 {
t.Errorf("Expected 2 search operators, got %d", len(params.SearchOps))
}
if op, exists := params.SearchOps["name"]; exists {
if op.Operator != "eq" {
t.Errorf("Expected operator=eq for name, got %s", op.Operator)
}
if op.Logic != "AND" {
t.Errorf("Expected logic=AND, got %s", op.Logic)
}
} else {
t.Error("Expected name search operator to exist")
}
},
},
{
name: "Parse search operators - OR",
headers: map[string]string{
"X-SearchOr-Like-Description": "test",
},
validate: func(t *testing.T, params *RequestParameters) {
if op, exists := params.SearchOps["description"]; exists {
if op.Logic != "OR" {
t.Errorf("Expected logic=OR, got %s", op.Logic)
}
} else {
t.Error("Expected description search operator to exist")
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := createTestRequest("GET", "/test", tt.queryParams, tt.headers, nil)
params := handler.ParseParameters(req)
if tt.validate != nil {
tt.validate(t, params)
}
})
}
}
// TestBuildFilterCondition tests the filter condition builder
func TestBuildFilterCondition(t *testing.T) {
handler := NewHandler(&MockDatabase{})
tests := []struct {
name string
colName string
operator FilterOperator
expected string
}{
{
name: "Equals operator - numeric",
colName: "age",
operator: FilterOperator{
Operator: "eq",
Value: "25",
Logic: "AND",
},
expected: "age = 25",
},
{
name: "Equals operator - string",
colName: "name",
operator: FilterOperator{
Operator: "eq",
Value: "john",
Logic: "AND",
},
expected: "name = 'john'",
},
{
name: "Not equals operator",
colName: "status",
operator: FilterOperator{
Operator: "neq",
Value: "inactive",
Logic: "AND",
},
expected: "status != 'inactive'",
},
{
name: "Greater than operator",
colName: "age",
operator: FilterOperator{
Operator: "gt",
Value: "18",
Logic: "AND",
},
expected: "age > 18",
},
{
name: "Less than operator",
colName: "price",
operator: FilterOperator{
Operator: "lt",
Value: "100",
Logic: "AND",
},
expected: "price < 100",
},
{
name: "Contains operator",
colName: "description",
operator: FilterOperator{
Operator: "contains",
Value: "test",
Logic: "AND",
},
expected: "description ILIKE '%test%'",
},
{
name: "Starts with operator",
colName: "name",
operator: FilterOperator{
Operator: "startswith",
Value: "john",
Logic: "AND",
},
expected: "name ILIKE 'john%'",
},
{
name: "Ends with operator",
colName: "email",
operator: FilterOperator{
Operator: "endswith",
Value: "@example.com",
Logic: "AND",
},
expected: "email ILIKE '%@example.com'",
},
{
name: "Between operator",
colName: "age",
operator: FilterOperator{
Operator: "between",
Value: "18,65",
Logic: "AND",
},
expected: "age > 18 AND age < 65",
},
{
name: "IN operator",
colName: "status",
operator: FilterOperator{
Operator: "in",
Value: "active,pending,approved",
Logic: "AND",
},
expected: "status IN ('active', 'pending', 'approved')",
},
{
name: "IS NULL operator",
colName: "deleted_at",
operator: FilterOperator{
Operator: "null",
Value: "",
Logic: "AND",
},
expected: "(deleted_at IS NULL OR deleted_at = '')",
},
{
name: "IS NOT NULL operator",
colName: "created_at",
operator: FilterOperator{
Operator: "notnull",
Value: "",
Logic: "AND",
},
expected: "(created_at IS NOT NULL AND created_at != '')",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := handler.buildFilterCondition(tt.colName, tt.operator)
if result != tt.expected {
t.Errorf("Expected: %s\nGot: %s", tt.expected, result)
}
})
}
}
// TestApplyFilters tests the filter application to SQL queries
func TestApplyFilters(t *testing.T) {
handler := NewHandler(&MockDatabase{})
tests := []struct {
name string
sqlQuery string
params *RequestParameters
expectedSQL string
shouldContain []string
}{
{
name: "Apply field filter",
sqlQuery: "SELECT * FROM users",
params: &RequestParameters{
FieldFilters: map[string]string{
"status": "active",
},
},
shouldContain: []string{"WHERE", "status"},
},
{
name: "Apply search filter",
sqlQuery: "SELECT * FROM users",
params: &RequestParameters{
SearchFilters: map[string]string{
"name": "john",
},
},
shouldContain: []string{"WHERE", "name", "ILIKE"},
},
{
name: "Apply search operators",
sqlQuery: "SELECT * FROM users",
params: &RequestParameters{
SearchOps: map[string]FilterOperator{
"age": {
Operator: "gt",
Value: "18",
Logic: "AND",
},
},
},
shouldContain: []string{"WHERE", "age", ">", "18"},
},
{
name: "Apply custom SQL WHERE",
sqlQuery: "SELECT * FROM users",
params: &RequestParameters{
CustomSQLWhere: "deleted = false",
},
shouldContain: []string{"WHERE", "deleted"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := handler.ApplyFilters(tt.sqlQuery, tt.params)
for _, expected := range tt.shouldContain {
if !strings.Contains(result, expected) {
t.Errorf("Expected SQL to contain %q, got: %s", expected, result)
}
}
})
}
}
// TestApplyDistinct tests DISTINCT application
func TestApplyDistinct(t *testing.T) {
handler := NewHandler(&MockDatabase{})
tests := []struct {
name string
sqlQuery string
distinct bool
shouldHave string
}{
{
name: "Apply DISTINCT",
sqlQuery: "SELECT id, name FROM users",
distinct: true,
shouldHave: "SELECT DISTINCT",
},
{
name: "Do not apply DISTINCT",
sqlQuery: "SELECT id, name FROM users",
distinct: false,
shouldHave: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
params := &RequestParameters{Distinct: tt.distinct}
result := handler.ApplyDistinct(tt.sqlQuery, params)
if tt.shouldHave != "" {
if !strings.Contains(result, tt.shouldHave) {
t.Errorf("Expected SQL to contain %q, got: %s", tt.shouldHave, result)
}
} else {
// Should not have DISTINCT when not requested
if strings.Contains(result, "DISTINCT") && !tt.distinct {
t.Errorf("SQL should not contain DISTINCT when not requested: %s", result)
}
}
})
}
}
// TestParseCommaSeparated tests comma-separated value parsing
func TestParseCommaSeparated(t *testing.T) {
handler := NewHandler(&MockDatabase{})
tests := []struct {
name string
input string
expected []string
}{
{
name: "Simple comma-separated",
input: "id,name,email",
expected: []string{"id", "name", "email"},
},
{
name: "With spaces",
input: "id, name, email",
expected: []string{"id", "name", "email"},
},
{
name: "Empty string",
input: "",
expected: nil,
},
{
name: "Single value",
input: "id",
expected: []string{"id"},
},
{
name: "With extra commas",
input: "id,,name,,email",
expected: []string{"id", "name", "email"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := handler.parseCommaSeparated(tt.input)
if len(result) != len(tt.expected) {
t.Errorf("Expected %d values, got %d", len(tt.expected), len(result))
return
}
for i, expected := range tt.expected {
if result[i] != expected {
t.Errorf("Expected value %d to be %s, got %s", i, expected, result[i])
}
}
})
}
}
// TestSqlQryWhereOr tests OR WHERE clause manipulation
func TestSqlQryWhereOr(t *testing.T) {
tests := []struct {
name string
sqlQuery string
condition string
shouldContain []string
}{
{
name: "Add WHERE with OR to query without WHERE",
sqlQuery: "SELECT * FROM users",
condition: "status = 'inactive'",
shouldContain: []string{"WHERE", "status = 'inactive'"},
},
{
name: "Add OR to query with existing WHERE",
sqlQuery: "SELECT * FROM users WHERE id > 0",
condition: "status = 'inactive'",
shouldContain: []string{"WHERE", "OR", "(status = 'inactive')"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := sqlQryWhereOr(tt.sqlQuery, tt.condition)
for _, expected := range tt.shouldContain {
if !strings.Contains(result, expected) {
t.Errorf("Expected SQL to contain %q, got: %s", expected, result)
}
}
})
}
}

View File

@@ -0,0 +1,83 @@
package funcspec
import (
"context"
"github.com/bitechdev/ResolveSpec/pkg/security"
)
// RegisterSecurityHooks registers security hooks for funcspec handlers
// Note: funcspec operates on SQL queries directly, so row-level security is not directly applicable
// We provide audit logging for data access tracking
func RegisterSecurityHooks(handler *Handler, securityList *security.SecurityList) {
// Hook 1: BeforeQueryList - Audit logging before query list execution
handler.Hooks().Register(BeforeQueryList, func(hookCtx *HookContext) error {
secCtx := newFuncSpecSecurityContext(hookCtx)
return security.LogDataAccess(secCtx)
})
// Hook 2: BeforeQuery - Audit logging before single query execution
handler.Hooks().Register(BeforeQuery, func(hookCtx *HookContext) error {
secCtx := newFuncSpecSecurityContext(hookCtx)
return security.LogDataAccess(secCtx)
})
// Note: Row-level security and column masking are challenging in funcspec
// because the SQL query is fully user-defined. Security should be implemented
// at the SQL function level or through database policies (RLS).
}
// funcSpecSecurityContext adapts funcspec.HookContext to security.SecurityContext interface
type funcSpecSecurityContext struct {
ctx *HookContext
}
func newFuncSpecSecurityContext(ctx *HookContext) security.SecurityContext {
return &funcSpecSecurityContext{ctx: ctx}
}
func (f *funcSpecSecurityContext) GetContext() context.Context {
return f.ctx.Context
}
func (f *funcSpecSecurityContext) GetUserID() (int, bool) {
if f.ctx.UserContext == nil {
return 0, false
}
return int(f.ctx.UserContext.UserID), true
}
func (f *funcSpecSecurityContext) GetSchema() string {
// funcspec doesn't have a schema concept, extract from SQL query or use default
return "public"
}
func (f *funcSpecSecurityContext) GetEntity() string {
// funcspec doesn't have an entity concept, could parse from SQL or use a placeholder
return "sql_query"
}
func (f *funcSpecSecurityContext) GetModel() interface{} {
// funcspec doesn't use models in the same way as restheadspec
return nil
}
func (f *funcSpecSecurityContext) GetQuery() interface{} {
// In funcspec, the query is a string, not a query builder object
return f.ctx.SQLQuery
}
func (f *funcSpecSecurityContext) SetQuery(query interface{}) {
// In funcspec, we could modify the SQL string, but this should be done cautiously
if sqlQuery, ok := query.(string); ok {
f.ctx.SQLQuery = sqlQuery
}
}
func (f *funcSpecSecurityContext) GetResult() interface{} {
return f.ctx.Result
}
func (f *funcSpecSecurityContext) SetResult(result interface{}) {
f.ctx.Result = result
}

View File

@@ -23,6 +23,15 @@ func Init(dev bool) {
}
func UpdateLoggerPath(path string, dev bool) {
defaultConfig := zap.NewProductionConfig()
if dev {
defaultConfig = zap.NewDevelopmentConfig()
}
defaultConfig.OutputPaths = []string{path}
UpdateLogger(&defaultConfig)
}
func UpdateLogger(config *zap.Config) {
defaultConfig := zap.NewProductionConfig()
defaultConfig.OutputPaths = []string{"resolvespec.log"}

View File

@@ -26,8 +26,7 @@ func GetModelColumnDetail(record reflect.Value) []ModelFieldDetail {
}
}()
var lst []ModelFieldDetail
lst = make([]ModelFieldDetail, 0)
lst := make([]ModelFieldDetail, 0)
if !record.IsValid() {
return lst

View File

@@ -17,3 +17,33 @@ func Len(v any) int {
return 0
}
}
// ExtractTableNameOnly extracts the table name from a fully qualified table reference.
// It removes any schema prefix (e.g., "schema.table" -> "table") and truncates at
// the first delimiter (comma, space, tab, or newline). If the input contains multiple
// dots, it returns everything after the last dot up to the first delimiter.
func ExtractTableNameOnly(fullName string) string {
// First, split by dot to remove schema prefix if present
lastDotIndex := -1
for i, char := range fullName {
if char == '.' {
lastDotIndex = i
}
}
// Start from after the last dot (or from beginning if no dot)
startIndex := 0
if lastDotIndex != -1 {
startIndex = lastDotIndex + 1
}
// Now find the end (first delimiter after the table name)
for i := startIndex; i < len(fullName); i++ {
char := rune(fullName[i])
if char == ',' || char == ' ' || char == '\t' || char == '\n' {
return fullName[startIndex:i]
}
}
return fullName[startIndex:]
}

View File

@@ -1,7 +1,9 @@
package reflection
import (
"fmt"
"reflect"
"strconv"
"strings"
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
@@ -132,7 +134,7 @@ func findFieldByName(val reflect.Value, name string) any {
}
// Check if field name matches
if strings.ToLower(field.Name) == name && fieldValue.CanInterface() {
if strings.EqualFold(field.Name, name) && fieldValue.CanInterface() {
return fieldValue.Interface()
}
}
@@ -409,8 +411,8 @@ func collectSQLColumnsFromType(typ reflect.Type, columns *[]string, scanOnlyEmbe
if bunTag != "" {
// Skip if it's a bun relation (rel:, join:, or m2m:)
if strings.Contains(bunTag, "rel:") ||
strings.Contains(bunTag, "join:") ||
strings.Contains(bunTag, "m2m:") {
strings.Contains(bunTag, "join:") ||
strings.Contains(bunTag, "m2m:") {
continue
}
}
@@ -419,9 +421,9 @@ func collectSQLColumnsFromType(typ reflect.Type, columns *[]string, scanOnlyEmbe
if gormTag != "" {
// Skip if it has gorm relationship tags
if strings.Contains(gormTag, "foreignKey:") ||
strings.Contains(gormTag, "references:") ||
strings.Contains(gormTag, "many2many:") ||
strings.Contains(gormTag, "constraint:") {
strings.Contains(gormTag, "references:") ||
strings.Contains(gormTag, "many2many:") ||
strings.Contains(gormTag, "constraint:") {
continue
}
}
@@ -472,7 +474,7 @@ func IsColumnWritable(model any, columnName string) bool {
// isColumnWritableInType recursively searches for a column and checks if it's writable
// Returns (found, writable) where found indicates if the column was found
func isColumnWritableInType(typ reflect.Type, columnName string) (bool, bool) {
func isColumnWritableInType(typ reflect.Type, columnName string) (found bool, writable bool) {
for i := 0; i < typ.NumField(); i++ {
field := typ.Field(i)
@@ -561,3 +563,321 @@ func isGormFieldReadOnly(tag string) bool {
}
return false
}
// ExtractSourceColumn extracts the base column name from PostgreSQL JSON operators
// Examples:
// - "columna->>'val'" returns "columna"
// - "columna->'key'" returns "columna"
// - "columna" returns "columna"
// - "table.columna->>'val'" returns "table.columna"
func ExtractSourceColumn(colName string) string {
// Check for PostgreSQL JSON operators: -> and ->>
if idx := strings.Index(colName, "->>"); idx != -1 {
return strings.TrimSpace(colName[:idx])
}
if idx := strings.Index(colName, "->"); idx != -1 {
return strings.TrimSpace(colName[:idx])
}
return colName
}
// ToSnakeCase converts a string from CamelCase to snake_case
func ToSnakeCase(s string) string {
var result strings.Builder
for i, r := range s {
if i > 0 && r >= 'A' && r <= 'Z' {
result.WriteRune('_')
}
result.WriteRune(r)
}
return strings.ToLower(result.String())
}
// GetColumnTypeFromModel uses reflection to determine the Go type of a column in a model
func GetColumnTypeFromModel(model interface{}, colName string) reflect.Kind {
if model == nil {
return reflect.Invalid
}
// Extract the source column name (remove JSON operators like ->> or ->)
sourceColName := ExtractSourceColumn(colName)
modelType := reflect.TypeOf(model)
// Dereference pointer if needed
if modelType.Kind() == reflect.Ptr {
modelType = modelType.Elem()
}
// Ensure it's a struct
if modelType.Kind() != reflect.Struct {
return reflect.Invalid
}
// Find the field by JSON tag or field name
for i := 0; i < modelType.NumField(); i++ {
field := modelType.Field(i)
// Check JSON tag
jsonTag := field.Tag.Get("json")
if jsonTag != "" {
// Parse JSON tag (format: "name,omitempty")
parts := strings.Split(jsonTag, ",")
if parts[0] == sourceColName {
return field.Type.Kind()
}
}
// Check field name (case-insensitive)
if strings.EqualFold(field.Name, sourceColName) {
return field.Type.Kind()
}
// Check snake_case conversion
snakeCaseName := ToSnakeCase(field.Name)
if snakeCaseName == sourceColName {
return field.Type.Kind()
}
}
return reflect.Invalid
}
// IsNumericType checks if a reflect.Kind is a numeric type
func IsNumericType(kind reflect.Kind) bool {
return kind == reflect.Int || kind == reflect.Int8 || kind == reflect.Int16 ||
kind == reflect.Int32 || kind == reflect.Int64 || kind == reflect.Uint ||
kind == reflect.Uint8 || kind == reflect.Uint16 || kind == reflect.Uint32 ||
kind == reflect.Uint64 || kind == reflect.Float32 || kind == reflect.Float64
}
// IsStringType checks if a reflect.Kind is a string type
func IsStringType(kind reflect.Kind) bool {
return kind == reflect.String
}
// IsNumericValue checks if a string value can be parsed as a number
func IsNumericValue(value string) bool {
value = strings.TrimSpace(value)
_, err := strconv.ParseFloat(value, 64)
return err == nil
}
// ConvertToNumericType converts a string value to the appropriate numeric type
func ConvertToNumericType(value string, kind reflect.Kind) (interface{}, error) {
value = strings.TrimSpace(value)
switch kind {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
// Parse as integer
bitSize := 64
switch kind {
case reflect.Int8:
bitSize = 8
case reflect.Int16:
bitSize = 16
case reflect.Int32:
bitSize = 32
}
intVal, err := strconv.ParseInt(value, 10, bitSize)
if err != nil {
return nil, fmt.Errorf("invalid integer value: %w", err)
}
// Return the appropriate type
switch kind {
case reflect.Int:
return int(intVal), nil
case reflect.Int8:
return int8(intVal), nil
case reflect.Int16:
return int16(intVal), nil
case reflect.Int32:
return int32(intVal), nil
case reflect.Int64:
return intVal, nil
}
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
// Parse as unsigned integer
bitSize := 64
switch kind {
case reflect.Uint8:
bitSize = 8
case reflect.Uint16:
bitSize = 16
case reflect.Uint32:
bitSize = 32
}
uintVal, err := strconv.ParseUint(value, 10, bitSize)
if err != nil {
return nil, fmt.Errorf("invalid unsigned integer value: %w", err)
}
// Return the appropriate type
switch kind {
case reflect.Uint:
return uint(uintVal), nil
case reflect.Uint8:
return uint8(uintVal), nil
case reflect.Uint16:
return uint16(uintVal), nil
case reflect.Uint32:
return uint32(uintVal), nil
case reflect.Uint64:
return uintVal, nil
}
case reflect.Float32, reflect.Float64:
// Parse as float
bitSize := 64
if kind == reflect.Float32 {
bitSize = 32
}
floatVal, err := strconv.ParseFloat(value, bitSize)
if err != nil {
return nil, fmt.Errorf("invalid float value: %w", err)
}
if kind == reflect.Float32 {
return float32(floatVal), nil
}
return floatVal, nil
}
return nil, fmt.Errorf("unsupported numeric type: %v", kind)
}
// GetRelationModel gets the model type for a relation field
// It searches for the field by name in the following order (case-insensitive):
// 1. Actual field name
// 2. Bun tag name (if exists)
// 3. Gorm tag name (if exists)
// 4. JSON tag name (if exists)
//
// Supports recursive field paths using dot notation (e.g., "MAL.MAL.DEF")
// For nested fields, it traverses through each level of the struct hierarchy
func GetRelationModel(model interface{}, fieldName string) interface{} {
if model == nil || fieldName == "" {
return nil
}
// Split the field name by "." to handle nested/recursive relations
fieldParts := strings.Split(fieldName, ".")
// Start with the current model
currentModel := model
// Traverse through each level of the field path
for _, part := range fieldParts {
if part == "" {
continue
}
currentModel = getRelationModelSingleLevel(currentModel, part)
if currentModel == nil {
return nil
}
}
return currentModel
}
// getRelationModelSingleLevel gets the model type for a single level field (non-recursive)
// This is a helper function used by GetRelationModel to handle one level at a time
func getRelationModelSingleLevel(model interface{}, fieldName string) interface{} {
if model == nil || fieldName == "" {
return nil
}
modelType := reflect.TypeOf(model)
if modelType == nil {
return nil
}
if modelType.Kind() == reflect.Ptr {
modelType = modelType.Elem()
}
if modelType == nil || modelType.Kind() != reflect.Struct {
return nil
}
// Find the field by checking in priority order (case-insensitive)
var field *reflect.StructField
normalizedFieldName := strings.ToLower(fieldName)
for i := 0; i < modelType.NumField(); i++ {
f := modelType.Field(i)
// 1. Check actual field name (case-insensitive)
if strings.EqualFold(f.Name, fieldName) {
field = &f
break
}
// 2. Check bun tag name
bunTag := f.Tag.Get("bun")
if bunTag != "" {
bunColName := ExtractColumnFromBunTag(bunTag)
if bunColName != "" && strings.EqualFold(bunColName, normalizedFieldName) {
field = &f
break
}
}
// 3. Check gorm tag name
gormTag := f.Tag.Get("gorm")
if gormTag != "" {
gormColName := ExtractColumnFromGormTag(gormTag)
if gormColName != "" && strings.EqualFold(gormColName, normalizedFieldName) {
field = &f
break
}
}
// 4. Check JSON tag name
jsonTag := f.Tag.Get("json")
if jsonTag != "" {
parts := strings.Split(jsonTag, ",")
if len(parts) > 0 && parts[0] != "" && parts[0] != "-" {
if strings.EqualFold(parts[0], normalizedFieldName) {
field = &f
break
}
}
}
}
if field == nil {
return nil
}
// Get the target type
targetType := field.Type
if targetType == nil {
return nil
}
if targetType.Kind() == reflect.Slice {
targetType = targetType.Elem()
if targetType == nil {
return nil
}
}
if targetType.Kind() == reflect.Ptr {
targetType = targetType.Elem()
if targetType == nil {
return nil
}
}
if targetType.Kind() != reflect.Struct {
return nil
}
// Create a zero value of the target type
return reflect.New(targetType).Elem().Interface()
}

View File

@@ -8,17 +8,25 @@ import (
"reflect"
"runtime/debug"
"strings"
"time"
"github.com/bitechdev/ResolveSpec/pkg/cache"
"github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/logger"
"github.com/bitechdev/ResolveSpec/pkg/reflection"
)
// FallbackHandler is a function that handles requests when no model is found
// It receives the same parameters as the Handle method
type FallbackHandler func(w common.ResponseWriter, r common.Request, params map[string]string)
// Handler handles API requests using database and model abstractions
type Handler struct {
db common.Database
registry common.ModelRegistry
nestedProcessor *common.NestedCUDProcessor
hooks *HookRegistry
fallbackHandler FallbackHandler
}
// NewHandler creates a new API handler with database and registry abstractions
@@ -26,12 +34,31 @@ func NewHandler(db common.Database, registry common.ModelRegistry) *Handler {
handler := &Handler{
db: db,
registry: registry,
hooks: NewHookRegistry(),
}
// Initialize nested processor
handler.nestedProcessor = common.NewNestedCUDProcessor(db, registry, handler)
return handler
}
// Hooks returns the hook registry for this handler
// Use this to register custom hooks for operations
func (h *Handler) Hooks() *HookRegistry {
return h.hooks
}
// SetFallbackHandler sets a fallback handler to be called when no model is found
// If not set, the handler will simply return (pass through to next route)
func (h *Handler) SetFallbackHandler(fallback FallbackHandler) {
h.fallbackHandler = fallback
}
// GetDatabase returns the underlying database connection
// Implements common.SpecHandler interface
func (h *Handler) GetDatabase() common.Database {
return h.db
}
// handlePanic is a helper function to handle panics with stack traces
func (h *Handler) handlePanic(w common.ResponseWriter, method string, err interface{}) {
stack := debug.Stack()
@@ -73,8 +100,14 @@ func (h *Handler) Handle(w common.ResponseWriter, r common.Request, params map[s
// Get model and populate context with request-scoped data
model, err := h.registry.GetModelByEntity(schema, entity)
if err != nil {
logger.Error("Invalid entity: %v", err)
h.sendError(w, http.StatusBadRequest, "invalid_entity", "Invalid entity", err)
// Model not found - call fallback handler if set, otherwise pass through
logger.Debug("Model not found for %s.%s", schema, entity)
if h.fallbackHandler != nil {
logger.Debug("Calling fallback handler for %s.%s", schema, entity)
h.fallbackHandler(w, r, params)
} else {
logger.Debug("No fallback handler set, passing through to next route")
}
return
}
@@ -118,6 +151,8 @@ func (h *Handler) Handle(w common.ResponseWriter, r common.Request, params map[s
h.handleUpdate(ctx, w, id, req.ID, req.Data, req.Options)
case "delete":
h.handleDelete(ctx, w, id, req.Data)
case "meta":
h.handleMeta(ctx, w, schema, entity, model)
default:
logger.Error("Invalid operation: %s", req.Operation)
h.sendError(w, http.StatusBadRequest, "invalid_operation", "Invalid operation", nil)
@@ -140,8 +175,14 @@ func (h *Handler) HandleGet(w common.ResponseWriter, r common.Request, params ma
model, err := h.registry.GetModelByEntity(schema, entity)
if err != nil {
logger.Error("Failed to get model: %v", err)
h.sendError(w, http.StatusBadRequest, "invalid_entity", "Invalid entity", err)
// Model not found - call fallback handler if set, otherwise pass through
logger.Debug("Model not found for %s.%s", schema, entity)
if h.fallbackHandler != nil {
logger.Debug("Calling fallback handler for %s.%s", schema, entity)
h.fallbackHandler(w, r, params)
} else {
logger.Debug("No fallback handler set, passing through to next route")
}
return
}
@@ -149,6 +190,21 @@ func (h *Handler) HandleGet(w common.ResponseWriter, r common.Request, params ma
h.sendResponse(w, metadata, nil)
}
// handleMeta processes meta operation requests
func (h *Handler) handleMeta(ctx context.Context, w common.ResponseWriter, schema, entity string, model interface{}) {
// Capture panics and return error response
defer func() {
if err := recover(); err != nil {
h.handlePanic(w, "handleMeta", err)
}
}()
logger.Info("Getting metadata for %s.%s via meta operation", schema, entity)
metadata := h.generateMetadata(schema, entity, model)
h.sendResponse(w, metadata, nil)
}
func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id string, options common.RequestOptions) {
// Capture panics and return error response
defer func() {
@@ -199,7 +255,9 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
// Apply column selection
if len(options.Columns) > 0 {
logger.Debug("Selecting columns: %v", options.Columns)
query = query.Column(options.Columns...)
for _, col := range options.Columns {
query = query.Column(reflection.ExtractSourceColumn(col))
}
}
if len(options.ComputedColumns) > 0 {
@@ -231,13 +289,46 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
}
// Get total count before pagination
total, err := query.Count(ctx)
if err != nil {
logger.Error("Error counting records: %v", err)
h.sendError(w, http.StatusInternalServerError, "query_error", "Error counting records", err)
return
var total int
// Try to get from cache first
cacheKeyHash := cache.BuildQueryCacheKey(
tableName,
options.Filters,
options.Sort,
"", // No custom SQL WHERE in resolvespec
"", // No custom SQL OR in resolvespec
)
cacheKey := cache.GetQueryTotalCacheKey(cacheKeyHash)
// Try to retrieve from cache
var cachedTotal cache.CachedTotal
err := cache.GetDefaultCache().Get(ctx, cacheKey, &cachedTotal)
if err == nil {
total = cachedTotal.Total
logger.Debug("Total records (from cache): %d", total)
} else {
// Cache miss - execute count query
logger.Debug("Cache miss for query total")
count, err := query.Count(ctx)
if err != nil {
logger.Error("Error counting records: %v", err)
h.sendError(w, http.StatusInternalServerError, "query_error", "Error counting records", err)
return
}
total = count
logger.Debug("Total records (from query): %d", total)
// Store in cache
cacheTTL := time.Minute * 2 // Default 2 minutes TTL
cacheData := cache.CachedTotal{Total: total}
if err := cache.GetDefaultCache().Set(ctx, cacheKey, cacheData, cacheTTL); err != nil {
logger.Warn("Failed to cache query total: %v", err)
// Don't fail the request if caching fails
} else {
logger.Debug("Cached query total with key: %s", cacheKey)
}
}
logger.Debug("Total records before filtering: %d", total)
// Apply pagination
if options.Limit != nil && *options.Limit > 0 {
@@ -1149,6 +1240,11 @@ func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, pre
logger.Debug("Applying preload: %s", relationFieldName)
query = query.PreloadRelation(relationFieldName, func(sq common.SelectQuery) common.SelectQuery {
if len(preload.Columns) == 0 && (len(preload.ComputedQL) > 0 || len(preload.OmitColumns) > 0) {
preload.Columns = reflection.GetSQLModelColumns(model)
}
// Handle column selection and omission
if len(preload.OmitColumns) > 0 {
allCols := reflection.GetSQLModelColumns(model)
// Remove omitted columns
@@ -1204,7 +1300,10 @@ func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, pre
}
if len(preload.Where) > 0 {
sq = sq.Where(preload.Where)
sanitizedWhere := common.SanitizeWhereClause(preload.Where, reflection.ExtractTableNameOnly(preload.Relation))
if len(sanitizedWhere) > 0 {
sq = sq.Where(sanitizedWhere)
}
}
if preload.Limit != nil && *preload.Limit > 0 {

152
pkg/resolvespec/hooks.go Normal file
View File

@@ -0,0 +1,152 @@
package resolvespec
import (
"context"
"fmt"
"github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/logger"
)
// HookType defines the type of hook to execute
type HookType string
const (
// Read operation hooks
BeforeRead HookType = "before_read"
AfterRead HookType = "after_read"
// Create operation hooks
BeforeCreate HookType = "before_create"
AfterCreate HookType = "after_create"
// Update operation hooks
BeforeUpdate HookType = "before_update"
AfterUpdate HookType = "after_update"
// Delete operation hooks
BeforeDelete HookType = "before_delete"
AfterDelete HookType = "after_delete"
// Scan/Execute operation hooks (for query building)
BeforeScan HookType = "before_scan"
)
// HookContext contains all the data available to a hook
type HookContext struct {
Context context.Context
Handler *Handler // Reference to the handler for accessing database, registry, etc.
Schema string
Entity string
Model interface{}
Options common.RequestOptions
Writer common.ResponseWriter
Request common.Request
// Operation-specific fields
ID string
Data interface{} // For create/update operations
Result interface{} // For after hooks
Error error // For after hooks
// Query chain - allows hooks to modify the query before execution
Query common.SelectQuery
// Allow hooks to abort the operation
Abort bool // If set to true, the operation will be aborted
AbortMessage string // Message to return if aborted
AbortCode int // HTTP status code if aborted
}
// HookFunc is the signature for hook functions
// It receives a HookContext and can modify it or return an error
// If an error is returned, the operation will be aborted
type HookFunc func(*HookContext) error
// HookRegistry manages all registered hooks
type HookRegistry struct {
hooks map[HookType][]HookFunc
}
// NewHookRegistry creates a new hook registry
func NewHookRegistry() *HookRegistry {
return &HookRegistry{
hooks: make(map[HookType][]HookFunc),
}
}
// Register adds a new hook for the specified hook type
func (r *HookRegistry) Register(hookType HookType, hook HookFunc) {
if r.hooks == nil {
r.hooks = make(map[HookType][]HookFunc)
}
r.hooks[hookType] = append(r.hooks[hookType], hook)
logger.Info("Registered resolvespec hook for %s (total: %d)", hookType, len(r.hooks[hookType]))
}
// RegisterMultiple registers a hook for multiple hook types
func (r *HookRegistry) RegisterMultiple(hookTypes []HookType, hook HookFunc) {
for _, hookType := range hookTypes {
r.Register(hookType, hook)
}
}
// Execute runs all hooks for the specified type in order
// If any hook returns an error, execution stops and the error is returned
func (r *HookRegistry) Execute(hookType HookType, ctx *HookContext) error {
hooks, exists := r.hooks[hookType]
if !exists || len(hooks) == 0 {
return nil
}
logger.Debug("Executing %d resolvespec hook(s) for %s", len(hooks), hookType)
for i, hook := range hooks {
if err := hook(ctx); err != nil {
logger.Error("Resolvespec hook %d for %s failed: %v", i+1, hookType, err)
return fmt.Errorf("hook execution failed: %w", err)
}
// Check if hook requested abort
if ctx.Abort {
logger.Warn("Resolvespec hook %d for %s requested abort: %s", i+1, hookType, ctx.AbortMessage)
return fmt.Errorf("operation aborted by hook: %s", ctx.AbortMessage)
}
}
return nil
}
// Clear removes all hooks for the specified type
func (r *HookRegistry) Clear(hookType HookType) {
delete(r.hooks, hookType)
logger.Info("Cleared all resolvespec hooks for %s", hookType)
}
// ClearAll removes all registered hooks
func (r *HookRegistry) ClearAll() {
r.hooks = make(map[HookType][]HookFunc)
logger.Info("Cleared all resolvespec hooks")
}
// Count returns the number of hooks registered for a specific type
func (r *HookRegistry) Count(hookType HookType) int {
if hooks, exists := r.hooks[hookType]; exists {
return len(hooks)
}
return 0
}
// HasHooks returns true if there are any hooks registered for the specified type
func (r *HookRegistry) HasHooks(hookType HookType) bool {
return r.Count(hookType) > 0
}
// GetAllHookTypes returns all hook types that have registered hooks
func (r *HookRegistry) GetAllHookTypes() []HookType {
types := make([]HookType, 0, len(r.hooks))
for hookType := range r.hooks {
types = append(types, hookType)
}
return types
}

View File

@@ -2,12 +2,14 @@ package resolvespec
import (
"net/http"
"strings"
"github.com/gorilla/mux"
"github.com/uptrace/bun"
"github.com/uptrace/bunrouter"
"gorm.io/gorm"
"github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/common/adapters/database"
"github.com/bitechdev/ResolveSpec/pkg/common/adapters/router"
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
@@ -37,28 +39,122 @@ func NewStandardBunRouter() *router.StandardBunRouterAdapter {
return router.NewStandardBunRouterAdapter()
}
// MiddlewareFunc is a function that wraps an http.Handler with additional functionality
type MiddlewareFunc func(http.Handler) http.Handler
// SetupMuxRoutes sets up routes for the ResolveSpec API with Mux
func SetupMuxRoutes(muxRouter *mux.Router, handler *Handler) {
muxRouter.HandleFunc("/{schema}/{entity}", func(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r)
reqAdapter := router.NewHTTPRequest(r)
respAdapter := router.NewHTTPResponseWriter(w)
handler.Handle(respAdapter, reqAdapter, vars)
}).Methods("POST")
// authMiddleware is optional - if provided, routes will be protected with the middleware
// Example: SetupMuxRoutes(router, handler, func(h http.Handler) http.Handler { return security.NewAuthHandler(securityList, h) })
func SetupMuxRoutes(muxRouter *mux.Router, handler *Handler, authMiddleware MiddlewareFunc) {
// Get all registered models from the registry
allModels := handler.registry.GetAllModels()
muxRouter.HandleFunc("/{schema}/{entity}/{id}", func(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r)
reqAdapter := router.NewHTTPRequest(r)
respAdapter := router.NewHTTPResponseWriter(w)
handler.Handle(respAdapter, reqAdapter, vars)
}).Methods("POST")
// Loop through each registered model and create explicit routes
for fullName := range allModels {
// Parse the full name (e.g., "public.users" or just "users")
schema, entity := parseModelName(fullName)
muxRouter.HandleFunc("/{schema}/{entity}", func(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r)
reqAdapter := router.NewHTTPRequest(r)
// Build the route paths
entityPath := buildRoutePath(schema, entity)
entityWithIDPath := buildRoutePath(schema, entity) + "/{id}"
// Create handler functions for this specific entity
postEntityHandler := createMuxHandler(handler, schema, entity, "")
postEntityWithIDHandler := createMuxHandler(handler, schema, entity, "id")
getEntityHandler := createMuxGetHandler(handler, schema, entity, "")
optionsEntityHandler := createMuxOptionsHandler(handler, schema, entity, []string{"GET", "POST", "OPTIONS"})
optionsEntityWithIDHandler := createMuxOptionsHandler(handler, schema, entity, []string{"POST", "OPTIONS"})
// Apply authentication middleware if provided
if authMiddleware != nil {
postEntityHandler = authMiddleware(postEntityHandler).(http.HandlerFunc)
postEntityWithIDHandler = authMiddleware(postEntityWithIDHandler).(http.HandlerFunc)
getEntityHandler = authMiddleware(getEntityHandler).(http.HandlerFunc)
// Don't apply auth middleware to OPTIONS - CORS preflight must not require auth
}
// Register routes for this entity
muxRouter.Handle(entityPath, postEntityHandler).Methods("POST")
muxRouter.Handle(entityWithIDPath, postEntityWithIDHandler).Methods("POST")
muxRouter.Handle(entityPath, getEntityHandler).Methods("GET")
muxRouter.Handle(entityPath, optionsEntityHandler).Methods("OPTIONS")
muxRouter.Handle(entityWithIDPath, optionsEntityWithIDHandler).Methods("OPTIONS")
}
}
// Helper function to create Mux handler for a specific entity with CORS support
func createMuxHandler(handler *Handler, schema, entity, idParam string) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// Set CORS headers
corsConfig := common.DefaultCORSConfig()
respAdapter := router.NewHTTPResponseWriter(w)
common.SetCORSHeaders(respAdapter, corsConfig)
vars := make(map[string]string)
vars["schema"] = schema
vars["entity"] = entity
if idParam != "" {
vars["id"] = mux.Vars(r)[idParam]
}
reqAdapter := router.NewHTTPRequest(r)
handler.Handle(respAdapter, reqAdapter, vars)
}
}
// Helper function to create Mux GET handler for a specific entity with CORS support
func createMuxGetHandler(handler *Handler, schema, entity, idParam string) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// Set CORS headers
corsConfig := common.DefaultCORSConfig()
respAdapter := router.NewHTTPResponseWriter(w)
common.SetCORSHeaders(respAdapter, corsConfig)
vars := make(map[string]string)
vars["schema"] = schema
vars["entity"] = entity
if idParam != "" {
vars["id"] = mux.Vars(r)[idParam]
}
reqAdapter := router.NewHTTPRequest(r)
handler.HandleGet(respAdapter, reqAdapter, vars)
}).Methods("GET")
}
}
// Helper function to create Mux OPTIONS handler that returns metadata
func createMuxOptionsHandler(handler *Handler, schema, entity string, allowedMethods []string) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// Set CORS headers with the allowed methods for this route
corsConfig := common.DefaultCORSConfig()
corsConfig.AllowedMethods = allowedMethods
respAdapter := router.NewHTTPResponseWriter(w)
common.SetCORSHeaders(respAdapter, corsConfig)
// Return metadata in the OPTIONS response body
vars := make(map[string]string)
vars["schema"] = schema
vars["entity"] = entity
reqAdapter := router.NewHTTPRequest(r)
handler.HandleGet(respAdapter, reqAdapter, vars)
}
}
// parseModelName parses a model name like "public.users" into schema and entity
// If no schema is present, returns empty string for schema
func parseModelName(fullName string) (schema, entity string) {
parts := strings.Split(fullName, ".")
if len(parts) == 2 {
return parts[0], parts[1]
}
return "", fullName
}
// buildRoutePath builds a route path from schema and entity
// If schema is empty, returns just "/entity", otherwise "/{schema}/{entity}"
func buildRoutePath(schema, entity string) string {
if schema == "" {
return "/" + entity
}
return "/" + schema + "/" + entity
}
// Example usage functions for documentation:
@@ -68,12 +164,20 @@ func ExampleWithGORM(db *gorm.DB) {
// Create handler using GORM
handler := NewHandlerWithGORM(db)
// Setup router
// Setup router without authentication
muxRouter := mux.NewRouter()
SetupMuxRoutes(muxRouter, handler)
SetupMuxRoutes(muxRouter, handler, nil)
// Register models
// handler.RegisterModel("public", "users", &User{})
// To add authentication, pass a middleware function:
// import "github.com/bitechdev/ResolveSpec/pkg/security"
// secList := security.NewSecurityList(myProvider)
// authMiddleware := func(h http.Handler) http.Handler {
// return security.NewAuthHandler(secList, h)
// }
// SetupMuxRoutes(muxRouter, handler, authMiddleware)
}
// ExampleWithBun shows how to switch to Bun ORM
@@ -88,60 +192,118 @@ func ExampleWithBun(bunDB *bun.DB) {
// Create handler
handler := NewHandler(dbAdapter, registry)
// Setup routes
// Setup routes without authentication
muxRouter := mux.NewRouter()
SetupMuxRoutes(muxRouter, handler)
SetupMuxRoutes(muxRouter, handler, nil)
}
// SetupBunRouterRoutes sets up bunrouter routes for the ResolveSpec API
func SetupBunRouterRoutes(bunRouter *router.StandardBunRouterAdapter, handler *Handler) {
r := bunRouter.GetBunRouter()
r.Handle("POST", "/:schema/:entity", func(w http.ResponseWriter, req bunrouter.Request) error {
params := map[string]string{
"schema": req.Param("schema"),
"entity": req.Param("entity"),
}
reqAdapter := router.NewHTTPRequest(req.Request)
respAdapter := router.NewHTTPResponseWriter(w)
handler.Handle(respAdapter, reqAdapter, params)
return nil
})
// Get all registered models from the registry
allModels := handler.registry.GetAllModels()
r.Handle("POST", "/:schema/:entity/:id", func(w http.ResponseWriter, req bunrouter.Request) error {
params := map[string]string{
"schema": req.Param("schema"),
"entity": req.Param("entity"),
"id": req.Param("id"),
}
reqAdapter := router.NewHTTPRequest(req.Request)
respAdapter := router.NewHTTPResponseWriter(w)
handler.Handle(respAdapter, reqAdapter, params)
return nil
})
// CORS config
corsConfig := common.DefaultCORSConfig()
r.Handle("GET", "/:schema/:entity", func(w http.ResponseWriter, req bunrouter.Request) error {
params := map[string]string{
"schema": req.Param("schema"),
"entity": req.Param("entity"),
}
reqAdapter := router.NewHTTPRequest(req.Request)
respAdapter := router.NewHTTPResponseWriter(w)
handler.HandleGet(respAdapter, reqAdapter, params)
return nil
})
// Loop through each registered model and create explicit routes
for fullName := range allModels {
// Parse the full name (e.g., "public.users" or just "users")
schema, entity := parseModelName(fullName)
r.Handle("GET", "/:schema/:entity/:id", func(w http.ResponseWriter, req bunrouter.Request) error {
params := map[string]string{
"schema": req.Param("schema"),
"entity": req.Param("entity"),
"id": req.Param("id"),
}
reqAdapter := router.NewHTTPRequest(req.Request)
respAdapter := router.NewHTTPResponseWriter(w)
handler.HandleGet(respAdapter, reqAdapter, params)
return nil
})
// Build the route paths
entityPath := buildRoutePath(schema, entity)
entityWithIDPath := entityPath + "/:id"
// Create closure variables to capture current schema and entity
currentSchema := schema
currentEntity := entity
// POST route without ID
r.Handle("POST", entityPath, func(w http.ResponseWriter, req bunrouter.Request) error {
respAdapter := router.NewHTTPResponseWriter(w)
common.SetCORSHeaders(respAdapter, corsConfig)
params := map[string]string{
"schema": currentSchema,
"entity": currentEntity,
}
reqAdapter := router.NewHTTPRequest(req.Request)
handler.Handle(respAdapter, reqAdapter, params)
return nil
})
// POST route with ID
r.Handle("POST", entityWithIDPath, func(w http.ResponseWriter, req bunrouter.Request) error {
respAdapter := router.NewHTTPResponseWriter(w)
common.SetCORSHeaders(respAdapter, corsConfig)
params := map[string]string{
"schema": currentSchema,
"entity": currentEntity,
"id": req.Param("id"),
}
reqAdapter := router.NewHTTPRequest(req.Request)
handler.Handle(respAdapter, reqAdapter, params)
return nil
})
// GET route without ID
r.Handle("GET", entityPath, func(w http.ResponseWriter, req bunrouter.Request) error {
respAdapter := router.NewHTTPResponseWriter(w)
common.SetCORSHeaders(respAdapter, corsConfig)
params := map[string]string{
"schema": currentSchema,
"entity": currentEntity,
}
reqAdapter := router.NewHTTPRequest(req.Request)
handler.HandleGet(respAdapter, reqAdapter, params)
return nil
})
// GET route with ID
r.Handle("GET", entityWithIDPath, func(w http.ResponseWriter, req bunrouter.Request) error {
respAdapter := router.NewHTTPResponseWriter(w)
common.SetCORSHeaders(respAdapter, corsConfig)
params := map[string]string{
"schema": currentSchema,
"entity": currentEntity,
"id": req.Param("id"),
}
reqAdapter := router.NewHTTPRequest(req.Request)
handler.HandleGet(respAdapter, reqAdapter, params)
return nil
})
// OPTIONS route without ID (returns metadata)
r.Handle("OPTIONS", entityPath, func(w http.ResponseWriter, req bunrouter.Request) error {
respAdapter := router.NewHTTPResponseWriter(w)
optionsCorsConfig := corsConfig
optionsCorsConfig.AllowedMethods = []string{"GET", "POST", "OPTIONS"}
common.SetCORSHeaders(respAdapter, optionsCorsConfig)
params := map[string]string{
"schema": currentSchema,
"entity": currentEntity,
}
reqAdapter := router.NewHTTPRequest(req.Request)
handler.HandleGet(respAdapter, reqAdapter, params)
return nil
})
// OPTIONS route with ID (returns metadata)
r.Handle("OPTIONS", entityWithIDPath, func(w http.ResponseWriter, req bunrouter.Request) error {
respAdapter := router.NewHTTPResponseWriter(w)
optionsCorsConfig := corsConfig
optionsCorsConfig.AllowedMethods = []string{"POST", "OPTIONS"}
common.SetCORSHeaders(respAdapter, optionsCorsConfig)
params := map[string]string{
"schema": currentSchema,
"entity": currentEntity,
}
reqAdapter := router.NewHTTPRequest(req.Request)
handler.HandleGet(respAdapter, reqAdapter, params)
return nil
})
}
}
// ExampleWithBunRouter shows how to use bunrouter from uptrace

View File

@@ -0,0 +1,85 @@
package resolvespec
import (
"context"
"github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/logger"
"github.com/bitechdev/ResolveSpec/pkg/security"
)
// RegisterSecurityHooks registers all security-related hooks with the handler
func RegisterSecurityHooks(handler *Handler, securityList *security.SecurityList) {
// Hook 1: BeforeRead - Load security rules
handler.Hooks().Register(BeforeRead, func(hookCtx *HookContext) error {
secCtx := newSecurityContext(hookCtx)
return security.LoadSecurityRules(secCtx, securityList)
})
// Hook 2: BeforeScan - Apply row-level security filters
handler.Hooks().Register(BeforeScan, func(hookCtx *HookContext) error {
secCtx := newSecurityContext(hookCtx)
return security.ApplyRowSecurity(secCtx, securityList)
})
// Hook 3: AfterRead - Apply column-level security (masking)
handler.Hooks().Register(AfterRead, func(hookCtx *HookContext) error {
secCtx := newSecurityContext(hookCtx)
return security.ApplyColumnSecurity(secCtx, securityList)
})
// Hook 4 (Optional): Audit logging
handler.Hooks().Register(AfterRead, func(hookCtx *HookContext) error {
secCtx := newSecurityContext(hookCtx)
return security.LogDataAccess(secCtx)
})
logger.Info("Security hooks registered for resolvespec handler")
}
// securityContext adapts resolvespec.HookContext to security.SecurityContext interface
type securityContext struct {
ctx *HookContext
}
func newSecurityContext(ctx *HookContext) security.SecurityContext {
return &securityContext{ctx: ctx}
}
func (s *securityContext) GetContext() context.Context {
return s.ctx.Context
}
func (s *securityContext) GetUserID() (int, bool) {
return security.GetUserID(s.ctx.Context)
}
func (s *securityContext) GetSchema() string {
return s.ctx.Schema
}
func (s *securityContext) GetEntity() string {
return s.ctx.Entity
}
func (s *securityContext) GetModel() interface{} {
return s.ctx.Model
}
func (s *securityContext) GetQuery() interface{} {
return s.ctx.Query
}
func (s *securityContext) SetQuery(query interface{}) {
if q, ok := query.(common.SelectQuery); ok {
s.ctx.Query = q
}
}
func (s *securityContext) GetResult() interface{} {
return s.ctx.Result
}
func (s *securityContext) SetResult(result interface{}) {
s.ctx.Result = result
}

View File

@@ -9,12 +9,18 @@ import (
"runtime/debug"
"strconv"
"strings"
"time"
"github.com/bitechdev/ResolveSpec/pkg/cache"
"github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/logger"
"github.com/bitechdev/ResolveSpec/pkg/reflection"
)
// FallbackHandler is a function that handles requests when no model is found
// It receives the same parameters as the Handle method
type FallbackHandler func(w common.ResponseWriter, r common.Request, params map[string]string)
// Handler handles API requests using database and model abstractions
// This handler reads filters, columns, and options from HTTP headers
type Handler struct {
@@ -22,6 +28,7 @@ type Handler struct {
registry common.ModelRegistry
hooks *HookRegistry
nestedProcessor *common.NestedCUDProcessor
fallbackHandler FallbackHandler
}
// NewHandler creates a new API handler with database and registry abstractions
@@ -36,12 +43,24 @@ func NewHandler(db common.Database, registry common.ModelRegistry) *Handler {
return handler
}
// GetDatabase returns the underlying database connection
// Implements common.SpecHandler interface
func (h *Handler) GetDatabase() common.Database {
return h.db
}
// Hooks returns the hook registry for this handler
// Use this to register custom hooks for operations
func (h *Handler) Hooks() *HookRegistry {
return h.hooks
}
// SetFallbackHandler sets a fallback handler to be called when no model is found
// If not set, the handler will simply return (pass through to next route)
func (h *Handler) SetFallbackHandler(fallback FallbackHandler) {
h.fallbackHandler = fallback
}
// handlePanic is a helper function to handle panics with stack traces
func (h *Handler) handlePanic(w common.ResponseWriter, method string, err interface{}) {
stack := debug.Stack()
@@ -73,8 +92,14 @@ func (h *Handler) Handle(w common.ResponseWriter, r common.Request, params map[s
// Get model and populate context with request-scoped data
model, err := h.registry.GetModelByEntity(schema, entity)
if err != nil {
logger.Error("Invalid entity: %v", err)
h.sendError(w, http.StatusBadRequest, "invalid_entity", "Invalid entity", err)
// Model not found - call fallback handler if set, otherwise pass through
logger.Debug("Model not found for %s.%s", schema, entity)
if h.fallbackHandler != nil {
logger.Debug("Calling fallback handler for %s.%s", schema, entity)
h.fallbackHandler(w, r, params)
} else {
logger.Debug("No fallback handler set, passing through to next route")
}
return
}
@@ -121,13 +146,25 @@ func (h *Handler) Handle(w common.ResponseWriter, r common.Request, params map[s
h.handleRead(ctx, w, "", options)
}
case "POST":
// Create operation
// Read request body
body, err := r.Body()
if err != nil {
logger.Error("Failed to read request body: %v", err)
h.sendError(w, http.StatusBadRequest, "invalid_request", "Failed to read request body", err)
return
}
// Try to detect if this is a meta operation request
var bodyMap map[string]interface{}
if err := json.Unmarshal(body, &bodyMap); err == nil {
if operation, ok := bodyMap["operation"].(string); ok && operation == "meta" {
logger.Info("Detected meta operation request for %s.%s", schema, entity)
h.handleMeta(ctx, w, schema, entity, model)
return
}
}
// Not a meta operation, proceed with normal create/update
var data interface{}
if err := json.Unmarshal(body, &data); err != nil {
logger.Error("Failed to decode request body: %v", err)
@@ -189,8 +226,14 @@ func (h *Handler) HandleGet(w common.ResponseWriter, r common.Request, params ma
model, err := h.registry.GetModelByEntity(schema, entity)
if err != nil {
logger.Error("Failed to get model: %v", err)
h.sendError(w, http.StatusBadRequest, "invalid_entity", "Invalid entity", err)
// Model not found - call fallback handler if set, otherwise pass through
logger.Debug("Model not found for %s.%s", schema, entity)
if h.fallbackHandler != nil {
logger.Debug("Calling fallback handler for %s.%s", schema, entity)
h.fallbackHandler(w, r, params)
} else {
logger.Debug("No fallback handler set, passing through to next route")
}
return
}
@@ -198,6 +241,21 @@ func (h *Handler) HandleGet(w common.ResponseWriter, r common.Request, params ma
h.sendResponse(w, metadata, nil)
}
// handleMeta processes meta operation requests
func (h *Handler) handleMeta(ctx context.Context, w common.ResponseWriter, schema, entity string, model interface{}) {
// Capture panics and return error response
defer func() {
if err := recover(); err != nil {
h.handlePanic(w, "handleMeta", err)
}
}()
logger.Info("Getting metadata for %s.%s via meta operation", schema, entity)
metadata := h.generateMetadata(schema, entity, model)
h.sendResponse(w, metadata, nil)
}
// parseOptionsFromHeaders is now implemented in headers.go
func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id string, options ExtendedRequestOptions) {
@@ -213,6 +271,10 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
tableName := GetTableName(ctx)
model := GetModel(ctx)
if id == "" {
options.SingleRecordAsObject = false
}
// Execute BeforeRead hooks
hookCtx := &HookContext{
Context: ctx,
@@ -299,7 +361,10 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
// Apply column selection
if len(options.Columns) > 0 {
logger.Debug("Selecting columns: %v", options.Columns)
query = query.Column(options.Columns...)
for _, col := range options.Columns {
query = query.Column(reflection.ExtractSourceColumn(col))
}
}
// Apply expand (Just expand to Preload for now)
@@ -360,50 +425,8 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
preload.Where = fixedWhere
}
query = query.PreloadRelation(preload.Relation, func(sq common.SelectQuery) common.SelectQuery {
if len(preload.OmitColumns) > 0 {
allCols := reflection.GetModelColumns(model)
// Remove omitted columns
preload.Columns = []string{}
for _, col := range allCols {
addCols := true
for _, omitCol := range preload.OmitColumns {
if col == omitCol {
addCols = false
break
}
}
if addCols {
preload.Columns = append(preload.Columns, col)
}
}
}
if len(preload.Columns) > 0 {
sq = sq.Column(preload.Columns...)
}
if len(preload.Filters) > 0 {
for _, filter := range preload.Filters {
sq = h.applyFilter(sq, filter, "", false, "AND")
}
}
if len(preload.Sort) > 0 {
for _, sort := range preload.Sort {
sq = sq.Order(fmt.Sprintf("%s %s", sort.Column, sort.Direction))
}
}
if len(preload.Where) > 0 {
sq = sq.Where(preload.Where)
}
if preload.Limit != nil && *preload.Limit > 0 {
sq = sq.Limit(*preload.Limit)
}
return sq
})
// Apply the preload with recursive support
query = h.applyPreloadWithRecursion(query, preload, model, 0)
}
// Apply DISTINCT if requested
@@ -433,13 +456,21 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
// Apply custom SQL WHERE clause (AND condition)
if options.CustomSQLWhere != "" {
logger.Debug("Applying custom SQL WHERE: %s", options.CustomSQLWhere)
query = query.Where(options.CustomSQLWhere)
// Sanitize without auto-prefixing since custom SQL may reference multiple tables
sanitizedWhere := common.SanitizeWhereClause(options.CustomSQLWhere, reflection.ExtractTableNameOnly(tableName))
if sanitizedWhere != "" {
query = query.Where(sanitizedWhere)
}
}
// Apply custom SQL WHERE clause (OR condition)
if options.CustomSQLOr != "" {
logger.Debug("Applying custom SQL OR: %s", options.CustomSQLOr)
query = query.WhereOr(options.CustomSQLOr)
// Sanitize without auto-prefixing since custom SQL may reference multiple tables
sanitizedOr := common.SanitizeWhereClause(options.CustomSQLOr, reflection.ExtractTableNameOnly(tableName))
if sanitizedOr != "" {
query = query.WhereOr(sanitizedOr)
}
}
// If ID is provided, filter by ID
@@ -463,14 +494,69 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
// Get total count before pagination (unless skip count is requested)
var total int
if !options.SkipCount {
count, err := query.Count(ctx)
if err != nil {
logger.Error("Error counting records: %v", err)
h.sendError(w, http.StatusInternalServerError, "query_error", "Error counting records", err)
return
// Try to get from cache first (unless SkipCache is true)
var cachedTotal *cache.CachedTotal
var cacheKey string
if !options.SkipCache {
// Build cache key from query parameters
// Convert expand options to interface slice for the cache key builder
expandOpts := make([]interface{}, len(options.Expand))
for i, exp := range options.Expand {
expandOpts[i] = map[string]interface{}{
"relation": exp.Relation,
"where": exp.Where,
}
}
cacheKeyHash := cache.BuildExtendedQueryCacheKey(
tableName,
options.Filters,
options.Sort,
options.CustomSQLWhere,
options.CustomSQLOr,
expandOpts,
options.Distinct,
options.CursorForward,
options.CursorBackward,
)
cacheKey = cache.GetQueryTotalCacheKey(cacheKeyHash)
// Try to retrieve from cache
cachedTotal = &cache.CachedTotal{}
err := cache.GetDefaultCache().Get(ctx, cacheKey, cachedTotal)
if err == nil {
total = cachedTotal.Total
logger.Debug("Total records (from cache): %d", total)
} else {
logger.Debug("Cache miss for query total")
cachedTotal = nil
}
}
// If not in cache or cache skip, execute count query
if cachedTotal == nil {
count, err := query.Count(ctx)
if err != nil {
logger.Error("Error counting records: %v", err)
h.sendError(w, http.StatusInternalServerError, "query_error", "Error counting records", err)
return
}
total = count
logger.Debug("Total records (from query): %d", total)
// Store in cache (if caching is enabled)
if !options.SkipCache && cacheKey != "" {
cacheTTL := time.Minute * 2 // Default 2 minutes TTL
cacheData := &cache.CachedTotal{Total: total}
if err := cache.GetDefaultCache().Set(ctx, cacheKey, cacheData, cacheTTL); err != nil {
logger.Warn("Failed to cache query total: %v", err)
// Don't fail the request if caching fails
} else {
logger.Debug("Cached query total with key: %s", cacheKey)
}
}
}
total = count
logger.Debug("Total records: %d", total)
} else {
logger.Debug("Skipping count as requested")
total = -1 // Indicate count was skipped
@@ -515,7 +601,10 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
// Apply cursor filter to query
if cursorFilter != "" {
logger.Debug("Applying cursor filter: %s", cursorFilter)
query = query.Where(cursorFilter)
sanitizedCursor := common.SanitizeWhereClause(cursorFilter, reflection.ExtractTableNameOnly(tableName))
if sanitizedCursor != "" {
query = query.Where(sanitizedCursor)
}
}
}
@@ -589,6 +678,142 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
h.sendFormattedResponse(w, modelPtr, metadata, options)
}
// applyPreloadWithRecursion applies a preload with support for ComputedQL and recursive preloading
func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload common.PreloadOption, model interface{}, depth int) common.SelectQuery {
// Log relationship keys if they're specified (from XFiles)
if preload.RelatedKey != "" || preload.ForeignKey != "" || preload.PrimaryKey != "" {
logger.Debug("Preload %s has relationship keys - PK: %s, RelatedKey: %s, ForeignKey: %s",
preload.Relation, preload.PrimaryKey, preload.RelatedKey, preload.ForeignKey)
// Build a WHERE clause using the relationship keys if needed
// Note: Bun's PreloadRelation typically handles the relationship join automatically via struct tags
// However, if the relationship keys are explicitly provided from XFiles, we can use them
// to add additional filtering or validation
if preload.RelatedKey != "" && preload.Where == "" {
// For child tables: ensure the child's relatedkey column will be matched
// The actual parent value is dynamic and handled by Bun's preload mechanism
// We just log this for visibility
logger.Debug("Child table %s will be filtered by %s matching parent's primary key",
preload.Relation, preload.RelatedKey)
}
if preload.ForeignKey != "" && preload.Where == "" {
// For parent tables: ensure the parent's primary key matches the current table's foreign key
logger.Debug("Parent table %s will be filtered by primary key matching current table's %s",
preload.Relation, preload.ForeignKey)
}
}
// Apply the preload
query = query.PreloadRelation(preload.Relation, func(sq common.SelectQuery) common.SelectQuery {
// Get the related model for column operations
relatedModel := reflection.GetRelationModel(model, preload.Relation)
if relatedModel == nil {
logger.Warn("Could not get related model for preload: %s", preload.Relation)
// relatedModel = model // fallback to parent model
} else {
// If we have computed columns but no explicit columns, populate with all model columns first
// since computed columns are additions
if len(preload.Columns) == 0 && (len(preload.ComputedQL) > 0 || len(preload.OmitColumns) > 0) {
logger.Debug("Populating preload columns with all model columns since computed columns are additions")
preload.Columns = reflection.GetSQLModelColumns(relatedModel)
}
// Apply ComputedQL fields if any
if len(preload.ComputedQL) > 0 {
for colName, colExpr := range preload.ComputedQL {
logger.Debug("Applying computed column to preload %s: %s", preload.Relation, colName)
sq = sq.ColumnExpr(fmt.Sprintf("(%s) AS %s", colExpr, colName))
// Remove the computed column from selected columns to avoid duplication
for colIndex := range preload.Columns {
if preload.Columns[colIndex] == colName {
preload.Columns = append(preload.Columns[:colIndex], preload.Columns[colIndex+1:]...)
break
}
}
}
}
// Handle OmitColumns
if len(preload.OmitColumns) > 0 {
allCols := preload.Columns
// Remove omitted columns
preload.Columns = []string{}
for _, col := range allCols {
addCols := true
for _, omitCol := range preload.OmitColumns {
if col == omitCol {
addCols = false
break
}
}
if addCols {
preload.Columns = append(preload.Columns, col)
}
}
}
// Apply column selection
if len(preload.Columns) > 0 {
sq = sq.Column(preload.Columns...)
}
}
// Apply filters
if len(preload.Filters) > 0 {
for _, filter := range preload.Filters {
sq = h.applyFilter(sq, filter, "", false, "AND")
}
}
// Apply sorting
if len(preload.Sort) > 0 {
for _, sort := range preload.Sort {
sq = sq.Order(fmt.Sprintf("%s %s", sort.Column, sort.Direction))
}
}
// Apply WHERE clause
if len(preload.Where) > 0 {
sanitizedWhere := common.SanitizeWhereClause(preload.Where, reflection.ExtractTableNameOnly(preload.Relation))
if len(sanitizedWhere) > 0 {
sq = sq.Where(sanitizedWhere)
}
}
// Apply limit
if preload.Limit != nil && *preload.Limit > 0 {
sq = sq.Limit(*preload.Limit)
}
if preload.Offset != nil && *preload.Offset > 0 {
sq = sq.Offset(*preload.Offset)
}
return sq
})
// Handle recursive preloading
if preload.Recursive && depth < 5 {
logger.Debug("Applying recursive preload for %s at depth %d", preload.Relation, depth+1)
// For recursive relationships, we need to get the last part of the relation path
// e.g., "MastertaskItems" -> "MastertaskItems.MastertaskItems"
relationParts := strings.Split(preload.Relation, ".")
lastRelationName := relationParts[len(relationParts)-1]
// Create a recursive preload with the same configuration
// but with the relation path extended
recursivePreload := preload
recursivePreload.Relation = preload.Relation + "." + lastRelationName
// Recursively apply preload until we reach depth 5
query = h.applyPreloadWithRecursion(query, recursivePreload, model, depth+1)
}
return query
}
func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, data interface{}, options ExtendedRequestOptions) {
// Capture panics and return error response
defer func() {
@@ -1249,7 +1474,7 @@ func (h *Handler) normalizeToSlice(data interface{}) []interface{} {
func (h *Handler) extractNestedRelations(
data map[string]interface{},
model interface{},
) (map[string]interface{}, map[string]interface{}, error) {
) (_cleanedData map[string]interface{}, _relations map[string]interface{}, _err error) {
// Get model type for reflection
modelType := reflect.TypeOf(model)
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
@@ -1678,7 +1903,7 @@ func (h *Handler) sendResponseWithOptions(w common.ResponseWriter, data interfac
// Returns the single element if data is a slice/array with exactly one element, otherwise returns data unchanged
func (h *Handler) normalizeResultArray(data interface{}) interface{} {
if data == nil {
return data
return nil
}
// Use reflection to check if data is a slice or array

View File

@@ -10,6 +10,7 @@ import (
"github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/logger"
"github.com/bitechdev/ResolveSpec/pkg/reflection"
)
// ExtendedRequestOptions extends common.RequestOptions with additional features
@@ -110,8 +111,8 @@ func (h *Handler) parseOptionsFromHeaders(r common.Request, model interface{}) E
AdvancedSQL: make(map[string]string),
ComputedQL: make(map[string]string),
Expand: make([]ExpandOption, 0),
ResponseFormat: "simple", // Default response format
SingleRecordAsObject: true, // Default: normalize single-element arrays to objects
ResponseFormat: "simple", // Default response format
SingleRecordAsObject: true, // Default: normalize single-element arrays to objects
}
// Get all headers
@@ -122,106 +123,137 @@ func (h *Handler) parseOptionsFromHeaders(r common.Request, model interface{}) E
// Merge headers and query parameters - query parameters take precedence
// This allows the same parameters to be specified in either headers or query string
// Normalize keys to lowercase to ensure query params properly override headers
combinedParams := make(map[string]string)
for key, value := range headers {
combinedParams[key] = value
combinedParams[strings.ToLower(key)] = value
}
for key, value := range queryParams {
combinedParams[key] = value
combinedParams[strings.ToLower(key)] = value
}
// Process each parameter (from both headers and query params)
// Note: keys are already normalized to lowercase in combinedParams
for key, value := range combinedParams {
// Normalize parameter key to lowercase for consistent matching
normalizedKey := strings.ToLower(key)
// Decode value if it's base64 encoded
decodedValue := decodeHeaderValue(value)
// Parse based on parameter prefix/name
switch {
// Field Selection
case strings.HasPrefix(normalizedKey, "x-select-fields"):
case strings.HasPrefix(key, "x-select-fields"):
h.parseSelectFields(&options, decodedValue)
case strings.HasPrefix(normalizedKey, "x-not-select-fields"):
case strings.HasPrefix(key, "x-not-select-fields"):
h.parseNotSelectFields(&options, decodedValue)
case strings.HasPrefix(normalizedKey, "x-clean-json"):
case strings.HasPrefix(key, "x-clean-json"):
options.CleanJSON = strings.EqualFold(decodedValue, "true")
// Filtering & Search
case strings.HasPrefix(normalizedKey, "x-fieldfilter-"):
h.parseFieldFilter(&options, normalizedKey, decodedValue)
case strings.HasPrefix(normalizedKey, "x-searchfilter-"):
h.parseSearchFilter(&options, normalizedKey, decodedValue)
case strings.HasPrefix(normalizedKey, "x-searchop-"):
h.parseSearchOp(&options, normalizedKey, decodedValue, "AND")
case strings.HasPrefix(normalizedKey, "x-searchor-"):
h.parseSearchOp(&options, normalizedKey, decodedValue, "OR")
case strings.HasPrefix(normalizedKey, "x-searchand-"):
h.parseSearchOp(&options, normalizedKey, decodedValue, "AND")
case strings.HasPrefix(normalizedKey, "x-searchcols"):
case strings.HasPrefix(key, "x-fieldfilter-"):
h.parseFieldFilter(&options, key, decodedValue)
case strings.HasPrefix(key, "x-searchfilter-"):
h.parseSearchFilter(&options, key, decodedValue)
case strings.HasPrefix(key, "x-searchop-"):
h.parseSearchOp(&options, key, decodedValue, "AND")
case strings.HasPrefix(key, "x-searchor-"):
h.parseSearchOp(&options, key, decodedValue, "OR")
case strings.HasPrefix(key, "x-searchand-"):
h.parseSearchOp(&options, key, decodedValue, "AND")
case strings.HasPrefix(key, "x-searchcols"):
options.SearchColumns = h.parseCommaSeparated(decodedValue)
case strings.HasPrefix(normalizedKey, "x-custom-sql-w"):
options.CustomSQLWhere = decodedValue
case strings.HasPrefix(normalizedKey, "x-custom-sql-or"):
options.CustomSQLOr = decodedValue
case strings.HasPrefix(key, "x-custom-sql-w"):
if options.CustomSQLWhere != "" {
options.CustomSQLWhere = fmt.Sprintf("%s AND (%s)", options.CustomSQLWhere, decodedValue)
} else {
options.CustomSQLWhere = decodedValue
}
case strings.HasPrefix(key, "x-custom-sql-or"):
if options.CustomSQLOr != "" {
options.CustomSQLOr = fmt.Sprintf("%s OR (%s)", options.CustomSQLOr, decodedValue)
} else {
options.CustomSQLOr = decodedValue
}
// Joins & Relations
case strings.HasPrefix(normalizedKey, "x-preload"):
if strings.HasSuffix(normalizedKey, "-where") {
case strings.HasPrefix(key, "x-preload"):
if strings.HasSuffix(key, "-where") {
continue
}
whereClaude := combinedParams[fmt.Sprintf("%s-where", key)]
h.parsePreload(&options, decodedValue, decodeHeaderValue(whereClaude))
case strings.HasPrefix(normalizedKey, "x-expand"):
case strings.HasPrefix(key, "x-expand"):
h.parseExpand(&options, decodedValue)
case strings.HasPrefix(normalizedKey, "x-custom-sql-join"):
case strings.HasPrefix(key, "x-custom-sql-join"):
// TODO: Implement custom SQL join
logger.Debug("Custom SQL join not yet implemented: %s", decodedValue)
// Sorting & Pagination
case strings.HasPrefix(normalizedKey, "x-sort"):
case strings.HasPrefix(key, "x-sort"):
h.parseSorting(&options, decodedValue)
case strings.HasPrefix(normalizedKey, "x-limit"):
// Special cases for older clients using sort(a,b,-c) syntax
case strings.HasPrefix(key, "sort(") && strings.Contains(key, ")"):
sortValue := key[strings.Index(key, "(")+1 : strings.Index(key, ")")]
h.parseSorting(&options, sortValue)
case strings.HasPrefix(key, "x-limit"):
if limit, err := strconv.Atoi(decodedValue); err == nil {
options.Limit = &limit
}
case strings.HasPrefix(normalizedKey, "x-offset"):
// Special cases for older clients using limit(n) syntax
case strings.HasPrefix(key, "limit(") && strings.Contains(key, ")"):
limitValue := key[strings.Index(key, "(")+1 : strings.Index(key, ")")]
limitValueParts := strings.Split(limitValue, ",")
if len(limitValueParts) > 1 {
if offset, err := strconv.Atoi(limitValueParts[0]); err == nil {
options.Offset = &offset
}
if limit, err := strconv.Atoi(limitValueParts[1]); err == nil {
options.Limit = &limit
}
} else {
if limit, err := strconv.Atoi(limitValueParts[0]); err == nil {
options.Limit = &limit
}
}
case strings.HasPrefix(key, "x-offset"):
if offset, err := strconv.Atoi(decodedValue); err == nil {
options.Offset = &offset
}
case strings.HasPrefix(normalizedKey, "x-cursor-forward"):
case strings.HasPrefix(key, "x-cursor-forward"):
options.CursorForward = decodedValue
case strings.HasPrefix(normalizedKey, "x-cursor-backward"):
case strings.HasPrefix(key, "x-cursor-backward"):
options.CursorBackward = decodedValue
// Advanced Features
case strings.HasPrefix(normalizedKey, "x-advsql-"):
colName := strings.TrimPrefix(normalizedKey, "x-advsql-")
case strings.HasPrefix(key, "x-advsql-"):
colName := strings.TrimPrefix(key, "x-advsql-")
options.AdvancedSQL[colName] = decodedValue
case strings.HasPrefix(normalizedKey, "x-cql-sel-"):
colName := strings.TrimPrefix(normalizedKey, "x-cql-sel-")
case strings.HasPrefix(key, "x-cql-sel-"):
colName := strings.TrimPrefix(key, "x-cql-sel-")
options.ComputedQL[colName] = decodedValue
case strings.HasPrefix(normalizedKey, "x-distinct"):
case strings.HasPrefix(key, "x-distinct"):
options.Distinct = strings.EqualFold(decodedValue, "true")
case strings.HasPrefix(normalizedKey, "x-skipcount"):
case strings.HasPrefix(key, "x-skipcount"):
options.SkipCount = strings.EqualFold(decodedValue, "true")
case strings.HasPrefix(normalizedKey, "x-skipcache"):
case strings.HasPrefix(key, "x-skipcache"):
options.SkipCache = strings.EqualFold(decodedValue, "true")
case strings.HasPrefix(normalizedKey, "x-fetch-rownumber"):
case strings.HasPrefix(key, "x-fetch-rownumber"):
options.FetchRowNumber = &decodedValue
case strings.HasPrefix(normalizedKey, "x-pkrow"):
case strings.HasPrefix(key, "x-pkrow"):
options.PKRow = &decodedValue
// Response Format
case strings.HasPrefix(normalizedKey, "x-simpleapi"):
case strings.HasPrefix(key, "x-simpleapi"):
options.ResponseFormat = "simple"
case strings.HasPrefix(normalizedKey, "x-detailapi"):
case strings.HasPrefix(key, "x-detailapi"):
options.ResponseFormat = "detail"
case strings.HasPrefix(normalizedKey, "x-syncfusion"):
case strings.HasPrefix(key, "x-syncfusion"):
options.ResponseFormat = "syncfusion"
case strings.HasPrefix(normalizedKey, "x-single-record-as-object"):
case strings.HasPrefix(key, "x-single-record-as-object"):
// Parse as boolean - "false" disables, "true" enables (default is true)
if strings.EqualFold(decodedValue, "false") {
options.SingleRecordAsObject = false
@@ -230,11 +262,11 @@ func (h *Handler) parseOptionsFromHeaders(r common.Request, model interface{}) E
}
// Transaction Control
case strings.HasPrefix(normalizedKey, "x-transaction-atomic"):
case strings.HasPrefix(key, "x-transaction-atomic"):
options.AtomicTransaction = strings.EqualFold(decodedValue, "true")
// X-Files - comprehensive JSON configuration
case strings.HasPrefix(normalizedKey, "x-files"):
case strings.HasPrefix(key, "x-files"):
h.parseXFiles(&options, decodedValue)
}
}
@@ -244,6 +276,12 @@ func (h *Handler) parseOptionsFromHeaders(r common.Request, model interface{}) E
h.resolveRelationNamesInOptions(&options, model)
}
// Always sort according to the primary key if no sorting is specified
if len(options.Sort) == 0 {
pkName := reflection.GetPrimaryKeyName(model)
options.Sort = []common.SortOption{{Column: pkName, Direction: "ASC"}}
}
return options
}
@@ -697,7 +735,7 @@ func (h *Handler) resolveRelationNamesInOptions(options *ExtendedRequestOptions,
// Try to get the model type for the next level
// This allows nested resolution
if nextModel := h.getRelationModel(currentModel, resolvedPart); nextModel != nil {
if nextModel := reflection.GetRelationModel(currentModel, resolvedPart); nextModel != nil {
currentModel = nextModel
}
}
@@ -721,58 +759,6 @@ func (h *Handler) resolveRelationNamesInOptions(options *ExtendedRequestOptions,
}
}
// getRelationModel gets the model type for a relation field
func (h *Handler) getRelationModel(model interface{}, fieldName string) interface{} {
if model == nil || fieldName == "" {
return nil
}
modelType := reflect.TypeOf(model)
if modelType == nil {
return nil
}
if modelType.Kind() == reflect.Ptr {
modelType = modelType.Elem()
}
if modelType == nil || modelType.Kind() != reflect.Struct {
return nil
}
// Find the field
field, found := modelType.FieldByName(fieldName)
if !found {
return nil
}
// Get the target type
targetType := field.Type
if targetType == nil {
return nil
}
if targetType.Kind() == reflect.Slice {
targetType = targetType.Elem()
if targetType == nil {
return nil
}
}
if targetType.Kind() == reflect.Ptr {
targetType = targetType.Elem()
if targetType == nil {
return nil
}
}
if targetType.Kind() != reflect.Struct {
return nil
}
// Create a zero value of the target type
return reflect.New(targetType).Elem().Interface()
}
// resolveRelationName resolves a relation name or table name to the actual field name in the model
// If the input is already a field name, it returns it as-is
// If the input is a table name, it looks up the corresponding relation field
@@ -806,7 +792,7 @@ func (h *Handler) resolveRelationName(model interface{}, nameOrTable string) str
field := modelType.Field(i)
if field.Name == nameOrTable {
// It's already a field name
logger.Debug("Input '%s' is a field name", nameOrTable)
// logger.Debug("Input '%s' is a field name", nameOrTable)
return nameOrTable
}
}
@@ -935,6 +921,33 @@ func (h *Handler) addXFilesPreload(xfile *XFiles, options *ExtendedRequestOption
}
}
// Add computed columns (CQL) -> ComputedQL
if len(xfile.CQLColumns) > 0 {
preloadOpt.ComputedQL = make(map[string]string)
for i, cqlExpr := range xfile.CQLColumns {
colName := fmt.Sprintf("cql%d", i+1)
preloadOpt.ComputedQL[colName] = cqlExpr
logger.Debug("X-Files: Added computed column %s to preload %s: %s", colName, relationPath, cqlExpr)
}
}
// Set recursive flag
preloadOpt.Recursive = xfile.Recursive
// Extract relationship keys for proper foreign key filtering
if xfile.PrimaryKey != "" {
preloadOpt.PrimaryKey = xfile.PrimaryKey
logger.Debug("X-Files: Set primary key for %s: %s", relationPath, xfile.PrimaryKey)
}
if xfile.RelatedKey != "" {
preloadOpt.RelatedKey = xfile.RelatedKey
logger.Debug("X-Files: Set related key for %s: %s", relationPath, xfile.RelatedKey)
}
if xfile.ForeignKey != "" {
preloadOpt.ForeignKey = xfile.ForeignKey
logger.Debug("X-Files: Set foreign key for %s: %s", relationPath, xfile.ForeignKey)
}
// Add the preload option
options.Preload = append(options.Preload, preloadOpt)
@@ -947,192 +960,6 @@ func (h *Handler) addXFilesPreload(xfile *XFiles, options *ExtendedRequestOption
}
}
// extractSourceColumn extracts the base column name from PostgreSQL JSON operators
// Examples:
// - "columna->>'val'" returns "columna"
// - "columna->'key'" returns "columna"
// - "columna" returns "columna"
// - "table.columna->>'val'" returns "table.columna"
func extractSourceColumn(colName string) string {
// Check for PostgreSQL JSON operators: -> and ->>
if idx := strings.Index(colName, "->>"); idx != -1 {
return strings.TrimSpace(colName[:idx])
}
if idx := strings.Index(colName, "->"); idx != -1 {
return strings.TrimSpace(colName[:idx])
}
return colName
}
// getColumnTypeFromModel uses reflection to determine the Go type of a column in a model
func (h *Handler) getColumnTypeFromModel(model interface{}, colName string) reflect.Kind {
if model == nil {
return reflect.Invalid
}
// Extract the source column name (remove JSON operators like ->> or ->)
sourceColName := extractSourceColumn(colName)
modelType := reflect.TypeOf(model)
// Dereference pointer if needed
if modelType.Kind() == reflect.Ptr {
modelType = modelType.Elem()
}
// Ensure it's a struct
if modelType.Kind() != reflect.Struct {
return reflect.Invalid
}
// Find the field by JSON tag or field name
for i := 0; i < modelType.NumField(); i++ {
field := modelType.Field(i)
// Check JSON tag
jsonTag := field.Tag.Get("json")
if jsonTag != "" {
// Parse JSON tag (format: "name,omitempty")
parts := strings.Split(jsonTag, ",")
if parts[0] == sourceColName {
return field.Type.Kind()
}
}
// Check field name (case-insensitive)
if strings.EqualFold(field.Name, sourceColName) {
return field.Type.Kind()
}
// Check snake_case conversion
snakeCaseName := toSnakeCase(field.Name)
if snakeCaseName == sourceColName {
return field.Type.Kind()
}
}
return reflect.Invalid
}
// toSnakeCase converts a string from CamelCase to snake_case
func toSnakeCase(s string) string {
var result strings.Builder
for i, r := range s {
if i > 0 && r >= 'A' && r <= 'Z' {
result.WriteRune('_')
}
result.WriteRune(r)
}
return strings.ToLower(result.String())
}
// isNumericType checks if a reflect.Kind is a numeric type
func isNumericType(kind reflect.Kind) bool {
return kind == reflect.Int || kind == reflect.Int8 || kind == reflect.Int16 ||
kind == reflect.Int32 || kind == reflect.Int64 || kind == reflect.Uint ||
kind == reflect.Uint8 || kind == reflect.Uint16 || kind == reflect.Uint32 ||
kind == reflect.Uint64 || kind == reflect.Float32 || kind == reflect.Float64
}
// isStringType checks if a reflect.Kind is a string type
func isStringType(kind reflect.Kind) bool {
return kind == reflect.String
}
// convertToNumericType converts a string value to the appropriate numeric type
func convertToNumericType(value string, kind reflect.Kind) (interface{}, error) {
value = strings.TrimSpace(value)
switch kind {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
// Parse as integer
bitSize := 64
switch kind {
case reflect.Int8:
bitSize = 8
case reflect.Int16:
bitSize = 16
case reflect.Int32:
bitSize = 32
}
intVal, err := strconv.ParseInt(value, 10, bitSize)
if err != nil {
return nil, fmt.Errorf("invalid integer value: %w", err)
}
// Return the appropriate type
switch kind {
case reflect.Int:
return int(intVal), nil
case reflect.Int8:
return int8(intVal), nil
case reflect.Int16:
return int16(intVal), nil
case reflect.Int32:
return int32(intVal), nil
case reflect.Int64:
return intVal, nil
}
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
// Parse as unsigned integer
bitSize := 64
switch kind {
case reflect.Uint8:
bitSize = 8
case reflect.Uint16:
bitSize = 16
case reflect.Uint32:
bitSize = 32
}
uintVal, err := strconv.ParseUint(value, 10, bitSize)
if err != nil {
return nil, fmt.Errorf("invalid unsigned integer value: %w", err)
}
// Return the appropriate type
switch kind {
case reflect.Uint:
return uint(uintVal), nil
case reflect.Uint8:
return uint8(uintVal), nil
case reflect.Uint16:
return uint16(uintVal), nil
case reflect.Uint32:
return uint32(uintVal), nil
case reflect.Uint64:
return uintVal, nil
}
case reflect.Float32, reflect.Float64:
// Parse as float
bitSize := 64
if kind == reflect.Float32 {
bitSize = 32
}
floatVal, err := strconv.ParseFloat(value, bitSize)
if err != nil {
return nil, fmt.Errorf("invalid float value: %w", err)
}
if kind == reflect.Float32 {
return float32(floatVal), nil
}
return floatVal, nil
}
return nil, fmt.Errorf("unsupported numeric type: %v", kind)
}
// isNumericValue checks if a string value can be parsed as a number
func isNumericValue(value string) bool {
value = strings.TrimSpace(value)
_, err := strconv.ParseFloat(value, 64)
return err == nil
}
// ColumnCastInfo holds information about whether a column needs casting
type ColumnCastInfo struct {
NeedsCast bool
@@ -1146,7 +973,7 @@ func (h *Handler) ValidateAndAdjustFilterForColumnType(filter *common.FilterOpti
return ColumnCastInfo{NeedsCast: false, IsNumericType: false}
}
colType := h.getColumnTypeFromModel(model, filter.Column)
colType := reflection.GetColumnTypeFromModel(model, filter.Column)
if colType == reflect.Invalid {
// Column not found in model, no casting needed
logger.Debug("Column %s not found in model, skipping type validation", filter.Column)
@@ -1157,18 +984,18 @@ func (h *Handler) ValidateAndAdjustFilterForColumnType(filter *common.FilterOpti
valueIsNumeric := false
if strVal, ok := filter.Value.(string); ok {
strVal = strings.Trim(strVal, "%")
valueIsNumeric = isNumericValue(strVal)
valueIsNumeric = reflection.IsNumericValue(strVal)
}
// Adjust based on column type
switch {
case isNumericType(colType):
case reflection.IsNumericType(colType):
// Column is numeric
if valueIsNumeric {
// Value is numeric - try to convert it
if strVal, ok := filter.Value.(string); ok {
strVal = strings.Trim(strVal, "%")
numericVal, err := convertToNumericType(strVal, colType)
numericVal, err := reflection.ConvertToNumericType(strVal, colType)
if err != nil {
logger.Debug("Failed to convert value '%s' to numeric type for column %s, will use text cast", strVal, filter.Column)
return ColumnCastInfo{NeedsCast: true, IsNumericType: true}
@@ -1183,7 +1010,7 @@ func (h *Handler) ValidateAndAdjustFilterForColumnType(filter *common.FilterOpti
return ColumnCastInfo{NeedsCast: true, IsNumericType: true}
}
case isStringType(colType):
case reflection.IsStringType(colType):
// String columns don't need casting
return ColumnCastInfo{NeedsCast: false, IsNumericType: false}

View File

@@ -1,6 +1,7 @@
package restheadspec
import (
"net/http"
"testing"
)
@@ -42,6 +43,12 @@ func (m *MockRequest) AllQueryParams() map[string]string {
return m.queryParams
}
func (m *MockRequest) UnderlyingRequest() *http.Request {
// For testing purposes, return nil
// In real scenarios, you might want to construct a proper http.Request
return nil
}
func TestParseOptionsFromQueryParams(t *testing.T) {
handler := NewHandler(nil, nil)

View File

@@ -54,12 +54,14 @@ package restheadspec
import (
"net/http"
"strings"
"github.com/gorilla/mux"
"github.com/uptrace/bun"
"github.com/uptrace/bunrouter"
"gorm.io/gorm"
"github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/common/adapters/database"
"github.com/bitechdev/ResolveSpec/pkg/common/adapters/router"
"github.com/bitechdev/ResolveSpec/pkg/logger"
@@ -90,31 +92,130 @@ func NewStandardBunRouter() *router.StandardBunRouterAdapter {
return router.NewStandardBunRouterAdapter()
}
// MiddlewareFunc is a function that wraps an http.Handler with additional functionality
type MiddlewareFunc func(http.Handler) http.Handler
// SetupMuxRoutes sets up routes for the RestHeadSpec API with Mux
func SetupMuxRoutes(muxRouter *mux.Router, handler *Handler) {
// GET, POST, PUT, PATCH, DELETE for /{schema}/{entity}
muxRouter.HandleFunc("/{schema}/{entity}", func(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r)
reqAdapter := router.NewHTTPRequest(r)
respAdapter := router.NewHTTPResponseWriter(w)
handler.Handle(respAdapter, reqAdapter, vars)
}).Methods("GET", "POST")
// authMiddleware is optional - if provided, routes will be protected with the middleware
// Example: SetupMuxRoutes(router, handler, func(h http.Handler) http.Handler { return security.NewAuthHandler(securityList, h) })
func SetupMuxRoutes(muxRouter *mux.Router, handler *Handler, authMiddleware MiddlewareFunc) {
// Get all registered models from the registry
allModels := handler.registry.GetAllModels()
// GET, PUT, PATCH, DELETE for /{schema}/{entity}/{id}
muxRouter.HandleFunc("/{schema}/{entity}/{id}", func(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r)
reqAdapter := router.NewHTTPRequest(r)
respAdapter := router.NewHTTPResponseWriter(w)
handler.Handle(respAdapter, reqAdapter, vars)
}).Methods("GET", "PUT", "PATCH", "DELETE", "POST")
// Loop through each registered model and create explicit routes
for fullName := range allModels {
// Parse the full name (e.g., "public.users" or just "users")
schema, entity := parseModelName(fullName)
// GET for metadata (using HandleGet)
muxRouter.HandleFunc("/{schema}/{entity}/metadata", func(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r)
reqAdapter := router.NewHTTPRequest(r)
// Build the route paths
entityPath := buildRoutePath(schema, entity)
entityWithIDPath := buildRoutePath(schema, entity) + "/{id}"
metadataPath := buildRoutePath(schema, entity) + "/metadata"
// Create handler functions for this specific entity
entityHandler := createMuxHandler(handler, schema, entity, "")
entityWithIDHandler := createMuxHandler(handler, schema, entity, "id")
metadataHandler := createMuxGetHandler(handler, schema, entity, "")
optionsEntityHandler := createMuxOptionsHandler(handler, schema, entity, []string{"GET", "POST", "OPTIONS"})
optionsEntityWithIDHandler := createMuxOptionsHandler(handler, schema, entity, []string{"GET", "PUT", "PATCH", "DELETE", "POST", "OPTIONS"})
// Apply authentication middleware if provided
if authMiddleware != nil {
entityHandler = authMiddleware(entityHandler).(http.HandlerFunc)
entityWithIDHandler = authMiddleware(entityWithIDHandler).(http.HandlerFunc)
metadataHandler = authMiddleware(metadataHandler).(http.HandlerFunc)
// Don't apply auth middleware to OPTIONS - CORS preflight must not require auth
}
// Register routes for this entity
// GET, POST for /{schema}/{entity}
muxRouter.Handle(entityPath, entityHandler).Methods("GET", "POST")
// GET, PUT, PATCH, DELETE, POST for /{schema}/{entity}/{id}
muxRouter.Handle(entityWithIDPath, entityWithIDHandler).Methods("GET", "PUT", "PATCH", "DELETE", "POST")
// GET for metadata (using HandleGet)
muxRouter.Handle(metadataPath, metadataHandler).Methods("GET")
// OPTIONS for CORS preflight - returns metadata
muxRouter.Handle(entityPath, optionsEntityHandler).Methods("OPTIONS")
muxRouter.Handle(entityWithIDPath, optionsEntityWithIDHandler).Methods("OPTIONS")
}
}
// Helper function to create Mux handler for a specific entity with CORS support
func createMuxHandler(handler *Handler, schema, entity, idParam string) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// Set CORS headers
corsConfig := common.DefaultCORSConfig()
respAdapter := router.NewHTTPResponseWriter(w)
common.SetCORSHeaders(respAdapter, corsConfig)
vars := make(map[string]string)
vars["schema"] = schema
vars["entity"] = entity
if idParam != "" {
vars["id"] = mux.Vars(r)[idParam]
}
reqAdapter := router.NewHTTPRequest(r)
handler.Handle(respAdapter, reqAdapter, vars)
}
}
// Helper function to create Mux GET handler for a specific entity with CORS support
func createMuxGetHandler(handler *Handler, schema, entity, idParam string) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// Set CORS headers
corsConfig := common.DefaultCORSConfig()
respAdapter := router.NewHTTPResponseWriter(w)
common.SetCORSHeaders(respAdapter, corsConfig)
vars := make(map[string]string)
vars["schema"] = schema
vars["entity"] = entity
if idParam != "" {
vars["id"] = mux.Vars(r)[idParam]
}
reqAdapter := router.NewHTTPRequest(r)
handler.HandleGet(respAdapter, reqAdapter, vars)
}).Methods("GET")
}
}
// Helper function to create Mux OPTIONS handler that returns metadata
func createMuxOptionsHandler(handler *Handler, schema, entity string, allowedMethods []string) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// Set CORS headers with the allowed methods for this route
corsConfig := common.DefaultCORSConfig()
corsConfig.AllowedMethods = allowedMethods
respAdapter := router.NewHTTPResponseWriter(w)
common.SetCORSHeaders(respAdapter, corsConfig)
// Return metadata in the OPTIONS response body
vars := make(map[string]string)
vars["schema"] = schema
vars["entity"] = entity
reqAdapter := router.NewHTTPRequest(r)
handler.HandleGet(respAdapter, reqAdapter, vars)
}
}
// parseModelName parses a model name like "public.users" into schema and entity
// If no schema is present, returns empty string for schema
func parseModelName(fullName string) (schema, entity string) {
parts := strings.Split(fullName, ".")
if len(parts) == 2 {
return parts[0], parts[1]
}
return "", fullName
}
// buildRoutePath builds a route path from schema and entity
// If schema is empty, returns just "/entity", otherwise "/{schema}/{entity}"
func buildRoutePath(schema, entity string) string {
if schema == "" {
return "/" + entity
}
return "/" + schema + "/" + entity
}
// Example usage functions for documentation:
@@ -124,12 +225,20 @@ func ExampleWithGORM(db *gorm.DB) {
// Create handler using GORM
handler := NewHandlerWithGORM(db)
// Setup router
// Setup router without authentication
muxRouter := mux.NewRouter()
SetupMuxRoutes(muxRouter, handler)
SetupMuxRoutes(muxRouter, handler, nil)
// Register models
// handler.registry.RegisterModel("public.users", &User{})
// To add authentication, pass a middleware function:
// import "github.com/bitechdev/ResolveSpec/pkg/security"
// secList := security.NewSecurityList(myProvider)
// authMiddleware := func(h http.Handler) http.Handler {
// return security.NewAuthHandler(secList, h)
// }
// SetupMuxRoutes(muxRouter, handler, authMiddleware)
}
// ExampleWithBun shows how to switch to Bun ORM
@@ -144,110 +253,169 @@ func ExampleWithBun(bunDB *bun.DB) {
// Create handler
handler := NewHandler(dbAdapter, registry)
// Setup routes
// Setup routes without authentication
muxRouter := mux.NewRouter()
SetupMuxRoutes(muxRouter, handler)
SetupMuxRoutes(muxRouter, handler, nil)
}
// SetupBunRouterRoutes sets up bunrouter routes for the RestHeadSpec API
func SetupBunRouterRoutes(bunRouter *router.StandardBunRouterAdapter, handler *Handler) {
r := bunRouter.GetBunRouter()
// GET and POST for /:schema/:entity
r.Handle("GET", "/:schema/:entity", func(w http.ResponseWriter, req bunrouter.Request) error {
params := map[string]string{
"schema": req.Param("schema"),
"entity": req.Param("entity"),
}
reqAdapter := router.NewBunRouterRequest(req)
respAdapter := router.NewHTTPResponseWriter(w)
handler.Handle(respAdapter, reqAdapter, params)
return nil
})
// Get all registered models from the registry
allModels := handler.registry.GetAllModels()
r.Handle("POST", "/:schema/:entity", func(w http.ResponseWriter, req bunrouter.Request) error {
params := map[string]string{
"schema": req.Param("schema"),
"entity": req.Param("entity"),
}
reqAdapter := router.NewBunRouterRequest(req)
respAdapter := router.NewHTTPResponseWriter(w)
handler.Handle(respAdapter, reqAdapter, params)
return nil
})
// CORS config
corsConfig := common.DefaultCORSConfig()
// GET, PUT, PATCH, DELETE for /:schema/:entity/:id
r.Handle("GET", "/:schema/:entity/:id", func(w http.ResponseWriter, req bunrouter.Request) error {
params := map[string]string{
"schema": req.Param("schema"),
"entity": req.Param("entity"),
"id": req.Param("id"),
}
reqAdapter := router.NewBunRouterRequest(req)
respAdapter := router.NewHTTPResponseWriter(w)
handler.Handle(respAdapter, reqAdapter, params)
return nil
})
// Loop through each registered model and create explicit routes
for fullName := range allModels {
// Parse the full name (e.g., "public.users" or just "users")
schema, entity := parseModelName(fullName)
r.Handle("POST", "/:schema/:entity/:id", func(w http.ResponseWriter, req bunrouter.Request) error {
params := map[string]string{
"schema": req.Param("schema"),
"entity": req.Param("entity"),
"id": req.Param("id"),
}
reqAdapter := router.NewBunRouterRequest(req)
respAdapter := router.NewHTTPResponseWriter(w)
handler.Handle(respAdapter, reqAdapter, params)
return nil
})
// Build the route paths
entityPath := buildRoutePath(schema, entity)
entityWithIDPath := entityPath + "/:id"
metadataPath := entityPath + "/metadata"
r.Handle("PUT", "/:schema/:entity/:id", func(w http.ResponseWriter, req bunrouter.Request) error {
params := map[string]string{
"schema": req.Param("schema"),
"entity": req.Param("entity"),
"id": req.Param("id"),
}
reqAdapter := router.NewBunRouterRequest(req)
respAdapter := router.NewHTTPResponseWriter(w)
handler.Handle(respAdapter, reqAdapter, params)
return nil
})
// Create closure variables to capture current schema and entity
currentSchema := schema
currentEntity := entity
r.Handle("PATCH", "/:schema/:entity/:id", func(w http.ResponseWriter, req bunrouter.Request) error {
params := map[string]string{
"schema": req.Param("schema"),
"entity": req.Param("entity"),
"id": req.Param("id"),
}
reqAdapter := router.NewBunRouterRequest(req)
respAdapter := router.NewHTTPResponseWriter(w)
handler.Handle(respAdapter, reqAdapter, params)
return nil
})
// GET and POST for /{schema}/{entity}
r.Handle("GET", entityPath, func(w http.ResponseWriter, req bunrouter.Request) error {
respAdapter := router.NewHTTPResponseWriter(w)
common.SetCORSHeaders(respAdapter, corsConfig)
params := map[string]string{
"schema": currentSchema,
"entity": currentEntity,
}
reqAdapter := router.NewBunRouterRequest(req)
handler.Handle(respAdapter, reqAdapter, params)
return nil
})
r.Handle("DELETE", "/:schema/:entity/:id", func(w http.ResponseWriter, req bunrouter.Request) error {
params := map[string]string{
"schema": req.Param("schema"),
"entity": req.Param("entity"),
"id": req.Param("id"),
}
reqAdapter := router.NewBunRouterRequest(req)
respAdapter := router.NewHTTPResponseWriter(w)
handler.Handle(respAdapter, reqAdapter, params)
return nil
})
r.Handle("POST", entityPath, func(w http.ResponseWriter, req bunrouter.Request) error {
respAdapter := router.NewHTTPResponseWriter(w)
common.SetCORSHeaders(respAdapter, corsConfig)
params := map[string]string{
"schema": currentSchema,
"entity": currentEntity,
}
reqAdapter := router.NewBunRouterRequest(req)
handler.Handle(respAdapter, reqAdapter, params)
return nil
})
// Metadata endpoint
r.Handle("GET", "/:schema/:entity/metadata", func(w http.ResponseWriter, req bunrouter.Request) error {
params := map[string]string{
"schema": req.Param("schema"),
"entity": req.Param("entity"),
}
reqAdapter := router.NewBunRouterRequest(req)
respAdapter := router.NewHTTPResponseWriter(w)
handler.HandleGet(respAdapter, reqAdapter, params)
return nil
})
// GET, POST, PUT, PATCH, DELETE for /{schema}/{entity}/:id
r.Handle("GET", entityWithIDPath, func(w http.ResponseWriter, req bunrouter.Request) error {
respAdapter := router.NewHTTPResponseWriter(w)
common.SetCORSHeaders(respAdapter, corsConfig)
params := map[string]string{
"schema": currentSchema,
"entity": currentEntity,
"id": req.Param("id"),
}
reqAdapter := router.NewBunRouterRequest(req)
handler.Handle(respAdapter, reqAdapter, params)
return nil
})
r.Handle("POST", entityWithIDPath, func(w http.ResponseWriter, req bunrouter.Request) error {
respAdapter := router.NewHTTPResponseWriter(w)
common.SetCORSHeaders(respAdapter, corsConfig)
params := map[string]string{
"schema": currentSchema,
"entity": currentEntity,
"id": req.Param("id"),
}
reqAdapter := router.NewBunRouterRequest(req)
handler.Handle(respAdapter, reqAdapter, params)
return nil
})
r.Handle("PUT", entityWithIDPath, func(w http.ResponseWriter, req bunrouter.Request) error {
respAdapter := router.NewHTTPResponseWriter(w)
common.SetCORSHeaders(respAdapter, corsConfig)
params := map[string]string{
"schema": currentSchema,
"entity": currentEntity,
"id": req.Param("id"),
}
reqAdapter := router.NewBunRouterRequest(req)
handler.Handle(respAdapter, reqAdapter, params)
return nil
})
r.Handle("PATCH", entityWithIDPath, func(w http.ResponseWriter, req bunrouter.Request) error {
respAdapter := router.NewHTTPResponseWriter(w)
common.SetCORSHeaders(respAdapter, corsConfig)
params := map[string]string{
"schema": currentSchema,
"entity": currentEntity,
"id": req.Param("id"),
}
reqAdapter := router.NewBunRouterRequest(req)
handler.Handle(respAdapter, reqAdapter, params)
return nil
})
r.Handle("DELETE", entityWithIDPath, func(w http.ResponseWriter, req bunrouter.Request) error {
respAdapter := router.NewHTTPResponseWriter(w)
common.SetCORSHeaders(respAdapter, corsConfig)
params := map[string]string{
"schema": currentSchema,
"entity": currentEntity,
"id": req.Param("id"),
}
reqAdapter := router.NewBunRouterRequest(req)
handler.Handle(respAdapter, reqAdapter, params)
return nil
})
// Metadata endpoint
r.Handle("GET", metadataPath, func(w http.ResponseWriter, req bunrouter.Request) error {
respAdapter := router.NewHTTPResponseWriter(w)
common.SetCORSHeaders(respAdapter, corsConfig)
params := map[string]string{
"schema": currentSchema,
"entity": currentEntity,
}
reqAdapter := router.NewBunRouterRequest(req)
handler.HandleGet(respAdapter, reqAdapter, params)
return nil
})
// OPTIONS route without ID (returns metadata)
r.Handle("OPTIONS", entityPath, func(w http.ResponseWriter, req bunrouter.Request) error {
respAdapter := router.NewHTTPResponseWriter(w)
optionsCorsConfig := corsConfig
optionsCorsConfig.AllowedMethods = []string{"GET", "POST", "OPTIONS"}
common.SetCORSHeaders(respAdapter, optionsCorsConfig)
params := map[string]string{
"schema": currentSchema,
"entity": currentEntity,
}
reqAdapter := router.NewBunRouterRequest(req)
handler.HandleGet(respAdapter, reqAdapter, params)
return nil
})
// OPTIONS route with ID (returns metadata)
r.Handle("OPTIONS", entityWithIDPath, func(w http.ResponseWriter, req bunrouter.Request) error {
respAdapter := router.NewHTTPResponseWriter(w)
optionsCorsConfig := corsConfig
optionsCorsConfig.AllowedMethods = []string{"GET", "PUT", "PATCH", "DELETE", "POST", "OPTIONS"}
common.SetCORSHeaders(respAdapter, optionsCorsConfig)
params := map[string]string{
"schema": currentSchema,
"entity": currentEntity,
}
reqAdapter := router.NewBunRouterRequest(req)
handler.HandleGet(respAdapter, reqAdapter, params)
return nil
})
}
}
// ExampleBunRouterWithBunDB shows usage with both BunRouter and Bun DB

View File

@@ -0,0 +1,82 @@
package restheadspec
import (
"context"
"github.com/bitechdev/ResolveSpec/pkg/logger"
"github.com/bitechdev/ResolveSpec/pkg/security"
)
// RegisterSecurityHooks registers all security-related hooks with the handler
func RegisterSecurityHooks(handler *Handler, securityList *security.SecurityList) {
// Hook 1: BeforeRead - Load security rules
handler.Hooks().Register(BeforeRead, func(hookCtx *HookContext) error {
secCtx := newSecurityContext(hookCtx)
return security.LoadSecurityRules(secCtx, securityList)
})
// Hook 2: BeforeScan - Apply row-level security filters
handler.Hooks().Register(BeforeScan, func(hookCtx *HookContext) error {
secCtx := newSecurityContext(hookCtx)
return security.ApplyRowSecurity(secCtx, securityList)
})
// Hook 3: AfterRead - Apply column-level security (masking)
handler.Hooks().Register(AfterRead, func(hookCtx *HookContext) error {
secCtx := newSecurityContext(hookCtx)
return security.ApplyColumnSecurity(secCtx, securityList)
})
// Hook 4 (Optional): Audit logging
handler.Hooks().Register(AfterRead, func(hookCtx *HookContext) error {
secCtx := newSecurityContext(hookCtx)
return security.LogDataAccess(secCtx)
})
logger.Info("Security hooks registered for restheadspec handler")
}
// securityContext adapts restheadspec.HookContext to security.SecurityContext interface
type securityContext struct {
ctx *HookContext
}
func newSecurityContext(ctx *HookContext) security.SecurityContext {
return &securityContext{ctx: ctx}
}
func (s *securityContext) GetContext() context.Context {
return s.ctx.Context
}
func (s *securityContext) GetUserID() (int, bool) {
return security.GetUserID(s.ctx.Context)
}
func (s *securityContext) GetSchema() string {
return s.ctx.Schema
}
func (s *securityContext) GetEntity() string {
return s.ctx.Entity
}
func (s *securityContext) GetModel() interface{} {
return s.ctx.Model
}
func (s *securityContext) GetQuery() interface{} {
return s.ctx.Query
}
func (s *securityContext) SetQuery(query interface{}) {
s.ctx.Query = query
}
func (s *securityContext) GetResult() interface{} {
return s.ctx.Result
}
func (s *securityContext) SetResult(result interface{}) {
s.ctx.Result = result
}

View File

@@ -1,662 +0,0 @@
# Security Provider Callbacks Guide
## Overview
The ResolveSpec security provider uses a **callback-based architecture** that requires you to implement three functions:
1. **AuthenticateCallback** - Extract user credentials from HTTP requests
2. **LoadColumnSecurityCallback** - Load column security rules for masking/hiding
3. **LoadRowSecurityCallback** - Load row security filters (WHERE clauses)
This design allows you to integrate the security provider with **any** authentication system and database schema.
---
## Why Callbacks?
The callback-based design provides:
**Flexibility** - Works with any auth system (JWT, session, OAuth, custom)
**Database Agnostic** - No assumptions about your security table schema
**Testability** - Easy to mock for unit tests
**Extensibility** - Add custom logic without modifying core code
---
## Quick Start
### Step 1: Implement the Three Callbacks
```go
package main
import (
"fmt"
"net/http"
"github.com/bitechdev/ResolveSpec/pkg/security"
)
// 1. Authentication: Extract user from request
func myAuthFunction(r *http.Request) (userID int, roles string, err error) {
// Your auth logic here (JWT, session, header, etc.)
token := r.Header.Get("Authorization")
userID, roles, err = validateToken(token)
return userID, roles, err
}
// 2. Column Security: Load column masking rules
func myLoadColumnSecurity(userID int, schema, tablename string) ([]security.ColumnSecurity, error) {
// Your database query or config lookup here
return loadColumnRulesFromDatabase(userID, schema, tablename)
}
// 3. Row Security: Load row filtering rules
func myLoadRowSecurity(userID int, schema, tablename string) (security.RowSecurity, error) {
// Your database query or config lookup here
return loadRowRulesFromDatabase(userID, schema, tablename)
}
```
### Step 2: Configure the Callbacks
```go
func main() {
db := setupDatabase()
handler := restheadspec.NewHandlerWithGORM(db)
// Configure callbacks BEFORE SetupSecurityProvider
security.GlobalSecurity.AuthenticateCallback = myAuthFunction
security.GlobalSecurity.LoadColumnSecurityCallback = myLoadColumnSecurity
security.GlobalSecurity.LoadRowSecurityCallback = myLoadRowSecurity
// Setup security provider (validates callbacks are set)
if err := security.SetupSecurityProvider(handler, &security.GlobalSecurity); err != nil {
log.Fatal(err) // Fails if callbacks not configured
}
// Apply middleware
router := mux.NewRouter()
restheadspec.SetupMuxRoutes(router, handler)
router.Use(mux.MiddlewareFunc(security.AuthMiddleware))
router.Use(mux.MiddlewareFunc(security.SetSecurityMiddleware))
http.ListenAndServe(":8080", router)
}
```
---
## Callback 1: AuthenticateCallback
### Function Signature
```go
func(r *http.Request) (userID int, roles string, err error)
```
### Parameters
- `r *http.Request` - The incoming HTTP request
### Returns
- `userID int` - The authenticated user's ID
- `roles string` - User's roles (comma-separated, e.g., "admin,manager")
- `err error` - Return error to reject the request (HTTP 401)
### Example Implementations
#### Simple Header-Based Auth
```go
func authenticateFromHeader(r *http.Request) (int, string, error) {
userIDStr := r.Header.Get("X-User-ID")
if userIDStr == "" {
return 0, "", fmt.Errorf("X-User-ID header required")
}
userID, err := strconv.Atoi(userIDStr)
if err != nil {
return 0, "", fmt.Errorf("invalid user ID")
}
roles := r.Header.Get("X-User-Roles") // Optional
return userID, roles, nil
}
```
#### JWT Token Auth
```go
import "github.com/golang-jwt/jwt/v5"
func authenticateFromJWT(r *http.Request) (int, string, error) {
authHeader := r.Header.Get("Authorization")
tokenString := strings.TrimPrefix(authHeader, "Bearer ")
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
return []byte(os.Getenv("JWT_SECRET")), nil
})
if err != nil || !token.Valid {
return 0, "", fmt.Errorf("invalid token")
}
claims := token.Claims.(jwt.MapClaims)
userID := int(claims["user_id"].(float64))
roles := claims["roles"].(string)
return userID, roles, nil
}
```
#### Session Cookie Auth
```go
func authenticateFromSession(r *http.Request) (int, string, error) {
cookie, err := r.Cookie("session_id")
if err != nil {
return 0, "", fmt.Errorf("no session cookie")
}
session, err := sessionStore.Get(cookie.Value)
if err != nil {
return 0, "", fmt.Errorf("invalid session")
}
return session.UserID, session.Roles, nil
}
```
---
## Callback 2: LoadColumnSecurityCallback
### Function Signature
```go
func(pUserID int, pSchema, pTablename string) ([]ColumnSecurity, error)
```
### Parameters
- `pUserID int` - The authenticated user's ID
- `pSchema string` - Database schema (e.g., "public")
- `pTablename string` - Table name (e.g., "employees")
### Returns
- `[]ColumnSecurity` - List of column security rules
- `error` - Return error if loading fails
### ColumnSecurity Structure
```go
type ColumnSecurity struct {
Schema string // "public"
Tablename string // "employees"
Path []string // ["ssn"] or ["address", "street"]
Accesstype string // "mask" or "hide"
// Masking configuration (for Accesstype = "mask")
MaskStart int // Mask first N characters
MaskEnd int // Mask last N characters
MaskInvert bool // true = mask middle, false = mask edges
MaskChar string // Character to use for masking (default "*")
// Optional fields
ExtraFilters map[string]string
Control string
ID int
UserID int
}
```
### Example Implementations
#### Load from Database
```go
func loadColumnSecurityFromDB(userID int, schema, tablename string) ([]security.ColumnSecurity, error) {
var rules []security.ColumnSecurity
query := `
SELECT control, accesstype, jsonvalue
FROM core.secacces
WHERE rid_hub IN (
SELECT rid_hub_parent FROM core.hub_link
WHERE rid_hub_child = ? AND parent_hubtype = 'secgroup'
)
AND control ILIKE ?
`
rows, err := db.Query(query, userID, fmt.Sprintf("%s.%s%%", schema, tablename))
if err != nil {
return nil, err
}
defer rows.Close()
for rows.Next() {
var control, accesstype, jsonValue string
rows.Scan(&control, &accesstype, &jsonValue)
// Parse control: "schema.table.column"
parts := strings.Split(control, ".")
if len(parts) < 3 {
continue
}
rule := security.ColumnSecurity{
Schema: schema,
Tablename: tablename,
Path: parts[2:],
Accesstype: accesstype,
}
// Parse JSON configuration
var config map[string]interface{}
json.Unmarshal([]byte(jsonValue), &config)
if start, ok := config["start"].(float64); ok {
rule.MaskStart = int(start)
}
if end, ok := config["end"].(float64); ok {
rule.MaskEnd = int(end)
}
if char, ok := config["char"].(string); ok {
rule.MaskChar = char
}
rules = append(rules, rule)
}
return rules, nil
}
```
#### Load from Static Config
```go
func loadColumnSecurityFromConfig(userID int, schema, tablename string) ([]security.ColumnSecurity, error) {
// Define security rules in code
allRules := map[string][]security.ColumnSecurity{
"public.employees": {
{
Schema: "public",
Tablename: "employees",
Path: []string{"ssn"},
Accesstype: "mask",
MaskStart: 5,
MaskChar: "*",
},
{
Schema: "public",
Tablename: "employees",
Path: []string{"salary"},
Accesstype: "hide",
},
},
}
key := fmt.Sprintf("%s.%s", schema, tablename)
rules, ok := allRules[key]
if !ok {
return []security.ColumnSecurity{}, nil // No rules
}
return rules, nil
}
```
### Column Security Examples
**Mask SSN (show last 4 digits):**
```go
ColumnSecurity{
Path: []string{"ssn"},
Accesstype: "mask",
MaskStart: 5, // Mask first 5 characters
MaskEnd: 0, // Keep last 4 visible
MaskChar: "*",
}
// Result: "123-45-6789" → "*****6789"
```
**Hide entire field:**
```go
ColumnSecurity{
Path: []string{"salary"},
Accesstype: "hide",
}
// Result: salary field returns 0 or empty
```
**Mask credit card (show last 4 digits):**
```go
ColumnSecurity{
Path: []string{"credit_card"},
Accesstype: "mask",
MaskStart: 12,
MaskChar: "*",
}
// Result: "1234-5678-9012-3456" → "************3456"
```
---
## Callback 3: LoadRowSecurityCallback
### Function Signature
```go
func(pUserID int, pSchema, pTablename string) (RowSecurity, error)
```
### Parameters
- `pUserID int` - The authenticated user's ID
- `pSchema string` - Database schema
- `pTablename string` - Table name
### Returns
- `RowSecurity` - Row security configuration
- `error` - Return error if loading fails
### RowSecurity Structure
```go
type RowSecurity struct {
Schema string // "public"
Tablename string // "orders"
UserID int // Current user ID
Template string // WHERE clause template (e.g., "user_id = {UserID}")
HasBlock bool // If true, block ALL access to this table
}
```
### Template Variables
You can use these placeholders in the `Template` string:
- `{UserID}` - Current user's ID
- `{PrimaryKeyName}` - Primary key column name
- `{TableName}` - Table name
- `{SchemaName}` - Schema name
### Example Implementations
#### Load from Database Function
```go
func loadRowSecurityFromDB(userID int, schema, tablename string) (security.RowSecurity, error) {
var record security.RowSecurity
query := `
SELECT p_template, p_block
FROM core.api_sec_rowtemplate(?, ?, ?)
`
row := db.QueryRow(query, schema, tablename, userID)
err := row.Scan(&record.Template, &record.HasBlock)
if err != nil {
return security.RowSecurity{}, err
}
record.Schema = schema
record.Tablename = tablename
record.UserID = userID
return record, nil
}
```
#### Load from Static Config
```go
func loadRowSecurityFromConfig(userID int, schema, tablename string) (security.RowSecurity, error) {
key := fmt.Sprintf("%s.%s", schema, tablename)
// Define templates for each table
templates := map[string]string{
"public.orders": "user_id = {UserID}",
"public.documents": "user_id = {UserID} OR is_public = true",
}
// Define blocked tables
blocked := map[string]bool{
"public.admin_logs": true,
}
if blocked[key] {
return security.RowSecurity{
Schema: schema,
Tablename: tablename,
UserID: userID,
HasBlock: true,
}, nil
}
template, ok := templates[key]
if !ok {
// No row security - allow all rows
return security.RowSecurity{
Schema: schema,
Tablename: tablename,
UserID: userID,
Template: "",
HasBlock: false,
}, nil
}
return security.RowSecurity{
Schema: schema,
Tablename: tablename,
UserID: userID,
Template: template,
HasBlock: false,
}, nil
}
```
### Row Security Examples
**Users see only their own records:**
```go
RowSecurity{
Template: "user_id = {UserID}",
}
// Query: SELECT * FROM orders WHERE user_id = 123
```
**Users see their records OR public records:**
```go
RowSecurity{
Template: "user_id = {UserID} OR is_public = true",
}
```
**Complex filter with subquery:**
```go
RowSecurity{
Template: "department_id IN (SELECT department_id FROM user_departments WHERE user_id = {UserID})",
}
```
**Block all access:**
```go
RowSecurity{
HasBlock: true,
}
// All queries to this table will be rejected
```
---
## Complete Integration Example
```go
package main
import (
"fmt"
"log"
"net/http"
"strconv"
"github.com/bitechdev/ResolveSpec/pkg/restheadspec"
"github.com/bitechdev/ResolveSpec/pkg/security"
"github.com/gorilla/mux"
"gorm.io/gorm"
)
func main() {
db := setupDatabase()
handler := restheadspec.NewHandlerWithGORM(db)
handler.RegisterModel("public", "orders", Order{})
// ===== CONFIGURE CALLBACKS =====
security.GlobalSecurity.AuthenticateCallback = authenticateUser
security.GlobalSecurity.LoadColumnSecurityCallback = loadColumnSec
security.GlobalSecurity.LoadRowSecurityCallback = loadRowSec
// ===== SETUP SECURITY =====
if err := security.SetupSecurityProvider(handler, &security.GlobalSecurity); err != nil {
log.Fatal("Security setup failed:", err)
}
// ===== SETUP ROUTES =====
router := mux.NewRouter()
restheadspec.SetupMuxRoutes(router, handler)
router.Use(mux.MiddlewareFunc(security.AuthMiddleware))
router.Use(mux.MiddlewareFunc(security.SetSecurityMiddleware))
log.Println("Server starting on :8080")
http.ListenAndServe(":8080", router)
}
// Callback implementations
func authenticateUser(r *http.Request) (int, string, error) {
userIDStr := r.Header.Get("X-User-ID")
if userIDStr == "" {
return 0, "", fmt.Errorf("authentication required")
}
userID, err := strconv.Atoi(userIDStr)
return userID, "", err
}
func loadColumnSec(userID int, schema, table string) ([]security.ColumnSecurity, error) {
// Your implementation here
return []security.ColumnSecurity{}, nil
}
func loadRowSec(userID int, schema, table string) (security.RowSecurity, error) {
return security.RowSecurity{
Schema: schema,
Tablename: table,
UserID: userID,
Template: "user_id = " + strconv.Itoa(userID),
}, nil
}
```
---
## Testing Your Callbacks
### Unit Test Example
```go
func TestAuthCallback(t *testing.T) {
req := httptest.NewRequest("GET", "/api/orders", nil)
req.Header.Set("X-User-ID", "123")
userID, roles, err := myAuthFunction(req)
assert.Nil(t, err)
assert.Equal(t, 123, userID)
}
func TestColumnSecurityCallback(t *testing.T) {
rules, err := myLoadColumnSecurity(123, "public", "employees")
assert.Nil(t, err)
assert.Greater(t, len(rules), 0)
assert.Equal(t, "mask", rules[0].Accesstype)
}
```
---
## Common Patterns
### Pattern 1: Role-Based Security
```go
func loadColumnSec(userID int, schema, table string) ([]security.ColumnSecurity, error) {
roles := getUserRoles(userID)
if contains(roles, "admin") {
// Admins see everything
return []security.ColumnSecurity{}, nil
}
// Non-admins have restrictions
return []security.ColumnSecurity{
{Path: []string{"ssn"}, Accesstype: "mask"},
}, nil
}
```
### Pattern 2: Tenant Isolation
```go
func loadRowSec(userID int, schema, table string) (security.RowSecurity, error) {
tenantID := getUserTenant(userID)
return security.RowSecurity{
Template: fmt.Sprintf("tenant_id = %d", tenantID),
}, nil
}
```
### Pattern 3: Caching Security Rules
```go
var securityCache = cache.New(5*time.Minute, 10*time.Minute)
func loadColumnSec(userID int, schema, table string) ([]security.ColumnSecurity, error) {
cacheKey := fmt.Sprintf("%d:%s.%s", userID, schema, table)
if cached, found := securityCache.Get(cacheKey); found {
return cached.([]security.ColumnSecurity), nil
}
rules := loadFromDatabase(userID, schema, table)
securityCache.Set(cacheKey, rules, cache.DefaultExpiration)
return rules, nil
}
```
---
## Troubleshooting
### Error: "AuthenticateCallback not set"
**Solution:** Configure all three callbacks before calling `SetupSecurityProvider`:
```go
security.GlobalSecurity.AuthenticateCallback = myAuthFunc
security.GlobalSecurity.LoadColumnSecurityCallback = myColSecFunc
security.GlobalSecurity.LoadRowSecurityCallback = myRowSecFunc
```
### Error: "Authentication failed"
**Solution:** Check your `AuthenticateCallback` implementation. Ensure it returns valid user ID or proper error.
### Security rules not applying
**Solution:**
1. Check callbacks are returning data
2. Enable debug logging
3. Verify database queries return results
4. Check user has security groups assigned
---
## Next Steps
1. ✅ Implement the three callbacks for your system
2. ✅ Configure `GlobalSecurity` with your callbacks
3. ✅ Call `SetupSecurityProvider`
4. ✅ Test with different users and verify isolation
5. ✅ Review `callbacks_example.go` for more examples
For complete working examples, see:
- `pkg/security/callbacks_example.go` - 7 example implementations
- `examples/secure_server/main.go` - Full server example
- `pkg/security/README.md` - Comprehensive documentation

View File

@@ -3,35 +3,97 @@
## 3-Step Setup
```go
// Step 1: Implement callbacks
func myAuth(r *http.Request) (int, string, error) { /* ... */ }
func myColSec(userID int, schema, table string) ([]security.ColumnSecurity, error) { /* ... */ }
func myRowSec(userID int, schema, table string) (security.RowSecurity, error) { /* ... */ }
// Step 1: Create security providers
auth := security.NewDatabaseAuthenticator(db) // Session-based (recommended)
// OR: auth := security.NewJWTAuthenticator("secret-key", db)
// OR: auth := security.NewHeaderAuthenticator()
// Step 2: Configure callbacks
security.GlobalSecurity.AuthenticateCallback = myAuth
security.GlobalSecurity.LoadColumnSecurityCallback = myColSec
security.GlobalSecurity.LoadRowSecurityCallback = myRowSec
colSec := security.NewDatabaseColumnSecurityProvider(db)
rowSec := security.NewDatabaseRowSecurityProvider(db)
// Step 2: Combine providers
provider := security.NewCompositeSecurityProvider(auth, colSec, rowSec)
// Step 3: Setup and apply middleware
security.SetupSecurityProvider(handler, &security.GlobalSecurity)
router.Use(mux.MiddlewareFunc(security.AuthMiddleware))
router.Use(mux.MiddlewareFunc(security.SetSecurityMiddleware))
securityList := security.SetupSecurityProvider(handler, provider)
router.Use(security.NewAuthMiddleware(securityList))
router.Use(security.SetSecurityMiddleware(securityList))
```
---
## Callback Signatures
## Stored Procedures
**All database operations use PostgreSQL stored procedures** with `resolvespec_*` naming:
### Database Authenticators
```go
// DatabaseAuthenticator uses these stored procedures:
resolvespec_login(jsonb) // Login with credentials
resolvespec_logout(jsonb) // Invalidate session
resolvespec_session(text, text) // Validate session token
resolvespec_session_update(text, jsonb) // Update activity timestamp
resolvespec_refresh_token(text, jsonb) // Generate new session
// JWTAuthenticator uses these stored procedures:
resolvespec_jwt_login(text, text) // Validate credentials
resolvespec_jwt_logout(text, int) // Blacklist token
```
### Security Providers
```go
// DatabaseColumnSecurityProvider:
resolvespec_column_security(int, text, text) // Load column rules
// DatabaseRowSecurityProvider:
resolvespec_row_security(text, text, int) // Load row template
```
All stored procedures return structured results:
- Session/Login: `(p_success bool, p_error text, p_data jsonb)`
- Security: `(p_success bool, p_error text, p_rules jsonb)`
See `database_schema.sql` for complete definitions.
---
## Interface Signatures
```go
// 1. Authentication
func(r *http.Request) (userID int, roles string, err error)
// Authenticator interface
type Authenticator interface {
Login(ctx context.Context, req LoginRequest) (*LoginResponse, error)
Logout(ctx context.Context, req LogoutRequest) error
Authenticate(r *http.Request) (*UserContext, error)
}
// 2. Column Security
func(userID int, schema, tablename string) ([]ColumnSecurity, error)
// ColumnSecurityProvider interface
type ColumnSecurityProvider interface {
GetColumnSecurity(ctx context.Context, userID int, schema, table string) ([]ColumnSecurity, error)
}
// 3. Row Security
func(userID int, schema, tablename string) (RowSecurity, error)
// RowSecurityProvider interface
type RowSecurityProvider interface {
GetRowSecurity(ctx context.Context, userID int, schema, table string) (RowSecurity, error)
}
```
---
## UserContext Structure
```go
security.UserContext{
UserID: 123, // User's unique ID
UserName: "john_doe", // Username
UserLevel: 5, // User privilege level
SessionID: "sess_abc123", // Current session ID
RemoteID: "remote_xyz", // Remote system ID
Roles: []string{"admin"}, // User roles
Email: "john@example.com", // User email
Claims: map[string]any{}, // Additional authentication claims
Meta: map[string]any{}, // Additional metadata (JSON-serializable)
}
```
---
@@ -109,70 +171,204 @@ HasBlock: true
## Example Implementations
### Simple Header Auth
### Database Session Authenticator (Recommended)
```go
func authFromHeader(r *http.Request) (int, string, error) {
// Create authenticator
auth := security.NewDatabaseAuthenticator(db)
// Requires these tables:
// - users (id, username, email, password, user_level, roles, is_active)
// - user_sessions (session_token, user_id, expires_at, created_at, last_activity_at)
// See database_schema.sql for full schema
// Features:
// - Login with username/password
// - Session management in database
// - Token refresh support (implements Refreshable)
// - Automatic session expiration
// - Tracks IP address and user agent
// - Works with Authorization header or cookie
```
### Simple Header Authenticator
```go
type HeaderAuthenticator struct{}
func NewHeaderAuthenticator() *HeaderAuthenticator {
return &HeaderAuthenticator{}
}
func (a *HeaderAuthenticator) Login(ctx context.Context, req security.LoginRequest) (*security.LoginResponse, error) {
return nil, fmt.Errorf("not supported")
}
func (a *HeaderAuthenticator) Logout(ctx context.Context, req security.LogoutRequest) error {
return nil
}
func (a *HeaderAuthenticator) Authenticate(r *http.Request) (*security.UserContext, error) {
userIDStr := r.Header.Get("X-User-ID")
if userIDStr == "" {
return 0, "", fmt.Errorf("X-User-ID required")
return nil, fmt.Errorf("X-User-ID required")
}
userID, err := strconv.Atoi(userIDStr)
return userID, "", err
userID, _ := strconv.Atoi(userIDStr)
return &security.UserContext{
UserID: userID,
UserName: r.Header.Get("X-User-Name"),
}, nil
}
```
### JWT Auth
### JWT Authenticator
```go
func authFromJWT(r *http.Request) (int, string, error) {
token := strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer ")
claims, err := jwt.Parse(token, secret)
type JWTAuthenticator struct {
secretKey []byte
db *gorm.DB
}
func NewJWTAuthenticator(secret string, db *gorm.DB) *JWTAuthenticator {
return &JWTAuthenticator{secretKey: []byte(secret), db: db}
}
func (a *JWTAuthenticator) Login(ctx context.Context, req security.LoginRequest) (*security.LoginResponse, error) {
// Validate credentials against database
var user User
err := a.db.WithContext(ctx).Where("username = ?", req.Username).First(&user).Error
if err != nil {
return 0, "", err
return nil, fmt.Errorf("invalid credentials")
}
return claims.UserID, claims.Roles, nil
// Generate JWT token
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"user_id": user.ID,
"exp": time.Now().Add(24 * time.Hour).Unix(),
})
tokenString, _ := token.SignedString(a.secretKey)
return &security.LoginResponse{
Token: tokenString,
User: &security.UserContext{UserID: user.ID},
ExpiresIn: 86400,
}, nil
}
func (a *JWTAuthenticator) Logout(ctx context.Context, req security.LogoutRequest) error {
// Add to blacklist
return a.db.WithContext(ctx).Table("token_blacklist").Create(map[string]any{
"token": req.Token,
"user_id": req.UserID,
}).Error
}
func (a *JWTAuthenticator) Authenticate(r *http.Request) (*security.UserContext, error) {
tokenString := strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer ")
token, err := jwt.Parse(tokenString, func(t *jwt.Token) (any, error) {
return a.secretKey, nil
})
if err != nil || !token.Valid {
return nil, fmt.Errorf("invalid token")
}
claims := token.Claims.(jwt.MapClaims)
return &security.UserContext{
UserID: int(claims["user_id"].(float64)),
}, nil
}
```
### Static Column Security
```go
func loadColSec(userID int, schema, table string) ([]security.ColumnSecurity, error) {
if table == "employees" {
return []security.ColumnSecurity{
{Path: []string{"ssn"}, Accesstype: "mask", MaskStart: 5},
{Path: []string{"salary"}, Accesstype: "hide"},
}, nil
}
return []security.ColumnSecurity{}, nil
type ConfigColumnSecurityProvider struct {
rules map[string][]security.ColumnSecurity
}
func NewConfigColumnSecurityProvider(rules map[string][]security.ColumnSecurity) *ConfigColumnSecurityProvider {
return &ConfigColumnSecurityProvider{rules: rules}
}
func (p *ConfigColumnSecurityProvider) GetColumnSecurity(ctx context.Context, userID int, schema, table string) ([]security.ColumnSecurity, error) {
key := fmt.Sprintf("%s.%s", schema, table)
return p.rules[key], nil
}
```
### Database Column Security
```go
func loadColSec(userID int, schema, table string) ([]security.ColumnSecurity, error) {
rows, err := db.Query(`
type DatabaseColumnSecurityProvider struct {
db *gorm.DB
}
func NewDatabaseColumnSecurityProvider(db *gorm.DB) *DatabaseColumnSecurityProvider {
return &DatabaseColumnSecurityProvider{db: db}
}
func (p *DatabaseColumnSecurityProvider) GetColumnSecurity(ctx context.Context, userID int, schema, table string) ([]security.ColumnSecurity, error) {
var records []struct {
Control string
Accesstype string
JSONValue string
}
query := `
SELECT control, accesstype, jsonvalue
FROM core.secacces
WHERE rid_hub IN (...)
FROM core.secaccess
WHERE rid_hub IN (
SELECT rid_hub_parent FROM core.hub_link
WHERE rid_hub_child = ? AND parent_hubtype = 'secgroup'
)
AND control ILIKE ?
`, fmt.Sprintf("%s.%s%%", schema, table))
// ... parse and return
`
err := p.db.WithContext(ctx).Raw(query, userID, fmt.Sprintf("%s.%s%%", schema, table)).Scan(&records).Error
if err != nil {
return nil, err
}
var rules []security.ColumnSecurity
for _, rec := range records {
parts := strings.Split(rec.Control, ".")
if len(parts) < 3 {
continue
}
rules = append(rules, security.ColumnSecurity{
Schema: schema,
Tablename: table,
Path: parts[2:],
Accesstype: rec.Accesstype,
})
}
return rules, nil
}
```
### Static Row Security
```go
func loadRowSec(userID int, schema, table string) (security.RowSecurity, error) {
templates := map[string]string{
"orders": "user_id = {UserID}",
"documents": "user_id = {UserID} OR is_public = true",
type ConfigRowSecurityProvider struct {
templates map[string]string
blocked map[string]bool
}
func NewConfigRowSecurityProvider(templates map[string]string, blocked map[string]bool) *ConfigRowSecurityProvider {
return &ConfigRowSecurityProvider{templates: templates, blocked: blocked}
}
func (p *ConfigRowSecurityProvider) GetRowSecurity(ctx context.Context, userID int, schema, table string) (security.RowSecurity, error) {
key := fmt.Sprintf("%s.%s", schema, table)
if p.blocked[key] {
return security.RowSecurity{HasBlock: true}, nil
}
return security.RowSecurity{
Template: templates[table],
Schema: schema,
Tablename: table,
UserID: userID,
Template: p.templates[key],
}, nil
}
```
@@ -182,19 +378,22 @@ func loadRowSec(userID int, schema, table string) (security.RowSecurity, error)
## Testing
```go
// Test auth callback
// Test Authenticator
auth := security.NewHeaderAuthenticator()
req := httptest.NewRequest("GET", "/", nil)
req.Header.Set("X-User-ID", "123")
userID, roles, err := myAuth(req)
assert.Equal(t, 123, userID)
userCtx, err := auth.Authenticate(req)
assert.Equal(t, 123, userCtx.UserID)
// Test column security callback
rules, err := myColSec(123, "public", "employees")
assert.Equal(t, "mask", rules[0].Accesstype)
// Test ColumnSecurityProvider
colSec := security.NewConfigColumnSecurityProvider(rules)
cols, err := colSec.GetColumnSecurity(context.Background(), 123, "public", "employees")
assert.Equal(t, "mask", cols[0].Accesstype)
// Test row security callback
rowSec, err := myRowSec(123, "public", "orders")
assert.Equal(t, "user_id = {UserID}", rowSec.Template)
// Test RowSecurityProvider
rowSec := security.NewConfigRowSecurityProvider(templates, blocked)
row, err := rowSec.GetRowSecurity(context.Background(), 123, "public", "orders")
assert.Equal(t, "user_id = {UserID}", row.Template)
```
---
@@ -204,13 +403,13 @@ assert.Equal(t, "user_id = {UserID}", rowSec.Template)
```
HTTP Request
AuthMiddleware → calls AuthenticateCallback
↓ (adds userID to context)
SetSecurityMiddleware → adds GlobalSecurity to context
NewAuthMiddleware → calls provider.Authenticate()
↓ (adds UserContext to context)
SetSecurityMiddleware → adds SecurityList to context
Handler.Handle()
BeforeRead Hook → calls LoadColumnSecurityCallback + LoadRowSecurityCallback
BeforeRead Hook → calls provider.GetColumnSecurity() + GetRowSecurity()
BeforeScan Hook → applies row security (WHERE clause)
@@ -228,10 +427,13 @@ HTTP Response
### Role-Based Security
```go
func loadColSec(userID int, schema, table string) ([]security.ColumnSecurity, error) {
if isAdmin(userID) {
func (p *MyColumnSecurityProvider) GetColumnSecurity(ctx context.Context, userID int, schema, table string) ([]security.ColumnSecurity, error) {
userCtx, _ := security.GetUserContext(ctx)
if contains(userCtx.Roles, "admin") {
return []security.ColumnSecurity{}, nil // No restrictions
}
return loadRestrictions(userID, schema, table), nil
}
```
@@ -239,7 +441,7 @@ func loadColSec(userID int, schema, table string) ([]security.ColumnSecurity, er
### Tenant Isolation
```go
func loadRowSec(userID int, schema, table string) (security.RowSecurity, error) {
func (p *MyRowSecurityProvider) GetRowSecurity(ctx context.Context, userID int, schema, table string) (security.RowSecurity, error) {
tenantID := getUserTenant(userID)
return security.RowSecurity{
Template: fmt.Sprintf("tenant_id = %d", tenantID),
@@ -247,19 +449,26 @@ func loadRowSec(userID int, schema, table string) (security.RowSecurity, error)
}
```
### Caching
### Caching with Decorator
```go
var cache = make(map[string][]security.ColumnSecurity)
type CachedColumnSecurityProvider struct {
inner security.ColumnSecurityProvider
cache *cache.Cache
}
func loadColSec(userID int, schema, table string) ([]security.ColumnSecurity, error) {
func (p *CachedColumnSecurityProvider) GetColumnSecurity(ctx context.Context, userID int, schema, table string) ([]security.ColumnSecurity, error) {
key := fmt.Sprintf("%d:%s.%s", userID, schema, table)
if cached, ok := cache[key]; ok {
return cached, nil
if cached, found := p.cache.Get(key); found {
return cached.([]security.ColumnSecurity), nil
}
rules := loadFromDB(userID, schema, table)
cache[key] = rules
return rules, nil
rules, err := p.inner.GetColumnSecurity(ctx, userID, schema, table)
if err == nil {
p.cache.Set(key, rules, cache.DefaultExpiration)
}
return rules, err
}
```
@@ -268,21 +477,20 @@ func loadColSec(userID int, schema, table string) ([]security.ColumnSecurity, er
## Error Handling
```go
// Setup will fail if callbacks not configured
if err := security.SetupSecurityProvider(handler, &security.GlobalSecurity); err != nil {
log.Fatal("Security setup failed:", err)
}
// Panic if provider is nil
provider := security.NewCompositeSecurityProvider(auth, colSec, rowSec)
// panics if any parameter is nil
// Auth middleware rejects if callback returns error
func myAuth(r *http.Request) (int, string, error) {
// Auth middleware returns 401 if Authenticate fails
func (a *MyAuthenticator) Authenticate(r *http.Request) (*security.UserContext, error) {
if invalid {
return 0, "", fmt.Errorf("invalid credentials") // Returns HTTP 401
return nil, fmt.Errorf("invalid credentials") // Returns HTTP 401
}
return userID, roles, nil
return &security.UserContext{UserID: userID}, nil
}
// Security loading can fail gracefully
func loadColSec(userID int, schema, table string) ([]security.ColumnSecurity, error) {
func (p *MyProvider) GetColumnSecurity(ctx context.Context, userID int, schema, table string) ([]security.ColumnSecurity, error) {
rules, err := db.Load(...)
if err != nil {
log.Printf("Failed to load security: %v", err)
@@ -294,6 +502,45 @@ func loadColSec(userID int, schema, table string) ([]security.ColumnSecurity, er
---
## Login/Logout Endpoints
```go
func SetupAuthRoutes(router *mux.Router, securityList *security.SecurityList) {
// Login
router.HandleFunc("/auth/login", func(w http.ResponseWriter, r *http.Request) {
var req security.LoginRequest
json.NewDecoder(r.Body).Decode(&req)
resp, err := securityList.Provider().Login(r.Context(), req)
if err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
json.NewEncoder(w).Encode(resp)
}).Methods("POST")
// Logout
router.HandleFunc("/auth/logout", func(w http.ResponseWriter, r *http.Request) {
token := r.Header.Get("Authorization")
userID, _ := security.GetUserID(r.Context())
err := securityList.Provider().Logout(r.Context(), security.LogoutRequest{
Token: token,
UserID: userID,
})
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
w.WriteHeader(http.StatusOK)
}).Methods("POST")
}
```
---
## Debugging
```go
@@ -301,15 +548,15 @@ func loadColSec(userID int, schema, table string) ([]security.ColumnSecurity, er
import "github.com/bitechdev/GoCore/pkg/cfg"
cfg.SetLogLevel("DEBUG")
// Log in callbacks
func myAuth(r *http.Request) (int, string, error) {
// Log in provider methods
func (a *MyAuthenticator) Authenticate(r *http.Request) (*security.UserContext, error) {
token := r.Header.Get("Authorization")
log.Printf("Auth: token=%s", token)
// ...
}
// Check if callbacks are called
func loadColSec(userID int, schema, table string) ([]security.ColumnSecurity, error) {
// Check if methods are called
func (p *MyColumnSecurityProvider) GetColumnSecurity(ctx context.Context, userID int, schema, table string) ([]security.ColumnSecurity, error) {
log.Printf("Loading column security: user=%d, schema=%s, table=%s", userID, schema, table)
// ...
}
@@ -323,6 +570,7 @@ func loadColSec(userID int, schema, table string) ([]security.ColumnSecurity, er
package main
import (
"context"
"fmt"
"net/http"
"strconv"
@@ -331,29 +579,42 @@ import (
"github.com/gorilla/mux"
)
// Simple all-in-one provider
type SimpleProvider struct{}
func (p *SimpleProvider) Login(ctx context.Context, req security.LoginRequest) (*security.LoginResponse, error) {
return nil, fmt.Errorf("not implemented")
}
func (p *SimpleProvider) Logout(ctx context.Context, req security.LogoutRequest) error {
return nil
}
func (p *SimpleProvider) Authenticate(r *http.Request) (*security.UserContext, error) {
id, _ := strconv.Atoi(r.Header.Get("X-User-ID"))
return &security.UserContext{UserID: id}, nil
}
func (p *SimpleProvider) GetColumnSecurity(ctx context.Context, u int, s, t string) ([]security.ColumnSecurity, error) {
return []security.ColumnSecurity{}, nil
}
func (p *SimpleProvider) GetRowSecurity(ctx context.Context, u int, s, t string) (security.RowSecurity, error) {
return security.RowSecurity{Template: fmt.Sprintf("user_id = %d", u)}, nil
}
func main() {
handler := restheadspec.NewHandlerWithGORM(db)
// Configure callbacks
security.GlobalSecurity.AuthenticateCallback = func(r *http.Request) (int, string, error) {
id, _ := strconv.Atoi(r.Header.Get("X-User-ID"))
return id, "", nil
}
security.GlobalSecurity.LoadColumnSecurityCallback = func(u int, s, t string) ([]security.ColumnSecurity, error) {
return []security.ColumnSecurity{}, nil
}
security.GlobalSecurity.LoadRowSecurityCallback = func(u int, s, t string) (security.RowSecurity, error) {
return security.RowSecurity{Template: fmt.Sprintf("user_id = %d", u)}, nil
}
// Setup security
provider := &SimpleProvider{}
securityList := security.SetupSecurityProvider(handler, provider)
// Setup
security.SetupSecurityProvider(handler, &security.GlobalSecurity)
// Middleware
// Apply middleware
router := mux.NewRouter()
restheadspec.SetupMuxRoutes(router, handler)
router.Use(mux.MiddlewareFunc(security.AuthMiddleware))
router.Use(mux.MiddlewareFunc(security.SetSecurityMiddleware))
router.Use(security.NewAuthMiddleware(securityList))
router.Use(security.SetSecurityMiddleware(securityList))
http.ListenAndServe(":8080", router)
}
@@ -361,15 +622,94 @@ func main() {
---
## Authentication Modes
```go
// Required authentication (default)
// Authentication must succeed or returns 401
router.Use(security.NewAuthMiddleware(securityList))
// Skip authentication for specific routes
// Always sets guest user context
func PublicRoute(w http.ResponseWriter, r *http.Request) {
ctx := security.SkipAuth(r.Context())
r = r.WithContext(ctx)
// Guest context will be set
}
// Optional authentication for specific routes
// Tries to authenticate, falls back to guest if it fails
func HomeRoute(w http.ResponseWriter, r *http.Request) {
ctx := security.OptionalAuth(r.Context())
r = r.WithContext(ctx)
userCtx, _ := security.GetUserContext(r.Context())
if userCtx.UserID == 0 {
// Guest user
} else {
// Authenticated user
}
}
```
**Comparison:**
- **Required**: Auth must succeed or return 401 (default)
- **SkipAuth**: Never tries to authenticate, always guest
- **OptionalAuth**: Tries to authenticate, guest on failure
---
## Standalone Handlers
```go
// NewAuthHandler - Required authentication (returns 401 on failure)
authHandler := security.NewAuthHandler(securityList, myHandler)
http.Handle("/api/protected", authHandler)
// NewOptionalAuthHandler - Optional authentication (guest on failure)
optionalHandler := security.NewOptionalAuthHandler(securityList, myHandler)
http.Handle("/home", optionalHandler)
// Example handler
func myHandler(w http.ResponseWriter, r *http.Request) {
userCtx, _ := security.GetUserContext(r.Context())
if userCtx.UserID == 0 {
// Guest user
} else {
// Authenticated user
}
}
```
---
## Context Helpers
```go
// Get full user context
userCtx, ok := security.GetUserContext(ctx)
// Get individual fields
userID, ok := security.GetUserID(ctx)
userName, ok := security.GetUserName(ctx)
userLevel, ok := security.GetUserLevel(ctx)
sessionID, ok := security.GetSessionID(ctx)
remoteID, ok := security.GetRemoteID(ctx)
roles, ok := security.GetUserRoles(ctx)
email, ok := security.GetUserEmail(ctx)
meta, ok := security.GetUserMeta(ctx)
```
---
## Resources
| File | Description |
|------|-------------|
| `CALLBACKS_GUIDE.md` | **Start here** - Complete implementation guide |
| `callbacks_example.go` | 7 working examples to copy |
| `CALLBACKS_SUMMARY.md` | Architecture overview |
| `README.md` | Full documentation |
| `setup_example.go` | Integration examples |
| `INTERFACE_GUIDE.md` | **Start here** - Complete implementation guide |
| `examples.go` | Working provider implementations to copy |
| `setup_example.go` | 6 complete integration examples |
| `README.md` | Architecture overview and migration guide |
---
@@ -377,22 +717,22 @@ func main() {
```go
// ===== REQUIRED SETUP =====
security.GlobalSecurity.AuthenticateCallback = myAuthFunc
security.GlobalSecurity.LoadColumnSecurityCallback = myColFunc
security.GlobalSecurity.LoadRowSecurityCallback = myRowFunc
security.SetupSecurityProvider(handler, &security.GlobalSecurity)
auth := security.NewJWTAuthenticator("secret", db)
colSec := security.NewDatabaseColumnSecurityProvider(db)
rowSec := security.NewDatabaseRowSecurityProvider(db)
provider := security.NewCompositeSecurityProvider(auth, colSec, rowSec)
securityList := security.SetupSecurityProvider(handler, provider)
// ===== CALLBACK SIGNATURES =====
func(r *http.Request) (int, string, error) // Auth
func(int, string, string) ([]security.ColumnSecurity, error) // Column
func(int, string, string) (security.RowSecurity, error) // Row
// ===== INTERFACE METHODS =====
Authenticate(r *http.Request) (*UserContext, error)
Login(ctx context.Context, req LoginRequest) (*LoginResponse, error)
Logout(ctx context.Context, req LogoutRequest) error
GetColumnSecurity(ctx context.Context, userID int, schema, table string) ([]ColumnSecurity, error)
GetRowSecurity(ctx context.Context, userID int, schema, table string) (RowSecurity, error)
// ===== QUICK EXAMPLES =====
// Header auth
func(r *http.Request) (int, string, error) {
id, _ := strconv.Atoi(r.Header.Get("X-User-ID"))
return id, "", nil
}
&UserContext{UserID: 123, UserName: "john"}
// Mask SSN
{Path: []string{"ssn"}, Accesstype: "mask", MaskStart: 5}

950
pkg/security/README.md Normal file
View File

@@ -0,0 +1,950 @@
# ResolveSpec Security Provider
Type-safe, composable security system for ResolveSpec with support for authentication, column-level security (masking), and row-level security (filtering).
## Features
-**Interface-Based** - Type-safe providers instead of callbacks
-**Login/Logout Support** - Built-in authentication lifecycle
-**Composable** - Mix and match different providers
-**No Global State** - Each handler has its own security configuration
-**Testable** - Easy to mock and test
-**Extensible** - Implement custom providers for your needs
-**Stored Procedures** - All database operations use PostgreSQL stored procedures for security and maintainability
## Stored Procedure Architecture
**All database-backed security providers use PostgreSQL stored procedures exclusively.** No raw SQL queries are executed from Go code.
### Benefits
- **Security**: Database logic is centralized and protected
- **Maintainability**: Update database logic without recompiling Go code
- **Performance**: Stored procedures are pre-compiled and optimized
- **Testability**: Test database logic independently
- **Consistency**: Standardized `resolvespec_*` naming convention
### Available Stored Procedures
| Procedure | Purpose | Used By |
|-----------|---------|---------|
| `resolvespec_login` | Session-based login | DatabaseAuthenticator |
| `resolvespec_logout` | Session invalidation | DatabaseAuthenticator |
| `resolvespec_session` | Session validation | DatabaseAuthenticator |
| `resolvespec_session_update` | Update session activity | DatabaseAuthenticator |
| `resolvespec_refresh_token` | Token refresh | DatabaseAuthenticator |
| `resolvespec_jwt_login` | JWT user validation | JWTAuthenticator |
| `resolvespec_jwt_logout` | JWT token blacklist | JWTAuthenticator |
| `resolvespec_column_security` | Load column rules | DatabaseColumnSecurityProvider |
| `resolvespec_row_security` | Load row templates | DatabaseRowSecurityProvider |
See `database_schema.sql` for complete stored procedure definitions and examples.
## Quick Start
```go
import (
"github.com/bitechdev/ResolveSpec/pkg/security"
"github.com/bitechdev/ResolveSpec/pkg/restheadspec"
)
// 1. Create security providers
auth := security.NewJWTAuthenticator("your-secret-key", db)
colSec := security.NewDatabaseColumnSecurityProvider(db)
rowSec := security.NewDatabaseRowSecurityProvider(db)
// 2. Combine providers
provider := security.NewCompositeSecurityProvider(auth, colSec, rowSec)
// 3. Create handler and register security hooks
handler := restheadspec.NewHandlerWithGORM(db)
securityList := security.NewSecurityList(provider)
restheadspec.RegisterSecurityHooks(handler, securityList)
// 4. Apply middleware
router := mux.NewRouter()
restheadspec.SetupMuxRoutes(router, handler)
router.Use(security.NewAuthMiddleware(securityList))
router.Use(security.SetSecurityMiddleware(securityList))
```
## Architecture
### Spec-Agnostic Design
The security system is **completely spec-agnostic** - it doesn't depend on any specific spec implementation. Instead, each spec (restheadspec, funcspec, resolvespec) implements its own security integration by adapting to the `SecurityContext` interface.
```
┌─────────────────────────────────────┐
│ Security Package (Generic) │
│ - SecurityContext interface │
│ - Security providers │
│ - Core security logic │
└─────────────────────────────────────┘
▲ ▲ ▲
│ │ │
┌──────┘ │ └──────┐
│ │ │
┌───▼────┐ ┌────▼─────┐ ┌────▼──────┐
│RestHead│ │ FuncSpec │ │ResolveSpec│
│ Spec │ │ │ │ │
│ │ │ │ │ │
│Adapts │ │ Adapts │ │ Adapts │
│to │ │ to │ │ to │
│Security│ │ Security │ │ Security │
│Context │ │ Context │ │ Context │
└────────┘ └──────────┘ └───────────┘
```
**Benefits:**
- ✅ No circular dependencies
- ✅ Each spec can customize security integration
- ✅ Easy to add new specs
- ✅ Security logic is reusable across all specs
### Core Interfaces
The security system is built on three main interfaces:
#### 1. Authenticator
Handles user authentication lifecycle:
```go
type Authenticator interface {
Login(ctx context.Context, req LoginRequest) (*LoginResponse, error)
Logout(ctx context.Context, req LogoutRequest) error
Authenticate(r *http.Request) (*UserContext, error)
}
```
#### 2. ColumnSecurityProvider
Manages column-level security (masking/hiding):
```go
type ColumnSecurityProvider interface {
GetColumnSecurity(ctx context.Context, userID int, schema, table string) ([]ColumnSecurity, error)
}
```
#### 3. RowSecurityProvider
Manages row-level security (WHERE clause filtering):
```go
type RowSecurityProvider interface {
GetRowSecurity(ctx context.Context, userID int, schema, table string) (RowSecurity, error)
}
```
### SecurityProvider
The main interface that combines all three:
```go
type SecurityProvider interface {
Authenticator
ColumnSecurityProvider
RowSecurityProvider
}
```
#### 4. SecurityContext (Spec Integration Interface)
Each spec implements this interface to integrate with the security system:
```go
type SecurityContext interface {
GetContext() context.Context
GetUserID() (int, bool)
GetSchema() string
GetEntity() string
GetModel() interface{}
GetQuery() interface{}
SetQuery(interface{})
GetResult() interface{}
SetResult(interface{})
}
```
**Implementation Examples:**
- `restheadspec`: Adapts `restheadspec.HookContext``SecurityContext`
- `funcspec`: Adapts `funcspec.HookContext``SecurityContext`
- `resolvespec`: Adapts `resolvespec.HookContext``SecurityContext`
### UserContext
Enhanced user context with complete user information:
```go
type UserContext struct {
UserID int // User's unique ID
UserName string // Username
UserLevel int // User privilege level
SessionID string // Current session ID
RemoteID string // Remote system ID
Roles []string // User roles
Email string // User email
Claims map[string]any // Additional authentication claims
Meta map[string]any // Additional metadata (can hold any JSON-serializable values)
}
```
## Available Implementations
### Authenticators
**HeaderAuthenticator** - Simple header-based authentication:
```go
auth := security.NewHeaderAuthenticator()
// Expects: X-User-ID, X-User-Name, X-User-Level, etc.
```
**DatabaseAuthenticator** - Database session-based authentication (Recommended):
```go
auth := security.NewDatabaseAuthenticator(db)
// Supports: Login, Logout, Session management, Token refresh
// All operations use stored procedures: resolvespec_login, resolvespec_logout,
// resolvespec_session, resolvespec_session_update, resolvespec_refresh_token
// Requires: users and user_sessions tables + stored procedures (see database_schema.sql)
```
**JWTAuthenticator** - JWT token authentication with login/logout:
```go
auth := security.NewJWTAuthenticator("secret-key", db)
// Supports: Login, Logout, JWT token validation
// All operations use stored procedures: resolvespec_jwt_login, resolvespec_jwt_logout
// Note: Requires JWT library installation for token signing/verification
```
### Column Security Providers
**DatabaseColumnSecurityProvider** - Loads rules from database:
```go
colSec := security.NewDatabaseColumnSecurityProvider(db)
// Uses stored procedure: resolvespec_column_security
// Queries core.secaccess and core.hub_link tables
```
**ConfigColumnSecurityProvider** - Static configuration:
```go
rules := map[string][]security.ColumnSecurity{
"public.employees": {
{Path: []string{"ssn"}, Accesstype: "mask", MaskStart: 5},
},
}
colSec := security.NewConfigColumnSecurityProvider(rules)
```
### Row Security Providers
**DatabaseRowSecurityProvider** - Loads filters from database:
```go
rowSec := security.NewDatabaseRowSecurityProvider(db)
// Uses stored procedure: resolvespec_row_security
```
**ConfigRowSecurityProvider** - Static templates:
```go
templates := map[string]string{
"public.orders": "user_id = {UserID}",
}
blocked := map[string]bool{
"public.admin_logs": true,
}
rowSec := security.NewConfigRowSecurityProvider(templates, blocked)
```
## Usage Examples
### Example 1: Complete Database-Backed Security with Sessions (restheadspec)
```go
func main() {
db := setupDatabase()
// Run migrations (see database_schema.sql)
// db.Exec("CREATE TABLE users ...")
// db.Exec("CREATE TABLE user_sessions ...")
// Create handler
handler := restheadspec.NewHandlerWithGORM(db)
// Create security providers
auth := security.NewDatabaseAuthenticator(db) // Session-based auth
colSec := security.NewDatabaseColumnSecurityProvider(db)
rowSec := security.NewDatabaseRowSecurityProvider(db)
// Combine providers
provider := security.NewCompositeSecurityProvider(auth, colSec, rowSec)
securityList := security.NewSecurityList(provider)
// Register security hooks for this spec
restheadspec.RegisterSecurityHooks(handler, securityList)
// Setup routes
router := mux.NewRouter()
// Add auth endpoints
router.HandleFunc("/auth/login", handleLogin(securityList)).Methods("POST")
router.HandleFunc("/auth/logout", handleLogout(securityList)).Methods("POST")
router.HandleFunc("/auth/refresh", handleRefresh(securityList)).Methods("POST")
// Setup API with security
apiRouter := router.PathPrefix("/api").Subrouter()
restheadspec.SetupMuxRoutes(apiRouter, handler)
apiRouter.Use(security.NewAuthMiddleware(securityList))
apiRouter.Use(security.SetSecurityMiddleware(securityList))
http.ListenAndServe(":8080", router)
}
func handleLogin(securityList *security.SecurityList) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
var req security.LoginRequest
json.NewDecoder(r.Body).Decode(&req)
// Add client info to claims
req.Claims = map[string]any{
"ip_address": r.RemoteAddr,
"user_agent": r.UserAgent(),
}
resp, err := securityList.Provider().Login(r.Context(), req)
if err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
// Set session cookie (optional)
http.SetCookie(w, &http.Cookie{
Name: "session_token",
Value: resp.Token,
Expires: time.Now().Add(24 * time.Hour),
HttpOnly: true,
Secure: true, // Use in production with HTTPS
SameSite: http.SameSiteStrictMode,
})
json.NewEncoder(w).Encode(resp)
}
}
func handleRefresh(securityList *security.SecurityList) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
token := r.Header.Get("X-Refresh-Token")
if refreshable, ok := securityList.Provider().(security.Refreshable); ok {
resp, err := refreshable.RefreshToken(r.Context(), token)
if err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
json.NewEncoder(w).Encode(resp)
} else {
http.Error(w, "Refresh not supported", http.StatusNotImplemented)
}
}
}
```
### Example 2: Config-Based Security (No Database)
```go
func main() {
db := setupDatabase()
handler := restheadspec.NewHandlerWithGORM(db)
// Static column security rules
columnRules := map[string][]security.ColumnSecurity{
"public.employees": {
{Path: []string{"ssn"}, Accesstype: "mask", MaskStart: 5},
{Path: []string{"salary"}, Accesstype: "hide"},
},
}
// Static row security templates
rowTemplates := map[string]string{
"public.orders": "user_id = {UserID}",
}
// Create providers
auth := security.NewHeaderAuthenticator()
colSec := security.NewConfigColumnSecurityProvider(columnRules)
rowSec := security.NewConfigRowSecurityProvider(rowTemplates, nil)
// Combine providers and register hooks
provider := security.NewCompositeSecurityProvider(auth, colSec, rowSec)
securityList := security.NewSecurityList(provider)
restheadspec.RegisterSecurityHooks(handler, securityList)
// Setup routes...
}
```
### Example 3: FuncSpec Security (SQL Query API)
```go
import (
"github.com/bitechdev/ResolveSpec/pkg/funcspec"
"github.com/bitechdev/ResolveSpec/pkg/security"
)
func main() {
db := setupDatabase()
// Create funcspec handler
handler := funcspec.NewHandler(db)
// Create security providers
auth := security.NewJWTAuthenticator("secret-key", db)
colSec := security.NewDatabaseColumnSecurityProvider(db)
rowSec := security.NewDatabaseRowSecurityProvider(db)
// Combine providers
provider := security.NewCompositeSecurityProvider(auth, colSec, rowSec)
securityList := security.NewSecurityList(provider)
// Register security hooks (audit logging)
funcspec.RegisterSecurityHooks(handler, securityList)
// Note: funcspec operates on raw SQL queries, so row/column
// security is limited. Security should be enforced at the
// SQL function level or via database policies.
// Setup routes...
}
```
### Example 4: ResolveSpec Security (REST API)
```go
import (
"github.com/bitechdev/ResolveSpec/pkg/resolvespec"
"github.com/bitechdev/ResolveSpec/pkg/security"
)
func main() {
db := setupDatabase()
registry := common.NewModelRegistry()
// Register models
registry.RegisterModel("public.users", &User{})
registry.RegisterModel("public.orders", &Order{})
// Create resolvespec handler
handler := resolvespec.NewHandler(db, registry)
// Create security providers
auth := security.NewDatabaseAuthenticator(db)
colSec := security.NewDatabaseColumnSecurityProvider(db)
rowSec := security.NewDatabaseRowSecurityProvider(db)
// Combine providers
provider := security.NewCompositeSecurityProvider(auth, colSec, rowSec)
securityList := security.NewSecurityList(provider)
// Register security hooks for resolvespec
resolvespec.RegisterSecurityHooks(handler, securityList)
// Setup routes...
}
```
### Example 5: Custom Provider
Implement your own provider for complete control:
```go
type MySecurityProvider struct {
db *gorm.DB
}
func (p *MySecurityProvider) Login(ctx context.Context, req security.LoginRequest) (*security.LoginResponse, error) {
// Your custom login logic
}
func (p *MySecurityProvider) Logout(ctx context.Context, req security.LogoutRequest) error {
// Your custom logout logic
}
func (p *MySecurityProvider) Authenticate(r *http.Request) (*security.UserContext, error) {
// Your custom authentication logic
}
func (p *MySecurityProvider) GetColumnSecurity(ctx context.Context, userID int, schema, table string) ([]security.ColumnSecurity, error) {
// Your custom column security logic
}
func (p *MySecurityProvider) GetRowSecurity(ctx context.Context, userID int, schema, table string) (security.RowSecurity, error) {
// Your custom row security logic
}
// Use it with any spec
provider := &MySecurityProvider{db: db}
securityList := security.NewSecurityList(provider)
// Register with restheadspec
restheadspec.RegisterSecurityHooks(restHandler, securityList)
// Or with funcspec
funcspec.RegisterSecurityHooks(funcHandler, securityList)
// Or with resolvespec
resolvespec.RegisterSecurityHooks(resolveHandler, securityList)
```
## Security Features
### Column Security (Masking/Hiding)
**Mask SSN (show last 4 digits):**
```go
{
Path: []string{"ssn"},
Accesstype: "mask",
MaskStart: 5,
MaskChar: "*",
}
// "123-45-6789" → "*****6789"
```
**Hide entire field:**
```go
{
Path: []string{"salary"},
Accesstype: "hide",
}
// Field returns 0 or empty
```
**Nested JSON field masking:**
```go
{
Path: []string{"address", "street"},
Accesstype: "mask",
MaskStart: 10,
}
```
### Row Security (Filtering)
**User isolation:**
```go
{
Template: "user_id = {UserID}",
}
// Users only see their own records
```
**Tenant isolation:**
```go
{
Template: "tenant_id = {TenantID} AND user_id = {UserID}",
}
```
**Block all access:**
```go
{
HasBlock: true,
}
// Completely blocks access to the table
```
**Template variables:**
- `{UserID}` - Current user's ID
- `{PrimaryKeyName}` - Primary key column
- `{TableName}` - Table name
- `{SchemaName}` - Schema name
## Request Flow
```
HTTP Request
NewAuthMiddleware (security package)
├─ Calls provider.Authenticate(request)
└─ Adds UserContext to context
SetSecurityMiddleware (security package)
└─ Adds SecurityList to context
Spec Handler (restheadspec/funcspec/resolvespec)
BeforeRead Hook (registered by spec)
├─ Adapts spec's HookContext → SecurityContext
├─ Calls security.LoadSecurityRules(secCtx, securityList)
│ ├─ Calls provider.GetColumnSecurity()
│ └─ Calls provider.GetRowSecurity()
└─ Caches security rules
BeforeScan Hook (registered by spec)
├─ Adapts spec's HookContext → SecurityContext
├─ Calls security.ApplyRowSecurity(secCtx, securityList)
└─ Applies row security (adds WHERE clause to query)
Database Query (with security filters)
AfterRead Hook (registered by spec)
├─ Adapts spec's HookContext → SecurityContext
├─ Calls security.ApplyColumnSecurity(secCtx, securityList)
├─ Applies column security (masks/hides fields)
└─ Calls security.LogDataAccess(secCtx)
HTTP Response (secured data)
```
**Key Points:**
- Security package is spec-agnostic and provides core logic
- Each spec registers its own hooks that adapt to SecurityContext
- Security rules are loaded once and cached for the request
- Row security is applied to the query (database level)
- Column security is applied to results (application level)
## Testing
The interface-based design makes testing straightforward:
```go
// Mock authenticator for tests
type MockAuthenticator struct {
UserToReturn *security.UserContext
ErrorToReturn error
}
func (m *MockAuthenticator) Authenticate(r *http.Request) (*security.UserContext, error) {
return m.UserToReturn, m.ErrorToReturn
}
// Use in tests
func TestMyHandler(t *testing.T) {
mockAuth := &MockAuthenticator{
UserToReturn: &security.UserContext{UserID: 123},
}
provider := security.NewCompositeSecurityProvider(
mockAuth,
&MockColumnSecurity{},
&MockRowSecurity{},
)
securityList := security.SetupSecurityProvider(handler, provider)
// ... test your handler
}
```
## Migration Guide
### From Old Callback System
If you're upgrading from the old callback-based system:
**Old:**
```go
security.GlobalSecurity.AuthenticateCallback = myAuthFunc
security.GlobalSecurity.LoadColumnSecurityCallback = myColSecFunc
security.GlobalSecurity.LoadRowSecurityCallback = myRowSecFunc
security.SetupSecurityProvider(handler, &security.GlobalSecurity)
```
**New:**
```go
// 1. Wrap your functions in a provider
type MyProvider struct{}
func (p *MyProvider) Authenticate(r *http.Request) (*security.UserContext, error) {
userID, roles, err := myAuthFunc(r)
return &security.UserContext{UserID: userID, Roles: strings.Split(roles, ",")}, err
}
func (p *MyProvider) GetColumnSecurity(ctx context.Context, userID int, schema, table string) ([]security.ColumnSecurity, error) {
return myColSecFunc(userID, schema, table)
}
func (p *MyProvider) GetRowSecurity(ctx context.Context, userID int, schema, table string) (security.RowSecurity, error) {
return myRowSecFunc(userID, schema, table)
}
func (p *MyProvider) Login(ctx context.Context, req security.LoginRequest) (*security.LoginResponse, error) {
return nil, fmt.Errorf("not implemented")
}
func (p *MyProvider) Logout(ctx context.Context, req security.LogoutRequest) error {
return nil
}
// 2. Create security list and register hooks
provider := &MyProvider{}
securityList := security.NewSecurityList(provider)
// 3. Register with your spec
restheadspec.RegisterSecurityHooks(handler, securityList)
```
### From Old SetupSecurityProvider API
If you're upgrading from the previous interface-based system:
**Old:**
```go
securityList := security.SetupSecurityProvider(handler, provider)
```
**New:**
```go
securityList := security.NewSecurityList(provider)
restheadspec.RegisterSecurityHooks(handler, securityList) // or funcspec/resolvespec
```
The main changes:
1. Security package no longer knows about specific spec types
2. Each spec registers its own security hooks
3. More flexible - same security provider works with all specs
## Documentation
| File | Description |
|------|-------------|
| **QUICK_REFERENCE.md** | Quick reference guide with examples |
| **INTERFACE_GUIDE.md** | Complete implementation guide |
| **examples.go** | Working provider implementations |
| **setup_example.go** | 6 complete integration examples |
## API Reference
### Context Helpers
Get user information from request context:
```go
userCtx, ok := security.GetUserContext(ctx)
userID, ok := security.GetUserID(ctx)
userName, ok := security.GetUserName(ctx)
userLevel, ok := security.GetUserLevel(ctx)
sessionID, ok := security.GetSessionID(ctx)
remoteID, ok := security.GetRemoteID(ctx)
roles, ok := security.GetUserRoles(ctx)
email, ok := security.GetUserEmail(ctx)
```
### Optional Interfaces
Implement these for additional features:
**Refreshable** - Token refresh support:
```go
type Refreshable interface {
RefreshToken(ctx context.Context, refreshToken string) (*LoginResponse, error)
}
```
**Validatable** - Token validation:
```go
type Validatable interface {
ValidateToken(ctx context.Context, token string) (bool, error)
}
```
**Cacheable** - Cache management:
```go
type Cacheable interface {
ClearCache(ctx context.Context, userID int, schema, table string) error
}
```
## Benefits Over Callbacks
| Feature | Old (Callbacks) | New (Interfaces) |
|---------|----------------|------------------|
| Type Safety | ❌ Callbacks can be nil | ✅ Compile-time verification |
| Global State | ❌ GlobalSecurity variable | ✅ Dependency injection |
| Testability | ⚠️ Need to set globals | ✅ Easy to mock |
| Composability | ❌ Single provider only | ✅ Mix and match |
| Login/Logout | ❌ Not supported | ✅ Built-in |
| Extensibility | ⚠️ Limited | ✅ Optional interfaces |
## Common Patterns
### Caching Security Rules
```go
type CachedProvider struct {
inner security.ColumnSecurityProvider
cache *cache.Cache
}
func (p *CachedProvider) GetColumnSecurity(ctx context.Context, userID int, schema, table string) ([]security.ColumnSecurity, error) {
key := fmt.Sprintf("%d:%s.%s", userID, schema, table)
if cached, found := p.cache.Get(key); found {
return cached.([]security.ColumnSecurity), nil
}
rules, err := p.inner.GetColumnSecurity(ctx, userID, schema, table)
if err == nil {
p.cache.Set(key, rules, cache.DefaultExpiration)
}
return rules, err
}
```
### Role-Based Security
```go
func (p *MyProvider) GetColumnSecurity(ctx context.Context, userID int, schema, table string) ([]security.ColumnSecurity, error) {
userCtx, _ := security.GetUserContext(ctx)
if contains(userCtx.Roles, "admin") {
return []security.ColumnSecurity{}, nil // No restrictions
}
return loadRestrictionsForUser(userID, schema, table), nil
}
```
### Multi-Tenant Isolation
```go
func (p *MyProvider) GetRowSecurity(ctx context.Context, userID int, schema, table string) (security.RowSecurity, error) {
tenantID := getUserTenant(userID)
return security.RowSecurity{
Template: fmt.Sprintf("tenant_id = %d AND user_id = {UserID}", tenantID),
}, nil
}
```
## Middleware and Handler API
### NewAuthMiddleware
Standard middleware that authenticates all requests:
```go
router.Use(security.NewAuthMiddleware(securityList))
```
Routes can skip authentication using the `SkipAuth` helper:
```go
func PublicHandler(w http.ResponseWriter, r *http.Request) {
ctx := security.SkipAuth(r.Context())
// This route will bypass authentication
// A guest user context will be set instead
}
router.Handle("/public", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := security.SkipAuth(r.Context())
PublicHandler(w, r.WithContext(ctx))
}))
```
When authentication is skipped, a guest user context is automatically set:
- UserID: 0
- UserName: "guest"
- Roles: ["guest"]
- RemoteID: Request's remote address
Routes can use optional authentication with the `OptionalAuth` helper:
```go
func OptionalAuthHandler(w http.ResponseWriter, r *http.Request) {
ctx := security.OptionalAuth(r.Context())
r = r.WithContext(ctx)
// This route will try to authenticate
// If authentication succeeds, authenticated user context is set
// If authentication fails, guest user context is set instead
userCtx, _ := security.GetUserContext(r.Context())
if userCtx.UserID == 0 {
// Guest user
fmt.Fprintf(w, "Welcome, guest!")
} else {
// Authenticated user
fmt.Fprintf(w, "Welcome back, %s!", userCtx.UserName)
}
}
router.Handle("/home", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := security.OptionalAuth(r.Context())
OptionalAuthHandler(w, r.WithContext(ctx))
}))
```
**Authentication Modes Summary:**
- **Required (default)**: Authentication must succeed or returns 401
- **SkipAuth**: Bypasses authentication entirely, always sets guest context
- **OptionalAuth**: Tries authentication, falls back to guest context if it fails
### NewAuthHandler
Standalone authentication handler (without middleware wrapping):
```go
// Use when you need authentication logic without middleware
authHandler := security.NewAuthHandler(securityList, myHandler)
http.Handle("/api/protected", authHandler)
```
### NewOptionalAuthHandler
Standalone optional authentication handler that tries to authenticate but falls back to guest:
```go
// Use for routes that should work for both authenticated and guest users
optionalHandler := security.NewOptionalAuthHandler(securityList, myHandler)
http.Handle("/home", optionalHandler)
// Example handler that checks user context
func myHandler(w http.ResponseWriter, r *http.Request) {
userCtx, _ := security.GetUserContext(r.Context())
if userCtx.UserID == 0 {
fmt.Fprintf(w, "Welcome, guest!")
} else {
fmt.Fprintf(w, "Welcome back, %s!", userCtx.UserName)
}
}
```
### Helper Functions
Extract user information from context:
```go
// Get full user context
userCtx, ok := security.GetUserContext(ctx)
// Get specific fields
userID, ok := security.GetUserID(ctx)
userName, ok := security.GetUserName(ctx)
userLevel, ok := security.GetUserLevel(ctx)
sessionID, ok := security.GetSessionID(ctx)
remoteID, ok := security.GetRemoteID(ctx)
roles, ok := security.GetUserRoles(ctx)
email, ok := security.GetUserEmail(ctx)
meta, ok := security.GetUserMeta(ctx)
```
### Metadata Support
The `Meta` field in `UserContext` can hold any JSON-serializable values:
```go
// Set metadata during login
loginReq := security.LoginRequest{
Username: "user@example.com",
Password: "password",
Meta: map[string]any{
"department": "engineering",
"location": "US",
"preferences": map[string]any{
"theme": "dark",
},
},
}
// Access metadata in handlers
meta, ok := security.GetUserMeta(ctx)
if ok {
department := meta["department"].(string)
}
```
## License
Part of the ResolveSpec project.

View File

@@ -1,414 +0,0 @@
package security
import (
"fmt"
"net/http"
"strconv"
"strings"
)
// This file provides example implementations of the required security callbacks.
// Copy these functions and modify them to match your authentication and database schema.
// =============================================================================
// EXAMPLE 1: Simple Header-Based Authentication
// =============================================================================
// ExampleAuthenticateFromHeader extracts user ID from X-User-ID header
func ExampleAuthenticateFromHeader(r *http.Request) (userID int, roles string, err error) {
userIDStr := r.Header.Get("X-User-ID")
if userIDStr == "" {
return 0, "", fmt.Errorf("X-User-ID header not provided")
}
userID, err = strconv.Atoi(userIDStr)
if err != nil {
return 0, "", fmt.Errorf("invalid user ID format: %v", err)
}
// Optionally extract roles
roles = r.Header.Get("X-User-Roles") // comma-separated: "admin,manager"
return userID, roles, nil
}
// =============================================================================
// EXAMPLE 2: JWT Token Authentication
// =============================================================================
// ExampleAuthenticateFromJWT parses a JWT token and extracts user info
// You'll need to import a JWT library like github.com/golang-jwt/jwt/v5
func ExampleAuthenticateFromJWT(r *http.Request) (userID int, roles string, err error) {
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
return 0, "", fmt.Errorf("authorization header not provided")
}
// Extract Bearer token
tokenString := strings.TrimPrefix(authHeader, "Bearer ")
if tokenString == authHeader {
return 0, "", fmt.Errorf("invalid authorization header format")
}
// TODO: Parse and validate JWT token
// Example using github.com/golang-jwt/jwt/v5:
//
// token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
// return []byte(os.Getenv("JWT_SECRET")), nil
// })
//
// if err != nil || !token.Valid {
// return 0, "", fmt.Errorf("invalid token: %v", err)
// }
//
// claims := token.Claims.(jwt.MapClaims)
// userID = int(claims["user_id"].(float64))
// roles = claims["roles"].(string)
return 0, "", fmt.Errorf("JWT parsing not implemented - see example above")
}
// =============================================================================
// EXAMPLE 3: Session Cookie Authentication
// =============================================================================
// ExampleAuthenticateFromSession validates a session cookie
func ExampleAuthenticateFromSession(r *http.Request) (userID int, roles string, err error) {
sessionCookie, err := r.Cookie("session_id")
if err != nil {
return 0, "", fmt.Errorf("session cookie not found")
}
// TODO: Validate session against your session store (Redis, database, etc.)
// Example:
//
// session, err := sessionStore.Get(sessionCookie.Value)
// if err != nil {
// return 0, "", fmt.Errorf("invalid session")
// }
//
// userID = session.UserID
// roles = session.Roles
_ = sessionCookie // Suppress unused warning until implemented
return 0, "", fmt.Errorf("session validation not implemented - see example above")
}
// =============================================================================
// EXAMPLE 4: Column Security - Database Implementation
// =============================================================================
// ExampleLoadColumnSecurityFromDatabase loads column security rules from database
// This implementation assumes the following database schema:
//
// CREATE TABLE core.secacces (
// rid_secacces SERIAL PRIMARY KEY,
// rid_hub INTEGER,
// control TEXT, -- Format: "schema.table.column"
// accesstype TEXT, -- "mask" or "hide"
// jsonvalue JSONB -- Masking configuration
// );
//
// CREATE TABLE core.hub_link (
// rid_hub_parent INTEGER, -- Security group ID
// rid_hub_child INTEGER, -- User ID
// parent_hubtype TEXT -- 'secgroup'
// );
func ExampleLoadColumnSecurityFromDatabase(pUserID int, pSchema, pTablename string) ([]ColumnSecurity, error) {
colSecList := make([]ColumnSecurity, 0)
// getExtraFilters := func(pStr string) map[string]string {
// mp := make(map[string]string, 0)
// for i, val := range strings.Split(pStr, ",") {
// if i <= 1 {
// continue
// }
// vals := strings.Split(val, ":")
// if len(vals) > 1 {
// mp[vals[0]] = vals[1]
// }
// }
// return mp
// }
// rows, err := DBM.DBConn.Raw(fmt.Sprintf(`
// SELECT a.rid_secacces, a.control, a.accesstype, a.jsonvalue
// FROM core.secacces a
// WHERE a.rid_hub IN (
// SELECT l.rid_hub_parent
// FROM core.hub_link l
// WHERE l.parent_hubtype = 'secgroup'
// AND l.rid_hub_child = ?
// )
// AND control ILIKE '%s.%s%%'
// `, pSchema, pTablename), pUserID).Rows()
// defer func() {
// if rows != nil {
// rows.Close()
// }
// }()
// if err != nil {
// return colSecList, fmt.Errorf("failed to fetch column security from SQL: %v", err)
// }
// for rows.Next() {
// var rid int
// var jsondata []byte
// var control, accesstype string
// err = rows.Scan(&rid, &control, &accesstype, &jsondata)
// if err != nil {
// return colSecList, fmt.Errorf("failed to scan column security: %v", err)
// }
// parts := strings.Split(control, ",")
// ids := strings.Split(parts[0], ".")
// if len(ids) < 3 {
// continue
// }
// jsonvalue := make(map[string]interface{})
// if len(jsondata) > 1 {
// err = json.Unmarshal(jsondata, &jsonvalue)
// if err != nil {
// logger.Error("Failed to parse json: %v", err)
// }
// }
// colsec := ColumnSecurity{
// Schema: pSchema,
// Tablename: pTablename,
// UserID: pUserID,
// Path: ids[2:],
// ExtraFilters: getExtraFilters(control),
// Accesstype: accesstype,
// Control: control,
// ID: int(rid),
// }
// // Parse masking configuration from JSON
// if v, ok := jsonvalue["start"]; ok {
// if value, ok := v.(float64); ok {
// colsec.MaskStart = int(value)
// }
// }
// if v, ok := jsonvalue["end"]; ok {
// if value, ok := v.(float64); ok {
// colsec.MaskEnd = int(value)
// }
// }
// if v, ok := jsonvalue["invert"]; ok {
// if value, ok := v.(bool); ok {
// colsec.MaskInvert = value
// }
// }
// if v, ok := jsonvalue["char"]; ok {
// if value, ok := v.(string); ok {
// colsec.MaskChar = value
// }
// }
// colSecList = append(colSecList, colsec)
// }
return colSecList, nil
}
// =============================================================================
// EXAMPLE 5: Column Security - In-Memory/Static Configuration
// =============================================================================
// ExampleLoadColumnSecurityFromConfig loads column security from static config
func ExampleLoadColumnSecurityFromConfig(pUserID int, pSchema, pTablename string) ([]ColumnSecurity, error) {
// Example: Define security rules in code or load from config file
securityRules := map[string][]ColumnSecurity{
"public.employees": {
{
Schema: "public",
Tablename: "employees",
Path: []string{"ssn"},
Accesstype: "mask",
MaskStart: 5,
MaskEnd: 0,
MaskChar: "*",
},
{
Schema: "public",
Tablename: "employees",
Path: []string{"salary"},
Accesstype: "hide",
},
},
"public.customers": {
{
Schema: "public",
Tablename: "customers",
Path: []string{"credit_card"},
Accesstype: "mask",
MaskStart: 12,
MaskEnd: 0,
MaskChar: "*",
},
},
}
key := fmt.Sprintf("%s.%s", pSchema, pTablename)
rules, ok := securityRules[key]
if !ok {
return []ColumnSecurity{}, nil // No rules for this table
}
// Filter by user ID if needed
// For this example, all rules apply to all users
return rules, nil
}
// =============================================================================
// EXAMPLE 6: Row Security - Database Implementation
// =============================================================================
// ExampleLoadRowSecurityFromDatabase loads row security rules from database
// This implementation assumes a PostgreSQL function:
//
// CREATE FUNCTION core.api_sec_rowtemplate(
// p_schema TEXT,
// p_table TEXT,
// p_userid INTEGER
// ) RETURNS TABLE (
// p_retval INTEGER,
// p_errmsg TEXT,
// p_template TEXT,
// p_block BOOLEAN
// );
func ExampleLoadRowSecurityFromDatabase(pUserID int, pSchema, pTablename string) (RowSecurity, error) {
record := RowSecurity{
Schema: pSchema,
Tablename: pTablename,
UserID: pUserID,
}
// rows, err := DBM.DBConn.Raw(`
// SELECT r.p_retval, r.p_errmsg, r.p_template, r.p_block
// FROM core.api_sec_rowtemplate(?, ?, ?) r
// `, pSchema, pTablename, pUserID).Rows()
// defer func() {
// if rows != nil {
// rows.Close()
// }
// }()
// if err != nil {
// return record, fmt.Errorf("failed to fetch row security from SQL: %v", err)
// }
// for rows.Next() {
// var retval int
// var errmsg string
// err = rows.Scan(&retval, &errmsg, &record.Template, &record.HasBlock)
// if err != nil {
// return record, fmt.Errorf("failed to scan row security: %v", err)
// }
// if retval != 0 {
// return RowSecurity{}, fmt.Errorf("api_sec_rowtemplate error: %s", errmsg)
// }
// }
return record, nil
}
// =============================================================================
// EXAMPLE 7: Row Security - Static Configuration
// =============================================================================
// ExampleLoadRowSecurityFromConfig loads row security from static config
func ExampleLoadRowSecurityFromConfig(pUserID int, pSchema, pTablename string) (RowSecurity, error) {
// Define row security templates based on entity
templates := map[string]string{
"public.orders": "user_id = {UserID}", // Users see only their orders
"public.documents": "user_id = {UserID} OR is_public = true", // Users see their docs + public docs
"public.employees": "department_id IN (SELECT department_id FROM user_departments WHERE user_id = {UserID})", // Complex filter
}
// Define blocked entities (no access at all)
blockedEntities := map[string][]int{
"public.admin_logs": {}, // All users blocked (empty list = block all)
"public.audit_logs": {1, 2, 3}, // Block users 1, 2, 3
}
key := fmt.Sprintf("%s.%s", pSchema, pTablename)
// Check if entity is blocked for this user
if blockedUsers, ok := blockedEntities[key]; ok {
if len(blockedUsers) == 0 {
// Block all users
return RowSecurity{
Schema: pSchema,
Tablename: pTablename,
UserID: pUserID,
HasBlock: true,
}, nil
}
// Check if specific user is blocked
for _, blockedUserID := range blockedUsers {
if blockedUserID == pUserID {
return RowSecurity{
Schema: pSchema,
Tablename: pTablename,
UserID: pUserID,
HasBlock: true,
}, nil
}
}
}
// Get template for this entity
template, ok := templates[key]
if !ok {
// No row security defined - allow all rows
return RowSecurity{
Schema: pSchema,
Tablename: pTablename,
UserID: pUserID,
Template: "",
HasBlock: false,
}, nil
}
return RowSecurity{
Schema: pSchema,
Tablename: pTablename,
UserID: pUserID,
Template: template,
HasBlock: false,
}, nil
}
// =============================================================================
// SETUP HELPER: Configure All Callbacks
// =============================================================================
// SetupCallbacksExample shows how to configure all callbacks
func SetupCallbacksExample() {
// Option 1: Use database-backed security (production)
GlobalSecurity.AuthenticateCallback = ExampleAuthenticateFromJWT
GlobalSecurity.LoadColumnSecurityCallback = ExampleLoadColumnSecurityFromDatabase
GlobalSecurity.LoadRowSecurityCallback = ExampleLoadRowSecurityFromDatabase
// Option 2: Use static configuration (development/testing)
// GlobalSecurity.AuthenticateCallback = ExampleAuthenticateFromHeader
// GlobalSecurity.LoadColumnSecurityCallback = ExampleLoadColumnSecurityFromConfig
// GlobalSecurity.LoadRowSecurityCallback = ExampleLoadRowSecurityFromConfig
// Option 3: Mix and match
// GlobalSecurity.AuthenticateCallback = ExampleAuthenticateFromJWT
// GlobalSecurity.LoadColumnSecurityCallback = ExampleLoadColumnSecurityFromConfig
// GlobalSecurity.LoadRowSecurityCallback = ExampleLoadRowSecurityFromDatabase
}

105
pkg/security/composite.go Normal file
View File

@@ -0,0 +1,105 @@
package security
import (
"context"
"fmt"
"net/http"
)
// CompositeSecurityProvider combines multiple security providers
// Allows separating authentication, column security, and row security concerns
type CompositeSecurityProvider struct {
auth Authenticator
colSec ColumnSecurityProvider
rowSec RowSecurityProvider
}
// NewCompositeSecurityProvider creates a composite provider
// All parameters are required
func NewCompositeSecurityProvider(
auth Authenticator,
colSec ColumnSecurityProvider,
rowSec RowSecurityProvider,
) *CompositeSecurityProvider {
if auth == nil {
panic("authenticator cannot be nil")
}
if colSec == nil {
panic("column security provider cannot be nil")
}
if rowSec == nil {
panic("row security provider cannot be nil")
}
return &CompositeSecurityProvider{
auth: auth,
colSec: colSec,
rowSec: rowSec,
}
}
// Login delegates to the authenticator
func (c *CompositeSecurityProvider) Login(ctx context.Context, req LoginRequest) (*LoginResponse, error) {
return c.auth.Login(ctx, req)
}
// Logout delegates to the authenticator
func (c *CompositeSecurityProvider) Logout(ctx context.Context, req LogoutRequest) error {
return c.auth.Logout(ctx, req)
}
// Authenticate delegates to the authenticator
func (c *CompositeSecurityProvider) Authenticate(r *http.Request) (*UserContext, error) {
return c.auth.Authenticate(r)
}
// GetColumnSecurity delegates to the column security provider
func (c *CompositeSecurityProvider) GetColumnSecurity(ctx context.Context, userID int, schema, table string) ([]ColumnSecurity, error) {
return c.colSec.GetColumnSecurity(ctx, userID, schema, table)
}
// GetRowSecurity delegates to the row security provider
func (c *CompositeSecurityProvider) GetRowSecurity(ctx context.Context, userID int, schema, table string) (RowSecurity, error) {
return c.rowSec.GetRowSecurity(ctx, userID, schema, table)
}
// Optional interface implementations (if wrapped providers support them)
// RefreshToken implements Refreshable if the authenticator supports it
func (c *CompositeSecurityProvider) RefreshToken(ctx context.Context, refreshToken string) (*LoginResponse, error) {
if refreshable, ok := c.auth.(Refreshable); ok {
return refreshable.RefreshToken(ctx, refreshToken)
}
return nil, fmt.Errorf("authenticator does not support token refresh")
}
// ValidateToken implements Validatable if the authenticator supports it
func (c *CompositeSecurityProvider) ValidateToken(ctx context.Context, token string) (bool, error) {
if validatable, ok := c.auth.(Validatable); ok {
return validatable.ValidateToken(ctx, token)
}
return false, fmt.Errorf("authenticator does not support token validation")
}
// ClearCache implements Cacheable if any provider supports it
func (c *CompositeSecurityProvider) ClearCache(ctx context.Context, userID int, schema, table string) error {
var errs []error
if cacheable, ok := c.colSec.(Cacheable); ok {
if err := cacheable.ClearCache(ctx, userID, schema, table); err != nil {
errs = append(errs, fmt.Errorf("column security cache clear failed: %w", err))
}
}
if cacheable, ok := c.rowSec.(Cacheable); ok {
if err := cacheable.ClearCache(ctx, userID, schema, table); err != nil {
errs = append(errs, fmt.Errorf("row security cache clear failed: %w", err))
}
}
if len(errs) > 0 {
return fmt.Errorf("cache clear errors: %v", errs)
}
return nil
}

View File

@@ -0,0 +1,428 @@
-- Database Schema for DatabaseAuthenticator
-- ============================================
-- Users table
CREATE TABLE IF NOT EXISTS users (
id SERIAL PRIMARY KEY,
username VARCHAR(255) NOT NULL UNIQUE,
email VARCHAR(255) NOT NULL UNIQUE,
password VARCHAR(255) NOT NULL, -- bcrypt hashed password
user_level INTEGER DEFAULT 0,
roles VARCHAR(500), -- Comma-separated roles: "admin,manager,user"
is_active BOOLEAN DEFAULT true,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
last_login_at TIMESTAMP
);
-- User sessions table for DatabaseAuthenticator
CREATE TABLE IF NOT EXISTS user_sessions (
id SERIAL PRIMARY KEY,
session_token VARCHAR(500) NOT NULL UNIQUE,
user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE,
expires_at TIMESTAMP NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
last_activity_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
ip_address VARCHAR(45), -- IPv4 or IPv6
user_agent TEXT
);
CREATE INDEX IF NOT EXISTS idx_session_token ON user_sessions(session_token);
CREATE INDEX IF NOT EXISTS idx_user_id ON user_sessions(user_id);
CREATE INDEX IF NOT EXISTS idx_expires_at ON user_sessions(expires_at);
-- Optional: Token blacklist for logout tracking (useful for JWT too)
CREATE TABLE IF NOT EXISTS token_blacklist (
id SERIAL PRIMARY KEY,
token VARCHAR(500) NOT NULL,
user_id INTEGER REFERENCES users(id) ON DELETE CASCADE,
expires_at TIMESTAMP NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);
CREATE INDEX IF NOT EXISTS idx_token ON token_blacklist(token);
CREATE INDEX IF NOT EXISTS idx_blacklist_expires_at ON token_blacklist(expires_at);
-- Example: Seed admin user (password should be hashed with bcrypt)
-- INSERT INTO users (username, email, password, user_level, roles, is_active)
-- VALUES ('admin', 'admin@example.com', '$2a$10$...', 10, 'admin,user', true);
-- Cleanup expired sessions (run periodically)
-- DELETE FROM user_sessions WHERE expires_at < NOW();
-- Cleanup expired blacklisted tokens (run periodically)
-- DELETE FROM token_blacklist WHERE expires_at < NOW();
-- ============================================
-- Stored Procedures for DatabaseAuthenticator
-- ============================================
-- 1. resolvespec_login - Authenticates user and creates session
-- Input: LoginRequest as jsonb {username: string, password: string, claims: object}
-- Output: p_success (bool), p_error (text), p_data (LoginResponse as jsonb)
CREATE OR REPLACE FUNCTION resolvespec_login(p_request jsonb)
RETURNS TABLE(p_success boolean, p_error text, p_data jsonb) AS $$
DECLARE
v_user_id INTEGER;
v_username TEXT;
v_email TEXT;
v_user_level INTEGER;
v_roles TEXT;
v_password_hash TEXT;
v_session_token TEXT;
v_expires_at TIMESTAMP;
v_ip_address TEXT;
v_user_agent TEXT;
BEGIN
-- Extract login request fields
v_username := p_request->>'username';
v_ip_address := p_request->'claims'->>'ip_address';
v_user_agent := p_request->'claims'->>'user_agent';
-- Validate user credentials
SELECT id, username, email, password, user_level, roles
INTO v_user_id, v_username, v_email, v_password_hash, v_user_level, v_roles
FROM users
WHERE username = v_username AND is_active = true;
IF NOT FOUND THEN
RETURN QUERY SELECT false, 'Invalid credentials'::text, NULL::jsonb;
RETURN;
END IF;
-- TODO: Verify password hash using pgcrypto extension
-- Enable pgcrypto: CREATE EXTENSION IF NOT EXISTS pgcrypto;
-- IF NOT (crypt(p_request->>'password', v_password_hash) = v_password_hash) THEN
-- RETURN QUERY SELECT false, 'Invalid credentials'::text, NULL::jsonb;
-- RETURN;
-- END IF;
-- Generate session token
v_session_token := 'sess_' || encode(gen_random_bytes(32), 'hex') || '_' || extract(epoch from now())::bigint::text;
v_expires_at := now() + interval '24 hours';
-- Create session
INSERT INTO user_sessions (session_token, user_id, expires_at, ip_address, user_agent, last_activity_at)
VALUES (v_session_token, v_user_id, v_expires_at, v_ip_address, v_user_agent, now());
-- Update last login time
UPDATE users SET last_login_at = now() WHERE id = v_user_id;
-- Return success with LoginResponse
RETURN QUERY SELECT
true,
NULL::text,
jsonb_build_object(
'token', v_session_token,
'user', jsonb_build_object(
'user_id', v_user_id,
'user_name', v_username,
'email', v_email,
'user_level', v_user_level,
'roles', string_to_array(COALESCE(v_roles, ''), ','),
'session_id', v_session_token
),
'expires_in', 86400 -- 24 hours in seconds
);
END;
$$ LANGUAGE plpgsql;
-- 2. resolvespec_logout - Invalidates session
-- Input: LogoutRequest as jsonb {token: string, user_id: int}
-- Output: p_success (bool), p_error (text), p_data (jsonb)
CREATE OR REPLACE FUNCTION resolvespec_logout(p_request jsonb)
RETURNS TABLE(p_success boolean, p_error text, p_data jsonb) AS $$
DECLARE
v_token TEXT;
v_user_id INTEGER;
v_deleted INTEGER;
BEGIN
v_token := p_request->>'token';
v_user_id := (p_request->>'user_id')::integer;
-- Remove Bearer prefix if present
v_token := regexp_replace(v_token, '^Bearer ', '', 'i');
-- Delete the session
DELETE FROM user_sessions
WHERE session_token = v_token AND user_id = v_user_id;
GET DIAGNOSTICS v_deleted = ROW_COUNT;
IF v_deleted = 0 THEN
RETURN QUERY SELECT false, 'Session not found'::text, NULL::jsonb;
ELSE
RETURN QUERY SELECT true, NULL::text, jsonb_build_object('success', true);
END IF;
END;
$$ LANGUAGE plpgsql;
-- 3. resolvespec_session - Validates session and returns user context
-- Input: sessionid (text), reference (text)
-- Output: p_success (bool), p_error (text), p_user (UserContext as jsonb)
CREATE OR REPLACE FUNCTION resolvespec_session(p_session_token text, p_reference text)
RETURNS TABLE(p_success boolean, p_error text, p_user jsonb) AS $$
DECLARE
v_user_id INTEGER;
v_username TEXT;
v_email TEXT;
v_user_level INTEGER;
v_roles TEXT;
v_session_id TEXT;
BEGIN
-- Query session and user data
SELECT
s.user_id, u.username, u.email, u.user_level, u.roles, s.session_token
INTO
v_user_id, v_username, v_email, v_user_level, v_roles, v_session_id
FROM user_sessions s
JOIN users u ON s.user_id = u.id
WHERE s.session_token = p_session_token
AND s.expires_at > now()
AND u.is_active = true;
IF NOT FOUND THEN
RETURN QUERY SELECT false, 'Invalid or expired session'::text, NULL::jsonb;
RETURN;
END IF;
-- Return UserContext
RETURN QUERY SELECT
true,
NULL::text,
jsonb_build_object(
'user_id', v_user_id,
'user_name', v_username,
'email', v_email,
'user_level', v_user_level,
'session_id', v_session_id,
'roles', string_to_array(COALESCE(v_roles, ''), ',')
);
END;
$$ LANGUAGE plpgsql;
-- 4. resolvespec_session_update - Updates session activity timestamp
-- Input: sessionid (text), user_context (jsonb)
-- Output: p_success (bool), p_error (text), p_user (UserContext as jsonb)
CREATE OR REPLACE FUNCTION resolvespec_session_update(p_session_token text, p_user_context jsonb)
RETURNS TABLE(p_success boolean, p_error text, p_user jsonb) AS $$
DECLARE
v_updated INTEGER;
BEGIN
-- Update last activity timestamp
UPDATE user_sessions
SET last_activity_at = now()
WHERE session_token = p_session_token AND expires_at > now();
GET DIAGNOSTICS v_updated = ROW_COUNT;
IF v_updated = 0 THEN
RETURN QUERY SELECT false, 'Session not found or expired'::text, NULL::jsonb;
ELSE
-- Return the user context as-is
RETURN QUERY SELECT true, NULL::text, p_user_context;
END IF;
END;
$$ LANGUAGE plpgsql;
-- 5. resolvespec_refresh_token - Generates new session from existing one
-- Input: sessionid (text), user_context (jsonb)
-- Output: p_success (bool), p_error (text), p_user (UserContext as jsonb with new session_id)
CREATE OR REPLACE FUNCTION resolvespec_refresh_token(p_old_session_token text, p_user_context jsonb)
RETURNS TABLE(p_success boolean, p_error text, p_user jsonb) AS $$
DECLARE
v_user_id INTEGER;
v_username TEXT;
v_email TEXT;
v_user_level INTEGER;
v_roles TEXT;
v_new_session_token TEXT;
v_expires_at TIMESTAMP;
v_ip_address TEXT;
v_user_agent TEXT;
BEGIN
-- Verify old session exists and is valid
SELECT s.user_id, u.username, u.email, u.user_level, u.roles, s.ip_address, s.user_agent
INTO v_user_id, v_username, v_email, v_user_level, v_roles, v_ip_address, v_user_agent
FROM user_sessions s
JOIN users u ON s.user_id = u.id
WHERE s.session_token = p_old_session_token
AND s.expires_at > now()
AND u.is_active = true;
IF NOT FOUND THEN
RETURN QUERY SELECT false, 'Invalid or expired refresh token'::text, NULL::jsonb;
RETURN;
END IF;
-- Generate new session token
v_new_session_token := 'sess_' || encode(gen_random_bytes(32), 'hex') || '_' || extract(epoch from now())::bigint::text;
v_expires_at := now() + interval '24 hours';
-- Create new session
INSERT INTO user_sessions (session_token, user_id, expires_at, ip_address, user_agent, last_activity_at)
VALUES (v_new_session_token, v_user_id, v_expires_at, v_ip_address, v_user_agent, now());
-- Delete old session
DELETE FROM user_sessions WHERE session_token = p_old_session_token;
-- Return UserContext with new session_id
RETURN QUERY SELECT
true,
NULL::text,
jsonb_build_object(
'user_id', v_user_id,
'user_name', v_username,
'email', v_email,
'user_level', v_user_level,
'session_id', v_new_session_token,
'roles', string_to_array(COALESCE(v_roles, ''), ',')
);
END;
$$ LANGUAGE plpgsql;
-- 6. resolvespec_jwt_login - JWT-based login (queries user and returns data for JWT token generation)
-- Input: username (text), password (text)
-- Output: p_success (bool), p_error (text), p_user (user data as jsonb)
CREATE OR REPLACE FUNCTION resolvespec_jwt_login(p_username text, p_password text)
RETURNS TABLE(p_success boolean, p_error text, p_user jsonb) AS $$
DECLARE
v_user_id INTEGER;
v_username TEXT;
v_email TEXT;
v_password TEXT;
v_user_level INTEGER;
v_roles TEXT;
BEGIN
-- Query user data
SELECT id, username, email, password, user_level, roles
INTO v_user_id, v_username, v_email, v_password, v_user_level, v_roles
FROM users
WHERE username = p_username AND is_active = true;
IF NOT FOUND THEN
RETURN QUERY SELECT false, 'Invalid credentials'::text, NULL::jsonb;
RETURN;
END IF;
-- TODO: Verify password hash
-- IF NOT (crypt(p_password, v_password) = v_password) THEN
-- RETURN QUERY SELECT false, 'Invalid credentials'::text, NULL::jsonb;
-- RETURN;
-- END IF;
-- Return user data for JWT token generation
RETURN QUERY SELECT
true,
NULL::text,
jsonb_build_object(
'id', v_user_id,
'username', v_username,
'email', v_email,
'password', v_password,
'user_level', v_user_level,
'roles', v_roles
);
END;
$$ LANGUAGE plpgsql;
-- 7. resolvespec_jwt_logout - Adds token to blacklist
-- Input: token (text), user_id (int)
-- Output: p_success (bool), p_error (text)
CREATE OR REPLACE FUNCTION resolvespec_jwt_logout(p_token text, p_user_id integer)
RETURNS TABLE(p_success boolean, p_error text) AS $$
BEGIN
-- Add token to blacklist
INSERT INTO token_blacklist (token, user_id, expires_at)
VALUES (p_token, p_user_id, now() + interval '24 hours');
RETURN QUERY SELECT true, NULL::text;
EXCEPTION
WHEN OTHERS THEN
RETURN QUERY SELECT false, SQLERRM::text;
END;
$$ LANGUAGE plpgsql;
-- 8. resolvespec_column_security - Loads column security rules for user
-- Input: user_id (int), schema (text), table_name (text)
-- Output: p_success (bool), p_error (text), p_rules (array of security rules as jsonb)
CREATE OR REPLACE FUNCTION resolvespec_column_security(p_user_id integer, p_schema text, p_table_name text)
RETURNS TABLE(p_success boolean, p_error text, p_rules jsonb) AS $$
DECLARE
v_rules jsonb;
BEGIN
-- Query column security rules from core.secaccess
SELECT jsonb_agg(
jsonb_build_object(
'control', control,
'accesstype', accesstype,
'jsonvalue', jsonvalue
)
)
INTO v_rules
FROM core.secaccess
WHERE rid_hub IN (
SELECT rid_hub_parent
FROM core.hub_link
WHERE rid_hub_child = p_user_id AND parent_hubtype = 'secgroup'
)
AND control ILIKE (p_schema || '.' || p_table_name || '%');
IF v_rules IS NULL THEN
v_rules := '[]'::jsonb;
END IF;
RETURN QUERY SELECT true, NULL::text, v_rules;
EXCEPTION
WHEN OTHERS THEN
RETURN QUERY SELECT false, SQLERRM::text, '[]'::jsonb;
END;
$$ LANGUAGE plpgsql;
-- 9. resolvespec_row_security - Loads row security template for user (replaces core.api_sec_rowtemplate)
-- Input: schema (text), table_name (text), user_id (int)
-- Output: p_template (text), p_block (bool)
CREATE OR REPLACE FUNCTION resolvespec_row_security(p_schema text, p_table_name text, p_user_id integer)
RETURNS TABLE(p_template text, p_block boolean) AS $$
BEGIN
-- Call the existing core function if it exists, or implement your own logic
-- This is a placeholder that you should customize based on your core.api_sec_rowtemplate logic
RETURN QUERY SELECT ''::text, false;
-- Example implementation:
-- RETURN QUERY SELECT template, has_block
-- FROM core.row_security_config
-- WHERE schema_name = p_schema AND table_name = p_table_name AND user_id = p_user_id;
END;
$$ LANGUAGE plpgsql;
-- ============================================
-- Example: Test stored procedures
-- ============================================
-- Test login
-- SELECT * FROM resolvespec_login('{"username": "admin", "password": "test123", "claims": {"ip_address": "127.0.0.1", "user_agent": "test"}}'::jsonb);
-- Test session validation
-- SELECT * FROM resolvespec_session('sess_abc123', 'test_reference');
-- Test session update
-- SELECT * FROM resolvespec_session_update('sess_abc123', '{"user_id": 1, "user_name": "admin"}'::jsonb);
-- Test token refresh
-- SELECT * FROM resolvespec_refresh_token('sess_abc123', '{"user_id": 1, "user_name": "admin"}'::jsonb);
-- Test logout
-- SELECT * FROM resolvespec_logout('{"token": "sess_abc123", "user_id": 1}'::jsonb);
-- Test JWT login
-- SELECT * FROM resolvespec_jwt_login('admin', 'password123');
-- Test JWT logout
-- SELECT * FROM resolvespec_jwt_logout('jwt_token_here', 1);
-- Test column security
-- SELECT * FROM resolvespec_column_security(1, 'public', 'users');
-- Test row security
-- SELECT * FROM resolvespec_row_security('public', 'users', 1);

391
pkg/security/examples.go Normal file
View File

@@ -0,0 +1,391 @@
package security
import (
"context"
"fmt"
"net/http"
"strconv"
"strings"
"time"
// Optional: Uncomment if you want to use JWT authentication
// "github.com/golang-jwt/jwt/v5"
"gorm.io/gorm"
)
// Example 1: Simple Header-Based Authenticator
// =============================================
type HeaderAuthenticatorExample struct {
// Optional: Add any dependencies here (e.g., database, cache)
}
func NewHeaderAuthenticatorExample() *HeaderAuthenticatorExample {
return &HeaderAuthenticatorExample{}
}
func (a *HeaderAuthenticatorExample) Login(ctx context.Context, req LoginRequest) (*LoginResponse, error) {
// For header-based auth, login might not be used
// Could validate credentials against a database here
return nil, fmt.Errorf("header authentication does not support login")
}
func (a *HeaderAuthenticatorExample) Logout(ctx context.Context, req LogoutRequest) error {
// For header-based auth, logout is a no-op
return nil
}
func (a *HeaderAuthenticatorExample) Authenticate(r *http.Request) (*UserContext, error) {
userIDStr := r.Header.Get("X-User-ID")
if userIDStr == "" {
return nil, fmt.Errorf("X-User-ID header required")
}
userID, err := strconv.Atoi(userIDStr)
if err != nil {
return nil, fmt.Errorf("invalid user ID: %w", err)
}
return &UserContext{
UserID: userID,
UserName: r.Header.Get("X-User-Name"),
UserLevel: parseIntHeader(r, "X-User-Level", 0),
SessionID: r.Header.Get("X-Session-ID"),
RemoteID: r.Header.Get("X-Remote-ID"),
Email: r.Header.Get("X-User-Email"),
Roles: parseRoles(r.Header.Get("X-User-Roles")),
Claims: make(map[string]any),
Meta: make(map[string]any),
}, nil
}
// Example 2: JWT Token Authenticator
// ====================================
// NOTE: To use this, uncomment the jwt import and install: go get github.com/golang-jwt/jwt/v5
type JWTAuthenticatorExample struct {
secretKey []byte
db *gorm.DB
}
func NewJWTAuthenticatorExample(secretKey string, db *gorm.DB) *JWTAuthenticatorExample {
return &JWTAuthenticatorExample{
secretKey: []byte(secretKey),
db: db,
}
}
func (a *JWTAuthenticatorExample) Login(ctx context.Context, req LoginRequest) (*LoginResponse, error) {
// Validate credentials against database
var user struct {
ID int
Username string
Email string
Password string // Should be hashed
UserLevel int
Roles string
}
err := a.db.WithContext(ctx).
Table("users").
Where("username = ?", req.Username).
First(&user).Error
if err != nil {
return nil, fmt.Errorf("invalid credentials")
}
// TODO: Verify password hash
// if !verifyPassword(user.Password, req.Password) {
// return nil, fmt.Errorf("invalid credentials")
// }
// Create JWT token
expiresAt := time.Now().Add(24 * time.Hour)
// Uncomment when using JWT:
// token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
// "user_id": user.ID,
// "username": user.Username,
// "email": user.Email,
// "user_level": user.UserLevel,
// "roles": user.Roles,
// "exp": expiresAt.Unix(),
// })
// tokenString, err := token.SignedString(a.secretKey)
// if err != nil {
// return nil, fmt.Errorf("failed to generate token: %w", err)
// }
// Placeholder token for example (replace with actual JWT)
tokenString := fmt.Sprintf("token_%d_%d", user.ID, expiresAt.Unix())
return &LoginResponse{
Token: tokenString,
User: &UserContext{
UserID: user.ID,
UserName: user.Username,
Email: user.Email,
UserLevel: user.UserLevel,
Roles: parseRoles(user.Roles),
Claims: req.Claims,
Meta: req.Meta,
},
ExpiresIn: int64(24 * time.Hour.Seconds()),
}, nil
}
func (a *JWTAuthenticatorExample) Logout(ctx context.Context, req LogoutRequest) error {
// For JWT, logout could involve token blacklisting
// Add token to blacklist table
// err := a.db.WithContext(ctx).Table("token_blacklist").Create(map[string]interface{}{
// "token": req.Token,
// "expires_at": time.Now().Add(24 * time.Hour),
// }).Error
return nil
}
func (a *JWTAuthenticatorExample) Authenticate(r *http.Request) (*UserContext, error) {
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
return nil, fmt.Errorf("authorization header required")
}
tokenString := strings.TrimPrefix(authHeader, "Bearer ")
if tokenString == authHeader {
return nil, fmt.Errorf("bearer token required")
}
// Uncomment when using JWT:
// token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
// if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
// return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
// }
// return a.secretKey, nil
// })
//
// if err != nil || !token.Valid {
// return nil, fmt.Errorf("invalid token: %w", err)
// }
//
// claims, ok := token.Claims.(jwt.MapClaims)
// if !ok {
// return nil, fmt.Errorf("invalid token claims")
// }
//
// return &UserContext{
// UserID: int(claims["user_id"].(float64)),
// UserName: getString(claims, "username"),
// Email: getString(claims, "email"),
// UserLevel: getInt(claims, "user_level"),
// Roles: parseRoles(getString(claims, "roles")),
// Claims: claims,
// }, nil
// Placeholder implementation (replace with actual JWT parsing)
return nil, fmt.Errorf("JWT parsing not implemented - uncomment JWT code above")
}
// Example 3: Database Session Authenticator
// ==========================================
type DatabaseAuthenticatorExample struct {
db *gorm.DB
}
func NewDatabaseAuthenticatorExample(db *gorm.DB) *DatabaseAuthenticatorExample {
return &DatabaseAuthenticatorExample{db: db}
}
func (a *DatabaseAuthenticatorExample) Login(ctx context.Context, req LoginRequest) (*LoginResponse, error) {
// Query user from database
var user struct {
ID int
Username string
Email string
Password string // Should be hashed with bcrypt
UserLevel int
Roles string
IsActive bool
}
err := a.db.WithContext(ctx).
Table("users").
Where("username = ? AND is_active = true", req.Username).
First(&user).Error
if err != nil {
return nil, fmt.Errorf("invalid credentials")
}
// TODO: Verify password with bcrypt
// if err := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(req.Password)); err != nil {
// return nil, fmt.Errorf("invalid credentials")
// }
// Generate session token
sessionToken := fmt.Sprintf("sess_%s_%d", generateRandomString(32), time.Now().Unix())
expiresAt := time.Now().Add(24 * time.Hour)
// Create session in database
err = a.db.WithContext(ctx).Table("user_sessions").Create(map[string]any{
"session_token": sessionToken,
"user_id": user.ID,
"expires_at": expiresAt,
"created_at": time.Now(),
"ip_address": req.Claims["ip_address"],
"user_agent": req.Claims["user_agent"],
}).Error
if err != nil {
return nil, fmt.Errorf("failed to create session: %w", err)
}
return &LoginResponse{
Token: sessionToken,
User: &UserContext{
UserID: user.ID,
UserName: user.Username,
Email: user.Email,
UserLevel: user.UserLevel,
Roles: parseRoles(user.Roles),
SessionID: sessionToken,
Claims: req.Claims,
Meta: req.Meta,
},
ExpiresIn: int64(24 * time.Hour.Seconds()),
}, nil
}
func (a *DatabaseAuthenticatorExample) Logout(ctx context.Context, req LogoutRequest) error {
// Delete session from database
err := a.db.WithContext(ctx).
Table("user_sessions").
Where("session_token = ? AND user_id = ?", req.Token, req.UserID).
Delete(nil).Error
if err != nil {
return fmt.Errorf("failed to delete session: %w", err)
}
return nil
}
func (a *DatabaseAuthenticatorExample) Authenticate(r *http.Request) (*UserContext, error) {
// Extract session token from header or cookie
sessionToken := r.Header.Get("Authorization")
if sessionToken == "" {
// Try cookie
cookie, err := r.Cookie("session_token")
if err == nil {
sessionToken = cookie.Value
}
} else {
// Remove "Bearer " prefix if present
sessionToken = strings.TrimPrefix(sessionToken, "Bearer ")
}
if sessionToken == "" {
return nil, fmt.Errorf("session token required")
}
// Query session and user from database
var session struct {
SessionToken string
UserID int
ExpiresAt time.Time
Username string
Email string
UserLevel int
Roles string
}
query := `
SELECT
s.session_token,
s.user_id,
s.expires_at,
u.username,
u.email,
u.user_level,
u.roles
FROM user_sessions s
JOIN users u ON s.user_id = u.id
WHERE s.session_token = ?
AND s.expires_at > ?
AND u.is_active = true
`
err := a.db.Raw(query, sessionToken, time.Now()).Scan(&session).Error
if err != nil {
return nil, fmt.Errorf("invalid or expired session")
}
// Update last activity timestamp
go a.updateSessionActivity(sessionToken)
return &UserContext{
UserID: session.UserID,
UserName: session.Username,
Email: session.Email,
UserLevel: session.UserLevel,
SessionID: sessionToken,
Roles: parseRoles(session.Roles),
Claims: make(map[string]any),
Meta: make(map[string]any),
}, nil
}
// updateSessionActivity updates the last activity timestamp for the session
func (a *DatabaseAuthenticatorExample) updateSessionActivity(sessionToken string) {
a.db.Table("user_sessions").
Where("session_token = ?", sessionToken).
Update("last_activity_at", time.Now())
}
// Optional: Implement Refreshable interface
func (a *DatabaseAuthenticatorExample) RefreshToken(ctx context.Context, refreshToken string) (*LoginResponse, error) {
// Query the refresh token
var session struct {
UserID int
Username string
Email string
}
err := a.db.WithContext(ctx).Raw(`
SELECT u.id as user_id, u.username, u.email
FROM user_sessions s
JOIN users u ON s.user_id = u.id
WHERE s.session_token = ? AND s.expires_at > ?
`, refreshToken, time.Now()).Scan(&session).Error
if err != nil {
return nil, fmt.Errorf("invalid refresh token")
}
// Generate new session token
newSessionToken := fmt.Sprintf("sess_%s_%d", generateRandomString(32), time.Now().Unix())
expiresAt := time.Now().Add(24 * time.Hour)
// Create new session
err = a.db.WithContext(ctx).Table("user_sessions").Create(map[string]any{
"session_token": newSessionToken,
"user_id": session.UserID,
"expires_at": expiresAt,
"created_at": time.Now(),
}).Error
if err != nil {
return nil, fmt.Errorf("failed to create new session: %w", err)
}
// Delete old session
a.db.WithContext(ctx).Table("user_sessions").Where("session_token = ?", refreshToken).Delete(nil)
return &LoginResponse{
Token: newSessionToken,
User: &UserContext{
UserID: session.UserID,
UserName: session.Username,
Email: session.Email,
SessionID: newSessionToken,
Claims: make(map[string]any),
Meta: make(map[string]any),
},
ExpiresIn: int64(24 * time.Hour.Seconds()),
}, nil
}

View File

@@ -1,59 +1,51 @@
package security
import (
"context"
"fmt"
"reflect"
"github.com/bitechdev/ResolveSpec/pkg/logger"
"github.com/bitechdev/ResolveSpec/pkg/restheadspec"
)
// RegisterSecurityHooks registers all security-related hooks with the handler
func RegisterSecurityHooks(handler *restheadspec.Handler, securityList *SecurityList) {
// Hook 1: BeforeRead - Load security rules
handler.Hooks().Register(restheadspec.BeforeRead, func(hookCtx *restheadspec.HookContext) error {
return loadSecurityRules(hookCtx, securityList)
})
// Hook 2: BeforeScan - Apply row-level security filters
handler.Hooks().Register(restheadspec.BeforeScan, func(hookCtx *restheadspec.HookContext) error {
return applyRowSecurity(hookCtx, securityList)
})
// Hook 3: AfterRead - Apply column-level security (masking)
handler.Hooks().Register(restheadspec.AfterRead, func(hookCtx *restheadspec.HookContext) error {
return applyColumnSecurity(hookCtx, securityList)
})
// Hook 4 (Optional): Audit logging
handler.Hooks().Register(restheadspec.AfterRead, logDataAccess)
// SecurityContext is a generic interface that any spec can implement to integrate with security features
// This interface abstracts the common security context needs across different specs
type SecurityContext interface {
GetContext() context.Context
GetUserID() (int, bool)
GetSchema() string
GetEntity() string
GetModel() interface{}
GetQuery() interface{}
SetQuery(interface{})
GetResult() interface{}
SetResult(interface{})
}
// loadSecurityRules loads security configuration for the user and entity
func loadSecurityRules(hookCtx *restheadspec.HookContext, securityList *SecurityList) error {
// loadSecurityRules loads security configuration for the user and entity (generic version)
func loadSecurityRules(secCtx SecurityContext, securityList *SecurityList) error {
// Extract user ID from context
userID, ok := GetUserID(hookCtx.Context)
userID, ok := secCtx.GetUserID()
if !ok {
logger.Warn("No user ID in context for security check")
return fmt.Errorf("authentication required")
return nil
}
schema := hookCtx.Schema
tablename := hookCtx.Entity
schema := secCtx.GetSchema()
tablename := secCtx.GetEntity()
logger.Debug("Loading security rules for user=%d, schema=%s, table=%s", userID, schema, tablename)
// Load column security rules from database
err := securityList.LoadColumnSecurity(userID, schema, tablename, false)
// Load column security rules using the provider
err := securityList.LoadColumnSecurity(secCtx.GetContext(), userID, schema, tablename, false)
if err != nil {
logger.Warn("Failed to load column security: %v", err)
// Don't fail the request if no security rules exist
// return err
}
// Load row security rules from database
_, err = securityList.LoadRowSecurity(userID, schema, tablename, false)
// Load row security rules using the provider
_, err = securityList.LoadRowSecurity(secCtx.GetContext(), userID, schema, tablename, false)
if err != nil {
logger.Warn("Failed to load row security: %v", err)
// Don't fail the request if no security rules exist
@@ -63,15 +55,15 @@ func loadSecurityRules(hookCtx *restheadspec.HookContext, securityList *Security
return nil
}
// applyRowSecurity applies row-level security filters to the query
func applyRowSecurity(hookCtx *restheadspec.HookContext, securityList *SecurityList) error {
userID, ok := GetUserID(hookCtx.Context)
// applyRowSecurity applies row-level security filters to the query (generic version)
func applyRowSecurity(secCtx SecurityContext, securityList *SecurityList) error {
userID, ok := secCtx.GetUserID()
if !ok {
return nil // No user context, skip
}
schema := hookCtx.Schema
tablename := hookCtx.Entity
schema := secCtx.GetSchema()
tablename := secCtx.GetEntity()
// Get row security template
rowSec, err := securityList.GetRowSecurityTemplate(userID, schema, tablename)
@@ -89,8 +81,14 @@ func applyRowSecurity(hookCtx *restheadspec.HookContext, securityList *SecurityL
// If there's a security template, apply it as a WHERE clause
if rowSec.Template != "" {
model := secCtx.GetModel()
if model == nil {
logger.Debug("No model available for row security on %s.%s", schema, tablename)
return nil
}
// Get primary key name from model
modelType := reflect.TypeOf(hookCtx.Model)
modelType := reflect.TypeOf(model)
if modelType.Kind() == reflect.Ptr {
modelType = modelType.Elem()
}
@@ -117,39 +115,45 @@ func applyRowSecurity(hookCtx *restheadspec.HookContext, securityList *SecurityL
userID, schema, tablename, whereClause)
// Apply the WHERE clause to the query
// The query is in hookCtx.Query
if selectQuery, ok := hookCtx.Query.(interface {
query := secCtx.GetQuery()
if selectQuery, ok := query.(interface {
Where(string, ...interface{}) interface{}
}); ok {
hookCtx.Query = selectQuery.Where(whereClause)
secCtx.SetQuery(selectQuery.Where(whereClause))
} else {
logger.Error("Unable to apply WHERE clause - query doesn't support Where method")
logger.Debug("Query doesn't support Where method, skipping row security")
}
}
return nil
}
// applyColumnSecurity applies column-level security (masking/hiding) to results
func applyColumnSecurity(hookCtx *restheadspec.HookContext, securityList *SecurityList) error {
userID, ok := GetUserID(hookCtx.Context)
// applyColumnSecurity applies column-level security (masking/hiding) to results (generic version)
func applyColumnSecurity(secCtx SecurityContext, securityList *SecurityList) error {
userID, ok := secCtx.GetUserID()
if !ok {
return nil // No user context, skip
}
schema := hookCtx.Schema
tablename := hookCtx.Entity
schema := secCtx.GetSchema()
tablename := secCtx.GetEntity()
// Get result data
result := hookCtx.Result
result := secCtx.GetResult()
if result == nil {
return nil
}
logger.Debug("Applying column security for user=%d, schema=%s, table=%s", userID, schema, tablename)
model := secCtx.GetModel()
if model == nil {
logger.Debug("No model available for column security on %s.%s", schema, tablename)
return nil
}
// Get model type
modelType := reflect.TypeOf(hookCtx.Model)
modelType := reflect.TypeOf(model)
if modelType.Kind() == reflect.Ptr {
modelType = modelType.Elem()
}
@@ -169,37 +173,59 @@ func applyColumnSecurity(hookCtx *restheadspec.HookContext, securityList *Securi
// Update the result with masked data
if maskedResult.IsValid() && maskedResult.CanInterface() {
hookCtx.Result = maskedResult.Interface()
secCtx.SetResult(maskedResult.Interface())
}
return nil
}
// logDataAccess logs all data access for audit purposes
func logDataAccess(hookCtx *restheadspec.HookContext) error {
userID, _ := GetUserID(hookCtx.Context)
// logDataAccess logs all data access for audit purposes (generic version)
func logDataAccess(secCtx SecurityContext) error {
userID, _ := secCtx.GetUserID()
logger.Info("AUDIT: User %d accessed %s.%s with filters: %+v",
logger.Info("AUDIT: User %d accessed %s.%s",
userID,
hookCtx.Schema,
hookCtx.Entity,
hookCtx.Options.Filters,
secCtx.GetSchema(),
secCtx.GetEntity(),
)
// TODO: Write to audit log table or external audit service
// auditLog := AuditLog{
// UserID: userID,
// Schema: hookCtx.Schema,
// Entity: hookCtx.Entity,
// Schema: secCtx.GetSchema(),
// Entity: secCtx.GetEntity(),
// Action: "READ",
// Timestamp: time.Now(),
// Filters: hookCtx.Options.Filters,
// }
// db.Create(&auditLog)
return nil
}
// LogDataAccess is a public wrapper for logDataAccess that accepts a SecurityContext
// This allows other packages to use the audit logging functionality
func LogDataAccess(secCtx SecurityContext) error {
return logDataAccess(secCtx)
}
// LoadSecurityRules is a public wrapper for loadSecurityRules that accepts a SecurityContext
// This allows other packages to load security rules using the generic interface
func LoadSecurityRules(secCtx SecurityContext, securityList *SecurityList) error {
return loadSecurityRules(secCtx, securityList)
}
// ApplyRowSecurity is a public wrapper for applyRowSecurity that accepts a SecurityContext
// This allows other packages to apply row-level security using the generic interface
func ApplyRowSecurity(secCtx SecurityContext, securityList *SecurityList) error {
return applyRowSecurity(secCtx, securityList)
}
// ApplyColumnSecurity is a public wrapper for applyColumnSecurity that accepts a SecurityContext
// This allows other packages to apply column-level security using the generic interface
func ApplyColumnSecurity(secCtx SecurityContext, securityList *SecurityList) error {
return applyColumnSecurity(secCtx, securityList)
}
// Helper functions
func contains(s, substr string) bool {

View File

@@ -0,0 +1,93 @@
package security
import (
"context"
"net/http"
)
// UserContext holds authenticated user information
type UserContext struct {
UserID int `json:"user_id"`
UserName string `json:"user_name"`
UserLevel int `json:"user_level"`
SessionID string `json:"session_id"`
RemoteID string `json:"remote_id"`
Roles []string `json:"roles"`
Email string `json:"email"`
Claims map[string]any `json:"claims"`
Meta map[string]any `json:"meta"` // Additional metadata that can hold any JSON-serializable values
}
// LoginRequest contains credentials for login
type LoginRequest struct {
Username string `json:"username"`
Password string `json:"password"`
Claims map[string]any `json:"claims"` // Additional login data
Meta map[string]any `json:"meta"` // Additional metadata to be set on user context
}
// LoginResponse contains the result of a login attempt
type LoginResponse struct {
Token string `json:"token"`
RefreshToken string `json:"refresh_token"`
User *UserContext `json:"user"`
ExpiresIn int64 `json:"expires_in"` // Token expiration in seconds
}
// LogoutRequest contains information for logout
type LogoutRequest struct {
Token string `json:"token"`
UserID int `json:"user_id"`
}
// Authenticator handles user authentication operations
type Authenticator interface {
// Login authenticates credentials and returns a token
Login(ctx context.Context, req LoginRequest) (*LoginResponse, error)
// Logout invalidates a user's session/token
Logout(ctx context.Context, req LogoutRequest) error
// Authenticate extracts and validates user from HTTP request
// Returns UserContext or error if authentication fails
Authenticate(r *http.Request) (*UserContext, error)
}
// ColumnSecurityProvider handles column-level security (masking/hiding)
type ColumnSecurityProvider interface {
// GetColumnSecurity loads column security rules for a user and entity
GetColumnSecurity(ctx context.Context, userID int, schema, table string) ([]ColumnSecurity, error)
}
// RowSecurityProvider handles row-level security (filtering)
type RowSecurityProvider interface {
// GetRowSecurity loads row security rules for a user and entity
GetRowSecurity(ctx context.Context, userID int, schema, table string) (RowSecurity, error)
}
// SecurityProvider is the main interface combining all security concerns
type SecurityProvider interface {
Authenticator
ColumnSecurityProvider
RowSecurityProvider
}
// Optional interfaces for advanced functionality
// Refreshable allows providers to support token refresh
type Refreshable interface {
// RefreshToken exchanges a refresh token for a new access token
RefreshToken(ctx context.Context, refreshToken string) (*LoginResponse, error)
}
// Validatable allows providers to validate tokens without full authentication
type Validatable interface {
// ValidateToken checks if a token is valid without extracting full user context
ValidateToken(ctx context.Context, token string) (bool, error)
}
// Cacheable allows providers to support caching of security rules
type Cacheable interface {
// ClearCache clears cached security rules for a user/entity
ClearCache(ctx context.Context, userID int, schema, table string) error
}

View File

@@ -10,48 +10,391 @@ type contextKey string
const (
// Context keys for user information
UserIDKey contextKey = "user_id"
UserRolesKey contextKey = "user_roles"
UserTokenKey contextKey = "user_token"
UserIDKey contextKey = "user_id"
UserNameKey contextKey = "user_name"
UserLevelKey contextKey = "user_level"
SessionIDKey contextKey = "session_id"
RemoteIDKey contextKey = "remote_id"
UserRolesKey contextKey = "user_roles"
UserEmailKey contextKey = "user_email"
UserContextKey contextKey = "user_context"
UserMetaKey contextKey = "user_meta"
SkipAuthKey contextKey = "skip_auth"
OptionalAuthKey contextKey = "optional_auth"
)
// AuthMiddleware extracts user authentication from request and adds to context
// This should be applied before the ResolveSpec handler
// Uses GlobalSecurity.AuthenticateCallback if set, otherwise returns error
func AuthMiddleware(next http.Handler) http.Handler {
// SkipAuth returns a context with skip auth flag set to true
// Use this to mark routes that should bypass authentication middleware
func SkipAuth(ctx context.Context) context.Context {
return context.WithValue(ctx, SkipAuthKey, true)
}
// OptionalAuth returns a context with optional auth flag set to true
// Use this to mark routes that should try to authenticate, but fall back to guest if authentication fails
func OptionalAuth(ctx context.Context) context.Context {
return context.WithValue(ctx, OptionalAuthKey, true)
}
// createGuestContext creates a guest user context for unauthenticated requests
func createGuestContext(r *http.Request) *UserContext {
return &UserContext{
UserID: 0,
UserName: "guest",
UserLevel: 0,
SessionID: "",
RemoteID: r.RemoteAddr,
Roles: []string{"guest"},
Email: "",
Claims: map[string]any{},
Meta: map[string]any{},
}
}
// setUserContext adds a user context to the request context
func setUserContext(r *http.Request, userCtx *UserContext) *http.Request {
ctx := r.Context()
ctx = context.WithValue(ctx, UserContextKey, userCtx)
ctx = context.WithValue(ctx, UserIDKey, userCtx.UserID)
ctx = context.WithValue(ctx, UserNameKey, userCtx.UserName)
ctx = context.WithValue(ctx, UserLevelKey, userCtx.UserLevel)
ctx = context.WithValue(ctx, SessionIDKey, userCtx.SessionID)
ctx = context.WithValue(ctx, RemoteIDKey, userCtx.RemoteID)
ctx = context.WithValue(ctx, UserRolesKey, userCtx.Roles)
if userCtx.Email != "" {
ctx = context.WithValue(ctx, UserEmailKey, userCtx.Email)
}
if len(userCtx.Meta) > 0 {
ctx = context.WithValue(ctx, UserMetaKey, userCtx.Meta)
}
return r.WithContext(ctx)
}
// authenticateRequest performs authentication and adds user context to the request
// This is the shared authentication logic used by both handler and middleware
func authenticateRequest(w http.ResponseWriter, r *http.Request, provider SecurityProvider) (*http.Request, bool) {
// Call the provider's Authenticate method
userCtx, err := provider.Authenticate(r)
if err != nil {
http.Error(w, "Authentication failed: "+err.Error(), http.StatusUnauthorized)
return nil, false
}
return setUserContext(r, userCtx), true
}
// NewAuthHandler creates an authentication handler that can be used standalone
// This handler performs authentication and returns 401 if authentication fails
// Use this when you need authentication logic without middleware wrapping
func NewAuthHandler(securityList *SecurityList, next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Check if callback is set
if GlobalSecurity.AuthenticateCallback == nil {
http.Error(w, "AuthenticateCallback not set - you must provide an authentication callback", http.StatusInternalServerError)
// Get the security provider
provider := securityList.Provider()
if provider == nil {
http.Error(w, "Security provider not configured", http.StatusInternalServerError)
return
}
// Call the user-provided authentication callback
userID, roles, err := GlobalSecurity.AuthenticateCallback(r)
if err != nil {
http.Error(w, "Authentication failed: "+err.Error(), http.StatusUnauthorized)
return
}
// Add user information to context
ctx := context.WithValue(r.Context(), UserIDKey, userID)
if roles != "" {
ctx = context.WithValue(ctx, UserRolesKey, roles)
// Authenticate the request
authenticatedReq, ok := authenticateRequest(w, r, provider)
if !ok {
return // authenticateRequest already wrote the error response
}
// Continue with authenticated context
next.ServeHTTP(w, r.WithContext(ctx))
next.ServeHTTP(w, authenticatedReq)
})
}
// NewOptionalAuthHandler creates an optional authentication handler that can be used standalone
// This handler tries to authenticate but falls back to guest context if authentication fails
// Use this for routes that should show personalized content for authenticated users but still work for guests
func NewOptionalAuthHandler(securityList *SecurityList, next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Get the security provider
provider := securityList.Provider()
if provider == nil {
http.Error(w, "Security provider not configured", http.StatusInternalServerError)
return
}
// Try to authenticate
userCtx, err := provider.Authenticate(r)
if err != nil {
// Authentication failed - set guest context and continue
guestCtx := createGuestContext(r)
next.ServeHTTP(w, setUserContext(r, guestCtx))
return
}
// Authentication succeeded - set user context
next.ServeHTTP(w, setUserContext(r, userCtx))
})
}
// NewAuthMiddleware creates an authentication middleware with the given security list
// This middleware extracts user authentication from the request and adds it to context
// Routes can skip authentication by setting SkipAuthKey context value (use SkipAuth helper)
// Routes can use optional authentication by setting OptionalAuthKey context value (use OptionalAuth helper)
// When authentication is skipped or fails with optional auth, a guest user context is set instead
func NewAuthMiddleware(securityList *SecurityList) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Check if this route should skip authentication
if skip, ok := r.Context().Value(SkipAuthKey).(bool); ok && skip {
// Set guest user context for skipped routes
guestCtx := createGuestContext(r)
next.ServeHTTP(w, setUserContext(r, guestCtx))
return
}
// Get the security provider
provider := securityList.Provider()
if provider == nil {
http.Error(w, "Security provider not configured", http.StatusInternalServerError)
return
}
// Check if this route has optional authentication
optional, _ := r.Context().Value(OptionalAuthKey).(bool)
// Try to authenticate
userCtx, err := provider.Authenticate(r)
if err != nil {
if optional {
// Optional auth failed - set guest context and continue
guestCtx := createGuestContext(r)
next.ServeHTTP(w, setUserContext(r, guestCtx))
return
}
// Required auth failed - return error
http.Error(w, "Authentication failed: "+err.Error(), http.StatusUnauthorized)
return
}
// Authentication succeeded - set user context
next.ServeHTTP(w, setUserContext(r, userCtx))
})
}
}
// SetSecurityMiddleware adds security context to requests
// This middleware should be applied after AuthMiddleware
func SetSecurityMiddleware(securityList *SecurityList) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := context.WithValue(r.Context(), SECURITY_CONTEXT_KEY, securityList)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
// GetUserContext extracts the full user context from request context
func GetUserContext(ctx context.Context) (*UserContext, bool) {
userCtx, ok := ctx.Value(UserContextKey).(*UserContext)
return userCtx, ok
}
// GetUserID extracts the user ID from context
func GetUserID(ctx context.Context) (int, bool) {
userID, ok := ctx.Value(UserIDKey).(int)
return userID, ok
}
// GetUserName extracts the user name from context
func GetUserName(ctx context.Context) (string, bool) {
userName, ok := ctx.Value(UserNameKey).(string)
return userName, ok
}
// GetUserLevel extracts the user level from context
func GetUserLevel(ctx context.Context) (int, bool) {
userLevel, ok := ctx.Value(UserLevelKey).(int)
return userLevel, ok
}
// GetSessionID extracts the session ID from context
func GetSessionID(ctx context.Context) (string, bool) {
sessionID, ok := ctx.Value(SessionIDKey).(string)
return sessionID, ok
}
// GetRemoteID extracts the remote ID from context
func GetRemoteID(ctx context.Context) (string, bool) {
remoteID, ok := ctx.Value(RemoteIDKey).(string)
return remoteID, ok
}
// GetUserRoles extracts user roles from context
func GetUserRoles(ctx context.Context) (string, bool) {
roles, ok := ctx.Value(UserRolesKey).(string)
func GetUserRoles(ctx context.Context) ([]string, bool) {
roles, ok := ctx.Value(UserRolesKey).([]string)
return roles, ok
}
// GetUserEmail extracts user email from context
func GetUserEmail(ctx context.Context) (string, bool) {
email, ok := ctx.Value(UserEmailKey).(string)
return email, ok
}
// GetUserMeta extracts user metadata from context
func GetUserMeta(ctx context.Context) (map[string]any, bool) {
meta, ok := ctx.Value(UserMetaKey).(map[string]any)
return meta, ok
}
// // Handler adapters for resolvespec/restheadspec compatibility
// // These functions allow using NewAuthHandler and NewOptionalAuthHandler with custom handler abstractions
// // SpecHandlerAdapter is an interface for handler adapters that need authentication
// // Implement this interface to create adapters for custom handler types
// type SpecHandlerAdapter interface {
// // AdaptToHTTPHandler converts the custom handler to a standard http.Handler
// AdaptToHTTPHandler() http.Handler
// }
// // ResolveSpecHandlerAdapter adapts a resolvespec/restheadspec handler method to http.Handler
// type ResolveSpecHandlerAdapter struct {
// // HandlerMethod is the method to call (e.g., handler.Handle, handler.HandleGet)
// HandlerMethod func(w any, r any, params map[string]string)
// // Params are the route parameters (e.g., {"schema": "public", "entity": "users"})
// Params map[string]string
// // RequestAdapter converts *http.Request to the custom Request interface
// // Use router.NewHTTPRequest from pkg/common/adapters/router
// RequestAdapter func(*http.Request) any
// // ResponseAdapter converts http.ResponseWriter to the custom ResponseWriter interface
// // Use router.NewHTTPResponseWriter from pkg/common/adapters/router
// ResponseAdapter func(http.ResponseWriter) any
// }
// // AdaptToHTTPHandler implements SpecHandlerAdapter
// func (a *ResolveSpecHandlerAdapter) AdaptToHTTPHandler() http.Handler {
// return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// req := a.RequestAdapter(r)
// resp := a.ResponseAdapter(w)
// a.HandlerMethod(resp, req, a.Params)
// })
// }
// // WrapSpecHandler wraps a spec handler adapter with authentication
// // Use this to apply NewAuthHandler or NewOptionalAuthHandler to resolvespec/restheadspec handlers
// //
// // Example with required auth:
// //
// // adapter := &security.ResolveSpecHandlerAdapter{
// // HandlerMethod: handler.Handle,
// // Params: map[string]string{"schema": "public", "entity": "users"},
// // RequestAdapter: func(r *http.Request) any { return router.NewHTTPRequest(r) },
// // ResponseAdapter: func(w http.ResponseWriter) any { return router.NewHTTPResponseWriter(w) },
// // }
// // authHandler := security.WrapSpecHandler(securityList, adapter, false)
// // muxRouter.Handle("/api/users", authHandler)
// func WrapSpecHandler(securityList *SecurityList, adapter SpecHandlerAdapter, optional bool) http.Handler {
// httpHandler := adapter.AdaptToHTTPHandler()
// if optional {
// return NewOptionalAuthHandler(securityList, httpHandler)
// }
// return NewAuthHandler(securityList, httpHandler)
// }
// // MuxRouteBuilder helps build authenticated routes with Gorilla Mux
// type MuxRouteBuilder struct {
// securityList *SecurityList
// requestAdapter func(*http.Request) any
// responseAdapter func(http.ResponseWriter) any
// paramExtractor func(*http.Request) map[string]string
// }
// // NewMuxRouteBuilder creates a route builder for Gorilla Mux with standard router adapters
// // Usage:
// //
// // builder := security.NewMuxRouteBuilder(securityList, router.NewHTTPRequest, router.NewHTTPResponseWriter)
// func NewMuxRouteBuilder(
// securityList *SecurityList,
// requestAdapter func(*http.Request) any,
// responseAdapter func(http.ResponseWriter) any,
// ) *MuxRouteBuilder {
// return &MuxRouteBuilder{
// securityList: securityList,
// requestAdapter: requestAdapter,
// responseAdapter: responseAdapter,
// paramExtractor: nil, // Will be set per route using mux.Vars
// }
// }
// // HandleAuth creates an authenticated route handler
// // pattern: the route pattern (e.g., "/{schema}/{entity}")
// // handler: the handler method to call (e.g., handler.Handle)
// // optional: true for optional auth (guest fallback), false for required auth (401 on failure)
// // methods: HTTP methods (e.g., "GET", "POST")
// //
// // Usage:
// //
// // builder.HandleAuth(router, "/{schema}/{entity}", handler.Handle, false, "POST")
// func (b *MuxRouteBuilder) HandleAuth(
// router interface {
// HandleFunc(pattern string, f func(http.ResponseWriter, *http.Request)) interface{ Methods(...string) interface{} }
// },
// pattern string,
// handlerMethod func(w any, r any, params map[string]string),
// optional bool,
// methods ...string,
// ) {
// router.HandleFunc(pattern, func(w http.ResponseWriter, r *http.Request) {
// // Extract params using the registered extractor or default to empty map
// var params map[string]string
// if b.paramExtractor != nil {
// params = b.paramExtractor(r)
// } else {
// params = make(map[string]string)
// }
// adapter := &ResolveSpecHandlerAdapter{
// HandlerMethod: handlerMethod,
// Params: params,
// RequestAdapter: b.requestAdapter,
// ResponseAdapter: b.responseAdapter,
// }
// authHandler := WrapSpecHandler(b.securityList, adapter, optional)
// authHandler.ServeHTTP(w, r)
// }).Methods(methods...)
// }
// // SetParamExtractor sets a custom parameter extractor function
// // For Gorilla Mux, you would use: builder.SetParamExtractor(mux.Vars)
// func (b *MuxRouteBuilder) SetParamExtractor(extractor func(*http.Request) map[string]string) {
// b.paramExtractor = extractor
// }
// // SetupAuthenticatedSpecRoutes sets up all standard resolvespec/restheadspec routes with authentication
// // This is a convenience function that sets up the common route patterns
// //
// // Usage:
// //
// // security.SetupAuthenticatedSpecRoutes(router, handler, securityList, router.NewHTTPRequest, router.NewHTTPResponseWriter, mux.Vars)
// func SetupAuthenticatedSpecRoutes(
// router interface {
// HandleFunc(pattern string, f func(http.ResponseWriter, *http.Request)) interface{ Methods(...string) interface{} }
// },
// handler interface {
// Handle(w any, r any, params map[string]string)
// HandleGet(w any, r any, params map[string]string)
// },
// securityList *SecurityList,
// requestAdapter func(*http.Request) any,
// responseAdapter func(http.ResponseWriter) any,
// paramExtractor func(*http.Request) map[string]string,
// ) {
// builder := NewMuxRouteBuilder(securityList, requestAdapter, responseAdapter)
// builder.SetParamExtractor(paramExtractor)
// // POST /{schema}/{entity}
// builder.HandleAuth(router, "/{schema}/{entity}", handler.Handle, false, "POST")
// // POST /{schema}/{entity}/{id}
// builder.HandleAuth(router, "/{schema}/{entity}/{id}", handler.Handle, false, "POST")
// // GET /{schema}/{entity}
// builder.HandleAuth(router, "/{schema}/{entity}", handler.HandleGet, false, "GET")
// }

View File

@@ -3,7 +3,6 @@ package security
import (
"context"
"fmt"
"net/http"
"reflect"
"strings"
"sync"
@@ -16,26 +15,26 @@ import (
)
type ColumnSecurity struct {
Schema string
Tablename string
Path []string
ExtraFilters map[string]string
UserID int
Accesstype string `json:"accesstype"`
MaskStart int
MaskEnd int
MaskInvert bool
MaskChar string
Control string `json:"control"`
ID int `json:"id"`
Schema string `json:"schema"`
Tablename string `json:"tablename"`
Path []string `json:"path"`
ExtraFilters map[string]string `json:"extra_filters"`
UserID int `json:"user_id"`
Accesstype string `json:"accesstype"`
MaskStart int `json:"mask_start"`
MaskEnd int `json:"mask_end"`
MaskInvert bool `json:"mask_invert"`
MaskChar string `json:"mask_char"`
Control string `json:"control"`
ID int `json:"id"`
}
type RowSecurity struct {
Schema string
Tablename string
Template string
HasBlock bool
UserID int
Schema string `json:"schema"`
Tablename string `json:"tablename"`
Template string `json:"template"`
HasBlock bool `json:"has_block"`
UserID int `json:"user_id"`
}
func (m *RowSecurity) GetTemplate(pPrimaryKeyName string, pModelType reflect.Type) string {
@@ -47,46 +46,39 @@ func (m *RowSecurity) GetTemplate(pPrimaryKeyName string, pModelType reflect.Typ
return str
}
// Callback function types for customizing security behavior
type (
// AuthenticateFunc extracts user ID and roles from HTTP request
// Return userID, roles, error. If error is not nil, request will be rejected.
AuthenticateFunc func(r *http.Request) (userID int, roles string, err error)
// LoadColumnSecurityFunc loads column security rules for a user and entity
// Override this to customize how column security is loaded from your data source
LoadColumnSecurityFunc func(pUserID int, pSchema, pTablename string) ([]ColumnSecurity, error)
// LoadRowSecurityFunc loads row security rules for a user and entity
// Override this to customize how row security is loaded from your data source
LoadRowSecurityFunc func(pUserID int, pSchema, pTablename string) (RowSecurity, error)
)
// SecurityList manages security state and caching
// It wraps a SecurityProvider and provides caching and utility methods
type SecurityList struct {
provider SecurityProvider
ColumnSecurityMutex sync.RWMutex
ColumnSecurity map[string][]ColumnSecurity
RowSecurityMutex sync.RWMutex
RowSecurity map[string]RowSecurity
// Overridable callbacks
AuthenticateCallback AuthenticateFunc
LoadColumnSecurityCallback LoadColumnSecurityFunc
LoadRowSecurityCallback LoadRowSecurityFunc
}
// NewSecurityList creates a new security list with the given provider
func NewSecurityList(provider SecurityProvider) *SecurityList {
if provider == nil {
panic("security provider cannot be nil")
}
return &SecurityList{
provider: provider,
ColumnSecurity: make(map[string][]ColumnSecurity),
RowSecurity: make(map[string]RowSecurity),
}
}
// Provider returns the underlying security provider
func (m *SecurityList) Provider() SecurityProvider {
return m.provider
}
type CONTEXT_KEY string
const SECURITY_CONTEXT_KEY CONTEXT_KEY = "SecurityList"
var GlobalSecurity SecurityList
// SetSecurityMiddleware adds security context to requests
func SetSecurityMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := context.WithValue(r.Context(), SECURITY_CONTEXT_KEY, &GlobalSecurity)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
func maskString(pString string, maskStart, maskEnd int, maskChar string, invert bool) string {
strLen := len(pString)
middleIndex := (strLen / 2)
@@ -372,10 +364,9 @@ func (m *SecurityList) ApplyColumnSecurity(records reflect.Value, modelType refl
return records, nil
}
func (m *SecurityList) LoadColumnSecurity(pUserID int, pSchema, pTablename string, pOverwrite bool) error {
// Use the callback if provided
if m.LoadColumnSecurityCallback == nil {
return fmt.Errorf("LoadColumnSecurityCallback not set - you must provide a callback function")
func (m *SecurityList) LoadColumnSecurity(ctx context.Context, pUserID int, pSchema, pTablename string, pOverwrite bool) error {
if m.provider == nil {
return fmt.Errorf("security provider not set")
}
m.ColumnSecurityMutex.Lock()
@@ -390,10 +381,10 @@ func (m *SecurityList) LoadColumnSecurity(pUserID int, pSchema, pTablename strin
m.ColumnSecurity[secKey] = make([]ColumnSecurity, 0)
}
// Call the user-provided callback to load security rules
colSecList, err := m.LoadColumnSecurityCallback(pUserID, pSchema, pTablename)
// Call the provider to load security rules
colSecList, err := m.provider.GetColumnSecurity(ctx, pUserID, pSchema, pTablename)
if err != nil {
return fmt.Errorf("LoadColumnSecurityCallback failed: %v", err)
return fmt.Errorf("GetColumnSecurity failed: %v", err)
}
m.ColumnSecurity[secKey] = colSecList
@@ -422,10 +413,9 @@ func (m *SecurityList) ClearSecurity(pUserID int, pSchema, pTablename string) er
return nil
}
func (m *SecurityList) LoadRowSecurity(pUserID int, pSchema, pTablename string, pOverwrite bool) (RowSecurity, error) {
// Use the callback if provided
if m.LoadRowSecurityCallback == nil {
return RowSecurity{}, fmt.Errorf("LoadRowSecurityCallback not set - you must provide a callback function")
func (m *SecurityList) LoadRowSecurity(ctx context.Context, pUserID int, pSchema, pTablename string, pOverwrite bool) (RowSecurity, error) {
if m.provider == nil {
return RowSecurity{}, fmt.Errorf("security provider not set")
}
m.RowSecurityMutex.Lock()
@@ -436,10 +426,10 @@ func (m *SecurityList) LoadRowSecurity(pUserID int, pSchema, pTablename string,
}
secKey := fmt.Sprintf("%s.%s@%d", pSchema, pTablename, pUserID)
// Call the user-provided callback to load security rules
record, err := m.LoadRowSecurityCallback(pUserID, pSchema, pTablename)
// Call the provider to load security rules
record, err := m.provider.GetRowSecurity(ctx, pUserID, pSchema, pTablename)
if err != nil {
return RowSecurity{}, fmt.Errorf("LoadRowSecurityCallback failed: %v", err)
return RowSecurity{}, fmt.Errorf("GetRowSecurity failed: %v", err)
}
m.RowSecurity[secKey] = record

552
pkg/security/providers.go Normal file
View File

@@ -0,0 +1,552 @@
package security
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"net/http"
"strconv"
"strings"
"time"
)
// Production-Ready Authenticators
// =================================
// HeaderAuthenticator provides simple header-based authentication
// Expects: X-User-ID, X-User-Name, X-User-Level, X-Session-ID, X-Remote-ID, X-User-Roles, X-User-Email
type HeaderAuthenticator struct{}
func NewHeaderAuthenticator() *HeaderAuthenticator {
return &HeaderAuthenticator{}
}
func (a *HeaderAuthenticator) Login(ctx context.Context, req LoginRequest) (*LoginResponse, error) {
return nil, fmt.Errorf("header authentication does not support login")
}
func (a *HeaderAuthenticator) Logout(ctx context.Context, req LogoutRequest) error {
return nil
}
func (a *HeaderAuthenticator) Authenticate(r *http.Request) (*UserContext, error) {
userIDStr := r.Header.Get("X-User-ID")
if userIDStr == "" {
return nil, fmt.Errorf("X-User-ID header required")
}
userID, err := strconv.Atoi(userIDStr)
if err != nil {
return nil, fmt.Errorf("invalid user ID: %w", err)
}
return &UserContext{
UserID: userID,
UserName: r.Header.Get("X-User-Name"),
UserLevel: parseIntHeader(r, "X-User-Level", 0),
SessionID: r.Header.Get("X-Session-ID"),
RemoteID: r.Header.Get("X-Remote-ID"),
Email: r.Header.Get("X-User-Email"),
Roles: parseRoles(r.Header.Get("X-User-Roles")),
}, nil
}
// DatabaseAuthenticator provides session-based authentication with database storage
// All database operations go through stored procedures for security and consistency
// Requires stored procedures: resolvespec_login, resolvespec_logout, resolvespec_session,
// resolvespec_session_update, resolvespec_refresh_token
// See database_schema.sql for procedure definitions
type DatabaseAuthenticator struct {
db *sql.DB
}
func NewDatabaseAuthenticator(db *sql.DB) *DatabaseAuthenticator {
return &DatabaseAuthenticator{db: db}
}
func (a *DatabaseAuthenticator) Login(ctx context.Context, req LoginRequest) (*LoginResponse, error) {
// Convert LoginRequest to JSON
reqJSON, err := json.Marshal(req)
if err != nil {
return nil, fmt.Errorf("failed to marshal login request: %w", err)
}
// Call resolvespec_login stored procedure
var success bool
var errorMsg sql.NullString
var dataJSON []byte
query := `SELECT p_success, p_error, p_data FROM resolvespec_login($1::jsonb)`
err = a.db.QueryRowContext(ctx, query, reqJSON).Scan(&success, &errorMsg, &dataJSON)
if err != nil {
return nil, fmt.Errorf("login query failed: %w", err)
}
if !success {
if errorMsg.Valid {
return nil, fmt.Errorf("%s", errorMsg.String)
}
return nil, fmt.Errorf("login failed")
}
// Parse response
var response LoginResponse
if err := json.Unmarshal(dataJSON, &response); err != nil {
return nil, fmt.Errorf("failed to parse login response: %w", err)
}
return &response, nil
}
func (a *DatabaseAuthenticator) Logout(ctx context.Context, req LogoutRequest) error {
// Convert LogoutRequest to JSON
reqJSON, err := json.Marshal(req)
if err != nil {
return fmt.Errorf("failed to marshal logout request: %w", err)
}
// Call resolvespec_logout stored procedure
var success bool
var errorMsg sql.NullString
var dataJSON []byte
query := `SELECT p_success, p_error, p_data FROM resolvespec_logout($1::jsonb)`
err = a.db.QueryRowContext(ctx, query, reqJSON).Scan(&success, &errorMsg, &dataJSON)
if err != nil {
return fmt.Errorf("logout query failed: %w", err)
}
if !success {
if errorMsg.Valid {
return fmt.Errorf("%s", errorMsg.String)
}
return fmt.Errorf("logout failed")
}
return nil
}
func (a *DatabaseAuthenticator) Authenticate(r *http.Request) (*UserContext, error) {
// Extract session token from header or cookie
sessionToken := r.Header.Get("Authorization")
if sessionToken == "" {
// Try cookie
cookie, err := r.Cookie("session_token")
if err == nil {
sessionToken = cookie.Value
}
} else {
// Remove "Bearer " prefix if present
sessionToken = strings.TrimPrefix(sessionToken, "Bearer ")
}
if sessionToken == "" {
return nil, fmt.Errorf("session token required")
}
// Call resolvespec_session stored procedure
// reference could be route, controller name, or any identifier
reference := "authenticate"
var success bool
var errorMsg sql.NullString
var userJSON []byte
query := `SELECT p_success, p_error, p_user FROM resolvespec_session($1, $2)`
err := a.db.QueryRowContext(r.Context(), query, sessionToken, reference).Scan(&success, &errorMsg, &userJSON)
if err != nil {
return nil, fmt.Errorf("session query failed: %w", err)
}
if !success {
if errorMsg.Valid {
return nil, fmt.Errorf("%s", errorMsg.String)
}
return nil, fmt.Errorf("invalid or expired session")
}
// Parse UserContext
var userCtx UserContext
if err := json.Unmarshal(userJSON, &userCtx); err != nil {
return nil, fmt.Errorf("failed to parse user context: %w", err)
}
// Update last activity timestamp asynchronously
go a.updateSessionActivity(r.Context(), sessionToken, &userCtx)
return &userCtx, nil
}
// updateSessionActivity updates the last activity timestamp for the session
func (a *DatabaseAuthenticator) updateSessionActivity(ctx context.Context, sessionToken string, userCtx *UserContext) {
// Convert UserContext to JSON
userJSON, err := json.Marshal(userCtx)
if err != nil {
return
}
// Call resolvespec_session_update stored procedure
var success bool
var errorMsg sql.NullString
var updatedUserJSON []byte
query := `SELECT p_success, p_error, p_user FROM resolvespec_session_update($1, $2::jsonb)`
_ = a.db.QueryRowContext(ctx, query, sessionToken, userJSON).Scan(&success, &errorMsg, &updatedUserJSON)
}
// RefreshToken implements Refreshable interface
func (a *DatabaseAuthenticator) RefreshToken(ctx context.Context, refreshToken string) (*LoginResponse, error) {
// Call api_refresh_token stored procedure
// First, we need to get the current user context for the refresh token
var success bool
var errorMsg sql.NullString
var userJSON []byte
// Get current session to pass to refresh
query := `SELECT p_success, p_error, p_user FROM resolvespec_session($1, $2)`
err := a.db.QueryRowContext(ctx, query, refreshToken, "refresh").Scan(&success, &errorMsg, &userJSON)
if err != nil {
return nil, fmt.Errorf("refresh token query failed: %w", err)
}
if !success {
if errorMsg.Valid {
return nil, fmt.Errorf("%s", errorMsg.String)
}
return nil, fmt.Errorf("invalid refresh token")
}
// Call resolvespec_refresh_token to generate new token
var newSuccess bool
var newErrorMsg sql.NullString
var newUserJSON []byte
refreshQuery := `SELECT p_success, p_error, p_user FROM resolvespec_refresh_token($1, $2::jsonb)`
err = a.db.QueryRowContext(ctx, refreshQuery, refreshToken, userJSON).Scan(&newSuccess, &newErrorMsg, &newUserJSON)
if err != nil {
return nil, fmt.Errorf("refresh token generation failed: %w", err)
}
if !newSuccess {
if newErrorMsg.Valid {
return nil, fmt.Errorf("%s", newErrorMsg.String)
}
return nil, fmt.Errorf("failed to refresh token")
}
// Parse refreshed user context
var userCtx UserContext
if err := json.Unmarshal(newUserJSON, &userCtx); err != nil {
return nil, fmt.Errorf("failed to parse user context: %w", err)
}
return &LoginResponse{
Token: userCtx.SessionID, // New session token from stored procedure
User: &userCtx,
ExpiresIn: int64(24 * time.Hour.Seconds()),
}, nil
}
// JWTAuthenticator provides JWT token-based authentication
// All database operations go through stored procedures
// Requires stored procedures: resolvespec_jwt_login, resolvespec_jwt_logout
// NOTE: JWT signing/verification requires github.com/golang-jwt/jwt/v5 to be installed and imported
type JWTAuthenticator struct {
secretKey []byte
db *sql.DB
}
func NewJWTAuthenticator(secretKey string, db *sql.DB) *JWTAuthenticator {
return &JWTAuthenticator{
secretKey: []byte(secretKey),
db: db,
}
}
func (a *JWTAuthenticator) Login(ctx context.Context, req LoginRequest) (*LoginResponse, error) {
// Call resolvespec_jwt_login stored procedure
var success bool
var errorMsg sql.NullString
var userJSON []byte
query := `SELECT p_success, p_error, p_user FROM resolvespec_jwt_login($1, $2)`
err := a.db.QueryRowContext(ctx, query, req.Username, req.Password).Scan(&success, &errorMsg, &userJSON)
if err != nil {
return nil, fmt.Errorf("login query failed: %w", err)
}
if !success {
if errorMsg.Valid {
return nil, fmt.Errorf("%s", errorMsg.String)
}
return nil, fmt.Errorf("invalid credentials")
}
// Parse user data
var user struct {
ID int `json:"id"`
Username string `json:"username"`
Email string `json:"email"`
Password string `json:"password"`
UserLevel int `json:"user_level"`
Roles string `json:"roles"`
}
if err := json.Unmarshal(userJSON, &user); err != nil {
return nil, fmt.Errorf("failed to parse user data: %w", err)
}
// TODO: Verify password
// if err := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(req.Password)); err != nil {
// return nil, fmt.Errorf("invalid credentials")
// }
// Generate token (placeholder - implement JWT signing when library is available)
expiresAt := time.Now().Add(24 * time.Hour)
tokenString := fmt.Sprintf("token_%d_%d", user.ID, expiresAt.Unix())
return &LoginResponse{
Token: tokenString,
User: &UserContext{
UserID: user.ID,
UserName: user.Username,
Email: user.Email,
UserLevel: user.UserLevel,
Roles: parseRoles(user.Roles),
},
ExpiresIn: int64(24 * time.Hour.Seconds()),
}, nil
}
func (a *JWTAuthenticator) Logout(ctx context.Context, req LogoutRequest) error {
// Call resolvespec_jwt_logout stored procedure
var success bool
var errorMsg sql.NullString
query := `SELECT p_success, p_error FROM resolvespec_jwt_logout($1, $2)`
err := a.db.QueryRowContext(ctx, query, req.Token, req.UserID).Scan(&success, &errorMsg)
if err != nil {
return fmt.Errorf("logout query failed: %w", err)
}
if !success {
if errorMsg.Valid {
return fmt.Errorf("%s", errorMsg.String)
}
return fmt.Errorf("logout failed")
}
return nil
}
func (a *JWTAuthenticator) Authenticate(r *http.Request) (*UserContext, error) {
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
return nil, fmt.Errorf("authorization header required")
}
tokenString := strings.TrimPrefix(authHeader, "Bearer ")
if tokenString == authHeader {
return nil, fmt.Errorf("bearer token required")
}
// TODO: Implement JWT parsing when library is available
return nil, fmt.Errorf("JWT parsing not implemented - install github.com/golang-jwt/jwt/v5")
}
// Production-Ready Security Providers
// ====================================
// DatabaseColumnSecurityProvider loads column security from database
// All database operations go through stored procedures
// Requires stored procedure: resolvespec_column_security
type DatabaseColumnSecurityProvider struct {
db *sql.DB
}
func NewDatabaseColumnSecurityProvider(db *sql.DB) *DatabaseColumnSecurityProvider {
return &DatabaseColumnSecurityProvider{db: db}
}
func (p *DatabaseColumnSecurityProvider) GetColumnSecurity(ctx context.Context, userID int, schema, table string) ([]ColumnSecurity, error) {
var rules []ColumnSecurity
// Call resolvespec_column_security stored procedure
var success bool
var errorMsg sql.NullString
var rulesJSON []byte
query := `SELECT p_success, p_error, p_rules FROM resolvespec_column_security($1, $2, $3)`
err := p.db.QueryRowContext(ctx, query, userID, schema, table).Scan(&success, &errorMsg, &rulesJSON)
if err != nil {
return nil, fmt.Errorf("failed to load column security: %w", err)
}
if !success {
if errorMsg.Valid {
return nil, fmt.Errorf("%s", errorMsg.String)
}
return nil, fmt.Errorf("failed to load column security")
}
// Parse the JSON array of security records
type SecurityRecord struct {
Control string `json:"control"`
Accesstype string `json:"accesstype"`
JSONValue string `json:"jsonvalue"`
}
var records []SecurityRecord
if err := json.Unmarshal(rulesJSON, &records); err != nil {
return nil, fmt.Errorf("failed to parse security rules: %w", err)
}
// Convert records to ColumnSecurity rules
for _, rec := range records {
parts := strings.Split(rec.Control, ".")
if len(parts) < 3 {
continue
}
rule := ColumnSecurity{
Schema: schema,
Tablename: table,
Path: parts[2:],
Accesstype: rec.Accesstype,
UserID: userID,
}
rules = append(rules, rule)
}
return rules, nil
}
// DatabaseRowSecurityProvider loads row security from database
// All database operations go through stored procedures
// Requires stored procedure: resolvespec_row_security
type DatabaseRowSecurityProvider struct {
db *sql.DB
}
func NewDatabaseRowSecurityProvider(db *sql.DB) *DatabaseRowSecurityProvider {
return &DatabaseRowSecurityProvider{db: db}
}
func (p *DatabaseRowSecurityProvider) GetRowSecurity(ctx context.Context, userID int, schema, table string) (RowSecurity, error) {
var template string
var hasBlock bool
// Call resolvespec_row_security stored procedure
query := `SELECT p_template, p_block FROM resolvespec_row_security($1, $2, $3)`
err := p.db.QueryRowContext(ctx, query, schema, table, userID).Scan(&template, &hasBlock)
if err != nil {
return RowSecurity{}, fmt.Errorf("failed to load row security: %w", err)
}
return RowSecurity{
Schema: schema,
Tablename: table,
UserID: userID,
Template: template,
HasBlock: hasBlock,
}, nil
}
// ConfigColumnSecurityProvider provides static column security configuration
type ConfigColumnSecurityProvider struct {
rules map[string][]ColumnSecurity
}
func NewConfigColumnSecurityProvider(rules map[string][]ColumnSecurity) *ConfigColumnSecurityProvider {
return &ConfigColumnSecurityProvider{rules: rules}
}
func (p *ConfigColumnSecurityProvider) GetColumnSecurity(ctx context.Context, userID int, schema, table string) ([]ColumnSecurity, error) {
key := fmt.Sprintf("%s.%s", schema, table)
rules, ok := p.rules[key]
if !ok {
return []ColumnSecurity{}, nil
}
return rules, nil
}
// ConfigRowSecurityProvider provides static row security configuration
type ConfigRowSecurityProvider struct {
templates map[string]string
blocked map[string]bool
}
func NewConfigRowSecurityProvider(templates map[string]string, blocked map[string]bool) *ConfigRowSecurityProvider {
return &ConfigRowSecurityProvider{
templates: templates,
blocked: blocked,
}
}
func (p *ConfigRowSecurityProvider) GetRowSecurity(ctx context.Context, userID int, schema, table string) (RowSecurity, error) {
key := fmt.Sprintf("%s.%s", schema, table)
if p.blocked[key] {
return RowSecurity{
Schema: schema,
Tablename: table,
UserID: userID,
HasBlock: true,
}, nil
}
template := p.templates[key]
return RowSecurity{
Schema: schema,
Tablename: table,
UserID: userID,
Template: template,
HasBlock: false,
}, nil
}
// Helper functions
// ================
func parseRoles(rolesStr string) []string {
if rolesStr == "" {
return []string{}
}
return strings.Split(rolesStr, ",")
}
func parseIntHeader(r *http.Request, key string, defaultVal int) int {
val := r.Header.Get(key)
if val == "" {
return defaultVal
}
intVal, err := strconv.Atoi(val)
if err != nil {
return defaultVal
}
return intVal
}
func generateRandomString(length int) string {
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
b := make([]byte, length)
for i := range b {
b[i] = charset[time.Now().UnixNano()%int64(len(charset))]
}
return string(b)
}
// func getClaimString(claims map[string]any, key string) string {
// if claims == nil {
// return ""
// }
// if val, ok := claims[key]; ok {
// if str, ok := val.(string); ok {
// return str
// }
// }
// return ""
// }

View File

@@ -1,155 +0,0 @@
package security
import (
"fmt"
"net/http"
"github.com/gorilla/mux"
"gorm.io/gorm"
"github.com/bitechdev/ResolveSpec/pkg/restheadspec"
)
// SetupSecurityProvider initializes and configures the security provider
// This should be called when setting up your HTTP server
//
// IMPORTANT: You MUST configure the callbacks before calling this function:
// - GlobalSecurity.AuthenticateCallback
// - GlobalSecurity.LoadColumnSecurityCallback
// - GlobalSecurity.LoadRowSecurityCallback
//
// Example usage in your main.go or server setup:
//
// // Step 1: Configure callbacks (REQUIRED)
// security.GlobalSecurity.AuthenticateCallback = myAuthFunction
// security.GlobalSecurity.LoadColumnSecurityCallback = myLoadColumnSecurityFunction
// security.GlobalSecurity.LoadRowSecurityCallback = myLoadRowSecurityFunction
//
// // Step 2: Setup security provider
// handler := restheadspec.NewHandlerWithGORM(db)
// security.SetupSecurityProvider(handler, &security.GlobalSecurity)
//
// // Step 3: Apply middleware
// router.Use(mux.MiddlewareFunc(security.AuthMiddleware))
// router.Use(mux.MiddlewareFunc(security.SetSecurityMiddleware))
func SetupSecurityProvider(handler *restheadspec.Handler, securityList *SecurityList) error {
// Validate that required callbacks are configured
if securityList.AuthenticateCallback == nil {
return fmt.Errorf("AuthenticateCallback must be set before calling SetupSecurityProvider")
}
if securityList.LoadColumnSecurityCallback == nil {
return fmt.Errorf("LoadColumnSecurityCallback must be set before calling SetupSecurityProvider")
}
if securityList.LoadRowSecurityCallback == nil {
return fmt.Errorf("LoadRowSecurityCallback must be set before calling SetupSecurityProvider")
}
// Initialize security maps if needed
if securityList.ColumnSecurity == nil {
securityList.ColumnSecurity = make(map[string][]ColumnSecurity)
}
if securityList.RowSecurity == nil {
securityList.RowSecurity = make(map[string]RowSecurity)
}
// Register all security hooks
RegisterSecurityHooks(handler, securityList)
return nil
}
// Chain creates a middleware chain
func Chain(middlewares ...func(http.Handler) http.Handler) func(http.Handler) http.Handler {
return func(final http.Handler) http.Handler {
for i := len(middlewares) - 1; i >= 0; i-- {
final = middlewares[i](final)
}
return final
}
}
// CompleteExample shows a full integration example with Gorilla Mux
func CompleteExample(db *gorm.DB) (http.Handler, error) {
// Step 1: Create the ResolveSpec handler
handler := restheadspec.NewHandlerWithGORM(db)
// Step 2: Register your models
// handler.RegisterModel("public", "users", User{})
// handler.RegisterModel("public", "orders", Order{})
// Step 3: Configure security callbacks (REQUIRED!)
// See callbacks_example.go for example implementations
GlobalSecurity.AuthenticateCallback = ExampleAuthenticateFromHeader
GlobalSecurity.LoadColumnSecurityCallback = ExampleLoadColumnSecurityFromDatabase
GlobalSecurity.LoadRowSecurityCallback = ExampleLoadRowSecurityFromDatabase
// Step 4: Setup security provider
if err := SetupSecurityProvider(handler, &GlobalSecurity); err != nil {
return nil, fmt.Errorf("failed to setup security: %v", err)
}
// Step 5: Create Mux router and setup routes
router := mux.NewRouter()
// The routes are set up by restheadspec, which handles the conversion
// from http.Request to the internal request format
restheadspec.SetupMuxRoutes(router, handler)
// Step 6: Apply middleware to the entire router
secureRouter := Chain(
AuthMiddleware, // Extract user from token
SetSecurityMiddleware, // Add security context
)(router)
return secureRouter, nil
}
// ExampleWithMux shows a simpler integration with Mux
func ExampleWithMux(db *gorm.DB) (*mux.Router, error) {
handler := restheadspec.NewHandlerWithGORM(db)
// IMPORTANT: Configure callbacks BEFORE SetupSecurityProvider
GlobalSecurity.AuthenticateCallback = ExampleAuthenticateFromHeader
GlobalSecurity.LoadColumnSecurityCallback = ExampleLoadColumnSecurityFromConfig
GlobalSecurity.LoadRowSecurityCallback = ExampleLoadRowSecurityFromConfig
if err := SetupSecurityProvider(handler, &GlobalSecurity); err != nil {
return nil, fmt.Errorf("failed to setup security: %v", err)
}
router := mux.NewRouter()
// Setup API routes
restheadspec.SetupMuxRoutes(router, handler)
// Apply middleware to router
router.Use(mux.MiddlewareFunc(AuthMiddleware))
router.Use(mux.MiddlewareFunc(SetSecurityMiddleware))
return router, nil
}
// Example with Gin
// import "github.com/gin-gonic/gin"
//
// func ExampleWithGin(db *gorm.DB) *gin.Engine {
// handler := restheadspec.NewHandlerWithGORM(db)
// SetupSecurityProvider(handler, &GlobalSecurity)
//
// router := gin.Default()
//
// // Convert middleware to Gin middleware
// router.Use(func(c *gin.Context) {
// AuthMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// c.Request = r
// c.Next()
// })).ServeHTTP(c.Writer, c.Request)
// })
//
// // Setup routes
// api := router.Group("/api")
// api.Any("/:schema/:entity", gin.WrapH(http.HandlerFunc(handler.Handle)))
// api.Any("/:schema/:entity/:id", gin.WrapH(http.HandlerFunc(handler.Handle)))
//
// return router
// }