Fixed bug/gorm indexes
Some checks are pending
CI / Build (push) Waiting to run
CI / Test (1.23) (push) Waiting to run
CI / Test (1.24) (push) Waiting to run
CI / Test (1.25) (push) Waiting to run
CI / Lint (push) Waiting to run

This commit is contained in:
2025-12-18 19:15:22 +02:00
parent b7950057eb
commit d93a4b6f08
4 changed files with 622 additions and 142 deletions

View File

@@ -187,11 +187,43 @@ func (r *Reader) parseFile(filename string) ([]*models.Table, error) {
table.Schema = schemaName
}
// Update columns
// Update columns and indexes
for _, col := range table.Columns {
col.Table = tableName
col.Schema = table.Schema
}
for _, idx := range table.Indexes {
idx.Table = tableName
idx.Schema = table.Schema
}
}
}
// Third pass: parse relationship fields for constraints
for _, decl := range node.Decls {
genDecl, ok := decl.(*ast.GenDecl)
if !ok || genDecl.Tok != token.TYPE {
continue
}
for _, spec := range genDecl.Specs {
typeSpec, ok := spec.(*ast.TypeSpec)
if !ok {
continue
}
structType, ok := typeSpec.Type.(*ast.StructType)
if !ok {
continue
}
table, ok := structMap[typeSpec.Name.Name]
if !ok {
continue
}
// Parse relationship fields
r.parseRelationshipConstraints(table, structType, structMap)
}
}
@@ -314,6 +346,10 @@ func (r *Reader) parseStruct(structName string, structType *ast.StructType) *mod
column.Table = tableName
column.Schema = schemaName
table.Columns[column.Name] = column
// Parse indexes from bun tags
r.parseIndexesFromTag(table, column, tag)
sequence++
}
}
@@ -346,6 +382,159 @@ func (r *Reader) isRelationship(tag string) bool {
return strings.Contains(tag, "bun:\"rel:") || strings.Contains(tag, ",rel:")
}
// parseRelationshipConstraints parses relationship fields to extract foreign key constraints
func (r *Reader) parseRelationshipConstraints(table *models.Table, structType *ast.StructType, structMap map[string]*models.Table) {
for _, field := range structType.Fields.List {
if field.Tag == nil {
continue
}
tag := field.Tag.Value
if !r.isRelationship(tag) {
continue
}
bunTag := r.extractBunTag(tag)
// Get the referenced type name from the field type
referencedType := r.getRelationshipType(field.Type)
if referencedType == "" {
continue
}
// Find the referenced table
referencedTable, ok := structMap[referencedType]
if !ok {
continue
}
// Parse the join information: join:user_id=id
// This means: referencedTable.user_id = thisTable.id
joinInfo := r.parseJoinInfo(bunTag)
if joinInfo == nil {
continue
}
// The FK is on the referenced table
constraint := &models.Constraint{
Name: fmt.Sprintf("fk_%s_%s", referencedTable.Name, table.Name),
Type: models.ForeignKeyConstraint,
Table: referencedTable.Name,
Schema: referencedTable.Schema,
Columns: []string{joinInfo.ForeignKey},
ReferencedTable: table.Name,
ReferencedSchema: table.Schema,
ReferencedColumns: []string{joinInfo.ReferencedKey},
OnDelete: "NO ACTION", // Bun doesn't specify this in tags
OnUpdate: "NO ACTION",
}
referencedTable.Constraints[constraint.Name] = constraint
}
}
// JoinInfo holds parsed join information
type JoinInfo struct {
ForeignKey string // Column in the related table
ReferencedKey string // Column in the current table
}
// parseJoinInfo parses join information from bun tag
// Example: join:user_id=id means foreign_key=referenced_key
func (r *Reader) parseJoinInfo(bunTag string) *JoinInfo {
// Look for join: in the tag
if !strings.Contains(bunTag, "join:") {
return nil
}
// Extract join clause
parts := strings.Split(bunTag, ",")
for _, part := range parts {
part = strings.TrimSpace(part)
if strings.HasPrefix(part, "join:") {
joinStr := strings.TrimPrefix(part, "join:")
// Parse user_id=id
joinParts := strings.SplitN(joinStr, "=", 2)
if len(joinParts) == 2 {
return &JoinInfo{
ForeignKey: joinParts[0],
ReferencedKey: joinParts[1],
}
}
}
}
return nil
}
// getRelationshipType extracts the type name from a relationship field
func (r *Reader) getRelationshipType(expr ast.Expr) string {
switch t := expr.(type) {
case *ast.ArrayType:
// []*ModelPost -> ModelPost
if starExpr, ok := t.Elt.(*ast.StarExpr); ok {
if ident, ok := starExpr.X.(*ast.Ident); ok {
return ident.Name
}
}
case *ast.StarExpr:
// *ModelPost -> ModelPost
if ident, ok := t.X.(*ast.Ident); ok {
return ident.Name
}
}
return ""
}
// parseIndexesFromTag extracts index definitions from Bun tags
func (r *Reader) parseIndexesFromTag(table *models.Table, column *models.Column, tag string) {
bunTag := r.extractBunTag(tag)
// Parse tag into parts
parts := strings.Split(bunTag, ",")
for _, part := range parts {
part = strings.TrimSpace(part)
// Check for unique index
if part == "unique" {
// Auto-generate index name: idx_tablename_columnname
indexName := fmt.Sprintf("idx_%s_%s", table.Name, column.Name)
if _, exists := table.Indexes[indexName]; !exists {
index := &models.Index{
Name: indexName,
Table: table.Name,
Schema: table.Schema,
Columns: []string{column.Name},
Unique: true,
Type: "btree",
}
table.Indexes[indexName] = index
}
} else if strings.HasPrefix(part, "unique:") {
// Named unique index: unique:idx_name
indexName := strings.TrimPrefix(part, "unique:")
// Check if index already exists (for composite indexes)
if idx, exists := table.Indexes[indexName]; exists {
// Add this column to the existing index
idx.Columns = append(idx.Columns, column.Name)
} else {
// Create new index
index := &models.Index{
Name: indexName,
Table: table.Name,
Schema: table.Schema,
Columns: []string{column.Name},
Unique: true,
Type: "btree",
}
table.Indexes[indexName] = index
}
}
}
}
// extractTableNameFromTag extracts table and schema from bun tag
func (r *Reader) extractTableNameFromTag(tag string) (tableName string, schemaName string) {
// Extract bun tag value
@@ -422,6 +611,7 @@ func (r *Reader) parseColumn(fieldName string, fieldType ast.Expr, tag string, s
case "autoincrement":
column.AutoIncrement = true
case "default":
// Default value from Bun tag (e.g., default:gen_random_uuid())
column.Default = value
}
}
@@ -431,16 +621,24 @@ func (r *Reader) parseColumn(fieldName string, fieldType ast.Expr, tag string, s
column.Type = r.goTypeToSQL(fieldType)
}
// Determine if nullable based on Go type
if r.isNullableType(fieldType) {
// Determine if nullable based on Go type and bun tags
// In Bun:
// - nullzero tag means the field is nullable (can be NULL in DB)
// - absence of nullzero means the field is NOT NULL
// - primitive types (int64, bool, string) are NOT NULL by default
if strings.Contains(bunTag, "nullzero") {
column.NotNull = false
} else if !column.IsPrimaryKey && column.Type != "" {
// If it's not a nullable type and not a primary key, check the tag
if !strings.Contains(bunTag, "notnull") {
// If notnull is not explicitly set, it might still be nullable
// This is a heuristic - we default to nullable unless specified
return column
}
} else if r.isNullableGoType(fieldType) {
// SqlString, SqlInt, etc. without nullzero tag means NOT NULL
column.NotNull = true
} else {
// Primitive types are NOT NULL by default
column.NotNull = true
}
// Primary keys are always NOT NULL
if column.IsPrimaryKey {
column.NotNull = true
}
return column
@@ -529,11 +727,12 @@ func (r *Reader) sqlTypeToSQL(typeName string) string {
}
}
// isNullableType checks if a Go type represents a nullable field
func (r *Reader) isNullableType(expr ast.Expr) bool {
// isNullableGoType checks if a Go type represents a nullable field type
// (this is for types that CAN be nullable, not whether they ARE nullable)
func (r *Reader) isNullableGoType(expr ast.Expr) bool {
switch t := expr.(type) {
case *ast.StarExpr:
// Pointer type is nullable
// Pointer type can be nullable
return true
case *ast.SelectorExpr:
// Check for sql_types nullable types

View File

@@ -187,11 +187,44 @@ func (r *Reader) parseFile(filename string) ([]*models.Table, error) {
table.Schema = schemaName
}
// Update columns
// Update columns and indexes
for _, col := range table.Columns {
col.Table = tableName
col.Schema = table.Schema
}
for _, idx := range table.Indexes {
idx.Table = tableName
idx.Schema = table.Schema
}
}
}
// Third pass: parse relationship fields for constraints
// Re-parse the file to get relationship information
for _, decl := range node.Decls {
genDecl, ok := decl.(*ast.GenDecl)
if !ok || genDecl.Tok != token.TYPE {
continue
}
for _, spec := range genDecl.Specs {
typeSpec, ok := spec.(*ast.TypeSpec)
if !ok {
continue
}
structType, ok := typeSpec.Type.(*ast.StructType)
if !ok {
continue
}
table, ok := structMap[typeSpec.Name.Name]
if !ok {
continue
}
// Parse relationship fields
r.parseRelationshipConstraints(table, structType, structMap)
}
}
@@ -280,8 +313,14 @@ func (r *Reader) parseStruct(structName string, structType *ast.StructType) *mod
continue
}
// Skip embedded GORM model and relationship fields
if r.isGORMModel(field) || r.isRelationship(tag) {
// Skip embedded GORM model
if r.isGORMModel(field) {
continue
}
// Parse relationship fields for foreign key constraints
if r.isRelationship(tag) {
// We'll parse constraints in a second pass after we know all table names
continue
}
@@ -310,6 +349,10 @@ func (r *Reader) parseStruct(structName string, structType *ast.StructType) *mod
table.Name = tableName
table.Schema = schemaName
table.Columns[column.Name] = column
// Parse indexes from GORM tags
r.parseIndexesFromTag(table, column, tag)
sequence++
}
}
@@ -345,6 +388,169 @@ func (r *Reader) isRelationship(tag string) bool {
strings.Contains(gormTag, "many2many:")
}
// parseRelationshipConstraints parses relationship fields to extract foreign key constraints
func (r *Reader) parseRelationshipConstraints(table *models.Table, structType *ast.StructType, structMap map[string]*models.Table) {
for _, field := range structType.Fields.List {
if field.Tag == nil {
continue
}
tag := field.Tag.Value
if !r.isRelationship(tag) {
continue
}
gormTag := r.extractGormTag(tag)
parts := r.parseGormTag(gormTag)
// Get the referenced type name from the field type
referencedType := r.getRelationshipType(field.Type)
if referencedType == "" {
continue
}
// Find the referenced table
referencedTable, ok := structMap[referencedType]
if !ok {
continue
}
// Extract foreign key information
foreignKey, hasForeignKey := parts["foreignKey"]
if !hasForeignKey {
continue
}
// Convert field name to column name
fkColumn := r.fieldNameToColumnName(foreignKey)
// Determine constraint behavior
onDelete := "NO ACTION"
onUpdate := "NO ACTION"
if constraintStr, hasConstraint := parts["constraint"]; hasConstraint {
// Parse constraint:OnDelete:CASCADE,OnUpdate:CASCADE
if strings.Contains(constraintStr, "OnDelete:CASCADE") {
onDelete = "CASCADE"
} else if strings.Contains(constraintStr, "OnDelete:SET NULL") {
onDelete = "SET NULL"
}
if strings.Contains(constraintStr, "OnUpdate:CASCADE") {
onUpdate = "CASCADE"
} else if strings.Contains(constraintStr, "OnUpdate:SET NULL") {
onUpdate = "SET NULL"
}
}
// The FK is on the referenced table, pointing back to this table
// For has-many, the FK is on the "many" side
constraint := &models.Constraint{
Name: fmt.Sprintf("fk_%s_%s", referencedTable.Name, table.Name),
Type: models.ForeignKeyConstraint,
Table: referencedTable.Name,
Schema: referencedTable.Schema,
Columns: []string{fkColumn},
ReferencedTable: table.Name,
ReferencedSchema: table.Schema,
ReferencedColumns: []string{"id"}, // Typically references the primary key
OnDelete: onDelete,
OnUpdate: onUpdate,
}
referencedTable.Constraints[constraint.Name] = constraint
}
}
// getRelationshipType extracts the type name from a relationship field
func (r *Reader) getRelationshipType(expr ast.Expr) string {
switch t := expr.(type) {
case *ast.ArrayType:
// []*ModelPost -> ModelPost
if starExpr, ok := t.Elt.(*ast.StarExpr); ok {
if ident, ok := starExpr.X.(*ast.Ident); ok {
return ident.Name
}
}
case *ast.StarExpr:
// *ModelPost -> ModelPost
if ident, ok := t.X.(*ast.Ident); ok {
return ident.Name
}
}
return ""
}
// parseIndexesFromTag extracts index definitions from GORM tags
func (r *Reader) parseIndexesFromTag(table *models.Table, column *models.Column, tag string) {
gormTag := r.extractGormTag(tag)
parts := r.parseGormTag(gormTag)
// Check for regular index: index:idx_name or index
if indexName, ok := parts["index"]; ok {
if indexName == "" {
// Auto-generated index name
indexName = fmt.Sprintf("idx_%s_%s", table.Name, column.Name)
}
// Check if index already exists
if _, exists := table.Indexes[indexName]; !exists {
index := &models.Index{
Name: indexName,
Table: table.Name,
Schema: table.Schema,
Columns: []string{column.Name},
Unique: false,
Type: "btree",
}
table.Indexes[indexName] = index
} else {
// Add column to existing index
table.Indexes[indexName].Columns = append(table.Indexes[indexName].Columns, column.Name)
}
}
// Check for unique index: uniqueIndex:idx_name or uniqueIndex
if indexName, ok := parts["uniqueIndex"]; ok {
if indexName == "" {
// Auto-generated index name
indexName = fmt.Sprintf("idx_%s_%s", table.Name, column.Name)
}
// Check if index already exists
if _, exists := table.Indexes[indexName]; !exists {
index := &models.Index{
Name: indexName,
Table: table.Name,
Schema: table.Schema,
Columns: []string{column.Name},
Unique: true,
Type: "btree",
}
table.Indexes[indexName] = index
} else {
// Add column to existing index
table.Indexes[indexName].Columns = append(table.Indexes[indexName].Columns, column.Name)
}
}
// Check for simple unique flag (creates a unique index for this column)
if _, ok := parts["unique"]; ok {
// Auto-generated index name for unique constraint
indexName := fmt.Sprintf("idx_%s_%s", table.Name, column.Name)
if _, exists := table.Indexes[indexName]; !exists {
index := &models.Index{
Name: indexName,
Table: table.Name,
Schema: table.Schema,
Columns: []string{column.Name},
Unique: true,
Type: "btree",
}
table.Indexes[indexName] = index
}
}
}
// extractTableFromGormTag extracts table and schema from gorm tag
func (r *Reader) extractTableFromGormTag(tag string) (tablename string, schemaName string) {
// This is typically set via TableName() method, not in tags
@@ -406,6 +612,7 @@ func (r *Reader) parseColumn(fieldName string, fieldType ast.Expr, tag string, s
column.AutoIncrement = true
}
if def, ok := parts["default"]; ok {
// Default value from GORM tag (e.g., default:gen_random_uuid())
column.Default = def
}
if size, ok := parts["size"]; ok {
@@ -419,9 +626,27 @@ func (r *Reader) parseColumn(fieldName string, fieldType ast.Expr, tag string, s
column.Type = r.goTypeToSQL(fieldType)
}
// Determine if nullable based on Go type
if r.isNullableType(fieldType) {
column.NotNull = false
// Determine if nullable based on GORM tags and Go type
// In GORM:
// - explicit "not null" tag means NOT NULL
// - absence of "not null" tag with sql_types means nullable
// - primitive types (string, int64, bool) default to NOT NULL unless explicitly nullable
if _, hasNotNull := parts["not null"]; hasNotNull {
column.NotNull = true
} else {
// If no explicit "not null" tag, check the Go type
if r.isNullableGoType(fieldType) {
// sql_types.SqlString, etc. are nullable by default
column.NotNull = false
} else {
// Primitive types default to NOT NULL
column.NotNull = false // Default to nullable unless explicitly set
}
}
// Primary keys are always NOT NULL
if column.IsPrimaryKey {
column.NotNull = true
}
return column
@@ -557,11 +782,12 @@ func (r *Reader) sqlTypeToSQL(typeName string) string {
}
}
// isNullableType checks if a Go type represents a nullable field
func (r *Reader) isNullableType(expr ast.Expr) bool {
// isNullableGoType checks if a Go type represents a nullable field type
// (this is for types that CAN be nullable, not whether they ARE nullable)
func (r *Reader) isNullableGoType(expr ast.Expr) bool {
switch t := expr.(type) {
case *ast.StarExpr:
// Pointer type is nullable
// Pointer type can be nullable
return true
case *ast.SelectorExpr:
// Check for sql_types nullable types

View File

@@ -196,15 +196,31 @@ func (tm *TypeMapper) BuildBunTag(column *models.Column, table *models.Table) st
parts = append(parts, "nullzero")
}
// Check for unique constraint
// Check for indexes (unique indexes should be added to tag)
if table != nil {
for _, constraint := range table.Constraints {
if constraint.Type == models.UniqueConstraint {
for _, col := range constraint.Columns {
if col == column.Name {
parts = append(parts, "unique")
break
for _, index := range table.Indexes {
if !index.Unique {
continue
}
// Check if this column is in the index
for _, col := range index.Columns {
if col == column.Name {
// Add unique tag with index name for composite indexes
// or simple unique for single-column indexes
if len(index.Columns) > 1 {
// Composite index - use index name
parts = append(parts, fmt.Sprintf("unique:%s", index.Name))
} else {
// Single column - use index name if it's not auto-generated
// Auto-generated names typically follow pattern: idx_tablename_columnname
expectedAutoName := fmt.Sprintf("idx_%s_%s", table.Name, column.Name)
if index.Name == expectedAutoName {
parts = append(parts, "unique")
} else {
parts = append(parts, fmt.Sprintf("unique:%s", index.Name))
}
}
break
}
}
}

View File

@@ -14,18 +14,135 @@ import (
bunwriter "git.warky.dev/wdevs/relspecgo/pkg/writers/bun"
gormwriter "git.warky.dev/wdevs/relspecgo/pkg/writers/gorm"
yamlwriter "git.warky.dev/wdevs/relspecgo/pkg/writers/yaml"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gopkg.in/yaml.v3"
)
// ComparisonResults holds the results of database comparison
type ComparisonResults struct {
Schemas int
Tables int
Columns int
OriginalIndexes int
RoundtripIndexes int
OriginalConstraints int
RoundtripConstraints int
}
// countDatabaseStats counts tables, indexes, and constraints in a database
func countDatabaseStats(db *models.Database) (tables, indexes, constraints int) {
for _, schema := range db.Schemas {
tables += len(schema.Tables)
for _, table := range schema.Tables {
indexes += len(table.Indexes)
constraints += len(table.Constraints)
}
}
return
}
// compareDatabases performs comprehensive comparison between two databases
func compareDatabases(t *testing.T, db1, db2 *models.Database, ormName string) ComparisonResults {
t.Helper()
results := ComparisonResults{}
// Compare high-level structure
t.Log(" Comparing high-level structure...")
require.Equal(t, len(db1.Schemas), len(db2.Schemas), "Schema count should match")
results.Schemas = len(db1.Schemas)
// Count totals
tables1, indexes1, constraints1 := countDatabaseStats(db1)
_, indexes2, constraints2 := countDatabaseStats(db2)
results.OriginalIndexes = indexes1
results.RoundtripIndexes = indexes2
results.OriginalConstraints = constraints1
results.RoundtripConstraints = constraints2
// Compare schemas and tables
for i, schema1 := range db1.Schemas {
if i >= len(db2.Schemas) {
t.Errorf("Schema index %d out of bounds in second database", i)
continue
}
schema2 := db2.Schemas[i]
require.Equal(t, schema1.Name, schema2.Name, "Schema names should match")
require.Equal(t, len(schema1.Tables), len(schema2.Tables),
"Table count in schema '%s' should match", schema1.Name)
// Compare tables
for j, table1 := range schema1.Tables {
if j >= len(schema2.Tables) {
t.Errorf("Table index %d out of bounds in schema '%s'", j, schema1.Name)
continue
}
table2 := schema2.Tables[j]
require.Equal(t, table1.Name, table2.Name,
"Table names should match in schema '%s'", schema1.Name)
// Compare column count
require.Equal(t, len(table1.Columns), len(table2.Columns),
"Column count in table '%s.%s' should match", schema1.Name, table1.Name)
results.Columns += len(table1.Columns)
// Compare each column
for colName, col1 := range table1.Columns {
col2, ok := table2.Columns[colName]
if !ok {
t.Errorf("Column '%s' missing from roundtrip table '%s.%s'",
colName, schema1.Name, table1.Name)
continue
}
// Compare key column properties
require.Equal(t, col1.Name, col2.Name,
"Column name mismatch in '%s.%s.%s'", schema1.Name, table1.Name, colName)
require.Equal(t, col1.Type, col2.Type,
"Column type mismatch in '%s.%s.%s'", schema1.Name, table1.Name, colName)
require.Equal(t, col1.Length, col2.Length,
"Column length mismatch in '%s.%s.%s'", schema1.Name, table1.Name, colName)
require.Equal(t, col1.IsPrimaryKey, col2.IsPrimaryKey,
"Primary key mismatch in '%s.%s.%s'", schema1.Name, table1.Name, colName)
require.Equal(t, col1.NotNull, col2.NotNull,
"NotNull mismatch in '%s.%s.%s'", schema1.Name, table1.Name, colName)
// Log defaults that don't match (these can vary in representation)
if col1.Default != col2.Default {
t.Logf(" Default value differs for '%s.%s.%s': '%v' vs '%v'",
schema1.Name, table1.Name, colName, col1.Default, col2.Default)
}
}
// Log index and constraint differences (ORM readers may not capture all of these)
if len(table1.Indexes) != len(table2.Indexes) {
t.Logf(" Index count differs for table '%s.%s': %d vs %d",
schema1.Name, table1.Name, len(table1.Indexes), len(table2.Indexes))
}
if len(table1.Constraints) != len(table2.Constraints) {
t.Logf(" Constraint count differs for table '%s.%s': %d vs %d",
schema1.Name, table1.Name, len(table1.Constraints), len(table2.Constraints))
}
}
}
results.Tables = tables1
t.Logf(" ✓ Validated %d schemas, %d tables, %d columns", results.Schemas, results.Tables, results.Columns)
return results
}
// TestYAMLToBunRoundTrip tests YAML → Bun Go → YAML roundtrip
func TestYAMLToBunRoundTrip(t *testing.T) {
testDir := t.TempDir()
// Step 1: Read YAML file
t.Log("Step 1: Reading YAML file...")
yamlPath := filepath.Join("..", "assets", "yaml", "database.yaml")
yamlPath := filepath.Join("..", "assets", "yaml", "complex_database.yaml")
yamlReaderOpts := &readers.ReaderOptions{
FilePath: yamlPath,
}
@@ -36,6 +153,10 @@ func TestYAMLToBunRoundTrip(t *testing.T) {
require.NotNil(t, dbFromYAML, "Database from YAML should not be nil")
t.Logf(" ✓ Read database '%s' with %d schemas", dbFromYAML.Name, len(dbFromYAML.Schemas))
// Log initial stats
totalTables, totalIndexes, totalConstraints := countDatabaseStats(dbFromYAML)
t.Logf(" ✓ Original: %d tables, %d indexes, %d constraints", totalTables, totalIndexes, totalConstraints)
// Step 2: Write to Bun Go code
t.Log("Step 2: Writing to Bun Go code...")
bunGoPath := filepath.Join(testDir, "models_bun.go")
@@ -103,67 +224,24 @@ func TestYAMLToBunRoundTrip(t *testing.T) {
err = yaml.Unmarshal(yaml2Data, &db2)
require.NoError(t, err, "Failed to parse second YAML")
// Compare high-level structure
t.Log(" Comparing high-level structure...")
assert.Equal(t, len(db1.Schemas), len(db2.Schemas), "Schema count should match")
// Compare schemas and tables
for i, schema1 := range db1.Schemas {
if i >= len(db2.Schemas) {
t.Errorf("Schema index %d out of bounds in second database", i)
continue
}
schema2 := db2.Schemas[i]
assert.Equal(t, schema1.Name, schema2.Name, "Schema names should match")
assert.Equal(t, len(schema1.Tables), len(schema2.Tables),
"Table count in schema '%s' should match", schema1.Name)
// Compare tables
for j, table1 := range schema1.Tables {
if j >= len(schema2.Tables) {
t.Errorf("Table index %d out of bounds in schema '%s'", j, schema1.Name)
continue
}
table2 := schema2.Tables[j]
assert.Equal(t, table1.Name, table2.Name,
"Table names should match in schema '%s'", schema1.Name)
// Compare column count
assert.Equal(t, len(table1.Columns), len(table2.Columns),
"Column count in table '%s.%s' should match", schema1.Name, table1.Name)
// Compare each column
for colName, col1 := range table1.Columns {
col2, ok := table2.Columns[colName]
if !ok {
t.Errorf("Column '%s' missing from roundtrip table '%s.%s'",
colName, schema1.Name, table1.Name)
continue
}
// Compare key column properties
assert.Equal(t, col1.Name, col2.Name,
"Column name mismatch in '%s.%s.%s'", schema1.Name, table1.Name, colName)
assert.Equal(t, col1.Type, col2.Type,
"Column type mismatch in '%s.%s.%s'", schema1.Name, table1.Name, colName)
assert.Equal(t, col1.IsPrimaryKey, col2.IsPrimaryKey,
"Primary key mismatch in '%s.%s.%s'", schema1.Name, table1.Name, colName)
}
}
}
// Comprehensive comparison
compareResults := compareDatabases(t, &db1, &db2, "Bun")
// Summary
t.Log("Summary:")
t.Logf(" ✓ Round-trip completed: YAML → Bun → YAML")
t.Logf(" ✓ Schemas match: %d", len(db1.Schemas))
t.Logf(" ✓ Schemas: %d", compareResults.Schemas)
t.Logf(" ✓ Tables: %d", compareResults.Tables)
t.Logf(" ✓ Columns: %d", compareResults.Columns)
t.Logf(" ✓ Indexes: %d (original), %d (roundtrip)", compareResults.OriginalIndexes, compareResults.RoundtripIndexes)
t.Logf(" ✓ Constraints: %d (original), %d (roundtrip)", compareResults.OriginalConstraints, compareResults.RoundtripConstraints)
totalTables := 0
for _, schema := range db1.Schemas {
totalTables += len(schema.Tables)
if compareResults.OriginalIndexes != compareResults.RoundtripIndexes {
t.Logf(" ⚠ Note: Index counts differ - Bun reader may not parse all index information from Go code")
}
if compareResults.OriginalConstraints != compareResults.RoundtripConstraints {
t.Logf(" ⚠ Note: Constraint counts differ - Bun reader may not parse all constraint information from Go code")
}
t.Logf(" ✓ Total tables: %d", totalTables)
}
// TestYAMLToGORMRoundTrip tests YAML → GORM Go → YAML roundtrip
@@ -172,7 +250,7 @@ func TestYAMLToGORMRoundTrip(t *testing.T) {
// Step 1: Read YAML file
t.Log("Step 1: Reading YAML file...")
yamlPath := filepath.Join("..", "assets", "yaml", "database.yaml")
yamlPath := filepath.Join("..", "assets", "yaml", "complex_database.yaml")
yamlReaderOpts := &readers.ReaderOptions{
FilePath: yamlPath,
}
@@ -183,6 +261,10 @@ func TestYAMLToGORMRoundTrip(t *testing.T) {
require.NotNil(t, dbFromYAML, "Database from YAML should not be nil")
t.Logf(" ✓ Read database '%s' with %d schemas", dbFromYAML.Name, len(dbFromYAML.Schemas))
// Log initial stats
totalTables, totalIndexes, totalConstraints := countDatabaseStats(dbFromYAML)
t.Logf(" ✓ Original: %d tables, %d indexes, %d constraints", totalTables, totalIndexes, totalConstraints)
// Step 2: Write to GORM Go code
t.Log("Step 2: Writing to GORM Go code...")
gormGoPath := filepath.Join(testDir, "models_gorm.go")
@@ -250,65 +332,22 @@ func TestYAMLToGORMRoundTrip(t *testing.T) {
err = yaml.Unmarshal(yaml2Data, &db2)
require.NoError(t, err, "Failed to parse second YAML")
// Compare high-level structure
t.Log(" Comparing high-level structure...")
assert.Equal(t, len(db1.Schemas), len(db2.Schemas), "Schema count should match")
// Compare schemas and tables
for i, schema1 := range db1.Schemas {
if i >= len(db2.Schemas) {
t.Errorf("Schema index %d out of bounds in second database", i)
continue
}
schema2 := db2.Schemas[i]
assert.Equal(t, schema1.Name, schema2.Name, "Schema names should match")
assert.Equal(t, len(schema1.Tables), len(schema2.Tables),
"Table count in schema '%s' should match", schema1.Name)
// Compare tables
for j, table1 := range schema1.Tables {
if j >= len(schema2.Tables) {
t.Errorf("Table index %d out of bounds in schema '%s'", j, schema1.Name)
continue
}
table2 := schema2.Tables[j]
assert.Equal(t, table1.Name, table2.Name,
"Table names should match in schema '%s'", schema1.Name)
// Compare column count
assert.Equal(t, len(table1.Columns), len(table2.Columns),
"Column count in table '%s.%s' should match", schema1.Name, table1.Name)
// Compare each column
for colName, col1 := range table1.Columns {
col2, ok := table2.Columns[colName]
if !ok {
t.Errorf("Column '%s' missing from roundtrip table '%s.%s'",
colName, schema1.Name, table1.Name)
continue
}
// Compare key column properties
assert.Equal(t, col1.Name, col2.Name,
"Column name mismatch in '%s.%s.%s'", schema1.Name, table1.Name, colName)
assert.Equal(t, col1.Type, col2.Type,
"Column type mismatch in '%s.%s.%s'", schema1.Name, table1.Name, colName)
assert.Equal(t, col1.IsPrimaryKey, col2.IsPrimaryKey,
"Primary key mismatch in '%s.%s.%s'", schema1.Name, table1.Name, colName)
}
}
}
// Comprehensive comparison
compareResults := compareDatabases(t, &db1, &db2, "GORM")
// Summary
t.Log("Summary:")
t.Logf(" ✓ Round-trip completed: YAML → GORM → YAML")
t.Logf(" ✓ Schemas match: %d", len(db1.Schemas))
t.Logf(" ✓ Schemas: %d", compareResults.Schemas)
t.Logf(" ✓ Tables: %d", compareResults.Tables)
t.Logf(" ✓ Columns: %d", compareResults.Columns)
t.Logf(" ✓ Indexes: %d (original), %d (roundtrip)", compareResults.OriginalIndexes, compareResults.RoundtripIndexes)
t.Logf(" ✓ Constraints: %d (original), %d (roundtrip)", compareResults.OriginalConstraints, compareResults.RoundtripConstraints)
totalTables := 0
for _, schema := range db1.Schemas {
totalTables += len(schema.Tables)
if compareResults.OriginalIndexes != compareResults.RoundtripIndexes {
t.Logf(" ⚠ Note: Index counts differ - GORM reader may not parse all index information from Go code")
}
if compareResults.OriginalConstraints != compareResults.RoundtripConstraints {
t.Logf(" ⚠ Note: Constraint counts differ - GORM reader may not parse all constraint information from Go code")
}
t.Logf(" ✓ Total tables: %d", totalTables)
}