mirror of
https://github.com/bitechdev/ResolveSpec.git
synced 2025-12-06 14:26:22 +00:00
Added cache, funcspec and implemented total cache
This commit is contained in:
parent
6bbe0ec8b0
commit
1643a5e920
4
go.mod
4
go.mod
@ -19,7 +19,10 @@ require (
|
|||||||
)
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
|
github.com/bradfitz/gomemcache v0.0.0-20250403215159-8d39553ac7cf // indirect
|
||||||
|
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||||
|
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||||
github.com/glebarez/go-sqlite v1.21.2 // indirect
|
github.com/glebarez/go-sqlite v1.21.2 // indirect
|
||||||
github.com/google/uuid v1.6.0 // indirect
|
github.com/google/uuid v1.6.0 // indirect
|
||||||
@ -30,6 +33,7 @@ require (
|
|||||||
github.com/ncruces/go-strftime v0.1.9 // indirect
|
github.com/ncruces/go-strftime v0.1.9 // indirect
|
||||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||||
github.com/puzpuzpuz/xsync/v3 v3.5.1 // indirect
|
github.com/puzpuzpuz/xsync/v3 v3.5.1 // indirect
|
||||||
|
github.com/redis/go-redis/v9 v9.17.1 // indirect
|
||||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
||||||
github.com/tidwall/match v1.1.1 // indirect
|
github.com/tidwall/match v1.1.1 // indirect
|
||||||
github.com/tidwall/pretty v1.2.0 // indirect
|
github.com/tidwall/pretty v1.2.0 // indirect
|
||||||
|
|||||||
8
go.sum
8
go.sum
@ -1,6 +1,12 @@
|
|||||||
|
github.com/bradfitz/gomemcache v0.0.0-20250403215159-8d39553ac7cf h1:TqhNAT4zKbTdLa62d2HDBFdvgSbIGB3eJE8HqhgiL9I=
|
||||||
|
github.com/bradfitz/gomemcache v0.0.0-20250403215159-8d39553ac7cf/go.mod h1:r5xuitiExdLAJ09PR7vBVENGvp4ZuTBeWTGtxuX3K+c=
|
||||||
|
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
|
||||||
|
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
|
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
|
||||||
|
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
||||||
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
||||||
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
||||||
github.com/glebarez/go-sqlite v1.21.2 h1:3a6LFC4sKahUunAmynQKLZceZCOzUthkRkEAl9gAXWo=
|
github.com/glebarez/go-sqlite v1.21.2 h1:3a6LFC4sKahUunAmynQKLZceZCOzUthkRkEAl9gAXWo=
|
||||||
@ -31,6 +37,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
|
|||||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
github.com/puzpuzpuz/xsync/v3 v3.5.1 h1:GJYJZwO6IdxN/IKbneznS6yPkVC+c3zyY/j19c++5Fg=
|
github.com/puzpuzpuz/xsync/v3 v3.5.1 h1:GJYJZwO6IdxN/IKbneznS6yPkVC+c3zyY/j19c++5Fg=
|
||||||
github.com/puzpuzpuz/xsync/v3 v3.5.1/go.mod h1:VjzYrABPabuM4KyBh1Ftq6u8nhwY5tBPKP9jpmh0nnA=
|
github.com/puzpuzpuz/xsync/v3 v3.5.1/go.mod h1:VjzYrABPabuM4KyBh1Ftq6u8nhwY5tBPKP9jpmh0nnA=
|
||||||
|
github.com/redis/go-redis/v9 v9.17.1 h1:7tl732FjYPRT9H9aNfyTwKg9iTETjWjGKEJ2t/5iWTs=
|
||||||
|
github.com/redis/go-redis/v9 v9.17.1/go.mod h1:u410H11HMLoB+TP67dz8rL9s6QW2j76l0//kSOd3370=
|
||||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
|
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
|
||||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
||||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||||
|
|||||||
340
pkg/cache/README.md
vendored
Normal file
340
pkg/cache/README.md
vendored
Normal file
@ -0,0 +1,340 @@
|
|||||||
|
# Cache Package
|
||||||
|
|
||||||
|
A flexible, provider-based caching library for Go that supports multiple backend storage systems including in-memory, Redis, and Memcache.
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
- **Multiple Providers**: Support for in-memory, Redis, and Memcache backends
|
||||||
|
- **Pluggable Architecture**: Easy to add custom cache providers
|
||||||
|
- **Type-Safe API**: Automatic JSON serialization/deserialization
|
||||||
|
- **TTL Support**: Configurable time-to-live for cache entries
|
||||||
|
- **Context-Aware**: All operations support Go contexts
|
||||||
|
- **Statistics**: Built-in cache statistics and monitoring
|
||||||
|
- **Pattern Deletion**: Delete keys by pattern (Redis)
|
||||||
|
- **Lazy Loading**: GetOrSet pattern for easy cache-aside implementation
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
```bash
|
||||||
|
go get github.com/bitechdev/ResolveSpec/pkg/cache
|
||||||
|
```
|
||||||
|
|
||||||
|
For Redis support:
|
||||||
|
```bash
|
||||||
|
go get github.com/redis/go-redis/v9
|
||||||
|
```
|
||||||
|
|
||||||
|
For Memcache support:
|
||||||
|
```bash
|
||||||
|
go get github.com/bradfitz/gomemcache/memcache
|
||||||
|
```
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
### In-Memory Cache
|
||||||
|
|
||||||
|
```go
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"time"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/cache"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
// Initialize with in-memory provider
|
||||||
|
cache.UseMemory(&cache.Options{
|
||||||
|
DefaultTTL: 5 * time.Minute,
|
||||||
|
MaxSize: 10000,
|
||||||
|
})
|
||||||
|
defer cache.Close()
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
c := cache.GetDefaultCache()
|
||||||
|
|
||||||
|
// Store a value
|
||||||
|
type User struct {
|
||||||
|
ID int
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
user := User{ID: 1, Name: "John"}
|
||||||
|
c.Set(ctx, "user:1", user, 10*time.Minute)
|
||||||
|
|
||||||
|
// Retrieve a value
|
||||||
|
var retrieved User
|
||||||
|
c.Get(ctx, "user:1", &retrieved)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Redis Cache
|
||||||
|
|
||||||
|
```go
|
||||||
|
cache.UseRedis(&cache.RedisConfig{
|
||||||
|
Host: "localhost",
|
||||||
|
Port: 6379,
|
||||||
|
Password: "",
|
||||||
|
DB: 0,
|
||||||
|
Options: &cache.Options{
|
||||||
|
DefaultTTL: 5 * time.Minute,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
defer cache.Close()
|
||||||
|
```
|
||||||
|
|
||||||
|
### Memcache
|
||||||
|
|
||||||
|
```go
|
||||||
|
cache.UseMemcache(&cache.MemcacheConfig{
|
||||||
|
Servers: []string{"localhost:11211"},
|
||||||
|
Options: &cache.Options{
|
||||||
|
DefaultTTL: 5 * time.Minute,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
defer cache.Close()
|
||||||
|
```
|
||||||
|
|
||||||
|
## API Reference
|
||||||
|
|
||||||
|
### Core Methods
|
||||||
|
|
||||||
|
#### Set
|
||||||
|
```go
|
||||||
|
Set(ctx context.Context, key string, value interface{}, ttl time.Duration) error
|
||||||
|
```
|
||||||
|
Stores a value in the cache with automatic JSON serialization.
|
||||||
|
|
||||||
|
#### Get
|
||||||
|
```go
|
||||||
|
Get(ctx context.Context, key string, dest interface{}) error
|
||||||
|
```
|
||||||
|
Retrieves and deserializes a value from the cache.
|
||||||
|
|
||||||
|
#### SetBytes / GetBytes
|
||||||
|
```go
|
||||||
|
SetBytes(ctx context.Context, key string, value []byte, ttl time.Duration) error
|
||||||
|
GetBytes(ctx context.Context, key string) ([]byte, error)
|
||||||
|
```
|
||||||
|
Store and retrieve raw bytes without serialization.
|
||||||
|
|
||||||
|
#### Delete
|
||||||
|
```go
|
||||||
|
Delete(ctx context.Context, key string) error
|
||||||
|
```
|
||||||
|
Removes a key from the cache.
|
||||||
|
|
||||||
|
#### DeleteByPattern
|
||||||
|
```go
|
||||||
|
DeleteByPattern(ctx context.Context, pattern string) error
|
||||||
|
```
|
||||||
|
Removes all keys matching a pattern (Redis only).
|
||||||
|
|
||||||
|
#### Clear
|
||||||
|
```go
|
||||||
|
Clear(ctx context.Context) error
|
||||||
|
```
|
||||||
|
Removes all items from the cache.
|
||||||
|
|
||||||
|
#### Exists
|
||||||
|
```go
|
||||||
|
Exists(ctx context.Context, key string) bool
|
||||||
|
```
|
||||||
|
Checks if a key exists in the cache.
|
||||||
|
|
||||||
|
#### GetOrSet
|
||||||
|
```go
|
||||||
|
GetOrSet(ctx context.Context, key string, dest interface{}, ttl time.Duration,
|
||||||
|
loader func() (interface{}, error)) error
|
||||||
|
```
|
||||||
|
Retrieves a value from cache, or loads and caches it if not found (lazy loading).
|
||||||
|
|
||||||
|
#### Stats
|
||||||
|
```go
|
||||||
|
Stats(ctx context.Context) (*CacheStats, error)
|
||||||
|
```
|
||||||
|
Returns cache statistics including hits, misses, and key counts.
|
||||||
|
|
||||||
|
## Provider Configuration
|
||||||
|
|
||||||
|
### In-Memory Options
|
||||||
|
|
||||||
|
```go
|
||||||
|
&cache.Options{
|
||||||
|
DefaultTTL: 5 * time.Minute, // Default expiration time
|
||||||
|
MaxSize: 10000, // Maximum number of items
|
||||||
|
EvictionPolicy: "LRU", // Eviction strategy (future)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Redis Configuration
|
||||||
|
|
||||||
|
```go
|
||||||
|
&cache.RedisConfig{
|
||||||
|
Host: "localhost",
|
||||||
|
Port: 6379,
|
||||||
|
Password: "", // Optional authentication
|
||||||
|
DB: 0, // Database number
|
||||||
|
PoolSize: 10, // Connection pool size
|
||||||
|
Options: &cache.Options{
|
||||||
|
DefaultTTL: 5 * time.Minute,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Memcache Configuration
|
||||||
|
|
||||||
|
```go
|
||||||
|
&cache.MemcacheConfig{
|
||||||
|
Servers: []string{"localhost:11211"},
|
||||||
|
MaxIdleConns: 2,
|
||||||
|
Timeout: 1 * time.Second,
|
||||||
|
Options: &cache.Options{
|
||||||
|
DefaultTTL: 5 * time.Minute,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Advanced Usage
|
||||||
|
|
||||||
|
### Custom Provider
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Create a custom provider instance
|
||||||
|
memProvider := cache.NewMemoryProvider(&cache.Options{
|
||||||
|
DefaultTTL: 10 * time.Minute,
|
||||||
|
MaxSize: 500,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Initialize with custom provider
|
||||||
|
cache.Initialize(memProvider)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Lazy Loading Pattern
|
||||||
|
|
||||||
|
```go
|
||||||
|
var data ExpensiveData
|
||||||
|
err := c.GetOrSet(ctx, "expensive:key", &data, 10*time.Minute, func() (interface{}, error) {
|
||||||
|
// This expensive operation only runs if key is not in cache
|
||||||
|
return computeExpensiveData(), nil
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
### Query API Cache
|
||||||
|
|
||||||
|
The package includes specialized functions for caching query results:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Cache a query result
|
||||||
|
api := "GetUsers"
|
||||||
|
query := "SELECT * FROM users WHERE active = true"
|
||||||
|
tablenames := "users"
|
||||||
|
total := int64(150)
|
||||||
|
|
||||||
|
cache.PutQueryAPICache(ctx, api, query, tablenames, total)
|
||||||
|
|
||||||
|
// Retrieve cached query
|
||||||
|
hash := cache.HashQueryAPICache(api, query)
|
||||||
|
cachedQuery, err := cache.FetchQueryAPICache(ctx, hash)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Provider Comparison
|
||||||
|
|
||||||
|
| Feature | In-Memory | Redis | Memcache |
|
||||||
|
|---------|-----------|-------|----------|
|
||||||
|
| Persistence | No | Yes | No |
|
||||||
|
| Distributed | No | Yes | Yes |
|
||||||
|
| Pattern Delete | No | Yes | No |
|
||||||
|
| Statistics | Full | Full | Limited |
|
||||||
|
| Atomic Operations | Yes | Yes | Yes |
|
||||||
|
| Max Item Size | Memory | 512MB | 1MB |
|
||||||
|
|
||||||
|
## Best Practices
|
||||||
|
|
||||||
|
1. **Use contexts**: Always pass context for cancellation and timeout control
|
||||||
|
2. **Set appropriate TTLs**: Balance between freshness and performance
|
||||||
|
3. **Handle errors**: Cache misses and errors should be handled gracefully
|
||||||
|
4. **Monitor statistics**: Use Stats() to monitor cache performance
|
||||||
|
5. **Clean up**: Always call Close() when shutting down
|
||||||
|
6. **Pattern consistency**: Use consistent key naming patterns (e.g., "user:id:field")
|
||||||
|
|
||||||
|
## Example: Complete Application
|
||||||
|
|
||||||
|
```go
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"log"
|
||||||
|
"time"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/cache"
|
||||||
|
)
|
||||||
|
|
||||||
|
type UserService struct {
|
||||||
|
cache *cache.Cache
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewUserService() *UserService {
|
||||||
|
// Initialize with Redis in production, memory for testing
|
||||||
|
cache.UseRedis(&cache.RedisConfig{
|
||||||
|
Host: "localhost",
|
||||||
|
Port: 6379,
|
||||||
|
Options: &cache.Options{
|
||||||
|
DefaultTTL: 10 * time.Minute,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
return &UserService{
|
||||||
|
cache: cache.GetDefaultCache(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *UserService) GetUser(ctx context.Context, userID int) (*User, error) {
|
||||||
|
var user User
|
||||||
|
cacheKey := fmt.Sprintf("user:%d", userID)
|
||||||
|
|
||||||
|
// Try to get from cache first
|
||||||
|
err := s.cache.GetOrSet(ctx, cacheKey, &user, 15*time.Minute, func() (interface{}, error) {
|
||||||
|
// Load from database if not in cache
|
||||||
|
return s.loadUserFromDB(userID)
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &user, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *UserService) InvalidateUser(ctx context.Context, userID int) error {
|
||||||
|
cacheKey := fmt.Sprintf("user:%d", userID)
|
||||||
|
return s.cache.Delete(ctx, cacheKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
service := NewUserService()
|
||||||
|
defer cache.Close()
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
user, err := service.GetUser(ctx, 123)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Printf("User: %+v", user)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Performance Considerations
|
||||||
|
|
||||||
|
- **In-Memory**: Fastest but limited by RAM and not distributed
|
||||||
|
- **Redis**: Great for distributed systems, persistent, but network overhead
|
||||||
|
- **Memcache**: Good for distributed caching, simpler than Redis but less features
|
||||||
|
|
||||||
|
Choose based on your needs:
|
||||||
|
- Single instance? Use in-memory
|
||||||
|
- Need persistence or advanced features? Use Redis
|
||||||
|
- Simple distributed cache? Use Memcache
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
See repository license.
|
||||||
76
pkg/cache/cache.go
vendored
Normal file
76
pkg/cache/cache.go
vendored
Normal file
@ -0,0 +1,76 @@
|
|||||||
|
package cache
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
defaultCache *Cache
|
||||||
|
)
|
||||||
|
|
||||||
|
// Initialize initializes the cache with a provider.
|
||||||
|
// If not called, the package will use an in-memory provider by default.
|
||||||
|
func Initialize(provider Provider) {
|
||||||
|
defaultCache = NewCache(provider)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UseMemory configures the cache to use in-memory storage.
|
||||||
|
func UseMemory(opts *Options) error {
|
||||||
|
provider := NewMemoryProvider(opts)
|
||||||
|
defaultCache = NewCache(provider)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UseRedis configures the cache to use Redis storage.
|
||||||
|
func UseRedis(config *RedisConfig) error {
|
||||||
|
provider, err := NewRedisProvider(config)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to initialize Redis provider: %w", err)
|
||||||
|
}
|
||||||
|
defaultCache = NewCache(provider)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UseMemcache configures the cache to use Memcache storage.
|
||||||
|
func UseMemcache(config *MemcacheConfig) error {
|
||||||
|
provider, err := NewMemcacheProvider(config)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to initialize Memcache provider: %w", err)
|
||||||
|
}
|
||||||
|
defaultCache = NewCache(provider)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDefaultCache returns the default cache instance.
|
||||||
|
// Initializes with in-memory provider if not already initialized.
|
||||||
|
func GetDefaultCache() *Cache {
|
||||||
|
if defaultCache == nil {
|
||||||
|
UseMemory(&Options{
|
||||||
|
DefaultTTL: 5 * time.Minute,
|
||||||
|
MaxSize: 10000,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return defaultCache
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetDefaultCache sets a custom cache instance as the default cache.
|
||||||
|
// This is useful for testing or when you want to use a pre-configured cache instance.
|
||||||
|
func SetDefaultCache(cache *Cache) {
|
||||||
|
defaultCache = cache
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetStats returns cache statistics.
|
||||||
|
func GetStats(ctx context.Context) (*CacheStats, error) {
|
||||||
|
cache := GetDefaultCache()
|
||||||
|
return cache.Stats(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close closes the cache and releases resources.
|
||||||
|
func Close() error {
|
||||||
|
if defaultCache != nil {
|
||||||
|
return defaultCache.Close()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
147
pkg/cache/cache_manager.go
vendored
Normal file
147
pkg/cache/cache_manager.go
vendored
Normal file
@ -0,0 +1,147 @@
|
|||||||
|
package cache
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Cache is the main cache manager that wraps a Provider.
|
||||||
|
type Cache struct {
|
||||||
|
provider Provider
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewCache creates a new cache manager with the specified provider.
|
||||||
|
func NewCache(provider Provider) *Cache {
|
||||||
|
return &Cache{
|
||||||
|
provider: provider,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get retrieves and deserializes a value from the cache.
|
||||||
|
func (c *Cache) Get(ctx context.Context, key string, dest interface{}) error {
|
||||||
|
data, exists := c.provider.Get(ctx, key)
|
||||||
|
if !exists {
|
||||||
|
return fmt.Errorf("key not found: %s", key)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.Unmarshal(data, dest); err != nil {
|
||||||
|
return fmt.Errorf("failed to deserialize: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetBytes retrieves raw bytes from the cache.
|
||||||
|
func (c *Cache) GetBytes(ctx context.Context, key string) ([]byte, error) {
|
||||||
|
data, exists := c.provider.Get(ctx, key)
|
||||||
|
if !exists {
|
||||||
|
return nil, fmt.Errorf("key not found: %s", key)
|
||||||
|
}
|
||||||
|
return data, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set serializes and stores a value in the cache with the specified TTL.
|
||||||
|
func (c *Cache) Set(ctx context.Context, key string, value interface{}, ttl time.Duration) error {
|
||||||
|
data, err := json.Marshal(value)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to serialize: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.provider.Set(ctx, key, data, ttl)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetBytes stores raw bytes in the cache with the specified TTL.
|
||||||
|
func (c *Cache) SetBytes(ctx context.Context, key string, value []byte, ttl time.Duration) error {
|
||||||
|
return c.provider.Set(ctx, key, value, ttl)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete removes a key from the cache.
|
||||||
|
func (c *Cache) Delete(ctx context.Context, key string) error {
|
||||||
|
return c.provider.Delete(ctx, key)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteByPattern removes all keys matching the pattern.
|
||||||
|
func (c *Cache) DeleteByPattern(ctx context.Context, pattern string) error {
|
||||||
|
return c.provider.DeleteByPattern(ctx, pattern)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clear removes all items from the cache.
|
||||||
|
func (c *Cache) Clear(ctx context.Context) error {
|
||||||
|
return c.provider.Clear(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exists checks if a key exists in the cache.
|
||||||
|
func (c *Cache) Exists(ctx context.Context, key string) bool {
|
||||||
|
return c.provider.Exists(ctx, key)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stats returns statistics about the cache.
|
||||||
|
func (c *Cache) Stats(ctx context.Context) (*CacheStats, error) {
|
||||||
|
return c.provider.Stats(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close closes the cache and releases any resources.
|
||||||
|
func (c *Cache) Close() error {
|
||||||
|
return c.provider.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetOrSet retrieves a value from cache, or sets it if it doesn't exist.
|
||||||
|
// The loader function is called only if the key is not found in cache.
|
||||||
|
func (c *Cache) GetOrSet(ctx context.Context, key string, dest interface{}, ttl time.Duration, loader func() (interface{}, error)) error {
|
||||||
|
// Try to get from cache first
|
||||||
|
err := c.Get(ctx, key, dest)
|
||||||
|
if err == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load the value
|
||||||
|
value, err := loader()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("loader failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store in cache
|
||||||
|
if err := c.Set(ctx, key, value, ttl); err != nil {
|
||||||
|
return fmt.Errorf("failed to cache value: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Populate dest with the loaded value
|
||||||
|
data, err := json.Marshal(value)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to serialize loaded value: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.Unmarshal(data, dest); err != nil {
|
||||||
|
return fmt.Errorf("failed to deserialize loaded value: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remember is a convenience function that caches the result of a function call.
|
||||||
|
// It's similar to GetOrSet but returns the value directly.
|
||||||
|
func (c *Cache) Remember(ctx context.Context, key string, ttl time.Duration, loader func() (interface{}, error)) (interface{}, error) {
|
||||||
|
// Try to get from cache first as bytes
|
||||||
|
data, err := c.GetBytes(ctx, key)
|
||||||
|
if err == nil {
|
||||||
|
var result interface{}
|
||||||
|
if err := json.Unmarshal(data, &result); err == nil {
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load the value
|
||||||
|
value, err := loader()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("loader failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store in cache
|
||||||
|
if err := c.Set(ctx, key, value, ttl); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to cache value: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return value, nil
|
||||||
|
}
|
||||||
69
pkg/cache/cache_test.go
vendored
Normal file
69
pkg/cache/cache_test.go
vendored
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
package cache
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSetDefaultCache(t *testing.T) {
|
||||||
|
// Create a custom cache instance
|
||||||
|
provider := NewMemoryProvider(&Options{
|
||||||
|
DefaultTTL: 1 * time.Minute,
|
||||||
|
MaxSize: 50,
|
||||||
|
})
|
||||||
|
customCache := NewCache(provider)
|
||||||
|
|
||||||
|
// Set it as the default
|
||||||
|
SetDefaultCache(customCache)
|
||||||
|
|
||||||
|
// Verify it's now the default
|
||||||
|
retrievedCache := GetDefaultCache()
|
||||||
|
if retrievedCache != customCache {
|
||||||
|
t.Error("SetDefaultCache did not set the cache correctly")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test that we can use it
|
||||||
|
ctx := context.Background()
|
||||||
|
testKey := "test_key"
|
||||||
|
testValue := "test_value"
|
||||||
|
|
||||||
|
err := retrievedCache.Set(ctx, testKey, testValue, time.Minute)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to set value: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var result string
|
||||||
|
err = retrievedCache.Get(ctx, testKey, &result)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to get value: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if result != testValue {
|
||||||
|
t.Errorf("Expected %s, got %s", testValue, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clean up - reset to default
|
||||||
|
SetDefaultCache(nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetDefaultCacheInitialization(t *testing.T) {
|
||||||
|
// Reset to nil first
|
||||||
|
SetDefaultCache(nil)
|
||||||
|
|
||||||
|
// GetDefaultCache should auto-initialize
|
||||||
|
cache := GetDefaultCache()
|
||||||
|
if cache == nil {
|
||||||
|
t.Error("GetDefaultCache should auto-initialize, got nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should be usable
|
||||||
|
ctx := context.Background()
|
||||||
|
err := cache.Set(ctx, "test", "value", time.Minute)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Failed to use auto-initialized cache: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clean up
|
||||||
|
SetDefaultCache(nil)
|
||||||
|
}
|
||||||
252
pkg/cache/example_usage.go
vendored
Normal file
252
pkg/cache/example_usage.go
vendored
Normal file
@ -0,0 +1,252 @@
|
|||||||
|
package cache
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ExampleInMemoryCache demonstrates using the in-memory cache provider.
|
||||||
|
func ExampleInMemoryCache() {
|
||||||
|
// Initialize with in-memory provider
|
||||||
|
err := UseMemory(&Options{
|
||||||
|
DefaultTTL: 5 * time.Minute,
|
||||||
|
MaxSize: 1000,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
defer Close()
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Get the cache instance
|
||||||
|
cache := GetDefaultCache()
|
||||||
|
|
||||||
|
// Store a value
|
||||||
|
type User struct {
|
||||||
|
ID int
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
|
||||||
|
user := User{ID: 1, Name: "John Doe"}
|
||||||
|
err = cache.Set(ctx, "user:1", user, 10*time.Minute)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Retrieve a value
|
||||||
|
var retrieved User
|
||||||
|
err = cache.Get(ctx, "user:1", &retrieved)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("Retrieved user: %+v\n", retrieved)
|
||||||
|
|
||||||
|
// Check if key exists
|
||||||
|
exists := cache.Exists(ctx, "user:1")
|
||||||
|
fmt.Printf("Key exists: %v\n", exists)
|
||||||
|
|
||||||
|
// Delete a key
|
||||||
|
err = cache.Delete(ctx, "user:1")
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get statistics
|
||||||
|
stats, err := cache.Stats(ctx)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
fmt.Printf("Cache stats: %+v\n", stats)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExampleRedisCache demonstrates using the Redis cache provider.
|
||||||
|
func ExampleRedisCache() {
|
||||||
|
// Initialize with Redis provider
|
||||||
|
err := UseRedis(&RedisConfig{
|
||||||
|
Host: "localhost",
|
||||||
|
Port: 6379,
|
||||||
|
Password: "", // Set if Redis requires authentication
|
||||||
|
DB: 0,
|
||||||
|
Options: &Options{
|
||||||
|
DefaultTTL: 5 * time.Minute,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
defer Close()
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Get the cache instance
|
||||||
|
cache := GetDefaultCache()
|
||||||
|
|
||||||
|
// Store raw bytes
|
||||||
|
data := []byte("Hello, Redis!")
|
||||||
|
err = cache.SetBytes(ctx, "greeting", data, 1*time.Hour)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Retrieve raw bytes
|
||||||
|
retrieved, err := cache.GetBytes(ctx, "greeting")
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("Retrieved data: %s\n", string(retrieved))
|
||||||
|
|
||||||
|
// Clear all cache
|
||||||
|
err = cache.Clear(ctx)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExampleMemcacheCache demonstrates using the Memcache cache provider.
|
||||||
|
func ExampleMemcacheCache() {
|
||||||
|
// Initialize with Memcache provider
|
||||||
|
err := UseMemcache(&MemcacheConfig{
|
||||||
|
Servers: []string{"localhost:11211"},
|
||||||
|
Options: &Options{
|
||||||
|
DefaultTTL: 5 * time.Minute,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
defer Close()
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Get the cache instance
|
||||||
|
cache := GetDefaultCache()
|
||||||
|
|
||||||
|
// Store a value
|
||||||
|
type Product struct {
|
||||||
|
ID int
|
||||||
|
Name string
|
||||||
|
Price float64
|
||||||
|
}
|
||||||
|
|
||||||
|
product := Product{ID: 100, Name: "Widget", Price: 29.99}
|
||||||
|
err = cache.Set(ctx, "product:100", product, 30*time.Minute)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Retrieve a value
|
||||||
|
var retrieved Product
|
||||||
|
err = cache.Get(ctx, "product:100", &retrieved)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("Retrieved product: %+v\n", retrieved)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExampleGetOrSet demonstrates the GetOrSet pattern for lazy loading.
|
||||||
|
func ExampleGetOrSet() {
|
||||||
|
err := UseMemory(&Options{
|
||||||
|
DefaultTTL: 5 * time.Minute,
|
||||||
|
MaxSize: 1000,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
defer Close()
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
cache := GetDefaultCache()
|
||||||
|
|
||||||
|
type ExpensiveData struct {
|
||||||
|
Result string
|
||||||
|
}
|
||||||
|
|
||||||
|
var data ExpensiveData
|
||||||
|
err = cache.GetOrSet(ctx, "expensive:computation", &data, 10*time.Minute, func() (interface{}, error) {
|
||||||
|
// This expensive operation only runs if the key is not in cache
|
||||||
|
fmt.Println("Computing expensive result...")
|
||||||
|
time.Sleep(1 * time.Second)
|
||||||
|
return ExpensiveData{Result: "computed value"}, nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("Data: %+v\n", data)
|
||||||
|
|
||||||
|
// Second call will use cached value
|
||||||
|
err = cache.GetOrSet(ctx, "expensive:computation", &data, 10*time.Minute, func() (interface{}, error) {
|
||||||
|
fmt.Println("This won't be called!")
|
||||||
|
return ExpensiveData{Result: "new value"}, nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("Cached data: %+v\n", data)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExampleCustomProvider demonstrates using a custom provider.
|
||||||
|
func ExampleCustomProvider() {
|
||||||
|
// Create a custom provider
|
||||||
|
memProvider := NewMemoryProvider(&Options{
|
||||||
|
DefaultTTL: 10 * time.Minute,
|
||||||
|
MaxSize: 500,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Initialize with custom provider
|
||||||
|
Initialize(memProvider)
|
||||||
|
defer Close()
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
cache := GetDefaultCache()
|
||||||
|
|
||||||
|
// Use the cache
|
||||||
|
err := cache.SetBytes(ctx, "key", []byte("value"), 5*time.Minute)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clean expired items (memory provider specific)
|
||||||
|
if mp, ok := cache.provider.(*MemoryProvider); ok {
|
||||||
|
count := mp.CleanExpired(ctx)
|
||||||
|
fmt.Printf("Cleaned %d expired items\n", count)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExampleDeleteByPattern demonstrates pattern-based deletion (Redis only).
|
||||||
|
func ExampleDeleteByPattern() {
|
||||||
|
err := UseRedis(&RedisConfig{
|
||||||
|
Host: "localhost",
|
||||||
|
Port: 6379,
|
||||||
|
Options: &Options{
|
||||||
|
DefaultTTL: 5 * time.Minute,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
defer Close()
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
cache := GetDefaultCache()
|
||||||
|
|
||||||
|
// Store multiple keys with a pattern
|
||||||
|
cache.SetBytes(ctx, "user:1:profile", []byte("profile1"), 10*time.Minute)
|
||||||
|
cache.SetBytes(ctx, "user:2:profile", []byte("profile2"), 10*time.Minute)
|
||||||
|
cache.SetBytes(ctx, "user:1:settings", []byte("settings1"), 10*time.Minute)
|
||||||
|
|
||||||
|
// Delete all keys matching pattern (Redis glob pattern)
|
||||||
|
err = cache.DeleteByPattern(ctx, "user:*:profile")
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println("Deleted all user profile keys")
|
||||||
|
}
|
||||||
57
pkg/cache/provider.go
vendored
Normal file
57
pkg/cache/provider.go
vendored
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
package cache
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Provider defines the interface that all cache providers must implement.
|
||||||
|
type Provider interface {
|
||||||
|
// Get retrieves a value from the cache by key.
|
||||||
|
// Returns nil, false if key doesn't exist or is expired.
|
||||||
|
Get(ctx context.Context, key string) ([]byte, bool)
|
||||||
|
|
||||||
|
// Set stores a value in the cache with the specified TTL.
|
||||||
|
// If ttl is 0, the item never expires.
|
||||||
|
Set(ctx context.Context, key string, value []byte, ttl time.Duration) error
|
||||||
|
|
||||||
|
// Delete removes a key from the cache.
|
||||||
|
Delete(ctx context.Context, key string) error
|
||||||
|
|
||||||
|
// DeleteByPattern removes all keys matching the pattern.
|
||||||
|
// Pattern syntax depends on the provider implementation.
|
||||||
|
DeleteByPattern(ctx context.Context, pattern string) error
|
||||||
|
|
||||||
|
// Clear removes all items from the cache.
|
||||||
|
Clear(ctx context.Context) error
|
||||||
|
|
||||||
|
// Exists checks if a key exists in the cache.
|
||||||
|
Exists(ctx context.Context, key string) bool
|
||||||
|
|
||||||
|
// Close closes the provider and releases any resources.
|
||||||
|
Close() error
|
||||||
|
|
||||||
|
// Stats returns statistics about the cache provider.
|
||||||
|
Stats(ctx context.Context) (*CacheStats, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CacheStats contains cache statistics.
|
||||||
|
type CacheStats struct {
|
||||||
|
Hits int64 `json:"hits"`
|
||||||
|
Misses int64 `json:"misses"`
|
||||||
|
Keys int64 `json:"keys"`
|
||||||
|
ProviderType string `json:"provider_type"`
|
||||||
|
ProviderStats map[string]any `json:"provider_stats,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Options contains configuration options for cache providers.
|
||||||
|
type Options struct {
|
||||||
|
// DefaultTTL is the default time-to-live for cache items.
|
||||||
|
DefaultTTL time.Duration
|
||||||
|
|
||||||
|
// MaxSize is the maximum number of items (for in-memory provider).
|
||||||
|
MaxSize int
|
||||||
|
|
||||||
|
// EvictionPolicy determines how items are evicted (LRU, LFU, etc).
|
||||||
|
EvictionPolicy string
|
||||||
|
}
|
||||||
144
pkg/cache/provider_memcache.go
vendored
Normal file
144
pkg/cache/provider_memcache.go
vendored
Normal file
@ -0,0 +1,144 @@
|
|||||||
|
package cache
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/bradfitz/gomemcache/memcache"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MemcacheProvider is a Memcache implementation of the Provider interface.
|
||||||
|
type MemcacheProvider struct {
|
||||||
|
client *memcache.Client
|
||||||
|
options *Options
|
||||||
|
}
|
||||||
|
|
||||||
|
// MemcacheConfig contains Memcache-specific configuration.
|
||||||
|
type MemcacheConfig struct {
|
||||||
|
// Servers is a list of memcache server addresses (e.g., "localhost:11211")
|
||||||
|
Servers []string
|
||||||
|
|
||||||
|
// MaxIdleConns is the maximum number of idle connections (default: 2)
|
||||||
|
MaxIdleConns int
|
||||||
|
|
||||||
|
// Timeout for connection operations (default: 1 second)
|
||||||
|
Timeout time.Duration
|
||||||
|
|
||||||
|
// Options contains general cache options
|
||||||
|
Options *Options
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMemcacheProvider creates a new Memcache cache provider.
|
||||||
|
func NewMemcacheProvider(config *MemcacheConfig) (*MemcacheProvider, error) {
|
||||||
|
if config == nil {
|
||||||
|
config = &MemcacheConfig{
|
||||||
|
Servers: []string{"localhost:11211"},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(config.Servers) == 0 {
|
||||||
|
config.Servers = []string{"localhost:11211"}
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.MaxIdleConns == 0 {
|
||||||
|
config.MaxIdleConns = 2
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.Timeout == 0 {
|
||||||
|
config.Timeout = 1 * time.Second
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.Options == nil {
|
||||||
|
config.Options = &Options{
|
||||||
|
DefaultTTL: 5 * time.Minute,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
client := memcache.New(config.Servers...)
|
||||||
|
client.MaxIdleConns = config.MaxIdleConns
|
||||||
|
client.Timeout = config.Timeout
|
||||||
|
|
||||||
|
// Test connection
|
||||||
|
if err := client.Ping(); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to connect to Memcache: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &MemcacheProvider{
|
||||||
|
client: client,
|
||||||
|
options: config.Options,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get retrieves a value from the cache by key.
|
||||||
|
func (m *MemcacheProvider) Get(ctx context.Context, key string) ([]byte, bool) {
|
||||||
|
item, err := m.client.Get(key)
|
||||||
|
if err == memcache.ErrCacheMiss {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
return item.Value, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set stores a value in the cache with the specified TTL.
|
||||||
|
func (m *MemcacheProvider) Set(ctx context.Context, key string, value []byte, ttl time.Duration) error {
|
||||||
|
if ttl == 0 {
|
||||||
|
ttl = m.options.DefaultTTL
|
||||||
|
}
|
||||||
|
|
||||||
|
item := &memcache.Item{
|
||||||
|
Key: key,
|
||||||
|
Value: value,
|
||||||
|
Expiration: int32(ttl.Seconds()),
|
||||||
|
}
|
||||||
|
|
||||||
|
return m.client.Set(item)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete removes a key from the cache.
|
||||||
|
func (m *MemcacheProvider) Delete(ctx context.Context, key string) error {
|
||||||
|
err := m.client.Delete(key)
|
||||||
|
if err == memcache.ErrCacheMiss {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteByPattern removes all keys matching the pattern.
|
||||||
|
// Note: Memcache does not support pattern-based deletion natively.
|
||||||
|
// This is a no-op for memcache and returns an error.
|
||||||
|
func (m *MemcacheProvider) DeleteByPattern(ctx context.Context, pattern string) error {
|
||||||
|
return fmt.Errorf("pattern-based deletion is not supported by Memcache")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clear removes all items from the cache.
|
||||||
|
func (m *MemcacheProvider) Clear(ctx context.Context) error {
|
||||||
|
return m.client.FlushAll()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exists checks if a key exists in the cache.
|
||||||
|
func (m *MemcacheProvider) Exists(ctx context.Context, key string) bool {
|
||||||
|
_, err := m.client.Get(key)
|
||||||
|
return err == nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close closes the provider and releases any resources.
|
||||||
|
func (m *MemcacheProvider) Close() error {
|
||||||
|
// Memcache client doesn't have a close method
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stats returns statistics about the cache provider.
|
||||||
|
// Note: Memcache provider returns limited statistics.
|
||||||
|
func (m *MemcacheProvider) Stats(ctx context.Context) (*CacheStats, error) {
|
||||||
|
stats := &CacheStats{
|
||||||
|
ProviderType: "memcache",
|
||||||
|
ProviderStats: map[string]any{
|
||||||
|
"note": "Memcache does not provide detailed statistics through the standard client",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
return stats, nil
|
||||||
|
}
|
||||||
226
pkg/cache/provider_memory.go
vendored
Normal file
226
pkg/cache/provider_memory.go
vendored
Normal file
@ -0,0 +1,226 @@
|
|||||||
|
package cache
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"regexp"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// memoryItem represents a cached item in memory.
|
||||||
|
type memoryItem struct {
|
||||||
|
Value []byte
|
||||||
|
Expiration time.Time
|
||||||
|
LastAccess time.Time
|
||||||
|
HitCount int64
|
||||||
|
}
|
||||||
|
|
||||||
|
// isExpired checks if the item has expired.
|
||||||
|
func (m *memoryItem) isExpired() bool {
|
||||||
|
if m.Expiration.IsZero() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return time.Now().After(m.Expiration)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MemoryProvider is an in-memory implementation of the Provider interface.
|
||||||
|
type MemoryProvider struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
items map[string]*memoryItem
|
||||||
|
options *Options
|
||||||
|
hits int64
|
||||||
|
misses int64
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMemoryProvider creates a new in-memory cache provider.
|
||||||
|
func NewMemoryProvider(opts *Options) *MemoryProvider {
|
||||||
|
if opts == nil {
|
||||||
|
opts = &Options{
|
||||||
|
DefaultTTL: 5 * time.Minute,
|
||||||
|
MaxSize: 10000,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &MemoryProvider{
|
||||||
|
items: make(map[string]*memoryItem),
|
||||||
|
options: opts,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get retrieves a value from the cache by key.
|
||||||
|
func (m *MemoryProvider) Get(ctx context.Context, key string) ([]byte, bool) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
item, exists := m.items[key]
|
||||||
|
if !exists {
|
||||||
|
m.misses++
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
if item.isExpired() {
|
||||||
|
delete(m.items, key)
|
||||||
|
m.misses++
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
item.LastAccess = time.Now()
|
||||||
|
item.HitCount++
|
||||||
|
m.hits++
|
||||||
|
|
||||||
|
return item.Value, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set stores a value in the cache with the specified TTL.
|
||||||
|
func (m *MemoryProvider) Set(ctx context.Context, key string, value []byte, ttl time.Duration) error {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
if ttl == 0 {
|
||||||
|
ttl = m.options.DefaultTTL
|
||||||
|
}
|
||||||
|
|
||||||
|
var expiration time.Time
|
||||||
|
if ttl > 0 {
|
||||||
|
expiration = time.Now().Add(ttl)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check max size and evict if necessary
|
||||||
|
if m.options.MaxSize > 0 && len(m.items) >= m.options.MaxSize {
|
||||||
|
if _, exists := m.items[key]; !exists {
|
||||||
|
m.evictOne()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
m.items[key] = &memoryItem{
|
||||||
|
Value: value,
|
||||||
|
Expiration: expiration,
|
||||||
|
LastAccess: time.Now(),
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete removes a key from the cache.
|
||||||
|
func (m *MemoryProvider) Delete(ctx context.Context, key string) error {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
delete(m.items, key)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteByPattern removes all keys matching the pattern.
|
||||||
|
func (m *MemoryProvider) DeleteByPattern(ctx context.Context, pattern string) error {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
re, err := regexp.Compile(pattern)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid pattern: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for key := range m.items {
|
||||||
|
if re.MatchString(key) {
|
||||||
|
delete(m.items, key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clear removes all items from the cache.
|
||||||
|
func (m *MemoryProvider) Clear(ctx context.Context) error {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
m.items = make(map[string]*memoryItem)
|
||||||
|
m.hits = 0
|
||||||
|
m.misses = 0
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exists checks if a key exists in the cache.
|
||||||
|
func (m *MemoryProvider) Exists(ctx context.Context, key string) bool {
|
||||||
|
m.mu.RLock()
|
||||||
|
defer m.mu.RUnlock()
|
||||||
|
|
||||||
|
item, exists := m.items[key]
|
||||||
|
if !exists {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return !item.isExpired()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close closes the provider and releases any resources.
|
||||||
|
func (m *MemoryProvider) Close() error {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
m.items = nil
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stats returns statistics about the cache provider.
|
||||||
|
func (m *MemoryProvider) Stats(ctx context.Context) (*CacheStats, error) {
|
||||||
|
m.mu.RLock()
|
||||||
|
defer m.mu.RUnlock()
|
||||||
|
|
||||||
|
// Clean expired items first
|
||||||
|
validKeys := 0
|
||||||
|
for _, item := range m.items {
|
||||||
|
if !item.isExpired() {
|
||||||
|
validKeys++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &CacheStats{
|
||||||
|
Hits: m.hits,
|
||||||
|
Misses: m.misses,
|
||||||
|
Keys: int64(validKeys),
|
||||||
|
ProviderType: "memory",
|
||||||
|
ProviderStats: map[string]any{
|
||||||
|
"capacity": m.options.MaxSize,
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// evictOne removes one item from the cache using LRU strategy.
|
||||||
|
func (m *MemoryProvider) evictOne() {
|
||||||
|
var oldestKey string
|
||||||
|
var oldestTime time.Time
|
||||||
|
|
||||||
|
for key, item := range m.items {
|
||||||
|
if item.isExpired() {
|
||||||
|
delete(m.items, key)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if oldestKey == "" || item.LastAccess.Before(oldestTime) {
|
||||||
|
oldestKey = key
|
||||||
|
oldestTime = item.LastAccess
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if oldestKey != "" {
|
||||||
|
delete(m.items, oldestKey)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CleanExpired removes all expired items from the cache.
|
||||||
|
func (m *MemoryProvider) CleanExpired(ctx context.Context) int {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
count := 0
|
||||||
|
for key, item := range m.items {
|
||||||
|
if item.isExpired() {
|
||||||
|
delete(m.items, key)
|
||||||
|
count++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return count
|
||||||
|
}
|
||||||
185
pkg/cache/provider_redis.go
vendored
Normal file
185
pkg/cache/provider_redis.go
vendored
Normal file
@ -0,0 +1,185 @@
|
|||||||
|
package cache
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/redis/go-redis/v9"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RedisProvider is a Redis implementation of the Provider interface.
|
||||||
|
type RedisProvider struct {
|
||||||
|
client *redis.Client
|
||||||
|
options *Options
|
||||||
|
}
|
||||||
|
|
||||||
|
// RedisConfig contains Redis-specific configuration.
|
||||||
|
type RedisConfig struct {
|
||||||
|
// Host is the Redis server host (default: localhost)
|
||||||
|
Host string
|
||||||
|
|
||||||
|
// Port is the Redis server port (default: 6379)
|
||||||
|
Port int
|
||||||
|
|
||||||
|
// Password for Redis authentication (optional)
|
||||||
|
Password string
|
||||||
|
|
||||||
|
// DB is the Redis database number (default: 0)
|
||||||
|
DB int
|
||||||
|
|
||||||
|
// PoolSize is the maximum number of connections (default: 10)
|
||||||
|
PoolSize int
|
||||||
|
|
||||||
|
// Options contains general cache options
|
||||||
|
Options *Options
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewRedisProvider creates a new Redis cache provider.
|
||||||
|
func NewRedisProvider(config *RedisConfig) (*RedisProvider, error) {
|
||||||
|
if config == nil {
|
||||||
|
config = &RedisConfig{
|
||||||
|
Host: "localhost",
|
||||||
|
Port: 6379,
|
||||||
|
DB: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.Host == "" {
|
||||||
|
config.Host = "localhost"
|
||||||
|
}
|
||||||
|
if config.Port == 0 {
|
||||||
|
config.Port = 6379
|
||||||
|
}
|
||||||
|
if config.PoolSize == 0 {
|
||||||
|
config.PoolSize = 10
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.Options == nil {
|
||||||
|
config.Options = &Options{
|
||||||
|
DefaultTTL: 5 * time.Minute,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
client := redis.NewClient(&redis.Options{
|
||||||
|
Addr: fmt.Sprintf("%s:%d", config.Host, config.Port),
|
||||||
|
Password: config.Password,
|
||||||
|
DB: config.DB,
|
||||||
|
PoolSize: config.PoolSize,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test connection
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
if err := client.Ping(ctx).Err(); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to connect to Redis: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &RedisProvider{
|
||||||
|
client: client,
|
||||||
|
options: config.Options,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get retrieves a value from the cache by key.
|
||||||
|
func (r *RedisProvider) Get(ctx context.Context, key string) ([]byte, bool) {
|
||||||
|
val, err := r.client.Get(ctx, key).Bytes()
|
||||||
|
if err == redis.Nil {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
return val, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set stores a value in the cache with the specified TTL.
|
||||||
|
func (r *RedisProvider) Set(ctx context.Context, key string, value []byte, ttl time.Duration) error {
|
||||||
|
if ttl == 0 {
|
||||||
|
ttl = r.options.DefaultTTL
|
||||||
|
}
|
||||||
|
|
||||||
|
return r.client.Set(ctx, key, value, ttl).Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete removes a key from the cache.
|
||||||
|
func (r *RedisProvider) Delete(ctx context.Context, key string) error {
|
||||||
|
return r.client.Del(ctx, key).Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteByPattern removes all keys matching the pattern.
|
||||||
|
func (r *RedisProvider) DeleteByPattern(ctx context.Context, pattern string) error {
|
||||||
|
iter := r.client.Scan(ctx, 0, pattern, 0).Iterator()
|
||||||
|
pipe := r.client.Pipeline()
|
||||||
|
|
||||||
|
count := 0
|
||||||
|
for iter.Next(ctx) {
|
||||||
|
pipe.Del(ctx, iter.Val())
|
||||||
|
count++
|
||||||
|
|
||||||
|
// Execute pipeline in batches of 100
|
||||||
|
if count%100 == 0 {
|
||||||
|
if _, err := pipe.Exec(ctx); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
pipe = r.client.Pipeline()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := iter.Err(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute remaining commands
|
||||||
|
if count%100 != 0 {
|
||||||
|
_, err := pipe.Exec(ctx)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clear removes all items from the cache.
|
||||||
|
func (r *RedisProvider) Clear(ctx context.Context) error {
|
||||||
|
return r.client.FlushDB(ctx).Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exists checks if a key exists in the cache.
|
||||||
|
func (r *RedisProvider) Exists(ctx context.Context, key string) bool {
|
||||||
|
result, err := r.client.Exists(ctx, key).Result()
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return result > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close closes the provider and releases any resources.
|
||||||
|
func (r *RedisProvider) Close() error {
|
||||||
|
return r.client.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stats returns statistics about the cache provider.
|
||||||
|
func (r *RedisProvider) Stats(ctx context.Context) (*CacheStats, error) {
|
||||||
|
info, err := r.client.Info(ctx, "stats", "keyspace").Result()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get Redis stats: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
dbSize, err := r.client.DBSize(ctx).Result()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get DB size: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse stats from INFO command
|
||||||
|
// This is a simplified version - you may want to parse more detailed stats
|
||||||
|
stats := &CacheStats{
|
||||||
|
Keys: dbSize,
|
||||||
|
ProviderType: "redis",
|
||||||
|
ProviderStats: map[string]any{
|
||||||
|
"info": info,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
return stats, nil
|
||||||
|
}
|
||||||
127
pkg/cache/query_cache.go
vendored
Normal file
127
pkg/cache/query_cache.go
vendored
Normal file
@ -0,0 +1,127 @@
|
|||||||
|
package cache
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/hex"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||||
|
)
|
||||||
|
|
||||||
|
// QueryCacheKey represents the components used to build a cache key for query total count
|
||||||
|
type QueryCacheKey struct {
|
||||||
|
TableName string `json:"table_name"`
|
||||||
|
Filters []common.FilterOption `json:"filters"`
|
||||||
|
Sort []common.SortOption `json:"sort"`
|
||||||
|
CustomSQLWhere string `json:"custom_sql_where,omitempty"`
|
||||||
|
CustomSQLOr string `json:"custom_sql_or,omitempty"`
|
||||||
|
Expand []ExpandOptionKey `json:"expand,omitempty"`
|
||||||
|
Distinct bool `json:"distinct,omitempty"`
|
||||||
|
CursorForward string `json:"cursor_forward,omitempty"`
|
||||||
|
CursorBackward string `json:"cursor_backward,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExpandOptionKey represents expand options for cache key
|
||||||
|
type ExpandOptionKey struct {
|
||||||
|
Relation string `json:"relation"`
|
||||||
|
Where string `json:"where,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildQueryCacheKey builds a cache key from query parameters for total count caching
|
||||||
|
// This is used to cache the total count of records matching a query
|
||||||
|
func BuildQueryCacheKey(tableName string, filters []common.FilterOption, sort []common.SortOption, customWhere, customOr string) string {
|
||||||
|
key := QueryCacheKey{
|
||||||
|
TableName: tableName,
|
||||||
|
Filters: filters,
|
||||||
|
Sort: sort,
|
||||||
|
CustomSQLWhere: customWhere,
|
||||||
|
CustomSQLOr: customOr,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Serialize to JSON for consistent hashing
|
||||||
|
jsonData, err := json.Marshal(key)
|
||||||
|
if err != nil {
|
||||||
|
// Fallback to simple string concatenation if JSON fails
|
||||||
|
return hashString(fmt.Sprintf("%s_%v_%v_%s_%s", tableName, filters, sort, customWhere, customOr))
|
||||||
|
}
|
||||||
|
|
||||||
|
return hashString(string(jsonData))
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildExtendedQueryCacheKey builds a cache key for extended query options (restheadspec)
|
||||||
|
// Includes expand, distinct, and cursor pagination options
|
||||||
|
func BuildExtendedQueryCacheKey(tableName string, filters []common.FilterOption, sort []common.SortOption,
|
||||||
|
customWhere, customOr string, expandOpts []interface{}, distinct bool, cursorFwd, cursorBwd string) string {
|
||||||
|
|
||||||
|
key := QueryCacheKey{
|
||||||
|
TableName: tableName,
|
||||||
|
Filters: filters,
|
||||||
|
Sort: sort,
|
||||||
|
CustomSQLWhere: customWhere,
|
||||||
|
CustomSQLOr: customOr,
|
||||||
|
Distinct: distinct,
|
||||||
|
CursorForward: cursorFwd,
|
||||||
|
CursorBackward: cursorBwd,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert expand options to cache key format
|
||||||
|
if len(expandOpts) > 0 {
|
||||||
|
key.Expand = make([]ExpandOptionKey, 0, len(expandOpts))
|
||||||
|
for _, exp := range expandOpts {
|
||||||
|
// Type assert to get the expand option fields we care about for caching
|
||||||
|
if expMap, ok := exp.(map[string]interface{}); ok {
|
||||||
|
expKey := ExpandOptionKey{}
|
||||||
|
if rel, ok := expMap["relation"].(string); ok {
|
||||||
|
expKey.Relation = rel
|
||||||
|
}
|
||||||
|
if where, ok := expMap["where"].(string); ok {
|
||||||
|
expKey.Where = where
|
||||||
|
}
|
||||||
|
key.Expand = append(key.Expand, expKey)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Sort expand options for consistent hashing (already sorted by relation name above)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Serialize to JSON for consistent hashing
|
||||||
|
jsonData, err := json.Marshal(key)
|
||||||
|
if err != nil {
|
||||||
|
// Fallback to simple string concatenation if JSON fails
|
||||||
|
return hashString(fmt.Sprintf("%s_%v_%v_%s_%s_%v_%v_%s_%s",
|
||||||
|
tableName, filters, sort, customWhere, customOr, expandOpts, distinct, cursorFwd, cursorBwd))
|
||||||
|
}
|
||||||
|
|
||||||
|
return hashString(string(jsonData))
|
||||||
|
}
|
||||||
|
|
||||||
|
// hashString computes SHA256 hash of a string
|
||||||
|
func hashString(s string) string {
|
||||||
|
h := sha256.New()
|
||||||
|
h.Write([]byte(s))
|
||||||
|
return hex.EncodeToString(h.Sum(nil))
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetQueryTotalCacheKey returns a formatted cache key for storing/retrieving total count
|
||||||
|
func GetQueryTotalCacheKey(hash string) string {
|
||||||
|
return fmt.Sprintf("query_total:%s", hash)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CachedTotal represents a cached total count
|
||||||
|
type CachedTotal struct {
|
||||||
|
Total int `json:"total"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// InvalidateCacheForTable removes all cached totals for a specific table
|
||||||
|
// This should be called when data in the table changes (insert/update/delete)
|
||||||
|
func InvalidateCacheForTable(ctx context.Context, tableName string) error {
|
||||||
|
cache := GetDefaultCache()
|
||||||
|
|
||||||
|
// Build a pattern to match all query totals for this table
|
||||||
|
// Note: This requires pattern matching support in the provider
|
||||||
|
pattern := fmt.Sprintf("query_total:*%s*", strings.ToLower(tableName))
|
||||||
|
|
||||||
|
return cache.DeleteByPattern(ctx, pattern)
|
||||||
|
}
|
||||||
151
pkg/cache/query_cache_test.go
vendored
Normal file
151
pkg/cache/query_cache_test.go
vendored
Normal file
@ -0,0 +1,151 @@
|
|||||||
|
package cache
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestBuildQueryCacheKey(t *testing.T) {
|
||||||
|
filters := []common.FilterOption{
|
||||||
|
{Column: "name", Operator: "eq", Value: "test"},
|
||||||
|
{Column: "age", Operator: "gt", Value: 25},
|
||||||
|
}
|
||||||
|
sorts := []common.SortOption{
|
||||||
|
{Column: "name", Direction: "asc"},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate cache key
|
||||||
|
key1 := BuildQueryCacheKey("users", filters, sorts, "status = 'active'", "")
|
||||||
|
|
||||||
|
// Same parameters should generate same key
|
||||||
|
key2 := BuildQueryCacheKey("users", filters, sorts, "status = 'active'", "")
|
||||||
|
|
||||||
|
if key1 != key2 {
|
||||||
|
t.Errorf("Expected same cache keys for identical parameters, got %s and %s", key1, key2)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Different parameters should generate different key
|
||||||
|
key3 := BuildQueryCacheKey("users", filters, sorts, "status = 'inactive'", "")
|
||||||
|
|
||||||
|
if key1 == key3 {
|
||||||
|
t.Errorf("Expected different cache keys for different parameters, got %s and %s", key1, key3)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildExtendedQueryCacheKey(t *testing.T) {
|
||||||
|
filters := []common.FilterOption{
|
||||||
|
{Column: "name", Operator: "eq", Value: "test"},
|
||||||
|
}
|
||||||
|
sorts := []common.SortOption{
|
||||||
|
{Column: "name", Direction: "asc"},
|
||||||
|
}
|
||||||
|
expandOpts := []interface{}{
|
||||||
|
map[string]interface{}{
|
||||||
|
"relation": "posts",
|
||||||
|
"where": "status = 'published'",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate cache key
|
||||||
|
key1 := BuildExtendedQueryCacheKey("users", filters, sorts, "", "", expandOpts, false, "", "")
|
||||||
|
|
||||||
|
// Same parameters should generate same key
|
||||||
|
key2 := BuildExtendedQueryCacheKey("users", filters, sorts, "", "", expandOpts, false, "", "")
|
||||||
|
|
||||||
|
if key1 != key2 {
|
||||||
|
t.Errorf("Expected same cache keys for identical parameters")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Different distinct value should generate different key
|
||||||
|
key3 := BuildExtendedQueryCacheKey("users", filters, sorts, "", "", expandOpts, true, "", "")
|
||||||
|
|
||||||
|
if key1 == key3 {
|
||||||
|
t.Errorf("Expected different cache keys for different distinct values")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetQueryTotalCacheKey(t *testing.T) {
|
||||||
|
hash := "abc123"
|
||||||
|
key := GetQueryTotalCacheKey(hash)
|
||||||
|
|
||||||
|
expected := "query_total:abc123"
|
||||||
|
if key != expected {
|
||||||
|
t.Errorf("Expected %s, got %s", expected, key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCachedTotalIntegration(t *testing.T) {
|
||||||
|
// Initialize cache with memory provider for testing
|
||||||
|
UseMemory(&Options{
|
||||||
|
DefaultTTL: 1 * time.Minute,
|
||||||
|
MaxSize: 100,
|
||||||
|
})
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Create test data
|
||||||
|
filters := []common.FilterOption{
|
||||||
|
{Column: "status", Operator: "eq", Value: "active"},
|
||||||
|
}
|
||||||
|
sorts := []common.SortOption{
|
||||||
|
{Column: "created_at", Direction: "desc"},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build cache key
|
||||||
|
cacheKeyHash := BuildQueryCacheKey("test_table", filters, sorts, "", "")
|
||||||
|
cacheKey := GetQueryTotalCacheKey(cacheKeyHash)
|
||||||
|
|
||||||
|
// Store a total count in cache
|
||||||
|
totalToCache := CachedTotal{Total: 42}
|
||||||
|
err := GetDefaultCache().Set(ctx, cacheKey, totalToCache, time.Minute)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to set cache: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Retrieve from cache
|
||||||
|
var cachedTotal CachedTotal
|
||||||
|
err = GetDefaultCache().Get(ctx, cacheKey, &cachedTotal)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to get from cache: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if cachedTotal.Total != 42 {
|
||||||
|
t.Errorf("Expected total 42, got %d", cachedTotal.Total)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test cache miss
|
||||||
|
nonExistentKey := GetQueryTotalCacheKey("nonexistent")
|
||||||
|
var missedTotal CachedTotal
|
||||||
|
err = GetDefaultCache().Get(ctx, nonExistentKey, &missedTotal)
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("Expected error for cache miss, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHashString(t *testing.T) {
|
||||||
|
input1 := "test string"
|
||||||
|
input2 := "test string"
|
||||||
|
input3 := "different string"
|
||||||
|
|
||||||
|
hash1 := hashString(input1)
|
||||||
|
hash2 := hashString(input2)
|
||||||
|
hash3 := hashString(input3)
|
||||||
|
|
||||||
|
// Same input should produce same hash
|
||||||
|
if hash1 != hash2 {
|
||||||
|
t.Errorf("Expected same hash for identical inputs")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Different input should produce different hash
|
||||||
|
if hash1 == hash3 {
|
||||||
|
t.Errorf("Expected different hash for different inputs")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Hash should be hex encoded SHA256 (64 characters)
|
||||||
|
if len(hash1) != 64 {
|
||||||
|
t.Errorf("Expected hash length of 64, got %d", len(hash1))
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -20,16 +20,24 @@ import (
|
|||||||
|
|
||||||
// Handler handles function-based SQL API requests
|
// Handler handles function-based SQL API requests
|
||||||
type Handler struct {
|
type Handler struct {
|
||||||
db common.Database
|
db common.Database
|
||||||
|
hooks *HookRegistry
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewHandler creates a new function API handler
|
// NewHandler creates a new function API handler
|
||||||
func NewHandler(db common.Database) *Handler {
|
func NewHandler(db common.Database) *Handler {
|
||||||
return &Handler{
|
return &Handler{
|
||||||
db: db,
|
db: db,
|
||||||
|
hooks: NewHookRegistry(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Hooks returns the hook registry for this handler
|
||||||
|
// Use this to register custom hooks for operations
|
||||||
|
func (h *Handler) Hooks() *HookRegistry {
|
||||||
|
return h.hooks
|
||||||
|
}
|
||||||
|
|
||||||
// HTTPFuncType is a function type for HTTP handlers
|
// HTTPFuncType is a function type for HTTP handlers
|
||||||
type HTTPFuncType func(http.ResponseWriter, *http.Request)
|
type HTTPFuncType func(http.ResponseWriter, *http.Request)
|
||||||
|
|
||||||
@ -64,18 +72,77 @@ func (h *Handler) SqlQueryList(sqlquery string, pNoCount, pBlankparms, pAllowFil
|
|||||||
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
// Initialize hook context
|
||||||
|
hookCtx := &HookContext{
|
||||||
|
Context: ctx,
|
||||||
|
Handler: h,
|
||||||
|
Request: r,
|
||||||
|
Writer: w,
|
||||||
|
SQLQuery: sqlquery,
|
||||||
|
Variables: variables,
|
||||||
|
InputVars: inputvars,
|
||||||
|
MetaInfo: metainfo,
|
||||||
|
PropQry: propQry,
|
||||||
|
UserContext: userCtx,
|
||||||
|
NoCount: pNoCount,
|
||||||
|
BlankParams: pBlankparms,
|
||||||
|
AllowFilter: pAllowFilter,
|
||||||
|
ComplexAPI: complexAPI,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute BeforeQueryList hook
|
||||||
|
if err := h.hooks.Execute(BeforeQueryList, hookCtx); err != nil {
|
||||||
|
logger.Error("BeforeQueryList hook failed: %v", err)
|
||||||
|
sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if hook aborted the operation
|
||||||
|
if hookCtx.Abort {
|
||||||
|
if hookCtx.AbortCode == 0 {
|
||||||
|
hookCtx.AbortCode = http.StatusBadRequest
|
||||||
|
}
|
||||||
|
sendError(w, hookCtx.AbortCode, "operation_aborted", hookCtx.AbortMessage, nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use potentially modified SQL query and variables from hooks
|
||||||
|
sqlquery = hookCtx.SQLQuery
|
||||||
|
variables = hookCtx.Variables
|
||||||
|
complexAPI = hookCtx.ComplexAPI
|
||||||
|
|
||||||
// Extract input variables from SQL query (placeholders like [variable])
|
// Extract input variables from SQL query (placeholders like [variable])
|
||||||
sqlquery = h.extractInputVariables(sqlquery, &inputvars)
|
sqlquery = h.extractInputVariables(sqlquery, &inputvars)
|
||||||
|
|
||||||
// Merge URL path parameters
|
// Merge URL path parameters
|
||||||
sqlquery = h.mergePathParams(r, sqlquery, variables)
|
sqlquery = h.mergePathParams(r, sqlquery, variables)
|
||||||
|
|
||||||
|
// Parse comprehensive parameters from headers and query string
|
||||||
|
reqParams := h.ParseParameters(r)
|
||||||
|
complexAPI = reqParams.ComplexAPI
|
||||||
|
|
||||||
// Merge query string parameters
|
// Merge query string parameters
|
||||||
sqlquery = h.mergeQueryParams(r, sqlquery, variables, pAllowFilter, propQry)
|
sqlquery = h.mergeQueryParams(r, sqlquery, variables, pAllowFilter, propQry)
|
||||||
|
|
||||||
// Merge header parameters
|
// Merge header parameters
|
||||||
sqlquery = h.mergeHeaderParams(r, sqlquery, variables, propQry, &complexAPI)
|
sqlquery = h.mergeHeaderParams(r, sqlquery, variables, propQry, &complexAPI)
|
||||||
|
|
||||||
|
// Apply filters from parsed parameters (if not already applied by pAllowFilter)
|
||||||
|
if !pAllowFilter {
|
||||||
|
sqlquery = h.ApplyFilters(sqlquery, reqParams)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply field selection
|
||||||
|
sqlquery = h.ApplyFieldSelection(sqlquery, reqParams)
|
||||||
|
|
||||||
|
// Apply DISTINCT if requested
|
||||||
|
sqlquery = h.ApplyDistinct(sqlquery, reqParams)
|
||||||
|
|
||||||
|
// Override pNoCount if skipcount is specified
|
||||||
|
if reqParams.SkipCount {
|
||||||
|
pNoCount = true
|
||||||
|
}
|
||||||
|
|
||||||
// Build metainfo
|
// Build metainfo
|
||||||
metainfo["ipaddress"] = getIPAddress(r)
|
metainfo["ipaddress"] = getIPAddress(r)
|
||||||
metainfo["url"] = r.RequestURI
|
metainfo["url"] = r.RequestURI
|
||||||
@ -95,12 +162,32 @@ func (h *Handler) SqlQueryList(sqlquery string, pNoCount, pBlankparms, pAllowFil
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Update hook context with latest SQL query and variables
|
||||||
|
hookCtx.SQLQuery = sqlquery
|
||||||
|
hookCtx.Variables = variables
|
||||||
|
hookCtx.InputVars = inputvars
|
||||||
|
|
||||||
// Execute query within transaction
|
// Execute query within transaction
|
||||||
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
||||||
sqlqueryCnt := sqlquery
|
sqlqueryCnt := sqlquery
|
||||||
|
|
||||||
// Parse sorting and pagination parameters
|
// Parse sorting and pagination parameters
|
||||||
sortcols, limit, offset := h.parsePaginationParams(r)
|
sortcols, limit, offset := h.parsePaginationParams(r)
|
||||||
|
|
||||||
|
// Override with parsed parameters if available
|
||||||
|
if reqParams.SortColumns != "" {
|
||||||
|
sortcols = reqParams.SortColumns
|
||||||
|
}
|
||||||
|
if reqParams.Limit > 0 {
|
||||||
|
limit = reqParams.Limit
|
||||||
|
}
|
||||||
|
if reqParams.Offset > 0 {
|
||||||
|
offset = reqParams.Offset
|
||||||
|
}
|
||||||
|
|
||||||
|
hookCtx.SortColumns = sortcols
|
||||||
|
hookCtx.Limit = limit
|
||||||
|
hookCtx.Offset = offset
|
||||||
fromPos := strings.Index(strings.ToLower(sqlquery), "from ")
|
fromPos := strings.Index(strings.ToLower(sqlquery), "from ")
|
||||||
orderbyPos := strings.Index(strings.ToLower(sqlquery), "order by")
|
orderbyPos := strings.Index(strings.ToLower(sqlquery), "order by")
|
||||||
|
|
||||||
@ -127,6 +214,16 @@ func (h *Handler) SqlQueryList(sqlquery string, pNoCount, pBlankparms, pAllowFil
|
|||||||
total = countResult.Count
|
total = countResult.Count
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Execute BeforeSQLExec hook
|
||||||
|
hookCtx.SQLQuery = sqlquery
|
||||||
|
if err := h.hooks.Execute(BeforeSQLExec, hookCtx); err != nil {
|
||||||
|
logger.Error("BeforeSQLExec hook failed: %v", err)
|
||||||
|
sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
// Use potentially modified SQL query from hook
|
||||||
|
sqlquery = hookCtx.SQLQuery
|
||||||
|
|
||||||
// Execute main query
|
// Execute main query
|
||||||
rows := make([]map[string]interface{}, 0)
|
rows := make([]map[string]interface{}, 0)
|
||||||
if err := tx.Query(ctx, &rows, sqlquery); err != nil {
|
if err := tx.Query(ctx, &rows, sqlquery); err != nil {
|
||||||
@ -140,6 +237,20 @@ func (h *Handler) SqlQueryList(sqlquery string, pNoCount, pBlankparms, pAllowFil
|
|||||||
total = int64(len(dbobjlist))
|
total = int64(len(dbobjlist))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Execute AfterSQLExec hook
|
||||||
|
hookCtx.Result = dbobjlist
|
||||||
|
hookCtx.Total = total
|
||||||
|
if err := h.hooks.Execute(AfterSQLExec, hookCtx); err != nil {
|
||||||
|
logger.Error("AfterSQLExec hook failed: %v", err)
|
||||||
|
sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
// Use potentially modified result from hook
|
||||||
|
if modifiedResult, ok := hookCtx.Result.([]map[string]interface{}); ok {
|
||||||
|
dbobjlist = modifiedResult
|
||||||
|
}
|
||||||
|
total = hookCtx.Total
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -148,6 +259,21 @@ func (h *Handler) SqlQueryList(sqlquery string, pNoCount, pBlankparms, pAllowFil
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Execute AfterQueryList hook
|
||||||
|
hookCtx.Result = dbobjlist
|
||||||
|
hookCtx.Total = total
|
||||||
|
hookCtx.Error = err
|
||||||
|
if err := h.hooks.Execute(AfterQueryList, hookCtx); err != nil {
|
||||||
|
logger.Error("AfterQueryList hook failed: %v", err)
|
||||||
|
sendError(w, http.StatusInternalServerError, "hook_error", "Hook execution failed", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Use potentially modified result from hook
|
||||||
|
if modifiedResult, ok := hookCtx.Result.([]map[string]interface{}); ok {
|
||||||
|
dbobjlist = modifiedResult
|
||||||
|
}
|
||||||
|
total = hookCtx.Total
|
||||||
|
|
||||||
// Set response headers
|
// Set response headers
|
||||||
respOffset := 0
|
respOffset := 0
|
||||||
if offsetStr := r.URL.Query().Get("offset"); offsetStr != "" {
|
if offsetStr := r.URL.Query().Get("offset"); offsetStr != "" {
|
||||||
@ -159,12 +285,44 @@ func (h *Handler) SqlQueryList(sqlquery string, pNoCount, pBlankparms, pAllowFil
|
|||||||
w.Header().Set("Content-Range", fmt.Sprintf("items %d-%d/%d", respOffset, respOffset+len(dbobjlist), total))
|
w.Header().Set("Content-Range", fmt.Sprintf("items %d-%d/%d", respOffset, respOffset+len(dbobjlist), total))
|
||||||
logger.Info("Serving: Records %d of %d", len(dbobjlist), total)
|
logger.Info("Serving: Records %d of %d", len(dbobjlist), total)
|
||||||
|
|
||||||
|
// Execute BeforeResponse hook
|
||||||
|
hookCtx.Result = dbobjlist
|
||||||
|
hookCtx.Total = total
|
||||||
|
if err := h.hooks.Execute(BeforeResponse, hookCtx); err != nil {
|
||||||
|
logger.Error("BeforeResponse hook failed: %v", err)
|
||||||
|
sendError(w, http.StatusInternalServerError, "hook_error", "Hook execution failed", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Use potentially modified result from hook
|
||||||
|
if modifiedResult, ok := hookCtx.Result.([]map[string]interface{}); ok {
|
||||||
|
dbobjlist = modifiedResult
|
||||||
|
}
|
||||||
|
|
||||||
if len(dbobjlist) == 0 {
|
if len(dbobjlist) == 0 {
|
||||||
w.Write([]byte("[]"))
|
w.Write([]byte("[]"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if complexAPI {
|
// Format response based on response format
|
||||||
|
switch reqParams.ResponseFormat {
|
||||||
|
case "syncfusion":
|
||||||
|
// Syncfusion format: { result: data, count: total }
|
||||||
|
response := map[string]interface{}{
|
||||||
|
"result": dbobjlist,
|
||||||
|
"count": total,
|
||||||
|
}
|
||||||
|
data, err := json.Marshal(response)
|
||||||
|
if err != nil {
|
||||||
|
sendError(w, http.StatusInternalServerError, "json_error", "Could not marshal response", err)
|
||||||
|
} else {
|
||||||
|
if int64(len(dbobjlist)) < total {
|
||||||
|
w.WriteHeader(http.StatusPartialContent)
|
||||||
|
}
|
||||||
|
w.Write(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
case "detail":
|
||||||
|
// Detail format: complex API with metadata
|
||||||
metaobj := map[string]interface{}{
|
metaobj := map[string]interface{}{
|
||||||
"items": dbobjlist,
|
"items": dbobjlist,
|
||||||
"count": fmt.Sprintf("%d", len(dbobjlist)),
|
"count": fmt.Sprintf("%d", len(dbobjlist)),
|
||||||
@ -172,7 +330,6 @@ func (h *Handler) SqlQueryList(sqlquery string, pNoCount, pBlankparms, pAllowFil
|
|||||||
"tablename": r.URL.Path,
|
"tablename": r.URL.Path,
|
||||||
"tableprefix": "gsql",
|
"tableprefix": "gsql",
|
||||||
}
|
}
|
||||||
|
|
||||||
data, err := json.Marshal(metaobj)
|
data, err := json.Marshal(metaobj)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
sendError(w, http.StatusInternalServerError, "json_error", "Could not marshal response", err)
|
sendError(w, http.StatusInternalServerError, "json_error", "Could not marshal response", err)
|
||||||
@ -182,15 +339,36 @@ func (h *Handler) SqlQueryList(sqlquery string, pNoCount, pBlankparms, pAllowFil
|
|||||||
}
|
}
|
||||||
w.Write(data)
|
w.Write(data)
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
data, err := json.Marshal(dbobjlist)
|
default:
|
||||||
if err != nil {
|
// Simple format: just return the data array (or complex API if requested)
|
||||||
sendError(w, http.StatusInternalServerError, "json_error", "Could not marshal response", err)
|
if complexAPI {
|
||||||
} else {
|
metaobj := map[string]interface{}{
|
||||||
if int64(len(dbobjlist)) < total {
|
"items": dbobjlist,
|
||||||
w.WriteHeader(http.StatusPartialContent)
|
"count": fmt.Sprintf("%d", len(dbobjlist)),
|
||||||
|
"total": fmt.Sprintf("%d", total),
|
||||||
|
"tablename": r.URL.Path,
|
||||||
|
"tableprefix": "gsql",
|
||||||
|
}
|
||||||
|
data, err := json.Marshal(metaobj)
|
||||||
|
if err != nil {
|
||||||
|
sendError(w, http.StatusInternalServerError, "json_error", "Could not marshal response", err)
|
||||||
|
} else {
|
||||||
|
if int64(len(dbobjlist)) < total {
|
||||||
|
w.WriteHeader(http.StatusPartialContent)
|
||||||
|
}
|
||||||
|
w.Write(data)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
data, err := json.Marshal(dbobjlist)
|
||||||
|
if err != nil {
|
||||||
|
sendError(w, http.StatusInternalServerError, "json_error", "Could not marshal response", err)
|
||||||
|
} else {
|
||||||
|
if int64(len(dbobjlist)) < total {
|
||||||
|
w.WriteHeader(http.StatusPartialContent)
|
||||||
|
}
|
||||||
|
w.Write(data)
|
||||||
}
|
}
|
||||||
w.Write(data)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -215,6 +393,7 @@ func (h *Handler) SqlQuery(sqlquery string, pBlankparms bool) HTTPFuncType {
|
|||||||
metainfo := make(map[string]interface{})
|
metainfo := make(map[string]interface{})
|
||||||
variables := make(map[string]interface{})
|
variables := make(map[string]interface{})
|
||||||
dbobj := make(map[string]interface{})
|
dbobj := make(map[string]interface{})
|
||||||
|
complexAPI := false
|
||||||
|
|
||||||
// Get user context from security package
|
// Get user context from security package
|
||||||
userCtx, ok := security.GetUserContext(ctx)
|
userCtx, ok := security.GetUserContext(ctx)
|
||||||
@ -225,18 +404,67 @@ func (h *Handler) SqlQuery(sqlquery string, pBlankparms bool) HTTPFuncType {
|
|||||||
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
// Initialize hook context
|
||||||
|
hookCtx := &HookContext{
|
||||||
|
Context: ctx,
|
||||||
|
Handler: h,
|
||||||
|
Request: r,
|
||||||
|
Writer: w,
|
||||||
|
SQLQuery: sqlquery,
|
||||||
|
Variables: variables,
|
||||||
|
InputVars: inputvars,
|
||||||
|
MetaInfo: metainfo,
|
||||||
|
PropQry: propQry,
|
||||||
|
UserContext: userCtx,
|
||||||
|
BlankParams: pBlankparms,
|
||||||
|
ComplexAPI: complexAPI,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute BeforeQuery hook
|
||||||
|
if err := h.hooks.Execute(BeforeQuery, hookCtx); err != nil {
|
||||||
|
logger.Error("BeforeQuery hook failed: %v", err)
|
||||||
|
sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if hook aborted the operation
|
||||||
|
if hookCtx.Abort {
|
||||||
|
if hookCtx.AbortCode == 0 {
|
||||||
|
hookCtx.AbortCode = http.StatusBadRequest
|
||||||
|
}
|
||||||
|
sendError(w, hookCtx.AbortCode, "operation_aborted", hookCtx.AbortMessage, nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use potentially modified SQL query and variables from hooks
|
||||||
|
sqlquery = hookCtx.SQLQuery
|
||||||
|
variables = hookCtx.Variables
|
||||||
|
|
||||||
// Extract input variables from SQL query
|
// Extract input variables from SQL query
|
||||||
sqlquery = h.extractInputVariables(sqlquery, &inputvars)
|
sqlquery = h.extractInputVariables(sqlquery, &inputvars)
|
||||||
|
|
||||||
// Merge URL path parameters
|
// Merge URL path parameters
|
||||||
sqlquery = h.mergePathParams(r, sqlquery, variables)
|
sqlquery = h.mergePathParams(r, sqlquery, variables)
|
||||||
|
|
||||||
|
// Parse comprehensive parameters from headers and query string
|
||||||
|
reqParams := h.ParseParameters(r)
|
||||||
|
complexAPI = reqParams.ComplexAPI
|
||||||
|
|
||||||
// Merge query string parameters
|
// Merge query string parameters
|
||||||
sqlquery = h.mergeQueryParams(r, sqlquery, variables, false, propQry)
|
sqlquery = h.mergeQueryParams(r, sqlquery, variables, false, propQry)
|
||||||
|
|
||||||
// Merge header parameters
|
// Merge header parameters
|
||||||
complexAPI := false
|
|
||||||
sqlquery = h.mergeHeaderParams(r, sqlquery, variables, propQry, &complexAPI)
|
sqlquery = h.mergeHeaderParams(r, sqlquery, variables, propQry, &complexAPI)
|
||||||
|
hookCtx.ComplexAPI = complexAPI
|
||||||
|
|
||||||
|
// Apply filters from parsed parameters
|
||||||
|
sqlquery = h.ApplyFilters(sqlquery, reqParams)
|
||||||
|
|
||||||
|
// Apply field selection
|
||||||
|
sqlquery = h.ApplyFieldSelection(sqlquery, reqParams)
|
||||||
|
|
||||||
|
// Apply DISTINCT if requested
|
||||||
|
sqlquery = h.ApplyDistinct(sqlquery, reqParams)
|
||||||
|
|
||||||
// Build metainfo
|
// Build metainfo
|
||||||
metainfo["ipaddress"] = getIPAddress(r)
|
metainfo["ipaddress"] = getIPAddress(r)
|
||||||
@ -272,8 +500,22 @@ func (h *Handler) SqlQuery(sqlquery string, pBlankparms bool) HTTPFuncType {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Update hook context with latest SQL query and variables
|
||||||
|
hookCtx.SQLQuery = sqlquery
|
||||||
|
hookCtx.Variables = variables
|
||||||
|
hookCtx.InputVars = inputvars
|
||||||
|
|
||||||
// Execute query within transaction
|
// Execute query within transaction
|
||||||
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
||||||
|
// Execute BeforeSQLExec hook
|
||||||
|
if err := h.hooks.Execute(BeforeSQLExec, hookCtx); err != nil {
|
||||||
|
logger.Error("BeforeSQLExec hook failed: %v", err)
|
||||||
|
sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
// Use potentially modified SQL query from hook
|
||||||
|
sqlquery = hookCtx.SQLQuery
|
||||||
|
|
||||||
// Execute main query
|
// Execute main query
|
||||||
rows := make([]map[string]interface{}, 0)
|
rows := make([]map[string]interface{}, 0)
|
||||||
if err := tx.Query(ctx, &rows, sqlquery); err != nil {
|
if err := tx.Query(ctx, &rows, sqlquery); err != nil {
|
||||||
@ -285,6 +527,18 @@ func (h *Handler) SqlQuery(sqlquery string, pBlankparms bool) HTTPFuncType {
|
|||||||
dbobj = rows[0]
|
dbobj = rows[0]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Execute AfterSQLExec hook
|
||||||
|
hookCtx.Result = dbobj
|
||||||
|
if err := h.hooks.Execute(AfterSQLExec, hookCtx); err != nil {
|
||||||
|
logger.Error("AfterSQLExec hook failed: %v", err)
|
||||||
|
sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
// Use potentially modified result from hook
|
||||||
|
if modifiedResult, ok := hookCtx.Result.(map[string]interface{}); ok {
|
||||||
|
dbobj = modifiedResult
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -293,6 +547,31 @@ func (h *Handler) SqlQuery(sqlquery string, pBlankparms bool) HTTPFuncType {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Execute AfterQuery hook
|
||||||
|
hookCtx.Result = dbobj
|
||||||
|
hookCtx.Error = err
|
||||||
|
if err := h.hooks.Execute(AfterQuery, hookCtx); err != nil {
|
||||||
|
logger.Error("AfterQuery hook failed: %v", err)
|
||||||
|
sendError(w, http.StatusInternalServerError, "hook_error", "Hook execution failed", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Use potentially modified result from hook
|
||||||
|
if modifiedResult, ok := hookCtx.Result.(map[string]interface{}); ok {
|
||||||
|
dbobj = modifiedResult
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute BeforeResponse hook
|
||||||
|
hookCtx.Result = dbobj
|
||||||
|
if err := h.hooks.Execute(BeforeResponse, hookCtx); err != nil {
|
||||||
|
logger.Error("BeforeResponse hook failed: %v", err)
|
||||||
|
sendError(w, http.StatusInternalServerError, "hook_error", "Hook execution failed", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Use potentially modified result from hook
|
||||||
|
if modifiedResult, ok := hookCtx.Result.(map[string]interface{}); ok {
|
||||||
|
dbobj = modifiedResult
|
||||||
|
}
|
||||||
|
|
||||||
// Check if response should be root-level data
|
// Check if response should be root-level data
|
||||||
if val, ok := dbobj["root_as_data"]; ok {
|
if val, ok := dbobj["root_as_data"]; ok {
|
||||||
data, err := json.Marshal(val)
|
data, err := json.Marshal(val)
|
||||||
|
|||||||
837
pkg/funcspec/function_api_test.go
Normal file
837
pkg/funcspec/function_api_test.go
Normal file
@ -0,0 +1,837 @@
|
|||||||
|
package funcspec
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/security"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MockDatabase implements common.Database interface for testing
|
||||||
|
type MockDatabase struct {
|
||||||
|
QueryFunc func(ctx context.Context, dest interface{}, query string, args ...interface{}) error
|
||||||
|
ExecFunc func(ctx context.Context, query string, args ...interface{}) (common.Result, error)
|
||||||
|
RunInTransactionFunc func(ctx context.Context, fn func(common.Database) error) error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockDatabase) NewSelect() common.SelectQuery {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockDatabase) NewInsert() common.InsertQuery {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockDatabase) NewUpdate() common.UpdateQuery {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockDatabase) NewDelete() common.DeleteQuery {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockDatabase) Exec(ctx context.Context, query string, args ...interface{}) (common.Result, error) {
|
||||||
|
if m.ExecFunc != nil {
|
||||||
|
return m.ExecFunc(ctx, query, args...)
|
||||||
|
}
|
||||||
|
return &MockResult{rows: 0}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockDatabase) Query(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
|
||||||
|
if m.QueryFunc != nil {
|
||||||
|
return m.QueryFunc(ctx, dest, query, args...)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockDatabase) BeginTx(ctx context.Context) (common.Database, error) {
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockDatabase) CommitTx(ctx context.Context) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockDatabase) RollbackTx(ctx context.Context) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockDatabase) RunInTransaction(ctx context.Context, fn func(common.Database) error) error {
|
||||||
|
if m.RunInTransactionFunc != nil {
|
||||||
|
return m.RunInTransactionFunc(ctx, fn)
|
||||||
|
}
|
||||||
|
return fn(m)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockResult implements common.Result interface for testing
|
||||||
|
type MockResult struct {
|
||||||
|
rows int64
|
||||||
|
id int64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockResult) RowsAffected() int64 {
|
||||||
|
return m.rows
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockResult) LastInsertId() (int64, error) {
|
||||||
|
return m.id, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper function to create a test request with user context
|
||||||
|
func createTestRequest(method, path string, queryParams map[string]string, headers map[string]string, body []byte) *http.Request {
|
||||||
|
u, _ := url.Parse(path)
|
||||||
|
if queryParams != nil {
|
||||||
|
q := u.Query()
|
||||||
|
for k, v := range queryParams {
|
||||||
|
q.Set(k, v)
|
||||||
|
}
|
||||||
|
u.RawQuery = q.Encode()
|
||||||
|
}
|
||||||
|
|
||||||
|
var bodyReader *bytes.Reader
|
||||||
|
if body != nil {
|
||||||
|
bodyReader = bytes.NewReader(body)
|
||||||
|
} else {
|
||||||
|
bodyReader = bytes.NewReader([]byte{})
|
||||||
|
}
|
||||||
|
|
||||||
|
req := httptest.NewRequest(method, u.String(), bodyReader)
|
||||||
|
|
||||||
|
if headers != nil {
|
||||||
|
for k, v := range headers {
|
||||||
|
req.Header.Set(k, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add user context
|
||||||
|
userCtx := &security.UserContext{
|
||||||
|
UserID: 1,
|
||||||
|
UserName: "testuser",
|
||||||
|
SessionID: "test-session-123",
|
||||||
|
}
|
||||||
|
ctx := context.WithValue(req.Context(), security.UserContextKey, userCtx)
|
||||||
|
req = req.WithContext(ctx)
|
||||||
|
|
||||||
|
return req
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestNewHandler tests handler creation
|
||||||
|
func TestNewHandler(t *testing.T) {
|
||||||
|
db := &MockDatabase{}
|
||||||
|
handler := NewHandler(db)
|
||||||
|
|
||||||
|
if handler == nil {
|
||||||
|
t.Fatal("Expected handler to be created, got nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
if handler.db != db {
|
||||||
|
t.Error("Expected handler to have the provided database")
|
||||||
|
}
|
||||||
|
|
||||||
|
if handler.hooks == nil {
|
||||||
|
t.Error("Expected handler to have a hook registry")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestHandlerHooks tests the Hooks method
|
||||||
|
func TestHandlerHooks(t *testing.T) {
|
||||||
|
handler := NewHandler(&MockDatabase{})
|
||||||
|
hooks := handler.Hooks()
|
||||||
|
|
||||||
|
if hooks == nil {
|
||||||
|
t.Fatal("Expected hooks registry to be non-nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should return the same instance
|
||||||
|
hooks2 := handler.Hooks()
|
||||||
|
if hooks != hooks2 {
|
||||||
|
t.Error("Expected Hooks() to return the same registry instance")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestExtractInputVariables tests the extractInputVariables function
|
||||||
|
func TestExtractInputVariables(t *testing.T) {
|
||||||
|
handler := NewHandler(&MockDatabase{})
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
sqlQuery string
|
||||||
|
expectedVars []string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "No variables",
|
||||||
|
sqlQuery: "SELECT * FROM users",
|
||||||
|
expectedVars: []string{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Single variable",
|
||||||
|
sqlQuery: "SELECT * FROM users WHERE id = [user_id]",
|
||||||
|
expectedVars: []string{"[user_id]"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Multiple variables",
|
||||||
|
sqlQuery: "SELECT * FROM users WHERE id = [user_id] AND name = [user_name]",
|
||||||
|
expectedVars: []string{"[user_id]", "[user_name]"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Nested brackets",
|
||||||
|
sqlQuery: "SELECT * FROM users WHERE data::jsonb @> '[field]'::jsonb AND id = [user_id]",
|
||||||
|
expectedVars: []string{"[field]", "[user_id]"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
inputvars := make([]string, 0)
|
||||||
|
result := handler.extractInputVariables(tt.sqlQuery, &inputvars)
|
||||||
|
|
||||||
|
if result != tt.sqlQuery {
|
||||||
|
t.Errorf("Expected SQL query to be unchanged, got %s", result)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(inputvars) != len(tt.expectedVars) {
|
||||||
|
t.Errorf("Expected %d variables, got %d: %v", len(tt.expectedVars), len(inputvars), inputvars)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, expected := range tt.expectedVars {
|
||||||
|
if inputvars[i] != expected {
|
||||||
|
t.Errorf("Expected variable %d to be %s, got %s", i, expected, inputvars[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestValidSQL tests the SQL sanitization function
|
||||||
|
func TestValidSQL(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
mode string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Column name with valid characters",
|
||||||
|
input: "user_id",
|
||||||
|
mode: "colname",
|
||||||
|
expected: "user_id",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Column name with dots (table.column)",
|
||||||
|
input: "users.user_id",
|
||||||
|
mode: "colname",
|
||||||
|
expected: "users.user_id",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Column name with SQL injection attempt",
|
||||||
|
input: "id'; DROP TABLE users--",
|
||||||
|
mode: "colname",
|
||||||
|
expected: "idDROPTABLEusers",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Column value with single quotes",
|
||||||
|
input: "O'Brien",
|
||||||
|
mode: "colvalue",
|
||||||
|
expected: "O''Brien",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Select with dangerous keywords",
|
||||||
|
input: "name, email; DROP TABLE users",
|
||||||
|
mode: "select",
|
||||||
|
expected: "name, email TABLE users",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := ValidSQL(tt.input, tt.mode)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("ValidSQL(%q, %q) = %q, expected %q", tt.input, tt.mode, result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestIsNumeric tests the IsNumeric function
|
||||||
|
func TestIsNumeric(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
input string
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{"123", true},
|
||||||
|
{"123.45", true},
|
||||||
|
{"-123", true},
|
||||||
|
{"-123.45", true},
|
||||||
|
{"0", true},
|
||||||
|
{"abc", false},
|
||||||
|
{"12.34.56", false},
|
||||||
|
{"", false},
|
||||||
|
{"123abc", false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.input, func(t *testing.T) {
|
||||||
|
result := IsNumeric(tt.input)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("IsNumeric(%q) = %v, expected %v", tt.input, result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestSqlQryWhere tests the WHERE clause manipulation
|
||||||
|
func TestSqlQryWhere(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
sqlQuery string
|
||||||
|
condition string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Add WHERE to query without WHERE",
|
||||||
|
sqlQuery: "SELECT * FROM users",
|
||||||
|
condition: "status = 'active'",
|
||||||
|
expected: "SELECT * FROM users WHERE status = 'active' ",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Add AND to query with existing WHERE",
|
||||||
|
sqlQuery: "SELECT * FROM users WHERE id > 0",
|
||||||
|
condition: "status = 'active'",
|
||||||
|
expected: "SELECT * FROM users WHERE id > 0 AND status = 'active' ",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Add WHERE before ORDER BY",
|
||||||
|
sqlQuery: "SELECT * FROM users ORDER BY name",
|
||||||
|
condition: "status = 'active'",
|
||||||
|
expected: "SELECT * FROM users WHERE status = 'active' ORDER BY name",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Add WHERE before GROUP BY",
|
||||||
|
sqlQuery: "SELECT COUNT(*) FROM users GROUP BY department",
|
||||||
|
condition: "status = 'active'",
|
||||||
|
expected: "SELECT COUNT(*) FROM users WHERE status = 'active' GROUP BY department",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Add WHERE before LIMIT",
|
||||||
|
sqlQuery: "SELECT * FROM users LIMIT 10",
|
||||||
|
condition: "status = 'active'",
|
||||||
|
expected: "SELECT * FROM users WHERE status = 'active' LIMIT 10",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := sqlQryWhere(tt.sqlQuery, tt.condition)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("sqlQryWhere() = %q, expected %q", result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestGetIPAddress tests IP address extraction
|
||||||
|
func TestGetIPAddress(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
setupReq func() *http.Request
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "X-Forwarded-For header",
|
||||||
|
setupReq: func() *http.Request {
|
||||||
|
req := httptest.NewRequest("GET", "/test", nil)
|
||||||
|
req.Header.Set("X-Forwarded-For", "192.168.1.100, 10.0.0.1")
|
||||||
|
return req
|
||||||
|
},
|
||||||
|
expected: "192.168.1.100",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "X-Real-IP header",
|
||||||
|
setupReq: func() *http.Request {
|
||||||
|
req := httptest.NewRequest("GET", "/test", nil)
|
||||||
|
req.Header.Set("X-Real-IP", "192.168.1.200")
|
||||||
|
return req
|
||||||
|
},
|
||||||
|
expected: "192.168.1.200",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "RemoteAddr fallback",
|
||||||
|
setupReq: func() *http.Request {
|
||||||
|
req := httptest.NewRequest("GET", "/test", nil)
|
||||||
|
req.RemoteAddr = "192.168.1.1:12345"
|
||||||
|
return req
|
||||||
|
},
|
||||||
|
expected: "192.168.1.1:12345",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
req := tt.setupReq()
|
||||||
|
result := getIPAddress(req)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("getIPAddress() = %q, expected %q", result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestParsePaginationParams tests pagination parameter parsing
|
||||||
|
func TestParsePaginationParams(t *testing.T) {
|
||||||
|
handler := NewHandler(&MockDatabase{})
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
queryParams map[string]string
|
||||||
|
expectedSort string
|
||||||
|
expectedLimit int
|
||||||
|
expectedOffset int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "No parameters - defaults",
|
||||||
|
queryParams: map[string]string{},
|
||||||
|
expectedSort: "",
|
||||||
|
expectedLimit: 20,
|
||||||
|
expectedOffset: 0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "All parameters provided",
|
||||||
|
queryParams: map[string]string{
|
||||||
|
"sort": "name,-created_at",
|
||||||
|
"limit": "100",
|
||||||
|
"offset": "50",
|
||||||
|
},
|
||||||
|
expectedSort: "name,-created_at",
|
||||||
|
expectedLimit: 100,
|
||||||
|
expectedOffset: 50,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Invalid limit - use default",
|
||||||
|
queryParams: map[string]string{
|
||||||
|
"limit": "invalid",
|
||||||
|
},
|
||||||
|
expectedSort: "",
|
||||||
|
expectedLimit: 20,
|
||||||
|
expectedOffset: 0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Negative offset - use default",
|
||||||
|
queryParams: map[string]string{
|
||||||
|
"offset": "-10",
|
||||||
|
},
|
||||||
|
expectedSort: "",
|
||||||
|
expectedLimit: 20,
|
||||||
|
expectedOffset: 0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
req := createTestRequest("GET", "/test", tt.queryParams, nil, nil)
|
||||||
|
sort, limit, offset := handler.parsePaginationParams(req)
|
||||||
|
|
||||||
|
if sort != tt.expectedSort {
|
||||||
|
t.Errorf("Expected sort=%q, got %q", tt.expectedSort, sort)
|
||||||
|
}
|
||||||
|
if limit != tt.expectedLimit {
|
||||||
|
t.Errorf("Expected limit=%d, got %d", tt.expectedLimit, limit)
|
||||||
|
}
|
||||||
|
if offset != tt.expectedOffset {
|
||||||
|
t.Errorf("Expected offset=%d, got %d", tt.expectedOffset, offset)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestSqlQuery tests the SqlQuery handler for single record queries
|
||||||
|
func TestSqlQuery(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
sqlQuery string
|
||||||
|
blankParams bool
|
||||||
|
queryParams map[string]string
|
||||||
|
headers map[string]string
|
||||||
|
setupDB func() *MockDatabase
|
||||||
|
expectedStatus int
|
||||||
|
validateResp func(t *testing.T, body []byte)
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Basic query - returns single record",
|
||||||
|
sqlQuery: "SELECT * FROM users WHERE id = 1",
|
||||||
|
blankParams: false,
|
||||||
|
setupDB: func() *MockDatabase {
|
||||||
|
return &MockDatabase{
|
||||||
|
RunInTransactionFunc: func(ctx context.Context, fn func(common.Database) error) error {
|
||||||
|
db := &MockDatabase{
|
||||||
|
QueryFunc: func(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
|
||||||
|
rows := dest.(*[]map[string]interface{})
|
||||||
|
*rows = []map[string]interface{}{
|
||||||
|
{"id": float64(1), "name": "Test User", "email": "test@example.com"},
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return fn(db)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
expectedStatus: 200,
|
||||||
|
validateResp: func(t *testing.T, body []byte) {
|
||||||
|
var result map[string]interface{}
|
||||||
|
if err := json.Unmarshal(body, &result); err != nil {
|
||||||
|
t.Fatalf("Failed to unmarshal response: %v", err)
|
||||||
|
}
|
||||||
|
if result["name"] != "Test User" {
|
||||||
|
t.Errorf("Expected name='Test User', got %v", result["name"])
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Query with no results",
|
||||||
|
sqlQuery: "SELECT * FROM users WHERE id = 999",
|
||||||
|
blankParams: false,
|
||||||
|
setupDB: func() *MockDatabase {
|
||||||
|
return &MockDatabase{
|
||||||
|
RunInTransactionFunc: func(ctx context.Context, fn func(common.Database) error) error {
|
||||||
|
db := &MockDatabase{
|
||||||
|
QueryFunc: func(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
|
||||||
|
// Return empty array
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return fn(db)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
expectedStatus: 200,
|
||||||
|
validateResp: func(t *testing.T, body []byte) {
|
||||||
|
var result map[string]interface{}
|
||||||
|
if err := json.Unmarshal(body, &result); err != nil {
|
||||||
|
t.Fatalf("Failed to unmarshal response: %v", err)
|
||||||
|
}
|
||||||
|
if len(result) != 0 {
|
||||||
|
t.Errorf("Expected empty result, got %v", result)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
db := tt.setupDB()
|
||||||
|
handler := NewHandler(db)
|
||||||
|
|
||||||
|
req := createTestRequest("GET", "/test", tt.queryParams, tt.headers, nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
handlerFunc := handler.SqlQuery(tt.sqlQuery, tt.blankParams)
|
||||||
|
handlerFunc(w, req)
|
||||||
|
|
||||||
|
if w.Code != tt.expectedStatus {
|
||||||
|
t.Errorf("Expected status %d, got %d", tt.expectedStatus, w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.validateResp != nil {
|
||||||
|
tt.validateResp(t, w.Body.Bytes())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestSqlQueryList tests the SqlQueryList handler for list queries
|
||||||
|
func TestSqlQueryList(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
sqlQuery string
|
||||||
|
noCount bool
|
||||||
|
blankParams bool
|
||||||
|
allowFilter bool
|
||||||
|
queryParams map[string]string
|
||||||
|
headers map[string]string
|
||||||
|
setupDB func() *MockDatabase
|
||||||
|
expectedStatus int
|
||||||
|
validateResp func(t *testing.T, w *httptest.ResponseRecorder)
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Basic list query",
|
||||||
|
sqlQuery: "SELECT * FROM users",
|
||||||
|
noCount: false,
|
||||||
|
blankParams: false,
|
||||||
|
allowFilter: false,
|
||||||
|
setupDB: func() *MockDatabase {
|
||||||
|
return &MockDatabase{
|
||||||
|
RunInTransactionFunc: func(ctx context.Context, fn func(common.Database) error) error {
|
||||||
|
callCount := 0
|
||||||
|
db := &MockDatabase{
|
||||||
|
QueryFunc: func(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
|
||||||
|
callCount++
|
||||||
|
if strings.Contains(query, "COUNT") {
|
||||||
|
// Count query
|
||||||
|
countResult := dest.(*struct{ Count int64 })
|
||||||
|
countResult.Count = 2
|
||||||
|
} else {
|
||||||
|
// Main query
|
||||||
|
rows := dest.(*[]map[string]interface{})
|
||||||
|
*rows = []map[string]interface{}{
|
||||||
|
{"id": float64(1), "name": "User 1"},
|
||||||
|
{"id": float64(2), "name": "User 2"},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return fn(db)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
expectedStatus: 200,
|
||||||
|
validateResp: func(t *testing.T, w *httptest.ResponseRecorder) {
|
||||||
|
var result []map[string]interface{}
|
||||||
|
if err := json.Unmarshal(w.Body.Bytes(), &result); err != nil {
|
||||||
|
t.Fatalf("Failed to unmarshal response: %v", err)
|
||||||
|
}
|
||||||
|
if len(result) != 2 {
|
||||||
|
t.Errorf("Expected 2 results, got %d", len(result))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check Content-Range header
|
||||||
|
contentRange := w.Header().Get("Content-Range")
|
||||||
|
if !strings.Contains(contentRange, "2") {
|
||||||
|
t.Errorf("Expected Content-Range to contain total count, got: %s", contentRange)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "List query with noCount",
|
||||||
|
sqlQuery: "SELECT * FROM users",
|
||||||
|
noCount: true,
|
||||||
|
blankParams: false,
|
||||||
|
allowFilter: false,
|
||||||
|
setupDB: func() *MockDatabase {
|
||||||
|
return &MockDatabase{
|
||||||
|
RunInTransactionFunc: func(ctx context.Context, fn func(common.Database) error) error {
|
||||||
|
db := &MockDatabase{
|
||||||
|
QueryFunc: func(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
|
||||||
|
if strings.Contains(query, "COUNT") {
|
||||||
|
t.Error("Count query should not be executed when noCount is true")
|
||||||
|
}
|
||||||
|
rows := dest.(*[]map[string]interface{})
|
||||||
|
*rows = []map[string]interface{}{
|
||||||
|
{"id": float64(1), "name": "User 1"},
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return fn(db)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
expectedStatus: 200,
|
||||||
|
validateResp: func(t *testing.T, w *httptest.ResponseRecorder) {
|
||||||
|
var result []map[string]interface{}
|
||||||
|
if err := json.Unmarshal(w.Body.Bytes(), &result); err != nil {
|
||||||
|
t.Fatalf("Failed to unmarshal response: %v", err)
|
||||||
|
}
|
||||||
|
if len(result) != 1 {
|
||||||
|
t.Errorf("Expected 1 result, got %d", len(result))
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
db := tt.setupDB()
|
||||||
|
handler := NewHandler(db)
|
||||||
|
|
||||||
|
req := createTestRequest("GET", "/test", tt.queryParams, tt.headers, nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
handlerFunc := handler.SqlQueryList(tt.sqlQuery, tt.noCount, tt.blankParams, tt.allowFilter)
|
||||||
|
handlerFunc(w, req)
|
||||||
|
|
||||||
|
if w.Code != tt.expectedStatus {
|
||||||
|
t.Errorf("Expected status %d, got %d. Body: %s", tt.expectedStatus, w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.validateResp != nil {
|
||||||
|
tt.validateResp(t, w)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestMergeQueryParams tests query parameter merging
|
||||||
|
func TestMergeQueryParams(t *testing.T) {
|
||||||
|
handler := NewHandler(&MockDatabase{})
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
sqlQuery string
|
||||||
|
queryParams map[string]string
|
||||||
|
allowFilter bool
|
||||||
|
expectedQuery string
|
||||||
|
checkVars func(t *testing.T, vars map[string]interface{})
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Replace placeholder with parameter",
|
||||||
|
sqlQuery: "SELECT * FROM users WHERE id = [user_id]",
|
||||||
|
queryParams: map[string]string{"p-user_id": "123"},
|
||||||
|
allowFilter: false,
|
||||||
|
checkVars: func(t *testing.T, vars map[string]interface{}) {
|
||||||
|
if vars["p-user_id"] != "123" {
|
||||||
|
t.Errorf("Expected p-user_id=123, got %v", vars["p-user_id"])
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Add filter when allowed",
|
||||||
|
sqlQuery: "SELECT * FROM users",
|
||||||
|
queryParams: map[string]string{"status": "active"},
|
||||||
|
allowFilter: true,
|
||||||
|
checkVars: func(t *testing.T, vars map[string]interface{}) {
|
||||||
|
if vars["status"] != "active" {
|
||||||
|
t.Errorf("Expected status=active, got %v", vars["status"])
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
req := createTestRequest("GET", "/test", tt.queryParams, nil, nil)
|
||||||
|
variables := make(map[string]interface{})
|
||||||
|
propQry := make(map[string]string)
|
||||||
|
|
||||||
|
result := handler.mergeQueryParams(req, tt.sqlQuery, variables, tt.allowFilter, propQry)
|
||||||
|
|
||||||
|
if result == "" {
|
||||||
|
t.Error("Expected non-empty SQL query result")
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.checkVars != nil {
|
||||||
|
tt.checkVars(t, variables)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestMergeHeaderParams tests header parameter merging
|
||||||
|
func TestMergeHeaderParams(t *testing.T) {
|
||||||
|
handler := NewHandler(&MockDatabase{})
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
sqlQuery string
|
||||||
|
headers map[string]string
|
||||||
|
expectedQuery string
|
||||||
|
checkVars func(t *testing.T, vars map[string]interface{})
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Field filter header",
|
||||||
|
sqlQuery: "SELECT * FROM users",
|
||||||
|
headers: map[string]string{"X-FieldFilter-Status": "1"},
|
||||||
|
checkVars: func(t *testing.T, vars map[string]interface{}) {
|
||||||
|
if vars["x-fieldfilter-status"] != "1" {
|
||||||
|
t.Errorf("Expected x-fieldfilter-status=1, got %v", vars["x-fieldfilter-status"])
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Search filter header",
|
||||||
|
sqlQuery: "SELECT * FROM users",
|
||||||
|
headers: map[string]string{"X-SearchFilter-Name": "john"},
|
||||||
|
checkVars: func(t *testing.T, vars map[string]interface{}) {
|
||||||
|
if vars["x-searchfilter-name"] != "john" {
|
||||||
|
t.Errorf("Expected x-searchfilter-name=john, got %v", vars["x-searchfilter-name"])
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
req := createTestRequest("GET", "/test", nil, tt.headers, nil)
|
||||||
|
variables := make(map[string]interface{})
|
||||||
|
propQry := make(map[string]string)
|
||||||
|
complexAPI := false
|
||||||
|
|
||||||
|
result := handler.mergeHeaderParams(req, tt.sqlQuery, variables, propQry, &complexAPI)
|
||||||
|
|
||||||
|
if result == "" {
|
||||||
|
t.Error("Expected non-empty SQL query result")
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.checkVars != nil {
|
||||||
|
tt.checkVars(t, variables)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestReplaceMetaVariables tests meta variable replacement
|
||||||
|
func TestReplaceMetaVariables(t *testing.T) {
|
||||||
|
handler := NewHandler(&MockDatabase{})
|
||||||
|
|
||||||
|
userCtx := &security.UserContext{
|
||||||
|
UserID: 123,
|
||||||
|
UserName: "testuser",
|
||||||
|
SessionID: "session-abc",
|
||||||
|
}
|
||||||
|
|
||||||
|
metainfo := map[string]interface{}{
|
||||||
|
"ipaddress": "192.168.1.1",
|
||||||
|
"url": "/api/test",
|
||||||
|
}
|
||||||
|
|
||||||
|
variables := map[string]interface{}{
|
||||||
|
"param1": "value1",
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
sqlQuery string
|
||||||
|
expectedCheck func(result string) bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Replace [rid_user]",
|
||||||
|
sqlQuery: "SELECT * FROM users WHERE created_by = [rid_user]",
|
||||||
|
expectedCheck: func(result string) bool {
|
||||||
|
return strings.Contains(result, "123")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Replace [user]",
|
||||||
|
sqlQuery: "SELECT * FROM audit WHERE username = [user]",
|
||||||
|
expectedCheck: func(result string) bool {
|
||||||
|
return strings.Contains(result, "'testuser'")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Replace [rid_session]",
|
||||||
|
sqlQuery: "SELECT * FROM sessions WHERE session_id = [rid_session]",
|
||||||
|
expectedCheck: func(result string) bool {
|
||||||
|
return strings.Contains(result, "'session-abc'")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
req := createTestRequest("GET", "/test", nil, nil, nil)
|
||||||
|
result := handler.replaceMetaVariables(tt.sqlQuery, req, userCtx, metainfo, variables)
|
||||||
|
|
||||||
|
if !tt.expectedCheck(result) {
|
||||||
|
t.Errorf("Meta variable replacement failed. Query: %s", result)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
160
pkg/funcspec/hooks.go
Normal file
160
pkg/funcspec/hooks.go
Normal file
@ -0,0 +1,160 @@
|
|||||||
|
package funcspec
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/security"
|
||||||
|
)
|
||||||
|
|
||||||
|
// HookType defines the type of hook to execute
|
||||||
|
type HookType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
// Query operation hooks (for SqlQuery - single record)
|
||||||
|
BeforeQuery HookType = "before_query"
|
||||||
|
AfterQuery HookType = "after_query"
|
||||||
|
|
||||||
|
// Query list operation hooks (for SqlQueryList - multiple records)
|
||||||
|
BeforeQueryList HookType = "before_query_list"
|
||||||
|
AfterQueryList HookType = "after_query_list"
|
||||||
|
|
||||||
|
// SQL execution hooks (just before SQL is executed)
|
||||||
|
BeforeSQLExec HookType = "before_sql_exec"
|
||||||
|
AfterSQLExec HookType = "after_sql_exec"
|
||||||
|
|
||||||
|
// Response hooks (before response is sent)
|
||||||
|
BeforeResponse HookType = "before_response"
|
||||||
|
)
|
||||||
|
|
||||||
|
// HookContext contains all the data available to a hook
|
||||||
|
type HookContext struct {
|
||||||
|
Context context.Context
|
||||||
|
Handler *Handler // Reference to the handler for accessing database
|
||||||
|
Request *http.Request
|
||||||
|
Writer http.ResponseWriter
|
||||||
|
|
||||||
|
// SQL query and variables
|
||||||
|
SQLQuery string // The SQL query being executed (can be modified by hooks)
|
||||||
|
Variables map[string]interface{} // Variables extracted from request
|
||||||
|
InputVars []string // Input variable placeholders found in query
|
||||||
|
MetaInfo map[string]interface{} // Metadata about the request
|
||||||
|
PropQry map[string]string // Property query parameters
|
||||||
|
|
||||||
|
// User context
|
||||||
|
UserContext *security.UserContext
|
||||||
|
|
||||||
|
// Pagination and filtering (for list queries)
|
||||||
|
SortColumns string
|
||||||
|
Limit int
|
||||||
|
Offset int
|
||||||
|
|
||||||
|
// Results
|
||||||
|
Result interface{} // Query result (single record or list)
|
||||||
|
Total int64 // Total count (for list queries)
|
||||||
|
Error error // Error if operation failed
|
||||||
|
ComplexAPI bool // Whether complex API response format is requested
|
||||||
|
NoCount bool // Whether count query should be skipped
|
||||||
|
BlankParams bool // Whether blank parameters should be removed
|
||||||
|
AllowFilter bool // Whether filtering is allowed
|
||||||
|
|
||||||
|
// Allow hooks to abort the operation
|
||||||
|
Abort bool // If set to true, the operation will be aborted
|
||||||
|
AbortMessage string // Message to return if aborted
|
||||||
|
AbortCode int // HTTP status code if aborted
|
||||||
|
}
|
||||||
|
|
||||||
|
// HookFunc is the signature for hook functions
|
||||||
|
// It receives a HookContext and can modify it or return an error
|
||||||
|
// If an error is returned, the operation will be aborted
|
||||||
|
type HookFunc func(*HookContext) error
|
||||||
|
|
||||||
|
// HookRegistry manages all registered hooks
|
||||||
|
type HookRegistry struct {
|
||||||
|
hooks map[HookType][]HookFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewHookRegistry creates a new hook registry
|
||||||
|
func NewHookRegistry() *HookRegistry {
|
||||||
|
return &HookRegistry{
|
||||||
|
hooks: make(map[HookType][]HookFunc),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register adds a new hook for the specified hook type
|
||||||
|
func (r *HookRegistry) Register(hookType HookType, hook HookFunc) {
|
||||||
|
if r.hooks == nil {
|
||||||
|
r.hooks = make(map[HookType][]HookFunc)
|
||||||
|
}
|
||||||
|
r.hooks[hookType] = append(r.hooks[hookType], hook)
|
||||||
|
logger.Info("Registered funcspec hook for %s (total: %d)", hookType, len(r.hooks[hookType]))
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterMultiple registers a hook for multiple hook types
|
||||||
|
func (r *HookRegistry) RegisterMultiple(hookTypes []HookType, hook HookFunc) {
|
||||||
|
for _, hookType := range hookTypes {
|
||||||
|
r.Register(hookType, hook)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute runs all hooks for the specified type in order
|
||||||
|
// If any hook returns an error, execution stops and the error is returned
|
||||||
|
func (r *HookRegistry) Execute(hookType HookType, ctx *HookContext) error {
|
||||||
|
hooks, exists := r.hooks[hookType]
|
||||||
|
if !exists || len(hooks) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Debug("Executing %d funcspec hook(s) for %s", len(hooks), hookType)
|
||||||
|
|
||||||
|
for i, hook := range hooks {
|
||||||
|
if err := hook(ctx); err != nil {
|
||||||
|
logger.Error("Funcspec hook %d for %s failed: %v", i+1, hookType, err)
|
||||||
|
return fmt.Errorf("hook execution failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if hook requested abort
|
||||||
|
if ctx.Abort {
|
||||||
|
logger.Warn("Funcspec hook %d for %s requested abort: %s", i+1, hookType, ctx.AbortMessage)
|
||||||
|
return fmt.Errorf("operation aborted by hook: %s", ctx.AbortMessage)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clear removes all hooks for the specified type
|
||||||
|
func (r *HookRegistry) Clear(hookType HookType) {
|
||||||
|
delete(r.hooks, hookType)
|
||||||
|
logger.Info("Cleared all funcspec hooks for %s", hookType)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearAll removes all registered hooks
|
||||||
|
func (r *HookRegistry) ClearAll() {
|
||||||
|
r.hooks = make(map[HookType][]HookFunc)
|
||||||
|
logger.Info("Cleared all funcspec hooks")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Count returns the number of hooks registered for a specific type
|
||||||
|
func (r *HookRegistry) Count(hookType HookType) int {
|
||||||
|
if hooks, exists := r.hooks[hookType]; exists {
|
||||||
|
return len(hooks)
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// HasHooks returns true if there are any hooks registered for the specified type
|
||||||
|
func (r *HookRegistry) HasHooks(hookType HookType) bool {
|
||||||
|
return r.Count(hookType) > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAllHookTypes returns all hook types that have registered hooks
|
||||||
|
func (r *HookRegistry) GetAllHookTypes() []HookType {
|
||||||
|
types := make([]HookType, 0, len(r.hooks))
|
||||||
|
for hookType := range r.hooks {
|
||||||
|
types = append(types, hookType)
|
||||||
|
}
|
||||||
|
return types
|
||||||
|
}
|
||||||
137
pkg/funcspec/hooks_example.go
Normal file
137
pkg/funcspec/hooks_example.go
Normal file
@ -0,0 +1,137 @@
|
|||||||
|
package funcspec
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Example hook functions demonstrating various use cases
|
||||||
|
|
||||||
|
// ExampleLoggingHook logs all SQL queries before execution
|
||||||
|
func ExampleLoggingHook(ctx *HookContext) error {
|
||||||
|
logger.Info("Executing SQL query for user %s: %s", ctx.UserContext.UserName, ctx.SQLQuery)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExampleSecurityHook validates user permissions before executing queries
|
||||||
|
func ExampleSecurityHook(ctx *HookContext) error {
|
||||||
|
// Example: Block queries that try to access sensitive tables
|
||||||
|
if strings.Contains(strings.ToLower(ctx.SQLQuery), "sensitive_table") {
|
||||||
|
if ctx.UserContext.UserID != 1 { // Only admin can access
|
||||||
|
ctx.Abort = true
|
||||||
|
ctx.AbortCode = 403
|
||||||
|
ctx.AbortMessage = "Access denied: insufficient permissions"
|
||||||
|
return fmt.Errorf("access denied to sensitive_table")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExampleQueryModificationHook modifies SQL queries to add user-specific filtering
|
||||||
|
func ExampleQueryModificationHook(ctx *HookContext) error {
|
||||||
|
// Example: Automatically add user_id filter for non-admin users
|
||||||
|
if ctx.UserContext.UserID != 1 { // Not admin
|
||||||
|
// Add WHERE clause to filter by user_id
|
||||||
|
if !strings.Contains(strings.ToLower(ctx.SQLQuery), "where") {
|
||||||
|
ctx.SQLQuery = fmt.Sprintf("%s WHERE user_id = %d", ctx.SQLQuery, ctx.UserContext.UserID)
|
||||||
|
} else {
|
||||||
|
ctx.SQLQuery = strings.Replace(
|
||||||
|
ctx.SQLQuery,
|
||||||
|
"WHERE",
|
||||||
|
fmt.Sprintf("WHERE user_id = %d AND", ctx.UserContext.UserID),
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
logger.Debug("Modified query for user %d: %s", ctx.UserContext.UserID, ctx.SQLQuery)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExampleResultFilterHook filters results after query execution
|
||||||
|
func ExampleResultFilterHook(ctx *HookContext) error {
|
||||||
|
// Example: Remove sensitive fields from results for non-admin users
|
||||||
|
if ctx.UserContext.UserID != 1 { // Not admin
|
||||||
|
switch result := ctx.Result.(type) {
|
||||||
|
case []map[string]interface{}:
|
||||||
|
// Filter list results
|
||||||
|
for i := range result {
|
||||||
|
delete(result[i], "password")
|
||||||
|
delete(result[i], "ssn")
|
||||||
|
delete(result[i], "credit_card")
|
||||||
|
}
|
||||||
|
case map[string]interface{}:
|
||||||
|
// Filter single record
|
||||||
|
delete(result, "password")
|
||||||
|
delete(result, "ssn")
|
||||||
|
delete(result, "credit_card")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExampleAuditHook logs all queries and results for audit purposes
|
||||||
|
func ExampleAuditHook(ctx *HookContext) error {
|
||||||
|
// Log to audit table or external system
|
||||||
|
logger.Info("AUDIT: User %s (%d) executed query from %s",
|
||||||
|
ctx.UserContext.UserName,
|
||||||
|
ctx.UserContext.UserID,
|
||||||
|
ctx.Request.RemoteAddr,
|
||||||
|
)
|
||||||
|
|
||||||
|
// In a real implementation, you might:
|
||||||
|
// - Insert into an audit log table
|
||||||
|
// - Send to a logging service
|
||||||
|
// - Write to a file
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExampleCacheHook implements simple response caching
|
||||||
|
func ExampleCacheHook(ctx *HookContext) error {
|
||||||
|
// This is a simplified example - real caching would use a proper cache store
|
||||||
|
// Check if we have a cached result for this query
|
||||||
|
// cacheKey := fmt.Sprintf("%s:%s", ctx.UserContext.UserName, ctx.SQLQuery)
|
||||||
|
// if cachedResult := checkCache(cacheKey); cachedResult != nil {
|
||||||
|
// ctx.Result = cachedResult
|
||||||
|
// ctx.Abort = true // Skip query execution
|
||||||
|
// ctx.AbortMessage = "Serving from cache"
|
||||||
|
// }
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExampleErrorHandlingHook provides custom error handling
|
||||||
|
func ExampleErrorHandlingHook(ctx *HookContext) error {
|
||||||
|
if ctx.Error != nil {
|
||||||
|
// Log error with context
|
||||||
|
logger.Error("Query failed for user %s: %v\nQuery: %s",
|
||||||
|
ctx.UserContext.UserName,
|
||||||
|
ctx.Error,
|
||||||
|
ctx.SQLQuery,
|
||||||
|
)
|
||||||
|
|
||||||
|
// You could send notifications, update metrics, etc.
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Example of registering hooks:
|
||||||
|
//
|
||||||
|
// func SetupHooks(handler *Handler) {
|
||||||
|
// hooks := handler.Hooks()
|
||||||
|
//
|
||||||
|
// // Register security hook before query execution
|
||||||
|
// hooks.Register(BeforeQuery, ExampleSecurityHook)
|
||||||
|
// hooks.Register(BeforeQueryList, ExampleSecurityHook)
|
||||||
|
//
|
||||||
|
// // Register logging hook before SQL execution
|
||||||
|
// hooks.Register(BeforeSQLExec, ExampleLoggingHook)
|
||||||
|
//
|
||||||
|
// // Register result filtering after query
|
||||||
|
// hooks.Register(AfterQuery, ExampleResultFilterHook)
|
||||||
|
// hooks.Register(AfterQueryList, ExampleResultFilterHook)
|
||||||
|
//
|
||||||
|
// // Register audit hook after execution
|
||||||
|
// hooks.RegisterMultiple([]HookType{AfterQuery, AfterQueryList}, ExampleAuditHook)
|
||||||
|
// }
|
||||||
589
pkg/funcspec/hooks_test.go
Normal file
589
pkg/funcspec/hooks_test.go
Normal file
@ -0,0 +1,589 @@
|
|||||||
|
package funcspec
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/security"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestNewHookRegistry tests hook registry creation
|
||||||
|
func TestNewHookRegistry(t *testing.T) {
|
||||||
|
registry := NewHookRegistry()
|
||||||
|
|
||||||
|
if registry == nil {
|
||||||
|
t.Fatal("Expected registry to be created, got nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
if registry.hooks == nil {
|
||||||
|
t.Error("Expected hooks map to be initialized")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRegisterHook tests registering a single hook
|
||||||
|
func TestRegisterHook(t *testing.T) {
|
||||||
|
registry := NewHookRegistry()
|
||||||
|
|
||||||
|
hookCalled := false
|
||||||
|
testHook := func(ctx *HookContext) error {
|
||||||
|
hookCalled = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
registry.Register(BeforeQuery, testHook)
|
||||||
|
|
||||||
|
if !registry.HasHooks(BeforeQuery) {
|
||||||
|
t.Error("Expected hook to be registered")
|
||||||
|
}
|
||||||
|
|
||||||
|
if registry.Count(BeforeQuery) != 1 {
|
||||||
|
t.Errorf("Expected 1 hook, got %d", registry.Count(BeforeQuery))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute the hook
|
||||||
|
ctx := &HookContext{}
|
||||||
|
err := registry.Execute(BeforeQuery, ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Hook execution failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !hookCalled {
|
||||||
|
t.Error("Expected hook to be called")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRegisterMultipleHooks tests registering multiple hooks for same type
|
||||||
|
func TestRegisterMultipleHooks(t *testing.T) {
|
||||||
|
registry := NewHookRegistry()
|
||||||
|
|
||||||
|
callOrder := []int{}
|
||||||
|
|
||||||
|
hook1 := func(ctx *HookContext) error {
|
||||||
|
callOrder = append(callOrder, 1)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
hook2 := func(ctx *HookContext) error {
|
||||||
|
callOrder = append(callOrder, 2)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
hook3 := func(ctx *HookContext) error {
|
||||||
|
callOrder = append(callOrder, 3)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
registry.Register(BeforeQuery, hook1)
|
||||||
|
registry.Register(BeforeQuery, hook2)
|
||||||
|
registry.Register(BeforeQuery, hook3)
|
||||||
|
|
||||||
|
if registry.Count(BeforeQuery) != 3 {
|
||||||
|
t.Errorf("Expected 3 hooks, got %d", registry.Count(BeforeQuery))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute hooks
|
||||||
|
ctx := &HookContext{}
|
||||||
|
err := registry.Execute(BeforeQuery, ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Hook execution failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify hooks were called in order
|
||||||
|
if len(callOrder) != 3 {
|
||||||
|
t.Errorf("Expected 3 hooks to be called, got %d", len(callOrder))
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, expected := range []int{1, 2, 3} {
|
||||||
|
if callOrder[i] != expected {
|
||||||
|
t.Errorf("Expected hook %d at position %d, got %d", expected, i, callOrder[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRegisterMultipleHookTypes tests registering a hook for multiple types
|
||||||
|
func TestRegisterMultipleHookTypes(t *testing.T) {
|
||||||
|
registry := NewHookRegistry()
|
||||||
|
|
||||||
|
callCount := 0
|
||||||
|
testHook := func(ctx *HookContext) error {
|
||||||
|
callCount++
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
hookTypes := []HookType{BeforeQuery, AfterQuery, BeforeSQLExec}
|
||||||
|
registry.RegisterMultiple(hookTypes, testHook)
|
||||||
|
|
||||||
|
// Verify hook is registered for all types
|
||||||
|
for _, hookType := range hookTypes {
|
||||||
|
if !registry.HasHooks(hookType) {
|
||||||
|
t.Errorf("Expected hook to be registered for %s", hookType)
|
||||||
|
}
|
||||||
|
|
||||||
|
if registry.Count(hookType) != 1 {
|
||||||
|
t.Errorf("Expected 1 hook for %s, got %d", hookType, registry.Count(hookType))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute each hook type
|
||||||
|
ctx := &HookContext{}
|
||||||
|
for _, hookType := range hookTypes {
|
||||||
|
if err := registry.Execute(hookType, ctx); err != nil {
|
||||||
|
t.Errorf("Hook execution failed for %s: %v", hookType, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if callCount != 3 {
|
||||||
|
t.Errorf("Expected hook to be called 3 times, got %d", callCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestHookError tests hook error handling
|
||||||
|
func TestHookError(t *testing.T) {
|
||||||
|
registry := NewHookRegistry()
|
||||||
|
|
||||||
|
expectedError := fmt.Errorf("test error")
|
||||||
|
errorHook := func(ctx *HookContext) error {
|
||||||
|
return expectedError
|
||||||
|
}
|
||||||
|
|
||||||
|
registry.Register(BeforeQuery, errorHook)
|
||||||
|
|
||||||
|
ctx := &HookContext{}
|
||||||
|
err := registry.Execute(BeforeQuery, ctx)
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Expected error from hook, got nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err.Error() != fmt.Sprintf("hook execution failed: %v", expectedError) {
|
||||||
|
t.Errorf("Expected error message to contain hook error, got: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestHookAbort tests hook abort functionality
|
||||||
|
func TestHookAbort(t *testing.T) {
|
||||||
|
registry := NewHookRegistry()
|
||||||
|
|
||||||
|
abortHook := func(ctx *HookContext) error {
|
||||||
|
ctx.Abort = true
|
||||||
|
ctx.AbortMessage = "Operation aborted by hook"
|
||||||
|
ctx.AbortCode = 403
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
registry.Register(BeforeQuery, abortHook)
|
||||||
|
|
||||||
|
ctx := &HookContext{}
|
||||||
|
err := registry.Execute(BeforeQuery, ctx)
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Expected error when hook aborts, got nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !ctx.Abort {
|
||||||
|
t.Error("Expected Abort to be true")
|
||||||
|
}
|
||||||
|
|
||||||
|
if ctx.AbortMessage != "Operation aborted by hook" {
|
||||||
|
t.Errorf("Expected abort message, got: %s", ctx.AbortMessage)
|
||||||
|
}
|
||||||
|
|
||||||
|
if ctx.AbortCode != 403 {
|
||||||
|
t.Errorf("Expected abort code 403, got: %d", ctx.AbortCode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestHookChainWithError tests that hook chain stops on first error
|
||||||
|
func TestHookChainWithError(t *testing.T) {
|
||||||
|
registry := NewHookRegistry()
|
||||||
|
|
||||||
|
callOrder := []int{}
|
||||||
|
|
||||||
|
hook1 := func(ctx *HookContext) error {
|
||||||
|
callOrder = append(callOrder, 1)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
hook2 := func(ctx *HookContext) error {
|
||||||
|
callOrder = append(callOrder, 2)
|
||||||
|
return fmt.Errorf("error in hook 2")
|
||||||
|
}
|
||||||
|
|
||||||
|
hook3 := func(ctx *HookContext) error {
|
||||||
|
callOrder = append(callOrder, 3)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
registry.Register(BeforeQuery, hook1)
|
||||||
|
registry.Register(BeforeQuery, hook2)
|
||||||
|
registry.Register(BeforeQuery, hook3)
|
||||||
|
|
||||||
|
ctx := &HookContext{}
|
||||||
|
err := registry.Execute(BeforeQuery, ctx)
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Expected error from hook chain")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only first two hooks should have been called
|
||||||
|
if len(callOrder) != 2 {
|
||||||
|
t.Errorf("Expected 2 hooks to be called, got %d", len(callOrder))
|
||||||
|
}
|
||||||
|
|
||||||
|
if callOrder[0] != 1 || callOrder[1] != 2 {
|
||||||
|
t.Errorf("Expected hooks 1 and 2 to be called, got: %v", callOrder)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestClearHooks tests clearing hooks
|
||||||
|
func TestClearHooks(t *testing.T) {
|
||||||
|
registry := NewHookRegistry()
|
||||||
|
|
||||||
|
testHook := func(ctx *HookContext) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
registry.Register(BeforeQuery, testHook)
|
||||||
|
registry.Register(AfterQuery, testHook)
|
||||||
|
|
||||||
|
if !registry.HasHooks(BeforeQuery) {
|
||||||
|
t.Error("Expected BeforeQuery hook to be registered")
|
||||||
|
}
|
||||||
|
|
||||||
|
registry.Clear(BeforeQuery)
|
||||||
|
|
||||||
|
if registry.HasHooks(BeforeQuery) {
|
||||||
|
t.Error("Expected BeforeQuery hooks to be cleared")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !registry.HasHooks(AfterQuery) {
|
||||||
|
t.Error("Expected AfterQuery hook to still be registered")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestClearAllHooks tests clearing all hooks
|
||||||
|
func TestClearAllHooks(t *testing.T) {
|
||||||
|
registry := NewHookRegistry()
|
||||||
|
|
||||||
|
testHook := func(ctx *HookContext) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
registry.Register(BeforeQuery, testHook)
|
||||||
|
registry.Register(AfterQuery, testHook)
|
||||||
|
registry.Register(BeforeSQLExec, testHook)
|
||||||
|
|
||||||
|
registry.ClearAll()
|
||||||
|
|
||||||
|
if registry.HasHooks(BeforeQuery) || registry.HasHooks(AfterQuery) || registry.HasHooks(BeforeSQLExec) {
|
||||||
|
t.Error("Expected all hooks to be cleared")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestGetAllHookTypes tests getting all registered hook types
|
||||||
|
func TestGetAllHookTypes(t *testing.T) {
|
||||||
|
registry := NewHookRegistry()
|
||||||
|
|
||||||
|
testHook := func(ctx *HookContext) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
registry.Register(BeforeQuery, testHook)
|
||||||
|
registry.Register(AfterQuery, testHook)
|
||||||
|
|
||||||
|
types := registry.GetAllHookTypes()
|
||||||
|
|
||||||
|
if len(types) != 2 {
|
||||||
|
t.Errorf("Expected 2 hook types, got %d", len(types))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify the types are present
|
||||||
|
foundBefore := false
|
||||||
|
foundAfter := false
|
||||||
|
for _, hookType := range types {
|
||||||
|
if hookType == BeforeQuery {
|
||||||
|
foundBefore = true
|
||||||
|
}
|
||||||
|
if hookType == AfterQuery {
|
||||||
|
foundAfter = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !foundBefore || !foundAfter {
|
||||||
|
t.Error("Expected both BeforeQuery and AfterQuery hook types")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestHookContextModification tests that hooks can modify the context
|
||||||
|
func TestHookContextModification(t *testing.T) {
|
||||||
|
registry := NewHookRegistry()
|
||||||
|
|
||||||
|
// Hook that modifies SQL query
|
||||||
|
modifyHook := func(ctx *HookContext) error {
|
||||||
|
ctx.SQLQuery = "SELECT * FROM modified_table"
|
||||||
|
ctx.Variables["new_var"] = "new_value"
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
registry.Register(BeforeQuery, modifyHook)
|
||||||
|
|
||||||
|
ctx := &HookContext{
|
||||||
|
SQLQuery: "SELECT * FROM original_table",
|
||||||
|
Variables: make(map[string]interface{}),
|
||||||
|
}
|
||||||
|
|
||||||
|
err := registry.Execute(BeforeQuery, ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Hook execution failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if ctx.SQLQuery != "SELECT * FROM modified_table" {
|
||||||
|
t.Errorf("Expected SQL query to be modified, got: %s", ctx.SQLQuery)
|
||||||
|
}
|
||||||
|
|
||||||
|
if ctx.Variables["new_var"] != "new_value" {
|
||||||
|
t.Errorf("Expected variable to be added, got: %v", ctx.Variables)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestExampleHooks tests the example hooks
|
||||||
|
func TestExampleLoggingHook(t *testing.T) {
|
||||||
|
ctx := &HookContext{
|
||||||
|
Context: context.Background(),
|
||||||
|
SQLQuery: "SELECT * FROM test",
|
||||||
|
UserContext: &security.UserContext{
|
||||||
|
UserName: "testuser",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
err := ExampleLoggingHook(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("ExampleLoggingHook failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExampleSecurityHook(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
sqlQuery string
|
||||||
|
userID int
|
||||||
|
shouldAbort bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Admin accessing sensitive table",
|
||||||
|
sqlQuery: "SELECT * FROM sensitive_table",
|
||||||
|
userID: 1,
|
||||||
|
shouldAbort: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Non-admin accessing sensitive table",
|
||||||
|
sqlQuery: "SELECT * FROM sensitive_table",
|
||||||
|
userID: 2,
|
||||||
|
shouldAbort: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Non-admin accessing normal table",
|
||||||
|
sqlQuery: "SELECT * FROM users",
|
||||||
|
userID: 2,
|
||||||
|
shouldAbort: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
ctx := &HookContext{
|
||||||
|
Context: context.Background(),
|
||||||
|
SQLQuery: tt.sqlQuery,
|
||||||
|
UserContext: &security.UserContext{
|
||||||
|
UserID: tt.userID,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = ExampleSecurityHook(ctx)
|
||||||
|
|
||||||
|
if tt.shouldAbort {
|
||||||
|
if !ctx.Abort {
|
||||||
|
t.Error("Expected security hook to abort operation")
|
||||||
|
}
|
||||||
|
if ctx.AbortCode != 403 {
|
||||||
|
t.Errorf("Expected abort code 403, got %d", ctx.AbortCode)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if ctx.Abort {
|
||||||
|
t.Error("Expected security hook not to abort operation")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExampleResultFilterHook(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
userID int
|
||||||
|
result interface{}
|
||||||
|
validate func(t *testing.T, result interface{})
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Admin user - no filtering",
|
||||||
|
userID: 1,
|
||||||
|
result: map[string]interface{}{
|
||||||
|
"id": 1,
|
||||||
|
"name": "Test",
|
||||||
|
"password": "secret",
|
||||||
|
},
|
||||||
|
validate: func(t *testing.T, result interface{}) {
|
||||||
|
m := result.(map[string]interface{})
|
||||||
|
if _, exists := m["password"]; !exists {
|
||||||
|
t.Error("Expected password field to remain for admin")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Regular user - sensitive fields removed",
|
||||||
|
userID: 2,
|
||||||
|
result: map[string]interface{}{
|
||||||
|
"id": 1,
|
||||||
|
"name": "Test",
|
||||||
|
"password": "secret",
|
||||||
|
"ssn": "123-45-6789",
|
||||||
|
},
|
||||||
|
validate: func(t *testing.T, result interface{}) {
|
||||||
|
m := result.(map[string]interface{})
|
||||||
|
if _, exists := m["password"]; exists {
|
||||||
|
t.Error("Expected password field to be removed")
|
||||||
|
}
|
||||||
|
if _, exists := m["ssn"]; exists {
|
||||||
|
t.Error("Expected ssn field to be removed")
|
||||||
|
}
|
||||||
|
if _, exists := m["name"]; !exists {
|
||||||
|
t.Error("Expected name field to remain")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Regular user - list results filtered",
|
||||||
|
userID: 2,
|
||||||
|
result: []map[string]interface{}{
|
||||||
|
{"id": 1, "name": "User 1", "password": "secret1"},
|
||||||
|
{"id": 2, "name": "User 2", "password": "secret2"},
|
||||||
|
},
|
||||||
|
validate: func(t *testing.T, result interface{}) {
|
||||||
|
list := result.([]map[string]interface{})
|
||||||
|
for _, m := range list {
|
||||||
|
if _, exists := m["password"]; exists {
|
||||||
|
t.Error("Expected password field to be removed from list")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
ctx := &HookContext{
|
||||||
|
Context: context.Background(),
|
||||||
|
Result: tt.result,
|
||||||
|
UserContext: &security.UserContext{
|
||||||
|
UserID: tt.userID,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
err := ExampleResultFilterHook(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Hook failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.validate != nil {
|
||||||
|
tt.validate(t, ctx.Result)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExampleAuditHook(t *testing.T) {
|
||||||
|
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||||
|
req.RemoteAddr = "192.168.1.1:12345"
|
||||||
|
|
||||||
|
ctx := &HookContext{
|
||||||
|
Context: context.Background(),
|
||||||
|
Request: req,
|
||||||
|
UserContext: &security.UserContext{
|
||||||
|
UserID: 123,
|
||||||
|
UserName: "testuser",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
err := ExampleAuditHook(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("ExampleAuditHook failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExampleErrorHandlingHook(t *testing.T) {
|
||||||
|
ctx := &HookContext{
|
||||||
|
Context: context.Background(),
|
||||||
|
SQLQuery: "SELECT * FROM test",
|
||||||
|
Error: fmt.Errorf("test error"),
|
||||||
|
UserContext: &security.UserContext{
|
||||||
|
UserName: "testuser",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
err := ExampleErrorHandlingHook(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("ExampleErrorHandlingHook failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestHookIntegrationWithHandler tests hooks integrated with the handler
|
||||||
|
func TestHookIntegrationWithHandler(t *testing.T) {
|
||||||
|
db := &MockDatabase{
|
||||||
|
RunInTransactionFunc: func(ctx context.Context, fn func(common.Database) error) error {
|
||||||
|
queryDB := &MockDatabase{
|
||||||
|
QueryFunc: func(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
|
||||||
|
rows := dest.(*[]map[string]interface{})
|
||||||
|
*rows = []map[string]interface{}{
|
||||||
|
{"id": float64(1), "name": "Test User"},
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return fn(queryDB)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
handler := NewHandler(db)
|
||||||
|
|
||||||
|
// Register a hook that modifies the SQL query
|
||||||
|
hookCalled := false
|
||||||
|
handler.Hooks().Register(BeforeSQLExec, func(ctx *HookContext) error {
|
||||||
|
hookCalled = true
|
||||||
|
// Verify we can access context data
|
||||||
|
if ctx.SQLQuery == "" {
|
||||||
|
t.Error("Expected SQL query to be set")
|
||||||
|
}
|
||||||
|
if ctx.UserContext == nil {
|
||||||
|
t.Error("Expected user context to be set")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
// Execute a query
|
||||||
|
req := createTestRequest("GET", "/test", nil, nil, nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
handlerFunc := handler.SqlQuery("SELECT * FROM users WHERE id = 1", false)
|
||||||
|
handlerFunc(w, req)
|
||||||
|
|
||||||
|
if !hookCalled {
|
||||||
|
t.Error("Expected hook to be called during query execution")
|
||||||
|
}
|
||||||
|
|
||||||
|
if w.Code != 200 {
|
||||||
|
t.Errorf("Expected status 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
411
pkg/funcspec/parameters.go
Normal file
411
pkg/funcspec/parameters.go
Normal file
@ -0,0 +1,411 @@
|
|||||||
|
package funcspec
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/restheadspec"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RequestParameters holds parsed parameters from headers and query string
|
||||||
|
type RequestParameters struct {
|
||||||
|
// Field selection
|
||||||
|
SelectFields []string
|
||||||
|
NotSelectFields []string
|
||||||
|
Distinct bool
|
||||||
|
|
||||||
|
// Filtering
|
||||||
|
FieldFilters map[string]string // column -> value (exact match)
|
||||||
|
SearchFilters map[string]string // column -> value (ILIKE)
|
||||||
|
SearchOps map[string]FilterOperator // column -> {operator, value, logic}
|
||||||
|
CustomSQLWhere string
|
||||||
|
CustomSQLOr string
|
||||||
|
|
||||||
|
// Sorting & Pagination
|
||||||
|
SortColumns string
|
||||||
|
Limit int
|
||||||
|
Offset int
|
||||||
|
|
||||||
|
// Advanced features
|
||||||
|
SkipCount bool
|
||||||
|
SkipCache bool
|
||||||
|
|
||||||
|
// Response format
|
||||||
|
ResponseFormat string // "simple", "detail", "syncfusion"
|
||||||
|
ComplexAPI bool // true if NOT simple API
|
||||||
|
}
|
||||||
|
|
||||||
|
// FilterOperator represents a filter with operator
|
||||||
|
type FilterOperator struct {
|
||||||
|
Operator string // eq, neq, gt, lt, gte, lte, like, ilike, in, between, etc.
|
||||||
|
Value string
|
||||||
|
Logic string // AND or OR
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseParameters parses all parameters from request headers and query string
|
||||||
|
func (h *Handler) ParseParameters(r *http.Request) *RequestParameters {
|
||||||
|
params := &RequestParameters{
|
||||||
|
FieldFilters: make(map[string]string),
|
||||||
|
SearchFilters: make(map[string]string),
|
||||||
|
SearchOps: make(map[string]FilterOperator),
|
||||||
|
Limit: 20, // Default limit
|
||||||
|
Offset: 0, // Default offset
|
||||||
|
ResponseFormat: "simple", // Default format
|
||||||
|
ComplexAPI: false, // Default to simple API
|
||||||
|
}
|
||||||
|
|
||||||
|
// Merge headers and query parameters
|
||||||
|
combined := make(map[string]string)
|
||||||
|
|
||||||
|
// Add all headers (normalize to lowercase)
|
||||||
|
for key, values := range r.Header {
|
||||||
|
if len(values) > 0 {
|
||||||
|
combined[strings.ToLower(key)] = values[0]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add all query parameters (override headers)
|
||||||
|
for key, values := range r.URL.Query() {
|
||||||
|
if len(values) > 0 {
|
||||||
|
combined[strings.ToLower(key)] = values[0]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse each parameter
|
||||||
|
for key, value := range combined {
|
||||||
|
// Decode value if base64 encoded
|
||||||
|
decodedValue := h.decodeValue(value)
|
||||||
|
|
||||||
|
switch {
|
||||||
|
// Field Selection
|
||||||
|
case strings.HasPrefix(key, "x-select-fields"):
|
||||||
|
params.SelectFields = h.parseCommaSeparated(decodedValue)
|
||||||
|
case strings.HasPrefix(key, "x-not-select-fields"):
|
||||||
|
params.NotSelectFields = h.parseCommaSeparated(decodedValue)
|
||||||
|
case strings.HasPrefix(key, "x-distinct"):
|
||||||
|
params.Distinct = strings.EqualFold(decodedValue, "true")
|
||||||
|
|
||||||
|
// Filtering
|
||||||
|
case strings.HasPrefix(key, "x-fieldfilter-"):
|
||||||
|
colName := strings.TrimPrefix(key, "x-fieldfilter-")
|
||||||
|
params.FieldFilters[colName] = decodedValue
|
||||||
|
case strings.HasPrefix(key, "x-searchfilter-"):
|
||||||
|
colName := strings.TrimPrefix(key, "x-searchfilter-")
|
||||||
|
params.SearchFilters[colName] = decodedValue
|
||||||
|
case strings.HasPrefix(key, "x-searchop-"):
|
||||||
|
h.parseSearchOp(params, key, decodedValue, "AND")
|
||||||
|
case strings.HasPrefix(key, "x-searchor-"):
|
||||||
|
h.parseSearchOp(params, key, decodedValue, "OR")
|
||||||
|
case strings.HasPrefix(key, "x-searchand-"):
|
||||||
|
h.parseSearchOp(params, key, decodedValue, "AND")
|
||||||
|
case strings.HasPrefix(key, "x-custom-sql-w"):
|
||||||
|
if params.CustomSQLWhere != "" {
|
||||||
|
params.CustomSQLWhere = fmt.Sprintf("%s AND (%s)", params.CustomSQLWhere, decodedValue)
|
||||||
|
} else {
|
||||||
|
params.CustomSQLWhere = decodedValue
|
||||||
|
}
|
||||||
|
case strings.HasPrefix(key, "x-custom-sql-or"):
|
||||||
|
if params.CustomSQLOr != "" {
|
||||||
|
params.CustomSQLOr = fmt.Sprintf("%s OR (%s)", params.CustomSQLOr, decodedValue)
|
||||||
|
} else {
|
||||||
|
params.CustomSQLOr = decodedValue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sorting & Pagination
|
||||||
|
case key == "sort" || strings.HasPrefix(key, "x-sort"):
|
||||||
|
params.SortColumns = decodedValue
|
||||||
|
case strings.HasPrefix(key, "sort(") && strings.Contains(key, ")"):
|
||||||
|
// Handle sort(col1,-col2) syntax
|
||||||
|
sortValue := key[strings.Index(key, "(")+1 : strings.Index(key, ")")]
|
||||||
|
params.SortColumns = sortValue
|
||||||
|
case key == "limit" || strings.HasPrefix(key, "x-limit"):
|
||||||
|
if limit, err := strconv.Atoi(decodedValue); err == nil && limit > 0 {
|
||||||
|
params.Limit = limit
|
||||||
|
}
|
||||||
|
case strings.HasPrefix(key, "limit(") && strings.Contains(key, ")"):
|
||||||
|
// Handle limit(offset,limit) or limit(limit) syntax
|
||||||
|
limitValue := key[strings.Index(key, "(")+1 : strings.Index(key, ")")]
|
||||||
|
parts := strings.Split(limitValue, ",")
|
||||||
|
if len(parts) > 1 {
|
||||||
|
if offset, err := strconv.Atoi(parts[0]); err == nil {
|
||||||
|
params.Offset = offset
|
||||||
|
}
|
||||||
|
if limit, err := strconv.Atoi(parts[1]); err == nil {
|
||||||
|
params.Limit = limit
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if limit, err := strconv.Atoi(parts[0]); err == nil {
|
||||||
|
params.Limit = limit
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case key == "offset" || strings.HasPrefix(key, "x-offset"):
|
||||||
|
if offset, err := strconv.Atoi(decodedValue); err == nil && offset >= 0 {
|
||||||
|
params.Offset = offset
|
||||||
|
}
|
||||||
|
|
||||||
|
// Advanced features
|
||||||
|
case strings.HasPrefix(key, "x-skipcount"):
|
||||||
|
params.SkipCount = strings.EqualFold(decodedValue, "true")
|
||||||
|
case strings.HasPrefix(key, "x-skipcache"):
|
||||||
|
params.SkipCache = strings.EqualFold(decodedValue, "true")
|
||||||
|
|
||||||
|
// Response Format
|
||||||
|
case strings.HasPrefix(key, "x-simpleapi"):
|
||||||
|
params.ResponseFormat = "simple"
|
||||||
|
params.ComplexAPI = !(decodedValue == "1" || strings.EqualFold(decodedValue, "true"))
|
||||||
|
case strings.HasPrefix(key, "x-detailapi"):
|
||||||
|
params.ResponseFormat = "detail"
|
||||||
|
params.ComplexAPI = true
|
||||||
|
case strings.HasPrefix(key, "x-syncfusion"):
|
||||||
|
params.ResponseFormat = "syncfusion"
|
||||||
|
params.ComplexAPI = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return params
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseSearchOp parses x-searchop-{operator}-{column} or x-searchor-{operator}-{column}
|
||||||
|
func (h *Handler) parseSearchOp(params *RequestParameters, headerKey, value, logic string) {
|
||||||
|
var prefix string
|
||||||
|
if logic == "OR" {
|
||||||
|
prefix = "x-searchor-"
|
||||||
|
} else {
|
||||||
|
prefix = "x-searchop-"
|
||||||
|
if strings.HasPrefix(headerKey, "x-searchand-") {
|
||||||
|
prefix = "x-searchand-"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
rest := strings.TrimPrefix(headerKey, prefix)
|
||||||
|
parts := strings.SplitN(rest, "-", 2)
|
||||||
|
if len(parts) != 2 {
|
||||||
|
logger.Warn("Invalid search operator header format: %s", headerKey)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
operator := parts[0]
|
||||||
|
colName := parts[1]
|
||||||
|
|
||||||
|
params.SearchOps[colName] = FilterOperator{
|
||||||
|
Operator: operator,
|
||||||
|
Value: value,
|
||||||
|
Logic: logic,
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Debug("%s search operator: %s %s %s", logic, colName, operator, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
// decodeValue decodes base64 encoded values (ZIP_ or __ prefix)
|
||||||
|
func (h *Handler) decodeValue(value string) string {
|
||||||
|
decoded, _ := restheadspec.DecodeParam(value)
|
||||||
|
return decoded
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseCommaSeparated parses comma-separated values
|
||||||
|
func (h *Handler) parseCommaSeparated(value string) []string {
|
||||||
|
if value == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
parts := strings.Split(value, ",")
|
||||||
|
result := make([]string, 0, len(parts))
|
||||||
|
for _, part := range parts {
|
||||||
|
part = strings.TrimSpace(part)
|
||||||
|
if part != "" {
|
||||||
|
result = append(result, part)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// ApplyFieldSelection applies column selection to SQL query
|
||||||
|
func (h *Handler) ApplyFieldSelection(sqlQuery string, params *RequestParameters) string {
|
||||||
|
if len(params.SelectFields) == 0 && len(params.NotSelectFields) == 0 {
|
||||||
|
return sqlQuery
|
||||||
|
}
|
||||||
|
|
||||||
|
// This is a simplified implementation
|
||||||
|
// A full implementation would parse the SQL and replace the SELECT clause
|
||||||
|
// For now, we log a warning that this feature needs manual implementation
|
||||||
|
if len(params.SelectFields) > 0 {
|
||||||
|
logger.Debug("Field selection requested: %v (manual SQL adjustment may be needed)", params.SelectFields)
|
||||||
|
}
|
||||||
|
if len(params.NotSelectFields) > 0 {
|
||||||
|
logger.Debug("Field exclusion requested: %v (manual SQL adjustment may be needed)", params.NotSelectFields)
|
||||||
|
}
|
||||||
|
|
||||||
|
return sqlQuery
|
||||||
|
}
|
||||||
|
|
||||||
|
// ApplyFilters applies all filters to the SQL query
|
||||||
|
func (h *Handler) ApplyFilters(sqlQuery string, params *RequestParameters) string {
|
||||||
|
// Apply field filters (exact match)
|
||||||
|
for colName, value := range params.FieldFilters {
|
||||||
|
condition := ""
|
||||||
|
if value == "" || value == "0" {
|
||||||
|
condition = fmt.Sprintf("COALESCE(%s, 0) = %s", ValidSQL(colName, "colname"), ValidSQL(value, "colvalue"))
|
||||||
|
} else {
|
||||||
|
condition = fmt.Sprintf("%s = %s", ValidSQL(colName, "colname"), ValidSQL(value, "colvalue"))
|
||||||
|
}
|
||||||
|
sqlQuery = sqlQryWhere(sqlQuery, condition)
|
||||||
|
logger.Debug("Applied field filter: %s", condition)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply search filters (ILIKE)
|
||||||
|
for colName, value := range params.SearchFilters {
|
||||||
|
sval := strings.ReplaceAll(value, "'", "")
|
||||||
|
if sval != "" {
|
||||||
|
condition := fmt.Sprintf("%s ILIKE '%%%s%%'", ValidSQL(colName, "colname"), ValidSQL(sval, "colvalue"))
|
||||||
|
sqlQuery = sqlQryWhere(sqlQuery, condition)
|
||||||
|
logger.Debug("Applied search filter: %s", condition)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply search operators
|
||||||
|
for colName, filterOp := range params.SearchOps {
|
||||||
|
condition := h.buildFilterCondition(colName, filterOp)
|
||||||
|
if condition != "" {
|
||||||
|
if filterOp.Logic == "OR" {
|
||||||
|
sqlQuery = sqlQryWhereOr(sqlQuery, condition)
|
||||||
|
} else {
|
||||||
|
sqlQuery = sqlQryWhere(sqlQuery, condition)
|
||||||
|
}
|
||||||
|
logger.Debug("Applied search operator: %s", condition)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply custom SQL WHERE
|
||||||
|
if params.CustomSQLWhere != "" {
|
||||||
|
colval := ValidSQL(params.CustomSQLWhere, "select")
|
||||||
|
if colval != "" {
|
||||||
|
sqlQuery = sqlQryWhere(sqlQuery, colval)
|
||||||
|
logger.Debug("Applied custom SQL WHERE: %s", colval)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply custom SQL OR
|
||||||
|
if params.CustomSQLOr != "" {
|
||||||
|
colval := ValidSQL(params.CustomSQLOr, "select")
|
||||||
|
if colval != "" {
|
||||||
|
sqlQuery = sqlQryWhereOr(sqlQuery, colval)
|
||||||
|
logger.Debug("Applied custom SQL OR: %s", colval)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return sqlQuery
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildFilterCondition builds a SQL condition from a FilterOperator
|
||||||
|
func (h *Handler) buildFilterCondition(colName string, op FilterOperator) string {
|
||||||
|
safCol := ValidSQL(colName, "colname")
|
||||||
|
operator := strings.ToLower(op.Operator)
|
||||||
|
value := op.Value
|
||||||
|
|
||||||
|
switch operator {
|
||||||
|
case "contains", "contain", "like":
|
||||||
|
return fmt.Sprintf("%s ILIKE '%%%s%%'", safCol, ValidSQL(value, "colvalue"))
|
||||||
|
case "beginswith", "startswith":
|
||||||
|
return fmt.Sprintf("%s ILIKE '%s%%'", safCol, ValidSQL(value, "colvalue"))
|
||||||
|
case "endswith":
|
||||||
|
return fmt.Sprintf("%s ILIKE '%%%s'", safCol, ValidSQL(value, "colvalue"))
|
||||||
|
case "equals", "eq", "=":
|
||||||
|
if IsNumeric(value) {
|
||||||
|
return fmt.Sprintf("%s = %s", safCol, ValidSQL(value, "colvalue"))
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%s = '%s'", safCol, ValidSQL(value, "colvalue"))
|
||||||
|
case "notequals", "neq", "ne", "!=", "<>":
|
||||||
|
if IsNumeric(value) {
|
||||||
|
return fmt.Sprintf("%s != %s", safCol, ValidSQL(value, "colvalue"))
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%s != '%s'", safCol, ValidSQL(value, "colvalue"))
|
||||||
|
case "greaterthan", "gt", ">":
|
||||||
|
return fmt.Sprintf("%s > %s", safCol, ValidSQL(value, "colvalue"))
|
||||||
|
case "lessthan", "lt", "<":
|
||||||
|
return fmt.Sprintf("%s < %s", safCol, ValidSQL(value, "colvalue"))
|
||||||
|
case "greaterthanorequal", "gte", "ge", ">=":
|
||||||
|
return fmt.Sprintf("%s >= %s", safCol, ValidSQL(value, "colvalue"))
|
||||||
|
case "lessthanorequal", "lte", "le", "<=":
|
||||||
|
return fmt.Sprintf("%s <= %s", safCol, ValidSQL(value, "colvalue"))
|
||||||
|
case "between":
|
||||||
|
parts := strings.Split(value, ",")
|
||||||
|
if len(parts) == 2 {
|
||||||
|
return fmt.Sprintf("%s > %s AND %s < %s", safCol, ValidSQL(parts[0], "colvalue"), safCol, ValidSQL(parts[1], "colvalue"))
|
||||||
|
}
|
||||||
|
case "betweeninclusive":
|
||||||
|
parts := strings.Split(value, ",")
|
||||||
|
if len(parts) == 2 {
|
||||||
|
return fmt.Sprintf("%s >= %s AND %s <= %s", safCol, ValidSQL(parts[0], "colvalue"), safCol, ValidSQL(parts[1], "colvalue"))
|
||||||
|
}
|
||||||
|
case "in":
|
||||||
|
values := strings.Split(value, ",")
|
||||||
|
safeValues := make([]string, len(values))
|
||||||
|
for i, v := range values {
|
||||||
|
safeValues[i] = fmt.Sprintf("'%s'", ValidSQL(v, "colvalue"))
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%s IN (%s)", safCol, strings.Join(safeValues, ", "))
|
||||||
|
case "empty", "isnull", "null":
|
||||||
|
return fmt.Sprintf("(%s IS NULL OR %s = '')", safCol, safCol)
|
||||||
|
case "notempty", "isnotnull", "notnull":
|
||||||
|
return fmt.Sprintf("(%s IS NOT NULL AND %s != '')", safCol, safCol)
|
||||||
|
default:
|
||||||
|
logger.Warn("Unknown filter operator: %s, defaulting to equals", operator)
|
||||||
|
return fmt.Sprintf("%s = '%s'", safCol, ValidSQL(value, "colvalue"))
|
||||||
|
}
|
||||||
|
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// ApplyDistinct adds DISTINCT to SQL query if requested
|
||||||
|
func (h *Handler) ApplyDistinct(sqlQuery string, params *RequestParameters) string {
|
||||||
|
if !params.Distinct {
|
||||||
|
return sqlQuery
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add DISTINCT after SELECT
|
||||||
|
selectPos := strings.Index(strings.ToUpper(sqlQuery), "SELECT")
|
||||||
|
if selectPos >= 0 {
|
||||||
|
beforeSelect := sqlQuery[:selectPos+6] // "SELECT"
|
||||||
|
afterSelect := sqlQuery[selectPos+6:]
|
||||||
|
sqlQuery = beforeSelect + " DISTINCT" + afterSelect
|
||||||
|
logger.Debug("Applied DISTINCT to query")
|
||||||
|
}
|
||||||
|
|
||||||
|
return sqlQuery
|
||||||
|
}
|
||||||
|
|
||||||
|
// sqlQryWhereOr adds a WHERE clause with OR logic
|
||||||
|
func sqlQryWhereOr(sqlquery, condition string) string {
|
||||||
|
lowerQuery := strings.ToLower(sqlquery)
|
||||||
|
wherePos := strings.Index(lowerQuery, " where ")
|
||||||
|
groupPos := strings.Index(lowerQuery, " group by")
|
||||||
|
orderPos := strings.Index(lowerQuery, " order by")
|
||||||
|
limitPos := strings.Index(lowerQuery, " limit ")
|
||||||
|
|
||||||
|
// Find the insertion point
|
||||||
|
insertPos := len(sqlquery)
|
||||||
|
if groupPos > 0 && groupPos < insertPos {
|
||||||
|
insertPos = groupPos
|
||||||
|
}
|
||||||
|
if orderPos > 0 && orderPos < insertPos {
|
||||||
|
insertPos = orderPos
|
||||||
|
}
|
||||||
|
if limitPos > 0 && limitPos < insertPos {
|
||||||
|
insertPos = limitPos
|
||||||
|
}
|
||||||
|
|
||||||
|
if wherePos > 0 {
|
||||||
|
// WHERE exists, add OR condition
|
||||||
|
before := sqlquery[:insertPos]
|
||||||
|
after := sqlquery[insertPos:]
|
||||||
|
return fmt.Sprintf("%s OR (%s) %s", before, condition, after)
|
||||||
|
} else {
|
||||||
|
// No WHERE exists, add it
|
||||||
|
before := sqlquery[:insertPos]
|
||||||
|
after := sqlquery[insertPos:]
|
||||||
|
return fmt.Sprintf("%s WHERE %s %s", before, condition, after)
|
||||||
|
}
|
||||||
|
}
|
||||||
549
pkg/funcspec/parameters_test.go
Normal file
549
pkg/funcspec/parameters_test.go
Normal file
@ -0,0 +1,549 @@
|
|||||||
|
package funcspec
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestParseParameters tests the comprehensive parameter parsing
|
||||||
|
func TestParseParameters(t *testing.T) {
|
||||||
|
handler := NewHandler(&MockDatabase{})
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
queryParams map[string]string
|
||||||
|
headers map[string]string
|
||||||
|
validate func(t *testing.T, params *RequestParameters)
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Parse field selection",
|
||||||
|
headers: map[string]string{
|
||||||
|
"X-Select-Fields": "id,name,email",
|
||||||
|
"X-Not-Select-Fields": "password,ssn",
|
||||||
|
},
|
||||||
|
validate: func(t *testing.T, params *RequestParameters) {
|
||||||
|
if len(params.SelectFields) != 3 {
|
||||||
|
t.Errorf("Expected 3 select fields, got %d", len(params.SelectFields))
|
||||||
|
}
|
||||||
|
if len(params.NotSelectFields) != 2 {
|
||||||
|
t.Errorf("Expected 2 not-select fields, got %d", len(params.NotSelectFields))
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Parse distinct flag",
|
||||||
|
headers: map[string]string{
|
||||||
|
"X-Distinct": "true",
|
||||||
|
},
|
||||||
|
validate: func(t *testing.T, params *RequestParameters) {
|
||||||
|
if !params.Distinct {
|
||||||
|
t.Error("Expected Distinct to be true")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Parse field filters",
|
||||||
|
headers: map[string]string{
|
||||||
|
"X-FieldFilter-Status": "active",
|
||||||
|
"X-FieldFilter-Type": "admin",
|
||||||
|
},
|
||||||
|
validate: func(t *testing.T, params *RequestParameters) {
|
||||||
|
if len(params.FieldFilters) != 2 {
|
||||||
|
t.Errorf("Expected 2 field filters, got %d", len(params.FieldFilters))
|
||||||
|
}
|
||||||
|
if params.FieldFilters["status"] != "active" {
|
||||||
|
t.Errorf("Expected status filter=active, got %s", params.FieldFilters["status"])
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Parse search filters",
|
||||||
|
headers: map[string]string{
|
||||||
|
"X-SearchFilter-Name": "john",
|
||||||
|
"X-SearchFilter-Email": "test",
|
||||||
|
},
|
||||||
|
validate: func(t *testing.T, params *RequestParameters) {
|
||||||
|
if len(params.SearchFilters) != 2 {
|
||||||
|
t.Errorf("Expected 2 search filters, got %d", len(params.SearchFilters))
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Parse sort columns",
|
||||||
|
queryParams: map[string]string{
|
||||||
|
"sort": "-created_at,name",
|
||||||
|
},
|
||||||
|
validate: func(t *testing.T, params *RequestParameters) {
|
||||||
|
if params.SortColumns != "-created_at,name" {
|
||||||
|
t.Errorf("Expected sort columns=-created_at,name, got %s", params.SortColumns)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Parse limit and offset",
|
||||||
|
queryParams: map[string]string{
|
||||||
|
"limit": "100",
|
||||||
|
"offset": "50",
|
||||||
|
},
|
||||||
|
validate: func(t *testing.T, params *RequestParameters) {
|
||||||
|
if params.Limit != 100 {
|
||||||
|
t.Errorf("Expected limit=100, got %d", params.Limit)
|
||||||
|
}
|
||||||
|
if params.Offset != 50 {
|
||||||
|
t.Errorf("Expected offset=50, got %d", params.Offset)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Parse skip count",
|
||||||
|
headers: map[string]string{
|
||||||
|
"X-SkipCount": "true",
|
||||||
|
},
|
||||||
|
validate: func(t *testing.T, params *RequestParameters) {
|
||||||
|
if !params.SkipCount {
|
||||||
|
t.Error("Expected SkipCount to be true")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Parse response format - syncfusion",
|
||||||
|
headers: map[string]string{
|
||||||
|
"X-Syncfusion": "true",
|
||||||
|
},
|
||||||
|
validate: func(t *testing.T, params *RequestParameters) {
|
||||||
|
if params.ResponseFormat != "syncfusion" {
|
||||||
|
t.Errorf("Expected ResponseFormat=syncfusion, got %s", params.ResponseFormat)
|
||||||
|
}
|
||||||
|
if !params.ComplexAPI {
|
||||||
|
t.Error("Expected ComplexAPI to be true for syncfusion format")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Parse response format - detail",
|
||||||
|
headers: map[string]string{
|
||||||
|
"X-DetailAPI": "true",
|
||||||
|
},
|
||||||
|
validate: func(t *testing.T, params *RequestParameters) {
|
||||||
|
if params.ResponseFormat != "detail" {
|
||||||
|
t.Errorf("Expected ResponseFormat=detail, got %s", params.ResponseFormat)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Parse simple API",
|
||||||
|
headers: map[string]string{
|
||||||
|
"X-SimpleAPI": "true",
|
||||||
|
},
|
||||||
|
validate: func(t *testing.T, params *RequestParameters) {
|
||||||
|
if params.ResponseFormat != "simple" {
|
||||||
|
t.Errorf("Expected ResponseFormat=simple, got %s", params.ResponseFormat)
|
||||||
|
}
|
||||||
|
if params.ComplexAPI {
|
||||||
|
t.Error("Expected ComplexAPI to be false for simple API")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Parse custom SQL WHERE",
|
||||||
|
headers: map[string]string{
|
||||||
|
"X-Custom-SQL-W": "status = 'active' AND deleted = false",
|
||||||
|
},
|
||||||
|
validate: func(t *testing.T, params *RequestParameters) {
|
||||||
|
if params.CustomSQLWhere == "" {
|
||||||
|
t.Error("Expected CustomSQLWhere to be set")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Parse search operators - AND",
|
||||||
|
headers: map[string]string{
|
||||||
|
"X-SearchOp-Eq-Name": "john",
|
||||||
|
"X-SearchOp-Gt-Age": "18",
|
||||||
|
},
|
||||||
|
validate: func(t *testing.T, params *RequestParameters) {
|
||||||
|
if len(params.SearchOps) != 2 {
|
||||||
|
t.Errorf("Expected 2 search operators, got %d", len(params.SearchOps))
|
||||||
|
}
|
||||||
|
if op, exists := params.SearchOps["name"]; exists {
|
||||||
|
if op.Operator != "eq" {
|
||||||
|
t.Errorf("Expected operator=eq for name, got %s", op.Operator)
|
||||||
|
}
|
||||||
|
if op.Logic != "AND" {
|
||||||
|
t.Errorf("Expected logic=AND, got %s", op.Logic)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
t.Error("Expected name search operator to exist")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Parse search operators - OR",
|
||||||
|
headers: map[string]string{
|
||||||
|
"X-SearchOr-Like-Description": "test",
|
||||||
|
},
|
||||||
|
validate: func(t *testing.T, params *RequestParameters) {
|
||||||
|
if op, exists := params.SearchOps["description"]; exists {
|
||||||
|
if op.Logic != "OR" {
|
||||||
|
t.Errorf("Expected logic=OR, got %s", op.Logic)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
t.Error("Expected description search operator to exist")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
req := createTestRequest("GET", "/test", tt.queryParams, tt.headers, nil)
|
||||||
|
params := handler.ParseParameters(req)
|
||||||
|
|
||||||
|
if tt.validate != nil {
|
||||||
|
tt.validate(t, params)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestBuildFilterCondition tests the filter condition builder
|
||||||
|
func TestBuildFilterCondition(t *testing.T) {
|
||||||
|
handler := NewHandler(&MockDatabase{})
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
colName string
|
||||||
|
operator FilterOperator
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Equals operator - numeric",
|
||||||
|
colName: "age",
|
||||||
|
operator: FilterOperator{
|
||||||
|
Operator: "eq",
|
||||||
|
Value: "25",
|
||||||
|
Logic: "AND",
|
||||||
|
},
|
||||||
|
expected: "age = 25",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Equals operator - string",
|
||||||
|
colName: "name",
|
||||||
|
operator: FilterOperator{
|
||||||
|
Operator: "eq",
|
||||||
|
Value: "john",
|
||||||
|
Logic: "AND",
|
||||||
|
},
|
||||||
|
expected: "name = 'john'",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Not equals operator",
|
||||||
|
colName: "status",
|
||||||
|
operator: FilterOperator{
|
||||||
|
Operator: "neq",
|
||||||
|
Value: "inactive",
|
||||||
|
Logic: "AND",
|
||||||
|
},
|
||||||
|
expected: "status != 'inactive'",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Greater than operator",
|
||||||
|
colName: "age",
|
||||||
|
operator: FilterOperator{
|
||||||
|
Operator: "gt",
|
||||||
|
Value: "18",
|
||||||
|
Logic: "AND",
|
||||||
|
},
|
||||||
|
expected: "age > 18",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Less than operator",
|
||||||
|
colName: "price",
|
||||||
|
operator: FilterOperator{
|
||||||
|
Operator: "lt",
|
||||||
|
Value: "100",
|
||||||
|
Logic: "AND",
|
||||||
|
},
|
||||||
|
expected: "price < 100",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Contains operator",
|
||||||
|
colName: "description",
|
||||||
|
operator: FilterOperator{
|
||||||
|
Operator: "contains",
|
||||||
|
Value: "test",
|
||||||
|
Logic: "AND",
|
||||||
|
},
|
||||||
|
expected: "description ILIKE '%test%'",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Starts with operator",
|
||||||
|
colName: "name",
|
||||||
|
operator: FilterOperator{
|
||||||
|
Operator: "startswith",
|
||||||
|
Value: "john",
|
||||||
|
Logic: "AND",
|
||||||
|
},
|
||||||
|
expected: "name ILIKE 'john%'",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Ends with operator",
|
||||||
|
colName: "email",
|
||||||
|
operator: FilterOperator{
|
||||||
|
Operator: "endswith",
|
||||||
|
Value: "@example.com",
|
||||||
|
Logic: "AND",
|
||||||
|
},
|
||||||
|
expected: "email ILIKE '%@example.com'",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Between operator",
|
||||||
|
colName: "age",
|
||||||
|
operator: FilterOperator{
|
||||||
|
Operator: "between",
|
||||||
|
Value: "18,65",
|
||||||
|
Logic: "AND",
|
||||||
|
},
|
||||||
|
expected: "age > 18 AND age < 65",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IN operator",
|
||||||
|
colName: "status",
|
||||||
|
operator: FilterOperator{
|
||||||
|
Operator: "in",
|
||||||
|
Value: "active,pending,approved",
|
||||||
|
Logic: "AND",
|
||||||
|
},
|
||||||
|
expected: "status IN ('active', 'pending', 'approved')",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IS NULL operator",
|
||||||
|
colName: "deleted_at",
|
||||||
|
operator: FilterOperator{
|
||||||
|
Operator: "null",
|
||||||
|
Value: "",
|
||||||
|
Logic: "AND",
|
||||||
|
},
|
||||||
|
expected: "(deleted_at IS NULL OR deleted_at = '')",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IS NOT NULL operator",
|
||||||
|
colName: "created_at",
|
||||||
|
operator: FilterOperator{
|
||||||
|
Operator: "notnull",
|
||||||
|
Value: "",
|
||||||
|
Logic: "AND",
|
||||||
|
},
|
||||||
|
expected: "(created_at IS NOT NULL AND created_at != '')",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := handler.buildFilterCondition(tt.colName, tt.operator)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("Expected: %s\nGot: %s", tt.expected, result)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestApplyFilters tests the filter application to SQL queries
|
||||||
|
func TestApplyFilters(t *testing.T) {
|
||||||
|
handler := NewHandler(&MockDatabase{})
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
sqlQuery string
|
||||||
|
params *RequestParameters
|
||||||
|
expectedSQL string
|
||||||
|
shouldContain []string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Apply field filter",
|
||||||
|
sqlQuery: "SELECT * FROM users",
|
||||||
|
params: &RequestParameters{
|
||||||
|
FieldFilters: map[string]string{
|
||||||
|
"status": "active",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
shouldContain: []string{"WHERE", "status"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Apply search filter",
|
||||||
|
sqlQuery: "SELECT * FROM users",
|
||||||
|
params: &RequestParameters{
|
||||||
|
SearchFilters: map[string]string{
|
||||||
|
"name": "john",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
shouldContain: []string{"WHERE", "name", "ILIKE"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Apply search operators",
|
||||||
|
sqlQuery: "SELECT * FROM users",
|
||||||
|
params: &RequestParameters{
|
||||||
|
SearchOps: map[string]FilterOperator{
|
||||||
|
"age": {
|
||||||
|
Operator: "gt",
|
||||||
|
Value: "18",
|
||||||
|
Logic: "AND",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
shouldContain: []string{"WHERE", "age", ">", "18"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Apply custom SQL WHERE",
|
||||||
|
sqlQuery: "SELECT * FROM users",
|
||||||
|
params: &RequestParameters{
|
||||||
|
CustomSQLWhere: "deleted = false",
|
||||||
|
},
|
||||||
|
shouldContain: []string{"WHERE", "deleted"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := handler.ApplyFilters(tt.sqlQuery, tt.params)
|
||||||
|
|
||||||
|
for _, expected := range tt.shouldContain {
|
||||||
|
if !strings.Contains(result, expected) {
|
||||||
|
t.Errorf("Expected SQL to contain %q, got: %s", expected, result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestApplyDistinct tests DISTINCT application
|
||||||
|
func TestApplyDistinct(t *testing.T) {
|
||||||
|
handler := NewHandler(&MockDatabase{})
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
sqlQuery string
|
||||||
|
distinct bool
|
||||||
|
shouldHave string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Apply DISTINCT",
|
||||||
|
sqlQuery: "SELECT id, name FROM users",
|
||||||
|
distinct: true,
|
||||||
|
shouldHave: "SELECT DISTINCT",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Do not apply DISTINCT",
|
||||||
|
sqlQuery: "SELECT id, name FROM users",
|
||||||
|
distinct: false,
|
||||||
|
shouldHave: "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
params := &RequestParameters{Distinct: tt.distinct}
|
||||||
|
result := handler.ApplyDistinct(tt.sqlQuery, params)
|
||||||
|
|
||||||
|
if tt.shouldHave != "" {
|
||||||
|
if !strings.Contains(result, tt.shouldHave) {
|
||||||
|
t.Errorf("Expected SQL to contain %q, got: %s", tt.shouldHave, result)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Should not have DISTINCT when not requested
|
||||||
|
if strings.Contains(result, "DISTINCT") && !tt.distinct {
|
||||||
|
t.Errorf("SQL should not contain DISTINCT when not requested: %s", result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestParseCommaSeparated tests comma-separated value parsing
|
||||||
|
func TestParseCommaSeparated(t *testing.T) {
|
||||||
|
handler := NewHandler(&MockDatabase{})
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
expected []string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Simple comma-separated",
|
||||||
|
input: "id,name,email",
|
||||||
|
expected: []string{"id", "name", "email"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "With spaces",
|
||||||
|
input: "id, name, email",
|
||||||
|
expected: []string{"id", "name", "email"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Empty string",
|
||||||
|
input: "",
|
||||||
|
expected: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Single value",
|
||||||
|
input: "id",
|
||||||
|
expected: []string{"id"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "With extra commas",
|
||||||
|
input: "id,,name,,email",
|
||||||
|
expected: []string{"id", "name", "email"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := handler.parseCommaSeparated(tt.input)
|
||||||
|
|
||||||
|
if len(result) != len(tt.expected) {
|
||||||
|
t.Errorf("Expected %d values, got %d", len(tt.expected), len(result))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, expected := range tt.expected {
|
||||||
|
if result[i] != expected {
|
||||||
|
t.Errorf("Expected value %d to be %s, got %s", i, expected, result[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestSqlQryWhereOr tests OR WHERE clause manipulation
|
||||||
|
func TestSqlQryWhereOr(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
sqlQuery string
|
||||||
|
condition string
|
||||||
|
shouldContain []string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Add WHERE with OR to query without WHERE",
|
||||||
|
sqlQuery: "SELECT * FROM users",
|
||||||
|
condition: "status = 'inactive'",
|
||||||
|
shouldContain: []string{"WHERE", "status = 'inactive'"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Add OR to query with existing WHERE",
|
||||||
|
sqlQuery: "SELECT * FROM users WHERE id > 0",
|
||||||
|
condition: "status = 'inactive'",
|
||||||
|
shouldContain: []string{"WHERE", "OR", "(status = 'inactive')"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := sqlQryWhereOr(tt.sqlQuery, tt.condition)
|
||||||
|
|
||||||
|
for _, expected := range tt.shouldContain {
|
||||||
|
if !strings.Contains(result, expected) {
|
||||||
|
t.Errorf("Expected SQL to contain %q, got: %s", expected, result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -8,7 +8,9 @@ import (
|
|||||||
"reflect"
|
"reflect"
|
||||||
"runtime/debug"
|
"runtime/debug"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/cache"
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/reflection"
|
"github.com/bitechdev/ResolveSpec/pkg/reflection"
|
||||||
@ -233,13 +235,46 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Get total count before pagination
|
// Get total count before pagination
|
||||||
total, err := query.Count(ctx)
|
var total int
|
||||||
if err != nil {
|
|
||||||
logger.Error("Error counting records: %v", err)
|
// Try to get from cache first
|
||||||
h.sendError(w, http.StatusInternalServerError, "query_error", "Error counting records", err)
|
cacheKeyHash := cache.BuildQueryCacheKey(
|
||||||
return
|
tableName,
|
||||||
|
options.Filters,
|
||||||
|
options.Sort,
|
||||||
|
"", // No custom SQL WHERE in resolvespec
|
||||||
|
"", // No custom SQL OR in resolvespec
|
||||||
|
)
|
||||||
|
cacheKey := cache.GetQueryTotalCacheKey(cacheKeyHash)
|
||||||
|
|
||||||
|
// Try to retrieve from cache
|
||||||
|
var cachedTotal cache.CachedTotal
|
||||||
|
err := cache.GetDefaultCache().Get(ctx, cacheKey, &cachedTotal)
|
||||||
|
if err == nil {
|
||||||
|
total = cachedTotal.Total
|
||||||
|
logger.Debug("Total records (from cache): %d", total)
|
||||||
|
} else {
|
||||||
|
// Cache miss - execute count query
|
||||||
|
logger.Debug("Cache miss for query total")
|
||||||
|
count, err := query.Count(ctx)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Error counting records: %v", err)
|
||||||
|
h.sendError(w, http.StatusInternalServerError, "query_error", "Error counting records", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
total = count
|
||||||
|
logger.Debug("Total records (from query): %d", total)
|
||||||
|
|
||||||
|
// Store in cache
|
||||||
|
cacheTTL := time.Minute * 2 // Default 2 minutes TTL
|
||||||
|
cacheData := cache.CachedTotal{Total: total}
|
||||||
|
if err := cache.GetDefaultCache().Set(ctx, cacheKey, cacheData, cacheTTL); err != nil {
|
||||||
|
logger.Warn("Failed to cache query total: %v", err)
|
||||||
|
// Don't fail the request if caching fails
|
||||||
|
} else {
|
||||||
|
logger.Debug("Cached query total with key: %s", cacheKey)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
logger.Debug("Total records before filtering: %d", total)
|
|
||||||
|
|
||||||
// Apply pagination
|
// Apply pagination
|
||||||
if options.Limit != nil && *options.Limit > 0 {
|
if options.Limit != nil && *options.Limit > 0 {
|
||||||
|
|||||||
@ -9,7 +9,9 @@ import (
|
|||||||
"runtime/debug"
|
"runtime/debug"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/cache"
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/reflection"
|
"github.com/bitechdev/ResolveSpec/pkg/reflection"
|
||||||
@ -436,14 +438,69 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
|||||||
// Get total count before pagination (unless skip count is requested)
|
// Get total count before pagination (unless skip count is requested)
|
||||||
var total int
|
var total int
|
||||||
if !options.SkipCount {
|
if !options.SkipCount {
|
||||||
count, err := query.Count(ctx)
|
// Try to get from cache first (unless SkipCache is true)
|
||||||
if err != nil {
|
var cachedTotal *cache.CachedTotal
|
||||||
logger.Error("Error counting records: %v", err)
|
var cacheKey string
|
||||||
h.sendError(w, http.StatusInternalServerError, "query_error", "Error counting records", err)
|
|
||||||
return
|
if !options.SkipCache {
|
||||||
|
// Build cache key from query parameters
|
||||||
|
// Convert expand options to interface slice for the cache key builder
|
||||||
|
expandOpts := make([]interface{}, len(options.Expand))
|
||||||
|
for i, exp := range options.Expand {
|
||||||
|
expandOpts[i] = map[string]interface{}{
|
||||||
|
"relation": exp.Relation,
|
||||||
|
"where": exp.Where,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cacheKeyHash := cache.BuildExtendedQueryCacheKey(
|
||||||
|
tableName,
|
||||||
|
options.Filters,
|
||||||
|
options.Sort,
|
||||||
|
options.CustomSQLWhere,
|
||||||
|
options.CustomSQLOr,
|
||||||
|
expandOpts,
|
||||||
|
options.Distinct,
|
||||||
|
options.CursorForward,
|
||||||
|
options.CursorBackward,
|
||||||
|
)
|
||||||
|
cacheKey = cache.GetQueryTotalCacheKey(cacheKeyHash)
|
||||||
|
|
||||||
|
// Try to retrieve from cache
|
||||||
|
cachedTotal = &cache.CachedTotal{}
|
||||||
|
err := cache.GetDefaultCache().Get(ctx, cacheKey, cachedTotal)
|
||||||
|
if err == nil {
|
||||||
|
total = cachedTotal.Total
|
||||||
|
logger.Debug("Total records (from cache): %d", total)
|
||||||
|
} else {
|
||||||
|
logger.Debug("Cache miss for query total")
|
||||||
|
cachedTotal = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If not in cache or cache skip, execute count query
|
||||||
|
if cachedTotal == nil {
|
||||||
|
count, err := query.Count(ctx)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Error counting records: %v", err)
|
||||||
|
h.sendError(w, http.StatusInternalServerError, "query_error", "Error counting records", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
total = count
|
||||||
|
logger.Debug("Total records (from query): %d", total)
|
||||||
|
|
||||||
|
// Store in cache (if caching is enabled)
|
||||||
|
if !options.SkipCache && cacheKey != "" {
|
||||||
|
cacheTTL := time.Minute * 2 // Default 2 minutes TTL
|
||||||
|
cacheData := &cache.CachedTotal{Total: total}
|
||||||
|
if err := cache.GetDefaultCache().Set(ctx, cacheKey, cacheData, cacheTTL); err != nil {
|
||||||
|
logger.Warn("Failed to cache query total: %v", err)
|
||||||
|
// Don't fail the request if caching fails
|
||||||
|
} else {
|
||||||
|
logger.Debug("Cached query total with key: %s", cacheKey)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
total = count
|
|
||||||
logger.Debug("Total records: %d", total)
|
|
||||||
} else {
|
} else {
|
||||||
logger.Debug("Skipping count as requested")
|
logger.Debug("Skipping count as requested")
|
||||||
total = -1 // Indicate count was skipped
|
total = -1 // Indicate count was skipped
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user