mirror of
https://github.com/bitechdev/ResolveSpec.git
synced 2025-12-30 08:14:25 +00:00
Compare commits
16 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
850d7b546c | ||
|
|
a44ef90d7c | ||
|
|
8b7db5b31a | ||
|
|
14daea3b05 | ||
|
|
35f23b6d9e | ||
|
|
53a4e67f70 | ||
|
|
1289c3af88 | ||
|
|
cdfb7a67fd | ||
|
|
7f5b851669 | ||
|
|
f0e26b1c0d | ||
|
|
1db1b924ef | ||
|
|
d9cf23b1dc | ||
|
|
94f013c872 | ||
|
|
c52fcff61d | ||
|
|
ce106fa940 | ||
|
|
37b4b75175 |
110
.golangci.bck.yml
Normal file
110
.golangci.bck.yml
Normal file
@@ -0,0 +1,110 @@
|
|||||||
|
run:
|
||||||
|
timeout: 5m
|
||||||
|
tests: true
|
||||||
|
skip-dirs:
|
||||||
|
- vendor
|
||||||
|
- .github
|
||||||
|
|
||||||
|
linters:
|
||||||
|
enable:
|
||||||
|
- errcheck
|
||||||
|
- gosimple
|
||||||
|
- govet
|
||||||
|
- ineffassign
|
||||||
|
- staticcheck
|
||||||
|
- unused
|
||||||
|
- gofmt
|
||||||
|
- goimports
|
||||||
|
- misspell
|
||||||
|
- gocritic
|
||||||
|
- revive
|
||||||
|
- stylecheck
|
||||||
|
disable:
|
||||||
|
- typecheck # Can cause issues with generics in some cases
|
||||||
|
|
||||||
|
linters-settings:
|
||||||
|
errcheck:
|
||||||
|
check-type-assertions: false
|
||||||
|
check-blank: false
|
||||||
|
|
||||||
|
govet:
|
||||||
|
check-shadowing: false
|
||||||
|
|
||||||
|
gofmt:
|
||||||
|
simplify: true
|
||||||
|
|
||||||
|
goimports:
|
||||||
|
local-prefixes: github.com/bitechdev/ResolveSpec
|
||||||
|
|
||||||
|
gocritic:
|
||||||
|
enabled-checks:
|
||||||
|
- appendAssign
|
||||||
|
- assignOp
|
||||||
|
- boolExprSimplify
|
||||||
|
- builtinShadow
|
||||||
|
- captLocal
|
||||||
|
- caseOrder
|
||||||
|
- defaultCaseOrder
|
||||||
|
- dupArg
|
||||||
|
- dupBranchBody
|
||||||
|
- dupCase
|
||||||
|
- dupSubExpr
|
||||||
|
- elseif
|
||||||
|
- emptyFallthrough
|
||||||
|
- equalFold
|
||||||
|
- flagName
|
||||||
|
- ifElseChain
|
||||||
|
- indexAlloc
|
||||||
|
- initClause
|
||||||
|
- methodExprCall
|
||||||
|
- nilValReturn
|
||||||
|
- rangeExprCopy
|
||||||
|
- rangeValCopy
|
||||||
|
- regexpMust
|
||||||
|
- singleCaseSwitch
|
||||||
|
- sloppyLen
|
||||||
|
- stringXbytes
|
||||||
|
- switchTrue
|
||||||
|
- typeAssertChain
|
||||||
|
- typeSwitchVar
|
||||||
|
- underef
|
||||||
|
- unlabelStmt
|
||||||
|
- unnamedResult
|
||||||
|
- unnecessaryBlock
|
||||||
|
- weakCond
|
||||||
|
- yodaStyleExpr
|
||||||
|
|
||||||
|
revive:
|
||||||
|
rules:
|
||||||
|
- name: exported
|
||||||
|
disabled: true
|
||||||
|
- name: package-comments
|
||||||
|
disabled: true
|
||||||
|
|
||||||
|
issues:
|
||||||
|
exclude-use-default: false
|
||||||
|
max-issues-per-linter: 0
|
||||||
|
max-same-issues: 0
|
||||||
|
|
||||||
|
# Exclude some linters from running on tests files
|
||||||
|
exclude-rules:
|
||||||
|
- path: _test\.go
|
||||||
|
linters:
|
||||||
|
- errcheck
|
||||||
|
- dupl
|
||||||
|
- gosec
|
||||||
|
- gocritic
|
||||||
|
|
||||||
|
# Ignore "error return value not checked" for defer statements
|
||||||
|
- linters:
|
||||||
|
- errcheck
|
||||||
|
text: "Error return value of .((os\\.)?std(out|err)\\..*|.*Close|.*Flush|os\\.Remove(All)?|.*print(f|ln)?|os\\.(Un)?Setenv). is not checked"
|
||||||
|
|
||||||
|
# Ignore complexity in test files
|
||||||
|
- path: _test\.go
|
||||||
|
text: "cognitive complexity|cyclomatic complexity"
|
||||||
|
|
||||||
|
output:
|
||||||
|
format: colored-line-number
|
||||||
|
print-issued-lines: true
|
||||||
|
print-linter-name: true
|
||||||
58
.vscode/tasks.json
vendored
58
.vscode/tasks.json
vendored
@@ -24,21 +24,63 @@
|
|||||||
"type": "go",
|
"type": "go",
|
||||||
"label": "go: test workspace",
|
"label": "go: test workspace",
|
||||||
"command": "test",
|
"command": "test",
|
||||||
|
|
||||||
"options": {
|
"options": {
|
||||||
"env": {
|
"cwd": "${workspaceFolder}"
|
||||||
"CGO_ENABLED": "0"
|
|
||||||
},
|
|
||||||
"cwd": "${workspaceFolder}/bin",
|
|
||||||
},
|
},
|
||||||
"args": [
|
"args": [
|
||||||
"../..."
|
"-v",
|
||||||
|
"-race",
|
||||||
|
"-coverprofile=coverage.out",
|
||||||
|
"-covermode=atomic",
|
||||||
|
"./..."
|
||||||
],
|
],
|
||||||
"problemMatcher": [
|
"problemMatcher": [
|
||||||
"$go"
|
"$go"
|
||||||
],
|
],
|
||||||
"group": "build",
|
"group": {
|
||||||
|
"kind": "test",
|
||||||
|
"isDefault": true
|
||||||
|
},
|
||||||
|
"presentation": {
|
||||||
|
"reveal": "always",
|
||||||
|
"panel": "new"
|
||||||
|
}
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"type": "shell",
|
||||||
|
"label": "go: vet workspace",
|
||||||
|
"command": "go vet ./...",
|
||||||
|
"options": {
|
||||||
|
"cwd": "${workspaceFolder}"
|
||||||
|
},
|
||||||
|
"problemMatcher": [
|
||||||
|
"$go"
|
||||||
|
],
|
||||||
|
"group": "test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "shell",
|
||||||
|
"label": "go: lint workspace",
|
||||||
|
"command": "golangci-lint run --timeout=5m",
|
||||||
|
"options": {
|
||||||
|
"cwd": "${workspaceFolder}"
|
||||||
|
},
|
||||||
|
"problemMatcher": [],
|
||||||
|
"group": "test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "shell",
|
||||||
|
"label": "go: full test suite",
|
||||||
|
"dependsOrder": "sequence",
|
||||||
|
"dependsOn": [
|
||||||
|
"go: vet workspace",
|
||||||
|
"go: test workspace"
|
||||||
|
],
|
||||||
|
"problemMatcher": [],
|
||||||
|
"group": {
|
||||||
|
"kind": "test",
|
||||||
|
"isDefault": false
|
||||||
|
}
|
||||||
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
164
README.md
164
README.md
@@ -31,9 +31,12 @@ Both share the same core architecture and provide dynamic data querying, relatio
|
|||||||
- [RestHeadSpec: Header-Based API](#restheadspec-header-based-api-1)
|
- [RestHeadSpec: Header-Based API](#restheadspec-header-based-api-1)
|
||||||
- [Lifecycle Hooks](#lifecycle-hooks)
|
- [Lifecycle Hooks](#lifecycle-hooks)
|
||||||
- [Cursor Pagination](#cursor-pagination)
|
- [Cursor Pagination](#cursor-pagination)
|
||||||
|
- [Response Formats](#response-formats)
|
||||||
|
- [Single Record as Object](#single-record-as-object-default-behavior)
|
||||||
- [Example Usage](#example-usage)
|
- [Example Usage](#example-usage)
|
||||||
|
- [Recursive CRUD Operations](#recursive-crud-operations-)
|
||||||
- [Testing](#testing)
|
- [Testing](#testing)
|
||||||
- [What's New in v2.0](#whats-new-in-v20)
|
- [What's New](#whats-new)
|
||||||
|
|
||||||
## Features
|
## Features
|
||||||
|
|
||||||
@@ -45,6 +48,7 @@ Both share the same core architecture and provide dynamic data querying, relatio
|
|||||||
- **Pagination**: Built-in limit/offset and cursor-based pagination
|
- **Pagination**: Built-in limit/offset and cursor-based pagination
|
||||||
- **Computed Columns**: Define virtual columns for complex calculations
|
- **Computed Columns**: Define virtual columns for complex calculations
|
||||||
- **Custom Operators**: Add custom SQL conditions when needed
|
- **Custom Operators**: Add custom SQL conditions when needed
|
||||||
|
- **🆕 Recursive CRUD Handler**: Automatically handle nested object graphs with foreign key resolution and per-record operation control via `_request` field
|
||||||
|
|
||||||
### Architecture (v2.0+)
|
### Architecture (v2.0+)
|
||||||
- **🆕 Database Agnostic**: Works with GORM, Bun, or any database layer through adapters
|
- **🆕 Database Agnostic**: Works with GORM, Bun, or any database layer through adapters
|
||||||
@@ -57,6 +61,7 @@ Both share the same core architecture and provide dynamic data querying, relatio
|
|||||||
- **🆕 Lifecycle Hooks**: Before/after hooks for create, read, update, and delete operations
|
- **🆕 Lifecycle Hooks**: Before/after hooks for create, read, update, and delete operations
|
||||||
- **🆕 Cursor Pagination**: Efficient cursor-based pagination with complex sort support
|
- **🆕 Cursor Pagination**: Efficient cursor-based pagination with complex sort support
|
||||||
- **🆕 Multiple Response Formats**: Simple, detailed, and Syncfusion-compatible formats
|
- **🆕 Multiple Response Formats**: Simple, detailed, and Syncfusion-compatible formats
|
||||||
|
- **🆕 Single Record as Object**: Automatically normalize single-element arrays to objects (enabled by default)
|
||||||
- **🆕 Advanced Filtering**: Field filters, search operators, AND/OR logic, and custom SQL
|
- **🆕 Advanced Filtering**: Field filters, search operators, AND/OR logic, and custom SQL
|
||||||
- **🆕 Base64 Encoding**: Support for base64-encoded header values
|
- **🆕 Base64 Encoding**: Support for base64-encoded header values
|
||||||
|
|
||||||
@@ -161,6 +166,7 @@ restheadspec.SetupMuxRoutes(router, handler)
|
|||||||
| `X-Limit` | Limit results | `50` |
|
| `X-Limit` | Limit results | `50` |
|
||||||
| `X-Offset` | Offset for pagination | `100` |
|
| `X-Offset` | Offset for pagination | `100` |
|
||||||
| `X-Clean-JSON` | Remove null/empty fields | `true` |
|
| `X-Clean-JSON` | Remove null/empty fields | `true` |
|
||||||
|
| `X-Single-Record-As-Object` | Return single records as objects (default: `true`) | `false` |
|
||||||
|
|
||||||
**Available Operators**: `eq`, `neq`, `gt`, `gte`, `lt`, `lte`, `contains`, `startswith`, `endswith`, `between`, `betweeninclusive`, `in`, `empty`, `notempty`
|
**Available Operators**: `eq`, `neq`, `gt`, `gte`, `lt`, `lte`, `contains`, `startswith`, `endswith`, `between`, `betweeninclusive`, `in`, `empty`, `notempty`
|
||||||
|
|
||||||
@@ -301,6 +307,55 @@ RestHeadSpec supports multiple response formats:
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Single Record as Object (Default Behavior)
|
||||||
|
|
||||||
|
By default, RestHeadSpec automatically converts single-element arrays into objects for cleaner API responses. This provides a better developer experience when fetching individual records.
|
||||||
|
|
||||||
|
**Default behavior (enabled)**:
|
||||||
|
```http
|
||||||
|
GET /public/users/123
|
||||||
|
```
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"success": true,
|
||||||
|
"data": { "id": 123, "name": "John", "email": "john@example.com" }
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Instead of:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"success": true,
|
||||||
|
"data": [{ "id": 123, "name": "John", "email": "john@example.com" }]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**To disable** (force arrays for consistency):
|
||||||
|
```http
|
||||||
|
GET /public/users/123
|
||||||
|
X-Single-Record-As-Object: false
|
||||||
|
```
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"success": true,
|
||||||
|
"data": [{ "id": 123, "name": "John", "email": "john@example.com" }]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**How it works**:
|
||||||
|
- When a query returns exactly **one record**, it's returned as an object
|
||||||
|
- When a query returns **multiple records**, they're returned as an array
|
||||||
|
- Set `X-Single-Record-As-Object: false` to always receive arrays
|
||||||
|
- Works with all response formats (simple, detail, syncfusion)
|
||||||
|
- Applies to both read operations and create/update returning clauses
|
||||||
|
|
||||||
|
**Benefits**:
|
||||||
|
- Cleaner API responses for single-record queries
|
||||||
|
- No need to unwrap single-element arrays on the client side
|
||||||
|
- Better TypeScript/type inference support
|
||||||
|
- Consistent with common REST API patterns
|
||||||
|
- Backward compatible via header opt-out
|
||||||
|
|
||||||
## Example Usage
|
## Example Usage
|
||||||
|
|
||||||
### Reading Data with Related Entities
|
### Reading Data with Related Entities
|
||||||
@@ -342,6 +397,92 @@ POST /core/users
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Recursive CRUD Operations (🆕)
|
||||||
|
|
||||||
|
ResolveSpec now supports automatic handling of nested object graphs with intelligent foreign key resolution. This allows you to create, update, or delete entire object hierarchies in a single request.
|
||||||
|
|
||||||
|
#### Creating Nested Objects
|
||||||
|
|
||||||
|
```json
|
||||||
|
POST /core/users
|
||||||
|
{
|
||||||
|
"operation": "create",
|
||||||
|
"data": {
|
||||||
|
"name": "John Doe",
|
||||||
|
"email": "john@example.com",
|
||||||
|
"posts": [
|
||||||
|
{
|
||||||
|
"title": "My First Post",
|
||||||
|
"content": "Hello World",
|
||||||
|
"tags": [
|
||||||
|
{"name": "tech"},
|
||||||
|
{"name": "programming"}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"title": "Second Post",
|
||||||
|
"content": "More content"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"profile": {
|
||||||
|
"bio": "Software Developer",
|
||||||
|
"website": "https://example.com"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Per-Record Operation Control with `_request`
|
||||||
|
|
||||||
|
Control individual operations for each nested record using the special `_request` field:
|
||||||
|
|
||||||
|
```json
|
||||||
|
POST /core/users/123
|
||||||
|
{
|
||||||
|
"operation": "update",
|
||||||
|
"data": {
|
||||||
|
"name": "John Updated",
|
||||||
|
"posts": [
|
||||||
|
{
|
||||||
|
"_request": "insert",
|
||||||
|
"title": "New Post",
|
||||||
|
"content": "Fresh content"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"_request": "update",
|
||||||
|
"id": 456,
|
||||||
|
"title": "Updated Post Title"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"_request": "delete",
|
||||||
|
"id": 789
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Supported `_request` values**:
|
||||||
|
- `insert` - Create a new related record
|
||||||
|
- `update` - Update an existing related record
|
||||||
|
- `delete` - Delete a related record
|
||||||
|
- `upsert` - Create if doesn't exist, update if exists
|
||||||
|
|
||||||
|
#### How It Works
|
||||||
|
|
||||||
|
1. **Automatic Foreign Key Resolution**: Parent IDs are automatically propagated to child records
|
||||||
|
2. **Recursive Processing**: Handles nested relationships at any depth
|
||||||
|
3. **Transaction Safety**: All operations execute within database transactions
|
||||||
|
4. **Relationship Detection**: Automatically detects belongsTo, hasMany, hasOne, and many2many relationships
|
||||||
|
5. **Flexible Operations**: Mix create, update, and delete operations in a single request
|
||||||
|
|
||||||
|
#### Benefits
|
||||||
|
|
||||||
|
- Reduce API round trips for complex object graphs
|
||||||
|
- Maintain referential integrity automatically
|
||||||
|
- Simplify client-side code
|
||||||
|
- Atomic operations with automatic rollback on errors
|
||||||
|
|
||||||
## Installation
|
## Installation
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
@@ -811,12 +952,32 @@ This project is licensed under the MIT License - see the [LICENSE](LICENSE) file
|
|||||||
|
|
||||||
### v2.1 (Latest)
|
### v2.1 (Latest)
|
||||||
|
|
||||||
|
**Recursive CRUD Handler (🆕 Nov 11, 2025)**:
|
||||||
|
- **Nested Object Graphs**: Automatically handle complex object hierarchies with parent-child relationships
|
||||||
|
- **Foreign Key Resolution**: Automatic propagation of parent IDs to child records
|
||||||
|
- **Per-Record Operations**: Control create/update/delete operations per record via `_request` field
|
||||||
|
- **Transaction Safety**: All nested operations execute atomically within database transactions
|
||||||
|
- **Relationship Detection**: Automatic detection of belongsTo, hasMany, hasOne, and many2many relationships
|
||||||
|
- **Deep Nesting Support**: Handle relationships at any depth level
|
||||||
|
- **Mixed Operations**: Combine insert, update, and delete operations in a single request
|
||||||
|
|
||||||
|
**Primary Key Improvements (Nov 11, 2025)**:
|
||||||
|
- **GetPrimaryKeyName**: Enhanced primary key detection for better preload and ID field handling
|
||||||
|
- **Better GORM/Bun Support**: Improved compatibility with both ORMs for primary key operations
|
||||||
|
- **Computed Column Support**: Fixed computed columns functionality across handlers
|
||||||
|
|
||||||
|
**Database Adapter Enhancements (Nov 11, 2025)**:
|
||||||
|
- **Bun ORM Relations**: Using Scan model method for better has-many and many-to-many relationship handling
|
||||||
|
- **Model Method Support**: Enhanced query building with proper model registration
|
||||||
|
- **Improved Type Safety**: Better handling of relationship queries with type-aware scanning
|
||||||
|
|
||||||
**RestHeadSpec - Header-Based REST API**:
|
**RestHeadSpec - Header-Based REST API**:
|
||||||
- **Header-Based Querying**: All query options via HTTP headers instead of request body
|
- **Header-Based Querying**: All query options via HTTP headers instead of request body
|
||||||
- **Lifecycle Hooks**: Before/after hooks for create, read, update, delete operations
|
- **Lifecycle Hooks**: Before/after hooks for create, read, update, delete operations
|
||||||
- **Cursor Pagination**: Efficient cursor-based pagination with complex sorting
|
- **Cursor Pagination**: Efficient cursor-based pagination with complex sorting
|
||||||
- **Advanced Filtering**: Field filters, search operators, AND/OR logic
|
- **Advanced Filtering**: Field filters, search operators, AND/OR logic
|
||||||
- **Multiple Response Formats**: Simple, detailed, and Syncfusion-compatible responses
|
- **Multiple Response Formats**: Simple, detailed, and Syncfusion-compatible responses
|
||||||
|
- **Single Record as Object**: Automatically return single-element arrays as objects (default, toggleable via header)
|
||||||
- **Base64 Support**: Base64-encoded header values for complex queries
|
- **Base64 Support**: Base64-encoded header values for complex queries
|
||||||
- **Type-Aware Filtering**: Automatic type detection and conversion for filters
|
- **Type-Aware Filtering**: Automatic type detection and conversion for filters
|
||||||
|
|
||||||
@@ -826,6 +987,7 @@ This project is licensed under the MIT License - see the [LICENSE](LICENSE) file
|
|||||||
- Improved reflection safety
|
- Improved reflection safety
|
||||||
- Fixed COUNT query issues with table aliasing
|
- Fixed COUNT query issues with table aliasing
|
||||||
- Better pointer handling throughout the codebase
|
- Better pointer handling throughout the codebase
|
||||||
|
- **Comprehensive Test Coverage**: Added standalone CRUD tests for both ResolveSpec and RestHeadSpec
|
||||||
|
|
||||||
### v2.0
|
### v2.0
|
||||||
|
|
||||||
|
|||||||
@@ -9,6 +9,8 @@ import (
|
|||||||
"github.com/uptrace/bun"
|
"github.com/uptrace/bun"
|
||||||
|
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/reflection"
|
||||||
)
|
)
|
||||||
|
|
||||||
// BunAdapter adapts Bun to work with our Database interface
|
// BunAdapter adapts Bun to work with our Database interface
|
||||||
@@ -215,6 +217,40 @@ func (b *BunSelectQuery) Preload(relation string, conditions ...interface{}) com
|
|||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (b *BunSelectQuery) PreloadRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery {
|
||||||
|
b.query = b.query.Relation(relation, func(sq *bun.SelectQuery) *bun.SelectQuery {
|
||||||
|
if len(apply) == 0 {
|
||||||
|
return sq
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wrap the incoming *bun.SelectQuery in our adapter
|
||||||
|
wrapper := &BunSelectQuery{
|
||||||
|
query: sq,
|
||||||
|
db: b.db,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start with the interface value (not pointer)
|
||||||
|
current := common.SelectQuery(wrapper)
|
||||||
|
|
||||||
|
// Apply each function in sequence
|
||||||
|
for _, fn := range apply {
|
||||||
|
if fn != nil {
|
||||||
|
// Pass ¤t (pointer to interface variable), fn modifies and returns new interface value
|
||||||
|
modified := fn(current)
|
||||||
|
current = modified
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract the final *bun.SelectQuery
|
||||||
|
if finalBun, ok := current.(*BunSelectQuery); ok {
|
||||||
|
return finalBun.query
|
||||||
|
}
|
||||||
|
|
||||||
|
return sq // fallback
|
||||||
|
})
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
func (b *BunSelectQuery) Order(order string) common.SelectQuery {
|
func (b *BunSelectQuery) Order(order string) common.SelectQuery {
|
||||||
b.query = b.query.Order(order)
|
b.query = b.query.Order(order)
|
||||||
return b
|
return b
|
||||||
@@ -319,25 +355,45 @@ func (b *BunInsertQuery) Exec(ctx context.Context) (common.Result, error) {
|
|||||||
// BunUpdateQuery implements UpdateQuery for Bun
|
// BunUpdateQuery implements UpdateQuery for Bun
|
||||||
type BunUpdateQuery struct {
|
type BunUpdateQuery struct {
|
||||||
query *bun.UpdateQuery
|
query *bun.UpdateQuery
|
||||||
|
model interface{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *BunUpdateQuery) Model(model interface{}) common.UpdateQuery {
|
func (b *BunUpdateQuery) Model(model interface{}) common.UpdateQuery {
|
||||||
b.query = b.query.Model(model)
|
b.query = b.query.Model(model)
|
||||||
|
b.model = model
|
||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *BunUpdateQuery) Table(table string) common.UpdateQuery {
|
func (b *BunUpdateQuery) Table(table string) common.UpdateQuery {
|
||||||
b.query = b.query.Table(table)
|
b.query = b.query.Table(table)
|
||||||
|
if b.model == nil {
|
||||||
|
// Try to get table name from table string if model is not set
|
||||||
|
|
||||||
|
model, err := modelregistry.GetModelByName(table)
|
||||||
|
if err == nil {
|
||||||
|
b.model = model
|
||||||
|
}
|
||||||
|
}
|
||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *BunUpdateQuery) Set(column string, value interface{}) common.UpdateQuery {
|
func (b *BunUpdateQuery) Set(column string, value interface{}) common.UpdateQuery {
|
||||||
|
// Validate column is writable if model is set
|
||||||
|
if b.model != nil && !reflection.IsColumnWritable(b.model, column) {
|
||||||
|
// Skip scan-only columns
|
||||||
|
return b
|
||||||
|
}
|
||||||
b.query = b.query.Set(column+" = ?", value)
|
b.query = b.query.Set(column+" = ?", value)
|
||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *BunUpdateQuery) SetMap(values map[string]interface{}) common.UpdateQuery {
|
func (b *BunUpdateQuery) SetMap(values map[string]interface{}) common.UpdateQuery {
|
||||||
for column, value := range values {
|
for column, value := range values {
|
||||||
|
// Validate column is writable if model is set
|
||||||
|
if b.model != nil && !reflection.IsColumnWritable(b.model, column) {
|
||||||
|
// Skip scan-only columns
|
||||||
|
continue
|
||||||
|
}
|
||||||
b.query = b.query.Set(column+" = ?", value)
|
b.query = b.query.Set(column+" = ?", value)
|
||||||
}
|
}
|
||||||
return b
|
return b
|
||||||
|
|||||||
@@ -8,6 +8,8 @@ import (
|
|||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
|
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/reflection"
|
||||||
)
|
)
|
||||||
|
|
||||||
// GormAdapter adapts GORM to work with our Database interface
|
// GormAdapter adapts GORM to work with our Database interface
|
||||||
@@ -97,6 +99,7 @@ func (g *GormSelectQuery) Table(table string) common.SelectQuery {
|
|||||||
g.db = g.db.Table(table)
|
g.db = g.db.Table(table)
|
||||||
// Check if the table name contains schema (e.g., "schema.table")
|
// Check if the table name contains schema (e.g., "schema.table")
|
||||||
g.schema, g.tableName = parseTableName(table)
|
g.schema, g.tableName = parseTableName(table)
|
||||||
|
|
||||||
return g
|
return g
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -197,6 +200,36 @@ func (g *GormSelectQuery) Preload(relation string, conditions ...interface{}) co
|
|||||||
return g
|
return g
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (g *GormSelectQuery) PreloadRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery {
|
||||||
|
g.db = g.db.Preload(relation, func(db *gorm.DB) *gorm.DB {
|
||||||
|
if len(apply) == 0 {
|
||||||
|
return db
|
||||||
|
}
|
||||||
|
|
||||||
|
wrapper := &GormSelectQuery{
|
||||||
|
db: db,
|
||||||
|
}
|
||||||
|
|
||||||
|
current := common.SelectQuery(wrapper)
|
||||||
|
|
||||||
|
for _, fn := range apply {
|
||||||
|
if fn != nil {
|
||||||
|
|
||||||
|
modified := fn(current)
|
||||||
|
current = modified
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if finalBun, ok := current.(*GormSelectQuery); ok {
|
||||||
|
return finalBun.db
|
||||||
|
}
|
||||||
|
|
||||||
|
return db // fallback
|
||||||
|
})
|
||||||
|
|
||||||
|
return g
|
||||||
|
}
|
||||||
|
|
||||||
func (g *GormSelectQuery) Order(order string) common.SelectQuery {
|
func (g *GormSelectQuery) Order(order string) common.SelectQuery {
|
||||||
g.db = g.db.Order(order)
|
g.db = g.db.Order(order)
|
||||||
return g
|
return g
|
||||||
@@ -309,10 +342,23 @@ func (g *GormUpdateQuery) Model(model interface{}) common.UpdateQuery {
|
|||||||
|
|
||||||
func (g *GormUpdateQuery) Table(table string) common.UpdateQuery {
|
func (g *GormUpdateQuery) Table(table string) common.UpdateQuery {
|
||||||
g.db = g.db.Table(table)
|
g.db = g.db.Table(table)
|
||||||
|
if g.model == nil {
|
||||||
|
// Try to get table name from table string if model is not set
|
||||||
|
model, err := modelregistry.GetModelByName(table)
|
||||||
|
if err == nil {
|
||||||
|
g.model = model
|
||||||
|
}
|
||||||
|
}
|
||||||
return g
|
return g
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *GormUpdateQuery) Set(column string, value interface{}) common.UpdateQuery {
|
func (g *GormUpdateQuery) Set(column string, value interface{}) common.UpdateQuery {
|
||||||
|
// Validate column is writable if model is set
|
||||||
|
if g.model != nil && !reflection.IsColumnWritable(g.model, column) {
|
||||||
|
// Skip read-only columns
|
||||||
|
return g
|
||||||
|
}
|
||||||
|
|
||||||
if g.updates == nil {
|
if g.updates == nil {
|
||||||
g.updates = make(map[string]interface{})
|
g.updates = make(map[string]interface{})
|
||||||
}
|
}
|
||||||
@@ -323,7 +369,18 @@ func (g *GormUpdateQuery) Set(column string, value interface{}) common.UpdateQue
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (g *GormUpdateQuery) SetMap(values map[string]interface{}) common.UpdateQuery {
|
func (g *GormUpdateQuery) SetMap(values map[string]interface{}) common.UpdateQuery {
|
||||||
g.updates = values
|
// Filter out read-only columns if model is set
|
||||||
|
if g.model != nil {
|
||||||
|
filteredValues := make(map[string]interface{})
|
||||||
|
for column, value := range values {
|
||||||
|
if reflection.IsColumnWritable(g.model, column) {
|
||||||
|
filteredValues[column] = value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
g.updates = filteredValues
|
||||||
|
} else {
|
||||||
|
g.updates = values
|
||||||
|
}
|
||||||
return g
|
return g
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
161
pkg/common/adapters/database/update_validation_test.go
Normal file
161
pkg/common/adapters/database/update_validation_test.go
Normal file
@@ -0,0 +1,161 @@
|
|||||||
|
package database
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/reflection"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Test models for bun
|
||||||
|
type BunTestModel struct {
|
||||||
|
ID int `bun:"id,pk"`
|
||||||
|
Name string `bun:"name"`
|
||||||
|
Email string `bun:"email"`
|
||||||
|
ComputedCol string `bun:"computed_col,scanonly"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test models for gorm
|
||||||
|
type GormTestModel struct {
|
||||||
|
ID int `gorm:"column:id;primaryKey"`
|
||||||
|
Name string `gorm:"column:name"`
|
||||||
|
Email string `gorm:"column:email"`
|
||||||
|
ReadOnlyCol string `gorm:"column:readonly_col;->"`
|
||||||
|
NoWriteCol string `gorm:"column:nowrite_col;<-:false"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsColumnWritable_Bun(t *testing.T) {
|
||||||
|
model := &BunTestModel{}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
columnName string
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "writable column - id",
|
||||||
|
columnName: "id",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "writable column - name",
|
||||||
|
columnName: "name",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "writable column - email",
|
||||||
|
columnName: "email",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "scanonly column should not be writable",
|
||||||
|
columnName: "computed_col",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "non-existent column should be writable (dynamic)",
|
||||||
|
columnName: "nonexistent",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := reflection.IsColumnWritable(model, tt.columnName)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("IsColumnWritable(%q) = %v, want %v", tt.columnName, result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsColumnWritable_Gorm(t *testing.T) {
|
||||||
|
model := &GormTestModel{}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
columnName string
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "writable column - id",
|
||||||
|
columnName: "id",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "writable column - name",
|
||||||
|
columnName: "name",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "writable column - email",
|
||||||
|
columnName: "email",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "read-only column with -> should not be writable",
|
||||||
|
columnName: "readonly_col",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "column with <-:false should not be writable",
|
||||||
|
columnName: "nowrite_col",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "non-existent column should be writable (dynamic)",
|
||||||
|
columnName: "nonexistent",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := reflection.IsColumnWritable(model, tt.columnName)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("IsColumnWritable(%q) = %v, want %v", tt.columnName, result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBunUpdateQuery_SetMap_FiltersScanOnly(t *testing.T) {
|
||||||
|
// Note: This is a unit test for the validation logic only.
|
||||||
|
// We can't fully test the bun query without a database connection,
|
||||||
|
// but we've verified the validation logic in TestIsColumnWritable_Bun
|
||||||
|
t.Skip("Skipping integration test - validation logic tested in TestIsColumnWritable_Bun")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGormUpdateQuery_SetMap_FiltersReadOnly(t *testing.T) {
|
||||||
|
model := &GormTestModel{}
|
||||||
|
query := &GormUpdateQuery{
|
||||||
|
model: model,
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetMap should filter out read-only columns
|
||||||
|
values := map[string]interface{}{
|
||||||
|
"name": "John",
|
||||||
|
"email": "john@example.com",
|
||||||
|
"readonly_col": "should_be_filtered",
|
||||||
|
"nowrite_col": "should_also_be_filtered",
|
||||||
|
}
|
||||||
|
|
||||||
|
query.SetMap(values)
|
||||||
|
|
||||||
|
// Check that the updates map only contains writable columns
|
||||||
|
if updates, ok := query.updates.(map[string]interface{}); ok {
|
||||||
|
if _, exists := updates["readonly_col"]; exists {
|
||||||
|
t.Error("readonly_col should have been filtered out")
|
||||||
|
}
|
||||||
|
if _, exists := updates["nowrite_col"]; exists {
|
||||||
|
t.Error("nowrite_col should have been filtered out")
|
||||||
|
}
|
||||||
|
if _, exists := updates["name"]; !exists {
|
||||||
|
t.Error("name should be in updates")
|
||||||
|
}
|
||||||
|
if _, exists := updates["email"]; !exists {
|
||||||
|
t.Error("email should be in updates")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
t.Error("updates should be a map[string]interface{}")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -32,6 +32,7 @@ type SelectQuery interface {
|
|||||||
Join(query string, args ...interface{}) SelectQuery
|
Join(query string, args ...interface{}) SelectQuery
|
||||||
LeftJoin(query string, args ...interface{}) SelectQuery
|
LeftJoin(query string, args ...interface{}) SelectQuery
|
||||||
Preload(relation string, conditions ...interface{}) SelectQuery
|
Preload(relation string, conditions ...interface{}) SelectQuery
|
||||||
|
PreloadRelation(relation string, apply ...func(SelectQuery) SelectQuery) SelectQuery
|
||||||
Order(order string) SelectQuery
|
Order(order string) SelectQuery
|
||||||
Limit(n int) SelectQuery
|
Limit(n int) SelectQuery
|
||||||
Offset(n int) SelectQuery
|
Offset(n int) SelectQuery
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/reflection"
|
||||||
)
|
)
|
||||||
|
|
||||||
// CRUDRequestProvider interface for models that provide CRUD request strings
|
// CRUDRequestProvider interface for models that provide CRUD request strings
|
||||||
@@ -248,7 +249,7 @@ func (p *NestedCUDProcessor) processUpdate(
|
|||||||
|
|
||||||
logger.Debug("Updating %s with ID %v, data: %+v", tableName, id, data)
|
logger.Debug("Updating %s with ID %v, data: %+v", tableName, id, data)
|
||||||
|
|
||||||
query := p.db.NewUpdate().Table(tableName).SetMap(data).Where("id = ?", id)
|
query := p.db.NewUpdate().Table(tableName).SetMap(data).Where(fmt.Sprintf("%s = ?", QuoteIdent(reflection.GetPrimaryKeyName(tableName))), id)
|
||||||
|
|
||||||
result, err := query.Exec(ctx)
|
result, err := query.Exec(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -268,7 +269,7 @@ func (p *NestedCUDProcessor) processDelete(ctx context.Context, tableName string
|
|||||||
|
|
||||||
logger.Debug("Deleting from %s with ID %v", tableName, id)
|
logger.Debug("Deleting from %s with ID %v", tableName, id)
|
||||||
|
|
||||||
query := p.db.NewDelete().Table(tableName).Where("id = ?", id)
|
query := p.db.NewDelete().Table(tableName).Where(fmt.Sprintf("%s = ?", QuoteIdent(reflection.GetPrimaryKeyName(tableName))), id)
|
||||||
|
|
||||||
result, err := query.Exec(ctx)
|
result, err := query.Exec(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -377,8 +378,16 @@ func (p *NestedCUDProcessor) getTableNameForModel(model interface{}, defaultName
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ShouldUseNestedProcessor determines if we should use nested CUD processing
|
// ShouldUseNestedProcessor determines if we should use nested CUD processing
|
||||||
// It checks if the data contains nested relations or a _request field
|
// It recursively checks if the data contains:
|
||||||
|
// 1. A _request field at any level, OR
|
||||||
|
// 2. Nested relations that themselves contain further nested relations or _request fields
|
||||||
|
// This ensures nested processing is only used when there are deeply nested operations
|
||||||
func ShouldUseNestedProcessor(data map[string]interface{}, model interface{}, relationshipHelper RelationshipInfoProvider) bool {
|
func ShouldUseNestedProcessor(data map[string]interface{}, model interface{}, relationshipHelper RelationshipInfoProvider) bool {
|
||||||
|
return shouldUseNestedProcessorDepth(data, model, relationshipHelper, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// shouldUseNestedProcessorDepth is the internal recursive implementation with depth tracking
|
||||||
|
func shouldUseNestedProcessorDepth(data map[string]interface{}, model interface{}, relationshipHelper RelationshipInfoProvider, depth int) bool {
|
||||||
// Check for _request field
|
// Check for _request field
|
||||||
if _, hasCRUDRequest := data["_request"]; hasCRUDRequest {
|
if _, hasCRUDRequest := data["_request"]; hasCRUDRequest {
|
||||||
return true
|
return true
|
||||||
@@ -405,10 +414,34 @@ func ShouldUseNestedProcessor(data map[string]interface{}, model interface{}, re
|
|||||||
relInfo := relationshipHelper.GetRelationshipInfo(modelType, key)
|
relInfo := relationshipHelper.GetRelationshipInfo(modelType, key)
|
||||||
if relInfo != nil {
|
if relInfo != nil {
|
||||||
// Check if the value is actually nested data (object or array)
|
// Check if the value is actually nested data (object or array)
|
||||||
switch value.(type) {
|
switch v := value.(type) {
|
||||||
case map[string]interface{}, []interface{}, []map[string]interface{}:
|
case map[string]interface{}, []interface{}, []map[string]interface{}:
|
||||||
logger.Debug("Found nested relation field: %s", key)
|
// If we're already at a nested level (depth > 0) and found a relation,
|
||||||
return true
|
// that means we have multi-level nesting, so return true
|
||||||
|
if depth > 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
// At depth 0, recurse to check if the nested data has further nesting
|
||||||
|
switch typedValue := v.(type) {
|
||||||
|
case map[string]interface{}:
|
||||||
|
if shouldUseNestedProcessorDepth(typedValue, relInfo.RelatedModel, relationshipHelper, depth+1) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
case []interface{}:
|
||||||
|
for _, item := range typedValue {
|
||||||
|
if itemMap, ok := item.(map[string]interface{}); ok {
|
||||||
|
if shouldUseNestedProcessorDepth(itemMap, relInfo.RelatedModel, relationshipHelper, depth+1) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case []map[string]interface{}:
|
||||||
|
for _, itemMap := range typedValue {
|
||||||
|
if shouldUseNestedProcessorDepth(itemMap, relInfo.RelatedModel, relationshipHelper, depth+1) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -35,7 +35,9 @@ type PreloadOption struct {
|
|||||||
Relation string `json:"relation"`
|
Relation string `json:"relation"`
|
||||||
Columns []string `json:"columns"`
|
Columns []string `json:"columns"`
|
||||||
OmitColumns []string `json:"omit_columns"`
|
OmitColumns []string `json:"omit_columns"`
|
||||||
|
Sort []SortOption `json:"sort"`
|
||||||
Filters []FilterOption `json:"filters"`
|
Filters []FilterOption `json:"filters"`
|
||||||
|
Where string `json:"where"`
|
||||||
Limit *int `json:"limit"`
|
Limit *int `json:"limit"`
|
||||||
Offset *int `json:"offset"`
|
Offset *int `json:"offset"`
|
||||||
Updatable *bool `json:"updateable"` // if true, the relation can be updated
|
Updatable *bool `json:"updateable"` // if true, the relation can be updated
|
||||||
|
|||||||
@@ -183,7 +183,8 @@ func (v *ColumnValidator) ValidateRequestOptions(options RequestOptions) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Validate Preload columns (if specified)
|
// Validate Preload columns (if specified)
|
||||||
for _, preload := range options.Preload {
|
for idx := range options.Preload {
|
||||||
|
preload := options.Preload[idx]
|
||||||
// Note: We don't validate the relation name itself, as it's a relationship
|
// Note: We don't validate the relation name itself, as it's a relationship
|
||||||
// Only validate columns if specified for the preload
|
// Only validate columns if specified for the preload
|
||||||
if err := v.ValidateColumns(preload.Columns); err != nil {
|
if err := v.ValidateColumns(preload.Columns); err != nil {
|
||||||
@@ -239,7 +240,8 @@ func (v *ColumnValidator) FilterRequestOptions(options RequestOptions) RequestOp
|
|||||||
|
|
||||||
// Filter Preload columns
|
// Filter Preload columns
|
||||||
validPreloads := make([]PreloadOption, 0, len(options.Preload))
|
validPreloads := make([]PreloadOption, 0, len(options.Preload))
|
||||||
for _, preload := range options.Preload {
|
for idx := range options.Preload {
|
||||||
|
preload := options.Preload[idx]
|
||||||
filteredPreload := preload
|
filteredPreload := preload
|
||||||
filteredPreload.Columns = v.FilterValidColumns(preload.Columns)
|
filteredPreload.Columns = v.FilterValidColumns(preload.Columns)
|
||||||
filteredPreload.OmitColumns = v.FilterValidColumns(preload.OmitColumns)
|
filteredPreload.OmitColumns = v.FilterValidColumns(preload.OmitColumns)
|
||||||
@@ -270,3 +272,11 @@ func (v *ColumnValidator) GetValidColumns() []string {
|
|||||||
}
|
}
|
||||||
return columns
|
return columns
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func QuoteIdent(qualifier string) string {
|
||||||
|
return `"` + strings.ReplaceAll(qualifier, `"`, `""`) + `"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func QuoteLiteral(value string) string {
|
||||||
|
return `'` + strings.ReplaceAll(value, `'`, `''`) + `'`
|
||||||
|
}
|
||||||
|
|||||||
@@ -17,6 +17,10 @@ var defaultRegistry = &DefaultModelRegistry{
|
|||||||
models: make(map[string]interface{}),
|
models: make(map[string]interface{}),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Global list of registries (searched in order)
|
||||||
|
var registries = []*DefaultModelRegistry{defaultRegistry}
|
||||||
|
var registriesMutex sync.RWMutex
|
||||||
|
|
||||||
// NewModelRegistry creates a new model registry
|
// NewModelRegistry creates a new model registry
|
||||||
func NewModelRegistry() *DefaultModelRegistry {
|
func NewModelRegistry() *DefaultModelRegistry {
|
||||||
return &DefaultModelRegistry{
|
return &DefaultModelRegistry{
|
||||||
@@ -24,6 +28,14 @@ func NewModelRegistry() *DefaultModelRegistry {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AddRegistry adds a registry to the global list of registries
|
||||||
|
// Registries are searched in the order they were added
|
||||||
|
func AddRegistry(registry *DefaultModelRegistry) {
|
||||||
|
registriesMutex.Lock()
|
||||||
|
defer registriesMutex.Unlock()
|
||||||
|
registries = append(registries, registry)
|
||||||
|
}
|
||||||
|
|
||||||
func (r *DefaultModelRegistry) RegisterModel(name string, model interface{}) error {
|
func (r *DefaultModelRegistry) RegisterModel(name string, model interface{}) error {
|
||||||
r.mutex.Lock()
|
r.mutex.Lock()
|
||||||
defer r.mutex.Unlock()
|
defer r.mutex.Unlock()
|
||||||
@@ -107,9 +119,19 @@ func RegisterModel(model interface{}, name string) error {
|
|||||||
return defaultRegistry.RegisterModel(name, model)
|
return defaultRegistry.RegisterModel(name, model)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetModelByName retrieves a model from the default global registry by name
|
// GetModelByName retrieves a model by searching through all registries in order
|
||||||
|
// Returns the first match found
|
||||||
func GetModelByName(name string) (interface{}, error) {
|
func GetModelByName(name string) (interface{}, error) {
|
||||||
return defaultRegistry.GetModel(name)
|
registriesMutex.RLock()
|
||||||
|
defer registriesMutex.RUnlock()
|
||||||
|
|
||||||
|
for _, registry := range registries {
|
||||||
|
if model, err := registry.GetModel(name); err == nil {
|
||||||
|
return model, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("model %s not found in any registry", name)
|
||||||
}
|
}
|
||||||
|
|
||||||
// IterateModels iterates over all models in the default global registry
|
// IterateModels iterates over all models in the default global registry
|
||||||
@@ -122,14 +144,26 @@ func IterateModels(fn func(name string, model interface{})) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetModels returns a list of all models in the default global registry
|
// GetModels returns a list of all models from all registries
|
||||||
|
// Models are collected in registry order, with duplicates included
|
||||||
func GetModels() []interface{} {
|
func GetModels() []interface{} {
|
||||||
defaultRegistry.mutex.RLock()
|
registriesMutex.RLock()
|
||||||
defer defaultRegistry.mutex.RUnlock()
|
defer registriesMutex.RUnlock()
|
||||||
|
|
||||||
models := make([]interface{}, 0, len(defaultRegistry.models))
|
var models []interface{}
|
||||||
for _, model := range defaultRegistry.models {
|
seen := make(map[string]bool)
|
||||||
models = append(models, model)
|
|
||||||
|
for _, registry := range registries {
|
||||||
|
registry.mutex.RLock()
|
||||||
|
for name, model := range registry.models {
|
||||||
|
// Only add the first occurrence of each model name
|
||||||
|
if !seen[name] {
|
||||||
|
models = append(models, model)
|
||||||
|
seen[name] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
registry.mutex.RUnlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
return models
|
return models
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,9 +1,15 @@
|
|||||||
package common
|
package reflection
|
||||||
|
|
||||||
import "reflect"
|
import "reflect"
|
||||||
|
|
||||||
func Len(v any) int {
|
func Len(v any) int {
|
||||||
val := reflect.ValueOf(v)
|
val := reflect.ValueOf(v)
|
||||||
|
valKind := val.Kind()
|
||||||
|
|
||||||
|
if valKind == reflect.Ptr {
|
||||||
|
val = val.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
switch val.Kind() {
|
switch val.Kind() {
|
||||||
case reflect.Slice, reflect.Array, reflect.Map, reflect.String, reflect.Chan:
|
case reflect.Slice, reflect.Array, reflect.Map, reflect.String, reflect.Chan:
|
||||||
return val.Len()
|
return val.Len()
|
||||||
@@ -4,15 +4,31 @@ import (
|
|||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type PrimaryKeyNameProvider interface {
|
||||||
|
GetIDName() string
|
||||||
|
}
|
||||||
|
|
||||||
// GetPrimaryKeyName extracts the primary key column name from a model
|
// GetPrimaryKeyName extracts the primary key column name from a model
|
||||||
// It first checks if the model implements PrimaryKeyNameProvider (GetIDName method)
|
// It first checks if the model implements PrimaryKeyNameProvider (GetIDName method)
|
||||||
// Falls back to reflection to find bun:",pk" tag, then gorm:"primaryKey" tag
|
// Falls back to reflection to find bun:",pk" tag, then gorm:"primaryKey" tag
|
||||||
func GetPrimaryKeyName(model any) string {
|
func GetPrimaryKeyName(model any) string {
|
||||||
|
if reflect.TypeOf(model) == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
// If we are given a string model name, look up the model
|
||||||
|
if reflect.TypeOf(model).Kind() == reflect.String {
|
||||||
|
name := model.(string)
|
||||||
|
m, err := modelregistry.GetModelByName(name)
|
||||||
|
if err == nil {
|
||||||
|
model = m
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Check if model implements PrimaryKeyNameProvider
|
// Check if model implements PrimaryKeyNameProvider
|
||||||
if provider, ok := model.(common.PrimaryKeyNameProvider); ok {
|
if provider, ok := model.(PrimaryKeyNameProvider); ok {
|
||||||
return provider.GetIDName()
|
return provider.GetIDName()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -22,7 +38,67 @@ func GetPrimaryKeyName(model any) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Fall back to GORM tag
|
// Fall back to GORM tag
|
||||||
return getPrimaryKeyFromReflection(model, "gorm")
|
if pkName := getPrimaryKeyFromReflection(model, "gorm"); pkName != "" {
|
||||||
|
return pkName
|
||||||
|
}
|
||||||
|
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPrimaryKeyValue extracts the primary key value from a model instance
|
||||||
|
// Returns the value of the primary key field
|
||||||
|
func GetPrimaryKeyValue(model any) interface{} {
|
||||||
|
if model == nil || reflect.TypeOf(model) == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
val := reflect.ValueOf(model)
|
||||||
|
if val.Kind() == reflect.Pointer {
|
||||||
|
val = val.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
if val.Kind() != reflect.Struct {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
typ := val.Type()
|
||||||
|
|
||||||
|
// Try Bun tag first
|
||||||
|
for i := 0; i < typ.NumField(); i++ {
|
||||||
|
field := typ.Field(i)
|
||||||
|
bunTag := field.Tag.Get("bun")
|
||||||
|
if strings.Contains(bunTag, "pk") {
|
||||||
|
fieldValue := val.Field(i)
|
||||||
|
if fieldValue.CanInterface() {
|
||||||
|
return fieldValue.Interface()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fall back to GORM tag
|
||||||
|
for i := 0; i < typ.NumField(); i++ {
|
||||||
|
field := typ.Field(i)
|
||||||
|
gormTag := field.Tag.Get("gorm")
|
||||||
|
if strings.Contains(gormTag, "primaryKey") {
|
||||||
|
fieldValue := val.Field(i)
|
||||||
|
if fieldValue.CanInterface() {
|
||||||
|
return fieldValue.Interface()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Last resort: look for field named "ID" or "Id"
|
||||||
|
for i := 0; i < typ.NumField(); i++ {
|
||||||
|
field := typ.Field(i)
|
||||||
|
if strings.ToLower(field.Name) == "id" {
|
||||||
|
fieldValue := val.Field(i)
|
||||||
|
if fieldValue.CanInterface() {
|
||||||
|
return fieldValue.Interface()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetModelColumns extracts all column names from a model using reflection
|
// GetModelColumns extracts all column names from a model using reflection
|
||||||
@@ -160,3 +236,90 @@ func ExtractColumnFromBunTag(tag string) string {
|
|||||||
}
|
}
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsColumnWritable checks if a column can be written to in the database
|
||||||
|
// For bun: returns false if the field has "scanonly" tag
|
||||||
|
// For gorm: returns false if the field has "<-:false" or "->" (read-only) tag
|
||||||
|
func IsColumnWritable(model any, columnName string) bool {
|
||||||
|
modelType := reflect.TypeOf(model)
|
||||||
|
|
||||||
|
// Unwrap pointers to get to the base struct type
|
||||||
|
for modelType != nil && modelType.Kind() == reflect.Pointer {
|
||||||
|
modelType = modelType.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate that we have a struct type
|
||||||
|
if modelType == nil || modelType.Kind() != reflect.Struct {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < modelType.NumField(); i++ {
|
||||||
|
field := modelType.Field(i)
|
||||||
|
|
||||||
|
// Check if this field matches the column name
|
||||||
|
fieldColumnName := getColumnNameFromField(field)
|
||||||
|
if fieldColumnName != columnName {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check bun tag for scanonly
|
||||||
|
bunTag := field.Tag.Get("bun")
|
||||||
|
if bunTag != "" {
|
||||||
|
if isBunFieldScanOnly(bunTag) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check gorm tag for write restrictions
|
||||||
|
gormTag := field.Tag.Get("gorm")
|
||||||
|
if gormTag != "" {
|
||||||
|
if isGormFieldReadOnly(gormTag) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Column is writable
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Column not found in model, allow it (might be a dynamic column)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// isBunFieldScanOnly checks if a bun tag indicates the field is scan-only
|
||||||
|
// Example: "column_name,scanonly" -> true
|
||||||
|
func isBunFieldScanOnly(tag string) bool {
|
||||||
|
parts := strings.Split(tag, ",")
|
||||||
|
for _, part := range parts {
|
||||||
|
if strings.TrimSpace(part) == "scanonly" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// isGormFieldReadOnly checks if a gorm tag indicates the field is read-only
|
||||||
|
// Examples:
|
||||||
|
// - "<-:false" -> true (no writes allowed)
|
||||||
|
// - "->" -> true (read-only, common pattern)
|
||||||
|
// - "column:name;->" -> true
|
||||||
|
// - "<-:create" -> false (writes allowed on create)
|
||||||
|
func isGormFieldReadOnly(tag string) bool {
|
||||||
|
parts := strings.Split(tag, ";")
|
||||||
|
for _, part := range parts {
|
||||||
|
part = strings.TrimSpace(part)
|
||||||
|
|
||||||
|
// Check for read-only marker
|
||||||
|
if part == "->" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for write restrictions
|
||||||
|
if value, found := strings.CutPrefix(part, "<-:"); found {
|
||||||
|
if value == "false" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import (
|
|||||||
|
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/reflection"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Handler handles API requests using database and model abstractions
|
// Handler handles API requests using database and model abstractions
|
||||||
@@ -199,7 +200,7 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
|||||||
if len(options.ComputedColumns) > 0 {
|
if len(options.ComputedColumns) > 0 {
|
||||||
for _, cu := range options.ComputedColumns {
|
for _, cu := range options.ComputedColumns {
|
||||||
logger.Debug("Applying computed column: %s", cu.Name)
|
logger.Debug("Applying computed column: %s", cu.Name)
|
||||||
query = query.ColumnExpr("(?) AS "+cu.Name, cu.Expression)
|
query = query.ColumnExpr(fmt.Sprintf("(%s) AS %s", cu.Expression, cu.Name))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -249,7 +250,8 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
|||||||
logger.Debug("Querying single record with ID: %s", id)
|
logger.Debug("Querying single record with ID: %s", id)
|
||||||
// For single record, create a new pointer to the struct type
|
// For single record, create a new pointer to the struct type
|
||||||
singleResult := reflect.New(modelType).Interface()
|
singleResult := reflect.New(modelType).Interface()
|
||||||
query = query.Where("id = ?", id)
|
|
||||||
|
query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(reflection.GetPrimaryKeyName(singleResult))), id)
|
||||||
if err := query.Scan(ctx, singleResult); err != nil {
|
if err := query.Scan(ctx, singleResult); err != nil {
|
||||||
logger.Error("Error querying record: %v", err)
|
logger.Error("Error querying record: %v", err)
|
||||||
h.sendError(w, http.StatusInternalServerError, "query_error", "Error executing query", err)
|
h.sendError(w, http.StatusInternalServerError, "query_error", "Error executing query", err)
|
||||||
@@ -521,15 +523,15 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url
|
|||||||
// Apply conditions
|
// Apply conditions
|
||||||
if urlID != "" {
|
if urlID != "" {
|
||||||
logger.Debug("Updating by URL ID: %s", urlID)
|
logger.Debug("Updating by URL ID: %s", urlID)
|
||||||
query = query.Where("id = ?", urlID)
|
query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(reflection.GetPrimaryKeyName(model))), urlID)
|
||||||
} else if reqID != nil {
|
} else if reqID != nil {
|
||||||
switch id := reqID.(type) {
|
switch id := reqID.(type) {
|
||||||
case string:
|
case string:
|
||||||
logger.Debug("Updating by request ID: %s", id)
|
logger.Debug("Updating by request ID: %s", id)
|
||||||
query = query.Where("id = ?", id)
|
query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(reflection.GetPrimaryKeyName(model))), id)
|
||||||
case []string:
|
case []string:
|
||||||
logger.Debug("Updating by multiple IDs: %v", id)
|
logger.Debug("Updating by multiple IDs: %v", id)
|
||||||
query = query.Where("id IN (?)", id)
|
query = query.Where(fmt.Sprintf("%s IN (?)", common.QuoteIdent(reflection.GetPrimaryKeyName(model))), id)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -593,7 +595,8 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url
|
|||||||
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
||||||
for _, item := range updates {
|
for _, item := range updates {
|
||||||
if itemID, ok := item["id"]; ok {
|
if itemID, ok := item["id"]; ok {
|
||||||
txQuery := tx.NewUpdate().Table(tableName).SetMap(item).Where("id = ?", itemID)
|
|
||||||
|
txQuery := tx.NewUpdate().Table(tableName).SetMap(item).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(reflection.GetPrimaryKeyName(model))), itemID)
|
||||||
if _, err := txQuery.Exec(ctx); err != nil {
|
if _, err := txQuery.Exec(ctx); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -659,7 +662,8 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url
|
|||||||
for _, item := range updates {
|
for _, item := range updates {
|
||||||
if itemMap, ok := item.(map[string]interface{}); ok {
|
if itemMap, ok := item.(map[string]interface{}); ok {
|
||||||
if itemID, ok := itemMap["id"]; ok {
|
if itemID, ok := itemMap["id"]; ok {
|
||||||
txQuery := tx.NewUpdate().Table(tableName).SetMap(itemMap).Where("id = ?", itemID)
|
|
||||||
|
txQuery := tx.NewUpdate().Table(tableName).SetMap(itemMap).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(reflection.GetPrimaryKeyName(model))), itemID)
|
||||||
if _, err := txQuery.Exec(ctx); err != nil {
|
if _, err := txQuery.Exec(ctx); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -695,6 +699,7 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
|||||||
schema := GetSchema(ctx)
|
schema := GetSchema(ctx)
|
||||||
entity := GetEntity(ctx)
|
entity := GetEntity(ctx)
|
||||||
tableName := GetTableName(ctx)
|
tableName := GetTableName(ctx)
|
||||||
|
model := GetModel(ctx)
|
||||||
|
|
||||||
logger.Info("Deleting records from %s.%s", schema, entity)
|
logger.Info("Deleting records from %s.%s", schema, entity)
|
||||||
|
|
||||||
@@ -706,7 +711,8 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
|||||||
logger.Info("Batch delete with %d IDs ([]string)", len(v))
|
logger.Info("Batch delete with %d IDs ([]string)", len(v))
|
||||||
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
||||||
for _, itemID := range v {
|
for _, itemID := range v {
|
||||||
query := tx.NewDelete().Table(tableName).Where("id = ?", itemID)
|
|
||||||
|
query := tx.NewDelete().Table(tableName).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(reflection.GetPrimaryKeyName(model))), itemID)
|
||||||
if _, err := query.Exec(ctx); err != nil {
|
if _, err := query.Exec(ctx); err != nil {
|
||||||
return fmt.Errorf("failed to delete record %s: %w", itemID, err)
|
return fmt.Errorf("failed to delete record %s: %w", itemID, err)
|
||||||
}
|
}
|
||||||
@@ -745,7 +751,7 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
|||||||
continue // Skip items without ID
|
continue // Skip items without ID
|
||||||
}
|
}
|
||||||
|
|
||||||
query := tx.NewDelete().Table(tableName).Where("id = ?", itemID)
|
query := tx.NewDelete().Table(tableName).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(reflection.GetPrimaryKeyName(model))), itemID)
|
||||||
result, err := query.Exec(ctx)
|
result, err := query.Exec(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to delete record %v: %w", itemID, err)
|
return fmt.Errorf("failed to delete record %v: %w", itemID, err)
|
||||||
@@ -770,7 +776,7 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
|||||||
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
||||||
for _, item := range v {
|
for _, item := range v {
|
||||||
if itemID, ok := item["id"]; ok && itemID != nil {
|
if itemID, ok := item["id"]; ok && itemID != nil {
|
||||||
query := tx.NewDelete().Table(tableName).Where("id = ?", itemID)
|
query := tx.NewDelete().Table(tableName).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(reflection.GetPrimaryKeyName(model))), itemID)
|
||||||
result, err := query.Exec(ctx)
|
result, err := query.Exec(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to delete record %v: %w", itemID, err)
|
return fmt.Errorf("failed to delete record %v: %w", itemID, err)
|
||||||
@@ -804,7 +810,7 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
query := h.db.NewDelete().Table(tableName).Where("id = ?", id)
|
query := h.db.NewDelete().Table(tableName).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(reflection.GetPrimaryKeyName(model))), id)
|
||||||
|
|
||||||
result, err := query.Exec(ctx)
|
result, err := query.Exec(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -1113,7 +1119,8 @@ func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, pre
|
|||||||
return query
|
return query
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, preload := range preloads {
|
for idx := range preloads {
|
||||||
|
preload := preloads[idx]
|
||||||
logger.Debug("Processing preload for relation: %s", preload.Relation)
|
logger.Debug("Processing preload for relation: %s", preload.Relation)
|
||||||
relInfo := h.getRelationshipInfo(modelType, preload.Relation)
|
relInfo := h.getRelationshipInfo(modelType, preload.Relation)
|
||||||
if relInfo == nil {
|
if relInfo == nil {
|
||||||
@@ -1128,7 +1135,75 @@ func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, pre
|
|||||||
// For now, we'll preload without conditions
|
// For now, we'll preload without conditions
|
||||||
// TODO: Implement column selection and filtering for preloads
|
// TODO: Implement column selection and filtering for preloads
|
||||||
// This requires a more sophisticated approach with callbacks or query builders
|
// This requires a more sophisticated approach with callbacks or query builders
|
||||||
query = query.Preload(relationFieldName)
|
// Apply preloading
|
||||||
|
|
||||||
|
logger.Debug("Applying preload: %s", relationFieldName)
|
||||||
|
query = query.PreloadRelation(relationFieldName, func(sq common.SelectQuery) common.SelectQuery {
|
||||||
|
if len(preload.OmitColumns) > 0 {
|
||||||
|
allCols := reflection.GetModelColumns(model)
|
||||||
|
// Remove omitted columns
|
||||||
|
preload.Columns = []string{}
|
||||||
|
for _, col := range allCols {
|
||||||
|
addCols := true
|
||||||
|
for _, omitCol := range preload.OmitColumns {
|
||||||
|
if col == omitCol {
|
||||||
|
addCols = false
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if addCols {
|
||||||
|
preload.Columns = append(preload.Columns, col)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(preload.Columns) > 0 {
|
||||||
|
// Ensure foreign key is included in column selection for GORM to establish the relationship
|
||||||
|
columns := make([]string, len(preload.Columns))
|
||||||
|
copy(columns, preload.Columns)
|
||||||
|
|
||||||
|
// Add foreign key if not already present
|
||||||
|
if relInfo.foreignKey != "" {
|
||||||
|
// Convert struct field name (e.g., DepartmentID) to snake_case (e.g., department_id)
|
||||||
|
foreignKeyColumn := toSnakeCase(relInfo.foreignKey)
|
||||||
|
|
||||||
|
hasForeignKey := false
|
||||||
|
for _, col := range columns {
|
||||||
|
if col == foreignKeyColumn || col == relInfo.foreignKey {
|
||||||
|
hasForeignKey = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !hasForeignKey {
|
||||||
|
columns = append(columns, foreignKeyColumn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sq = sq.Column(columns...)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(preload.Filters) > 0 {
|
||||||
|
for _, filter := range preload.Filters {
|
||||||
|
sq = h.applyFilter(sq, filter)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(preload.Sort) > 0 {
|
||||||
|
for _, sort := range preload.Sort {
|
||||||
|
sq = sq.Order(fmt.Sprintf("%s %s", sort.Column, sort.Direction))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(preload.Where) > 0 {
|
||||||
|
sq = sq.Where(preload.Where)
|
||||||
|
}
|
||||||
|
|
||||||
|
if preload.Limit != nil && *preload.Limit > 0 {
|
||||||
|
sq = sq.Limit(*preload.Limit)
|
||||||
|
}
|
||||||
|
|
||||||
|
return sq
|
||||||
|
})
|
||||||
|
|
||||||
logger.Debug("Applied Preload for relation: %s (field: %s)", preload.Relation, relationFieldName)
|
logger.Debug("Applied Preload for relation: %s (field: %s)", preload.Relation, relationFieldName)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1186,3 +1261,28 @@ func (h *Handler) extractTagValue(tag, key string) string {
|
|||||||
}
|
}
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// toSnakeCase converts a PascalCase or camelCase string to snake_case
|
||||||
|
func toSnakeCase(s string) string {
|
||||||
|
var result strings.Builder
|
||||||
|
runes := []rune(s)
|
||||||
|
|
||||||
|
for i := 0; i < len(runes); i++ {
|
||||||
|
r := runes[i]
|
||||||
|
|
||||||
|
if i > 0 && r >= 'A' && r <= 'Z' {
|
||||||
|
// Check if previous character is lowercase or if next character is lowercase
|
||||||
|
prevIsLower := runes[i-1] >= 'a' && runes[i-1] <= 'z'
|
||||||
|
nextIsLower := i+1 < len(runes) && runes[i+1] >= 'a' && runes[i+1] <= 'z'
|
||||||
|
|
||||||
|
// Add underscore if this is the start of a new word
|
||||||
|
// (previous was lowercase OR this is followed by lowercase)
|
||||||
|
if prevIsLower || nextIsLower {
|
||||||
|
result.WriteByte('_')
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
result.WriteRune(r)
|
||||||
|
}
|
||||||
|
return strings.ToLower(result.String())
|
||||||
|
}
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"reflect"
|
"reflect"
|
||||||
"runtime/debug"
|
"runtime/debug"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||||
@@ -133,9 +134,15 @@ func (h *Handler) Handle(w common.ResponseWriter, r common.Request, params map[s
|
|||||||
h.sendError(w, http.StatusBadRequest, "invalid_request", "Invalid request body", err)
|
h.sendError(w, http.StatusBadRequest, "invalid_request", "Invalid request body", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
h.handleCreate(ctx, w, data, options)
|
validId, _ := strconv.ParseInt(id, 10, 64)
|
||||||
|
if validId > 0 {
|
||||||
|
h.handleUpdate(ctx, w, id, nil, data, options)
|
||||||
|
} else {
|
||||||
|
h.handleCreate(ctx, w, data, options)
|
||||||
|
}
|
||||||
case "PUT", "PATCH":
|
case "PUT", "PATCH":
|
||||||
// Update operation
|
// Update operation
|
||||||
|
|
||||||
body, err := r.Body()
|
body, err := r.Body()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("Failed to read request body: %v", err)
|
logger.Error("Failed to read request body: %v", err)
|
||||||
@@ -257,14 +264,28 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
|||||||
if len(options.ComputedQL) > 0 {
|
if len(options.ComputedQL) > 0 {
|
||||||
for colName, colExpr := range options.ComputedQL {
|
for colName, colExpr := range options.ComputedQL {
|
||||||
logger.Debug("Applying computed column: %s", colName)
|
logger.Debug("Applying computed column: %s", colName)
|
||||||
query = query.ColumnExpr("(?) AS "+colName, colExpr)
|
query = query.ColumnExpr(fmt.Sprintf("(%s) AS %s", colExpr, colName))
|
||||||
|
for colIndex := range options.Columns {
|
||||||
|
if options.Columns[colIndex] == colName {
|
||||||
|
// Remove the computed column from the selected columns to avoid duplication
|
||||||
|
options.Columns = append(options.Columns[:colIndex], options.Columns[colIndex+1:]...)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(options.ComputedColumns) > 0 {
|
if len(options.ComputedColumns) > 0 {
|
||||||
for _, cu := range options.ComputedColumns {
|
for _, cu := range options.ComputedColumns {
|
||||||
logger.Debug("Applying computed column: %s", cu.Name)
|
logger.Debug("Applying computed column: %s", cu.Name)
|
||||||
query = query.ColumnExpr("(?) AS "+cu.Name, cu.Expression)
|
query = query.ColumnExpr(fmt.Sprintf("(%s) AS %s", cu.Expression, cu.Name))
|
||||||
|
for colIndex := range options.Columns {
|
||||||
|
if options.Columns[colIndex] == cu.Name {
|
||||||
|
// Remove the computed column from the selected columns to avoid duplication
|
||||||
|
options.Columns = append(options.Columns[:colIndex], options.Columns[colIndex+1:]...)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -274,18 +295,95 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
|||||||
query = query.Column(options.Columns...)
|
query = query.Column(options.Columns...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Apply preloading
|
// Apply expand (Just expand to Preload for now)
|
||||||
for _, preload := range options.Preload {
|
|
||||||
logger.Debug("Applying preload: %s", preload.Relation)
|
|
||||||
query = query.Preload(preload.Relation)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Apply expand (LEFT JOIN)
|
|
||||||
for _, expand := range options.Expand {
|
for _, expand := range options.Expand {
|
||||||
logger.Debug("Applying expand: %s", expand.Relation)
|
logger.Debug("Applying expand: %s", expand.Relation)
|
||||||
|
sorts := make([]common.SortOption, 0)
|
||||||
|
for _, s := range strings.Split(expand.Sort, ",") {
|
||||||
|
if s == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
dir := "ASC"
|
||||||
|
if strings.HasPrefix(s, "-") || strings.HasSuffix(strings.ToUpper(s), " DESC") {
|
||||||
|
dir = "DESC"
|
||||||
|
s = strings.TrimPrefix(s, "-")
|
||||||
|
s = strings.TrimSuffix(strings.ToLower(s), " desc")
|
||||||
|
}
|
||||||
|
sorts = append(sorts, common.SortOption{
|
||||||
|
Column: s, Direction: dir,
|
||||||
|
})
|
||||||
|
}
|
||||||
// Note: Expand would require JOIN implementation
|
// Note: Expand would require JOIN implementation
|
||||||
// For now, we'll use Preload as a fallback
|
// For now, we'll use Preload as a fallback
|
||||||
query = query.Preload(expand.Relation)
|
// query = query.Preload(expand.Relation)
|
||||||
|
if options.Preload == nil {
|
||||||
|
options.Preload = make([]common.PreloadOption, 0)
|
||||||
|
}
|
||||||
|
skip := false
|
||||||
|
for idx := range options.Preload {
|
||||||
|
if options.Preload[idx].Relation == expand.Relation {
|
||||||
|
skip = true
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !skip {
|
||||||
|
options.Preload = append(options.Preload, common.PreloadOption{
|
||||||
|
Relation: expand.Relation,
|
||||||
|
Columns: expand.Columns,
|
||||||
|
Sort: sorts,
|
||||||
|
Where: expand.Where,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply preloading
|
||||||
|
for idx := range options.Preload {
|
||||||
|
preload := options.Preload[idx]
|
||||||
|
logger.Debug("Applying preload: %s", preload.Relation)
|
||||||
|
query = query.PreloadRelation(preload.Relation, func(sq common.SelectQuery) common.SelectQuery {
|
||||||
|
if len(preload.OmitColumns) > 0 {
|
||||||
|
allCols := reflection.GetModelColumns(model)
|
||||||
|
// Remove omitted columns
|
||||||
|
preload.Columns = []string{}
|
||||||
|
for _, col := range allCols {
|
||||||
|
addCols := true
|
||||||
|
for _, omitCol := range preload.OmitColumns {
|
||||||
|
if col == omitCol {
|
||||||
|
addCols = false
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if addCols {
|
||||||
|
preload.Columns = append(preload.Columns, col)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(preload.Columns) > 0 {
|
||||||
|
sq = sq.Column(preload.Columns...)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(preload.Filters) > 0 {
|
||||||
|
for _, filter := range preload.Filters {
|
||||||
|
sq = h.applyFilter(sq, filter, "", false, "AND")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(preload.Sort) > 0 {
|
||||||
|
for _, sort := range preload.Sort {
|
||||||
|
sq = sq.Order(fmt.Sprintf("%s %s", sort.Column, sort.Direction))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(preload.Where) > 0 {
|
||||||
|
sq = sq.Where(preload.Where)
|
||||||
|
}
|
||||||
|
|
||||||
|
if preload.Limit != nil && *preload.Limit > 0 {
|
||||||
|
sq = sq.Limit(*preload.Limit)
|
||||||
|
}
|
||||||
|
|
||||||
|
return sq
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Apply DISTINCT if requested
|
// Apply DISTINCT if requested
|
||||||
@@ -326,8 +424,10 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
|||||||
|
|
||||||
// If ID is provided, filter by ID
|
// If ID is provided, filter by ID
|
||||||
if id != "" {
|
if id != "" {
|
||||||
logger.Debug("Filtering by ID: %s", id)
|
pkName := reflection.GetPrimaryKeyName(model)
|
||||||
query = query.Where("id = ?", id)
|
logger.Debug("Filtering by ID=%s: %s", pkName, id)
|
||||||
|
|
||||||
|
query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Apply sorting
|
// Apply sorting
|
||||||
@@ -433,7 +533,7 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
|||||||
|
|
||||||
metadata := &common.Metadata{
|
metadata := &common.Metadata{
|
||||||
Total: int64(total),
|
Total: int64(total),
|
||||||
Count: int64(common.Len(modelPtr)),
|
Count: int64(reflection.Len(modelPtr)),
|
||||||
Filtered: int64(total),
|
Filtered: int64(total),
|
||||||
Limit: limit,
|
Limit: limit,
|
||||||
Offset: offset,
|
Offset: offset,
|
||||||
@@ -484,22 +584,6 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
|
|||||||
|
|
||||||
logger.Info("Creating record in %s.%s", schema, entity)
|
logger.Info("Creating record in %s.%s", schema, entity)
|
||||||
|
|
||||||
// Check if data is a single map with nested relations
|
|
||||||
if dataMap, ok := data.(map[string]interface{}); ok {
|
|
||||||
if h.shouldUseNestedProcessor(dataMap, model) {
|
|
||||||
logger.Info("Using nested CUD processor for create operation")
|
|
||||||
result, err := h.nestedProcessor.ProcessNestedCUD(ctx, "insert", dataMap, model, make(map[string]interface{}), tableName)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("Error in nested create: %v", err)
|
|
||||||
h.sendError(w, http.StatusInternalServerError, "create_error", "Error creating record with nested data", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
logger.Info("Successfully created record with nested data, ID: %v", result.ID)
|
|
||||||
h.sendResponse(w, result.Data, nil)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Execute BeforeCreate hooks
|
// Execute BeforeCreate hooks
|
||||||
hookCtx := &HookContext{
|
hookCtx := &HookContext{
|
||||||
Context: ctx,
|
Context: ctx,
|
||||||
@@ -522,172 +606,113 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
|
|||||||
// Use potentially modified data from hook context
|
// Use potentially modified data from hook context
|
||||||
data = hookCtx.Data
|
data = hookCtx.Data
|
||||||
|
|
||||||
// Handle batch creation
|
// Normalize data to slice for unified processing
|
||||||
dataValue := reflect.ValueOf(data)
|
dataSlice := h.normalizeToSlice(data)
|
||||||
if dataValue.Kind() == reflect.Slice || dataValue.Kind() == reflect.Array {
|
logger.Debug("Processing %d item(s) for creation", len(dataSlice))
|
||||||
logger.Debug("Batch creation detected, count: %d", dataValue.Len())
|
|
||||||
|
|
||||||
// Check if any item needs nested processing
|
// Process all items in a transaction
|
||||||
hasNestedData := false
|
results := make([]interface{}, 0, len(dataSlice))
|
||||||
for i := 0; i < dataValue.Len(); i++ {
|
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
||||||
item := dataValue.Index(i).Interface()
|
// Create temporary nested processor with transaction
|
||||||
if itemMap, ok := item.(map[string]interface{}); ok {
|
txNestedProcessor := common.NewNestedCUDProcessor(tx, h.registry, h)
|
||||||
if h.shouldUseNestedProcessor(itemMap, model) {
|
|
||||||
hasNestedData = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if hasNestedData {
|
for i, item := range dataSlice {
|
||||||
logger.Info("Using nested CUD processor for batch create with nested data")
|
itemMap, ok := item.(map[string]interface{})
|
||||||
results := make([]interface{}, 0, dataValue.Len())
|
if !ok {
|
||||||
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
// Convert to map if needed
|
||||||
// Temporarily swap the database to use transaction
|
|
||||||
originalDB := h.nestedProcessor
|
|
||||||
h.nestedProcessor = common.NewNestedCUDProcessor(tx, h.registry, h)
|
|
||||||
defer func() {
|
|
||||||
h.nestedProcessor = originalDB
|
|
||||||
}()
|
|
||||||
|
|
||||||
for i := 0; i < dataValue.Len(); i++ {
|
|
||||||
item := dataValue.Index(i).Interface()
|
|
||||||
if itemMap, ok := item.(map[string]interface{}); ok {
|
|
||||||
result, err := h.nestedProcessor.ProcessNestedCUD(ctx, "insert", itemMap, model, make(map[string]interface{}), tableName)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to process item: %w", err)
|
|
||||||
}
|
|
||||||
results = append(results, result.Data)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("Error creating records with nested data: %v", err)
|
|
||||||
h.sendError(w, http.StatusInternalServerError, "create_error", "Error creating records with nested data", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Execute AfterCreate hooks
|
|
||||||
hookCtx.Result = map[string]interface{}{"created": len(results), "data": results}
|
|
||||||
hookCtx.Error = nil
|
|
||||||
|
|
||||||
if err := h.hooks.Execute(AfterCreate, hookCtx); err != nil {
|
|
||||||
logger.Error("AfterCreate hook failed: %v", err)
|
|
||||||
h.sendError(w, http.StatusInternalServerError, "hook_error", "Hook execution failed", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Info("Successfully created %d records with nested data", len(results))
|
|
||||||
h.sendResponse(w, results, nil)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Standard batch insert without nested relations
|
|
||||||
// Use transaction for batch insert
|
|
||||||
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
|
||||||
for i := 0; i < dataValue.Len(); i++ {
|
|
||||||
item := dataValue.Index(i).Interface()
|
|
||||||
|
|
||||||
// Convert item to model type - create a pointer to the model
|
|
||||||
modelValue := reflect.New(reflect.TypeOf(model)).Interface()
|
|
||||||
jsonData, err := json.Marshal(item)
|
jsonData, err := json.Marshal(item)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to marshal item: %w", err)
|
return fmt.Errorf("failed to marshal item %d: %w", i, err)
|
||||||
}
|
}
|
||||||
if err := json.Unmarshal(jsonData, modelValue); err != nil {
|
itemMap = make(map[string]interface{})
|
||||||
return fmt.Errorf("failed to unmarshal item: %w", err)
|
if err := json.Unmarshal(jsonData, &itemMap); err != nil {
|
||||||
}
|
return fmt.Errorf("failed to unmarshal item %d: %w", i, err)
|
||||||
|
|
||||||
query := tx.NewInsert().Model(modelValue).Table(tableName)
|
|
||||||
|
|
||||||
// Execute BeforeScan hooks - pass query chain so hooks can modify it
|
|
||||||
batchHookCtx := &HookContext{
|
|
||||||
Context: ctx,
|
|
||||||
Handler: h,
|
|
||||||
Schema: schema,
|
|
||||||
Entity: entity,
|
|
||||||
TableName: tableName,
|
|
||||||
Model: model,
|
|
||||||
Options: options,
|
|
||||||
Data: modelValue,
|
|
||||||
Writer: w,
|
|
||||||
Query: query,
|
|
||||||
}
|
|
||||||
if err := h.hooks.Execute(BeforeScan, batchHookCtx); err != nil {
|
|
||||||
return fmt.Errorf("BeforeScan hook failed: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Use potentially modified query from hook context
|
|
||||||
if modifiedQuery, ok := batchHookCtx.Query.(common.InsertQuery); ok {
|
|
||||||
query = modifiedQuery
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := query.Exec(ctx); err != nil {
|
|
||||||
return fmt.Errorf("failed to insert record: %w", err)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
|
||||||
})
|
|
||||||
|
|
||||||
if err != nil {
|
// Extract nested relations if present (but don't process them yet)
|
||||||
logger.Error("Error creating records: %v", err)
|
var nestedRelations map[string]interface{}
|
||||||
h.sendError(w, http.StatusInternalServerError, "create_error", "Error creating records", err)
|
if h.shouldUseNestedProcessor(itemMap, model) {
|
||||||
return
|
logger.Debug("Extracting nested relations for item %d", i)
|
||||||
|
cleanedData, relations, err := h.extractNestedRelations(itemMap, model)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to extract nested relations for item %d: %w", i, err)
|
||||||
|
}
|
||||||
|
itemMap = cleanedData
|
||||||
|
nestedRelations = relations
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert item to model type - create a pointer to the model
|
||||||
|
modelValue := reflect.New(reflect.TypeOf(model)).Interface()
|
||||||
|
jsonData, err := json.Marshal(itemMap)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to marshal item %d: %w", i, err)
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(jsonData, modelValue); err != nil {
|
||||||
|
return fmt.Errorf("failed to unmarshal item %d: %w", i, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create insert query
|
||||||
|
query := tx.NewInsert().Model(modelValue).Table(tableName).Returning("*")
|
||||||
|
|
||||||
|
// Execute BeforeScan hooks - pass query chain so hooks can modify it
|
||||||
|
itemHookCtx := &HookContext{
|
||||||
|
Context: ctx,
|
||||||
|
Handler: h,
|
||||||
|
Schema: schema,
|
||||||
|
Entity: entity,
|
||||||
|
TableName: tableName,
|
||||||
|
Model: model,
|
||||||
|
Options: options,
|
||||||
|
Data: modelValue,
|
||||||
|
Writer: w,
|
||||||
|
Query: query,
|
||||||
|
}
|
||||||
|
if err := h.hooks.Execute(BeforeScan, itemHookCtx); err != nil {
|
||||||
|
return fmt.Errorf("BeforeScan hook failed for item %d: %w", i, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use potentially modified query from hook context
|
||||||
|
if modifiedQuery, ok := itemHookCtx.Query.(common.InsertQuery); ok {
|
||||||
|
query = modifiedQuery
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute insert and get the ID
|
||||||
|
if _, err := query.Exec(ctx); err != nil {
|
||||||
|
return fmt.Errorf("failed to insert item %d: %w", i, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the inserted ID
|
||||||
|
insertedID := reflection.GetPrimaryKeyValue(modelValue)
|
||||||
|
|
||||||
|
// Now process nested relations with the parent ID
|
||||||
|
if len(nestedRelations) > 0 {
|
||||||
|
logger.Debug("Processing nested relations for item %d with parent ID: %v", i, insertedID)
|
||||||
|
if err := h.processChildRelationsWithParentID(ctx, txNestedProcessor, "insert", nestedRelations, model, insertedID); err != nil {
|
||||||
|
return fmt.Errorf("failed to process nested relations for item %d: %w", i, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
results = append(results, modelValue)
|
||||||
}
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
// Execute AfterCreate hooks for batch creation
|
|
||||||
hookCtx.Result = map[string]interface{}{"created": dataValue.Len()}
|
|
||||||
hookCtx.Error = nil
|
|
||||||
|
|
||||||
if err := h.hooks.Execute(AfterCreate, hookCtx); err != nil {
|
|
||||||
logger.Error("AfterCreate hook failed: %v", err)
|
|
||||||
h.sendError(w, http.StatusInternalServerError, "hook_error", "Hook execution failed", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
h.sendResponse(w, map[string]interface{}{"created": dataValue.Len()}, nil)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Single record creation - create a pointer to the model
|
|
||||||
modelValue := reflect.New(reflect.TypeOf(model)).Interface()
|
|
||||||
jsonData, err := json.Marshal(data)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("Error marshaling data: %v", err)
|
logger.Error("Error creating records: %v", err)
|
||||||
h.sendError(w, http.StatusBadRequest, "invalid_data", "Invalid data format", err)
|
h.sendError(w, http.StatusInternalServerError, "create_error", "Error creating records", err)
|
||||||
return
|
|
||||||
}
|
|
||||||
if err := json.Unmarshal(jsonData, modelValue); err != nil {
|
|
||||||
logger.Error("Error unmarshaling data: %v", err)
|
|
||||||
h.sendError(w, http.StatusBadRequest, "invalid_data", "Invalid data format", err)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
query := h.db.NewInsert().Model(modelValue).Table(tableName)
|
// Execute AfterCreate hooks
|
||||||
|
var responseData interface{}
|
||||||
// Execute BeforeScan hooks - pass query chain so hooks can modify it
|
if len(results) == 1 {
|
||||||
hookCtx.Data = modelValue
|
responseData = results[0]
|
||||||
hookCtx.Query = query
|
hookCtx.Result = results[0]
|
||||||
if err := h.hooks.Execute(BeforeScan, hookCtx); err != nil {
|
} else {
|
||||||
logger.Error("BeforeScan hook failed: %v", err)
|
responseData = results
|
||||||
h.sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err)
|
hookCtx.Result = map[string]interface{}{"created": len(results), "data": results}
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Use potentially modified query from hook context
|
|
||||||
if modifiedQuery, ok := hookCtx.Query.(common.InsertQuery); ok {
|
|
||||||
query = modifiedQuery
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := query.Exec(ctx); err != nil {
|
|
||||||
logger.Error("Error creating record: %v", err)
|
|
||||||
h.sendError(w, http.StatusInternalServerError, "create_error", "Error creating record", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Execute AfterCreate hooks for single record creation
|
|
||||||
hookCtx.Result = modelValue
|
|
||||||
hookCtx.Error = nil
|
hookCtx.Error = nil
|
||||||
|
|
||||||
if err := h.hooks.Execute(AfterCreate, hookCtx); err != nil {
|
if err := h.hooks.Execute(AfterCreate, hookCtx); err != nil {
|
||||||
@@ -696,7 +721,8 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
h.sendResponse(w, modelValue, nil)
|
logger.Info("Successfully created %d record(s)", len(results))
|
||||||
|
h.sendResponseWithOptions(w, responseData, nil, &options)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id string, idPtr *int64, data interface{}, options ExtendedRequestOptions) {
|
func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id string, idPtr *int64, data interface{}, options ExtendedRequestOptions) {
|
||||||
@@ -714,46 +740,6 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id
|
|||||||
|
|
||||||
logger.Info("Updating record in %s.%s", schema, entity)
|
logger.Info("Updating record in %s.%s", schema, entity)
|
||||||
|
|
||||||
// Convert data to map first for nested processor check
|
|
||||||
dataMap, ok := data.(map[string]interface{})
|
|
||||||
if !ok {
|
|
||||||
jsonData, err := json.Marshal(data)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("Error marshaling data: %v", err)
|
|
||||||
h.sendError(w, http.StatusBadRequest, "invalid_data", "Invalid data format", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if err := json.Unmarshal(jsonData, &dataMap); err != nil {
|
|
||||||
logger.Error("Error unmarshaling data: %v", err)
|
|
||||||
h.sendError(w, http.StatusBadRequest, "invalid_data", "Invalid data format", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if we should use nested processing
|
|
||||||
if h.shouldUseNestedProcessor(dataMap, model) {
|
|
||||||
logger.Info("Using nested CUD processor for update operation")
|
|
||||||
// Ensure ID is in the data map
|
|
||||||
var targetID interface{}
|
|
||||||
if id != "" {
|
|
||||||
targetID = id
|
|
||||||
} else if idPtr != nil {
|
|
||||||
targetID = *idPtr
|
|
||||||
}
|
|
||||||
if targetID != nil {
|
|
||||||
dataMap["id"] = targetID
|
|
||||||
}
|
|
||||||
result, err := h.nestedProcessor.ProcessNestedCUD(ctx, "update", dataMap, model, make(map[string]interface{}), tableName)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("Error in nested update: %v", err)
|
|
||||||
h.sendError(w, http.StatusInternalServerError, "update_error", "Error updating record with nested data", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
logger.Info("Successfully updated record with nested data, rows: %d", result.AffectedRows)
|
|
||||||
h.sendResponse(w, result.Data, nil)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Execute BeforeUpdate hooks
|
// Execute BeforeUpdate hooks
|
||||||
hookCtx := &HookContext{
|
hookCtx := &HookContext{
|
||||||
Context: ctx,
|
Context: ctx,
|
||||||
@@ -777,8 +763,8 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id
|
|||||||
// Use potentially modified data from hook context
|
// Use potentially modified data from hook context
|
||||||
data = hookCtx.Data
|
data = hookCtx.Data
|
||||||
|
|
||||||
// Convert data to map (again if modified by hooks)
|
// Convert data to map
|
||||||
dataMap, ok = data.(map[string]interface{})
|
dataMap, ok := data.(map[string]interface{})
|
||||||
if !ok {
|
if !ok {
|
||||||
jsonData, err := json.Marshal(data)
|
jsonData, err := json.Marshal(data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -793,33 +779,74 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
query := h.db.NewUpdate().Table(tableName).SetMap(dataMap)
|
// Determine target ID
|
||||||
|
var targetID interface{}
|
||||||
// Apply ID filter
|
if id != "" {
|
||||||
switch {
|
targetID = id
|
||||||
case id != "":
|
} else if idPtr != nil {
|
||||||
query = query.Where("id = ?", id)
|
targetID = *idPtr
|
||||||
case idPtr != nil:
|
} else {
|
||||||
query = query.Where("id = ?", *idPtr)
|
|
||||||
default:
|
|
||||||
h.sendError(w, http.StatusBadRequest, "missing_id", "ID is required for update", nil)
|
h.sendError(w, http.StatusBadRequest, "missing_id", "ID is required for update", nil)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Execute BeforeScan hooks - pass query chain so hooks can modify it
|
// Process nested relations if present
|
||||||
hookCtx.Query = query
|
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
||||||
if err := h.hooks.Execute(BeforeScan, hookCtx); err != nil {
|
// Create temporary nested processor with transaction
|
||||||
logger.Error("BeforeScan hook failed: %v", err)
|
txNestedProcessor := common.NewNestedCUDProcessor(tx, h.registry, h)
|
||||||
h.sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Use potentially modified query from hook context
|
// Extract nested relations if present (but don't process them yet)
|
||||||
if modifiedQuery, ok := hookCtx.Query.(common.UpdateQuery); ok {
|
var nestedRelations map[string]interface{}
|
||||||
query = modifiedQuery
|
if h.shouldUseNestedProcessor(dataMap, model) {
|
||||||
}
|
logger.Debug("Extracting nested relations for update")
|
||||||
|
cleanedData, relations, err := h.extractNestedRelations(dataMap, model)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to extract nested relations: %w", err)
|
||||||
|
}
|
||||||
|
dataMap = cleanedData
|
||||||
|
nestedRelations = relations
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure ID is in the data map for the update
|
||||||
|
dataMap["id"] = targetID
|
||||||
|
|
||||||
|
// Create update query
|
||||||
|
query := tx.NewUpdate().Model(model).Table(tableName).SetMap(dataMap)
|
||||||
|
pkName := reflection.GetPrimaryKeyName(model)
|
||||||
|
query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), targetID)
|
||||||
|
|
||||||
|
// Execute BeforeScan hooks - pass query chain so hooks can modify it
|
||||||
|
hookCtx.Query = query
|
||||||
|
if err := h.hooks.Execute(BeforeScan, hookCtx); err != nil {
|
||||||
|
return fmt.Errorf("BeforeScan hook failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use potentially modified query from hook context
|
||||||
|
if modifiedQuery, ok := hookCtx.Query.(common.UpdateQuery); ok {
|
||||||
|
query = modifiedQuery
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute update
|
||||||
|
result, err := query.Exec(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to update record: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now process nested relations with the parent ID
|
||||||
|
if len(nestedRelations) > 0 {
|
||||||
|
logger.Debug("Processing nested relations for update with parent ID: %v", targetID)
|
||||||
|
if err := h.processChildRelationsWithParentID(ctx, txNestedProcessor, "update", nestedRelations, model, targetID); err != nil {
|
||||||
|
return fmt.Errorf("failed to process nested relations: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store result for hooks
|
||||||
|
hookCtx.Result = map[string]interface{}{
|
||||||
|
"updated": result.RowsAffected(),
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
result, err := query.Exec(ctx)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("Error updating record: %v", err)
|
logger.Error("Error updating record: %v", err)
|
||||||
h.sendError(w, http.StatusInternalServerError, "update_error", "Error updating record", err)
|
h.sendError(w, http.StatusInternalServerError, "update_error", "Error updating record", err)
|
||||||
@@ -827,19 +854,15 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Execute AfterUpdate hooks
|
// Execute AfterUpdate hooks
|
||||||
responseData := map[string]interface{}{
|
|
||||||
"updated": result.RowsAffected(),
|
|
||||||
}
|
|
||||||
hookCtx.Result = responseData
|
|
||||||
hookCtx.Error = nil
|
hookCtx.Error = nil
|
||||||
|
|
||||||
if err := h.hooks.Execute(AfterUpdate, hookCtx); err != nil {
|
if err := h.hooks.Execute(AfterUpdate, hookCtx); err != nil {
|
||||||
logger.Error("AfterUpdate hook failed: %v", err)
|
logger.Error("AfterUpdate hook failed: %v", err)
|
||||||
h.sendError(w, http.StatusInternalServerError, "hook_error", "Hook execution failed", err)
|
h.sendError(w, http.StatusInternalServerError, "hook_error", "Hook execution failed", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
h.sendResponse(w, responseData, nil)
|
logger.Info("Successfully updated record with ID: %v", targetID)
|
||||||
|
h.sendResponseWithOptions(w, hookCtx.Result, nil, &options)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id string, data interface{}) {
|
func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id string, data interface{}) {
|
||||||
@@ -883,7 +906,7 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
query := tx.NewDelete().Table(tableName).Where("id = ?", itemID)
|
query := tx.NewDelete().Table(tableName).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(reflection.GetPrimaryKeyName(model))), itemID)
|
||||||
|
|
||||||
result, err := query.Exec(ctx)
|
result, err := query.Exec(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -950,7 +973,7 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
query := tx.NewDelete().Table(tableName).Where("id = ?", itemID)
|
query := tx.NewDelete().Table(tableName).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(reflection.GetPrimaryKeyName(model))), itemID)
|
||||||
result, err := query.Exec(ctx)
|
result, err := query.Exec(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to delete record %v: %w", itemID, err)
|
return fmt.Errorf("failed to delete record %v: %w", itemID, err)
|
||||||
@@ -1001,7 +1024,7 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
query := tx.NewDelete().Table(tableName).Where("id = ?", itemID)
|
query := tx.NewDelete().Table(tableName).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(reflection.GetPrimaryKeyName(model))), itemID)
|
||||||
result, err := query.Exec(ctx)
|
result, err := query.Exec(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to delete record %v: %w", itemID, err)
|
return fmt.Errorf("failed to delete record %v: %w", itemID, err)
|
||||||
@@ -1061,7 +1084,7 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
query = query.Where("id = ?", id)
|
query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(reflection.GetPrimaryKeyName(model))), id)
|
||||||
|
|
||||||
// Execute BeforeScan hooks - pass query chain so hooks can modify it
|
// Execute BeforeScan hooks - pass query chain so hooks can modify it
|
||||||
hookCtx.Query = query
|
hookCtx.Query = query
|
||||||
@@ -1099,6 +1122,196 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
|||||||
h.sendResponse(w, responseData, nil)
|
h.sendResponse(w, responseData, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// normalizeToSlice converts data to a slice. Single items become a 1-item slice.
|
||||||
|
func (h *Handler) normalizeToSlice(data interface{}) []interface{} {
|
||||||
|
if data == nil {
|
||||||
|
return []interface{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
dataValue := reflect.ValueOf(data)
|
||||||
|
if dataValue.Kind() == reflect.Slice || dataValue.Kind() == reflect.Array {
|
||||||
|
result := make([]interface{}, dataValue.Len())
|
||||||
|
for i := 0; i < dataValue.Len(); i++ {
|
||||||
|
result[i] = dataValue.Index(i).Interface()
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// Single item - return as 1-item slice
|
||||||
|
return []interface{}{data}
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractNestedRelations extracts nested relations from data, returning cleaned data and relations
|
||||||
|
// This does NOT process the relations, just separates them for later processing
|
||||||
|
func (h *Handler) extractNestedRelations(
|
||||||
|
data map[string]interface{},
|
||||||
|
model interface{},
|
||||||
|
) (map[string]interface{}, map[string]interface{}, error) {
|
||||||
|
// Get model type for reflection
|
||||||
|
modelType := reflect.TypeOf(model)
|
||||||
|
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
||||||
|
modelType = modelType.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
if modelType == nil || modelType.Kind() != reflect.Struct {
|
||||||
|
return data, nil, fmt.Errorf("model must be a struct type, got %v", modelType)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Separate relation fields from regular fields
|
||||||
|
cleanedData := make(map[string]interface{})
|
||||||
|
relations := make(map[string]interface{})
|
||||||
|
|
||||||
|
for key, value := range data {
|
||||||
|
// Skip _request field
|
||||||
|
if key == "_request" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if this field is a relation
|
||||||
|
relInfo := h.GetRelationshipInfo(modelType, key)
|
||||||
|
if relInfo != nil {
|
||||||
|
logger.Debug("Found nested relation field: %s (type: %s)", key, relInfo.RelationType)
|
||||||
|
relations[key] = value
|
||||||
|
} else {
|
||||||
|
cleanedData[key] = value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return cleanedData, relations, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// processChildRelationsWithParentID processes nested relations with a parent ID
|
||||||
|
func (h *Handler) processChildRelationsWithParentID(
|
||||||
|
ctx context.Context,
|
||||||
|
processor *common.NestedCUDProcessor,
|
||||||
|
operation string,
|
||||||
|
relations map[string]interface{},
|
||||||
|
parentModel interface{},
|
||||||
|
parentID interface{},
|
||||||
|
) error {
|
||||||
|
// Get model type for reflection
|
||||||
|
modelType := reflect.TypeOf(parentModel)
|
||||||
|
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
||||||
|
modelType = modelType.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
if modelType == nil || modelType.Kind() != reflect.Struct {
|
||||||
|
return fmt.Errorf("model must be a struct type, got %v", modelType)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process each relation
|
||||||
|
for relationName, relationValue := range relations {
|
||||||
|
if relationValue == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get relationship info
|
||||||
|
relInfo := h.GetRelationshipInfo(modelType, relationName)
|
||||||
|
if relInfo == nil {
|
||||||
|
logger.Warn("No relationship info found for %s, skipping", relationName)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process this relation with parent ID
|
||||||
|
if err := h.processChildRelationsForField(ctx, processor, operation, relationName, relationValue, relInfo, modelType, parentID); err != nil {
|
||||||
|
return fmt.Errorf("failed to process relation %s: %w", relationName, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// processChildRelationsForField processes a single nested relation field
|
||||||
|
func (h *Handler) processChildRelationsForField(
|
||||||
|
ctx context.Context,
|
||||||
|
processor *common.NestedCUDProcessor,
|
||||||
|
operation string,
|
||||||
|
relationName string,
|
||||||
|
relationValue interface{},
|
||||||
|
relInfo *common.RelationshipInfo,
|
||||||
|
parentModelType reflect.Type,
|
||||||
|
parentID interface{},
|
||||||
|
) error {
|
||||||
|
if relationValue == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the related model
|
||||||
|
field, found := parentModelType.FieldByName(relInfo.FieldName)
|
||||||
|
if !found {
|
||||||
|
return fmt.Errorf("field %s not found in model", relInfo.FieldName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the model type for the relation
|
||||||
|
relatedModelType := field.Type
|
||||||
|
if relatedModelType.Kind() == reflect.Slice {
|
||||||
|
relatedModelType = relatedModelType.Elem()
|
||||||
|
}
|
||||||
|
if relatedModelType.Kind() == reflect.Ptr {
|
||||||
|
relatedModelType = relatedModelType.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create an instance of the related model
|
||||||
|
relatedModel := reflect.New(relatedModelType).Elem().Interface()
|
||||||
|
|
||||||
|
// Get table name for related model
|
||||||
|
relatedTableName := h.getTableNameForRelatedModel(relatedModel, relInfo.JSONName)
|
||||||
|
|
||||||
|
// Prepare parent IDs for foreign key injection
|
||||||
|
parentIDs := make(map[string]interface{})
|
||||||
|
if relInfo.ForeignKey != "" && parentID != nil {
|
||||||
|
baseName := strings.TrimSuffix(relInfo.ForeignKey, "ID")
|
||||||
|
baseName = strings.TrimSuffix(strings.ToLower(baseName), "_id")
|
||||||
|
parentIDs[baseName] = parentID
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process based on relation type and data structure
|
||||||
|
switch v := relationValue.(type) {
|
||||||
|
case map[string]interface{}:
|
||||||
|
// Single related object
|
||||||
|
_, err := processor.ProcessNestedCUD(ctx, operation, v, relatedModel, parentIDs, relatedTableName)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to process single relation: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
case []interface{}:
|
||||||
|
// Multiple related objects
|
||||||
|
for i, item := range v {
|
||||||
|
if itemMap, ok := item.(map[string]interface{}); ok {
|
||||||
|
_, err := processor.ProcessNestedCUD(ctx, operation, itemMap, relatedModel, parentIDs, relatedTableName)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to process relation item %d: %w", i, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
case []map[string]interface{}:
|
||||||
|
// Multiple related objects (typed slice)
|
||||||
|
for i, itemMap := range v {
|
||||||
|
_, err := processor.ProcessNestedCUD(ctx, operation, itemMap, relatedModel, parentIDs, relatedTableName)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to process relation item %d: %w", i, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("unsupported relation data type: %T", relationValue)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getTableNameForRelatedModel gets the table name for a related model
|
||||||
|
func (h *Handler) getTableNameForRelatedModel(model interface{}, defaultName string) string {
|
||||||
|
if provider, ok := model.(common.TableNameProvider); ok {
|
||||||
|
tableName := provider.TableName()
|
||||||
|
if tableName != "" {
|
||||||
|
return tableName
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return defaultName
|
||||||
|
}
|
||||||
|
|
||||||
// qualifyColumnName ensures column name is fully qualified with table name if not already
|
// qualifyColumnName ensures column name is fully qualified with table name if not already
|
||||||
func (h *Handler) qualifyColumnName(columnName, fullTableName string) string {
|
func (h *Handler) qualifyColumnName(columnName, fullTableName string) string {
|
||||||
// Check if column already has a table/schema prefix (contains a dot)
|
// Check if column already has a table/schema prefix (contains a dot)
|
||||||
@@ -1341,6 +1554,16 @@ func (h *Handler) isNullable(field reflect.StructField) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) sendResponse(w common.ResponseWriter, data interface{}, metadata *common.Metadata) {
|
func (h *Handler) sendResponse(w common.ResponseWriter, data interface{}, metadata *common.Metadata) {
|
||||||
|
h.sendResponseWithOptions(w, data, metadata, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// sendResponseWithOptions sends a response with optional formatting
|
||||||
|
func (h *Handler) sendResponseWithOptions(w common.ResponseWriter, data interface{}, metadata *common.Metadata, options *ExtendedRequestOptions) {
|
||||||
|
// Normalize single-record arrays to objects if requested
|
||||||
|
if options != nil && options.SingleRecordAsObject {
|
||||||
|
data = h.normalizeResultArray(data)
|
||||||
|
}
|
||||||
|
|
||||||
response := common.Response{
|
response := common.Response{
|
||||||
Success: true,
|
Success: true,
|
||||||
Data: data,
|
Data: data,
|
||||||
@@ -1352,8 +1575,35 @@ func (h *Handler) sendResponse(w common.ResponseWriter, data interface{}, metada
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// normalizeResultArray converts a single-element array to an object if requested
|
||||||
|
// Returns the single element if data is a slice/array with exactly one element, otherwise returns data unchanged
|
||||||
|
func (h *Handler) normalizeResultArray(data interface{}) interface{} {
|
||||||
|
if data == nil {
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use reflection to check if data is a slice or array
|
||||||
|
dataValue := reflect.ValueOf(data)
|
||||||
|
if dataValue.Kind() == reflect.Ptr {
|
||||||
|
dataValue = dataValue.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if it's a slice or array with exactly one element
|
||||||
|
if (dataValue.Kind() == reflect.Slice || dataValue.Kind() == reflect.Array) && dataValue.Len() == 1 {
|
||||||
|
// Return the single element
|
||||||
|
return dataValue.Index(0).Interface()
|
||||||
|
}
|
||||||
|
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
// sendFormattedResponse sends response with formatting options
|
// sendFormattedResponse sends response with formatting options
|
||||||
func (h *Handler) sendFormattedResponse(w common.ResponseWriter, data interface{}, metadata *common.Metadata, options ExtendedRequestOptions) {
|
func (h *Handler) sendFormattedResponse(w common.ResponseWriter, data interface{}, metadata *common.Metadata, options ExtendedRequestOptions) {
|
||||||
|
// Normalize single-record arrays to objects if requested
|
||||||
|
if options.SingleRecordAsObject {
|
||||||
|
data = h.normalizeResultArray(data)
|
||||||
|
}
|
||||||
|
|
||||||
// Clean JSON if requested (remove null/empty fields)
|
// Clean JSON if requested (remove null/empty fields)
|
||||||
if options.CleanJSON {
|
if options.CleanJSON {
|
||||||
data = h.cleanJSON(data)
|
data = h.cleanJSON(data)
|
||||||
@@ -1441,6 +1691,9 @@ func (h *Handler) FetchRowNumber(ctx context.Context, tableName string, pkName s
|
|||||||
if len(options.Sort) > 0 {
|
if len(options.Sort) > 0 {
|
||||||
sortParts := make([]string, 0, len(options.Sort))
|
sortParts := make([]string, 0, len(options.Sort))
|
||||||
for _, sort := range options.Sort {
|
for _, sort := range options.Sort {
|
||||||
|
if sort.Column == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
direction := "ASC"
|
direction := "ASC"
|
||||||
if strings.EqualFold(sort.Direction, "desc") {
|
if strings.EqualFold(sort.Direction, "desc") {
|
||||||
direction = "DESC"
|
direction = "DESC"
|
||||||
@@ -1651,7 +1904,8 @@ func filterExtendedOptions(validator *common.ColumnValidator, options ExtendedRe
|
|||||||
}
|
}
|
||||||
|
|
||||||
// shouldUseNestedProcessor determines if we should use nested CUD processing
|
// shouldUseNestedProcessor determines if we should use nested CUD processing
|
||||||
// It checks if the data contains nested relations or a _request field
|
// It recursively checks if the data contains deeply nested relations or _request fields
|
||||||
|
// Simple one-level relations without further nesting don't require the nested processor
|
||||||
func (h *Handler) shouldUseNestedProcessor(data map[string]interface{}, model interface{}) bool {
|
func (h *Handler) shouldUseNestedProcessor(data map[string]interface{}, model interface{}) bool {
|
||||||
return common.ShouldUseNestedProcessor(data, model, h)
|
return common.ShouldUseNestedProcessor(data, model, h)
|
||||||
}
|
}
|
||||||
@@ -1713,12 +1967,40 @@ func (h *Handler) getRelationshipInfo(modelType reflect.Type, relationName strin
|
|||||||
// Determine if it's belongsTo or hasMany/hasOne
|
// Determine if it's belongsTo or hasMany/hasOne
|
||||||
if field.Type.Kind() == reflect.Slice {
|
if field.Type.Kind() == reflect.Slice {
|
||||||
info.relationType = "hasMany"
|
info.relationType = "hasMany"
|
||||||
|
// Get the element type for slice
|
||||||
|
elemType := field.Type.Elem()
|
||||||
|
if elemType.Kind() == reflect.Ptr {
|
||||||
|
elemType = elemType.Elem()
|
||||||
|
}
|
||||||
|
if elemType.Kind() == reflect.Struct {
|
||||||
|
info.relatedModel = reflect.New(elemType).Elem().Interface()
|
||||||
|
}
|
||||||
} else if field.Type.Kind() == reflect.Ptr || field.Type.Kind() == reflect.Struct {
|
} else if field.Type.Kind() == reflect.Ptr || field.Type.Kind() == reflect.Struct {
|
||||||
info.relationType = "belongsTo"
|
info.relationType = "belongsTo"
|
||||||
|
elemType := field.Type
|
||||||
|
if elemType.Kind() == reflect.Ptr {
|
||||||
|
elemType = elemType.Elem()
|
||||||
|
}
|
||||||
|
if elemType.Kind() == reflect.Struct {
|
||||||
|
info.relatedModel = reflect.New(elemType).Elem().Interface()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} else if strings.Contains(gormTag, "many2many") {
|
} else if strings.Contains(gormTag, "many2many") {
|
||||||
info.relationType = "many2many"
|
info.relationType = "many2many"
|
||||||
info.joinTable = h.extractTagValue(gormTag, "many2many")
|
info.joinTable = h.extractTagValue(gormTag, "many2many")
|
||||||
|
// Get the element type for many2many (always slice)
|
||||||
|
if field.Type.Kind() == reflect.Slice {
|
||||||
|
elemType := field.Type.Elem()
|
||||||
|
if elemType.Kind() == reflect.Ptr {
|
||||||
|
elemType = elemType.Elem()
|
||||||
|
}
|
||||||
|
if elemType.Kind() == reflect.Struct {
|
||||||
|
info.relatedModel = reflect.New(elemType).Elem().Interface()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Field has no GORM relationship tags, so it's not a relation
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return info
|
return info
|
||||||
|
|||||||
423
pkg/restheadspec/handler_nested_test.go
Normal file
423
pkg/restheadspec/handler_nested_test.go
Normal file
@@ -0,0 +1,423 @@
|
|||||||
|
package restheadspec
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Test models for nested CRUD operations
|
||||||
|
type TestUser struct {
|
||||||
|
ID int64 `json:"id" bun:"id,pk,autoincrement"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Posts []TestPost `json:"posts" gorm:"foreignKey:UserID"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type TestPost struct {
|
||||||
|
ID int64 `json:"id" bun:"id,pk,autoincrement"`
|
||||||
|
UserID int64 `json:"user_id"`
|
||||||
|
Title string `json:"title"`
|
||||||
|
Comments []TestComment `json:"comments" gorm:"foreignKey:PostID"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type TestComment struct {
|
||||||
|
ID int64 `json:"id" bun:"id,pk,autoincrement"`
|
||||||
|
PostID int64 `json:"post_id"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (TestUser) TableName() string { return "users" }
|
||||||
|
func (TestPost) TableName() string { return "posts" }
|
||||||
|
func (TestComment) TableName() string { return "comments" }
|
||||||
|
|
||||||
|
// Test extractNestedRelations function
|
||||||
|
func TestExtractNestedRelations(t *testing.T) {
|
||||||
|
// Create handler
|
||||||
|
registry := &mockRegistry{
|
||||||
|
models: map[string]interface{}{
|
||||||
|
"users": TestUser{},
|
||||||
|
"posts": TestPost{},
|
||||||
|
"comments": TestComment{},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
handler := NewHandler(nil, registry)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
data map[string]interface{}
|
||||||
|
model interface{}
|
||||||
|
expectedCleanCount int
|
||||||
|
expectedRelCount int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "User with posts",
|
||||||
|
data: map[string]interface{}{
|
||||||
|
"name": "John Doe",
|
||||||
|
"posts": []map[string]interface{}{
|
||||||
|
{"title": "Post 1"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
model: TestUser{},
|
||||||
|
expectedCleanCount: 1, // name
|
||||||
|
expectedRelCount: 1, // posts
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Post with comments",
|
||||||
|
data: map[string]interface{}{
|
||||||
|
"title": "Test Post",
|
||||||
|
"comments": []map[string]interface{}{
|
||||||
|
{"content": "Comment 1"},
|
||||||
|
{"content": "Comment 2"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
model: TestPost{},
|
||||||
|
expectedCleanCount: 1, // title
|
||||||
|
expectedRelCount: 1, // comments
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "User with nested posts and comments",
|
||||||
|
data: map[string]interface{}{
|
||||||
|
"name": "Jane Doe",
|
||||||
|
"posts": []map[string]interface{}{
|
||||||
|
{
|
||||||
|
"title": "Post 1",
|
||||||
|
"comments": []map[string]interface{}{
|
||||||
|
{"content": "Comment 1"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
model: TestUser{},
|
||||||
|
expectedCleanCount: 1, // name
|
||||||
|
expectedRelCount: 1, // posts (which contains nested comments)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
cleanedData, relations, err := handler.extractNestedRelations(tt.data, tt.model)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("extractNestedRelations() error = %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(cleanedData) != tt.expectedCleanCount {
|
||||||
|
t.Errorf("Expected %d cleaned fields, got %d: %+v", tt.expectedCleanCount, len(cleanedData), cleanedData)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(relations) != tt.expectedRelCount {
|
||||||
|
t.Errorf("Expected %d relation fields, got %d: %+v", tt.expectedRelCount, len(relations), relations)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Logf("Cleaned data: %+v", cleanedData)
|
||||||
|
t.Logf("Relations: %+v", relations)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test shouldUseNestedProcessor function
|
||||||
|
func TestShouldUseNestedProcessor(t *testing.T) {
|
||||||
|
registry := &mockRegistry{
|
||||||
|
models: map[string]interface{}{
|
||||||
|
"users": TestUser{},
|
||||||
|
"posts": TestPost{},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
handler := NewHandler(nil, registry)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
data map[string]interface{}
|
||||||
|
model interface{}
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Data with simple nested posts (no further nesting)",
|
||||||
|
data: map[string]interface{}{
|
||||||
|
"name": "John",
|
||||||
|
"posts": []map[string]interface{}{
|
||||||
|
{"title": "Post 1"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
model: TestUser{},
|
||||||
|
expected: false, // Simple one-level nesting doesn't require nested processor
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Data with deeply nested relations",
|
||||||
|
data: map[string]interface{}{
|
||||||
|
"name": "John",
|
||||||
|
"posts": []map[string]interface{}{
|
||||||
|
{
|
||||||
|
"title": "Post 1",
|
||||||
|
"comments": []map[string]interface{}{
|
||||||
|
{"content": "Comment 1"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
model: TestUser{},
|
||||||
|
expected: true, // Multi-level nesting requires nested processor
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Data without nested relations",
|
||||||
|
data: map[string]interface{}{
|
||||||
|
"name": "John",
|
||||||
|
},
|
||||||
|
model: TestUser{},
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Data with _request field",
|
||||||
|
data: map[string]interface{}{
|
||||||
|
"_request": "insert",
|
||||||
|
"name": "John",
|
||||||
|
},
|
||||||
|
model: TestUser{},
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Nested data with _request field",
|
||||||
|
data: map[string]interface{}{
|
||||||
|
"name": "John",
|
||||||
|
"posts": []map[string]interface{}{
|
||||||
|
{
|
||||||
|
"_request": "insert",
|
||||||
|
"title": "Post 1",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
model: TestUser{},
|
||||||
|
expected: true, // _request at nested level requires nested processor
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := handler.shouldUseNestedProcessor(tt.data, tt.model)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("shouldUseNestedProcessor() = %v, expected %v", result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test normalizeToSlice function
|
||||||
|
func TestNormalizeToSlice(t *testing.T) {
|
||||||
|
registry := &mockRegistry{}
|
||||||
|
handler := NewHandler(nil, registry)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input interface{}
|
||||||
|
expected int // expected slice length
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Single object",
|
||||||
|
input: map[string]interface{}{"name": "John"},
|
||||||
|
expected: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Slice of objects",
|
||||||
|
input: []map[string]interface{}{
|
||||||
|
{"name": "John"},
|
||||||
|
{"name": "Jane"},
|
||||||
|
},
|
||||||
|
expected: 2,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Array of interfaces",
|
||||||
|
input: []interface{}{
|
||||||
|
map[string]interface{}{"name": "John"},
|
||||||
|
map[string]interface{}{"name": "Jane"},
|
||||||
|
map[string]interface{}{"name": "Bob"},
|
||||||
|
},
|
||||||
|
expected: 3,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Nil input",
|
||||||
|
input: nil,
|
||||||
|
expected: 0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := handler.normalizeToSlice(tt.input)
|
||||||
|
if len(result) != tt.expected {
|
||||||
|
t.Errorf("normalizeToSlice() returned slice of length %d, expected %d", len(result), tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test GetRelationshipInfo function
|
||||||
|
func TestGetRelationshipInfo(t *testing.T) {
|
||||||
|
registry := &mockRegistry{}
|
||||||
|
handler := NewHandler(nil, registry)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
modelType reflect.Type
|
||||||
|
relationName string
|
||||||
|
expectNil bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "User posts relation",
|
||||||
|
modelType: reflect.TypeOf(TestUser{}),
|
||||||
|
relationName: "posts",
|
||||||
|
expectNil: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Post comments relation",
|
||||||
|
modelType: reflect.TypeOf(TestPost{}),
|
||||||
|
relationName: "comments",
|
||||||
|
expectNil: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Non-existent relation",
|
||||||
|
modelType: reflect.TypeOf(TestUser{}),
|
||||||
|
relationName: "nonexistent",
|
||||||
|
expectNil: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := handler.GetRelationshipInfo(tt.modelType, tt.relationName)
|
||||||
|
if tt.expectNil && result != nil {
|
||||||
|
t.Errorf("Expected nil, got %+v", result)
|
||||||
|
}
|
||||||
|
if !tt.expectNil && result == nil {
|
||||||
|
t.Errorf("Expected non-nil relationship info")
|
||||||
|
}
|
||||||
|
if result != nil {
|
||||||
|
t.Logf("Relationship info: FieldName=%s, JSONName=%s, RelationType=%s, ForeignKey=%s",
|
||||||
|
result.FieldName, result.JSONName, result.RelationType, result.ForeignKey)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mock registry for testing
|
||||||
|
type mockRegistry struct {
|
||||||
|
models map[string]interface{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockRegistry) Register(name string, model interface{}) {
|
||||||
|
m.RegisterModel(name, model)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockRegistry) RegisterModel(name string, model interface{}) error {
|
||||||
|
if m.models == nil {
|
||||||
|
m.models = make(map[string]interface{})
|
||||||
|
}
|
||||||
|
m.models[name] = model
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockRegistry) GetModelByEntity(schema, entity string) (interface{}, error) {
|
||||||
|
if model, ok := m.models[entity]; ok {
|
||||||
|
return model, nil
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("model not found: %s", entity)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockRegistry) GetModelByName(name string) (interface{}, error) {
|
||||||
|
if model, ok := m.models[name]; ok {
|
||||||
|
return model, nil
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("model not found: %s", name)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockRegistry) GetModel(name string) (interface{}, error) {
|
||||||
|
return m.GetModelByName(name)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockRegistry) HasModel(schema, entity string) bool {
|
||||||
|
_, ok := m.models[entity]
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockRegistry) ListModels() []string {
|
||||||
|
models := make([]string, 0, len(m.models))
|
||||||
|
for name := range m.models {
|
||||||
|
models = append(models, name)
|
||||||
|
}
|
||||||
|
return models
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockRegistry) GetAllModels() map[string]interface{} {
|
||||||
|
return m.models
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestMultiLevelRelationExtraction tests extracting deeply nested relations
|
||||||
|
func TestMultiLevelRelationExtraction(t *testing.T) {
|
||||||
|
registry := &mockRegistry{
|
||||||
|
models: map[string]interface{}{
|
||||||
|
"users": TestUser{},
|
||||||
|
"posts": TestPost{},
|
||||||
|
"comments": TestComment{},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
handler := NewHandler(nil, registry)
|
||||||
|
|
||||||
|
// Test data with 3 levels: User -> Posts -> Comments
|
||||||
|
testData := map[string]interface{}{
|
||||||
|
"name": "John Doe",
|
||||||
|
"posts": []map[string]interface{}{
|
||||||
|
{
|
||||||
|
"title": "First Post",
|
||||||
|
"comments": []map[string]interface{}{
|
||||||
|
{"content": "Great post!"},
|
||||||
|
{"content": "Thanks for sharing!"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"title": "Second Post",
|
||||||
|
"comments": []map[string]interface{}{
|
||||||
|
{"content": "Interesting read"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract relations from user
|
||||||
|
cleanedData, relations, err := handler.extractNestedRelations(testData, TestUser{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to extract relations: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify user data is cleaned
|
||||||
|
if len(cleanedData) != 1 || cleanedData["name"] != "John Doe" {
|
||||||
|
t.Errorf("Expected cleaned data to contain only name, got: %+v", cleanedData)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify posts relation was extracted
|
||||||
|
if len(relations) != 1 {
|
||||||
|
t.Errorf("Expected 1 relation (posts), got %d", len(relations))
|
||||||
|
}
|
||||||
|
|
||||||
|
posts, ok := relations["posts"]
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("Expected posts relation to be extracted")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify posts is a slice with 2 items
|
||||||
|
postsSlice, ok := posts.([]map[string]interface{})
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("Expected posts to be []map[string]interface{}, got %T", posts)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(postsSlice) != 2 {
|
||||||
|
t.Errorf("Expected 2 posts, got %d", len(postsSlice))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify first post has comments
|
||||||
|
if _, hasComments := postsSlice[0]["comments"]; !hasComments {
|
||||||
|
t.Error("Expected first post to have comments")
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Logf("Successfully extracted multi-level nested relations")
|
||||||
|
t.Logf("Cleaned data: %+v", cleanedData)
|
||||||
|
t.Logf("Relations: %d posts with nested comments", len(postsSlice))
|
||||||
|
}
|
||||||
@@ -37,6 +37,9 @@ type ExtendedRequestOptions struct {
|
|||||||
// Response format
|
// Response format
|
||||||
ResponseFormat string // "simple", "detail", "syncfusion"
|
ResponseFormat string // "simple", "detail", "syncfusion"
|
||||||
|
|
||||||
|
// Single record normalization - convert single-element arrays to objects
|
||||||
|
SingleRecordAsObject bool
|
||||||
|
|
||||||
// Transaction
|
// Transaction
|
||||||
AtomicTransaction bool
|
AtomicTransaction bool
|
||||||
}
|
}
|
||||||
@@ -99,10 +102,11 @@ func (h *Handler) parseOptionsFromHeaders(r common.Request) ExtendedRequestOptio
|
|||||||
Sort: make([]common.SortOption, 0),
|
Sort: make([]common.SortOption, 0),
|
||||||
Preload: make([]common.PreloadOption, 0),
|
Preload: make([]common.PreloadOption, 0),
|
||||||
},
|
},
|
||||||
AdvancedSQL: make(map[string]string),
|
AdvancedSQL: make(map[string]string),
|
||||||
ComputedQL: make(map[string]string),
|
ComputedQL: make(map[string]string),
|
||||||
Expand: make([]ExpandOption, 0),
|
Expand: make([]ExpandOption, 0),
|
||||||
ResponseFormat: "simple", // Default response format
|
ResponseFormat: "simple", // Default response format
|
||||||
|
SingleRecordAsObject: true, // Default: normalize single-element arrays to objects
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get all headers
|
// Get all headers
|
||||||
@@ -146,7 +150,12 @@ func (h *Handler) parseOptionsFromHeaders(r common.Request) ExtendedRequestOptio
|
|||||||
|
|
||||||
// Joins & Relations
|
// Joins & Relations
|
||||||
case strings.HasPrefix(normalizedKey, "x-preload"):
|
case strings.HasPrefix(normalizedKey, "x-preload"):
|
||||||
h.parsePreload(&options, decodedValue)
|
if strings.HasSuffix(normalizedKey, "-where") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
whereClaude := headers[fmt.Sprintf("%s-where", key)]
|
||||||
|
h.parsePreload(&options, decodedValue, decodeHeaderValue(whereClaude))
|
||||||
|
|
||||||
case strings.HasPrefix(normalizedKey, "x-expand"):
|
case strings.HasPrefix(normalizedKey, "x-expand"):
|
||||||
h.parseExpand(&options, decodedValue)
|
h.parseExpand(&options, decodedValue)
|
||||||
case strings.HasPrefix(normalizedKey, "x-custom-sql-join"):
|
case strings.HasPrefix(normalizedKey, "x-custom-sql-join"):
|
||||||
@@ -194,6 +203,13 @@ func (h *Handler) parseOptionsFromHeaders(r common.Request) ExtendedRequestOptio
|
|||||||
options.ResponseFormat = "detail"
|
options.ResponseFormat = "detail"
|
||||||
case strings.HasPrefix(normalizedKey, "x-syncfusion"):
|
case strings.HasPrefix(normalizedKey, "x-syncfusion"):
|
||||||
options.ResponseFormat = "syncfusion"
|
options.ResponseFormat = "syncfusion"
|
||||||
|
case strings.HasPrefix(normalizedKey, "x-single-record-as-object"):
|
||||||
|
// Parse as boolean - "false" disables, "true" enables (default is true)
|
||||||
|
if strings.EqualFold(decodedValue, "false") {
|
||||||
|
options.SingleRecordAsObject = false
|
||||||
|
} else if strings.EqualFold(decodedValue, "true") {
|
||||||
|
options.SingleRecordAsObject = true
|
||||||
|
}
|
||||||
|
|
||||||
// Transaction Control
|
// Transaction Control
|
||||||
case strings.HasPrefix(normalizedKey, "x-transaction-atomic"):
|
case strings.HasPrefix(normalizedKey, "x-transaction-atomic"):
|
||||||
@@ -341,7 +357,15 @@ func (h *Handler) mapSearchOperator(colName, operator, value string) common.Filt
|
|||||||
|
|
||||||
// parsePreload parses x-preload header
|
// parsePreload parses x-preload header
|
||||||
// Format: RelationName:field1,field2 or RelationName or multiple separated by |
|
// Format: RelationName:field1,field2 or RelationName or multiple separated by |
|
||||||
func (h *Handler) parsePreload(options *ExtendedRequestOptions, value string) {
|
func (h *Handler) parsePreload(options *ExtendedRequestOptions, values ...string) {
|
||||||
|
if len(values) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
value := values[0]
|
||||||
|
whereClause := ""
|
||||||
|
if len(values) > 1 {
|
||||||
|
whereClause = values[1]
|
||||||
|
}
|
||||||
if value == "" {
|
if value == "" {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -358,6 +382,7 @@ func (h *Handler) parsePreload(options *ExtendedRequestOptions, value string) {
|
|||||||
parts := strings.SplitN(preloadStr, ":", 2)
|
parts := strings.SplitN(preloadStr, ":", 2)
|
||||||
preload := common.PreloadOption{
|
preload := common.PreloadOption{
|
||||||
Relation: strings.TrimSpace(parts[0]),
|
Relation: strings.TrimSpace(parts[0]),
|
||||||
|
Where: whereClause,
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(parts) == 2 {
|
if len(parts) == 2 {
|
||||||
|
|||||||
@@ -106,7 +106,7 @@ func SetupMuxRoutes(muxRouter *mux.Router, handler *Handler) {
|
|||||||
reqAdapter := router.NewHTTPRequest(r)
|
reqAdapter := router.NewHTTPRequest(r)
|
||||||
respAdapter := router.NewHTTPResponseWriter(w)
|
respAdapter := router.NewHTTPResponseWriter(w)
|
||||||
handler.Handle(respAdapter, reqAdapter, vars)
|
handler.Handle(respAdapter, reqAdapter, vars)
|
||||||
}).Methods("GET", "PUT", "PATCH", "DELETE")
|
}).Methods("GET", "PUT", "PATCH", "DELETE", "POST")
|
||||||
|
|
||||||
// GET for metadata (using HandleGet)
|
// GET for metadata (using HandleGet)
|
||||||
muxRouter.HandleFunc("/{schema}/{entity}/metadata", func(w http.ResponseWriter, r *http.Request) {
|
muxRouter.HandleFunc("/{schema}/{entity}/metadata", func(w http.ResponseWriter, r *http.Request) {
|
||||||
@@ -189,6 +189,18 @@ func SetupBunRouterRoutes(bunRouter *router.StandardBunRouterAdapter, handler *H
|
|||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
|
||||||
|
r.Handle("POST", "/:schema/:entity/:id", func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||||
|
params := map[string]string{
|
||||||
|
"schema": req.Param("schema"),
|
||||||
|
"entity": req.Param("entity"),
|
||||||
|
"id": req.Param("id"),
|
||||||
|
}
|
||||||
|
reqAdapter := router.NewBunRouterRequest(req)
|
||||||
|
respAdapter := router.NewHTTPResponseWriter(w)
|
||||||
|
handler.Handle(respAdapter, reqAdapter, params)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
r.Handle("PUT", "/:schema/:entity/:id", func(w http.ResponseWriter, req bunrouter.Request) error {
|
r.Handle("PUT", "/:schema/:entity/:id", func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||||
params := map[string]string{
|
params := map[string]string{
|
||||||
"schema": req.Param("schema"),
|
"schema": req.Param("schema"),
|
||||||
|
|||||||
@@ -402,25 +402,41 @@ func testRestHeadSpecCRUD(t *testing.T, serverURL string) {
|
|||||||
resp := makeRestHeadSpecRequest(t, serverURL, fmt.Sprintf("/restheadspec/departments/%s", deptID), "GET", nil, nil)
|
resp := makeRestHeadSpecRequest(t, serverURL, fmt.Sprintf("/restheadspec/departments/%s", deptID), "GET", nil, nil)
|
||||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||||
|
|
||||||
// RestHeadSpec may return data directly as array or wrapped in response object
|
// RestHeadSpec may return data directly as array/object or wrapped in response object
|
||||||
body, err := io.ReadAll(resp.Body)
|
body, err := io.ReadAll(resp.Body)
|
||||||
assert.NoError(t, err, "Failed to read response body")
|
assert.NoError(t, err, "Failed to read response body")
|
||||||
|
|
||||||
// Try to decode as array first (simple format)
|
// Try to decode as array first (simple format - multiple records or disabled SingleRecordAsObject)
|
||||||
var dataArray []interface{}
|
var dataArray []interface{}
|
||||||
if err := json.Unmarshal(body, &dataArray); err == nil {
|
if err := json.Unmarshal(body, &dataArray); err == nil {
|
||||||
assert.GreaterOrEqual(t, len(dataArray), 1, "Should find department")
|
assert.GreaterOrEqual(t, len(dataArray), 1, "Should find department")
|
||||||
logger.Info("Department read successfully (simple format): %s", deptID)
|
logger.Info("Department read successfully (simple format - array): %s", deptID)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Try to decode as standard response object (detail format)
|
// Try to decode as a single object first (simple format with SingleRecordAsObject enabled)
|
||||||
var result map[string]interface{}
|
var singleObj map[string]interface{}
|
||||||
if err := json.Unmarshal(body, &result); err == nil {
|
if err := json.Unmarshal(body, &singleObj); err == nil {
|
||||||
if success, ok := result["success"]; ok && success != nil && success.(bool) {
|
// Check if it's a data object (not a response wrapper)
|
||||||
if data, ok := result["data"].([]interface{}); ok {
|
if _, hasSuccess := singleObj["success"]; !hasSuccess {
|
||||||
|
// This is a direct data object (simple format, single record)
|
||||||
|
assert.NotEmpty(t, singleObj, "Should find department")
|
||||||
|
logger.Info("Department read successfully (simple format - single object): %s", deptID)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Otherwise it's a standard response object (detail format)
|
||||||
|
if success, ok := singleObj["success"]; ok && success != nil && success.(bool) {
|
||||||
|
// Check if data is an array
|
||||||
|
if data, ok := singleObj["data"].([]interface{}); ok {
|
||||||
assert.GreaterOrEqual(t, len(data), 1, "Should find department")
|
assert.GreaterOrEqual(t, len(data), 1, "Should find department")
|
||||||
logger.Info("Department read successfully (detail format): %s", deptID)
|
logger.Info("Department read successfully (detail format - array): %s", deptID)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Check if data is a single object (SingleRecordAsObject feature in detail format)
|
||||||
|
if data, ok := singleObj["data"].(map[string]interface{}); ok {
|
||||||
|
assert.NotEmpty(t, data, "Should find department")
|
||||||
|
logger.Info("Department read successfully (detail format - single object): %s", deptID)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -446,25 +462,41 @@ func testRestHeadSpecCRUD(t *testing.T, serverURL string) {
|
|||||||
resp := makeRestHeadSpecRequest(t, serverURL, "/restheadspec/employees", "GET", nil, headers)
|
resp := makeRestHeadSpecRequest(t, serverURL, "/restheadspec/employees", "GET", nil, headers)
|
||||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||||
|
|
||||||
// RestHeadSpec may return data directly as array or wrapped in response object
|
// RestHeadSpec may return data directly as array/object or wrapped in response object
|
||||||
body, err := io.ReadAll(resp.Body)
|
body, err := io.ReadAll(resp.Body)
|
||||||
assert.NoError(t, err, "Failed to read response body")
|
assert.NoError(t, err, "Failed to read response body")
|
||||||
|
|
||||||
// Try array format first
|
// Try array format first (multiple records or disabled SingleRecordAsObject)
|
||||||
var dataArray []interface{}
|
var dataArray []interface{}
|
||||||
if err := json.Unmarshal(body, &dataArray); err == nil {
|
if err := json.Unmarshal(body, &dataArray); err == nil {
|
||||||
assert.GreaterOrEqual(t, len(dataArray), 1, "Should find at least one employee")
|
assert.GreaterOrEqual(t, len(dataArray), 1, "Should find at least one employee")
|
||||||
logger.Info("Employees read with filter successfully (simple format), found: %d", len(dataArray))
|
logger.Info("Employees read with filter successfully (simple format - array), found: %d", len(dataArray))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Try standard response format
|
// Try to decode as a single object (simple format with SingleRecordAsObject enabled)
|
||||||
var result map[string]interface{}
|
var singleObj map[string]interface{}
|
||||||
if err := json.Unmarshal(body, &result); err == nil {
|
if err := json.Unmarshal(body, &singleObj); err == nil {
|
||||||
if success, ok := result["success"]; ok && success != nil && success.(bool) {
|
// Check if it's a data object (not a response wrapper)
|
||||||
if data, ok := result["data"].([]interface{}); ok {
|
if _, hasSuccess := singleObj["success"]; !hasSuccess {
|
||||||
|
// This is a direct data object (simple format, single record)
|
||||||
|
assert.NotEmpty(t, singleObj, "Should find at least one employee")
|
||||||
|
logger.Info("Employees read with filter successfully (simple format - single object), found: 1")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Otherwise it's a standard response object (detail format)
|
||||||
|
if success, ok := singleObj["success"]; ok && success != nil && success.(bool) {
|
||||||
|
// Check if data is an array
|
||||||
|
if data, ok := singleObj["data"].([]interface{}); ok {
|
||||||
assert.GreaterOrEqual(t, len(data), 1, "Should find at least one employee")
|
assert.GreaterOrEqual(t, len(data), 1, "Should find at least one employee")
|
||||||
logger.Info("Employees read with filter successfully (detail format), found: %d", len(data))
|
logger.Info("Employees read with filter successfully (detail format - array), found: %d", len(data))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Check if data is a single object (SingleRecordAsObject feature in detail format)
|
||||||
|
if data, ok := singleObj["data"].(map[string]interface{}); ok {
|
||||||
|
assert.NotEmpty(t, data, "Should find at least one employee")
|
||||||
|
logger.Info("Employees read with filter successfully (detail format - single object), found: 1")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user