Added cache, funcspec and implemented total cache

This commit is contained in:
Hein 2025-12-01 14:40:54 +02:00
parent 6bbe0ec8b0
commit 1643a5e920
22 changed files with 4866 additions and 26 deletions

4
go.mod
View File

@ -19,7 +19,10 @@ require (
) )
require ( require (
github.com/bradfitz/gomemcache v0.0.0-20250403215159-8d39553ac7cf // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect github.com/davecgh/go-spew v1.1.1 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/dustin/go-humanize v1.0.1 // indirect github.com/dustin/go-humanize v1.0.1 // indirect
github.com/glebarez/go-sqlite v1.21.2 // indirect github.com/glebarez/go-sqlite v1.21.2 // indirect
github.com/google/uuid v1.6.0 // indirect github.com/google/uuid v1.6.0 // indirect
@ -30,6 +33,7 @@ require (
github.com/ncruces/go-strftime v0.1.9 // indirect github.com/ncruces/go-strftime v0.1.9 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/puzpuzpuz/xsync/v3 v3.5.1 // indirect github.com/puzpuzpuz/xsync/v3 v3.5.1 // indirect
github.com/redis/go-redis/v9 v9.17.1 // indirect
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.0 // indirect github.com/tidwall/pretty v1.2.0 // indirect

8
go.sum
View File

@ -1,6 +1,12 @@
github.com/bradfitz/gomemcache v0.0.0-20250403215159-8d39553ac7cf h1:TqhNAT4zKbTdLa62d2HDBFdvgSbIGB3eJE8HqhgiL9I=
github.com/bradfitz/gomemcache v0.0.0-20250403215159-8d39553ac7cf/go.mod h1:r5xuitiExdLAJ09PR7vBVENGvp4ZuTBeWTGtxuX3K+c=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
github.com/glebarez/go-sqlite v1.21.2 h1:3a6LFC4sKahUunAmynQKLZceZCOzUthkRkEAl9gAXWo= github.com/glebarez/go-sqlite v1.21.2 h1:3a6LFC4sKahUunAmynQKLZceZCOzUthkRkEAl9gAXWo=
@ -31,6 +37,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/puzpuzpuz/xsync/v3 v3.5.1 h1:GJYJZwO6IdxN/IKbneznS6yPkVC+c3zyY/j19c++5Fg= github.com/puzpuzpuz/xsync/v3 v3.5.1 h1:GJYJZwO6IdxN/IKbneznS6yPkVC+c3zyY/j19c++5Fg=
github.com/puzpuzpuz/xsync/v3 v3.5.1/go.mod h1:VjzYrABPabuM4KyBh1Ftq6u8nhwY5tBPKP9jpmh0nnA= github.com/puzpuzpuz/xsync/v3 v3.5.1/go.mod h1:VjzYrABPabuM4KyBh1Ftq6u8nhwY5tBPKP9jpmh0nnA=
github.com/redis/go-redis/v9 v9.17.1 h1:7tl732FjYPRT9H9aNfyTwKg9iTETjWjGKEJ2t/5iWTs=
github.com/redis/go-redis/v9 v9.17.1/go.mod h1:u410H11HMLoB+TP67dz8rL9s6QW2j76l0//kSOd3370=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=

340
pkg/cache/README.md vendored Normal file
View File

@ -0,0 +1,340 @@
# Cache Package
A flexible, provider-based caching library for Go that supports multiple backend storage systems including in-memory, Redis, and Memcache.
## Features
- **Multiple Providers**: Support for in-memory, Redis, and Memcache backends
- **Pluggable Architecture**: Easy to add custom cache providers
- **Type-Safe API**: Automatic JSON serialization/deserialization
- **TTL Support**: Configurable time-to-live for cache entries
- **Context-Aware**: All operations support Go contexts
- **Statistics**: Built-in cache statistics and monitoring
- **Pattern Deletion**: Delete keys by pattern (Redis)
- **Lazy Loading**: GetOrSet pattern for easy cache-aside implementation
## Installation
```bash
go get github.com/bitechdev/ResolveSpec/pkg/cache
```
For Redis support:
```bash
go get github.com/redis/go-redis/v9
```
For Memcache support:
```bash
go get github.com/bradfitz/gomemcache/memcache
```
## Quick Start
### In-Memory Cache
```go
package main
import (
"context"
"time"
"github.com/bitechdev/ResolveSpec/pkg/cache"
)
func main() {
// Initialize with in-memory provider
cache.UseMemory(&cache.Options{
DefaultTTL: 5 * time.Minute,
MaxSize: 10000,
})
defer cache.Close()
ctx := context.Background()
c := cache.GetDefaultCache()
// Store a value
type User struct {
ID int
Name string
}
user := User{ID: 1, Name: "John"}
c.Set(ctx, "user:1", user, 10*time.Minute)
// Retrieve a value
var retrieved User
c.Get(ctx, "user:1", &retrieved)
}
```
### Redis Cache
```go
cache.UseRedis(&cache.RedisConfig{
Host: "localhost",
Port: 6379,
Password: "",
DB: 0,
Options: &cache.Options{
DefaultTTL: 5 * time.Minute,
},
})
defer cache.Close()
```
### Memcache
```go
cache.UseMemcache(&cache.MemcacheConfig{
Servers: []string{"localhost:11211"},
Options: &cache.Options{
DefaultTTL: 5 * time.Minute,
},
})
defer cache.Close()
```
## API Reference
### Core Methods
#### Set
```go
Set(ctx context.Context, key string, value interface{}, ttl time.Duration) error
```
Stores a value in the cache with automatic JSON serialization.
#### Get
```go
Get(ctx context.Context, key string, dest interface{}) error
```
Retrieves and deserializes a value from the cache.
#### SetBytes / GetBytes
```go
SetBytes(ctx context.Context, key string, value []byte, ttl time.Duration) error
GetBytes(ctx context.Context, key string) ([]byte, error)
```
Store and retrieve raw bytes without serialization.
#### Delete
```go
Delete(ctx context.Context, key string) error
```
Removes a key from the cache.
#### DeleteByPattern
```go
DeleteByPattern(ctx context.Context, pattern string) error
```
Removes all keys matching a pattern (Redis only).
#### Clear
```go
Clear(ctx context.Context) error
```
Removes all items from the cache.
#### Exists
```go
Exists(ctx context.Context, key string) bool
```
Checks if a key exists in the cache.
#### GetOrSet
```go
GetOrSet(ctx context.Context, key string, dest interface{}, ttl time.Duration,
loader func() (interface{}, error)) error
```
Retrieves a value from cache, or loads and caches it if not found (lazy loading).
#### Stats
```go
Stats(ctx context.Context) (*CacheStats, error)
```
Returns cache statistics including hits, misses, and key counts.
## Provider Configuration
### In-Memory Options
```go
&cache.Options{
DefaultTTL: 5 * time.Minute, // Default expiration time
MaxSize: 10000, // Maximum number of items
EvictionPolicy: "LRU", // Eviction strategy (future)
}
```
### Redis Configuration
```go
&cache.RedisConfig{
Host: "localhost",
Port: 6379,
Password: "", // Optional authentication
DB: 0, // Database number
PoolSize: 10, // Connection pool size
Options: &cache.Options{
DefaultTTL: 5 * time.Minute,
},
}
```
### Memcache Configuration
```go
&cache.MemcacheConfig{
Servers: []string{"localhost:11211"},
MaxIdleConns: 2,
Timeout: 1 * time.Second,
Options: &cache.Options{
DefaultTTL: 5 * time.Minute,
},
}
```
## Advanced Usage
### Custom Provider
```go
// Create a custom provider instance
memProvider := cache.NewMemoryProvider(&cache.Options{
DefaultTTL: 10 * time.Minute,
MaxSize: 500,
})
// Initialize with custom provider
cache.Initialize(memProvider)
```
### Lazy Loading Pattern
```go
var data ExpensiveData
err := c.GetOrSet(ctx, "expensive:key", &data, 10*time.Minute, func() (interface{}, error) {
// This expensive operation only runs if key is not in cache
return computeExpensiveData(), nil
})
```
### Query API Cache
The package includes specialized functions for caching query results:
```go
// Cache a query result
api := "GetUsers"
query := "SELECT * FROM users WHERE active = true"
tablenames := "users"
total := int64(150)
cache.PutQueryAPICache(ctx, api, query, tablenames, total)
// Retrieve cached query
hash := cache.HashQueryAPICache(api, query)
cachedQuery, err := cache.FetchQueryAPICache(ctx, hash)
```
## Provider Comparison
| Feature | In-Memory | Redis | Memcache |
|---------|-----------|-------|----------|
| Persistence | No | Yes | No |
| Distributed | No | Yes | Yes |
| Pattern Delete | No | Yes | No |
| Statistics | Full | Full | Limited |
| Atomic Operations | Yes | Yes | Yes |
| Max Item Size | Memory | 512MB | 1MB |
## Best Practices
1. **Use contexts**: Always pass context for cancellation and timeout control
2. **Set appropriate TTLs**: Balance between freshness and performance
3. **Handle errors**: Cache misses and errors should be handled gracefully
4. **Monitor statistics**: Use Stats() to monitor cache performance
5. **Clean up**: Always call Close() when shutting down
6. **Pattern consistency**: Use consistent key naming patterns (e.g., "user:id:field")
## Example: Complete Application
```go
package main
import (
"context"
"log"
"time"
"github.com/bitechdev/ResolveSpec/pkg/cache"
)
type UserService struct {
cache *cache.Cache
}
func NewUserService() *UserService {
// Initialize with Redis in production, memory for testing
cache.UseRedis(&cache.RedisConfig{
Host: "localhost",
Port: 6379,
Options: &cache.Options{
DefaultTTL: 10 * time.Minute,
},
})
return &UserService{
cache: cache.GetDefaultCache(),
}
}
func (s *UserService) GetUser(ctx context.Context, userID int) (*User, error) {
var user User
cacheKey := fmt.Sprintf("user:%d", userID)
// Try to get from cache first
err := s.cache.GetOrSet(ctx, cacheKey, &user, 15*time.Minute, func() (interface{}, error) {
// Load from database if not in cache
return s.loadUserFromDB(userID)
})
if err != nil {
return nil, err
}
return &user, nil
}
func (s *UserService) InvalidateUser(ctx context.Context, userID int) error {
cacheKey := fmt.Sprintf("user:%d", userID)
return s.cache.Delete(ctx, cacheKey)
}
func main() {
service := NewUserService()
defer cache.Close()
ctx := context.Background()
user, err := service.GetUser(ctx, 123)
if err != nil {
log.Fatal(err)
}
log.Printf("User: %+v", user)
}
```
## Performance Considerations
- **In-Memory**: Fastest but limited by RAM and not distributed
- **Redis**: Great for distributed systems, persistent, but network overhead
- **Memcache**: Good for distributed caching, simpler than Redis but less features
Choose based on your needs:
- Single instance? Use in-memory
- Need persistence or advanced features? Use Redis
- Simple distributed cache? Use Memcache
## License
See repository license.

76
pkg/cache/cache.go vendored Normal file
View File

@ -0,0 +1,76 @@
package cache
import (
"context"
"fmt"
"time"
)
var (
defaultCache *Cache
)
// Initialize initializes the cache with a provider.
// If not called, the package will use an in-memory provider by default.
func Initialize(provider Provider) {
defaultCache = NewCache(provider)
}
// UseMemory configures the cache to use in-memory storage.
func UseMemory(opts *Options) error {
provider := NewMemoryProvider(opts)
defaultCache = NewCache(provider)
return nil
}
// UseRedis configures the cache to use Redis storage.
func UseRedis(config *RedisConfig) error {
provider, err := NewRedisProvider(config)
if err != nil {
return fmt.Errorf("failed to initialize Redis provider: %w", err)
}
defaultCache = NewCache(provider)
return nil
}
// UseMemcache configures the cache to use Memcache storage.
func UseMemcache(config *MemcacheConfig) error {
provider, err := NewMemcacheProvider(config)
if err != nil {
return fmt.Errorf("failed to initialize Memcache provider: %w", err)
}
defaultCache = NewCache(provider)
return nil
}
// GetDefaultCache returns the default cache instance.
// Initializes with in-memory provider if not already initialized.
func GetDefaultCache() *Cache {
if defaultCache == nil {
UseMemory(&Options{
DefaultTTL: 5 * time.Minute,
MaxSize: 10000,
})
}
return defaultCache
}
// SetDefaultCache sets a custom cache instance as the default cache.
// This is useful for testing or when you want to use a pre-configured cache instance.
func SetDefaultCache(cache *Cache) {
defaultCache = cache
}
// GetStats returns cache statistics.
func GetStats(ctx context.Context) (*CacheStats, error) {
cache := GetDefaultCache()
return cache.Stats(ctx)
}
// Close closes the cache and releases resources.
func Close() error {
if defaultCache != nil {
return defaultCache.Close()
}
return nil
}

147
pkg/cache/cache_manager.go vendored Normal file
View File

@ -0,0 +1,147 @@
package cache
import (
"context"
"encoding/json"
"fmt"
"time"
)
// Cache is the main cache manager that wraps a Provider.
type Cache struct {
provider Provider
}
// NewCache creates a new cache manager with the specified provider.
func NewCache(provider Provider) *Cache {
return &Cache{
provider: provider,
}
}
// Get retrieves and deserializes a value from the cache.
func (c *Cache) Get(ctx context.Context, key string, dest interface{}) error {
data, exists := c.provider.Get(ctx, key)
if !exists {
return fmt.Errorf("key not found: %s", key)
}
if err := json.Unmarshal(data, dest); err != nil {
return fmt.Errorf("failed to deserialize: %w", err)
}
return nil
}
// GetBytes retrieves raw bytes from the cache.
func (c *Cache) GetBytes(ctx context.Context, key string) ([]byte, error) {
data, exists := c.provider.Get(ctx, key)
if !exists {
return nil, fmt.Errorf("key not found: %s", key)
}
return data, nil
}
// Set serializes and stores a value in the cache with the specified TTL.
func (c *Cache) Set(ctx context.Context, key string, value interface{}, ttl time.Duration) error {
data, err := json.Marshal(value)
if err != nil {
return fmt.Errorf("failed to serialize: %w", err)
}
return c.provider.Set(ctx, key, data, ttl)
}
// SetBytes stores raw bytes in the cache with the specified TTL.
func (c *Cache) SetBytes(ctx context.Context, key string, value []byte, ttl time.Duration) error {
return c.provider.Set(ctx, key, value, ttl)
}
// Delete removes a key from the cache.
func (c *Cache) Delete(ctx context.Context, key string) error {
return c.provider.Delete(ctx, key)
}
// DeleteByPattern removes all keys matching the pattern.
func (c *Cache) DeleteByPattern(ctx context.Context, pattern string) error {
return c.provider.DeleteByPattern(ctx, pattern)
}
// Clear removes all items from the cache.
func (c *Cache) Clear(ctx context.Context) error {
return c.provider.Clear(ctx)
}
// Exists checks if a key exists in the cache.
func (c *Cache) Exists(ctx context.Context, key string) bool {
return c.provider.Exists(ctx, key)
}
// Stats returns statistics about the cache.
func (c *Cache) Stats(ctx context.Context) (*CacheStats, error) {
return c.provider.Stats(ctx)
}
// Close closes the cache and releases any resources.
func (c *Cache) Close() error {
return c.provider.Close()
}
// GetOrSet retrieves a value from cache, or sets it if it doesn't exist.
// The loader function is called only if the key is not found in cache.
func (c *Cache) GetOrSet(ctx context.Context, key string, dest interface{}, ttl time.Duration, loader func() (interface{}, error)) error {
// Try to get from cache first
err := c.Get(ctx, key, dest)
if err == nil {
return nil
}
// Load the value
value, err := loader()
if err != nil {
return fmt.Errorf("loader failed: %w", err)
}
// Store in cache
if err := c.Set(ctx, key, value, ttl); err != nil {
return fmt.Errorf("failed to cache value: %w", err)
}
// Populate dest with the loaded value
data, err := json.Marshal(value)
if err != nil {
return fmt.Errorf("failed to serialize loaded value: %w", err)
}
if err := json.Unmarshal(data, dest); err != nil {
return fmt.Errorf("failed to deserialize loaded value: %w", err)
}
return nil
}
// Remember is a convenience function that caches the result of a function call.
// It's similar to GetOrSet but returns the value directly.
func (c *Cache) Remember(ctx context.Context, key string, ttl time.Duration, loader func() (interface{}, error)) (interface{}, error) {
// Try to get from cache first as bytes
data, err := c.GetBytes(ctx, key)
if err == nil {
var result interface{}
if err := json.Unmarshal(data, &result); err == nil {
return result, nil
}
}
// Load the value
value, err := loader()
if err != nil {
return nil, fmt.Errorf("loader failed: %w", err)
}
// Store in cache
if err := c.Set(ctx, key, value, ttl); err != nil {
return nil, fmt.Errorf("failed to cache value: %w", err)
}
return value, nil
}

69
pkg/cache/cache_test.go vendored Normal file
View File

@ -0,0 +1,69 @@
package cache
import (
"context"
"testing"
"time"
)
func TestSetDefaultCache(t *testing.T) {
// Create a custom cache instance
provider := NewMemoryProvider(&Options{
DefaultTTL: 1 * time.Minute,
MaxSize: 50,
})
customCache := NewCache(provider)
// Set it as the default
SetDefaultCache(customCache)
// Verify it's now the default
retrievedCache := GetDefaultCache()
if retrievedCache != customCache {
t.Error("SetDefaultCache did not set the cache correctly")
}
// Test that we can use it
ctx := context.Background()
testKey := "test_key"
testValue := "test_value"
err := retrievedCache.Set(ctx, testKey, testValue, time.Minute)
if err != nil {
t.Fatalf("Failed to set value: %v", err)
}
var result string
err = retrievedCache.Get(ctx, testKey, &result)
if err != nil {
t.Fatalf("Failed to get value: %v", err)
}
if result != testValue {
t.Errorf("Expected %s, got %s", testValue, result)
}
// Clean up - reset to default
SetDefaultCache(nil)
}
func TestGetDefaultCacheInitialization(t *testing.T) {
// Reset to nil first
SetDefaultCache(nil)
// GetDefaultCache should auto-initialize
cache := GetDefaultCache()
if cache == nil {
t.Error("GetDefaultCache should auto-initialize, got nil")
}
// Should be usable
ctx := context.Background()
err := cache.Set(ctx, "test", "value", time.Minute)
if err != nil {
t.Errorf("Failed to use auto-initialized cache: %v", err)
}
// Clean up
SetDefaultCache(nil)
}

252
pkg/cache/example_usage.go vendored Normal file
View File

@ -0,0 +1,252 @@
package cache
import (
"context"
"fmt"
"log"
"time"
)
// ExampleInMemoryCache demonstrates using the in-memory cache provider.
func ExampleInMemoryCache() {
// Initialize with in-memory provider
err := UseMemory(&Options{
DefaultTTL: 5 * time.Minute,
MaxSize: 1000,
})
if err != nil {
log.Fatal(err)
}
defer Close()
ctx := context.Background()
// Get the cache instance
cache := GetDefaultCache()
// Store a value
type User struct {
ID int
Name string
}
user := User{ID: 1, Name: "John Doe"}
err = cache.Set(ctx, "user:1", user, 10*time.Minute)
if err != nil {
log.Fatal(err)
}
// Retrieve a value
var retrieved User
err = cache.Get(ctx, "user:1", &retrieved)
if err != nil {
log.Fatal(err)
}
fmt.Printf("Retrieved user: %+v\n", retrieved)
// Check if key exists
exists := cache.Exists(ctx, "user:1")
fmt.Printf("Key exists: %v\n", exists)
// Delete a key
err = cache.Delete(ctx, "user:1")
if err != nil {
log.Fatal(err)
}
// Get statistics
stats, err := cache.Stats(ctx)
if err != nil {
log.Fatal(err)
}
fmt.Printf("Cache stats: %+v\n", stats)
}
// ExampleRedisCache demonstrates using the Redis cache provider.
func ExampleRedisCache() {
// Initialize with Redis provider
err := UseRedis(&RedisConfig{
Host: "localhost",
Port: 6379,
Password: "", // Set if Redis requires authentication
DB: 0,
Options: &Options{
DefaultTTL: 5 * time.Minute,
},
})
if err != nil {
log.Fatal(err)
}
defer Close()
ctx := context.Background()
// Get the cache instance
cache := GetDefaultCache()
// Store raw bytes
data := []byte("Hello, Redis!")
err = cache.SetBytes(ctx, "greeting", data, 1*time.Hour)
if err != nil {
log.Fatal(err)
}
// Retrieve raw bytes
retrieved, err := cache.GetBytes(ctx, "greeting")
if err != nil {
log.Fatal(err)
}
fmt.Printf("Retrieved data: %s\n", string(retrieved))
// Clear all cache
err = cache.Clear(ctx)
if err != nil {
log.Fatal(err)
}
}
// ExampleMemcacheCache demonstrates using the Memcache cache provider.
func ExampleMemcacheCache() {
// Initialize with Memcache provider
err := UseMemcache(&MemcacheConfig{
Servers: []string{"localhost:11211"},
Options: &Options{
DefaultTTL: 5 * time.Minute,
},
})
if err != nil {
log.Fatal(err)
}
defer Close()
ctx := context.Background()
// Get the cache instance
cache := GetDefaultCache()
// Store a value
type Product struct {
ID int
Name string
Price float64
}
product := Product{ID: 100, Name: "Widget", Price: 29.99}
err = cache.Set(ctx, "product:100", product, 30*time.Minute)
if err != nil {
log.Fatal(err)
}
// Retrieve a value
var retrieved Product
err = cache.Get(ctx, "product:100", &retrieved)
if err != nil {
log.Fatal(err)
}
fmt.Printf("Retrieved product: %+v\n", retrieved)
}
// ExampleGetOrSet demonstrates the GetOrSet pattern for lazy loading.
func ExampleGetOrSet() {
err := UseMemory(&Options{
DefaultTTL: 5 * time.Minute,
MaxSize: 1000,
})
if err != nil {
log.Fatal(err)
}
defer Close()
ctx := context.Background()
cache := GetDefaultCache()
type ExpensiveData struct {
Result string
}
var data ExpensiveData
err = cache.GetOrSet(ctx, "expensive:computation", &data, 10*time.Minute, func() (interface{}, error) {
// This expensive operation only runs if the key is not in cache
fmt.Println("Computing expensive result...")
time.Sleep(1 * time.Second)
return ExpensiveData{Result: "computed value"}, nil
})
if err != nil {
log.Fatal(err)
}
fmt.Printf("Data: %+v\n", data)
// Second call will use cached value
err = cache.GetOrSet(ctx, "expensive:computation", &data, 10*time.Minute, func() (interface{}, error) {
fmt.Println("This won't be called!")
return ExpensiveData{Result: "new value"}, nil
})
if err != nil {
log.Fatal(err)
}
fmt.Printf("Cached data: %+v\n", data)
}
// ExampleCustomProvider demonstrates using a custom provider.
func ExampleCustomProvider() {
// Create a custom provider
memProvider := NewMemoryProvider(&Options{
DefaultTTL: 10 * time.Minute,
MaxSize: 500,
})
// Initialize with custom provider
Initialize(memProvider)
defer Close()
ctx := context.Background()
cache := GetDefaultCache()
// Use the cache
err := cache.SetBytes(ctx, "key", []byte("value"), 5*time.Minute)
if err != nil {
log.Fatal(err)
}
// Clean expired items (memory provider specific)
if mp, ok := cache.provider.(*MemoryProvider); ok {
count := mp.CleanExpired(ctx)
fmt.Printf("Cleaned %d expired items\n", count)
}
}
// ExampleDeleteByPattern demonstrates pattern-based deletion (Redis only).
func ExampleDeleteByPattern() {
err := UseRedis(&RedisConfig{
Host: "localhost",
Port: 6379,
Options: &Options{
DefaultTTL: 5 * time.Minute,
},
})
if err != nil {
log.Fatal(err)
}
defer Close()
ctx := context.Background()
cache := GetDefaultCache()
// Store multiple keys with a pattern
cache.SetBytes(ctx, "user:1:profile", []byte("profile1"), 10*time.Minute)
cache.SetBytes(ctx, "user:2:profile", []byte("profile2"), 10*time.Minute)
cache.SetBytes(ctx, "user:1:settings", []byte("settings1"), 10*time.Minute)
// Delete all keys matching pattern (Redis glob pattern)
err = cache.DeleteByPattern(ctx, "user:*:profile")
if err != nil {
log.Fatal(err)
}
fmt.Println("Deleted all user profile keys")
}

57
pkg/cache/provider.go vendored Normal file
View File

@ -0,0 +1,57 @@
package cache
import (
"context"
"time"
)
// Provider defines the interface that all cache providers must implement.
type Provider interface {
// Get retrieves a value from the cache by key.
// Returns nil, false if key doesn't exist or is expired.
Get(ctx context.Context, key string) ([]byte, bool)
// Set stores a value in the cache with the specified TTL.
// If ttl is 0, the item never expires.
Set(ctx context.Context, key string, value []byte, ttl time.Duration) error
// Delete removes a key from the cache.
Delete(ctx context.Context, key string) error
// DeleteByPattern removes all keys matching the pattern.
// Pattern syntax depends on the provider implementation.
DeleteByPattern(ctx context.Context, pattern string) error
// Clear removes all items from the cache.
Clear(ctx context.Context) error
// Exists checks if a key exists in the cache.
Exists(ctx context.Context, key string) bool
// Close closes the provider and releases any resources.
Close() error
// Stats returns statistics about the cache provider.
Stats(ctx context.Context) (*CacheStats, error)
}
// CacheStats contains cache statistics.
type CacheStats struct {
Hits int64 `json:"hits"`
Misses int64 `json:"misses"`
Keys int64 `json:"keys"`
ProviderType string `json:"provider_type"`
ProviderStats map[string]any `json:"provider_stats,omitempty"`
}
// Options contains configuration options for cache providers.
type Options struct {
// DefaultTTL is the default time-to-live for cache items.
DefaultTTL time.Duration
// MaxSize is the maximum number of items (for in-memory provider).
MaxSize int
// EvictionPolicy determines how items are evicted (LRU, LFU, etc).
EvictionPolicy string
}

144
pkg/cache/provider_memcache.go vendored Normal file
View File

@ -0,0 +1,144 @@
package cache
import (
"context"
"fmt"
"time"
"github.com/bradfitz/gomemcache/memcache"
)
// MemcacheProvider is a Memcache implementation of the Provider interface.
type MemcacheProvider struct {
client *memcache.Client
options *Options
}
// MemcacheConfig contains Memcache-specific configuration.
type MemcacheConfig struct {
// Servers is a list of memcache server addresses (e.g., "localhost:11211")
Servers []string
// MaxIdleConns is the maximum number of idle connections (default: 2)
MaxIdleConns int
// Timeout for connection operations (default: 1 second)
Timeout time.Duration
// Options contains general cache options
Options *Options
}
// NewMemcacheProvider creates a new Memcache cache provider.
func NewMemcacheProvider(config *MemcacheConfig) (*MemcacheProvider, error) {
if config == nil {
config = &MemcacheConfig{
Servers: []string{"localhost:11211"},
}
}
if len(config.Servers) == 0 {
config.Servers = []string{"localhost:11211"}
}
if config.MaxIdleConns == 0 {
config.MaxIdleConns = 2
}
if config.Timeout == 0 {
config.Timeout = 1 * time.Second
}
if config.Options == nil {
config.Options = &Options{
DefaultTTL: 5 * time.Minute,
}
}
client := memcache.New(config.Servers...)
client.MaxIdleConns = config.MaxIdleConns
client.Timeout = config.Timeout
// Test connection
if err := client.Ping(); err != nil {
return nil, fmt.Errorf("failed to connect to Memcache: %w", err)
}
return &MemcacheProvider{
client: client,
options: config.Options,
}, nil
}
// Get retrieves a value from the cache by key.
func (m *MemcacheProvider) Get(ctx context.Context, key string) ([]byte, bool) {
item, err := m.client.Get(key)
if err == memcache.ErrCacheMiss {
return nil, false
}
if err != nil {
return nil, false
}
return item.Value, true
}
// Set stores a value in the cache with the specified TTL.
func (m *MemcacheProvider) Set(ctx context.Context, key string, value []byte, ttl time.Duration) error {
if ttl == 0 {
ttl = m.options.DefaultTTL
}
item := &memcache.Item{
Key: key,
Value: value,
Expiration: int32(ttl.Seconds()),
}
return m.client.Set(item)
}
// Delete removes a key from the cache.
func (m *MemcacheProvider) Delete(ctx context.Context, key string) error {
err := m.client.Delete(key)
if err == memcache.ErrCacheMiss {
return nil
}
return err
}
// DeleteByPattern removes all keys matching the pattern.
// Note: Memcache does not support pattern-based deletion natively.
// This is a no-op for memcache and returns an error.
func (m *MemcacheProvider) DeleteByPattern(ctx context.Context, pattern string) error {
return fmt.Errorf("pattern-based deletion is not supported by Memcache")
}
// Clear removes all items from the cache.
func (m *MemcacheProvider) Clear(ctx context.Context) error {
return m.client.FlushAll()
}
// Exists checks if a key exists in the cache.
func (m *MemcacheProvider) Exists(ctx context.Context, key string) bool {
_, err := m.client.Get(key)
return err == nil
}
// Close closes the provider and releases any resources.
func (m *MemcacheProvider) Close() error {
// Memcache client doesn't have a close method
return nil
}
// Stats returns statistics about the cache provider.
// Note: Memcache provider returns limited statistics.
func (m *MemcacheProvider) Stats(ctx context.Context) (*CacheStats, error) {
stats := &CacheStats{
ProviderType: "memcache",
ProviderStats: map[string]any{
"note": "Memcache does not provide detailed statistics through the standard client",
},
}
return stats, nil
}

226
pkg/cache/provider_memory.go vendored Normal file
View File

@ -0,0 +1,226 @@
package cache
import (
"context"
"fmt"
"regexp"
"sync"
"time"
)
// memoryItem represents a cached item in memory.
type memoryItem struct {
Value []byte
Expiration time.Time
LastAccess time.Time
HitCount int64
}
// isExpired checks if the item has expired.
func (m *memoryItem) isExpired() bool {
if m.Expiration.IsZero() {
return false
}
return time.Now().After(m.Expiration)
}
// MemoryProvider is an in-memory implementation of the Provider interface.
type MemoryProvider struct {
mu sync.RWMutex
items map[string]*memoryItem
options *Options
hits int64
misses int64
}
// NewMemoryProvider creates a new in-memory cache provider.
func NewMemoryProvider(opts *Options) *MemoryProvider {
if opts == nil {
opts = &Options{
DefaultTTL: 5 * time.Minute,
MaxSize: 10000,
}
}
return &MemoryProvider{
items: make(map[string]*memoryItem),
options: opts,
}
}
// Get retrieves a value from the cache by key.
func (m *MemoryProvider) Get(ctx context.Context, key string) ([]byte, bool) {
m.mu.Lock()
defer m.mu.Unlock()
item, exists := m.items[key]
if !exists {
m.misses++
return nil, false
}
if item.isExpired() {
delete(m.items, key)
m.misses++
return nil, false
}
item.LastAccess = time.Now()
item.HitCount++
m.hits++
return item.Value, true
}
// Set stores a value in the cache with the specified TTL.
func (m *MemoryProvider) Set(ctx context.Context, key string, value []byte, ttl time.Duration) error {
m.mu.Lock()
defer m.mu.Unlock()
if ttl == 0 {
ttl = m.options.DefaultTTL
}
var expiration time.Time
if ttl > 0 {
expiration = time.Now().Add(ttl)
}
// Check max size and evict if necessary
if m.options.MaxSize > 0 && len(m.items) >= m.options.MaxSize {
if _, exists := m.items[key]; !exists {
m.evictOne()
}
}
m.items[key] = &memoryItem{
Value: value,
Expiration: expiration,
LastAccess: time.Now(),
}
return nil
}
// Delete removes a key from the cache.
func (m *MemoryProvider) Delete(ctx context.Context, key string) error {
m.mu.Lock()
defer m.mu.Unlock()
delete(m.items, key)
return nil
}
// DeleteByPattern removes all keys matching the pattern.
func (m *MemoryProvider) DeleteByPattern(ctx context.Context, pattern string) error {
m.mu.Lock()
defer m.mu.Unlock()
re, err := regexp.Compile(pattern)
if err != nil {
return fmt.Errorf("invalid pattern: %w", err)
}
for key := range m.items {
if re.MatchString(key) {
delete(m.items, key)
}
}
return nil
}
// Clear removes all items from the cache.
func (m *MemoryProvider) Clear(ctx context.Context) error {
m.mu.Lock()
defer m.mu.Unlock()
m.items = make(map[string]*memoryItem)
m.hits = 0
m.misses = 0
return nil
}
// Exists checks if a key exists in the cache.
func (m *MemoryProvider) Exists(ctx context.Context, key string) bool {
m.mu.RLock()
defer m.mu.RUnlock()
item, exists := m.items[key]
if !exists {
return false
}
return !item.isExpired()
}
// Close closes the provider and releases any resources.
func (m *MemoryProvider) Close() error {
m.mu.Lock()
defer m.mu.Unlock()
m.items = nil
return nil
}
// Stats returns statistics about the cache provider.
func (m *MemoryProvider) Stats(ctx context.Context) (*CacheStats, error) {
m.mu.RLock()
defer m.mu.RUnlock()
// Clean expired items first
validKeys := 0
for _, item := range m.items {
if !item.isExpired() {
validKeys++
}
}
return &CacheStats{
Hits: m.hits,
Misses: m.misses,
Keys: int64(validKeys),
ProviderType: "memory",
ProviderStats: map[string]any{
"capacity": m.options.MaxSize,
},
}, nil
}
// evictOne removes one item from the cache using LRU strategy.
func (m *MemoryProvider) evictOne() {
var oldestKey string
var oldestTime time.Time
for key, item := range m.items {
if item.isExpired() {
delete(m.items, key)
return
}
if oldestKey == "" || item.LastAccess.Before(oldestTime) {
oldestKey = key
oldestTime = item.LastAccess
}
}
if oldestKey != "" {
delete(m.items, oldestKey)
}
}
// CleanExpired removes all expired items from the cache.
func (m *MemoryProvider) CleanExpired(ctx context.Context) int {
m.mu.Lock()
defer m.mu.Unlock()
count := 0
for key, item := range m.items {
if item.isExpired() {
delete(m.items, key)
count++
}
}
return count
}

185
pkg/cache/provider_redis.go vendored Normal file
View File

@ -0,0 +1,185 @@
package cache
import (
"context"
"fmt"
"time"
"github.com/redis/go-redis/v9"
)
// RedisProvider is a Redis implementation of the Provider interface.
type RedisProvider struct {
client *redis.Client
options *Options
}
// RedisConfig contains Redis-specific configuration.
type RedisConfig struct {
// Host is the Redis server host (default: localhost)
Host string
// Port is the Redis server port (default: 6379)
Port int
// Password for Redis authentication (optional)
Password string
// DB is the Redis database number (default: 0)
DB int
// PoolSize is the maximum number of connections (default: 10)
PoolSize int
// Options contains general cache options
Options *Options
}
// NewRedisProvider creates a new Redis cache provider.
func NewRedisProvider(config *RedisConfig) (*RedisProvider, error) {
if config == nil {
config = &RedisConfig{
Host: "localhost",
Port: 6379,
DB: 0,
}
}
if config.Host == "" {
config.Host = "localhost"
}
if config.Port == 0 {
config.Port = 6379
}
if config.PoolSize == 0 {
config.PoolSize = 10
}
if config.Options == nil {
config.Options = &Options{
DefaultTTL: 5 * time.Minute,
}
}
client := redis.NewClient(&redis.Options{
Addr: fmt.Sprintf("%s:%d", config.Host, config.Port),
Password: config.Password,
DB: config.DB,
PoolSize: config.PoolSize,
})
// Test connection
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := client.Ping(ctx).Err(); err != nil {
return nil, fmt.Errorf("failed to connect to Redis: %w", err)
}
return &RedisProvider{
client: client,
options: config.Options,
}, nil
}
// Get retrieves a value from the cache by key.
func (r *RedisProvider) Get(ctx context.Context, key string) ([]byte, bool) {
val, err := r.client.Get(ctx, key).Bytes()
if err == redis.Nil {
return nil, false
}
if err != nil {
return nil, false
}
return val, true
}
// Set stores a value in the cache with the specified TTL.
func (r *RedisProvider) Set(ctx context.Context, key string, value []byte, ttl time.Duration) error {
if ttl == 0 {
ttl = r.options.DefaultTTL
}
return r.client.Set(ctx, key, value, ttl).Err()
}
// Delete removes a key from the cache.
func (r *RedisProvider) Delete(ctx context.Context, key string) error {
return r.client.Del(ctx, key).Err()
}
// DeleteByPattern removes all keys matching the pattern.
func (r *RedisProvider) DeleteByPattern(ctx context.Context, pattern string) error {
iter := r.client.Scan(ctx, 0, pattern, 0).Iterator()
pipe := r.client.Pipeline()
count := 0
for iter.Next(ctx) {
pipe.Del(ctx, iter.Val())
count++
// Execute pipeline in batches of 100
if count%100 == 0 {
if _, err := pipe.Exec(ctx); err != nil {
return err
}
pipe = r.client.Pipeline()
}
}
if err := iter.Err(); err != nil {
return err
}
// Execute remaining commands
if count%100 != 0 {
_, err := pipe.Exec(ctx)
return err
}
return nil
}
// Clear removes all items from the cache.
func (r *RedisProvider) Clear(ctx context.Context) error {
return r.client.FlushDB(ctx).Err()
}
// Exists checks if a key exists in the cache.
func (r *RedisProvider) Exists(ctx context.Context, key string) bool {
result, err := r.client.Exists(ctx, key).Result()
if err != nil {
return false
}
return result > 0
}
// Close closes the provider and releases any resources.
func (r *RedisProvider) Close() error {
return r.client.Close()
}
// Stats returns statistics about the cache provider.
func (r *RedisProvider) Stats(ctx context.Context) (*CacheStats, error) {
info, err := r.client.Info(ctx, "stats", "keyspace").Result()
if err != nil {
return nil, fmt.Errorf("failed to get Redis stats: %w", err)
}
dbSize, err := r.client.DBSize(ctx).Result()
if err != nil {
return nil, fmt.Errorf("failed to get DB size: %w", err)
}
// Parse stats from INFO command
// This is a simplified version - you may want to parse more detailed stats
stats := &CacheStats{
Keys: dbSize,
ProviderType: "redis",
ProviderStats: map[string]any{
"info": info,
},
}
return stats, nil
}

127
pkg/cache/query_cache.go vendored Normal file
View File

@ -0,0 +1,127 @@
package cache
import (
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"strings"
"github.com/bitechdev/ResolveSpec/pkg/common"
)
// QueryCacheKey represents the components used to build a cache key for query total count
type QueryCacheKey struct {
TableName string `json:"table_name"`
Filters []common.FilterOption `json:"filters"`
Sort []common.SortOption `json:"sort"`
CustomSQLWhere string `json:"custom_sql_where,omitempty"`
CustomSQLOr string `json:"custom_sql_or,omitempty"`
Expand []ExpandOptionKey `json:"expand,omitempty"`
Distinct bool `json:"distinct,omitempty"`
CursorForward string `json:"cursor_forward,omitempty"`
CursorBackward string `json:"cursor_backward,omitempty"`
}
// ExpandOptionKey represents expand options for cache key
type ExpandOptionKey struct {
Relation string `json:"relation"`
Where string `json:"where,omitempty"`
}
// BuildQueryCacheKey builds a cache key from query parameters for total count caching
// This is used to cache the total count of records matching a query
func BuildQueryCacheKey(tableName string, filters []common.FilterOption, sort []common.SortOption, customWhere, customOr string) string {
key := QueryCacheKey{
TableName: tableName,
Filters: filters,
Sort: sort,
CustomSQLWhere: customWhere,
CustomSQLOr: customOr,
}
// Serialize to JSON for consistent hashing
jsonData, err := json.Marshal(key)
if err != nil {
// Fallback to simple string concatenation if JSON fails
return hashString(fmt.Sprintf("%s_%v_%v_%s_%s", tableName, filters, sort, customWhere, customOr))
}
return hashString(string(jsonData))
}
// BuildExtendedQueryCacheKey builds a cache key for extended query options (restheadspec)
// Includes expand, distinct, and cursor pagination options
func BuildExtendedQueryCacheKey(tableName string, filters []common.FilterOption, sort []common.SortOption,
customWhere, customOr string, expandOpts []interface{}, distinct bool, cursorFwd, cursorBwd string) string {
key := QueryCacheKey{
TableName: tableName,
Filters: filters,
Sort: sort,
CustomSQLWhere: customWhere,
CustomSQLOr: customOr,
Distinct: distinct,
CursorForward: cursorFwd,
CursorBackward: cursorBwd,
}
// Convert expand options to cache key format
if len(expandOpts) > 0 {
key.Expand = make([]ExpandOptionKey, 0, len(expandOpts))
for _, exp := range expandOpts {
// Type assert to get the expand option fields we care about for caching
if expMap, ok := exp.(map[string]interface{}); ok {
expKey := ExpandOptionKey{}
if rel, ok := expMap["relation"].(string); ok {
expKey.Relation = rel
}
if where, ok := expMap["where"].(string); ok {
expKey.Where = where
}
key.Expand = append(key.Expand, expKey)
}
}
// Sort expand options for consistent hashing (already sorted by relation name above)
}
// Serialize to JSON for consistent hashing
jsonData, err := json.Marshal(key)
if err != nil {
// Fallback to simple string concatenation if JSON fails
return hashString(fmt.Sprintf("%s_%v_%v_%s_%s_%v_%v_%s_%s",
tableName, filters, sort, customWhere, customOr, expandOpts, distinct, cursorFwd, cursorBwd))
}
return hashString(string(jsonData))
}
// hashString computes SHA256 hash of a string
func hashString(s string) string {
h := sha256.New()
h.Write([]byte(s))
return hex.EncodeToString(h.Sum(nil))
}
// GetQueryTotalCacheKey returns a formatted cache key for storing/retrieving total count
func GetQueryTotalCacheKey(hash string) string {
return fmt.Sprintf("query_total:%s", hash)
}
// CachedTotal represents a cached total count
type CachedTotal struct {
Total int `json:"total"`
}
// InvalidateCacheForTable removes all cached totals for a specific table
// This should be called when data in the table changes (insert/update/delete)
func InvalidateCacheForTable(ctx context.Context, tableName string) error {
cache := GetDefaultCache()
// Build a pattern to match all query totals for this table
// Note: This requires pattern matching support in the provider
pattern := fmt.Sprintf("query_total:*%s*", strings.ToLower(tableName))
return cache.DeleteByPattern(ctx, pattern)
}

151
pkg/cache/query_cache_test.go vendored Normal file
View File

@ -0,0 +1,151 @@
package cache
import (
"context"
"testing"
"time"
"github.com/bitechdev/ResolveSpec/pkg/common"
)
func TestBuildQueryCacheKey(t *testing.T) {
filters := []common.FilterOption{
{Column: "name", Operator: "eq", Value: "test"},
{Column: "age", Operator: "gt", Value: 25},
}
sorts := []common.SortOption{
{Column: "name", Direction: "asc"},
}
// Generate cache key
key1 := BuildQueryCacheKey("users", filters, sorts, "status = 'active'", "")
// Same parameters should generate same key
key2 := BuildQueryCacheKey("users", filters, sorts, "status = 'active'", "")
if key1 != key2 {
t.Errorf("Expected same cache keys for identical parameters, got %s and %s", key1, key2)
}
// Different parameters should generate different key
key3 := BuildQueryCacheKey("users", filters, sorts, "status = 'inactive'", "")
if key1 == key3 {
t.Errorf("Expected different cache keys for different parameters, got %s and %s", key1, key3)
}
}
func TestBuildExtendedQueryCacheKey(t *testing.T) {
filters := []common.FilterOption{
{Column: "name", Operator: "eq", Value: "test"},
}
sorts := []common.SortOption{
{Column: "name", Direction: "asc"},
}
expandOpts := []interface{}{
map[string]interface{}{
"relation": "posts",
"where": "status = 'published'",
},
}
// Generate cache key
key1 := BuildExtendedQueryCacheKey("users", filters, sorts, "", "", expandOpts, false, "", "")
// Same parameters should generate same key
key2 := BuildExtendedQueryCacheKey("users", filters, sorts, "", "", expandOpts, false, "", "")
if key1 != key2 {
t.Errorf("Expected same cache keys for identical parameters")
}
// Different distinct value should generate different key
key3 := BuildExtendedQueryCacheKey("users", filters, sorts, "", "", expandOpts, true, "", "")
if key1 == key3 {
t.Errorf("Expected different cache keys for different distinct values")
}
}
func TestGetQueryTotalCacheKey(t *testing.T) {
hash := "abc123"
key := GetQueryTotalCacheKey(hash)
expected := "query_total:abc123"
if key != expected {
t.Errorf("Expected %s, got %s", expected, key)
}
}
func TestCachedTotalIntegration(t *testing.T) {
// Initialize cache with memory provider for testing
UseMemory(&Options{
DefaultTTL: 1 * time.Minute,
MaxSize: 100,
})
ctx := context.Background()
// Create test data
filters := []common.FilterOption{
{Column: "status", Operator: "eq", Value: "active"},
}
sorts := []common.SortOption{
{Column: "created_at", Direction: "desc"},
}
// Build cache key
cacheKeyHash := BuildQueryCacheKey("test_table", filters, sorts, "", "")
cacheKey := GetQueryTotalCacheKey(cacheKeyHash)
// Store a total count in cache
totalToCache := CachedTotal{Total: 42}
err := GetDefaultCache().Set(ctx, cacheKey, totalToCache, time.Minute)
if err != nil {
t.Fatalf("Failed to set cache: %v", err)
}
// Retrieve from cache
var cachedTotal CachedTotal
err = GetDefaultCache().Get(ctx, cacheKey, &cachedTotal)
if err != nil {
t.Fatalf("Failed to get from cache: %v", err)
}
if cachedTotal.Total != 42 {
t.Errorf("Expected total 42, got %d", cachedTotal.Total)
}
// Test cache miss
nonExistentKey := GetQueryTotalCacheKey("nonexistent")
var missedTotal CachedTotal
err = GetDefaultCache().Get(ctx, nonExistentKey, &missedTotal)
if err == nil {
t.Errorf("Expected error for cache miss, got nil")
}
}
func TestHashString(t *testing.T) {
input1 := "test string"
input2 := "test string"
input3 := "different string"
hash1 := hashString(input1)
hash2 := hashString(input2)
hash3 := hashString(input3)
// Same input should produce same hash
if hash1 != hash2 {
t.Errorf("Expected same hash for identical inputs")
}
// Different input should produce different hash
if hash1 == hash3 {
t.Errorf("Expected different hash for different inputs")
}
// Hash should be hex encoded SHA256 (64 characters)
if len(hash1) != 64 {
t.Errorf("Expected hash length of 64, got %d", len(hash1))
}
}

View File

@ -21,15 +21,23 @@ import (
// Handler handles function-based SQL API requests // Handler handles function-based SQL API requests
type Handler struct { type Handler struct {
db common.Database db common.Database
hooks *HookRegistry
} }
// NewHandler creates a new function API handler // NewHandler creates a new function API handler
func NewHandler(db common.Database) *Handler { func NewHandler(db common.Database) *Handler {
return &Handler{ return &Handler{
db: db, db: db,
hooks: NewHookRegistry(),
} }
} }
// Hooks returns the hook registry for this handler
// Use this to register custom hooks for operations
func (h *Handler) Hooks() *HookRegistry {
return h.hooks
}
// HTTPFuncType is a function type for HTTP handlers // HTTPFuncType is a function type for HTTP handlers
type HTTPFuncType func(http.ResponseWriter, *http.Request) type HTTPFuncType func(http.ResponseWriter, *http.Request)
@ -64,18 +72,77 @@ func (h *Handler) SqlQueryList(sqlquery string, pNoCount, pBlankparms, pAllowFil
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
// Initialize hook context
hookCtx := &HookContext{
Context: ctx,
Handler: h,
Request: r,
Writer: w,
SQLQuery: sqlquery,
Variables: variables,
InputVars: inputvars,
MetaInfo: metainfo,
PropQry: propQry,
UserContext: userCtx,
NoCount: pNoCount,
BlankParams: pBlankparms,
AllowFilter: pAllowFilter,
ComplexAPI: complexAPI,
}
// Execute BeforeQueryList hook
if err := h.hooks.Execute(BeforeQueryList, hookCtx); err != nil {
logger.Error("BeforeQueryList hook failed: %v", err)
sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err)
return
}
// Check if hook aborted the operation
if hookCtx.Abort {
if hookCtx.AbortCode == 0 {
hookCtx.AbortCode = http.StatusBadRequest
}
sendError(w, hookCtx.AbortCode, "operation_aborted", hookCtx.AbortMessage, nil)
return
}
// Use potentially modified SQL query and variables from hooks
sqlquery = hookCtx.SQLQuery
variables = hookCtx.Variables
complexAPI = hookCtx.ComplexAPI
// Extract input variables from SQL query (placeholders like [variable]) // Extract input variables from SQL query (placeholders like [variable])
sqlquery = h.extractInputVariables(sqlquery, &inputvars) sqlquery = h.extractInputVariables(sqlquery, &inputvars)
// Merge URL path parameters // Merge URL path parameters
sqlquery = h.mergePathParams(r, sqlquery, variables) sqlquery = h.mergePathParams(r, sqlquery, variables)
// Parse comprehensive parameters from headers and query string
reqParams := h.ParseParameters(r)
complexAPI = reqParams.ComplexAPI
// Merge query string parameters // Merge query string parameters
sqlquery = h.mergeQueryParams(r, sqlquery, variables, pAllowFilter, propQry) sqlquery = h.mergeQueryParams(r, sqlquery, variables, pAllowFilter, propQry)
// Merge header parameters // Merge header parameters
sqlquery = h.mergeHeaderParams(r, sqlquery, variables, propQry, &complexAPI) sqlquery = h.mergeHeaderParams(r, sqlquery, variables, propQry, &complexAPI)
// Apply filters from parsed parameters (if not already applied by pAllowFilter)
if !pAllowFilter {
sqlquery = h.ApplyFilters(sqlquery, reqParams)
}
// Apply field selection
sqlquery = h.ApplyFieldSelection(sqlquery, reqParams)
// Apply DISTINCT if requested
sqlquery = h.ApplyDistinct(sqlquery, reqParams)
// Override pNoCount if skipcount is specified
if reqParams.SkipCount {
pNoCount = true
}
// Build metainfo // Build metainfo
metainfo["ipaddress"] = getIPAddress(r) metainfo["ipaddress"] = getIPAddress(r)
metainfo["url"] = r.RequestURI metainfo["url"] = r.RequestURI
@ -95,12 +162,32 @@ func (h *Handler) SqlQueryList(sqlquery string, pNoCount, pBlankparms, pAllowFil
} }
} }
// Update hook context with latest SQL query and variables
hookCtx.SQLQuery = sqlquery
hookCtx.Variables = variables
hookCtx.InputVars = inputvars
// Execute query within transaction // Execute query within transaction
err := h.db.RunInTransaction(ctx, func(tx common.Database) error { err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
sqlqueryCnt := sqlquery sqlqueryCnt := sqlquery
// Parse sorting and pagination parameters // Parse sorting and pagination parameters
sortcols, limit, offset := h.parsePaginationParams(r) sortcols, limit, offset := h.parsePaginationParams(r)
// Override with parsed parameters if available
if reqParams.SortColumns != "" {
sortcols = reqParams.SortColumns
}
if reqParams.Limit > 0 {
limit = reqParams.Limit
}
if reqParams.Offset > 0 {
offset = reqParams.Offset
}
hookCtx.SortColumns = sortcols
hookCtx.Limit = limit
hookCtx.Offset = offset
fromPos := strings.Index(strings.ToLower(sqlquery), "from ") fromPos := strings.Index(strings.ToLower(sqlquery), "from ")
orderbyPos := strings.Index(strings.ToLower(sqlquery), "order by") orderbyPos := strings.Index(strings.ToLower(sqlquery), "order by")
@ -127,6 +214,16 @@ func (h *Handler) SqlQueryList(sqlquery string, pNoCount, pBlankparms, pAllowFil
total = countResult.Count total = countResult.Count
} }
// Execute BeforeSQLExec hook
hookCtx.SQLQuery = sqlquery
if err := h.hooks.Execute(BeforeSQLExec, hookCtx); err != nil {
logger.Error("BeforeSQLExec hook failed: %v", err)
sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err)
return err
}
// Use potentially modified SQL query from hook
sqlquery = hookCtx.SQLQuery
// Execute main query // Execute main query
rows := make([]map[string]interface{}, 0) rows := make([]map[string]interface{}, 0)
if err := tx.Query(ctx, &rows, sqlquery); err != nil { if err := tx.Query(ctx, &rows, sqlquery); err != nil {
@ -140,6 +237,20 @@ func (h *Handler) SqlQueryList(sqlquery string, pNoCount, pBlankparms, pAllowFil
total = int64(len(dbobjlist)) total = int64(len(dbobjlist))
} }
// Execute AfterSQLExec hook
hookCtx.Result = dbobjlist
hookCtx.Total = total
if err := h.hooks.Execute(AfterSQLExec, hookCtx); err != nil {
logger.Error("AfterSQLExec hook failed: %v", err)
sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err)
return err
}
// Use potentially modified result from hook
if modifiedResult, ok := hookCtx.Result.([]map[string]interface{}); ok {
dbobjlist = modifiedResult
}
total = hookCtx.Total
return nil return nil
}) })
@ -148,6 +259,21 @@ func (h *Handler) SqlQueryList(sqlquery string, pNoCount, pBlankparms, pAllowFil
return return
} }
// Execute AfterQueryList hook
hookCtx.Result = dbobjlist
hookCtx.Total = total
hookCtx.Error = err
if err := h.hooks.Execute(AfterQueryList, hookCtx); err != nil {
logger.Error("AfterQueryList hook failed: %v", err)
sendError(w, http.StatusInternalServerError, "hook_error", "Hook execution failed", err)
return
}
// Use potentially modified result from hook
if modifiedResult, ok := hookCtx.Result.([]map[string]interface{}); ok {
dbobjlist = modifiedResult
}
total = hookCtx.Total
// Set response headers // Set response headers
respOffset := 0 respOffset := 0
if offsetStr := r.URL.Query().Get("offset"); offsetStr != "" { if offsetStr := r.URL.Query().Get("offset"); offsetStr != "" {
@ -159,11 +285,63 @@ func (h *Handler) SqlQueryList(sqlquery string, pNoCount, pBlankparms, pAllowFil
w.Header().Set("Content-Range", fmt.Sprintf("items %d-%d/%d", respOffset, respOffset+len(dbobjlist), total)) w.Header().Set("Content-Range", fmt.Sprintf("items %d-%d/%d", respOffset, respOffset+len(dbobjlist), total))
logger.Info("Serving: Records %d of %d", len(dbobjlist), total) logger.Info("Serving: Records %d of %d", len(dbobjlist), total)
// Execute BeforeResponse hook
hookCtx.Result = dbobjlist
hookCtx.Total = total
if err := h.hooks.Execute(BeforeResponse, hookCtx); err != nil {
logger.Error("BeforeResponse hook failed: %v", err)
sendError(w, http.StatusInternalServerError, "hook_error", "Hook execution failed", err)
return
}
// Use potentially modified result from hook
if modifiedResult, ok := hookCtx.Result.([]map[string]interface{}); ok {
dbobjlist = modifiedResult
}
if len(dbobjlist) == 0 { if len(dbobjlist) == 0 {
w.Write([]byte("[]")) w.Write([]byte("[]"))
return return
} }
// Format response based on response format
switch reqParams.ResponseFormat {
case "syncfusion":
// Syncfusion format: { result: data, count: total }
response := map[string]interface{}{
"result": dbobjlist,
"count": total,
}
data, err := json.Marshal(response)
if err != nil {
sendError(w, http.StatusInternalServerError, "json_error", "Could not marshal response", err)
} else {
if int64(len(dbobjlist)) < total {
w.WriteHeader(http.StatusPartialContent)
}
w.Write(data)
}
case "detail":
// Detail format: complex API with metadata
metaobj := map[string]interface{}{
"items": dbobjlist,
"count": fmt.Sprintf("%d", len(dbobjlist)),
"total": fmt.Sprintf("%d", total),
"tablename": r.URL.Path,
"tableprefix": "gsql",
}
data, err := json.Marshal(metaobj)
if err != nil {
sendError(w, http.StatusInternalServerError, "json_error", "Could not marshal response", err)
} else {
if int64(len(dbobjlist)) < total {
w.WriteHeader(http.StatusPartialContent)
}
w.Write(data)
}
default:
// Simple format: just return the data array (or complex API if requested)
if complexAPI { if complexAPI {
metaobj := map[string]interface{}{ metaobj := map[string]interface{}{
"items": dbobjlist, "items": dbobjlist,
@ -172,7 +350,6 @@ func (h *Handler) SqlQueryList(sqlquery string, pNoCount, pBlankparms, pAllowFil
"tablename": r.URL.Path, "tablename": r.URL.Path,
"tableprefix": "gsql", "tableprefix": "gsql",
} }
data, err := json.Marshal(metaobj) data, err := json.Marshal(metaobj)
if err != nil { if err != nil {
sendError(w, http.StatusInternalServerError, "json_error", "Could not marshal response", err) sendError(w, http.StatusInternalServerError, "json_error", "Could not marshal response", err)
@ -195,6 +372,7 @@ func (h *Handler) SqlQueryList(sqlquery string, pNoCount, pBlankparms, pAllowFil
} }
} }
} }
}
// SqlQuery creates an HTTP handler that executes a SQL query and returns a single record // SqlQuery creates an HTTP handler that executes a SQL query and returns a single record
func (h *Handler) SqlQuery(sqlquery string, pBlankparms bool) HTTPFuncType { func (h *Handler) SqlQuery(sqlquery string, pBlankparms bool) HTTPFuncType {
@ -215,6 +393,7 @@ func (h *Handler) SqlQuery(sqlquery string, pBlankparms bool) HTTPFuncType {
metainfo := make(map[string]interface{}) metainfo := make(map[string]interface{})
variables := make(map[string]interface{}) variables := make(map[string]interface{})
dbobj := make(map[string]interface{}) dbobj := make(map[string]interface{})
complexAPI := false
// Get user context from security package // Get user context from security package
userCtx, ok := security.GetUserContext(ctx) userCtx, ok := security.GetUserContext(ctx)
@ -225,18 +404,67 @@ func (h *Handler) SqlQuery(sqlquery string, pBlankparms bool) HTTPFuncType {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
// Initialize hook context
hookCtx := &HookContext{
Context: ctx,
Handler: h,
Request: r,
Writer: w,
SQLQuery: sqlquery,
Variables: variables,
InputVars: inputvars,
MetaInfo: metainfo,
PropQry: propQry,
UserContext: userCtx,
BlankParams: pBlankparms,
ComplexAPI: complexAPI,
}
// Execute BeforeQuery hook
if err := h.hooks.Execute(BeforeQuery, hookCtx); err != nil {
logger.Error("BeforeQuery hook failed: %v", err)
sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err)
return
}
// Check if hook aborted the operation
if hookCtx.Abort {
if hookCtx.AbortCode == 0 {
hookCtx.AbortCode = http.StatusBadRequest
}
sendError(w, hookCtx.AbortCode, "operation_aborted", hookCtx.AbortMessage, nil)
return
}
// Use potentially modified SQL query and variables from hooks
sqlquery = hookCtx.SQLQuery
variables = hookCtx.Variables
// Extract input variables from SQL query // Extract input variables from SQL query
sqlquery = h.extractInputVariables(sqlquery, &inputvars) sqlquery = h.extractInputVariables(sqlquery, &inputvars)
// Merge URL path parameters // Merge URL path parameters
sqlquery = h.mergePathParams(r, sqlquery, variables) sqlquery = h.mergePathParams(r, sqlquery, variables)
// Parse comprehensive parameters from headers and query string
reqParams := h.ParseParameters(r)
complexAPI = reqParams.ComplexAPI
// Merge query string parameters // Merge query string parameters
sqlquery = h.mergeQueryParams(r, sqlquery, variables, false, propQry) sqlquery = h.mergeQueryParams(r, sqlquery, variables, false, propQry)
// Merge header parameters // Merge header parameters
complexAPI := false
sqlquery = h.mergeHeaderParams(r, sqlquery, variables, propQry, &complexAPI) sqlquery = h.mergeHeaderParams(r, sqlquery, variables, propQry, &complexAPI)
hookCtx.ComplexAPI = complexAPI
// Apply filters from parsed parameters
sqlquery = h.ApplyFilters(sqlquery, reqParams)
// Apply field selection
sqlquery = h.ApplyFieldSelection(sqlquery, reqParams)
// Apply DISTINCT if requested
sqlquery = h.ApplyDistinct(sqlquery, reqParams)
// Build metainfo // Build metainfo
metainfo["ipaddress"] = getIPAddress(r) metainfo["ipaddress"] = getIPAddress(r)
@ -272,8 +500,22 @@ func (h *Handler) SqlQuery(sqlquery string, pBlankparms bool) HTTPFuncType {
} }
} }
// Update hook context with latest SQL query and variables
hookCtx.SQLQuery = sqlquery
hookCtx.Variables = variables
hookCtx.InputVars = inputvars
// Execute query within transaction // Execute query within transaction
err := h.db.RunInTransaction(ctx, func(tx common.Database) error { err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
// Execute BeforeSQLExec hook
if err := h.hooks.Execute(BeforeSQLExec, hookCtx); err != nil {
logger.Error("BeforeSQLExec hook failed: %v", err)
sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err)
return err
}
// Use potentially modified SQL query from hook
sqlquery = hookCtx.SQLQuery
// Execute main query // Execute main query
rows := make([]map[string]interface{}, 0) rows := make([]map[string]interface{}, 0)
if err := tx.Query(ctx, &rows, sqlquery); err != nil { if err := tx.Query(ctx, &rows, sqlquery); err != nil {
@ -285,6 +527,18 @@ func (h *Handler) SqlQuery(sqlquery string, pBlankparms bool) HTTPFuncType {
dbobj = rows[0] dbobj = rows[0]
} }
// Execute AfterSQLExec hook
hookCtx.Result = dbobj
if err := h.hooks.Execute(AfterSQLExec, hookCtx); err != nil {
logger.Error("AfterSQLExec hook failed: %v", err)
sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err)
return err
}
// Use potentially modified result from hook
if modifiedResult, ok := hookCtx.Result.(map[string]interface{}); ok {
dbobj = modifiedResult
}
return nil return nil
}) })
@ -293,6 +547,31 @@ func (h *Handler) SqlQuery(sqlquery string, pBlankparms bool) HTTPFuncType {
return return
} }
// Execute AfterQuery hook
hookCtx.Result = dbobj
hookCtx.Error = err
if err := h.hooks.Execute(AfterQuery, hookCtx); err != nil {
logger.Error("AfterQuery hook failed: %v", err)
sendError(w, http.StatusInternalServerError, "hook_error", "Hook execution failed", err)
return
}
// Use potentially modified result from hook
if modifiedResult, ok := hookCtx.Result.(map[string]interface{}); ok {
dbobj = modifiedResult
}
// Execute BeforeResponse hook
hookCtx.Result = dbobj
if err := h.hooks.Execute(BeforeResponse, hookCtx); err != nil {
logger.Error("BeforeResponse hook failed: %v", err)
sendError(w, http.StatusInternalServerError, "hook_error", "Hook execution failed", err)
return
}
// Use potentially modified result from hook
if modifiedResult, ok := hookCtx.Result.(map[string]interface{}); ok {
dbobj = modifiedResult
}
// Check if response should be root-level data // Check if response should be root-level data
if val, ok := dbobj["root_as_data"]; ok { if val, ok := dbobj["root_as_data"]; ok {
data, err := json.Marshal(val) data, err := json.Marshal(val)

View File

@ -0,0 +1,837 @@
package funcspec
import (
"bytes"
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/security"
)
// MockDatabase implements common.Database interface for testing
type MockDatabase struct {
QueryFunc func(ctx context.Context, dest interface{}, query string, args ...interface{}) error
ExecFunc func(ctx context.Context, query string, args ...interface{}) (common.Result, error)
RunInTransactionFunc func(ctx context.Context, fn func(common.Database) error) error
}
func (m *MockDatabase) NewSelect() common.SelectQuery {
return nil
}
func (m *MockDatabase) NewInsert() common.InsertQuery {
return nil
}
func (m *MockDatabase) NewUpdate() common.UpdateQuery {
return nil
}
func (m *MockDatabase) NewDelete() common.DeleteQuery {
return nil
}
func (m *MockDatabase) Exec(ctx context.Context, query string, args ...interface{}) (common.Result, error) {
if m.ExecFunc != nil {
return m.ExecFunc(ctx, query, args...)
}
return &MockResult{rows: 0}, nil
}
func (m *MockDatabase) Query(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
if m.QueryFunc != nil {
return m.QueryFunc(ctx, dest, query, args...)
}
return nil
}
func (m *MockDatabase) BeginTx(ctx context.Context) (common.Database, error) {
return m, nil
}
func (m *MockDatabase) CommitTx(ctx context.Context) error {
return nil
}
func (m *MockDatabase) RollbackTx(ctx context.Context) error {
return nil
}
func (m *MockDatabase) RunInTransaction(ctx context.Context, fn func(common.Database) error) error {
if m.RunInTransactionFunc != nil {
return m.RunInTransactionFunc(ctx, fn)
}
return fn(m)
}
// MockResult implements common.Result interface for testing
type MockResult struct {
rows int64
id int64
}
func (m *MockResult) RowsAffected() int64 {
return m.rows
}
func (m *MockResult) LastInsertId() (int64, error) {
return m.id, nil
}
// Helper function to create a test request with user context
func createTestRequest(method, path string, queryParams map[string]string, headers map[string]string, body []byte) *http.Request {
u, _ := url.Parse(path)
if queryParams != nil {
q := u.Query()
for k, v := range queryParams {
q.Set(k, v)
}
u.RawQuery = q.Encode()
}
var bodyReader *bytes.Reader
if body != nil {
bodyReader = bytes.NewReader(body)
} else {
bodyReader = bytes.NewReader([]byte{})
}
req := httptest.NewRequest(method, u.String(), bodyReader)
if headers != nil {
for k, v := range headers {
req.Header.Set(k, v)
}
}
// Add user context
userCtx := &security.UserContext{
UserID: 1,
UserName: "testuser",
SessionID: "test-session-123",
}
ctx := context.WithValue(req.Context(), security.UserContextKey, userCtx)
req = req.WithContext(ctx)
return req
}
// TestNewHandler tests handler creation
func TestNewHandler(t *testing.T) {
db := &MockDatabase{}
handler := NewHandler(db)
if handler == nil {
t.Fatal("Expected handler to be created, got nil")
}
if handler.db != db {
t.Error("Expected handler to have the provided database")
}
if handler.hooks == nil {
t.Error("Expected handler to have a hook registry")
}
}
// TestHandlerHooks tests the Hooks method
func TestHandlerHooks(t *testing.T) {
handler := NewHandler(&MockDatabase{})
hooks := handler.Hooks()
if hooks == nil {
t.Fatal("Expected hooks registry to be non-nil")
}
// Should return the same instance
hooks2 := handler.Hooks()
if hooks != hooks2 {
t.Error("Expected Hooks() to return the same registry instance")
}
}
// TestExtractInputVariables tests the extractInputVariables function
func TestExtractInputVariables(t *testing.T) {
handler := NewHandler(&MockDatabase{})
tests := []struct {
name string
sqlQuery string
expectedVars []string
}{
{
name: "No variables",
sqlQuery: "SELECT * FROM users",
expectedVars: []string{},
},
{
name: "Single variable",
sqlQuery: "SELECT * FROM users WHERE id = [user_id]",
expectedVars: []string{"[user_id]"},
},
{
name: "Multiple variables",
sqlQuery: "SELECT * FROM users WHERE id = [user_id] AND name = [user_name]",
expectedVars: []string{"[user_id]", "[user_name]"},
},
{
name: "Nested brackets",
sqlQuery: "SELECT * FROM users WHERE data::jsonb @> '[field]'::jsonb AND id = [user_id]",
expectedVars: []string{"[field]", "[user_id]"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
inputvars := make([]string, 0)
result := handler.extractInputVariables(tt.sqlQuery, &inputvars)
if result != tt.sqlQuery {
t.Errorf("Expected SQL query to be unchanged, got %s", result)
}
if len(inputvars) != len(tt.expectedVars) {
t.Errorf("Expected %d variables, got %d: %v", len(tt.expectedVars), len(inputvars), inputvars)
return
}
for i, expected := range tt.expectedVars {
if inputvars[i] != expected {
t.Errorf("Expected variable %d to be %s, got %s", i, expected, inputvars[i])
}
}
})
}
}
// TestValidSQL tests the SQL sanitization function
func TestValidSQL(t *testing.T) {
tests := []struct {
name string
input string
mode string
expected string
}{
{
name: "Column name with valid characters",
input: "user_id",
mode: "colname",
expected: "user_id",
},
{
name: "Column name with dots (table.column)",
input: "users.user_id",
mode: "colname",
expected: "users.user_id",
},
{
name: "Column name with SQL injection attempt",
input: "id'; DROP TABLE users--",
mode: "colname",
expected: "idDROPTABLEusers",
},
{
name: "Column value with single quotes",
input: "O'Brien",
mode: "colvalue",
expected: "O''Brien",
},
{
name: "Select with dangerous keywords",
input: "name, email; DROP TABLE users",
mode: "select",
expected: "name, email TABLE users",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := ValidSQL(tt.input, tt.mode)
if result != tt.expected {
t.Errorf("ValidSQL(%q, %q) = %q, expected %q", tt.input, tt.mode, result, tt.expected)
}
})
}
}
// TestIsNumeric tests the IsNumeric function
func TestIsNumeric(t *testing.T) {
tests := []struct {
input string
expected bool
}{
{"123", true},
{"123.45", true},
{"-123", true},
{"-123.45", true},
{"0", true},
{"abc", false},
{"12.34.56", false},
{"", false},
{"123abc", false},
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
result := IsNumeric(tt.input)
if result != tt.expected {
t.Errorf("IsNumeric(%q) = %v, expected %v", tt.input, result, tt.expected)
}
})
}
}
// TestSqlQryWhere tests the WHERE clause manipulation
func TestSqlQryWhere(t *testing.T) {
tests := []struct {
name string
sqlQuery string
condition string
expected string
}{
{
name: "Add WHERE to query without WHERE",
sqlQuery: "SELECT * FROM users",
condition: "status = 'active'",
expected: "SELECT * FROM users WHERE status = 'active' ",
},
{
name: "Add AND to query with existing WHERE",
sqlQuery: "SELECT * FROM users WHERE id > 0",
condition: "status = 'active'",
expected: "SELECT * FROM users WHERE id > 0 AND status = 'active' ",
},
{
name: "Add WHERE before ORDER BY",
sqlQuery: "SELECT * FROM users ORDER BY name",
condition: "status = 'active'",
expected: "SELECT * FROM users WHERE status = 'active' ORDER BY name",
},
{
name: "Add WHERE before GROUP BY",
sqlQuery: "SELECT COUNT(*) FROM users GROUP BY department",
condition: "status = 'active'",
expected: "SELECT COUNT(*) FROM users WHERE status = 'active' GROUP BY department",
},
{
name: "Add WHERE before LIMIT",
sqlQuery: "SELECT * FROM users LIMIT 10",
condition: "status = 'active'",
expected: "SELECT * FROM users WHERE status = 'active' LIMIT 10",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := sqlQryWhere(tt.sqlQuery, tt.condition)
if result != tt.expected {
t.Errorf("sqlQryWhere() = %q, expected %q", result, tt.expected)
}
})
}
}
// TestGetIPAddress tests IP address extraction
func TestGetIPAddress(t *testing.T) {
tests := []struct {
name string
setupReq func() *http.Request
expected string
}{
{
name: "X-Forwarded-For header",
setupReq: func() *http.Request {
req := httptest.NewRequest("GET", "/test", nil)
req.Header.Set("X-Forwarded-For", "192.168.1.100, 10.0.0.1")
return req
},
expected: "192.168.1.100",
},
{
name: "X-Real-IP header",
setupReq: func() *http.Request {
req := httptest.NewRequest("GET", "/test", nil)
req.Header.Set("X-Real-IP", "192.168.1.200")
return req
},
expected: "192.168.1.200",
},
{
name: "RemoteAddr fallback",
setupReq: func() *http.Request {
req := httptest.NewRequest("GET", "/test", nil)
req.RemoteAddr = "192.168.1.1:12345"
return req
},
expected: "192.168.1.1:12345",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := tt.setupReq()
result := getIPAddress(req)
if result != tt.expected {
t.Errorf("getIPAddress() = %q, expected %q", result, tt.expected)
}
})
}
}
// TestParsePaginationParams tests pagination parameter parsing
func TestParsePaginationParams(t *testing.T) {
handler := NewHandler(&MockDatabase{})
tests := []struct {
name string
queryParams map[string]string
expectedSort string
expectedLimit int
expectedOffset int
}{
{
name: "No parameters - defaults",
queryParams: map[string]string{},
expectedSort: "",
expectedLimit: 20,
expectedOffset: 0,
},
{
name: "All parameters provided",
queryParams: map[string]string{
"sort": "name,-created_at",
"limit": "100",
"offset": "50",
},
expectedSort: "name,-created_at",
expectedLimit: 100,
expectedOffset: 50,
},
{
name: "Invalid limit - use default",
queryParams: map[string]string{
"limit": "invalid",
},
expectedSort: "",
expectedLimit: 20,
expectedOffset: 0,
},
{
name: "Negative offset - use default",
queryParams: map[string]string{
"offset": "-10",
},
expectedSort: "",
expectedLimit: 20,
expectedOffset: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := createTestRequest("GET", "/test", tt.queryParams, nil, nil)
sort, limit, offset := handler.parsePaginationParams(req)
if sort != tt.expectedSort {
t.Errorf("Expected sort=%q, got %q", tt.expectedSort, sort)
}
if limit != tt.expectedLimit {
t.Errorf("Expected limit=%d, got %d", tt.expectedLimit, limit)
}
if offset != tt.expectedOffset {
t.Errorf("Expected offset=%d, got %d", tt.expectedOffset, offset)
}
})
}
}
// TestSqlQuery tests the SqlQuery handler for single record queries
func TestSqlQuery(t *testing.T) {
tests := []struct {
name string
sqlQuery string
blankParams bool
queryParams map[string]string
headers map[string]string
setupDB func() *MockDatabase
expectedStatus int
validateResp func(t *testing.T, body []byte)
}{
{
name: "Basic query - returns single record",
sqlQuery: "SELECT * FROM users WHERE id = 1",
blankParams: false,
setupDB: func() *MockDatabase {
return &MockDatabase{
RunInTransactionFunc: func(ctx context.Context, fn func(common.Database) error) error {
db := &MockDatabase{
QueryFunc: func(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
rows := dest.(*[]map[string]interface{})
*rows = []map[string]interface{}{
{"id": float64(1), "name": "Test User", "email": "test@example.com"},
}
return nil
},
}
return fn(db)
},
}
},
expectedStatus: 200,
validateResp: func(t *testing.T, body []byte) {
var result map[string]interface{}
if err := json.Unmarshal(body, &result); err != nil {
t.Fatalf("Failed to unmarshal response: %v", err)
}
if result["name"] != "Test User" {
t.Errorf("Expected name='Test User', got %v", result["name"])
}
},
},
{
name: "Query with no results",
sqlQuery: "SELECT * FROM users WHERE id = 999",
blankParams: false,
setupDB: func() *MockDatabase {
return &MockDatabase{
RunInTransactionFunc: func(ctx context.Context, fn func(common.Database) error) error {
db := &MockDatabase{
QueryFunc: func(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
// Return empty array
return nil
},
}
return fn(db)
},
}
},
expectedStatus: 200,
validateResp: func(t *testing.T, body []byte) {
var result map[string]interface{}
if err := json.Unmarshal(body, &result); err != nil {
t.Fatalf("Failed to unmarshal response: %v", err)
}
if len(result) != 0 {
t.Errorf("Expected empty result, got %v", result)
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
db := tt.setupDB()
handler := NewHandler(db)
req := createTestRequest("GET", "/test", tt.queryParams, tt.headers, nil)
w := httptest.NewRecorder()
handlerFunc := handler.SqlQuery(tt.sqlQuery, tt.blankParams)
handlerFunc(w, req)
if w.Code != tt.expectedStatus {
t.Errorf("Expected status %d, got %d", tt.expectedStatus, w.Code)
}
if tt.validateResp != nil {
tt.validateResp(t, w.Body.Bytes())
}
})
}
}
// TestSqlQueryList tests the SqlQueryList handler for list queries
func TestSqlQueryList(t *testing.T) {
tests := []struct {
name string
sqlQuery string
noCount bool
blankParams bool
allowFilter bool
queryParams map[string]string
headers map[string]string
setupDB func() *MockDatabase
expectedStatus int
validateResp func(t *testing.T, w *httptest.ResponseRecorder)
}{
{
name: "Basic list query",
sqlQuery: "SELECT * FROM users",
noCount: false,
blankParams: false,
allowFilter: false,
setupDB: func() *MockDatabase {
return &MockDatabase{
RunInTransactionFunc: func(ctx context.Context, fn func(common.Database) error) error {
callCount := 0
db := &MockDatabase{
QueryFunc: func(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
callCount++
if strings.Contains(query, "COUNT") {
// Count query
countResult := dest.(*struct{ Count int64 })
countResult.Count = 2
} else {
// Main query
rows := dest.(*[]map[string]interface{})
*rows = []map[string]interface{}{
{"id": float64(1), "name": "User 1"},
{"id": float64(2), "name": "User 2"},
}
}
return nil
},
}
return fn(db)
},
}
},
expectedStatus: 200,
validateResp: func(t *testing.T, w *httptest.ResponseRecorder) {
var result []map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &result); err != nil {
t.Fatalf("Failed to unmarshal response: %v", err)
}
if len(result) != 2 {
t.Errorf("Expected 2 results, got %d", len(result))
}
// Check Content-Range header
contentRange := w.Header().Get("Content-Range")
if !strings.Contains(contentRange, "2") {
t.Errorf("Expected Content-Range to contain total count, got: %s", contentRange)
}
},
},
{
name: "List query with noCount",
sqlQuery: "SELECT * FROM users",
noCount: true,
blankParams: false,
allowFilter: false,
setupDB: func() *MockDatabase {
return &MockDatabase{
RunInTransactionFunc: func(ctx context.Context, fn func(common.Database) error) error {
db := &MockDatabase{
QueryFunc: func(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
if strings.Contains(query, "COUNT") {
t.Error("Count query should not be executed when noCount is true")
}
rows := dest.(*[]map[string]interface{})
*rows = []map[string]interface{}{
{"id": float64(1), "name": "User 1"},
}
return nil
},
}
return fn(db)
},
}
},
expectedStatus: 200,
validateResp: func(t *testing.T, w *httptest.ResponseRecorder) {
var result []map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &result); err != nil {
t.Fatalf("Failed to unmarshal response: %v", err)
}
if len(result) != 1 {
t.Errorf("Expected 1 result, got %d", len(result))
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
db := tt.setupDB()
handler := NewHandler(db)
req := createTestRequest("GET", "/test", tt.queryParams, tt.headers, nil)
w := httptest.NewRecorder()
handlerFunc := handler.SqlQueryList(tt.sqlQuery, tt.noCount, tt.blankParams, tt.allowFilter)
handlerFunc(w, req)
if w.Code != tt.expectedStatus {
t.Errorf("Expected status %d, got %d. Body: %s", tt.expectedStatus, w.Code, w.Body.String())
}
if tt.validateResp != nil {
tt.validateResp(t, w)
}
})
}
}
// TestMergeQueryParams tests query parameter merging
func TestMergeQueryParams(t *testing.T) {
handler := NewHandler(&MockDatabase{})
tests := []struct {
name string
sqlQuery string
queryParams map[string]string
allowFilter bool
expectedQuery string
checkVars func(t *testing.T, vars map[string]interface{})
}{
{
name: "Replace placeholder with parameter",
sqlQuery: "SELECT * FROM users WHERE id = [user_id]",
queryParams: map[string]string{"p-user_id": "123"},
allowFilter: false,
checkVars: func(t *testing.T, vars map[string]interface{}) {
if vars["p-user_id"] != "123" {
t.Errorf("Expected p-user_id=123, got %v", vars["p-user_id"])
}
},
},
{
name: "Add filter when allowed",
sqlQuery: "SELECT * FROM users",
queryParams: map[string]string{"status": "active"},
allowFilter: true,
checkVars: func(t *testing.T, vars map[string]interface{}) {
if vars["status"] != "active" {
t.Errorf("Expected status=active, got %v", vars["status"])
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := createTestRequest("GET", "/test", tt.queryParams, nil, nil)
variables := make(map[string]interface{})
propQry := make(map[string]string)
result := handler.mergeQueryParams(req, tt.sqlQuery, variables, tt.allowFilter, propQry)
if result == "" {
t.Error("Expected non-empty SQL query result")
}
if tt.checkVars != nil {
tt.checkVars(t, variables)
}
})
}
}
// TestMergeHeaderParams tests header parameter merging
func TestMergeHeaderParams(t *testing.T) {
handler := NewHandler(&MockDatabase{})
tests := []struct {
name string
sqlQuery string
headers map[string]string
expectedQuery string
checkVars func(t *testing.T, vars map[string]interface{})
}{
{
name: "Field filter header",
sqlQuery: "SELECT * FROM users",
headers: map[string]string{"X-FieldFilter-Status": "1"},
checkVars: func(t *testing.T, vars map[string]interface{}) {
if vars["x-fieldfilter-status"] != "1" {
t.Errorf("Expected x-fieldfilter-status=1, got %v", vars["x-fieldfilter-status"])
}
},
},
{
name: "Search filter header",
sqlQuery: "SELECT * FROM users",
headers: map[string]string{"X-SearchFilter-Name": "john"},
checkVars: func(t *testing.T, vars map[string]interface{}) {
if vars["x-searchfilter-name"] != "john" {
t.Errorf("Expected x-searchfilter-name=john, got %v", vars["x-searchfilter-name"])
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := createTestRequest("GET", "/test", nil, tt.headers, nil)
variables := make(map[string]interface{})
propQry := make(map[string]string)
complexAPI := false
result := handler.mergeHeaderParams(req, tt.sqlQuery, variables, propQry, &complexAPI)
if result == "" {
t.Error("Expected non-empty SQL query result")
}
if tt.checkVars != nil {
tt.checkVars(t, variables)
}
})
}
}
// TestReplaceMetaVariables tests meta variable replacement
func TestReplaceMetaVariables(t *testing.T) {
handler := NewHandler(&MockDatabase{})
userCtx := &security.UserContext{
UserID: 123,
UserName: "testuser",
SessionID: "session-abc",
}
metainfo := map[string]interface{}{
"ipaddress": "192.168.1.1",
"url": "/api/test",
}
variables := map[string]interface{}{
"param1": "value1",
}
tests := []struct {
name string
sqlQuery string
expectedCheck func(result string) bool
}{
{
name: "Replace [rid_user]",
sqlQuery: "SELECT * FROM users WHERE created_by = [rid_user]",
expectedCheck: func(result string) bool {
return strings.Contains(result, "123")
},
},
{
name: "Replace [user]",
sqlQuery: "SELECT * FROM audit WHERE username = [user]",
expectedCheck: func(result string) bool {
return strings.Contains(result, "'testuser'")
},
},
{
name: "Replace [rid_session]",
sqlQuery: "SELECT * FROM sessions WHERE session_id = [rid_session]",
expectedCheck: func(result string) bool {
return strings.Contains(result, "'session-abc'")
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := createTestRequest("GET", "/test", nil, nil, nil)
result := handler.replaceMetaVariables(tt.sqlQuery, req, userCtx, metainfo, variables)
if !tt.expectedCheck(result) {
t.Errorf("Meta variable replacement failed. Query: %s", result)
}
})
}
}

160
pkg/funcspec/hooks.go Normal file
View File

@ -0,0 +1,160 @@
package funcspec
import (
"context"
"fmt"
"net/http"
"github.com/bitechdev/ResolveSpec/pkg/logger"
"github.com/bitechdev/ResolveSpec/pkg/security"
)
// HookType defines the type of hook to execute
type HookType string
const (
// Query operation hooks (for SqlQuery - single record)
BeforeQuery HookType = "before_query"
AfterQuery HookType = "after_query"
// Query list operation hooks (for SqlQueryList - multiple records)
BeforeQueryList HookType = "before_query_list"
AfterQueryList HookType = "after_query_list"
// SQL execution hooks (just before SQL is executed)
BeforeSQLExec HookType = "before_sql_exec"
AfterSQLExec HookType = "after_sql_exec"
// Response hooks (before response is sent)
BeforeResponse HookType = "before_response"
)
// HookContext contains all the data available to a hook
type HookContext struct {
Context context.Context
Handler *Handler // Reference to the handler for accessing database
Request *http.Request
Writer http.ResponseWriter
// SQL query and variables
SQLQuery string // The SQL query being executed (can be modified by hooks)
Variables map[string]interface{} // Variables extracted from request
InputVars []string // Input variable placeholders found in query
MetaInfo map[string]interface{} // Metadata about the request
PropQry map[string]string // Property query parameters
// User context
UserContext *security.UserContext
// Pagination and filtering (for list queries)
SortColumns string
Limit int
Offset int
// Results
Result interface{} // Query result (single record or list)
Total int64 // Total count (for list queries)
Error error // Error if operation failed
ComplexAPI bool // Whether complex API response format is requested
NoCount bool // Whether count query should be skipped
BlankParams bool // Whether blank parameters should be removed
AllowFilter bool // Whether filtering is allowed
// Allow hooks to abort the operation
Abort bool // If set to true, the operation will be aborted
AbortMessage string // Message to return if aborted
AbortCode int // HTTP status code if aborted
}
// HookFunc is the signature for hook functions
// It receives a HookContext and can modify it or return an error
// If an error is returned, the operation will be aborted
type HookFunc func(*HookContext) error
// HookRegistry manages all registered hooks
type HookRegistry struct {
hooks map[HookType][]HookFunc
}
// NewHookRegistry creates a new hook registry
func NewHookRegistry() *HookRegistry {
return &HookRegistry{
hooks: make(map[HookType][]HookFunc),
}
}
// Register adds a new hook for the specified hook type
func (r *HookRegistry) Register(hookType HookType, hook HookFunc) {
if r.hooks == nil {
r.hooks = make(map[HookType][]HookFunc)
}
r.hooks[hookType] = append(r.hooks[hookType], hook)
logger.Info("Registered funcspec hook for %s (total: %d)", hookType, len(r.hooks[hookType]))
}
// RegisterMultiple registers a hook for multiple hook types
func (r *HookRegistry) RegisterMultiple(hookTypes []HookType, hook HookFunc) {
for _, hookType := range hookTypes {
r.Register(hookType, hook)
}
}
// Execute runs all hooks for the specified type in order
// If any hook returns an error, execution stops and the error is returned
func (r *HookRegistry) Execute(hookType HookType, ctx *HookContext) error {
hooks, exists := r.hooks[hookType]
if !exists || len(hooks) == 0 {
return nil
}
logger.Debug("Executing %d funcspec hook(s) for %s", len(hooks), hookType)
for i, hook := range hooks {
if err := hook(ctx); err != nil {
logger.Error("Funcspec hook %d for %s failed: %v", i+1, hookType, err)
return fmt.Errorf("hook execution failed: %w", err)
}
// Check if hook requested abort
if ctx.Abort {
logger.Warn("Funcspec hook %d for %s requested abort: %s", i+1, hookType, ctx.AbortMessage)
return fmt.Errorf("operation aborted by hook: %s", ctx.AbortMessage)
}
}
return nil
}
// Clear removes all hooks for the specified type
func (r *HookRegistry) Clear(hookType HookType) {
delete(r.hooks, hookType)
logger.Info("Cleared all funcspec hooks for %s", hookType)
}
// ClearAll removes all registered hooks
func (r *HookRegistry) ClearAll() {
r.hooks = make(map[HookType][]HookFunc)
logger.Info("Cleared all funcspec hooks")
}
// Count returns the number of hooks registered for a specific type
func (r *HookRegistry) Count(hookType HookType) int {
if hooks, exists := r.hooks[hookType]; exists {
return len(hooks)
}
return 0
}
// HasHooks returns true if there are any hooks registered for the specified type
func (r *HookRegistry) HasHooks(hookType HookType) bool {
return r.Count(hookType) > 0
}
// GetAllHookTypes returns all hook types that have registered hooks
func (r *HookRegistry) GetAllHookTypes() []HookType {
types := make([]HookType, 0, len(r.hooks))
for hookType := range r.hooks {
types = append(types, hookType)
}
return types
}

View File

@ -0,0 +1,137 @@
package funcspec
import (
"fmt"
"strings"
"github.com/bitechdev/ResolveSpec/pkg/logger"
)
// Example hook functions demonstrating various use cases
// ExampleLoggingHook logs all SQL queries before execution
func ExampleLoggingHook(ctx *HookContext) error {
logger.Info("Executing SQL query for user %s: %s", ctx.UserContext.UserName, ctx.SQLQuery)
return nil
}
// ExampleSecurityHook validates user permissions before executing queries
func ExampleSecurityHook(ctx *HookContext) error {
// Example: Block queries that try to access sensitive tables
if strings.Contains(strings.ToLower(ctx.SQLQuery), "sensitive_table") {
if ctx.UserContext.UserID != 1 { // Only admin can access
ctx.Abort = true
ctx.AbortCode = 403
ctx.AbortMessage = "Access denied: insufficient permissions"
return fmt.Errorf("access denied to sensitive_table")
}
}
return nil
}
// ExampleQueryModificationHook modifies SQL queries to add user-specific filtering
func ExampleQueryModificationHook(ctx *HookContext) error {
// Example: Automatically add user_id filter for non-admin users
if ctx.UserContext.UserID != 1 { // Not admin
// Add WHERE clause to filter by user_id
if !strings.Contains(strings.ToLower(ctx.SQLQuery), "where") {
ctx.SQLQuery = fmt.Sprintf("%s WHERE user_id = %d", ctx.SQLQuery, ctx.UserContext.UserID)
} else {
ctx.SQLQuery = strings.Replace(
ctx.SQLQuery,
"WHERE",
fmt.Sprintf("WHERE user_id = %d AND", ctx.UserContext.UserID),
1,
)
}
logger.Debug("Modified query for user %d: %s", ctx.UserContext.UserID, ctx.SQLQuery)
}
return nil
}
// ExampleResultFilterHook filters results after query execution
func ExampleResultFilterHook(ctx *HookContext) error {
// Example: Remove sensitive fields from results for non-admin users
if ctx.UserContext.UserID != 1 { // Not admin
switch result := ctx.Result.(type) {
case []map[string]interface{}:
// Filter list results
for i := range result {
delete(result[i], "password")
delete(result[i], "ssn")
delete(result[i], "credit_card")
}
case map[string]interface{}:
// Filter single record
delete(result, "password")
delete(result, "ssn")
delete(result, "credit_card")
}
}
return nil
}
// ExampleAuditHook logs all queries and results for audit purposes
func ExampleAuditHook(ctx *HookContext) error {
// Log to audit table or external system
logger.Info("AUDIT: User %s (%d) executed query from %s",
ctx.UserContext.UserName,
ctx.UserContext.UserID,
ctx.Request.RemoteAddr,
)
// In a real implementation, you might:
// - Insert into an audit log table
// - Send to a logging service
// - Write to a file
return nil
}
// ExampleCacheHook implements simple response caching
func ExampleCacheHook(ctx *HookContext) error {
// This is a simplified example - real caching would use a proper cache store
// Check if we have a cached result for this query
// cacheKey := fmt.Sprintf("%s:%s", ctx.UserContext.UserName, ctx.SQLQuery)
// if cachedResult := checkCache(cacheKey); cachedResult != nil {
// ctx.Result = cachedResult
// ctx.Abort = true // Skip query execution
// ctx.AbortMessage = "Serving from cache"
// }
return nil
}
// ExampleErrorHandlingHook provides custom error handling
func ExampleErrorHandlingHook(ctx *HookContext) error {
if ctx.Error != nil {
// Log error with context
logger.Error("Query failed for user %s: %v\nQuery: %s",
ctx.UserContext.UserName,
ctx.Error,
ctx.SQLQuery,
)
// You could send notifications, update metrics, etc.
}
return nil
}
// Example of registering hooks:
//
// func SetupHooks(handler *Handler) {
// hooks := handler.Hooks()
//
// // Register security hook before query execution
// hooks.Register(BeforeQuery, ExampleSecurityHook)
// hooks.Register(BeforeQueryList, ExampleSecurityHook)
//
// // Register logging hook before SQL execution
// hooks.Register(BeforeSQLExec, ExampleLoggingHook)
//
// // Register result filtering after query
// hooks.Register(AfterQuery, ExampleResultFilterHook)
// hooks.Register(AfterQueryList, ExampleResultFilterHook)
//
// // Register audit hook after execution
// hooks.RegisterMultiple([]HookType{AfterQuery, AfterQueryList}, ExampleAuditHook)
// }

589
pkg/funcspec/hooks_test.go Normal file
View File

@ -0,0 +1,589 @@
package funcspec
import (
"context"
"fmt"
"net/http/httptest"
"testing"
"github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/security"
)
// TestNewHookRegistry tests hook registry creation
func TestNewHookRegistry(t *testing.T) {
registry := NewHookRegistry()
if registry == nil {
t.Fatal("Expected registry to be created, got nil")
}
if registry.hooks == nil {
t.Error("Expected hooks map to be initialized")
}
}
// TestRegisterHook tests registering a single hook
func TestRegisterHook(t *testing.T) {
registry := NewHookRegistry()
hookCalled := false
testHook := func(ctx *HookContext) error {
hookCalled = true
return nil
}
registry.Register(BeforeQuery, testHook)
if !registry.HasHooks(BeforeQuery) {
t.Error("Expected hook to be registered")
}
if registry.Count(BeforeQuery) != 1 {
t.Errorf("Expected 1 hook, got %d", registry.Count(BeforeQuery))
}
// Execute the hook
ctx := &HookContext{}
err := registry.Execute(BeforeQuery, ctx)
if err != nil {
t.Errorf("Hook execution failed: %v", err)
}
if !hookCalled {
t.Error("Expected hook to be called")
}
}
// TestRegisterMultipleHooks tests registering multiple hooks for same type
func TestRegisterMultipleHooks(t *testing.T) {
registry := NewHookRegistry()
callOrder := []int{}
hook1 := func(ctx *HookContext) error {
callOrder = append(callOrder, 1)
return nil
}
hook2 := func(ctx *HookContext) error {
callOrder = append(callOrder, 2)
return nil
}
hook3 := func(ctx *HookContext) error {
callOrder = append(callOrder, 3)
return nil
}
registry.Register(BeforeQuery, hook1)
registry.Register(BeforeQuery, hook2)
registry.Register(BeforeQuery, hook3)
if registry.Count(BeforeQuery) != 3 {
t.Errorf("Expected 3 hooks, got %d", registry.Count(BeforeQuery))
}
// Execute hooks
ctx := &HookContext{}
err := registry.Execute(BeforeQuery, ctx)
if err != nil {
t.Errorf("Hook execution failed: %v", err)
}
// Verify hooks were called in order
if len(callOrder) != 3 {
t.Errorf("Expected 3 hooks to be called, got %d", len(callOrder))
}
for i, expected := range []int{1, 2, 3} {
if callOrder[i] != expected {
t.Errorf("Expected hook %d at position %d, got %d", expected, i, callOrder[i])
}
}
}
// TestRegisterMultipleHookTypes tests registering a hook for multiple types
func TestRegisterMultipleHookTypes(t *testing.T) {
registry := NewHookRegistry()
callCount := 0
testHook := func(ctx *HookContext) error {
callCount++
return nil
}
hookTypes := []HookType{BeforeQuery, AfterQuery, BeforeSQLExec}
registry.RegisterMultiple(hookTypes, testHook)
// Verify hook is registered for all types
for _, hookType := range hookTypes {
if !registry.HasHooks(hookType) {
t.Errorf("Expected hook to be registered for %s", hookType)
}
if registry.Count(hookType) != 1 {
t.Errorf("Expected 1 hook for %s, got %d", hookType, registry.Count(hookType))
}
}
// Execute each hook type
ctx := &HookContext{}
for _, hookType := range hookTypes {
if err := registry.Execute(hookType, ctx); err != nil {
t.Errorf("Hook execution failed for %s: %v", hookType, err)
}
}
if callCount != 3 {
t.Errorf("Expected hook to be called 3 times, got %d", callCount)
}
}
// TestHookError tests hook error handling
func TestHookError(t *testing.T) {
registry := NewHookRegistry()
expectedError := fmt.Errorf("test error")
errorHook := func(ctx *HookContext) error {
return expectedError
}
registry.Register(BeforeQuery, errorHook)
ctx := &HookContext{}
err := registry.Execute(BeforeQuery, ctx)
if err == nil {
t.Error("Expected error from hook, got nil")
}
if err.Error() != fmt.Sprintf("hook execution failed: %v", expectedError) {
t.Errorf("Expected error message to contain hook error, got: %v", err)
}
}
// TestHookAbort tests hook abort functionality
func TestHookAbort(t *testing.T) {
registry := NewHookRegistry()
abortHook := func(ctx *HookContext) error {
ctx.Abort = true
ctx.AbortMessage = "Operation aborted by hook"
ctx.AbortCode = 403
return nil
}
registry.Register(BeforeQuery, abortHook)
ctx := &HookContext{}
err := registry.Execute(BeforeQuery, ctx)
if err == nil {
t.Error("Expected error when hook aborts, got nil")
}
if !ctx.Abort {
t.Error("Expected Abort to be true")
}
if ctx.AbortMessage != "Operation aborted by hook" {
t.Errorf("Expected abort message, got: %s", ctx.AbortMessage)
}
if ctx.AbortCode != 403 {
t.Errorf("Expected abort code 403, got: %d", ctx.AbortCode)
}
}
// TestHookChainWithError tests that hook chain stops on first error
func TestHookChainWithError(t *testing.T) {
registry := NewHookRegistry()
callOrder := []int{}
hook1 := func(ctx *HookContext) error {
callOrder = append(callOrder, 1)
return nil
}
hook2 := func(ctx *HookContext) error {
callOrder = append(callOrder, 2)
return fmt.Errorf("error in hook 2")
}
hook3 := func(ctx *HookContext) error {
callOrder = append(callOrder, 3)
return nil
}
registry.Register(BeforeQuery, hook1)
registry.Register(BeforeQuery, hook2)
registry.Register(BeforeQuery, hook3)
ctx := &HookContext{}
err := registry.Execute(BeforeQuery, ctx)
if err == nil {
t.Error("Expected error from hook chain")
}
// Only first two hooks should have been called
if len(callOrder) != 2 {
t.Errorf("Expected 2 hooks to be called, got %d", len(callOrder))
}
if callOrder[0] != 1 || callOrder[1] != 2 {
t.Errorf("Expected hooks 1 and 2 to be called, got: %v", callOrder)
}
}
// TestClearHooks tests clearing hooks
func TestClearHooks(t *testing.T) {
registry := NewHookRegistry()
testHook := func(ctx *HookContext) error {
return nil
}
registry.Register(BeforeQuery, testHook)
registry.Register(AfterQuery, testHook)
if !registry.HasHooks(BeforeQuery) {
t.Error("Expected BeforeQuery hook to be registered")
}
registry.Clear(BeforeQuery)
if registry.HasHooks(BeforeQuery) {
t.Error("Expected BeforeQuery hooks to be cleared")
}
if !registry.HasHooks(AfterQuery) {
t.Error("Expected AfterQuery hook to still be registered")
}
}
// TestClearAllHooks tests clearing all hooks
func TestClearAllHooks(t *testing.T) {
registry := NewHookRegistry()
testHook := func(ctx *HookContext) error {
return nil
}
registry.Register(BeforeQuery, testHook)
registry.Register(AfterQuery, testHook)
registry.Register(BeforeSQLExec, testHook)
registry.ClearAll()
if registry.HasHooks(BeforeQuery) || registry.HasHooks(AfterQuery) || registry.HasHooks(BeforeSQLExec) {
t.Error("Expected all hooks to be cleared")
}
}
// TestGetAllHookTypes tests getting all registered hook types
func TestGetAllHookTypes(t *testing.T) {
registry := NewHookRegistry()
testHook := func(ctx *HookContext) error {
return nil
}
registry.Register(BeforeQuery, testHook)
registry.Register(AfterQuery, testHook)
types := registry.GetAllHookTypes()
if len(types) != 2 {
t.Errorf("Expected 2 hook types, got %d", len(types))
}
// Verify the types are present
foundBefore := false
foundAfter := false
for _, hookType := range types {
if hookType == BeforeQuery {
foundBefore = true
}
if hookType == AfterQuery {
foundAfter = true
}
}
if !foundBefore || !foundAfter {
t.Error("Expected both BeforeQuery and AfterQuery hook types")
}
}
// TestHookContextModification tests that hooks can modify the context
func TestHookContextModification(t *testing.T) {
registry := NewHookRegistry()
// Hook that modifies SQL query
modifyHook := func(ctx *HookContext) error {
ctx.SQLQuery = "SELECT * FROM modified_table"
ctx.Variables["new_var"] = "new_value"
return nil
}
registry.Register(BeforeQuery, modifyHook)
ctx := &HookContext{
SQLQuery: "SELECT * FROM original_table",
Variables: make(map[string]interface{}),
}
err := registry.Execute(BeforeQuery, ctx)
if err != nil {
t.Errorf("Hook execution failed: %v", err)
}
if ctx.SQLQuery != "SELECT * FROM modified_table" {
t.Errorf("Expected SQL query to be modified, got: %s", ctx.SQLQuery)
}
if ctx.Variables["new_var"] != "new_value" {
t.Errorf("Expected variable to be added, got: %v", ctx.Variables)
}
}
// TestExampleHooks tests the example hooks
func TestExampleLoggingHook(t *testing.T) {
ctx := &HookContext{
Context: context.Background(),
SQLQuery: "SELECT * FROM test",
UserContext: &security.UserContext{
UserName: "testuser",
},
}
err := ExampleLoggingHook(ctx)
if err != nil {
t.Errorf("ExampleLoggingHook failed: %v", err)
}
}
func TestExampleSecurityHook(t *testing.T) {
tests := []struct {
name string
sqlQuery string
userID int
shouldAbort bool
}{
{
name: "Admin accessing sensitive table",
sqlQuery: "SELECT * FROM sensitive_table",
userID: 1,
shouldAbort: false,
},
{
name: "Non-admin accessing sensitive table",
sqlQuery: "SELECT * FROM sensitive_table",
userID: 2,
shouldAbort: true,
},
{
name: "Non-admin accessing normal table",
sqlQuery: "SELECT * FROM users",
userID: 2,
shouldAbort: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := &HookContext{
Context: context.Background(),
SQLQuery: tt.sqlQuery,
UserContext: &security.UserContext{
UserID: tt.userID,
},
}
_ = ExampleSecurityHook(ctx)
if tt.shouldAbort {
if !ctx.Abort {
t.Error("Expected security hook to abort operation")
}
if ctx.AbortCode != 403 {
t.Errorf("Expected abort code 403, got %d", ctx.AbortCode)
}
} else {
if ctx.Abort {
t.Error("Expected security hook not to abort operation")
}
}
})
}
}
func TestExampleResultFilterHook(t *testing.T) {
tests := []struct {
name string
userID int
result interface{}
validate func(t *testing.T, result interface{})
}{
{
name: "Admin user - no filtering",
userID: 1,
result: map[string]interface{}{
"id": 1,
"name": "Test",
"password": "secret",
},
validate: func(t *testing.T, result interface{}) {
m := result.(map[string]interface{})
if _, exists := m["password"]; !exists {
t.Error("Expected password field to remain for admin")
}
},
},
{
name: "Regular user - sensitive fields removed",
userID: 2,
result: map[string]interface{}{
"id": 1,
"name": "Test",
"password": "secret",
"ssn": "123-45-6789",
},
validate: func(t *testing.T, result interface{}) {
m := result.(map[string]interface{})
if _, exists := m["password"]; exists {
t.Error("Expected password field to be removed")
}
if _, exists := m["ssn"]; exists {
t.Error("Expected ssn field to be removed")
}
if _, exists := m["name"]; !exists {
t.Error("Expected name field to remain")
}
},
},
{
name: "Regular user - list results filtered",
userID: 2,
result: []map[string]interface{}{
{"id": 1, "name": "User 1", "password": "secret1"},
{"id": 2, "name": "User 2", "password": "secret2"},
},
validate: func(t *testing.T, result interface{}) {
list := result.([]map[string]interface{})
for _, m := range list {
if _, exists := m["password"]; exists {
t.Error("Expected password field to be removed from list")
}
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := &HookContext{
Context: context.Background(),
Result: tt.result,
UserContext: &security.UserContext{
UserID: tt.userID,
},
}
err := ExampleResultFilterHook(ctx)
if err != nil {
t.Errorf("Hook failed: %v", err)
}
if tt.validate != nil {
tt.validate(t, ctx.Result)
}
})
}
}
func TestExampleAuditHook(t *testing.T) {
req := httptest.NewRequest("GET", "/api/test", nil)
req.RemoteAddr = "192.168.1.1:12345"
ctx := &HookContext{
Context: context.Background(),
Request: req,
UserContext: &security.UserContext{
UserID: 123,
UserName: "testuser",
},
}
err := ExampleAuditHook(ctx)
if err != nil {
t.Errorf("ExampleAuditHook failed: %v", err)
}
}
func TestExampleErrorHandlingHook(t *testing.T) {
ctx := &HookContext{
Context: context.Background(),
SQLQuery: "SELECT * FROM test",
Error: fmt.Errorf("test error"),
UserContext: &security.UserContext{
UserName: "testuser",
},
}
err := ExampleErrorHandlingHook(ctx)
if err != nil {
t.Errorf("ExampleErrorHandlingHook failed: %v", err)
}
}
// TestHookIntegrationWithHandler tests hooks integrated with the handler
func TestHookIntegrationWithHandler(t *testing.T) {
db := &MockDatabase{
RunInTransactionFunc: func(ctx context.Context, fn func(common.Database) error) error {
queryDB := &MockDatabase{
QueryFunc: func(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
rows := dest.(*[]map[string]interface{})
*rows = []map[string]interface{}{
{"id": float64(1), "name": "Test User"},
}
return nil
},
}
return fn(queryDB)
},
}
handler := NewHandler(db)
// Register a hook that modifies the SQL query
hookCalled := false
handler.Hooks().Register(BeforeSQLExec, func(ctx *HookContext) error {
hookCalled = true
// Verify we can access context data
if ctx.SQLQuery == "" {
t.Error("Expected SQL query to be set")
}
if ctx.UserContext == nil {
t.Error("Expected user context to be set")
}
return nil
})
// Execute a query
req := createTestRequest("GET", "/test", nil, nil, nil)
w := httptest.NewRecorder()
handlerFunc := handler.SqlQuery("SELECT * FROM users WHERE id = 1", false)
handlerFunc(w, req)
if !hookCalled {
t.Error("Expected hook to be called during query execution")
}
if w.Code != 200 {
t.Errorf("Expected status 200, got %d", w.Code)
}
}

411
pkg/funcspec/parameters.go Normal file
View File

@ -0,0 +1,411 @@
package funcspec
import (
"fmt"
"net/http"
"strconv"
"strings"
"github.com/bitechdev/ResolveSpec/pkg/logger"
"github.com/bitechdev/ResolveSpec/pkg/restheadspec"
)
// RequestParameters holds parsed parameters from headers and query string
type RequestParameters struct {
// Field selection
SelectFields []string
NotSelectFields []string
Distinct bool
// Filtering
FieldFilters map[string]string // column -> value (exact match)
SearchFilters map[string]string // column -> value (ILIKE)
SearchOps map[string]FilterOperator // column -> {operator, value, logic}
CustomSQLWhere string
CustomSQLOr string
// Sorting & Pagination
SortColumns string
Limit int
Offset int
// Advanced features
SkipCount bool
SkipCache bool
// Response format
ResponseFormat string // "simple", "detail", "syncfusion"
ComplexAPI bool // true if NOT simple API
}
// FilterOperator represents a filter with operator
type FilterOperator struct {
Operator string // eq, neq, gt, lt, gte, lte, like, ilike, in, between, etc.
Value string
Logic string // AND or OR
}
// ParseParameters parses all parameters from request headers and query string
func (h *Handler) ParseParameters(r *http.Request) *RequestParameters {
params := &RequestParameters{
FieldFilters: make(map[string]string),
SearchFilters: make(map[string]string),
SearchOps: make(map[string]FilterOperator),
Limit: 20, // Default limit
Offset: 0, // Default offset
ResponseFormat: "simple", // Default format
ComplexAPI: false, // Default to simple API
}
// Merge headers and query parameters
combined := make(map[string]string)
// Add all headers (normalize to lowercase)
for key, values := range r.Header {
if len(values) > 0 {
combined[strings.ToLower(key)] = values[0]
}
}
// Add all query parameters (override headers)
for key, values := range r.URL.Query() {
if len(values) > 0 {
combined[strings.ToLower(key)] = values[0]
}
}
// Parse each parameter
for key, value := range combined {
// Decode value if base64 encoded
decodedValue := h.decodeValue(value)
switch {
// Field Selection
case strings.HasPrefix(key, "x-select-fields"):
params.SelectFields = h.parseCommaSeparated(decodedValue)
case strings.HasPrefix(key, "x-not-select-fields"):
params.NotSelectFields = h.parseCommaSeparated(decodedValue)
case strings.HasPrefix(key, "x-distinct"):
params.Distinct = strings.EqualFold(decodedValue, "true")
// Filtering
case strings.HasPrefix(key, "x-fieldfilter-"):
colName := strings.TrimPrefix(key, "x-fieldfilter-")
params.FieldFilters[colName] = decodedValue
case strings.HasPrefix(key, "x-searchfilter-"):
colName := strings.TrimPrefix(key, "x-searchfilter-")
params.SearchFilters[colName] = decodedValue
case strings.HasPrefix(key, "x-searchop-"):
h.parseSearchOp(params, key, decodedValue, "AND")
case strings.HasPrefix(key, "x-searchor-"):
h.parseSearchOp(params, key, decodedValue, "OR")
case strings.HasPrefix(key, "x-searchand-"):
h.parseSearchOp(params, key, decodedValue, "AND")
case strings.HasPrefix(key, "x-custom-sql-w"):
if params.CustomSQLWhere != "" {
params.CustomSQLWhere = fmt.Sprintf("%s AND (%s)", params.CustomSQLWhere, decodedValue)
} else {
params.CustomSQLWhere = decodedValue
}
case strings.HasPrefix(key, "x-custom-sql-or"):
if params.CustomSQLOr != "" {
params.CustomSQLOr = fmt.Sprintf("%s OR (%s)", params.CustomSQLOr, decodedValue)
} else {
params.CustomSQLOr = decodedValue
}
// Sorting & Pagination
case key == "sort" || strings.HasPrefix(key, "x-sort"):
params.SortColumns = decodedValue
case strings.HasPrefix(key, "sort(") && strings.Contains(key, ")"):
// Handle sort(col1,-col2) syntax
sortValue := key[strings.Index(key, "(")+1 : strings.Index(key, ")")]
params.SortColumns = sortValue
case key == "limit" || strings.HasPrefix(key, "x-limit"):
if limit, err := strconv.Atoi(decodedValue); err == nil && limit > 0 {
params.Limit = limit
}
case strings.HasPrefix(key, "limit(") && strings.Contains(key, ")"):
// Handle limit(offset,limit) or limit(limit) syntax
limitValue := key[strings.Index(key, "(")+1 : strings.Index(key, ")")]
parts := strings.Split(limitValue, ",")
if len(parts) > 1 {
if offset, err := strconv.Atoi(parts[0]); err == nil {
params.Offset = offset
}
if limit, err := strconv.Atoi(parts[1]); err == nil {
params.Limit = limit
}
} else {
if limit, err := strconv.Atoi(parts[0]); err == nil {
params.Limit = limit
}
}
case key == "offset" || strings.HasPrefix(key, "x-offset"):
if offset, err := strconv.Atoi(decodedValue); err == nil && offset >= 0 {
params.Offset = offset
}
// Advanced features
case strings.HasPrefix(key, "x-skipcount"):
params.SkipCount = strings.EqualFold(decodedValue, "true")
case strings.HasPrefix(key, "x-skipcache"):
params.SkipCache = strings.EqualFold(decodedValue, "true")
// Response Format
case strings.HasPrefix(key, "x-simpleapi"):
params.ResponseFormat = "simple"
params.ComplexAPI = !(decodedValue == "1" || strings.EqualFold(decodedValue, "true"))
case strings.HasPrefix(key, "x-detailapi"):
params.ResponseFormat = "detail"
params.ComplexAPI = true
case strings.HasPrefix(key, "x-syncfusion"):
params.ResponseFormat = "syncfusion"
params.ComplexAPI = true
}
}
return params
}
// parseSearchOp parses x-searchop-{operator}-{column} or x-searchor-{operator}-{column}
func (h *Handler) parseSearchOp(params *RequestParameters, headerKey, value, logic string) {
var prefix string
if logic == "OR" {
prefix = "x-searchor-"
} else {
prefix = "x-searchop-"
if strings.HasPrefix(headerKey, "x-searchand-") {
prefix = "x-searchand-"
}
}
rest := strings.TrimPrefix(headerKey, prefix)
parts := strings.SplitN(rest, "-", 2)
if len(parts) != 2 {
logger.Warn("Invalid search operator header format: %s", headerKey)
return
}
operator := parts[0]
colName := parts[1]
params.SearchOps[colName] = FilterOperator{
Operator: operator,
Value: value,
Logic: logic,
}
logger.Debug("%s search operator: %s %s %s", logic, colName, operator, value)
}
// decodeValue decodes base64 encoded values (ZIP_ or __ prefix)
func (h *Handler) decodeValue(value string) string {
decoded, _ := restheadspec.DecodeParam(value)
return decoded
}
// parseCommaSeparated parses comma-separated values
func (h *Handler) parseCommaSeparated(value string) []string {
if value == "" {
return nil
}
parts := strings.Split(value, ",")
result := make([]string, 0, len(parts))
for _, part := range parts {
part = strings.TrimSpace(part)
if part != "" {
result = append(result, part)
}
}
return result
}
// ApplyFieldSelection applies column selection to SQL query
func (h *Handler) ApplyFieldSelection(sqlQuery string, params *RequestParameters) string {
if len(params.SelectFields) == 0 && len(params.NotSelectFields) == 0 {
return sqlQuery
}
// This is a simplified implementation
// A full implementation would parse the SQL and replace the SELECT clause
// For now, we log a warning that this feature needs manual implementation
if len(params.SelectFields) > 0 {
logger.Debug("Field selection requested: %v (manual SQL adjustment may be needed)", params.SelectFields)
}
if len(params.NotSelectFields) > 0 {
logger.Debug("Field exclusion requested: %v (manual SQL adjustment may be needed)", params.NotSelectFields)
}
return sqlQuery
}
// ApplyFilters applies all filters to the SQL query
func (h *Handler) ApplyFilters(sqlQuery string, params *RequestParameters) string {
// Apply field filters (exact match)
for colName, value := range params.FieldFilters {
condition := ""
if value == "" || value == "0" {
condition = fmt.Sprintf("COALESCE(%s, 0) = %s", ValidSQL(colName, "colname"), ValidSQL(value, "colvalue"))
} else {
condition = fmt.Sprintf("%s = %s", ValidSQL(colName, "colname"), ValidSQL(value, "colvalue"))
}
sqlQuery = sqlQryWhere(sqlQuery, condition)
logger.Debug("Applied field filter: %s", condition)
}
// Apply search filters (ILIKE)
for colName, value := range params.SearchFilters {
sval := strings.ReplaceAll(value, "'", "")
if sval != "" {
condition := fmt.Sprintf("%s ILIKE '%%%s%%'", ValidSQL(colName, "colname"), ValidSQL(sval, "colvalue"))
sqlQuery = sqlQryWhere(sqlQuery, condition)
logger.Debug("Applied search filter: %s", condition)
}
}
// Apply search operators
for colName, filterOp := range params.SearchOps {
condition := h.buildFilterCondition(colName, filterOp)
if condition != "" {
if filterOp.Logic == "OR" {
sqlQuery = sqlQryWhereOr(sqlQuery, condition)
} else {
sqlQuery = sqlQryWhere(sqlQuery, condition)
}
logger.Debug("Applied search operator: %s", condition)
}
}
// Apply custom SQL WHERE
if params.CustomSQLWhere != "" {
colval := ValidSQL(params.CustomSQLWhere, "select")
if colval != "" {
sqlQuery = sqlQryWhere(sqlQuery, colval)
logger.Debug("Applied custom SQL WHERE: %s", colval)
}
}
// Apply custom SQL OR
if params.CustomSQLOr != "" {
colval := ValidSQL(params.CustomSQLOr, "select")
if colval != "" {
sqlQuery = sqlQryWhereOr(sqlQuery, colval)
logger.Debug("Applied custom SQL OR: %s", colval)
}
}
return sqlQuery
}
// buildFilterCondition builds a SQL condition from a FilterOperator
func (h *Handler) buildFilterCondition(colName string, op FilterOperator) string {
safCol := ValidSQL(colName, "colname")
operator := strings.ToLower(op.Operator)
value := op.Value
switch operator {
case "contains", "contain", "like":
return fmt.Sprintf("%s ILIKE '%%%s%%'", safCol, ValidSQL(value, "colvalue"))
case "beginswith", "startswith":
return fmt.Sprintf("%s ILIKE '%s%%'", safCol, ValidSQL(value, "colvalue"))
case "endswith":
return fmt.Sprintf("%s ILIKE '%%%s'", safCol, ValidSQL(value, "colvalue"))
case "equals", "eq", "=":
if IsNumeric(value) {
return fmt.Sprintf("%s = %s", safCol, ValidSQL(value, "colvalue"))
}
return fmt.Sprintf("%s = '%s'", safCol, ValidSQL(value, "colvalue"))
case "notequals", "neq", "ne", "!=", "<>":
if IsNumeric(value) {
return fmt.Sprintf("%s != %s", safCol, ValidSQL(value, "colvalue"))
}
return fmt.Sprintf("%s != '%s'", safCol, ValidSQL(value, "colvalue"))
case "greaterthan", "gt", ">":
return fmt.Sprintf("%s > %s", safCol, ValidSQL(value, "colvalue"))
case "lessthan", "lt", "<":
return fmt.Sprintf("%s < %s", safCol, ValidSQL(value, "colvalue"))
case "greaterthanorequal", "gte", "ge", ">=":
return fmt.Sprintf("%s >= %s", safCol, ValidSQL(value, "colvalue"))
case "lessthanorequal", "lte", "le", "<=":
return fmt.Sprintf("%s <= %s", safCol, ValidSQL(value, "colvalue"))
case "between":
parts := strings.Split(value, ",")
if len(parts) == 2 {
return fmt.Sprintf("%s > %s AND %s < %s", safCol, ValidSQL(parts[0], "colvalue"), safCol, ValidSQL(parts[1], "colvalue"))
}
case "betweeninclusive":
parts := strings.Split(value, ",")
if len(parts) == 2 {
return fmt.Sprintf("%s >= %s AND %s <= %s", safCol, ValidSQL(parts[0], "colvalue"), safCol, ValidSQL(parts[1], "colvalue"))
}
case "in":
values := strings.Split(value, ",")
safeValues := make([]string, len(values))
for i, v := range values {
safeValues[i] = fmt.Sprintf("'%s'", ValidSQL(v, "colvalue"))
}
return fmt.Sprintf("%s IN (%s)", safCol, strings.Join(safeValues, ", "))
case "empty", "isnull", "null":
return fmt.Sprintf("(%s IS NULL OR %s = '')", safCol, safCol)
case "notempty", "isnotnull", "notnull":
return fmt.Sprintf("(%s IS NOT NULL AND %s != '')", safCol, safCol)
default:
logger.Warn("Unknown filter operator: %s, defaulting to equals", operator)
return fmt.Sprintf("%s = '%s'", safCol, ValidSQL(value, "colvalue"))
}
return ""
}
// ApplyDistinct adds DISTINCT to SQL query if requested
func (h *Handler) ApplyDistinct(sqlQuery string, params *RequestParameters) string {
if !params.Distinct {
return sqlQuery
}
// Add DISTINCT after SELECT
selectPos := strings.Index(strings.ToUpper(sqlQuery), "SELECT")
if selectPos >= 0 {
beforeSelect := sqlQuery[:selectPos+6] // "SELECT"
afterSelect := sqlQuery[selectPos+6:]
sqlQuery = beforeSelect + " DISTINCT" + afterSelect
logger.Debug("Applied DISTINCT to query")
}
return sqlQuery
}
// sqlQryWhereOr adds a WHERE clause with OR logic
func sqlQryWhereOr(sqlquery, condition string) string {
lowerQuery := strings.ToLower(sqlquery)
wherePos := strings.Index(lowerQuery, " where ")
groupPos := strings.Index(lowerQuery, " group by")
orderPos := strings.Index(lowerQuery, " order by")
limitPos := strings.Index(lowerQuery, " limit ")
// Find the insertion point
insertPos := len(sqlquery)
if groupPos > 0 && groupPos < insertPos {
insertPos = groupPos
}
if orderPos > 0 && orderPos < insertPos {
insertPos = orderPos
}
if limitPos > 0 && limitPos < insertPos {
insertPos = limitPos
}
if wherePos > 0 {
// WHERE exists, add OR condition
before := sqlquery[:insertPos]
after := sqlquery[insertPos:]
return fmt.Sprintf("%s OR (%s) %s", before, condition, after)
} else {
// No WHERE exists, add it
before := sqlquery[:insertPos]
after := sqlquery[insertPos:]
return fmt.Sprintf("%s WHERE %s %s", before, condition, after)
}
}

View File

@ -0,0 +1,549 @@
package funcspec
import (
"strings"
"testing"
)
// TestParseParameters tests the comprehensive parameter parsing
func TestParseParameters(t *testing.T) {
handler := NewHandler(&MockDatabase{})
tests := []struct {
name string
queryParams map[string]string
headers map[string]string
validate func(t *testing.T, params *RequestParameters)
}{
{
name: "Parse field selection",
headers: map[string]string{
"X-Select-Fields": "id,name,email",
"X-Not-Select-Fields": "password,ssn",
},
validate: func(t *testing.T, params *RequestParameters) {
if len(params.SelectFields) != 3 {
t.Errorf("Expected 3 select fields, got %d", len(params.SelectFields))
}
if len(params.NotSelectFields) != 2 {
t.Errorf("Expected 2 not-select fields, got %d", len(params.NotSelectFields))
}
},
},
{
name: "Parse distinct flag",
headers: map[string]string{
"X-Distinct": "true",
},
validate: func(t *testing.T, params *RequestParameters) {
if !params.Distinct {
t.Error("Expected Distinct to be true")
}
},
},
{
name: "Parse field filters",
headers: map[string]string{
"X-FieldFilter-Status": "active",
"X-FieldFilter-Type": "admin",
},
validate: func(t *testing.T, params *RequestParameters) {
if len(params.FieldFilters) != 2 {
t.Errorf("Expected 2 field filters, got %d", len(params.FieldFilters))
}
if params.FieldFilters["status"] != "active" {
t.Errorf("Expected status filter=active, got %s", params.FieldFilters["status"])
}
},
},
{
name: "Parse search filters",
headers: map[string]string{
"X-SearchFilter-Name": "john",
"X-SearchFilter-Email": "test",
},
validate: func(t *testing.T, params *RequestParameters) {
if len(params.SearchFilters) != 2 {
t.Errorf("Expected 2 search filters, got %d", len(params.SearchFilters))
}
},
},
{
name: "Parse sort columns",
queryParams: map[string]string{
"sort": "-created_at,name",
},
validate: func(t *testing.T, params *RequestParameters) {
if params.SortColumns != "-created_at,name" {
t.Errorf("Expected sort columns=-created_at,name, got %s", params.SortColumns)
}
},
},
{
name: "Parse limit and offset",
queryParams: map[string]string{
"limit": "100",
"offset": "50",
},
validate: func(t *testing.T, params *RequestParameters) {
if params.Limit != 100 {
t.Errorf("Expected limit=100, got %d", params.Limit)
}
if params.Offset != 50 {
t.Errorf("Expected offset=50, got %d", params.Offset)
}
},
},
{
name: "Parse skip count",
headers: map[string]string{
"X-SkipCount": "true",
},
validate: func(t *testing.T, params *RequestParameters) {
if !params.SkipCount {
t.Error("Expected SkipCount to be true")
}
},
},
{
name: "Parse response format - syncfusion",
headers: map[string]string{
"X-Syncfusion": "true",
},
validate: func(t *testing.T, params *RequestParameters) {
if params.ResponseFormat != "syncfusion" {
t.Errorf("Expected ResponseFormat=syncfusion, got %s", params.ResponseFormat)
}
if !params.ComplexAPI {
t.Error("Expected ComplexAPI to be true for syncfusion format")
}
},
},
{
name: "Parse response format - detail",
headers: map[string]string{
"X-DetailAPI": "true",
},
validate: func(t *testing.T, params *RequestParameters) {
if params.ResponseFormat != "detail" {
t.Errorf("Expected ResponseFormat=detail, got %s", params.ResponseFormat)
}
},
},
{
name: "Parse simple API",
headers: map[string]string{
"X-SimpleAPI": "true",
},
validate: func(t *testing.T, params *RequestParameters) {
if params.ResponseFormat != "simple" {
t.Errorf("Expected ResponseFormat=simple, got %s", params.ResponseFormat)
}
if params.ComplexAPI {
t.Error("Expected ComplexAPI to be false for simple API")
}
},
},
{
name: "Parse custom SQL WHERE",
headers: map[string]string{
"X-Custom-SQL-W": "status = 'active' AND deleted = false",
},
validate: func(t *testing.T, params *RequestParameters) {
if params.CustomSQLWhere == "" {
t.Error("Expected CustomSQLWhere to be set")
}
},
},
{
name: "Parse search operators - AND",
headers: map[string]string{
"X-SearchOp-Eq-Name": "john",
"X-SearchOp-Gt-Age": "18",
},
validate: func(t *testing.T, params *RequestParameters) {
if len(params.SearchOps) != 2 {
t.Errorf("Expected 2 search operators, got %d", len(params.SearchOps))
}
if op, exists := params.SearchOps["name"]; exists {
if op.Operator != "eq" {
t.Errorf("Expected operator=eq for name, got %s", op.Operator)
}
if op.Logic != "AND" {
t.Errorf("Expected logic=AND, got %s", op.Logic)
}
} else {
t.Error("Expected name search operator to exist")
}
},
},
{
name: "Parse search operators - OR",
headers: map[string]string{
"X-SearchOr-Like-Description": "test",
},
validate: func(t *testing.T, params *RequestParameters) {
if op, exists := params.SearchOps["description"]; exists {
if op.Logic != "OR" {
t.Errorf("Expected logic=OR, got %s", op.Logic)
}
} else {
t.Error("Expected description search operator to exist")
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := createTestRequest("GET", "/test", tt.queryParams, tt.headers, nil)
params := handler.ParseParameters(req)
if tt.validate != nil {
tt.validate(t, params)
}
})
}
}
// TestBuildFilterCondition tests the filter condition builder
func TestBuildFilterCondition(t *testing.T) {
handler := NewHandler(&MockDatabase{})
tests := []struct {
name string
colName string
operator FilterOperator
expected string
}{
{
name: "Equals operator - numeric",
colName: "age",
operator: FilterOperator{
Operator: "eq",
Value: "25",
Logic: "AND",
},
expected: "age = 25",
},
{
name: "Equals operator - string",
colName: "name",
operator: FilterOperator{
Operator: "eq",
Value: "john",
Logic: "AND",
},
expected: "name = 'john'",
},
{
name: "Not equals operator",
colName: "status",
operator: FilterOperator{
Operator: "neq",
Value: "inactive",
Logic: "AND",
},
expected: "status != 'inactive'",
},
{
name: "Greater than operator",
colName: "age",
operator: FilterOperator{
Operator: "gt",
Value: "18",
Logic: "AND",
},
expected: "age > 18",
},
{
name: "Less than operator",
colName: "price",
operator: FilterOperator{
Operator: "lt",
Value: "100",
Logic: "AND",
},
expected: "price < 100",
},
{
name: "Contains operator",
colName: "description",
operator: FilterOperator{
Operator: "contains",
Value: "test",
Logic: "AND",
},
expected: "description ILIKE '%test%'",
},
{
name: "Starts with operator",
colName: "name",
operator: FilterOperator{
Operator: "startswith",
Value: "john",
Logic: "AND",
},
expected: "name ILIKE 'john%'",
},
{
name: "Ends with operator",
colName: "email",
operator: FilterOperator{
Operator: "endswith",
Value: "@example.com",
Logic: "AND",
},
expected: "email ILIKE '%@example.com'",
},
{
name: "Between operator",
colName: "age",
operator: FilterOperator{
Operator: "between",
Value: "18,65",
Logic: "AND",
},
expected: "age > 18 AND age < 65",
},
{
name: "IN operator",
colName: "status",
operator: FilterOperator{
Operator: "in",
Value: "active,pending,approved",
Logic: "AND",
},
expected: "status IN ('active', 'pending', 'approved')",
},
{
name: "IS NULL operator",
colName: "deleted_at",
operator: FilterOperator{
Operator: "null",
Value: "",
Logic: "AND",
},
expected: "(deleted_at IS NULL OR deleted_at = '')",
},
{
name: "IS NOT NULL operator",
colName: "created_at",
operator: FilterOperator{
Operator: "notnull",
Value: "",
Logic: "AND",
},
expected: "(created_at IS NOT NULL AND created_at != '')",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := handler.buildFilterCondition(tt.colName, tt.operator)
if result != tt.expected {
t.Errorf("Expected: %s\nGot: %s", tt.expected, result)
}
})
}
}
// TestApplyFilters tests the filter application to SQL queries
func TestApplyFilters(t *testing.T) {
handler := NewHandler(&MockDatabase{})
tests := []struct {
name string
sqlQuery string
params *RequestParameters
expectedSQL string
shouldContain []string
}{
{
name: "Apply field filter",
sqlQuery: "SELECT * FROM users",
params: &RequestParameters{
FieldFilters: map[string]string{
"status": "active",
},
},
shouldContain: []string{"WHERE", "status"},
},
{
name: "Apply search filter",
sqlQuery: "SELECT * FROM users",
params: &RequestParameters{
SearchFilters: map[string]string{
"name": "john",
},
},
shouldContain: []string{"WHERE", "name", "ILIKE"},
},
{
name: "Apply search operators",
sqlQuery: "SELECT * FROM users",
params: &RequestParameters{
SearchOps: map[string]FilterOperator{
"age": {
Operator: "gt",
Value: "18",
Logic: "AND",
},
},
},
shouldContain: []string{"WHERE", "age", ">", "18"},
},
{
name: "Apply custom SQL WHERE",
sqlQuery: "SELECT * FROM users",
params: &RequestParameters{
CustomSQLWhere: "deleted = false",
},
shouldContain: []string{"WHERE", "deleted"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := handler.ApplyFilters(tt.sqlQuery, tt.params)
for _, expected := range tt.shouldContain {
if !strings.Contains(result, expected) {
t.Errorf("Expected SQL to contain %q, got: %s", expected, result)
}
}
})
}
}
// TestApplyDistinct tests DISTINCT application
func TestApplyDistinct(t *testing.T) {
handler := NewHandler(&MockDatabase{})
tests := []struct {
name string
sqlQuery string
distinct bool
shouldHave string
}{
{
name: "Apply DISTINCT",
sqlQuery: "SELECT id, name FROM users",
distinct: true,
shouldHave: "SELECT DISTINCT",
},
{
name: "Do not apply DISTINCT",
sqlQuery: "SELECT id, name FROM users",
distinct: false,
shouldHave: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
params := &RequestParameters{Distinct: tt.distinct}
result := handler.ApplyDistinct(tt.sqlQuery, params)
if tt.shouldHave != "" {
if !strings.Contains(result, tt.shouldHave) {
t.Errorf("Expected SQL to contain %q, got: %s", tt.shouldHave, result)
}
} else {
// Should not have DISTINCT when not requested
if strings.Contains(result, "DISTINCT") && !tt.distinct {
t.Errorf("SQL should not contain DISTINCT when not requested: %s", result)
}
}
})
}
}
// TestParseCommaSeparated tests comma-separated value parsing
func TestParseCommaSeparated(t *testing.T) {
handler := NewHandler(&MockDatabase{})
tests := []struct {
name string
input string
expected []string
}{
{
name: "Simple comma-separated",
input: "id,name,email",
expected: []string{"id", "name", "email"},
},
{
name: "With spaces",
input: "id, name, email",
expected: []string{"id", "name", "email"},
},
{
name: "Empty string",
input: "",
expected: nil,
},
{
name: "Single value",
input: "id",
expected: []string{"id"},
},
{
name: "With extra commas",
input: "id,,name,,email",
expected: []string{"id", "name", "email"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := handler.parseCommaSeparated(tt.input)
if len(result) != len(tt.expected) {
t.Errorf("Expected %d values, got %d", len(tt.expected), len(result))
return
}
for i, expected := range tt.expected {
if result[i] != expected {
t.Errorf("Expected value %d to be %s, got %s", i, expected, result[i])
}
}
})
}
}
// TestSqlQryWhereOr tests OR WHERE clause manipulation
func TestSqlQryWhereOr(t *testing.T) {
tests := []struct {
name string
sqlQuery string
condition string
shouldContain []string
}{
{
name: "Add WHERE with OR to query without WHERE",
sqlQuery: "SELECT * FROM users",
condition: "status = 'inactive'",
shouldContain: []string{"WHERE", "status = 'inactive'"},
},
{
name: "Add OR to query with existing WHERE",
sqlQuery: "SELECT * FROM users WHERE id > 0",
condition: "status = 'inactive'",
shouldContain: []string{"WHERE", "OR", "(status = 'inactive')"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := sqlQryWhereOr(tt.sqlQuery, tt.condition)
for _, expected := range tt.shouldContain {
if !strings.Contains(result, expected) {
t.Errorf("Expected SQL to contain %q, got: %s", expected, result)
}
}
})
}
}

View File

@ -8,7 +8,9 @@ import (
"reflect" "reflect"
"runtime/debug" "runtime/debug"
"strings" "strings"
"time"
"github.com/bitechdev/ResolveSpec/pkg/cache"
"github.com/bitechdev/ResolveSpec/pkg/common" "github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/logger" "github.com/bitechdev/ResolveSpec/pkg/logger"
"github.com/bitechdev/ResolveSpec/pkg/reflection" "github.com/bitechdev/ResolveSpec/pkg/reflection"
@ -233,13 +235,46 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
} }
// Get total count before pagination // Get total count before pagination
total, err := query.Count(ctx) var total int
// Try to get from cache first
cacheKeyHash := cache.BuildQueryCacheKey(
tableName,
options.Filters,
options.Sort,
"", // No custom SQL WHERE in resolvespec
"", // No custom SQL OR in resolvespec
)
cacheKey := cache.GetQueryTotalCacheKey(cacheKeyHash)
// Try to retrieve from cache
var cachedTotal cache.CachedTotal
err := cache.GetDefaultCache().Get(ctx, cacheKey, &cachedTotal)
if err == nil {
total = cachedTotal.Total
logger.Debug("Total records (from cache): %d", total)
} else {
// Cache miss - execute count query
logger.Debug("Cache miss for query total")
count, err := query.Count(ctx)
if err != nil { if err != nil {
logger.Error("Error counting records: %v", err) logger.Error("Error counting records: %v", err)
h.sendError(w, http.StatusInternalServerError, "query_error", "Error counting records", err) h.sendError(w, http.StatusInternalServerError, "query_error", "Error counting records", err)
return return
} }
logger.Debug("Total records before filtering: %d", total) total = count
logger.Debug("Total records (from query): %d", total)
// Store in cache
cacheTTL := time.Minute * 2 // Default 2 minutes TTL
cacheData := cache.CachedTotal{Total: total}
if err := cache.GetDefaultCache().Set(ctx, cacheKey, cacheData, cacheTTL); err != nil {
logger.Warn("Failed to cache query total: %v", err)
// Don't fail the request if caching fails
} else {
logger.Debug("Cached query total with key: %s", cacheKey)
}
}
// Apply pagination // Apply pagination
if options.Limit != nil && *options.Limit > 0 { if options.Limit != nil && *options.Limit > 0 {

View File

@ -9,7 +9,9 @@ import (
"runtime/debug" "runtime/debug"
"strconv" "strconv"
"strings" "strings"
"time"
"github.com/bitechdev/ResolveSpec/pkg/cache"
"github.com/bitechdev/ResolveSpec/pkg/common" "github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/logger" "github.com/bitechdev/ResolveSpec/pkg/logger"
"github.com/bitechdev/ResolveSpec/pkg/reflection" "github.com/bitechdev/ResolveSpec/pkg/reflection"
@ -436,6 +438,48 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
// Get total count before pagination (unless skip count is requested) // Get total count before pagination (unless skip count is requested)
var total int var total int
if !options.SkipCount { if !options.SkipCount {
// Try to get from cache first (unless SkipCache is true)
var cachedTotal *cache.CachedTotal
var cacheKey string
if !options.SkipCache {
// Build cache key from query parameters
// Convert expand options to interface slice for the cache key builder
expandOpts := make([]interface{}, len(options.Expand))
for i, exp := range options.Expand {
expandOpts[i] = map[string]interface{}{
"relation": exp.Relation,
"where": exp.Where,
}
}
cacheKeyHash := cache.BuildExtendedQueryCacheKey(
tableName,
options.Filters,
options.Sort,
options.CustomSQLWhere,
options.CustomSQLOr,
expandOpts,
options.Distinct,
options.CursorForward,
options.CursorBackward,
)
cacheKey = cache.GetQueryTotalCacheKey(cacheKeyHash)
// Try to retrieve from cache
cachedTotal = &cache.CachedTotal{}
err := cache.GetDefaultCache().Get(ctx, cacheKey, cachedTotal)
if err == nil {
total = cachedTotal.Total
logger.Debug("Total records (from cache): %d", total)
} else {
logger.Debug("Cache miss for query total")
cachedTotal = nil
}
}
// If not in cache or cache skip, execute count query
if cachedTotal == nil {
count, err := query.Count(ctx) count, err := query.Count(ctx)
if err != nil { if err != nil {
logger.Error("Error counting records: %v", err) logger.Error("Error counting records: %v", err)
@ -443,7 +487,20 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
return return
} }
total = count total = count
logger.Debug("Total records: %d", total) logger.Debug("Total records (from query): %d", total)
// Store in cache (if caching is enabled)
if !options.SkipCache && cacheKey != "" {
cacheTTL := time.Minute * 2 // Default 2 minutes TTL
cacheData := &cache.CachedTotal{Total: total}
if err := cache.GetDefaultCache().Set(ctx, cacheKey, cacheData, cacheTTL); err != nil {
logger.Warn("Failed to cache query total: %v", err)
// Don't fail the request if caching fails
} else {
logger.Debug("Cached query total with key: %s", cacheKey)
}
}
}
} else { } else {
logger.Debug("Skipping count as requested") logger.Debug("Skipping count as requested")
total = -1 // Indicate count was skipped total = -1 // Indicate count was skipped