mirror of
https://github.com/bitechdev/ResolveSpec.git
synced 2025-12-29 15:54:26 +00:00
Compare commits
26 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
af3260864d | ||
|
|
ca6d2deff6 | ||
|
|
1481443516 | ||
|
|
cb54ec5e27 | ||
|
|
7d6a9025f5 | ||
|
|
35089f511f | ||
|
|
66b6a0d835 | ||
|
|
456c165814 | ||
|
|
850d7b546c | ||
|
|
a44ef90d7c | ||
|
|
8b7db5b31a | ||
|
|
14daea3b05 | ||
|
|
35f23b6d9e | ||
|
|
53a4e67f70 | ||
|
|
1289c3af88 | ||
|
|
cdfb7a67fd | ||
|
|
7f5b851669 | ||
|
|
f0e26b1c0d | ||
|
|
1db1b924ef | ||
|
|
d9cf23b1dc | ||
|
|
94f013c872 | ||
|
|
c52fcff61d | ||
|
|
ce106fa940 | ||
|
|
37b4b75175 | ||
|
|
0cef0f75d3 | ||
|
|
006dc4a2b2 |
3
.gitignore
vendored
3
.gitignore
vendored
@@ -23,4 +23,5 @@ go.work.sum
|
||||
|
||||
# env file
|
||||
.env
|
||||
bin/
|
||||
bin/
|
||||
test.db
|
||||
|
||||
110
.golangci.bck.yml
Normal file
110
.golangci.bck.yml
Normal file
@@ -0,0 +1,110 @@
|
||||
run:
|
||||
timeout: 5m
|
||||
tests: true
|
||||
skip-dirs:
|
||||
- vendor
|
||||
- .github
|
||||
|
||||
linters:
|
||||
enable:
|
||||
- errcheck
|
||||
- gosimple
|
||||
- govet
|
||||
- ineffassign
|
||||
- staticcheck
|
||||
- unused
|
||||
- gofmt
|
||||
- goimports
|
||||
- misspell
|
||||
- gocritic
|
||||
- revive
|
||||
- stylecheck
|
||||
disable:
|
||||
- typecheck # Can cause issues with generics in some cases
|
||||
|
||||
linters-settings:
|
||||
errcheck:
|
||||
check-type-assertions: false
|
||||
check-blank: false
|
||||
|
||||
govet:
|
||||
check-shadowing: false
|
||||
|
||||
gofmt:
|
||||
simplify: true
|
||||
|
||||
goimports:
|
||||
local-prefixes: github.com/bitechdev/ResolveSpec
|
||||
|
||||
gocritic:
|
||||
enabled-checks:
|
||||
- appendAssign
|
||||
- assignOp
|
||||
- boolExprSimplify
|
||||
- builtinShadow
|
||||
- captLocal
|
||||
- caseOrder
|
||||
- defaultCaseOrder
|
||||
- dupArg
|
||||
- dupBranchBody
|
||||
- dupCase
|
||||
- dupSubExpr
|
||||
- elseif
|
||||
- emptyFallthrough
|
||||
- equalFold
|
||||
- flagName
|
||||
- ifElseChain
|
||||
- indexAlloc
|
||||
- initClause
|
||||
- methodExprCall
|
||||
- nilValReturn
|
||||
- rangeExprCopy
|
||||
- rangeValCopy
|
||||
- regexpMust
|
||||
- singleCaseSwitch
|
||||
- sloppyLen
|
||||
- stringXbytes
|
||||
- switchTrue
|
||||
- typeAssertChain
|
||||
- typeSwitchVar
|
||||
- underef
|
||||
- unlabelStmt
|
||||
- unnamedResult
|
||||
- unnecessaryBlock
|
||||
- weakCond
|
||||
- yodaStyleExpr
|
||||
|
||||
revive:
|
||||
rules:
|
||||
- name: exported
|
||||
disabled: true
|
||||
- name: package-comments
|
||||
disabled: true
|
||||
|
||||
issues:
|
||||
exclude-use-default: false
|
||||
max-issues-per-linter: 0
|
||||
max-same-issues: 0
|
||||
|
||||
# Exclude some linters from running on tests files
|
||||
exclude-rules:
|
||||
- path: _test\.go
|
||||
linters:
|
||||
- errcheck
|
||||
- dupl
|
||||
- gosec
|
||||
- gocritic
|
||||
|
||||
# Ignore "error return value not checked" for defer statements
|
||||
- linters:
|
||||
- errcheck
|
||||
text: "Error return value of .((os\\.)?std(out|err)\\..*|.*Close|.*Flush|os\\.Remove(All)?|.*print(f|ln)?|os\\.(Un)?Setenv). is not checked"
|
||||
|
||||
# Ignore complexity in test files
|
||||
- path: _test\.go
|
||||
text: "cognitive complexity|cyclomatic complexity"
|
||||
|
||||
output:
|
||||
format: colored-line-number
|
||||
print-issued-lines: true
|
||||
print-linter-name: true
|
||||
58
.vscode/tasks.json
vendored
58
.vscode/tasks.json
vendored
@@ -24,21 +24,63 @@
|
||||
"type": "go",
|
||||
"label": "go: test workspace",
|
||||
"command": "test",
|
||||
|
||||
"options": {
|
||||
"env": {
|
||||
"CGO_ENABLED": "0"
|
||||
},
|
||||
"cwd": "${workspaceFolder}/bin",
|
||||
"cwd": "${workspaceFolder}"
|
||||
},
|
||||
"args": [
|
||||
"../..."
|
||||
"-v",
|
||||
"-race",
|
||||
"-coverprofile=coverage.out",
|
||||
"-covermode=atomic",
|
||||
"./..."
|
||||
],
|
||||
"problemMatcher": [
|
||||
"$go"
|
||||
],
|
||||
"group": "build",
|
||||
|
||||
"group": {
|
||||
"kind": "test",
|
||||
"isDefault": true
|
||||
},
|
||||
"presentation": {
|
||||
"reveal": "always",
|
||||
"panel": "new"
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "shell",
|
||||
"label": "go: vet workspace",
|
||||
"command": "go vet ./...",
|
||||
"options": {
|
||||
"cwd": "${workspaceFolder}"
|
||||
},
|
||||
"problemMatcher": [
|
||||
"$go"
|
||||
],
|
||||
"group": "test"
|
||||
},
|
||||
{
|
||||
"type": "shell",
|
||||
"label": "go: lint workspace",
|
||||
"command": "golangci-lint run --timeout=5m",
|
||||
"options": {
|
||||
"cwd": "${workspaceFolder}"
|
||||
},
|
||||
"problemMatcher": [],
|
||||
"group": "test"
|
||||
},
|
||||
{
|
||||
"type": "shell",
|
||||
"label": "go: full test suite",
|
||||
"dependsOrder": "sequence",
|
||||
"dependsOn": [
|
||||
"go: vet workspace",
|
||||
"go: test workspace"
|
||||
],
|
||||
"problemMatcher": [],
|
||||
"group": {
|
||||
"kind": "test",
|
||||
"isDefault": false
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
164
README.md
164
README.md
@@ -31,9 +31,12 @@ Both share the same core architecture and provide dynamic data querying, relatio
|
||||
- [RestHeadSpec: Header-Based API](#restheadspec-header-based-api-1)
|
||||
- [Lifecycle Hooks](#lifecycle-hooks)
|
||||
- [Cursor Pagination](#cursor-pagination)
|
||||
- [Response Formats](#response-formats)
|
||||
- [Single Record as Object](#single-record-as-object-default-behavior)
|
||||
- [Example Usage](#example-usage)
|
||||
- [Recursive CRUD Operations](#recursive-crud-operations-)
|
||||
- [Testing](#testing)
|
||||
- [What's New in v2.0](#whats-new-in-v20)
|
||||
- [What's New](#whats-new)
|
||||
|
||||
## Features
|
||||
|
||||
@@ -45,6 +48,7 @@ Both share the same core architecture and provide dynamic data querying, relatio
|
||||
- **Pagination**: Built-in limit/offset and cursor-based pagination
|
||||
- **Computed Columns**: Define virtual columns for complex calculations
|
||||
- **Custom Operators**: Add custom SQL conditions when needed
|
||||
- **🆕 Recursive CRUD Handler**: Automatically handle nested object graphs with foreign key resolution and per-record operation control via `_request` field
|
||||
|
||||
### Architecture (v2.0+)
|
||||
- **🆕 Database Agnostic**: Works with GORM, Bun, or any database layer through adapters
|
||||
@@ -57,6 +61,7 @@ Both share the same core architecture and provide dynamic data querying, relatio
|
||||
- **🆕 Lifecycle Hooks**: Before/after hooks for create, read, update, and delete operations
|
||||
- **🆕 Cursor Pagination**: Efficient cursor-based pagination with complex sort support
|
||||
- **🆕 Multiple Response Formats**: Simple, detailed, and Syncfusion-compatible formats
|
||||
- **🆕 Single Record as Object**: Automatically normalize single-element arrays to objects (enabled by default)
|
||||
- **🆕 Advanced Filtering**: Field filters, search operators, AND/OR logic, and custom SQL
|
||||
- **🆕 Base64 Encoding**: Support for base64-encoded header values
|
||||
|
||||
@@ -161,6 +166,7 @@ restheadspec.SetupMuxRoutes(router, handler)
|
||||
| `X-Limit` | Limit results | `50` |
|
||||
| `X-Offset` | Offset for pagination | `100` |
|
||||
| `X-Clean-JSON` | Remove null/empty fields | `true` |
|
||||
| `X-Single-Record-As-Object` | Return single records as objects (default: `true`) | `false` |
|
||||
|
||||
**Available Operators**: `eq`, `neq`, `gt`, `gte`, `lt`, `lte`, `contains`, `startswith`, `endswith`, `between`, `betweeninclusive`, `in`, `empty`, `notempty`
|
||||
|
||||
@@ -301,6 +307,55 @@ RestHeadSpec supports multiple response formats:
|
||||
}
|
||||
```
|
||||
|
||||
### Single Record as Object (Default Behavior)
|
||||
|
||||
By default, RestHeadSpec automatically converts single-element arrays into objects for cleaner API responses. This provides a better developer experience when fetching individual records.
|
||||
|
||||
**Default behavior (enabled)**:
|
||||
```http
|
||||
GET /public/users/123
|
||||
```
|
||||
```json
|
||||
{
|
||||
"success": true,
|
||||
"data": { "id": 123, "name": "John", "email": "john@example.com" }
|
||||
}
|
||||
```
|
||||
|
||||
Instead of:
|
||||
```json
|
||||
{
|
||||
"success": true,
|
||||
"data": [{ "id": 123, "name": "John", "email": "john@example.com" }]
|
||||
}
|
||||
```
|
||||
|
||||
**To disable** (force arrays for consistency):
|
||||
```http
|
||||
GET /public/users/123
|
||||
X-Single-Record-As-Object: false
|
||||
```
|
||||
```json
|
||||
{
|
||||
"success": true,
|
||||
"data": [{ "id": 123, "name": "John", "email": "john@example.com" }]
|
||||
}
|
||||
```
|
||||
|
||||
**How it works**:
|
||||
- When a query returns exactly **one record**, it's returned as an object
|
||||
- When a query returns **multiple records**, they're returned as an array
|
||||
- Set `X-Single-Record-As-Object: false` to always receive arrays
|
||||
- Works with all response formats (simple, detail, syncfusion)
|
||||
- Applies to both read operations and create/update returning clauses
|
||||
|
||||
**Benefits**:
|
||||
- Cleaner API responses for single-record queries
|
||||
- No need to unwrap single-element arrays on the client side
|
||||
- Better TypeScript/type inference support
|
||||
- Consistent with common REST API patterns
|
||||
- Backward compatible via header opt-out
|
||||
|
||||
## Example Usage
|
||||
|
||||
### Reading Data with Related Entities
|
||||
@@ -342,6 +397,92 @@ POST /core/users
|
||||
}
|
||||
```
|
||||
|
||||
### Recursive CRUD Operations (🆕)
|
||||
|
||||
ResolveSpec now supports automatic handling of nested object graphs with intelligent foreign key resolution. This allows you to create, update, or delete entire object hierarchies in a single request.
|
||||
|
||||
#### Creating Nested Objects
|
||||
|
||||
```json
|
||||
POST /core/users
|
||||
{
|
||||
"operation": "create",
|
||||
"data": {
|
||||
"name": "John Doe",
|
||||
"email": "john@example.com",
|
||||
"posts": [
|
||||
{
|
||||
"title": "My First Post",
|
||||
"content": "Hello World",
|
||||
"tags": [
|
||||
{"name": "tech"},
|
||||
{"name": "programming"}
|
||||
]
|
||||
},
|
||||
{
|
||||
"title": "Second Post",
|
||||
"content": "More content"
|
||||
}
|
||||
],
|
||||
"profile": {
|
||||
"bio": "Software Developer",
|
||||
"website": "https://example.com"
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
#### Per-Record Operation Control with `_request`
|
||||
|
||||
Control individual operations for each nested record using the special `_request` field:
|
||||
|
||||
```json
|
||||
POST /core/users/123
|
||||
{
|
||||
"operation": "update",
|
||||
"data": {
|
||||
"name": "John Updated",
|
||||
"posts": [
|
||||
{
|
||||
"_request": "insert",
|
||||
"title": "New Post",
|
||||
"content": "Fresh content"
|
||||
},
|
||||
{
|
||||
"_request": "update",
|
||||
"id": 456,
|
||||
"title": "Updated Post Title"
|
||||
},
|
||||
{
|
||||
"_request": "delete",
|
||||
"id": 789
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Supported `_request` values**:
|
||||
- `insert` - Create a new related record
|
||||
- `update` - Update an existing related record
|
||||
- `delete` - Delete a related record
|
||||
- `upsert` - Create if doesn't exist, update if exists
|
||||
|
||||
#### How It Works
|
||||
|
||||
1. **Automatic Foreign Key Resolution**: Parent IDs are automatically propagated to child records
|
||||
2. **Recursive Processing**: Handles nested relationships at any depth
|
||||
3. **Transaction Safety**: All operations execute within database transactions
|
||||
4. **Relationship Detection**: Automatically detects belongsTo, hasMany, hasOne, and many2many relationships
|
||||
5. **Flexible Operations**: Mix create, update, and delete operations in a single request
|
||||
|
||||
#### Benefits
|
||||
|
||||
- Reduce API round trips for complex object graphs
|
||||
- Maintain referential integrity automatically
|
||||
- Simplify client-side code
|
||||
- Atomic operations with automatic rollback on errors
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
@@ -811,12 +952,32 @@ This project is licensed under the MIT License - see the [LICENSE](LICENSE) file
|
||||
|
||||
### v2.1 (Latest)
|
||||
|
||||
**Recursive CRUD Handler (🆕 Nov 11, 2025)**:
|
||||
- **Nested Object Graphs**: Automatically handle complex object hierarchies with parent-child relationships
|
||||
- **Foreign Key Resolution**: Automatic propagation of parent IDs to child records
|
||||
- **Per-Record Operations**: Control create/update/delete operations per record via `_request` field
|
||||
- **Transaction Safety**: All nested operations execute atomically within database transactions
|
||||
- **Relationship Detection**: Automatic detection of belongsTo, hasMany, hasOne, and many2many relationships
|
||||
- **Deep Nesting Support**: Handle relationships at any depth level
|
||||
- **Mixed Operations**: Combine insert, update, and delete operations in a single request
|
||||
|
||||
**Primary Key Improvements (Nov 11, 2025)**:
|
||||
- **GetPrimaryKeyName**: Enhanced primary key detection for better preload and ID field handling
|
||||
- **Better GORM/Bun Support**: Improved compatibility with both ORMs for primary key operations
|
||||
- **Computed Column Support**: Fixed computed columns functionality across handlers
|
||||
|
||||
**Database Adapter Enhancements (Nov 11, 2025)**:
|
||||
- **Bun ORM Relations**: Using Scan model method for better has-many and many-to-many relationship handling
|
||||
- **Model Method Support**: Enhanced query building with proper model registration
|
||||
- **Improved Type Safety**: Better handling of relationship queries with type-aware scanning
|
||||
|
||||
**RestHeadSpec - Header-Based REST API**:
|
||||
- **Header-Based Querying**: All query options via HTTP headers instead of request body
|
||||
- **Lifecycle Hooks**: Before/after hooks for create, read, update, delete operations
|
||||
- **Cursor Pagination**: Efficient cursor-based pagination with complex sorting
|
||||
- **Advanced Filtering**: Field filters, search operators, AND/OR logic
|
||||
- **Multiple Response Formats**: Simple, detailed, and Syncfusion-compatible responses
|
||||
- **Single Record as Object**: Automatically return single-element arrays as objects (default, toggleable via header)
|
||||
- **Base64 Support**: Base64-encoded header values for complex queries
|
||||
- **Type-Aware Filtering**: Automatic type detection and conversion for filters
|
||||
|
||||
@@ -826,6 +987,7 @@ This project is licensed under the MIT License - see the [LICENSE](LICENSE) file
|
||||
- Improved reflection safety
|
||||
- Fixed COUNT query issues with table aliasing
|
||||
- Better pointer handling throughout the codebase
|
||||
- **Comprehensive Test Coverage**: Added standalone CRUD tests for both ResolveSpec and RestHeadSpec
|
||||
|
||||
### v2.0
|
||||
|
||||
|
||||
13
go.mod
13
go.mod
@@ -21,6 +21,8 @@ require (
|
||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||
github.com/jinzhu/now v1.1.5 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/mattn/go-sqlite3 v1.14.28 // indirect
|
||||
github.com/ncruces/go-strftime v0.1.9 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/puzpuzpuz/xsync/v3 v3.5.1 // indirect
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
||||
@@ -29,15 +31,18 @@ require (
|
||||
github.com/tidwall/pretty v1.2.0 // indirect
|
||||
github.com/tidwall/sjson v1.2.5 // indirect
|
||||
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc // indirect
|
||||
github.com/uptrace/bun/dialect/sqlitedialect v1.2.15 // indirect
|
||||
github.com/uptrace/bun/driver/sqliteshim v1.2.15 // indirect
|
||||
github.com/uptrace/bunrouter v1.0.23 // indirect
|
||||
github.com/vmihailenco/msgpack/v5 v5.4.1 // indirect
|
||||
github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
|
||||
go.uber.org/multierr v1.10.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20250711185948-6ae5c78190dc // indirect
|
||||
golang.org/x/sys v0.34.0 // indirect
|
||||
golang.org/x/text v0.21.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
modernc.org/libc v1.22.5 // indirect
|
||||
modernc.org/mathutil v1.5.0 // indirect
|
||||
modernc.org/memory v1.5.0 // indirect
|
||||
modernc.org/sqlite v1.23.1 // indirect
|
||||
modernc.org/libc v1.66.3 // indirect
|
||||
modernc.org/mathutil v1.7.1 // indirect
|
||||
modernc.org/memory v1.11.0 // indirect
|
||||
modernc.org/sqlite v1.38.0 // indirect
|
||||
)
|
||||
|
||||
19
go.sum
19
go.sum
@@ -9,6 +9,7 @@ github.com/glebarez/sqlite v1.11.0 h1:wSG0irqzP6VurnMEpFGer5Li19RpIRi2qvQz++w0GM
|
||||
github.com/glebarez/sqlite v1.11.0/go.mod h1:h8/o8j5wiAsqSPoWELDUdJXhjAhsVliSn7bWZjOhrgQ=
|
||||
github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26 h1:Xim43kblpZXfIBQsbuBVKCudVG457BR2GZFIz3uw3hQ=
|
||||
github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26/go.mod h1:dDKJzRmX4S37WGHujM7tX//fmj1uioxKzKxz3lo4HJo=
|
||||
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY=
|
||||
@@ -21,6 +22,10 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-sqlite3 v1.14.28 h1:ThEiQrnbtumT+QMknw63Befp/ce/nUPgBPMlRFEum7A=
|
||||
github.com/mattn/go-sqlite3 v1.14.28/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
|
||||
github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4=
|
||||
github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
|
||||
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWbfPhv4DMiApHyliiK5xCTNVSPiaAs=
|
||||
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
@@ -50,6 +55,10 @@ github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc h1:9lRDQMhESg+zvGYm
|
||||
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc/go.mod h1:bciPuU6GHm1iF1pBvUfxfsH0Wmnc2VbpgvbI9ZWuIRs=
|
||||
github.com/uptrace/bun v1.2.15 h1:Ut68XRBLDgp9qG9QBMa9ELWaZOmzHNdczHQdrOZbEFE=
|
||||
github.com/uptrace/bun v1.2.15/go.mod h1:Eghz7NonZMiTX/Z6oKYytJ0oaMEJ/eq3kEV4vSqG038=
|
||||
github.com/uptrace/bun/dialect/sqlitedialect v1.2.15 h1:7upGMVjFRB1oI78GQw6ruNLblYn5CR+kxqcbbeBBils=
|
||||
github.com/uptrace/bun/dialect/sqlitedialect v1.2.15/go.mod h1:c7YIDaPNS2CU2uI1p7umFuFWkuKbDcPDDvp+DLHZnkI=
|
||||
github.com/uptrace/bun/driver/sqliteshim v1.2.15 h1:M/rZJSjOPV4OmfTVnDPtL+wJmdMTqDUn8cuk5ycfABA=
|
||||
github.com/uptrace/bun/driver/sqliteshim v1.2.15/go.mod h1:YqwxFyvM992XOCpGJtXyKPkgkb+aZpIIMzGbpaw1hIk=
|
||||
github.com/uptrace/bunrouter v1.0.23 h1:Bi7NKw3uCQkcA/GUCtDNPq5LE5UdR9pe+UyWbjHB/wU=
|
||||
github.com/uptrace/bunrouter v1.0.23/go.mod h1:O3jAcl+5qgnF+ejhgkmbceEk0E/mqaK+ADOocdNpY8M=
|
||||
github.com/vmihailenco/msgpack/v5 v5.4.1 h1:cQriyiUvjTwOHg8QZaPihLWeRAAVoCpE00IUPn0Bjt8=
|
||||
@@ -62,6 +71,8 @@ go.uber.org/multierr v1.10.0 h1:S0h4aNzvfcFsC3dRF1jLoaov7oRaKqRGC/pUEJ2yvPQ=
|
||||
go.uber.org/multierr v1.10.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
|
||||
go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8=
|
||||
go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
|
||||
golang.org/x/exp v0.0.0-20250711185948-6ae5c78190dc h1:TS73t7x3KarrNd5qAipmspBDS1rkMcgVG/fS1aRb4Rc=
|
||||
golang.org/x/exp v0.0.0-20250711185948-6ae5c78190dc/go.mod h1:A+z0yzpGtvnG90cToK5n2tu8UJVP2XUATh+r+sfOOOc=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.34.0 h1:H5Y5sJ2L2JRdyv7ROF1he/lPdvFsd0mJHFw2ThKHxLA=
|
||||
golang.org/x/sys v0.34.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||
@@ -77,9 +88,17 @@ gorm.io/gorm v1.25.12 h1:I0u8i2hWQItBq1WfE0o2+WuL9+8L21K9e2HHSTE/0f8=
|
||||
gorm.io/gorm v1.25.12/go.mod h1:xh7N7RHfYlNc5EmcI/El95gXusucDrQnHXe0+CgWcLQ=
|
||||
modernc.org/libc v1.22.5 h1:91BNch/e5B0uPbJFgqbxXuOnxBQjlS//icfQEGmvyjE=
|
||||
modernc.org/libc v1.22.5/go.mod h1:jj+Z7dTNX8fBScMVNRAYZ/jF91K8fdT2hYMThc3YjBY=
|
||||
modernc.org/libc v1.66.3 h1:cfCbjTUcdsKyyZZfEUKfoHcP3S0Wkvz3jgSzByEWVCQ=
|
||||
modernc.org/libc v1.66.3/go.mod h1:XD9zO8kt59cANKvHPXpx7yS2ELPheAey0vjIuZOhOU8=
|
||||
modernc.org/mathutil v1.5.0 h1:rV0Ko/6SfM+8G+yKiyI830l3Wuz1zRutdslNoQ0kfiQ=
|
||||
modernc.org/mathutil v1.5.0/go.mod h1:mZW8CKdRPY1v87qxC/wUdX5O1qDzXMP5TH3wjfpga6E=
|
||||
modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU=
|
||||
modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg=
|
||||
modernc.org/memory v1.5.0 h1:N+/8c5rE6EqugZwHii4IFsaJ7MUhoWX07J5tC/iI5Ds=
|
||||
modernc.org/memory v1.5.0/go.mod h1:PkUhL0Mugw21sHPeskwZW4D6VscE/GQJOnIpCnW6pSU=
|
||||
modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI=
|
||||
modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw=
|
||||
modernc.org/sqlite v1.23.1 h1:nrSBg4aRQQwq59JpvGEQ15tNxoO5pX/kUjcRNwSAGQM=
|
||||
modernc.org/sqlite v1.23.1/go.mod h1:OrDj17Mggn6MhE+iPbBNf7RGKODDE9NFT0f3EwDzJqk=
|
||||
modernc.org/sqlite v1.38.0 h1:+4OrfPQ8pxHKuWG4md1JpR/EYAh3Md7TdejuuzE7EUI=
|
||||
modernc.org/sqlite v1.38.0/go.mod h1:1Bj+yES4SVvBZ4cBOpVZ6QgesMCKpJZDq0nxYzOpmNE=
|
||||
|
||||
@@ -9,6 +9,8 @@ import (
|
||||
"github.com/uptrace/bun"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/reflection"
|
||||
)
|
||||
|
||||
// BunAdapter adapts Bun to work with our Database interface
|
||||
@@ -119,6 +121,12 @@ func (b *BunSelectQuery) Column(columns ...string) common.SelectQuery {
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *BunSelectQuery) ColumnExpr(query string, args ...interface{}) common.SelectQuery {
|
||||
b.query = b.query.ColumnExpr(query, args)
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *BunSelectQuery) Where(query string, args ...interface{}) common.SelectQuery {
|
||||
b.query = b.query.Where(query, args...)
|
||||
return b
|
||||
@@ -209,6 +217,40 @@ func (b *BunSelectQuery) Preload(relation string, conditions ...interface{}) com
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *BunSelectQuery) PreloadRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery {
|
||||
b.query = b.query.Relation(relation, func(sq *bun.SelectQuery) *bun.SelectQuery {
|
||||
if len(apply) == 0 {
|
||||
return sq
|
||||
}
|
||||
|
||||
// Wrap the incoming *bun.SelectQuery in our adapter
|
||||
wrapper := &BunSelectQuery{
|
||||
query: sq,
|
||||
db: b.db,
|
||||
}
|
||||
|
||||
// Start with the interface value (not pointer)
|
||||
current := common.SelectQuery(wrapper)
|
||||
|
||||
// Apply each function in sequence
|
||||
for _, fn := range apply {
|
||||
if fn != nil {
|
||||
// Pass ¤t (pointer to interface variable), fn modifies and returns new interface value
|
||||
modified := fn(current)
|
||||
current = modified
|
||||
}
|
||||
}
|
||||
|
||||
// Extract the final *bun.SelectQuery
|
||||
if finalBun, ok := current.(*BunSelectQuery); ok {
|
||||
return finalBun.query
|
||||
}
|
||||
|
||||
return sq // fallback
|
||||
})
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *BunSelectQuery) Order(order string) common.SelectQuery {
|
||||
b.query = b.query.Order(order)
|
||||
return b
|
||||
@@ -238,6 +280,10 @@ func (b *BunSelectQuery) Scan(ctx context.Context, dest interface{}) error {
|
||||
return b.query.Scan(ctx, dest)
|
||||
}
|
||||
|
||||
func (b *BunSelectQuery) ScanModel(ctx context.Context) error {
|
||||
return b.query.Scan(ctx)
|
||||
}
|
||||
|
||||
func (b *BunSelectQuery) Count(ctx context.Context) (int, error) {
|
||||
// If Model() was set, use bun's native Count() which works properly
|
||||
if b.hasModel {
|
||||
@@ -261,12 +307,14 @@ func (b *BunSelectQuery) Exists(ctx context.Context) (bool, error) {
|
||||
|
||||
// BunInsertQuery implements InsertQuery for Bun
|
||||
type BunInsertQuery struct {
|
||||
query *bun.InsertQuery
|
||||
values map[string]interface{}
|
||||
query *bun.InsertQuery
|
||||
values map[string]interface{}
|
||||
hasModel bool
|
||||
}
|
||||
|
||||
func (b *BunInsertQuery) Model(model interface{}) common.InsertQuery {
|
||||
b.query = b.query.Model(model)
|
||||
b.hasModel = true
|
||||
return b
|
||||
}
|
||||
|
||||
@@ -296,10 +344,16 @@ func (b *BunInsertQuery) Returning(columns ...string) common.InsertQuery {
|
||||
}
|
||||
|
||||
func (b *BunInsertQuery) Exec(ctx context.Context) (common.Result, error) {
|
||||
if b.values != nil {
|
||||
// For Bun, we need to handle this differently
|
||||
for k, v := range b.values {
|
||||
b.query = b.query.Set("? = ?", bun.Ident(k), v)
|
||||
if b.values != nil && len(b.values) > 0 {
|
||||
if !b.hasModel {
|
||||
// If no model was set, use the values map as the model
|
||||
// Bun can insert map[string]interface{} directly
|
||||
b.query = b.query.Model(&b.values)
|
||||
} else {
|
||||
// If model was set, use Value() to add individual values
|
||||
for k, v := range b.values {
|
||||
b.query = b.query.Value(k, "?", v)
|
||||
}
|
||||
}
|
||||
}
|
||||
result, err := b.query.Exec(ctx)
|
||||
@@ -309,25 +363,50 @@ func (b *BunInsertQuery) Exec(ctx context.Context) (common.Result, error) {
|
||||
// BunUpdateQuery implements UpdateQuery for Bun
|
||||
type BunUpdateQuery struct {
|
||||
query *bun.UpdateQuery
|
||||
model interface{}
|
||||
}
|
||||
|
||||
func (b *BunUpdateQuery) Model(model interface{}) common.UpdateQuery {
|
||||
b.query = b.query.Model(model)
|
||||
b.model = model
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *BunUpdateQuery) Table(table string) common.UpdateQuery {
|
||||
b.query = b.query.Table(table)
|
||||
if b.model == nil {
|
||||
// Try to get table name from table string if model is not set
|
||||
|
||||
model, err := modelregistry.GetModelByName(table)
|
||||
if err == nil {
|
||||
b.model = model
|
||||
}
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *BunUpdateQuery) Set(column string, value interface{}) common.UpdateQuery {
|
||||
// Validate column is writable if model is set
|
||||
if b.model != nil && !reflection.IsColumnWritable(b.model, column) {
|
||||
// Skip scan-only columns
|
||||
return b
|
||||
}
|
||||
b.query = b.query.Set(column+" = ?", value)
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *BunUpdateQuery) SetMap(values map[string]interface{}) common.UpdateQuery {
|
||||
pkName := reflection.GetPrimaryKeyName(b.model)
|
||||
for column, value := range values {
|
||||
// Validate column is writable if model is set
|
||||
if b.model != nil && !reflection.IsColumnWritable(b.model, column) {
|
||||
// Skip scan-only columns
|
||||
continue
|
||||
}
|
||||
if pkName != "" && column == pkName {
|
||||
// Skip primary key updates
|
||||
continue
|
||||
}
|
||||
b.query = b.query.Set(column+" = ?", value)
|
||||
}
|
||||
return b
|
||||
|
||||
213
pkg/common/adapters/database/bun_insert_test.go
Normal file
213
pkg/common/adapters/database/bun_insert_test.go
Normal file
@@ -0,0 +1,213 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/uptrace/bun"
|
||||
"github.com/uptrace/bun/dialect/sqlitedialect"
|
||||
"github.com/uptrace/bun/driver/sqliteshim"
|
||||
)
|
||||
|
||||
// TestInsertModel is a test model for insert operations
|
||||
type TestInsertModel struct {
|
||||
bun.BaseModel `bun:"table:test_inserts"`
|
||||
ID int64 `bun:"id,pk,autoincrement"`
|
||||
Name string `bun:"name,notnull"`
|
||||
Email string `bun:"email"`
|
||||
Age int `bun:"age"`
|
||||
}
|
||||
|
||||
func setupBunTestDB(t *testing.T) *bun.DB {
|
||||
sqldb, err := sql.Open(sqliteshim.ShimName, "file::memory:?cache=shared")
|
||||
require.NoError(t, err, "Failed to open SQLite database")
|
||||
|
||||
db := bun.NewDB(sqldb, sqlitedialect.New())
|
||||
|
||||
// Create test table
|
||||
_, err = db.NewCreateTable().
|
||||
Model((*TestInsertModel)(nil)).
|
||||
IfNotExists().
|
||||
Exec(context.Background())
|
||||
require.NoError(t, err, "Failed to create test table")
|
||||
|
||||
return db
|
||||
}
|
||||
|
||||
func TestBunInsertQuery_Model(t *testing.T) {
|
||||
db := setupBunTestDB(t)
|
||||
defer db.Close()
|
||||
|
||||
adapter := NewBunAdapter(db)
|
||||
ctx := context.Background()
|
||||
|
||||
// Test inserting with Model()
|
||||
model := &TestInsertModel{
|
||||
Name: "John Doe",
|
||||
Email: "john@example.com",
|
||||
Age: 30,
|
||||
}
|
||||
|
||||
result, err := adapter.NewInsert().
|
||||
Model(model).
|
||||
Returning("*").
|
||||
Exec(ctx)
|
||||
|
||||
require.NoError(t, err, "Insert should succeed")
|
||||
assert.Equal(t, int64(1), result.RowsAffected(), "Should insert 1 row")
|
||||
|
||||
// Verify the data was inserted
|
||||
var retrieved TestInsertModel
|
||||
err = db.NewSelect().
|
||||
Model(&retrieved).
|
||||
Where("id = ?", model.ID).
|
||||
Scan(ctx)
|
||||
|
||||
require.NoError(t, err, "Should retrieve inserted row")
|
||||
assert.Equal(t, "John Doe", retrieved.Name)
|
||||
assert.Equal(t, "john@example.com", retrieved.Email)
|
||||
assert.Equal(t, 30, retrieved.Age)
|
||||
}
|
||||
|
||||
func TestBunInsertQuery_Value(t *testing.T) {
|
||||
db := setupBunTestDB(t)
|
||||
defer db.Close()
|
||||
|
||||
adapter := NewBunAdapter(db)
|
||||
ctx := context.Background()
|
||||
|
||||
// Test inserting with Value() method - this was the bug
|
||||
result, err := adapter.NewInsert().
|
||||
Table("test_inserts").
|
||||
Value("name", "Jane Smith").
|
||||
Value("email", "jane@example.com").
|
||||
Value("age", 25).
|
||||
Exec(ctx)
|
||||
|
||||
require.NoError(t, err, "Insert with Value() should succeed")
|
||||
assert.Equal(t, int64(1), result.RowsAffected(), "Should insert 1 row")
|
||||
|
||||
// Verify the data was inserted
|
||||
var retrieved TestInsertModel
|
||||
err = db.NewSelect().
|
||||
Model(&retrieved).
|
||||
Where("name = ?", "Jane Smith").
|
||||
Scan(ctx)
|
||||
|
||||
require.NoError(t, err, "Should retrieve inserted row")
|
||||
assert.Equal(t, "Jane Smith", retrieved.Name)
|
||||
assert.Equal(t, "jane@example.com", retrieved.Email)
|
||||
assert.Equal(t, 25, retrieved.Age)
|
||||
}
|
||||
|
||||
func TestBunInsertQuery_MultipleValues(t *testing.T) {
|
||||
db := setupBunTestDB(t)
|
||||
defer db.Close()
|
||||
|
||||
adapter := NewBunAdapter(db)
|
||||
ctx := context.Background()
|
||||
|
||||
// Test inserting multiple values
|
||||
result, err := adapter.NewInsert().
|
||||
Table("test_inserts").
|
||||
Value("name", "Alice").
|
||||
Value("email", "alice@example.com").
|
||||
Value("age", 28).
|
||||
Exec(ctx)
|
||||
|
||||
require.NoError(t, err, "First insert should succeed")
|
||||
assert.Equal(t, int64(1), result.RowsAffected())
|
||||
|
||||
result, err = adapter.NewInsert().
|
||||
Table("test_inserts").
|
||||
Value("name", "Bob").
|
||||
Value("email", "bob@example.com").
|
||||
Value("age", 35).
|
||||
Exec(ctx)
|
||||
|
||||
require.NoError(t, err, "Second insert should succeed")
|
||||
assert.Equal(t, int64(1), result.RowsAffected())
|
||||
|
||||
// Verify both rows exist
|
||||
var count int
|
||||
count, err = db.NewSelect().
|
||||
Model((*TestInsertModel)(nil)).
|
||||
Count(ctx)
|
||||
|
||||
require.NoError(t, err, "Count should succeed")
|
||||
assert.Equal(t, 2, count, "Should have 2 rows")
|
||||
}
|
||||
|
||||
func TestBunInsertQuery_ValueWithNil(t *testing.T) {
|
||||
db := setupBunTestDB(t)
|
||||
defer db.Close()
|
||||
|
||||
adapter := NewBunAdapter(db)
|
||||
ctx := context.Background()
|
||||
|
||||
// Test inserting with nil value for nullable field
|
||||
result, err := adapter.NewInsert().
|
||||
Table("test_inserts").
|
||||
Value("name", "Test User").
|
||||
Value("email", nil). // NULL email
|
||||
Value("age", 20).
|
||||
Exec(ctx)
|
||||
|
||||
require.NoError(t, err, "Insert with nil value should succeed")
|
||||
assert.Equal(t, int64(1), result.RowsAffected())
|
||||
|
||||
// Verify the data was inserted with NULL email
|
||||
var retrieved TestInsertModel
|
||||
err = db.NewSelect().
|
||||
Model(&retrieved).
|
||||
Where("name = ?", "Test User").
|
||||
Scan(ctx)
|
||||
|
||||
require.NoError(t, err, "Should retrieve inserted row")
|
||||
assert.Equal(t, "Test User", retrieved.Name)
|
||||
assert.Equal(t, "", retrieved.Email) // NULL becomes empty string
|
||||
assert.Equal(t, 20, retrieved.Age)
|
||||
}
|
||||
|
||||
func TestBunInsertQuery_Returning(t *testing.T) {
|
||||
db := setupBunTestDB(t)
|
||||
defer db.Close()
|
||||
|
||||
adapter := NewBunAdapter(db)
|
||||
ctx := context.Background()
|
||||
|
||||
// Test insert with RETURNING clause
|
||||
// Note: SQLite has limited RETURNING support, but this tests the API
|
||||
result, err := adapter.NewInsert().
|
||||
Table("test_inserts").
|
||||
Value("name", "Return Test").
|
||||
Value("email", "return@example.com").
|
||||
Value("age", 40).
|
||||
Returning("*").
|
||||
Exec(ctx)
|
||||
|
||||
require.NoError(t, err, "Insert with RETURNING should succeed")
|
||||
assert.Equal(t, int64(1), result.RowsAffected())
|
||||
}
|
||||
|
||||
func TestBunInsertQuery_EmptyValues(t *testing.T) {
|
||||
db := setupBunTestDB(t)
|
||||
defer db.Close()
|
||||
|
||||
adapter := NewBunAdapter(db)
|
||||
ctx := context.Background()
|
||||
|
||||
// Test insert without calling Value() - should use Model() or fail gracefully
|
||||
result, err := adapter.NewInsert().
|
||||
Table("test_inserts").
|
||||
Exec(ctx)
|
||||
|
||||
// This should fail because no values are provided
|
||||
assert.Error(t, err, "Insert without values should fail")
|
||||
if result != nil {
|
||||
assert.Equal(t, int64(0), result.RowsAffected())
|
||||
}
|
||||
}
|
||||
@@ -8,6 +8,8 @@ import (
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/reflection"
|
||||
)
|
||||
|
||||
// GormAdapter adapts GORM to work with our Database interface
|
||||
@@ -97,6 +99,7 @@ func (g *GormSelectQuery) Table(table string) common.SelectQuery {
|
||||
g.db = g.db.Table(table)
|
||||
// Check if the table name contains schema (e.g., "schema.table")
|
||||
g.schema, g.tableName = parseTableName(table)
|
||||
|
||||
return g
|
||||
}
|
||||
|
||||
@@ -105,6 +108,11 @@ func (g *GormSelectQuery) Column(columns ...string) common.SelectQuery {
|
||||
return g
|
||||
}
|
||||
|
||||
func (g *GormSelectQuery) ColumnExpr(query string, args ...interface{}) common.SelectQuery {
|
||||
g.db = g.db.Select(query, args...)
|
||||
return g
|
||||
}
|
||||
|
||||
func (g *GormSelectQuery) Where(query string, args ...interface{}) common.SelectQuery {
|
||||
g.db = g.db.Where(query, args...)
|
||||
return g
|
||||
@@ -192,6 +200,36 @@ func (g *GormSelectQuery) Preload(relation string, conditions ...interface{}) co
|
||||
return g
|
||||
}
|
||||
|
||||
func (g *GormSelectQuery) PreloadRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery {
|
||||
g.db = g.db.Preload(relation, func(db *gorm.DB) *gorm.DB {
|
||||
if len(apply) == 0 {
|
||||
return db
|
||||
}
|
||||
|
||||
wrapper := &GormSelectQuery{
|
||||
db: db,
|
||||
}
|
||||
|
||||
current := common.SelectQuery(wrapper)
|
||||
|
||||
for _, fn := range apply {
|
||||
if fn != nil {
|
||||
|
||||
modified := fn(current)
|
||||
current = modified
|
||||
}
|
||||
}
|
||||
|
||||
if finalBun, ok := current.(*GormSelectQuery); ok {
|
||||
return finalBun.db
|
||||
}
|
||||
|
||||
return db // fallback
|
||||
})
|
||||
|
||||
return g
|
||||
}
|
||||
|
||||
func (g *GormSelectQuery) Order(order string) common.SelectQuery {
|
||||
g.db = g.db.Order(order)
|
||||
return g
|
||||
@@ -221,6 +259,13 @@ func (g *GormSelectQuery) Scan(ctx context.Context, dest interface{}) error {
|
||||
return g.db.WithContext(ctx).Find(dest).Error
|
||||
}
|
||||
|
||||
func (g *GormSelectQuery) ScanModel(ctx context.Context) error {
|
||||
if g.db.Statement.Model == nil {
|
||||
return fmt.Errorf("ScanModel requires Model() to be set before scanning")
|
||||
}
|
||||
return g.db.WithContext(ctx).Find(g.db.Statement.Model).Error
|
||||
}
|
||||
|
||||
func (g *GormSelectQuery) Count(ctx context.Context) (int, error) {
|
||||
var count int64
|
||||
err := g.db.WithContext(ctx).Count(&count).Error
|
||||
@@ -297,10 +342,23 @@ func (g *GormUpdateQuery) Model(model interface{}) common.UpdateQuery {
|
||||
|
||||
func (g *GormUpdateQuery) Table(table string) common.UpdateQuery {
|
||||
g.db = g.db.Table(table)
|
||||
if g.model == nil {
|
||||
// Try to get table name from table string if model is not set
|
||||
model, err := modelregistry.GetModelByName(table)
|
||||
if err == nil {
|
||||
g.model = model
|
||||
}
|
||||
}
|
||||
return g
|
||||
}
|
||||
|
||||
func (g *GormUpdateQuery) Set(column string, value interface{}) common.UpdateQuery {
|
||||
// Validate column is writable if model is set
|
||||
if g.model != nil && !reflection.IsColumnWritable(g.model, column) {
|
||||
// Skip read-only columns
|
||||
return g
|
||||
}
|
||||
|
||||
if g.updates == nil {
|
||||
g.updates = make(map[string]interface{})
|
||||
}
|
||||
@@ -311,7 +369,25 @@ func (g *GormUpdateQuery) Set(column string, value interface{}) common.UpdateQue
|
||||
}
|
||||
|
||||
func (g *GormUpdateQuery) SetMap(values map[string]interface{}) common.UpdateQuery {
|
||||
g.updates = values
|
||||
|
||||
// Filter out read-only columns if model is set
|
||||
if g.model != nil {
|
||||
pkName := reflection.GetPrimaryKeyName(g.model)
|
||||
filteredValues := make(map[string]interface{})
|
||||
for column, value := range values {
|
||||
if pkName != "" && column == pkName {
|
||||
// Skip primary key updates
|
||||
continue
|
||||
}
|
||||
if reflection.IsColumnWritable(g.model, column) {
|
||||
filteredValues[column] = value
|
||||
}
|
||||
|
||||
}
|
||||
g.updates = filteredValues
|
||||
} else {
|
||||
g.updates = values
|
||||
}
|
||||
return g
|
||||
}
|
||||
|
||||
|
||||
161
pkg/common/adapters/database/update_validation_test.go
Normal file
161
pkg/common/adapters/database/update_validation_test.go
Normal file
@@ -0,0 +1,161 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/reflection"
|
||||
)
|
||||
|
||||
// Test models for bun
|
||||
type BunTestModel struct {
|
||||
ID int `bun:"id,pk"`
|
||||
Name string `bun:"name"`
|
||||
Email string `bun:"email"`
|
||||
ComputedCol string `bun:"computed_col,scanonly"`
|
||||
}
|
||||
|
||||
// Test models for gorm
|
||||
type GormTestModel struct {
|
||||
ID int `gorm:"column:id;primaryKey"`
|
||||
Name string `gorm:"column:name"`
|
||||
Email string `gorm:"column:email"`
|
||||
ReadOnlyCol string `gorm:"column:readonly_col;->"`
|
||||
NoWriteCol string `gorm:"column:nowrite_col;<-:false"`
|
||||
}
|
||||
|
||||
func TestIsColumnWritable_Bun(t *testing.T) {
|
||||
model := &BunTestModel{}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
columnName string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "writable column - id",
|
||||
columnName: "id",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "writable column - name",
|
||||
columnName: "name",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "writable column - email",
|
||||
columnName: "email",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "scanonly column should not be writable",
|
||||
columnName: "computed_col",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "non-existent column should be writable (dynamic)",
|
||||
columnName: "nonexistent",
|
||||
expected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := reflection.IsColumnWritable(model, tt.columnName)
|
||||
if result != tt.expected {
|
||||
t.Errorf("IsColumnWritable(%q) = %v, want %v", tt.columnName, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsColumnWritable_Gorm(t *testing.T) {
|
||||
model := &GormTestModel{}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
columnName string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "writable column - id",
|
||||
columnName: "id",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "writable column - name",
|
||||
columnName: "name",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "writable column - email",
|
||||
columnName: "email",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "read-only column with -> should not be writable",
|
||||
columnName: "readonly_col",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "column with <-:false should not be writable",
|
||||
columnName: "nowrite_col",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "non-existent column should be writable (dynamic)",
|
||||
columnName: "nonexistent",
|
||||
expected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := reflection.IsColumnWritable(model, tt.columnName)
|
||||
if result != tt.expected {
|
||||
t.Errorf("IsColumnWritable(%q) = %v, want %v", tt.columnName, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBunUpdateQuery_SetMap_FiltersScanOnly(t *testing.T) {
|
||||
// Note: This is a unit test for the validation logic only.
|
||||
// We can't fully test the bun query without a database connection,
|
||||
// but we've verified the validation logic in TestIsColumnWritable_Bun
|
||||
t.Skip("Skipping integration test - validation logic tested in TestIsColumnWritable_Bun")
|
||||
}
|
||||
|
||||
func TestGormUpdateQuery_SetMap_FiltersReadOnly(t *testing.T) {
|
||||
model := &GormTestModel{}
|
||||
query := &GormUpdateQuery{
|
||||
model: model,
|
||||
}
|
||||
|
||||
// SetMap should filter out read-only columns
|
||||
values := map[string]interface{}{
|
||||
"name": "John",
|
||||
"email": "john@example.com",
|
||||
"readonly_col": "should_be_filtered",
|
||||
"nowrite_col": "should_also_be_filtered",
|
||||
}
|
||||
|
||||
query.SetMap(values)
|
||||
|
||||
// Check that the updates map only contains writable columns
|
||||
if updates, ok := query.updates.(map[string]interface{}); ok {
|
||||
if _, exists := updates["readonly_col"]; exists {
|
||||
t.Error("readonly_col should have been filtered out")
|
||||
}
|
||||
if _, exists := updates["nowrite_col"]; exists {
|
||||
t.Error("nowrite_col should have been filtered out")
|
||||
}
|
||||
if _, exists := updates["name"]; !exists {
|
||||
t.Error("name should be in updates")
|
||||
}
|
||||
if _, exists := updates["email"]; !exists {
|
||||
t.Error("email should be in updates")
|
||||
}
|
||||
} else {
|
||||
t.Error("updates should be a map[string]interface{}")
|
||||
}
|
||||
}
|
||||
@@ -26,11 +26,13 @@ type SelectQuery interface {
|
||||
Model(model interface{}) SelectQuery
|
||||
Table(table string) SelectQuery
|
||||
Column(columns ...string) SelectQuery
|
||||
ColumnExpr(query string, args ...interface{}) SelectQuery
|
||||
Where(query string, args ...interface{}) SelectQuery
|
||||
WhereOr(query string, args ...interface{}) SelectQuery
|
||||
Join(query string, args ...interface{}) SelectQuery
|
||||
LeftJoin(query string, args ...interface{}) SelectQuery
|
||||
Preload(relation string, conditions ...interface{}) SelectQuery
|
||||
PreloadRelation(relation string, apply ...func(SelectQuery) SelectQuery) SelectQuery
|
||||
Order(order string) SelectQuery
|
||||
Limit(n int) SelectQuery
|
||||
Offset(n int) SelectQuery
|
||||
@@ -39,6 +41,7 @@ type SelectQuery interface {
|
||||
|
||||
// Execution methods
|
||||
Scan(ctx context.Context, dest interface{}) error
|
||||
ScanModel(ctx context.Context) error
|
||||
Count(ctx context.Context) (int, error)
|
||||
Exists(ctx context.Context) (bool, error)
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/reflection"
|
||||
)
|
||||
|
||||
// CRUDRequestProvider interface for models that provide CRUD request strings
|
||||
@@ -110,6 +111,9 @@ func (p *NestedCUDProcessor) ProcessNestedCUD(
|
||||
// Inject parent IDs for foreign key resolution
|
||||
p.injectForeignKeys(regularData, modelType, parentIDs)
|
||||
|
||||
// Get the primary key name for this model
|
||||
pkName := reflection.GetPrimaryKeyName(model)
|
||||
|
||||
// Process based on operation
|
||||
switch strings.ToLower(operation) {
|
||||
case "insert", "create":
|
||||
@@ -127,30 +131,30 @@ func (p *NestedCUDProcessor) ProcessNestedCUD(
|
||||
}
|
||||
|
||||
case "update":
|
||||
rows, err := p.processUpdate(ctx, regularData, tableName, data["id"])
|
||||
rows, err := p.processUpdate(ctx, regularData, tableName, data[pkName])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("update failed: %w", err)
|
||||
}
|
||||
result.ID = data["id"]
|
||||
result.ID = data[pkName]
|
||||
result.AffectedRows = rows
|
||||
result.Data = regularData
|
||||
|
||||
// Process child relations for update
|
||||
if err := p.processChildRelations(ctx, "update", data["id"], relationFields, result.RelationData, modelType); err != nil {
|
||||
if err := p.processChildRelations(ctx, "update", data[pkName], relationFields, result.RelationData, modelType); err != nil {
|
||||
return nil, fmt.Errorf("failed to process child relations: %w", err)
|
||||
}
|
||||
|
||||
case "delete":
|
||||
// Process child relations first (for referential integrity)
|
||||
if err := p.processChildRelations(ctx, "delete", data["id"], relationFields, result.RelationData, modelType); err != nil {
|
||||
if err := p.processChildRelations(ctx, "delete", data[pkName], relationFields, result.RelationData, modelType); err != nil {
|
||||
return nil, fmt.Errorf("failed to process child relations before delete: %w", err)
|
||||
}
|
||||
|
||||
rows, err := p.processDelete(ctx, tableName, data["id"])
|
||||
rows, err := p.processDelete(ctx, tableName, data[pkName])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("delete failed: %w", err)
|
||||
}
|
||||
result.ID = data["id"]
|
||||
result.ID = data[pkName]
|
||||
result.AffectedRows = rows
|
||||
result.Data = regularData
|
||||
|
||||
@@ -248,7 +252,7 @@ func (p *NestedCUDProcessor) processUpdate(
|
||||
|
||||
logger.Debug("Updating %s with ID %v, data: %+v", tableName, id, data)
|
||||
|
||||
query := p.db.NewUpdate().Table(tableName).SetMap(data).Where("id = ?", id)
|
||||
query := p.db.NewUpdate().Table(tableName).SetMap(data).Where(fmt.Sprintf("%s = ?", QuoteIdent(reflection.GetPrimaryKeyName(tableName))), id)
|
||||
|
||||
result, err := query.Exec(ctx)
|
||||
if err != nil {
|
||||
@@ -268,7 +272,7 @@ func (p *NestedCUDProcessor) processDelete(ctx context.Context, tableName string
|
||||
|
||||
logger.Debug("Deleting from %s with ID %v", tableName, id)
|
||||
|
||||
query := p.db.NewDelete().Table(tableName).Where("id = ?", id)
|
||||
query := p.db.NewDelete().Table(tableName).Where(fmt.Sprintf("%s = ?", QuoteIdent(reflection.GetPrimaryKeyName(tableName))), id)
|
||||
|
||||
result, err := query.Exec(ctx)
|
||||
if err != nil {
|
||||
@@ -377,8 +381,16 @@ func (p *NestedCUDProcessor) getTableNameForModel(model interface{}, defaultName
|
||||
}
|
||||
|
||||
// ShouldUseNestedProcessor determines if we should use nested CUD processing
|
||||
// It checks if the data contains nested relations or a _request field
|
||||
// It recursively checks if the data contains:
|
||||
// 1. A _request field at any level, OR
|
||||
// 2. Nested relations that themselves contain further nested relations or _request fields
|
||||
// This ensures nested processing is only used when there are deeply nested operations
|
||||
func ShouldUseNestedProcessor(data map[string]interface{}, model interface{}, relationshipHelper RelationshipInfoProvider) bool {
|
||||
return shouldUseNestedProcessorDepth(data, model, relationshipHelper, 0)
|
||||
}
|
||||
|
||||
// shouldUseNestedProcessorDepth is the internal recursive implementation with depth tracking
|
||||
func shouldUseNestedProcessorDepth(data map[string]interface{}, model interface{}, relationshipHelper RelationshipInfoProvider, depth int) bool {
|
||||
// Check for _request field
|
||||
if _, hasCRUDRequest := data["_request"]; hasCRUDRequest {
|
||||
return true
|
||||
@@ -405,10 +417,34 @@ func ShouldUseNestedProcessor(data map[string]interface{}, model interface{}, re
|
||||
relInfo := relationshipHelper.GetRelationshipInfo(modelType, key)
|
||||
if relInfo != nil {
|
||||
// Check if the value is actually nested data (object or array)
|
||||
switch value.(type) {
|
||||
switch v := value.(type) {
|
||||
case map[string]interface{}, []interface{}, []map[string]interface{}:
|
||||
logger.Debug("Found nested relation field: %s", key)
|
||||
return true
|
||||
// If we're already at a nested level (depth > 0) and found a relation,
|
||||
// that means we have multi-level nesting, so return true
|
||||
if depth > 0 {
|
||||
return true
|
||||
}
|
||||
// At depth 0, recurse to check if the nested data has further nesting
|
||||
switch typedValue := v.(type) {
|
||||
case map[string]interface{}:
|
||||
if shouldUseNestedProcessorDepth(typedValue, relInfo.RelatedModel, relationshipHelper, depth+1) {
|
||||
return true
|
||||
}
|
||||
case []interface{}:
|
||||
for _, item := range typedValue {
|
||||
if itemMap, ok := item.(map[string]interface{}); ok {
|
||||
if shouldUseNestedProcessorDepth(itemMap, relInfo.RelatedModel, relationshipHelper, depth+1) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
case []map[string]interface{}:
|
||||
for _, itemMap := range typedValue {
|
||||
if shouldUseNestedProcessorDepth(itemMap, relInfo.RelatedModel, relationshipHelper, depth+1) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -35,7 +35,9 @@ type PreloadOption struct {
|
||||
Relation string `json:"relation"`
|
||||
Columns []string `json:"columns"`
|
||||
OmitColumns []string `json:"omit_columns"`
|
||||
Sort []SortOption `json:"sort"`
|
||||
Filters []FilterOption `json:"filters"`
|
||||
Where string `json:"where"`
|
||||
Limit *int `json:"limit"`
|
||||
Offset *int `json:"offset"`
|
||||
Updatable *bool `json:"updateable"` // if true, the relation can be updated
|
||||
|
||||
@@ -183,7 +183,8 @@ func (v *ColumnValidator) ValidateRequestOptions(options RequestOptions) error {
|
||||
}
|
||||
|
||||
// Validate Preload columns (if specified)
|
||||
for _, preload := range options.Preload {
|
||||
for idx := range options.Preload {
|
||||
preload := options.Preload[idx]
|
||||
// Note: We don't validate the relation name itself, as it's a relationship
|
||||
// Only validate columns if specified for the preload
|
||||
if err := v.ValidateColumns(preload.Columns); err != nil {
|
||||
@@ -239,7 +240,8 @@ func (v *ColumnValidator) FilterRequestOptions(options RequestOptions) RequestOp
|
||||
|
||||
// Filter Preload columns
|
||||
validPreloads := make([]PreloadOption, 0, len(options.Preload))
|
||||
for _, preload := range options.Preload {
|
||||
for idx := range options.Preload {
|
||||
preload := options.Preload[idx]
|
||||
filteredPreload := preload
|
||||
filteredPreload.Columns = v.FilterValidColumns(preload.Columns)
|
||||
filteredPreload.OmitColumns = v.FilterValidColumns(preload.OmitColumns)
|
||||
@@ -270,3 +272,11 @@ func (v *ColumnValidator) GetValidColumns() []string {
|
||||
}
|
||||
return columns
|
||||
}
|
||||
|
||||
func QuoteIdent(qualifier string) string {
|
||||
return `"` + strings.ReplaceAll(qualifier, `"`, `""`) + `"`
|
||||
}
|
||||
|
||||
func QuoteLiteral(value string) string {
|
||||
return `'` + strings.ReplaceAll(value, `'`, `''`) + `'`
|
||||
}
|
||||
|
||||
@@ -17,6 +17,10 @@ var defaultRegistry = &DefaultModelRegistry{
|
||||
models: make(map[string]interface{}),
|
||||
}
|
||||
|
||||
// Global list of registries (searched in order)
|
||||
var registries = []*DefaultModelRegistry{defaultRegistry}
|
||||
var registriesMutex sync.RWMutex
|
||||
|
||||
// NewModelRegistry creates a new model registry
|
||||
func NewModelRegistry() *DefaultModelRegistry {
|
||||
return &DefaultModelRegistry{
|
||||
@@ -24,6 +28,34 @@ func NewModelRegistry() *DefaultModelRegistry {
|
||||
}
|
||||
}
|
||||
|
||||
func SetDefaultRegistry(registry *DefaultModelRegistry) {
|
||||
registriesMutex.Lock()
|
||||
foundAt := -1
|
||||
for idx, r := range registries {
|
||||
if r == defaultRegistry {
|
||||
foundAt = idx
|
||||
break
|
||||
}
|
||||
}
|
||||
defaultRegistry = registry
|
||||
if foundAt >= 0 {
|
||||
registries[foundAt] = registry
|
||||
} else {
|
||||
registries = append([]*DefaultModelRegistry{registry}, registries...)
|
||||
}
|
||||
|
||||
defer registriesMutex.Unlock()
|
||||
|
||||
}
|
||||
|
||||
// AddRegistry adds a registry to the global list of registries
|
||||
// Registries are searched in the order they were added
|
||||
func AddRegistry(registry *DefaultModelRegistry) {
|
||||
registriesMutex.Lock()
|
||||
defer registriesMutex.Unlock()
|
||||
registries = append(registries, registry)
|
||||
}
|
||||
|
||||
func (r *DefaultModelRegistry) RegisterModel(name string, model interface{}) error {
|
||||
r.mutex.Lock()
|
||||
defer r.mutex.Unlock()
|
||||
@@ -107,9 +139,19 @@ func RegisterModel(model interface{}, name string) error {
|
||||
return defaultRegistry.RegisterModel(name, model)
|
||||
}
|
||||
|
||||
// GetModelByName retrieves a model from the default global registry by name
|
||||
// GetModelByName retrieves a model by searching through all registries in order
|
||||
// Returns the first match found
|
||||
func GetModelByName(name string) (interface{}, error) {
|
||||
return defaultRegistry.GetModel(name)
|
||||
registriesMutex.RLock()
|
||||
defer registriesMutex.RUnlock()
|
||||
|
||||
for _, registry := range registries {
|
||||
if model, err := registry.GetModel(name); err == nil {
|
||||
return model, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("model %s not found in any registry", name)
|
||||
}
|
||||
|
||||
// IterateModels iterates over all models in the default global registry
|
||||
@@ -122,14 +164,26 @@ func IterateModels(fn func(name string, model interface{})) {
|
||||
}
|
||||
}
|
||||
|
||||
// GetModels returns a list of all models in the default global registry
|
||||
// GetModels returns a list of all models from all registries
|
||||
// Models are collected in registry order, with duplicates included
|
||||
func GetModels() []interface{} {
|
||||
defaultRegistry.mutex.RLock()
|
||||
defer defaultRegistry.mutex.RUnlock()
|
||||
registriesMutex.RLock()
|
||||
defer registriesMutex.RUnlock()
|
||||
|
||||
models := make([]interface{}, 0, len(defaultRegistry.models))
|
||||
for _, model := range defaultRegistry.models {
|
||||
models = append(models, model)
|
||||
var models []interface{}
|
||||
seen := make(map[string]bool)
|
||||
|
||||
for _, registry := range registries {
|
||||
registry.mutex.RLock()
|
||||
for name, model := range registry.models {
|
||||
// Only add the first occurrence of each model name
|
||||
if !seen[name] {
|
||||
models = append(models, model)
|
||||
seen[name] = true
|
||||
}
|
||||
}
|
||||
registry.mutex.RUnlock()
|
||||
}
|
||||
|
||||
return models
|
||||
}
|
||||
|
||||
@@ -18,6 +18,7 @@ type ModelFieldDetail struct {
|
||||
}
|
||||
|
||||
// GetModelColumnDetail - Get a list of columns in the SQL declaration of the model
|
||||
// This function recursively processes embedded structs to include their fields
|
||||
func GetModelColumnDetail(record reflect.Value) []ModelFieldDetail {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
@@ -37,14 +38,43 @@ func GetModelColumnDetail(record reflect.Value) []ModelFieldDetail {
|
||||
if record.Kind() != reflect.Struct {
|
||||
return lst
|
||||
}
|
||||
|
||||
collectFieldDetails(record, &lst)
|
||||
|
||||
return lst
|
||||
}
|
||||
|
||||
// collectFieldDetails recursively collects field details from a struct value and its embedded fields
|
||||
func collectFieldDetails(record reflect.Value, lst *[]ModelFieldDetail) {
|
||||
modeltype := record.Type()
|
||||
|
||||
for i := 0; i < modeltype.NumField(); i++ {
|
||||
fieldtype := modeltype.Field(i)
|
||||
fieldValue := record.Field(i)
|
||||
|
||||
// Check if this is an embedded struct
|
||||
if fieldtype.Anonymous {
|
||||
// Unwrap pointer type if necessary
|
||||
embeddedValue := fieldValue
|
||||
if fieldValue.Kind() == reflect.Pointer {
|
||||
if fieldValue.IsNil() {
|
||||
// Skip nil embedded pointers
|
||||
continue
|
||||
}
|
||||
embeddedValue = fieldValue.Elem()
|
||||
}
|
||||
|
||||
// Recursively process embedded struct
|
||||
if embeddedValue.Kind() == reflect.Struct {
|
||||
collectFieldDetails(embeddedValue, lst)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
gormdetail := fieldtype.Tag.Get("gorm")
|
||||
gormdetail = strings.Trim(gormdetail, " ")
|
||||
fielddetail := ModelFieldDetail{}
|
||||
fielddetail.FieldValue = record.Field(i)
|
||||
fielddetail.FieldValue = fieldValue
|
||||
fielddetail.Name = fieldtype.Name
|
||||
fielddetail.DataType = fieldtype.Type.Name()
|
||||
fielddetail.SQLName = fnFindKeyVal(gormdetail, "column:")
|
||||
@@ -80,10 +110,8 @@ func GetModelColumnDetail(record reflect.Value) []ModelFieldDetail {
|
||||
}
|
||||
// ";foreignkey:rid_parent;association_foreignkey:id_atevent;save_associations:false;association_autocreate:false;"
|
||||
|
||||
lst = append(lst, fielddetail)
|
||||
|
||||
*lst = append(*lst, fielddetail)
|
||||
}
|
||||
return lst
|
||||
}
|
||||
|
||||
func fnFindKeyVal(src, key string) string {
|
||||
|
||||
@@ -1,9 +1,15 @@
|
||||
package common
|
||||
package reflection
|
||||
|
||||
import "reflect"
|
||||
|
||||
func Len(v any) int {
|
||||
val := reflect.ValueOf(v)
|
||||
valKind := val.Kind()
|
||||
|
||||
if valKind == reflect.Ptr {
|
||||
val = val.Elem()
|
||||
}
|
||||
|
||||
switch val.Kind() {
|
||||
case reflect.Slice, reflect.Array, reflect.Map, reflect.String, reflect.Chan:
|
||||
return val.Len()
|
||||
@@ -4,15 +4,31 @@ import (
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
|
||||
)
|
||||
|
||||
type PrimaryKeyNameProvider interface {
|
||||
GetIDName() string
|
||||
}
|
||||
|
||||
// GetPrimaryKeyName extracts the primary key column name from a model
|
||||
// It first checks if the model implements PrimaryKeyNameProvider (GetIDName method)
|
||||
// Falls back to reflection to find bun:",pk" tag, then gorm:"primaryKey" tag
|
||||
func GetPrimaryKeyName(model any) string {
|
||||
if reflect.TypeOf(model) == nil {
|
||||
return ""
|
||||
}
|
||||
// If we are given a string model name, look up the model
|
||||
if reflect.TypeOf(model).Kind() == reflect.String {
|
||||
name := model.(string)
|
||||
m, err := modelregistry.GetModelByName(name)
|
||||
if err == nil {
|
||||
model = m
|
||||
}
|
||||
}
|
||||
|
||||
// Check if model implements PrimaryKeyNameProvider
|
||||
if provider, ok := model.(common.PrimaryKeyNameProvider); ok {
|
||||
if provider, ok := model.(PrimaryKeyNameProvider); ok {
|
||||
return provider.GetIDName()
|
||||
}
|
||||
|
||||
@@ -22,11 +38,111 @@ func GetPrimaryKeyName(model any) string {
|
||||
}
|
||||
|
||||
// Fall back to GORM tag
|
||||
return getPrimaryKeyFromReflection(model, "gorm")
|
||||
if pkName := getPrimaryKeyFromReflection(model, "gorm"); pkName != "" {
|
||||
return pkName
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// GetPrimaryKeyValue extracts the primary key value from a model instance
|
||||
// Returns the value of the primary key field
|
||||
func GetPrimaryKeyValue(model any) any {
|
||||
if model == nil || reflect.TypeOf(model) == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
val := reflect.ValueOf(model)
|
||||
if val.Kind() == reflect.Pointer {
|
||||
val = val.Elem()
|
||||
}
|
||||
|
||||
if val.Kind() != reflect.Struct {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Try Bun tag first
|
||||
if pkValue := findPrimaryKeyValue(val, "bun"); pkValue != nil {
|
||||
return pkValue
|
||||
}
|
||||
|
||||
// Fall back to GORM tag
|
||||
if pkValue := findPrimaryKeyValue(val, "gorm"); pkValue != nil {
|
||||
return pkValue
|
||||
}
|
||||
|
||||
// Last resort: look for field named "ID" or "Id"
|
||||
if pkValue := findFieldByName(val, "id"); pkValue != nil {
|
||||
return pkValue
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// findPrimaryKeyValue recursively searches for a primary key field in the struct
|
||||
func findPrimaryKeyValue(val reflect.Value, ormType string) any {
|
||||
typ := val.Type()
|
||||
|
||||
for i := 0; i < typ.NumField(); i++ {
|
||||
field := typ.Field(i)
|
||||
fieldValue := val.Field(i)
|
||||
|
||||
// Check if this is an embedded struct
|
||||
if field.Anonymous && field.Type.Kind() == reflect.Struct {
|
||||
// Recursively search in embedded struct
|
||||
if pkValue := findPrimaryKeyValue(fieldValue, ormType); pkValue != nil {
|
||||
return pkValue
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Check for primary key tag
|
||||
switch ormType {
|
||||
case "bun":
|
||||
bunTag := field.Tag.Get("bun")
|
||||
if strings.Contains(bunTag, "pk") && fieldValue.CanInterface() {
|
||||
return fieldValue.Interface()
|
||||
}
|
||||
case "gorm":
|
||||
gormTag := field.Tag.Get("gorm")
|
||||
if strings.Contains(gormTag, "primaryKey") && fieldValue.CanInterface() {
|
||||
return fieldValue.Interface()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// findFieldByName recursively searches for a field by name in the struct
|
||||
func findFieldByName(val reflect.Value, name string) any {
|
||||
typ := val.Type()
|
||||
|
||||
for i := 0; i < typ.NumField(); i++ {
|
||||
field := typ.Field(i)
|
||||
fieldValue := val.Field(i)
|
||||
|
||||
// Check if this is an embedded struct
|
||||
if field.Anonymous && field.Type.Kind() == reflect.Struct {
|
||||
// Recursively search in embedded struct
|
||||
if result := findFieldByName(fieldValue, name); result != nil {
|
||||
return result
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if field name matches
|
||||
if strings.ToLower(field.Name) == name && fieldValue.CanInterface() {
|
||||
return fieldValue.Interface()
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetModelColumns extracts all column names from a model using reflection
|
||||
// It checks bun tags first, then gorm tags, then json tags, and finally falls back to lowercase field names
|
||||
// This function recursively processes embedded structs to include their fields
|
||||
func GetModelColumns(model any) []string {
|
||||
var columns []string
|
||||
|
||||
@@ -42,18 +158,38 @@ func GetModelColumns(model any) []string {
|
||||
return columns
|
||||
}
|
||||
|
||||
for i := 0; i < modelType.NumField(); i++ {
|
||||
field := modelType.Field(i)
|
||||
collectColumnsFromType(modelType, &columns)
|
||||
|
||||
return columns
|
||||
}
|
||||
|
||||
// collectColumnsFromType recursively collects column names from a struct type and its embedded fields
|
||||
func collectColumnsFromType(typ reflect.Type, columns *[]string) {
|
||||
for i := 0; i < typ.NumField(); i++ {
|
||||
field := typ.Field(i)
|
||||
|
||||
// Check if this is an embedded struct
|
||||
if field.Anonymous {
|
||||
// Unwrap pointer type if necessary
|
||||
fieldType := field.Type
|
||||
if fieldType.Kind() == reflect.Pointer {
|
||||
fieldType = fieldType.Elem()
|
||||
}
|
||||
|
||||
// Recursively process embedded struct
|
||||
if fieldType.Kind() == reflect.Struct {
|
||||
collectColumnsFromType(fieldType, columns)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Get column name using the same logic as primary key extraction
|
||||
columnName := getColumnNameFromField(field)
|
||||
|
||||
if columnName != "" {
|
||||
columns = append(columns, columnName)
|
||||
*columns = append(*columns, columnName)
|
||||
}
|
||||
}
|
||||
|
||||
return columns
|
||||
}
|
||||
|
||||
// getColumnNameFromField extracts the column name from a struct field
|
||||
@@ -90,6 +226,7 @@ func getColumnNameFromField(field reflect.StructField) string {
|
||||
}
|
||||
|
||||
// getPrimaryKeyFromReflection uses reflection to find the primary key field
|
||||
// This function recursively searches embedded structs
|
||||
func getPrimaryKeyFromReflection(model any, ormType string) string {
|
||||
val := reflect.ValueOf(model)
|
||||
if val.Kind() == reflect.Pointer {
|
||||
@@ -101,9 +238,31 @@ func getPrimaryKeyFromReflection(model any, ormType string) string {
|
||||
}
|
||||
|
||||
typ := val.Type()
|
||||
return findPrimaryKeyNameFromType(typ, ormType)
|
||||
}
|
||||
|
||||
// findPrimaryKeyNameFromType recursively searches for the primary key field name in a struct type
|
||||
func findPrimaryKeyNameFromType(typ reflect.Type, ormType string) string {
|
||||
for i := 0; i < typ.NumField(); i++ {
|
||||
field := typ.Field(i)
|
||||
|
||||
// Check if this is an embedded struct
|
||||
if field.Anonymous {
|
||||
// Unwrap pointer type if necessary
|
||||
fieldType := field.Type
|
||||
if fieldType.Kind() == reflect.Pointer {
|
||||
fieldType = fieldType.Elem()
|
||||
}
|
||||
|
||||
// Recursively search in embedded struct
|
||||
if fieldType.Kind() == reflect.Struct {
|
||||
if pkName := findPrimaryKeyNameFromType(fieldType, ormType); pkName != "" {
|
||||
return pkName
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
switch ormType {
|
||||
case "gorm":
|
||||
// Check for gorm tag with primaryKey
|
||||
@@ -155,8 +314,129 @@ func ExtractColumnFromGormTag(tag string) string {
|
||||
// Example: ",pk" -> "" (will fall back to json tag)
|
||||
func ExtractColumnFromBunTag(tag string) string {
|
||||
parts := strings.Split(tag, ",")
|
||||
if strings.HasPrefix(strings.ToLower(tag), "table:") || strings.HasPrefix(strings.ToLower(tag), "rel:") || strings.HasPrefix(strings.ToLower(tag), "join:") {
|
||||
return ""
|
||||
}
|
||||
if len(parts) > 0 && parts[0] != "" {
|
||||
return parts[0]
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// IsColumnWritable checks if a column can be written to in the database
|
||||
// For bun: returns false if the field has "scanonly" tag
|
||||
// For gorm: returns false if the field has "<-:false" or "->" (read-only) tag
|
||||
// This function recursively searches embedded structs
|
||||
func IsColumnWritable(model any, columnName string) bool {
|
||||
modelType := reflect.TypeOf(model)
|
||||
|
||||
// Unwrap pointers to get to the base struct type
|
||||
for modelType != nil && modelType.Kind() == reflect.Pointer {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
|
||||
// Validate that we have a struct type
|
||||
if modelType == nil || modelType.Kind() != reflect.Struct {
|
||||
return false
|
||||
}
|
||||
|
||||
found, writable := isColumnWritableInType(modelType, columnName)
|
||||
if found {
|
||||
return writable
|
||||
}
|
||||
|
||||
// Column not found in model, allow it (might be a dynamic column)
|
||||
return true
|
||||
}
|
||||
|
||||
// isColumnWritableInType recursively searches for a column and checks if it's writable
|
||||
// Returns (found, writable) where found indicates if the column was found
|
||||
func isColumnWritableInType(typ reflect.Type, columnName string) (bool, bool) {
|
||||
for i := 0; i < typ.NumField(); i++ {
|
||||
field := typ.Field(i)
|
||||
|
||||
// Check if this is an embedded struct
|
||||
if field.Anonymous {
|
||||
// Unwrap pointer type if necessary
|
||||
fieldType := field.Type
|
||||
if fieldType.Kind() == reflect.Pointer {
|
||||
fieldType = fieldType.Elem()
|
||||
}
|
||||
|
||||
// Recursively search in embedded struct
|
||||
if fieldType.Kind() == reflect.Struct {
|
||||
if found, writable := isColumnWritableInType(fieldType, columnName); found {
|
||||
return true, writable
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if this field matches the column name
|
||||
fieldColumnName := getColumnNameFromField(field)
|
||||
if fieldColumnName != columnName {
|
||||
continue
|
||||
}
|
||||
|
||||
// Found the field, now check if it's writable
|
||||
// Check bun tag for scanonly
|
||||
bunTag := field.Tag.Get("bun")
|
||||
if bunTag != "" {
|
||||
if isBunFieldScanOnly(bunTag) {
|
||||
return true, false
|
||||
}
|
||||
}
|
||||
|
||||
// Check gorm tag for write restrictions
|
||||
gormTag := field.Tag.Get("gorm")
|
||||
if gormTag != "" {
|
||||
if isGormFieldReadOnly(gormTag) {
|
||||
return true, false
|
||||
}
|
||||
}
|
||||
|
||||
// Column is writable
|
||||
return true, true
|
||||
}
|
||||
|
||||
// Column not found
|
||||
return false, false
|
||||
}
|
||||
|
||||
// isBunFieldScanOnly checks if a bun tag indicates the field is scan-only
|
||||
// Example: "column_name,scanonly" -> true
|
||||
func isBunFieldScanOnly(tag string) bool {
|
||||
parts := strings.Split(tag, ",")
|
||||
for _, part := range parts {
|
||||
if strings.TrimSpace(part) == "scanonly" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// isGormFieldReadOnly checks if a gorm tag indicates the field is read-only
|
||||
// Examples:
|
||||
// - "<-:false" -> true (no writes allowed)
|
||||
// - "->" -> true (read-only, common pattern)
|
||||
// - "column:name;->" -> true
|
||||
// - "<-:create" -> false (writes allowed on create)
|
||||
func isGormFieldReadOnly(tag string) bool {
|
||||
parts := strings.Split(tag, ";")
|
||||
for _, part := range parts {
|
||||
part = strings.TrimSpace(part)
|
||||
|
||||
// Check for read-only marker
|
||||
if part == "->" {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check for write restrictions
|
||||
if value, found := strings.CutPrefix(part, "<-:"); found {
|
||||
if value == "false" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -231,3 +231,246 @@ func TestGetModelColumns(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Test models with embedded structs
|
||||
|
||||
type BaseModel struct {
|
||||
ID int `bun:"rid_base,pk" json:"id"`
|
||||
CreatedAt string `bun:"created_at" json:"created_at"`
|
||||
}
|
||||
|
||||
type AdhocBuffer struct {
|
||||
CQL1 string `json:"cql1,omitempty" gorm:"->" bun:",scanonly"`
|
||||
CQL2 string `json:"cql2,omitempty" gorm:"->" bun:",scanonly"`
|
||||
RowNumber int64 `json:"_rownumber,omitempty" gorm:"-" bun:",scanonly"`
|
||||
}
|
||||
|
||||
type ModelWithEmbedded struct {
|
||||
BaseModel
|
||||
Name string `bun:"name" json:"name"`
|
||||
Description string `bun:"description" json:"description"`
|
||||
AdhocBuffer
|
||||
}
|
||||
|
||||
type GormBaseModel struct {
|
||||
ID int `gorm:"column:rid_base;primaryKey" json:"id"`
|
||||
CreatedAt string `gorm:"column:created_at" json:"created_at"`
|
||||
}
|
||||
|
||||
type GormAdhocBuffer struct {
|
||||
CQL1 string `json:"cql1,omitempty" gorm:"column:cql1;->" bun:",scanonly"`
|
||||
CQL2 string `json:"cql2,omitempty" gorm:"column:cql2;->" bun:",scanonly"`
|
||||
RowNumber int64 `json:"_rownumber,omitempty" gorm:"-" bun:",scanonly"`
|
||||
}
|
||||
|
||||
type GormModelWithEmbedded struct {
|
||||
GormBaseModel
|
||||
Name string `gorm:"column:name" json:"name"`
|
||||
Description string `gorm:"column:description" json:"description"`
|
||||
GormAdhocBuffer
|
||||
}
|
||||
|
||||
func TestGetPrimaryKeyNameWithEmbedded(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
model any
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "Bun model with embedded base",
|
||||
model: ModelWithEmbedded{},
|
||||
expected: "rid_base",
|
||||
},
|
||||
{
|
||||
name: "Bun model with embedded base (pointer)",
|
||||
model: &ModelWithEmbedded{},
|
||||
expected: "rid_base",
|
||||
},
|
||||
{
|
||||
name: "GORM model with embedded base",
|
||||
model: GormModelWithEmbedded{},
|
||||
expected: "rid_base",
|
||||
},
|
||||
{
|
||||
name: "GORM model with embedded base (pointer)",
|
||||
model: &GormModelWithEmbedded{},
|
||||
expected: "rid_base",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := GetPrimaryKeyName(tt.model)
|
||||
if result != tt.expected {
|
||||
t.Errorf("GetPrimaryKeyName() = %v, want %v", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetPrimaryKeyValueWithEmbedded(t *testing.T) {
|
||||
bunModel := ModelWithEmbedded{
|
||||
BaseModel: BaseModel{
|
||||
ID: 123,
|
||||
CreatedAt: "2024-01-01",
|
||||
},
|
||||
Name: "Test",
|
||||
Description: "Test Description",
|
||||
}
|
||||
|
||||
gormModel := GormModelWithEmbedded{
|
||||
GormBaseModel: GormBaseModel{
|
||||
ID: 456,
|
||||
CreatedAt: "2024-01-02",
|
||||
},
|
||||
Name: "GORM Test",
|
||||
Description: "GORM Test Description",
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
model any
|
||||
expected any
|
||||
}{
|
||||
{
|
||||
name: "Bun model with embedded base",
|
||||
model: bunModel,
|
||||
expected: 123,
|
||||
},
|
||||
{
|
||||
name: "Bun model with embedded base (pointer)",
|
||||
model: &bunModel,
|
||||
expected: 123,
|
||||
},
|
||||
{
|
||||
name: "GORM model with embedded base",
|
||||
model: gormModel,
|
||||
expected: 456,
|
||||
},
|
||||
{
|
||||
name: "GORM model with embedded base (pointer)",
|
||||
model: &gormModel,
|
||||
expected: 456,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := GetPrimaryKeyValue(tt.model)
|
||||
if result != tt.expected {
|
||||
t.Errorf("GetPrimaryKeyValue() = %v, want %v", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetModelColumnsWithEmbedded(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
model any
|
||||
expected []string
|
||||
}{
|
||||
{
|
||||
name: "Bun model with embedded structs",
|
||||
model: ModelWithEmbedded{},
|
||||
expected: []string{"rid_base", "created_at", "name", "description", "cql1", "cql2", "_rownumber"},
|
||||
},
|
||||
{
|
||||
name: "Bun model with embedded structs (pointer)",
|
||||
model: &ModelWithEmbedded{},
|
||||
expected: []string{"rid_base", "created_at", "name", "description", "cql1", "cql2", "_rownumber"},
|
||||
},
|
||||
{
|
||||
name: "GORM model with embedded structs",
|
||||
model: GormModelWithEmbedded{},
|
||||
expected: []string{"rid_base", "created_at", "name", "description", "cql1", "cql2", "_rownumber"},
|
||||
},
|
||||
{
|
||||
name: "GORM model with embedded structs (pointer)",
|
||||
model: &GormModelWithEmbedded{},
|
||||
expected: []string{"rid_base", "created_at", "name", "description", "cql1", "cql2", "_rownumber"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := GetModelColumns(tt.model)
|
||||
if len(result) != len(tt.expected) {
|
||||
t.Errorf("GetModelColumns() returned %d columns, want %d. Got: %v", len(result), len(tt.expected), result)
|
||||
return
|
||||
}
|
||||
for i, col := range result {
|
||||
if col != tt.expected[i] {
|
||||
t.Errorf("GetModelColumns()[%d] = %v, want %v", i, col, tt.expected[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsColumnWritableWithEmbedded(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
model any
|
||||
columnName string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "Bun model - writable column in main struct",
|
||||
model: ModelWithEmbedded{},
|
||||
columnName: "name",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Bun model - writable column in embedded base",
|
||||
model: ModelWithEmbedded{},
|
||||
columnName: "rid_base",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Bun model - scanonly column in embedded adhoc buffer",
|
||||
model: ModelWithEmbedded{},
|
||||
columnName: "cql1",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Bun model - scanonly column _rownumber",
|
||||
model: ModelWithEmbedded{},
|
||||
columnName: "_rownumber",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "GORM model - writable column in main struct",
|
||||
model: GormModelWithEmbedded{},
|
||||
columnName: "name",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "GORM model - writable column in embedded base",
|
||||
model: GormModelWithEmbedded{},
|
||||
columnName: "rid_base",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "GORM model - readonly column in embedded adhoc buffer",
|
||||
model: GormModelWithEmbedded{},
|
||||
columnName: "cql1",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "GORM model - readonly column _rownumber",
|
||||
model: GormModelWithEmbedded{},
|
||||
columnName: "_rownumber",
|
||||
expected: false, // bun:",scanonly" marks it as read-only, takes precedence over gorm:"-"
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := IsColumnWritable(tt.model, tt.columnName)
|
||||
if result != tt.expected {
|
||||
t.Errorf("IsColumnWritable(%s) = %v, want %v", tt.columnName, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/reflection"
|
||||
)
|
||||
|
||||
// Handler handles API requests using database and model abstractions
|
||||
@@ -196,6 +197,13 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
||||
query = query.Column(options.Columns...)
|
||||
}
|
||||
|
||||
if len(options.ComputedColumns) > 0 {
|
||||
for _, cu := range options.ComputedColumns {
|
||||
logger.Debug("Applying computed column: %s", cu.Name)
|
||||
query = query.ColumnExpr(fmt.Sprintf("(%s) AS %s", cu.Expression, cu.Name))
|
||||
}
|
||||
}
|
||||
|
||||
// Apply preloading
|
||||
if len(options.Preload) > 0 {
|
||||
query = h.applyPreloads(model, query, options.Preload)
|
||||
@@ -242,7 +250,8 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
||||
logger.Debug("Querying single record with ID: %s", id)
|
||||
// For single record, create a new pointer to the struct type
|
||||
singleResult := reflect.New(modelType).Interface()
|
||||
query = query.Where("id = ?", id)
|
||||
|
||||
query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(reflection.GetPrimaryKeyName(singleResult))), id)
|
||||
if err := query.Scan(ctx, singleResult); err != nil {
|
||||
logger.Error("Error querying record: %v", err)
|
||||
h.sendError(w, http.StatusInternalServerError, "query_error", "Error executing query", err)
|
||||
@@ -514,15 +523,15 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url
|
||||
// Apply conditions
|
||||
if urlID != "" {
|
||||
logger.Debug("Updating by URL ID: %s", urlID)
|
||||
query = query.Where("id = ?", urlID)
|
||||
query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(reflection.GetPrimaryKeyName(model))), urlID)
|
||||
} else if reqID != nil {
|
||||
switch id := reqID.(type) {
|
||||
case string:
|
||||
logger.Debug("Updating by request ID: %s", id)
|
||||
query = query.Where("id = ?", id)
|
||||
query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(reflection.GetPrimaryKeyName(model))), id)
|
||||
case []string:
|
||||
logger.Debug("Updating by multiple IDs: %v", id)
|
||||
query = query.Where("id IN (?)", id)
|
||||
query = query.Where(fmt.Sprintf("%s IN (?)", common.QuoteIdent(reflection.GetPrimaryKeyName(model))), id)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -586,7 +595,8 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url
|
||||
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
||||
for _, item := range updates {
|
||||
if itemID, ok := item["id"]; ok {
|
||||
txQuery := tx.NewUpdate().Table(tableName).SetMap(item).Where("id = ?", itemID)
|
||||
|
||||
txQuery := tx.NewUpdate().Table(tableName).SetMap(item).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(reflection.GetPrimaryKeyName(model))), itemID)
|
||||
if _, err := txQuery.Exec(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -652,7 +662,8 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url
|
||||
for _, item := range updates {
|
||||
if itemMap, ok := item.(map[string]interface{}); ok {
|
||||
if itemID, ok := itemMap["id"]; ok {
|
||||
txQuery := tx.NewUpdate().Table(tableName).SetMap(itemMap).Where("id = ?", itemID)
|
||||
|
||||
txQuery := tx.NewUpdate().Table(tableName).SetMap(itemMap).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(reflection.GetPrimaryKeyName(model))), itemID)
|
||||
if _, err := txQuery.Exec(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -688,6 +699,7 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
||||
schema := GetSchema(ctx)
|
||||
entity := GetEntity(ctx)
|
||||
tableName := GetTableName(ctx)
|
||||
model := GetModel(ctx)
|
||||
|
||||
logger.Info("Deleting records from %s.%s", schema, entity)
|
||||
|
||||
@@ -699,7 +711,8 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
||||
logger.Info("Batch delete with %d IDs ([]string)", len(v))
|
||||
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
||||
for _, itemID := range v {
|
||||
query := tx.NewDelete().Table(tableName).Where("id = ?", itemID)
|
||||
|
||||
query := tx.NewDelete().Table(tableName).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(reflection.GetPrimaryKeyName(model))), itemID)
|
||||
if _, err := query.Exec(ctx); err != nil {
|
||||
return fmt.Errorf("failed to delete record %s: %w", itemID, err)
|
||||
}
|
||||
@@ -738,7 +751,7 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
||||
continue // Skip items without ID
|
||||
}
|
||||
|
||||
query := tx.NewDelete().Table(tableName).Where("id = ?", itemID)
|
||||
query := tx.NewDelete().Table(tableName).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(reflection.GetPrimaryKeyName(model))), itemID)
|
||||
result, err := query.Exec(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete record %v: %w", itemID, err)
|
||||
@@ -763,7 +776,7 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
||||
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
||||
for _, item := range v {
|
||||
if itemID, ok := item["id"]; ok && itemID != nil {
|
||||
query := tx.NewDelete().Table(tableName).Where("id = ?", itemID)
|
||||
query := tx.NewDelete().Table(tableName).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(reflection.GetPrimaryKeyName(model))), itemID)
|
||||
result, err := query.Exec(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete record %v: %w", itemID, err)
|
||||
@@ -797,7 +810,7 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
||||
return
|
||||
}
|
||||
|
||||
query := h.db.NewDelete().Table(tableName).Where("id = ?", id)
|
||||
query := h.db.NewDelete().Table(tableName).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(reflection.GetPrimaryKeyName(model))), id)
|
||||
|
||||
result, err := query.Exec(ctx)
|
||||
if err != nil {
|
||||
@@ -1106,7 +1119,8 @@ func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, pre
|
||||
return query
|
||||
}
|
||||
|
||||
for _, preload := range preloads {
|
||||
for idx := range preloads {
|
||||
preload := preloads[idx]
|
||||
logger.Debug("Processing preload for relation: %s", preload.Relation)
|
||||
relInfo := h.getRelationshipInfo(modelType, preload.Relation)
|
||||
if relInfo == nil {
|
||||
@@ -1121,7 +1135,75 @@ func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, pre
|
||||
// For now, we'll preload without conditions
|
||||
// TODO: Implement column selection and filtering for preloads
|
||||
// This requires a more sophisticated approach with callbacks or query builders
|
||||
query = query.Preload(relationFieldName)
|
||||
// Apply preloading
|
||||
|
||||
logger.Debug("Applying preload: %s", relationFieldName)
|
||||
query = query.PreloadRelation(relationFieldName, func(sq common.SelectQuery) common.SelectQuery {
|
||||
if len(preload.OmitColumns) > 0 {
|
||||
allCols := reflection.GetModelColumns(model)
|
||||
// Remove omitted columns
|
||||
preload.Columns = []string{}
|
||||
for _, col := range allCols {
|
||||
addCols := true
|
||||
for _, omitCol := range preload.OmitColumns {
|
||||
if col == omitCol {
|
||||
addCols = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if addCols {
|
||||
preload.Columns = append(preload.Columns, col)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(preload.Columns) > 0 {
|
||||
// Ensure foreign key is included in column selection for GORM to establish the relationship
|
||||
columns := make([]string, len(preload.Columns))
|
||||
copy(columns, preload.Columns)
|
||||
|
||||
// Add foreign key if not already present
|
||||
if relInfo.foreignKey != "" {
|
||||
// Convert struct field name (e.g., DepartmentID) to snake_case (e.g., department_id)
|
||||
foreignKeyColumn := toSnakeCase(relInfo.foreignKey)
|
||||
|
||||
hasForeignKey := false
|
||||
for _, col := range columns {
|
||||
if col == foreignKeyColumn || col == relInfo.foreignKey {
|
||||
hasForeignKey = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasForeignKey {
|
||||
columns = append(columns, foreignKeyColumn)
|
||||
}
|
||||
}
|
||||
|
||||
sq = sq.Column(columns...)
|
||||
}
|
||||
|
||||
if len(preload.Filters) > 0 {
|
||||
for _, filter := range preload.Filters {
|
||||
sq = h.applyFilter(sq, filter)
|
||||
}
|
||||
}
|
||||
if len(preload.Sort) > 0 {
|
||||
for _, sort := range preload.Sort {
|
||||
sq = sq.Order(fmt.Sprintf("%s %s", sort.Column, sort.Direction))
|
||||
}
|
||||
}
|
||||
|
||||
if len(preload.Where) > 0 {
|
||||
sq = sq.Where(preload.Where)
|
||||
}
|
||||
|
||||
if preload.Limit != nil && *preload.Limit > 0 {
|
||||
sq = sq.Limit(*preload.Limit)
|
||||
}
|
||||
|
||||
return sq
|
||||
})
|
||||
|
||||
logger.Debug("Applied Preload for relation: %s (field: %s)", preload.Relation, relationFieldName)
|
||||
}
|
||||
|
||||
@@ -1179,3 +1261,28 @@ func (h *Handler) extractTagValue(tag, key string) string {
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// toSnakeCase converts a PascalCase or camelCase string to snake_case
|
||||
func toSnakeCase(s string) string {
|
||||
var result strings.Builder
|
||||
runes := []rune(s)
|
||||
|
||||
for i := 0; i < len(runes); i++ {
|
||||
r := runes[i]
|
||||
|
||||
if i > 0 && r >= 'A' && r <= 'Z' {
|
||||
// Check if previous character is lowercase or if next character is lowercase
|
||||
prevIsLower := runes[i-1] >= 'a' && runes[i-1] <= 'z'
|
||||
nextIsLower := i+1 < len(runes) && runes[i+1] >= 'a' && runes[i+1] <= 'z'
|
||||
|
||||
// Add underscore if this is the start of a new word
|
||||
// (previous was lowercase OR this is followed by lowercase)
|
||||
if prevIsLower || nextIsLower {
|
||||
result.WriteByte('_')
|
||||
}
|
||||
}
|
||||
|
||||
result.WriteRune(r)
|
||||
}
|
||||
return strings.ToLower(result.String())
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
423
pkg/restheadspec/handler_nested_test.go
Normal file
423
pkg/restheadspec/handler_nested_test.go
Normal file
@@ -0,0 +1,423 @@
|
||||
package restheadspec
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// Test models for nested CRUD operations
|
||||
type TestUser struct {
|
||||
ID int64 `json:"id" bun:"id,pk,autoincrement"`
|
||||
Name string `json:"name"`
|
||||
Posts []TestPost `json:"posts" gorm:"foreignKey:UserID"`
|
||||
}
|
||||
|
||||
type TestPost struct {
|
||||
ID int64 `json:"id" bun:"id,pk,autoincrement"`
|
||||
UserID int64 `json:"user_id"`
|
||||
Title string `json:"title"`
|
||||
Comments []TestComment `json:"comments" gorm:"foreignKey:PostID"`
|
||||
}
|
||||
|
||||
type TestComment struct {
|
||||
ID int64 `json:"id" bun:"id,pk,autoincrement"`
|
||||
PostID int64 `json:"post_id"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
func (TestUser) TableName() string { return "users" }
|
||||
func (TestPost) TableName() string { return "posts" }
|
||||
func (TestComment) TableName() string { return "comments" }
|
||||
|
||||
// Test extractNestedRelations function
|
||||
func TestExtractNestedRelations(t *testing.T) {
|
||||
// Create handler
|
||||
registry := &mockRegistry{
|
||||
models: map[string]interface{}{
|
||||
"users": TestUser{},
|
||||
"posts": TestPost{},
|
||||
"comments": TestComment{},
|
||||
},
|
||||
}
|
||||
handler := NewHandler(nil, registry)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
data map[string]interface{}
|
||||
model interface{}
|
||||
expectedCleanCount int
|
||||
expectedRelCount int
|
||||
}{
|
||||
{
|
||||
name: "User with posts",
|
||||
data: map[string]interface{}{
|
||||
"name": "John Doe",
|
||||
"posts": []map[string]interface{}{
|
||||
{"title": "Post 1"},
|
||||
},
|
||||
},
|
||||
model: TestUser{},
|
||||
expectedCleanCount: 1, // name
|
||||
expectedRelCount: 1, // posts
|
||||
},
|
||||
{
|
||||
name: "Post with comments",
|
||||
data: map[string]interface{}{
|
||||
"title": "Test Post",
|
||||
"comments": []map[string]interface{}{
|
||||
{"content": "Comment 1"},
|
||||
{"content": "Comment 2"},
|
||||
},
|
||||
},
|
||||
model: TestPost{},
|
||||
expectedCleanCount: 1, // title
|
||||
expectedRelCount: 1, // comments
|
||||
},
|
||||
{
|
||||
name: "User with nested posts and comments",
|
||||
data: map[string]interface{}{
|
||||
"name": "Jane Doe",
|
||||
"posts": []map[string]interface{}{
|
||||
{
|
||||
"title": "Post 1",
|
||||
"comments": []map[string]interface{}{
|
||||
{"content": "Comment 1"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
model: TestUser{},
|
||||
expectedCleanCount: 1, // name
|
||||
expectedRelCount: 1, // posts (which contains nested comments)
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cleanedData, relations, err := handler.extractNestedRelations(tt.data, tt.model)
|
||||
if err != nil {
|
||||
t.Errorf("extractNestedRelations() error = %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if len(cleanedData) != tt.expectedCleanCount {
|
||||
t.Errorf("Expected %d cleaned fields, got %d: %+v", tt.expectedCleanCount, len(cleanedData), cleanedData)
|
||||
}
|
||||
|
||||
if len(relations) != tt.expectedRelCount {
|
||||
t.Errorf("Expected %d relation fields, got %d: %+v", tt.expectedRelCount, len(relations), relations)
|
||||
}
|
||||
|
||||
t.Logf("Cleaned data: %+v", cleanedData)
|
||||
t.Logf("Relations: %+v", relations)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Test shouldUseNestedProcessor function
|
||||
func TestShouldUseNestedProcessor(t *testing.T) {
|
||||
registry := &mockRegistry{
|
||||
models: map[string]interface{}{
|
||||
"users": TestUser{},
|
||||
"posts": TestPost{},
|
||||
},
|
||||
}
|
||||
handler := NewHandler(nil, registry)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
data map[string]interface{}
|
||||
model interface{}
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "Data with simple nested posts (no further nesting)",
|
||||
data: map[string]interface{}{
|
||||
"name": "John",
|
||||
"posts": []map[string]interface{}{
|
||||
{"title": "Post 1"},
|
||||
},
|
||||
},
|
||||
model: TestUser{},
|
||||
expected: false, // Simple one-level nesting doesn't require nested processor
|
||||
},
|
||||
{
|
||||
name: "Data with deeply nested relations",
|
||||
data: map[string]interface{}{
|
||||
"name": "John",
|
||||
"posts": []map[string]interface{}{
|
||||
{
|
||||
"title": "Post 1",
|
||||
"comments": []map[string]interface{}{
|
||||
{"content": "Comment 1"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
model: TestUser{},
|
||||
expected: true, // Multi-level nesting requires nested processor
|
||||
},
|
||||
{
|
||||
name: "Data without nested relations",
|
||||
data: map[string]interface{}{
|
||||
"name": "John",
|
||||
},
|
||||
model: TestUser{},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Data with _request field",
|
||||
data: map[string]interface{}{
|
||||
"_request": "insert",
|
||||
"name": "John",
|
||||
},
|
||||
model: TestUser{},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Nested data with _request field",
|
||||
data: map[string]interface{}{
|
||||
"name": "John",
|
||||
"posts": []map[string]interface{}{
|
||||
{
|
||||
"_request": "insert",
|
||||
"title": "Post 1",
|
||||
},
|
||||
},
|
||||
},
|
||||
model: TestUser{},
|
||||
expected: true, // _request at nested level requires nested processor
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := handler.shouldUseNestedProcessor(tt.data, tt.model)
|
||||
if result != tt.expected {
|
||||
t.Errorf("shouldUseNestedProcessor() = %v, expected %v", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Test normalizeToSlice function
|
||||
func TestNormalizeToSlice(t *testing.T) {
|
||||
registry := &mockRegistry{}
|
||||
handler := NewHandler(nil, registry)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input interface{}
|
||||
expected int // expected slice length
|
||||
}{
|
||||
{
|
||||
name: "Single object",
|
||||
input: map[string]interface{}{"name": "John"},
|
||||
expected: 1,
|
||||
},
|
||||
{
|
||||
name: "Slice of objects",
|
||||
input: []map[string]interface{}{
|
||||
{"name": "John"},
|
||||
{"name": "Jane"},
|
||||
},
|
||||
expected: 2,
|
||||
},
|
||||
{
|
||||
name: "Array of interfaces",
|
||||
input: []interface{}{
|
||||
map[string]interface{}{"name": "John"},
|
||||
map[string]interface{}{"name": "Jane"},
|
||||
map[string]interface{}{"name": "Bob"},
|
||||
},
|
||||
expected: 3,
|
||||
},
|
||||
{
|
||||
name: "Nil input",
|
||||
input: nil,
|
||||
expected: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := handler.normalizeToSlice(tt.input)
|
||||
if len(result) != tt.expected {
|
||||
t.Errorf("normalizeToSlice() returned slice of length %d, expected %d", len(result), tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Test GetRelationshipInfo function
|
||||
func TestGetRelationshipInfo(t *testing.T) {
|
||||
registry := &mockRegistry{}
|
||||
handler := NewHandler(nil, registry)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
modelType reflect.Type
|
||||
relationName string
|
||||
expectNil bool
|
||||
}{
|
||||
{
|
||||
name: "User posts relation",
|
||||
modelType: reflect.TypeOf(TestUser{}),
|
||||
relationName: "posts",
|
||||
expectNil: false,
|
||||
},
|
||||
{
|
||||
name: "Post comments relation",
|
||||
modelType: reflect.TypeOf(TestPost{}),
|
||||
relationName: "comments",
|
||||
expectNil: false,
|
||||
},
|
||||
{
|
||||
name: "Non-existent relation",
|
||||
modelType: reflect.TypeOf(TestUser{}),
|
||||
relationName: "nonexistent",
|
||||
expectNil: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := handler.GetRelationshipInfo(tt.modelType, tt.relationName)
|
||||
if tt.expectNil && result != nil {
|
||||
t.Errorf("Expected nil, got %+v", result)
|
||||
}
|
||||
if !tt.expectNil && result == nil {
|
||||
t.Errorf("Expected non-nil relationship info")
|
||||
}
|
||||
if result != nil {
|
||||
t.Logf("Relationship info: FieldName=%s, JSONName=%s, RelationType=%s, ForeignKey=%s",
|
||||
result.FieldName, result.JSONName, result.RelationType, result.ForeignKey)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Mock registry for testing
|
||||
type mockRegistry struct {
|
||||
models map[string]interface{}
|
||||
}
|
||||
|
||||
func (m *mockRegistry) Register(name string, model interface{}) {
|
||||
m.RegisterModel(name, model)
|
||||
}
|
||||
|
||||
func (m *mockRegistry) RegisterModel(name string, model interface{}) error {
|
||||
if m.models == nil {
|
||||
m.models = make(map[string]interface{})
|
||||
}
|
||||
m.models[name] = model
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockRegistry) GetModelByEntity(schema, entity string) (interface{}, error) {
|
||||
if model, ok := m.models[entity]; ok {
|
||||
return model, nil
|
||||
}
|
||||
return nil, fmt.Errorf("model not found: %s", entity)
|
||||
}
|
||||
|
||||
func (m *mockRegistry) GetModelByName(name string) (interface{}, error) {
|
||||
if model, ok := m.models[name]; ok {
|
||||
return model, nil
|
||||
}
|
||||
return nil, fmt.Errorf("model not found: %s", name)
|
||||
}
|
||||
|
||||
func (m *mockRegistry) GetModel(name string) (interface{}, error) {
|
||||
return m.GetModelByName(name)
|
||||
}
|
||||
|
||||
func (m *mockRegistry) HasModel(schema, entity string) bool {
|
||||
_, ok := m.models[entity]
|
||||
return ok
|
||||
}
|
||||
|
||||
func (m *mockRegistry) ListModels() []string {
|
||||
models := make([]string, 0, len(m.models))
|
||||
for name := range m.models {
|
||||
models = append(models, name)
|
||||
}
|
||||
return models
|
||||
}
|
||||
|
||||
func (m *mockRegistry) GetAllModels() map[string]interface{} {
|
||||
return m.models
|
||||
}
|
||||
|
||||
// TestMultiLevelRelationExtraction tests extracting deeply nested relations
|
||||
func TestMultiLevelRelationExtraction(t *testing.T) {
|
||||
registry := &mockRegistry{
|
||||
models: map[string]interface{}{
|
||||
"users": TestUser{},
|
||||
"posts": TestPost{},
|
||||
"comments": TestComment{},
|
||||
},
|
||||
}
|
||||
handler := NewHandler(nil, registry)
|
||||
|
||||
// Test data with 3 levels: User -> Posts -> Comments
|
||||
testData := map[string]interface{}{
|
||||
"name": "John Doe",
|
||||
"posts": []map[string]interface{}{
|
||||
{
|
||||
"title": "First Post",
|
||||
"comments": []map[string]interface{}{
|
||||
{"content": "Great post!"},
|
||||
{"content": "Thanks for sharing!"},
|
||||
},
|
||||
},
|
||||
{
|
||||
"title": "Second Post",
|
||||
"comments": []map[string]interface{}{
|
||||
{"content": "Interesting read"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Extract relations from user
|
||||
cleanedData, relations, err := handler.extractNestedRelations(testData, TestUser{})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to extract relations: %v", err)
|
||||
}
|
||||
|
||||
// Verify user data is cleaned
|
||||
if len(cleanedData) != 1 || cleanedData["name"] != "John Doe" {
|
||||
t.Errorf("Expected cleaned data to contain only name, got: %+v", cleanedData)
|
||||
}
|
||||
|
||||
// Verify posts relation was extracted
|
||||
if len(relations) != 1 {
|
||||
t.Errorf("Expected 1 relation (posts), got %d", len(relations))
|
||||
}
|
||||
|
||||
posts, ok := relations["posts"]
|
||||
if !ok {
|
||||
t.Fatal("Expected posts relation to be extracted")
|
||||
}
|
||||
|
||||
// Verify posts is a slice with 2 items
|
||||
postsSlice, ok := posts.([]map[string]interface{})
|
||||
if !ok {
|
||||
t.Fatalf("Expected posts to be []map[string]interface{}, got %T", posts)
|
||||
}
|
||||
|
||||
if len(postsSlice) != 2 {
|
||||
t.Errorf("Expected 2 posts, got %d", len(postsSlice))
|
||||
}
|
||||
|
||||
// Verify first post has comments
|
||||
if _, hasComments := postsSlice[0]["comments"]; !hasComments {
|
||||
t.Error("Expected first post to have comments")
|
||||
}
|
||||
|
||||
t.Logf("Successfully extracted multi-level nested relations")
|
||||
t.Logf("Cleaned data: %+v", cleanedData)
|
||||
t.Logf("Relations: %d posts with nested comments", len(postsSlice))
|
||||
}
|
||||
@@ -37,6 +37,9 @@ type ExtendedRequestOptions struct {
|
||||
// Response format
|
||||
ResponseFormat string // "simple", "detail", "syncfusion"
|
||||
|
||||
// Single record normalization - convert single-element arrays to objects
|
||||
SingleRecordAsObject bool
|
||||
|
||||
// Transaction
|
||||
AtomicTransaction bool
|
||||
}
|
||||
@@ -99,10 +102,11 @@ func (h *Handler) parseOptionsFromHeaders(r common.Request) ExtendedRequestOptio
|
||||
Sort: make([]common.SortOption, 0),
|
||||
Preload: make([]common.PreloadOption, 0),
|
||||
},
|
||||
AdvancedSQL: make(map[string]string),
|
||||
ComputedQL: make(map[string]string),
|
||||
Expand: make([]ExpandOption, 0),
|
||||
ResponseFormat: "simple", // Default response format
|
||||
AdvancedSQL: make(map[string]string),
|
||||
ComputedQL: make(map[string]string),
|
||||
Expand: make([]ExpandOption, 0),
|
||||
ResponseFormat: "simple", // Default response format
|
||||
SingleRecordAsObject: true, // Default: normalize single-element arrays to objects
|
||||
}
|
||||
|
||||
// Get all headers
|
||||
@@ -146,7 +150,12 @@ func (h *Handler) parseOptionsFromHeaders(r common.Request) ExtendedRequestOptio
|
||||
|
||||
// Joins & Relations
|
||||
case strings.HasPrefix(normalizedKey, "x-preload"):
|
||||
h.parsePreload(&options, decodedValue)
|
||||
if strings.HasSuffix(normalizedKey, "-where") {
|
||||
continue
|
||||
}
|
||||
whereClaude := headers[fmt.Sprintf("%s-where", key)]
|
||||
h.parsePreload(&options, decodedValue, decodeHeaderValue(whereClaude))
|
||||
|
||||
case strings.HasPrefix(normalizedKey, "x-expand"):
|
||||
h.parseExpand(&options, decodedValue)
|
||||
case strings.HasPrefix(normalizedKey, "x-custom-sql-join"):
|
||||
@@ -194,6 +203,13 @@ func (h *Handler) parseOptionsFromHeaders(r common.Request) ExtendedRequestOptio
|
||||
options.ResponseFormat = "detail"
|
||||
case strings.HasPrefix(normalizedKey, "x-syncfusion"):
|
||||
options.ResponseFormat = "syncfusion"
|
||||
case strings.HasPrefix(normalizedKey, "x-single-record-as-object"):
|
||||
// Parse as boolean - "false" disables, "true" enables (default is true)
|
||||
if strings.EqualFold(decodedValue, "false") {
|
||||
options.SingleRecordAsObject = false
|
||||
} else if strings.EqualFold(decodedValue, "true") {
|
||||
options.SingleRecordAsObject = true
|
||||
}
|
||||
|
||||
// Transaction Control
|
||||
case strings.HasPrefix(normalizedKey, "x-transaction-atomic"):
|
||||
@@ -341,7 +357,15 @@ func (h *Handler) mapSearchOperator(colName, operator, value string) common.Filt
|
||||
|
||||
// parsePreload parses x-preload header
|
||||
// Format: RelationName:field1,field2 or RelationName or multiple separated by |
|
||||
func (h *Handler) parsePreload(options *ExtendedRequestOptions, value string) {
|
||||
func (h *Handler) parsePreload(options *ExtendedRequestOptions, values ...string) {
|
||||
if len(values) == 0 {
|
||||
return
|
||||
}
|
||||
value := values[0]
|
||||
whereClause := ""
|
||||
if len(values) > 1 {
|
||||
whereClause = values[1]
|
||||
}
|
||||
if value == "" {
|
||||
return
|
||||
}
|
||||
@@ -358,6 +382,7 @@ func (h *Handler) parsePreload(options *ExtendedRequestOptions, value string) {
|
||||
parts := strings.SplitN(preloadStr, ":", 2)
|
||||
preload := common.PreloadOption{
|
||||
Relation: strings.TrimSpace(parts[0]),
|
||||
Where: whereClause,
|
||||
}
|
||||
|
||||
if len(parts) == 2 {
|
||||
|
||||
@@ -106,7 +106,7 @@ func SetupMuxRoutes(muxRouter *mux.Router, handler *Handler) {
|
||||
reqAdapter := router.NewHTTPRequest(r)
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
handler.Handle(respAdapter, reqAdapter, vars)
|
||||
}).Methods("GET", "PUT", "PATCH", "DELETE")
|
||||
}).Methods("GET", "PUT", "PATCH", "DELETE", "POST")
|
||||
|
||||
// GET for metadata (using HandleGet)
|
||||
muxRouter.HandleFunc("/{schema}/{entity}/metadata", func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -189,6 +189,18 @@ func SetupBunRouterRoutes(bunRouter *router.StandardBunRouterAdapter, handler *H
|
||||
return nil
|
||||
})
|
||||
|
||||
r.Handle("POST", "/:schema/:entity/:id", func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
params := map[string]string{
|
||||
"schema": req.Param("schema"),
|
||||
"entity": req.Param("entity"),
|
||||
"id": req.Param("id"),
|
||||
}
|
||||
reqAdapter := router.NewBunRouterRequest(req)
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
handler.Handle(respAdapter, reqAdapter, params)
|
||||
return nil
|
||||
})
|
||||
|
||||
r.Handle("PUT", "/:schema/:entity/:id", func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
params := map[string]string{
|
||||
"schema": req.Param("schema"),
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
type TestModel struct {
|
||||
ID int64 `json:"id" bun:"id,pk"`
|
||||
Name string `json:"name" bun:"name"`
|
||||
RowNumber int64 `json:"_rownumber,omitempty" gorm:"-" bun:"-"`
|
||||
RowNumber int64 `json:"_rownumber,omitempty" gorm:"-" bun:",scanonly"`
|
||||
}
|
||||
|
||||
func TestSetRowNumbersOnRecords(t *testing.T) {
|
||||
|
||||
@@ -402,25 +402,41 @@ func testRestHeadSpecCRUD(t *testing.T, serverURL string) {
|
||||
resp := makeRestHeadSpecRequest(t, serverURL, fmt.Sprintf("/restheadspec/departments/%s", deptID), "GET", nil, nil)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
// RestHeadSpec may return data directly as array or wrapped in response object
|
||||
// RestHeadSpec may return data directly as array/object or wrapped in response object
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
assert.NoError(t, err, "Failed to read response body")
|
||||
|
||||
// Try to decode as array first (simple format)
|
||||
// Try to decode as array first (simple format - multiple records or disabled SingleRecordAsObject)
|
||||
var dataArray []interface{}
|
||||
if err := json.Unmarshal(body, &dataArray); err == nil {
|
||||
assert.GreaterOrEqual(t, len(dataArray), 1, "Should find department")
|
||||
logger.Info("Department read successfully (simple format): %s", deptID)
|
||||
logger.Info("Department read successfully (simple format - array): %s", deptID)
|
||||
return
|
||||
}
|
||||
|
||||
// Try to decode as standard response object (detail format)
|
||||
var result map[string]interface{}
|
||||
if err := json.Unmarshal(body, &result); err == nil {
|
||||
if success, ok := result["success"]; ok && success != nil && success.(bool) {
|
||||
if data, ok := result["data"].([]interface{}); ok {
|
||||
// Try to decode as a single object first (simple format with SingleRecordAsObject enabled)
|
||||
var singleObj map[string]interface{}
|
||||
if err := json.Unmarshal(body, &singleObj); err == nil {
|
||||
// Check if it's a data object (not a response wrapper)
|
||||
if _, hasSuccess := singleObj["success"]; !hasSuccess {
|
||||
// This is a direct data object (simple format, single record)
|
||||
assert.NotEmpty(t, singleObj, "Should find department")
|
||||
logger.Info("Department read successfully (simple format - single object): %s", deptID)
|
||||
return
|
||||
}
|
||||
|
||||
// Otherwise it's a standard response object (detail format)
|
||||
if success, ok := singleObj["success"]; ok && success != nil && success.(bool) {
|
||||
// Check if data is an array
|
||||
if data, ok := singleObj["data"].([]interface{}); ok {
|
||||
assert.GreaterOrEqual(t, len(data), 1, "Should find department")
|
||||
logger.Info("Department read successfully (detail format): %s", deptID)
|
||||
logger.Info("Department read successfully (detail format - array): %s", deptID)
|
||||
return
|
||||
}
|
||||
// Check if data is a single object (SingleRecordAsObject feature in detail format)
|
||||
if data, ok := singleObj["data"].(map[string]interface{}); ok {
|
||||
assert.NotEmpty(t, data, "Should find department")
|
||||
logger.Info("Department read successfully (detail format - single object): %s", deptID)
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -446,25 +462,41 @@ func testRestHeadSpecCRUD(t *testing.T, serverURL string) {
|
||||
resp := makeRestHeadSpecRequest(t, serverURL, "/restheadspec/employees", "GET", nil, headers)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
// RestHeadSpec may return data directly as array or wrapped in response object
|
||||
// RestHeadSpec may return data directly as array/object or wrapped in response object
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
assert.NoError(t, err, "Failed to read response body")
|
||||
|
||||
// Try array format first
|
||||
// Try array format first (multiple records or disabled SingleRecordAsObject)
|
||||
var dataArray []interface{}
|
||||
if err := json.Unmarshal(body, &dataArray); err == nil {
|
||||
assert.GreaterOrEqual(t, len(dataArray), 1, "Should find at least one employee")
|
||||
logger.Info("Employees read with filter successfully (simple format), found: %d", len(dataArray))
|
||||
logger.Info("Employees read with filter successfully (simple format - array), found: %d", len(dataArray))
|
||||
return
|
||||
}
|
||||
|
||||
// Try standard response format
|
||||
var result map[string]interface{}
|
||||
if err := json.Unmarshal(body, &result); err == nil {
|
||||
if success, ok := result["success"]; ok && success != nil && success.(bool) {
|
||||
if data, ok := result["data"].([]interface{}); ok {
|
||||
// Try to decode as a single object (simple format with SingleRecordAsObject enabled)
|
||||
var singleObj map[string]interface{}
|
||||
if err := json.Unmarshal(body, &singleObj); err == nil {
|
||||
// Check if it's a data object (not a response wrapper)
|
||||
if _, hasSuccess := singleObj["success"]; !hasSuccess {
|
||||
// This is a direct data object (simple format, single record)
|
||||
assert.NotEmpty(t, singleObj, "Should find at least one employee")
|
||||
logger.Info("Employees read with filter successfully (simple format - single object), found: 1")
|
||||
return
|
||||
}
|
||||
|
||||
// Otherwise it's a standard response object (detail format)
|
||||
if success, ok := singleObj["success"]; ok && success != nil && success.(bool) {
|
||||
// Check if data is an array
|
||||
if data, ok := singleObj["data"].([]interface{}); ok {
|
||||
assert.GreaterOrEqual(t, len(data), 1, "Should find at least one employee")
|
||||
logger.Info("Employees read with filter successfully (detail format), found: %d", len(data))
|
||||
logger.Info("Employees read with filter successfully (detail format - array), found: %d", len(data))
|
||||
return
|
||||
}
|
||||
// Check if data is a single object (SingleRecordAsObject feature in detail format)
|
||||
if data, ok := singleObj["data"].(map[string]interface{}); ok {
|
||||
assert.NotEmpty(t, data, "Should find at least one employee")
|
||||
logger.Info("Employees read with filter successfully (detail format - single object), found: 1")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user