Compare commits

..

4 Commits

Author SHA1 Message Date
Hein
932f12ab0a Update handler fixes for Utils bug
Some checks failed
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Has been cancelled
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Has been cancelled
Build , Vet Test, and Lint / Lint Code (push) Has been cancelled
Build , Vet Test, and Lint / Build (push) Has been cancelled
Tests / Unit Tests (push) Has been cancelled
Tests / Integration Tests (push) Has been cancelled
2025-12-12 17:01:37 +02:00
Hein
b22792bad6 Optional check for bun 2025-12-12 14:49:52 +02:00
Hein
e8111c01aa Fixed for relation preloading 2025-12-12 11:45:04 +02:00
Hein
5862016031 Added ModelRules
Some checks are pending
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Waiting to run
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Waiting to run
Build , Vet Test, and Lint / Lint Code (push) Waiting to run
Build , Vet Test, and Lint / Build (push) Waiting to run
Tests / Unit Tests (push) Waiting to run
Tests / Integration Tests (push) Waiting to run
2025-12-12 10:13:11 +02:00
7 changed files with 1096 additions and 11 deletions

View File

@ -2,6 +2,7 @@ package common
import (
"fmt"
"regexp"
"strings"
"github.com/bitechdev/ResolveSpec/pkg/logger"
@ -207,6 +208,20 @@ func SanitizeWhereClause(where string, tableName string, options ...*RequestOpti
}
}
}
} else if tableName != "" && !hasTablePrefix(condToCheck) {
// If tableName is provided and the condition DOESN'T have a table prefix,
// qualify unambiguous column references to prevent "ambiguous column" errors
// when there are multiple joins on the same table (e.g., recursive preloads)
columnName := extractUnqualifiedColumnName(condToCheck)
if columnName != "" && (validColumns == nil || isValidColumn(columnName, validColumns)) {
// Qualify the column with the table name
// Be careful to only replace the column name, not other occurrences of the string
oldRef := columnName
newRef := tableName + "." + columnName
// Use word boundary matching to avoid replacing partial matches
cond = qualifyColumnInCondition(cond, oldRef, newRef)
logger.Debug("Qualified unqualified column in condition: '%s' added table prefix '%s'", oldRef, tableName)
}
}
validConditions = append(validConditions, cond)
@ -483,6 +498,86 @@ func extractTableAndColumn(cond string) (table string, column string) {
return "", ""
}
// extractUnqualifiedColumnName extracts the column name from an unqualified condition
// For example: "rid_parentmastertaskitem is null" returns "rid_parentmastertaskitem"
// "status = 'active'" returns "status"
func extractUnqualifiedColumnName(cond string) string {
// Common SQL operators
operators := []string{" = ", " != ", " <> ", " > ", " >= ", " < ", " <= ", " LIKE ", " like ", " IN ", " in ", " IS ", " is ", " NOT ", " not "}
// Find the column reference (left side of the operator)
minIdx := -1
for _, op := range operators {
idx := strings.Index(cond, op)
if idx > 0 && (minIdx == -1 || idx < minIdx) {
minIdx = idx
}
}
var columnRef string
if minIdx > 0 {
columnRef = strings.TrimSpace(cond[:minIdx])
} else {
// No operator found, might be a single column reference
parts := strings.Fields(cond)
if len(parts) > 0 {
columnRef = parts[0]
}
}
if columnRef == "" {
return ""
}
// Remove any quotes
columnRef = strings.Trim(columnRef, "`\"'")
// Return empty if it contains a dot (already qualified) or function call
if strings.Contains(columnRef, ".") || strings.Contains(columnRef, "(") {
return ""
}
return columnRef
}
// qualifyColumnInCondition replaces an unqualified column name with a qualified one in a condition
// Uses word boundaries to avoid partial matches
// For example: qualifyColumnInCondition("rid_item is null", "rid_item", "table.rid_item")
// returns "table.rid_item is null"
func qualifyColumnInCondition(cond, oldRef, newRef string) string {
// Use word boundary matching with Go's supported regex syntax
// \b matches word boundaries
escapedOld := regexp.QuoteMeta(oldRef)
pattern := `\b` + escapedOld + `\b`
re, err := regexp.Compile(pattern)
if err != nil {
// If regex fails, fall back to simple string replacement
logger.Debug("Failed to compile regex for column qualification, using simple replace: %v", err)
return strings.Replace(cond, oldRef, newRef, 1)
}
// Only replace if the match is not preceded by a dot (to avoid replacing already qualified columns)
result := cond
matches := re.FindAllStringIndex(cond, -1)
// Process matches in reverse order to maintain correct indices
for i := len(matches) - 1; i >= 0; i-- {
match := matches[i]
start := match[0]
// Check if preceded by a dot (already qualified)
if start > 0 && cond[start-1] == '.' {
continue
}
// Replace this occurrence
result = result[:start] + newRef + result[match[1]:]
}
return result
}
// findOperatorOutsideParentheses finds the first occurrence of an operator outside of parentheses
// Returns the index of the operator, or -1 if not found or only found inside parentheses
func findOperatorOutsideParentheses(s string, operator string) int {

View File

@ -33,16 +33,16 @@ func TestSanitizeWhereClause(t *testing.T) {
expected: "",
},
{
name: "valid condition with parentheses - no prefix added",
name: "valid condition with parentheses - prefix added to prevent ambiguity",
where: "(status = 'active')",
tableName: "users",
expected: "status = 'active'",
expected: "users.status = 'active'",
},
{
name: "mixed trivial and valid conditions - no prefix added",
name: "mixed trivial and valid conditions - prefix added",
where: "true AND status = 'active' AND 1=1",
tableName: "users",
expected: "status = 'active'",
expected: "users.status = 'active'",
},
{
name: "condition with correct table prefix - unchanged",
@ -63,10 +63,10 @@ func TestSanitizeWhereClause(t *testing.T) {
expected: "users.status = 'active' AND users.age > 18",
},
{
name: "multiple valid conditions without prefix - no prefix added",
name: "multiple valid conditions without prefix - prefixes added",
where: "status = 'active' AND age > 18",
tableName: "users",
expected: "status = 'active' AND age > 18",
expected: "users.status = 'active' AND users.age > 18",
},
{
name: "no table name provided",
@ -90,13 +90,13 @@ func TestSanitizeWhereClause(t *testing.T) {
name: "mixed case AND operators",
where: "status = 'active' AND age > 18 and name = 'John'",
tableName: "users",
expected: "status = 'active' AND age > 18 AND name = 'John'",
expected: "users.status = 'active' AND users.age > 18 AND users.name = 'John'",
},
{
name: "subquery with ORDER BY and LIMIT - allowed",
where: "id IN (SELECT id FROM users WHERE status = 'active' ORDER BY created_at DESC LIMIT 10)",
tableName: "users",
expected: "id IN (SELECT id FROM users WHERE status = 'active' ORDER BY created_at DESC LIMIT 10)",
expected: "users.id IN (SELECT users.id FROM users WHERE status = 'active' ORDER BY created_at DESC LIMIT 10)",
},
{
name: "dangerous DELETE keyword - blocked",

View File

@ -6,15 +6,37 @@ import (
"sync"
)
// ModelRules defines the permissions and security settings for a model
type ModelRules struct {
CanRead bool // Whether the model can be read (GET operations)
CanUpdate bool // Whether the model can be updated (PUT/PATCH operations)
CanCreate bool // Whether the model can be created (POST operations)
CanDelete bool // Whether the model can be deleted (DELETE operations)
SecurityDisabled bool // Whether security checks are disabled for this model
}
// DefaultModelRules returns the default rules for a model (all operations allowed, security enabled)
func DefaultModelRules() ModelRules {
return ModelRules{
CanRead: true,
CanUpdate: true,
CanCreate: true,
CanDelete: true,
SecurityDisabled: false,
}
}
// DefaultModelRegistry implements ModelRegistry interface
type DefaultModelRegistry struct {
models map[string]interface{}
rules map[string]ModelRules
mutex sync.RWMutex
}
// Global default registry instance
var defaultRegistry = &DefaultModelRegistry{
models: make(map[string]interface{}),
rules: make(map[string]ModelRules),
}
// Global list of registries (searched in order)
@ -25,6 +47,7 @@ var registriesMutex sync.RWMutex
func NewModelRegistry() *DefaultModelRegistry {
return &DefaultModelRegistry{
models: make(map[string]interface{}),
rules: make(map[string]ModelRules),
}
}
@ -98,6 +121,10 @@ func (r *DefaultModelRegistry) RegisterModel(name string, model interface{}) err
}
r.models[name] = model
// Initialize with default rules if not already set
if _, exists := r.rules[name]; !exists {
r.rules[name] = DefaultModelRules()
}
return nil
}
@ -135,6 +162,54 @@ func (r *DefaultModelRegistry) GetModelByEntity(schema, entity string) (interfac
return r.GetModel(entity)
}
// SetModelRules sets the rules for a specific model
func (r *DefaultModelRegistry) SetModelRules(name string, rules ModelRules) error {
r.mutex.Lock()
defer r.mutex.Unlock()
// Check if model exists
if _, exists := r.models[name]; !exists {
return fmt.Errorf("model %s not found", name)
}
r.rules[name] = rules
return nil
}
// GetModelRules retrieves the rules for a specific model
// Returns default rules if model exists but rules are not set
func (r *DefaultModelRegistry) GetModelRules(name string) (ModelRules, error) {
r.mutex.RLock()
defer r.mutex.RUnlock()
// Check if model exists
if _, exists := r.models[name]; !exists {
return ModelRules{}, fmt.Errorf("model %s not found", name)
}
// Return rules if set, otherwise return default rules
if rules, exists := r.rules[name]; exists {
return rules, nil
}
return DefaultModelRules(), nil
}
// RegisterModelWithRules registers a model with specific rules
func (r *DefaultModelRegistry) RegisterModelWithRules(name string, model interface{}, rules ModelRules) error {
// First register the model
if err := r.RegisterModel(name, model); err != nil {
return err
}
// Then set the rules (we need to lock again for rules)
r.mutex.Lock()
defer r.mutex.Unlock()
r.rules[name] = rules
return nil
}
// Global convenience functions using the default registry
// RegisterModel registers a model with the default global registry
@ -190,3 +265,34 @@ func GetModels() []interface{} {
return models
}
// SetModelRules sets the rules for a specific model in the default registry
func SetModelRules(name string, rules ModelRules) error {
return defaultRegistry.SetModelRules(name, rules)
}
// GetModelRules retrieves the rules for a specific model from the default registry
func GetModelRules(name string) (ModelRules, error) {
return defaultRegistry.GetModelRules(name)
}
// GetModelRulesByName retrieves the rules for a model by searching through all registries in order
// Returns the first match found
func GetModelRulesByName(name string) (ModelRules, error) {
registriesMutex.RLock()
defer registriesMutex.RUnlock()
for _, registry := range registries {
if _, err := registry.GetModel(name); err == nil {
// Model found in this registry, get its rules
return registry.GetModelRules(name)
}
}
return ModelRules{}, fmt.Errorf("model %s not found in any registry", name)
}
// RegisterModelWithRules registers a model with specific rules in the default registry
func RegisterModelWithRules(model interface{}, name string, rules ModelRules) error {
return defaultRegistry.RegisterModelWithRules(name, model, rules)
}

View File

@ -1,6 +1,7 @@
package reflection
import (
"encoding/json"
"fmt"
"reflect"
"strconv"
@ -897,6 +898,319 @@ func GetRelationModel(model interface{}, fieldName string) interface{} {
return currentModel
}
// MapToStruct populates a struct from a map while preserving custom types
// It uses reflection to set struct fields based on map keys, matching by:
// 1. Bun tag column name
// 2. Gorm tag column name
// 3. JSON tag name
// 4. Field name (case-insensitive)
// This preserves custom types that implement driver.Valuer like SqlJSONB
func MapToStruct(dataMap map[string]interface{}, target interface{}) error {
if dataMap == nil || target == nil {
return fmt.Errorf("dataMap and target cannot be nil")
}
targetValue := reflect.ValueOf(target)
if targetValue.Kind() != reflect.Ptr {
return fmt.Errorf("target must be a pointer to a struct")
}
targetValue = targetValue.Elem()
if targetValue.Kind() != reflect.Struct {
return fmt.Errorf("target must be a pointer to a struct")
}
targetType := targetValue.Type()
// Create a map of column names to field indices for faster lookup
columnToField := make(map[string]int)
for i := 0; i < targetType.NumField(); i++ {
field := targetType.Field(i)
// Skip unexported fields
if !field.IsExported() {
continue
}
// Build list of possible column names for this field
var columnNames []string
// 1. Bun tag
if bunTag := field.Tag.Get("bun"); bunTag != "" && bunTag != "-" {
if colName := ExtractColumnFromBunTag(bunTag); colName != "" {
columnNames = append(columnNames, colName)
}
}
// 2. Gorm tag
if gormTag := field.Tag.Get("gorm"); gormTag != "" && gormTag != "-" {
if colName := ExtractColumnFromGormTag(gormTag); colName != "" {
columnNames = append(columnNames, colName)
}
}
// 3. JSON tag
if jsonTag := field.Tag.Get("json"); jsonTag != "" && jsonTag != "-" {
parts := strings.Split(jsonTag, ",")
if len(parts) > 0 && parts[0] != "" {
columnNames = append(columnNames, parts[0])
}
}
// 4. Field name variations
columnNames = append(columnNames, field.Name)
columnNames = append(columnNames, strings.ToLower(field.Name))
columnNames = append(columnNames, ToSnakeCase(field.Name))
// Map all column name variations to this field index
for _, colName := range columnNames {
columnToField[strings.ToLower(colName)] = i
}
}
// Iterate through the map and set struct fields
for key, value := range dataMap {
// Find the field index for this key
fieldIndex, found := columnToField[strings.ToLower(key)]
if !found {
// Skip keys that don't map to any field
continue
}
field := targetValue.Field(fieldIndex)
if !field.CanSet() {
continue
}
// Set the value, preserving custom types
if err := setFieldValue(field, value); err != nil {
return fmt.Errorf("failed to set field %s: %w", targetType.Field(fieldIndex).Name, err)
}
}
return nil
}
// setFieldValue sets a reflect.Value from an interface{} value, handling type conversions
func setFieldValue(field reflect.Value, value interface{}) error {
if value == nil {
// Set zero value for nil
field.Set(reflect.Zero(field.Type()))
return nil
}
valueReflect := reflect.ValueOf(value)
// If types match exactly, just set it
if valueReflect.Type().AssignableTo(field.Type()) {
field.Set(valueReflect)
return nil
}
// Handle pointer fields
if field.Kind() == reflect.Ptr {
if valueReflect.Kind() != reflect.Ptr {
// Create a new pointer and set its value
newPtr := reflect.New(field.Type().Elem())
if err := setFieldValue(newPtr.Elem(), value); err != nil {
return err
}
field.Set(newPtr)
return nil
}
}
// Handle conversions for basic types
switch field.Kind() {
case reflect.String:
if str, ok := value.(string); ok {
field.SetString(str)
return nil
}
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
if num, ok := convertToInt64(value); ok {
if field.OverflowInt(num) {
return fmt.Errorf("integer overflow")
}
field.SetInt(num)
return nil
}
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
if num, ok := convertToUint64(value); ok {
if field.OverflowUint(num) {
return fmt.Errorf("unsigned integer overflow")
}
field.SetUint(num)
return nil
}
case reflect.Float32, reflect.Float64:
if num, ok := convertToFloat64(value); ok {
if field.OverflowFloat(num) {
return fmt.Errorf("float overflow")
}
field.SetFloat(num)
return nil
}
case reflect.Bool:
if b, ok := value.(bool); ok {
field.SetBool(b)
return nil
}
case reflect.Slice:
// Handle []byte specially (for types like SqlJSONB)
if field.Type().Elem().Kind() == reflect.Uint8 {
switch v := value.(type) {
case []byte:
field.SetBytes(v)
return nil
case string:
field.SetBytes([]byte(v))
return nil
case map[string]interface{}, []interface{}:
// Marshal complex types to JSON for SqlJSONB fields
jsonBytes, err := json.Marshal(v)
if err != nil {
return fmt.Errorf("failed to marshal value to JSON: %w", err)
}
field.SetBytes(jsonBytes)
return nil
}
}
}
// Handle struct types (like SqlTimeStamp, SqlDate, SqlTime which wrap SqlNull[time.Time])
if field.Kind() == reflect.Struct {
// Try to find a "Val" field (for SqlNull types) and set it
valField := field.FieldByName("Val")
if valField.IsValid() && valField.CanSet() {
// Also set Valid field to true
validField := field.FieldByName("Valid")
if validField.IsValid() && validField.CanSet() && validField.Kind() == reflect.Bool {
// Set the Val field
if err := setFieldValue(valField, value); err != nil {
return err
}
// Set Valid to true
validField.SetBool(true)
return nil
}
}
}
// If we can convert the type, do it
if valueReflect.Type().ConvertibleTo(field.Type()) {
field.Set(valueReflect.Convert(field.Type()))
return nil
}
return fmt.Errorf("cannot convert %v to %v", valueReflect.Type(), field.Type())
}
// convertToInt64 attempts to convert various types to int64
func convertToInt64(value interface{}) (int64, bool) {
switch v := value.(type) {
case int:
return int64(v), true
case int8:
return int64(v), true
case int16:
return int64(v), true
case int32:
return int64(v), true
case int64:
return v, true
case uint:
return int64(v), true
case uint8:
return int64(v), true
case uint16:
return int64(v), true
case uint32:
return int64(v), true
case uint64:
return int64(v), true
case float32:
return int64(v), true
case float64:
return int64(v), true
case string:
if num, err := strconv.ParseInt(v, 10, 64); err == nil {
return num, true
}
}
return 0, false
}
// convertToUint64 attempts to convert various types to uint64
func convertToUint64(value interface{}) (uint64, bool) {
switch v := value.(type) {
case int:
return uint64(v), true
case int8:
return uint64(v), true
case int16:
return uint64(v), true
case int32:
return uint64(v), true
case int64:
return uint64(v), true
case uint:
return uint64(v), true
case uint8:
return uint64(v), true
case uint16:
return uint64(v), true
case uint32:
return uint64(v), true
case uint64:
return v, true
case float32:
return uint64(v), true
case float64:
return uint64(v), true
case string:
if num, err := strconv.ParseUint(v, 10, 64); err == nil {
return num, true
}
}
return 0, false
}
// convertToFloat64 attempts to convert various types to float64
func convertToFloat64(value interface{}) (float64, bool) {
switch v := value.(type) {
case int:
return float64(v), true
case int8:
return float64(v), true
case int16:
return float64(v), true
case int32:
return float64(v), true
case int64:
return float64(v), true
case uint:
return float64(v), true
case uint8:
return float64(v), true
case uint16:
return float64(v), true
case uint32:
return float64(v), true
case uint64:
return float64(v), true
case float32:
return float64(v), true
case float64:
return v, true
case string:
if num, err := strconv.ParseFloat(v, 64); err == nil {
return num, true
}
}
return 0, false
}
// getRelationModelSingleLevel gets the model type for a single level field (non-recursive)
// This is a helper function used by GetRelationModel to handle one level at a time
func getRelationModelSingleLevel(model interface{}, fieldName string) interface{} {

View File

@ -0,0 +1,266 @@
package reflection_test
import (
"testing"
"time"
"github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/reflection"
)
func TestMapToStruct_SqlJSONB_PreservesDriverValuer(t *testing.T) {
// Test that SqlJSONB type preserves driver.Valuer interface
type TestModel struct {
ID int64 `bun:"id,pk" json:"id"`
Meta common.SqlJSONB `bun:"meta" json:"meta"`
}
dataMap := map[string]interface{}{
"id": int64(123),
"meta": map[string]interface{}{
"key": "value",
"num": 42,
},
}
var result TestModel
err := reflection.MapToStruct(dataMap, &result)
if err != nil {
t.Fatalf("MapToStruct() error = %v", err)
}
// Verify the field was set
if result.ID != 123 {
t.Errorf("ID = %v, want 123", result.ID)
}
// Verify SqlJSONB was populated
if len(result.Meta) == 0 {
t.Error("Meta is empty, want non-empty")
}
// Most importantly: verify driver.Valuer interface works
value, err := result.Meta.Value()
if err != nil {
t.Errorf("Meta.Value() error = %v, want nil", err)
}
// Value should return a string representation of the JSON
if value == nil {
t.Error("Meta.Value() returned nil, want non-nil")
}
// Check it's a valid JSON string
if str, ok := value.(string); ok {
if len(str) == 0 {
t.Error("Meta.Value() returned empty string, want valid JSON")
}
t.Logf("SqlJSONB.Value() returned: %s", str)
} else {
t.Errorf("Meta.Value() returned type %T, want string", value)
}
}
func TestMapToStruct_SqlJSONB_FromBytes(t *testing.T) {
// Test that SqlJSONB can be set from []byte directly
type TestModel struct {
ID int64 `bun:"id,pk" json:"id"`
Meta common.SqlJSONB `bun:"meta" json:"meta"`
}
jsonBytes := []byte(`{"direct":"bytes"}`)
dataMap := map[string]interface{}{
"id": int64(456),
"meta": jsonBytes,
}
var result TestModel
err := reflection.MapToStruct(dataMap, &result)
if err != nil {
t.Fatalf("MapToStruct() error = %v", err)
}
if result.ID != 456 {
t.Errorf("ID = %v, want 456", result.ID)
}
if string(result.Meta) != string(jsonBytes) {
t.Errorf("Meta = %s, want %s", string(result.Meta), string(jsonBytes))
}
// Verify driver.Valuer works
value, err := result.Meta.Value()
if err != nil {
t.Errorf("Meta.Value() error = %v", err)
}
if value == nil {
t.Error("Meta.Value() returned nil")
}
}
func TestMapToStruct_AllSqlTypes(t *testing.T) {
// Test model with all SQL custom types
type TestModel struct {
ID int64 `bun:"id,pk" json:"id"`
Name string `bun:"name" json:"name"`
CreatedAt common.SqlTimeStamp `bun:"created_at" json:"created_at"`
BirthDate common.SqlDate `bun:"birth_date" json:"birth_date"`
LoginTime common.SqlTime `bun:"login_time" json:"login_time"`
Meta common.SqlJSONB `bun:"meta" json:"meta"`
Tags common.SqlJSONB `bun:"tags" json:"tags"`
}
now := time.Now()
birthDate := time.Date(1990, 1, 15, 0, 0, 0, 0, time.UTC)
loginTime := time.Date(0, 1, 1, 14, 30, 0, 0, time.UTC)
dataMap := map[string]interface{}{
"id": int64(100),
"name": "Test User",
"created_at": now,
"birth_date": birthDate,
"login_time": loginTime,
"meta": map[string]interface{}{
"role": "admin",
"active": true,
},
"tags": []interface{}{"golang", "testing", "sql"},
}
var result TestModel
err := reflection.MapToStruct(dataMap, &result)
if err != nil {
t.Fatalf("MapToStruct() error = %v", err)
}
// Verify basic fields
if result.ID != 100 {
t.Errorf("ID = %v, want 100", result.ID)
}
if result.Name != "Test User" {
t.Errorf("Name = %v, want 'Test User'", result.Name)
}
// Verify SqlTimeStamp
if !result.CreatedAt.Valid {
t.Error("CreatedAt.Valid = false, want true")
}
if !result.CreatedAt.Val.Equal(now) {
t.Errorf("CreatedAt.Val = %v, want %v", result.CreatedAt.Val, now)
}
// Verify driver.Valuer for SqlTimeStamp
tsValue, err := result.CreatedAt.Value()
if err != nil {
t.Errorf("CreatedAt.Value() error = %v", err)
}
if tsValue == nil {
t.Error("CreatedAt.Value() returned nil")
}
// Verify SqlDate
if !result.BirthDate.Valid {
t.Error("BirthDate.Valid = false, want true")
}
if !result.BirthDate.Val.Equal(birthDate) {
t.Errorf("BirthDate.Val = %v, want %v", result.BirthDate.Val, birthDate)
}
// Verify driver.Valuer for SqlDate
dateValue, err := result.BirthDate.Value()
if err != nil {
t.Errorf("BirthDate.Value() error = %v", err)
}
if dateValue == nil {
t.Error("BirthDate.Value() returned nil")
}
// Verify SqlTime
if !result.LoginTime.Valid {
t.Error("LoginTime.Valid = false, want true")
}
// Verify driver.Valuer for SqlTime
timeValue, err := result.LoginTime.Value()
if err != nil {
t.Errorf("LoginTime.Value() error = %v", err)
}
if timeValue == nil {
t.Error("LoginTime.Value() returned nil")
}
// Verify SqlJSONB for Meta
if len(result.Meta) == 0 {
t.Error("Meta is empty")
}
metaValue, err := result.Meta.Value()
if err != nil {
t.Errorf("Meta.Value() error = %v", err)
}
if metaValue == nil {
t.Error("Meta.Value() returned nil")
}
// Verify SqlJSONB for Tags
if len(result.Tags) == 0 {
t.Error("Tags is empty")
}
tagsValue, err := result.Tags.Value()
if err != nil {
t.Errorf("Tags.Value() error = %v", err)
}
if tagsValue == nil {
t.Error("Tags.Value() returned nil")
}
t.Logf("All SQL types successfully preserved driver.Valuer interface:")
t.Logf(" - SqlTimeStamp: %v", tsValue)
t.Logf(" - SqlDate: %v", dateValue)
t.Logf(" - SqlTime: %v", timeValue)
t.Logf(" - SqlJSONB (Meta): %v", metaValue)
t.Logf(" - SqlJSONB (Tags): %v", tagsValue)
}
func TestMapToStruct_SqlNull_NilValues(t *testing.T) {
// Test that SqlNull types handle nil values correctly
type TestModel struct {
ID int64 `bun:"id,pk" json:"id"`
UpdatedAt common.SqlTimeStamp `bun:"updated_at" json:"updated_at"`
DeletedAt common.SqlTimeStamp `bun:"deleted_at" json:"deleted_at"`
}
now := time.Now()
dataMap := map[string]interface{}{
"id": int64(200),
"updated_at": now,
"deleted_at": nil, // Explicitly nil
}
var result TestModel
err := reflection.MapToStruct(dataMap, &result)
if err != nil {
t.Fatalf("MapToStruct() error = %v", err)
}
// UpdatedAt should be valid
if !result.UpdatedAt.Valid {
t.Error("UpdatedAt.Valid = false, want true")
}
if !result.UpdatedAt.Val.Equal(now) {
t.Errorf("UpdatedAt.Val = %v, want %v", result.UpdatedAt.Val, now)
}
// DeletedAt should be invalid (null)
if result.DeletedAt.Valid {
t.Error("DeletedAt.Valid = true, want false (null)")
}
// Verify driver.Valuer for null SqlTimeStamp
deletedValue, err := result.DeletedAt.Value()
if err != nil {
t.Errorf("DeletedAt.Value() error = %v", err)
}
if deletedValue != nil {
t.Errorf("DeletedAt.Value() = %v, want nil", deletedValue)
}
}

View File

@ -1687,3 +1687,201 @@ func TestGetRelationModel_WithTags(t *testing.T) {
})
}
}
func TestMapToStruct(t *testing.T) {
// Test model with various field types
type TestModel struct {
ID int64 `bun:"id,pk" json:"id"`
Name string `bun:"name" json:"name"`
Age int `bun:"age" json:"age"`
Active bool `bun:"active" json:"active"`
Score float64 `bun:"score" json:"score"`
Data []byte `bun:"data" json:"data"`
MetaJSON []byte `bun:"meta_json" json:"meta_json"`
}
tests := []struct {
name string
dataMap map[string]interface{}
expected TestModel
wantErr bool
}{
{
name: "Basic types conversion",
dataMap: map[string]interface{}{
"id": int64(123),
"name": "Test User",
"age": 30,
"active": true,
"score": 95.5,
},
expected: TestModel{
ID: 123,
Name: "Test User",
Age: 30,
Active: true,
Score: 95.5,
},
wantErr: false,
},
{
name: "Byte slice (SqlJSONB-like) from []byte",
dataMap: map[string]interface{}{
"id": int64(456),
"name": "JSON Test",
"data": []byte(`{"key":"value"}`),
},
expected: TestModel{
ID: 456,
Name: "JSON Test",
Data: []byte(`{"key":"value"}`),
},
wantErr: false,
},
{
name: "Byte slice from string",
dataMap: map[string]interface{}{
"id": int64(789),
"data": "string data",
},
expected: TestModel{
ID: 789,
Data: []byte("string data"),
},
wantErr: false,
},
{
name: "Byte slice from map (JSON marshal)",
dataMap: map[string]interface{}{
"id": int64(999),
"meta_json": map[string]interface{}{
"field1": "value1",
"field2": 42,
},
},
expected: TestModel{
ID: 999,
MetaJSON: []byte(`{"field1":"value1","field2":42}`),
},
wantErr: false,
},
{
name: "Byte slice from slice (JSON marshal)",
dataMap: map[string]interface{}{
"id": int64(111),
"meta_json": []interface{}{"item1", "item2", 3},
},
expected: TestModel{
ID: 111,
MetaJSON: []byte(`["item1","item2",3]`),
},
wantErr: false,
},
{
name: "Field matching by bun tag",
dataMap: map[string]interface{}{
"id": int64(222),
"name": "Tagged Field",
},
expected: TestModel{
ID: 222,
Name: "Tagged Field",
},
wantErr: false,
},
{
name: "Nil values",
dataMap: map[string]interface{}{
"id": int64(333),
"data": nil,
},
expected: TestModel{
ID: 333,
Data: nil,
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var result TestModel
err := MapToStruct(tt.dataMap, &result)
if (err != nil) != tt.wantErr {
t.Errorf("MapToStruct() error = %v, wantErr %v", err, tt.wantErr)
return
}
// Compare fields individually for better error messages
if result.ID != tt.expected.ID {
t.Errorf("ID = %v, want %v", result.ID, tt.expected.ID)
}
if result.Name != tt.expected.Name {
t.Errorf("Name = %v, want %v", result.Name, tt.expected.Name)
}
if result.Age != tt.expected.Age {
t.Errorf("Age = %v, want %v", result.Age, tt.expected.Age)
}
if result.Active != tt.expected.Active {
t.Errorf("Active = %v, want %v", result.Active, tt.expected.Active)
}
if result.Score != tt.expected.Score {
t.Errorf("Score = %v, want %v", result.Score, tt.expected.Score)
}
// For byte slices, compare as strings for JSON data
if tt.expected.Data != nil {
if string(result.Data) != string(tt.expected.Data) {
t.Errorf("Data = %s, want %s", string(result.Data), string(tt.expected.Data))
}
}
if tt.expected.MetaJSON != nil {
if string(result.MetaJSON) != string(tt.expected.MetaJSON) {
t.Errorf("MetaJSON = %s, want %s", string(result.MetaJSON), string(tt.expected.MetaJSON))
}
}
})
}
}
func TestMapToStruct_Errors(t *testing.T) {
type TestModel struct {
ID int `bun:"id" json:"id"`
}
tests := []struct {
name string
dataMap map[string]interface{}
target interface{}
wantErr bool
}{
{
name: "Nil dataMap",
dataMap: nil,
target: &TestModel{},
wantErr: true,
},
{
name: "Nil target",
dataMap: map[string]interface{}{"id": 1},
target: nil,
wantErr: true,
},
{
name: "Non-pointer target",
dataMap: map[string]interface{}{"id": 1},
target: TestModel{},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := MapToStruct(tt.dataMap, tt.target)
if (err != nil) != tt.wantErr {
t.Errorf("MapToStruct() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}

View File

@ -746,9 +746,42 @@ func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload co
// Apply ComputedQL fields if any
if len(preload.ComputedQL) > 0 {
// Get the base table name from the related model
baseTableName := getTableNameFromModel(relatedModel)
// Convert the preload relation path to the appropriate alias format
// This is ORM-specific. Currently we only support Bun's format.
// TODO: Add support for other ORMs if needed
preloadAlias := ""
if h.db.GetUnderlyingDB() != nil {
// Check if we're using Bun by checking the type name
underlyingType := fmt.Sprintf("%T", h.db.GetUnderlyingDB())
if strings.Contains(underlyingType, "bun.DB") {
// Use Bun's alias format: lowercase with double underscores
preloadAlias = relationPathToBunAlias(preload.Relation)
}
// For GORM: GORM doesn't use the same alias format, and this fix
// may not be needed since GORM handles preloads differently
}
logger.Debug("Applying computed columns to preload %s (alias: %s, base table: %s)",
preload.Relation, preloadAlias, baseTableName)
for colName, colExpr := range preload.ComputedQL {
// Replace table references in the expression with the preload alias
// This fixes the ambiguous column reference issue when there are multiple
// levels of recursive/nested preloads
adjustedExpr := colExpr
if baseTableName != "" && preloadAlias != "" {
adjustedExpr = replaceTableReferencesInSQL(colExpr, baseTableName, preloadAlias)
if adjustedExpr != colExpr {
logger.Debug("Adjusted computed column expression for %s: '%s' -> '%s'",
colName, colExpr, adjustedExpr)
}
}
logger.Debug("Applying computed column to preload %s: %s", preload.Relation, colName)
sq = sq.ColumnExpr(fmt.Sprintf("(%s) AS %s", colExpr, colName))
sq = sq.ColumnExpr(fmt.Sprintf("(%s) AS %s", adjustedExpr, colName))
// Remove the computed column from selected columns to avoid duplication
for colIndex := range preload.Columns {
if preload.Columns[colIndex] == colName {
@ -841,6 +874,73 @@ func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload co
return query
}
// relationPathToBunAlias converts a relation path like "MAL.MAL.DEF" to the Bun alias format "mal__mal__def"
// Bun generates aliases for nested relations by lowercasing and replacing dots with double underscores
func relationPathToBunAlias(relationPath string) string {
if relationPath == "" {
return ""
}
// Convert to lowercase and replace dots with double underscores
alias := strings.ToLower(relationPath)
alias = strings.ReplaceAll(alias, ".", "__")
return alias
}
// replaceTableReferencesInSQL replaces references to a base table name in a SQL expression
// with the appropriate alias for the current preload level
// For example, if baseTableName is "mastertaskitem" and targetAlias is "mal__mal",
// it will replace "mastertaskitem.rid_mastertaskitem" with "mal__mal.rid_mastertaskitem"
func replaceTableReferencesInSQL(sqlExpr, baseTableName, targetAlias string) string {
if sqlExpr == "" || baseTableName == "" || targetAlias == "" {
return sqlExpr
}
// Replace both quoted and unquoted table references
// Handle patterns like: tablename.column, "tablename".column, tablename."column", "tablename"."column"
// Pattern 1: tablename.column (unquoted)
result := strings.ReplaceAll(sqlExpr, baseTableName+".", targetAlias+".")
// Pattern 2: "tablename".column or "tablename"."column" (quoted table name)
result = strings.ReplaceAll(result, "\""+baseTableName+"\".", "\""+targetAlias+"\".")
return result
}
// getTableNameFromModel extracts the table name from a model
// It checks the bun tag first, then falls back to converting the struct name to snake_case
func getTableNameFromModel(model interface{}) string {
if model == nil {
return ""
}
modelType := reflect.TypeOf(model)
// Unwrap pointers
for modelType != nil && modelType.Kind() == reflect.Ptr {
modelType = modelType.Elem()
}
if modelType == nil || modelType.Kind() != reflect.Struct {
return ""
}
// Look for bun tag on embedded BaseModel
for i := 0; i < modelType.NumField(); i++ {
field := modelType.Field(i)
if field.Anonymous {
bunTag := field.Tag.Get("bun")
if strings.HasPrefix(bunTag, "table:") {
return strings.TrimPrefix(bunTag, "table:")
}
}
}
// Fallback: convert struct name to lowercase (simple heuristic)
// This handles cases like "MasterTaskItem" -> "mastertaskitem"
return strings.ToLower(modelType.Name())
}
func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, data interface{}, options ExtendedRequestOptions) {
// Capture panics and return error response
defer func() {
@ -1120,8 +1220,14 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id
// Ensure ID is in the data map for the update
dataMap[pkName] = targetID
// Create update query
query := tx.NewUpdate().Table(tableName).SetMap(dataMap)
// Populate model instance from dataMap to preserve custom types (like SqlJSONB)
modelInstance := reflect.New(reflect.TypeOf(model).Elem()).Interface()
if err := reflection.MapToStruct(dataMap, modelInstance); err != nil {
return fmt.Errorf("failed to populate model from data: %w", err)
}
// Create update query using Model() to preserve custom types and driver.Valuer interfaces
query := tx.NewUpdate().Model(modelInstance).Table(tableName)
query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), targetID)
// Execute BeforeScan hooks - pass query chain so hooks can modify it