mirror of
https://github.com/bitechdev/ResolveSpec.git
synced 2025-12-29 15:54:26 +00:00
Compare commits
12 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ca4e53969b | ||
|
|
db2b7e878e | ||
|
|
9572bfc7b8 | ||
|
|
f0962ea1ec | ||
|
|
8fcb065b42 | ||
|
|
dc3b621380 | ||
|
|
a4dd2a7086 | ||
|
|
3ec2e5f15a | ||
|
|
c52afe2825 | ||
|
|
76e98d02c3 | ||
|
|
23e2db1496 | ||
|
|
d188f49126 |
218
pkg/common/adapters/database/RELATION_LOADING.md
Normal file
218
pkg/common/adapters/database/RELATION_LOADING.md
Normal file
@@ -0,0 +1,218 @@
|
||||
# Automatic Relation Loading Strategies
|
||||
|
||||
## Overview
|
||||
|
||||
**NEW:** The database adapters now **automatically** choose the optimal loading strategy by inspecting your model's relationship tags!
|
||||
|
||||
Simply use `PreloadRelation()` and the system automatically:
|
||||
- Detects relationship type from Bun/GORM tags
|
||||
- Uses **JOIN** for many-to-one and one-to-one (efficient, no duplication)
|
||||
- Uses **separate query** for one-to-many and many-to-many (avoids duplication)
|
||||
|
||||
## How It Works
|
||||
|
||||
```go
|
||||
// Just write this - the system handles the rest!
|
||||
db.NewSelect().
|
||||
Model(&links).
|
||||
PreloadRelation("Provider"). // ✓ Auto-detects belongs-to → uses JOIN
|
||||
PreloadRelation("Tags"). // ✓ Auto-detects has-many → uses separate query
|
||||
Scan(ctx, &links)
|
||||
```
|
||||
|
||||
### Detection Logic
|
||||
|
||||
The system inspects your model's struct tags:
|
||||
|
||||
**Bun models:**
|
||||
```go
|
||||
type Link struct {
|
||||
Provider *Provider `bun:"rel:belongs-to"` // → Detected: belongs-to → JOIN
|
||||
Tags []Tag `bun:"rel:has-many"` // → Detected: has-many → Separate query
|
||||
}
|
||||
```
|
||||
|
||||
**GORM models:**
|
||||
```go
|
||||
type Link struct {
|
||||
ProviderID int
|
||||
Provider *Provider `gorm:"foreignKey:ProviderID"` // → Detected: belongs-to → JOIN
|
||||
Tags []Tag `gorm:"many2many:link_tags"` // → Detected: many-to-many → Separate query
|
||||
}
|
||||
```
|
||||
|
||||
**Type inference (fallback):**
|
||||
- `[]Type` (slice) → has-many → Separate query
|
||||
- `*Type` (pointer) → belongs-to → JOIN
|
||||
- `Type` (struct) → belongs-to → JOIN
|
||||
|
||||
### What Gets Logged
|
||||
|
||||
Enable debug logging to see strategy selection:
|
||||
|
||||
```go
|
||||
bunAdapter.EnableQueryDebug()
|
||||
```
|
||||
|
||||
**Output:**
|
||||
```
|
||||
DEBUG: PreloadRelation 'Provider' detected as: belongs-to
|
||||
INFO: Using JOIN strategy for belongs-to relation 'Provider'
|
||||
DEBUG: PreloadRelation 'Links' detected as: has-many
|
||||
DEBUG: Using separate query for has-many relation 'Links'
|
||||
```
|
||||
|
||||
## Relationship Types
|
||||
|
||||
| Bun Tag | GORM Pattern | Field Type | Strategy | Why |
|
||||
|---------|--------------|------------|----------|-----|
|
||||
| `rel:has-many` | Slice field | `[]Type` | Separate Query | Avoids duplicating parent data |
|
||||
| `rel:belongs-to` | `foreignKey:` | `*Type` | JOIN | Single parent, no duplication |
|
||||
| `rel:has-one` | Single pointer | `*Type` | JOIN | One-to-one, no duplication |
|
||||
| `rel:many-to-many` | `many2many:` | `[]Type` | Separate Query | Complex join, avoid cartesian |
|
||||
|
||||
## Manual Override
|
||||
|
||||
If you need to force a specific strategy, use `JoinRelation()`:
|
||||
|
||||
```go
|
||||
// Force JOIN even for has-many (not recommended)
|
||||
db.NewSelect().
|
||||
Model(&providers).
|
||||
JoinRelation("Links"). // Explicitly use JOIN
|
||||
Scan(ctx, &providers)
|
||||
```
|
||||
|
||||
## Examples
|
||||
|
||||
### Automatic Strategy Selection (Recommended)
|
||||
|
||||
```go
|
||||
// Example 1: Loading parent provider for each link
|
||||
// System detects belongs-to → uses JOIN automatically
|
||||
db.NewSelect().
|
||||
Model(&links).
|
||||
PreloadRelation("Provider", func(q common.SelectQuery) common.SelectQuery {
|
||||
return q.Where("active = ?", true)
|
||||
}).
|
||||
Scan(ctx, &links)
|
||||
|
||||
// Generated SQL: Single query with JOIN
|
||||
// SELECT links.*, providers.*
|
||||
// FROM links
|
||||
// LEFT JOIN providers ON links.provider_id = providers.id
|
||||
// WHERE providers.active = true
|
||||
|
||||
// Example 2: Loading child links for each provider
|
||||
// System detects has-many → uses separate query automatically
|
||||
db.NewSelect().
|
||||
Model(&providers).
|
||||
PreloadRelation("Links", func(q common.SelectQuery) common.SelectQuery {
|
||||
return q.Where("active = ?", true)
|
||||
}).
|
||||
Scan(ctx, &providers)
|
||||
|
||||
// Generated SQL: Two queries
|
||||
// Query 1: SELECT * FROM providers
|
||||
// Query 2: SELECT * FROM links
|
||||
// WHERE provider_id IN (1, 2, 3, ...)
|
||||
// AND active = true
|
||||
```
|
||||
|
||||
### Mixed Relationships
|
||||
|
||||
```go
|
||||
type Order struct {
|
||||
ID int
|
||||
CustomerID int
|
||||
Customer *Customer `bun:"rel:belongs-to"` // JOIN
|
||||
Items []Item `bun:"rel:has-many"` // Separate
|
||||
Invoice *Invoice `bun:"rel:has-one"` // JOIN
|
||||
}
|
||||
|
||||
// All three handled optimally!
|
||||
db.NewSelect().
|
||||
Model(&orders).
|
||||
PreloadRelation("Customer"). // → JOIN (many-to-one)
|
||||
PreloadRelation("Items"). // → Separate (one-to-many)
|
||||
PreloadRelation("Invoice"). // → JOIN (one-to-one)
|
||||
Scan(ctx, &orders)
|
||||
```
|
||||
|
||||
## Performance Benefits
|
||||
|
||||
### Before (Manual Strategy Selection)
|
||||
|
||||
```go
|
||||
// You had to remember which to use:
|
||||
.PreloadRelation("Provider") // Should I use PreloadRelation or JoinRelation?
|
||||
.PreloadRelation("Links") // Which is more efficient here?
|
||||
```
|
||||
|
||||
### After (Automatic Selection)
|
||||
|
||||
```go
|
||||
// Just use PreloadRelation everywhere:
|
||||
.PreloadRelation("Provider") // ✓ System uses JOIN automatically
|
||||
.PreloadRelation("Links") // ✓ System uses separate query automatically
|
||||
```
|
||||
|
||||
## Migration Guide
|
||||
|
||||
**No changes needed!** If you're already using `PreloadRelation()`, it now automatically optimizes:
|
||||
|
||||
```go
|
||||
// Before: Always used separate query
|
||||
.PreloadRelation("Provider") // Inefficient: extra round trip
|
||||
|
||||
// After: Automatic optimization
|
||||
.PreloadRelation("Provider") // ✓ Now uses JOIN automatically!
|
||||
```
|
||||
|
||||
## Implementation Details
|
||||
|
||||
### Supported Bun Tags
|
||||
- `rel:has-many` → Separate query
|
||||
- `rel:belongs-to` → JOIN
|
||||
- `rel:has-one` → JOIN
|
||||
- `rel:many-to-many` or `rel:m2m` → Separate query
|
||||
|
||||
### Supported GORM Patterns
|
||||
- `many2many:` tag → Separate query
|
||||
- `foreignKey:` tag → JOIN (belongs-to)
|
||||
- `[]Type` slice without many2many → Separate query (has-many)
|
||||
- `*Type` pointer with foreignKey → JOIN (belongs-to)
|
||||
- `*Type` pointer without foreignKey → JOIN (has-one)
|
||||
|
||||
### Fallback Behavior
|
||||
- `[]Type` (slice) → Separate query (safe default for collections)
|
||||
- `*Type` or `Type` (single) → JOIN (safe default for single relations)
|
||||
- Unknown → Separate query (safest default)
|
||||
|
||||
## Debugging
|
||||
|
||||
To see strategy selection in action:
|
||||
|
||||
```go
|
||||
// Enable debug logging
|
||||
bunAdapter.EnableQueryDebug() // or gormAdapter.EnableQueryDebug()
|
||||
|
||||
// Run your query
|
||||
db.NewSelect().
|
||||
Model(&records).
|
||||
PreloadRelation("RelationName").
|
||||
Scan(ctx, &records)
|
||||
|
||||
// Check logs for:
|
||||
// - "PreloadRelation 'X' detected as: belongs-to"
|
||||
// - "Using JOIN strategy for belongs-to relation 'X'"
|
||||
// - Actual SQL queries executed
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Use PreloadRelation() for everything** - Let the system optimize
|
||||
2. **Define proper relationship tags** - Ensures correct detection
|
||||
3. **Only use JoinRelation() for overrides** - When you know better than auto-detection
|
||||
4. **Enable debug logging during development** - Verify optimal strategies are chosen
|
||||
5. **Trust the system** - It's designed to choose correctly based on relationship type
|
||||
81
pkg/common/adapters/database/alias_test.go
Normal file
81
pkg/common/adapters/database/alias_test.go
Normal file
@@ -0,0 +1,81 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNormalizeTableAlias(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
query string
|
||||
expectedAlias string
|
||||
tableName string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "strips plausible alias from simple condition",
|
||||
query: "APIL.rid_hub = 2576",
|
||||
expectedAlias: "apiproviderlink",
|
||||
tableName: "apiproviderlink",
|
||||
want: "rid_hub = 2576",
|
||||
},
|
||||
{
|
||||
name: "keeps correct alias",
|
||||
query: "apiproviderlink.rid_hub = 2576",
|
||||
expectedAlias: "apiproviderlink",
|
||||
tableName: "apiproviderlink",
|
||||
want: "apiproviderlink.rid_hub = 2576",
|
||||
},
|
||||
{
|
||||
name: "strips plausible alias with multiple conditions",
|
||||
query: "APIL.rid_hub = ? AND APIL.active = ?",
|
||||
expectedAlias: "apiproviderlink",
|
||||
tableName: "apiproviderlink",
|
||||
want: "rid_hub = ? AND active = ?",
|
||||
},
|
||||
{
|
||||
name: "handles mixed correct and plausible aliases",
|
||||
query: "APIL.rid_hub = ? AND apiproviderlink.active = ?",
|
||||
expectedAlias: "apiproviderlink",
|
||||
tableName: "apiproviderlink",
|
||||
want: "rid_hub = ? AND apiproviderlink.active = ?",
|
||||
},
|
||||
{
|
||||
name: "handles parentheses",
|
||||
query: "(APIL.rid_hub = ?)",
|
||||
expectedAlias: "apiproviderlink",
|
||||
tableName: "apiproviderlink",
|
||||
want: "(rid_hub = ?)",
|
||||
},
|
||||
{
|
||||
name: "no alias in query",
|
||||
query: "rid_hub = ?",
|
||||
expectedAlias: "apiproviderlink",
|
||||
tableName: "apiproviderlink",
|
||||
want: "rid_hub = ?",
|
||||
},
|
||||
{
|
||||
name: "keeps reference to different table (not in current table name)",
|
||||
query: "APIL.rid_hub = ?",
|
||||
expectedAlias: "apiprovider",
|
||||
tableName: "apiprovider",
|
||||
want: "APIL.rid_hub = ?",
|
||||
},
|
||||
{
|
||||
name: "keeps reference with short prefix that might be ambiguous",
|
||||
query: "AP.rid = ?",
|
||||
expectedAlias: "apiprovider",
|
||||
tableName: "apiprovider",
|
||||
want: "AP.rid = ?",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := normalizeTableAlias(tt.query, tt.expectedAlias, tt.tableName)
|
||||
if got != tt.want {
|
||||
t.Errorf("normalizeTableAlias() = %q, want %q", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/uptrace/bun"
|
||||
|
||||
@@ -15,6 +16,24 @@ import (
|
||||
"github.com/bitechdev/ResolveSpec/pkg/reflection"
|
||||
)
|
||||
|
||||
// QueryDebugHook is a Bun query hook that logs all SQL queries including preloads
|
||||
type QueryDebugHook struct{}
|
||||
|
||||
func (h *QueryDebugHook) BeforeQuery(ctx context.Context, event *bun.QueryEvent) context.Context {
|
||||
return ctx
|
||||
}
|
||||
|
||||
func (h *QueryDebugHook) AfterQuery(ctx context.Context, event *bun.QueryEvent) {
|
||||
query := event.Query
|
||||
duration := time.Since(event.StartTime)
|
||||
|
||||
if event.Err != nil {
|
||||
logger.Error("SQL Query Failed [%s]: %s. Error: %v", duration, query, event.Err)
|
||||
} else {
|
||||
logger.Debug("SQL Query Success [%s]: %s", duration, query)
|
||||
}
|
||||
}
|
||||
|
||||
// BunAdapter adapts Bun to work with our Database interface
|
||||
// This demonstrates how the abstraction works with different ORMs
|
||||
type BunAdapter struct {
|
||||
@@ -26,6 +45,20 @@ func NewBunAdapter(db *bun.DB) *BunAdapter {
|
||||
return &BunAdapter{db: db}
|
||||
}
|
||||
|
||||
// EnableQueryDebug enables query debugging which logs all SQL queries including preloads
|
||||
// This is useful for debugging preload queries that may be failing
|
||||
func (b *BunAdapter) EnableQueryDebug() {
|
||||
b.db.AddQueryHook(&QueryDebugHook{})
|
||||
logger.Info("Bun query debug mode enabled - all SQL queries will be logged")
|
||||
}
|
||||
|
||||
// DisableQueryDebug removes all query hooks
|
||||
func (b *BunAdapter) DisableQueryDebug() {
|
||||
// Create a new DB without hooks
|
||||
// Note: Bun doesn't have a RemoveQueryHook, so we'd need to track hooks manually
|
||||
logger.Info("To disable query debug, recreate the BunAdapter without adding the hook")
|
||||
}
|
||||
|
||||
func (b *BunAdapter) NewSelect() common.SelectQuery {
|
||||
return &BunSelectQuery{
|
||||
query: b.db.NewSelect(),
|
||||
@@ -107,6 +140,8 @@ type BunSelectQuery struct {
|
||||
tableName string // Just the table name, without schema
|
||||
tableAlias string
|
||||
deferredPreloads []deferredPreload // Preloads to execute as separate queries
|
||||
inJoinContext bool // Track if we're in a JOIN relation context
|
||||
joinTableAlias string // Alias to use for JOIN conditions
|
||||
}
|
||||
|
||||
// deferredPreload represents a preload that will be executed as a separate query
|
||||
@@ -156,10 +191,147 @@ func (b *BunSelectQuery) ColumnExpr(query string, args ...interface{}) common.Se
|
||||
}
|
||||
|
||||
func (b *BunSelectQuery) Where(query string, args ...interface{}) common.SelectQuery {
|
||||
// If we're in a JOIN context, add table prefix to unqualified columns
|
||||
if b.inJoinContext && b.joinTableAlias != "" {
|
||||
query = addTablePrefix(query, b.joinTableAlias)
|
||||
} else if b.tableAlias != "" && b.tableName != "" {
|
||||
// If we have a table alias defined, check if the query references a different alias
|
||||
// This can happen in preloads where the user expects a certain alias but Bun generates another
|
||||
query = normalizeTableAlias(query, b.tableAlias, b.tableName)
|
||||
}
|
||||
b.query = b.query.Where(query, args...)
|
||||
return b
|
||||
}
|
||||
|
||||
// addTablePrefix adds a table prefix to unqualified column references
|
||||
// This is used in JOIN contexts where conditions must reference the joined table
|
||||
func addTablePrefix(query, tableAlias string) string {
|
||||
if tableAlias == "" || query == "" {
|
||||
return query
|
||||
}
|
||||
|
||||
// Split on spaces and parentheses to find column references
|
||||
parts := strings.FieldsFunc(query, func(r rune) bool {
|
||||
return r == ' ' || r == '(' || r == ')' || r == ','
|
||||
})
|
||||
|
||||
modified := query
|
||||
for _, part := range parts {
|
||||
// Check if this looks like an unqualified column reference
|
||||
// (no dot, and likely a column name before an operator)
|
||||
if !strings.Contains(part, ".") {
|
||||
// Extract potential column name (before = or other operators)
|
||||
for _, op := range []string{"=", "!=", "<>", ">", ">=", "<", "<=", " LIKE ", " IN ", " IS "} {
|
||||
if strings.Contains(part, op) {
|
||||
colName := strings.Split(part, op)[0]
|
||||
colName = strings.TrimSpace(colName)
|
||||
if colName != "" && !isOperatorOrKeyword(colName) {
|
||||
// Add table prefix
|
||||
prefixed := tableAlias + "." + colName + strings.TrimPrefix(part, colName)
|
||||
modified = strings.ReplaceAll(modified, part, prefixed)
|
||||
logger.Debug("Adding table prefix '%s' to column '%s' in JOIN condition", tableAlias, colName)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return modified
|
||||
}
|
||||
|
||||
// isOperatorOrKeyword checks if a string is likely an operator or SQL keyword
|
||||
func isOperatorOrKeyword(s string) bool {
|
||||
s = strings.ToUpper(strings.TrimSpace(s))
|
||||
keywords := []string{"AND", "OR", "NOT", "IN", "IS", "NULL", "TRUE", "FALSE", "LIKE", "BETWEEN"}
|
||||
for _, kw := range keywords {
|
||||
if s == kw {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// isAcronymMatch checks if prefix is an acronym of tableName
|
||||
// For example, "apil" matches "apiproviderlink" because each letter appears in sequence
|
||||
func isAcronymMatch(prefix, tableName string) bool {
|
||||
if len(prefix) == 0 || len(tableName) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
prefixIdx := 0
|
||||
for i := 0; i < len(tableName) && prefixIdx < len(prefix); i++ {
|
||||
if tableName[i] == prefix[prefixIdx] {
|
||||
prefixIdx++
|
||||
}
|
||||
}
|
||||
|
||||
// All characters of prefix were found in sequence in tableName
|
||||
return prefixIdx == len(prefix)
|
||||
}
|
||||
|
||||
// normalizeTableAlias replaces table alias prefixes in SQL conditions
|
||||
// This handles cases where a user references a table alias that doesn't match
|
||||
// what Bun generates (common in preload contexts)
|
||||
func normalizeTableAlias(query, expectedAlias, tableName string) string {
|
||||
// Pattern: <word>.<column> where <word> might be an incorrect alias
|
||||
// We'll look for patterns like "APIL.column" and either:
|
||||
// 1. Remove the alias prefix if it's clearly meant for this table
|
||||
// 2. Leave it alone if it might be referring to another table (JOIN/preload)
|
||||
|
||||
// Split on spaces and parentheses to find qualified references
|
||||
parts := strings.FieldsFunc(query, func(r rune) bool {
|
||||
return r == ' ' || r == '(' || r == ')' || r == ','
|
||||
})
|
||||
|
||||
modified := query
|
||||
for _, part := range parts {
|
||||
// Check if this looks like a qualified column reference
|
||||
if dotIndex := strings.Index(part, "."); dotIndex > 0 {
|
||||
prefix := part[:dotIndex]
|
||||
column := part[dotIndex+1:]
|
||||
|
||||
// Check if the prefix matches our expected alias or table name (case-insensitive)
|
||||
if strings.EqualFold(prefix, expectedAlias) ||
|
||||
strings.EqualFold(prefix, tableName) ||
|
||||
strings.EqualFold(prefix, strings.ToLower(tableName)) {
|
||||
// Prefix matches current table, it's safe but redundant - leave it
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if the prefix could plausibly be an alias/acronym for this table
|
||||
// Only strip if we're confident it's meant for this table
|
||||
// For example: "APIL" could be an acronym for "apiproviderlink"
|
||||
prefixLower := strings.ToLower(prefix)
|
||||
tableNameLower := strings.ToLower(tableName)
|
||||
|
||||
// Check if prefix is a substring of table name
|
||||
isSubstring := strings.Contains(tableNameLower, prefixLower) && len(prefixLower) > 2
|
||||
|
||||
// Check if prefix is an acronym of table name
|
||||
// e.g., "APIL" matches "ApiProviderLink" (A-p-I-providerL-ink)
|
||||
isAcronym := false
|
||||
if !isSubstring && len(prefixLower) > 2 {
|
||||
isAcronym = isAcronymMatch(prefixLower, tableNameLower)
|
||||
}
|
||||
|
||||
if isSubstring || isAcronym {
|
||||
// This looks like it could be an alias for this table - strip it
|
||||
logger.Debug("Stripping plausible alias '%s' from WHERE condition, keeping just '%s'", prefix, column)
|
||||
// Replace the qualified reference with just the column name
|
||||
modified = strings.ReplaceAll(modified, part, column)
|
||||
} else {
|
||||
// Prefix doesn't match the current table at all
|
||||
// It's likely referring to a different table (JOIN/preload)
|
||||
// DON'T strip it - leave the qualified reference as-is
|
||||
logger.Debug("Keeping qualified reference '%s' - prefix '%s' doesn't match current table '%s'", part, prefix, tableName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return modified
|
||||
}
|
||||
|
||||
func (b *BunSelectQuery) WhereOr(query string, args ...interface{}) common.SelectQuery {
|
||||
b.query = b.query.WhereOr(query, args...)
|
||||
return b
|
||||
@@ -288,6 +460,27 @@ func (b *BunSelectQuery) Preload(relation string, conditions ...interface{}) com
|
||||
// }
|
||||
|
||||
func (b *BunSelectQuery) PreloadRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery {
|
||||
// Auto-detect relationship type and choose optimal loading strategy
|
||||
// Get the model from the query if available
|
||||
model := b.query.GetModel()
|
||||
if model != nil && model.Value() != nil {
|
||||
relType := reflection.GetRelationType(model.Value(), relation)
|
||||
|
||||
// Log the detected relationship type
|
||||
logger.Debug("PreloadRelation '%s' detected as: %s", relation, relType)
|
||||
|
||||
// If this is a belongs-to or has-one relation, use JOIN for better performance
|
||||
if relType.ShouldUseJoin() {
|
||||
logger.Info("Using JOIN strategy for %s relation '%s'", relType, relation)
|
||||
return b.JoinRelation(relation, apply...)
|
||||
}
|
||||
|
||||
// For has-many, many-to-many, or unknown: use separate query (safer default)
|
||||
if relType == reflection.RelationHasMany || relType == reflection.RelationManyToMany {
|
||||
logger.Debug("Using separate query for %s relation '%s'", relType, relation)
|
||||
}
|
||||
}
|
||||
|
||||
// Check if this relation chain would create problematic long aliases
|
||||
relationParts := strings.Split(relation, ".")
|
||||
aliasChain := strings.ToLower(strings.Join(relationParts, "__"))
|
||||
@@ -350,6 +543,28 @@ func (b *BunSelectQuery) PreloadRelation(relation string, apply ...func(common.S
|
||||
db: b.db,
|
||||
}
|
||||
|
||||
// Try to extract table name and alias from the preload model
|
||||
if model := sq.GetModel(); model != nil && model.Value() != nil {
|
||||
modelValue := model.Value()
|
||||
|
||||
// Extract table name if model implements TableNameProvider
|
||||
if provider, ok := modelValue.(common.TableNameProvider); ok {
|
||||
fullTableName := provider.TableName()
|
||||
wrapper.schema, wrapper.tableName = parseTableName(fullTableName)
|
||||
}
|
||||
|
||||
// Extract table alias if model implements TableAliasProvider
|
||||
if provider, ok := modelValue.(common.TableAliasProvider); ok {
|
||||
wrapper.tableAlias = provider.TableAlias()
|
||||
// Apply the alias to the Bun query so conditions can reference it
|
||||
if wrapper.tableAlias != "" {
|
||||
// Note: Bun's Relation() already sets up the table, but we can add
|
||||
// the alias explicitly if needed
|
||||
logger.Debug("Preload relation '%s' using table alias: %s", relation, wrapper.tableAlias)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Start with the interface value (not pointer)
|
||||
current := common.SelectQuery(wrapper)
|
||||
|
||||
@@ -372,6 +587,36 @@ func (b *BunSelectQuery) PreloadRelation(relation string, apply ...func(common.S
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *BunSelectQuery) JoinRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery {
|
||||
// JoinRelation uses a LEFT JOIN instead of a separate query
|
||||
// This is more efficient for many-to-one or one-to-one relationships
|
||||
|
||||
logger.Debug("JoinRelation '%s' - Using JOIN strategy with automatic WHERE prefix addition", relation)
|
||||
|
||||
// Wrap the apply functions to automatically add table prefix to WHERE conditions
|
||||
wrappedApply := make([]func(common.SelectQuery) common.SelectQuery, 0, len(apply))
|
||||
for _, fn := range apply {
|
||||
if fn != nil {
|
||||
wrappedFn := func(originalFn func(common.SelectQuery) common.SelectQuery) func(common.SelectQuery) common.SelectQuery {
|
||||
return func(q common.SelectQuery) common.SelectQuery {
|
||||
// Create a special wrapper that adds prefixes to WHERE conditions
|
||||
if bunQuery, ok := q.(*BunSelectQuery); ok {
|
||||
// Mark this query as being in JOIN context
|
||||
bunQuery.inJoinContext = true
|
||||
bunQuery.joinTableAlias = strings.ToLower(relation)
|
||||
}
|
||||
return originalFn(q)
|
||||
}
|
||||
}(fn)
|
||||
wrappedApply = append(wrappedApply, wrappedFn)
|
||||
}
|
||||
}
|
||||
|
||||
// Use PreloadRelation with the wrapped functions
|
||||
// Bun's Relation() will use JOIN for belongs-to and has-one relations
|
||||
return b.PreloadRelation(relation, wrappedApply...)
|
||||
}
|
||||
|
||||
func (b *BunSelectQuery) Order(order string) common.SelectQuery {
|
||||
b.query = b.query.Order(order)
|
||||
return b
|
||||
@@ -410,6 +655,9 @@ func (b *BunSelectQuery) Scan(ctx context.Context, dest interface{}) (err error)
|
||||
// Execute the main query first
|
||||
err = b.query.Scan(ctx, dest)
|
||||
if err != nil {
|
||||
// Log SQL string for debugging
|
||||
sqlStr := b.query.String()
|
||||
logger.Error("BunSelectQuery.Scan failed. SQL: %s. Error: %v", sqlStr, err)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -438,6 +686,9 @@ func (b *BunSelectQuery) ScanModel(ctx context.Context) (err error) {
|
||||
// Execute the main query first
|
||||
err = b.query.Scan(ctx)
|
||||
if err != nil {
|
||||
// Log SQL string for debugging
|
||||
sqlStr := b.query.String()
|
||||
logger.Error("BunSelectQuery.ScanModel failed. SQL: %s. Error: %v", sqlStr, err)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -573,15 +824,25 @@ func (b *BunSelectQuery) Count(ctx context.Context) (count int, err error) {
|
||||
// If Model() was set, use bun's native Count() which works properly
|
||||
if b.hasModel {
|
||||
count, err := b.query.Count(ctx)
|
||||
if err != nil {
|
||||
// Log SQL string for debugging
|
||||
sqlStr := b.query.String()
|
||||
logger.Error("BunSelectQuery.Count failed. SQL: %s. Error: %v", sqlStr, err)
|
||||
}
|
||||
return count, err
|
||||
}
|
||||
|
||||
// Otherwise, wrap as subquery to avoid "Model(nil)" error
|
||||
// This is needed when only Table() is set without a model
|
||||
err = b.db.NewSelect().
|
||||
countQuery := b.db.NewSelect().
|
||||
TableExpr("(?) AS subquery", b.query).
|
||||
ColumnExpr("COUNT(*)").
|
||||
Scan(ctx, &count)
|
||||
ColumnExpr("COUNT(*)")
|
||||
err = countQuery.Scan(ctx, &count)
|
||||
if err != nil {
|
||||
// Log SQL string for debugging
|
||||
sqlStr := countQuery.String()
|
||||
logger.Error("BunSelectQuery.Count (subquery) failed. SQL: %s. Error: %v", sqlStr, err)
|
||||
}
|
||||
return count, err
|
||||
}
|
||||
|
||||
@@ -592,7 +853,13 @@ func (b *BunSelectQuery) Exists(ctx context.Context) (exists bool, err error) {
|
||||
exists = false
|
||||
}
|
||||
}()
|
||||
return b.query.Exists(ctx)
|
||||
exists, err = b.query.Exists(ctx)
|
||||
if err != nil {
|
||||
// Log SQL string for debugging
|
||||
sqlStr := b.query.String()
|
||||
logger.Error("BunSelectQuery.Exists failed. SQL: %s. Error: %v", sqlStr, err)
|
||||
}
|
||||
return exists, err
|
||||
}
|
||||
|
||||
// BunInsertQuery implements InsertQuery for Bun
|
||||
@@ -729,6 +996,11 @@ func (b *BunUpdateQuery) Exec(ctx context.Context) (res common.Result, err error
|
||||
}
|
||||
}()
|
||||
result, err := b.query.Exec(ctx)
|
||||
if err != nil {
|
||||
// Log SQL string for debugging
|
||||
sqlStr := b.query.String()
|
||||
logger.Error("BunUpdateQuery.Exec failed. SQL: %s. Error: %v", sqlStr, err)
|
||||
}
|
||||
return &BunResult{result: result}, err
|
||||
}
|
||||
|
||||
@@ -759,6 +1031,11 @@ func (b *BunDeleteQuery) Exec(ctx context.Context) (res common.Result, err error
|
||||
}
|
||||
}()
|
||||
result, err := b.query.Exec(ctx)
|
||||
if err != nil {
|
||||
// Log SQL string for debugging
|
||||
sqlStr := b.query.String()
|
||||
logger.Error("BunDeleteQuery.Exec failed. SQL: %s. Error: %v", sqlStr, err)
|
||||
}
|
||||
return &BunResult{result: result}, err
|
||||
}
|
||||
|
||||
|
||||
@@ -23,6 +23,22 @@ func NewGormAdapter(db *gorm.DB) *GormAdapter {
|
||||
return &GormAdapter{db: db}
|
||||
}
|
||||
|
||||
// EnableQueryDebug enables query debugging which logs all SQL queries including preloads
|
||||
// This is useful for debugging preload queries that may be failing
|
||||
func (g *GormAdapter) EnableQueryDebug() *GormAdapter {
|
||||
g.db = g.db.Debug()
|
||||
logger.Info("GORM query debug mode enabled - all SQL queries will be logged")
|
||||
return g
|
||||
}
|
||||
|
||||
// DisableQueryDebug disables query debugging
|
||||
func (g *GormAdapter) DisableQueryDebug() *GormAdapter {
|
||||
// GORM's Debug() creates a new session, so we need to get the base DB
|
||||
// This is a simplified implementation
|
||||
logger.Info("GORM debug mode - create a new adapter without Debug() to disable")
|
||||
return g
|
||||
}
|
||||
|
||||
func (g *GormAdapter) NewSelect() common.SelectQuery {
|
||||
return &GormSelectQuery{db: g.db}
|
||||
}
|
||||
@@ -88,10 +104,12 @@ func (g *GormAdapter) RunInTransaction(ctx context.Context, fn func(common.Datab
|
||||
|
||||
// GormSelectQuery implements SelectQuery for GORM
|
||||
type GormSelectQuery struct {
|
||||
db *gorm.DB
|
||||
schema string // Separated schema name
|
||||
tableName string // Just the table name, without schema
|
||||
tableAlias string
|
||||
db *gorm.DB
|
||||
schema string // Separated schema name
|
||||
tableName string // Just the table name, without schema
|
||||
tableAlias string
|
||||
inJoinContext bool // Track if we're in a JOIN relation context
|
||||
joinTableAlias string // Alias to use for JOIN conditions
|
||||
}
|
||||
|
||||
func (g *GormSelectQuery) Model(model interface{}) common.SelectQuery {
|
||||
@@ -135,10 +153,61 @@ func (g *GormSelectQuery) ColumnExpr(query string, args ...interface{}) common.S
|
||||
}
|
||||
|
||||
func (g *GormSelectQuery) Where(query string, args ...interface{}) common.SelectQuery {
|
||||
// If we're in a JOIN context, add table prefix to unqualified columns
|
||||
if g.inJoinContext && g.joinTableAlias != "" {
|
||||
query = addTablePrefixGorm(query, g.joinTableAlias)
|
||||
}
|
||||
g.db = g.db.Where(query, args...)
|
||||
return g
|
||||
}
|
||||
|
||||
// addTablePrefixGorm adds a table prefix to unqualified column references (GORM version)
|
||||
func addTablePrefixGorm(query, tableAlias string) string {
|
||||
if tableAlias == "" || query == "" {
|
||||
return query
|
||||
}
|
||||
|
||||
// Split on spaces and parentheses to find column references
|
||||
parts := strings.FieldsFunc(query, func(r rune) bool {
|
||||
return r == ' ' || r == '(' || r == ')' || r == ','
|
||||
})
|
||||
|
||||
modified := query
|
||||
for _, part := range parts {
|
||||
// Check if this looks like an unqualified column reference
|
||||
if !strings.Contains(part, ".") {
|
||||
// Extract potential column name (before = or other operators)
|
||||
for _, op := range []string{"=", "!=", "<>", ">", ">=", "<", "<=", " LIKE ", " IN ", " IS "} {
|
||||
if strings.Contains(part, op) {
|
||||
colName := strings.Split(part, op)[0]
|
||||
colName = strings.TrimSpace(colName)
|
||||
if colName != "" && !isOperatorOrKeywordGorm(colName) {
|
||||
// Add table prefix
|
||||
prefixed := tableAlias + "." + colName + strings.TrimPrefix(part, colName)
|
||||
modified = strings.ReplaceAll(modified, part, prefixed)
|
||||
logger.Debug("Adding table prefix '%s' to column '%s' in JOIN condition", tableAlias, colName)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return modified
|
||||
}
|
||||
|
||||
// isOperatorOrKeywordGorm checks if a string is likely an operator or SQL keyword (GORM version)
|
||||
func isOperatorOrKeywordGorm(s string) bool {
|
||||
s = strings.ToUpper(strings.TrimSpace(s))
|
||||
keywords := []string{"AND", "OR", "NOT", "IN", "IS", "NULL", "TRUE", "FALSE", "LIKE", "BETWEEN"}
|
||||
for _, kw := range keywords {
|
||||
if s == kw {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (g *GormSelectQuery) WhereOr(query string, args ...interface{}) common.SelectQuery {
|
||||
g.db = g.db.Or(query, args...)
|
||||
return g
|
||||
@@ -222,6 +291,27 @@ func (g *GormSelectQuery) Preload(relation string, conditions ...interface{}) co
|
||||
}
|
||||
|
||||
func (g *GormSelectQuery) PreloadRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery {
|
||||
// Auto-detect relationship type and choose optimal loading strategy
|
||||
// Get the model from GORM's statement if available
|
||||
if g.db.Statement != nil && g.db.Statement.Model != nil {
|
||||
relType := reflection.GetRelationType(g.db.Statement.Model, relation)
|
||||
|
||||
// Log the detected relationship type
|
||||
logger.Debug("PreloadRelation '%s' detected as: %s", relation, relType)
|
||||
|
||||
// If this is a belongs-to or has-one relation, use JOIN for better performance
|
||||
if relType.ShouldUseJoin() {
|
||||
logger.Info("Using JOIN strategy for %s relation '%s'", relType, relation)
|
||||
return g.JoinRelation(relation, apply...)
|
||||
}
|
||||
|
||||
// For has-many, many-to-many, or unknown: use separate query (safer default)
|
||||
if relType == reflection.RelationHasMany || relType == reflection.RelationManyToMany {
|
||||
logger.Debug("Using separate query for %s relation '%s'", relType, relation)
|
||||
}
|
||||
}
|
||||
|
||||
// Use GORM's Preload (separate query strategy)
|
||||
g.db = g.db.Preload(relation, func(db *gorm.DB) *gorm.DB {
|
||||
if len(apply) == 0 {
|
||||
return db
|
||||
@@ -251,6 +341,42 @@ func (g *GormSelectQuery) PreloadRelation(relation string, apply ...func(common.
|
||||
return g
|
||||
}
|
||||
|
||||
func (g *GormSelectQuery) JoinRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery {
|
||||
// JoinRelation uses a JOIN instead of a separate preload query
|
||||
// This is more efficient for many-to-one or one-to-one relationships
|
||||
// as it avoids additional round trips to the database
|
||||
|
||||
// GORM's Joins() method forces a JOIN for the preload
|
||||
logger.Debug("JoinRelation '%s' - Using GORM Joins() with automatic WHERE prefix addition", relation)
|
||||
|
||||
g.db = g.db.Joins(relation, func(db *gorm.DB) *gorm.DB {
|
||||
if len(apply) == 0 {
|
||||
return db
|
||||
}
|
||||
|
||||
wrapper := &GormSelectQuery{
|
||||
db: db,
|
||||
inJoinContext: true, // Mark as JOIN context
|
||||
joinTableAlias: strings.ToLower(relation), // Use relation name as alias
|
||||
}
|
||||
current := common.SelectQuery(wrapper)
|
||||
|
||||
for _, fn := range apply {
|
||||
if fn != nil {
|
||||
current = fn(current)
|
||||
}
|
||||
}
|
||||
|
||||
if finalGorm, ok := current.(*GormSelectQuery); ok {
|
||||
return finalGorm.db
|
||||
}
|
||||
|
||||
return db
|
||||
})
|
||||
|
||||
return g
|
||||
}
|
||||
|
||||
func (g *GormSelectQuery) Order(order string) common.SelectQuery {
|
||||
g.db = g.db.Order(order)
|
||||
return g
|
||||
@@ -282,7 +408,15 @@ func (g *GormSelectQuery) Scan(ctx context.Context, dest interface{}) (err error
|
||||
err = logger.HandlePanic("GormSelectQuery.Scan", r)
|
||||
}
|
||||
}()
|
||||
return g.db.WithContext(ctx).Find(dest).Error
|
||||
err = g.db.WithContext(ctx).Find(dest).Error
|
||||
if err != nil {
|
||||
// Log SQL string for debugging
|
||||
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
|
||||
return tx.Find(dest)
|
||||
})
|
||||
logger.Error("GormSelectQuery.Scan failed. SQL: %s. Error: %v", sqlStr, err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (g *GormSelectQuery) ScanModel(ctx context.Context) (err error) {
|
||||
@@ -294,7 +428,15 @@ func (g *GormSelectQuery) ScanModel(ctx context.Context) (err error) {
|
||||
if g.db.Statement.Model == nil {
|
||||
return fmt.Errorf("ScanModel requires Model() to be set before scanning")
|
||||
}
|
||||
return g.db.WithContext(ctx).Find(g.db.Statement.Model).Error
|
||||
err = g.db.WithContext(ctx).Find(g.db.Statement.Model).Error
|
||||
if err != nil {
|
||||
// Log SQL string for debugging
|
||||
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
|
||||
return tx.Find(g.db.Statement.Model)
|
||||
})
|
||||
logger.Error("GormSelectQuery.ScanModel failed. SQL: %s. Error: %v", sqlStr, err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (g *GormSelectQuery) Count(ctx context.Context) (count int, err error) {
|
||||
@@ -306,6 +448,13 @@ func (g *GormSelectQuery) Count(ctx context.Context) (count int, err error) {
|
||||
}()
|
||||
var count64 int64
|
||||
err = g.db.WithContext(ctx).Count(&count64).Error
|
||||
if err != nil {
|
||||
// Log SQL string for debugging
|
||||
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
|
||||
return tx.Count(&count64)
|
||||
})
|
||||
logger.Error("GormSelectQuery.Count failed. SQL: %s. Error: %v", sqlStr, err)
|
||||
}
|
||||
return int(count64), err
|
||||
}
|
||||
|
||||
@@ -318,6 +467,13 @@ func (g *GormSelectQuery) Exists(ctx context.Context) (exists bool, err error) {
|
||||
}()
|
||||
var count int64
|
||||
err = g.db.WithContext(ctx).Limit(1).Count(&count).Error
|
||||
if err != nil {
|
||||
// Log SQL string for debugging
|
||||
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
|
||||
return tx.Limit(1).Count(&count)
|
||||
})
|
||||
logger.Error("GormSelectQuery.Exists failed. SQL: %s. Error: %v", sqlStr, err)
|
||||
}
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
@@ -456,6 +612,13 @@ func (g *GormUpdateQuery) Exec(ctx context.Context) (res common.Result, err erro
|
||||
}
|
||||
}()
|
||||
result := g.db.WithContext(ctx).Updates(g.updates)
|
||||
if result.Error != nil {
|
||||
// Log SQL string for debugging
|
||||
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
|
||||
return tx.Updates(g.updates)
|
||||
})
|
||||
logger.Error("GormUpdateQuery.Exec failed. SQL: %s. Error: %v", sqlStr, result.Error)
|
||||
}
|
||||
return &GormResult{result: result}, result.Error
|
||||
}
|
||||
|
||||
@@ -488,6 +651,13 @@ func (g *GormDeleteQuery) Exec(ctx context.Context) (res common.Result, err erro
|
||||
}
|
||||
}()
|
||||
result := g.db.WithContext(ctx).Delete(g.model)
|
||||
if result.Error != nil {
|
||||
// Log SQL string for debugging
|
||||
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
|
||||
return tx.Delete(g.model)
|
||||
})
|
||||
logger.Error("GormDeleteQuery.Exec failed. SQL: %s. Error: %v", sqlStr, result.Error)
|
||||
}
|
||||
return &GormResult{result: result}, result.Error
|
||||
}
|
||||
|
||||
|
||||
@@ -38,6 +38,7 @@ type SelectQuery interface {
|
||||
LeftJoin(query string, args ...interface{}) SelectQuery
|
||||
Preload(relation string, conditions ...interface{}) SelectQuery
|
||||
PreloadRelation(relation string, apply ...func(SelectQuery) SelectQuery) SelectQuery
|
||||
JoinRelation(relation string, apply ...func(SelectQuery) SelectQuery) SelectQuery
|
||||
Order(order string) SelectQuery
|
||||
Limit(n int) SelectQuery
|
||||
Offset(n int) SelectQuery
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
@@ -9,81 +8,40 @@ import (
|
||||
"github.com/bitechdev/ResolveSpec/pkg/reflection"
|
||||
)
|
||||
|
||||
// ValidateAndFixPreloadWhere validates that the WHERE clause for a preload contains
|
||||
// the relation prefix (alias). If not present, it attempts to add it to column references.
|
||||
// Returns the fixed WHERE clause and an error if it cannot be safely fixed.
|
||||
// ValidateAndFixPreloadWhere validates and normalizes WHERE clauses for preloads
|
||||
//
|
||||
// NOTE: For preload queries, table aliases from the parent query are not valid since
|
||||
// the preload executes as a separate query with its own table alias. This function
|
||||
// now simply validates basic syntax without requiring or adding prefixes.
|
||||
// The actual alias normalization happens in the database adapter layer.
|
||||
//
|
||||
// Returns the WHERE clause and an error if it contains obviously invalid syntax.
|
||||
func ValidateAndFixPreloadWhere(where string, relationName string) (string, error) {
|
||||
if where == "" {
|
||||
return where, nil
|
||||
}
|
||||
|
||||
// Check if the relation name is already present in the WHERE clause
|
||||
lowerWhere := strings.ToLower(where)
|
||||
lowerRelation := strings.ToLower(relationName)
|
||||
where = strings.TrimSpace(where)
|
||||
|
||||
// Check for patterns like "relation.", "relation ", or just "relation" followed by a dot
|
||||
if strings.Contains(lowerWhere, lowerRelation+".") ||
|
||||
strings.Contains(lowerWhere, "`"+lowerRelation+"`.") ||
|
||||
strings.Contains(lowerWhere, "\""+lowerRelation+"\".") {
|
||||
// Relation prefix is already present
|
||||
// Just do basic validation - don't require or add prefixes
|
||||
// The database adapter will handle alias normalization
|
||||
|
||||
// Check if the WHERE clause contains any qualified column references
|
||||
// If it does, log a debug message but don't fail - let the adapter handle it
|
||||
if strings.Contains(where, ".") {
|
||||
logger.Debug("Preload WHERE clause for '%s' contains qualified column references: '%s'. "+
|
||||
"Note: In preload context, table aliases from parent query are not available. "+
|
||||
"The database adapter will normalize aliases automatically.", relationName, where)
|
||||
}
|
||||
|
||||
// Validate that it's not empty or just whitespace
|
||||
if where == "" {
|
||||
return where, nil
|
||||
}
|
||||
|
||||
// If the WHERE clause is complex (contains OR, parentheses, subqueries, etc.),
|
||||
// we can't safely auto-fix it - require explicit prefix
|
||||
if strings.Contains(lowerWhere, " or ") ||
|
||||
strings.Contains(where, "(") ||
|
||||
strings.Contains(where, ")") {
|
||||
return "", fmt.Errorf("preload WHERE condition must reference the relation '%s' (e.g., '%s.column_name'). Complex WHERE clauses with OR/parentheses must explicitly use the relation prefix", relationName, relationName)
|
||||
}
|
||||
|
||||
// Try to add the relation prefix to simple column references
|
||||
// This handles basic cases like "column = value" or "column = value AND other_column = value"
|
||||
// Split by AND to handle multiple conditions (case-insensitive)
|
||||
originalConditions := strings.Split(where, " AND ")
|
||||
|
||||
// If uppercase split didn't work, try lowercase
|
||||
if len(originalConditions) == 1 {
|
||||
originalConditions = strings.Split(where, " and ")
|
||||
}
|
||||
|
||||
fixedConditions := make([]string, 0, len(originalConditions))
|
||||
|
||||
for _, cond := range originalConditions {
|
||||
cond = strings.TrimSpace(cond)
|
||||
if cond == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if this condition already has a table prefix (contains a dot)
|
||||
if strings.Contains(cond, ".") {
|
||||
fixedConditions = append(fixedConditions, cond)
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if this is a SQL expression/literal that shouldn't be prefixed
|
||||
lowerCond := strings.ToLower(strings.TrimSpace(cond))
|
||||
if IsSQLExpression(lowerCond) {
|
||||
// Don't prefix SQL expressions like "true", "false", "1=1", etc.
|
||||
fixedConditions = append(fixedConditions, cond)
|
||||
continue
|
||||
}
|
||||
|
||||
// Extract the column name (first identifier before operator)
|
||||
columnName := ExtractColumnName(cond)
|
||||
if columnName == "" {
|
||||
// Can't identify column name, require explicit prefix
|
||||
return "", fmt.Errorf("preload WHERE condition must reference the relation '%s' (e.g., '%s.column_name'). Cannot auto-fix condition: %s", relationName, relationName, cond)
|
||||
}
|
||||
|
||||
// Add relation prefix to the column name only
|
||||
fixedCond := strings.Replace(cond, columnName, relationName+"."+columnName, 1)
|
||||
fixedConditions = append(fixedConditions, fixedCond)
|
||||
}
|
||||
|
||||
fixedWhere := strings.Join(fixedConditions, " AND ")
|
||||
logger.Debug("Auto-fixed preload WHERE clause: '%s' -> '%s'", where, fixedWhere)
|
||||
return fixedWhere, nil
|
||||
// Return the WHERE clause as-is
|
||||
// The BunSelectQuery.Where() method will handle alias normalization via normalizeTableAlias()
|
||||
return where, nil
|
||||
}
|
||||
|
||||
// IsSQLExpression checks if a condition is a SQL expression that shouldn't be prefixed
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -9,18 +9,18 @@ import (
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// TestSqlInt16 tests SqlInt16 type
|
||||
func TestSqlInt16(t *testing.T) {
|
||||
// TestNewSqlInt16 tests NewSqlInt16 type
|
||||
func TestNewSqlInt16(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input interface{}
|
||||
expected SqlInt16
|
||||
}{
|
||||
{"int", 42, SqlInt16(42)},
|
||||
{"int32", int32(100), SqlInt16(100)},
|
||||
{"int64", int64(200), SqlInt16(200)},
|
||||
{"string", "123", SqlInt16(123)},
|
||||
{"nil", nil, SqlInt16(0)},
|
||||
{"int", 42, Null(int16(42), true)},
|
||||
{"int32", int32(100), NewSqlInt16(100)},
|
||||
{"int64", int64(200), NewSqlInt16(200)},
|
||||
{"string", "123", NewSqlInt16(123)},
|
||||
{"nil", nil, Null(int16(0), false)},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@@ -36,15 +36,15 @@ func TestSqlInt16(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlInt16_Value(t *testing.T) {
|
||||
func TestNewSqlInt16_Value(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input SqlInt16
|
||||
expected driver.Value
|
||||
}{
|
||||
{"zero", SqlInt16(0), nil},
|
||||
{"positive", SqlInt16(42), int64(42)},
|
||||
{"negative", SqlInt16(-10), int64(-10)},
|
||||
{"zero", Null(int16(0), false), nil},
|
||||
{"positive", NewSqlInt16(42), int16(42)},
|
||||
{"negative", NewSqlInt16(-10), int16(-10)},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@@ -60,8 +60,8 @@ func TestSqlInt16_Value(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlInt16_JSON(t *testing.T) {
|
||||
n := SqlInt16(42)
|
||||
func TestNewSqlInt16_JSON(t *testing.T) {
|
||||
n := NewSqlInt16(42)
|
||||
|
||||
// Marshal
|
||||
data, err := json.Marshal(n)
|
||||
@@ -78,24 +78,24 @@ func TestSqlInt16_JSON(t *testing.T) {
|
||||
if err := json.Unmarshal([]byte("123"), &n2); err != nil {
|
||||
t.Fatalf("Unmarshal failed: %v", err)
|
||||
}
|
||||
if n2 != 123 {
|
||||
t.Errorf("expected 123, got %d", n2)
|
||||
if n2.Int64() != 123 {
|
||||
t.Errorf("expected 123, got %d", n2.Int64())
|
||||
}
|
||||
}
|
||||
|
||||
// TestSqlInt64 tests SqlInt64 type
|
||||
func TestSqlInt64(t *testing.T) {
|
||||
// TestNewSqlInt64 tests NewSqlInt64 type
|
||||
func TestNewSqlInt64(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input interface{}
|
||||
expected SqlInt64
|
||||
}{
|
||||
{"int", 42, SqlInt64(42)},
|
||||
{"int32", int32(100), SqlInt64(100)},
|
||||
{"int64", int64(9223372036854775807), SqlInt64(9223372036854775807)},
|
||||
{"uint32", uint32(100), SqlInt64(100)},
|
||||
{"uint64", uint64(200), SqlInt64(200)},
|
||||
{"nil", nil, SqlInt64(0)},
|
||||
{"int", 42, NewSqlInt64(42)},
|
||||
{"int32", int32(100), NewSqlInt64(100)},
|
||||
{"int64", int64(9223372036854775807), NewSqlInt64(9223372036854775807)},
|
||||
{"uint32", uint32(100), NewSqlInt64(100)},
|
||||
{"uint64", uint64(200), NewSqlInt64(200)},
|
||||
{"nil", nil, SqlInt64{}},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@@ -135,8 +135,8 @@ func TestSqlFloat64(t *testing.T) {
|
||||
if n.Valid != tt.valid {
|
||||
t.Errorf("expected valid=%v, got valid=%v", tt.valid, n.Valid)
|
||||
}
|
||||
if tt.valid && n.Float64 != tt.expected {
|
||||
t.Errorf("expected %v, got %v", tt.expected, n.Float64)
|
||||
if tt.valid && n.Float64() != tt.expected {
|
||||
t.Errorf("expected %v, got %v", tt.expected, n.Float64())
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -162,7 +162,7 @@ func TestSqlTimeStamp(t *testing.T) {
|
||||
if err := ts.Scan(tt.input); err != nil {
|
||||
t.Fatalf("Scan failed: %v", err)
|
||||
}
|
||||
if ts.GetTime().IsZero() {
|
||||
if ts.Time().IsZero() {
|
||||
t.Error("expected non-zero time")
|
||||
}
|
||||
})
|
||||
@@ -171,7 +171,7 @@ func TestSqlTimeStamp(t *testing.T) {
|
||||
|
||||
func TestSqlTimeStamp_JSON(t *testing.T) {
|
||||
now := time.Date(2024, 1, 15, 10, 30, 45, 0, time.UTC)
|
||||
ts := SqlTimeStamp(now)
|
||||
ts := NewSqlTimeStamp(now)
|
||||
|
||||
// Marshal
|
||||
data, err := json.Marshal(ts)
|
||||
@@ -188,8 +188,8 @@ func TestSqlTimeStamp_JSON(t *testing.T) {
|
||||
if err := json.Unmarshal([]byte(`"2024-01-15T10:30:45"`), &ts2); err != nil {
|
||||
t.Fatalf("Unmarshal failed: %v", err)
|
||||
}
|
||||
if ts2.GetTime().Year() != 2024 {
|
||||
t.Errorf("expected year 2024, got %d", ts2.GetTime().Year())
|
||||
if ts2.Time().Year() != 2024 {
|
||||
t.Errorf("expected year 2024, got %d", ts2.Time().Year())
|
||||
}
|
||||
|
||||
// Test null
|
||||
@@ -226,7 +226,7 @@ func TestSqlDate(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSqlDate_JSON(t *testing.T) {
|
||||
date := SqlDate(time.Date(2024, 1, 15, 0, 0, 0, 0, time.UTC))
|
||||
date := NewSqlDate(time.Date(2024, 1, 15, 0, 0, 0, 0, time.UTC))
|
||||
|
||||
// Marshal
|
||||
data, err := json.Marshal(date)
|
||||
@@ -471,8 +471,8 @@ func TestSqlUUID_Scan(t *testing.T) {
|
||||
if u.Valid != tt.valid {
|
||||
t.Errorf("expected valid=%v, got valid=%v", tt.valid, u.Valid)
|
||||
}
|
||||
if tt.valid && u.String != tt.expected {
|
||||
t.Errorf("expected %s, got %s", tt.expected, u.String)
|
||||
if tt.valid && u.String() != tt.expected {
|
||||
t.Errorf("expected %s, got %s", tt.expected, u.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -480,13 +480,13 @@ func TestSqlUUID_Scan(t *testing.T) {
|
||||
|
||||
func TestSqlUUID_Value(t *testing.T) {
|
||||
testUUID := uuid.New()
|
||||
u := SqlUUID{String: testUUID.String(), Valid: true}
|
||||
u := NewSqlUUID(testUUID)
|
||||
|
||||
val, err := u.Value()
|
||||
if err != nil {
|
||||
t.Fatalf("Value failed: %v", err)
|
||||
}
|
||||
if val != testUUID.String() {
|
||||
if val != testUUID {
|
||||
t.Errorf("expected %s, got %s", testUUID.String(), val)
|
||||
}
|
||||
|
||||
@@ -503,7 +503,7 @@ func TestSqlUUID_Value(t *testing.T) {
|
||||
|
||||
func TestSqlUUID_JSON(t *testing.T) {
|
||||
testUUID := uuid.New()
|
||||
u := SqlUUID{String: testUUID.String(), Valid: true}
|
||||
u := NewSqlUUID(testUUID)
|
||||
|
||||
// Marshal
|
||||
data, err := json.Marshal(u)
|
||||
@@ -520,8 +520,8 @@ func TestSqlUUID_JSON(t *testing.T) {
|
||||
if err := json.Unmarshal([]byte(`"`+testUUID.String()+`"`), &u2); err != nil {
|
||||
t.Fatalf("Unmarshal failed: %v", err)
|
||||
}
|
||||
if u2.String != testUUID.String() {
|
||||
t.Errorf("expected %s, got %s", testUUID.String(), u2.String)
|
||||
if u2.String() != testUUID.String() {
|
||||
t.Errorf("expected %s, got %s", testUUID.String(), u2.String())
|
||||
}
|
||||
|
||||
// Test null
|
||||
|
||||
@@ -16,8 +16,8 @@ import (
|
||||
|
||||
// MockDatabase implements common.Database interface for testing
|
||||
type MockDatabase struct {
|
||||
QueryFunc func(ctx context.Context, dest interface{}, query string, args ...interface{}) error
|
||||
ExecFunc func(ctx context.Context, query string, args ...interface{}) (common.Result, error)
|
||||
QueryFunc func(ctx context.Context, dest interface{}, query string, args ...interface{}) error
|
||||
ExecFunc func(ctx context.Context, query string, args ...interface{}) (common.Result, error)
|
||||
RunInTransactionFunc func(ctx context.Context, fn func(common.Database) error) error
|
||||
}
|
||||
|
||||
@@ -161,9 +161,9 @@ func TestExtractInputVariables(t *testing.T) {
|
||||
handler := NewHandler(&MockDatabase{})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
sqlQuery string
|
||||
expectedVars []string
|
||||
name string
|
||||
sqlQuery string
|
||||
expectedVars []string
|
||||
}{
|
||||
{
|
||||
name: "No variables",
|
||||
@@ -340,9 +340,9 @@ func TestSqlQryWhere(t *testing.T) {
|
||||
// TestGetIPAddress tests IP address extraction
|
||||
func TestGetIPAddress(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setupReq func() *http.Request
|
||||
expected string
|
||||
name string
|
||||
setupReq func() *http.Request
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "X-Forwarded-For header",
|
||||
@@ -782,9 +782,10 @@ func TestReplaceMetaVariables(t *testing.T) {
|
||||
handler := NewHandler(&MockDatabase{})
|
||||
|
||||
userCtx := &security.UserContext{
|
||||
UserID: 123,
|
||||
UserName: "testuser",
|
||||
SessionID: "456",
|
||||
UserID: 123,
|
||||
UserName: "testuser",
|
||||
SessionID: "ABC456",
|
||||
SessionRID: 456,
|
||||
}
|
||||
|
||||
metainfo := map[string]interface{}{
|
||||
@@ -821,6 +822,12 @@ func TestReplaceMetaVariables(t *testing.T) {
|
||||
expectedCheck: func(result string) bool {
|
||||
return strings.Contains(result, "456")
|
||||
},
|
||||
}, {
|
||||
name: "Replace [id_session]",
|
||||
sqlQuery: "SELECT * FROM sessions WHERE session_id = [id_session]",
|
||||
expectedCheck: func(result string) bool {
|
||||
return strings.Contains(result, "ABC456")
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -28,6 +28,10 @@ func NewModelRegistry() *DefaultModelRegistry {
|
||||
}
|
||||
}
|
||||
|
||||
func GetDefaultRegistry() *DefaultModelRegistry {
|
||||
return defaultRegistry
|
||||
}
|
||||
|
||||
func SetDefaultRegistry(registry *DefaultModelRegistry) {
|
||||
registriesMutex.Lock()
|
||||
defer registriesMutex.Unlock()
|
||||
|
||||
321
pkg/openapi/README.md
Normal file
321
pkg/openapi/README.md
Normal file
@@ -0,0 +1,321 @@
|
||||
# OpenAPI Generator for ResolveSpec
|
||||
|
||||
This package provides automatic OpenAPI 3.0 specification generation for ResolveSpec, RestheadSpec, and FuncSpec API frameworks.
|
||||
|
||||
## Features
|
||||
|
||||
- **Automatic Schema Generation**: Generates OpenAPI schemas from Go struct models
|
||||
- **Multiple Framework Support**: Works with RestheadSpec, ResolveSpec, and FuncSpec
|
||||
- **Dynamic Endpoint Discovery**: Automatically discovers all registered models and generates paths
|
||||
- **Query Parameter Access**: Access spec via `?openapi` on any endpoint or via `/openapi`
|
||||
- **Comprehensive Documentation**: Includes all request/response schemas, parameters, and security schemes
|
||||
|
||||
## Quick Start
|
||||
|
||||
### RestheadSpec Example
|
||||
|
||||
```go
|
||||
import (
|
||||
"github.com/bitechdev/ResolveSpec/pkg/openapi"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/restheadspec"
|
||||
"github.com/gorilla/mux"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// 1. Create handler
|
||||
handler := restheadspec.NewHandlerWithGORM(db)
|
||||
|
||||
// 2. Register models
|
||||
handler.registry.RegisterModel("public.users", User{})
|
||||
handler.registry.RegisterModel("public.products", Product{})
|
||||
|
||||
// 3. Configure OpenAPI generator
|
||||
handler.SetOpenAPIGenerator(func() (string, error) {
|
||||
generator := openapi.NewGenerator(openapi.GeneratorConfig{
|
||||
Title: "My API",
|
||||
Description: "API documentation",
|
||||
Version: "1.0.0",
|
||||
BaseURL: "http://localhost:8080",
|
||||
Registry: handler.registry.(*modelregistry.DefaultModelRegistry),
|
||||
IncludeRestheadSpec: true,
|
||||
IncludeResolveSpec: false,
|
||||
IncludeFuncSpec: false,
|
||||
})
|
||||
return generator.GenerateJSON()
|
||||
})
|
||||
|
||||
// 4. Setup routes (automatically includes /openapi endpoint)
|
||||
router := mux.NewRouter()
|
||||
restheadspec.SetupMuxRoutes(router, handler, nil)
|
||||
|
||||
// Start server
|
||||
http.ListenAndServe(":8080", router)
|
||||
}
|
||||
```
|
||||
|
||||
### ResolveSpec Example
|
||||
|
||||
```go
|
||||
func main() {
|
||||
// 1. Create handler
|
||||
handler := resolvespec.NewHandlerWithGORM(db)
|
||||
|
||||
// 2. Register models
|
||||
handler.RegisterModel("public", "users", User{})
|
||||
handler.RegisterModel("public", "products", Product{})
|
||||
|
||||
// 3. Configure OpenAPI generator
|
||||
handler.SetOpenAPIGenerator(func() (string, error) {
|
||||
generator := openapi.NewGenerator(openapi.GeneratorConfig{
|
||||
Title: "My API",
|
||||
Version: "1.0.0",
|
||||
Registry: handler.registry.(*modelregistry.DefaultModelRegistry),
|
||||
IncludeResolveSpec: true,
|
||||
})
|
||||
return generator.GenerateJSON()
|
||||
})
|
||||
|
||||
// 4. Setup routes
|
||||
router := mux.NewRouter()
|
||||
resolvespec.SetupMuxRoutes(router, handler, nil)
|
||||
|
||||
http.ListenAndServe(":8080", router)
|
||||
}
|
||||
```
|
||||
|
||||
## Accessing the OpenAPI Specification
|
||||
|
||||
Once configured, the OpenAPI spec is available in two ways:
|
||||
|
||||
### 1. Global `/openapi` Endpoint
|
||||
|
||||
```bash
|
||||
curl http://localhost:8080/openapi
|
||||
```
|
||||
|
||||
Returns the complete OpenAPI specification for all registered models.
|
||||
|
||||
### 2. Query Parameter on Any Endpoint
|
||||
|
||||
```bash
|
||||
# RestheadSpec
|
||||
curl http://localhost:8080/public/users?openapi
|
||||
|
||||
# ResolveSpec
|
||||
curl http://localhost:8080/resolve/public/users?openapi
|
||||
```
|
||||
|
||||
Returns the same OpenAPI specification as `/openapi`.
|
||||
|
||||
## Generated Endpoints
|
||||
|
||||
### RestheadSpec
|
||||
|
||||
For each registered model (e.g., `public.users`), the following paths are generated:
|
||||
|
||||
- `GET /public/users` - List records with header-based filtering
|
||||
- `POST /public/users` - Create a new record
|
||||
- `GET /public/users/{id}` - Get a single record
|
||||
- `PUT /public/users/{id}` - Update a record
|
||||
- `PATCH /public/users/{id}` - Partially update a record
|
||||
- `DELETE /public/users/{id}` - Delete a record
|
||||
- `GET /public/users/metadata` - Get table metadata
|
||||
- `OPTIONS /public/users` - CORS preflight
|
||||
|
||||
### ResolveSpec
|
||||
|
||||
For each registered model (e.g., `public.users`), the following paths are generated:
|
||||
|
||||
- `POST /resolve/public/users` - Execute operations (read, create, meta)
|
||||
- `POST /resolve/public/users/{id}` - Execute operations (update, delete)
|
||||
- `GET /resolve/public/users` - Get metadata
|
||||
- `OPTIONS /resolve/public/users` - CORS preflight
|
||||
|
||||
## Schema Generation
|
||||
|
||||
The generator automatically extracts information from your Go struct tags:
|
||||
|
||||
```go
|
||||
type User struct {
|
||||
ID int `json:"id" gorm:"primaryKey" description:"User ID"`
|
||||
Name string `json:"name" gorm:"not null" description:"User's full name"`
|
||||
Email string `json:"email" gorm:"unique" description:"Email address"`
|
||||
CreatedAt time.Time `json:"created_at" description:"Creation timestamp"`
|
||||
Roles []string `json:"roles" description:"User roles"`
|
||||
}
|
||||
```
|
||||
|
||||
This generates an OpenAPI schema with:
|
||||
- Property names from `json` tags
|
||||
- Required fields from `gorm:"not null"` and non-pointer types
|
||||
- Descriptions from `description` tags
|
||||
- Proper type mappings (int → integer, time.Time → string with format: date-time, etc.)
|
||||
|
||||
## RestheadSpec Headers
|
||||
|
||||
The generator documents all RestheadSpec HTTP headers:
|
||||
|
||||
- `X-Filters` - JSON array of filter conditions
|
||||
- `X-Columns` - Comma-separated columns to select
|
||||
- `X-Sort` - JSON array of sort specifications
|
||||
- `X-Limit` - Maximum records to return
|
||||
- `X-Offset` - Records to skip
|
||||
- `X-Preload` - Relations to eager load
|
||||
- `X-Expand` - Relations to expand (LEFT JOIN)
|
||||
- `X-Distinct` - Enable DISTINCT queries
|
||||
- `X-Response-Format` - Response format (detail, simple, syncfusion)
|
||||
- `X-Clean-JSON` - Remove null/empty fields
|
||||
- `X-Custom-SQL-Where` - Custom WHERE clause (AND)
|
||||
- `X-Custom-SQL-Or` - Custom WHERE clause (OR)
|
||||
|
||||
## ResolveSpec Request Body
|
||||
|
||||
The generator documents the ResolveSpec request body structure:
|
||||
|
||||
```json
|
||||
{
|
||||
"operation": "read",
|
||||
"data": {},
|
||||
"id": 123,
|
||||
"options": {
|
||||
"limit": 10,
|
||||
"offset": 0,
|
||||
"filters": [
|
||||
{"column": "status", "operator": "eq", "value": "active"}
|
||||
],
|
||||
"sort": [
|
||||
{"column": "created_at", "direction": "desc"}
|
||||
]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Security Schemes
|
||||
|
||||
The generator automatically includes common security schemes:
|
||||
|
||||
- **BearerAuth**: JWT Bearer token authentication
|
||||
- **SessionToken**: Session token in Authorization header
|
||||
- **CookieAuth**: Cookie-based session authentication
|
||||
- **HeaderAuth**: Header-based user authentication (X-User-ID)
|
||||
|
||||
## FuncSpec Custom Endpoints
|
||||
|
||||
For FuncSpec, you can manually register custom SQL endpoints:
|
||||
|
||||
```go
|
||||
funcSpecEndpoints := map[string]openapi.FuncSpecEndpoint{
|
||||
"/api/reports/sales": {
|
||||
Path: "/api/reports/sales",
|
||||
Method: "GET",
|
||||
Summary: "Get sales report",
|
||||
Description: "Returns sales data for specified date range",
|
||||
SQLQuery: "SELECT * FROM sales WHERE date BETWEEN [start_date] AND [end_date]",
|
||||
Parameters: []string{"start_date", "end_date"},
|
||||
},
|
||||
}
|
||||
|
||||
generator := openapi.NewGenerator(openapi.GeneratorConfig{
|
||||
// ... other config
|
||||
IncludeFuncSpec: true,
|
||||
FuncSpecEndpoints: funcSpecEndpoints,
|
||||
})
|
||||
```
|
||||
|
||||
## Combining Multiple Frameworks
|
||||
|
||||
You can generate a unified OpenAPI spec that includes multiple frameworks:
|
||||
|
||||
```go
|
||||
generator := openapi.NewGenerator(openapi.GeneratorConfig{
|
||||
Title: "Unified API",
|
||||
Version: "1.0.0",
|
||||
Registry: sharedRegistry,
|
||||
IncludeRestheadSpec: true,
|
||||
IncludeResolveSpec: true,
|
||||
IncludeFuncSpec: true,
|
||||
FuncSpecEndpoints: funcSpecEndpoints,
|
||||
})
|
||||
```
|
||||
|
||||
This will generate a complete spec with all endpoints from all frameworks.
|
||||
|
||||
## Advanced Customization
|
||||
|
||||
You can customize the generated spec further:
|
||||
|
||||
```go
|
||||
handler.SetOpenAPIGenerator(func() (string, error) {
|
||||
generator := openapi.NewGenerator(config)
|
||||
|
||||
// Generate initial spec
|
||||
spec, err := generator.Generate()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Add contact information
|
||||
spec.Info.Contact = &openapi.Contact{
|
||||
Name: "API Support",
|
||||
Email: "support@example.com",
|
||||
URL: "https://example.com/support",
|
||||
}
|
||||
|
||||
// Add additional servers
|
||||
spec.Servers = append(spec.Servers, openapi.Server{
|
||||
URL: "https://staging.example.com",
|
||||
Description: "Staging Server",
|
||||
})
|
||||
|
||||
// Convert back to JSON
|
||||
data, _ := json.MarshalIndent(spec, "", " ")
|
||||
return string(data), nil
|
||||
})
|
||||
```
|
||||
|
||||
## Using with Swagger UI
|
||||
|
||||
You can serve the generated OpenAPI spec with Swagger UI:
|
||||
|
||||
1. Get the spec from `/openapi`
|
||||
2. Load it in Swagger UI at `https://petstore.swagger.io/`
|
||||
3. Or self-host Swagger UI and point it to your `/openapi` endpoint
|
||||
|
||||
Example with self-hosted Swagger UI:
|
||||
|
||||
```go
|
||||
// Serve Swagger UI static files
|
||||
router.PathPrefix("/swagger/").Handler(
|
||||
http.StripPrefix("/swagger/", http.FileServer(http.Dir("./swagger-ui"))),
|
||||
)
|
||||
|
||||
// Configure Swagger UI to use /openapi
|
||||
```
|
||||
|
||||
## Testing
|
||||
|
||||
You can test the OpenAPI endpoint:
|
||||
|
||||
```bash
|
||||
# Get the full spec
|
||||
curl http://localhost:8080/openapi | jq
|
||||
|
||||
# Validate with openapi-generator
|
||||
openapi-generator validate -i http://localhost:8080/openapi
|
||||
|
||||
# Generate client SDKs
|
||||
openapi-generator generate -i http://localhost:8080/openapi -g typescript-fetch -o ./client
|
||||
```
|
||||
|
||||
## Complete Example
|
||||
|
||||
See `example.go` in this package for complete, runnable examples including:
|
||||
- Basic RestheadSpec setup
|
||||
- Basic ResolveSpec setup
|
||||
- Combining both frameworks
|
||||
- Adding FuncSpec endpoints
|
||||
- Advanced customization
|
||||
|
||||
## License
|
||||
|
||||
Part of the ResolveSpec project.
|
||||
236
pkg/openapi/example.go
Normal file
236
pkg/openapi/example.go
Normal file
@@ -0,0 +1,236 @@
|
||||
package openapi
|
||||
|
||||
import (
|
||||
"github.com/gorilla/mux"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/resolvespec"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/restheadspec"
|
||||
)
|
||||
|
||||
// ExampleRestheadSpec shows how to configure OpenAPI generation for RestheadSpec
|
||||
func ExampleRestheadSpec(db *gorm.DB) {
|
||||
// 1. Create registry and register models
|
||||
registry := modelregistry.NewModelRegistry()
|
||||
// registry.RegisterModel("public.users", User{})
|
||||
// registry.RegisterModel("public.products", Product{})
|
||||
|
||||
// 2. Create handler with custom registry
|
||||
// import "github.com/bitechdev/ResolveSpec/pkg/common/adapters/database"
|
||||
// gormAdapter := database.NewGormAdapter(db)
|
||||
// handler := restheadspec.NewHandler(gormAdapter, registry)
|
||||
// Or use the convenience function (creates its own registry):
|
||||
handler := restheadspec.NewHandlerWithGORM(db)
|
||||
|
||||
// 3. Configure OpenAPI generator
|
||||
handler.SetOpenAPIGenerator(func() (string, error) {
|
||||
generator := NewGenerator(GeneratorConfig{
|
||||
Title: "My API",
|
||||
Description: "API documentation for my application",
|
||||
Version: "1.0.0",
|
||||
BaseURL: "http://localhost:8080",
|
||||
Registry: registry,
|
||||
IncludeRestheadSpec: true,
|
||||
IncludeResolveSpec: false,
|
||||
IncludeFuncSpec: false,
|
||||
})
|
||||
return generator.GenerateJSON()
|
||||
})
|
||||
|
||||
// 4. Setup routes (includes /openapi endpoint)
|
||||
router := mux.NewRouter()
|
||||
restheadspec.SetupMuxRoutes(router, handler, nil)
|
||||
|
||||
// Now the following endpoints are available:
|
||||
// GET /openapi - Full OpenAPI spec
|
||||
// GET /public/users?openapi - OpenAPI spec
|
||||
// GET /public/products?openapi - OpenAPI spec
|
||||
// etc.
|
||||
}
|
||||
|
||||
// ExampleResolveSpec shows how to configure OpenAPI generation for ResolveSpec
|
||||
func ExampleResolveSpec(db *gorm.DB) {
|
||||
// 1. Create registry and register models
|
||||
registry := modelregistry.NewModelRegistry()
|
||||
// registry.RegisterModel("public.users", User{})
|
||||
// registry.RegisterModel("public.products", Product{})
|
||||
|
||||
// 2. Create handler with custom registry
|
||||
// import "github.com/bitechdev/ResolveSpec/pkg/common/adapters/database"
|
||||
// gormAdapter := database.NewGormAdapter(db)
|
||||
// handler := resolvespec.NewHandler(gormAdapter, registry)
|
||||
// Or use the convenience function (creates its own registry):
|
||||
handler := resolvespec.NewHandlerWithGORM(db)
|
||||
// Note: handler.RegisterModel("schema", "entity", model) can be used
|
||||
|
||||
// 3. Configure OpenAPI generator
|
||||
handler.SetOpenAPIGenerator(func() (string, error) {
|
||||
generator := NewGenerator(GeneratorConfig{
|
||||
Title: "My API",
|
||||
Description: "API documentation for my application",
|
||||
Version: "1.0.0",
|
||||
BaseURL: "http://localhost:8080",
|
||||
Registry: registry,
|
||||
IncludeRestheadSpec: false,
|
||||
IncludeResolveSpec: true,
|
||||
IncludeFuncSpec: false,
|
||||
})
|
||||
return generator.GenerateJSON()
|
||||
})
|
||||
|
||||
// 4. Setup routes (includes /openapi endpoint)
|
||||
router := mux.NewRouter()
|
||||
resolvespec.SetupMuxRoutes(router, handler, nil)
|
||||
|
||||
// Now the following endpoints are available:
|
||||
// GET /openapi - Full OpenAPI spec
|
||||
// POST /resolve/public/users?openapi - OpenAPI spec
|
||||
// POST /resolve/public/products?openapi - OpenAPI spec
|
||||
// etc.
|
||||
}
|
||||
|
||||
// ExampleBothSpecs shows how to combine both RestheadSpec and ResolveSpec
|
||||
func ExampleBothSpecs(db *gorm.DB) {
|
||||
// Create shared registry
|
||||
sharedRegistry := modelregistry.NewModelRegistry()
|
||||
// Register models once
|
||||
// sharedRegistry.RegisterModel("public.users", User{})
|
||||
// sharedRegistry.RegisterModel("public.products", Product{})
|
||||
|
||||
// Create handlers - they will have separate registries initially
|
||||
restheadHandler := restheadspec.NewHandlerWithGORM(db)
|
||||
resolveHandler := resolvespec.NewHandlerWithGORM(db)
|
||||
|
||||
// Note: If you want to use a shared registry, create handlers manually:
|
||||
// import "github.com/bitechdev/ResolveSpec/pkg/common/adapters/database"
|
||||
// gormAdapter := database.NewGormAdapter(db)
|
||||
// restheadHandler := restheadspec.NewHandler(gormAdapter, sharedRegistry)
|
||||
// resolveHandler := resolvespec.NewHandler(gormAdapter, sharedRegistry)
|
||||
|
||||
// Configure OpenAPI generator for both
|
||||
generatorFunc := func() (string, error) {
|
||||
generator := NewGenerator(GeneratorConfig{
|
||||
Title: "My Unified API",
|
||||
Description: "Complete API documentation with both RestheadSpec and ResolveSpec endpoints",
|
||||
Version: "1.0.0",
|
||||
BaseURL: "http://localhost:8080",
|
||||
Registry: sharedRegistry,
|
||||
IncludeRestheadSpec: true,
|
||||
IncludeResolveSpec: true,
|
||||
IncludeFuncSpec: false,
|
||||
})
|
||||
return generator.GenerateJSON()
|
||||
}
|
||||
|
||||
restheadHandler.SetOpenAPIGenerator(generatorFunc)
|
||||
resolveHandler.SetOpenAPIGenerator(generatorFunc)
|
||||
|
||||
// Setup routes
|
||||
router := mux.NewRouter()
|
||||
restheadspec.SetupMuxRoutes(router, restheadHandler, nil)
|
||||
|
||||
// Add ResolveSpec routes under /resolve prefix
|
||||
resolveRouter := router.PathPrefix("/resolve").Subrouter()
|
||||
resolvespec.SetupMuxRoutes(resolveRouter, resolveHandler, nil)
|
||||
|
||||
// Now you have both styles of API available:
|
||||
// GET /openapi - Full OpenAPI spec (both styles)
|
||||
// GET /public/users - RestheadSpec list endpoint
|
||||
// POST /resolve/public/users - ResolveSpec operation endpoint
|
||||
// GET /public/users?openapi - OpenAPI spec
|
||||
// POST /resolve/public/users?openapi - OpenAPI spec
|
||||
}
|
||||
|
||||
// ExampleWithFuncSpec shows how to add FuncSpec endpoints to OpenAPI
|
||||
func ExampleWithFuncSpec() {
|
||||
// FuncSpec endpoints need to be registered manually since they don't use model registry
|
||||
generatorFunc := func() (string, error) {
|
||||
funcSpecEndpoints := map[string]FuncSpecEndpoint{
|
||||
"/api/reports/sales": {
|
||||
Path: "/api/reports/sales",
|
||||
Method: "GET",
|
||||
Summary: "Get sales report",
|
||||
Description: "Returns sales data for the specified date range",
|
||||
SQLQuery: "SELECT * FROM sales WHERE date BETWEEN [start_date] AND [end_date]",
|
||||
Parameters: []string{"start_date", "end_date"},
|
||||
},
|
||||
"/api/analytics/users": {
|
||||
Path: "/api/analytics/users",
|
||||
Method: "GET",
|
||||
Summary: "Get user analytics",
|
||||
Description: "Returns user activity analytics",
|
||||
SQLQuery: "SELECT * FROM user_analytics WHERE user_id = [user_id]",
|
||||
Parameters: []string{"user_id"},
|
||||
},
|
||||
}
|
||||
|
||||
generator := NewGenerator(GeneratorConfig{
|
||||
Title: "My API with Custom Queries",
|
||||
Description: "API with FuncSpec custom SQL endpoints",
|
||||
Version: "1.0.0",
|
||||
BaseURL: "http://localhost:8080",
|
||||
Registry: modelregistry.NewModelRegistry(),
|
||||
IncludeRestheadSpec: false,
|
||||
IncludeResolveSpec: false,
|
||||
IncludeFuncSpec: true,
|
||||
FuncSpecEndpoints: funcSpecEndpoints,
|
||||
})
|
||||
return generator.GenerateJSON()
|
||||
}
|
||||
|
||||
// Use this generator function with your handlers
|
||||
_ = generatorFunc
|
||||
}
|
||||
|
||||
// ExampleCustomization shows advanced customization options
|
||||
func ExampleCustomization() {
|
||||
// Create registry and register models with descriptions using struct tags
|
||||
registry := modelregistry.NewModelRegistry()
|
||||
|
||||
// type User struct {
|
||||
// ID int `json:"id" gorm:"primaryKey" description:"Unique user identifier"`
|
||||
// Name string `json:"name" description:"User's full name"`
|
||||
// Email string `json:"email" gorm:"unique" description:"User's email address"`
|
||||
// }
|
||||
// registry.RegisterModel("public.users", User{})
|
||||
|
||||
// Advanced configuration - create generator function
|
||||
generatorFunc := func() (string, error) {
|
||||
generator := NewGenerator(GeneratorConfig{
|
||||
Title: "My Advanced API",
|
||||
Description: "Comprehensive API documentation with custom configuration",
|
||||
Version: "2.1.0",
|
||||
BaseURL: "https://api.myapp.com",
|
||||
Registry: registry,
|
||||
IncludeRestheadSpec: true,
|
||||
IncludeResolveSpec: true,
|
||||
IncludeFuncSpec: false,
|
||||
})
|
||||
|
||||
// Generate the spec
|
||||
// spec, err := generator.Generate()
|
||||
// if err != nil {
|
||||
// return "", err
|
||||
// }
|
||||
|
||||
// Customize the spec further if needed
|
||||
// spec.Info.Contact = &Contact{
|
||||
// Name: "API Support",
|
||||
// Email: "support@myapp.com",
|
||||
// URL: "https://myapp.com/support",
|
||||
// }
|
||||
|
||||
// Add additional servers
|
||||
// spec.Servers = append(spec.Servers, Server{
|
||||
// URL: "https://staging-api.myapp.com",
|
||||
// Description: "Staging Server",
|
||||
// })
|
||||
|
||||
// Convert back to JSON - or use GenerateJSON() for simple cases
|
||||
return generator.GenerateJSON()
|
||||
}
|
||||
|
||||
// Use this generator function with your handlers
|
||||
_ = generatorFunc
|
||||
}
|
||||
513
pkg/openapi/generator.go
Normal file
513
pkg/openapi/generator.go
Normal file
@@ -0,0 +1,513 @@
|
||||
package openapi
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
|
||||
)
|
||||
|
||||
// OpenAPISpec represents the OpenAPI 3.0 specification structure
|
||||
type OpenAPISpec struct {
|
||||
OpenAPI string `json:"openapi"`
|
||||
Info Info `json:"info"`
|
||||
Servers []Server `json:"servers,omitempty"`
|
||||
Paths map[string]PathItem `json:"paths"`
|
||||
Components Components `json:"components"`
|
||||
Security []map[string][]string `json:"security,omitempty"`
|
||||
}
|
||||
|
||||
type Info struct {
|
||||
Title string `json:"title"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Version string `json:"version"`
|
||||
Contact *Contact `json:"contact,omitempty"`
|
||||
}
|
||||
|
||||
type Contact struct {
|
||||
Name string `json:"name,omitempty"`
|
||||
URL string `json:"url,omitempty"`
|
||||
Email string `json:"email,omitempty"`
|
||||
}
|
||||
|
||||
type Server struct {
|
||||
URL string `json:"url"`
|
||||
Description string `json:"description,omitempty"`
|
||||
}
|
||||
|
||||
type PathItem struct {
|
||||
Get *Operation `json:"get,omitempty"`
|
||||
Post *Operation `json:"post,omitempty"`
|
||||
Put *Operation `json:"put,omitempty"`
|
||||
Patch *Operation `json:"patch,omitempty"`
|
||||
Delete *Operation `json:"delete,omitempty"`
|
||||
Options *Operation `json:"options,omitempty"`
|
||||
}
|
||||
|
||||
type Operation struct {
|
||||
Summary string `json:"summary,omitempty"`
|
||||
Description string `json:"description,omitempty"`
|
||||
OperationID string `json:"operationId,omitempty"`
|
||||
Tags []string `json:"tags,omitempty"`
|
||||
Parameters []Parameter `json:"parameters,omitempty"`
|
||||
RequestBody *RequestBody `json:"requestBody,omitempty"`
|
||||
Responses map[string]Response `json:"responses"`
|
||||
Security []map[string][]string `json:"security,omitempty"`
|
||||
}
|
||||
|
||||
type Parameter struct {
|
||||
Name string `json:"name"`
|
||||
In string `json:"in"` // "query", "header", "path", "cookie"
|
||||
Description string `json:"description,omitempty"`
|
||||
Required bool `json:"required,omitempty"`
|
||||
Schema *Schema `json:"schema,omitempty"`
|
||||
Example interface{} `json:"example,omitempty"`
|
||||
}
|
||||
|
||||
type RequestBody struct {
|
||||
Description string `json:"description,omitempty"`
|
||||
Required bool `json:"required,omitempty"`
|
||||
Content map[string]MediaType `json:"content"`
|
||||
}
|
||||
|
||||
type MediaType struct {
|
||||
Schema *Schema `json:"schema,omitempty"`
|
||||
Example interface{} `json:"example,omitempty"`
|
||||
}
|
||||
|
||||
type Response struct {
|
||||
Description string `json:"description"`
|
||||
Content map[string]MediaType `json:"content,omitempty"`
|
||||
}
|
||||
|
||||
type Components struct {
|
||||
Schemas map[string]Schema `json:"schemas,omitempty"`
|
||||
SecuritySchemes map[string]SecurityScheme `json:"securitySchemes,omitempty"`
|
||||
}
|
||||
|
||||
type Schema struct {
|
||||
Type string `json:"type,omitempty"`
|
||||
Format string `json:"format,omitempty"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Properties map[string]*Schema `json:"properties,omitempty"`
|
||||
Items *Schema `json:"items,omitempty"`
|
||||
Required []string `json:"required,omitempty"`
|
||||
Ref string `json:"$ref,omitempty"`
|
||||
Enum []interface{} `json:"enum,omitempty"`
|
||||
Example interface{} `json:"example,omitempty"`
|
||||
AdditionalProperties interface{} `json:"additionalProperties,omitempty"`
|
||||
OneOf []*Schema `json:"oneOf,omitempty"`
|
||||
AnyOf []*Schema `json:"anyOf,omitempty"`
|
||||
}
|
||||
|
||||
type SecurityScheme struct {
|
||||
Type string `json:"type"` // "apiKey", "http", "oauth2", "openIdConnect"
|
||||
Description string `json:"description,omitempty"`
|
||||
Name string `json:"name,omitempty"` // For apiKey
|
||||
In string `json:"in,omitempty"` // For apiKey: "query", "header", "cookie"
|
||||
Scheme string `json:"scheme,omitempty"` // For http: "basic", "bearer"
|
||||
BearerFormat string `json:"bearerFormat,omitempty"` // For http bearer
|
||||
}
|
||||
|
||||
// GeneratorConfig holds configuration for OpenAPI spec generation
|
||||
type GeneratorConfig struct {
|
||||
Title string
|
||||
Description string
|
||||
Version string
|
||||
BaseURL string
|
||||
Registry *modelregistry.DefaultModelRegistry
|
||||
IncludeRestheadSpec bool
|
||||
IncludeResolveSpec bool
|
||||
IncludeFuncSpec bool
|
||||
FuncSpecEndpoints map[string]FuncSpecEndpoint // path -> endpoint info
|
||||
}
|
||||
|
||||
// FuncSpecEndpoint represents a FuncSpec endpoint for OpenAPI generation
|
||||
type FuncSpecEndpoint struct {
|
||||
Path string
|
||||
Method string
|
||||
Summary string
|
||||
Description string
|
||||
SQLQuery string
|
||||
Parameters []string // Parameter names extracted from SQL
|
||||
}
|
||||
|
||||
// Generator creates OpenAPI specifications
|
||||
type Generator struct {
|
||||
config GeneratorConfig
|
||||
}
|
||||
|
||||
// NewGenerator creates a new OpenAPI generator
|
||||
func NewGenerator(config GeneratorConfig) *Generator {
|
||||
if config.Title == "" {
|
||||
config.Title = "ResolveSpec API"
|
||||
}
|
||||
if config.Version == "" {
|
||||
config.Version = "1.0.0"
|
||||
}
|
||||
return &Generator{config: config}
|
||||
}
|
||||
|
||||
// Generate creates the complete OpenAPI specification
|
||||
func (g *Generator) Generate() (*OpenAPISpec, error) {
|
||||
spec := &OpenAPISpec{
|
||||
OpenAPI: "3.0.0",
|
||||
Info: Info{
|
||||
Title: g.config.Title,
|
||||
Description: g.config.Description,
|
||||
Version: g.config.Version,
|
||||
},
|
||||
Paths: make(map[string]PathItem),
|
||||
Components: Components{
|
||||
Schemas: make(map[string]Schema),
|
||||
SecuritySchemes: g.generateSecuritySchemes(),
|
||||
},
|
||||
}
|
||||
|
||||
if g.config.BaseURL != "" {
|
||||
spec.Servers = []Server{
|
||||
{URL: g.config.BaseURL, Description: "API Server"},
|
||||
}
|
||||
}
|
||||
|
||||
// Add common schemas
|
||||
g.addCommonSchemas(spec)
|
||||
|
||||
// Generate paths and schemas from registered models
|
||||
if err := g.generateFromModels(spec); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return spec, nil
|
||||
}
|
||||
|
||||
// GenerateJSON generates OpenAPI spec as JSON string
|
||||
func (g *Generator) GenerateJSON() (string, error) {
|
||||
spec, err := g.Generate()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
data, err := json.MarshalIndent(spec, "", " ")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to marshal spec: %w", err)
|
||||
}
|
||||
|
||||
return string(data), nil
|
||||
}
|
||||
|
||||
// generateSecuritySchemes creates security scheme definitions
|
||||
func (g *Generator) generateSecuritySchemes() map[string]SecurityScheme {
|
||||
return map[string]SecurityScheme{
|
||||
"BearerAuth": {
|
||||
Type: "http",
|
||||
Scheme: "bearer",
|
||||
BearerFormat: "JWT",
|
||||
Description: "JWT Bearer token authentication",
|
||||
},
|
||||
"SessionToken": {
|
||||
Type: "apiKey",
|
||||
In: "header",
|
||||
Name: "Authorization",
|
||||
Description: "Session token authentication",
|
||||
},
|
||||
"CookieAuth": {
|
||||
Type: "apiKey",
|
||||
In: "cookie",
|
||||
Name: "session_token",
|
||||
Description: "Cookie-based session authentication",
|
||||
},
|
||||
"HeaderAuth": {
|
||||
Type: "apiKey",
|
||||
In: "header",
|
||||
Name: "X-User-ID",
|
||||
Description: "Header-based user authentication",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// addCommonSchemas adds common reusable schemas
|
||||
func (g *Generator) addCommonSchemas(spec *OpenAPISpec) {
|
||||
// Response wrapper schema
|
||||
spec.Components.Schemas["Response"] = Schema{
|
||||
Type: "object",
|
||||
Properties: map[string]*Schema{
|
||||
"success": {Type: "boolean", Description: "Indicates if the operation was successful"},
|
||||
"data": {Description: "The response data"},
|
||||
"metadata": {Ref: "#/components/schemas/Metadata"},
|
||||
"error": {Ref: "#/components/schemas/APIError"},
|
||||
},
|
||||
}
|
||||
|
||||
// Metadata schema
|
||||
spec.Components.Schemas["Metadata"] = Schema{
|
||||
Type: "object",
|
||||
Properties: map[string]*Schema{
|
||||
"total": {Type: "integer", Description: "Total number of records"},
|
||||
"count": {Type: "integer", Description: "Number of records in this response"},
|
||||
"filtered": {Type: "integer", Description: "Number of records after filtering"},
|
||||
"limit": {Type: "integer", Description: "Limit applied"},
|
||||
"offset": {Type: "integer", Description: "Offset applied"},
|
||||
"rowNumber": {Type: "integer", Description: "Row number for cursor pagination"},
|
||||
},
|
||||
}
|
||||
|
||||
// APIError schema
|
||||
spec.Components.Schemas["APIError"] = Schema{
|
||||
Type: "object",
|
||||
Properties: map[string]*Schema{
|
||||
"code": {Type: "string", Description: "Error code"},
|
||||
"message": {Type: "string", Description: "Error message"},
|
||||
"details": {Type: "string", Description: "Detailed error information"},
|
||||
},
|
||||
}
|
||||
|
||||
// RequestOptions schema
|
||||
spec.Components.Schemas["RequestOptions"] = Schema{
|
||||
Type: "object",
|
||||
Properties: map[string]*Schema{
|
||||
"preload": {
|
||||
Type: "array",
|
||||
Description: "Relations to eager load",
|
||||
Items: &Schema{Ref: "#/components/schemas/PreloadOption"},
|
||||
},
|
||||
"columns": {
|
||||
Type: "array",
|
||||
Description: "Columns to select",
|
||||
Items: &Schema{Type: "string"},
|
||||
},
|
||||
"omitColumns": {
|
||||
Type: "array",
|
||||
Description: "Columns to exclude",
|
||||
Items: &Schema{Type: "string"},
|
||||
},
|
||||
"filters": {
|
||||
Type: "array",
|
||||
Description: "Filter conditions",
|
||||
Items: &Schema{Ref: "#/components/schemas/FilterOption"},
|
||||
},
|
||||
"sort": {
|
||||
Type: "array",
|
||||
Description: "Sort specifications",
|
||||
Items: &Schema{Ref: "#/components/schemas/SortOption"},
|
||||
},
|
||||
"limit": {Type: "integer", Description: "Maximum number of records"},
|
||||
"offset": {Type: "integer", Description: "Number of records to skip"},
|
||||
},
|
||||
}
|
||||
|
||||
// FilterOption schema
|
||||
spec.Components.Schemas["FilterOption"] = Schema{
|
||||
Type: "object",
|
||||
Properties: map[string]*Schema{
|
||||
"column": {Type: "string", Description: "Column name"},
|
||||
"operator": {Type: "string", Description: "Comparison operator", Enum: []interface{}{"eq", "neq", "gt", "lt", "gte", "lte", "like", "ilike", "in", "not_in", "between", "is_null", "is_not_null"}},
|
||||
"value": {Description: "Filter value"},
|
||||
"logicOperator": {Type: "string", Description: "Logic operator", Enum: []interface{}{"AND", "OR"}},
|
||||
},
|
||||
}
|
||||
|
||||
// SortOption schema
|
||||
spec.Components.Schemas["SortOption"] = Schema{
|
||||
Type: "object",
|
||||
Properties: map[string]*Schema{
|
||||
"column": {Type: "string", Description: "Column name"},
|
||||
"direction": {Type: "string", Description: "Sort direction", Enum: []interface{}{"asc", "desc"}},
|
||||
},
|
||||
}
|
||||
|
||||
// PreloadOption schema
|
||||
spec.Components.Schemas["PreloadOption"] = Schema{
|
||||
Type: "object",
|
||||
Properties: map[string]*Schema{
|
||||
"relation": {Type: "string", Description: "Relation name"},
|
||||
"columns": {
|
||||
Type: "array",
|
||||
Description: "Columns to select from related table",
|
||||
Items: &Schema{Type: "string"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// ResolveSpec RequestBody schema
|
||||
spec.Components.Schemas["ResolveSpecRequest"] = Schema{
|
||||
Type: "object",
|
||||
Properties: map[string]*Schema{
|
||||
"operation": {Type: "string", Description: "Operation type", Enum: []interface{}{"read", "create", "update", "delete", "meta"}},
|
||||
"data": {Description: "Payload data (object or array)"},
|
||||
"id": {Type: "integer", Description: "Record ID for single operations"},
|
||||
"options": {Ref: "#/components/schemas/RequestOptions"},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// generateFromModels generates paths and schemas from registered models
|
||||
func (g *Generator) generateFromModels(spec *OpenAPISpec) error {
|
||||
if g.config.Registry == nil {
|
||||
return fmt.Errorf("model registry is required")
|
||||
}
|
||||
|
||||
models := g.config.Registry.GetAllModels()
|
||||
|
||||
for name, model := range models {
|
||||
// Parse schema.entity from model name
|
||||
schema, entity := parseModelName(name)
|
||||
|
||||
// Generate schema for this model
|
||||
modelSchema := g.generateModelSchema(model)
|
||||
schemaName := formatSchemaName(schema, entity)
|
||||
spec.Components.Schemas[schemaName] = modelSchema
|
||||
|
||||
// Generate paths for different frameworks
|
||||
if g.config.IncludeRestheadSpec {
|
||||
g.generateRestheadSpecPaths(spec, schema, entity, schemaName)
|
||||
}
|
||||
|
||||
if g.config.IncludeResolveSpec {
|
||||
g.generateResolveSpecPaths(spec, schema, entity, schemaName)
|
||||
}
|
||||
}
|
||||
|
||||
// Generate FuncSpec paths if configured
|
||||
if g.config.IncludeFuncSpec && len(g.config.FuncSpecEndpoints) > 0 {
|
||||
g.generateFuncSpecPaths(spec)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// generateModelSchema creates an OpenAPI schema from a Go struct
|
||||
func (g *Generator) generateModelSchema(model interface{}) Schema {
|
||||
schema := Schema{
|
||||
Type: "object",
|
||||
Properties: make(map[string]*Schema),
|
||||
Required: []string{},
|
||||
}
|
||||
|
||||
modelType := reflect.TypeOf(model)
|
||||
if modelType.Kind() == reflect.Ptr {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
if modelType.Kind() != reflect.Struct {
|
||||
return schema
|
||||
}
|
||||
|
||||
for i := 0; i < modelType.NumField(); i++ {
|
||||
field := modelType.Field(i)
|
||||
|
||||
// Skip unexported fields
|
||||
if !field.IsExported() {
|
||||
continue
|
||||
}
|
||||
|
||||
// Get JSON tag name
|
||||
jsonTag := field.Tag.Get("json")
|
||||
if jsonTag == "-" {
|
||||
continue
|
||||
}
|
||||
|
||||
fieldName := strings.Split(jsonTag, ",")[0]
|
||||
if fieldName == "" {
|
||||
fieldName = field.Name
|
||||
}
|
||||
|
||||
// Generate property schema
|
||||
propSchema := g.generatePropertySchema(field)
|
||||
schema.Properties[fieldName] = propSchema
|
||||
|
||||
// Check if field is required (not a pointer and no omitempty)
|
||||
if field.Type.Kind() != reflect.Ptr && !strings.Contains(jsonTag, "omitempty") {
|
||||
schema.Required = append(schema.Required, fieldName)
|
||||
}
|
||||
}
|
||||
|
||||
return schema
|
||||
}
|
||||
|
||||
// generatePropertySchema creates a schema for a struct field
|
||||
func (g *Generator) generatePropertySchema(field reflect.StructField) *Schema {
|
||||
schema := &Schema{}
|
||||
|
||||
fieldType := field.Type
|
||||
if fieldType.Kind() == reflect.Ptr {
|
||||
fieldType = fieldType.Elem()
|
||||
}
|
||||
|
||||
// Get description from tag
|
||||
if desc := field.Tag.Get("description"); desc != "" {
|
||||
schema.Description = desc
|
||||
}
|
||||
|
||||
switch fieldType.Kind() {
|
||||
case reflect.String:
|
||||
schema.Type = "string"
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
|
||||
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
schema.Type = "integer"
|
||||
case reflect.Float32, reflect.Float64:
|
||||
schema.Type = "number"
|
||||
case reflect.Bool:
|
||||
schema.Type = "boolean"
|
||||
case reflect.Slice, reflect.Array:
|
||||
schema.Type = "array"
|
||||
elemType := fieldType.Elem()
|
||||
if elemType.Kind() == reflect.Ptr {
|
||||
elemType = elemType.Elem()
|
||||
}
|
||||
if elemType.Kind() == reflect.Struct {
|
||||
// Complex type - would need recursive handling
|
||||
schema.Items = &Schema{Type: "object"}
|
||||
} else {
|
||||
schema.Items = g.generatePropertySchema(reflect.StructField{Type: elemType})
|
||||
}
|
||||
case reflect.Struct:
|
||||
// Check for time.Time
|
||||
if fieldType.String() == "time.Time" {
|
||||
schema.Type = "string"
|
||||
schema.Format = "date-time"
|
||||
} else {
|
||||
schema.Type = "object"
|
||||
}
|
||||
default:
|
||||
schema.Type = "string"
|
||||
}
|
||||
|
||||
// Check for custom format from gorm/bun tags
|
||||
if gormTag := field.Tag.Get("gorm"); gormTag != "" {
|
||||
if strings.Contains(gormTag, "type:uuid") {
|
||||
schema.Format = "uuid"
|
||||
}
|
||||
}
|
||||
|
||||
return schema
|
||||
}
|
||||
|
||||
// parseModelName splits "schema.entity" or returns "public" and entity
|
||||
func parseModelName(name string) (schema, entity string) {
|
||||
parts := strings.Split(name, ".")
|
||||
if len(parts) == 2 {
|
||||
return parts[0], parts[1]
|
||||
}
|
||||
return "public", name
|
||||
}
|
||||
|
||||
// formatSchemaName creates a component schema name
|
||||
func formatSchemaName(schema, entity string) string {
|
||||
if schema == "public" {
|
||||
return toTitleCase(entity)
|
||||
}
|
||||
return toTitleCase(schema) + toTitleCase(entity)
|
||||
}
|
||||
|
||||
// toTitleCase converts a string to title case (first letter uppercase)
|
||||
func toTitleCase(s string) string {
|
||||
if s == "" {
|
||||
return ""
|
||||
}
|
||||
if len(s) == 1 {
|
||||
return strings.ToUpper(s)
|
||||
}
|
||||
return strings.ToUpper(s[:1]) + s[1:]
|
||||
}
|
||||
714
pkg/openapi/generator_test.go
Normal file
714
pkg/openapi/generator_test.go
Normal file
@@ -0,0 +1,714 @@
|
||||
package openapi
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
|
||||
)
|
||||
|
||||
// Test models
|
||||
type TestUser struct {
|
||||
ID int `json:"id" gorm:"primaryKey" description:"User ID"`
|
||||
Name string `json:"name" gorm:"not null" description:"User's full name"`
|
||||
Email string `json:"email" gorm:"unique" description:"Email address"`
|
||||
Age int `json:"age" description:"User age"`
|
||||
IsActive bool `json:"is_active" description:"Active status"`
|
||||
CreatedAt time.Time `json:"created_at" description:"Creation timestamp"`
|
||||
UpdatedAt *time.Time `json:"updated_at,omitempty" description:"Last update timestamp"`
|
||||
Roles []string `json:"roles,omitempty" description:"User roles"`
|
||||
}
|
||||
|
||||
type TestProduct struct {
|
||||
ID int `json:"id" gorm:"primaryKey"`
|
||||
Name string `json:"name" gorm:"not null"`
|
||||
Description string `json:"description"`
|
||||
Price float64 `json:"price"`
|
||||
InStock bool `json:"in_stock"`
|
||||
}
|
||||
|
||||
type TestOrder struct {
|
||||
ID int `json:"id" gorm:"primaryKey"`
|
||||
UserID int `json:"user_id" gorm:"not null"`
|
||||
ProductID int `json:"product_id" gorm:"not null"`
|
||||
Quantity int `json:"quantity"`
|
||||
TotalPrice float64 `json:"total_price"`
|
||||
}
|
||||
|
||||
func TestNewGenerator(t *testing.T) {
|
||||
registry := modelregistry.NewModelRegistry()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
config GeneratorConfig
|
||||
want string // expected title
|
||||
}{
|
||||
{
|
||||
name: "with all fields",
|
||||
config: GeneratorConfig{
|
||||
Title: "Test API",
|
||||
Description: "Test Description",
|
||||
Version: "1.0.0",
|
||||
BaseURL: "http://localhost:8080",
|
||||
Registry: registry,
|
||||
},
|
||||
want: "Test API",
|
||||
},
|
||||
{
|
||||
name: "with defaults",
|
||||
config: GeneratorConfig{
|
||||
Registry: registry,
|
||||
},
|
||||
want: "ResolveSpec API",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gen := NewGenerator(tt.config)
|
||||
if gen == nil {
|
||||
t.Fatal("NewGenerator returned nil")
|
||||
}
|
||||
if gen.config.Title != tt.want {
|
||||
t.Errorf("Title = %v, want %v", gen.config.Title, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateBasicSpec(t *testing.T) {
|
||||
registry := modelregistry.NewModelRegistry()
|
||||
err := registry.RegisterModel("public.users", TestUser{})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to register model: %v", err)
|
||||
}
|
||||
|
||||
config := GeneratorConfig{
|
||||
Title: "Test API",
|
||||
Version: "1.0.0",
|
||||
Registry: registry,
|
||||
IncludeRestheadSpec: true,
|
||||
}
|
||||
|
||||
gen := NewGenerator(config)
|
||||
spec, err := gen.Generate()
|
||||
if err != nil {
|
||||
t.Fatalf("Generate failed: %v", err)
|
||||
}
|
||||
|
||||
// Test basic spec structure
|
||||
if spec.OpenAPI != "3.0.0" {
|
||||
t.Errorf("OpenAPI version = %v, want 3.0.0", spec.OpenAPI)
|
||||
}
|
||||
if spec.Info.Title != "Test API" {
|
||||
t.Errorf("Title = %v, want Test API", spec.Info.Title)
|
||||
}
|
||||
if spec.Info.Version != "1.0.0" {
|
||||
t.Errorf("Version = %v, want 1.0.0", spec.Info.Version)
|
||||
}
|
||||
|
||||
// Test that common schemas are added
|
||||
if spec.Components.Schemas["Response"].Type != "object" {
|
||||
t.Error("Response schema not found or invalid")
|
||||
}
|
||||
if spec.Components.Schemas["Metadata"].Type != "object" {
|
||||
t.Error("Metadata schema not found or invalid")
|
||||
}
|
||||
|
||||
// Test that model schema is added
|
||||
if _, exists := spec.Components.Schemas["Users"]; !exists {
|
||||
t.Error("Users schema not found")
|
||||
}
|
||||
|
||||
// Test that security schemes are added
|
||||
if len(spec.Components.SecuritySchemes) == 0 {
|
||||
t.Error("Security schemes not added")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateModelSchema(t *testing.T) {
|
||||
registry := modelregistry.NewModelRegistry()
|
||||
gen := NewGenerator(GeneratorConfig{Registry: registry})
|
||||
|
||||
schema := gen.generateModelSchema(TestUser{})
|
||||
|
||||
// Test basic properties
|
||||
if schema.Type != "object" {
|
||||
t.Errorf("Schema type = %v, want object", schema.Type)
|
||||
}
|
||||
|
||||
// Test that properties are generated
|
||||
expectedProps := []string{"id", "name", "email", "age", "is_active", "created_at", "updated_at", "roles"}
|
||||
for _, prop := range expectedProps {
|
||||
if _, exists := schema.Properties[prop]; !exists {
|
||||
t.Errorf("Property %s not found in schema", prop)
|
||||
}
|
||||
}
|
||||
|
||||
// Test property types
|
||||
if schema.Properties["id"].Type != "integer" {
|
||||
t.Errorf("id type = %v, want integer", schema.Properties["id"].Type)
|
||||
}
|
||||
if schema.Properties["name"].Type != "string" {
|
||||
t.Errorf("name type = %v, want string", schema.Properties["name"].Type)
|
||||
}
|
||||
if schema.Properties["is_active"].Type != "boolean" {
|
||||
t.Errorf("is_active type = %v, want boolean", schema.Properties["is_active"].Type)
|
||||
}
|
||||
|
||||
// Test array type
|
||||
if schema.Properties["roles"].Type != "array" {
|
||||
t.Errorf("roles type = %v, want array", schema.Properties["roles"].Type)
|
||||
}
|
||||
if schema.Properties["roles"].Items.Type != "string" {
|
||||
t.Errorf("roles items type = %v, want string", schema.Properties["roles"].Items.Type)
|
||||
}
|
||||
|
||||
// Test time.Time format
|
||||
if schema.Properties["created_at"].Type != "string" {
|
||||
t.Errorf("created_at type = %v, want string", schema.Properties["created_at"].Type)
|
||||
}
|
||||
if schema.Properties["created_at"].Format != "date-time" {
|
||||
t.Errorf("created_at format = %v, want date-time", schema.Properties["created_at"].Format)
|
||||
}
|
||||
|
||||
// Test required fields (non-pointer, no omitempty)
|
||||
requiredFields := map[string]bool{}
|
||||
for _, field := range schema.Required {
|
||||
requiredFields[field] = true
|
||||
}
|
||||
if !requiredFields["id"] {
|
||||
t.Error("id should be required")
|
||||
}
|
||||
if !requiredFields["name"] {
|
||||
t.Error("name should be required")
|
||||
}
|
||||
if requiredFields["updated_at"] {
|
||||
t.Error("updated_at should not be required (pointer + omitempty)")
|
||||
}
|
||||
if requiredFields["roles"] {
|
||||
t.Error("roles should not be required (omitempty)")
|
||||
}
|
||||
|
||||
// Test descriptions
|
||||
if schema.Properties["id"].Description != "User ID" {
|
||||
t.Errorf("id description = %v, want 'User ID'", schema.Properties["id"].Description)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateRestheadSpecPaths(t *testing.T) {
|
||||
registry := modelregistry.NewModelRegistry()
|
||||
err := registry.RegisterModel("public.users", TestUser{})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to register model: %v", err)
|
||||
}
|
||||
|
||||
config := GeneratorConfig{
|
||||
Title: "Test API",
|
||||
Version: "1.0.0",
|
||||
Registry: registry,
|
||||
IncludeRestheadSpec: true,
|
||||
}
|
||||
|
||||
gen := NewGenerator(config)
|
||||
spec, err := gen.Generate()
|
||||
if err != nil {
|
||||
t.Fatalf("Generate failed: %v", err)
|
||||
}
|
||||
|
||||
// Test that paths are generated
|
||||
expectedPaths := []string{
|
||||
"/public/users",
|
||||
"/public/users/{id}",
|
||||
"/public/users/metadata",
|
||||
}
|
||||
|
||||
for _, path := range expectedPaths {
|
||||
if _, exists := spec.Paths[path]; !exists {
|
||||
t.Errorf("Path %s not found", path)
|
||||
}
|
||||
}
|
||||
|
||||
// Test collection endpoint methods
|
||||
usersPath := spec.Paths["/public/users"]
|
||||
if usersPath.Get == nil {
|
||||
t.Error("GET method not found for /public/users")
|
||||
}
|
||||
if usersPath.Post == nil {
|
||||
t.Error("POST method not found for /public/users")
|
||||
}
|
||||
if usersPath.Options == nil {
|
||||
t.Error("OPTIONS method not found for /public/users")
|
||||
}
|
||||
|
||||
// Test single record endpoint methods
|
||||
userIDPath := spec.Paths["/public/users/{id}"]
|
||||
if userIDPath.Get == nil {
|
||||
t.Error("GET method not found for /public/users/{id}")
|
||||
}
|
||||
if userIDPath.Put == nil {
|
||||
t.Error("PUT method not found for /public/users/{id}")
|
||||
}
|
||||
if userIDPath.Patch == nil {
|
||||
t.Error("PATCH method not found for /public/users/{id}")
|
||||
}
|
||||
if userIDPath.Delete == nil {
|
||||
t.Error("DELETE method not found for /public/users/{id}")
|
||||
}
|
||||
|
||||
// Test metadata endpoint
|
||||
metadataPath := spec.Paths["/public/users/metadata"]
|
||||
if metadataPath.Get == nil {
|
||||
t.Error("GET method not found for /public/users/metadata")
|
||||
}
|
||||
|
||||
// Test operation details
|
||||
getOp := usersPath.Get
|
||||
if getOp.Summary == "" {
|
||||
t.Error("GET operation summary is empty")
|
||||
}
|
||||
if getOp.OperationID == "" {
|
||||
t.Error("GET operation ID is empty")
|
||||
}
|
||||
if len(getOp.Tags) == 0 {
|
||||
t.Error("GET operation has no tags")
|
||||
}
|
||||
if len(getOp.Parameters) == 0 {
|
||||
t.Error("GET operation has no parameters")
|
||||
}
|
||||
|
||||
// Test RestheadSpec headers
|
||||
hasFiltersHeader := false
|
||||
for _, param := range getOp.Parameters {
|
||||
if param.Name == "X-Filters" && param.In == "header" {
|
||||
hasFiltersHeader = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasFiltersHeader {
|
||||
t.Error("X-Filters header parameter not found")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateResolveSpecPaths(t *testing.T) {
|
||||
registry := modelregistry.NewModelRegistry()
|
||||
err := registry.RegisterModel("public.products", TestProduct{})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to register model: %v", err)
|
||||
}
|
||||
|
||||
config := GeneratorConfig{
|
||||
Title: "Test API",
|
||||
Version: "1.0.0",
|
||||
Registry: registry,
|
||||
IncludeResolveSpec: true,
|
||||
}
|
||||
|
||||
gen := NewGenerator(config)
|
||||
spec, err := gen.Generate()
|
||||
if err != nil {
|
||||
t.Fatalf("Generate failed: %v", err)
|
||||
}
|
||||
|
||||
// Test that paths are generated
|
||||
expectedPaths := []string{
|
||||
"/resolve/public/products",
|
||||
"/resolve/public/products/{id}",
|
||||
}
|
||||
|
||||
for _, path := range expectedPaths {
|
||||
if _, exists := spec.Paths[path]; !exists {
|
||||
t.Errorf("Path %s not found", path)
|
||||
}
|
||||
}
|
||||
|
||||
// Test collection endpoint methods
|
||||
productsPath := spec.Paths["/resolve/public/products"]
|
||||
if productsPath.Post == nil {
|
||||
t.Error("POST method not found for /resolve/public/products")
|
||||
}
|
||||
if productsPath.Get == nil {
|
||||
t.Error("GET method not found for /resolve/public/products")
|
||||
}
|
||||
if productsPath.Options == nil {
|
||||
t.Error("OPTIONS method not found for /resolve/public/products")
|
||||
}
|
||||
|
||||
// Test POST operation has request body
|
||||
postOp := productsPath.Post
|
||||
if postOp.RequestBody == nil {
|
||||
t.Error("POST operation has no request body")
|
||||
}
|
||||
if _, exists := postOp.RequestBody.Content["application/json"]; !exists {
|
||||
t.Error("POST operation request body has no application/json content")
|
||||
}
|
||||
|
||||
// Test request body schema references ResolveSpecRequest
|
||||
reqBodySchema := postOp.RequestBody.Content["application/json"].Schema
|
||||
if reqBodySchema.Ref != "#/components/schemas/ResolveSpecRequest" {
|
||||
t.Errorf("Request body schema ref = %v, want #/components/schemas/ResolveSpecRequest", reqBodySchema.Ref)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateFuncSpecPaths(t *testing.T) {
|
||||
registry := modelregistry.NewModelRegistry()
|
||||
|
||||
funcSpecEndpoints := map[string]FuncSpecEndpoint{
|
||||
"/api/reports/sales": {
|
||||
Path: "/api/reports/sales",
|
||||
Method: "GET",
|
||||
Summary: "Get sales report",
|
||||
Description: "Returns sales data",
|
||||
Parameters: []string{"start_date", "end_date"},
|
||||
},
|
||||
"/api/analytics/users": {
|
||||
Path: "/api/analytics/users",
|
||||
Method: "POST",
|
||||
Summary: "Get user analytics",
|
||||
Description: "Returns user activity",
|
||||
Parameters: []string{"user_id"},
|
||||
},
|
||||
}
|
||||
|
||||
config := GeneratorConfig{
|
||||
Title: "Test API",
|
||||
Version: "1.0.0",
|
||||
Registry: registry,
|
||||
IncludeFuncSpec: true,
|
||||
FuncSpecEndpoints: funcSpecEndpoints,
|
||||
}
|
||||
|
||||
gen := NewGenerator(config)
|
||||
spec, err := gen.Generate()
|
||||
if err != nil {
|
||||
t.Fatalf("Generate failed: %v", err)
|
||||
}
|
||||
|
||||
// Test that FuncSpec paths are generated
|
||||
salesPath := spec.Paths["/api/reports/sales"]
|
||||
if salesPath.Get == nil {
|
||||
t.Error("GET method not found for /api/reports/sales")
|
||||
}
|
||||
if salesPath.Get.Summary != "Get sales report" {
|
||||
t.Errorf("GET summary = %v, want 'Get sales report'", salesPath.Get.Summary)
|
||||
}
|
||||
if len(salesPath.Get.Parameters) != 2 {
|
||||
t.Errorf("GET has %d parameters, want 2", len(salesPath.Get.Parameters))
|
||||
}
|
||||
|
||||
analyticsPath := spec.Paths["/api/analytics/users"]
|
||||
if analyticsPath.Post == nil {
|
||||
t.Error("POST method not found for /api/analytics/users")
|
||||
}
|
||||
if len(analyticsPath.Post.Parameters) != 1 {
|
||||
t.Errorf("POST has %d parameters, want 1", len(analyticsPath.Post.Parameters))
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateJSON(t *testing.T) {
|
||||
registry := modelregistry.NewModelRegistry()
|
||||
err := registry.RegisterModel("public.users", TestUser{})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to register model: %v", err)
|
||||
}
|
||||
|
||||
config := GeneratorConfig{
|
||||
Title: "Test API",
|
||||
Version: "1.0.0",
|
||||
Registry: registry,
|
||||
IncludeRestheadSpec: true,
|
||||
}
|
||||
|
||||
gen := NewGenerator(config)
|
||||
jsonStr, err := gen.GenerateJSON()
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateJSON failed: %v", err)
|
||||
}
|
||||
|
||||
// Test that it's valid JSON
|
||||
var spec OpenAPISpec
|
||||
if err := json.Unmarshal([]byte(jsonStr), &spec); err != nil {
|
||||
t.Fatalf("Generated JSON is invalid: %v", err)
|
||||
}
|
||||
|
||||
// Test basic structure
|
||||
if spec.OpenAPI != "3.0.0" {
|
||||
t.Errorf("OpenAPI version = %v, want 3.0.0", spec.OpenAPI)
|
||||
}
|
||||
if spec.Info.Title != "Test API" {
|
||||
t.Errorf("Title = %v, want Test API", spec.Info.Title)
|
||||
}
|
||||
|
||||
// Test that JSON contains expected fields
|
||||
if !strings.Contains(jsonStr, `"openapi"`) {
|
||||
t.Error("JSON doesn't contain 'openapi' field")
|
||||
}
|
||||
if !strings.Contains(jsonStr, `"paths"`) {
|
||||
t.Error("JSON doesn't contain 'paths' field")
|
||||
}
|
||||
if !strings.Contains(jsonStr, `"components"`) {
|
||||
t.Error("JSON doesn't contain 'components' field")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMultipleModels(t *testing.T) {
|
||||
registry := modelregistry.NewModelRegistry()
|
||||
registry.RegisterModel("public.users", TestUser{})
|
||||
registry.RegisterModel("public.products", TestProduct{})
|
||||
registry.RegisterModel("public.orders", TestOrder{})
|
||||
|
||||
config := GeneratorConfig{
|
||||
Title: "Test API",
|
||||
Version: "1.0.0",
|
||||
Registry: registry,
|
||||
IncludeRestheadSpec: true,
|
||||
}
|
||||
|
||||
gen := NewGenerator(config)
|
||||
spec, err := gen.Generate()
|
||||
if err != nil {
|
||||
t.Fatalf("Generate failed: %v", err)
|
||||
}
|
||||
|
||||
// Test that all model schemas are generated
|
||||
expectedSchemas := []string{"Users", "Products", "Orders"}
|
||||
for _, schemaName := range expectedSchemas {
|
||||
if _, exists := spec.Components.Schemas[schemaName]; !exists {
|
||||
t.Errorf("Schema %s not found", schemaName)
|
||||
}
|
||||
}
|
||||
|
||||
// Test that all paths are generated
|
||||
expectedPaths := []string{
|
||||
"/public/users",
|
||||
"/public/products",
|
||||
"/public/orders",
|
||||
}
|
||||
for _, path := range expectedPaths {
|
||||
if _, exists := spec.Paths[path]; !exists {
|
||||
t.Errorf("Path %s not found", path)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelNameParsing(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
fullName string
|
||||
wantSchema string
|
||||
wantEntity string
|
||||
}{
|
||||
{
|
||||
name: "with schema",
|
||||
fullName: "public.users",
|
||||
wantSchema: "public",
|
||||
wantEntity: "users",
|
||||
},
|
||||
{
|
||||
name: "without schema",
|
||||
fullName: "users",
|
||||
wantSchema: "public",
|
||||
wantEntity: "users",
|
||||
},
|
||||
{
|
||||
name: "custom schema",
|
||||
fullName: "custom.products",
|
||||
wantSchema: "custom",
|
||||
wantEntity: "products",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
schema, entity := parseModelName(tt.fullName)
|
||||
if schema != tt.wantSchema {
|
||||
t.Errorf("schema = %v, want %v", schema, tt.wantSchema)
|
||||
}
|
||||
if entity != tt.wantEntity {
|
||||
t.Errorf("entity = %v, want %v", entity, tt.wantEntity)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSchemaNameFormatting(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
schema string
|
||||
entity string
|
||||
wantName string
|
||||
}{
|
||||
{
|
||||
name: "public schema",
|
||||
schema: "public",
|
||||
entity: "users",
|
||||
wantName: "Users",
|
||||
},
|
||||
{
|
||||
name: "custom schema",
|
||||
schema: "custom",
|
||||
entity: "products",
|
||||
wantName: "CustomProducts",
|
||||
},
|
||||
{
|
||||
name: "multi-word entity",
|
||||
schema: "public",
|
||||
entity: "user_profiles",
|
||||
wantName: "User_profiles",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
name := formatSchemaName(tt.schema, tt.entity)
|
||||
if name != tt.wantName {
|
||||
t.Errorf("formatSchemaName() = %v, want %v", name, tt.wantName)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestToTitleCase(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
want string
|
||||
}{
|
||||
{"users", "Users"},
|
||||
{"products", "Products"},
|
||||
{"userProfiles", "UserProfiles"},
|
||||
{"a", "A"},
|
||||
{"", ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.input, func(t *testing.T) {
|
||||
got := toTitleCase(tt.input)
|
||||
if got != tt.want {
|
||||
t.Errorf("toTitleCase(%v) = %v, want %v", tt.input, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateWithBaseURL(t *testing.T) {
|
||||
registry := modelregistry.NewModelRegistry()
|
||||
registry.RegisterModel("public.users", TestUser{})
|
||||
|
||||
config := GeneratorConfig{
|
||||
Title: "Test API",
|
||||
Version: "1.0.0",
|
||||
BaseURL: "https://api.example.com",
|
||||
Registry: registry,
|
||||
IncludeRestheadSpec: true,
|
||||
}
|
||||
|
||||
gen := NewGenerator(config)
|
||||
spec, err := gen.Generate()
|
||||
if err != nil {
|
||||
t.Fatalf("Generate failed: %v", err)
|
||||
}
|
||||
|
||||
// Test that server is added
|
||||
if len(spec.Servers) == 0 {
|
||||
t.Fatal("No servers added")
|
||||
}
|
||||
if spec.Servers[0].URL != "https://api.example.com" {
|
||||
t.Errorf("Server URL = %v, want https://api.example.com", spec.Servers[0].URL)
|
||||
}
|
||||
if spec.Servers[0].Description != "API Server" {
|
||||
t.Errorf("Server description = %v, want 'API Server'", spec.Servers[0].Description)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateCombinedFrameworks(t *testing.T) {
|
||||
registry := modelregistry.NewModelRegistry()
|
||||
registry.RegisterModel("public.users", TestUser{})
|
||||
|
||||
config := GeneratorConfig{
|
||||
Title: "Test API",
|
||||
Version: "1.0.0",
|
||||
Registry: registry,
|
||||
IncludeRestheadSpec: true,
|
||||
IncludeResolveSpec: true,
|
||||
}
|
||||
|
||||
gen := NewGenerator(config)
|
||||
spec, err := gen.Generate()
|
||||
if err != nil {
|
||||
t.Fatalf("Generate failed: %v", err)
|
||||
}
|
||||
|
||||
// Test that both RestheadSpec and ResolveSpec paths are generated
|
||||
restheadPath := "/public/users"
|
||||
resolveSpecPath := "/resolve/public/users"
|
||||
|
||||
if _, exists := spec.Paths[restheadPath]; !exists {
|
||||
t.Errorf("RestheadSpec path %s not found", restheadPath)
|
||||
}
|
||||
if _, exists := spec.Paths[resolveSpecPath]; !exists {
|
||||
t.Errorf("ResolveSpec path %s not found", resolveSpecPath)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNilRegistry(t *testing.T) {
|
||||
config := GeneratorConfig{
|
||||
Title: "Test API",
|
||||
Version: "1.0.0",
|
||||
}
|
||||
|
||||
gen := NewGenerator(config)
|
||||
_, err := gen.Generate()
|
||||
if err == nil {
|
||||
t.Error("Expected error for nil registry, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "registry") {
|
||||
t.Errorf("Error message should mention registry, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecuritySchemes(t *testing.T) {
|
||||
registry := modelregistry.NewModelRegistry()
|
||||
config := GeneratorConfig{
|
||||
Registry: registry,
|
||||
}
|
||||
|
||||
gen := NewGenerator(config)
|
||||
spec, err := gen.Generate()
|
||||
if err != nil {
|
||||
t.Fatalf("Generate failed: %v", err)
|
||||
}
|
||||
|
||||
// Test that all security schemes are present
|
||||
expectedSchemes := []string{"BearerAuth", "SessionToken", "CookieAuth", "HeaderAuth"}
|
||||
for _, scheme := range expectedSchemes {
|
||||
if _, exists := spec.Components.SecuritySchemes[scheme]; !exists {
|
||||
t.Errorf("Security scheme %s not found", scheme)
|
||||
}
|
||||
}
|
||||
|
||||
// Test BearerAuth scheme details
|
||||
bearerAuth := spec.Components.SecuritySchemes["BearerAuth"]
|
||||
if bearerAuth.Type != "http" {
|
||||
t.Errorf("BearerAuth type = %v, want http", bearerAuth.Type)
|
||||
}
|
||||
if bearerAuth.Scheme != "bearer" {
|
||||
t.Errorf("BearerAuth scheme = %v, want bearer", bearerAuth.Scheme)
|
||||
}
|
||||
if bearerAuth.BearerFormat != "JWT" {
|
||||
t.Errorf("BearerAuth format = %v, want JWT", bearerAuth.BearerFormat)
|
||||
}
|
||||
|
||||
// Test HeaderAuth scheme details
|
||||
headerAuth := spec.Components.SecuritySchemes["HeaderAuth"]
|
||||
if headerAuth.Type != "apiKey" {
|
||||
t.Errorf("HeaderAuth type = %v, want apiKey", headerAuth.Type)
|
||||
}
|
||||
if headerAuth.In != "header" {
|
||||
t.Errorf("HeaderAuth in = %v, want header", headerAuth.In)
|
||||
}
|
||||
if headerAuth.Name != "X-User-ID" {
|
||||
t.Errorf("HeaderAuth name = %v, want X-User-ID", headerAuth.Name)
|
||||
}
|
||||
}
|
||||
499
pkg/openapi/paths.go
Normal file
499
pkg/openapi/paths.go
Normal file
@@ -0,0 +1,499 @@
|
||||
package openapi
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// generateRestheadSpecPaths generates OpenAPI paths for RestheadSpec endpoints
|
||||
func (g *Generator) generateRestheadSpecPaths(spec *OpenAPISpec, schema, entity, schemaName string) {
|
||||
basePath := fmt.Sprintf("/%s/%s", schema, entity)
|
||||
idPath := fmt.Sprintf("/%s/%s/{id}", schema, entity)
|
||||
metaPath := fmt.Sprintf("/%s/%s/metadata", schema, entity)
|
||||
|
||||
// Collection endpoint: GET (list), POST (create)
|
||||
spec.Paths[basePath] = PathItem{
|
||||
Get: &Operation{
|
||||
Summary: fmt.Sprintf("List %s records", entity),
|
||||
Description: fmt.Sprintf("Retrieve a list of %s records with optional filtering, sorting, and pagination via headers", entity),
|
||||
OperationID: fmt.Sprintf("listRestheadSpec%s%s", formatSchemaName(schema, ""), formatSchemaName("", entity)),
|
||||
Tags: []string{fmt.Sprintf("%s (RestheadSpec)", entity)},
|
||||
Parameters: g.getRestheadSpecHeaders(),
|
||||
Responses: map[string]Response{
|
||||
"200": {
|
||||
Description: "Successful response",
|
||||
Content: map[string]MediaType{
|
||||
"application/json": {
|
||||
Schema: &Schema{
|
||||
Type: "object",
|
||||
Properties: map[string]*Schema{
|
||||
"success": {Type: "boolean"},
|
||||
"data": {Type: "array", Items: &Schema{Ref: fmt.Sprintf("#/components/schemas/%s", schemaName)}},
|
||||
"metadata": {Ref: "#/components/schemas/Metadata"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"401": g.errorResponse("Unauthorized"),
|
||||
"500": g.errorResponse("Internal server error"),
|
||||
},
|
||||
Security: g.securityRequirements(),
|
||||
},
|
||||
Post: &Operation{
|
||||
Summary: fmt.Sprintf("Create %s record", entity),
|
||||
Description: fmt.Sprintf("Create a new %s record", entity),
|
||||
OperationID: fmt.Sprintf("createRestheadSpec%s%s", formatSchemaName(schema, ""), formatSchemaName("", entity)),
|
||||
Tags: []string{fmt.Sprintf("%s (RestheadSpec)", entity)},
|
||||
RequestBody: &RequestBody{
|
||||
Required: true,
|
||||
Description: fmt.Sprintf("%s object to create", entity),
|
||||
Content: map[string]MediaType{
|
||||
"application/json": {
|
||||
Schema: &Schema{Ref: fmt.Sprintf("#/components/schemas/%s", schemaName)},
|
||||
},
|
||||
},
|
||||
},
|
||||
Responses: map[string]Response{
|
||||
"201": {
|
||||
Description: "Record created successfully",
|
||||
Content: map[string]MediaType{
|
||||
"application/json": {
|
||||
Schema: &Schema{
|
||||
Type: "object",
|
||||
Properties: map[string]*Schema{
|
||||
"success": {Type: "boolean"},
|
||||
"data": {Ref: fmt.Sprintf("#/components/schemas/%s", schemaName)},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"400": g.errorResponse("Bad request"),
|
||||
"401": g.errorResponse("Unauthorized"),
|
||||
"500": g.errorResponse("Internal server error"),
|
||||
},
|
||||
Security: g.securityRequirements(),
|
||||
},
|
||||
Options: &Operation{
|
||||
Summary: "CORS preflight",
|
||||
Description: "Handle CORS preflight requests",
|
||||
OperationID: fmt.Sprintf("optionsRestheadSpec%s%s", formatSchemaName(schema, ""), formatSchemaName("", entity)),
|
||||
Tags: []string{fmt.Sprintf("%s (RestheadSpec)", entity)},
|
||||
Responses: map[string]Response{
|
||||
"204": {Description: "No content"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Single record endpoint: GET (read), PUT/PATCH (update), DELETE
|
||||
spec.Paths[idPath] = PathItem{
|
||||
Get: &Operation{
|
||||
Summary: fmt.Sprintf("Get %s record by ID", entity),
|
||||
Description: fmt.Sprintf("Retrieve a single %s record by its ID", entity),
|
||||
OperationID: fmt.Sprintf("getRestheadSpec%s%s", formatSchemaName(schema, ""), formatSchemaName("", entity)),
|
||||
Tags: []string{fmt.Sprintf("%s (RestheadSpec)", entity)},
|
||||
Parameters: []Parameter{
|
||||
{Name: "id", In: "path", Required: true, Description: "Record ID", Schema: &Schema{Type: "integer"}},
|
||||
},
|
||||
Responses: map[string]Response{
|
||||
"200": {
|
||||
Description: "Successful response",
|
||||
Content: map[string]MediaType{
|
||||
"application/json": {
|
||||
Schema: &Schema{
|
||||
Type: "object",
|
||||
Properties: map[string]*Schema{
|
||||
"success": {Type: "boolean"},
|
||||
"data": {Ref: fmt.Sprintf("#/components/schemas/%s", schemaName)},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"404": g.errorResponse("Record not found"),
|
||||
"401": g.errorResponse("Unauthorized"),
|
||||
"500": g.errorResponse("Internal server error"),
|
||||
},
|
||||
Security: g.securityRequirements(),
|
||||
},
|
||||
Put: &Operation{
|
||||
Summary: fmt.Sprintf("Update %s record", entity),
|
||||
Description: fmt.Sprintf("Update an existing %s record by ID", entity),
|
||||
OperationID: fmt.Sprintf("updateRestheadSpec%s%s", formatSchemaName(schema, ""), formatSchemaName("", entity)),
|
||||
Tags: []string{fmt.Sprintf("%s (RestheadSpec)", entity)},
|
||||
Parameters: []Parameter{
|
||||
{Name: "id", In: "path", Required: true, Description: "Record ID", Schema: &Schema{Type: "integer"}},
|
||||
},
|
||||
RequestBody: &RequestBody{
|
||||
Required: true,
|
||||
Description: fmt.Sprintf("Updated %s object", entity),
|
||||
Content: map[string]MediaType{
|
||||
"application/json": {
|
||||
Schema: &Schema{Ref: fmt.Sprintf("#/components/schemas/%s", schemaName)},
|
||||
},
|
||||
},
|
||||
},
|
||||
Responses: map[string]Response{
|
||||
"200": {
|
||||
Description: "Record updated successfully",
|
||||
Content: map[string]MediaType{
|
||||
"application/json": {
|
||||
Schema: &Schema{
|
||||
Type: "object",
|
||||
Properties: map[string]*Schema{
|
||||
"success": {Type: "boolean"},
|
||||
"data": {Ref: fmt.Sprintf("#/components/schemas/%s", schemaName)},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"400": g.errorResponse("Bad request"),
|
||||
"404": g.errorResponse("Record not found"),
|
||||
"401": g.errorResponse("Unauthorized"),
|
||||
"500": g.errorResponse("Internal server error"),
|
||||
},
|
||||
Security: g.securityRequirements(),
|
||||
},
|
||||
Patch: &Operation{
|
||||
Summary: fmt.Sprintf("Partially update %s record", entity),
|
||||
Description: fmt.Sprintf("Partially update an existing %s record by ID", entity),
|
||||
OperationID: fmt.Sprintf("patchRestheadSpec%s%s", formatSchemaName(schema, ""), formatSchemaName("", entity)),
|
||||
Tags: []string{fmt.Sprintf("%s (RestheadSpec)", entity)},
|
||||
Parameters: []Parameter{
|
||||
{Name: "id", In: "path", Required: true, Description: "Record ID", Schema: &Schema{Type: "integer"}},
|
||||
},
|
||||
RequestBody: &RequestBody{
|
||||
Required: true,
|
||||
Description: fmt.Sprintf("Partial %s object", entity),
|
||||
Content: map[string]MediaType{
|
||||
"application/json": {
|
||||
Schema: &Schema{Ref: fmt.Sprintf("#/components/schemas/%s", schemaName)},
|
||||
},
|
||||
},
|
||||
},
|
||||
Responses: map[string]Response{
|
||||
"200": {
|
||||
Description: "Record updated successfully",
|
||||
Content: map[string]MediaType{
|
||||
"application/json": {
|
||||
Schema: &Schema{
|
||||
Type: "object",
|
||||
Properties: map[string]*Schema{
|
||||
"success": {Type: "boolean"},
|
||||
"data": {Ref: fmt.Sprintf("#/components/schemas/%s", schemaName)},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"400": g.errorResponse("Bad request"),
|
||||
"404": g.errorResponse("Record not found"),
|
||||
"401": g.errorResponse("Unauthorized"),
|
||||
"500": g.errorResponse("Internal server error"),
|
||||
},
|
||||
Security: g.securityRequirements(),
|
||||
},
|
||||
Delete: &Operation{
|
||||
Summary: fmt.Sprintf("Delete %s record", entity),
|
||||
Description: fmt.Sprintf("Delete a %s record by ID", entity),
|
||||
OperationID: fmt.Sprintf("deleteRestheadSpec%s%s", formatSchemaName(schema, ""), formatSchemaName("", entity)),
|
||||
Tags: []string{fmt.Sprintf("%s (RestheadSpec)", entity)},
|
||||
Parameters: []Parameter{
|
||||
{Name: "id", In: "path", Required: true, Description: "Record ID", Schema: &Schema{Type: "integer"}},
|
||||
},
|
||||
Responses: map[string]Response{
|
||||
"200": {
|
||||
Description: "Record deleted successfully",
|
||||
Content: map[string]MediaType{
|
||||
"application/json": {
|
||||
Schema: &Schema{
|
||||
Type: "object",
|
||||
Properties: map[string]*Schema{
|
||||
"success": {Type: "boolean"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"404": g.errorResponse("Record not found"),
|
||||
"401": g.errorResponse("Unauthorized"),
|
||||
"500": g.errorResponse("Internal server error"),
|
||||
},
|
||||
Security: g.securityRequirements(),
|
||||
},
|
||||
}
|
||||
|
||||
// Metadata endpoint
|
||||
spec.Paths[metaPath] = PathItem{
|
||||
Get: &Operation{
|
||||
Summary: fmt.Sprintf("Get %s metadata", entity),
|
||||
Description: fmt.Sprintf("Retrieve metadata information for %s table", entity),
|
||||
OperationID: fmt.Sprintf("metadataRestheadSpec%s%s", formatSchemaName(schema, ""), formatSchemaName("", entity)),
|
||||
Tags: []string{fmt.Sprintf("%s (RestheadSpec)", entity)},
|
||||
Responses: map[string]Response{
|
||||
"200": {
|
||||
Description: "Metadata retrieved successfully",
|
||||
Content: map[string]MediaType{
|
||||
"application/json": {
|
||||
Schema: &Schema{
|
||||
Type: "object",
|
||||
Properties: map[string]*Schema{
|
||||
"success": {Type: "boolean"},
|
||||
"data": {
|
||||
Type: "object",
|
||||
Properties: map[string]*Schema{
|
||||
"schema": {Type: "string"},
|
||||
"table": {Type: "string"},
|
||||
"columns": {Type: "array", Items: &Schema{Type: "object"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"401": g.errorResponse("Unauthorized"),
|
||||
"500": g.errorResponse("Internal server error"),
|
||||
},
|
||||
Security: g.securityRequirements(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// generateResolveSpecPaths generates OpenAPI paths for ResolveSpec endpoints
|
||||
func (g *Generator) generateResolveSpecPaths(spec *OpenAPISpec, schema, entity, schemaName string) {
|
||||
basePath := fmt.Sprintf("/resolve/%s/%s", schema, entity)
|
||||
idPath := fmt.Sprintf("/resolve/%s/%s/{id}", schema, entity)
|
||||
|
||||
// Collection endpoint: POST (operations)
|
||||
spec.Paths[basePath] = PathItem{
|
||||
Post: &Operation{
|
||||
Summary: fmt.Sprintf("Perform operation on %s", entity),
|
||||
Description: fmt.Sprintf("Execute read, create, or meta operations on %s records", entity),
|
||||
OperationID: fmt.Sprintf("operateResolveSpec%s%s", formatSchemaName(schema, ""), formatSchemaName("", entity)),
|
||||
Tags: []string{fmt.Sprintf("%s (ResolveSpec)", entity)},
|
||||
RequestBody: &RequestBody{
|
||||
Required: true,
|
||||
Description: "Operation request with operation type and options",
|
||||
Content: map[string]MediaType{
|
||||
"application/json": {
|
||||
Schema: &Schema{Ref: "#/components/schemas/ResolveSpecRequest"},
|
||||
Example: map[string]interface{}{
|
||||
"operation": "read",
|
||||
"options": map[string]interface{}{
|
||||
"limit": 10,
|
||||
"filters": []map[string]interface{}{
|
||||
{"column": "status", "operator": "eq", "value": "active"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Responses: map[string]Response{
|
||||
"200": {
|
||||
Description: "Operation completed successfully",
|
||||
Content: map[string]MediaType{
|
||||
"application/json": {
|
||||
Schema: &Schema{
|
||||
Type: "object",
|
||||
Properties: map[string]*Schema{
|
||||
"success": {Type: "boolean"},
|
||||
"data": {Type: "array", Items: &Schema{Ref: fmt.Sprintf("#/components/schemas/%s", schemaName)}},
|
||||
"metadata": {Ref: "#/components/schemas/Metadata"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"400": g.errorResponse("Bad request"),
|
||||
"401": g.errorResponse("Unauthorized"),
|
||||
"500": g.errorResponse("Internal server error"),
|
||||
},
|
||||
Security: g.securityRequirements(),
|
||||
},
|
||||
Get: &Operation{
|
||||
Summary: fmt.Sprintf("Get %s metadata", entity),
|
||||
Description: fmt.Sprintf("Retrieve metadata for %s", entity),
|
||||
OperationID: fmt.Sprintf("metadataResolveSpec%s%s", formatSchemaName(schema, ""), formatSchemaName("", entity)),
|
||||
Tags: []string{fmt.Sprintf("%s (ResolveSpec)", entity)},
|
||||
Responses: map[string]Response{
|
||||
"200": {
|
||||
Description: "Metadata retrieved successfully",
|
||||
Content: map[string]MediaType{
|
||||
"application/json": {
|
||||
Schema: &Schema{Ref: "#/components/schemas/Response"},
|
||||
},
|
||||
},
|
||||
},
|
||||
"401": g.errorResponse("Unauthorized"),
|
||||
"500": g.errorResponse("Internal server error"),
|
||||
},
|
||||
Security: g.securityRequirements(),
|
||||
},
|
||||
Options: &Operation{
|
||||
Summary: "CORS preflight",
|
||||
Description: "Handle CORS preflight requests",
|
||||
OperationID: fmt.Sprintf("optionsResolveSpec%s%s", formatSchemaName(schema, ""), formatSchemaName("", entity)),
|
||||
Tags: []string{fmt.Sprintf("%s (ResolveSpec)", entity)},
|
||||
Responses: map[string]Response{
|
||||
"204": {Description: "No content"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Single record endpoint: POST (update/delete)
|
||||
spec.Paths[idPath] = PathItem{
|
||||
Post: &Operation{
|
||||
Summary: fmt.Sprintf("Update or delete %s record", entity),
|
||||
Description: fmt.Sprintf("Execute update or delete operation on a specific %s record", entity),
|
||||
OperationID: fmt.Sprintf("modifyResolveSpec%s%s", formatSchemaName(schema, ""), formatSchemaName("", entity)),
|
||||
Tags: []string{fmt.Sprintf("%s (ResolveSpec)", entity)},
|
||||
Parameters: []Parameter{
|
||||
{Name: "id", In: "path", Required: true, Description: "Record ID", Schema: &Schema{Type: "integer"}},
|
||||
},
|
||||
RequestBody: &RequestBody{
|
||||
Required: true,
|
||||
Description: "Operation request (update or delete)",
|
||||
Content: map[string]MediaType{
|
||||
"application/json": {
|
||||
Schema: &Schema{Ref: "#/components/schemas/ResolveSpecRequest"},
|
||||
Example: map[string]interface{}{
|
||||
"operation": "update",
|
||||
"data": map[string]interface{}{
|
||||
"status": "inactive",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Responses: map[string]Response{
|
||||
"200": {
|
||||
Description: "Operation completed successfully",
|
||||
Content: map[string]MediaType{
|
||||
"application/json": {
|
||||
Schema: &Schema{
|
||||
Type: "object",
|
||||
Properties: map[string]*Schema{
|
||||
"success": {Type: "boolean"},
|
||||
"data": {Ref: fmt.Sprintf("#/components/schemas/%s", schemaName)},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"400": g.errorResponse("Bad request"),
|
||||
"404": g.errorResponse("Record not found"),
|
||||
"401": g.errorResponse("Unauthorized"),
|
||||
"500": g.errorResponse("Internal server error"),
|
||||
},
|
||||
Security: g.securityRequirements(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// generateFuncSpecPaths generates OpenAPI paths for FuncSpec endpoints
|
||||
func (g *Generator) generateFuncSpecPaths(spec *OpenAPISpec) {
|
||||
for path, endpoint := range g.config.FuncSpecEndpoints {
|
||||
operation := &Operation{
|
||||
Summary: endpoint.Summary,
|
||||
Description: endpoint.Description,
|
||||
OperationID: fmt.Sprintf("funcSpec%s", sanitizeOperationID(path)),
|
||||
Tags: []string{"FuncSpec"},
|
||||
Parameters: g.extractFuncSpecParameters(endpoint.Parameters),
|
||||
Responses: map[string]Response{
|
||||
"200": {
|
||||
Description: "Query executed successfully",
|
||||
Content: map[string]MediaType{
|
||||
"application/json": {
|
||||
Schema: &Schema{Ref: "#/components/schemas/Response"},
|
||||
},
|
||||
},
|
||||
},
|
||||
"400": g.errorResponse("Bad request"),
|
||||
"401": g.errorResponse("Unauthorized"),
|
||||
"500": g.errorResponse("Internal server error"),
|
||||
},
|
||||
Security: g.securityRequirements(),
|
||||
}
|
||||
|
||||
pathItem := spec.Paths[path]
|
||||
switch endpoint.Method {
|
||||
case "GET":
|
||||
pathItem.Get = operation
|
||||
case "POST":
|
||||
pathItem.Post = operation
|
||||
case "PUT":
|
||||
pathItem.Put = operation
|
||||
case "DELETE":
|
||||
pathItem.Delete = operation
|
||||
}
|
||||
spec.Paths[path] = pathItem
|
||||
}
|
||||
}
|
||||
|
||||
// getRestheadSpecHeaders returns all RestheadSpec header parameters
|
||||
func (g *Generator) getRestheadSpecHeaders() []Parameter {
|
||||
return []Parameter{
|
||||
{Name: "X-Filters", In: "header", Description: "JSON array of filter conditions", Schema: &Schema{Type: "string"}},
|
||||
{Name: "X-Columns", In: "header", Description: "Comma-separated list of columns to select", Schema: &Schema{Type: "string"}},
|
||||
{Name: "X-Sort", In: "header", Description: "JSON array of sort specifications", Schema: &Schema{Type: "string"}},
|
||||
{Name: "X-Limit", In: "header", Description: "Maximum number of records to return", Schema: &Schema{Type: "integer"}},
|
||||
{Name: "X-Offset", In: "header", Description: "Number of records to skip", Schema: &Schema{Type: "integer"}},
|
||||
{Name: "X-Preload", In: "header", Description: "Relations to eager load (comma-separated)", Schema: &Schema{Type: "string"}},
|
||||
{Name: "X-Expand", In: "header", Description: "Relations to expand with LEFT JOIN (comma-separated)", Schema: &Schema{Type: "string"}},
|
||||
{Name: "X-Distinct", In: "header", Description: "Enable DISTINCT query (true/false)", Schema: &Schema{Type: "boolean"}},
|
||||
{Name: "X-Response-Format", In: "header", Description: "Response format", Schema: &Schema{Type: "string", Enum: []interface{}{"detail", "simple", "syncfusion"}}},
|
||||
{Name: "X-Clean-JSON", In: "header", Description: "Remove null/empty fields from response (true/false)", Schema: &Schema{Type: "boolean"}},
|
||||
{Name: "X-Custom-SQL-Where", In: "header", Description: "Custom SQL WHERE clause (AND)", Schema: &Schema{Type: "string"}},
|
||||
{Name: "X-Custom-SQL-Or", In: "header", Description: "Custom SQL WHERE clause (OR)", Schema: &Schema{Type: "string"}},
|
||||
}
|
||||
}
|
||||
|
||||
// extractFuncSpecParameters creates OpenAPI parameters from parameter names
|
||||
func (g *Generator) extractFuncSpecParameters(paramNames []string) []Parameter {
|
||||
params := []Parameter{}
|
||||
for _, name := range paramNames {
|
||||
params = append(params, Parameter{
|
||||
Name: name,
|
||||
In: "query",
|
||||
Description: fmt.Sprintf("Parameter: %s", name),
|
||||
Schema: &Schema{Type: "string"},
|
||||
})
|
||||
}
|
||||
return params
|
||||
}
|
||||
|
||||
// errorResponse creates a standard error response
|
||||
func (g *Generator) errorResponse(description string) Response {
|
||||
return Response{
|
||||
Description: description,
|
||||
Content: map[string]MediaType{
|
||||
"application/json": {
|
||||
Schema: &Schema{Ref: "#/components/schemas/APIError"},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// securityRequirements returns all security options (user can use any)
|
||||
func (g *Generator) securityRequirements() []map[string][]string {
|
||||
return []map[string][]string{
|
||||
{"BearerAuth": {}},
|
||||
{"SessionToken": {}},
|
||||
{"CookieAuth": {}},
|
||||
{"HeaderAuth": {}},
|
||||
}
|
||||
}
|
||||
|
||||
// sanitizeOperationID removes invalid characters from operation IDs
|
||||
func sanitizeOperationID(path string) string {
|
||||
result := ""
|
||||
for _, char := range path {
|
||||
if (char >= 'a' && char <= 'z') || (char >= 'A' && char <= 'Z') || (char >= '0' && char <= '9') {
|
||||
result += string(char)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
331
pkg/reflection/generic_model_test.go
Normal file
331
pkg/reflection/generic_model_test.go
Normal file
@@ -0,0 +1,331 @@
|
||||
package reflection
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// Test models for GetModelColumnDetail
|
||||
type TestModelForColumnDetail struct {
|
||||
ID int `gorm:"column:rid_test;primaryKey;type:bigserial;not null" json:"id"`
|
||||
Name string `gorm:"column:name;type:varchar(255);not null" json:"name"`
|
||||
Email string `gorm:"column:email;type:varchar(255);unique;nullable" json:"email"`
|
||||
Description string `gorm:"column:description;type:text;null" json:"description"`
|
||||
ForeignKey int `gorm:"foreignKey:parent_id" json:"foreign_key"`
|
||||
}
|
||||
|
||||
type EmbeddedBase struct {
|
||||
ID int `gorm:"column:rid_base;primaryKey;identity" json:"id"`
|
||||
CreatedAt string `gorm:"column:created_at;type:timestamp" json:"created_at"`
|
||||
}
|
||||
|
||||
type ModelWithEmbeddedForDetail struct {
|
||||
EmbeddedBase
|
||||
Title string `gorm:"column:title;type:varchar(100);not null" json:"title"`
|
||||
Content string `gorm:"column:content;type:text" json:"content"`
|
||||
}
|
||||
|
||||
// Model with nil embedded pointer
|
||||
type ModelWithNilEmbedded struct {
|
||||
ID int `gorm:"column:id;primaryKey" json:"id"`
|
||||
*EmbeddedBase
|
||||
Name string `gorm:"column:name" json:"name"`
|
||||
}
|
||||
|
||||
func TestGetModelColumnDetail(t *testing.T) {
|
||||
t.Run("simple struct", func(t *testing.T) {
|
||||
model := TestModelForColumnDetail{
|
||||
ID: 1,
|
||||
Name: "Test",
|
||||
Email: "test@example.com",
|
||||
Description: "Test description",
|
||||
ForeignKey: 100,
|
||||
}
|
||||
|
||||
details := GetModelColumnDetail(reflect.ValueOf(model))
|
||||
|
||||
if len(details) != 5 {
|
||||
t.Errorf("Expected 5 fields, got %d", len(details))
|
||||
}
|
||||
|
||||
// Check ID field
|
||||
found := false
|
||||
for _, detail := range details {
|
||||
if detail.Name == "ID" {
|
||||
found = true
|
||||
if detail.SQLName != "rid_test" {
|
||||
t.Errorf("Expected SQLName 'rid_test', got '%s'", detail.SQLName)
|
||||
}
|
||||
// Note: primaryKey (without underscore) is not detected as primary_key
|
||||
// The function looks for "identity" or "primary_key" (with underscore)
|
||||
if detail.SQLDataType != "bigserial" {
|
||||
t.Errorf("Expected SQLDataType 'bigserial', got '%s'", detail.SQLDataType)
|
||||
}
|
||||
if detail.Nullable {
|
||||
t.Errorf("Expected Nullable false, got true")
|
||||
}
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("ID field not found in details")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("struct with embedded fields", func(t *testing.T) {
|
||||
model := ModelWithEmbeddedForDetail{
|
||||
EmbeddedBase: EmbeddedBase{
|
||||
ID: 1,
|
||||
CreatedAt: "2024-01-01",
|
||||
},
|
||||
Title: "Test Title",
|
||||
Content: "Test Content",
|
||||
}
|
||||
|
||||
details := GetModelColumnDetail(reflect.ValueOf(model))
|
||||
|
||||
// Should have 4 fields: ID, CreatedAt from embedded, Title, Content from main
|
||||
if len(details) != 4 {
|
||||
t.Errorf("Expected 4 fields, got %d", len(details))
|
||||
}
|
||||
|
||||
// Check that embedded field is included
|
||||
foundID := false
|
||||
foundCreatedAt := false
|
||||
for _, detail := range details {
|
||||
if detail.Name == "ID" {
|
||||
foundID = true
|
||||
if detail.SQLKey != "primary_key" {
|
||||
t.Errorf("Expected SQLKey 'primary_key' for embedded ID, got '%s'", detail.SQLKey)
|
||||
}
|
||||
}
|
||||
if detail.Name == "CreatedAt" {
|
||||
foundCreatedAt = true
|
||||
}
|
||||
}
|
||||
if !foundID {
|
||||
t.Errorf("Embedded ID field not found")
|
||||
}
|
||||
if !foundCreatedAt {
|
||||
t.Errorf("Embedded CreatedAt field not found")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("nil embedded pointer is skipped", func(t *testing.T) {
|
||||
model := ModelWithNilEmbedded{
|
||||
ID: 1,
|
||||
Name: "Test",
|
||||
EmbeddedBase: nil, // nil embedded pointer
|
||||
}
|
||||
|
||||
details := GetModelColumnDetail(reflect.ValueOf(model))
|
||||
|
||||
// Should have 2 fields: ID and Name (embedded is nil, so skipped)
|
||||
if len(details) != 2 {
|
||||
t.Errorf("Expected 2 fields (nil embedded skipped), got %d", len(details))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("pointer to struct", func(t *testing.T) {
|
||||
model := &TestModelForColumnDetail{
|
||||
ID: 1,
|
||||
Name: "Test",
|
||||
}
|
||||
|
||||
details := GetModelColumnDetail(reflect.ValueOf(model))
|
||||
|
||||
if len(details) != 5 {
|
||||
t.Errorf("Expected 5 fields, got %d", len(details))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("invalid value", func(t *testing.T) {
|
||||
var invalid reflect.Value
|
||||
details := GetModelColumnDetail(invalid)
|
||||
|
||||
if len(details) != 0 {
|
||||
t.Errorf("Expected 0 fields for invalid value, got %d", len(details))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("non-struct type", func(t *testing.T) {
|
||||
details := GetModelColumnDetail(reflect.ValueOf(123))
|
||||
|
||||
if len(details) != 0 {
|
||||
t.Errorf("Expected 0 fields for non-struct, got %d", len(details))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("nullable and not null detection", func(t *testing.T) {
|
||||
model := TestModelForColumnDetail{}
|
||||
details := GetModelColumnDetail(reflect.ValueOf(model))
|
||||
|
||||
for _, detail := range details {
|
||||
switch detail.Name {
|
||||
case "ID":
|
||||
if detail.Nullable {
|
||||
t.Errorf("ID should not be nullable (has 'not null')")
|
||||
}
|
||||
case "Name":
|
||||
if detail.Nullable {
|
||||
t.Errorf("Name should not be nullable (has 'not null')")
|
||||
}
|
||||
case "Email":
|
||||
if !detail.Nullable {
|
||||
t.Errorf("Email should be nullable (has 'nullable')")
|
||||
}
|
||||
case "Description":
|
||||
if !detail.Nullable {
|
||||
t.Errorf("Description should be nullable (has 'null')")
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("unique and uniqueindex detection", func(t *testing.T) {
|
||||
type UniqueTestModel struct {
|
||||
ID int `gorm:"column:id;primary_key"`
|
||||
Username string `gorm:"column:username;unique"`
|
||||
Email string `gorm:"column:email;uniqueindex"`
|
||||
}
|
||||
|
||||
model := UniqueTestModel{}
|
||||
details := GetModelColumnDetail(reflect.ValueOf(model))
|
||||
|
||||
for _, detail := range details {
|
||||
switch detail.Name {
|
||||
case "ID":
|
||||
if detail.SQLKey != "primary_key" {
|
||||
t.Errorf("ID should have SQLKey 'primary_key', got '%s'", detail.SQLKey)
|
||||
}
|
||||
case "Username":
|
||||
if detail.SQLKey != "unique" {
|
||||
t.Errorf("Username should have SQLKey 'unique', got '%s'", detail.SQLKey)
|
||||
}
|
||||
case "Email":
|
||||
// The function checks for "unique" first, so uniqueindex is also detected as "unique"
|
||||
// This is expected behavior based on the code logic
|
||||
if detail.SQLKey != "unique" {
|
||||
t.Errorf("Email should have SQLKey 'unique' (uniqueindex contains 'unique'), got '%s'", detail.SQLKey)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("foreign key detection", func(t *testing.T) {
|
||||
// Note: The foreignkey extraction in generic_model.go has a bug where
|
||||
// it requires ik > 0, so foreignkey at the start won't extract the value
|
||||
type FKTestModel struct {
|
||||
ParentID int `gorm:"column:parent_id;foreignkey:rid_parent;association_foreignkey:id_atevent"`
|
||||
}
|
||||
|
||||
model := FKTestModel{}
|
||||
details := GetModelColumnDetail(reflect.ValueOf(model))
|
||||
|
||||
if len(details) == 0 {
|
||||
t.Fatal("Expected at least 1 field")
|
||||
}
|
||||
|
||||
detail := details[0]
|
||||
if detail.SQLKey != "foreign_key" {
|
||||
t.Errorf("Expected SQLKey 'foreign_key', got '%s'", detail.SQLKey)
|
||||
}
|
||||
// Due to the bug in the code (requires ik > 0), the SQLName will be extracted
|
||||
// when foreignkey is not at the beginning of the string
|
||||
if detail.SQLName != "rid_parent" {
|
||||
t.Errorf("Expected SQLName 'rid_parent', got '%s'", detail.SQLName)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestFnFindKeyVal(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
src string
|
||||
key string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "find column",
|
||||
src: "column:user_id;primaryKey;type:bigint",
|
||||
key: "column:",
|
||||
expected: "user_id",
|
||||
},
|
||||
{
|
||||
name: "find type",
|
||||
src: "column:name;type:varchar(255);not null",
|
||||
key: "type:",
|
||||
expected: "varchar(255)",
|
||||
},
|
||||
{
|
||||
name: "key not found",
|
||||
src: "primaryKey;autoIncrement",
|
||||
key: "column:",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "key at end without semicolon",
|
||||
src: "primaryKey;column:id",
|
||||
key: "column:",
|
||||
expected: "id",
|
||||
},
|
||||
{
|
||||
name: "case insensitive search",
|
||||
src: "Column:user_id;primaryKey",
|
||||
key: "column:",
|
||||
expected: "user_id",
|
||||
},
|
||||
{
|
||||
name: "empty src",
|
||||
src: "",
|
||||
key: "column:",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "multiple occurrences (returns first)",
|
||||
src: "column:first;column:second",
|
||||
key: "column:",
|
||||
expected: "first",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := fnFindKeyVal(tt.src, tt.key)
|
||||
if result != tt.expected {
|
||||
t.Errorf("fnFindKeyVal(%q, %q) = %q, want %q", tt.src, tt.key, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetModelColumnDetail_FieldValue(t *testing.T) {
|
||||
model := TestModelForColumnDetail{
|
||||
ID: 123,
|
||||
Name: "TestName",
|
||||
Email: "test@example.com",
|
||||
}
|
||||
|
||||
details := GetModelColumnDetail(reflect.ValueOf(model))
|
||||
|
||||
for _, detail := range details {
|
||||
if !detail.FieldValue.IsValid() {
|
||||
t.Errorf("Field %s has invalid FieldValue", detail.Name)
|
||||
}
|
||||
|
||||
// Check that FieldValue matches the actual value
|
||||
switch detail.Name {
|
||||
case "ID":
|
||||
if detail.FieldValue.Int() != 123 {
|
||||
t.Errorf("Expected ID FieldValue 123, got %v", detail.FieldValue.Int())
|
||||
}
|
||||
case "Name":
|
||||
if detail.FieldValue.String() != "TestName" {
|
||||
t.Errorf("Expected Name FieldValue 'TestName', got %v", detail.FieldValue.String())
|
||||
}
|
||||
case "Email":
|
||||
if detail.FieldValue.String() != "test@example.com" {
|
||||
t.Errorf("Expected Email FieldValue 'test@example.com', got %v", detail.FieldValue.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -750,6 +750,118 @@ func ConvertToNumericType(value string, kind reflect.Kind) (interface{}, error)
|
||||
return nil, fmt.Errorf("unsupported numeric type: %v", kind)
|
||||
}
|
||||
|
||||
// RelationType represents the type of database relationship
|
||||
type RelationType string
|
||||
|
||||
const (
|
||||
RelationHasMany RelationType = "has-many" // 1:N - use separate query
|
||||
RelationBelongsTo RelationType = "belongs-to" // N:1 - use JOIN
|
||||
RelationHasOne RelationType = "has-one" // 1:1 - use JOIN
|
||||
RelationManyToMany RelationType = "many-to-many" // M:N - use separate query
|
||||
RelationUnknown RelationType = "unknown"
|
||||
)
|
||||
|
||||
// ShouldUseJoin returns true if the relation type should use a JOIN instead of separate query
|
||||
func (rt RelationType) ShouldUseJoin() bool {
|
||||
return rt == RelationBelongsTo || rt == RelationHasOne
|
||||
}
|
||||
|
||||
// GetRelationType inspects the model's struct tags to determine the relationship type
|
||||
// It checks both Bun and GORM tags to identify the relationship cardinality
|
||||
func GetRelationType(model interface{}, fieldName string) RelationType {
|
||||
if model == nil || fieldName == "" {
|
||||
return RelationUnknown
|
||||
}
|
||||
|
||||
modelType := reflect.TypeOf(model)
|
||||
if modelType == nil {
|
||||
return RelationUnknown
|
||||
}
|
||||
|
||||
if modelType.Kind() == reflect.Ptr {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
|
||||
if modelType == nil || modelType.Kind() != reflect.Struct {
|
||||
return RelationUnknown
|
||||
}
|
||||
|
||||
// Find the field
|
||||
for i := 0; i < modelType.NumField(); i++ {
|
||||
field := modelType.Field(i)
|
||||
|
||||
// Check if field name matches (case-insensitive)
|
||||
if !strings.EqualFold(field.Name, fieldName) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check Bun tags first
|
||||
bunTag := field.Tag.Get("bun")
|
||||
if bunTag != "" && strings.Contains(bunTag, "rel:") {
|
||||
// Parse bun relation tag: rel:has-many, rel:belongs-to, rel:has-one, rel:many-to-many
|
||||
parts := strings.Split(bunTag, ",")
|
||||
for _, part := range parts {
|
||||
part = strings.TrimSpace(part)
|
||||
if strings.HasPrefix(part, "rel:") {
|
||||
relType := strings.TrimPrefix(part, "rel:")
|
||||
switch relType {
|
||||
case "has-many":
|
||||
return RelationHasMany
|
||||
case "belongs-to":
|
||||
return RelationBelongsTo
|
||||
case "has-one":
|
||||
return RelationHasOne
|
||||
case "many-to-many", "m2m":
|
||||
return RelationManyToMany
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check GORM tags
|
||||
gormTag := field.Tag.Get("gorm")
|
||||
if gormTag != "" {
|
||||
// GORM uses different patterns:
|
||||
// - foreignKey: usually indicates belongs-to or has-one
|
||||
// - many2many: indicates many-to-many
|
||||
// - Field type (slice vs pointer) helps determine cardinality
|
||||
|
||||
if strings.Contains(gormTag, "many2many:") {
|
||||
return RelationManyToMany
|
||||
}
|
||||
|
||||
// Check field type for cardinality hints
|
||||
fieldType := field.Type
|
||||
if fieldType.Kind() == reflect.Slice {
|
||||
// Slice indicates has-many or many-to-many
|
||||
return RelationHasMany
|
||||
}
|
||||
if fieldType.Kind() == reflect.Ptr {
|
||||
// Pointer to single struct usually indicates belongs-to or has-one
|
||||
// Check if it has foreignKey (belongs-to) or references (has-one)
|
||||
if strings.Contains(gormTag, "foreignKey:") {
|
||||
return RelationBelongsTo
|
||||
}
|
||||
return RelationHasOne
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to field type inference
|
||||
fieldType := field.Type
|
||||
if fieldType.Kind() == reflect.Slice {
|
||||
// Slice of structs → has-many
|
||||
return RelationHasMany
|
||||
}
|
||||
if fieldType.Kind() == reflect.Ptr || fieldType.Kind() == reflect.Struct {
|
||||
// Single struct → belongs-to (default assumption for safety)
|
||||
// Using belongs-to as default ensures we use JOIN, which is safer
|
||||
return RelationBelongsTo
|
||||
}
|
||||
}
|
||||
|
||||
return RelationUnknown
|
||||
}
|
||||
|
||||
// GetRelationModel gets the model type for a relation field
|
||||
// It searches for the field by name in the following order (case-insensitive):
|
||||
// 1. Actual field name
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -22,11 +22,12 @@ type FallbackHandler func(w common.ResponseWriter, r common.Request, params map[
|
||||
|
||||
// Handler handles API requests using database and model abstractions
|
||||
type Handler struct {
|
||||
db common.Database
|
||||
registry common.ModelRegistry
|
||||
nestedProcessor *common.NestedCUDProcessor
|
||||
hooks *HookRegistry
|
||||
fallbackHandler FallbackHandler
|
||||
db common.Database
|
||||
registry common.ModelRegistry
|
||||
nestedProcessor *common.NestedCUDProcessor
|
||||
hooks *HookRegistry
|
||||
fallbackHandler FallbackHandler
|
||||
openAPIGenerator func() (string, error)
|
||||
}
|
||||
|
||||
// NewHandler creates a new API handler with database and registry abstractions
|
||||
@@ -75,6 +76,12 @@ func (h *Handler) Handle(w common.ResponseWriter, r common.Request, params map[s
|
||||
}
|
||||
}()
|
||||
|
||||
// Check for ?openapi query parameter
|
||||
if r.UnderlyingRequest().URL.Query().Get("openapi") != "" {
|
||||
h.HandleOpenAPI(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := r.UnderlyingRequest().Context()
|
||||
|
||||
body, err := r.Body()
|
||||
@@ -156,6 +163,12 @@ func (h *Handler) HandleGet(w common.ResponseWriter, r common.Request, params ma
|
||||
}
|
||||
}()
|
||||
|
||||
// Check for ?openapi query parameter
|
||||
if r.UnderlyingRequest().URL.Query().Get("openapi") != "" {
|
||||
h.HandleOpenAPI(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
schema := params["schema"]
|
||||
entity := params["entity"]
|
||||
|
||||
@@ -1433,3 +1446,31 @@ func toSnakeCase(s string) string {
|
||||
}
|
||||
return strings.ToLower(result.String())
|
||||
}
|
||||
|
||||
// HandleOpenAPI generates and returns the OpenAPI specification
|
||||
func (h *Handler) HandleOpenAPI(w common.ResponseWriter, r common.Request) {
|
||||
if h.openAPIGenerator == nil {
|
||||
logger.Error("OpenAPI generator not configured")
|
||||
h.sendError(w, http.StatusInternalServerError, "openapi_not_configured", "OpenAPI generation not configured", nil)
|
||||
return
|
||||
}
|
||||
|
||||
spec, err := h.openAPIGenerator()
|
||||
if err != nil {
|
||||
logger.Error("Failed to generate OpenAPI spec: %v", err)
|
||||
h.sendError(w, http.StatusInternalServerError, "openapi_generation_error", "Failed to generate OpenAPI specification", err)
|
||||
return
|
||||
}
|
||||
|
||||
w.SetHeader("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, err = w.Write([]byte(spec))
|
||||
if err != nil {
|
||||
logger.Error("Error sending OpenAPI spec response: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// SetOpenAPIGenerator sets the OpenAPI generator function
|
||||
func (h *Handler) SetOpenAPIGenerator(generator func() (string, error)) {
|
||||
h.openAPIGenerator = generator
|
||||
}
|
||||
|
||||
@@ -46,6 +46,16 @@ type MiddlewareFunc func(http.Handler) http.Handler
|
||||
// authMiddleware is optional - if provided, routes will be protected with the middleware
|
||||
// Example: SetupMuxRoutes(router, handler, func(h http.Handler) http.Handler { return security.NewAuthHandler(securityList, h) })
|
||||
func SetupMuxRoutes(muxRouter *mux.Router, handler *Handler, authMiddleware MiddlewareFunc) {
|
||||
// Add global /openapi route
|
||||
openAPIHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
corsConfig := common.DefaultCORSConfig()
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
common.SetCORSHeaders(respAdapter, corsConfig)
|
||||
reqAdapter := router.NewHTTPRequest(r)
|
||||
handler.HandleOpenAPI(respAdapter, reqAdapter)
|
||||
})
|
||||
muxRouter.Handle("/openapi", openAPIHandler).Methods("GET", "OPTIONS")
|
||||
|
||||
// Get all registered models from the registry
|
||||
allModels := handler.registry.GetAllModels()
|
||||
|
||||
@@ -201,12 +211,27 @@ func ExampleWithBun(bunDB *bun.DB) {
|
||||
func SetupBunRouterRoutes(bunRouter *router.StandardBunRouterAdapter, handler *Handler) {
|
||||
r := bunRouter.GetBunRouter()
|
||||
|
||||
// Get all registered models from the registry
|
||||
allModels := handler.registry.GetAllModels()
|
||||
|
||||
// CORS config
|
||||
corsConfig := common.DefaultCORSConfig()
|
||||
|
||||
// Add global /openapi route
|
||||
r.Handle("GET", "/openapi", func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
common.SetCORSHeaders(respAdapter, corsConfig)
|
||||
reqAdapter := router.NewHTTPRequest(req.Request)
|
||||
handler.HandleOpenAPI(respAdapter, reqAdapter)
|
||||
return nil
|
||||
})
|
||||
|
||||
r.Handle("OPTIONS", "/openapi", func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
common.SetCORSHeaders(respAdapter, corsConfig)
|
||||
return nil
|
||||
})
|
||||
|
||||
// Get all registered models from the registry
|
||||
allModels := handler.registry.GetAllModels()
|
||||
|
||||
// Loop through each registered model and create explicit routes
|
||||
for fullName := range allModels {
|
||||
// Parse the full name (e.g., "public.users" or just "users")
|
||||
|
||||
@@ -24,11 +24,12 @@ type FallbackHandler func(w common.ResponseWriter, r common.Request, params map[
|
||||
// Handler handles API requests using database and model abstractions
|
||||
// This handler reads filters, columns, and options from HTTP headers
|
||||
type Handler struct {
|
||||
db common.Database
|
||||
registry common.ModelRegistry
|
||||
hooks *HookRegistry
|
||||
nestedProcessor *common.NestedCUDProcessor
|
||||
fallbackHandler FallbackHandler
|
||||
db common.Database
|
||||
registry common.ModelRegistry
|
||||
hooks *HookRegistry
|
||||
nestedProcessor *common.NestedCUDProcessor
|
||||
fallbackHandler FallbackHandler
|
||||
openAPIGenerator func() (string, error)
|
||||
}
|
||||
|
||||
// NewHandler creates a new API handler with database and registry abstractions
|
||||
@@ -78,6 +79,12 @@ func (h *Handler) Handle(w common.ResponseWriter, r common.Request, params map[s
|
||||
}
|
||||
}()
|
||||
|
||||
// Check for ?openapi query parameter
|
||||
if r.UnderlyingRequest().URL.Query().Get("openapi") != "" {
|
||||
h.HandleOpenAPI(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := r.UnderlyingRequest().Context()
|
||||
|
||||
schema := params["schema"]
|
||||
@@ -208,6 +215,12 @@ func (h *Handler) HandleGet(w common.ResponseWriter, r common.Request, params ma
|
||||
}
|
||||
}()
|
||||
|
||||
// Check for ?openapi query parameter
|
||||
if r.UnderlyingRequest().URL.Query().Get("openapi") != "" {
|
||||
h.HandleOpenAPI(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
schema := params["schema"]
|
||||
entity := params["entity"]
|
||||
|
||||
@@ -2379,3 +2392,35 @@ func (h *Handler) extractTagValue(tag, key string) string {
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// HandleOpenAPI generates and returns the OpenAPI specification
|
||||
func (h *Handler) HandleOpenAPI(w common.ResponseWriter, r common.Request) {
|
||||
// Import needed here to avoid circular dependency
|
||||
// The import is done inline
|
||||
// We'll use a factory function approach instead
|
||||
if h.openAPIGenerator == nil {
|
||||
logger.Error("OpenAPI generator not configured")
|
||||
h.sendError(w, http.StatusInternalServerError, "openapi_not_configured", "OpenAPI generation not configured", nil)
|
||||
return
|
||||
}
|
||||
|
||||
spec, err := h.openAPIGenerator()
|
||||
if err != nil {
|
||||
logger.Error("Failed to generate OpenAPI spec: %v", err)
|
||||
h.sendError(w, http.StatusInternalServerError, "openapi_generation_error", "Failed to generate OpenAPI specification", err)
|
||||
return
|
||||
}
|
||||
|
||||
w.SetHeader("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, err = w.Write([]byte(spec))
|
||||
if err != nil {
|
||||
logger.Error("Error sending OpenAPI spec response: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// SetOpenAPIGenerator sets the OpenAPI generator function
|
||||
// This allows avoiding circular dependencies
|
||||
func (h *Handler) SetOpenAPIGenerator(generator func() (string, error)) {
|
||||
h.openAPIGenerator = generator
|
||||
}
|
||||
|
||||
@@ -99,6 +99,16 @@ type MiddlewareFunc func(http.Handler) http.Handler
|
||||
// authMiddleware is optional - if provided, routes will be protected with the middleware
|
||||
// Example: SetupMuxRoutes(router, handler, func(h http.Handler) http.Handler { return security.NewAuthHandler(securityList, h) })
|
||||
func SetupMuxRoutes(muxRouter *mux.Router, handler *Handler, authMiddleware MiddlewareFunc) {
|
||||
// Add global /openapi route
|
||||
openAPIHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
corsConfig := common.DefaultCORSConfig()
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
common.SetCORSHeaders(respAdapter, corsConfig)
|
||||
reqAdapter := router.NewHTTPRequest(r)
|
||||
handler.HandleOpenAPI(respAdapter, reqAdapter)
|
||||
})
|
||||
muxRouter.Handle("/openapi", openAPIHandler).Methods("GET", "OPTIONS")
|
||||
|
||||
// Get all registered models from the registry
|
||||
allModels := handler.registry.GetAllModels()
|
||||
|
||||
@@ -264,12 +274,27 @@ func ExampleWithBun(bunDB *bun.DB) {
|
||||
func SetupBunRouterRoutes(bunRouter *router.StandardBunRouterAdapter, handler *Handler) {
|
||||
r := bunRouter.GetBunRouter()
|
||||
|
||||
// Get all registered models from the registry
|
||||
allModels := handler.registry.GetAllModels()
|
||||
|
||||
// CORS config
|
||||
corsConfig := common.DefaultCORSConfig()
|
||||
|
||||
// Add global /openapi route
|
||||
r.Handle("GET", "/openapi", func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
common.SetCORSHeaders(respAdapter, corsConfig)
|
||||
reqAdapter := router.NewBunRouterRequest(req)
|
||||
handler.HandleOpenAPI(respAdapter, reqAdapter)
|
||||
return nil
|
||||
})
|
||||
|
||||
r.Handle("OPTIONS", "/openapi", func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
common.SetCORSHeaders(respAdapter, corsConfig)
|
||||
return nil
|
||||
})
|
||||
|
||||
// Get all registered models from the registry
|
||||
allModels := handler.registry.GetAllModels()
|
||||
|
||||
// Loop through each registered model and create explicit routes
|
||||
for fullName := range allModels {
|
||||
// Parse the full name (e.g., "public.users" or just "users")
|
||||
|
||||
Reference in New Issue
Block a user