mirror of
https://github.com/bitechdev/ResolveSpec.git
synced 2026-07-02 09:27:39 +00:00
Compare commits
14 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 3dac55cb19 | |||
| bbb2c6d127 | |||
| 3fec7b1a90 | |||
| 910390f62d | |||
| b9bed67bd7 | |||
| 11ef16f75a | |||
| 48b72a7631 | |||
| 4c512acf25 | |||
| 07a402634e | |||
| 0e8f8925c6 | |||
| 5a359a160b | |||
| a2799fa224 | |||
| 1419542650 | |||
| c120b49529 |
@@ -39,7 +39,7 @@ func (h *QueryDebugHook) AfterQuery(ctx context.Context, event *bun.QueryEvent)
|
||||
// This helps identify which specific field is causing scanning issues
|
||||
func debugScanIntoStruct(rows interface{}, dest interface{}) error {
|
||||
v := reflect.ValueOf(dest)
|
||||
if v.Kind() != reflect.Ptr {
|
||||
if v.Kind() != reflect.Pointer {
|
||||
return fmt.Errorf("dest must be a pointer")
|
||||
}
|
||||
|
||||
@@ -59,7 +59,7 @@ func debugScanIntoStruct(rows interface{}, dest interface{}) error {
|
||||
logger.Debug(" Slice element type: %s", elemType)
|
||||
|
||||
// If slice of pointers, get the underlying type
|
||||
if elemType.Kind() == reflect.Ptr {
|
||||
if elemType.Kind() == reflect.Pointer {
|
||||
structType = elemType.Elem()
|
||||
} else {
|
||||
structType = elemType
|
||||
@@ -747,7 +747,7 @@ func checkIfRelationAlreadyLoaded(parents reflect.Value, relationName string) (r
|
||||
|
||||
// Get the first parent to check the relation field
|
||||
firstParent := parents.Index(0)
|
||||
if firstParent.Kind() == reflect.Ptr {
|
||||
if firstParent.Kind() == reflect.Pointer {
|
||||
firstParent = firstParent.Elem()
|
||||
}
|
||||
|
||||
@@ -762,7 +762,7 @@ func checkIfRelationAlreadyLoaded(parents reflect.Value, relationName string) (r
|
||||
// Check if any parent has a non-empty slice
|
||||
for i := 0; i < parents.Len(); i++ {
|
||||
parent := parents.Index(i)
|
||||
if parent.Kind() == reflect.Ptr {
|
||||
if parent.Kind() == reflect.Pointer {
|
||||
parent = parent.Elem()
|
||||
}
|
||||
field := parent.FieldByName(relationName)
|
||||
@@ -771,7 +771,7 @@ func checkIfRelationAlreadyLoaded(parents reflect.Value, relationName string) (r
|
||||
allRelated := reflect.MakeSlice(field.Type(), 0, field.Len()*parents.Len())
|
||||
for j := 0; j < parents.Len(); j++ {
|
||||
p := parents.Index(j)
|
||||
if p.Kind() == reflect.Ptr {
|
||||
if p.Kind() == reflect.Pointer {
|
||||
p = p.Elem()
|
||||
}
|
||||
f := p.FieldByName(relationName)
|
||||
@@ -784,7 +784,7 @@ func checkIfRelationAlreadyLoaded(parents reflect.Value, relationName string) (r
|
||||
return allRelated, true
|
||||
}
|
||||
}
|
||||
} else if relationField.Kind() == reflect.Ptr {
|
||||
} else if relationField.Kind() == reflect.Pointer {
|
||||
// Check if it's a pointer (has-one/belongs-to)
|
||||
if !relationField.IsNil() {
|
||||
// Already loaded! Collect all related records from all parents
|
||||
@@ -792,7 +792,7 @@ func checkIfRelationAlreadyLoaded(parents reflect.Value, relationName string) (r
|
||||
allRelated := reflect.MakeSlice(reflect.SliceOf(relatedType), 0, parents.Len())
|
||||
for j := 0; j < parents.Len(); j++ {
|
||||
p := parents.Index(j)
|
||||
if p.Kind() == reflect.Ptr {
|
||||
if p.Kind() == reflect.Pointer {
|
||||
p = p.Elem()
|
||||
}
|
||||
f := p.FieldByName(relationName)
|
||||
@@ -816,7 +816,7 @@ func (b *BunSelectQuery) loadCustomPreloads(ctx context.Context) error {
|
||||
|
||||
// Get the actual data from the model
|
||||
modelValue := reflect.ValueOf(model.Value())
|
||||
if modelValue.Kind() == reflect.Ptr {
|
||||
if modelValue.Kind() == reflect.Pointer {
|
||||
modelValue = modelValue.Elem()
|
||||
}
|
||||
|
||||
@@ -884,7 +884,7 @@ func (b *BunSelectQuery) loadRelationLevel(ctx context.Context, parentRecords re
|
||||
|
||||
// Get the first record to inspect the struct type
|
||||
firstRecord := parentRecords.Index(0)
|
||||
if firstRecord.Kind() == reflect.Ptr {
|
||||
if firstRecord.Kind() == reflect.Pointer {
|
||||
firstRecord = firstRecord.Elem()
|
||||
}
|
||||
|
||||
@@ -930,7 +930,7 @@ func (b *BunSelectQuery) loadRelationLevel(ctx context.Context, parentRecords re
|
||||
if isSlice {
|
||||
relatedType = relatedType.Elem()
|
||||
}
|
||||
if relatedType.Kind() == reflect.Ptr {
|
||||
if relatedType.Kind() == reflect.Pointer {
|
||||
relatedType = relatedType.Elem()
|
||||
}
|
||||
|
||||
@@ -1018,7 +1018,7 @@ func extractForeignKeyValues(records reflect.Value, fkFieldName string) ([]inter
|
||||
|
||||
for i := 0; i < records.Len(); i++ {
|
||||
record := records.Index(i)
|
||||
if record.Kind() == reflect.Ptr {
|
||||
if record.Kind() == reflect.Pointer {
|
||||
record = record.Elem()
|
||||
}
|
||||
|
||||
@@ -1083,7 +1083,7 @@ func associateRelatedRecords(parents, related reflect.Value, fieldName string, r
|
||||
for i := 0; i < related.Len(); i++ {
|
||||
relRecord := related.Index(i)
|
||||
relRecordElem := relRecord
|
||||
if relRecordElem.Kind() == reflect.Ptr {
|
||||
if relRecordElem.Kind() == reflect.Pointer {
|
||||
relRecordElem = relRecordElem.Elem()
|
||||
}
|
||||
|
||||
@@ -1109,7 +1109,7 @@ func associateRelatedRecords(parents, related reflect.Value, fieldName string, r
|
||||
for i := 0; i < parents.Len(); i++ {
|
||||
parentPtr := parents.Index(i)
|
||||
parent := parentPtr
|
||||
if parent.Kind() == reflect.Ptr {
|
||||
if parent.Kind() == reflect.Pointer {
|
||||
parent = parent.Elem()
|
||||
}
|
||||
|
||||
@@ -1332,11 +1332,11 @@ func (b *BunSelectQuery) ScanModel(ctx context.Context) (err error) {
|
||||
modelInfo = fmt.Sprintf("Model type: %T", modelValue)
|
||||
|
||||
v := reflect.ValueOf(modelValue)
|
||||
if v.Kind() == reflect.Ptr {
|
||||
if v.Kind() == reflect.Pointer {
|
||||
v = v.Elem()
|
||||
}
|
||||
if v.Kind() == reflect.Slice {
|
||||
if v.Type().Elem().Kind() == reflect.Ptr {
|
||||
if v.Type().Elem().Kind() == reflect.Pointer {
|
||||
modelInfo += fmt.Sprintf(", Slice of: %s", v.Type().Elem().Elem().Name())
|
||||
} else {
|
||||
modelInfo += fmt.Sprintf(", Slice of: %s", v.Type().Elem().Name())
|
||||
|
||||
@@ -800,7 +800,7 @@ func (g *GormInsertQuery) Scan(ctx context.Context, dest interface{}) (err error
|
||||
col := g.returningColumns[0]
|
||||
if g.model != nil {
|
||||
val := reflect.ValueOf(g.model)
|
||||
if val.Kind() == reflect.Ptr {
|
||||
if val.Kind() == reflect.Pointer {
|
||||
val = val.Elem()
|
||||
}
|
||||
if val.Kind() == reflect.Struct {
|
||||
|
||||
@@ -1195,7 +1195,7 @@ func (p *PgSQLSelectQuery) applySubqueryPreloads(ctx context.Context, dest inter
|
||||
|
||||
// Use reflection to process the destination
|
||||
destValue := reflect.ValueOf(dest)
|
||||
if destValue.Kind() != reflect.Ptr {
|
||||
if destValue.Kind() != reflect.Pointer {
|
||||
return fmt.Errorf("dest must be a pointer")
|
||||
}
|
||||
|
||||
@@ -1222,7 +1222,7 @@ func (p *PgSQLSelectQuery) applySubqueryPreloads(ctx context.Context, dest inter
|
||||
|
||||
// loadPreloadsForRecord loads all preload relationships for a single record
|
||||
func (p *PgSQLSelectQuery) loadPreloadsForRecord(ctx context.Context, record reflect.Value, preloads []preloadConfig) error {
|
||||
if record.Kind() == reflect.Ptr {
|
||||
if record.Kind() == reflect.Pointer {
|
||||
if record.IsNil() {
|
||||
return nil
|
||||
}
|
||||
@@ -1299,7 +1299,7 @@ func (p *PgSQLSelectQuery) executePreloadQuery(ctx context.Context, field reflec
|
||||
} else {
|
||||
// Single struct - create a pointer if needed
|
||||
var target reflect.Value
|
||||
if field.Kind() == reflect.Ptr {
|
||||
if field.Kind() == reflect.Pointer {
|
||||
target = reflect.New(field.Type().Elem())
|
||||
} else {
|
||||
target = reflect.New(field.Type())
|
||||
@@ -1312,7 +1312,7 @@ func (p *PgSQLSelectQuery) executePreloadQuery(ctx context.Context, field reflec
|
||||
}
|
||||
|
||||
// Set the field
|
||||
if field.Kind() == reflect.Ptr {
|
||||
if field.Kind() == reflect.Pointer {
|
||||
field.Set(target)
|
||||
} else {
|
||||
field.Set(target.Elem())
|
||||
@@ -1329,7 +1329,7 @@ func (p *PgSQLSelectQuery) getRelationMetadata(fieldName string) *relationMetada
|
||||
}
|
||||
|
||||
modelType := reflect.TypeOf(p.model)
|
||||
if modelType.Kind() == reflect.Ptr {
|
||||
if modelType.Kind() == reflect.Pointer {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
|
||||
@@ -1378,7 +1378,7 @@ func (p *PgSQLSelectQuery) getRelationMetadataFromField(modelType reflect.Type,
|
||||
if fieldType.Kind() == reflect.Slice {
|
||||
fieldType = fieldType.Elem()
|
||||
}
|
||||
if fieldType.Kind() == reflect.Ptr {
|
||||
if fieldType.Kind() == reflect.Pointer {
|
||||
fieldType = fieldType.Elem()
|
||||
}
|
||||
|
||||
@@ -1411,7 +1411,7 @@ func scanRows(rows *sql.Rows, dest interface{}) error {
|
||||
|
||||
// Get destination type
|
||||
destValue := reflect.ValueOf(dest)
|
||||
if destValue.Kind() != reflect.Ptr {
|
||||
if destValue.Kind() != reflect.Pointer {
|
||||
return fmt.Errorf("dest must be a pointer")
|
||||
}
|
||||
|
||||
@@ -1466,7 +1466,7 @@ func scanRowsToMapSlice(rows *sql.Rows, columns []string, destValue reflect.Valu
|
||||
// scanRowsToStructSlice scans rows into a slice of structs
|
||||
func scanRowsToStructSlice(rows *sql.Rows, columns []string, destValue reflect.Value) error {
|
||||
elemType := destValue.Type().Elem()
|
||||
isPtr := elemType.Kind() == reflect.Ptr
|
||||
isPtr := elemType.Kind() == reflect.Pointer
|
||||
|
||||
if isPtr {
|
||||
elemType = elemType.Elem()
|
||||
|
||||
@@ -71,7 +71,7 @@ func entityNameFromModel(model interface{}, table string) string {
|
||||
}
|
||||
|
||||
modelType := reflect.TypeOf(model)
|
||||
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
||||
for modelType != nil && (modelType.Kind() == reflect.Pointer || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
|
||||
@@ -108,7 +108,7 @@ func tableNameProviderFromModel(model interface{}) (common.TableNameProvider, bo
|
||||
}
|
||||
|
||||
modelType := reflect.TypeOf(model)
|
||||
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
||||
for modelType != nil && (modelType.Kind() == reflect.Pointer || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
|
||||
|
||||
@@ -174,7 +174,9 @@ func (h *HTTPResponseWriter) Write(data []byte) (int, error) {
|
||||
|
||||
func (h *HTTPResponseWriter) WriteJSON(data interface{}) error {
|
||||
h.SetHeader("Content-Type", "application/json")
|
||||
return json.NewEncoder(h.resp).Encode(data)
|
||||
enc := json.NewEncoder(h.resp)
|
||||
enc.SetEscapeHTML(false)
|
||||
return enc.Encode(data)
|
||||
}
|
||||
|
||||
// UnderlyingResponseWriter returns the underlying http.ResponseWriter
|
||||
|
||||
@@ -25,7 +25,7 @@ func ValidateAndUnwrapModel(model interface{}) (*ValidateAndUnwrapModelResult, e
|
||||
originalType := modelType
|
||||
|
||||
// Unwrap pointers, slices, and arrays to get to the base struct type
|
||||
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
||||
for modelType != nil && (modelType.Kind() == reflect.Pointer || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
|
||||
@@ -126,15 +126,15 @@ func GetRelationshipInfo(modelType reflect.Type, relationName string) *Relations
|
||||
// Get related model type
|
||||
if field.Type.Kind() == reflect.Slice {
|
||||
elemType := field.Type.Elem()
|
||||
if elemType.Kind() == reflect.Ptr {
|
||||
if elemType.Kind() == reflect.Pointer {
|
||||
elemType = elemType.Elem()
|
||||
}
|
||||
if elemType.Kind() == reflect.Struct {
|
||||
info.RelatedModel = reflect.New(elemType).Elem().Interface()
|
||||
}
|
||||
} else if field.Type.Kind() == reflect.Ptr || field.Type.Kind() == reflect.Struct {
|
||||
} else if field.Type.Kind() == reflect.Pointer || field.Type.Kind() == reflect.Struct {
|
||||
elemType := field.Type
|
||||
if elemType.Kind() == reflect.Ptr {
|
||||
if elemType.Kind() == reflect.Pointer {
|
||||
elemType = elemType.Elem()
|
||||
}
|
||||
if elemType.Kind() == reflect.Struct {
|
||||
@@ -155,16 +155,16 @@ func GetRelationshipInfo(modelType reflect.Type, relationName string) *Relations
|
||||
info.RelationType = "hasMany"
|
||||
// Get the element type for slice
|
||||
elemType := field.Type.Elem()
|
||||
if elemType.Kind() == reflect.Ptr {
|
||||
if elemType.Kind() == reflect.Pointer {
|
||||
elemType = elemType.Elem()
|
||||
}
|
||||
if elemType.Kind() == reflect.Struct {
|
||||
info.RelatedModel = reflect.New(elemType).Elem().Interface()
|
||||
}
|
||||
} else if field.Type.Kind() == reflect.Ptr || field.Type.Kind() == reflect.Struct {
|
||||
} else if field.Type.Kind() == reflect.Pointer || field.Type.Kind() == reflect.Struct {
|
||||
info.RelationType = "belongsTo"
|
||||
elemType := field.Type
|
||||
if elemType.Kind() == reflect.Ptr {
|
||||
if elemType.Kind() == reflect.Pointer {
|
||||
elemType = elemType.Elem()
|
||||
}
|
||||
if elemType.Kind() == reflect.Struct {
|
||||
@@ -177,7 +177,7 @@ func GetRelationshipInfo(modelType reflect.Type, relationName string) *Relations
|
||||
// Get the element type for many2many (always slice)
|
||||
if field.Type.Kind() == reflect.Slice {
|
||||
elemType := field.Type.Elem()
|
||||
if elemType.Kind() == reflect.Ptr {
|
||||
if elemType.Kind() == reflect.Pointer {
|
||||
elemType = elemType.Elem()
|
||||
}
|
||||
if elemType.Kind() == reflect.Struct {
|
||||
@@ -239,7 +239,7 @@ func GetTableNameFromModel(model interface{}) string {
|
||||
modelType := reflect.TypeOf(model)
|
||||
|
||||
// Unwrap pointers
|
||||
for modelType != nil && modelType.Kind() == reflect.Ptr {
|
||||
for modelType != nil && modelType.Kind() == reflect.Pointer {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
|
||||
|
||||
@@ -178,7 +178,9 @@ func (s *StandardResponseWriter) Write(data []byte) (int, error) {
|
||||
|
||||
func (s *StandardResponseWriter) WriteJSON(data interface{}) error {
|
||||
s.SetHeader("Content-Type", "application/json")
|
||||
return json.NewEncoder(s.w).Encode(data)
|
||||
enc := json.NewEncoder(s.w)
|
||||
enc.SetEscapeHTML(false)
|
||||
return enc.Encode(data)
|
||||
}
|
||||
|
||||
func (s *StandardResponseWriter) UnderlyingResponseWriter() http.ResponseWriter {
|
||||
|
||||
@@ -69,7 +69,7 @@ func (p *NestedCUDProcessor) ProcessNestedCUD(
|
||||
|
||||
// Get model type for reflection
|
||||
modelType := reflect.TypeOf(model)
|
||||
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
||||
for modelType != nil && (modelType.Kind() == reflect.Pointer || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
|
||||
@@ -224,7 +224,7 @@ func (p *NestedCUDProcessor) filterValidFields(data map[string]interface{}, mode
|
||||
}
|
||||
|
||||
modelType := reflect.TypeOf(model)
|
||||
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
||||
for modelType != nil && (modelType.Kind() == reflect.Pointer || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
|
||||
@@ -410,7 +410,7 @@ func (p *NestedCUDProcessor) processChildRelations(
|
||||
if relatedModelType.Kind() == reflect.Slice {
|
||||
relatedModelType = relatedModelType.Elem()
|
||||
}
|
||||
if relatedModelType.Kind() == reflect.Ptr {
|
||||
if relatedModelType.Kind() == reflect.Pointer {
|
||||
relatedModelType = relatedModelType.Elem()
|
||||
}
|
||||
|
||||
@@ -590,7 +590,7 @@ func shouldUseNestedProcessorDepth(data map[string]interface{}, model interface{
|
||||
|
||||
// Get model type
|
||||
modelType := reflect.TypeOf(model)
|
||||
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
||||
for modelType != nil && (modelType.Kind() == reflect.Pointer || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
|
||||
|
||||
+52
-18
@@ -446,18 +446,36 @@ func containsTopLevelOR(clause string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// splitByAND splits a WHERE clause by AND operators (case-insensitive)
|
||||
// This is parenthesis-aware and won't split on AND operators inside subqueries
|
||||
// splitByAND splits a WHERE clause by AND operators (case-insensitive).
|
||||
// It is parenthesis-aware (won't split inside subqueries), quote-aware
|
||||
// (won't split on AND inside single-quoted strings), and BETWEEN-aware
|
||||
// (won't split on the AND that separates the two operands of BETWEEN x AND y).
|
||||
func splitByAND(where string) []string {
|
||||
conditions := []string{}
|
||||
currentCondition := strings.Builder{}
|
||||
depth := 0 // Track parenthesis depth
|
||||
depth := 0 // parenthesis nesting depth
|
||||
inSingleQuote := false
|
||||
afterBetween := false // true after seeing BETWEEN at depth 0; next AND belongs to it
|
||||
i := 0
|
||||
|
||||
for i < len(where) {
|
||||
ch := where[i]
|
||||
|
||||
// Track parenthesis depth
|
||||
// Track single-quote state so we never split on AND inside string literals.
|
||||
if ch == '\'' {
|
||||
inSingleQuote = !inSingleQuote
|
||||
currentCondition.WriteByte(ch)
|
||||
i++
|
||||
continue
|
||||
}
|
||||
|
||||
if inSingleQuote {
|
||||
currentCondition.WriteByte(ch)
|
||||
i++
|
||||
continue
|
||||
}
|
||||
|
||||
// Track parenthesis depth (outside quotes only).
|
||||
if ch == '(' {
|
||||
depth++
|
||||
currentCondition.WriteByte(ch)
|
||||
@@ -470,32 +488,39 @@ func splitByAND(where string) []string {
|
||||
continue
|
||||
}
|
||||
|
||||
// Only look for AND operators at depth 0 (not inside parentheses)
|
||||
// All keyword checks only apply at depth 0 (not inside subqueries).
|
||||
if depth == 0 {
|
||||
// Check if we're at an AND operator (case-insensitive)
|
||||
// We need at least " AND " (5 chars) or " and " (5 chars)
|
||||
if i+5 <= len(where) {
|
||||
substring := where[i : i+5]
|
||||
lowerSubstring := strings.ToLower(substring)
|
||||
// Detect " BETWEEN " (9 chars, case-insensitive) so the very next
|
||||
// top-level AND is recognised as part of the BETWEEN syntax.
|
||||
if i+9 <= len(where) && strings.ToLower(where[i:i+9]) == " between " {
|
||||
afterBetween = true
|
||||
currentCondition.WriteString(where[i : i+9])
|
||||
i += 9
|
||||
continue
|
||||
}
|
||||
|
||||
if lowerSubstring == " and " {
|
||||
// Found an AND operator at the top level
|
||||
// Add the current condition to the list
|
||||
conditions = append(conditions, currentCondition.String())
|
||||
currentCondition.Reset()
|
||||
// Skip past the AND operator
|
||||
// Detect " AND " (5 chars, case-insensitive).
|
||||
if i+5 <= len(where) && strings.ToLower(where[i:i+5]) == " and " {
|
||||
if afterBetween {
|
||||
// This AND closes a BETWEEN expression — do NOT split.
|
||||
afterBetween = false
|
||||
currentCondition.WriteString(where[i : i+5])
|
||||
i += 5
|
||||
continue
|
||||
}
|
||||
// Regular conjunction — split here.
|
||||
conditions = append(conditions, currentCondition.String())
|
||||
currentCondition.Reset()
|
||||
i += 5
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Not an AND operator or we're inside parentheses, just add the character
|
||||
currentCondition.WriteByte(ch)
|
||||
i++
|
||||
}
|
||||
|
||||
// Add the last condition
|
||||
// Add the last condition.
|
||||
if currentCondition.Len() > 0 {
|
||||
conditions = append(conditions, currentCondition.String())
|
||||
}
|
||||
@@ -614,6 +639,15 @@ func extractTableAndColumn(cond string) (table string, column string) {
|
||||
// Remove any quotes
|
||||
columnRef = strings.Trim(columnRef, "`\"'")
|
||||
|
||||
// If the left side is a parenthesized subquery (starts with '(' and contains SQL keywords),
|
||||
// don't attempt prefix extraction from inside it.
|
||||
if len(columnRef) > 0 && columnRef[0] == '(' {
|
||||
lowerRef := strings.ToLower(columnRef)
|
||||
if strings.Contains(lowerRef, "select ") || strings.Contains(lowerRef, " from ") || strings.Contains(lowerRef, " where ") {
|
||||
return "", ""
|
||||
}
|
||||
}
|
||||
|
||||
// Check if there's a function call (contains opening parenthesis)
|
||||
openParenIdx := strings.Index(columnRef, "(")
|
||||
|
||||
|
||||
@@ -520,6 +520,38 @@ func TestSplitByAND(t *testing.T) {
|
||||
input: "a = 1 AND b = 2 AND c = 3 and (select s from generate_series(1,10) s where s < 10 and s > 0 offset 2 limit 1) = 3",
|
||||
expected: []string{"a = 1", "b = 2", "c = 3", "(select s from generate_series(1,10) s where s < 10 and s > 0 offset 2 limit 1) = 3"},
|
||||
},
|
||||
// BETWEEN-aware cases: the AND inside BETWEEN x AND y must not cause a split.
|
||||
{
|
||||
name: "BETWEEN does not split on its AND",
|
||||
input: "col between '2025-08-31' and '1970-01-01'",
|
||||
expected: []string{"col between '2025-08-31' and '1970-01-01'"},
|
||||
},
|
||||
{
|
||||
name: "BETWEEN uppercase AND",
|
||||
input: "col BETWEEN '2025-08-31' AND '1970-01-01'",
|
||||
expected: []string{"col BETWEEN '2025-08-31' AND '1970-01-01'"},
|
||||
},
|
||||
{
|
||||
name: "BETWEEN followed by a regular AND conjunction",
|
||||
input: "col between 1 and 5 and other = 'x'",
|
||||
expected: []string{"col between 1 and 5", "other = 'x'"},
|
||||
},
|
||||
{
|
||||
name: "two BETWEEN conditions joined by AND",
|
||||
input: "col1 between 1 and 5 and col2 between 10 and 20",
|
||||
expected: []string{"col1 between 1 and 5", "col2 between 10 and 20"},
|
||||
},
|
||||
{
|
||||
name: "complex OR block with multiple BETWEENs (real-world case)",
|
||||
input: "tbl.applicationdate between '2025-08-31' and '1970-01-01'\n or tbl.capturedate between '2025-08-31' and '1970-01-01'\n or tbl.startdate between '2025-08-31' AND '1970-01-01'",
|
||||
expected: []string{"tbl.applicationdate between '2025-08-31' and '1970-01-01'\n or tbl.capturedate between '2025-08-31' and '1970-01-01'\n or tbl.startdate between '2025-08-31' AND '1970-01-01'"},
|
||||
},
|
||||
// Quote-aware cases: AND inside a string literal must not split.
|
||||
{
|
||||
name: "AND inside single-quoted string is not a split point",
|
||||
input: "comment = 'this and that' and status = 'active'",
|
||||
expected: []string{"comment = 'this and that'", "status = 'active'"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@@ -917,6 +949,25 @@ where: "(true AND status = 'active')",
|
||||
tableName: "unregistered_table",
|
||||
expected: "(true AND unregistered_table.status = 'active')",
|
||||
},
|
||||
// BETWEEN regression: date literals inside BETWEEN must not be prefixed as columns.
|
||||
{
|
||||
name: "BETWEEN date range - second date must not be prefixed",
|
||||
where: "applicationdate between '2025-08-31' and '1970-01-01'",
|
||||
tableName: "unregistered_table",
|
||||
expected: "unregistered_table.applicationdate between '2025-08-31' and '1970-01-01'",
|
||||
},
|
||||
{
|
||||
name: "Already-prefixed BETWEEN column - unchanged",
|
||||
where: `"v_webui_clients".applicationdate between '2025-08-31' and '1970-01-01'`,
|
||||
tableName: "v_webui_clients",
|
||||
expected: `"v_webui_clients".applicationdate between '2025-08-31' and '1970-01-01'`,
|
||||
},
|
||||
{
|
||||
name: "Complex OR block with multiple BETWEENs - date values must not be prefixed",
|
||||
where: `("v_webui_clients".applicationdate between '2025-08-31' and '1970-01-01' or "v_webui_clients".clientcapturedate between '2025-08-31' and '1970-01-01' or "v_webui_clients".startdate between '2025-08-31' AND '1970-01-01')`,
|
||||
tableName: "v_webui_clients",
|
||||
expected: `("v_webui_clients".applicationdate between '2025-08-31' and '1970-01-01' or "v_webui_clients".clientcapturedate between '2025-08-31' and '1970-01-01' or "v_webui_clients".startdate between '2025-08-31' AND '1970-01-01')`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
|
||||
@@ -31,7 +31,7 @@ func (v *ColumnValidator) buildValidColumns() {
|
||||
modelType := reflect.TypeOf(v.model)
|
||||
|
||||
// Unwrap pointers, slices, and arrays to get to the base struct type
|
||||
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
||||
for modelType != nil && (modelType.Kind() == reflect.Pointer || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
|
||||
@@ -290,7 +290,7 @@ func (v *ColumnValidator) FilterRequestOptions(options RequestOptions) RequestOp
|
||||
// Filter Preload columns
|
||||
validPreloads := make([]PreloadOption, 0, len(options.Preload))
|
||||
modelType := reflect.TypeOf(v.model)
|
||||
if modelType != nil && modelType.Kind() == reflect.Ptr {
|
||||
if modelType != nil && modelType.Kind() == reflect.Pointer {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
for idx := range options.Preload {
|
||||
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/reflection"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/restheadspec"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/security"
|
||||
)
|
||||
@@ -367,13 +368,17 @@ func (h *Handler) SqlQueryList(sqlquery string, options SqlQueryOptions) HTTPFun
|
||||
}
|
||||
|
||||
case "detail":
|
||||
// Detail format: complex API with metadata
|
||||
// Detail format: { count, fields, items, tablename, tableprefix, total }
|
||||
tableName := r.URL.Path
|
||||
tablePrefix := reflection.ExtractTableNameOnly(tableName)
|
||||
fields := buildDetailFieldsFromRows(dbobjlist)
|
||||
metaobj := map[string]interface{}{
|
||||
"items": dbobjlist,
|
||||
"count": fmt.Sprintf("%d", len(dbobjlist)),
|
||||
"fields": fields,
|
||||
"items": dbobjlist,
|
||||
"tablename": tableName,
|
||||
"tableprefix": tablePrefix,
|
||||
"total": fmt.Sprintf("%d", total),
|
||||
"tablename": r.URL.Path,
|
||||
"tableprefix": "gsql",
|
||||
}
|
||||
data, err := json.Marshal(metaobj)
|
||||
if err != nil {
|
||||
@@ -1079,6 +1084,49 @@ func getReplacementForBlankParam(sqlquery, param string) string {
|
||||
// return result
|
||||
// }
|
||||
|
||||
// buildDetailFieldsFromRows builds a field metadata list from the column names and value types
|
||||
// of a raw SQL result set. Used when no model struct is available (funcspec raw queries).
|
||||
func buildDetailFieldsFromRows(rows []map[string]interface{}) []reflection.ModelFieldDetail {
|
||||
if len(rows) == 0 {
|
||||
return []reflection.ModelFieldDetail{}
|
||||
}
|
||||
first := rows[0]
|
||||
fields := make([]reflection.ModelFieldDetail, 0, len(first))
|
||||
for colName, val := range first {
|
||||
dataType := inferGoType(val)
|
||||
fields = append(fields, reflection.ModelFieldDetail{
|
||||
Name: colName,
|
||||
DataType: dataType,
|
||||
SQLName: colName,
|
||||
SQLDataType: "",
|
||||
SQLKey: "",
|
||||
Nullable: val == nil,
|
||||
})
|
||||
}
|
||||
return fields
|
||||
}
|
||||
|
||||
// inferGoType returns a simple type name for a value, used for detail field metadata.
|
||||
func inferGoType(val interface{}) string {
|
||||
if val == nil {
|
||||
return "interface{}"
|
||||
}
|
||||
switch val.(type) {
|
||||
case bool:
|
||||
return "bool"
|
||||
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
|
||||
return "int64"
|
||||
case float32, float64:
|
||||
return "float64"
|
||||
case string:
|
||||
return "string"
|
||||
case []byte:
|
||||
return "[]byte"
|
||||
default:
|
||||
return "interface{}"
|
||||
}
|
||||
}
|
||||
|
||||
// getIPAddress extracts the real IP address from the request
|
||||
func getIPAddress(r *http.Request) string {
|
||||
if forwarded := r.Header.Get("X-Forwarded-For"); forwarded != "" {
|
||||
|
||||
@@ -617,6 +617,91 @@ func TestSqlQueryList(t *testing.T) {
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "x-detailapi header returns detail format",
|
||||
sqlQuery: "SELECT * FROM myschema.myentity",
|
||||
noCount: false,
|
||||
blankParams: false,
|
||||
allowFilter: false,
|
||||
headers: map[string]string{"x-detailapi": "true"},
|
||||
setupDB: func() *MockDatabase {
|
||||
return &MockDatabase{
|
||||
RunInTransactionFunc: func(ctx context.Context, fn func(common.Database) error) error {
|
||||
db := &MockDatabase{
|
||||
QueryFunc: func(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
|
||||
if strings.Contains(query, "COUNT") {
|
||||
dest.(*struct{ Count int64 }).Count = 3
|
||||
return nil
|
||||
}
|
||||
*dest.(*[]map[string]interface{}) = []map[string]interface{}{
|
||||
{"id": float64(1), "name": "Alice"},
|
||||
{"id": float64(2), "name": "Bob"},
|
||||
{"id": float64(3), "name": "Carol"},
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
return fn(db)
|
||||
},
|
||||
}
|
||||
},
|
||||
expectedStatus: 200,
|
||||
validateResp: func(t *testing.T, w *httptest.ResponseRecorder) {
|
||||
var resp map[string]json.RawMessage
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("expected JSON object, got: %s", w.Body.String())
|
||||
}
|
||||
|
||||
for _, key := range []string{"count", "fields", "items", "tablename", "tableprefix", "total"} {
|
||||
if _, ok := resp[key]; !ok {
|
||||
t.Errorf("missing key %q in detail response", key)
|
||||
}
|
||||
}
|
||||
|
||||
var count, total string
|
||||
json.Unmarshal(resp["count"], &count)
|
||||
json.Unmarshal(resp["total"], &total)
|
||||
if count != "3" {
|
||||
t.Errorf("expected count %q, got %q", "3", count)
|
||||
}
|
||||
if total != "3" {
|
||||
t.Errorf("expected total %q, got %q", "3", total)
|
||||
}
|
||||
|
||||
var items []map[string]interface{}
|
||||
if err := json.Unmarshal(resp["items"], &items); err != nil {
|
||||
t.Fatalf("items is not an array: %v", err)
|
||||
}
|
||||
if len(items) != 3 {
|
||||
t.Errorf("expected 3 items, got %d", len(items))
|
||||
}
|
||||
|
||||
var fields []map[string]interface{}
|
||||
if err := json.Unmarshal(resp["fields"], &fields); err != nil {
|
||||
t.Fatalf("fields is not an array: %v", err)
|
||||
}
|
||||
if len(fields) == 0 {
|
||||
t.Error("expected non-empty fields list")
|
||||
}
|
||||
for _, f := range fields {
|
||||
for _, key := range []string{"name", "datatype", "sqlname", "sqldatatype", "sqlkey", "nullable"} {
|
||||
if _, ok := f[key]; !ok {
|
||||
t.Errorf("field %v missing key %q", f, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var tablename, tableprefix string
|
||||
json.Unmarshal(resp["tablename"], &tablename)
|
||||
json.Unmarshal(resp["tableprefix"], &tableprefix)
|
||||
if tablename == "" {
|
||||
t.Error("expected non-empty tablename")
|
||||
}
|
||||
if tableprefix == "" {
|
||||
t.Error("expected non-empty tableprefix")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "List query with noCount",
|
||||
sqlQuery: "SELECT * FROM users",
|
||||
|
||||
@@ -107,7 +107,7 @@ func (r *DefaultModelRegistry) RegisterModel(name string, model interface{}) err
|
||||
originalType := modelType
|
||||
|
||||
// Unwrap pointers, slices, and arrays to check the underlying type
|
||||
for modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array {
|
||||
for modelType.Kind() == reflect.Pointer || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
|
||||
@@ -124,7 +124,7 @@ func (r *DefaultModelRegistry) RegisterModel(name string, model interface{}) err
|
||||
|
||||
// Additional check: ensure model is not a pointer
|
||||
finalType := reflect.TypeOf(model)
|
||||
if finalType.Kind() == reflect.Ptr {
|
||||
if finalType.Kind() == reflect.Pointer {
|
||||
return fmt.Errorf("model must be a non-pointer struct, got pointer to %s. Use MyModel{} instead of &MyModel{}", finalType.Elem().Name())
|
||||
}
|
||||
|
||||
|
||||
@@ -781,6 +781,12 @@ func (h *Handler) create(hookCtx *HookContext) (interface{}, error) {
|
||||
return nil, fmt.Errorf("failed to create record: %w", err)
|
||||
}
|
||||
|
||||
// Re-fetch the created record to capture DB-generated defaults/triggers.
|
||||
if pkVal := reflection.GetPrimaryKeyValue(hookCtx.ModelPtr); pkVal != nil {
|
||||
hookCtx.ID = fmt.Sprintf("%v", pkVal)
|
||||
return h.readByID(hookCtx)
|
||||
}
|
||||
|
||||
return hookCtx.ModelPtr, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -387,7 +387,7 @@ func (g *Generator) generateModelSchema(model interface{}) Schema {
|
||||
}
|
||||
|
||||
modelType := reflect.TypeOf(model)
|
||||
if modelType.Kind() == reflect.Ptr {
|
||||
if modelType.Kind() == reflect.Pointer {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
if modelType.Kind() != reflect.Struct {
|
||||
@@ -418,7 +418,7 @@ func (g *Generator) generateModelSchema(model interface{}) Schema {
|
||||
schema.Properties[fieldName] = propSchema
|
||||
|
||||
// Check if field is required (not a pointer and no omitempty)
|
||||
if field.Type.Kind() != reflect.Ptr && !strings.Contains(jsonTag, "omitempty") {
|
||||
if field.Type.Kind() != reflect.Pointer && !strings.Contains(jsonTag, "omitempty") {
|
||||
schema.Required = append(schema.Required, fieldName)
|
||||
}
|
||||
}
|
||||
@@ -431,7 +431,7 @@ func (g *Generator) generatePropertySchema(field reflect.StructField) *Schema {
|
||||
schema := &Schema{}
|
||||
|
||||
fieldType := field.Type
|
||||
if fieldType.Kind() == reflect.Ptr {
|
||||
if fieldType.Kind() == reflect.Pointer {
|
||||
fieldType = fieldType.Elem()
|
||||
}
|
||||
|
||||
@@ -453,7 +453,7 @@ func (g *Generator) generatePropertySchema(field reflect.StructField) *Schema {
|
||||
case reflect.Slice, reflect.Array:
|
||||
schema.Type = "array"
|
||||
elemType := fieldType.Elem()
|
||||
if elemType.Kind() == reflect.Ptr {
|
||||
if elemType.Kind() == reflect.Pointer {
|
||||
elemType = elemType.Elem()
|
||||
}
|
||||
if elemType.Kind() == reflect.Struct {
|
||||
|
||||
@@ -9,7 +9,7 @@ func Len(v any) int {
|
||||
val := reflect.ValueOf(v)
|
||||
valKind := val.Kind()
|
||||
|
||||
if valKind == reflect.Ptr {
|
||||
if valKind == reflect.Pointer {
|
||||
val = val.Elem()
|
||||
}
|
||||
|
||||
@@ -57,7 +57,7 @@ func IsEmptyValue(v any) bool {
|
||||
return true
|
||||
}
|
||||
rv := reflect.ValueOf(v)
|
||||
if rv.Kind() == reflect.Ptr {
|
||||
if rv.Kind() == reflect.Pointer {
|
||||
if rv.IsNil() {
|
||||
return true
|
||||
}
|
||||
@@ -80,12 +80,12 @@ func IsEmptyValue(v any) bool {
|
||||
// If the type is a slice of pointers, it returns the element type of the pointer within the slice.
|
||||
// If neither condition is met, it returns the original type.
|
||||
func GetPointerElement(v reflect.Type) reflect.Type {
|
||||
if v.Kind() == reflect.Ptr {
|
||||
if v.Kind() == reflect.Pointer {
|
||||
return v.Elem()
|
||||
}
|
||||
if v.Kind() == reflect.Slice && v.Elem().Kind() == reflect.Ptr {
|
||||
if v.Kind() == reflect.Slice && v.Elem().Kind() == reflect.Pointer {
|
||||
subElem := v.Elem()
|
||||
if subElem.Elem().Kind() == reflect.Ptr {
|
||||
if subElem.Elem().Kind() == reflect.Pointer {
|
||||
return subElem.Elem().Elem()
|
||||
}
|
||||
return v.Elem()
|
||||
@@ -104,7 +104,7 @@ func GetJSONNameForField(modelType reflect.Type, fieldName string) string {
|
||||
// Unwrap pointer and slice indirections to reach the struct type
|
||||
for {
|
||||
switch modelType.Kind() {
|
||||
case reflect.Ptr, reflect.Slice:
|
||||
case reflect.Pointer, reflect.Slice:
|
||||
modelType = modelType.Elem()
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -226,7 +226,7 @@ func buildJSONToDBMap(modelType reflect.Type, result map[string]string, scanOnly
|
||||
// Handle embedded structs
|
||||
if field.Anonymous {
|
||||
ft := field.Type
|
||||
if ft.Kind() == reflect.Ptr {
|
||||
if ft.Kind() == reflect.Pointer {
|
||||
ft = ft.Elem()
|
||||
}
|
||||
isScanOnly := scanOnly
|
||||
@@ -544,7 +544,7 @@ func IsColumnWritable(model any, columnName string) bool {
|
||||
// Unwrap pointers and slices to get to the base struct type
|
||||
for modelType != nil {
|
||||
switch modelType.Kind() {
|
||||
case reflect.Ptr, reflect.Slice:
|
||||
case reflect.Pointer, reflect.Slice:
|
||||
modelType = modelType.Elem()
|
||||
continue
|
||||
}
|
||||
@@ -709,7 +709,7 @@ func GetColumnTypeFromModel(model interface{}, colName string) reflect.Kind {
|
||||
|
||||
modelType := reflect.TypeOf(model)
|
||||
// Dereference pointer if needed
|
||||
if modelType.Kind() == reflect.Ptr {
|
||||
if modelType.Kind() == reflect.Pointer {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
|
||||
@@ -886,7 +886,7 @@ func GetRelationType(model interface{}, fieldName string) RelationType {
|
||||
// Unwrap pointer → slice → pointer chains to reach the underlying struct
|
||||
for {
|
||||
switch modelType.Kind() {
|
||||
case reflect.Ptr, reflect.Slice:
|
||||
case reflect.Pointer, reflect.Slice:
|
||||
modelType = modelType.Elem()
|
||||
continue
|
||||
}
|
||||
@@ -947,7 +947,7 @@ func GetRelationType(model interface{}, fieldName string) RelationType {
|
||||
// Slice indicates has-many or many-to-many
|
||||
return RelationHasMany
|
||||
}
|
||||
if fieldType.Kind() == reflect.Ptr {
|
||||
if fieldType.Kind() == reflect.Pointer {
|
||||
// Pointer to single struct usually indicates belongs-to or has-one
|
||||
// Check if it has foreignKey (belongs-to) or references (has-one)
|
||||
if strings.Contains(gormTag, "foreignKey:") {
|
||||
@@ -963,7 +963,7 @@ func GetRelationType(model interface{}, fieldName string) RelationType {
|
||||
// Slice of structs → has-many
|
||||
return RelationHasMany
|
||||
}
|
||||
if fieldType.Kind() == reflect.Ptr || fieldType.Kind() == reflect.Struct {
|
||||
if fieldType.Kind() == reflect.Pointer || fieldType.Kind() == reflect.Struct {
|
||||
// Single struct → belongs-to (default assumption for safety)
|
||||
// Using belongs-to as default ensures we use JOIN, which is safer
|
||||
return RelationBelongsTo
|
||||
@@ -990,7 +990,7 @@ func GetRelationType(model interface{}, fieldName string) RelationType {
|
||||
// Strategy 1 is skipped if the matched field is a declared relation (rel:) or
|
||||
// has a GORM tag but carries no explicit FK — callers should use convention.
|
||||
func GetForeignKeyColumn(modelType reflect.Type, parentKey string) []string {
|
||||
for modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice {
|
||||
for modelType.Kind() == reflect.Pointer || modelType.Kind() == reflect.Slice {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
if modelType.Kind() != reflect.Struct {
|
||||
@@ -1123,7 +1123,7 @@ func MapToStruct(dataMap map[string]interface{}, target interface{}) error {
|
||||
}
|
||||
|
||||
targetValue := reflect.ValueOf(target)
|
||||
if targetValue.Kind() != reflect.Ptr {
|
||||
if targetValue.Kind() != reflect.Pointer {
|
||||
return fmt.Errorf("target must be a pointer to a struct")
|
||||
}
|
||||
|
||||
@@ -1226,8 +1226,8 @@ func setFieldValue(field reflect.Value, value interface{}) error {
|
||||
}
|
||||
|
||||
// Handle pointer fields
|
||||
if field.Kind() == reflect.Ptr {
|
||||
if valueReflect.Kind() != reflect.Ptr {
|
||||
if field.Kind() == reflect.Pointer {
|
||||
if valueReflect.Kind() != reflect.Pointer {
|
||||
// Create a new pointer and set its value
|
||||
newPtr := reflect.New(field.Type().Elem())
|
||||
if err := setFieldValue(newPtr.Elem(), value); err != nil {
|
||||
@@ -1418,14 +1418,14 @@ func convertSlice(targetSlice reflect.Value, sourceSlice reflect.Value) error {
|
||||
// Handle nil elements
|
||||
if sourceValue == nil {
|
||||
// For pointer types, nil is valid
|
||||
if targetElemType.Kind() == reflect.Ptr {
|
||||
if targetElemType.Kind() == reflect.Pointer {
|
||||
targetElem.Set(reflect.Zero(targetElemType))
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// If target element type is a pointer to struct, we need to create new instances
|
||||
if targetElemType.Kind() == reflect.Ptr {
|
||||
if targetElemType.Kind() == reflect.Pointer {
|
||||
// Create a new instance of the pointed-to type
|
||||
newElemPtr := reflect.New(targetElemType.Elem())
|
||||
|
||||
@@ -1588,7 +1588,7 @@ func GetValidJSONFieldNames(modelType reflect.Type) map[string]bool {
|
||||
// Unwrap pointers and slices to get to the base struct type
|
||||
for modelType != nil {
|
||||
switch modelType.Kind() {
|
||||
case reflect.Ptr, reflect.Slice:
|
||||
case reflect.Pointer, reflect.Slice:
|
||||
modelType = modelType.Elem()
|
||||
continue
|
||||
}
|
||||
@@ -1616,7 +1616,7 @@ func collectValidFieldNames(typ reflect.Type, validFields map[string]bool) {
|
||||
// Check for embedded structs
|
||||
if field.Anonymous {
|
||||
fieldType := field.Type
|
||||
if fieldType.Kind() == reflect.Ptr {
|
||||
if fieldType.Kind() == reflect.Pointer {
|
||||
fieldType = fieldType.Elem()
|
||||
}
|
||||
if fieldType.Kind() == reflect.Struct {
|
||||
@@ -1655,7 +1655,7 @@ func getRelationModelSingleLevel(model interface{}, fieldName string) interface{
|
||||
|
||||
for {
|
||||
switch modelType.Kind() {
|
||||
case reflect.Ptr, reflect.Slice:
|
||||
case reflect.Pointer, reflect.Slice:
|
||||
modelType = modelType.Elem()
|
||||
continue
|
||||
}
|
||||
@@ -1724,7 +1724,7 @@ func getRelationModelSingleLevel(model interface{}, fieldName string) interface{
|
||||
|
||||
for {
|
||||
switch targetType.Kind() {
|
||||
case reflect.Ptr, reflect.Slice:
|
||||
case reflect.Pointer, reflect.Slice:
|
||||
targetType = targetType.Elem()
|
||||
if targetType == nil {
|
||||
return nil
|
||||
|
||||
+102
-7
@@ -428,14 +428,36 @@ func (h *Handler) executeCreate(ctx context.Context, schema, entity string, data
|
||||
// Use potentially modified data
|
||||
data = hookCtx.Data
|
||||
|
||||
pkName := reflection.GetPrimaryKeyName(model)
|
||||
|
||||
switch v := data.(type) {
|
||||
case map[string]interface{}:
|
||||
query := h.db.NewInsert().Table(tableName)
|
||||
for key, value := range v {
|
||||
query = query.Value(key, value)
|
||||
}
|
||||
if _, err := query.Exec(ctx); err != nil {
|
||||
return nil, fmt.Errorf("create error: %w", err)
|
||||
if pkName != "" {
|
||||
var insertedID interface{}
|
||||
if err := query.Returning(pkName).Scan(ctx, &insertedID); err != nil {
|
||||
return nil, fmt.Errorf("create error: %w", err)
|
||||
}
|
||||
// Re-fetch after insert to capture DB-generated defaults/triggers.
|
||||
modelType := reflect.TypeOf(model)
|
||||
if modelType.Kind() == reflect.Pointer {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
fetchedRecord := reflect.New(modelType).Interface()
|
||||
if err := h.db.NewSelect().Model(fetchedRecord).
|
||||
Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), insertedID).
|
||||
ScanModel(ctx); err == nil {
|
||||
v = mergeWithInput(fetchedRecord, v)
|
||||
} else {
|
||||
logger.Warn("Failed to re-fetch created record with %s=%v: %v", pkName, insertedID, err)
|
||||
}
|
||||
} else {
|
||||
if _, err := query.Exec(ctx); err != nil {
|
||||
return nil, fmt.Errorf("create error: %w", err)
|
||||
}
|
||||
}
|
||||
hookCtx.Result = v
|
||||
if err := h.hooks.Execute(AfterCreate, hookCtx); err != nil {
|
||||
@@ -444,7 +466,12 @@ func (h *Handler) executeCreate(ctx context.Context, schema, entity string, data
|
||||
return v, nil
|
||||
|
||||
case []interface{}:
|
||||
results := make([]interface{}, 0, len(v))
|
||||
modelType := reflect.TypeOf(model)
|
||||
if modelType.Kind() == reflect.Pointer {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
originals := make([]map[string]interface{}, 0, len(v))
|
||||
insertedIDs := make([]interface{}, 0, len(v))
|
||||
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
||||
for _, item := range v {
|
||||
itemMap, ok := item.(map[string]interface{})
|
||||
@@ -455,16 +482,43 @@ func (h *Handler) executeCreate(ctx context.Context, schema, entity string, data
|
||||
for key, value := range itemMap {
|
||||
q = q.Value(key, value)
|
||||
}
|
||||
if _, err := q.Exec(ctx); err != nil {
|
||||
if pkName == "" {
|
||||
if _, err := q.Exec(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
originals = append(originals, itemMap)
|
||||
insertedIDs = append(insertedIDs, nil)
|
||||
continue
|
||||
}
|
||||
var returnedID interface{}
|
||||
if err := q.Returning(pkName).Scan(ctx, &returnedID); err != nil {
|
||||
return err
|
||||
}
|
||||
results = append(results, item)
|
||||
originals = append(originals, itemMap)
|
||||
insertedIDs = append(insertedIDs, returnedID)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("batch create error: %w", err)
|
||||
}
|
||||
// Re-fetch each record after transaction commits; fall back to input on failure.
|
||||
results := make([]interface{}, 0, len(insertedIDs))
|
||||
for i, pkVal := range insertedIDs {
|
||||
if pkVal == nil {
|
||||
results = append(results, originals[i])
|
||||
continue
|
||||
}
|
||||
fetchedRecord := reflect.New(modelType).Interface()
|
||||
if err := h.db.NewSelect().Model(fetchedRecord).
|
||||
Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), pkVal).
|
||||
ScanModel(ctx); err == nil {
|
||||
results = append(results, mergeWithInput(fetchedRecord, originals[i]))
|
||||
} else {
|
||||
logger.Warn("Failed to re-fetch created record with %s=%v: %v", pkName, pkVal, err)
|
||||
results = append(results, originals[i])
|
||||
}
|
||||
}
|
||||
hookCtx.Result = results
|
||||
if err := h.hooks.Execute(AfterCreate, hookCtx); err != nil {
|
||||
return nil, fmt.Errorf("AfterCreate hook failed: %w", err)
|
||||
@@ -513,7 +567,7 @@ func (h *Handler) executeUpdate(ctx context.Context, schema, entity, id string,
|
||||
err = h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
||||
// Read existing record
|
||||
modelType := reflect.TypeOf(model)
|
||||
if modelType.Kind() == reflect.Ptr {
|
||||
if modelType.Kind() == reflect.Pointer {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
existingRecord := reflect.New(modelType).Interface()
|
||||
@@ -584,6 +638,25 @@ func (h *Handler) executeUpdate(ctx context.Context, schema, entity, id string,
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Re-fetch the record after transaction commits to capture DB-generated changes.
|
||||
modelType := reflect.TypeOf(model)
|
||||
if modelType.Kind() == reflect.Pointer {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
fetchedRecord := reflect.New(modelType).Interface()
|
||||
if err := h.db.NewSelect().Model(fetchedRecord).
|
||||
Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id).
|
||||
ScanModel(ctx); err == nil {
|
||||
jsonData, marshalErr := json.Marshal(fetchedRecord)
|
||||
if marshalErr == nil {
|
||||
var fetchedMap map[string]interface{}
|
||||
if json.Unmarshal(jsonData, &fetchedMap) == nil {
|
||||
updateResult = fetchedMap
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return updateResult, nil
|
||||
}
|
||||
|
||||
@@ -628,7 +701,7 @@ func (h *Handler) executeDelete(ctx context.Context, schema, entity, id string)
|
||||
}
|
||||
|
||||
modelType := reflect.TypeOf(model)
|
||||
if modelType.Kind() == reflect.Ptr {
|
||||
if modelType.Kind() == reflect.Pointer {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
|
||||
@@ -749,6 +822,28 @@ func (h *Handler) buildFilterCondition(filter common.FilterOption) (condition st
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// mergeWithInput merges a database record with the original request data.
|
||||
// DB values take precedence (capturing triggers/defaults), while extra
|
||||
// input keys that have no DB column are preserved in the response.
|
||||
func mergeWithInput(dbRecord interface{}, input map[string]interface{}) map[string]interface{} {
|
||||
result := make(map[string]interface{}, len(input))
|
||||
for k, v := range input {
|
||||
result[k] = v
|
||||
}
|
||||
jsonData, err := json.Marshal(dbRecord)
|
||||
if err != nil {
|
||||
return result
|
||||
}
|
||||
var dbMap map[string]interface{}
|
||||
if err := json.Unmarshal(jsonData, &dbMap); err != nil {
|
||||
return result
|
||||
}
|
||||
for k, v := range dbMap {
|
||||
result[k] = v
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, preloads []common.PreloadOption) (common.SelectQuery, error) {
|
||||
for i := range preloads {
|
||||
preload := &preloads[i]
|
||||
|
||||
@@ -67,7 +67,7 @@ func buildModelInfo(schema, entity string, model interface{}) modelInfo {
|
||||
|
||||
// Unwrap to base struct type
|
||||
modelType := reflect.TypeOf(model)
|
||||
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice) {
|
||||
for modelType != nil && (modelType.Kind() == reflect.Pointer || modelType.Kind() == reflect.Slice) {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
if modelType == nil || modelType.Kind() != reflect.Struct {
|
||||
@@ -87,7 +87,7 @@ func buildModelInfo(schema, entity string, model interface{}) modelInfo {
|
||||
fieldType, found := modelType.FieldByName(d.Name)
|
||||
if found {
|
||||
ft := fieldType.Type
|
||||
if ft.Kind() == reflect.Ptr {
|
||||
if ft.Kind() == reflect.Pointer {
|
||||
ft = ft.Elem()
|
||||
}
|
||||
isUserStruct := ft.Kind() == reflect.Struct && ft.Name() != "Time" && ft.PkgPath() != ""
|
||||
@@ -106,7 +106,7 @@ func buildModelInfo(schema, entity string, model interface{}) modelInfo {
|
||||
goType := d.DataType
|
||||
if goType == "" && found {
|
||||
ft := fieldType.Type
|
||||
for ft.Kind() == reflect.Ptr {
|
||||
for ft.Kind() == reflect.Pointer {
|
||||
ft = ft.Elem()
|
||||
}
|
||||
goType = ft.Name()
|
||||
|
||||
+131
-24
@@ -243,7 +243,7 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
||||
|
||||
// Validate and unwrap model type to get base struct
|
||||
modelType := reflect.TypeOf(model)
|
||||
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
||||
for modelType != nil && (modelType.Kind() == reflect.Pointer || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
|
||||
@@ -602,23 +602,44 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
|
||||
}
|
||||
|
||||
// Standard processing without nested relations
|
||||
pkName := reflection.GetPrimaryKeyName(model)
|
||||
query := h.db.NewInsert().Table(tableName)
|
||||
for key, value := range v {
|
||||
query = query.Value(key, common.ConvertSliceForBun(value))
|
||||
}
|
||||
result, err := query.Exec(ctx)
|
||||
if err != nil {
|
||||
logger.Error("Error creating record: %v", err)
|
||||
h.sendError(w, http.StatusInternalServerError, "create_error", "Error creating record", err)
|
||||
return
|
||||
var responseData interface{} = v
|
||||
if pkName == "" {
|
||||
// No PK on model — insert and return input as-is.
|
||||
result, err := query.Exec(ctx)
|
||||
if err != nil {
|
||||
logger.Error("Error creating record: %v", err)
|
||||
h.sendError(w, http.StatusInternalServerError, "create_error", "Error creating record", err)
|
||||
return
|
||||
}
|
||||
logger.Info("Successfully created record, rows affected: %d", result.RowsAffected())
|
||||
} else {
|
||||
var insertedID interface{}
|
||||
if err := query.Returning(pkName).Scan(ctx, &insertedID); err != nil {
|
||||
logger.Error("Error creating record: %v", err)
|
||||
h.sendError(w, http.StatusInternalServerError, "create_error", "Error creating record", err)
|
||||
return
|
||||
}
|
||||
logger.Info("Successfully created record with %s: %v", pkName, insertedID)
|
||||
fetchedRecord := reflect.New(reflection.GetPointerElement(reflect.TypeOf(model))).Interface()
|
||||
if fetchErr := h.db.NewSelect().Model(fetchedRecord).
|
||||
Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), insertedID).
|
||||
ScanModel(ctx); fetchErr == nil {
|
||||
responseData = mergeWithInput(fetchedRecord, v)
|
||||
} else {
|
||||
logger.Warn("Failed to re-fetch created record with %s=%v: %v", pkName, insertedID, fetchErr)
|
||||
}
|
||||
}
|
||||
logger.Info("Successfully created record, rows affected: %d", result.RowsAffected())
|
||||
// Invalidate cache for this table
|
||||
cacheTags := buildCacheTags(schema, tableName)
|
||||
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
||||
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
|
||||
}
|
||||
h.sendResponse(w, v, nil)
|
||||
h.sendResponse(w, responseData, nil)
|
||||
|
||||
case []map[string]interface{}:
|
||||
// Check if any item needs nested processing
|
||||
@@ -666,15 +687,30 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
|
||||
}
|
||||
|
||||
// Standard batch insert without nested relations
|
||||
pkName := reflection.GetPrimaryKeyName(model)
|
||||
modelElemType := reflection.GetPointerElement(reflect.TypeOf(model))
|
||||
originals := make([]map[string]interface{}, 0, len(v))
|
||||
insertedIDs := make([]interface{}, 0, len(v))
|
||||
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
||||
for _, item := range v {
|
||||
txQuery := tx.NewInsert().Table(tableName)
|
||||
for key, value := range item {
|
||||
txQuery = txQuery.Value(key, common.ConvertSliceForBun(value))
|
||||
}
|
||||
if _, err := txQuery.Exec(ctx); err != nil {
|
||||
if pkName == "" {
|
||||
if _, err := txQuery.Exec(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
originals = append(originals, item)
|
||||
insertedIDs = append(insertedIDs, nil)
|
||||
continue
|
||||
}
|
||||
var returnedID interface{}
|
||||
if err := txQuery.Returning(pkName).Scan(ctx, &returnedID); err != nil {
|
||||
return err
|
||||
}
|
||||
originals = append(originals, item)
|
||||
insertedIDs = append(insertedIDs, returnedID)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
@@ -689,7 +725,24 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
|
||||
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
||||
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
|
||||
}
|
||||
h.sendResponse(w, v, nil)
|
||||
// Re-fetch each record after transaction commits; fall back to input on failure.
|
||||
responseItems := make([]interface{}, 0, len(insertedIDs))
|
||||
for i, pkVal := range insertedIDs {
|
||||
if pkVal == nil {
|
||||
responseItems = append(responseItems, originals[i])
|
||||
continue
|
||||
}
|
||||
fetchedRecord := reflect.New(modelElemType).Interface()
|
||||
if fetchErr := h.db.NewSelect().Model(fetchedRecord).
|
||||
Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), pkVal).
|
||||
ScanModel(ctx); fetchErr == nil {
|
||||
responseItems = append(responseItems, mergeWithInput(fetchedRecord, originals[i]))
|
||||
} else {
|
||||
logger.Warn("Failed to re-fetch created record with %s=%v: %v", pkName, pkVal, fetchErr)
|
||||
responseItems = append(responseItems, originals[i])
|
||||
}
|
||||
}
|
||||
h.sendResponse(w, responseItems, nil)
|
||||
|
||||
case []interface{}:
|
||||
// Handle []interface{} type from JSON unmarshaling
|
||||
@@ -742,19 +795,34 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
|
||||
}
|
||||
|
||||
// Standard batch insert without nested relations
|
||||
list := make([]interface{}, 0)
|
||||
pkName := reflection.GetPrimaryKeyName(model)
|
||||
modelElemType := reflection.GetPointerElement(reflect.TypeOf(model))
|
||||
originals := make([]map[string]interface{}, 0, len(v))
|
||||
insertedIDs := make([]interface{}, 0, len(v))
|
||||
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
||||
for _, item := range v {
|
||||
if itemMap, ok := item.(map[string]interface{}); ok {
|
||||
txQuery := tx.NewInsert().Table(tableName)
|
||||
for key, value := range itemMap {
|
||||
txQuery = txQuery.Value(key, common.ConvertSliceForBun(value))
|
||||
}
|
||||
itemMap, ok := item.(map[string]interface{})
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
txQuery := tx.NewInsert().Table(tableName)
|
||||
for key, value := range itemMap {
|
||||
txQuery = txQuery.Value(key, common.ConvertSliceForBun(value))
|
||||
}
|
||||
if pkName == "" {
|
||||
if _, err := txQuery.Exec(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
list = append(list, item)
|
||||
originals = append(originals, itemMap)
|
||||
insertedIDs = append(insertedIDs, nil)
|
||||
continue
|
||||
}
|
||||
var returnedID interface{}
|
||||
if err := txQuery.Returning(pkName).Scan(ctx, &returnedID); err != nil {
|
||||
return err
|
||||
}
|
||||
originals = append(originals, itemMap)
|
||||
insertedIDs = append(insertedIDs, returnedID)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
@@ -769,7 +837,24 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
|
||||
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
||||
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
|
||||
}
|
||||
h.sendResponse(w, list, nil)
|
||||
// Re-fetch each record after transaction commits; fall back to input on failure.
|
||||
responseItems := make([]interface{}, 0, len(insertedIDs))
|
||||
for i, pkVal := range insertedIDs {
|
||||
if pkVal == nil {
|
||||
responseItems = append(responseItems, originals[i])
|
||||
continue
|
||||
}
|
||||
fetchedRecord := reflect.New(modelElemType).Interface()
|
||||
if fetchErr := h.db.NewSelect().Model(fetchedRecord).
|
||||
Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), pkVal).
|
||||
ScanModel(ctx); fetchErr == nil {
|
||||
responseItems = append(responseItems, mergeWithInput(fetchedRecord, originals[i]))
|
||||
} else {
|
||||
logger.Warn("Failed to re-fetch created record with %s=%v: %v", pkName, pkVal, fetchErr)
|
||||
responseItems = append(responseItems, originals[i])
|
||||
}
|
||||
}
|
||||
h.sendResponse(w, responseItems, nil)
|
||||
|
||||
default:
|
||||
logger.Error("Invalid data type for create operation: %T", data)
|
||||
@@ -1462,7 +1547,7 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
||||
|
||||
// First, fetch the record that will be deleted
|
||||
modelType := reflect.TypeOf(model)
|
||||
if modelType.Kind() == reflect.Ptr {
|
||||
if modelType.Kind() == reflect.Pointer {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
recordToDelete := reflect.New(modelType).Interface()
|
||||
@@ -1737,7 +1822,7 @@ func (h *Handler) generateMetadata(schema, entity string, model interface{}) *co
|
||||
modelType := reflect.TypeOf(model)
|
||||
|
||||
// Unwrap pointers, slices, and arrays to get to the base struct type
|
||||
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
||||
for modelType != nil && (modelType.Kind() == reflect.Pointer || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
|
||||
@@ -1881,7 +1966,7 @@ func getColumnType(field reflect.StructField) string {
|
||||
|
||||
func isNullable(field reflect.StructField) bool {
|
||||
// Check if it's a pointer type
|
||||
if field.Type.Kind() == reflect.Ptr {
|
||||
if field.Type.Kind() == reflect.Pointer {
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -1907,7 +1992,7 @@ func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, pre
|
||||
modelType := reflect.TypeOf(model)
|
||||
|
||||
// Unwrap pointers, slices, and arrays to get to the base struct type
|
||||
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
||||
for modelType != nil && (modelType.Kind() == reflect.Pointer || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
|
||||
@@ -2055,7 +2140,7 @@ func toSnakeCase(s string) string {
|
||||
func (h *Handler) setRowNumbersOnRecords(records interface{}, offset int) {
|
||||
// Get the reflect value of the records
|
||||
recordsValue := reflect.ValueOf(records)
|
||||
if recordsValue.Kind() == reflect.Ptr {
|
||||
if recordsValue.Kind() == reflect.Pointer {
|
||||
recordsValue = recordsValue.Elem()
|
||||
}
|
||||
|
||||
@@ -2070,7 +2155,7 @@ func (h *Handler) setRowNumbersOnRecords(records interface{}, offset int) {
|
||||
record := recordsValue.Index(i)
|
||||
|
||||
// Dereference if it's a pointer
|
||||
if record.Kind() == reflect.Ptr {
|
||||
if record.Kind() == reflect.Pointer {
|
||||
if record.IsNil() {
|
||||
continue
|
||||
}
|
||||
@@ -2122,3 +2207,25 @@ func (h *Handler) HandleOpenAPI(w common.ResponseWriter, r common.Request) {
|
||||
func (h *Handler) SetOpenAPIGenerator(generator func() (string, error)) {
|
||||
h.openAPIGenerator = generator
|
||||
}
|
||||
|
||||
// mergeWithInput merges a database record with the original request data.
|
||||
// DB values take precedence (capturing triggers/defaults), while extra
|
||||
// input keys that have no DB column are preserved in the response.
|
||||
func mergeWithInput(dbRecord interface{}, input map[string]interface{}) map[string]interface{} {
|
||||
result := make(map[string]interface{}, len(input))
|
||||
for k, v := range input {
|
||||
result[k] = v
|
||||
}
|
||||
jsonData, err := json.Marshal(dbRecord)
|
||||
if err != nil {
|
||||
return result
|
||||
}
|
||||
var dbMap map[string]interface{}
|
||||
if err := json.Unmarshal(jsonData, &dbMap); err != nil {
|
||||
return result
|
||||
}
|
||||
for k, v := range dbMap {
|
||||
result[k] = v
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
@@ -0,0 +1,209 @@
|
||||
package restheadspec
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
)
|
||||
|
||||
// detailTestModel is a simple model with gorm column/type tags for detail format tests.
|
||||
type detailTestModel struct {
|
||||
ID int64 `bun:"rid,pk" gorm:"column:rid;primaryKey" json:"rid"`
|
||||
Name string `bun:"name" gorm:"column:name;type:citext" json:"name"`
|
||||
Description *string `bun:"description" gorm:"column:description;type:text;nullable" json:"description"`
|
||||
Score float64 `bun:"score" gorm:"column:score;type:numeric" json:"score"`
|
||||
Active bool `bun:"active" gorm:"column:active;type:boolean;not null" json:"active"`
|
||||
}
|
||||
|
||||
func TestSendFormattedResponse_DetailFormat(t *testing.T) {
|
||||
handler := &Handler{}
|
||||
|
||||
name := "hello"
|
||||
items := []*detailTestModel{
|
||||
{ID: 1, Name: "first", Description: &name, Score: 1.5, Active: true},
|
||||
{ID: 2, Name: "second", Description: nil, Score: 2.0, Active: false},
|
||||
}
|
||||
metadata := &common.Metadata{
|
||||
Total: 36,
|
||||
Count: 2,
|
||||
Filtered: 36,
|
||||
Limit: 10,
|
||||
Offset: 0,
|
||||
}
|
||||
options := ExtendedRequestOptions{
|
||||
ResponseFormat: "detail",
|
||||
}
|
||||
|
||||
mockWriter := &MockTestResponseWriter{headers: make(map[string]string)}
|
||||
handler.sendFormattedResponse(mockWriter, items, metadata, "myschema.myentity", detailTestModel{}, options)
|
||||
|
||||
if mockWriter.statusCode != 200 {
|
||||
t.Fatalf("expected status 200, got %d", mockWriter.statusCode)
|
||||
}
|
||||
|
||||
body, err := json.Marshal(mockWriter.body)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to marshal body: %v", err)
|
||||
}
|
||||
|
||||
var resp map[string]json.RawMessage
|
||||
if err := json.Unmarshal(body, &resp); err != nil {
|
||||
t.Fatalf("failed to unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
t.Run("top-level keys", func(t *testing.T) {
|
||||
for _, key := range []string{"count", "fields", "items", "tablename", "tableprefix", "total"} {
|
||||
if _, ok := resp[key]; !ok {
|
||||
t.Errorf("missing key %q in detail response", key)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("count and total are string", func(t *testing.T) {
|
||||
var count, total string
|
||||
if err := json.Unmarshal(resp["count"], &count); err != nil {
|
||||
t.Errorf("count is not a string: %v", err)
|
||||
}
|
||||
if err := json.Unmarshal(resp["total"], &total); err != nil {
|
||||
t.Errorf("total is not a string: %v", err)
|
||||
}
|
||||
if count != "2" {
|
||||
t.Errorf("expected count %q, got %q", "2", count)
|
||||
}
|
||||
if total != "36" {
|
||||
t.Errorf("expected total %q, got %q", "36", total)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("tablename and tableprefix", func(t *testing.T) {
|
||||
var tablename, tableprefix string
|
||||
json.Unmarshal(resp["tablename"], &tablename)
|
||||
json.Unmarshal(resp["tableprefix"], &tableprefix)
|
||||
if tablename != "myschema.myentity" {
|
||||
t.Errorf("expected tablename %q, got %q", "myschema.myentity", tablename)
|
||||
}
|
||||
if tableprefix != "myentity" {
|
||||
t.Errorf("expected tableprefix %q, got %q", "myentity", tableprefix)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("items contains data", func(t *testing.T) {
|
||||
var itemSlice []map[string]interface{}
|
||||
if err := json.Unmarshal(resp["items"], &itemSlice); err != nil {
|
||||
t.Fatalf("items is not an array: %v", err)
|
||||
}
|
||||
if len(itemSlice) != 2 {
|
||||
t.Errorf("expected 2 items, got %d", len(itemSlice))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("fields contains column metadata", func(t *testing.T) {
|
||||
var fields []map[string]interface{}
|
||||
if err := json.Unmarshal(resp["fields"], &fields); err != nil {
|
||||
t.Fatalf("fields is not an array: %v", err)
|
||||
}
|
||||
if len(fields) == 0 {
|
||||
t.Fatal("expected fields to be non-empty")
|
||||
}
|
||||
|
||||
bySQL := make(map[string]map[string]interface{}, len(fields))
|
||||
for _, f := range fields {
|
||||
if sqlname, ok := f["sqlname"].(string); ok {
|
||||
bySQL[sqlname] = f
|
||||
}
|
||||
}
|
||||
|
||||
// Check required field keys are present
|
||||
for _, f := range fields {
|
||||
for _, key := range []string{"name", "datatype", "sqlname", "sqldatatype", "sqlkey", "nullable"} {
|
||||
if _, ok := f[key]; !ok {
|
||||
t.Errorf("field %v missing key %q", f, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Validate specific columns
|
||||
if col, ok := bySQL["rid"]; ok {
|
||||
if col["sqlkey"] != "primary_key" {
|
||||
t.Errorf("rid: expected sqlkey %q, got %v", "primary_key", col["sqlkey"])
|
||||
}
|
||||
} else {
|
||||
t.Error("expected column 'rid' in fields")
|
||||
}
|
||||
|
||||
if col, ok := bySQL["name"]; ok {
|
||||
if col["sqldatatype"] != "citext" {
|
||||
t.Errorf("name: expected sqldatatype %q, got %v", "citext", col["sqldatatype"])
|
||||
}
|
||||
if col["nullable"] != false {
|
||||
t.Errorf("name: expected nullable false, got %v", col["nullable"])
|
||||
}
|
||||
} else {
|
||||
t.Error("expected column 'name' in fields")
|
||||
}
|
||||
|
||||
if col, ok := bySQL["description"]; ok {
|
||||
if col["sqldatatype"] != "text" {
|
||||
t.Errorf("description: expected sqldatatype %q, got %v", "text", col["sqldatatype"])
|
||||
}
|
||||
if col["nullable"] != true {
|
||||
t.Errorf("description: expected nullable true, got %v", col["nullable"])
|
||||
}
|
||||
} else {
|
||||
t.Error("expected column 'description' in fields")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestSendFormattedResponse_DetailFormat_EmptyItems(t *testing.T) {
|
||||
handler := &Handler{}
|
||||
metadata := &common.Metadata{Total: 0, Count: 0, Filtered: 0}
|
||||
options := ExtendedRequestOptions{ResponseFormat: "detail"}
|
||||
|
||||
mockWriter := &MockTestResponseWriter{headers: make(map[string]string)}
|
||||
handler.sendFormattedResponse(mockWriter, []*detailTestModel{}, metadata, "s.t", detailTestModel{}, options)
|
||||
|
||||
body, _ := json.Marshal(mockWriter.body)
|
||||
var resp map[string]json.RawMessage
|
||||
json.Unmarshal(body, &resp)
|
||||
|
||||
var count, total string
|
||||
json.Unmarshal(resp["count"], &count)
|
||||
json.Unmarshal(resp["total"], &total)
|
||||
|
||||
if count != "0" || total != "0" {
|
||||
t.Errorf("expected count/total both %q, got count=%q total=%q", "0", count, total)
|
||||
}
|
||||
|
||||
var fields []interface{}
|
||||
json.Unmarshal(resp["fields"], &fields)
|
||||
if len(fields) == 0 {
|
||||
t.Error("fields should still list column metadata even when items is empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildDetailFields_SkipsRelations(t *testing.T) {
|
||||
type child struct {
|
||||
ID int64 `bun:"id,pk" gorm:"column:id;primaryKey" json:"id"`
|
||||
}
|
||||
type parent struct {
|
||||
ID int64 `bun:"id,pk" gorm:"column:id;primaryKey" json:"id"`
|
||||
Name string `bun:"name" gorm:"column:name" json:"name"`
|
||||
Children []child `bun:"rel:has-many" json:"children"`
|
||||
Child *child `bun:"rel:has-one" json:"child"`
|
||||
}
|
||||
|
||||
handler := &Handler{}
|
||||
fields := handler.buildDetailFields(parent{})
|
||||
|
||||
for _, f := range fields {
|
||||
if f.SQLName == "children" || f.SQLName == "child" {
|
||||
t.Errorf("relation field %q should not appear in detail fields", f.SQLName)
|
||||
}
|
||||
}
|
||||
|
||||
if len(fields) != 2 {
|
||||
t.Errorf("expected 2 scalar fields (id, name), got %d", len(fields))
|
||||
}
|
||||
}
|
||||
@@ -95,7 +95,7 @@ func TestSendFormattedResponse_NoDataFoundHeader(t *testing.T) {
|
||||
|
||||
// Test with empty data
|
||||
emptyData := []interface{}{}
|
||||
handler.sendFormattedResponse(mockWriter, emptyData, metadata, options)
|
||||
handler.sendFormattedResponse(mockWriter, emptyData, metadata, "", nil, options)
|
||||
|
||||
// Check if X-No-Data-Found header was set
|
||||
if mockWriter.headers["X-No-Data-Found"] != "true" {
|
||||
|
||||
+138
-18
@@ -289,7 +289,8 @@ func (h *Handler) HandleGet(w common.ResponseWriter, r common.Request, params ma
|
||||
Limit: 0,
|
||||
Offset: 0,
|
||||
}
|
||||
h.sendFormattedResponse(w, tableMetadata, responseMetadata, options)
|
||||
tableName := h.getTableName(schema, entity, model)
|
||||
h.sendFormattedResponse(w, tableMetadata, responseMetadata, tableName, model, options)
|
||||
}
|
||||
|
||||
// handleMeta processes meta operation requests
|
||||
@@ -348,7 +349,7 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
||||
|
||||
// Validate and unwrap model type to get base struct
|
||||
modelType := reflect.TypeOf(model)
|
||||
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
||||
for modelType != nil && (modelType.Kind() == reflect.Pointer || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
|
||||
@@ -848,7 +849,7 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
||||
return
|
||||
}
|
||||
|
||||
h.sendFormattedResponse(w, modelPtr, metadata, options)
|
||||
h.sendFormattedResponse(w, modelPtr, metadata, tableName, model, options)
|
||||
}
|
||||
|
||||
// applyPreloadWithRecursion applies a preload with support for ComputedQL and recursive preloading
|
||||
@@ -1891,7 +1892,7 @@ func (h *Handler) extractNestedRelations(
|
||||
) (_cleanedData map[string]interface{}, _relations map[string]interface{}, _err error) {
|
||||
// Get model type for reflection
|
||||
modelType := reflect.TypeOf(model)
|
||||
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
||||
for modelType != nil && (modelType.Kind() == reflect.Pointer || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
|
||||
@@ -1933,7 +1934,7 @@ func (h *Handler) processChildRelationsWithParentID(
|
||||
) error {
|
||||
// Get model type for reflection
|
||||
modelType := reflect.TypeOf(parentModel)
|
||||
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
||||
for modelType != nil && (modelType.Kind() == reflect.Pointer || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
|
||||
@@ -1989,7 +1990,7 @@ func (h *Handler) processChildRelationsForField(
|
||||
if relatedModelType.Kind() == reflect.Slice {
|
||||
relatedModelType = relatedModelType.Elem()
|
||||
}
|
||||
if relatedModelType.Kind() == reflect.Ptr {
|
||||
if relatedModelType.Kind() == reflect.Pointer {
|
||||
relatedModelType = relatedModelType.Elem()
|
||||
}
|
||||
|
||||
@@ -2412,7 +2413,7 @@ func (h *Handler) generateMetadata(schema, entity string, model interface{}) *co
|
||||
modelType := reflect.TypeOf(model)
|
||||
|
||||
// Unwrap pointers, slices, and arrays to get to the base struct type
|
||||
for modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array {
|
||||
for modelType.Kind() == reflect.Pointer || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
|
||||
@@ -2461,7 +2462,7 @@ func (h *Handler) generateMetadata(schema, entity string, model interface{}) *co
|
||||
// Check if this is a relation field (slice or struct, but not time.Time)
|
||||
if field.Type.Kind() == reflect.Slice ||
|
||||
(field.Type.Kind() == reflect.Struct && field.Type.Name() != "Time") ||
|
||||
(field.Type.Kind() == reflect.Ptr && field.Type.Elem().Kind() == reflect.Struct && field.Type.Elem().Name() != "Time") {
|
||||
(field.Type.Kind() == reflect.Pointer && field.Type.Elem().Kind() == reflect.Struct && field.Type.Elem().Name() != "Time") {
|
||||
metadata.Relations = append(metadata.Relations, jsonName)
|
||||
continue
|
||||
}
|
||||
@@ -2507,7 +2508,7 @@ func (h *Handler) getColumnType(t reflect.Type) string {
|
||||
return "float"
|
||||
case reflect.Bool:
|
||||
return "boolean"
|
||||
case reflect.Ptr:
|
||||
case reflect.Pointer:
|
||||
return h.getColumnType(t.Elem())
|
||||
default:
|
||||
return "unknown"
|
||||
@@ -2515,7 +2516,7 @@ func (h *Handler) getColumnType(t reflect.Type) string {
|
||||
}
|
||||
|
||||
func (h *Handler) isNullable(field reflect.StructField) bool {
|
||||
return field.Type.Kind() == reflect.Ptr
|
||||
return field.Type.Kind() == reflect.Pointer
|
||||
}
|
||||
|
||||
func (h *Handler) sendResponse(w common.ResponseWriter, data interface{}, metadata *common.Metadata) {
|
||||
@@ -2560,7 +2561,7 @@ func (h *Handler) normalizeResultArray(data interface{}) interface{} {
|
||||
|
||||
// Use reflection to check if data is a slice or array
|
||||
dataValue := reflect.ValueOf(data)
|
||||
if dataValue.Kind() == reflect.Ptr {
|
||||
if dataValue.Kind() == reflect.Pointer {
|
||||
dataValue = dataValue.Elem()
|
||||
}
|
||||
|
||||
@@ -2585,8 +2586,103 @@ func (h *Handler) normalizeResultArray(data interface{}) interface{} {
|
||||
return data
|
||||
}
|
||||
|
||||
// sendFormattedResponse sends response with formatting options
|
||||
func (h *Handler) sendFormattedResponse(w common.ResponseWriter, data interface{}, metadata *common.Metadata, options ExtendedRequestOptions) {
|
||||
// buildDetailFields returns the field metadata list for the detail API format,
|
||||
// containing only non-relation scalar columns derived from the model's struct tags.
|
||||
func (h *Handler) buildDetailFields(model interface{}) []reflection.ModelFieldDetail {
|
||||
if model == nil {
|
||||
return []reflection.ModelFieldDetail{}
|
||||
}
|
||||
|
||||
modelType := reflect.TypeOf(model)
|
||||
for modelType != nil && (modelType.Kind() == reflect.Pointer || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
if modelType == nil || modelType.Kind() != reflect.Struct {
|
||||
return []reflection.ModelFieldDetail{}
|
||||
}
|
||||
|
||||
fields := make([]reflection.ModelFieldDetail, 0, modelType.NumField())
|
||||
for i := 0; i < modelType.NumField(); i++ {
|
||||
field := modelType.Field(i)
|
||||
if !field.IsExported() {
|
||||
continue
|
||||
}
|
||||
|
||||
jsonTag := field.Tag.Get("json")
|
||||
if jsonTag == "-" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip relation fields (slices, structs that aren't time.Time, ptrs to struct)
|
||||
ft := field.Type
|
||||
if ft.Kind() == reflect.Pointer {
|
||||
ft = ft.Elem()
|
||||
}
|
||||
if ft.Kind() == reflect.Slice ||
|
||||
(ft.Kind() == reflect.Struct && ft.Name() != "Time") {
|
||||
continue
|
||||
}
|
||||
|
||||
jsonName := strings.Split(jsonTag, ",")[0]
|
||||
if jsonName == "" {
|
||||
jsonName = field.Name
|
||||
}
|
||||
|
||||
gormTag := field.Tag.Get("gorm")
|
||||
sqlName := fnFindTagVal(gormTag, "column:")
|
||||
if sqlName == "" {
|
||||
sqlName = jsonName
|
||||
}
|
||||
|
||||
sqlDataType := fnFindTagVal(gormTag, "type:")
|
||||
|
||||
var sqlKey string
|
||||
gormLower := strings.ToLower(gormTag)
|
||||
switch {
|
||||
case strings.Contains(gormLower, "identity") || strings.Contains(gormLower, "primary_key") || strings.Contains(gormLower, "primarykey"):
|
||||
sqlKey = "primary_key"
|
||||
case strings.Contains(gormLower, "uniqueindex"):
|
||||
sqlKey = "uniqueindex"
|
||||
case strings.Contains(gormLower, "unique"):
|
||||
sqlKey = "unique"
|
||||
}
|
||||
|
||||
nullable := field.Type.Kind() == reflect.Pointer
|
||||
if strings.Contains(gormLower, "not null") {
|
||||
nullable = false
|
||||
} else if strings.Contains(gormLower, "nullable") || strings.Contains(gormLower, ",null") {
|
||||
nullable = true
|
||||
}
|
||||
|
||||
fields = append(fields, reflection.ModelFieldDetail{
|
||||
Name: jsonName,
|
||||
DataType: h.getColumnType(field.Type),
|
||||
SQLName: sqlName,
|
||||
SQLDataType: sqlDataType,
|
||||
SQLKey: sqlKey,
|
||||
Nullable: nullable,
|
||||
})
|
||||
}
|
||||
return fields
|
||||
}
|
||||
|
||||
// fnFindTagVal extracts a value from a semicolon-separated struct tag string.
|
||||
func fnFindTagVal(tag, key string) string {
|
||||
lower := strings.ToLower(tag)
|
||||
idx := strings.Index(lower, strings.ToLower(key))
|
||||
if idx < 0 {
|
||||
return ""
|
||||
}
|
||||
val := tag[idx+len(key):]
|
||||
if end := strings.Index(val, ";"); end >= 0 {
|
||||
val = val[:end]
|
||||
}
|
||||
return val
|
||||
}
|
||||
|
||||
// sendFormattedResponse sends response with formatting options.
|
||||
// model is used when ResponseFormat is "detail" to generate the fields metadata list.
|
||||
func (h *Handler) sendFormattedResponse(w common.ResponseWriter, data interface{}, metadata *common.Metadata, tableName string, model interface{}, options ExtendedRequestOptions) {
|
||||
// Handle nil data - convert to empty array
|
||||
if data == nil {
|
||||
data = []interface{}{}
|
||||
@@ -2615,9 +2711,12 @@ func (h *Handler) sendFormattedResponse(w common.ResponseWriter, data interface{
|
||||
}
|
||||
|
||||
w.SetHeader("Content-Type", "application/json")
|
||||
w.SetHeader("Content-Range", fmt.Sprintf("%d-%d/%d", metadata.Offset, int64(metadata.Offset)+metadata.Count, metadata.Filtered))
|
||||
w.SetHeader("Content-Range", fmt.Sprintf("items %d-%d/%d", metadata.Offset, int64(metadata.Offset)+metadata.Count, metadata.Filtered))
|
||||
w.SetHeader("X-Api-Range-Total", fmt.Sprintf("%d", metadata.Filtered))
|
||||
w.SetHeader("X-Api-Range-Size", fmt.Sprintf("%d", metadata.Count))
|
||||
w.SetHeader("X-Api-Range-From", fmt.Sprintf("%d", metadata.Offset))
|
||||
w.SetHeader("X-Api-Range-Etotal", fmt.Sprintf("%d", metadata.Filtered))
|
||||
w.SetHeader("X-Api-Modelname", tableName)
|
||||
|
||||
// Format response based on response format option
|
||||
switch options.ResponseFormat {
|
||||
@@ -2639,8 +2738,29 @@ func (h *Handler) sendFormattedResponse(w common.ResponseWriter, data interface{
|
||||
if err := w.WriteJSON(response); err != nil {
|
||||
logger.Error("Failed to write JSON response: %v", err)
|
||||
}
|
||||
case "detail":
|
||||
// Detail format: { count, fields, items, tablename, tableprefix, total }
|
||||
var count, total int64
|
||||
if metadata != nil {
|
||||
count = metadata.Count
|
||||
total = metadata.Total
|
||||
}
|
||||
tablePrefix := reflection.ExtractTableNameOnly(tableName)
|
||||
fieldList := h.buildDetailFields(model)
|
||||
response := map[string]interface{}{
|
||||
"count": strconv.FormatInt(count, 10),
|
||||
"fields": fieldList,
|
||||
"items": data,
|
||||
"tablename": tableName,
|
||||
"tableprefix": tablePrefix,
|
||||
"total": strconv.FormatInt(total, 10),
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if err := w.WriteJSON(response); err != nil {
|
||||
logger.Error("Failed to write JSON response: %v", err)
|
||||
}
|
||||
default:
|
||||
// Default/detail format: standard response with metadata
|
||||
// Default format: standard response with metadata
|
||||
response := common.Response{
|
||||
Success: true,
|
||||
Data: data,
|
||||
@@ -2898,7 +3018,7 @@ func (h *Handler) buildFilterSQL(filter *common.FilterOption, tableName string)
|
||||
func (h *Handler) setRowNumbersOnRecords(records any, offset int) {
|
||||
// Get the reflect value of the records
|
||||
recordsValue := reflect.ValueOf(records)
|
||||
if recordsValue.Kind() == reflect.Ptr {
|
||||
if recordsValue.Kind() == reflect.Pointer {
|
||||
recordsValue = recordsValue.Elem()
|
||||
}
|
||||
|
||||
@@ -2913,7 +3033,7 @@ func (h *Handler) setRowNumbersOnRecords(records any, offset int) {
|
||||
record := recordsValue.Index(i)
|
||||
|
||||
// Dereference if it's a pointer
|
||||
if record.Kind() == reflect.Ptr {
|
||||
if record.Kind() == reflect.Pointer {
|
||||
if record.IsNil() {
|
||||
continue
|
||||
}
|
||||
@@ -2968,7 +3088,7 @@ func (h *Handler) filterExtendedOptions(validator *common.ColumnValidator, optio
|
||||
// Filter Expand columns using the expand relation's model
|
||||
filteredExpands := make([]ExpandOption, 0, len(options.Expand))
|
||||
modelType := reflect.TypeOf(model)
|
||||
if modelType.Kind() == reflect.Ptr {
|
||||
if modelType.Kind() == reflect.Pointer {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
|
||||
|
||||
+11
-10
@@ -225,12 +225,13 @@ func (h *Handler) parseOptionsFromHeaders(r common.Request, model interface{}) E
|
||||
limitValueParts := strings.Split(limitValue, ",")
|
||||
|
||||
if len(limitValueParts) > 1 {
|
||||
if offset, err := strconv.Atoi(limitValueParts[0]); err == nil {
|
||||
options.Offset = &offset
|
||||
}
|
||||
if limit, err := strconv.Atoi(limitValueParts[1]); err == nil {
|
||||
if limit, err := strconv.Atoi(limitValueParts[0]); err == nil {
|
||||
options.Limit = &limit
|
||||
}
|
||||
if offset, err := strconv.Atoi(limitValueParts[1]); err == nil {
|
||||
options.Offset = &offset
|
||||
}
|
||||
|
||||
} else {
|
||||
if limit, err := strconv.Atoi(limitValueParts[0]); err == nil {
|
||||
options.Limit = &limit
|
||||
@@ -977,7 +978,7 @@ func (h *Handler) resolveRelationName(model interface{}, nameOrTable string) str
|
||||
}
|
||||
|
||||
// Dereference pointer if needed
|
||||
if modelType.Kind() == reflect.Ptr {
|
||||
if modelType.Kind() == reflect.Pointer {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
|
||||
@@ -1012,13 +1013,13 @@ func (h *Handler) resolveRelationName(model interface{}, nameOrTable string) str
|
||||
var targetType reflect.Type
|
||||
if fieldType.Kind() == reflect.Slice {
|
||||
targetType = fieldType.Elem()
|
||||
} else if fieldType.Kind() == reflect.Ptr {
|
||||
} else if fieldType.Kind() == reflect.Pointer {
|
||||
targetType = fieldType.Elem()
|
||||
}
|
||||
|
||||
if targetType != nil {
|
||||
// Dereference pointer if the slice contains pointers
|
||||
if targetType.Kind() == reflect.Ptr {
|
||||
if targetType.Kind() == reflect.Pointer {
|
||||
targetType = targetType.Elem()
|
||||
}
|
||||
|
||||
@@ -1062,7 +1063,7 @@ func (h *Handler) resolveRelationNameWithJoinKey(model interface{}, nameOrTable
|
||||
if modelType == nil {
|
||||
return nameOrTable
|
||||
}
|
||||
if modelType.Kind() == reflect.Ptr {
|
||||
if modelType.Kind() == reflect.Pointer {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
if modelType == nil || modelType.Kind() != reflect.Struct {
|
||||
@@ -1089,10 +1090,10 @@ func (h *Handler) resolveRelationNameWithJoinKey(model interface{}, nameOrTable
|
||||
var targetType reflect.Type
|
||||
if fieldType.Kind() == reflect.Slice {
|
||||
targetType = fieldType.Elem()
|
||||
} else if fieldType.Kind() == reflect.Ptr {
|
||||
} else if fieldType.Kind() == reflect.Pointer {
|
||||
targetType = fieldType.Elem()
|
||||
}
|
||||
if targetType != nil && targetType.Kind() == reflect.Ptr {
|
||||
if targetType != nil && targetType.Kind() == reflect.Pointer {
|
||||
targetType = targetType.Elem()
|
||||
}
|
||||
if targetType == nil || targetType.Kind() != reflect.Struct {
|
||||
|
||||
@@ -13,6 +13,9 @@ CREATE TABLE IF NOT EXISTS users (
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
last_login_at TIMESTAMP,
|
||||
-- Program-level user mapping
|
||||
program_user_id INTEGER DEFAULT 0,
|
||||
program_user_table VARCHAR(255) DEFAULT '',
|
||||
-- OAuth2 fields
|
||||
remote_id VARCHAR(255), -- Provider's user ID (e.g., Google sub, GitHub id)
|
||||
auth_provider VARCHAR(50), -- 'local', 'google', 'github', 'microsoft', 'facebook', etc.
|
||||
@@ -99,6 +102,8 @@ DECLARE
|
||||
v_expires_at TIMESTAMP;
|
||||
v_ip_address TEXT;
|
||||
v_user_agent TEXT;
|
||||
v_program_user_id INTEGER;
|
||||
v_program_user_table TEXT;
|
||||
BEGIN
|
||||
-- Extract login request fields
|
||||
v_username := p_request->>'username';
|
||||
@@ -106,8 +111,8 @@ BEGIN
|
||||
v_user_agent := p_request->'claims'->>'user_agent';
|
||||
|
||||
-- Validate user credentials
|
||||
SELECT id, username, email, password, user_level, roles
|
||||
INTO v_user_id, v_username, v_email, v_password_hash, v_user_level, v_roles
|
||||
SELECT id, username, email, password, user_level, roles, program_user_id, program_user_table
|
||||
INTO v_user_id, v_username, v_email, v_password_hash, v_user_level, v_roles, v_program_user_id, v_program_user_table
|
||||
FROM users
|
||||
WHERE username = v_username AND is_active = true;
|
||||
|
||||
@@ -146,7 +151,9 @@ BEGIN
|
||||
'email', v_email,
|
||||
'user_level', v_user_level,
|
||||
'roles', string_to_array(COALESCE(v_roles, ''), ','),
|
||||
'session_id', v_session_token
|
||||
'session_id', v_session_token,
|
||||
'program_user_id', COALESCE(v_program_user_id, 0),
|
||||
'program_user_table', COALESCE(v_program_user_table, '')
|
||||
),
|
||||
'expires_in', 86400 -- 24 hours in seconds
|
||||
);
|
||||
@@ -195,12 +202,16 @@ DECLARE
|
||||
v_user_level INTEGER;
|
||||
v_roles TEXT;
|
||||
v_session_id TEXT;
|
||||
v_program_user_id INTEGER;
|
||||
v_program_user_table TEXT;
|
||||
BEGIN
|
||||
-- Query session and user data
|
||||
SELECT
|
||||
s.user_id, u.username, u.email, u.user_level, u.roles, s.session_token
|
||||
s.user_id, u.username, u.email, u.user_level, u.roles, s.session_token,
|
||||
u.program_user_id, u.program_user_table
|
||||
INTO
|
||||
v_user_id, v_username, v_email, v_user_level, v_roles, v_session_id
|
||||
v_user_id, v_username, v_email, v_user_level, v_roles, v_session_id,
|
||||
v_program_user_id, v_program_user_table
|
||||
FROM user_sessions s
|
||||
JOIN users u ON s.user_id = u.id
|
||||
WHERE s.session_token = p_session_token
|
||||
@@ -222,7 +233,9 @@ BEGIN
|
||||
'email', v_email,
|
||||
'user_level', v_user_level,
|
||||
'session_id', v_session_id,
|
||||
'roles', string_to_array(COALESCE(v_roles, ''), ',')
|
||||
'roles', string_to_array(COALESCE(v_roles, ''), ','),
|
||||
'program_user_id', COALESCE(v_program_user_id, 0),
|
||||
'program_user_table', COALESCE(v_program_user_table, '')
|
||||
);
|
||||
END;
|
||||
$$ LANGUAGE plpgsql;
|
||||
@@ -266,10 +279,14 @@ DECLARE
|
||||
v_expires_at TIMESTAMP;
|
||||
v_ip_address TEXT;
|
||||
v_user_agent TEXT;
|
||||
v_program_user_id INTEGER;
|
||||
v_program_user_table TEXT;
|
||||
BEGIN
|
||||
-- Verify old session exists and is valid
|
||||
SELECT s.user_id, u.username, u.email, u.user_level, u.roles, s.ip_address, s.user_agent
|
||||
INTO v_user_id, v_username, v_email, v_user_level, v_roles, v_ip_address, v_user_agent
|
||||
SELECT s.user_id, u.username, u.email, u.user_level, u.roles, s.ip_address, s.user_agent,
|
||||
u.program_user_id, u.program_user_table
|
||||
INTO v_user_id, v_username, v_email, v_user_level, v_roles, v_ip_address, v_user_agent,
|
||||
v_program_user_id, v_program_user_table
|
||||
FROM user_sessions s
|
||||
JOIN users u ON s.user_id = u.id
|
||||
WHERE s.session_token = p_old_session_token
|
||||
@@ -302,7 +319,9 @@ BEGIN
|
||||
'email', v_email,
|
||||
'user_level', v_user_level,
|
||||
'session_id', v_new_session_token,
|
||||
'roles', string_to_array(COALESCE(v_roles, ''), ',')
|
||||
'roles', string_to_array(COALESCE(v_roles, ''), ','),
|
||||
'program_user_id', COALESCE(v_program_user_id, 0),
|
||||
'program_user_table', COALESCE(v_program_user_table, '')
|
||||
);
|
||||
END;
|
||||
$$ LANGUAGE plpgsql;
|
||||
@@ -439,6 +458,8 @@ DECLARE
|
||||
v_ip_address TEXT;
|
||||
v_user_agent TEXT;
|
||||
v_roles_array TEXT[];
|
||||
v_program_user_id INTEGER;
|
||||
v_program_user_table TEXT;
|
||||
BEGIN
|
||||
-- Extract registration request fields
|
||||
v_username := p_request->>'username';
|
||||
@@ -447,6 +468,8 @@ BEGIN
|
||||
v_user_level := COALESCE((p_request->>'user_level')::integer, 0);
|
||||
v_ip_address := p_request->'claims'->>'ip_address';
|
||||
v_user_agent := p_request->'claims'->>'user_agent';
|
||||
v_program_user_id := COALESCE((p_request->>'program_user_id')::integer, 0);
|
||||
v_program_user_table := COALESCE(p_request->>'program_user_table', '');
|
||||
|
||||
-- Convert roles array from JSON to comma-separated string
|
||||
SELECT array_to_string(ARRAY(SELECT jsonb_array_elements_text(p_request->'roles')), ',')
|
||||
@@ -485,8 +508,8 @@ BEGIN
|
||||
-- v_password := crypt(v_password, gen_salt('bf'));
|
||||
|
||||
-- Create new user
|
||||
INSERT INTO users (username, email, password, user_level, roles, is_active, created_at, updated_at)
|
||||
VALUES (v_username, v_email, v_password, v_user_level, v_roles, true, now(), now())
|
||||
INSERT INTO users (username, email, password, user_level, roles, is_active, created_at, updated_at, program_user_id, program_user_table)
|
||||
VALUES (v_username, v_email, v_password, v_user_level, v_roles, true, now(), now(), v_program_user_id, v_program_user_table)
|
||||
RETURNING id INTO v_user_id;
|
||||
|
||||
-- Generate session token
|
||||
@@ -512,7 +535,9 @@ BEGIN
|
||||
'email', v_email,
|
||||
'user_level', v_user_level,
|
||||
'roles', string_to_array(COALESCE(v_roles, ''), ','),
|
||||
'session_id', v_session_token
|
||||
'session_id', v_session_token,
|
||||
'program_user_id', v_program_user_id,
|
||||
'program_user_table', v_program_user_table
|
||||
),
|
||||
'expires_in', 86400 -- 24 hours in seconds
|
||||
);
|
||||
@@ -671,12 +696,16 @@ DECLARE
|
||||
v_user_level INTEGER;
|
||||
v_roles TEXT;
|
||||
v_expires_at TIMESTAMP;
|
||||
v_program_user_id INTEGER;
|
||||
v_program_user_table TEXT;
|
||||
BEGIN
|
||||
-- Query session and user data from user_sessions table
|
||||
SELECT
|
||||
s.user_id, u.username, u.email, u.user_level, u.roles, s.expires_at
|
||||
s.user_id, u.username, u.email, u.user_level, u.roles, s.expires_at,
|
||||
u.program_user_id, u.program_user_table
|
||||
INTO
|
||||
v_user_id, v_username, v_email, v_user_level, v_roles, v_expires_at
|
||||
v_user_id, v_username, v_email, v_user_level, v_roles, v_expires_at,
|
||||
v_program_user_id, v_program_user_table
|
||||
FROM user_sessions s
|
||||
JOIN users u ON s.user_id = u.id
|
||||
WHERE s.session_token = p_session_token
|
||||
@@ -698,7 +727,9 @@ BEGIN
|
||||
'email', v_email,
|
||||
'user_level', v_user_level,
|
||||
'session_id', p_session_token,
|
||||
'roles', string_to_array(COALESCE(v_roles, ''), ',')
|
||||
'roles', string_to_array(COALESCE(v_roles, ''), ','),
|
||||
'program_user_id', COALESCE(v_program_user_id, 0),
|
||||
'program_user_table', COALESCE(v_program_user_table, '')
|
||||
);
|
||||
END;
|
||||
$$ LANGUAGE plpgsql;
|
||||
@@ -815,10 +846,12 @@ DECLARE
|
||||
v_email TEXT;
|
||||
v_user_level INTEGER;
|
||||
v_roles TEXT;
|
||||
v_program_user_id INTEGER;
|
||||
v_program_user_table TEXT;
|
||||
BEGIN
|
||||
-- Query user data
|
||||
SELECT username, email, user_level, roles
|
||||
INTO v_username, v_email, v_user_level, v_roles
|
||||
SELECT username, email, user_level, roles, program_user_id, program_user_table
|
||||
INTO v_username, v_email, v_user_level, v_roles, v_program_user_id, v_program_user_table
|
||||
FROM users
|
||||
WHERE id = p_user_id
|
||||
AND is_active = true;
|
||||
@@ -837,7 +870,9 @@ BEGIN
|
||||
'user_name', v_username,
|
||||
'email', v_email,
|
||||
'user_level', v_user_level,
|
||||
'roles', string_to_array(COALESCE(v_roles, ''), ',')
|
||||
'roles', string_to_array(COALESCE(v_roles, ''), ','),
|
||||
'program_user_id', COALESCE(v_program_user_id, 0),
|
||||
'program_user_table', COALESCE(v_program_user_table, '')
|
||||
);
|
||||
END;
|
||||
$$ LANGUAGE plpgsql;
|
||||
|
||||
@@ -90,7 +90,7 @@ func applyRowSecurity(secCtx SecurityContext, securityList *SecurityList) error
|
||||
|
||||
// Get primary key name from model
|
||||
modelType := reflect.TypeOf(model)
|
||||
if modelType.Kind() == reflect.Ptr {
|
||||
if modelType.Kind() == reflect.Pointer {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
|
||||
@@ -155,13 +155,13 @@ func applyColumnSecurity(secCtx SecurityContext, securityList *SecurityList) err
|
||||
|
||||
// Get model type
|
||||
modelType := reflect.TypeOf(model)
|
||||
if modelType.Kind() == reflect.Ptr {
|
||||
if modelType.Kind() == reflect.Pointer {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
|
||||
// Apply column security masking
|
||||
resultValue := reflect.ValueOf(result)
|
||||
if resultValue.Kind() == reflect.Ptr {
|
||||
if resultValue.Kind() == reflect.Pointer {
|
||||
resultValue = resultValue.Elem()
|
||||
}
|
||||
|
||||
|
||||
@@ -18,6 +18,8 @@ type UserContext struct {
|
||||
Claims map[string]any `json:"claims"`
|
||||
Meta map[string]any `json:"meta"` // Additional metadata that can hold any JSON-serializable values
|
||||
TwoFactorEnabled bool `json:"two_factor_enabled"` // Indicates if 2FA is enabled for this user
|
||||
ProgramUserID int `json:"program_user_id"`
|
||||
ProgramUserTable string `json:"program_user_table"`
|
||||
}
|
||||
|
||||
// LoginRequest contains credentials for login
|
||||
|
||||
@@ -19,6 +19,10 @@ type Config struct {
|
||||
// GZIP compression support
|
||||
GZIP bool
|
||||
|
||||
// HTTP2 enables HTTP/2 with the Extended CONNECT protocol (RFC 8441) for WebSocket support.
|
||||
// Requires TLS; pair with SSLCert/SSLKey, SelfSignedSSL, or AutoTLS.
|
||||
HTTP2 bool
|
||||
|
||||
// TLS/HTTPS configuration options (mutually exclusive)
|
||||
// Option 1: Provide certificate and key files directly
|
||||
SSLCert string
|
||||
|
||||
+32
-8
@@ -8,6 +8,7 @@ import (
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
@@ -461,15 +462,38 @@ func newInstance(cfg Config) (*serverInstance, error) {
|
||||
}
|
||||
|
||||
// Create gracefulServer
|
||||
httpServer := &http.Server{
|
||||
Addr: addr,
|
||||
Handler: handler,
|
||||
ReadTimeout: cfg.ReadTimeout,
|
||||
WriteTimeout: cfg.WriteTimeout,
|
||||
IdleTimeout: cfg.IdleTimeout,
|
||||
TLSConfig: tlsConfig,
|
||||
}
|
||||
|
||||
// Enable HTTP/2 with Extended CONNECT (RFC 8441) for WebSocket-over-H2 support.
|
||||
// The GODEBUG=http2xconnect=1 flag is read by net/http's init(); setting it here
|
||||
// ensures it propagates to subprocesses and any future process restarts.
|
||||
// For the current process, set GODEBUG=http2xconnect=1 in the environment before launch.
|
||||
if cfg.HTTP2 {
|
||||
if existing := os.Getenv("GODEBUG"); !strings.Contains(existing, "http2xconnect=1") {
|
||||
if existing == "" {
|
||||
os.Setenv("GODEBUG", "http2xconnect=1")
|
||||
} else {
|
||||
os.Setenv("GODEBUG", existing+",http2xconnect=1")
|
||||
}
|
||||
}
|
||||
if httpServer.HTTP2 == nil {
|
||||
httpServer.HTTP2 = &http.HTTP2Config{}
|
||||
}
|
||||
httpServer.Protocols.SetHTTP2(true)
|
||||
httpServer.Protocols.SetUnencryptedHTTP2(true)
|
||||
} else {
|
||||
httpServer.Protocols.SetHTTP2(false)
|
||||
}
|
||||
|
||||
gracefulSrv := &gracefulServer{
|
||||
server: &http.Server{
|
||||
Addr: addr,
|
||||
Handler: handler,
|
||||
ReadTimeout: cfg.ReadTimeout,
|
||||
WriteTimeout: cfg.WriteTimeout,
|
||||
IdleTimeout: cfg.IdleTimeout,
|
||||
TLSConfig: tlsConfig,
|
||||
},
|
||||
server: httpServer,
|
||||
shutdownTimeout: cfg.ShutdownTimeout,
|
||||
drainTimeout: cfg.DrainTimeout,
|
||||
shutdownComplete: make(chan struct{}),
|
||||
|
||||
@@ -671,6 +671,12 @@ func (h *Handler) create(hookCtx *HookContext) (interface{}, error) {
|
||||
return nil, fmt.Errorf("failed to create record: %w", err)
|
||||
}
|
||||
|
||||
// Re-fetch the created record to capture DB-generated defaults/triggers.
|
||||
if pkVal := reflection.GetPrimaryKeyValue(hookCtx.ModelPtr); pkVal != nil {
|
||||
hookCtx.ID = fmt.Sprintf("%v", pkVal)
|
||||
return h.readByID(hookCtx)
|
||||
}
|
||||
|
||||
return hookCtx.ModelPtr, nil
|
||||
}
|
||||
|
||||
@@ -834,7 +840,7 @@ func (h *Handler) buildFilterCondition(filter common.FilterOption) (conditionStr
|
||||
func (h *Handler) setRowNumbersOnRecords(records interface{}, offset int) {
|
||||
// Get the reflect value of the records
|
||||
recordsValue := reflect.ValueOf(records)
|
||||
if recordsValue.Kind() == reflect.Ptr {
|
||||
if recordsValue.Kind() == reflect.Pointer {
|
||||
recordsValue = recordsValue.Elem()
|
||||
}
|
||||
|
||||
@@ -849,7 +855,7 @@ func (h *Handler) setRowNumbersOnRecords(records interface{}, offset int) {
|
||||
record := recordsValue.Index(i)
|
||||
|
||||
// Dereference if it's a pointer
|
||||
if record.Kind() == reflect.Ptr {
|
||||
if record.Kind() == reflect.Pointer {
|
||||
if record.IsNil() {
|
||||
continue
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user