diff --git a/go.mod b/go.mod index db8f789..04129fc 100644 --- a/go.mod +++ b/go.mod @@ -19,7 +19,10 @@ require ( ) require ( + github.com/bradfitz/gomemcache v0.0.0-20250403215159-8d39553ac7cf // indirect + github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect + github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/dustin/go-humanize v1.0.1 // indirect github.com/glebarez/go-sqlite v1.21.2 // indirect github.com/google/uuid v1.6.0 // indirect @@ -30,6 +33,7 @@ require ( github.com/ncruces/go-strftime v0.1.9 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/puzpuzpuz/xsync/v3 v3.5.1 // indirect + github.com/redis/go-redis/v9 v9.17.1 // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.0 // indirect diff --git a/go.sum b/go.sum index c8aa694..0254bc9 100644 --- a/go.sum +++ b/go.sum @@ -1,6 +1,12 @@ +github.com/bradfitz/gomemcache v0.0.0-20250403215159-8d39553ac7cf h1:TqhNAT4zKbTdLa62d2HDBFdvgSbIGB3eJE8HqhgiL9I= +github.com/bradfitz/gomemcache v0.0.0-20250403215159-8d39553ac7cf/go.mod h1:r5xuitiExdLAJ09PR7vBVENGvp4ZuTBeWTGtxuX3K+c= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= github.com/glebarez/go-sqlite v1.21.2 h1:3a6LFC4sKahUunAmynQKLZceZCOzUthkRkEAl9gAXWo= @@ -31,6 +37,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/puzpuzpuz/xsync/v3 v3.5.1 h1:GJYJZwO6IdxN/IKbneznS6yPkVC+c3zyY/j19c++5Fg= github.com/puzpuzpuz/xsync/v3 v3.5.1/go.mod h1:VjzYrABPabuM4KyBh1Ftq6u8nhwY5tBPKP9jpmh0nnA= +github.com/redis/go-redis/v9 v9.17.1 h1:7tl732FjYPRT9H9aNfyTwKg9iTETjWjGKEJ2t/5iWTs= +github.com/redis/go-redis/v9 v9.17.1/go.mod h1:u410H11HMLoB+TP67dz8rL9s6QW2j76l0//kSOd3370= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= diff --git a/pkg/cache/README.md b/pkg/cache/README.md new file mode 100644 index 0000000..643d364 --- /dev/null +++ b/pkg/cache/README.md @@ -0,0 +1,340 @@ +# Cache Package + +A flexible, provider-based caching library for Go that supports multiple backend storage systems including in-memory, Redis, and Memcache. + +## Features + +- **Multiple Providers**: Support for in-memory, Redis, and Memcache backends +- **Pluggable Architecture**: Easy to add custom cache providers +- **Type-Safe API**: Automatic JSON serialization/deserialization +- **TTL Support**: Configurable time-to-live for cache entries +- **Context-Aware**: All operations support Go contexts +- **Statistics**: Built-in cache statistics and monitoring +- **Pattern Deletion**: Delete keys by pattern (Redis) +- **Lazy Loading**: GetOrSet pattern for easy cache-aside implementation + +## Installation + +```bash +go get github.com/bitechdev/ResolveSpec/pkg/cache +``` + +For Redis support: +```bash +go get github.com/redis/go-redis/v9 +``` + +For Memcache support: +```bash +go get github.com/bradfitz/gomemcache/memcache +``` + +## Quick Start + +### In-Memory Cache + +```go +package main + +import ( + "context" + "time" + "github.com/bitechdev/ResolveSpec/pkg/cache" +) + +func main() { + // Initialize with in-memory provider + cache.UseMemory(&cache.Options{ + DefaultTTL: 5 * time.Minute, + MaxSize: 10000, + }) + defer cache.Close() + + ctx := context.Background() + c := cache.GetDefaultCache() + + // Store a value + type User struct { + ID int + Name string + } + user := User{ID: 1, Name: "John"} + c.Set(ctx, "user:1", user, 10*time.Minute) + + // Retrieve a value + var retrieved User + c.Get(ctx, "user:1", &retrieved) +} +``` + +### Redis Cache + +```go +cache.UseRedis(&cache.RedisConfig{ + Host: "localhost", + Port: 6379, + Password: "", + DB: 0, + Options: &cache.Options{ + DefaultTTL: 5 * time.Minute, + }, +}) +defer cache.Close() +``` + +### Memcache + +```go +cache.UseMemcache(&cache.MemcacheConfig{ + Servers: []string{"localhost:11211"}, + Options: &cache.Options{ + DefaultTTL: 5 * time.Minute, + }, +}) +defer cache.Close() +``` + +## API Reference + +### Core Methods + +#### Set +```go +Set(ctx context.Context, key string, value interface{}, ttl time.Duration) error +``` +Stores a value in the cache with automatic JSON serialization. + +#### Get +```go +Get(ctx context.Context, key string, dest interface{}) error +``` +Retrieves and deserializes a value from the cache. + +#### SetBytes / GetBytes +```go +SetBytes(ctx context.Context, key string, value []byte, ttl time.Duration) error +GetBytes(ctx context.Context, key string) ([]byte, error) +``` +Store and retrieve raw bytes without serialization. + +#### Delete +```go +Delete(ctx context.Context, key string) error +``` +Removes a key from the cache. + +#### DeleteByPattern +```go +DeleteByPattern(ctx context.Context, pattern string) error +``` +Removes all keys matching a pattern (Redis only). + +#### Clear +```go +Clear(ctx context.Context) error +``` +Removes all items from the cache. + +#### Exists +```go +Exists(ctx context.Context, key string) bool +``` +Checks if a key exists in the cache. + +#### GetOrSet +```go +GetOrSet(ctx context.Context, key string, dest interface{}, ttl time.Duration, + loader func() (interface{}, error)) error +``` +Retrieves a value from cache, or loads and caches it if not found (lazy loading). + +#### Stats +```go +Stats(ctx context.Context) (*CacheStats, error) +``` +Returns cache statistics including hits, misses, and key counts. + +## Provider Configuration + +### In-Memory Options + +```go +&cache.Options{ + DefaultTTL: 5 * time.Minute, // Default expiration time + MaxSize: 10000, // Maximum number of items + EvictionPolicy: "LRU", // Eviction strategy (future) +} +``` + +### Redis Configuration + +```go +&cache.RedisConfig{ + Host: "localhost", + Port: 6379, + Password: "", // Optional authentication + DB: 0, // Database number + PoolSize: 10, // Connection pool size + Options: &cache.Options{ + DefaultTTL: 5 * time.Minute, + }, +} +``` + +### Memcache Configuration + +```go +&cache.MemcacheConfig{ + Servers: []string{"localhost:11211"}, + MaxIdleConns: 2, + Timeout: 1 * time.Second, + Options: &cache.Options{ + DefaultTTL: 5 * time.Minute, + }, +} +``` + +## Advanced Usage + +### Custom Provider + +```go +// Create a custom provider instance +memProvider := cache.NewMemoryProvider(&cache.Options{ + DefaultTTL: 10 * time.Minute, + MaxSize: 500, +}) + +// Initialize with custom provider +cache.Initialize(memProvider) +``` + +### Lazy Loading Pattern + +```go +var data ExpensiveData +err := c.GetOrSet(ctx, "expensive:key", &data, 10*time.Minute, func() (interface{}, error) { + // This expensive operation only runs if key is not in cache + return computeExpensiveData(), nil +}) +``` + +### Query API Cache + +The package includes specialized functions for caching query results: + +```go +// Cache a query result +api := "GetUsers" +query := "SELECT * FROM users WHERE active = true" +tablenames := "users" +total := int64(150) + +cache.PutQueryAPICache(ctx, api, query, tablenames, total) + +// Retrieve cached query +hash := cache.HashQueryAPICache(api, query) +cachedQuery, err := cache.FetchQueryAPICache(ctx, hash) +``` + +## Provider Comparison + +| Feature | In-Memory | Redis | Memcache | +|---------|-----------|-------|----------| +| Persistence | No | Yes | No | +| Distributed | No | Yes | Yes | +| Pattern Delete | No | Yes | No | +| Statistics | Full | Full | Limited | +| Atomic Operations | Yes | Yes | Yes | +| Max Item Size | Memory | 512MB | 1MB | + +## Best Practices + +1. **Use contexts**: Always pass context for cancellation and timeout control +2. **Set appropriate TTLs**: Balance between freshness and performance +3. **Handle errors**: Cache misses and errors should be handled gracefully +4. **Monitor statistics**: Use Stats() to monitor cache performance +5. **Clean up**: Always call Close() when shutting down +6. **Pattern consistency**: Use consistent key naming patterns (e.g., "user:id:field") + +## Example: Complete Application + +```go +package main + +import ( + "context" + "log" + "time" + "github.com/bitechdev/ResolveSpec/pkg/cache" +) + +type UserService struct { + cache *cache.Cache +} + +func NewUserService() *UserService { + // Initialize with Redis in production, memory for testing + cache.UseRedis(&cache.RedisConfig{ + Host: "localhost", + Port: 6379, + Options: &cache.Options{ + DefaultTTL: 10 * time.Minute, + }, + }) + + return &UserService{ + cache: cache.GetDefaultCache(), + } +} + +func (s *UserService) GetUser(ctx context.Context, userID int) (*User, error) { + var user User + cacheKey := fmt.Sprintf("user:%d", userID) + + // Try to get from cache first + err := s.cache.GetOrSet(ctx, cacheKey, &user, 15*time.Minute, func() (interface{}, error) { + // Load from database if not in cache + return s.loadUserFromDB(userID) + }) + + if err != nil { + return nil, err + } + + return &user, nil +} + +func (s *UserService) InvalidateUser(ctx context.Context, userID int) error { + cacheKey := fmt.Sprintf("user:%d", userID) + return s.cache.Delete(ctx, cacheKey) +} + +func main() { + service := NewUserService() + defer cache.Close() + + ctx := context.Background() + user, err := service.GetUser(ctx, 123) + if err != nil { + log.Fatal(err) + } + + log.Printf("User: %+v", user) +} +``` + +## Performance Considerations + +- **In-Memory**: Fastest but limited by RAM and not distributed +- **Redis**: Great for distributed systems, persistent, but network overhead +- **Memcache**: Good for distributed caching, simpler than Redis but less features + +Choose based on your needs: +- Single instance? Use in-memory +- Need persistence or advanced features? Use Redis +- Simple distributed cache? Use Memcache + +## License + +See repository license. diff --git a/pkg/cache/cache.go b/pkg/cache/cache.go new file mode 100644 index 0000000..10fc8cd --- /dev/null +++ b/pkg/cache/cache.go @@ -0,0 +1,76 @@ +package cache + +import ( + "context" + "fmt" + "time" +) + +var ( + defaultCache *Cache +) + +// Initialize initializes the cache with a provider. +// If not called, the package will use an in-memory provider by default. +func Initialize(provider Provider) { + defaultCache = NewCache(provider) +} + +// UseMemory configures the cache to use in-memory storage. +func UseMemory(opts *Options) error { + provider := NewMemoryProvider(opts) + defaultCache = NewCache(provider) + return nil +} + +// UseRedis configures the cache to use Redis storage. +func UseRedis(config *RedisConfig) error { + provider, err := NewRedisProvider(config) + if err != nil { + return fmt.Errorf("failed to initialize Redis provider: %w", err) + } + defaultCache = NewCache(provider) + return nil +} + +// UseMemcache configures the cache to use Memcache storage. +func UseMemcache(config *MemcacheConfig) error { + provider, err := NewMemcacheProvider(config) + if err != nil { + return fmt.Errorf("failed to initialize Memcache provider: %w", err) + } + defaultCache = NewCache(provider) + return nil +} + +// GetDefaultCache returns the default cache instance. +// Initializes with in-memory provider if not already initialized. +func GetDefaultCache() *Cache { + if defaultCache == nil { + UseMemory(&Options{ + DefaultTTL: 5 * time.Minute, + MaxSize: 10000, + }) + } + return defaultCache +} + +// SetDefaultCache sets a custom cache instance as the default cache. +// This is useful for testing or when you want to use a pre-configured cache instance. +func SetDefaultCache(cache *Cache) { + defaultCache = cache +} + +// GetStats returns cache statistics. +func GetStats(ctx context.Context) (*CacheStats, error) { + cache := GetDefaultCache() + return cache.Stats(ctx) +} + +// Close closes the cache and releases resources. +func Close() error { + if defaultCache != nil { + return defaultCache.Close() + } + return nil +} diff --git a/pkg/cache/cache_manager.go b/pkg/cache/cache_manager.go new file mode 100644 index 0000000..e1ca28d --- /dev/null +++ b/pkg/cache/cache_manager.go @@ -0,0 +1,147 @@ +package cache + +import ( + "context" + "encoding/json" + "fmt" + "time" +) + +// Cache is the main cache manager that wraps a Provider. +type Cache struct { + provider Provider +} + +// NewCache creates a new cache manager with the specified provider. +func NewCache(provider Provider) *Cache { + return &Cache{ + provider: provider, + } +} + +// Get retrieves and deserializes a value from the cache. +func (c *Cache) Get(ctx context.Context, key string, dest interface{}) error { + data, exists := c.provider.Get(ctx, key) + if !exists { + return fmt.Errorf("key not found: %s", key) + } + + if err := json.Unmarshal(data, dest); err != nil { + return fmt.Errorf("failed to deserialize: %w", err) + } + + return nil +} + +// GetBytes retrieves raw bytes from the cache. +func (c *Cache) GetBytes(ctx context.Context, key string) ([]byte, error) { + data, exists := c.provider.Get(ctx, key) + if !exists { + return nil, fmt.Errorf("key not found: %s", key) + } + return data, nil +} + +// Set serializes and stores a value in the cache with the specified TTL. +func (c *Cache) Set(ctx context.Context, key string, value interface{}, ttl time.Duration) error { + data, err := json.Marshal(value) + if err != nil { + return fmt.Errorf("failed to serialize: %w", err) + } + + return c.provider.Set(ctx, key, data, ttl) +} + +// SetBytes stores raw bytes in the cache with the specified TTL. +func (c *Cache) SetBytes(ctx context.Context, key string, value []byte, ttl time.Duration) error { + return c.provider.Set(ctx, key, value, ttl) +} + +// Delete removes a key from the cache. +func (c *Cache) Delete(ctx context.Context, key string) error { + return c.provider.Delete(ctx, key) +} + +// DeleteByPattern removes all keys matching the pattern. +func (c *Cache) DeleteByPattern(ctx context.Context, pattern string) error { + return c.provider.DeleteByPattern(ctx, pattern) +} + +// Clear removes all items from the cache. +func (c *Cache) Clear(ctx context.Context) error { + return c.provider.Clear(ctx) +} + +// Exists checks if a key exists in the cache. +func (c *Cache) Exists(ctx context.Context, key string) bool { + return c.provider.Exists(ctx, key) +} + +// Stats returns statistics about the cache. +func (c *Cache) Stats(ctx context.Context) (*CacheStats, error) { + return c.provider.Stats(ctx) +} + +// Close closes the cache and releases any resources. +func (c *Cache) Close() error { + return c.provider.Close() +} + +// GetOrSet retrieves a value from cache, or sets it if it doesn't exist. +// The loader function is called only if the key is not found in cache. +func (c *Cache) GetOrSet(ctx context.Context, key string, dest interface{}, ttl time.Duration, loader func() (interface{}, error)) error { + // Try to get from cache first + err := c.Get(ctx, key, dest) + if err == nil { + return nil + } + + // Load the value + value, err := loader() + if err != nil { + return fmt.Errorf("loader failed: %w", err) + } + + // Store in cache + if err := c.Set(ctx, key, value, ttl); err != nil { + return fmt.Errorf("failed to cache value: %w", err) + } + + // Populate dest with the loaded value + data, err := json.Marshal(value) + if err != nil { + return fmt.Errorf("failed to serialize loaded value: %w", err) + } + + if err := json.Unmarshal(data, dest); err != nil { + return fmt.Errorf("failed to deserialize loaded value: %w", err) + } + + return nil +} + +// Remember is a convenience function that caches the result of a function call. +// It's similar to GetOrSet but returns the value directly. +func (c *Cache) Remember(ctx context.Context, key string, ttl time.Duration, loader func() (interface{}, error)) (interface{}, error) { + // Try to get from cache first as bytes + data, err := c.GetBytes(ctx, key) + if err == nil { + var result interface{} + if err := json.Unmarshal(data, &result); err == nil { + return result, nil + } + } + + // Load the value + value, err := loader() + if err != nil { + return nil, fmt.Errorf("loader failed: %w", err) + } + + // Store in cache + if err := c.Set(ctx, key, value, ttl); err != nil { + return nil, fmt.Errorf("failed to cache value: %w", err) + } + + return value, nil +} diff --git a/pkg/cache/cache_test.go b/pkg/cache/cache_test.go new file mode 100644 index 0000000..224ca7c --- /dev/null +++ b/pkg/cache/cache_test.go @@ -0,0 +1,69 @@ +package cache + +import ( + "context" + "testing" + "time" +) + +func TestSetDefaultCache(t *testing.T) { + // Create a custom cache instance + provider := NewMemoryProvider(&Options{ + DefaultTTL: 1 * time.Minute, + MaxSize: 50, + }) + customCache := NewCache(provider) + + // Set it as the default + SetDefaultCache(customCache) + + // Verify it's now the default + retrievedCache := GetDefaultCache() + if retrievedCache != customCache { + t.Error("SetDefaultCache did not set the cache correctly") + } + + // Test that we can use it + ctx := context.Background() + testKey := "test_key" + testValue := "test_value" + + err := retrievedCache.Set(ctx, testKey, testValue, time.Minute) + if err != nil { + t.Fatalf("Failed to set value: %v", err) + } + + var result string + err = retrievedCache.Get(ctx, testKey, &result) + if err != nil { + t.Fatalf("Failed to get value: %v", err) + } + + if result != testValue { + t.Errorf("Expected %s, got %s", testValue, result) + } + + // Clean up - reset to default + SetDefaultCache(nil) +} + +func TestGetDefaultCacheInitialization(t *testing.T) { + // Reset to nil first + SetDefaultCache(nil) + + // GetDefaultCache should auto-initialize + cache := GetDefaultCache() + if cache == nil { + t.Error("GetDefaultCache should auto-initialize, got nil") + } + + // Should be usable + ctx := context.Background() + err := cache.Set(ctx, "test", "value", time.Minute) + if err != nil { + t.Errorf("Failed to use auto-initialized cache: %v", err) + } + + // Clean up + SetDefaultCache(nil) +} diff --git a/pkg/cache/example_usage.go b/pkg/cache/example_usage.go new file mode 100644 index 0000000..8aeac41 --- /dev/null +++ b/pkg/cache/example_usage.go @@ -0,0 +1,252 @@ +package cache + +import ( + "context" + "fmt" + "log" + "time" +) + +// ExampleInMemoryCache demonstrates using the in-memory cache provider. +func ExampleInMemoryCache() { + // Initialize with in-memory provider + err := UseMemory(&Options{ + DefaultTTL: 5 * time.Minute, + MaxSize: 1000, + }) + if err != nil { + log.Fatal(err) + } + defer Close() + + ctx := context.Background() + + // Get the cache instance + cache := GetDefaultCache() + + // Store a value + type User struct { + ID int + Name string + } + + user := User{ID: 1, Name: "John Doe"} + err = cache.Set(ctx, "user:1", user, 10*time.Minute) + if err != nil { + log.Fatal(err) + } + + // Retrieve a value + var retrieved User + err = cache.Get(ctx, "user:1", &retrieved) + if err != nil { + log.Fatal(err) + } + + fmt.Printf("Retrieved user: %+v\n", retrieved) + + // Check if key exists + exists := cache.Exists(ctx, "user:1") + fmt.Printf("Key exists: %v\n", exists) + + // Delete a key + err = cache.Delete(ctx, "user:1") + if err != nil { + log.Fatal(err) + } + + // Get statistics + stats, err := cache.Stats(ctx) + if err != nil { + log.Fatal(err) + } + fmt.Printf("Cache stats: %+v\n", stats) +} + +// ExampleRedisCache demonstrates using the Redis cache provider. +func ExampleRedisCache() { + // Initialize with Redis provider + err := UseRedis(&RedisConfig{ + Host: "localhost", + Port: 6379, + Password: "", // Set if Redis requires authentication + DB: 0, + Options: &Options{ + DefaultTTL: 5 * time.Minute, + }, + }) + if err != nil { + log.Fatal(err) + } + defer Close() + + ctx := context.Background() + + // Get the cache instance + cache := GetDefaultCache() + + // Store raw bytes + data := []byte("Hello, Redis!") + err = cache.SetBytes(ctx, "greeting", data, 1*time.Hour) + if err != nil { + log.Fatal(err) + } + + // Retrieve raw bytes + retrieved, err := cache.GetBytes(ctx, "greeting") + if err != nil { + log.Fatal(err) + } + + fmt.Printf("Retrieved data: %s\n", string(retrieved)) + + // Clear all cache + err = cache.Clear(ctx) + if err != nil { + log.Fatal(err) + } +} + +// ExampleMemcacheCache demonstrates using the Memcache cache provider. +func ExampleMemcacheCache() { + // Initialize with Memcache provider + err := UseMemcache(&MemcacheConfig{ + Servers: []string{"localhost:11211"}, + Options: &Options{ + DefaultTTL: 5 * time.Minute, + }, + }) + if err != nil { + log.Fatal(err) + } + defer Close() + + ctx := context.Background() + + // Get the cache instance + cache := GetDefaultCache() + + // Store a value + type Product struct { + ID int + Name string + Price float64 + } + + product := Product{ID: 100, Name: "Widget", Price: 29.99} + err = cache.Set(ctx, "product:100", product, 30*time.Minute) + if err != nil { + log.Fatal(err) + } + + // Retrieve a value + var retrieved Product + err = cache.Get(ctx, "product:100", &retrieved) + if err != nil { + log.Fatal(err) + } + + fmt.Printf("Retrieved product: %+v\n", retrieved) +} + +// ExampleGetOrSet demonstrates the GetOrSet pattern for lazy loading. +func ExampleGetOrSet() { + err := UseMemory(&Options{ + DefaultTTL: 5 * time.Minute, + MaxSize: 1000, + }) + if err != nil { + log.Fatal(err) + } + defer Close() + + ctx := context.Background() + cache := GetDefaultCache() + + type ExpensiveData struct { + Result string + } + + var data ExpensiveData + err = cache.GetOrSet(ctx, "expensive:computation", &data, 10*time.Minute, func() (interface{}, error) { + // This expensive operation only runs if the key is not in cache + fmt.Println("Computing expensive result...") + time.Sleep(1 * time.Second) + return ExpensiveData{Result: "computed value"}, nil + }) + if err != nil { + log.Fatal(err) + } + + fmt.Printf("Data: %+v\n", data) + + // Second call will use cached value + err = cache.GetOrSet(ctx, "expensive:computation", &data, 10*time.Minute, func() (interface{}, error) { + fmt.Println("This won't be called!") + return ExpensiveData{Result: "new value"}, nil + }) + if err != nil { + log.Fatal(err) + } + + fmt.Printf("Cached data: %+v\n", data) +} + +// ExampleCustomProvider demonstrates using a custom provider. +func ExampleCustomProvider() { + // Create a custom provider + memProvider := NewMemoryProvider(&Options{ + DefaultTTL: 10 * time.Minute, + MaxSize: 500, + }) + + // Initialize with custom provider + Initialize(memProvider) + defer Close() + + ctx := context.Background() + cache := GetDefaultCache() + + // Use the cache + err := cache.SetBytes(ctx, "key", []byte("value"), 5*time.Minute) + if err != nil { + log.Fatal(err) + } + + // Clean expired items (memory provider specific) + if mp, ok := cache.provider.(*MemoryProvider); ok { + count := mp.CleanExpired(ctx) + fmt.Printf("Cleaned %d expired items\n", count) + } +} + +// ExampleDeleteByPattern demonstrates pattern-based deletion (Redis only). +func ExampleDeleteByPattern() { + err := UseRedis(&RedisConfig{ + Host: "localhost", + Port: 6379, + Options: &Options{ + DefaultTTL: 5 * time.Minute, + }, + }) + if err != nil { + log.Fatal(err) + } + defer Close() + + ctx := context.Background() + cache := GetDefaultCache() + + // Store multiple keys with a pattern + cache.SetBytes(ctx, "user:1:profile", []byte("profile1"), 10*time.Minute) + cache.SetBytes(ctx, "user:2:profile", []byte("profile2"), 10*time.Minute) + cache.SetBytes(ctx, "user:1:settings", []byte("settings1"), 10*time.Minute) + + // Delete all keys matching pattern (Redis glob pattern) + err = cache.DeleteByPattern(ctx, "user:*:profile") + if err != nil { + log.Fatal(err) + } + + fmt.Println("Deleted all user profile keys") +} diff --git a/pkg/cache/provider.go b/pkg/cache/provider.go new file mode 100644 index 0000000..b9ae132 --- /dev/null +++ b/pkg/cache/provider.go @@ -0,0 +1,57 @@ +package cache + +import ( + "context" + "time" +) + +// Provider defines the interface that all cache providers must implement. +type Provider interface { + // Get retrieves a value from the cache by key. + // Returns nil, false if key doesn't exist or is expired. + Get(ctx context.Context, key string) ([]byte, bool) + + // Set stores a value in the cache with the specified TTL. + // If ttl is 0, the item never expires. + Set(ctx context.Context, key string, value []byte, ttl time.Duration) error + + // Delete removes a key from the cache. + Delete(ctx context.Context, key string) error + + // DeleteByPattern removes all keys matching the pattern. + // Pattern syntax depends on the provider implementation. + DeleteByPattern(ctx context.Context, pattern string) error + + // Clear removes all items from the cache. + Clear(ctx context.Context) error + + // Exists checks if a key exists in the cache. + Exists(ctx context.Context, key string) bool + + // Close closes the provider and releases any resources. + Close() error + + // Stats returns statistics about the cache provider. + Stats(ctx context.Context) (*CacheStats, error) +} + +// CacheStats contains cache statistics. +type CacheStats struct { + Hits int64 `json:"hits"` + Misses int64 `json:"misses"` + Keys int64 `json:"keys"` + ProviderType string `json:"provider_type"` + ProviderStats map[string]any `json:"provider_stats,omitempty"` +} + +// Options contains configuration options for cache providers. +type Options struct { + // DefaultTTL is the default time-to-live for cache items. + DefaultTTL time.Duration + + // MaxSize is the maximum number of items (for in-memory provider). + MaxSize int + + // EvictionPolicy determines how items are evicted (LRU, LFU, etc). + EvictionPolicy string +} diff --git a/pkg/cache/provider_memcache.go b/pkg/cache/provider_memcache.go new file mode 100644 index 0000000..f764d72 --- /dev/null +++ b/pkg/cache/provider_memcache.go @@ -0,0 +1,144 @@ +package cache + +import ( + "context" + "fmt" + "time" + + "github.com/bradfitz/gomemcache/memcache" +) + +// MemcacheProvider is a Memcache implementation of the Provider interface. +type MemcacheProvider struct { + client *memcache.Client + options *Options +} + +// MemcacheConfig contains Memcache-specific configuration. +type MemcacheConfig struct { + // Servers is a list of memcache server addresses (e.g., "localhost:11211") + Servers []string + + // MaxIdleConns is the maximum number of idle connections (default: 2) + MaxIdleConns int + + // Timeout for connection operations (default: 1 second) + Timeout time.Duration + + // Options contains general cache options + Options *Options +} + +// NewMemcacheProvider creates a new Memcache cache provider. +func NewMemcacheProvider(config *MemcacheConfig) (*MemcacheProvider, error) { + if config == nil { + config = &MemcacheConfig{ + Servers: []string{"localhost:11211"}, + } + } + + if len(config.Servers) == 0 { + config.Servers = []string{"localhost:11211"} + } + + if config.MaxIdleConns == 0 { + config.MaxIdleConns = 2 + } + + if config.Timeout == 0 { + config.Timeout = 1 * time.Second + } + + if config.Options == nil { + config.Options = &Options{ + DefaultTTL: 5 * time.Minute, + } + } + + client := memcache.New(config.Servers...) + client.MaxIdleConns = config.MaxIdleConns + client.Timeout = config.Timeout + + // Test connection + if err := client.Ping(); err != nil { + return nil, fmt.Errorf("failed to connect to Memcache: %w", err) + } + + return &MemcacheProvider{ + client: client, + options: config.Options, + }, nil +} + +// Get retrieves a value from the cache by key. +func (m *MemcacheProvider) Get(ctx context.Context, key string) ([]byte, bool) { + item, err := m.client.Get(key) + if err == memcache.ErrCacheMiss { + return nil, false + } + if err != nil { + return nil, false + } + return item.Value, true +} + +// Set stores a value in the cache with the specified TTL. +func (m *MemcacheProvider) Set(ctx context.Context, key string, value []byte, ttl time.Duration) error { + if ttl == 0 { + ttl = m.options.DefaultTTL + } + + item := &memcache.Item{ + Key: key, + Value: value, + Expiration: int32(ttl.Seconds()), + } + + return m.client.Set(item) +} + +// Delete removes a key from the cache. +func (m *MemcacheProvider) Delete(ctx context.Context, key string) error { + err := m.client.Delete(key) + if err == memcache.ErrCacheMiss { + return nil + } + return err +} + +// DeleteByPattern removes all keys matching the pattern. +// Note: Memcache does not support pattern-based deletion natively. +// This is a no-op for memcache and returns an error. +func (m *MemcacheProvider) DeleteByPattern(ctx context.Context, pattern string) error { + return fmt.Errorf("pattern-based deletion is not supported by Memcache") +} + +// Clear removes all items from the cache. +func (m *MemcacheProvider) Clear(ctx context.Context) error { + return m.client.FlushAll() +} + +// Exists checks if a key exists in the cache. +func (m *MemcacheProvider) Exists(ctx context.Context, key string) bool { + _, err := m.client.Get(key) + return err == nil +} + +// Close closes the provider and releases any resources. +func (m *MemcacheProvider) Close() error { + // Memcache client doesn't have a close method + return nil +} + +// Stats returns statistics about the cache provider. +// Note: Memcache provider returns limited statistics. +func (m *MemcacheProvider) Stats(ctx context.Context) (*CacheStats, error) { + stats := &CacheStats{ + ProviderType: "memcache", + ProviderStats: map[string]any{ + "note": "Memcache does not provide detailed statistics through the standard client", + }, + } + + return stats, nil +} diff --git a/pkg/cache/provider_memory.go b/pkg/cache/provider_memory.go new file mode 100644 index 0000000..ae039b0 --- /dev/null +++ b/pkg/cache/provider_memory.go @@ -0,0 +1,226 @@ +package cache + +import ( + "context" + "fmt" + "regexp" + "sync" + "time" +) + +// memoryItem represents a cached item in memory. +type memoryItem struct { + Value []byte + Expiration time.Time + LastAccess time.Time + HitCount int64 +} + +// isExpired checks if the item has expired. +func (m *memoryItem) isExpired() bool { + if m.Expiration.IsZero() { + return false + } + return time.Now().After(m.Expiration) +} + +// MemoryProvider is an in-memory implementation of the Provider interface. +type MemoryProvider struct { + mu sync.RWMutex + items map[string]*memoryItem + options *Options + hits int64 + misses int64 +} + +// NewMemoryProvider creates a new in-memory cache provider. +func NewMemoryProvider(opts *Options) *MemoryProvider { + if opts == nil { + opts = &Options{ + DefaultTTL: 5 * time.Minute, + MaxSize: 10000, + } + } + + return &MemoryProvider{ + items: make(map[string]*memoryItem), + options: opts, + } +} + +// Get retrieves a value from the cache by key. +func (m *MemoryProvider) Get(ctx context.Context, key string) ([]byte, bool) { + m.mu.Lock() + defer m.mu.Unlock() + + item, exists := m.items[key] + if !exists { + m.misses++ + return nil, false + } + + if item.isExpired() { + delete(m.items, key) + m.misses++ + return nil, false + } + + item.LastAccess = time.Now() + item.HitCount++ + m.hits++ + + return item.Value, true +} + +// Set stores a value in the cache with the specified TTL. +func (m *MemoryProvider) Set(ctx context.Context, key string, value []byte, ttl time.Duration) error { + m.mu.Lock() + defer m.mu.Unlock() + + if ttl == 0 { + ttl = m.options.DefaultTTL + } + + var expiration time.Time + if ttl > 0 { + expiration = time.Now().Add(ttl) + } + + // Check max size and evict if necessary + if m.options.MaxSize > 0 && len(m.items) >= m.options.MaxSize { + if _, exists := m.items[key]; !exists { + m.evictOne() + } + } + + m.items[key] = &memoryItem{ + Value: value, + Expiration: expiration, + LastAccess: time.Now(), + } + + return nil +} + +// Delete removes a key from the cache. +func (m *MemoryProvider) Delete(ctx context.Context, key string) error { + m.mu.Lock() + defer m.mu.Unlock() + + delete(m.items, key) + return nil +} + +// DeleteByPattern removes all keys matching the pattern. +func (m *MemoryProvider) DeleteByPattern(ctx context.Context, pattern string) error { + m.mu.Lock() + defer m.mu.Unlock() + + re, err := regexp.Compile(pattern) + if err != nil { + return fmt.Errorf("invalid pattern: %w", err) + } + + for key := range m.items { + if re.MatchString(key) { + delete(m.items, key) + } + } + + return nil +} + +// Clear removes all items from the cache. +func (m *MemoryProvider) Clear(ctx context.Context) error { + m.mu.Lock() + defer m.mu.Unlock() + + m.items = make(map[string]*memoryItem) + m.hits = 0 + m.misses = 0 + return nil +} + +// Exists checks if a key exists in the cache. +func (m *MemoryProvider) Exists(ctx context.Context, key string) bool { + m.mu.RLock() + defer m.mu.RUnlock() + + item, exists := m.items[key] + if !exists { + return false + } + + return !item.isExpired() +} + +// Close closes the provider and releases any resources. +func (m *MemoryProvider) Close() error { + m.mu.Lock() + defer m.mu.Unlock() + + m.items = nil + return nil +} + +// Stats returns statistics about the cache provider. +func (m *MemoryProvider) Stats(ctx context.Context) (*CacheStats, error) { + m.mu.RLock() + defer m.mu.RUnlock() + + // Clean expired items first + validKeys := 0 + for _, item := range m.items { + if !item.isExpired() { + validKeys++ + } + } + + return &CacheStats{ + Hits: m.hits, + Misses: m.misses, + Keys: int64(validKeys), + ProviderType: "memory", + ProviderStats: map[string]any{ + "capacity": m.options.MaxSize, + }, + }, nil +} + +// evictOne removes one item from the cache using LRU strategy. +func (m *MemoryProvider) evictOne() { + var oldestKey string + var oldestTime time.Time + + for key, item := range m.items { + if item.isExpired() { + delete(m.items, key) + return + } + + if oldestKey == "" || item.LastAccess.Before(oldestTime) { + oldestKey = key + oldestTime = item.LastAccess + } + } + + if oldestKey != "" { + delete(m.items, oldestKey) + } +} + +// CleanExpired removes all expired items from the cache. +func (m *MemoryProvider) CleanExpired(ctx context.Context) int { + m.mu.Lock() + defer m.mu.Unlock() + + count := 0 + for key, item := range m.items { + if item.isExpired() { + delete(m.items, key) + count++ + } + } + + return count +} diff --git a/pkg/cache/provider_redis.go b/pkg/cache/provider_redis.go new file mode 100644 index 0000000..4e04711 --- /dev/null +++ b/pkg/cache/provider_redis.go @@ -0,0 +1,185 @@ +package cache + +import ( + "context" + "fmt" + "time" + + "github.com/redis/go-redis/v9" +) + +// RedisProvider is a Redis implementation of the Provider interface. +type RedisProvider struct { + client *redis.Client + options *Options +} + +// RedisConfig contains Redis-specific configuration. +type RedisConfig struct { + // Host is the Redis server host (default: localhost) + Host string + + // Port is the Redis server port (default: 6379) + Port int + + // Password for Redis authentication (optional) + Password string + + // DB is the Redis database number (default: 0) + DB int + + // PoolSize is the maximum number of connections (default: 10) + PoolSize int + + // Options contains general cache options + Options *Options +} + +// NewRedisProvider creates a new Redis cache provider. +func NewRedisProvider(config *RedisConfig) (*RedisProvider, error) { + if config == nil { + config = &RedisConfig{ + Host: "localhost", + Port: 6379, + DB: 0, + } + } + + if config.Host == "" { + config.Host = "localhost" + } + if config.Port == 0 { + config.Port = 6379 + } + if config.PoolSize == 0 { + config.PoolSize = 10 + } + + if config.Options == nil { + config.Options = &Options{ + DefaultTTL: 5 * time.Minute, + } + } + + client := redis.NewClient(&redis.Options{ + Addr: fmt.Sprintf("%s:%d", config.Host, config.Port), + Password: config.Password, + DB: config.DB, + PoolSize: config.PoolSize, + }) + + // Test connection + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + if err := client.Ping(ctx).Err(); err != nil { + return nil, fmt.Errorf("failed to connect to Redis: %w", err) + } + + return &RedisProvider{ + client: client, + options: config.Options, + }, nil +} + +// Get retrieves a value from the cache by key. +func (r *RedisProvider) Get(ctx context.Context, key string) ([]byte, bool) { + val, err := r.client.Get(ctx, key).Bytes() + if err == redis.Nil { + return nil, false + } + if err != nil { + return nil, false + } + return val, true +} + +// Set stores a value in the cache with the specified TTL. +func (r *RedisProvider) Set(ctx context.Context, key string, value []byte, ttl time.Duration) error { + if ttl == 0 { + ttl = r.options.DefaultTTL + } + + return r.client.Set(ctx, key, value, ttl).Err() +} + +// Delete removes a key from the cache. +func (r *RedisProvider) Delete(ctx context.Context, key string) error { + return r.client.Del(ctx, key).Err() +} + +// DeleteByPattern removes all keys matching the pattern. +func (r *RedisProvider) DeleteByPattern(ctx context.Context, pattern string) error { + iter := r.client.Scan(ctx, 0, pattern, 0).Iterator() + pipe := r.client.Pipeline() + + count := 0 + for iter.Next(ctx) { + pipe.Del(ctx, iter.Val()) + count++ + + // Execute pipeline in batches of 100 + if count%100 == 0 { + if _, err := pipe.Exec(ctx); err != nil { + return err + } + pipe = r.client.Pipeline() + } + } + + if err := iter.Err(); err != nil { + return err + } + + // Execute remaining commands + if count%100 != 0 { + _, err := pipe.Exec(ctx) + return err + } + + return nil +} + +// Clear removes all items from the cache. +func (r *RedisProvider) Clear(ctx context.Context) error { + return r.client.FlushDB(ctx).Err() +} + +// Exists checks if a key exists in the cache. +func (r *RedisProvider) Exists(ctx context.Context, key string) bool { + result, err := r.client.Exists(ctx, key).Result() + if err != nil { + return false + } + return result > 0 +} + +// Close closes the provider and releases any resources. +func (r *RedisProvider) Close() error { + return r.client.Close() +} + +// Stats returns statistics about the cache provider. +func (r *RedisProvider) Stats(ctx context.Context) (*CacheStats, error) { + info, err := r.client.Info(ctx, "stats", "keyspace").Result() + if err != nil { + return nil, fmt.Errorf("failed to get Redis stats: %w", err) + } + + dbSize, err := r.client.DBSize(ctx).Result() + if err != nil { + return nil, fmt.Errorf("failed to get DB size: %w", err) + } + + // Parse stats from INFO command + // This is a simplified version - you may want to parse more detailed stats + stats := &CacheStats{ + Keys: dbSize, + ProviderType: "redis", + ProviderStats: map[string]any{ + "info": info, + }, + } + + return stats, nil +} diff --git a/pkg/cache/query_cache.go b/pkg/cache/query_cache.go new file mode 100644 index 0000000..64593a1 --- /dev/null +++ b/pkg/cache/query_cache.go @@ -0,0 +1,127 @@ +package cache + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "fmt" + "strings" + + "github.com/bitechdev/ResolveSpec/pkg/common" +) + +// QueryCacheKey represents the components used to build a cache key for query total count +type QueryCacheKey struct { + TableName string `json:"table_name"` + Filters []common.FilterOption `json:"filters"` + Sort []common.SortOption `json:"sort"` + CustomSQLWhere string `json:"custom_sql_where,omitempty"` + CustomSQLOr string `json:"custom_sql_or,omitempty"` + Expand []ExpandOptionKey `json:"expand,omitempty"` + Distinct bool `json:"distinct,omitempty"` + CursorForward string `json:"cursor_forward,omitempty"` + CursorBackward string `json:"cursor_backward,omitempty"` +} + +// ExpandOptionKey represents expand options for cache key +type ExpandOptionKey struct { + Relation string `json:"relation"` + Where string `json:"where,omitempty"` +} + +// BuildQueryCacheKey builds a cache key from query parameters for total count caching +// This is used to cache the total count of records matching a query +func BuildQueryCacheKey(tableName string, filters []common.FilterOption, sort []common.SortOption, customWhere, customOr string) string { + key := QueryCacheKey{ + TableName: tableName, + Filters: filters, + Sort: sort, + CustomSQLWhere: customWhere, + CustomSQLOr: customOr, + } + + // Serialize to JSON for consistent hashing + jsonData, err := json.Marshal(key) + if err != nil { + // Fallback to simple string concatenation if JSON fails + return hashString(fmt.Sprintf("%s_%v_%v_%s_%s", tableName, filters, sort, customWhere, customOr)) + } + + return hashString(string(jsonData)) +} + +// BuildExtendedQueryCacheKey builds a cache key for extended query options (restheadspec) +// Includes expand, distinct, and cursor pagination options +func BuildExtendedQueryCacheKey(tableName string, filters []common.FilterOption, sort []common.SortOption, + customWhere, customOr string, expandOpts []interface{}, distinct bool, cursorFwd, cursorBwd string) string { + + key := QueryCacheKey{ + TableName: tableName, + Filters: filters, + Sort: sort, + CustomSQLWhere: customWhere, + CustomSQLOr: customOr, + Distinct: distinct, + CursorForward: cursorFwd, + CursorBackward: cursorBwd, + } + + // Convert expand options to cache key format + if len(expandOpts) > 0 { + key.Expand = make([]ExpandOptionKey, 0, len(expandOpts)) + for _, exp := range expandOpts { + // Type assert to get the expand option fields we care about for caching + if expMap, ok := exp.(map[string]interface{}); ok { + expKey := ExpandOptionKey{} + if rel, ok := expMap["relation"].(string); ok { + expKey.Relation = rel + } + if where, ok := expMap["where"].(string); ok { + expKey.Where = where + } + key.Expand = append(key.Expand, expKey) + } + } + // Sort expand options for consistent hashing (already sorted by relation name above) + } + + // Serialize to JSON for consistent hashing + jsonData, err := json.Marshal(key) + if err != nil { + // Fallback to simple string concatenation if JSON fails + return hashString(fmt.Sprintf("%s_%v_%v_%s_%s_%v_%v_%s_%s", + tableName, filters, sort, customWhere, customOr, expandOpts, distinct, cursorFwd, cursorBwd)) + } + + return hashString(string(jsonData)) +} + +// hashString computes SHA256 hash of a string +func hashString(s string) string { + h := sha256.New() + h.Write([]byte(s)) + return hex.EncodeToString(h.Sum(nil)) +} + +// GetQueryTotalCacheKey returns a formatted cache key for storing/retrieving total count +func GetQueryTotalCacheKey(hash string) string { + return fmt.Sprintf("query_total:%s", hash) +} + +// CachedTotal represents a cached total count +type CachedTotal struct { + Total int `json:"total"` +} + +// InvalidateCacheForTable removes all cached totals for a specific table +// This should be called when data in the table changes (insert/update/delete) +func InvalidateCacheForTable(ctx context.Context, tableName string) error { + cache := GetDefaultCache() + + // Build a pattern to match all query totals for this table + // Note: This requires pattern matching support in the provider + pattern := fmt.Sprintf("query_total:*%s*", strings.ToLower(tableName)) + + return cache.DeleteByPattern(ctx, pattern) +} diff --git a/pkg/cache/query_cache_test.go b/pkg/cache/query_cache_test.go new file mode 100644 index 0000000..0920983 --- /dev/null +++ b/pkg/cache/query_cache_test.go @@ -0,0 +1,151 @@ +package cache + +import ( + "context" + "testing" + "time" + + "github.com/bitechdev/ResolveSpec/pkg/common" +) + +func TestBuildQueryCacheKey(t *testing.T) { + filters := []common.FilterOption{ + {Column: "name", Operator: "eq", Value: "test"}, + {Column: "age", Operator: "gt", Value: 25}, + } + sorts := []common.SortOption{ + {Column: "name", Direction: "asc"}, + } + + // Generate cache key + key1 := BuildQueryCacheKey("users", filters, sorts, "status = 'active'", "") + + // Same parameters should generate same key + key2 := BuildQueryCacheKey("users", filters, sorts, "status = 'active'", "") + + if key1 != key2 { + t.Errorf("Expected same cache keys for identical parameters, got %s and %s", key1, key2) + } + + // Different parameters should generate different key + key3 := BuildQueryCacheKey("users", filters, sorts, "status = 'inactive'", "") + + if key1 == key3 { + t.Errorf("Expected different cache keys for different parameters, got %s and %s", key1, key3) + } +} + +func TestBuildExtendedQueryCacheKey(t *testing.T) { + filters := []common.FilterOption{ + {Column: "name", Operator: "eq", Value: "test"}, + } + sorts := []common.SortOption{ + {Column: "name", Direction: "asc"}, + } + expandOpts := []interface{}{ + map[string]interface{}{ + "relation": "posts", + "where": "status = 'published'", + }, + } + + // Generate cache key + key1 := BuildExtendedQueryCacheKey("users", filters, sorts, "", "", expandOpts, false, "", "") + + // Same parameters should generate same key + key2 := BuildExtendedQueryCacheKey("users", filters, sorts, "", "", expandOpts, false, "", "") + + if key1 != key2 { + t.Errorf("Expected same cache keys for identical parameters") + } + + // Different distinct value should generate different key + key3 := BuildExtendedQueryCacheKey("users", filters, sorts, "", "", expandOpts, true, "", "") + + if key1 == key3 { + t.Errorf("Expected different cache keys for different distinct values") + } +} + +func TestGetQueryTotalCacheKey(t *testing.T) { + hash := "abc123" + key := GetQueryTotalCacheKey(hash) + + expected := "query_total:abc123" + if key != expected { + t.Errorf("Expected %s, got %s", expected, key) + } +} + +func TestCachedTotalIntegration(t *testing.T) { + // Initialize cache with memory provider for testing + UseMemory(&Options{ + DefaultTTL: 1 * time.Minute, + MaxSize: 100, + }) + + ctx := context.Background() + + // Create test data + filters := []common.FilterOption{ + {Column: "status", Operator: "eq", Value: "active"}, + } + sorts := []common.SortOption{ + {Column: "created_at", Direction: "desc"}, + } + + // Build cache key + cacheKeyHash := BuildQueryCacheKey("test_table", filters, sorts, "", "") + cacheKey := GetQueryTotalCacheKey(cacheKeyHash) + + // Store a total count in cache + totalToCache := CachedTotal{Total: 42} + err := GetDefaultCache().Set(ctx, cacheKey, totalToCache, time.Minute) + if err != nil { + t.Fatalf("Failed to set cache: %v", err) + } + + // Retrieve from cache + var cachedTotal CachedTotal + err = GetDefaultCache().Get(ctx, cacheKey, &cachedTotal) + if err != nil { + t.Fatalf("Failed to get from cache: %v", err) + } + + if cachedTotal.Total != 42 { + t.Errorf("Expected total 42, got %d", cachedTotal.Total) + } + + // Test cache miss + nonExistentKey := GetQueryTotalCacheKey("nonexistent") + var missedTotal CachedTotal + err = GetDefaultCache().Get(ctx, nonExistentKey, &missedTotal) + if err == nil { + t.Errorf("Expected error for cache miss, got nil") + } +} + +func TestHashString(t *testing.T) { + input1 := "test string" + input2 := "test string" + input3 := "different string" + + hash1 := hashString(input1) + hash2 := hashString(input2) + hash3 := hashString(input3) + + // Same input should produce same hash + if hash1 != hash2 { + t.Errorf("Expected same hash for identical inputs") + } + + // Different input should produce different hash + if hash1 == hash3 { + t.Errorf("Expected different hash for different inputs") + } + + // Hash should be hex encoded SHA256 (64 characters) + if len(hash1) != 64 { + t.Errorf("Expected hash length of 64, got %d", len(hash1)) + } +} diff --git a/pkg/funcspec/function_api.go b/pkg/funcspec/function_api.go index e56d8ac..ace729e 100644 --- a/pkg/funcspec/function_api.go +++ b/pkg/funcspec/function_api.go @@ -20,16 +20,24 @@ import ( // Handler handles function-based SQL API requests type Handler struct { - db common.Database + db common.Database + hooks *HookRegistry } // NewHandler creates a new function API handler func NewHandler(db common.Database) *Handler { return &Handler{ - db: db, + db: db, + hooks: NewHookRegistry(), } } +// Hooks returns the hook registry for this handler +// Use this to register custom hooks for operations +func (h *Handler) Hooks() *HookRegistry { + return h.hooks +} + // HTTPFuncType is a function type for HTTP handlers type HTTPFuncType func(http.ResponseWriter, *http.Request) @@ -64,18 +72,77 @@ func (h *Handler) SqlQueryList(sqlquery string, pNoCount, pBlankparms, pAllowFil w.Header().Set("Content-Type", "application/json") + // Initialize hook context + hookCtx := &HookContext{ + Context: ctx, + Handler: h, + Request: r, + Writer: w, + SQLQuery: sqlquery, + Variables: variables, + InputVars: inputvars, + MetaInfo: metainfo, + PropQry: propQry, + UserContext: userCtx, + NoCount: pNoCount, + BlankParams: pBlankparms, + AllowFilter: pAllowFilter, + ComplexAPI: complexAPI, + } + + // Execute BeforeQueryList hook + if err := h.hooks.Execute(BeforeQueryList, hookCtx); err != nil { + logger.Error("BeforeQueryList hook failed: %v", err) + sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err) + return + } + + // Check if hook aborted the operation + if hookCtx.Abort { + if hookCtx.AbortCode == 0 { + hookCtx.AbortCode = http.StatusBadRequest + } + sendError(w, hookCtx.AbortCode, "operation_aborted", hookCtx.AbortMessage, nil) + return + } + + // Use potentially modified SQL query and variables from hooks + sqlquery = hookCtx.SQLQuery + variables = hookCtx.Variables + complexAPI = hookCtx.ComplexAPI + // Extract input variables from SQL query (placeholders like [variable]) sqlquery = h.extractInputVariables(sqlquery, &inputvars) // Merge URL path parameters sqlquery = h.mergePathParams(r, sqlquery, variables) + // Parse comprehensive parameters from headers and query string + reqParams := h.ParseParameters(r) + complexAPI = reqParams.ComplexAPI + // Merge query string parameters sqlquery = h.mergeQueryParams(r, sqlquery, variables, pAllowFilter, propQry) // Merge header parameters sqlquery = h.mergeHeaderParams(r, sqlquery, variables, propQry, &complexAPI) + // Apply filters from parsed parameters (if not already applied by pAllowFilter) + if !pAllowFilter { + sqlquery = h.ApplyFilters(sqlquery, reqParams) + } + + // Apply field selection + sqlquery = h.ApplyFieldSelection(sqlquery, reqParams) + + // Apply DISTINCT if requested + sqlquery = h.ApplyDistinct(sqlquery, reqParams) + + // Override pNoCount if skipcount is specified + if reqParams.SkipCount { + pNoCount = true + } + // Build metainfo metainfo["ipaddress"] = getIPAddress(r) metainfo["url"] = r.RequestURI @@ -95,12 +162,32 @@ func (h *Handler) SqlQueryList(sqlquery string, pNoCount, pBlankparms, pAllowFil } } + // Update hook context with latest SQL query and variables + hookCtx.SQLQuery = sqlquery + hookCtx.Variables = variables + hookCtx.InputVars = inputvars + // Execute query within transaction err := h.db.RunInTransaction(ctx, func(tx common.Database) error { sqlqueryCnt := sqlquery // Parse sorting and pagination parameters sortcols, limit, offset := h.parsePaginationParams(r) + + // Override with parsed parameters if available + if reqParams.SortColumns != "" { + sortcols = reqParams.SortColumns + } + if reqParams.Limit > 0 { + limit = reqParams.Limit + } + if reqParams.Offset > 0 { + offset = reqParams.Offset + } + + hookCtx.SortColumns = sortcols + hookCtx.Limit = limit + hookCtx.Offset = offset fromPos := strings.Index(strings.ToLower(sqlquery), "from ") orderbyPos := strings.Index(strings.ToLower(sqlquery), "order by") @@ -127,6 +214,16 @@ func (h *Handler) SqlQueryList(sqlquery string, pNoCount, pBlankparms, pAllowFil total = countResult.Count } + // Execute BeforeSQLExec hook + hookCtx.SQLQuery = sqlquery + if err := h.hooks.Execute(BeforeSQLExec, hookCtx); err != nil { + logger.Error("BeforeSQLExec hook failed: %v", err) + sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err) + return err + } + // Use potentially modified SQL query from hook + sqlquery = hookCtx.SQLQuery + // Execute main query rows := make([]map[string]interface{}, 0) if err := tx.Query(ctx, &rows, sqlquery); err != nil { @@ -140,6 +237,20 @@ func (h *Handler) SqlQueryList(sqlquery string, pNoCount, pBlankparms, pAllowFil total = int64(len(dbobjlist)) } + // Execute AfterSQLExec hook + hookCtx.Result = dbobjlist + hookCtx.Total = total + if err := h.hooks.Execute(AfterSQLExec, hookCtx); err != nil { + logger.Error("AfterSQLExec hook failed: %v", err) + sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err) + return err + } + // Use potentially modified result from hook + if modifiedResult, ok := hookCtx.Result.([]map[string]interface{}); ok { + dbobjlist = modifiedResult + } + total = hookCtx.Total + return nil }) @@ -148,6 +259,21 @@ func (h *Handler) SqlQueryList(sqlquery string, pNoCount, pBlankparms, pAllowFil return } + // Execute AfterQueryList hook + hookCtx.Result = dbobjlist + hookCtx.Total = total + hookCtx.Error = err + if err := h.hooks.Execute(AfterQueryList, hookCtx); err != nil { + logger.Error("AfterQueryList hook failed: %v", err) + sendError(w, http.StatusInternalServerError, "hook_error", "Hook execution failed", err) + return + } + // Use potentially modified result from hook + if modifiedResult, ok := hookCtx.Result.([]map[string]interface{}); ok { + dbobjlist = modifiedResult + } + total = hookCtx.Total + // Set response headers respOffset := 0 if offsetStr := r.URL.Query().Get("offset"); offsetStr != "" { @@ -159,12 +285,44 @@ func (h *Handler) SqlQueryList(sqlquery string, pNoCount, pBlankparms, pAllowFil w.Header().Set("Content-Range", fmt.Sprintf("items %d-%d/%d", respOffset, respOffset+len(dbobjlist), total)) logger.Info("Serving: Records %d of %d", len(dbobjlist), total) + // Execute BeforeResponse hook + hookCtx.Result = dbobjlist + hookCtx.Total = total + if err := h.hooks.Execute(BeforeResponse, hookCtx); err != nil { + logger.Error("BeforeResponse hook failed: %v", err) + sendError(w, http.StatusInternalServerError, "hook_error", "Hook execution failed", err) + return + } + // Use potentially modified result from hook + if modifiedResult, ok := hookCtx.Result.([]map[string]interface{}); ok { + dbobjlist = modifiedResult + } + if len(dbobjlist) == 0 { w.Write([]byte("[]")) return } - if complexAPI { + // Format response based on response format + switch reqParams.ResponseFormat { + case "syncfusion": + // Syncfusion format: { result: data, count: total } + response := map[string]interface{}{ + "result": dbobjlist, + "count": total, + } + data, err := json.Marshal(response) + if err != nil { + sendError(w, http.StatusInternalServerError, "json_error", "Could not marshal response", err) + } else { + if int64(len(dbobjlist)) < total { + w.WriteHeader(http.StatusPartialContent) + } + w.Write(data) + } + + case "detail": + // Detail format: complex API with metadata metaobj := map[string]interface{}{ "items": dbobjlist, "count": fmt.Sprintf("%d", len(dbobjlist)), @@ -172,7 +330,6 @@ func (h *Handler) SqlQueryList(sqlquery string, pNoCount, pBlankparms, pAllowFil "tablename": r.URL.Path, "tableprefix": "gsql", } - data, err := json.Marshal(metaobj) if err != nil { sendError(w, http.StatusInternalServerError, "json_error", "Could not marshal response", err) @@ -182,15 +339,36 @@ func (h *Handler) SqlQueryList(sqlquery string, pNoCount, pBlankparms, pAllowFil } w.Write(data) } - } else { - data, err := json.Marshal(dbobjlist) - if err != nil { - sendError(w, http.StatusInternalServerError, "json_error", "Could not marshal response", err) - } else { - if int64(len(dbobjlist)) < total { - w.WriteHeader(http.StatusPartialContent) + + default: + // Simple format: just return the data array (or complex API if requested) + if complexAPI { + metaobj := map[string]interface{}{ + "items": dbobjlist, + "count": fmt.Sprintf("%d", len(dbobjlist)), + "total": fmt.Sprintf("%d", total), + "tablename": r.URL.Path, + "tableprefix": "gsql", + } + data, err := json.Marshal(metaobj) + if err != nil { + sendError(w, http.StatusInternalServerError, "json_error", "Could not marshal response", err) + } else { + if int64(len(dbobjlist)) < total { + w.WriteHeader(http.StatusPartialContent) + } + w.Write(data) + } + } else { + data, err := json.Marshal(dbobjlist) + if err != nil { + sendError(w, http.StatusInternalServerError, "json_error", "Could not marshal response", err) + } else { + if int64(len(dbobjlist)) < total { + w.WriteHeader(http.StatusPartialContent) + } + w.Write(data) } - w.Write(data) } } } @@ -215,6 +393,7 @@ func (h *Handler) SqlQuery(sqlquery string, pBlankparms bool) HTTPFuncType { metainfo := make(map[string]interface{}) variables := make(map[string]interface{}) dbobj := make(map[string]interface{}) + complexAPI := false // Get user context from security package userCtx, ok := security.GetUserContext(ctx) @@ -225,18 +404,67 @@ func (h *Handler) SqlQuery(sqlquery string, pBlankparms bool) HTTPFuncType { w.Header().Set("Content-Type", "application/json") + // Initialize hook context + hookCtx := &HookContext{ + Context: ctx, + Handler: h, + Request: r, + Writer: w, + SQLQuery: sqlquery, + Variables: variables, + InputVars: inputvars, + MetaInfo: metainfo, + PropQry: propQry, + UserContext: userCtx, + BlankParams: pBlankparms, + ComplexAPI: complexAPI, + } + + // Execute BeforeQuery hook + if err := h.hooks.Execute(BeforeQuery, hookCtx); err != nil { + logger.Error("BeforeQuery hook failed: %v", err) + sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err) + return + } + + // Check if hook aborted the operation + if hookCtx.Abort { + if hookCtx.AbortCode == 0 { + hookCtx.AbortCode = http.StatusBadRequest + } + sendError(w, hookCtx.AbortCode, "operation_aborted", hookCtx.AbortMessage, nil) + return + } + + // Use potentially modified SQL query and variables from hooks + sqlquery = hookCtx.SQLQuery + variables = hookCtx.Variables + // Extract input variables from SQL query sqlquery = h.extractInputVariables(sqlquery, &inputvars) // Merge URL path parameters sqlquery = h.mergePathParams(r, sqlquery, variables) + // Parse comprehensive parameters from headers and query string + reqParams := h.ParseParameters(r) + complexAPI = reqParams.ComplexAPI + // Merge query string parameters sqlquery = h.mergeQueryParams(r, sqlquery, variables, false, propQry) // Merge header parameters - complexAPI := false sqlquery = h.mergeHeaderParams(r, sqlquery, variables, propQry, &complexAPI) + hookCtx.ComplexAPI = complexAPI + + // Apply filters from parsed parameters + sqlquery = h.ApplyFilters(sqlquery, reqParams) + + // Apply field selection + sqlquery = h.ApplyFieldSelection(sqlquery, reqParams) + + // Apply DISTINCT if requested + sqlquery = h.ApplyDistinct(sqlquery, reqParams) // Build metainfo metainfo["ipaddress"] = getIPAddress(r) @@ -272,8 +500,22 @@ func (h *Handler) SqlQuery(sqlquery string, pBlankparms bool) HTTPFuncType { } } + // Update hook context with latest SQL query and variables + hookCtx.SQLQuery = sqlquery + hookCtx.Variables = variables + hookCtx.InputVars = inputvars + // Execute query within transaction err := h.db.RunInTransaction(ctx, func(tx common.Database) error { + // Execute BeforeSQLExec hook + if err := h.hooks.Execute(BeforeSQLExec, hookCtx); err != nil { + logger.Error("BeforeSQLExec hook failed: %v", err) + sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err) + return err + } + // Use potentially modified SQL query from hook + sqlquery = hookCtx.SQLQuery + // Execute main query rows := make([]map[string]interface{}, 0) if err := tx.Query(ctx, &rows, sqlquery); err != nil { @@ -285,6 +527,18 @@ func (h *Handler) SqlQuery(sqlquery string, pBlankparms bool) HTTPFuncType { dbobj = rows[0] } + // Execute AfterSQLExec hook + hookCtx.Result = dbobj + if err := h.hooks.Execute(AfterSQLExec, hookCtx); err != nil { + logger.Error("AfterSQLExec hook failed: %v", err) + sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err) + return err + } + // Use potentially modified result from hook + if modifiedResult, ok := hookCtx.Result.(map[string]interface{}); ok { + dbobj = modifiedResult + } + return nil }) @@ -293,6 +547,31 @@ func (h *Handler) SqlQuery(sqlquery string, pBlankparms bool) HTTPFuncType { return } + // Execute AfterQuery hook + hookCtx.Result = dbobj + hookCtx.Error = err + if err := h.hooks.Execute(AfterQuery, hookCtx); err != nil { + logger.Error("AfterQuery hook failed: %v", err) + sendError(w, http.StatusInternalServerError, "hook_error", "Hook execution failed", err) + return + } + // Use potentially modified result from hook + if modifiedResult, ok := hookCtx.Result.(map[string]interface{}); ok { + dbobj = modifiedResult + } + + // Execute BeforeResponse hook + hookCtx.Result = dbobj + if err := h.hooks.Execute(BeforeResponse, hookCtx); err != nil { + logger.Error("BeforeResponse hook failed: %v", err) + sendError(w, http.StatusInternalServerError, "hook_error", "Hook execution failed", err) + return + } + // Use potentially modified result from hook + if modifiedResult, ok := hookCtx.Result.(map[string]interface{}); ok { + dbobj = modifiedResult + } + // Check if response should be root-level data if val, ok := dbobj["root_as_data"]; ok { data, err := json.Marshal(val) diff --git a/pkg/funcspec/function_api_test.go b/pkg/funcspec/function_api_test.go new file mode 100644 index 0000000..c34c803 --- /dev/null +++ b/pkg/funcspec/function_api_test.go @@ -0,0 +1,837 @@ +package funcspec + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + + "github.com/bitechdev/ResolveSpec/pkg/common" + "github.com/bitechdev/ResolveSpec/pkg/security" +) + +// MockDatabase implements common.Database interface for testing +type MockDatabase struct { + QueryFunc func(ctx context.Context, dest interface{}, query string, args ...interface{}) error + ExecFunc func(ctx context.Context, query string, args ...interface{}) (common.Result, error) + RunInTransactionFunc func(ctx context.Context, fn func(common.Database) error) error +} + +func (m *MockDatabase) NewSelect() common.SelectQuery { + return nil +} + +func (m *MockDatabase) NewInsert() common.InsertQuery { + return nil +} + +func (m *MockDatabase) NewUpdate() common.UpdateQuery { + return nil +} + +func (m *MockDatabase) NewDelete() common.DeleteQuery { + return nil +} + +func (m *MockDatabase) Exec(ctx context.Context, query string, args ...interface{}) (common.Result, error) { + if m.ExecFunc != nil { + return m.ExecFunc(ctx, query, args...) + } + return &MockResult{rows: 0}, nil +} + +func (m *MockDatabase) Query(ctx context.Context, dest interface{}, query string, args ...interface{}) error { + if m.QueryFunc != nil { + return m.QueryFunc(ctx, dest, query, args...) + } + return nil +} + +func (m *MockDatabase) BeginTx(ctx context.Context) (common.Database, error) { + return m, nil +} + +func (m *MockDatabase) CommitTx(ctx context.Context) error { + return nil +} + +func (m *MockDatabase) RollbackTx(ctx context.Context) error { + return nil +} + +func (m *MockDatabase) RunInTransaction(ctx context.Context, fn func(common.Database) error) error { + if m.RunInTransactionFunc != nil { + return m.RunInTransactionFunc(ctx, fn) + } + return fn(m) +} + +// MockResult implements common.Result interface for testing +type MockResult struct { + rows int64 + id int64 +} + +func (m *MockResult) RowsAffected() int64 { + return m.rows +} + +func (m *MockResult) LastInsertId() (int64, error) { + return m.id, nil +} + +// Helper function to create a test request with user context +func createTestRequest(method, path string, queryParams map[string]string, headers map[string]string, body []byte) *http.Request { + u, _ := url.Parse(path) + if queryParams != nil { + q := u.Query() + for k, v := range queryParams { + q.Set(k, v) + } + u.RawQuery = q.Encode() + } + + var bodyReader *bytes.Reader + if body != nil { + bodyReader = bytes.NewReader(body) + } else { + bodyReader = bytes.NewReader([]byte{}) + } + + req := httptest.NewRequest(method, u.String(), bodyReader) + + if headers != nil { + for k, v := range headers { + req.Header.Set(k, v) + } + } + + // Add user context + userCtx := &security.UserContext{ + UserID: 1, + UserName: "testuser", + SessionID: "test-session-123", + } + ctx := context.WithValue(req.Context(), security.UserContextKey, userCtx) + req = req.WithContext(ctx) + + return req +} + +// TestNewHandler tests handler creation +func TestNewHandler(t *testing.T) { + db := &MockDatabase{} + handler := NewHandler(db) + + if handler == nil { + t.Fatal("Expected handler to be created, got nil") + } + + if handler.db != db { + t.Error("Expected handler to have the provided database") + } + + if handler.hooks == nil { + t.Error("Expected handler to have a hook registry") + } +} + +// TestHandlerHooks tests the Hooks method +func TestHandlerHooks(t *testing.T) { + handler := NewHandler(&MockDatabase{}) + hooks := handler.Hooks() + + if hooks == nil { + t.Fatal("Expected hooks registry to be non-nil") + } + + // Should return the same instance + hooks2 := handler.Hooks() + if hooks != hooks2 { + t.Error("Expected Hooks() to return the same registry instance") + } +} + +// TestExtractInputVariables tests the extractInputVariables function +func TestExtractInputVariables(t *testing.T) { + handler := NewHandler(&MockDatabase{}) + + tests := []struct { + name string + sqlQuery string + expectedVars []string + }{ + { + name: "No variables", + sqlQuery: "SELECT * FROM users", + expectedVars: []string{}, + }, + { + name: "Single variable", + sqlQuery: "SELECT * FROM users WHERE id = [user_id]", + expectedVars: []string{"[user_id]"}, + }, + { + name: "Multiple variables", + sqlQuery: "SELECT * FROM users WHERE id = [user_id] AND name = [user_name]", + expectedVars: []string{"[user_id]", "[user_name]"}, + }, + { + name: "Nested brackets", + sqlQuery: "SELECT * FROM users WHERE data::jsonb @> '[field]'::jsonb AND id = [user_id]", + expectedVars: []string{"[field]", "[user_id]"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + inputvars := make([]string, 0) + result := handler.extractInputVariables(tt.sqlQuery, &inputvars) + + if result != tt.sqlQuery { + t.Errorf("Expected SQL query to be unchanged, got %s", result) + } + + if len(inputvars) != len(tt.expectedVars) { + t.Errorf("Expected %d variables, got %d: %v", len(tt.expectedVars), len(inputvars), inputvars) + return + } + + for i, expected := range tt.expectedVars { + if inputvars[i] != expected { + t.Errorf("Expected variable %d to be %s, got %s", i, expected, inputvars[i]) + } + } + }) + } +} + +// TestValidSQL tests the SQL sanitization function +func TestValidSQL(t *testing.T) { + tests := []struct { + name string + input string + mode string + expected string + }{ + { + name: "Column name with valid characters", + input: "user_id", + mode: "colname", + expected: "user_id", + }, + { + name: "Column name with dots (table.column)", + input: "users.user_id", + mode: "colname", + expected: "users.user_id", + }, + { + name: "Column name with SQL injection attempt", + input: "id'; DROP TABLE users--", + mode: "colname", + expected: "idDROPTABLEusers", + }, + { + name: "Column value with single quotes", + input: "O'Brien", + mode: "colvalue", + expected: "O''Brien", + }, + { + name: "Select with dangerous keywords", + input: "name, email; DROP TABLE users", + mode: "select", + expected: "name, email TABLE users", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := ValidSQL(tt.input, tt.mode) + if result != tt.expected { + t.Errorf("ValidSQL(%q, %q) = %q, expected %q", tt.input, tt.mode, result, tt.expected) + } + }) + } +} + +// TestIsNumeric tests the IsNumeric function +func TestIsNumeric(t *testing.T) { + tests := []struct { + input string + expected bool + }{ + {"123", true}, + {"123.45", true}, + {"-123", true}, + {"-123.45", true}, + {"0", true}, + {"abc", false}, + {"12.34.56", false}, + {"", false}, + {"123abc", false}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + result := IsNumeric(tt.input) + if result != tt.expected { + t.Errorf("IsNumeric(%q) = %v, expected %v", tt.input, result, tt.expected) + } + }) + } +} + +// TestSqlQryWhere tests the WHERE clause manipulation +func TestSqlQryWhere(t *testing.T) { + tests := []struct { + name string + sqlQuery string + condition string + expected string + }{ + { + name: "Add WHERE to query without WHERE", + sqlQuery: "SELECT * FROM users", + condition: "status = 'active'", + expected: "SELECT * FROM users WHERE status = 'active' ", + }, + { + name: "Add AND to query with existing WHERE", + sqlQuery: "SELECT * FROM users WHERE id > 0", + condition: "status = 'active'", + expected: "SELECT * FROM users WHERE id > 0 AND status = 'active' ", + }, + { + name: "Add WHERE before ORDER BY", + sqlQuery: "SELECT * FROM users ORDER BY name", + condition: "status = 'active'", + expected: "SELECT * FROM users WHERE status = 'active' ORDER BY name", + }, + { + name: "Add WHERE before GROUP BY", + sqlQuery: "SELECT COUNT(*) FROM users GROUP BY department", + condition: "status = 'active'", + expected: "SELECT COUNT(*) FROM users WHERE status = 'active' GROUP BY department", + }, + { + name: "Add WHERE before LIMIT", + sqlQuery: "SELECT * FROM users LIMIT 10", + condition: "status = 'active'", + expected: "SELECT * FROM users WHERE status = 'active' LIMIT 10", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := sqlQryWhere(tt.sqlQuery, tt.condition) + if result != tt.expected { + t.Errorf("sqlQryWhere() = %q, expected %q", result, tt.expected) + } + }) + } +} + +// TestGetIPAddress tests IP address extraction +func TestGetIPAddress(t *testing.T) { + tests := []struct { + name string + setupReq func() *http.Request + expected string + }{ + { + name: "X-Forwarded-For header", + setupReq: func() *http.Request { + req := httptest.NewRequest("GET", "/test", nil) + req.Header.Set("X-Forwarded-For", "192.168.1.100, 10.0.0.1") + return req + }, + expected: "192.168.1.100", + }, + { + name: "X-Real-IP header", + setupReq: func() *http.Request { + req := httptest.NewRequest("GET", "/test", nil) + req.Header.Set("X-Real-IP", "192.168.1.200") + return req + }, + expected: "192.168.1.200", + }, + { + name: "RemoteAddr fallback", + setupReq: func() *http.Request { + req := httptest.NewRequest("GET", "/test", nil) + req.RemoteAddr = "192.168.1.1:12345" + return req + }, + expected: "192.168.1.1:12345", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := tt.setupReq() + result := getIPAddress(req) + if result != tt.expected { + t.Errorf("getIPAddress() = %q, expected %q", result, tt.expected) + } + }) + } +} + +// TestParsePaginationParams tests pagination parameter parsing +func TestParsePaginationParams(t *testing.T) { + handler := NewHandler(&MockDatabase{}) + + tests := []struct { + name string + queryParams map[string]string + expectedSort string + expectedLimit int + expectedOffset int + }{ + { + name: "No parameters - defaults", + queryParams: map[string]string{}, + expectedSort: "", + expectedLimit: 20, + expectedOffset: 0, + }, + { + name: "All parameters provided", + queryParams: map[string]string{ + "sort": "name,-created_at", + "limit": "100", + "offset": "50", + }, + expectedSort: "name,-created_at", + expectedLimit: 100, + expectedOffset: 50, + }, + { + name: "Invalid limit - use default", + queryParams: map[string]string{ + "limit": "invalid", + }, + expectedSort: "", + expectedLimit: 20, + expectedOffset: 0, + }, + { + name: "Negative offset - use default", + queryParams: map[string]string{ + "offset": "-10", + }, + expectedSort: "", + expectedLimit: 20, + expectedOffset: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := createTestRequest("GET", "/test", tt.queryParams, nil, nil) + sort, limit, offset := handler.parsePaginationParams(req) + + if sort != tt.expectedSort { + t.Errorf("Expected sort=%q, got %q", tt.expectedSort, sort) + } + if limit != tt.expectedLimit { + t.Errorf("Expected limit=%d, got %d", tt.expectedLimit, limit) + } + if offset != tt.expectedOffset { + t.Errorf("Expected offset=%d, got %d", tt.expectedOffset, offset) + } + }) + } +} + +// TestSqlQuery tests the SqlQuery handler for single record queries +func TestSqlQuery(t *testing.T) { + tests := []struct { + name string + sqlQuery string + blankParams bool + queryParams map[string]string + headers map[string]string + setupDB func() *MockDatabase + expectedStatus int + validateResp func(t *testing.T, body []byte) + }{ + { + name: "Basic query - returns single record", + sqlQuery: "SELECT * FROM users WHERE id = 1", + blankParams: false, + setupDB: func() *MockDatabase { + return &MockDatabase{ + RunInTransactionFunc: func(ctx context.Context, fn func(common.Database) error) error { + db := &MockDatabase{ + QueryFunc: func(ctx context.Context, dest interface{}, query string, args ...interface{}) error { + rows := dest.(*[]map[string]interface{}) + *rows = []map[string]interface{}{ + {"id": float64(1), "name": "Test User", "email": "test@example.com"}, + } + return nil + }, + } + return fn(db) + }, + } + }, + expectedStatus: 200, + validateResp: func(t *testing.T, body []byte) { + var result map[string]interface{} + if err := json.Unmarshal(body, &result); err != nil { + t.Fatalf("Failed to unmarshal response: %v", err) + } + if result["name"] != "Test User" { + t.Errorf("Expected name='Test User', got %v", result["name"]) + } + }, + }, + { + name: "Query with no results", + sqlQuery: "SELECT * FROM users WHERE id = 999", + blankParams: false, + setupDB: func() *MockDatabase { + return &MockDatabase{ + RunInTransactionFunc: func(ctx context.Context, fn func(common.Database) error) error { + db := &MockDatabase{ + QueryFunc: func(ctx context.Context, dest interface{}, query string, args ...interface{}) error { + // Return empty array + return nil + }, + } + return fn(db) + }, + } + }, + expectedStatus: 200, + validateResp: func(t *testing.T, body []byte) { + var result map[string]interface{} + if err := json.Unmarshal(body, &result); err != nil { + t.Fatalf("Failed to unmarshal response: %v", err) + } + if len(result) != 0 { + t.Errorf("Expected empty result, got %v", result) + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + db := tt.setupDB() + handler := NewHandler(db) + + req := createTestRequest("GET", "/test", tt.queryParams, tt.headers, nil) + w := httptest.NewRecorder() + + handlerFunc := handler.SqlQuery(tt.sqlQuery, tt.blankParams) + handlerFunc(w, req) + + if w.Code != tt.expectedStatus { + t.Errorf("Expected status %d, got %d", tt.expectedStatus, w.Code) + } + + if tt.validateResp != nil { + tt.validateResp(t, w.Body.Bytes()) + } + }) + } +} + +// TestSqlQueryList tests the SqlQueryList handler for list queries +func TestSqlQueryList(t *testing.T) { + tests := []struct { + name string + sqlQuery string + noCount bool + blankParams bool + allowFilter bool + queryParams map[string]string + headers map[string]string + setupDB func() *MockDatabase + expectedStatus int + validateResp func(t *testing.T, w *httptest.ResponseRecorder) + }{ + { + name: "Basic list query", + sqlQuery: "SELECT * FROM users", + noCount: false, + blankParams: false, + allowFilter: false, + setupDB: func() *MockDatabase { + return &MockDatabase{ + RunInTransactionFunc: func(ctx context.Context, fn func(common.Database) error) error { + callCount := 0 + db := &MockDatabase{ + QueryFunc: func(ctx context.Context, dest interface{}, query string, args ...interface{}) error { + callCount++ + if strings.Contains(query, "COUNT") { + // Count query + countResult := dest.(*struct{ Count int64 }) + countResult.Count = 2 + } else { + // Main query + rows := dest.(*[]map[string]interface{}) + *rows = []map[string]interface{}{ + {"id": float64(1), "name": "User 1"}, + {"id": float64(2), "name": "User 2"}, + } + } + return nil + }, + } + return fn(db) + }, + } + }, + expectedStatus: 200, + validateResp: func(t *testing.T, w *httptest.ResponseRecorder) { + var result []map[string]interface{} + if err := json.Unmarshal(w.Body.Bytes(), &result); err != nil { + t.Fatalf("Failed to unmarshal response: %v", err) + } + if len(result) != 2 { + t.Errorf("Expected 2 results, got %d", len(result)) + } + + // Check Content-Range header + contentRange := w.Header().Get("Content-Range") + if !strings.Contains(contentRange, "2") { + t.Errorf("Expected Content-Range to contain total count, got: %s", contentRange) + } + }, + }, + { + name: "List query with noCount", + sqlQuery: "SELECT * FROM users", + noCount: true, + blankParams: false, + allowFilter: false, + setupDB: func() *MockDatabase { + return &MockDatabase{ + RunInTransactionFunc: func(ctx context.Context, fn func(common.Database) error) error { + db := &MockDatabase{ + QueryFunc: func(ctx context.Context, dest interface{}, query string, args ...interface{}) error { + if strings.Contains(query, "COUNT") { + t.Error("Count query should not be executed when noCount is true") + } + rows := dest.(*[]map[string]interface{}) + *rows = []map[string]interface{}{ + {"id": float64(1), "name": "User 1"}, + } + return nil + }, + } + return fn(db) + }, + } + }, + expectedStatus: 200, + validateResp: func(t *testing.T, w *httptest.ResponseRecorder) { + var result []map[string]interface{} + if err := json.Unmarshal(w.Body.Bytes(), &result); err != nil { + t.Fatalf("Failed to unmarshal response: %v", err) + } + if len(result) != 1 { + t.Errorf("Expected 1 result, got %d", len(result)) + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + db := tt.setupDB() + handler := NewHandler(db) + + req := createTestRequest("GET", "/test", tt.queryParams, tt.headers, nil) + w := httptest.NewRecorder() + + handlerFunc := handler.SqlQueryList(tt.sqlQuery, tt.noCount, tt.blankParams, tt.allowFilter) + handlerFunc(w, req) + + if w.Code != tt.expectedStatus { + t.Errorf("Expected status %d, got %d. Body: %s", tt.expectedStatus, w.Code, w.Body.String()) + } + + if tt.validateResp != nil { + tt.validateResp(t, w) + } + }) + } +} + +// TestMergeQueryParams tests query parameter merging +func TestMergeQueryParams(t *testing.T) { + handler := NewHandler(&MockDatabase{}) + + tests := []struct { + name string + sqlQuery string + queryParams map[string]string + allowFilter bool + expectedQuery string + checkVars func(t *testing.T, vars map[string]interface{}) + }{ + { + name: "Replace placeholder with parameter", + sqlQuery: "SELECT * FROM users WHERE id = [user_id]", + queryParams: map[string]string{"p-user_id": "123"}, + allowFilter: false, + checkVars: func(t *testing.T, vars map[string]interface{}) { + if vars["p-user_id"] != "123" { + t.Errorf("Expected p-user_id=123, got %v", vars["p-user_id"]) + } + }, + }, + { + name: "Add filter when allowed", + sqlQuery: "SELECT * FROM users", + queryParams: map[string]string{"status": "active"}, + allowFilter: true, + checkVars: func(t *testing.T, vars map[string]interface{}) { + if vars["status"] != "active" { + t.Errorf("Expected status=active, got %v", vars["status"]) + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := createTestRequest("GET", "/test", tt.queryParams, nil, nil) + variables := make(map[string]interface{}) + propQry := make(map[string]string) + + result := handler.mergeQueryParams(req, tt.sqlQuery, variables, tt.allowFilter, propQry) + + if result == "" { + t.Error("Expected non-empty SQL query result") + } + + if tt.checkVars != nil { + tt.checkVars(t, variables) + } + }) + } +} + +// TestMergeHeaderParams tests header parameter merging +func TestMergeHeaderParams(t *testing.T) { + handler := NewHandler(&MockDatabase{}) + + tests := []struct { + name string + sqlQuery string + headers map[string]string + expectedQuery string + checkVars func(t *testing.T, vars map[string]interface{}) + }{ + { + name: "Field filter header", + sqlQuery: "SELECT * FROM users", + headers: map[string]string{"X-FieldFilter-Status": "1"}, + checkVars: func(t *testing.T, vars map[string]interface{}) { + if vars["x-fieldfilter-status"] != "1" { + t.Errorf("Expected x-fieldfilter-status=1, got %v", vars["x-fieldfilter-status"]) + } + }, + }, + { + name: "Search filter header", + sqlQuery: "SELECT * FROM users", + headers: map[string]string{"X-SearchFilter-Name": "john"}, + checkVars: func(t *testing.T, vars map[string]interface{}) { + if vars["x-searchfilter-name"] != "john" { + t.Errorf("Expected x-searchfilter-name=john, got %v", vars["x-searchfilter-name"]) + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := createTestRequest("GET", "/test", nil, tt.headers, nil) + variables := make(map[string]interface{}) + propQry := make(map[string]string) + complexAPI := false + + result := handler.mergeHeaderParams(req, tt.sqlQuery, variables, propQry, &complexAPI) + + if result == "" { + t.Error("Expected non-empty SQL query result") + } + + if tt.checkVars != nil { + tt.checkVars(t, variables) + } + }) + } +} + +// TestReplaceMetaVariables tests meta variable replacement +func TestReplaceMetaVariables(t *testing.T) { + handler := NewHandler(&MockDatabase{}) + + userCtx := &security.UserContext{ + UserID: 123, + UserName: "testuser", + SessionID: "session-abc", + } + + metainfo := map[string]interface{}{ + "ipaddress": "192.168.1.1", + "url": "/api/test", + } + + variables := map[string]interface{}{ + "param1": "value1", + } + + tests := []struct { + name string + sqlQuery string + expectedCheck func(result string) bool + }{ + { + name: "Replace [rid_user]", + sqlQuery: "SELECT * FROM users WHERE created_by = [rid_user]", + expectedCheck: func(result string) bool { + return strings.Contains(result, "123") + }, + }, + { + name: "Replace [user]", + sqlQuery: "SELECT * FROM audit WHERE username = [user]", + expectedCheck: func(result string) bool { + return strings.Contains(result, "'testuser'") + }, + }, + { + name: "Replace [rid_session]", + sqlQuery: "SELECT * FROM sessions WHERE session_id = [rid_session]", + expectedCheck: func(result string) bool { + return strings.Contains(result, "'session-abc'") + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := createTestRequest("GET", "/test", nil, nil, nil) + result := handler.replaceMetaVariables(tt.sqlQuery, req, userCtx, metainfo, variables) + + if !tt.expectedCheck(result) { + t.Errorf("Meta variable replacement failed. Query: %s", result) + } + }) + } +} diff --git a/pkg/funcspec/hooks.go b/pkg/funcspec/hooks.go new file mode 100644 index 0000000..c59fa35 --- /dev/null +++ b/pkg/funcspec/hooks.go @@ -0,0 +1,160 @@ +package funcspec + +import ( + "context" + "fmt" + "net/http" + + "github.com/bitechdev/ResolveSpec/pkg/logger" + "github.com/bitechdev/ResolveSpec/pkg/security" +) + +// HookType defines the type of hook to execute +type HookType string + +const ( + // Query operation hooks (for SqlQuery - single record) + BeforeQuery HookType = "before_query" + AfterQuery HookType = "after_query" + + // Query list operation hooks (for SqlQueryList - multiple records) + BeforeQueryList HookType = "before_query_list" + AfterQueryList HookType = "after_query_list" + + // SQL execution hooks (just before SQL is executed) + BeforeSQLExec HookType = "before_sql_exec" + AfterSQLExec HookType = "after_sql_exec" + + // Response hooks (before response is sent) + BeforeResponse HookType = "before_response" +) + +// HookContext contains all the data available to a hook +type HookContext struct { + Context context.Context + Handler *Handler // Reference to the handler for accessing database + Request *http.Request + Writer http.ResponseWriter + + // SQL query and variables + SQLQuery string // The SQL query being executed (can be modified by hooks) + Variables map[string]interface{} // Variables extracted from request + InputVars []string // Input variable placeholders found in query + MetaInfo map[string]interface{} // Metadata about the request + PropQry map[string]string // Property query parameters + + // User context + UserContext *security.UserContext + + // Pagination and filtering (for list queries) + SortColumns string + Limit int + Offset int + + // Results + Result interface{} // Query result (single record or list) + Total int64 // Total count (for list queries) + Error error // Error if operation failed + ComplexAPI bool // Whether complex API response format is requested + NoCount bool // Whether count query should be skipped + BlankParams bool // Whether blank parameters should be removed + AllowFilter bool // Whether filtering is allowed + + // Allow hooks to abort the operation + Abort bool // If set to true, the operation will be aborted + AbortMessage string // Message to return if aborted + AbortCode int // HTTP status code if aborted +} + +// HookFunc is the signature for hook functions +// It receives a HookContext and can modify it or return an error +// If an error is returned, the operation will be aborted +type HookFunc func(*HookContext) error + +// HookRegistry manages all registered hooks +type HookRegistry struct { + hooks map[HookType][]HookFunc +} + +// NewHookRegistry creates a new hook registry +func NewHookRegistry() *HookRegistry { + return &HookRegistry{ + hooks: make(map[HookType][]HookFunc), + } +} + +// Register adds a new hook for the specified hook type +func (r *HookRegistry) Register(hookType HookType, hook HookFunc) { + if r.hooks == nil { + r.hooks = make(map[HookType][]HookFunc) + } + r.hooks[hookType] = append(r.hooks[hookType], hook) + logger.Info("Registered funcspec hook for %s (total: %d)", hookType, len(r.hooks[hookType])) +} + +// RegisterMultiple registers a hook for multiple hook types +func (r *HookRegistry) RegisterMultiple(hookTypes []HookType, hook HookFunc) { + for _, hookType := range hookTypes { + r.Register(hookType, hook) + } +} + +// Execute runs all hooks for the specified type in order +// If any hook returns an error, execution stops and the error is returned +func (r *HookRegistry) Execute(hookType HookType, ctx *HookContext) error { + hooks, exists := r.hooks[hookType] + if !exists || len(hooks) == 0 { + return nil + } + + logger.Debug("Executing %d funcspec hook(s) for %s", len(hooks), hookType) + + for i, hook := range hooks { + if err := hook(ctx); err != nil { + logger.Error("Funcspec hook %d for %s failed: %v", i+1, hookType, err) + return fmt.Errorf("hook execution failed: %w", err) + } + + // Check if hook requested abort + if ctx.Abort { + logger.Warn("Funcspec hook %d for %s requested abort: %s", i+1, hookType, ctx.AbortMessage) + return fmt.Errorf("operation aborted by hook: %s", ctx.AbortMessage) + } + } + + return nil +} + +// Clear removes all hooks for the specified type +func (r *HookRegistry) Clear(hookType HookType) { + delete(r.hooks, hookType) + logger.Info("Cleared all funcspec hooks for %s", hookType) +} + +// ClearAll removes all registered hooks +func (r *HookRegistry) ClearAll() { + r.hooks = make(map[HookType][]HookFunc) + logger.Info("Cleared all funcspec hooks") +} + +// Count returns the number of hooks registered for a specific type +func (r *HookRegistry) Count(hookType HookType) int { + if hooks, exists := r.hooks[hookType]; exists { + return len(hooks) + } + return 0 +} + +// HasHooks returns true if there are any hooks registered for the specified type +func (r *HookRegistry) HasHooks(hookType HookType) bool { + return r.Count(hookType) > 0 +} + +// GetAllHookTypes returns all hook types that have registered hooks +func (r *HookRegistry) GetAllHookTypes() []HookType { + types := make([]HookType, 0, len(r.hooks)) + for hookType := range r.hooks { + types = append(types, hookType) + } + return types +} diff --git a/pkg/funcspec/hooks_example.go b/pkg/funcspec/hooks_example.go new file mode 100644 index 0000000..5f25d37 --- /dev/null +++ b/pkg/funcspec/hooks_example.go @@ -0,0 +1,137 @@ +package funcspec + +import ( + "fmt" + "strings" + + "github.com/bitechdev/ResolveSpec/pkg/logger" +) + +// Example hook functions demonstrating various use cases + +// ExampleLoggingHook logs all SQL queries before execution +func ExampleLoggingHook(ctx *HookContext) error { + logger.Info("Executing SQL query for user %s: %s", ctx.UserContext.UserName, ctx.SQLQuery) + return nil +} + +// ExampleSecurityHook validates user permissions before executing queries +func ExampleSecurityHook(ctx *HookContext) error { + // Example: Block queries that try to access sensitive tables + if strings.Contains(strings.ToLower(ctx.SQLQuery), "sensitive_table") { + if ctx.UserContext.UserID != 1 { // Only admin can access + ctx.Abort = true + ctx.AbortCode = 403 + ctx.AbortMessage = "Access denied: insufficient permissions" + return fmt.Errorf("access denied to sensitive_table") + } + } + return nil +} + +// ExampleQueryModificationHook modifies SQL queries to add user-specific filtering +func ExampleQueryModificationHook(ctx *HookContext) error { + // Example: Automatically add user_id filter for non-admin users + if ctx.UserContext.UserID != 1 { // Not admin + // Add WHERE clause to filter by user_id + if !strings.Contains(strings.ToLower(ctx.SQLQuery), "where") { + ctx.SQLQuery = fmt.Sprintf("%s WHERE user_id = %d", ctx.SQLQuery, ctx.UserContext.UserID) + } else { + ctx.SQLQuery = strings.Replace( + ctx.SQLQuery, + "WHERE", + fmt.Sprintf("WHERE user_id = %d AND", ctx.UserContext.UserID), + 1, + ) + } + logger.Debug("Modified query for user %d: %s", ctx.UserContext.UserID, ctx.SQLQuery) + } + return nil +} + +// ExampleResultFilterHook filters results after query execution +func ExampleResultFilterHook(ctx *HookContext) error { + // Example: Remove sensitive fields from results for non-admin users + if ctx.UserContext.UserID != 1 { // Not admin + switch result := ctx.Result.(type) { + case []map[string]interface{}: + // Filter list results + for i := range result { + delete(result[i], "password") + delete(result[i], "ssn") + delete(result[i], "credit_card") + } + case map[string]interface{}: + // Filter single record + delete(result, "password") + delete(result, "ssn") + delete(result, "credit_card") + } + } + return nil +} + +// ExampleAuditHook logs all queries and results for audit purposes +func ExampleAuditHook(ctx *HookContext) error { + // Log to audit table or external system + logger.Info("AUDIT: User %s (%d) executed query from %s", + ctx.UserContext.UserName, + ctx.UserContext.UserID, + ctx.Request.RemoteAddr, + ) + + // In a real implementation, you might: + // - Insert into an audit log table + // - Send to a logging service + // - Write to a file + + return nil +} + +// ExampleCacheHook implements simple response caching +func ExampleCacheHook(ctx *HookContext) error { + // This is a simplified example - real caching would use a proper cache store + // Check if we have a cached result for this query + // cacheKey := fmt.Sprintf("%s:%s", ctx.UserContext.UserName, ctx.SQLQuery) + // if cachedResult := checkCache(cacheKey); cachedResult != nil { + // ctx.Result = cachedResult + // ctx.Abort = true // Skip query execution + // ctx.AbortMessage = "Serving from cache" + // } + return nil +} + +// ExampleErrorHandlingHook provides custom error handling +func ExampleErrorHandlingHook(ctx *HookContext) error { + if ctx.Error != nil { + // Log error with context + logger.Error("Query failed for user %s: %v\nQuery: %s", + ctx.UserContext.UserName, + ctx.Error, + ctx.SQLQuery, + ) + + // You could send notifications, update metrics, etc. + } + return nil +} + +// Example of registering hooks: +// +// func SetupHooks(handler *Handler) { +// hooks := handler.Hooks() +// +// // Register security hook before query execution +// hooks.Register(BeforeQuery, ExampleSecurityHook) +// hooks.Register(BeforeQueryList, ExampleSecurityHook) +// +// // Register logging hook before SQL execution +// hooks.Register(BeforeSQLExec, ExampleLoggingHook) +// +// // Register result filtering after query +// hooks.Register(AfterQuery, ExampleResultFilterHook) +// hooks.Register(AfterQueryList, ExampleResultFilterHook) +// +// // Register audit hook after execution +// hooks.RegisterMultiple([]HookType{AfterQuery, AfterQueryList}, ExampleAuditHook) +// } diff --git a/pkg/funcspec/hooks_test.go b/pkg/funcspec/hooks_test.go new file mode 100644 index 0000000..fb9055a --- /dev/null +++ b/pkg/funcspec/hooks_test.go @@ -0,0 +1,589 @@ +package funcspec + +import ( + "context" + "fmt" + "net/http/httptest" + "testing" + + "github.com/bitechdev/ResolveSpec/pkg/common" + "github.com/bitechdev/ResolveSpec/pkg/security" +) + +// TestNewHookRegistry tests hook registry creation +func TestNewHookRegistry(t *testing.T) { + registry := NewHookRegistry() + + if registry == nil { + t.Fatal("Expected registry to be created, got nil") + } + + if registry.hooks == nil { + t.Error("Expected hooks map to be initialized") + } +} + +// TestRegisterHook tests registering a single hook +func TestRegisterHook(t *testing.T) { + registry := NewHookRegistry() + + hookCalled := false + testHook := func(ctx *HookContext) error { + hookCalled = true + return nil + } + + registry.Register(BeforeQuery, testHook) + + if !registry.HasHooks(BeforeQuery) { + t.Error("Expected hook to be registered") + } + + if registry.Count(BeforeQuery) != 1 { + t.Errorf("Expected 1 hook, got %d", registry.Count(BeforeQuery)) + } + + // Execute the hook + ctx := &HookContext{} + err := registry.Execute(BeforeQuery, ctx) + if err != nil { + t.Errorf("Hook execution failed: %v", err) + } + + if !hookCalled { + t.Error("Expected hook to be called") + } +} + +// TestRegisterMultipleHooks tests registering multiple hooks for same type +func TestRegisterMultipleHooks(t *testing.T) { + registry := NewHookRegistry() + + callOrder := []int{} + + hook1 := func(ctx *HookContext) error { + callOrder = append(callOrder, 1) + return nil + } + + hook2 := func(ctx *HookContext) error { + callOrder = append(callOrder, 2) + return nil + } + + hook3 := func(ctx *HookContext) error { + callOrder = append(callOrder, 3) + return nil + } + + registry.Register(BeforeQuery, hook1) + registry.Register(BeforeQuery, hook2) + registry.Register(BeforeQuery, hook3) + + if registry.Count(BeforeQuery) != 3 { + t.Errorf("Expected 3 hooks, got %d", registry.Count(BeforeQuery)) + } + + // Execute hooks + ctx := &HookContext{} + err := registry.Execute(BeforeQuery, ctx) + if err != nil { + t.Errorf("Hook execution failed: %v", err) + } + + // Verify hooks were called in order + if len(callOrder) != 3 { + t.Errorf("Expected 3 hooks to be called, got %d", len(callOrder)) + } + + for i, expected := range []int{1, 2, 3} { + if callOrder[i] != expected { + t.Errorf("Expected hook %d at position %d, got %d", expected, i, callOrder[i]) + } + } +} + +// TestRegisterMultipleHookTypes tests registering a hook for multiple types +func TestRegisterMultipleHookTypes(t *testing.T) { + registry := NewHookRegistry() + + callCount := 0 + testHook := func(ctx *HookContext) error { + callCount++ + return nil + } + + hookTypes := []HookType{BeforeQuery, AfterQuery, BeforeSQLExec} + registry.RegisterMultiple(hookTypes, testHook) + + // Verify hook is registered for all types + for _, hookType := range hookTypes { + if !registry.HasHooks(hookType) { + t.Errorf("Expected hook to be registered for %s", hookType) + } + + if registry.Count(hookType) != 1 { + t.Errorf("Expected 1 hook for %s, got %d", hookType, registry.Count(hookType)) + } + } + + // Execute each hook type + ctx := &HookContext{} + for _, hookType := range hookTypes { + if err := registry.Execute(hookType, ctx); err != nil { + t.Errorf("Hook execution failed for %s: %v", hookType, err) + } + } + + if callCount != 3 { + t.Errorf("Expected hook to be called 3 times, got %d", callCount) + } +} + +// TestHookError tests hook error handling +func TestHookError(t *testing.T) { + registry := NewHookRegistry() + + expectedError := fmt.Errorf("test error") + errorHook := func(ctx *HookContext) error { + return expectedError + } + + registry.Register(BeforeQuery, errorHook) + + ctx := &HookContext{} + err := registry.Execute(BeforeQuery, ctx) + + if err == nil { + t.Error("Expected error from hook, got nil") + } + + if err.Error() != fmt.Sprintf("hook execution failed: %v", expectedError) { + t.Errorf("Expected error message to contain hook error, got: %v", err) + } +} + +// TestHookAbort tests hook abort functionality +func TestHookAbort(t *testing.T) { + registry := NewHookRegistry() + + abortHook := func(ctx *HookContext) error { + ctx.Abort = true + ctx.AbortMessage = "Operation aborted by hook" + ctx.AbortCode = 403 + return nil + } + + registry.Register(BeforeQuery, abortHook) + + ctx := &HookContext{} + err := registry.Execute(BeforeQuery, ctx) + + if err == nil { + t.Error("Expected error when hook aborts, got nil") + } + + if !ctx.Abort { + t.Error("Expected Abort to be true") + } + + if ctx.AbortMessage != "Operation aborted by hook" { + t.Errorf("Expected abort message, got: %s", ctx.AbortMessage) + } + + if ctx.AbortCode != 403 { + t.Errorf("Expected abort code 403, got: %d", ctx.AbortCode) + } +} + +// TestHookChainWithError tests that hook chain stops on first error +func TestHookChainWithError(t *testing.T) { + registry := NewHookRegistry() + + callOrder := []int{} + + hook1 := func(ctx *HookContext) error { + callOrder = append(callOrder, 1) + return nil + } + + hook2 := func(ctx *HookContext) error { + callOrder = append(callOrder, 2) + return fmt.Errorf("error in hook 2") + } + + hook3 := func(ctx *HookContext) error { + callOrder = append(callOrder, 3) + return nil + } + + registry.Register(BeforeQuery, hook1) + registry.Register(BeforeQuery, hook2) + registry.Register(BeforeQuery, hook3) + + ctx := &HookContext{} + err := registry.Execute(BeforeQuery, ctx) + + if err == nil { + t.Error("Expected error from hook chain") + } + + // Only first two hooks should have been called + if len(callOrder) != 2 { + t.Errorf("Expected 2 hooks to be called, got %d", len(callOrder)) + } + + if callOrder[0] != 1 || callOrder[1] != 2 { + t.Errorf("Expected hooks 1 and 2 to be called, got: %v", callOrder) + } +} + +// TestClearHooks tests clearing hooks +func TestClearHooks(t *testing.T) { + registry := NewHookRegistry() + + testHook := func(ctx *HookContext) error { + return nil + } + + registry.Register(BeforeQuery, testHook) + registry.Register(AfterQuery, testHook) + + if !registry.HasHooks(BeforeQuery) { + t.Error("Expected BeforeQuery hook to be registered") + } + + registry.Clear(BeforeQuery) + + if registry.HasHooks(BeforeQuery) { + t.Error("Expected BeforeQuery hooks to be cleared") + } + + if !registry.HasHooks(AfterQuery) { + t.Error("Expected AfterQuery hook to still be registered") + } +} + +// TestClearAllHooks tests clearing all hooks +func TestClearAllHooks(t *testing.T) { + registry := NewHookRegistry() + + testHook := func(ctx *HookContext) error { + return nil + } + + registry.Register(BeforeQuery, testHook) + registry.Register(AfterQuery, testHook) + registry.Register(BeforeSQLExec, testHook) + + registry.ClearAll() + + if registry.HasHooks(BeforeQuery) || registry.HasHooks(AfterQuery) || registry.HasHooks(BeforeSQLExec) { + t.Error("Expected all hooks to be cleared") + } +} + +// TestGetAllHookTypes tests getting all registered hook types +func TestGetAllHookTypes(t *testing.T) { + registry := NewHookRegistry() + + testHook := func(ctx *HookContext) error { + return nil + } + + registry.Register(BeforeQuery, testHook) + registry.Register(AfterQuery, testHook) + + types := registry.GetAllHookTypes() + + if len(types) != 2 { + t.Errorf("Expected 2 hook types, got %d", len(types)) + } + + // Verify the types are present + foundBefore := false + foundAfter := false + for _, hookType := range types { + if hookType == BeforeQuery { + foundBefore = true + } + if hookType == AfterQuery { + foundAfter = true + } + } + + if !foundBefore || !foundAfter { + t.Error("Expected both BeforeQuery and AfterQuery hook types") + } +} + +// TestHookContextModification tests that hooks can modify the context +func TestHookContextModification(t *testing.T) { + registry := NewHookRegistry() + + // Hook that modifies SQL query + modifyHook := func(ctx *HookContext) error { + ctx.SQLQuery = "SELECT * FROM modified_table" + ctx.Variables["new_var"] = "new_value" + return nil + } + + registry.Register(BeforeQuery, modifyHook) + + ctx := &HookContext{ + SQLQuery: "SELECT * FROM original_table", + Variables: make(map[string]interface{}), + } + + err := registry.Execute(BeforeQuery, ctx) + if err != nil { + t.Errorf("Hook execution failed: %v", err) + } + + if ctx.SQLQuery != "SELECT * FROM modified_table" { + t.Errorf("Expected SQL query to be modified, got: %s", ctx.SQLQuery) + } + + if ctx.Variables["new_var"] != "new_value" { + t.Errorf("Expected variable to be added, got: %v", ctx.Variables) + } +} + +// TestExampleHooks tests the example hooks +func TestExampleLoggingHook(t *testing.T) { + ctx := &HookContext{ + Context: context.Background(), + SQLQuery: "SELECT * FROM test", + UserContext: &security.UserContext{ + UserName: "testuser", + }, + } + + err := ExampleLoggingHook(ctx) + if err != nil { + t.Errorf("ExampleLoggingHook failed: %v", err) + } +} + +func TestExampleSecurityHook(t *testing.T) { + tests := []struct { + name string + sqlQuery string + userID int + shouldAbort bool + }{ + { + name: "Admin accessing sensitive table", + sqlQuery: "SELECT * FROM sensitive_table", + userID: 1, + shouldAbort: false, + }, + { + name: "Non-admin accessing sensitive table", + sqlQuery: "SELECT * FROM sensitive_table", + userID: 2, + shouldAbort: true, + }, + { + name: "Non-admin accessing normal table", + sqlQuery: "SELECT * FROM users", + userID: 2, + shouldAbort: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := &HookContext{ + Context: context.Background(), + SQLQuery: tt.sqlQuery, + UserContext: &security.UserContext{ + UserID: tt.userID, + }, + } + + _ = ExampleSecurityHook(ctx) + + if tt.shouldAbort { + if !ctx.Abort { + t.Error("Expected security hook to abort operation") + } + if ctx.AbortCode != 403 { + t.Errorf("Expected abort code 403, got %d", ctx.AbortCode) + } + } else { + if ctx.Abort { + t.Error("Expected security hook not to abort operation") + } + } + }) + } +} + +func TestExampleResultFilterHook(t *testing.T) { + tests := []struct { + name string + userID int + result interface{} + validate func(t *testing.T, result interface{}) + }{ + { + name: "Admin user - no filtering", + userID: 1, + result: map[string]interface{}{ + "id": 1, + "name": "Test", + "password": "secret", + }, + validate: func(t *testing.T, result interface{}) { + m := result.(map[string]interface{}) + if _, exists := m["password"]; !exists { + t.Error("Expected password field to remain for admin") + } + }, + }, + { + name: "Regular user - sensitive fields removed", + userID: 2, + result: map[string]interface{}{ + "id": 1, + "name": "Test", + "password": "secret", + "ssn": "123-45-6789", + }, + validate: func(t *testing.T, result interface{}) { + m := result.(map[string]interface{}) + if _, exists := m["password"]; exists { + t.Error("Expected password field to be removed") + } + if _, exists := m["ssn"]; exists { + t.Error("Expected ssn field to be removed") + } + if _, exists := m["name"]; !exists { + t.Error("Expected name field to remain") + } + }, + }, + { + name: "Regular user - list results filtered", + userID: 2, + result: []map[string]interface{}{ + {"id": 1, "name": "User 1", "password": "secret1"}, + {"id": 2, "name": "User 2", "password": "secret2"}, + }, + validate: func(t *testing.T, result interface{}) { + list := result.([]map[string]interface{}) + for _, m := range list { + if _, exists := m["password"]; exists { + t.Error("Expected password field to be removed from list") + } + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := &HookContext{ + Context: context.Background(), + Result: tt.result, + UserContext: &security.UserContext{ + UserID: tt.userID, + }, + } + + err := ExampleResultFilterHook(ctx) + if err != nil { + t.Errorf("Hook failed: %v", err) + } + + if tt.validate != nil { + tt.validate(t, ctx.Result) + } + }) + } +} + +func TestExampleAuditHook(t *testing.T) { + req := httptest.NewRequest("GET", "/api/test", nil) + req.RemoteAddr = "192.168.1.1:12345" + + ctx := &HookContext{ + Context: context.Background(), + Request: req, + UserContext: &security.UserContext{ + UserID: 123, + UserName: "testuser", + }, + } + + err := ExampleAuditHook(ctx) + if err != nil { + t.Errorf("ExampleAuditHook failed: %v", err) + } +} + +func TestExampleErrorHandlingHook(t *testing.T) { + ctx := &HookContext{ + Context: context.Background(), + SQLQuery: "SELECT * FROM test", + Error: fmt.Errorf("test error"), + UserContext: &security.UserContext{ + UserName: "testuser", + }, + } + + err := ExampleErrorHandlingHook(ctx) + if err != nil { + t.Errorf("ExampleErrorHandlingHook failed: %v", err) + } +} + +// TestHookIntegrationWithHandler tests hooks integrated with the handler +func TestHookIntegrationWithHandler(t *testing.T) { + db := &MockDatabase{ + RunInTransactionFunc: func(ctx context.Context, fn func(common.Database) error) error { + queryDB := &MockDatabase{ + QueryFunc: func(ctx context.Context, dest interface{}, query string, args ...interface{}) error { + rows := dest.(*[]map[string]interface{}) + *rows = []map[string]interface{}{ + {"id": float64(1), "name": "Test User"}, + } + return nil + }, + } + return fn(queryDB) + }, + } + + handler := NewHandler(db) + + // Register a hook that modifies the SQL query + hookCalled := false + handler.Hooks().Register(BeforeSQLExec, func(ctx *HookContext) error { + hookCalled = true + // Verify we can access context data + if ctx.SQLQuery == "" { + t.Error("Expected SQL query to be set") + } + if ctx.UserContext == nil { + t.Error("Expected user context to be set") + } + return nil + }) + + // Execute a query + req := createTestRequest("GET", "/test", nil, nil, nil) + w := httptest.NewRecorder() + + handlerFunc := handler.SqlQuery("SELECT * FROM users WHERE id = 1", false) + handlerFunc(w, req) + + if !hookCalled { + t.Error("Expected hook to be called during query execution") + } + + if w.Code != 200 { + t.Errorf("Expected status 200, got %d", w.Code) + } +} diff --git a/pkg/funcspec/parameters.go b/pkg/funcspec/parameters.go new file mode 100644 index 0000000..31584f9 --- /dev/null +++ b/pkg/funcspec/parameters.go @@ -0,0 +1,411 @@ +package funcspec + +import ( + "fmt" + "net/http" + "strconv" + "strings" + + "github.com/bitechdev/ResolveSpec/pkg/logger" + "github.com/bitechdev/ResolveSpec/pkg/restheadspec" +) + +// RequestParameters holds parsed parameters from headers and query string +type RequestParameters struct { + // Field selection + SelectFields []string + NotSelectFields []string + Distinct bool + + // Filtering + FieldFilters map[string]string // column -> value (exact match) + SearchFilters map[string]string // column -> value (ILIKE) + SearchOps map[string]FilterOperator // column -> {operator, value, logic} + CustomSQLWhere string + CustomSQLOr string + + // Sorting & Pagination + SortColumns string + Limit int + Offset int + + // Advanced features + SkipCount bool + SkipCache bool + + // Response format + ResponseFormat string // "simple", "detail", "syncfusion" + ComplexAPI bool // true if NOT simple API +} + +// FilterOperator represents a filter with operator +type FilterOperator struct { + Operator string // eq, neq, gt, lt, gte, lte, like, ilike, in, between, etc. + Value string + Logic string // AND or OR +} + +// ParseParameters parses all parameters from request headers and query string +func (h *Handler) ParseParameters(r *http.Request) *RequestParameters { + params := &RequestParameters{ + FieldFilters: make(map[string]string), + SearchFilters: make(map[string]string), + SearchOps: make(map[string]FilterOperator), + Limit: 20, // Default limit + Offset: 0, // Default offset + ResponseFormat: "simple", // Default format + ComplexAPI: false, // Default to simple API + } + + // Merge headers and query parameters + combined := make(map[string]string) + + // Add all headers (normalize to lowercase) + for key, values := range r.Header { + if len(values) > 0 { + combined[strings.ToLower(key)] = values[0] + } + } + + // Add all query parameters (override headers) + for key, values := range r.URL.Query() { + if len(values) > 0 { + combined[strings.ToLower(key)] = values[0] + } + } + + // Parse each parameter + for key, value := range combined { + // Decode value if base64 encoded + decodedValue := h.decodeValue(value) + + switch { + // Field Selection + case strings.HasPrefix(key, "x-select-fields"): + params.SelectFields = h.parseCommaSeparated(decodedValue) + case strings.HasPrefix(key, "x-not-select-fields"): + params.NotSelectFields = h.parseCommaSeparated(decodedValue) + case strings.HasPrefix(key, "x-distinct"): + params.Distinct = strings.EqualFold(decodedValue, "true") + + // Filtering + case strings.HasPrefix(key, "x-fieldfilter-"): + colName := strings.TrimPrefix(key, "x-fieldfilter-") + params.FieldFilters[colName] = decodedValue + case strings.HasPrefix(key, "x-searchfilter-"): + colName := strings.TrimPrefix(key, "x-searchfilter-") + params.SearchFilters[colName] = decodedValue + case strings.HasPrefix(key, "x-searchop-"): + h.parseSearchOp(params, key, decodedValue, "AND") + case strings.HasPrefix(key, "x-searchor-"): + h.parseSearchOp(params, key, decodedValue, "OR") + case strings.HasPrefix(key, "x-searchand-"): + h.parseSearchOp(params, key, decodedValue, "AND") + case strings.HasPrefix(key, "x-custom-sql-w"): + if params.CustomSQLWhere != "" { + params.CustomSQLWhere = fmt.Sprintf("%s AND (%s)", params.CustomSQLWhere, decodedValue) + } else { + params.CustomSQLWhere = decodedValue + } + case strings.HasPrefix(key, "x-custom-sql-or"): + if params.CustomSQLOr != "" { + params.CustomSQLOr = fmt.Sprintf("%s OR (%s)", params.CustomSQLOr, decodedValue) + } else { + params.CustomSQLOr = decodedValue + } + + // Sorting & Pagination + case key == "sort" || strings.HasPrefix(key, "x-sort"): + params.SortColumns = decodedValue + case strings.HasPrefix(key, "sort(") && strings.Contains(key, ")"): + // Handle sort(col1,-col2) syntax + sortValue := key[strings.Index(key, "(")+1 : strings.Index(key, ")")] + params.SortColumns = sortValue + case key == "limit" || strings.HasPrefix(key, "x-limit"): + if limit, err := strconv.Atoi(decodedValue); err == nil && limit > 0 { + params.Limit = limit + } + case strings.HasPrefix(key, "limit(") && strings.Contains(key, ")"): + // Handle limit(offset,limit) or limit(limit) syntax + limitValue := key[strings.Index(key, "(")+1 : strings.Index(key, ")")] + parts := strings.Split(limitValue, ",") + if len(parts) > 1 { + if offset, err := strconv.Atoi(parts[0]); err == nil { + params.Offset = offset + } + if limit, err := strconv.Atoi(parts[1]); err == nil { + params.Limit = limit + } + } else { + if limit, err := strconv.Atoi(parts[0]); err == nil { + params.Limit = limit + } + } + case key == "offset" || strings.HasPrefix(key, "x-offset"): + if offset, err := strconv.Atoi(decodedValue); err == nil && offset >= 0 { + params.Offset = offset + } + + // Advanced features + case strings.HasPrefix(key, "x-skipcount"): + params.SkipCount = strings.EqualFold(decodedValue, "true") + case strings.HasPrefix(key, "x-skipcache"): + params.SkipCache = strings.EqualFold(decodedValue, "true") + + // Response Format + case strings.HasPrefix(key, "x-simpleapi"): + params.ResponseFormat = "simple" + params.ComplexAPI = !(decodedValue == "1" || strings.EqualFold(decodedValue, "true")) + case strings.HasPrefix(key, "x-detailapi"): + params.ResponseFormat = "detail" + params.ComplexAPI = true + case strings.HasPrefix(key, "x-syncfusion"): + params.ResponseFormat = "syncfusion" + params.ComplexAPI = true + } + } + + return params +} + +// parseSearchOp parses x-searchop-{operator}-{column} or x-searchor-{operator}-{column} +func (h *Handler) parseSearchOp(params *RequestParameters, headerKey, value, logic string) { + var prefix string + if logic == "OR" { + prefix = "x-searchor-" + } else { + prefix = "x-searchop-" + if strings.HasPrefix(headerKey, "x-searchand-") { + prefix = "x-searchand-" + } + } + + rest := strings.TrimPrefix(headerKey, prefix) + parts := strings.SplitN(rest, "-", 2) + if len(parts) != 2 { + logger.Warn("Invalid search operator header format: %s", headerKey) + return + } + + operator := parts[0] + colName := parts[1] + + params.SearchOps[colName] = FilterOperator{ + Operator: operator, + Value: value, + Logic: logic, + } + + logger.Debug("%s search operator: %s %s %s", logic, colName, operator, value) +} + +// decodeValue decodes base64 encoded values (ZIP_ or __ prefix) +func (h *Handler) decodeValue(value string) string { + decoded, _ := restheadspec.DecodeParam(value) + return decoded +} + +// parseCommaSeparated parses comma-separated values +func (h *Handler) parseCommaSeparated(value string) []string { + if value == "" { + return nil + } + + parts := strings.Split(value, ",") + result := make([]string, 0, len(parts)) + for _, part := range parts { + part = strings.TrimSpace(part) + if part != "" { + result = append(result, part) + } + } + return result +} + +// ApplyFieldSelection applies column selection to SQL query +func (h *Handler) ApplyFieldSelection(sqlQuery string, params *RequestParameters) string { + if len(params.SelectFields) == 0 && len(params.NotSelectFields) == 0 { + return sqlQuery + } + + // This is a simplified implementation + // A full implementation would parse the SQL and replace the SELECT clause + // For now, we log a warning that this feature needs manual implementation + if len(params.SelectFields) > 0 { + logger.Debug("Field selection requested: %v (manual SQL adjustment may be needed)", params.SelectFields) + } + if len(params.NotSelectFields) > 0 { + logger.Debug("Field exclusion requested: %v (manual SQL adjustment may be needed)", params.NotSelectFields) + } + + return sqlQuery +} + +// ApplyFilters applies all filters to the SQL query +func (h *Handler) ApplyFilters(sqlQuery string, params *RequestParameters) string { + // Apply field filters (exact match) + for colName, value := range params.FieldFilters { + condition := "" + if value == "" || value == "0" { + condition = fmt.Sprintf("COALESCE(%s, 0) = %s", ValidSQL(colName, "colname"), ValidSQL(value, "colvalue")) + } else { + condition = fmt.Sprintf("%s = %s", ValidSQL(colName, "colname"), ValidSQL(value, "colvalue")) + } + sqlQuery = sqlQryWhere(sqlQuery, condition) + logger.Debug("Applied field filter: %s", condition) + } + + // Apply search filters (ILIKE) + for colName, value := range params.SearchFilters { + sval := strings.ReplaceAll(value, "'", "") + if sval != "" { + condition := fmt.Sprintf("%s ILIKE '%%%s%%'", ValidSQL(colName, "colname"), ValidSQL(sval, "colvalue")) + sqlQuery = sqlQryWhere(sqlQuery, condition) + logger.Debug("Applied search filter: %s", condition) + } + } + + // Apply search operators + for colName, filterOp := range params.SearchOps { + condition := h.buildFilterCondition(colName, filterOp) + if condition != "" { + if filterOp.Logic == "OR" { + sqlQuery = sqlQryWhereOr(sqlQuery, condition) + } else { + sqlQuery = sqlQryWhere(sqlQuery, condition) + } + logger.Debug("Applied search operator: %s", condition) + } + } + + // Apply custom SQL WHERE + if params.CustomSQLWhere != "" { + colval := ValidSQL(params.CustomSQLWhere, "select") + if colval != "" { + sqlQuery = sqlQryWhere(sqlQuery, colval) + logger.Debug("Applied custom SQL WHERE: %s", colval) + } + } + + // Apply custom SQL OR + if params.CustomSQLOr != "" { + colval := ValidSQL(params.CustomSQLOr, "select") + if colval != "" { + sqlQuery = sqlQryWhereOr(sqlQuery, colval) + logger.Debug("Applied custom SQL OR: %s", colval) + } + } + + return sqlQuery +} + +// buildFilterCondition builds a SQL condition from a FilterOperator +func (h *Handler) buildFilterCondition(colName string, op FilterOperator) string { + safCol := ValidSQL(colName, "colname") + operator := strings.ToLower(op.Operator) + value := op.Value + + switch operator { + case "contains", "contain", "like": + return fmt.Sprintf("%s ILIKE '%%%s%%'", safCol, ValidSQL(value, "colvalue")) + case "beginswith", "startswith": + return fmt.Sprintf("%s ILIKE '%s%%'", safCol, ValidSQL(value, "colvalue")) + case "endswith": + return fmt.Sprintf("%s ILIKE '%%%s'", safCol, ValidSQL(value, "colvalue")) + case "equals", "eq", "=": + if IsNumeric(value) { + return fmt.Sprintf("%s = %s", safCol, ValidSQL(value, "colvalue")) + } + return fmt.Sprintf("%s = '%s'", safCol, ValidSQL(value, "colvalue")) + case "notequals", "neq", "ne", "!=", "<>": + if IsNumeric(value) { + return fmt.Sprintf("%s != %s", safCol, ValidSQL(value, "colvalue")) + } + return fmt.Sprintf("%s != '%s'", safCol, ValidSQL(value, "colvalue")) + case "greaterthan", "gt", ">": + return fmt.Sprintf("%s > %s", safCol, ValidSQL(value, "colvalue")) + case "lessthan", "lt", "<": + return fmt.Sprintf("%s < %s", safCol, ValidSQL(value, "colvalue")) + case "greaterthanorequal", "gte", "ge", ">=": + return fmt.Sprintf("%s >= %s", safCol, ValidSQL(value, "colvalue")) + case "lessthanorequal", "lte", "le", "<=": + return fmt.Sprintf("%s <= %s", safCol, ValidSQL(value, "colvalue")) + case "between": + parts := strings.Split(value, ",") + if len(parts) == 2 { + return fmt.Sprintf("%s > %s AND %s < %s", safCol, ValidSQL(parts[0], "colvalue"), safCol, ValidSQL(parts[1], "colvalue")) + } + case "betweeninclusive": + parts := strings.Split(value, ",") + if len(parts) == 2 { + return fmt.Sprintf("%s >= %s AND %s <= %s", safCol, ValidSQL(parts[0], "colvalue"), safCol, ValidSQL(parts[1], "colvalue")) + } + case "in": + values := strings.Split(value, ",") + safeValues := make([]string, len(values)) + for i, v := range values { + safeValues[i] = fmt.Sprintf("'%s'", ValidSQL(v, "colvalue")) + } + return fmt.Sprintf("%s IN (%s)", safCol, strings.Join(safeValues, ", ")) + case "empty", "isnull", "null": + return fmt.Sprintf("(%s IS NULL OR %s = '')", safCol, safCol) + case "notempty", "isnotnull", "notnull": + return fmt.Sprintf("(%s IS NOT NULL AND %s != '')", safCol, safCol) + default: + logger.Warn("Unknown filter operator: %s, defaulting to equals", operator) + return fmt.Sprintf("%s = '%s'", safCol, ValidSQL(value, "colvalue")) + } + + return "" +} + +// ApplyDistinct adds DISTINCT to SQL query if requested +func (h *Handler) ApplyDistinct(sqlQuery string, params *RequestParameters) string { + if !params.Distinct { + return sqlQuery + } + + // Add DISTINCT after SELECT + selectPos := strings.Index(strings.ToUpper(sqlQuery), "SELECT") + if selectPos >= 0 { + beforeSelect := sqlQuery[:selectPos+6] // "SELECT" + afterSelect := sqlQuery[selectPos+6:] + sqlQuery = beforeSelect + " DISTINCT" + afterSelect + logger.Debug("Applied DISTINCT to query") + } + + return sqlQuery +} + +// sqlQryWhereOr adds a WHERE clause with OR logic +func sqlQryWhereOr(sqlquery, condition string) string { + lowerQuery := strings.ToLower(sqlquery) + wherePos := strings.Index(lowerQuery, " where ") + groupPos := strings.Index(lowerQuery, " group by") + orderPos := strings.Index(lowerQuery, " order by") + limitPos := strings.Index(lowerQuery, " limit ") + + // Find the insertion point + insertPos := len(sqlquery) + if groupPos > 0 && groupPos < insertPos { + insertPos = groupPos + } + if orderPos > 0 && orderPos < insertPos { + insertPos = orderPos + } + if limitPos > 0 && limitPos < insertPos { + insertPos = limitPos + } + + if wherePos > 0 { + // WHERE exists, add OR condition + before := sqlquery[:insertPos] + after := sqlquery[insertPos:] + return fmt.Sprintf("%s OR (%s) %s", before, condition, after) + } else { + // No WHERE exists, add it + before := sqlquery[:insertPos] + after := sqlquery[insertPos:] + return fmt.Sprintf("%s WHERE %s %s", before, condition, after) + } +} diff --git a/pkg/funcspec/parameters_test.go b/pkg/funcspec/parameters_test.go new file mode 100644 index 0000000..f2fec6c --- /dev/null +++ b/pkg/funcspec/parameters_test.go @@ -0,0 +1,549 @@ +package funcspec + +import ( + "strings" + "testing" +) + +// TestParseParameters tests the comprehensive parameter parsing +func TestParseParameters(t *testing.T) { + handler := NewHandler(&MockDatabase{}) + + tests := []struct { + name string + queryParams map[string]string + headers map[string]string + validate func(t *testing.T, params *RequestParameters) + }{ + { + name: "Parse field selection", + headers: map[string]string{ + "X-Select-Fields": "id,name,email", + "X-Not-Select-Fields": "password,ssn", + }, + validate: func(t *testing.T, params *RequestParameters) { + if len(params.SelectFields) != 3 { + t.Errorf("Expected 3 select fields, got %d", len(params.SelectFields)) + } + if len(params.NotSelectFields) != 2 { + t.Errorf("Expected 2 not-select fields, got %d", len(params.NotSelectFields)) + } + }, + }, + { + name: "Parse distinct flag", + headers: map[string]string{ + "X-Distinct": "true", + }, + validate: func(t *testing.T, params *RequestParameters) { + if !params.Distinct { + t.Error("Expected Distinct to be true") + } + }, + }, + { + name: "Parse field filters", + headers: map[string]string{ + "X-FieldFilter-Status": "active", + "X-FieldFilter-Type": "admin", + }, + validate: func(t *testing.T, params *RequestParameters) { + if len(params.FieldFilters) != 2 { + t.Errorf("Expected 2 field filters, got %d", len(params.FieldFilters)) + } + if params.FieldFilters["status"] != "active" { + t.Errorf("Expected status filter=active, got %s", params.FieldFilters["status"]) + } + }, + }, + { + name: "Parse search filters", + headers: map[string]string{ + "X-SearchFilter-Name": "john", + "X-SearchFilter-Email": "test", + }, + validate: func(t *testing.T, params *RequestParameters) { + if len(params.SearchFilters) != 2 { + t.Errorf("Expected 2 search filters, got %d", len(params.SearchFilters)) + } + }, + }, + { + name: "Parse sort columns", + queryParams: map[string]string{ + "sort": "-created_at,name", + }, + validate: func(t *testing.T, params *RequestParameters) { + if params.SortColumns != "-created_at,name" { + t.Errorf("Expected sort columns=-created_at,name, got %s", params.SortColumns) + } + }, + }, + { + name: "Parse limit and offset", + queryParams: map[string]string{ + "limit": "100", + "offset": "50", + }, + validate: func(t *testing.T, params *RequestParameters) { + if params.Limit != 100 { + t.Errorf("Expected limit=100, got %d", params.Limit) + } + if params.Offset != 50 { + t.Errorf("Expected offset=50, got %d", params.Offset) + } + }, + }, + { + name: "Parse skip count", + headers: map[string]string{ + "X-SkipCount": "true", + }, + validate: func(t *testing.T, params *RequestParameters) { + if !params.SkipCount { + t.Error("Expected SkipCount to be true") + } + }, + }, + { + name: "Parse response format - syncfusion", + headers: map[string]string{ + "X-Syncfusion": "true", + }, + validate: func(t *testing.T, params *RequestParameters) { + if params.ResponseFormat != "syncfusion" { + t.Errorf("Expected ResponseFormat=syncfusion, got %s", params.ResponseFormat) + } + if !params.ComplexAPI { + t.Error("Expected ComplexAPI to be true for syncfusion format") + } + }, + }, + { + name: "Parse response format - detail", + headers: map[string]string{ + "X-DetailAPI": "true", + }, + validate: func(t *testing.T, params *RequestParameters) { + if params.ResponseFormat != "detail" { + t.Errorf("Expected ResponseFormat=detail, got %s", params.ResponseFormat) + } + }, + }, + { + name: "Parse simple API", + headers: map[string]string{ + "X-SimpleAPI": "true", + }, + validate: func(t *testing.T, params *RequestParameters) { + if params.ResponseFormat != "simple" { + t.Errorf("Expected ResponseFormat=simple, got %s", params.ResponseFormat) + } + if params.ComplexAPI { + t.Error("Expected ComplexAPI to be false for simple API") + } + }, + }, + { + name: "Parse custom SQL WHERE", + headers: map[string]string{ + "X-Custom-SQL-W": "status = 'active' AND deleted = false", + }, + validate: func(t *testing.T, params *RequestParameters) { + if params.CustomSQLWhere == "" { + t.Error("Expected CustomSQLWhere to be set") + } + }, + }, + { + name: "Parse search operators - AND", + headers: map[string]string{ + "X-SearchOp-Eq-Name": "john", + "X-SearchOp-Gt-Age": "18", + }, + validate: func(t *testing.T, params *RequestParameters) { + if len(params.SearchOps) != 2 { + t.Errorf("Expected 2 search operators, got %d", len(params.SearchOps)) + } + if op, exists := params.SearchOps["name"]; exists { + if op.Operator != "eq" { + t.Errorf("Expected operator=eq for name, got %s", op.Operator) + } + if op.Logic != "AND" { + t.Errorf("Expected logic=AND, got %s", op.Logic) + } + } else { + t.Error("Expected name search operator to exist") + } + }, + }, + { + name: "Parse search operators - OR", + headers: map[string]string{ + "X-SearchOr-Like-Description": "test", + }, + validate: func(t *testing.T, params *RequestParameters) { + if op, exists := params.SearchOps["description"]; exists { + if op.Logic != "OR" { + t.Errorf("Expected logic=OR, got %s", op.Logic) + } + } else { + t.Error("Expected description search operator to exist") + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := createTestRequest("GET", "/test", tt.queryParams, tt.headers, nil) + params := handler.ParseParameters(req) + + if tt.validate != nil { + tt.validate(t, params) + } + }) + } +} + +// TestBuildFilterCondition tests the filter condition builder +func TestBuildFilterCondition(t *testing.T) { + handler := NewHandler(&MockDatabase{}) + + tests := []struct { + name string + colName string + operator FilterOperator + expected string + }{ + { + name: "Equals operator - numeric", + colName: "age", + operator: FilterOperator{ + Operator: "eq", + Value: "25", + Logic: "AND", + }, + expected: "age = 25", + }, + { + name: "Equals operator - string", + colName: "name", + operator: FilterOperator{ + Operator: "eq", + Value: "john", + Logic: "AND", + }, + expected: "name = 'john'", + }, + { + name: "Not equals operator", + colName: "status", + operator: FilterOperator{ + Operator: "neq", + Value: "inactive", + Logic: "AND", + }, + expected: "status != 'inactive'", + }, + { + name: "Greater than operator", + colName: "age", + operator: FilterOperator{ + Operator: "gt", + Value: "18", + Logic: "AND", + }, + expected: "age > 18", + }, + { + name: "Less than operator", + colName: "price", + operator: FilterOperator{ + Operator: "lt", + Value: "100", + Logic: "AND", + }, + expected: "price < 100", + }, + { + name: "Contains operator", + colName: "description", + operator: FilterOperator{ + Operator: "contains", + Value: "test", + Logic: "AND", + }, + expected: "description ILIKE '%test%'", + }, + { + name: "Starts with operator", + colName: "name", + operator: FilterOperator{ + Operator: "startswith", + Value: "john", + Logic: "AND", + }, + expected: "name ILIKE 'john%'", + }, + { + name: "Ends with operator", + colName: "email", + operator: FilterOperator{ + Operator: "endswith", + Value: "@example.com", + Logic: "AND", + }, + expected: "email ILIKE '%@example.com'", + }, + { + name: "Between operator", + colName: "age", + operator: FilterOperator{ + Operator: "between", + Value: "18,65", + Logic: "AND", + }, + expected: "age > 18 AND age < 65", + }, + { + name: "IN operator", + colName: "status", + operator: FilterOperator{ + Operator: "in", + Value: "active,pending,approved", + Logic: "AND", + }, + expected: "status IN ('active', 'pending', 'approved')", + }, + { + name: "IS NULL operator", + colName: "deleted_at", + operator: FilterOperator{ + Operator: "null", + Value: "", + Logic: "AND", + }, + expected: "(deleted_at IS NULL OR deleted_at = '')", + }, + { + name: "IS NOT NULL operator", + colName: "created_at", + operator: FilterOperator{ + Operator: "notnull", + Value: "", + Logic: "AND", + }, + expected: "(created_at IS NOT NULL AND created_at != '')", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := handler.buildFilterCondition(tt.colName, tt.operator) + if result != tt.expected { + t.Errorf("Expected: %s\nGot: %s", tt.expected, result) + } + }) + } +} + +// TestApplyFilters tests the filter application to SQL queries +func TestApplyFilters(t *testing.T) { + handler := NewHandler(&MockDatabase{}) + + tests := []struct { + name string + sqlQuery string + params *RequestParameters + expectedSQL string + shouldContain []string + }{ + { + name: "Apply field filter", + sqlQuery: "SELECT * FROM users", + params: &RequestParameters{ + FieldFilters: map[string]string{ + "status": "active", + }, + }, + shouldContain: []string{"WHERE", "status"}, + }, + { + name: "Apply search filter", + sqlQuery: "SELECT * FROM users", + params: &RequestParameters{ + SearchFilters: map[string]string{ + "name": "john", + }, + }, + shouldContain: []string{"WHERE", "name", "ILIKE"}, + }, + { + name: "Apply search operators", + sqlQuery: "SELECT * FROM users", + params: &RequestParameters{ + SearchOps: map[string]FilterOperator{ + "age": { + Operator: "gt", + Value: "18", + Logic: "AND", + }, + }, + }, + shouldContain: []string{"WHERE", "age", ">", "18"}, + }, + { + name: "Apply custom SQL WHERE", + sqlQuery: "SELECT * FROM users", + params: &RequestParameters{ + CustomSQLWhere: "deleted = false", + }, + shouldContain: []string{"WHERE", "deleted"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := handler.ApplyFilters(tt.sqlQuery, tt.params) + + for _, expected := range tt.shouldContain { + if !strings.Contains(result, expected) { + t.Errorf("Expected SQL to contain %q, got: %s", expected, result) + } + } + }) + } +} + +// TestApplyDistinct tests DISTINCT application +func TestApplyDistinct(t *testing.T) { + handler := NewHandler(&MockDatabase{}) + + tests := []struct { + name string + sqlQuery string + distinct bool + shouldHave string + }{ + { + name: "Apply DISTINCT", + sqlQuery: "SELECT id, name FROM users", + distinct: true, + shouldHave: "SELECT DISTINCT", + }, + { + name: "Do not apply DISTINCT", + sqlQuery: "SELECT id, name FROM users", + distinct: false, + shouldHave: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + params := &RequestParameters{Distinct: tt.distinct} + result := handler.ApplyDistinct(tt.sqlQuery, params) + + if tt.shouldHave != "" { + if !strings.Contains(result, tt.shouldHave) { + t.Errorf("Expected SQL to contain %q, got: %s", tt.shouldHave, result) + } + } else { + // Should not have DISTINCT when not requested + if strings.Contains(result, "DISTINCT") && !tt.distinct { + t.Errorf("SQL should not contain DISTINCT when not requested: %s", result) + } + } + }) + } +} + +// TestParseCommaSeparated tests comma-separated value parsing +func TestParseCommaSeparated(t *testing.T) { + handler := NewHandler(&MockDatabase{}) + + tests := []struct { + name string + input string + expected []string + }{ + { + name: "Simple comma-separated", + input: "id,name,email", + expected: []string{"id", "name", "email"}, + }, + { + name: "With spaces", + input: "id, name, email", + expected: []string{"id", "name", "email"}, + }, + { + name: "Empty string", + input: "", + expected: nil, + }, + { + name: "Single value", + input: "id", + expected: []string{"id"}, + }, + { + name: "With extra commas", + input: "id,,name,,email", + expected: []string{"id", "name", "email"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := handler.parseCommaSeparated(tt.input) + + if len(result) != len(tt.expected) { + t.Errorf("Expected %d values, got %d", len(tt.expected), len(result)) + return + } + + for i, expected := range tt.expected { + if result[i] != expected { + t.Errorf("Expected value %d to be %s, got %s", i, expected, result[i]) + } + } + }) + } +} + +// TestSqlQryWhereOr tests OR WHERE clause manipulation +func TestSqlQryWhereOr(t *testing.T) { + tests := []struct { + name string + sqlQuery string + condition string + shouldContain []string + }{ + { + name: "Add WHERE with OR to query without WHERE", + sqlQuery: "SELECT * FROM users", + condition: "status = 'inactive'", + shouldContain: []string{"WHERE", "status = 'inactive'"}, + }, + { + name: "Add OR to query with existing WHERE", + sqlQuery: "SELECT * FROM users WHERE id > 0", + condition: "status = 'inactive'", + shouldContain: []string{"WHERE", "OR", "(status = 'inactive')"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := sqlQryWhereOr(tt.sqlQuery, tt.condition) + + for _, expected := range tt.shouldContain { + if !strings.Contains(result, expected) { + t.Errorf("Expected SQL to contain %q, got: %s", expected, result) + } + } + }) + } +} diff --git a/pkg/resolvespec/handler.go b/pkg/resolvespec/handler.go index 93b421e..3cd7b89 100644 --- a/pkg/resolvespec/handler.go +++ b/pkg/resolvespec/handler.go @@ -8,7 +8,9 @@ import ( "reflect" "runtime/debug" "strings" + "time" + "github.com/bitechdev/ResolveSpec/pkg/cache" "github.com/bitechdev/ResolveSpec/pkg/common" "github.com/bitechdev/ResolveSpec/pkg/logger" "github.com/bitechdev/ResolveSpec/pkg/reflection" @@ -233,13 +235,46 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st } // Get total count before pagination - total, err := query.Count(ctx) - if err != nil { - logger.Error("Error counting records: %v", err) - h.sendError(w, http.StatusInternalServerError, "query_error", "Error counting records", err) - return + var total int + + // Try to get from cache first + cacheKeyHash := cache.BuildQueryCacheKey( + tableName, + options.Filters, + options.Sort, + "", // No custom SQL WHERE in resolvespec + "", // No custom SQL OR in resolvespec + ) + cacheKey := cache.GetQueryTotalCacheKey(cacheKeyHash) + + // Try to retrieve from cache + var cachedTotal cache.CachedTotal + err := cache.GetDefaultCache().Get(ctx, cacheKey, &cachedTotal) + if err == nil { + total = cachedTotal.Total + logger.Debug("Total records (from cache): %d", total) + } else { + // Cache miss - execute count query + logger.Debug("Cache miss for query total") + count, err := query.Count(ctx) + if err != nil { + logger.Error("Error counting records: %v", err) + h.sendError(w, http.StatusInternalServerError, "query_error", "Error counting records", err) + return + } + total = count + logger.Debug("Total records (from query): %d", total) + + // Store in cache + cacheTTL := time.Minute * 2 // Default 2 minutes TTL + cacheData := cache.CachedTotal{Total: total} + if err := cache.GetDefaultCache().Set(ctx, cacheKey, cacheData, cacheTTL); err != nil { + logger.Warn("Failed to cache query total: %v", err) + // Don't fail the request if caching fails + } else { + logger.Debug("Cached query total with key: %s", cacheKey) + } } - logger.Debug("Total records before filtering: %d", total) // Apply pagination if options.Limit != nil && *options.Limit > 0 { diff --git a/pkg/restheadspec/handler.go b/pkg/restheadspec/handler.go index 515c2e6..a58dcc5 100644 --- a/pkg/restheadspec/handler.go +++ b/pkg/restheadspec/handler.go @@ -9,7 +9,9 @@ import ( "runtime/debug" "strconv" "strings" + "time" + "github.com/bitechdev/ResolveSpec/pkg/cache" "github.com/bitechdev/ResolveSpec/pkg/common" "github.com/bitechdev/ResolveSpec/pkg/logger" "github.com/bitechdev/ResolveSpec/pkg/reflection" @@ -436,14 +438,69 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st // Get total count before pagination (unless skip count is requested) var total int if !options.SkipCount { - count, err := query.Count(ctx) - if err != nil { - logger.Error("Error counting records: %v", err) - h.sendError(w, http.StatusInternalServerError, "query_error", "Error counting records", err) - return + // Try to get from cache first (unless SkipCache is true) + var cachedTotal *cache.CachedTotal + var cacheKey string + + if !options.SkipCache { + // Build cache key from query parameters + // Convert expand options to interface slice for the cache key builder + expandOpts := make([]interface{}, len(options.Expand)) + for i, exp := range options.Expand { + expandOpts[i] = map[string]interface{}{ + "relation": exp.Relation, + "where": exp.Where, + } + } + + cacheKeyHash := cache.BuildExtendedQueryCacheKey( + tableName, + options.Filters, + options.Sort, + options.CustomSQLWhere, + options.CustomSQLOr, + expandOpts, + options.Distinct, + options.CursorForward, + options.CursorBackward, + ) + cacheKey = cache.GetQueryTotalCacheKey(cacheKeyHash) + + // Try to retrieve from cache + cachedTotal = &cache.CachedTotal{} + err := cache.GetDefaultCache().Get(ctx, cacheKey, cachedTotal) + if err == nil { + total = cachedTotal.Total + logger.Debug("Total records (from cache): %d", total) + } else { + logger.Debug("Cache miss for query total") + cachedTotal = nil + } + } + + // If not in cache or cache skip, execute count query + if cachedTotal == nil { + count, err := query.Count(ctx) + if err != nil { + logger.Error("Error counting records: %v", err) + h.sendError(w, http.StatusInternalServerError, "query_error", "Error counting records", err) + return + } + total = count + logger.Debug("Total records (from query): %d", total) + + // Store in cache (if caching is enabled) + if !options.SkipCache && cacheKey != "" { + cacheTTL := time.Minute * 2 // Default 2 minutes TTL + cacheData := &cache.CachedTotal{Total: total} + if err := cache.GetDefaultCache().Set(ctx, cacheKey, cacheData, cacheTTL); err != nil { + logger.Warn("Failed to cache query total: %v", err) + // Don't fail the request if caching fails + } else { + logger.Debug("Cached query total with key: %s", cacheKey) + } + } } - total = count - logger.Debug("Total records: %d", total) } else { logger.Debug("Skipping count as requested") total = -1 // Indicate count was skipped