diff --git a/pkg/common/recursive_crud.go b/pkg/common/recursive_crud.go index caecdd8..524c40c 100644 --- a/pkg/common/recursive_crud.go +++ b/pkg/common/recursive_crud.go @@ -98,8 +98,8 @@ func (p *NestedCUDProcessor) ProcessNestedCUD( } } - // Filter regularData to only include fields that exist in the model - // Use MapToStruct to validate and filter fields + // Filter regularData to only include fields that exist in the model, + // and translate JSON keys to their actual database column names. regularData = p.filterValidFields(regularData, model) // Inject parent IDs for foreign key resolution @@ -191,14 +191,15 @@ func (p *NestedCUDProcessor) extractCRUDRequest(data map[string]interface{}) str return "" } -// filterValidFields filters input data to only include fields that exist in the model -// Uses reflection.MapToStruct to validate fields and extract only those that match the model +// filterValidFields filters input data to only include fields that exist in the model, +// and translates JSON key names to their actual database column names. +// For example, a field tagged `json:"_changed_date" bun:"changed_date"` will be +// included in the result as "changed_date", not "_changed_date". func (p *NestedCUDProcessor) filterValidFields(data map[string]interface{}, model interface{}) map[string]interface{} { if len(data) == 0 { return data } - // Create a new instance of the model to use with MapToStruct modelType := reflect.TypeOf(model) for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) { modelType = modelType.Elem() @@ -208,25 +209,16 @@ func (p *NestedCUDProcessor) filterValidFields(data map[string]interface{}, mode return data } - // Create a new instance of the model - tempModel := reflect.New(modelType).Interface() + // Build a mapping from JSON key -> DB column name for all writable fields. + // This both validates which fields belong to the model and translates their names + // to the correct column names for use in SQL insert/update queries. + jsonToDBCol := reflection.BuildJSONToDBColumnMap(modelType) - // Use MapToStruct to map the data - this will only map valid fields - err := reflection.MapToStruct(data, tempModel) - if err != nil { - logger.Debug("Error mapping data to model: %v", err) - return data - } - - // Extract the mapped fields back into a map - // This effectively filters out any fields that don't exist in the model filteredData := make(map[string]interface{}) - tempModelValue := reflect.ValueOf(tempModel).Elem() - for key, value := range data { - // Check if the field was successfully mapped - if fieldWasMapped(tempModelValue, modelType, key) { - filteredData[key] = value + dbColName, exists := jsonToDBCol[key] + if exists { + filteredData[dbColName] = value } else { logger.Debug("Skipping invalid field '%s' - not found in model %v", key, modelType) } @@ -235,72 +227,9 @@ func (p *NestedCUDProcessor) filterValidFields(data map[string]interface{}, mode return filteredData } -// fieldWasMapped checks if a field with the given key was mapped to the model -func fieldWasMapped(modelValue reflect.Value, modelType reflect.Type, key string) bool { - // Look for the field by JSON tag or field name - for i := 0; i < modelType.NumField(); i++ { - field := modelType.Field(i) - // Skip unexported fields - if !field.IsExported() { - continue - } - - // Check JSON tag - jsonTag := field.Tag.Get("json") - if jsonTag != "" && jsonTag != "-" { - parts := strings.Split(jsonTag, ",") - if len(parts) > 0 && parts[0] == key { - return true - } - } - - // Check bun tag - bunTag := field.Tag.Get("bun") - if bunTag != "" && bunTag != "-" { - if colName := reflection.ExtractColumnFromBunTag(bunTag); colName == key { - return true - } - } - - // Check gorm tag - gormTag := field.Tag.Get("gorm") - if gormTag != "" && gormTag != "-" { - if colName := reflection.ExtractColumnFromGormTag(gormTag); colName == key { - return true - } - } - - // Check lowercase field name - if strings.EqualFold(field.Name, key) { - return true - } - - // Handle embedded structs recursively - if field.Anonymous { - fieldType := field.Type - if fieldType.Kind() == reflect.Ptr { - fieldType = fieldType.Elem() - } - if fieldType.Kind() == reflect.Struct { - embeddedValue := modelValue.Field(i) - if embeddedValue.Kind() == reflect.Ptr { - if embeddedValue.IsNil() { - continue - } - embeddedValue = embeddedValue.Elem() - } - if fieldWasMapped(embeddedValue, fieldType, key) { - return true - } - } - } - } - - return false -} - -// injectForeignKeys injects parent IDs into data for foreign key fields +// injectForeignKeys injects parent IDs into data for foreign key fields. +// data is expected to be keyed by DB column names (as returned by filterValidFields). func (p *NestedCUDProcessor) injectForeignKeys(data map[string]interface{}, modelType reflect.Type, parentIDs map[string]interface{}) { if len(parentIDs) == 0 { return @@ -319,10 +248,11 @@ func (p *NestedCUDProcessor) injectForeignKeys(data map[string]interface{}, mode if strings.EqualFold(jsonName, parentKey+"_id") || strings.EqualFold(jsonName, parentKey+"id") || strings.EqualFold(field.Name, parentKey+"ID") { - // Only inject if not already present - if _, exists := data[jsonName]; !exists { - logger.Debug("Injecting foreign key: %s = %v", jsonName, parentID) - data[jsonName] = parentID + // Use the DB column name as the key, since data is keyed by DB column names + dbColName := reflection.GetColumnName(field) + if _, exists := data[dbColName]; !exists { + logger.Debug("Injecting foreign key: %s = %v", dbColName, parentID) + data[dbColName] = parentID } } } diff --git a/pkg/reflection/model_utils.go b/pkg/reflection/model_utils.go index 350fd72..672aee0 100644 --- a/pkg/reflection/model_utils.go +++ b/pkg/reflection/model_utils.go @@ -196,6 +196,92 @@ func collectColumnsFromType(typ reflect.Type, columns *[]string) { } } +// GetColumnName extracts the database column name from a struct field. +// Priority: bun tag -> gorm tag -> json tag -> lowercase field name. +// This is the exported version for use by other packages. +func GetColumnName(field reflect.StructField) string { + return getColumnNameFromField(field) +} + +// BuildJSONToDBColumnMap returns a map from JSON key names to database column names +// for the given model type. Only writable, non-relation fields are included. +// This is used to translate incoming request data (keyed by JSON names) into +// properly named database columns before insert/update operations. +func BuildJSONToDBColumnMap(modelType reflect.Type) map[string]string { + result := make(map[string]string) + buildJSONToDBMap(modelType, result, false) + return result +} + +func buildJSONToDBMap(modelType reflect.Type, result map[string]string, scanOnly bool) { + for i := 0; i < modelType.NumField(); i++ { + field := modelType.Field(i) + if !field.IsExported() { + continue + } + + bunTag := field.Tag.Get("bun") + gormTag := field.Tag.Get("gorm") + + // Handle embedded structs + if field.Anonymous { + ft := field.Type + if ft.Kind() == reflect.Ptr { + ft = ft.Elem() + } + isScanOnly := scanOnly + if bunTag != "" && isBunFieldScanOnly(bunTag) { + isScanOnly = true + } + if ft.Kind() == reflect.Struct { + buildJSONToDBMap(ft, result, isScanOnly) + continue + } + } + + if scanOnly { + continue + } + + // Skip explicitly excluded fields + if bunTag == "-" || gormTag == "-" { + continue + } + + // Skip scan-only fields + if bunTag != "" && isBunFieldScanOnly(bunTag) { + continue + } + + // Skip bun relation fields + if bunTag != "" && (strings.Contains(bunTag, "rel:") || strings.Contains(bunTag, "join:") || strings.Contains(bunTag, "m2m:")) { + continue + } + + // Skip gorm relation fields + if gormTag != "" && (strings.Contains(gormTag, "foreignKey:") || strings.Contains(gormTag, "references:") || strings.Contains(gormTag, "many2many:")) { + continue + } + + // Get JSON key (how the field appears in incoming request data) + jsonKey := "" + if jsonTag := field.Tag.Get("json"); jsonTag != "" && jsonTag != "-" { + parts := strings.Split(jsonTag, ",") + if len(parts) > 0 && parts[0] != "" { + jsonKey = parts[0] + } + } + if jsonKey == "" { + jsonKey = strings.ToLower(field.Name) + } + + // Get the actual DB column name (bun > gorm > json > field name) + dbColName := getColumnNameFromField(field) + + result[jsonKey] = dbColName + } +} + // getColumnNameFromField extracts the column name from a struct field // Priority: bun tag -> gorm tag -> json tag -> lowercase field name func getColumnNameFromField(field reflect.StructField) string { diff --git a/pkg/reflection/model_utils_test.go b/pkg/reflection/model_utils_test.go index 41b6529..0041a81 100644 --- a/pkg/reflection/model_utils_test.go +++ b/pkg/reflection/model_utils_test.go @@ -823,12 +823,12 @@ func TestToSnakeCase(t *testing.T) { { name: "UserID", input: "UserID", - expected: "user_i_d", + expected: "user_id", }, { name: "HTTPServer", input: "HTTPServer", - expected: "h_t_t_p_server", + expected: "http_server", }, { name: "lowercase", @@ -838,7 +838,7 @@ func TestToSnakeCase(t *testing.T) { { name: "UPPERCASE", input: "UPPERCASE", - expected: "u_p_p_e_r_c_a_s_e", + expected: "uppercase", }, { name: "Single",