mirror of
https://github.com/bitechdev/ResolveSpec.git
synced 2025-12-30 16:24:26 +00:00
Compare commits
14 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c26ea3cd61 | ||
|
|
a5d97cc07b | ||
|
|
0899ba5029 | ||
|
|
c84dd7dc91 | ||
|
|
f1c6b36374 | ||
|
|
abee5c942f | ||
|
|
2e9a0bd51a | ||
|
|
f518a3c73c | ||
|
|
07c239aaa1 | ||
|
|
1adca4c49b | ||
|
|
eefed23766 | ||
|
|
3b2d05465e | ||
|
|
e88018543e | ||
|
|
e7e5754a47 |
138
SCHEMA_TABLE_HANDLING.md
Normal file
138
SCHEMA_TABLE_HANDLING.md
Normal file
@@ -0,0 +1,138 @@
|
|||||||
|
# Schema and Table Name Handling
|
||||||
|
|
||||||
|
This document explains how the handlers properly separate and handle schema and table names.
|
||||||
|
|
||||||
|
## Implementation
|
||||||
|
|
||||||
|
Both `resolvespec` and `restheadspec` handlers now properly handle schema and table name separation through the following functions:
|
||||||
|
|
||||||
|
- `parseTableName(fullTableName)` - Splits "schema.table" into separate components
|
||||||
|
- `getSchemaAndTable(defaultSchema, entity, model)` - Returns schema and table separately
|
||||||
|
- `getTableName(schema, entity, model)` - Returns the full "schema.table" format
|
||||||
|
|
||||||
|
## Priority Order
|
||||||
|
|
||||||
|
When determining the schema and table name, the following priority is used:
|
||||||
|
|
||||||
|
1. **If `TableName()` contains a schema** (e.g., "myschema.mytable"), that schema takes precedence
|
||||||
|
2. **If model implements `SchemaProvider`**, use that schema
|
||||||
|
3. **Otherwise**, use the `defaultSchema` parameter from the URL/request
|
||||||
|
|
||||||
|
## Scenarios
|
||||||
|
|
||||||
|
### Scenario 1: Simple table name, default schema
|
||||||
|
```go
|
||||||
|
type User struct {
|
||||||
|
ID string
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (User) TableName() string {
|
||||||
|
return "users"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
- Request URL: `/api/public/users`
|
||||||
|
- Result: `schema="public"`, `table="users"`, `fullName="public.users"`
|
||||||
|
|
||||||
|
### Scenario 2: Table name includes schema
|
||||||
|
```go
|
||||||
|
type User struct {
|
||||||
|
ID string
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (User) TableName() string {
|
||||||
|
return "auth.users" // Schema included!
|
||||||
|
}
|
||||||
|
```
|
||||||
|
- Request URL: `/api/public/users` (public is ignored)
|
||||||
|
- Result: `schema="auth"`, `table="users"`, `fullName="auth.users"`
|
||||||
|
- **Note**: The schema from `TableName()` takes precedence over the URL schema
|
||||||
|
|
||||||
|
### Scenario 3: Using SchemaProvider
|
||||||
|
```go
|
||||||
|
type User struct {
|
||||||
|
ID string
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (User) TableName() string {
|
||||||
|
return "users"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (User) SchemaName() string {
|
||||||
|
return "auth"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
- Request URL: `/api/public/users` (public is ignored)
|
||||||
|
- Result: `schema="auth"`, `table="users"`, `fullName="auth.users"`
|
||||||
|
|
||||||
|
### Scenario 4: Table name includes schema AND SchemaProvider
|
||||||
|
```go
|
||||||
|
type User struct {
|
||||||
|
ID string
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (User) TableName() string {
|
||||||
|
return "core.users" // This wins!
|
||||||
|
}
|
||||||
|
|
||||||
|
func (User) SchemaName() string {
|
||||||
|
return "auth" // This is ignored
|
||||||
|
}
|
||||||
|
```
|
||||||
|
- Request URL: `/api/public/users`
|
||||||
|
- Result: `schema="core"`, `table="users"`, `fullName="core.users"`
|
||||||
|
- **Note**: Schema from `TableName()` takes highest precedence
|
||||||
|
|
||||||
|
### Scenario 5: No providers at all
|
||||||
|
```go
|
||||||
|
type User struct {
|
||||||
|
ID string
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
// No TableName() or SchemaName()
|
||||||
|
```
|
||||||
|
- Request URL: `/api/public/users`
|
||||||
|
- Result: `schema="public"`, `table="users"`, `fullName="public.users"`
|
||||||
|
- Uses URL schema and entity name
|
||||||
|
|
||||||
|
## Key Features
|
||||||
|
|
||||||
|
1. **Automatic detection**: The code automatically detects if `TableName()` includes a schema by checking for "."
|
||||||
|
2. **Backward compatible**: Existing code continues to work
|
||||||
|
3. **Flexible**: Supports multiple ways to specify schema and table
|
||||||
|
4. **Debug logging**: Logs when schema is detected in `TableName()` for debugging
|
||||||
|
|
||||||
|
## Code Locations
|
||||||
|
|
||||||
|
### Handlers
|
||||||
|
- `/pkg/resolvespec/handler.go:472-531`
|
||||||
|
- `/pkg/restheadspec/handler.go:534-593`
|
||||||
|
|
||||||
|
### Database Adapters
|
||||||
|
- `/pkg/common/adapters/database/utils.go` - Shared `parseTableName()` function
|
||||||
|
- `/pkg/common/adapters/database/bun.go` - Bun adapter with separated schema/table
|
||||||
|
- `/pkg/common/adapters/database/gorm.go` - GORM adapter with separated schema/table
|
||||||
|
|
||||||
|
## Adapter Implementation
|
||||||
|
|
||||||
|
Both Bun and GORM adapters now properly separate schema and table name:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// BunSelectQuery/GormSelectQuery now have separated fields:
|
||||||
|
type BunSelectQuery struct {
|
||||||
|
query *bun.SelectQuery
|
||||||
|
schema string // Separated schema name
|
||||||
|
tableName string // Just the table name, without schema
|
||||||
|
tableAlias string
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
When `Model()` or `Table()` is called:
|
||||||
|
1. The full table name (which may include schema) is parsed
|
||||||
|
2. Schema and table name are stored separately
|
||||||
|
3. When building joins, the already-separated table name is used directly
|
||||||
|
|
||||||
|
This ensures consistent handling of schema-qualified table names throughout the codebase.
|
||||||
@@ -22,7 +22,10 @@ func NewBunAdapter(db *bun.DB) *BunAdapter {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (b *BunAdapter) NewSelect() common.SelectQuery {
|
func (b *BunAdapter) NewSelect() common.SelectQuery {
|
||||||
return &BunSelectQuery{query: b.db.NewSelect()}
|
return &BunSelectQuery{
|
||||||
|
query: b.db.NewSelect(),
|
||||||
|
db: b.db,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *BunAdapter) NewInsert() common.InsertQuery {
|
func (b *BunAdapter) NewInsert() common.InsertQuery {
|
||||||
@@ -78,16 +81,22 @@ func (b *BunAdapter) RunInTransaction(ctx context.Context, fn func(common.Databa
|
|||||||
// BunSelectQuery implements SelectQuery for Bun
|
// BunSelectQuery implements SelectQuery for Bun
|
||||||
type BunSelectQuery struct {
|
type BunSelectQuery struct {
|
||||||
query *bun.SelectQuery
|
query *bun.SelectQuery
|
||||||
tableName string
|
db bun.IDB // Store DB connection for count queries
|
||||||
|
hasModel bool // Track if Model() was called
|
||||||
|
schema string // Separated schema name
|
||||||
|
tableName string // Just the table name, without schema
|
||||||
tableAlias string
|
tableAlias string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *BunSelectQuery) Model(model interface{}) common.SelectQuery {
|
func (b *BunSelectQuery) Model(model interface{}) common.SelectQuery {
|
||||||
b.query = b.query.Model(model)
|
b.query = b.query.Model(model)
|
||||||
|
b.hasModel = true // Mark that we have a model
|
||||||
|
|
||||||
// Try to get table name from model if it implements TableNameProvider
|
// Try to get table name from model if it implements TableNameProvider
|
||||||
if provider, ok := model.(common.TableNameProvider); ok {
|
if provider, ok := model.(common.TableNameProvider); ok {
|
||||||
b.tableName = provider.TableName()
|
fullTableName := provider.TableName()
|
||||||
|
// Check if the table name contains schema (e.g., "schema.table")
|
||||||
|
b.schema, b.tableName = parseTableName(fullTableName)
|
||||||
}
|
}
|
||||||
|
|
||||||
return b
|
return b
|
||||||
@@ -95,7 +104,8 @@ func (b *BunSelectQuery) Model(model interface{}) common.SelectQuery {
|
|||||||
|
|
||||||
func (b *BunSelectQuery) Table(table string) common.SelectQuery {
|
func (b *BunSelectQuery) Table(table string) common.SelectQuery {
|
||||||
b.query = b.query.Table(table)
|
b.query = b.query.Table(table)
|
||||||
b.tableName = table
|
// Check if the table name contains schema (e.g., "schema.table")
|
||||||
|
b.schema, b.tableName = parseTableName(table)
|
||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -128,13 +138,9 @@ func (b *BunSelectQuery) Join(query string, args ...interface{}) common.SelectQu
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// If no prefix provided, use the table name as prefix
|
// If no prefix provided, use the table name as prefix (already separated from schema)
|
||||||
if prefix == "" && b.tableName != "" {
|
if prefix == "" && b.tableName != "" {
|
||||||
prefix = b.tableName
|
prefix = b.tableName
|
||||||
// Extract just the table name if it has schema
|
|
||||||
if idx := strings.LastIndex(prefix, "."); idx != -1 {
|
|
||||||
prefix = prefix[idx+1:]
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// If prefix is provided, add it as an alias in the join
|
// If prefix is provided, add it as an alias in the join
|
||||||
@@ -169,12 +175,9 @@ func (b *BunSelectQuery) LeftJoin(query string, args ...interface{}) common.Sele
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// If no prefix provided, use the table name as prefix
|
// If no prefix provided, use the table name as prefix (already separated from schema)
|
||||||
if prefix == "" && b.tableName != "" {
|
if prefix == "" && b.tableName != "" {
|
||||||
prefix = b.tableName
|
prefix = b.tableName
|
||||||
if idx := strings.LastIndex(prefix, "."); idx != -1 {
|
|
||||||
prefix = prefix[idx+1:]
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Construct LEFT JOIN with prefix
|
// Construct LEFT JOIN with prefix
|
||||||
@@ -231,7 +234,19 @@ func (b *BunSelectQuery) Scan(ctx context.Context, dest interface{}) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (b *BunSelectQuery) Count(ctx context.Context) (int, error) {
|
func (b *BunSelectQuery) Count(ctx context.Context) (int, error) {
|
||||||
count, err := b.query.Count(ctx)
|
// If Model() was set, use bun's native Count() which works properly
|
||||||
|
if b.hasModel {
|
||||||
|
count, err := b.query.Count(ctx)
|
||||||
|
return count, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Otherwise, wrap as subquery to avoid "Model(nil)" error
|
||||||
|
// This is needed when only Table() is set without a model
|
||||||
|
var count int
|
||||||
|
err := b.db.NewSelect().
|
||||||
|
TableExpr("(?) AS subquery", b.query).
|
||||||
|
ColumnExpr("COUNT(*)").
|
||||||
|
Scan(ctx, &count)
|
||||||
return count, err
|
return count, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -381,7 +396,10 @@ type BunTxAdapter struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (b *BunTxAdapter) NewSelect() common.SelectQuery {
|
func (b *BunTxAdapter) NewSelect() common.SelectQuery {
|
||||||
return &BunSelectQuery{query: b.tx.NewSelect()}
|
return &BunSelectQuery{
|
||||||
|
query: b.tx.NewSelect(),
|
||||||
|
db: b.tx,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *BunTxAdapter) NewInsert() common.InsertQuery {
|
func (b *BunTxAdapter) NewInsert() common.InsertQuery {
|
||||||
|
|||||||
@@ -70,7 +70,8 @@ func (g *GormAdapter) RunInTransaction(ctx context.Context, fn func(common.Datab
|
|||||||
// GormSelectQuery implements SelectQuery for GORM
|
// GormSelectQuery implements SelectQuery for GORM
|
||||||
type GormSelectQuery struct {
|
type GormSelectQuery struct {
|
||||||
db *gorm.DB
|
db *gorm.DB
|
||||||
tableName string
|
schema string // Separated schema name
|
||||||
|
tableName string // Just the table name, without schema
|
||||||
tableAlias string
|
tableAlias string
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -79,7 +80,9 @@ func (g *GormSelectQuery) Model(model interface{}) common.SelectQuery {
|
|||||||
|
|
||||||
// Try to get table name from model if it implements TableNameProvider
|
// Try to get table name from model if it implements TableNameProvider
|
||||||
if provider, ok := model.(common.TableNameProvider); ok {
|
if provider, ok := model.(common.TableNameProvider); ok {
|
||||||
g.tableName = provider.TableName()
|
fullTableName := provider.TableName()
|
||||||
|
// Check if the table name contains schema (e.g., "schema.table")
|
||||||
|
g.schema, g.tableName = parseTableName(fullTableName)
|
||||||
}
|
}
|
||||||
|
|
||||||
return g
|
return g
|
||||||
@@ -87,7 +90,8 @@ func (g *GormSelectQuery) Model(model interface{}) common.SelectQuery {
|
|||||||
|
|
||||||
func (g *GormSelectQuery) Table(table string) common.SelectQuery {
|
func (g *GormSelectQuery) Table(table string) common.SelectQuery {
|
||||||
g.db = g.db.Table(table)
|
g.db = g.db.Table(table)
|
||||||
g.tableName = table
|
// Check if the table name contains schema (e.g., "schema.table")
|
||||||
|
g.schema, g.tableName = parseTableName(table)
|
||||||
return g
|
return g
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -120,13 +124,9 @@ func (g *GormSelectQuery) Join(query string, args ...interface{}) common.SelectQ
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// If no prefix provided, use the table name as prefix
|
// If no prefix provided, use the table name as prefix (already separated from schema)
|
||||||
if prefix == "" && g.tableName != "" {
|
if prefix == "" && g.tableName != "" {
|
||||||
prefix = g.tableName
|
prefix = g.tableName
|
||||||
// Extract just the table name if it has schema
|
|
||||||
if idx := strings.LastIndex(prefix, "."); idx != -1 {
|
|
||||||
prefix = prefix[idx+1:]
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// If prefix is provided, add it as an alias in the join
|
// If prefix is provided, add it as an alias in the join
|
||||||
@@ -161,12 +161,9 @@ func (g *GormSelectQuery) LeftJoin(query string, args ...interface{}) common.Sel
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// If no prefix provided, use the table name as prefix
|
// If no prefix provided, use the table name as prefix (already separated from schema)
|
||||||
if prefix == "" && g.tableName != "" {
|
if prefix == "" && g.tableName != "" {
|
||||||
prefix = g.tableName
|
prefix = g.tableName
|
||||||
if idx := strings.LastIndex(prefix, "."); idx != -1 {
|
|
||||||
prefix = prefix[idx+1:]
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Construct LEFT JOIN with prefix
|
// Construct LEFT JOIN with prefix
|
||||||
|
|||||||
13
pkg/common/adapters/database/utils.go
Normal file
13
pkg/common/adapters/database/utils.go
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
package database
|
||||||
|
|
||||||
|
import "strings"
|
||||||
|
|
||||||
|
// parseTableName splits a table name that may contain schema into separate schema and table
|
||||||
|
// For example: "public.users" -> ("public", "users")
|
||||||
|
// "users" -> ("", "users")
|
||||||
|
func parseTableName(fullTableName string) (schema, table string) {
|
||||||
|
if idx := strings.LastIndex(fullTableName, "."); idx != -1 {
|
||||||
|
return fullTableName[:idx], fullTableName[idx+1:]
|
||||||
|
}
|
||||||
|
return "", fullTableName
|
||||||
|
}
|
||||||
@@ -37,9 +37,10 @@ type PreloadOption struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type FilterOption struct {
|
type FilterOption struct {
|
||||||
Column string `json:"column"`
|
Column string `json:"column"`
|
||||||
Operator string `json:"operator"`
|
Operator string `json:"operator"`
|
||||||
Value interface{} `json:"value"`
|
Value interface{} `json:"value"`
|
||||||
|
LogicOperator string `json:"logic_operator"` // "AND" or "OR" - how this filter combines with previous filters
|
||||||
}
|
}
|
||||||
|
|
||||||
type SortOption struct {
|
type SortOption struct {
|
||||||
|
|||||||
272
pkg/common/validation.go
Normal file
272
pkg/common/validation.go
Normal file
@@ -0,0 +1,272 @@
|
|||||||
|
package common
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/Warky-Devs/ResolveSpec/pkg/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ColumnValidator validates column names against a model's fields
|
||||||
|
type ColumnValidator struct {
|
||||||
|
validColumns map[string]bool
|
||||||
|
model interface{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewColumnValidator creates a new column validator for a given model
|
||||||
|
func NewColumnValidator(model interface{}) *ColumnValidator {
|
||||||
|
validator := &ColumnValidator{
|
||||||
|
validColumns: make(map[string]bool),
|
||||||
|
model: model,
|
||||||
|
}
|
||||||
|
validator.buildValidColumns()
|
||||||
|
return validator
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildValidColumns extracts all valid column names from the model using reflection
|
||||||
|
func (v *ColumnValidator) buildValidColumns() {
|
||||||
|
modelType := reflect.TypeOf(v.model)
|
||||||
|
|
||||||
|
// Unwrap pointers, slices, and arrays to get to the base struct type
|
||||||
|
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
||||||
|
modelType = modelType.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate that we have a struct type
|
||||||
|
if modelType == nil || modelType.Kind() != reflect.Struct {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract column names from struct fields
|
||||||
|
for i := 0; i < modelType.NumField(); i++ {
|
||||||
|
field := modelType.Field(i)
|
||||||
|
|
||||||
|
if !field.IsExported() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get column name from bun, gorm, or json tag
|
||||||
|
columnName := v.getColumnName(field)
|
||||||
|
if columnName != "" && columnName != "-" {
|
||||||
|
v.validColumns[strings.ToLower(columnName)] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// getColumnName extracts the column name from a struct field's tags
|
||||||
|
// Supports both Bun and GORM tags
|
||||||
|
func (v *ColumnValidator) getColumnName(field reflect.StructField) string {
|
||||||
|
// First check Bun tag for column name
|
||||||
|
bunTag := field.Tag.Get("bun")
|
||||||
|
if bunTag != "" && bunTag != "-" {
|
||||||
|
parts := strings.Split(bunTag, ",")
|
||||||
|
// The first part is usually the column name
|
||||||
|
columnName := strings.TrimSpace(parts[0])
|
||||||
|
if columnName != "" && columnName != "-" {
|
||||||
|
return columnName
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check GORM tag for column name
|
||||||
|
gormTag := field.Tag.Get("gorm")
|
||||||
|
if strings.Contains(gormTag, "column:") {
|
||||||
|
parts := strings.Split(gormTag, ";")
|
||||||
|
for _, part := range parts {
|
||||||
|
part = strings.TrimSpace(part)
|
||||||
|
if strings.HasPrefix(part, "column:") {
|
||||||
|
return strings.TrimPrefix(part, "column:")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fall back to JSON tag
|
||||||
|
jsonTag := field.Tag.Get("json")
|
||||||
|
if jsonTag != "" && jsonTag != "-" {
|
||||||
|
// Extract just the name part (before any comma)
|
||||||
|
jsonName := strings.Split(jsonTag, ",")[0]
|
||||||
|
return jsonName
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fall back to field name in lowercase (snake_case conversion would be better)
|
||||||
|
return strings.ToLower(field.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateColumn validates a single column name
|
||||||
|
// Returns nil if valid, error if invalid
|
||||||
|
// Columns prefixed with "cql" (case insensitive) are always valid
|
||||||
|
func (v *ColumnValidator) ValidateColumn(column string) error {
|
||||||
|
// Allow empty columns
|
||||||
|
if column == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Allow columns prefixed with "cql" (case insensitive) for computed columns
|
||||||
|
if strings.HasPrefix(strings.ToLower(column), "cql") {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if column exists in model
|
||||||
|
if _, exists := v.validColumns[strings.ToLower(column)]; !exists {
|
||||||
|
return fmt.Errorf("invalid column '%s': column does not exist in model", column)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsValidColumn checks if a column is valid
|
||||||
|
// Returns true if valid, false if invalid
|
||||||
|
func (v *ColumnValidator) IsValidColumn(column string) bool {
|
||||||
|
return v.ValidateColumn(column) == nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// FilterValidColumns filters a list of columns, returning only valid ones
|
||||||
|
// Logs warnings for any invalid columns
|
||||||
|
func (v *ColumnValidator) FilterValidColumns(columns []string) []string {
|
||||||
|
if len(columns) == 0 {
|
||||||
|
return columns
|
||||||
|
}
|
||||||
|
|
||||||
|
validColumns := make([]string, 0, len(columns))
|
||||||
|
for _, col := range columns {
|
||||||
|
if v.IsValidColumn(col) {
|
||||||
|
validColumns = append(validColumns, col)
|
||||||
|
} else {
|
||||||
|
logger.Warn("Invalid column '%s' filtered out: column does not exist in model", col)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return validColumns
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateColumns validates multiple column names
|
||||||
|
// Returns error with details about all invalid columns
|
||||||
|
func (v *ColumnValidator) ValidateColumns(columns []string) error {
|
||||||
|
var invalidColumns []string
|
||||||
|
|
||||||
|
for _, column := range columns {
|
||||||
|
if err := v.ValidateColumn(column); err != nil {
|
||||||
|
invalidColumns = append(invalidColumns, column)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(invalidColumns) > 0 {
|
||||||
|
return fmt.Errorf("invalid columns: %s", strings.Join(invalidColumns, ", "))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateRequestOptions validates all column references in RequestOptions
|
||||||
|
func (v *ColumnValidator) ValidateRequestOptions(options RequestOptions) error {
|
||||||
|
// Validate Columns
|
||||||
|
if err := v.ValidateColumns(options.Columns); err != nil {
|
||||||
|
return fmt.Errorf("in select columns: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate OmitColumns
|
||||||
|
if err := v.ValidateColumns(options.OmitColumns); err != nil {
|
||||||
|
return fmt.Errorf("in omit columns: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate Filter columns
|
||||||
|
for _, filter := range options.Filters {
|
||||||
|
if err := v.ValidateColumn(filter.Column); err != nil {
|
||||||
|
return fmt.Errorf("in filter: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate Sort columns
|
||||||
|
for _, sort := range options.Sort {
|
||||||
|
if err := v.ValidateColumn(sort.Column); err != nil {
|
||||||
|
return fmt.Errorf("in sort: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate Preload columns (if specified)
|
||||||
|
for _, preload := range options.Preload {
|
||||||
|
// Note: We don't validate the relation name itself, as it's a relationship
|
||||||
|
// Only validate columns if specified for the preload
|
||||||
|
if err := v.ValidateColumns(preload.Columns); err != nil {
|
||||||
|
return fmt.Errorf("in preload '%s' columns: %w", preload.Relation, err)
|
||||||
|
}
|
||||||
|
if err := v.ValidateColumns(preload.OmitColumns); err != nil {
|
||||||
|
return fmt.Errorf("in preload '%s' omit columns: %w", preload.Relation, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate filter columns in preload
|
||||||
|
for _, filter := range preload.Filters {
|
||||||
|
if err := v.ValidateColumn(filter.Column); err != nil {
|
||||||
|
return fmt.Errorf("in preload '%s' filter: %w", preload.Relation, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// FilterRequestOptions filters all column references in RequestOptions
|
||||||
|
// Returns a new RequestOptions with only valid columns, logging warnings for invalid ones
|
||||||
|
func (v *ColumnValidator) FilterRequestOptions(options RequestOptions) RequestOptions {
|
||||||
|
filtered := options
|
||||||
|
|
||||||
|
// Filter Columns
|
||||||
|
filtered.Columns = v.FilterValidColumns(options.Columns)
|
||||||
|
|
||||||
|
// Filter OmitColumns
|
||||||
|
filtered.OmitColumns = v.FilterValidColumns(options.OmitColumns)
|
||||||
|
|
||||||
|
// Filter Filter columns
|
||||||
|
validFilters := make([]FilterOption, 0, len(options.Filters))
|
||||||
|
for _, filter := range options.Filters {
|
||||||
|
if v.IsValidColumn(filter.Column) {
|
||||||
|
validFilters = append(validFilters, filter)
|
||||||
|
} else {
|
||||||
|
logger.Warn("Invalid column in filter '%s' removed", filter.Column)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
filtered.Filters = validFilters
|
||||||
|
|
||||||
|
// Filter Sort columns
|
||||||
|
validSorts := make([]SortOption, 0, len(options.Sort))
|
||||||
|
for _, sort := range options.Sort {
|
||||||
|
if v.IsValidColumn(sort.Column) {
|
||||||
|
validSorts = append(validSorts, sort)
|
||||||
|
} else {
|
||||||
|
logger.Warn("Invalid column in sort '%s' removed", sort.Column)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
filtered.Sort = validSorts
|
||||||
|
|
||||||
|
// Filter Preload columns
|
||||||
|
validPreloads := make([]PreloadOption, 0, len(options.Preload))
|
||||||
|
for _, preload := range options.Preload {
|
||||||
|
filteredPreload := preload
|
||||||
|
filteredPreload.Columns = v.FilterValidColumns(preload.Columns)
|
||||||
|
filteredPreload.OmitColumns = v.FilterValidColumns(preload.OmitColumns)
|
||||||
|
|
||||||
|
// Filter preload filters
|
||||||
|
validPreloadFilters := make([]FilterOption, 0, len(preload.Filters))
|
||||||
|
for _, filter := range preload.Filters {
|
||||||
|
if v.IsValidColumn(filter.Column) {
|
||||||
|
validPreloadFilters = append(validPreloadFilters, filter)
|
||||||
|
} else {
|
||||||
|
logger.Warn("Invalid column in preload '%s' filter '%s' removed", preload.Relation, filter.Column)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
filteredPreload.Filters = validPreloadFilters
|
||||||
|
|
||||||
|
validPreloads = append(validPreloads, filteredPreload)
|
||||||
|
}
|
||||||
|
filtered.Preload = validPreloads
|
||||||
|
|
||||||
|
return filtered
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetValidColumns returns a list of all valid column names for debugging purposes
|
||||||
|
func (v *ColumnValidator) GetValidColumns() []string {
|
||||||
|
columns := make([]string, 0, len(v.validColumns))
|
||||||
|
for col := range v.validColumns {
|
||||||
|
columns = append(columns, col)
|
||||||
|
}
|
||||||
|
return columns
|
||||||
|
}
|
||||||
363
pkg/common/validation_test.go
Normal file
363
pkg/common/validation_test.go
Normal file
@@ -0,0 +1,363 @@
|
|||||||
|
package common
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestModel represents a sample model for testing
|
||||||
|
type TestModel struct {
|
||||||
|
ID int64 `json:"id" gorm:"primaryKey"`
|
||||||
|
Name string `json:"name" gorm:"column:name"`
|
||||||
|
Email string `json:"email" bun:"email"`
|
||||||
|
Age int `json:"age"`
|
||||||
|
IsActive bool `json:"is_active"`
|
||||||
|
CreatedAt string `json:"created_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewColumnValidator(t *testing.T) {
|
||||||
|
model := TestModel{}
|
||||||
|
validator := NewColumnValidator(model)
|
||||||
|
|
||||||
|
if validator == nil {
|
||||||
|
t.Fatal("Expected validator to be created")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(validator.validColumns) == 0 {
|
||||||
|
t.Fatal("Expected validator to have valid columns")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that expected columns are present
|
||||||
|
expectedColumns := []string{"id", "name", "email", "age", "is_active", "created_at"}
|
||||||
|
for _, col := range expectedColumns {
|
||||||
|
if !validator.validColumns[col] {
|
||||||
|
t.Errorf("Expected column '%s' to be valid", col)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateColumn(t *testing.T) {
|
||||||
|
model := TestModel{}
|
||||||
|
validator := NewColumnValidator(model)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
column string
|
||||||
|
shouldError bool
|
||||||
|
}{
|
||||||
|
{"Valid column - id", "id", false},
|
||||||
|
{"Valid column - name", "name", false},
|
||||||
|
{"Valid column - email", "email", false},
|
||||||
|
{"Valid column - uppercase", "ID", false}, // Case insensitive
|
||||||
|
{"Invalid column", "invalid_column", true},
|
||||||
|
{"CQL prefixed - should be valid", "cqlComputedField", false},
|
||||||
|
{"CQL prefixed uppercase - should be valid", "CQLComputedField", false},
|
||||||
|
{"Empty column", "", false}, // Empty columns are allowed
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
err := validator.ValidateColumn(tt.column)
|
||||||
|
if tt.shouldError && err == nil {
|
||||||
|
t.Errorf("Expected error for column '%s', got nil", tt.column)
|
||||||
|
}
|
||||||
|
if !tt.shouldError && err != nil {
|
||||||
|
t.Errorf("Expected no error for column '%s', got: %v", tt.column, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateColumns(t *testing.T) {
|
||||||
|
model := TestModel{}
|
||||||
|
validator := NewColumnValidator(model)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
columns []string
|
||||||
|
shouldError bool
|
||||||
|
}{
|
||||||
|
{"All valid columns", []string{"id", "name", "email"}, false},
|
||||||
|
{"One invalid column", []string{"id", "invalid_col", "name"}, true},
|
||||||
|
{"All invalid columns", []string{"bad1", "bad2"}, true},
|
||||||
|
{"With CQL prefix", []string{"id", "cqlComputed", "name"}, false},
|
||||||
|
{"Empty list", []string{}, false},
|
||||||
|
{"Nil list", nil, false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
err := validator.ValidateColumns(tt.columns)
|
||||||
|
if tt.shouldError && err == nil {
|
||||||
|
t.Errorf("Expected error for columns %v, got nil", tt.columns)
|
||||||
|
}
|
||||||
|
if !tt.shouldError && err != nil {
|
||||||
|
t.Errorf("Expected no error for columns %v, got: %v", tt.columns, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateRequestOptions(t *testing.T) {
|
||||||
|
model := TestModel{}
|
||||||
|
validator := NewColumnValidator(model)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
options RequestOptions
|
||||||
|
shouldError bool
|
||||||
|
errorMsg string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Valid options with columns",
|
||||||
|
options: RequestOptions{
|
||||||
|
Columns: []string{"id", "name"},
|
||||||
|
Filters: []FilterOption{
|
||||||
|
{Column: "name", Operator: "eq", Value: "test"},
|
||||||
|
},
|
||||||
|
Sort: []SortOption{
|
||||||
|
{Column: "id", Direction: "ASC"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
shouldError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Invalid column in Columns",
|
||||||
|
options: RequestOptions{
|
||||||
|
Columns: []string{"id", "invalid_column"},
|
||||||
|
},
|
||||||
|
shouldError: true,
|
||||||
|
errorMsg: "select columns",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Invalid column in Filters",
|
||||||
|
options: RequestOptions{
|
||||||
|
Filters: []FilterOption{
|
||||||
|
{Column: "invalid_col", Operator: "eq", Value: "test"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
shouldError: true,
|
||||||
|
errorMsg: "filter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Invalid column in Sort",
|
||||||
|
options: RequestOptions{
|
||||||
|
Sort: []SortOption{
|
||||||
|
{Column: "invalid_col", Direction: "ASC"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
shouldError: true,
|
||||||
|
errorMsg: "sort",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Valid CQL prefixed columns",
|
||||||
|
options: RequestOptions{
|
||||||
|
Columns: []string{"id", "cqlComputedField"},
|
||||||
|
Filters: []FilterOption{
|
||||||
|
{Column: "cqlCustomFilter", Operator: "eq", Value: "test"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
shouldError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Invalid column in Preload",
|
||||||
|
options: RequestOptions{
|
||||||
|
Preload: []PreloadOption{
|
||||||
|
{
|
||||||
|
Relation: "SomeRelation",
|
||||||
|
Columns: []string{"id", "invalid_col"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
shouldError: true,
|
||||||
|
errorMsg: "preload",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Valid preload with valid columns",
|
||||||
|
options: RequestOptions{
|
||||||
|
Preload: []PreloadOption{
|
||||||
|
{
|
||||||
|
Relation: "SomeRelation",
|
||||||
|
Columns: []string{"id", "name"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
shouldError: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
err := validator.ValidateRequestOptions(tt.options)
|
||||||
|
if tt.shouldError {
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("Expected error, got nil")
|
||||||
|
} else if tt.errorMsg != "" && !strings.Contains(err.Error(), tt.errorMsg) {
|
||||||
|
t.Errorf("Expected error to contain '%s', got: %v", tt.errorMsg, err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Expected no error, got: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetValidColumns(t *testing.T) {
|
||||||
|
model := TestModel{}
|
||||||
|
validator := NewColumnValidator(model)
|
||||||
|
|
||||||
|
columns := validator.GetValidColumns()
|
||||||
|
if len(columns) == 0 {
|
||||||
|
t.Error("Expected to get valid columns, got empty list")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should have at least the columns from TestModel
|
||||||
|
if len(columns) < 6 {
|
||||||
|
t.Errorf("Expected at least 6 columns, got %d", len(columns))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test with Bun tags specifically
|
||||||
|
type BunModel struct {
|
||||||
|
ID int64 `bun:"id,pk"`
|
||||||
|
Name string `bun:"name"`
|
||||||
|
Email string `bun:"user_email"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBunTagSupport(t *testing.T) {
|
||||||
|
model := BunModel{}
|
||||||
|
validator := NewColumnValidator(model)
|
||||||
|
|
||||||
|
// Test that bun tags are properly recognized
|
||||||
|
tests := []struct {
|
||||||
|
column string
|
||||||
|
shouldError bool
|
||||||
|
}{
|
||||||
|
{"id", false},
|
||||||
|
{"name", false},
|
||||||
|
{"user_email", false}, // Bun tag specifies this name
|
||||||
|
{"email", true}, // JSON tag would be "email", but bun tag says "user_email"
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.column, func(t *testing.T) {
|
||||||
|
err := validator.ValidateColumn(tt.column)
|
||||||
|
if tt.shouldError && err == nil {
|
||||||
|
t.Errorf("Expected error for column '%s'", tt.column)
|
||||||
|
}
|
||||||
|
if !tt.shouldError && err != nil {
|
||||||
|
t.Errorf("Expected no error for column '%s', got: %v", tt.column, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFilterValidColumns(t *testing.T) {
|
||||||
|
model := TestModel{}
|
||||||
|
validator := NewColumnValidator(model)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input []string
|
||||||
|
expectedOutput []string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "All valid columns",
|
||||||
|
input: []string{"id", "name", "email"},
|
||||||
|
expectedOutput: []string{"id", "name", "email"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Mix of valid and invalid",
|
||||||
|
input: []string{"id", "invalid_col", "name", "bad_col", "email"},
|
||||||
|
expectedOutput: []string{"id", "name", "email"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "All invalid columns",
|
||||||
|
input: []string{"bad1", "bad2"},
|
||||||
|
expectedOutput: []string{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "With CQL prefix (should pass)",
|
||||||
|
input: []string{"id", "cqlComputed", "name"},
|
||||||
|
expectedOutput: []string{"id", "cqlComputed", "name"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Empty input",
|
||||||
|
input: []string{},
|
||||||
|
expectedOutput: []string{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Nil input",
|
||||||
|
input: nil,
|
||||||
|
expectedOutput: nil,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := validator.FilterValidColumns(tt.input)
|
||||||
|
if len(result) != len(tt.expectedOutput) {
|
||||||
|
t.Errorf("Expected %d columns, got %d", len(tt.expectedOutput), len(result))
|
||||||
|
}
|
||||||
|
for i, col := range result {
|
||||||
|
if col != tt.expectedOutput[i] {
|
||||||
|
t.Errorf("At index %d: expected %s, got %s", i, tt.expectedOutput[i], col)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFilterRequestOptions(t *testing.T) {
|
||||||
|
model := TestModel{}
|
||||||
|
validator := NewColumnValidator(model)
|
||||||
|
|
||||||
|
options := RequestOptions{
|
||||||
|
Columns: []string{"id", "name", "invalid_col"},
|
||||||
|
OmitColumns: []string{"email", "bad_col"},
|
||||||
|
Filters: []FilterOption{
|
||||||
|
{Column: "name", Operator: "eq", Value: "test"},
|
||||||
|
{Column: "invalid_col", Operator: "eq", Value: "test"},
|
||||||
|
},
|
||||||
|
Sort: []SortOption{
|
||||||
|
{Column: "id", Direction: "ASC"},
|
||||||
|
{Column: "bad_col", Direction: "DESC"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
filtered := validator.FilterRequestOptions(options)
|
||||||
|
|
||||||
|
// Check Columns
|
||||||
|
if len(filtered.Columns) != 2 {
|
||||||
|
t.Errorf("Expected 2 columns, got %d", len(filtered.Columns))
|
||||||
|
}
|
||||||
|
if filtered.Columns[0] != "id" || filtered.Columns[1] != "name" {
|
||||||
|
t.Errorf("Expected columns [id, name], got %v", filtered.Columns)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check OmitColumns
|
||||||
|
if len(filtered.OmitColumns) != 1 {
|
||||||
|
t.Errorf("Expected 1 omit column, got %d", len(filtered.OmitColumns))
|
||||||
|
}
|
||||||
|
if filtered.OmitColumns[0] != "email" {
|
||||||
|
t.Errorf("Expected omit column [email], got %v", filtered.OmitColumns)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check Filters
|
||||||
|
if len(filtered.Filters) != 1 {
|
||||||
|
t.Errorf("Expected 1 filter, got %d", len(filtered.Filters))
|
||||||
|
}
|
||||||
|
if filtered.Filters[0].Column != "name" {
|
||||||
|
t.Errorf("Expected filter column 'name', got %s", filtered.Filters[0].Column)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check Sort
|
||||||
|
if len(filtered.Sort) != 1 {
|
||||||
|
t.Errorf("Expected 1 sort, got %d", len(filtered.Sort))
|
||||||
|
}
|
||||||
|
if filtered.Sort[0].Column != "id" {
|
||||||
|
t.Errorf("Expected sort column 'id', got %s", filtered.Sort[0].Column)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -38,12 +38,28 @@ func (r *DefaultModelRegistry) RegisterModel(name string, model interface{}) err
|
|||||||
return fmt.Errorf("model cannot be nil")
|
return fmt.Errorf("model cannot be nil")
|
||||||
}
|
}
|
||||||
|
|
||||||
if modelType.Kind() == reflect.Ptr {
|
originalType := modelType
|
||||||
return fmt.Errorf("model must be a non-pointer struct, got pointer to %s", modelType.Elem().Kind())
|
|
||||||
|
// Unwrap pointers, slices, and arrays to check the underlying type
|
||||||
|
for modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array {
|
||||||
|
modelType = modelType.Elem()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Validate that the underlying type is a struct
|
||||||
if modelType.Kind() != reflect.Struct {
|
if modelType.Kind() != reflect.Struct {
|
||||||
return fmt.Errorf("model must be a struct, got %s", modelType.Kind())
|
return fmt.Errorf("model must be a struct or pointer to struct, got %s", originalType.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// If a pointer/slice/array was passed, unwrap to the base struct
|
||||||
|
if originalType != modelType {
|
||||||
|
// Create a zero value of the struct type
|
||||||
|
model = reflect.New(modelType).Elem().Interface()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Additional check: ensure model is not a pointer
|
||||||
|
finalType := reflect.TypeOf(model)
|
||||||
|
if finalType.Kind() == reflect.Ptr {
|
||||||
|
return fmt.Errorf("model must be a non-pointer struct, got pointer to %s. Use MyModel{} instead of &MyModel{}", finalType.Elem().Name())
|
||||||
}
|
}
|
||||||
|
|
||||||
r.models[name] = model
|
r.models[name] = model
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"runtime/debug"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/Warky-Devs/ResolveSpec/pkg/common"
|
"github.com/Warky-Devs/ResolveSpec/pkg/common"
|
||||||
@@ -26,8 +27,22 @@ func NewHandler(db common.Database, registry common.ModelRegistry) *Handler {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// handlePanic is a helper function to handle panics with stack traces
|
||||||
|
func (h *Handler) handlePanic(w common.ResponseWriter, method string, err interface{}) {
|
||||||
|
stack := debug.Stack()
|
||||||
|
logger.Error("Panic in %s: %v\nStack trace:\n%s", method, err, string(stack))
|
||||||
|
h.sendError(w, http.StatusInternalServerError, "internal_error", fmt.Sprintf("Internal server error in %s", method), fmt.Errorf("%v", err))
|
||||||
|
}
|
||||||
|
|
||||||
// Handle processes API requests through router-agnostic interface
|
// Handle processes API requests through router-agnostic interface
|
||||||
func (h *Handler) Handle(w common.ResponseWriter, r common.Request, params map[string]string) {
|
func (h *Handler) Handle(w common.ResponseWriter, r common.Request, params map[string]string) {
|
||||||
|
// Capture panics and return error response
|
||||||
|
defer func() {
|
||||||
|
if err := recover(); err != nil {
|
||||||
|
h.handlePanic(w, "Handle", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
body, err := r.Body()
|
body, err := r.Body()
|
||||||
@@ -58,6 +73,26 @@ func (h *Handler) Handle(w common.ResponseWriter, r common.Request, params map[s
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Validate that the model is a struct type (not a slice or pointer to slice)
|
||||||
|
modelType := reflect.TypeOf(model)
|
||||||
|
originalType := modelType
|
||||||
|
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
||||||
|
modelType = modelType.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
if modelType == nil || modelType.Kind() != reflect.Struct {
|
||||||
|
logger.Error("Model for %s.%s must be a struct type, got %v. Please register models as struct types, not slices or pointers to slices.", schema, entity, originalType)
|
||||||
|
h.sendError(w, http.StatusInternalServerError, "invalid_model_type",
|
||||||
|
fmt.Sprintf("Model must be a struct type, got %v. Ensure you register the struct (e.g., ModelCoreAccount{}) not a slice (e.g., []*ModelCoreAccount)", originalType),
|
||||||
|
fmt.Errorf("invalid model type: %v", originalType))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the registered model was a pointer or slice, use the unwrapped struct type
|
||||||
|
if originalType != modelType {
|
||||||
|
model = reflect.New(modelType).Elem().Interface()
|
||||||
|
}
|
||||||
|
|
||||||
// Create a pointer to the model type for database operations
|
// Create a pointer to the model type for database operations
|
||||||
modelPtr := reflect.New(reflect.TypeOf(model)).Interface()
|
modelPtr := reflect.New(reflect.TypeOf(model)).Interface()
|
||||||
tableName := h.getTableName(schema, entity, model)
|
tableName := h.getTableName(schema, entity, model)
|
||||||
@@ -65,6 +100,10 @@ func (h *Handler) Handle(w common.ResponseWriter, r common.Request, params map[s
|
|||||||
// Add request-scoped data to context
|
// Add request-scoped data to context
|
||||||
ctx = WithRequestData(ctx, schema, entity, tableName, model, modelPtr)
|
ctx = WithRequestData(ctx, schema, entity, tableName, model, modelPtr)
|
||||||
|
|
||||||
|
// Validate and filter columns in options (log warnings for invalid columns)
|
||||||
|
validator := common.NewColumnValidator(model)
|
||||||
|
req.Options = validator.FilterRequestOptions(req.Options)
|
||||||
|
|
||||||
switch req.Operation {
|
switch req.Operation {
|
||||||
case "read":
|
case "read":
|
||||||
h.handleRead(ctx, w, id, req.Options)
|
h.handleRead(ctx, w, id, req.Options)
|
||||||
@@ -82,6 +121,13 @@ func (h *Handler) Handle(w common.ResponseWriter, r common.Request, params map[s
|
|||||||
|
|
||||||
// HandleGet processes GET requests for metadata
|
// HandleGet processes GET requests for metadata
|
||||||
func (h *Handler) HandleGet(w common.ResponseWriter, r common.Request, params map[string]string) {
|
func (h *Handler) HandleGet(w common.ResponseWriter, r common.Request, params map[string]string) {
|
||||||
|
// Capture panics and return error response
|
||||||
|
defer func() {
|
||||||
|
if err := recover(); err != nil {
|
||||||
|
h.handlePanic(w, "HandleGet", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
schema := params["schema"]
|
schema := params["schema"]
|
||||||
entity := params["entity"]
|
entity := params["entity"]
|
||||||
|
|
||||||
@@ -99,16 +145,46 @@ func (h *Handler) HandleGet(w common.ResponseWriter, r common.Request, params ma
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id string, options common.RequestOptions) {
|
func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id string, options common.RequestOptions) {
|
||||||
|
// Capture panics and return error response
|
||||||
|
defer func() {
|
||||||
|
if err := recover(); err != nil {
|
||||||
|
h.handlePanic(w, "handleRead", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
schema := GetSchema(ctx)
|
schema := GetSchema(ctx)
|
||||||
entity := GetEntity(ctx)
|
entity := GetEntity(ctx)
|
||||||
tableName := GetTableName(ctx)
|
tableName := GetTableName(ctx)
|
||||||
model := GetModel(ctx)
|
model := GetModel(ctx)
|
||||||
modelPtr := GetModelPtr(ctx)
|
|
||||||
|
// Validate and unwrap model type to get base struct
|
||||||
|
modelType := reflect.TypeOf(model)
|
||||||
|
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
||||||
|
modelType = modelType.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
if modelType == nil || modelType.Kind() != reflect.Struct {
|
||||||
|
logger.Error("Model must be a struct type, got %v for %s.%s", modelType, schema, entity)
|
||||||
|
h.sendError(w, http.StatusInternalServerError, "invalid_model", "Model must be a struct type", fmt.Errorf("invalid model type: %v", modelType))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
logger.Info("Reading records from %s.%s", schema, entity)
|
logger.Info("Reading records from %s.%s", schema, entity)
|
||||||
|
|
||||||
|
// Create the model pointer for Scan() operations
|
||||||
|
sliceType := reflect.SliceOf(reflect.PointerTo(modelType))
|
||||||
|
modelPtr := reflect.New(sliceType).Interface()
|
||||||
|
|
||||||
|
// Start with Model() using the slice pointer to avoid "Model(nil)" errors in Count()
|
||||||
|
// Bun's Model() accepts both single pointers and slice pointers
|
||||||
query := h.db.NewSelect().Model(modelPtr)
|
query := h.db.NewSelect().Model(modelPtr)
|
||||||
query = query.Table(tableName)
|
|
||||||
|
// Only set Table() if the model doesn't provide a table name via the underlying type
|
||||||
|
// Create a temporary instance to check for TableNameProvider
|
||||||
|
tempInstance := reflect.New(modelType).Interface()
|
||||||
|
if provider, ok := tempInstance.(common.TableNameProvider); !ok || provider.TableName() == "" {
|
||||||
|
query = query.Table(tableName)
|
||||||
|
}
|
||||||
|
|
||||||
// Apply column selection
|
// Apply column selection
|
||||||
if len(options.Columns) > 0 {
|
if len(options.Columns) > 0 {
|
||||||
@@ -160,8 +236,8 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
|||||||
var result interface{}
|
var result interface{}
|
||||||
if id != "" {
|
if id != "" {
|
||||||
logger.Debug("Querying single record with ID: %s", id)
|
logger.Debug("Querying single record with ID: %s", id)
|
||||||
// Create a pointer to the struct type for scanning
|
// For single record, create a new pointer to the struct type
|
||||||
singleResult := reflect.New(reflect.TypeOf(model)).Interface()
|
singleResult := reflect.New(modelType).Interface()
|
||||||
query = query.Where("id = ?", id)
|
query = query.Where("id = ?", id)
|
||||||
if err := query.Scan(ctx, singleResult); err != nil {
|
if err := query.Scan(ctx, singleResult); err != nil {
|
||||||
logger.Error("Error querying record: %v", err)
|
logger.Error("Error querying record: %v", err)
|
||||||
@@ -171,16 +247,13 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
|||||||
result = singleResult
|
result = singleResult
|
||||||
} else {
|
} else {
|
||||||
logger.Debug("Querying multiple records")
|
logger.Debug("Querying multiple records")
|
||||||
// Create a slice of pointers to the model type
|
// Use the modelPtr already created and set on the query
|
||||||
sliceType := reflect.SliceOf(reflect.PointerTo(reflect.TypeOf(model)))
|
if err := query.Scan(ctx, modelPtr); err != nil {
|
||||||
results := reflect.New(sliceType).Interface()
|
|
||||||
|
|
||||||
if err := query.Scan(ctx, results); err != nil {
|
|
||||||
logger.Error("Error querying records: %v", err)
|
logger.Error("Error querying records: %v", err)
|
||||||
h.sendError(w, http.StatusInternalServerError, "query_error", "Error executing query", err)
|
h.sendError(w, http.StatusInternalServerError, "query_error", "Error executing query", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
result = reflect.ValueOf(results).Elem().Interface()
|
result = reflect.ValueOf(modelPtr).Elem().Interface()
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.Info("Successfully retrieved records")
|
logger.Info("Successfully retrieved records")
|
||||||
@@ -203,6 +276,13 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, data interface{}, options common.RequestOptions) {
|
func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, data interface{}, options common.RequestOptions) {
|
||||||
|
// Capture panics and return error response
|
||||||
|
defer func() {
|
||||||
|
if err := recover(); err != nil {
|
||||||
|
h.handlePanic(w, "handleCreate", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
schema := GetSchema(ctx)
|
schema := GetSchema(ctx)
|
||||||
entity := GetEntity(ctx)
|
entity := GetEntity(ctx)
|
||||||
tableName := GetTableName(ctx)
|
tableName := GetTableName(ctx)
|
||||||
@@ -279,6 +359,13 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, urlID string, reqID interface{}, data interface{}, options common.RequestOptions) {
|
func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, urlID string, reqID interface{}, data interface{}, options common.RequestOptions) {
|
||||||
|
// Capture panics and return error response
|
||||||
|
defer func() {
|
||||||
|
if err := recover(); err != nil {
|
||||||
|
h.handlePanic(w, "handleUpdate", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
schema := GetSchema(ctx)
|
schema := GetSchema(ctx)
|
||||||
entity := GetEntity(ctx)
|
entity := GetEntity(ctx)
|
||||||
tableName := GetTableName(ctx)
|
tableName := GetTableName(ctx)
|
||||||
@@ -329,6 +416,13 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id string) {
|
func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id string) {
|
||||||
|
// Capture panics and return error response
|
||||||
|
defer func() {
|
||||||
|
if err := recover(); err != nil {
|
||||||
|
h.handlePanic(w, "handleDelete", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
schema := GetSchema(ctx)
|
schema := GetSchema(ctx)
|
||||||
entity := GetEntity(ctx)
|
entity := GetEntity(ctx)
|
||||||
tableName := GetTableName(ctx)
|
tableName := GetTableName(ctx)
|
||||||
@@ -385,19 +479,86 @@ func (h *Handler) applyFilter(query common.SelectQuery, filter common.FilterOpti
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) getTableName(schema, entity string, model interface{}) string {
|
// parseTableName splits a table name that may contain schema into separate schema and table
|
||||||
if provider, ok := model.(common.TableNameProvider); ok {
|
func (h *Handler) parseTableName(fullTableName string) (schema, table string) {
|
||||||
return provider.TableName()
|
if idx := strings.LastIndex(fullTableName, "."); idx != -1 {
|
||||||
|
return fullTableName[:idx], fullTableName[idx+1:]
|
||||||
}
|
}
|
||||||
return fmt.Sprintf("%s.%s", schema, entity)
|
return "", fullTableName
|
||||||
|
}
|
||||||
|
|
||||||
|
// getSchemaAndTable returns the schema and table name separately
|
||||||
|
// It checks SchemaProvider and TableNameProvider interfaces and handles cases where
|
||||||
|
// the table name may already include the schema (e.g., "public.users")
|
||||||
|
//
|
||||||
|
// Priority order:
|
||||||
|
// 1. If TableName() contains a schema (e.g., "myschema.mytable"), that schema takes precedence
|
||||||
|
// 2. If model implements SchemaProvider, use that schema
|
||||||
|
// 3. Otherwise, use the defaultSchema parameter
|
||||||
|
func (h *Handler) getSchemaAndTable(defaultSchema, entity string, model interface{}) (schema, table string) {
|
||||||
|
// First check if model provides a table name
|
||||||
|
// We check this FIRST because the table name might already contain the schema
|
||||||
|
if tableProvider, ok := model.(common.TableNameProvider); ok {
|
||||||
|
tableName := tableProvider.TableName()
|
||||||
|
|
||||||
|
// IMPORTANT: Check if the table name already contains a schema (e.g., "schema.table")
|
||||||
|
// This is common when models need to specify a different schema than the default
|
||||||
|
if tableSchema, tableOnly := h.parseTableName(tableName); tableSchema != "" {
|
||||||
|
// Table name includes schema - use it and ignore any other schema providers
|
||||||
|
logger.Debug("TableName() includes schema: %s.%s", tableSchema, tableOnly)
|
||||||
|
return tableSchema, tableOnly
|
||||||
|
}
|
||||||
|
|
||||||
|
// Table name is just the table name without schema
|
||||||
|
// Now determine which schema to use
|
||||||
|
if schemaProvider, ok := model.(common.SchemaProvider); ok {
|
||||||
|
schema = schemaProvider.SchemaName()
|
||||||
|
} else {
|
||||||
|
schema = defaultSchema
|
||||||
|
}
|
||||||
|
|
||||||
|
return schema, tableName
|
||||||
|
}
|
||||||
|
|
||||||
|
// No TableNameProvider, so check for schema and use entity as table name
|
||||||
|
if schemaProvider, ok := model.(common.SchemaProvider); ok {
|
||||||
|
schema = schemaProvider.SchemaName()
|
||||||
|
} else {
|
||||||
|
schema = defaultSchema
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default to entity name as table
|
||||||
|
return schema, entity
|
||||||
|
}
|
||||||
|
|
||||||
|
// getTableName returns the full table name including schema (schema.table)
|
||||||
|
func (h *Handler) getTableName(schema, entity string, model interface{}) string {
|
||||||
|
schemaName, tableName := h.getSchemaAndTable(schema, entity, model)
|
||||||
|
if schemaName != "" {
|
||||||
|
return fmt.Sprintf("%s.%s", schemaName, tableName)
|
||||||
|
}
|
||||||
|
return tableName
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) generateMetadata(schema, entity string, model interface{}) *common.TableMetadata {
|
func (h *Handler) generateMetadata(schema, entity string, model interface{}) *common.TableMetadata {
|
||||||
modelType := reflect.TypeOf(model)
|
modelType := reflect.TypeOf(model)
|
||||||
if modelType.Kind() == reflect.Ptr {
|
|
||||||
|
// Unwrap pointers, slices, and arrays to get to the base struct type
|
||||||
|
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
||||||
modelType = modelType.Elem()
|
modelType = modelType.Elem()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Validate that we have a struct type
|
||||||
|
if modelType == nil || modelType.Kind() != reflect.Struct {
|
||||||
|
logger.Error("Model type must be a struct, got %v for %s.%s", modelType, schema, entity)
|
||||||
|
return &common.TableMetadata{
|
||||||
|
Schema: schema,
|
||||||
|
Table: entity,
|
||||||
|
Columns: make([]common.Column, 0),
|
||||||
|
Relations: make([]string, 0),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
metadata := &common.TableMetadata{
|
metadata := &common.TableMetadata{
|
||||||
Schema: schema,
|
Schema: schema,
|
||||||
Table: entity,
|
Table: entity,
|
||||||
@@ -541,10 +702,18 @@ type relationshipInfo struct {
|
|||||||
|
|
||||||
func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, preloads []common.PreloadOption) common.SelectQuery {
|
func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, preloads []common.PreloadOption) common.SelectQuery {
|
||||||
modelType := reflect.TypeOf(model)
|
modelType := reflect.TypeOf(model)
|
||||||
if modelType.Kind() == reflect.Ptr {
|
|
||||||
|
// Unwrap pointers, slices, and arrays to get to the base struct type
|
||||||
|
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
||||||
modelType = modelType.Elem()
|
modelType = modelType.Elem()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Validate that we have a struct type
|
||||||
|
if modelType == nil || modelType.Kind() != reflect.Struct {
|
||||||
|
logger.Warn("Cannot apply preloads to non-struct type: %v", modelType)
|
||||||
|
return query
|
||||||
|
}
|
||||||
|
|
||||||
for _, preload := range preloads {
|
for _, preload := range preloads {
|
||||||
logger.Debug("Processing preload for relation: %s", preload.Relation)
|
logger.Debug("Processing preload for relation: %s", preload.Relation)
|
||||||
relInfo := h.getRelationshipInfo(modelType, preload.Relation)
|
relInfo := h.getRelationshipInfo(modelType, preload.Relation)
|
||||||
@@ -568,6 +737,12 @@ func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, pre
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) getRelationshipInfo(modelType reflect.Type, relationName string) *relationshipInfo {
|
func (h *Handler) getRelationshipInfo(modelType reflect.Type, relationName string) *relationshipInfo {
|
||||||
|
// Ensure we have a struct type
|
||||||
|
if modelType == nil || modelType.Kind() != reflect.Struct {
|
||||||
|
logger.Warn("Cannot get relationship info from non-struct type: %v", modelType)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
for i := 0; i < modelType.NumField(); i++ {
|
for i := 0; i < modelType.NumField(); i++ {
|
||||||
field := modelType.Field(i)
|
field := modelType.Field(i)
|
||||||
jsonTag := field.Tag.Get("json")
|
jsonTag := field.Tag.Get("json")
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"runtime/debug"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/Warky-Devs/ResolveSpec/pkg/common"
|
"github.com/Warky-Devs/ResolveSpec/pkg/common"
|
||||||
@@ -27,9 +28,23 @@ func NewHandler(db common.Database, registry common.ModelRegistry) *Handler {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// handlePanic is a helper function to handle panics with stack traces
|
||||||
|
func (h *Handler) handlePanic(w common.ResponseWriter, method string, err interface{}) {
|
||||||
|
stack := debug.Stack()
|
||||||
|
logger.Error("Panic in %s: %v\nStack trace:\n%s", method, err, string(stack))
|
||||||
|
h.sendError(w, http.StatusInternalServerError, "internal_error", fmt.Sprintf("Internal server error in %s", method), fmt.Errorf("%v", err))
|
||||||
|
}
|
||||||
|
|
||||||
// Handle processes API requests through router-agnostic interface
|
// Handle processes API requests through router-agnostic interface
|
||||||
// Options are read from HTTP headers instead of request body
|
// Options are read from HTTP headers instead of request body
|
||||||
func (h *Handler) Handle(w common.ResponseWriter, r common.Request, params map[string]string) {
|
func (h *Handler) Handle(w common.ResponseWriter, r common.Request, params map[string]string) {
|
||||||
|
// Capture panics and return error response
|
||||||
|
defer func() {
|
||||||
|
if err := recover(); err != nil {
|
||||||
|
h.handlePanic(w, "Handle", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
schema := params["schema"]
|
schema := params["schema"]
|
||||||
@@ -52,12 +67,36 @@ func (h *Handler) Handle(w common.ResponseWriter, r common.Request, params map[s
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Validate that the model is a struct type (not a slice or pointer to slice)
|
||||||
|
modelType := reflect.TypeOf(model)
|
||||||
|
originalType := modelType
|
||||||
|
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
||||||
|
modelType = modelType.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
if modelType == nil || modelType.Kind() != reflect.Struct {
|
||||||
|
logger.Error("Model for %s.%s must be a struct type, got %v. Please register models as struct types, not slices or pointers to slices.", schema, entity, originalType)
|
||||||
|
h.sendError(w, http.StatusInternalServerError, "invalid_model_type",
|
||||||
|
fmt.Sprintf("Model must be a struct type, got %v. Ensure you register the struct (e.g., ModelCoreAccount{}) not a slice (e.g., []*ModelCoreAccount)", originalType),
|
||||||
|
fmt.Errorf("invalid model type: %v", originalType))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the registered model was a pointer or slice, use the unwrapped struct type
|
||||||
|
if originalType != modelType {
|
||||||
|
model = reflect.New(modelType).Elem().Interface()
|
||||||
|
}
|
||||||
|
|
||||||
modelPtr := reflect.New(reflect.TypeOf(model)).Interface()
|
modelPtr := reflect.New(reflect.TypeOf(model)).Interface()
|
||||||
tableName := h.getTableName(schema, entity, model)
|
tableName := h.getTableName(schema, entity, model)
|
||||||
|
|
||||||
// Add request-scoped data to context
|
// Add request-scoped data to context
|
||||||
ctx = WithRequestData(ctx, schema, entity, tableName, model, modelPtr)
|
ctx = WithRequestData(ctx, schema, entity, tableName, model, modelPtr)
|
||||||
|
|
||||||
|
// Validate and filter columns in options (log warnings for invalid columns)
|
||||||
|
validator := common.NewColumnValidator(model)
|
||||||
|
options = filterExtendedOptions(validator, options)
|
||||||
|
|
||||||
switch method {
|
switch method {
|
||||||
case "GET":
|
case "GET":
|
||||||
if id != "" {
|
if id != "" {
|
||||||
@@ -107,6 +146,13 @@ func (h *Handler) Handle(w common.ResponseWriter, r common.Request, params map[s
|
|||||||
|
|
||||||
// HandleGet processes GET requests for metadata
|
// HandleGet processes GET requests for metadata
|
||||||
func (h *Handler) HandleGet(w common.ResponseWriter, r common.Request, params map[string]string) {
|
func (h *Handler) HandleGet(w common.ResponseWriter, r common.Request, params map[string]string) {
|
||||||
|
// Capture panics and return error response
|
||||||
|
defer func() {
|
||||||
|
if err := recover(); err != nil {
|
||||||
|
h.handlePanic(w, "HandleGet", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
schema := params["schema"]
|
schema := params["schema"]
|
||||||
entity := params["entity"]
|
entity := params["entity"]
|
||||||
|
|
||||||
@@ -126,15 +172,45 @@ func (h *Handler) HandleGet(w common.ResponseWriter, r common.Request, params ma
|
|||||||
// parseOptionsFromHeaders is now implemented in headers.go
|
// parseOptionsFromHeaders is now implemented in headers.go
|
||||||
|
|
||||||
func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id string, options ExtendedRequestOptions) {
|
func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id string, options ExtendedRequestOptions) {
|
||||||
|
// Capture panics and return error response
|
||||||
|
defer func() {
|
||||||
|
if err := recover(); err != nil {
|
||||||
|
h.handlePanic(w, "handleRead", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
schema := GetSchema(ctx)
|
schema := GetSchema(ctx)
|
||||||
entity := GetEntity(ctx)
|
entity := GetEntity(ctx)
|
||||||
tableName := GetTableName(ctx)
|
tableName := GetTableName(ctx)
|
||||||
modelPtr := GetModelPtr(ctx)
|
model := GetModel(ctx)
|
||||||
|
|
||||||
|
// Validate and unwrap model type to get base struct
|
||||||
|
modelType := reflect.TypeOf(model)
|
||||||
|
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
||||||
|
modelType = modelType.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
if modelType == nil || modelType.Kind() != reflect.Struct {
|
||||||
|
logger.Error("Model must be a struct type, got %v for %s.%s", modelType, schema, entity)
|
||||||
|
h.sendError(w, http.StatusInternalServerError, "invalid_model", "Model must be a struct type", fmt.Errorf("invalid model type: %v", modelType))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a pointer to a slice of pointers to the model type for query results
|
||||||
|
modelPtr := reflect.New(reflect.SliceOf(reflect.PointerTo(modelType))).Interface()
|
||||||
|
|
||||||
logger.Info("Reading records from %s.%s", schema, entity)
|
logger.Info("Reading records from %s.%s", schema, entity)
|
||||||
|
|
||||||
|
// Start with Model() using the slice pointer to avoid "Model(nil)" errors in Count()
|
||||||
|
// Bun's Model() accepts both single pointers and slice pointers
|
||||||
query := h.db.NewSelect().Model(modelPtr)
|
query := h.db.NewSelect().Model(modelPtr)
|
||||||
query = query.Table(tableName)
|
|
||||||
|
// Only set Table() if the model doesn't provide a table name via the underlying type
|
||||||
|
// Create a temporary instance to check for TableNameProvider
|
||||||
|
tempInstance := reflect.New(modelType).Interface()
|
||||||
|
if provider, ok := tempInstance.(common.TableNameProvider); !ok || provider.TableName() == "" {
|
||||||
|
query = query.Table(tableName)
|
||||||
|
}
|
||||||
|
|
||||||
// Apply column selection
|
// Apply column selection
|
||||||
if len(options.Columns) > 0 {
|
if len(options.Columns) > 0 {
|
||||||
@@ -163,10 +239,21 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
|||||||
// This may need to be handled differently per database adapter
|
// This may need to be handled differently per database adapter
|
||||||
}
|
}
|
||||||
|
|
||||||
// Apply filters
|
// Apply filters - validate and adjust for column types first
|
||||||
for _, filter := range options.Filters {
|
for i := range options.Filters {
|
||||||
logger.Debug("Applying filter: %s %s %v", filter.Column, filter.Operator, filter.Value)
|
filter := &options.Filters[i]
|
||||||
query = h.applyFilter(query, filter)
|
|
||||||
|
// Validate and adjust filter based on column type
|
||||||
|
castInfo := h.ValidateAndAdjustFilterForColumnType(filter, model)
|
||||||
|
|
||||||
|
// Default to AND if LogicOperator is not set
|
||||||
|
logicOp := filter.LogicOperator
|
||||||
|
if logicOp == "" {
|
||||||
|
logicOp = "AND"
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Debug("Applying filter: %s %s %v (needsCast=%v, logic=%s)", filter.Column, filter.Operator, filter.Value, castInfo.NeedsCast, logicOp)
|
||||||
|
query = h.applyFilter(query, *filter, tableName, castInfo.NeedsCast, logicOp)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Apply custom SQL WHERE clause (AND condition)
|
// Apply custom SQL WHERE clause (AND condition)
|
||||||
@@ -223,10 +310,8 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
|||||||
query = query.Offset(*options.Offset)
|
query = query.Offset(*options.Offset)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Execute query - create a slice of pointers to the model type
|
// Execute query - modelPtr was already created earlier
|
||||||
model := GetModel(ctx)
|
if err := query.Scan(ctx, modelPtr); err != nil {
|
||||||
resultSlice := reflect.New(reflect.SliceOf(reflect.PointerTo(reflect.TypeOf(model)))).Interface()
|
|
||||||
if err := query.Scan(ctx, resultSlice); err != nil {
|
|
||||||
logger.Error("Error executing query: %v", err)
|
logger.Error("Error executing query: %v", err)
|
||||||
h.sendError(w, http.StatusInternalServerError, "query_error", "Error executing query", err)
|
h.sendError(w, http.StatusInternalServerError, "query_error", "Error executing query", err)
|
||||||
return
|
return
|
||||||
@@ -248,10 +333,17 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
|||||||
Offset: offset,
|
Offset: offset,
|
||||||
}
|
}
|
||||||
|
|
||||||
h.sendFormattedResponse(w, resultSlice, metadata, options)
|
h.sendFormattedResponse(w, modelPtr, metadata, options)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, data interface{}, options ExtendedRequestOptions) {
|
func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, data interface{}, options ExtendedRequestOptions) {
|
||||||
|
// Capture panics and return error response
|
||||||
|
defer func() {
|
||||||
|
if err := recover(); err != nil {
|
||||||
|
h.handlePanic(w, "handleCreate", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
schema := GetSchema(ctx)
|
schema := GetSchema(ctx)
|
||||||
entity := GetEntity(ctx)
|
entity := GetEntity(ctx)
|
||||||
tableName := GetTableName(ctx)
|
tableName := GetTableName(ctx)
|
||||||
@@ -322,6 +414,13 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id string, idPtr *int64, data interface{}, options ExtendedRequestOptions) {
|
func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id string, idPtr *int64, data interface{}, options ExtendedRequestOptions) {
|
||||||
|
// Capture panics and return error response
|
||||||
|
defer func() {
|
||||||
|
if err := recover(); err != nil {
|
||||||
|
h.handlePanic(w, "handleUpdate", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
schema := GetSchema(ctx)
|
schema := GetSchema(ctx)
|
||||||
entity := GetEntity(ctx)
|
entity := GetEntity(ctx)
|
||||||
tableName := GetTableName(ctx)
|
tableName := GetTableName(ctx)
|
||||||
@@ -369,6 +468,13 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id string) {
|
func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id string) {
|
||||||
|
// Capture panics and return error response
|
||||||
|
defer func() {
|
||||||
|
if err := recover(); err != nil {
|
||||||
|
h.handlePanic(w, "handleDelete", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
schema := GetSchema(ctx)
|
schema := GetSchema(ctx)
|
||||||
entity := GetEntity(ctx)
|
entity := GetEntity(ctx)
|
||||||
tableName := GetTableName(ctx)
|
tableName := GetTableName(ctx)
|
||||||
@@ -396,80 +502,178 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
|||||||
}, nil)
|
}, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) applyFilter(query common.SelectQuery, filter common.FilterOption) common.SelectQuery {
|
// qualifyColumnName ensures column name is fully qualified with table name if not already
|
||||||
|
func (h *Handler) qualifyColumnName(columnName, fullTableName string) string {
|
||||||
|
// Check if column already has a table/schema prefix (contains a dot)
|
||||||
|
if strings.Contains(columnName, ".") {
|
||||||
|
return columnName
|
||||||
|
}
|
||||||
|
|
||||||
|
// If no table name provided, return column as-is
|
||||||
|
if fullTableName == "" {
|
||||||
|
return columnName
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract just the table name from "schema.table" format
|
||||||
|
// Only use the table name part, not the schema
|
||||||
|
tableOnly := fullTableName
|
||||||
|
if idx := strings.LastIndex(fullTableName, "."); idx != -1 {
|
||||||
|
tableOnly = fullTableName[idx+1:]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return column qualified with just the table name
|
||||||
|
return fmt.Sprintf("%s.%s", tableOnly, columnName)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handler) applyFilter(query common.SelectQuery, filter common.FilterOption, tableName string, needsCast bool, logicOp string) common.SelectQuery {
|
||||||
|
// Qualify the column name with table name if not already qualified
|
||||||
|
qualifiedColumn := h.qualifyColumnName(filter.Column, tableName)
|
||||||
|
|
||||||
|
// Apply casting to text if needed for non-numeric columns or non-numeric values
|
||||||
|
if needsCast {
|
||||||
|
qualifiedColumn = fmt.Sprintf("CAST(%s AS TEXT)", qualifiedColumn)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper function to apply the correct Where method based on logic operator
|
||||||
|
applyWhere := func(condition string, args ...interface{}) common.SelectQuery {
|
||||||
|
if logicOp == "OR" {
|
||||||
|
return query.WhereOr(condition, args...)
|
||||||
|
}
|
||||||
|
return query.Where(condition, args...)
|
||||||
|
}
|
||||||
|
|
||||||
switch strings.ToLower(filter.Operator) {
|
switch strings.ToLower(filter.Operator) {
|
||||||
case "eq", "equals":
|
case "eq", "equals":
|
||||||
return query.Where(fmt.Sprintf("%s = ?", filter.Column), filter.Value)
|
return applyWhere(fmt.Sprintf("%s = ?", qualifiedColumn), filter.Value)
|
||||||
case "neq", "not_equals", "ne":
|
case "neq", "not_equals", "ne":
|
||||||
return query.Where(fmt.Sprintf("%s != ?", filter.Column), filter.Value)
|
return applyWhere(fmt.Sprintf("%s != ?", qualifiedColumn), filter.Value)
|
||||||
case "gt", "greater_than":
|
case "gt", "greater_than":
|
||||||
return query.Where(fmt.Sprintf("%s > ?", filter.Column), filter.Value)
|
return applyWhere(fmt.Sprintf("%s > ?", qualifiedColumn), filter.Value)
|
||||||
case "gte", "greater_than_equals", "ge":
|
case "gte", "greater_than_equals", "ge":
|
||||||
return query.Where(fmt.Sprintf("%s >= ?", filter.Column), filter.Value)
|
return applyWhere(fmt.Sprintf("%s >= ?", qualifiedColumn), filter.Value)
|
||||||
case "lt", "less_than":
|
case "lt", "less_than":
|
||||||
return query.Where(fmt.Sprintf("%s < ?", filter.Column), filter.Value)
|
return applyWhere(fmt.Sprintf("%s < ?", qualifiedColumn), filter.Value)
|
||||||
case "lte", "less_than_equals", "le":
|
case "lte", "less_than_equals", "le":
|
||||||
return query.Where(fmt.Sprintf("%s <= ?", filter.Column), filter.Value)
|
return applyWhere(fmt.Sprintf("%s <= ?", qualifiedColumn), filter.Value)
|
||||||
case "like":
|
case "like":
|
||||||
return query.Where(fmt.Sprintf("%s LIKE ?", filter.Column), filter.Value)
|
return applyWhere(fmt.Sprintf("%s LIKE ?", qualifiedColumn), filter.Value)
|
||||||
case "ilike":
|
case "ilike":
|
||||||
// Use ILIKE for case-insensitive search (PostgreSQL)
|
// Use ILIKE for case-insensitive search (PostgreSQL)
|
||||||
// For other databases, cast to citext or use LOWER()
|
// Column is already cast to TEXT if needed
|
||||||
return query.Where(fmt.Sprintf("CAST(%s AS TEXT) ILIKE ?", filter.Column), filter.Value)
|
return applyWhere(fmt.Sprintf("%s ILIKE ?", qualifiedColumn), filter.Value)
|
||||||
case "in":
|
case "in":
|
||||||
return query.Where(fmt.Sprintf("%s IN (?)", filter.Column), filter.Value)
|
return applyWhere(fmt.Sprintf("%s IN (?)", qualifiedColumn), filter.Value)
|
||||||
case "between":
|
case "between":
|
||||||
// Handle between operator - exclusive (> val1 AND < val2)
|
// Handle between operator - exclusive (> val1 AND < val2)
|
||||||
if values, ok := filter.Value.([]interface{}); ok && len(values) == 2 {
|
if values, ok := filter.Value.([]interface{}); ok && len(values) == 2 {
|
||||||
return query.Where(fmt.Sprintf("%s > ? AND %s < ?", filter.Column, filter.Column), values[0], values[1])
|
return applyWhere(fmt.Sprintf("%s > ? AND %s < ?", qualifiedColumn, qualifiedColumn), values[0], values[1])
|
||||||
} else if values, ok := filter.Value.([]string); ok && len(values) == 2 {
|
} else if values, ok := filter.Value.([]string); ok && len(values) == 2 {
|
||||||
return query.Where(fmt.Sprintf("%s > ? AND %s < ?", filter.Column, filter.Column), values[0], values[1])
|
return applyWhere(fmt.Sprintf("%s > ? AND %s < ?", qualifiedColumn, qualifiedColumn), values[0], values[1])
|
||||||
}
|
}
|
||||||
logger.Warn("Invalid BETWEEN filter value format")
|
logger.Warn("Invalid BETWEEN filter value format")
|
||||||
return query
|
return query
|
||||||
case "between_inclusive":
|
case "between_inclusive":
|
||||||
// Handle between inclusive operator - inclusive (>= val1 AND <= val2)
|
// Handle between inclusive operator - inclusive (>= val1 AND <= val2)
|
||||||
if values, ok := filter.Value.([]interface{}); ok && len(values) == 2 {
|
if values, ok := filter.Value.([]interface{}); ok && len(values) == 2 {
|
||||||
return query.Where(fmt.Sprintf("%s >= ? AND %s <= ?", filter.Column, filter.Column), values[0], values[1])
|
return applyWhere(fmt.Sprintf("%s >= ? AND %s <= ?", qualifiedColumn, qualifiedColumn), values[0], values[1])
|
||||||
} else if values, ok := filter.Value.([]string); ok && len(values) == 2 {
|
} else if values, ok := filter.Value.([]string); ok && len(values) == 2 {
|
||||||
return query.Where(fmt.Sprintf("%s >= ? AND %s <= ?", filter.Column, filter.Column), values[0], values[1])
|
return applyWhere(fmt.Sprintf("%s >= ? AND %s <= ?", qualifiedColumn, qualifiedColumn), values[0], values[1])
|
||||||
}
|
}
|
||||||
logger.Warn("Invalid BETWEEN INCLUSIVE filter value format")
|
logger.Warn("Invalid BETWEEN INCLUSIVE filter value format")
|
||||||
return query
|
return query
|
||||||
case "is_null", "isnull":
|
case "is_null", "isnull":
|
||||||
// Check for NULL values
|
// Check for NULL values - don't use cast for NULL checks
|
||||||
return query.Where(fmt.Sprintf("(%s IS NULL OR %s = '')", filter.Column, filter.Column))
|
colName := h.qualifyColumnName(filter.Column, tableName)
|
||||||
|
return applyWhere(fmt.Sprintf("(%s IS NULL OR %s = '')", colName, colName))
|
||||||
case "is_not_null", "isnotnull":
|
case "is_not_null", "isnotnull":
|
||||||
// Check for NOT NULL values
|
// Check for NOT NULL values - don't use cast for NULL checks
|
||||||
return query.Where(fmt.Sprintf("(%s IS NOT NULL AND %s != '')", filter.Column, filter.Column))
|
colName := h.qualifyColumnName(filter.Column, tableName)
|
||||||
|
return applyWhere(fmt.Sprintf("(%s IS NOT NULL AND %s != '')", colName, colName))
|
||||||
default:
|
default:
|
||||||
logger.Warn("Unknown filter operator: %s, defaulting to equals", filter.Operator)
|
logger.Warn("Unknown filter operator: %s, defaulting to equals", filter.Operator)
|
||||||
return query.Where(fmt.Sprintf("%s = ?", filter.Column), filter.Value)
|
return applyWhere(fmt.Sprintf("%s = ?", qualifiedColumn), filter.Value)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) getTableName(schema, entity string, model interface{}) string {
|
// parseTableName splits a table name that may contain schema into separate schema and table
|
||||||
// Check if model implements TableNameProvider
|
func (h *Handler) parseTableName(fullTableName string) (schema, table string) {
|
||||||
if provider, ok := model.(common.TableNameProvider); ok {
|
if idx := strings.LastIndex(fullTableName, "."); idx != -1 {
|
||||||
tableName := provider.TableName()
|
return fullTableName[:idx], fullTableName[idx+1:]
|
||||||
if tableName != "" {
|
}
|
||||||
return tableName
|
return "", fullTableName
|
||||||
|
}
|
||||||
|
|
||||||
|
// getSchemaAndTable returns the schema and table name separately
|
||||||
|
// It checks SchemaProvider and TableNameProvider interfaces and handles cases where
|
||||||
|
// the table name may already include the schema (e.g., "public.users")
|
||||||
|
//
|
||||||
|
// Priority order:
|
||||||
|
// 1. If TableName() contains a schema (e.g., "myschema.mytable"), that schema takes precedence
|
||||||
|
// 2. If model implements SchemaProvider, use that schema
|
||||||
|
// 3. Otherwise, use the defaultSchema parameter
|
||||||
|
func (h *Handler) getSchemaAndTable(defaultSchema, entity string, model interface{}) (schema, table string) {
|
||||||
|
// First check if model provides a table name
|
||||||
|
// We check this FIRST because the table name might already contain the schema
|
||||||
|
if tableProvider, ok := model.(common.TableNameProvider); ok {
|
||||||
|
tableName := tableProvider.TableName()
|
||||||
|
|
||||||
|
// IMPORTANT: Check if the table name already contains a schema (e.g., "schema.table")
|
||||||
|
// This is common when models need to specify a different schema than the default
|
||||||
|
if tableSchema, tableOnly := h.parseTableName(tableName); tableSchema != "" {
|
||||||
|
// Table name includes schema - use it and ignore any other schema providers
|
||||||
|
logger.Debug("TableName() includes schema: %s.%s", tableSchema, tableOnly)
|
||||||
|
return tableSchema, tableOnly
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Table name is just the table name without schema
|
||||||
|
// Now determine which schema to use
|
||||||
|
if schemaProvider, ok := model.(common.SchemaProvider); ok {
|
||||||
|
schema = schemaProvider.SchemaName()
|
||||||
|
} else {
|
||||||
|
schema = defaultSchema
|
||||||
|
}
|
||||||
|
|
||||||
|
return schema, tableName
|
||||||
}
|
}
|
||||||
|
|
||||||
// Default to schema.entity
|
// No TableNameProvider, so check for schema and use entity as table name
|
||||||
if schema != "" {
|
if schemaProvider, ok := model.(common.SchemaProvider); ok {
|
||||||
return fmt.Sprintf("%s.%s", schema, entity)
|
schema = schemaProvider.SchemaName()
|
||||||
|
} else {
|
||||||
|
schema = defaultSchema
|
||||||
}
|
}
|
||||||
return entity
|
|
||||||
|
// Default to entity name as table
|
||||||
|
return schema, entity
|
||||||
|
}
|
||||||
|
|
||||||
|
// getTableName returns the full table name including schema (schema.table)
|
||||||
|
func (h *Handler) getTableName(schema, entity string, model interface{}) string {
|
||||||
|
schemaName, tableName := h.getSchemaAndTable(schema, entity, model)
|
||||||
|
if schemaName != "" {
|
||||||
|
return fmt.Sprintf("%s.%s", schemaName, tableName)
|
||||||
|
}
|
||||||
|
return tableName
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) generateMetadata(schema, entity string, model interface{}) *common.TableMetadata {
|
func (h *Handler) generateMetadata(schema, entity string, model interface{}) *common.TableMetadata {
|
||||||
modelType := reflect.TypeOf(model)
|
modelType := reflect.TypeOf(model)
|
||||||
if modelType.Kind() == reflect.Ptr {
|
|
||||||
|
// Unwrap pointers, slices, and arrays to get to the base struct type
|
||||||
|
for modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array {
|
||||||
modelType = modelType.Elem()
|
modelType = modelType.Elem()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Validate that we have a struct type
|
||||||
|
if modelType.Kind() != reflect.Struct {
|
||||||
|
logger.Error("Model type must be a struct, got %s for %s.%s", modelType.Kind(), schema, entity)
|
||||||
|
return &common.TableMetadata{
|
||||||
|
Schema: schema,
|
||||||
|
Table: h.getTableName(schema, entity, model),
|
||||||
|
Columns: []common.Column{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
tableName := h.getTableName(schema, entity, model)
|
tableName := h.getTableName(schema, entity, model)
|
||||||
|
|
||||||
metadata := &common.TableMetadata{
|
metadata := &common.TableMetadata{
|
||||||
@@ -555,7 +759,7 @@ func (h *Handler) sendFormattedResponse(w common.ResponseWriter, data interface{
|
|||||||
if options.CleanJSON {
|
if options.CleanJSON {
|
||||||
data = h.cleanJSON(data)
|
data = h.cleanJSON(data)
|
||||||
}
|
}
|
||||||
|
w.SetHeader("Content-Type", "application/json")
|
||||||
// Format response based on response format option
|
// Format response based on response format option
|
||||||
switch options.ResponseFormat {
|
switch options.ResponseFormat {
|
||||||
case "simple":
|
case "simple":
|
||||||
@@ -610,3 +814,41 @@ func (h *Handler) sendError(w common.ResponseWriter, statusCode int, code, messa
|
|||||||
w.WriteHeader(statusCode)
|
w.WriteHeader(statusCode)
|
||||||
w.WriteJSON(response)
|
w.WriteJSON(response)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// filterExtendedOptions filters all column references, removing invalid ones and logging warnings
|
||||||
|
func filterExtendedOptions(validator *common.ColumnValidator, options ExtendedRequestOptions) ExtendedRequestOptions {
|
||||||
|
filtered := options
|
||||||
|
|
||||||
|
// Filter base RequestOptions
|
||||||
|
filtered.RequestOptions = validator.FilterRequestOptions(options.RequestOptions)
|
||||||
|
|
||||||
|
// Filter SearchColumns
|
||||||
|
filtered.SearchColumns = validator.FilterValidColumns(options.SearchColumns)
|
||||||
|
|
||||||
|
// Filter AdvancedSQL column keys
|
||||||
|
filteredAdvSQL := make(map[string]string)
|
||||||
|
for colName, sqlExpr := range options.AdvancedSQL {
|
||||||
|
if validator.IsValidColumn(colName) {
|
||||||
|
filteredAdvSQL[colName] = sqlExpr
|
||||||
|
} else {
|
||||||
|
logger.Warn("Invalid column in advanced SQL removed: %s", colName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
filtered.AdvancedSQL = filteredAdvSQL
|
||||||
|
|
||||||
|
// ComputedQL columns are allowed to be any name since they're computed
|
||||||
|
// No filtering needed for ComputedQL keys
|
||||||
|
filtered.ComputedQL = options.ComputedQL
|
||||||
|
|
||||||
|
// Filter Expand columns
|
||||||
|
filteredExpands := make([]ExpandOption, 0, len(options.Expand))
|
||||||
|
for _, expand := range options.Expand {
|
||||||
|
filteredExpand := expand
|
||||||
|
// Don't validate relation name, only columns
|
||||||
|
filteredExpand.Columns = validator.FilterValidColumns(expand.Columns)
|
||||||
|
filteredExpands = append(filteredExpands, filteredExpand)
|
||||||
|
}
|
||||||
|
filtered.Expand = filteredExpands
|
||||||
|
|
||||||
|
return filtered
|
||||||
|
}
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"reflect"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
@@ -19,21 +20,21 @@ type ExtendedRequestOptions struct {
|
|||||||
CleanJSON bool
|
CleanJSON bool
|
||||||
|
|
||||||
// Advanced filtering
|
// Advanced filtering
|
||||||
SearchColumns []string
|
SearchColumns []string
|
||||||
CustomSQLWhere string
|
CustomSQLWhere string
|
||||||
CustomSQLOr string
|
CustomSQLOr string
|
||||||
|
|
||||||
// Joins
|
// Joins
|
||||||
Expand []ExpandOption
|
Expand []ExpandOption
|
||||||
|
|
||||||
// Advanced features
|
// Advanced features
|
||||||
AdvancedSQL map[string]string // Column -> SQL expression
|
AdvancedSQL map[string]string // Column -> SQL expression
|
||||||
ComputedQL map[string]string // Column -> CQL expression
|
ComputedQL map[string]string // Column -> CQL expression
|
||||||
Distinct bool
|
Distinct bool
|
||||||
SkipCount bool
|
SkipCount bool
|
||||||
SkipCache bool
|
SkipCache bool
|
||||||
FetchRowNumber *string
|
FetchRowNumber *string
|
||||||
PKRow *string
|
PKRow *string
|
||||||
|
|
||||||
// Response format
|
// Response format
|
||||||
ResponseFormat string // "simple", "detail", "syncfusion"
|
ResponseFormat string // "simple", "detail", "syncfusion"
|
||||||
@@ -42,42 +43,58 @@ type ExtendedRequestOptions struct {
|
|||||||
AtomicTransaction bool
|
AtomicTransaction bool
|
||||||
|
|
||||||
// Cursor pagination
|
// Cursor pagination
|
||||||
CursorForward string
|
CursorForward string
|
||||||
CursorBackward string
|
CursorBackward string
|
||||||
}
|
}
|
||||||
|
|
||||||
// ExpandOption represents a relation expansion configuration
|
// ExpandOption represents a relation expansion configuration
|
||||||
type ExpandOption struct {
|
type ExpandOption struct {
|
||||||
Relation string
|
Relation string
|
||||||
Columns []string
|
Columns []string
|
||||||
Where string
|
Where string
|
||||||
Sort string
|
Sort string
|
||||||
}
|
}
|
||||||
|
|
||||||
// decodeHeaderValue decodes base64 encoded header values
|
// decodeHeaderValue decodes base64 encoded header values
|
||||||
// Supports ZIP_ and __ prefixes for base64 encoding
|
// Supports ZIP_ and __ prefixes for base64 encoding
|
||||||
func decodeHeaderValue(value string) string {
|
func decodeHeaderValue(value string) string {
|
||||||
// Check for ZIP_ prefix
|
str, _ := DecodeParam(value)
|
||||||
if strings.HasPrefix(value, "ZIP_") {
|
return str
|
||||||
decoded, err := base64.StdEncoding.DecodeString(value[4:])
|
}
|
||||||
if err == nil {
|
|
||||||
return string(decoded)
|
// DecodeParam - Decodes parameter string and returns unencoded string
|
||||||
|
func DecodeParam(pStr string) (string, error) {
|
||||||
|
var code string = pStr
|
||||||
|
if strings.HasPrefix(pStr, "ZIP_") {
|
||||||
|
code = strings.ReplaceAll(pStr, "ZIP_", "")
|
||||||
|
code = strings.ReplaceAll(code, "\n", "")
|
||||||
|
code = strings.ReplaceAll(code, "\r", "")
|
||||||
|
code = strings.ReplaceAll(code, " ", "")
|
||||||
|
strDat, err := base64.StdEncoding.DecodeString(code)
|
||||||
|
if err != nil {
|
||||||
|
return code, fmt.Errorf("failed to read parameter base64: %v", err)
|
||||||
|
} else {
|
||||||
|
code = string(strDat)
|
||||||
|
}
|
||||||
|
} else if strings.HasPrefix(pStr, "__") {
|
||||||
|
code = strings.ReplaceAll(pStr, "__", "")
|
||||||
|
code = strings.ReplaceAll(code, "\n", "")
|
||||||
|
code = strings.ReplaceAll(code, "\r", "")
|
||||||
|
code = strings.ReplaceAll(code, " ", "")
|
||||||
|
|
||||||
|
strDat, err := base64.StdEncoding.DecodeString(code)
|
||||||
|
if err != nil {
|
||||||
|
return code, fmt.Errorf("failed to read parameter base64: %v", err)
|
||||||
|
} else {
|
||||||
|
code = string(strDat)
|
||||||
}
|
}
|
||||||
logger.Warn("Failed to decode ZIP_ prefixed value: %v", err)
|
|
||||||
return value
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check for __ prefix
|
if strings.HasPrefix(code, "ZIP_") || strings.HasPrefix(code, "__") {
|
||||||
if strings.HasPrefix(value, "__") {
|
code, _ = DecodeParam(code)
|
||||||
decoded, err := base64.StdEncoding.DecodeString(value[2:])
|
|
||||||
if err == nil {
|
|
||||||
return string(decoded)
|
|
||||||
}
|
|
||||||
logger.Warn("Failed to decode __ prefixed value: %v", err)
|
|
||||||
return value
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return value
|
return code, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// parseOptionsFromHeaders parses all request options from HTTP headers
|
// parseOptionsFromHeaders parses all request options from HTTP headers
|
||||||
@@ -85,12 +102,13 @@ func (h *Handler) parseOptionsFromHeaders(r common.Request) ExtendedRequestOptio
|
|||||||
options := ExtendedRequestOptions{
|
options := ExtendedRequestOptions{
|
||||||
RequestOptions: common.RequestOptions{
|
RequestOptions: common.RequestOptions{
|
||||||
Filters: make([]common.FilterOption, 0),
|
Filters: make([]common.FilterOption, 0),
|
||||||
Sort: make([]common.SortOption, 0),
|
Sort: make([]common.SortOption, 0),
|
||||||
Preload: make([]common.PreloadOption, 0),
|
Preload: make([]common.PreloadOption, 0),
|
||||||
},
|
},
|
||||||
AdvancedSQL: make(map[string]string),
|
AdvancedSQL: make(map[string]string),
|
||||||
ComputedQL: make(map[string]string),
|
ComputedQL: make(map[string]string),
|
||||||
Expand: make([]ExpandOption, 0),
|
Expand: make([]ExpandOption, 0),
|
||||||
|
ResponseFormat: "simple", // Default response format
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get all headers
|
// Get all headers
|
||||||
@@ -198,6 +216,9 @@ func (h *Handler) parseSelectFields(options *ExtendedRequestOptions, value strin
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
options.Columns = h.parseCommaSeparated(value)
|
options.Columns = h.parseCommaSeparated(value)
|
||||||
|
if len(options.Columns) > 1 {
|
||||||
|
options.CleanJSON = true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// parseNotSelectFields parses x-not-select-fields header
|
// parseNotSelectFields parses x-not-select-fields header
|
||||||
@@ -206,15 +227,19 @@ func (h *Handler) parseNotSelectFields(options *ExtendedRequestOptions, value st
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
options.OmitColumns = h.parseCommaSeparated(value)
|
options.OmitColumns = h.parseCommaSeparated(value)
|
||||||
|
if len(options.OmitColumns) > 1 {
|
||||||
|
options.CleanJSON = true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// parseFieldFilter parses x-fieldfilter-{colname} header (exact match)
|
// parseFieldFilter parses x-fieldfilter-{colname} header (exact match)
|
||||||
func (h *Handler) parseFieldFilter(options *ExtendedRequestOptions, headerKey, value string) {
|
func (h *Handler) parseFieldFilter(options *ExtendedRequestOptions, headerKey, value string) {
|
||||||
colName := strings.TrimPrefix(headerKey, "x-fieldfilter-")
|
colName := strings.TrimPrefix(headerKey, "x-fieldfilter-")
|
||||||
options.Filters = append(options.Filters, common.FilterOption{
|
options.Filters = append(options.Filters, common.FilterOption{
|
||||||
Column: colName,
|
Column: colName,
|
||||||
Operator: "eq",
|
Operator: "eq",
|
||||||
Value: value,
|
Value: value,
|
||||||
|
LogicOperator: "AND", // Default to AND
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -223,9 +248,10 @@ func (h *Handler) parseSearchFilter(options *ExtendedRequestOptions, headerKey,
|
|||||||
colName := strings.TrimPrefix(headerKey, "x-searchfilter-")
|
colName := strings.TrimPrefix(headerKey, "x-searchfilter-")
|
||||||
// Use ILIKE for fuzzy search
|
// Use ILIKE for fuzzy search
|
||||||
options.Filters = append(options.Filters, common.FilterOption{
|
options.Filters = append(options.Filters, common.FilterOption{
|
||||||
Column: colName,
|
Column: colName,
|
||||||
Operator: "ilike",
|
Operator: "ilike",
|
||||||
Value: "%" + value + "%",
|
Value: "%" + value + "%",
|
||||||
|
LogicOperator: "AND", // Default to AND
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -254,70 +280,68 @@ func (h *Handler) parseSearchOp(options *ExtendedRequestOptions, headerKey, valu
|
|||||||
colName := parts[1]
|
colName := parts[1]
|
||||||
|
|
||||||
// Map operator names to filter operators
|
// Map operator names to filter operators
|
||||||
filterOp := h.mapSearchOperator(operator, value)
|
filterOp := h.mapSearchOperator(colName, operator, value)
|
||||||
|
|
||||||
|
// Set the logic operator (AND or OR)
|
||||||
|
filterOp.LogicOperator = logicOp
|
||||||
|
|
||||||
options.Filters = append(options.Filters, filterOp)
|
options.Filters = append(options.Filters, filterOp)
|
||||||
|
|
||||||
// Note: OR logic would need special handling in query builder
|
logger.Debug("%s logic filter: %s %s %v", logicOp, colName, filterOp.Operator, filterOp.Value)
|
||||||
// For now, we'll add a comment to indicate OR logic
|
|
||||||
if logicOp == "OR" {
|
|
||||||
// TODO: Implement OR logic in query builder
|
|
||||||
logger.Debug("OR logic filter: %s %s %v", colName, filterOp.Operator, filterOp.Value)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// mapSearchOperator maps search operator names to filter operators
|
// mapSearchOperator maps search operator names to filter operators
|
||||||
func (h *Handler) mapSearchOperator(operator, value string) common.FilterOption {
|
func (h *Handler) mapSearchOperator(colName, operator, value string) common.FilterOption {
|
||||||
operator = strings.ToLower(operator)
|
operator = strings.ToLower(operator)
|
||||||
|
|
||||||
switch operator {
|
switch operator {
|
||||||
case "contains":
|
case "contains", "contain", "like":
|
||||||
return common.FilterOption{Operator: "ilike", Value: "%" + value + "%"}
|
return common.FilterOption{Column: colName, Operator: "ilike", Value: "%" + value + "%"}
|
||||||
case "beginswith", "startswith":
|
case "beginswith", "startswith":
|
||||||
return common.FilterOption{Operator: "ilike", Value: value + "%"}
|
return common.FilterOption{Column: colName, Operator: "ilike", Value: value + "%"}
|
||||||
case "endswith":
|
case "endswith":
|
||||||
return common.FilterOption{Operator: "ilike", Value: "%" + value}
|
return common.FilterOption{Column: colName, Operator: "ilike", Value: "%" + value}
|
||||||
case "equals", "eq":
|
case "equals", "eq", "=":
|
||||||
return common.FilterOption{Operator: "eq", Value: value}
|
return common.FilterOption{Column: colName, Operator: "eq", Value: value}
|
||||||
case "notequals", "neq", "ne":
|
case "notequals", "neq", "ne", "!=", "<>":
|
||||||
return common.FilterOption{Operator: "neq", Value: value}
|
return common.FilterOption{Column: colName, Operator: "neq", Value: value}
|
||||||
case "greaterthan", "gt":
|
case "greaterthan", "gt", ">":
|
||||||
return common.FilterOption{Operator: "gt", Value: value}
|
return common.FilterOption{Column: colName, Operator: "gt", Value: value}
|
||||||
case "lessthan", "lt":
|
case "lessthan", "lt", "<":
|
||||||
return common.FilterOption{Operator: "lt", Value: value}
|
return common.FilterOption{Column: colName, Operator: "lt", Value: value}
|
||||||
case "greaterthanorequal", "gte", "ge":
|
case "greaterthanorequal", "gte", "ge", ">=":
|
||||||
return common.FilterOption{Operator: "gte", Value: value}
|
return common.FilterOption{Column: colName, Operator: "gte", Value: value}
|
||||||
case "lessthanorequal", "lte", "le":
|
case "lessthanorequal", "lte", "le", "<=":
|
||||||
return common.FilterOption{Operator: "lte", Value: value}
|
return common.FilterOption{Column: colName, Operator: "lte", Value: value}
|
||||||
case "between":
|
case "between":
|
||||||
// Parse between values (format: "value1,value2")
|
// Parse between values (format: "value1,value2")
|
||||||
// Between is exclusive (> value1 AND < value2)
|
// Between is exclusive (> value1 AND < value2)
|
||||||
parts := strings.Split(value, ",")
|
parts := strings.Split(value, ",")
|
||||||
if len(parts) == 2 {
|
if len(parts) == 2 {
|
||||||
return common.FilterOption{Operator: "between", Value: parts}
|
return common.FilterOption{Column: colName, Operator: "between", Value: parts}
|
||||||
}
|
}
|
||||||
return common.FilterOption{Operator: "eq", Value: value}
|
return common.FilterOption{Column: colName, Operator: "eq", Value: value}
|
||||||
case "betweeninclusive":
|
case "betweeninclusive":
|
||||||
// Parse between values (format: "value1,value2")
|
// Parse between values (format: "value1,value2")
|
||||||
// Between inclusive is >= value1 AND <= value2
|
// Between inclusive is >= value1 AND <= value2
|
||||||
parts := strings.Split(value, ",")
|
parts := strings.Split(value, ",")
|
||||||
if len(parts) == 2 {
|
if len(parts) == 2 {
|
||||||
return common.FilterOption{Operator: "between_inclusive", Value: parts}
|
return common.FilterOption{Column: colName, Operator: "between_inclusive", Value: parts}
|
||||||
}
|
}
|
||||||
return common.FilterOption{Operator: "eq", Value: value}
|
return common.FilterOption{Column: colName, Operator: "eq", Value: value}
|
||||||
case "in":
|
case "in":
|
||||||
// Parse IN values (format: "value1,value2,value3")
|
// Parse IN values (format: "value1,value2,value3")
|
||||||
values := strings.Split(value, ",")
|
values := strings.Split(value, ",")
|
||||||
return common.FilterOption{Operator: "in", Value: values}
|
return common.FilterOption{Column: colName, Operator: "in", Value: values}
|
||||||
case "empty", "isnull", "null":
|
case "empty", "isnull", "null":
|
||||||
// Check for NULL or empty string
|
// Check for NULL or empty string
|
||||||
return common.FilterOption{Operator: "is_null", Value: nil}
|
return common.FilterOption{Column: colName, Operator: "is_null", Value: nil}
|
||||||
case "notempty", "isnotnull", "notnull":
|
case "notempty", "isnotnull", "notnull":
|
||||||
// Check for NOT NULL
|
// Check for NOT NULL
|
||||||
return common.FilterOption{Operator: "is_not_null", Value: nil}
|
return common.FilterOption{Column: colName, Operator: "is_not_null", Value: nil}
|
||||||
default:
|
default:
|
||||||
logger.Warn("Unknown search operator: %s, defaulting to equals", operator)
|
logger.Warn("Unknown search operator: %s, defaulting to equals", operator)
|
||||||
return common.FilterOption{Operator: "eq", Value: value}
|
return common.FilterOption{Column: colName, Operator: "eq", Value: value}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -404,10 +428,16 @@ func (h *Handler) parseSorting(options *ExtendedRequestOptions, value string) {
|
|||||||
} else if strings.HasPrefix(field, "+") {
|
} else if strings.HasPrefix(field, "+") {
|
||||||
direction = "ASC"
|
direction = "ASC"
|
||||||
colName = strings.TrimPrefix(field, "+")
|
colName = strings.TrimPrefix(field, "+")
|
||||||
|
} else if strings.HasSuffix(field, " desc") {
|
||||||
|
direction = "DESC"
|
||||||
|
colName = strings.TrimSuffix(field, "desc")
|
||||||
|
} else if strings.HasSuffix(field, " asc") {
|
||||||
|
direction = "ASC"
|
||||||
|
colName = strings.TrimSuffix(field, "asc")
|
||||||
}
|
}
|
||||||
|
|
||||||
options.Sort = append(options.Sort, common.SortOption{
|
options.Sort = append(options.Sort, common.SortOption{
|
||||||
Column: colName,
|
Column: strings.Trim(colName, " "),
|
||||||
Direction: direction,
|
Direction: direction,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -439,3 +469,235 @@ func (h *Handler) parseJSONHeader(value string) (map[string]interface{}, error)
|
|||||||
}
|
}
|
||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// getColumnTypeFromModel uses reflection to determine the Go type of a column in a model
|
||||||
|
func (h *Handler) getColumnTypeFromModel(model interface{}, colName string) reflect.Kind {
|
||||||
|
if model == nil {
|
||||||
|
return reflect.Invalid
|
||||||
|
}
|
||||||
|
|
||||||
|
modelType := reflect.TypeOf(model)
|
||||||
|
// Dereference pointer if needed
|
||||||
|
if modelType.Kind() == reflect.Ptr {
|
||||||
|
modelType = modelType.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure it's a struct
|
||||||
|
if modelType.Kind() != reflect.Struct {
|
||||||
|
return reflect.Invalid
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find the field by JSON tag or field name
|
||||||
|
for i := 0; i < modelType.NumField(); i++ {
|
||||||
|
field := modelType.Field(i)
|
||||||
|
|
||||||
|
// Check JSON tag
|
||||||
|
jsonTag := field.Tag.Get("json")
|
||||||
|
if jsonTag != "" {
|
||||||
|
// Parse JSON tag (format: "name,omitempty")
|
||||||
|
parts := strings.Split(jsonTag, ",")
|
||||||
|
if parts[0] == colName {
|
||||||
|
return field.Type.Kind()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check field name (case-insensitive)
|
||||||
|
if strings.EqualFold(field.Name, colName) {
|
||||||
|
return field.Type.Kind()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check snake_case conversion
|
||||||
|
snakeCaseName := toSnakeCase(field.Name)
|
||||||
|
if snakeCaseName == colName {
|
||||||
|
return field.Type.Kind()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return reflect.Invalid
|
||||||
|
}
|
||||||
|
|
||||||
|
// toSnakeCase converts a string from CamelCase to snake_case
|
||||||
|
func toSnakeCase(s string) string {
|
||||||
|
var result strings.Builder
|
||||||
|
for i, r := range s {
|
||||||
|
if i > 0 && r >= 'A' && r <= 'Z' {
|
||||||
|
result.WriteRune('_')
|
||||||
|
}
|
||||||
|
result.WriteRune(r)
|
||||||
|
}
|
||||||
|
return strings.ToLower(result.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// isNumericType checks if a reflect.Kind is a numeric type
|
||||||
|
func isNumericType(kind reflect.Kind) bool {
|
||||||
|
return kind == reflect.Int || kind == reflect.Int8 || kind == reflect.Int16 ||
|
||||||
|
kind == reflect.Int32 || kind == reflect.Int64 || kind == reflect.Uint ||
|
||||||
|
kind == reflect.Uint8 || kind == reflect.Uint16 || kind == reflect.Uint32 ||
|
||||||
|
kind == reflect.Uint64 || kind == reflect.Float32 || kind == reflect.Float64
|
||||||
|
}
|
||||||
|
|
||||||
|
// isStringType checks if a reflect.Kind is a string type
|
||||||
|
func isStringType(kind reflect.Kind) bool {
|
||||||
|
return kind == reflect.String
|
||||||
|
}
|
||||||
|
|
||||||
|
// isBoolType checks if a reflect.Kind is a boolean type
|
||||||
|
func isBoolType(kind reflect.Kind) bool {
|
||||||
|
return kind == reflect.Bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// convertToNumericType converts a string value to the appropriate numeric type
|
||||||
|
func convertToNumericType(value string, kind reflect.Kind) (interface{}, error) {
|
||||||
|
value = strings.TrimSpace(value)
|
||||||
|
|
||||||
|
switch kind {
|
||||||
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||||
|
// Parse as integer
|
||||||
|
bitSize := 64
|
||||||
|
switch kind {
|
||||||
|
case reflect.Int8:
|
||||||
|
bitSize = 8
|
||||||
|
case reflect.Int16:
|
||||||
|
bitSize = 16
|
||||||
|
case reflect.Int32:
|
||||||
|
bitSize = 32
|
||||||
|
}
|
||||||
|
|
||||||
|
intVal, err := strconv.ParseInt(value, 10, bitSize)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid integer value: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return the appropriate type
|
||||||
|
switch kind {
|
||||||
|
case reflect.Int:
|
||||||
|
return int(intVal), nil
|
||||||
|
case reflect.Int8:
|
||||||
|
return int8(intVal), nil
|
||||||
|
case reflect.Int16:
|
||||||
|
return int16(intVal), nil
|
||||||
|
case reflect.Int32:
|
||||||
|
return int32(intVal), nil
|
||||||
|
case reflect.Int64:
|
||||||
|
return intVal, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||||
|
// Parse as unsigned integer
|
||||||
|
bitSize := 64
|
||||||
|
switch kind {
|
||||||
|
case reflect.Uint8:
|
||||||
|
bitSize = 8
|
||||||
|
case reflect.Uint16:
|
||||||
|
bitSize = 16
|
||||||
|
case reflect.Uint32:
|
||||||
|
bitSize = 32
|
||||||
|
}
|
||||||
|
|
||||||
|
uintVal, err := strconv.ParseUint(value, 10, bitSize)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid unsigned integer value: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return the appropriate type
|
||||||
|
switch kind {
|
||||||
|
case reflect.Uint:
|
||||||
|
return uint(uintVal), nil
|
||||||
|
case reflect.Uint8:
|
||||||
|
return uint8(uintVal), nil
|
||||||
|
case reflect.Uint16:
|
||||||
|
return uint16(uintVal), nil
|
||||||
|
case reflect.Uint32:
|
||||||
|
return uint32(uintVal), nil
|
||||||
|
case reflect.Uint64:
|
||||||
|
return uintVal, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
case reflect.Float32, reflect.Float64:
|
||||||
|
// Parse as float
|
||||||
|
bitSize := 64
|
||||||
|
if kind == reflect.Float32 {
|
||||||
|
bitSize = 32
|
||||||
|
}
|
||||||
|
|
||||||
|
floatVal, err := strconv.ParseFloat(value, bitSize)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid float value: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if kind == reflect.Float32 {
|
||||||
|
return float32(floatVal), nil
|
||||||
|
}
|
||||||
|
return floatVal, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("unsupported numeric type: %v", kind)
|
||||||
|
}
|
||||||
|
|
||||||
|
// isNumericValue checks if a string value can be parsed as a number
|
||||||
|
func isNumericValue(value string) bool {
|
||||||
|
value = strings.TrimSpace(value)
|
||||||
|
_, err := strconv.ParseFloat(value, 64)
|
||||||
|
return err == nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ColumnCastInfo holds information about whether a column needs casting
|
||||||
|
type ColumnCastInfo struct {
|
||||||
|
NeedsCast bool
|
||||||
|
IsNumericType bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateAndAdjustFilterForColumnType validates and adjusts a filter based on column type
|
||||||
|
// Returns ColumnCastInfo indicating whether the column should be cast to text in SQL
|
||||||
|
func (h *Handler) ValidateAndAdjustFilterForColumnType(filter *common.FilterOption, model interface{}) ColumnCastInfo {
|
||||||
|
if filter == nil || model == nil {
|
||||||
|
return ColumnCastInfo{NeedsCast: false, IsNumericType: false}
|
||||||
|
}
|
||||||
|
|
||||||
|
colType := h.getColumnTypeFromModel(model, filter.Column)
|
||||||
|
if colType == reflect.Invalid {
|
||||||
|
// Column not found in model, no casting needed
|
||||||
|
logger.Debug("Column %s not found in model, skipping type validation", filter.Column)
|
||||||
|
return ColumnCastInfo{NeedsCast: false, IsNumericType: false}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if the input value is numeric
|
||||||
|
valueIsNumeric := false
|
||||||
|
if strVal, ok := filter.Value.(string); ok {
|
||||||
|
strVal = strings.Trim(strVal, "%")
|
||||||
|
valueIsNumeric = isNumericValue(strVal)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Adjust based on column type
|
||||||
|
switch {
|
||||||
|
case isNumericType(colType):
|
||||||
|
// Column is numeric
|
||||||
|
if valueIsNumeric {
|
||||||
|
// Value is numeric - try to convert it
|
||||||
|
if strVal, ok := filter.Value.(string); ok {
|
||||||
|
strVal = strings.Trim(strVal, "%")
|
||||||
|
numericVal, err := convertToNumericType(strVal, colType)
|
||||||
|
if err != nil {
|
||||||
|
logger.Debug("Failed to convert value '%s' to numeric type for column %s, will use text cast", strVal, filter.Column)
|
||||||
|
return ColumnCastInfo{NeedsCast: true, IsNumericType: true}
|
||||||
|
}
|
||||||
|
filter.Value = numericVal
|
||||||
|
}
|
||||||
|
// No cast needed - numeric column with numeric value
|
||||||
|
return ColumnCastInfo{NeedsCast: false, IsNumericType: true}
|
||||||
|
} else {
|
||||||
|
// Value is not numeric - cast column to text for comparison
|
||||||
|
logger.Debug("Non-numeric value for numeric column %s, will cast to text", filter.Column)
|
||||||
|
return ColumnCastInfo{NeedsCast: true, IsNumericType: true}
|
||||||
|
}
|
||||||
|
|
||||||
|
case isStringType(colType):
|
||||||
|
// String columns don't need casting
|
||||||
|
return ColumnCastInfo{NeedsCast: false, IsNumericType: false}
|
||||||
|
|
||||||
|
default:
|
||||||
|
// For bool, time.Time, and other complex types - cast to text
|
||||||
|
logger.Debug("Complex type column %s, will cast to text", filter.Column)
|
||||||
|
return ColumnCastInfo{NeedsCast: true, IsNumericType: false}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
152
todo.md
Normal file
152
todo.md
Normal file
@@ -0,0 +1,152 @@
|
|||||||
|
# ResolveSpec - TODO List
|
||||||
|
|
||||||
|
This document tracks incomplete features and improvements for the ResolveSpec project.
|
||||||
|
|
||||||
|
## Core Features to Implement
|
||||||
|
|
||||||
|
### 1. Column Selection and Filtering for Preloads
|
||||||
|
**Location:** `pkg/resolvespec/handler.go:730`
|
||||||
|
**Status:** Not Implemented
|
||||||
|
**Description:** Currently, preloads are applied without any column selection or filtering. This feature would allow clients to:
|
||||||
|
- Select specific columns for preloaded relationships
|
||||||
|
- Apply filters to preloaded data
|
||||||
|
- Reduce payload size and improve performance
|
||||||
|
|
||||||
|
**Current Limitation:**
|
||||||
|
```go
|
||||||
|
// For now, we'll preload without conditions
|
||||||
|
// TODO: Implement column selection and filtering for preloads
|
||||||
|
// This requires a more sophisticated approach with callbacks or query builders
|
||||||
|
query = query.Preload(relationFieldName)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Required Implementation:**
|
||||||
|
- Add support for column selection in preloaded relationships
|
||||||
|
- Implement filtering conditions for preloaded data
|
||||||
|
- Design a callback or query builder approach that works across different ORMs
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 2. Recursive JSON Cleaning
|
||||||
|
**Location:** `pkg/restheadspec/handler.go:796`
|
||||||
|
**Status:** Partially Implemented (Simplified)
|
||||||
|
**Description:** The current `cleanJSON` function returns data as-is without recursively removing null and empty fields from nested structures.
|
||||||
|
|
||||||
|
**Current Limitation:**
|
||||||
|
```go
|
||||||
|
// This is a simplified implementation
|
||||||
|
// A full implementation would recursively clean nested structures
|
||||||
|
// For now, we'll return the data as-is
|
||||||
|
// TODO: Implement recursive cleaning
|
||||||
|
return data
|
||||||
|
```
|
||||||
|
|
||||||
|
**Required Implementation:**
|
||||||
|
- Recursively traverse nested structures (maps, slices, structs)
|
||||||
|
- Remove null values
|
||||||
|
- Remove empty objects and arrays
|
||||||
|
- Handle edge cases (circular references, pointers, etc.)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 3. Custom SQL Join Support
|
||||||
|
**Location:** `pkg/restheadspec/headers.go:159`
|
||||||
|
**Status:** Not Implemented
|
||||||
|
**Description:** Support for custom SQL joins via the `X-Custom-SQL-Join` header is currently logged but not executed.
|
||||||
|
|
||||||
|
**Current Limitation:**
|
||||||
|
```go
|
||||||
|
case strings.HasPrefix(normalizedKey, "x-custom-sql-join"):
|
||||||
|
// TODO: Implement custom SQL join
|
||||||
|
logger.Debug("Custom SQL join not yet implemented: %s", decodedValue)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Required Implementation:**
|
||||||
|
- Parse custom SQL join expressions from headers
|
||||||
|
- Apply joins to the query builder
|
||||||
|
- Ensure security (SQL injection prevention)
|
||||||
|
- Support for different join types (INNER, LEFT, RIGHT, FULL)
|
||||||
|
- Works across different database adapters (GORM, Bun)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 4. Proper Condition Handling for Bun Preloads
|
||||||
|
**Location:** `pkg/common/adapters/database/bun.go:202`
|
||||||
|
**Status:** Partially Implemented
|
||||||
|
**Description:** The Bun adapter's `Preload` method currently ignores conditions passed to it.
|
||||||
|
|
||||||
|
**Current Limitation:**
|
||||||
|
```go
|
||||||
|
func (b *BunSelectQuery) Preload(relation string, conditions ...interface{}) common.SelectQuery {
|
||||||
|
// Bun uses Relation() method for preloading
|
||||||
|
// For now, we'll just pass the relation name without conditions
|
||||||
|
// TODO: Implement proper condition handling for Bun
|
||||||
|
b.query = b.query.Relation(relation)
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Required Implementation:**
|
||||||
|
- Properly handle condition parameters in Bun's Relation() method
|
||||||
|
- Support filtering on preloaded relationships
|
||||||
|
- Ensure compatibility with GORM's condition syntax where possible
|
||||||
|
- Test with various condition types
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Code Quality Improvements
|
||||||
|
|
||||||
|
### 5. Modernize Go Type Declarations
|
||||||
|
**Location:** `pkg/common/types.go:5, 42, 64, 79`
|
||||||
|
**Status:** Pending
|
||||||
|
**Priority:** Low
|
||||||
|
**Description:** Replace legacy `interface{}` with modern `any` type alias (Go 1.18+).
|
||||||
|
|
||||||
|
**Affected Lines:**
|
||||||
|
- Line 5: Function parameter or return type
|
||||||
|
- Line 42: Function parameter or return type
|
||||||
|
- Line 64: Function parameter or return type
|
||||||
|
- Line 79: Function parameter or return type
|
||||||
|
|
||||||
|
**Benefits:**
|
||||||
|
- More modern and idiomatic Go code
|
||||||
|
- Better readability
|
||||||
|
- Aligns with current Go best practices
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Additional Considerations
|
||||||
|
|
||||||
|
### Documentation
|
||||||
|
- Ensure all new features are documented in README.md
|
||||||
|
- Update examples to showcase new functionality
|
||||||
|
- Add migration notes if any breaking changes are introduced
|
||||||
|
|
||||||
|
### Testing
|
||||||
|
- Add unit tests for each new feature
|
||||||
|
- Add integration tests for database adapter compatibility
|
||||||
|
- Ensure backward compatibility is maintained
|
||||||
|
|
||||||
|
### Performance
|
||||||
|
- Profile preload performance with column selection and filtering
|
||||||
|
- Optimize recursive JSON cleaning for large payloads
|
||||||
|
- Benchmark custom SQL join performance
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Priority Ranking
|
||||||
|
|
||||||
|
1. **High Priority**
|
||||||
|
- Column Selection and Filtering for Preloads (#1)
|
||||||
|
- Proper Condition Handling for Bun Preloads (#4)
|
||||||
|
|
||||||
|
2. **Medium Priority**
|
||||||
|
- Custom SQL Join Support (#3)
|
||||||
|
- Recursive JSON Cleaning (#2)
|
||||||
|
|
||||||
|
3. **Low Priority**
|
||||||
|
- Modernize Go Type Declarations (#5)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
**Last Updated:** 2025-11-07
|
||||||
Reference in New Issue
Block a user