Compare commits

...

23 Commits

Author SHA1 Message Date
Hein
c95bc9e633 Added x-files feature 2025-11-20 12:47:36 +02:00
Hein
07b09e2025 handle JSON sql columns 2025-11-20 12:04:19 +02:00
Hein
3d5334002d Fixes on Table Name on insert 2025-11-20 11:49:07 +02:00
Hein
640582d508 Better types 2025-11-20 11:40:16 +02:00
Hein
b0b3ae662b Common Sql Types 2025-11-20 11:18:49 +02:00
Hein
c9b9f75b06 Fixed go mod version issues 2025-11-20 10:34:27 +02:00
Hein
af3260864d INSERT statements were failing with duplicate key errors because the SQL being generated 2025-11-20 10:31:25 +02:00
Hein
ca6d2deff6 Fixed insert statement bug 2025-11-20 10:11:26 +02:00
Hein
1481443516 Fixed double .Model and .Table 2025-11-20 10:02:36 +02:00
Hein
cb54ec5e27 Better responses for updates and inserts 2025-11-20 09:57:24 +02:00
Hein
7d6a9025f5 Fixed hardcoded id 2025-11-20 09:40:11 +02:00
Hein
35089f511f correctly handle structs with embedded fields 2025-11-20 09:28:37 +02:00
Hein
66b6a0d835 Better registry handling
Some checks are pending
Tests / Run Tests (1.23.x) (push) Waiting to run
Tests / Run Tests (1.24.x) (push) Waiting to run
Tests / Lint Code (push) Waiting to run
Tests / Build (push) Waiting to run
2025-11-19 18:29:24 +02:00
Hein
456c165814 Fixed models being icorrectly set and added SetDefaultRegistry 2025-11-19 18:22:56 +02:00
Hein
850d7b546c Added modelregistry.AddRegistry 2025-11-19 18:18:18 +02:00
Hein
a44ef90d7c Fixes on getRelationshipInfo, ShouldUseNestedProcessor 2025-11-19 18:03:25 +02:00
Hein
8b7db5b31a reflection-based column validation for UpdateQuery 2025-11-19 17:41:15 +02:00
Hein
14daea3b05 Fixes for CUD operations 2025-11-19 15:08:04 +02:00
Hein
35f23b6d9e Recursive crud fix 2025-11-19 14:32:20 +02:00
Hein
53a4e67f70 Specifically call update if a ID was given. 2025-11-19 14:24:39 +02:00
Hein
1289c3af88 Fixed handling post routes as well for the restheadspec
Some checks are pending
Tests / Run Tests (1.23.x) (push) Waiting to run
Tests / Run Tests (1.24.x) (push) Waiting to run
Tests / Lint Code (push) Waiting to run
Tests / Build (push) Waiting to run
2025-11-19 14:04:56 +02:00
Hein
cdfb7a67fd Added Single Record as Object feature 2025-11-19 13:58:52 +02:00
Hein
7f5b851669 Empty sort appended bug fix
Some checks failed
Tests / Build (push) Has been cancelled
Tests / Run Tests (1.23.x) (push) Has been cancelled
Tests / Run Tests (1.24.x) (push) Has been cancelled
Tests / Lint Code (push) Has been cancelled
2025-11-11 17:16:59 +02:00
24 changed files with 4702 additions and 342 deletions

View File

@@ -31,6 +31,8 @@ Both share the same core architecture and provide dynamic data querying, relatio
- [RestHeadSpec: Header-Based API](#restheadspec-header-based-api-1)
- [Lifecycle Hooks](#lifecycle-hooks)
- [Cursor Pagination](#cursor-pagination)
- [Response Formats](#response-formats)
- [Single Record as Object](#single-record-as-object-default-behavior)
- [Example Usage](#example-usage)
- [Recursive CRUD Operations](#recursive-crud-operations-)
- [Testing](#testing)
@@ -59,6 +61,7 @@ Both share the same core architecture and provide dynamic data querying, relatio
- **🆕 Lifecycle Hooks**: Before/after hooks for create, read, update, and delete operations
- **🆕 Cursor Pagination**: Efficient cursor-based pagination with complex sort support
- **🆕 Multiple Response Formats**: Simple, detailed, and Syncfusion-compatible formats
- **🆕 Single Record as Object**: Automatically normalize single-element arrays to objects (enabled by default)
- **🆕 Advanced Filtering**: Field filters, search operators, AND/OR logic, and custom SQL
- **🆕 Base64 Encoding**: Support for base64-encoded header values
@@ -163,6 +166,7 @@ restheadspec.SetupMuxRoutes(router, handler)
| `X-Limit` | Limit results | `50` |
| `X-Offset` | Offset for pagination | `100` |
| `X-Clean-JSON` | Remove null/empty fields | `true` |
| `X-Single-Record-As-Object` | Return single records as objects (default: `true`) | `false` |
**Available Operators**: `eq`, `neq`, `gt`, `gte`, `lt`, `lte`, `contains`, `startswith`, `endswith`, `between`, `betweeninclusive`, `in`, `empty`, `notempty`
@@ -303,6 +307,55 @@ RestHeadSpec supports multiple response formats:
}
```
### Single Record as Object (Default Behavior)
By default, RestHeadSpec automatically converts single-element arrays into objects for cleaner API responses. This provides a better developer experience when fetching individual records.
**Default behavior (enabled)**:
```http
GET /public/users/123
```
```json
{
"success": true,
"data": { "id": 123, "name": "John", "email": "john@example.com" }
}
```
Instead of:
```json
{
"success": true,
"data": [{ "id": 123, "name": "John", "email": "john@example.com" }]
}
```
**To disable** (force arrays for consistency):
```http
GET /public/users/123
X-Single-Record-As-Object: false
```
```json
{
"success": true,
"data": [{ "id": 123, "name": "John", "email": "john@example.com" }]
}
```
**How it works**:
- When a query returns exactly **one record**, it's returned as an object
- When a query returns **multiple records**, they're returned as an array
- Set `X-Single-Record-As-Object: false` to always receive arrays
- Works with all response formats (simple, detail, syncfusion)
- Applies to both read operations and create/update returning clauses
**Benefits**:
- Cleaner API responses for single-record queries
- No need to unwrap single-element arrays on the client side
- Better TypeScript/type inference support
- Consistent with common REST API patterns
- Backward compatible via header opt-out
## Example Usage
### Reading Data with Related Entities
@@ -924,6 +977,7 @@ This project is licensed under the MIT License - see the [LICENSE](LICENSE) file
- **Cursor Pagination**: Efficient cursor-based pagination with complex sorting
- **Advanced Filtering**: Field filters, search operators, AND/OR logic
- **Multiple Response Formats**: Simple, detailed, and Syncfusion-compatible responses
- **Single Record as Object**: Automatically return single-element arrays as objects (default, toggleable via header)
- **Base64 Support**: Base64-encoded header values for complex queries
- **Type-Aware Filtering**: Automatic type detection and conversion for filters

19
go.mod
View File

@@ -8,7 +8,12 @@ require (
github.com/glebarez/sqlite v1.11.0
github.com/gorilla/mux v1.8.1
github.com/stretchr/testify v1.8.1
github.com/tidwall/gjson v1.18.0
github.com/tidwall/sjson v1.2.5
github.com/uptrace/bun v1.2.15
github.com/uptrace/bun/dialect/sqlitedialect v1.2.15
github.com/uptrace/bun/driver/sqliteshim v1.2.15
github.com/uptrace/bunrouter v1.0.23
go.uber.org/zap v1.27.0
gorm.io/gorm v1.25.12
)
@@ -21,23 +26,23 @@ require (
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mattn/go-sqlite3 v1.14.28 // indirect
github.com/ncruces/go-strftime v0.1.9 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/puzpuzpuz/xsync/v3 v3.5.1 // indirect
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
github.com/tidwall/gjson v1.18.0 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.0 // indirect
github.com/tidwall/sjson v1.2.5 // indirect
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc // indirect
github.com/uptrace/bunrouter v1.0.23 // indirect
github.com/vmihailenco/msgpack/v5 v5.4.1 // indirect
github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
go.uber.org/multierr v1.10.0 // indirect
golang.org/x/exp v0.0.0-20250711185948-6ae5c78190dc // indirect
golang.org/x/sys v0.34.0 // indirect
golang.org/x/text v0.21.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
modernc.org/libc v1.22.5 // indirect
modernc.org/mathutil v1.5.0 // indirect
modernc.org/memory v1.5.0 // indirect
modernc.org/sqlite v1.23.1 // indirect
modernc.org/libc v1.66.3 // indirect
modernc.org/mathutil v1.7.1 // indirect
modernc.org/memory v1.11.0 // indirect
modernc.org/sqlite v1.38.0 // indirect
)

55
go.sum
View File

@@ -7,8 +7,8 @@ github.com/glebarez/go-sqlite v1.21.2 h1:3a6LFC4sKahUunAmynQKLZceZCOzUthkRkEAl9g
github.com/glebarez/go-sqlite v1.21.2/go.mod h1:sfxdZyhQjTM2Wry3gVYWaW072Ri1WMdWJi0k6+3382k=
github.com/glebarez/sqlite v1.11.0 h1:wSG0irqzP6VurnMEpFGer5Li19RpIRi2qvQz++w0GMw=
github.com/glebarez/sqlite v1.11.0/go.mod h1:h8/o8j5wiAsqSPoWELDUdJXhjAhsVliSn7bWZjOhrgQ=
github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26 h1:Xim43kblpZXfIBQsbuBVKCudVG457BR2GZFIz3uw3hQ=
github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26/go.mod h1:dDKJzRmX4S37WGHujM7tX//fmj1uioxKzKxz3lo4HJo=
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY=
@@ -21,13 +21,16 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-sqlite3 v1.14.28 h1:ThEiQrnbtumT+QMknw63Befp/ce/nUPgBPMlRFEum7A=
github.com/mattn/go-sqlite3 v1.14.28/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4=
github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWbfPhv4DMiApHyliiK5xCTNVSPiaAs=
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/puzpuzpuz/xsync/v3 v3.5.1 h1:GJYJZwO6IdxN/IKbneznS6yPkVC+c3zyY/j19c++5Fg=
github.com/puzpuzpuz/xsync/v3 v3.5.1/go.mod h1:VjzYrABPabuM4KyBh1Ftq6u8nhwY5tBPKP9jpmh0nnA=
github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
@@ -50,6 +53,10 @@ github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc h1:9lRDQMhESg+zvGYm
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc/go.mod h1:bciPuU6GHm1iF1pBvUfxfsH0Wmnc2VbpgvbI9ZWuIRs=
github.com/uptrace/bun v1.2.15 h1:Ut68XRBLDgp9qG9QBMa9ELWaZOmzHNdczHQdrOZbEFE=
github.com/uptrace/bun v1.2.15/go.mod h1:Eghz7NonZMiTX/Z6oKYytJ0oaMEJ/eq3kEV4vSqG038=
github.com/uptrace/bun/dialect/sqlitedialect v1.2.15 h1:7upGMVjFRB1oI78GQw6ruNLblYn5CR+kxqcbbeBBils=
github.com/uptrace/bun/dialect/sqlitedialect v1.2.15/go.mod h1:c7YIDaPNS2CU2uI1p7umFuFWkuKbDcPDDvp+DLHZnkI=
github.com/uptrace/bun/driver/sqliteshim v1.2.15 h1:M/rZJSjOPV4OmfTVnDPtL+wJmdMTqDUn8cuk5ycfABA=
github.com/uptrace/bun/driver/sqliteshim v1.2.15/go.mod h1:YqwxFyvM992XOCpGJtXyKPkgkb+aZpIIMzGbpaw1hIk=
github.com/uptrace/bunrouter v1.0.23 h1:Bi7NKw3uCQkcA/GUCtDNPq5LE5UdR9pe+UyWbjHB/wU=
github.com/uptrace/bunrouter v1.0.23/go.mod h1:O3jAcl+5qgnF+ejhgkmbceEk0E/mqaK+ADOocdNpY8M=
github.com/vmihailenco/msgpack/v5 v5.4.1 h1:cQriyiUvjTwOHg8QZaPihLWeRAAVoCpE00IUPn0Bjt8=
@@ -62,11 +69,19 @@ go.uber.org/multierr v1.10.0 h1:S0h4aNzvfcFsC3dRF1jLoaov7oRaKqRGC/pUEJ2yvPQ=
go.uber.org/multierr v1.10.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8=
go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
golang.org/x/exp v0.0.0-20250711185948-6ae5c78190dc h1:TS73t7x3KarrNd5qAipmspBDS1rkMcgVG/fS1aRb4Rc=
golang.org/x/exp v0.0.0-20250711185948-6ae5c78190dc/go.mod h1:A+z0yzpGtvnG90cToK5n2tu8UJVP2XUATh+r+sfOOOc=
golang.org/x/mod v0.26.0 h1:EGMPT//Ezu+ylkCijjPc+f4Aih7sZvaAr+O3EHBxvZg=
golang.org/x/mod v0.26.0/go.mod h1:/j6NAhSk8iQ723BGAUyoAcn7SlD7s15Dp9Nd/SfeaFQ=
golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw=
golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.34.0 h1:H5Y5sJ2L2JRdyv7ROF1he/lPdvFsd0mJHFw2ThKHxLA=
golang.org/x/sys v0.34.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
golang.org/x/tools v0.35.0 h1:mBffYraMEf7aa0sB+NuKnuCy8qI/9Bughn8dC2Gu5r0=
golang.org/x/tools v0.35.0/go.mod h1:NKdj5HkL/73byiZSJjqJgKn3ep7KjFkBOkR/Hps3VPw=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f h1:BLraFXnmrev5lT+xlilqcH8XK9/i0At2xKjWk4p6zsU=
gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
@@ -75,11 +90,29 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gorm.io/gorm v1.25.12 h1:I0u8i2hWQItBq1WfE0o2+WuL9+8L21K9e2HHSTE/0f8=
gorm.io/gorm v1.25.12/go.mod h1:xh7N7RHfYlNc5EmcI/El95gXusucDrQnHXe0+CgWcLQ=
modernc.org/libc v1.22.5 h1:91BNch/e5B0uPbJFgqbxXuOnxBQjlS//icfQEGmvyjE=
modernc.org/libc v1.22.5/go.mod h1:jj+Z7dTNX8fBScMVNRAYZ/jF91K8fdT2hYMThc3YjBY=
modernc.org/mathutil v1.5.0 h1:rV0Ko/6SfM+8G+yKiyI830l3Wuz1zRutdslNoQ0kfiQ=
modernc.org/mathutil v1.5.0/go.mod h1:mZW8CKdRPY1v87qxC/wUdX5O1qDzXMP5TH3wjfpga6E=
modernc.org/memory v1.5.0 h1:N+/8c5rE6EqugZwHii4IFsaJ7MUhoWX07J5tC/iI5Ds=
modernc.org/memory v1.5.0/go.mod h1:PkUhL0Mugw21sHPeskwZW4D6VscE/GQJOnIpCnW6pSU=
modernc.org/sqlite v1.23.1 h1:nrSBg4aRQQwq59JpvGEQ15tNxoO5pX/kUjcRNwSAGQM=
modernc.org/sqlite v1.23.1/go.mod h1:OrDj17Mggn6MhE+iPbBNf7RGKODDE9NFT0f3EwDzJqk=
modernc.org/cc/v4 v4.26.2 h1:991HMkLjJzYBIfha6ECZdjrIYz2/1ayr+FL8GN+CNzM=
modernc.org/cc/v4 v4.26.2/go.mod h1:uVtb5OGqUKpoLWhqwNQo/8LwvoiEBLvZXIQ/SmO6mL0=
modernc.org/ccgo/v4 v4.28.0 h1:rjznn6WWehKq7dG4JtLRKxb52Ecv8OUGah8+Z/SfpNU=
modernc.org/ccgo/v4 v4.28.0/go.mod h1:JygV3+9AV6SmPhDasu4JgquwU81XAKLd3OKTUDNOiKE=
modernc.org/fileutil v1.3.8 h1:qtzNm7ED75pd1C7WgAGcK4edm4fvhtBsEiI/0NQ54YM=
modernc.org/fileutil v1.3.8/go.mod h1:HxmghZSZVAz/LXcMNwZPA/DRrQZEVP9VX0V4LQGQFOc=
modernc.org/gc/v2 v2.6.5 h1:nyqdV8q46KvTpZlsw66kWqwXRHdjIlJOhG6kxiV/9xI=
modernc.org/gc/v2 v2.6.5/go.mod h1:YgIahr1ypgfe7chRuJi2gD7DBQiKSLMPgBQe9oIiito=
modernc.org/goabi0 v0.2.0 h1:HvEowk7LxcPd0eq6mVOAEMai46V+i7Jrj13t4AzuNks=
modernc.org/goabi0 v0.2.0/go.mod h1:CEFRnnJhKvWT1c1JTI3Avm+tgOWbkOu5oPA8eH8LnMI=
modernc.org/libc v1.66.3 h1:cfCbjTUcdsKyyZZfEUKfoHcP3S0Wkvz3jgSzByEWVCQ=
modernc.org/libc v1.66.3/go.mod h1:XD9zO8kt59cANKvHPXpx7yS2ELPheAey0vjIuZOhOU8=
modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU=
modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg=
modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI=
modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw=
modernc.org/opt v0.1.4 h1:2kNGMRiUjrp4LcaPuLY2PzUfqM/w9N23quVwhKt5Qm8=
modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns=
modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w=
modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE=
modernc.org/sqlite v1.38.0 h1:+4OrfPQ8pxHKuWG4md1JpR/EYAh3Md7TdejuuzE7EUI=
modernc.org/sqlite v1.38.0/go.mod h1:1Bj+yES4SVvBZ4cBOpVZ6QgesMCKpJZDq0nxYzOpmNE=
modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0=
modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A=
modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y=
modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM=

View File

@@ -9,6 +9,8 @@ import (
"github.com/uptrace/bun"
"github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
"github.com/bitechdev/ResolveSpec/pkg/reflection"
)
// BunAdapter adapts Bun to work with our Database interface
@@ -305,16 +307,22 @@ func (b *BunSelectQuery) Exists(ctx context.Context) (bool, error) {
// BunInsertQuery implements InsertQuery for Bun
type BunInsertQuery struct {
query *bun.InsertQuery
values map[string]interface{}
query *bun.InsertQuery
values map[string]interface{}
hasModel bool
}
func (b *BunInsertQuery) Model(model interface{}) common.InsertQuery {
b.query = b.query.Model(model)
b.hasModel = true
return b
}
func (b *BunInsertQuery) Table(table string) common.InsertQuery {
if b.hasModel {
// If model is set, do not override table name
return b
}
b.query = b.query.Table(table)
return b
}
@@ -340,10 +348,16 @@ func (b *BunInsertQuery) Returning(columns ...string) common.InsertQuery {
}
func (b *BunInsertQuery) Exec(ctx context.Context) (common.Result, error) {
if b.values != nil {
// For Bun, we need to handle this differently
for k, v := range b.values {
b.query = b.query.Set("? = ?", bun.Ident(k), v)
if b.values != nil && len(b.values) > 0 {
if !b.hasModel {
// If no model was set, use the values map as the model
// Bun can insert map[string]interface{} directly
b.query = b.query.Model(&b.values)
} else {
// If model was set, use Value() to add individual values
for k, v := range b.values {
b.query = b.query.Value(k, "?", v)
}
}
}
result, err := b.query.Exec(ctx)
@@ -353,25 +367,50 @@ func (b *BunInsertQuery) Exec(ctx context.Context) (common.Result, error) {
// BunUpdateQuery implements UpdateQuery for Bun
type BunUpdateQuery struct {
query *bun.UpdateQuery
model interface{}
}
func (b *BunUpdateQuery) Model(model interface{}) common.UpdateQuery {
b.query = b.query.Model(model)
b.model = model
return b
}
func (b *BunUpdateQuery) Table(table string) common.UpdateQuery {
b.query = b.query.Table(table)
if b.model == nil {
// Try to get table name from table string if model is not set
model, err := modelregistry.GetModelByName(table)
if err == nil {
b.model = model
}
}
return b
}
func (b *BunUpdateQuery) Set(column string, value interface{}) common.UpdateQuery {
// Validate column is writable if model is set
if b.model != nil && !reflection.IsColumnWritable(b.model, column) {
// Skip scan-only columns
return b
}
b.query = b.query.Set(column+" = ?", value)
return b
}
func (b *BunUpdateQuery) SetMap(values map[string]interface{}) common.UpdateQuery {
pkName := reflection.GetPrimaryKeyName(b.model)
for column, value := range values {
// Validate column is writable if model is set
if b.model != nil && !reflection.IsColumnWritable(b.model, column) {
// Skip scan-only columns
continue
}
if pkName != "" && column == pkName {
// Skip primary key updates
continue
}
b.query = b.query.Set(column+" = ?", value)
}
return b

View File

@@ -0,0 +1,213 @@
package database
import (
"context"
"database/sql"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/uptrace/bun"
"github.com/uptrace/bun/dialect/sqlitedialect"
"github.com/uptrace/bun/driver/sqliteshim"
)
// TestInsertModel is a test model for insert operations
type TestInsertModel struct {
bun.BaseModel `bun:"table:test_inserts"`
ID int64 `bun:"id,pk,autoincrement"`
Name string `bun:"name,notnull"`
Email string `bun:"email"`
Age int `bun:"age"`
}
func setupBunTestDB(t *testing.T) *bun.DB {
sqldb, err := sql.Open(sqliteshim.ShimName, "file::memory:?cache=shared")
require.NoError(t, err, "Failed to open SQLite database")
db := bun.NewDB(sqldb, sqlitedialect.New())
// Create test table
_, err = db.NewCreateTable().
Model((*TestInsertModel)(nil)).
IfNotExists().
Exec(context.Background())
require.NoError(t, err, "Failed to create test table")
return db
}
func TestBunInsertQuery_Model(t *testing.T) {
db := setupBunTestDB(t)
defer db.Close()
adapter := NewBunAdapter(db)
ctx := context.Background()
// Test inserting with Model()
model := &TestInsertModel{
Name: "John Doe",
Email: "john@example.com",
Age: 30,
}
result, err := adapter.NewInsert().
Model(model).
Returning("*").
Exec(ctx)
require.NoError(t, err, "Insert should succeed")
assert.Equal(t, int64(1), result.RowsAffected(), "Should insert 1 row")
// Verify the data was inserted
var retrieved TestInsertModel
err = db.NewSelect().
Model(&retrieved).
Where("id = ?", model.ID).
Scan(ctx)
require.NoError(t, err, "Should retrieve inserted row")
assert.Equal(t, "John Doe", retrieved.Name)
assert.Equal(t, "john@example.com", retrieved.Email)
assert.Equal(t, 30, retrieved.Age)
}
func TestBunInsertQuery_Value(t *testing.T) {
db := setupBunTestDB(t)
defer db.Close()
adapter := NewBunAdapter(db)
ctx := context.Background()
// Test inserting with Value() method - this was the bug
result, err := adapter.NewInsert().
Table("test_inserts").
Value("name", "Jane Smith").
Value("email", "jane@example.com").
Value("age", 25).
Exec(ctx)
require.NoError(t, err, "Insert with Value() should succeed")
assert.Equal(t, int64(1), result.RowsAffected(), "Should insert 1 row")
// Verify the data was inserted
var retrieved TestInsertModel
err = db.NewSelect().
Model(&retrieved).
Where("name = ?", "Jane Smith").
Scan(ctx)
require.NoError(t, err, "Should retrieve inserted row")
assert.Equal(t, "Jane Smith", retrieved.Name)
assert.Equal(t, "jane@example.com", retrieved.Email)
assert.Equal(t, 25, retrieved.Age)
}
func TestBunInsertQuery_MultipleValues(t *testing.T) {
db := setupBunTestDB(t)
defer db.Close()
adapter := NewBunAdapter(db)
ctx := context.Background()
// Test inserting multiple values
result, err := adapter.NewInsert().
Table("test_inserts").
Value("name", "Alice").
Value("email", "alice@example.com").
Value("age", 28).
Exec(ctx)
require.NoError(t, err, "First insert should succeed")
assert.Equal(t, int64(1), result.RowsAffected())
result, err = adapter.NewInsert().
Table("test_inserts").
Value("name", "Bob").
Value("email", "bob@example.com").
Value("age", 35).
Exec(ctx)
require.NoError(t, err, "Second insert should succeed")
assert.Equal(t, int64(1), result.RowsAffected())
// Verify both rows exist
var count int
count, err = db.NewSelect().
Model((*TestInsertModel)(nil)).
Count(ctx)
require.NoError(t, err, "Count should succeed")
assert.Equal(t, 2, count, "Should have 2 rows")
}
func TestBunInsertQuery_ValueWithNil(t *testing.T) {
db := setupBunTestDB(t)
defer db.Close()
adapter := NewBunAdapter(db)
ctx := context.Background()
// Test inserting with nil value for nullable field
result, err := adapter.NewInsert().
Table("test_inserts").
Value("name", "Test User").
Value("email", nil). // NULL email
Value("age", 20).
Exec(ctx)
require.NoError(t, err, "Insert with nil value should succeed")
assert.Equal(t, int64(1), result.RowsAffected())
// Verify the data was inserted with NULL email
var retrieved TestInsertModel
err = db.NewSelect().
Model(&retrieved).
Where("name = ?", "Test User").
Scan(ctx)
require.NoError(t, err, "Should retrieve inserted row")
assert.Equal(t, "Test User", retrieved.Name)
assert.Equal(t, "", retrieved.Email) // NULL becomes empty string
assert.Equal(t, 20, retrieved.Age)
}
func TestBunInsertQuery_Returning(t *testing.T) {
db := setupBunTestDB(t)
defer db.Close()
adapter := NewBunAdapter(db)
ctx := context.Background()
// Test insert with RETURNING clause
// Note: SQLite has limited RETURNING support, but this tests the API
result, err := adapter.NewInsert().
Table("test_inserts").
Value("name", "Return Test").
Value("email", "return@example.com").
Value("age", 40).
Returning("*").
Exec(ctx)
require.NoError(t, err, "Insert with RETURNING should succeed")
assert.Equal(t, int64(1), result.RowsAffected())
}
func TestBunInsertQuery_EmptyValues(t *testing.T) {
db := setupBunTestDB(t)
defer db.Close()
adapter := NewBunAdapter(db)
ctx := context.Background()
// Test insert without calling Value() - should use Model() or fail gracefully
result, err := adapter.NewInsert().
Table("test_inserts").
Exec(ctx)
// This should fail because no values are provided
assert.Error(t, err, "Insert without values should fail")
if result != nil {
assert.Equal(t, int64(0), result.RowsAffected())
}
}

View File

@@ -8,6 +8,8 @@ import (
"gorm.io/gorm"
"github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
"github.com/bitechdev/ResolveSpec/pkg/reflection"
)
// GormAdapter adapts GORM to work with our Database interface
@@ -97,6 +99,7 @@ func (g *GormSelectQuery) Table(table string) common.SelectQuery {
g.db = g.db.Table(table)
// Check if the table name contains schema (e.g., "schema.table")
g.schema, g.tableName = parseTableName(table)
return g
}
@@ -339,10 +342,23 @@ func (g *GormUpdateQuery) Model(model interface{}) common.UpdateQuery {
func (g *GormUpdateQuery) Table(table string) common.UpdateQuery {
g.db = g.db.Table(table)
if g.model == nil {
// Try to get table name from table string if model is not set
model, err := modelregistry.GetModelByName(table)
if err == nil {
g.model = model
}
}
return g
}
func (g *GormUpdateQuery) Set(column string, value interface{}) common.UpdateQuery {
// Validate column is writable if model is set
if g.model != nil && !reflection.IsColumnWritable(g.model, column) {
// Skip read-only columns
return g
}
if g.updates == nil {
g.updates = make(map[string]interface{})
}
@@ -353,7 +369,25 @@ func (g *GormUpdateQuery) Set(column string, value interface{}) common.UpdateQue
}
func (g *GormUpdateQuery) SetMap(values map[string]interface{}) common.UpdateQuery {
g.updates = values
// Filter out read-only columns if model is set
if g.model != nil {
pkName := reflection.GetPrimaryKeyName(g.model)
filteredValues := make(map[string]interface{})
for column, value := range values {
if pkName != "" && column == pkName {
// Skip primary key updates
continue
}
if reflection.IsColumnWritable(g.model, column) {
filteredValues[column] = value
}
}
g.updates = filteredValues
} else {
g.updates = values
}
return g
}

View File

@@ -0,0 +1,161 @@
package database
import (
"testing"
"github.com/bitechdev/ResolveSpec/pkg/reflection"
)
// Test models for bun
type BunTestModel struct {
ID int `bun:"id,pk"`
Name string `bun:"name"`
Email string `bun:"email"`
ComputedCol string `bun:"computed_col,scanonly"`
}
// Test models for gorm
type GormTestModel struct {
ID int `gorm:"column:id;primaryKey"`
Name string `gorm:"column:name"`
Email string `gorm:"column:email"`
ReadOnlyCol string `gorm:"column:readonly_col;->"`
NoWriteCol string `gorm:"column:nowrite_col;<-:false"`
}
func TestIsColumnWritable_Bun(t *testing.T) {
model := &BunTestModel{}
tests := []struct {
name string
columnName string
expected bool
}{
{
name: "writable column - id",
columnName: "id",
expected: true,
},
{
name: "writable column - name",
columnName: "name",
expected: true,
},
{
name: "writable column - email",
columnName: "email",
expected: true,
},
{
name: "scanonly column should not be writable",
columnName: "computed_col",
expected: false,
},
{
name: "non-existent column should be writable (dynamic)",
columnName: "nonexistent",
expected: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := reflection.IsColumnWritable(model, tt.columnName)
if result != tt.expected {
t.Errorf("IsColumnWritable(%q) = %v, want %v", tt.columnName, result, tt.expected)
}
})
}
}
func TestIsColumnWritable_Gorm(t *testing.T) {
model := &GormTestModel{}
tests := []struct {
name string
columnName string
expected bool
}{
{
name: "writable column - id",
columnName: "id",
expected: true,
},
{
name: "writable column - name",
columnName: "name",
expected: true,
},
{
name: "writable column - email",
columnName: "email",
expected: true,
},
{
name: "read-only column with -> should not be writable",
columnName: "readonly_col",
expected: false,
},
{
name: "column with <-:false should not be writable",
columnName: "nowrite_col",
expected: false,
},
{
name: "non-existent column should be writable (dynamic)",
columnName: "nonexistent",
expected: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := reflection.IsColumnWritable(model, tt.columnName)
if result != tt.expected {
t.Errorf("IsColumnWritable(%q) = %v, want %v", tt.columnName, result, tt.expected)
}
})
}
}
func TestBunUpdateQuery_SetMap_FiltersScanOnly(t *testing.T) {
// Note: This is a unit test for the validation logic only.
// We can't fully test the bun query without a database connection,
// but we've verified the validation logic in TestIsColumnWritable_Bun
t.Skip("Skipping integration test - validation logic tested in TestIsColumnWritable_Bun")
}
func TestGormUpdateQuery_SetMap_FiltersReadOnly(t *testing.T) {
model := &GormTestModel{}
query := &GormUpdateQuery{
model: model,
}
// SetMap should filter out read-only columns
values := map[string]interface{}{
"name": "John",
"email": "john@example.com",
"readonly_col": "should_be_filtered",
"nowrite_col": "should_also_be_filtered",
}
query.SetMap(values)
// Check that the updates map only contains writable columns
if updates, ok := query.updates.(map[string]interface{}); ok {
if _, exists := updates["readonly_col"]; exists {
t.Error("readonly_col should have been filtered out")
}
if _, exists := updates["nowrite_col"]; exists {
t.Error("nowrite_col should have been filtered out")
}
if _, exists := updates["name"]; !exists {
t.Error("name should be in updates")
}
if _, exists := updates["email"]; !exists {
t.Error("email should be in updates")
}
} else {
t.Error("updates should be a map[string]interface{}")
}
}

View File

@@ -111,6 +111,9 @@ func (p *NestedCUDProcessor) ProcessNestedCUD(
// Inject parent IDs for foreign key resolution
p.injectForeignKeys(regularData, modelType, parentIDs)
// Get the primary key name for this model
pkName := reflection.GetPrimaryKeyName(model)
// Process based on operation
switch strings.ToLower(operation) {
case "insert", "create":
@@ -128,30 +131,30 @@ func (p *NestedCUDProcessor) ProcessNestedCUD(
}
case "update":
rows, err := p.processUpdate(ctx, regularData, tableName, data["id"])
rows, err := p.processUpdate(ctx, regularData, tableName, data[pkName])
if err != nil {
return nil, fmt.Errorf("update failed: %w", err)
}
result.ID = data["id"]
result.ID = data[pkName]
result.AffectedRows = rows
result.Data = regularData
// Process child relations for update
if err := p.processChildRelations(ctx, "update", data["id"], relationFields, result.RelationData, modelType); err != nil {
if err := p.processChildRelations(ctx, "update", data[pkName], relationFields, result.RelationData, modelType); err != nil {
return nil, fmt.Errorf("failed to process child relations: %w", err)
}
case "delete":
// Process child relations first (for referential integrity)
if err := p.processChildRelations(ctx, "delete", data["id"], relationFields, result.RelationData, modelType); err != nil {
if err := p.processChildRelations(ctx, "delete", data[pkName], relationFields, result.RelationData, modelType); err != nil {
return nil, fmt.Errorf("failed to process child relations before delete: %w", err)
}
rows, err := p.processDelete(ctx, tableName, data["id"])
rows, err := p.processDelete(ctx, tableName, data[pkName])
if err != nil {
return nil, fmt.Errorf("delete failed: %w", err)
}
result.ID = data["id"]
result.ID = data[pkName]
result.AffectedRows = rows
result.Data = regularData
@@ -378,8 +381,16 @@ func (p *NestedCUDProcessor) getTableNameForModel(model interface{}, defaultName
}
// ShouldUseNestedProcessor determines if we should use nested CUD processing
// It checks if the data contains nested relations or a _request field
// It recursively checks if the data contains:
// 1. A _request field at any level, OR
// 2. Nested relations that themselves contain further nested relations or _request fields
// This ensures nested processing is only used when there are deeply nested operations
func ShouldUseNestedProcessor(data map[string]interface{}, model interface{}, relationshipHelper RelationshipInfoProvider) bool {
return shouldUseNestedProcessorDepth(data, model, relationshipHelper, 0)
}
// shouldUseNestedProcessorDepth is the internal recursive implementation with depth tracking
func shouldUseNestedProcessorDepth(data map[string]interface{}, model interface{}, relationshipHelper RelationshipInfoProvider, depth int) bool {
// Check for _request field
if _, hasCRUDRequest := data["_request"]; hasCRUDRequest {
return true
@@ -406,10 +417,34 @@ func ShouldUseNestedProcessor(data map[string]interface{}, model interface{}, re
relInfo := relationshipHelper.GetRelationshipInfo(modelType, key)
if relInfo != nil {
// Check if the value is actually nested data (object or array)
switch value.(type) {
switch v := value.(type) {
case map[string]interface{}, []interface{}, []map[string]interface{}:
logger.Debug("Found nested relation field: %s", key)
return true
// If we're already at a nested level (depth > 0) and found a relation,
// that means we have multi-level nesting, so return true
if depth > 0 {
return true
}
// At depth 0, recurse to check if the nested data has further nesting
switch typedValue := v.(type) {
case map[string]interface{}:
if shouldUseNestedProcessorDepth(typedValue, relInfo.RelatedModel, relationshipHelper, depth+1) {
return true
}
case []interface{}:
for _, item := range typedValue {
if itemMap, ok := item.(map[string]interface{}); ok {
if shouldUseNestedProcessorDepth(itemMap, relInfo.RelatedModel, relationshipHelper, depth+1) {
return true
}
}
}
case []map[string]interface{}:
for _, itemMap := range typedValue {
if shouldUseNestedProcessorDepth(itemMap, relInfo.RelatedModel, relationshipHelper, depth+1) {
return true
}
}
}
}
}
}

774
pkg/common/sql_types.go Normal file
View File

@@ -0,0 +1,774 @@
package common
import (
"database/sql"
"database/sql/driver"
"encoding/json"
"fmt"
"strconv"
"strings"
"time"
"github.com/google/uuid"
)
func tryParseDT(str string) (time.Time, error) {
var lasterror error
tryFormats := []string{time.RFC3339,
"2006-01-02T15:04:05.000-0700",
"2006-01-02T15:04:05.000",
"06-01-02T15:04:05.000",
"2006-01-02T15:04:05",
"2006-01-02 15:04:05",
"02/01/2006",
"02-01-2006",
"2006-01-02",
"15:04:05.000",
"15:04:05",
"15:04"}
for _, f := range tryFormats {
tx, err := time.Parse(f, str)
if err == nil {
return tx, nil
} else {
lasterror = err
}
}
return time.Now(), lasterror
}
func ToJSONDT(dt time.Time) string {
return dt.Format(time.RFC3339)
}
// SqlInt16 - A Int16 that supports SQL string
type SqlInt16 int16
// Scan -
func (n *SqlInt16) Scan(value interface{}) error {
if value == nil {
*n = 0
return nil
}
switch v := value.(type) {
case int:
*n = SqlInt16(v)
case int32:
*n = SqlInt16(v)
case int64:
*n = SqlInt16(v)
default:
i, _ := strconv.ParseInt(fmt.Sprintf("%v", v), 10, 64)
*n = SqlInt16(i)
}
return nil
}
// Value -
func (n SqlInt16) Value() (driver.Value, error) {
if n == 0 {
return nil, nil
}
return int64(n), nil
}
// String - Override String format of ZNullInt32
func (n SqlInt16) String() string {
tmstr := fmt.Sprintf("%d", n)
return tmstr
}
// UnmarshalJSON - Overre JidSON format of ZNullInt32
func (n *SqlInt16) UnmarshalJSON(b []byte) error {
s := strings.Trim(strings.Trim(string(b), " "), "\"")
n64, err := strconv.ParseInt(s, 10, 64)
if err == nil {
*n = SqlInt16(n64)
}
return nil
}
// MarshalJSON - Override JSON format of time
func (n SqlInt16) MarshalJSON() ([]byte, error) {
return []byte(fmt.Sprintf("%d", n)), nil
}
// SqlInt32 - A int32 that supports SQL string
type SqlInt32 int32
// Scan -
func (n *SqlInt32) Scan(value interface{}) error {
if value == nil {
*n = 0
return nil
}
switch v := value.(type) {
case int:
*n = SqlInt32(v)
case int32:
*n = SqlInt32(v)
case int64:
*n = SqlInt32(v)
default:
i, _ := strconv.ParseInt(fmt.Sprintf("%v", v), 10, 64)
*n = SqlInt32(i)
}
return nil
}
// Value -
func (n SqlInt32) Value() (driver.Value, error) {
if n == 0 {
return nil, nil
}
return int64(n), nil
}
// String - Override String format of ZNullInt32
func (n SqlInt32) String() string {
tmstr := fmt.Sprintf("%d", n)
return tmstr
}
// UnmarshalJSON - Overre JidSON format of ZNullInt32
func (n *SqlInt32) UnmarshalJSON(b []byte) error {
s := strings.Trim(strings.Trim(string(b), " "), "\"")
n64, err := strconv.ParseInt(s, 10, 64)
if err == nil {
*n = SqlInt32(n64)
}
return nil
}
// MarshalJSON - Override JSON format of time
func (n SqlInt32) MarshalJSON() ([]byte, error) {
return []byte(fmt.Sprintf("%d", n)), nil
}
// SqlInt64 - A int64 that supports SQL string
type SqlInt64 int64
// Scan -
func (n *SqlInt64) Scan(value interface{}) error {
if value == nil {
*n = 0
return nil
}
switch v := value.(type) {
case int:
*n = SqlInt64(v)
case int32:
*n = SqlInt64(v)
case uint32:
*n = SqlInt64(v)
case int64:
*n = SqlInt64(v)
case uint64:
*n = SqlInt64(v)
default:
i, _ := strconv.ParseInt(fmt.Sprintf("%v", v), 10, 64)
*n = SqlInt64(i)
}
return nil
}
// Value -
func (n SqlInt64) Value() (driver.Value, error) {
if n == 0 {
return nil, nil
}
return int64(n), nil
}
// String - Override String format of ZNullInt32
func (n SqlInt64) String() string {
tmstr := fmt.Sprintf("%d", n)
return tmstr
}
// UnmarshalJSON - Overre JidSON format of ZNullInt32
func (n *SqlInt64) UnmarshalJSON(b []byte) error {
s := strings.Trim(strings.Trim(string(b), " "), "\"")
n64, err := strconv.ParseInt(s, 10, 64)
if err == nil {
*n = SqlInt64(n64)
}
return nil
}
// MarshalJSON - Override JSON format of time
func (n SqlInt64) MarshalJSON() ([]byte, error) {
return []byte(fmt.Sprintf("%d", n)), nil
}
// SqlTimeStamp - Implementation of SqlTimeStamp with some interfaces.
type SqlTimeStamp time.Time
// MarshalJSON - Override JSON format of time
func (t SqlTimeStamp) MarshalJSON() ([]byte, error) {
if time.Time(t).IsZero() {
return []byte("null"), nil
}
if time.Time(t).Before(time.Date(0001, 1, 1, 0, 0, 0, 0, time.UTC)) {
return []byte("null"), nil
}
tmstr := time.Time(t).Format("2006-01-02T15:04:05")
if tmstr == "0001-01-01T00:00:00" {
return []byte("null"), nil
}
return []byte(fmt.Sprintf("\"%s\"", tmstr)), nil
}
// UnmarshalJSON - Override JSON format of time
func (t *SqlTimeStamp) UnmarshalJSON(b []byte) error {
var err error
if b == nil {
t = &SqlTimeStamp{}
return nil
}
s := strings.Trim(strings.Trim(string(b), " "), "\"")
if s == "null" || s == "" || s == "0" ||
s == "0001-01-01T00:00:00" || s == "0001-01-01" {
t = &SqlTimeStamp{}
return nil
}
tx, err := tryParseDT(s)
if err != nil {
return err
}
*t = SqlTimeStamp(tx)
return err
}
// Value - SQL Value of custom date
func (t SqlTimeStamp) Value() (driver.Value, error) {
if t.GetTime().IsZero() || t.GetTime().Before(time.Date(0002, 1, 1, 0, 0, 0, 0, time.UTC)) {
return nil, nil
}
tmstr := time.Time(t).Format("2006-01-02T15:04:05")
if tmstr <= "0001-01-01" || tmstr == "" {
empty := time.Time{}
return empty, nil
}
return tmstr, nil
}
// Scan - Scan custom date from sql
func (t *SqlTimeStamp) Scan(value interface{}) error {
tm, ok := value.(time.Time)
if ok {
*t = SqlTimeStamp(tm)
return nil
}
str, ok := value.(string)
if ok {
tx, err := tryParseDT(str)
if err != nil {
return err
}
*t = SqlTimeStamp(tx)
}
return nil
}
// String - Override String format of time
func (t SqlTimeStamp) String() string {
return fmt.Sprintf("%s", time.Time(t).Format("2006-01-02T15:04:05"))
}
// GetTime - Returns Time
func (t SqlTimeStamp) GetTime() time.Time {
return time.Time(t)
}
// SetTime - Returns Time
func (t *SqlTimeStamp) SetTime(pTime time.Time) {
*t = SqlTimeStamp(pTime)
}
// Format - Formats the time
func (t SqlTimeStamp) Format(layout string) string {
return fmt.Sprintf("%s", time.Time(t).Format(layout))
}
func SqlTimeStampNow() SqlTimeStamp {
tx := time.Now()
return SqlTimeStamp(tx)
}
// SqlFloat64 - SQL Int
type SqlFloat64 sql.NullFloat64
// Scan -
func (n *SqlFloat64) Scan(value interface{}) error {
newval := sql.NullFloat64{Float64: 0, Valid: false}
if value == nil {
newval.Valid = false
*n = SqlFloat64(newval)
return nil
}
switch v := value.(type) {
case int:
newval.Float64 = float64(v)
newval.Valid = true
case float64:
newval.Float64 = float64(v)
newval.Valid = true
case float32:
newval.Float64 = float64(v)
newval.Valid = true
case int64:
newval.Float64 = float64(v)
newval.Valid = true
case int32:
newval.Float64 = float64(v)
newval.Valid = true
case uint16:
newval.Float64 = float64(v)
newval.Valid = true
case uint64:
newval.Float64 = float64(v)
newval.Valid = true
case uint32:
newval.Float64 = float64(v)
newval.Valid = true
default:
i, err := strconv.ParseInt(fmt.Sprintf("%v", v), 10, 64)
newval.Float64 = float64(i)
if err == nil {
newval.Valid = false
}
}
*n = SqlFloat64(newval)
return nil
}
// Value -
func (n SqlFloat64) Value() (driver.Value, error) {
if !n.Valid {
return nil, nil
}
return float64(n.Float64), nil
}
// String -
func (n SqlFloat64) String() string {
if !n.Valid {
return ""
}
tmstr := fmt.Sprintf("%f", n.Float64)
return tmstr
}
// UnmarshalJSON -
func (n *SqlFloat64) UnmarshalJSON(b []byte) error {
s := strings.Trim(strings.Trim(string(b), " "), "\"")
invalid := (s == "null" || s == "" || len(s) < 2) || (strings.Contains(s, "{") || strings.Contains(s, "["))
if invalid {
return nil
}
nval, err := strconv.ParseInt(s, 10, 64)
if err != nil {
return err
}
*n = SqlFloat64(sql.NullFloat64{Valid: true, Float64: float64(nval)})
return nil
}
// MarshalJSON - Override JSON format of time
func (n SqlFloat64) MarshalJSON() ([]byte, error) {
if !n.Valid {
return []byte("null"), nil
}
return []byte(fmt.Sprintf("%f", n.Float64)), nil
}
// SqlDate - Implementation of SqlTime with some interfaces.
type SqlDate time.Time
// UnmarshalJSON - Override JSON format of time
func (t *SqlDate) UnmarshalJSON(b []byte) error {
var err error
s := strings.Trim(strings.Trim(string(b), " "), "\"")
if s == "null" || s == "" || s == "0" ||
strings.HasPrefix(s, "0001-01-01T00:00:00") ||
s == "0001-01-01" {
t = &SqlDate{}
return nil
}
tx, err := tryParseDT(s)
if err != nil {
return err
}
*t = SqlDate(tx)
return err
}
// MarshalJSON - Override JSON format of time
func (t SqlDate) MarshalJSON() ([]byte, error) {
tmstr := time.Time(t).Format("2006-01-02") //time.RFC3339
if strings.HasPrefix(tmstr, "0001-01-01") {
return []byte("null"), nil
}
return []byte(fmt.Sprintf("\"%s\"", tmstr)), nil
}
// Value - SQL Value of custom date
func (t SqlDate) Value() (driver.Value, error) {
var s time.Time
tmstr := time.Time(t).Format("2006-01-02")
if strings.HasPrefix(tmstr, "0001-01-01") || tmstr <= "0001-01-01" {
return nil, nil
}
s = time.Time(t)
return s.Format("2006-01-02"), nil
}
// Scan - Scan custom date from sql
func (t *SqlDate) Scan(value interface{}) error {
tm, ok := value.(time.Time)
if ok {
*t = SqlDate(tm)
return nil
}
str, ok := value.(string)
if ok {
tx, err := tryParseDT(str)
if err != nil {
return err
}
*t = SqlDate(tx)
return err
}
return nil
}
// Int64 - Override date format in unix epoch
func (t SqlDate) Int64() int64 {
return time.Time(t).Unix()
}
// String - Override String format of time
func (t SqlDate) String() string {
tmstr := time.Time(t).Format("2006-01-02") //time.RFC3339
if strings.HasPrefix(tmstr, "0001-01-01") || strings.HasPrefix(tmstr, "1800-12-31") {
return "0"
}
return tmstr
}
func SqlDateNow() SqlDate {
tx := time.Now()
return SqlDate(tx)
}
// ////////////////////// SqlTime /////////////////////////
// SqlTime - Implementation of SqlTime with some interfaces.
type SqlTime time.Time
// Int64 - Override Time format in unix epoch
func (t SqlTime) Int64() int64 {
return time.Time(t).Unix()
}
// String - Override String format of time
func (t SqlTime) String() string {
return time.Time(t).Format("15:04:05")
}
// UnmarshalJSON - Override JSON format of time
func (t *SqlTime) UnmarshalJSON(b []byte) error {
var err error
s := strings.Trim(strings.Trim(string(b), " "), "\"")
if s == "null" || s == "" || s == "0" ||
s == "0001-01-01T00:00:00" || s == "00:00:00" {
*t = SqlTime{}
return nil
}
tx := time.Time{}
tx, err = tryParseDT(s)
*t = SqlTime(tx)
return err
}
// Format - Format Function
func (t SqlTime) Format(form string) string {
tmstr := time.Time(t).Format(form)
return tmstr
}
// Scan - Scan custom date from sql
func (t *SqlTime) Scan(value interface{}) error {
tm, ok := value.(time.Time)
if ok {
*t = SqlTime(tm)
return nil
}
str, ok := value.(string)
if ok {
tx, err := tryParseDT(str)
*t = SqlTime(tx)
return err
}
return nil
}
// Value - SQL Value of custom date
func (t SqlTime) Value() (driver.Value, error) {
s := time.Time(t)
st := s.Format("15:04:05")
return st, nil
}
// MarshalJSON - Override JSON format of time
func (t SqlTime) MarshalJSON() ([]byte, error) {
tmstr := time.Time(t).Format("15:04:05")
if tmstr == "0001-01-01T00:00:00" || tmstr == "00:00:00" {
return []byte("null"), nil
}
return []byte(fmt.Sprintf("\"%s\"", tmstr)), nil
}
func SqlTimeNow() SqlTime {
tx := time.Now()
return SqlTime(tx)
}
// SqlJSONB - Nullable JSONB String
type SqlJSONB []byte
// Scan - Implements sql.Scanner for reading JSONB from database
func (n *SqlJSONB) Scan(value interface{}) error {
if value == nil {
*n = nil
return nil
}
switch v := value.(type) {
case string:
*n = SqlJSONB([]byte(v))
case []byte:
*n = SqlJSONB(v)
default:
// For other types, marshal to JSON
dat, err := json.Marshal(value)
if err != nil {
return fmt.Errorf("failed to marshal value to JSON: %v", err)
}
*n = SqlJSONB(dat)
}
return nil
}
// Value - Implements driver.Valuer for writing JSONB to database
func (n SqlJSONB) Value() (driver.Value, error) {
if len(n) == 0 {
return nil, nil
}
// Validate that it's valid JSON before returning
var js interface{}
if err := json.Unmarshal(n, &js); err != nil {
return nil, fmt.Errorf("invalid JSON: %v", err)
}
// Return as string for PostgreSQL JSONB/JSON columns
return string(n), nil
}
func (n SqlJSONB) AsMap() (map[string]any, error) {
if len(n) == 0 {
return nil, nil
}
// Validate that it's valid JSON before returning
js := make(map[string]any)
if err := json.Unmarshal(n, &js); err != nil {
return nil, fmt.Errorf("invalid JSON: %v", err)
}
return js, nil
}
func (n SqlJSONB) AsSlice() ([]any, error) {
if len(n) == 0 {
return nil, nil
}
// Validate that it's valid JSON before returning
js := make([]any, 0)
if err := json.Unmarshal(n, &js); err != nil {
return nil, fmt.Errorf("invalid JSON: %v", err)
}
return js, nil
}
// UnmarshalJSON - Override JSON
func (n *SqlJSONB) UnmarshalJSON(b []byte) error {
s := strings.Trim(strings.Trim(string(b), " "), "\"")
invalid := (s == "null" || s == "" || len(s) < 2) || !(strings.Contains(s, "{") || strings.Contains(s, "["))
if invalid {
s = ""
return nil
}
*n = []byte(s)
return nil
}
// MarshalJSON - Override JSON format of time
func (n SqlJSONB) MarshalJSON() ([]byte, error) {
if n == nil {
return []byte("null"), nil
}
var obj interface{}
err := json.Unmarshal(n, &obj)
if err != nil {
//fmt.Printf("Invalid JSON %v", err)
return []byte("null"), nil
}
// dat, err := json.MarshalIndent(obj, " ", " ")
// if err != nil {
// return nil, fmt.Errorf("failed to convert to JSON: %v", err)
// }
dat := n
return dat, nil
}
// SqlUUID - Nullable UUID String
type SqlUUID sql.NullString
// Scan -
func (n *SqlUUID) Scan(value interface{}) error {
str := sql.NullString{String: "", Valid: false}
if value == nil {
*n = SqlUUID(str)
return nil
}
switch v := value.(type) {
case string:
uuid, err := uuid.Parse(v)
if err == nil {
str.String = uuid.String()
str.Valid = true
*n = SqlUUID(str)
}
case []uint8:
uuid, err := uuid.ParseBytes(v)
if err == nil {
str.String = uuid.String()
str.Valid = true
*n = SqlUUID(str)
}
default:
uuid, err := uuid.Parse(fmt.Sprintf("%v", v))
if err == nil {
str.String = uuid.String()
str.Valid = true
*n = SqlUUID(str)
}
}
return nil
}
// Value -
func (n SqlUUID) Value() (driver.Value, error) {
if !n.Valid {
return nil, nil
}
return n.String, nil
}
// UnmarshalJSON - Override JSON
func (n *SqlUUID) UnmarshalJSON(b []byte) error {
s := strings.Trim(strings.Trim(string(b), " "), "\"")
invalid := (s == "null" || s == "" || len(s) < 30)
if invalid {
s = ""
return nil
}
*n = SqlUUID(sql.NullString{String: s, Valid: !invalid})
return nil
}
// MarshalJSON - Override JSON format of time
func (n SqlUUID) MarshalJSON() ([]byte, error) {
if !n.Valid {
return []byte("null"), nil
}
return []byte(fmt.Sprintf("\"%s\"", n.String)), nil
}
// TryIfInt64 - Wrapper function to quickly try and cast text to int
func TryIfInt64(v any, def int64) int64 {
str := ""
switch val := v.(type) {
case string:
str = val
case int:
return int64(val)
case int32:
return int64(val)
case int64:
return val
case uint32:
return int64(val)
case uint64:
return int64(val)
case float32:
return int64(val)
case float64:
return int64(val)
case []byte:
str = string(val)
default:
str = fmt.Sprintf("%d", def)
}
val, err := strconv.ParseInt(str, 10, 64)
if err != nil {
return def
}
return val
}

View File

@@ -0,0 +1,566 @@
package common
import (
"database/sql/driver"
"encoding/json"
"testing"
"time"
"github.com/google/uuid"
)
// TestSqlInt16 tests SqlInt16 type
func TestSqlInt16(t *testing.T) {
tests := []struct {
name string
input interface{}
expected SqlInt16
}{
{"int", 42, SqlInt16(42)},
{"int32", int32(100), SqlInt16(100)},
{"int64", int64(200), SqlInt16(200)},
{"string", "123", SqlInt16(123)},
{"nil", nil, SqlInt16(0)},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var n SqlInt16
if err := n.Scan(tt.input); err != nil {
t.Fatalf("Scan failed: %v", err)
}
if n != tt.expected {
t.Errorf("expected %v, got %v", tt.expected, n)
}
})
}
}
func TestSqlInt16_Value(t *testing.T) {
tests := []struct {
name string
input SqlInt16
expected driver.Value
}{
{"zero", SqlInt16(0), nil},
{"positive", SqlInt16(42), int64(42)},
{"negative", SqlInt16(-10), int64(-10)},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
val, err := tt.input.Value()
if err != nil {
t.Fatalf("Value failed: %v", err)
}
if val != tt.expected {
t.Errorf("expected %v, got %v", tt.expected, val)
}
})
}
}
func TestSqlInt16_JSON(t *testing.T) {
n := SqlInt16(42)
// Marshal
data, err := json.Marshal(n)
if err != nil {
t.Fatalf("Marshal failed: %v", err)
}
expected := "42"
if string(data) != expected {
t.Errorf("expected %s, got %s", expected, string(data))
}
// Unmarshal
var n2 SqlInt16
if err := json.Unmarshal([]byte("123"), &n2); err != nil {
t.Fatalf("Unmarshal failed: %v", err)
}
if n2 != 123 {
t.Errorf("expected 123, got %d", n2)
}
}
// TestSqlInt64 tests SqlInt64 type
func TestSqlInt64(t *testing.T) {
tests := []struct {
name string
input interface{}
expected SqlInt64
}{
{"int", 42, SqlInt64(42)},
{"int32", int32(100), SqlInt64(100)},
{"int64", int64(9223372036854775807), SqlInt64(9223372036854775807)},
{"uint32", uint32(100), SqlInt64(100)},
{"uint64", uint64(200), SqlInt64(200)},
{"nil", nil, SqlInt64(0)},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var n SqlInt64
if err := n.Scan(tt.input); err != nil {
t.Fatalf("Scan failed: %v", err)
}
if n != tt.expected {
t.Errorf("expected %v, got %v", tt.expected, n)
}
})
}
}
// TestSqlFloat64 tests SqlFloat64 type
func TestSqlFloat64(t *testing.T) {
tests := []struct {
name string
input interface{}
expected float64
valid bool
}{
{"float64", float64(3.14), 3.14, true},
{"float32", float32(2.5), 2.5, true},
{"int", 42, 42.0, true},
{"int64", int64(100), 100.0, true},
{"nil", nil, 0, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var n SqlFloat64
if err := n.Scan(tt.input); err != nil {
t.Fatalf("Scan failed: %v", err)
}
if n.Valid != tt.valid {
t.Errorf("expected valid=%v, got valid=%v", tt.valid, n.Valid)
}
if tt.valid && n.Float64 != tt.expected {
t.Errorf("expected %v, got %v", tt.expected, n.Float64)
}
})
}
}
// TestSqlTimeStamp tests SqlTimeStamp type
func TestSqlTimeStamp(t *testing.T) {
now := time.Now()
tests := []struct {
name string
input interface{}
}{
{"time.Time", now},
{"string RFC3339", now.Format(time.RFC3339)},
{"string date", "2024-01-15"},
{"string datetime", "2024-01-15T10:30:00"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var ts SqlTimeStamp
if err := ts.Scan(tt.input); err != nil {
t.Fatalf("Scan failed: %v", err)
}
if ts.GetTime().IsZero() {
t.Error("expected non-zero time")
}
})
}
}
func TestSqlTimeStamp_JSON(t *testing.T) {
now := time.Date(2024, 1, 15, 10, 30, 45, 0, time.UTC)
ts := SqlTimeStamp(now)
// Marshal
data, err := json.Marshal(ts)
if err != nil {
t.Fatalf("Marshal failed: %v", err)
}
expected := `"2024-01-15T10:30:45"`
if string(data) != expected {
t.Errorf("expected %s, got %s", expected, string(data))
}
// Unmarshal
var ts2 SqlTimeStamp
if err := json.Unmarshal([]byte(`"2024-01-15T10:30:45"`), &ts2); err != nil {
t.Fatalf("Unmarshal failed: %v", err)
}
if ts2.GetTime().Year() != 2024 {
t.Errorf("expected year 2024, got %d", ts2.GetTime().Year())
}
// Test null
var ts3 SqlTimeStamp
if err := json.Unmarshal([]byte("null"), &ts3); err != nil {
t.Fatalf("Unmarshal null failed: %v", err)
}
}
// TestSqlDate tests SqlDate type
func TestSqlDate(t *testing.T) {
now := time.Now()
tests := []struct {
name string
input interface{}
}{
{"time.Time", now},
{"string date", "2024-01-15"},
{"string UK format", "15/01/2024"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var d SqlDate
if err := d.Scan(tt.input); err != nil {
t.Fatalf("Scan failed: %v", err)
}
if d.String() == "0" {
t.Error("expected non-zero date")
}
})
}
}
func TestSqlDate_JSON(t *testing.T) {
date := SqlDate(time.Date(2024, 1, 15, 0, 0, 0, 0, time.UTC))
// Marshal
data, err := json.Marshal(date)
if err != nil {
t.Fatalf("Marshal failed: %v", err)
}
expected := `"2024-01-15"`
if string(data) != expected {
t.Errorf("expected %s, got %s", expected, string(data))
}
// Unmarshal
var d2 SqlDate
if err := json.Unmarshal([]byte(`"2024-01-15"`), &d2); err != nil {
t.Fatalf("Unmarshal failed: %v", err)
}
}
// TestSqlTime tests SqlTime type
func TestSqlTime(t *testing.T) {
now := time.Now()
tests := []struct {
name string
input interface{}
expected string
}{
{"time.Time", now, now.Format("15:04:05")},
{"string time", "10:30:45", "10:30:45"},
{"string short time", "10:30", "10:30:00"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var tm SqlTime
if err := tm.Scan(tt.input); err != nil {
t.Fatalf("Scan failed: %v", err)
}
if tm.String() != tt.expected {
t.Errorf("expected %s, got %s", tt.expected, tm.String())
}
})
}
}
// TestSqlJSONB tests SqlJSONB type
func TestSqlJSONB_Scan(t *testing.T) {
tests := []struct {
name string
input interface{}
expected string
}{
{"string JSON object", `{"key":"value"}`, `{"key":"value"}`},
{"string JSON array", `[1,2,3]`, `[1,2,3]`},
{"bytes", []byte(`{"test":true}`), `{"test":true}`},
{"nil", nil, ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var j SqlJSONB
if err := j.Scan(tt.input); err != nil {
t.Fatalf("Scan failed: %v", err)
}
if tt.expected == "" && j == nil {
return // nil case
}
if string(j) != tt.expected {
t.Errorf("expected %s, got %s", tt.expected, string(j))
}
})
}
}
func TestSqlJSONB_Value(t *testing.T) {
tests := []struct {
name string
input SqlJSONB
expected string
wantErr bool
}{
{"valid object", SqlJSONB(`{"key":"value"}`), `{"key":"value"}`, false},
{"valid array", SqlJSONB(`[1,2,3]`), `[1,2,3]`, false},
{"empty", SqlJSONB{}, "", false},
{"nil", nil, "", false},
{"invalid JSON", SqlJSONB(`{invalid`), "", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
val, err := tt.input.Value()
if tt.wantErr {
if err == nil {
t.Error("expected error, got nil")
}
return
}
if err != nil {
t.Fatalf("Value failed: %v", err)
}
if tt.expected == "" && val == nil {
return // nil case
}
if val.(string) != tt.expected {
t.Errorf("expected %s, got %s", tt.expected, val)
}
})
}
}
func TestSqlJSONB_JSON(t *testing.T) {
// Marshal
j := SqlJSONB(`{"name":"test","count":42}`)
data, err := json.Marshal(j)
if err != nil {
t.Fatalf("Marshal failed: %v", err)
}
var result map[string]interface{}
if err := json.Unmarshal(data, &result); err != nil {
t.Fatalf("Unmarshal result failed: %v", err)
}
if result["name"] != "test" {
t.Errorf("expected name=test, got %v", result["name"])
}
// Unmarshal
var j2 SqlJSONB
if err := json.Unmarshal([]byte(`{"key":"value"}`), &j2); err != nil {
t.Fatalf("Unmarshal failed: %v", err)
}
if string(j2) != `{"key":"value"}` {
t.Errorf("expected {\"key\":\"value\"}, got %s", string(j2))
}
// Test null
var j3 SqlJSONB
if err := json.Unmarshal([]byte("null"), &j3); err != nil {
t.Fatalf("Unmarshal null failed: %v", err)
}
}
func TestSqlJSONB_AsMap(t *testing.T) {
tests := []struct {
name string
input SqlJSONB
wantErr bool
wantNil bool
}{
{"valid object", SqlJSONB(`{"name":"test","age":30}`), false, false},
{"empty", SqlJSONB{}, false, true},
{"nil", nil, false, true},
{"invalid JSON", SqlJSONB(`{invalid`), true, false},
{"array not object", SqlJSONB(`[1,2,3]`), true, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
m, err := tt.input.AsMap()
if tt.wantErr {
if err == nil {
t.Error("expected error, got nil")
}
return
}
if err != nil {
t.Fatalf("AsMap failed: %v", err)
}
if tt.wantNil {
if m != nil {
t.Errorf("expected nil, got %v", m)
}
return
}
if m == nil {
t.Error("expected non-nil map")
}
})
}
}
func TestSqlJSONB_AsSlice(t *testing.T) {
tests := []struct {
name string
input SqlJSONB
wantErr bool
wantNil bool
}{
{"valid array", SqlJSONB(`[1,2,3]`), false, false},
{"empty", SqlJSONB{}, false, true},
{"nil", nil, false, true},
{"invalid JSON", SqlJSONB(`[invalid`), true, false},
{"object not array", SqlJSONB(`{"key":"value"}`), true, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s, err := tt.input.AsSlice()
if tt.wantErr {
if err == nil {
t.Error("expected error, got nil")
}
return
}
if err != nil {
t.Fatalf("AsSlice failed: %v", err)
}
if tt.wantNil {
if s != nil {
t.Errorf("expected nil, got %v", s)
}
return
}
if s == nil {
t.Error("expected non-nil slice")
}
})
}
}
// TestSqlUUID tests SqlUUID type
func TestSqlUUID_Scan(t *testing.T) {
testUUID := uuid.New()
testUUIDStr := testUUID.String()
tests := []struct {
name string
input interface{}
expected string
valid bool
}{
{"string UUID", testUUIDStr, testUUIDStr, true},
{"bytes UUID", []byte(testUUIDStr), testUUIDStr, true},
{"nil", nil, "", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var u SqlUUID
if err := u.Scan(tt.input); err != nil {
t.Fatalf("Scan failed: %v", err)
}
if u.Valid != tt.valid {
t.Errorf("expected valid=%v, got valid=%v", tt.valid, u.Valid)
}
if tt.valid && u.String != tt.expected {
t.Errorf("expected %s, got %s", tt.expected, u.String)
}
})
}
}
func TestSqlUUID_Value(t *testing.T) {
testUUID := uuid.New()
u := SqlUUID{String: testUUID.String(), Valid: true}
val, err := u.Value()
if err != nil {
t.Fatalf("Value failed: %v", err)
}
if val != testUUID.String() {
t.Errorf("expected %s, got %s", testUUID.String(), val)
}
// Test invalid UUID
u2 := SqlUUID{Valid: false}
val2, err := u2.Value()
if err != nil {
t.Fatalf("Value failed: %v", err)
}
if val2 != nil {
t.Errorf("expected nil, got %v", val2)
}
}
func TestSqlUUID_JSON(t *testing.T) {
testUUID := uuid.New()
u := SqlUUID{String: testUUID.String(), Valid: true}
// Marshal
data, err := json.Marshal(u)
if err != nil {
t.Fatalf("Marshal failed: %v", err)
}
expected := `"` + testUUID.String() + `"`
if string(data) != expected {
t.Errorf("expected %s, got %s", expected, string(data))
}
// Unmarshal
var u2 SqlUUID
if err := json.Unmarshal([]byte(`"`+testUUID.String()+`"`), &u2); err != nil {
t.Fatalf("Unmarshal failed: %v", err)
}
if u2.String != testUUID.String() {
t.Errorf("expected %s, got %s", testUUID.String(), u2.String)
}
// Test null
var u3 SqlUUID
if err := json.Unmarshal([]byte("null"), &u3); err != nil {
t.Fatalf("Unmarshal null failed: %v", err)
}
if u3.Valid {
t.Error("expected invalid UUID")
}
}
// TestTryIfInt64 tests the TryIfInt64 helper function
func TestTryIfInt64(t *testing.T) {
tests := []struct {
name string
input interface{}
def int64
expected int64
}{
{"string valid", "123", 0, 123},
{"string invalid", "abc", 99, 99},
{"int", 42, 0, 42},
{"int32", int32(100), 0, 100},
{"int64", int64(200), 0, 200},
{"uint32", uint32(50), 0, 50},
{"uint64", uint64(75), 0, 75},
{"float32", float32(3.14), 0, 3},
{"float64", float64(2.71), 0, 2},
{"bytes", []byte("456"), 0, 456},
{"unknown type", struct{}{}, 999, 999},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := TryIfInt64(tt.input, tt.def)
if result != tt.expected {
t.Errorf("expected %d, got %d", tt.expected, result)
}
})
}
}

View File

@@ -92,9 +92,27 @@ func (v *ColumnValidator) getColumnName(field reflect.StructField) string {
return strings.ToLower(field.Name)
}
// extractSourceColumn extracts the base column name from PostgreSQL JSON operators
// Examples:
// - "columna->>'val'" returns "columna"
// - "columna->'key'" returns "columna"
// - "columna" returns "columna"
// - "table.columna->>'val'" returns "table.columna"
func extractSourceColumn(colName string) string {
// Check for PostgreSQL JSON operators: -> and ->>
if idx := strings.Index(colName, "->>"); idx != -1 {
return strings.TrimSpace(colName[:idx])
}
if idx := strings.Index(colName, "->"); idx != -1 {
return strings.TrimSpace(colName[:idx])
}
return colName
}
// ValidateColumn validates a single column name
// Returns nil if valid, error if invalid
// Columns prefixed with "cql" (case insensitive) are always valid
// Handles PostgreSQL JSON operators (-> and ->>)
func (v *ColumnValidator) ValidateColumn(column string) error {
// Allow empty columns
if column == "" {
@@ -106,8 +124,11 @@ func (v *ColumnValidator) ValidateColumn(column string) error {
return nil
}
// Extract source column name (remove JSON operators like ->> or ->)
sourceColumn := extractSourceColumn(column)
// Check if column exists in model
if _, exists := v.validColumns[strings.ToLower(column)]; !exists {
if _, exists := v.validColumns[strings.ToLower(sourceColumn)]; !exists {
return fmt.Errorf("invalid column '%s': column does not exist in model", column)
}

View File

@@ -0,0 +1,124 @@
package common
import (
"testing"
)
func TestExtractSourceColumn(t *testing.T) {
testCases := []struct {
name string
input string
expected string
}{
{
name: "simple column name",
input: "columna",
expected: "columna",
},
{
name: "column with ->> operator",
input: "columna->>'val'",
expected: "columna",
},
{
name: "column with -> operator",
input: "columna->'key'",
expected: "columna",
},
{
name: "column with table prefix and ->> operator",
input: "table.columna->>'val'",
expected: "table.columna",
},
{
name: "column with table prefix and -> operator",
input: "table.columna->'key'",
expected: "table.columna",
},
{
name: "complex JSON path with ->>",
input: "data->>'nested'->>'value'",
expected: "data",
},
{
name: "column with spaces before operator",
input: "columna ->>'val'",
expected: "columna",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
result := extractSourceColumn(tc.input)
if result != tc.expected {
t.Errorf("extractSourceColumn(%q) = %q; want %q", tc.input, result, tc.expected)
}
})
}
}
func TestValidateColumnWithJSONOperators(t *testing.T) {
// Create a test model
type TestModel struct {
ID int `json:"id"`
Name string `json:"name"`
Data string `json:"data"` // JSON column
Metadata string `json:"metadata"`
}
validator := NewColumnValidator(TestModel{})
testCases := []struct {
name string
column string
shouldErr bool
}{
{
name: "simple valid column",
column: "name",
shouldErr: false,
},
{
name: "valid column with ->> operator",
column: "data->>'field'",
shouldErr: false,
},
{
name: "valid column with -> operator",
column: "metadata->'key'",
shouldErr: false,
},
{
name: "invalid column",
column: "invalid_column",
shouldErr: true,
},
{
name: "invalid column with ->> operator",
column: "invalid_column->>'field'",
shouldErr: true,
},
{
name: "cql prefixed column (always valid)",
column: "cql_computed",
shouldErr: false,
},
{
name: "empty column",
column: "",
shouldErr: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := validator.ValidateColumn(tc.column)
if tc.shouldErr && err == nil {
t.Errorf("ValidateColumn(%q) expected error, got nil", tc.column)
}
if !tc.shouldErr && err != nil {
t.Errorf("ValidateColumn(%q) expected no error, got %v", tc.column, err)
}
})
}
}

View File

@@ -17,6 +17,10 @@ var defaultRegistry = &DefaultModelRegistry{
models: make(map[string]interface{}),
}
// Global list of registries (searched in order)
var registries = []*DefaultModelRegistry{defaultRegistry}
var registriesMutex sync.RWMutex
// NewModelRegistry creates a new model registry
func NewModelRegistry() *DefaultModelRegistry {
return &DefaultModelRegistry{
@@ -24,6 +28,34 @@ func NewModelRegistry() *DefaultModelRegistry {
}
}
func SetDefaultRegistry(registry *DefaultModelRegistry) {
registriesMutex.Lock()
foundAt := -1
for idx, r := range registries {
if r == defaultRegistry {
foundAt = idx
break
}
}
defaultRegistry = registry
if foundAt >= 0 {
registries[foundAt] = registry
} else {
registries = append([]*DefaultModelRegistry{registry}, registries...)
}
defer registriesMutex.Unlock()
}
// AddRegistry adds a registry to the global list of registries
// Registries are searched in the order they were added
func AddRegistry(registry *DefaultModelRegistry) {
registriesMutex.Lock()
defer registriesMutex.Unlock()
registries = append(registries, registry)
}
func (r *DefaultModelRegistry) RegisterModel(name string, model interface{}) error {
r.mutex.Lock()
defer r.mutex.Unlock()
@@ -107,9 +139,19 @@ func RegisterModel(model interface{}, name string) error {
return defaultRegistry.RegisterModel(name, model)
}
// GetModelByName retrieves a model from the default global registry by name
// GetModelByName retrieves a model by searching through all registries in order
// Returns the first match found
func GetModelByName(name string) (interface{}, error) {
return defaultRegistry.GetModel(name)
registriesMutex.RLock()
defer registriesMutex.RUnlock()
for _, registry := range registries {
if model, err := registry.GetModel(name); err == nil {
return model, nil
}
}
return nil, fmt.Errorf("model %s not found in any registry", name)
}
// IterateModels iterates over all models in the default global registry
@@ -122,14 +164,26 @@ func IterateModels(fn func(name string, model interface{})) {
}
}
// GetModels returns a list of all models in the default global registry
// GetModels returns a list of all models from all registries
// Models are collected in registry order, with duplicates included
func GetModels() []interface{} {
defaultRegistry.mutex.RLock()
defer defaultRegistry.mutex.RUnlock()
registriesMutex.RLock()
defer registriesMutex.RUnlock()
models := make([]interface{}, 0, len(defaultRegistry.models))
for _, model := range defaultRegistry.models {
models = append(models, model)
var models []interface{}
seen := make(map[string]bool)
for _, registry := range registries {
registry.mutex.RLock()
for name, model := range registry.models {
// Only add the first occurrence of each model name
if !seen[name] {
models = append(models, model)
seen[name] = true
}
}
registry.mutex.RUnlock()
}
return models
}

View File

@@ -18,6 +18,7 @@ type ModelFieldDetail struct {
}
// GetModelColumnDetail - Get a list of columns in the SQL declaration of the model
// This function recursively processes embedded structs to include their fields
func GetModelColumnDetail(record reflect.Value) []ModelFieldDetail {
defer func() {
if r := recover(); r != nil {
@@ -37,14 +38,43 @@ func GetModelColumnDetail(record reflect.Value) []ModelFieldDetail {
if record.Kind() != reflect.Struct {
return lst
}
collectFieldDetails(record, &lst)
return lst
}
// collectFieldDetails recursively collects field details from a struct value and its embedded fields
func collectFieldDetails(record reflect.Value, lst *[]ModelFieldDetail) {
modeltype := record.Type()
for i := 0; i < modeltype.NumField(); i++ {
fieldtype := modeltype.Field(i)
fieldValue := record.Field(i)
// Check if this is an embedded struct
if fieldtype.Anonymous {
// Unwrap pointer type if necessary
embeddedValue := fieldValue
if fieldValue.Kind() == reflect.Pointer {
if fieldValue.IsNil() {
// Skip nil embedded pointers
continue
}
embeddedValue = fieldValue.Elem()
}
// Recursively process embedded struct
if embeddedValue.Kind() == reflect.Struct {
collectFieldDetails(embeddedValue, lst)
continue
}
}
gormdetail := fieldtype.Tag.Get("gorm")
gormdetail = strings.Trim(gormdetail, " ")
fielddetail := ModelFieldDetail{}
fielddetail.FieldValue = record.Field(i)
fielddetail.FieldValue = fieldValue
fielddetail.Name = fieldtype.Name
fielddetail.DataType = fieldtype.Type.Name()
fielddetail.SQLName = fnFindKeyVal(gormdetail, "column:")
@@ -80,10 +110,8 @@ func GetModelColumnDetail(record reflect.Value) []ModelFieldDetail {
}
// ";foreignkey:rid_parent;association_foreignkey:id_atevent;save_associations:false;association_autocreate:false;"
lst = append(lst, fielddetail)
*lst = append(*lst, fielddetail)
}
return lst
}
func fnFindKeyVal(src, key string) string {

View File

@@ -45,8 +45,104 @@ func GetPrimaryKeyName(model any) string {
return ""
}
// GetPrimaryKeyValue extracts the primary key value from a model instance
// Returns the value of the primary key field
func GetPrimaryKeyValue(model any) any {
if model == nil || reflect.TypeOf(model) == nil {
return nil
}
val := reflect.ValueOf(model)
if val.Kind() == reflect.Pointer {
val = val.Elem()
}
if val.Kind() != reflect.Struct {
return nil
}
// Try Bun tag first
if pkValue := findPrimaryKeyValue(val, "bun"); pkValue != nil {
return pkValue
}
// Fall back to GORM tag
if pkValue := findPrimaryKeyValue(val, "gorm"); pkValue != nil {
return pkValue
}
// Last resort: look for field named "ID" or "Id"
if pkValue := findFieldByName(val, "id"); pkValue != nil {
return pkValue
}
return nil
}
// findPrimaryKeyValue recursively searches for a primary key field in the struct
func findPrimaryKeyValue(val reflect.Value, ormType string) any {
typ := val.Type()
for i := 0; i < typ.NumField(); i++ {
field := typ.Field(i)
fieldValue := val.Field(i)
// Check if this is an embedded struct
if field.Anonymous && field.Type.Kind() == reflect.Struct {
// Recursively search in embedded struct
if pkValue := findPrimaryKeyValue(fieldValue, ormType); pkValue != nil {
return pkValue
}
continue
}
// Check for primary key tag
switch ormType {
case "bun":
bunTag := field.Tag.Get("bun")
if strings.Contains(bunTag, "pk") && fieldValue.CanInterface() {
return fieldValue.Interface()
}
case "gorm":
gormTag := field.Tag.Get("gorm")
if strings.Contains(gormTag, "primaryKey") && fieldValue.CanInterface() {
return fieldValue.Interface()
}
}
}
return nil
}
// findFieldByName recursively searches for a field by name in the struct
func findFieldByName(val reflect.Value, name string) any {
typ := val.Type()
for i := 0; i < typ.NumField(); i++ {
field := typ.Field(i)
fieldValue := val.Field(i)
// Check if this is an embedded struct
if field.Anonymous && field.Type.Kind() == reflect.Struct {
// Recursively search in embedded struct
if result := findFieldByName(fieldValue, name); result != nil {
return result
}
continue
}
// Check if field name matches
if strings.ToLower(field.Name) == name && fieldValue.CanInterface() {
return fieldValue.Interface()
}
}
return nil
}
// GetModelColumns extracts all column names from a model using reflection
// It checks bun tags first, then gorm tags, then json tags, and finally falls back to lowercase field names
// This function recursively processes embedded structs to include their fields
func GetModelColumns(model any) []string {
var columns []string
@@ -62,18 +158,38 @@ func GetModelColumns(model any) []string {
return columns
}
for i := 0; i < modelType.NumField(); i++ {
field := modelType.Field(i)
collectColumnsFromType(modelType, &columns)
return columns
}
// collectColumnsFromType recursively collects column names from a struct type and its embedded fields
func collectColumnsFromType(typ reflect.Type, columns *[]string) {
for i := 0; i < typ.NumField(); i++ {
field := typ.Field(i)
// Check if this is an embedded struct
if field.Anonymous {
// Unwrap pointer type if necessary
fieldType := field.Type
if fieldType.Kind() == reflect.Pointer {
fieldType = fieldType.Elem()
}
// Recursively process embedded struct
if fieldType.Kind() == reflect.Struct {
collectColumnsFromType(fieldType, columns)
continue
}
}
// Get column name using the same logic as primary key extraction
columnName := getColumnNameFromField(field)
if columnName != "" {
columns = append(columns, columnName)
*columns = append(*columns, columnName)
}
}
return columns
}
// getColumnNameFromField extracts the column name from a struct field
@@ -110,6 +226,7 @@ func getColumnNameFromField(field reflect.StructField) string {
}
// getPrimaryKeyFromReflection uses reflection to find the primary key field
// This function recursively searches embedded structs
func getPrimaryKeyFromReflection(model any, ormType string) string {
val := reflect.ValueOf(model)
if val.Kind() == reflect.Pointer {
@@ -121,9 +238,31 @@ func getPrimaryKeyFromReflection(model any, ormType string) string {
}
typ := val.Type()
return findPrimaryKeyNameFromType(typ, ormType)
}
// findPrimaryKeyNameFromType recursively searches for the primary key field name in a struct type
func findPrimaryKeyNameFromType(typ reflect.Type, ormType string) string {
for i := 0; i < typ.NumField(); i++ {
field := typ.Field(i)
// Check if this is an embedded struct
if field.Anonymous {
// Unwrap pointer type if necessary
fieldType := field.Type
if fieldType.Kind() == reflect.Pointer {
fieldType = fieldType.Elem()
}
// Recursively search in embedded struct
if fieldType.Kind() == reflect.Struct {
if pkName := findPrimaryKeyNameFromType(fieldType, ormType); pkName != "" {
return pkName
}
}
continue
}
switch ormType {
case "gorm":
// Check for gorm tag with primaryKey
@@ -175,8 +314,129 @@ func ExtractColumnFromGormTag(tag string) string {
// Example: ",pk" -> "" (will fall back to json tag)
func ExtractColumnFromBunTag(tag string) string {
parts := strings.Split(tag, ",")
if strings.HasPrefix(strings.ToLower(tag), "table:") || strings.HasPrefix(strings.ToLower(tag), "rel:") || strings.HasPrefix(strings.ToLower(tag), "join:") {
return ""
}
if len(parts) > 0 && parts[0] != "" {
return parts[0]
}
return ""
}
// IsColumnWritable checks if a column can be written to in the database
// For bun: returns false if the field has "scanonly" tag
// For gorm: returns false if the field has "<-:false" or "->" (read-only) tag
// This function recursively searches embedded structs
func IsColumnWritable(model any, columnName string) bool {
modelType := reflect.TypeOf(model)
// Unwrap pointers to get to the base struct type
for modelType != nil && modelType.Kind() == reflect.Pointer {
modelType = modelType.Elem()
}
// Validate that we have a struct type
if modelType == nil || modelType.Kind() != reflect.Struct {
return false
}
found, writable := isColumnWritableInType(modelType, columnName)
if found {
return writable
}
// Column not found in model, allow it (might be a dynamic column)
return true
}
// isColumnWritableInType recursively searches for a column and checks if it's writable
// Returns (found, writable) where found indicates if the column was found
func isColumnWritableInType(typ reflect.Type, columnName string) (bool, bool) {
for i := 0; i < typ.NumField(); i++ {
field := typ.Field(i)
// Check if this is an embedded struct
if field.Anonymous {
// Unwrap pointer type if necessary
fieldType := field.Type
if fieldType.Kind() == reflect.Pointer {
fieldType = fieldType.Elem()
}
// Recursively search in embedded struct
if fieldType.Kind() == reflect.Struct {
if found, writable := isColumnWritableInType(fieldType, columnName); found {
return true, writable
}
}
continue
}
// Check if this field matches the column name
fieldColumnName := getColumnNameFromField(field)
if fieldColumnName != columnName {
continue
}
// Found the field, now check if it's writable
// Check bun tag for scanonly
bunTag := field.Tag.Get("bun")
if bunTag != "" {
if isBunFieldScanOnly(bunTag) {
return true, false
}
}
// Check gorm tag for write restrictions
gormTag := field.Tag.Get("gorm")
if gormTag != "" {
if isGormFieldReadOnly(gormTag) {
return true, false
}
}
// Column is writable
return true, true
}
// Column not found
return false, false
}
// isBunFieldScanOnly checks if a bun tag indicates the field is scan-only
// Example: "column_name,scanonly" -> true
func isBunFieldScanOnly(tag string) bool {
parts := strings.Split(tag, ",")
for _, part := range parts {
if strings.TrimSpace(part) == "scanonly" {
return true
}
}
return false
}
// isGormFieldReadOnly checks if a gorm tag indicates the field is read-only
// Examples:
// - "<-:false" -> true (no writes allowed)
// - "->" -> true (read-only, common pattern)
// - "column:name;->" -> true
// - "<-:create" -> false (writes allowed on create)
func isGormFieldReadOnly(tag string) bool {
parts := strings.Split(tag, ";")
for _, part := range parts {
part = strings.TrimSpace(part)
// Check for read-only marker
if part == "->" {
return true
}
// Check for write restrictions
if value, found := strings.CutPrefix(part, "<-:"); found {
if value == "false" {
return true
}
}
}
return false
}

View File

@@ -231,3 +231,246 @@ func TestGetModelColumns(t *testing.T) {
})
}
}
// Test models with embedded structs
type BaseModel struct {
ID int `bun:"rid_base,pk" json:"id"`
CreatedAt string `bun:"created_at" json:"created_at"`
}
type AdhocBuffer struct {
CQL1 string `json:"cql1,omitempty" gorm:"->" bun:",scanonly"`
CQL2 string `json:"cql2,omitempty" gorm:"->" bun:",scanonly"`
RowNumber int64 `json:"_rownumber,omitempty" gorm:"-" bun:",scanonly"`
}
type ModelWithEmbedded struct {
BaseModel
Name string `bun:"name" json:"name"`
Description string `bun:"description" json:"description"`
AdhocBuffer
}
type GormBaseModel struct {
ID int `gorm:"column:rid_base;primaryKey" json:"id"`
CreatedAt string `gorm:"column:created_at" json:"created_at"`
}
type GormAdhocBuffer struct {
CQL1 string `json:"cql1,omitempty" gorm:"column:cql1;->" bun:",scanonly"`
CQL2 string `json:"cql2,omitempty" gorm:"column:cql2;->" bun:",scanonly"`
RowNumber int64 `json:"_rownumber,omitempty" gorm:"-" bun:",scanonly"`
}
type GormModelWithEmbedded struct {
GormBaseModel
Name string `gorm:"column:name" json:"name"`
Description string `gorm:"column:description" json:"description"`
GormAdhocBuffer
}
func TestGetPrimaryKeyNameWithEmbedded(t *testing.T) {
tests := []struct {
name string
model any
expected string
}{
{
name: "Bun model with embedded base",
model: ModelWithEmbedded{},
expected: "rid_base",
},
{
name: "Bun model with embedded base (pointer)",
model: &ModelWithEmbedded{},
expected: "rid_base",
},
{
name: "GORM model with embedded base",
model: GormModelWithEmbedded{},
expected: "rid_base",
},
{
name: "GORM model with embedded base (pointer)",
model: &GormModelWithEmbedded{},
expected: "rid_base",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := GetPrimaryKeyName(tt.model)
if result != tt.expected {
t.Errorf("GetPrimaryKeyName() = %v, want %v", result, tt.expected)
}
})
}
}
func TestGetPrimaryKeyValueWithEmbedded(t *testing.T) {
bunModel := ModelWithEmbedded{
BaseModel: BaseModel{
ID: 123,
CreatedAt: "2024-01-01",
},
Name: "Test",
Description: "Test Description",
}
gormModel := GormModelWithEmbedded{
GormBaseModel: GormBaseModel{
ID: 456,
CreatedAt: "2024-01-02",
},
Name: "GORM Test",
Description: "GORM Test Description",
}
tests := []struct {
name string
model any
expected any
}{
{
name: "Bun model with embedded base",
model: bunModel,
expected: 123,
},
{
name: "Bun model with embedded base (pointer)",
model: &bunModel,
expected: 123,
},
{
name: "GORM model with embedded base",
model: gormModel,
expected: 456,
},
{
name: "GORM model with embedded base (pointer)",
model: &gormModel,
expected: 456,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := GetPrimaryKeyValue(tt.model)
if result != tt.expected {
t.Errorf("GetPrimaryKeyValue() = %v, want %v", result, tt.expected)
}
})
}
}
func TestGetModelColumnsWithEmbedded(t *testing.T) {
tests := []struct {
name string
model any
expected []string
}{
{
name: "Bun model with embedded structs",
model: ModelWithEmbedded{},
expected: []string{"rid_base", "created_at", "name", "description", "cql1", "cql2", "_rownumber"},
},
{
name: "Bun model with embedded structs (pointer)",
model: &ModelWithEmbedded{},
expected: []string{"rid_base", "created_at", "name", "description", "cql1", "cql2", "_rownumber"},
},
{
name: "GORM model with embedded structs",
model: GormModelWithEmbedded{},
expected: []string{"rid_base", "created_at", "name", "description", "cql1", "cql2", "_rownumber"},
},
{
name: "GORM model with embedded structs (pointer)",
model: &GormModelWithEmbedded{},
expected: []string{"rid_base", "created_at", "name", "description", "cql1", "cql2", "_rownumber"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := GetModelColumns(tt.model)
if len(result) != len(tt.expected) {
t.Errorf("GetModelColumns() returned %d columns, want %d. Got: %v", len(result), len(tt.expected), result)
return
}
for i, col := range result {
if col != tt.expected[i] {
t.Errorf("GetModelColumns()[%d] = %v, want %v", i, col, tt.expected[i])
}
}
})
}
}
func TestIsColumnWritableWithEmbedded(t *testing.T) {
tests := []struct {
name string
model any
columnName string
expected bool
}{
{
name: "Bun model - writable column in main struct",
model: ModelWithEmbedded{},
columnName: "name",
expected: true,
},
{
name: "Bun model - writable column in embedded base",
model: ModelWithEmbedded{},
columnName: "rid_base",
expected: true,
},
{
name: "Bun model - scanonly column in embedded adhoc buffer",
model: ModelWithEmbedded{},
columnName: "cql1",
expected: false,
},
{
name: "Bun model - scanonly column _rownumber",
model: ModelWithEmbedded{},
columnName: "_rownumber",
expected: false,
},
{
name: "GORM model - writable column in main struct",
model: GormModelWithEmbedded{},
columnName: "name",
expected: true,
},
{
name: "GORM model - writable column in embedded base",
model: GormModelWithEmbedded{},
columnName: "rid_base",
expected: true,
},
{
name: "GORM model - readonly column in embedded adhoc buffer",
model: GormModelWithEmbedded{},
columnName: "cql1",
expected: false,
},
{
name: "GORM model - readonly column _rownumber",
model: GormModelWithEmbedded{},
columnName: "_rownumber",
expected: false, // bun:",scanonly" marks it as read-only, takes precedence over gorm:"-"
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := IsColumnWritable(tt.model, tt.columnName)
if result != tt.expected {
t.Errorf("IsColumnWritable(%s) = %v, want %v", tt.columnName, result, tt.expected)
}
})
}
}

View File

@@ -7,6 +7,7 @@ import (
"net/http"
"reflect"
"runtime/debug"
"strconv"
"strings"
"github.com/bitechdev/ResolveSpec/pkg/common"
@@ -133,9 +134,15 @@ func (h *Handler) Handle(w common.ResponseWriter, r common.Request, params map[s
h.sendError(w, http.StatusBadRequest, "invalid_request", "Invalid request body", err)
return
}
h.handleCreate(ctx, w, data, options)
validId, _ := strconv.ParseInt(id, 10, 64)
if validId > 0 {
h.handleUpdate(ctx, w, id, nil, data, options)
} else {
h.handleCreate(ctx, w, data, options)
}
case "PUT", "PATCH":
// Update operation
body, err := r.Body()
if err != nil {
logger.Error("Failed to read request body: %v", err)
@@ -253,6 +260,10 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
query = query.Table(tableName)
}
// Note: X-Files configuration is now applied via parseXFiles which populates
// ExtendedRequestOptions fields (columns, filters, sort, preload, etc.)
// These are applied below in the normal query building process
// Apply ComputedQL fields if any
if len(options.ComputedQL) > 0 {
for colName, colExpr := range options.ComputedQL {
@@ -293,6 +304,9 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
logger.Debug("Applying expand: %s", expand.Relation)
sorts := make([]common.SortOption, 0)
for _, s := range strings.Split(expand.Sort, ",") {
if s == "" {
continue
}
dir := "ASC"
if strings.HasPrefix(s, "-") || strings.HasSuffix(strings.ToUpper(s), " DESC") {
dir = "DESC"
@@ -574,22 +588,6 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
logger.Info("Creating record in %s.%s", schema, entity)
// Check if data is a single map with nested relations
if dataMap, ok := data.(map[string]interface{}); ok {
if h.shouldUseNestedProcessor(dataMap, model) {
logger.Info("Using nested CUD processor for create operation")
result, err := h.nestedProcessor.ProcessNestedCUD(ctx, "insert", dataMap, model, make(map[string]interface{}), tableName)
if err != nil {
logger.Error("Error in nested create: %v", err)
h.sendError(w, http.StatusInternalServerError, "create_error", "Error creating record with nested data", err)
return
}
logger.Info("Successfully created record with nested data, ID: %v", result.ID)
h.sendResponse(w, result.Data, nil)
return
}
}
// Execute BeforeCreate hooks
hookCtx := &HookContext{
Context: ctx,
@@ -612,172 +610,142 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
// Use potentially modified data from hook context
data = hookCtx.Data
// Handle batch creation
dataValue := reflect.ValueOf(data)
if dataValue.Kind() == reflect.Slice || dataValue.Kind() == reflect.Array {
logger.Debug("Batch creation detected, count: %d", dataValue.Len())
// Normalize data to slice for unified processing
dataSlice := h.normalizeToSlice(data)
logger.Debug("Processing %d item(s) for creation", len(dataSlice))
// Check if any item needs nested processing
hasNestedData := false
for i := 0; i < dataValue.Len(); i++ {
item := dataValue.Index(i).Interface()
if itemMap, ok := item.(map[string]interface{}); ok {
if h.shouldUseNestedProcessor(itemMap, model) {
hasNestedData = true
break
}
}
}
// Store original data maps for merging later
originalDataMaps := make([]map[string]interface{}, 0, len(dataSlice))
if hasNestedData {
logger.Info("Using nested CUD processor for batch create with nested data")
results := make([]interface{}, 0, dataValue.Len())
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
// Temporarily swap the database to use transaction
originalDB := h.nestedProcessor
h.nestedProcessor = common.NewNestedCUDProcessor(tx, h.registry, h)
defer func() {
h.nestedProcessor = originalDB
}()
// Process all items in a transaction
results := make([]interface{}, 0, len(dataSlice))
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
// Create temporary nested processor with transaction
txNestedProcessor := common.NewNestedCUDProcessor(tx, h.registry, h)
for i := 0; i < dataValue.Len(); i++ {
item := dataValue.Index(i).Interface()
if itemMap, ok := item.(map[string]interface{}); ok {
result, err := h.nestedProcessor.ProcessNestedCUD(ctx, "insert", itemMap, model, make(map[string]interface{}), tableName)
if err != nil {
return fmt.Errorf("failed to process item: %w", err)
}
results = append(results, result.Data)
}
}
return nil
})
if err != nil {
logger.Error("Error creating records with nested data: %v", err)
h.sendError(w, http.StatusInternalServerError, "create_error", "Error creating records with nested data", err)
return
}
// Execute AfterCreate hooks
hookCtx.Result = map[string]interface{}{"created": len(results), "data": results}
hookCtx.Error = nil
if err := h.hooks.Execute(AfterCreate, hookCtx); err != nil {
logger.Error("AfterCreate hook failed: %v", err)
h.sendError(w, http.StatusInternalServerError, "hook_error", "Hook execution failed", err)
return
}
logger.Info("Successfully created %d records with nested data", len(results))
h.sendResponse(w, results, nil)
return
}
// Standard batch insert without nested relations
// Use transaction for batch insert
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
for i := 0; i < dataValue.Len(); i++ {
item := dataValue.Index(i).Interface()
// Convert item to model type - create a pointer to the model
modelValue := reflect.New(reflect.TypeOf(model)).Interface()
for i, item := range dataSlice {
itemMap, ok := item.(map[string]interface{})
if !ok {
// Convert to map if needed
jsonData, err := json.Marshal(item)
if err != nil {
return fmt.Errorf("failed to marshal item: %w", err)
return fmt.Errorf("failed to marshal item %d: %w", i, err)
}
if err := json.Unmarshal(jsonData, modelValue); err != nil {
return fmt.Errorf("failed to unmarshal item: %w", err)
}
query := tx.NewInsert().Model(modelValue).Table(tableName)
// Execute BeforeScan hooks - pass query chain so hooks can modify it
batchHookCtx := &HookContext{
Context: ctx,
Handler: h,
Schema: schema,
Entity: entity,
TableName: tableName,
Model: model,
Options: options,
Data: modelValue,
Writer: w,
Query: query,
}
if err := h.hooks.Execute(BeforeScan, batchHookCtx); err != nil {
return fmt.Errorf("BeforeScan hook failed: %w", err)
}
// Use potentially modified query from hook context
if modifiedQuery, ok := batchHookCtx.Query.(common.InsertQuery); ok {
query = modifiedQuery
}
if _, err := query.Exec(ctx); err != nil {
return fmt.Errorf("failed to insert record: %w", err)
itemMap = make(map[string]interface{})
if err := json.Unmarshal(jsonData, &itemMap); err != nil {
return fmt.Errorf("failed to unmarshal item %d: %w", i, err)
}
}
return nil
})
if err != nil {
logger.Error("Error creating records: %v", err)
h.sendError(w, http.StatusInternalServerError, "create_error", "Error creating records", err)
return
// Store a copy of the original data map for merging later
originalMap := make(map[string]interface{})
for k, v := range itemMap {
originalMap[k] = v
}
originalDataMaps = append(originalDataMaps, originalMap)
// Extract nested relations if present (but don't process them yet)
var nestedRelations map[string]interface{}
if h.shouldUseNestedProcessor(itemMap, model) {
logger.Debug("Extracting nested relations for item %d", i)
cleanedData, relations, err := h.extractNestedRelations(itemMap, model)
if err != nil {
return fmt.Errorf("failed to extract nested relations for item %d: %w", i, err)
}
itemMap = cleanedData
nestedRelations = relations
}
// Convert item to model type - create a pointer to the model
modelValue := reflect.New(reflect.TypeOf(model)).Interface()
jsonData, err := json.Marshal(itemMap)
if err != nil {
return fmt.Errorf("failed to marshal item %d: %w", i, err)
}
if err := json.Unmarshal(jsonData, modelValue); err != nil {
return fmt.Errorf("failed to unmarshal item %d: %w", i, err)
}
// Create insert query
query := tx.NewInsert().Model(modelValue)
// Only set Table() if the model doesn't provide a table name via TableNameProvider
if provider, ok := modelValue.(common.TableNameProvider); !ok || provider.TableName() == "" {
query = query.Table(tableName)
}
query = query.Returning("*")
// Execute BeforeScan hooks - pass query chain so hooks can modify it
itemHookCtx := &HookContext{
Context: ctx,
Handler: h,
Schema: schema,
Entity: entity,
TableName: tableName,
Model: model,
Options: options,
Data: modelValue,
Writer: w,
Query: query,
}
if err := h.hooks.Execute(BeforeScan, itemHookCtx); err != nil {
return fmt.Errorf("BeforeScan hook failed for item %d: %w", i, err)
}
// Use potentially modified query from hook context
if modifiedQuery, ok := itemHookCtx.Query.(common.InsertQuery); ok {
query = modifiedQuery
}
// Execute insert and get the ID
if _, err := query.Exec(ctx); err != nil {
return fmt.Errorf("failed to insert item %d: %w", i, err)
}
// Get the inserted ID
insertedID := reflection.GetPrimaryKeyValue(modelValue)
// Now process nested relations with the parent ID
if len(nestedRelations) > 0 {
logger.Debug("Processing nested relations for item %d with parent ID: %v", i, insertedID)
if err := h.processChildRelationsWithParentID(ctx, txNestedProcessor, "insert", nestedRelations, model, insertedID); err != nil {
return fmt.Errorf("failed to process nested relations for item %d: %w", i, err)
}
}
results = append(results, modelValue)
}
return nil
})
// Execute AfterCreate hooks for batch creation
hookCtx.Result = map[string]interface{}{"created": dataValue.Len()}
hookCtx.Error = nil
if err := h.hooks.Execute(AfterCreate, hookCtx); err != nil {
logger.Error("AfterCreate hook failed: %v", err)
h.sendError(w, http.StatusInternalServerError, "hook_error", "Hook execution failed", err)
return
}
h.sendResponse(w, map[string]interface{}{"created": dataValue.Len()}, nil)
return
}
// Single record creation - create a pointer to the model
modelValue := reflect.New(reflect.TypeOf(model)).Interface()
jsonData, err := json.Marshal(data)
if err != nil {
logger.Error("Error marshaling data: %v", err)
h.sendError(w, http.StatusBadRequest, "invalid_data", "Invalid data format", err)
return
}
if err := json.Unmarshal(jsonData, modelValue); err != nil {
logger.Error("Error unmarshaling data: %v", err)
h.sendError(w, http.StatusBadRequest, "invalid_data", "Invalid data format", err)
logger.Error("Error creating records: %v", err)
h.sendError(w, http.StatusInternalServerError, "create_error", "Error creating records", err)
return
}
query := h.db.NewInsert().Model(modelValue).Table(tableName)
// Execute BeforeScan hooks - pass query chain so hooks can modify it
hookCtx.Data = modelValue
hookCtx.Query = query
if err := h.hooks.Execute(BeforeScan, hookCtx); err != nil {
logger.Error("BeforeScan hook failed: %v", err)
h.sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err)
return
// Merge created records with original request data
// This preserves extra keys from the request
mergedResults := make([]interface{}, 0, len(results))
for i, result := range results {
if i < len(originalDataMaps) {
merged := h.mergeRecordWithRequest(result, originalDataMaps[i])
mergedResults = append(mergedResults, merged)
} else {
mergedResults = append(mergedResults, result)
}
}
// Use potentially modified query from hook context
if modifiedQuery, ok := hookCtx.Query.(common.InsertQuery); ok {
query = modifiedQuery
// Execute AfterCreate hooks
var responseData interface{}
if len(mergedResults) == 1 {
responseData = mergedResults[0]
hookCtx.Result = mergedResults[0]
} else {
responseData = mergedResults
hookCtx.Result = map[string]interface{}{"created": len(mergedResults), "data": mergedResults}
}
if _, err := query.Exec(ctx); err != nil {
logger.Error("Error creating record: %v", err)
h.sendError(w, http.StatusInternalServerError, "create_error", "Error creating record", err)
return
}
// Execute AfterCreate hooks for single record creation
hookCtx.Result = modelValue
hookCtx.Error = nil
if err := h.hooks.Execute(AfterCreate, hookCtx); err != nil {
@@ -786,7 +754,8 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
return
}
h.sendResponse(w, modelValue, nil)
logger.Info("Successfully created %d record(s)", len(mergedResults))
h.sendResponseWithOptions(w, responseData, nil, &options)
}
func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id string, idPtr *int64, data interface{}, options ExtendedRequestOptions) {
@@ -804,46 +773,6 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id
logger.Info("Updating record in %s.%s", schema, entity)
// Convert data to map first for nested processor check
dataMap, ok := data.(map[string]interface{})
if !ok {
jsonData, err := json.Marshal(data)
if err != nil {
logger.Error("Error marshaling data: %v", err)
h.sendError(w, http.StatusBadRequest, "invalid_data", "Invalid data format", err)
return
}
if err := json.Unmarshal(jsonData, &dataMap); err != nil {
logger.Error("Error unmarshaling data: %v", err)
h.sendError(w, http.StatusBadRequest, "invalid_data", "Invalid data format", err)
return
}
}
// Check if we should use nested processing
if h.shouldUseNestedProcessor(dataMap, model) {
logger.Info("Using nested CUD processor for update operation")
// Ensure ID is in the data map
var targetID interface{}
if id != "" {
targetID = id
} else if idPtr != nil {
targetID = *idPtr
}
if targetID != nil {
dataMap["id"] = targetID
}
result, err := h.nestedProcessor.ProcessNestedCUD(ctx, "update", dataMap, model, make(map[string]interface{}), tableName)
if err != nil {
logger.Error("Error in nested update: %v", err)
h.sendError(w, http.StatusInternalServerError, "update_error", "Error updating record with nested data", err)
return
}
logger.Info("Successfully updated record with nested data, rows: %d", result.AffectedRows)
h.sendResponse(w, result.Data, nil)
return
}
// Execute BeforeUpdate hooks
hookCtx := &HookContext{
Context: ctx,
@@ -867,8 +796,8 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id
// Use potentially modified data from hook context
data = hookCtx.Data
// Convert data to map (again if modified by hooks)
dataMap, ok = data.(map[string]interface{})
// Convert data to map
dataMap, ok := data.(map[string]interface{})
if !ok {
jsonData, err := json.Marshal(data)
if err != nil {
@@ -883,53 +812,108 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id
}
}
query := h.db.NewUpdate().Table(tableName).SetMap(dataMap)
pkName := reflection.GetPrimaryKeyName(model)
// Apply ID filter
switch {
case id != "":
query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id)
case idPtr != nil:
query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), *idPtr)
default:
// Determine target ID
var targetID interface{}
if id != "" {
targetID = id
} else if idPtr != nil {
targetID = *idPtr
} else {
h.sendError(w, http.StatusBadRequest, "missing_id", "ID is required for update", nil)
return
}
// Execute BeforeScan hooks - pass query chain so hooks can modify it
hookCtx.Query = query
if err := h.hooks.Execute(BeforeScan, hookCtx); err != nil {
logger.Error("BeforeScan hook failed: %v", err)
h.sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err)
return
}
// Get the primary key name for the model
pkName := reflection.GetPrimaryKeyName(model)
// Use potentially modified query from hook context
if modifiedQuery, ok := hookCtx.Query.(common.UpdateQuery); ok {
query = modifiedQuery
}
// Variable to store the updated record
var updatedRecord interface{}
// Process nested relations if present
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
// Create temporary nested processor with transaction
txNestedProcessor := common.NewNestedCUDProcessor(tx, h.registry, h)
// Extract nested relations if present (but don't process them yet)
var nestedRelations map[string]interface{}
if h.shouldUseNestedProcessor(dataMap, model) {
logger.Debug("Extracting nested relations for update")
cleanedData, relations, err := h.extractNestedRelations(dataMap, model)
if err != nil {
return fmt.Errorf("failed to extract nested relations: %w", err)
}
dataMap = cleanedData
nestedRelations = relations
}
// Ensure ID is in the data map for the update
dataMap[pkName] = targetID
// Create update query
query := tx.NewUpdate().Table(tableName).SetMap(dataMap)
query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), targetID)
// Execute BeforeScan hooks - pass query chain so hooks can modify it
hookCtx.Query = query
if err := h.hooks.Execute(BeforeScan, hookCtx); err != nil {
return fmt.Errorf("BeforeScan hook failed: %w", err)
}
// Use potentially modified query from hook context
if modifiedQuery, ok := hookCtx.Query.(common.UpdateQuery); ok {
query = modifiedQuery
}
// Execute update
result, err := query.Exec(ctx)
if err != nil {
return fmt.Errorf("failed to update record: %w", err)
}
// Now process nested relations with the parent ID
if len(nestedRelations) > 0 {
logger.Debug("Processing nested relations for update with parent ID: %v", targetID)
if err := h.processChildRelationsWithParentID(ctx, txNestedProcessor, "update", nestedRelations, model, targetID); err != nil {
return fmt.Errorf("failed to process nested relations: %w", err)
}
}
// Fetch the updated record to return the new values
modelValue := reflect.New(reflect.TypeOf(model)).Interface()
selectQuery := tx.NewSelect().Model(modelValue).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), targetID)
if err := selectQuery.ScanModel(ctx); err != nil {
return fmt.Errorf("failed to fetch updated record: %w", err)
}
updatedRecord = modelValue
// Store result for hooks
hookCtx.Result = updatedRecord
_ = result // Keep result variable for potential future use
return nil
})
result, err := query.Exec(ctx)
if err != nil {
logger.Error("Error updating record: %v", err)
h.sendError(w, http.StatusInternalServerError, "update_error", "Error updating record", err)
return
}
// Execute AfterUpdate hooks
responseData := map[string]interface{}{
"updated": result.RowsAffected(),
}
hookCtx.Result = responseData
hookCtx.Error = nil
// Merge the updated record with the original request data
// This preserves extra keys from the request and updates values from the database
mergedData := h.mergeRecordWithRequest(updatedRecord, dataMap)
// Execute AfterUpdate hooks
hookCtx.Result = mergedData
hookCtx.Error = nil
if err := h.hooks.Execute(AfterUpdate, hookCtx); err != nil {
logger.Error("AfterUpdate hook failed: %v", err)
h.sendError(w, http.StatusInternalServerError, "hook_error", "Hook execution failed", err)
return
}
h.sendResponse(w, responseData, nil)
logger.Info("Successfully updated record with ID: %v", targetID)
h.sendResponseWithOptions(w, mergedData, nil, &options)
}
func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id string, data interface{}) {
@@ -1003,6 +987,7 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
// Array of IDs or objects with ID field
logger.Info("Batch delete with %d items ([]interface{})", len(v))
deletedCount := 0
pkName := reflection.GetPrimaryKeyName(model)
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
for _, item := range v {
var itemID interface{}
@@ -1012,7 +997,7 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
case string:
itemID = v
case map[string]interface{}:
itemID = v["id"]
itemID = v[pkName]
default:
itemID = item
}
@@ -1069,9 +1054,10 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
// Array of objects with id field
logger.Info("Batch delete with %d items ([]map[string]interface{})", len(v))
deletedCount := 0
pkName := reflection.GetPrimaryKeyName(model)
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
for _, item := range v {
if itemID, ok := item["id"]; ok && itemID != nil {
if itemID, ok := item[pkName]; ok && itemID != nil {
itemIDStr := fmt.Sprintf("%v", itemID)
// Execute hooks for each item
@@ -1119,7 +1105,8 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
case map[string]interface{}:
// Single object with id field
if itemID, ok := v["id"]; ok && itemID != nil {
pkName := reflection.GetPrimaryKeyName(model)
if itemID, ok := v[pkName]; ok && itemID != nil {
id = fmt.Sprintf("%v", itemID)
}
}
@@ -1189,6 +1176,229 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
h.sendResponse(w, responseData, nil)
}
// mergeRecordWithRequest merges a database record with the original request data
// This preserves extra keys from the request that aren't in the database model
// and updates values from the database (e.g., from SQL triggers or defaults)
func (h *Handler) mergeRecordWithRequest(dbRecord interface{}, requestData map[string]interface{}) map[string]interface{} {
// Convert the database record to a map
dbMap := make(map[string]interface{})
// Marshal and unmarshal to convert struct to map
jsonData, err := json.Marshal(dbRecord)
if err != nil {
logger.Warn("Failed to marshal database record for merging: %v", err)
return requestData
}
if err := json.Unmarshal(jsonData, &dbMap); err != nil {
logger.Warn("Failed to unmarshal database record for merging: %v", err)
return requestData
}
// Start with the request data (preserves extra keys)
result := make(map[string]interface{})
for k, v := range requestData {
result[k] = v
}
// Update with values from database (overwrites with DB values, including trigger changes)
for k, v := range dbMap {
result[k] = v
}
return result
}
// normalizeToSlice converts data to a slice. Single items become a 1-item slice.
func (h *Handler) normalizeToSlice(data interface{}) []interface{} {
if data == nil {
return []interface{}{}
}
dataValue := reflect.ValueOf(data)
if dataValue.Kind() == reflect.Slice || dataValue.Kind() == reflect.Array {
result := make([]interface{}, dataValue.Len())
for i := 0; i < dataValue.Len(); i++ {
result[i] = dataValue.Index(i).Interface()
}
return result
}
// Single item - return as 1-item slice
return []interface{}{data}
}
// extractNestedRelations extracts nested relations from data, returning cleaned data and relations
// This does NOT process the relations, just separates them for later processing
func (h *Handler) extractNestedRelations(
data map[string]interface{},
model interface{},
) (map[string]interface{}, map[string]interface{}, error) {
// Get model type for reflection
modelType := reflect.TypeOf(model)
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
modelType = modelType.Elem()
}
if modelType == nil || modelType.Kind() != reflect.Struct {
return data, nil, fmt.Errorf("model must be a struct type, got %v", modelType)
}
// Separate relation fields from regular fields
cleanedData := make(map[string]interface{})
relations := make(map[string]interface{})
for key, value := range data {
// Skip _request field
if key == "_request" {
continue
}
// Check if this field is a relation
relInfo := h.GetRelationshipInfo(modelType, key)
if relInfo != nil {
logger.Debug("Found nested relation field: %s (type: %s)", key, relInfo.RelationType)
relations[key] = value
} else {
cleanedData[key] = value
}
}
return cleanedData, relations, nil
}
// processChildRelationsWithParentID processes nested relations with a parent ID
func (h *Handler) processChildRelationsWithParentID(
ctx context.Context,
processor *common.NestedCUDProcessor,
operation string,
relations map[string]interface{},
parentModel interface{},
parentID interface{},
) error {
// Get model type for reflection
modelType := reflect.TypeOf(parentModel)
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
modelType = modelType.Elem()
}
if modelType == nil || modelType.Kind() != reflect.Struct {
return fmt.Errorf("model must be a struct type, got %v", modelType)
}
// Process each relation
for relationName, relationValue := range relations {
if relationValue == nil {
continue
}
// Get relationship info
relInfo := h.GetRelationshipInfo(modelType, relationName)
if relInfo == nil {
logger.Warn("No relationship info found for %s, skipping", relationName)
continue
}
// Process this relation with parent ID
if err := h.processChildRelationsForField(ctx, processor, operation, relationName, relationValue, relInfo, modelType, parentID); err != nil {
return fmt.Errorf("failed to process relation %s: %w", relationName, err)
}
}
return nil
}
// processChildRelationsForField processes a single nested relation field
func (h *Handler) processChildRelationsForField(
ctx context.Context,
processor *common.NestedCUDProcessor,
operation string,
relationName string,
relationValue interface{},
relInfo *common.RelationshipInfo,
parentModelType reflect.Type,
parentID interface{},
) error {
if relationValue == nil {
return nil
}
// Get the related model
field, found := parentModelType.FieldByName(relInfo.FieldName)
if !found {
return fmt.Errorf("field %s not found in model", relInfo.FieldName)
}
// Get the model type for the relation
relatedModelType := field.Type
if relatedModelType.Kind() == reflect.Slice {
relatedModelType = relatedModelType.Elem()
}
if relatedModelType.Kind() == reflect.Ptr {
relatedModelType = relatedModelType.Elem()
}
// Create an instance of the related model
relatedModel := reflect.New(relatedModelType).Elem().Interface()
// Get table name for related model
relatedTableName := h.getTableNameForRelatedModel(relatedModel, relInfo.JSONName)
// Prepare parent IDs for foreign key injection
parentIDs := make(map[string]interface{})
if relInfo.ForeignKey != "" && parentID != nil {
baseName := strings.TrimSuffix(relInfo.ForeignKey, "ID")
baseName = strings.TrimSuffix(strings.ToLower(baseName), "_id")
parentIDs[baseName] = parentID
}
// Process based on relation type and data structure
switch v := relationValue.(type) {
case map[string]interface{}:
// Single related object
_, err := processor.ProcessNestedCUD(ctx, operation, v, relatedModel, parentIDs, relatedTableName)
if err != nil {
return fmt.Errorf("failed to process single relation: %w", err)
}
case []interface{}:
// Multiple related objects
for i, item := range v {
if itemMap, ok := item.(map[string]interface{}); ok {
_, err := processor.ProcessNestedCUD(ctx, operation, itemMap, relatedModel, parentIDs, relatedTableName)
if err != nil {
return fmt.Errorf("failed to process relation item %d: %w", i, err)
}
}
}
case []map[string]interface{}:
// Multiple related objects (typed slice)
for i, itemMap := range v {
_, err := processor.ProcessNestedCUD(ctx, operation, itemMap, relatedModel, parentIDs, relatedTableName)
if err != nil {
return fmt.Errorf("failed to process relation item %d: %w", i, err)
}
}
default:
return fmt.Errorf("unsupported relation data type: %T", relationValue)
}
return nil
}
// getTableNameForRelatedModel gets the table name for a related model
func (h *Handler) getTableNameForRelatedModel(model interface{}, defaultName string) string {
if provider, ok := model.(common.TableNameProvider); ok {
tableName := provider.TableName()
if tableName != "" {
return tableName
}
}
return defaultName
}
// qualifyColumnName ensures column name is fully qualified with table name if not already
func (h *Handler) qualifyColumnName(columnName, fullTableName string) string {
// Check if column already has a table/schema prefix (contains a dot)
@@ -1431,19 +1641,52 @@ func (h *Handler) isNullable(field reflect.StructField) bool {
}
func (h *Handler) sendResponse(w common.ResponseWriter, data interface{}, metadata *common.Metadata) {
response := common.Response{
Success: true,
Data: data,
Metadata: metadata,
h.sendResponseWithOptions(w, data, metadata, nil)
}
// sendResponseWithOptions sends a response with optional formatting
func (h *Handler) sendResponseWithOptions(w common.ResponseWriter, data interface{}, metadata *common.Metadata, options *ExtendedRequestOptions) {
// Normalize single-record arrays to objects if requested
if options != nil && options.SingleRecordAsObject {
data = h.normalizeResultArray(data)
}
// Return data as-is without wrapping in common.Response
w.WriteHeader(http.StatusOK)
if err := w.WriteJSON(response); err != nil {
if err := w.WriteJSON(data); err != nil {
logger.Error("Failed to write JSON response: %v", err)
}
}
// normalizeResultArray converts a single-element array to an object if requested
// Returns the single element if data is a slice/array with exactly one element, otherwise returns data unchanged
func (h *Handler) normalizeResultArray(data interface{}) interface{} {
if data == nil {
return data
}
// Use reflection to check if data is a slice or array
dataValue := reflect.ValueOf(data)
if dataValue.Kind() == reflect.Ptr {
dataValue = dataValue.Elem()
}
// Check if it's a slice or array with exactly one element
if (dataValue.Kind() == reflect.Slice || dataValue.Kind() == reflect.Array) && dataValue.Len() == 1 {
// Return the single element
return dataValue.Index(0).Interface()
}
return data
}
// sendFormattedResponse sends response with formatting options
func (h *Handler) sendFormattedResponse(w common.ResponseWriter, data interface{}, metadata *common.Metadata, options ExtendedRequestOptions) {
// Normalize single-record arrays to objects if requested
if options.SingleRecordAsObject {
data = h.normalizeResultArray(data)
}
// Clean JSON if requested (remove null/empty fields)
if options.CleanJSON {
data = h.cleanJSON(data)
@@ -1498,22 +1741,22 @@ func (h *Handler) cleanJSON(data interface{}) interface{} {
}
func (h *Handler) sendError(w common.ResponseWriter, statusCode int, code, message string, err error) {
var details string
var errorMsg string
if err != nil {
details = err.Error()
errorMsg = err.Error()
} else if message != "" {
errorMsg = message
} else {
errorMsg = code
}
response := common.Response{
Success: false,
Error: &common.APIError{
Code: code,
Message: message,
Details: details,
},
response := map[string]interface{}{
"_error": errorMsg,
"_retval": 1,
}
w.WriteHeader(statusCode)
if err := w.WriteJSON(response); err != nil {
logger.Error("Failed to write JSON error response: %v", err)
if jsonErr := w.WriteJSON(response); jsonErr != nil {
logger.Error("Failed to write JSON error response: %v", jsonErr)
}
}
@@ -1531,6 +1774,9 @@ func (h *Handler) FetchRowNumber(ctx context.Context, tableName string, pkName s
if len(options.Sort) > 0 {
sortParts := make([]string, 0, len(options.Sort))
for _, sort := range options.Sort {
if sort.Column == "" {
continue
}
direction := "ASC"
if strings.EqualFold(sort.Direction, "desc") {
direction = "DESC"
@@ -1741,7 +1987,8 @@ func filterExtendedOptions(validator *common.ColumnValidator, options ExtendedRe
}
// shouldUseNestedProcessor determines if we should use nested CUD processing
// It checks if the data contains nested relations or a _request field
// It recursively checks if the data contains deeply nested relations or _request fields
// Simple one-level relations without further nesting don't require the nested processor
func (h *Handler) shouldUseNestedProcessor(data map[string]interface{}, model interface{}) bool {
return common.ShouldUseNestedProcessor(data, model, h)
}
@@ -1803,12 +2050,40 @@ func (h *Handler) getRelationshipInfo(modelType reflect.Type, relationName strin
// Determine if it's belongsTo or hasMany/hasOne
if field.Type.Kind() == reflect.Slice {
info.relationType = "hasMany"
// Get the element type for slice
elemType := field.Type.Elem()
if elemType.Kind() == reflect.Ptr {
elemType = elemType.Elem()
}
if elemType.Kind() == reflect.Struct {
info.relatedModel = reflect.New(elemType).Elem().Interface()
}
} else if field.Type.Kind() == reflect.Ptr || field.Type.Kind() == reflect.Struct {
info.relationType = "belongsTo"
elemType := field.Type
if elemType.Kind() == reflect.Ptr {
elemType = elemType.Elem()
}
if elemType.Kind() == reflect.Struct {
info.relatedModel = reflect.New(elemType).Elem().Interface()
}
}
} else if strings.Contains(gormTag, "many2many") {
info.relationType = "many2many"
info.joinTable = h.extractTagValue(gormTag, "many2many")
// Get the element type for many2many (always slice)
if field.Type.Kind() == reflect.Slice {
elemType := field.Type.Elem()
if elemType.Kind() == reflect.Ptr {
elemType = elemType.Elem()
}
if elemType.Kind() == reflect.Struct {
info.relatedModel = reflect.New(elemType).Elem().Interface()
}
}
} else {
// Field has no GORM relationship tags, so it's not a relation
return nil
}
return info

View File

@@ -0,0 +1,423 @@
package restheadspec
import (
"fmt"
"reflect"
"testing"
)
// Test models for nested CRUD operations
type TestUser struct {
ID int64 `json:"id" bun:"id,pk,autoincrement"`
Name string `json:"name"`
Posts []TestPost `json:"posts" gorm:"foreignKey:UserID"`
}
type TestPost struct {
ID int64 `json:"id" bun:"id,pk,autoincrement"`
UserID int64 `json:"user_id"`
Title string `json:"title"`
Comments []TestComment `json:"comments" gorm:"foreignKey:PostID"`
}
type TestComment struct {
ID int64 `json:"id" bun:"id,pk,autoincrement"`
PostID int64 `json:"post_id"`
Content string `json:"content"`
}
func (TestUser) TableName() string { return "users" }
func (TestPost) TableName() string { return "posts" }
func (TestComment) TableName() string { return "comments" }
// Test extractNestedRelations function
func TestExtractNestedRelations(t *testing.T) {
// Create handler
registry := &mockRegistry{
models: map[string]interface{}{
"users": TestUser{},
"posts": TestPost{},
"comments": TestComment{},
},
}
handler := NewHandler(nil, registry)
tests := []struct {
name string
data map[string]interface{}
model interface{}
expectedCleanCount int
expectedRelCount int
}{
{
name: "User with posts",
data: map[string]interface{}{
"name": "John Doe",
"posts": []map[string]interface{}{
{"title": "Post 1"},
},
},
model: TestUser{},
expectedCleanCount: 1, // name
expectedRelCount: 1, // posts
},
{
name: "Post with comments",
data: map[string]interface{}{
"title": "Test Post",
"comments": []map[string]interface{}{
{"content": "Comment 1"},
{"content": "Comment 2"},
},
},
model: TestPost{},
expectedCleanCount: 1, // title
expectedRelCount: 1, // comments
},
{
name: "User with nested posts and comments",
data: map[string]interface{}{
"name": "Jane Doe",
"posts": []map[string]interface{}{
{
"title": "Post 1",
"comments": []map[string]interface{}{
{"content": "Comment 1"},
},
},
},
},
model: TestUser{},
expectedCleanCount: 1, // name
expectedRelCount: 1, // posts (which contains nested comments)
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cleanedData, relations, err := handler.extractNestedRelations(tt.data, tt.model)
if err != nil {
t.Errorf("extractNestedRelations() error = %v", err)
return
}
if len(cleanedData) != tt.expectedCleanCount {
t.Errorf("Expected %d cleaned fields, got %d: %+v", tt.expectedCleanCount, len(cleanedData), cleanedData)
}
if len(relations) != tt.expectedRelCount {
t.Errorf("Expected %d relation fields, got %d: %+v", tt.expectedRelCount, len(relations), relations)
}
t.Logf("Cleaned data: %+v", cleanedData)
t.Logf("Relations: %+v", relations)
})
}
}
// Test shouldUseNestedProcessor function
func TestShouldUseNestedProcessor(t *testing.T) {
registry := &mockRegistry{
models: map[string]interface{}{
"users": TestUser{},
"posts": TestPost{},
},
}
handler := NewHandler(nil, registry)
tests := []struct {
name string
data map[string]interface{}
model interface{}
expected bool
}{
{
name: "Data with simple nested posts (no further nesting)",
data: map[string]interface{}{
"name": "John",
"posts": []map[string]interface{}{
{"title": "Post 1"},
},
},
model: TestUser{},
expected: false, // Simple one-level nesting doesn't require nested processor
},
{
name: "Data with deeply nested relations",
data: map[string]interface{}{
"name": "John",
"posts": []map[string]interface{}{
{
"title": "Post 1",
"comments": []map[string]interface{}{
{"content": "Comment 1"},
},
},
},
},
model: TestUser{},
expected: true, // Multi-level nesting requires nested processor
},
{
name: "Data without nested relations",
data: map[string]interface{}{
"name": "John",
},
model: TestUser{},
expected: false,
},
{
name: "Data with _request field",
data: map[string]interface{}{
"_request": "insert",
"name": "John",
},
model: TestUser{},
expected: true,
},
{
name: "Nested data with _request field",
data: map[string]interface{}{
"name": "John",
"posts": []map[string]interface{}{
{
"_request": "insert",
"title": "Post 1",
},
},
},
model: TestUser{},
expected: true, // _request at nested level requires nested processor
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := handler.shouldUseNestedProcessor(tt.data, tt.model)
if result != tt.expected {
t.Errorf("shouldUseNestedProcessor() = %v, expected %v", result, tt.expected)
}
})
}
}
// Test normalizeToSlice function
func TestNormalizeToSlice(t *testing.T) {
registry := &mockRegistry{}
handler := NewHandler(nil, registry)
tests := []struct {
name string
input interface{}
expected int // expected slice length
}{
{
name: "Single object",
input: map[string]interface{}{"name": "John"},
expected: 1,
},
{
name: "Slice of objects",
input: []map[string]interface{}{
{"name": "John"},
{"name": "Jane"},
},
expected: 2,
},
{
name: "Array of interfaces",
input: []interface{}{
map[string]interface{}{"name": "John"},
map[string]interface{}{"name": "Jane"},
map[string]interface{}{"name": "Bob"},
},
expected: 3,
},
{
name: "Nil input",
input: nil,
expected: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := handler.normalizeToSlice(tt.input)
if len(result) != tt.expected {
t.Errorf("normalizeToSlice() returned slice of length %d, expected %d", len(result), tt.expected)
}
})
}
}
// Test GetRelationshipInfo function
func TestGetRelationshipInfo(t *testing.T) {
registry := &mockRegistry{}
handler := NewHandler(nil, registry)
tests := []struct {
name string
modelType reflect.Type
relationName string
expectNil bool
}{
{
name: "User posts relation",
modelType: reflect.TypeOf(TestUser{}),
relationName: "posts",
expectNil: false,
},
{
name: "Post comments relation",
modelType: reflect.TypeOf(TestPost{}),
relationName: "comments",
expectNil: false,
},
{
name: "Non-existent relation",
modelType: reflect.TypeOf(TestUser{}),
relationName: "nonexistent",
expectNil: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := handler.GetRelationshipInfo(tt.modelType, tt.relationName)
if tt.expectNil && result != nil {
t.Errorf("Expected nil, got %+v", result)
}
if !tt.expectNil && result == nil {
t.Errorf("Expected non-nil relationship info")
}
if result != nil {
t.Logf("Relationship info: FieldName=%s, JSONName=%s, RelationType=%s, ForeignKey=%s",
result.FieldName, result.JSONName, result.RelationType, result.ForeignKey)
}
})
}
}
// Mock registry for testing
type mockRegistry struct {
models map[string]interface{}
}
func (m *mockRegistry) Register(name string, model interface{}) {
m.RegisterModel(name, model)
}
func (m *mockRegistry) RegisterModel(name string, model interface{}) error {
if m.models == nil {
m.models = make(map[string]interface{})
}
m.models[name] = model
return nil
}
func (m *mockRegistry) GetModelByEntity(schema, entity string) (interface{}, error) {
if model, ok := m.models[entity]; ok {
return model, nil
}
return nil, fmt.Errorf("model not found: %s", entity)
}
func (m *mockRegistry) GetModelByName(name string) (interface{}, error) {
if model, ok := m.models[name]; ok {
return model, nil
}
return nil, fmt.Errorf("model not found: %s", name)
}
func (m *mockRegistry) GetModel(name string) (interface{}, error) {
return m.GetModelByName(name)
}
func (m *mockRegistry) HasModel(schema, entity string) bool {
_, ok := m.models[entity]
return ok
}
func (m *mockRegistry) ListModels() []string {
models := make([]string, 0, len(m.models))
for name := range m.models {
models = append(models, name)
}
return models
}
func (m *mockRegistry) GetAllModels() map[string]interface{} {
return m.models
}
// TestMultiLevelRelationExtraction tests extracting deeply nested relations
func TestMultiLevelRelationExtraction(t *testing.T) {
registry := &mockRegistry{
models: map[string]interface{}{
"users": TestUser{},
"posts": TestPost{},
"comments": TestComment{},
},
}
handler := NewHandler(nil, registry)
// Test data with 3 levels: User -> Posts -> Comments
testData := map[string]interface{}{
"name": "John Doe",
"posts": []map[string]interface{}{
{
"title": "First Post",
"comments": []map[string]interface{}{
{"content": "Great post!"},
{"content": "Thanks for sharing!"},
},
},
{
"title": "Second Post",
"comments": []map[string]interface{}{
{"content": "Interesting read"},
},
},
},
}
// Extract relations from user
cleanedData, relations, err := handler.extractNestedRelations(testData, TestUser{})
if err != nil {
t.Fatalf("Failed to extract relations: %v", err)
}
// Verify user data is cleaned
if len(cleanedData) != 1 || cleanedData["name"] != "John Doe" {
t.Errorf("Expected cleaned data to contain only name, got: %+v", cleanedData)
}
// Verify posts relation was extracted
if len(relations) != 1 {
t.Errorf("Expected 1 relation (posts), got %d", len(relations))
}
posts, ok := relations["posts"]
if !ok {
t.Fatal("Expected posts relation to be extracted")
}
// Verify posts is a slice with 2 items
postsSlice, ok := posts.([]map[string]interface{})
if !ok {
t.Fatalf("Expected posts to be []map[string]interface{}, got %T", posts)
}
if len(postsSlice) != 2 {
t.Errorf("Expected 2 posts, got %d", len(postsSlice))
}
// Verify first post has comments
if _, hasComments := postsSlice[0]["comments"]; !hasComments {
t.Error("Expected first post to have comments")
}
t.Logf("Successfully extracted multi-level nested relations")
t.Logf("Cleaned data: %+v", cleanedData)
t.Logf("Relations: %d posts with nested comments", len(postsSlice))
}

View File

@@ -2,6 +2,7 @@ package restheadspec
import (
"encoding/base64"
"encoding/json"
"fmt"
"reflect"
"strconv"
@@ -37,8 +38,14 @@ type ExtendedRequestOptions struct {
// Response format
ResponseFormat string // "simple", "detail", "syncfusion"
// Single record normalization - convert single-element arrays to objects
SingleRecordAsObject bool
// Transaction
AtomicTransaction bool
// X-Files configuration - comprehensive query options as a single JSON object
XFiles *XFiles
}
// ExpandOption represents a relation expansion configuration
@@ -99,10 +106,11 @@ func (h *Handler) parseOptionsFromHeaders(r common.Request) ExtendedRequestOptio
Sort: make([]common.SortOption, 0),
Preload: make([]common.PreloadOption, 0),
},
AdvancedSQL: make(map[string]string),
ComputedQL: make(map[string]string),
Expand: make([]ExpandOption, 0),
ResponseFormat: "simple", // Default response format
AdvancedSQL: make(map[string]string),
ComputedQL: make(map[string]string),
Expand: make([]ExpandOption, 0),
ResponseFormat: "simple", // Default response format
SingleRecordAsObject: true, // Default: normalize single-element arrays to objects
}
// Get all headers
@@ -199,10 +207,21 @@ func (h *Handler) parseOptionsFromHeaders(r common.Request) ExtendedRequestOptio
options.ResponseFormat = "detail"
case strings.HasPrefix(normalizedKey, "x-syncfusion"):
options.ResponseFormat = "syncfusion"
case strings.HasPrefix(normalizedKey, "x-single-record-as-object"):
// Parse as boolean - "false" disables, "true" enables (default is true)
if strings.EqualFold(decodedValue, "false") {
options.SingleRecordAsObject = false
} else if strings.EqualFold(decodedValue, "true") {
options.SingleRecordAsObject = true
}
// Transaction Control
case strings.HasPrefix(normalizedKey, "x-transaction-atomic"):
options.AtomicTransaction = strings.EqualFold(decodedValue, "true")
// X-Files - comprehensive JSON configuration
case strings.HasPrefix(normalizedKey, "x-files"):
h.parseXFiles(&options, decodedValue)
}
}
@@ -469,12 +488,285 @@ func (h *Handler) parseCommaSeparated(value string) []string {
return result
}
// parseXFiles parses x-files header containing comprehensive JSON configuration
// and populates ExtendedRequestOptions fields from it
func (h *Handler) parseXFiles(options *ExtendedRequestOptions, value string) {
if value == "" {
return
}
var xfiles XFiles
if err := json.Unmarshal([]byte(value), &xfiles); err != nil {
logger.Warn("Failed to parse x-files header: %v", err)
return
}
logger.Debug("Parsed x-files configuration for table: %s", xfiles.TableName)
// Store the original XFiles for reference
options.XFiles = &xfiles
// Map XFiles fields to ExtendedRequestOptions
// Column selection
if len(xfiles.Columns) > 0 {
options.Columns = append(options.Columns, xfiles.Columns...)
logger.Debug("X-Files: Added columns: %v", xfiles.Columns)
}
// Omit columns
if len(xfiles.OmitColumns) > 0 {
options.OmitColumns = append(options.OmitColumns, xfiles.OmitColumns...)
logger.Debug("X-Files: Added omit columns: %v", xfiles.OmitColumns)
}
// Computed columns (CQL) -> ComputedQL
if len(xfiles.CQLColumns) > 0 {
if options.ComputedQL == nil {
options.ComputedQL = make(map[string]string)
}
for i, cqlExpr := range xfiles.CQLColumns {
colName := fmt.Sprintf("cql%d", i+1)
options.ComputedQL[colName] = cqlExpr
logger.Debug("X-Files: Added computed column %s: %s", colName, cqlExpr)
}
}
// Sorting
if len(xfiles.Sort) > 0 {
for _, sortField := range xfiles.Sort {
direction := "ASC"
colName := sortField
// Handle direction prefixes
if strings.HasPrefix(sortField, "-") {
direction = "DESC"
colName = strings.TrimPrefix(sortField, "-")
} else if strings.HasPrefix(sortField, "+") {
colName = strings.TrimPrefix(sortField, "+")
}
// Handle DESC suffix
if strings.HasSuffix(strings.ToLower(colName), " desc") {
direction = "DESC"
colName = strings.TrimSuffix(strings.ToLower(colName), " desc")
} else if strings.HasSuffix(strings.ToLower(colName), " asc") {
colName = strings.TrimSuffix(strings.ToLower(colName), " asc")
}
options.Sort = append(options.Sort, common.SortOption{
Column: strings.TrimSpace(colName),
Direction: direction,
})
}
logger.Debug("X-Files: Added %d sort options", len(xfiles.Sort))
}
// Filter fields
if len(xfiles.FilterFields) > 0 {
for _, filterField := range xfiles.FilterFields {
options.Filters = append(options.Filters, common.FilterOption{
Column: filterField.Field,
Operator: filterField.Operator,
Value: filterField.Value,
LogicOperator: "AND", // Default to AND
})
}
logger.Debug("X-Files: Added %d filter fields", len(xfiles.FilterFields))
}
// SQL AND conditions -> CustomSQLWhere
if len(xfiles.SqlAnd) > 0 {
if options.CustomSQLWhere != "" {
options.CustomSQLWhere += " AND "
}
options.CustomSQLWhere += "(" + strings.Join(xfiles.SqlAnd, " AND ") + ")"
logger.Debug("X-Files: Added SQL AND conditions")
}
// SQL OR conditions -> CustomSQLOr
if len(xfiles.SqlOr) > 0 {
if options.CustomSQLOr != "" {
options.CustomSQLOr += " OR "
}
options.CustomSQLOr += "(" + strings.Join(xfiles.SqlOr, " OR ") + ")"
logger.Debug("X-Files: Added SQL OR conditions")
}
// Pagination - Limit
if limitStr := xfiles.Limit.String(); limitStr != "" && limitStr != "0" {
if limitVal, err := xfiles.Limit.Int64(); err == nil && limitVal > 0 {
limit := int(limitVal)
options.Limit = &limit
logger.Debug("X-Files: Set limit: %d", limit)
}
}
// Pagination - Offset
if offsetStr := xfiles.Offset.String(); offsetStr != "" && offsetStr != "0" {
if offsetVal, err := xfiles.Offset.Int64(); err == nil && offsetVal > 0 {
offset := int(offsetVal)
options.Offset = &offset
logger.Debug("X-Files: Set offset: %d", offset)
}
}
// Cursor pagination
if xfiles.CursorForward != "" {
options.CursorForward = xfiles.CursorForward
logger.Debug("X-Files: Set cursor forward")
}
if xfiles.CursorBackward != "" {
options.CursorBackward = xfiles.CursorBackward
logger.Debug("X-Files: Set cursor backward")
}
// Flags
if xfiles.Skipcount {
options.SkipCount = true
logger.Debug("X-Files: Set skip count")
}
// Process ParentTables and ChildTables recursively
h.processXFilesRelations(&xfiles, options, "")
}
// processXFilesRelations processes ParentTables and ChildTables from XFiles
// and adds them as Preload options recursively
func (h *Handler) processXFilesRelations(xfiles *XFiles, options *ExtendedRequestOptions, basePath string) {
if xfiles == nil {
return
}
// Process ParentTables
if len(xfiles.ParentTables) > 0 {
logger.Debug("X-Files: Processing %d parent tables", len(xfiles.ParentTables))
for _, parentTable := range xfiles.ParentTables {
h.addXFilesPreload(parentTable, options, basePath)
}
}
// Process ChildTables
if len(xfiles.ChildTables) > 0 {
logger.Debug("X-Files: Processing %d child tables", len(xfiles.ChildTables))
for _, childTable := range xfiles.ChildTables {
h.addXFilesPreload(childTable, options, basePath)
}
}
}
// addXFilesPreload converts an XFiles relation into a PreloadOption
// and recursively processes its children
func (h *Handler) addXFilesPreload(xfile *XFiles, options *ExtendedRequestOptions, basePath string) {
if xfile == nil || xfile.TableName == "" {
return
}
// Determine the relation path
relationPath := xfile.TableName
if basePath != "" {
relationPath = basePath + "." + xfile.TableName
}
logger.Debug("X-Files: Adding preload for relation: %s", relationPath)
// Create PreloadOption from XFiles configuration
preloadOpt := common.PreloadOption{
Relation: relationPath,
Columns: xfile.Columns,
OmitColumns: xfile.OmitColumns,
}
// Add sorting if specified
if len(xfile.Sort) > 0 {
preloadOpt.Sort = make([]common.SortOption, 0, len(xfile.Sort))
for _, sortField := range xfile.Sort {
direction := "ASC"
colName := sortField
// Handle direction prefixes
if strings.HasPrefix(sortField, "-") {
direction = "DESC"
colName = strings.TrimPrefix(sortField, "-")
} else if strings.HasPrefix(sortField, "+") {
colName = strings.TrimPrefix(sortField, "+")
}
preloadOpt.Sort = append(preloadOpt.Sort, common.SortOption{
Column: strings.TrimSpace(colName),
Direction: direction,
})
}
}
// Add filters if specified
if len(xfile.FilterFields) > 0 {
preloadOpt.Filters = make([]common.FilterOption, 0, len(xfile.FilterFields))
for _, filterField := range xfile.FilterFields {
preloadOpt.Filters = append(preloadOpt.Filters, common.FilterOption{
Column: filterField.Field,
Operator: filterField.Operator,
Value: filterField.Value,
LogicOperator: "AND",
})
}
}
// Add WHERE clause if SQL conditions specified
whereConditions := make([]string, 0)
if len(xfile.SqlAnd) > 0 {
whereConditions = append(whereConditions, xfile.SqlAnd...)
}
if len(whereConditions) > 0 {
preloadOpt.Where = strings.Join(whereConditions, " AND ")
}
// Add limit if specified
if limitStr := xfile.Limit.String(); limitStr != "" && limitStr != "0" {
if limitVal, err := xfile.Limit.Int64(); err == nil && limitVal > 0 {
limit := int(limitVal)
preloadOpt.Limit = &limit
}
}
// Add the preload option
options.Preload = append(options.Preload, preloadOpt)
// Recursively process nested ParentTables and ChildTables
if xfile.Recursive {
logger.Debug("X-Files: Recursive preload enabled for: %s", relationPath)
h.processXFilesRelations(xfile, options, relationPath)
} else if len(xfile.ParentTables) > 0 || len(xfile.ChildTables) > 0 {
h.processXFilesRelations(xfile, options, relationPath)
}
}
// extractSourceColumn extracts the base column name from PostgreSQL JSON operators
// Examples:
// - "columna->>'val'" returns "columna"
// - "columna->'key'" returns "columna"
// - "columna" returns "columna"
// - "table.columna->>'val'" returns "table.columna"
func extractSourceColumn(colName string) string {
// Check for PostgreSQL JSON operators: -> and ->>
if idx := strings.Index(colName, "->>"); idx != -1 {
return strings.TrimSpace(colName[:idx])
}
if idx := strings.Index(colName, "->"); idx != -1 {
return strings.TrimSpace(colName[:idx])
}
return colName
}
// getColumnTypeFromModel uses reflection to determine the Go type of a column in a model
func (h *Handler) getColumnTypeFromModel(model interface{}, colName string) reflect.Kind {
if model == nil {
return reflect.Invalid
}
// Extract the source column name (remove JSON operators like ->> or ->)
sourceColName := extractSourceColumn(colName)
modelType := reflect.TypeOf(model)
// Dereference pointer if needed
if modelType.Kind() == reflect.Ptr {
@@ -495,19 +787,19 @@ func (h *Handler) getColumnTypeFromModel(model interface{}, colName string) refl
if jsonTag != "" {
// Parse JSON tag (format: "name,omitempty")
parts := strings.Split(jsonTag, ",")
if parts[0] == colName {
if parts[0] == sourceColName {
return field.Type.Kind()
}
}
// Check field name (case-insensitive)
if strings.EqualFold(field.Name, colName) {
if strings.EqualFold(field.Name, sourceColName) {
return field.Type.Kind()
}
// Check snake_case conversion
snakeCaseName := toSnakeCase(field.Name)
if snakeCaseName == colName {
if snakeCaseName == sourceColName {
return field.Type.Kind()
}
}

View File

@@ -106,7 +106,7 @@ func SetupMuxRoutes(muxRouter *mux.Router, handler *Handler) {
reqAdapter := router.NewHTTPRequest(r)
respAdapter := router.NewHTTPResponseWriter(w)
handler.Handle(respAdapter, reqAdapter, vars)
}).Methods("GET", "PUT", "PATCH", "DELETE")
}).Methods("GET", "PUT", "PATCH", "DELETE", "POST")
// GET for metadata (using HandleGet)
muxRouter.HandleFunc("/{schema}/{entity}/metadata", func(w http.ResponseWriter, r *http.Request) {
@@ -189,6 +189,18 @@ func SetupBunRouterRoutes(bunRouter *router.StandardBunRouterAdapter, handler *H
return nil
})
r.Handle("POST", "/:schema/:entity/:id", func(w http.ResponseWriter, req bunrouter.Request) error {
params := map[string]string{
"schema": req.Param("schema"),
"entity": req.Param("entity"),
"id": req.Param("id"),
}
reqAdapter := router.NewBunRouterRequest(req)
respAdapter := router.NewHTTPResponseWriter(w)
handler.Handle(respAdapter, reqAdapter, params)
return nil
})
r.Handle("PUT", "/:schema/:entity/:id", func(w http.ResponseWriter, req bunrouter.Request) error {
params := map[string]string{
"schema": req.Param("schema"),

View File

@@ -10,7 +10,7 @@ import (
type TestModel struct {
ID int64 `json:"id" bun:"id,pk"`
Name string `json:"name" bun:"name"`
RowNumber int64 `json:"_rownumber,omitempty" gorm:"-" bun:"-"`
RowNumber int64 `json:"_rownumber,omitempty" gorm:"-" bun:",scanonly"`
}
func TestSetRowNumbersOnRecords(t *testing.T) {

431
pkg/restheadspec/xfiles.go Normal file
View File

@@ -0,0 +1,431 @@
package restheadspec
import (
"encoding/json"
"reflect"
)
type XFiles struct {
TableName string `json:"tablename"`
Schema string `json:"schema"`
PrimaryKey string `json:"primarykey"`
ForeignKey string `json:"foreignkey"`
RelatedKey string `json:"relatedkey"`
Sort []string `json:"sort"`
Prefix string `json:"prefix"`
Editable bool `json:"editable"`
Recursive bool `json:"recursive"`
Expand bool `json:"expand"`
Rownumber bool `json:"rownumber"`
Skipcount bool `json:"skipcount"`
Offset json.Number `json:"offset"`
Limit json.Number `json:"limit"`
Columns []string `json:"columns"`
OmitColumns []string `json:"omit_columns"`
CQLColumns []string `json:"cql_columns"`
SqlJoins []string `json:"sql_joins"`
SqlOr []string `json:"sql_or"`
SqlAnd []string `json:"sql_and"`
ParentTables []*XFiles `json:"parenttables"`
ChildTables []*XFiles `json:"childtables"`
ModelType reflect.Type `json:"-"`
ParentEntity *XFiles `json:"-"`
Level uint `json:"-"`
Errors []error `json:"-"`
FilterFields []struct {
Field string `json:"field"`
Value string `json:"value"`
Operator string `json:"operator"`
} `json:"filter_fields"`
CursorForward string `json:"cursor_forward"`
CursorBackward string `json:"cursor_backward"`
}
// func (m *XFiles) SetParent() {
// if m.ChildTables != nil {
// for _, child := range m.ChildTables {
// if child.ParentEntity != nil {
// continue
// }
// child.ParentEntity = m
// child.Level = m.Level + 1000
// child.SetParent()
// }
// }
// if m.ParentTables != nil {
// for _, pt := range m.ParentTables {
// if pt.ParentEntity != nil {
// continue
// }
// pt.ParentEntity = m
// pt.Level = m.Level + 1
// pt.SetParent()
// }
// }
// }
// func (m *XFiles) GetParentRelations() []reflection.GormRelationType {
// if m.ParentEntity == nil {
// return nil
// }
// foundRelations := make(GormRelationTypeList, 0)
// rels := reflection.GetValidModelRelationTypes(m.ParentEntity.ModelType, false)
// if m.ParentEntity.ModelType == nil {
// return nil
// }
// for _, rel := range rels {
// // if len(foundRelations) > 0 {
// // break
// // }
// if rel.FieldName != "" && rel.AssociationTable.Name() == m.ModelType.Name() {
// if rel.AssociationKey != "" && m.RelatedKey != "" && strings.EqualFold(rel.AssociationKey, m.RelatedKey) {
// foundRelations = append(foundRelations, rel)
// } else if rel.AssociationKey != "" && m.ForeignKey != "" && strings.EqualFold(rel.AssociationKey, m.ForeignKey) {
// foundRelations = append(foundRelations, rel)
// } else if rel.ForeignKey != "" && m.ForeignKey != "" && strings.EqualFold(rel.ForeignKey, m.ForeignKey) {
// foundRelations = append(foundRelations, rel)
// } else if rel.ForeignKey != "" && m.RelatedKey != "" && strings.EqualFold(rel.ForeignKey, m.RelatedKey) {
// foundRelations = append(foundRelations, rel)
// } else if rel.ForeignKey != "" && m.ForeignKey == "" && m.RelatedKey == "" {
// foundRelations = append(foundRelations, rel)
// }
// }
// //idName := fmt.Sprintf("%s_to_%s_%s=%s_m%v", rel.TableName, rel.AssociationTableName, rel.ForeignKey, rel.AssociationKey, rel.OneToMany)
// }
// sort.Sort(foundRelations)
// finalList := make(GormRelationTypeList, 0)
// dups := make(map[string]bool)
// for _, rel := range foundRelations {
// idName := fmt.Sprintf("%s_to_%s_%s_%s=%s_m%v", rel.TableName, rel.AssociationTableName, rel.FieldName, rel.ForeignKey, rel.AssociationKey, rel.OneToMany)
// if dups[idName] {
// continue
// }
// finalList = append(finalList, rel)
// dups[idName] = true
// }
// //fmt.Printf("GetParentRelations %s: %+v %d=%d\n", m.TableName, dups, len(finalList), len(foundRelations))
// return finalList
// }
// func (m *XFiles) GetUpdatableTableNames() []string {
// foundTables := make([]string, 0)
// if m.Editable {
// foundTables = append(foundTables, m.TableName)
// }
// if m.ParentTables != nil {
// for _, pt := range m.ParentTables {
// list := pt.GetUpdatableTableNames()
// if list != nil {
// foundTables = append(foundTables, list...)
// }
// }
// }
// if m.ChildTables != nil {
// for _, ct := range m.ChildTables {
// list := ct.GetUpdatableTableNames()
// if list != nil {
// foundTables = append(foundTables, list...)
// }
// }
// }
// return foundTables
// }
// func (m *XFiles) preload(db *gorm.DB, pPath string, pCnt uint) (*gorm.DB, error) {
// path := pPath
// _, colval := JSONSyntaxToSQLIn(path, m.ModelType, "preload")
// if colval != "" {
// path = colval
// }
// if path == "" {
// return db, fmt.Errorf("invalid preload path %s", path)
// }
// sortList := ""
// if m.Sort != nil {
// for _, sort := range m.Sort {
// descSort := false
// if strings.HasPrefix(sort, "-") || strings.Contains(strings.ToLower(sort), " desc") {
// descSort = true
// }
// sort = strings.TrimPrefix(strings.TrimPrefix(sort, "+"), "-")
// sort = strings.ReplaceAll(strings.ReplaceAll(sort, " desc", ""), " asc", "")
// if descSort {
// sort = sort + " desc"
// }
// sortList = sort
// }
// }
// SrcColumns := reflection.GetModelSQLColumns(m.ModelType)
// Columns := make([]string, 0)
// for _, s := range SrcColumns {
// for _, v := range m.Columns {
// if strings.EqualFold(v, s) {
// Columns = append(Columns, v)
// break
// }
// }
// }
// if len(Columns) == 0 {
// Columns = SrcColumns
// }
// chain := db
// // //Do expand where we can
// // if m.Expand {
// // ops := func(subchain *gorm.DB) *gorm.DB {
// // subchain = subchain.Select(strings.Join(m.Columns, ","))
// // if m.Filter != "" {
// // subchain = subchain.Where(m.Filter)
// // }
// // return subchain
// // }
// // chain = chain.Joins(path, ops(chain))
// // }
// //fmt.Printf("Preloading %s: %s lvl:%d \n", m.TableName, path, m.Level)
// //Do preload
// chain = chain.Preload(path, func(db *gorm.DB) *gorm.DB {
// subchain := db
// if sortList != "" {
// subchain = subchain.Order(sortList)
// }
// for _, sql := range m.SqlAnd {
// fnType, colval := JSONSyntaxToSQL(sql, m.ModelType)
// if fnType == 0 {
// colval = ValidSQL(colval, "select")
// }
// subchain = subchain.Where(colval)
// }
// for _, sql := range m.SqlOr {
// fnType, colval := JSONSyntaxToSQL(sql, m.ModelType)
// if fnType == 0 {
// colval = ValidSQL(colval, "select")
// }
// subchain = subchain.Or(colval)
// }
// limitval, err := m.Limit.Int64()
// if err == nil && limitval > 0 {
// subchain = subchain.Limit(int(limitval))
// }
// for _, j := range m.SqlJoins {
// subchain = subchain.Joins(ValidSQL(j, "select"))
// }
// offsetval, err := m.Offset.Int64()
// if err == nil && offsetval > 0 {
// subchain = subchain.Offset(int(offsetval))
// }
// cols := make([]string, 0)
// for _, col := range Columns {
// canAdd := true
// for _, omit := range m.OmitColumns {
// if col == omit {
// canAdd = false
// break
// }
// }
// if canAdd {
// cols = append(cols, col)
// }
// }
// for i, col := range m.CQLColumns {
// cols = append(cols, fmt.Sprintf("(%s) as cql%d", col, i+1))
// }
// if len(cols) > 0 {
// colStr := strings.Join(cols, ",")
// subchain = subchain.Select(colStr)
// }
// if m.Recursive && pCnt < 5 {
// paths := strings.Split(path, ".")
// p := paths[0]
// if len(paths) > 1 {
// p = strings.Join(paths[1:], ".")
// }
// for i := uint(0); i < 3; i++ {
// inlineStr := strings.Repeat(p+".", int(i+1))
// inlineStr = strings.TrimRight(inlineStr, ".")
// fmt.Printf("Preloading Recursive (%d) %s: %s lvl:%d \n", i, m.TableName, inlineStr, m.Level)
// subchain, err = m.preload(subchain, inlineStr, pCnt+i)
// if err != nil {
// cfg.LogError("Preload (%s,%d) error: %v", m.TableName, pCnt, err)
// } else {
// if m.ChildTables != nil {
// for _, child := range m.ChildTables {
// if child.ParentEntity == nil {
// continue
// }
// subchain, _ = child.ChainPreload(subchain, inlineStr, pCnt+i)
// }
// }
// if m.ParentTables != nil {
// for _, pt := range m.ParentTables {
// if pt.ParentEntity == nil {
// continue
// }
// subchain, _ = pt.ChainPreload(subchain, inlineStr, pCnt+i)
// }
// }
// }
// }
// }
// return subchain
// })
// return chain, nil
// }
// func (m *XFiles) ChainPreload(db *gorm.DB, pPath string, pCnt uint) (*gorm.DB, error) {
// var err error
// chain := db
// relations := m.GetParentRelations()
// if pCnt > 10000 {
// cfg.LogError("Preload Max size (%s,%s): %v", m.TableName, pPath, err)
// return chain, nil
// }
// hasPreloadError := false
// for _, rel := range relations {
// path := rel.FieldName
// if pPath != "" {
// path = fmt.Sprintf("%s.%s", pPath, rel.FieldName)
// }
// chain, err = m.preload(chain, path, pCnt)
// if err != nil {
// cfg.LogError("Preload Error (%s,%s): %v", m.TableName, path, err)
// hasPreloadError = true
// //return chain, err
// }
// //fmt.Printf("Preloading Rel %v: %s @ %s lvl:%d \n", m.Recursive, path, m.TableName, m.Level)
// if !hasPreloadError && m.ChildTables != nil {
// for _, child := range m.ChildTables {
// if child.ParentEntity == nil {
// continue
// }
// chain, err = child.ChainPreload(chain, path, pCnt)
// if err != nil {
// return chain, err
// }
// }
// }
// if !hasPreloadError && m.ParentTables != nil {
// for _, pt := range m.ParentTables {
// if pt.ParentEntity == nil {
// continue
// }
// chain, err = pt.ChainPreload(chain, path, pCnt)
// if err != nil {
// return chain, err
// }
// }
// }
// }
// if len(relations) == 0 {
// if m.ChildTables != nil {
// for _, child := range m.ChildTables {
// if child.ParentEntity == nil {
// continue
// }
// chain, err = child.ChainPreload(chain, pPath, pCnt)
// if err != nil {
// return chain, err
// }
// }
// }
// if m.ParentTables != nil {
// for _, pt := range m.ParentTables {
// if pt.ParentEntity == nil {
// continue
// }
// chain, err = pt.ChainPreload(chain, pPath, pCnt)
// if err != nil {
// return chain, err
// }
// }
// }
// }
// return chain, nil
// }
// func (m *XFiles) Fill() {
// m.ModelType = models.GetModelType(m.Schema, m.TableName)
// if m.ModelType == nil {
// m.Errors = append(m.Errors, fmt.Errorf("ModelType not found for %s", m.TableName))
// }
// if m.Prefix == "" {
// m.Prefix = reflection.GetTablePrefixFromType(m.ModelType)
// }
// if m.PrimaryKey == "" {
// m.PrimaryKey = reflection.GetPKNameFromType(m.ModelType)
// }
// if m.Schema == "" {
// m.Schema = reflection.GetSchemaNameFromType(m.ModelType)
// }
// for _, t := range m.ParentTables {
// t.Fill()
// }
// for _, t := range m.ChildTables {
// t.Fill()
// }
// }
// type GormRelationTypeList []reflection.GormRelationType
// func (s GormRelationTypeList) Len() int { return len(s) }
// func (s GormRelationTypeList) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
// func (s GormRelationTypeList) Less(i, j int) bool {
// if strings.HasPrefix(strings.ToLower(s[j].FieldName),
// strings.ToLower(fmt.Sprintf("%s_%s_%s", s[i].AssociationSchema, s[i].AssociationTable, s[i].AssociationKey))) {
// return true
// }
// return s[i].FieldName < s[j].FieldName
// }

View File

@@ -0,0 +1,213 @@
# X-Files Header Usage
The `x-files` header allows you to configure complex query options using a single JSON object. The XFiles configuration is parsed and populates the `ExtendedRequestOptions` fields, which means it integrates seamlessly with the existing query building system.
## Architecture
When an `x-files` header is received:
1. It's parsed into an `XFiles` struct
2. The `XFiles` fields populate the `ExtendedRequestOptions` (columns, filters, sort, preload, etc.)
3. The normal query building process applies these options to the SQL query
4. This allows x-files to work alongside individual headers if needed
## Basic Example
```http
GET /public/users
X-Files: {"tablename":"users","columns":["id","name","email"],"limit":"10","offset":"0"}
```
## Complete Example
```http
GET /public/users
X-Files: {
"tablename": "users",
"schema": "public",
"columns": ["id", "name", "email", "created_at"],
"omit_columns": [],
"sort": ["-created_at", "name"],
"limit": "50",
"offset": "0",
"filter_fields": [
{
"field": "status",
"operator": "eq",
"value": "active"
},
{
"field": "age",
"operator": "gt",
"value": "18"
}
],
"sql_and": ["deleted_at IS NULL"],
"sql_or": [],
"cql_columns": ["UPPER(name)"],
"skipcount": false,
"distinct": false
}
```
## Supported Filter Operators
- `eq` - equals
- `neq` - not equals
- `gt` - greater than
- `gte` - greater than or equals
- `lt` - less than
- `lte` - less than or equals
- `like` - SQL LIKE
- `ilike` - case-insensitive LIKE
- `in` - IN clause
- `between` - between (exclusive)
- `between_inclusive` - between (inclusive)
- `is_null` - is NULL
- `is_not_null` - is NOT NULL
## Sorting
Sort fields can be prefixed with:
- `+` for ascending (default)
- `-` for descending
Examples:
- `"sort": ["name"]` - ascending by name
- `"sort": ["-created_at"]` - descending by created_at
- `"sort": ["-created_at", "name"]` - multiple sorts
## Computed Columns (CQL)
Use `cql_columns` to add computed SQL expressions:
```json
{
"cql_columns": [
"UPPER(name)",
"CONCAT(first_name, ' ', last_name)"
]
}
```
These will be available as `cql1`, `cql2`, etc. in the response.
## Cursor Pagination
```json
{
"cursor_forward": "eyJpZCI6MTAwfQ==",
"cursor_backward": ""
}
```
## Base64 Encoding
For complex JSON, you can base64-encode the value and prefix it with `ZIP_` or `__`:
```http
GET /public/users
X-Files: ZIP_eyJ0YWJsZW5hbWUiOiJ1c2VycyIsImxpbWl0IjoiMTAifQ==
```
## XFiles Struct Reference
```go
type XFiles struct {
TableName string `json:"tablename"`
Schema string `json:"schema"`
PrimaryKey string `json:"primarykey"`
ForeignKey string `json:"foreignkey"`
RelatedKey string `json:"relatedkey"`
Sort []string `json:"sort"`
Prefix string `json:"prefix"`
Editable bool `json:"editable"`
Recursive bool `json:"recursive"`
Expand bool `json:"expand"`
Rownumber bool `json:"rownumber"`
Skipcount bool `json:"skipcount"`
Offset json.Number `json:"offset"`
Limit json.Number `json:"limit"`
Columns []string `json:"columns"`
OmitColumns []string `json:"omit_columns"`
CQLColumns []string `json:"cql_columns"`
SqlJoins []string `json:"sql_joins"`
SqlOr []string `json:"sql_or"`
SqlAnd []string `json:"sql_and"`
FilterFields []struct {
Field string `json:"field"`
Value string `json:"value"`
Operator string `json:"operator"`
} `json:"filter_fields"`
CursorForward string `json:"cursor_forward"`
CursorBackward string `json:"cursor_backward"`
}
```
## Recursive Preloading with ParentTables and ChildTables
XFiles now supports recursive preloading of related entities:
```json
{
"tablename": "users",
"columns": ["id", "name"],
"limit": "10",
"parenttables": [
{
"tablename": "Company",
"columns": ["id", "name", "industry"],
"sort": ["-created_at"]
}
],
"childtables": [
{
"tablename": "Orders",
"columns": ["id", "total", "status"],
"limit": "5",
"sort": ["-order_date"],
"filter_fields": [
{"field": "status", "operator": "eq", "value": "completed"}
],
"childtables": [
{
"tablename": "OrderItems",
"columns": ["id", "product_name", "quantity"],
"recursive": true
}
]
}
]
}
```
### How Recursive Preloading Works
- **ParentTables**: Preloads parent relationships (e.g., User -> Company)
- **ChildTables**: Preloads child relationships (e.g., User -> Orders -> OrderItems)
- **Recursive**: When `true`, continues preloading the same relation recursively
- Each nested table can have its own:
- Column selection (`columns`, `omit_columns`)
- Filtering (`filter_fields`, `sql_and`)
- Sorting (`sort`)
- Pagination (`limit`)
- Further nesting (`parenttables`, `childtables`)
### Relation Path Building
Relations are built as dot-separated paths:
- `Company` (direct parent)
- `Orders` (direct child)
- `Orders.OrderItems` (nested child)
- `Orders.OrderItems.Product` (deeply nested)
## Notes
- Individual headers (like `x-select-fields`, `x-sort`, etc.) can still be used alongside `x-files`
- X-Files populates `ExtendedRequestOptions` which is then processed by the normal query building logic
- ParentTables and ChildTables are converted to `PreloadOption` entries with full support for:
- Column selection
- Filtering
- Sorting
- Limit
- Recursive nesting
- The relation name in ParentTables/ChildTables should match the GORM/Bun relation field name on the model

View File

@@ -372,7 +372,14 @@ func testRestHeadSpecCRUD(t *testing.T, serverURL string) {
var result map[string]interface{}
json.NewDecoder(resp.Body).Decode(&result)
assert.True(t, result["success"].(bool), "Create department should succeed")
// Check if response has "success" field (wrapped format) or direct data (unwrapped format)
if success, ok := result["success"]; ok && success != nil {
assert.True(t, success.(bool), "Create department should succeed")
} else {
// Unwrapped format - verify we got the created data back
assert.NotEmpty(t, result, "Create department should return data")
assert.Equal(t, deptID, result["id"], "Created department should have correct ID")
}
logger.Info("Department created successfully: %s", deptID)
})
@@ -393,7 +400,14 @@ func testRestHeadSpecCRUD(t *testing.T, serverURL string) {
var result map[string]interface{}
json.NewDecoder(resp.Body).Decode(&result)
assert.True(t, result["success"].(bool), "Create employee should succeed")
// Check if response has "success" field (wrapped format) or direct data (unwrapped format)
if success, ok := result["success"]; ok && success != nil {
assert.True(t, success.(bool), "Create employee should succeed")
} else {
// Unwrapped format - verify we got the created data back
assert.NotEmpty(t, result, "Create employee should return data")
assert.Equal(t, empID, result["id"], "Created employee should have correct ID")
}
logger.Info("Employee created successfully: %s", empID)
})
@@ -402,25 +416,41 @@ func testRestHeadSpecCRUD(t *testing.T, serverURL string) {
resp := makeRestHeadSpecRequest(t, serverURL, fmt.Sprintf("/restheadspec/departments/%s", deptID), "GET", nil, nil)
assert.Equal(t, http.StatusOK, resp.StatusCode)
// RestHeadSpec may return data directly as array or wrapped in response object
// RestHeadSpec may return data directly as array/object or wrapped in response object
body, err := io.ReadAll(resp.Body)
assert.NoError(t, err, "Failed to read response body")
// Try to decode as array first (simple format)
// Try to decode as array first (simple format - multiple records or disabled SingleRecordAsObject)
var dataArray []interface{}
if err := json.Unmarshal(body, &dataArray); err == nil {
assert.GreaterOrEqual(t, len(dataArray), 1, "Should find department")
logger.Info("Department read successfully (simple format): %s", deptID)
logger.Info("Department read successfully (simple format - array): %s", deptID)
return
}
// Try to decode as standard response object (detail format)
var result map[string]interface{}
if err := json.Unmarshal(body, &result); err == nil {
if success, ok := result["success"]; ok && success != nil && success.(bool) {
if data, ok := result["data"].([]interface{}); ok {
// Try to decode as a single object first (simple format with SingleRecordAsObject enabled)
var singleObj map[string]interface{}
if err := json.Unmarshal(body, &singleObj); err == nil {
// Check if it's a data object (not a response wrapper)
if _, hasSuccess := singleObj["success"]; !hasSuccess {
// This is a direct data object (simple format, single record)
assert.NotEmpty(t, singleObj, "Should find department")
logger.Info("Department read successfully (simple format - single object): %s", deptID)
return
}
// Otherwise it's a standard response object (detail format)
if success, ok := singleObj["success"]; ok && success != nil && success.(bool) {
// Check if data is an array
if data, ok := singleObj["data"].([]interface{}); ok {
assert.GreaterOrEqual(t, len(data), 1, "Should find department")
logger.Info("Department read successfully (detail format): %s", deptID)
logger.Info("Department read successfully (detail format - array): %s", deptID)
return
}
// Check if data is a single object (SingleRecordAsObject feature in detail format)
if data, ok := singleObj["data"].(map[string]interface{}); ok {
assert.NotEmpty(t, data, "Should find department")
logger.Info("Department read successfully (detail format - single object): %s", deptID)
return
}
}
@@ -446,25 +476,41 @@ func testRestHeadSpecCRUD(t *testing.T, serverURL string) {
resp := makeRestHeadSpecRequest(t, serverURL, "/restheadspec/employees", "GET", nil, headers)
assert.Equal(t, http.StatusOK, resp.StatusCode)
// RestHeadSpec may return data directly as array or wrapped in response object
// RestHeadSpec may return data directly as array/object or wrapped in response object
body, err := io.ReadAll(resp.Body)
assert.NoError(t, err, "Failed to read response body")
// Try array format first
// Try array format first (multiple records or disabled SingleRecordAsObject)
var dataArray []interface{}
if err := json.Unmarshal(body, &dataArray); err == nil {
assert.GreaterOrEqual(t, len(dataArray), 1, "Should find at least one employee")
logger.Info("Employees read with filter successfully (simple format), found: %d", len(dataArray))
logger.Info("Employees read with filter successfully (simple format - array), found: %d", len(dataArray))
return
}
// Try standard response format
var result map[string]interface{}
if err := json.Unmarshal(body, &result); err == nil {
if success, ok := result["success"]; ok && success != nil && success.(bool) {
if data, ok := result["data"].([]interface{}); ok {
// Try to decode as a single object (simple format with SingleRecordAsObject enabled)
var singleObj map[string]interface{}
if err := json.Unmarshal(body, &singleObj); err == nil {
// Check if it's a data object (not a response wrapper)
if _, hasSuccess := singleObj["success"]; !hasSuccess {
// This is a direct data object (simple format, single record)
assert.NotEmpty(t, singleObj, "Should find at least one employee")
logger.Info("Employees read with filter successfully (simple format - single object), found: 1")
return
}
// Otherwise it's a standard response object (detail format)
if success, ok := singleObj["success"]; ok && success != nil && success.(bool) {
// Check if data is an array
if data, ok := singleObj["data"].([]interface{}); ok {
assert.GreaterOrEqual(t, len(data), 1, "Should find at least one employee")
logger.Info("Employees read with filter successfully (detail format), found: %d", len(data))
logger.Info("Employees read with filter successfully (detail format - array), found: %d", len(data))
return
}
// Check if data is a single object (SingleRecordAsObject feature in detail format)
if data, ok := singleObj["data"].(map[string]interface{}); ok {
assert.NotEmpty(t, data, "Should find at least one employee")
logger.Info("Employees read with filter successfully (detail format - single object), found: 1")
return
}
}
@@ -508,7 +554,13 @@ func testRestHeadSpecCRUD(t *testing.T, serverURL string) {
var result map[string]interface{}
json.NewDecoder(resp.Body).Decode(&result)
assert.True(t, result["success"].(bool), "Update department should succeed")
// Check if response has "success" field (wrapped format) or direct data (unwrapped format)
if success, ok := result["success"]; ok && success != nil {
assert.True(t, success.(bool), "Update department should succeed")
} else {
// Unwrapped format - verify we got the updated data back
assert.NotEmpty(t, result, "Update department should return data")
}
logger.Info("Department updated successfully: %s", deptID)
// Verify update by reading the department again
@@ -526,7 +578,13 @@ func testRestHeadSpecCRUD(t *testing.T, serverURL string) {
var result map[string]interface{}
json.NewDecoder(resp.Body).Decode(&result)
assert.True(t, result["success"].(bool), "Update employee should succeed")
// Check if response has "success" field (wrapped format) or direct data (unwrapped format)
if success, ok := result["success"]; ok && success != nil {
assert.True(t, success.(bool), "Update employee should succeed")
} else {
// Unwrapped format - verify we got the updated data back
assert.NotEmpty(t, result, "Update employee should return data")
}
logger.Info("Employee updated successfully: %s", empID)
})
@@ -537,7 +595,13 @@ func testRestHeadSpecCRUD(t *testing.T, serverURL string) {
var result map[string]interface{}
json.NewDecoder(resp.Body).Decode(&result)
assert.True(t, result["success"].(bool), "Delete employee should succeed")
// Check if response has "success" field (wrapped format) or direct data (unwrapped format)
if success, ok := result["success"]; ok && success != nil {
assert.True(t, success.(bool), "Delete employee should succeed")
} else {
// Unwrapped format - verify we got a response (typically {"deleted": count})
assert.NotEmpty(t, result, "Delete employee should return data")
}
logger.Info("Employee deleted successfully: %s", empID)
// Verify deletion - just log that delete succeeded
@@ -550,7 +614,13 @@ func testRestHeadSpecCRUD(t *testing.T, serverURL string) {
var result map[string]interface{}
json.NewDecoder(resp.Body).Decode(&result)
assert.True(t, result["success"].(bool), "Delete department should succeed")
// Check if response has "success" field (wrapped format) or direct data (unwrapped format)
if success, ok := result["success"]; ok && success != nil {
assert.True(t, success.(bool), "Delete department should succeed")
} else {
// Unwrapped format - verify we got a response (typically {"deleted": count})
assert.NotEmpty(t, result, "Delete department should return data")
}
logger.Info("Department deleted successfully: %s", deptID)
})