mirror of
https://github.com/bitechdev/ResolveSpec.git
synced 2025-12-29 15:54:26 +00:00
Compare commits
7 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ed67caf055 | ||
|
|
63ed62a9a3 | ||
|
|
0525323a47 | ||
|
|
c3443f702e | ||
|
|
45c463c117 | ||
|
|
84d673ce14 | ||
|
|
02fbdbd651 |
6
Makefile
6
Makefile
@@ -16,7 +16,7 @@ test: test-unit test-integration
|
|||||||
# Start PostgreSQL for integration tests
|
# Start PostgreSQL for integration tests
|
||||||
docker-up:
|
docker-up:
|
||||||
@echo "Starting PostgreSQL container..."
|
@echo "Starting PostgreSQL container..."
|
||||||
@docker-compose up -d postgres-test
|
@podman compose up -d postgres-test
|
||||||
@echo "Waiting for PostgreSQL to be ready..."
|
@echo "Waiting for PostgreSQL to be ready..."
|
||||||
@sleep 5
|
@sleep 5
|
||||||
@echo "PostgreSQL is ready!"
|
@echo "PostgreSQL is ready!"
|
||||||
@@ -24,12 +24,12 @@ docker-up:
|
|||||||
# Stop PostgreSQL container
|
# Stop PostgreSQL container
|
||||||
docker-down:
|
docker-down:
|
||||||
@echo "Stopping PostgreSQL container..."
|
@echo "Stopping PostgreSQL container..."
|
||||||
@docker-compose down
|
@podman compose down
|
||||||
|
|
||||||
# Clean up Docker volumes and test data
|
# Clean up Docker volumes and test data
|
||||||
clean:
|
clean:
|
||||||
@echo "Cleaning up..."
|
@echo "Cleaning up..."
|
||||||
@docker-compose down -v
|
@podman compose down -v
|
||||||
@echo "Cleanup complete!"
|
@echo "Cleanup complete!"
|
||||||
|
|
||||||
# Run integration tests with Docker (full workflow)
|
# Run integration tests with Docker (full workflow)
|
||||||
|
|||||||
20
pkg/cache/cache_manager.go
vendored
20
pkg/cache/cache_manager.go
vendored
@@ -57,11 +57,31 @@ func (c *Cache) SetBytes(ctx context.Context, key string, value []byte, ttl time
|
|||||||
return c.provider.Set(ctx, key, value, ttl)
|
return c.provider.Set(ctx, key, value, ttl)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetWithTags serializes and stores a value in the cache with the specified TTL and tags.
|
||||||
|
func (c *Cache) SetWithTags(ctx context.Context, key string, value interface{}, ttl time.Duration, tags []string) error {
|
||||||
|
data, err := json.Marshal(value)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to serialize: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.provider.SetWithTags(ctx, key, data, ttl, tags)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetBytesWithTags stores raw bytes in the cache with the specified TTL and tags.
|
||||||
|
func (c *Cache) SetBytesWithTags(ctx context.Context, key string, value []byte, ttl time.Duration, tags []string) error {
|
||||||
|
return c.provider.SetWithTags(ctx, key, value, ttl, tags)
|
||||||
|
}
|
||||||
|
|
||||||
// Delete removes a key from the cache.
|
// Delete removes a key from the cache.
|
||||||
func (c *Cache) Delete(ctx context.Context, key string) error {
|
func (c *Cache) Delete(ctx context.Context, key string) error {
|
||||||
return c.provider.Delete(ctx, key)
|
return c.provider.Delete(ctx, key)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DeleteByTag removes all keys associated with the given tag.
|
||||||
|
func (c *Cache) DeleteByTag(ctx context.Context, tag string) error {
|
||||||
|
return c.provider.DeleteByTag(ctx, tag)
|
||||||
|
}
|
||||||
|
|
||||||
// DeleteByPattern removes all keys matching the pattern.
|
// DeleteByPattern removes all keys matching the pattern.
|
||||||
func (c *Cache) DeleteByPattern(ctx context.Context, pattern string) error {
|
func (c *Cache) DeleteByPattern(ctx context.Context, pattern string) error {
|
||||||
return c.provider.DeleteByPattern(ctx, pattern)
|
return c.provider.DeleteByPattern(ctx, pattern)
|
||||||
|
|||||||
8
pkg/cache/provider.go
vendored
8
pkg/cache/provider.go
vendored
@@ -15,9 +15,17 @@ type Provider interface {
|
|||||||
// If ttl is 0, the item never expires.
|
// If ttl is 0, the item never expires.
|
||||||
Set(ctx context.Context, key string, value []byte, ttl time.Duration) error
|
Set(ctx context.Context, key string, value []byte, ttl time.Duration) error
|
||||||
|
|
||||||
|
// SetWithTags stores a value in the cache with the specified TTL and tags.
|
||||||
|
// Tags can be used to invalidate groups of related keys.
|
||||||
|
// If ttl is 0, the item never expires.
|
||||||
|
SetWithTags(ctx context.Context, key string, value []byte, ttl time.Duration, tags []string) error
|
||||||
|
|
||||||
// Delete removes a key from the cache.
|
// Delete removes a key from the cache.
|
||||||
Delete(ctx context.Context, key string) error
|
Delete(ctx context.Context, key string) error
|
||||||
|
|
||||||
|
// DeleteByTag removes all keys associated with the given tag.
|
||||||
|
DeleteByTag(ctx context.Context, tag string) error
|
||||||
|
|
||||||
// DeleteByPattern removes all keys matching the pattern.
|
// DeleteByPattern removes all keys matching the pattern.
|
||||||
// Pattern syntax depends on the provider implementation.
|
// Pattern syntax depends on the provider implementation.
|
||||||
DeleteByPattern(ctx context.Context, pattern string) error
|
DeleteByPattern(ctx context.Context, pattern string) error
|
||||||
|
|||||||
140
pkg/cache/provider_memcache.go
vendored
140
pkg/cache/provider_memcache.go
vendored
@@ -2,6 +2,7 @@ package cache
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -97,8 +98,115 @@ func (m *MemcacheProvider) Set(ctx context.Context, key string, value []byte, tt
|
|||||||
return m.client.Set(item)
|
return m.client.Set(item)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetWithTags stores a value in the cache with the specified TTL and tags.
|
||||||
|
// Note: Tag support in Memcache is limited and less efficient than Redis.
|
||||||
|
func (m *MemcacheProvider) SetWithTags(ctx context.Context, key string, value []byte, ttl time.Duration, tags []string) error {
|
||||||
|
if ttl == 0 {
|
||||||
|
ttl = m.options.DefaultTTL
|
||||||
|
}
|
||||||
|
|
||||||
|
expiration := int32(ttl.Seconds())
|
||||||
|
|
||||||
|
// Set the main value
|
||||||
|
item := &memcache.Item{
|
||||||
|
Key: key,
|
||||||
|
Value: value,
|
||||||
|
Expiration: expiration,
|
||||||
|
}
|
||||||
|
if err := m.client.Set(item); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store tags for this key
|
||||||
|
if len(tags) > 0 {
|
||||||
|
tagsData, err := json.Marshal(tags)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to marshal tags: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tagsItem := &memcache.Item{
|
||||||
|
Key: fmt.Sprintf("cache:tags:%s", key),
|
||||||
|
Value: tagsData,
|
||||||
|
Expiration: expiration,
|
||||||
|
}
|
||||||
|
if err := m.client.Set(tagsItem); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add key to each tag's key list
|
||||||
|
for _, tag := range tags {
|
||||||
|
tagKey := fmt.Sprintf("cache:tag:%s", tag)
|
||||||
|
|
||||||
|
// Get existing keys for this tag
|
||||||
|
var keys []string
|
||||||
|
if item, err := m.client.Get(tagKey); err == nil {
|
||||||
|
_ = json.Unmarshal(item.Value, &keys)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add current key if not already present
|
||||||
|
found := false
|
||||||
|
for _, k := range keys {
|
||||||
|
if k == key {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
keys = append(keys, key)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store updated key list
|
||||||
|
keysData, err := json.Marshal(keys)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
tagItem := &memcache.Item{
|
||||||
|
Key: tagKey,
|
||||||
|
Value: keysData,
|
||||||
|
Expiration: expiration + 3600, // Give tag lists longer TTL
|
||||||
|
}
|
||||||
|
_ = m.client.Set(tagItem)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// Delete removes a key from the cache.
|
// Delete removes a key from the cache.
|
||||||
func (m *MemcacheProvider) Delete(ctx context.Context, key string) error {
|
func (m *MemcacheProvider) Delete(ctx context.Context, key string) error {
|
||||||
|
// Get tags for this key
|
||||||
|
tagsKey := fmt.Sprintf("cache:tags:%s", key)
|
||||||
|
if item, err := m.client.Get(tagsKey); err == nil {
|
||||||
|
var tags []string
|
||||||
|
if err := json.Unmarshal(item.Value, &tags); err == nil {
|
||||||
|
// Remove key from each tag's key list
|
||||||
|
for _, tag := range tags {
|
||||||
|
tagKey := fmt.Sprintf("cache:tag:%s", tag)
|
||||||
|
if tagItem, err := m.client.Get(tagKey); err == nil {
|
||||||
|
var keys []string
|
||||||
|
if err := json.Unmarshal(tagItem.Value, &keys); err == nil {
|
||||||
|
// Remove current key from the list
|
||||||
|
newKeys := make([]string, 0, len(keys))
|
||||||
|
for _, k := range keys {
|
||||||
|
if k != key {
|
||||||
|
newKeys = append(newKeys, k)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Update the tag's key list
|
||||||
|
if keysData, err := json.Marshal(newKeys); err == nil {
|
||||||
|
tagItem.Value = keysData
|
||||||
|
_ = m.client.Set(tagItem)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Delete the tags key
|
||||||
|
_ = m.client.Delete(tagsKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete the actual key
|
||||||
err := m.client.Delete(key)
|
err := m.client.Delete(key)
|
||||||
if err == memcache.ErrCacheMiss {
|
if err == memcache.ErrCacheMiss {
|
||||||
return nil
|
return nil
|
||||||
@@ -106,6 +214,38 @@ func (m *MemcacheProvider) Delete(ctx context.Context, key string) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DeleteByTag removes all keys associated with the given tag.
|
||||||
|
func (m *MemcacheProvider) DeleteByTag(ctx context.Context, tag string) error {
|
||||||
|
tagKey := fmt.Sprintf("cache:tag:%s", tag)
|
||||||
|
|
||||||
|
// Get all keys associated with this tag
|
||||||
|
item, err := m.client.Get(tagKey)
|
||||||
|
if err == memcache.ErrCacheMiss {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
var keys []string
|
||||||
|
if err := json.Unmarshal(item.Value, &keys); err != nil {
|
||||||
|
return fmt.Errorf("failed to unmarshal tag keys: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete all keys
|
||||||
|
for _, key := range keys {
|
||||||
|
_ = m.client.Delete(key)
|
||||||
|
// Also delete the tags key for this cache key
|
||||||
|
tagsKey := fmt.Sprintf("cache:tags:%s", key)
|
||||||
|
_ = m.client.Delete(tagsKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete the tag key itself
|
||||||
|
_ = m.client.Delete(tagKey)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// DeleteByPattern removes all keys matching the pattern.
|
// DeleteByPattern removes all keys matching the pattern.
|
||||||
// Note: Memcache does not support pattern-based deletion natively.
|
// Note: Memcache does not support pattern-based deletion natively.
|
||||||
// This is a no-op for memcache and returns an error.
|
// This is a no-op for memcache and returns an error.
|
||||||
|
|||||||
118
pkg/cache/provider_memory.go
vendored
118
pkg/cache/provider_memory.go
vendored
@@ -15,6 +15,7 @@ type memoryItem struct {
|
|||||||
Expiration time.Time
|
Expiration time.Time
|
||||||
LastAccess time.Time
|
LastAccess time.Time
|
||||||
HitCount int64
|
HitCount int64
|
||||||
|
Tags []string
|
||||||
}
|
}
|
||||||
|
|
||||||
// isExpired checks if the item has expired.
|
// isExpired checks if the item has expired.
|
||||||
@@ -27,11 +28,12 @@ func (m *memoryItem) isExpired() bool {
|
|||||||
|
|
||||||
// MemoryProvider is an in-memory implementation of the Provider interface.
|
// MemoryProvider is an in-memory implementation of the Provider interface.
|
||||||
type MemoryProvider struct {
|
type MemoryProvider struct {
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
items map[string]*memoryItem
|
items map[string]*memoryItem
|
||||||
options *Options
|
tagToKeys map[string]map[string]struct{} // tag -> set of keys
|
||||||
hits atomic.Int64
|
options *Options
|
||||||
misses atomic.Int64
|
hits atomic.Int64
|
||||||
|
misses atomic.Int64
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewMemoryProvider creates a new in-memory cache provider.
|
// NewMemoryProvider creates a new in-memory cache provider.
|
||||||
@@ -44,8 +46,9 @@ func NewMemoryProvider(opts *Options) *MemoryProvider {
|
|||||||
}
|
}
|
||||||
|
|
||||||
return &MemoryProvider{
|
return &MemoryProvider{
|
||||||
items: make(map[string]*memoryItem),
|
items: make(map[string]*memoryItem),
|
||||||
options: opts,
|
tagToKeys: make(map[string]map[string]struct{}),
|
||||||
|
options: opts,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -114,15 +117,116 @@ func (m *MemoryProvider) Set(ctx context.Context, key string, value []byte, ttl
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetWithTags stores a value in the cache with the specified TTL and tags.
|
||||||
|
func (m *MemoryProvider) SetWithTags(ctx context.Context, key string, value []byte, ttl time.Duration, tags []string) error {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
if ttl == 0 {
|
||||||
|
ttl = m.options.DefaultTTL
|
||||||
|
}
|
||||||
|
|
||||||
|
var expiration time.Time
|
||||||
|
if ttl > 0 {
|
||||||
|
expiration = time.Now().Add(ttl)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check max size and evict if necessary
|
||||||
|
if m.options.MaxSize > 0 && len(m.items) >= m.options.MaxSize {
|
||||||
|
if _, exists := m.items[key]; !exists {
|
||||||
|
m.evictOne()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove old tag associations if key exists
|
||||||
|
if oldItem, exists := m.items[key]; exists {
|
||||||
|
for _, tag := range oldItem.Tags {
|
||||||
|
if keySet, ok := m.tagToKeys[tag]; ok {
|
||||||
|
delete(keySet, key)
|
||||||
|
if len(keySet) == 0 {
|
||||||
|
delete(m.tagToKeys, tag)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store the item
|
||||||
|
m.items[key] = &memoryItem{
|
||||||
|
Value: value,
|
||||||
|
Expiration: expiration,
|
||||||
|
LastAccess: time.Now(),
|
||||||
|
Tags: tags,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add new tag associations
|
||||||
|
for _, tag := range tags {
|
||||||
|
if m.tagToKeys[tag] == nil {
|
||||||
|
m.tagToKeys[tag] = make(map[string]struct{})
|
||||||
|
}
|
||||||
|
m.tagToKeys[tag][key] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// Delete removes a key from the cache.
|
// Delete removes a key from the cache.
|
||||||
func (m *MemoryProvider) Delete(ctx context.Context, key string) error {
|
func (m *MemoryProvider) Delete(ctx context.Context, key string) error {
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
defer m.mu.Unlock()
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
// Remove tag associations
|
||||||
|
if item, exists := m.items[key]; exists {
|
||||||
|
for _, tag := range item.Tags {
|
||||||
|
if keySet, ok := m.tagToKeys[tag]; ok {
|
||||||
|
delete(keySet, key)
|
||||||
|
if len(keySet) == 0 {
|
||||||
|
delete(m.tagToKeys, tag)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
delete(m.items, key)
|
delete(m.items, key)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DeleteByTag removes all keys associated with the given tag.
|
||||||
|
func (m *MemoryProvider) DeleteByTag(ctx context.Context, tag string) error {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
// Get all keys associated with this tag
|
||||||
|
keySet, exists := m.tagToKeys[tag]
|
||||||
|
if !exists {
|
||||||
|
return nil // No keys with this tag
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete all items with this tag
|
||||||
|
for key := range keySet {
|
||||||
|
if item, ok := m.items[key]; ok {
|
||||||
|
// Remove this tag from the item's tag list
|
||||||
|
newTags := make([]string, 0, len(item.Tags))
|
||||||
|
for _, t := range item.Tags {
|
||||||
|
if t != tag {
|
||||||
|
newTags = append(newTags, t)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If item has no more tags, delete it
|
||||||
|
// Otherwise update its tags
|
||||||
|
if len(newTags) == 0 {
|
||||||
|
delete(m.items, key)
|
||||||
|
} else {
|
||||||
|
item.Tags = newTags
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove the tag mapping
|
||||||
|
delete(m.tagToKeys, tag)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// DeleteByPattern removes all keys matching the pattern.
|
// DeleteByPattern removes all keys matching the pattern.
|
||||||
func (m *MemoryProvider) DeleteByPattern(ctx context.Context, pattern string) error {
|
func (m *MemoryProvider) DeleteByPattern(ctx context.Context, pattern string) error {
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
|
|||||||
86
pkg/cache/provider_redis.go
vendored
86
pkg/cache/provider_redis.go
vendored
@@ -103,9 +103,93 @@ func (r *RedisProvider) Set(ctx context.Context, key string, value []byte, ttl t
|
|||||||
return r.client.Set(ctx, key, value, ttl).Err()
|
return r.client.Set(ctx, key, value, ttl).Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetWithTags stores a value in the cache with the specified TTL and tags.
|
||||||
|
func (r *RedisProvider) SetWithTags(ctx context.Context, key string, value []byte, ttl time.Duration, tags []string) error {
|
||||||
|
if ttl == 0 {
|
||||||
|
ttl = r.options.DefaultTTL
|
||||||
|
}
|
||||||
|
|
||||||
|
pipe := r.client.Pipeline()
|
||||||
|
|
||||||
|
// Set the value
|
||||||
|
pipe.Set(ctx, key, value, ttl)
|
||||||
|
|
||||||
|
// Add key to each tag's set
|
||||||
|
for _, tag := range tags {
|
||||||
|
tagKey := fmt.Sprintf("cache:tag:%s", tag)
|
||||||
|
pipe.SAdd(ctx, tagKey, key)
|
||||||
|
// Set expiration on tag set (longer than cache items to ensure cleanup)
|
||||||
|
if ttl > 0 {
|
||||||
|
pipe.Expire(ctx, tagKey, ttl+time.Hour)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store tags for this key for later cleanup
|
||||||
|
if len(tags) > 0 {
|
||||||
|
tagsKey := fmt.Sprintf("cache:tags:%s", key)
|
||||||
|
pipe.SAdd(ctx, tagsKey, tags)
|
||||||
|
if ttl > 0 {
|
||||||
|
pipe.Expire(ctx, tagsKey, ttl)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := pipe.Exec(ctx)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
// Delete removes a key from the cache.
|
// Delete removes a key from the cache.
|
||||||
func (r *RedisProvider) Delete(ctx context.Context, key string) error {
|
func (r *RedisProvider) Delete(ctx context.Context, key string) error {
|
||||||
return r.client.Del(ctx, key).Err()
|
pipe := r.client.Pipeline()
|
||||||
|
|
||||||
|
// Get tags for this key
|
||||||
|
tagsKey := fmt.Sprintf("cache:tags:%s", key)
|
||||||
|
tags, err := r.client.SMembers(ctx, tagsKey).Result()
|
||||||
|
if err == nil && len(tags) > 0 {
|
||||||
|
// Remove key from each tag set
|
||||||
|
for _, tag := range tags {
|
||||||
|
tagKey := fmt.Sprintf("cache:tag:%s", tag)
|
||||||
|
pipe.SRem(ctx, tagKey, key)
|
||||||
|
}
|
||||||
|
// Delete the tags key
|
||||||
|
pipe.Del(ctx, tagsKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete the actual key
|
||||||
|
pipe.Del(ctx, key)
|
||||||
|
|
||||||
|
_, err = pipe.Exec(ctx)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteByTag removes all keys associated with the given tag.
|
||||||
|
func (r *RedisProvider) DeleteByTag(ctx context.Context, tag string) error {
|
||||||
|
tagKey := fmt.Sprintf("cache:tag:%s", tag)
|
||||||
|
|
||||||
|
// Get all keys associated with this tag
|
||||||
|
keys, err := r.client.SMembers(ctx, tagKey).Result()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(keys) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
pipe := r.client.Pipeline()
|
||||||
|
|
||||||
|
// Delete all keys and their tag associations
|
||||||
|
for _, key := range keys {
|
||||||
|
pipe.Del(ctx, key)
|
||||||
|
// Also delete the tags key for this cache key
|
||||||
|
tagsKey := fmt.Sprintf("cache:tags:%s", key)
|
||||||
|
pipe.Del(ctx, tagsKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete the tag set itself
|
||||||
|
pipe.Del(ctx, tagKey)
|
||||||
|
|
||||||
|
_, err = pipe.Exec(ctx)
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteByPattern removes all keys matching the pattern.
|
// DeleteByPattern removes all keys matching the pattern.
|
||||||
|
|||||||
151
pkg/cache/query_cache_test.go
vendored
151
pkg/cache/query_cache_test.go
vendored
@@ -1,151 +0,0 @@
|
|||||||
package cache
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestBuildQueryCacheKey(t *testing.T) {
|
|
||||||
filters := []common.FilterOption{
|
|
||||||
{Column: "name", Operator: "eq", Value: "test"},
|
|
||||||
{Column: "age", Operator: "gt", Value: 25},
|
|
||||||
}
|
|
||||||
sorts := []common.SortOption{
|
|
||||||
{Column: "name", Direction: "asc"},
|
|
||||||
}
|
|
||||||
|
|
||||||
// Generate cache key
|
|
||||||
key1 := BuildQueryCacheKey("users", filters, sorts, "status = 'active'", "")
|
|
||||||
|
|
||||||
// Same parameters should generate same key
|
|
||||||
key2 := BuildQueryCacheKey("users", filters, sorts, "status = 'active'", "")
|
|
||||||
|
|
||||||
if key1 != key2 {
|
|
||||||
t.Errorf("Expected same cache keys for identical parameters, got %s and %s", key1, key2)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Different parameters should generate different key
|
|
||||||
key3 := BuildQueryCacheKey("users", filters, sorts, "status = 'inactive'", "")
|
|
||||||
|
|
||||||
if key1 == key3 {
|
|
||||||
t.Errorf("Expected different cache keys for different parameters, got %s and %s", key1, key3)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestBuildExtendedQueryCacheKey(t *testing.T) {
|
|
||||||
filters := []common.FilterOption{
|
|
||||||
{Column: "name", Operator: "eq", Value: "test"},
|
|
||||||
}
|
|
||||||
sorts := []common.SortOption{
|
|
||||||
{Column: "name", Direction: "asc"},
|
|
||||||
}
|
|
||||||
expandOpts := []interface{}{
|
|
||||||
map[string]interface{}{
|
|
||||||
"relation": "posts",
|
|
||||||
"where": "status = 'published'",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
// Generate cache key
|
|
||||||
key1 := BuildExtendedQueryCacheKey("users", filters, sorts, "", "", expandOpts, false, "", "")
|
|
||||||
|
|
||||||
// Same parameters should generate same key
|
|
||||||
key2 := BuildExtendedQueryCacheKey("users", filters, sorts, "", "", expandOpts, false, "", "")
|
|
||||||
|
|
||||||
if key1 != key2 {
|
|
||||||
t.Errorf("Expected same cache keys for identical parameters")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Different distinct value should generate different key
|
|
||||||
key3 := BuildExtendedQueryCacheKey("users", filters, sorts, "", "", expandOpts, true, "", "")
|
|
||||||
|
|
||||||
if key1 == key3 {
|
|
||||||
t.Errorf("Expected different cache keys for different distinct values")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGetQueryTotalCacheKey(t *testing.T) {
|
|
||||||
hash := "abc123"
|
|
||||||
key := GetQueryTotalCacheKey(hash)
|
|
||||||
|
|
||||||
expected := "query_total:abc123"
|
|
||||||
if key != expected {
|
|
||||||
t.Errorf("Expected %s, got %s", expected, key)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCachedTotalIntegration(t *testing.T) {
|
|
||||||
// Initialize cache with memory provider for testing
|
|
||||||
UseMemory(&Options{
|
|
||||||
DefaultTTL: 1 * time.Minute,
|
|
||||||
MaxSize: 100,
|
|
||||||
})
|
|
||||||
|
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
// Create test data
|
|
||||||
filters := []common.FilterOption{
|
|
||||||
{Column: "status", Operator: "eq", Value: "active"},
|
|
||||||
}
|
|
||||||
sorts := []common.SortOption{
|
|
||||||
{Column: "created_at", Direction: "desc"},
|
|
||||||
}
|
|
||||||
|
|
||||||
// Build cache key
|
|
||||||
cacheKeyHash := BuildQueryCacheKey("test_table", filters, sorts, "", "")
|
|
||||||
cacheKey := GetQueryTotalCacheKey(cacheKeyHash)
|
|
||||||
|
|
||||||
// Store a total count in cache
|
|
||||||
totalToCache := CachedTotal{Total: 42}
|
|
||||||
err := GetDefaultCache().Set(ctx, cacheKey, totalToCache, time.Minute)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to set cache: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Retrieve from cache
|
|
||||||
var cachedTotal CachedTotal
|
|
||||||
err = GetDefaultCache().Get(ctx, cacheKey, &cachedTotal)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to get from cache: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if cachedTotal.Total != 42 {
|
|
||||||
t.Errorf("Expected total 42, got %d", cachedTotal.Total)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test cache miss
|
|
||||||
nonExistentKey := GetQueryTotalCacheKey("nonexistent")
|
|
||||||
var missedTotal CachedTotal
|
|
||||||
err = GetDefaultCache().Get(ctx, nonExistentKey, &missedTotal)
|
|
||||||
if err == nil {
|
|
||||||
t.Errorf("Expected error for cache miss, got nil")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestHashString(t *testing.T) {
|
|
||||||
input1 := "test string"
|
|
||||||
input2 := "test string"
|
|
||||||
input3 := "different string"
|
|
||||||
|
|
||||||
hash1 := hashString(input1)
|
|
||||||
hash2 := hashString(input2)
|
|
||||||
hash3 := hashString(input3)
|
|
||||||
|
|
||||||
// Same input should produce same hash
|
|
||||||
if hash1 != hash2 {
|
|
||||||
t.Errorf("Expected same hash for identical inputs")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Different input should produce different hash
|
|
||||||
if hash1 == hash3 {
|
|
||||||
t.Errorf("Expected different hash for different inputs")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Hash should be hex encoded SHA256 (64 characters)
|
|
||||||
if len(hash1) != 64 {
|
|
||||||
t.Errorf("Expected hash length of 64, got %d", len(hash1))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -208,21 +208,9 @@ func SanitizeWhereClause(where string, tableName string, options ...*RequestOpti
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if tableName != "" && !hasTablePrefix(condToCheck) {
|
|
||||||
// If tableName is provided and the condition DOESN'T have a table prefix,
|
|
||||||
// qualify unambiguous column references to prevent "ambiguous column" errors
|
|
||||||
// when there are multiple joins on the same table (e.g., recursive preloads)
|
|
||||||
columnName := extractUnqualifiedColumnName(condToCheck)
|
|
||||||
if columnName != "" && (validColumns == nil || isValidColumn(columnName, validColumns)) {
|
|
||||||
// Qualify the column with the table name
|
|
||||||
// Be careful to only replace the column name, not other occurrences of the string
|
|
||||||
oldRef := columnName
|
|
||||||
newRef := tableName + "." + columnName
|
|
||||||
// Use word boundary matching to avoid replacing partial matches
|
|
||||||
cond = qualifyColumnInCondition(cond, oldRef, newRef)
|
|
||||||
logger.Debug("Qualified unqualified column in condition: '%s' added table prefix '%s'", oldRef, tableName)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
// Note: We no longer add prefixes to unqualified columns here.
|
||||||
|
// Use AddTablePrefixToColumns() separately if you need to add prefixes.
|
||||||
|
|
||||||
validConditions = append(validConditions, cond)
|
validConditions = append(validConditions, cond)
|
||||||
}
|
}
|
||||||
@@ -633,3 +621,145 @@ func isValidColumn(columnName string, validColumns map[string]bool) bool {
|
|||||||
}
|
}
|
||||||
return validColumns[strings.ToLower(columnName)]
|
return validColumns[strings.ToLower(columnName)]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AddTablePrefixToColumns adds table prefix to unqualified column references in a WHERE clause.
|
||||||
|
// This function only prefixes simple column references and skips:
|
||||||
|
// - Columns already having a table prefix (containing a dot)
|
||||||
|
// - Columns inside function calls or expressions (inside parentheses)
|
||||||
|
// - Columns inside subqueries
|
||||||
|
// - Columns that don't exist in the table (validation via model registry)
|
||||||
|
//
|
||||||
|
// Examples:
|
||||||
|
// - "status = 'active'" -> "users.status = 'active'" (if status exists in users table)
|
||||||
|
// - "COALESCE(status, 'default') = 'active'" -> unchanged (status inside function)
|
||||||
|
// - "users.status = 'active'" -> unchanged (already has prefix)
|
||||||
|
// - "(status = 'active')" -> "(users.status = 'active')" (grouping parens are OK)
|
||||||
|
// - "invalid_col = 'value'" -> unchanged (if invalid_col doesn't exist in table)
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - where: The WHERE clause to process
|
||||||
|
// - tableName: The table name to use as prefix
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - The WHERE clause with table prefixes added to appropriate and valid columns
|
||||||
|
func AddTablePrefixToColumns(where string, tableName string) string {
|
||||||
|
if where == "" || tableName == "" {
|
||||||
|
return where
|
||||||
|
}
|
||||||
|
|
||||||
|
where = strings.TrimSpace(where)
|
||||||
|
|
||||||
|
// Get valid columns from the model registry for validation
|
||||||
|
validColumns := getValidColumnsForTable(tableName)
|
||||||
|
|
||||||
|
// Split by AND to handle multiple conditions (parenthesis-aware)
|
||||||
|
conditions := splitByAND(where)
|
||||||
|
prefixedConditions := make([]string, 0, len(conditions))
|
||||||
|
|
||||||
|
for _, cond := range conditions {
|
||||||
|
cond = strings.TrimSpace(cond)
|
||||||
|
if cond == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process this condition to add table prefix if appropriate
|
||||||
|
processedCond := addPrefixToSingleCondition(cond, tableName, validColumns)
|
||||||
|
prefixedConditions = append(prefixedConditions, processedCond)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(prefixedConditions) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
return strings.Join(prefixedConditions, " AND ")
|
||||||
|
}
|
||||||
|
|
||||||
|
// addPrefixToSingleCondition adds table prefix to a single condition if appropriate
|
||||||
|
// Returns the condition unchanged if:
|
||||||
|
// - The condition is a SQL literal/expression (true, false, null, 1=1, etc.)
|
||||||
|
// - The column reference is inside a function call
|
||||||
|
// - The column already has a table prefix
|
||||||
|
// - No valid column reference is found
|
||||||
|
// - The column doesn't exist in the table (when validColumns is provided)
|
||||||
|
func addPrefixToSingleCondition(cond string, tableName string, validColumns map[string]bool) string {
|
||||||
|
// Strip outer grouping parentheses to get to the actual condition
|
||||||
|
strippedCond := stripOuterParentheses(cond)
|
||||||
|
|
||||||
|
// Skip SQL literals and trivial conditions (true, false, null, 1=1, etc.)
|
||||||
|
if IsSQLExpression(strippedCond) || IsTrivialCondition(strippedCond) {
|
||||||
|
logger.Debug("Skipping SQL literal/trivial condition: '%s'", strippedCond)
|
||||||
|
return cond
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract the left side of the comparison (before the operator)
|
||||||
|
columnRef := extractLeftSideOfComparison(strippedCond)
|
||||||
|
if columnRef == "" {
|
||||||
|
return cond
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip if it already has a prefix (contains a dot)
|
||||||
|
if strings.Contains(columnRef, ".") {
|
||||||
|
logger.Debug("Skipping column '%s' - already has table prefix", columnRef)
|
||||||
|
return cond
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip if it's a function call or expression (contains parentheses)
|
||||||
|
if strings.Contains(columnRef, "(") {
|
||||||
|
logger.Debug("Skipping column reference '%s' - inside function or expression", columnRef)
|
||||||
|
return cond
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate that the column exists in the table (if we have column info)
|
||||||
|
if !isValidColumn(columnRef, validColumns) {
|
||||||
|
logger.Debug("Skipping column '%s' - not found in table '%s'", columnRef, tableName)
|
||||||
|
return cond
|
||||||
|
}
|
||||||
|
|
||||||
|
// It's a simple unqualified column reference that exists in the table - add the table prefix
|
||||||
|
newRef := tableName + "." + columnRef
|
||||||
|
result := qualifyColumnInCondition(cond, columnRef, newRef)
|
||||||
|
logger.Debug("Added table prefix to column: '%s' -> '%s'", columnRef, newRef)
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractLeftSideOfComparison extracts the left side of a comparison operator from a condition.
|
||||||
|
// This is used to identify the column reference that may need a table prefix.
|
||||||
|
//
|
||||||
|
// Examples:
|
||||||
|
// - "status = 'active'" returns "status"
|
||||||
|
// - "COALESCE(status, 'default') = 'active'" returns "COALESCE(status, 'default')"
|
||||||
|
// - "priority > 5" returns "priority"
|
||||||
|
//
|
||||||
|
// Returns empty string if no operator is found.
|
||||||
|
func extractLeftSideOfComparison(cond string) string {
|
||||||
|
operators := []string{" = ", " != ", " <> ", " > ", " >= ", " < ", " <= ", " LIKE ", " like ", " IN ", " in ", " IS ", " is ", " NOT ", " not "}
|
||||||
|
|
||||||
|
// Find the first operator outside of parentheses and quotes
|
||||||
|
minIdx := -1
|
||||||
|
for _, op := range operators {
|
||||||
|
idx := findOperatorOutsideParentheses(cond, op)
|
||||||
|
if idx > 0 && (minIdx == -1 || idx < minIdx) {
|
||||||
|
minIdx = idx
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if minIdx > 0 {
|
||||||
|
leftSide := strings.TrimSpace(cond[:minIdx])
|
||||||
|
// Remove any surrounding quotes
|
||||||
|
leftSide = strings.Trim(leftSide, "`\"'")
|
||||||
|
return leftSide
|
||||||
|
}
|
||||||
|
|
||||||
|
// No operator found - might be a boolean column
|
||||||
|
parts := strings.Fields(cond)
|
||||||
|
if len(parts) > 0 {
|
||||||
|
columnRef := strings.Trim(parts[0], "`\"'")
|
||||||
|
// Make sure it's not a SQL keyword
|
||||||
|
if !IsSQLKeyword(strings.ToLower(columnRef)) {
|
||||||
|
return columnRef
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|||||||
@@ -273,25 +273,151 @@ handler.SetOpenAPIGenerator(func() (string, error) {
|
|||||||
})
|
})
|
||||||
```
|
```
|
||||||
|
|
||||||
## Using with Swagger UI
|
## Using the Built-in UI Handler
|
||||||
|
|
||||||
You can serve the generated OpenAPI spec with Swagger UI:
|
The package includes a built-in UI handler that serves popular OpenAPI visualization tools. No need to download or manage static files - everything is served from CDN.
|
||||||
|
|
||||||
|
### Quick Start
|
||||||
|
|
||||||
|
```go
|
||||||
|
import (
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/openapi"
|
||||||
|
"github.com/gorilla/mux"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
router := mux.NewRouter()
|
||||||
|
|
||||||
|
// Setup your API routes and OpenAPI generator...
|
||||||
|
// (see examples above)
|
||||||
|
|
||||||
|
// Add the UI handler - defaults to Swagger UI
|
||||||
|
openapi.SetupUIRoute(router, "/docs", openapi.UIConfig{
|
||||||
|
UIType: openapi.SwaggerUI,
|
||||||
|
SpecURL: "/openapi",
|
||||||
|
Title: "My API Documentation",
|
||||||
|
})
|
||||||
|
|
||||||
|
// Now visit http://localhost:8080/docs
|
||||||
|
http.ListenAndServe(":8080", router)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Supported UI Frameworks
|
||||||
|
|
||||||
|
The handler supports four popular OpenAPI UI frameworks:
|
||||||
|
|
||||||
|
#### 1. Swagger UI (Default)
|
||||||
|
The most widely used OpenAPI UI with excellent compatibility and features.
|
||||||
|
|
||||||
|
```go
|
||||||
|
openapi.SetupUIRoute(router, "/docs", openapi.UIConfig{
|
||||||
|
UIType: openapi.SwaggerUI,
|
||||||
|
Theme: "dark", // optional: "light" or "dark"
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 2. RapiDoc
|
||||||
|
Modern, customizable, and feature-rich OpenAPI UI.
|
||||||
|
|
||||||
|
```go
|
||||||
|
openapi.SetupUIRoute(router, "/docs", openapi.UIConfig{
|
||||||
|
UIType: openapi.RapiDoc,
|
||||||
|
Theme: "dark",
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 3. Redoc
|
||||||
|
Clean, responsive documentation with great UX.
|
||||||
|
|
||||||
|
```go
|
||||||
|
openapi.SetupUIRoute(router, "/docs", openapi.UIConfig{
|
||||||
|
UIType: openapi.Redoc,
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 4. Scalar
|
||||||
|
Modern and sleek OpenAPI documentation.
|
||||||
|
|
||||||
|
```go
|
||||||
|
openapi.SetupUIRoute(router, "/docs", openapi.UIConfig{
|
||||||
|
UIType: openapi.Scalar,
|
||||||
|
Theme: "dark",
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
### Configuration Options
|
||||||
|
|
||||||
|
```go
|
||||||
|
type UIConfig struct {
|
||||||
|
UIType UIType // SwaggerUI, RapiDoc, Redoc, or Scalar
|
||||||
|
SpecURL string // URL to OpenAPI spec (default: "/openapi")
|
||||||
|
Title string // Page title (default: "API Documentation")
|
||||||
|
FaviconURL string // Custom favicon URL (optional)
|
||||||
|
CustomCSS string // Custom CSS to inject (optional)
|
||||||
|
Theme string // "light" or "dark" (support varies by UI)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Custom Styling Example
|
||||||
|
|
||||||
|
```go
|
||||||
|
openapi.SetupUIRoute(router, "/docs", openapi.UIConfig{
|
||||||
|
UIType: openapi.SwaggerUI,
|
||||||
|
Title: "Acme Corp API",
|
||||||
|
CustomCSS: `
|
||||||
|
.swagger-ui .topbar {
|
||||||
|
background-color: #1976d2;
|
||||||
|
}
|
||||||
|
.swagger-ui .info .title {
|
||||||
|
color: #1976d2;
|
||||||
|
}
|
||||||
|
`,
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
### Using Multiple UIs
|
||||||
|
|
||||||
|
You can serve different UIs at different paths:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Swagger UI at /docs
|
||||||
|
openapi.SetupUIRoute(router, "/docs", openapi.UIConfig{
|
||||||
|
UIType: openapi.SwaggerUI,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Redoc at /redoc
|
||||||
|
openapi.SetupUIRoute(router, "/redoc", openapi.UIConfig{
|
||||||
|
UIType: openapi.Redoc,
|
||||||
|
})
|
||||||
|
|
||||||
|
// RapiDoc at /api-docs
|
||||||
|
openapi.SetupUIRoute(router, "/api-docs", openapi.UIConfig{
|
||||||
|
UIType: openapi.RapiDoc,
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
### Manual Handler Usage
|
||||||
|
|
||||||
|
If you need more control, use the handler directly:
|
||||||
|
|
||||||
|
```go
|
||||||
|
handler := openapi.UIHandler(openapi.UIConfig{
|
||||||
|
UIType: openapi.SwaggerUI,
|
||||||
|
SpecURL: "/api/openapi.json",
|
||||||
|
})
|
||||||
|
|
||||||
|
router.Handle("/documentation", handler)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Using with External Swagger UI
|
||||||
|
|
||||||
|
Alternatively, you can use an external Swagger UI instance:
|
||||||
|
|
||||||
1. Get the spec from `/openapi`
|
1. Get the spec from `/openapi`
|
||||||
2. Load it in Swagger UI at `https://petstore.swagger.io/`
|
2. Load it in Swagger UI at `https://petstore.swagger.io/`
|
||||||
3. Or self-host Swagger UI and point it to your `/openapi` endpoint
|
3. Or self-host Swagger UI and point it to your `/openapi` endpoint
|
||||||
|
|
||||||
Example with self-hosted Swagger UI:
|
|
||||||
|
|
||||||
```go
|
|
||||||
// Serve Swagger UI static files
|
|
||||||
router.PathPrefix("/swagger/").Handler(
|
|
||||||
http.StripPrefix("/swagger/", http.FileServer(http.Dir("./swagger-ui"))),
|
|
||||||
)
|
|
||||||
|
|
||||||
// Configure Swagger UI to use /openapi
|
|
||||||
```
|
|
||||||
|
|
||||||
## Testing
|
## Testing
|
||||||
|
|
||||||
You can test the OpenAPI endpoint:
|
You can test the OpenAPI endpoint:
|
||||||
|
|||||||
@@ -183,6 +183,69 @@ func ExampleWithFuncSpec() {
|
|||||||
_ = generatorFunc
|
_ = generatorFunc
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ExampleWithUIHandler shows how to serve OpenAPI documentation with a web UI
|
||||||
|
func ExampleWithUIHandler(db *gorm.DB) {
|
||||||
|
// Create handler and configure OpenAPI generator
|
||||||
|
handler := restheadspec.NewHandlerWithGORM(db)
|
||||||
|
registry := modelregistry.NewModelRegistry()
|
||||||
|
|
||||||
|
handler.SetOpenAPIGenerator(func() (string, error) {
|
||||||
|
generator := NewGenerator(GeneratorConfig{
|
||||||
|
Title: "My API",
|
||||||
|
Description: "API documentation with interactive UI",
|
||||||
|
Version: "1.0.0",
|
||||||
|
BaseURL: "http://localhost:8080",
|
||||||
|
Registry: registry,
|
||||||
|
IncludeRestheadSpec: true,
|
||||||
|
})
|
||||||
|
return generator.GenerateJSON()
|
||||||
|
})
|
||||||
|
|
||||||
|
// Setup routes
|
||||||
|
router := mux.NewRouter()
|
||||||
|
restheadspec.SetupMuxRoutes(router, handler, nil)
|
||||||
|
|
||||||
|
// Add UI handlers for different frameworks
|
||||||
|
// Swagger UI at /docs (most popular)
|
||||||
|
SetupUIRoute(router, "/docs", UIConfig{
|
||||||
|
UIType: SwaggerUI,
|
||||||
|
SpecURL: "/openapi",
|
||||||
|
Title: "My API - Swagger UI",
|
||||||
|
Theme: "light",
|
||||||
|
})
|
||||||
|
|
||||||
|
// RapiDoc at /rapidoc (modern alternative)
|
||||||
|
SetupUIRoute(router, "/rapidoc", UIConfig{
|
||||||
|
UIType: RapiDoc,
|
||||||
|
SpecURL: "/openapi",
|
||||||
|
Title: "My API - RapiDoc",
|
||||||
|
})
|
||||||
|
|
||||||
|
// Redoc at /redoc (clean and responsive)
|
||||||
|
SetupUIRoute(router, "/redoc", UIConfig{
|
||||||
|
UIType: Redoc,
|
||||||
|
SpecURL: "/openapi",
|
||||||
|
Title: "My API - Redoc",
|
||||||
|
})
|
||||||
|
|
||||||
|
// Scalar at /scalar (modern and sleek)
|
||||||
|
SetupUIRoute(router, "/scalar", UIConfig{
|
||||||
|
UIType: Scalar,
|
||||||
|
SpecURL: "/openapi",
|
||||||
|
Title: "My API - Scalar",
|
||||||
|
Theme: "dark",
|
||||||
|
})
|
||||||
|
|
||||||
|
// Now you can access:
|
||||||
|
// http://localhost:8080/docs - Swagger UI
|
||||||
|
// http://localhost:8080/rapidoc - RapiDoc
|
||||||
|
// http://localhost:8080/redoc - Redoc
|
||||||
|
// http://localhost:8080/scalar - Scalar
|
||||||
|
// http://localhost:8080/openapi - Raw OpenAPI JSON
|
||||||
|
|
||||||
|
_ = router
|
||||||
|
}
|
||||||
|
|
||||||
// ExampleCustomization shows advanced customization options
|
// ExampleCustomization shows advanced customization options
|
||||||
func ExampleCustomization() {
|
func ExampleCustomization() {
|
||||||
// Create registry and register models with descriptions using struct tags
|
// Create registry and register models with descriptions using struct tags
|
||||||
|
|||||||
294
pkg/openapi/ui_handler.go
Normal file
294
pkg/openapi/ui_handler.go
Normal file
@@ -0,0 +1,294 @@
|
|||||||
|
package openapi
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"html/template"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/gorilla/mux"
|
||||||
|
)
|
||||||
|
|
||||||
|
// UIType represents the type of OpenAPI UI to serve
|
||||||
|
type UIType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
// SwaggerUI is the most popular OpenAPI UI
|
||||||
|
SwaggerUI UIType = "swagger-ui"
|
||||||
|
// RapiDoc is a modern, customizable OpenAPI UI
|
||||||
|
RapiDoc UIType = "rapidoc"
|
||||||
|
// Redoc is a clean, responsive OpenAPI UI
|
||||||
|
Redoc UIType = "redoc"
|
||||||
|
// Scalar is a modern and sleek OpenAPI UI
|
||||||
|
Scalar UIType = "scalar"
|
||||||
|
)
|
||||||
|
|
||||||
|
// UIConfig holds configuration for the OpenAPI UI handler
|
||||||
|
type UIConfig struct {
|
||||||
|
// UIType specifies which UI framework to use (default: SwaggerUI)
|
||||||
|
UIType UIType
|
||||||
|
// SpecURL is the URL to the OpenAPI spec JSON (default: "/openapi")
|
||||||
|
SpecURL string
|
||||||
|
// Title is the page title (default: "API Documentation")
|
||||||
|
Title string
|
||||||
|
// FaviconURL is the URL to the favicon (optional)
|
||||||
|
FaviconURL string
|
||||||
|
// CustomCSS allows injecting custom CSS (optional)
|
||||||
|
CustomCSS string
|
||||||
|
// Theme for the UI (light/dark, depends on UI type)
|
||||||
|
Theme string
|
||||||
|
}
|
||||||
|
|
||||||
|
// UIHandler creates an HTTP handler that serves an OpenAPI UI
|
||||||
|
func UIHandler(config UIConfig) http.HandlerFunc {
|
||||||
|
// Set defaults
|
||||||
|
if config.UIType == "" {
|
||||||
|
config.UIType = SwaggerUI
|
||||||
|
}
|
||||||
|
if config.SpecURL == "" {
|
||||||
|
config.SpecURL = "/openapi"
|
||||||
|
}
|
||||||
|
if config.Title == "" {
|
||||||
|
config.Title = "API Documentation"
|
||||||
|
}
|
||||||
|
if config.Theme == "" {
|
||||||
|
config.Theme = "light"
|
||||||
|
}
|
||||||
|
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
var htmlContent string
|
||||||
|
var err error
|
||||||
|
|
||||||
|
switch config.UIType {
|
||||||
|
case SwaggerUI:
|
||||||
|
htmlContent, err = generateSwaggerUI(config)
|
||||||
|
case RapiDoc:
|
||||||
|
htmlContent, err = generateRapiDoc(config)
|
||||||
|
case Redoc:
|
||||||
|
htmlContent, err = generateRedoc(config)
|
||||||
|
case Scalar:
|
||||||
|
htmlContent, err = generateScalar(config)
|
||||||
|
default:
|
||||||
|
http.Error(w, "Unsupported UI type", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, fmt.Sprintf("Failed to generate UI: %v", err), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, err = w.Write([]byte(htmlContent))
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, fmt.Sprintf("Failed to write response: %v", err), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// templateData wraps UIConfig to properly handle CSS in templates
|
||||||
|
type templateData struct {
|
||||||
|
UIConfig
|
||||||
|
SafeCustomCSS template.CSS
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateSwaggerUI generates the HTML for Swagger UI
|
||||||
|
func generateSwaggerUI(config UIConfig) (string, error) {
|
||||||
|
tmpl := `<!DOCTYPE html>
|
||||||
|
<html lang="en">
|
||||||
|
<head>
|
||||||
|
<meta charset="UTF-8">
|
||||||
|
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||||
|
<title>{{.Title}}</title>
|
||||||
|
{{if .FaviconURL}}<link rel="icon" type="image/png" href="{{.FaviconURL}}">{{end}}
|
||||||
|
<link rel="stylesheet" type="text/css" href="https://cdn.jsdelivr.net/npm/swagger-ui-dist@5/swagger-ui.css">
|
||||||
|
{{if .SafeCustomCSS}}<style>{{.SafeCustomCSS}}</style>{{end}}
|
||||||
|
<style>
|
||||||
|
html { box-sizing: border-box; overflow: -moz-scrollbars-vertical; overflow-y: scroll; }
|
||||||
|
*, *:before, *:after { box-sizing: inherit; }
|
||||||
|
body { margin: 0; padding: 0; }
|
||||||
|
</style>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<div id="swagger-ui"></div>
|
||||||
|
<script src="https://cdn.jsdelivr.net/npm/swagger-ui-dist@5/swagger-ui-bundle.js"></script>
|
||||||
|
<script src="https://cdn.jsdelivr.net/npm/swagger-ui-dist@5/swagger-ui-standalone-preset.js"></script>
|
||||||
|
<script>
|
||||||
|
window.onload = function() {
|
||||||
|
const ui = SwaggerUIBundle({
|
||||||
|
url: "{{.SpecURL}}",
|
||||||
|
dom_id: '#swagger-ui',
|
||||||
|
deepLinking: true,
|
||||||
|
presets: [
|
||||||
|
SwaggerUIBundle.presets.apis,
|
||||||
|
SwaggerUIStandalonePreset
|
||||||
|
],
|
||||||
|
plugins: [
|
||||||
|
SwaggerUIBundle.plugins.DownloadUrl
|
||||||
|
],
|
||||||
|
layout: "StandaloneLayout",
|
||||||
|
{{if eq .Theme "dark"}}
|
||||||
|
syntaxHighlight: {
|
||||||
|
activate: true,
|
||||||
|
theme: "monokai"
|
||||||
|
}
|
||||||
|
{{end}}
|
||||||
|
});
|
||||||
|
window.ui = ui;
|
||||||
|
};
|
||||||
|
</script>
|
||||||
|
</body>
|
||||||
|
</html>`
|
||||||
|
|
||||||
|
t, err := template.New("swagger").Parse(tmpl)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
data := templateData{
|
||||||
|
UIConfig: config,
|
||||||
|
SafeCustomCSS: template.CSS(config.CustomCSS),
|
||||||
|
}
|
||||||
|
|
||||||
|
var buf strings.Builder
|
||||||
|
if err := t.Execute(&buf, data); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return buf.String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateRapiDoc generates the HTML for RapiDoc
|
||||||
|
func generateRapiDoc(config UIConfig) (string, error) {
|
||||||
|
theme := "light"
|
||||||
|
if config.Theme == "dark" {
|
||||||
|
theme = "dark"
|
||||||
|
}
|
||||||
|
|
||||||
|
tmpl := `<!DOCTYPE html>
|
||||||
|
<html lang="en">
|
||||||
|
<head>
|
||||||
|
<meta charset="UTF-8">
|
||||||
|
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||||
|
<title>{{.Title}}</title>
|
||||||
|
{{if .FaviconURL}}<link rel="icon" type="image/png" href="{{.FaviconURL}}">{{end}}
|
||||||
|
<script type="module" src="https://unpkg.com/rapidoc/dist/rapidoc-min.js"></script>
|
||||||
|
{{if .SafeCustomCSS}}<style>{{.SafeCustomCSS}}</style>{{end}}
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<rapi-doc
|
||||||
|
spec-url="{{.SpecURL}}"
|
||||||
|
theme="` + theme + `"
|
||||||
|
render-style="read"
|
||||||
|
show-header="true"
|
||||||
|
show-info="true"
|
||||||
|
allow-try="true"
|
||||||
|
allow-server-selection="true"
|
||||||
|
allow-authentication="true"
|
||||||
|
api-key-name="Authorization"
|
||||||
|
api-key-location="header"
|
||||||
|
></rapi-doc>
|
||||||
|
</body>
|
||||||
|
</html>`
|
||||||
|
|
||||||
|
t, err := template.New("rapidoc").Parse(tmpl)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
data := templateData{
|
||||||
|
UIConfig: config,
|
||||||
|
SafeCustomCSS: template.CSS(config.CustomCSS),
|
||||||
|
}
|
||||||
|
|
||||||
|
var buf strings.Builder
|
||||||
|
if err := t.Execute(&buf, data); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return buf.String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateRedoc generates the HTML for Redoc
|
||||||
|
func generateRedoc(config UIConfig) (string, error) {
|
||||||
|
tmpl := `<!DOCTYPE html>
|
||||||
|
<html lang="en">
|
||||||
|
<head>
|
||||||
|
<meta charset="UTF-8">
|
||||||
|
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||||
|
<title>{{.Title}}</title>
|
||||||
|
{{if .FaviconURL}}<link rel="icon" type="image/png" href="{{.FaviconURL}}">{{end}}
|
||||||
|
{{if .SafeCustomCSS}}<style>{{.SafeCustomCSS}}</style>{{end}}
|
||||||
|
<style>
|
||||||
|
body { margin: 0; padding: 0; }
|
||||||
|
</style>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<redoc spec-url="{{.SpecURL}}" {{if eq .Theme "dark"}}theme='{"colors": {"primary": {"main": "#dd5522"}}}'{{end}}></redoc>
|
||||||
|
<script src="https://cdn.redoc.ly/redoc/latest/bundles/redoc.standalone.js"></script>
|
||||||
|
</body>
|
||||||
|
</html>`
|
||||||
|
|
||||||
|
t, err := template.New("redoc").Parse(tmpl)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
data := templateData{
|
||||||
|
UIConfig: config,
|
||||||
|
SafeCustomCSS: template.CSS(config.CustomCSS),
|
||||||
|
}
|
||||||
|
|
||||||
|
var buf strings.Builder
|
||||||
|
if err := t.Execute(&buf, data); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return buf.String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateScalar generates the HTML for Scalar
|
||||||
|
func generateScalar(config UIConfig) (string, error) {
|
||||||
|
tmpl := `<!DOCTYPE html>
|
||||||
|
<html lang="en">
|
||||||
|
<head>
|
||||||
|
<meta charset="UTF-8">
|
||||||
|
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||||
|
<title>{{.Title}}</title>
|
||||||
|
{{if .FaviconURL}}<link rel="icon" type="image/png" href="{{.FaviconURL}}">{{end}}
|
||||||
|
{{if .SafeCustomCSS}}<style>{{.SafeCustomCSS}}</style>{{end}}
|
||||||
|
<style>
|
||||||
|
body { margin: 0; padding: 0; }
|
||||||
|
</style>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<script id="api-reference" data-url="{{.SpecURL}}" {{if eq .Theme "dark"}}data-theme="dark"{{end}}></script>
|
||||||
|
<script src="https://cdn.jsdelivr.net/npm/@scalar/api-reference"></script>
|
||||||
|
</body>
|
||||||
|
</html>`
|
||||||
|
|
||||||
|
t, err := template.New("scalar").Parse(tmpl)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
data := templateData{
|
||||||
|
UIConfig: config,
|
||||||
|
SafeCustomCSS: template.CSS(config.CustomCSS),
|
||||||
|
}
|
||||||
|
|
||||||
|
var buf strings.Builder
|
||||||
|
if err := t.Execute(&buf, data); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return buf.String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetupUIRoute adds the OpenAPI UI route to a mux router
|
||||||
|
// This is a convenience function for the most common use case
|
||||||
|
func SetupUIRoute(router *mux.Router, path string, config UIConfig) {
|
||||||
|
router.Handle(path, UIHandler(config))
|
||||||
|
}
|
||||||
308
pkg/openapi/ui_handler_test.go
Normal file
308
pkg/openapi/ui_handler_test.go
Normal file
@@ -0,0 +1,308 @@
|
|||||||
|
package openapi
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gorilla/mux"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestUIHandler_SwaggerUI(t *testing.T) {
|
||||||
|
config := UIConfig{
|
||||||
|
UIType: SwaggerUI,
|
||||||
|
SpecURL: "/openapi",
|
||||||
|
Title: "Test API Docs",
|
||||||
|
}
|
||||||
|
|
||||||
|
handler := UIHandler(config)
|
||||||
|
req := httptest.NewRequest("GET", "/docs", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
handler(w, req)
|
||||||
|
|
||||||
|
resp := w.Result()
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Errorf("Expected status 200, got %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
body := w.Body.String()
|
||||||
|
|
||||||
|
// Check for Swagger UI specific content
|
||||||
|
if !strings.Contains(body, "swagger-ui") {
|
||||||
|
t.Error("Expected Swagger UI content")
|
||||||
|
}
|
||||||
|
if !strings.Contains(body, "SwaggerUIBundle") {
|
||||||
|
t.Error("Expected SwaggerUIBundle script")
|
||||||
|
}
|
||||||
|
if !strings.Contains(body, config.Title) {
|
||||||
|
t.Errorf("Expected title '%s' in HTML", config.Title)
|
||||||
|
}
|
||||||
|
if !strings.Contains(body, config.SpecURL) {
|
||||||
|
t.Errorf("Expected spec URL '%s' in HTML", config.SpecURL)
|
||||||
|
}
|
||||||
|
if !strings.Contains(body, "swagger-ui-dist") {
|
||||||
|
t.Error("Expected Swagger UI CDN link")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUIHandler_RapiDoc(t *testing.T) {
|
||||||
|
config := UIConfig{
|
||||||
|
UIType: RapiDoc,
|
||||||
|
SpecURL: "/api/spec",
|
||||||
|
Title: "RapiDoc Test",
|
||||||
|
}
|
||||||
|
|
||||||
|
handler := UIHandler(config)
|
||||||
|
req := httptest.NewRequest("GET", "/docs", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
handler(w, req)
|
||||||
|
|
||||||
|
resp := w.Result()
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Errorf("Expected status 200, got %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
body := w.Body.String()
|
||||||
|
|
||||||
|
// Check for RapiDoc specific content
|
||||||
|
if !strings.Contains(body, "rapi-doc") {
|
||||||
|
t.Error("Expected rapi-doc element")
|
||||||
|
}
|
||||||
|
if !strings.Contains(body, "rapidoc-min.js") {
|
||||||
|
t.Error("Expected RapiDoc script")
|
||||||
|
}
|
||||||
|
if !strings.Contains(body, config.Title) {
|
||||||
|
t.Errorf("Expected title '%s' in HTML", config.Title)
|
||||||
|
}
|
||||||
|
if !strings.Contains(body, config.SpecURL) {
|
||||||
|
t.Errorf("Expected spec URL '%s' in HTML", config.SpecURL)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUIHandler_Redoc(t *testing.T) {
|
||||||
|
config := UIConfig{
|
||||||
|
UIType: Redoc,
|
||||||
|
SpecURL: "/spec.json",
|
||||||
|
Title: "Redoc Test",
|
||||||
|
}
|
||||||
|
|
||||||
|
handler := UIHandler(config)
|
||||||
|
req := httptest.NewRequest("GET", "/docs", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
handler(w, req)
|
||||||
|
|
||||||
|
resp := w.Result()
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Errorf("Expected status 200, got %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
body := w.Body.String()
|
||||||
|
|
||||||
|
// Check for Redoc specific content
|
||||||
|
if !strings.Contains(body, "<redoc") {
|
||||||
|
t.Error("Expected redoc element")
|
||||||
|
}
|
||||||
|
if !strings.Contains(body, "redoc.standalone.js") {
|
||||||
|
t.Error("Expected Redoc script")
|
||||||
|
}
|
||||||
|
if !strings.Contains(body, config.Title) {
|
||||||
|
t.Errorf("Expected title '%s' in HTML", config.Title)
|
||||||
|
}
|
||||||
|
if !strings.Contains(body, config.SpecURL) {
|
||||||
|
t.Errorf("Expected spec URL '%s' in HTML", config.SpecURL)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUIHandler_Scalar(t *testing.T) {
|
||||||
|
config := UIConfig{
|
||||||
|
UIType: Scalar,
|
||||||
|
SpecURL: "/openapi.json",
|
||||||
|
Title: "Scalar Test",
|
||||||
|
}
|
||||||
|
|
||||||
|
handler := UIHandler(config)
|
||||||
|
req := httptest.NewRequest("GET", "/docs", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
handler(w, req)
|
||||||
|
|
||||||
|
resp := w.Result()
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Errorf("Expected status 200, got %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
body := w.Body.String()
|
||||||
|
|
||||||
|
// Check for Scalar specific content
|
||||||
|
if !strings.Contains(body, "api-reference") {
|
||||||
|
t.Error("Expected api-reference element")
|
||||||
|
}
|
||||||
|
if !strings.Contains(body, "@scalar/api-reference") {
|
||||||
|
t.Error("Expected Scalar script")
|
||||||
|
}
|
||||||
|
if !strings.Contains(body, config.Title) {
|
||||||
|
t.Errorf("Expected title '%s' in HTML", config.Title)
|
||||||
|
}
|
||||||
|
if !strings.Contains(body, config.SpecURL) {
|
||||||
|
t.Errorf("Expected spec URL '%s' in HTML", config.SpecURL)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUIHandler_DefaultValues(t *testing.T) {
|
||||||
|
// Test with empty config to check defaults
|
||||||
|
config := UIConfig{}
|
||||||
|
|
||||||
|
handler := UIHandler(config)
|
||||||
|
req := httptest.NewRequest("GET", "/docs", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
handler(w, req)
|
||||||
|
|
||||||
|
resp := w.Result()
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Errorf("Expected status 200, got %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
body := w.Body.String()
|
||||||
|
|
||||||
|
// Should default to Swagger UI
|
||||||
|
if !strings.Contains(body, "swagger-ui") {
|
||||||
|
t.Error("Expected default to Swagger UI")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should default to /openapi spec URL
|
||||||
|
if !strings.Contains(body, "/openapi") {
|
||||||
|
t.Error("Expected default spec URL '/openapi'")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should default to "API Documentation" title
|
||||||
|
if !strings.Contains(body, "API Documentation") {
|
||||||
|
t.Error("Expected default title 'API Documentation'")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUIHandler_CustomCSS(t *testing.T) {
|
||||||
|
customCSS := ".custom-class { color: red; }"
|
||||||
|
config := UIConfig{
|
||||||
|
UIType: SwaggerUI,
|
||||||
|
CustomCSS: customCSS,
|
||||||
|
}
|
||||||
|
|
||||||
|
handler := UIHandler(config)
|
||||||
|
req := httptest.NewRequest("GET", "/docs", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
handler(w, req)
|
||||||
|
|
||||||
|
body := w.Body.String()
|
||||||
|
|
||||||
|
if !strings.Contains(body, customCSS) {
|
||||||
|
t.Errorf("Expected custom CSS to be included. Body:\n%s", body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUIHandler_Favicon(t *testing.T) {
|
||||||
|
faviconURL := "https://example.com/favicon.ico"
|
||||||
|
config := UIConfig{
|
||||||
|
UIType: SwaggerUI,
|
||||||
|
FaviconURL: faviconURL,
|
||||||
|
}
|
||||||
|
|
||||||
|
handler := UIHandler(config)
|
||||||
|
req := httptest.NewRequest("GET", "/docs", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
handler(w, req)
|
||||||
|
|
||||||
|
body := w.Body.String()
|
||||||
|
|
||||||
|
if !strings.Contains(body, faviconURL) {
|
||||||
|
t.Error("Expected favicon URL to be included")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUIHandler_DarkTheme(t *testing.T) {
|
||||||
|
config := UIConfig{
|
||||||
|
UIType: SwaggerUI,
|
||||||
|
Theme: "dark",
|
||||||
|
}
|
||||||
|
|
||||||
|
handler := UIHandler(config)
|
||||||
|
req := httptest.NewRequest("GET", "/docs", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
handler(w, req)
|
||||||
|
|
||||||
|
body := w.Body.String()
|
||||||
|
|
||||||
|
// SwaggerUI uses monokai theme for dark mode
|
||||||
|
if !strings.Contains(body, "monokai") {
|
||||||
|
t.Error("Expected dark theme configuration for Swagger UI")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUIHandler_InvalidUIType(t *testing.T) {
|
||||||
|
config := UIConfig{
|
||||||
|
UIType: "invalid-ui-type",
|
||||||
|
}
|
||||||
|
|
||||||
|
handler := UIHandler(config)
|
||||||
|
req := httptest.NewRequest("GET", "/docs", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
handler(w, req)
|
||||||
|
|
||||||
|
resp := w.Result()
|
||||||
|
if resp.StatusCode != http.StatusBadRequest {
|
||||||
|
t.Errorf("Expected status 400 for invalid UI type, got %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUIHandler_ContentType(t *testing.T) {
|
||||||
|
config := UIConfig{
|
||||||
|
UIType: SwaggerUI,
|
||||||
|
}
|
||||||
|
|
||||||
|
handler := UIHandler(config)
|
||||||
|
req := httptest.NewRequest("GET", "/docs", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
handler(w, req)
|
||||||
|
|
||||||
|
contentType := w.Header().Get("Content-Type")
|
||||||
|
if !strings.Contains(contentType, "text/html") {
|
||||||
|
t.Errorf("Expected Content-Type to contain 'text/html', got '%s'", contentType)
|
||||||
|
}
|
||||||
|
if !strings.Contains(contentType, "charset=utf-8") {
|
||||||
|
t.Errorf("Expected Content-Type to contain 'charset=utf-8', got '%s'", contentType)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetupUIRoute(t *testing.T) {
|
||||||
|
router := mux.NewRouter()
|
||||||
|
|
||||||
|
config := UIConfig{
|
||||||
|
UIType: SwaggerUI,
|
||||||
|
}
|
||||||
|
|
||||||
|
SetupUIRoute(router, "/api-docs", config)
|
||||||
|
|
||||||
|
// Test that the route was added and works
|
||||||
|
req := httptest.NewRequest("GET", "/api-docs", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("Expected status 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify it returns HTML
|
||||||
|
body := w.Body.String()
|
||||||
|
if !strings.Contains(body, "swagger-ui") {
|
||||||
|
t.Error("Expected Swagger UI content")
|
||||||
|
}
|
||||||
|
}
|
||||||
118
pkg/resolvespec/cache_helpers.go
Normal file
118
pkg/resolvespec/cache_helpers.go
Normal file
@@ -0,0 +1,118 @@
|
|||||||
|
package resolvespec
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/hex"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/cache"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||||
|
)
|
||||||
|
|
||||||
|
// queryCacheKey represents the components used to build a cache key for query total count
|
||||||
|
type queryCacheKey struct {
|
||||||
|
TableName string `json:"table_name"`
|
||||||
|
Filters []common.FilterOption `json:"filters"`
|
||||||
|
Sort []common.SortOption `json:"sort"`
|
||||||
|
CustomSQLWhere string `json:"custom_sql_where,omitempty"`
|
||||||
|
CustomSQLOr string `json:"custom_sql_or,omitempty"`
|
||||||
|
CursorForward string `json:"cursor_forward,omitempty"`
|
||||||
|
CursorBackward string `json:"cursor_backward,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// cachedTotal represents a cached total count
|
||||||
|
type cachedTotal struct {
|
||||||
|
Total int `json:"total"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildQueryCacheKey builds a cache key from query parameters for total count caching
|
||||||
|
func buildQueryCacheKey(tableName string, filters []common.FilterOption, sort []common.SortOption, customWhere, customOr string) string {
|
||||||
|
key := queryCacheKey{
|
||||||
|
TableName: tableName,
|
||||||
|
Filters: filters,
|
||||||
|
Sort: sort,
|
||||||
|
CustomSQLWhere: customWhere,
|
||||||
|
CustomSQLOr: customOr,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Serialize to JSON for consistent hashing
|
||||||
|
jsonData, err := json.Marshal(key)
|
||||||
|
if err != nil {
|
||||||
|
// Fallback to simple string concatenation if JSON fails
|
||||||
|
return hashString(fmt.Sprintf("%s_%v_%v_%s_%s", tableName, filters, sort, customWhere, customOr))
|
||||||
|
}
|
||||||
|
|
||||||
|
return hashString(string(jsonData))
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildExtendedQueryCacheKey builds a cache key for extended query options with cursor pagination
|
||||||
|
func buildExtendedQueryCacheKey(tableName string, filters []common.FilterOption, sort []common.SortOption,
|
||||||
|
customWhere, customOr string, cursorFwd, cursorBwd string) string {
|
||||||
|
|
||||||
|
key := queryCacheKey{
|
||||||
|
TableName: tableName,
|
||||||
|
Filters: filters,
|
||||||
|
Sort: sort,
|
||||||
|
CustomSQLWhere: customWhere,
|
||||||
|
CustomSQLOr: customOr,
|
||||||
|
CursorForward: cursorFwd,
|
||||||
|
CursorBackward: cursorBwd,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Serialize to JSON for consistent hashing
|
||||||
|
jsonData, err := json.Marshal(key)
|
||||||
|
if err != nil {
|
||||||
|
// Fallback to simple string concatenation if JSON fails
|
||||||
|
return hashString(fmt.Sprintf("%s_%v_%v_%s_%s_%s_%s",
|
||||||
|
tableName, filters, sort, customWhere, customOr, cursorFwd, cursorBwd))
|
||||||
|
}
|
||||||
|
|
||||||
|
return hashString(string(jsonData))
|
||||||
|
}
|
||||||
|
|
||||||
|
// hashString computes SHA256 hash of a string
|
||||||
|
func hashString(s string) string {
|
||||||
|
h := sha256.New()
|
||||||
|
h.Write([]byte(s))
|
||||||
|
return hex.EncodeToString(h.Sum(nil))
|
||||||
|
}
|
||||||
|
|
||||||
|
// getQueryTotalCacheKey returns a formatted cache key for storing/retrieving total count
|
||||||
|
func getQueryTotalCacheKey(hash string) string {
|
||||||
|
return fmt.Sprintf("query_total:%s", hash)
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildCacheTags creates cache tags from schema and table name
|
||||||
|
func buildCacheTags(schema, tableName string) []string {
|
||||||
|
return []string{
|
||||||
|
fmt.Sprintf("schema:%s", strings.ToLower(schema)),
|
||||||
|
fmt.Sprintf("table:%s", strings.ToLower(tableName)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// setQueryTotalCache stores a query total in the cache with schema and table tags
|
||||||
|
func setQueryTotalCache(ctx context.Context, cacheKey string, total int, schema, tableName string, ttl time.Duration) error {
|
||||||
|
c := cache.GetDefaultCache()
|
||||||
|
cacheData := cachedTotal{Total: total}
|
||||||
|
tags := buildCacheTags(schema, tableName)
|
||||||
|
|
||||||
|
return c.SetWithTags(ctx, cacheKey, cacheData, ttl, tags)
|
||||||
|
}
|
||||||
|
|
||||||
|
// invalidateCacheForTags removes all cached items matching the specified tags
|
||||||
|
func invalidateCacheForTags(ctx context.Context, tags []string) error {
|
||||||
|
c := cache.GetDefaultCache()
|
||||||
|
|
||||||
|
// Invalidate for each tag
|
||||||
|
for _, tag := range tags {
|
||||||
|
if err := c.DeleteByTag(ctx, tag); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -331,19 +331,17 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
|||||||
// Use extended cache key if cursors are present
|
// Use extended cache key if cursors are present
|
||||||
var cacheKeyHash string
|
var cacheKeyHash string
|
||||||
if len(options.CursorForward) > 0 || len(options.CursorBackward) > 0 {
|
if len(options.CursorForward) > 0 || len(options.CursorBackward) > 0 {
|
||||||
cacheKeyHash = cache.BuildExtendedQueryCacheKey(
|
cacheKeyHash = buildExtendedQueryCacheKey(
|
||||||
tableName,
|
tableName,
|
||||||
options.Filters,
|
options.Filters,
|
||||||
options.Sort,
|
options.Sort,
|
||||||
"", // No custom SQL WHERE in resolvespec
|
"", // No custom SQL WHERE in resolvespec
|
||||||
"", // No custom SQL OR in resolvespec
|
"", // No custom SQL OR in resolvespec
|
||||||
nil, // No expand options in resolvespec
|
|
||||||
false, // distinct not used here
|
|
||||||
options.CursorForward,
|
options.CursorForward,
|
||||||
options.CursorBackward,
|
options.CursorBackward,
|
||||||
)
|
)
|
||||||
} else {
|
} else {
|
||||||
cacheKeyHash = cache.BuildQueryCacheKey(
|
cacheKeyHash = buildQueryCacheKey(
|
||||||
tableName,
|
tableName,
|
||||||
options.Filters,
|
options.Filters,
|
||||||
options.Sort,
|
options.Sort,
|
||||||
@@ -351,10 +349,10 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
|||||||
"", // No custom SQL OR in resolvespec
|
"", // No custom SQL OR in resolvespec
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
cacheKey := cache.GetQueryTotalCacheKey(cacheKeyHash)
|
cacheKey := getQueryTotalCacheKey(cacheKeyHash)
|
||||||
|
|
||||||
// Try to retrieve from cache
|
// Try to retrieve from cache
|
||||||
var cachedTotal cache.CachedTotal
|
var cachedTotal cachedTotal
|
||||||
err := cache.GetDefaultCache().Get(ctx, cacheKey, &cachedTotal)
|
err := cache.GetDefaultCache().Get(ctx, cacheKey, &cachedTotal)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
total = cachedTotal.Total
|
total = cachedTotal.Total
|
||||||
@@ -371,10 +369,9 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
|||||||
total = count
|
total = count
|
||||||
logger.Debug("Total records (from query): %d", total)
|
logger.Debug("Total records (from query): %d", total)
|
||||||
|
|
||||||
// Store in cache
|
// Store in cache with schema and table tags
|
||||||
cacheTTL := time.Minute * 2 // Default 2 minutes TTL
|
cacheTTL := time.Minute * 2 // Default 2 minutes TTL
|
||||||
cacheData := cache.CachedTotal{Total: total}
|
if err := setQueryTotalCache(ctx, cacheKey, total, schema, tableName, cacheTTL); err != nil {
|
||||||
if err := cache.GetDefaultCache().Set(ctx, cacheKey, cacheData, cacheTTL); err != nil {
|
|
||||||
logger.Warn("Failed to cache query total: %v", err)
|
logger.Warn("Failed to cache query total: %v", err)
|
||||||
// Don't fail the request if caching fails
|
// Don't fail the request if caching fails
|
||||||
} else {
|
} else {
|
||||||
@@ -464,6 +461,11 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
logger.Info("Successfully created record with nested data, ID: %v", result.ID)
|
logger.Info("Successfully created record with nested data, ID: %v", result.ID)
|
||||||
|
// Invalidate cache for this table
|
||||||
|
cacheTags := buildCacheTags(schema, tableName)
|
||||||
|
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
||||||
|
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
|
||||||
|
}
|
||||||
h.sendResponse(w, result.Data, nil)
|
h.sendResponse(w, result.Data, nil)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -480,6 +482,11 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
logger.Info("Successfully created record, rows affected: %d", result.RowsAffected())
|
logger.Info("Successfully created record, rows affected: %d", result.RowsAffected())
|
||||||
|
// Invalidate cache for this table
|
||||||
|
cacheTags := buildCacheTags(schema, tableName)
|
||||||
|
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
||||||
|
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
|
||||||
|
}
|
||||||
h.sendResponse(w, v, nil)
|
h.sendResponse(w, v, nil)
|
||||||
|
|
||||||
case []map[string]interface{}:
|
case []map[string]interface{}:
|
||||||
@@ -518,6 +525,11 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
logger.Info("Successfully created %d records with nested data", len(results))
|
logger.Info("Successfully created %d records with nested data", len(results))
|
||||||
|
// Invalidate cache for this table
|
||||||
|
cacheTags := buildCacheTags(schema, tableName)
|
||||||
|
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
||||||
|
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
|
||||||
|
}
|
||||||
h.sendResponse(w, results, nil)
|
h.sendResponse(w, results, nil)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -541,6 +553,11 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
logger.Info("Successfully created %d records", len(v))
|
logger.Info("Successfully created %d records", len(v))
|
||||||
|
// Invalidate cache for this table
|
||||||
|
cacheTags := buildCacheTags(schema, tableName)
|
||||||
|
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
||||||
|
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
|
||||||
|
}
|
||||||
h.sendResponse(w, v, nil)
|
h.sendResponse(w, v, nil)
|
||||||
|
|
||||||
case []interface{}:
|
case []interface{}:
|
||||||
@@ -584,6 +601,11 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
logger.Info("Successfully created %d records with nested data", len(results))
|
logger.Info("Successfully created %d records with nested data", len(results))
|
||||||
|
// Invalidate cache for this table
|
||||||
|
cacheTags := buildCacheTags(schema, tableName)
|
||||||
|
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
||||||
|
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
|
||||||
|
}
|
||||||
h.sendResponse(w, results, nil)
|
h.sendResponse(w, results, nil)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -611,6 +633,11 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
logger.Info("Successfully created %d records", len(v))
|
logger.Info("Successfully created %d records", len(v))
|
||||||
|
// Invalidate cache for this table
|
||||||
|
cacheTags := buildCacheTags(schema, tableName)
|
||||||
|
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
||||||
|
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
|
||||||
|
}
|
||||||
h.sendResponse(w, list, nil)
|
h.sendResponse(w, list, nil)
|
||||||
|
|
||||||
default:
|
default:
|
||||||
@@ -661,6 +688,11 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
logger.Info("Successfully updated record with nested data, rows: %d", result.AffectedRows)
|
logger.Info("Successfully updated record with nested data, rows: %d", result.AffectedRows)
|
||||||
|
// Invalidate cache for this table
|
||||||
|
cacheTags := buildCacheTags(schema, tableName)
|
||||||
|
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
||||||
|
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
|
||||||
|
}
|
||||||
h.sendResponse(w, result.Data, nil)
|
h.sendResponse(w, result.Data, nil)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -697,6 +729,11 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url
|
|||||||
}
|
}
|
||||||
|
|
||||||
logger.Info("Successfully updated %d records", result.RowsAffected())
|
logger.Info("Successfully updated %d records", result.RowsAffected())
|
||||||
|
// Invalidate cache for this table
|
||||||
|
cacheTags := buildCacheTags(schema, tableName)
|
||||||
|
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
||||||
|
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
|
||||||
|
}
|
||||||
h.sendResponse(w, data, nil)
|
h.sendResponse(w, data, nil)
|
||||||
|
|
||||||
case []map[string]interface{}:
|
case []map[string]interface{}:
|
||||||
@@ -735,6 +772,11 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
logger.Info("Successfully updated %d records with nested data", len(results))
|
logger.Info("Successfully updated %d records with nested data", len(results))
|
||||||
|
// Invalidate cache for this table
|
||||||
|
cacheTags := buildCacheTags(schema, tableName)
|
||||||
|
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
||||||
|
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
|
||||||
|
}
|
||||||
h.sendResponse(w, results, nil)
|
h.sendResponse(w, results, nil)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -758,6 +800,11 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
logger.Info("Successfully updated %d records", len(updates))
|
logger.Info("Successfully updated %d records", len(updates))
|
||||||
|
// Invalidate cache for this table
|
||||||
|
cacheTags := buildCacheTags(schema, tableName)
|
||||||
|
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
||||||
|
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
|
||||||
|
}
|
||||||
h.sendResponse(w, updates, nil)
|
h.sendResponse(w, updates, nil)
|
||||||
|
|
||||||
case []interface{}:
|
case []interface{}:
|
||||||
@@ -800,6 +847,11 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
logger.Info("Successfully updated %d records with nested data", len(results))
|
logger.Info("Successfully updated %d records with nested data", len(results))
|
||||||
|
// Invalidate cache for this table
|
||||||
|
cacheTags := buildCacheTags(schema, tableName)
|
||||||
|
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
||||||
|
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
|
||||||
|
}
|
||||||
h.sendResponse(w, results, nil)
|
h.sendResponse(w, results, nil)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -827,6 +879,11 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
logger.Info("Successfully updated %d records", len(list))
|
logger.Info("Successfully updated %d records", len(list))
|
||||||
|
// Invalidate cache for this table
|
||||||
|
cacheTags := buildCacheTags(schema, tableName)
|
||||||
|
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
||||||
|
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
|
||||||
|
}
|
||||||
h.sendResponse(w, list, nil)
|
h.sendResponse(w, list, nil)
|
||||||
|
|
||||||
default:
|
default:
|
||||||
@@ -873,6 +930,11 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
logger.Info("Successfully deleted %d records", len(v))
|
logger.Info("Successfully deleted %d records", len(v))
|
||||||
|
// Invalidate cache for this table
|
||||||
|
cacheTags := buildCacheTags(schema, tableName)
|
||||||
|
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
||||||
|
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
|
||||||
|
}
|
||||||
h.sendResponse(w, map[string]interface{}{"deleted": len(v)}, nil)
|
h.sendResponse(w, map[string]interface{}{"deleted": len(v)}, nil)
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -914,6 +976,11 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
logger.Info("Successfully deleted %d records", deletedCount)
|
logger.Info("Successfully deleted %d records", deletedCount)
|
||||||
|
// Invalidate cache for this table
|
||||||
|
cacheTags := buildCacheTags(schema, tableName)
|
||||||
|
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
||||||
|
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
|
||||||
|
}
|
||||||
h.sendResponse(w, map[string]interface{}{"deleted": deletedCount}, nil)
|
h.sendResponse(w, map[string]interface{}{"deleted": deletedCount}, nil)
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -940,6 +1007,11 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
logger.Info("Successfully deleted %d records", deletedCount)
|
logger.Info("Successfully deleted %d records", deletedCount)
|
||||||
|
// Invalidate cache for this table
|
||||||
|
cacheTags := buildCacheTags(schema, tableName)
|
||||||
|
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
||||||
|
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
|
||||||
|
}
|
||||||
h.sendResponse(w, map[string]interface{}{"deleted": deletedCount}, nil)
|
h.sendResponse(w, map[string]interface{}{"deleted": deletedCount}, nil)
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -998,6 +1070,11 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
|||||||
|
|
||||||
logger.Info("Successfully deleted record with ID: %s", id)
|
logger.Info("Successfully deleted record with ID: %s", id)
|
||||||
// Return the deleted record data
|
// Return the deleted record data
|
||||||
|
// Invalidate cache for this table
|
||||||
|
cacheTags := buildCacheTags(schema, tableName)
|
||||||
|
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
||||||
|
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
|
||||||
|
}
|
||||||
h.sendResponse(w, recordToDelete, nil)
|
h.sendResponse(w, recordToDelete, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
package cache
|
package restheadspec
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@@ -7,56 +7,42 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/cache"
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||||
)
|
)
|
||||||
|
|
||||||
// QueryCacheKey represents the components used to build a cache key for query total count
|
// expandOptionKey represents expand options for cache key
|
||||||
type QueryCacheKey struct {
|
type expandOptionKey struct {
|
||||||
|
Relation string `json:"relation"`
|
||||||
|
Where string `json:"where,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// queryCacheKey represents the components used to build a cache key for query total count
|
||||||
|
type queryCacheKey struct {
|
||||||
TableName string `json:"table_name"`
|
TableName string `json:"table_name"`
|
||||||
Filters []common.FilterOption `json:"filters"`
|
Filters []common.FilterOption `json:"filters"`
|
||||||
Sort []common.SortOption `json:"sort"`
|
Sort []common.SortOption `json:"sort"`
|
||||||
CustomSQLWhere string `json:"custom_sql_where,omitempty"`
|
CustomSQLWhere string `json:"custom_sql_where,omitempty"`
|
||||||
CustomSQLOr string `json:"custom_sql_or,omitempty"`
|
CustomSQLOr string `json:"custom_sql_or,omitempty"`
|
||||||
Expand []ExpandOptionKey `json:"expand,omitempty"`
|
Expand []expandOptionKey `json:"expand,omitempty"`
|
||||||
Distinct bool `json:"distinct,omitempty"`
|
Distinct bool `json:"distinct,omitempty"`
|
||||||
CursorForward string `json:"cursor_forward,omitempty"`
|
CursorForward string `json:"cursor_forward,omitempty"`
|
||||||
CursorBackward string `json:"cursor_backward,omitempty"`
|
CursorBackward string `json:"cursor_backward,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// ExpandOptionKey represents expand options for cache key
|
// cachedTotal represents a cached total count
|
||||||
type ExpandOptionKey struct {
|
type cachedTotal struct {
|
||||||
Relation string `json:"relation"`
|
Total int `json:"total"`
|
||||||
Where string `json:"where,omitempty"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// BuildQueryCacheKey builds a cache key from query parameters for total count caching
|
// buildExtendedQueryCacheKey builds a cache key for extended query options (restheadspec)
|
||||||
// This is used to cache the total count of records matching a query
|
|
||||||
func BuildQueryCacheKey(tableName string, filters []common.FilterOption, sort []common.SortOption, customWhere, customOr string) string {
|
|
||||||
key := QueryCacheKey{
|
|
||||||
TableName: tableName,
|
|
||||||
Filters: filters,
|
|
||||||
Sort: sort,
|
|
||||||
CustomSQLWhere: customWhere,
|
|
||||||
CustomSQLOr: customOr,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Serialize to JSON for consistent hashing
|
|
||||||
jsonData, err := json.Marshal(key)
|
|
||||||
if err != nil {
|
|
||||||
// Fallback to simple string concatenation if JSON fails
|
|
||||||
return hashString(fmt.Sprintf("%s_%v_%v_%s_%s", tableName, filters, sort, customWhere, customOr))
|
|
||||||
}
|
|
||||||
|
|
||||||
return hashString(string(jsonData))
|
|
||||||
}
|
|
||||||
|
|
||||||
// BuildExtendedQueryCacheKey builds a cache key for extended query options (restheadspec)
|
|
||||||
// Includes expand, distinct, and cursor pagination options
|
// Includes expand, distinct, and cursor pagination options
|
||||||
func BuildExtendedQueryCacheKey(tableName string, filters []common.FilterOption, sort []common.SortOption,
|
func buildExtendedQueryCacheKey(tableName string, filters []common.FilterOption, sort []common.SortOption,
|
||||||
customWhere, customOr string, expandOpts []interface{}, distinct bool, cursorFwd, cursorBwd string) string {
|
customWhere, customOr string, expandOpts []interface{}, distinct bool, cursorFwd, cursorBwd string) string {
|
||||||
|
|
||||||
key := QueryCacheKey{
|
key := queryCacheKey{
|
||||||
TableName: tableName,
|
TableName: tableName,
|
||||||
Filters: filters,
|
Filters: filters,
|
||||||
Sort: sort,
|
Sort: sort,
|
||||||
@@ -69,11 +55,11 @@ func BuildExtendedQueryCacheKey(tableName string, filters []common.FilterOption,
|
|||||||
|
|
||||||
// Convert expand options to cache key format
|
// Convert expand options to cache key format
|
||||||
if len(expandOpts) > 0 {
|
if len(expandOpts) > 0 {
|
||||||
key.Expand = make([]ExpandOptionKey, 0, len(expandOpts))
|
key.Expand = make([]expandOptionKey, 0, len(expandOpts))
|
||||||
for _, exp := range expandOpts {
|
for _, exp := range expandOpts {
|
||||||
// Type assert to get the expand option fields we care about for caching
|
// Type assert to get the expand option fields we care about for caching
|
||||||
if expMap, ok := exp.(map[string]interface{}); ok {
|
if expMap, ok := exp.(map[string]interface{}); ok {
|
||||||
expKey := ExpandOptionKey{}
|
expKey := expandOptionKey{}
|
||||||
if rel, ok := expMap["relation"].(string); ok {
|
if rel, ok := expMap["relation"].(string); ok {
|
||||||
expKey.Relation = rel
|
expKey.Relation = rel
|
||||||
}
|
}
|
||||||
@@ -83,7 +69,6 @@ func BuildExtendedQueryCacheKey(tableName string, filters []common.FilterOption,
|
|||||||
key.Expand = append(key.Expand, expKey)
|
key.Expand = append(key.Expand, expKey)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Sort expand options for consistent hashing (already sorted by relation name above)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Serialize to JSON for consistent hashing
|
// Serialize to JSON for consistent hashing
|
||||||
@@ -104,24 +89,38 @@ func hashString(s string) string {
|
|||||||
return hex.EncodeToString(h.Sum(nil))
|
return hex.EncodeToString(h.Sum(nil))
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetQueryTotalCacheKey returns a formatted cache key for storing/retrieving total count
|
// getQueryTotalCacheKey returns a formatted cache key for storing/retrieving total count
|
||||||
func GetQueryTotalCacheKey(hash string) string {
|
func getQueryTotalCacheKey(hash string) string {
|
||||||
return fmt.Sprintf("query_total:%s", hash)
|
return fmt.Sprintf("query_total:%s", hash)
|
||||||
}
|
}
|
||||||
|
|
||||||
// CachedTotal represents a cached total count
|
// buildCacheTags creates cache tags from schema and table name
|
||||||
type CachedTotal struct {
|
func buildCacheTags(schema, tableName string) []string {
|
||||||
Total int `json:"total"`
|
return []string{
|
||||||
|
fmt.Sprintf("schema:%s", strings.ToLower(schema)),
|
||||||
|
fmt.Sprintf("table:%s", strings.ToLower(tableName)),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// InvalidateCacheForTable removes all cached totals for a specific table
|
// setQueryTotalCache stores a query total in the cache with schema and table tags
|
||||||
// This should be called when data in the table changes (insert/update/delete)
|
func setQueryTotalCache(ctx context.Context, cacheKey string, total int, schema, tableName string, ttl time.Duration) error {
|
||||||
func InvalidateCacheForTable(ctx context.Context, tableName string) error {
|
c := cache.GetDefaultCache()
|
||||||
cache := GetDefaultCache()
|
cacheData := cachedTotal{Total: total}
|
||||||
|
tags := buildCacheTags(schema, tableName)
|
||||||
|
|
||||||
// Build a pattern to match all query totals for this table
|
return c.SetWithTags(ctx, cacheKey, cacheData, ttl, tags)
|
||||||
// Note: This requires pattern matching support in the provider
|
}
|
||||||
pattern := fmt.Sprintf("query_total:*%s*", strings.ToLower(tableName))
|
|
||||||
|
// invalidateCacheForTags removes all cached items matching the specified tags
|
||||||
return cache.DeleteByPattern(ctx, pattern)
|
func invalidateCacheForTags(ctx context.Context, tags []string) error {
|
||||||
|
c := cache.GetDefaultCache()
|
||||||
|
|
||||||
|
// Invalidate for each tag
|
||||||
|
for _, tag := range tags {
|
||||||
|
if err := c.DeleteByTag(ctx, tag); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
@@ -482,8 +482,10 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
|||||||
// Apply custom SQL WHERE clause (AND condition)
|
// Apply custom SQL WHERE clause (AND condition)
|
||||||
if options.CustomSQLWhere != "" {
|
if options.CustomSQLWhere != "" {
|
||||||
logger.Debug("Applying custom SQL WHERE: %s", options.CustomSQLWhere)
|
logger.Debug("Applying custom SQL WHERE: %s", options.CustomSQLWhere)
|
||||||
// Sanitize and allow preload table prefixes since custom SQL may reference multiple tables
|
// First add table prefixes to unqualified columns (but skip columns inside function calls)
|
||||||
sanitizedWhere := common.SanitizeWhereClause(options.CustomSQLWhere, reflection.ExtractTableNameOnly(tableName), &options.RequestOptions)
|
prefixedWhere := common.AddTablePrefixToColumns(options.CustomSQLWhere, reflection.ExtractTableNameOnly(tableName))
|
||||||
|
// Then sanitize and allow preload table prefixes since custom SQL may reference multiple tables
|
||||||
|
sanitizedWhere := common.SanitizeWhereClause(prefixedWhere, reflection.ExtractTableNameOnly(tableName), &options.RequestOptions)
|
||||||
if sanitizedWhere != "" {
|
if sanitizedWhere != "" {
|
||||||
query = query.Where(sanitizedWhere)
|
query = query.Where(sanitizedWhere)
|
||||||
}
|
}
|
||||||
@@ -492,8 +494,9 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
|||||||
// Apply custom SQL WHERE clause (OR condition)
|
// Apply custom SQL WHERE clause (OR condition)
|
||||||
if options.CustomSQLOr != "" {
|
if options.CustomSQLOr != "" {
|
||||||
logger.Debug("Applying custom SQL OR: %s", options.CustomSQLOr)
|
logger.Debug("Applying custom SQL OR: %s", options.CustomSQLOr)
|
||||||
|
customOr := common.AddTablePrefixToColumns(options.CustomSQLOr, reflection.ExtractTableNameOnly(tableName))
|
||||||
// Sanitize and allow preload table prefixes since custom SQL may reference multiple tables
|
// Sanitize and allow preload table prefixes since custom SQL may reference multiple tables
|
||||||
sanitizedOr := common.SanitizeWhereClause(options.CustomSQLOr, reflection.ExtractTableNameOnly(tableName), &options.RequestOptions)
|
sanitizedOr := common.SanitizeWhereClause(customOr, reflection.ExtractTableNameOnly(tableName), &options.RequestOptions)
|
||||||
if sanitizedOr != "" {
|
if sanitizedOr != "" {
|
||||||
query = query.WhereOr(sanitizedOr)
|
query = query.WhereOr(sanitizedOr)
|
||||||
}
|
}
|
||||||
@@ -529,7 +532,7 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
|||||||
var total int
|
var total int
|
||||||
if !options.SkipCount {
|
if !options.SkipCount {
|
||||||
// Try to get from cache first (unless SkipCache is true)
|
// Try to get from cache first (unless SkipCache is true)
|
||||||
var cachedTotal *cache.CachedTotal
|
var cachedTotalData *cachedTotal
|
||||||
var cacheKey string
|
var cacheKey string
|
||||||
|
|
||||||
if !options.SkipCache {
|
if !options.SkipCache {
|
||||||
@@ -543,7 +546,7 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
cacheKeyHash := cache.BuildExtendedQueryCacheKey(
|
cacheKeyHash := buildExtendedQueryCacheKey(
|
||||||
tableName,
|
tableName,
|
||||||
options.Filters,
|
options.Filters,
|
||||||
options.Sort,
|
options.Sort,
|
||||||
@@ -554,22 +557,22 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
|||||||
options.CursorForward,
|
options.CursorForward,
|
||||||
options.CursorBackward,
|
options.CursorBackward,
|
||||||
)
|
)
|
||||||
cacheKey = cache.GetQueryTotalCacheKey(cacheKeyHash)
|
cacheKey = getQueryTotalCacheKey(cacheKeyHash)
|
||||||
|
|
||||||
// Try to retrieve from cache
|
// Try to retrieve from cache
|
||||||
cachedTotal = &cache.CachedTotal{}
|
cachedTotalData = &cachedTotal{}
|
||||||
err := cache.GetDefaultCache().Get(ctx, cacheKey, cachedTotal)
|
err := cache.GetDefaultCache().Get(ctx, cacheKey, cachedTotalData)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
total = cachedTotal.Total
|
total = cachedTotalData.Total
|
||||||
logger.Debug("Total records (from cache): %d", total)
|
logger.Debug("Total records (from cache): %d", total)
|
||||||
} else {
|
} else {
|
||||||
logger.Debug("Cache miss for query total")
|
logger.Debug("Cache miss for query total")
|
||||||
cachedTotal = nil
|
cachedTotalData = nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// If not in cache or cache skip, execute count query
|
// If not in cache or cache skip, execute count query
|
||||||
if cachedTotal == nil {
|
if cachedTotalData == nil {
|
||||||
count, err := query.Count(ctx)
|
count, err := query.Count(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("Error counting records: %v", err)
|
logger.Error("Error counting records: %v", err)
|
||||||
@@ -579,11 +582,10 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
|||||||
total = count
|
total = count
|
||||||
logger.Debug("Total records (from query): %d", total)
|
logger.Debug("Total records (from query): %d", total)
|
||||||
|
|
||||||
// Store in cache (if caching is enabled)
|
// Store in cache with schema and table tags (if caching is enabled)
|
||||||
if !options.SkipCache && cacheKey != "" {
|
if !options.SkipCache && cacheKey != "" {
|
||||||
cacheTTL := time.Minute * 2 // Default 2 minutes TTL
|
cacheTTL := time.Minute * 2 // Default 2 minutes TTL
|
||||||
cacheData := &cache.CachedTotal{Total: total}
|
if err := setQueryTotalCache(ctx, cacheKey, total, schema, tableName, cacheTTL); err != nil {
|
||||||
if err := cache.GetDefaultCache().Set(ctx, cacheKey, cacheData, cacheTTL); err != nil {
|
|
||||||
logger.Warn("Failed to cache query total: %v", err)
|
logger.Warn("Failed to cache query total: %v", err)
|
||||||
// Don't fail the request if caching fails
|
// Don't fail the request if caching fails
|
||||||
} else {
|
} else {
|
||||||
@@ -1149,6 +1151,11 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
|
|||||||
}
|
}
|
||||||
|
|
||||||
logger.Info("Successfully created %d record(s)", len(mergedResults))
|
logger.Info("Successfully created %d record(s)", len(mergedResults))
|
||||||
|
// Invalidate cache for this table
|
||||||
|
cacheTags := buildCacheTags(schema, tableName)
|
||||||
|
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
||||||
|
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
|
||||||
|
}
|
||||||
h.sendResponseWithOptions(w, responseData, nil, &options)
|
h.sendResponseWithOptions(w, responseData, nil, &options)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1320,6 +1327,11 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id
|
|||||||
}
|
}
|
||||||
|
|
||||||
logger.Info("Successfully updated record with ID: %v", targetID)
|
logger.Info("Successfully updated record with ID: %v", targetID)
|
||||||
|
// Invalidate cache for this table
|
||||||
|
cacheTags := buildCacheTags(schema, tableName)
|
||||||
|
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
||||||
|
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
|
||||||
|
}
|
||||||
h.sendResponseWithOptions(w, mergedData, nil, &options)
|
h.sendResponseWithOptions(w, mergedData, nil, &options)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1388,6 +1400,11 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
logger.Info("Successfully deleted %d records", deletedCount)
|
logger.Info("Successfully deleted %d records", deletedCount)
|
||||||
|
// Invalidate cache for this table
|
||||||
|
cacheTags := buildCacheTags(schema, tableName)
|
||||||
|
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
||||||
|
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
|
||||||
|
}
|
||||||
h.sendResponse(w, map[string]interface{}{"deleted": deletedCount}, nil)
|
h.sendResponse(w, map[string]interface{}{"deleted": deletedCount}, nil)
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -1456,6 +1473,11 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
logger.Info("Successfully deleted %d records", deletedCount)
|
logger.Info("Successfully deleted %d records", deletedCount)
|
||||||
|
// Invalidate cache for this table
|
||||||
|
cacheTags := buildCacheTags(schema, tableName)
|
||||||
|
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
||||||
|
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
|
||||||
|
}
|
||||||
h.sendResponse(w, map[string]interface{}{"deleted": deletedCount}, nil)
|
h.sendResponse(w, map[string]interface{}{"deleted": deletedCount}, nil)
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -1510,6 +1532,11 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
logger.Info("Successfully deleted %d records", deletedCount)
|
logger.Info("Successfully deleted %d records", deletedCount)
|
||||||
|
// Invalidate cache for this table
|
||||||
|
cacheTags := buildCacheTags(schema, tableName)
|
||||||
|
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
||||||
|
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
|
||||||
|
}
|
||||||
h.sendResponse(w, map[string]interface{}{"deleted": deletedCount}, nil)
|
h.sendResponse(w, map[string]interface{}{"deleted": deletedCount}, nil)
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -1611,6 +1638,11 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Return the deleted record data
|
// Return the deleted record data
|
||||||
|
// Invalidate cache for this table
|
||||||
|
cacheTags := buildCacheTags(schema, tableName)
|
||||||
|
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
||||||
|
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
|
||||||
|
}
|
||||||
h.sendResponse(w, recordToDelete, nil)
|
h.sendResponse(w, recordToDelete, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
//go:build integration
|
||||||
// +build integration
|
// +build integration
|
||||||
|
|
||||||
package restheadspec
|
package restheadspec
|
||||||
@@ -21,12 +22,12 @@ import (
|
|||||||
|
|
||||||
// Test models
|
// Test models
|
||||||
type TestUser struct {
|
type TestUser struct {
|
||||||
ID uint `gorm:"primaryKey" json:"id"`
|
ID uint `gorm:"primaryKey" json:"id"`
|
||||||
Name string `gorm:"not null" json:"name"`
|
Name string `gorm:"not null" json:"name"`
|
||||||
Email string `gorm:"uniqueIndex;not null" json:"email"`
|
Email string `gorm:"uniqueIndex;not null" json:"email"`
|
||||||
Age int `json:"age"`
|
Age int `json:"age"`
|
||||||
Active bool `gorm:"default:true" json:"active"`
|
Active bool `gorm:"default:true" json:"active"`
|
||||||
CreatedAt time.Time `json:"created_at"`
|
CreatedAt time.Time `json:"created_at"`
|
||||||
Posts []TestPost `gorm:"foreignKey:UserID" json:"posts,omitempty"`
|
Posts []TestPost `gorm:"foreignKey:UserID" json:"posts,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -35,13 +36,13 @@ func (TestUser) TableName() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type TestPost struct {
|
type TestPost struct {
|
||||||
ID uint `gorm:"primaryKey" json:"id"`
|
ID uint `gorm:"primaryKey" json:"id"`
|
||||||
UserID uint `gorm:"not null" json:"user_id"`
|
UserID uint `gorm:"not null" json:"user_id"`
|
||||||
Title string `gorm:"not null" json:"title"`
|
Title string `gorm:"not null" json:"title"`
|
||||||
Content string `json:"content"`
|
Content string `json:"content"`
|
||||||
Published bool `gorm:"default:false" json:"published"`
|
Published bool `gorm:"default:false" json:"published"`
|
||||||
CreatedAt time.Time `json:"created_at"`
|
CreatedAt time.Time `json:"created_at"`
|
||||||
User *TestUser `gorm:"foreignKey:UserID" json:"user,omitempty"`
|
User *TestUser `gorm:"foreignKey:UserID" json:"user,omitempty"`
|
||||||
Comments []TestComment `gorm:"foreignKey:PostID" json:"comments,omitempty"`
|
Comments []TestComment `gorm:"foreignKey:PostID" json:"comments,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -54,7 +55,7 @@ type TestComment struct {
|
|||||||
PostID uint `gorm:"not null" json:"post_id"`
|
PostID uint `gorm:"not null" json:"post_id"`
|
||||||
Content string `gorm:"not null" json:"content"`
|
Content string `gorm:"not null" json:"content"`
|
||||||
CreatedAt time.Time `json:"created_at"`
|
CreatedAt time.Time `json:"created_at"`
|
||||||
Post *TestPost `gorm:"foreignKey:PostID" json:"post,omitempty"`
|
Post *TestPost `gorm:"foreignKey:PostID" json:"post,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (TestComment) TableName() string {
|
func (TestComment) TableName() string {
|
||||||
@@ -401,7 +402,7 @@ func TestIntegration_GetMetadata(t *testing.T) {
|
|||||||
|
|
||||||
muxRouter.ServeHTTP(w, req)
|
muxRouter.ServeHTTP(w, req)
|
||||||
|
|
||||||
if w.Code != http.StatusOK {
|
if !(w.Code == http.StatusOK || w.Code == http.StatusPartialContent) {
|
||||||
t.Errorf("Expected status 200, got %d. Body: %s", w.Code, w.Body.String())
|
t.Errorf("Expected status 200, got %d. Body: %s", w.Code, w.Body.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -492,7 +493,7 @@ func TestIntegration_QueryParamsOverHeaders(t *testing.T) {
|
|||||||
|
|
||||||
muxRouter.ServeHTTP(w, req)
|
muxRouter.ServeHTTP(w, req)
|
||||||
|
|
||||||
if w.Code != http.StatusOK {
|
if !(w.Code == http.StatusOK || w.Code == http.StatusPartialContent) {
|
||||||
t.Errorf("Expected status 200, got %d", w.Code)
|
t.Errorf("Expected status 200, got %d", w.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -465,7 +465,7 @@ func processRequest(ctx context.Context) {
|
|||||||
|
|
||||||
1. **Check collector is running:**
|
1. **Check collector is running:**
|
||||||
```bash
|
```bash
|
||||||
docker-compose ps
|
podman compose ps
|
||||||
```
|
```
|
||||||
|
|
||||||
2. **Verify endpoint:**
|
2. **Verify endpoint:**
|
||||||
@@ -476,7 +476,7 @@ func processRequest(ctx context.Context) {
|
|||||||
|
|
||||||
3. **Check logs:**
|
3. **Check logs:**
|
||||||
```bash
|
```bash
|
||||||
docker-compose logs otel-collector
|
podman compose logs otel-collector
|
||||||
```
|
```
|
||||||
|
|
||||||
### Disable Tracing
|
### Disable Tracing
|
||||||
|
|||||||
@@ -14,33 +14,33 @@ NC='\033[0m' # No Color
|
|||||||
|
|
||||||
echo -e "${GREEN}=== ResolveSpec Integration Tests ===${NC}\n"
|
echo -e "${GREEN}=== ResolveSpec Integration Tests ===${NC}\n"
|
||||||
|
|
||||||
# Check if docker-compose is available
|
# Check if podman compose is available
|
||||||
if ! command -v docker-compose &> /dev/null; then
|
if ! command -v podman &> /dev/null; then
|
||||||
echo -e "${RED}Error: docker-compose is not installed${NC}"
|
echo -e "${RED}Error: podman is not installed${NC}"
|
||||||
echo "Please install docker-compose or run PostgreSQL manually"
|
echo "Please install podman or run PostgreSQL manually"
|
||||||
echo "See INTEGRATION_TESTS.md for details"
|
echo "See INTEGRATION_TESTS.md for details"
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# Clean up any existing containers and networks from previous runs
|
# Clean up any existing containers and networks from previous runs
|
||||||
echo -e "${YELLOW}Cleaning up existing containers and networks...${NC}"
|
echo -e "${YELLOW}Cleaning up existing containers and networks...${NC}"
|
||||||
docker-compose down -v 2>/dev/null || true
|
podman compose down -v 2>/dev/null || true
|
||||||
|
|
||||||
# Start PostgreSQL
|
# Start PostgreSQL
|
||||||
echo -e "${YELLOW}Starting PostgreSQL...${NC}"
|
echo -e "${YELLOW}Starting PostgreSQL...${NC}"
|
||||||
docker-compose up -d postgres-test
|
podman compose up -d postgres-test
|
||||||
|
|
||||||
# Wait for PostgreSQL to be ready
|
# Wait for PostgreSQL to be ready
|
||||||
echo -e "${YELLOW}Waiting for PostgreSQL to be ready...${NC}"
|
echo -e "${YELLOW}Waiting for PostgreSQL to be ready...${NC}"
|
||||||
max_attempts=30
|
max_attempts=30
|
||||||
attempt=0
|
attempt=0
|
||||||
|
|
||||||
while ! docker-compose exec -T postgres-test pg_isready -U postgres > /dev/null 2>&1; do
|
while ! podman compose exec -T postgres-test pg_isready -U postgres > /dev/null 2>&1; do
|
||||||
attempt=$((attempt + 1))
|
attempt=$((attempt + 1))
|
||||||
if [ $attempt -ge $max_attempts ]; then
|
if [ $attempt -ge $max_attempts ]; then
|
||||||
echo -e "${RED}Error: PostgreSQL failed to start after ${max_attempts} seconds${NC}"
|
echo -e "${RED}Error: PostgreSQL failed to start after ${max_attempts} seconds${NC}"
|
||||||
docker-compose logs postgres-test
|
podman compose logs postgres-test
|
||||||
docker-compose down
|
podman compose down
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
sleep 1
|
sleep 1
|
||||||
@@ -51,8 +51,8 @@ echo -e "\n${GREEN}PostgreSQL is ready!${NC}\n"
|
|||||||
|
|
||||||
# Create test databases
|
# Create test databases
|
||||||
echo -e "${YELLOW}Creating test databases...${NC}"
|
echo -e "${YELLOW}Creating test databases...${NC}"
|
||||||
docker-compose exec -T postgres-test psql -U postgres -c "CREATE DATABASE resolvespec_test;" 2>/dev/null || echo " resolvespec_test already exists"
|
podman compose exec -T postgres-test psql -U postgres -c "CREATE DATABASE resolvespec_test;" 2>/dev/null || echo " resolvespec_test already exists"
|
||||||
docker-compose exec -T postgres-test psql -U postgres -c "CREATE DATABASE restheadspec_test;" 2>/dev/null || echo " restheadspec_test already exists"
|
podman compose exec -T postgres-test psql -U postgres -c "CREATE DATABASE restheadspec_test;" 2>/dev/null || echo " restheadspec_test already exists"
|
||||||
echo -e "${GREEN}Test databases ready!${NC}\n"
|
echo -e "${GREEN}Test databases ready!${NC}\n"
|
||||||
|
|
||||||
# Determine which tests to run
|
# Determine which tests to run
|
||||||
@@ -79,6 +79,6 @@ fi
|
|||||||
|
|
||||||
# Cleanup
|
# Cleanup
|
||||||
echo -e "\n${YELLOW}Stopping PostgreSQL...${NC}"
|
echo -e "\n${YELLOW}Stopping PostgreSQL...${NC}"
|
||||||
docker-compose down
|
podman compose down
|
||||||
|
|
||||||
exit $EXIT_CODE
|
exit $EXIT_CODE
|
||||||
|
|||||||
@@ -19,14 +19,14 @@ Integration tests validate the full functionality of both `pkg/resolvespec` and
|
|||||||
|
|
||||||
- Go 1.19 or later
|
- Go 1.19 or later
|
||||||
- PostgreSQL 12 or later
|
- PostgreSQL 12 or later
|
||||||
- Docker and Docker Compose (optional, for easy setup)
|
- Podman and Podman Compose (optional, for easy setup)
|
||||||
|
|
||||||
## Quick Start with Docker
|
## Quick Start with Podman
|
||||||
|
|
||||||
### 1. Start PostgreSQL with Docker Compose
|
### 1. Start PostgreSQL with Podman Compose
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
docker-compose up -d postgres-test
|
podman compose up -d postgres-test
|
||||||
```
|
```
|
||||||
|
|
||||||
This starts a PostgreSQL container with the following default settings:
|
This starts a PostgreSQL container with the following default settings:
|
||||||
@@ -52,7 +52,7 @@ go test -tags=integration ./pkg/restheadspec -v
|
|||||||
### 3. Stop PostgreSQL
|
### 3. Stop PostgreSQL
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
docker-compose down
|
podman compose down
|
||||||
```
|
```
|
||||||
|
|
||||||
## Manual PostgreSQL Setup
|
## Manual PostgreSQL Setup
|
||||||
@@ -161,7 +161,7 @@ If you see "connection refused" errors:
|
|||||||
|
|
||||||
1. Check that PostgreSQL is running:
|
1. Check that PostgreSQL is running:
|
||||||
```bash
|
```bash
|
||||||
docker-compose ps
|
podman compose ps
|
||||||
```
|
```
|
||||||
|
|
||||||
2. Verify connection parameters:
|
2. Verify connection parameters:
|
||||||
@@ -194,10 +194,10 @@ Each test automatically cleans up its data using `TRUNCATE`. If you need a fresh
|
|||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Stop and remove containers (removes data)
|
# Stop and remove containers (removes data)
|
||||||
docker-compose down -v
|
podman compose down -v
|
||||||
|
|
||||||
# Restart
|
# Restart
|
||||||
docker-compose up -d postgres-test
|
podman compose up -d postgres-test
|
||||||
```
|
```
|
||||||
|
|
||||||
## CI/CD Integration
|
## CI/CD Integration
|
||||||
|
|||||||
@@ -119,13 +119,13 @@ Integration tests require a PostgreSQL database and use the `// +build integrati
|
|||||||
- PostgreSQL 12+ installed and running
|
- PostgreSQL 12+ installed and running
|
||||||
- Create test databases manually (see below)
|
- Create test databases manually (see below)
|
||||||
|
|
||||||
### Setup with Docker
|
### Setup with Podman
|
||||||
|
|
||||||
1. **Start PostgreSQL**:
|
1. **Start PostgreSQL**:
|
||||||
```bash
|
```bash
|
||||||
make docker-up
|
make docker-up
|
||||||
# or
|
# or
|
||||||
docker-compose up -d postgres-test
|
podman compose up -d postgres-test
|
||||||
```
|
```
|
||||||
|
|
||||||
2. **Run Tests**:
|
2. **Run Tests**:
|
||||||
@@ -141,10 +141,10 @@ Integration tests require a PostgreSQL database and use the `// +build integrati
|
|||||||
```bash
|
```bash
|
||||||
make docker-down
|
make docker-down
|
||||||
# or
|
# or
|
||||||
docker-compose down
|
podman compose down
|
||||||
```
|
```
|
||||||
|
|
||||||
### Setup without Docker
|
### Setup without Podman
|
||||||
|
|
||||||
1. **Create Databases**:
|
1. **Create Databases**:
|
||||||
```sql
|
```sql
|
||||||
@@ -289,8 +289,8 @@ go test -tags=integration ./pkg/resolvespec -v
|
|||||||
**Problem**: "connection refused" or "database does not exist"
|
**Problem**: "connection refused" or "database does not exist"
|
||||||
|
|
||||||
**Solutions**:
|
**Solutions**:
|
||||||
1. Check PostgreSQL is running: `docker-compose ps`
|
1. Check PostgreSQL is running: `podman compose ps`
|
||||||
2. Verify databases exist: `docker-compose exec postgres-test psql -U postgres -l`
|
2. Verify databases exist: `podman compose exec postgres-test psql -U postgres -l`
|
||||||
3. Check environment variable: `echo $TEST_DATABASE_URL`
|
3. Check environment variable: `echo $TEST_DATABASE_URL`
|
||||||
4. Recreate databases: `make clean && make docker-up`
|
4. Recreate databases: `make clean && make docker-up`
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user