ResolveSpec/pkg/reflection/model_utils.go

443 lines
12 KiB
Go

package reflection
import (
"reflect"
"strings"
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
)
type PrimaryKeyNameProvider interface {
GetIDName() string
}
// GetPrimaryKeyName extracts the primary key column name from a model
// It first checks if the model implements PrimaryKeyNameProvider (GetIDName method)
// Falls back to reflection to find bun:",pk" tag, then gorm:"primaryKey" tag
func GetPrimaryKeyName(model any) string {
if reflect.TypeOf(model) == nil {
return ""
}
// If we are given a string model name, look up the model
if reflect.TypeOf(model).Kind() == reflect.String {
name := model.(string)
m, err := modelregistry.GetModelByName(name)
if err == nil {
model = m
}
}
// Check if model implements PrimaryKeyNameProvider
if provider, ok := model.(PrimaryKeyNameProvider); ok {
return provider.GetIDName()
}
// Try Bun tag first
if pkName := getPrimaryKeyFromReflection(model, "bun"); pkName != "" {
return pkName
}
// Fall back to GORM tag
if pkName := getPrimaryKeyFromReflection(model, "gorm"); pkName != "" {
return pkName
}
return ""
}
// GetPrimaryKeyValue extracts the primary key value from a model instance
// Returns the value of the primary key field
func GetPrimaryKeyValue(model any) any {
if model == nil || reflect.TypeOf(model) == nil {
return nil
}
val := reflect.ValueOf(model)
if val.Kind() == reflect.Pointer {
val = val.Elem()
}
if val.Kind() != reflect.Struct {
return nil
}
// Try Bun tag first
if pkValue := findPrimaryKeyValue(val, "bun"); pkValue != nil {
return pkValue
}
// Fall back to GORM tag
if pkValue := findPrimaryKeyValue(val, "gorm"); pkValue != nil {
return pkValue
}
// Last resort: look for field named "ID" or "Id"
if pkValue := findFieldByName(val, "id"); pkValue != nil {
return pkValue
}
return nil
}
// findPrimaryKeyValue recursively searches for a primary key field in the struct
func findPrimaryKeyValue(val reflect.Value, ormType string) any {
typ := val.Type()
for i := 0; i < typ.NumField(); i++ {
field := typ.Field(i)
fieldValue := val.Field(i)
// Check if this is an embedded struct
if field.Anonymous && field.Type.Kind() == reflect.Struct {
// Recursively search in embedded struct
if pkValue := findPrimaryKeyValue(fieldValue, ormType); pkValue != nil {
return pkValue
}
continue
}
// Check for primary key tag
switch ormType {
case "bun":
bunTag := field.Tag.Get("bun")
if strings.Contains(bunTag, "pk") && fieldValue.CanInterface() {
return fieldValue.Interface()
}
case "gorm":
gormTag := field.Tag.Get("gorm")
if strings.Contains(gormTag, "primaryKey") && fieldValue.CanInterface() {
return fieldValue.Interface()
}
}
}
return nil
}
// findFieldByName recursively searches for a field by name in the struct
func findFieldByName(val reflect.Value, name string) any {
typ := val.Type()
for i := 0; i < typ.NumField(); i++ {
field := typ.Field(i)
fieldValue := val.Field(i)
// Check if this is an embedded struct
if field.Anonymous && field.Type.Kind() == reflect.Struct {
// Recursively search in embedded struct
if result := findFieldByName(fieldValue, name); result != nil {
return result
}
continue
}
// Check if field name matches
if strings.ToLower(field.Name) == name && fieldValue.CanInterface() {
return fieldValue.Interface()
}
}
return nil
}
// GetModelColumns extracts all column names from a model using reflection
// It checks bun tags first, then gorm tags, then json tags, and finally falls back to lowercase field names
// This function recursively processes embedded structs to include their fields
func GetModelColumns(model any) []string {
var columns []string
modelType := reflect.TypeOf(model)
// Unwrap pointers, slices, and arrays to get to the base struct type
for modelType != nil && (modelType.Kind() == reflect.Pointer || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
modelType = modelType.Elem()
}
// Validate that we have a struct type
if modelType == nil || modelType.Kind() != reflect.Struct {
return columns
}
collectColumnsFromType(modelType, &columns)
return columns
}
// collectColumnsFromType recursively collects column names from a struct type and its embedded fields
func collectColumnsFromType(typ reflect.Type, columns *[]string) {
for i := 0; i < typ.NumField(); i++ {
field := typ.Field(i)
// Check if this is an embedded struct
if field.Anonymous {
// Unwrap pointer type if necessary
fieldType := field.Type
if fieldType.Kind() == reflect.Pointer {
fieldType = fieldType.Elem()
}
// Recursively process embedded struct
if fieldType.Kind() == reflect.Struct {
collectColumnsFromType(fieldType, columns)
continue
}
}
// Get column name using the same logic as primary key extraction
columnName := getColumnNameFromField(field)
if columnName != "" {
*columns = append(*columns, columnName)
}
}
}
// getColumnNameFromField extracts the column name from a struct field
// Priority: bun tag -> gorm tag -> json tag -> lowercase field name
func getColumnNameFromField(field reflect.StructField) string {
// Try bun tag first
bunTag := field.Tag.Get("bun")
if bunTag != "" && bunTag != "-" {
if colName := ExtractColumnFromBunTag(bunTag); colName != "" {
return colName
}
}
// Try gorm tag
gormTag := field.Tag.Get("gorm")
if gormTag != "" && gormTag != "-" {
if colName := ExtractColumnFromGormTag(gormTag); colName != "" {
return colName
}
}
// Fall back to json tag
jsonTag := field.Tag.Get("json")
if jsonTag != "" && jsonTag != "-" {
// Extract just the field name before any options
parts := strings.Split(jsonTag, ",")
if len(parts) > 0 && parts[0] != "" {
return parts[0]
}
}
// Last resort: use field name in lowercase
return strings.ToLower(field.Name)
}
// getPrimaryKeyFromReflection uses reflection to find the primary key field
// This function recursively searches embedded structs
func getPrimaryKeyFromReflection(model any, ormType string) string {
val := reflect.ValueOf(model)
if val.Kind() == reflect.Pointer {
val = val.Elem()
}
if val.Kind() != reflect.Struct {
return ""
}
typ := val.Type()
return findPrimaryKeyNameFromType(typ, ormType)
}
// findPrimaryKeyNameFromType recursively searches for the primary key field name in a struct type
func findPrimaryKeyNameFromType(typ reflect.Type, ormType string) string {
for i := 0; i < typ.NumField(); i++ {
field := typ.Field(i)
// Check if this is an embedded struct
if field.Anonymous {
// Unwrap pointer type if necessary
fieldType := field.Type
if fieldType.Kind() == reflect.Pointer {
fieldType = fieldType.Elem()
}
// Recursively search in embedded struct
if fieldType.Kind() == reflect.Struct {
if pkName := findPrimaryKeyNameFromType(fieldType, ormType); pkName != "" {
return pkName
}
}
continue
}
switch ormType {
case "gorm":
// Check for gorm tag with primaryKey
gormTag := field.Tag.Get("gorm")
if strings.Contains(gormTag, "primaryKey") {
// Try to extract column name from gorm tag
if colName := ExtractColumnFromGormTag(gormTag); colName != "" {
return colName
}
// Fall back to json tag
if jsonTag := field.Tag.Get("json"); jsonTag != "" {
return strings.Split(jsonTag, ",")[0]
}
}
case "bun":
// Check for bun tag with pk flag
bunTag := field.Tag.Get("bun")
if strings.Contains(bunTag, "pk") {
// Extract column name from bun tag
if colName := ExtractColumnFromBunTag(bunTag); colName != "" {
return colName
}
// Fall back to json tag
if jsonTag := field.Tag.Get("json"); jsonTag != "" {
return strings.Split(jsonTag, ",")[0]
}
}
}
}
return ""
}
// ExtractColumnFromGormTag extracts the column name from a gorm tag
// Example: "column:id;primaryKey" -> "id"
func ExtractColumnFromGormTag(tag string) string {
parts := strings.Split(tag, ";")
for _, part := range parts {
part = strings.TrimSpace(part)
if colName, found := strings.CutPrefix(part, "column:"); found {
return colName
}
}
return ""
}
// ExtractColumnFromBunTag extracts the column name from a bun tag
// Example: "id,pk" -> "id"
// Example: ",pk" -> "" (will fall back to json tag)
func ExtractColumnFromBunTag(tag string) string {
parts := strings.Split(tag, ",")
if strings.HasPrefix(strings.ToLower(tag), "table:") || strings.HasPrefix(strings.ToLower(tag), "rel:") || strings.HasPrefix(strings.ToLower(tag), "join:") {
return ""
}
if len(parts) > 0 && parts[0] != "" {
return parts[0]
}
return ""
}
// IsColumnWritable checks if a column can be written to in the database
// For bun: returns false if the field has "scanonly" tag
// For gorm: returns false if the field has "<-:false" or "->" (read-only) tag
// This function recursively searches embedded structs
func IsColumnWritable(model any, columnName string) bool {
modelType := reflect.TypeOf(model)
// Unwrap pointers to get to the base struct type
for modelType != nil && modelType.Kind() == reflect.Pointer {
modelType = modelType.Elem()
}
// Validate that we have a struct type
if modelType == nil || modelType.Kind() != reflect.Struct {
return false
}
found, writable := isColumnWritableInType(modelType, columnName)
if found {
return writable
}
// Column not found in model, allow it (might be a dynamic column)
return true
}
// isColumnWritableInType recursively searches for a column and checks if it's writable
// Returns (found, writable) where found indicates if the column was found
func isColumnWritableInType(typ reflect.Type, columnName string) (bool, bool) {
for i := 0; i < typ.NumField(); i++ {
field := typ.Field(i)
// Check if this is an embedded struct
if field.Anonymous {
// Unwrap pointer type if necessary
fieldType := field.Type
if fieldType.Kind() == reflect.Pointer {
fieldType = fieldType.Elem()
}
// Recursively search in embedded struct
if fieldType.Kind() == reflect.Struct {
if found, writable := isColumnWritableInType(fieldType, columnName); found {
return true, writable
}
}
continue
}
// Check if this field matches the column name
fieldColumnName := getColumnNameFromField(field)
if fieldColumnName != columnName {
continue
}
// Found the field, now check if it's writable
// Check bun tag for scanonly
bunTag := field.Tag.Get("bun")
if bunTag != "" {
if isBunFieldScanOnly(bunTag) {
return true, false
}
}
// Check gorm tag for write restrictions
gormTag := field.Tag.Get("gorm")
if gormTag != "" {
if isGormFieldReadOnly(gormTag) {
return true, false
}
}
// Column is writable
return true, true
}
// Column not found
return false, false
}
// isBunFieldScanOnly checks if a bun tag indicates the field is scan-only
// Example: "column_name,scanonly" -> true
func isBunFieldScanOnly(tag string) bool {
parts := strings.Split(tag, ",")
for _, part := range parts {
if strings.TrimSpace(part) == "scanonly" {
return true
}
}
return false
}
// isGormFieldReadOnly checks if a gorm tag indicates the field is read-only
// Examples:
// - "<-:false" -> true (no writes allowed)
// - "->" -> true (read-only, common pattern)
// - "column:name;->" -> true
// - "<-:create" -> false (writes allowed on create)
func isGormFieldReadOnly(tag string) bool {
parts := strings.Split(tag, ";")
for _, part := range parts {
part = strings.TrimSpace(part)
// Check for read-only marker
if part == "->" {
return true
}
// Check for write restrictions
if value, found := strings.CutPrefix(part, "<-:"); found {
if value == "false" {
return true
}
}
}
return false
}