Compare commits

...

17 Commits

Author SHA1 Message Date
Hein
3eb17666bf Migration to Bitech 2025-11-10 11:43:15 +02:00
Hein
c8704c07dd Added cursor filters and hooks 2025-11-10 10:22:55 +02:00
Hein
fc82a9bc50 todo 2025-11-07 16:30:02 +02:00
Hein
c26ea3cd61 todo 2025-11-07 16:12:09 +02:00
Hein
a5d97cc07b Fixed the filters 2025-11-07 15:58:24 +02:00
Hein
0899ba5029 Pointer Fixes 2025-11-07 14:22:58 +02:00
Hein
c84dd7dc91 Lets try the model approach again 2025-11-07 14:18:15 +02:00
Hein
f1c6b36374 Bun Adaptor updates 2025-11-07 14:03:40 +02:00
Hein
abee5c942f Count Fixes 2025-11-07 13:54:24 +02:00
Hein
2e9a0bd51a Better model pointers 2025-11-07 13:45:08 +02:00
Hein
f518a3c73c Some validation and header decoding 2025-11-07 13:31:48 +02:00
Hein
07c239aaa1 Make sure to enable Clean JSON when select fields given 2025-11-07 11:00:56 +02:00
Hein
1adca4c49b Content Types and Respose fixes for restheadpsec 2025-11-07 10:55:42 +02:00
Hein
eefed23766 COUNT queries were generating incorrect SQL with the table appearing twice 2025-11-07 10:37:53 +02:00
Hein
3b2d05465e Fixed tablename and schema lookups 2025-11-07 10:28:14 +02:00
Hein
e88018543e Reflect safty 2025-11-07 09:47:12 +02:00
Hein
e7e5754a47 Added panic catches 2025-11-07 09:32:37 +02:00
29 changed files with 3774 additions and 283 deletions

View File

@@ -1,6 +1,6 @@
MIT License
Copyright (c) 2025 Warky Devs Pty Ltd
Copyright (c) 2025
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal

390
README.md
View File

@@ -1,9 +1,16 @@
# 📜 ResolveSpec 📜
ResolveSpec is a flexible and powerful REST API specification and implementation that provides GraphQL-like capabilities while maintaining REST simplicity. It allows for dynamic data querying, relationship preloading, and complex filtering through a clean, URL-based interface.
ResolveSpec is a flexible and powerful REST API specification and implementation that provides GraphQL-like capabilities while maintaining REST simplicity. It offers **two complementary approaches**:
1. **ResolveSpec** - Body-based API with JSON request options
2. **RestHeadSpec** - Header-based API where query options are passed via HTTP headers
Both share the same core architecture and provide dynamic data querying, relationship preloading, and complex filtering.
**🆕 New in v2.0**: Database-agnostic architecture with support for GORM, Bun, and other ORMs. Router-flexible design works with Gorilla Mux, Gin, Echo, and more.
**🆕 New in v2.1**: RestHeadSpec (HeaderSpec) - Header-based REST API with lifecycle hooks, cursor pagination, and advanced filtering.
![slogan](./generated_slogan.webp)
## Table of Contents
@@ -11,30 +18,46 @@ ResolveSpec is a flexible and powerful REST API specification and implementation
- [Features](#features)
- [Installation](#installation)
- [Quick Start](#quick-start)
- [ResolveSpec (Body-Based API)](#resolvespec-body-based-api)
- [RestHeadSpec (Header-Based API)](#restheadspec-header-based-api)
- [Existing Code (Backward Compatible)](#option-1-existing-code-backward-compatible)
- [New Database-Agnostic API](#option-2-new-database-agnostic-api)
- [Router Integration](#router-integration)
- [Migration from v1.x](#migration-from-v1x)
- [Architecture](#architecture)
- [API Structure](#api-structure)
- [RestHeadSpec: Header-Based API](#restheadspec-header-based-api-1)
- [Lifecycle Hooks](#lifecycle-hooks)
- [Cursor Pagination](#cursor-pagination)
- [Example Usage](#example-usage)
- [Testing](#testing)
- [What's New in v2.0](#whats-new-in-v20)
## Features
### Core Features
- **Dynamic Data Querying**: Select specific columns and relationships to return
- **Relationship Preloading**: Load related entities with custom column selection and filters
- **Complex Filtering**: Apply multiple filters with various operators
- **Sorting**: Multi-column sort support
- **Pagination**: Built-in limit and offset support
- **Pagination**: Built-in limit/offset and cursor-based pagination
- **Computed Columns**: Define virtual columns for complex calculations
- **Custom Operators**: Add custom SQL conditions when needed
### Architecture (v2.0+)
- **🆕 Database Agnostic**: Works with GORM, Bun, or any database layer through adapters
- **🆕 Router Flexible**: Integrates with Gorilla Mux, Gin, Echo, or custom routers
- **🆕 Backward Compatible**: Existing code works without changes
- **🆕 Better Testing**: Mockable interfaces for easy unit testing
### RestHeadSpec (v2.1+)
- **🆕 Header-Based API**: All query options passed via HTTP headers instead of request body
- **🆕 Lifecycle Hooks**: Before/after hooks for create, read, update, and delete operations
- **🆕 Cursor Pagination**: Efficient cursor-based pagination with complex sort support
- **🆕 Multiple Response Formats**: Simple, detailed, and Syncfusion-compatible formats
- **🆕 Advanced Filtering**: Field filters, search operators, AND/OR logic, and custom SQL
- **🆕 Base64 Encoding**: Support for base64-encoded header values
## API Structure
### URL Patterns
@@ -66,6 +89,216 @@ ResolveSpec is a flexible and powerful REST API specification and implementation
}
```
## RestHeadSpec: Header-Based API
RestHeadSpec provides an alternative REST API approach where all query options are passed via HTTP headers instead of the request body. This provides cleaner separation between data and metadata.
### Quick Example
```http
GET /public/users HTTP/1.1
Host: api.example.com
X-Select-Fields: id,name,email,department_id
X-Preload: department:id,name
X-FieldFilter-Status: active
X-SearchOp-Gte-Age: 18
X-Sort: -created_at,+name
X-Limit: 50
X-DetailApi: true
```
### Setup with GORM
```go
import "github.com/bitechdev/ResolveSpec/pkg/restheadspec"
import "github.com/gorilla/mux"
// Create handler
handler := restheadspec.NewHandlerWithGORM(db)
// Register models using schema.table format
handler.Registry.RegisterModel("public.users", &User{})
handler.Registry.RegisterModel("public.posts", &Post{})
// Setup routes
router := mux.NewRouter()
restheadspec.SetupMuxRoutes(router, handler)
// Start server
http.ListenAndServe(":8080", router)
```
### Setup with Bun ORM
```go
import "github.com/bitechdev/ResolveSpec/pkg/restheadspec"
import "github.com/uptrace/bun"
// Create handler with Bun
handler := restheadspec.NewHandlerWithBun(bunDB)
// Register models
handler.Registry.RegisterModel("public.users", &User{})
// Setup routes (same as GORM)
router := mux.NewRouter()
restheadspec.SetupMuxRoutes(router, handler)
```
### Common Headers
| Header | Description | Example |
|--------|-------------|---------|
| `X-Select-Fields` | Columns to include | `id,name,email` |
| `X-Not-Select-Fields` | Columns to exclude | `password,internal_notes` |
| `X-FieldFilter-{col}` | Exact match filter | `X-FieldFilter-Status: active` |
| `X-SearchFilter-{col}` | Fuzzy search (ILIKE) | `X-SearchFilter-Name: john` |
| `X-SearchOp-{op}-{col}` | Filter with operator | `X-SearchOp-Gte-Age: 18` |
| `X-Preload` | Preload relations | `posts:id,title` |
| `X-Sort` | Sort columns | `-created_at,+name` |
| `X-Limit` | Limit results | `50` |
| `X-Offset` | Offset for pagination | `100` |
| `X-Clean-JSON` | Remove null/empty fields | `true` |
**Available Operators**: `eq`, `neq`, `gt`, `gte`, `lt`, `lte`, `contains`, `startswith`, `endswith`, `between`, `betweeninclusive`, `in`, `empty`, `notempty`
For complete header documentation, see [pkg/restheadspec/HEADERS.md](pkg/restheadspec/HEADERS.md).
### Lifecycle Hooks
RestHeadSpec supports lifecycle hooks for all CRUD operations:
```go
import "github.com/bitechdev/ResolveSpec/pkg/restheadspec"
// Create handler
handler := restheadspec.NewHandlerWithGORM(db)
// Register a before-read hook (e.g., for authorization)
handler.Hooks.Register(restheadspec.BeforeRead, func(ctx *restheadspec.HookContext) error {
// Check permissions
if !userHasPermission(ctx.Context, ctx.Entity) {
return fmt.Errorf("unauthorized access to %s", ctx.Entity)
}
// Modify query options
ctx.Options.Limit = ptr(100) // Enforce max limit
return nil
})
// Register an after-read hook (e.g., for data transformation)
handler.Hooks.Register(restheadspec.AfterRead, func(ctx *restheadspec.HookContext) error {
// Transform or filter results
if users, ok := ctx.Result.([]User); ok {
for i := range users {
users[i].Email = maskEmail(users[i].Email)
}
}
return nil
})
// Register a before-create hook (e.g., for validation)
handler.Hooks.Register(restheadspec.BeforeCreate, func(ctx *restheadspec.HookContext) error {
// Validate data
if user, ok := ctx.Data.(*User); ok {
if user.Email == "" {
return fmt.Errorf("email is required")
}
// Add timestamps
user.CreatedAt = time.Now()
}
return nil
})
```
**Available Hook Types**:
- `BeforeRead`, `AfterRead`
- `BeforeCreate`, `AfterCreate`
- `BeforeUpdate`, `AfterUpdate`
- `BeforeDelete`, `AfterDelete`
**HookContext** provides:
- `Context`: Request context
- `Handler`: Access to handler, database, and registry
- `Schema`, `Entity`, `TableName`: Request info
- `Model`: The registered model type
- `Options`: Parsed request options (filters, sorting, etc.)
- `ID`: Record ID (for single-record operations)
- `Data`: Request data (for create/update)
- `Result`: Operation result (for after hooks)
- `Writer`: Response writer (allows hooks to modify response)
### Cursor Pagination
RestHeadSpec supports efficient cursor-based pagination for large datasets:
```http
GET /public/posts HTTP/1.1
X-Sort: -created_at,+id
X-Limit: 50
X-Cursor-Forward: <cursor_token>
```
**How it works**:
1. First request returns results + cursor token in response
2. Subsequent requests use `X-Cursor-Forward` or `X-Cursor-Backward`
3. Cursor maintains consistent ordering even with data changes
4. Supports complex multi-column sorting
**Benefits over offset pagination**:
- Consistent results when data changes
- Better performance for large offsets
- Prevents "skipped" or duplicate records
- Works with complex sort expressions
**Example with hooks**:
```go
// Enable cursor pagination in a hook
handler.Hooks.Register(restheadspec.BeforeRead, func(ctx *restheadspec.HookContext) error {
// For large tables, enforce cursor pagination
if ctx.Entity == "posts" && ctx.Options.Offset != nil && *ctx.Options.Offset > 1000 {
return fmt.Errorf("use cursor pagination for large offsets")
}
return nil
})
```
### Response Formats
RestHeadSpec supports multiple response formats:
**1. Simple Format** (`X-SimpleApi: true`):
```json
[
{ "id": 1, "name": "John" },
{ "id": 2, "name": "Jane" }
]
```
**2. Detail Format** (`X-DetailApi: true`, default):
```json
{
"success": true,
"data": [...],
"metadata": {
"total": 100,
"filtered": 100,
"limit": 50,
"offset": 0
}
}
```
**3. Syncfusion Format** (`X-Syncfusion: true`):
```json
{
"result": [...],
"count": 100
}
```
## Example Usage
### Reading Data with Related Entities
@@ -110,17 +343,73 @@ POST /core/users
## Installation
```bash
go get github.com/Warky-Devs/ResolveSpec
go get github.com/bitechdev/ResolveSpec
```
## Quick Start
### ResolveSpec (Body-Based API)
ResolveSpec uses JSON request bodies to specify query options:
```go
import "github.com/bitechdev/ResolveSpec/pkg/resolvespec"
// Create handler
handler := resolvespec.NewAPIHandler(gormDB)
handler.RegisterModel("core", "users", &User{})
// Setup routes
router := mux.NewRouter()
resolvespec.SetupRoutes(router, handler)
// Client makes POST request with body:
// POST /core/users
// {
// "operation": "read",
// "options": {
// "columns": ["id", "name", "email"],
// "filters": [{"column": "status", "operator": "eq", "value": "active"}],
// "limit": 10
// }
// }
```
### RestHeadSpec (Header-Based API)
RestHeadSpec uses HTTP headers for query options instead of request body:
```go
import "github.com/bitechdev/ResolveSpec/pkg/restheadspec"
// Create handler with GORM
handler := restheadspec.NewHandlerWithGORM(db)
// Register models (schema.table format)
handler.Registry.RegisterModel("public.users", &User{})
handler.Registry.RegisterModel("public.posts", &Post{})
// Setup routes with Mux
muxRouter := mux.NewRouter()
restheadspec.SetupMuxRoutes(muxRouter, handler)
// Client makes GET request with headers:
// GET /public/users
// X-Select-Fields: id,name,email
// X-FieldFilter-Status: active
// X-Limit: 10
// X-Sort: -created_at
// X-Preload: posts:id,title
```
See [RestHeadSpec: Header-Based API](#restheadspec-header-based-api-1) for complete header documentation.
### Option 1: Existing Code (Backward Compatible)
Your existing code continues to work without any changes:
```go
import "github.com/Warky-Devs/ResolveSpec/pkg/resolvespec"
import "github.com/bitechdev/ResolveSpec/pkg/resolvespec"
// This still works exactly as before
handler := resolvespec.NewAPIHandler(gormDB)
@@ -131,12 +420,36 @@ handler.RegisterModel("core", "users", &User{})
ResolveSpec v2.0 introduces a new database and router abstraction layer while maintaining **100% backward compatibility**. Your existing code will continue to work without any changes.
### Repository Path Migration
**IMPORTANT**: The repository has moved from `github.com/Warky-Devs/ResolveSpec` to `github.com/bitechdev/ResolveSpec`.
To update your imports:
```bash
# Update go.mod
go mod edit -replace github.com/Warky-Devs/ResolveSpec=github.com/bitechdev/ResolveSpec@latest
go mod tidy
# Or update imports manually in your code
# Old: import "github.com/Warky-Devs/ResolveSpec/pkg/resolvespec"
# New: import "github.com/bitechdev/ResolveSpec/pkg/resolvespec"
```
Alternatively, use find and replace in your project:
```bash
find . -type f -name "*.go" -exec sed -i 's|github.com/Warky-Devs/ResolveSpec|github.com/bitechdev/ResolveSpec|g' {} +
go mod tidy
```
### Migration Timeline
1. **Phase 1**: Continue using existing API (no changes needed)
2. **Phase 2**: Gradually adopt new constructors when convenient
3. **Phase 3**: Switch to interface-based approach for new features
4. **Phase 4**: Optionally switch database backends
1. **Phase 1**: Update repository path (see above)
2. **Phase 2**: Continue using existing API (no changes needed)
3. **Phase 3**: Gradually adopt new constructors when convenient
4. **Phase 4**: Switch to interface-based approach for new features
5. **Phase 5**: Optionally switch database backends or try RestHeadSpec
### Detailed Migration Guide
@@ -144,12 +457,34 @@ For detailed migration instructions, examples, and best practices, see [MIGRATIO
## Architecture
### Two Complementary APIs
```
┌─────────────────────────────────────────────────────┐
│ ResolveSpec Framework │
├─────────────────────┬───────────────────────────────┤
│ ResolveSpec │ RestHeadSpec │
│ (Body-based) │ (Header-based) │
├─────────────────────┴───────────────────────────────┤
│ Common Core Components │
│ • Model Registry • Filters • Preloading │
│ • Sorting • Pagination • Type System │
└──────────────────────┬──────────────────────────────┘
┌──────────────────────────────┐
│ Database Abstraction │
│ [GORM] [Bun] [Custom] │
└──────────────────────────────┘
```
### Database Abstraction Layer
```
Your Application Code
Handler (Business Logic)
Handler (Business Logic)
[Hooks & Middleware] (RestHeadSpec only)
Database Interface
@@ -176,7 +511,7 @@ Your Application Code
#### With GORM (Recommended Migration Path)
```go
import "github.com/Warky-Devs/ResolveSpec/pkg/resolvespec"
import "github.com/bitechdev/ResolveSpec/pkg/resolvespec"
// Create database adapter
dbAdapter := resolvespec.NewGormAdapter(gormDB)
@@ -192,7 +527,7 @@ handler := resolvespec.NewHandler(dbAdapter, registry)
#### With Bun ORM
```go
import "github.com/Warky-Devs/ResolveSpec/pkg/resolvespec"
import "github.com/bitechdev/ResolveSpec/pkg/resolvespec"
import "github.com/uptrace/bun"
// Create Bun adapter (Bun dependency already included)
@@ -415,12 +750,32 @@ func TestHandler(t *testing.T) {
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
## What's New in v2.0
## What's New
### Breaking Changes
### v2.1 (Latest)
**RestHeadSpec - Header-Based REST API**:
- **Header-Based Querying**: All query options via HTTP headers instead of request body
- **Lifecycle Hooks**: Before/after hooks for create, read, update, delete operations
- **Cursor Pagination**: Efficient cursor-based pagination with complex sorting
- **Advanced Filtering**: Field filters, search operators, AND/OR logic
- **Multiple Response Formats**: Simple, detailed, and Syncfusion-compatible responses
- **Base64 Support**: Base64-encoded header values for complex queries
- **Type-Aware Filtering**: Automatic type detection and conversion for filters
**Core Improvements**:
- Better model registry with schema.table format support
- Enhanced validation and error handling
- Improved reflection safety
- Fixed COUNT query issues with table aliasing
- Better pointer handling throughout the codebase
### v2.0
**Breaking Changes**:
- **None!** Full backward compatibility maintained
### New Features
**New Features**:
- **Database Abstraction**: Support for GORM, Bun, and custom ORMs
- **Router Flexibility**: Works with any HTTP router through adapters
- **BunRouter Integration**: Built-in support for uptrace/bunrouter
@@ -428,7 +783,7 @@ This project is licensed under the MIT License - see the [LICENSE](LICENSE) file
- **Enhanced Testing**: Mockable interfaces for comprehensive testing
- **Migration Guide**: Step-by-step migration instructions
### Performance Improvements
**Performance Improvements**:
- More efficient query building through interface design
- Reduced coupling between components
- Better memory management with interface boundaries
@@ -436,8 +791,9 @@ This project is licensed under the MIT License - see the [LICENSE](LICENSE) file
## Acknowledgments
- Inspired by REST, OData, and GraphQL's flexibility
- **Header-based approach**: Inspired by REST best practices and clean API design
- **Database Support**: [GORM](https://gorm.io) and [Bun](https://bun.uptrace.dev/)
- **Router Support**: Gorilla Mux (built-in), Gin, Echo, and others through adapters
- **Router Support**: Gorilla Mux (built-in), BunRouter, Gin, Echo, and others through adapters
- Slogan generated using DALL-E
- AI used for documentation checking and correction
- Community feedback and contributions that made v2.0 possible
- Community feedback and contributions that made v2.0 and v2.1 possible

138
SCHEMA_TABLE_HANDLING.md Normal file
View File

@@ -0,0 +1,138 @@
# Schema and Table Name Handling
This document explains how the handlers properly separate and handle schema and table names.
## Implementation
Both `resolvespec` and `restheadspec` handlers now properly handle schema and table name separation through the following functions:
- `parseTableName(fullTableName)` - Splits "schema.table" into separate components
- `getSchemaAndTable(defaultSchema, entity, model)` - Returns schema and table separately
- `getTableName(schema, entity, model)` - Returns the full "schema.table" format
## Priority Order
When determining the schema and table name, the following priority is used:
1. **If `TableName()` contains a schema** (e.g., "myschema.mytable"), that schema takes precedence
2. **If model implements `SchemaProvider`**, use that schema
3. **Otherwise**, use the `defaultSchema` parameter from the URL/request
## Scenarios
### Scenario 1: Simple table name, default schema
```go
type User struct {
ID string
Name string
}
func (User) TableName() string {
return "users"
}
```
- Request URL: `/api/public/users`
- Result: `schema="public"`, `table="users"`, `fullName="public.users"`
### Scenario 2: Table name includes schema
```go
type User struct {
ID string
Name string
}
func (User) TableName() string {
return "auth.users" // Schema included!
}
```
- Request URL: `/api/public/users` (public is ignored)
- Result: `schema="auth"`, `table="users"`, `fullName="auth.users"`
- **Note**: The schema from `TableName()` takes precedence over the URL schema
### Scenario 3: Using SchemaProvider
```go
type User struct {
ID string
Name string
}
func (User) TableName() string {
return "users"
}
func (User) SchemaName() string {
return "auth"
}
```
- Request URL: `/api/public/users` (public is ignored)
- Result: `schema="auth"`, `table="users"`, `fullName="auth.users"`
### Scenario 4: Table name includes schema AND SchemaProvider
```go
type User struct {
ID string
Name string
}
func (User) TableName() string {
return "core.users" // This wins!
}
func (User) SchemaName() string {
return "auth" // This is ignored
}
```
- Request URL: `/api/public/users`
- Result: `schema="core"`, `table="users"`, `fullName="core.users"`
- **Note**: Schema from `TableName()` takes highest precedence
### Scenario 5: No providers at all
```go
type User struct {
ID string
Name string
}
// No TableName() or SchemaName()
```
- Request URL: `/api/public/users`
- Result: `schema="public"`, `table="users"`, `fullName="public.users"`
- Uses URL schema and entity name
## Key Features
1. **Automatic detection**: The code automatically detects if `TableName()` includes a schema by checking for "."
2. **Backward compatible**: Existing code continues to work
3. **Flexible**: Supports multiple ways to specify schema and table
4. **Debug logging**: Logs when schema is detected in `TableName()` for debugging
## Code Locations
### Handlers
- `/pkg/resolvespec/handler.go:472-531`
- `/pkg/restheadspec/handler.go:534-593`
### Database Adapters
- `/pkg/common/adapters/database/utils.go` - Shared `parseTableName()` function
- `/pkg/common/adapters/database/bun.go` - Bun adapter with separated schema/table
- `/pkg/common/adapters/database/gorm.go` - GORM adapter with separated schema/table
## Adapter Implementation
Both Bun and GORM adapters now properly separate schema and table name:
```go
// BunSelectQuery/GormSelectQuery now have separated fields:
type BunSelectQuery struct {
query *bun.SelectQuery
schema string // Separated schema name
tableName string // Just the table name, without schema
tableAlias string
}
```
When `Model()` or `Table()` is called:
1. The full table name (which may include schema) is parsed
2. Schema and table name are stored separately
3. When building joins, the already-separated table name is used directly
This ensures consistent handling of schema-qualified table names throughout the codebase.

View File

@@ -7,11 +7,11 @@ import (
"os"
"time"
"github.com/Warky-Devs/ResolveSpec/pkg/logger"
"github.com/Warky-Devs/ResolveSpec/pkg/modelregistry"
"github.com/Warky-Devs/ResolveSpec/pkg/testmodels"
"github.com/bitechdev/ResolveSpec/pkg/logger"
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
"github.com/bitechdev/ResolveSpec/pkg/testmodels"
"github.com/Warky-Devs/ResolveSpec/pkg/resolvespec"
"github.com/bitechdev/ResolveSpec/pkg/resolvespec"
"github.com/gorilla/mux"
"github.com/glebarez/sqlite"

2
go.mod
View File

@@ -1,4 +1,4 @@
module github.com/Warky-Devs/ResolveSpec
module github.com/bitechdev/ResolveSpec
go 1.23.0

View File

@@ -6,7 +6,7 @@ import (
"fmt"
"strings"
"github.com/Warky-Devs/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/uptrace/bun"
)
@@ -22,7 +22,10 @@ func NewBunAdapter(db *bun.DB) *BunAdapter {
}
func (b *BunAdapter) NewSelect() common.SelectQuery {
return &BunSelectQuery{query: b.db.NewSelect()}
return &BunSelectQuery{
query: b.db.NewSelect(),
db: b.db,
}
}
func (b *BunAdapter) NewInsert() common.InsertQuery {
@@ -78,16 +81,22 @@ func (b *BunAdapter) RunInTransaction(ctx context.Context, fn func(common.Databa
// BunSelectQuery implements SelectQuery for Bun
type BunSelectQuery struct {
query *bun.SelectQuery
tableName string
db bun.IDB // Store DB connection for count queries
hasModel bool // Track if Model() was called
schema string // Separated schema name
tableName string // Just the table name, without schema
tableAlias string
}
func (b *BunSelectQuery) Model(model interface{}) common.SelectQuery {
b.query = b.query.Model(model)
b.hasModel = true // Mark that we have a model
// Try to get table name from model if it implements TableNameProvider
if provider, ok := model.(common.TableNameProvider); ok {
b.tableName = provider.TableName()
fullTableName := provider.TableName()
// Check if the table name contains schema (e.g., "schema.table")
b.schema, b.tableName = parseTableName(fullTableName)
}
return b
@@ -95,7 +104,8 @@ func (b *BunSelectQuery) Model(model interface{}) common.SelectQuery {
func (b *BunSelectQuery) Table(table string) common.SelectQuery {
b.query = b.query.Table(table)
b.tableName = table
// Check if the table name contains schema (e.g., "schema.table")
b.schema, b.tableName = parseTableName(table)
return b
}
@@ -128,13 +138,9 @@ func (b *BunSelectQuery) Join(query string, args ...interface{}) common.SelectQu
}
}
// If no prefix provided, use the table name as prefix
// If no prefix provided, use the table name as prefix (already separated from schema)
if prefix == "" && b.tableName != "" {
prefix = b.tableName
// Extract just the table name if it has schema
if idx := strings.LastIndex(prefix, "."); idx != -1 {
prefix = prefix[idx+1:]
}
}
// If prefix is provided, add it as an alias in the join
@@ -169,12 +175,9 @@ func (b *BunSelectQuery) LeftJoin(query string, args ...interface{}) common.Sele
}
}
// If no prefix provided, use the table name as prefix
// If no prefix provided, use the table name as prefix (already separated from schema)
if prefix == "" && b.tableName != "" {
prefix = b.tableName
if idx := strings.LastIndex(prefix, "."); idx != -1 {
prefix = prefix[idx+1:]
}
}
// Construct LEFT JOIN with prefix
@@ -189,7 +192,7 @@ func (b *BunSelectQuery) LeftJoin(query string, args ...interface{}) common.Sele
}
}
b.query = b.query.Join("LEFT JOIN " + joinClause, sqlArgs...)
b.query = b.query.Join("LEFT JOIN "+joinClause, sqlArgs...)
return b
}
@@ -231,7 +234,19 @@ func (b *BunSelectQuery) Scan(ctx context.Context, dest interface{}) error {
}
func (b *BunSelectQuery) Count(ctx context.Context) (int, error) {
count, err := b.query.Count(ctx)
// If Model() was set, use bun's native Count() which works properly
if b.hasModel {
count, err := b.query.Count(ctx)
return count, err
}
// Otherwise, wrap as subquery to avoid "Model(nil)" error
// This is needed when only Table() is set without a model
var count int
err := b.db.NewSelect().
TableExpr("(?) AS subquery", b.query).
ColumnExpr("COUNT(*)").
Scan(ctx, &count)
return count, err
}
@@ -381,7 +396,10 @@ type BunTxAdapter struct {
}
func (b *BunTxAdapter) NewSelect() common.SelectQuery {
return &BunSelectQuery{query: b.tx.NewSelect()}
return &BunSelectQuery{
query: b.tx.NewSelect(),
db: b.tx,
}
}
func (b *BunTxAdapter) NewInsert() common.InsertQuery {

View File

@@ -5,7 +5,7 @@ import (
"fmt"
"strings"
"github.com/Warky-Devs/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/common"
"gorm.io/gorm"
)
@@ -70,7 +70,8 @@ func (g *GormAdapter) RunInTransaction(ctx context.Context, fn func(common.Datab
// GormSelectQuery implements SelectQuery for GORM
type GormSelectQuery struct {
db *gorm.DB
tableName string
schema string // Separated schema name
tableName string // Just the table name, without schema
tableAlias string
}
@@ -79,7 +80,9 @@ func (g *GormSelectQuery) Model(model interface{}) common.SelectQuery {
// Try to get table name from model if it implements TableNameProvider
if provider, ok := model.(common.TableNameProvider); ok {
g.tableName = provider.TableName()
fullTableName := provider.TableName()
// Check if the table name contains schema (e.g., "schema.table")
g.schema, g.tableName = parseTableName(fullTableName)
}
return g
@@ -87,7 +90,8 @@ func (g *GormSelectQuery) Model(model interface{}) common.SelectQuery {
func (g *GormSelectQuery) Table(table string) common.SelectQuery {
g.db = g.db.Table(table)
g.tableName = table
// Check if the table name contains schema (e.g., "schema.table")
g.schema, g.tableName = parseTableName(table)
return g
}
@@ -120,13 +124,9 @@ func (g *GormSelectQuery) Join(query string, args ...interface{}) common.SelectQ
}
}
// If no prefix provided, use the table name as prefix
// If no prefix provided, use the table name as prefix (already separated from schema)
if prefix == "" && g.tableName != "" {
prefix = g.tableName
// Extract just the table name if it has schema
if idx := strings.LastIndex(prefix, "."); idx != -1 {
prefix = prefix[idx+1:]
}
}
// If prefix is provided, add it as an alias in the join
@@ -161,12 +161,9 @@ func (g *GormSelectQuery) LeftJoin(query string, args ...interface{}) common.Sel
}
}
// If no prefix provided, use the table name as prefix
// If no prefix provided, use the table name as prefix (already separated from schema)
if prefix == "" && g.tableName != "" {
prefix = g.tableName
if idx := strings.LastIndex(prefix, "."); idx != -1 {
prefix = prefix[idx+1:]
}
}
// Construct LEFT JOIN with prefix

View File

@@ -0,0 +1,173 @@
package database
import (
"reflect"
"strings"
"github.com/bitechdev/ResolveSpec/pkg/common"
)
// parseTableName splits a table name that may contain schema into separate schema and table
// For example: "public.users" -> ("public", "users")
//
// "users" -> ("", "users")
func parseTableName(fullTableName string) (schema, table string) {
if idx := strings.LastIndex(fullTableName, "."); idx != -1 {
return fullTableName[:idx], fullTableName[idx+1:]
}
return "", fullTableName
}
// GetPrimaryKeyName extracts the primary key column name from a model
// It first checks if the model implements PrimaryKeyNameProvider (GetIDName method)
// Falls back to reflection to find bun:",pk" tag, then gorm:"primaryKey" tag
func GetPrimaryKeyName(model any) string {
// Check if model implements PrimaryKeyNameProvider
if provider, ok := model.(common.PrimaryKeyNameProvider); ok {
return provider.GetIDName()
}
// Try Bun tag first
if pkName := getPrimaryKeyFromReflection(model, "bun"); pkName != "" {
return pkName
}
// Fall back to GORM tag
return getPrimaryKeyFromReflection(model, "gorm")
}
// GetModelColumns extracts all column names from a model using reflection
// It checks bun tags first, then gorm tags, then json tags, and finally falls back to lowercase field names
func GetModelColumns(model any) []string {
var columns []string
modelType := reflect.TypeOf(model)
// Unwrap pointers, slices, and arrays to get to the base struct type
for modelType != nil && (modelType.Kind() == reflect.Pointer || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
modelType = modelType.Elem()
}
// Validate that we have a struct type
if modelType == nil || modelType.Kind() != reflect.Struct {
return columns
}
for i := 0; i < modelType.NumField(); i++ {
field := modelType.Field(i)
// Get column name using the same logic as primary key extraction
columnName := getColumnNameFromField(field)
if columnName != "" {
columns = append(columns, columnName)
}
}
return columns
}
// getColumnNameFromField extracts the column name from a struct field
// Priority: bun tag -> gorm tag -> json tag -> lowercase field name
func getColumnNameFromField(field reflect.StructField) string {
// Try bun tag first
bunTag := field.Tag.Get("bun")
if bunTag != "" && bunTag != "-" {
if colName := extractColumnFromBunTag(bunTag); colName != "" {
return colName
}
}
// Try gorm tag
gormTag := field.Tag.Get("gorm")
if gormTag != "" && gormTag != "-" {
if colName := extractColumnFromGormTag(gormTag); colName != "" {
return colName
}
}
// Fall back to json tag
jsonTag := field.Tag.Get("json")
if jsonTag != "" && jsonTag != "-" {
// Extract just the field name before any options
parts := strings.Split(jsonTag, ",")
if len(parts) > 0 && parts[0] != "" {
return parts[0]
}
}
// Last resort: use field name in lowercase
return strings.ToLower(field.Name)
}
// getPrimaryKeyFromReflection uses reflection to find the primary key field
func getPrimaryKeyFromReflection(model any, ormType string) string {
val := reflect.ValueOf(model)
if val.Kind() == reflect.Pointer {
val = val.Elem()
}
if val.Kind() != reflect.Struct {
return ""
}
typ := val.Type()
for i := 0; i < typ.NumField(); i++ {
field := typ.Field(i)
switch ormType {
case "gorm":
// Check for gorm tag with primaryKey
gormTag := field.Tag.Get("gorm")
if strings.Contains(gormTag, "primaryKey") {
// Try to extract column name from gorm tag
if colName := extractColumnFromGormTag(gormTag); colName != "" {
return colName
}
// Fall back to json tag
if jsonTag := field.Tag.Get("json"); jsonTag != "" {
return strings.Split(jsonTag, ",")[0]
}
}
case "bun":
// Check for bun tag with pk flag
bunTag := field.Tag.Get("bun")
if strings.Contains(bunTag, "pk") {
// Extract column name from bun tag
if colName := extractColumnFromBunTag(bunTag); colName != "" {
return colName
}
// Fall back to json tag
if jsonTag := field.Tag.Get("json"); jsonTag != "" {
return strings.Split(jsonTag, ",")[0]
}
}
}
}
return ""
}
// extractColumnFromGormTag extracts the column name from a gorm tag
// Example: "column:id;primaryKey" -> "id"
func extractColumnFromGormTag(tag string) string {
parts := strings.Split(tag, ";")
for _, part := range parts {
part = strings.TrimSpace(part)
if colName, found := strings.CutPrefix(part, "column:"); found {
return colName
}
}
return ""
}
// extractColumnFromBunTag extracts the column name from a bun tag
// Example: "id,pk" -> "id"
// Example: ",pk" -> "" (will fall back to json tag)
func extractColumnFromBunTag(tag string) string {
parts := strings.Split(tag, ",")
if len(parts) > 0 && parts[0] != "" {
return parts[0]
}
return ""
}

View File

@@ -0,0 +1,233 @@
package database
import (
"testing"
)
// Test models for GORM
type GormModelWithGetIDName struct {
ID int `gorm:"column:rid_test;primaryKey" json:"id"`
Name string `json:"name"`
}
func (m GormModelWithGetIDName) GetIDName() string {
return "rid_test"
}
type GormModelWithColumnTag struct {
ID int `gorm:"column:custom_id;primaryKey" json:"id"`
Name string `json:"name"`
}
type GormModelWithJSONFallback struct {
ID int `gorm:"primaryKey" json:"user_id"`
Name string `json:"name"`
}
// Test models for Bun
type BunModelWithGetIDName struct {
ID int `bun:"rid_test,pk" json:"id"`
Name string `json:"name"`
}
func (m BunModelWithGetIDName) GetIDName() string {
return "rid_test"
}
type BunModelWithColumnTag struct {
ID int `bun:"custom_id,pk" json:"id"`
Name string `json:"name"`
}
type BunModelWithJSONFallback struct {
ID int `bun:",pk" json:"user_id"`
Name string `json:"name"`
}
func TestGetPrimaryKeyName(t *testing.T) {
tests := []struct {
name string
model any
expected string
}{
{
name: "GORM model with GetIDName method",
model: GormModelWithGetIDName{},
expected: "rid_test",
},
{
name: "GORM model with column tag",
model: GormModelWithColumnTag{},
expected: "custom_id",
},
{
name: "GORM model with JSON fallback",
model: GormModelWithJSONFallback{},
expected: "user_id",
},
{
name: "GORM model pointer with GetIDName",
model: &GormModelWithGetIDName{},
expected: "rid_test",
},
{
name: "GORM model pointer with column tag",
model: &GormModelWithColumnTag{},
expected: "custom_id",
},
{
name: "Bun model with GetIDName method",
model: BunModelWithGetIDName{},
expected: "rid_test",
},
{
name: "Bun model with column tag",
model: BunModelWithColumnTag{},
expected: "custom_id",
},
{
name: "Bun model with JSON fallback",
model: BunModelWithJSONFallback{},
expected: "user_id",
},
{
name: "Bun model pointer with GetIDName",
model: &BunModelWithGetIDName{},
expected: "rid_test",
},
{
name: "Bun model pointer with column tag",
model: &BunModelWithColumnTag{},
expected: "custom_id",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := GetPrimaryKeyName(tt.model)
if result != tt.expected {
t.Errorf("GetPrimaryKeyName() = %v, want %v", result, tt.expected)
}
})
}
}
func TestExtractColumnFromGormTag(t *testing.T) {
tests := []struct {
name string
tag string
expected string
}{
{
name: "column tag with primaryKey",
tag: "column:rid_test;primaryKey",
expected: "rid_test",
},
{
name: "column tag with spaces",
tag: "column:user_id ; primaryKey ; autoIncrement",
expected: "user_id",
},
{
name: "no column tag",
tag: "primaryKey;autoIncrement",
expected: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := extractColumnFromGormTag(tt.tag)
if result != tt.expected {
t.Errorf("extractColumnFromGormTag() = %v, want %v", result, tt.expected)
}
})
}
}
func TestExtractColumnFromBunTag(t *testing.T) {
tests := []struct {
name string
tag string
expected string
}{
{
name: "column name with pk flag",
tag: "rid_test,pk",
expected: "rid_test",
},
{
name: "only pk flag",
tag: ",pk",
expected: "",
},
{
name: "column with multiple flags",
tag: "user_id,pk,autoincrement",
expected: "user_id",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := extractColumnFromBunTag(tt.tag)
if result != tt.expected {
t.Errorf("extractColumnFromBunTag() = %v, want %v", result, tt.expected)
}
})
}
}
func TestGetModelColumns(t *testing.T) {
tests := []struct {
name string
model any
expected []string
}{
{
name: "Bun model with multiple columns",
model: BunModelWithColumnTag{},
expected: []string{"custom_id", "name"},
},
{
name: "GORM model with multiple columns",
model: GormModelWithColumnTag{},
expected: []string{"custom_id", "name"},
},
{
name: "Bun model pointer",
model: &BunModelWithColumnTag{},
expected: []string{"custom_id", "name"},
},
{
name: "GORM model pointer",
model: &GormModelWithColumnTag{},
expected: []string{"custom_id", "name"},
},
{
name: "Bun model with JSON fallback",
model: BunModelWithJSONFallback{},
expected: []string{"user_id", "name"},
},
{
name: "GORM model with JSON fallback",
model: GormModelWithJSONFallback{},
expected: []string{"user_id", "name"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := GetModelColumns(tt.model)
if len(result) != len(tt.expected) {
t.Errorf("GetModelColumns() returned %d columns, want %d", len(result), len(tt.expected))
return
}
for i, col := range result {
if col != tt.expected[i] {
t.Errorf("GetModelColumns()[%d] = %v, want %v", i, col, tt.expected[i])
}
}
})
}
}

View File

@@ -3,7 +3,7 @@ package router
import (
"net/http"
"github.com/Warky-Devs/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/uptrace/bunrouter"
)
@@ -190,4 +190,3 @@ func DefaultBunRouterConfig() *BunRouterConfig {
HandleOPTIONS: true,
}
}

View File

@@ -5,7 +5,7 @@ import (
"io"
"net/http"
"github.com/Warky-Devs/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/gorilla/mux"
)

View File

@@ -131,6 +131,11 @@ type TableNameProvider interface {
TableName() string
}
// PrimaryKeyNameProvider interface for models that provide primary key column names
type PrimaryKeyNameProvider interface {
GetIDName() string
}
// SchemaProvider interface for models that provide schema names
type SchemaProvider interface {
SchemaName() string

View File

@@ -37,9 +37,10 @@ type PreloadOption struct {
}
type FilterOption struct {
Column string `json:"column"`
Operator string `json:"operator"`
Value interface{} `json:"value"`
Column string `json:"column"`
Operator string `json:"operator"`
Value interface{} `json:"value"`
LogicOperator string `json:"logic_operator"` // "AND" or "OR" - how this filter combines with previous filters
}
type SortOption struct {

272
pkg/common/validation.go Normal file
View File

@@ -0,0 +1,272 @@
package common
import (
"fmt"
"reflect"
"strings"
"github.com/bitechdev/ResolveSpec/pkg/logger"
)
// ColumnValidator validates column names against a model's fields
type ColumnValidator struct {
validColumns map[string]bool
model interface{}
}
// NewColumnValidator creates a new column validator for a given model
func NewColumnValidator(model interface{}) *ColumnValidator {
validator := &ColumnValidator{
validColumns: make(map[string]bool),
model: model,
}
validator.buildValidColumns()
return validator
}
// buildValidColumns extracts all valid column names from the model using reflection
func (v *ColumnValidator) buildValidColumns() {
modelType := reflect.TypeOf(v.model)
// Unwrap pointers, slices, and arrays to get to the base struct type
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
modelType = modelType.Elem()
}
// Validate that we have a struct type
if modelType == nil || modelType.Kind() != reflect.Struct {
return
}
// Extract column names from struct fields
for i := 0; i < modelType.NumField(); i++ {
field := modelType.Field(i)
if !field.IsExported() {
continue
}
// Get column name from bun, gorm, or json tag
columnName := v.getColumnName(field)
if columnName != "" && columnName != "-" {
v.validColumns[strings.ToLower(columnName)] = true
}
}
}
// getColumnName extracts the column name from a struct field's tags
// Supports both Bun and GORM tags
func (v *ColumnValidator) getColumnName(field reflect.StructField) string {
// First check Bun tag for column name
bunTag := field.Tag.Get("bun")
if bunTag != "" && bunTag != "-" {
parts := strings.Split(bunTag, ",")
// The first part is usually the column name
columnName := strings.TrimSpace(parts[0])
if columnName != "" && columnName != "-" {
return columnName
}
}
// Check GORM tag for column name
gormTag := field.Tag.Get("gorm")
if strings.Contains(gormTag, "column:") {
parts := strings.Split(gormTag, ";")
for _, part := range parts {
part = strings.TrimSpace(part)
if strings.HasPrefix(part, "column:") {
return strings.TrimPrefix(part, "column:")
}
}
}
// Fall back to JSON tag
jsonTag := field.Tag.Get("json")
if jsonTag != "" && jsonTag != "-" {
// Extract just the name part (before any comma)
jsonName := strings.Split(jsonTag, ",")[0]
return jsonName
}
// Fall back to field name in lowercase (snake_case conversion would be better)
return strings.ToLower(field.Name)
}
// ValidateColumn validates a single column name
// Returns nil if valid, error if invalid
// Columns prefixed with "cql" (case insensitive) are always valid
func (v *ColumnValidator) ValidateColumn(column string) error {
// Allow empty columns
if column == "" {
return nil
}
// Allow columns prefixed with "cql" (case insensitive) for computed columns
if strings.HasPrefix(strings.ToLower(column), "cql") {
return nil
}
// Check if column exists in model
if _, exists := v.validColumns[strings.ToLower(column)]; !exists {
return fmt.Errorf("invalid column '%s': column does not exist in model", column)
}
return nil
}
// IsValidColumn checks if a column is valid
// Returns true if valid, false if invalid
func (v *ColumnValidator) IsValidColumn(column string) bool {
return v.ValidateColumn(column) == nil
}
// FilterValidColumns filters a list of columns, returning only valid ones
// Logs warnings for any invalid columns
func (v *ColumnValidator) FilterValidColumns(columns []string) []string {
if len(columns) == 0 {
return columns
}
validColumns := make([]string, 0, len(columns))
for _, col := range columns {
if v.IsValidColumn(col) {
validColumns = append(validColumns, col)
} else {
logger.Warn("Invalid column '%s' filtered out: column does not exist in model", col)
}
}
return validColumns
}
// ValidateColumns validates multiple column names
// Returns error with details about all invalid columns
func (v *ColumnValidator) ValidateColumns(columns []string) error {
var invalidColumns []string
for _, column := range columns {
if err := v.ValidateColumn(column); err != nil {
invalidColumns = append(invalidColumns, column)
}
}
if len(invalidColumns) > 0 {
return fmt.Errorf("invalid columns: %s", strings.Join(invalidColumns, ", "))
}
return nil
}
// ValidateRequestOptions validates all column references in RequestOptions
func (v *ColumnValidator) ValidateRequestOptions(options RequestOptions) error {
// Validate Columns
if err := v.ValidateColumns(options.Columns); err != nil {
return fmt.Errorf("in select columns: %w", err)
}
// Validate OmitColumns
if err := v.ValidateColumns(options.OmitColumns); err != nil {
return fmt.Errorf("in omit columns: %w", err)
}
// Validate Filter columns
for _, filter := range options.Filters {
if err := v.ValidateColumn(filter.Column); err != nil {
return fmt.Errorf("in filter: %w", err)
}
}
// Validate Sort columns
for _, sort := range options.Sort {
if err := v.ValidateColumn(sort.Column); err != nil {
return fmt.Errorf("in sort: %w", err)
}
}
// Validate Preload columns (if specified)
for _, preload := range options.Preload {
// Note: We don't validate the relation name itself, as it's a relationship
// Only validate columns if specified for the preload
if err := v.ValidateColumns(preload.Columns); err != nil {
return fmt.Errorf("in preload '%s' columns: %w", preload.Relation, err)
}
if err := v.ValidateColumns(preload.OmitColumns); err != nil {
return fmt.Errorf("in preload '%s' omit columns: %w", preload.Relation, err)
}
// Validate filter columns in preload
for _, filter := range preload.Filters {
if err := v.ValidateColumn(filter.Column); err != nil {
return fmt.Errorf("in preload '%s' filter: %w", preload.Relation, err)
}
}
}
return nil
}
// FilterRequestOptions filters all column references in RequestOptions
// Returns a new RequestOptions with only valid columns, logging warnings for invalid ones
func (v *ColumnValidator) FilterRequestOptions(options RequestOptions) RequestOptions {
filtered := options
// Filter Columns
filtered.Columns = v.FilterValidColumns(options.Columns)
// Filter OmitColumns
filtered.OmitColumns = v.FilterValidColumns(options.OmitColumns)
// Filter Filter columns
validFilters := make([]FilterOption, 0, len(options.Filters))
for _, filter := range options.Filters {
if v.IsValidColumn(filter.Column) {
validFilters = append(validFilters, filter)
} else {
logger.Warn("Invalid column in filter '%s' removed", filter.Column)
}
}
filtered.Filters = validFilters
// Filter Sort columns
validSorts := make([]SortOption, 0, len(options.Sort))
for _, sort := range options.Sort {
if v.IsValidColumn(sort.Column) {
validSorts = append(validSorts, sort)
} else {
logger.Warn("Invalid column in sort '%s' removed", sort.Column)
}
}
filtered.Sort = validSorts
// Filter Preload columns
validPreloads := make([]PreloadOption, 0, len(options.Preload))
for _, preload := range options.Preload {
filteredPreload := preload
filteredPreload.Columns = v.FilterValidColumns(preload.Columns)
filteredPreload.OmitColumns = v.FilterValidColumns(preload.OmitColumns)
// Filter preload filters
validPreloadFilters := make([]FilterOption, 0, len(preload.Filters))
for _, filter := range preload.Filters {
if v.IsValidColumn(filter.Column) {
validPreloadFilters = append(validPreloadFilters, filter)
} else {
logger.Warn("Invalid column in preload '%s' filter '%s' removed", preload.Relation, filter.Column)
}
}
filteredPreload.Filters = validPreloadFilters
validPreloads = append(validPreloads, filteredPreload)
}
filtered.Preload = validPreloads
return filtered
}
// GetValidColumns returns a list of all valid column names for debugging purposes
func (v *ColumnValidator) GetValidColumns() []string {
columns := make([]string, 0, len(v.validColumns))
for col := range v.validColumns {
columns = append(columns, col)
}
return columns
}

View File

@@ -0,0 +1,363 @@
package common
import (
"strings"
"testing"
)
// TestModel represents a sample model for testing
type TestModel struct {
ID int64 `json:"id" gorm:"primaryKey"`
Name string `json:"name" gorm:"column:name"`
Email string `json:"email" bun:"email"`
Age int `json:"age"`
IsActive bool `json:"is_active"`
CreatedAt string `json:"created_at"`
}
func TestNewColumnValidator(t *testing.T) {
model := TestModel{}
validator := NewColumnValidator(model)
if validator == nil {
t.Fatal("Expected validator to be created")
}
if len(validator.validColumns) == 0 {
t.Fatal("Expected validator to have valid columns")
}
// Check that expected columns are present
expectedColumns := []string{"id", "name", "email", "age", "is_active", "created_at"}
for _, col := range expectedColumns {
if !validator.validColumns[col] {
t.Errorf("Expected column '%s' to be valid", col)
}
}
}
func TestValidateColumn(t *testing.T) {
model := TestModel{}
validator := NewColumnValidator(model)
tests := []struct {
name string
column string
shouldError bool
}{
{"Valid column - id", "id", false},
{"Valid column - name", "name", false},
{"Valid column - email", "email", false},
{"Valid column - uppercase", "ID", false}, // Case insensitive
{"Invalid column", "invalid_column", true},
{"CQL prefixed - should be valid", "cqlComputedField", false},
{"CQL prefixed uppercase - should be valid", "CQLComputedField", false},
{"Empty column", "", false}, // Empty columns are allowed
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validator.ValidateColumn(tt.column)
if tt.shouldError && err == nil {
t.Errorf("Expected error for column '%s', got nil", tt.column)
}
if !tt.shouldError && err != nil {
t.Errorf("Expected no error for column '%s', got: %v", tt.column, err)
}
})
}
}
func TestValidateColumns(t *testing.T) {
model := TestModel{}
validator := NewColumnValidator(model)
tests := []struct {
name string
columns []string
shouldError bool
}{
{"All valid columns", []string{"id", "name", "email"}, false},
{"One invalid column", []string{"id", "invalid_col", "name"}, true},
{"All invalid columns", []string{"bad1", "bad2"}, true},
{"With CQL prefix", []string{"id", "cqlComputed", "name"}, false},
{"Empty list", []string{}, false},
{"Nil list", nil, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validator.ValidateColumns(tt.columns)
if tt.shouldError && err == nil {
t.Errorf("Expected error for columns %v, got nil", tt.columns)
}
if !tt.shouldError && err != nil {
t.Errorf("Expected no error for columns %v, got: %v", tt.columns, err)
}
})
}
}
func TestValidateRequestOptions(t *testing.T) {
model := TestModel{}
validator := NewColumnValidator(model)
tests := []struct {
name string
options RequestOptions
shouldError bool
errorMsg string
}{
{
name: "Valid options with columns",
options: RequestOptions{
Columns: []string{"id", "name"},
Filters: []FilterOption{
{Column: "name", Operator: "eq", Value: "test"},
},
Sort: []SortOption{
{Column: "id", Direction: "ASC"},
},
},
shouldError: false,
},
{
name: "Invalid column in Columns",
options: RequestOptions{
Columns: []string{"id", "invalid_column"},
},
shouldError: true,
errorMsg: "select columns",
},
{
name: "Invalid column in Filters",
options: RequestOptions{
Filters: []FilterOption{
{Column: "invalid_col", Operator: "eq", Value: "test"},
},
},
shouldError: true,
errorMsg: "filter",
},
{
name: "Invalid column in Sort",
options: RequestOptions{
Sort: []SortOption{
{Column: "invalid_col", Direction: "ASC"},
},
},
shouldError: true,
errorMsg: "sort",
},
{
name: "Valid CQL prefixed columns",
options: RequestOptions{
Columns: []string{"id", "cqlComputedField"},
Filters: []FilterOption{
{Column: "cqlCustomFilter", Operator: "eq", Value: "test"},
},
},
shouldError: false,
},
{
name: "Invalid column in Preload",
options: RequestOptions{
Preload: []PreloadOption{
{
Relation: "SomeRelation",
Columns: []string{"id", "invalid_col"},
},
},
},
shouldError: true,
errorMsg: "preload",
},
{
name: "Valid preload with valid columns",
options: RequestOptions{
Preload: []PreloadOption{
{
Relation: "SomeRelation",
Columns: []string{"id", "name"},
},
},
},
shouldError: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validator.ValidateRequestOptions(tt.options)
if tt.shouldError {
if err == nil {
t.Errorf("Expected error, got nil")
} else if tt.errorMsg != "" && !strings.Contains(err.Error(), tt.errorMsg) {
t.Errorf("Expected error to contain '%s', got: %v", tt.errorMsg, err)
}
} else {
if err != nil {
t.Errorf("Expected no error, got: %v", err)
}
}
})
}
}
func TestGetValidColumns(t *testing.T) {
model := TestModel{}
validator := NewColumnValidator(model)
columns := validator.GetValidColumns()
if len(columns) == 0 {
t.Error("Expected to get valid columns, got empty list")
}
// Should have at least the columns from TestModel
if len(columns) < 6 {
t.Errorf("Expected at least 6 columns, got %d", len(columns))
}
}
// Test with Bun tags specifically
type BunModel struct {
ID int64 `bun:"id,pk"`
Name string `bun:"name"`
Email string `bun:"user_email"`
}
func TestBunTagSupport(t *testing.T) {
model := BunModel{}
validator := NewColumnValidator(model)
// Test that bun tags are properly recognized
tests := []struct {
column string
shouldError bool
}{
{"id", false},
{"name", false},
{"user_email", false}, // Bun tag specifies this name
{"email", true}, // JSON tag would be "email", but bun tag says "user_email"
}
for _, tt := range tests {
t.Run(tt.column, func(t *testing.T) {
err := validator.ValidateColumn(tt.column)
if tt.shouldError && err == nil {
t.Errorf("Expected error for column '%s'", tt.column)
}
if !tt.shouldError && err != nil {
t.Errorf("Expected no error for column '%s', got: %v", tt.column, err)
}
})
}
}
func TestFilterValidColumns(t *testing.T) {
model := TestModel{}
validator := NewColumnValidator(model)
tests := []struct {
name string
input []string
expectedOutput []string
}{
{
name: "All valid columns",
input: []string{"id", "name", "email"},
expectedOutput: []string{"id", "name", "email"},
},
{
name: "Mix of valid and invalid",
input: []string{"id", "invalid_col", "name", "bad_col", "email"},
expectedOutput: []string{"id", "name", "email"},
},
{
name: "All invalid columns",
input: []string{"bad1", "bad2"},
expectedOutput: []string{},
},
{
name: "With CQL prefix (should pass)",
input: []string{"id", "cqlComputed", "name"},
expectedOutput: []string{"id", "cqlComputed", "name"},
},
{
name: "Empty input",
input: []string{},
expectedOutput: []string{},
},
{
name: "Nil input",
input: nil,
expectedOutput: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := validator.FilterValidColumns(tt.input)
if len(result) != len(tt.expectedOutput) {
t.Errorf("Expected %d columns, got %d", len(tt.expectedOutput), len(result))
}
for i, col := range result {
if col != tt.expectedOutput[i] {
t.Errorf("At index %d: expected %s, got %s", i, tt.expectedOutput[i], col)
}
}
})
}
}
func TestFilterRequestOptions(t *testing.T) {
model := TestModel{}
validator := NewColumnValidator(model)
options := RequestOptions{
Columns: []string{"id", "name", "invalid_col"},
OmitColumns: []string{"email", "bad_col"},
Filters: []FilterOption{
{Column: "name", Operator: "eq", Value: "test"},
{Column: "invalid_col", Operator: "eq", Value: "test"},
},
Sort: []SortOption{
{Column: "id", Direction: "ASC"},
{Column: "bad_col", Direction: "DESC"},
},
}
filtered := validator.FilterRequestOptions(options)
// Check Columns
if len(filtered.Columns) != 2 {
t.Errorf("Expected 2 columns, got %d", len(filtered.Columns))
}
if filtered.Columns[0] != "id" || filtered.Columns[1] != "name" {
t.Errorf("Expected columns [id, name], got %v", filtered.Columns)
}
// Check OmitColumns
if len(filtered.OmitColumns) != 1 {
t.Errorf("Expected 1 omit column, got %d", len(filtered.OmitColumns))
}
if filtered.OmitColumns[0] != "email" {
t.Errorf("Expected omit column [email], got %v", filtered.OmitColumns)
}
// Check Filters
if len(filtered.Filters) != 1 {
t.Errorf("Expected 1 filter, got %d", len(filtered.Filters))
}
if filtered.Filters[0].Column != "name" {
t.Errorf("Expected filter column 'name', got %s", filtered.Filters[0].Column)
}
// Check Sort
if len(filtered.Sort) != 1 {
t.Errorf("Expected 1 sort, got %d", len(filtered.Sort))
}
if filtered.Sort[0].Column != "id" {
t.Errorf("Expected sort column 'id', got %s", filtered.Sort[0].Column)
}
}

View File

@@ -38,12 +38,28 @@ func (r *DefaultModelRegistry) RegisterModel(name string, model interface{}) err
return fmt.Errorf("model cannot be nil")
}
if modelType.Kind() == reflect.Ptr {
return fmt.Errorf("model must be a non-pointer struct, got pointer to %s", modelType.Elem().Kind())
originalType := modelType
// Unwrap pointers, slices, and arrays to check the underlying type
for modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array {
modelType = modelType.Elem()
}
// Validate that the underlying type is a struct
if modelType.Kind() != reflect.Struct {
return fmt.Errorf("model must be a struct, got %s", modelType.Kind())
return fmt.Errorf("model must be a struct or pointer to struct, got %s", originalType.String())
}
// If a pointer/slice/array was passed, unwrap to the base struct
if originalType != modelType {
// Create a zero value of the struct type
model = reflect.New(modelType).Elem().Interface()
}
// Additional check: ensure model is not a pointer
finalType := reflect.TypeOf(model)
if finalType.Kind() == reflect.Ptr {
return fmt.Errorf("model must be a non-pointer struct, got pointer to %s. Use MyModel{} instead of &MyModel{}", finalType.Elem().Name())
}
r.models[name] = model

View File

@@ -6,10 +6,11 @@ import (
"fmt"
"net/http"
"reflect"
"runtime/debug"
"strings"
"github.com/Warky-Devs/ResolveSpec/pkg/common"
"github.com/Warky-Devs/ResolveSpec/pkg/logger"
"github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/logger"
)
// Handler handles API requests using database and model abstractions
@@ -26,8 +27,22 @@ func NewHandler(db common.Database, registry common.ModelRegistry) *Handler {
}
}
// handlePanic is a helper function to handle panics with stack traces
func (h *Handler) handlePanic(w common.ResponseWriter, method string, err interface{}) {
stack := debug.Stack()
logger.Error("Panic in %s: %v\nStack trace:\n%s", method, err, string(stack))
h.sendError(w, http.StatusInternalServerError, "internal_error", fmt.Sprintf("Internal server error in %s", method), fmt.Errorf("%v", err))
}
// Handle processes API requests through router-agnostic interface
func (h *Handler) Handle(w common.ResponseWriter, r common.Request, params map[string]string) {
// Capture panics and return error response
defer func() {
if err := recover(); err != nil {
h.handlePanic(w, "Handle", err)
}
}()
ctx := context.Background()
body, err := r.Body()
@@ -58,6 +73,26 @@ func (h *Handler) Handle(w common.ResponseWriter, r common.Request, params map[s
return
}
// Validate that the model is a struct type (not a slice or pointer to slice)
modelType := reflect.TypeOf(model)
originalType := modelType
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
modelType = modelType.Elem()
}
if modelType == nil || modelType.Kind() != reflect.Struct {
logger.Error("Model for %s.%s must be a struct type, got %v. Please register models as struct types, not slices or pointers to slices.", schema, entity, originalType)
h.sendError(w, http.StatusInternalServerError, "invalid_model_type",
fmt.Sprintf("Model must be a struct type, got %v. Ensure you register the struct (e.g., ModelCoreAccount{}) not a slice (e.g., []*ModelCoreAccount)", originalType),
fmt.Errorf("invalid model type: %v", originalType))
return
}
// If the registered model was a pointer or slice, use the unwrapped struct type
if originalType != modelType {
model = reflect.New(modelType).Elem().Interface()
}
// Create a pointer to the model type for database operations
modelPtr := reflect.New(reflect.TypeOf(model)).Interface()
tableName := h.getTableName(schema, entity, model)
@@ -65,6 +100,10 @@ func (h *Handler) Handle(w common.ResponseWriter, r common.Request, params map[s
// Add request-scoped data to context
ctx = WithRequestData(ctx, schema, entity, tableName, model, modelPtr)
// Validate and filter columns in options (log warnings for invalid columns)
validator := common.NewColumnValidator(model)
req.Options = validator.FilterRequestOptions(req.Options)
switch req.Operation {
case "read":
h.handleRead(ctx, w, id, req.Options)
@@ -82,6 +121,13 @@ func (h *Handler) Handle(w common.ResponseWriter, r common.Request, params map[s
// HandleGet processes GET requests for metadata
func (h *Handler) HandleGet(w common.ResponseWriter, r common.Request, params map[string]string) {
// Capture panics and return error response
defer func() {
if err := recover(); err != nil {
h.handlePanic(w, "HandleGet", err)
}
}()
schema := params["schema"]
entity := params["entity"]
@@ -99,16 +145,46 @@ func (h *Handler) HandleGet(w common.ResponseWriter, r common.Request, params ma
}
func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id string, options common.RequestOptions) {
// Capture panics and return error response
defer func() {
if err := recover(); err != nil {
h.handlePanic(w, "handleRead", err)
}
}()
schema := GetSchema(ctx)
entity := GetEntity(ctx)
tableName := GetTableName(ctx)
model := GetModel(ctx)
modelPtr := GetModelPtr(ctx)
// Validate and unwrap model type to get base struct
modelType := reflect.TypeOf(model)
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
modelType = modelType.Elem()
}
if modelType == nil || modelType.Kind() != reflect.Struct {
logger.Error("Model must be a struct type, got %v for %s.%s", modelType, schema, entity)
h.sendError(w, http.StatusInternalServerError, "invalid_model", "Model must be a struct type", fmt.Errorf("invalid model type: %v", modelType))
return
}
logger.Info("Reading records from %s.%s", schema, entity)
// Create the model pointer for Scan() operations
sliceType := reflect.SliceOf(reflect.PointerTo(modelType))
modelPtr := reflect.New(sliceType).Interface()
// Start with Model() using the slice pointer to avoid "Model(nil)" errors in Count()
// Bun's Model() accepts both single pointers and slice pointers
query := h.db.NewSelect().Model(modelPtr)
query = query.Table(tableName)
// Only set Table() if the model doesn't provide a table name via the underlying type
// Create a temporary instance to check for TableNameProvider
tempInstance := reflect.New(modelType).Interface()
if provider, ok := tempInstance.(common.TableNameProvider); !ok || provider.TableName() == "" {
query = query.Table(tableName)
}
// Apply column selection
if len(options.Columns) > 0 {
@@ -160,8 +236,8 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
var result interface{}
if id != "" {
logger.Debug("Querying single record with ID: %s", id)
// Create a pointer to the struct type for scanning
singleResult := reflect.New(reflect.TypeOf(model)).Interface()
// For single record, create a new pointer to the struct type
singleResult := reflect.New(modelType).Interface()
query = query.Where("id = ?", id)
if err := query.Scan(ctx, singleResult); err != nil {
logger.Error("Error querying record: %v", err)
@@ -171,16 +247,13 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
result = singleResult
} else {
logger.Debug("Querying multiple records")
// Create a slice of pointers to the model type
sliceType := reflect.SliceOf(reflect.PointerTo(reflect.TypeOf(model)))
results := reflect.New(sliceType).Interface()
if err := query.Scan(ctx, results); err != nil {
// Use the modelPtr already created and set on the query
if err := query.Scan(ctx, modelPtr); err != nil {
logger.Error("Error querying records: %v", err)
h.sendError(w, http.StatusInternalServerError, "query_error", "Error executing query", err)
return
}
result = reflect.ValueOf(results).Elem().Interface()
result = reflect.ValueOf(modelPtr).Elem().Interface()
}
logger.Info("Successfully retrieved records")
@@ -203,6 +276,13 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
}
func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, data interface{}, options common.RequestOptions) {
// Capture panics and return error response
defer func() {
if err := recover(); err != nil {
h.handlePanic(w, "handleCreate", err)
}
}()
schema := GetSchema(ctx)
entity := GetEntity(ctx)
tableName := GetTableName(ctx)
@@ -279,6 +359,13 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
}
func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, urlID string, reqID interface{}, data interface{}, options common.RequestOptions) {
// Capture panics and return error response
defer func() {
if err := recover(); err != nil {
h.handlePanic(w, "handleUpdate", err)
}
}()
schema := GetSchema(ctx)
entity := GetEntity(ctx)
tableName := GetTableName(ctx)
@@ -329,6 +416,13 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url
}
func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id string) {
// Capture panics and return error response
defer func() {
if err := recover(); err != nil {
h.handlePanic(w, "handleDelete", err)
}
}()
schema := GetSchema(ctx)
entity := GetEntity(ctx)
tableName := GetTableName(ctx)
@@ -385,19 +479,86 @@ func (h *Handler) applyFilter(query common.SelectQuery, filter common.FilterOpti
}
}
func (h *Handler) getTableName(schema, entity string, model interface{}) string {
if provider, ok := model.(common.TableNameProvider); ok {
return provider.TableName()
// parseTableName splits a table name that may contain schema into separate schema and table
func (h *Handler) parseTableName(fullTableName string) (schema, table string) {
if idx := strings.LastIndex(fullTableName, "."); idx != -1 {
return fullTableName[:idx], fullTableName[idx+1:]
}
return fmt.Sprintf("%s.%s", schema, entity)
return "", fullTableName
}
// getSchemaAndTable returns the schema and table name separately
// It checks SchemaProvider and TableNameProvider interfaces and handles cases where
// the table name may already include the schema (e.g., "public.users")
//
// Priority order:
// 1. If TableName() contains a schema (e.g., "myschema.mytable"), that schema takes precedence
// 2. If model implements SchemaProvider, use that schema
// 3. Otherwise, use the defaultSchema parameter
func (h *Handler) getSchemaAndTable(defaultSchema, entity string, model interface{}) (schema, table string) {
// First check if model provides a table name
// We check this FIRST because the table name might already contain the schema
if tableProvider, ok := model.(common.TableNameProvider); ok {
tableName := tableProvider.TableName()
// IMPORTANT: Check if the table name already contains a schema (e.g., "schema.table")
// This is common when models need to specify a different schema than the default
if tableSchema, tableOnly := h.parseTableName(tableName); tableSchema != "" {
// Table name includes schema - use it and ignore any other schema providers
logger.Debug("TableName() includes schema: %s.%s", tableSchema, tableOnly)
return tableSchema, tableOnly
}
// Table name is just the table name without schema
// Now determine which schema to use
if schemaProvider, ok := model.(common.SchemaProvider); ok {
schema = schemaProvider.SchemaName()
} else {
schema = defaultSchema
}
return schema, tableName
}
// No TableNameProvider, so check for schema and use entity as table name
if schemaProvider, ok := model.(common.SchemaProvider); ok {
schema = schemaProvider.SchemaName()
} else {
schema = defaultSchema
}
// Default to entity name as table
return schema, entity
}
// getTableName returns the full table name including schema (schema.table)
func (h *Handler) getTableName(schema, entity string, model interface{}) string {
schemaName, tableName := h.getSchemaAndTable(schema, entity, model)
if schemaName != "" {
return fmt.Sprintf("%s.%s", schemaName, tableName)
}
return tableName
}
func (h *Handler) generateMetadata(schema, entity string, model interface{}) *common.TableMetadata {
modelType := reflect.TypeOf(model)
if modelType.Kind() == reflect.Ptr {
// Unwrap pointers, slices, and arrays to get to the base struct type
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
modelType = modelType.Elem()
}
// Validate that we have a struct type
if modelType == nil || modelType.Kind() != reflect.Struct {
logger.Error("Model type must be a struct, got %v for %s.%s", modelType, schema, entity)
return &common.TableMetadata{
Schema: schema,
Table: entity,
Columns: make([]common.Column, 0),
Relations: make([]string, 0),
}
}
metadata := &common.TableMetadata{
Schema: schema,
Table: entity,
@@ -541,10 +702,18 @@ type relationshipInfo struct {
func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, preloads []common.PreloadOption) common.SelectQuery {
modelType := reflect.TypeOf(model)
if modelType.Kind() == reflect.Ptr {
// Unwrap pointers, slices, and arrays to get to the base struct type
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
modelType = modelType.Elem()
}
// Validate that we have a struct type
if modelType == nil || modelType.Kind() != reflect.Struct {
logger.Warn("Cannot apply preloads to non-struct type: %v", modelType)
return query
}
for _, preload := range preloads {
logger.Debug("Processing preload for relation: %s", preload.Relation)
relInfo := h.getRelationshipInfo(modelType, preload.Relation)
@@ -568,6 +737,12 @@ func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, pre
}
func (h *Handler) getRelationshipInfo(modelType reflect.Type, relationName string) *relationshipInfo {
// Ensure we have a struct type
if modelType == nil || modelType.Kind() != reflect.Struct {
logger.Warn("Cannot get relationship info from non-struct type: %v", modelType)
return nil
}
for i := 0; i < modelType.NumField(); i++ {
field := modelType.Field(i)
jsonTag := field.Tag.Get("json")

View File

@@ -3,9 +3,9 @@ package resolvespec
import (
"net/http"
"github.com/Warky-Devs/ResolveSpec/pkg/common/adapters/database"
"github.com/Warky-Devs/ResolveSpec/pkg/common/adapters/router"
"github.com/Warky-Devs/ResolveSpec/pkg/modelregistry"
"github.com/bitechdev/ResolveSpec/pkg/common/adapters/database"
"github.com/bitechdev/ResolveSpec/pkg/common/adapters/router"
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
"github.com/gorilla/mux"
"github.com/uptrace/bun"
"github.com/uptrace/bunrouter"

223
pkg/restheadspec/cursor.go Normal file
View File

@@ -0,0 +1,223 @@
package restheadspec
import (
"fmt"
"strings"
"github.com/bitechdev/ResolveSpec/pkg/common"
)
// CursorDirection defines pagination direction
type CursorDirection int
const (
CursorForward CursorDirection = 1
CursorBackward CursorDirection = -1
)
// GetCursorFilter generates a SQL `EXISTS` subquery for cursor-based pagination.
// It uses the current request's sort, cursor, joins (via Expand), and CQL (via ComputedQL).
//
// Parameters:
// - tableName: name of the main table (e.g. "post")
// - pkName: primary key column (e.g. "id")
// - modelColumns: optional list of valid main-table columns (for validation). Pass nil to skip.
// - expandJoins: optional map[alias]string of JOIN clauses (e.g. "user": "LEFT JOIN user ON ...")
//
// Returns SQL snippet to embed in WHERE clause.
func (opts *ExtendedRequestOptions) GetCursorFilter(
tableName string,
pkName string,
modelColumns []string, // optional: for validation
expandJoins map[string]string, // optional: alias → JOIN SQL
) (string, error) {
// --------------------------------------------------------------------- //
// 1. Determine active cursor
// --------------------------------------------------------------------- //
cursorID, direction := opts.getActiveCursor()
if cursorID == "" {
return "", fmt.Errorf("no cursor provided for table %s", tableName)
}
// --------------------------------------------------------------------- //
// 2. Extract sort columns
// --------------------------------------------------------------------- //
sortItems := opts.getSortColumns()
if len(sortItems) == 0 {
return "", fmt.Errorf("no sort columns defined")
}
// --------------------------------------------------------------------- //
// 3. Prepare
// --------------------------------------------------------------------- //
var whereClauses []string
joinSQL := ""
reverse := direction < 0
// --------------------------------------------------------------------- //
// 4. Process each sort column
// --------------------------------------------------------------------- //
for _, s := range sortItems {
col := strings.TrimSpace(s.Column)
if col == "" {
continue
}
// Parse: "user.name desc nulls last"
parts := strings.Split(col, ".")
field := strings.TrimSpace(parts[len(parts)-1])
prefix := strings.Join(parts[:len(parts)-1], ".")
// Direction from struct or string
desc := strings.EqualFold(s.Direction, "desc") ||
strings.Contains(strings.ToLower(field), "desc")
field = opts.cleanSortField(field)
if reverse {
desc = !desc
}
// Resolve column
cursorCol, targetCol, isJoin, err := opts.resolveColumn(
field, prefix, tableName, modelColumns,
)
if err != nil {
fmt.Printf("WARN: Skipping invalid sort column %q: %v\n", col, err)
continue
}
// Handle joins
if isJoin && expandJoins != nil {
if joinClause, ok := expandJoins[prefix]; ok {
jSQL, cRef := rewriteJoin(joinClause, tableName, prefix)
joinSQL = jSQL
cursorCol = cRef + "." + field
targetCol = prefix + "." + field
}
}
// Build inequality
op := "<"
if desc {
op = ">"
}
whereClauses = append(whereClauses, fmt.Sprintf("%s %s %s", cursorCol, op, targetCol))
}
if len(whereClauses) == 0 {
return "", fmt.Errorf("no valid sort columns after filtering")
}
// --------------------------------------------------------------------- //
// 5. Build priority OR-AND chain
// --------------------------------------------------------------------- //
orSQL := buildPriorityChain(whereClauses)
// --------------------------------------------------------------------- //
// 6. Final EXISTS subquery
// --------------------------------------------------------------------- //
query := fmt.Sprintf(`EXISTS (
SELECT 1
FROM %s cursor_select
%s
WHERE cursor_select.%s = %s
AND (%s)
)`,
tableName,
joinSQL,
pkName,
cursorID,
orSQL,
)
return query, nil
}
// ------------------------------------------------------------------------- //
// Helper: get active cursor (forward or backward)
func (opts *ExtendedRequestOptions) getActiveCursor() (id string, direction CursorDirection) {
if opts.CursorForward != "" {
return opts.CursorForward, CursorForward
}
if opts.CursorBackward != "" {
return opts.CursorBackward, CursorBackward
}
return "", 0
}
// Helper: extract sort columns
func (opts *ExtendedRequestOptions) getSortColumns() []common.SortOption {
if opts.RequestOptions.Sort != nil {
return opts.RequestOptions.Sort
}
return nil
}
// Helper: clean sort field (remove desc, asc, nulls)
func (opts *ExtendedRequestOptions) cleanSortField(field string) string {
f := strings.ToLower(field)
for _, token := range []string{"desc", "asc", "nulls last", "nulls first"} {
f = strings.ReplaceAll(f, token, "")
}
return strings.TrimSpace(f)
}
// Helper: resolve column (main, JSON, CQL, join)
func (opts *ExtendedRequestOptions) resolveColumn(
field, prefix, tableName string,
modelColumns []string,
) (cursorCol, targetCol string, isJoin bool, err error) {
// JSON field
if strings.Contains(field, "->") {
return "cursor_select." + field, tableName + "." + field, false, nil
}
// CQL via ComputedQL
if strings.Contains(strings.ToLower(field), "cql") && opts.ComputedQL != nil {
if expr, ok := opts.ComputedQL[field]; ok {
return "cursor_select." + expr, expr, false, nil
}
}
// Main table column
if modelColumns != nil {
for _, col := range modelColumns {
if strings.EqualFold(col, field) {
return "cursor_select." + field, tableName + "." + field, false, nil
}
}
} else {
// No validation → allow all main-table fields
return "cursor_select." + field, tableName + "." + field, false, nil
}
// Joined column
if prefix != "" && prefix != tableName {
return "", "", true, nil
}
return "", "", false, fmt.Errorf("invalid column: %s", field)
}
// ------------------------------------------------------------------------- //
// Helper: rewrite JOIN clause for cursor subquery
func rewriteJoin(joinClause, mainTable, alias string) (joinSQL, cursorAlias string) {
joinSQL = strings.ReplaceAll(joinClause, mainTable+".", "cursor_select.")
cursorAlias = "cursor_select_" + alias
joinSQL = strings.ReplaceAll(joinSQL, " "+alias+" ", " "+cursorAlias+" ")
joinSQL = strings.ReplaceAll(joinSQL, " "+alias+".", " "+cursorAlias+".")
return joinSQL, cursorAlias
}
// ------------------------------------------------------------------------- //
// Helper: build OR-AND priority chain
func buildPriorityChain(clauses []string) string {
var or []string
for i := 0; i < len(clauses); i++ {
and := strings.Join(clauses[:i+1], "\n AND ")
or = append(or, "("+and+")")
}
return strings.Join(or, "\n OR ")
}

View File

@@ -6,10 +6,12 @@ import (
"fmt"
"net/http"
"reflect"
"runtime/debug"
"strings"
"github.com/Warky-Devs/ResolveSpec/pkg/common"
"github.com/Warky-Devs/ResolveSpec/pkg/logger"
"github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/common/adapters/database"
"github.com/bitechdev/ResolveSpec/pkg/logger"
)
// Handler handles API requests using database and model abstractions
@@ -17,6 +19,7 @@ import (
type Handler struct {
db common.Database
registry common.ModelRegistry
hooks *HookRegistry
}
// NewHandler creates a new API handler with database and registry abstractions
@@ -24,12 +27,33 @@ func NewHandler(db common.Database, registry common.ModelRegistry) *Handler {
return &Handler{
db: db,
registry: registry,
hooks: NewHookRegistry(),
}
}
// Hooks returns the hook registry for this handler
// Use this to register custom hooks for operations
func (h *Handler) Hooks() *HookRegistry {
return h.hooks
}
// handlePanic is a helper function to handle panics with stack traces
func (h *Handler) handlePanic(w common.ResponseWriter, method string, err interface{}) {
stack := debug.Stack()
logger.Error("Panic in %s: %v\nStack trace:\n%s", method, err, string(stack))
h.sendError(w, http.StatusInternalServerError, "internal_error", fmt.Sprintf("Internal server error in %s", method), fmt.Errorf("%v", err))
}
// Handle processes API requests through router-agnostic interface
// Options are read from HTTP headers instead of request body
func (h *Handler) Handle(w common.ResponseWriter, r common.Request, params map[string]string) {
// Capture panics and return error response
defer func() {
if err := recover(); err != nil {
h.handlePanic(w, "Handle", err)
}
}()
ctx := context.Background()
schema := params["schema"]
@@ -52,12 +76,36 @@ func (h *Handler) Handle(w common.ResponseWriter, r common.Request, params map[s
return
}
// Validate that the model is a struct type (not a slice or pointer to slice)
modelType := reflect.TypeOf(model)
originalType := modelType
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
modelType = modelType.Elem()
}
if modelType == nil || modelType.Kind() != reflect.Struct {
logger.Error("Model for %s.%s must be a struct type, got %v. Please register models as struct types, not slices or pointers to slices.", schema, entity, originalType)
h.sendError(w, http.StatusInternalServerError, "invalid_model_type",
fmt.Sprintf("Model must be a struct type, got %v. Ensure you register the struct (e.g., ModelCoreAccount{}) not a slice (e.g., []*ModelCoreAccount)", originalType),
fmt.Errorf("invalid model type: %v", originalType))
return
}
// If the registered model was a pointer or slice, use the unwrapped struct type
if originalType != modelType {
model = reflect.New(modelType).Elem().Interface()
}
modelPtr := reflect.New(reflect.TypeOf(model)).Interface()
tableName := h.getTableName(schema, entity, model)
// Add request-scoped data to context
ctx = WithRequestData(ctx, schema, entity, tableName, model, modelPtr)
// Validate and filter columns in options (log warnings for invalid columns)
validator := common.NewColumnValidator(model)
options = filterExtendedOptions(validator, options)
switch method {
case "GET":
if id != "" {
@@ -107,6 +155,13 @@ func (h *Handler) Handle(w common.ResponseWriter, r common.Request, params map[s
// HandleGet processes GET requests for metadata
func (h *Handler) HandleGet(w common.ResponseWriter, r common.Request, params map[string]string) {
// Capture panics and return error response
defer func() {
if err := recover(); err != nil {
h.handlePanic(w, "HandleGet", err)
}
}()
schema := params["schema"]
entity := params["entity"]
@@ -126,15 +181,64 @@ func (h *Handler) HandleGet(w common.ResponseWriter, r common.Request, params ma
// parseOptionsFromHeaders is now implemented in headers.go
func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id string, options ExtendedRequestOptions) {
// Capture panics and return error response
defer func() {
if err := recover(); err != nil {
h.handlePanic(w, "handleRead", err)
}
}()
schema := GetSchema(ctx)
entity := GetEntity(ctx)
tableName := GetTableName(ctx)
modelPtr := GetModelPtr(ctx)
model := GetModel(ctx)
// Execute BeforeRead hooks
hookCtx := &HookContext{
Context: ctx,
Handler: h,
Schema: schema,
Entity: entity,
TableName: tableName,
Model: model,
Options: options,
ID: id,
Writer: w,
}
if err := h.hooks.Execute(BeforeRead, hookCtx); err != nil {
logger.Error("BeforeRead hook failed: %v", err)
h.sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err)
return
}
// Validate and unwrap model type to get base struct
modelType := reflect.TypeOf(model)
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
modelType = modelType.Elem()
}
if modelType == nil || modelType.Kind() != reflect.Struct {
logger.Error("Model must be a struct type, got %v for %s.%s", modelType, schema, entity)
h.sendError(w, http.StatusInternalServerError, "invalid_model", "Model must be a struct type", fmt.Errorf("invalid model type: %v", modelType))
return
}
// Create a pointer to a slice of pointers to the model type for query results
modelPtr := reflect.New(reflect.SliceOf(reflect.PointerTo(modelType))).Interface()
logger.Info("Reading records from %s.%s", schema, entity)
// Start with Model() using the slice pointer to avoid "Model(nil)" errors in Count()
// Bun's Model() accepts both single pointers and slice pointers
query := h.db.NewSelect().Model(modelPtr)
query = query.Table(tableName)
// Only set Table() if the model doesn't provide a table name via the underlying type
// Create a temporary instance to check for TableNameProvider
tempInstance := reflect.New(modelType).Interface()
if provider, ok := tempInstance.(common.TableNameProvider); !ok || provider.TableName() == "" {
query = query.Table(tableName)
}
// Apply column selection
if len(options.Columns) > 0 {
@@ -163,10 +267,21 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
// This may need to be handled differently per database adapter
}
// Apply filters
for _, filter := range options.Filters {
logger.Debug("Applying filter: %s %s %v", filter.Column, filter.Operator, filter.Value)
query = h.applyFilter(query, filter)
// Apply filters - validate and adjust for column types first
for i := range options.Filters {
filter := &options.Filters[i]
// Validate and adjust filter based on column type
castInfo := h.ValidateAndAdjustFilterForColumnType(filter, model)
// Default to AND if LogicOperator is not set
logicOp := filter.LogicOperator
if logicOp == "" {
logicOp = "AND"
}
logger.Debug("Applying filter: %s %s %v (needsCast=%v, logic=%s)", filter.Column, filter.Operator, filter.Value, castInfo.NeedsCast, logicOp)
query = h.applyFilter(query, *filter, tableName, castInfo.NeedsCast, logicOp)
}
// Apply custom SQL WHERE clause (AND condition)
@@ -223,10 +338,41 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
query = query.Offset(*options.Offset)
}
// Execute query - create a slice of pointers to the model type
model := GetModel(ctx)
resultSlice := reflect.New(reflect.SliceOf(reflect.PointerTo(reflect.TypeOf(model)))).Interface()
if err := query.Scan(ctx, resultSlice); err != nil {
// Apply cursor-based pagination
if len(options.CursorForward) > 0 || len(options.CursorBackward) > 0 {
logger.Debug("Applying cursor pagination")
// Get primary key name
pkName := database.GetPrimaryKeyName(model)
// Extract model columns for validation using the generic database function
modelColumns := database.GetModelColumns(model)
// Build expand joins map (if needed in future)
var expandJoins map[string]string
if len(options.Expand) > 0 {
expandJoins = make(map[string]string)
// TODO: Build actual JOIN SQL for each expand relation
// For now, pass empty map as joins are handled via Preload
}
// Get cursor filter SQL
cursorFilter, err := options.GetCursorFilter(tableName, pkName, modelColumns, expandJoins)
if err != nil {
logger.Error("Error building cursor filter: %v", err)
h.sendError(w, http.StatusBadRequest, "cursor_error", "Invalid cursor pagination", err)
return
}
// Apply cursor filter to query
if cursorFilter != "" {
logger.Debug("Applying cursor filter: %s", cursorFilter)
query = query.Where(cursorFilter)
}
}
// Execute query - modelPtr was already created earlier
if err := query.Scan(ctx, modelPtr); err != nil {
logger.Error("Error executing query: %v", err)
h.sendError(w, http.StatusInternalServerError, "query_error", "Error executing query", err)
return
@@ -248,10 +394,27 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
Offset: offset,
}
h.sendFormattedResponse(w, resultSlice, metadata, options)
// Execute AfterRead hooks
hookCtx.Result = modelPtr
hookCtx.Error = nil
if err := h.hooks.Execute(AfterRead, hookCtx); err != nil {
logger.Error("AfterRead hook failed: %v", err)
h.sendError(w, http.StatusInternalServerError, "hook_error", "Hook execution failed", err)
return
}
h.sendFormattedResponse(w, modelPtr, metadata, options)
}
func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, data interface{}, options ExtendedRequestOptions) {
// Capture panics and return error response
defer func() {
if err := recover(); err != nil {
h.handlePanic(w, "handleCreate", err)
}
}()
schema := GetSchema(ctx)
entity := GetEntity(ctx)
tableName := GetTableName(ctx)
@@ -259,6 +422,28 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
logger.Info("Creating record in %s.%s", schema, entity)
// Execute BeforeCreate hooks
hookCtx := &HookContext{
Context: ctx,
Handler: h,
Schema: schema,
Entity: entity,
TableName: tableName,
Model: model,
Options: options,
Data: data,
Writer: w,
}
if err := h.hooks.Execute(BeforeCreate, hookCtx); err != nil {
logger.Error("BeforeCreate hook failed: %v", err)
h.sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err)
return
}
// Use potentially modified data from hook context
data = hookCtx.Data
// Handle batch creation
dataValue := reflect.ValueOf(data)
if dataValue.Kind() == reflect.Slice || dataValue.Kind() == reflect.Array {
@@ -293,6 +478,16 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
return
}
// Execute AfterCreate hooks for batch creation
hookCtx.Result = map[string]interface{}{"created": dataValue.Len()}
hookCtx.Error = nil
if err := h.hooks.Execute(AfterCreate, hookCtx); err != nil {
logger.Error("AfterCreate hook failed: %v", err)
h.sendError(w, http.StatusInternalServerError, "hook_error", "Hook execution failed", err)
return
}
h.sendResponse(w, map[string]interface{}{"created": dataValue.Len()}, nil)
return
}
@@ -318,16 +513,57 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
return
}
// Execute AfterCreate hooks for single record creation
hookCtx.Result = modelValue
hookCtx.Error = nil
if err := h.hooks.Execute(AfterCreate, hookCtx); err != nil {
logger.Error("AfterCreate hook failed: %v", err)
h.sendError(w, http.StatusInternalServerError, "hook_error", "Hook execution failed", err)
return
}
h.sendResponse(w, modelValue, nil)
}
func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id string, idPtr *int64, data interface{}, options ExtendedRequestOptions) {
// Capture panics and return error response
defer func() {
if err := recover(); err != nil {
h.handlePanic(w, "handleUpdate", err)
}
}()
schema := GetSchema(ctx)
entity := GetEntity(ctx)
tableName := GetTableName(ctx)
model := GetModel(ctx)
logger.Info("Updating record in %s.%s", schema, entity)
// Execute BeforeUpdate hooks
hookCtx := &HookContext{
Context: ctx,
Handler: h,
Schema: schema,
Entity: entity,
TableName: tableName,
Model: model,
Options: options,
ID: id,
Data: data,
Writer: w,
}
if err := h.hooks.Execute(BeforeUpdate, hookCtx); err != nil {
logger.Error("BeforeUpdate hook failed: %v", err)
h.sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err)
return
}
// Use potentially modified data from hook context
data = hookCtx.Data
// Convert data to map
dataMap, ok := data.(map[string]interface{})
if !ok {
@@ -363,18 +599,55 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id
return
}
h.sendResponse(w, map[string]interface{}{
// Execute AfterUpdate hooks
responseData := map[string]interface{}{
"updated": result.RowsAffected(),
}, nil)
}
hookCtx.Result = responseData
hookCtx.Error = nil
if err := h.hooks.Execute(AfterUpdate, hookCtx); err != nil {
logger.Error("AfterUpdate hook failed: %v", err)
h.sendError(w, http.StatusInternalServerError, "hook_error", "Hook execution failed", err)
return
}
h.sendResponse(w, responseData, nil)
}
func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id string) {
// Capture panics and return error response
defer func() {
if err := recover(); err != nil {
h.handlePanic(w, "handleDelete", err)
}
}()
schema := GetSchema(ctx)
entity := GetEntity(ctx)
tableName := GetTableName(ctx)
model := GetModel(ctx)
logger.Info("Deleting record from %s.%s", schema, entity)
// Execute BeforeDelete hooks
hookCtx := &HookContext{
Context: ctx,
Handler: h,
Schema: schema,
Entity: entity,
TableName: tableName,
Model: model,
ID: id,
Writer: w,
}
if err := h.hooks.Execute(BeforeDelete, hookCtx); err != nil {
logger.Error("BeforeDelete hook failed: %v", err)
h.sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err)
return
}
query := h.db.NewDelete().Table(tableName)
if id == "" {
@@ -391,85 +664,194 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
return
}
h.sendResponse(w, map[string]interface{}{
// Execute AfterDelete hooks
responseData := map[string]interface{}{
"deleted": result.RowsAffected(),
}, nil)
}
hookCtx.Result = responseData
hookCtx.Error = nil
if err := h.hooks.Execute(AfterDelete, hookCtx); err != nil {
logger.Error("AfterDelete hook failed: %v", err)
h.sendError(w, http.StatusInternalServerError, "hook_error", "Hook execution failed", err)
return
}
h.sendResponse(w, responseData, nil)
}
func (h *Handler) applyFilter(query common.SelectQuery, filter common.FilterOption) common.SelectQuery {
// qualifyColumnName ensures column name is fully qualified with table name if not already
func (h *Handler) qualifyColumnName(columnName, fullTableName string) string {
// Check if column already has a table/schema prefix (contains a dot)
if strings.Contains(columnName, ".") {
return columnName
}
// If no table name provided, return column as-is
if fullTableName == "" {
return columnName
}
// Extract just the table name from "schema.table" format
// Only use the table name part, not the schema
tableOnly := fullTableName
if idx := strings.LastIndex(fullTableName, "."); idx != -1 {
tableOnly = fullTableName[idx+1:]
}
// Return column qualified with just the table name
return fmt.Sprintf("%s.%s", tableOnly, columnName)
}
func (h *Handler) applyFilter(query common.SelectQuery, filter common.FilterOption, tableName string, needsCast bool, logicOp string) common.SelectQuery {
// Qualify the column name with table name if not already qualified
qualifiedColumn := h.qualifyColumnName(filter.Column, tableName)
// Apply casting to text if needed for non-numeric columns or non-numeric values
if needsCast {
qualifiedColumn = fmt.Sprintf("CAST(%s AS TEXT)", qualifiedColumn)
}
// Helper function to apply the correct Where method based on logic operator
applyWhere := func(condition string, args ...interface{}) common.SelectQuery {
if logicOp == "OR" {
return query.WhereOr(condition, args...)
}
return query.Where(condition, args...)
}
switch strings.ToLower(filter.Operator) {
case "eq", "equals":
return query.Where(fmt.Sprintf("%s = ?", filter.Column), filter.Value)
return applyWhere(fmt.Sprintf("%s = ?", qualifiedColumn), filter.Value)
case "neq", "not_equals", "ne":
return query.Where(fmt.Sprintf("%s != ?", filter.Column), filter.Value)
return applyWhere(fmt.Sprintf("%s != ?", qualifiedColumn), filter.Value)
case "gt", "greater_than":
return query.Where(fmt.Sprintf("%s > ?", filter.Column), filter.Value)
return applyWhere(fmt.Sprintf("%s > ?", qualifiedColumn), filter.Value)
case "gte", "greater_than_equals", "ge":
return query.Where(fmt.Sprintf("%s >= ?", filter.Column), filter.Value)
return applyWhere(fmt.Sprintf("%s >= ?", qualifiedColumn), filter.Value)
case "lt", "less_than":
return query.Where(fmt.Sprintf("%s < ?", filter.Column), filter.Value)
return applyWhere(fmt.Sprintf("%s < ?", qualifiedColumn), filter.Value)
case "lte", "less_than_equals", "le":
return query.Where(fmt.Sprintf("%s <= ?", filter.Column), filter.Value)
return applyWhere(fmt.Sprintf("%s <= ?", qualifiedColumn), filter.Value)
case "like":
return query.Where(fmt.Sprintf("%s LIKE ?", filter.Column), filter.Value)
return applyWhere(fmt.Sprintf("%s LIKE ?", qualifiedColumn), filter.Value)
case "ilike":
// Use ILIKE for case-insensitive search (PostgreSQL)
// For other databases, cast to citext or use LOWER()
return query.Where(fmt.Sprintf("CAST(%s AS TEXT) ILIKE ?", filter.Column), filter.Value)
// Column is already cast to TEXT if needed
return applyWhere(fmt.Sprintf("%s ILIKE ?", qualifiedColumn), filter.Value)
case "in":
return query.Where(fmt.Sprintf("%s IN (?)", filter.Column), filter.Value)
return applyWhere(fmt.Sprintf("%s IN (?)", qualifiedColumn), filter.Value)
case "between":
// Handle between operator - exclusive (> val1 AND < val2)
if values, ok := filter.Value.([]interface{}); ok && len(values) == 2 {
return query.Where(fmt.Sprintf("%s > ? AND %s < ?", filter.Column, filter.Column), values[0], values[1])
return applyWhere(fmt.Sprintf("%s > ? AND %s < ?", qualifiedColumn, qualifiedColumn), values[0], values[1])
} else if values, ok := filter.Value.([]string); ok && len(values) == 2 {
return query.Where(fmt.Sprintf("%s > ? AND %s < ?", filter.Column, filter.Column), values[0], values[1])
return applyWhere(fmt.Sprintf("%s > ? AND %s < ?", qualifiedColumn, qualifiedColumn), values[0], values[1])
}
logger.Warn("Invalid BETWEEN filter value format")
return query
case "between_inclusive":
// Handle between inclusive operator - inclusive (>= val1 AND <= val2)
if values, ok := filter.Value.([]interface{}); ok && len(values) == 2 {
return query.Where(fmt.Sprintf("%s >= ? AND %s <= ?", filter.Column, filter.Column), values[0], values[1])
return applyWhere(fmt.Sprintf("%s >= ? AND %s <= ?", qualifiedColumn, qualifiedColumn), values[0], values[1])
} else if values, ok := filter.Value.([]string); ok && len(values) == 2 {
return query.Where(fmt.Sprintf("%s >= ? AND %s <= ?", filter.Column, filter.Column), values[0], values[1])
return applyWhere(fmt.Sprintf("%s >= ? AND %s <= ?", qualifiedColumn, qualifiedColumn), values[0], values[1])
}
logger.Warn("Invalid BETWEEN INCLUSIVE filter value format")
return query
case "is_null", "isnull":
// Check for NULL values
return query.Where(fmt.Sprintf("(%s IS NULL OR %s = '')", filter.Column, filter.Column))
// Check for NULL values - don't use cast for NULL checks
colName := h.qualifyColumnName(filter.Column, tableName)
return applyWhere(fmt.Sprintf("(%s IS NULL OR %s = '')", colName, colName))
case "is_not_null", "isnotnull":
// Check for NOT NULL values
return query.Where(fmt.Sprintf("(%s IS NOT NULL AND %s != '')", filter.Column, filter.Column))
// Check for NOT NULL values - don't use cast for NULL checks
colName := h.qualifyColumnName(filter.Column, tableName)
return applyWhere(fmt.Sprintf("(%s IS NOT NULL AND %s != '')", colName, colName))
default:
logger.Warn("Unknown filter operator: %s, defaulting to equals", filter.Operator)
return query.Where(fmt.Sprintf("%s = ?", filter.Column), filter.Value)
return applyWhere(fmt.Sprintf("%s = ?", qualifiedColumn), filter.Value)
}
}
func (h *Handler) getTableName(schema, entity string, model interface{}) string {
// Check if model implements TableNameProvider
if provider, ok := model.(common.TableNameProvider); ok {
tableName := provider.TableName()
if tableName != "" {
return tableName
// parseTableName splits a table name that may contain schema into separate schema and table
func (h *Handler) parseTableName(fullTableName string) (schema, table string) {
if idx := strings.LastIndex(fullTableName, "."); idx != -1 {
return fullTableName[:idx], fullTableName[idx+1:]
}
return "", fullTableName
}
// getSchemaAndTable returns the schema and table name separately
// It checks SchemaProvider and TableNameProvider interfaces and handles cases where
// the table name may already include the schema (e.g., "public.users")
//
// Priority order:
// 1. If TableName() contains a schema (e.g., "myschema.mytable"), that schema takes precedence
// 2. If model implements SchemaProvider, use that schema
// 3. Otherwise, use the defaultSchema parameter
func (h *Handler) getSchemaAndTable(defaultSchema, entity string, model interface{}) (schema, table string) {
// First check if model provides a table name
// We check this FIRST because the table name might already contain the schema
if tableProvider, ok := model.(common.TableNameProvider); ok {
tableName := tableProvider.TableName()
// IMPORTANT: Check if the table name already contains a schema (e.g., "schema.table")
// This is common when models need to specify a different schema than the default
if tableSchema, tableOnly := h.parseTableName(tableName); tableSchema != "" {
// Table name includes schema - use it and ignore any other schema providers
logger.Debug("TableName() includes schema: %s.%s", tableSchema, tableOnly)
return tableSchema, tableOnly
}
// Table name is just the table name without schema
// Now determine which schema to use
if schemaProvider, ok := model.(common.SchemaProvider); ok {
schema = schemaProvider.SchemaName()
} else {
schema = defaultSchema
}
return schema, tableName
}
// Default to schema.entity
if schema != "" {
return fmt.Sprintf("%s.%s", schema, entity)
// No TableNameProvider, so check for schema and use entity as table name
if schemaProvider, ok := model.(common.SchemaProvider); ok {
schema = schemaProvider.SchemaName()
} else {
schema = defaultSchema
}
return entity
// Default to entity name as table
return schema, entity
}
// getTableName returns the full table name including schema (schema.table)
func (h *Handler) getTableName(schema, entity string, model interface{}) string {
schemaName, tableName := h.getSchemaAndTable(schema, entity, model)
if schemaName != "" {
return fmt.Sprintf("%s.%s", schemaName, tableName)
}
return tableName
}
func (h *Handler) generateMetadata(schema, entity string, model interface{}) *common.TableMetadata {
modelType := reflect.TypeOf(model)
if modelType.Kind() == reflect.Ptr {
// Unwrap pointers, slices, and arrays to get to the base struct type
for modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array {
modelType = modelType.Elem()
}
// Validate that we have a struct type
if modelType.Kind() != reflect.Struct {
logger.Error("Model type must be a struct, got %s for %s.%s", modelType.Kind(), schema, entity)
return &common.TableMetadata{
Schema: schema,
Table: h.getTableName(schema, entity, model),
Columns: []common.Column{},
}
}
tableName := h.getTableName(schema, entity, model)
metadata := &common.TableMetadata{
@@ -555,7 +937,7 @@ func (h *Handler) sendFormattedResponse(w common.ResponseWriter, data interface{
if options.CleanJSON {
data = h.cleanJSON(data)
}
w.SetHeader("Content-Type", "application/json")
// Format response based on response format option
switch options.ResponseFormat {
case "simple":
@@ -610,3 +992,41 @@ func (h *Handler) sendError(w common.ResponseWriter, statusCode int, code, messa
w.WriteHeader(statusCode)
w.WriteJSON(response)
}
// filterExtendedOptions filters all column references, removing invalid ones and logging warnings
func filterExtendedOptions(validator *common.ColumnValidator, options ExtendedRequestOptions) ExtendedRequestOptions {
filtered := options
// Filter base RequestOptions
filtered.RequestOptions = validator.FilterRequestOptions(options.RequestOptions)
// Filter SearchColumns
filtered.SearchColumns = validator.FilterValidColumns(options.SearchColumns)
// Filter AdvancedSQL column keys
filteredAdvSQL := make(map[string]string)
for colName, sqlExpr := range options.AdvancedSQL {
if validator.IsValidColumn(colName) {
filteredAdvSQL[colName] = sqlExpr
} else {
logger.Warn("Invalid column in advanced SQL removed: %s", colName)
}
}
filtered.AdvancedSQL = filteredAdvSQL
// ComputedQL columns are allowed to be any name since they're computed
// No filtering needed for ComputedQL keys
filtered.ComputedQL = options.ComputedQL
// Filter Expand columns
filteredExpands := make([]ExpandOption, 0, len(options.Expand))
for _, expand := range options.Expand {
filteredExpand := expand
// Don't validate relation name, only columns
filteredExpand.Columns = validator.FilterValidColumns(expand.Columns)
filteredExpands = append(filteredExpands, filteredExpand)
}
filtered.Expand = filteredExpands
return filtered
}

View File

@@ -4,11 +4,12 @@ import (
"encoding/base64"
"encoding/json"
"fmt"
"reflect"
"strconv"
"strings"
"github.com/Warky-Devs/ResolveSpec/pkg/common"
"github.com/Warky-Devs/ResolveSpec/pkg/logger"
"github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/logger"
)
// ExtendedRequestOptions extends common.RequestOptions with additional features
@@ -19,21 +20,21 @@ type ExtendedRequestOptions struct {
CleanJSON bool
// Advanced filtering
SearchColumns []string
SearchColumns []string
CustomSQLWhere string
CustomSQLOr string
CustomSQLOr string
// Joins
Expand []ExpandOption
// Advanced features
AdvancedSQL map[string]string // Column -> SQL expression
ComputedQL map[string]string // Column -> CQL expression
Distinct bool
SkipCount bool
SkipCache bool
AdvancedSQL map[string]string // Column -> SQL expression
ComputedQL map[string]string // Column -> CQL expression
Distinct bool
SkipCount bool
SkipCache bool
FetchRowNumber *string
PKRow *string
PKRow *string
// Response format
ResponseFormat string // "simple", "detail", "syncfusion"
@@ -42,42 +43,58 @@ type ExtendedRequestOptions struct {
AtomicTransaction bool
// Cursor pagination
CursorForward string
CursorForward string
CursorBackward string
}
// ExpandOption represents a relation expansion configuration
type ExpandOption struct {
Relation string
Columns []string
Where string
Sort string
Columns []string
Where string
Sort string
}
// decodeHeaderValue decodes base64 encoded header values
// Supports ZIP_ and __ prefixes for base64 encoding
func decodeHeaderValue(value string) string {
// Check for ZIP_ prefix
if strings.HasPrefix(value, "ZIP_") {
decoded, err := base64.StdEncoding.DecodeString(value[4:])
if err == nil {
return string(decoded)
str, _ := DecodeParam(value)
return str
}
// DecodeParam - Decodes parameter string and returns unencoded string
func DecodeParam(pStr string) (string, error) {
var code string = pStr
if strings.HasPrefix(pStr, "ZIP_") {
code = strings.ReplaceAll(pStr, "ZIP_", "")
code = strings.ReplaceAll(code, "\n", "")
code = strings.ReplaceAll(code, "\r", "")
code = strings.ReplaceAll(code, " ", "")
strDat, err := base64.StdEncoding.DecodeString(code)
if err != nil {
return code, fmt.Errorf("failed to read parameter base64: %v", err)
} else {
code = string(strDat)
}
} else if strings.HasPrefix(pStr, "__") {
code = strings.ReplaceAll(pStr, "__", "")
code = strings.ReplaceAll(code, "\n", "")
code = strings.ReplaceAll(code, "\r", "")
code = strings.ReplaceAll(code, " ", "")
strDat, err := base64.StdEncoding.DecodeString(code)
if err != nil {
return code, fmt.Errorf("failed to read parameter base64: %v", err)
} else {
code = string(strDat)
}
logger.Warn("Failed to decode ZIP_ prefixed value: %v", err)
return value
}
// Check for __ prefix
if strings.HasPrefix(value, "__") {
decoded, err := base64.StdEncoding.DecodeString(value[2:])
if err == nil {
return string(decoded)
}
logger.Warn("Failed to decode __ prefixed value: %v", err)
return value
if strings.HasPrefix(code, "ZIP_") || strings.HasPrefix(code, "__") {
code, _ = DecodeParam(code)
}
return value
return code, nil
}
// parseOptionsFromHeaders parses all request options from HTTP headers
@@ -85,12 +102,13 @@ func (h *Handler) parseOptionsFromHeaders(r common.Request) ExtendedRequestOptio
options := ExtendedRequestOptions{
RequestOptions: common.RequestOptions{
Filters: make([]common.FilterOption, 0),
Sort: make([]common.SortOption, 0),
Sort: make([]common.SortOption, 0),
Preload: make([]common.PreloadOption, 0),
},
AdvancedSQL: make(map[string]string),
ComputedQL: make(map[string]string),
Expand: make([]ExpandOption, 0),
AdvancedSQL: make(map[string]string),
ComputedQL: make(map[string]string),
Expand: make([]ExpandOption, 0),
ResponseFormat: "simple", // Default response format
}
// Get all headers
@@ -198,6 +216,9 @@ func (h *Handler) parseSelectFields(options *ExtendedRequestOptions, value strin
return
}
options.Columns = h.parseCommaSeparated(value)
if len(options.Columns) > 1 {
options.CleanJSON = true
}
}
// parseNotSelectFields parses x-not-select-fields header
@@ -206,15 +227,19 @@ func (h *Handler) parseNotSelectFields(options *ExtendedRequestOptions, value st
return
}
options.OmitColumns = h.parseCommaSeparated(value)
if len(options.OmitColumns) > 1 {
options.CleanJSON = true
}
}
// parseFieldFilter parses x-fieldfilter-{colname} header (exact match)
func (h *Handler) parseFieldFilter(options *ExtendedRequestOptions, headerKey, value string) {
colName := strings.TrimPrefix(headerKey, "x-fieldfilter-")
options.Filters = append(options.Filters, common.FilterOption{
Column: colName,
Operator: "eq",
Value: value,
Column: colName,
Operator: "eq",
Value: value,
LogicOperator: "AND", // Default to AND
})
}
@@ -223,9 +248,10 @@ func (h *Handler) parseSearchFilter(options *ExtendedRequestOptions, headerKey,
colName := strings.TrimPrefix(headerKey, "x-searchfilter-")
// Use ILIKE for fuzzy search
options.Filters = append(options.Filters, common.FilterOption{
Column: colName,
Operator: "ilike",
Value: "%" + value + "%",
Column: colName,
Operator: "ilike",
Value: "%" + value + "%",
LogicOperator: "AND", // Default to AND
})
}
@@ -254,70 +280,68 @@ func (h *Handler) parseSearchOp(options *ExtendedRequestOptions, headerKey, valu
colName := parts[1]
// Map operator names to filter operators
filterOp := h.mapSearchOperator(operator, value)
filterOp := h.mapSearchOperator(colName, operator, value)
// Set the logic operator (AND or OR)
filterOp.LogicOperator = logicOp
options.Filters = append(options.Filters, filterOp)
// Note: OR logic would need special handling in query builder
// For now, we'll add a comment to indicate OR logic
if logicOp == "OR" {
// TODO: Implement OR logic in query builder
logger.Debug("OR logic filter: %s %s %v", colName, filterOp.Operator, filterOp.Value)
}
logger.Debug("%s logic filter: %s %s %v", logicOp, colName, filterOp.Operator, filterOp.Value)
}
// mapSearchOperator maps search operator names to filter operators
func (h *Handler) mapSearchOperator(operator, value string) common.FilterOption {
func (h *Handler) mapSearchOperator(colName, operator, value string) common.FilterOption {
operator = strings.ToLower(operator)
switch operator {
case "contains":
return common.FilterOption{Operator: "ilike", Value: "%" + value + "%"}
case "contains", "contain", "like":
return common.FilterOption{Column: colName, Operator: "ilike", Value: "%" + value + "%"}
case "beginswith", "startswith":
return common.FilterOption{Operator: "ilike", Value: value + "%"}
return common.FilterOption{Column: colName, Operator: "ilike", Value: value + "%"}
case "endswith":
return common.FilterOption{Operator: "ilike", Value: "%" + value}
case "equals", "eq":
return common.FilterOption{Operator: "eq", Value: value}
case "notequals", "neq", "ne":
return common.FilterOption{Operator: "neq", Value: value}
case "greaterthan", "gt":
return common.FilterOption{Operator: "gt", Value: value}
case "lessthan", "lt":
return common.FilterOption{Operator: "lt", Value: value}
case "greaterthanorequal", "gte", "ge":
return common.FilterOption{Operator: "gte", Value: value}
case "lessthanorequal", "lte", "le":
return common.FilterOption{Operator: "lte", Value: value}
return common.FilterOption{Column: colName, Operator: "ilike", Value: "%" + value}
case "equals", "eq", "=":
return common.FilterOption{Column: colName, Operator: "eq", Value: value}
case "notequals", "neq", "ne", "!=", "<>":
return common.FilterOption{Column: colName, Operator: "neq", Value: value}
case "greaterthan", "gt", ">":
return common.FilterOption{Column: colName, Operator: "gt", Value: value}
case "lessthan", "lt", "<":
return common.FilterOption{Column: colName, Operator: "lt", Value: value}
case "greaterthanorequal", "gte", "ge", ">=":
return common.FilterOption{Column: colName, Operator: "gte", Value: value}
case "lessthanorequal", "lte", "le", "<=":
return common.FilterOption{Column: colName, Operator: "lte", Value: value}
case "between":
// Parse between values (format: "value1,value2")
// Between is exclusive (> value1 AND < value2)
parts := strings.Split(value, ",")
if len(parts) == 2 {
return common.FilterOption{Operator: "between", Value: parts}
return common.FilterOption{Column: colName, Operator: "between", Value: parts}
}
return common.FilterOption{Operator: "eq", Value: value}
return common.FilterOption{Column: colName, Operator: "eq", Value: value}
case "betweeninclusive":
// Parse between values (format: "value1,value2")
// Between inclusive is >= value1 AND <= value2
parts := strings.Split(value, ",")
if len(parts) == 2 {
return common.FilterOption{Operator: "between_inclusive", Value: parts}
return common.FilterOption{Column: colName, Operator: "between_inclusive", Value: parts}
}
return common.FilterOption{Operator: "eq", Value: value}
return common.FilterOption{Column: colName, Operator: "eq", Value: value}
case "in":
// Parse IN values (format: "value1,value2,value3")
values := strings.Split(value, ",")
return common.FilterOption{Operator: "in", Value: values}
return common.FilterOption{Column: colName, Operator: "in", Value: values}
case "empty", "isnull", "null":
// Check for NULL or empty string
return common.FilterOption{Operator: "is_null", Value: nil}
return common.FilterOption{Column: colName, Operator: "is_null", Value: nil}
case "notempty", "isnotnull", "notnull":
// Check for NOT NULL
return common.FilterOption{Operator: "is_not_null", Value: nil}
return common.FilterOption{Column: colName, Operator: "is_not_null", Value: nil}
default:
logger.Warn("Unknown search operator: %s, defaulting to equals", operator)
return common.FilterOption{Operator: "eq", Value: value}
return common.FilterOption{Column: colName, Operator: "eq", Value: value}
}
}
@@ -404,10 +428,16 @@ func (h *Handler) parseSorting(options *ExtendedRequestOptions, value string) {
} else if strings.HasPrefix(field, "+") {
direction = "ASC"
colName = strings.TrimPrefix(field, "+")
} else if strings.HasSuffix(field, " desc") {
direction = "DESC"
colName = strings.TrimSuffix(field, "desc")
} else if strings.HasSuffix(field, " asc") {
direction = "ASC"
colName = strings.TrimSuffix(field, "asc")
}
options.Sort = append(options.Sort, common.SortOption{
Column: colName,
Column: strings.Trim(colName, " "),
Direction: direction,
})
}
@@ -439,3 +469,235 @@ func (h *Handler) parseJSONHeader(value string) (map[string]interface{}, error)
}
return result, nil
}
// getColumnTypeFromModel uses reflection to determine the Go type of a column in a model
func (h *Handler) getColumnTypeFromModel(model interface{}, colName string) reflect.Kind {
if model == nil {
return reflect.Invalid
}
modelType := reflect.TypeOf(model)
// Dereference pointer if needed
if modelType.Kind() == reflect.Ptr {
modelType = modelType.Elem()
}
// Ensure it's a struct
if modelType.Kind() != reflect.Struct {
return reflect.Invalid
}
// Find the field by JSON tag or field name
for i := 0; i < modelType.NumField(); i++ {
field := modelType.Field(i)
// Check JSON tag
jsonTag := field.Tag.Get("json")
if jsonTag != "" {
// Parse JSON tag (format: "name,omitempty")
parts := strings.Split(jsonTag, ",")
if parts[0] == colName {
return field.Type.Kind()
}
}
// Check field name (case-insensitive)
if strings.EqualFold(field.Name, colName) {
return field.Type.Kind()
}
// Check snake_case conversion
snakeCaseName := toSnakeCase(field.Name)
if snakeCaseName == colName {
return field.Type.Kind()
}
}
return reflect.Invalid
}
// toSnakeCase converts a string from CamelCase to snake_case
func toSnakeCase(s string) string {
var result strings.Builder
for i, r := range s {
if i > 0 && r >= 'A' && r <= 'Z' {
result.WriteRune('_')
}
result.WriteRune(r)
}
return strings.ToLower(result.String())
}
// isNumericType checks if a reflect.Kind is a numeric type
func isNumericType(kind reflect.Kind) bool {
return kind == reflect.Int || kind == reflect.Int8 || kind == reflect.Int16 ||
kind == reflect.Int32 || kind == reflect.Int64 || kind == reflect.Uint ||
kind == reflect.Uint8 || kind == reflect.Uint16 || kind == reflect.Uint32 ||
kind == reflect.Uint64 || kind == reflect.Float32 || kind == reflect.Float64
}
// isStringType checks if a reflect.Kind is a string type
func isStringType(kind reflect.Kind) bool {
return kind == reflect.String
}
// isBoolType checks if a reflect.Kind is a boolean type
func isBoolType(kind reflect.Kind) bool {
return kind == reflect.Bool
}
// convertToNumericType converts a string value to the appropriate numeric type
func convertToNumericType(value string, kind reflect.Kind) (interface{}, error) {
value = strings.TrimSpace(value)
switch kind {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
// Parse as integer
bitSize := 64
switch kind {
case reflect.Int8:
bitSize = 8
case reflect.Int16:
bitSize = 16
case reflect.Int32:
bitSize = 32
}
intVal, err := strconv.ParseInt(value, 10, bitSize)
if err != nil {
return nil, fmt.Errorf("invalid integer value: %w", err)
}
// Return the appropriate type
switch kind {
case reflect.Int:
return int(intVal), nil
case reflect.Int8:
return int8(intVal), nil
case reflect.Int16:
return int16(intVal), nil
case reflect.Int32:
return int32(intVal), nil
case reflect.Int64:
return intVal, nil
}
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
// Parse as unsigned integer
bitSize := 64
switch kind {
case reflect.Uint8:
bitSize = 8
case reflect.Uint16:
bitSize = 16
case reflect.Uint32:
bitSize = 32
}
uintVal, err := strconv.ParseUint(value, 10, bitSize)
if err != nil {
return nil, fmt.Errorf("invalid unsigned integer value: %w", err)
}
// Return the appropriate type
switch kind {
case reflect.Uint:
return uint(uintVal), nil
case reflect.Uint8:
return uint8(uintVal), nil
case reflect.Uint16:
return uint16(uintVal), nil
case reflect.Uint32:
return uint32(uintVal), nil
case reflect.Uint64:
return uintVal, nil
}
case reflect.Float32, reflect.Float64:
// Parse as float
bitSize := 64
if kind == reflect.Float32 {
bitSize = 32
}
floatVal, err := strconv.ParseFloat(value, bitSize)
if err != nil {
return nil, fmt.Errorf("invalid float value: %w", err)
}
if kind == reflect.Float32 {
return float32(floatVal), nil
}
return floatVal, nil
}
return nil, fmt.Errorf("unsupported numeric type: %v", kind)
}
// isNumericValue checks if a string value can be parsed as a number
func isNumericValue(value string) bool {
value = strings.TrimSpace(value)
_, err := strconv.ParseFloat(value, 64)
return err == nil
}
// ColumnCastInfo holds information about whether a column needs casting
type ColumnCastInfo struct {
NeedsCast bool
IsNumericType bool
}
// ValidateAndAdjustFilterForColumnType validates and adjusts a filter based on column type
// Returns ColumnCastInfo indicating whether the column should be cast to text in SQL
func (h *Handler) ValidateAndAdjustFilterForColumnType(filter *common.FilterOption, model interface{}) ColumnCastInfo {
if filter == nil || model == nil {
return ColumnCastInfo{NeedsCast: false, IsNumericType: false}
}
colType := h.getColumnTypeFromModel(model, filter.Column)
if colType == reflect.Invalid {
// Column not found in model, no casting needed
logger.Debug("Column %s not found in model, skipping type validation", filter.Column)
return ColumnCastInfo{NeedsCast: false, IsNumericType: false}
}
// Check if the input value is numeric
valueIsNumeric := false
if strVal, ok := filter.Value.(string); ok {
strVal = strings.Trim(strVal, "%")
valueIsNumeric = isNumericValue(strVal)
}
// Adjust based on column type
switch {
case isNumericType(colType):
// Column is numeric
if valueIsNumeric {
// Value is numeric - try to convert it
if strVal, ok := filter.Value.(string); ok {
strVal = strings.Trim(strVal, "%")
numericVal, err := convertToNumericType(strVal, colType)
if err != nil {
logger.Debug("Failed to convert value '%s' to numeric type for column %s, will use text cast", strVal, filter.Column)
return ColumnCastInfo{NeedsCast: true, IsNumericType: true}
}
filter.Value = numericVal
}
// No cast needed - numeric column with numeric value
return ColumnCastInfo{NeedsCast: false, IsNumericType: true}
} else {
// Value is not numeric - cast column to text for comparison
logger.Debug("Non-numeric value for numeric column %s, will cast to text", filter.Column)
return ColumnCastInfo{NeedsCast: true, IsNumericType: true}
}
case isStringType(colType):
// String columns don't need casting
return ColumnCastInfo{NeedsCast: false, IsNumericType: false}
default:
// For bool, time.Time, and other complex types - cast to text
logger.Debug("Complex type column %s, will cast to text", filter.Column)
return ColumnCastInfo{NeedsCast: true, IsNumericType: false}
}
}

140
pkg/restheadspec/hooks.go Normal file
View File

@@ -0,0 +1,140 @@
package restheadspec
import (
"context"
"fmt"
"github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/logger"
)
// HookType defines the type of hook to execute
type HookType string
const (
// Read operation hooks
BeforeRead HookType = "before_read"
AfterRead HookType = "after_read"
// Create operation hooks
BeforeCreate HookType = "before_create"
AfterCreate HookType = "after_create"
// Update operation hooks
BeforeUpdate HookType = "before_update"
AfterUpdate HookType = "after_update"
// Delete operation hooks
BeforeDelete HookType = "before_delete"
AfterDelete HookType = "after_delete"
)
// HookContext contains all the data available to a hook
type HookContext struct {
Context context.Context
Handler *Handler // Reference to the handler for accessing database, registry, etc.
Schema string
Entity string
TableName string
Model interface{}
Options ExtendedRequestOptions
// Operation-specific fields
ID string
Data interface{} // For create/update operations
Result interface{} // For after hooks
Error error // For after hooks
QueryFilter string // For read operations
// Response writer - allows hooks to modify response
Writer common.ResponseWriter
}
// HookFunc is the signature for hook functions
// It receives a HookContext and can modify it or return an error
// If an error is returned, the operation will be aborted
type HookFunc func(*HookContext) error
// HookRegistry manages all registered hooks
type HookRegistry struct {
hooks map[HookType][]HookFunc
}
// NewHookRegistry creates a new hook registry
func NewHookRegistry() *HookRegistry {
return &HookRegistry{
hooks: make(map[HookType][]HookFunc),
}
}
// Register adds a new hook for the specified hook type
func (r *HookRegistry) Register(hookType HookType, hook HookFunc) {
if r.hooks == nil {
r.hooks = make(map[HookType][]HookFunc)
}
r.hooks[hookType] = append(r.hooks[hookType], hook)
logger.Info("Registered hook for %s (total: %d)", hookType, len(r.hooks[hookType]))
}
// RegisterMultiple registers a hook for multiple hook types
func (r *HookRegistry) RegisterMultiple(hookTypes []HookType, hook HookFunc) {
for _, hookType := range hookTypes {
r.Register(hookType, hook)
}
}
// Execute runs all hooks for the specified type in order
// If any hook returns an error, execution stops and the error is returned
func (r *HookRegistry) Execute(hookType HookType, ctx *HookContext) error {
hooks, exists := r.hooks[hookType]
if !exists || len(hooks) == 0 {
logger.Debug("No hooks registered for %s", hookType)
return nil
}
logger.Debug("Executing %d hook(s) for %s", len(hooks), hookType)
for i, hook := range hooks {
if err := hook(ctx); err != nil {
logger.Error("Hook %d for %s failed: %v", i+1, hookType, err)
return fmt.Errorf("hook execution failed: %w", err)
}
}
logger.Debug("All hooks for %s executed successfully", hookType)
return nil
}
// Clear removes all hooks for the specified type
func (r *HookRegistry) Clear(hookType HookType) {
delete(r.hooks, hookType)
logger.Info("Cleared all hooks for %s", hookType)
}
// ClearAll removes all registered hooks
func (r *HookRegistry) ClearAll() {
r.hooks = make(map[HookType][]HookFunc)
logger.Info("Cleared all hooks")
}
// Count returns the number of hooks registered for a specific type
func (r *HookRegistry) Count(hookType HookType) int {
if hooks, exists := r.hooks[hookType]; exists {
return len(hooks)
}
return 0
}
// HasHooks returns true if there are any hooks registered for the specified type
func (r *HookRegistry) HasHooks(hookType HookType) bool {
return r.Count(hookType) > 0
}
// GetAllHookTypes returns all hook types that have registered hooks
func (r *HookRegistry) GetAllHookTypes() []HookType {
types := make([]HookType, 0, len(r.hooks))
for hookType := range r.hooks {
types = append(types, hookType)
}
return types
}

View File

@@ -0,0 +1,197 @@
package restheadspec
import (
"fmt"
"github.com/bitechdev/ResolveSpec/pkg/logger"
)
// This file contains example implementations showing how to use hooks
// These are just examples - you can implement hooks as needed for your application
// ExampleLoggingHook logs before and after operations
func ExampleLoggingHook(hookType HookType) HookFunc {
return func(ctx *HookContext) error {
logger.Info("[%s] Operation: %s.%s, ID: %s", hookType, ctx.Schema, ctx.Entity, ctx.ID)
if ctx.Data != nil {
logger.Debug("[%s] Data: %+v", hookType, ctx.Data)
}
if ctx.Result != nil {
logger.Debug("[%s] Result: %+v", hookType, ctx.Result)
}
return nil
}
}
// ExampleValidationHook validates data before create/update operations
func ExampleValidationHook(ctx *HookContext) error {
// Example: Ensure certain fields are present
if dataMap, ok := ctx.Data.(map[string]interface{}); ok {
// Check for required fields
requiredFields := []string{"name"} // Add your required fields here
for _, field := range requiredFields {
if _, exists := dataMap[field]; !exists {
return fmt.Errorf("required field missing: %s", field)
}
}
}
return nil
}
// ExampleAuthorizationHook checks if the user has permission to perform the operation
func ExampleAuthorizationHook(ctx *HookContext) error {
// Example: Check user permissions from context
// userID, ok := ctx.Context.Value("user_id").(string)
// if !ok {
// return fmt.Errorf("unauthorized: no user in context")
// }
// You can access the handler's database or registry if needed
// For example, to check permissions in the database:
// query := ctx.Handler.db.NewSelect().Table("permissions")...
// Add your authorization logic here
logger.Debug("Authorization check for %s.%s", ctx.Schema, ctx.Entity)
return nil
}
// ExampleDataTransformHook modifies data before create/update
func ExampleDataTransformHook(ctx *HookContext) error {
if dataMap, ok := ctx.Data.(map[string]interface{}); ok {
// Example: Add a timestamp or user ID
// dataMap["updated_at"] = time.Now()
// dataMap["updated_by"] = ctx.Context.Value("user_id")
// Update the context with modified data
ctx.Data = dataMap
logger.Debug("Data transformed for %s.%s", ctx.Schema, ctx.Entity)
}
return nil
}
// ExampleAuditLogHook creates audit log entries for operations
func ExampleAuditLogHook(hookType HookType) HookFunc {
return func(ctx *HookContext) error {
// Example: Log to audit system
auditEntry := map[string]interface{}{
"operation": hookType,
"schema": ctx.Schema,
"entity": ctx.Entity,
"table_name": ctx.TableName,
"id": ctx.ID,
}
if ctx.Error != nil {
auditEntry["error"] = ctx.Error.Error()
}
logger.Info("Audit log: %+v", auditEntry)
// In a real application, you would save this to a database using the handler
// Example:
// query := ctx.Handler.db.NewInsert().Table("audit_logs").Model(&auditEntry)
// if _, err := query.Exec(ctx.Context); err != nil {
// logger.Error("Failed to save audit log: %v", err)
// }
return nil
}
}
// ExampleCacheInvalidationHook invalidates cache after create/update/delete
func ExampleCacheInvalidationHook(ctx *HookContext) error {
// Example: Invalidate cache for the entity
cacheKey := fmt.Sprintf("%s.%s", ctx.Schema, ctx.Entity)
logger.Info("Invalidating cache for: %s", cacheKey)
// Add your cache invalidation logic here
// cache.Delete(cacheKey)
return nil
}
// ExampleFilterSensitiveDataHook removes sensitive data from responses
func ExampleFilterSensitiveDataHook(ctx *HookContext) error {
// Example: Remove password fields from results
// This would be called in AfterRead hooks
logger.Debug("Filtering sensitive data for %s.%s", ctx.Schema, ctx.Entity)
// Add your data filtering logic here
// You would iterate through ctx.Result and remove sensitive fields
return nil
}
// ExampleRelatedDataHook fetches related data using the handler's database
func ExampleRelatedDataHook(ctx *HookContext) error {
// Example: Fetch related data after reading the main entity
// This hook demonstrates using ctx.Handler to access the database
if ctx.Entity == "users" && ctx.Result != nil {
// Example: Fetch user's recent activity
// userID := ... extract from ctx.Result
// Use the handler's database to query related data
// query := ctx.Handler.db.NewSelect().Table("user_activity").Where("user_id = ?", userID)
// var activities []Activity
// if err := query.Scan(ctx.Context, &activities); err != nil {
// logger.Error("Failed to fetch user activities: %v", err)
// return err
// }
// Optionally modify the result to include the related data
// if resultMap, ok := ctx.Result.(map[string]interface{}); ok {
// resultMap["recent_activities"] = activities
// }
logger.Debug("Fetched related data for user entity")
}
return nil
}
// SetupExampleHooks demonstrates how to register hooks on a handler
func SetupExampleHooks(handler *Handler) {
hooks := handler.Hooks()
// Register logging hooks for all operations
hooks.Register(BeforeRead, ExampleLoggingHook(BeforeRead))
hooks.Register(AfterRead, ExampleLoggingHook(AfterRead))
hooks.Register(BeforeCreate, ExampleLoggingHook(BeforeCreate))
hooks.Register(AfterCreate, ExampleLoggingHook(AfterCreate))
hooks.Register(BeforeUpdate, ExampleLoggingHook(BeforeUpdate))
hooks.Register(AfterUpdate, ExampleLoggingHook(AfterUpdate))
hooks.Register(BeforeDelete, ExampleLoggingHook(BeforeDelete))
hooks.Register(AfterDelete, ExampleLoggingHook(AfterDelete))
// Register validation hooks for create/update
hooks.Register(BeforeCreate, ExampleValidationHook)
hooks.Register(BeforeUpdate, ExampleValidationHook)
// Register authorization hooks for all operations
hooks.RegisterMultiple([]HookType{
BeforeRead, BeforeCreate, BeforeUpdate, BeforeDelete,
}, ExampleAuthorizationHook)
// Register data transform hook for create/update
hooks.Register(BeforeCreate, ExampleDataTransformHook)
hooks.Register(BeforeUpdate, ExampleDataTransformHook)
// Register audit log hooks for after operations
hooks.Register(AfterCreate, ExampleAuditLogHook(AfterCreate))
hooks.Register(AfterUpdate, ExampleAuditLogHook(AfterUpdate))
hooks.Register(AfterDelete, ExampleAuditLogHook(AfterDelete))
// Register cache invalidation for after operations
hooks.Register(AfterCreate, ExampleCacheInvalidationHook)
hooks.Register(AfterUpdate, ExampleCacheInvalidationHook)
hooks.Register(AfterDelete, ExampleCacheInvalidationHook)
// Register sensitive data filtering for read operations
hooks.Register(AfterRead, ExampleFilterSensitiveDataHook)
// Register related data fetching for read operations
hooks.Register(AfterRead, ExampleRelatedDataHook)
logger.Info("Example hooks registered successfully")
}

View File

@@ -0,0 +1,347 @@
package restheadspec
import (
"context"
"fmt"
"testing"
)
// TestHookRegistry tests the hook registry functionality
func TestHookRegistry(t *testing.T) {
registry := NewHookRegistry()
// Test registering a hook
called := false
hook := func(ctx *HookContext) error {
called = true
return nil
}
registry.Register(BeforeRead, hook)
if registry.Count(BeforeRead) != 1 {
t.Errorf("Expected 1 hook, got %d", registry.Count(BeforeRead))
}
// Test executing a hook
ctx := &HookContext{
Context: context.Background(),
Schema: "test",
Entity: "users",
}
err := registry.Execute(BeforeRead, ctx)
if err != nil {
t.Errorf("Hook execution failed: %v", err)
}
if !called {
t.Error("Hook was not called")
}
}
// TestHookExecution tests hook execution order
func TestHookExecutionOrder(t *testing.T) {
registry := NewHookRegistry()
order := []int{}
hook1 := func(ctx *HookContext) error {
order = append(order, 1)
return nil
}
hook2 := func(ctx *HookContext) error {
order = append(order, 2)
return nil
}
hook3 := func(ctx *HookContext) error {
order = append(order, 3)
return nil
}
registry.Register(BeforeCreate, hook1)
registry.Register(BeforeCreate, hook2)
registry.Register(BeforeCreate, hook3)
ctx := &HookContext{
Context: context.Background(),
Schema: "test",
Entity: "users",
}
err := registry.Execute(BeforeCreate, ctx)
if err != nil {
t.Errorf("Hook execution failed: %v", err)
}
if len(order) != 3 {
t.Errorf("Expected 3 hooks to be called, got %d", len(order))
}
if order[0] != 1 || order[1] != 2 || order[2] != 3 {
t.Errorf("Hooks executed in wrong order: %v", order)
}
}
// TestHookError tests hook error handling
func TestHookError(t *testing.T) {
registry := NewHookRegistry()
executed := []string{}
hook1 := func(ctx *HookContext) error {
executed = append(executed, "hook1")
return nil
}
hook2 := func(ctx *HookContext) error {
executed = append(executed, "hook2")
return fmt.Errorf("hook2 error")
}
hook3 := func(ctx *HookContext) error {
executed = append(executed, "hook3")
return nil
}
registry.Register(BeforeUpdate, hook1)
registry.Register(BeforeUpdate, hook2)
registry.Register(BeforeUpdate, hook3)
ctx := &HookContext{
Context: context.Background(),
Schema: "test",
Entity: "users",
}
err := registry.Execute(BeforeUpdate, ctx)
if err == nil {
t.Error("Expected error from hook execution")
}
if len(executed) != 2 {
t.Errorf("Expected only 2 hooks to be executed, got %d", len(executed))
}
if executed[0] != "hook1" || executed[1] != "hook2" {
t.Errorf("Unexpected execution order: %v", executed)
}
}
// TestHookDataModification tests modifying data in hooks
func TestHookDataModification(t *testing.T) {
registry := NewHookRegistry()
modifyHook := func(ctx *HookContext) error {
if dataMap, ok := ctx.Data.(map[string]interface{}); ok {
dataMap["modified"] = true
ctx.Data = dataMap
}
return nil
}
registry.Register(BeforeCreate, modifyHook)
data := map[string]interface{}{
"name": "test",
}
ctx := &HookContext{
Context: context.Background(),
Schema: "test",
Entity: "users",
Data: data,
}
err := registry.Execute(BeforeCreate, ctx)
if err != nil {
t.Errorf("Hook execution failed: %v", err)
}
modifiedData := ctx.Data.(map[string]interface{})
if !modifiedData["modified"].(bool) {
t.Error("Data was not modified by hook")
}
}
// TestRegisterMultiple tests registering a hook for multiple types
func TestRegisterMultiple(t *testing.T) {
registry := NewHookRegistry()
called := 0
hook := func(ctx *HookContext) error {
called++
return nil
}
registry.RegisterMultiple([]HookType{
BeforeRead,
BeforeCreate,
BeforeUpdate,
}, hook)
if registry.Count(BeforeRead) != 1 {
t.Error("Hook not registered for BeforeRead")
}
if registry.Count(BeforeCreate) != 1 {
t.Error("Hook not registered for BeforeCreate")
}
if registry.Count(BeforeUpdate) != 1 {
t.Error("Hook not registered for BeforeUpdate")
}
ctx := &HookContext{
Context: context.Background(),
Schema: "test",
Entity: "users",
}
registry.Execute(BeforeRead, ctx)
registry.Execute(BeforeCreate, ctx)
registry.Execute(BeforeUpdate, ctx)
if called != 3 {
t.Errorf("Expected hook to be called 3 times, got %d", called)
}
}
// TestClearHooks tests clearing hooks
func TestClearHooks(t *testing.T) {
registry := NewHookRegistry()
hook := func(ctx *HookContext) error {
return nil
}
registry.Register(BeforeRead, hook)
registry.Register(BeforeCreate, hook)
if registry.Count(BeforeRead) != 1 {
t.Error("Hook not registered")
}
registry.Clear(BeforeRead)
if registry.Count(BeforeRead) != 0 {
t.Error("Hook not cleared")
}
if registry.Count(BeforeCreate) != 1 {
t.Error("Wrong hook was cleared")
}
}
// TestClearAllHooks tests clearing all hooks
func TestClearAllHooks(t *testing.T) {
registry := NewHookRegistry()
hook := func(ctx *HookContext) error {
return nil
}
registry.Register(BeforeRead, hook)
registry.Register(BeforeCreate, hook)
registry.Register(BeforeUpdate, hook)
registry.ClearAll()
if registry.Count(BeforeRead) != 0 || registry.Count(BeforeCreate) != 0 || registry.Count(BeforeUpdate) != 0 {
t.Error("Not all hooks were cleared")
}
}
// TestHasHooks tests checking if hooks exist
func TestHasHooks(t *testing.T) {
registry := NewHookRegistry()
if registry.HasHooks(BeforeRead) {
t.Error("Should not have hooks initially")
}
hook := func(ctx *HookContext) error {
return nil
}
registry.Register(BeforeRead, hook)
if !registry.HasHooks(BeforeRead) {
t.Error("Should have hooks after registration")
}
}
// TestGetAllHookTypes tests getting all registered hook types
func TestGetAllHookTypes(t *testing.T) {
registry := NewHookRegistry()
hook := func(ctx *HookContext) error {
return nil
}
registry.Register(BeforeRead, hook)
registry.Register(BeforeCreate, hook)
registry.Register(AfterUpdate, hook)
types := registry.GetAllHookTypes()
if len(types) != 3 {
t.Errorf("Expected 3 hook types, got %d", len(types))
}
// Verify all expected types are present
expectedTypes := map[HookType]bool{
BeforeRead: true,
BeforeCreate: true,
AfterUpdate: true,
}
for _, hookType := range types {
if !expectedTypes[hookType] {
t.Errorf("Unexpected hook type: %s", hookType)
}
}
}
// TestHookContextHandler tests that hooks can access the handler
func TestHookContextHandler(t *testing.T) {
registry := NewHookRegistry()
var capturedHandler *Handler
hook := func(ctx *HookContext) error {
// Verify that the handler is accessible from the context
if ctx.Handler == nil {
return fmt.Errorf("handler is nil in hook context")
}
capturedHandler = ctx.Handler
return nil
}
registry.Register(BeforeRead, hook)
// Create a mock handler
handler := &Handler{
hooks: registry,
}
ctx := &HookContext{
Context: context.Background(),
Handler: handler,
Schema: "test",
Entity: "users",
}
err := registry.Execute(BeforeRead, ctx)
if err != nil {
t.Errorf("Hook execution failed: %v", err)
}
if capturedHandler == nil {
t.Error("Handler was not captured from hook context")
}
if capturedHandler != handler {
t.Error("Captured handler does not match original handler")
}
}

View File

@@ -55,9 +55,9 @@ package restheadspec
import (
"net/http"
"github.com/Warky-Devs/ResolveSpec/pkg/common/adapters/database"
"github.com/Warky-Devs/ResolveSpec/pkg/common/adapters/router"
"github.com/Warky-Devs/ResolveSpec/pkg/modelregistry"
"github.com/bitechdev/ResolveSpec/pkg/common/adapters/database"
"github.com/bitechdev/ResolveSpec/pkg/common/adapters/router"
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
"github.com/gorilla/mux"
"github.com/uptrace/bun"
"github.com/uptrace/bunrouter"

View File

@@ -3,7 +3,7 @@ package testmodels
import (
"time"
"github.com/Warky-Devs/ResolveSpec/pkg/modelregistry"
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
)
// Department represents a company department

View File

@@ -1,71 +1,70 @@
{
"name": "@warkypublic/resolvespec-js",
"version": "1.0.0",
"description": "Client side library for the ResolveSpec API",
"type": "module",
"main": "./src/index.ts",
"module": "./src/index.ts",
"types": "./src/index.ts",
"publishConfig": {
"access": "public",
"main": "./dist/index.js",
"module": "./dist/index.js",
"types": "./dist/index.d.ts"
},
"files": [
"dist",
"bin",
"README.md"
],
"scripts": {
"build": "vite build",
"clean": "rm -rf dist",
"prepublishOnly": "npm run build",
"test": "vitest run",
"lint": "eslint src"
},
"keywords": [
"string",
"blob",
"dependencies",
"workspace",
"package",
"cli",
"tools",
"npm",
"yarn",
"pnpm"
],
"author": "Hein (Warkanum) Puth",
"license": "MIT",
"dependencies": {
"semver": "^7.6.3",
"uuid": "^11.0.3"
},
"devDependencies": {
"@changesets/cli": "^2.27.10",
"@eslint/js": "^9.16.0",
"@types/jsdom": "^21.1.7",
"eslint": "^9.16.0",
"globals": "^15.13.0",
"jsdom": "^25.0.1",
"typescript": "^5.7.2",
"typescript-eslint": "^8.17.0",
"vite": "^6.0.2",
"vite-plugin-dts": "^4.3.0",
"vitest": "^2.1.8"
},
"engines": {
"node": ">=14.16"
},
"repository": {
"type": "git",
"url": "git+https://github.com/Warky-Devs/ResolveSpec"
},
"bugs": {
"url": "https://github.com/Warky-Devs/ResolveSpec/issues"
},
"homepage": "https://github.com/Warky-Devs/ResolveSpec#readme",
"packageManager": "pnpm@9.6.0+sha512.38dc6fba8dba35b39340b9700112c2fe1e12f10b17134715a4aa98ccf7bb035e76fd981cf0bb384dfa98f8d6af5481c2bef2f4266a24bfa20c34eb7147ce0b5e"
}
"name": "@warkypublic/resolvespec-js",
"version": "1.0.0",
"description": "Client side library for the ResolveSpec API",
"type": "module",
"main": "./src/index.ts",
"module": "./src/index.ts",
"types": "./src/index.ts",
"publishConfig": {
"access": "public",
"main": "./dist/index.js",
"module": "./dist/index.js",
"types": "./dist/index.d.ts"
},
"files": [
"dist",
"bin",
"README.md"
],
"scripts": {
"build": "vite build",
"clean": "rm -rf dist",
"prepublishOnly": "npm run build",
"test": "vitest run",
"lint": "eslint src"
},
"keywords": [
"string",
"blob",
"dependencies",
"workspace",
"package",
"cli",
"tools",
"npm",
"yarn",
"pnpm"
],
"author": "Hein (Warkanum) Puth",
"license": "MIT",
"dependencies": {
"semver": "^7.6.3",
"uuid": "^11.0.3"
},
"devDependencies": {
"@changesets/cli": "^2.27.10",
"@eslint/js": "^9.16.0",
"@types/jsdom": "^21.1.7",
"eslint": "^9.16.0",
"globals": "^15.13.0",
"jsdom": "^25.0.1",
"typescript": "^5.7.2",
"typescript-eslint": "^8.17.0",
"vite": "^6.0.2",
"vite-plugin-dts": "^4.3.0",
"vitest": "^2.1.8"
},
"engines": {
"node": ">=14.16"
},
"repository": {
"type": "git",
"url": "git+https://github.com/bitechdev/ResolveSpec"
},
"bugs": {
"url": "https://github.com/bitechdev/ResolveSpec/issues"
},
"homepage": "https://github.com/bitechdev/ResolveSpec#readme",
"packageManager": "pnpm@9.6.0+sha512.38dc6fba8dba35b39340b9700112c2fe1e12f10b17134715a4aa98ccf7bb035e76fd981cf0bb384dfa98f8d6af5481c2bef2f4266a24bfa20c34eb7147ce0b5e"
}

View File

@@ -10,10 +10,10 @@ import (
"os"
"testing"
"github.com/Warky-Devs/ResolveSpec/pkg/logger"
"github.com/Warky-Devs/ResolveSpec/pkg/modelregistry"
"github.com/Warky-Devs/ResolveSpec/pkg/resolvespec"
"github.com/Warky-Devs/ResolveSpec/pkg/testmodels"
"github.com/bitechdev/ResolveSpec/pkg/logger"
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
"github.com/bitechdev/ResolveSpec/pkg/resolvespec"
"github.com/bitechdev/ResolveSpec/pkg/testmodels"
"github.com/glebarez/sqlite"
"github.com/gorilla/mux"
"github.com/stretchr/testify/assert"

157
todo.md Normal file
View File

@@ -0,0 +1,157 @@
# ResolveSpec - TODO List
This document tracks incomplete features and improvements for the ResolveSpec project.
## Core Features to Implement
### 1. Column Selection and Filtering for Preloads
**Location:** `pkg/resolvespec/handler.go:730`
**Status:** Not Implemented
**Description:** Currently, preloads are applied without any column selection or filtering. This feature would allow clients to:
- Select specific columns for preloaded relationships
- Apply filters to preloaded data
- Reduce payload size and improve performance
**Current Limitation:**
```go
// For now, we'll preload without conditions
// TODO: Implement column selection and filtering for preloads
// This requires a more sophisticated approach with callbacks or query builders
query = query.Preload(relationFieldName)
```
**Required Implementation:**
- Add support for column selection in preloaded relationships
- Implement filtering conditions for preloaded data
- Design a callback or query builder approach that works across different ORMs
---
### 2. Recursive JSON Cleaning
**Location:** `pkg/restheadspec/handler.go:796`
**Status:** Partially Implemented (Simplified)
**Description:** The current `cleanJSON` function returns data as-is without recursively removing null and empty fields from nested structures.
**Current Limitation:**
```go
// This is a simplified implementation
// A full implementation would recursively clean nested structures
// For now, we'll return the data as-is
// TODO: Implement recursive cleaning
return data
```
**Required Implementation:**
- Recursively traverse nested structures (maps, slices, structs)
- Remove null values
- Remove empty objects and arrays
- Handle edge cases (circular references, pointers, etc.)
---
### 3. Custom SQL Join Support
**Location:** `pkg/restheadspec/headers.go:159`
**Status:** Not Implemented
**Description:** Support for custom SQL joins via the `X-Custom-SQL-Join` header is currently logged but not executed.
**Current Limitation:**
```go
case strings.HasPrefix(normalizedKey, "x-custom-sql-join"):
// TODO: Implement custom SQL join
logger.Debug("Custom SQL join not yet implemented: %s", decodedValue)
```
**Required Implementation:**
- Parse custom SQL join expressions from headers
- Apply joins to the query builder
- Ensure security (SQL injection prevention)
- Support for different join types (INNER, LEFT, RIGHT, FULL)
- Works across different database adapters (GORM, Bun)
---
### 4. Proper Condition Handling for Bun Preloads
**Location:** `pkg/common/adapters/database/bun.go:202`
**Status:** Partially Implemented
**Description:** The Bun adapter's `Preload` method currently ignores conditions passed to it.
**Current Limitation:**
```go
func (b *BunSelectQuery) Preload(relation string, conditions ...interface{}) common.SelectQuery {
// Bun uses Relation() method for preloading
// For now, we'll just pass the relation name without conditions
// TODO: Implement proper condition handling for Bun
b.query = b.query.Relation(relation)
return b
}
```
**Required Implementation:**
- Properly handle condition parameters in Bun's Relation() method
- Support filtering on preloaded relationships
- Ensure compatibility with GORM's condition syntax where possible
- Test with various condition types
---
## Code Quality Improvements
### 5. Modernize Go Type Declarations
**Location:** `pkg/common/types.go:5, 42, 64, 79`
**Status:** Pending
**Priority:** Low
**Description:** Replace legacy `interface{}` with modern `any` type alias (Go 1.18+).
**Affected Lines:**
- Line 5: Function parameter or return type
- Line 42: Function parameter or return type
- Line 64: Function parameter or return type
- Line 79: Function parameter or return type
**Benefits:**
- More modern and idiomatic Go code
- Better readability
- Aligns with current Go best practices
---
### 6. Pre / Post select/update/delete query in transaction.
- This will allow us to set a user before doing a select
- When making changes, we can have the trigger fire with the correct user.
- Maybe wrap the handleRead,Update,Create,Delete handlers in a transaction with context that can abort when the request is cancelled or a configurable timeout is reached.
## Additional Considerations
### Documentation
- Ensure all new features are documented in README.md
- Update examples to showcase new functionality
- Add migration notes if any breaking changes are introduced
### Testing
- Add unit tests for each new feature
- Add integration tests for database adapter compatibility
- Ensure backward compatibility is maintained
### Performance
- Profile preload performance with column selection and filtering
- Optimize recursive JSON cleaning for large payloads
- Benchmark custom SQL join performance
---
## Priority Ranking
1. **High Priority**
- Column Selection and Filtering for Preloads (#1)
- Proper Condition Handling for Bun Preloads (#4)
2. **Medium Priority**
- Custom SQL Join Support (#3)
- Recursive JSON Cleaning (#2)
3. **Low Priority**
- Modernize Go Type Declarations (#5)
---
**Last Updated:** 2025-11-07