mirror of
https://github.com/bitechdev/ResolveSpec.git
synced 2025-12-29 07:44:25 +00:00
Compare commits
10 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b0b3ae662b | ||
|
|
c9b9f75b06 | ||
|
|
af3260864d | ||
|
|
ca6d2deff6 | ||
|
|
1481443516 | ||
|
|
cb54ec5e27 | ||
|
|
7d6a9025f5 | ||
|
|
35089f511f | ||
|
|
66b6a0d835 | ||
|
|
456c165814 |
19
go.mod
19
go.mod
@@ -8,7 +8,12 @@ require (
|
||||
github.com/glebarez/sqlite v1.11.0
|
||||
github.com/gorilla/mux v1.8.1
|
||||
github.com/stretchr/testify v1.8.1
|
||||
github.com/tidwall/gjson v1.18.0
|
||||
github.com/tidwall/sjson v1.2.5
|
||||
github.com/uptrace/bun v1.2.15
|
||||
github.com/uptrace/bun/dialect/sqlitedialect v1.2.15
|
||||
github.com/uptrace/bun/driver/sqliteshim v1.2.15
|
||||
github.com/uptrace/bunrouter v1.0.23
|
||||
go.uber.org/zap v1.27.0
|
||||
gorm.io/gorm v1.25.12
|
||||
)
|
||||
@@ -21,23 +26,23 @@ require (
|
||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||
github.com/jinzhu/now v1.1.5 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/mattn/go-sqlite3 v1.14.28 // indirect
|
||||
github.com/ncruces/go-strftime v0.1.9 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/puzpuzpuz/xsync/v3 v3.5.1 // indirect
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
||||
github.com/tidwall/gjson v1.18.0 // indirect
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
github.com/tidwall/pretty v1.2.0 // indirect
|
||||
github.com/tidwall/sjson v1.2.5 // indirect
|
||||
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc // indirect
|
||||
github.com/uptrace/bunrouter v1.0.23 // indirect
|
||||
github.com/vmihailenco/msgpack/v5 v5.4.1 // indirect
|
||||
github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
|
||||
go.uber.org/multierr v1.10.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20250711185948-6ae5c78190dc // indirect
|
||||
golang.org/x/sys v0.34.0 // indirect
|
||||
golang.org/x/text v0.21.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
modernc.org/libc v1.22.5 // indirect
|
||||
modernc.org/mathutil v1.5.0 // indirect
|
||||
modernc.org/memory v1.5.0 // indirect
|
||||
modernc.org/sqlite v1.23.1 // indirect
|
||||
modernc.org/libc v1.66.3 // indirect
|
||||
modernc.org/mathutil v1.7.1 // indirect
|
||||
modernc.org/memory v1.11.0 // indirect
|
||||
modernc.org/sqlite v1.38.0 // indirect
|
||||
)
|
||||
|
||||
55
go.sum
55
go.sum
@@ -7,8 +7,8 @@ github.com/glebarez/go-sqlite v1.21.2 h1:3a6LFC4sKahUunAmynQKLZceZCOzUthkRkEAl9g
|
||||
github.com/glebarez/go-sqlite v1.21.2/go.mod h1:sfxdZyhQjTM2Wry3gVYWaW072Ri1WMdWJi0k6+3382k=
|
||||
github.com/glebarez/sqlite v1.11.0 h1:wSG0irqzP6VurnMEpFGer5Li19RpIRi2qvQz++w0GMw=
|
||||
github.com/glebarez/sqlite v1.11.0/go.mod h1:h8/o8j5wiAsqSPoWELDUdJXhjAhsVliSn7bWZjOhrgQ=
|
||||
github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26 h1:Xim43kblpZXfIBQsbuBVKCudVG457BR2GZFIz3uw3hQ=
|
||||
github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26/go.mod h1:dDKJzRmX4S37WGHujM7tX//fmj1uioxKzKxz3lo4HJo=
|
||||
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
|
||||
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY=
|
||||
@@ -21,13 +21,16 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-sqlite3 v1.14.28 h1:ThEiQrnbtumT+QMknw63Befp/ce/nUPgBPMlRFEum7A=
|
||||
github.com/mattn/go-sqlite3 v1.14.28/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
|
||||
github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4=
|
||||
github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
|
||||
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWbfPhv4DMiApHyliiK5xCTNVSPiaAs=
|
||||
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/puzpuzpuz/xsync/v3 v3.5.1 h1:GJYJZwO6IdxN/IKbneznS6yPkVC+c3zyY/j19c++5Fg=
|
||||
github.com/puzpuzpuz/xsync/v3 v3.5.1/go.mod h1:VjzYrABPabuM4KyBh1Ftq6u8nhwY5tBPKP9jpmh0nnA=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
@@ -50,6 +53,10 @@ github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc h1:9lRDQMhESg+zvGYm
|
||||
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc/go.mod h1:bciPuU6GHm1iF1pBvUfxfsH0Wmnc2VbpgvbI9ZWuIRs=
|
||||
github.com/uptrace/bun v1.2.15 h1:Ut68XRBLDgp9qG9QBMa9ELWaZOmzHNdczHQdrOZbEFE=
|
||||
github.com/uptrace/bun v1.2.15/go.mod h1:Eghz7NonZMiTX/Z6oKYytJ0oaMEJ/eq3kEV4vSqG038=
|
||||
github.com/uptrace/bun/dialect/sqlitedialect v1.2.15 h1:7upGMVjFRB1oI78GQw6ruNLblYn5CR+kxqcbbeBBils=
|
||||
github.com/uptrace/bun/dialect/sqlitedialect v1.2.15/go.mod h1:c7YIDaPNS2CU2uI1p7umFuFWkuKbDcPDDvp+DLHZnkI=
|
||||
github.com/uptrace/bun/driver/sqliteshim v1.2.15 h1:M/rZJSjOPV4OmfTVnDPtL+wJmdMTqDUn8cuk5ycfABA=
|
||||
github.com/uptrace/bun/driver/sqliteshim v1.2.15/go.mod h1:YqwxFyvM992XOCpGJtXyKPkgkb+aZpIIMzGbpaw1hIk=
|
||||
github.com/uptrace/bunrouter v1.0.23 h1:Bi7NKw3uCQkcA/GUCtDNPq5LE5UdR9pe+UyWbjHB/wU=
|
||||
github.com/uptrace/bunrouter v1.0.23/go.mod h1:O3jAcl+5qgnF+ejhgkmbceEk0E/mqaK+ADOocdNpY8M=
|
||||
github.com/vmihailenco/msgpack/v5 v5.4.1 h1:cQriyiUvjTwOHg8QZaPihLWeRAAVoCpE00IUPn0Bjt8=
|
||||
@@ -62,11 +69,19 @@ go.uber.org/multierr v1.10.0 h1:S0h4aNzvfcFsC3dRF1jLoaov7oRaKqRGC/pUEJ2yvPQ=
|
||||
go.uber.org/multierr v1.10.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
|
||||
go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8=
|
||||
go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
|
||||
golang.org/x/exp v0.0.0-20250711185948-6ae5c78190dc h1:TS73t7x3KarrNd5qAipmspBDS1rkMcgVG/fS1aRb4Rc=
|
||||
golang.org/x/exp v0.0.0-20250711185948-6ae5c78190dc/go.mod h1:A+z0yzpGtvnG90cToK5n2tu8UJVP2XUATh+r+sfOOOc=
|
||||
golang.org/x/mod v0.26.0 h1:EGMPT//Ezu+ylkCijjPc+f4Aih7sZvaAr+O3EHBxvZg=
|
||||
golang.org/x/mod v0.26.0/go.mod h1:/j6NAhSk8iQ723BGAUyoAcn7SlD7s15Dp9Nd/SfeaFQ=
|
||||
golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw=
|
||||
golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.34.0 h1:H5Y5sJ2L2JRdyv7ROF1he/lPdvFsd0mJHFw2ThKHxLA=
|
||||
golang.org/x/sys v0.34.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
|
||||
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
|
||||
golang.org/x/tools v0.35.0 h1:mBffYraMEf7aa0sB+NuKnuCy8qI/9Bughn8dC2Gu5r0=
|
||||
golang.org/x/tools v0.35.0/go.mod h1:NKdj5HkL/73byiZSJjqJgKn3ep7KjFkBOkR/Hps3VPw=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f h1:BLraFXnmrev5lT+xlilqcH8XK9/i0At2xKjWk4p6zsU=
|
||||
gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
@@ -75,11 +90,29 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gorm.io/gorm v1.25.12 h1:I0u8i2hWQItBq1WfE0o2+WuL9+8L21K9e2HHSTE/0f8=
|
||||
gorm.io/gorm v1.25.12/go.mod h1:xh7N7RHfYlNc5EmcI/El95gXusucDrQnHXe0+CgWcLQ=
|
||||
modernc.org/libc v1.22.5 h1:91BNch/e5B0uPbJFgqbxXuOnxBQjlS//icfQEGmvyjE=
|
||||
modernc.org/libc v1.22.5/go.mod h1:jj+Z7dTNX8fBScMVNRAYZ/jF91K8fdT2hYMThc3YjBY=
|
||||
modernc.org/mathutil v1.5.0 h1:rV0Ko/6SfM+8G+yKiyI830l3Wuz1zRutdslNoQ0kfiQ=
|
||||
modernc.org/mathutil v1.5.0/go.mod h1:mZW8CKdRPY1v87qxC/wUdX5O1qDzXMP5TH3wjfpga6E=
|
||||
modernc.org/memory v1.5.0 h1:N+/8c5rE6EqugZwHii4IFsaJ7MUhoWX07J5tC/iI5Ds=
|
||||
modernc.org/memory v1.5.0/go.mod h1:PkUhL0Mugw21sHPeskwZW4D6VscE/GQJOnIpCnW6pSU=
|
||||
modernc.org/sqlite v1.23.1 h1:nrSBg4aRQQwq59JpvGEQ15tNxoO5pX/kUjcRNwSAGQM=
|
||||
modernc.org/sqlite v1.23.1/go.mod h1:OrDj17Mggn6MhE+iPbBNf7RGKODDE9NFT0f3EwDzJqk=
|
||||
modernc.org/cc/v4 v4.26.2 h1:991HMkLjJzYBIfha6ECZdjrIYz2/1ayr+FL8GN+CNzM=
|
||||
modernc.org/cc/v4 v4.26.2/go.mod h1:uVtb5OGqUKpoLWhqwNQo/8LwvoiEBLvZXIQ/SmO6mL0=
|
||||
modernc.org/ccgo/v4 v4.28.0 h1:rjznn6WWehKq7dG4JtLRKxb52Ecv8OUGah8+Z/SfpNU=
|
||||
modernc.org/ccgo/v4 v4.28.0/go.mod h1:JygV3+9AV6SmPhDasu4JgquwU81XAKLd3OKTUDNOiKE=
|
||||
modernc.org/fileutil v1.3.8 h1:qtzNm7ED75pd1C7WgAGcK4edm4fvhtBsEiI/0NQ54YM=
|
||||
modernc.org/fileutil v1.3.8/go.mod h1:HxmghZSZVAz/LXcMNwZPA/DRrQZEVP9VX0V4LQGQFOc=
|
||||
modernc.org/gc/v2 v2.6.5 h1:nyqdV8q46KvTpZlsw66kWqwXRHdjIlJOhG6kxiV/9xI=
|
||||
modernc.org/gc/v2 v2.6.5/go.mod h1:YgIahr1ypgfe7chRuJi2gD7DBQiKSLMPgBQe9oIiito=
|
||||
modernc.org/goabi0 v0.2.0 h1:HvEowk7LxcPd0eq6mVOAEMai46V+i7Jrj13t4AzuNks=
|
||||
modernc.org/goabi0 v0.2.0/go.mod h1:CEFRnnJhKvWT1c1JTI3Avm+tgOWbkOu5oPA8eH8LnMI=
|
||||
modernc.org/libc v1.66.3 h1:cfCbjTUcdsKyyZZfEUKfoHcP3S0Wkvz3jgSzByEWVCQ=
|
||||
modernc.org/libc v1.66.3/go.mod h1:XD9zO8kt59cANKvHPXpx7yS2ELPheAey0vjIuZOhOU8=
|
||||
modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU=
|
||||
modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg=
|
||||
modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI=
|
||||
modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw=
|
||||
modernc.org/opt v0.1.4 h1:2kNGMRiUjrp4LcaPuLY2PzUfqM/w9N23quVwhKt5Qm8=
|
||||
modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns=
|
||||
modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w=
|
||||
modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE=
|
||||
modernc.org/sqlite v1.38.0 h1:+4OrfPQ8pxHKuWG4md1JpR/EYAh3Md7TdejuuzE7EUI=
|
||||
modernc.org/sqlite v1.38.0/go.mod h1:1Bj+yES4SVvBZ4cBOpVZ6QgesMCKpJZDq0nxYzOpmNE=
|
||||
modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0=
|
||||
modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A=
|
||||
modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y=
|
||||
modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM=
|
||||
|
||||
@@ -307,12 +307,14 @@ func (b *BunSelectQuery) Exists(ctx context.Context) (bool, error) {
|
||||
|
||||
// BunInsertQuery implements InsertQuery for Bun
|
||||
type BunInsertQuery struct {
|
||||
query *bun.InsertQuery
|
||||
values map[string]interface{}
|
||||
query *bun.InsertQuery
|
||||
values map[string]interface{}
|
||||
hasModel bool
|
||||
}
|
||||
|
||||
func (b *BunInsertQuery) Model(model interface{}) common.InsertQuery {
|
||||
b.query = b.query.Model(model)
|
||||
b.hasModel = true
|
||||
return b
|
||||
}
|
||||
|
||||
@@ -342,10 +344,16 @@ func (b *BunInsertQuery) Returning(columns ...string) common.InsertQuery {
|
||||
}
|
||||
|
||||
func (b *BunInsertQuery) Exec(ctx context.Context) (common.Result, error) {
|
||||
if b.values != nil {
|
||||
// For Bun, we need to handle this differently
|
||||
for k, v := range b.values {
|
||||
b.query = b.query.Set("? = ?", bun.Ident(k), v)
|
||||
if b.values != nil && len(b.values) > 0 {
|
||||
if !b.hasModel {
|
||||
// If no model was set, use the values map as the model
|
||||
// Bun can insert map[string]interface{} directly
|
||||
b.query = b.query.Model(&b.values)
|
||||
} else {
|
||||
// If model was set, use Value() to add individual values
|
||||
for k, v := range b.values {
|
||||
b.query = b.query.Value(k, "?", v)
|
||||
}
|
||||
}
|
||||
}
|
||||
result, err := b.query.Exec(ctx)
|
||||
@@ -388,12 +396,17 @@ func (b *BunUpdateQuery) Set(column string, value interface{}) common.UpdateQuer
|
||||
}
|
||||
|
||||
func (b *BunUpdateQuery) SetMap(values map[string]interface{}) common.UpdateQuery {
|
||||
pkName := reflection.GetPrimaryKeyName(b.model)
|
||||
for column, value := range values {
|
||||
// Validate column is writable if model is set
|
||||
if b.model != nil && !reflection.IsColumnWritable(b.model, column) {
|
||||
// Skip scan-only columns
|
||||
continue
|
||||
}
|
||||
if pkName != "" && column == pkName {
|
||||
// Skip primary key updates
|
||||
continue
|
||||
}
|
||||
b.query = b.query.Set(column+" = ?", value)
|
||||
}
|
||||
return b
|
||||
|
||||
213
pkg/common/adapters/database/bun_insert_test.go
Normal file
213
pkg/common/adapters/database/bun_insert_test.go
Normal file
@@ -0,0 +1,213 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/uptrace/bun"
|
||||
"github.com/uptrace/bun/dialect/sqlitedialect"
|
||||
"github.com/uptrace/bun/driver/sqliteshim"
|
||||
)
|
||||
|
||||
// TestInsertModel is a test model for insert operations
|
||||
type TestInsertModel struct {
|
||||
bun.BaseModel `bun:"table:test_inserts"`
|
||||
ID int64 `bun:"id,pk,autoincrement"`
|
||||
Name string `bun:"name,notnull"`
|
||||
Email string `bun:"email"`
|
||||
Age int `bun:"age"`
|
||||
}
|
||||
|
||||
func setupBunTestDB(t *testing.T) *bun.DB {
|
||||
sqldb, err := sql.Open(sqliteshim.ShimName, "file::memory:?cache=shared")
|
||||
require.NoError(t, err, "Failed to open SQLite database")
|
||||
|
||||
db := bun.NewDB(sqldb, sqlitedialect.New())
|
||||
|
||||
// Create test table
|
||||
_, err = db.NewCreateTable().
|
||||
Model((*TestInsertModel)(nil)).
|
||||
IfNotExists().
|
||||
Exec(context.Background())
|
||||
require.NoError(t, err, "Failed to create test table")
|
||||
|
||||
return db
|
||||
}
|
||||
|
||||
func TestBunInsertQuery_Model(t *testing.T) {
|
||||
db := setupBunTestDB(t)
|
||||
defer db.Close()
|
||||
|
||||
adapter := NewBunAdapter(db)
|
||||
ctx := context.Background()
|
||||
|
||||
// Test inserting with Model()
|
||||
model := &TestInsertModel{
|
||||
Name: "John Doe",
|
||||
Email: "john@example.com",
|
||||
Age: 30,
|
||||
}
|
||||
|
||||
result, err := adapter.NewInsert().
|
||||
Model(model).
|
||||
Returning("*").
|
||||
Exec(ctx)
|
||||
|
||||
require.NoError(t, err, "Insert should succeed")
|
||||
assert.Equal(t, int64(1), result.RowsAffected(), "Should insert 1 row")
|
||||
|
||||
// Verify the data was inserted
|
||||
var retrieved TestInsertModel
|
||||
err = db.NewSelect().
|
||||
Model(&retrieved).
|
||||
Where("id = ?", model.ID).
|
||||
Scan(ctx)
|
||||
|
||||
require.NoError(t, err, "Should retrieve inserted row")
|
||||
assert.Equal(t, "John Doe", retrieved.Name)
|
||||
assert.Equal(t, "john@example.com", retrieved.Email)
|
||||
assert.Equal(t, 30, retrieved.Age)
|
||||
}
|
||||
|
||||
func TestBunInsertQuery_Value(t *testing.T) {
|
||||
db := setupBunTestDB(t)
|
||||
defer db.Close()
|
||||
|
||||
adapter := NewBunAdapter(db)
|
||||
ctx := context.Background()
|
||||
|
||||
// Test inserting with Value() method - this was the bug
|
||||
result, err := adapter.NewInsert().
|
||||
Table("test_inserts").
|
||||
Value("name", "Jane Smith").
|
||||
Value("email", "jane@example.com").
|
||||
Value("age", 25).
|
||||
Exec(ctx)
|
||||
|
||||
require.NoError(t, err, "Insert with Value() should succeed")
|
||||
assert.Equal(t, int64(1), result.RowsAffected(), "Should insert 1 row")
|
||||
|
||||
// Verify the data was inserted
|
||||
var retrieved TestInsertModel
|
||||
err = db.NewSelect().
|
||||
Model(&retrieved).
|
||||
Where("name = ?", "Jane Smith").
|
||||
Scan(ctx)
|
||||
|
||||
require.NoError(t, err, "Should retrieve inserted row")
|
||||
assert.Equal(t, "Jane Smith", retrieved.Name)
|
||||
assert.Equal(t, "jane@example.com", retrieved.Email)
|
||||
assert.Equal(t, 25, retrieved.Age)
|
||||
}
|
||||
|
||||
func TestBunInsertQuery_MultipleValues(t *testing.T) {
|
||||
db := setupBunTestDB(t)
|
||||
defer db.Close()
|
||||
|
||||
adapter := NewBunAdapter(db)
|
||||
ctx := context.Background()
|
||||
|
||||
// Test inserting multiple values
|
||||
result, err := adapter.NewInsert().
|
||||
Table("test_inserts").
|
||||
Value("name", "Alice").
|
||||
Value("email", "alice@example.com").
|
||||
Value("age", 28).
|
||||
Exec(ctx)
|
||||
|
||||
require.NoError(t, err, "First insert should succeed")
|
||||
assert.Equal(t, int64(1), result.RowsAffected())
|
||||
|
||||
result, err = adapter.NewInsert().
|
||||
Table("test_inserts").
|
||||
Value("name", "Bob").
|
||||
Value("email", "bob@example.com").
|
||||
Value("age", 35).
|
||||
Exec(ctx)
|
||||
|
||||
require.NoError(t, err, "Second insert should succeed")
|
||||
assert.Equal(t, int64(1), result.RowsAffected())
|
||||
|
||||
// Verify both rows exist
|
||||
var count int
|
||||
count, err = db.NewSelect().
|
||||
Model((*TestInsertModel)(nil)).
|
||||
Count(ctx)
|
||||
|
||||
require.NoError(t, err, "Count should succeed")
|
||||
assert.Equal(t, 2, count, "Should have 2 rows")
|
||||
}
|
||||
|
||||
func TestBunInsertQuery_ValueWithNil(t *testing.T) {
|
||||
db := setupBunTestDB(t)
|
||||
defer db.Close()
|
||||
|
||||
adapter := NewBunAdapter(db)
|
||||
ctx := context.Background()
|
||||
|
||||
// Test inserting with nil value for nullable field
|
||||
result, err := adapter.NewInsert().
|
||||
Table("test_inserts").
|
||||
Value("name", "Test User").
|
||||
Value("email", nil). // NULL email
|
||||
Value("age", 20).
|
||||
Exec(ctx)
|
||||
|
||||
require.NoError(t, err, "Insert with nil value should succeed")
|
||||
assert.Equal(t, int64(1), result.RowsAffected())
|
||||
|
||||
// Verify the data was inserted with NULL email
|
||||
var retrieved TestInsertModel
|
||||
err = db.NewSelect().
|
||||
Model(&retrieved).
|
||||
Where("name = ?", "Test User").
|
||||
Scan(ctx)
|
||||
|
||||
require.NoError(t, err, "Should retrieve inserted row")
|
||||
assert.Equal(t, "Test User", retrieved.Name)
|
||||
assert.Equal(t, "", retrieved.Email) // NULL becomes empty string
|
||||
assert.Equal(t, 20, retrieved.Age)
|
||||
}
|
||||
|
||||
func TestBunInsertQuery_Returning(t *testing.T) {
|
||||
db := setupBunTestDB(t)
|
||||
defer db.Close()
|
||||
|
||||
adapter := NewBunAdapter(db)
|
||||
ctx := context.Background()
|
||||
|
||||
// Test insert with RETURNING clause
|
||||
// Note: SQLite has limited RETURNING support, but this tests the API
|
||||
result, err := adapter.NewInsert().
|
||||
Table("test_inserts").
|
||||
Value("name", "Return Test").
|
||||
Value("email", "return@example.com").
|
||||
Value("age", 40).
|
||||
Returning("*").
|
||||
Exec(ctx)
|
||||
|
||||
require.NoError(t, err, "Insert with RETURNING should succeed")
|
||||
assert.Equal(t, int64(1), result.RowsAffected())
|
||||
}
|
||||
|
||||
func TestBunInsertQuery_EmptyValues(t *testing.T) {
|
||||
db := setupBunTestDB(t)
|
||||
defer db.Close()
|
||||
|
||||
adapter := NewBunAdapter(db)
|
||||
ctx := context.Background()
|
||||
|
||||
// Test insert without calling Value() - should use Model() or fail gracefully
|
||||
result, err := adapter.NewInsert().
|
||||
Table("test_inserts").
|
||||
Exec(ctx)
|
||||
|
||||
// This should fail because no values are provided
|
||||
assert.Error(t, err, "Insert without values should fail")
|
||||
if result != nil {
|
||||
assert.Equal(t, int64(0), result.RowsAffected())
|
||||
}
|
||||
}
|
||||
@@ -369,13 +369,20 @@ func (g *GormUpdateQuery) Set(column string, value interface{}) common.UpdateQue
|
||||
}
|
||||
|
||||
func (g *GormUpdateQuery) SetMap(values map[string]interface{}) common.UpdateQuery {
|
||||
|
||||
// Filter out read-only columns if model is set
|
||||
if g.model != nil {
|
||||
pkName := reflection.GetPrimaryKeyName(g.model)
|
||||
filteredValues := make(map[string]interface{})
|
||||
for column, value := range values {
|
||||
if pkName != "" && column == pkName {
|
||||
// Skip primary key updates
|
||||
continue
|
||||
}
|
||||
if reflection.IsColumnWritable(g.model, column) {
|
||||
filteredValues[column] = value
|
||||
}
|
||||
|
||||
}
|
||||
g.updates = filteredValues
|
||||
} else {
|
||||
|
||||
@@ -111,6 +111,9 @@ func (p *NestedCUDProcessor) ProcessNestedCUD(
|
||||
// Inject parent IDs for foreign key resolution
|
||||
p.injectForeignKeys(regularData, modelType, parentIDs)
|
||||
|
||||
// Get the primary key name for this model
|
||||
pkName := reflection.GetPrimaryKeyName(model)
|
||||
|
||||
// Process based on operation
|
||||
switch strings.ToLower(operation) {
|
||||
case "insert", "create":
|
||||
@@ -128,30 +131,30 @@ func (p *NestedCUDProcessor) ProcessNestedCUD(
|
||||
}
|
||||
|
||||
case "update":
|
||||
rows, err := p.processUpdate(ctx, regularData, tableName, data["id"])
|
||||
rows, err := p.processUpdate(ctx, regularData, tableName, data[pkName])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("update failed: %w", err)
|
||||
}
|
||||
result.ID = data["id"]
|
||||
result.ID = data[pkName]
|
||||
result.AffectedRows = rows
|
||||
result.Data = regularData
|
||||
|
||||
// Process child relations for update
|
||||
if err := p.processChildRelations(ctx, "update", data["id"], relationFields, result.RelationData, modelType); err != nil {
|
||||
if err := p.processChildRelations(ctx, "update", data[pkName], relationFields, result.RelationData, modelType); err != nil {
|
||||
return nil, fmt.Errorf("failed to process child relations: %w", err)
|
||||
}
|
||||
|
||||
case "delete":
|
||||
// Process child relations first (for referential integrity)
|
||||
if err := p.processChildRelations(ctx, "delete", data["id"], relationFields, result.RelationData, modelType); err != nil {
|
||||
if err := p.processChildRelations(ctx, "delete", data[pkName], relationFields, result.RelationData, modelType); err != nil {
|
||||
return nil, fmt.Errorf("failed to process child relations before delete: %w", err)
|
||||
}
|
||||
|
||||
rows, err := p.processDelete(ctx, tableName, data["id"])
|
||||
rows, err := p.processDelete(ctx, tableName, data[pkName])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("delete failed: %w", err)
|
||||
}
|
||||
result.ID = data["id"]
|
||||
result.ID = data[pkName]
|
||||
result.AffectedRows = rows
|
||||
result.Data = regularData
|
||||
|
||||
|
||||
673
pkg/common/sql_types.go
Normal file
673
pkg/common/sql_types.go
Normal file
@@ -0,0 +1,673 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
func tryParseDT(str string) (time.Time, error) {
|
||||
var lasterror error
|
||||
tryFormats := []string{time.RFC3339,
|
||||
"2006-01-02T15:04:05.000-0700",
|
||||
"2006-01-02T15:04:05.000",
|
||||
"06-01-02T15:04:05.000",
|
||||
"2006-01-02T15:04:05",
|
||||
"2006-01-02 15:04:05",
|
||||
"02/01/2006",
|
||||
"02-01-2006",
|
||||
"2006-01-02",
|
||||
"15:04:05.000",
|
||||
"15:04:05",
|
||||
"15:04"}
|
||||
|
||||
for _, f := range tryFormats {
|
||||
tx, err := time.Parse(f, str)
|
||||
if err == nil {
|
||||
return tx, nil
|
||||
} else {
|
||||
lasterror = err
|
||||
}
|
||||
}
|
||||
|
||||
return time.Now(), lasterror
|
||||
}
|
||||
|
||||
func ToJSONDT(dt time.Time) string {
|
||||
return dt.Format(time.RFC3339)
|
||||
}
|
||||
|
||||
// SqlInt16 - A Int16 that supports SQL string
|
||||
type SqlInt16 int16
|
||||
|
||||
// Scan -
|
||||
func (n *SqlInt16) Scan(value interface{}) error {
|
||||
if value == nil {
|
||||
*n = 0
|
||||
return nil
|
||||
}
|
||||
switch v := value.(type) {
|
||||
case int:
|
||||
*n = SqlInt16(v)
|
||||
case int32:
|
||||
*n = SqlInt16(v)
|
||||
case int64:
|
||||
*n = SqlInt16(v)
|
||||
default:
|
||||
i, _ := strconv.ParseInt(fmt.Sprintf("%v", v), 10, 64)
|
||||
*n = SqlInt16(i)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Value -
|
||||
func (n SqlInt16) Value() (driver.Value, error) {
|
||||
if n == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
return int64(n), nil
|
||||
}
|
||||
|
||||
// String - Override String format of ZNullInt32
|
||||
func (n SqlInt16) String() string {
|
||||
tmstr := fmt.Sprintf("%d", n)
|
||||
return tmstr
|
||||
}
|
||||
|
||||
// UnmarshalJSON - Overre JidSON format of ZNullInt32
|
||||
func (n *SqlInt16) UnmarshalJSON(b []byte) error {
|
||||
|
||||
s := strings.Trim(strings.Trim(string(b), " "), "\"")
|
||||
|
||||
n64, err := strconv.ParseInt(s, 10, 64)
|
||||
if err == nil {
|
||||
*n = SqlInt16(n64)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// MarshalJSON - Override JSON format of time
|
||||
func (n SqlInt16) MarshalJSON() ([]byte, error) {
|
||||
return []byte(fmt.Sprintf("%d", n)), nil
|
||||
}
|
||||
|
||||
// SqlInt32 - A int32 that supports SQL string
|
||||
type SqlInt32 int32
|
||||
|
||||
// Scan -
|
||||
func (n *SqlInt32) Scan(value interface{}) error {
|
||||
if value == nil {
|
||||
*n = 0
|
||||
return nil
|
||||
}
|
||||
switch v := value.(type) {
|
||||
case int:
|
||||
*n = SqlInt32(v)
|
||||
case int32:
|
||||
*n = SqlInt32(v)
|
||||
case int64:
|
||||
*n = SqlInt32(v)
|
||||
default:
|
||||
i, _ := strconv.ParseInt(fmt.Sprintf("%v", v), 10, 64)
|
||||
*n = SqlInt32(i)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Value -
|
||||
func (n SqlInt32) Value() (driver.Value, error) {
|
||||
if n == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
return int64(n), nil
|
||||
}
|
||||
|
||||
// String - Override String format of ZNullInt32
|
||||
func (n SqlInt32) String() string {
|
||||
tmstr := fmt.Sprintf("%d", n)
|
||||
return tmstr
|
||||
}
|
||||
|
||||
// UnmarshalJSON - Overre JidSON format of ZNullInt32
|
||||
func (n *SqlInt32) UnmarshalJSON(b []byte) error {
|
||||
|
||||
s := strings.Trim(strings.Trim(string(b), " "), "\"")
|
||||
|
||||
n64, err := strconv.ParseInt(s, 10, 64)
|
||||
if err == nil {
|
||||
*n = SqlInt32(n64)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// MarshalJSON - Override JSON format of time
|
||||
func (n SqlInt32) MarshalJSON() ([]byte, error) {
|
||||
return []byte(fmt.Sprintf("%d", n)), nil
|
||||
}
|
||||
|
||||
// SqlInt64 - A int64 that supports SQL string
|
||||
type SqlInt64 int64
|
||||
|
||||
// Scan -
|
||||
func (n *SqlInt64) Scan(value interface{}) error {
|
||||
if value == nil {
|
||||
*n = 0
|
||||
return nil
|
||||
}
|
||||
switch v := value.(type) {
|
||||
case int:
|
||||
*n = SqlInt64(v)
|
||||
case int32:
|
||||
*n = SqlInt64(v)
|
||||
case uint32:
|
||||
*n = SqlInt64(v)
|
||||
case int64:
|
||||
*n = SqlInt64(v)
|
||||
case uint64:
|
||||
*n = SqlInt64(v)
|
||||
default:
|
||||
i, _ := strconv.ParseInt(fmt.Sprintf("%v", v), 10, 64)
|
||||
*n = SqlInt64(i)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Value -
|
||||
func (n SqlInt64) Value() (driver.Value, error) {
|
||||
if n == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
return int64(n), nil
|
||||
}
|
||||
|
||||
// String - Override String format of ZNullInt32
|
||||
func (n SqlInt64) String() string {
|
||||
tmstr := fmt.Sprintf("%d", n)
|
||||
return tmstr
|
||||
}
|
||||
|
||||
// UnmarshalJSON - Overre JidSON format of ZNullInt32
|
||||
func (n *SqlInt64) UnmarshalJSON(b []byte) error {
|
||||
|
||||
s := strings.Trim(strings.Trim(string(b), " "), "\"")
|
||||
|
||||
n64, err := strconv.ParseInt(s, 10, 64)
|
||||
if err == nil {
|
||||
*n = SqlInt64(n64)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// MarshalJSON - Override JSON format of time
|
||||
func (n SqlInt64) MarshalJSON() ([]byte, error) {
|
||||
return []byte(fmt.Sprintf("%d", n)), nil
|
||||
}
|
||||
|
||||
// SqlTimeStamp - Implementation of SqlTimeStamp with some interfaces.
|
||||
type SqlTimeStamp time.Time
|
||||
|
||||
// MarshalJSON - Override JSON format of time
|
||||
func (t SqlTimeStamp) MarshalJSON() ([]byte, error) {
|
||||
if time.Time(t).IsZero() {
|
||||
return []byte("null"), nil
|
||||
}
|
||||
if time.Time(t).Before(time.Date(0001, 1, 1, 0, 0, 0, 0, time.UTC)) {
|
||||
return []byte("null"), nil
|
||||
}
|
||||
tmstr := time.Time(t).Format("2006-01-02T15:04:05")
|
||||
if tmstr == "0001-01-01T00:00:00" {
|
||||
return []byte("null"), nil
|
||||
}
|
||||
return []byte(fmt.Sprintf("\"%s\"", tmstr)), nil
|
||||
}
|
||||
|
||||
// UnmarshalJSON - Override JSON format of time
|
||||
func (t *SqlTimeStamp) UnmarshalJSON(b []byte) error {
|
||||
var err error
|
||||
|
||||
if b == nil {
|
||||
t = &SqlTimeStamp{}
|
||||
return nil
|
||||
}
|
||||
s := strings.Trim(strings.Trim(string(b), " "), "\"")
|
||||
if s == "null" || s == "" || s == "0" ||
|
||||
s == "0001-01-01T00:00:00" || s == "0001-01-01" {
|
||||
t = &SqlTimeStamp{}
|
||||
return nil
|
||||
}
|
||||
|
||||
tx, err := tryParseDT(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*t = SqlTimeStamp(tx)
|
||||
return err
|
||||
}
|
||||
|
||||
// Value - SQL Value of custom date
|
||||
func (t SqlTimeStamp) Value() (driver.Value, error) {
|
||||
if t.GetTime().IsZero() || t.GetTime().Before(time.Date(0002, 1, 1, 0, 0, 0, 0, time.UTC)) {
|
||||
return nil, nil
|
||||
}
|
||||
tmstr := time.Time(t).Format("2006-01-02T15:04:05")
|
||||
if tmstr <= "0001-01-01" || tmstr == "" {
|
||||
empty := time.Time{}
|
||||
return empty, nil
|
||||
}
|
||||
|
||||
return tmstr, nil
|
||||
}
|
||||
|
||||
// Scan - Scan custom date from sql
|
||||
func (t *SqlTimeStamp) Scan(value interface{}) error {
|
||||
tm, ok := value.(time.Time)
|
||||
if ok {
|
||||
*t = SqlTimeStamp(tm)
|
||||
return nil
|
||||
}
|
||||
|
||||
str, ok := value.(string)
|
||||
if ok {
|
||||
tx, err := tryParseDT(str)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*t = SqlTimeStamp(tx)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// String - Override String format of time
|
||||
func (t SqlTimeStamp) String() string {
|
||||
return fmt.Sprintf("%s", time.Time(t).Format("2006-01-02T15:04:05"))
|
||||
}
|
||||
|
||||
// GetTime - Returns Time
|
||||
func (t SqlTimeStamp) GetTime() time.Time {
|
||||
return time.Time(t)
|
||||
}
|
||||
|
||||
// SetTime - Returns Time
|
||||
func (t *SqlTimeStamp) SetTime(pTime time.Time) {
|
||||
*t = SqlTimeStamp(pTime)
|
||||
}
|
||||
|
||||
// Format - Formats the time
|
||||
func (t SqlTimeStamp) Format(layout string) string {
|
||||
return fmt.Sprintf("%s", time.Time(t).Format(layout))
|
||||
}
|
||||
|
||||
func SqlTimeStampNow() SqlTimeStamp {
|
||||
tx := time.Now()
|
||||
|
||||
return SqlTimeStamp(tx)
|
||||
}
|
||||
|
||||
// SqlFloat64 - SQL Int
|
||||
type SqlFloat64 sql.NullFloat64
|
||||
|
||||
// Scan -
|
||||
func (n *SqlFloat64) Scan(value interface{}) error {
|
||||
newval := sql.NullFloat64{Float64: 0, Valid: false}
|
||||
if value == nil {
|
||||
newval.Valid = false
|
||||
*n = SqlFloat64(newval)
|
||||
return nil
|
||||
}
|
||||
switch v := value.(type) {
|
||||
case int:
|
||||
newval.Float64 = float64(v)
|
||||
newval.Valid = true
|
||||
case float64:
|
||||
newval.Float64 = float64(v)
|
||||
newval.Valid = true
|
||||
case float32:
|
||||
newval.Float64 = float64(v)
|
||||
newval.Valid = true
|
||||
case int64:
|
||||
newval.Float64 = float64(v)
|
||||
newval.Valid = true
|
||||
case int32:
|
||||
newval.Float64 = float64(v)
|
||||
newval.Valid = true
|
||||
case uint16:
|
||||
newval.Float64 = float64(v)
|
||||
newval.Valid = true
|
||||
case uint64:
|
||||
newval.Float64 = float64(v)
|
||||
newval.Valid = true
|
||||
case uint32:
|
||||
newval.Float64 = float64(v)
|
||||
newval.Valid = true
|
||||
default:
|
||||
i, err := strconv.ParseInt(fmt.Sprintf("%v", v), 10, 64)
|
||||
newval.Float64 = float64(i)
|
||||
if err == nil {
|
||||
newval.Valid = false
|
||||
}
|
||||
}
|
||||
|
||||
*n = SqlFloat64(newval)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Value -
|
||||
func (n SqlFloat64) Value() (driver.Value, error) {
|
||||
if !n.Valid {
|
||||
return nil, nil
|
||||
}
|
||||
return float64(n.Float64), nil
|
||||
}
|
||||
|
||||
// String -
|
||||
func (n SqlFloat64) String() string {
|
||||
if !n.Valid {
|
||||
return ""
|
||||
}
|
||||
tmstr := fmt.Sprintf("%f", n.Float64)
|
||||
return tmstr
|
||||
}
|
||||
|
||||
// UnmarshalJSON -
|
||||
func (n *SqlFloat64) UnmarshalJSON(b []byte) error {
|
||||
|
||||
s := strings.Trim(strings.Trim(string(b), " "), "\"")
|
||||
invalid := (s == "null" || s == "" || len(s) < 2) || (strings.Contains(s, "{") || strings.Contains(s, "["))
|
||||
if invalid {
|
||||
return nil
|
||||
}
|
||||
|
||||
nval, err := strconv.ParseInt(s, 10, 64)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*n = SqlFloat64(sql.NullFloat64{Valid: true, Float64: float64(nval)})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// MarshalJSON - Override JSON format of time
|
||||
func (n SqlFloat64) MarshalJSON() ([]byte, error) {
|
||||
if !n.Valid {
|
||||
return []byte("null"), nil
|
||||
}
|
||||
return []byte(fmt.Sprintf("%f", n.Float64)), nil
|
||||
}
|
||||
|
||||
// SqlDate - Implementation of SqlTime with some interfaces.
|
||||
type SqlDate time.Time
|
||||
|
||||
// UnmarshalJSON - Override JSON format of time
|
||||
func (t *SqlDate) UnmarshalJSON(b []byte) error {
|
||||
var err error
|
||||
|
||||
s := strings.Trim(strings.Trim(string(b), " "), "\"")
|
||||
if s == "null" || s == "" || s == "0" ||
|
||||
strings.HasPrefix(s, "0001-01-01T00:00:00") ||
|
||||
s == "0001-01-01" {
|
||||
t = &SqlDate{}
|
||||
return nil
|
||||
}
|
||||
|
||||
tx, err := tryParseDT(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*t = SqlDate(tx)
|
||||
return err
|
||||
}
|
||||
|
||||
// MarshalJSON - Override JSON format of time
|
||||
func (t SqlDate) MarshalJSON() ([]byte, error) {
|
||||
tmstr := time.Time(t).Format("2006-01-02") //time.RFC3339
|
||||
if strings.HasPrefix(tmstr, "0001-01-01") {
|
||||
return []byte("null"), nil
|
||||
}
|
||||
return []byte(fmt.Sprintf("\"%s\"", tmstr)), nil
|
||||
}
|
||||
|
||||
// Value - SQL Value of custom date
|
||||
func (t SqlDate) Value() (driver.Value, error) {
|
||||
var s time.Time
|
||||
tmstr := time.Time(t).Format("2006-01-02")
|
||||
if strings.HasPrefix(tmstr, "0001-01-01") || tmstr <= "0001-01-01" {
|
||||
return nil, nil
|
||||
}
|
||||
s = time.Time(t)
|
||||
|
||||
return s.Format("2006-01-02"), nil
|
||||
}
|
||||
|
||||
// Scan - Scan custom date from sql
|
||||
func (t *SqlDate) Scan(value interface{}) error {
|
||||
tm, ok := value.(time.Time)
|
||||
if ok {
|
||||
*t = SqlDate(tm)
|
||||
return nil
|
||||
}
|
||||
|
||||
str, ok := value.(string)
|
||||
if ok {
|
||||
tx, err := tryParseDT(str)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*t = SqlDate(tx)
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Int64 - Override date format in unix epoch
|
||||
func (t SqlDate) Int64() int64 {
|
||||
return time.Time(t).Unix()
|
||||
}
|
||||
|
||||
// String - Override String format of time
|
||||
func (t SqlDate) String() string {
|
||||
tmstr := time.Time(t).Format("2006-01-02") //time.RFC3339
|
||||
if strings.HasPrefix(tmstr, "0001-01-01") || strings.HasPrefix(tmstr, "1800-12-31") {
|
||||
return "0"
|
||||
}
|
||||
return tmstr
|
||||
}
|
||||
|
||||
func SqlDateNow() SqlDate {
|
||||
tx := time.Now()
|
||||
return SqlDate(tx)
|
||||
}
|
||||
|
||||
// ////////////////////// SqlTime /////////////////////////
|
||||
// SqlTime - Implementation of SqlTime with some interfaces.
|
||||
type SqlTime time.Time
|
||||
|
||||
// Int64 - Override Time format in unix epoch
|
||||
func (t SqlTime) Int64() int64 {
|
||||
return time.Time(t).Unix()
|
||||
}
|
||||
|
||||
// String - Override String format of time
|
||||
func (t SqlTime) String() string {
|
||||
return time.Time(t).Format("15:04:05")
|
||||
}
|
||||
|
||||
// UnmarshalJSON - Override JSON format of time
|
||||
func (t *SqlTime) UnmarshalJSON(b []byte) error {
|
||||
var err error
|
||||
s := strings.Trim(strings.Trim(string(b), " "), "\"")
|
||||
if s == "null" || s == "" || s == "0" ||
|
||||
s == "0001-01-01T00:00:00" || s == "00:00:00" {
|
||||
*t = SqlTime{}
|
||||
return nil
|
||||
}
|
||||
tx := time.Time{}
|
||||
tx, err = tryParseDT(s)
|
||||
*t = SqlTime(tx)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// Format - Format Function
|
||||
func (t SqlTime) Format(form string) string {
|
||||
tmstr := time.Time(t).Format(form)
|
||||
return tmstr
|
||||
}
|
||||
|
||||
// Scan - Scan custom date from sql
|
||||
func (t *SqlTime) Scan(value interface{}) error {
|
||||
tm, ok := value.(time.Time)
|
||||
if ok {
|
||||
*t = SqlTime(tm)
|
||||
return nil
|
||||
}
|
||||
|
||||
str, ok := value.(string)
|
||||
if ok {
|
||||
tx, err := tryParseDT(str)
|
||||
*t = SqlTime(tx)
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Value - SQL Value of custom date
|
||||
func (t SqlTime) Value() (driver.Value, error) {
|
||||
|
||||
s := time.Time(t)
|
||||
st := s.Format("15:04:05")
|
||||
|
||||
return st, nil
|
||||
}
|
||||
|
||||
// MarshalJSON - Override JSON format of time
|
||||
func (t SqlTime) MarshalJSON() ([]byte, error) {
|
||||
tmstr := time.Time(t).Format("15:04:05")
|
||||
if tmstr == "0001-01-01T00:00:00" || tmstr == "00:00:00" {
|
||||
return []byte("null"), nil
|
||||
}
|
||||
return []byte(fmt.Sprintf("\"%s\"", tmstr)), nil
|
||||
}
|
||||
|
||||
func SqlTimeNow() SqlTime {
|
||||
tx := time.Now()
|
||||
return SqlTime(tx)
|
||||
}
|
||||
|
||||
// SqlJSONB - Nullable JSONB String
|
||||
type SqlJSONB []byte
|
||||
|
||||
// Scan - Implements sql.Scanner for reading JSONB from database
|
||||
func (n *SqlJSONB) Scan(value interface{}) error {
|
||||
if value == nil {
|
||||
*n = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
*n = SqlJSONB([]byte(v))
|
||||
case []byte:
|
||||
*n = SqlJSONB(v)
|
||||
default:
|
||||
// For other types, marshal to JSON
|
||||
dat, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal value to JSON: %v", err)
|
||||
}
|
||||
*n = SqlJSONB(dat)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Value - Implements driver.Valuer for writing JSONB to database
|
||||
func (n SqlJSONB) Value() (driver.Value, error) {
|
||||
if len(n) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Validate that it's valid JSON before returning
|
||||
var js interface{}
|
||||
if err := json.Unmarshal(n, &js); err != nil {
|
||||
return nil, fmt.Errorf("invalid JSON: %v", err)
|
||||
}
|
||||
|
||||
// Return as string for PostgreSQL JSONB/JSON columns
|
||||
return string(n), nil
|
||||
}
|
||||
|
||||
func (n SqlJSONB) AsMap() (map[string]any, error) {
|
||||
if len(n) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
// Validate that it's valid JSON before returning
|
||||
js := make(map[string]any)
|
||||
if err := json.Unmarshal(n, &js); err != nil {
|
||||
return nil, fmt.Errorf("invalid JSON: %v", err)
|
||||
}
|
||||
return js, nil
|
||||
}
|
||||
|
||||
func (n SqlJSONB) AsSlice() ([]any, error) {
|
||||
if len(n) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
// Validate that it's valid JSON before returning
|
||||
js := make([]any, 0)
|
||||
if err := json.Unmarshal(n, &js); err != nil {
|
||||
return nil, fmt.Errorf("invalid JSON: %v", err)
|
||||
}
|
||||
return js, nil
|
||||
}
|
||||
|
||||
// UnmarshalJSON - Override JSON
|
||||
func (n *SqlJSONB) UnmarshalJSON(b []byte) error {
|
||||
|
||||
s := strings.Trim(strings.Trim(string(b), " "), "\"")
|
||||
invalid := (s == "null" || s == "" || len(s) < 2) || !(strings.Contains(s, "{") || strings.Contains(s, "["))
|
||||
if invalid {
|
||||
s = ""
|
||||
return nil
|
||||
}
|
||||
|
||||
*n = []byte(s)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// MarshalJSON - Override JSON format of time
|
||||
func (n SqlJSONB) MarshalJSON() ([]byte, error) {
|
||||
if n == nil {
|
||||
return []byte("null"), nil
|
||||
}
|
||||
var obj interface{}
|
||||
err := json.Unmarshal(n, &obj)
|
||||
if err != nil {
|
||||
//fmt.Printf("Invalid JSON %v", err)
|
||||
return []byte("null"), nil
|
||||
}
|
||||
|
||||
// dat, err := json.MarshalIndent(obj, " ", " ")
|
||||
// if err != nil {
|
||||
// return nil, fmt.Errorf("failed to convert to JSON: %v", err)
|
||||
// }
|
||||
dat := n
|
||||
|
||||
return dat, nil
|
||||
}
|
||||
@@ -28,6 +28,26 @@ func NewModelRegistry() *DefaultModelRegistry {
|
||||
}
|
||||
}
|
||||
|
||||
func SetDefaultRegistry(registry *DefaultModelRegistry) {
|
||||
registriesMutex.Lock()
|
||||
foundAt := -1
|
||||
for idx, r := range registries {
|
||||
if r == defaultRegistry {
|
||||
foundAt = idx
|
||||
break
|
||||
}
|
||||
}
|
||||
defaultRegistry = registry
|
||||
if foundAt >= 0 {
|
||||
registries[foundAt] = registry
|
||||
} else {
|
||||
registries = append([]*DefaultModelRegistry{registry}, registries...)
|
||||
}
|
||||
|
||||
defer registriesMutex.Unlock()
|
||||
|
||||
}
|
||||
|
||||
// AddRegistry adds a registry to the global list of registries
|
||||
// Registries are searched in the order they were added
|
||||
func AddRegistry(registry *DefaultModelRegistry) {
|
||||
|
||||
@@ -18,6 +18,7 @@ type ModelFieldDetail struct {
|
||||
}
|
||||
|
||||
// GetModelColumnDetail - Get a list of columns in the SQL declaration of the model
|
||||
// This function recursively processes embedded structs to include their fields
|
||||
func GetModelColumnDetail(record reflect.Value) []ModelFieldDetail {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
@@ -37,14 +38,43 @@ func GetModelColumnDetail(record reflect.Value) []ModelFieldDetail {
|
||||
if record.Kind() != reflect.Struct {
|
||||
return lst
|
||||
}
|
||||
|
||||
collectFieldDetails(record, &lst)
|
||||
|
||||
return lst
|
||||
}
|
||||
|
||||
// collectFieldDetails recursively collects field details from a struct value and its embedded fields
|
||||
func collectFieldDetails(record reflect.Value, lst *[]ModelFieldDetail) {
|
||||
modeltype := record.Type()
|
||||
|
||||
for i := 0; i < modeltype.NumField(); i++ {
|
||||
fieldtype := modeltype.Field(i)
|
||||
fieldValue := record.Field(i)
|
||||
|
||||
// Check if this is an embedded struct
|
||||
if fieldtype.Anonymous {
|
||||
// Unwrap pointer type if necessary
|
||||
embeddedValue := fieldValue
|
||||
if fieldValue.Kind() == reflect.Pointer {
|
||||
if fieldValue.IsNil() {
|
||||
// Skip nil embedded pointers
|
||||
continue
|
||||
}
|
||||
embeddedValue = fieldValue.Elem()
|
||||
}
|
||||
|
||||
// Recursively process embedded struct
|
||||
if embeddedValue.Kind() == reflect.Struct {
|
||||
collectFieldDetails(embeddedValue, lst)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
gormdetail := fieldtype.Tag.Get("gorm")
|
||||
gormdetail = strings.Trim(gormdetail, " ")
|
||||
fielddetail := ModelFieldDetail{}
|
||||
fielddetail.FieldValue = record.Field(i)
|
||||
fielddetail.FieldValue = fieldValue
|
||||
fielddetail.Name = fieldtype.Name
|
||||
fielddetail.DataType = fieldtype.Type.Name()
|
||||
fielddetail.SQLName = fnFindKeyVal(gormdetail, "column:")
|
||||
@@ -80,10 +110,8 @@ func GetModelColumnDetail(record reflect.Value) []ModelFieldDetail {
|
||||
}
|
||||
// ";foreignkey:rid_parent;association_foreignkey:id_atevent;save_associations:false;association_autocreate:false;"
|
||||
|
||||
lst = append(lst, fielddetail)
|
||||
|
||||
*lst = append(*lst, fielddetail)
|
||||
}
|
||||
return lst
|
||||
}
|
||||
|
||||
func fnFindKeyVal(src, key string) string {
|
||||
|
||||
@@ -47,7 +47,7 @@ func GetPrimaryKeyName(model any) string {
|
||||
|
||||
// GetPrimaryKeyValue extracts the primary key value from a model instance
|
||||
// Returns the value of the primary key field
|
||||
func GetPrimaryKeyValue(model any) interface{} {
|
||||
func GetPrimaryKeyValue(model any) any {
|
||||
if model == nil || reflect.TypeOf(model) == nil {
|
||||
return nil
|
||||
}
|
||||
@@ -61,38 +61,51 @@ func GetPrimaryKeyValue(model any) interface{} {
|
||||
return nil
|
||||
}
|
||||
|
||||
typ := val.Type()
|
||||
|
||||
// Try Bun tag first
|
||||
for i := 0; i < typ.NumField(); i++ {
|
||||
field := typ.Field(i)
|
||||
bunTag := field.Tag.Get("bun")
|
||||
if strings.Contains(bunTag, "pk") {
|
||||
fieldValue := val.Field(i)
|
||||
if fieldValue.CanInterface() {
|
||||
return fieldValue.Interface()
|
||||
}
|
||||
}
|
||||
if pkValue := findPrimaryKeyValue(val, "bun"); pkValue != nil {
|
||||
return pkValue
|
||||
}
|
||||
|
||||
// Fall back to GORM tag
|
||||
for i := 0; i < typ.NumField(); i++ {
|
||||
field := typ.Field(i)
|
||||
gormTag := field.Tag.Get("gorm")
|
||||
if strings.Contains(gormTag, "primaryKey") {
|
||||
fieldValue := val.Field(i)
|
||||
if fieldValue.CanInterface() {
|
||||
return fieldValue.Interface()
|
||||
}
|
||||
}
|
||||
if pkValue := findPrimaryKeyValue(val, "gorm"); pkValue != nil {
|
||||
return pkValue
|
||||
}
|
||||
|
||||
// Last resort: look for field named "ID" or "Id"
|
||||
if pkValue := findFieldByName(val, "id"); pkValue != nil {
|
||||
return pkValue
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// findPrimaryKeyValue recursively searches for a primary key field in the struct
|
||||
func findPrimaryKeyValue(val reflect.Value, ormType string) any {
|
||||
typ := val.Type()
|
||||
|
||||
for i := 0; i < typ.NumField(); i++ {
|
||||
field := typ.Field(i)
|
||||
if strings.ToLower(field.Name) == "id" {
|
||||
fieldValue := val.Field(i)
|
||||
if fieldValue.CanInterface() {
|
||||
fieldValue := val.Field(i)
|
||||
|
||||
// Check if this is an embedded struct
|
||||
if field.Anonymous && field.Type.Kind() == reflect.Struct {
|
||||
// Recursively search in embedded struct
|
||||
if pkValue := findPrimaryKeyValue(fieldValue, ormType); pkValue != nil {
|
||||
return pkValue
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Check for primary key tag
|
||||
switch ormType {
|
||||
case "bun":
|
||||
bunTag := field.Tag.Get("bun")
|
||||
if strings.Contains(bunTag, "pk") && fieldValue.CanInterface() {
|
||||
return fieldValue.Interface()
|
||||
}
|
||||
case "gorm":
|
||||
gormTag := field.Tag.Get("gorm")
|
||||
if strings.Contains(gormTag, "primaryKey") && fieldValue.CanInterface() {
|
||||
return fieldValue.Interface()
|
||||
}
|
||||
}
|
||||
@@ -101,8 +114,35 @@ func GetPrimaryKeyValue(model any) interface{} {
|
||||
return nil
|
||||
}
|
||||
|
||||
// findFieldByName recursively searches for a field by name in the struct
|
||||
func findFieldByName(val reflect.Value, name string) any {
|
||||
typ := val.Type()
|
||||
|
||||
for i := 0; i < typ.NumField(); i++ {
|
||||
field := typ.Field(i)
|
||||
fieldValue := val.Field(i)
|
||||
|
||||
// Check if this is an embedded struct
|
||||
if field.Anonymous && field.Type.Kind() == reflect.Struct {
|
||||
// Recursively search in embedded struct
|
||||
if result := findFieldByName(fieldValue, name); result != nil {
|
||||
return result
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if field name matches
|
||||
if strings.ToLower(field.Name) == name && fieldValue.CanInterface() {
|
||||
return fieldValue.Interface()
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetModelColumns extracts all column names from a model using reflection
|
||||
// It checks bun tags first, then gorm tags, then json tags, and finally falls back to lowercase field names
|
||||
// This function recursively processes embedded structs to include their fields
|
||||
func GetModelColumns(model any) []string {
|
||||
var columns []string
|
||||
|
||||
@@ -118,18 +158,38 @@ func GetModelColumns(model any) []string {
|
||||
return columns
|
||||
}
|
||||
|
||||
for i := 0; i < modelType.NumField(); i++ {
|
||||
field := modelType.Field(i)
|
||||
collectColumnsFromType(modelType, &columns)
|
||||
|
||||
return columns
|
||||
}
|
||||
|
||||
// collectColumnsFromType recursively collects column names from a struct type and its embedded fields
|
||||
func collectColumnsFromType(typ reflect.Type, columns *[]string) {
|
||||
for i := 0; i < typ.NumField(); i++ {
|
||||
field := typ.Field(i)
|
||||
|
||||
// Check if this is an embedded struct
|
||||
if field.Anonymous {
|
||||
// Unwrap pointer type if necessary
|
||||
fieldType := field.Type
|
||||
if fieldType.Kind() == reflect.Pointer {
|
||||
fieldType = fieldType.Elem()
|
||||
}
|
||||
|
||||
// Recursively process embedded struct
|
||||
if fieldType.Kind() == reflect.Struct {
|
||||
collectColumnsFromType(fieldType, columns)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Get column name using the same logic as primary key extraction
|
||||
columnName := getColumnNameFromField(field)
|
||||
|
||||
if columnName != "" {
|
||||
columns = append(columns, columnName)
|
||||
*columns = append(*columns, columnName)
|
||||
}
|
||||
}
|
||||
|
||||
return columns
|
||||
}
|
||||
|
||||
// getColumnNameFromField extracts the column name from a struct field
|
||||
@@ -166,6 +226,7 @@ func getColumnNameFromField(field reflect.StructField) string {
|
||||
}
|
||||
|
||||
// getPrimaryKeyFromReflection uses reflection to find the primary key field
|
||||
// This function recursively searches embedded structs
|
||||
func getPrimaryKeyFromReflection(model any, ormType string) string {
|
||||
val := reflect.ValueOf(model)
|
||||
if val.Kind() == reflect.Pointer {
|
||||
@@ -177,9 +238,31 @@ func getPrimaryKeyFromReflection(model any, ormType string) string {
|
||||
}
|
||||
|
||||
typ := val.Type()
|
||||
return findPrimaryKeyNameFromType(typ, ormType)
|
||||
}
|
||||
|
||||
// findPrimaryKeyNameFromType recursively searches for the primary key field name in a struct type
|
||||
func findPrimaryKeyNameFromType(typ reflect.Type, ormType string) string {
|
||||
for i := 0; i < typ.NumField(); i++ {
|
||||
field := typ.Field(i)
|
||||
|
||||
// Check if this is an embedded struct
|
||||
if field.Anonymous {
|
||||
// Unwrap pointer type if necessary
|
||||
fieldType := field.Type
|
||||
if fieldType.Kind() == reflect.Pointer {
|
||||
fieldType = fieldType.Elem()
|
||||
}
|
||||
|
||||
// Recursively search in embedded struct
|
||||
if fieldType.Kind() == reflect.Struct {
|
||||
if pkName := findPrimaryKeyNameFromType(fieldType, ormType); pkName != "" {
|
||||
return pkName
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
switch ormType {
|
||||
case "gorm":
|
||||
// Check for gorm tag with primaryKey
|
||||
@@ -231,6 +314,9 @@ func ExtractColumnFromGormTag(tag string) string {
|
||||
// Example: ",pk" -> "" (will fall back to json tag)
|
||||
func ExtractColumnFromBunTag(tag string) string {
|
||||
parts := strings.Split(tag, ",")
|
||||
if strings.HasPrefix(strings.ToLower(tag), "table:") || strings.HasPrefix(strings.ToLower(tag), "rel:") || strings.HasPrefix(strings.ToLower(tag), "join:") {
|
||||
return ""
|
||||
}
|
||||
if len(parts) > 0 && parts[0] != "" {
|
||||
return parts[0]
|
||||
}
|
||||
@@ -240,6 +326,7 @@ func ExtractColumnFromBunTag(tag string) string {
|
||||
// IsColumnWritable checks if a column can be written to in the database
|
||||
// For bun: returns false if the field has "scanonly" tag
|
||||
// For gorm: returns false if the field has "<-:false" or "->" (read-only) tag
|
||||
// This function recursively searches embedded structs
|
||||
func IsColumnWritable(model any, columnName string) bool {
|
||||
modelType := reflect.TypeOf(model)
|
||||
|
||||
@@ -253,8 +340,37 @@ func IsColumnWritable(model any, columnName string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
for i := 0; i < modelType.NumField(); i++ {
|
||||
field := modelType.Field(i)
|
||||
found, writable := isColumnWritableInType(modelType, columnName)
|
||||
if found {
|
||||
return writable
|
||||
}
|
||||
|
||||
// Column not found in model, allow it (might be a dynamic column)
|
||||
return true
|
||||
}
|
||||
|
||||
// isColumnWritableInType recursively searches for a column and checks if it's writable
|
||||
// Returns (found, writable) where found indicates if the column was found
|
||||
func isColumnWritableInType(typ reflect.Type, columnName string) (bool, bool) {
|
||||
for i := 0; i < typ.NumField(); i++ {
|
||||
field := typ.Field(i)
|
||||
|
||||
// Check if this is an embedded struct
|
||||
if field.Anonymous {
|
||||
// Unwrap pointer type if necessary
|
||||
fieldType := field.Type
|
||||
if fieldType.Kind() == reflect.Pointer {
|
||||
fieldType = fieldType.Elem()
|
||||
}
|
||||
|
||||
// Recursively search in embedded struct
|
||||
if fieldType.Kind() == reflect.Struct {
|
||||
if found, writable := isColumnWritableInType(fieldType, columnName); found {
|
||||
return true, writable
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if this field matches the column name
|
||||
fieldColumnName := getColumnNameFromField(field)
|
||||
@@ -262,11 +378,12 @@ func IsColumnWritable(model any, columnName string) bool {
|
||||
continue
|
||||
}
|
||||
|
||||
// Found the field, now check if it's writable
|
||||
// Check bun tag for scanonly
|
||||
bunTag := field.Tag.Get("bun")
|
||||
if bunTag != "" {
|
||||
if isBunFieldScanOnly(bunTag) {
|
||||
return false
|
||||
return true, false
|
||||
}
|
||||
}
|
||||
|
||||
@@ -274,16 +391,16 @@ func IsColumnWritable(model any, columnName string) bool {
|
||||
gormTag := field.Tag.Get("gorm")
|
||||
if gormTag != "" {
|
||||
if isGormFieldReadOnly(gormTag) {
|
||||
return false
|
||||
return true, false
|
||||
}
|
||||
}
|
||||
|
||||
// Column is writable
|
||||
return true
|
||||
return true, true
|
||||
}
|
||||
|
||||
// Column not found in model, allow it (might be a dynamic column)
|
||||
return true
|
||||
// Column not found
|
||||
return false, false
|
||||
}
|
||||
|
||||
// isBunFieldScanOnly checks if a bun tag indicates the field is scan-only
|
||||
|
||||
@@ -231,3 +231,246 @@ func TestGetModelColumns(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Test models with embedded structs
|
||||
|
||||
type BaseModel struct {
|
||||
ID int `bun:"rid_base,pk" json:"id"`
|
||||
CreatedAt string `bun:"created_at" json:"created_at"`
|
||||
}
|
||||
|
||||
type AdhocBuffer struct {
|
||||
CQL1 string `json:"cql1,omitempty" gorm:"->" bun:",scanonly"`
|
||||
CQL2 string `json:"cql2,omitempty" gorm:"->" bun:",scanonly"`
|
||||
RowNumber int64 `json:"_rownumber,omitempty" gorm:"-" bun:",scanonly"`
|
||||
}
|
||||
|
||||
type ModelWithEmbedded struct {
|
||||
BaseModel
|
||||
Name string `bun:"name" json:"name"`
|
||||
Description string `bun:"description" json:"description"`
|
||||
AdhocBuffer
|
||||
}
|
||||
|
||||
type GormBaseModel struct {
|
||||
ID int `gorm:"column:rid_base;primaryKey" json:"id"`
|
||||
CreatedAt string `gorm:"column:created_at" json:"created_at"`
|
||||
}
|
||||
|
||||
type GormAdhocBuffer struct {
|
||||
CQL1 string `json:"cql1,omitempty" gorm:"column:cql1;->" bun:",scanonly"`
|
||||
CQL2 string `json:"cql2,omitempty" gorm:"column:cql2;->" bun:",scanonly"`
|
||||
RowNumber int64 `json:"_rownumber,omitempty" gorm:"-" bun:",scanonly"`
|
||||
}
|
||||
|
||||
type GormModelWithEmbedded struct {
|
||||
GormBaseModel
|
||||
Name string `gorm:"column:name" json:"name"`
|
||||
Description string `gorm:"column:description" json:"description"`
|
||||
GormAdhocBuffer
|
||||
}
|
||||
|
||||
func TestGetPrimaryKeyNameWithEmbedded(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
model any
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "Bun model with embedded base",
|
||||
model: ModelWithEmbedded{},
|
||||
expected: "rid_base",
|
||||
},
|
||||
{
|
||||
name: "Bun model with embedded base (pointer)",
|
||||
model: &ModelWithEmbedded{},
|
||||
expected: "rid_base",
|
||||
},
|
||||
{
|
||||
name: "GORM model with embedded base",
|
||||
model: GormModelWithEmbedded{},
|
||||
expected: "rid_base",
|
||||
},
|
||||
{
|
||||
name: "GORM model with embedded base (pointer)",
|
||||
model: &GormModelWithEmbedded{},
|
||||
expected: "rid_base",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := GetPrimaryKeyName(tt.model)
|
||||
if result != tt.expected {
|
||||
t.Errorf("GetPrimaryKeyName() = %v, want %v", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetPrimaryKeyValueWithEmbedded(t *testing.T) {
|
||||
bunModel := ModelWithEmbedded{
|
||||
BaseModel: BaseModel{
|
||||
ID: 123,
|
||||
CreatedAt: "2024-01-01",
|
||||
},
|
||||
Name: "Test",
|
||||
Description: "Test Description",
|
||||
}
|
||||
|
||||
gormModel := GormModelWithEmbedded{
|
||||
GormBaseModel: GormBaseModel{
|
||||
ID: 456,
|
||||
CreatedAt: "2024-01-02",
|
||||
},
|
||||
Name: "GORM Test",
|
||||
Description: "GORM Test Description",
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
model any
|
||||
expected any
|
||||
}{
|
||||
{
|
||||
name: "Bun model with embedded base",
|
||||
model: bunModel,
|
||||
expected: 123,
|
||||
},
|
||||
{
|
||||
name: "Bun model with embedded base (pointer)",
|
||||
model: &bunModel,
|
||||
expected: 123,
|
||||
},
|
||||
{
|
||||
name: "GORM model with embedded base",
|
||||
model: gormModel,
|
||||
expected: 456,
|
||||
},
|
||||
{
|
||||
name: "GORM model with embedded base (pointer)",
|
||||
model: &gormModel,
|
||||
expected: 456,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := GetPrimaryKeyValue(tt.model)
|
||||
if result != tt.expected {
|
||||
t.Errorf("GetPrimaryKeyValue() = %v, want %v", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetModelColumnsWithEmbedded(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
model any
|
||||
expected []string
|
||||
}{
|
||||
{
|
||||
name: "Bun model with embedded structs",
|
||||
model: ModelWithEmbedded{},
|
||||
expected: []string{"rid_base", "created_at", "name", "description", "cql1", "cql2", "_rownumber"},
|
||||
},
|
||||
{
|
||||
name: "Bun model with embedded structs (pointer)",
|
||||
model: &ModelWithEmbedded{},
|
||||
expected: []string{"rid_base", "created_at", "name", "description", "cql1", "cql2", "_rownumber"},
|
||||
},
|
||||
{
|
||||
name: "GORM model with embedded structs",
|
||||
model: GormModelWithEmbedded{},
|
||||
expected: []string{"rid_base", "created_at", "name", "description", "cql1", "cql2", "_rownumber"},
|
||||
},
|
||||
{
|
||||
name: "GORM model with embedded structs (pointer)",
|
||||
model: &GormModelWithEmbedded{},
|
||||
expected: []string{"rid_base", "created_at", "name", "description", "cql1", "cql2", "_rownumber"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := GetModelColumns(tt.model)
|
||||
if len(result) != len(tt.expected) {
|
||||
t.Errorf("GetModelColumns() returned %d columns, want %d. Got: %v", len(result), len(tt.expected), result)
|
||||
return
|
||||
}
|
||||
for i, col := range result {
|
||||
if col != tt.expected[i] {
|
||||
t.Errorf("GetModelColumns()[%d] = %v, want %v", i, col, tt.expected[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsColumnWritableWithEmbedded(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
model any
|
||||
columnName string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "Bun model - writable column in main struct",
|
||||
model: ModelWithEmbedded{},
|
||||
columnName: "name",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Bun model - writable column in embedded base",
|
||||
model: ModelWithEmbedded{},
|
||||
columnName: "rid_base",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Bun model - scanonly column in embedded adhoc buffer",
|
||||
model: ModelWithEmbedded{},
|
||||
columnName: "cql1",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Bun model - scanonly column _rownumber",
|
||||
model: ModelWithEmbedded{},
|
||||
columnName: "_rownumber",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "GORM model - writable column in main struct",
|
||||
model: GormModelWithEmbedded{},
|
||||
columnName: "name",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "GORM model - writable column in embedded base",
|
||||
model: GormModelWithEmbedded{},
|
||||
columnName: "rid_base",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "GORM model - readonly column in embedded adhoc buffer",
|
||||
model: GormModelWithEmbedded{},
|
||||
columnName: "cql1",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "GORM model - readonly column _rownumber",
|
||||
model: GormModelWithEmbedded{},
|
||||
columnName: "_rownumber",
|
||||
expected: false, // bun:",scanonly" marks it as read-only, takes precedence over gorm:"-"
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := IsColumnWritable(tt.model, tt.columnName)
|
||||
if result != tt.expected {
|
||||
t.Errorf("IsColumnWritable(%s) = %v, want %v", tt.columnName, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -610,6 +610,9 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
|
||||
dataSlice := h.normalizeToSlice(data)
|
||||
logger.Debug("Processing %d item(s) for creation", len(dataSlice))
|
||||
|
||||
// Store original data maps for merging later
|
||||
originalDataMaps := make([]map[string]interface{}, 0, len(dataSlice))
|
||||
|
||||
// Process all items in a transaction
|
||||
results := make([]interface{}, 0, len(dataSlice))
|
||||
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
||||
@@ -630,6 +633,13 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
|
||||
}
|
||||
}
|
||||
|
||||
// Store a copy of the original data map for merging later
|
||||
originalMap := make(map[string]interface{})
|
||||
for k, v := range itemMap {
|
||||
originalMap[k] = v
|
||||
}
|
||||
originalDataMaps = append(originalDataMaps, originalMap)
|
||||
|
||||
// Extract nested relations if present (but don't process them yet)
|
||||
var nestedRelations map[string]interface{}
|
||||
if h.shouldUseNestedProcessor(itemMap, model) {
|
||||
@@ -704,14 +714,26 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
|
||||
return
|
||||
}
|
||||
|
||||
// Merge created records with original request data
|
||||
// This preserves extra keys from the request
|
||||
mergedResults := make([]interface{}, 0, len(results))
|
||||
for i, result := range results {
|
||||
if i < len(originalDataMaps) {
|
||||
merged := h.mergeRecordWithRequest(result, originalDataMaps[i])
|
||||
mergedResults = append(mergedResults, merged)
|
||||
} else {
|
||||
mergedResults = append(mergedResults, result)
|
||||
}
|
||||
}
|
||||
|
||||
// Execute AfterCreate hooks
|
||||
var responseData interface{}
|
||||
if len(results) == 1 {
|
||||
responseData = results[0]
|
||||
hookCtx.Result = results[0]
|
||||
if len(mergedResults) == 1 {
|
||||
responseData = mergedResults[0]
|
||||
hookCtx.Result = mergedResults[0]
|
||||
} else {
|
||||
responseData = results
|
||||
hookCtx.Result = map[string]interface{}{"created": len(results), "data": results}
|
||||
responseData = mergedResults
|
||||
hookCtx.Result = map[string]interface{}{"created": len(mergedResults), "data": mergedResults}
|
||||
}
|
||||
hookCtx.Error = nil
|
||||
|
||||
@@ -721,7 +743,7 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
|
||||
return
|
||||
}
|
||||
|
||||
logger.Info("Successfully created %d record(s)", len(results))
|
||||
logger.Info("Successfully created %d record(s)", len(mergedResults))
|
||||
h.sendResponseWithOptions(w, responseData, nil, &options)
|
||||
}
|
||||
|
||||
@@ -790,6 +812,12 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id
|
||||
return
|
||||
}
|
||||
|
||||
// Get the primary key name for the model
|
||||
pkName := reflection.GetPrimaryKeyName(model)
|
||||
|
||||
// Variable to store the updated record
|
||||
var updatedRecord interface{}
|
||||
|
||||
// Process nested relations if present
|
||||
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
||||
// Create temporary nested processor with transaction
|
||||
@@ -808,11 +836,10 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id
|
||||
}
|
||||
|
||||
// Ensure ID is in the data map for the update
|
||||
dataMap["id"] = targetID
|
||||
dataMap[pkName] = targetID
|
||||
|
||||
// Create update query
|
||||
query := tx.NewUpdate().Model(model).Table(tableName).SetMap(dataMap)
|
||||
pkName := reflection.GetPrimaryKeyName(model)
|
||||
query := tx.NewUpdate().Table(tableName).SetMap(dataMap)
|
||||
query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), targetID)
|
||||
|
||||
// Execute BeforeScan hooks - pass query chain so hooks can modify it
|
||||
@@ -840,10 +867,18 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id
|
||||
}
|
||||
}
|
||||
|
||||
// Store result for hooks
|
||||
hookCtx.Result = map[string]interface{}{
|
||||
"updated": result.RowsAffected(),
|
||||
// Fetch the updated record to return the new values
|
||||
modelValue := reflect.New(reflect.TypeOf(model)).Interface()
|
||||
selectQuery := tx.NewSelect().Model(modelValue).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), targetID)
|
||||
if err := selectQuery.ScanModel(ctx); err != nil {
|
||||
return fmt.Errorf("failed to fetch updated record: %w", err)
|
||||
}
|
||||
|
||||
updatedRecord = modelValue
|
||||
|
||||
// Store result for hooks
|
||||
hookCtx.Result = updatedRecord
|
||||
_ = result // Keep result variable for potential future use
|
||||
return nil
|
||||
})
|
||||
|
||||
@@ -853,7 +888,12 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id
|
||||
return
|
||||
}
|
||||
|
||||
// Merge the updated record with the original request data
|
||||
// This preserves extra keys from the request and updates values from the database
|
||||
mergedData := h.mergeRecordWithRequest(updatedRecord, dataMap)
|
||||
|
||||
// Execute AfterUpdate hooks
|
||||
hookCtx.Result = mergedData
|
||||
hookCtx.Error = nil
|
||||
if err := h.hooks.Execute(AfterUpdate, hookCtx); err != nil {
|
||||
logger.Error("AfterUpdate hook failed: %v", err)
|
||||
@@ -862,7 +902,7 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id
|
||||
}
|
||||
|
||||
logger.Info("Successfully updated record with ID: %v", targetID)
|
||||
h.sendResponseWithOptions(w, hookCtx.Result, nil, &options)
|
||||
h.sendResponseWithOptions(w, mergedData, nil, &options)
|
||||
}
|
||||
|
||||
func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id string, data interface{}) {
|
||||
@@ -936,6 +976,7 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
||||
// Array of IDs or objects with ID field
|
||||
logger.Info("Batch delete with %d items ([]interface{})", len(v))
|
||||
deletedCount := 0
|
||||
pkName := reflection.GetPrimaryKeyName(model)
|
||||
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
||||
for _, item := range v {
|
||||
var itemID interface{}
|
||||
@@ -945,7 +986,7 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
||||
case string:
|
||||
itemID = v
|
||||
case map[string]interface{}:
|
||||
itemID = v["id"]
|
||||
itemID = v[pkName]
|
||||
default:
|
||||
itemID = item
|
||||
}
|
||||
@@ -1002,9 +1043,10 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
||||
// Array of objects with id field
|
||||
logger.Info("Batch delete with %d items ([]map[string]interface{})", len(v))
|
||||
deletedCount := 0
|
||||
pkName := reflection.GetPrimaryKeyName(model)
|
||||
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
||||
for _, item := range v {
|
||||
if itemID, ok := item["id"]; ok && itemID != nil {
|
||||
if itemID, ok := item[pkName]; ok && itemID != nil {
|
||||
itemIDStr := fmt.Sprintf("%v", itemID)
|
||||
|
||||
// Execute hooks for each item
|
||||
@@ -1052,7 +1094,8 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
||||
|
||||
case map[string]interface{}:
|
||||
// Single object with id field
|
||||
if itemID, ok := v["id"]; ok && itemID != nil {
|
||||
pkName := reflection.GetPrimaryKeyName(model)
|
||||
if itemID, ok := v[pkName]; ok && itemID != nil {
|
||||
id = fmt.Sprintf("%v", itemID)
|
||||
}
|
||||
}
|
||||
@@ -1122,6 +1165,39 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
||||
h.sendResponse(w, responseData, nil)
|
||||
}
|
||||
|
||||
// mergeRecordWithRequest merges a database record with the original request data
|
||||
// This preserves extra keys from the request that aren't in the database model
|
||||
// and updates values from the database (e.g., from SQL triggers or defaults)
|
||||
func (h *Handler) mergeRecordWithRequest(dbRecord interface{}, requestData map[string]interface{}) map[string]interface{} {
|
||||
// Convert the database record to a map
|
||||
dbMap := make(map[string]interface{})
|
||||
|
||||
// Marshal and unmarshal to convert struct to map
|
||||
jsonData, err := json.Marshal(dbRecord)
|
||||
if err != nil {
|
||||
logger.Warn("Failed to marshal database record for merging: %v", err)
|
||||
return requestData
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(jsonData, &dbMap); err != nil {
|
||||
logger.Warn("Failed to unmarshal database record for merging: %v", err)
|
||||
return requestData
|
||||
}
|
||||
|
||||
// Start with the request data (preserves extra keys)
|
||||
result := make(map[string]interface{})
|
||||
for k, v := range requestData {
|
||||
result[k] = v
|
||||
}
|
||||
|
||||
// Update with values from database (overwrites with DB values, including trigger changes)
|
||||
for k, v := range dbMap {
|
||||
result[k] = v
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// normalizeToSlice converts data to a slice. Single items become a 1-item slice.
|
||||
func (h *Handler) normalizeToSlice(data interface{}) []interface{} {
|
||||
if data == nil {
|
||||
@@ -1658,22 +1734,22 @@ func (h *Handler) cleanJSON(data interface{}) interface{} {
|
||||
}
|
||||
|
||||
func (h *Handler) sendError(w common.ResponseWriter, statusCode int, code, message string, err error) {
|
||||
var details string
|
||||
var errorMsg string
|
||||
if err != nil {
|
||||
details = err.Error()
|
||||
errorMsg = err.Error()
|
||||
} else if message != "" {
|
||||
errorMsg = message
|
||||
} else {
|
||||
errorMsg = code
|
||||
}
|
||||
|
||||
response := common.Response{
|
||||
Success: false,
|
||||
Error: &common.APIError{
|
||||
Code: code,
|
||||
Message: message,
|
||||
Details: details,
|
||||
},
|
||||
response := map[string]interface{}{
|
||||
"_error": errorMsg,
|
||||
"_retval": 1,
|
||||
}
|
||||
w.WriteHeader(statusCode)
|
||||
if err := w.WriteJSON(response); err != nil {
|
||||
logger.Error("Failed to write JSON error response: %v", err)
|
||||
if jsonErr := w.WriteJSON(response); jsonErr != nil {
|
||||
logger.Error("Failed to write JSON error response: %v", jsonErr)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
type TestModel struct {
|
||||
ID int64 `json:"id" bun:"id,pk"`
|
||||
Name string `json:"name" bun:"name"`
|
||||
RowNumber int64 `json:"_rownumber,omitempty" gorm:"-" bun:"-"`
|
||||
RowNumber int64 `json:"_rownumber,omitempty" gorm:"-" bun:",scanonly"`
|
||||
}
|
||||
|
||||
func TestSetRowNumbersOnRecords(t *testing.T) {
|
||||
|
||||
Reference in New Issue
Block a user