Fixed bug/gorm indexes
This commit is contained in:
@@ -187,11 +187,43 @@ func (r *Reader) parseFile(filename string) ([]*models.Table, error) {
|
|||||||
table.Schema = schemaName
|
table.Schema = schemaName
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update columns
|
// Update columns and indexes
|
||||||
for _, col := range table.Columns {
|
for _, col := range table.Columns {
|
||||||
col.Table = tableName
|
col.Table = tableName
|
||||||
col.Schema = table.Schema
|
col.Schema = table.Schema
|
||||||
}
|
}
|
||||||
|
for _, idx := range table.Indexes {
|
||||||
|
idx.Table = tableName
|
||||||
|
idx.Schema = table.Schema
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Third pass: parse relationship fields for constraints
|
||||||
|
for _, decl := range node.Decls {
|
||||||
|
genDecl, ok := decl.(*ast.GenDecl)
|
||||||
|
if !ok || genDecl.Tok != token.TYPE {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, spec := range genDecl.Specs {
|
||||||
|
typeSpec, ok := spec.(*ast.TypeSpec)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
structType, ok := typeSpec.Type.(*ast.StructType)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
table, ok := structMap[typeSpec.Name.Name]
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse relationship fields
|
||||||
|
r.parseRelationshipConstraints(table, structType, structMap)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -314,6 +346,10 @@ func (r *Reader) parseStruct(structName string, structType *ast.StructType) *mod
|
|||||||
column.Table = tableName
|
column.Table = tableName
|
||||||
column.Schema = schemaName
|
column.Schema = schemaName
|
||||||
table.Columns[column.Name] = column
|
table.Columns[column.Name] = column
|
||||||
|
|
||||||
|
// Parse indexes from bun tags
|
||||||
|
r.parseIndexesFromTag(table, column, tag)
|
||||||
|
|
||||||
sequence++
|
sequence++
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -346,6 +382,159 @@ func (r *Reader) isRelationship(tag string) bool {
|
|||||||
return strings.Contains(tag, "bun:\"rel:") || strings.Contains(tag, ",rel:")
|
return strings.Contains(tag, "bun:\"rel:") || strings.Contains(tag, ",rel:")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// parseRelationshipConstraints parses relationship fields to extract foreign key constraints
|
||||||
|
func (r *Reader) parseRelationshipConstraints(table *models.Table, structType *ast.StructType, structMap map[string]*models.Table) {
|
||||||
|
for _, field := range structType.Fields.List {
|
||||||
|
if field.Tag == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
tag := field.Tag.Value
|
||||||
|
if !r.isRelationship(tag) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
bunTag := r.extractBunTag(tag)
|
||||||
|
|
||||||
|
// Get the referenced type name from the field type
|
||||||
|
referencedType := r.getRelationshipType(field.Type)
|
||||||
|
if referencedType == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find the referenced table
|
||||||
|
referencedTable, ok := structMap[referencedType]
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse the join information: join:user_id=id
|
||||||
|
// This means: referencedTable.user_id = thisTable.id
|
||||||
|
joinInfo := r.parseJoinInfo(bunTag)
|
||||||
|
if joinInfo == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// The FK is on the referenced table
|
||||||
|
constraint := &models.Constraint{
|
||||||
|
Name: fmt.Sprintf("fk_%s_%s", referencedTable.Name, table.Name),
|
||||||
|
Type: models.ForeignKeyConstraint,
|
||||||
|
Table: referencedTable.Name,
|
||||||
|
Schema: referencedTable.Schema,
|
||||||
|
Columns: []string{joinInfo.ForeignKey},
|
||||||
|
ReferencedTable: table.Name,
|
||||||
|
ReferencedSchema: table.Schema,
|
||||||
|
ReferencedColumns: []string{joinInfo.ReferencedKey},
|
||||||
|
OnDelete: "NO ACTION", // Bun doesn't specify this in tags
|
||||||
|
OnUpdate: "NO ACTION",
|
||||||
|
}
|
||||||
|
|
||||||
|
referencedTable.Constraints[constraint.Name] = constraint
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// JoinInfo holds parsed join information
|
||||||
|
type JoinInfo struct {
|
||||||
|
ForeignKey string // Column in the related table
|
||||||
|
ReferencedKey string // Column in the current table
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseJoinInfo parses join information from bun tag
|
||||||
|
// Example: join:user_id=id means foreign_key=referenced_key
|
||||||
|
func (r *Reader) parseJoinInfo(bunTag string) *JoinInfo {
|
||||||
|
// Look for join: in the tag
|
||||||
|
if !strings.Contains(bunTag, "join:") {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract join clause
|
||||||
|
parts := strings.Split(bunTag, ",")
|
||||||
|
for _, part := range parts {
|
||||||
|
part = strings.TrimSpace(part)
|
||||||
|
if strings.HasPrefix(part, "join:") {
|
||||||
|
joinStr := strings.TrimPrefix(part, "join:")
|
||||||
|
// Parse user_id=id
|
||||||
|
joinParts := strings.SplitN(joinStr, "=", 2)
|
||||||
|
if len(joinParts) == 2 {
|
||||||
|
return &JoinInfo{
|
||||||
|
ForeignKey: joinParts[0],
|
||||||
|
ReferencedKey: joinParts[1],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getRelationshipType extracts the type name from a relationship field
|
||||||
|
func (r *Reader) getRelationshipType(expr ast.Expr) string {
|
||||||
|
switch t := expr.(type) {
|
||||||
|
case *ast.ArrayType:
|
||||||
|
// []*ModelPost -> ModelPost
|
||||||
|
if starExpr, ok := t.Elt.(*ast.StarExpr); ok {
|
||||||
|
if ident, ok := starExpr.X.(*ast.Ident); ok {
|
||||||
|
return ident.Name
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case *ast.StarExpr:
|
||||||
|
// *ModelPost -> ModelPost
|
||||||
|
if ident, ok := t.X.(*ast.Ident); ok {
|
||||||
|
return ident.Name
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseIndexesFromTag extracts index definitions from Bun tags
|
||||||
|
func (r *Reader) parseIndexesFromTag(table *models.Table, column *models.Column, tag string) {
|
||||||
|
bunTag := r.extractBunTag(tag)
|
||||||
|
|
||||||
|
// Parse tag into parts
|
||||||
|
parts := strings.Split(bunTag, ",")
|
||||||
|
for _, part := range parts {
|
||||||
|
part = strings.TrimSpace(part)
|
||||||
|
|
||||||
|
// Check for unique index
|
||||||
|
if part == "unique" {
|
||||||
|
// Auto-generate index name: idx_tablename_columnname
|
||||||
|
indexName := fmt.Sprintf("idx_%s_%s", table.Name, column.Name)
|
||||||
|
|
||||||
|
if _, exists := table.Indexes[indexName]; !exists {
|
||||||
|
index := &models.Index{
|
||||||
|
Name: indexName,
|
||||||
|
Table: table.Name,
|
||||||
|
Schema: table.Schema,
|
||||||
|
Columns: []string{column.Name},
|
||||||
|
Unique: true,
|
||||||
|
Type: "btree",
|
||||||
|
}
|
||||||
|
table.Indexes[indexName] = index
|
||||||
|
}
|
||||||
|
} else if strings.HasPrefix(part, "unique:") {
|
||||||
|
// Named unique index: unique:idx_name
|
||||||
|
indexName := strings.TrimPrefix(part, "unique:")
|
||||||
|
|
||||||
|
// Check if index already exists (for composite indexes)
|
||||||
|
if idx, exists := table.Indexes[indexName]; exists {
|
||||||
|
// Add this column to the existing index
|
||||||
|
idx.Columns = append(idx.Columns, column.Name)
|
||||||
|
} else {
|
||||||
|
// Create new index
|
||||||
|
index := &models.Index{
|
||||||
|
Name: indexName,
|
||||||
|
Table: table.Name,
|
||||||
|
Schema: table.Schema,
|
||||||
|
Columns: []string{column.Name},
|
||||||
|
Unique: true,
|
||||||
|
Type: "btree",
|
||||||
|
}
|
||||||
|
table.Indexes[indexName] = index
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// extractTableNameFromTag extracts table and schema from bun tag
|
// extractTableNameFromTag extracts table and schema from bun tag
|
||||||
func (r *Reader) extractTableNameFromTag(tag string) (tableName string, schemaName string) {
|
func (r *Reader) extractTableNameFromTag(tag string) (tableName string, schemaName string) {
|
||||||
// Extract bun tag value
|
// Extract bun tag value
|
||||||
@@ -422,6 +611,7 @@ func (r *Reader) parseColumn(fieldName string, fieldType ast.Expr, tag string, s
|
|||||||
case "autoincrement":
|
case "autoincrement":
|
||||||
column.AutoIncrement = true
|
column.AutoIncrement = true
|
||||||
case "default":
|
case "default":
|
||||||
|
// Default value from Bun tag (e.g., default:gen_random_uuid())
|
||||||
column.Default = value
|
column.Default = value
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -431,16 +621,24 @@ func (r *Reader) parseColumn(fieldName string, fieldType ast.Expr, tag string, s
|
|||||||
column.Type = r.goTypeToSQL(fieldType)
|
column.Type = r.goTypeToSQL(fieldType)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Determine if nullable based on Go type
|
// Determine if nullable based on Go type and bun tags
|
||||||
if r.isNullableType(fieldType) {
|
// In Bun:
|
||||||
|
// - nullzero tag means the field is nullable (can be NULL in DB)
|
||||||
|
// - absence of nullzero means the field is NOT NULL
|
||||||
|
// - primitive types (int64, bool, string) are NOT NULL by default
|
||||||
|
if strings.Contains(bunTag, "nullzero") {
|
||||||
column.NotNull = false
|
column.NotNull = false
|
||||||
} else if !column.IsPrimaryKey && column.Type != "" {
|
} else if r.isNullableGoType(fieldType) {
|
||||||
// If it's not a nullable type and not a primary key, check the tag
|
// SqlString, SqlInt, etc. without nullzero tag means NOT NULL
|
||||||
if !strings.Contains(bunTag, "notnull") {
|
column.NotNull = true
|
||||||
// If notnull is not explicitly set, it might still be nullable
|
} else {
|
||||||
// This is a heuristic - we default to nullable unless specified
|
// Primitive types are NOT NULL by default
|
||||||
return column
|
column.NotNull = true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Primary keys are always NOT NULL
|
||||||
|
if column.IsPrimaryKey {
|
||||||
|
column.NotNull = true
|
||||||
}
|
}
|
||||||
|
|
||||||
return column
|
return column
|
||||||
@@ -529,11 +727,12 @@ func (r *Reader) sqlTypeToSQL(typeName string) string {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// isNullableType checks if a Go type represents a nullable field
|
// isNullableGoType checks if a Go type represents a nullable field type
|
||||||
func (r *Reader) isNullableType(expr ast.Expr) bool {
|
// (this is for types that CAN be nullable, not whether they ARE nullable)
|
||||||
|
func (r *Reader) isNullableGoType(expr ast.Expr) bool {
|
||||||
switch t := expr.(type) {
|
switch t := expr.(type) {
|
||||||
case *ast.StarExpr:
|
case *ast.StarExpr:
|
||||||
// Pointer type is nullable
|
// Pointer type can be nullable
|
||||||
return true
|
return true
|
||||||
case *ast.SelectorExpr:
|
case *ast.SelectorExpr:
|
||||||
// Check for sql_types nullable types
|
// Check for sql_types nullable types
|
||||||
|
|||||||
@@ -187,11 +187,44 @@ func (r *Reader) parseFile(filename string) ([]*models.Table, error) {
|
|||||||
table.Schema = schemaName
|
table.Schema = schemaName
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update columns
|
// Update columns and indexes
|
||||||
for _, col := range table.Columns {
|
for _, col := range table.Columns {
|
||||||
col.Table = tableName
|
col.Table = tableName
|
||||||
col.Schema = table.Schema
|
col.Schema = table.Schema
|
||||||
}
|
}
|
||||||
|
for _, idx := range table.Indexes {
|
||||||
|
idx.Table = tableName
|
||||||
|
idx.Schema = table.Schema
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Third pass: parse relationship fields for constraints
|
||||||
|
// Re-parse the file to get relationship information
|
||||||
|
for _, decl := range node.Decls {
|
||||||
|
genDecl, ok := decl.(*ast.GenDecl)
|
||||||
|
if !ok || genDecl.Tok != token.TYPE {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, spec := range genDecl.Specs {
|
||||||
|
typeSpec, ok := spec.(*ast.TypeSpec)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
structType, ok := typeSpec.Type.(*ast.StructType)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
table, ok := structMap[typeSpec.Name.Name]
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse relationship fields
|
||||||
|
r.parseRelationshipConstraints(table, structType, structMap)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -280,8 +313,14 @@ func (r *Reader) parseStruct(structName string, structType *ast.StructType) *mod
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Skip embedded GORM model and relationship fields
|
// Skip embedded GORM model
|
||||||
if r.isGORMModel(field) || r.isRelationship(tag) {
|
if r.isGORMModel(field) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse relationship fields for foreign key constraints
|
||||||
|
if r.isRelationship(tag) {
|
||||||
|
// We'll parse constraints in a second pass after we know all table names
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -310,6 +349,10 @@ func (r *Reader) parseStruct(structName string, structType *ast.StructType) *mod
|
|||||||
table.Name = tableName
|
table.Name = tableName
|
||||||
table.Schema = schemaName
|
table.Schema = schemaName
|
||||||
table.Columns[column.Name] = column
|
table.Columns[column.Name] = column
|
||||||
|
|
||||||
|
// Parse indexes from GORM tags
|
||||||
|
r.parseIndexesFromTag(table, column, tag)
|
||||||
|
|
||||||
sequence++
|
sequence++
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -345,6 +388,169 @@ func (r *Reader) isRelationship(tag string) bool {
|
|||||||
strings.Contains(gormTag, "many2many:")
|
strings.Contains(gormTag, "many2many:")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// parseRelationshipConstraints parses relationship fields to extract foreign key constraints
|
||||||
|
func (r *Reader) parseRelationshipConstraints(table *models.Table, structType *ast.StructType, structMap map[string]*models.Table) {
|
||||||
|
for _, field := range structType.Fields.List {
|
||||||
|
if field.Tag == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
tag := field.Tag.Value
|
||||||
|
if !r.isRelationship(tag) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
gormTag := r.extractGormTag(tag)
|
||||||
|
parts := r.parseGormTag(gormTag)
|
||||||
|
|
||||||
|
// Get the referenced type name from the field type
|
||||||
|
referencedType := r.getRelationshipType(field.Type)
|
||||||
|
if referencedType == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find the referenced table
|
||||||
|
referencedTable, ok := structMap[referencedType]
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract foreign key information
|
||||||
|
foreignKey, hasForeignKey := parts["foreignKey"]
|
||||||
|
if !hasForeignKey {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert field name to column name
|
||||||
|
fkColumn := r.fieldNameToColumnName(foreignKey)
|
||||||
|
|
||||||
|
// Determine constraint behavior
|
||||||
|
onDelete := "NO ACTION"
|
||||||
|
onUpdate := "NO ACTION"
|
||||||
|
if constraintStr, hasConstraint := parts["constraint"]; hasConstraint {
|
||||||
|
// Parse constraint:OnDelete:CASCADE,OnUpdate:CASCADE
|
||||||
|
if strings.Contains(constraintStr, "OnDelete:CASCADE") {
|
||||||
|
onDelete = "CASCADE"
|
||||||
|
} else if strings.Contains(constraintStr, "OnDelete:SET NULL") {
|
||||||
|
onDelete = "SET NULL"
|
||||||
|
}
|
||||||
|
if strings.Contains(constraintStr, "OnUpdate:CASCADE") {
|
||||||
|
onUpdate = "CASCADE"
|
||||||
|
} else if strings.Contains(constraintStr, "OnUpdate:SET NULL") {
|
||||||
|
onUpdate = "SET NULL"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// The FK is on the referenced table, pointing back to this table
|
||||||
|
// For has-many, the FK is on the "many" side
|
||||||
|
constraint := &models.Constraint{
|
||||||
|
Name: fmt.Sprintf("fk_%s_%s", referencedTable.Name, table.Name),
|
||||||
|
Type: models.ForeignKeyConstraint,
|
||||||
|
Table: referencedTable.Name,
|
||||||
|
Schema: referencedTable.Schema,
|
||||||
|
Columns: []string{fkColumn},
|
||||||
|
ReferencedTable: table.Name,
|
||||||
|
ReferencedSchema: table.Schema,
|
||||||
|
ReferencedColumns: []string{"id"}, // Typically references the primary key
|
||||||
|
OnDelete: onDelete,
|
||||||
|
OnUpdate: onUpdate,
|
||||||
|
}
|
||||||
|
|
||||||
|
referencedTable.Constraints[constraint.Name] = constraint
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// getRelationshipType extracts the type name from a relationship field
|
||||||
|
func (r *Reader) getRelationshipType(expr ast.Expr) string {
|
||||||
|
switch t := expr.(type) {
|
||||||
|
case *ast.ArrayType:
|
||||||
|
// []*ModelPost -> ModelPost
|
||||||
|
if starExpr, ok := t.Elt.(*ast.StarExpr); ok {
|
||||||
|
if ident, ok := starExpr.X.(*ast.Ident); ok {
|
||||||
|
return ident.Name
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case *ast.StarExpr:
|
||||||
|
// *ModelPost -> ModelPost
|
||||||
|
if ident, ok := t.X.(*ast.Ident); ok {
|
||||||
|
return ident.Name
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseIndexesFromTag extracts index definitions from GORM tags
|
||||||
|
func (r *Reader) parseIndexesFromTag(table *models.Table, column *models.Column, tag string) {
|
||||||
|
gormTag := r.extractGormTag(tag)
|
||||||
|
parts := r.parseGormTag(gormTag)
|
||||||
|
|
||||||
|
// Check for regular index: index:idx_name or index
|
||||||
|
if indexName, ok := parts["index"]; ok {
|
||||||
|
if indexName == "" {
|
||||||
|
// Auto-generated index name
|
||||||
|
indexName = fmt.Sprintf("idx_%s_%s", table.Name, column.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if index already exists
|
||||||
|
if _, exists := table.Indexes[indexName]; !exists {
|
||||||
|
index := &models.Index{
|
||||||
|
Name: indexName,
|
||||||
|
Table: table.Name,
|
||||||
|
Schema: table.Schema,
|
||||||
|
Columns: []string{column.Name},
|
||||||
|
Unique: false,
|
||||||
|
Type: "btree",
|
||||||
|
}
|
||||||
|
table.Indexes[indexName] = index
|
||||||
|
} else {
|
||||||
|
// Add column to existing index
|
||||||
|
table.Indexes[indexName].Columns = append(table.Indexes[indexName].Columns, column.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for unique index: uniqueIndex:idx_name or uniqueIndex
|
||||||
|
if indexName, ok := parts["uniqueIndex"]; ok {
|
||||||
|
if indexName == "" {
|
||||||
|
// Auto-generated index name
|
||||||
|
indexName = fmt.Sprintf("idx_%s_%s", table.Name, column.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if index already exists
|
||||||
|
if _, exists := table.Indexes[indexName]; !exists {
|
||||||
|
index := &models.Index{
|
||||||
|
Name: indexName,
|
||||||
|
Table: table.Name,
|
||||||
|
Schema: table.Schema,
|
||||||
|
Columns: []string{column.Name},
|
||||||
|
Unique: true,
|
||||||
|
Type: "btree",
|
||||||
|
}
|
||||||
|
table.Indexes[indexName] = index
|
||||||
|
} else {
|
||||||
|
// Add column to existing index
|
||||||
|
table.Indexes[indexName].Columns = append(table.Indexes[indexName].Columns, column.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for simple unique flag (creates a unique index for this column)
|
||||||
|
if _, ok := parts["unique"]; ok {
|
||||||
|
// Auto-generated index name for unique constraint
|
||||||
|
indexName := fmt.Sprintf("idx_%s_%s", table.Name, column.Name)
|
||||||
|
|
||||||
|
if _, exists := table.Indexes[indexName]; !exists {
|
||||||
|
index := &models.Index{
|
||||||
|
Name: indexName,
|
||||||
|
Table: table.Name,
|
||||||
|
Schema: table.Schema,
|
||||||
|
Columns: []string{column.Name},
|
||||||
|
Unique: true,
|
||||||
|
Type: "btree",
|
||||||
|
}
|
||||||
|
table.Indexes[indexName] = index
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// extractTableFromGormTag extracts table and schema from gorm tag
|
// extractTableFromGormTag extracts table and schema from gorm tag
|
||||||
func (r *Reader) extractTableFromGormTag(tag string) (tablename string, schemaName string) {
|
func (r *Reader) extractTableFromGormTag(tag string) (tablename string, schemaName string) {
|
||||||
// This is typically set via TableName() method, not in tags
|
// This is typically set via TableName() method, not in tags
|
||||||
@@ -406,6 +612,7 @@ func (r *Reader) parseColumn(fieldName string, fieldType ast.Expr, tag string, s
|
|||||||
column.AutoIncrement = true
|
column.AutoIncrement = true
|
||||||
}
|
}
|
||||||
if def, ok := parts["default"]; ok {
|
if def, ok := parts["default"]; ok {
|
||||||
|
// Default value from GORM tag (e.g., default:gen_random_uuid())
|
||||||
column.Default = def
|
column.Default = def
|
||||||
}
|
}
|
||||||
if size, ok := parts["size"]; ok {
|
if size, ok := parts["size"]; ok {
|
||||||
@@ -419,9 +626,27 @@ func (r *Reader) parseColumn(fieldName string, fieldType ast.Expr, tag string, s
|
|||||||
column.Type = r.goTypeToSQL(fieldType)
|
column.Type = r.goTypeToSQL(fieldType)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Determine if nullable based on Go type
|
// Determine if nullable based on GORM tags and Go type
|
||||||
if r.isNullableType(fieldType) {
|
// In GORM:
|
||||||
column.NotNull = false
|
// - explicit "not null" tag means NOT NULL
|
||||||
|
// - absence of "not null" tag with sql_types means nullable
|
||||||
|
// - primitive types (string, int64, bool) default to NOT NULL unless explicitly nullable
|
||||||
|
if _, hasNotNull := parts["not null"]; hasNotNull {
|
||||||
|
column.NotNull = true
|
||||||
|
} else {
|
||||||
|
// If no explicit "not null" tag, check the Go type
|
||||||
|
if r.isNullableGoType(fieldType) {
|
||||||
|
// sql_types.SqlString, etc. are nullable by default
|
||||||
|
column.NotNull = false
|
||||||
|
} else {
|
||||||
|
// Primitive types default to NOT NULL
|
||||||
|
column.NotNull = false // Default to nullable unless explicitly set
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Primary keys are always NOT NULL
|
||||||
|
if column.IsPrimaryKey {
|
||||||
|
column.NotNull = true
|
||||||
}
|
}
|
||||||
|
|
||||||
return column
|
return column
|
||||||
@@ -557,11 +782,12 @@ func (r *Reader) sqlTypeToSQL(typeName string) string {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// isNullableType checks if a Go type represents a nullable field
|
// isNullableGoType checks if a Go type represents a nullable field type
|
||||||
func (r *Reader) isNullableType(expr ast.Expr) bool {
|
// (this is for types that CAN be nullable, not whether they ARE nullable)
|
||||||
|
func (r *Reader) isNullableGoType(expr ast.Expr) bool {
|
||||||
switch t := expr.(type) {
|
switch t := expr.(type) {
|
||||||
case *ast.StarExpr:
|
case *ast.StarExpr:
|
||||||
// Pointer type is nullable
|
// Pointer type can be nullable
|
||||||
return true
|
return true
|
||||||
case *ast.SelectorExpr:
|
case *ast.SelectorExpr:
|
||||||
// Check for sql_types nullable types
|
// Check for sql_types nullable types
|
||||||
|
|||||||
@@ -196,15 +196,31 @@ func (tm *TypeMapper) BuildBunTag(column *models.Column, table *models.Table) st
|
|||||||
parts = append(parts, "nullzero")
|
parts = append(parts, "nullzero")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check for unique constraint
|
// Check for indexes (unique indexes should be added to tag)
|
||||||
if table != nil {
|
if table != nil {
|
||||||
for _, constraint := range table.Constraints {
|
for _, index := range table.Indexes {
|
||||||
if constraint.Type == models.UniqueConstraint {
|
if !index.Unique {
|
||||||
for _, col := range constraint.Columns {
|
continue
|
||||||
if col == column.Name {
|
}
|
||||||
parts = append(parts, "unique")
|
// Check if this column is in the index
|
||||||
break
|
for _, col := range index.Columns {
|
||||||
|
if col == column.Name {
|
||||||
|
// Add unique tag with index name for composite indexes
|
||||||
|
// or simple unique for single-column indexes
|
||||||
|
if len(index.Columns) > 1 {
|
||||||
|
// Composite index - use index name
|
||||||
|
parts = append(parts, fmt.Sprintf("unique:%s", index.Name))
|
||||||
|
} else {
|
||||||
|
// Single column - use index name if it's not auto-generated
|
||||||
|
// Auto-generated names typically follow pattern: idx_tablename_columnname
|
||||||
|
expectedAutoName := fmt.Sprintf("idx_%s_%s", table.Name, column.Name)
|
||||||
|
if index.Name == expectedAutoName {
|
||||||
|
parts = append(parts, "unique")
|
||||||
|
} else {
|
||||||
|
parts = append(parts, fmt.Sprintf("unique:%s", index.Name))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,18 +14,135 @@ import (
|
|||||||
bunwriter "git.warky.dev/wdevs/relspecgo/pkg/writers/bun"
|
bunwriter "git.warky.dev/wdevs/relspecgo/pkg/writers/bun"
|
||||||
gormwriter "git.warky.dev/wdevs/relspecgo/pkg/writers/gorm"
|
gormwriter "git.warky.dev/wdevs/relspecgo/pkg/writers/gorm"
|
||||||
yamlwriter "git.warky.dev/wdevs/relspecgo/pkg/writers/yaml"
|
yamlwriter "git.warky.dev/wdevs/relspecgo/pkg/writers/yaml"
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"gopkg.in/yaml.v3"
|
"gopkg.in/yaml.v3"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// ComparisonResults holds the results of database comparison
|
||||||
|
type ComparisonResults struct {
|
||||||
|
Schemas int
|
||||||
|
Tables int
|
||||||
|
Columns int
|
||||||
|
OriginalIndexes int
|
||||||
|
RoundtripIndexes int
|
||||||
|
OriginalConstraints int
|
||||||
|
RoundtripConstraints int
|
||||||
|
}
|
||||||
|
|
||||||
|
// countDatabaseStats counts tables, indexes, and constraints in a database
|
||||||
|
func countDatabaseStats(db *models.Database) (tables, indexes, constraints int) {
|
||||||
|
for _, schema := range db.Schemas {
|
||||||
|
tables += len(schema.Tables)
|
||||||
|
for _, table := range schema.Tables {
|
||||||
|
indexes += len(table.Indexes)
|
||||||
|
constraints += len(table.Constraints)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// compareDatabases performs comprehensive comparison between two databases
|
||||||
|
func compareDatabases(t *testing.T, db1, db2 *models.Database, ormName string) ComparisonResults {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
results := ComparisonResults{}
|
||||||
|
|
||||||
|
// Compare high-level structure
|
||||||
|
t.Log(" Comparing high-level structure...")
|
||||||
|
require.Equal(t, len(db1.Schemas), len(db2.Schemas), "Schema count should match")
|
||||||
|
results.Schemas = len(db1.Schemas)
|
||||||
|
|
||||||
|
// Count totals
|
||||||
|
tables1, indexes1, constraints1 := countDatabaseStats(db1)
|
||||||
|
_, indexes2, constraints2 := countDatabaseStats(db2)
|
||||||
|
|
||||||
|
results.OriginalIndexes = indexes1
|
||||||
|
results.RoundtripIndexes = indexes2
|
||||||
|
results.OriginalConstraints = constraints1
|
||||||
|
results.RoundtripConstraints = constraints2
|
||||||
|
|
||||||
|
// Compare schemas and tables
|
||||||
|
for i, schema1 := range db1.Schemas {
|
||||||
|
if i >= len(db2.Schemas) {
|
||||||
|
t.Errorf("Schema index %d out of bounds in second database", i)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
schema2 := db2.Schemas[i]
|
||||||
|
|
||||||
|
require.Equal(t, schema1.Name, schema2.Name, "Schema names should match")
|
||||||
|
require.Equal(t, len(schema1.Tables), len(schema2.Tables),
|
||||||
|
"Table count in schema '%s' should match", schema1.Name)
|
||||||
|
|
||||||
|
// Compare tables
|
||||||
|
for j, table1 := range schema1.Tables {
|
||||||
|
if j >= len(schema2.Tables) {
|
||||||
|
t.Errorf("Table index %d out of bounds in schema '%s'", j, schema1.Name)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
table2 := schema2.Tables[j]
|
||||||
|
|
||||||
|
require.Equal(t, table1.Name, table2.Name,
|
||||||
|
"Table names should match in schema '%s'", schema1.Name)
|
||||||
|
|
||||||
|
// Compare column count
|
||||||
|
require.Equal(t, len(table1.Columns), len(table2.Columns),
|
||||||
|
"Column count in table '%s.%s' should match", schema1.Name, table1.Name)
|
||||||
|
|
||||||
|
results.Columns += len(table1.Columns)
|
||||||
|
|
||||||
|
// Compare each column
|
||||||
|
for colName, col1 := range table1.Columns {
|
||||||
|
col2, ok := table2.Columns[colName]
|
||||||
|
if !ok {
|
||||||
|
t.Errorf("Column '%s' missing from roundtrip table '%s.%s'",
|
||||||
|
colName, schema1.Name, table1.Name)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compare key column properties
|
||||||
|
require.Equal(t, col1.Name, col2.Name,
|
||||||
|
"Column name mismatch in '%s.%s.%s'", schema1.Name, table1.Name, colName)
|
||||||
|
require.Equal(t, col1.Type, col2.Type,
|
||||||
|
"Column type mismatch in '%s.%s.%s'", schema1.Name, table1.Name, colName)
|
||||||
|
require.Equal(t, col1.Length, col2.Length,
|
||||||
|
"Column length mismatch in '%s.%s.%s'", schema1.Name, table1.Name, colName)
|
||||||
|
require.Equal(t, col1.IsPrimaryKey, col2.IsPrimaryKey,
|
||||||
|
"Primary key mismatch in '%s.%s.%s'", schema1.Name, table1.Name, colName)
|
||||||
|
require.Equal(t, col1.NotNull, col2.NotNull,
|
||||||
|
"NotNull mismatch in '%s.%s.%s'", schema1.Name, table1.Name, colName)
|
||||||
|
|
||||||
|
// Log defaults that don't match (these can vary in representation)
|
||||||
|
if col1.Default != col2.Default {
|
||||||
|
t.Logf(" ℹ Default value differs for '%s.%s.%s': '%v' vs '%v'",
|
||||||
|
schema1.Name, table1.Name, colName, col1.Default, col2.Default)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Log index and constraint differences (ORM readers may not capture all of these)
|
||||||
|
if len(table1.Indexes) != len(table2.Indexes) {
|
||||||
|
t.Logf(" ℹ Index count differs for table '%s.%s': %d vs %d",
|
||||||
|
schema1.Name, table1.Name, len(table1.Indexes), len(table2.Indexes))
|
||||||
|
}
|
||||||
|
if len(table1.Constraints) != len(table2.Constraints) {
|
||||||
|
t.Logf(" ℹ Constraint count differs for table '%s.%s': %d vs %d",
|
||||||
|
schema1.Name, table1.Name, len(table1.Constraints), len(table2.Constraints))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
results.Tables = tables1
|
||||||
|
t.Logf(" ✓ Validated %d schemas, %d tables, %d columns", results.Schemas, results.Tables, results.Columns)
|
||||||
|
|
||||||
|
return results
|
||||||
|
}
|
||||||
|
|
||||||
// TestYAMLToBunRoundTrip tests YAML → Bun Go → YAML roundtrip
|
// TestYAMLToBunRoundTrip tests YAML → Bun Go → YAML roundtrip
|
||||||
func TestYAMLToBunRoundTrip(t *testing.T) {
|
func TestYAMLToBunRoundTrip(t *testing.T) {
|
||||||
testDir := t.TempDir()
|
testDir := t.TempDir()
|
||||||
|
|
||||||
// Step 1: Read YAML file
|
// Step 1: Read YAML file
|
||||||
t.Log("Step 1: Reading YAML file...")
|
t.Log("Step 1: Reading YAML file...")
|
||||||
yamlPath := filepath.Join("..", "assets", "yaml", "database.yaml")
|
yamlPath := filepath.Join("..", "assets", "yaml", "complex_database.yaml")
|
||||||
yamlReaderOpts := &readers.ReaderOptions{
|
yamlReaderOpts := &readers.ReaderOptions{
|
||||||
FilePath: yamlPath,
|
FilePath: yamlPath,
|
||||||
}
|
}
|
||||||
@@ -36,6 +153,10 @@ func TestYAMLToBunRoundTrip(t *testing.T) {
|
|||||||
require.NotNil(t, dbFromYAML, "Database from YAML should not be nil")
|
require.NotNil(t, dbFromYAML, "Database from YAML should not be nil")
|
||||||
t.Logf(" ✓ Read database '%s' with %d schemas", dbFromYAML.Name, len(dbFromYAML.Schemas))
|
t.Logf(" ✓ Read database '%s' with %d schemas", dbFromYAML.Name, len(dbFromYAML.Schemas))
|
||||||
|
|
||||||
|
// Log initial stats
|
||||||
|
totalTables, totalIndexes, totalConstraints := countDatabaseStats(dbFromYAML)
|
||||||
|
t.Logf(" ✓ Original: %d tables, %d indexes, %d constraints", totalTables, totalIndexes, totalConstraints)
|
||||||
|
|
||||||
// Step 2: Write to Bun Go code
|
// Step 2: Write to Bun Go code
|
||||||
t.Log("Step 2: Writing to Bun Go code...")
|
t.Log("Step 2: Writing to Bun Go code...")
|
||||||
bunGoPath := filepath.Join(testDir, "models_bun.go")
|
bunGoPath := filepath.Join(testDir, "models_bun.go")
|
||||||
@@ -103,67 +224,24 @@ func TestYAMLToBunRoundTrip(t *testing.T) {
|
|||||||
err = yaml.Unmarshal(yaml2Data, &db2)
|
err = yaml.Unmarshal(yaml2Data, &db2)
|
||||||
require.NoError(t, err, "Failed to parse second YAML")
|
require.NoError(t, err, "Failed to parse second YAML")
|
||||||
|
|
||||||
// Compare high-level structure
|
// Comprehensive comparison
|
||||||
t.Log(" Comparing high-level structure...")
|
compareResults := compareDatabases(t, &db1, &db2, "Bun")
|
||||||
assert.Equal(t, len(db1.Schemas), len(db2.Schemas), "Schema count should match")
|
|
||||||
|
|
||||||
// Compare schemas and tables
|
|
||||||
for i, schema1 := range db1.Schemas {
|
|
||||||
if i >= len(db2.Schemas) {
|
|
||||||
t.Errorf("Schema index %d out of bounds in second database", i)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
schema2 := db2.Schemas[i]
|
|
||||||
|
|
||||||
assert.Equal(t, schema1.Name, schema2.Name, "Schema names should match")
|
|
||||||
assert.Equal(t, len(schema1.Tables), len(schema2.Tables),
|
|
||||||
"Table count in schema '%s' should match", schema1.Name)
|
|
||||||
|
|
||||||
// Compare tables
|
|
||||||
for j, table1 := range schema1.Tables {
|
|
||||||
if j >= len(schema2.Tables) {
|
|
||||||
t.Errorf("Table index %d out of bounds in schema '%s'", j, schema1.Name)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
table2 := schema2.Tables[j]
|
|
||||||
|
|
||||||
assert.Equal(t, table1.Name, table2.Name,
|
|
||||||
"Table names should match in schema '%s'", schema1.Name)
|
|
||||||
|
|
||||||
// Compare column count
|
|
||||||
assert.Equal(t, len(table1.Columns), len(table2.Columns),
|
|
||||||
"Column count in table '%s.%s' should match", schema1.Name, table1.Name)
|
|
||||||
|
|
||||||
// Compare each column
|
|
||||||
for colName, col1 := range table1.Columns {
|
|
||||||
col2, ok := table2.Columns[colName]
|
|
||||||
if !ok {
|
|
||||||
t.Errorf("Column '%s' missing from roundtrip table '%s.%s'",
|
|
||||||
colName, schema1.Name, table1.Name)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// Compare key column properties
|
|
||||||
assert.Equal(t, col1.Name, col2.Name,
|
|
||||||
"Column name mismatch in '%s.%s.%s'", schema1.Name, table1.Name, colName)
|
|
||||||
assert.Equal(t, col1.Type, col2.Type,
|
|
||||||
"Column type mismatch in '%s.%s.%s'", schema1.Name, table1.Name, colName)
|
|
||||||
assert.Equal(t, col1.IsPrimaryKey, col2.IsPrimaryKey,
|
|
||||||
"Primary key mismatch in '%s.%s.%s'", schema1.Name, table1.Name, colName)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Summary
|
// Summary
|
||||||
t.Log("Summary:")
|
t.Log("Summary:")
|
||||||
t.Logf(" ✓ Round-trip completed: YAML → Bun → YAML")
|
t.Logf(" ✓ Round-trip completed: YAML → Bun → YAML")
|
||||||
t.Logf(" ✓ Schemas match: %d", len(db1.Schemas))
|
t.Logf(" ✓ Schemas: %d", compareResults.Schemas)
|
||||||
|
t.Logf(" ✓ Tables: %d", compareResults.Tables)
|
||||||
|
t.Logf(" ✓ Columns: %d", compareResults.Columns)
|
||||||
|
t.Logf(" ✓ Indexes: %d (original), %d (roundtrip)", compareResults.OriginalIndexes, compareResults.RoundtripIndexes)
|
||||||
|
t.Logf(" ✓ Constraints: %d (original), %d (roundtrip)", compareResults.OriginalConstraints, compareResults.RoundtripConstraints)
|
||||||
|
|
||||||
totalTables := 0
|
if compareResults.OriginalIndexes != compareResults.RoundtripIndexes {
|
||||||
for _, schema := range db1.Schemas {
|
t.Logf(" ⚠ Note: Index counts differ - Bun reader may not parse all index information from Go code")
|
||||||
totalTables += len(schema.Tables)
|
}
|
||||||
|
if compareResults.OriginalConstraints != compareResults.RoundtripConstraints {
|
||||||
|
t.Logf(" ⚠ Note: Constraint counts differ - Bun reader may not parse all constraint information from Go code")
|
||||||
}
|
}
|
||||||
t.Logf(" ✓ Total tables: %d", totalTables)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestYAMLToGORMRoundTrip tests YAML → GORM Go → YAML roundtrip
|
// TestYAMLToGORMRoundTrip tests YAML → GORM Go → YAML roundtrip
|
||||||
@@ -172,7 +250,7 @@ func TestYAMLToGORMRoundTrip(t *testing.T) {
|
|||||||
|
|
||||||
// Step 1: Read YAML file
|
// Step 1: Read YAML file
|
||||||
t.Log("Step 1: Reading YAML file...")
|
t.Log("Step 1: Reading YAML file...")
|
||||||
yamlPath := filepath.Join("..", "assets", "yaml", "database.yaml")
|
yamlPath := filepath.Join("..", "assets", "yaml", "complex_database.yaml")
|
||||||
yamlReaderOpts := &readers.ReaderOptions{
|
yamlReaderOpts := &readers.ReaderOptions{
|
||||||
FilePath: yamlPath,
|
FilePath: yamlPath,
|
||||||
}
|
}
|
||||||
@@ -183,6 +261,10 @@ func TestYAMLToGORMRoundTrip(t *testing.T) {
|
|||||||
require.NotNil(t, dbFromYAML, "Database from YAML should not be nil")
|
require.NotNil(t, dbFromYAML, "Database from YAML should not be nil")
|
||||||
t.Logf(" ✓ Read database '%s' with %d schemas", dbFromYAML.Name, len(dbFromYAML.Schemas))
|
t.Logf(" ✓ Read database '%s' with %d schemas", dbFromYAML.Name, len(dbFromYAML.Schemas))
|
||||||
|
|
||||||
|
// Log initial stats
|
||||||
|
totalTables, totalIndexes, totalConstraints := countDatabaseStats(dbFromYAML)
|
||||||
|
t.Logf(" ✓ Original: %d tables, %d indexes, %d constraints", totalTables, totalIndexes, totalConstraints)
|
||||||
|
|
||||||
// Step 2: Write to GORM Go code
|
// Step 2: Write to GORM Go code
|
||||||
t.Log("Step 2: Writing to GORM Go code...")
|
t.Log("Step 2: Writing to GORM Go code...")
|
||||||
gormGoPath := filepath.Join(testDir, "models_gorm.go")
|
gormGoPath := filepath.Join(testDir, "models_gorm.go")
|
||||||
@@ -250,65 +332,22 @@ func TestYAMLToGORMRoundTrip(t *testing.T) {
|
|||||||
err = yaml.Unmarshal(yaml2Data, &db2)
|
err = yaml.Unmarshal(yaml2Data, &db2)
|
||||||
require.NoError(t, err, "Failed to parse second YAML")
|
require.NoError(t, err, "Failed to parse second YAML")
|
||||||
|
|
||||||
// Compare high-level structure
|
// Comprehensive comparison
|
||||||
t.Log(" Comparing high-level structure...")
|
compareResults := compareDatabases(t, &db1, &db2, "GORM")
|
||||||
assert.Equal(t, len(db1.Schemas), len(db2.Schemas), "Schema count should match")
|
|
||||||
|
|
||||||
// Compare schemas and tables
|
|
||||||
for i, schema1 := range db1.Schemas {
|
|
||||||
if i >= len(db2.Schemas) {
|
|
||||||
t.Errorf("Schema index %d out of bounds in second database", i)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
schema2 := db2.Schemas[i]
|
|
||||||
|
|
||||||
assert.Equal(t, schema1.Name, schema2.Name, "Schema names should match")
|
|
||||||
assert.Equal(t, len(schema1.Tables), len(schema2.Tables),
|
|
||||||
"Table count in schema '%s' should match", schema1.Name)
|
|
||||||
|
|
||||||
// Compare tables
|
|
||||||
for j, table1 := range schema1.Tables {
|
|
||||||
if j >= len(schema2.Tables) {
|
|
||||||
t.Errorf("Table index %d out of bounds in schema '%s'", j, schema1.Name)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
table2 := schema2.Tables[j]
|
|
||||||
|
|
||||||
assert.Equal(t, table1.Name, table2.Name,
|
|
||||||
"Table names should match in schema '%s'", schema1.Name)
|
|
||||||
|
|
||||||
// Compare column count
|
|
||||||
assert.Equal(t, len(table1.Columns), len(table2.Columns),
|
|
||||||
"Column count in table '%s.%s' should match", schema1.Name, table1.Name)
|
|
||||||
|
|
||||||
// Compare each column
|
|
||||||
for colName, col1 := range table1.Columns {
|
|
||||||
col2, ok := table2.Columns[colName]
|
|
||||||
if !ok {
|
|
||||||
t.Errorf("Column '%s' missing from roundtrip table '%s.%s'",
|
|
||||||
colName, schema1.Name, table1.Name)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// Compare key column properties
|
|
||||||
assert.Equal(t, col1.Name, col2.Name,
|
|
||||||
"Column name mismatch in '%s.%s.%s'", schema1.Name, table1.Name, colName)
|
|
||||||
assert.Equal(t, col1.Type, col2.Type,
|
|
||||||
"Column type mismatch in '%s.%s.%s'", schema1.Name, table1.Name, colName)
|
|
||||||
assert.Equal(t, col1.IsPrimaryKey, col2.IsPrimaryKey,
|
|
||||||
"Primary key mismatch in '%s.%s.%s'", schema1.Name, table1.Name, colName)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Summary
|
// Summary
|
||||||
t.Log("Summary:")
|
t.Log("Summary:")
|
||||||
t.Logf(" ✓ Round-trip completed: YAML → GORM → YAML")
|
t.Logf(" ✓ Round-trip completed: YAML → GORM → YAML")
|
||||||
t.Logf(" ✓ Schemas match: %d", len(db1.Schemas))
|
t.Logf(" ✓ Schemas: %d", compareResults.Schemas)
|
||||||
|
t.Logf(" ✓ Tables: %d", compareResults.Tables)
|
||||||
|
t.Logf(" ✓ Columns: %d", compareResults.Columns)
|
||||||
|
t.Logf(" ✓ Indexes: %d (original), %d (roundtrip)", compareResults.OriginalIndexes, compareResults.RoundtripIndexes)
|
||||||
|
t.Logf(" ✓ Constraints: %d (original), %d (roundtrip)", compareResults.OriginalConstraints, compareResults.RoundtripConstraints)
|
||||||
|
|
||||||
totalTables := 0
|
if compareResults.OriginalIndexes != compareResults.RoundtripIndexes {
|
||||||
for _, schema := range db1.Schemas {
|
t.Logf(" ⚠ Note: Index counts differ - GORM reader may not parse all index information from Go code")
|
||||||
totalTables += len(schema.Tables)
|
}
|
||||||
|
if compareResults.OriginalConstraints != compareResults.RoundtripConstraints {
|
||||||
|
t.Logf(" ⚠ Note: Constraint counts differ - GORM reader may not parse all constraint information from Go code")
|
||||||
}
|
}
|
||||||
t.Logf(" ✓ Total tables: %d", totalTables)
|
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user