Compare commits

...

18 Commits

Author SHA1 Message Date
Hein b9bed67bd7 feat(security): add program user ID and table to user context
Build , Vet Test, and Lint / Lint Code (push) Failing after 0s
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Failing after 0s
Build , Vet Test, and Lint / Build (push) Failing after 1s
Tests / Unit Tests (push) Failing after 0s
Tests / Integration Tests (push) Failing after 1s
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Failing after 14m57s
2026-06-23 10:18:22 +02:00
Hein 11ef16f75a fix(sql_helpers): adjust parenthesis nesting depth comment 2026-06-23 09:41:40 +02:00
Hein 48b72a7631 fix(sql_helpers): enhance splitByAND to handle BETWEEN and quotes
* Add support for BETWEEN-aware AND detection
* Ensure AND inside single-quoted strings does not cause splits
* Update tests to cover new BETWEEN and quote scenarios
2026-06-23 09:41:27 +02:00
Hein 4c512acf25 test(function_api): add test for x-detailapi header response 2026-06-23 08:53:33 +02:00
Hein 07a402634e fix(function_api): enhance detail format with table metadata
* include table name and prefix in response
* add field metadata extraction for raw SQL results
2026-06-23 08:50:29 +02:00
Hein 0e8f8925c6 fix(reflection): replace reflect.Ptr with reflect.Pointer
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Failing after 0s
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Failing after 1s
Build , Vet Test, and Lint / Lint Code (push) Failing after 0s
Build , Vet Test, and Lint / Build (push) Failing after 0s
Tests / Unit Tests (push) Failing after 1s
Tests / Integration Tests (push) Failing after 1s
* Updated all instances of reflect.Ptr to reflect.Pointer for consistency in type checking.
2026-06-22 16:40:07 +02:00
Hein 5a359a160b fix(handler): update sendFormattedResponse to include table name and model 2026-06-22 16:38:21 +02:00
Hein a2799fa224 fix(handler): re-fetch records to capture DB-generated values
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Has been cancelled
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Has been cancelled
Build , Vet Test, and Lint / Lint Code (push) Has been cancelled
Build , Vet Test, and Lint / Build (push) Has been cancelled
Tests / Unit Tests (push) Has been cancelled
Tests / Integration Tests (push) Has been cancelled
2026-06-12 16:28:51 +02:00
Hein 1419542650 fix(handler): re-fetch records to capture DB-generated changes 2026-06-12 13:37:07 +02:00
Hein c120b49529 fix(router): prevent HTML escaping in JSON responses
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Has been cancelled
Build , Vet Test, and Lint / Lint Code (push) Has been cancelled
Build , Vet Test, and Lint / Build (push) Has been cancelled
Tests / Integration Tests (push) Has been cancelled
Tests / Unit Tests (push) Has been cancelled
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Has been cancelled
fix(sql_helpers): avoid prefix extraction in subqueries
2026-06-08 15:13:58 +02:00
Hein 66348dac97 test(handler): add tests for valid nested request verbs 2026-06-08 09:06:29 +02:00
Hein a87cd18b1b fix(handler): validate nested request structure for relations
* added checks for valid _request values in single and multiple relations
* introduced isValidNestedRequest function to encapsulate validation logic
fix(crud): expand operation handling for nested CUD
* added "add" to insert operations and "modify" to update operations
* included "remove" in delete operations
2026-06-08 09:02:29 +02:00
Hein 29449c93d5 fix(test): add tests for asymmetric join column handling
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Has been cancelled
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Has been cancelled
Build , Vet Test, and Lint / Lint Code (push) Has been cancelled
Build , Vet Test, and Lint / Build (push) Has been cancelled
Tests / Unit Tests (push) Has been cancelled
Tests / Integration Tests (push) Has been cancelled
2026-06-07 19:13:59 +02:00
Hein 3b6e5c75be fix(handler): update foreign key field resolution logic
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Has been cancelled
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Has been cancelled
Build , Vet Test, and Lint / Lint Code (push) Has been cancelled
Build , Vet Test, and Lint / Build (push) Has been cancelled
Tests / Unit Tests (push) Has been cancelled
Tests / Integration Tests (push) Has been cancelled
* Adjust foreign key field name selection for has-many/has-one relationships
* Improve logging to clarify foreign key and child field usage
2026-06-07 14:20:55 +02:00
Hein 549ccb8468 fix(handler): fetch updated records after transaction commits
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Has been cancelled
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Has been cancelled
Build , Vet Test, and Lint / Lint Code (push) Has been cancelled
Build , Vet Test, and Lint / Build (push) Has been cancelled
Tests / Unit Tests (push) Has been cancelled
Tests / Integration Tests (push) Has been cancelled
* Update selection queries to use model columns
* Ensure updated records are fetched and returned in responses
2026-06-05 11:12:04 +02:00
Hein 1af9c76337 fix(handler): fetch updated record after transaction commits
Tests / Unit Tests (push) Has been cancelled
Tests / Integration Tests (push) Has been cancelled
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Has been cancelled
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Has been cancelled
Build , Vet Test, and Lint / Lint Code (push) Has been cancelled
Build , Vet Test, and Lint / Build (push) Has been cancelled
2026-06-04 18:23:18 +02:00
Hein 938a2ef3d9 fix(staticweb): add fallback for extensionless file paths
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Failing after 6s
Tests / Integration Tests (push) Failing after 13m59s
Tests / Unit Tests (push) Failing after 14m11s
Build , Vet Test, and Lint / Build (push) Failing after 14m21s
Build , Vet Test, and Lint / Lint Code (push) Failing after 14m31s
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Failing after 14m45s
2026-05-27 18:41:43 +02:00
Hein 69cc3e2839 fix(db): update Returning method to accept multiple columns 2026-05-27 14:11:20 +02:00
32 changed files with 1374 additions and 224 deletions
+17 -17
View File
@@ -39,7 +39,7 @@ func (h *QueryDebugHook) AfterQuery(ctx context.Context, event *bun.QueryEvent)
// This helps identify which specific field is causing scanning issues // This helps identify which specific field is causing scanning issues
func debugScanIntoStruct(rows interface{}, dest interface{}) error { func debugScanIntoStruct(rows interface{}, dest interface{}) error {
v := reflect.ValueOf(dest) v := reflect.ValueOf(dest)
if v.Kind() != reflect.Ptr { if v.Kind() != reflect.Pointer {
return fmt.Errorf("dest must be a pointer") return fmt.Errorf("dest must be a pointer")
} }
@@ -59,7 +59,7 @@ func debugScanIntoStruct(rows interface{}, dest interface{}) error {
logger.Debug(" Slice element type: %s", elemType) logger.Debug(" Slice element type: %s", elemType)
// If slice of pointers, get the underlying type // If slice of pointers, get the underlying type
if elemType.Kind() == reflect.Ptr { if elemType.Kind() == reflect.Pointer {
structType = elemType.Elem() structType = elemType.Elem()
} else { } else {
structType = elemType structType = elemType
@@ -747,7 +747,7 @@ func checkIfRelationAlreadyLoaded(parents reflect.Value, relationName string) (r
// Get the first parent to check the relation field // Get the first parent to check the relation field
firstParent := parents.Index(0) firstParent := parents.Index(0)
if firstParent.Kind() == reflect.Ptr { if firstParent.Kind() == reflect.Pointer {
firstParent = firstParent.Elem() firstParent = firstParent.Elem()
} }
@@ -762,7 +762,7 @@ func checkIfRelationAlreadyLoaded(parents reflect.Value, relationName string) (r
// Check if any parent has a non-empty slice // Check if any parent has a non-empty slice
for i := 0; i < parents.Len(); i++ { for i := 0; i < parents.Len(); i++ {
parent := parents.Index(i) parent := parents.Index(i)
if parent.Kind() == reflect.Ptr { if parent.Kind() == reflect.Pointer {
parent = parent.Elem() parent = parent.Elem()
} }
field := parent.FieldByName(relationName) field := parent.FieldByName(relationName)
@@ -771,7 +771,7 @@ func checkIfRelationAlreadyLoaded(parents reflect.Value, relationName string) (r
allRelated := reflect.MakeSlice(field.Type(), 0, field.Len()*parents.Len()) allRelated := reflect.MakeSlice(field.Type(), 0, field.Len()*parents.Len())
for j := 0; j < parents.Len(); j++ { for j := 0; j < parents.Len(); j++ {
p := parents.Index(j) p := parents.Index(j)
if p.Kind() == reflect.Ptr { if p.Kind() == reflect.Pointer {
p = p.Elem() p = p.Elem()
} }
f := p.FieldByName(relationName) f := p.FieldByName(relationName)
@@ -784,7 +784,7 @@ func checkIfRelationAlreadyLoaded(parents reflect.Value, relationName string) (r
return allRelated, true return allRelated, true
} }
} }
} else if relationField.Kind() == reflect.Ptr { } else if relationField.Kind() == reflect.Pointer {
// Check if it's a pointer (has-one/belongs-to) // Check if it's a pointer (has-one/belongs-to)
if !relationField.IsNil() { if !relationField.IsNil() {
// Already loaded! Collect all related records from all parents // Already loaded! Collect all related records from all parents
@@ -792,7 +792,7 @@ func checkIfRelationAlreadyLoaded(parents reflect.Value, relationName string) (r
allRelated := reflect.MakeSlice(reflect.SliceOf(relatedType), 0, parents.Len()) allRelated := reflect.MakeSlice(reflect.SliceOf(relatedType), 0, parents.Len())
for j := 0; j < parents.Len(); j++ { for j := 0; j < parents.Len(); j++ {
p := parents.Index(j) p := parents.Index(j)
if p.Kind() == reflect.Ptr { if p.Kind() == reflect.Pointer {
p = p.Elem() p = p.Elem()
} }
f := p.FieldByName(relationName) f := p.FieldByName(relationName)
@@ -816,7 +816,7 @@ func (b *BunSelectQuery) loadCustomPreloads(ctx context.Context) error {
// Get the actual data from the model // Get the actual data from the model
modelValue := reflect.ValueOf(model.Value()) modelValue := reflect.ValueOf(model.Value())
if modelValue.Kind() == reflect.Ptr { if modelValue.Kind() == reflect.Pointer {
modelValue = modelValue.Elem() modelValue = modelValue.Elem()
} }
@@ -884,7 +884,7 @@ func (b *BunSelectQuery) loadRelationLevel(ctx context.Context, parentRecords re
// Get the first record to inspect the struct type // Get the first record to inspect the struct type
firstRecord := parentRecords.Index(0) firstRecord := parentRecords.Index(0)
if firstRecord.Kind() == reflect.Ptr { if firstRecord.Kind() == reflect.Pointer {
firstRecord = firstRecord.Elem() firstRecord = firstRecord.Elem()
} }
@@ -930,7 +930,7 @@ func (b *BunSelectQuery) loadRelationLevel(ctx context.Context, parentRecords re
if isSlice { if isSlice {
relatedType = relatedType.Elem() relatedType = relatedType.Elem()
} }
if relatedType.Kind() == reflect.Ptr { if relatedType.Kind() == reflect.Pointer {
relatedType = relatedType.Elem() relatedType = relatedType.Elem()
} }
@@ -1018,7 +1018,7 @@ func extractForeignKeyValues(records reflect.Value, fkFieldName string) ([]inter
for i := 0; i < records.Len(); i++ { for i := 0; i < records.Len(); i++ {
record := records.Index(i) record := records.Index(i)
if record.Kind() == reflect.Ptr { if record.Kind() == reflect.Pointer {
record = record.Elem() record = record.Elem()
} }
@@ -1083,7 +1083,7 @@ func associateRelatedRecords(parents, related reflect.Value, fieldName string, r
for i := 0; i < related.Len(); i++ { for i := 0; i < related.Len(); i++ {
relRecord := related.Index(i) relRecord := related.Index(i)
relRecordElem := relRecord relRecordElem := relRecord
if relRecordElem.Kind() == reflect.Ptr { if relRecordElem.Kind() == reflect.Pointer {
relRecordElem = relRecordElem.Elem() relRecordElem = relRecordElem.Elem()
} }
@@ -1109,7 +1109,7 @@ func associateRelatedRecords(parents, related reflect.Value, fieldName string, r
for i := 0; i < parents.Len(); i++ { for i := 0; i < parents.Len(); i++ {
parentPtr := parents.Index(i) parentPtr := parents.Index(i)
parent := parentPtr parent := parentPtr
if parent.Kind() == reflect.Ptr { if parent.Kind() == reflect.Pointer {
parent = parent.Elem() parent = parent.Elem()
} }
@@ -1332,11 +1332,11 @@ func (b *BunSelectQuery) ScanModel(ctx context.Context) (err error) {
modelInfo = fmt.Sprintf("Model type: %T", modelValue) modelInfo = fmt.Sprintf("Model type: %T", modelValue)
v := reflect.ValueOf(modelValue) v := reflect.ValueOf(modelValue)
if v.Kind() == reflect.Ptr { if v.Kind() == reflect.Pointer {
v = v.Elem() v = v.Elem()
} }
if v.Kind() == reflect.Slice { if v.Kind() == reflect.Slice {
if v.Type().Elem().Kind() == reflect.Ptr { if v.Type().Elem().Kind() == reflect.Pointer {
modelInfo += fmt.Sprintf(", Slice of: %s", v.Type().Elem().Elem().Name()) modelInfo += fmt.Sprintf(", Slice of: %s", v.Type().Elem().Elem().Name())
} else { } else {
modelInfo += fmt.Sprintf(", Slice of: %s", v.Type().Elem().Name()) modelInfo += fmt.Sprintf(", Slice of: %s", v.Type().Elem().Name())
@@ -1489,7 +1489,7 @@ func (b *BunInsertQuery) OnConflict(action string) common.InsertQuery {
func (b *BunInsertQuery) Returning(columns ...string) common.InsertQuery { func (b *BunInsertQuery) Returning(columns ...string) common.InsertQuery {
if len(columns) > 0 { if len(columns) > 0 {
b.query = b.query.Returning(columns[0]) b.query = b.query.Returning(strings.Join(columns, ", "))
} }
return b return b
} }
@@ -1606,7 +1606,7 @@ func (b *BunUpdateQuery) Where(query string, args ...interface{}) common.UpdateQ
func (b *BunUpdateQuery) Returning(columns ...string) common.UpdateQuery { func (b *BunUpdateQuery) Returning(columns ...string) common.UpdateQuery {
if len(columns) > 0 { if len(columns) > 0 {
b.query = b.query.Returning(columns[0]) b.query = b.query.Returning(strings.Join(columns, ", "))
} }
return b return b
} }
+1 -1
View File
@@ -800,7 +800,7 @@ func (g *GormInsertQuery) Scan(ctx context.Context, dest interface{}) (err error
col := g.returningColumns[0] col := g.returningColumns[0]
if g.model != nil { if g.model != nil {
val := reflect.ValueOf(g.model) val := reflect.ValueOf(g.model)
if val.Kind() == reflect.Ptr { if val.Kind() == reflect.Pointer {
val = val.Elem() val = val.Elem()
} }
if val.Kind() == reflect.Struct { if val.Kind() == reflect.Struct {
+8 -8
View File
@@ -1195,7 +1195,7 @@ func (p *PgSQLSelectQuery) applySubqueryPreloads(ctx context.Context, dest inter
// Use reflection to process the destination // Use reflection to process the destination
destValue := reflect.ValueOf(dest) destValue := reflect.ValueOf(dest)
if destValue.Kind() != reflect.Ptr { if destValue.Kind() != reflect.Pointer {
return fmt.Errorf("dest must be a pointer") return fmt.Errorf("dest must be a pointer")
} }
@@ -1222,7 +1222,7 @@ func (p *PgSQLSelectQuery) applySubqueryPreloads(ctx context.Context, dest inter
// loadPreloadsForRecord loads all preload relationships for a single record // loadPreloadsForRecord loads all preload relationships for a single record
func (p *PgSQLSelectQuery) loadPreloadsForRecord(ctx context.Context, record reflect.Value, preloads []preloadConfig) error { func (p *PgSQLSelectQuery) loadPreloadsForRecord(ctx context.Context, record reflect.Value, preloads []preloadConfig) error {
if record.Kind() == reflect.Ptr { if record.Kind() == reflect.Pointer {
if record.IsNil() { if record.IsNil() {
return nil return nil
} }
@@ -1299,7 +1299,7 @@ func (p *PgSQLSelectQuery) executePreloadQuery(ctx context.Context, field reflec
} else { } else {
// Single struct - create a pointer if needed // Single struct - create a pointer if needed
var target reflect.Value var target reflect.Value
if field.Kind() == reflect.Ptr { if field.Kind() == reflect.Pointer {
target = reflect.New(field.Type().Elem()) target = reflect.New(field.Type().Elem())
} else { } else {
target = reflect.New(field.Type()) target = reflect.New(field.Type())
@@ -1312,7 +1312,7 @@ func (p *PgSQLSelectQuery) executePreloadQuery(ctx context.Context, field reflec
} }
// Set the field // Set the field
if field.Kind() == reflect.Ptr { if field.Kind() == reflect.Pointer {
field.Set(target) field.Set(target)
} else { } else {
field.Set(target.Elem()) field.Set(target.Elem())
@@ -1329,7 +1329,7 @@ func (p *PgSQLSelectQuery) getRelationMetadata(fieldName string) *relationMetada
} }
modelType := reflect.TypeOf(p.model) modelType := reflect.TypeOf(p.model)
if modelType.Kind() == reflect.Ptr { if modelType.Kind() == reflect.Pointer {
modelType = modelType.Elem() modelType = modelType.Elem()
} }
@@ -1378,7 +1378,7 @@ func (p *PgSQLSelectQuery) getRelationMetadataFromField(modelType reflect.Type,
if fieldType.Kind() == reflect.Slice { if fieldType.Kind() == reflect.Slice {
fieldType = fieldType.Elem() fieldType = fieldType.Elem()
} }
if fieldType.Kind() == reflect.Ptr { if fieldType.Kind() == reflect.Pointer {
fieldType = fieldType.Elem() fieldType = fieldType.Elem()
} }
@@ -1411,7 +1411,7 @@ func scanRows(rows *sql.Rows, dest interface{}) error {
// Get destination type // Get destination type
destValue := reflect.ValueOf(dest) destValue := reflect.ValueOf(dest)
if destValue.Kind() != reflect.Ptr { if destValue.Kind() != reflect.Pointer {
return fmt.Errorf("dest must be a pointer") return fmt.Errorf("dest must be a pointer")
} }
@@ -1466,7 +1466,7 @@ func scanRowsToMapSlice(rows *sql.Rows, columns []string, destValue reflect.Valu
// scanRowsToStructSlice scans rows into a slice of structs // scanRowsToStructSlice scans rows into a slice of structs
func scanRowsToStructSlice(rows *sql.Rows, columns []string, destValue reflect.Value) error { func scanRowsToStructSlice(rows *sql.Rows, columns []string, destValue reflect.Value) error {
elemType := destValue.Type().Elem() elemType := destValue.Type().Elem()
isPtr := elemType.Kind() == reflect.Ptr isPtr := elemType.Kind() == reflect.Pointer
if isPtr { if isPtr {
elemType = elemType.Elem() elemType = elemType.Elem()
@@ -71,7 +71,7 @@ func entityNameFromModel(model interface{}, table string) string {
} }
modelType := reflect.TypeOf(model) modelType := reflect.TypeOf(model)
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) { for modelType != nil && (modelType.Kind() == reflect.Pointer || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
modelType = modelType.Elem() modelType = modelType.Elem()
} }
@@ -108,7 +108,7 @@ func tableNameProviderFromModel(model interface{}) (common.TableNameProvider, bo
} }
modelType := reflect.TypeOf(model) modelType := reflect.TypeOf(model)
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) { for modelType != nil && (modelType.Kind() == reflect.Pointer || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
modelType = modelType.Elem() modelType = modelType.Elem()
} }
+3 -1
View File
@@ -174,7 +174,9 @@ func (h *HTTPResponseWriter) Write(data []byte) (int, error) {
func (h *HTTPResponseWriter) WriteJSON(data interface{}) error { func (h *HTTPResponseWriter) WriteJSON(data interface{}) error {
h.SetHeader("Content-Type", "application/json") h.SetHeader("Content-Type", "application/json")
return json.NewEncoder(h.resp).Encode(data) enc := json.NewEncoder(h.resp)
enc.SetEscapeHTML(false)
return enc.Encode(data)
} }
// UnderlyingResponseWriter returns the underlying http.ResponseWriter // UnderlyingResponseWriter returns the underlying http.ResponseWriter
+9 -9
View File
@@ -25,7 +25,7 @@ func ValidateAndUnwrapModel(model interface{}) (*ValidateAndUnwrapModelResult, e
originalType := modelType originalType := modelType
// Unwrap pointers, slices, and arrays to get to the base struct type // Unwrap pointers, slices, and arrays to get to the base struct type
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) { for modelType != nil && (modelType.Kind() == reflect.Pointer || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
modelType = modelType.Elem() modelType = modelType.Elem()
} }
@@ -126,15 +126,15 @@ func GetRelationshipInfo(modelType reflect.Type, relationName string) *Relations
// Get related model type // Get related model type
if field.Type.Kind() == reflect.Slice { if field.Type.Kind() == reflect.Slice {
elemType := field.Type.Elem() elemType := field.Type.Elem()
if elemType.Kind() == reflect.Ptr { if elemType.Kind() == reflect.Pointer {
elemType = elemType.Elem() elemType = elemType.Elem()
} }
if elemType.Kind() == reflect.Struct { if elemType.Kind() == reflect.Struct {
info.RelatedModel = reflect.New(elemType).Elem().Interface() info.RelatedModel = reflect.New(elemType).Elem().Interface()
} }
} else if field.Type.Kind() == reflect.Ptr || field.Type.Kind() == reflect.Struct { } else if field.Type.Kind() == reflect.Pointer || field.Type.Kind() == reflect.Struct {
elemType := field.Type elemType := field.Type
if elemType.Kind() == reflect.Ptr { if elemType.Kind() == reflect.Pointer {
elemType = elemType.Elem() elemType = elemType.Elem()
} }
if elemType.Kind() == reflect.Struct { if elemType.Kind() == reflect.Struct {
@@ -155,16 +155,16 @@ func GetRelationshipInfo(modelType reflect.Type, relationName string) *Relations
info.RelationType = "hasMany" info.RelationType = "hasMany"
// Get the element type for slice // Get the element type for slice
elemType := field.Type.Elem() elemType := field.Type.Elem()
if elemType.Kind() == reflect.Ptr { if elemType.Kind() == reflect.Pointer {
elemType = elemType.Elem() elemType = elemType.Elem()
} }
if elemType.Kind() == reflect.Struct { if elemType.Kind() == reflect.Struct {
info.RelatedModel = reflect.New(elemType).Elem().Interface() info.RelatedModel = reflect.New(elemType).Elem().Interface()
} }
} else if field.Type.Kind() == reflect.Ptr || field.Type.Kind() == reflect.Struct { } else if field.Type.Kind() == reflect.Pointer || field.Type.Kind() == reflect.Struct {
info.RelationType = "belongsTo" info.RelationType = "belongsTo"
elemType := field.Type elemType := field.Type
if elemType.Kind() == reflect.Ptr { if elemType.Kind() == reflect.Pointer {
elemType = elemType.Elem() elemType = elemType.Elem()
} }
if elemType.Kind() == reflect.Struct { if elemType.Kind() == reflect.Struct {
@@ -177,7 +177,7 @@ func GetRelationshipInfo(modelType reflect.Type, relationName string) *Relations
// Get the element type for many2many (always slice) // Get the element type for many2many (always slice)
if field.Type.Kind() == reflect.Slice { if field.Type.Kind() == reflect.Slice {
elemType := field.Type.Elem() elemType := field.Type.Elem()
if elemType.Kind() == reflect.Ptr { if elemType.Kind() == reflect.Pointer {
elemType = elemType.Elem() elemType = elemType.Elem()
} }
if elemType.Kind() == reflect.Struct { if elemType.Kind() == reflect.Struct {
@@ -239,7 +239,7 @@ func GetTableNameFromModel(model interface{}) string {
modelType := reflect.TypeOf(model) modelType := reflect.TypeOf(model)
// Unwrap pointers // Unwrap pointers
for modelType != nil && modelType.Kind() == reflect.Ptr { for modelType != nil && modelType.Kind() == reflect.Pointer {
modelType = modelType.Elem() modelType = modelType.Elem()
} }
+3 -1
View File
@@ -178,7 +178,9 @@ func (s *StandardResponseWriter) Write(data []byte) (int, error) {
func (s *StandardResponseWriter) WriteJSON(data interface{}) error { func (s *StandardResponseWriter) WriteJSON(data interface{}) error {
s.SetHeader("Content-Type", "application/json") s.SetHeader("Content-Type", "application/json")
return json.NewEncoder(s.w).Encode(data) enc := json.NewEncoder(s.w)
enc.SetEscapeHTML(false)
return enc.Encode(data)
} }
func (s *StandardResponseWriter) UnderlyingResponseWriter() http.ResponseWriter { func (s *StandardResponseWriter) UnderlyingResponseWriter() http.ResponseWriter {
+17 -13
View File
@@ -69,7 +69,7 @@ func (p *NestedCUDProcessor) ProcessNestedCUD(
// Get model type for reflection // Get model type for reflection
modelType := reflect.TypeOf(model) modelType := reflect.TypeOf(model)
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) { for modelType != nil && (modelType.Kind() == reflect.Pointer || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
modelType = modelType.Elem() modelType = modelType.Elem()
} }
@@ -113,7 +113,7 @@ func (p *NestedCUDProcessor) ProcessNestedCUD(
// Process based on operation // Process based on operation
switch strings.ToLower(operation) { switch strings.ToLower(operation) {
case "insert", "create": case "insert", "create", "add":
// Only perform insert if we have data to insert // Only perform insert if we have data to insert
if hasData { if hasData {
id, err := p.processInsert(ctx, regularData, tableName) id, err := p.processInsert(ctx, regularData, tableName)
@@ -141,7 +141,7 @@ func (p *NestedCUDProcessor) ProcessNestedCUD(
logger.Debug("Skipping insert for %s - no data columns besides _request", tableName) logger.Debug("Skipping insert for %s - no data columns besides _request", tableName)
} }
case "update", "change": case "update", "change", "modify":
// Only perform update if we have data to update // Only perform update if we have data to update
if reflection.IsEmptyValue(data[pkName]) { if reflection.IsEmptyValue(data[pkName]) {
logger.Warn("Skipping update for %s - no primary key", tableName) logger.Warn("Skipping update for %s - no primary key", tableName)
@@ -174,7 +174,7 @@ func (p *NestedCUDProcessor) ProcessNestedCUD(
result.ID = data[pkName] result.ID = data[pkName]
} }
case "delete": case "delete", "remove":
if reflection.IsEmptyValue(data[pkName]) { if reflection.IsEmptyValue(data[pkName]) {
logger.Warn("Skipping delete for %s - no primary key", tableName) logger.Warn("Skipping delete for %s - no primary key", tableName)
return result, nil return result, nil
@@ -224,7 +224,7 @@ func (p *NestedCUDProcessor) filterValidFields(data map[string]interface{}, mode
} }
modelType := reflect.TypeOf(model) modelType := reflect.TypeOf(model)
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) { for modelType != nil && (modelType.Kind() == reflect.Pointer || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
modelType = modelType.Elem() modelType = modelType.Elem()
} }
@@ -410,7 +410,7 @@ func (p *NestedCUDProcessor) processChildRelations(
if relatedModelType.Kind() == reflect.Slice { if relatedModelType.Kind() == reflect.Slice {
relatedModelType = relatedModelType.Elem() relatedModelType = relatedModelType.Elem()
} }
if relatedModelType.Kind() == reflect.Ptr { if relatedModelType.Kind() == reflect.Pointer {
relatedModelType = relatedModelType.Elem() relatedModelType = relatedModelType.Elem()
} }
@@ -471,13 +471,17 @@ func (p *NestedCUDProcessor) processChildRelations(
// Priority: Use foreign key field name if specified // Priority: Use foreign key field name if specified
var foreignKeyFieldName string var foreignKeyFieldName string
if relInfo.ForeignKey != "" { if relInfo.ForeignKey != "" {
// Get the JSON name for the foreign key field in the child model // For has-many/has-one: join:parentCol=childCol
foreignKeyFieldName = reflection.GetJSONNameForField(relatedModelType, relInfo.ForeignKey) // ForeignKey = parent side, References = child side (where we actually set the value)
if foreignKeyFieldName == "" { childField := relInfo.ForeignKey
// Fallback to lowercase field name if (relInfo.RelationType == "hasMany" || relInfo.RelationType == "hasOne") && relInfo.References != "" {
foreignKeyFieldName = strings.ToLower(relInfo.ForeignKey) childField = relInfo.References
} }
logger.Debug("Using foreign key field for direct assignment: %s (from FK %s)", foreignKeyFieldName, relInfo.ForeignKey) foreignKeyFieldName = reflection.GetJSONNameForField(relatedModelType, childField)
if foreignKeyFieldName == "" {
foreignKeyFieldName = strings.ToLower(childField)
}
logger.Debug("Using foreign key field for direct assignment: %s (from FK %s -> child %s)", foreignKeyFieldName, relInfo.ForeignKey, childField)
} }
// Get the primary key name for the child model to avoid overwriting it in recursive relationships // Get the primary key name for the child model to avoid overwriting it in recursive relationships
@@ -586,7 +590,7 @@ func shouldUseNestedProcessorDepth(data map[string]interface{}, model interface{
// Get model type // Get model type
modelType := reflect.TypeOf(model) modelType := reflect.TypeOf(model)
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) { for modelType != nil && (modelType.Kind() == reflect.Pointer || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
modelType = modelType.Elem() modelType = modelType.Elem()
} }
+214
View File
@@ -713,6 +713,220 @@ func TestInjectForeignKeys(t *testing.T) {
} }
} }
// Models for asymmetric join column tests (mirrors the bun has-many join:parentCol=childCol pattern).
// ActionOption has-many ActionOptionLinks via join:rid_actionoption=rid_actionoption_child.
// The child column ("rid_actionoption_child") differs from the parent column ("rid_actionoption").
type ActionOption struct {
RidActionoption int64 `json:"rid_actionoption" bun:"rid_actionoption,pk"`
Label string `json:"label"`
Links []*ActionOptionLink `json:"aol_rid_actionoption_child,omitempty"`
}
func (a ActionOption) TableName() string { return "action_options" }
func (a ActionOption) GetIDName() string { return "RidActionoption" }
type ActionOptionLink struct {
RidActionoptionlink int64 `json:"rid_actionoptionlink" bun:"rid_actionoptionlink,pk"`
RidActionoptionChild int64 `json:"rid_actionoption_child" bun:"rid_actionoption_child"`
Label string `json:"label"`
// Note: no field named "rid_actionoption" — that is the parent's column.
}
func (a ActionOptionLink) TableName() string { return "action_option_links" }
func (a ActionOptionLink) GetIDName() string { return "RidActionoptionlink" }
// TestProcessNestedCUD_AsymmetricJoinColumns verifies that for a has-many relation with
// join:parentCol=childCol, the child rows are stamped with the child-side column (References),
// not the parent-side column (ForeignKey).
func TestProcessNestedCUD_AsymmetricJoinColumns(t *testing.T) {
db := newMockDatabase()
registry := &mockModelRegistry{}
relProvider := newMockRelationshipProvider()
// Mirrors: bun:"rel:has-many,join:rid_actionoption=rid_actionoption_child"
relProvider.RegisterRelation("ActionOption", "aol_rid_actionoption_child", &RelationshipInfo{
FieldName: "Links",
JSONName: "aol_rid_actionoption_child",
RelationType: "hasMany",
ForeignKey: "rid_actionoption", // parent-side column (left of join:)
References: "rid_actionoption_child", // child-side column (right of join:)
RelatedModel: ActionOptionLink{},
})
processor := NewNestedCUDProcessor(db, registry, relProvider)
data := map[string]interface{}{
"label": "option-a",
"aol_rid_actionoption_child": []interface{}{
map[string]interface{}{"label": "link-1"},
},
}
_, err := processor.ProcessNestedCUD(
context.Background(),
"insert",
data,
ActionOption{},
nil,
"action_options",
)
if err != nil {
t.Fatalf("ProcessNestedCUD failed: %v", err)
}
if len(db.insertCalls) < 2 {
t.Fatalf("Expected at least 2 insert calls (parent + child), got %d", len(db.insertCalls))
}
childInsert := db.insertCalls[1]
// The fix: child must receive "rid_actionoption_child", NOT "rid_actionoption".
if childInsert["rid_actionoption_child"] == nil {
t.Error("Expected child to have rid_actionoption_child set (child-side FK column)")
}
if childInsert["rid_actionoption"] != nil {
t.Errorf("Child must not receive parent-side column rid_actionoption, got %v", childInsert["rid_actionoption"])
}
}
// TestProcessNestedCUD_BelongsToUnchanged verifies that the fix does not regress belongsTo
// relations, where ForeignKey is already the local (child) column.
func TestProcessNestedCUD_BelongsToUnchanged(t *testing.T) {
db := newMockDatabase()
registry := &mockModelRegistry{}
relProvider := newMockRelationshipProvider()
// For belongsTo, ForeignKey is the column on the child; References is on the parent.
// The old and new code must behave identically here.
relProvider.RegisterRelation("Employee", "department", &RelationshipInfo{
FieldName: "Department",
JSONName: "department",
RelationType: "belongsTo",
ForeignKey: "DepartmentID", // child's own column
References: "ID", // parent's PK
RelatedModel: Department{},
})
relProvider.RegisterRelation("Department", "employees", &RelationshipInfo{
FieldName: "Employees",
JSONName: "employees",
RelationType: "has_many",
ForeignKey: "DepartmentID",
RelatedModel: Employee{},
})
processor := NewNestedCUDProcessor(db, registry, relProvider)
data := map[string]interface{}{
"name": "Engineering",
"employees": []interface{}{
map[string]interface{}{"name": "Alice"},
},
}
_, err := processor.ProcessNestedCUD(
context.Background(),
"insert",
data,
Department{},
nil,
"departments",
)
if err != nil {
t.Fatalf("ProcessNestedCUD failed: %v", err)
}
if len(db.insertCalls) < 2 {
t.Fatalf("Expected at least 2 inserts, got %d", len(db.insertCalls))
}
// Employees relation uses has_many (old-style) so it goes through the parentIDs injection path,
// not the foreignKeyFieldName path. Just confirm no panic and employee is inserted.
if db.insertCalls[0]["name"] != "Engineering" {
t.Errorf("Expected department name 'Engineering', got %v", db.insertCalls[0]["name"])
}
}
func TestProcessNestedCUD_AddAlias(t *testing.T) {
db := newMockDatabase()
registry := &mockModelRegistry{}
relProvider := newMockRelationshipProvider()
processor := NewNestedCUDProcessor(db, registry, relProvider)
data := map[string]interface{}{
"_request": "add",
"name": "New Department",
}
result, err := processor.ProcessNestedCUD(context.Background(), "insert", data, Department{}, nil, "departments")
if err != nil {
t.Fatalf("ProcessNestedCUD with _request=add failed: %v", err)
}
if result.ID == nil {
t.Error("Expected result.ID to be set after add")
}
if len(db.insertCalls) != 1 {
t.Errorf("Expected 1 insert call, got %d", len(db.insertCalls))
}
}
func TestProcessNestedCUD_RemoveAlias(t *testing.T) {
db := newMockDatabase()
registry := &mockModelRegistry{}
relProvider := newMockRelationshipProvider()
processor := NewNestedCUDProcessor(db, registry, relProvider)
data := map[string]interface{}{
"_request": "remove",
"ID": int64(42),
}
_, err := processor.ProcessNestedCUD(context.Background(), "delete", data, Department{}, nil, "departments")
if err != nil {
t.Fatalf("ProcessNestedCUD with _request=remove failed: %v", err)
}
if len(db.deleteCalls) != 1 {
t.Errorf("Expected 1 delete call, got %d", len(db.deleteCalls))
}
}
func TestProcessNestedCUD_NestedAddRemoveAliases(t *testing.T) {
db := newMockDatabase()
registry := &mockModelRegistry{}
relProvider := newMockRelationshipProvider()
relProvider.RegisterRelation("Department", "employees", &RelationshipInfo{
FieldName: "Employees",
JSONName: "employees",
RelationType: "has_many",
ForeignKey: "DepartmentID",
RelatedModel: Employee{},
})
processor := NewNestedCUDProcessor(db, registry, relProvider)
data := map[string]interface{}{
"ID": int64(1),
"name": "Engineering",
"employees": []interface{}{
map[string]interface{}{"_request": "add", "name": "Alice"},
map[string]interface{}{"_request": "remove", "ID": int64(5)},
},
}
_, err := processor.ProcessNestedCUD(context.Background(), "update", data, Department{}, nil, "departments")
if err != nil {
t.Fatalf("ProcessNestedCUD with nested add/remove failed: %v", err)
}
if len(db.insertCalls) != 1 {
t.Errorf("Expected 1 insert (add alias) for employee, got %d", len(db.insertCalls))
}
if len(db.deleteCalls) != 1 {
t.Errorf("Expected 1 delete (remove alias) for employee, got %d", len(db.deleteCalls))
}
}
func TestGetPrimaryKeyName(t *testing.T) { func TestGetPrimaryKeyName(t *testing.T) {
dept := Department{} dept := Department{}
pkName := reflection.GetPrimaryKeyName(dept) pkName := reflection.GetPrimaryKeyName(dept)
+52 -18
View File
@@ -446,18 +446,36 @@ func containsTopLevelOR(clause string) bool {
return false return false
} }
// splitByAND splits a WHERE clause by AND operators (case-insensitive) // splitByAND splits a WHERE clause by AND operators (case-insensitive).
// This is parenthesis-aware and won't split on AND operators inside subqueries // It is parenthesis-aware (won't split inside subqueries), quote-aware
// (won't split on AND inside single-quoted strings), and BETWEEN-aware
// (won't split on the AND that separates the two operands of BETWEEN x AND y).
func splitByAND(where string) []string { func splitByAND(where string) []string {
conditions := []string{} conditions := []string{}
currentCondition := strings.Builder{} currentCondition := strings.Builder{}
depth := 0 // Track parenthesis depth depth := 0 // parenthesis nesting depth
inSingleQuote := false
afterBetween := false // true after seeing BETWEEN at depth 0; next AND belongs to it
i := 0 i := 0
for i < len(where) { for i < len(where) {
ch := where[i] ch := where[i]
// Track parenthesis depth // Track single-quote state so we never split on AND inside string literals.
if ch == '\'' {
inSingleQuote = !inSingleQuote
currentCondition.WriteByte(ch)
i++
continue
}
if inSingleQuote {
currentCondition.WriteByte(ch)
i++
continue
}
// Track parenthesis depth (outside quotes only).
if ch == '(' { if ch == '(' {
depth++ depth++
currentCondition.WriteByte(ch) currentCondition.WriteByte(ch)
@@ -470,32 +488,39 @@ func splitByAND(where string) []string {
continue continue
} }
// Only look for AND operators at depth 0 (not inside parentheses) // All keyword checks only apply at depth 0 (not inside subqueries).
if depth == 0 { if depth == 0 {
// Check if we're at an AND operator (case-insensitive) // Detect " BETWEEN " (9 chars, case-insensitive) so the very next
// We need at least " AND " (5 chars) or " and " (5 chars) // top-level AND is recognised as part of the BETWEEN syntax.
if i+5 <= len(where) { if i+9 <= len(where) && strings.ToLower(where[i:i+9]) == " between " {
substring := where[i : i+5] afterBetween = true
lowerSubstring := strings.ToLower(substring) currentCondition.WriteString(where[i : i+9])
i += 9
continue
}
if lowerSubstring == " and " { // Detect " AND " (5 chars, case-insensitive).
// Found an AND operator at the top level if i+5 <= len(where) && strings.ToLower(where[i:i+5]) == " and " {
// Add the current condition to the list if afterBetween {
conditions = append(conditions, currentCondition.String()) // This AND closes a BETWEEN expression — do NOT split.
currentCondition.Reset() afterBetween = false
// Skip past the AND operator currentCondition.WriteString(where[i : i+5])
i += 5 i += 5
continue continue
} }
// Regular conjunction — split here.
conditions = append(conditions, currentCondition.String())
currentCondition.Reset()
i += 5
continue
} }
} }
// Not an AND operator or we're inside parentheses, just add the character
currentCondition.WriteByte(ch) currentCondition.WriteByte(ch)
i++ i++
} }
// Add the last condition // Add the last condition.
if currentCondition.Len() > 0 { if currentCondition.Len() > 0 {
conditions = append(conditions, currentCondition.String()) conditions = append(conditions, currentCondition.String())
} }
@@ -614,6 +639,15 @@ func extractTableAndColumn(cond string) (table string, column string) {
// Remove any quotes // Remove any quotes
columnRef = strings.Trim(columnRef, "`\"'") columnRef = strings.Trim(columnRef, "`\"'")
// If the left side is a parenthesized subquery (starts with '(' and contains SQL keywords),
// don't attempt prefix extraction from inside it.
if len(columnRef) > 0 && columnRef[0] == '(' {
lowerRef := strings.ToLower(columnRef)
if strings.Contains(lowerRef, "select ") || strings.Contains(lowerRef, " from ") || strings.Contains(lowerRef, " where ") {
return "", ""
}
}
// Check if there's a function call (contains opening parenthesis) // Check if there's a function call (contains opening parenthesis)
openParenIdx := strings.Index(columnRef, "(") openParenIdx := strings.Index(columnRef, "(")
+51
View File
@@ -520,6 +520,38 @@ func TestSplitByAND(t *testing.T) {
input: "a = 1 AND b = 2 AND c = 3 and (select s from generate_series(1,10) s where s < 10 and s > 0 offset 2 limit 1) = 3", input: "a = 1 AND b = 2 AND c = 3 and (select s from generate_series(1,10) s where s < 10 and s > 0 offset 2 limit 1) = 3",
expected: []string{"a = 1", "b = 2", "c = 3", "(select s from generate_series(1,10) s where s < 10 and s > 0 offset 2 limit 1) = 3"}, expected: []string{"a = 1", "b = 2", "c = 3", "(select s from generate_series(1,10) s where s < 10 and s > 0 offset 2 limit 1) = 3"},
}, },
// BETWEEN-aware cases: the AND inside BETWEEN x AND y must not cause a split.
{
name: "BETWEEN does not split on its AND",
input: "col between '2025-08-31' and '1970-01-01'",
expected: []string{"col between '2025-08-31' and '1970-01-01'"},
},
{
name: "BETWEEN uppercase AND",
input: "col BETWEEN '2025-08-31' AND '1970-01-01'",
expected: []string{"col BETWEEN '2025-08-31' AND '1970-01-01'"},
},
{
name: "BETWEEN followed by a regular AND conjunction",
input: "col between 1 and 5 and other = 'x'",
expected: []string{"col between 1 and 5", "other = 'x'"},
},
{
name: "two BETWEEN conditions joined by AND",
input: "col1 between 1 and 5 and col2 between 10 and 20",
expected: []string{"col1 between 1 and 5", "col2 between 10 and 20"},
},
{
name: "complex OR block with multiple BETWEENs (real-world case)",
input: "tbl.applicationdate between '2025-08-31' and '1970-01-01'\n or tbl.capturedate between '2025-08-31' and '1970-01-01'\n or tbl.startdate between '2025-08-31' AND '1970-01-01'",
expected: []string{"tbl.applicationdate between '2025-08-31' and '1970-01-01'\n or tbl.capturedate between '2025-08-31' and '1970-01-01'\n or tbl.startdate between '2025-08-31' AND '1970-01-01'"},
},
// Quote-aware cases: AND inside a string literal must not split.
{
name: "AND inside single-quoted string is not a split point",
input: "comment = 'this and that' and status = 'active'",
expected: []string{"comment = 'this and that'", "status = 'active'"},
},
} }
for _, tt := range tests { for _, tt := range tests {
@@ -917,6 +949,25 @@ where: "(true AND status = 'active')",
tableName: "unregistered_table", tableName: "unregistered_table",
expected: "(true AND unregistered_table.status = 'active')", expected: "(true AND unregistered_table.status = 'active')",
}, },
// BETWEEN regression: date literals inside BETWEEN must not be prefixed as columns.
{
name: "BETWEEN date range - second date must not be prefixed",
where: "applicationdate between '2025-08-31' and '1970-01-01'",
tableName: "unregistered_table",
expected: "unregistered_table.applicationdate between '2025-08-31' and '1970-01-01'",
},
{
name: "Already-prefixed BETWEEN column - unchanged",
where: `"v_webui_clients".applicationdate between '2025-08-31' and '1970-01-01'`,
tableName: "v_webui_clients",
expected: `"v_webui_clients".applicationdate between '2025-08-31' and '1970-01-01'`,
},
{
name: "Complex OR block with multiple BETWEENs - date values must not be prefixed",
where: `("v_webui_clients".applicationdate between '2025-08-31' and '1970-01-01' or "v_webui_clients".clientcapturedate between '2025-08-31' and '1970-01-01' or "v_webui_clients".startdate between '2025-08-31' AND '1970-01-01')`,
tableName: "v_webui_clients",
expected: `("v_webui_clients".applicationdate between '2025-08-31' and '1970-01-01' or "v_webui_clients".clientcapturedate between '2025-08-31' and '1970-01-01' or "v_webui_clients".startdate between '2025-08-31' AND '1970-01-01')`,
},
} }
for _, tt := range tests { for _, tt := range tests {
+2 -2
View File
@@ -31,7 +31,7 @@ func (v *ColumnValidator) buildValidColumns() {
modelType := reflect.TypeOf(v.model) modelType := reflect.TypeOf(v.model)
// Unwrap pointers, slices, and arrays to get to the base struct type // Unwrap pointers, slices, and arrays to get to the base struct type
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) { for modelType != nil && (modelType.Kind() == reflect.Pointer || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
modelType = modelType.Elem() modelType = modelType.Elem()
} }
@@ -290,7 +290,7 @@ func (v *ColumnValidator) FilterRequestOptions(options RequestOptions) RequestOp
// Filter Preload columns // Filter Preload columns
validPreloads := make([]PreloadOption, 0, len(options.Preload)) validPreloads := make([]PreloadOption, 0, len(options.Preload))
modelType := reflect.TypeOf(v.model) modelType := reflect.TypeOf(v.model)
if modelType != nil && modelType.Kind() == reflect.Ptr { if modelType != nil && modelType.Kind() == reflect.Pointer {
modelType = modelType.Elem() modelType = modelType.Elem()
} }
for idx := range options.Preload { for idx := range options.Preload {
+52 -4
View File
@@ -17,6 +17,7 @@ import (
"github.com/bitechdev/ResolveSpec/pkg/common" "github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/logger" "github.com/bitechdev/ResolveSpec/pkg/logger"
"github.com/bitechdev/ResolveSpec/pkg/reflection"
"github.com/bitechdev/ResolveSpec/pkg/restheadspec" "github.com/bitechdev/ResolveSpec/pkg/restheadspec"
"github.com/bitechdev/ResolveSpec/pkg/security" "github.com/bitechdev/ResolveSpec/pkg/security"
) )
@@ -367,13 +368,17 @@ func (h *Handler) SqlQueryList(sqlquery string, options SqlQueryOptions) HTTPFun
} }
case "detail": case "detail":
// Detail format: complex API with metadata // Detail format: { count, fields, items, tablename, tableprefix, total }
tableName := r.URL.Path
tablePrefix := reflection.ExtractTableNameOnly(tableName)
fields := buildDetailFieldsFromRows(dbobjlist)
metaobj := map[string]interface{}{ metaobj := map[string]interface{}{
"items": dbobjlist,
"count": fmt.Sprintf("%d", len(dbobjlist)), "count": fmt.Sprintf("%d", len(dbobjlist)),
"fields": fields,
"items": dbobjlist,
"tablename": tableName,
"tableprefix": tablePrefix,
"total": fmt.Sprintf("%d", total), "total": fmt.Sprintf("%d", total),
"tablename": r.URL.Path,
"tableprefix": "gsql",
} }
data, err := json.Marshal(metaobj) data, err := json.Marshal(metaobj)
if err != nil { if err != nil {
@@ -1079,6 +1084,49 @@ func getReplacementForBlankParam(sqlquery, param string) string {
// return result // return result
// } // }
// buildDetailFieldsFromRows builds a field metadata list from the column names and value types
// of a raw SQL result set. Used when no model struct is available (funcspec raw queries).
func buildDetailFieldsFromRows(rows []map[string]interface{}) []reflection.ModelFieldDetail {
if len(rows) == 0 {
return []reflection.ModelFieldDetail{}
}
first := rows[0]
fields := make([]reflection.ModelFieldDetail, 0, len(first))
for colName, val := range first {
dataType := inferGoType(val)
fields = append(fields, reflection.ModelFieldDetail{
Name: colName,
DataType: dataType,
SQLName: colName,
SQLDataType: "",
SQLKey: "",
Nullable: val == nil,
})
}
return fields
}
// inferGoType returns a simple type name for a value, used for detail field metadata.
func inferGoType(val interface{}) string {
if val == nil {
return "interface{}"
}
switch val.(type) {
case bool:
return "bool"
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
return "int64"
case float32, float64:
return "float64"
case string:
return "string"
case []byte:
return "[]byte"
default:
return "interface{}"
}
}
// getIPAddress extracts the real IP address from the request // getIPAddress extracts the real IP address from the request
func getIPAddress(r *http.Request) string { func getIPAddress(r *http.Request) string {
if forwarded := r.Header.Get("X-Forwarded-For"); forwarded != "" { if forwarded := r.Header.Get("X-Forwarded-For"); forwarded != "" {
+85
View File
@@ -617,6 +617,91 @@ func TestSqlQueryList(t *testing.T) {
} }
}, },
}, },
{
name: "x-detailapi header returns detail format",
sqlQuery: "SELECT * FROM myschema.myentity",
noCount: false,
blankParams: false,
allowFilter: false,
headers: map[string]string{"x-detailapi": "true"},
setupDB: func() *MockDatabase {
return &MockDatabase{
RunInTransactionFunc: func(ctx context.Context, fn func(common.Database) error) error {
db := &MockDatabase{
QueryFunc: func(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
if strings.Contains(query, "COUNT") {
dest.(*struct{ Count int64 }).Count = 3
return nil
}
*dest.(*[]map[string]interface{}) = []map[string]interface{}{
{"id": float64(1), "name": "Alice"},
{"id": float64(2), "name": "Bob"},
{"id": float64(3), "name": "Carol"},
}
return nil
},
}
return fn(db)
},
}
},
expectedStatus: 200,
validateResp: func(t *testing.T, w *httptest.ResponseRecorder) {
var resp map[string]json.RawMessage
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("expected JSON object, got: %s", w.Body.String())
}
for _, key := range []string{"count", "fields", "items", "tablename", "tableprefix", "total"} {
if _, ok := resp[key]; !ok {
t.Errorf("missing key %q in detail response", key)
}
}
var count, total string
json.Unmarshal(resp["count"], &count)
json.Unmarshal(resp["total"], &total)
if count != "3" {
t.Errorf("expected count %q, got %q", "3", count)
}
if total != "3" {
t.Errorf("expected total %q, got %q", "3", total)
}
var items []map[string]interface{}
if err := json.Unmarshal(resp["items"], &items); err != nil {
t.Fatalf("items is not an array: %v", err)
}
if len(items) != 3 {
t.Errorf("expected 3 items, got %d", len(items))
}
var fields []map[string]interface{}
if err := json.Unmarshal(resp["fields"], &fields); err != nil {
t.Fatalf("fields is not an array: %v", err)
}
if len(fields) == 0 {
t.Error("expected non-empty fields list")
}
for _, f := range fields {
for _, key := range []string{"name", "datatype", "sqlname", "sqldatatype", "sqlkey", "nullable"} {
if _, ok := f[key]; !ok {
t.Errorf("field %v missing key %q", f, key)
}
}
}
var tablename, tableprefix string
json.Unmarshal(resp["tablename"], &tablename)
json.Unmarshal(resp["tableprefix"], &tableprefix)
if tablename == "" {
t.Error("expected non-empty tablename")
}
if tableprefix == "" {
t.Error("expected non-empty tableprefix")
}
},
},
{ {
name: "List query with noCount", name: "List query with noCount",
sqlQuery: "SELECT * FROM users", sqlQuery: "SELECT * FROM users",
+2 -2
View File
@@ -107,7 +107,7 @@ func (r *DefaultModelRegistry) RegisterModel(name string, model interface{}) err
originalType := modelType originalType := modelType
// Unwrap pointers, slices, and arrays to check the underlying type // Unwrap pointers, slices, and arrays to check the underlying type
for modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array { for modelType.Kind() == reflect.Pointer || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array {
modelType = modelType.Elem() modelType = modelType.Elem()
} }
@@ -124,7 +124,7 @@ func (r *DefaultModelRegistry) RegisterModel(name string, model interface{}) err
// Additional check: ensure model is not a pointer // Additional check: ensure model is not a pointer
finalType := reflect.TypeOf(model) finalType := reflect.TypeOf(model)
if finalType.Kind() == reflect.Ptr { if finalType.Kind() == reflect.Pointer {
return fmt.Errorf("model must be a non-pointer struct, got pointer to %s. Use MyModel{} instead of &MyModel{}", finalType.Elem().Name()) return fmt.Errorf("model must be a non-pointer struct, got pointer to %s. Use MyModel{} instead of &MyModel{}", finalType.Elem().Name())
} }
+6
View File
@@ -781,6 +781,12 @@ func (h *Handler) create(hookCtx *HookContext) (interface{}, error) {
return nil, fmt.Errorf("failed to create record: %w", err) return nil, fmt.Errorf("failed to create record: %w", err)
} }
// Re-fetch the created record to capture DB-generated defaults/triggers.
if pkVal := reflection.GetPrimaryKeyValue(hookCtx.ModelPtr); pkVal != nil {
hookCtx.ID = fmt.Sprintf("%v", pkVal)
return h.readByID(hookCtx)
}
return hookCtx.ModelPtr, nil return hookCtx.ModelPtr, nil
} }
+4 -4
View File
@@ -387,7 +387,7 @@ func (g *Generator) generateModelSchema(model interface{}) Schema {
} }
modelType := reflect.TypeOf(model) modelType := reflect.TypeOf(model)
if modelType.Kind() == reflect.Ptr { if modelType.Kind() == reflect.Pointer {
modelType = modelType.Elem() modelType = modelType.Elem()
} }
if modelType.Kind() != reflect.Struct { if modelType.Kind() != reflect.Struct {
@@ -418,7 +418,7 @@ func (g *Generator) generateModelSchema(model interface{}) Schema {
schema.Properties[fieldName] = propSchema schema.Properties[fieldName] = propSchema
// Check if field is required (not a pointer and no omitempty) // Check if field is required (not a pointer and no omitempty)
if field.Type.Kind() != reflect.Ptr && !strings.Contains(jsonTag, "omitempty") { if field.Type.Kind() != reflect.Pointer && !strings.Contains(jsonTag, "omitempty") {
schema.Required = append(schema.Required, fieldName) schema.Required = append(schema.Required, fieldName)
} }
} }
@@ -431,7 +431,7 @@ func (g *Generator) generatePropertySchema(field reflect.StructField) *Schema {
schema := &Schema{} schema := &Schema{}
fieldType := field.Type fieldType := field.Type
if fieldType.Kind() == reflect.Ptr { if fieldType.Kind() == reflect.Pointer {
fieldType = fieldType.Elem() fieldType = fieldType.Elem()
} }
@@ -453,7 +453,7 @@ func (g *Generator) generatePropertySchema(field reflect.StructField) *Schema {
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
schema.Type = "array" schema.Type = "array"
elemType := fieldType.Elem() elemType := fieldType.Elem()
if elemType.Kind() == reflect.Ptr { if elemType.Kind() == reflect.Pointer {
elemType = elemType.Elem() elemType = elemType.Elem()
} }
if elemType.Kind() == reflect.Struct { if elemType.Kind() == reflect.Struct {
+6 -6
View File
@@ -9,7 +9,7 @@ func Len(v any) int {
val := reflect.ValueOf(v) val := reflect.ValueOf(v)
valKind := val.Kind() valKind := val.Kind()
if valKind == reflect.Ptr { if valKind == reflect.Pointer {
val = val.Elem() val = val.Elem()
} }
@@ -57,7 +57,7 @@ func IsEmptyValue(v any) bool {
return true return true
} }
rv := reflect.ValueOf(v) rv := reflect.ValueOf(v)
if rv.Kind() == reflect.Ptr { if rv.Kind() == reflect.Pointer {
if rv.IsNil() { if rv.IsNil() {
return true return true
} }
@@ -80,12 +80,12 @@ func IsEmptyValue(v any) bool {
// If the type is a slice of pointers, it returns the element type of the pointer within the slice. // If the type is a slice of pointers, it returns the element type of the pointer within the slice.
// If neither condition is met, it returns the original type. // If neither condition is met, it returns the original type.
func GetPointerElement(v reflect.Type) reflect.Type { func GetPointerElement(v reflect.Type) reflect.Type {
if v.Kind() == reflect.Ptr { if v.Kind() == reflect.Pointer {
return v.Elem() return v.Elem()
} }
if v.Kind() == reflect.Slice && v.Elem().Kind() == reflect.Ptr { if v.Kind() == reflect.Slice && v.Elem().Kind() == reflect.Pointer {
subElem := v.Elem() subElem := v.Elem()
if subElem.Elem().Kind() == reflect.Ptr { if subElem.Elem().Kind() == reflect.Pointer {
return subElem.Elem().Elem() return subElem.Elem().Elem()
} }
return v.Elem() return v.Elem()
@@ -104,7 +104,7 @@ func GetJSONNameForField(modelType reflect.Type, fieldName string) string {
// Unwrap pointer and slice indirections to reach the struct type // Unwrap pointer and slice indirections to reach the struct type
for { for {
switch modelType.Kind() { switch modelType.Kind() {
case reflect.Ptr, reflect.Slice: case reflect.Pointer, reflect.Slice:
modelType = modelType.Elem() modelType = modelType.Elem()
continue continue
} }
+16 -16
View File
@@ -226,7 +226,7 @@ func buildJSONToDBMap(modelType reflect.Type, result map[string]string, scanOnly
// Handle embedded structs // Handle embedded structs
if field.Anonymous { if field.Anonymous {
ft := field.Type ft := field.Type
if ft.Kind() == reflect.Ptr { if ft.Kind() == reflect.Pointer {
ft = ft.Elem() ft = ft.Elem()
} }
isScanOnly := scanOnly isScanOnly := scanOnly
@@ -544,7 +544,7 @@ func IsColumnWritable(model any, columnName string) bool {
// Unwrap pointers and slices to get to the base struct type // Unwrap pointers and slices to get to the base struct type
for modelType != nil { for modelType != nil {
switch modelType.Kind() { switch modelType.Kind() {
case reflect.Ptr, reflect.Slice: case reflect.Pointer, reflect.Slice:
modelType = modelType.Elem() modelType = modelType.Elem()
continue continue
} }
@@ -709,7 +709,7 @@ func GetColumnTypeFromModel(model interface{}, colName string) reflect.Kind {
modelType := reflect.TypeOf(model) modelType := reflect.TypeOf(model)
// Dereference pointer if needed // Dereference pointer if needed
if modelType.Kind() == reflect.Ptr { if modelType.Kind() == reflect.Pointer {
modelType = modelType.Elem() modelType = modelType.Elem()
} }
@@ -886,7 +886,7 @@ func GetRelationType(model interface{}, fieldName string) RelationType {
// Unwrap pointer → slice → pointer chains to reach the underlying struct // Unwrap pointer → slice → pointer chains to reach the underlying struct
for { for {
switch modelType.Kind() { switch modelType.Kind() {
case reflect.Ptr, reflect.Slice: case reflect.Pointer, reflect.Slice:
modelType = modelType.Elem() modelType = modelType.Elem()
continue continue
} }
@@ -947,7 +947,7 @@ func GetRelationType(model interface{}, fieldName string) RelationType {
// Slice indicates has-many or many-to-many // Slice indicates has-many or many-to-many
return RelationHasMany return RelationHasMany
} }
if fieldType.Kind() == reflect.Ptr { if fieldType.Kind() == reflect.Pointer {
// Pointer to single struct usually indicates belongs-to or has-one // Pointer to single struct usually indicates belongs-to or has-one
// Check if it has foreignKey (belongs-to) or references (has-one) // Check if it has foreignKey (belongs-to) or references (has-one)
if strings.Contains(gormTag, "foreignKey:") { if strings.Contains(gormTag, "foreignKey:") {
@@ -963,7 +963,7 @@ func GetRelationType(model interface{}, fieldName string) RelationType {
// Slice of structs → has-many // Slice of structs → has-many
return RelationHasMany return RelationHasMany
} }
if fieldType.Kind() == reflect.Ptr || fieldType.Kind() == reflect.Struct { if fieldType.Kind() == reflect.Pointer || fieldType.Kind() == reflect.Struct {
// Single struct → belongs-to (default assumption for safety) // Single struct → belongs-to (default assumption for safety)
// Using belongs-to as default ensures we use JOIN, which is safer // Using belongs-to as default ensures we use JOIN, which is safer
return RelationBelongsTo return RelationBelongsTo
@@ -990,7 +990,7 @@ func GetRelationType(model interface{}, fieldName string) RelationType {
// Strategy 1 is skipped if the matched field is a declared relation (rel:) or // Strategy 1 is skipped if the matched field is a declared relation (rel:) or
// has a GORM tag but carries no explicit FK — callers should use convention. // has a GORM tag but carries no explicit FK — callers should use convention.
func GetForeignKeyColumn(modelType reflect.Type, parentKey string) []string { func GetForeignKeyColumn(modelType reflect.Type, parentKey string) []string {
for modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice { for modelType.Kind() == reflect.Pointer || modelType.Kind() == reflect.Slice {
modelType = modelType.Elem() modelType = modelType.Elem()
} }
if modelType.Kind() != reflect.Struct { if modelType.Kind() != reflect.Struct {
@@ -1123,7 +1123,7 @@ func MapToStruct(dataMap map[string]interface{}, target interface{}) error {
} }
targetValue := reflect.ValueOf(target) targetValue := reflect.ValueOf(target)
if targetValue.Kind() != reflect.Ptr { if targetValue.Kind() != reflect.Pointer {
return fmt.Errorf("target must be a pointer to a struct") return fmt.Errorf("target must be a pointer to a struct")
} }
@@ -1226,8 +1226,8 @@ func setFieldValue(field reflect.Value, value interface{}) error {
} }
// Handle pointer fields // Handle pointer fields
if field.Kind() == reflect.Ptr { if field.Kind() == reflect.Pointer {
if valueReflect.Kind() != reflect.Ptr { if valueReflect.Kind() != reflect.Pointer {
// Create a new pointer and set its value // Create a new pointer and set its value
newPtr := reflect.New(field.Type().Elem()) newPtr := reflect.New(field.Type().Elem())
if err := setFieldValue(newPtr.Elem(), value); err != nil { if err := setFieldValue(newPtr.Elem(), value); err != nil {
@@ -1418,14 +1418,14 @@ func convertSlice(targetSlice reflect.Value, sourceSlice reflect.Value) error {
// Handle nil elements // Handle nil elements
if sourceValue == nil { if sourceValue == nil {
// For pointer types, nil is valid // For pointer types, nil is valid
if targetElemType.Kind() == reflect.Ptr { if targetElemType.Kind() == reflect.Pointer {
targetElem.Set(reflect.Zero(targetElemType)) targetElem.Set(reflect.Zero(targetElemType))
} }
continue continue
} }
// If target element type is a pointer to struct, we need to create new instances // If target element type is a pointer to struct, we need to create new instances
if targetElemType.Kind() == reflect.Ptr { if targetElemType.Kind() == reflect.Pointer {
// Create a new instance of the pointed-to type // Create a new instance of the pointed-to type
newElemPtr := reflect.New(targetElemType.Elem()) newElemPtr := reflect.New(targetElemType.Elem())
@@ -1588,7 +1588,7 @@ func GetValidJSONFieldNames(modelType reflect.Type) map[string]bool {
// Unwrap pointers and slices to get to the base struct type // Unwrap pointers and slices to get to the base struct type
for modelType != nil { for modelType != nil {
switch modelType.Kind() { switch modelType.Kind() {
case reflect.Ptr, reflect.Slice: case reflect.Pointer, reflect.Slice:
modelType = modelType.Elem() modelType = modelType.Elem()
continue continue
} }
@@ -1616,7 +1616,7 @@ func collectValidFieldNames(typ reflect.Type, validFields map[string]bool) {
// Check for embedded structs // Check for embedded structs
if field.Anonymous { if field.Anonymous {
fieldType := field.Type fieldType := field.Type
if fieldType.Kind() == reflect.Ptr { if fieldType.Kind() == reflect.Pointer {
fieldType = fieldType.Elem() fieldType = fieldType.Elem()
} }
if fieldType.Kind() == reflect.Struct { if fieldType.Kind() == reflect.Struct {
@@ -1655,7 +1655,7 @@ func getRelationModelSingleLevel(model interface{}, fieldName string) interface{
for { for {
switch modelType.Kind() { switch modelType.Kind() {
case reflect.Ptr, reflect.Slice: case reflect.Pointer, reflect.Slice:
modelType = modelType.Elem() modelType = modelType.Elem()
continue continue
} }
@@ -1724,7 +1724,7 @@ func getRelationModelSingleLevel(model interface{}, fieldName string) interface{
for { for {
switch targetType.Kind() { switch targetType.Kind() {
case reflect.Ptr, reflect.Slice: case reflect.Pointer, reflect.Slice:
targetType = targetType.Elem() targetType = targetType.Elem()
if targetType == nil { if targetType == nil {
return nil return nil
+102 -7
View File
@@ -428,14 +428,36 @@ func (h *Handler) executeCreate(ctx context.Context, schema, entity string, data
// Use potentially modified data // Use potentially modified data
data = hookCtx.Data data = hookCtx.Data
pkName := reflection.GetPrimaryKeyName(model)
switch v := data.(type) { switch v := data.(type) {
case map[string]interface{}: case map[string]interface{}:
query := h.db.NewInsert().Table(tableName) query := h.db.NewInsert().Table(tableName)
for key, value := range v { for key, value := range v {
query = query.Value(key, value) query = query.Value(key, value)
} }
if _, err := query.Exec(ctx); err != nil { if pkName != "" {
return nil, fmt.Errorf("create error: %w", err) var insertedID interface{}
if err := query.Returning(pkName).Scan(ctx, &insertedID); err != nil {
return nil, fmt.Errorf("create error: %w", err)
}
// Re-fetch after insert to capture DB-generated defaults/triggers.
modelType := reflect.TypeOf(model)
if modelType.Kind() == reflect.Pointer {
modelType = modelType.Elem()
}
fetchedRecord := reflect.New(modelType).Interface()
if err := h.db.NewSelect().Model(fetchedRecord).
Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), insertedID).
ScanModel(ctx); err == nil {
v = mergeWithInput(fetchedRecord, v)
} else {
logger.Warn("Failed to re-fetch created record with %s=%v: %v", pkName, insertedID, err)
}
} else {
if _, err := query.Exec(ctx); err != nil {
return nil, fmt.Errorf("create error: %w", err)
}
} }
hookCtx.Result = v hookCtx.Result = v
if err := h.hooks.Execute(AfterCreate, hookCtx); err != nil { if err := h.hooks.Execute(AfterCreate, hookCtx); err != nil {
@@ -444,7 +466,12 @@ func (h *Handler) executeCreate(ctx context.Context, schema, entity string, data
return v, nil return v, nil
case []interface{}: case []interface{}:
results := make([]interface{}, 0, len(v)) modelType := reflect.TypeOf(model)
if modelType.Kind() == reflect.Pointer {
modelType = modelType.Elem()
}
originals := make([]map[string]interface{}, 0, len(v))
insertedIDs := make([]interface{}, 0, len(v))
err := h.db.RunInTransaction(ctx, func(tx common.Database) error { err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
for _, item := range v { for _, item := range v {
itemMap, ok := item.(map[string]interface{}) itemMap, ok := item.(map[string]interface{})
@@ -455,16 +482,43 @@ func (h *Handler) executeCreate(ctx context.Context, schema, entity string, data
for key, value := range itemMap { for key, value := range itemMap {
q = q.Value(key, value) q = q.Value(key, value)
} }
if _, err := q.Exec(ctx); err != nil { if pkName == "" {
if _, err := q.Exec(ctx); err != nil {
return err
}
originals = append(originals, itemMap)
insertedIDs = append(insertedIDs, nil)
continue
}
var returnedID interface{}
if err := q.Returning(pkName).Scan(ctx, &returnedID); err != nil {
return err return err
} }
results = append(results, item) originals = append(originals, itemMap)
insertedIDs = append(insertedIDs, returnedID)
} }
return nil return nil
}) })
if err != nil { if err != nil {
return nil, fmt.Errorf("batch create error: %w", err) return nil, fmt.Errorf("batch create error: %w", err)
} }
// Re-fetch each record after transaction commits; fall back to input on failure.
results := make([]interface{}, 0, len(insertedIDs))
for i, pkVal := range insertedIDs {
if pkVal == nil {
results = append(results, originals[i])
continue
}
fetchedRecord := reflect.New(modelType).Interface()
if err := h.db.NewSelect().Model(fetchedRecord).
Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), pkVal).
ScanModel(ctx); err == nil {
results = append(results, mergeWithInput(fetchedRecord, originals[i]))
} else {
logger.Warn("Failed to re-fetch created record with %s=%v: %v", pkName, pkVal, err)
results = append(results, originals[i])
}
}
hookCtx.Result = results hookCtx.Result = results
if err := h.hooks.Execute(AfterCreate, hookCtx); err != nil { if err := h.hooks.Execute(AfterCreate, hookCtx); err != nil {
return nil, fmt.Errorf("AfterCreate hook failed: %w", err) return nil, fmt.Errorf("AfterCreate hook failed: %w", err)
@@ -513,7 +567,7 @@ func (h *Handler) executeUpdate(ctx context.Context, schema, entity, id string,
err = h.db.RunInTransaction(ctx, func(tx common.Database) error { err = h.db.RunInTransaction(ctx, func(tx common.Database) error {
// Read existing record // Read existing record
modelType := reflect.TypeOf(model) modelType := reflect.TypeOf(model)
if modelType.Kind() == reflect.Ptr { if modelType.Kind() == reflect.Pointer {
modelType = modelType.Elem() modelType = modelType.Elem()
} }
existingRecord := reflect.New(modelType).Interface() existingRecord := reflect.New(modelType).Interface()
@@ -584,6 +638,25 @@ func (h *Handler) executeUpdate(ctx context.Context, schema, entity, id string,
if err != nil { if err != nil {
return nil, err return nil, err
} }
// Re-fetch the record after transaction commits to capture DB-generated changes.
modelType := reflect.TypeOf(model)
if modelType.Kind() == reflect.Pointer {
modelType = modelType.Elem()
}
fetchedRecord := reflect.New(modelType).Interface()
if err := h.db.NewSelect().Model(fetchedRecord).
Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id).
ScanModel(ctx); err == nil {
jsonData, marshalErr := json.Marshal(fetchedRecord)
if marshalErr == nil {
var fetchedMap map[string]interface{}
if json.Unmarshal(jsonData, &fetchedMap) == nil {
updateResult = fetchedMap
}
}
}
return updateResult, nil return updateResult, nil
} }
@@ -628,7 +701,7 @@ func (h *Handler) executeDelete(ctx context.Context, schema, entity, id string)
} }
modelType := reflect.TypeOf(model) modelType := reflect.TypeOf(model)
if modelType.Kind() == reflect.Ptr { if modelType.Kind() == reflect.Pointer {
modelType = modelType.Elem() modelType = modelType.Elem()
} }
@@ -749,6 +822,28 @@ func (h *Handler) buildFilterCondition(filter common.FilterOption) (condition st
return "", nil return "", nil
} }
// mergeWithInput merges a database record with the original request data.
// DB values take precedence (capturing triggers/defaults), while extra
// input keys that have no DB column are preserved in the response.
func mergeWithInput(dbRecord interface{}, input map[string]interface{}) map[string]interface{} {
result := make(map[string]interface{}, len(input))
for k, v := range input {
result[k] = v
}
jsonData, err := json.Marshal(dbRecord)
if err != nil {
return result
}
var dbMap map[string]interface{}
if err := json.Unmarshal(jsonData, &dbMap); err != nil {
return result
}
for k, v := range dbMap {
result[k] = v
}
return result
}
func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, preloads []common.PreloadOption) (common.SelectQuery, error) { func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, preloads []common.PreloadOption) (common.SelectQuery, error) {
for i := range preloads { for i := range preloads {
preload := &preloads[i] preload := &preloads[i]
+3 -3
View File
@@ -67,7 +67,7 @@ func buildModelInfo(schema, entity string, model interface{}) modelInfo {
// Unwrap to base struct type // Unwrap to base struct type
modelType := reflect.TypeOf(model) modelType := reflect.TypeOf(model)
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice) { for modelType != nil && (modelType.Kind() == reflect.Pointer || modelType.Kind() == reflect.Slice) {
modelType = modelType.Elem() modelType = modelType.Elem()
} }
if modelType == nil || modelType.Kind() != reflect.Struct { if modelType == nil || modelType.Kind() != reflect.Struct {
@@ -87,7 +87,7 @@ func buildModelInfo(schema, entity string, model interface{}) modelInfo {
fieldType, found := modelType.FieldByName(d.Name) fieldType, found := modelType.FieldByName(d.Name)
if found { if found {
ft := fieldType.Type ft := fieldType.Type
if ft.Kind() == reflect.Ptr { if ft.Kind() == reflect.Pointer {
ft = ft.Elem() ft = ft.Elem()
} }
isUserStruct := ft.Kind() == reflect.Struct && ft.Name() != "Time" && ft.PkgPath() != "" isUserStruct := ft.Kind() == reflect.Struct && ft.Name() != "Time" && ft.PkgPath() != ""
@@ -106,7 +106,7 @@ func buildModelInfo(schema, entity string, model interface{}) modelInfo {
goType := d.DataType goType := d.DataType
if goType == "" && found { if goType == "" && found {
ft := fieldType.Type ft := fieldType.Type
for ft.Kind() == reflect.Ptr { for ft.Kind() == reflect.Pointer {
ft = ft.Elem() ft = ft.Elem()
} }
goType = ft.Name() goType = ft.Name()
+194 -32
View File
@@ -243,7 +243,7 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
// Validate and unwrap model type to get base struct // Validate and unwrap model type to get base struct
modelType := reflect.TypeOf(model) modelType := reflect.TypeOf(model)
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) { for modelType != nil && (modelType.Kind() == reflect.Pointer || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
modelType = modelType.Elem() modelType = modelType.Elem()
} }
@@ -602,23 +602,44 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
} }
// Standard processing without nested relations // Standard processing without nested relations
pkName := reflection.GetPrimaryKeyName(model)
query := h.db.NewInsert().Table(tableName) query := h.db.NewInsert().Table(tableName)
for key, value := range v { for key, value := range v {
query = query.Value(key, common.ConvertSliceForBun(value)) query = query.Value(key, common.ConvertSliceForBun(value))
} }
result, err := query.Exec(ctx) var responseData interface{} = v
if err != nil { if pkName == "" {
logger.Error("Error creating record: %v", err) // No PK on model — insert and return input as-is.
h.sendError(w, http.StatusInternalServerError, "create_error", "Error creating record", err) result, err := query.Exec(ctx)
return if err != nil {
logger.Error("Error creating record: %v", err)
h.sendError(w, http.StatusInternalServerError, "create_error", "Error creating record", err)
return
}
logger.Info("Successfully created record, rows affected: %d", result.RowsAffected())
} else {
var insertedID interface{}
if err := query.Returning(pkName).Scan(ctx, &insertedID); err != nil {
logger.Error("Error creating record: %v", err)
h.sendError(w, http.StatusInternalServerError, "create_error", "Error creating record", err)
return
}
logger.Info("Successfully created record with %s: %v", pkName, insertedID)
fetchedRecord := reflect.New(reflection.GetPointerElement(reflect.TypeOf(model))).Interface()
if fetchErr := h.db.NewSelect().Model(fetchedRecord).
Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), insertedID).
ScanModel(ctx); fetchErr == nil {
responseData = mergeWithInput(fetchedRecord, v)
} else {
logger.Warn("Failed to re-fetch created record with %s=%v: %v", pkName, insertedID, fetchErr)
}
} }
logger.Info("Successfully created record, rows affected: %d", result.RowsAffected())
// Invalidate cache for this table // Invalidate cache for this table
cacheTags := buildCacheTags(schema, tableName) cacheTags := buildCacheTags(schema, tableName)
if err := invalidateCacheForTags(ctx, cacheTags); err != nil { if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err) logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
} }
h.sendResponse(w, v, nil) h.sendResponse(w, responseData, nil)
case []map[string]interface{}: case []map[string]interface{}:
// Check if any item needs nested processing // Check if any item needs nested processing
@@ -666,15 +687,30 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
} }
// Standard batch insert without nested relations // Standard batch insert without nested relations
pkName := reflection.GetPrimaryKeyName(model)
modelElemType := reflection.GetPointerElement(reflect.TypeOf(model))
originals := make([]map[string]interface{}, 0, len(v))
insertedIDs := make([]interface{}, 0, len(v))
err := h.db.RunInTransaction(ctx, func(tx common.Database) error { err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
for _, item := range v { for _, item := range v {
txQuery := tx.NewInsert().Table(tableName) txQuery := tx.NewInsert().Table(tableName)
for key, value := range item { for key, value := range item {
txQuery = txQuery.Value(key, common.ConvertSliceForBun(value)) txQuery = txQuery.Value(key, common.ConvertSliceForBun(value))
} }
if _, err := txQuery.Exec(ctx); err != nil { if pkName == "" {
if _, err := txQuery.Exec(ctx); err != nil {
return err
}
originals = append(originals, item)
insertedIDs = append(insertedIDs, nil)
continue
}
var returnedID interface{}
if err := txQuery.Returning(pkName).Scan(ctx, &returnedID); err != nil {
return err return err
} }
originals = append(originals, item)
insertedIDs = append(insertedIDs, returnedID)
} }
return nil return nil
}) })
@@ -689,7 +725,24 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
if err := invalidateCacheForTags(ctx, cacheTags); err != nil { if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err) logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
} }
h.sendResponse(w, v, nil) // Re-fetch each record after transaction commits; fall back to input on failure.
responseItems := make([]interface{}, 0, len(insertedIDs))
for i, pkVal := range insertedIDs {
if pkVal == nil {
responseItems = append(responseItems, originals[i])
continue
}
fetchedRecord := reflect.New(modelElemType).Interface()
if fetchErr := h.db.NewSelect().Model(fetchedRecord).
Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), pkVal).
ScanModel(ctx); fetchErr == nil {
responseItems = append(responseItems, mergeWithInput(fetchedRecord, originals[i]))
} else {
logger.Warn("Failed to re-fetch created record with %s=%v: %v", pkName, pkVal, fetchErr)
responseItems = append(responseItems, originals[i])
}
}
h.sendResponse(w, responseItems, nil)
case []interface{}: case []interface{}:
// Handle []interface{} type from JSON unmarshaling // Handle []interface{} type from JSON unmarshaling
@@ -742,19 +795,34 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
} }
// Standard batch insert without nested relations // Standard batch insert without nested relations
list := make([]interface{}, 0) pkName := reflection.GetPrimaryKeyName(model)
modelElemType := reflection.GetPointerElement(reflect.TypeOf(model))
originals := make([]map[string]interface{}, 0, len(v))
insertedIDs := make([]interface{}, 0, len(v))
err := h.db.RunInTransaction(ctx, func(tx common.Database) error { err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
for _, item := range v { for _, item := range v {
if itemMap, ok := item.(map[string]interface{}); ok { itemMap, ok := item.(map[string]interface{})
txQuery := tx.NewInsert().Table(tableName) if !ok {
for key, value := range itemMap { continue
txQuery = txQuery.Value(key, common.ConvertSliceForBun(value)) }
} txQuery := tx.NewInsert().Table(tableName)
for key, value := range itemMap {
txQuery = txQuery.Value(key, common.ConvertSliceForBun(value))
}
if pkName == "" {
if _, err := txQuery.Exec(ctx); err != nil { if _, err := txQuery.Exec(ctx); err != nil {
return err return err
} }
list = append(list, item) originals = append(originals, itemMap)
insertedIDs = append(insertedIDs, nil)
continue
} }
var returnedID interface{}
if err := txQuery.Returning(pkName).Scan(ctx, &returnedID); err != nil {
return err
}
originals = append(originals, itemMap)
insertedIDs = append(insertedIDs, returnedID)
} }
return nil return nil
}) })
@@ -769,7 +837,24 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
if err := invalidateCacheForTags(ctx, cacheTags); err != nil { if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err) logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
} }
h.sendResponse(w, list, nil) // Re-fetch each record after transaction commits; fall back to input on failure.
responseItems := make([]interface{}, 0, len(insertedIDs))
for i, pkVal := range insertedIDs {
if pkVal == nil {
responseItems = append(responseItems, originals[i])
continue
}
fetchedRecord := reflect.New(modelElemType).Interface()
if fetchErr := h.db.NewSelect().Model(fetchedRecord).
Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), pkVal).
ScanModel(ctx); fetchErr == nil {
responseItems = append(responseItems, mergeWithInput(fetchedRecord, originals[i]))
} else {
logger.Warn("Failed to re-fetch created record with %s=%v: %v", pkName, pkVal, fetchErr)
responseItems = append(responseItems, originals[i])
}
}
h.sendResponse(w, responseItems, nil)
default: default:
logger.Error("Invalid data type for create operation: %T", data) logger.Error("Invalid data type for create operation: %T", data)
@@ -836,7 +921,7 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url
err := h.db.RunInTransaction(ctx, func(tx common.Database) error { err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
// First, read the existing record from the database // First, read the existing record from the database
existingRecord := reflect.New(reflection.GetPointerElement(reflect.TypeOf(model))).Interface() existingRecord := reflect.New(reflection.GetPointerElement(reflect.TypeOf(model))).Interface()
selectQuery := tx.NewSelect().Model(existingRecord).Column("*") selectQuery := tx.NewSelect().Model(existingRecord).Column(reflection.GetSQLModelColumns(model)...)
// Apply conditions to select // Apply conditions to select
if urlID != "" { if urlID != "" {
@@ -955,13 +1040,34 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url
return return
} }
// Fetch the updated record after the transaction commits to capture any trigger changes
updatedRecord := reflect.New(reflection.GetPointerElement(reflect.TypeOf(model))).Interface()
fetchQuery := h.db.NewSelect().Model(updatedRecord).Column(reflection.GetSQLModelColumns(model)...)
if urlID != "" {
fetchQuery = fetchQuery.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), urlID)
} else if reqID != nil {
switch id := reqID.(type) {
case string:
fetchQuery = fetchQuery.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id)
case []string:
if len(id) > 0 {
fetchQuery = fetchQuery.Where(fmt.Sprintf("%s IN (?)", common.QuoteIdent(pkName)), id)
}
}
}
if err := fetchQuery.ScanModel(ctx); err != nil {
logger.Error("Failed to fetch updated record: %v", err)
h.sendError(w, http.StatusInternalServerError, "fetch_error", "Failed to fetch updated record", err)
return
}
logger.Info("Successfully updated record(s)") logger.Info("Successfully updated record(s)")
// Invalidate cache for this table // Invalidate cache for this table
cacheTags := buildCacheTags(schema, tableName) cacheTags := buildCacheTags(schema, tableName)
if err := invalidateCacheForTags(ctx, cacheTags); err != nil { if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err) logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
} }
h.sendResponse(w, data, nil) h.sendResponse(w, updatedRecord, nil)
case []map[string]interface{}: case []map[string]interface{}:
// Batch update with array of objects // Batch update with array of objects
@@ -1017,7 +1123,7 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url
// First, read the existing record // First, read the existing record
existingRecord := reflect.New(reflection.GetPointerElement(reflect.TypeOf(model))).Interface() existingRecord := reflect.New(reflection.GetPointerElement(reflect.TypeOf(model))).Interface()
selectQuery := tx.NewSelect().Model(existingRecord).Column("*").Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), itemID) selectQuery := tx.NewSelect().Model(existingRecord).Column(reflection.GetSQLModelColumns(model)...).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), itemID)
if err := selectQuery.ScanModel(ctx); err != nil { if err := selectQuery.ScanModel(ctx); err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
continue // Skip if record not found continue // Skip if record not found
@@ -1089,13 +1195,29 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url
h.sendError(w, http.StatusInternalServerError, "update_error", "Error updating records", err) h.sendError(w, http.StatusInternalServerError, "update_error", "Error updating records", err)
return return
} }
logger.Info("Successfully updated %d records", len(updates))
// Fetch updated records after the transaction commits to capture any trigger changes
fetchedUpdates := make([]interface{}, 0, len(updates))
for _, item := range updates {
if itemID, ok := item["id"]; ok && itemID != nil {
fetchedRecord := reflect.New(reflection.GetPointerElement(reflect.TypeOf(model))).Interface()
fetchQuery := h.db.NewSelect().Model(fetchedRecord).Column(reflection.GetSQLModelColumns(model)...).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), itemID)
if err := fetchQuery.ScanModel(ctx); err != nil {
logger.Error("Failed to fetch updated record with ID %v: %v", itemID, err)
h.sendError(w, http.StatusInternalServerError, "fetch_error", "Failed to fetch updated record", err)
return
}
fetchedUpdates = append(fetchedUpdates, fetchedRecord)
}
}
logger.Info("Successfully updated %d records", len(fetchedUpdates))
// Invalidate cache for this table // Invalidate cache for this table
cacheTags := buildCacheTags(schema, tableName) cacheTags := buildCacheTags(schema, tableName)
if err := invalidateCacheForTags(ctx, cacheTags); err != nil { if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err) logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
} }
h.sendResponse(w, updates, nil) h.sendResponse(w, fetchedUpdates, nil)
case []interface{}: case []interface{}:
// Batch update with []interface{} // Batch update with []interface{}
@@ -1157,7 +1279,7 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url
// First, read the existing record // First, read the existing record
existingRecord := reflect.New(reflection.GetPointerElement(reflect.TypeOf(model))).Interface() existingRecord := reflect.New(reflection.GetPointerElement(reflect.TypeOf(model))).Interface()
selectQuery := tx.NewSelect().Model(existingRecord).Column("*").Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), itemID) selectQuery := tx.NewSelect().Model(existingRecord).Column(reflection.GetSQLModelColumns(model)...).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), itemID)
if err := selectQuery.ScanModel(ctx); err != nil { if err := selectQuery.ScanModel(ctx); err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
continue // Skip if record not found continue // Skip if record not found
@@ -1232,13 +1354,31 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url
h.sendError(w, http.StatusInternalServerError, "update_error", "Error updating records", err) h.sendError(w, http.StatusInternalServerError, "update_error", "Error updating records", err)
return return
} }
logger.Info("Successfully updated %d records", len(list))
// Fetch updated records after the transaction commits to capture any trigger changes
fetchedList := make([]interface{}, 0, len(list))
for _, item := range list {
if itemMap, ok := item.(map[string]interface{}); ok {
if itemID, ok := itemMap["id"]; ok && itemID != nil {
fetchedRecord := reflect.New(reflection.GetPointerElement(reflect.TypeOf(model))).Interface()
fetchQuery := h.db.NewSelect().Model(fetchedRecord).Column(reflection.GetSQLModelColumns(model)...).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), itemID)
if err := fetchQuery.ScanModel(ctx); err != nil {
logger.Error("Failed to fetch updated record with ID %v: %v", itemID, err)
h.sendError(w, http.StatusInternalServerError, "fetch_error", "Failed to fetch updated record", err)
return
}
fetchedList = append(fetchedList, fetchedRecord)
}
}
}
logger.Info("Successfully updated %d records", len(fetchedList))
// Invalidate cache for this table // Invalidate cache for this table
cacheTags := buildCacheTags(schema, tableName) cacheTags := buildCacheTags(schema, tableName)
if err := invalidateCacheForTags(ctx, cacheTags); err != nil { if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err) logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
} }
h.sendResponse(w, list, nil) h.sendResponse(w, fetchedList, nil)
default: default:
logger.Error("Invalid data type for update operation: %T", data) logger.Error("Invalid data type for update operation: %T", data)
@@ -1407,7 +1547,7 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
// First, fetch the record that will be deleted // First, fetch the record that will be deleted
modelType := reflect.TypeOf(model) modelType := reflect.TypeOf(model)
if modelType.Kind() == reflect.Ptr { if modelType.Kind() == reflect.Pointer {
modelType = modelType.Elem() modelType = modelType.Elem()
} }
recordToDelete := reflect.New(modelType).Interface() recordToDelete := reflect.New(modelType).Interface()
@@ -1682,7 +1822,7 @@ func (h *Handler) generateMetadata(schema, entity string, model interface{}) *co
modelType := reflect.TypeOf(model) modelType := reflect.TypeOf(model)
// Unwrap pointers, slices, and arrays to get to the base struct type // Unwrap pointers, slices, and arrays to get to the base struct type
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) { for modelType != nil && (modelType.Kind() == reflect.Pointer || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
modelType = modelType.Elem() modelType = modelType.Elem()
} }
@@ -1826,7 +1966,7 @@ func getColumnType(field reflect.StructField) string {
func isNullable(field reflect.StructField) bool { func isNullable(field reflect.StructField) bool {
// Check if it's a pointer type // Check if it's a pointer type
if field.Type.Kind() == reflect.Ptr { if field.Type.Kind() == reflect.Pointer {
return true return true
} }
@@ -1852,7 +1992,7 @@ func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, pre
modelType := reflect.TypeOf(model) modelType := reflect.TypeOf(model)
// Unwrap pointers, slices, and arrays to get to the base struct type // Unwrap pointers, slices, and arrays to get to the base struct type
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) { for modelType != nil && (modelType.Kind() == reflect.Pointer || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
modelType = modelType.Elem() modelType = modelType.Elem()
} }
@@ -2000,7 +2140,7 @@ func toSnakeCase(s string) string {
func (h *Handler) setRowNumbersOnRecords(records interface{}, offset int) { func (h *Handler) setRowNumbersOnRecords(records interface{}, offset int) {
// Get the reflect value of the records // Get the reflect value of the records
recordsValue := reflect.ValueOf(records) recordsValue := reflect.ValueOf(records)
if recordsValue.Kind() == reflect.Ptr { if recordsValue.Kind() == reflect.Pointer {
recordsValue = recordsValue.Elem() recordsValue = recordsValue.Elem()
} }
@@ -2015,7 +2155,7 @@ func (h *Handler) setRowNumbersOnRecords(records interface{}, offset int) {
record := recordsValue.Index(i) record := recordsValue.Index(i)
// Dereference if it's a pointer // Dereference if it's a pointer
if record.Kind() == reflect.Ptr { if record.Kind() == reflect.Pointer {
if record.IsNil() { if record.IsNil() {
continue continue
} }
@@ -2067,3 +2207,25 @@ func (h *Handler) HandleOpenAPI(w common.ResponseWriter, r common.Request) {
func (h *Handler) SetOpenAPIGenerator(generator func() (string, error)) { func (h *Handler) SetOpenAPIGenerator(generator func() (string, error)) {
h.openAPIGenerator = generator h.openAPIGenerator = generator
} }
// mergeWithInput merges a database record with the original request data.
// DB values take precedence (capturing triggers/defaults), while extra
// input keys that have no DB column are preserved in the response.
func mergeWithInput(dbRecord interface{}, input map[string]interface{}) map[string]interface{} {
result := make(map[string]interface{}, len(input))
for k, v := range input {
result[k] = v
}
jsonData, err := json.Marshal(dbRecord)
if err != nil {
return result
}
var dbMap map[string]interface{}
if err := json.Unmarshal(jsonData, &dbMap); err != nil {
return result
}
for k, v := range dbMap {
result[k] = v
}
return result
}
+209
View File
@@ -0,0 +1,209 @@
package restheadspec
import (
"encoding/json"
"testing"
"github.com/bitechdev/ResolveSpec/pkg/common"
)
// detailTestModel is a simple model with gorm column/type tags for detail format tests.
type detailTestModel struct {
ID int64 `bun:"rid,pk" gorm:"column:rid;primaryKey" json:"rid"`
Name string `bun:"name" gorm:"column:name;type:citext" json:"name"`
Description *string `bun:"description" gorm:"column:description;type:text;nullable" json:"description"`
Score float64 `bun:"score" gorm:"column:score;type:numeric" json:"score"`
Active bool `bun:"active" gorm:"column:active;type:boolean;not null" json:"active"`
}
func TestSendFormattedResponse_DetailFormat(t *testing.T) {
handler := &Handler{}
name := "hello"
items := []*detailTestModel{
{ID: 1, Name: "first", Description: &name, Score: 1.5, Active: true},
{ID: 2, Name: "second", Description: nil, Score: 2.0, Active: false},
}
metadata := &common.Metadata{
Total: 36,
Count: 2,
Filtered: 36,
Limit: 10,
Offset: 0,
}
options := ExtendedRequestOptions{
ResponseFormat: "detail",
}
mockWriter := &MockTestResponseWriter{headers: make(map[string]string)}
handler.sendFormattedResponse(mockWriter, items, metadata, "myschema.myentity", detailTestModel{}, options)
if mockWriter.statusCode != 200 {
t.Fatalf("expected status 200, got %d", mockWriter.statusCode)
}
body, err := json.Marshal(mockWriter.body)
if err != nil {
t.Fatalf("failed to marshal body: %v", err)
}
var resp map[string]json.RawMessage
if err := json.Unmarshal(body, &resp); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
t.Run("top-level keys", func(t *testing.T) {
for _, key := range []string{"count", "fields", "items", "tablename", "tableprefix", "total"} {
if _, ok := resp[key]; !ok {
t.Errorf("missing key %q in detail response", key)
}
}
})
t.Run("count and total are string", func(t *testing.T) {
var count, total string
if err := json.Unmarshal(resp["count"], &count); err != nil {
t.Errorf("count is not a string: %v", err)
}
if err := json.Unmarshal(resp["total"], &total); err != nil {
t.Errorf("total is not a string: %v", err)
}
if count != "2" {
t.Errorf("expected count %q, got %q", "2", count)
}
if total != "36" {
t.Errorf("expected total %q, got %q", "36", total)
}
})
t.Run("tablename and tableprefix", func(t *testing.T) {
var tablename, tableprefix string
json.Unmarshal(resp["tablename"], &tablename)
json.Unmarshal(resp["tableprefix"], &tableprefix)
if tablename != "myschema.myentity" {
t.Errorf("expected tablename %q, got %q", "myschema.myentity", tablename)
}
if tableprefix != "myentity" {
t.Errorf("expected tableprefix %q, got %q", "myentity", tableprefix)
}
})
t.Run("items contains data", func(t *testing.T) {
var itemSlice []map[string]interface{}
if err := json.Unmarshal(resp["items"], &itemSlice); err != nil {
t.Fatalf("items is not an array: %v", err)
}
if len(itemSlice) != 2 {
t.Errorf("expected 2 items, got %d", len(itemSlice))
}
})
t.Run("fields contains column metadata", func(t *testing.T) {
var fields []map[string]interface{}
if err := json.Unmarshal(resp["fields"], &fields); err != nil {
t.Fatalf("fields is not an array: %v", err)
}
if len(fields) == 0 {
t.Fatal("expected fields to be non-empty")
}
bySQL := make(map[string]map[string]interface{}, len(fields))
for _, f := range fields {
if sqlname, ok := f["sqlname"].(string); ok {
bySQL[sqlname] = f
}
}
// Check required field keys are present
for _, f := range fields {
for _, key := range []string{"name", "datatype", "sqlname", "sqldatatype", "sqlkey", "nullable"} {
if _, ok := f[key]; !ok {
t.Errorf("field %v missing key %q", f, key)
}
}
}
// Validate specific columns
if col, ok := bySQL["rid"]; ok {
if col["sqlkey"] != "primary_key" {
t.Errorf("rid: expected sqlkey %q, got %v", "primary_key", col["sqlkey"])
}
} else {
t.Error("expected column 'rid' in fields")
}
if col, ok := bySQL["name"]; ok {
if col["sqldatatype"] != "citext" {
t.Errorf("name: expected sqldatatype %q, got %v", "citext", col["sqldatatype"])
}
if col["nullable"] != false {
t.Errorf("name: expected nullable false, got %v", col["nullable"])
}
} else {
t.Error("expected column 'name' in fields")
}
if col, ok := bySQL["description"]; ok {
if col["sqldatatype"] != "text" {
t.Errorf("description: expected sqldatatype %q, got %v", "text", col["sqldatatype"])
}
if col["nullable"] != true {
t.Errorf("description: expected nullable true, got %v", col["nullable"])
}
} else {
t.Error("expected column 'description' in fields")
}
})
}
func TestSendFormattedResponse_DetailFormat_EmptyItems(t *testing.T) {
handler := &Handler{}
metadata := &common.Metadata{Total: 0, Count: 0, Filtered: 0}
options := ExtendedRequestOptions{ResponseFormat: "detail"}
mockWriter := &MockTestResponseWriter{headers: make(map[string]string)}
handler.sendFormattedResponse(mockWriter, []*detailTestModel{}, metadata, "s.t", detailTestModel{}, options)
body, _ := json.Marshal(mockWriter.body)
var resp map[string]json.RawMessage
json.Unmarshal(body, &resp)
var count, total string
json.Unmarshal(resp["count"], &count)
json.Unmarshal(resp["total"], &total)
if count != "0" || total != "0" {
t.Errorf("expected count/total both %q, got count=%q total=%q", "0", count, total)
}
var fields []interface{}
json.Unmarshal(resp["fields"], &fields)
if len(fields) == 0 {
t.Error("fields should still list column metadata even when items is empty")
}
}
func TestBuildDetailFields_SkipsRelations(t *testing.T) {
type child struct {
ID int64 `bun:"id,pk" gorm:"column:id;primaryKey" json:"id"`
}
type parent struct {
ID int64 `bun:"id,pk" gorm:"column:id;primaryKey" json:"id"`
Name string `bun:"name" gorm:"column:name" json:"name"`
Children []child `bun:"rel:has-many" json:"children"`
Child *child `bun:"rel:has-one" json:"child"`
}
handler := &Handler{}
fields := handler.buildDetailFields(parent{})
for _, f := range fields {
if f.SQLName == "children" || f.SQLName == "child" {
t.Errorf("relation field %q should not appear in detail fields", f.SQLName)
}
}
if len(fields) != 2 {
t.Errorf("expected 2 scalar fields (id, name), got %d", len(fields))
}
}
+1 -1
View File
@@ -95,7 +95,7 @@ func TestSendFormattedResponse_NoDataFoundHeader(t *testing.T) {
// Test with empty data // Test with empty data
emptyData := []interface{}{} emptyData := []interface{}{}
handler.sendFormattedResponse(mockWriter, emptyData, metadata, options) handler.sendFormattedResponse(mockWriter, emptyData, metadata, "", nil, options)
// Check if X-No-Data-Found header was set // Check if X-No-Data-Found header was set
if mockWriter.headers["X-No-Data-Found"] != "true" { if mockWriter.headers["X-No-Data-Found"] != "true" {
+185 -38
View File
@@ -289,7 +289,8 @@ func (h *Handler) HandleGet(w common.ResponseWriter, r common.Request, params ma
Limit: 0, Limit: 0,
Offset: 0, Offset: 0,
} }
h.sendFormattedResponse(w, tableMetadata, responseMetadata, options) tableName := h.getTableName(schema, entity, model)
h.sendFormattedResponse(w, tableMetadata, responseMetadata, tableName, model, options)
} }
// handleMeta processes meta operation requests // handleMeta processes meta operation requests
@@ -348,7 +349,7 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
// Validate and unwrap model type to get base struct // Validate and unwrap model type to get base struct
modelType := reflect.TypeOf(model) modelType := reflect.TypeOf(model)
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) { for modelType != nil && (modelType.Kind() == reflect.Pointer || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
modelType = modelType.Elem() modelType = modelType.Elem()
} }
@@ -848,7 +849,7 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
return return
} }
h.sendFormattedResponse(w, modelPtr, metadata, options) h.sendFormattedResponse(w, modelPtr, metadata, tableName, model, options)
} }
// applyPreloadWithRecursion applies a preload with support for ComputedQL and recursive preloading // applyPreloadWithRecursion applies a preload with support for ComputedQL and recursive preloading
@@ -1218,8 +1219,8 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
if provider, ok := modelValue.(common.TableNameProvider); !ok || provider.TableName() == "" { if provider, ok := modelValue.(common.TableNameProvider); !ok || provider.TableName() == "" {
query = query.Table(tableName) query = query.Table(tableName)
} }
fields := reflection.GetSQLModelColumns(model)
query = query.Returning("*") query = query.Returning(fields...)
// Execute BeforeScan hooks - pass query chain so hooks can modify it // Execute BeforeScan hooks - pass query chain so hooks can modify it
itemHookCtx := &HookContext{ itemHookCtx := &HookContext{
@@ -1480,18 +1481,7 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id
} }
} }
// Fetch the updated record to return the new values _ = result
modelValue := reflect.New(reflect.TypeOf(model)).Interface()
selectQuery = tx.NewSelect().Model(modelValue).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), targetID)
if err := selectQuery.ScanModel(ctx); err != nil {
return fmt.Errorf("failed to fetch updated record: %w", err)
}
updatedRecord = modelValue
// Store result for hooks
hookCtx.Result = updatedRecord
_ = result // Keep result variable for potential future use
return nil return nil
}) })
@@ -1501,6 +1491,16 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id
return return
} }
// Fetch the updated record after the transaction commits to capture any trigger changes
fetchedRecord := reflect.New(reflect.TypeOf(model)).Interface()
selectQuery := h.db.NewSelect().Model(fetchedRecord).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), targetID)
if err := selectQuery.ScanModel(ctx); err != nil {
logger.Error("Failed to fetch updated record: %v", err)
h.sendError(w, http.StatusInternalServerError, "fetch_error", "Failed to fetch updated record", err)
return
}
updatedRecord = fetchedRecord
// Merge the updated record with the original request data // Merge the updated record with the original request data
// This preserves extra keys from the request and updates values from the database // This preserves extra keys from the request and updates values from the database
mergedData := h.mergeRecordWithRequest(updatedRecord, dataMap) mergedData := h.mergeRecordWithRequest(updatedRecord, dataMap)
@@ -1892,7 +1892,7 @@ func (h *Handler) extractNestedRelations(
) (_cleanedData map[string]interface{}, _relations map[string]interface{}, _err error) { ) (_cleanedData map[string]interface{}, _relations map[string]interface{}, _err error) {
// Get model type for reflection // Get model type for reflection
modelType := reflect.TypeOf(model) modelType := reflect.TypeOf(model)
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) { for modelType != nil && (modelType.Kind() == reflect.Pointer || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
modelType = modelType.Elem() modelType = modelType.Elem()
} }
@@ -1934,7 +1934,7 @@ func (h *Handler) processChildRelationsWithParentID(
) error { ) error {
// Get model type for reflection // Get model type for reflection
modelType := reflect.TypeOf(parentModel) modelType := reflect.TypeOf(parentModel)
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) { for modelType != nil && (modelType.Kind() == reflect.Pointer || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
modelType = modelType.Elem() modelType = modelType.Elem()
} }
@@ -1990,7 +1990,7 @@ func (h *Handler) processChildRelationsForField(
if relatedModelType.Kind() == reflect.Slice { if relatedModelType.Kind() == reflect.Slice {
relatedModelType = relatedModelType.Elem() relatedModelType = relatedModelType.Elem()
} }
if relatedModelType.Kind() == reflect.Ptr { if relatedModelType.Kind() == reflect.Pointer {
relatedModelType = relatedModelType.Elem() relatedModelType = relatedModelType.Elem()
} }
@@ -2012,11 +2012,15 @@ func (h *Handler) processChildRelationsForField(
// Priority: Use foreign key field name if specified, otherwise use parent's PK name // Priority: Use foreign key field name if specified, otherwise use parent's PK name
var foreignKeyFieldName string var foreignKeyFieldName string
if relInfo.ForeignKey != "" { if relInfo.ForeignKey != "" {
// Get the JSON name for the foreign key field in the child model // For has-many/has-one: join:parentCol=childCol
foreignKeyFieldName = reflection.GetJSONNameForField(relatedModelType, relInfo.ForeignKey) // ForeignKey = parent side, References = child side (where we actually set the value)
childField := relInfo.ForeignKey
if (relInfo.RelationType == "hasMany" || relInfo.RelationType == "hasOne") && relInfo.References != "" {
childField = relInfo.References
}
foreignKeyFieldName = reflection.GetJSONNameForField(relatedModelType, childField)
if foreignKeyFieldName == "" { if foreignKeyFieldName == "" {
// Fallback to lowercase field name foreignKeyFieldName = strings.ToLower(childField)
foreignKeyFieldName = strings.ToLower(relInfo.ForeignKey)
} }
} else { } else {
// Fallback: use parent's primary key name // Fallback: use parent's primary key name
@@ -2040,7 +2044,10 @@ func (h *Handler) processChildRelationsForField(
// Process based on relation type and data structure // Process based on relation type and data structure
switch v := relationValue.(type) { switch v := relationValue.(type) {
case map[string]interface{}: case map[string]interface{}:
// Single related object - add parent ID to foreign key field if !isValidNestedRequest(v) {
logger.Debug("Skipping single relation %s - missing or invalid _request value", relationName)
return nil
}
// IMPORTANT: In recursive relationships, don't overwrite the primary key // IMPORTANT: In recursive relationships, don't overwrite the primary key
if parentID != nil && foreignKeyFieldName != "" && foreignKeyFieldName != childPKFieldName { if parentID != nil && foreignKeyFieldName != "" && foreignKeyFieldName != childPKFieldName {
v[foreignKeyFieldName] = parentID v[foreignKeyFieldName] = parentID
@@ -2057,7 +2064,10 @@ func (h *Handler) processChildRelationsForField(
// Multiple related objects // Multiple related objects
for i, item := range v { for i, item := range v {
if itemMap, ok := item.(map[string]interface{}); ok { if itemMap, ok := item.(map[string]interface{}); ok {
// Add parent ID to foreign key field if !isValidNestedRequest(itemMap) {
logger.Debug("Skipping relation array[%d] %s - missing or invalid _request value", i, relationName)
continue
}
// IMPORTANT: In recursive relationships, don't overwrite the primary key // IMPORTANT: In recursive relationships, don't overwrite the primary key
if parentID != nil && foreignKeyFieldName != "" && foreignKeyFieldName != childPKFieldName { if parentID != nil && foreignKeyFieldName != "" && foreignKeyFieldName != childPKFieldName {
itemMap[foreignKeyFieldName] = parentID itemMap[foreignKeyFieldName] = parentID
@@ -2075,7 +2085,10 @@ func (h *Handler) processChildRelationsForField(
case []map[string]interface{}: case []map[string]interface{}:
// Multiple related objects (typed slice) // Multiple related objects (typed slice)
for i, itemMap := range v { for i, itemMap := range v {
// Add parent ID to foreign key field if !isValidNestedRequest(itemMap) {
logger.Debug("Skipping relation typed array[%d] %s - missing or invalid _request value", i, relationName)
continue
}
// IMPORTANT: In recursive relationships, don't overwrite the primary key // IMPORTANT: In recursive relationships, don't overwrite the primary key
if parentID != nil && foreignKeyFieldName != "" && foreignKeyFieldName != childPKFieldName { if parentID != nil && foreignKeyFieldName != "" && foreignKeyFieldName != childPKFieldName {
itemMap[foreignKeyFieldName] = parentID itemMap[foreignKeyFieldName] = parentID
@@ -2096,6 +2109,24 @@ func (h *Handler) processChildRelationsForField(
return nil return nil
} }
// isValidNestedRequest returns true only when the item carries a _request key
// whose value is one of the recognised mutation verbs.
func isValidNestedRequest(item map[string]interface{}) bool {
raw, ok := item["_request"]
if !ok {
return false
}
s, ok := raw.(string)
if !ok {
return false
}
switch strings.ToLower(strings.TrimSpace(s)) {
case "insert", "add", "change", "update", "delete", "remove":
return true
}
return false
}
// getTableNameForRelatedModel gets the table name for a related model. // getTableNameForRelatedModel gets the table name for a related model.
// If the model's TableName() is schema-qualified (e.g. "public.users") the // If the model's TableName() is schema-qualified (e.g. "public.users") the
// separator is adjusted for the active driver: underscore for SQLite, dot otherwise. // separator is adjusted for the active driver: underscore for SQLite, dot otherwise.
@@ -2382,7 +2413,7 @@ func (h *Handler) generateMetadata(schema, entity string, model interface{}) *co
modelType := reflect.TypeOf(model) modelType := reflect.TypeOf(model)
// Unwrap pointers, slices, and arrays to get to the base struct type // Unwrap pointers, slices, and arrays to get to the base struct type
for modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array { for modelType.Kind() == reflect.Pointer || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array {
modelType = modelType.Elem() modelType = modelType.Elem()
} }
@@ -2431,7 +2462,7 @@ func (h *Handler) generateMetadata(schema, entity string, model interface{}) *co
// Check if this is a relation field (slice or struct, but not time.Time) // Check if this is a relation field (slice or struct, but not time.Time)
if field.Type.Kind() == reflect.Slice || if field.Type.Kind() == reflect.Slice ||
(field.Type.Kind() == reflect.Struct && field.Type.Name() != "Time") || (field.Type.Kind() == reflect.Struct && field.Type.Name() != "Time") ||
(field.Type.Kind() == reflect.Ptr && field.Type.Elem().Kind() == reflect.Struct && field.Type.Elem().Name() != "Time") { (field.Type.Kind() == reflect.Pointer && field.Type.Elem().Kind() == reflect.Struct && field.Type.Elem().Name() != "Time") {
metadata.Relations = append(metadata.Relations, jsonName) metadata.Relations = append(metadata.Relations, jsonName)
continue continue
} }
@@ -2477,7 +2508,7 @@ func (h *Handler) getColumnType(t reflect.Type) string {
return "float" return "float"
case reflect.Bool: case reflect.Bool:
return "boolean" return "boolean"
case reflect.Ptr: case reflect.Pointer:
return h.getColumnType(t.Elem()) return h.getColumnType(t.Elem())
default: default:
return "unknown" return "unknown"
@@ -2485,7 +2516,7 @@ func (h *Handler) getColumnType(t reflect.Type) string {
} }
func (h *Handler) isNullable(field reflect.StructField) bool { func (h *Handler) isNullable(field reflect.StructField) bool {
return field.Type.Kind() == reflect.Ptr return field.Type.Kind() == reflect.Pointer
} }
func (h *Handler) sendResponse(w common.ResponseWriter, data interface{}, metadata *common.Metadata) { func (h *Handler) sendResponse(w common.ResponseWriter, data interface{}, metadata *common.Metadata) {
@@ -2530,7 +2561,7 @@ func (h *Handler) normalizeResultArray(data interface{}) interface{} {
// Use reflection to check if data is a slice or array // Use reflection to check if data is a slice or array
dataValue := reflect.ValueOf(data) dataValue := reflect.ValueOf(data)
if dataValue.Kind() == reflect.Ptr { if dataValue.Kind() == reflect.Pointer {
dataValue = dataValue.Elem() dataValue = dataValue.Elem()
} }
@@ -2555,8 +2586,103 @@ func (h *Handler) normalizeResultArray(data interface{}) interface{} {
return data return data
} }
// sendFormattedResponse sends response with formatting options // buildDetailFields returns the field metadata list for the detail API format,
func (h *Handler) sendFormattedResponse(w common.ResponseWriter, data interface{}, metadata *common.Metadata, options ExtendedRequestOptions) { // containing only non-relation scalar columns derived from the model's struct tags.
func (h *Handler) buildDetailFields(model interface{}) []reflection.ModelFieldDetail {
if model == nil {
return []reflection.ModelFieldDetail{}
}
modelType := reflect.TypeOf(model)
for modelType != nil && (modelType.Kind() == reflect.Pointer || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
modelType = modelType.Elem()
}
if modelType == nil || modelType.Kind() != reflect.Struct {
return []reflection.ModelFieldDetail{}
}
fields := make([]reflection.ModelFieldDetail, 0, modelType.NumField())
for i := 0; i < modelType.NumField(); i++ {
field := modelType.Field(i)
if !field.IsExported() {
continue
}
jsonTag := field.Tag.Get("json")
if jsonTag == "-" {
continue
}
// Skip relation fields (slices, structs that aren't time.Time, ptrs to struct)
ft := field.Type
if ft.Kind() == reflect.Pointer {
ft = ft.Elem()
}
if ft.Kind() == reflect.Slice ||
(ft.Kind() == reflect.Struct && ft.Name() != "Time") {
continue
}
jsonName := strings.Split(jsonTag, ",")[0]
if jsonName == "" {
jsonName = field.Name
}
gormTag := field.Tag.Get("gorm")
sqlName := fnFindTagVal(gormTag, "column:")
if sqlName == "" {
sqlName = jsonName
}
sqlDataType := fnFindTagVal(gormTag, "type:")
var sqlKey string
gormLower := strings.ToLower(gormTag)
switch {
case strings.Contains(gormLower, "identity") || strings.Contains(gormLower, "primary_key") || strings.Contains(gormLower, "primarykey"):
sqlKey = "primary_key"
case strings.Contains(gormLower, "uniqueindex"):
sqlKey = "uniqueindex"
case strings.Contains(gormLower, "unique"):
sqlKey = "unique"
}
nullable := field.Type.Kind() == reflect.Pointer
if strings.Contains(gormLower, "not null") {
nullable = false
} else if strings.Contains(gormLower, "nullable") || strings.Contains(gormLower, ",null") {
nullable = true
}
fields = append(fields, reflection.ModelFieldDetail{
Name: jsonName,
DataType: h.getColumnType(field.Type),
SQLName: sqlName,
SQLDataType: sqlDataType,
SQLKey: sqlKey,
Nullable: nullable,
})
}
return fields
}
// fnFindTagVal extracts a value from a semicolon-separated struct tag string.
func fnFindTagVal(tag, key string) string {
lower := strings.ToLower(tag)
idx := strings.Index(lower, strings.ToLower(key))
if idx < 0 {
return ""
}
val := tag[idx+len(key):]
if end := strings.Index(val, ";"); end >= 0 {
val = val[:end]
}
return val
}
// sendFormattedResponse sends response with formatting options.
// model is used when ResponseFormat is "detail" to generate the fields metadata list.
func (h *Handler) sendFormattedResponse(w common.ResponseWriter, data interface{}, metadata *common.Metadata, tableName string, model interface{}, options ExtendedRequestOptions) {
// Handle nil data - convert to empty array // Handle nil data - convert to empty array
if data == nil { if data == nil {
data = []interface{}{} data = []interface{}{}
@@ -2609,8 +2735,29 @@ func (h *Handler) sendFormattedResponse(w common.ResponseWriter, data interface{
if err := w.WriteJSON(response); err != nil { if err := w.WriteJSON(response); err != nil {
logger.Error("Failed to write JSON response: %v", err) logger.Error("Failed to write JSON response: %v", err)
} }
case "detail":
// Detail format: { count, fields, items, tablename, tableprefix, total }
var count, total int64
if metadata != nil {
count = metadata.Count
total = metadata.Total
}
tablePrefix := reflection.ExtractTableNameOnly(tableName)
fieldList := h.buildDetailFields(model)
response := map[string]interface{}{
"count": strconv.FormatInt(count, 10),
"fields": fieldList,
"items": data,
"tablename": tableName,
"tableprefix": tablePrefix,
"total": strconv.FormatInt(total, 10),
}
w.WriteHeader(http.StatusOK)
if err := w.WriteJSON(response); err != nil {
logger.Error("Failed to write JSON response: %v", err)
}
default: default:
// Default/detail format: standard response with metadata // Default format: standard response with metadata
response := common.Response{ response := common.Response{
Success: true, Success: true,
Data: data, Data: data,
@@ -2868,7 +3015,7 @@ func (h *Handler) buildFilterSQL(filter *common.FilterOption, tableName string)
func (h *Handler) setRowNumbersOnRecords(records any, offset int) { func (h *Handler) setRowNumbersOnRecords(records any, offset int) {
// Get the reflect value of the records // Get the reflect value of the records
recordsValue := reflect.ValueOf(records) recordsValue := reflect.ValueOf(records)
if recordsValue.Kind() == reflect.Ptr { if recordsValue.Kind() == reflect.Pointer {
recordsValue = recordsValue.Elem() recordsValue = recordsValue.Elem()
} }
@@ -2883,7 +3030,7 @@ func (h *Handler) setRowNumbersOnRecords(records any, offset int) {
record := recordsValue.Index(i) record := recordsValue.Index(i)
// Dereference if it's a pointer // Dereference if it's a pointer
if record.Kind() == reflect.Ptr { if record.Kind() == reflect.Pointer {
if record.IsNil() { if record.IsNil() {
continue continue
} }
@@ -2938,7 +3085,7 @@ func (h *Handler) filterExtendedOptions(validator *common.ColumnValidator, optio
// Filter Expand columns using the expand relation's model // Filter Expand columns using the expand relation's model
filteredExpands := make([]ExpandOption, 0, len(options.Expand)) filteredExpands := make([]ExpandOption, 0, len(options.Expand))
modelType := reflect.TypeOf(model) modelType := reflect.TypeOf(model)
if modelType.Kind() == reflect.Ptr { if modelType.Kind() == reflect.Pointer {
modelType = modelType.Elem() modelType = modelType.Elem()
} }
+39
View File
@@ -352,6 +352,45 @@ func (m *mockRegistry) GetAllModels() map[string]interface{} {
return m.models return m.models
} }
// TestIsValidNestedRequest verifies that only the allowed _request verbs are accepted
// and that items missing the key are rejected.
func TestIsValidNestedRequest(t *testing.T) {
tests := []struct {
name string
item map[string]interface{}
expected bool
}{
// Valid verbs
{name: "insert", item: map[string]interface{}{"_request": "insert"}, expected: true},
{name: "add", item: map[string]interface{}{"_request": "add"}, expected: true},
{name: "update", item: map[string]interface{}{"_request": "update"}, expected: true},
{name: "change", item: map[string]interface{}{"_request": "change"}, expected: true},
{name: "delete", item: map[string]interface{}{"_request": "delete"}, expected: true},
{name: "remove", item: map[string]interface{}{"_request": "remove"}, expected: true},
// Case-insensitive
{name: "INSERT uppercase", item: map[string]interface{}{"_request": "INSERT"}, expected: true},
{name: "Remove mixed case", item: map[string]interface{}{"_request": "Remove"}, expected: true},
// Whitespace trimmed
{name: "insert with spaces", item: map[string]interface{}{"_request": " insert "}, expected: true},
// Invalid / missing
{name: "missing _request", item: map[string]interface{}{"name": "foo"}, expected: false},
{name: "empty string", item: map[string]interface{}{"_request": ""}, expected: false},
{name: "unknown verb", item: map[string]interface{}{"_request": "create"}, expected: false},
{name: "unknown verb modify", item: map[string]interface{}{"_request": "modify"}, expected: false},
{name: "non-string value", item: map[string]interface{}{"_request": 42}, expected: false},
{name: "nil value", item: map[string]interface{}{"_request": nil}, expected: false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := isValidNestedRequest(tt.item)
if got != tt.expected {
t.Errorf("isValidNestedRequest(%v) = %v, want %v", tt.item, got, tt.expected)
}
})
}
}
// TestMultiLevelRelationExtraction tests extracting deeply nested relations // TestMultiLevelRelationExtraction tests extracting deeply nested relations
func TestMultiLevelRelationExtraction(t *testing.T) { func TestMultiLevelRelationExtraction(t *testing.T) {
registry := &mockRegistry{ registry := &mockRegistry{
+6 -6
View File
@@ -977,7 +977,7 @@ func (h *Handler) resolveRelationName(model interface{}, nameOrTable string) str
} }
// Dereference pointer if needed // Dereference pointer if needed
if modelType.Kind() == reflect.Ptr { if modelType.Kind() == reflect.Pointer {
modelType = modelType.Elem() modelType = modelType.Elem()
} }
@@ -1012,13 +1012,13 @@ func (h *Handler) resolveRelationName(model interface{}, nameOrTable string) str
var targetType reflect.Type var targetType reflect.Type
if fieldType.Kind() == reflect.Slice { if fieldType.Kind() == reflect.Slice {
targetType = fieldType.Elem() targetType = fieldType.Elem()
} else if fieldType.Kind() == reflect.Ptr { } else if fieldType.Kind() == reflect.Pointer {
targetType = fieldType.Elem() targetType = fieldType.Elem()
} }
if targetType != nil { if targetType != nil {
// Dereference pointer if the slice contains pointers // Dereference pointer if the slice contains pointers
if targetType.Kind() == reflect.Ptr { if targetType.Kind() == reflect.Pointer {
targetType = targetType.Elem() targetType = targetType.Elem()
} }
@@ -1062,7 +1062,7 @@ func (h *Handler) resolveRelationNameWithJoinKey(model interface{}, nameOrTable
if modelType == nil { if modelType == nil {
return nameOrTable return nameOrTable
} }
if modelType.Kind() == reflect.Ptr { if modelType.Kind() == reflect.Pointer {
modelType = modelType.Elem() modelType = modelType.Elem()
} }
if modelType == nil || modelType.Kind() != reflect.Struct { if modelType == nil || modelType.Kind() != reflect.Struct {
@@ -1089,10 +1089,10 @@ func (h *Handler) resolveRelationNameWithJoinKey(model interface{}, nameOrTable
var targetType reflect.Type var targetType reflect.Type
if fieldType.Kind() == reflect.Slice { if fieldType.Kind() == reflect.Slice {
targetType = fieldType.Elem() targetType = fieldType.Elem()
} else if fieldType.Kind() == reflect.Ptr { } else if fieldType.Kind() == reflect.Pointer {
targetType = fieldType.Elem() targetType = fieldType.Elem()
} }
if targetType != nil && targetType.Kind() == reflect.Ptr { if targetType != nil && targetType.Kind() == reflect.Pointer {
targetType = targetType.Elem() targetType = targetType.Elem()
} }
if targetType == nil || targetType.Kind() != reflect.Struct { if targetType == nil || targetType.Kind() != reflect.Struct {
+53 -18
View File
@@ -13,6 +13,9 @@ CREATE TABLE IF NOT EXISTS users (
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
last_login_at TIMESTAMP, last_login_at TIMESTAMP,
-- Program-level user mapping
program_user_id INTEGER DEFAULT 0,
program_user_table VARCHAR(255) DEFAULT '',
-- OAuth2 fields -- OAuth2 fields
remote_id VARCHAR(255), -- Provider's user ID (e.g., Google sub, GitHub id) remote_id VARCHAR(255), -- Provider's user ID (e.g., Google sub, GitHub id)
auth_provider VARCHAR(50), -- 'local', 'google', 'github', 'microsoft', 'facebook', etc. auth_provider VARCHAR(50), -- 'local', 'google', 'github', 'microsoft', 'facebook', etc.
@@ -99,6 +102,8 @@ DECLARE
v_expires_at TIMESTAMP; v_expires_at TIMESTAMP;
v_ip_address TEXT; v_ip_address TEXT;
v_user_agent TEXT; v_user_agent TEXT;
v_program_user_id INTEGER;
v_program_user_table TEXT;
BEGIN BEGIN
-- Extract login request fields -- Extract login request fields
v_username := p_request->>'username'; v_username := p_request->>'username';
@@ -106,8 +111,8 @@ BEGIN
v_user_agent := p_request->'claims'->>'user_agent'; v_user_agent := p_request->'claims'->>'user_agent';
-- Validate user credentials -- Validate user credentials
SELECT id, username, email, password, user_level, roles SELECT id, username, email, password, user_level, roles, program_user_id, program_user_table
INTO v_user_id, v_username, v_email, v_password_hash, v_user_level, v_roles INTO v_user_id, v_username, v_email, v_password_hash, v_user_level, v_roles, v_program_user_id, v_program_user_table
FROM users FROM users
WHERE username = v_username AND is_active = true; WHERE username = v_username AND is_active = true;
@@ -146,7 +151,9 @@ BEGIN
'email', v_email, 'email', v_email,
'user_level', v_user_level, 'user_level', v_user_level,
'roles', string_to_array(COALESCE(v_roles, ''), ','), 'roles', string_to_array(COALESCE(v_roles, ''), ','),
'session_id', v_session_token 'session_id', v_session_token,
'program_user_id', COALESCE(v_program_user_id, 0),
'program_user_table', COALESCE(v_program_user_table, '')
), ),
'expires_in', 86400 -- 24 hours in seconds 'expires_in', 86400 -- 24 hours in seconds
); );
@@ -195,12 +202,16 @@ DECLARE
v_user_level INTEGER; v_user_level INTEGER;
v_roles TEXT; v_roles TEXT;
v_session_id TEXT; v_session_id TEXT;
v_program_user_id INTEGER;
v_program_user_table TEXT;
BEGIN BEGIN
-- Query session and user data -- Query session and user data
SELECT SELECT
s.user_id, u.username, u.email, u.user_level, u.roles, s.session_token s.user_id, u.username, u.email, u.user_level, u.roles, s.session_token,
u.program_user_id, u.program_user_table
INTO INTO
v_user_id, v_username, v_email, v_user_level, v_roles, v_session_id v_user_id, v_username, v_email, v_user_level, v_roles, v_session_id,
v_program_user_id, v_program_user_table
FROM user_sessions s FROM user_sessions s
JOIN users u ON s.user_id = u.id JOIN users u ON s.user_id = u.id
WHERE s.session_token = p_session_token WHERE s.session_token = p_session_token
@@ -222,7 +233,9 @@ BEGIN
'email', v_email, 'email', v_email,
'user_level', v_user_level, 'user_level', v_user_level,
'session_id', v_session_id, 'session_id', v_session_id,
'roles', string_to_array(COALESCE(v_roles, ''), ',') 'roles', string_to_array(COALESCE(v_roles, ''), ','),
'program_user_id', COALESCE(v_program_user_id, 0),
'program_user_table', COALESCE(v_program_user_table, '')
); );
END; END;
$$ LANGUAGE plpgsql; $$ LANGUAGE plpgsql;
@@ -266,10 +279,14 @@ DECLARE
v_expires_at TIMESTAMP; v_expires_at TIMESTAMP;
v_ip_address TEXT; v_ip_address TEXT;
v_user_agent TEXT; v_user_agent TEXT;
v_program_user_id INTEGER;
v_program_user_table TEXT;
BEGIN BEGIN
-- Verify old session exists and is valid -- Verify old session exists and is valid
SELECT s.user_id, u.username, u.email, u.user_level, u.roles, s.ip_address, s.user_agent SELECT s.user_id, u.username, u.email, u.user_level, u.roles, s.ip_address, s.user_agent,
INTO v_user_id, v_username, v_email, v_user_level, v_roles, v_ip_address, v_user_agent u.program_user_id, u.program_user_table
INTO v_user_id, v_username, v_email, v_user_level, v_roles, v_ip_address, v_user_agent,
v_program_user_id, v_program_user_table
FROM user_sessions s FROM user_sessions s
JOIN users u ON s.user_id = u.id JOIN users u ON s.user_id = u.id
WHERE s.session_token = p_old_session_token WHERE s.session_token = p_old_session_token
@@ -302,7 +319,9 @@ BEGIN
'email', v_email, 'email', v_email,
'user_level', v_user_level, 'user_level', v_user_level,
'session_id', v_new_session_token, 'session_id', v_new_session_token,
'roles', string_to_array(COALESCE(v_roles, ''), ',') 'roles', string_to_array(COALESCE(v_roles, ''), ','),
'program_user_id', COALESCE(v_program_user_id, 0),
'program_user_table', COALESCE(v_program_user_table, '')
); );
END; END;
$$ LANGUAGE plpgsql; $$ LANGUAGE plpgsql;
@@ -439,6 +458,8 @@ DECLARE
v_ip_address TEXT; v_ip_address TEXT;
v_user_agent TEXT; v_user_agent TEXT;
v_roles_array TEXT[]; v_roles_array TEXT[];
v_program_user_id INTEGER;
v_program_user_table TEXT;
BEGIN BEGIN
-- Extract registration request fields -- Extract registration request fields
v_username := p_request->>'username'; v_username := p_request->>'username';
@@ -447,6 +468,8 @@ BEGIN
v_user_level := COALESCE((p_request->>'user_level')::integer, 0); v_user_level := COALESCE((p_request->>'user_level')::integer, 0);
v_ip_address := p_request->'claims'->>'ip_address'; v_ip_address := p_request->'claims'->>'ip_address';
v_user_agent := p_request->'claims'->>'user_agent'; v_user_agent := p_request->'claims'->>'user_agent';
v_program_user_id := COALESCE((p_request->>'program_user_id')::integer, 0);
v_program_user_table := COALESCE(p_request->>'program_user_table', '');
-- Convert roles array from JSON to comma-separated string -- Convert roles array from JSON to comma-separated string
SELECT array_to_string(ARRAY(SELECT jsonb_array_elements_text(p_request->'roles')), ',') SELECT array_to_string(ARRAY(SELECT jsonb_array_elements_text(p_request->'roles')), ',')
@@ -485,8 +508,8 @@ BEGIN
-- v_password := crypt(v_password, gen_salt('bf')); -- v_password := crypt(v_password, gen_salt('bf'));
-- Create new user -- Create new user
INSERT INTO users (username, email, password, user_level, roles, is_active, created_at, updated_at) INSERT INTO users (username, email, password, user_level, roles, is_active, created_at, updated_at, program_user_id, program_user_table)
VALUES (v_username, v_email, v_password, v_user_level, v_roles, true, now(), now()) VALUES (v_username, v_email, v_password, v_user_level, v_roles, true, now(), now(), v_program_user_id, v_program_user_table)
RETURNING id INTO v_user_id; RETURNING id INTO v_user_id;
-- Generate session token -- Generate session token
@@ -512,7 +535,9 @@ BEGIN
'email', v_email, 'email', v_email,
'user_level', v_user_level, 'user_level', v_user_level,
'roles', string_to_array(COALESCE(v_roles, ''), ','), 'roles', string_to_array(COALESCE(v_roles, ''), ','),
'session_id', v_session_token 'session_id', v_session_token,
'program_user_id', v_program_user_id,
'program_user_table', v_program_user_table
), ),
'expires_in', 86400 -- 24 hours in seconds 'expires_in', 86400 -- 24 hours in seconds
); );
@@ -671,12 +696,16 @@ DECLARE
v_user_level INTEGER; v_user_level INTEGER;
v_roles TEXT; v_roles TEXT;
v_expires_at TIMESTAMP; v_expires_at TIMESTAMP;
v_program_user_id INTEGER;
v_program_user_table TEXT;
BEGIN BEGIN
-- Query session and user data from user_sessions table -- Query session and user data from user_sessions table
SELECT SELECT
s.user_id, u.username, u.email, u.user_level, u.roles, s.expires_at s.user_id, u.username, u.email, u.user_level, u.roles, s.expires_at,
u.program_user_id, u.program_user_table
INTO INTO
v_user_id, v_username, v_email, v_user_level, v_roles, v_expires_at v_user_id, v_username, v_email, v_user_level, v_roles, v_expires_at,
v_program_user_id, v_program_user_table
FROM user_sessions s FROM user_sessions s
JOIN users u ON s.user_id = u.id JOIN users u ON s.user_id = u.id
WHERE s.session_token = p_session_token WHERE s.session_token = p_session_token
@@ -698,7 +727,9 @@ BEGIN
'email', v_email, 'email', v_email,
'user_level', v_user_level, 'user_level', v_user_level,
'session_id', p_session_token, 'session_id', p_session_token,
'roles', string_to_array(COALESCE(v_roles, ''), ',') 'roles', string_to_array(COALESCE(v_roles, ''), ','),
'program_user_id', COALESCE(v_program_user_id, 0),
'program_user_table', COALESCE(v_program_user_table, '')
); );
END; END;
$$ LANGUAGE plpgsql; $$ LANGUAGE plpgsql;
@@ -815,10 +846,12 @@ DECLARE
v_email TEXT; v_email TEXT;
v_user_level INTEGER; v_user_level INTEGER;
v_roles TEXT; v_roles TEXT;
v_program_user_id INTEGER;
v_program_user_table TEXT;
BEGIN BEGIN
-- Query user data -- Query user data
SELECT username, email, user_level, roles SELECT username, email, user_level, roles, program_user_id, program_user_table
INTO v_username, v_email, v_user_level, v_roles INTO v_username, v_email, v_user_level, v_roles, v_program_user_id, v_program_user_table
FROM users FROM users
WHERE id = p_user_id WHERE id = p_user_id
AND is_active = true; AND is_active = true;
@@ -837,7 +870,9 @@ BEGIN
'user_name', v_username, 'user_name', v_username,
'email', v_email, 'email', v_email,
'user_level', v_user_level, 'user_level', v_user_level,
'roles', string_to_array(COALESCE(v_roles, ''), ',') 'roles', string_to_array(COALESCE(v_roles, ''), ','),
'program_user_id', COALESCE(v_program_user_id, 0),
'program_user_table', COALESCE(v_program_user_table, '')
); );
END; END;
$$ LANGUAGE plpgsql; $$ LANGUAGE plpgsql;
+3 -3
View File
@@ -90,7 +90,7 @@ func applyRowSecurity(secCtx SecurityContext, securityList *SecurityList) error
// Get primary key name from model // Get primary key name from model
modelType := reflect.TypeOf(model) modelType := reflect.TypeOf(model)
if modelType.Kind() == reflect.Ptr { if modelType.Kind() == reflect.Pointer {
modelType = modelType.Elem() modelType = modelType.Elem()
} }
@@ -155,13 +155,13 @@ func applyColumnSecurity(secCtx SecurityContext, securityList *SecurityList) err
// Get model type // Get model type
modelType := reflect.TypeOf(model) modelType := reflect.TypeOf(model)
if modelType.Kind() == reflect.Ptr { if modelType.Kind() == reflect.Pointer {
modelType = modelType.Elem() modelType = modelType.Elem()
} }
// Apply column security masking // Apply column security masking
resultValue := reflect.ValueOf(result) resultValue := reflect.ValueOf(result)
if resultValue.Kind() == reflect.Ptr { if resultValue.Kind() == reflect.Pointer {
resultValue = resultValue.Elem() resultValue = resultValue.Elem()
} }
+2
View File
@@ -18,6 +18,8 @@ type UserContext struct {
Claims map[string]any `json:"claims"` Claims map[string]any `json:"claims"`
Meta map[string]any `json:"meta"` // Additional metadata that can hold any JSON-serializable values Meta map[string]any `json:"meta"` // Additional metadata that can hold any JSON-serializable values
TwoFactorEnabled bool `json:"two_factor_enabled"` // Indicates if 2FA is enabled for this user TwoFactorEnabled bool `json:"two_factor_enabled"` // Indicates if 2FA is enabled for this user
ProgramUserID int `json:"program_user_id"`
ProgramUserTable string `json:"program_user_table"`
} }
// LoginRequest contains credentials for login // LoginRequest contains credentials for login
+19 -10
View File
@@ -70,6 +70,25 @@ func (m *mountPoint) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Try to open the file // Try to open the file
file, err := m.provider.Open(strings.TrimPrefix(filePath, "/")) file, err := m.provider.Open(strings.TrimPrefix(filePath, "/"))
if err != nil { if err != nil {
// For extensionless paths, also try path/index.html
if path.Ext(filePath) == "" {
indexFallback := path.Join(filePath, "index.html")
file, err = m.provider.Open(strings.TrimPrefix(indexFallback, "/"))
if err == nil {
defer file.Close()
m.serveFile(w, r, indexFallback, file)
return
}
indexFallback = fmt.Sprintf("%s.html", filePath)
file, err = m.provider.Open(strings.TrimPrefix(indexFallback, "/"))
if err == nil {
defer file.Close()
m.serveFile(w, r, indexFallback, file)
return
}
}
// File doesn't exist - check if we should use fallback // File doesn't exist - check if we should use fallback
if m.fallbackStrategy != nil && m.fallbackStrategy.ShouldFallback(filePath) { if m.fallbackStrategy != nil && m.fallbackStrategy.ShouldFallback(filePath) {
fallbackPath := m.fallbackStrategy.GetFallbackPath(filePath) fallbackPath := m.fallbackStrategy.GetFallbackPath(filePath)
@@ -80,16 +99,6 @@ func (m *mountPoint) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return return
} }
// For extensionless paths, also try path/index.html
if path.Ext(filePath) == "" {
indexFallback := path.Join(filePath, "index.html")
file, err = m.provider.Open(strings.TrimPrefix(indexFallback, "/"))
if err == nil {
defer file.Close()
m.serveFile(w, r, indexFallback, file)
return
}
}
} }
// No fallback or fallback failed - return 404 // No fallback or fallback failed - return 404
+8 -2
View File
@@ -671,6 +671,12 @@ func (h *Handler) create(hookCtx *HookContext) (interface{}, error) {
return nil, fmt.Errorf("failed to create record: %w", err) return nil, fmt.Errorf("failed to create record: %w", err)
} }
// Re-fetch the created record to capture DB-generated defaults/triggers.
if pkVal := reflection.GetPrimaryKeyValue(hookCtx.ModelPtr); pkVal != nil {
hookCtx.ID = fmt.Sprintf("%v", pkVal)
return h.readByID(hookCtx)
}
return hookCtx.ModelPtr, nil return hookCtx.ModelPtr, nil
} }
@@ -834,7 +840,7 @@ func (h *Handler) buildFilterCondition(filter common.FilterOption) (conditionStr
func (h *Handler) setRowNumbersOnRecords(records interface{}, offset int) { func (h *Handler) setRowNumbersOnRecords(records interface{}, offset int) {
// Get the reflect value of the records // Get the reflect value of the records
recordsValue := reflect.ValueOf(records) recordsValue := reflect.ValueOf(records)
if recordsValue.Kind() == reflect.Ptr { if recordsValue.Kind() == reflect.Pointer {
recordsValue = recordsValue.Elem() recordsValue = recordsValue.Elem()
} }
@@ -849,7 +855,7 @@ func (h *Handler) setRowNumbersOnRecords(records interface{}, offset int) {
record := recordsValue.Index(i) record := recordsValue.Index(i)
// Dereference if it's a pointer // Dereference if it's a pointer
if record.Kind() == reflect.Ptr { if record.Kind() == reflect.Pointer {
if record.IsNil() { if record.IsNil() {
continue continue
} }