diff --git a/pkg/common/handler_utils.go b/pkg/common/handler_utils.go index 61716fb..6e1ee12 100644 --- a/pkg/common/handler_utils.go +++ b/pkg/common/handler_utils.go @@ -3,6 +3,9 @@ package common import ( "fmt" "reflect" + "strings" + + "github.com/bitechdev/ResolveSpec/pkg/logger" ) // ValidateAndUnwrapModelResult contains the result of model validation @@ -45,3 +48,216 @@ func ValidateAndUnwrapModel(model interface{}) (*ValidateAndUnwrapModelResult, e OriginalType: originalType, }, nil } + +// ExtractTagValue extracts the value for a given key from a struct tag string. +// It handles both semicolon and comma-separated tag formats (e.g., GORM and BUN tags). +// For tags like "json:name;validate:required" it will extract "name" for key "json". +// For tags like "rel:has-many,join:table" it will extract "table" for key "join". +func ExtractTagValue(tag, key string) string { + // Split by both semicolons and commas to handle different tag formats + // We need to be smart about this - commas can be part of values + // So we'll try semicolon first, then comma if needed + separators := []string{";", ","} + + for _, sep := range separators { + parts := strings.Split(tag, sep) + for _, part := range parts { + part = strings.TrimSpace(part) + if strings.HasPrefix(part, key+":") { + return strings.TrimPrefix(part, key+":") + } + } + } + return "" +} + +// GetRelationshipInfo analyzes a model type and extracts relationship metadata +// for a specific relation field identified by its JSON name. +// Returns nil if the field is not found or is not a valid relationship. +func GetRelationshipInfo(modelType reflect.Type, relationName string) *RelationshipInfo { + // Ensure we have a struct type + if modelType == nil || modelType.Kind() != reflect.Struct { + logger.Warn("Cannot get relationship info from non-struct type: %v", modelType) + return nil + } + + for i := 0; i < modelType.NumField(); i++ { + field := modelType.Field(i) + jsonTag := field.Tag.Get("json") + jsonName := strings.Split(jsonTag, ",")[0] + + if jsonName == relationName { + gormTag := field.Tag.Get("gorm") + bunTag := field.Tag.Get("bun") + info := &RelationshipInfo{ + FieldName: field.Name, + JSONName: jsonName, + } + + if strings.Contains(bunTag, "rel:") || strings.Contains(bunTag, "join:") { + //bun:"rel:has-many,join:rid_hub=rid_hub_division" + if strings.Contains(bunTag, "has-many") { + info.RelationType = "hasMany" + } else if strings.Contains(bunTag, "has-one") { + info.RelationType = "hasOne" + } else if strings.Contains(bunTag, "belongs-to") { + info.RelationType = "belongsTo" + } else if strings.Contains(bunTag, "many-to-many") { + info.RelationType = "many2many" + } else { + info.RelationType = "hasOne" + } + + // Extract join info + joinPart := ExtractTagValue(bunTag, "join") + if joinPart != "" && info.RelationType == "many2many" { + // For many2many, the join part is the join table name + info.JoinTable = joinPart + } else if joinPart != "" { + // For other relations, parse foreignKey and references + joinParts := strings.Split(joinPart, "=") + if len(joinParts) == 2 { + info.ForeignKey = joinParts[0] + info.References = joinParts[1] + } + } + + // Get related model type + if field.Type.Kind() == reflect.Slice { + elemType := field.Type.Elem() + if elemType.Kind() == reflect.Ptr { + elemType = elemType.Elem() + } + if elemType.Kind() == reflect.Struct { + info.RelatedModel = reflect.New(elemType).Elem().Interface() + } + } else if field.Type.Kind() == reflect.Ptr || field.Type.Kind() == reflect.Struct { + elemType := field.Type + if elemType.Kind() == reflect.Ptr { + elemType = elemType.Elem() + } + if elemType.Kind() == reflect.Struct { + info.RelatedModel = reflect.New(elemType).Elem().Interface() + } + } + + return info + } + + // Parse GORM tag to determine relationship type and keys + if strings.Contains(gormTag, "foreignKey") { + info.ForeignKey = ExtractTagValue(gormTag, "foreignKey") + info.References = ExtractTagValue(gormTag, "references") + + // Determine if it's belongsTo or hasMany/hasOne + if field.Type.Kind() == reflect.Slice { + info.RelationType = "hasMany" + // Get the element type for slice + elemType := field.Type.Elem() + if elemType.Kind() == reflect.Ptr { + elemType = elemType.Elem() + } + if elemType.Kind() == reflect.Struct { + info.RelatedModel = reflect.New(elemType).Elem().Interface() + } + } else if field.Type.Kind() == reflect.Ptr || field.Type.Kind() == reflect.Struct { + info.RelationType = "belongsTo" + elemType := field.Type + if elemType.Kind() == reflect.Ptr { + elemType = elemType.Elem() + } + if elemType.Kind() == reflect.Struct { + info.RelatedModel = reflect.New(elemType).Elem().Interface() + } + } + } else if strings.Contains(gormTag, "many2many") { + info.RelationType = "many2many" + info.JoinTable = ExtractTagValue(gormTag, "many2many") + // Get the element type for many2many (always slice) + if field.Type.Kind() == reflect.Slice { + elemType := field.Type.Elem() + if elemType.Kind() == reflect.Ptr { + elemType = elemType.Elem() + } + if elemType.Kind() == reflect.Struct { + info.RelatedModel = reflect.New(elemType).Elem().Interface() + } + } + } else { + // Field has no GORM relationship tags, so it's not a relation + return nil + } + + return info + } + } + return nil +} + +// RelationPathToBunAlias converts a relation path (e.g., "Order.Customer") to a Bun alias format. +// It converts to lowercase and replaces dots with double underscores. +// For example: "Order.Customer" -> "order__customer" +func RelationPathToBunAlias(relationPath string) string { + if relationPath == "" { + return "" + } + // Convert to lowercase and replace dots with double underscores + alias := strings.ToLower(relationPath) + alias = strings.ReplaceAll(alias, ".", "__") + return alias +} + +// ReplaceTableReferencesInSQL replaces references to a base table name in a SQL expression +// with the appropriate alias for the current preload level. +// For example, if baseTableName is "mastertaskitem" and targetAlias is "mal__mal", +// it will replace "mastertaskitem.rid_mastertaskitem" with "mal__mal.rid_mastertaskitem" +func ReplaceTableReferencesInSQL(sqlExpr, baseTableName, targetAlias string) string { + if sqlExpr == "" || baseTableName == "" || targetAlias == "" { + return sqlExpr + } + + // Replace both quoted and unquoted table references + // Handle patterns like: tablename.column, "tablename".column, tablename."column", "tablename"."column" + + // Pattern 1: tablename.column (unquoted) + result := strings.ReplaceAll(sqlExpr, baseTableName+".", targetAlias+".") + + // Pattern 2: "tablename".column or "tablename"."column" (quoted table name) + result = strings.ReplaceAll(result, "\""+baseTableName+"\".", "\""+targetAlias+"\".") + + return result +} + +// GetTableNameFromModel extracts the table name from a model. +// It checks the bun tag first, then falls back to converting the struct name to snake_case. +func GetTableNameFromModel(model interface{}) string { + if model == nil { + return "" + } + + modelType := reflect.TypeOf(model) + + // Unwrap pointers + for modelType != nil && modelType.Kind() == reflect.Ptr { + modelType = modelType.Elem() + } + + if modelType == nil || modelType.Kind() != reflect.Struct { + return "" + } + + // Look for bun tag on embedded BaseModel + for i := 0; i < modelType.NumField(); i++ { + field := modelType.Field(i) + if field.Anonymous { + bunTag := field.Tag.Get("bun") + if strings.HasPrefix(bunTag, "table:") { + return strings.TrimPrefix(bunTag, "table:") + } + } + } + + // Fallback: convert struct name to lowercase (simple heuristic) + // This handles cases like "MasterTaskItem" -> "mastertaskitem" + return strings.ToLower(modelType.Name()) +} diff --git a/pkg/common/handler_utils_test.go b/pkg/common/handler_utils_test.go new file mode 100644 index 0000000..05d374f --- /dev/null +++ b/pkg/common/handler_utils_test.go @@ -0,0 +1,108 @@ +package common + +import ( + "testing" +) + +func TestExtractTagValue(t *testing.T) { + tests := []struct { + name string + tag string + key string + expected string + }{ + { + name: "Extract existing key", + tag: "json:name;validate:required", + key: "json", + expected: "name", + }, + { + name: "Extract key with spaces", + tag: "json:name ; validate:required", + key: "validate", + expected: "required", + }, + { + name: "Extract key at end", + tag: "json:name;validate:required;db:column_name", + key: "db", + expected: "column_name", + }, + { + name: "Extract key at beginning", + tag: "primary:true;json:id;db:user_id", + key: "primary", + expected: "true", + }, + { + name: "Key not found", + tag: "json:name;validate:required", + key: "db", + expected: "", + }, + { + name: "Empty tag", + tag: "", + key: "json", + expected: "", + }, + { + name: "Single key-value pair", + tag: "json:name", + key: "json", + expected: "name", + }, + { + name: "Key with empty value", + tag: "json:;validate:required", + key: "json", + expected: "", + }, + { + name: "Key with complex value", + tag: "json:user_name,omitempty;validate:required,min=3", + key: "json", + expected: "user_name,omitempty", + }, + { + name: "Multiple semicolons", + tag: "json:name;;validate:required", + key: "validate", + expected: "required", + }, + { + name: "BUN Tag with comma separator", + tag: "rel:has-many,join:rid_hub=rid_hub_child", + key: "join", + expected: "rid_hub=rid_hub_child", + }, + { + name: "Extract foreignKey", + tag: "foreignKey:UserID;references:ID", + key: "foreignKey", + expected: "UserID", + }, + { + name: "Extract references", + tag: "foreignKey:UserID;references:ID", + key: "references", + expected: "ID", + }, + { + name: "Extract many2many", + tag: "many2many:user_roles", + key: "many2many", + expected: "user_roles", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := ExtractTagValue(tt.tag, tt.key) + if result != tt.expected { + t.Errorf("ExtractTagValue(%q, %q) = %q; want %q", tt.tag, tt.key, result, tt.expected) + } + }) + } +} diff --git a/pkg/common/recursive_crud.go b/pkg/common/recursive_crud.go index f7f06a7..13b6f89 100644 --- a/pkg/common/recursive_crud.go +++ b/pkg/common/recursive_crud.go @@ -20,17 +20,6 @@ type RelationshipInfoProvider interface { GetRelationshipInfo(modelType reflect.Type, relationName string) *RelationshipInfo } -// RelationshipInfo contains information about a model relationship -type RelationshipInfo struct { - FieldName string - JSONName string - RelationType string // "belongsTo", "hasMany", "hasOne", "many2many" - ForeignKey string - References string - JoinTable string - RelatedModel interface{} -} - // NestedCUDProcessor handles recursive processing of nested object graphs type NestedCUDProcessor struct { db Database diff --git a/pkg/common/types.go b/pkg/common/types.go index d8e54b2..b09b3db 100644 --- a/pkg/common/types.go +++ b/pkg/common/types.go @@ -111,3 +111,14 @@ type TableMetadata struct { Columns []Column `json:"columns"` Relations []string `json:"relations"` } + +// RelationshipInfo contains information about a model relationship +type RelationshipInfo struct { + FieldName string `json:"field_name"` + JSONName string `json:"json_name"` + RelationType string `json:"relation_type"` // "belongsTo", "hasMany", "hasOne", "many2many" + ForeignKey string `json:"foreign_key"` + References string `json:"references"` + JoinTable string `json:"join_table"` + RelatedModel interface{} `json:"related_model"` +} diff --git a/pkg/resolvespec/handler.go b/pkg/resolvespec/handler.go index 4a1aea8..bf082e7 100644 --- a/pkg/resolvespec/handler.go +++ b/pkg/resolvespec/handler.go @@ -1453,30 +1453,7 @@ func isNullable(field reflect.StructField) bool { // GetRelationshipInfo implements common.RelationshipInfoProvider interface func (h *Handler) GetRelationshipInfo(modelType reflect.Type, relationName string) *common.RelationshipInfo { - info := h.getRelationshipInfo(modelType, relationName) - if info == nil { - return nil - } - // Convert internal type to common type - return &common.RelationshipInfo{ - FieldName: info.fieldName, - JSONName: info.jsonName, - RelationType: info.relationType, - ForeignKey: info.foreignKey, - References: info.references, - JoinTable: info.joinTable, - RelatedModel: info.relatedModel, - } -} - -type relationshipInfo struct { - fieldName string - jsonName string - relationType string // "belongsTo", "hasMany", "hasOne", "many2many" - foreignKey string - references string - joinTable string - relatedModel interface{} + return common.GetRelationshipInfo(modelType, relationName) } func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, preloads []common.PreloadOption) (common.SelectQuery, error) { @@ -1496,7 +1473,7 @@ func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, pre for idx := range preloads { preload := preloads[idx] logger.Debug("Processing preload for relation: %s", preload.Relation) - relInfo := h.getRelationshipInfo(modelType, preload.Relation) + relInfo := common.GetRelationshipInfo(modelType, preload.Relation) if relInfo == nil { logger.Warn("Relation %s not found in model", preload.Relation) continue @@ -1504,7 +1481,7 @@ func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, pre // Use the field name (capitalized) for ORM preloading // ORMs like GORM and Bun expect the struct field name, not the JSON name - relationFieldName := relInfo.fieldName + relationFieldName := relInfo.FieldName // Validate and fix WHERE clause to ensure it contains the relation prefix if len(preload.Where) > 0 { @@ -1547,13 +1524,13 @@ func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, pre copy(columns, preload.Columns) // Add foreign key if not already present - if relInfo.foreignKey != "" { + if relInfo.ForeignKey != "" { // Convert struct field name (e.g., DepartmentID) to snake_case (e.g., department_id) - foreignKeyColumn := toSnakeCase(relInfo.foreignKey) + foreignKeyColumn := toSnakeCase(relInfo.ForeignKey) hasForeignKey := false for _, col := range columns { - if col == foreignKeyColumn || col == relInfo.foreignKey { + if col == foreignKeyColumn || col == relInfo.ForeignKey { hasForeignKey = true break } @@ -1599,58 +1576,6 @@ func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, pre return query, nil } -func (h *Handler) getRelationshipInfo(modelType reflect.Type, relationName string) *relationshipInfo { - // Ensure we have a struct type - if modelType == nil || modelType.Kind() != reflect.Struct { - logger.Warn("Cannot get relationship info from non-struct type: %v", modelType) - return nil - } - - for i := 0; i < modelType.NumField(); i++ { - field := modelType.Field(i) - jsonTag := field.Tag.Get("json") - jsonName := strings.Split(jsonTag, ",")[0] - - if jsonName == relationName { - gormTag := field.Tag.Get("gorm") - info := &relationshipInfo{ - fieldName: field.Name, - jsonName: jsonName, - } - - // Parse GORM tag to determine relationship type and keys - if strings.Contains(gormTag, "foreignKey") { - info.foreignKey = h.extractTagValue(gormTag, "foreignKey") - info.references = h.extractTagValue(gormTag, "references") - - // Determine if it's belongsTo or hasMany/hasOne - if field.Type.Kind() == reflect.Slice { - info.relationType = "hasMany" - } else if field.Type.Kind() == reflect.Ptr || field.Type.Kind() == reflect.Struct { - info.relationType = "belongsTo" - } - } else if strings.Contains(gormTag, "many2many") { - info.relationType = "many2many" - info.joinTable = h.extractTagValue(gormTag, "many2many") - } - - return info - } - } - return nil -} - -func (h *Handler) extractTagValue(tag, key string) string { - parts := strings.Split(tag, ";") - for _, part := range parts { - part = strings.TrimSpace(part) - if strings.HasPrefix(part, key+":") { - return strings.TrimPrefix(part, key+":") - } - } - return "" -} - // toSnakeCase converts a PascalCase or camelCase string to snake_case func toSnakeCase(s string) string { var result strings.Builder diff --git a/pkg/resolvespec/handler_test.go b/pkg/resolvespec/handler_test.go index ac36b6f..d57e49d 100644 --- a/pkg/resolvespec/handler_test.go +++ b/pkg/resolvespec/handler_test.go @@ -269,8 +269,6 @@ func TestToSnakeCase(t *testing.T) { } func TestExtractTagValue(t *testing.T) { - handler := NewHandler(nil, nil) - tests := []struct { name string tag string @@ -311,9 +309,9 @@ func TestExtractTagValue(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result := handler.extractTagValue(tt.tag, tt.key) + result := common.ExtractTagValue(tt.tag, tt.key) if result != tt.expected { - t.Errorf("extractTagValue(%q, %q) = %q, expected %q", tt.tag, tt.key, result, tt.expected) + t.Errorf("ExtractTagValue(%q, %q) = %q, expected %q", tt.tag, tt.key, result, tt.expected) } }) } diff --git a/pkg/restheadspec/handler.go b/pkg/restheadspec/handler.go index 3a47fad..2f20ce1 100644 --- a/pkg/restheadspec/handler.go +++ b/pkg/restheadspec/handler.go @@ -766,7 +766,7 @@ func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload co // Apply ComputedQL fields if any if len(preload.ComputedQL) > 0 { // Get the base table name from the related model - baseTableName := getTableNameFromModel(relatedModel) + baseTableName := common.GetTableNameFromModel(relatedModel) // Convert the preload relation path to the appropriate alias format // This is ORM-specific. Currently we only support Bun's format. @@ -777,7 +777,7 @@ func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload co underlyingType := fmt.Sprintf("%T", h.db.GetUnderlyingDB()) if strings.Contains(underlyingType, "bun.DB") { // Use Bun's alias format: lowercase with double underscores - preloadAlias = relationPathToBunAlias(preload.Relation) + preloadAlias = common.RelationPathToBunAlias(preload.Relation) } // For GORM: GORM doesn't use the same alias format, and this fix // may not be needed since GORM handles preloads differently @@ -792,7 +792,7 @@ func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload co // levels of recursive/nested preloads adjustedExpr := colExpr if baseTableName != "" && preloadAlias != "" { - adjustedExpr = replaceTableReferencesInSQL(colExpr, baseTableName, preloadAlias) + adjustedExpr = common.ReplaceTableReferencesInSQL(colExpr, baseTableName, preloadAlias) if adjustedExpr != colExpr { logger.Debug("Adjusted computed column expression for %s: '%s' -> '%s'", colName, colExpr, adjustedExpr) @@ -903,73 +903,6 @@ func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload co return query } -// relationPathToBunAlias converts a relation path like "MAL.MAL.DEF" to the Bun alias format "mal__mal__def" -// Bun generates aliases for nested relations by lowercasing and replacing dots with double underscores -func relationPathToBunAlias(relationPath string) string { - if relationPath == "" { - return "" - } - // Convert to lowercase and replace dots with double underscores - alias := strings.ToLower(relationPath) - alias = strings.ReplaceAll(alias, ".", "__") - return alias -} - -// replaceTableReferencesInSQL replaces references to a base table name in a SQL expression -// with the appropriate alias for the current preload level -// For example, if baseTableName is "mastertaskitem" and targetAlias is "mal__mal", -// it will replace "mastertaskitem.rid_mastertaskitem" with "mal__mal.rid_mastertaskitem" -func replaceTableReferencesInSQL(sqlExpr, baseTableName, targetAlias string) string { - if sqlExpr == "" || baseTableName == "" || targetAlias == "" { - return sqlExpr - } - - // Replace both quoted and unquoted table references - // Handle patterns like: tablename.column, "tablename".column, tablename."column", "tablename"."column" - - // Pattern 1: tablename.column (unquoted) - result := strings.ReplaceAll(sqlExpr, baseTableName+".", targetAlias+".") - - // Pattern 2: "tablename".column or "tablename"."column" (quoted table name) - result = strings.ReplaceAll(result, "\""+baseTableName+"\".", "\""+targetAlias+"\".") - - return result -} - -// getTableNameFromModel extracts the table name from a model -// It checks the bun tag first, then falls back to converting the struct name to snake_case -func getTableNameFromModel(model interface{}) string { - if model == nil { - return "" - } - - modelType := reflect.TypeOf(model) - - // Unwrap pointers - for modelType != nil && modelType.Kind() == reflect.Ptr { - modelType = modelType.Elem() - } - - if modelType == nil || modelType.Kind() != reflect.Struct { - return "" - } - - // Look for bun tag on embedded BaseModel - for i := 0; i < modelType.NumField(); i++ { - field := modelType.Field(i) - if field.Anonymous { - bunTag := field.Tag.Get("bun") - if strings.HasPrefix(bunTag, "table:") { - return strings.TrimPrefix(bunTag, "table:") - } - } - } - - // Fallback: convert struct name to lowercase (simple heuristic) - // This handles cases like "MasterTaskItem" -> "mastertaskitem" - return strings.ToLower(modelType.Name()) -} - func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, data interface{}, options ExtendedRequestOptions) { // Capture panics and return error response defer func() { @@ -2570,10 +2503,10 @@ func (h *Handler) filterExtendedOptions(validator *common.ColumnValidator, optio filteredExpand := expand // Get the relationship info for this expand relation - relInfo := h.getRelationshipInfo(modelType, expand.Relation) - if relInfo != nil && relInfo.relatedModel != nil { + relInfo := common.GetRelationshipInfo(modelType, expand.Relation) + if relInfo != nil && relInfo.RelatedModel != nil { // Create a validator for the related model - expandValidator := common.NewColumnValidator(relInfo.relatedModel) + expandValidator := common.NewColumnValidator(relInfo.RelatedModel) // Filter columns using the related model's validator filteredExpand.Columns = expandValidator.FilterValidColumns(expand.Columns) @@ -2650,110 +2583,7 @@ func (h *Handler) shouldUseNestedProcessor(data map[string]interface{}, model in // GetRelationshipInfo implements common.RelationshipInfoProvider interface func (h *Handler) GetRelationshipInfo(modelType reflect.Type, relationName string) *common.RelationshipInfo { - info := h.getRelationshipInfo(modelType, relationName) - if info == nil { - return nil - } - // Convert internal type to common type - return &common.RelationshipInfo{ - FieldName: info.fieldName, - JSONName: info.jsonName, - RelationType: info.relationType, - ForeignKey: info.foreignKey, - References: info.references, - JoinTable: info.joinTable, - RelatedModel: info.relatedModel, - } -} - -type relationshipInfo struct { - fieldName string - jsonName string - relationType string // "belongsTo", "hasMany", "hasOne", "many2many" - foreignKey string - references string - joinTable string - relatedModel interface{} -} - -func (h *Handler) getRelationshipInfo(modelType reflect.Type, relationName string) *relationshipInfo { - // Ensure we have a struct type - if modelType == nil || modelType.Kind() != reflect.Struct { - logger.Warn("Cannot get relationship info from non-struct type: %v", modelType) - return nil - } - - for i := 0; i < modelType.NumField(); i++ { - field := modelType.Field(i) - jsonTag := field.Tag.Get("json") - jsonName := strings.Split(jsonTag, ",")[0] - - if jsonName == relationName { - gormTag := field.Tag.Get("gorm") - info := &relationshipInfo{ - fieldName: field.Name, - jsonName: jsonName, - } - - // Parse GORM tag to determine relationship type and keys - if strings.Contains(gormTag, "foreignKey") { - info.foreignKey = h.extractTagValue(gormTag, "foreignKey") - info.references = h.extractTagValue(gormTag, "references") - - // Determine if it's belongsTo or hasMany/hasOne - if field.Type.Kind() == reflect.Slice { - info.relationType = "hasMany" - // Get the element type for slice - elemType := field.Type.Elem() - if elemType.Kind() == reflect.Ptr { - elemType = elemType.Elem() - } - if elemType.Kind() == reflect.Struct { - info.relatedModel = reflect.New(elemType).Elem().Interface() - } - } else if field.Type.Kind() == reflect.Ptr || field.Type.Kind() == reflect.Struct { - info.relationType = "belongsTo" - elemType := field.Type - if elemType.Kind() == reflect.Ptr { - elemType = elemType.Elem() - } - if elemType.Kind() == reflect.Struct { - info.relatedModel = reflect.New(elemType).Elem().Interface() - } - } - } else if strings.Contains(gormTag, "many2many") { - info.relationType = "many2many" - info.joinTable = h.extractTagValue(gormTag, "many2many") - // Get the element type for many2many (always slice) - if field.Type.Kind() == reflect.Slice { - elemType := field.Type.Elem() - if elemType.Kind() == reflect.Ptr { - elemType = elemType.Elem() - } - if elemType.Kind() == reflect.Struct { - info.relatedModel = reflect.New(elemType).Elem().Interface() - } - } - } else { - // Field has no GORM relationship tags, so it's not a relation - return nil - } - - return info - } - } - return nil -} - -func (h *Handler) extractTagValue(tag, key string) string { - parts := strings.Split(tag, ";") - for _, part := range parts { - part = strings.TrimSpace(part) - if strings.HasPrefix(part, key+":") { - return strings.TrimPrefix(part, key+":") - } - } - return "" + return common.GetRelationshipInfo(modelType, relationName) } // HandleOpenAPI generates and returns the OpenAPI specification diff --git a/pkg/restheadspec/restheadspec_test.go b/pkg/restheadspec/restheadspec_test.go index 355938b..53a1b23 100644 --- a/pkg/restheadspec/restheadspec_test.go +++ b/pkg/restheadspec/restheadspec_test.go @@ -2,6 +2,8 @@ package restheadspec import ( "testing" + + "github.com/bitechdev/ResolveSpec/pkg/common" ) func TestParseModelName(t *testing.T) { @@ -112,3 +114,88 @@ func TestNewStandardBunRouter(t *testing.T) { t.Error("Expected router to be created, got nil") } } + +func TestExtractTagValue(t *testing.T) { + tests := []struct { + name string + tag string + key string + expected string + }{ + { + name: "Extract existing key", + tag: "json:name;validate:required", + key: "json", + expected: "name", + }, + { + name: "Extract key with spaces", + tag: "json:name ; validate:required", + key: "validate", + expected: "required", + }, + { + name: "Extract key at end", + tag: "json:name;validate:required;db:column_name", + key: "db", + expected: "column_name", + }, + { + name: "Extract key at beginning", + tag: "primary:true;json:id;db:user_id", + key: "primary", + expected: "true", + }, + { + name: "Key not found", + tag: "json:name;validate:required", + key: "db", + expected: "", + }, + { + name: "Empty tag", + tag: "", + key: "json", + expected: "", + }, + { + name: "Single key-value pair", + tag: "json:name", + key: "json", + expected: "name", + }, + { + name: "Key with empty value", + tag: "json:;validate:required", + key: "json", + expected: "", + }, + { + name: "Key with complex value", + tag: "json:user_name,omitempty;validate:required,min=3", + key: "json", + expected: "user_name,omitempty", + }, + { + name: "Multiple semicolons", + tag: "json:name;;validate:required", + key: "validate", + expected: "required", + }, + { + name: "BUN Tag", + tag: "rel:has-many,join:rid_hub=rid_hub_child", + key: "join", + expected: "rid_hub=rid_hub_child", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := common.ExtractTagValue(tt.tag, tt.key) + if result != tt.expected { + t.Errorf("ExtractTagValue(%q, %q) = %q; want %q", tt.tag, tt.key, result, tt.expected) + } + }) + } +}