mirror of
https://github.com/bitechdev/ResolveSpec.git
synced 2025-12-29 15:54:26 +00:00
Compare commits
22 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7e76977dcc | ||
|
|
7853a3f56a | ||
|
|
c2e0c36c79 | ||
|
|
59bd709460 | ||
|
|
05962035b6 | ||
|
|
1cd04b7083 | ||
|
|
0d4909054c | ||
|
|
745564f2e7 | ||
|
|
311e50bfdd | ||
|
|
c95bc9e633 | ||
|
|
07b09e2025 | ||
|
|
3d5334002d | ||
|
|
640582d508 | ||
|
|
b0b3ae662b | ||
|
|
c9b9f75b06 | ||
|
|
af3260864d | ||
|
|
ca6d2deff6 | ||
|
|
1481443516 | ||
|
|
cb54ec5e27 | ||
|
|
7d6a9025f5 | ||
|
|
35089f511f | ||
|
|
66b6a0d835 |
@@ -86,7 +86,6 @@
|
|||||||
"emptyFallthrough",
|
"emptyFallthrough",
|
||||||
"equalFold",
|
"equalFold",
|
||||||
"flagName",
|
"flagName",
|
||||||
"ifElseChain",
|
|
||||||
"indexAlloc",
|
"indexAlloc",
|
||||||
"initClause",
|
"initClause",
|
||||||
"methodExprCall",
|
"methodExprCall",
|
||||||
@@ -106,6 +105,9 @@
|
|||||||
"unnecessaryBlock",
|
"unnecessaryBlock",
|
||||||
"weakCond",
|
"weakCond",
|
||||||
"yodaStyleExpr"
|
"yodaStyleExpr"
|
||||||
|
],
|
||||||
|
"disabled-checks": [
|
||||||
|
"ifElseChain"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"revive": {
|
"revive": {
|
||||||
|
|||||||
19
go.mod
19
go.mod
@@ -8,7 +8,12 @@ require (
|
|||||||
github.com/glebarez/sqlite v1.11.0
|
github.com/glebarez/sqlite v1.11.0
|
||||||
github.com/gorilla/mux v1.8.1
|
github.com/gorilla/mux v1.8.1
|
||||||
github.com/stretchr/testify v1.8.1
|
github.com/stretchr/testify v1.8.1
|
||||||
|
github.com/tidwall/gjson v1.18.0
|
||||||
|
github.com/tidwall/sjson v1.2.5
|
||||||
github.com/uptrace/bun v1.2.15
|
github.com/uptrace/bun v1.2.15
|
||||||
|
github.com/uptrace/bun/dialect/sqlitedialect v1.2.15
|
||||||
|
github.com/uptrace/bun/driver/sqliteshim v1.2.15
|
||||||
|
github.com/uptrace/bunrouter v1.0.23
|
||||||
go.uber.org/zap v1.27.0
|
go.uber.org/zap v1.27.0
|
||||||
gorm.io/gorm v1.25.12
|
gorm.io/gorm v1.25.12
|
||||||
)
|
)
|
||||||
@@ -21,23 +26,23 @@ require (
|
|||||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||||
github.com/jinzhu/now v1.1.5 // indirect
|
github.com/jinzhu/now v1.1.5 // indirect
|
||||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||||
|
github.com/mattn/go-sqlite3 v1.14.28 // indirect
|
||||||
|
github.com/ncruces/go-strftime v0.1.9 // indirect
|
||||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||||
github.com/puzpuzpuz/xsync/v3 v3.5.1 // indirect
|
github.com/puzpuzpuz/xsync/v3 v3.5.1 // indirect
|
||||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
||||||
github.com/tidwall/gjson v1.18.0 // indirect
|
|
||||||
github.com/tidwall/match v1.1.1 // indirect
|
github.com/tidwall/match v1.1.1 // indirect
|
||||||
github.com/tidwall/pretty v1.2.0 // indirect
|
github.com/tidwall/pretty v1.2.0 // indirect
|
||||||
github.com/tidwall/sjson v1.2.5 // indirect
|
|
||||||
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc // indirect
|
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc // indirect
|
||||||
github.com/uptrace/bunrouter v1.0.23 // indirect
|
|
||||||
github.com/vmihailenco/msgpack/v5 v5.4.1 // indirect
|
github.com/vmihailenco/msgpack/v5 v5.4.1 // indirect
|
||||||
github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
|
github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
|
||||||
go.uber.org/multierr v1.10.0 // indirect
|
go.uber.org/multierr v1.10.0 // indirect
|
||||||
|
golang.org/x/exp v0.0.0-20250711185948-6ae5c78190dc // indirect
|
||||||
golang.org/x/sys v0.34.0 // indirect
|
golang.org/x/sys v0.34.0 // indirect
|
||||||
golang.org/x/text v0.21.0 // indirect
|
golang.org/x/text v0.21.0 // indirect
|
||||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||||
modernc.org/libc v1.22.5 // indirect
|
modernc.org/libc v1.66.3 // indirect
|
||||||
modernc.org/mathutil v1.5.0 // indirect
|
modernc.org/mathutil v1.7.1 // indirect
|
||||||
modernc.org/memory v1.5.0 // indirect
|
modernc.org/memory v1.11.0 // indirect
|
||||||
modernc.org/sqlite v1.23.1 // indirect
|
modernc.org/sqlite v1.38.0 // indirect
|
||||||
)
|
)
|
||||||
|
|||||||
55
go.sum
55
go.sum
@@ -7,8 +7,8 @@ github.com/glebarez/go-sqlite v1.21.2 h1:3a6LFC4sKahUunAmynQKLZceZCOzUthkRkEAl9g
|
|||||||
github.com/glebarez/go-sqlite v1.21.2/go.mod h1:sfxdZyhQjTM2Wry3gVYWaW072Ri1WMdWJi0k6+3382k=
|
github.com/glebarez/go-sqlite v1.21.2/go.mod h1:sfxdZyhQjTM2Wry3gVYWaW072Ri1WMdWJi0k6+3382k=
|
||||||
github.com/glebarez/sqlite v1.11.0 h1:wSG0irqzP6VurnMEpFGer5Li19RpIRi2qvQz++w0GMw=
|
github.com/glebarez/sqlite v1.11.0 h1:wSG0irqzP6VurnMEpFGer5Li19RpIRi2qvQz++w0GMw=
|
||||||
github.com/glebarez/sqlite v1.11.0/go.mod h1:h8/o8j5wiAsqSPoWELDUdJXhjAhsVliSn7bWZjOhrgQ=
|
github.com/glebarez/sqlite v1.11.0/go.mod h1:h8/o8j5wiAsqSPoWELDUdJXhjAhsVliSn7bWZjOhrgQ=
|
||||||
github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26 h1:Xim43kblpZXfIBQsbuBVKCudVG457BR2GZFIz3uw3hQ=
|
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
|
||||||
github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26/go.mod h1:dDKJzRmX4S37WGHujM7tX//fmj1uioxKzKxz3lo4HJo=
|
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
|
||||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||||
github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY=
|
github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY=
|
||||||
@@ -21,13 +21,16 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
|||||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||||
|
github.com/mattn/go-sqlite3 v1.14.28 h1:ThEiQrnbtumT+QMknw63Befp/ce/nUPgBPMlRFEum7A=
|
||||||
|
github.com/mattn/go-sqlite3 v1.14.28/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
|
||||||
|
github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4=
|
||||||
|
github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
|
||||||
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWbfPhv4DMiApHyliiK5xCTNVSPiaAs=
|
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWbfPhv4DMiApHyliiK5xCTNVSPiaAs=
|
||||||
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno=
|
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno=
|
||||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
github.com/puzpuzpuz/xsync/v3 v3.5.1 h1:GJYJZwO6IdxN/IKbneznS6yPkVC+c3zyY/j19c++5Fg=
|
github.com/puzpuzpuz/xsync/v3 v3.5.1 h1:GJYJZwO6IdxN/IKbneznS6yPkVC+c3zyY/j19c++5Fg=
|
||||||
github.com/puzpuzpuz/xsync/v3 v3.5.1/go.mod h1:VjzYrABPabuM4KyBh1Ftq6u8nhwY5tBPKP9jpmh0nnA=
|
github.com/puzpuzpuz/xsync/v3 v3.5.1/go.mod h1:VjzYrABPabuM4KyBh1Ftq6u8nhwY5tBPKP9jpmh0nnA=
|
||||||
github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
|
||||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
|
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
|
||||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
||||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||||
@@ -50,6 +53,10 @@ github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc h1:9lRDQMhESg+zvGYm
|
|||||||
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc/go.mod h1:bciPuU6GHm1iF1pBvUfxfsH0Wmnc2VbpgvbI9ZWuIRs=
|
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc/go.mod h1:bciPuU6GHm1iF1pBvUfxfsH0Wmnc2VbpgvbI9ZWuIRs=
|
||||||
github.com/uptrace/bun v1.2.15 h1:Ut68XRBLDgp9qG9QBMa9ELWaZOmzHNdczHQdrOZbEFE=
|
github.com/uptrace/bun v1.2.15 h1:Ut68XRBLDgp9qG9QBMa9ELWaZOmzHNdczHQdrOZbEFE=
|
||||||
github.com/uptrace/bun v1.2.15/go.mod h1:Eghz7NonZMiTX/Z6oKYytJ0oaMEJ/eq3kEV4vSqG038=
|
github.com/uptrace/bun v1.2.15/go.mod h1:Eghz7NonZMiTX/Z6oKYytJ0oaMEJ/eq3kEV4vSqG038=
|
||||||
|
github.com/uptrace/bun/dialect/sqlitedialect v1.2.15 h1:7upGMVjFRB1oI78GQw6ruNLblYn5CR+kxqcbbeBBils=
|
||||||
|
github.com/uptrace/bun/dialect/sqlitedialect v1.2.15/go.mod h1:c7YIDaPNS2CU2uI1p7umFuFWkuKbDcPDDvp+DLHZnkI=
|
||||||
|
github.com/uptrace/bun/driver/sqliteshim v1.2.15 h1:M/rZJSjOPV4OmfTVnDPtL+wJmdMTqDUn8cuk5ycfABA=
|
||||||
|
github.com/uptrace/bun/driver/sqliteshim v1.2.15/go.mod h1:YqwxFyvM992XOCpGJtXyKPkgkb+aZpIIMzGbpaw1hIk=
|
||||||
github.com/uptrace/bunrouter v1.0.23 h1:Bi7NKw3uCQkcA/GUCtDNPq5LE5UdR9pe+UyWbjHB/wU=
|
github.com/uptrace/bunrouter v1.0.23 h1:Bi7NKw3uCQkcA/GUCtDNPq5LE5UdR9pe+UyWbjHB/wU=
|
||||||
github.com/uptrace/bunrouter v1.0.23/go.mod h1:O3jAcl+5qgnF+ejhgkmbceEk0E/mqaK+ADOocdNpY8M=
|
github.com/uptrace/bunrouter v1.0.23/go.mod h1:O3jAcl+5qgnF+ejhgkmbceEk0E/mqaK+ADOocdNpY8M=
|
||||||
github.com/vmihailenco/msgpack/v5 v5.4.1 h1:cQriyiUvjTwOHg8QZaPihLWeRAAVoCpE00IUPn0Bjt8=
|
github.com/vmihailenco/msgpack/v5 v5.4.1 h1:cQriyiUvjTwOHg8QZaPihLWeRAAVoCpE00IUPn0Bjt8=
|
||||||
@@ -62,11 +69,19 @@ go.uber.org/multierr v1.10.0 h1:S0h4aNzvfcFsC3dRF1jLoaov7oRaKqRGC/pUEJ2yvPQ=
|
|||||||
go.uber.org/multierr v1.10.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
|
go.uber.org/multierr v1.10.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
|
||||||
go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8=
|
go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8=
|
||||||
go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
|
go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
|
||||||
|
golang.org/x/exp v0.0.0-20250711185948-6ae5c78190dc h1:TS73t7x3KarrNd5qAipmspBDS1rkMcgVG/fS1aRb4Rc=
|
||||||
|
golang.org/x/exp v0.0.0-20250711185948-6ae5c78190dc/go.mod h1:A+z0yzpGtvnG90cToK5n2tu8UJVP2XUATh+r+sfOOOc=
|
||||||
|
golang.org/x/mod v0.26.0 h1:EGMPT//Ezu+ylkCijjPc+f4Aih7sZvaAr+O3EHBxvZg=
|
||||||
|
golang.org/x/mod v0.26.0/go.mod h1:/j6NAhSk8iQ723BGAUyoAcn7SlD7s15Dp9Nd/SfeaFQ=
|
||||||
|
golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw=
|
||||||
|
golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
|
||||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.34.0 h1:H5Y5sJ2L2JRdyv7ROF1he/lPdvFsd0mJHFw2ThKHxLA=
|
golang.org/x/sys v0.34.0 h1:H5Y5sJ2L2JRdyv7ROF1he/lPdvFsd0mJHFw2ThKHxLA=
|
||||||
golang.org/x/sys v0.34.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
golang.org/x/sys v0.34.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||||
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
|
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
|
||||||
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
|
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
|
||||||
|
golang.org/x/tools v0.35.0 h1:mBffYraMEf7aa0sB+NuKnuCy8qI/9Bughn8dC2Gu5r0=
|
||||||
|
golang.org/x/tools v0.35.0/go.mod h1:NKdj5HkL/73byiZSJjqJgKn3ep7KjFkBOkR/Hps3VPw=
|
||||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f h1:BLraFXnmrev5lT+xlilqcH8XK9/i0At2xKjWk4p6zsU=
|
gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f h1:BLraFXnmrev5lT+xlilqcH8XK9/i0At2xKjWk4p6zsU=
|
||||||
gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
@@ -75,11 +90,29 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
|||||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||||
gorm.io/gorm v1.25.12 h1:I0u8i2hWQItBq1WfE0o2+WuL9+8L21K9e2HHSTE/0f8=
|
gorm.io/gorm v1.25.12 h1:I0u8i2hWQItBq1WfE0o2+WuL9+8L21K9e2HHSTE/0f8=
|
||||||
gorm.io/gorm v1.25.12/go.mod h1:xh7N7RHfYlNc5EmcI/El95gXusucDrQnHXe0+CgWcLQ=
|
gorm.io/gorm v1.25.12/go.mod h1:xh7N7RHfYlNc5EmcI/El95gXusucDrQnHXe0+CgWcLQ=
|
||||||
modernc.org/libc v1.22.5 h1:91BNch/e5B0uPbJFgqbxXuOnxBQjlS//icfQEGmvyjE=
|
modernc.org/cc/v4 v4.26.2 h1:991HMkLjJzYBIfha6ECZdjrIYz2/1ayr+FL8GN+CNzM=
|
||||||
modernc.org/libc v1.22.5/go.mod h1:jj+Z7dTNX8fBScMVNRAYZ/jF91K8fdT2hYMThc3YjBY=
|
modernc.org/cc/v4 v4.26.2/go.mod h1:uVtb5OGqUKpoLWhqwNQo/8LwvoiEBLvZXIQ/SmO6mL0=
|
||||||
modernc.org/mathutil v1.5.0 h1:rV0Ko/6SfM+8G+yKiyI830l3Wuz1zRutdslNoQ0kfiQ=
|
modernc.org/ccgo/v4 v4.28.0 h1:rjznn6WWehKq7dG4JtLRKxb52Ecv8OUGah8+Z/SfpNU=
|
||||||
modernc.org/mathutil v1.5.0/go.mod h1:mZW8CKdRPY1v87qxC/wUdX5O1qDzXMP5TH3wjfpga6E=
|
modernc.org/ccgo/v4 v4.28.0/go.mod h1:JygV3+9AV6SmPhDasu4JgquwU81XAKLd3OKTUDNOiKE=
|
||||||
modernc.org/memory v1.5.0 h1:N+/8c5rE6EqugZwHii4IFsaJ7MUhoWX07J5tC/iI5Ds=
|
modernc.org/fileutil v1.3.8 h1:qtzNm7ED75pd1C7WgAGcK4edm4fvhtBsEiI/0NQ54YM=
|
||||||
modernc.org/memory v1.5.0/go.mod h1:PkUhL0Mugw21sHPeskwZW4D6VscE/GQJOnIpCnW6pSU=
|
modernc.org/fileutil v1.3.8/go.mod h1:HxmghZSZVAz/LXcMNwZPA/DRrQZEVP9VX0V4LQGQFOc=
|
||||||
modernc.org/sqlite v1.23.1 h1:nrSBg4aRQQwq59JpvGEQ15tNxoO5pX/kUjcRNwSAGQM=
|
modernc.org/gc/v2 v2.6.5 h1:nyqdV8q46KvTpZlsw66kWqwXRHdjIlJOhG6kxiV/9xI=
|
||||||
modernc.org/sqlite v1.23.1/go.mod h1:OrDj17Mggn6MhE+iPbBNf7RGKODDE9NFT0f3EwDzJqk=
|
modernc.org/gc/v2 v2.6.5/go.mod h1:YgIahr1ypgfe7chRuJi2gD7DBQiKSLMPgBQe9oIiito=
|
||||||
|
modernc.org/goabi0 v0.2.0 h1:HvEowk7LxcPd0eq6mVOAEMai46V+i7Jrj13t4AzuNks=
|
||||||
|
modernc.org/goabi0 v0.2.0/go.mod h1:CEFRnnJhKvWT1c1JTI3Avm+tgOWbkOu5oPA8eH8LnMI=
|
||||||
|
modernc.org/libc v1.66.3 h1:cfCbjTUcdsKyyZZfEUKfoHcP3S0Wkvz3jgSzByEWVCQ=
|
||||||
|
modernc.org/libc v1.66.3/go.mod h1:XD9zO8kt59cANKvHPXpx7yS2ELPheAey0vjIuZOhOU8=
|
||||||
|
modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU=
|
||||||
|
modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg=
|
||||||
|
modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI=
|
||||||
|
modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw=
|
||||||
|
modernc.org/opt v0.1.4 h1:2kNGMRiUjrp4LcaPuLY2PzUfqM/w9N23quVwhKt5Qm8=
|
||||||
|
modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns=
|
||||||
|
modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w=
|
||||||
|
modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE=
|
||||||
|
modernc.org/sqlite v1.38.0 h1:+4OrfPQ8pxHKuWG4md1JpR/EYAh3Md7TdejuuzE7EUI=
|
||||||
|
modernc.org/sqlite v1.38.0/go.mod h1:1Bj+yES4SVvBZ4cBOpVZ6QgesMCKpJZDq0nxYzOpmNE=
|
||||||
|
modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0=
|
||||||
|
modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A=
|
||||||
|
modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y=
|
||||||
|
modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM=
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"github.com/uptrace/bun"
|
"github.com/uptrace/bun"
|
||||||
|
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
|
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/reflection"
|
"github.com/bitechdev/ResolveSpec/pkg/reflection"
|
||||||
)
|
)
|
||||||
@@ -43,12 +44,22 @@ func (b *BunAdapter) NewDelete() common.DeleteQuery {
|
|||||||
return &BunDeleteQuery{query: b.db.NewDelete()}
|
return &BunDeleteQuery{query: b.db.NewDelete()}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *BunAdapter) Exec(ctx context.Context, query string, args ...interface{}) (common.Result, error) {
|
func (b *BunAdapter) Exec(ctx context.Context, query string, args ...interface{}) (res common.Result, err error) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
err = logger.HandlePanic("BunAdapter.Exec", r)
|
||||||
|
}
|
||||||
|
}()
|
||||||
result, err := b.db.ExecContext(ctx, query, args...)
|
result, err := b.db.ExecContext(ctx, query, args...)
|
||||||
return &BunResult{result: result}, err
|
return &BunResult{result: result}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *BunAdapter) Query(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
|
func (b *BunAdapter) Query(ctx context.Context, dest interface{}, query string, args ...interface{}) (err error) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
err = logger.HandlePanic("BunAdapter.Query", r)
|
||||||
|
}
|
||||||
|
}()
|
||||||
return b.db.NewRaw(query, args...).Scan(ctx, dest)
|
return b.db.NewRaw(query, args...).Scan(ctx, dest)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -73,7 +84,12 @@ func (b *BunAdapter) RollbackTx(ctx context.Context) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *BunAdapter) RunInTransaction(ctx context.Context, fn func(common.Database) error) error {
|
func (b *BunAdapter) RunInTransaction(ctx context.Context, fn func(common.Database) error) (err error) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
err = logger.HandlePanic("BunAdapter.RunInTransaction", r)
|
||||||
|
}
|
||||||
|
}()
|
||||||
return b.db.RunInTx(ctx, &sql.TxOptions{}, func(ctx context.Context, tx bun.Tx) error {
|
return b.db.RunInTx(ctx, &sql.TxOptions{}, func(ctx context.Context, tx bun.Tx) error {
|
||||||
// Create adapter with transaction
|
// Create adapter with transaction
|
||||||
adapter := &BunTxAdapter{tx: tx}
|
adapter := &BunTxAdapter{tx: tx}
|
||||||
@@ -219,6 +235,14 @@ func (b *BunSelectQuery) Preload(relation string, conditions ...interface{}) com
|
|||||||
|
|
||||||
func (b *BunSelectQuery) PreloadRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery {
|
func (b *BunSelectQuery) PreloadRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery {
|
||||||
b.query = b.query.Relation(relation, func(sq *bun.SelectQuery) *bun.SelectQuery {
|
b.query = b.query.Relation(relation, func(sq *bun.SelectQuery) *bun.SelectQuery {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
err := logger.HandlePanic("BunSelectQuery.PreloadRelation", r)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
if len(apply) == 0 {
|
if len(apply) == 0 {
|
||||||
return sq
|
return sq
|
||||||
}
|
}
|
||||||
@@ -276,15 +300,38 @@ func (b *BunSelectQuery) Having(having string, args ...interface{}) common.Selec
|
|||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *BunSelectQuery) Scan(ctx context.Context, dest interface{}) error {
|
func (b *BunSelectQuery) Scan(ctx context.Context, dest interface{}) (err error) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
err = logger.HandlePanic("BunSelectQuery.Scan", r)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
if dest == nil {
|
||||||
|
return fmt.Errorf("destination cannot be nil")
|
||||||
|
}
|
||||||
return b.query.Scan(ctx, dest)
|
return b.query.Scan(ctx, dest)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *BunSelectQuery) ScanModel(ctx context.Context) error {
|
func (b *BunSelectQuery) ScanModel(ctx context.Context) (err error) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
err = logger.HandlePanic("BunSelectQuery.ScanModel", r)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
if b.query.GetModel() == nil {
|
||||||
|
return fmt.Errorf("model is nil")
|
||||||
|
}
|
||||||
|
|
||||||
return b.query.Scan(ctx)
|
return b.query.Scan(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *BunSelectQuery) Count(ctx context.Context) (int, error) {
|
func (b *BunSelectQuery) Count(ctx context.Context) (count int, err error) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
err = logger.HandlePanic("BunSelectQuery.Count", r)
|
||||||
|
count = 0
|
||||||
|
}
|
||||||
|
}()
|
||||||
// If Model() was set, use bun's native Count() which works properly
|
// If Model() was set, use bun's native Count() which works properly
|
||||||
if b.hasModel {
|
if b.hasModel {
|
||||||
count, err := b.query.Count(ctx)
|
count, err := b.query.Count(ctx)
|
||||||
@@ -293,30 +340,40 @@ func (b *BunSelectQuery) Count(ctx context.Context) (int, error) {
|
|||||||
|
|
||||||
// Otherwise, wrap as subquery to avoid "Model(nil)" error
|
// Otherwise, wrap as subquery to avoid "Model(nil)" error
|
||||||
// This is needed when only Table() is set without a model
|
// This is needed when only Table() is set without a model
|
||||||
var count int
|
err = b.db.NewSelect().
|
||||||
err := b.db.NewSelect().
|
|
||||||
TableExpr("(?) AS subquery", b.query).
|
TableExpr("(?) AS subquery", b.query).
|
||||||
ColumnExpr("COUNT(*)").
|
ColumnExpr("COUNT(*)").
|
||||||
Scan(ctx, &count)
|
Scan(ctx, &count)
|
||||||
return count, err
|
return count, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *BunSelectQuery) Exists(ctx context.Context) (bool, error) {
|
func (b *BunSelectQuery) Exists(ctx context.Context) (exists bool, err error) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
err = logger.HandlePanic("BunSelectQuery.Exists", r)
|
||||||
|
exists = false
|
||||||
|
}
|
||||||
|
}()
|
||||||
return b.query.Exists(ctx)
|
return b.query.Exists(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
// BunInsertQuery implements InsertQuery for Bun
|
// BunInsertQuery implements InsertQuery for Bun
|
||||||
type BunInsertQuery struct {
|
type BunInsertQuery struct {
|
||||||
query *bun.InsertQuery
|
query *bun.InsertQuery
|
||||||
values map[string]interface{}
|
values map[string]interface{}
|
||||||
|
hasModel bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *BunInsertQuery) Model(model interface{}) common.InsertQuery {
|
func (b *BunInsertQuery) Model(model interface{}) common.InsertQuery {
|
||||||
b.query = b.query.Model(model)
|
b.query = b.query.Model(model)
|
||||||
|
b.hasModel = true
|
||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *BunInsertQuery) Table(table string) common.InsertQuery {
|
func (b *BunInsertQuery) Table(table string) common.InsertQuery {
|
||||||
|
if b.hasModel {
|
||||||
|
return b
|
||||||
|
}
|
||||||
b.query = b.query.Table(table)
|
b.query = b.query.Table(table)
|
||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
@@ -341,11 +398,22 @@ func (b *BunInsertQuery) Returning(columns ...string) common.InsertQuery {
|
|||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *BunInsertQuery) Exec(ctx context.Context) (common.Result, error) {
|
func (b *BunInsertQuery) Exec(ctx context.Context) (res common.Result, err error) {
|
||||||
if b.values != nil {
|
defer func() {
|
||||||
// For Bun, we need to handle this differently
|
if r := recover(); r != nil {
|
||||||
for k, v := range b.values {
|
err = logger.HandlePanic("BunInsertQuery.Exec", r)
|
||||||
b.query = b.query.Set("? = ?", bun.Ident(k), v)
|
}
|
||||||
|
}()
|
||||||
|
if len(b.values) > 0 {
|
||||||
|
if !b.hasModel {
|
||||||
|
// If no model was set, use the values map as the model
|
||||||
|
// Bun can insert map[string]interface{} directly
|
||||||
|
b.query = b.query.Model(&b.values)
|
||||||
|
} else {
|
||||||
|
// If model was set, use Value() to add individual values
|
||||||
|
for k, v := range b.values {
|
||||||
|
b.query = b.query.Value(k, "?", v)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
result, err := b.query.Exec(ctx)
|
result, err := b.query.Exec(ctx)
|
||||||
@@ -388,12 +456,17 @@ func (b *BunUpdateQuery) Set(column string, value interface{}) common.UpdateQuer
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (b *BunUpdateQuery) SetMap(values map[string]interface{}) common.UpdateQuery {
|
func (b *BunUpdateQuery) SetMap(values map[string]interface{}) common.UpdateQuery {
|
||||||
|
pkName := reflection.GetPrimaryKeyName(b.model)
|
||||||
for column, value := range values {
|
for column, value := range values {
|
||||||
// Validate column is writable if model is set
|
// Validate column is writable if model is set
|
||||||
if b.model != nil && !reflection.IsColumnWritable(b.model, column) {
|
if b.model != nil && !reflection.IsColumnWritable(b.model, column) {
|
||||||
// Skip scan-only columns
|
// Skip scan-only columns
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
if pkName != "" && column == pkName {
|
||||||
|
// Skip primary key updates
|
||||||
|
continue
|
||||||
|
}
|
||||||
b.query = b.query.Set(column+" = ?", value)
|
b.query = b.query.Set(column+" = ?", value)
|
||||||
}
|
}
|
||||||
return b
|
return b
|
||||||
@@ -411,7 +484,12 @@ func (b *BunUpdateQuery) Returning(columns ...string) common.UpdateQuery {
|
|||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *BunUpdateQuery) Exec(ctx context.Context) (common.Result, error) {
|
func (b *BunUpdateQuery) Exec(ctx context.Context) (res common.Result, err error) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
err = logger.HandlePanic("BunUpdateQuery.Exec", r)
|
||||||
|
}
|
||||||
|
}()
|
||||||
result, err := b.query.Exec(ctx)
|
result, err := b.query.Exec(ctx)
|
||||||
return &BunResult{result: result}, err
|
return &BunResult{result: result}, err
|
||||||
}
|
}
|
||||||
@@ -436,7 +514,12 @@ func (b *BunDeleteQuery) Where(query string, args ...interface{}) common.DeleteQ
|
|||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *BunDeleteQuery) Exec(ctx context.Context) (common.Result, error) {
|
func (b *BunDeleteQuery) Exec(ctx context.Context) (res common.Result, err error) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
err = logger.HandlePanic("BunDeleteQuery.Exec", r)
|
||||||
|
}
|
||||||
|
}()
|
||||||
result, err := b.query.Exec(ctx)
|
result, err := b.query.Exec(ctx)
|
||||||
return &BunResult{result: result}, err
|
return &BunResult{result: result}, err
|
||||||
}
|
}
|
||||||
|
|||||||
213
pkg/common/adapters/database/bun_insert_test.go
Normal file
213
pkg/common/adapters/database/bun_insert_test.go
Normal file
@@ -0,0 +1,213 @@
|
|||||||
|
package database
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"github.com/uptrace/bun"
|
||||||
|
"github.com/uptrace/bun/dialect/sqlitedialect"
|
||||||
|
"github.com/uptrace/bun/driver/sqliteshim"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestInsertModel is a test model for insert operations
|
||||||
|
type TestInsertModel struct {
|
||||||
|
bun.BaseModel `bun:"table:test_inserts"`
|
||||||
|
ID int64 `bun:"id,pk,autoincrement"`
|
||||||
|
Name string `bun:"name,notnull"`
|
||||||
|
Email string `bun:"email"`
|
||||||
|
Age int `bun:"age"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupBunTestDB(t *testing.T) *bun.DB {
|
||||||
|
sqldb, err := sql.Open(sqliteshim.ShimName, "file::memory:?cache=shared")
|
||||||
|
require.NoError(t, err, "Failed to open SQLite database")
|
||||||
|
|
||||||
|
db := bun.NewDB(sqldb, sqlitedialect.New())
|
||||||
|
|
||||||
|
// Create test table
|
||||||
|
_, err = db.NewCreateTable().
|
||||||
|
Model((*TestInsertModel)(nil)).
|
||||||
|
IfNotExists().
|
||||||
|
Exec(context.Background())
|
||||||
|
require.NoError(t, err, "Failed to create test table")
|
||||||
|
|
||||||
|
return db
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBunInsertQuery_Model(t *testing.T) {
|
||||||
|
db := setupBunTestDB(t)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
adapter := NewBunAdapter(db)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Test inserting with Model()
|
||||||
|
model := &TestInsertModel{
|
||||||
|
Name: "John Doe",
|
||||||
|
Email: "john@example.com",
|
||||||
|
Age: 30,
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := adapter.NewInsert().
|
||||||
|
Model(model).
|
||||||
|
Returning("*").
|
||||||
|
Exec(ctx)
|
||||||
|
|
||||||
|
require.NoError(t, err, "Insert should succeed")
|
||||||
|
assert.Equal(t, int64(1), result.RowsAffected(), "Should insert 1 row")
|
||||||
|
|
||||||
|
// Verify the data was inserted
|
||||||
|
var retrieved TestInsertModel
|
||||||
|
err = db.NewSelect().
|
||||||
|
Model(&retrieved).
|
||||||
|
Where("id = ?", model.ID).
|
||||||
|
Scan(ctx)
|
||||||
|
|
||||||
|
require.NoError(t, err, "Should retrieve inserted row")
|
||||||
|
assert.Equal(t, "John Doe", retrieved.Name)
|
||||||
|
assert.Equal(t, "john@example.com", retrieved.Email)
|
||||||
|
assert.Equal(t, 30, retrieved.Age)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBunInsertQuery_Value(t *testing.T) {
|
||||||
|
db := setupBunTestDB(t)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
adapter := NewBunAdapter(db)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Test inserting with Value() method - this was the bug
|
||||||
|
result, err := adapter.NewInsert().
|
||||||
|
Table("test_inserts").
|
||||||
|
Value("name", "Jane Smith").
|
||||||
|
Value("email", "jane@example.com").
|
||||||
|
Value("age", 25).
|
||||||
|
Exec(ctx)
|
||||||
|
|
||||||
|
require.NoError(t, err, "Insert with Value() should succeed")
|
||||||
|
assert.Equal(t, int64(1), result.RowsAffected(), "Should insert 1 row")
|
||||||
|
|
||||||
|
// Verify the data was inserted
|
||||||
|
var retrieved TestInsertModel
|
||||||
|
err = db.NewSelect().
|
||||||
|
Model(&retrieved).
|
||||||
|
Where("name = ?", "Jane Smith").
|
||||||
|
Scan(ctx)
|
||||||
|
|
||||||
|
require.NoError(t, err, "Should retrieve inserted row")
|
||||||
|
assert.Equal(t, "Jane Smith", retrieved.Name)
|
||||||
|
assert.Equal(t, "jane@example.com", retrieved.Email)
|
||||||
|
assert.Equal(t, 25, retrieved.Age)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBunInsertQuery_MultipleValues(t *testing.T) {
|
||||||
|
db := setupBunTestDB(t)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
adapter := NewBunAdapter(db)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Test inserting multiple values
|
||||||
|
result, err := adapter.NewInsert().
|
||||||
|
Table("test_inserts").
|
||||||
|
Value("name", "Alice").
|
||||||
|
Value("email", "alice@example.com").
|
||||||
|
Value("age", 28).
|
||||||
|
Exec(ctx)
|
||||||
|
|
||||||
|
require.NoError(t, err, "First insert should succeed")
|
||||||
|
assert.Equal(t, int64(1), result.RowsAffected())
|
||||||
|
|
||||||
|
result, err = adapter.NewInsert().
|
||||||
|
Table("test_inserts").
|
||||||
|
Value("name", "Bob").
|
||||||
|
Value("email", "bob@example.com").
|
||||||
|
Value("age", 35).
|
||||||
|
Exec(ctx)
|
||||||
|
|
||||||
|
require.NoError(t, err, "Second insert should succeed")
|
||||||
|
assert.Equal(t, int64(1), result.RowsAffected())
|
||||||
|
|
||||||
|
// Verify both rows exist
|
||||||
|
var count int
|
||||||
|
count, err = db.NewSelect().
|
||||||
|
Model((*TestInsertModel)(nil)).
|
||||||
|
Count(ctx)
|
||||||
|
|
||||||
|
require.NoError(t, err, "Count should succeed")
|
||||||
|
assert.Equal(t, 2, count, "Should have 2 rows")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBunInsertQuery_ValueWithNil(t *testing.T) {
|
||||||
|
db := setupBunTestDB(t)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
adapter := NewBunAdapter(db)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Test inserting with nil value for nullable field
|
||||||
|
result, err := adapter.NewInsert().
|
||||||
|
Table("test_inserts").
|
||||||
|
Value("name", "Test User").
|
||||||
|
Value("email", nil). // NULL email
|
||||||
|
Value("age", 20).
|
||||||
|
Exec(ctx)
|
||||||
|
|
||||||
|
require.NoError(t, err, "Insert with nil value should succeed")
|
||||||
|
assert.Equal(t, int64(1), result.RowsAffected())
|
||||||
|
|
||||||
|
// Verify the data was inserted with NULL email
|
||||||
|
var retrieved TestInsertModel
|
||||||
|
err = db.NewSelect().
|
||||||
|
Model(&retrieved).
|
||||||
|
Where("name = ?", "Test User").
|
||||||
|
Scan(ctx)
|
||||||
|
|
||||||
|
require.NoError(t, err, "Should retrieve inserted row")
|
||||||
|
assert.Equal(t, "Test User", retrieved.Name)
|
||||||
|
assert.Equal(t, "", retrieved.Email) // NULL becomes empty string
|
||||||
|
assert.Equal(t, 20, retrieved.Age)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBunInsertQuery_Returning(t *testing.T) {
|
||||||
|
db := setupBunTestDB(t)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
adapter := NewBunAdapter(db)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Test insert with RETURNING clause
|
||||||
|
// Note: SQLite has limited RETURNING support, but this tests the API
|
||||||
|
result, err := adapter.NewInsert().
|
||||||
|
Table("test_inserts").
|
||||||
|
Value("name", "Return Test").
|
||||||
|
Value("email", "return@example.com").
|
||||||
|
Value("age", 40).
|
||||||
|
Returning("*").
|
||||||
|
Exec(ctx)
|
||||||
|
|
||||||
|
require.NoError(t, err, "Insert with RETURNING should succeed")
|
||||||
|
assert.Equal(t, int64(1), result.RowsAffected())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBunInsertQuery_EmptyValues(t *testing.T) {
|
||||||
|
db := setupBunTestDB(t)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
adapter := NewBunAdapter(db)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Test insert without calling Value() - should use Model() or fail gracefully
|
||||||
|
result, err := adapter.NewInsert().
|
||||||
|
Table("test_inserts").
|
||||||
|
Exec(ctx)
|
||||||
|
|
||||||
|
// This should fail because no values are provided
|
||||||
|
assert.Error(t, err, "Insert without values should fail")
|
||||||
|
if result != nil {
|
||||||
|
assert.Equal(t, int64(0), result.RowsAffected())
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
|
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
|
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/reflection"
|
"github.com/bitechdev/ResolveSpec/pkg/reflection"
|
||||||
)
|
)
|
||||||
@@ -38,12 +39,22 @@ func (g *GormAdapter) NewDelete() common.DeleteQuery {
|
|||||||
return &GormDeleteQuery{db: g.db}
|
return &GormDeleteQuery{db: g.db}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *GormAdapter) Exec(ctx context.Context, query string, args ...interface{}) (common.Result, error) {
|
func (g *GormAdapter) Exec(ctx context.Context, query string, args ...interface{}) (res common.Result, err error) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
err = logger.HandlePanic("GormAdapter.Exec", r)
|
||||||
|
}
|
||||||
|
}()
|
||||||
result := g.db.WithContext(ctx).Exec(query, args...)
|
result := g.db.WithContext(ctx).Exec(query, args...)
|
||||||
return &GormResult{result: result}, result.Error
|
return &GormResult{result: result}, result.Error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *GormAdapter) Query(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
|
func (g *GormAdapter) Query(ctx context.Context, dest interface{}, query string, args ...interface{}) (err error) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
err = logger.HandlePanic("GormAdapter.Query", r)
|
||||||
|
}
|
||||||
|
}()
|
||||||
return g.db.WithContext(ctx).Raw(query, args...).Find(dest).Error
|
return g.db.WithContext(ctx).Raw(query, args...).Find(dest).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -63,7 +74,12 @@ func (g *GormAdapter) RollbackTx(ctx context.Context) error {
|
|||||||
return g.db.WithContext(ctx).Rollback().Error
|
return g.db.WithContext(ctx).Rollback().Error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *GormAdapter) RunInTransaction(ctx context.Context, fn func(common.Database) error) error {
|
func (g *GormAdapter) RunInTransaction(ctx context.Context, fn func(common.Database) error) (err error) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
err = logger.HandlePanic("GormAdapter.RunInTransaction", r)
|
||||||
|
}
|
||||||
|
}()
|
||||||
return g.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
return g.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||||
adapter := &GormAdapter{db: tx}
|
adapter := &GormAdapter{db: tx}
|
||||||
return fn(adapter)
|
return fn(adapter)
|
||||||
@@ -255,26 +271,48 @@ func (g *GormSelectQuery) Having(having string, args ...interface{}) common.Sele
|
|||||||
return g
|
return g
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *GormSelectQuery) Scan(ctx context.Context, dest interface{}) error {
|
func (g *GormSelectQuery) Scan(ctx context.Context, dest interface{}) (err error) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
err = logger.HandlePanic("GormSelectQuery.Scan", r)
|
||||||
|
}
|
||||||
|
}()
|
||||||
return g.db.WithContext(ctx).Find(dest).Error
|
return g.db.WithContext(ctx).Find(dest).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *GormSelectQuery) ScanModel(ctx context.Context) error {
|
func (g *GormSelectQuery) ScanModel(ctx context.Context) (err error) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
err = logger.HandlePanic("GormSelectQuery.ScanModel", r)
|
||||||
|
}
|
||||||
|
}()
|
||||||
if g.db.Statement.Model == nil {
|
if g.db.Statement.Model == nil {
|
||||||
return fmt.Errorf("ScanModel requires Model() to be set before scanning")
|
return fmt.Errorf("ScanModel requires Model() to be set before scanning")
|
||||||
}
|
}
|
||||||
return g.db.WithContext(ctx).Find(g.db.Statement.Model).Error
|
return g.db.WithContext(ctx).Find(g.db.Statement.Model).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *GormSelectQuery) Count(ctx context.Context) (int, error) {
|
func (g *GormSelectQuery) Count(ctx context.Context) (count int, err error) {
|
||||||
var count int64
|
defer func() {
|
||||||
err := g.db.WithContext(ctx).Count(&count).Error
|
if r := recover(); r != nil {
|
||||||
return int(count), err
|
err = logger.HandlePanic("GormSelectQuery.Count", r)
|
||||||
|
count = 0
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
var count64 int64
|
||||||
|
err = g.db.WithContext(ctx).Count(&count64).Error
|
||||||
|
return int(count64), err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *GormSelectQuery) Exists(ctx context.Context) (bool, error) {
|
func (g *GormSelectQuery) Exists(ctx context.Context) (exists bool, err error) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
err = logger.HandlePanic("GormSelectQuery.Exists", r)
|
||||||
|
exists = false
|
||||||
|
}
|
||||||
|
}()
|
||||||
var count int64
|
var count int64
|
||||||
err := g.db.WithContext(ctx).Limit(1).Count(&count).Error
|
err = g.db.WithContext(ctx).Limit(1).Count(&count).Error
|
||||||
return count > 0, err
|
return count > 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -314,7 +352,12 @@ func (g *GormInsertQuery) Returning(columns ...string) common.InsertQuery {
|
|||||||
return g
|
return g
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *GormInsertQuery) Exec(ctx context.Context) (common.Result, error) {
|
func (g *GormInsertQuery) Exec(ctx context.Context) (res common.Result, err error) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
err = logger.HandlePanic("GormInsertQuery.Exec", r)
|
||||||
|
}
|
||||||
|
}()
|
||||||
var result *gorm.DB
|
var result *gorm.DB
|
||||||
switch {
|
switch {
|
||||||
case g.model != nil:
|
case g.model != nil:
|
||||||
@@ -369,13 +412,20 @@ func (g *GormUpdateQuery) Set(column string, value interface{}) common.UpdateQue
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (g *GormUpdateQuery) SetMap(values map[string]interface{}) common.UpdateQuery {
|
func (g *GormUpdateQuery) SetMap(values map[string]interface{}) common.UpdateQuery {
|
||||||
|
|
||||||
// Filter out read-only columns if model is set
|
// Filter out read-only columns if model is set
|
||||||
if g.model != nil {
|
if g.model != nil {
|
||||||
|
pkName := reflection.GetPrimaryKeyName(g.model)
|
||||||
filteredValues := make(map[string]interface{})
|
filteredValues := make(map[string]interface{})
|
||||||
for column, value := range values {
|
for column, value := range values {
|
||||||
|
if pkName != "" && column == pkName {
|
||||||
|
// Skip primary key updates
|
||||||
|
continue
|
||||||
|
}
|
||||||
if reflection.IsColumnWritable(g.model, column) {
|
if reflection.IsColumnWritable(g.model, column) {
|
||||||
filteredValues[column] = value
|
filteredValues[column] = value
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
g.updates = filteredValues
|
g.updates = filteredValues
|
||||||
} else {
|
} else {
|
||||||
@@ -394,7 +444,12 @@ func (g *GormUpdateQuery) Returning(columns ...string) common.UpdateQuery {
|
|||||||
return g
|
return g
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *GormUpdateQuery) Exec(ctx context.Context) (common.Result, error) {
|
func (g *GormUpdateQuery) Exec(ctx context.Context) (res common.Result, err error) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
err = logger.HandlePanic("GormUpdateQuery.Exec", r)
|
||||||
|
}
|
||||||
|
}()
|
||||||
result := g.db.WithContext(ctx).Updates(g.updates)
|
result := g.db.WithContext(ctx).Updates(g.updates)
|
||||||
return &GormResult{result: result}, result.Error
|
return &GormResult{result: result}, result.Error
|
||||||
}
|
}
|
||||||
@@ -421,7 +476,12 @@ func (g *GormDeleteQuery) Where(query string, args ...interface{}) common.Delete
|
|||||||
return g
|
return g
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *GormDeleteQuery) Exec(ctx context.Context) (common.Result, error) {
|
func (g *GormDeleteQuery) Exec(ctx context.Context) (res common.Result, err error) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
err = logger.HandlePanic("GormDeleteQuery.Exec", r)
|
||||||
|
}
|
||||||
|
}()
|
||||||
result := g.db.WithContext(ctx).Delete(g.model)
|
result := g.db.WithContext(ctx).Delete(g.model)
|
||||||
return &GormResult{result: result}, result.Error
|
return &GormResult{result: result}, result.Error
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -121,6 +121,16 @@ func (b *BunRouterRequest) QueryParam(key string) string {
|
|||||||
return b.req.URL.Query().Get(key)
|
return b.req.URL.Query().Get(key)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (b *BunRouterRequest) AllQueryParams() map[string]string {
|
||||||
|
params := make(map[string]string)
|
||||||
|
for key, values := range b.req.URL.Query() {
|
||||||
|
if len(values) > 0 {
|
||||||
|
params[key] = values[0]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return params
|
||||||
|
}
|
||||||
|
|
||||||
func (b *BunRouterRequest) AllHeaders() map[string]string {
|
func (b *BunRouterRequest) AllHeaders() map[string]string {
|
||||||
headers := make(map[string]string)
|
headers := make(map[string]string)
|
||||||
for key, values := range b.req.Header {
|
for key, values := range b.req.Header {
|
||||||
|
|||||||
@@ -117,6 +117,16 @@ func (h *HTTPRequest) QueryParam(key string) string {
|
|||||||
return h.req.URL.Query().Get(key)
|
return h.req.URL.Query().Get(key)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *HTTPRequest) AllQueryParams() map[string]string {
|
||||||
|
params := make(map[string]string)
|
||||||
|
for key, values := range h.req.URL.Query() {
|
||||||
|
if len(values) > 0 {
|
||||||
|
params[key] = values[0]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return params
|
||||||
|
}
|
||||||
|
|
||||||
func (h *HTTPRequest) AllHeaders() map[string]string {
|
func (h *HTTPRequest) AllHeaders() map[string]string {
|
||||||
headers := make(map[string]string)
|
headers := make(map[string]string)
|
||||||
for key, values := range h.req.Header {
|
for key, values := range h.req.Header {
|
||||||
|
|||||||
@@ -116,6 +116,7 @@ type Request interface {
|
|||||||
Body() ([]byte, error)
|
Body() ([]byte, error)
|
||||||
PathParam(key string) string
|
PathParam(key string) string
|
||||||
QueryParam(key string) string
|
QueryParam(key string) string
|
||||||
|
AllQueryParams() map[string]string // Get all query parameters as a map
|
||||||
}
|
}
|
||||||
|
|
||||||
// ResponseWriter interface abstracts HTTP response
|
// ResponseWriter interface abstracts HTTP response
|
||||||
|
|||||||
@@ -111,6 +111,9 @@ func (p *NestedCUDProcessor) ProcessNestedCUD(
|
|||||||
// Inject parent IDs for foreign key resolution
|
// Inject parent IDs for foreign key resolution
|
||||||
p.injectForeignKeys(regularData, modelType, parentIDs)
|
p.injectForeignKeys(regularData, modelType, parentIDs)
|
||||||
|
|
||||||
|
// Get the primary key name for this model
|
||||||
|
pkName := reflection.GetPrimaryKeyName(model)
|
||||||
|
|
||||||
// Process based on operation
|
// Process based on operation
|
||||||
switch strings.ToLower(operation) {
|
switch strings.ToLower(operation) {
|
||||||
case "insert", "create":
|
case "insert", "create":
|
||||||
@@ -128,30 +131,30 @@ func (p *NestedCUDProcessor) ProcessNestedCUD(
|
|||||||
}
|
}
|
||||||
|
|
||||||
case "update":
|
case "update":
|
||||||
rows, err := p.processUpdate(ctx, regularData, tableName, data["id"])
|
rows, err := p.processUpdate(ctx, regularData, tableName, data[pkName])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("update failed: %w", err)
|
return nil, fmt.Errorf("update failed: %w", err)
|
||||||
}
|
}
|
||||||
result.ID = data["id"]
|
result.ID = data[pkName]
|
||||||
result.AffectedRows = rows
|
result.AffectedRows = rows
|
||||||
result.Data = regularData
|
result.Data = regularData
|
||||||
|
|
||||||
// Process child relations for update
|
// Process child relations for update
|
||||||
if err := p.processChildRelations(ctx, "update", data["id"], relationFields, result.RelationData, modelType); err != nil {
|
if err := p.processChildRelations(ctx, "update", data[pkName], relationFields, result.RelationData, modelType); err != nil {
|
||||||
return nil, fmt.Errorf("failed to process child relations: %w", err)
|
return nil, fmt.Errorf("failed to process child relations: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
case "delete":
|
case "delete":
|
||||||
// Process child relations first (for referential integrity)
|
// Process child relations first (for referential integrity)
|
||||||
if err := p.processChildRelations(ctx, "delete", data["id"], relationFields, result.RelationData, modelType); err != nil {
|
if err := p.processChildRelations(ctx, "delete", data[pkName], relationFields, result.RelationData, modelType); err != nil {
|
||||||
return nil, fmt.Errorf("failed to process child relations before delete: %w", err)
|
return nil, fmt.Errorf("failed to process child relations before delete: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
rows, err := p.processDelete(ctx, tableName, data["id"])
|
rows, err := p.processDelete(ctx, tableName, data[pkName])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("delete failed: %w", err)
|
return nil, fmt.Errorf("delete failed: %w", err)
|
||||||
}
|
}
|
||||||
result.ID = data["id"]
|
result.ID = data[pkName]
|
||||||
result.AffectedRows = rows
|
result.AffectedRows = rows
|
||||||
result.Data = regularData
|
result.Data = regularData
|
||||||
|
|
||||||
|
|||||||
247
pkg/common/sql_helpers.go
Normal file
247
pkg/common/sql_helpers.go
Normal file
@@ -0,0 +1,247 @@
|
|||||||
|
package common
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ValidateAndFixPreloadWhere validates that the WHERE clause for a preload contains
|
||||||
|
// the relation prefix (alias). If not present, it attempts to add it to column references.
|
||||||
|
// Returns the fixed WHERE clause and an error if it cannot be safely fixed.
|
||||||
|
func ValidateAndFixPreloadWhere(where string, relationName string) (string, error) {
|
||||||
|
if where == "" {
|
||||||
|
return where, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if the relation name is already present in the WHERE clause
|
||||||
|
lowerWhere := strings.ToLower(where)
|
||||||
|
lowerRelation := strings.ToLower(relationName)
|
||||||
|
|
||||||
|
// Check for patterns like "relation.", "relation ", or just "relation" followed by a dot
|
||||||
|
if strings.Contains(lowerWhere, lowerRelation+".") ||
|
||||||
|
strings.Contains(lowerWhere, "`"+lowerRelation+"`.") ||
|
||||||
|
strings.Contains(lowerWhere, "\""+lowerRelation+"\".") {
|
||||||
|
// Relation prefix is already present
|
||||||
|
return where, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the WHERE clause is complex (contains OR, parentheses, subqueries, etc.),
|
||||||
|
// we can't safely auto-fix it - require explicit prefix
|
||||||
|
if strings.Contains(lowerWhere, " or ") ||
|
||||||
|
strings.Contains(where, "(") ||
|
||||||
|
strings.Contains(where, ")") {
|
||||||
|
return "", fmt.Errorf("preload WHERE condition must reference the relation '%s' (e.g., '%s.column_name'). Complex WHERE clauses with OR/parentheses must explicitly use the relation prefix", relationName, relationName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to add the relation prefix to simple column references
|
||||||
|
// This handles basic cases like "column = value" or "column = value AND other_column = value"
|
||||||
|
// Split by AND to handle multiple conditions (case-insensitive)
|
||||||
|
originalConditions := strings.Split(where, " AND ")
|
||||||
|
|
||||||
|
// If uppercase split didn't work, try lowercase
|
||||||
|
if len(originalConditions) == 1 {
|
||||||
|
originalConditions = strings.Split(where, " and ")
|
||||||
|
}
|
||||||
|
|
||||||
|
fixedConditions := make([]string, 0, len(originalConditions))
|
||||||
|
|
||||||
|
for _, cond := range originalConditions {
|
||||||
|
cond = strings.TrimSpace(cond)
|
||||||
|
if cond == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if this condition already has a table prefix (contains a dot)
|
||||||
|
if strings.Contains(cond, ".") {
|
||||||
|
fixedConditions = append(fixedConditions, cond)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if this is a SQL expression/literal that shouldn't be prefixed
|
||||||
|
lowerCond := strings.ToLower(strings.TrimSpace(cond))
|
||||||
|
if IsSQLExpression(lowerCond) {
|
||||||
|
// Don't prefix SQL expressions like "true", "false", "1=1", etc.
|
||||||
|
fixedConditions = append(fixedConditions, cond)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract the column name (first identifier before operator)
|
||||||
|
columnName := ExtractColumnName(cond)
|
||||||
|
if columnName == "" {
|
||||||
|
// Can't identify column name, require explicit prefix
|
||||||
|
return "", fmt.Errorf("preload WHERE condition must reference the relation '%s' (e.g., '%s.column_name'). Cannot auto-fix condition: %s", relationName, relationName, cond)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add relation prefix to the column name only
|
||||||
|
fixedCond := strings.Replace(cond, columnName, relationName+"."+columnName, 1)
|
||||||
|
fixedConditions = append(fixedConditions, fixedCond)
|
||||||
|
}
|
||||||
|
|
||||||
|
fixedWhere := strings.Join(fixedConditions, " AND ")
|
||||||
|
logger.Debug("Auto-fixed preload WHERE clause: '%s' -> '%s'", where, fixedWhere)
|
||||||
|
return fixedWhere, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsSQLExpression checks if a condition is a SQL expression that shouldn't be prefixed
|
||||||
|
func IsSQLExpression(cond string) bool {
|
||||||
|
// Common SQL literals and expressions
|
||||||
|
sqlLiterals := []string{"true", "false", "null", "1=1", "1 = 1", "0=0", "0 = 0"}
|
||||||
|
for _, literal := range sqlLiterals {
|
||||||
|
if cond == literal {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsTrivialCondition checks if a condition is trivial and always evaluates to true
|
||||||
|
// These conditions should be removed from WHERE clauses as they have no filtering effect
|
||||||
|
func IsTrivialCondition(cond string) bool {
|
||||||
|
cond = strings.TrimSpace(cond)
|
||||||
|
lowerCond := strings.ToLower(cond)
|
||||||
|
|
||||||
|
// Conditions that always evaluate to true
|
||||||
|
trivialConditions := []string{
|
||||||
|
"1=1", "1 = 1", "1= 1", "1 =1",
|
||||||
|
"true", "true = true", "true=true", "true= true", "true =true",
|
||||||
|
"0=0", "0 = 0", "0= 0", "0 =0",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, trivial := range trivialConditions {
|
||||||
|
if lowerCond == trivial {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// SanitizeWhereClause removes trivial conditions and optionally prefixes table/relation names to columns
|
||||||
|
// This function should be used everywhere a WHERE statement is sent to ensure clean, efficient SQL
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - where: The WHERE clause string to sanitize
|
||||||
|
// - tableName: Optional table/relation name to prefix to column references (empty string to skip prefixing)
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - The sanitized WHERE clause with trivial conditions removed and columns optionally prefixed
|
||||||
|
// - An empty string if all conditions were trivial or the input was empty
|
||||||
|
func SanitizeWhereClause(where string, tableName string) string {
|
||||||
|
if where == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
where = strings.TrimSpace(where)
|
||||||
|
|
||||||
|
// Split by AND to handle multiple conditions
|
||||||
|
conditions := splitByAND(where)
|
||||||
|
|
||||||
|
validConditions := make([]string, 0, len(conditions))
|
||||||
|
|
||||||
|
for _, cond := range conditions {
|
||||||
|
cond = strings.TrimSpace(cond)
|
||||||
|
if cond == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip trivial conditions that always evaluate to true
|
||||||
|
if IsTrivialCondition(cond) {
|
||||||
|
logger.Debug("Removing trivial condition: '%s'", cond)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// If tableName is provided and the condition doesn't already have a table prefix,
|
||||||
|
// attempt to add it
|
||||||
|
if tableName != "" && !hasTablePrefix(cond) {
|
||||||
|
// Check if this is a SQL expression/literal that shouldn't be prefixed
|
||||||
|
if !IsSQLExpression(strings.ToLower(cond)) {
|
||||||
|
// Extract the column name and prefix it
|
||||||
|
columnName := ExtractColumnName(cond)
|
||||||
|
if columnName != "" {
|
||||||
|
cond = strings.Replace(cond, columnName, tableName+"."+columnName, 1)
|
||||||
|
logger.Debug("Prefixed column in condition: '%s'", cond)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
validConditions = append(validConditions, cond)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(validConditions) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
result := strings.Join(validConditions, " AND ")
|
||||||
|
|
||||||
|
if result != where {
|
||||||
|
logger.Debug("Sanitized WHERE clause: '%s' -> '%s'", where, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// splitByAND splits a WHERE clause by AND operators (case-insensitive)
|
||||||
|
// This is a simple split that doesn't handle nested parentheses or complex expressions
|
||||||
|
func splitByAND(where string) []string {
|
||||||
|
// First try uppercase AND
|
||||||
|
conditions := strings.Split(where, " AND ")
|
||||||
|
|
||||||
|
// If we didn't split on uppercase, try lowercase
|
||||||
|
if len(conditions) == 1 {
|
||||||
|
conditions = strings.Split(where, " and ")
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we still didn't split, try mixed case
|
||||||
|
if len(conditions) == 1 {
|
||||||
|
conditions = strings.Split(where, " And ")
|
||||||
|
}
|
||||||
|
|
||||||
|
return conditions
|
||||||
|
}
|
||||||
|
|
||||||
|
// hasTablePrefix checks if a condition already has a table/relation prefix (contains a dot)
|
||||||
|
func hasTablePrefix(cond string) bool {
|
||||||
|
// Look for patterns like "table.column" or "`table`.`column`" or "\"table\".\"column\""
|
||||||
|
return strings.Contains(cond, ".")
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExtractColumnName extracts the column name from a WHERE condition
|
||||||
|
// For example: "status = 'active'" returns "status"
|
||||||
|
func ExtractColumnName(cond string) string {
|
||||||
|
// Common SQL operators
|
||||||
|
operators := []string{" = ", " != ", " <> ", " > ", " >= ", " < ", " <= ", " LIKE ", " like ", " IN ", " in ", " IS ", " is "}
|
||||||
|
|
||||||
|
for _, op := range operators {
|
||||||
|
if idx := strings.Index(cond, op); idx > 0 {
|
||||||
|
columnName := strings.TrimSpace(cond[:idx])
|
||||||
|
// Remove quotes if present
|
||||||
|
columnName = strings.Trim(columnName, "`\"'")
|
||||||
|
return columnName
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If no operator found, check if it's a simple identifier (for boolean columns)
|
||||||
|
parts := strings.Fields(cond)
|
||||||
|
if len(parts) > 0 {
|
||||||
|
columnName := strings.Trim(parts[0], "`\"'")
|
||||||
|
// Check if it's a valid identifier (not a SQL keyword)
|
||||||
|
if !IsSQLKeyword(strings.ToLower(columnName)) {
|
||||||
|
return columnName
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsSQLKeyword checks if a string is a SQL keyword that shouldn't be treated as a column name
|
||||||
|
func IsSQLKeyword(word string) bool {
|
||||||
|
keywords := []string{"select", "from", "where", "and", "or", "not", "in", "is", "null", "true", "false", "like", "between", "exists"}
|
||||||
|
for _, kw := range keywords {
|
||||||
|
if word == kw {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
771
pkg/common/sql_types.go
Normal file
771
pkg/common/sql_types.go
Normal file
@@ -0,0 +1,771 @@
|
|||||||
|
package common
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"database/sql/driver"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
func tryParseDT(str string) (time.Time, error) {
|
||||||
|
var lasterror error
|
||||||
|
tryFormats := []string{time.RFC3339,
|
||||||
|
"2006-01-02T15:04:05.000-0700",
|
||||||
|
"2006-01-02T15:04:05.000",
|
||||||
|
"06-01-02T15:04:05.000",
|
||||||
|
"2006-01-02T15:04:05",
|
||||||
|
"2006-01-02 15:04:05",
|
||||||
|
"02/01/2006",
|
||||||
|
"02-01-2006",
|
||||||
|
"2006-01-02",
|
||||||
|
"15:04:05.000",
|
||||||
|
"15:04:05",
|
||||||
|
"15:04"}
|
||||||
|
|
||||||
|
for _, f := range tryFormats {
|
||||||
|
tx, err := time.Parse(f, str)
|
||||||
|
if err == nil {
|
||||||
|
return tx, nil
|
||||||
|
} else {
|
||||||
|
lasterror = err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return time.Now(), lasterror
|
||||||
|
}
|
||||||
|
|
||||||
|
func ToJSONDT(dt time.Time) string {
|
||||||
|
return dt.Format(time.RFC3339)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SqlInt16 - A Int16 that supports SQL string
|
||||||
|
type SqlInt16 int16
|
||||||
|
|
||||||
|
// Scan -
|
||||||
|
func (n *SqlInt16) Scan(value interface{}) error {
|
||||||
|
if value == nil {
|
||||||
|
*n = 0
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
switch v := value.(type) {
|
||||||
|
case int:
|
||||||
|
*n = SqlInt16(v)
|
||||||
|
case int32:
|
||||||
|
*n = SqlInt16(v)
|
||||||
|
case int64:
|
||||||
|
*n = SqlInt16(v)
|
||||||
|
default:
|
||||||
|
i, _ := strconv.ParseInt(fmt.Sprintf("%v", v), 10, 64)
|
||||||
|
*n = SqlInt16(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Value -
|
||||||
|
func (n SqlInt16) Value() (driver.Value, error) {
|
||||||
|
if n == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return int64(n), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// String - Override String format of ZNullInt32
|
||||||
|
func (n SqlInt16) String() string {
|
||||||
|
tmstr := fmt.Sprintf("%d", n)
|
||||||
|
return tmstr
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON - Overre JidSON format of ZNullInt32
|
||||||
|
func (n *SqlInt16) UnmarshalJSON(b []byte) error {
|
||||||
|
|
||||||
|
s := strings.Trim(strings.Trim(string(b), " "), "\"")
|
||||||
|
|
||||||
|
n64, err := strconv.ParseInt(s, 10, 64)
|
||||||
|
if err == nil {
|
||||||
|
*n = SqlInt16(n64)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON - Override JSON format of time
|
||||||
|
func (n SqlInt16) MarshalJSON() ([]byte, error) {
|
||||||
|
return []byte(fmt.Sprintf("%d", n)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SqlInt32 - A int32 that supports SQL string
|
||||||
|
type SqlInt32 int32
|
||||||
|
|
||||||
|
// Scan -
|
||||||
|
func (n *SqlInt32) Scan(value interface{}) error {
|
||||||
|
if value == nil {
|
||||||
|
*n = 0
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
switch v := value.(type) {
|
||||||
|
case int:
|
||||||
|
*n = SqlInt32(v)
|
||||||
|
case int32:
|
||||||
|
*n = SqlInt32(v)
|
||||||
|
case int64:
|
||||||
|
*n = SqlInt32(v)
|
||||||
|
default:
|
||||||
|
i, _ := strconv.ParseInt(fmt.Sprintf("%v", v), 10, 64)
|
||||||
|
*n = SqlInt32(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Value -
|
||||||
|
func (n SqlInt32) Value() (driver.Value, error) {
|
||||||
|
if n == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return int64(n), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// String - Override String format of ZNullInt32
|
||||||
|
func (n SqlInt32) String() string {
|
||||||
|
tmstr := fmt.Sprintf("%d", n)
|
||||||
|
return tmstr
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON - Overre JidSON format of ZNullInt32
|
||||||
|
func (n *SqlInt32) UnmarshalJSON(b []byte) error {
|
||||||
|
|
||||||
|
s := strings.Trim(strings.Trim(string(b), " "), "\"")
|
||||||
|
|
||||||
|
n64, err := strconv.ParseInt(s, 10, 64)
|
||||||
|
if err == nil {
|
||||||
|
*n = SqlInt32(n64)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON - Override JSON format of time
|
||||||
|
func (n SqlInt32) MarshalJSON() ([]byte, error) {
|
||||||
|
return []byte(fmt.Sprintf("%d", n)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SqlInt64 - A int64 that supports SQL string
|
||||||
|
type SqlInt64 int64
|
||||||
|
|
||||||
|
// Scan -
|
||||||
|
func (n *SqlInt64) Scan(value interface{}) error {
|
||||||
|
if value == nil {
|
||||||
|
*n = 0
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
switch v := value.(type) {
|
||||||
|
case int:
|
||||||
|
*n = SqlInt64(v)
|
||||||
|
case int32:
|
||||||
|
*n = SqlInt64(v)
|
||||||
|
case uint32:
|
||||||
|
*n = SqlInt64(v)
|
||||||
|
case int64:
|
||||||
|
*n = SqlInt64(v)
|
||||||
|
case uint64:
|
||||||
|
*n = SqlInt64(v)
|
||||||
|
default:
|
||||||
|
i, _ := strconv.ParseInt(fmt.Sprintf("%v", v), 10, 64)
|
||||||
|
*n = SqlInt64(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Value -
|
||||||
|
func (n SqlInt64) Value() (driver.Value, error) {
|
||||||
|
if n == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return int64(n), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// String - Override String format of ZNullInt32
|
||||||
|
func (n SqlInt64) String() string {
|
||||||
|
tmstr := fmt.Sprintf("%d", n)
|
||||||
|
return tmstr
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON - Overre JidSON format of ZNullInt32
|
||||||
|
func (n *SqlInt64) UnmarshalJSON(b []byte) error {
|
||||||
|
|
||||||
|
s := strings.Trim(strings.Trim(string(b), " "), "\"")
|
||||||
|
|
||||||
|
n64, err := strconv.ParseInt(s, 10, 64)
|
||||||
|
if err == nil {
|
||||||
|
*n = SqlInt64(n64)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON - Override JSON format of time
|
||||||
|
func (n SqlInt64) MarshalJSON() ([]byte, error) {
|
||||||
|
return []byte(fmt.Sprintf("%d", n)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SqlTimeStamp - Implementation of SqlTimeStamp with some interfaces.
|
||||||
|
type SqlTimeStamp time.Time
|
||||||
|
|
||||||
|
// MarshalJSON - Override JSON format of time
|
||||||
|
func (t SqlTimeStamp) MarshalJSON() ([]byte, error) {
|
||||||
|
if time.Time(t).IsZero() {
|
||||||
|
return []byte("null"), nil
|
||||||
|
}
|
||||||
|
if time.Time(t).Before(time.Date(0001, 1, 1, 0, 0, 0, 0, time.UTC)) {
|
||||||
|
return []byte("null"), nil
|
||||||
|
}
|
||||||
|
tmstr := time.Time(t).Format("2006-01-02T15:04:05")
|
||||||
|
if tmstr == "0001-01-01T00:00:00" {
|
||||||
|
return []byte("null"), nil
|
||||||
|
}
|
||||||
|
return []byte(fmt.Sprintf("\"%s\"", tmstr)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON - Override JSON format of time
|
||||||
|
func (t *SqlTimeStamp) UnmarshalJSON(b []byte) error {
|
||||||
|
var err error
|
||||||
|
|
||||||
|
if b == nil {
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
s := strings.Trim(strings.Trim(string(b), " "), "\"")
|
||||||
|
if s == "null" || s == "" || s == "0" ||
|
||||||
|
s == "0001-01-01T00:00:00" || s == "0001-01-01" {
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
tx, err := tryParseDT(s)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
*t = SqlTimeStamp(tx)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Value - SQL Value of custom date
|
||||||
|
func (t SqlTimeStamp) Value() (driver.Value, error) {
|
||||||
|
if t.GetTime().IsZero() || t.GetTime().Before(time.Date(0002, 1, 1, 0, 0, 0, 0, time.UTC)) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
tmstr := time.Time(t).Format("2006-01-02T15:04:05")
|
||||||
|
if tmstr <= "0001-01-01" || tmstr == "" {
|
||||||
|
empty := time.Time{}
|
||||||
|
return empty, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return tmstr, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Scan - Scan custom date from sql
|
||||||
|
func (t *SqlTimeStamp) Scan(value interface{}) error {
|
||||||
|
tm, ok := value.(time.Time)
|
||||||
|
if ok {
|
||||||
|
*t = SqlTimeStamp(tm)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
str, ok := value.(string)
|
||||||
|
if ok {
|
||||||
|
tx, err := tryParseDT(str)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
*t = SqlTimeStamp(tx)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// String - Override String format of time
|
||||||
|
func (t SqlTimeStamp) String() string {
|
||||||
|
return time.Time(t).Format("2006-01-02T15:04:05")
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetTime - Returns Time
|
||||||
|
func (t SqlTimeStamp) GetTime() time.Time {
|
||||||
|
return time.Time(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetTime - Returns Time
|
||||||
|
func (t *SqlTimeStamp) SetTime(pTime time.Time) {
|
||||||
|
*t = SqlTimeStamp(pTime)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Format - Formats the time
|
||||||
|
func (t SqlTimeStamp) Format(layout string) string {
|
||||||
|
return time.Time(t).Format(layout)
|
||||||
|
}
|
||||||
|
|
||||||
|
func SqlTimeStampNow() SqlTimeStamp {
|
||||||
|
tx := time.Now()
|
||||||
|
|
||||||
|
return SqlTimeStamp(tx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SqlFloat64 - SQL Int
|
||||||
|
type SqlFloat64 sql.NullFloat64
|
||||||
|
|
||||||
|
// Scan -
|
||||||
|
func (n *SqlFloat64) Scan(value interface{}) error {
|
||||||
|
newval := sql.NullFloat64{Float64: 0, Valid: false}
|
||||||
|
if value == nil {
|
||||||
|
newval.Valid = false
|
||||||
|
*n = SqlFloat64(newval)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
switch v := value.(type) {
|
||||||
|
case int:
|
||||||
|
newval.Float64 = float64(v)
|
||||||
|
newval.Valid = true
|
||||||
|
case float64:
|
||||||
|
newval.Float64 = float64(v)
|
||||||
|
newval.Valid = true
|
||||||
|
case float32:
|
||||||
|
newval.Float64 = float64(v)
|
||||||
|
newval.Valid = true
|
||||||
|
case int64:
|
||||||
|
newval.Float64 = float64(v)
|
||||||
|
newval.Valid = true
|
||||||
|
case int32:
|
||||||
|
newval.Float64 = float64(v)
|
||||||
|
newval.Valid = true
|
||||||
|
case uint16:
|
||||||
|
newval.Float64 = float64(v)
|
||||||
|
newval.Valid = true
|
||||||
|
case uint64:
|
||||||
|
newval.Float64 = float64(v)
|
||||||
|
newval.Valid = true
|
||||||
|
case uint32:
|
||||||
|
newval.Float64 = float64(v)
|
||||||
|
newval.Valid = true
|
||||||
|
default:
|
||||||
|
i, err := strconv.ParseInt(fmt.Sprintf("%v", v), 10, 64)
|
||||||
|
newval.Float64 = float64(i)
|
||||||
|
if err == nil {
|
||||||
|
newval.Valid = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
*n = SqlFloat64(newval)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Value -
|
||||||
|
func (n SqlFloat64) Value() (driver.Value, error) {
|
||||||
|
if !n.Valid {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return float64(n.Float64), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// String -
|
||||||
|
func (n SqlFloat64) String() string {
|
||||||
|
if !n.Valid {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
tmstr := fmt.Sprintf("%f", n.Float64)
|
||||||
|
return tmstr
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON -
|
||||||
|
func (n *SqlFloat64) UnmarshalJSON(b []byte) error {
|
||||||
|
|
||||||
|
s := strings.Trim(strings.Trim(string(b), " "), "\"")
|
||||||
|
invalid := (s == "null" || s == "" || len(s) < 2) || (strings.Contains(s, "{") || strings.Contains(s, "["))
|
||||||
|
if invalid {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
nval, err := strconv.ParseInt(s, 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
*n = SqlFloat64(sql.NullFloat64{Valid: true, Float64: float64(nval)})
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON - Override JSON format of time
|
||||||
|
func (n SqlFloat64) MarshalJSON() ([]byte, error) {
|
||||||
|
if !n.Valid {
|
||||||
|
return []byte("null"), nil
|
||||||
|
}
|
||||||
|
return []byte(fmt.Sprintf("%f", n.Float64)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SqlDate - Implementation of SqlTime with some interfaces.
|
||||||
|
type SqlDate time.Time
|
||||||
|
|
||||||
|
// UnmarshalJSON - Override JSON format of time
|
||||||
|
func (t *SqlDate) UnmarshalJSON(b []byte) error {
|
||||||
|
var err error
|
||||||
|
|
||||||
|
s := strings.Trim(strings.Trim(string(b), " "), "\"")
|
||||||
|
if s == "null" || s == "" || s == "0" ||
|
||||||
|
strings.HasPrefix(s, "0001-01-01T00:00:00") ||
|
||||||
|
s == "0001-01-01" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
tx, err := tryParseDT(s)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
*t = SqlDate(tx)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON - Override JSON format of time
|
||||||
|
func (t SqlDate) MarshalJSON() ([]byte, error) {
|
||||||
|
tmstr := time.Time(t).Format("2006-01-02") // time.RFC3339
|
||||||
|
if strings.HasPrefix(tmstr, "0001-01-01") {
|
||||||
|
return []byte("null"), nil
|
||||||
|
}
|
||||||
|
return []byte(fmt.Sprintf("\"%s\"", tmstr)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Value - SQL Value of custom date
|
||||||
|
func (t SqlDate) Value() (driver.Value, error) {
|
||||||
|
var s time.Time
|
||||||
|
tmstr := time.Time(t).Format("2006-01-02")
|
||||||
|
if strings.HasPrefix(tmstr, "0001-01-01") || tmstr <= "0001-01-01" {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
s = time.Time(t)
|
||||||
|
|
||||||
|
return s.Format("2006-01-02"), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Scan - Scan custom date from sql
|
||||||
|
func (t *SqlDate) Scan(value interface{}) error {
|
||||||
|
tm, ok := value.(time.Time)
|
||||||
|
if ok {
|
||||||
|
*t = SqlDate(tm)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
str, ok := value.(string)
|
||||||
|
if ok {
|
||||||
|
tx, err := tryParseDT(str)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
*t = SqlDate(tx)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Int64 - Override date format in unix epoch
|
||||||
|
func (t SqlDate) Int64() int64 {
|
||||||
|
return time.Time(t).Unix()
|
||||||
|
}
|
||||||
|
|
||||||
|
// String - Override String format of time
|
||||||
|
func (t SqlDate) String() string {
|
||||||
|
tmstr := time.Time(t).Format("2006-01-02") // time.RFC3339
|
||||||
|
if strings.HasPrefix(tmstr, "0001-01-01") || strings.HasPrefix(tmstr, "1800-12-31") {
|
||||||
|
return "0"
|
||||||
|
}
|
||||||
|
return tmstr
|
||||||
|
}
|
||||||
|
|
||||||
|
func SqlDateNow() SqlDate {
|
||||||
|
tx := time.Now()
|
||||||
|
return SqlDate(tx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ////////////////////// SqlTime /////////////////////////
|
||||||
|
// SqlTime - Implementation of SqlTime with some interfaces.
|
||||||
|
type SqlTime time.Time
|
||||||
|
|
||||||
|
// Int64 - Override Time format in unix epoch
|
||||||
|
func (t SqlTime) Int64() int64 {
|
||||||
|
return time.Time(t).Unix()
|
||||||
|
}
|
||||||
|
|
||||||
|
// String - Override String format of time
|
||||||
|
func (t SqlTime) String() string {
|
||||||
|
return time.Time(t).Format("15:04:05")
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON - Override JSON format of time
|
||||||
|
func (t *SqlTime) UnmarshalJSON(b []byte) error {
|
||||||
|
var err error
|
||||||
|
s := strings.Trim(strings.Trim(string(b), " "), "\"")
|
||||||
|
if s == "null" || s == "" || s == "0" ||
|
||||||
|
s == "0001-01-01T00:00:00" || s == "00:00:00" {
|
||||||
|
*t = SqlTime{}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
tx, err := tryParseDT(s)
|
||||||
|
*t = SqlTime(tx)
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Format - Format Function
|
||||||
|
func (t SqlTime) Format(form string) string {
|
||||||
|
tmstr := time.Time(t).Format(form)
|
||||||
|
return tmstr
|
||||||
|
}
|
||||||
|
|
||||||
|
// Scan - Scan custom date from sql
|
||||||
|
func (t *SqlTime) Scan(value interface{}) error {
|
||||||
|
tm, ok := value.(time.Time)
|
||||||
|
if ok {
|
||||||
|
*t = SqlTime(tm)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
str, ok := value.(string)
|
||||||
|
if ok {
|
||||||
|
tx, err := tryParseDT(str)
|
||||||
|
*t = SqlTime(tx)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Value - SQL Value of custom date
|
||||||
|
func (t SqlTime) Value() (driver.Value, error) {
|
||||||
|
|
||||||
|
s := time.Time(t)
|
||||||
|
st := s.Format("15:04:05")
|
||||||
|
|
||||||
|
return st, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON - Override JSON format of time
|
||||||
|
func (t SqlTime) MarshalJSON() ([]byte, error) {
|
||||||
|
tmstr := time.Time(t).Format("15:04:05")
|
||||||
|
if tmstr == "0001-01-01T00:00:00" || tmstr == "00:00:00" {
|
||||||
|
return []byte("null"), nil
|
||||||
|
}
|
||||||
|
return []byte(fmt.Sprintf("\"%s\"", tmstr)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func SqlTimeNow() SqlTime {
|
||||||
|
tx := time.Now()
|
||||||
|
return SqlTime(tx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SqlJSONB - Nullable JSONB String
|
||||||
|
type SqlJSONB []byte
|
||||||
|
|
||||||
|
// Scan - Implements sql.Scanner for reading JSONB from database
|
||||||
|
func (n *SqlJSONB) Scan(value interface{}) error {
|
||||||
|
if value == nil {
|
||||||
|
*n = nil
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
switch v := value.(type) {
|
||||||
|
case string:
|
||||||
|
*n = SqlJSONB([]byte(v))
|
||||||
|
case []byte:
|
||||||
|
*n = SqlJSONB(v)
|
||||||
|
default:
|
||||||
|
// For other types, marshal to JSON
|
||||||
|
dat, err := json.Marshal(value)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to marshal value to JSON: %v", err)
|
||||||
|
}
|
||||||
|
*n = SqlJSONB(dat)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Value - Implements driver.Valuer for writing JSONB to database
|
||||||
|
func (n SqlJSONB) Value() (driver.Value, error) {
|
||||||
|
if len(n) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate that it's valid JSON before returning
|
||||||
|
var js interface{}
|
||||||
|
if err := json.Unmarshal(n, &js); err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid JSON: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return as string for PostgreSQL JSONB/JSON columns
|
||||||
|
return string(n), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n SqlJSONB) AsMap() (map[string]any, error) {
|
||||||
|
if len(n) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
// Validate that it's valid JSON before returning
|
||||||
|
js := make(map[string]any)
|
||||||
|
if err := json.Unmarshal(n, &js); err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid JSON: %v", err)
|
||||||
|
}
|
||||||
|
return js, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n SqlJSONB) AsSlice() ([]any, error) {
|
||||||
|
if len(n) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
// Validate that it's valid JSON before returning
|
||||||
|
js := make([]any, 0)
|
||||||
|
if err := json.Unmarshal(n, &js); err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid JSON: %v", err)
|
||||||
|
}
|
||||||
|
return js, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON - Override JSON
|
||||||
|
func (n *SqlJSONB) UnmarshalJSON(b []byte) error {
|
||||||
|
|
||||||
|
s := strings.Trim(strings.Trim(string(b), " "), "\"")
|
||||||
|
invalid := (s == "null" || s == "" || len(s) < 2) || (!strings.Contains(s, "{") && !strings.Contains(s, "["))
|
||||||
|
if invalid {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
*n = []byte(s)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON - Override JSON format of time
|
||||||
|
func (n SqlJSONB) MarshalJSON() ([]byte, error) {
|
||||||
|
if n == nil {
|
||||||
|
return []byte("null"), nil
|
||||||
|
}
|
||||||
|
var obj interface{}
|
||||||
|
err := json.Unmarshal(n, &obj)
|
||||||
|
if err != nil {
|
||||||
|
// fmt.Printf("Invalid JSON %v", err)
|
||||||
|
return []byte("null"), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// dat, err := json.MarshalIndent(obj, " ", " ")
|
||||||
|
// if err != nil {
|
||||||
|
// return nil, fmt.Errorf("failed to convert to JSON: %v", err)
|
||||||
|
// }
|
||||||
|
dat := n
|
||||||
|
|
||||||
|
return dat, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SqlUUID - Nullable UUID String
|
||||||
|
type SqlUUID sql.NullString
|
||||||
|
|
||||||
|
// Scan -
|
||||||
|
func (n *SqlUUID) Scan(value interface{}) error {
|
||||||
|
str := sql.NullString{String: "", Valid: false}
|
||||||
|
if value == nil {
|
||||||
|
*n = SqlUUID(str)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
switch v := value.(type) {
|
||||||
|
case string:
|
||||||
|
uuid, err := uuid.Parse(v)
|
||||||
|
if err == nil {
|
||||||
|
str.String = uuid.String()
|
||||||
|
str.Valid = true
|
||||||
|
*n = SqlUUID(str)
|
||||||
|
}
|
||||||
|
case []uint8:
|
||||||
|
uuid, err := uuid.ParseBytes(v)
|
||||||
|
if err == nil {
|
||||||
|
str.String = uuid.String()
|
||||||
|
str.Valid = true
|
||||||
|
*n = SqlUUID(str)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
uuid, err := uuid.Parse(fmt.Sprintf("%v", v))
|
||||||
|
if err == nil {
|
||||||
|
str.String = uuid.String()
|
||||||
|
str.Valid = true
|
||||||
|
*n = SqlUUID(str)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Value -
|
||||||
|
func (n SqlUUID) Value() (driver.Value, error) {
|
||||||
|
if !n.Valid {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return n.String, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON - Override JSON
|
||||||
|
func (n *SqlUUID) UnmarshalJSON(b []byte) error {
|
||||||
|
|
||||||
|
s := strings.Trim(strings.Trim(string(b), " "), "\"")
|
||||||
|
invalid := (s == "null" || s == "" || len(s) < 30)
|
||||||
|
if invalid {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
*n = SqlUUID(sql.NullString{String: s, Valid: !invalid})
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON - Override JSON format of time
|
||||||
|
func (n SqlUUID) MarshalJSON() ([]byte, error) {
|
||||||
|
if !n.Valid {
|
||||||
|
return []byte("null"), nil
|
||||||
|
}
|
||||||
|
return []byte(fmt.Sprintf("\"%s\"", n.String)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// TryIfInt64 - Wrapper function to quickly try and cast text to int
|
||||||
|
func TryIfInt64(v any, def int64) int64 {
|
||||||
|
str := ""
|
||||||
|
switch val := v.(type) {
|
||||||
|
case string:
|
||||||
|
str = val
|
||||||
|
case int:
|
||||||
|
return int64(val)
|
||||||
|
case int32:
|
||||||
|
return int64(val)
|
||||||
|
case int64:
|
||||||
|
return val
|
||||||
|
case uint32:
|
||||||
|
return int64(val)
|
||||||
|
case uint64:
|
||||||
|
return int64(val)
|
||||||
|
case float32:
|
||||||
|
return int64(val)
|
||||||
|
case float64:
|
||||||
|
return int64(val)
|
||||||
|
case []byte:
|
||||||
|
str = string(val)
|
||||||
|
default:
|
||||||
|
str = fmt.Sprintf("%d", def)
|
||||||
|
}
|
||||||
|
val, err := strconv.ParseInt(str, 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
return def
|
||||||
|
}
|
||||||
|
return val
|
||||||
|
}
|
||||||
566
pkg/common/sql_types_test.go
Normal file
566
pkg/common/sql_types_test.go
Normal file
@@ -0,0 +1,566 @@
|
|||||||
|
package common
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql/driver"
|
||||||
|
"encoding/json"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestSqlInt16 tests SqlInt16 type
|
||||||
|
func TestSqlInt16(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input interface{}
|
||||||
|
expected SqlInt16
|
||||||
|
}{
|
||||||
|
{"int", 42, SqlInt16(42)},
|
||||||
|
{"int32", int32(100), SqlInt16(100)},
|
||||||
|
{"int64", int64(200), SqlInt16(200)},
|
||||||
|
{"string", "123", SqlInt16(123)},
|
||||||
|
{"nil", nil, SqlInt16(0)},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
var n SqlInt16
|
||||||
|
if err := n.Scan(tt.input); err != nil {
|
||||||
|
t.Fatalf("Scan failed: %v", err)
|
||||||
|
}
|
||||||
|
if n != tt.expected {
|
||||||
|
t.Errorf("expected %v, got %v", tt.expected, n)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqlInt16_Value(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input SqlInt16
|
||||||
|
expected driver.Value
|
||||||
|
}{
|
||||||
|
{"zero", SqlInt16(0), nil},
|
||||||
|
{"positive", SqlInt16(42), int64(42)},
|
||||||
|
{"negative", SqlInt16(-10), int64(-10)},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
val, err := tt.input.Value()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Value failed: %v", err)
|
||||||
|
}
|
||||||
|
if val != tt.expected {
|
||||||
|
t.Errorf("expected %v, got %v", tt.expected, val)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqlInt16_JSON(t *testing.T) {
|
||||||
|
n := SqlInt16(42)
|
||||||
|
|
||||||
|
// Marshal
|
||||||
|
data, err := json.Marshal(n)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Marshal failed: %v", err)
|
||||||
|
}
|
||||||
|
expected := "42"
|
||||||
|
if string(data) != expected {
|
||||||
|
t.Errorf("expected %s, got %s", expected, string(data))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unmarshal
|
||||||
|
var n2 SqlInt16
|
||||||
|
if err := json.Unmarshal([]byte("123"), &n2); err != nil {
|
||||||
|
t.Fatalf("Unmarshal failed: %v", err)
|
||||||
|
}
|
||||||
|
if n2 != 123 {
|
||||||
|
t.Errorf("expected 123, got %d", n2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestSqlInt64 tests SqlInt64 type
|
||||||
|
func TestSqlInt64(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input interface{}
|
||||||
|
expected SqlInt64
|
||||||
|
}{
|
||||||
|
{"int", 42, SqlInt64(42)},
|
||||||
|
{"int32", int32(100), SqlInt64(100)},
|
||||||
|
{"int64", int64(9223372036854775807), SqlInt64(9223372036854775807)},
|
||||||
|
{"uint32", uint32(100), SqlInt64(100)},
|
||||||
|
{"uint64", uint64(200), SqlInt64(200)},
|
||||||
|
{"nil", nil, SqlInt64(0)},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
var n SqlInt64
|
||||||
|
if err := n.Scan(tt.input); err != nil {
|
||||||
|
t.Fatalf("Scan failed: %v", err)
|
||||||
|
}
|
||||||
|
if n != tt.expected {
|
||||||
|
t.Errorf("expected %v, got %v", tt.expected, n)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestSqlFloat64 tests SqlFloat64 type
|
||||||
|
func TestSqlFloat64(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input interface{}
|
||||||
|
expected float64
|
||||||
|
valid bool
|
||||||
|
}{
|
||||||
|
{"float64", float64(3.14), 3.14, true},
|
||||||
|
{"float32", float32(2.5), 2.5, true},
|
||||||
|
{"int", 42, 42.0, true},
|
||||||
|
{"int64", int64(100), 100.0, true},
|
||||||
|
{"nil", nil, 0, false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
var n SqlFloat64
|
||||||
|
if err := n.Scan(tt.input); err != nil {
|
||||||
|
t.Fatalf("Scan failed: %v", err)
|
||||||
|
}
|
||||||
|
if n.Valid != tt.valid {
|
||||||
|
t.Errorf("expected valid=%v, got valid=%v", tt.valid, n.Valid)
|
||||||
|
}
|
||||||
|
if tt.valid && n.Float64 != tt.expected {
|
||||||
|
t.Errorf("expected %v, got %v", tt.expected, n.Float64)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestSqlTimeStamp tests SqlTimeStamp type
|
||||||
|
func TestSqlTimeStamp(t *testing.T) {
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input interface{}
|
||||||
|
}{
|
||||||
|
{"time.Time", now},
|
||||||
|
{"string RFC3339", now.Format(time.RFC3339)},
|
||||||
|
{"string date", "2024-01-15"},
|
||||||
|
{"string datetime", "2024-01-15T10:30:00"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
var ts SqlTimeStamp
|
||||||
|
if err := ts.Scan(tt.input); err != nil {
|
||||||
|
t.Fatalf("Scan failed: %v", err)
|
||||||
|
}
|
||||||
|
if ts.GetTime().IsZero() {
|
||||||
|
t.Error("expected non-zero time")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqlTimeStamp_JSON(t *testing.T) {
|
||||||
|
now := time.Date(2024, 1, 15, 10, 30, 45, 0, time.UTC)
|
||||||
|
ts := SqlTimeStamp(now)
|
||||||
|
|
||||||
|
// Marshal
|
||||||
|
data, err := json.Marshal(ts)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Marshal failed: %v", err)
|
||||||
|
}
|
||||||
|
expected := `"2024-01-15T10:30:45"`
|
||||||
|
if string(data) != expected {
|
||||||
|
t.Errorf("expected %s, got %s", expected, string(data))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unmarshal
|
||||||
|
var ts2 SqlTimeStamp
|
||||||
|
if err := json.Unmarshal([]byte(`"2024-01-15T10:30:45"`), &ts2); err != nil {
|
||||||
|
t.Fatalf("Unmarshal failed: %v", err)
|
||||||
|
}
|
||||||
|
if ts2.GetTime().Year() != 2024 {
|
||||||
|
t.Errorf("expected year 2024, got %d", ts2.GetTime().Year())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test null
|
||||||
|
var ts3 SqlTimeStamp
|
||||||
|
if err := json.Unmarshal([]byte("null"), &ts3); err != nil {
|
||||||
|
t.Fatalf("Unmarshal null failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestSqlDate tests SqlDate type
|
||||||
|
func TestSqlDate(t *testing.T) {
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input interface{}
|
||||||
|
}{
|
||||||
|
{"time.Time", now},
|
||||||
|
{"string date", "2024-01-15"},
|
||||||
|
{"string UK format", "15/01/2024"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
var d SqlDate
|
||||||
|
if err := d.Scan(tt.input); err != nil {
|
||||||
|
t.Fatalf("Scan failed: %v", err)
|
||||||
|
}
|
||||||
|
if d.String() == "0" {
|
||||||
|
t.Error("expected non-zero date")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqlDate_JSON(t *testing.T) {
|
||||||
|
date := SqlDate(time.Date(2024, 1, 15, 0, 0, 0, 0, time.UTC))
|
||||||
|
|
||||||
|
// Marshal
|
||||||
|
data, err := json.Marshal(date)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Marshal failed: %v", err)
|
||||||
|
}
|
||||||
|
expected := `"2024-01-15"`
|
||||||
|
if string(data) != expected {
|
||||||
|
t.Errorf("expected %s, got %s", expected, string(data))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unmarshal
|
||||||
|
var d2 SqlDate
|
||||||
|
if err := json.Unmarshal([]byte(`"2024-01-15"`), &d2); err != nil {
|
||||||
|
t.Fatalf("Unmarshal failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestSqlTime tests SqlTime type
|
||||||
|
func TestSqlTime(t *testing.T) {
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input interface{}
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{"time.Time", now, now.Format("15:04:05")},
|
||||||
|
{"string time", "10:30:45", "10:30:45"},
|
||||||
|
{"string short time", "10:30", "10:30:00"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
var tm SqlTime
|
||||||
|
if err := tm.Scan(tt.input); err != nil {
|
||||||
|
t.Fatalf("Scan failed: %v", err)
|
||||||
|
}
|
||||||
|
if tm.String() != tt.expected {
|
||||||
|
t.Errorf("expected %s, got %s", tt.expected, tm.String())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestSqlJSONB tests SqlJSONB type
|
||||||
|
func TestSqlJSONB_Scan(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input interface{}
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{"string JSON object", `{"key":"value"}`, `{"key":"value"}`},
|
||||||
|
{"string JSON array", `[1,2,3]`, `[1,2,3]`},
|
||||||
|
{"bytes", []byte(`{"test":true}`), `{"test":true}`},
|
||||||
|
{"nil", nil, ""},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
var j SqlJSONB
|
||||||
|
if err := j.Scan(tt.input); err != nil {
|
||||||
|
t.Fatalf("Scan failed: %v", err)
|
||||||
|
}
|
||||||
|
if tt.expected == "" && j == nil {
|
||||||
|
return // nil case
|
||||||
|
}
|
||||||
|
if string(j) != tt.expected {
|
||||||
|
t.Errorf("expected %s, got %s", tt.expected, string(j))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqlJSONB_Value(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input SqlJSONB
|
||||||
|
expected string
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"valid object", SqlJSONB(`{"key":"value"}`), `{"key":"value"}`, false},
|
||||||
|
{"valid array", SqlJSONB(`[1,2,3]`), `[1,2,3]`, false},
|
||||||
|
{"empty", SqlJSONB{}, "", false},
|
||||||
|
{"nil", nil, "", false},
|
||||||
|
{"invalid JSON", SqlJSONB(`{invalid`), "", true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
val, err := tt.input.Value()
|
||||||
|
if tt.wantErr {
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error, got nil")
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Value failed: %v", err)
|
||||||
|
}
|
||||||
|
if tt.expected == "" && val == nil {
|
||||||
|
return // nil case
|
||||||
|
}
|
||||||
|
if val.(string) != tt.expected {
|
||||||
|
t.Errorf("expected %s, got %s", tt.expected, val)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqlJSONB_JSON(t *testing.T) {
|
||||||
|
// Marshal
|
||||||
|
j := SqlJSONB(`{"name":"test","count":42}`)
|
||||||
|
data, err := json.Marshal(j)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Marshal failed: %v", err)
|
||||||
|
}
|
||||||
|
var result map[string]interface{}
|
||||||
|
if err := json.Unmarshal(data, &result); err != nil {
|
||||||
|
t.Fatalf("Unmarshal result failed: %v", err)
|
||||||
|
}
|
||||||
|
if result["name"] != "test" {
|
||||||
|
t.Errorf("expected name=test, got %v", result["name"])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unmarshal
|
||||||
|
var j2 SqlJSONB
|
||||||
|
if err := json.Unmarshal([]byte(`{"key":"value"}`), &j2); err != nil {
|
||||||
|
t.Fatalf("Unmarshal failed: %v", err)
|
||||||
|
}
|
||||||
|
if string(j2) != `{"key":"value"}` {
|
||||||
|
t.Errorf("expected {\"key\":\"value\"}, got %s", string(j2))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test null
|
||||||
|
var j3 SqlJSONB
|
||||||
|
if err := json.Unmarshal([]byte("null"), &j3); err != nil {
|
||||||
|
t.Fatalf("Unmarshal null failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqlJSONB_AsMap(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input SqlJSONB
|
||||||
|
wantErr bool
|
||||||
|
wantNil bool
|
||||||
|
}{
|
||||||
|
{"valid object", SqlJSONB(`{"name":"test","age":30}`), false, false},
|
||||||
|
{"empty", SqlJSONB{}, false, true},
|
||||||
|
{"nil", nil, false, true},
|
||||||
|
{"invalid JSON", SqlJSONB(`{invalid`), true, false},
|
||||||
|
{"array not object", SqlJSONB(`[1,2,3]`), true, false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
m, err := tt.input.AsMap()
|
||||||
|
if tt.wantErr {
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error, got nil")
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("AsMap failed: %v", err)
|
||||||
|
}
|
||||||
|
if tt.wantNil {
|
||||||
|
if m != nil {
|
||||||
|
t.Errorf("expected nil, got %v", m)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if m == nil {
|
||||||
|
t.Error("expected non-nil map")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqlJSONB_AsSlice(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input SqlJSONB
|
||||||
|
wantErr bool
|
||||||
|
wantNil bool
|
||||||
|
}{
|
||||||
|
{"valid array", SqlJSONB(`[1,2,3]`), false, false},
|
||||||
|
{"empty", SqlJSONB{}, false, true},
|
||||||
|
{"nil", nil, false, true},
|
||||||
|
{"invalid JSON", SqlJSONB(`[invalid`), true, false},
|
||||||
|
{"object not array", SqlJSONB(`{"key":"value"}`), true, false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
s, err := tt.input.AsSlice()
|
||||||
|
if tt.wantErr {
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error, got nil")
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("AsSlice failed: %v", err)
|
||||||
|
}
|
||||||
|
if tt.wantNil {
|
||||||
|
if s != nil {
|
||||||
|
t.Errorf("expected nil, got %v", s)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if s == nil {
|
||||||
|
t.Error("expected non-nil slice")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestSqlUUID tests SqlUUID type
|
||||||
|
func TestSqlUUID_Scan(t *testing.T) {
|
||||||
|
testUUID := uuid.New()
|
||||||
|
testUUIDStr := testUUID.String()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input interface{}
|
||||||
|
expected string
|
||||||
|
valid bool
|
||||||
|
}{
|
||||||
|
{"string UUID", testUUIDStr, testUUIDStr, true},
|
||||||
|
{"bytes UUID", []byte(testUUIDStr), testUUIDStr, true},
|
||||||
|
{"nil", nil, "", false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
var u SqlUUID
|
||||||
|
if err := u.Scan(tt.input); err != nil {
|
||||||
|
t.Fatalf("Scan failed: %v", err)
|
||||||
|
}
|
||||||
|
if u.Valid != tt.valid {
|
||||||
|
t.Errorf("expected valid=%v, got valid=%v", tt.valid, u.Valid)
|
||||||
|
}
|
||||||
|
if tt.valid && u.String != tt.expected {
|
||||||
|
t.Errorf("expected %s, got %s", tt.expected, u.String)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqlUUID_Value(t *testing.T) {
|
||||||
|
testUUID := uuid.New()
|
||||||
|
u := SqlUUID{String: testUUID.String(), Valid: true}
|
||||||
|
|
||||||
|
val, err := u.Value()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Value failed: %v", err)
|
||||||
|
}
|
||||||
|
if val != testUUID.String() {
|
||||||
|
t.Errorf("expected %s, got %s", testUUID.String(), val)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test invalid UUID
|
||||||
|
u2 := SqlUUID{Valid: false}
|
||||||
|
val2, err := u2.Value()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Value failed: %v", err)
|
||||||
|
}
|
||||||
|
if val2 != nil {
|
||||||
|
t.Errorf("expected nil, got %v", val2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqlUUID_JSON(t *testing.T) {
|
||||||
|
testUUID := uuid.New()
|
||||||
|
u := SqlUUID{String: testUUID.String(), Valid: true}
|
||||||
|
|
||||||
|
// Marshal
|
||||||
|
data, err := json.Marshal(u)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Marshal failed: %v", err)
|
||||||
|
}
|
||||||
|
expected := `"` + testUUID.String() + `"`
|
||||||
|
if string(data) != expected {
|
||||||
|
t.Errorf("expected %s, got %s", expected, string(data))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unmarshal
|
||||||
|
var u2 SqlUUID
|
||||||
|
if err := json.Unmarshal([]byte(`"`+testUUID.String()+`"`), &u2); err != nil {
|
||||||
|
t.Fatalf("Unmarshal failed: %v", err)
|
||||||
|
}
|
||||||
|
if u2.String != testUUID.String() {
|
||||||
|
t.Errorf("expected %s, got %s", testUUID.String(), u2.String)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test null
|
||||||
|
var u3 SqlUUID
|
||||||
|
if err := json.Unmarshal([]byte("null"), &u3); err != nil {
|
||||||
|
t.Fatalf("Unmarshal null failed: %v", err)
|
||||||
|
}
|
||||||
|
if u3.Valid {
|
||||||
|
t.Error("expected invalid UUID")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestTryIfInt64 tests the TryIfInt64 helper function
|
||||||
|
func TestTryIfInt64(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input interface{}
|
||||||
|
def int64
|
||||||
|
expected int64
|
||||||
|
}{
|
||||||
|
{"string valid", "123", 0, 123},
|
||||||
|
{"string invalid", "abc", 99, 99},
|
||||||
|
{"int", 42, 0, 42},
|
||||||
|
{"int32", int32(100), 0, 100},
|
||||||
|
{"int64", int64(200), 0, 200},
|
||||||
|
{"uint32", uint32(50), 0, 50},
|
||||||
|
{"uint64", uint64(75), 0, 75},
|
||||||
|
{"float32", float32(3.14), 0, 3},
|
||||||
|
{"float64", float64(2.71), 0, 2},
|
||||||
|
{"bytes", []byte("456"), 0, 456},
|
||||||
|
{"unknown type", struct{}{}, 999, 999},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := TryIfInt64(tt.input, tt.def)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("expected %d, got %d", tt.expected, result)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -32,15 +32,17 @@ type Parameter struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type PreloadOption struct {
|
type PreloadOption struct {
|
||||||
Relation string `json:"relation"`
|
Relation string `json:"relation"`
|
||||||
Columns []string `json:"columns"`
|
Columns []string `json:"columns"`
|
||||||
OmitColumns []string `json:"omit_columns"`
|
OmitColumns []string `json:"omit_columns"`
|
||||||
Sort []SortOption `json:"sort"`
|
Sort []SortOption `json:"sort"`
|
||||||
Filters []FilterOption `json:"filters"`
|
Filters []FilterOption `json:"filters"`
|
||||||
Where string `json:"where"`
|
Where string `json:"where"`
|
||||||
Limit *int `json:"limit"`
|
Limit *int `json:"limit"`
|
||||||
Offset *int `json:"offset"`
|
Offset *int `json:"offset"`
|
||||||
Updatable *bool `json:"updateable"` // if true, the relation can be updated
|
Updatable *bool `json:"updateable"` // if true, the relation can be updated
|
||||||
|
ComputedQL map[string]string `json:"computed_ql"` // Computed columns as SQL expressions
|
||||||
|
Recursive bool `json:"recursive"` // if true, preload recursively up to 5 levels
|
||||||
}
|
}
|
||||||
|
|
||||||
type FilterOption struct {
|
type FilterOption struct {
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/reflection"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ColumnValidator validates column names against a model's fields
|
// ColumnValidator validates column names against a model's fields
|
||||||
@@ -95,6 +96,7 @@ func (v *ColumnValidator) getColumnName(field reflect.StructField) string {
|
|||||||
// ValidateColumn validates a single column name
|
// ValidateColumn validates a single column name
|
||||||
// Returns nil if valid, error if invalid
|
// Returns nil if valid, error if invalid
|
||||||
// Columns prefixed with "cql" (case insensitive) are always valid
|
// Columns prefixed with "cql" (case insensitive) are always valid
|
||||||
|
// Handles PostgreSQL JSON operators (-> and ->>)
|
||||||
func (v *ColumnValidator) ValidateColumn(column string) error {
|
func (v *ColumnValidator) ValidateColumn(column string) error {
|
||||||
// Allow empty columns
|
// Allow empty columns
|
||||||
if column == "" {
|
if column == "" {
|
||||||
@@ -106,8 +108,11 @@ func (v *ColumnValidator) ValidateColumn(column string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Extract source column name (remove JSON operators like ->> or ->)
|
||||||
|
sourceColumn := reflection.ExtractSourceColumn(column)
|
||||||
|
|
||||||
// Check if column exists in model
|
// Check if column exists in model
|
||||||
if _, exists := v.validColumns[strings.ToLower(column)]; !exists {
|
if _, exists := v.validColumns[strings.ToLower(sourceColumn)]; !exists {
|
||||||
return fmt.Errorf("invalid column '%s': column does not exist in model", column)
|
return fmt.Errorf("invalid column '%s': column does not exist in model", column)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
126
pkg/common/validation_json_test.go
Normal file
126
pkg/common/validation_json_test.go
Normal file
@@ -0,0 +1,126 @@
|
|||||||
|
package common
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/reflection"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestExtractSourceColumn(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "simple column name",
|
||||||
|
input: "columna",
|
||||||
|
expected: "columna",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "column with ->> operator",
|
||||||
|
input: "columna->>'val'",
|
||||||
|
expected: "columna",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "column with -> operator",
|
||||||
|
input: "columna->'key'",
|
||||||
|
expected: "columna",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "column with table prefix and ->> operator",
|
||||||
|
input: "table.columna->>'val'",
|
||||||
|
expected: "table.columna",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "column with table prefix and -> operator",
|
||||||
|
input: "table.columna->'key'",
|
||||||
|
expected: "table.columna",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "complex JSON path with ->>",
|
||||||
|
input: "data->>'nested'->>'value'",
|
||||||
|
expected: "data",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "column with spaces before operator",
|
||||||
|
input: "columna ->>'val'",
|
||||||
|
expected: "columna",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
result := reflection.ExtractSourceColumn(tc.input)
|
||||||
|
if result != tc.expected {
|
||||||
|
t.Errorf("reflection.ExtractSourceColumn(%q) = %q; want %q", tc.input, result, tc.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateColumnWithJSONOperators(t *testing.T) {
|
||||||
|
// Create a test model
|
||||||
|
type TestModel struct {
|
||||||
|
ID int `json:"id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Data string `json:"data"` // JSON column
|
||||||
|
Metadata string `json:"metadata"`
|
||||||
|
}
|
||||||
|
|
||||||
|
validator := NewColumnValidator(TestModel{})
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
column string
|
||||||
|
shouldErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "simple valid column",
|
||||||
|
column: "name",
|
||||||
|
shouldErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid column with ->> operator",
|
||||||
|
column: "data->>'field'",
|
||||||
|
shouldErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid column with -> operator",
|
||||||
|
column: "metadata->'key'",
|
||||||
|
shouldErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid column",
|
||||||
|
column: "invalid_column",
|
||||||
|
shouldErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid column with ->> operator",
|
||||||
|
column: "invalid_column->>'field'",
|
||||||
|
shouldErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "cql prefixed column (always valid)",
|
||||||
|
column: "cql_computed",
|
||||||
|
shouldErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty column",
|
||||||
|
column: "",
|
||||||
|
shouldErr: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
err := validator.ValidateColumn(tc.column)
|
||||||
|
if tc.shouldErr && err == nil {
|
||||||
|
t.Errorf("ValidateColumn(%q) expected error, got nil", tc.column)
|
||||||
|
}
|
||||||
|
if !tc.shouldErr && err != nil {
|
||||||
|
t.Errorf("ValidateColumn(%q) expected no error, got %v", tc.column, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -103,3 +103,18 @@ func CatchPanicCallback(location string, cb func(err any)) {
|
|||||||
func CatchPanic(location string) {
|
func CatchPanic(location string) {
|
||||||
CatchPanicCallback(location, nil)
|
CatchPanicCallback(location, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// HandlePanic logs a panic and returns it as an error
|
||||||
|
// This should be called with the result of recover() from a deferred function
|
||||||
|
// Example usage:
|
||||||
|
//
|
||||||
|
// defer func() {
|
||||||
|
// if r := recover(); r != nil {
|
||||||
|
// err = logger.HandlePanic("MethodName", r)
|
||||||
|
// }
|
||||||
|
// }()
|
||||||
|
func HandlePanic(methodName string, r any) error {
|
||||||
|
stack := debug.Stack()
|
||||||
|
Error("Panic in %s: %v\nStack trace:\n%s", methodName, r, string(stack))
|
||||||
|
return fmt.Errorf("panic in %s: %v", methodName, r)
|
||||||
|
}
|
||||||
|
|||||||
@@ -29,7 +29,23 @@ func NewModelRegistry() *DefaultModelRegistry {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func SetDefaultRegistry(registry *DefaultModelRegistry) {
|
func SetDefaultRegistry(registry *DefaultModelRegistry) {
|
||||||
|
registriesMutex.Lock()
|
||||||
|
foundAt := -1
|
||||||
|
for idx, r := range registries {
|
||||||
|
if r == defaultRegistry {
|
||||||
|
foundAt = idx
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
defaultRegistry = registry
|
defaultRegistry = registry
|
||||||
|
if foundAt >= 0 {
|
||||||
|
registries[foundAt] = registry
|
||||||
|
} else {
|
||||||
|
registries = append([]*DefaultModelRegistry{registry}, registries...)
|
||||||
|
}
|
||||||
|
|
||||||
|
defer registriesMutex.Unlock()
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddRegistry adds a registry to the global list of registries
|
// AddRegistry adds a registry to the global list of registries
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ type ModelFieldDetail struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetModelColumnDetail - Get a list of columns in the SQL declaration of the model
|
// GetModelColumnDetail - Get a list of columns in the SQL declaration of the model
|
||||||
|
// This function recursively processes embedded structs to include their fields
|
||||||
func GetModelColumnDetail(record reflect.Value) []ModelFieldDetail {
|
func GetModelColumnDetail(record reflect.Value) []ModelFieldDetail {
|
||||||
defer func() {
|
defer func() {
|
||||||
if r := recover(); r != nil {
|
if r := recover(); r != nil {
|
||||||
@@ -25,8 +26,7 @@ func GetModelColumnDetail(record reflect.Value) []ModelFieldDetail {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
var lst []ModelFieldDetail
|
lst := make([]ModelFieldDetail, 0)
|
||||||
lst = make([]ModelFieldDetail, 0)
|
|
||||||
|
|
||||||
if !record.IsValid() {
|
if !record.IsValid() {
|
||||||
return lst
|
return lst
|
||||||
@@ -37,14 +37,43 @@ func GetModelColumnDetail(record reflect.Value) []ModelFieldDetail {
|
|||||||
if record.Kind() != reflect.Struct {
|
if record.Kind() != reflect.Struct {
|
||||||
return lst
|
return lst
|
||||||
}
|
}
|
||||||
|
|
||||||
|
collectFieldDetails(record, &lst)
|
||||||
|
|
||||||
|
return lst
|
||||||
|
}
|
||||||
|
|
||||||
|
// collectFieldDetails recursively collects field details from a struct value and its embedded fields
|
||||||
|
func collectFieldDetails(record reflect.Value, lst *[]ModelFieldDetail) {
|
||||||
modeltype := record.Type()
|
modeltype := record.Type()
|
||||||
|
|
||||||
for i := 0; i < modeltype.NumField(); i++ {
|
for i := 0; i < modeltype.NumField(); i++ {
|
||||||
fieldtype := modeltype.Field(i)
|
fieldtype := modeltype.Field(i)
|
||||||
|
fieldValue := record.Field(i)
|
||||||
|
|
||||||
|
// Check if this is an embedded struct
|
||||||
|
if fieldtype.Anonymous {
|
||||||
|
// Unwrap pointer type if necessary
|
||||||
|
embeddedValue := fieldValue
|
||||||
|
if fieldValue.Kind() == reflect.Pointer {
|
||||||
|
if fieldValue.IsNil() {
|
||||||
|
// Skip nil embedded pointers
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
embeddedValue = fieldValue.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Recursively process embedded struct
|
||||||
|
if embeddedValue.Kind() == reflect.Struct {
|
||||||
|
collectFieldDetails(embeddedValue, lst)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
gormdetail := fieldtype.Tag.Get("gorm")
|
gormdetail := fieldtype.Tag.Get("gorm")
|
||||||
gormdetail = strings.Trim(gormdetail, " ")
|
gormdetail = strings.Trim(gormdetail, " ")
|
||||||
fielddetail := ModelFieldDetail{}
|
fielddetail := ModelFieldDetail{}
|
||||||
fielddetail.FieldValue = record.Field(i)
|
fielddetail.FieldValue = fieldValue
|
||||||
fielddetail.Name = fieldtype.Name
|
fielddetail.Name = fieldtype.Name
|
||||||
fielddetail.DataType = fieldtype.Type.Name()
|
fielddetail.DataType = fieldtype.Type.Name()
|
||||||
fielddetail.SQLName = fnFindKeyVal(gormdetail, "column:")
|
fielddetail.SQLName = fnFindKeyVal(gormdetail, "column:")
|
||||||
@@ -80,10 +109,8 @@ func GetModelColumnDetail(record reflect.Value) []ModelFieldDetail {
|
|||||||
}
|
}
|
||||||
// ";foreignkey:rid_parent;association_foreignkey:id_atevent;save_associations:false;association_autocreate:false;"
|
// ";foreignkey:rid_parent;association_foreignkey:id_atevent;save_associations:false;association_autocreate:false;"
|
||||||
|
|
||||||
lst = append(lst, fielddetail)
|
*lst = append(*lst, fielddetail)
|
||||||
|
|
||||||
}
|
}
|
||||||
return lst
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func fnFindKeyVal(src, key string) string {
|
func fnFindKeyVal(src, key string) string {
|
||||||
|
|||||||
@@ -1,7 +1,9 @@
|
|||||||
package reflection
|
package reflection
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
|
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
|
||||||
@@ -47,7 +49,7 @@ func GetPrimaryKeyName(model any) string {
|
|||||||
|
|
||||||
// GetPrimaryKeyValue extracts the primary key value from a model instance
|
// GetPrimaryKeyValue extracts the primary key value from a model instance
|
||||||
// Returns the value of the primary key field
|
// Returns the value of the primary key field
|
||||||
func GetPrimaryKeyValue(model any) interface{} {
|
func GetPrimaryKeyValue(model any) any {
|
||||||
if model == nil || reflect.TypeOf(model) == nil {
|
if model == nil || reflect.TypeOf(model) == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -61,38 +63,51 @@ func GetPrimaryKeyValue(model any) interface{} {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
typ := val.Type()
|
|
||||||
|
|
||||||
// Try Bun tag first
|
// Try Bun tag first
|
||||||
for i := 0; i < typ.NumField(); i++ {
|
if pkValue := findPrimaryKeyValue(val, "bun"); pkValue != nil {
|
||||||
field := typ.Field(i)
|
return pkValue
|
||||||
bunTag := field.Tag.Get("bun")
|
|
||||||
if strings.Contains(bunTag, "pk") {
|
|
||||||
fieldValue := val.Field(i)
|
|
||||||
if fieldValue.CanInterface() {
|
|
||||||
return fieldValue.Interface()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fall back to GORM tag
|
// Fall back to GORM tag
|
||||||
for i := 0; i < typ.NumField(); i++ {
|
if pkValue := findPrimaryKeyValue(val, "gorm"); pkValue != nil {
|
||||||
field := typ.Field(i)
|
return pkValue
|
||||||
gormTag := field.Tag.Get("gorm")
|
|
||||||
if strings.Contains(gormTag, "primaryKey") {
|
|
||||||
fieldValue := val.Field(i)
|
|
||||||
if fieldValue.CanInterface() {
|
|
||||||
return fieldValue.Interface()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Last resort: look for field named "ID" or "Id"
|
// Last resort: look for field named "ID" or "Id"
|
||||||
|
if pkValue := findFieldByName(val, "id"); pkValue != nil {
|
||||||
|
return pkValue
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// findPrimaryKeyValue recursively searches for a primary key field in the struct
|
||||||
|
func findPrimaryKeyValue(val reflect.Value, ormType string) any {
|
||||||
|
typ := val.Type()
|
||||||
|
|
||||||
for i := 0; i < typ.NumField(); i++ {
|
for i := 0; i < typ.NumField(); i++ {
|
||||||
field := typ.Field(i)
|
field := typ.Field(i)
|
||||||
if strings.ToLower(field.Name) == "id" {
|
fieldValue := val.Field(i)
|
||||||
fieldValue := val.Field(i)
|
|
||||||
if fieldValue.CanInterface() {
|
// Check if this is an embedded struct
|
||||||
|
if field.Anonymous && field.Type.Kind() == reflect.Struct {
|
||||||
|
// Recursively search in embedded struct
|
||||||
|
if pkValue := findPrimaryKeyValue(fieldValue, ormType); pkValue != nil {
|
||||||
|
return pkValue
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for primary key tag
|
||||||
|
switch ormType {
|
||||||
|
case "bun":
|
||||||
|
bunTag := field.Tag.Get("bun")
|
||||||
|
if strings.Contains(bunTag, "pk") && fieldValue.CanInterface() {
|
||||||
|
return fieldValue.Interface()
|
||||||
|
}
|
||||||
|
case "gorm":
|
||||||
|
gormTag := field.Tag.Get("gorm")
|
||||||
|
if strings.Contains(gormTag, "primaryKey") && fieldValue.CanInterface() {
|
||||||
return fieldValue.Interface()
|
return fieldValue.Interface()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -101,8 +116,35 @@ func GetPrimaryKeyValue(model any) interface{} {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// findFieldByName recursively searches for a field by name in the struct
|
||||||
|
func findFieldByName(val reflect.Value, name string) any {
|
||||||
|
typ := val.Type()
|
||||||
|
|
||||||
|
for i := 0; i < typ.NumField(); i++ {
|
||||||
|
field := typ.Field(i)
|
||||||
|
fieldValue := val.Field(i)
|
||||||
|
|
||||||
|
// Check if this is an embedded struct
|
||||||
|
if field.Anonymous && field.Type.Kind() == reflect.Struct {
|
||||||
|
// Recursively search in embedded struct
|
||||||
|
if result := findFieldByName(fieldValue, name); result != nil {
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if field name matches
|
||||||
|
if strings.EqualFold(field.Name, name) && fieldValue.CanInterface() {
|
||||||
|
return fieldValue.Interface()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// GetModelColumns extracts all column names from a model using reflection
|
// GetModelColumns extracts all column names from a model using reflection
|
||||||
// It checks bun tags first, then gorm tags, then json tags, and finally falls back to lowercase field names
|
// It checks bun tags first, then gorm tags, then json tags, and finally falls back to lowercase field names
|
||||||
|
// This function recursively processes embedded structs to include their fields
|
||||||
func GetModelColumns(model any) []string {
|
func GetModelColumns(model any) []string {
|
||||||
var columns []string
|
var columns []string
|
||||||
|
|
||||||
@@ -118,18 +160,38 @@ func GetModelColumns(model any) []string {
|
|||||||
return columns
|
return columns
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := 0; i < modelType.NumField(); i++ {
|
collectColumnsFromType(modelType, &columns)
|
||||||
field := modelType.Field(i)
|
|
||||||
|
return columns
|
||||||
|
}
|
||||||
|
|
||||||
|
// collectColumnsFromType recursively collects column names from a struct type and its embedded fields
|
||||||
|
func collectColumnsFromType(typ reflect.Type, columns *[]string) {
|
||||||
|
for i := 0; i < typ.NumField(); i++ {
|
||||||
|
field := typ.Field(i)
|
||||||
|
|
||||||
|
// Check if this is an embedded struct
|
||||||
|
if field.Anonymous {
|
||||||
|
// Unwrap pointer type if necessary
|
||||||
|
fieldType := field.Type
|
||||||
|
if fieldType.Kind() == reflect.Pointer {
|
||||||
|
fieldType = fieldType.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Recursively process embedded struct
|
||||||
|
if fieldType.Kind() == reflect.Struct {
|
||||||
|
collectColumnsFromType(fieldType, columns)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Get column name using the same logic as primary key extraction
|
// Get column name using the same logic as primary key extraction
|
||||||
columnName := getColumnNameFromField(field)
|
columnName := getColumnNameFromField(field)
|
||||||
|
|
||||||
if columnName != "" {
|
if columnName != "" {
|
||||||
columns = append(columns, columnName)
|
*columns = append(*columns, columnName)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return columns
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// getColumnNameFromField extracts the column name from a struct field
|
// getColumnNameFromField extracts the column name from a struct field
|
||||||
@@ -166,6 +228,7 @@ func getColumnNameFromField(field reflect.StructField) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// getPrimaryKeyFromReflection uses reflection to find the primary key field
|
// getPrimaryKeyFromReflection uses reflection to find the primary key field
|
||||||
|
// This function recursively searches embedded structs
|
||||||
func getPrimaryKeyFromReflection(model any, ormType string) string {
|
func getPrimaryKeyFromReflection(model any, ormType string) string {
|
||||||
val := reflect.ValueOf(model)
|
val := reflect.ValueOf(model)
|
||||||
if val.Kind() == reflect.Pointer {
|
if val.Kind() == reflect.Pointer {
|
||||||
@@ -177,9 +240,31 @@ func getPrimaryKeyFromReflection(model any, ormType string) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
typ := val.Type()
|
typ := val.Type()
|
||||||
|
return findPrimaryKeyNameFromType(typ, ormType)
|
||||||
|
}
|
||||||
|
|
||||||
|
// findPrimaryKeyNameFromType recursively searches for the primary key field name in a struct type
|
||||||
|
func findPrimaryKeyNameFromType(typ reflect.Type, ormType string) string {
|
||||||
for i := 0; i < typ.NumField(); i++ {
|
for i := 0; i < typ.NumField(); i++ {
|
||||||
field := typ.Field(i)
|
field := typ.Field(i)
|
||||||
|
|
||||||
|
// Check if this is an embedded struct
|
||||||
|
if field.Anonymous {
|
||||||
|
// Unwrap pointer type if necessary
|
||||||
|
fieldType := field.Type
|
||||||
|
if fieldType.Kind() == reflect.Pointer {
|
||||||
|
fieldType = fieldType.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Recursively search in embedded struct
|
||||||
|
if fieldType.Kind() == reflect.Struct {
|
||||||
|
if pkName := findPrimaryKeyNameFromType(fieldType, ormType); pkName != "" {
|
||||||
|
return pkName
|
||||||
|
}
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
switch ormType {
|
switch ormType {
|
||||||
case "gorm":
|
case "gorm":
|
||||||
// Check for gorm tag with primaryKey
|
// Check for gorm tag with primaryKey
|
||||||
@@ -231,15 +316,140 @@ func ExtractColumnFromGormTag(tag string) string {
|
|||||||
// Example: ",pk" -> "" (will fall back to json tag)
|
// Example: ",pk" -> "" (will fall back to json tag)
|
||||||
func ExtractColumnFromBunTag(tag string) string {
|
func ExtractColumnFromBunTag(tag string) string {
|
||||||
parts := strings.Split(tag, ",")
|
parts := strings.Split(tag, ",")
|
||||||
|
if strings.HasPrefix(strings.ToLower(tag), "table:") || strings.HasPrefix(strings.ToLower(tag), "rel:") || strings.HasPrefix(strings.ToLower(tag), "join:") {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
if len(parts) > 0 && parts[0] != "" {
|
if len(parts) > 0 && parts[0] != "" {
|
||||||
return parts[0]
|
return parts[0]
|
||||||
}
|
}
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetSQLModelColumns extracts column names that have valid SQL field mappings
|
||||||
|
// This function only returns columns that:
|
||||||
|
// 1. Have bun or gorm tags (not just json tags)
|
||||||
|
// 2. Are not relations (no rel:, join:, foreignKey, references, many2many tags)
|
||||||
|
// 3. Are not scan-only embedded fields
|
||||||
|
func GetSQLModelColumns(model any) []string {
|
||||||
|
var columns []string
|
||||||
|
|
||||||
|
modelType := reflect.TypeOf(model)
|
||||||
|
|
||||||
|
// Unwrap pointers, slices, and arrays to get to the base struct type
|
||||||
|
for modelType != nil && (modelType.Kind() == reflect.Pointer || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
||||||
|
modelType = modelType.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate that we have a struct type
|
||||||
|
if modelType == nil || modelType.Kind() != reflect.Struct {
|
||||||
|
return columns
|
||||||
|
}
|
||||||
|
|
||||||
|
collectSQLColumnsFromType(modelType, &columns, false)
|
||||||
|
|
||||||
|
return columns
|
||||||
|
}
|
||||||
|
|
||||||
|
// collectSQLColumnsFromType recursively collects SQL column names from a struct type
|
||||||
|
// scanOnlyEmbedded indicates if we're inside a scan-only embedded struct
|
||||||
|
func collectSQLColumnsFromType(typ reflect.Type, columns *[]string, scanOnlyEmbedded bool) {
|
||||||
|
for i := 0; i < typ.NumField(); i++ {
|
||||||
|
field := typ.Field(i)
|
||||||
|
|
||||||
|
// Check if this is an embedded struct
|
||||||
|
if field.Anonymous {
|
||||||
|
// Unwrap pointer type if necessary
|
||||||
|
fieldType := field.Type
|
||||||
|
if fieldType.Kind() == reflect.Pointer {
|
||||||
|
fieldType = fieldType.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if the embedded struct itself is scan-only
|
||||||
|
isScanOnly := scanOnlyEmbedded
|
||||||
|
bunTag := field.Tag.Get("bun")
|
||||||
|
if bunTag != "" && isBunFieldScanOnly(bunTag) {
|
||||||
|
isScanOnly = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Recursively process embedded struct
|
||||||
|
if fieldType.Kind() == reflect.Struct {
|
||||||
|
collectSQLColumnsFromType(fieldType, columns, isScanOnly)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip fields in scan-only embedded structs
|
||||||
|
if scanOnlyEmbedded {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get bun and gorm tags
|
||||||
|
bunTag := field.Tag.Get("bun")
|
||||||
|
gormTag := field.Tag.Get("gorm")
|
||||||
|
|
||||||
|
// Skip if neither bun nor gorm tag exists
|
||||||
|
if bunTag == "" && gormTag == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip if explicitly marked with "-"
|
||||||
|
if bunTag == "-" || gormTag == "-" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip if field itself is scan-only (bun)
|
||||||
|
if bunTag != "" && isBunFieldScanOnly(bunTag) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip if field itself is read-only (gorm)
|
||||||
|
if gormTag != "" && isGormFieldReadOnly(gormTag) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip relation fields (bun)
|
||||||
|
if bunTag != "" {
|
||||||
|
// Skip if it's a bun relation (rel:, join:, or m2m:)
|
||||||
|
if strings.Contains(bunTag, "rel:") ||
|
||||||
|
strings.Contains(bunTag, "join:") ||
|
||||||
|
strings.Contains(bunTag, "m2m:") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip relation fields (gorm)
|
||||||
|
if gormTag != "" {
|
||||||
|
// Skip if it has gorm relationship tags
|
||||||
|
if strings.Contains(gormTag, "foreignKey:") ||
|
||||||
|
strings.Contains(gormTag, "references:") ||
|
||||||
|
strings.Contains(gormTag, "many2many:") ||
|
||||||
|
strings.Contains(gormTag, "constraint:") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get column name
|
||||||
|
columnName := ""
|
||||||
|
if bunTag != "" {
|
||||||
|
columnName = ExtractColumnFromBunTag(bunTag)
|
||||||
|
}
|
||||||
|
if columnName == "" && gormTag != "" {
|
||||||
|
columnName = ExtractColumnFromGormTag(gormTag)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip if we couldn't extract a column name
|
||||||
|
if columnName == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
*columns = append(*columns, columnName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// IsColumnWritable checks if a column can be written to in the database
|
// IsColumnWritable checks if a column can be written to in the database
|
||||||
// For bun: returns false if the field has "scanonly" tag
|
// For bun: returns false if the field has "scanonly" tag
|
||||||
// For gorm: returns false if the field has "<-:false" or "->" (read-only) tag
|
// For gorm: returns false if the field has "<-:false" or "->" (read-only) tag
|
||||||
|
// This function recursively searches embedded structs
|
||||||
func IsColumnWritable(model any, columnName string) bool {
|
func IsColumnWritable(model any, columnName string) bool {
|
||||||
modelType := reflect.TypeOf(model)
|
modelType := reflect.TypeOf(model)
|
||||||
|
|
||||||
@@ -253,8 +463,37 @@ func IsColumnWritable(model any, columnName string) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := 0; i < modelType.NumField(); i++ {
|
found, writable := isColumnWritableInType(modelType, columnName)
|
||||||
field := modelType.Field(i)
|
if found {
|
||||||
|
return writable
|
||||||
|
}
|
||||||
|
|
||||||
|
// Column not found in model, allow it (might be a dynamic column)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// isColumnWritableInType recursively searches for a column and checks if it's writable
|
||||||
|
// Returns (found, writable) where found indicates if the column was found
|
||||||
|
func isColumnWritableInType(typ reflect.Type, columnName string) (found bool, writable bool) {
|
||||||
|
for i := 0; i < typ.NumField(); i++ {
|
||||||
|
field := typ.Field(i)
|
||||||
|
|
||||||
|
// Check if this is an embedded struct
|
||||||
|
if field.Anonymous {
|
||||||
|
// Unwrap pointer type if necessary
|
||||||
|
fieldType := field.Type
|
||||||
|
if fieldType.Kind() == reflect.Pointer {
|
||||||
|
fieldType = fieldType.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Recursively search in embedded struct
|
||||||
|
if fieldType.Kind() == reflect.Struct {
|
||||||
|
if found, writable := isColumnWritableInType(fieldType, columnName); found {
|
||||||
|
return true, writable
|
||||||
|
}
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
// Check if this field matches the column name
|
// Check if this field matches the column name
|
||||||
fieldColumnName := getColumnNameFromField(field)
|
fieldColumnName := getColumnNameFromField(field)
|
||||||
@@ -262,11 +501,12 @@ func IsColumnWritable(model any, columnName string) bool {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Found the field, now check if it's writable
|
||||||
// Check bun tag for scanonly
|
// Check bun tag for scanonly
|
||||||
bunTag := field.Tag.Get("bun")
|
bunTag := field.Tag.Get("bun")
|
||||||
if bunTag != "" {
|
if bunTag != "" {
|
||||||
if isBunFieldScanOnly(bunTag) {
|
if isBunFieldScanOnly(bunTag) {
|
||||||
return false
|
return true, false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -274,16 +514,16 @@ func IsColumnWritable(model any, columnName string) bool {
|
|||||||
gormTag := field.Tag.Get("gorm")
|
gormTag := field.Tag.Get("gorm")
|
||||||
if gormTag != "" {
|
if gormTag != "" {
|
||||||
if isGormFieldReadOnly(gormTag) {
|
if isGormFieldReadOnly(gormTag) {
|
||||||
return false
|
return true, false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Column is writable
|
// Column is writable
|
||||||
return true
|
return true, true
|
||||||
}
|
}
|
||||||
|
|
||||||
// Column not found in model, allow it (might be a dynamic column)
|
// Column not found
|
||||||
return true
|
return false, false
|
||||||
}
|
}
|
||||||
|
|
||||||
// isBunFieldScanOnly checks if a bun tag indicates the field is scan-only
|
// isBunFieldScanOnly checks if a bun tag indicates the field is scan-only
|
||||||
@@ -323,3 +563,290 @@ func isGormFieldReadOnly(tag string) bool {
|
|||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ExtractSourceColumn extracts the base column name from PostgreSQL JSON operators
|
||||||
|
// Examples:
|
||||||
|
// - "columna->>'val'" returns "columna"
|
||||||
|
// - "columna->'key'" returns "columna"
|
||||||
|
// - "columna" returns "columna"
|
||||||
|
// - "table.columna->>'val'" returns "table.columna"
|
||||||
|
func ExtractSourceColumn(colName string) string {
|
||||||
|
// Check for PostgreSQL JSON operators: -> and ->>
|
||||||
|
if idx := strings.Index(colName, "->>"); idx != -1 {
|
||||||
|
return strings.TrimSpace(colName[:idx])
|
||||||
|
}
|
||||||
|
if idx := strings.Index(colName, "->"); idx != -1 {
|
||||||
|
return strings.TrimSpace(colName[:idx])
|
||||||
|
}
|
||||||
|
return colName
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToSnakeCase converts a string from CamelCase to snake_case
|
||||||
|
func ToSnakeCase(s string) string {
|
||||||
|
var result strings.Builder
|
||||||
|
for i, r := range s {
|
||||||
|
if i > 0 && r >= 'A' && r <= 'Z' {
|
||||||
|
result.WriteRune('_')
|
||||||
|
}
|
||||||
|
result.WriteRune(r)
|
||||||
|
}
|
||||||
|
return strings.ToLower(result.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetColumnTypeFromModel uses reflection to determine the Go type of a column in a model
|
||||||
|
func GetColumnTypeFromModel(model interface{}, colName string) reflect.Kind {
|
||||||
|
if model == nil {
|
||||||
|
return reflect.Invalid
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract the source column name (remove JSON operators like ->> or ->)
|
||||||
|
sourceColName := ExtractSourceColumn(colName)
|
||||||
|
|
||||||
|
modelType := reflect.TypeOf(model)
|
||||||
|
// Dereference pointer if needed
|
||||||
|
if modelType.Kind() == reflect.Ptr {
|
||||||
|
modelType = modelType.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure it's a struct
|
||||||
|
if modelType.Kind() != reflect.Struct {
|
||||||
|
return reflect.Invalid
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find the field by JSON tag or field name
|
||||||
|
for i := 0; i < modelType.NumField(); i++ {
|
||||||
|
field := modelType.Field(i)
|
||||||
|
|
||||||
|
// Check JSON tag
|
||||||
|
jsonTag := field.Tag.Get("json")
|
||||||
|
if jsonTag != "" {
|
||||||
|
// Parse JSON tag (format: "name,omitempty")
|
||||||
|
parts := strings.Split(jsonTag, ",")
|
||||||
|
if parts[0] == sourceColName {
|
||||||
|
return field.Type.Kind()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check field name (case-insensitive)
|
||||||
|
if strings.EqualFold(field.Name, sourceColName) {
|
||||||
|
return field.Type.Kind()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check snake_case conversion
|
||||||
|
snakeCaseName := ToSnakeCase(field.Name)
|
||||||
|
if snakeCaseName == sourceColName {
|
||||||
|
return field.Type.Kind()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return reflect.Invalid
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsNumericType checks if a reflect.Kind is a numeric type
|
||||||
|
func IsNumericType(kind reflect.Kind) bool {
|
||||||
|
return kind == reflect.Int || kind == reflect.Int8 || kind == reflect.Int16 ||
|
||||||
|
kind == reflect.Int32 || kind == reflect.Int64 || kind == reflect.Uint ||
|
||||||
|
kind == reflect.Uint8 || kind == reflect.Uint16 || kind == reflect.Uint32 ||
|
||||||
|
kind == reflect.Uint64 || kind == reflect.Float32 || kind == reflect.Float64
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsStringType checks if a reflect.Kind is a string type
|
||||||
|
func IsStringType(kind reflect.Kind) bool {
|
||||||
|
return kind == reflect.String
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsNumericValue checks if a string value can be parsed as a number
|
||||||
|
func IsNumericValue(value string) bool {
|
||||||
|
value = strings.TrimSpace(value)
|
||||||
|
_, err := strconv.ParseFloat(value, 64)
|
||||||
|
return err == nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConvertToNumericType converts a string value to the appropriate numeric type
|
||||||
|
func ConvertToNumericType(value string, kind reflect.Kind) (interface{}, error) {
|
||||||
|
value = strings.TrimSpace(value)
|
||||||
|
|
||||||
|
switch kind {
|
||||||
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||||
|
// Parse as integer
|
||||||
|
bitSize := 64
|
||||||
|
switch kind {
|
||||||
|
case reflect.Int8:
|
||||||
|
bitSize = 8
|
||||||
|
case reflect.Int16:
|
||||||
|
bitSize = 16
|
||||||
|
case reflect.Int32:
|
||||||
|
bitSize = 32
|
||||||
|
}
|
||||||
|
|
||||||
|
intVal, err := strconv.ParseInt(value, 10, bitSize)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid integer value: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return the appropriate type
|
||||||
|
switch kind {
|
||||||
|
case reflect.Int:
|
||||||
|
return int(intVal), nil
|
||||||
|
case reflect.Int8:
|
||||||
|
return int8(intVal), nil
|
||||||
|
case reflect.Int16:
|
||||||
|
return int16(intVal), nil
|
||||||
|
case reflect.Int32:
|
||||||
|
return int32(intVal), nil
|
||||||
|
case reflect.Int64:
|
||||||
|
return intVal, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||||
|
// Parse as unsigned integer
|
||||||
|
bitSize := 64
|
||||||
|
switch kind {
|
||||||
|
case reflect.Uint8:
|
||||||
|
bitSize = 8
|
||||||
|
case reflect.Uint16:
|
||||||
|
bitSize = 16
|
||||||
|
case reflect.Uint32:
|
||||||
|
bitSize = 32
|
||||||
|
}
|
||||||
|
|
||||||
|
uintVal, err := strconv.ParseUint(value, 10, bitSize)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid unsigned integer value: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return the appropriate type
|
||||||
|
switch kind {
|
||||||
|
case reflect.Uint:
|
||||||
|
return uint(uintVal), nil
|
||||||
|
case reflect.Uint8:
|
||||||
|
return uint8(uintVal), nil
|
||||||
|
case reflect.Uint16:
|
||||||
|
return uint16(uintVal), nil
|
||||||
|
case reflect.Uint32:
|
||||||
|
return uint32(uintVal), nil
|
||||||
|
case reflect.Uint64:
|
||||||
|
return uintVal, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
case reflect.Float32, reflect.Float64:
|
||||||
|
// Parse as float
|
||||||
|
bitSize := 64
|
||||||
|
if kind == reflect.Float32 {
|
||||||
|
bitSize = 32
|
||||||
|
}
|
||||||
|
|
||||||
|
floatVal, err := strconv.ParseFloat(value, bitSize)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid float value: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if kind == reflect.Float32 {
|
||||||
|
return float32(floatVal), nil
|
||||||
|
}
|
||||||
|
return floatVal, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("unsupported numeric type: %v", kind)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetRelationModel gets the model type for a relation field
|
||||||
|
// It searches for the field by name in the following order (case-insensitive):
|
||||||
|
// 1. Actual field name
|
||||||
|
// 2. Bun tag name (if exists)
|
||||||
|
// 3. Gorm tag name (if exists)
|
||||||
|
// 4. JSON tag name (if exists)
|
||||||
|
func GetRelationModel(model interface{}, fieldName string) interface{} {
|
||||||
|
if model == nil || fieldName == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
modelType := reflect.TypeOf(model)
|
||||||
|
if modelType == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if modelType.Kind() == reflect.Ptr {
|
||||||
|
modelType = modelType.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
if modelType == nil || modelType.Kind() != reflect.Struct {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find the field by checking in priority order (case-insensitive)
|
||||||
|
var field *reflect.StructField
|
||||||
|
normalizedFieldName := strings.ToLower(fieldName)
|
||||||
|
|
||||||
|
for i := 0; i < modelType.NumField(); i++ {
|
||||||
|
f := modelType.Field(i)
|
||||||
|
|
||||||
|
// 1. Check actual field name (case-insensitive)
|
||||||
|
if strings.EqualFold(f.Name, fieldName) {
|
||||||
|
field = &f
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Check bun tag name
|
||||||
|
bunTag := f.Tag.Get("bun")
|
||||||
|
if bunTag != "" {
|
||||||
|
bunColName := ExtractColumnFromBunTag(bunTag)
|
||||||
|
if bunColName != "" && strings.EqualFold(bunColName, normalizedFieldName) {
|
||||||
|
field = &f
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. Check gorm tag name
|
||||||
|
gormTag := f.Tag.Get("gorm")
|
||||||
|
if gormTag != "" {
|
||||||
|
gormColName := ExtractColumnFromGormTag(gormTag)
|
||||||
|
if gormColName != "" && strings.EqualFold(gormColName, normalizedFieldName) {
|
||||||
|
field = &f
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4. Check JSON tag name
|
||||||
|
jsonTag := f.Tag.Get("json")
|
||||||
|
if jsonTag != "" {
|
||||||
|
parts := strings.Split(jsonTag, ",")
|
||||||
|
if len(parts) > 0 && parts[0] != "" && parts[0] != "-" {
|
||||||
|
if strings.EqualFold(parts[0], normalizedFieldName) {
|
||||||
|
field = &f
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if field == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the target type
|
||||||
|
targetType := field.Type
|
||||||
|
if targetType == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if targetType.Kind() == reflect.Slice {
|
||||||
|
targetType = targetType.Elem()
|
||||||
|
if targetType == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if targetType.Kind() == reflect.Ptr {
|
||||||
|
targetType = targetType.Elem()
|
||||||
|
if targetType == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if targetType.Kind() != reflect.Struct {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a zero value of the target type
|
||||||
|
return reflect.New(targetType).Elem().Interface()
|
||||||
|
}
|
||||||
|
|||||||
@@ -231,3 +231,386 @@ func TestGetModelColumns(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Test models with embedded structs
|
||||||
|
|
||||||
|
type BaseModel struct {
|
||||||
|
ID int `bun:"rid_base,pk" json:"id"`
|
||||||
|
CreatedAt string `bun:"created_at" json:"created_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type AdhocBuffer struct {
|
||||||
|
CQL1 string `json:"cql1,omitempty" gorm:"->" bun:",scanonly"`
|
||||||
|
CQL2 string `json:"cql2,omitempty" gorm:"->" bun:",scanonly"`
|
||||||
|
RowNumber int64 `json:"_rownumber,omitempty" gorm:"-" bun:",scanonly"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ModelWithEmbedded struct {
|
||||||
|
BaseModel
|
||||||
|
Name string `bun:"name" json:"name"`
|
||||||
|
Description string `bun:"description" json:"description"`
|
||||||
|
AdhocBuffer
|
||||||
|
}
|
||||||
|
|
||||||
|
type GormBaseModel struct {
|
||||||
|
ID int `gorm:"column:rid_base;primaryKey" json:"id"`
|
||||||
|
CreatedAt string `gorm:"column:created_at" json:"created_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type GormAdhocBuffer struct {
|
||||||
|
CQL1 string `json:"cql1,omitempty" gorm:"column:cql1;->" bun:",scanonly"`
|
||||||
|
CQL2 string `json:"cql2,omitempty" gorm:"column:cql2;->" bun:",scanonly"`
|
||||||
|
RowNumber int64 `json:"_rownumber,omitempty" gorm:"-" bun:",scanonly"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type GormModelWithEmbedded struct {
|
||||||
|
GormBaseModel
|
||||||
|
Name string `gorm:"column:name" json:"name"`
|
||||||
|
Description string `gorm:"column:description" json:"description"`
|
||||||
|
GormAdhocBuffer
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetPrimaryKeyNameWithEmbedded(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
model any
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Bun model with embedded base",
|
||||||
|
model: ModelWithEmbedded{},
|
||||||
|
expected: "rid_base",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Bun model with embedded base (pointer)",
|
||||||
|
model: &ModelWithEmbedded{},
|
||||||
|
expected: "rid_base",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "GORM model with embedded base",
|
||||||
|
model: GormModelWithEmbedded{},
|
||||||
|
expected: "rid_base",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "GORM model with embedded base (pointer)",
|
||||||
|
model: &GormModelWithEmbedded{},
|
||||||
|
expected: "rid_base",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := GetPrimaryKeyName(tt.model)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("GetPrimaryKeyName() = %v, want %v", result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetPrimaryKeyValueWithEmbedded(t *testing.T) {
|
||||||
|
bunModel := ModelWithEmbedded{
|
||||||
|
BaseModel: BaseModel{
|
||||||
|
ID: 123,
|
||||||
|
CreatedAt: "2024-01-01",
|
||||||
|
},
|
||||||
|
Name: "Test",
|
||||||
|
Description: "Test Description",
|
||||||
|
}
|
||||||
|
|
||||||
|
gormModel := GormModelWithEmbedded{
|
||||||
|
GormBaseModel: GormBaseModel{
|
||||||
|
ID: 456,
|
||||||
|
CreatedAt: "2024-01-02",
|
||||||
|
},
|
||||||
|
Name: "GORM Test",
|
||||||
|
Description: "GORM Test Description",
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
model any
|
||||||
|
expected any
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Bun model with embedded base",
|
||||||
|
model: bunModel,
|
||||||
|
expected: 123,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Bun model with embedded base (pointer)",
|
||||||
|
model: &bunModel,
|
||||||
|
expected: 123,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "GORM model with embedded base",
|
||||||
|
model: gormModel,
|
||||||
|
expected: 456,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "GORM model with embedded base (pointer)",
|
||||||
|
model: &gormModel,
|
||||||
|
expected: 456,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := GetPrimaryKeyValue(tt.model)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("GetPrimaryKeyValue() = %v, want %v", result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetModelColumnsWithEmbedded(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
model any
|
||||||
|
expected []string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Bun model with embedded structs",
|
||||||
|
model: ModelWithEmbedded{},
|
||||||
|
expected: []string{"rid_base", "created_at", "name", "description", "cql1", "cql2", "_rownumber"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Bun model with embedded structs (pointer)",
|
||||||
|
model: &ModelWithEmbedded{},
|
||||||
|
expected: []string{"rid_base", "created_at", "name", "description", "cql1", "cql2", "_rownumber"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "GORM model with embedded structs",
|
||||||
|
model: GormModelWithEmbedded{},
|
||||||
|
expected: []string{"rid_base", "created_at", "name", "description", "cql1", "cql2", "_rownumber"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "GORM model with embedded structs (pointer)",
|
||||||
|
model: &GormModelWithEmbedded{},
|
||||||
|
expected: []string{"rid_base", "created_at", "name", "description", "cql1", "cql2", "_rownumber"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := GetModelColumns(tt.model)
|
||||||
|
if len(result) != len(tt.expected) {
|
||||||
|
t.Errorf("GetModelColumns() returned %d columns, want %d. Got: %v", len(result), len(tt.expected), result)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for i, col := range result {
|
||||||
|
if col != tt.expected[i] {
|
||||||
|
t.Errorf("GetModelColumns()[%d] = %v, want %v", i, col, tt.expected[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsColumnWritableWithEmbedded(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
model any
|
||||||
|
columnName string
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Bun model - writable column in main struct",
|
||||||
|
model: ModelWithEmbedded{},
|
||||||
|
columnName: "name",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Bun model - writable column in embedded base",
|
||||||
|
model: ModelWithEmbedded{},
|
||||||
|
columnName: "rid_base",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Bun model - scanonly column in embedded adhoc buffer",
|
||||||
|
model: ModelWithEmbedded{},
|
||||||
|
columnName: "cql1",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Bun model - scanonly column _rownumber",
|
||||||
|
model: ModelWithEmbedded{},
|
||||||
|
columnName: "_rownumber",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "GORM model - writable column in main struct",
|
||||||
|
model: GormModelWithEmbedded{},
|
||||||
|
columnName: "name",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "GORM model - writable column in embedded base",
|
||||||
|
model: GormModelWithEmbedded{},
|
||||||
|
columnName: "rid_base",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "GORM model - readonly column in embedded adhoc buffer",
|
||||||
|
model: GormModelWithEmbedded{},
|
||||||
|
columnName: "cql1",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "GORM model - readonly column _rownumber",
|
||||||
|
model: GormModelWithEmbedded{},
|
||||||
|
columnName: "_rownumber",
|
||||||
|
expected: false, // bun:",scanonly" marks it as read-only, takes precedence over gorm:"-"
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := IsColumnWritable(tt.model, tt.columnName)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("IsColumnWritable(%s) = %v, want %v", tt.columnName, result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test models with relations for GetSQLModelColumns
|
||||||
|
type User struct {
|
||||||
|
ID int `bun:"id,pk" json:"id"`
|
||||||
|
Name string `bun:"name" json:"name"`
|
||||||
|
Email string `bun:"email" json:"email"`
|
||||||
|
ProfileData string `json:"profile_data"` // No bun/gorm tag
|
||||||
|
Posts []Post `bun:"rel:has-many,join:id=user_id" json:"posts"`
|
||||||
|
Profile *Profile `bun:"rel:has-one,join:id=user_id" json:"profile"`
|
||||||
|
RowNumber int64 `bun:",scanonly" json:"_rownumber"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Post struct {
|
||||||
|
ID int `gorm:"column:id;primaryKey" json:"id"`
|
||||||
|
Title string `gorm:"column:title" json:"title"`
|
||||||
|
UserID int `gorm:"column:user_id;foreignKey" json:"user_id"`
|
||||||
|
User *User `gorm:"foreignKey:UserID;references:ID" json:"user"`
|
||||||
|
Tags []Tag `gorm:"many2many:post_tags" json:"tags"`
|
||||||
|
Content string `json:"content"` // No bun/gorm tag
|
||||||
|
}
|
||||||
|
|
||||||
|
type Profile struct {
|
||||||
|
ID int `bun:"id,pk" json:"id"`
|
||||||
|
Bio string `bun:"bio" json:"bio"`
|
||||||
|
UserID int `bun:"user_id" json:"user_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Tag struct {
|
||||||
|
ID int `gorm:"column:id;primaryKey" json:"id"`
|
||||||
|
Name string `gorm:"column:name" json:"name"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Model with scan-only embedded struct
|
||||||
|
type EntityWithScanOnlyEmbedded struct {
|
||||||
|
ID int `bun:"id,pk" json:"id"`
|
||||||
|
Name string `bun:"name" json:"name"`
|
||||||
|
AdhocBuffer `bun:",scanonly"` // Entire embedded struct is scan-only
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetSQLModelColumns(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
model any
|
||||||
|
expected []string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Bun model with relations - excludes relations and non-SQL fields",
|
||||||
|
model: User{},
|
||||||
|
// Should include: id, name, email (has bun tags)
|
||||||
|
// Should exclude: profile_data (no bun tag), Posts/Profile (relations), RowNumber (scan-only in embedded would be excluded)
|
||||||
|
expected: []string{"id", "name", "email"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "GORM model with relations - excludes relations and non-SQL fields",
|
||||||
|
model: Post{},
|
||||||
|
// Should include: id, title, user_id (has gorm tags)
|
||||||
|
// Should exclude: content (no gorm tag), User/Tags (relations)
|
||||||
|
expected: []string{"id", "title", "user_id"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Model with embedded base and scan-only embedded",
|
||||||
|
model: EntityWithScanOnlyEmbedded{},
|
||||||
|
// Should include: id, name from main struct
|
||||||
|
// Should exclude: all fields from AdhocBuffer (scan-only embedded struct)
|
||||||
|
expected: []string{"id", "name"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Model with embedded - includes SQL fields, excludes scan-only",
|
||||||
|
model: ModelWithEmbedded{},
|
||||||
|
// Should include: rid_base, created_at (from BaseModel), name, description (from main)
|
||||||
|
// Should exclude: cql1, cql2, _rownumber (from AdhocBuffer - scan-only fields)
|
||||||
|
expected: []string{"rid_base", "created_at", "name", "description"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "GORM model with embedded - includes SQL fields, excludes scan-only",
|
||||||
|
model: GormModelWithEmbedded{},
|
||||||
|
// Should include: rid_base, created_at (from GormBaseModel), name, description (from main)
|
||||||
|
// Should exclude: cql1, cql2 (scan-only), _rownumber (no gorm column tag, marked as -)
|
||||||
|
expected: []string{"rid_base", "created_at", "name", "description"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Simple Profile model",
|
||||||
|
model: Profile{},
|
||||||
|
// Should include all fields with bun tags
|
||||||
|
expected: []string{"id", "bio", "user_id"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := GetSQLModelColumns(tt.model)
|
||||||
|
if len(result) != len(tt.expected) {
|
||||||
|
t.Errorf("GetSQLModelColumns() returned %d columns, want %d.\nGot: %v\nWant: %v",
|
||||||
|
len(result), len(tt.expected), result, tt.expected)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for i, col := range result {
|
||||||
|
if col != tt.expected[i] {
|
||||||
|
t.Errorf("GetSQLModelColumns()[%d] = %v, want %v.\nFull result: %v",
|
||||||
|
i, col, tt.expected[i], result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetSQLModelColumnsVsGetModelColumns(t *testing.T) {
|
||||||
|
// Demonstrate the difference between GetModelColumns and GetSQLModelColumns
|
||||||
|
user := User{}
|
||||||
|
|
||||||
|
allColumns := GetModelColumns(user)
|
||||||
|
sqlColumns := GetSQLModelColumns(user)
|
||||||
|
|
||||||
|
t.Logf("GetModelColumns(User): %v", allColumns)
|
||||||
|
t.Logf("GetSQLModelColumns(User): %v", sqlColumns)
|
||||||
|
|
||||||
|
// GetModelColumns should return more columns (includes fields with only json tags)
|
||||||
|
if len(allColumns) <= len(sqlColumns) {
|
||||||
|
t.Errorf("Expected GetModelColumns to return more columns than GetSQLModelColumns")
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetSQLModelColumns should not include 'profile_data' (no bun tag)
|
||||||
|
for _, col := range sqlColumns {
|
||||||
|
if col == "profile_data" {
|
||||||
|
t.Errorf("GetSQLModelColumns should not include 'profile_data' (no bun/gorm tag)")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetModelColumns should include 'profile_data' (has json tag)
|
||||||
|
hasProfileData := false
|
||||||
|
for _, col := range allColumns {
|
||||||
|
if col == "profile_data" {
|
||||||
|
hasProfileData = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !hasProfileData {
|
||||||
|
t.Errorf("GetModelColumns should include 'profile_data' (has json tag)")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -191,6 +191,11 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
|||||||
query = query.Table(tableName)
|
query = query.Table(tableName)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(options.Columns) == 0 && (len(options.ComputedColumns) > 0) {
|
||||||
|
logger.Debug("Populating options.Columns with all model columns since computed columns are additions")
|
||||||
|
options.Columns = reflection.GetSQLModelColumns(model)
|
||||||
|
}
|
||||||
|
|
||||||
// Apply column selection
|
// Apply column selection
|
||||||
if len(options.Columns) > 0 {
|
if len(options.Columns) > 0 {
|
||||||
logger.Debug("Selecting columns: %v", options.Columns)
|
logger.Debug("Selecting columns: %v", options.Columns)
|
||||||
@@ -1132,15 +1137,25 @@ func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, pre
|
|||||||
// ORMs like GORM and Bun expect the struct field name, not the JSON name
|
// ORMs like GORM and Bun expect the struct field name, not the JSON name
|
||||||
relationFieldName := relInfo.fieldName
|
relationFieldName := relInfo.fieldName
|
||||||
|
|
||||||
// For now, we'll preload without conditions
|
// Validate and fix WHERE clause to ensure it contains the relation prefix
|
||||||
// TODO: Implement column selection and filtering for preloads
|
if len(preload.Where) > 0 {
|
||||||
// This requires a more sophisticated approach with callbacks or query builders
|
fixedWhere, err := common.ValidateAndFixPreloadWhere(preload.Where, relationFieldName)
|
||||||
// Apply preloading
|
if err != nil {
|
||||||
|
logger.Error("Invalid preload WHERE clause for relation '%s': %v", relationFieldName, err)
|
||||||
|
panic(fmt.Errorf("invalid preload WHERE clause for relation '%s': %w", relationFieldName, err))
|
||||||
|
}
|
||||||
|
preload.Where = fixedWhere
|
||||||
|
}
|
||||||
|
|
||||||
logger.Debug("Applying preload: %s", relationFieldName)
|
logger.Debug("Applying preload: %s", relationFieldName)
|
||||||
query = query.PreloadRelation(relationFieldName, func(sq common.SelectQuery) common.SelectQuery {
|
query = query.PreloadRelation(relationFieldName, func(sq common.SelectQuery) common.SelectQuery {
|
||||||
|
if len(preload.Columns) == 0 && (len(preload.ComputedQL) > 0 || len(preload.OmitColumns) > 0) {
|
||||||
|
preload.Columns = reflection.GetSQLModelColumns(model)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle column selection and omission
|
||||||
if len(preload.OmitColumns) > 0 {
|
if len(preload.OmitColumns) > 0 {
|
||||||
allCols := reflection.GetModelColumns(model)
|
allCols := reflection.GetSQLModelColumns(model)
|
||||||
// Remove omitted columns
|
// Remove omitted columns
|
||||||
preload.Columns = []string{}
|
preload.Columns = []string{}
|
||||||
for _, col := range allCols {
|
for _, col := range allCols {
|
||||||
@@ -1194,7 +1209,10 @@ func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, pre
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(preload.Where) > 0 {
|
if len(preload.Where) > 0 {
|
||||||
sq = sq.Where(preload.Where)
|
sanitizedWhere := common.SanitizeWhereClause(preload.Where, preload.Relation)
|
||||||
|
if len(sanitizedWhere) > 0 {
|
||||||
|
sq = sq.Where(sanitizedWhere)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if preload.Limit != nil && *preload.Limit > 0 {
|
if preload.Limit != nil && *preload.Limit > 0 {
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ const (
|
|||||||
contextKeyTableName contextKey = "tableName"
|
contextKeyTableName contextKey = "tableName"
|
||||||
contextKeyModel contextKey = "model"
|
contextKeyModel contextKey = "model"
|
||||||
contextKeyModelPtr contextKey = "modelPtr"
|
contextKeyModelPtr contextKey = "modelPtr"
|
||||||
|
contextKeyOptions contextKey = "options"
|
||||||
)
|
)
|
||||||
|
|
||||||
// WithSchema adds schema to context
|
// WithSchema adds schema to context
|
||||||
@@ -74,12 +75,28 @@ func GetModelPtr(ctx context.Context) interface{} {
|
|||||||
return ctx.Value(contextKeyModelPtr)
|
return ctx.Value(contextKeyModelPtr)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WithOptions adds request options to context
|
||||||
|
func WithOptions(ctx context.Context, options ExtendedRequestOptions) context.Context {
|
||||||
|
return context.WithValue(ctx, contextKeyOptions, options)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetOptions retrieves request options from context
|
||||||
|
func GetOptions(ctx context.Context) *ExtendedRequestOptions {
|
||||||
|
if v := ctx.Value(contextKeyOptions); v != nil {
|
||||||
|
if opts, ok := v.(ExtendedRequestOptions); ok {
|
||||||
|
return &opts
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// WithRequestData adds all request-scoped data to context at once
|
// WithRequestData adds all request-scoped data to context at once
|
||||||
func WithRequestData(ctx context.Context, schema, entity, tableName string, model, modelPtr interface{}) context.Context {
|
func WithRequestData(ctx context.Context, schema, entity, tableName string, model, modelPtr interface{}, options ExtendedRequestOptions) context.Context {
|
||||||
ctx = WithSchema(ctx, schema)
|
ctx = WithSchema(ctx, schema)
|
||||||
ctx = WithEntity(ctx, entity)
|
ctx = WithEntity(ctx, entity)
|
||||||
ctx = WithTableName(ctx, tableName)
|
ctx = WithTableName(ctx, tableName)
|
||||||
ctx = WithModel(ctx, model)
|
ctx = WithModel(ctx, model)
|
||||||
ctx = WithModelPtr(ctx, modelPtr)
|
ctx = WithModelPtr(ctx, modelPtr)
|
||||||
|
ctx = WithOptions(ctx, options)
|
||||||
return ctx
|
return ctx
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -65,9 +65,6 @@ func (h *Handler) Handle(w common.ResponseWriter, r common.Request, params map[s
|
|||||||
entity := params["entity"]
|
entity := params["entity"]
|
||||||
id := params["id"]
|
id := params["id"]
|
||||||
|
|
||||||
// Parse options from headers (now returns ExtendedRequestOptions)
|
|
||||||
options := h.parseOptionsFromHeaders(r)
|
|
||||||
|
|
||||||
// Determine operation based on HTTP method
|
// Determine operation based on HTTP method
|
||||||
method := r.Method()
|
method := r.Method()
|
||||||
|
|
||||||
@@ -104,13 +101,16 @@ func (h *Handler) Handle(w common.ResponseWriter, r common.Request, params map[s
|
|||||||
modelPtr := reflect.New(reflect.TypeOf(model)).Interface()
|
modelPtr := reflect.New(reflect.TypeOf(model)).Interface()
|
||||||
tableName := h.getTableName(schema, entity, model)
|
tableName := h.getTableName(schema, entity, model)
|
||||||
|
|
||||||
// Add request-scoped data to context
|
// Parse options from headers - this now includes relation name resolution
|
||||||
ctx = WithRequestData(ctx, schema, entity, tableName, model, modelPtr)
|
options := h.parseOptionsFromHeaders(r, model)
|
||||||
|
|
||||||
// Validate and filter columns in options (log warnings for invalid columns)
|
// Validate and filter columns in options (log warnings for invalid columns)
|
||||||
validator := common.NewColumnValidator(model)
|
validator := common.NewColumnValidator(model)
|
||||||
options = filterExtendedOptions(validator, options)
|
options = filterExtendedOptions(validator, options)
|
||||||
|
|
||||||
|
// Add request-scoped data to context (including options)
|
||||||
|
ctx = WithRequestData(ctx, schema, entity, tableName, model, modelPtr, options)
|
||||||
|
|
||||||
switch method {
|
switch method {
|
||||||
case "GET":
|
case "GET":
|
||||||
if id != "" {
|
if id != "" {
|
||||||
@@ -260,6 +260,13 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
|||||||
query = query.Table(tableName)
|
query = query.Table(tableName)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If we have computed columns/expressions but options.Columns is empty,
|
||||||
|
// populate it with all model columns first since computed columns are additions
|
||||||
|
if len(options.Columns) == 0 && (len(options.ComputedQL) > 0 || len(options.ComputedColumns) > 0) {
|
||||||
|
logger.Debug("Populating options.Columns with all model columns since computed columns are additions")
|
||||||
|
options.Columns = reflection.GetSQLModelColumns(model)
|
||||||
|
}
|
||||||
|
|
||||||
// Apply ComputedQL fields if any
|
// Apply ComputedQL fields if any
|
||||||
if len(options.ComputedQL) > 0 {
|
if len(options.ComputedQL) > 0 {
|
||||||
for colName, colExpr := range options.ComputedQL {
|
for colName, colExpr := range options.ComputedQL {
|
||||||
@@ -340,50 +347,21 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
|||||||
for idx := range options.Preload {
|
for idx := range options.Preload {
|
||||||
preload := options.Preload[idx]
|
preload := options.Preload[idx]
|
||||||
logger.Debug("Applying preload: %s", preload.Relation)
|
logger.Debug("Applying preload: %s", preload.Relation)
|
||||||
query = query.PreloadRelation(preload.Relation, func(sq common.SelectQuery) common.SelectQuery {
|
|
||||||
if len(preload.OmitColumns) > 0 {
|
|
||||||
allCols := reflection.GetModelColumns(model)
|
|
||||||
// Remove omitted columns
|
|
||||||
preload.Columns = []string{}
|
|
||||||
for _, col := range allCols {
|
|
||||||
addCols := true
|
|
||||||
for _, omitCol := range preload.OmitColumns {
|
|
||||||
if col == omitCol {
|
|
||||||
addCols = false
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if addCols {
|
|
||||||
preload.Columns = append(preload.Columns, col)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(preload.Columns) > 0 {
|
// Validate and fix WHERE clause to ensure it contains the relation prefix
|
||||||
sq = sq.Column(preload.Columns...)
|
if len(preload.Where) > 0 {
|
||||||
|
fixedWhere, err := common.ValidateAndFixPreloadWhere(preload.Where, preload.Relation)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Invalid preload WHERE clause for relation '%s': %v", preload.Relation, err)
|
||||||
|
h.sendError(w, http.StatusBadRequest, "invalid_preload_where",
|
||||||
|
fmt.Sprintf("Invalid preload WHERE clause for relation '%s'", preload.Relation), err)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
preload.Where = fixedWhere
|
||||||
|
}
|
||||||
|
|
||||||
if len(preload.Filters) > 0 {
|
// Apply the preload with recursive support
|
||||||
for _, filter := range preload.Filters {
|
query = h.applyPreloadWithRecursion(query, preload, model, 0)
|
||||||
sq = h.applyFilter(sq, filter, "", false, "AND")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if len(preload.Sort) > 0 {
|
|
||||||
for _, sort := range preload.Sort {
|
|
||||||
sq = sq.Order(fmt.Sprintf("%s %s", sort.Column, sort.Direction))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(preload.Where) > 0 {
|
|
||||||
sq = sq.Where(preload.Where)
|
|
||||||
}
|
|
||||||
|
|
||||||
if preload.Limit != nil && *preload.Limit > 0 {
|
|
||||||
sq = sq.Limit(*preload.Limit)
|
|
||||||
}
|
|
||||||
|
|
||||||
return sq
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Apply DISTINCT if requested
|
// Apply DISTINCT if requested
|
||||||
@@ -413,13 +391,21 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
|||||||
// Apply custom SQL WHERE clause (AND condition)
|
// Apply custom SQL WHERE clause (AND condition)
|
||||||
if options.CustomSQLWhere != "" {
|
if options.CustomSQLWhere != "" {
|
||||||
logger.Debug("Applying custom SQL WHERE: %s", options.CustomSQLWhere)
|
logger.Debug("Applying custom SQL WHERE: %s", options.CustomSQLWhere)
|
||||||
query = query.Where(options.CustomSQLWhere)
|
// Sanitize without auto-prefixing since custom SQL may reference multiple tables
|
||||||
|
sanitizedWhere := common.SanitizeWhereClause(options.CustomSQLWhere, "")
|
||||||
|
if sanitizedWhere != "" {
|
||||||
|
query = query.Where(sanitizedWhere)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Apply custom SQL WHERE clause (OR condition)
|
// Apply custom SQL WHERE clause (OR condition)
|
||||||
if options.CustomSQLOr != "" {
|
if options.CustomSQLOr != "" {
|
||||||
logger.Debug("Applying custom SQL OR: %s", options.CustomSQLOr)
|
logger.Debug("Applying custom SQL OR: %s", options.CustomSQLOr)
|
||||||
query = query.WhereOr(options.CustomSQLOr)
|
// Sanitize without auto-prefixing since custom SQL may reference multiple tables
|
||||||
|
sanitizedOr := common.SanitizeWhereClause(options.CustomSQLOr, "")
|
||||||
|
if sanitizedOr != "" {
|
||||||
|
query = query.WhereOr(sanitizedOr)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// If ID is provided, filter by ID
|
// If ID is provided, filter by ID
|
||||||
@@ -495,7 +481,10 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
|||||||
// Apply cursor filter to query
|
// Apply cursor filter to query
|
||||||
if cursorFilter != "" {
|
if cursorFilter != "" {
|
||||||
logger.Debug("Applying cursor filter: %s", cursorFilter)
|
logger.Debug("Applying cursor filter: %s", cursorFilter)
|
||||||
query = query.Where(cursorFilter)
|
sanitizedCursor := common.SanitizeWhereClause(cursorFilter, "")
|
||||||
|
if sanitizedCursor != "" {
|
||||||
|
query = query.Where(sanitizedCursor)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -569,6 +558,120 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
|||||||
h.sendFormattedResponse(w, modelPtr, metadata, options)
|
h.sendFormattedResponse(w, modelPtr, metadata, options)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// applyPreloadWithRecursion applies a preload with support for ComputedQL and recursive preloading
|
||||||
|
func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload common.PreloadOption, model interface{}, depth int) common.SelectQuery {
|
||||||
|
// Apply the preload
|
||||||
|
query = query.PreloadRelation(preload.Relation, func(sq common.SelectQuery) common.SelectQuery {
|
||||||
|
// Get the related model for column operations
|
||||||
|
relationParts := strings.Split(preload.Relation, ",")
|
||||||
|
relatedModel := reflection.GetRelationModel(model, relationParts[0])
|
||||||
|
if relatedModel == nil {
|
||||||
|
logger.Warn("Could not get related model for preload: %s", preload.Relation)
|
||||||
|
// relatedModel = model // fallback to parent model
|
||||||
|
} else {
|
||||||
|
|
||||||
|
// If we have computed columns but no explicit columns, populate with all model columns first
|
||||||
|
// since computed columns are additions
|
||||||
|
if len(preload.Columns) == 0 && (len(preload.ComputedQL) > 0 || len(preload.OmitColumns) > 0) {
|
||||||
|
logger.Debug("Populating preload columns with all model columns since computed columns are additions")
|
||||||
|
preload.Columns = reflection.GetSQLModelColumns(relatedModel)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply ComputedQL fields if any
|
||||||
|
if len(preload.ComputedQL) > 0 {
|
||||||
|
for colName, colExpr := range preload.ComputedQL {
|
||||||
|
logger.Debug("Applying computed column to preload %s: %s", preload.Relation, colName)
|
||||||
|
sq = sq.ColumnExpr(fmt.Sprintf("(%s) AS %s", colExpr, colName))
|
||||||
|
// Remove the computed column from selected columns to avoid duplication
|
||||||
|
for colIndex := range preload.Columns {
|
||||||
|
if preload.Columns[colIndex] == colName {
|
||||||
|
preload.Columns = append(preload.Columns[:colIndex], preload.Columns[colIndex+1:]...)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle OmitColumns
|
||||||
|
if len(preload.OmitColumns) > 0 {
|
||||||
|
allCols := preload.Columns
|
||||||
|
// Remove omitted columns
|
||||||
|
preload.Columns = []string{}
|
||||||
|
for _, col := range allCols {
|
||||||
|
addCols := true
|
||||||
|
for _, omitCol := range preload.OmitColumns {
|
||||||
|
if col == omitCol {
|
||||||
|
addCols = false
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if addCols {
|
||||||
|
preload.Columns = append(preload.Columns, col)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply column selection
|
||||||
|
if len(preload.Columns) > 0 {
|
||||||
|
sq = sq.Column(preload.Columns...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply filters
|
||||||
|
if len(preload.Filters) > 0 {
|
||||||
|
for _, filter := range preload.Filters {
|
||||||
|
sq = h.applyFilter(sq, filter, "", false, "AND")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply sorting
|
||||||
|
if len(preload.Sort) > 0 {
|
||||||
|
for _, sort := range preload.Sort {
|
||||||
|
sq = sq.Order(fmt.Sprintf("%s %s", sort.Column, sort.Direction))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply WHERE clause
|
||||||
|
if len(preload.Where) > 0 {
|
||||||
|
sanitizedWhere := common.SanitizeWhereClause(preload.Where, preload.Relation)
|
||||||
|
if len(sanitizedWhere) > 0 {
|
||||||
|
sq = sq.Where(sanitizedWhere)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply limit
|
||||||
|
if preload.Limit != nil && *preload.Limit > 0 {
|
||||||
|
sq = sq.Limit(*preload.Limit)
|
||||||
|
}
|
||||||
|
|
||||||
|
if preload.Offset != nil && *preload.Offset > 0 {
|
||||||
|
sq = sq.Offset(*preload.Offset)
|
||||||
|
}
|
||||||
|
|
||||||
|
return sq
|
||||||
|
})
|
||||||
|
|
||||||
|
// Handle recursive preloading
|
||||||
|
if preload.Recursive && depth < 5 {
|
||||||
|
logger.Debug("Applying recursive preload for %s at depth %d", preload.Relation, depth+1)
|
||||||
|
|
||||||
|
// For recursive relationships, we need to get the last part of the relation path
|
||||||
|
// e.g., "MastertaskItems" -> "MastertaskItems.MastertaskItems"
|
||||||
|
relationParts := strings.Split(preload.Relation, ".")
|
||||||
|
lastRelationName := relationParts[len(relationParts)-1]
|
||||||
|
|
||||||
|
// Create a recursive preload with the same configuration
|
||||||
|
// but with the relation path extended
|
||||||
|
recursivePreload := preload
|
||||||
|
recursivePreload.Relation = preload.Relation + "." + lastRelationName
|
||||||
|
|
||||||
|
// Recursively apply preload until we reach depth 5
|
||||||
|
query = h.applyPreloadWithRecursion(query, recursivePreload, model, depth+1)
|
||||||
|
}
|
||||||
|
|
||||||
|
return query
|
||||||
|
}
|
||||||
|
|
||||||
func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, data interface{}, options ExtendedRequestOptions) {
|
func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, data interface{}, options ExtendedRequestOptions) {
|
||||||
// Capture panics and return error response
|
// Capture panics and return error response
|
||||||
defer func() {
|
defer func() {
|
||||||
@@ -610,6 +713,9 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
|
|||||||
dataSlice := h.normalizeToSlice(data)
|
dataSlice := h.normalizeToSlice(data)
|
||||||
logger.Debug("Processing %d item(s) for creation", len(dataSlice))
|
logger.Debug("Processing %d item(s) for creation", len(dataSlice))
|
||||||
|
|
||||||
|
// Store original data maps for merging later
|
||||||
|
originalDataMaps := make([]map[string]interface{}, 0, len(dataSlice))
|
||||||
|
|
||||||
// Process all items in a transaction
|
// Process all items in a transaction
|
||||||
results := make([]interface{}, 0, len(dataSlice))
|
results := make([]interface{}, 0, len(dataSlice))
|
||||||
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
||||||
@@ -630,6 +736,13 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Store a copy of the original data map for merging later
|
||||||
|
originalMap := make(map[string]interface{})
|
||||||
|
for k, v := range itemMap {
|
||||||
|
originalMap[k] = v
|
||||||
|
}
|
||||||
|
originalDataMaps = append(originalDataMaps, originalMap)
|
||||||
|
|
||||||
// Extract nested relations if present (but don't process them yet)
|
// Extract nested relations if present (but don't process them yet)
|
||||||
var nestedRelations map[string]interface{}
|
var nestedRelations map[string]interface{}
|
||||||
if h.shouldUseNestedProcessor(itemMap, model) {
|
if h.shouldUseNestedProcessor(itemMap, model) {
|
||||||
@@ -653,7 +766,14 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create insert query
|
// Create insert query
|
||||||
query := tx.NewInsert().Model(modelValue).Table(tableName).Returning("*")
|
query := tx.NewInsert().Model(modelValue)
|
||||||
|
|
||||||
|
// Only set Table() if the model doesn't provide a table name via TableNameProvider
|
||||||
|
if provider, ok := modelValue.(common.TableNameProvider); !ok || provider.TableName() == "" {
|
||||||
|
query = query.Table(tableName)
|
||||||
|
}
|
||||||
|
|
||||||
|
query = query.Returning("*")
|
||||||
|
|
||||||
// Execute BeforeScan hooks - pass query chain so hooks can modify it
|
// Execute BeforeScan hooks - pass query chain so hooks can modify it
|
||||||
itemHookCtx := &HookContext{
|
itemHookCtx := &HookContext{
|
||||||
@@ -704,14 +824,26 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Merge created records with original request data
|
||||||
|
// This preserves extra keys from the request
|
||||||
|
mergedResults := make([]interface{}, 0, len(results))
|
||||||
|
for i, result := range results {
|
||||||
|
if i < len(originalDataMaps) {
|
||||||
|
merged := h.mergeRecordWithRequest(result, originalDataMaps[i])
|
||||||
|
mergedResults = append(mergedResults, merged)
|
||||||
|
} else {
|
||||||
|
mergedResults = append(mergedResults, result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Execute AfterCreate hooks
|
// Execute AfterCreate hooks
|
||||||
var responseData interface{}
|
var responseData interface{}
|
||||||
if len(results) == 1 {
|
if len(mergedResults) == 1 {
|
||||||
responseData = results[0]
|
responseData = mergedResults[0]
|
||||||
hookCtx.Result = results[0]
|
hookCtx.Result = mergedResults[0]
|
||||||
} else {
|
} else {
|
||||||
responseData = results
|
responseData = mergedResults
|
||||||
hookCtx.Result = map[string]interface{}{"created": len(results), "data": results}
|
hookCtx.Result = map[string]interface{}{"created": len(mergedResults), "data": mergedResults}
|
||||||
}
|
}
|
||||||
hookCtx.Error = nil
|
hookCtx.Error = nil
|
||||||
|
|
||||||
@@ -721,7 +853,7 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.Info("Successfully created %d record(s)", len(results))
|
logger.Info("Successfully created %d record(s)", len(mergedResults))
|
||||||
h.sendResponseWithOptions(w, responseData, nil, &options)
|
h.sendResponseWithOptions(w, responseData, nil, &options)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -790,6 +922,12 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Get the primary key name for the model
|
||||||
|
pkName := reflection.GetPrimaryKeyName(model)
|
||||||
|
|
||||||
|
// Variable to store the updated record
|
||||||
|
var updatedRecord interface{}
|
||||||
|
|
||||||
// Process nested relations if present
|
// Process nested relations if present
|
||||||
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
||||||
// Create temporary nested processor with transaction
|
// Create temporary nested processor with transaction
|
||||||
@@ -808,11 +946,10 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Ensure ID is in the data map for the update
|
// Ensure ID is in the data map for the update
|
||||||
dataMap["id"] = targetID
|
dataMap[pkName] = targetID
|
||||||
|
|
||||||
// Create update query
|
// Create update query
|
||||||
query := tx.NewUpdate().Table(tableName).SetMap(dataMap)
|
query := tx.NewUpdate().Table(tableName).SetMap(dataMap)
|
||||||
pkName := reflection.GetPrimaryKeyName(model)
|
|
||||||
query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), targetID)
|
query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), targetID)
|
||||||
|
|
||||||
// Execute BeforeScan hooks - pass query chain so hooks can modify it
|
// Execute BeforeScan hooks - pass query chain so hooks can modify it
|
||||||
@@ -840,10 +977,18 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Store result for hooks
|
// Fetch the updated record to return the new values
|
||||||
hookCtx.Result = map[string]interface{}{
|
modelValue := reflect.New(reflect.TypeOf(model)).Interface()
|
||||||
"updated": result.RowsAffected(),
|
selectQuery := tx.NewSelect().Model(modelValue).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), targetID)
|
||||||
|
if err := selectQuery.ScanModel(ctx); err != nil {
|
||||||
|
return fmt.Errorf("failed to fetch updated record: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
updatedRecord = modelValue
|
||||||
|
|
||||||
|
// Store result for hooks
|
||||||
|
hookCtx.Result = updatedRecord
|
||||||
|
_ = result // Keep result variable for potential future use
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -853,7 +998,12 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Merge the updated record with the original request data
|
||||||
|
// This preserves extra keys from the request and updates values from the database
|
||||||
|
mergedData := h.mergeRecordWithRequest(updatedRecord, dataMap)
|
||||||
|
|
||||||
// Execute AfterUpdate hooks
|
// Execute AfterUpdate hooks
|
||||||
|
hookCtx.Result = mergedData
|
||||||
hookCtx.Error = nil
|
hookCtx.Error = nil
|
||||||
if err := h.hooks.Execute(AfterUpdate, hookCtx); err != nil {
|
if err := h.hooks.Execute(AfterUpdate, hookCtx); err != nil {
|
||||||
logger.Error("AfterUpdate hook failed: %v", err)
|
logger.Error("AfterUpdate hook failed: %v", err)
|
||||||
@@ -862,7 +1012,7 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id
|
|||||||
}
|
}
|
||||||
|
|
||||||
logger.Info("Successfully updated record with ID: %v", targetID)
|
logger.Info("Successfully updated record with ID: %v", targetID)
|
||||||
h.sendResponseWithOptions(w, hookCtx.Result, nil, &options)
|
h.sendResponseWithOptions(w, mergedData, nil, &options)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id string, data interface{}) {
|
func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id string, data interface{}) {
|
||||||
@@ -936,6 +1086,7 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
|||||||
// Array of IDs or objects with ID field
|
// Array of IDs or objects with ID field
|
||||||
logger.Info("Batch delete with %d items ([]interface{})", len(v))
|
logger.Info("Batch delete with %d items ([]interface{})", len(v))
|
||||||
deletedCount := 0
|
deletedCount := 0
|
||||||
|
pkName := reflection.GetPrimaryKeyName(model)
|
||||||
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
||||||
for _, item := range v {
|
for _, item := range v {
|
||||||
var itemID interface{}
|
var itemID interface{}
|
||||||
@@ -945,7 +1096,7 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
|||||||
case string:
|
case string:
|
||||||
itemID = v
|
itemID = v
|
||||||
case map[string]interface{}:
|
case map[string]interface{}:
|
||||||
itemID = v["id"]
|
itemID = v[pkName]
|
||||||
default:
|
default:
|
||||||
itemID = item
|
itemID = item
|
||||||
}
|
}
|
||||||
@@ -1002,9 +1153,10 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
|||||||
// Array of objects with id field
|
// Array of objects with id field
|
||||||
logger.Info("Batch delete with %d items ([]map[string]interface{})", len(v))
|
logger.Info("Batch delete with %d items ([]map[string]interface{})", len(v))
|
||||||
deletedCount := 0
|
deletedCount := 0
|
||||||
|
pkName := reflection.GetPrimaryKeyName(model)
|
||||||
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
||||||
for _, item := range v {
|
for _, item := range v {
|
||||||
if itemID, ok := item["id"]; ok && itemID != nil {
|
if itemID, ok := item[pkName]; ok && itemID != nil {
|
||||||
itemIDStr := fmt.Sprintf("%v", itemID)
|
itemIDStr := fmt.Sprintf("%v", itemID)
|
||||||
|
|
||||||
// Execute hooks for each item
|
// Execute hooks for each item
|
||||||
@@ -1052,7 +1204,8 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
|||||||
|
|
||||||
case map[string]interface{}:
|
case map[string]interface{}:
|
||||||
// Single object with id field
|
// Single object with id field
|
||||||
if itemID, ok := v["id"]; ok && itemID != nil {
|
pkName := reflection.GetPrimaryKeyName(model)
|
||||||
|
if itemID, ok := v[pkName]; ok && itemID != nil {
|
||||||
id = fmt.Sprintf("%v", itemID)
|
id = fmt.Sprintf("%v", itemID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1122,6 +1275,39 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
|||||||
h.sendResponse(w, responseData, nil)
|
h.sendResponse(w, responseData, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// mergeRecordWithRequest merges a database record with the original request data
|
||||||
|
// This preserves extra keys from the request that aren't in the database model
|
||||||
|
// and updates values from the database (e.g., from SQL triggers or defaults)
|
||||||
|
func (h *Handler) mergeRecordWithRequest(dbRecord interface{}, requestData map[string]interface{}) map[string]interface{} {
|
||||||
|
// Convert the database record to a map
|
||||||
|
dbMap := make(map[string]interface{})
|
||||||
|
|
||||||
|
// Marshal and unmarshal to convert struct to map
|
||||||
|
jsonData, err := json.Marshal(dbRecord)
|
||||||
|
if err != nil {
|
||||||
|
logger.Warn("Failed to marshal database record for merging: %v", err)
|
||||||
|
return requestData
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.Unmarshal(jsonData, &dbMap); err != nil {
|
||||||
|
logger.Warn("Failed to unmarshal database record for merging: %v", err)
|
||||||
|
return requestData
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start with the request data (preserves extra keys)
|
||||||
|
result := make(map[string]interface{})
|
||||||
|
for k, v := range requestData {
|
||||||
|
result[k] = v
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update with values from database (overwrites with DB values, including trigger changes)
|
||||||
|
for k, v := range dbMap {
|
||||||
|
result[k] = v
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
// normalizeToSlice converts data to a slice. Single items become a 1-item slice.
|
// normalizeToSlice converts data to a slice. Single items become a 1-item slice.
|
||||||
func (h *Handler) normalizeToSlice(data interface{}) []interface{} {
|
func (h *Handler) normalizeToSlice(data interface{}) []interface{} {
|
||||||
if data == nil {
|
if data == nil {
|
||||||
@@ -1146,7 +1332,7 @@ func (h *Handler) normalizeToSlice(data interface{}) []interface{} {
|
|||||||
func (h *Handler) extractNestedRelations(
|
func (h *Handler) extractNestedRelations(
|
||||||
data map[string]interface{},
|
data map[string]interface{},
|
||||||
model interface{},
|
model interface{},
|
||||||
) (map[string]interface{}, map[string]interface{}, error) {
|
) (_cleanedData map[string]interface{}, _relations map[string]interface{}, _err error) {
|
||||||
// Get model type for reflection
|
// Get model type for reflection
|
||||||
modelType := reflect.TypeOf(model)
|
modelType := reflect.TypeOf(model)
|
||||||
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
||||||
@@ -1564,13 +1750,9 @@ func (h *Handler) sendResponseWithOptions(w common.ResponseWriter, data interfac
|
|||||||
data = h.normalizeResultArray(data)
|
data = h.normalizeResultArray(data)
|
||||||
}
|
}
|
||||||
|
|
||||||
response := common.Response{
|
// Return data as-is without wrapping in common.Response
|
||||||
Success: true,
|
|
||||||
Data: data,
|
|
||||||
Metadata: metadata,
|
|
||||||
}
|
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
if err := w.WriteJSON(response); err != nil {
|
if err := w.WriteJSON(data); err != nil {
|
||||||
logger.Error("Failed to write JSON response: %v", err)
|
logger.Error("Failed to write JSON response: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1579,7 +1761,7 @@ func (h *Handler) sendResponseWithOptions(w common.ResponseWriter, data interfac
|
|||||||
// Returns the single element if data is a slice/array with exactly one element, otherwise returns data unchanged
|
// Returns the single element if data is a slice/array with exactly one element, otherwise returns data unchanged
|
||||||
func (h *Handler) normalizeResultArray(data interface{}) interface{} {
|
func (h *Handler) normalizeResultArray(data interface{}) interface{} {
|
||||||
if data == nil {
|
if data == nil {
|
||||||
return data
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Use reflection to check if data is a slice or array
|
// Use reflection to check if data is a slice or array
|
||||||
@@ -1658,22 +1840,22 @@ func (h *Handler) cleanJSON(data interface{}) interface{} {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) sendError(w common.ResponseWriter, statusCode int, code, message string, err error) {
|
func (h *Handler) sendError(w common.ResponseWriter, statusCode int, code, message string, err error) {
|
||||||
var details string
|
var errorMsg string
|
||||||
if err != nil {
|
if err != nil {
|
||||||
details = err.Error()
|
errorMsg = err.Error()
|
||||||
|
} else if message != "" {
|
||||||
|
errorMsg = message
|
||||||
|
} else {
|
||||||
|
errorMsg = code
|
||||||
}
|
}
|
||||||
|
|
||||||
response := common.Response{
|
response := map[string]interface{}{
|
||||||
Success: false,
|
"_error": errorMsg,
|
||||||
Error: &common.APIError{
|
"_retval": 1,
|
||||||
Code: code,
|
|
||||||
Message: message,
|
|
||||||
Details: details,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
w.WriteHeader(statusCode)
|
w.WriteHeader(statusCode)
|
||||||
if err := w.WriteJSON(response); err != nil {
|
if jsonErr := w.WriteJSON(response); jsonErr != nil {
|
||||||
logger.Error("Failed to write JSON error response: %v", err)
|
logger.Error("Failed to write JSON error response: %v", jsonErr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package restheadspec
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strconv"
|
"strconv"
|
||||||
@@ -9,6 +10,7 @@ import (
|
|||||||
|
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/reflection"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ExtendedRequestOptions extends common.RequestOptions with additional features
|
// ExtendedRequestOptions extends common.RequestOptions with additional features
|
||||||
@@ -42,6 +44,9 @@ type ExtendedRequestOptions struct {
|
|||||||
|
|
||||||
// Transaction
|
// Transaction
|
||||||
AtomicTransaction bool
|
AtomicTransaction bool
|
||||||
|
|
||||||
|
// X-Files configuration - comprehensive query options as a single JSON object
|
||||||
|
XFiles *XFiles
|
||||||
}
|
}
|
||||||
|
|
||||||
// ExpandOption represents a relation expansion configuration
|
// ExpandOption represents a relation expansion configuration
|
||||||
@@ -95,7 +100,8 @@ func DecodeParam(pStr string) (string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// parseOptionsFromHeaders parses all request options from HTTP headers
|
// parseOptionsFromHeaders parses all request options from HTTP headers
|
||||||
func (h *Handler) parseOptionsFromHeaders(r common.Request) ExtendedRequestOptions {
|
// If model is provided, it will resolve table names to field names in preload/expand options
|
||||||
|
func (h *Handler) parseOptionsFromHeaders(r common.Request, model interface{}) ExtendedRequestOptions {
|
||||||
options := ExtendedRequestOptions{
|
options := ExtendedRequestOptions{
|
||||||
RequestOptions: common.RequestOptions{
|
RequestOptions: common.RequestOptions{
|
||||||
Filters: make([]common.FilterOption, 0),
|
Filters: make([]common.FilterOption, 0),
|
||||||
@@ -105,105 +111,140 @@ func (h *Handler) parseOptionsFromHeaders(r common.Request) ExtendedRequestOptio
|
|||||||
AdvancedSQL: make(map[string]string),
|
AdvancedSQL: make(map[string]string),
|
||||||
ComputedQL: make(map[string]string),
|
ComputedQL: make(map[string]string),
|
||||||
Expand: make([]ExpandOption, 0),
|
Expand: make([]ExpandOption, 0),
|
||||||
ResponseFormat: "simple", // Default response format
|
ResponseFormat: "simple", // Default response format
|
||||||
SingleRecordAsObject: true, // Default: normalize single-element arrays to objects
|
SingleRecordAsObject: true, // Default: normalize single-element arrays to objects
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get all headers
|
// Get all headers
|
||||||
headers := r.AllHeaders()
|
headers := r.AllHeaders()
|
||||||
|
|
||||||
// Process each header
|
// Get all query parameters
|
||||||
for key, value := range headers {
|
queryParams := r.AllQueryParams()
|
||||||
// Normalize header key to lowercase for consistent matching
|
|
||||||
normalizedKey := strings.ToLower(key)
|
|
||||||
|
|
||||||
|
// Merge headers and query parameters - query parameters take precedence
|
||||||
|
// This allows the same parameters to be specified in either headers or query string
|
||||||
|
// Normalize keys to lowercase to ensure query params properly override headers
|
||||||
|
combinedParams := make(map[string]string)
|
||||||
|
for key, value := range headers {
|
||||||
|
combinedParams[strings.ToLower(key)] = value
|
||||||
|
}
|
||||||
|
for key, value := range queryParams {
|
||||||
|
combinedParams[strings.ToLower(key)] = value
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process each parameter (from both headers and query params)
|
||||||
|
// Note: keys are already normalized to lowercase in combinedParams
|
||||||
|
for key, value := range combinedParams {
|
||||||
// Decode value if it's base64 encoded
|
// Decode value if it's base64 encoded
|
||||||
decodedValue := decodeHeaderValue(value)
|
decodedValue := decodeHeaderValue(value)
|
||||||
|
|
||||||
// Parse based on header prefix/name
|
// Parse based on parameter prefix/name
|
||||||
switch {
|
switch {
|
||||||
// Field Selection
|
// Field Selection
|
||||||
case strings.HasPrefix(normalizedKey, "x-select-fields"):
|
case strings.HasPrefix(key, "x-select-fields"):
|
||||||
h.parseSelectFields(&options, decodedValue)
|
h.parseSelectFields(&options, decodedValue)
|
||||||
case strings.HasPrefix(normalizedKey, "x-not-select-fields"):
|
case strings.HasPrefix(key, "x-not-select-fields"):
|
||||||
h.parseNotSelectFields(&options, decodedValue)
|
h.parseNotSelectFields(&options, decodedValue)
|
||||||
case strings.HasPrefix(normalizedKey, "x-clean-json"):
|
case strings.HasPrefix(key, "x-clean-json"):
|
||||||
options.CleanJSON = strings.EqualFold(decodedValue, "true")
|
options.CleanJSON = strings.EqualFold(decodedValue, "true")
|
||||||
|
|
||||||
// Filtering & Search
|
// Filtering & Search
|
||||||
case strings.HasPrefix(normalizedKey, "x-fieldfilter-"):
|
case strings.HasPrefix(key, "x-fieldfilter-"):
|
||||||
h.parseFieldFilter(&options, normalizedKey, decodedValue)
|
h.parseFieldFilter(&options, key, decodedValue)
|
||||||
case strings.HasPrefix(normalizedKey, "x-searchfilter-"):
|
case strings.HasPrefix(key, "x-searchfilter-"):
|
||||||
h.parseSearchFilter(&options, normalizedKey, decodedValue)
|
h.parseSearchFilter(&options, key, decodedValue)
|
||||||
case strings.HasPrefix(normalizedKey, "x-searchop-"):
|
case strings.HasPrefix(key, "x-searchop-"):
|
||||||
h.parseSearchOp(&options, normalizedKey, decodedValue, "AND")
|
h.parseSearchOp(&options, key, decodedValue, "AND")
|
||||||
case strings.HasPrefix(normalizedKey, "x-searchor-"):
|
case strings.HasPrefix(key, "x-searchor-"):
|
||||||
h.parseSearchOp(&options, normalizedKey, decodedValue, "OR")
|
h.parseSearchOp(&options, key, decodedValue, "OR")
|
||||||
case strings.HasPrefix(normalizedKey, "x-searchand-"):
|
case strings.HasPrefix(key, "x-searchand-"):
|
||||||
h.parseSearchOp(&options, normalizedKey, decodedValue, "AND")
|
h.parseSearchOp(&options, key, decodedValue, "AND")
|
||||||
case strings.HasPrefix(normalizedKey, "x-searchcols"):
|
case strings.HasPrefix(key, "x-searchcols"):
|
||||||
options.SearchColumns = h.parseCommaSeparated(decodedValue)
|
options.SearchColumns = h.parseCommaSeparated(decodedValue)
|
||||||
case strings.HasPrefix(normalizedKey, "x-custom-sql-w"):
|
case strings.HasPrefix(key, "x-custom-sql-w"):
|
||||||
options.CustomSQLWhere = decodedValue
|
options.CustomSQLWhere = decodedValue
|
||||||
case strings.HasPrefix(normalizedKey, "x-custom-sql-or"):
|
case strings.HasPrefix(key, "x-custom-sql-or"):
|
||||||
options.CustomSQLOr = decodedValue
|
options.CustomSQLOr = decodedValue
|
||||||
|
|
||||||
// Joins & Relations
|
// Joins & Relations
|
||||||
case strings.HasPrefix(normalizedKey, "x-preload"):
|
case strings.HasPrefix(key, "x-preload"):
|
||||||
if strings.HasSuffix(normalizedKey, "-where") {
|
if strings.HasSuffix(key, "-where") {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
whereClaude := headers[fmt.Sprintf("%s-where", key)]
|
whereClaude := combinedParams[fmt.Sprintf("%s-where", key)]
|
||||||
h.parsePreload(&options, decodedValue, decodeHeaderValue(whereClaude))
|
h.parsePreload(&options, decodedValue, decodeHeaderValue(whereClaude))
|
||||||
|
|
||||||
case strings.HasPrefix(normalizedKey, "x-expand"):
|
case strings.HasPrefix(key, "x-expand"):
|
||||||
h.parseExpand(&options, decodedValue)
|
h.parseExpand(&options, decodedValue)
|
||||||
case strings.HasPrefix(normalizedKey, "x-custom-sql-join"):
|
case strings.HasPrefix(key, "x-custom-sql-join"):
|
||||||
// TODO: Implement custom SQL join
|
// TODO: Implement custom SQL join
|
||||||
logger.Debug("Custom SQL join not yet implemented: %s", decodedValue)
|
logger.Debug("Custom SQL join not yet implemented: %s", decodedValue)
|
||||||
|
|
||||||
// Sorting & Pagination
|
// Sorting & Pagination
|
||||||
case strings.HasPrefix(normalizedKey, "x-sort"):
|
case strings.HasPrefix(key, "x-sort"):
|
||||||
h.parseSorting(&options, decodedValue)
|
h.parseSorting(&options, decodedValue)
|
||||||
case strings.HasPrefix(normalizedKey, "x-limit"):
|
// Special cases for older clients using sort(a,b,-c) syntax
|
||||||
|
case strings.HasPrefix(key, "sort(") && strings.Contains(key, ")"):
|
||||||
|
sortValue := key[strings.Index(key, "(")+1 : strings.Index(key, ")")]
|
||||||
|
h.parseSorting(&options, sortValue)
|
||||||
|
case strings.HasPrefix(key, "x-limit"):
|
||||||
if limit, err := strconv.Atoi(decodedValue); err == nil {
|
if limit, err := strconv.Atoi(decodedValue); err == nil {
|
||||||
options.Limit = &limit
|
options.Limit = &limit
|
||||||
}
|
}
|
||||||
case strings.HasPrefix(normalizedKey, "x-offset"):
|
// Special cases for older clients using limit(n) syntax
|
||||||
|
case strings.HasPrefix(key, "limit(") && strings.Contains(key, ")"):
|
||||||
|
limitValue := key[strings.Index(key, "(")+1 : strings.Index(key, ")")]
|
||||||
|
limitValueParts := strings.Split(limitValue, ",")
|
||||||
|
|
||||||
|
if len(limitValueParts) > 1 {
|
||||||
|
if offset, err := strconv.Atoi(limitValueParts[0]); err == nil {
|
||||||
|
options.Offset = &offset
|
||||||
|
}
|
||||||
|
if limit, err := strconv.Atoi(limitValueParts[1]); err == nil {
|
||||||
|
options.Limit = &limit
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if limit, err := strconv.Atoi(limitValueParts[0]); err == nil {
|
||||||
|
options.Limit = &limit
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
case strings.HasPrefix(key, "x-offset"):
|
||||||
if offset, err := strconv.Atoi(decodedValue); err == nil {
|
if offset, err := strconv.Atoi(decodedValue); err == nil {
|
||||||
options.Offset = &offset
|
options.Offset = &offset
|
||||||
}
|
}
|
||||||
case strings.HasPrefix(normalizedKey, "x-cursor-forward"):
|
|
||||||
|
case strings.HasPrefix(key, "x-cursor-forward"):
|
||||||
options.CursorForward = decodedValue
|
options.CursorForward = decodedValue
|
||||||
case strings.HasPrefix(normalizedKey, "x-cursor-backward"):
|
case strings.HasPrefix(key, "x-cursor-backward"):
|
||||||
options.CursorBackward = decodedValue
|
options.CursorBackward = decodedValue
|
||||||
|
|
||||||
// Advanced Features
|
// Advanced Features
|
||||||
case strings.HasPrefix(normalizedKey, "x-advsql-"):
|
case strings.HasPrefix(key, "x-advsql-"):
|
||||||
colName := strings.TrimPrefix(normalizedKey, "x-advsql-")
|
colName := strings.TrimPrefix(key, "x-advsql-")
|
||||||
options.AdvancedSQL[colName] = decodedValue
|
options.AdvancedSQL[colName] = decodedValue
|
||||||
case strings.HasPrefix(normalizedKey, "x-cql-sel-"):
|
case strings.HasPrefix(key, "x-cql-sel-"):
|
||||||
colName := strings.TrimPrefix(normalizedKey, "x-cql-sel-")
|
colName := strings.TrimPrefix(key, "x-cql-sel-")
|
||||||
options.ComputedQL[colName] = decodedValue
|
options.ComputedQL[colName] = decodedValue
|
||||||
case strings.HasPrefix(normalizedKey, "x-distinct"):
|
case strings.HasPrefix(key, "x-distinct"):
|
||||||
options.Distinct = strings.EqualFold(decodedValue, "true")
|
options.Distinct = strings.EqualFold(decodedValue, "true")
|
||||||
case strings.HasPrefix(normalizedKey, "x-skipcount"):
|
case strings.HasPrefix(key, "x-skipcount"):
|
||||||
options.SkipCount = strings.EqualFold(decodedValue, "true")
|
options.SkipCount = strings.EqualFold(decodedValue, "true")
|
||||||
case strings.HasPrefix(normalizedKey, "x-skipcache"):
|
case strings.HasPrefix(key, "x-skipcache"):
|
||||||
options.SkipCache = strings.EqualFold(decodedValue, "true")
|
options.SkipCache = strings.EqualFold(decodedValue, "true")
|
||||||
case strings.HasPrefix(normalizedKey, "x-fetch-rownumber"):
|
case strings.HasPrefix(key, "x-fetch-rownumber"):
|
||||||
options.FetchRowNumber = &decodedValue
|
options.FetchRowNumber = &decodedValue
|
||||||
case strings.HasPrefix(normalizedKey, "x-pkrow"):
|
case strings.HasPrefix(key, "x-pkrow"):
|
||||||
options.PKRow = &decodedValue
|
options.PKRow = &decodedValue
|
||||||
|
|
||||||
// Response Format
|
// Response Format
|
||||||
case strings.HasPrefix(normalizedKey, "x-simpleapi"):
|
case strings.HasPrefix(key, "x-simpleapi"):
|
||||||
options.ResponseFormat = "simple"
|
options.ResponseFormat = "simple"
|
||||||
case strings.HasPrefix(normalizedKey, "x-detailapi"):
|
case strings.HasPrefix(key, "x-detailapi"):
|
||||||
options.ResponseFormat = "detail"
|
options.ResponseFormat = "detail"
|
||||||
case strings.HasPrefix(normalizedKey, "x-syncfusion"):
|
case strings.HasPrefix(key, "x-syncfusion"):
|
||||||
options.ResponseFormat = "syncfusion"
|
options.ResponseFormat = "syncfusion"
|
||||||
case strings.HasPrefix(normalizedKey, "x-single-record-as-object"):
|
case strings.HasPrefix(key, "x-single-record-as-object"):
|
||||||
// Parse as boolean - "false" disables, "true" enables (default is true)
|
// Parse as boolean - "false" disables, "true" enables (default is true)
|
||||||
if strings.EqualFold(decodedValue, "false") {
|
if strings.EqualFold(decodedValue, "false") {
|
||||||
options.SingleRecordAsObject = false
|
options.SingleRecordAsObject = false
|
||||||
@@ -212,11 +253,20 @@ func (h *Handler) parseOptionsFromHeaders(r common.Request) ExtendedRequestOptio
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Transaction Control
|
// Transaction Control
|
||||||
case strings.HasPrefix(normalizedKey, "x-transaction-atomic"):
|
case strings.HasPrefix(key, "x-transaction-atomic"):
|
||||||
options.AtomicTransaction = strings.EqualFold(decodedValue, "true")
|
options.AtomicTransaction = strings.EqualFold(decodedValue, "true")
|
||||||
|
|
||||||
|
// X-Files - comprehensive JSON configuration
|
||||||
|
case strings.HasPrefix(key, "x-files"):
|
||||||
|
h.parseXFiles(&options, decodedValue)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Resolve relation names (convert table names to field names) if model is provided
|
||||||
|
if model != nil {
|
||||||
|
h.resolveRelationNamesInOptions(&options, model)
|
||||||
|
}
|
||||||
|
|
||||||
return options
|
return options
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -480,170 +530,405 @@ func (h *Handler) parseCommaSeparated(value string) []string {
|
|||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
// getColumnTypeFromModel uses reflection to determine the Go type of a column in a model
|
// parseXFiles parses x-files header containing comprehensive JSON configuration
|
||||||
func (h *Handler) getColumnTypeFromModel(model interface{}, colName string) reflect.Kind {
|
// and populates ExtendedRequestOptions fields from it
|
||||||
if model == nil {
|
func (h *Handler) parseXFiles(options *ExtendedRequestOptions, value string) {
|
||||||
return reflect.Invalid
|
if value == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var xfiles XFiles
|
||||||
|
if err := json.Unmarshal([]byte(value), &xfiles); err != nil {
|
||||||
|
logger.Warn("Failed to parse x-files header: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Debug("Parsed x-files configuration for table: %s", xfiles.TableName)
|
||||||
|
|
||||||
|
// Store the original XFiles for reference
|
||||||
|
options.XFiles = &xfiles
|
||||||
|
|
||||||
|
// Map XFiles fields to ExtendedRequestOptions
|
||||||
|
|
||||||
|
// Column selection
|
||||||
|
if len(xfiles.Columns) > 0 {
|
||||||
|
options.Columns = append(options.Columns, xfiles.Columns...)
|
||||||
|
logger.Debug("X-Files: Added columns: %v", xfiles.Columns)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Omit columns
|
||||||
|
if len(xfiles.OmitColumns) > 0 {
|
||||||
|
options.OmitColumns = append(options.OmitColumns, xfiles.OmitColumns...)
|
||||||
|
logger.Debug("X-Files: Added omit columns: %v", xfiles.OmitColumns)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Computed columns (CQL) -> ComputedQL
|
||||||
|
if len(xfiles.CQLColumns) > 0 {
|
||||||
|
if options.ComputedQL == nil {
|
||||||
|
options.ComputedQL = make(map[string]string)
|
||||||
|
}
|
||||||
|
for i, cqlExpr := range xfiles.CQLColumns {
|
||||||
|
colName := fmt.Sprintf("cql%d", i+1)
|
||||||
|
options.ComputedQL[colName] = cqlExpr
|
||||||
|
logger.Debug("X-Files: Added computed column %s: %s", colName, cqlExpr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sorting
|
||||||
|
if len(xfiles.Sort) > 0 {
|
||||||
|
for _, sortField := range xfiles.Sort {
|
||||||
|
direction := "ASC"
|
||||||
|
colName := sortField
|
||||||
|
|
||||||
|
// Handle direction prefixes
|
||||||
|
if strings.HasPrefix(sortField, "-") {
|
||||||
|
direction = "DESC"
|
||||||
|
colName = strings.TrimPrefix(sortField, "-")
|
||||||
|
} else if strings.HasPrefix(sortField, "+") {
|
||||||
|
colName = strings.TrimPrefix(sortField, "+")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle DESC suffix
|
||||||
|
if strings.HasSuffix(strings.ToLower(colName), " desc") {
|
||||||
|
direction = "DESC"
|
||||||
|
colName = strings.TrimSuffix(strings.ToLower(colName), " desc")
|
||||||
|
} else if strings.HasSuffix(strings.ToLower(colName), " asc") {
|
||||||
|
colName = strings.TrimSuffix(strings.ToLower(colName), " asc")
|
||||||
|
}
|
||||||
|
|
||||||
|
options.Sort = append(options.Sort, common.SortOption{
|
||||||
|
Column: strings.TrimSpace(colName),
|
||||||
|
Direction: direction,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
logger.Debug("X-Files: Added %d sort options", len(xfiles.Sort))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Filter fields
|
||||||
|
if len(xfiles.FilterFields) > 0 {
|
||||||
|
for _, filterField := range xfiles.FilterFields {
|
||||||
|
options.Filters = append(options.Filters, common.FilterOption{
|
||||||
|
Column: filterField.Field,
|
||||||
|
Operator: filterField.Operator,
|
||||||
|
Value: filterField.Value,
|
||||||
|
LogicOperator: "AND", // Default to AND
|
||||||
|
})
|
||||||
|
}
|
||||||
|
logger.Debug("X-Files: Added %d filter fields", len(xfiles.FilterFields))
|
||||||
|
}
|
||||||
|
|
||||||
|
// SQL AND conditions -> CustomSQLWhere
|
||||||
|
if len(xfiles.SqlAnd) > 0 {
|
||||||
|
if options.CustomSQLWhere != "" {
|
||||||
|
options.CustomSQLWhere += " AND "
|
||||||
|
}
|
||||||
|
options.CustomSQLWhere += "(" + strings.Join(xfiles.SqlAnd, " AND ") + ")"
|
||||||
|
logger.Debug("X-Files: Added SQL AND conditions")
|
||||||
|
}
|
||||||
|
|
||||||
|
// SQL OR conditions -> CustomSQLOr
|
||||||
|
if len(xfiles.SqlOr) > 0 {
|
||||||
|
if options.CustomSQLOr != "" {
|
||||||
|
options.CustomSQLOr += " OR "
|
||||||
|
}
|
||||||
|
options.CustomSQLOr += "(" + strings.Join(xfiles.SqlOr, " OR ") + ")"
|
||||||
|
logger.Debug("X-Files: Added SQL OR conditions")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pagination - Limit
|
||||||
|
if limitStr := xfiles.Limit.String(); limitStr != "" && limitStr != "0" {
|
||||||
|
if limitVal, err := xfiles.Limit.Int64(); err == nil && limitVal > 0 {
|
||||||
|
limit := int(limitVal)
|
||||||
|
options.Limit = &limit
|
||||||
|
logger.Debug("X-Files: Set limit: %d", limit)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pagination - Offset
|
||||||
|
if offsetStr := xfiles.Offset.String(); offsetStr != "" && offsetStr != "0" {
|
||||||
|
if offsetVal, err := xfiles.Offset.Int64(); err == nil && offsetVal > 0 {
|
||||||
|
offset := int(offsetVal)
|
||||||
|
options.Offset = &offset
|
||||||
|
logger.Debug("X-Files: Set offset: %d", offset)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cursor pagination
|
||||||
|
if xfiles.CursorForward != "" {
|
||||||
|
options.CursorForward = xfiles.CursorForward
|
||||||
|
logger.Debug("X-Files: Set cursor forward")
|
||||||
|
}
|
||||||
|
if xfiles.CursorBackward != "" {
|
||||||
|
options.CursorBackward = xfiles.CursorBackward
|
||||||
|
logger.Debug("X-Files: Set cursor backward")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Flags
|
||||||
|
if xfiles.Skipcount {
|
||||||
|
options.SkipCount = true
|
||||||
|
logger.Debug("X-Files: Set skip count")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process ParentTables and ChildTables recursively
|
||||||
|
h.processXFilesRelations(&xfiles, options, "")
|
||||||
|
}
|
||||||
|
|
||||||
|
// processXFilesRelations processes ParentTables and ChildTables from XFiles
|
||||||
|
// and adds them as Preload options recursively
|
||||||
|
func (h *Handler) processXFilesRelations(xfiles *XFiles, options *ExtendedRequestOptions, basePath string) {
|
||||||
|
if xfiles == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process ParentTables
|
||||||
|
if len(xfiles.ParentTables) > 0 {
|
||||||
|
logger.Debug("X-Files: Processing %d parent tables", len(xfiles.ParentTables))
|
||||||
|
for _, parentTable := range xfiles.ParentTables {
|
||||||
|
h.addXFilesPreload(parentTable, options, basePath)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process ChildTables
|
||||||
|
if len(xfiles.ChildTables) > 0 {
|
||||||
|
logger.Debug("X-Files: Processing %d child tables", len(xfiles.ChildTables))
|
||||||
|
for _, childTable := range xfiles.ChildTables {
|
||||||
|
h.addXFilesPreload(childTable, options, basePath)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// resolveRelationNamesInOptions resolves all table names to field names in preload options
|
||||||
|
// This is called internally by parseOptionsFromHeaders when a model is provided
|
||||||
|
func (h *Handler) resolveRelationNamesInOptions(options *ExtendedRequestOptions, model interface{}) {
|
||||||
|
if options == nil || model == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Resolve relation names in all preload options
|
||||||
|
for i := range options.Preload {
|
||||||
|
preload := &options.Preload[i]
|
||||||
|
|
||||||
|
// Split the relation path (e.g., "parent.child.grandchild")
|
||||||
|
parts := strings.Split(preload.Relation, ".")
|
||||||
|
resolvedParts := make([]string, 0, len(parts))
|
||||||
|
|
||||||
|
// Resolve each part of the path
|
||||||
|
currentModel := model
|
||||||
|
for _, part := range parts {
|
||||||
|
resolvedPart := h.resolveRelationName(currentModel, part)
|
||||||
|
resolvedParts = append(resolvedParts, resolvedPart)
|
||||||
|
|
||||||
|
// Try to get the model type for the next level
|
||||||
|
// This allows nested resolution
|
||||||
|
if nextModel := reflection.GetRelationModel(currentModel, resolvedPart); nextModel != nil {
|
||||||
|
currentModel = nextModel
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update the relation path with resolved names
|
||||||
|
resolvedPath := strings.Join(resolvedParts, ".")
|
||||||
|
if resolvedPath != preload.Relation {
|
||||||
|
logger.Debug("Resolved relation path '%s' -> '%s'", preload.Relation, resolvedPath)
|
||||||
|
preload.Relation = resolvedPath
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Resolve relation names in expand options
|
||||||
|
for i := range options.Expand {
|
||||||
|
expand := &options.Expand[i]
|
||||||
|
resolved := h.resolveRelationName(model, expand.Relation)
|
||||||
|
if resolved != expand.Relation {
|
||||||
|
logger.Debug("Resolved expand relation '%s' -> '%s'", expand.Relation, resolved)
|
||||||
|
expand.Relation = resolved
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// resolveRelationName resolves a relation name or table name to the actual field name in the model
|
||||||
|
// If the input is already a field name, it returns it as-is
|
||||||
|
// If the input is a table name, it looks up the corresponding relation field
|
||||||
|
func (h *Handler) resolveRelationName(model interface{}, nameOrTable string) string {
|
||||||
|
if model == nil || nameOrTable == "" {
|
||||||
|
return nameOrTable
|
||||||
}
|
}
|
||||||
|
|
||||||
modelType := reflect.TypeOf(model)
|
modelType := reflect.TypeOf(model)
|
||||||
|
if modelType == nil {
|
||||||
|
return nameOrTable
|
||||||
|
}
|
||||||
|
|
||||||
// Dereference pointer if needed
|
// Dereference pointer if needed
|
||||||
if modelType.Kind() == reflect.Ptr {
|
if modelType.Kind() == reflect.Ptr {
|
||||||
modelType = modelType.Elem()
|
modelType = modelType.Elem()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check again after dereferencing
|
||||||
|
if modelType == nil {
|
||||||
|
return nameOrTable
|
||||||
|
}
|
||||||
|
|
||||||
// Ensure it's a struct
|
// Ensure it's a struct
|
||||||
if modelType.Kind() != reflect.Struct {
|
if modelType.Kind() != reflect.Struct {
|
||||||
return reflect.Invalid
|
return nameOrTable
|
||||||
}
|
}
|
||||||
|
|
||||||
// Find the field by JSON tag or field name
|
// First, check if the input matches a field name directly
|
||||||
for i := 0; i < modelType.NumField(); i++ {
|
for i := 0; i < modelType.NumField(); i++ {
|
||||||
field := modelType.Field(i)
|
field := modelType.Field(i)
|
||||||
|
if field.Name == nameOrTable {
|
||||||
|
// It's already a field name
|
||||||
|
logger.Debug("Input '%s' is a field name", nameOrTable)
|
||||||
|
return nameOrTable
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Check JSON tag
|
// If not found as a field name, try to look it up as a table name
|
||||||
jsonTag := field.Tag.Get("json")
|
normalizedInput := strings.ToLower(strings.ReplaceAll(nameOrTable, "_", ""))
|
||||||
if jsonTag != "" {
|
|
||||||
// Parse JSON tag (format: "name,omitempty")
|
for i := 0; i < modelType.NumField(); i++ {
|
||||||
parts := strings.Split(jsonTag, ",")
|
field := modelType.Field(i)
|
||||||
if parts[0] == colName {
|
fieldType := field.Type
|
||||||
return field.Type.Kind()
|
|
||||||
|
// Check if it's a slice or pointer to a struct
|
||||||
|
var targetType reflect.Type
|
||||||
|
if fieldType.Kind() == reflect.Slice {
|
||||||
|
targetType = fieldType.Elem()
|
||||||
|
} else if fieldType.Kind() == reflect.Ptr {
|
||||||
|
targetType = fieldType.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
if targetType != nil {
|
||||||
|
// Dereference pointer if the slice contains pointers
|
||||||
|
if targetType.Kind() == reflect.Ptr {
|
||||||
|
targetType = targetType.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if it's a struct type
|
||||||
|
if targetType.Kind() == reflect.Struct {
|
||||||
|
// Get the type name and normalize it
|
||||||
|
typeName := targetType.Name()
|
||||||
|
|
||||||
|
// Extract the table name from type name
|
||||||
|
// Patterns: ModelCoreMastertaskitem -> mastertaskitem
|
||||||
|
// ModelMastertaskitem -> mastertaskitem
|
||||||
|
normalizedTypeName := strings.ToLower(typeName)
|
||||||
|
|
||||||
|
// Remove common prefixes like "model", "modelcore", etc.
|
||||||
|
normalizedTypeName = strings.TrimPrefix(normalizedTypeName, "modelcore")
|
||||||
|
normalizedTypeName = strings.TrimPrefix(normalizedTypeName, "model")
|
||||||
|
|
||||||
|
// Compare normalized names
|
||||||
|
if normalizedTypeName == normalizedInput {
|
||||||
|
logger.Debug("Resolved table name '%s' to field '%s' (type: %s)", nameOrTable, field.Name, typeName)
|
||||||
|
return field.Name
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Check field name (case-insensitive)
|
// If no match found, return the original input
|
||||||
if strings.EqualFold(field.Name, colName) {
|
logger.Debug("No field found for '%s', using as-is", nameOrTable)
|
||||||
return field.Type.Kind()
|
return nameOrTable
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check snake_case conversion
|
// addXFilesPreload converts an XFiles relation into a PreloadOption
|
||||||
snakeCaseName := toSnakeCase(field.Name)
|
// and recursively processes its children
|
||||||
if snakeCaseName == colName {
|
func (h *Handler) addXFilesPreload(xfile *XFiles, options *ExtendedRequestOptions, basePath string) {
|
||||||
return field.Type.Kind()
|
if xfile == nil || xfile.TableName == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store the table name as-is for now - it will be resolved to field name later
|
||||||
|
// when we have the model instance available
|
||||||
|
relationPath := xfile.TableName
|
||||||
|
if basePath != "" {
|
||||||
|
relationPath = basePath + "." + xfile.TableName
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Debug("X-Files: Adding preload for relation: %s", relationPath)
|
||||||
|
|
||||||
|
// Create PreloadOption from XFiles configuration
|
||||||
|
preloadOpt := common.PreloadOption{
|
||||||
|
Relation: relationPath,
|
||||||
|
Columns: xfile.Columns,
|
||||||
|
OmitColumns: xfile.OmitColumns,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add sorting if specified
|
||||||
|
if len(xfile.Sort) > 0 {
|
||||||
|
preloadOpt.Sort = make([]common.SortOption, 0, len(xfile.Sort))
|
||||||
|
for _, sortField := range xfile.Sort {
|
||||||
|
direction := "ASC"
|
||||||
|
colName := sortField
|
||||||
|
|
||||||
|
// Handle direction prefixes
|
||||||
|
if strings.HasPrefix(sortField, "-") {
|
||||||
|
direction = "DESC"
|
||||||
|
colName = strings.TrimPrefix(sortField, "-")
|
||||||
|
} else if strings.HasPrefix(sortField, "+") {
|
||||||
|
colName = strings.TrimPrefix(sortField, "+")
|
||||||
|
}
|
||||||
|
|
||||||
|
preloadOpt.Sort = append(preloadOpt.Sort, common.SortOption{
|
||||||
|
Column: strings.TrimSpace(colName),
|
||||||
|
Direction: direction,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return reflect.Invalid
|
// Add filters if specified
|
||||||
}
|
if len(xfile.FilterFields) > 0 {
|
||||||
|
preloadOpt.Filters = make([]common.FilterOption, 0, len(xfile.FilterFields))
|
||||||
// toSnakeCase converts a string from CamelCase to snake_case
|
for _, filterField := range xfile.FilterFields {
|
||||||
func toSnakeCase(s string) string {
|
preloadOpt.Filters = append(preloadOpt.Filters, common.FilterOption{
|
||||||
var result strings.Builder
|
Column: filterField.Field,
|
||||||
for i, r := range s {
|
Operator: filterField.Operator,
|
||||||
if i > 0 && r >= 'A' && r <= 'Z' {
|
Value: filterField.Value,
|
||||||
result.WriteRune('_')
|
LogicOperator: "AND",
|
||||||
|
})
|
||||||
}
|
}
|
||||||
result.WriteRune(r)
|
|
||||||
}
|
|
||||||
return strings.ToLower(result.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
// isNumericType checks if a reflect.Kind is a numeric type
|
|
||||||
func isNumericType(kind reflect.Kind) bool {
|
|
||||||
return kind == reflect.Int || kind == reflect.Int8 || kind == reflect.Int16 ||
|
|
||||||
kind == reflect.Int32 || kind == reflect.Int64 || kind == reflect.Uint ||
|
|
||||||
kind == reflect.Uint8 || kind == reflect.Uint16 || kind == reflect.Uint32 ||
|
|
||||||
kind == reflect.Uint64 || kind == reflect.Float32 || kind == reflect.Float64
|
|
||||||
}
|
|
||||||
|
|
||||||
// isStringType checks if a reflect.Kind is a string type
|
|
||||||
func isStringType(kind reflect.Kind) bool {
|
|
||||||
return kind == reflect.String
|
|
||||||
}
|
|
||||||
|
|
||||||
// convertToNumericType converts a string value to the appropriate numeric type
|
|
||||||
func convertToNumericType(value string, kind reflect.Kind) (interface{}, error) {
|
|
||||||
value = strings.TrimSpace(value)
|
|
||||||
|
|
||||||
switch kind {
|
|
||||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
|
||||||
// Parse as integer
|
|
||||||
bitSize := 64
|
|
||||||
switch kind {
|
|
||||||
case reflect.Int8:
|
|
||||||
bitSize = 8
|
|
||||||
case reflect.Int16:
|
|
||||||
bitSize = 16
|
|
||||||
case reflect.Int32:
|
|
||||||
bitSize = 32
|
|
||||||
}
|
|
||||||
|
|
||||||
intVal, err := strconv.ParseInt(value, 10, bitSize)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("invalid integer value: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Return the appropriate type
|
|
||||||
switch kind {
|
|
||||||
case reflect.Int:
|
|
||||||
return int(intVal), nil
|
|
||||||
case reflect.Int8:
|
|
||||||
return int8(intVal), nil
|
|
||||||
case reflect.Int16:
|
|
||||||
return int16(intVal), nil
|
|
||||||
case reflect.Int32:
|
|
||||||
return int32(intVal), nil
|
|
||||||
case reflect.Int64:
|
|
||||||
return intVal, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
|
||||||
// Parse as unsigned integer
|
|
||||||
bitSize := 64
|
|
||||||
switch kind {
|
|
||||||
case reflect.Uint8:
|
|
||||||
bitSize = 8
|
|
||||||
case reflect.Uint16:
|
|
||||||
bitSize = 16
|
|
||||||
case reflect.Uint32:
|
|
||||||
bitSize = 32
|
|
||||||
}
|
|
||||||
|
|
||||||
uintVal, err := strconv.ParseUint(value, 10, bitSize)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("invalid unsigned integer value: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Return the appropriate type
|
|
||||||
switch kind {
|
|
||||||
case reflect.Uint:
|
|
||||||
return uint(uintVal), nil
|
|
||||||
case reflect.Uint8:
|
|
||||||
return uint8(uintVal), nil
|
|
||||||
case reflect.Uint16:
|
|
||||||
return uint16(uintVal), nil
|
|
||||||
case reflect.Uint32:
|
|
||||||
return uint32(uintVal), nil
|
|
||||||
case reflect.Uint64:
|
|
||||||
return uintVal, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
case reflect.Float32, reflect.Float64:
|
|
||||||
// Parse as float
|
|
||||||
bitSize := 64
|
|
||||||
if kind == reflect.Float32 {
|
|
||||||
bitSize = 32
|
|
||||||
}
|
|
||||||
|
|
||||||
floatVal, err := strconv.ParseFloat(value, bitSize)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("invalid float value: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if kind == reflect.Float32 {
|
|
||||||
return float32(floatVal), nil
|
|
||||||
}
|
|
||||||
return floatVal, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, fmt.Errorf("unsupported numeric type: %v", kind)
|
// Add WHERE clause if SQL conditions specified
|
||||||
}
|
whereConditions := make([]string, 0)
|
||||||
|
if len(xfile.SqlAnd) > 0 {
|
||||||
|
whereConditions = append(whereConditions, xfile.SqlAnd...)
|
||||||
|
}
|
||||||
|
if len(whereConditions) > 0 {
|
||||||
|
preloadOpt.Where = strings.Join(whereConditions, " AND ")
|
||||||
|
}
|
||||||
|
|
||||||
// isNumericValue checks if a string value can be parsed as a number
|
// Add limit if specified
|
||||||
func isNumericValue(value string) bool {
|
if limitStr := xfile.Limit.String(); limitStr != "" && limitStr != "0" {
|
||||||
value = strings.TrimSpace(value)
|
if limitVal, err := xfile.Limit.Int64(); err == nil && limitVal > 0 {
|
||||||
_, err := strconv.ParseFloat(value, 64)
|
limit := int(limitVal)
|
||||||
return err == nil
|
preloadOpt.Limit = &limit
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add computed columns (CQL) -> ComputedQL
|
||||||
|
if len(xfile.CQLColumns) > 0 {
|
||||||
|
preloadOpt.ComputedQL = make(map[string]string)
|
||||||
|
for i, cqlExpr := range xfile.CQLColumns {
|
||||||
|
colName := fmt.Sprintf("cql%d", i+1)
|
||||||
|
preloadOpt.ComputedQL[colName] = cqlExpr
|
||||||
|
logger.Debug("X-Files: Added computed column %s to preload %s: %s", colName, relationPath, cqlExpr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set recursive flag
|
||||||
|
preloadOpt.Recursive = xfile.Recursive
|
||||||
|
|
||||||
|
// Add the preload option
|
||||||
|
options.Preload = append(options.Preload, preloadOpt)
|
||||||
|
|
||||||
|
// Recursively process nested ParentTables and ChildTables
|
||||||
|
if xfile.Recursive {
|
||||||
|
logger.Debug("X-Files: Recursive preload enabled for: %s", relationPath)
|
||||||
|
h.processXFilesRelations(xfile, options, relationPath)
|
||||||
|
} else if len(xfile.ParentTables) > 0 || len(xfile.ChildTables) > 0 {
|
||||||
|
h.processXFilesRelations(xfile, options, relationPath)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ColumnCastInfo holds information about whether a column needs casting
|
// ColumnCastInfo holds information about whether a column needs casting
|
||||||
@@ -659,7 +944,7 @@ func (h *Handler) ValidateAndAdjustFilterForColumnType(filter *common.FilterOpti
|
|||||||
return ColumnCastInfo{NeedsCast: false, IsNumericType: false}
|
return ColumnCastInfo{NeedsCast: false, IsNumericType: false}
|
||||||
}
|
}
|
||||||
|
|
||||||
colType := h.getColumnTypeFromModel(model, filter.Column)
|
colType := reflection.GetColumnTypeFromModel(model, filter.Column)
|
||||||
if colType == reflect.Invalid {
|
if colType == reflect.Invalid {
|
||||||
// Column not found in model, no casting needed
|
// Column not found in model, no casting needed
|
||||||
logger.Debug("Column %s not found in model, skipping type validation", filter.Column)
|
logger.Debug("Column %s not found in model, skipping type validation", filter.Column)
|
||||||
@@ -670,18 +955,18 @@ func (h *Handler) ValidateAndAdjustFilterForColumnType(filter *common.FilterOpti
|
|||||||
valueIsNumeric := false
|
valueIsNumeric := false
|
||||||
if strVal, ok := filter.Value.(string); ok {
|
if strVal, ok := filter.Value.(string); ok {
|
||||||
strVal = strings.Trim(strVal, "%")
|
strVal = strings.Trim(strVal, "%")
|
||||||
valueIsNumeric = isNumericValue(strVal)
|
valueIsNumeric = reflection.IsNumericValue(strVal)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Adjust based on column type
|
// Adjust based on column type
|
||||||
switch {
|
switch {
|
||||||
case isNumericType(colType):
|
case reflection.IsNumericType(colType):
|
||||||
// Column is numeric
|
// Column is numeric
|
||||||
if valueIsNumeric {
|
if valueIsNumeric {
|
||||||
// Value is numeric - try to convert it
|
// Value is numeric - try to convert it
|
||||||
if strVal, ok := filter.Value.(string); ok {
|
if strVal, ok := filter.Value.(string); ok {
|
||||||
strVal = strings.Trim(strVal, "%")
|
strVal = strings.Trim(strVal, "%")
|
||||||
numericVal, err := convertToNumericType(strVal, colType)
|
numericVal, err := reflection.ConvertToNumericType(strVal, colType)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Debug("Failed to convert value '%s' to numeric type for column %s, will use text cast", strVal, filter.Column)
|
logger.Debug("Failed to convert value '%s' to numeric type for column %s, will use text cast", strVal, filter.Column)
|
||||||
return ColumnCastInfo{NeedsCast: true, IsNumericType: true}
|
return ColumnCastInfo{NeedsCast: true, IsNumericType: true}
|
||||||
@@ -696,7 +981,7 @@ func (h *Handler) ValidateAndAdjustFilterForColumnType(filter *common.FilterOpti
|
|||||||
return ColumnCastInfo{NeedsCast: true, IsNumericType: true}
|
return ColumnCastInfo{NeedsCast: true, IsNumericType: true}
|
||||||
}
|
}
|
||||||
|
|
||||||
case isStringType(colType):
|
case reflection.IsStringType(colType):
|
||||||
// String columns don't need casting
|
// String columns don't need casting
|
||||||
return ColumnCastInfo{NeedsCast: false, IsNumericType: false}
|
return ColumnCastInfo{NeedsCast: false, IsNumericType: false}
|
||||||
|
|
||||||
|
|||||||
403
pkg/restheadspec/query_params_test.go
Normal file
403
pkg/restheadspec/query_params_test.go
Normal file
@@ -0,0 +1,403 @@
|
|||||||
|
package restheadspec
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MockRequest implements common.Request interface for testing
|
||||||
|
type MockRequest struct {
|
||||||
|
headers map[string]string
|
||||||
|
queryParams map[string]string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockRequest) Method() string {
|
||||||
|
return "GET"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockRequest) URL() string {
|
||||||
|
return "http://example.com/test"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockRequest) Header(key string) string {
|
||||||
|
return m.headers[key]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockRequest) AllHeaders() map[string]string {
|
||||||
|
return m.headers
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockRequest) Body() ([]byte, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockRequest) PathParam(key string) string {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockRequest) QueryParam(key string) string {
|
||||||
|
return m.queryParams[key]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockRequest) AllQueryParams() map[string]string {
|
||||||
|
return m.queryParams
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseOptionsFromQueryParams(t *testing.T) {
|
||||||
|
handler := NewHandler(nil, nil)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
queryParams map[string]string
|
||||||
|
headers map[string]string
|
||||||
|
validate func(t *testing.T, options ExtendedRequestOptions)
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Parse custom SQL WHERE from query params",
|
||||||
|
queryParams: map[string]string{
|
||||||
|
"x-custom-sql-w-1": `("v_webui_clients".clientstatus = 0 or "v_webui_clients".clientstatus is null)`,
|
||||||
|
},
|
||||||
|
validate: func(t *testing.T, options ExtendedRequestOptions) {
|
||||||
|
if options.CustomSQLWhere == "" {
|
||||||
|
t.Error("Expected CustomSQLWhere to be set from query param")
|
||||||
|
}
|
||||||
|
expected := `("v_webui_clients".clientstatus = 0 or "v_webui_clients".clientstatus is null)`
|
||||||
|
if options.CustomSQLWhere != expected {
|
||||||
|
t.Errorf("Expected CustomSQLWhere=%q, got %q", expected, options.CustomSQLWhere)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Parse sort from query params",
|
||||||
|
queryParams: map[string]string{
|
||||||
|
"x-sort": "-applicationdate,name",
|
||||||
|
},
|
||||||
|
validate: func(t *testing.T, options ExtendedRequestOptions) {
|
||||||
|
if len(options.Sort) != 2 {
|
||||||
|
t.Errorf("Expected 2 sort options, got %d", len(options.Sort))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if options.Sort[0].Column != "applicationdate" || options.Sort[0].Direction != "DESC" {
|
||||||
|
t.Errorf("Expected first sort: applicationdate DESC, got %s %s", options.Sort[0].Column, options.Sort[0].Direction)
|
||||||
|
}
|
||||||
|
if options.Sort[1].Column != "name" || options.Sort[1].Direction != "ASC" {
|
||||||
|
t.Errorf("Expected second sort: name ASC, got %s %s", options.Sort[1].Column, options.Sort[1].Direction)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Parse limit and offset from query params",
|
||||||
|
queryParams: map[string]string{
|
||||||
|
"x-limit": "100",
|
||||||
|
"x-offset": "50",
|
||||||
|
},
|
||||||
|
validate: func(t *testing.T, options ExtendedRequestOptions) {
|
||||||
|
if options.Limit == nil || *options.Limit != 100 {
|
||||||
|
t.Errorf("Expected limit=100, got %v", options.Limit)
|
||||||
|
}
|
||||||
|
if options.Offset == nil || *options.Offset != 50 {
|
||||||
|
t.Errorf("Expected offset=50, got %v", options.Offset)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Parse field filters from query params",
|
||||||
|
queryParams: map[string]string{
|
||||||
|
"x-fieldfilter-status": "active",
|
||||||
|
"x-fieldfilter-type": "user",
|
||||||
|
},
|
||||||
|
validate: func(t *testing.T, options ExtendedRequestOptions) {
|
||||||
|
if len(options.Filters) != 2 {
|
||||||
|
t.Errorf("Expected 2 filters, got %d", len(options.Filters))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Check that filters were created
|
||||||
|
foundStatus := false
|
||||||
|
foundType := false
|
||||||
|
for _, filter := range options.Filters {
|
||||||
|
if filter.Column == "status" && filter.Value == "active" && filter.Operator == "eq" {
|
||||||
|
foundStatus = true
|
||||||
|
}
|
||||||
|
if filter.Column == "type" && filter.Value == "user" && filter.Operator == "eq" {
|
||||||
|
foundType = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !foundStatus {
|
||||||
|
t.Error("Expected status filter not found")
|
||||||
|
}
|
||||||
|
if !foundType {
|
||||||
|
t.Error("Expected type filter not found")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Parse select fields from query params",
|
||||||
|
queryParams: map[string]string{
|
||||||
|
"x-select-fields": "id,name,email",
|
||||||
|
},
|
||||||
|
validate: func(t *testing.T, options ExtendedRequestOptions) {
|
||||||
|
if len(options.Columns) != 3 {
|
||||||
|
t.Errorf("Expected 3 columns, got %d", len(options.Columns))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
expected := []string{"id", "name", "email"}
|
||||||
|
for i, col := range expected {
|
||||||
|
if i >= len(options.Columns) || options.Columns[i] != col {
|
||||||
|
t.Errorf("Expected column[%d]=%s, got %v", i, col, options.Columns)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Parse preload from query params",
|
||||||
|
queryParams: map[string]string{
|
||||||
|
"x-preload": "posts:title,content|comments",
|
||||||
|
},
|
||||||
|
validate: func(t *testing.T, options ExtendedRequestOptions) {
|
||||||
|
if len(options.Preload) != 2 {
|
||||||
|
t.Errorf("Expected 2 preload options, got %d", len(options.Preload))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Check first preload (posts with columns)
|
||||||
|
if options.Preload[0].Relation != "posts" {
|
||||||
|
t.Errorf("Expected first preload relation=posts, got %s", options.Preload[0].Relation)
|
||||||
|
}
|
||||||
|
if len(options.Preload[0].Columns) != 2 {
|
||||||
|
t.Errorf("Expected 2 columns for posts preload, got %d", len(options.Preload[0].Columns))
|
||||||
|
}
|
||||||
|
// Check second preload (comments without columns)
|
||||||
|
if options.Preload[1].Relation != "comments" {
|
||||||
|
t.Errorf("Expected second preload relation=comments, got %s", options.Preload[1].Relation)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Query params take precedence over headers",
|
||||||
|
queryParams: map[string]string{
|
||||||
|
"x-limit": "100",
|
||||||
|
},
|
||||||
|
headers: map[string]string{
|
||||||
|
"X-Limit": "50",
|
||||||
|
},
|
||||||
|
validate: func(t *testing.T, options ExtendedRequestOptions) {
|
||||||
|
if options.Limit == nil || *options.Limit != 100 {
|
||||||
|
t.Errorf("Expected query param limit=100 to override header, got %v", options.Limit)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Parse search operators from query params",
|
||||||
|
queryParams: map[string]string{
|
||||||
|
"x-searchop-contains-name": "john",
|
||||||
|
"x-searchop-gt-age": "18",
|
||||||
|
},
|
||||||
|
validate: func(t *testing.T, options ExtendedRequestOptions) {
|
||||||
|
if len(options.Filters) != 2 {
|
||||||
|
t.Errorf("Expected 2 filters, got %d", len(options.Filters))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Check for ILIKE filter
|
||||||
|
foundContains := false
|
||||||
|
foundGt := false
|
||||||
|
for _, filter := range options.Filters {
|
||||||
|
if filter.Column == "name" && filter.Operator == "ilike" {
|
||||||
|
foundContains = true
|
||||||
|
}
|
||||||
|
if filter.Column == "age" && filter.Operator == "gt" && filter.Value == "18" {
|
||||||
|
foundGt = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !foundContains {
|
||||||
|
t.Error("Expected contains filter not found")
|
||||||
|
}
|
||||||
|
if !foundGt {
|
||||||
|
t.Error("Expected gt filter not found")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Parse complex example with multiple params",
|
||||||
|
queryParams: map[string]string{
|
||||||
|
"x-custom-sql-w-1": `("v_webui_clients".clientstatus = 0)`,
|
||||||
|
"x-sort": "-applicationdate",
|
||||||
|
"x-limit": "100",
|
||||||
|
"x-select-fields": "id,name,status",
|
||||||
|
"x-fieldfilter-active": "true",
|
||||||
|
},
|
||||||
|
validate: func(t *testing.T, options ExtendedRequestOptions) {
|
||||||
|
// Validate CustomSQLWhere
|
||||||
|
if options.CustomSQLWhere == "" {
|
||||||
|
t.Error("Expected CustomSQLWhere to be set")
|
||||||
|
}
|
||||||
|
// Validate Sort
|
||||||
|
if len(options.Sort) != 1 || options.Sort[0].Column != "applicationdate" || options.Sort[0].Direction != "DESC" {
|
||||||
|
t.Errorf("Expected sort by applicationdate DESC, got %v", options.Sort)
|
||||||
|
}
|
||||||
|
// Validate Limit
|
||||||
|
if options.Limit == nil || *options.Limit != 100 {
|
||||||
|
t.Errorf("Expected limit=100, got %v", options.Limit)
|
||||||
|
}
|
||||||
|
// Validate Columns
|
||||||
|
if len(options.Columns) != 3 {
|
||||||
|
t.Errorf("Expected 3 columns, got %d", len(options.Columns))
|
||||||
|
}
|
||||||
|
// Validate Filters
|
||||||
|
if len(options.Filters) < 1 {
|
||||||
|
t.Error("Expected at least 1 filter")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Parse distinct flag from query params",
|
||||||
|
queryParams: map[string]string{
|
||||||
|
"x-distinct": "true",
|
||||||
|
},
|
||||||
|
validate: func(t *testing.T, options ExtendedRequestOptions) {
|
||||||
|
if !options.Distinct {
|
||||||
|
t.Error("Expected Distinct to be true")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Parse skip count flag from query params",
|
||||||
|
queryParams: map[string]string{
|
||||||
|
"x-skipcount": "true",
|
||||||
|
},
|
||||||
|
validate: func(t *testing.T, options ExtendedRequestOptions) {
|
||||||
|
if !options.SkipCount {
|
||||||
|
t.Error("Expected SkipCount to be true")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Parse response format from query params",
|
||||||
|
queryParams: map[string]string{
|
||||||
|
"x-syncfusion": "true",
|
||||||
|
},
|
||||||
|
validate: func(t *testing.T, options ExtendedRequestOptions) {
|
||||||
|
if options.ResponseFormat != "syncfusion" {
|
||||||
|
t.Errorf("Expected ResponseFormat=syncfusion, got %s", options.ResponseFormat)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Parse custom SQL OR from query params",
|
||||||
|
queryParams: map[string]string{
|
||||||
|
"x-custom-sql-or": `("field1" = 'value1' OR "field2" = 'value2')`,
|
||||||
|
},
|
||||||
|
validate: func(t *testing.T, options ExtendedRequestOptions) {
|
||||||
|
if options.CustomSQLOr == "" {
|
||||||
|
t.Error("Expected CustomSQLOr to be set")
|
||||||
|
}
|
||||||
|
expected := `("field1" = 'value1' OR "field2" = 'value2')`
|
||||||
|
if options.CustomSQLOr != expected {
|
||||||
|
t.Errorf("Expected CustomSQLOr=%q, got %q", expected, options.CustomSQLOr)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
// Create mock request
|
||||||
|
req := &MockRequest{
|
||||||
|
headers: tt.headers,
|
||||||
|
queryParams: tt.queryParams,
|
||||||
|
}
|
||||||
|
if req.headers == nil {
|
||||||
|
req.headers = make(map[string]string)
|
||||||
|
}
|
||||||
|
if req.queryParams == nil {
|
||||||
|
req.queryParams = make(map[string]string)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse options
|
||||||
|
options := handler.parseOptionsFromHeaders(req, nil)
|
||||||
|
|
||||||
|
// Validate
|
||||||
|
tt.validate(t, options)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQueryParamsWithURLEncoding(t *testing.T) {
|
||||||
|
handler := NewHandler(nil, nil)
|
||||||
|
|
||||||
|
// Test with URL-encoded query parameter (like the user's example)
|
||||||
|
req := &MockRequest{
|
||||||
|
headers: make(map[string]string),
|
||||||
|
queryParams: map[string]string{
|
||||||
|
// URL-encoded version of the SQL WHERE clause
|
||||||
|
"x-custom-sql-w-1": `("v_webui_clients".clientstatus = 0 or "v_webui_clients".clientstatus is null) and ("v_webui_clients".inactive = 0 or "v_webui_clients".inactive is null)`,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
options := handler.parseOptionsFromHeaders(req, nil)
|
||||||
|
|
||||||
|
if options.CustomSQLWhere == "" {
|
||||||
|
t.Error("Expected CustomSQLWhere to be set from URL-encoded query param")
|
||||||
|
}
|
||||||
|
|
||||||
|
// The SQL should contain the expected conditions
|
||||||
|
if !contains(options.CustomSQLWhere, "clientstatus") {
|
||||||
|
t.Error("Expected CustomSQLWhere to contain 'clientstatus'")
|
||||||
|
}
|
||||||
|
if !contains(options.CustomSQLWhere, "inactive") {
|
||||||
|
t.Error("Expected CustomSQLWhere to contain 'inactive'")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHeadersAndQueryParamsCombined(t *testing.T) {
|
||||||
|
handler := NewHandler(nil, nil)
|
||||||
|
|
||||||
|
// Test that headers and query params can work together
|
||||||
|
req := &MockRequest{
|
||||||
|
headers: map[string]string{
|
||||||
|
"X-Select-Fields": "id,name",
|
||||||
|
"X-Limit": "50",
|
||||||
|
},
|
||||||
|
queryParams: map[string]string{
|
||||||
|
"x-sort": "-created_at",
|
||||||
|
"x-offset": "10",
|
||||||
|
// This should override the header value
|
||||||
|
"x-limit": "100",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
options := handler.parseOptionsFromHeaders(req, nil)
|
||||||
|
|
||||||
|
// Verify columns from header
|
||||||
|
if len(options.Columns) != 2 {
|
||||||
|
t.Errorf("Expected 2 columns from header, got %d", len(options.Columns))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify sort from query param
|
||||||
|
if len(options.Sort) != 1 || options.Sort[0].Column != "created_at" {
|
||||||
|
t.Errorf("Expected sort from query param, got %v", options.Sort)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify offset from query param
|
||||||
|
if options.Offset == nil || *options.Offset != 10 {
|
||||||
|
t.Errorf("Expected offset=10 from query param, got %v", options.Offset)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify limit from query param (should override header)
|
||||||
|
if options.Limit == nil {
|
||||||
|
t.Error("Expected limit to be set from query param")
|
||||||
|
} else if *options.Limit != 100 {
|
||||||
|
t.Errorf("Expected limit=100 from query param (overriding header), got %d", *options.Limit)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper function to check if a string contains a substring
|
||||||
|
func contains(s, substr string) bool {
|
||||||
|
return len(s) >= len(substr) && (s == substr || len(s) > len(substr) && containsHelper(s, substr))
|
||||||
|
}
|
||||||
|
|
||||||
|
func containsHelper(s, substr string) bool {
|
||||||
|
for i := 0; i <= len(s)-len(substr); i++ {
|
||||||
|
if s[i:i+len(substr)] == substr {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
@@ -10,7 +10,7 @@ import (
|
|||||||
type TestModel struct {
|
type TestModel struct {
|
||||||
ID int64 `json:"id" bun:"id,pk"`
|
ID int64 `json:"id" bun:"id,pk"`
|
||||||
Name string `json:"name" bun:"name"`
|
Name string `json:"name" bun:"name"`
|
||||||
RowNumber int64 `json:"_rownumber,omitempty" gorm:"-" bun:"-"`
|
RowNumber int64 `json:"_rownumber,omitempty" gorm:"-" bun:",scanonly"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSetRowNumbersOnRecords(t *testing.T) {
|
func TestSetRowNumbersOnRecords(t *testing.T) {
|
||||||
|
|||||||
431
pkg/restheadspec/xfiles.go
Normal file
431
pkg/restheadspec/xfiles.go
Normal file
@@ -0,0 +1,431 @@
|
|||||||
|
package restheadspec
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"reflect"
|
||||||
|
)
|
||||||
|
|
||||||
|
type XFiles struct {
|
||||||
|
TableName string `json:"tablename"`
|
||||||
|
Schema string `json:"schema"`
|
||||||
|
PrimaryKey string `json:"primarykey"`
|
||||||
|
ForeignKey string `json:"foreignkey"`
|
||||||
|
RelatedKey string `json:"relatedkey"`
|
||||||
|
Sort []string `json:"sort"`
|
||||||
|
Prefix string `json:"prefix"`
|
||||||
|
Editable bool `json:"editable"`
|
||||||
|
Recursive bool `json:"recursive"`
|
||||||
|
Expand bool `json:"expand"`
|
||||||
|
Rownumber bool `json:"rownumber"`
|
||||||
|
Skipcount bool `json:"skipcount"`
|
||||||
|
Offset json.Number `json:"offset"`
|
||||||
|
Limit json.Number `json:"limit"`
|
||||||
|
Columns []string `json:"columns"`
|
||||||
|
OmitColumns []string `json:"omit_columns"`
|
||||||
|
CQLColumns []string `json:"cql_columns"`
|
||||||
|
|
||||||
|
SqlJoins []string `json:"sql_joins"`
|
||||||
|
SqlOr []string `json:"sql_or"`
|
||||||
|
SqlAnd []string `json:"sql_and"`
|
||||||
|
ParentTables []*XFiles `json:"parenttables"`
|
||||||
|
ChildTables []*XFiles `json:"childtables"`
|
||||||
|
ModelType reflect.Type `json:"-"`
|
||||||
|
ParentEntity *XFiles `json:"-"`
|
||||||
|
Level uint `json:"-"`
|
||||||
|
Errors []error `json:"-"`
|
||||||
|
FilterFields []struct {
|
||||||
|
Field string `json:"field"`
|
||||||
|
Value string `json:"value"`
|
||||||
|
Operator string `json:"operator"`
|
||||||
|
} `json:"filter_fields"`
|
||||||
|
CursorForward string `json:"cursor_forward"`
|
||||||
|
CursorBackward string `json:"cursor_backward"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// func (m *XFiles) SetParent() {
|
||||||
|
// if m.ChildTables != nil {
|
||||||
|
// for _, child := range m.ChildTables {
|
||||||
|
// if child.ParentEntity != nil {
|
||||||
|
// continue
|
||||||
|
// }
|
||||||
|
// child.ParentEntity = m
|
||||||
|
// child.Level = m.Level + 1000
|
||||||
|
// child.SetParent()
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// if m.ParentTables != nil {
|
||||||
|
// for _, pt := range m.ParentTables {
|
||||||
|
// if pt.ParentEntity != nil {
|
||||||
|
// continue
|
||||||
|
// }
|
||||||
|
// pt.ParentEntity = m
|
||||||
|
// pt.Level = m.Level + 1
|
||||||
|
// pt.SetParent()
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func (m *XFiles) GetParentRelations() []reflection.GormRelationType {
|
||||||
|
// if m.ParentEntity == nil {
|
||||||
|
// return nil
|
||||||
|
// }
|
||||||
|
|
||||||
|
// foundRelations := make(GormRelationTypeList, 0)
|
||||||
|
// rels := reflection.GetValidModelRelationTypes(m.ParentEntity.ModelType, false)
|
||||||
|
|
||||||
|
// if m.ParentEntity.ModelType == nil {
|
||||||
|
// return nil
|
||||||
|
// }
|
||||||
|
|
||||||
|
// for _, rel := range rels {
|
||||||
|
// // if len(foundRelations) > 0 {
|
||||||
|
// // break
|
||||||
|
// // }
|
||||||
|
// if rel.FieldName != "" && rel.AssociationTable.Name() == m.ModelType.Name() {
|
||||||
|
|
||||||
|
// if rel.AssociationKey != "" && m.RelatedKey != "" && strings.EqualFold(rel.AssociationKey, m.RelatedKey) {
|
||||||
|
// foundRelations = append(foundRelations, rel)
|
||||||
|
// } else if rel.AssociationKey != "" && m.ForeignKey != "" && strings.EqualFold(rel.AssociationKey, m.ForeignKey) {
|
||||||
|
// foundRelations = append(foundRelations, rel)
|
||||||
|
// } else if rel.ForeignKey != "" && m.ForeignKey != "" && strings.EqualFold(rel.ForeignKey, m.ForeignKey) {
|
||||||
|
// foundRelations = append(foundRelations, rel)
|
||||||
|
// } else if rel.ForeignKey != "" && m.RelatedKey != "" && strings.EqualFold(rel.ForeignKey, m.RelatedKey) {
|
||||||
|
// foundRelations = append(foundRelations, rel)
|
||||||
|
// } else if rel.ForeignKey != "" && m.ForeignKey == "" && m.RelatedKey == "" {
|
||||||
|
// foundRelations = append(foundRelations, rel)
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
// //idName := fmt.Sprintf("%s_to_%s_%s=%s_m%v", rel.TableName, rel.AssociationTableName, rel.ForeignKey, rel.AssociationKey, rel.OneToMany)
|
||||||
|
// }
|
||||||
|
|
||||||
|
// sort.Sort(foundRelations)
|
||||||
|
// finalList := make(GormRelationTypeList, 0)
|
||||||
|
// dups := make(map[string]bool)
|
||||||
|
// for _, rel := range foundRelations {
|
||||||
|
// idName := fmt.Sprintf("%s_to_%s_%s_%s=%s_m%v", rel.TableName, rel.AssociationTableName, rel.FieldName, rel.ForeignKey, rel.AssociationKey, rel.OneToMany)
|
||||||
|
// if dups[idName] {
|
||||||
|
// continue
|
||||||
|
// }
|
||||||
|
// finalList = append(finalList, rel)
|
||||||
|
// dups[idName] = true
|
||||||
|
// }
|
||||||
|
|
||||||
|
// //fmt.Printf("GetParentRelations %s: %+v %d=%d\n", m.TableName, dups, len(finalList), len(foundRelations))
|
||||||
|
|
||||||
|
// return finalList
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func (m *XFiles) GetUpdatableTableNames() []string {
|
||||||
|
// foundTables := make([]string, 0)
|
||||||
|
// if m.Editable {
|
||||||
|
// foundTables = append(foundTables, m.TableName)
|
||||||
|
// }
|
||||||
|
// if m.ParentTables != nil {
|
||||||
|
// for _, pt := range m.ParentTables {
|
||||||
|
// list := pt.GetUpdatableTableNames()
|
||||||
|
// if list != nil {
|
||||||
|
// foundTables = append(foundTables, list...)
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// if m.ChildTables != nil {
|
||||||
|
// for _, ct := range m.ChildTables {
|
||||||
|
// list := ct.GetUpdatableTableNames()
|
||||||
|
// if list != nil {
|
||||||
|
// foundTables = append(foundTables, list...)
|
||||||
|
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
// return foundTables
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func (m *XFiles) preload(db *gorm.DB, pPath string, pCnt uint) (*gorm.DB, error) {
|
||||||
|
|
||||||
|
// path := pPath
|
||||||
|
// _, colval := JSONSyntaxToSQLIn(path, m.ModelType, "preload")
|
||||||
|
// if colval != "" {
|
||||||
|
// path = colval
|
||||||
|
// }
|
||||||
|
|
||||||
|
// if path == "" {
|
||||||
|
// return db, fmt.Errorf("invalid preload path %s", path)
|
||||||
|
// }
|
||||||
|
|
||||||
|
// sortList := ""
|
||||||
|
// if m.Sort != nil {
|
||||||
|
// for _, sort := range m.Sort {
|
||||||
|
// descSort := false
|
||||||
|
// if strings.HasPrefix(sort, "-") || strings.Contains(strings.ToLower(sort), " desc") {
|
||||||
|
// descSort = true
|
||||||
|
// }
|
||||||
|
// sort = strings.TrimPrefix(strings.TrimPrefix(sort, "+"), "-")
|
||||||
|
// sort = strings.ReplaceAll(strings.ReplaceAll(sort, " desc", ""), " asc", "")
|
||||||
|
// if descSort {
|
||||||
|
// sort = sort + " desc"
|
||||||
|
// }
|
||||||
|
// sortList = sort
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
// SrcColumns := reflection.GetModelSQLColumns(m.ModelType)
|
||||||
|
// Columns := make([]string, 0)
|
||||||
|
|
||||||
|
// for _, s := range SrcColumns {
|
||||||
|
// for _, v := range m.Columns {
|
||||||
|
// if strings.EqualFold(v, s) {
|
||||||
|
// Columns = append(Columns, v)
|
||||||
|
// break
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
// if len(Columns) == 0 {
|
||||||
|
// Columns = SrcColumns
|
||||||
|
// }
|
||||||
|
|
||||||
|
// chain := db
|
||||||
|
|
||||||
|
// // //Do expand where we can
|
||||||
|
// // if m.Expand {
|
||||||
|
// // ops := func(subchain *gorm.DB) *gorm.DB {
|
||||||
|
// // subchain = subchain.Select(strings.Join(m.Columns, ","))
|
||||||
|
|
||||||
|
// // if m.Filter != "" {
|
||||||
|
// // subchain = subchain.Where(m.Filter)
|
||||||
|
// // }
|
||||||
|
// // return subchain
|
||||||
|
// // }
|
||||||
|
// // chain = chain.Joins(path, ops(chain))
|
||||||
|
// // }
|
||||||
|
|
||||||
|
// //fmt.Printf("Preloading %s: %s lvl:%d \n", m.TableName, path, m.Level)
|
||||||
|
// //Do preload
|
||||||
|
// chain = chain.Preload(path, func(db *gorm.DB) *gorm.DB {
|
||||||
|
// subchain := db
|
||||||
|
|
||||||
|
// if sortList != "" {
|
||||||
|
// subchain = subchain.Order(sortList)
|
||||||
|
// }
|
||||||
|
|
||||||
|
// for _, sql := range m.SqlAnd {
|
||||||
|
// fnType, colval := JSONSyntaxToSQL(sql, m.ModelType)
|
||||||
|
// if fnType == 0 {
|
||||||
|
// colval = ValidSQL(colval, "select")
|
||||||
|
// }
|
||||||
|
// subchain = subchain.Where(colval)
|
||||||
|
// }
|
||||||
|
|
||||||
|
// for _, sql := range m.SqlOr {
|
||||||
|
// fnType, colval := JSONSyntaxToSQL(sql, m.ModelType)
|
||||||
|
// if fnType == 0 {
|
||||||
|
// colval = ValidSQL(colval, "select")
|
||||||
|
// }
|
||||||
|
// subchain = subchain.Or(colval)
|
||||||
|
// }
|
||||||
|
|
||||||
|
// limitval, err := m.Limit.Int64()
|
||||||
|
// if err == nil && limitval > 0 {
|
||||||
|
// subchain = subchain.Limit(int(limitval))
|
||||||
|
// }
|
||||||
|
|
||||||
|
// for _, j := range m.SqlJoins {
|
||||||
|
// subchain = subchain.Joins(ValidSQL(j, "select"))
|
||||||
|
// }
|
||||||
|
|
||||||
|
// offsetval, err := m.Offset.Int64()
|
||||||
|
// if err == nil && offsetval > 0 {
|
||||||
|
// subchain = subchain.Offset(int(offsetval))
|
||||||
|
// }
|
||||||
|
|
||||||
|
// cols := make([]string, 0)
|
||||||
|
|
||||||
|
// for _, col := range Columns {
|
||||||
|
// canAdd := true
|
||||||
|
// for _, omit := range m.OmitColumns {
|
||||||
|
// if col == omit {
|
||||||
|
// canAdd = false
|
||||||
|
// break
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// if canAdd {
|
||||||
|
// cols = append(cols, col)
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
// for i, col := range m.CQLColumns {
|
||||||
|
// cols = append(cols, fmt.Sprintf("(%s) as cql%d", col, i+1))
|
||||||
|
// }
|
||||||
|
|
||||||
|
// if len(cols) > 0 {
|
||||||
|
|
||||||
|
// colStr := strings.Join(cols, ",")
|
||||||
|
// subchain = subchain.Select(colStr)
|
||||||
|
// }
|
||||||
|
|
||||||
|
// if m.Recursive && pCnt < 5 {
|
||||||
|
// paths := strings.Split(path, ".")
|
||||||
|
|
||||||
|
// p := paths[0]
|
||||||
|
// if len(paths) > 1 {
|
||||||
|
// p = strings.Join(paths[1:], ".")
|
||||||
|
// }
|
||||||
|
// for i := uint(0); i < 3; i++ {
|
||||||
|
// inlineStr := strings.Repeat(p+".", int(i+1))
|
||||||
|
// inlineStr = strings.TrimRight(inlineStr, ".")
|
||||||
|
|
||||||
|
// fmt.Printf("Preloading Recursive (%d) %s: %s lvl:%d \n", i, m.TableName, inlineStr, m.Level)
|
||||||
|
// subchain, err = m.preload(subchain, inlineStr, pCnt+i)
|
||||||
|
// if err != nil {
|
||||||
|
// cfg.LogError("Preload (%s,%d) error: %v", m.TableName, pCnt, err)
|
||||||
|
// } else {
|
||||||
|
|
||||||
|
// if m.ChildTables != nil {
|
||||||
|
// for _, child := range m.ChildTables {
|
||||||
|
// if child.ParentEntity == nil {
|
||||||
|
// continue
|
||||||
|
// }
|
||||||
|
// subchain, _ = child.ChainPreload(subchain, inlineStr, pCnt+i)
|
||||||
|
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// if m.ParentTables != nil {
|
||||||
|
// for _, pt := range m.ParentTables {
|
||||||
|
// if pt.ParentEntity == nil {
|
||||||
|
// continue
|
||||||
|
// }
|
||||||
|
// subchain, _ = pt.ChainPreload(subchain, inlineStr, pCnt+i)
|
||||||
|
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
// }
|
||||||
|
|
||||||
|
// return subchain
|
||||||
|
// })
|
||||||
|
|
||||||
|
// return chain, nil
|
||||||
|
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func (m *XFiles) ChainPreload(db *gorm.DB, pPath string, pCnt uint) (*gorm.DB, error) {
|
||||||
|
// var err error
|
||||||
|
// chain := db
|
||||||
|
|
||||||
|
// relations := m.GetParentRelations()
|
||||||
|
// if pCnt > 10000 {
|
||||||
|
// cfg.LogError("Preload Max size (%s,%s): %v", m.TableName, pPath, err)
|
||||||
|
// return chain, nil
|
||||||
|
// }
|
||||||
|
|
||||||
|
// hasPreloadError := false
|
||||||
|
// for _, rel := range relations {
|
||||||
|
// path := rel.FieldName
|
||||||
|
// if pPath != "" {
|
||||||
|
// path = fmt.Sprintf("%s.%s", pPath, rel.FieldName)
|
||||||
|
// }
|
||||||
|
|
||||||
|
// chain, err = m.preload(chain, path, pCnt)
|
||||||
|
// if err != nil {
|
||||||
|
// cfg.LogError("Preload Error (%s,%s): %v", m.TableName, path, err)
|
||||||
|
// hasPreloadError = true
|
||||||
|
// //return chain, err
|
||||||
|
// }
|
||||||
|
|
||||||
|
// //fmt.Printf("Preloading Rel %v: %s @ %s lvl:%d \n", m.Recursive, path, m.TableName, m.Level)
|
||||||
|
// if !hasPreloadError && m.ChildTables != nil {
|
||||||
|
// for _, child := range m.ChildTables {
|
||||||
|
// if child.ParentEntity == nil {
|
||||||
|
// continue
|
||||||
|
// }
|
||||||
|
// chain, err = child.ChainPreload(chain, path, pCnt)
|
||||||
|
// if err != nil {
|
||||||
|
// return chain, err
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// if !hasPreloadError && m.ParentTables != nil {
|
||||||
|
// for _, pt := range m.ParentTables {
|
||||||
|
// if pt.ParentEntity == nil {
|
||||||
|
// continue
|
||||||
|
// }
|
||||||
|
// chain, err = pt.ChainPreload(chain, path, pCnt)
|
||||||
|
// if err != nil {
|
||||||
|
// return chain, err
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
// if len(relations) == 0 {
|
||||||
|
// if m.ChildTables != nil {
|
||||||
|
// for _, child := range m.ChildTables {
|
||||||
|
// if child.ParentEntity == nil {
|
||||||
|
// continue
|
||||||
|
// }
|
||||||
|
// chain, err = child.ChainPreload(chain, pPath, pCnt)
|
||||||
|
// if err != nil {
|
||||||
|
// return chain, err
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// if m.ParentTables != nil {
|
||||||
|
// for _, pt := range m.ParentTables {
|
||||||
|
// if pt.ParentEntity == nil {
|
||||||
|
// continue
|
||||||
|
// }
|
||||||
|
// chain, err = pt.ChainPreload(chain, pPath, pCnt)
|
||||||
|
// if err != nil {
|
||||||
|
// return chain, err
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
// return chain, nil
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func (m *XFiles) Fill() {
|
||||||
|
// m.ModelType = models.GetModelType(m.Schema, m.TableName)
|
||||||
|
|
||||||
|
// if m.ModelType == nil {
|
||||||
|
// m.Errors = append(m.Errors, fmt.Errorf("ModelType not found for %s", m.TableName))
|
||||||
|
// }
|
||||||
|
// if m.Prefix == "" {
|
||||||
|
// m.Prefix = reflection.GetTablePrefixFromType(m.ModelType)
|
||||||
|
// }
|
||||||
|
// if m.PrimaryKey == "" {
|
||||||
|
// m.PrimaryKey = reflection.GetPKNameFromType(m.ModelType)
|
||||||
|
// }
|
||||||
|
|
||||||
|
// if m.Schema == "" {
|
||||||
|
// m.Schema = reflection.GetSchemaNameFromType(m.ModelType)
|
||||||
|
// }
|
||||||
|
|
||||||
|
// for _, t := range m.ParentTables {
|
||||||
|
// t.Fill()
|
||||||
|
// }
|
||||||
|
|
||||||
|
// for _, t := range m.ChildTables {
|
||||||
|
// t.Fill()
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
// type GormRelationTypeList []reflection.GormRelationType
|
||||||
|
|
||||||
|
// func (s GormRelationTypeList) Len() int { return len(s) }
|
||||||
|
// func (s GormRelationTypeList) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
|
||||||
|
|
||||||
|
// func (s GormRelationTypeList) Less(i, j int) bool {
|
||||||
|
// if strings.HasPrefix(strings.ToLower(s[j].FieldName),
|
||||||
|
// strings.ToLower(fmt.Sprintf("%s_%s_%s", s[i].AssociationSchema, s[i].AssociationTable, s[i].AssociationKey))) {
|
||||||
|
// return true
|
||||||
|
// }
|
||||||
|
|
||||||
|
// return s[i].FieldName < s[j].FieldName
|
||||||
|
// }
|
||||||
213
pkg/restheadspec/xfiles_example.md
Normal file
213
pkg/restheadspec/xfiles_example.md
Normal file
@@ -0,0 +1,213 @@
|
|||||||
|
# X-Files Header Usage
|
||||||
|
|
||||||
|
The `x-files` header allows you to configure complex query options using a single JSON object. The XFiles configuration is parsed and populates the `ExtendedRequestOptions` fields, which means it integrates seamlessly with the existing query building system.
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
|
When an `x-files` header is received:
|
||||||
|
1. It's parsed into an `XFiles` struct
|
||||||
|
2. The `XFiles` fields populate the `ExtendedRequestOptions` (columns, filters, sort, preload, etc.)
|
||||||
|
3. The normal query building process applies these options to the SQL query
|
||||||
|
4. This allows x-files to work alongside individual headers if needed
|
||||||
|
|
||||||
|
## Basic Example
|
||||||
|
|
||||||
|
```http
|
||||||
|
GET /public/users
|
||||||
|
X-Files: {"tablename":"users","columns":["id","name","email"],"limit":"10","offset":"0"}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Complete Example
|
||||||
|
|
||||||
|
```http
|
||||||
|
GET /public/users
|
||||||
|
X-Files: {
|
||||||
|
"tablename": "users",
|
||||||
|
"schema": "public",
|
||||||
|
"columns": ["id", "name", "email", "created_at"],
|
||||||
|
"omit_columns": [],
|
||||||
|
"sort": ["-created_at", "name"],
|
||||||
|
"limit": "50",
|
||||||
|
"offset": "0",
|
||||||
|
"filter_fields": [
|
||||||
|
{
|
||||||
|
"field": "status",
|
||||||
|
"operator": "eq",
|
||||||
|
"value": "active"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"field": "age",
|
||||||
|
"operator": "gt",
|
||||||
|
"value": "18"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"sql_and": ["deleted_at IS NULL"],
|
||||||
|
"sql_or": [],
|
||||||
|
"cql_columns": ["UPPER(name)"],
|
||||||
|
"skipcount": false,
|
||||||
|
"distinct": false
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Supported Filter Operators
|
||||||
|
|
||||||
|
- `eq` - equals
|
||||||
|
- `neq` - not equals
|
||||||
|
- `gt` - greater than
|
||||||
|
- `gte` - greater than or equals
|
||||||
|
- `lt` - less than
|
||||||
|
- `lte` - less than or equals
|
||||||
|
- `like` - SQL LIKE
|
||||||
|
- `ilike` - case-insensitive LIKE
|
||||||
|
- `in` - IN clause
|
||||||
|
- `between` - between (exclusive)
|
||||||
|
- `between_inclusive` - between (inclusive)
|
||||||
|
- `is_null` - is NULL
|
||||||
|
- `is_not_null` - is NOT NULL
|
||||||
|
|
||||||
|
## Sorting
|
||||||
|
|
||||||
|
Sort fields can be prefixed with:
|
||||||
|
- `+` for ascending (default)
|
||||||
|
- `-` for descending
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
- `"sort": ["name"]` - ascending by name
|
||||||
|
- `"sort": ["-created_at"]` - descending by created_at
|
||||||
|
- `"sort": ["-created_at", "name"]` - multiple sorts
|
||||||
|
|
||||||
|
## Computed Columns (CQL)
|
||||||
|
|
||||||
|
Use `cql_columns` to add computed SQL expressions:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"cql_columns": [
|
||||||
|
"UPPER(name)",
|
||||||
|
"CONCAT(first_name, ' ', last_name)"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
These will be available as `cql1`, `cql2`, etc. in the response.
|
||||||
|
|
||||||
|
## Cursor Pagination
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"cursor_forward": "eyJpZCI6MTAwfQ==",
|
||||||
|
"cursor_backward": ""
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Base64 Encoding
|
||||||
|
|
||||||
|
For complex JSON, you can base64-encode the value and prefix it with `ZIP_` or `__`:
|
||||||
|
|
||||||
|
```http
|
||||||
|
GET /public/users
|
||||||
|
X-Files: ZIP_eyJ0YWJsZW5hbWUiOiJ1c2VycyIsImxpbWl0IjoiMTAifQ==
|
||||||
|
```
|
||||||
|
|
||||||
|
## XFiles Struct Reference
|
||||||
|
|
||||||
|
```go
|
||||||
|
type XFiles struct {
|
||||||
|
TableName string `json:"tablename"`
|
||||||
|
Schema string `json:"schema"`
|
||||||
|
PrimaryKey string `json:"primarykey"`
|
||||||
|
ForeignKey string `json:"foreignkey"`
|
||||||
|
RelatedKey string `json:"relatedkey"`
|
||||||
|
Sort []string `json:"sort"`
|
||||||
|
Prefix string `json:"prefix"`
|
||||||
|
Editable bool `json:"editable"`
|
||||||
|
Recursive bool `json:"recursive"`
|
||||||
|
Expand bool `json:"expand"`
|
||||||
|
Rownumber bool `json:"rownumber"`
|
||||||
|
Skipcount bool `json:"skipcount"`
|
||||||
|
Offset json.Number `json:"offset"`
|
||||||
|
Limit json.Number `json:"limit"`
|
||||||
|
Columns []string `json:"columns"`
|
||||||
|
OmitColumns []string `json:"omit_columns"`
|
||||||
|
CQLColumns []string `json:"cql_columns"`
|
||||||
|
SqlJoins []string `json:"sql_joins"`
|
||||||
|
SqlOr []string `json:"sql_or"`
|
||||||
|
SqlAnd []string `json:"sql_and"`
|
||||||
|
FilterFields []struct {
|
||||||
|
Field string `json:"field"`
|
||||||
|
Value string `json:"value"`
|
||||||
|
Operator string `json:"operator"`
|
||||||
|
} `json:"filter_fields"`
|
||||||
|
CursorForward string `json:"cursor_forward"`
|
||||||
|
CursorBackward string `json:"cursor_backward"`
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Recursive Preloading with ParentTables and ChildTables
|
||||||
|
|
||||||
|
XFiles now supports recursive preloading of related entities:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"tablename": "users",
|
||||||
|
"columns": ["id", "name"],
|
||||||
|
"limit": "10",
|
||||||
|
"parenttables": [
|
||||||
|
{
|
||||||
|
"tablename": "Company",
|
||||||
|
"columns": ["id", "name", "industry"],
|
||||||
|
"sort": ["-created_at"]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"childtables": [
|
||||||
|
{
|
||||||
|
"tablename": "Orders",
|
||||||
|
"columns": ["id", "total", "status"],
|
||||||
|
"limit": "5",
|
||||||
|
"sort": ["-order_date"],
|
||||||
|
"filter_fields": [
|
||||||
|
{"field": "status", "operator": "eq", "value": "completed"}
|
||||||
|
],
|
||||||
|
"childtables": [
|
||||||
|
{
|
||||||
|
"tablename": "OrderItems",
|
||||||
|
"columns": ["id", "product_name", "quantity"],
|
||||||
|
"recursive": true
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### How Recursive Preloading Works
|
||||||
|
|
||||||
|
- **ParentTables**: Preloads parent relationships (e.g., User -> Company)
|
||||||
|
- **ChildTables**: Preloads child relationships (e.g., User -> Orders -> OrderItems)
|
||||||
|
- **Recursive**: When `true`, continues preloading the same relation recursively
|
||||||
|
- Each nested table can have its own:
|
||||||
|
- Column selection (`columns`, `omit_columns`)
|
||||||
|
- Filtering (`filter_fields`, `sql_and`)
|
||||||
|
- Sorting (`sort`)
|
||||||
|
- Pagination (`limit`)
|
||||||
|
- Further nesting (`parenttables`, `childtables`)
|
||||||
|
|
||||||
|
### Relation Path Building
|
||||||
|
|
||||||
|
Relations are built as dot-separated paths:
|
||||||
|
- `Company` (direct parent)
|
||||||
|
- `Orders` (direct child)
|
||||||
|
- `Orders.OrderItems` (nested child)
|
||||||
|
- `Orders.OrderItems.Product` (deeply nested)
|
||||||
|
|
||||||
|
## Notes
|
||||||
|
|
||||||
|
- Individual headers (like `x-select-fields`, `x-sort`, etc.) can still be used alongside `x-files`
|
||||||
|
- X-Files populates `ExtendedRequestOptions` which is then processed by the normal query building logic
|
||||||
|
- ParentTables and ChildTables are converted to `PreloadOption` entries with full support for:
|
||||||
|
- Column selection
|
||||||
|
- Filtering
|
||||||
|
- Sorting
|
||||||
|
- Limit
|
||||||
|
- Recursive nesting
|
||||||
|
- The relation name in ParentTables/ChildTables should match the GORM/Bun relation field name on the model
|
||||||
@@ -372,7 +372,14 @@ func testRestHeadSpecCRUD(t *testing.T, serverURL string) {
|
|||||||
|
|
||||||
var result map[string]interface{}
|
var result map[string]interface{}
|
||||||
json.NewDecoder(resp.Body).Decode(&result)
|
json.NewDecoder(resp.Body).Decode(&result)
|
||||||
assert.True(t, result["success"].(bool), "Create department should succeed")
|
// Check if response has "success" field (wrapped format) or direct data (unwrapped format)
|
||||||
|
if success, ok := result["success"]; ok && success != nil {
|
||||||
|
assert.True(t, success.(bool), "Create department should succeed")
|
||||||
|
} else {
|
||||||
|
// Unwrapped format - verify we got the created data back
|
||||||
|
assert.NotEmpty(t, result, "Create department should return data")
|
||||||
|
assert.Equal(t, deptID, result["id"], "Created department should have correct ID")
|
||||||
|
}
|
||||||
logger.Info("Department created successfully: %s", deptID)
|
logger.Info("Department created successfully: %s", deptID)
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -393,7 +400,14 @@ func testRestHeadSpecCRUD(t *testing.T, serverURL string) {
|
|||||||
|
|
||||||
var result map[string]interface{}
|
var result map[string]interface{}
|
||||||
json.NewDecoder(resp.Body).Decode(&result)
|
json.NewDecoder(resp.Body).Decode(&result)
|
||||||
assert.True(t, result["success"].(bool), "Create employee should succeed")
|
// Check if response has "success" field (wrapped format) or direct data (unwrapped format)
|
||||||
|
if success, ok := result["success"]; ok && success != nil {
|
||||||
|
assert.True(t, success.(bool), "Create employee should succeed")
|
||||||
|
} else {
|
||||||
|
// Unwrapped format - verify we got the created data back
|
||||||
|
assert.NotEmpty(t, result, "Create employee should return data")
|
||||||
|
assert.Equal(t, empID, result["id"], "Created employee should have correct ID")
|
||||||
|
}
|
||||||
logger.Info("Employee created successfully: %s", empID)
|
logger.Info("Employee created successfully: %s", empID)
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -540,7 +554,13 @@ func testRestHeadSpecCRUD(t *testing.T, serverURL string) {
|
|||||||
|
|
||||||
var result map[string]interface{}
|
var result map[string]interface{}
|
||||||
json.NewDecoder(resp.Body).Decode(&result)
|
json.NewDecoder(resp.Body).Decode(&result)
|
||||||
assert.True(t, result["success"].(bool), "Update department should succeed")
|
// Check if response has "success" field (wrapped format) or direct data (unwrapped format)
|
||||||
|
if success, ok := result["success"]; ok && success != nil {
|
||||||
|
assert.True(t, success.(bool), "Update department should succeed")
|
||||||
|
} else {
|
||||||
|
// Unwrapped format - verify we got the updated data back
|
||||||
|
assert.NotEmpty(t, result, "Update department should return data")
|
||||||
|
}
|
||||||
logger.Info("Department updated successfully: %s", deptID)
|
logger.Info("Department updated successfully: %s", deptID)
|
||||||
|
|
||||||
// Verify update by reading the department again
|
// Verify update by reading the department again
|
||||||
@@ -558,7 +578,13 @@ func testRestHeadSpecCRUD(t *testing.T, serverURL string) {
|
|||||||
|
|
||||||
var result map[string]interface{}
|
var result map[string]interface{}
|
||||||
json.NewDecoder(resp.Body).Decode(&result)
|
json.NewDecoder(resp.Body).Decode(&result)
|
||||||
assert.True(t, result["success"].(bool), "Update employee should succeed")
|
// Check if response has "success" field (wrapped format) or direct data (unwrapped format)
|
||||||
|
if success, ok := result["success"]; ok && success != nil {
|
||||||
|
assert.True(t, success.(bool), "Update employee should succeed")
|
||||||
|
} else {
|
||||||
|
// Unwrapped format - verify we got the updated data back
|
||||||
|
assert.NotEmpty(t, result, "Update employee should return data")
|
||||||
|
}
|
||||||
logger.Info("Employee updated successfully: %s", empID)
|
logger.Info("Employee updated successfully: %s", empID)
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -569,7 +595,13 @@ func testRestHeadSpecCRUD(t *testing.T, serverURL string) {
|
|||||||
|
|
||||||
var result map[string]interface{}
|
var result map[string]interface{}
|
||||||
json.NewDecoder(resp.Body).Decode(&result)
|
json.NewDecoder(resp.Body).Decode(&result)
|
||||||
assert.True(t, result["success"].(bool), "Delete employee should succeed")
|
// Check if response has "success" field (wrapped format) or direct data (unwrapped format)
|
||||||
|
if success, ok := result["success"]; ok && success != nil {
|
||||||
|
assert.True(t, success.(bool), "Delete employee should succeed")
|
||||||
|
} else {
|
||||||
|
// Unwrapped format - verify we got a response (typically {"deleted": count})
|
||||||
|
assert.NotEmpty(t, result, "Delete employee should return data")
|
||||||
|
}
|
||||||
logger.Info("Employee deleted successfully: %s", empID)
|
logger.Info("Employee deleted successfully: %s", empID)
|
||||||
|
|
||||||
// Verify deletion - just log that delete succeeded
|
// Verify deletion - just log that delete succeeded
|
||||||
@@ -582,7 +614,13 @@ func testRestHeadSpecCRUD(t *testing.T, serverURL string) {
|
|||||||
|
|
||||||
var result map[string]interface{}
|
var result map[string]interface{}
|
||||||
json.NewDecoder(resp.Body).Decode(&result)
|
json.NewDecoder(resp.Body).Decode(&result)
|
||||||
assert.True(t, result["success"].(bool), "Delete department should succeed")
|
// Check if response has "success" field (wrapped format) or direct data (unwrapped format)
|
||||||
|
if success, ok := result["success"]; ok && success != nil {
|
||||||
|
assert.True(t, success.(bool), "Delete department should succeed")
|
||||||
|
} else {
|
||||||
|
// Unwrapped format - verify we got a response (typically {"deleted": count})
|
||||||
|
assert.NotEmpty(t, result, "Delete department should return data")
|
||||||
|
}
|
||||||
logger.Info("Department deleted successfully: %s", deptID)
|
logger.Info("Department deleted successfully: %s", deptID)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user