feat(reflection): add ExtractTagValue and GetRelationshipInfo functions

* Implement ExtractTagValue to handle struct tag parsing.
* Introduce GetRelationshipInfo for extracting relationship metadata.
* Update tests to validate new functionality.
* Refactor related code for improved clarity and maintainability.
This commit is contained in:
Hein
2026-01-07 11:54:12 +02:00
parent e220ab3d34
commit bf7125efc3
8 changed files with 437 additions and 273 deletions

View File

@@ -1453,30 +1453,7 @@ func isNullable(field reflect.StructField) bool {
// GetRelationshipInfo implements common.RelationshipInfoProvider interface
func (h *Handler) GetRelationshipInfo(modelType reflect.Type, relationName string) *common.RelationshipInfo {
info := h.getRelationshipInfo(modelType, relationName)
if info == nil {
return nil
}
// Convert internal type to common type
return &common.RelationshipInfo{
FieldName: info.fieldName,
JSONName: info.jsonName,
RelationType: info.relationType,
ForeignKey: info.foreignKey,
References: info.references,
JoinTable: info.joinTable,
RelatedModel: info.relatedModel,
}
}
type relationshipInfo struct {
fieldName string
jsonName string
relationType string // "belongsTo", "hasMany", "hasOne", "many2many"
foreignKey string
references string
joinTable string
relatedModel interface{}
return common.GetRelationshipInfo(modelType, relationName)
}
func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, preloads []common.PreloadOption) (common.SelectQuery, error) {
@@ -1496,7 +1473,7 @@ func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, pre
for idx := range preloads {
preload := preloads[idx]
logger.Debug("Processing preload for relation: %s", preload.Relation)
relInfo := h.getRelationshipInfo(modelType, preload.Relation)
relInfo := common.GetRelationshipInfo(modelType, preload.Relation)
if relInfo == nil {
logger.Warn("Relation %s not found in model", preload.Relation)
continue
@@ -1504,7 +1481,7 @@ func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, pre
// Use the field name (capitalized) for ORM preloading
// ORMs like GORM and Bun expect the struct field name, not the JSON name
relationFieldName := relInfo.fieldName
relationFieldName := relInfo.FieldName
// Validate and fix WHERE clause to ensure it contains the relation prefix
if len(preload.Where) > 0 {
@@ -1547,13 +1524,13 @@ func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, pre
copy(columns, preload.Columns)
// Add foreign key if not already present
if relInfo.foreignKey != "" {
if relInfo.ForeignKey != "" {
// Convert struct field name (e.g., DepartmentID) to snake_case (e.g., department_id)
foreignKeyColumn := toSnakeCase(relInfo.foreignKey)
foreignKeyColumn := toSnakeCase(relInfo.ForeignKey)
hasForeignKey := false
for _, col := range columns {
if col == foreignKeyColumn || col == relInfo.foreignKey {
if col == foreignKeyColumn || col == relInfo.ForeignKey {
hasForeignKey = true
break
}
@@ -1599,58 +1576,6 @@ func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, pre
return query, nil
}
func (h *Handler) getRelationshipInfo(modelType reflect.Type, relationName string) *relationshipInfo {
// Ensure we have a struct type
if modelType == nil || modelType.Kind() != reflect.Struct {
logger.Warn("Cannot get relationship info from non-struct type: %v", modelType)
return nil
}
for i := 0; i < modelType.NumField(); i++ {
field := modelType.Field(i)
jsonTag := field.Tag.Get("json")
jsonName := strings.Split(jsonTag, ",")[0]
if jsonName == relationName {
gormTag := field.Tag.Get("gorm")
info := &relationshipInfo{
fieldName: field.Name,
jsonName: jsonName,
}
// Parse GORM tag to determine relationship type and keys
if strings.Contains(gormTag, "foreignKey") {
info.foreignKey = h.extractTagValue(gormTag, "foreignKey")
info.references = h.extractTagValue(gormTag, "references")
// Determine if it's belongsTo or hasMany/hasOne
if field.Type.Kind() == reflect.Slice {
info.relationType = "hasMany"
} else if field.Type.Kind() == reflect.Ptr || field.Type.Kind() == reflect.Struct {
info.relationType = "belongsTo"
}
} else if strings.Contains(gormTag, "many2many") {
info.relationType = "many2many"
info.joinTable = h.extractTagValue(gormTag, "many2many")
}
return info
}
}
return nil
}
func (h *Handler) extractTagValue(tag, key string) string {
parts := strings.Split(tag, ";")
for _, part := range parts {
part = strings.TrimSpace(part)
if strings.HasPrefix(part, key+":") {
return strings.TrimPrefix(part, key+":")
}
}
return ""
}
// toSnakeCase converts a PascalCase or camelCase string to snake_case
func toSnakeCase(s string) string {
var result strings.Builder