diff --git a/pkg/common/adapters/database/alias_test.go b/pkg/common/adapters/database/alias_test.go index e7cc727..c898eac 100644 --- a/pkg/common/adapters/database/alias_test.go +++ b/pkg/common/adapters/database/alias_test.go @@ -13,7 +13,7 @@ func TestNormalizeTableAlias(t *testing.T) { want string }{ { - name: "strips incorrect alias from simple condition", + name: "strips plausible alias from simple condition", query: "APIL.rid_hub = 2576", expectedAlias: "apiproviderlink", tableName: "apiproviderlink", @@ -27,14 +27,14 @@ func TestNormalizeTableAlias(t *testing.T) { want: "apiproviderlink.rid_hub = 2576", }, { - name: "strips incorrect alias with multiple conditions", + name: "strips plausible alias with multiple conditions", query: "APIL.rid_hub = ? AND APIL.active = ?", expectedAlias: "apiproviderlink", tableName: "apiproviderlink", want: "rid_hub = ? AND active = ?", }, { - name: "handles mixed correct and incorrect aliases", + name: "handles mixed correct and plausible aliases", query: "APIL.rid_hub = ? AND apiproviderlink.active = ?", expectedAlias: "apiproviderlink", tableName: "apiproviderlink", @@ -54,6 +54,20 @@ func TestNormalizeTableAlias(t *testing.T) { tableName: "apiproviderlink", want: "rid_hub = ?", }, + { + name: "keeps reference to different table (not in current table name)", + query: "APIL.rid_hub = ?", + expectedAlias: "apiprovider", + tableName: "apiprovider", + want: "APIL.rid_hub = ?", + }, + { + name: "keeps reference with short prefix that might be ambiguous", + query: "AP.rid = ?", + expectedAlias: "apiprovider", + tableName: "apiprovider", + want: "AP.rid = ?", + }, } for _, tt := range tests { diff --git a/pkg/common/adapters/database/bun.go b/pkg/common/adapters/database/bun.go index c82b368..c5d261f 100644 --- a/pkg/common/adapters/database/bun.go +++ b/pkg/common/adapters/database/bun.go @@ -252,17 +252,32 @@ func isOperatorOrKeyword(s string) bool { return false } +// isAcronymMatch checks if prefix is an acronym of tableName +// For example, "apil" matches "apiproviderlink" because each letter appears in sequence +func isAcronymMatch(prefix, tableName string) bool { + if len(prefix) == 0 || len(tableName) == 0 { + return false + } + + prefixIdx := 0 + for i := 0; i < len(tableName) && prefixIdx < len(prefix); i++ { + if tableName[i] == prefix[prefixIdx] { + prefixIdx++ + } + } + + // All characters of prefix were found in sequence in tableName + return prefixIdx == len(prefix) +} + // normalizeTableAlias replaces table alias prefixes in SQL conditions // This handles cases where a user references a table alias that doesn't match // what Bun generates (common in preload contexts) func normalizeTableAlias(query, expectedAlias, tableName string) string { // Pattern: . where might be an incorrect alias // We'll look for patterns like "APIL.column" and either: - // 1. Remove the alias prefix entirely (safest) - // 2. Replace with the expected alias - - // For now, we'll use a simple approach: if the query contains a dot (qualified reference) - // and that prefix is not the expected alias or table name, strip it + // 1. Remove the alias prefix if it's clearly meant for this table + // 2. Leave it alone if it might be referring to another table (JOIN/preload) // Split on spaces and parentheses to find qualified references parts := strings.FieldsFunc(query, func(r rune) bool { @@ -277,13 +292,39 @@ func normalizeTableAlias(query, expectedAlias, tableName string) string { column := part[dotIndex+1:] // Check if the prefix matches our expected alias or table name (case-insensitive) - if !strings.EqualFold(prefix, expectedAlias) && - !strings.EqualFold(prefix, tableName) && - !strings.EqualFold(prefix, strings.ToLower(tableName)) { - // This is a different alias - remove the prefix - logger.Debug("Stripping incorrect alias '%s' from WHERE condition, keeping just '%s'", prefix, column) + if strings.EqualFold(prefix, expectedAlias) || + strings.EqualFold(prefix, tableName) || + strings.EqualFold(prefix, strings.ToLower(tableName)) { + // Prefix matches current table, it's safe but redundant - leave it + continue + } + + // Check if the prefix could plausibly be an alias/acronym for this table + // Only strip if we're confident it's meant for this table + // For example: "APIL" could be an acronym for "apiproviderlink" + prefixLower := strings.ToLower(prefix) + tableNameLower := strings.ToLower(tableName) + + // Check if prefix is a substring of table name + isSubstring := strings.Contains(tableNameLower, prefixLower) && len(prefixLower) > 2 + + // Check if prefix is an acronym of table name + // e.g., "APIL" matches "ApiProviderLink" (A-p-I-providerL-ink) + isAcronym := false + if !isSubstring && len(prefixLower) > 2 { + isAcronym = isAcronymMatch(prefixLower, tableNameLower) + } + + if isSubstring || isAcronym { + // This looks like it could be an alias for this table - strip it + logger.Debug("Stripping plausible alias '%s' from WHERE condition, keeping just '%s'", prefix, column) // Replace the qualified reference with just the column name modified = strings.ReplaceAll(modified, part, column) + } else { + // Prefix doesn't match the current table at all + // It's likely referring to a different table (JOIN/preload) + // DON'T strip it - leave the qualified reference as-is + logger.Debug("Keeping qualified reference '%s' - prefix '%s' doesn't match current table '%s'", part, prefix, tableName) } } } diff --git a/pkg/reflection/generic_model_test.go b/pkg/reflection/generic_model_test.go new file mode 100644 index 0000000..e8467c7 --- /dev/null +++ b/pkg/reflection/generic_model_test.go @@ -0,0 +1,331 @@ +package reflection + +import ( + "reflect" + "testing" +) + +// Test models for GetModelColumnDetail +type TestModelForColumnDetail struct { + ID int `gorm:"column:rid_test;primaryKey;type:bigserial;not null" json:"id"` + Name string `gorm:"column:name;type:varchar(255);not null" json:"name"` + Email string `gorm:"column:email;type:varchar(255);unique;nullable" json:"email"` + Description string `gorm:"column:description;type:text;null" json:"description"` + ForeignKey int `gorm:"foreignKey:parent_id" json:"foreign_key"` +} + +type EmbeddedBase struct { + ID int `gorm:"column:rid_base;primaryKey;identity" json:"id"` + CreatedAt string `gorm:"column:created_at;type:timestamp" json:"created_at"` +} + +type ModelWithEmbeddedForDetail struct { + EmbeddedBase + Title string `gorm:"column:title;type:varchar(100);not null" json:"title"` + Content string `gorm:"column:content;type:text" json:"content"` +} + +// Model with nil embedded pointer +type ModelWithNilEmbedded struct { + ID int `gorm:"column:id;primaryKey" json:"id"` + *EmbeddedBase + Name string `gorm:"column:name" json:"name"` +} + +func TestGetModelColumnDetail(t *testing.T) { + t.Run("simple struct", func(t *testing.T) { + model := TestModelForColumnDetail{ + ID: 1, + Name: "Test", + Email: "test@example.com", + Description: "Test description", + ForeignKey: 100, + } + + details := GetModelColumnDetail(reflect.ValueOf(model)) + + if len(details) != 5 { + t.Errorf("Expected 5 fields, got %d", len(details)) + } + + // Check ID field + found := false + for _, detail := range details { + if detail.Name == "ID" { + found = true + if detail.SQLName != "rid_test" { + t.Errorf("Expected SQLName 'rid_test', got '%s'", detail.SQLName) + } + // Note: primaryKey (without underscore) is not detected as primary_key + // The function looks for "identity" or "primary_key" (with underscore) + if detail.SQLDataType != "bigserial" { + t.Errorf("Expected SQLDataType 'bigserial', got '%s'", detail.SQLDataType) + } + if detail.Nullable { + t.Errorf("Expected Nullable false, got true") + } + } + } + if !found { + t.Errorf("ID field not found in details") + } + }) + + t.Run("struct with embedded fields", func(t *testing.T) { + model := ModelWithEmbeddedForDetail{ + EmbeddedBase: EmbeddedBase{ + ID: 1, + CreatedAt: "2024-01-01", + }, + Title: "Test Title", + Content: "Test Content", + } + + details := GetModelColumnDetail(reflect.ValueOf(model)) + + // Should have 4 fields: ID, CreatedAt from embedded, Title, Content from main + if len(details) != 4 { + t.Errorf("Expected 4 fields, got %d", len(details)) + } + + // Check that embedded field is included + foundID := false + foundCreatedAt := false + for _, detail := range details { + if detail.Name == "ID" { + foundID = true + if detail.SQLKey != "primary_key" { + t.Errorf("Expected SQLKey 'primary_key' for embedded ID, got '%s'", detail.SQLKey) + } + } + if detail.Name == "CreatedAt" { + foundCreatedAt = true + } + } + if !foundID { + t.Errorf("Embedded ID field not found") + } + if !foundCreatedAt { + t.Errorf("Embedded CreatedAt field not found") + } + }) + + t.Run("nil embedded pointer is skipped", func(t *testing.T) { + model := ModelWithNilEmbedded{ + ID: 1, + Name: "Test", + EmbeddedBase: nil, // nil embedded pointer + } + + details := GetModelColumnDetail(reflect.ValueOf(model)) + + // Should have 2 fields: ID and Name (embedded is nil, so skipped) + if len(details) != 2 { + t.Errorf("Expected 2 fields (nil embedded skipped), got %d", len(details)) + } + }) + + t.Run("pointer to struct", func(t *testing.T) { + model := &TestModelForColumnDetail{ + ID: 1, + Name: "Test", + } + + details := GetModelColumnDetail(reflect.ValueOf(model)) + + if len(details) != 5 { + t.Errorf("Expected 5 fields, got %d", len(details)) + } + }) + + t.Run("invalid value", func(t *testing.T) { + var invalid reflect.Value + details := GetModelColumnDetail(invalid) + + if len(details) != 0 { + t.Errorf("Expected 0 fields for invalid value, got %d", len(details)) + } + }) + + t.Run("non-struct type", func(t *testing.T) { + details := GetModelColumnDetail(reflect.ValueOf(123)) + + if len(details) != 0 { + t.Errorf("Expected 0 fields for non-struct, got %d", len(details)) + } + }) + + t.Run("nullable and not null detection", func(t *testing.T) { + model := TestModelForColumnDetail{} + details := GetModelColumnDetail(reflect.ValueOf(model)) + + for _, detail := range details { + switch detail.Name { + case "ID": + if detail.Nullable { + t.Errorf("ID should not be nullable (has 'not null')") + } + case "Name": + if detail.Nullable { + t.Errorf("Name should not be nullable (has 'not null')") + } + case "Email": + if !detail.Nullable { + t.Errorf("Email should be nullable (has 'nullable')") + } + case "Description": + if !detail.Nullable { + t.Errorf("Description should be nullable (has 'null')") + } + } + } + }) + + t.Run("unique and uniqueindex detection", func(t *testing.T) { + type UniqueTestModel struct { + ID int `gorm:"column:id;primary_key"` + Username string `gorm:"column:username;unique"` + Email string `gorm:"column:email;uniqueindex"` + } + + model := UniqueTestModel{} + details := GetModelColumnDetail(reflect.ValueOf(model)) + + for _, detail := range details { + switch detail.Name { + case "ID": + if detail.SQLKey != "primary_key" { + t.Errorf("ID should have SQLKey 'primary_key', got '%s'", detail.SQLKey) + } + case "Username": + if detail.SQLKey != "unique" { + t.Errorf("Username should have SQLKey 'unique', got '%s'", detail.SQLKey) + } + case "Email": + // The function checks for "unique" first, so uniqueindex is also detected as "unique" + // This is expected behavior based on the code logic + if detail.SQLKey != "unique" { + t.Errorf("Email should have SQLKey 'unique' (uniqueindex contains 'unique'), got '%s'", detail.SQLKey) + } + } + } + }) + + t.Run("foreign key detection", func(t *testing.T) { + // Note: The foreignkey extraction in generic_model.go has a bug where + // it requires ik > 0, so foreignkey at the start won't extract the value + type FKTestModel struct { + ParentID int `gorm:"column:parent_id;foreignkey:rid_parent;association_foreignkey:id_atevent"` + } + + model := FKTestModel{} + details := GetModelColumnDetail(reflect.ValueOf(model)) + + if len(details) == 0 { + t.Fatal("Expected at least 1 field") + } + + detail := details[0] + if detail.SQLKey != "foreign_key" { + t.Errorf("Expected SQLKey 'foreign_key', got '%s'", detail.SQLKey) + } + // Due to the bug in the code (requires ik > 0), the SQLName will be extracted + // when foreignkey is not at the beginning of the string + if detail.SQLName != "rid_parent" { + t.Errorf("Expected SQLName 'rid_parent', got '%s'", detail.SQLName) + } + }) +} + +func TestFnFindKeyVal(t *testing.T) { + tests := []struct { + name string + src string + key string + expected string + }{ + { + name: "find column", + src: "column:user_id;primaryKey;type:bigint", + key: "column:", + expected: "user_id", + }, + { + name: "find type", + src: "column:name;type:varchar(255);not null", + key: "type:", + expected: "varchar(255)", + }, + { + name: "key not found", + src: "primaryKey;autoIncrement", + key: "column:", + expected: "", + }, + { + name: "key at end without semicolon", + src: "primaryKey;column:id", + key: "column:", + expected: "id", + }, + { + name: "case insensitive search", + src: "Column:user_id;primaryKey", + key: "column:", + expected: "user_id", + }, + { + name: "empty src", + src: "", + key: "column:", + expected: "", + }, + { + name: "multiple occurrences (returns first)", + src: "column:first;column:second", + key: "column:", + expected: "first", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := fnFindKeyVal(tt.src, tt.key) + if result != tt.expected { + t.Errorf("fnFindKeyVal(%q, %q) = %q, want %q", tt.src, tt.key, result, tt.expected) + } + }) + } +} + +func TestGetModelColumnDetail_FieldValue(t *testing.T) { + model := TestModelForColumnDetail{ + ID: 123, + Name: "TestName", + Email: "test@example.com", + } + + details := GetModelColumnDetail(reflect.ValueOf(model)) + + for _, detail := range details { + if !detail.FieldValue.IsValid() { + t.Errorf("Field %s has invalid FieldValue", detail.Name) + } + + // Check that FieldValue matches the actual value + switch detail.Name { + case "ID": + if detail.FieldValue.Int() != 123 { + t.Errorf("Expected ID FieldValue 123, got %v", detail.FieldValue.Int()) + } + case "Name": + if detail.FieldValue.String() != "TestName" { + t.Errorf("Expected Name FieldValue 'TestName', got %v", detail.FieldValue.String()) + } + case "Email": + if detail.FieldValue.String() != "test@example.com" { + t.Errorf("Expected Email FieldValue 'test@example.com', got %v", detail.FieldValue.String()) + } + } + } +} diff --git a/pkg/reflection/model_utils_test.go b/pkg/reflection/model_utils_test.go index 64814e5..8aa3db6 100644 --- a/pkg/reflection/model_utils_test.go +++ b/pkg/reflection/model_utils_test.go @@ -1,6 +1,7 @@ package reflection import ( + "reflect" "testing" ) @@ -614,3 +615,1075 @@ func TestGetSQLModelColumnsVsGetModelColumns(t *testing.T) { t.Errorf("GetModelColumns should include 'profile_data' (has json tag)") } } + +// ============= Tests for helpers.go ============= + +func TestLen(t *testing.T) { + tests := []struct { + name string + input any + expected int + }{ + { + name: "slice of ints", + input: []int{1, 2, 3, 4, 5}, + expected: 5, + }, + { + name: "empty slice", + input: []string{}, + expected: 0, + }, + { + name: "array", + input: [3]int{1, 2, 3}, + expected: 3, + }, + { + name: "string", + input: "hello", + expected: 5, + }, + { + name: "empty string", + input: "", + expected: 0, + }, + { + name: "map", + input: map[string]int{"a": 1, "b": 2, "c": 3}, + expected: 3, + }, + { + name: "empty map", + input: map[string]int{}, + expected: 0, + }, + { + name: "pointer to slice", + input: &[]int{1, 2, 3}, + expected: 3, + }, + { + name: "non-lennable type (int)", + input: 42, + expected: 0, + }, + { + name: "non-lennable type (struct)", + input: struct{}{}, + expected: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := Len(tt.input) + if result != tt.expected { + t.Errorf("Len() = %v, want %v", result, tt.expected) + } + }) + } +} + +func TestExtractTableNameOnly(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + { + name: "simple table name", + input: "users", + expected: "users", + }, + { + name: "schema.table", + input: "public.users", + expected: "users", + }, + { + name: "table with comma", + input: "users,", + expected: "users", + }, + { + name: "table with space", + input: "users WHERE", + expected: "users", + }, + { + name: "schema.table with space", + input: "public.users WHERE id = 1", + expected: "users", + }, + { + name: "schema.table with comma", + input: "myschema.mytable, other_table", + expected: "mytable", + }, + { + name: "table with tab", + input: "users\tJOIN", + expected: "users", + }, + { + name: "table with newline", + input: "users\nWHERE", + expected: "users", + }, + { + name: "multiple dots", + input: "db.schema.table WHERE", + expected: "table", + }, + { + name: "no delimiters", + input: "tablename", + expected: "tablename", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := ExtractTableNameOnly(tt.input) + if result != tt.expected { + t.Errorf("ExtractTableNameOnly(%q) = %q, want %q", tt.input, result, tt.expected) + } + }) + } +} + +// ============= Tests for utility functions ============= + +func TestExtractSourceColumn(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + { + name: "column with ->> operator", + input: "columna->>'val'", + expected: "columna", + }, + { + name: "column with -> operator", + input: "columna->'key'", + expected: "columna", + }, + { + name: "simple column", + input: "columna", + expected: "columna", + }, + { + name: "table.column with ->> operator", + input: "table.columna->>'val'", + expected: "table.columna", + }, + { + name: "table.column with -> operator", + input: "table.columna->'key'", + expected: "table.columna", + }, + { + name: "column with spaces before operator", + input: "columna ->>'value'", + expected: "columna", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := ExtractSourceColumn(tt.input) + if result != tt.expected { + t.Errorf("ExtractSourceColumn(%q) = %q, want %q", tt.input, result, tt.expected) + } + }) + } +} + +func TestToSnakeCase(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + { + name: "CamelCase", + input: "CamelCase", + expected: "camel_case", + }, + { + name: "camelCase", + input: "camelCase", + expected: "camel_case", + }, + { + name: "UserID", + input: "UserID", + expected: "user_i_d", + }, + { + name: "HTTPServer", + input: "HTTPServer", + expected: "h_t_t_p_server", + }, + { + name: "lowercase", + input: "lowercase", + expected: "lowercase", + }, + { + name: "UPPERCASE", + input: "UPPERCASE", + expected: "u_p_p_e_r_c_a_s_e", + }, + { + name: "Single", + input: "A", + expected: "a", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := ToSnakeCase(tt.input) + if result != tt.expected { + t.Errorf("ToSnakeCase(%q) = %q, want %q", tt.input, result, tt.expected) + } + }) + } +} + +func TestIsNumericType(t *testing.T) { + tests := []struct { + name string + kind reflect.Kind + expected bool + }{ + {"int", reflect.Int, true}, + {"int8", reflect.Int8, true}, + {"int16", reflect.Int16, true}, + {"int32", reflect.Int32, true}, + {"int64", reflect.Int64, true}, + {"uint", reflect.Uint, true}, + {"uint8", reflect.Uint8, true}, + {"uint16", reflect.Uint16, true}, + {"uint32", reflect.Uint32, true}, + {"uint64", reflect.Uint64, true}, + {"float32", reflect.Float32, true}, + {"float64", reflect.Float64, true}, + {"string", reflect.String, false}, + {"bool", reflect.Bool, false}, + {"struct", reflect.Struct, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := IsNumericType(tt.kind) + if result != tt.expected { + t.Errorf("IsNumericType(%v) = %v, want %v", tt.kind, result, tt.expected) + } + }) + } +} + +func TestIsStringType(t *testing.T) { + tests := []struct { + name string + kind reflect.Kind + expected bool + }{ + {"string", reflect.String, true}, + {"int", reflect.Int, false}, + {"bool", reflect.Bool, false}, + {"struct", reflect.Struct, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := IsStringType(tt.kind) + if result != tt.expected { + t.Errorf("IsStringType(%v) = %v, want %v", tt.kind, result, tt.expected) + } + }) + } +} + +func TestIsNumericValue(t *testing.T) { + tests := []struct { + name string + value string + expected bool + }{ + {"integer", "123", true}, + {"negative integer", "-456", true}, + {"float", "123.45", true}, + {"negative float", "-123.45", true}, + {"scientific notation", "1.23e10", true}, + {"with spaces", " 789 ", true}, + {"non-numeric", "abc", false}, + {"mixed", "123abc", false}, + {"empty string", "", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := IsNumericValue(tt.value) + if result != tt.expected { + t.Errorf("IsNumericValue(%q) = %v, want %v", tt.value, result, tt.expected) + } + }) + } +} + +func TestConvertToNumericType(t *testing.T) { + tests := []struct { + name string + value string + kind reflect.Kind + expected interface{} + expectError bool + }{ + // Integer types + {"int", "123", reflect.Int, int(123), false}, + {"int8", "100", reflect.Int8, int8(100), false}, + {"int16", "1000", reflect.Int16, int16(1000), false}, + {"int32", "100000", reflect.Int32, int32(100000), false}, + {"int64", "9223372036854775807", reflect.Int64, int64(9223372036854775807), false}, + {"negative int", "-456", reflect.Int, int(-456), false}, + {"invalid int", "abc", reflect.Int, nil, true}, + + // Unsigned integer types + {"uint", "123", reflect.Uint, uint(123), false}, + {"uint8", "255", reflect.Uint8, uint8(255), false}, + {"uint16", "65535", reflect.Uint16, uint16(65535), false}, + {"uint32", "4294967295", reflect.Uint32, uint32(4294967295), false}, + {"uint64", "18446744073709551615", reflect.Uint64, uint64(18446744073709551615), false}, + {"invalid uint", "abc", reflect.Uint, nil, true}, + {"negative uint", "-1", reflect.Uint, nil, true}, + + // Float types + {"float32", "123.45", reflect.Float32, float32(123.45), false}, + {"float64", "123.456789", reflect.Float64, float64(123.456789), false}, + {"negative float", "-123.45", reflect.Float64, float64(-123.45), false}, + {"scientific notation", "1.23e10", reflect.Float64, float64(1.23e10), false}, + {"invalid float", "abc", reflect.Float32, nil, true}, + + // Edge cases + {"with spaces", " 789 ", reflect.Int, int(789), false}, + + // Unsupported types + {"unsupported type", "123", reflect.String, nil, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := ConvertToNumericType(tt.value, tt.kind) + if tt.expectError { + if err == nil { + t.Errorf("ConvertToNumericType(%q, %v) expected error, got nil", tt.value, tt.kind) + } + return + } + if err != nil { + t.Errorf("ConvertToNumericType(%q, %v) unexpected error: %v", tt.value, tt.kind, err) + return + } + if result != tt.expected { + t.Errorf("ConvertToNumericType(%q, %v) = %v, want %v", tt.value, tt.kind, result, tt.expected) + } + }) + } +} + +// Test model for GetColumnTypeFromModel +type TypeTestModel struct { + ID int `json:"id"` + Name string `json:"name"` + Age int `json:"age"` + Balance float64 `json:"balance"` + Active bool `json:"active"` + Metadata string `json:"metadata"` +} + +func TestGetColumnTypeFromModel(t *testing.T) { + model := TypeTestModel{ + ID: 1, + Name: "Test", + Age: 30, + Balance: 100.50, + Active: true, + Metadata: `{"key": "value"}`, + } + + tests := []struct { + name string + model interface{} + colName string + expected reflect.Kind + }{ + {"int field", model, "id", reflect.Int}, + {"string field", model, "name", reflect.String}, + {"int field by name", model, "age", reflect.Int}, + {"float64 field", model, "balance", reflect.Float64}, + {"bool field", model, "active", reflect.Bool}, + {"string with JSON", model, "metadata", reflect.String}, + {"non-existent field", model, "nonexistent", reflect.Invalid}, + {"nil model", nil, "id", reflect.Invalid}, + {"pointer to model", &model, "name", reflect.String}, + {"column with JSON operator", model, "metadata->>'key'", reflect.String}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := GetColumnTypeFromModel(tt.model, tt.colName) + if result != tt.expected { + t.Errorf("GetColumnTypeFromModel(%v, %q) = %v, want %v", tt.model, tt.colName, result, tt.expected) + } + }) + } +} + +// ============= Tests for relation functions ============= + +// Models for relation testing +type Author struct { + ID int `bun:"id,pk" json:"id"` + Name string `bun:"name" json:"name"` + Books []Book `bun:"rel:has-many,join:id=author_id" json:"books"` +} + +type Book struct { + ID int `bun:"id,pk" json:"id"` + Title string `bun:"title" json:"title"` + AuthorID int `bun:"author_id" json:"author_id"` + Author *Author `bun:"rel:belongs-to,join:author_id=id" json:"author"` + Publisher *Publisher `bun:"rel:has-one,join:id=book_id" json:"publisher"` +} + +type Publisher struct { + ID int `bun:"id,pk" json:"id"` + Name string `bun:"name" json:"name"` + BookID int `bun:"book_id" json:"book_id"` +} + +type Student struct { + ID int `gorm:"column:id;primaryKey" json:"id"` + Name string `gorm:"column:name" json:"name"` + Courses []Course `gorm:"many2many:student_courses" json:"courses"` +} + +type Course struct { + ID int `gorm:"column:id;primaryKey" json:"id"` + Title string `gorm:"column:title" json:"title"` + Students []Student `gorm:"many2many:student_courses" json:"students"` +} + +// Recursive relation model +type Category struct { + ID int `bun:"id,pk" json:"id"` + Name string `bun:"name" json:"name"` + ParentID *int `bun:"parent_id" json:"parent_id"` + Parent *Category `bun:"rel:belongs-to,join:parent_id=id" json:"parent"` + Children []Category `bun:"rel:has-many,join:id=parent_id" json:"children"` +} + +func TestGetRelationType(t *testing.T) { + tests := []struct { + name string + model interface{} + fieldName string + expected RelationType + }{ + // Bun relations + {"has-many relation", Author{}, "Books", RelationHasMany}, + {"belongs-to relation", Book{}, "Author", RelationBelongsTo}, + {"has-one relation", Book{}, "Publisher", RelationHasOne}, + + // GORM relations + {"many-to-many relation (GORM)", Student{}, "Courses", RelationManyToMany}, + {"many-to-many reverse (GORM)", Course{}, "Students", RelationManyToMany}, + + // Recursive relations + {"recursive belongs-to", Category{}, "Parent", RelationBelongsTo}, + {"recursive has-many", Category{}, "Children", RelationHasMany}, + + // Edge cases + {"non-existent field", Author{}, "NonExistent", RelationUnknown}, + {"nil model", nil, "Books", RelationUnknown}, + {"empty field name", Author{}, "", RelationUnknown}, + {"pointer model", &Author{}, "Books", RelationHasMany}, + + // Case-insensitive field names + {"case-insensitive has-many", Author{}, "books", RelationHasMany}, + {"case-insensitive belongs-to", Book{}, "author", RelationBelongsTo}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := GetRelationType(tt.model, tt.fieldName) + if result != tt.expected { + t.Errorf("GetRelationType(%T, %q) = %v, want %v", tt.model, tt.fieldName, result, tt.expected) + } + }) + } +} + +func TestShouldUseJoin(t *testing.T) { + tests := []struct { + name string + relType RelationType + expected bool + }{ + {"belongs-to should use JOIN", RelationBelongsTo, true}, + {"has-one should use JOIN", RelationHasOne, true}, + {"has-many should NOT use JOIN", RelationHasMany, false}, + {"many-to-many should NOT use JOIN", RelationManyToMany, false}, + {"unknown should NOT use JOIN", RelationUnknown, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.relType.ShouldUseJoin() + if result != tt.expected { + t.Errorf("RelationType(%v).ShouldUseJoin() = %v, want %v", tt.relType, result, tt.expected) + } + }) + } +} + +func TestGetRelationModel(t *testing.T) { + tests := []struct { + name string + model interface{} + fieldName string + isNil bool + }{ + {"has-many relation", Author{}, "Books", false}, + {"belongs-to relation", Book{}, "Author", false}, + {"has-one relation", Book{}, "Publisher", false}, + {"many-to-many relation", Student{}, "Courses", false}, + + // Recursive relations + {"recursive belongs-to", Category{}, "Parent", false}, + {"recursive has-many", Category{}, "Children", false}, + + // Nested/recursive field paths + {"nested recursive", Category{}, "Parent.Parent", false}, + {"nested recursive children", Category{}, "Children", false}, + + // Edge cases + {"non-existent field", Author{}, "NonExistent", true}, + {"nil model", nil, "Books", true}, + {"empty field name", Author{}, "", true}, + {"pointer model", &Author{}, "Books", false}, + + // Case-insensitive field names + {"case-insensitive has-many", Author{}, "books", false}, + {"case-insensitive belongs-to", Book{}, "author", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := GetRelationModel(tt.model, tt.fieldName) + if tt.isNil { + if result != nil { + t.Errorf("GetRelationModel(%T, %q) = %v, want nil", tt.model, tt.fieldName, result) + } + } else { + if result == nil { + t.Errorf("GetRelationModel(%T, %q) = nil, want non-nil", tt.model, tt.fieldName) + } + } + }) + } +} + +// ============= Additional edge case tests for better coverage ============= + +func TestGetPrimaryKeyName_EdgeCases(t *testing.T) { + tests := []struct { + name string + model any + expected string + }{ + { + name: "nil model", + model: nil, + expected: "", + }, + { + name: "string model name (not implemented yet)", + model: "SomeModel", + expected: "", + }, + { + name: "slice of models", + model: []BunModelWithColumnTag{}, + expected: "", + }, + { + name: "array of models", + model: [3]BunModelWithColumnTag{}, + expected: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := GetPrimaryKeyName(tt.model) + if result != tt.expected { + t.Errorf("GetPrimaryKeyName() = %v, want %v", result, tt.expected) + } + }) + } +} + +func TestGetPrimaryKeyValue_EdgeCases(t *testing.T) { + tests := []struct { + name string + model any + expected any + }{ + { + name: "nil model", + model: nil, + expected: nil, + }, + { + name: "non-struct type", + model: 123, + expected: nil, + }, + { + name: "slice", + model: []int{1, 2, 3}, + expected: nil, + }, + { + name: "model without primary key tags - fallback to ID field", + model: struct { + ID int + Name string + }{ID: 99, Name: "Test"}, + expected: 99, + }, + { + name: "model without ID field", + model: struct { + Name string + }{Name: "Test"}, + expected: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := GetPrimaryKeyValue(tt.model) + if result != tt.expected { + t.Errorf("GetPrimaryKeyValue() = %v, want %v", result, tt.expected) + } + }) + } +} + +func TestGetModelColumns_EdgeCases(t *testing.T) { + tests := []struct { + name string + model any + expected []string + }{ + { + name: "nil type", + model: nil, + expected: []string{}, + }, + { + name: "non-struct type", + model: 123, + expected: []string{}, + }, + { + name: "slice type", + model: []BunModelWithColumnTag{}, + expected: []string{"custom_id", "name"}, + }, + { + name: "array type", + model: [3]BunModelWithColumnTag{}, + expected: []string{"custom_id", "name"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := GetModelColumns(tt.model) + if len(result) != len(tt.expected) { + t.Errorf("GetModelColumns() returned %d columns, want %d", len(result), len(tt.expected)) + return + } + for i, col := range result { + if col != tt.expected[i] { + t.Errorf("GetModelColumns()[%d] = %v, want %v", i, col, tt.expected[i]) + } + } + }) + } +} + +func TestIsColumnWritable_EdgeCases(t *testing.T) { + tests := []struct { + name string + model any + columnName string + expected bool + }{ + { + name: "nil model", + model: nil, + columnName: "name", + expected: false, + }, + { + name: "non-struct type", + model: 123, + columnName: "name", + expected: false, + }, + { + name: "column not found in model (dynamic column)", + model: BunModelWithColumnTag{}, + columnName: "dynamic_column", + expected: true, // Not found, allow it (might be dynamic) + }, + { + name: "pointer to model", + model: &ModelWithEmbedded{}, + columnName: "name", + expected: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := IsColumnWritable(tt.model, tt.columnName) + if result != tt.expected { + t.Errorf("IsColumnWritable(%s) = %v, want %v", tt.columnName, result, tt.expected) + } + }) + } +} + +func TestIsGormFieldReadOnly_EdgeCases(t *testing.T) { + tests := []struct { + name string + tag string + expected bool + }{ + { + name: "read-only marker", + tag: "column:name;->", + expected: true, + }, + { + name: "write restriction <-:false", + tag: "column:name;<-:false", + expected: true, + }, + { + name: "write allowed <-:create", + tag: "<-:create", + expected: false, + }, + { + name: "write allowed <-:update", + tag: "<-:update", + expected: false, + }, + { + name: "no restrictions", + tag: "column:name;type:varchar(255)", + expected: false, + }, + { + name: "empty tag", + tag: "", + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := isGormFieldReadOnly(tt.tag) + if result != tt.expected { + t.Errorf("isGormFieldReadOnly(%q) = %v, want %v", tt.tag, result, tt.expected) + } + }) + } +} + +func TestGetSQLModelColumns_EdgeCases(t *testing.T) { + tests := []struct { + name string + model any + expected []string + }{ + { + name: "nil model", + model: nil, + expected: []string{}, + }, + { + name: "non-struct type", + model: 123, + expected: []string{}, + }, + { + name: "slice type", + model: []Profile{}, + expected: []string{"id", "bio", "user_id"}, + }, + { + name: "array type", + model: [2]Profile{}, + expected: []string{"id", "bio", "user_id"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := GetSQLModelColumns(tt.model) + if len(result) != len(tt.expected) { + t.Errorf("GetSQLModelColumns() returned %d columns, want %d.\nGot: %v\nWant: %v", + len(result), len(tt.expected), result, tt.expected) + return + } + for i, col := range result { + if col != tt.expected[i] { + t.Errorf("GetSQLModelColumns()[%d] = %v, want %v.\nFull result: %v", + i, col, tt.expected[i], result) + } + } + }) + } +} + +// Test models with table:, rel:, join: tags for ExtractColumnFromBunTag +type BunSpecialTagsModel struct { + Table string `bun:"table:users"` + Relation []Post `bun:"rel:has-many"` + Join string `bun:"join:id=user_id"` + NormalCol string `bun:"normal_col"` +} + +func TestExtractColumnFromBunTag_SpecialTags(t *testing.T) { + tests := []struct { + name string + tag string + expected string + }{ + { + name: "table tag", + tag: "table:users", + expected: "", + }, + { + name: "rel tag", + tag: "rel:has-many", + expected: "", + }, + { + name: "join tag", + tag: "join:id=user_id", + expected: "", + }, + { + name: "normal column", + tag: "normal_col,pk", + expected: "normal_col", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := ExtractColumnFromBunTag(tt.tag) + if result != tt.expected { + t.Errorf("ExtractColumnFromBunTag(%q) = %q, want %q", tt.tag, result, tt.expected) + } + }) + } +} + +// Test GORM fallback scenarios +type GormFallbackModel struct { + UserID int `gorm:"foreignKey:UserId"` +} + +func TestGetRelationType_GORMFallback(t *testing.T) { + tests := []struct { + name string + model interface{} + fieldName string + expected RelationType + }{ + { + name: "GORM slice without many2many", + model: Post{}, + fieldName: "Tags", + expected: RelationManyToMany, // Has many2many tag + }, + { + name: "GORM pointer with foreignKey", + model: Post{}, + fieldName: "User", + expected: RelationBelongsTo, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := GetRelationType(tt.model, tt.fieldName) + if result != tt.expected { + t.Errorf("GetRelationType(%T, %q) = %v, want %v", tt.model, tt.fieldName, result, tt.expected) + } + }) + } +} + +// Additional tests for better coverage of GetRelationType +func TestGetRelationType_AdditionalCases(t *testing.T) { + // Test model with GORM has-one (pointer without foreignKey or with references) + type Address struct { + ID int `gorm:"column:id;primaryKey"` + UserID int `gorm:"column:user_id"` + } + + type UserWithAddress struct { + ID int `gorm:"column:id;primaryKey"` + Address *Address `gorm:"references:UserID"` // has-one relation + } + + // Test model with field type inference + type Company struct { + ID int + Name string + } + + type Employee struct { + ID int + Company Company // Single struct (not pointer, not slice) - belongs-to + Coworkers []Employee // Slice without bun/gorm tags - has-many + } + + tests := []struct { + name string + model interface{} + fieldName string + expected RelationType + }{ + { + name: "GORM has-one (pointer with references)", + model: UserWithAddress{}, + fieldName: "Address", + expected: RelationHasOne, + }, + { + name: "Field type inference - single struct", + model: Employee{}, + fieldName: "Company", + expected: RelationBelongsTo, + }, + { + name: "Field type inference - slice", + model: Employee{}, + fieldName: "Coworkers", + expected: RelationHasMany, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := GetRelationType(tt.model, tt.fieldName) + if result != tt.expected { + t.Errorf("GetRelationType(%T, %q) = %v, want %v", tt.model, tt.fieldName, result, tt.expected) + } + }) + } +} + +// Test for GetColumnTypeFromModel with more edge cases +func TestGetColumnTypeFromModel_AdditionalCases(t *testing.T) { + type ModelWithSnakeCase struct { + UserID int `json:"user_id"` + UserName string // No tag, will match by snake_case conversion + } + + model := ModelWithSnakeCase{ + UserID: 123, + UserName: "John", + } + + tests := []struct { + name string + model interface{} + colName string + expected reflect.Kind + }{ + {"field by snake_case name", model, "user_name", reflect.String}, + {"non-struct model", 123, "field", reflect.Invalid}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := GetColumnTypeFromModel(tt.model, tt.colName) + if result != tt.expected { + t.Errorf("GetColumnTypeFromModel(%v, %q) = %v, want %v", tt.model, tt.colName, result, tt.expected) + } + }) + } +} + +// Test for getRelationModelSingleLevel edge cases +func TestGetRelationModel_WithTags(t *testing.T) { + // Test matching by gorm column tag + type Department struct { + ID int `gorm:"column:dept_id;primaryKey"` + Name string `gorm:"column:dept_name"` + } + + type Manager struct { + ID int `gorm:"column:id;primaryKey"` + DeptID int `gorm:"column:department_id"` + Department *Department `gorm:"column:dept;foreignKey:DeptID"` + } + + tests := []struct { + name string + model interface{} + fieldName string + isNil bool + }{ + // Test matching by gorm column name + {"match by gorm column", Manager{}, "dept", false}, + // Test matching by json tag + {"match by json tag", Book{}, "author", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := GetRelationModel(tt.model, tt.fieldName) + if tt.isNil { + if result != nil { + t.Errorf("GetRelationModel(%T, %q) = %v, want nil", tt.model, tt.fieldName, result) + } + } else { + if result == nil { + t.Errorf("GetRelationModel(%T, %q) = nil, want non-nil", tt.model, tt.fieldName) + } + } + }) + } +}