From d93a4b6f08d2cfae3ecd65be36ba103fca9a85a0 Mon Sep 17 00:00:00 2001 From: Hein Date: Thu, 18 Dec 2025 19:15:22 +0200 Subject: [PATCH] Fixed bug/gorm indexes --- pkg/readers/bun/reader.go | 225 ++++++++++++++++++-- pkg/readers/gorm/reader.go | 244 +++++++++++++++++++++- pkg/writers/bun/type_mapper.go | 30 ++- tests/integration/orm_roundtrip_test.go | 265 ++++++++++++++---------- 4 files changed, 622 insertions(+), 142 deletions(-) diff --git a/pkg/readers/bun/reader.go b/pkg/readers/bun/reader.go index b20addd..213302b 100644 --- a/pkg/readers/bun/reader.go +++ b/pkg/readers/bun/reader.go @@ -187,11 +187,43 @@ func (r *Reader) parseFile(filename string) ([]*models.Table, error) { table.Schema = schemaName } - // Update columns + // Update columns and indexes for _, col := range table.Columns { col.Table = tableName col.Schema = table.Schema } + for _, idx := range table.Indexes { + idx.Table = tableName + idx.Schema = table.Schema + } + } + } + + // Third pass: parse relationship fields for constraints + for _, decl := range node.Decls { + genDecl, ok := decl.(*ast.GenDecl) + if !ok || genDecl.Tok != token.TYPE { + continue + } + + for _, spec := range genDecl.Specs { + typeSpec, ok := spec.(*ast.TypeSpec) + if !ok { + continue + } + + structType, ok := typeSpec.Type.(*ast.StructType) + if !ok { + continue + } + + table, ok := structMap[typeSpec.Name.Name] + if !ok { + continue + } + + // Parse relationship fields + r.parseRelationshipConstraints(table, structType, structMap) } } @@ -314,6 +346,10 @@ func (r *Reader) parseStruct(structName string, structType *ast.StructType) *mod column.Table = tableName column.Schema = schemaName table.Columns[column.Name] = column + + // Parse indexes from bun tags + r.parseIndexesFromTag(table, column, tag) + sequence++ } } @@ -346,6 +382,159 @@ func (r *Reader) isRelationship(tag string) bool { return strings.Contains(tag, "bun:\"rel:") || strings.Contains(tag, ",rel:") } +// parseRelationshipConstraints parses relationship fields to extract foreign key constraints +func (r *Reader) parseRelationshipConstraints(table *models.Table, structType *ast.StructType, structMap map[string]*models.Table) { + for _, field := range structType.Fields.List { + if field.Tag == nil { + continue + } + + tag := field.Tag.Value + if !r.isRelationship(tag) { + continue + } + + bunTag := r.extractBunTag(tag) + + // Get the referenced type name from the field type + referencedType := r.getRelationshipType(field.Type) + if referencedType == "" { + continue + } + + // Find the referenced table + referencedTable, ok := structMap[referencedType] + if !ok { + continue + } + + // Parse the join information: join:user_id=id + // This means: referencedTable.user_id = thisTable.id + joinInfo := r.parseJoinInfo(bunTag) + if joinInfo == nil { + continue + } + + // The FK is on the referenced table + constraint := &models.Constraint{ + Name: fmt.Sprintf("fk_%s_%s", referencedTable.Name, table.Name), + Type: models.ForeignKeyConstraint, + Table: referencedTable.Name, + Schema: referencedTable.Schema, + Columns: []string{joinInfo.ForeignKey}, + ReferencedTable: table.Name, + ReferencedSchema: table.Schema, + ReferencedColumns: []string{joinInfo.ReferencedKey}, + OnDelete: "NO ACTION", // Bun doesn't specify this in tags + OnUpdate: "NO ACTION", + } + + referencedTable.Constraints[constraint.Name] = constraint + } +} + +// JoinInfo holds parsed join information +type JoinInfo struct { + ForeignKey string // Column in the related table + ReferencedKey string // Column in the current table +} + +// parseJoinInfo parses join information from bun tag +// Example: join:user_id=id means foreign_key=referenced_key +func (r *Reader) parseJoinInfo(bunTag string) *JoinInfo { + // Look for join: in the tag + if !strings.Contains(bunTag, "join:") { + return nil + } + + // Extract join clause + parts := strings.Split(bunTag, ",") + for _, part := range parts { + part = strings.TrimSpace(part) + if strings.HasPrefix(part, "join:") { + joinStr := strings.TrimPrefix(part, "join:") + // Parse user_id=id + joinParts := strings.SplitN(joinStr, "=", 2) + if len(joinParts) == 2 { + return &JoinInfo{ + ForeignKey: joinParts[0], + ReferencedKey: joinParts[1], + } + } + } + } + + return nil +} + +// getRelationshipType extracts the type name from a relationship field +func (r *Reader) getRelationshipType(expr ast.Expr) string { + switch t := expr.(type) { + case *ast.ArrayType: + // []*ModelPost -> ModelPost + if starExpr, ok := t.Elt.(*ast.StarExpr); ok { + if ident, ok := starExpr.X.(*ast.Ident); ok { + return ident.Name + } + } + case *ast.StarExpr: + // *ModelPost -> ModelPost + if ident, ok := t.X.(*ast.Ident); ok { + return ident.Name + } + } + return "" +} + +// parseIndexesFromTag extracts index definitions from Bun tags +func (r *Reader) parseIndexesFromTag(table *models.Table, column *models.Column, tag string) { + bunTag := r.extractBunTag(tag) + + // Parse tag into parts + parts := strings.Split(bunTag, ",") + for _, part := range parts { + part = strings.TrimSpace(part) + + // Check for unique index + if part == "unique" { + // Auto-generate index name: idx_tablename_columnname + indexName := fmt.Sprintf("idx_%s_%s", table.Name, column.Name) + + if _, exists := table.Indexes[indexName]; !exists { + index := &models.Index{ + Name: indexName, + Table: table.Name, + Schema: table.Schema, + Columns: []string{column.Name}, + Unique: true, + Type: "btree", + } + table.Indexes[indexName] = index + } + } else if strings.HasPrefix(part, "unique:") { + // Named unique index: unique:idx_name + indexName := strings.TrimPrefix(part, "unique:") + + // Check if index already exists (for composite indexes) + if idx, exists := table.Indexes[indexName]; exists { + // Add this column to the existing index + idx.Columns = append(idx.Columns, column.Name) + } else { + // Create new index + index := &models.Index{ + Name: indexName, + Table: table.Name, + Schema: table.Schema, + Columns: []string{column.Name}, + Unique: true, + Type: "btree", + } + table.Indexes[indexName] = index + } + } + } +} + // extractTableNameFromTag extracts table and schema from bun tag func (r *Reader) extractTableNameFromTag(tag string) (tableName string, schemaName string) { // Extract bun tag value @@ -422,6 +611,7 @@ func (r *Reader) parseColumn(fieldName string, fieldType ast.Expr, tag string, s case "autoincrement": column.AutoIncrement = true case "default": + // Default value from Bun tag (e.g., default:gen_random_uuid()) column.Default = value } } @@ -431,16 +621,24 @@ func (r *Reader) parseColumn(fieldName string, fieldType ast.Expr, tag string, s column.Type = r.goTypeToSQL(fieldType) } - // Determine if nullable based on Go type - if r.isNullableType(fieldType) { + // Determine if nullable based on Go type and bun tags + // In Bun: + // - nullzero tag means the field is nullable (can be NULL in DB) + // - absence of nullzero means the field is NOT NULL + // - primitive types (int64, bool, string) are NOT NULL by default + if strings.Contains(bunTag, "nullzero") { column.NotNull = false - } else if !column.IsPrimaryKey && column.Type != "" { - // If it's not a nullable type and not a primary key, check the tag - if !strings.Contains(bunTag, "notnull") { - // If notnull is not explicitly set, it might still be nullable - // This is a heuristic - we default to nullable unless specified - return column - } + } else if r.isNullableGoType(fieldType) { + // SqlString, SqlInt, etc. without nullzero tag means NOT NULL + column.NotNull = true + } else { + // Primitive types are NOT NULL by default + column.NotNull = true + } + + // Primary keys are always NOT NULL + if column.IsPrimaryKey { + column.NotNull = true } return column @@ -529,11 +727,12 @@ func (r *Reader) sqlTypeToSQL(typeName string) string { } } -// isNullableType checks if a Go type represents a nullable field -func (r *Reader) isNullableType(expr ast.Expr) bool { +// isNullableGoType checks if a Go type represents a nullable field type +// (this is for types that CAN be nullable, not whether they ARE nullable) +func (r *Reader) isNullableGoType(expr ast.Expr) bool { switch t := expr.(type) { case *ast.StarExpr: - // Pointer type is nullable + // Pointer type can be nullable return true case *ast.SelectorExpr: // Check for sql_types nullable types diff --git a/pkg/readers/gorm/reader.go b/pkg/readers/gorm/reader.go index 7637e50..44aac88 100644 --- a/pkg/readers/gorm/reader.go +++ b/pkg/readers/gorm/reader.go @@ -187,11 +187,44 @@ func (r *Reader) parseFile(filename string) ([]*models.Table, error) { table.Schema = schemaName } - // Update columns + // Update columns and indexes for _, col := range table.Columns { col.Table = tableName col.Schema = table.Schema } + for _, idx := range table.Indexes { + idx.Table = tableName + idx.Schema = table.Schema + } + } + } + + // Third pass: parse relationship fields for constraints + // Re-parse the file to get relationship information + for _, decl := range node.Decls { + genDecl, ok := decl.(*ast.GenDecl) + if !ok || genDecl.Tok != token.TYPE { + continue + } + + for _, spec := range genDecl.Specs { + typeSpec, ok := spec.(*ast.TypeSpec) + if !ok { + continue + } + + structType, ok := typeSpec.Type.(*ast.StructType) + if !ok { + continue + } + + table, ok := structMap[typeSpec.Name.Name] + if !ok { + continue + } + + // Parse relationship fields + r.parseRelationshipConstraints(table, structType, structMap) } } @@ -280,8 +313,14 @@ func (r *Reader) parseStruct(structName string, structType *ast.StructType) *mod continue } - // Skip embedded GORM model and relationship fields - if r.isGORMModel(field) || r.isRelationship(tag) { + // Skip embedded GORM model + if r.isGORMModel(field) { + continue + } + + // Parse relationship fields for foreign key constraints + if r.isRelationship(tag) { + // We'll parse constraints in a second pass after we know all table names continue } @@ -310,6 +349,10 @@ func (r *Reader) parseStruct(structName string, structType *ast.StructType) *mod table.Name = tableName table.Schema = schemaName table.Columns[column.Name] = column + + // Parse indexes from GORM tags + r.parseIndexesFromTag(table, column, tag) + sequence++ } } @@ -345,6 +388,169 @@ func (r *Reader) isRelationship(tag string) bool { strings.Contains(gormTag, "many2many:") } +// parseRelationshipConstraints parses relationship fields to extract foreign key constraints +func (r *Reader) parseRelationshipConstraints(table *models.Table, structType *ast.StructType, structMap map[string]*models.Table) { + for _, field := range structType.Fields.List { + if field.Tag == nil { + continue + } + + tag := field.Tag.Value + if !r.isRelationship(tag) { + continue + } + + gormTag := r.extractGormTag(tag) + parts := r.parseGormTag(gormTag) + + // Get the referenced type name from the field type + referencedType := r.getRelationshipType(field.Type) + if referencedType == "" { + continue + } + + // Find the referenced table + referencedTable, ok := structMap[referencedType] + if !ok { + continue + } + + // Extract foreign key information + foreignKey, hasForeignKey := parts["foreignKey"] + if !hasForeignKey { + continue + } + + // Convert field name to column name + fkColumn := r.fieldNameToColumnName(foreignKey) + + // Determine constraint behavior + onDelete := "NO ACTION" + onUpdate := "NO ACTION" + if constraintStr, hasConstraint := parts["constraint"]; hasConstraint { + // Parse constraint:OnDelete:CASCADE,OnUpdate:CASCADE + if strings.Contains(constraintStr, "OnDelete:CASCADE") { + onDelete = "CASCADE" + } else if strings.Contains(constraintStr, "OnDelete:SET NULL") { + onDelete = "SET NULL" + } + if strings.Contains(constraintStr, "OnUpdate:CASCADE") { + onUpdate = "CASCADE" + } else if strings.Contains(constraintStr, "OnUpdate:SET NULL") { + onUpdate = "SET NULL" + } + } + + // The FK is on the referenced table, pointing back to this table + // For has-many, the FK is on the "many" side + constraint := &models.Constraint{ + Name: fmt.Sprintf("fk_%s_%s", referencedTable.Name, table.Name), + Type: models.ForeignKeyConstraint, + Table: referencedTable.Name, + Schema: referencedTable.Schema, + Columns: []string{fkColumn}, + ReferencedTable: table.Name, + ReferencedSchema: table.Schema, + ReferencedColumns: []string{"id"}, // Typically references the primary key + OnDelete: onDelete, + OnUpdate: onUpdate, + } + + referencedTable.Constraints[constraint.Name] = constraint + } +} + +// getRelationshipType extracts the type name from a relationship field +func (r *Reader) getRelationshipType(expr ast.Expr) string { + switch t := expr.(type) { + case *ast.ArrayType: + // []*ModelPost -> ModelPost + if starExpr, ok := t.Elt.(*ast.StarExpr); ok { + if ident, ok := starExpr.X.(*ast.Ident); ok { + return ident.Name + } + } + case *ast.StarExpr: + // *ModelPost -> ModelPost + if ident, ok := t.X.(*ast.Ident); ok { + return ident.Name + } + } + return "" +} + +// parseIndexesFromTag extracts index definitions from GORM tags +func (r *Reader) parseIndexesFromTag(table *models.Table, column *models.Column, tag string) { + gormTag := r.extractGormTag(tag) + parts := r.parseGormTag(gormTag) + + // Check for regular index: index:idx_name or index + if indexName, ok := parts["index"]; ok { + if indexName == "" { + // Auto-generated index name + indexName = fmt.Sprintf("idx_%s_%s", table.Name, column.Name) + } + + // Check if index already exists + if _, exists := table.Indexes[indexName]; !exists { + index := &models.Index{ + Name: indexName, + Table: table.Name, + Schema: table.Schema, + Columns: []string{column.Name}, + Unique: false, + Type: "btree", + } + table.Indexes[indexName] = index + } else { + // Add column to existing index + table.Indexes[indexName].Columns = append(table.Indexes[indexName].Columns, column.Name) + } + } + + // Check for unique index: uniqueIndex:idx_name or uniqueIndex + if indexName, ok := parts["uniqueIndex"]; ok { + if indexName == "" { + // Auto-generated index name + indexName = fmt.Sprintf("idx_%s_%s", table.Name, column.Name) + } + + // Check if index already exists + if _, exists := table.Indexes[indexName]; !exists { + index := &models.Index{ + Name: indexName, + Table: table.Name, + Schema: table.Schema, + Columns: []string{column.Name}, + Unique: true, + Type: "btree", + } + table.Indexes[indexName] = index + } else { + // Add column to existing index + table.Indexes[indexName].Columns = append(table.Indexes[indexName].Columns, column.Name) + } + } + + // Check for simple unique flag (creates a unique index for this column) + if _, ok := parts["unique"]; ok { + // Auto-generated index name for unique constraint + indexName := fmt.Sprintf("idx_%s_%s", table.Name, column.Name) + + if _, exists := table.Indexes[indexName]; !exists { + index := &models.Index{ + Name: indexName, + Table: table.Name, + Schema: table.Schema, + Columns: []string{column.Name}, + Unique: true, + Type: "btree", + } + table.Indexes[indexName] = index + } + } +} + // extractTableFromGormTag extracts table and schema from gorm tag func (r *Reader) extractTableFromGormTag(tag string) (tablename string, schemaName string) { // This is typically set via TableName() method, not in tags @@ -406,6 +612,7 @@ func (r *Reader) parseColumn(fieldName string, fieldType ast.Expr, tag string, s column.AutoIncrement = true } if def, ok := parts["default"]; ok { + // Default value from GORM tag (e.g., default:gen_random_uuid()) column.Default = def } if size, ok := parts["size"]; ok { @@ -419,9 +626,27 @@ func (r *Reader) parseColumn(fieldName string, fieldType ast.Expr, tag string, s column.Type = r.goTypeToSQL(fieldType) } - // Determine if nullable based on Go type - if r.isNullableType(fieldType) { - column.NotNull = false + // Determine if nullable based on GORM tags and Go type + // In GORM: + // - explicit "not null" tag means NOT NULL + // - absence of "not null" tag with sql_types means nullable + // - primitive types (string, int64, bool) default to NOT NULL unless explicitly nullable + if _, hasNotNull := parts["not null"]; hasNotNull { + column.NotNull = true + } else { + // If no explicit "not null" tag, check the Go type + if r.isNullableGoType(fieldType) { + // sql_types.SqlString, etc. are nullable by default + column.NotNull = false + } else { + // Primitive types default to NOT NULL + column.NotNull = false // Default to nullable unless explicitly set + } + } + + // Primary keys are always NOT NULL + if column.IsPrimaryKey { + column.NotNull = true } return column @@ -557,11 +782,12 @@ func (r *Reader) sqlTypeToSQL(typeName string) string { } } -// isNullableType checks if a Go type represents a nullable field -func (r *Reader) isNullableType(expr ast.Expr) bool { +// isNullableGoType checks if a Go type represents a nullable field type +// (this is for types that CAN be nullable, not whether they ARE nullable) +func (r *Reader) isNullableGoType(expr ast.Expr) bool { switch t := expr.(type) { case *ast.StarExpr: - // Pointer type is nullable + // Pointer type can be nullable return true case *ast.SelectorExpr: // Check for sql_types nullable types diff --git a/pkg/writers/bun/type_mapper.go b/pkg/writers/bun/type_mapper.go index 2a347a3..58dc830 100644 --- a/pkg/writers/bun/type_mapper.go +++ b/pkg/writers/bun/type_mapper.go @@ -196,15 +196,31 @@ func (tm *TypeMapper) BuildBunTag(column *models.Column, table *models.Table) st parts = append(parts, "nullzero") } - // Check for unique constraint + // Check for indexes (unique indexes should be added to tag) if table != nil { - for _, constraint := range table.Constraints { - if constraint.Type == models.UniqueConstraint { - for _, col := range constraint.Columns { - if col == column.Name { - parts = append(parts, "unique") - break + for _, index := range table.Indexes { + if !index.Unique { + continue + } + // Check if this column is in the index + for _, col := range index.Columns { + if col == column.Name { + // Add unique tag with index name for composite indexes + // or simple unique for single-column indexes + if len(index.Columns) > 1 { + // Composite index - use index name + parts = append(parts, fmt.Sprintf("unique:%s", index.Name)) + } else { + // Single column - use index name if it's not auto-generated + // Auto-generated names typically follow pattern: idx_tablename_columnname + expectedAutoName := fmt.Sprintf("idx_%s_%s", table.Name, column.Name) + if index.Name == expectedAutoName { + parts = append(parts, "unique") + } else { + parts = append(parts, fmt.Sprintf("unique:%s", index.Name)) + } } + break } } } diff --git a/tests/integration/orm_roundtrip_test.go b/tests/integration/orm_roundtrip_test.go index 5db5210..fcb4c8e 100644 --- a/tests/integration/orm_roundtrip_test.go +++ b/tests/integration/orm_roundtrip_test.go @@ -14,18 +14,135 @@ import ( bunwriter "git.warky.dev/wdevs/relspecgo/pkg/writers/bun" gormwriter "git.warky.dev/wdevs/relspecgo/pkg/writers/gorm" yamlwriter "git.warky.dev/wdevs/relspecgo/pkg/writers/yaml" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "gopkg.in/yaml.v3" ) +// ComparisonResults holds the results of database comparison +type ComparisonResults struct { + Schemas int + Tables int + Columns int + OriginalIndexes int + RoundtripIndexes int + OriginalConstraints int + RoundtripConstraints int +} + +// countDatabaseStats counts tables, indexes, and constraints in a database +func countDatabaseStats(db *models.Database) (tables, indexes, constraints int) { + for _, schema := range db.Schemas { + tables += len(schema.Tables) + for _, table := range schema.Tables { + indexes += len(table.Indexes) + constraints += len(table.Constraints) + } + } + return +} + +// compareDatabases performs comprehensive comparison between two databases +func compareDatabases(t *testing.T, db1, db2 *models.Database, ormName string) ComparisonResults { + t.Helper() + + results := ComparisonResults{} + + // Compare high-level structure + t.Log(" Comparing high-level structure...") + require.Equal(t, len(db1.Schemas), len(db2.Schemas), "Schema count should match") + results.Schemas = len(db1.Schemas) + + // Count totals + tables1, indexes1, constraints1 := countDatabaseStats(db1) + _, indexes2, constraints2 := countDatabaseStats(db2) + + results.OriginalIndexes = indexes1 + results.RoundtripIndexes = indexes2 + results.OriginalConstraints = constraints1 + results.RoundtripConstraints = constraints2 + + // Compare schemas and tables + for i, schema1 := range db1.Schemas { + if i >= len(db2.Schemas) { + t.Errorf("Schema index %d out of bounds in second database", i) + continue + } + schema2 := db2.Schemas[i] + + require.Equal(t, schema1.Name, schema2.Name, "Schema names should match") + require.Equal(t, len(schema1.Tables), len(schema2.Tables), + "Table count in schema '%s' should match", schema1.Name) + + // Compare tables + for j, table1 := range schema1.Tables { + if j >= len(schema2.Tables) { + t.Errorf("Table index %d out of bounds in schema '%s'", j, schema1.Name) + continue + } + table2 := schema2.Tables[j] + + require.Equal(t, table1.Name, table2.Name, + "Table names should match in schema '%s'", schema1.Name) + + // Compare column count + require.Equal(t, len(table1.Columns), len(table2.Columns), + "Column count in table '%s.%s' should match", schema1.Name, table1.Name) + + results.Columns += len(table1.Columns) + + // Compare each column + for colName, col1 := range table1.Columns { + col2, ok := table2.Columns[colName] + if !ok { + t.Errorf("Column '%s' missing from roundtrip table '%s.%s'", + colName, schema1.Name, table1.Name) + continue + } + + // Compare key column properties + require.Equal(t, col1.Name, col2.Name, + "Column name mismatch in '%s.%s.%s'", schema1.Name, table1.Name, colName) + require.Equal(t, col1.Type, col2.Type, + "Column type mismatch in '%s.%s.%s'", schema1.Name, table1.Name, colName) + require.Equal(t, col1.Length, col2.Length, + "Column length mismatch in '%s.%s.%s'", schema1.Name, table1.Name, colName) + require.Equal(t, col1.IsPrimaryKey, col2.IsPrimaryKey, + "Primary key mismatch in '%s.%s.%s'", schema1.Name, table1.Name, colName) + require.Equal(t, col1.NotNull, col2.NotNull, + "NotNull mismatch in '%s.%s.%s'", schema1.Name, table1.Name, colName) + + // Log defaults that don't match (these can vary in representation) + if col1.Default != col2.Default { + t.Logf(" ℹ Default value differs for '%s.%s.%s': '%v' vs '%v'", + schema1.Name, table1.Name, colName, col1.Default, col2.Default) + } + } + + // Log index and constraint differences (ORM readers may not capture all of these) + if len(table1.Indexes) != len(table2.Indexes) { + t.Logf(" ℹ Index count differs for table '%s.%s': %d vs %d", + schema1.Name, table1.Name, len(table1.Indexes), len(table2.Indexes)) + } + if len(table1.Constraints) != len(table2.Constraints) { + t.Logf(" ℹ Constraint count differs for table '%s.%s': %d vs %d", + schema1.Name, table1.Name, len(table1.Constraints), len(table2.Constraints)) + } + } + } + + results.Tables = tables1 + t.Logf(" ✓ Validated %d schemas, %d tables, %d columns", results.Schemas, results.Tables, results.Columns) + + return results +} + // TestYAMLToBunRoundTrip tests YAML → Bun Go → YAML roundtrip func TestYAMLToBunRoundTrip(t *testing.T) { testDir := t.TempDir() // Step 1: Read YAML file t.Log("Step 1: Reading YAML file...") - yamlPath := filepath.Join("..", "assets", "yaml", "database.yaml") + yamlPath := filepath.Join("..", "assets", "yaml", "complex_database.yaml") yamlReaderOpts := &readers.ReaderOptions{ FilePath: yamlPath, } @@ -36,6 +153,10 @@ func TestYAMLToBunRoundTrip(t *testing.T) { require.NotNil(t, dbFromYAML, "Database from YAML should not be nil") t.Logf(" ✓ Read database '%s' with %d schemas", dbFromYAML.Name, len(dbFromYAML.Schemas)) + // Log initial stats + totalTables, totalIndexes, totalConstraints := countDatabaseStats(dbFromYAML) + t.Logf(" ✓ Original: %d tables, %d indexes, %d constraints", totalTables, totalIndexes, totalConstraints) + // Step 2: Write to Bun Go code t.Log("Step 2: Writing to Bun Go code...") bunGoPath := filepath.Join(testDir, "models_bun.go") @@ -103,67 +224,24 @@ func TestYAMLToBunRoundTrip(t *testing.T) { err = yaml.Unmarshal(yaml2Data, &db2) require.NoError(t, err, "Failed to parse second YAML") - // Compare high-level structure - t.Log(" Comparing high-level structure...") - assert.Equal(t, len(db1.Schemas), len(db2.Schemas), "Schema count should match") - - // Compare schemas and tables - for i, schema1 := range db1.Schemas { - if i >= len(db2.Schemas) { - t.Errorf("Schema index %d out of bounds in second database", i) - continue - } - schema2 := db2.Schemas[i] - - assert.Equal(t, schema1.Name, schema2.Name, "Schema names should match") - assert.Equal(t, len(schema1.Tables), len(schema2.Tables), - "Table count in schema '%s' should match", schema1.Name) - - // Compare tables - for j, table1 := range schema1.Tables { - if j >= len(schema2.Tables) { - t.Errorf("Table index %d out of bounds in schema '%s'", j, schema1.Name) - continue - } - table2 := schema2.Tables[j] - - assert.Equal(t, table1.Name, table2.Name, - "Table names should match in schema '%s'", schema1.Name) - - // Compare column count - assert.Equal(t, len(table1.Columns), len(table2.Columns), - "Column count in table '%s.%s' should match", schema1.Name, table1.Name) - - // Compare each column - for colName, col1 := range table1.Columns { - col2, ok := table2.Columns[colName] - if !ok { - t.Errorf("Column '%s' missing from roundtrip table '%s.%s'", - colName, schema1.Name, table1.Name) - continue - } - - // Compare key column properties - assert.Equal(t, col1.Name, col2.Name, - "Column name mismatch in '%s.%s.%s'", schema1.Name, table1.Name, colName) - assert.Equal(t, col1.Type, col2.Type, - "Column type mismatch in '%s.%s.%s'", schema1.Name, table1.Name, colName) - assert.Equal(t, col1.IsPrimaryKey, col2.IsPrimaryKey, - "Primary key mismatch in '%s.%s.%s'", schema1.Name, table1.Name, colName) - } - } - } + // Comprehensive comparison + compareResults := compareDatabases(t, &db1, &db2, "Bun") // Summary t.Log("Summary:") t.Logf(" ✓ Round-trip completed: YAML → Bun → YAML") - t.Logf(" ✓ Schemas match: %d", len(db1.Schemas)) + t.Logf(" ✓ Schemas: %d", compareResults.Schemas) + t.Logf(" ✓ Tables: %d", compareResults.Tables) + t.Logf(" ✓ Columns: %d", compareResults.Columns) + t.Logf(" ✓ Indexes: %d (original), %d (roundtrip)", compareResults.OriginalIndexes, compareResults.RoundtripIndexes) + t.Logf(" ✓ Constraints: %d (original), %d (roundtrip)", compareResults.OriginalConstraints, compareResults.RoundtripConstraints) - totalTables := 0 - for _, schema := range db1.Schemas { - totalTables += len(schema.Tables) + if compareResults.OriginalIndexes != compareResults.RoundtripIndexes { + t.Logf(" ⚠ Note: Index counts differ - Bun reader may not parse all index information from Go code") + } + if compareResults.OriginalConstraints != compareResults.RoundtripConstraints { + t.Logf(" ⚠ Note: Constraint counts differ - Bun reader may not parse all constraint information from Go code") } - t.Logf(" ✓ Total tables: %d", totalTables) } // TestYAMLToGORMRoundTrip tests YAML → GORM Go → YAML roundtrip @@ -172,7 +250,7 @@ func TestYAMLToGORMRoundTrip(t *testing.T) { // Step 1: Read YAML file t.Log("Step 1: Reading YAML file...") - yamlPath := filepath.Join("..", "assets", "yaml", "database.yaml") + yamlPath := filepath.Join("..", "assets", "yaml", "complex_database.yaml") yamlReaderOpts := &readers.ReaderOptions{ FilePath: yamlPath, } @@ -183,6 +261,10 @@ func TestYAMLToGORMRoundTrip(t *testing.T) { require.NotNil(t, dbFromYAML, "Database from YAML should not be nil") t.Logf(" ✓ Read database '%s' with %d schemas", dbFromYAML.Name, len(dbFromYAML.Schemas)) + // Log initial stats + totalTables, totalIndexes, totalConstraints := countDatabaseStats(dbFromYAML) + t.Logf(" ✓ Original: %d tables, %d indexes, %d constraints", totalTables, totalIndexes, totalConstraints) + // Step 2: Write to GORM Go code t.Log("Step 2: Writing to GORM Go code...") gormGoPath := filepath.Join(testDir, "models_gorm.go") @@ -250,65 +332,22 @@ func TestYAMLToGORMRoundTrip(t *testing.T) { err = yaml.Unmarshal(yaml2Data, &db2) require.NoError(t, err, "Failed to parse second YAML") - // Compare high-level structure - t.Log(" Comparing high-level structure...") - assert.Equal(t, len(db1.Schemas), len(db2.Schemas), "Schema count should match") - - // Compare schemas and tables - for i, schema1 := range db1.Schemas { - if i >= len(db2.Schemas) { - t.Errorf("Schema index %d out of bounds in second database", i) - continue - } - schema2 := db2.Schemas[i] - - assert.Equal(t, schema1.Name, schema2.Name, "Schema names should match") - assert.Equal(t, len(schema1.Tables), len(schema2.Tables), - "Table count in schema '%s' should match", schema1.Name) - - // Compare tables - for j, table1 := range schema1.Tables { - if j >= len(schema2.Tables) { - t.Errorf("Table index %d out of bounds in schema '%s'", j, schema1.Name) - continue - } - table2 := schema2.Tables[j] - - assert.Equal(t, table1.Name, table2.Name, - "Table names should match in schema '%s'", schema1.Name) - - // Compare column count - assert.Equal(t, len(table1.Columns), len(table2.Columns), - "Column count in table '%s.%s' should match", schema1.Name, table1.Name) - - // Compare each column - for colName, col1 := range table1.Columns { - col2, ok := table2.Columns[colName] - if !ok { - t.Errorf("Column '%s' missing from roundtrip table '%s.%s'", - colName, schema1.Name, table1.Name) - continue - } - - // Compare key column properties - assert.Equal(t, col1.Name, col2.Name, - "Column name mismatch in '%s.%s.%s'", schema1.Name, table1.Name, colName) - assert.Equal(t, col1.Type, col2.Type, - "Column type mismatch in '%s.%s.%s'", schema1.Name, table1.Name, colName) - assert.Equal(t, col1.IsPrimaryKey, col2.IsPrimaryKey, - "Primary key mismatch in '%s.%s.%s'", schema1.Name, table1.Name, colName) - } - } - } + // Comprehensive comparison + compareResults := compareDatabases(t, &db1, &db2, "GORM") // Summary t.Log("Summary:") t.Logf(" ✓ Round-trip completed: YAML → GORM → YAML") - t.Logf(" ✓ Schemas match: %d", len(db1.Schemas)) + t.Logf(" ✓ Schemas: %d", compareResults.Schemas) + t.Logf(" ✓ Tables: %d", compareResults.Tables) + t.Logf(" ✓ Columns: %d", compareResults.Columns) + t.Logf(" ✓ Indexes: %d (original), %d (roundtrip)", compareResults.OriginalIndexes, compareResults.RoundtripIndexes) + t.Logf(" ✓ Constraints: %d (original), %d (roundtrip)", compareResults.OriginalConstraints, compareResults.RoundtripConstraints) - totalTables := 0 - for _, schema := range db1.Schemas { - totalTables += len(schema.Tables) + if compareResults.OriginalIndexes != compareResults.RoundtripIndexes { + t.Logf(" ⚠ Note: Index counts differ - GORM reader may not parse all index information from Go code") + } + if compareResults.OriginalConstraints != compareResults.RoundtripConstraints { + t.Logf(" ⚠ Note: Constraint counts differ - GORM reader may not parse all constraint information from Go code") } - t.Logf(" ✓ Total tables: %d", totalTables) }