Compare commits

..

9 Commits

Author SHA1 Message Date
Hein
abee5c942f Count Fixes 2025-11-07 13:54:24 +02:00
Hein
2e9a0bd51a Better model pointers 2025-11-07 13:45:08 +02:00
Hein
f518a3c73c Some validation and header decoding 2025-11-07 13:31:48 +02:00
Hein
07c239aaa1 Make sure to enable Clean JSON when select fields given 2025-11-07 11:00:56 +02:00
Hein
1adca4c49b Content Types and Respose fixes for restheadpsec 2025-11-07 10:55:42 +02:00
Hein
eefed23766 COUNT queries were generating incorrect SQL with the table appearing twice 2025-11-07 10:37:53 +02:00
Hein
3b2d05465e Fixed tablename and schema lookups 2025-11-07 10:28:14 +02:00
Hein
e88018543e Reflect safty 2025-11-07 09:47:12 +02:00
Hein
e7e5754a47 Added panic catches 2025-11-07 09:32:37 +02:00
10 changed files with 1275 additions and 102 deletions

138
SCHEMA_TABLE_HANDLING.md Normal file
View File

@@ -0,0 +1,138 @@
# Schema and Table Name Handling
This document explains how the handlers properly separate and handle schema and table names.
## Implementation
Both `resolvespec` and `restheadspec` handlers now properly handle schema and table name separation through the following functions:
- `parseTableName(fullTableName)` - Splits "schema.table" into separate components
- `getSchemaAndTable(defaultSchema, entity, model)` - Returns schema and table separately
- `getTableName(schema, entity, model)` - Returns the full "schema.table" format
## Priority Order
When determining the schema and table name, the following priority is used:
1. **If `TableName()` contains a schema** (e.g., "myschema.mytable"), that schema takes precedence
2. **If model implements `SchemaProvider`**, use that schema
3. **Otherwise**, use the `defaultSchema` parameter from the URL/request
## Scenarios
### Scenario 1: Simple table name, default schema
```go
type User struct {
ID string
Name string
}
func (User) TableName() string {
return "users"
}
```
- Request URL: `/api/public/users`
- Result: `schema="public"`, `table="users"`, `fullName="public.users"`
### Scenario 2: Table name includes schema
```go
type User struct {
ID string
Name string
}
func (User) TableName() string {
return "auth.users" // Schema included!
}
```
- Request URL: `/api/public/users` (public is ignored)
- Result: `schema="auth"`, `table="users"`, `fullName="auth.users"`
- **Note**: The schema from `TableName()` takes precedence over the URL schema
### Scenario 3: Using SchemaProvider
```go
type User struct {
ID string
Name string
}
func (User) TableName() string {
return "users"
}
func (User) SchemaName() string {
return "auth"
}
```
- Request URL: `/api/public/users` (public is ignored)
- Result: `schema="auth"`, `table="users"`, `fullName="auth.users"`
### Scenario 4: Table name includes schema AND SchemaProvider
```go
type User struct {
ID string
Name string
}
func (User) TableName() string {
return "core.users" // This wins!
}
func (User) SchemaName() string {
return "auth" // This is ignored
}
```
- Request URL: `/api/public/users`
- Result: `schema="core"`, `table="users"`, `fullName="core.users"`
- **Note**: Schema from `TableName()` takes highest precedence
### Scenario 5: No providers at all
```go
type User struct {
ID string
Name string
}
// No TableName() or SchemaName()
```
- Request URL: `/api/public/users`
- Result: `schema="public"`, `table="users"`, `fullName="public.users"`
- Uses URL schema and entity name
## Key Features
1. **Automatic detection**: The code automatically detects if `TableName()` includes a schema by checking for "."
2. **Backward compatible**: Existing code continues to work
3. **Flexible**: Supports multiple ways to specify schema and table
4. **Debug logging**: Logs when schema is detected in `TableName()` for debugging
## Code Locations
### Handlers
- `/pkg/resolvespec/handler.go:472-531`
- `/pkg/restheadspec/handler.go:534-593`
### Database Adapters
- `/pkg/common/adapters/database/utils.go` - Shared `parseTableName()` function
- `/pkg/common/adapters/database/bun.go` - Bun adapter with separated schema/table
- `/pkg/common/adapters/database/gorm.go` - GORM adapter with separated schema/table
## Adapter Implementation
Both Bun and GORM adapters now properly separate schema and table name:
```go
// BunSelectQuery/GormSelectQuery now have separated fields:
type BunSelectQuery struct {
query *bun.SelectQuery
schema string // Separated schema name
tableName string // Just the table name, without schema
tableAlias string
}
```
When `Model()` or `Table()` is called:
1. The full table name (which may include schema) is parsed
2. Schema and table name are stored separately
3. When building joins, the already-separated table name is used directly
This ensures consistent handling of schema-qualified table names throughout the codebase.

View File

@@ -78,7 +78,8 @@ func (b *BunAdapter) RunInTransaction(ctx context.Context, fn func(common.Databa
// BunSelectQuery implements SelectQuery for Bun
type BunSelectQuery struct {
query *bun.SelectQuery
tableName string
schema string // Separated schema name
tableName string // Just the table name, without schema
tableAlias string
}
@@ -87,7 +88,9 @@ func (b *BunSelectQuery) Model(model interface{}) common.SelectQuery {
// Try to get table name from model if it implements TableNameProvider
if provider, ok := model.(common.TableNameProvider); ok {
b.tableName = provider.TableName()
fullTableName := provider.TableName()
// Check if the table name contains schema (e.g., "schema.table")
b.schema, b.tableName = parseTableName(fullTableName)
}
return b
@@ -95,7 +98,8 @@ func (b *BunSelectQuery) Model(model interface{}) common.SelectQuery {
func (b *BunSelectQuery) Table(table string) common.SelectQuery {
b.query = b.query.Table(table)
b.tableName = table
// Check if the table name contains schema (e.g., "schema.table")
b.schema, b.tableName = parseTableName(table)
return b
}
@@ -128,13 +132,9 @@ func (b *BunSelectQuery) Join(query string, args ...interface{}) common.SelectQu
}
}
// If no prefix provided, use the table name as prefix
// If no prefix provided, use the table name as prefix (already separated from schema)
if prefix == "" && b.tableName != "" {
prefix = b.tableName
// Extract just the table name if it has schema
if idx := strings.LastIndex(prefix, "."); idx != -1 {
prefix = prefix[idx+1:]
}
}
// If prefix is provided, add it as an alias in the join
@@ -169,12 +169,9 @@ func (b *BunSelectQuery) LeftJoin(query string, args ...interface{}) common.Sele
}
}
// If no prefix provided, use the table name as prefix
// If no prefix provided, use the table name as prefix (already separated from schema)
if prefix == "" && b.tableName != "" {
prefix = b.tableName
if idx := strings.LastIndex(prefix, "."); idx != -1 {
prefix = prefix[idx+1:]
}
}
// Construct LEFT JOIN with prefix
@@ -231,7 +228,10 @@ func (b *BunSelectQuery) Scan(ctx context.Context, dest interface{}) error {
}
func (b *BunSelectQuery) Count(ctx context.Context) (int, error) {
count, err := b.query.Count(ctx)
// Use ColumnExpr with Scan instead of Count() to avoid requiring a model
// This works with just Table() set and avoids "Model(nil)" error
var count int
err := b.query.ColumnExpr("count(*)").Scan(ctx, &count)
return count, err
}

View File

@@ -70,7 +70,8 @@ func (g *GormAdapter) RunInTransaction(ctx context.Context, fn func(common.Datab
// GormSelectQuery implements SelectQuery for GORM
type GormSelectQuery struct {
db *gorm.DB
tableName string
schema string // Separated schema name
tableName string // Just the table name, without schema
tableAlias string
}
@@ -79,7 +80,9 @@ func (g *GormSelectQuery) Model(model interface{}) common.SelectQuery {
// Try to get table name from model if it implements TableNameProvider
if provider, ok := model.(common.TableNameProvider); ok {
g.tableName = provider.TableName()
fullTableName := provider.TableName()
// Check if the table name contains schema (e.g., "schema.table")
g.schema, g.tableName = parseTableName(fullTableName)
}
return g
@@ -87,7 +90,8 @@ func (g *GormSelectQuery) Model(model interface{}) common.SelectQuery {
func (g *GormSelectQuery) Table(table string) common.SelectQuery {
g.db = g.db.Table(table)
g.tableName = table
// Check if the table name contains schema (e.g., "schema.table")
g.schema, g.tableName = parseTableName(table)
return g
}
@@ -120,13 +124,9 @@ func (g *GormSelectQuery) Join(query string, args ...interface{}) common.SelectQ
}
}
// If no prefix provided, use the table name as prefix
// If no prefix provided, use the table name as prefix (already separated from schema)
if prefix == "" && g.tableName != "" {
prefix = g.tableName
// Extract just the table name if it has schema
if idx := strings.LastIndex(prefix, "."); idx != -1 {
prefix = prefix[idx+1:]
}
}
// If prefix is provided, add it as an alias in the join
@@ -161,12 +161,9 @@ func (g *GormSelectQuery) LeftJoin(query string, args ...interface{}) common.Sel
}
}
// If no prefix provided, use the table name as prefix
// If no prefix provided, use the table name as prefix (already separated from schema)
if prefix == "" && g.tableName != "" {
prefix = g.tableName
if idx := strings.LastIndex(prefix, "."); idx != -1 {
prefix = prefix[idx+1:]
}
}
// Construct LEFT JOIN with prefix

View File

@@ -0,0 +1,13 @@
package database
import "strings"
// parseTableName splits a table name that may contain schema into separate schema and table
// For example: "public.users" -> ("public", "users")
// "users" -> ("", "users")
func parseTableName(fullTableName string) (schema, table string) {
if idx := strings.LastIndex(fullTableName, "."); idx != -1 {
return fullTableName[:idx], fullTableName[idx+1:]
}
return "", fullTableName
}

272
pkg/common/validation.go Normal file
View File

@@ -0,0 +1,272 @@
package common
import (
"fmt"
"reflect"
"strings"
"github.com/Warky-Devs/ResolveSpec/pkg/logger"
)
// ColumnValidator validates column names against a model's fields
type ColumnValidator struct {
validColumns map[string]bool
model interface{}
}
// NewColumnValidator creates a new column validator for a given model
func NewColumnValidator(model interface{}) *ColumnValidator {
validator := &ColumnValidator{
validColumns: make(map[string]bool),
model: model,
}
validator.buildValidColumns()
return validator
}
// buildValidColumns extracts all valid column names from the model using reflection
func (v *ColumnValidator) buildValidColumns() {
modelType := reflect.TypeOf(v.model)
// Unwrap pointers, slices, and arrays to get to the base struct type
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
modelType = modelType.Elem()
}
// Validate that we have a struct type
if modelType == nil || modelType.Kind() != reflect.Struct {
return
}
// Extract column names from struct fields
for i := 0; i < modelType.NumField(); i++ {
field := modelType.Field(i)
if !field.IsExported() {
continue
}
// Get column name from bun, gorm, or json tag
columnName := v.getColumnName(field)
if columnName != "" && columnName != "-" {
v.validColumns[strings.ToLower(columnName)] = true
}
}
}
// getColumnName extracts the column name from a struct field's tags
// Supports both Bun and GORM tags
func (v *ColumnValidator) getColumnName(field reflect.StructField) string {
// First check Bun tag for column name
bunTag := field.Tag.Get("bun")
if bunTag != "" && bunTag != "-" {
parts := strings.Split(bunTag, ",")
// The first part is usually the column name
columnName := strings.TrimSpace(parts[0])
if columnName != "" && columnName != "-" {
return columnName
}
}
// Check GORM tag for column name
gormTag := field.Tag.Get("gorm")
if strings.Contains(gormTag, "column:") {
parts := strings.Split(gormTag, ";")
for _, part := range parts {
part = strings.TrimSpace(part)
if strings.HasPrefix(part, "column:") {
return strings.TrimPrefix(part, "column:")
}
}
}
// Fall back to JSON tag
jsonTag := field.Tag.Get("json")
if jsonTag != "" && jsonTag != "-" {
// Extract just the name part (before any comma)
jsonName := strings.Split(jsonTag, ",")[0]
return jsonName
}
// Fall back to field name in lowercase (snake_case conversion would be better)
return strings.ToLower(field.Name)
}
// ValidateColumn validates a single column name
// Returns nil if valid, error if invalid
// Columns prefixed with "cql" (case insensitive) are always valid
func (v *ColumnValidator) ValidateColumn(column string) error {
// Allow empty columns
if column == "" {
return nil
}
// Allow columns prefixed with "cql" (case insensitive) for computed columns
if strings.HasPrefix(strings.ToLower(column), "cql") {
return nil
}
// Check if column exists in model
if _, exists := v.validColumns[strings.ToLower(column)]; !exists {
return fmt.Errorf("invalid column '%s': column does not exist in model", column)
}
return nil
}
// IsValidColumn checks if a column is valid
// Returns true if valid, false if invalid
func (v *ColumnValidator) IsValidColumn(column string) bool {
return v.ValidateColumn(column) == nil
}
// FilterValidColumns filters a list of columns, returning only valid ones
// Logs warnings for any invalid columns
func (v *ColumnValidator) FilterValidColumns(columns []string) []string {
if len(columns) == 0 {
return columns
}
validColumns := make([]string, 0, len(columns))
for _, col := range columns {
if v.IsValidColumn(col) {
validColumns = append(validColumns, col)
} else {
logger.Warn("Invalid column '%s' filtered out: column does not exist in model", col)
}
}
return validColumns
}
// ValidateColumns validates multiple column names
// Returns error with details about all invalid columns
func (v *ColumnValidator) ValidateColumns(columns []string) error {
var invalidColumns []string
for _, column := range columns {
if err := v.ValidateColumn(column); err != nil {
invalidColumns = append(invalidColumns, column)
}
}
if len(invalidColumns) > 0 {
return fmt.Errorf("invalid columns: %s", strings.Join(invalidColumns, ", "))
}
return nil
}
// ValidateRequestOptions validates all column references in RequestOptions
func (v *ColumnValidator) ValidateRequestOptions(options RequestOptions) error {
// Validate Columns
if err := v.ValidateColumns(options.Columns); err != nil {
return fmt.Errorf("in select columns: %w", err)
}
// Validate OmitColumns
if err := v.ValidateColumns(options.OmitColumns); err != nil {
return fmt.Errorf("in omit columns: %w", err)
}
// Validate Filter columns
for _, filter := range options.Filters {
if err := v.ValidateColumn(filter.Column); err != nil {
return fmt.Errorf("in filter: %w", err)
}
}
// Validate Sort columns
for _, sort := range options.Sort {
if err := v.ValidateColumn(sort.Column); err != nil {
return fmt.Errorf("in sort: %w", err)
}
}
// Validate Preload columns (if specified)
for _, preload := range options.Preload {
// Note: We don't validate the relation name itself, as it's a relationship
// Only validate columns if specified for the preload
if err := v.ValidateColumns(preload.Columns); err != nil {
return fmt.Errorf("in preload '%s' columns: %w", preload.Relation, err)
}
if err := v.ValidateColumns(preload.OmitColumns); err != nil {
return fmt.Errorf("in preload '%s' omit columns: %w", preload.Relation, err)
}
// Validate filter columns in preload
for _, filter := range preload.Filters {
if err := v.ValidateColumn(filter.Column); err != nil {
return fmt.Errorf("in preload '%s' filter: %w", preload.Relation, err)
}
}
}
return nil
}
// FilterRequestOptions filters all column references in RequestOptions
// Returns a new RequestOptions with only valid columns, logging warnings for invalid ones
func (v *ColumnValidator) FilterRequestOptions(options RequestOptions) RequestOptions {
filtered := options
// Filter Columns
filtered.Columns = v.FilterValidColumns(options.Columns)
// Filter OmitColumns
filtered.OmitColumns = v.FilterValidColumns(options.OmitColumns)
// Filter Filter columns
validFilters := make([]FilterOption, 0, len(options.Filters))
for _, filter := range options.Filters {
if v.IsValidColumn(filter.Column) {
validFilters = append(validFilters, filter)
} else {
logger.Warn("Invalid column in filter '%s' removed", filter.Column)
}
}
filtered.Filters = validFilters
// Filter Sort columns
validSorts := make([]SortOption, 0, len(options.Sort))
for _, sort := range options.Sort {
if v.IsValidColumn(sort.Column) {
validSorts = append(validSorts, sort)
} else {
logger.Warn("Invalid column in sort '%s' removed", sort.Column)
}
}
filtered.Sort = validSorts
// Filter Preload columns
validPreloads := make([]PreloadOption, 0, len(options.Preload))
for _, preload := range options.Preload {
filteredPreload := preload
filteredPreload.Columns = v.FilterValidColumns(preload.Columns)
filteredPreload.OmitColumns = v.FilterValidColumns(preload.OmitColumns)
// Filter preload filters
validPreloadFilters := make([]FilterOption, 0, len(preload.Filters))
for _, filter := range preload.Filters {
if v.IsValidColumn(filter.Column) {
validPreloadFilters = append(validPreloadFilters, filter)
} else {
logger.Warn("Invalid column in preload '%s' filter '%s' removed", preload.Relation, filter.Column)
}
}
filteredPreload.Filters = validPreloadFilters
validPreloads = append(validPreloads, filteredPreload)
}
filtered.Preload = validPreloads
return filtered
}
// GetValidColumns returns a list of all valid column names for debugging purposes
func (v *ColumnValidator) GetValidColumns() []string {
columns := make([]string, 0, len(v.validColumns))
for col := range v.validColumns {
columns = append(columns, col)
}
return columns
}

View File

@@ -0,0 +1,363 @@
package common
import (
"strings"
"testing"
)
// TestModel represents a sample model for testing
type TestModel struct {
ID int64 `json:"id" gorm:"primaryKey"`
Name string `json:"name" gorm:"column:name"`
Email string `json:"email" bun:"email"`
Age int `json:"age"`
IsActive bool `json:"is_active"`
CreatedAt string `json:"created_at"`
}
func TestNewColumnValidator(t *testing.T) {
model := TestModel{}
validator := NewColumnValidator(model)
if validator == nil {
t.Fatal("Expected validator to be created")
}
if len(validator.validColumns) == 0 {
t.Fatal("Expected validator to have valid columns")
}
// Check that expected columns are present
expectedColumns := []string{"id", "name", "email", "age", "is_active", "created_at"}
for _, col := range expectedColumns {
if !validator.validColumns[col] {
t.Errorf("Expected column '%s' to be valid", col)
}
}
}
func TestValidateColumn(t *testing.T) {
model := TestModel{}
validator := NewColumnValidator(model)
tests := []struct {
name string
column string
shouldError bool
}{
{"Valid column - id", "id", false},
{"Valid column - name", "name", false},
{"Valid column - email", "email", false},
{"Valid column - uppercase", "ID", false}, // Case insensitive
{"Invalid column", "invalid_column", true},
{"CQL prefixed - should be valid", "cqlComputedField", false},
{"CQL prefixed uppercase - should be valid", "CQLComputedField", false},
{"Empty column", "", false}, // Empty columns are allowed
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validator.ValidateColumn(tt.column)
if tt.shouldError && err == nil {
t.Errorf("Expected error for column '%s', got nil", tt.column)
}
if !tt.shouldError && err != nil {
t.Errorf("Expected no error for column '%s', got: %v", tt.column, err)
}
})
}
}
func TestValidateColumns(t *testing.T) {
model := TestModel{}
validator := NewColumnValidator(model)
tests := []struct {
name string
columns []string
shouldError bool
}{
{"All valid columns", []string{"id", "name", "email"}, false},
{"One invalid column", []string{"id", "invalid_col", "name"}, true},
{"All invalid columns", []string{"bad1", "bad2"}, true},
{"With CQL prefix", []string{"id", "cqlComputed", "name"}, false},
{"Empty list", []string{}, false},
{"Nil list", nil, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validator.ValidateColumns(tt.columns)
if tt.shouldError && err == nil {
t.Errorf("Expected error for columns %v, got nil", tt.columns)
}
if !tt.shouldError && err != nil {
t.Errorf("Expected no error for columns %v, got: %v", tt.columns, err)
}
})
}
}
func TestValidateRequestOptions(t *testing.T) {
model := TestModel{}
validator := NewColumnValidator(model)
tests := []struct {
name string
options RequestOptions
shouldError bool
errorMsg string
}{
{
name: "Valid options with columns",
options: RequestOptions{
Columns: []string{"id", "name"},
Filters: []FilterOption{
{Column: "name", Operator: "eq", Value: "test"},
},
Sort: []SortOption{
{Column: "id", Direction: "ASC"},
},
},
shouldError: false,
},
{
name: "Invalid column in Columns",
options: RequestOptions{
Columns: []string{"id", "invalid_column"},
},
shouldError: true,
errorMsg: "select columns",
},
{
name: "Invalid column in Filters",
options: RequestOptions{
Filters: []FilterOption{
{Column: "invalid_col", Operator: "eq", Value: "test"},
},
},
shouldError: true,
errorMsg: "filter",
},
{
name: "Invalid column in Sort",
options: RequestOptions{
Sort: []SortOption{
{Column: "invalid_col", Direction: "ASC"},
},
},
shouldError: true,
errorMsg: "sort",
},
{
name: "Valid CQL prefixed columns",
options: RequestOptions{
Columns: []string{"id", "cqlComputedField"},
Filters: []FilterOption{
{Column: "cqlCustomFilter", Operator: "eq", Value: "test"},
},
},
shouldError: false,
},
{
name: "Invalid column in Preload",
options: RequestOptions{
Preload: []PreloadOption{
{
Relation: "SomeRelation",
Columns: []string{"id", "invalid_col"},
},
},
},
shouldError: true,
errorMsg: "preload",
},
{
name: "Valid preload with valid columns",
options: RequestOptions{
Preload: []PreloadOption{
{
Relation: "SomeRelation",
Columns: []string{"id", "name"},
},
},
},
shouldError: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validator.ValidateRequestOptions(tt.options)
if tt.shouldError {
if err == nil {
t.Errorf("Expected error, got nil")
} else if tt.errorMsg != "" && !strings.Contains(err.Error(), tt.errorMsg) {
t.Errorf("Expected error to contain '%s', got: %v", tt.errorMsg, err)
}
} else {
if err != nil {
t.Errorf("Expected no error, got: %v", err)
}
}
})
}
}
func TestGetValidColumns(t *testing.T) {
model := TestModel{}
validator := NewColumnValidator(model)
columns := validator.GetValidColumns()
if len(columns) == 0 {
t.Error("Expected to get valid columns, got empty list")
}
// Should have at least the columns from TestModel
if len(columns) < 6 {
t.Errorf("Expected at least 6 columns, got %d", len(columns))
}
}
// Test with Bun tags specifically
type BunModel struct {
ID int64 `bun:"id,pk"`
Name string `bun:"name"`
Email string `bun:"user_email"`
}
func TestBunTagSupport(t *testing.T) {
model := BunModel{}
validator := NewColumnValidator(model)
// Test that bun tags are properly recognized
tests := []struct {
column string
shouldError bool
}{
{"id", false},
{"name", false},
{"user_email", false}, // Bun tag specifies this name
{"email", true}, // JSON tag would be "email", but bun tag says "user_email"
}
for _, tt := range tests {
t.Run(tt.column, func(t *testing.T) {
err := validator.ValidateColumn(tt.column)
if tt.shouldError && err == nil {
t.Errorf("Expected error for column '%s'", tt.column)
}
if !tt.shouldError && err != nil {
t.Errorf("Expected no error for column '%s', got: %v", tt.column, err)
}
})
}
}
func TestFilterValidColumns(t *testing.T) {
model := TestModel{}
validator := NewColumnValidator(model)
tests := []struct {
name string
input []string
expectedOutput []string
}{
{
name: "All valid columns",
input: []string{"id", "name", "email"},
expectedOutput: []string{"id", "name", "email"},
},
{
name: "Mix of valid and invalid",
input: []string{"id", "invalid_col", "name", "bad_col", "email"},
expectedOutput: []string{"id", "name", "email"},
},
{
name: "All invalid columns",
input: []string{"bad1", "bad2"},
expectedOutput: []string{},
},
{
name: "With CQL prefix (should pass)",
input: []string{"id", "cqlComputed", "name"},
expectedOutput: []string{"id", "cqlComputed", "name"},
},
{
name: "Empty input",
input: []string{},
expectedOutput: []string{},
},
{
name: "Nil input",
input: nil,
expectedOutput: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := validator.FilterValidColumns(tt.input)
if len(result) != len(tt.expectedOutput) {
t.Errorf("Expected %d columns, got %d", len(tt.expectedOutput), len(result))
}
for i, col := range result {
if col != tt.expectedOutput[i] {
t.Errorf("At index %d: expected %s, got %s", i, tt.expectedOutput[i], col)
}
}
})
}
}
func TestFilterRequestOptions(t *testing.T) {
model := TestModel{}
validator := NewColumnValidator(model)
options := RequestOptions{
Columns: []string{"id", "name", "invalid_col"},
OmitColumns: []string{"email", "bad_col"},
Filters: []FilterOption{
{Column: "name", Operator: "eq", Value: "test"},
{Column: "invalid_col", Operator: "eq", Value: "test"},
},
Sort: []SortOption{
{Column: "id", Direction: "ASC"},
{Column: "bad_col", Direction: "DESC"},
},
}
filtered := validator.FilterRequestOptions(options)
// Check Columns
if len(filtered.Columns) != 2 {
t.Errorf("Expected 2 columns, got %d", len(filtered.Columns))
}
if filtered.Columns[0] != "id" || filtered.Columns[1] != "name" {
t.Errorf("Expected columns [id, name], got %v", filtered.Columns)
}
// Check OmitColumns
if len(filtered.OmitColumns) != 1 {
t.Errorf("Expected 1 omit column, got %d", len(filtered.OmitColumns))
}
if filtered.OmitColumns[0] != "email" {
t.Errorf("Expected omit column [email], got %v", filtered.OmitColumns)
}
// Check Filters
if len(filtered.Filters) != 1 {
t.Errorf("Expected 1 filter, got %d", len(filtered.Filters))
}
if filtered.Filters[0].Column != "name" {
t.Errorf("Expected filter column 'name', got %s", filtered.Filters[0].Column)
}
// Check Sort
if len(filtered.Sort) != 1 {
t.Errorf("Expected 1 sort, got %d", len(filtered.Sort))
}
if filtered.Sort[0].Column != "id" {
t.Errorf("Expected sort column 'id', got %s", filtered.Sort[0].Column)
}
}

View File

@@ -38,12 +38,28 @@ func (r *DefaultModelRegistry) RegisterModel(name string, model interface{}) err
return fmt.Errorf("model cannot be nil")
}
if modelType.Kind() == reflect.Ptr {
return fmt.Errorf("model must be a non-pointer struct, got pointer to %s", modelType.Elem().Kind())
originalType := modelType
// Unwrap pointers, slices, and arrays to check the underlying type
for modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array {
modelType = modelType.Elem()
}
// Validate that the underlying type is a struct
if modelType.Kind() != reflect.Struct {
return fmt.Errorf("model must be a struct, got %s", modelType.Kind())
return fmt.Errorf("model must be a struct or pointer to struct, got %s", originalType.String())
}
// If a pointer/slice/array was passed, unwrap to the base struct
if originalType != modelType {
// Create a zero value of the struct type
model = reflect.New(modelType).Elem().Interface()
}
// Additional check: ensure model is not a pointer
finalType := reflect.TypeOf(model)
if finalType.Kind() == reflect.Ptr {
return fmt.Errorf("model must be a non-pointer struct, got pointer to %s. Use MyModel{} instead of &MyModel{}", finalType.Elem().Name())
}
r.models[name] = model

View File

@@ -6,6 +6,7 @@ import (
"fmt"
"net/http"
"reflect"
"runtime/debug"
"strings"
"github.com/Warky-Devs/ResolveSpec/pkg/common"
@@ -26,8 +27,22 @@ func NewHandler(db common.Database, registry common.ModelRegistry) *Handler {
}
}
// handlePanic is a helper function to handle panics with stack traces
func (h *Handler) handlePanic(w common.ResponseWriter, method string, err interface{}) {
stack := debug.Stack()
logger.Error("Panic in %s: %v\nStack trace:\n%s", method, err, string(stack))
h.sendError(w, http.StatusInternalServerError, "internal_error", fmt.Sprintf("Internal server error in %s", method), fmt.Errorf("%v", err))
}
// Handle processes API requests through router-agnostic interface
func (h *Handler) Handle(w common.ResponseWriter, r common.Request, params map[string]string) {
// Capture panics and return error response
defer func() {
if err := recover(); err != nil {
h.handlePanic(w, "Handle", err)
}
}()
ctx := context.Background()
body, err := r.Body()
@@ -58,6 +73,26 @@ func (h *Handler) Handle(w common.ResponseWriter, r common.Request, params map[s
return
}
// Validate that the model is a struct type (not a slice or pointer to slice)
modelType := reflect.TypeOf(model)
originalType := modelType
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
modelType = modelType.Elem()
}
if modelType == nil || modelType.Kind() != reflect.Struct {
logger.Error("Model for %s.%s must be a struct type, got %v. Please register models as struct types, not slices or pointers to slices.", schema, entity, originalType)
h.sendError(w, http.StatusInternalServerError, "invalid_model_type",
fmt.Sprintf("Model must be a struct type, got %v. Ensure you register the struct (e.g., ModelCoreAccount{}) not a slice (e.g., []*ModelCoreAccount)", originalType),
fmt.Errorf("invalid model type: %v", originalType))
return
}
// If the registered model was a pointer or slice, use the unwrapped struct type
if originalType != modelType {
model = reflect.New(modelType).Elem().Interface()
}
// Create a pointer to the model type for database operations
modelPtr := reflect.New(reflect.TypeOf(model)).Interface()
tableName := h.getTableName(schema, entity, model)
@@ -65,6 +100,10 @@ func (h *Handler) Handle(w common.ResponseWriter, r common.Request, params map[s
// Add request-scoped data to context
ctx = WithRequestData(ctx, schema, entity, tableName, model, modelPtr)
// Validate and filter columns in options (log warnings for invalid columns)
validator := common.NewColumnValidator(model)
req.Options = validator.FilterRequestOptions(req.Options)
switch req.Operation {
case "read":
h.handleRead(ctx, w, id, req.Options)
@@ -82,6 +121,13 @@ func (h *Handler) Handle(w common.ResponseWriter, r common.Request, params map[s
// HandleGet processes GET requests for metadata
func (h *Handler) HandleGet(w common.ResponseWriter, r common.Request, params map[string]string) {
// Capture panics and return error response
defer func() {
if err := recover(); err != nil {
h.handlePanic(w, "HandleGet", err)
}
}()
schema := params["schema"]
entity := params["entity"]
@@ -99,16 +145,39 @@ func (h *Handler) HandleGet(w common.ResponseWriter, r common.Request, params ma
}
func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id string, options common.RequestOptions) {
// Capture panics and return error response
defer func() {
if err := recover(); err != nil {
h.handlePanic(w, "handleRead", err)
}
}()
schema := GetSchema(ctx)
entity := GetEntity(ctx)
tableName := GetTableName(ctx)
model := GetModel(ctx)
modelPtr := GetModelPtr(ctx)
// Validate and unwrap model type to get base struct
modelType := reflect.TypeOf(model)
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
modelType = modelType.Elem()
}
if modelType == nil || modelType.Kind() != reflect.Struct {
logger.Error("Model must be a struct type, got %v for %s.%s", modelType, schema, entity)
h.sendError(w, http.StatusInternalServerError, "invalid_model", "Model must be a struct type", fmt.Errorf("invalid model type: %v", modelType))
return
}
logger.Info("Reading records from %s.%s", schema, entity)
query := h.db.NewSelect().Model(modelPtr)
query = query.Table(tableName)
// Create the model pointer for Scan() operations
// We don't set it on the query to avoid table duplication in FROM clause
sliceType := reflect.SliceOf(reflect.PointerTo(modelType))
modelPtr := reflect.New(sliceType).Interface()
// Use only Table() - model will be provided to Scan() directly
query := h.db.NewSelect().Table(tableName)
// Apply column selection
if len(options.Columns) > 0 {
@@ -160,8 +229,8 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
var result interface{}
if id != "" {
logger.Debug("Querying single record with ID: %s", id)
// Create a pointer to the struct type for scanning
singleResult := reflect.New(reflect.TypeOf(model)).Interface()
// For single record, create a new pointer to the struct type
singleResult := reflect.New(modelType).Interface()
query = query.Where("id = ?", id)
if err := query.Scan(ctx, singleResult); err != nil {
logger.Error("Error querying record: %v", err)
@@ -171,16 +240,13 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
result = singleResult
} else {
logger.Debug("Querying multiple records")
// Create a slice of pointers to the model type
sliceType := reflect.SliceOf(reflect.PointerTo(reflect.TypeOf(model)))
results := reflect.New(sliceType).Interface()
if err := query.Scan(ctx, results); err != nil {
// Use the modelPtr already created and set on the query
if err := query.Scan(ctx, modelPtr); err != nil {
logger.Error("Error querying records: %v", err)
h.sendError(w, http.StatusInternalServerError, "query_error", "Error executing query", err)
return
}
result = reflect.ValueOf(results).Elem().Interface()
result = reflect.ValueOf(modelPtr).Elem().Interface()
}
logger.Info("Successfully retrieved records")
@@ -203,6 +269,13 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
}
func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, data interface{}, options common.RequestOptions) {
// Capture panics and return error response
defer func() {
if err := recover(); err != nil {
h.handlePanic(w, "handleCreate", err)
}
}()
schema := GetSchema(ctx)
entity := GetEntity(ctx)
tableName := GetTableName(ctx)
@@ -279,6 +352,13 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
}
func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, urlID string, reqID interface{}, data interface{}, options common.RequestOptions) {
// Capture panics and return error response
defer func() {
if err := recover(); err != nil {
h.handlePanic(w, "handleUpdate", err)
}
}()
schema := GetSchema(ctx)
entity := GetEntity(ctx)
tableName := GetTableName(ctx)
@@ -329,6 +409,13 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url
}
func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id string) {
// Capture panics and return error response
defer func() {
if err := recover(); err != nil {
h.handlePanic(w, "handleDelete", err)
}
}()
schema := GetSchema(ctx)
entity := GetEntity(ctx)
tableName := GetTableName(ctx)
@@ -385,19 +472,86 @@ func (h *Handler) applyFilter(query common.SelectQuery, filter common.FilterOpti
}
}
func (h *Handler) getTableName(schema, entity string, model interface{}) string {
if provider, ok := model.(common.TableNameProvider); ok {
return provider.TableName()
// parseTableName splits a table name that may contain schema into separate schema and table
func (h *Handler) parseTableName(fullTableName string) (schema, table string) {
if idx := strings.LastIndex(fullTableName, "."); idx != -1 {
return fullTableName[:idx], fullTableName[idx+1:]
}
return fmt.Sprintf("%s.%s", schema, entity)
return "", fullTableName
}
// getSchemaAndTable returns the schema and table name separately
// It checks SchemaProvider and TableNameProvider interfaces and handles cases where
// the table name may already include the schema (e.g., "public.users")
//
// Priority order:
// 1. If TableName() contains a schema (e.g., "myschema.mytable"), that schema takes precedence
// 2. If model implements SchemaProvider, use that schema
// 3. Otherwise, use the defaultSchema parameter
func (h *Handler) getSchemaAndTable(defaultSchema, entity string, model interface{}) (schema, table string) {
// First check if model provides a table name
// We check this FIRST because the table name might already contain the schema
if tableProvider, ok := model.(common.TableNameProvider); ok {
tableName := tableProvider.TableName()
// IMPORTANT: Check if the table name already contains a schema (e.g., "schema.table")
// This is common when models need to specify a different schema than the default
if tableSchema, tableOnly := h.parseTableName(tableName); tableSchema != "" {
// Table name includes schema - use it and ignore any other schema providers
logger.Debug("TableName() includes schema: %s.%s", tableSchema, tableOnly)
return tableSchema, tableOnly
}
// Table name is just the table name without schema
// Now determine which schema to use
if schemaProvider, ok := model.(common.SchemaProvider); ok {
schema = schemaProvider.SchemaName()
} else {
schema = defaultSchema
}
return schema, tableName
}
// No TableNameProvider, so check for schema and use entity as table name
if schemaProvider, ok := model.(common.SchemaProvider); ok {
schema = schemaProvider.SchemaName()
} else {
schema = defaultSchema
}
// Default to entity name as table
return schema, entity
}
// getTableName returns the full table name including schema (schema.table)
func (h *Handler) getTableName(schema, entity string, model interface{}) string {
schemaName, tableName := h.getSchemaAndTable(schema, entity, model)
if schemaName != "" {
return fmt.Sprintf("%s.%s", schemaName, tableName)
}
return tableName
}
func (h *Handler) generateMetadata(schema, entity string, model interface{}) *common.TableMetadata {
modelType := reflect.TypeOf(model)
if modelType.Kind() == reflect.Ptr {
// Unwrap pointers, slices, and arrays to get to the base struct type
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
modelType = modelType.Elem()
}
// Validate that we have a struct type
if modelType == nil || modelType.Kind() != reflect.Struct {
logger.Error("Model type must be a struct, got %v for %s.%s", modelType, schema, entity)
return &common.TableMetadata{
Schema: schema,
Table: entity,
Columns: make([]common.Column, 0),
Relations: make([]string, 0),
}
}
metadata := &common.TableMetadata{
Schema: schema,
Table: entity,
@@ -541,10 +695,18 @@ type relationshipInfo struct {
func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, preloads []common.PreloadOption) common.SelectQuery {
modelType := reflect.TypeOf(model)
if modelType.Kind() == reflect.Ptr {
// Unwrap pointers, slices, and arrays to get to the base struct type
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
modelType = modelType.Elem()
}
// Validate that we have a struct type
if modelType == nil || modelType.Kind() != reflect.Struct {
logger.Warn("Cannot apply preloads to non-struct type: %v", modelType)
return query
}
for _, preload := range preloads {
logger.Debug("Processing preload for relation: %s", preload.Relation)
relInfo := h.getRelationshipInfo(modelType, preload.Relation)
@@ -568,6 +730,12 @@ func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, pre
}
func (h *Handler) getRelationshipInfo(modelType reflect.Type, relationName string) *relationshipInfo {
// Ensure we have a struct type
if modelType == nil || modelType.Kind() != reflect.Struct {
logger.Warn("Cannot get relationship info from non-struct type: %v", modelType)
return nil
}
for i := 0; i < modelType.NumField(); i++ {
field := modelType.Field(i)
jsonTag := field.Tag.Get("json")

View File

@@ -6,6 +6,7 @@ import (
"fmt"
"net/http"
"reflect"
"runtime/debug"
"strings"
"github.com/Warky-Devs/ResolveSpec/pkg/common"
@@ -27,9 +28,23 @@ func NewHandler(db common.Database, registry common.ModelRegistry) *Handler {
}
}
// handlePanic is a helper function to handle panics with stack traces
func (h *Handler) handlePanic(w common.ResponseWriter, method string, err interface{}) {
stack := debug.Stack()
logger.Error("Panic in %s: %v\nStack trace:\n%s", method, err, string(stack))
h.sendError(w, http.StatusInternalServerError, "internal_error", fmt.Sprintf("Internal server error in %s", method), fmt.Errorf("%v", err))
}
// Handle processes API requests through router-agnostic interface
// Options are read from HTTP headers instead of request body
func (h *Handler) Handle(w common.ResponseWriter, r common.Request, params map[string]string) {
// Capture panics and return error response
defer func() {
if err := recover(); err != nil {
h.handlePanic(w, "Handle", err)
}
}()
ctx := context.Background()
schema := params["schema"]
@@ -52,12 +67,36 @@ func (h *Handler) Handle(w common.ResponseWriter, r common.Request, params map[s
return
}
// Validate that the model is a struct type (not a slice or pointer to slice)
modelType := reflect.TypeOf(model)
originalType := modelType
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
modelType = modelType.Elem()
}
if modelType == nil || modelType.Kind() != reflect.Struct {
logger.Error("Model for %s.%s must be a struct type, got %v. Please register models as struct types, not slices or pointers to slices.", schema, entity, originalType)
h.sendError(w, http.StatusInternalServerError, "invalid_model_type",
fmt.Sprintf("Model must be a struct type, got %v. Ensure you register the struct (e.g., ModelCoreAccount{}) not a slice (e.g., []*ModelCoreAccount)", originalType),
fmt.Errorf("invalid model type: %v", originalType))
return
}
// If the registered model was a pointer or slice, use the unwrapped struct type
if originalType != modelType {
model = reflect.New(modelType).Elem().Interface()
}
modelPtr := reflect.New(reflect.TypeOf(model)).Interface()
tableName := h.getTableName(schema, entity, model)
// Add request-scoped data to context
ctx = WithRequestData(ctx, schema, entity, tableName, model, modelPtr)
// Validate and filter columns in options (log warnings for invalid columns)
validator := common.NewColumnValidator(model)
options = filterExtendedOptions(validator, options)
switch method {
case "GET":
if id != "" {
@@ -107,6 +146,13 @@ func (h *Handler) Handle(w common.ResponseWriter, r common.Request, params map[s
// HandleGet processes GET requests for metadata
func (h *Handler) HandleGet(w common.ResponseWriter, r common.Request, params map[string]string) {
// Capture panics and return error response
defer func() {
if err := recover(); err != nil {
h.handlePanic(w, "HandleGet", err)
}
}()
schema := params["schema"]
entity := params["entity"]
@@ -126,15 +172,38 @@ func (h *Handler) HandleGet(w common.ResponseWriter, r common.Request, params ma
// parseOptionsFromHeaders is now implemented in headers.go
func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id string, options ExtendedRequestOptions) {
// Capture panics and return error response
defer func() {
if err := recover(); err != nil {
h.handlePanic(w, "handleRead", err)
}
}()
schema := GetSchema(ctx)
entity := GetEntity(ctx)
tableName := GetTableName(ctx)
modelPtr := GetModelPtr(ctx)
model := GetModel(ctx)
// Validate and unwrap model type to get base struct
modelType := reflect.TypeOf(model)
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
modelType = modelType.Elem()
}
if modelType == nil || modelType.Kind() != reflect.Struct {
logger.Error("Model must be a struct type, got %v for %s.%s", modelType, schema, entity)
h.sendError(w, http.StatusInternalServerError, "invalid_model", "Model must be a struct type", fmt.Errorf("invalid model type: %v", modelType))
return
}
// Create a pointer to a slice of pointers to the model type for query results
modelPtr := reflect.New(reflect.SliceOf(reflect.PointerTo(modelType))).Interface()
logger.Info("Reading records from %s.%s", schema, entity)
query := h.db.NewSelect().Model(modelPtr)
query = query.Table(tableName)
// Use Table() with the resolved table name
// Model will be provided to Scan() directly to avoid table duplication in FROM clause
query := h.db.NewSelect().Table(tableName)
// Apply column selection
if len(options.Columns) > 0 {
@@ -223,10 +292,8 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
query = query.Offset(*options.Offset)
}
// Execute query - create a slice of pointers to the model type
model := GetModel(ctx)
resultSlice := reflect.New(reflect.SliceOf(reflect.PointerTo(reflect.TypeOf(model)))).Interface()
if err := query.Scan(ctx, resultSlice); err != nil {
// Execute query - modelPtr was already created earlier
if err := query.Scan(ctx, modelPtr); err != nil {
logger.Error("Error executing query: %v", err)
h.sendError(w, http.StatusInternalServerError, "query_error", "Error executing query", err)
return
@@ -248,10 +315,17 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
Offset: offset,
}
h.sendFormattedResponse(w, resultSlice, metadata, options)
h.sendFormattedResponse(w, modelPtr, metadata, options)
}
func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, data interface{}, options ExtendedRequestOptions) {
// Capture panics and return error response
defer func() {
if err := recover(); err != nil {
h.handlePanic(w, "handleCreate", err)
}
}()
schema := GetSchema(ctx)
entity := GetEntity(ctx)
tableName := GetTableName(ctx)
@@ -322,6 +396,13 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
}
func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id string, idPtr *int64, data interface{}, options ExtendedRequestOptions) {
// Capture panics and return error response
defer func() {
if err := recover(); err != nil {
h.handlePanic(w, "handleUpdate", err)
}
}()
schema := GetSchema(ctx)
entity := GetEntity(ctx)
tableName := GetTableName(ctx)
@@ -369,6 +450,13 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id
}
func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id string) {
// Capture panics and return error response
defer func() {
if err := recover(); err != nil {
h.handlePanic(w, "handleDelete", err)
}
}()
schema := GetSchema(ctx)
entity := GetEntity(ctx)
tableName := GetTableName(ctx)
@@ -448,28 +536,85 @@ func (h *Handler) applyFilter(query common.SelectQuery, filter common.FilterOpti
}
}
func (h *Handler) getTableName(schema, entity string, model interface{}) string {
// Check if model implements TableNameProvider
if provider, ok := model.(common.TableNameProvider); ok {
tableName := provider.TableName()
if tableName != "" {
return tableName
// parseTableName splits a table name that may contain schema into separate schema and table
func (h *Handler) parseTableName(fullTableName string) (schema, table string) {
if idx := strings.LastIndex(fullTableName, "."); idx != -1 {
return fullTableName[:idx], fullTableName[idx+1:]
}
return "", fullTableName
}
// getSchemaAndTable returns the schema and table name separately
// It checks SchemaProvider and TableNameProvider interfaces and handles cases where
// the table name may already include the schema (e.g., "public.users")
//
// Priority order:
// 1. If TableName() contains a schema (e.g., "myschema.mytable"), that schema takes precedence
// 2. If model implements SchemaProvider, use that schema
// 3. Otherwise, use the defaultSchema parameter
func (h *Handler) getSchemaAndTable(defaultSchema, entity string, model interface{}) (schema, table string) {
// First check if model provides a table name
// We check this FIRST because the table name might already contain the schema
if tableProvider, ok := model.(common.TableNameProvider); ok {
tableName := tableProvider.TableName()
// IMPORTANT: Check if the table name already contains a schema (e.g., "schema.table")
// This is common when models need to specify a different schema than the default
if tableSchema, tableOnly := h.parseTableName(tableName); tableSchema != "" {
// Table name includes schema - use it and ignore any other schema providers
logger.Debug("TableName() includes schema: %s.%s", tableSchema, tableOnly)
return tableSchema, tableOnly
}
// Table name is just the table name without schema
// Now determine which schema to use
if schemaProvider, ok := model.(common.SchemaProvider); ok {
schema = schemaProvider.SchemaName()
} else {
schema = defaultSchema
}
return schema, tableName
}
// Default to schema.entity
if schema != "" {
return fmt.Sprintf("%s.%s", schema, entity)
// No TableNameProvider, so check for schema and use entity as table name
if schemaProvider, ok := model.(common.SchemaProvider); ok {
schema = schemaProvider.SchemaName()
} else {
schema = defaultSchema
}
return entity
// Default to entity name as table
return schema, entity
}
// getTableName returns the full table name including schema (schema.table)
func (h *Handler) getTableName(schema, entity string, model interface{}) string {
schemaName, tableName := h.getSchemaAndTable(schema, entity, model)
if schemaName != "" {
return fmt.Sprintf("%s.%s", schemaName, tableName)
}
return tableName
}
func (h *Handler) generateMetadata(schema, entity string, model interface{}) *common.TableMetadata {
modelType := reflect.TypeOf(model)
if modelType.Kind() == reflect.Ptr {
// Unwrap pointers, slices, and arrays to get to the base struct type
for modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array {
modelType = modelType.Elem()
}
// Validate that we have a struct type
if modelType.Kind() != reflect.Struct {
logger.Error("Model type must be a struct, got %s for %s.%s", modelType.Kind(), schema, entity)
return &common.TableMetadata{
Schema: schema,
Table: h.getTableName(schema, entity, model),
Columns: []common.Column{},
}
}
tableName := h.getTableName(schema, entity, model)
metadata := &common.TableMetadata{
@@ -555,7 +700,7 @@ func (h *Handler) sendFormattedResponse(w common.ResponseWriter, data interface{
if options.CleanJSON {
data = h.cleanJSON(data)
}
w.SetHeader("Content-Type", "application/json")
// Format response based on response format option
switch options.ResponseFormat {
case "simple":
@@ -610,3 +755,41 @@ func (h *Handler) sendError(w common.ResponseWriter, statusCode int, code, messa
w.WriteHeader(statusCode)
w.WriteJSON(response)
}
// filterExtendedOptions filters all column references, removing invalid ones and logging warnings
func filterExtendedOptions(validator *common.ColumnValidator, options ExtendedRequestOptions) ExtendedRequestOptions {
filtered := options
// Filter base RequestOptions
filtered.RequestOptions = validator.FilterRequestOptions(options.RequestOptions)
// Filter SearchColumns
filtered.SearchColumns = validator.FilterValidColumns(options.SearchColumns)
// Filter AdvancedSQL column keys
filteredAdvSQL := make(map[string]string)
for colName, sqlExpr := range options.AdvancedSQL {
if validator.IsValidColumn(colName) {
filteredAdvSQL[colName] = sqlExpr
} else {
logger.Warn("Invalid column in advanced SQL removed: %s", colName)
}
}
filtered.AdvancedSQL = filteredAdvSQL
// ComputedQL columns are allowed to be any name since they're computed
// No filtering needed for ComputedQL keys
filtered.ComputedQL = options.ComputedQL
// Filter Expand columns
filteredExpands := make([]ExpandOption, 0, len(options.Expand))
for _, expand := range options.Expand {
filteredExpand := expand
// Don't validate relation name, only columns
filteredExpand.Columns = validator.FilterValidColumns(expand.Columns)
filteredExpands = append(filteredExpands, filteredExpand)
}
filtered.Expand = filteredExpands
return filtered
}

View File

@@ -19,21 +19,21 @@ type ExtendedRequestOptions struct {
CleanJSON bool
// Advanced filtering
SearchColumns []string
SearchColumns []string
CustomSQLWhere string
CustomSQLOr string
CustomSQLOr string
// Joins
Expand []ExpandOption
// Advanced features
AdvancedSQL map[string]string // Column -> SQL expression
ComputedQL map[string]string // Column -> CQL expression
Distinct bool
SkipCount bool
SkipCache bool
AdvancedSQL map[string]string // Column -> SQL expression
ComputedQL map[string]string // Column -> CQL expression
Distinct bool
SkipCount bool
SkipCache bool
FetchRowNumber *string
PKRow *string
PKRow *string
// Response format
ResponseFormat string // "simple", "detail", "syncfusion"
@@ -42,42 +42,58 @@ type ExtendedRequestOptions struct {
AtomicTransaction bool
// Cursor pagination
CursorForward string
CursorForward string
CursorBackward string
}
// ExpandOption represents a relation expansion configuration
type ExpandOption struct {
Relation string
Columns []string
Where string
Sort string
Columns []string
Where string
Sort string
}
// decodeHeaderValue decodes base64 encoded header values
// Supports ZIP_ and __ prefixes for base64 encoding
func decodeHeaderValue(value string) string {
// Check for ZIP_ prefix
if strings.HasPrefix(value, "ZIP_") {
decoded, err := base64.StdEncoding.DecodeString(value[4:])
if err == nil {
return string(decoded)
str, _ := DecodeParam(value)
return str
}
// DecodeParam - Decodes parameter string and returns unencoded string
func DecodeParam(pStr string) (string, error) {
var code string = pStr
if strings.HasPrefix(pStr, "ZIP_") {
code = strings.ReplaceAll(pStr, "ZIP_", "")
code = strings.ReplaceAll(code, "\n", "")
code = strings.ReplaceAll(code, "\r", "")
code = strings.ReplaceAll(code, " ", "")
strDat, err := base64.StdEncoding.DecodeString(code)
if err != nil {
return code, fmt.Errorf("failed to read parameter base64: %v", err)
} else {
code = string(strDat)
}
} else if strings.HasPrefix(pStr, "__") {
code = strings.ReplaceAll(pStr, "__", "")
code = strings.ReplaceAll(code, "\n", "")
code = strings.ReplaceAll(code, "\r", "")
code = strings.ReplaceAll(code, " ", "")
strDat, err := base64.StdEncoding.DecodeString(code)
if err != nil {
return code, fmt.Errorf("failed to read parameter base64: %v", err)
} else {
code = string(strDat)
}
logger.Warn("Failed to decode ZIP_ prefixed value: %v", err)
return value
}
// Check for __ prefix
if strings.HasPrefix(value, "__") {
decoded, err := base64.StdEncoding.DecodeString(value[2:])
if err == nil {
return string(decoded)
}
logger.Warn("Failed to decode __ prefixed value: %v", err)
return value
if strings.HasPrefix(code, "ZIP_") || strings.HasPrefix(code, "__") {
code, _ = DecodeParam(code)
}
return value
return code, nil
}
// parseOptionsFromHeaders parses all request options from HTTP headers
@@ -85,12 +101,13 @@ func (h *Handler) parseOptionsFromHeaders(r common.Request) ExtendedRequestOptio
options := ExtendedRequestOptions{
RequestOptions: common.RequestOptions{
Filters: make([]common.FilterOption, 0),
Sort: make([]common.SortOption, 0),
Sort: make([]common.SortOption, 0),
Preload: make([]common.PreloadOption, 0),
},
AdvancedSQL: make(map[string]string),
ComputedQL: make(map[string]string),
Expand: make([]ExpandOption, 0),
AdvancedSQL: make(map[string]string),
ComputedQL: make(map[string]string),
Expand: make([]ExpandOption, 0),
ResponseFormat: "simple", // Default response format
}
// Get all headers
@@ -198,6 +215,9 @@ func (h *Handler) parseSelectFields(options *ExtendedRequestOptions, value strin
return
}
options.Columns = h.parseCommaSeparated(value)
if len(options.Columns) > 1 {
options.CleanJSON = true
}
}
// parseNotSelectFields parses x-not-select-fields header
@@ -206,15 +226,18 @@ func (h *Handler) parseNotSelectFields(options *ExtendedRequestOptions, value st
return
}
options.OmitColumns = h.parseCommaSeparated(value)
if len(options.OmitColumns) > 1 {
options.CleanJSON = true
}
}
// parseFieldFilter parses x-fieldfilter-{colname} header (exact match)
func (h *Handler) parseFieldFilter(options *ExtendedRequestOptions, headerKey, value string) {
colName := strings.TrimPrefix(headerKey, "x-fieldfilter-")
options.Filters = append(options.Filters, common.FilterOption{
Column: colName,
Column: colName,
Operator: "eq",
Value: value,
Value: value,
})
}
@@ -223,9 +246,9 @@ func (h *Handler) parseSearchFilter(options *ExtendedRequestOptions, headerKey,
colName := strings.TrimPrefix(headerKey, "x-searchfilter-")
// Use ILIKE for fuzzy search
options.Filters = append(options.Filters, common.FilterOption{
Column: colName,
Column: colName,
Operator: "ilike",
Value: "%" + value + "%",
Value: "%" + value + "%",
})
}
@@ -407,7 +430,7 @@ func (h *Handler) parseSorting(options *ExtendedRequestOptions, value string) {
}
options.Sort = append(options.Sort, common.SortOption{
Column: colName,
Column: colName,
Direction: direction,
})
}