mirror of
https://github.com/bitechdev/ResolveSpec.git
synced 2026-07-02 17:37:37 +00:00
Compare commits
21 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 0e8f8925c6 | |||
| 5a359a160b | |||
| a2799fa224 | |||
| 1419542650 | |||
| c120b49529 | |||
| 66348dac97 | |||
| a87cd18b1b | |||
| 29449c93d5 | |||
| 3b6e5c75be | |||
| 549ccb8468 | |||
| 1af9c76337 | |||
| 938a2ef3d9 | |||
| 69cc3e2839 | |||
| 4018af0636 | |||
| c4e79d6950 | |||
| 982a0e62ac | |||
| 5d459c95a7 | |||
| e9f7726e43 | |||
| 3d2251317a | |||
| 1ce0ab1ab4 | |||
| 1f9b230f7f |
@@ -39,7 +39,7 @@ func (h *QueryDebugHook) AfterQuery(ctx context.Context, event *bun.QueryEvent)
|
|||||||
// This helps identify which specific field is causing scanning issues
|
// This helps identify which specific field is causing scanning issues
|
||||||
func debugScanIntoStruct(rows interface{}, dest interface{}) error {
|
func debugScanIntoStruct(rows interface{}, dest interface{}) error {
|
||||||
v := reflect.ValueOf(dest)
|
v := reflect.ValueOf(dest)
|
||||||
if v.Kind() != reflect.Ptr {
|
if v.Kind() != reflect.Pointer {
|
||||||
return fmt.Errorf("dest must be a pointer")
|
return fmt.Errorf("dest must be a pointer")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -59,7 +59,7 @@ func debugScanIntoStruct(rows interface{}, dest interface{}) error {
|
|||||||
logger.Debug(" Slice element type: %s", elemType)
|
logger.Debug(" Slice element type: %s", elemType)
|
||||||
|
|
||||||
// If slice of pointers, get the underlying type
|
// If slice of pointers, get the underlying type
|
||||||
if elemType.Kind() == reflect.Ptr {
|
if elemType.Kind() == reflect.Pointer {
|
||||||
structType = elemType.Elem()
|
structType = elemType.Elem()
|
||||||
} else {
|
} else {
|
||||||
structType = elemType
|
structType = elemType
|
||||||
@@ -747,7 +747,7 @@ func checkIfRelationAlreadyLoaded(parents reflect.Value, relationName string) (r
|
|||||||
|
|
||||||
// Get the first parent to check the relation field
|
// Get the first parent to check the relation field
|
||||||
firstParent := parents.Index(0)
|
firstParent := parents.Index(0)
|
||||||
if firstParent.Kind() == reflect.Ptr {
|
if firstParent.Kind() == reflect.Pointer {
|
||||||
firstParent = firstParent.Elem()
|
firstParent = firstParent.Elem()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -762,7 +762,7 @@ func checkIfRelationAlreadyLoaded(parents reflect.Value, relationName string) (r
|
|||||||
// Check if any parent has a non-empty slice
|
// Check if any parent has a non-empty slice
|
||||||
for i := 0; i < parents.Len(); i++ {
|
for i := 0; i < parents.Len(); i++ {
|
||||||
parent := parents.Index(i)
|
parent := parents.Index(i)
|
||||||
if parent.Kind() == reflect.Ptr {
|
if parent.Kind() == reflect.Pointer {
|
||||||
parent = parent.Elem()
|
parent = parent.Elem()
|
||||||
}
|
}
|
||||||
field := parent.FieldByName(relationName)
|
field := parent.FieldByName(relationName)
|
||||||
@@ -771,7 +771,7 @@ func checkIfRelationAlreadyLoaded(parents reflect.Value, relationName string) (r
|
|||||||
allRelated := reflect.MakeSlice(field.Type(), 0, field.Len()*parents.Len())
|
allRelated := reflect.MakeSlice(field.Type(), 0, field.Len()*parents.Len())
|
||||||
for j := 0; j < parents.Len(); j++ {
|
for j := 0; j < parents.Len(); j++ {
|
||||||
p := parents.Index(j)
|
p := parents.Index(j)
|
||||||
if p.Kind() == reflect.Ptr {
|
if p.Kind() == reflect.Pointer {
|
||||||
p = p.Elem()
|
p = p.Elem()
|
||||||
}
|
}
|
||||||
f := p.FieldByName(relationName)
|
f := p.FieldByName(relationName)
|
||||||
@@ -784,7 +784,7 @@ func checkIfRelationAlreadyLoaded(parents reflect.Value, relationName string) (r
|
|||||||
return allRelated, true
|
return allRelated, true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if relationField.Kind() == reflect.Ptr {
|
} else if relationField.Kind() == reflect.Pointer {
|
||||||
// Check if it's a pointer (has-one/belongs-to)
|
// Check if it's a pointer (has-one/belongs-to)
|
||||||
if !relationField.IsNil() {
|
if !relationField.IsNil() {
|
||||||
// Already loaded! Collect all related records from all parents
|
// Already loaded! Collect all related records from all parents
|
||||||
@@ -792,7 +792,7 @@ func checkIfRelationAlreadyLoaded(parents reflect.Value, relationName string) (r
|
|||||||
allRelated := reflect.MakeSlice(reflect.SliceOf(relatedType), 0, parents.Len())
|
allRelated := reflect.MakeSlice(reflect.SliceOf(relatedType), 0, parents.Len())
|
||||||
for j := 0; j < parents.Len(); j++ {
|
for j := 0; j < parents.Len(); j++ {
|
||||||
p := parents.Index(j)
|
p := parents.Index(j)
|
||||||
if p.Kind() == reflect.Ptr {
|
if p.Kind() == reflect.Pointer {
|
||||||
p = p.Elem()
|
p = p.Elem()
|
||||||
}
|
}
|
||||||
f := p.FieldByName(relationName)
|
f := p.FieldByName(relationName)
|
||||||
@@ -816,7 +816,7 @@ func (b *BunSelectQuery) loadCustomPreloads(ctx context.Context) error {
|
|||||||
|
|
||||||
// Get the actual data from the model
|
// Get the actual data from the model
|
||||||
modelValue := reflect.ValueOf(model.Value())
|
modelValue := reflect.ValueOf(model.Value())
|
||||||
if modelValue.Kind() == reflect.Ptr {
|
if modelValue.Kind() == reflect.Pointer {
|
||||||
modelValue = modelValue.Elem()
|
modelValue = modelValue.Elem()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -884,7 +884,7 @@ func (b *BunSelectQuery) loadRelationLevel(ctx context.Context, parentRecords re
|
|||||||
|
|
||||||
// Get the first record to inspect the struct type
|
// Get the first record to inspect the struct type
|
||||||
firstRecord := parentRecords.Index(0)
|
firstRecord := parentRecords.Index(0)
|
||||||
if firstRecord.Kind() == reflect.Ptr {
|
if firstRecord.Kind() == reflect.Pointer {
|
||||||
firstRecord = firstRecord.Elem()
|
firstRecord = firstRecord.Elem()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -930,7 +930,7 @@ func (b *BunSelectQuery) loadRelationLevel(ctx context.Context, parentRecords re
|
|||||||
if isSlice {
|
if isSlice {
|
||||||
relatedType = relatedType.Elem()
|
relatedType = relatedType.Elem()
|
||||||
}
|
}
|
||||||
if relatedType.Kind() == reflect.Ptr {
|
if relatedType.Kind() == reflect.Pointer {
|
||||||
relatedType = relatedType.Elem()
|
relatedType = relatedType.Elem()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1018,7 +1018,7 @@ func extractForeignKeyValues(records reflect.Value, fkFieldName string) ([]inter
|
|||||||
|
|
||||||
for i := 0; i < records.Len(); i++ {
|
for i := 0; i < records.Len(); i++ {
|
||||||
record := records.Index(i)
|
record := records.Index(i)
|
||||||
if record.Kind() == reflect.Ptr {
|
if record.Kind() == reflect.Pointer {
|
||||||
record = record.Elem()
|
record = record.Elem()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1083,7 +1083,7 @@ func associateRelatedRecords(parents, related reflect.Value, fieldName string, r
|
|||||||
for i := 0; i < related.Len(); i++ {
|
for i := 0; i < related.Len(); i++ {
|
||||||
relRecord := related.Index(i)
|
relRecord := related.Index(i)
|
||||||
relRecordElem := relRecord
|
relRecordElem := relRecord
|
||||||
if relRecordElem.Kind() == reflect.Ptr {
|
if relRecordElem.Kind() == reflect.Pointer {
|
||||||
relRecordElem = relRecordElem.Elem()
|
relRecordElem = relRecordElem.Elem()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1109,7 +1109,7 @@ func associateRelatedRecords(parents, related reflect.Value, fieldName string, r
|
|||||||
for i := 0; i < parents.Len(); i++ {
|
for i := 0; i < parents.Len(); i++ {
|
||||||
parentPtr := parents.Index(i)
|
parentPtr := parents.Index(i)
|
||||||
parent := parentPtr
|
parent := parentPtr
|
||||||
if parent.Kind() == reflect.Ptr {
|
if parent.Kind() == reflect.Pointer {
|
||||||
parent = parent.Elem()
|
parent = parent.Elem()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1332,11 +1332,11 @@ func (b *BunSelectQuery) ScanModel(ctx context.Context) (err error) {
|
|||||||
modelInfo = fmt.Sprintf("Model type: %T", modelValue)
|
modelInfo = fmt.Sprintf("Model type: %T", modelValue)
|
||||||
|
|
||||||
v := reflect.ValueOf(modelValue)
|
v := reflect.ValueOf(modelValue)
|
||||||
if v.Kind() == reflect.Ptr {
|
if v.Kind() == reflect.Pointer {
|
||||||
v = v.Elem()
|
v = v.Elem()
|
||||||
}
|
}
|
||||||
if v.Kind() == reflect.Slice {
|
if v.Kind() == reflect.Slice {
|
||||||
if v.Type().Elem().Kind() == reflect.Ptr {
|
if v.Type().Elem().Kind() == reflect.Pointer {
|
||||||
modelInfo += fmt.Sprintf(", Slice of: %s", v.Type().Elem().Elem().Name())
|
modelInfo += fmt.Sprintf(", Slice of: %s", v.Type().Elem().Elem().Name())
|
||||||
} else {
|
} else {
|
||||||
modelInfo += fmt.Sprintf(", Slice of: %s", v.Type().Elem().Name())
|
modelInfo += fmt.Sprintf(", Slice of: %s", v.Type().Elem().Name())
|
||||||
@@ -1489,7 +1489,7 @@ func (b *BunInsertQuery) OnConflict(action string) common.InsertQuery {
|
|||||||
|
|
||||||
func (b *BunInsertQuery) Returning(columns ...string) common.InsertQuery {
|
func (b *BunInsertQuery) Returning(columns ...string) common.InsertQuery {
|
||||||
if len(columns) > 0 {
|
if len(columns) > 0 {
|
||||||
b.query = b.query.Returning(columns[0])
|
b.query = b.query.Returning(strings.Join(columns, ", "))
|
||||||
}
|
}
|
||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
@@ -1606,7 +1606,7 @@ func (b *BunUpdateQuery) Where(query string, args ...interface{}) common.UpdateQ
|
|||||||
|
|
||||||
func (b *BunUpdateQuery) Returning(columns ...string) common.UpdateQuery {
|
func (b *BunUpdateQuery) Returning(columns ...string) common.UpdateQuery {
|
||||||
if len(columns) > 0 {
|
if len(columns) > 0 {
|
||||||
b.query = b.query.Returning(columns[0])
|
b.query = b.query.Returning(strings.Join(columns, ", "))
|
||||||
}
|
}
|
||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -800,7 +800,7 @@ func (g *GormInsertQuery) Scan(ctx context.Context, dest interface{}) (err error
|
|||||||
col := g.returningColumns[0]
|
col := g.returningColumns[0]
|
||||||
if g.model != nil {
|
if g.model != nil {
|
||||||
val := reflect.ValueOf(g.model)
|
val := reflect.ValueOf(g.model)
|
||||||
if val.Kind() == reflect.Ptr {
|
if val.Kind() == reflect.Pointer {
|
||||||
val = val.Elem()
|
val = val.Elem()
|
||||||
}
|
}
|
||||||
if val.Kind() == reflect.Struct {
|
if val.Kind() == reflect.Struct {
|
||||||
|
|||||||
@@ -1195,7 +1195,7 @@ func (p *PgSQLSelectQuery) applySubqueryPreloads(ctx context.Context, dest inter
|
|||||||
|
|
||||||
// Use reflection to process the destination
|
// Use reflection to process the destination
|
||||||
destValue := reflect.ValueOf(dest)
|
destValue := reflect.ValueOf(dest)
|
||||||
if destValue.Kind() != reflect.Ptr {
|
if destValue.Kind() != reflect.Pointer {
|
||||||
return fmt.Errorf("dest must be a pointer")
|
return fmt.Errorf("dest must be a pointer")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1222,7 +1222,7 @@ func (p *PgSQLSelectQuery) applySubqueryPreloads(ctx context.Context, dest inter
|
|||||||
|
|
||||||
// loadPreloadsForRecord loads all preload relationships for a single record
|
// loadPreloadsForRecord loads all preload relationships for a single record
|
||||||
func (p *PgSQLSelectQuery) loadPreloadsForRecord(ctx context.Context, record reflect.Value, preloads []preloadConfig) error {
|
func (p *PgSQLSelectQuery) loadPreloadsForRecord(ctx context.Context, record reflect.Value, preloads []preloadConfig) error {
|
||||||
if record.Kind() == reflect.Ptr {
|
if record.Kind() == reflect.Pointer {
|
||||||
if record.IsNil() {
|
if record.IsNil() {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -1299,7 +1299,7 @@ func (p *PgSQLSelectQuery) executePreloadQuery(ctx context.Context, field reflec
|
|||||||
} else {
|
} else {
|
||||||
// Single struct - create a pointer if needed
|
// Single struct - create a pointer if needed
|
||||||
var target reflect.Value
|
var target reflect.Value
|
||||||
if field.Kind() == reflect.Ptr {
|
if field.Kind() == reflect.Pointer {
|
||||||
target = reflect.New(field.Type().Elem())
|
target = reflect.New(field.Type().Elem())
|
||||||
} else {
|
} else {
|
||||||
target = reflect.New(field.Type())
|
target = reflect.New(field.Type())
|
||||||
@@ -1312,7 +1312,7 @@ func (p *PgSQLSelectQuery) executePreloadQuery(ctx context.Context, field reflec
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Set the field
|
// Set the field
|
||||||
if field.Kind() == reflect.Ptr {
|
if field.Kind() == reflect.Pointer {
|
||||||
field.Set(target)
|
field.Set(target)
|
||||||
} else {
|
} else {
|
||||||
field.Set(target.Elem())
|
field.Set(target.Elem())
|
||||||
@@ -1329,7 +1329,7 @@ func (p *PgSQLSelectQuery) getRelationMetadata(fieldName string) *relationMetada
|
|||||||
}
|
}
|
||||||
|
|
||||||
modelType := reflect.TypeOf(p.model)
|
modelType := reflect.TypeOf(p.model)
|
||||||
if modelType.Kind() == reflect.Ptr {
|
if modelType.Kind() == reflect.Pointer {
|
||||||
modelType = modelType.Elem()
|
modelType = modelType.Elem()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1378,7 +1378,7 @@ func (p *PgSQLSelectQuery) getRelationMetadataFromField(modelType reflect.Type,
|
|||||||
if fieldType.Kind() == reflect.Slice {
|
if fieldType.Kind() == reflect.Slice {
|
||||||
fieldType = fieldType.Elem()
|
fieldType = fieldType.Elem()
|
||||||
}
|
}
|
||||||
if fieldType.Kind() == reflect.Ptr {
|
if fieldType.Kind() == reflect.Pointer {
|
||||||
fieldType = fieldType.Elem()
|
fieldType = fieldType.Elem()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1411,7 +1411,7 @@ func scanRows(rows *sql.Rows, dest interface{}) error {
|
|||||||
|
|
||||||
// Get destination type
|
// Get destination type
|
||||||
destValue := reflect.ValueOf(dest)
|
destValue := reflect.ValueOf(dest)
|
||||||
if destValue.Kind() != reflect.Ptr {
|
if destValue.Kind() != reflect.Pointer {
|
||||||
return fmt.Errorf("dest must be a pointer")
|
return fmt.Errorf("dest must be a pointer")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1466,7 +1466,7 @@ func scanRowsToMapSlice(rows *sql.Rows, columns []string, destValue reflect.Valu
|
|||||||
// scanRowsToStructSlice scans rows into a slice of structs
|
// scanRowsToStructSlice scans rows into a slice of structs
|
||||||
func scanRowsToStructSlice(rows *sql.Rows, columns []string, destValue reflect.Value) error {
|
func scanRowsToStructSlice(rows *sql.Rows, columns []string, destValue reflect.Value) error {
|
||||||
elemType := destValue.Type().Elem()
|
elemType := destValue.Type().Elem()
|
||||||
isPtr := elemType.Kind() == reflect.Ptr
|
isPtr := elemType.Kind() == reflect.Pointer
|
||||||
|
|
||||||
if isPtr {
|
if isPtr {
|
||||||
elemType = elemType.Elem()
|
elemType = elemType.Elem()
|
||||||
|
|||||||
@@ -71,7 +71,7 @@ func entityNameFromModel(model interface{}, table string) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
modelType := reflect.TypeOf(model)
|
modelType := reflect.TypeOf(model)
|
||||||
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
for modelType != nil && (modelType.Kind() == reflect.Pointer || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
||||||
modelType = modelType.Elem()
|
modelType = modelType.Elem()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -108,7 +108,7 @@ func tableNameProviderFromModel(model interface{}) (common.TableNameProvider, bo
|
|||||||
}
|
}
|
||||||
|
|
||||||
modelType := reflect.TypeOf(model)
|
modelType := reflect.TypeOf(model)
|
||||||
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
for modelType != nil && (modelType.Kind() == reflect.Pointer || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
||||||
modelType = modelType.Elem()
|
modelType = modelType.Elem()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -174,7 +174,9 @@ func (h *HTTPResponseWriter) Write(data []byte) (int, error) {
|
|||||||
|
|
||||||
func (h *HTTPResponseWriter) WriteJSON(data interface{}) error {
|
func (h *HTTPResponseWriter) WriteJSON(data interface{}) error {
|
||||||
h.SetHeader("Content-Type", "application/json")
|
h.SetHeader("Content-Type", "application/json")
|
||||||
return json.NewEncoder(h.resp).Encode(data)
|
enc := json.NewEncoder(h.resp)
|
||||||
|
enc.SetEscapeHTML(false)
|
||||||
|
return enc.Encode(data)
|
||||||
}
|
}
|
||||||
|
|
||||||
// UnderlyingResponseWriter returns the underlying http.ResponseWriter
|
// UnderlyingResponseWriter returns the underlying http.ResponseWriter
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ func ValidateAndUnwrapModel(model interface{}) (*ValidateAndUnwrapModelResult, e
|
|||||||
originalType := modelType
|
originalType := modelType
|
||||||
|
|
||||||
// Unwrap pointers, slices, and arrays to get to the base struct type
|
// Unwrap pointers, slices, and arrays to get to the base struct type
|
||||||
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
for modelType != nil && (modelType.Kind() == reflect.Pointer || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
||||||
modelType = modelType.Elem()
|
modelType = modelType.Elem()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -126,15 +126,15 @@ func GetRelationshipInfo(modelType reflect.Type, relationName string) *Relations
|
|||||||
// Get related model type
|
// Get related model type
|
||||||
if field.Type.Kind() == reflect.Slice {
|
if field.Type.Kind() == reflect.Slice {
|
||||||
elemType := field.Type.Elem()
|
elemType := field.Type.Elem()
|
||||||
if elemType.Kind() == reflect.Ptr {
|
if elemType.Kind() == reflect.Pointer {
|
||||||
elemType = elemType.Elem()
|
elemType = elemType.Elem()
|
||||||
}
|
}
|
||||||
if elemType.Kind() == reflect.Struct {
|
if elemType.Kind() == reflect.Struct {
|
||||||
info.RelatedModel = reflect.New(elemType).Elem().Interface()
|
info.RelatedModel = reflect.New(elemType).Elem().Interface()
|
||||||
}
|
}
|
||||||
} else if field.Type.Kind() == reflect.Ptr || field.Type.Kind() == reflect.Struct {
|
} else if field.Type.Kind() == reflect.Pointer || field.Type.Kind() == reflect.Struct {
|
||||||
elemType := field.Type
|
elemType := field.Type
|
||||||
if elemType.Kind() == reflect.Ptr {
|
if elemType.Kind() == reflect.Pointer {
|
||||||
elemType = elemType.Elem()
|
elemType = elemType.Elem()
|
||||||
}
|
}
|
||||||
if elemType.Kind() == reflect.Struct {
|
if elemType.Kind() == reflect.Struct {
|
||||||
@@ -155,16 +155,16 @@ func GetRelationshipInfo(modelType reflect.Type, relationName string) *Relations
|
|||||||
info.RelationType = "hasMany"
|
info.RelationType = "hasMany"
|
||||||
// Get the element type for slice
|
// Get the element type for slice
|
||||||
elemType := field.Type.Elem()
|
elemType := field.Type.Elem()
|
||||||
if elemType.Kind() == reflect.Ptr {
|
if elemType.Kind() == reflect.Pointer {
|
||||||
elemType = elemType.Elem()
|
elemType = elemType.Elem()
|
||||||
}
|
}
|
||||||
if elemType.Kind() == reflect.Struct {
|
if elemType.Kind() == reflect.Struct {
|
||||||
info.RelatedModel = reflect.New(elemType).Elem().Interface()
|
info.RelatedModel = reflect.New(elemType).Elem().Interface()
|
||||||
}
|
}
|
||||||
} else if field.Type.Kind() == reflect.Ptr || field.Type.Kind() == reflect.Struct {
|
} else if field.Type.Kind() == reflect.Pointer || field.Type.Kind() == reflect.Struct {
|
||||||
info.RelationType = "belongsTo"
|
info.RelationType = "belongsTo"
|
||||||
elemType := field.Type
|
elemType := field.Type
|
||||||
if elemType.Kind() == reflect.Ptr {
|
if elemType.Kind() == reflect.Pointer {
|
||||||
elemType = elemType.Elem()
|
elemType = elemType.Elem()
|
||||||
}
|
}
|
||||||
if elemType.Kind() == reflect.Struct {
|
if elemType.Kind() == reflect.Struct {
|
||||||
@@ -177,7 +177,7 @@ func GetRelationshipInfo(modelType reflect.Type, relationName string) *Relations
|
|||||||
// Get the element type for many2many (always slice)
|
// Get the element type for many2many (always slice)
|
||||||
if field.Type.Kind() == reflect.Slice {
|
if field.Type.Kind() == reflect.Slice {
|
||||||
elemType := field.Type.Elem()
|
elemType := field.Type.Elem()
|
||||||
if elemType.Kind() == reflect.Ptr {
|
if elemType.Kind() == reflect.Pointer {
|
||||||
elemType = elemType.Elem()
|
elemType = elemType.Elem()
|
||||||
}
|
}
|
||||||
if elemType.Kind() == reflect.Struct {
|
if elemType.Kind() == reflect.Struct {
|
||||||
@@ -239,7 +239,7 @@ func GetTableNameFromModel(model interface{}) string {
|
|||||||
modelType := reflect.TypeOf(model)
|
modelType := reflect.TypeOf(model)
|
||||||
|
|
||||||
// Unwrap pointers
|
// Unwrap pointers
|
||||||
for modelType != nil && modelType.Kind() == reflect.Ptr {
|
for modelType != nil && modelType.Kind() == reflect.Pointer {
|
||||||
modelType = modelType.Elem()
|
modelType = modelType.Elem()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -178,7 +178,9 @@ func (s *StandardResponseWriter) Write(data []byte) (int, error) {
|
|||||||
|
|
||||||
func (s *StandardResponseWriter) WriteJSON(data interface{}) error {
|
func (s *StandardResponseWriter) WriteJSON(data interface{}) error {
|
||||||
s.SetHeader("Content-Type", "application/json")
|
s.SetHeader("Content-Type", "application/json")
|
||||||
return json.NewEncoder(s.w).Encode(data)
|
enc := json.NewEncoder(s.w)
|
||||||
|
enc.SetEscapeHTML(false)
|
||||||
|
return enc.Encode(data)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *StandardResponseWriter) UnderlyingResponseWriter() http.ResponseWriter {
|
func (s *StandardResponseWriter) UnderlyingResponseWriter() http.ResponseWriter {
|
||||||
|
|||||||
@@ -69,7 +69,7 @@ func (p *NestedCUDProcessor) ProcessNestedCUD(
|
|||||||
|
|
||||||
// Get model type for reflection
|
// Get model type for reflection
|
||||||
modelType := reflect.TypeOf(model)
|
modelType := reflect.TypeOf(model)
|
||||||
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
for modelType != nil && (modelType.Kind() == reflect.Pointer || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
||||||
modelType = modelType.Elem()
|
modelType = modelType.Elem()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -113,7 +113,7 @@ func (p *NestedCUDProcessor) ProcessNestedCUD(
|
|||||||
|
|
||||||
// Process based on operation
|
// Process based on operation
|
||||||
switch strings.ToLower(operation) {
|
switch strings.ToLower(operation) {
|
||||||
case "insert", "create":
|
case "insert", "create", "add":
|
||||||
// Only perform insert if we have data to insert
|
// Only perform insert if we have data to insert
|
||||||
if hasData {
|
if hasData {
|
||||||
id, err := p.processInsert(ctx, regularData, tableName)
|
id, err := p.processInsert(ctx, regularData, tableName)
|
||||||
@@ -141,7 +141,7 @@ func (p *NestedCUDProcessor) ProcessNestedCUD(
|
|||||||
logger.Debug("Skipping insert for %s - no data columns besides _request", tableName)
|
logger.Debug("Skipping insert for %s - no data columns besides _request", tableName)
|
||||||
}
|
}
|
||||||
|
|
||||||
case "update", "change":
|
case "update", "change", "modify":
|
||||||
// Only perform update if we have data to update
|
// Only perform update if we have data to update
|
||||||
if reflection.IsEmptyValue(data[pkName]) {
|
if reflection.IsEmptyValue(data[pkName]) {
|
||||||
logger.Warn("Skipping update for %s - no primary key", tableName)
|
logger.Warn("Skipping update for %s - no primary key", tableName)
|
||||||
@@ -174,7 +174,7 @@ func (p *NestedCUDProcessor) ProcessNestedCUD(
|
|||||||
result.ID = data[pkName]
|
result.ID = data[pkName]
|
||||||
}
|
}
|
||||||
|
|
||||||
case "delete":
|
case "delete", "remove":
|
||||||
if reflection.IsEmptyValue(data[pkName]) {
|
if reflection.IsEmptyValue(data[pkName]) {
|
||||||
logger.Warn("Skipping delete for %s - no primary key", tableName)
|
logger.Warn("Skipping delete for %s - no primary key", tableName)
|
||||||
return result, nil
|
return result, nil
|
||||||
@@ -224,7 +224,7 @@ func (p *NestedCUDProcessor) filterValidFields(data map[string]interface{}, mode
|
|||||||
}
|
}
|
||||||
|
|
||||||
modelType := reflect.TypeOf(model)
|
modelType := reflect.TypeOf(model)
|
||||||
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
for modelType != nil && (modelType.Kind() == reflect.Pointer || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
||||||
modelType = modelType.Elem()
|
modelType = modelType.Elem()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -410,7 +410,7 @@ func (p *NestedCUDProcessor) processChildRelations(
|
|||||||
if relatedModelType.Kind() == reflect.Slice {
|
if relatedModelType.Kind() == reflect.Slice {
|
||||||
relatedModelType = relatedModelType.Elem()
|
relatedModelType = relatedModelType.Elem()
|
||||||
}
|
}
|
||||||
if relatedModelType.Kind() == reflect.Ptr {
|
if relatedModelType.Kind() == reflect.Pointer {
|
||||||
relatedModelType = relatedModelType.Elem()
|
relatedModelType = relatedModelType.Elem()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -471,13 +471,17 @@ func (p *NestedCUDProcessor) processChildRelations(
|
|||||||
// Priority: Use foreign key field name if specified
|
// Priority: Use foreign key field name if specified
|
||||||
var foreignKeyFieldName string
|
var foreignKeyFieldName string
|
||||||
if relInfo.ForeignKey != "" {
|
if relInfo.ForeignKey != "" {
|
||||||
// Get the JSON name for the foreign key field in the child model
|
// For has-many/has-one: join:parentCol=childCol
|
||||||
foreignKeyFieldName = reflection.GetJSONNameForField(relatedModelType, relInfo.ForeignKey)
|
// ForeignKey = parent side, References = child side (where we actually set the value)
|
||||||
if foreignKeyFieldName == "" {
|
childField := relInfo.ForeignKey
|
||||||
// Fallback to lowercase field name
|
if (relInfo.RelationType == "hasMany" || relInfo.RelationType == "hasOne") && relInfo.References != "" {
|
||||||
foreignKeyFieldName = strings.ToLower(relInfo.ForeignKey)
|
childField = relInfo.References
|
||||||
}
|
}
|
||||||
logger.Debug("Using foreign key field for direct assignment: %s (from FK %s)", foreignKeyFieldName, relInfo.ForeignKey)
|
foreignKeyFieldName = reflection.GetJSONNameForField(relatedModelType, childField)
|
||||||
|
if foreignKeyFieldName == "" {
|
||||||
|
foreignKeyFieldName = strings.ToLower(childField)
|
||||||
|
}
|
||||||
|
logger.Debug("Using foreign key field for direct assignment: %s (from FK %s -> child %s)", foreignKeyFieldName, relInfo.ForeignKey, childField)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the primary key name for the child model to avoid overwriting it in recursive relationships
|
// Get the primary key name for the child model to avoid overwriting it in recursive relationships
|
||||||
@@ -586,7 +590,7 @@ func shouldUseNestedProcessorDepth(data map[string]interface{}, model interface{
|
|||||||
|
|
||||||
// Get model type
|
// Get model type
|
||||||
modelType := reflect.TypeOf(model)
|
modelType := reflect.TypeOf(model)
|
||||||
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
for modelType != nil && (modelType.Kind() == reflect.Pointer || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
||||||
modelType = modelType.Elem()
|
modelType = modelType.Elem()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -713,6 +713,220 @@ func TestInjectForeignKeys(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Models for asymmetric join column tests (mirrors the bun has-many join:parentCol=childCol pattern).
|
||||||
|
// ActionOption has-many ActionOptionLinks via join:rid_actionoption=rid_actionoption_child.
|
||||||
|
// The child column ("rid_actionoption_child") differs from the parent column ("rid_actionoption").
|
||||||
|
type ActionOption struct {
|
||||||
|
RidActionoption int64 `json:"rid_actionoption" bun:"rid_actionoption,pk"`
|
||||||
|
Label string `json:"label"`
|
||||||
|
Links []*ActionOptionLink `json:"aol_rid_actionoption_child,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a ActionOption) TableName() string { return "action_options" }
|
||||||
|
func (a ActionOption) GetIDName() string { return "RidActionoption" }
|
||||||
|
|
||||||
|
type ActionOptionLink struct {
|
||||||
|
RidActionoptionlink int64 `json:"rid_actionoptionlink" bun:"rid_actionoptionlink,pk"`
|
||||||
|
RidActionoptionChild int64 `json:"rid_actionoption_child" bun:"rid_actionoption_child"`
|
||||||
|
Label string `json:"label"`
|
||||||
|
// Note: no field named "rid_actionoption" — that is the parent's column.
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a ActionOptionLink) TableName() string { return "action_option_links" }
|
||||||
|
func (a ActionOptionLink) GetIDName() string { return "RidActionoptionlink" }
|
||||||
|
|
||||||
|
// TestProcessNestedCUD_AsymmetricJoinColumns verifies that for a has-many relation with
|
||||||
|
// join:parentCol=childCol, the child rows are stamped with the child-side column (References),
|
||||||
|
// not the parent-side column (ForeignKey).
|
||||||
|
func TestProcessNestedCUD_AsymmetricJoinColumns(t *testing.T) {
|
||||||
|
db := newMockDatabase()
|
||||||
|
registry := &mockModelRegistry{}
|
||||||
|
relProvider := newMockRelationshipProvider()
|
||||||
|
|
||||||
|
// Mirrors: bun:"rel:has-many,join:rid_actionoption=rid_actionoption_child"
|
||||||
|
relProvider.RegisterRelation("ActionOption", "aol_rid_actionoption_child", &RelationshipInfo{
|
||||||
|
FieldName: "Links",
|
||||||
|
JSONName: "aol_rid_actionoption_child",
|
||||||
|
RelationType: "hasMany",
|
||||||
|
ForeignKey: "rid_actionoption", // parent-side column (left of join:)
|
||||||
|
References: "rid_actionoption_child", // child-side column (right of join:)
|
||||||
|
RelatedModel: ActionOptionLink{},
|
||||||
|
})
|
||||||
|
|
||||||
|
processor := NewNestedCUDProcessor(db, registry, relProvider)
|
||||||
|
|
||||||
|
data := map[string]interface{}{
|
||||||
|
"label": "option-a",
|
||||||
|
"aol_rid_actionoption_child": []interface{}{
|
||||||
|
map[string]interface{}{"label": "link-1"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := processor.ProcessNestedCUD(
|
||||||
|
context.Background(),
|
||||||
|
"insert",
|
||||||
|
data,
|
||||||
|
ActionOption{},
|
||||||
|
nil,
|
||||||
|
"action_options",
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ProcessNestedCUD failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(db.insertCalls) < 2 {
|
||||||
|
t.Fatalf("Expected at least 2 insert calls (parent + child), got %d", len(db.insertCalls))
|
||||||
|
}
|
||||||
|
|
||||||
|
childInsert := db.insertCalls[1]
|
||||||
|
|
||||||
|
// The fix: child must receive "rid_actionoption_child", NOT "rid_actionoption".
|
||||||
|
if childInsert["rid_actionoption_child"] == nil {
|
||||||
|
t.Error("Expected child to have rid_actionoption_child set (child-side FK column)")
|
||||||
|
}
|
||||||
|
if childInsert["rid_actionoption"] != nil {
|
||||||
|
t.Errorf("Child must not receive parent-side column rid_actionoption, got %v", childInsert["rid_actionoption"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestProcessNestedCUD_BelongsToUnchanged verifies that the fix does not regress belongsTo
|
||||||
|
// relations, where ForeignKey is already the local (child) column.
|
||||||
|
func TestProcessNestedCUD_BelongsToUnchanged(t *testing.T) {
|
||||||
|
db := newMockDatabase()
|
||||||
|
registry := &mockModelRegistry{}
|
||||||
|
relProvider := newMockRelationshipProvider()
|
||||||
|
|
||||||
|
// For belongsTo, ForeignKey is the column on the child; References is on the parent.
|
||||||
|
// The old and new code must behave identically here.
|
||||||
|
relProvider.RegisterRelation("Employee", "department", &RelationshipInfo{
|
||||||
|
FieldName: "Department",
|
||||||
|
JSONName: "department",
|
||||||
|
RelationType: "belongsTo",
|
||||||
|
ForeignKey: "DepartmentID", // child's own column
|
||||||
|
References: "ID", // parent's PK
|
||||||
|
RelatedModel: Department{},
|
||||||
|
})
|
||||||
|
relProvider.RegisterRelation("Department", "employees", &RelationshipInfo{
|
||||||
|
FieldName: "Employees",
|
||||||
|
JSONName: "employees",
|
||||||
|
RelationType: "has_many",
|
||||||
|
ForeignKey: "DepartmentID",
|
||||||
|
RelatedModel: Employee{},
|
||||||
|
})
|
||||||
|
|
||||||
|
processor := NewNestedCUDProcessor(db, registry, relProvider)
|
||||||
|
|
||||||
|
data := map[string]interface{}{
|
||||||
|
"name": "Engineering",
|
||||||
|
"employees": []interface{}{
|
||||||
|
map[string]interface{}{"name": "Alice"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := processor.ProcessNestedCUD(
|
||||||
|
context.Background(),
|
||||||
|
"insert",
|
||||||
|
data,
|
||||||
|
Department{},
|
||||||
|
nil,
|
||||||
|
"departments",
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ProcessNestedCUD failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(db.insertCalls) < 2 {
|
||||||
|
t.Fatalf("Expected at least 2 inserts, got %d", len(db.insertCalls))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Employees relation uses has_many (old-style) so it goes through the parentIDs injection path,
|
||||||
|
// not the foreignKeyFieldName path. Just confirm no panic and employee is inserted.
|
||||||
|
if db.insertCalls[0]["name"] != "Engineering" {
|
||||||
|
t.Errorf("Expected department name 'Engineering', got %v", db.insertCalls[0]["name"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProcessNestedCUD_AddAlias(t *testing.T) {
|
||||||
|
db := newMockDatabase()
|
||||||
|
registry := &mockModelRegistry{}
|
||||||
|
relProvider := newMockRelationshipProvider()
|
||||||
|
|
||||||
|
processor := NewNestedCUDProcessor(db, registry, relProvider)
|
||||||
|
|
||||||
|
data := map[string]interface{}{
|
||||||
|
"_request": "add",
|
||||||
|
"name": "New Department",
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := processor.ProcessNestedCUD(context.Background(), "insert", data, Department{}, nil, "departments")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ProcessNestedCUD with _request=add failed: %v", err)
|
||||||
|
}
|
||||||
|
if result.ID == nil {
|
||||||
|
t.Error("Expected result.ID to be set after add")
|
||||||
|
}
|
||||||
|
if len(db.insertCalls) != 1 {
|
||||||
|
t.Errorf("Expected 1 insert call, got %d", len(db.insertCalls))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProcessNestedCUD_RemoveAlias(t *testing.T) {
|
||||||
|
db := newMockDatabase()
|
||||||
|
registry := &mockModelRegistry{}
|
||||||
|
relProvider := newMockRelationshipProvider()
|
||||||
|
|
||||||
|
processor := NewNestedCUDProcessor(db, registry, relProvider)
|
||||||
|
|
||||||
|
data := map[string]interface{}{
|
||||||
|
"_request": "remove",
|
||||||
|
"ID": int64(42),
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := processor.ProcessNestedCUD(context.Background(), "delete", data, Department{}, nil, "departments")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ProcessNestedCUD with _request=remove failed: %v", err)
|
||||||
|
}
|
||||||
|
if len(db.deleteCalls) != 1 {
|
||||||
|
t.Errorf("Expected 1 delete call, got %d", len(db.deleteCalls))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProcessNestedCUD_NestedAddRemoveAliases(t *testing.T) {
|
||||||
|
db := newMockDatabase()
|
||||||
|
registry := &mockModelRegistry{}
|
||||||
|
relProvider := newMockRelationshipProvider()
|
||||||
|
|
||||||
|
relProvider.RegisterRelation("Department", "employees", &RelationshipInfo{
|
||||||
|
FieldName: "Employees",
|
||||||
|
JSONName: "employees",
|
||||||
|
RelationType: "has_many",
|
||||||
|
ForeignKey: "DepartmentID",
|
||||||
|
RelatedModel: Employee{},
|
||||||
|
})
|
||||||
|
|
||||||
|
processor := NewNestedCUDProcessor(db, registry, relProvider)
|
||||||
|
|
||||||
|
data := map[string]interface{}{
|
||||||
|
"ID": int64(1),
|
||||||
|
"name": "Engineering",
|
||||||
|
"employees": []interface{}{
|
||||||
|
map[string]interface{}{"_request": "add", "name": "Alice"},
|
||||||
|
map[string]interface{}{"_request": "remove", "ID": int64(5)},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := processor.ProcessNestedCUD(context.Background(), "update", data, Department{}, nil, "departments")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ProcessNestedCUD with nested add/remove failed: %v", err)
|
||||||
|
}
|
||||||
|
if len(db.insertCalls) != 1 {
|
||||||
|
t.Errorf("Expected 1 insert (add alias) for employee, got %d", len(db.insertCalls))
|
||||||
|
}
|
||||||
|
if len(db.deleteCalls) != 1 {
|
||||||
|
t.Errorf("Expected 1 delete (remove alias) for employee, got %d", len(db.deleteCalls))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestGetPrimaryKeyName(t *testing.T) {
|
func TestGetPrimaryKeyName(t *testing.T) {
|
||||||
dept := Department{}
|
dept := Department{}
|
||||||
pkName := reflection.GetPrimaryKeyName(dept)
|
pkName := reflection.GetPrimaryKeyName(dept)
|
||||||
|
|||||||
@@ -614,6 +614,15 @@ func extractTableAndColumn(cond string) (table string, column string) {
|
|||||||
// Remove any quotes
|
// Remove any quotes
|
||||||
columnRef = strings.Trim(columnRef, "`\"'")
|
columnRef = strings.Trim(columnRef, "`\"'")
|
||||||
|
|
||||||
|
// If the left side is a parenthesized subquery (starts with '(' and contains SQL keywords),
|
||||||
|
// don't attempt prefix extraction from inside it.
|
||||||
|
if len(columnRef) > 0 && columnRef[0] == '(' {
|
||||||
|
lowerRef := strings.ToLower(columnRef)
|
||||||
|
if strings.Contains(lowerRef, "select ") || strings.Contains(lowerRef, " from ") || strings.Contains(lowerRef, " where ") {
|
||||||
|
return "", ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Check if there's a function call (contains opening parenthesis)
|
// Check if there's a function call (contains opening parenthesis)
|
||||||
openParenIdx := strings.Index(columnRef, "(")
|
openParenIdx := strings.Index(columnRef, "(")
|
||||||
|
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package common
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
@@ -30,7 +31,7 @@ func (v *ColumnValidator) buildValidColumns() {
|
|||||||
modelType := reflect.TypeOf(v.model)
|
modelType := reflect.TypeOf(v.model)
|
||||||
|
|
||||||
// Unwrap pointers, slices, and arrays to get to the base struct type
|
// Unwrap pointers, slices, and arrays to get to the base struct type
|
||||||
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
for modelType != nil && (modelType.Kind() == reflect.Pointer || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
||||||
modelType = modelType.Elem()
|
modelType = modelType.Elem()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -43,7 +44,7 @@ func (v *ColumnValidator) buildValidColumns() {
|
|||||||
for i := 0; i < modelType.NumField(); i++ {
|
for i := 0; i < modelType.NumField(); i++ {
|
||||||
field := modelType.Field(i)
|
field := modelType.Field(i)
|
||||||
|
|
||||||
if !field.IsExported() {
|
if !field.IsExported() || field.Anonymous {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -125,6 +126,16 @@ func (v *ColumnValidator) IsValidColumn(column string) bool {
|
|||||||
return v.ValidateColumn(column) == nil
|
return v.ValidateColumn(column) == nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Columns returns all valid column names known to this validator
|
||||||
|
func (v *ColumnValidator) Columns() []string {
|
||||||
|
cols := make([]string, 0, len(v.validColumns))
|
||||||
|
for col := range v.validColumns {
|
||||||
|
cols = append(cols, col)
|
||||||
|
}
|
||||||
|
sort.Strings(cols)
|
||||||
|
return cols
|
||||||
|
}
|
||||||
|
|
||||||
// FilterValidColumns filters a list of columns, returning only valid ones
|
// FilterValidColumns filters a list of columns, returning only valid ones
|
||||||
// Logs warnings for any invalid columns
|
// Logs warnings for any invalid columns
|
||||||
func (v *ColumnValidator) FilterValidColumns(columns []string) []string {
|
func (v *ColumnValidator) FilterValidColumns(columns []string) []string {
|
||||||
@@ -224,7 +235,19 @@ func (v *ColumnValidator) FilterRequestOptions(options RequestOptions) RequestOp
|
|||||||
// Filter Filter columns
|
// Filter Filter columns
|
||||||
validFilters := make([]FilterOption, 0, len(options.Filters))
|
validFilters := make([]FilterOption, 0, len(options.Filters))
|
||||||
for _, filter := range options.Filters {
|
for _, filter := range options.Filters {
|
||||||
if v.IsValidColumn(filter.Column) {
|
if strings.EqualFold(filter.Column, "all") {
|
||||||
|
allCols := v.Columns()
|
||||||
|
if len(filtered.Columns) > 0 {
|
||||||
|
allCols = filtered.Columns
|
||||||
|
}
|
||||||
|
for _, col := range allCols {
|
||||||
|
expanded := filter
|
||||||
|
expanded.Column = col
|
||||||
|
expanded.LogicOperator = "OR"
|
||||||
|
|
||||||
|
validFilters = append(validFilters, expanded)
|
||||||
|
}
|
||||||
|
} else if v.IsValidColumn(filter.Column) {
|
||||||
validFilters = append(validFilters, filter)
|
validFilters = append(validFilters, filter)
|
||||||
} else {
|
} else {
|
||||||
logger.Warn("Invalid column in filter '%s' removed", filter.Column)
|
logger.Warn("Invalid column in filter '%s' removed", filter.Column)
|
||||||
@@ -266,11 +289,24 @@ func (v *ColumnValidator) FilterRequestOptions(options RequestOptions) RequestOp
|
|||||||
|
|
||||||
// Filter Preload columns
|
// Filter Preload columns
|
||||||
validPreloads := make([]PreloadOption, 0, len(options.Preload))
|
validPreloads := make([]PreloadOption, 0, len(options.Preload))
|
||||||
|
modelType := reflect.TypeOf(v.model)
|
||||||
|
if modelType != nil && modelType.Kind() == reflect.Pointer {
|
||||||
|
modelType = modelType.Elem()
|
||||||
|
}
|
||||||
for idx := range options.Preload {
|
for idx := range options.Preload {
|
||||||
preload := options.Preload[idx]
|
preload := options.Preload[idx]
|
||||||
filteredPreload := preload
|
filteredPreload := preload
|
||||||
filteredPreload.Columns = v.FilterValidColumns(preload.Columns)
|
|
||||||
filteredPreload.OmitColumns = v.FilterValidColumns(preload.OmitColumns)
|
// Use the related model's validator for preload columns/filters/sorts
|
||||||
|
preloadValidator := v
|
||||||
|
if modelType != nil {
|
||||||
|
if relInfo := GetRelationshipInfo(modelType, preload.Relation); relInfo != nil && relInfo.RelatedModel != nil {
|
||||||
|
preloadValidator = NewColumnValidator(relInfo.RelatedModel)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
filteredPreload.Columns = preloadValidator.FilterValidColumns(preload.Columns)
|
||||||
|
filteredPreload.OmitColumns = preloadValidator.FilterValidColumns(preload.OmitColumns)
|
||||||
|
|
||||||
// Preserve SqlJoins and JoinAliases for preloads with custom joins
|
// Preserve SqlJoins and JoinAliases for preloads with custom joins
|
||||||
filteredPreload.SqlJoins = preload.SqlJoins
|
filteredPreload.SqlJoins = preload.SqlJoins
|
||||||
@@ -279,7 +315,7 @@ func (v *ColumnValidator) FilterRequestOptions(options RequestOptions) RequestOp
|
|||||||
// Filter preload filters
|
// Filter preload filters
|
||||||
validPreloadFilters := make([]FilterOption, 0, len(preload.Filters))
|
validPreloadFilters := make([]FilterOption, 0, len(preload.Filters))
|
||||||
for _, filter := range preload.Filters {
|
for _, filter := range preload.Filters {
|
||||||
if v.IsValidColumn(filter.Column) {
|
if preloadValidator.IsValidColumn(filter.Column) {
|
||||||
validPreloadFilters = append(validPreloadFilters, filter)
|
validPreloadFilters = append(validPreloadFilters, filter)
|
||||||
} else {
|
} else {
|
||||||
// Check if the filter column references a joined table alias
|
// Check if the filter column references a joined table alias
|
||||||
@@ -302,7 +338,7 @@ func (v *ColumnValidator) FilterRequestOptions(options RequestOptions) RequestOp
|
|||||||
// Filter preload sort columns
|
// Filter preload sort columns
|
||||||
validPreloadSorts := make([]SortOption, 0, len(preload.Sort))
|
validPreloadSorts := make([]SortOption, 0, len(preload.Sort))
|
||||||
for _, sort := range preload.Sort {
|
for _, sort := range preload.Sort {
|
||||||
if v.IsValidColumn(sort.Column) {
|
if preloadValidator.IsValidColumn(sort.Column) {
|
||||||
validPreloadSorts = append(validPreloadSorts, sort)
|
validPreloadSorts = append(validPreloadSorts, sort)
|
||||||
} else if strings.HasPrefix(sort.Column, "(") && strings.HasSuffix(sort.Column, ")") {
|
} else if strings.HasPrefix(sort.Column, "(") && strings.HasSuffix(sort.Column, ")") {
|
||||||
// Allow sort by expression/subquery, but validate for security
|
// Allow sort by expression/subquery, but validate for security
|
||||||
|
|||||||
@@ -464,3 +464,84 @@ func TestFilterRequestOptions_WithSortExpressions(t *testing.T) {
|
|||||||
t.Errorf("Expected third sort to be 'name', got '%s'", filtered.Sort[2].Column)
|
t.Errorf("Expected third sort to be 'name', got '%s'", filtered.Sort[2].Column)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RelatedModel is used by PreloadParentModel to test preload column validation.
|
||||||
|
type RelatedModel struct {
|
||||||
|
RelatedID int64 `bun:"related_id,pk"`
|
||||||
|
Functionname string `bun:"functionname"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// PreloadParentModel has a has-one relation to RelatedModel. The json tag on
|
||||||
|
// the relation field is the name used in x-preload headers.
|
||||||
|
type PreloadParentModel struct {
|
||||||
|
ID int64 `bun:"id,pk"`
|
||||||
|
Name string `bun:"name"`
|
||||||
|
RELATED *RelatedModel `json:"RELATED" bun:"rel:has-one,join:id=related_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestFilterRequestOptions_PreloadColumnsValidatedAgainstRelatedModel verifies
|
||||||
|
// that preload columns are validated against the related model's fields, not the
|
||||||
|
// parent model's fields. This is the fix for the bug where specifying a column
|
||||||
|
// that exists only on the relation (e.g. "functionname") was incorrectly filtered
|
||||||
|
// out because it doesn't exist on the parent model.
|
||||||
|
func TestFilterRequestOptions_PreloadColumnsValidatedAgainstRelatedModel(t *testing.T) {
|
||||||
|
validator := NewColumnValidator(PreloadParentModel{})
|
||||||
|
|
||||||
|
options := RequestOptions{
|
||||||
|
Preload: []PreloadOption{
|
||||||
|
{
|
||||||
|
Relation: "RELATED",
|
||||||
|
// "functionname" exists on RelatedModel but NOT on PreloadParentModel.
|
||||||
|
// "name" exists on PreloadParentModel but NOT on RelatedModel.
|
||||||
|
// "nonexistent" exists on neither.
|
||||||
|
Columns: []string{"functionname", "name", "nonexistent"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
filtered := validator.FilterRequestOptions(options)
|
||||||
|
|
||||||
|
if len(filtered.Preload) != 1 {
|
||||||
|
t.Fatalf("Expected 1 preload, got %d", len(filtered.Preload))
|
||||||
|
}
|
||||||
|
|
||||||
|
cols := filtered.Preload[0].Columns
|
||||||
|
// Only "functionname" should survive: it belongs to RelatedModel.
|
||||||
|
if len(cols) != 1 {
|
||||||
|
t.Errorf("Expected 1 preload column, got %d: %v", len(cols), cols)
|
||||||
|
}
|
||||||
|
if len(cols) > 0 && cols[0] != "functionname" {
|
||||||
|
t.Errorf("Expected preload column 'functionname', got '%s'", cols[0])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestFilterRequestOptions_PreloadColumnsParentModelFallback verifies that when
|
||||||
|
// a preload relation is not found on the parent model, column validation falls
|
||||||
|
// back to the parent model's validator (no panic, no silent pass-through).
|
||||||
|
func TestFilterRequestOptions_PreloadColumnsParentModelFallback(t *testing.T) {
|
||||||
|
validator := NewColumnValidator(PreloadParentModel{})
|
||||||
|
|
||||||
|
options := RequestOptions{
|
||||||
|
Preload: []PreloadOption{
|
||||||
|
{
|
||||||
|
Relation: "UNKNOWN_RELATION",
|
||||||
|
Columns: []string{"id", "functionname"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
filtered := validator.FilterRequestOptions(options)
|
||||||
|
|
||||||
|
if len(filtered.Preload) != 1 {
|
||||||
|
t.Fatalf("Expected 1 preload, got %d", len(filtered.Preload))
|
||||||
|
}
|
||||||
|
|
||||||
|
cols := filtered.Preload[0].Columns
|
||||||
|
// Falls back to parent model: only "id" is valid on PreloadParentModel.
|
||||||
|
if len(cols) != 1 {
|
||||||
|
t.Errorf("Expected 1 preload column (fallback to parent), got %d: %v", len(cols), cols)
|
||||||
|
}
|
||||||
|
if len(cols) > 0 && cols[0] != "id" {
|
||||||
|
t.Errorf("Expected preload column 'id', got '%s'", cols[0])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -174,7 +174,7 @@ func (h *Handler) SqlQueryList(sqlquery string, options SqlQueryOptions) HTTPFun
|
|||||||
varName := kw[1 : len(kw)-1] // strip [ and ]
|
varName := kw[1 : len(kw)-1] // strip [ and ]
|
||||||
if val, ok := variables[varName]; ok {
|
if val, ok := variables[varName]; ok {
|
||||||
if strVal := fmt.Sprintf("%v", val); strVal != "" {
|
if strVal := fmt.Sprintf("%v", val); strVal != "" {
|
||||||
sqlquery = strings.ReplaceAll(sqlquery, kw, ValidSQL(strVal, "colvalue"))
|
sqlquery = strings.ReplaceAll(sqlquery, kw, safeSubstituteVar(sqlquery, kw, strVal))
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -533,7 +533,7 @@ func (h *Handler) SqlQuery(sqlquery string, options SqlQueryOptions) HTTPFuncTyp
|
|||||||
varName := kw[1 : len(kw)-1] // strip [ and ]
|
varName := kw[1 : len(kw)-1] // strip [ and ]
|
||||||
if val, ok := variables[varName]; ok {
|
if val, ok := variables[varName]; ok {
|
||||||
if strVal := fmt.Sprintf("%v", val); strVal != "" {
|
if strVal := fmt.Sprintf("%v", val); strVal != "" {
|
||||||
sqlquery = strings.ReplaceAll(sqlquery, kw, ValidSQL(strVal, "colvalue"))
|
sqlquery = strings.ReplaceAll(sqlquery, kw, safeSubstituteVar(sqlquery, kw, strVal))
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1006,6 +1006,37 @@ func IsNumeric(s string) bool {
|
|||||||
return err == nil
|
return err == nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// isInsideDollarQuote reports whether the first occurrence of placeholder in sqlquery
|
||||||
|
// is immediately surrounded by dollar-sign characters (i.e. inside a $...$-quoted string).
|
||||||
|
// Dollar-quoted strings pass content through literally — no backslash processing — so
|
||||||
|
// values placed there must NOT have their backslashes escaped.
|
||||||
|
func isInsideDollarQuote(sqlquery, placeholder string) bool {
|
||||||
|
idx := strings.Index(sqlquery, placeholder)
|
||||||
|
if idx < 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
endIdx := idx + len(placeholder)
|
||||||
|
charBefore := byte(0)
|
||||||
|
charAfter := byte(0)
|
||||||
|
if idx > 0 {
|
||||||
|
charBefore = sqlquery[idx-1]
|
||||||
|
}
|
||||||
|
if endIdx < len(sqlquery) {
|
||||||
|
charAfter = sqlquery[endIdx]
|
||||||
|
}
|
||||||
|
return charBefore == '$' || charAfter == '$'
|
||||||
|
}
|
||||||
|
|
||||||
|
// safeSubstituteVar returns value sanitised for the quoting context that surrounds
|
||||||
|
// placeholder in sqlquery: raw (no backslash escaping) for dollar-quoted contexts,
|
||||||
|
// ValidSQL("colvalue") escaping for everything else.
|
||||||
|
func safeSubstituteVar(sqlquery, placeholder, value string) string {
|
||||||
|
if isInsideDollarQuote(sqlquery, placeholder) {
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
return ValidSQL(value, "colvalue")
|
||||||
|
}
|
||||||
|
|
||||||
// getReplacementForBlankParam determines the replacement value for an unused parameter
|
// getReplacementForBlankParam determines the replacement value for an unused parameter
|
||||||
// based on whether it appears within quotes in the SQL query.
|
// based on whether it appears within quotes in the SQL query.
|
||||||
// It checks for PostgreSQL quotes: single quotes (”) and dollar quotes ($...$)
|
// It checks for PostgreSQL quotes: single quotes (”) and dollar quotes ($...$)
|
||||||
|
|||||||
@@ -107,7 +107,7 @@ func (r *DefaultModelRegistry) RegisterModel(name string, model interface{}) err
|
|||||||
originalType := modelType
|
originalType := modelType
|
||||||
|
|
||||||
// Unwrap pointers, slices, and arrays to check the underlying type
|
// Unwrap pointers, slices, and arrays to check the underlying type
|
||||||
for modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array {
|
for modelType.Kind() == reflect.Pointer || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array {
|
||||||
modelType = modelType.Elem()
|
modelType = modelType.Elem()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -124,7 +124,7 @@ func (r *DefaultModelRegistry) RegisterModel(name string, model interface{}) err
|
|||||||
|
|
||||||
// Additional check: ensure model is not a pointer
|
// Additional check: ensure model is not a pointer
|
||||||
finalType := reflect.TypeOf(model)
|
finalType := reflect.TypeOf(model)
|
||||||
if finalType.Kind() == reflect.Ptr {
|
if finalType.Kind() == reflect.Pointer {
|
||||||
return fmt.Errorf("model must be a non-pointer struct, got pointer to %s. Use MyModel{} instead of &MyModel{}", finalType.Elem().Name())
|
return fmt.Errorf("model must be a non-pointer struct, got pointer to %s. Use MyModel{} instead of &MyModel{}", finalType.Elem().Name())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -781,6 +781,12 @@ func (h *Handler) create(hookCtx *HookContext) (interface{}, error) {
|
|||||||
return nil, fmt.Errorf("failed to create record: %w", err)
|
return nil, fmt.Errorf("failed to create record: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Re-fetch the created record to capture DB-generated defaults/triggers.
|
||||||
|
if pkVal := reflection.GetPrimaryKeyValue(hookCtx.ModelPtr); pkVal != nil {
|
||||||
|
hookCtx.ID = fmt.Sprintf("%v", pkVal)
|
||||||
|
return h.readByID(hookCtx)
|
||||||
|
}
|
||||||
|
|
||||||
return hookCtx.ModelPtr, nil
|
return hookCtx.ModelPtr, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -387,7 +387,7 @@ func (g *Generator) generateModelSchema(model interface{}) Schema {
|
|||||||
}
|
}
|
||||||
|
|
||||||
modelType := reflect.TypeOf(model)
|
modelType := reflect.TypeOf(model)
|
||||||
if modelType.Kind() == reflect.Ptr {
|
if modelType.Kind() == reflect.Pointer {
|
||||||
modelType = modelType.Elem()
|
modelType = modelType.Elem()
|
||||||
}
|
}
|
||||||
if modelType.Kind() != reflect.Struct {
|
if modelType.Kind() != reflect.Struct {
|
||||||
@@ -418,7 +418,7 @@ func (g *Generator) generateModelSchema(model interface{}) Schema {
|
|||||||
schema.Properties[fieldName] = propSchema
|
schema.Properties[fieldName] = propSchema
|
||||||
|
|
||||||
// Check if field is required (not a pointer and no omitempty)
|
// Check if field is required (not a pointer and no omitempty)
|
||||||
if field.Type.Kind() != reflect.Ptr && !strings.Contains(jsonTag, "omitempty") {
|
if field.Type.Kind() != reflect.Pointer && !strings.Contains(jsonTag, "omitempty") {
|
||||||
schema.Required = append(schema.Required, fieldName)
|
schema.Required = append(schema.Required, fieldName)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -431,7 +431,7 @@ func (g *Generator) generatePropertySchema(field reflect.StructField) *Schema {
|
|||||||
schema := &Schema{}
|
schema := &Schema{}
|
||||||
|
|
||||||
fieldType := field.Type
|
fieldType := field.Type
|
||||||
if fieldType.Kind() == reflect.Ptr {
|
if fieldType.Kind() == reflect.Pointer {
|
||||||
fieldType = fieldType.Elem()
|
fieldType = fieldType.Elem()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -453,7 +453,7 @@ func (g *Generator) generatePropertySchema(field reflect.StructField) *Schema {
|
|||||||
case reflect.Slice, reflect.Array:
|
case reflect.Slice, reflect.Array:
|
||||||
schema.Type = "array"
|
schema.Type = "array"
|
||||||
elemType := fieldType.Elem()
|
elemType := fieldType.Elem()
|
||||||
if elemType.Kind() == reflect.Ptr {
|
if elemType.Kind() == reflect.Pointer {
|
||||||
elemType = elemType.Elem()
|
elemType = elemType.Elem()
|
||||||
}
|
}
|
||||||
if elemType.Kind() == reflect.Struct {
|
if elemType.Kind() == reflect.Struct {
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ func Len(v any) int {
|
|||||||
val := reflect.ValueOf(v)
|
val := reflect.ValueOf(v)
|
||||||
valKind := val.Kind()
|
valKind := val.Kind()
|
||||||
|
|
||||||
if valKind == reflect.Ptr {
|
if valKind == reflect.Pointer {
|
||||||
val = val.Elem()
|
val = val.Elem()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -57,7 +57,7 @@ func IsEmptyValue(v any) bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
rv := reflect.ValueOf(v)
|
rv := reflect.ValueOf(v)
|
||||||
if rv.Kind() == reflect.Ptr {
|
if rv.Kind() == reflect.Pointer {
|
||||||
if rv.IsNil() {
|
if rv.IsNil() {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
@@ -80,12 +80,12 @@ func IsEmptyValue(v any) bool {
|
|||||||
// If the type is a slice of pointers, it returns the element type of the pointer within the slice.
|
// If the type is a slice of pointers, it returns the element type of the pointer within the slice.
|
||||||
// If neither condition is met, it returns the original type.
|
// If neither condition is met, it returns the original type.
|
||||||
func GetPointerElement(v reflect.Type) reflect.Type {
|
func GetPointerElement(v reflect.Type) reflect.Type {
|
||||||
if v.Kind() == reflect.Ptr {
|
if v.Kind() == reflect.Pointer {
|
||||||
return v.Elem()
|
return v.Elem()
|
||||||
}
|
}
|
||||||
if v.Kind() == reflect.Slice && v.Elem().Kind() == reflect.Ptr {
|
if v.Kind() == reflect.Slice && v.Elem().Kind() == reflect.Pointer {
|
||||||
subElem := v.Elem()
|
subElem := v.Elem()
|
||||||
if subElem.Elem().Kind() == reflect.Ptr {
|
if subElem.Elem().Kind() == reflect.Pointer {
|
||||||
return subElem.Elem().Elem()
|
return subElem.Elem().Elem()
|
||||||
}
|
}
|
||||||
return v.Elem()
|
return v.Elem()
|
||||||
@@ -104,7 +104,7 @@ func GetJSONNameForField(modelType reflect.Type, fieldName string) string {
|
|||||||
// Unwrap pointer and slice indirections to reach the struct type
|
// Unwrap pointer and slice indirections to reach the struct type
|
||||||
for {
|
for {
|
||||||
switch modelType.Kind() {
|
switch modelType.Kind() {
|
||||||
case reflect.Ptr, reflect.Slice:
|
case reflect.Pointer, reflect.Slice:
|
||||||
modelType = modelType.Elem()
|
modelType = modelType.Elem()
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -226,7 +226,7 @@ func buildJSONToDBMap(modelType reflect.Type, result map[string]string, scanOnly
|
|||||||
// Handle embedded structs
|
// Handle embedded structs
|
||||||
if field.Anonymous {
|
if field.Anonymous {
|
||||||
ft := field.Type
|
ft := field.Type
|
||||||
if ft.Kind() == reflect.Ptr {
|
if ft.Kind() == reflect.Pointer {
|
||||||
ft = ft.Elem()
|
ft = ft.Elem()
|
||||||
}
|
}
|
||||||
isScanOnly := scanOnly
|
isScanOnly := scanOnly
|
||||||
@@ -544,7 +544,7 @@ func IsColumnWritable(model any, columnName string) bool {
|
|||||||
// Unwrap pointers and slices to get to the base struct type
|
// Unwrap pointers and slices to get to the base struct type
|
||||||
for modelType != nil {
|
for modelType != nil {
|
||||||
switch modelType.Kind() {
|
switch modelType.Kind() {
|
||||||
case reflect.Ptr, reflect.Slice:
|
case reflect.Pointer, reflect.Slice:
|
||||||
modelType = modelType.Elem()
|
modelType = modelType.Elem()
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -709,7 +709,7 @@ func GetColumnTypeFromModel(model interface{}, colName string) reflect.Kind {
|
|||||||
|
|
||||||
modelType := reflect.TypeOf(model)
|
modelType := reflect.TypeOf(model)
|
||||||
// Dereference pointer if needed
|
// Dereference pointer if needed
|
||||||
if modelType.Kind() == reflect.Ptr {
|
if modelType.Kind() == reflect.Pointer {
|
||||||
modelType = modelType.Elem()
|
modelType = modelType.Elem()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -886,7 +886,7 @@ func GetRelationType(model interface{}, fieldName string) RelationType {
|
|||||||
// Unwrap pointer → slice → pointer chains to reach the underlying struct
|
// Unwrap pointer → slice → pointer chains to reach the underlying struct
|
||||||
for {
|
for {
|
||||||
switch modelType.Kind() {
|
switch modelType.Kind() {
|
||||||
case reflect.Ptr, reflect.Slice:
|
case reflect.Pointer, reflect.Slice:
|
||||||
modelType = modelType.Elem()
|
modelType = modelType.Elem()
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -947,7 +947,7 @@ func GetRelationType(model interface{}, fieldName string) RelationType {
|
|||||||
// Slice indicates has-many or many-to-many
|
// Slice indicates has-many or many-to-many
|
||||||
return RelationHasMany
|
return RelationHasMany
|
||||||
}
|
}
|
||||||
if fieldType.Kind() == reflect.Ptr {
|
if fieldType.Kind() == reflect.Pointer {
|
||||||
// Pointer to single struct usually indicates belongs-to or has-one
|
// Pointer to single struct usually indicates belongs-to or has-one
|
||||||
// Check if it has foreignKey (belongs-to) or references (has-one)
|
// Check if it has foreignKey (belongs-to) or references (has-one)
|
||||||
if strings.Contains(gormTag, "foreignKey:") {
|
if strings.Contains(gormTag, "foreignKey:") {
|
||||||
@@ -963,7 +963,7 @@ func GetRelationType(model interface{}, fieldName string) RelationType {
|
|||||||
// Slice of structs → has-many
|
// Slice of structs → has-many
|
||||||
return RelationHasMany
|
return RelationHasMany
|
||||||
}
|
}
|
||||||
if fieldType.Kind() == reflect.Ptr || fieldType.Kind() == reflect.Struct {
|
if fieldType.Kind() == reflect.Pointer || fieldType.Kind() == reflect.Struct {
|
||||||
// Single struct → belongs-to (default assumption for safety)
|
// Single struct → belongs-to (default assumption for safety)
|
||||||
// Using belongs-to as default ensures we use JOIN, which is safer
|
// Using belongs-to as default ensures we use JOIN, which is safer
|
||||||
return RelationBelongsTo
|
return RelationBelongsTo
|
||||||
@@ -990,7 +990,7 @@ func GetRelationType(model interface{}, fieldName string) RelationType {
|
|||||||
// Strategy 1 is skipped if the matched field is a declared relation (rel:) or
|
// Strategy 1 is skipped if the matched field is a declared relation (rel:) or
|
||||||
// has a GORM tag but carries no explicit FK — callers should use convention.
|
// has a GORM tag but carries no explicit FK — callers should use convention.
|
||||||
func GetForeignKeyColumn(modelType reflect.Type, parentKey string) []string {
|
func GetForeignKeyColumn(modelType reflect.Type, parentKey string) []string {
|
||||||
for modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice {
|
for modelType.Kind() == reflect.Pointer || modelType.Kind() == reflect.Slice {
|
||||||
modelType = modelType.Elem()
|
modelType = modelType.Elem()
|
||||||
}
|
}
|
||||||
if modelType.Kind() != reflect.Struct {
|
if modelType.Kind() != reflect.Struct {
|
||||||
@@ -1123,7 +1123,7 @@ func MapToStruct(dataMap map[string]interface{}, target interface{}) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
targetValue := reflect.ValueOf(target)
|
targetValue := reflect.ValueOf(target)
|
||||||
if targetValue.Kind() != reflect.Ptr {
|
if targetValue.Kind() != reflect.Pointer {
|
||||||
return fmt.Errorf("target must be a pointer to a struct")
|
return fmt.Errorf("target must be a pointer to a struct")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1226,8 +1226,8 @@ func setFieldValue(field reflect.Value, value interface{}) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Handle pointer fields
|
// Handle pointer fields
|
||||||
if field.Kind() == reflect.Ptr {
|
if field.Kind() == reflect.Pointer {
|
||||||
if valueReflect.Kind() != reflect.Ptr {
|
if valueReflect.Kind() != reflect.Pointer {
|
||||||
// Create a new pointer and set its value
|
// Create a new pointer and set its value
|
||||||
newPtr := reflect.New(field.Type().Elem())
|
newPtr := reflect.New(field.Type().Elem())
|
||||||
if err := setFieldValue(newPtr.Elem(), value); err != nil {
|
if err := setFieldValue(newPtr.Elem(), value); err != nil {
|
||||||
@@ -1418,14 +1418,14 @@ func convertSlice(targetSlice reflect.Value, sourceSlice reflect.Value) error {
|
|||||||
// Handle nil elements
|
// Handle nil elements
|
||||||
if sourceValue == nil {
|
if sourceValue == nil {
|
||||||
// For pointer types, nil is valid
|
// For pointer types, nil is valid
|
||||||
if targetElemType.Kind() == reflect.Ptr {
|
if targetElemType.Kind() == reflect.Pointer {
|
||||||
targetElem.Set(reflect.Zero(targetElemType))
|
targetElem.Set(reflect.Zero(targetElemType))
|
||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// If target element type is a pointer to struct, we need to create new instances
|
// If target element type is a pointer to struct, we need to create new instances
|
||||||
if targetElemType.Kind() == reflect.Ptr {
|
if targetElemType.Kind() == reflect.Pointer {
|
||||||
// Create a new instance of the pointed-to type
|
// Create a new instance of the pointed-to type
|
||||||
newElemPtr := reflect.New(targetElemType.Elem())
|
newElemPtr := reflect.New(targetElemType.Elem())
|
||||||
|
|
||||||
@@ -1588,7 +1588,7 @@ func GetValidJSONFieldNames(modelType reflect.Type) map[string]bool {
|
|||||||
// Unwrap pointers and slices to get to the base struct type
|
// Unwrap pointers and slices to get to the base struct type
|
||||||
for modelType != nil {
|
for modelType != nil {
|
||||||
switch modelType.Kind() {
|
switch modelType.Kind() {
|
||||||
case reflect.Ptr, reflect.Slice:
|
case reflect.Pointer, reflect.Slice:
|
||||||
modelType = modelType.Elem()
|
modelType = modelType.Elem()
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -1616,7 +1616,7 @@ func collectValidFieldNames(typ reflect.Type, validFields map[string]bool) {
|
|||||||
// Check for embedded structs
|
// Check for embedded structs
|
||||||
if field.Anonymous {
|
if field.Anonymous {
|
||||||
fieldType := field.Type
|
fieldType := field.Type
|
||||||
if fieldType.Kind() == reflect.Ptr {
|
if fieldType.Kind() == reflect.Pointer {
|
||||||
fieldType = fieldType.Elem()
|
fieldType = fieldType.Elem()
|
||||||
}
|
}
|
||||||
if fieldType.Kind() == reflect.Struct {
|
if fieldType.Kind() == reflect.Struct {
|
||||||
@@ -1655,7 +1655,7 @@ func getRelationModelSingleLevel(model interface{}, fieldName string) interface{
|
|||||||
|
|
||||||
for {
|
for {
|
||||||
switch modelType.Kind() {
|
switch modelType.Kind() {
|
||||||
case reflect.Ptr, reflect.Slice:
|
case reflect.Pointer, reflect.Slice:
|
||||||
modelType = modelType.Elem()
|
modelType = modelType.Elem()
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -1724,7 +1724,7 @@ func getRelationModelSingleLevel(model interface{}, fieldName string) interface{
|
|||||||
|
|
||||||
for {
|
for {
|
||||||
switch targetType.Kind() {
|
switch targetType.Kind() {
|
||||||
case reflect.Ptr, reflect.Slice:
|
case reflect.Pointer, reflect.Slice:
|
||||||
targetType = targetType.Elem()
|
targetType = targetType.Elem()
|
||||||
if targetType == nil {
|
if targetType == nil {
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -428,15 +428,37 @@ func (h *Handler) executeCreate(ctx context.Context, schema, entity string, data
|
|||||||
// Use potentially modified data
|
// Use potentially modified data
|
||||||
data = hookCtx.Data
|
data = hookCtx.Data
|
||||||
|
|
||||||
|
pkName := reflection.GetPrimaryKeyName(model)
|
||||||
|
|
||||||
switch v := data.(type) {
|
switch v := data.(type) {
|
||||||
case map[string]interface{}:
|
case map[string]interface{}:
|
||||||
query := h.db.NewInsert().Table(tableName)
|
query := h.db.NewInsert().Table(tableName)
|
||||||
for key, value := range v {
|
for key, value := range v {
|
||||||
query = query.Value(key, value)
|
query = query.Value(key, value)
|
||||||
}
|
}
|
||||||
|
if pkName != "" {
|
||||||
|
var insertedID interface{}
|
||||||
|
if err := query.Returning(pkName).Scan(ctx, &insertedID); err != nil {
|
||||||
|
return nil, fmt.Errorf("create error: %w", err)
|
||||||
|
}
|
||||||
|
// Re-fetch after insert to capture DB-generated defaults/triggers.
|
||||||
|
modelType := reflect.TypeOf(model)
|
||||||
|
if modelType.Kind() == reflect.Pointer {
|
||||||
|
modelType = modelType.Elem()
|
||||||
|
}
|
||||||
|
fetchedRecord := reflect.New(modelType).Interface()
|
||||||
|
if err := h.db.NewSelect().Model(fetchedRecord).
|
||||||
|
Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), insertedID).
|
||||||
|
ScanModel(ctx); err == nil {
|
||||||
|
v = mergeWithInput(fetchedRecord, v)
|
||||||
|
} else {
|
||||||
|
logger.Warn("Failed to re-fetch created record with %s=%v: %v", pkName, insertedID, err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
if _, err := query.Exec(ctx); err != nil {
|
if _, err := query.Exec(ctx); err != nil {
|
||||||
return nil, fmt.Errorf("create error: %w", err)
|
return nil, fmt.Errorf("create error: %w", err)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
hookCtx.Result = v
|
hookCtx.Result = v
|
||||||
if err := h.hooks.Execute(AfterCreate, hookCtx); err != nil {
|
if err := h.hooks.Execute(AfterCreate, hookCtx); err != nil {
|
||||||
return nil, fmt.Errorf("AfterCreate hook failed: %w", err)
|
return nil, fmt.Errorf("AfterCreate hook failed: %w", err)
|
||||||
@@ -444,7 +466,12 @@ func (h *Handler) executeCreate(ctx context.Context, schema, entity string, data
|
|||||||
return v, nil
|
return v, nil
|
||||||
|
|
||||||
case []interface{}:
|
case []interface{}:
|
||||||
results := make([]interface{}, 0, len(v))
|
modelType := reflect.TypeOf(model)
|
||||||
|
if modelType.Kind() == reflect.Pointer {
|
||||||
|
modelType = modelType.Elem()
|
||||||
|
}
|
||||||
|
originals := make([]map[string]interface{}, 0, len(v))
|
||||||
|
insertedIDs := make([]interface{}, 0, len(v))
|
||||||
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
||||||
for _, item := range v {
|
for _, item := range v {
|
||||||
itemMap, ok := item.(map[string]interface{})
|
itemMap, ok := item.(map[string]interface{})
|
||||||
@@ -455,16 +482,43 @@ func (h *Handler) executeCreate(ctx context.Context, schema, entity string, data
|
|||||||
for key, value := range itemMap {
|
for key, value := range itemMap {
|
||||||
q = q.Value(key, value)
|
q = q.Value(key, value)
|
||||||
}
|
}
|
||||||
|
if pkName == "" {
|
||||||
if _, err := q.Exec(ctx); err != nil {
|
if _, err := q.Exec(ctx); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
results = append(results, item)
|
originals = append(originals, itemMap)
|
||||||
|
insertedIDs = append(insertedIDs, nil)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
var returnedID interface{}
|
||||||
|
if err := q.Returning(pkName).Scan(ctx, &returnedID); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
originals = append(originals, itemMap)
|
||||||
|
insertedIDs = append(insertedIDs, returnedID)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("batch create error: %w", err)
|
return nil, fmt.Errorf("batch create error: %w", err)
|
||||||
}
|
}
|
||||||
|
// Re-fetch each record after transaction commits; fall back to input on failure.
|
||||||
|
results := make([]interface{}, 0, len(insertedIDs))
|
||||||
|
for i, pkVal := range insertedIDs {
|
||||||
|
if pkVal == nil {
|
||||||
|
results = append(results, originals[i])
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
fetchedRecord := reflect.New(modelType).Interface()
|
||||||
|
if err := h.db.NewSelect().Model(fetchedRecord).
|
||||||
|
Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), pkVal).
|
||||||
|
ScanModel(ctx); err == nil {
|
||||||
|
results = append(results, mergeWithInput(fetchedRecord, originals[i]))
|
||||||
|
} else {
|
||||||
|
logger.Warn("Failed to re-fetch created record with %s=%v: %v", pkName, pkVal, err)
|
||||||
|
results = append(results, originals[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
hookCtx.Result = results
|
hookCtx.Result = results
|
||||||
if err := h.hooks.Execute(AfterCreate, hookCtx); err != nil {
|
if err := h.hooks.Execute(AfterCreate, hookCtx); err != nil {
|
||||||
return nil, fmt.Errorf("AfterCreate hook failed: %w", err)
|
return nil, fmt.Errorf("AfterCreate hook failed: %w", err)
|
||||||
@@ -513,7 +567,7 @@ func (h *Handler) executeUpdate(ctx context.Context, schema, entity, id string,
|
|||||||
err = h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
err = h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
||||||
// Read existing record
|
// Read existing record
|
||||||
modelType := reflect.TypeOf(model)
|
modelType := reflect.TypeOf(model)
|
||||||
if modelType.Kind() == reflect.Ptr {
|
if modelType.Kind() == reflect.Pointer {
|
||||||
modelType = modelType.Elem()
|
modelType = modelType.Elem()
|
||||||
}
|
}
|
||||||
existingRecord := reflect.New(modelType).Interface()
|
existingRecord := reflect.New(modelType).Interface()
|
||||||
@@ -584,6 +638,25 @@ func (h *Handler) executeUpdate(ctx context.Context, schema, entity, id string,
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Re-fetch the record after transaction commits to capture DB-generated changes.
|
||||||
|
modelType := reflect.TypeOf(model)
|
||||||
|
if modelType.Kind() == reflect.Pointer {
|
||||||
|
modelType = modelType.Elem()
|
||||||
|
}
|
||||||
|
fetchedRecord := reflect.New(modelType).Interface()
|
||||||
|
if err := h.db.NewSelect().Model(fetchedRecord).
|
||||||
|
Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id).
|
||||||
|
ScanModel(ctx); err == nil {
|
||||||
|
jsonData, marshalErr := json.Marshal(fetchedRecord)
|
||||||
|
if marshalErr == nil {
|
||||||
|
var fetchedMap map[string]interface{}
|
||||||
|
if json.Unmarshal(jsonData, &fetchedMap) == nil {
|
||||||
|
updateResult = fetchedMap
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return updateResult, nil
|
return updateResult, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -628,7 +701,7 @@ func (h *Handler) executeDelete(ctx context.Context, schema, entity, id string)
|
|||||||
}
|
}
|
||||||
|
|
||||||
modelType := reflect.TypeOf(model)
|
modelType := reflect.TypeOf(model)
|
||||||
if modelType.Kind() == reflect.Ptr {
|
if modelType.Kind() == reflect.Pointer {
|
||||||
modelType = modelType.Elem()
|
modelType = modelType.Elem()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -749,6 +822,28 @@ func (h *Handler) buildFilterCondition(filter common.FilterOption) (condition st
|
|||||||
return "", nil
|
return "", nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// mergeWithInput merges a database record with the original request data.
|
||||||
|
// DB values take precedence (capturing triggers/defaults), while extra
|
||||||
|
// input keys that have no DB column are preserved in the response.
|
||||||
|
func mergeWithInput(dbRecord interface{}, input map[string]interface{}) map[string]interface{} {
|
||||||
|
result := make(map[string]interface{}, len(input))
|
||||||
|
for k, v := range input {
|
||||||
|
result[k] = v
|
||||||
|
}
|
||||||
|
jsonData, err := json.Marshal(dbRecord)
|
||||||
|
if err != nil {
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
var dbMap map[string]interface{}
|
||||||
|
if err := json.Unmarshal(jsonData, &dbMap); err != nil {
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
for k, v := range dbMap {
|
||||||
|
result[k] = v
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, preloads []common.PreloadOption) (common.SelectQuery, error) {
|
func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, preloads []common.PreloadOption) (common.SelectQuery, error) {
|
||||||
for i := range preloads {
|
for i := range preloads {
|
||||||
preload := &preloads[i]
|
preload := &preloads[i]
|
||||||
|
|||||||
@@ -67,7 +67,7 @@ func buildModelInfo(schema, entity string, model interface{}) modelInfo {
|
|||||||
|
|
||||||
// Unwrap to base struct type
|
// Unwrap to base struct type
|
||||||
modelType := reflect.TypeOf(model)
|
modelType := reflect.TypeOf(model)
|
||||||
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice) {
|
for modelType != nil && (modelType.Kind() == reflect.Pointer || modelType.Kind() == reflect.Slice) {
|
||||||
modelType = modelType.Elem()
|
modelType = modelType.Elem()
|
||||||
}
|
}
|
||||||
if modelType == nil || modelType.Kind() != reflect.Struct {
|
if modelType == nil || modelType.Kind() != reflect.Struct {
|
||||||
@@ -87,7 +87,7 @@ func buildModelInfo(schema, entity string, model interface{}) modelInfo {
|
|||||||
fieldType, found := modelType.FieldByName(d.Name)
|
fieldType, found := modelType.FieldByName(d.Name)
|
||||||
if found {
|
if found {
|
||||||
ft := fieldType.Type
|
ft := fieldType.Type
|
||||||
if ft.Kind() == reflect.Ptr {
|
if ft.Kind() == reflect.Pointer {
|
||||||
ft = ft.Elem()
|
ft = ft.Elem()
|
||||||
}
|
}
|
||||||
isUserStruct := ft.Kind() == reflect.Struct && ft.Name() != "Time" && ft.PkgPath() != ""
|
isUserStruct := ft.Kind() == reflect.Struct && ft.Name() != "Time" && ft.PkgPath() != ""
|
||||||
@@ -106,7 +106,7 @@ func buildModelInfo(schema, entity string, model interface{}) modelInfo {
|
|||||||
goType := d.DataType
|
goType := d.DataType
|
||||||
if goType == "" && found {
|
if goType == "" && found {
|
||||||
ft := fieldType.Type
|
ft := fieldType.Type
|
||||||
for ft.Kind() == reflect.Ptr {
|
for ft.Kind() == reflect.Pointer {
|
||||||
ft = ft.Elem()
|
ft = ft.Elem()
|
||||||
}
|
}
|
||||||
goType = ft.Name()
|
goType = ft.Name()
|
||||||
|
|||||||
+183
-21
@@ -243,7 +243,7 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
|||||||
|
|
||||||
// Validate and unwrap model type to get base struct
|
// Validate and unwrap model type to get base struct
|
||||||
modelType := reflect.TypeOf(model)
|
modelType := reflect.TypeOf(model)
|
||||||
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
for modelType != nil && (modelType.Kind() == reflect.Pointer || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
||||||
modelType = modelType.Elem()
|
modelType = modelType.Elem()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -602,10 +602,14 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Standard processing without nested relations
|
// Standard processing without nested relations
|
||||||
|
pkName := reflection.GetPrimaryKeyName(model)
|
||||||
query := h.db.NewInsert().Table(tableName)
|
query := h.db.NewInsert().Table(tableName)
|
||||||
for key, value := range v {
|
for key, value := range v {
|
||||||
query = query.Value(key, common.ConvertSliceForBun(value))
|
query = query.Value(key, common.ConvertSliceForBun(value))
|
||||||
}
|
}
|
||||||
|
var responseData interface{} = v
|
||||||
|
if pkName == "" {
|
||||||
|
// No PK on model — insert and return input as-is.
|
||||||
result, err := query.Exec(ctx)
|
result, err := query.Exec(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("Error creating record: %v", err)
|
logger.Error("Error creating record: %v", err)
|
||||||
@@ -613,12 +617,29 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
logger.Info("Successfully created record, rows affected: %d", result.RowsAffected())
|
logger.Info("Successfully created record, rows affected: %d", result.RowsAffected())
|
||||||
|
} else {
|
||||||
|
var insertedID interface{}
|
||||||
|
if err := query.Returning(pkName).Scan(ctx, &insertedID); err != nil {
|
||||||
|
logger.Error("Error creating record: %v", err)
|
||||||
|
h.sendError(w, http.StatusInternalServerError, "create_error", "Error creating record", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
logger.Info("Successfully created record with %s: %v", pkName, insertedID)
|
||||||
|
fetchedRecord := reflect.New(reflection.GetPointerElement(reflect.TypeOf(model))).Interface()
|
||||||
|
if fetchErr := h.db.NewSelect().Model(fetchedRecord).
|
||||||
|
Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), insertedID).
|
||||||
|
ScanModel(ctx); fetchErr == nil {
|
||||||
|
responseData = mergeWithInput(fetchedRecord, v)
|
||||||
|
} else {
|
||||||
|
logger.Warn("Failed to re-fetch created record with %s=%v: %v", pkName, insertedID, fetchErr)
|
||||||
|
}
|
||||||
|
}
|
||||||
// Invalidate cache for this table
|
// Invalidate cache for this table
|
||||||
cacheTags := buildCacheTags(schema, tableName)
|
cacheTags := buildCacheTags(schema, tableName)
|
||||||
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
||||||
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
|
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
|
||||||
}
|
}
|
||||||
h.sendResponse(w, v, nil)
|
h.sendResponse(w, responseData, nil)
|
||||||
|
|
||||||
case []map[string]interface{}:
|
case []map[string]interface{}:
|
||||||
// Check if any item needs nested processing
|
// Check if any item needs nested processing
|
||||||
@@ -666,15 +687,30 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Standard batch insert without nested relations
|
// Standard batch insert without nested relations
|
||||||
|
pkName := reflection.GetPrimaryKeyName(model)
|
||||||
|
modelElemType := reflection.GetPointerElement(reflect.TypeOf(model))
|
||||||
|
originals := make([]map[string]interface{}, 0, len(v))
|
||||||
|
insertedIDs := make([]interface{}, 0, len(v))
|
||||||
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
||||||
for _, item := range v {
|
for _, item := range v {
|
||||||
txQuery := tx.NewInsert().Table(tableName)
|
txQuery := tx.NewInsert().Table(tableName)
|
||||||
for key, value := range item {
|
for key, value := range item {
|
||||||
txQuery = txQuery.Value(key, common.ConvertSliceForBun(value))
|
txQuery = txQuery.Value(key, common.ConvertSliceForBun(value))
|
||||||
}
|
}
|
||||||
|
if pkName == "" {
|
||||||
if _, err := txQuery.Exec(ctx); err != nil {
|
if _, err := txQuery.Exec(ctx); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
originals = append(originals, item)
|
||||||
|
insertedIDs = append(insertedIDs, nil)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
var returnedID interface{}
|
||||||
|
if err := txQuery.Returning(pkName).Scan(ctx, &returnedID); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
originals = append(originals, item)
|
||||||
|
insertedIDs = append(insertedIDs, returnedID)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
@@ -689,7 +725,24 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
|
|||||||
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
||||||
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
|
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
|
||||||
}
|
}
|
||||||
h.sendResponse(w, v, nil)
|
// Re-fetch each record after transaction commits; fall back to input on failure.
|
||||||
|
responseItems := make([]interface{}, 0, len(insertedIDs))
|
||||||
|
for i, pkVal := range insertedIDs {
|
||||||
|
if pkVal == nil {
|
||||||
|
responseItems = append(responseItems, originals[i])
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
fetchedRecord := reflect.New(modelElemType).Interface()
|
||||||
|
if fetchErr := h.db.NewSelect().Model(fetchedRecord).
|
||||||
|
Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), pkVal).
|
||||||
|
ScanModel(ctx); fetchErr == nil {
|
||||||
|
responseItems = append(responseItems, mergeWithInput(fetchedRecord, originals[i]))
|
||||||
|
} else {
|
||||||
|
logger.Warn("Failed to re-fetch created record with %s=%v: %v", pkName, pkVal, fetchErr)
|
||||||
|
responseItems = append(responseItems, originals[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
h.sendResponse(w, responseItems, nil)
|
||||||
|
|
||||||
case []interface{}:
|
case []interface{}:
|
||||||
// Handle []interface{} type from JSON unmarshaling
|
// Handle []interface{} type from JSON unmarshaling
|
||||||
@@ -742,19 +795,34 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Standard batch insert without nested relations
|
// Standard batch insert without nested relations
|
||||||
list := make([]interface{}, 0)
|
pkName := reflection.GetPrimaryKeyName(model)
|
||||||
|
modelElemType := reflection.GetPointerElement(reflect.TypeOf(model))
|
||||||
|
originals := make([]map[string]interface{}, 0, len(v))
|
||||||
|
insertedIDs := make([]interface{}, 0, len(v))
|
||||||
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
||||||
for _, item := range v {
|
for _, item := range v {
|
||||||
if itemMap, ok := item.(map[string]interface{}); ok {
|
itemMap, ok := item.(map[string]interface{})
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
txQuery := tx.NewInsert().Table(tableName)
|
txQuery := tx.NewInsert().Table(tableName)
|
||||||
for key, value := range itemMap {
|
for key, value := range itemMap {
|
||||||
txQuery = txQuery.Value(key, common.ConvertSliceForBun(value))
|
txQuery = txQuery.Value(key, common.ConvertSliceForBun(value))
|
||||||
}
|
}
|
||||||
|
if pkName == "" {
|
||||||
if _, err := txQuery.Exec(ctx); err != nil {
|
if _, err := txQuery.Exec(ctx); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
list = append(list, item)
|
originals = append(originals, itemMap)
|
||||||
|
insertedIDs = append(insertedIDs, nil)
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
|
var returnedID interface{}
|
||||||
|
if err := txQuery.Returning(pkName).Scan(ctx, &returnedID); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
originals = append(originals, itemMap)
|
||||||
|
insertedIDs = append(insertedIDs, returnedID)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
@@ -769,7 +837,24 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
|
|||||||
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
||||||
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
|
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
|
||||||
}
|
}
|
||||||
h.sendResponse(w, list, nil)
|
// Re-fetch each record after transaction commits; fall back to input on failure.
|
||||||
|
responseItems := make([]interface{}, 0, len(insertedIDs))
|
||||||
|
for i, pkVal := range insertedIDs {
|
||||||
|
if pkVal == nil {
|
||||||
|
responseItems = append(responseItems, originals[i])
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
fetchedRecord := reflect.New(modelElemType).Interface()
|
||||||
|
if fetchErr := h.db.NewSelect().Model(fetchedRecord).
|
||||||
|
Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), pkVal).
|
||||||
|
ScanModel(ctx); fetchErr == nil {
|
||||||
|
responseItems = append(responseItems, mergeWithInput(fetchedRecord, originals[i]))
|
||||||
|
} else {
|
||||||
|
logger.Warn("Failed to re-fetch created record with %s=%v: %v", pkName, pkVal, fetchErr)
|
||||||
|
responseItems = append(responseItems, originals[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
h.sendResponse(w, responseItems, nil)
|
||||||
|
|
||||||
default:
|
default:
|
||||||
logger.Error("Invalid data type for create operation: %T", data)
|
logger.Error("Invalid data type for create operation: %T", data)
|
||||||
@@ -836,7 +921,7 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url
|
|||||||
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
||||||
// First, read the existing record from the database
|
// First, read the existing record from the database
|
||||||
existingRecord := reflect.New(reflection.GetPointerElement(reflect.TypeOf(model))).Interface()
|
existingRecord := reflect.New(reflection.GetPointerElement(reflect.TypeOf(model))).Interface()
|
||||||
selectQuery := tx.NewSelect().Model(existingRecord).Column("*")
|
selectQuery := tx.NewSelect().Model(existingRecord).Column(reflection.GetSQLModelColumns(model)...)
|
||||||
|
|
||||||
// Apply conditions to select
|
// Apply conditions to select
|
||||||
if urlID != "" {
|
if urlID != "" {
|
||||||
@@ -955,13 +1040,34 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Fetch the updated record after the transaction commits to capture any trigger changes
|
||||||
|
updatedRecord := reflect.New(reflection.GetPointerElement(reflect.TypeOf(model))).Interface()
|
||||||
|
fetchQuery := h.db.NewSelect().Model(updatedRecord).Column(reflection.GetSQLModelColumns(model)...)
|
||||||
|
if urlID != "" {
|
||||||
|
fetchQuery = fetchQuery.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), urlID)
|
||||||
|
} else if reqID != nil {
|
||||||
|
switch id := reqID.(type) {
|
||||||
|
case string:
|
||||||
|
fetchQuery = fetchQuery.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id)
|
||||||
|
case []string:
|
||||||
|
if len(id) > 0 {
|
||||||
|
fetchQuery = fetchQuery.Where(fmt.Sprintf("%s IN (?)", common.QuoteIdent(pkName)), id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := fetchQuery.ScanModel(ctx); err != nil {
|
||||||
|
logger.Error("Failed to fetch updated record: %v", err)
|
||||||
|
h.sendError(w, http.StatusInternalServerError, "fetch_error", "Failed to fetch updated record", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
logger.Info("Successfully updated record(s)")
|
logger.Info("Successfully updated record(s)")
|
||||||
// Invalidate cache for this table
|
// Invalidate cache for this table
|
||||||
cacheTags := buildCacheTags(schema, tableName)
|
cacheTags := buildCacheTags(schema, tableName)
|
||||||
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
||||||
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
|
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
|
||||||
}
|
}
|
||||||
h.sendResponse(w, data, nil)
|
h.sendResponse(w, updatedRecord, nil)
|
||||||
|
|
||||||
case []map[string]interface{}:
|
case []map[string]interface{}:
|
||||||
// Batch update with array of objects
|
// Batch update with array of objects
|
||||||
@@ -1017,7 +1123,7 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url
|
|||||||
|
|
||||||
// First, read the existing record
|
// First, read the existing record
|
||||||
existingRecord := reflect.New(reflection.GetPointerElement(reflect.TypeOf(model))).Interface()
|
existingRecord := reflect.New(reflection.GetPointerElement(reflect.TypeOf(model))).Interface()
|
||||||
selectQuery := tx.NewSelect().Model(existingRecord).Column("*").Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), itemID)
|
selectQuery := tx.NewSelect().Model(existingRecord).Column(reflection.GetSQLModelColumns(model)...).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), itemID)
|
||||||
if err := selectQuery.ScanModel(ctx); err != nil {
|
if err := selectQuery.ScanModel(ctx); err != nil {
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
continue // Skip if record not found
|
continue // Skip if record not found
|
||||||
@@ -1089,13 +1195,29 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url
|
|||||||
h.sendError(w, http.StatusInternalServerError, "update_error", "Error updating records", err)
|
h.sendError(w, http.StatusInternalServerError, "update_error", "Error updating records", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
logger.Info("Successfully updated %d records", len(updates))
|
|
||||||
|
// Fetch updated records after the transaction commits to capture any trigger changes
|
||||||
|
fetchedUpdates := make([]interface{}, 0, len(updates))
|
||||||
|
for _, item := range updates {
|
||||||
|
if itemID, ok := item["id"]; ok && itemID != nil {
|
||||||
|
fetchedRecord := reflect.New(reflection.GetPointerElement(reflect.TypeOf(model))).Interface()
|
||||||
|
fetchQuery := h.db.NewSelect().Model(fetchedRecord).Column(reflection.GetSQLModelColumns(model)...).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), itemID)
|
||||||
|
if err := fetchQuery.ScanModel(ctx); err != nil {
|
||||||
|
logger.Error("Failed to fetch updated record with ID %v: %v", itemID, err)
|
||||||
|
h.sendError(w, http.StatusInternalServerError, "fetch_error", "Failed to fetch updated record", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
fetchedUpdates = append(fetchedUpdates, fetchedRecord)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("Successfully updated %d records", len(fetchedUpdates))
|
||||||
// Invalidate cache for this table
|
// Invalidate cache for this table
|
||||||
cacheTags := buildCacheTags(schema, tableName)
|
cacheTags := buildCacheTags(schema, tableName)
|
||||||
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
||||||
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
|
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
|
||||||
}
|
}
|
||||||
h.sendResponse(w, updates, nil)
|
h.sendResponse(w, fetchedUpdates, nil)
|
||||||
|
|
||||||
case []interface{}:
|
case []interface{}:
|
||||||
// Batch update with []interface{}
|
// Batch update with []interface{}
|
||||||
@@ -1157,7 +1279,7 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url
|
|||||||
|
|
||||||
// First, read the existing record
|
// First, read the existing record
|
||||||
existingRecord := reflect.New(reflection.GetPointerElement(reflect.TypeOf(model))).Interface()
|
existingRecord := reflect.New(reflection.GetPointerElement(reflect.TypeOf(model))).Interface()
|
||||||
selectQuery := tx.NewSelect().Model(existingRecord).Column("*").Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), itemID)
|
selectQuery := tx.NewSelect().Model(existingRecord).Column(reflection.GetSQLModelColumns(model)...).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), itemID)
|
||||||
if err := selectQuery.ScanModel(ctx); err != nil {
|
if err := selectQuery.ScanModel(ctx); err != nil {
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
continue // Skip if record not found
|
continue // Skip if record not found
|
||||||
@@ -1232,13 +1354,31 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url
|
|||||||
h.sendError(w, http.StatusInternalServerError, "update_error", "Error updating records", err)
|
h.sendError(w, http.StatusInternalServerError, "update_error", "Error updating records", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
logger.Info("Successfully updated %d records", len(list))
|
|
||||||
|
// Fetch updated records after the transaction commits to capture any trigger changes
|
||||||
|
fetchedList := make([]interface{}, 0, len(list))
|
||||||
|
for _, item := range list {
|
||||||
|
if itemMap, ok := item.(map[string]interface{}); ok {
|
||||||
|
if itemID, ok := itemMap["id"]; ok && itemID != nil {
|
||||||
|
fetchedRecord := reflect.New(reflection.GetPointerElement(reflect.TypeOf(model))).Interface()
|
||||||
|
fetchQuery := h.db.NewSelect().Model(fetchedRecord).Column(reflection.GetSQLModelColumns(model)...).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), itemID)
|
||||||
|
if err := fetchQuery.ScanModel(ctx); err != nil {
|
||||||
|
logger.Error("Failed to fetch updated record with ID %v: %v", itemID, err)
|
||||||
|
h.sendError(w, http.StatusInternalServerError, "fetch_error", "Failed to fetch updated record", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
fetchedList = append(fetchedList, fetchedRecord)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("Successfully updated %d records", len(fetchedList))
|
||||||
// Invalidate cache for this table
|
// Invalidate cache for this table
|
||||||
cacheTags := buildCacheTags(schema, tableName)
|
cacheTags := buildCacheTags(schema, tableName)
|
||||||
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
||||||
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
|
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
|
||||||
}
|
}
|
||||||
h.sendResponse(w, list, nil)
|
h.sendResponse(w, fetchedList, nil)
|
||||||
|
|
||||||
default:
|
default:
|
||||||
logger.Error("Invalid data type for update operation: %T", data)
|
logger.Error("Invalid data type for update operation: %T", data)
|
||||||
@@ -1407,7 +1547,7 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
|||||||
|
|
||||||
// First, fetch the record that will be deleted
|
// First, fetch the record that will be deleted
|
||||||
modelType := reflect.TypeOf(model)
|
modelType := reflect.TypeOf(model)
|
||||||
if modelType.Kind() == reflect.Ptr {
|
if modelType.Kind() == reflect.Pointer {
|
||||||
modelType = modelType.Elem()
|
modelType = modelType.Elem()
|
||||||
}
|
}
|
||||||
recordToDelete := reflect.New(modelType).Interface()
|
recordToDelete := reflect.New(modelType).Interface()
|
||||||
@@ -1682,7 +1822,7 @@ func (h *Handler) generateMetadata(schema, entity string, model interface{}) *co
|
|||||||
modelType := reflect.TypeOf(model)
|
modelType := reflect.TypeOf(model)
|
||||||
|
|
||||||
// Unwrap pointers, slices, and arrays to get to the base struct type
|
// Unwrap pointers, slices, and arrays to get to the base struct type
|
||||||
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
for modelType != nil && (modelType.Kind() == reflect.Pointer || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
||||||
modelType = modelType.Elem()
|
modelType = modelType.Elem()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1826,7 +1966,7 @@ func getColumnType(field reflect.StructField) string {
|
|||||||
|
|
||||||
func isNullable(field reflect.StructField) bool {
|
func isNullable(field reflect.StructField) bool {
|
||||||
// Check if it's a pointer type
|
// Check if it's a pointer type
|
||||||
if field.Type.Kind() == reflect.Ptr {
|
if field.Type.Kind() == reflect.Pointer {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1852,7 +1992,7 @@ func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, pre
|
|||||||
modelType := reflect.TypeOf(model)
|
modelType := reflect.TypeOf(model)
|
||||||
|
|
||||||
// Unwrap pointers, slices, and arrays to get to the base struct type
|
// Unwrap pointers, slices, and arrays to get to the base struct type
|
||||||
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
for modelType != nil && (modelType.Kind() == reflect.Pointer || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
||||||
modelType = modelType.Elem()
|
modelType = modelType.Elem()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -2000,7 +2140,7 @@ func toSnakeCase(s string) string {
|
|||||||
func (h *Handler) setRowNumbersOnRecords(records interface{}, offset int) {
|
func (h *Handler) setRowNumbersOnRecords(records interface{}, offset int) {
|
||||||
// Get the reflect value of the records
|
// Get the reflect value of the records
|
||||||
recordsValue := reflect.ValueOf(records)
|
recordsValue := reflect.ValueOf(records)
|
||||||
if recordsValue.Kind() == reflect.Ptr {
|
if recordsValue.Kind() == reflect.Pointer {
|
||||||
recordsValue = recordsValue.Elem()
|
recordsValue = recordsValue.Elem()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -2015,7 +2155,7 @@ func (h *Handler) setRowNumbersOnRecords(records interface{}, offset int) {
|
|||||||
record := recordsValue.Index(i)
|
record := recordsValue.Index(i)
|
||||||
|
|
||||||
// Dereference if it's a pointer
|
// Dereference if it's a pointer
|
||||||
if record.Kind() == reflect.Ptr {
|
if record.Kind() == reflect.Pointer {
|
||||||
if record.IsNil() {
|
if record.IsNil() {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -2067,3 +2207,25 @@ func (h *Handler) HandleOpenAPI(w common.ResponseWriter, r common.Request) {
|
|||||||
func (h *Handler) SetOpenAPIGenerator(generator func() (string, error)) {
|
func (h *Handler) SetOpenAPIGenerator(generator func() (string, error)) {
|
||||||
h.openAPIGenerator = generator
|
h.openAPIGenerator = generator
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// mergeWithInput merges a database record with the original request data.
|
||||||
|
// DB values take precedence (capturing triggers/defaults), while extra
|
||||||
|
// input keys that have no DB column are preserved in the response.
|
||||||
|
func mergeWithInput(dbRecord interface{}, input map[string]interface{}) map[string]interface{} {
|
||||||
|
result := make(map[string]interface{}, len(input))
|
||||||
|
for k, v := range input {
|
||||||
|
result[k] = v
|
||||||
|
}
|
||||||
|
jsonData, err := json.Marshal(dbRecord)
|
||||||
|
if err != nil {
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
var dbMap map[string]interface{}
|
||||||
|
if err := json.Unmarshal(jsonData, &dbMap); err != nil {
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
for k, v := range dbMap {
|
||||||
|
result[k] = v
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|||||||
@@ -0,0 +1,209 @@
|
|||||||
|
package restheadspec
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||||
|
)
|
||||||
|
|
||||||
|
// detailTestModel is a simple model with gorm column/type tags for detail format tests.
|
||||||
|
type detailTestModel struct {
|
||||||
|
ID int64 `bun:"rid,pk" gorm:"column:rid;primaryKey" json:"rid"`
|
||||||
|
Name string `bun:"name" gorm:"column:name;type:citext" json:"name"`
|
||||||
|
Description *string `bun:"description" gorm:"column:description;type:text;nullable" json:"description"`
|
||||||
|
Score float64 `bun:"score" gorm:"column:score;type:numeric" json:"score"`
|
||||||
|
Active bool `bun:"active" gorm:"column:active;type:boolean;not null" json:"active"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSendFormattedResponse_DetailFormat(t *testing.T) {
|
||||||
|
handler := &Handler{}
|
||||||
|
|
||||||
|
name := "hello"
|
||||||
|
items := []*detailTestModel{
|
||||||
|
{ID: 1, Name: "first", Description: &name, Score: 1.5, Active: true},
|
||||||
|
{ID: 2, Name: "second", Description: nil, Score: 2.0, Active: false},
|
||||||
|
}
|
||||||
|
metadata := &common.Metadata{
|
||||||
|
Total: 36,
|
||||||
|
Count: 2,
|
||||||
|
Filtered: 36,
|
||||||
|
Limit: 10,
|
||||||
|
Offset: 0,
|
||||||
|
}
|
||||||
|
options := ExtendedRequestOptions{
|
||||||
|
ResponseFormat: "detail",
|
||||||
|
}
|
||||||
|
|
||||||
|
mockWriter := &MockTestResponseWriter{headers: make(map[string]string)}
|
||||||
|
handler.sendFormattedResponse(mockWriter, items, metadata, "myschema.myentity", detailTestModel{}, options)
|
||||||
|
|
||||||
|
if mockWriter.statusCode != 200 {
|
||||||
|
t.Fatalf("expected status 200, got %d", mockWriter.statusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := json.Marshal(mockWriter.body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to marshal body: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var resp map[string]json.RawMessage
|
||||||
|
if err := json.Unmarshal(body, &resp); err != nil {
|
||||||
|
t.Fatalf("failed to unmarshal response: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("top-level keys", func(t *testing.T) {
|
||||||
|
for _, key := range []string{"count", "fields", "items", "tablename", "tableprefix", "total"} {
|
||||||
|
if _, ok := resp[key]; !ok {
|
||||||
|
t.Errorf("missing key %q in detail response", key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("count and total are string", func(t *testing.T) {
|
||||||
|
var count, total string
|
||||||
|
if err := json.Unmarshal(resp["count"], &count); err != nil {
|
||||||
|
t.Errorf("count is not a string: %v", err)
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(resp["total"], &total); err != nil {
|
||||||
|
t.Errorf("total is not a string: %v", err)
|
||||||
|
}
|
||||||
|
if count != "2" {
|
||||||
|
t.Errorf("expected count %q, got %q", "2", count)
|
||||||
|
}
|
||||||
|
if total != "36" {
|
||||||
|
t.Errorf("expected total %q, got %q", "36", total)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("tablename and tableprefix", func(t *testing.T) {
|
||||||
|
var tablename, tableprefix string
|
||||||
|
json.Unmarshal(resp["tablename"], &tablename)
|
||||||
|
json.Unmarshal(resp["tableprefix"], &tableprefix)
|
||||||
|
if tablename != "myschema.myentity" {
|
||||||
|
t.Errorf("expected tablename %q, got %q", "myschema.myentity", tablename)
|
||||||
|
}
|
||||||
|
if tableprefix != "myentity" {
|
||||||
|
t.Errorf("expected tableprefix %q, got %q", "myentity", tableprefix)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("items contains data", func(t *testing.T) {
|
||||||
|
var itemSlice []map[string]interface{}
|
||||||
|
if err := json.Unmarshal(resp["items"], &itemSlice); err != nil {
|
||||||
|
t.Fatalf("items is not an array: %v", err)
|
||||||
|
}
|
||||||
|
if len(itemSlice) != 2 {
|
||||||
|
t.Errorf("expected 2 items, got %d", len(itemSlice))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("fields contains column metadata", func(t *testing.T) {
|
||||||
|
var fields []map[string]interface{}
|
||||||
|
if err := json.Unmarshal(resp["fields"], &fields); err != nil {
|
||||||
|
t.Fatalf("fields is not an array: %v", err)
|
||||||
|
}
|
||||||
|
if len(fields) == 0 {
|
||||||
|
t.Fatal("expected fields to be non-empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
bySQL := make(map[string]map[string]interface{}, len(fields))
|
||||||
|
for _, f := range fields {
|
||||||
|
if sqlname, ok := f["sqlname"].(string); ok {
|
||||||
|
bySQL[sqlname] = f
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check required field keys are present
|
||||||
|
for _, f := range fields {
|
||||||
|
for _, key := range []string{"name", "datatype", "sqlname", "sqldatatype", "sqlkey", "nullable"} {
|
||||||
|
if _, ok := f[key]; !ok {
|
||||||
|
t.Errorf("field %v missing key %q", f, key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate specific columns
|
||||||
|
if col, ok := bySQL["rid"]; ok {
|
||||||
|
if col["sqlkey"] != "primary_key" {
|
||||||
|
t.Errorf("rid: expected sqlkey %q, got %v", "primary_key", col["sqlkey"])
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
t.Error("expected column 'rid' in fields")
|
||||||
|
}
|
||||||
|
|
||||||
|
if col, ok := bySQL["name"]; ok {
|
||||||
|
if col["sqldatatype"] != "citext" {
|
||||||
|
t.Errorf("name: expected sqldatatype %q, got %v", "citext", col["sqldatatype"])
|
||||||
|
}
|
||||||
|
if col["nullable"] != false {
|
||||||
|
t.Errorf("name: expected nullable false, got %v", col["nullable"])
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
t.Error("expected column 'name' in fields")
|
||||||
|
}
|
||||||
|
|
||||||
|
if col, ok := bySQL["description"]; ok {
|
||||||
|
if col["sqldatatype"] != "text" {
|
||||||
|
t.Errorf("description: expected sqldatatype %q, got %v", "text", col["sqldatatype"])
|
||||||
|
}
|
||||||
|
if col["nullable"] != true {
|
||||||
|
t.Errorf("description: expected nullable true, got %v", col["nullable"])
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
t.Error("expected column 'description' in fields")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSendFormattedResponse_DetailFormat_EmptyItems(t *testing.T) {
|
||||||
|
handler := &Handler{}
|
||||||
|
metadata := &common.Metadata{Total: 0, Count: 0, Filtered: 0}
|
||||||
|
options := ExtendedRequestOptions{ResponseFormat: "detail"}
|
||||||
|
|
||||||
|
mockWriter := &MockTestResponseWriter{headers: make(map[string]string)}
|
||||||
|
handler.sendFormattedResponse(mockWriter, []*detailTestModel{}, metadata, "s.t", detailTestModel{}, options)
|
||||||
|
|
||||||
|
body, _ := json.Marshal(mockWriter.body)
|
||||||
|
var resp map[string]json.RawMessage
|
||||||
|
json.Unmarshal(body, &resp)
|
||||||
|
|
||||||
|
var count, total string
|
||||||
|
json.Unmarshal(resp["count"], &count)
|
||||||
|
json.Unmarshal(resp["total"], &total)
|
||||||
|
|
||||||
|
if count != "0" || total != "0" {
|
||||||
|
t.Errorf("expected count/total both %q, got count=%q total=%q", "0", count, total)
|
||||||
|
}
|
||||||
|
|
||||||
|
var fields []interface{}
|
||||||
|
json.Unmarshal(resp["fields"], &fields)
|
||||||
|
if len(fields) == 0 {
|
||||||
|
t.Error("fields should still list column metadata even when items is empty")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildDetailFields_SkipsRelations(t *testing.T) {
|
||||||
|
type child struct {
|
||||||
|
ID int64 `bun:"id,pk" gorm:"column:id;primaryKey" json:"id"`
|
||||||
|
}
|
||||||
|
type parent struct {
|
||||||
|
ID int64 `bun:"id,pk" gorm:"column:id;primaryKey" json:"id"`
|
||||||
|
Name string `bun:"name" gorm:"column:name" json:"name"`
|
||||||
|
Children []child `bun:"rel:has-many" json:"children"`
|
||||||
|
Child *child `bun:"rel:has-one" json:"child"`
|
||||||
|
}
|
||||||
|
|
||||||
|
handler := &Handler{}
|
||||||
|
fields := handler.buildDetailFields(parent{})
|
||||||
|
|
||||||
|
for _, f := range fields {
|
||||||
|
if f.SQLName == "children" || f.SQLName == "child" {
|
||||||
|
t.Errorf("relation field %q should not appear in detail fields", f.SQLName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(fields) != 2 {
|
||||||
|
t.Errorf("expected 2 scalar fields (id, name), got %d", len(fields))
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -95,7 +95,7 @@ func TestSendFormattedResponse_NoDataFoundHeader(t *testing.T) {
|
|||||||
|
|
||||||
// Test with empty data
|
// Test with empty data
|
||||||
emptyData := []interface{}{}
|
emptyData := []interface{}{}
|
||||||
handler.sendFormattedResponse(mockWriter, emptyData, metadata, options)
|
handler.sendFormattedResponse(mockWriter, emptyData, metadata, "", nil, options)
|
||||||
|
|
||||||
// Check if X-No-Data-Found header was set
|
// Check if X-No-Data-Found header was set
|
||||||
if mockWriter.headers["X-No-Data-Found"] != "true" {
|
if mockWriter.headers["X-No-Data-Found"] != "true" {
|
||||||
|
|||||||
+185
-38
@@ -289,7 +289,8 @@ func (h *Handler) HandleGet(w common.ResponseWriter, r common.Request, params ma
|
|||||||
Limit: 0,
|
Limit: 0,
|
||||||
Offset: 0,
|
Offset: 0,
|
||||||
}
|
}
|
||||||
h.sendFormattedResponse(w, tableMetadata, responseMetadata, options)
|
tableName := h.getTableName(schema, entity, model)
|
||||||
|
h.sendFormattedResponse(w, tableMetadata, responseMetadata, tableName, model, options)
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleMeta processes meta operation requests
|
// handleMeta processes meta operation requests
|
||||||
@@ -348,7 +349,7 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
|||||||
|
|
||||||
// Validate and unwrap model type to get base struct
|
// Validate and unwrap model type to get base struct
|
||||||
modelType := reflect.TypeOf(model)
|
modelType := reflect.TypeOf(model)
|
||||||
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
for modelType != nil && (modelType.Kind() == reflect.Pointer || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
||||||
modelType = modelType.Elem()
|
modelType = modelType.Elem()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -848,7 +849,7 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
h.sendFormattedResponse(w, modelPtr, metadata, options)
|
h.sendFormattedResponse(w, modelPtr, metadata, tableName, model, options)
|
||||||
}
|
}
|
||||||
|
|
||||||
// applyPreloadWithRecursion applies a preload with support for ComputedQL and recursive preloading
|
// applyPreloadWithRecursion applies a preload with support for ComputedQL and recursive preloading
|
||||||
@@ -1218,8 +1219,8 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
|
|||||||
if provider, ok := modelValue.(common.TableNameProvider); !ok || provider.TableName() == "" {
|
if provider, ok := modelValue.(common.TableNameProvider); !ok || provider.TableName() == "" {
|
||||||
query = query.Table(tableName)
|
query = query.Table(tableName)
|
||||||
}
|
}
|
||||||
|
fields := reflection.GetSQLModelColumns(model)
|
||||||
query = query.Returning("*")
|
query = query.Returning(fields...)
|
||||||
|
|
||||||
// Execute BeforeScan hooks - pass query chain so hooks can modify it
|
// Execute BeforeScan hooks - pass query chain so hooks can modify it
|
||||||
itemHookCtx := &HookContext{
|
itemHookCtx := &HookContext{
|
||||||
@@ -1480,18 +1481,7 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fetch the updated record to return the new values
|
_ = result
|
||||||
modelValue := reflect.New(reflect.TypeOf(model)).Interface()
|
|
||||||
selectQuery = tx.NewSelect().Model(modelValue).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), targetID)
|
|
||||||
if err := selectQuery.ScanModel(ctx); err != nil {
|
|
||||||
return fmt.Errorf("failed to fetch updated record: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
updatedRecord = modelValue
|
|
||||||
|
|
||||||
// Store result for hooks
|
|
||||||
hookCtx.Result = updatedRecord
|
|
||||||
_ = result // Keep result variable for potential future use
|
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -1501,6 +1491,16 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Fetch the updated record after the transaction commits to capture any trigger changes
|
||||||
|
fetchedRecord := reflect.New(reflect.TypeOf(model)).Interface()
|
||||||
|
selectQuery := h.db.NewSelect().Model(fetchedRecord).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), targetID)
|
||||||
|
if err := selectQuery.ScanModel(ctx); err != nil {
|
||||||
|
logger.Error("Failed to fetch updated record: %v", err)
|
||||||
|
h.sendError(w, http.StatusInternalServerError, "fetch_error", "Failed to fetch updated record", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
updatedRecord = fetchedRecord
|
||||||
|
|
||||||
// Merge the updated record with the original request data
|
// Merge the updated record with the original request data
|
||||||
// This preserves extra keys from the request and updates values from the database
|
// This preserves extra keys from the request and updates values from the database
|
||||||
mergedData := h.mergeRecordWithRequest(updatedRecord, dataMap)
|
mergedData := h.mergeRecordWithRequest(updatedRecord, dataMap)
|
||||||
@@ -1892,7 +1892,7 @@ func (h *Handler) extractNestedRelations(
|
|||||||
) (_cleanedData map[string]interface{}, _relations map[string]interface{}, _err error) {
|
) (_cleanedData map[string]interface{}, _relations map[string]interface{}, _err error) {
|
||||||
// Get model type for reflection
|
// Get model type for reflection
|
||||||
modelType := reflect.TypeOf(model)
|
modelType := reflect.TypeOf(model)
|
||||||
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
for modelType != nil && (modelType.Kind() == reflect.Pointer || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
||||||
modelType = modelType.Elem()
|
modelType = modelType.Elem()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1934,7 +1934,7 @@ func (h *Handler) processChildRelationsWithParentID(
|
|||||||
) error {
|
) error {
|
||||||
// Get model type for reflection
|
// Get model type for reflection
|
||||||
modelType := reflect.TypeOf(parentModel)
|
modelType := reflect.TypeOf(parentModel)
|
||||||
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
for modelType != nil && (modelType.Kind() == reflect.Pointer || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
||||||
modelType = modelType.Elem()
|
modelType = modelType.Elem()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1990,7 +1990,7 @@ func (h *Handler) processChildRelationsForField(
|
|||||||
if relatedModelType.Kind() == reflect.Slice {
|
if relatedModelType.Kind() == reflect.Slice {
|
||||||
relatedModelType = relatedModelType.Elem()
|
relatedModelType = relatedModelType.Elem()
|
||||||
}
|
}
|
||||||
if relatedModelType.Kind() == reflect.Ptr {
|
if relatedModelType.Kind() == reflect.Pointer {
|
||||||
relatedModelType = relatedModelType.Elem()
|
relatedModelType = relatedModelType.Elem()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -2012,11 +2012,15 @@ func (h *Handler) processChildRelationsForField(
|
|||||||
// Priority: Use foreign key field name if specified, otherwise use parent's PK name
|
// Priority: Use foreign key field name if specified, otherwise use parent's PK name
|
||||||
var foreignKeyFieldName string
|
var foreignKeyFieldName string
|
||||||
if relInfo.ForeignKey != "" {
|
if relInfo.ForeignKey != "" {
|
||||||
// Get the JSON name for the foreign key field in the child model
|
// For has-many/has-one: join:parentCol=childCol
|
||||||
foreignKeyFieldName = reflection.GetJSONNameForField(relatedModelType, relInfo.ForeignKey)
|
// ForeignKey = parent side, References = child side (where we actually set the value)
|
||||||
|
childField := relInfo.ForeignKey
|
||||||
|
if (relInfo.RelationType == "hasMany" || relInfo.RelationType == "hasOne") && relInfo.References != "" {
|
||||||
|
childField = relInfo.References
|
||||||
|
}
|
||||||
|
foreignKeyFieldName = reflection.GetJSONNameForField(relatedModelType, childField)
|
||||||
if foreignKeyFieldName == "" {
|
if foreignKeyFieldName == "" {
|
||||||
// Fallback to lowercase field name
|
foreignKeyFieldName = strings.ToLower(childField)
|
||||||
foreignKeyFieldName = strings.ToLower(relInfo.ForeignKey)
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// Fallback: use parent's primary key name
|
// Fallback: use parent's primary key name
|
||||||
@@ -2040,7 +2044,10 @@ func (h *Handler) processChildRelationsForField(
|
|||||||
// Process based on relation type and data structure
|
// Process based on relation type and data structure
|
||||||
switch v := relationValue.(type) {
|
switch v := relationValue.(type) {
|
||||||
case map[string]interface{}:
|
case map[string]interface{}:
|
||||||
// Single related object - add parent ID to foreign key field
|
if !isValidNestedRequest(v) {
|
||||||
|
logger.Debug("Skipping single relation %s - missing or invalid _request value", relationName)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
// IMPORTANT: In recursive relationships, don't overwrite the primary key
|
// IMPORTANT: In recursive relationships, don't overwrite the primary key
|
||||||
if parentID != nil && foreignKeyFieldName != "" && foreignKeyFieldName != childPKFieldName {
|
if parentID != nil && foreignKeyFieldName != "" && foreignKeyFieldName != childPKFieldName {
|
||||||
v[foreignKeyFieldName] = parentID
|
v[foreignKeyFieldName] = parentID
|
||||||
@@ -2057,7 +2064,10 @@ func (h *Handler) processChildRelationsForField(
|
|||||||
// Multiple related objects
|
// Multiple related objects
|
||||||
for i, item := range v {
|
for i, item := range v {
|
||||||
if itemMap, ok := item.(map[string]interface{}); ok {
|
if itemMap, ok := item.(map[string]interface{}); ok {
|
||||||
// Add parent ID to foreign key field
|
if !isValidNestedRequest(itemMap) {
|
||||||
|
logger.Debug("Skipping relation array[%d] %s - missing or invalid _request value", i, relationName)
|
||||||
|
continue
|
||||||
|
}
|
||||||
// IMPORTANT: In recursive relationships, don't overwrite the primary key
|
// IMPORTANT: In recursive relationships, don't overwrite the primary key
|
||||||
if parentID != nil && foreignKeyFieldName != "" && foreignKeyFieldName != childPKFieldName {
|
if parentID != nil && foreignKeyFieldName != "" && foreignKeyFieldName != childPKFieldName {
|
||||||
itemMap[foreignKeyFieldName] = parentID
|
itemMap[foreignKeyFieldName] = parentID
|
||||||
@@ -2075,7 +2085,10 @@ func (h *Handler) processChildRelationsForField(
|
|||||||
case []map[string]interface{}:
|
case []map[string]interface{}:
|
||||||
// Multiple related objects (typed slice)
|
// Multiple related objects (typed slice)
|
||||||
for i, itemMap := range v {
|
for i, itemMap := range v {
|
||||||
// Add parent ID to foreign key field
|
if !isValidNestedRequest(itemMap) {
|
||||||
|
logger.Debug("Skipping relation typed array[%d] %s - missing or invalid _request value", i, relationName)
|
||||||
|
continue
|
||||||
|
}
|
||||||
// IMPORTANT: In recursive relationships, don't overwrite the primary key
|
// IMPORTANT: In recursive relationships, don't overwrite the primary key
|
||||||
if parentID != nil && foreignKeyFieldName != "" && foreignKeyFieldName != childPKFieldName {
|
if parentID != nil && foreignKeyFieldName != "" && foreignKeyFieldName != childPKFieldName {
|
||||||
itemMap[foreignKeyFieldName] = parentID
|
itemMap[foreignKeyFieldName] = parentID
|
||||||
@@ -2096,6 +2109,24 @@ func (h *Handler) processChildRelationsForField(
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// isValidNestedRequest returns true only when the item carries a _request key
|
||||||
|
// whose value is one of the recognised mutation verbs.
|
||||||
|
func isValidNestedRequest(item map[string]interface{}) bool {
|
||||||
|
raw, ok := item["_request"]
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
s, ok := raw.(string)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
switch strings.ToLower(strings.TrimSpace(s)) {
|
||||||
|
case "insert", "add", "change", "update", "delete", "remove":
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
// getTableNameForRelatedModel gets the table name for a related model.
|
// getTableNameForRelatedModel gets the table name for a related model.
|
||||||
// If the model's TableName() is schema-qualified (e.g. "public.users") the
|
// If the model's TableName() is schema-qualified (e.g. "public.users") the
|
||||||
// separator is adjusted for the active driver: underscore for SQLite, dot otherwise.
|
// separator is adjusted for the active driver: underscore for SQLite, dot otherwise.
|
||||||
@@ -2382,7 +2413,7 @@ func (h *Handler) generateMetadata(schema, entity string, model interface{}) *co
|
|||||||
modelType := reflect.TypeOf(model)
|
modelType := reflect.TypeOf(model)
|
||||||
|
|
||||||
// Unwrap pointers, slices, and arrays to get to the base struct type
|
// Unwrap pointers, slices, and arrays to get to the base struct type
|
||||||
for modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array {
|
for modelType.Kind() == reflect.Pointer || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array {
|
||||||
modelType = modelType.Elem()
|
modelType = modelType.Elem()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -2431,7 +2462,7 @@ func (h *Handler) generateMetadata(schema, entity string, model interface{}) *co
|
|||||||
// Check if this is a relation field (slice or struct, but not time.Time)
|
// Check if this is a relation field (slice or struct, but not time.Time)
|
||||||
if field.Type.Kind() == reflect.Slice ||
|
if field.Type.Kind() == reflect.Slice ||
|
||||||
(field.Type.Kind() == reflect.Struct && field.Type.Name() != "Time") ||
|
(field.Type.Kind() == reflect.Struct && field.Type.Name() != "Time") ||
|
||||||
(field.Type.Kind() == reflect.Ptr && field.Type.Elem().Kind() == reflect.Struct && field.Type.Elem().Name() != "Time") {
|
(field.Type.Kind() == reflect.Pointer && field.Type.Elem().Kind() == reflect.Struct && field.Type.Elem().Name() != "Time") {
|
||||||
metadata.Relations = append(metadata.Relations, jsonName)
|
metadata.Relations = append(metadata.Relations, jsonName)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -2477,7 +2508,7 @@ func (h *Handler) getColumnType(t reflect.Type) string {
|
|||||||
return "float"
|
return "float"
|
||||||
case reflect.Bool:
|
case reflect.Bool:
|
||||||
return "boolean"
|
return "boolean"
|
||||||
case reflect.Ptr:
|
case reflect.Pointer:
|
||||||
return h.getColumnType(t.Elem())
|
return h.getColumnType(t.Elem())
|
||||||
default:
|
default:
|
||||||
return "unknown"
|
return "unknown"
|
||||||
@@ -2485,7 +2516,7 @@ func (h *Handler) getColumnType(t reflect.Type) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) isNullable(field reflect.StructField) bool {
|
func (h *Handler) isNullable(field reflect.StructField) bool {
|
||||||
return field.Type.Kind() == reflect.Ptr
|
return field.Type.Kind() == reflect.Pointer
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) sendResponse(w common.ResponseWriter, data interface{}, metadata *common.Metadata) {
|
func (h *Handler) sendResponse(w common.ResponseWriter, data interface{}, metadata *common.Metadata) {
|
||||||
@@ -2530,7 +2561,7 @@ func (h *Handler) normalizeResultArray(data interface{}) interface{} {
|
|||||||
|
|
||||||
// Use reflection to check if data is a slice or array
|
// Use reflection to check if data is a slice or array
|
||||||
dataValue := reflect.ValueOf(data)
|
dataValue := reflect.ValueOf(data)
|
||||||
if dataValue.Kind() == reflect.Ptr {
|
if dataValue.Kind() == reflect.Pointer {
|
||||||
dataValue = dataValue.Elem()
|
dataValue = dataValue.Elem()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -2555,8 +2586,103 @@ func (h *Handler) normalizeResultArray(data interface{}) interface{} {
|
|||||||
return data
|
return data
|
||||||
}
|
}
|
||||||
|
|
||||||
// sendFormattedResponse sends response with formatting options
|
// buildDetailFields returns the field metadata list for the detail API format,
|
||||||
func (h *Handler) sendFormattedResponse(w common.ResponseWriter, data interface{}, metadata *common.Metadata, options ExtendedRequestOptions) {
|
// containing only non-relation scalar columns derived from the model's struct tags.
|
||||||
|
func (h *Handler) buildDetailFields(model interface{}) []reflection.ModelFieldDetail {
|
||||||
|
if model == nil {
|
||||||
|
return []reflection.ModelFieldDetail{}
|
||||||
|
}
|
||||||
|
|
||||||
|
modelType := reflect.TypeOf(model)
|
||||||
|
for modelType != nil && (modelType.Kind() == reflect.Pointer || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
||||||
|
modelType = modelType.Elem()
|
||||||
|
}
|
||||||
|
if modelType == nil || modelType.Kind() != reflect.Struct {
|
||||||
|
return []reflection.ModelFieldDetail{}
|
||||||
|
}
|
||||||
|
|
||||||
|
fields := make([]reflection.ModelFieldDetail, 0, modelType.NumField())
|
||||||
|
for i := 0; i < modelType.NumField(); i++ {
|
||||||
|
field := modelType.Field(i)
|
||||||
|
if !field.IsExported() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonTag := field.Tag.Get("json")
|
||||||
|
if jsonTag == "-" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip relation fields (slices, structs that aren't time.Time, ptrs to struct)
|
||||||
|
ft := field.Type
|
||||||
|
if ft.Kind() == reflect.Pointer {
|
||||||
|
ft = ft.Elem()
|
||||||
|
}
|
||||||
|
if ft.Kind() == reflect.Slice ||
|
||||||
|
(ft.Kind() == reflect.Struct && ft.Name() != "Time") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonName := strings.Split(jsonTag, ",")[0]
|
||||||
|
if jsonName == "" {
|
||||||
|
jsonName = field.Name
|
||||||
|
}
|
||||||
|
|
||||||
|
gormTag := field.Tag.Get("gorm")
|
||||||
|
sqlName := fnFindTagVal(gormTag, "column:")
|
||||||
|
if sqlName == "" {
|
||||||
|
sqlName = jsonName
|
||||||
|
}
|
||||||
|
|
||||||
|
sqlDataType := fnFindTagVal(gormTag, "type:")
|
||||||
|
|
||||||
|
var sqlKey string
|
||||||
|
gormLower := strings.ToLower(gormTag)
|
||||||
|
switch {
|
||||||
|
case strings.Contains(gormLower, "identity") || strings.Contains(gormLower, "primary_key") || strings.Contains(gormLower, "primarykey"):
|
||||||
|
sqlKey = "primary_key"
|
||||||
|
case strings.Contains(gormLower, "uniqueindex"):
|
||||||
|
sqlKey = "uniqueindex"
|
||||||
|
case strings.Contains(gormLower, "unique"):
|
||||||
|
sqlKey = "unique"
|
||||||
|
}
|
||||||
|
|
||||||
|
nullable := field.Type.Kind() == reflect.Pointer
|
||||||
|
if strings.Contains(gormLower, "not null") {
|
||||||
|
nullable = false
|
||||||
|
} else if strings.Contains(gormLower, "nullable") || strings.Contains(gormLower, ",null") {
|
||||||
|
nullable = true
|
||||||
|
}
|
||||||
|
|
||||||
|
fields = append(fields, reflection.ModelFieldDetail{
|
||||||
|
Name: jsonName,
|
||||||
|
DataType: h.getColumnType(field.Type),
|
||||||
|
SQLName: sqlName,
|
||||||
|
SQLDataType: sqlDataType,
|
||||||
|
SQLKey: sqlKey,
|
||||||
|
Nullable: nullable,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return fields
|
||||||
|
}
|
||||||
|
|
||||||
|
// fnFindTagVal extracts a value from a semicolon-separated struct tag string.
|
||||||
|
func fnFindTagVal(tag, key string) string {
|
||||||
|
lower := strings.ToLower(tag)
|
||||||
|
idx := strings.Index(lower, strings.ToLower(key))
|
||||||
|
if idx < 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
val := tag[idx+len(key):]
|
||||||
|
if end := strings.Index(val, ";"); end >= 0 {
|
||||||
|
val = val[:end]
|
||||||
|
}
|
||||||
|
return val
|
||||||
|
}
|
||||||
|
|
||||||
|
// sendFormattedResponse sends response with formatting options.
|
||||||
|
// model is used when ResponseFormat is "detail" to generate the fields metadata list.
|
||||||
|
func (h *Handler) sendFormattedResponse(w common.ResponseWriter, data interface{}, metadata *common.Metadata, tableName string, model interface{}, options ExtendedRequestOptions) {
|
||||||
// Handle nil data - convert to empty array
|
// Handle nil data - convert to empty array
|
||||||
if data == nil {
|
if data == nil {
|
||||||
data = []interface{}{}
|
data = []interface{}{}
|
||||||
@@ -2609,8 +2735,29 @@ func (h *Handler) sendFormattedResponse(w common.ResponseWriter, data interface{
|
|||||||
if err := w.WriteJSON(response); err != nil {
|
if err := w.WriteJSON(response); err != nil {
|
||||||
logger.Error("Failed to write JSON response: %v", err)
|
logger.Error("Failed to write JSON response: %v", err)
|
||||||
}
|
}
|
||||||
|
case "detail":
|
||||||
|
// Detail format: { count, fields, items, tablename, tableprefix, total }
|
||||||
|
var count, total int64
|
||||||
|
if metadata != nil {
|
||||||
|
count = metadata.Count
|
||||||
|
total = metadata.Total
|
||||||
|
}
|
||||||
|
tablePrefix := reflection.ExtractTableNameOnly(tableName)
|
||||||
|
fieldList := h.buildDetailFields(model)
|
||||||
|
response := map[string]interface{}{
|
||||||
|
"count": strconv.FormatInt(count, 10),
|
||||||
|
"fields": fieldList,
|
||||||
|
"items": data,
|
||||||
|
"tablename": tableName,
|
||||||
|
"tableprefix": tablePrefix,
|
||||||
|
"total": strconv.FormatInt(total, 10),
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
if err := w.WriteJSON(response); err != nil {
|
||||||
|
logger.Error("Failed to write JSON response: %v", err)
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
// Default/detail format: standard response with metadata
|
// Default format: standard response with metadata
|
||||||
response := common.Response{
|
response := common.Response{
|
||||||
Success: true,
|
Success: true,
|
||||||
Data: data,
|
Data: data,
|
||||||
@@ -2868,7 +3015,7 @@ func (h *Handler) buildFilterSQL(filter *common.FilterOption, tableName string)
|
|||||||
func (h *Handler) setRowNumbersOnRecords(records any, offset int) {
|
func (h *Handler) setRowNumbersOnRecords(records any, offset int) {
|
||||||
// Get the reflect value of the records
|
// Get the reflect value of the records
|
||||||
recordsValue := reflect.ValueOf(records)
|
recordsValue := reflect.ValueOf(records)
|
||||||
if recordsValue.Kind() == reflect.Ptr {
|
if recordsValue.Kind() == reflect.Pointer {
|
||||||
recordsValue = recordsValue.Elem()
|
recordsValue = recordsValue.Elem()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -2883,7 +3030,7 @@ func (h *Handler) setRowNumbersOnRecords(records any, offset int) {
|
|||||||
record := recordsValue.Index(i)
|
record := recordsValue.Index(i)
|
||||||
|
|
||||||
// Dereference if it's a pointer
|
// Dereference if it's a pointer
|
||||||
if record.Kind() == reflect.Ptr {
|
if record.Kind() == reflect.Pointer {
|
||||||
if record.IsNil() {
|
if record.IsNil() {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -2938,7 +3085,7 @@ func (h *Handler) filterExtendedOptions(validator *common.ColumnValidator, optio
|
|||||||
// Filter Expand columns using the expand relation's model
|
// Filter Expand columns using the expand relation's model
|
||||||
filteredExpands := make([]ExpandOption, 0, len(options.Expand))
|
filteredExpands := make([]ExpandOption, 0, len(options.Expand))
|
||||||
modelType := reflect.TypeOf(model)
|
modelType := reflect.TypeOf(model)
|
||||||
if modelType.Kind() == reflect.Ptr {
|
if modelType.Kind() == reflect.Pointer {
|
||||||
modelType = modelType.Elem()
|
modelType = modelType.Elem()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -352,6 +352,45 @@ func (m *mockRegistry) GetAllModels() map[string]interface{} {
|
|||||||
return m.models
|
return m.models
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestIsValidNestedRequest verifies that only the allowed _request verbs are accepted
|
||||||
|
// and that items missing the key are rejected.
|
||||||
|
func TestIsValidNestedRequest(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
item map[string]interface{}
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
// Valid verbs
|
||||||
|
{name: "insert", item: map[string]interface{}{"_request": "insert"}, expected: true},
|
||||||
|
{name: "add", item: map[string]interface{}{"_request": "add"}, expected: true},
|
||||||
|
{name: "update", item: map[string]interface{}{"_request": "update"}, expected: true},
|
||||||
|
{name: "change", item: map[string]interface{}{"_request": "change"}, expected: true},
|
||||||
|
{name: "delete", item: map[string]interface{}{"_request": "delete"}, expected: true},
|
||||||
|
{name: "remove", item: map[string]interface{}{"_request": "remove"}, expected: true},
|
||||||
|
// Case-insensitive
|
||||||
|
{name: "INSERT uppercase", item: map[string]interface{}{"_request": "INSERT"}, expected: true},
|
||||||
|
{name: "Remove mixed case", item: map[string]interface{}{"_request": "Remove"}, expected: true},
|
||||||
|
// Whitespace trimmed
|
||||||
|
{name: "insert with spaces", item: map[string]interface{}{"_request": " insert "}, expected: true},
|
||||||
|
// Invalid / missing
|
||||||
|
{name: "missing _request", item: map[string]interface{}{"name": "foo"}, expected: false},
|
||||||
|
{name: "empty string", item: map[string]interface{}{"_request": ""}, expected: false},
|
||||||
|
{name: "unknown verb", item: map[string]interface{}{"_request": "create"}, expected: false},
|
||||||
|
{name: "unknown verb modify", item: map[string]interface{}{"_request": "modify"}, expected: false},
|
||||||
|
{name: "non-string value", item: map[string]interface{}{"_request": 42}, expected: false},
|
||||||
|
{name: "nil value", item: map[string]interface{}{"_request": nil}, expected: false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := isValidNestedRequest(tt.item)
|
||||||
|
if got != tt.expected {
|
||||||
|
t.Errorf("isValidNestedRequest(%v) = %v, want %v", tt.item, got, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// TestMultiLevelRelationExtraction tests extracting deeply nested relations
|
// TestMultiLevelRelationExtraction tests extracting deeply nested relations
|
||||||
func TestMultiLevelRelationExtraction(t *testing.T) {
|
func TestMultiLevelRelationExtraction(t *testing.T) {
|
||||||
registry := &mockRegistry{
|
registry := &mockRegistry{
|
||||||
|
|||||||
+20
-13
@@ -6,9 +6,9 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
"regexp"
|
"regexp"
|
||||||
|
"sort"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"unicode/utf8"
|
|
||||||
|
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
@@ -102,11 +102,6 @@ func DecodeParam(pStr string) (string, error) {
|
|||||||
|
|
||||||
if strings.HasPrefix(code, "ZIP_") || strings.HasPrefix(code, "__") {
|
if strings.HasPrefix(code, "ZIP_") || strings.HasPrefix(code, "__") {
|
||||||
code, _ = DecodeParam(code)
|
code, _ = DecodeParam(code)
|
||||||
} else {
|
|
||||||
strDat, err := base64.StdEncoding.DecodeString(code)
|
|
||||||
if err == nil && utf8.Valid(strDat) {
|
|
||||||
code = string(strDat)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return code, nil
|
return code, nil
|
||||||
@@ -146,9 +141,21 @@ func (h *Handler) parseOptionsFromHeaders(r common.Request, model interface{}) E
|
|||||||
combinedParams[strings.ToLower(key)] = value
|
combinedParams[strings.ToLower(key)] = value
|
||||||
}
|
}
|
||||||
|
|
||||||
|
sortedKeys := make([]string, 0, len(combinedParams))
|
||||||
|
for key := range combinedParams {
|
||||||
|
sortedKeys = append(sortedKeys, key)
|
||||||
|
}
|
||||||
|
sort.Slice(sortedKeys, func(i, j int) bool {
|
||||||
|
if sortedKeys[i] != sortedKeys[j] {
|
||||||
|
return sortedKeys[i] < sortedKeys[j]
|
||||||
|
}
|
||||||
|
return combinedParams[sortedKeys[i]] < combinedParams[sortedKeys[j]]
|
||||||
|
})
|
||||||
|
|
||||||
// Process each parameter (from both headers and query params)
|
// Process each parameter (from both headers and query params)
|
||||||
// Note: keys are already normalized to lowercase in combinedParams
|
// Note: keys are already normalized to lowercase in combinedParams
|
||||||
for key, value := range combinedParams {
|
for _, key := range sortedKeys {
|
||||||
|
value := combinedParams[key]
|
||||||
// Decode value if it's base64 encoded
|
// Decode value if it's base64 encoded
|
||||||
decodedValue := decodeHeaderValue(value)
|
decodedValue := decodeHeaderValue(value)
|
||||||
|
|
||||||
@@ -970,7 +977,7 @@ func (h *Handler) resolveRelationName(model interface{}, nameOrTable string) str
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Dereference pointer if needed
|
// Dereference pointer if needed
|
||||||
if modelType.Kind() == reflect.Ptr {
|
if modelType.Kind() == reflect.Pointer {
|
||||||
modelType = modelType.Elem()
|
modelType = modelType.Elem()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1005,13 +1012,13 @@ func (h *Handler) resolveRelationName(model interface{}, nameOrTable string) str
|
|||||||
var targetType reflect.Type
|
var targetType reflect.Type
|
||||||
if fieldType.Kind() == reflect.Slice {
|
if fieldType.Kind() == reflect.Slice {
|
||||||
targetType = fieldType.Elem()
|
targetType = fieldType.Elem()
|
||||||
} else if fieldType.Kind() == reflect.Ptr {
|
} else if fieldType.Kind() == reflect.Pointer {
|
||||||
targetType = fieldType.Elem()
|
targetType = fieldType.Elem()
|
||||||
}
|
}
|
||||||
|
|
||||||
if targetType != nil {
|
if targetType != nil {
|
||||||
// Dereference pointer if the slice contains pointers
|
// Dereference pointer if the slice contains pointers
|
||||||
if targetType.Kind() == reflect.Ptr {
|
if targetType.Kind() == reflect.Pointer {
|
||||||
targetType = targetType.Elem()
|
targetType = targetType.Elem()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1055,7 +1062,7 @@ func (h *Handler) resolveRelationNameWithJoinKey(model interface{}, nameOrTable
|
|||||||
if modelType == nil {
|
if modelType == nil {
|
||||||
return nameOrTable
|
return nameOrTable
|
||||||
}
|
}
|
||||||
if modelType.Kind() == reflect.Ptr {
|
if modelType.Kind() == reflect.Pointer {
|
||||||
modelType = modelType.Elem()
|
modelType = modelType.Elem()
|
||||||
}
|
}
|
||||||
if modelType == nil || modelType.Kind() != reflect.Struct {
|
if modelType == nil || modelType.Kind() != reflect.Struct {
|
||||||
@@ -1082,10 +1089,10 @@ func (h *Handler) resolveRelationNameWithJoinKey(model interface{}, nameOrTable
|
|||||||
var targetType reflect.Type
|
var targetType reflect.Type
|
||||||
if fieldType.Kind() == reflect.Slice {
|
if fieldType.Kind() == reflect.Slice {
|
||||||
targetType = fieldType.Elem()
|
targetType = fieldType.Elem()
|
||||||
} else if fieldType.Kind() == reflect.Ptr {
|
} else if fieldType.Kind() == reflect.Pointer {
|
||||||
targetType = fieldType.Elem()
|
targetType = fieldType.Elem()
|
||||||
}
|
}
|
||||||
if targetType != nil && targetType.Kind() == reflect.Ptr {
|
if targetType != nil && targetType.Kind() == reflect.Pointer {
|
||||||
targetType = targetType.Elem()
|
targetType = targetType.Elem()
|
||||||
}
|
}
|
||||||
if targetType == nil || targetType.Kind() != reflect.Struct {
|
if targetType == nil || targetType.Kind() != reflect.Struct {
|
||||||
|
|||||||
@@ -90,7 +90,7 @@ func applyRowSecurity(secCtx SecurityContext, securityList *SecurityList) error
|
|||||||
|
|
||||||
// Get primary key name from model
|
// Get primary key name from model
|
||||||
modelType := reflect.TypeOf(model)
|
modelType := reflect.TypeOf(model)
|
||||||
if modelType.Kind() == reflect.Ptr {
|
if modelType.Kind() == reflect.Pointer {
|
||||||
modelType = modelType.Elem()
|
modelType = modelType.Elem()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -155,13 +155,13 @@ func applyColumnSecurity(secCtx SecurityContext, securityList *SecurityList) err
|
|||||||
|
|
||||||
// Get model type
|
// Get model type
|
||||||
modelType := reflect.TypeOf(model)
|
modelType := reflect.TypeOf(model)
|
||||||
if modelType.Kind() == reflect.Ptr {
|
if modelType.Kind() == reflect.Pointer {
|
||||||
modelType = modelType.Elem()
|
modelType = modelType.Elem()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Apply column security masking
|
// Apply column security masking
|
||||||
resultValue := reflect.ValueOf(result)
|
resultValue := reflect.ValueOf(result)
|
||||||
if resultValue.Kind() == reflect.Ptr {
|
if resultValue.Kind() == reflect.Pointer {
|
||||||
resultValue = resultValue.Elem()
|
resultValue = resultValue.Elem()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -70,6 +70,25 @@ func (m *mountPoint) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||||||
// Try to open the file
|
// Try to open the file
|
||||||
file, err := m.provider.Open(strings.TrimPrefix(filePath, "/"))
|
file, err := m.provider.Open(strings.TrimPrefix(filePath, "/"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
// For extensionless paths, also try path/index.html
|
||||||
|
if path.Ext(filePath) == "" {
|
||||||
|
indexFallback := path.Join(filePath, "index.html")
|
||||||
|
file, err = m.provider.Open(strings.TrimPrefix(indexFallback, "/"))
|
||||||
|
if err == nil {
|
||||||
|
defer file.Close()
|
||||||
|
m.serveFile(w, r, indexFallback, file)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
indexFallback = fmt.Sprintf("%s.html", filePath)
|
||||||
|
file, err = m.provider.Open(strings.TrimPrefix(indexFallback, "/"))
|
||||||
|
if err == nil {
|
||||||
|
defer file.Close()
|
||||||
|
m.serveFile(w, r, indexFallback, file)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// File doesn't exist - check if we should use fallback
|
// File doesn't exist - check if we should use fallback
|
||||||
if m.fallbackStrategy != nil && m.fallbackStrategy.ShouldFallback(filePath) {
|
if m.fallbackStrategy != nil && m.fallbackStrategy.ShouldFallback(filePath) {
|
||||||
fallbackPath := m.fallbackStrategy.GetFallbackPath(filePath)
|
fallbackPath := m.fallbackStrategy.GetFallbackPath(filePath)
|
||||||
@@ -80,16 +99,6 @@ func (m *mountPoint) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// For extensionless paths, also try path/index.html
|
|
||||||
if path.Ext(filePath) == "" {
|
|
||||||
indexFallback := path.Join(filePath, "index.html")
|
|
||||||
file, err = m.provider.Open(strings.TrimPrefix(indexFallback, "/"))
|
|
||||||
if err == nil {
|
|
||||||
defer file.Close()
|
|
||||||
m.serveFile(w, r, indexFallback, file)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// No fallback or fallback failed - return 404
|
// No fallback or fallback failed - return 404
|
||||||
|
|||||||
@@ -671,6 +671,12 @@ func (h *Handler) create(hookCtx *HookContext) (interface{}, error) {
|
|||||||
return nil, fmt.Errorf("failed to create record: %w", err)
|
return nil, fmt.Errorf("failed to create record: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Re-fetch the created record to capture DB-generated defaults/triggers.
|
||||||
|
if pkVal := reflection.GetPrimaryKeyValue(hookCtx.ModelPtr); pkVal != nil {
|
||||||
|
hookCtx.ID = fmt.Sprintf("%v", pkVal)
|
||||||
|
return h.readByID(hookCtx)
|
||||||
|
}
|
||||||
|
|
||||||
return hookCtx.ModelPtr, nil
|
return hookCtx.ModelPtr, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -834,7 +840,7 @@ func (h *Handler) buildFilterCondition(filter common.FilterOption) (conditionStr
|
|||||||
func (h *Handler) setRowNumbersOnRecords(records interface{}, offset int) {
|
func (h *Handler) setRowNumbersOnRecords(records interface{}, offset int) {
|
||||||
// Get the reflect value of the records
|
// Get the reflect value of the records
|
||||||
recordsValue := reflect.ValueOf(records)
|
recordsValue := reflect.ValueOf(records)
|
||||||
if recordsValue.Kind() == reflect.Ptr {
|
if recordsValue.Kind() == reflect.Pointer {
|
||||||
recordsValue = recordsValue.Elem()
|
recordsValue = recordsValue.Elem()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -849,7 +855,7 @@ func (h *Handler) setRowNumbersOnRecords(records interface{}, offset int) {
|
|||||||
record := recordsValue.Index(i)
|
record := recordsValue.Index(i)
|
||||||
|
|
||||||
// Dereference if it's a pointer
|
// Dereference if it's a pointer
|
||||||
if record.Kind() == reflect.Ptr {
|
if record.Kind() == reflect.Pointer {
|
||||||
if record.IsNil() {
|
if record.IsNil() {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user