Fixed bug/gorm indexes
This commit is contained in:
@@ -187,11 +187,43 @@ func (r *Reader) parseFile(filename string) ([]*models.Table, error) {
|
||||
table.Schema = schemaName
|
||||
}
|
||||
|
||||
// Update columns
|
||||
// Update columns and indexes
|
||||
for _, col := range table.Columns {
|
||||
col.Table = tableName
|
||||
col.Schema = table.Schema
|
||||
}
|
||||
for _, idx := range table.Indexes {
|
||||
idx.Table = tableName
|
||||
idx.Schema = table.Schema
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Third pass: parse relationship fields for constraints
|
||||
for _, decl := range node.Decls {
|
||||
genDecl, ok := decl.(*ast.GenDecl)
|
||||
if !ok || genDecl.Tok != token.TYPE {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, spec := range genDecl.Specs {
|
||||
typeSpec, ok := spec.(*ast.TypeSpec)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
structType, ok := typeSpec.Type.(*ast.StructType)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
table, ok := structMap[typeSpec.Name.Name]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
// Parse relationship fields
|
||||
r.parseRelationshipConstraints(table, structType, structMap)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -314,6 +346,10 @@ func (r *Reader) parseStruct(structName string, structType *ast.StructType) *mod
|
||||
column.Table = tableName
|
||||
column.Schema = schemaName
|
||||
table.Columns[column.Name] = column
|
||||
|
||||
// Parse indexes from bun tags
|
||||
r.parseIndexesFromTag(table, column, tag)
|
||||
|
||||
sequence++
|
||||
}
|
||||
}
|
||||
@@ -346,6 +382,159 @@ func (r *Reader) isRelationship(tag string) bool {
|
||||
return strings.Contains(tag, "bun:\"rel:") || strings.Contains(tag, ",rel:")
|
||||
}
|
||||
|
||||
// parseRelationshipConstraints parses relationship fields to extract foreign key constraints
|
||||
func (r *Reader) parseRelationshipConstraints(table *models.Table, structType *ast.StructType, structMap map[string]*models.Table) {
|
||||
for _, field := range structType.Fields.List {
|
||||
if field.Tag == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
tag := field.Tag.Value
|
||||
if !r.isRelationship(tag) {
|
||||
continue
|
||||
}
|
||||
|
||||
bunTag := r.extractBunTag(tag)
|
||||
|
||||
// Get the referenced type name from the field type
|
||||
referencedType := r.getRelationshipType(field.Type)
|
||||
if referencedType == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Find the referenced table
|
||||
referencedTable, ok := structMap[referencedType]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
// Parse the join information: join:user_id=id
|
||||
// This means: referencedTable.user_id = thisTable.id
|
||||
joinInfo := r.parseJoinInfo(bunTag)
|
||||
if joinInfo == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// The FK is on the referenced table
|
||||
constraint := &models.Constraint{
|
||||
Name: fmt.Sprintf("fk_%s_%s", referencedTable.Name, table.Name),
|
||||
Type: models.ForeignKeyConstraint,
|
||||
Table: referencedTable.Name,
|
||||
Schema: referencedTable.Schema,
|
||||
Columns: []string{joinInfo.ForeignKey},
|
||||
ReferencedTable: table.Name,
|
||||
ReferencedSchema: table.Schema,
|
||||
ReferencedColumns: []string{joinInfo.ReferencedKey},
|
||||
OnDelete: "NO ACTION", // Bun doesn't specify this in tags
|
||||
OnUpdate: "NO ACTION",
|
||||
}
|
||||
|
||||
referencedTable.Constraints[constraint.Name] = constraint
|
||||
}
|
||||
}
|
||||
|
||||
// JoinInfo holds parsed join information
|
||||
type JoinInfo struct {
|
||||
ForeignKey string // Column in the related table
|
||||
ReferencedKey string // Column in the current table
|
||||
}
|
||||
|
||||
// parseJoinInfo parses join information from bun tag
|
||||
// Example: join:user_id=id means foreign_key=referenced_key
|
||||
func (r *Reader) parseJoinInfo(bunTag string) *JoinInfo {
|
||||
// Look for join: in the tag
|
||||
if !strings.Contains(bunTag, "join:") {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Extract join clause
|
||||
parts := strings.Split(bunTag, ",")
|
||||
for _, part := range parts {
|
||||
part = strings.TrimSpace(part)
|
||||
if strings.HasPrefix(part, "join:") {
|
||||
joinStr := strings.TrimPrefix(part, "join:")
|
||||
// Parse user_id=id
|
||||
joinParts := strings.SplitN(joinStr, "=", 2)
|
||||
if len(joinParts) == 2 {
|
||||
return &JoinInfo{
|
||||
ForeignKey: joinParts[0],
|
||||
ReferencedKey: joinParts[1],
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getRelationshipType extracts the type name from a relationship field
|
||||
func (r *Reader) getRelationshipType(expr ast.Expr) string {
|
||||
switch t := expr.(type) {
|
||||
case *ast.ArrayType:
|
||||
// []*ModelPost -> ModelPost
|
||||
if starExpr, ok := t.Elt.(*ast.StarExpr); ok {
|
||||
if ident, ok := starExpr.X.(*ast.Ident); ok {
|
||||
return ident.Name
|
||||
}
|
||||
}
|
||||
case *ast.StarExpr:
|
||||
// *ModelPost -> ModelPost
|
||||
if ident, ok := t.X.(*ast.Ident); ok {
|
||||
return ident.Name
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// parseIndexesFromTag extracts index definitions from Bun tags
|
||||
func (r *Reader) parseIndexesFromTag(table *models.Table, column *models.Column, tag string) {
|
||||
bunTag := r.extractBunTag(tag)
|
||||
|
||||
// Parse tag into parts
|
||||
parts := strings.Split(bunTag, ",")
|
||||
for _, part := range parts {
|
||||
part = strings.TrimSpace(part)
|
||||
|
||||
// Check for unique index
|
||||
if part == "unique" {
|
||||
// Auto-generate index name: idx_tablename_columnname
|
||||
indexName := fmt.Sprintf("idx_%s_%s", table.Name, column.Name)
|
||||
|
||||
if _, exists := table.Indexes[indexName]; !exists {
|
||||
index := &models.Index{
|
||||
Name: indexName,
|
||||
Table: table.Name,
|
||||
Schema: table.Schema,
|
||||
Columns: []string{column.Name},
|
||||
Unique: true,
|
||||
Type: "btree",
|
||||
}
|
||||
table.Indexes[indexName] = index
|
||||
}
|
||||
} else if strings.HasPrefix(part, "unique:") {
|
||||
// Named unique index: unique:idx_name
|
||||
indexName := strings.TrimPrefix(part, "unique:")
|
||||
|
||||
// Check if index already exists (for composite indexes)
|
||||
if idx, exists := table.Indexes[indexName]; exists {
|
||||
// Add this column to the existing index
|
||||
idx.Columns = append(idx.Columns, column.Name)
|
||||
} else {
|
||||
// Create new index
|
||||
index := &models.Index{
|
||||
Name: indexName,
|
||||
Table: table.Name,
|
||||
Schema: table.Schema,
|
||||
Columns: []string{column.Name},
|
||||
Unique: true,
|
||||
Type: "btree",
|
||||
}
|
||||
table.Indexes[indexName] = index
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// extractTableNameFromTag extracts table and schema from bun tag
|
||||
func (r *Reader) extractTableNameFromTag(tag string) (tableName string, schemaName string) {
|
||||
// Extract bun tag value
|
||||
@@ -422,6 +611,7 @@ func (r *Reader) parseColumn(fieldName string, fieldType ast.Expr, tag string, s
|
||||
case "autoincrement":
|
||||
column.AutoIncrement = true
|
||||
case "default":
|
||||
// Default value from Bun tag (e.g., default:gen_random_uuid())
|
||||
column.Default = value
|
||||
}
|
||||
}
|
||||
@@ -431,16 +621,24 @@ func (r *Reader) parseColumn(fieldName string, fieldType ast.Expr, tag string, s
|
||||
column.Type = r.goTypeToSQL(fieldType)
|
||||
}
|
||||
|
||||
// Determine if nullable based on Go type
|
||||
if r.isNullableType(fieldType) {
|
||||
// Determine if nullable based on Go type and bun tags
|
||||
// In Bun:
|
||||
// - nullzero tag means the field is nullable (can be NULL in DB)
|
||||
// - absence of nullzero means the field is NOT NULL
|
||||
// - primitive types (int64, bool, string) are NOT NULL by default
|
||||
if strings.Contains(bunTag, "nullzero") {
|
||||
column.NotNull = false
|
||||
} else if !column.IsPrimaryKey && column.Type != "" {
|
||||
// If it's not a nullable type and not a primary key, check the tag
|
||||
if !strings.Contains(bunTag, "notnull") {
|
||||
// If notnull is not explicitly set, it might still be nullable
|
||||
// This is a heuristic - we default to nullable unless specified
|
||||
return column
|
||||
}
|
||||
} else if r.isNullableGoType(fieldType) {
|
||||
// SqlString, SqlInt, etc. without nullzero tag means NOT NULL
|
||||
column.NotNull = true
|
||||
} else {
|
||||
// Primitive types are NOT NULL by default
|
||||
column.NotNull = true
|
||||
}
|
||||
|
||||
// Primary keys are always NOT NULL
|
||||
if column.IsPrimaryKey {
|
||||
column.NotNull = true
|
||||
}
|
||||
|
||||
return column
|
||||
@@ -529,11 +727,12 @@ func (r *Reader) sqlTypeToSQL(typeName string) string {
|
||||
}
|
||||
}
|
||||
|
||||
// isNullableType checks if a Go type represents a nullable field
|
||||
func (r *Reader) isNullableType(expr ast.Expr) bool {
|
||||
// isNullableGoType checks if a Go type represents a nullable field type
|
||||
// (this is for types that CAN be nullable, not whether they ARE nullable)
|
||||
func (r *Reader) isNullableGoType(expr ast.Expr) bool {
|
||||
switch t := expr.(type) {
|
||||
case *ast.StarExpr:
|
||||
// Pointer type is nullable
|
||||
// Pointer type can be nullable
|
||||
return true
|
||||
case *ast.SelectorExpr:
|
||||
// Check for sql_types nullable types
|
||||
|
||||
@@ -187,11 +187,44 @@ func (r *Reader) parseFile(filename string) ([]*models.Table, error) {
|
||||
table.Schema = schemaName
|
||||
}
|
||||
|
||||
// Update columns
|
||||
// Update columns and indexes
|
||||
for _, col := range table.Columns {
|
||||
col.Table = tableName
|
||||
col.Schema = table.Schema
|
||||
}
|
||||
for _, idx := range table.Indexes {
|
||||
idx.Table = tableName
|
||||
idx.Schema = table.Schema
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Third pass: parse relationship fields for constraints
|
||||
// Re-parse the file to get relationship information
|
||||
for _, decl := range node.Decls {
|
||||
genDecl, ok := decl.(*ast.GenDecl)
|
||||
if !ok || genDecl.Tok != token.TYPE {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, spec := range genDecl.Specs {
|
||||
typeSpec, ok := spec.(*ast.TypeSpec)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
structType, ok := typeSpec.Type.(*ast.StructType)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
table, ok := structMap[typeSpec.Name.Name]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
// Parse relationship fields
|
||||
r.parseRelationshipConstraints(table, structType, structMap)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -280,8 +313,14 @@ func (r *Reader) parseStruct(structName string, structType *ast.StructType) *mod
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip embedded GORM model and relationship fields
|
||||
if r.isGORMModel(field) || r.isRelationship(tag) {
|
||||
// Skip embedded GORM model
|
||||
if r.isGORMModel(field) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Parse relationship fields for foreign key constraints
|
||||
if r.isRelationship(tag) {
|
||||
// We'll parse constraints in a second pass after we know all table names
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -310,6 +349,10 @@ func (r *Reader) parseStruct(structName string, structType *ast.StructType) *mod
|
||||
table.Name = tableName
|
||||
table.Schema = schemaName
|
||||
table.Columns[column.Name] = column
|
||||
|
||||
// Parse indexes from GORM tags
|
||||
r.parseIndexesFromTag(table, column, tag)
|
||||
|
||||
sequence++
|
||||
}
|
||||
}
|
||||
@@ -345,6 +388,169 @@ func (r *Reader) isRelationship(tag string) bool {
|
||||
strings.Contains(gormTag, "many2many:")
|
||||
}
|
||||
|
||||
// parseRelationshipConstraints parses relationship fields to extract foreign key constraints
|
||||
func (r *Reader) parseRelationshipConstraints(table *models.Table, structType *ast.StructType, structMap map[string]*models.Table) {
|
||||
for _, field := range structType.Fields.List {
|
||||
if field.Tag == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
tag := field.Tag.Value
|
||||
if !r.isRelationship(tag) {
|
||||
continue
|
||||
}
|
||||
|
||||
gormTag := r.extractGormTag(tag)
|
||||
parts := r.parseGormTag(gormTag)
|
||||
|
||||
// Get the referenced type name from the field type
|
||||
referencedType := r.getRelationshipType(field.Type)
|
||||
if referencedType == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Find the referenced table
|
||||
referencedTable, ok := structMap[referencedType]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
// Extract foreign key information
|
||||
foreignKey, hasForeignKey := parts["foreignKey"]
|
||||
if !hasForeignKey {
|
||||
continue
|
||||
}
|
||||
|
||||
// Convert field name to column name
|
||||
fkColumn := r.fieldNameToColumnName(foreignKey)
|
||||
|
||||
// Determine constraint behavior
|
||||
onDelete := "NO ACTION"
|
||||
onUpdate := "NO ACTION"
|
||||
if constraintStr, hasConstraint := parts["constraint"]; hasConstraint {
|
||||
// Parse constraint:OnDelete:CASCADE,OnUpdate:CASCADE
|
||||
if strings.Contains(constraintStr, "OnDelete:CASCADE") {
|
||||
onDelete = "CASCADE"
|
||||
} else if strings.Contains(constraintStr, "OnDelete:SET NULL") {
|
||||
onDelete = "SET NULL"
|
||||
}
|
||||
if strings.Contains(constraintStr, "OnUpdate:CASCADE") {
|
||||
onUpdate = "CASCADE"
|
||||
} else if strings.Contains(constraintStr, "OnUpdate:SET NULL") {
|
||||
onUpdate = "SET NULL"
|
||||
}
|
||||
}
|
||||
|
||||
// The FK is on the referenced table, pointing back to this table
|
||||
// For has-many, the FK is on the "many" side
|
||||
constraint := &models.Constraint{
|
||||
Name: fmt.Sprintf("fk_%s_%s", referencedTable.Name, table.Name),
|
||||
Type: models.ForeignKeyConstraint,
|
||||
Table: referencedTable.Name,
|
||||
Schema: referencedTable.Schema,
|
||||
Columns: []string{fkColumn},
|
||||
ReferencedTable: table.Name,
|
||||
ReferencedSchema: table.Schema,
|
||||
ReferencedColumns: []string{"id"}, // Typically references the primary key
|
||||
OnDelete: onDelete,
|
||||
OnUpdate: onUpdate,
|
||||
}
|
||||
|
||||
referencedTable.Constraints[constraint.Name] = constraint
|
||||
}
|
||||
}
|
||||
|
||||
// getRelationshipType extracts the type name from a relationship field
|
||||
func (r *Reader) getRelationshipType(expr ast.Expr) string {
|
||||
switch t := expr.(type) {
|
||||
case *ast.ArrayType:
|
||||
// []*ModelPost -> ModelPost
|
||||
if starExpr, ok := t.Elt.(*ast.StarExpr); ok {
|
||||
if ident, ok := starExpr.X.(*ast.Ident); ok {
|
||||
return ident.Name
|
||||
}
|
||||
}
|
||||
case *ast.StarExpr:
|
||||
// *ModelPost -> ModelPost
|
||||
if ident, ok := t.X.(*ast.Ident); ok {
|
||||
return ident.Name
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// parseIndexesFromTag extracts index definitions from GORM tags
|
||||
func (r *Reader) parseIndexesFromTag(table *models.Table, column *models.Column, tag string) {
|
||||
gormTag := r.extractGormTag(tag)
|
||||
parts := r.parseGormTag(gormTag)
|
||||
|
||||
// Check for regular index: index:idx_name or index
|
||||
if indexName, ok := parts["index"]; ok {
|
||||
if indexName == "" {
|
||||
// Auto-generated index name
|
||||
indexName = fmt.Sprintf("idx_%s_%s", table.Name, column.Name)
|
||||
}
|
||||
|
||||
// Check if index already exists
|
||||
if _, exists := table.Indexes[indexName]; !exists {
|
||||
index := &models.Index{
|
||||
Name: indexName,
|
||||
Table: table.Name,
|
||||
Schema: table.Schema,
|
||||
Columns: []string{column.Name},
|
||||
Unique: false,
|
||||
Type: "btree",
|
||||
}
|
||||
table.Indexes[indexName] = index
|
||||
} else {
|
||||
// Add column to existing index
|
||||
table.Indexes[indexName].Columns = append(table.Indexes[indexName].Columns, column.Name)
|
||||
}
|
||||
}
|
||||
|
||||
// Check for unique index: uniqueIndex:idx_name or uniqueIndex
|
||||
if indexName, ok := parts["uniqueIndex"]; ok {
|
||||
if indexName == "" {
|
||||
// Auto-generated index name
|
||||
indexName = fmt.Sprintf("idx_%s_%s", table.Name, column.Name)
|
||||
}
|
||||
|
||||
// Check if index already exists
|
||||
if _, exists := table.Indexes[indexName]; !exists {
|
||||
index := &models.Index{
|
||||
Name: indexName,
|
||||
Table: table.Name,
|
||||
Schema: table.Schema,
|
||||
Columns: []string{column.Name},
|
||||
Unique: true,
|
||||
Type: "btree",
|
||||
}
|
||||
table.Indexes[indexName] = index
|
||||
} else {
|
||||
// Add column to existing index
|
||||
table.Indexes[indexName].Columns = append(table.Indexes[indexName].Columns, column.Name)
|
||||
}
|
||||
}
|
||||
|
||||
// Check for simple unique flag (creates a unique index for this column)
|
||||
if _, ok := parts["unique"]; ok {
|
||||
// Auto-generated index name for unique constraint
|
||||
indexName := fmt.Sprintf("idx_%s_%s", table.Name, column.Name)
|
||||
|
||||
if _, exists := table.Indexes[indexName]; !exists {
|
||||
index := &models.Index{
|
||||
Name: indexName,
|
||||
Table: table.Name,
|
||||
Schema: table.Schema,
|
||||
Columns: []string{column.Name},
|
||||
Unique: true,
|
||||
Type: "btree",
|
||||
}
|
||||
table.Indexes[indexName] = index
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// extractTableFromGormTag extracts table and schema from gorm tag
|
||||
func (r *Reader) extractTableFromGormTag(tag string) (tablename string, schemaName string) {
|
||||
// This is typically set via TableName() method, not in tags
|
||||
@@ -406,6 +612,7 @@ func (r *Reader) parseColumn(fieldName string, fieldType ast.Expr, tag string, s
|
||||
column.AutoIncrement = true
|
||||
}
|
||||
if def, ok := parts["default"]; ok {
|
||||
// Default value from GORM tag (e.g., default:gen_random_uuid())
|
||||
column.Default = def
|
||||
}
|
||||
if size, ok := parts["size"]; ok {
|
||||
@@ -419,9 +626,27 @@ func (r *Reader) parseColumn(fieldName string, fieldType ast.Expr, tag string, s
|
||||
column.Type = r.goTypeToSQL(fieldType)
|
||||
}
|
||||
|
||||
// Determine if nullable based on Go type
|
||||
if r.isNullableType(fieldType) {
|
||||
column.NotNull = false
|
||||
// Determine if nullable based on GORM tags and Go type
|
||||
// In GORM:
|
||||
// - explicit "not null" tag means NOT NULL
|
||||
// - absence of "not null" tag with sql_types means nullable
|
||||
// - primitive types (string, int64, bool) default to NOT NULL unless explicitly nullable
|
||||
if _, hasNotNull := parts["not null"]; hasNotNull {
|
||||
column.NotNull = true
|
||||
} else {
|
||||
// If no explicit "not null" tag, check the Go type
|
||||
if r.isNullableGoType(fieldType) {
|
||||
// sql_types.SqlString, etc. are nullable by default
|
||||
column.NotNull = false
|
||||
} else {
|
||||
// Primitive types default to NOT NULL
|
||||
column.NotNull = false // Default to nullable unless explicitly set
|
||||
}
|
||||
}
|
||||
|
||||
// Primary keys are always NOT NULL
|
||||
if column.IsPrimaryKey {
|
||||
column.NotNull = true
|
||||
}
|
||||
|
||||
return column
|
||||
@@ -557,11 +782,12 @@ func (r *Reader) sqlTypeToSQL(typeName string) string {
|
||||
}
|
||||
}
|
||||
|
||||
// isNullableType checks if a Go type represents a nullable field
|
||||
func (r *Reader) isNullableType(expr ast.Expr) bool {
|
||||
// isNullableGoType checks if a Go type represents a nullable field type
|
||||
// (this is for types that CAN be nullable, not whether they ARE nullable)
|
||||
func (r *Reader) isNullableGoType(expr ast.Expr) bool {
|
||||
switch t := expr.(type) {
|
||||
case *ast.StarExpr:
|
||||
// Pointer type is nullable
|
||||
// Pointer type can be nullable
|
||||
return true
|
||||
case *ast.SelectorExpr:
|
||||
// Check for sql_types nullable types
|
||||
|
||||
@@ -196,15 +196,31 @@ func (tm *TypeMapper) BuildBunTag(column *models.Column, table *models.Table) st
|
||||
parts = append(parts, "nullzero")
|
||||
}
|
||||
|
||||
// Check for unique constraint
|
||||
// Check for indexes (unique indexes should be added to tag)
|
||||
if table != nil {
|
||||
for _, constraint := range table.Constraints {
|
||||
if constraint.Type == models.UniqueConstraint {
|
||||
for _, col := range constraint.Columns {
|
||||
if col == column.Name {
|
||||
parts = append(parts, "unique")
|
||||
break
|
||||
for _, index := range table.Indexes {
|
||||
if !index.Unique {
|
||||
continue
|
||||
}
|
||||
// Check if this column is in the index
|
||||
for _, col := range index.Columns {
|
||||
if col == column.Name {
|
||||
// Add unique tag with index name for composite indexes
|
||||
// or simple unique for single-column indexes
|
||||
if len(index.Columns) > 1 {
|
||||
// Composite index - use index name
|
||||
parts = append(parts, fmt.Sprintf("unique:%s", index.Name))
|
||||
} else {
|
||||
// Single column - use index name if it's not auto-generated
|
||||
// Auto-generated names typically follow pattern: idx_tablename_columnname
|
||||
expectedAutoName := fmt.Sprintf("idx_%s_%s", table.Name, column.Name)
|
||||
if index.Name == expectedAutoName {
|
||||
parts = append(parts, "unique")
|
||||
} else {
|
||||
parts = append(parts, fmt.Sprintf("unique:%s", index.Name))
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,18 +14,135 @@ import (
|
||||
bunwriter "git.warky.dev/wdevs/relspecgo/pkg/writers/bun"
|
||||
gormwriter "git.warky.dev/wdevs/relspecgo/pkg/writers/gorm"
|
||||
yamlwriter "git.warky.dev/wdevs/relspecgo/pkg/writers/yaml"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// ComparisonResults holds the results of database comparison
|
||||
type ComparisonResults struct {
|
||||
Schemas int
|
||||
Tables int
|
||||
Columns int
|
||||
OriginalIndexes int
|
||||
RoundtripIndexes int
|
||||
OriginalConstraints int
|
||||
RoundtripConstraints int
|
||||
}
|
||||
|
||||
// countDatabaseStats counts tables, indexes, and constraints in a database
|
||||
func countDatabaseStats(db *models.Database) (tables, indexes, constraints int) {
|
||||
for _, schema := range db.Schemas {
|
||||
tables += len(schema.Tables)
|
||||
for _, table := range schema.Tables {
|
||||
indexes += len(table.Indexes)
|
||||
constraints += len(table.Constraints)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// compareDatabases performs comprehensive comparison between two databases
|
||||
func compareDatabases(t *testing.T, db1, db2 *models.Database, ormName string) ComparisonResults {
|
||||
t.Helper()
|
||||
|
||||
results := ComparisonResults{}
|
||||
|
||||
// Compare high-level structure
|
||||
t.Log(" Comparing high-level structure...")
|
||||
require.Equal(t, len(db1.Schemas), len(db2.Schemas), "Schema count should match")
|
||||
results.Schemas = len(db1.Schemas)
|
||||
|
||||
// Count totals
|
||||
tables1, indexes1, constraints1 := countDatabaseStats(db1)
|
||||
_, indexes2, constraints2 := countDatabaseStats(db2)
|
||||
|
||||
results.OriginalIndexes = indexes1
|
||||
results.RoundtripIndexes = indexes2
|
||||
results.OriginalConstraints = constraints1
|
||||
results.RoundtripConstraints = constraints2
|
||||
|
||||
// Compare schemas and tables
|
||||
for i, schema1 := range db1.Schemas {
|
||||
if i >= len(db2.Schemas) {
|
||||
t.Errorf("Schema index %d out of bounds in second database", i)
|
||||
continue
|
||||
}
|
||||
schema2 := db2.Schemas[i]
|
||||
|
||||
require.Equal(t, schema1.Name, schema2.Name, "Schema names should match")
|
||||
require.Equal(t, len(schema1.Tables), len(schema2.Tables),
|
||||
"Table count in schema '%s' should match", schema1.Name)
|
||||
|
||||
// Compare tables
|
||||
for j, table1 := range schema1.Tables {
|
||||
if j >= len(schema2.Tables) {
|
||||
t.Errorf("Table index %d out of bounds in schema '%s'", j, schema1.Name)
|
||||
continue
|
||||
}
|
||||
table2 := schema2.Tables[j]
|
||||
|
||||
require.Equal(t, table1.Name, table2.Name,
|
||||
"Table names should match in schema '%s'", schema1.Name)
|
||||
|
||||
// Compare column count
|
||||
require.Equal(t, len(table1.Columns), len(table2.Columns),
|
||||
"Column count in table '%s.%s' should match", schema1.Name, table1.Name)
|
||||
|
||||
results.Columns += len(table1.Columns)
|
||||
|
||||
// Compare each column
|
||||
for colName, col1 := range table1.Columns {
|
||||
col2, ok := table2.Columns[colName]
|
||||
if !ok {
|
||||
t.Errorf("Column '%s' missing from roundtrip table '%s.%s'",
|
||||
colName, schema1.Name, table1.Name)
|
||||
continue
|
||||
}
|
||||
|
||||
// Compare key column properties
|
||||
require.Equal(t, col1.Name, col2.Name,
|
||||
"Column name mismatch in '%s.%s.%s'", schema1.Name, table1.Name, colName)
|
||||
require.Equal(t, col1.Type, col2.Type,
|
||||
"Column type mismatch in '%s.%s.%s'", schema1.Name, table1.Name, colName)
|
||||
require.Equal(t, col1.Length, col2.Length,
|
||||
"Column length mismatch in '%s.%s.%s'", schema1.Name, table1.Name, colName)
|
||||
require.Equal(t, col1.IsPrimaryKey, col2.IsPrimaryKey,
|
||||
"Primary key mismatch in '%s.%s.%s'", schema1.Name, table1.Name, colName)
|
||||
require.Equal(t, col1.NotNull, col2.NotNull,
|
||||
"NotNull mismatch in '%s.%s.%s'", schema1.Name, table1.Name, colName)
|
||||
|
||||
// Log defaults that don't match (these can vary in representation)
|
||||
if col1.Default != col2.Default {
|
||||
t.Logf(" ℹ Default value differs for '%s.%s.%s': '%v' vs '%v'",
|
||||
schema1.Name, table1.Name, colName, col1.Default, col2.Default)
|
||||
}
|
||||
}
|
||||
|
||||
// Log index and constraint differences (ORM readers may not capture all of these)
|
||||
if len(table1.Indexes) != len(table2.Indexes) {
|
||||
t.Logf(" ℹ Index count differs for table '%s.%s': %d vs %d",
|
||||
schema1.Name, table1.Name, len(table1.Indexes), len(table2.Indexes))
|
||||
}
|
||||
if len(table1.Constraints) != len(table2.Constraints) {
|
||||
t.Logf(" ℹ Constraint count differs for table '%s.%s': %d vs %d",
|
||||
schema1.Name, table1.Name, len(table1.Constraints), len(table2.Constraints))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
results.Tables = tables1
|
||||
t.Logf(" ✓ Validated %d schemas, %d tables, %d columns", results.Schemas, results.Tables, results.Columns)
|
||||
|
||||
return results
|
||||
}
|
||||
|
||||
// TestYAMLToBunRoundTrip tests YAML → Bun Go → YAML roundtrip
|
||||
func TestYAMLToBunRoundTrip(t *testing.T) {
|
||||
testDir := t.TempDir()
|
||||
|
||||
// Step 1: Read YAML file
|
||||
t.Log("Step 1: Reading YAML file...")
|
||||
yamlPath := filepath.Join("..", "assets", "yaml", "database.yaml")
|
||||
yamlPath := filepath.Join("..", "assets", "yaml", "complex_database.yaml")
|
||||
yamlReaderOpts := &readers.ReaderOptions{
|
||||
FilePath: yamlPath,
|
||||
}
|
||||
@@ -36,6 +153,10 @@ func TestYAMLToBunRoundTrip(t *testing.T) {
|
||||
require.NotNil(t, dbFromYAML, "Database from YAML should not be nil")
|
||||
t.Logf(" ✓ Read database '%s' with %d schemas", dbFromYAML.Name, len(dbFromYAML.Schemas))
|
||||
|
||||
// Log initial stats
|
||||
totalTables, totalIndexes, totalConstraints := countDatabaseStats(dbFromYAML)
|
||||
t.Logf(" ✓ Original: %d tables, %d indexes, %d constraints", totalTables, totalIndexes, totalConstraints)
|
||||
|
||||
// Step 2: Write to Bun Go code
|
||||
t.Log("Step 2: Writing to Bun Go code...")
|
||||
bunGoPath := filepath.Join(testDir, "models_bun.go")
|
||||
@@ -103,67 +224,24 @@ func TestYAMLToBunRoundTrip(t *testing.T) {
|
||||
err = yaml.Unmarshal(yaml2Data, &db2)
|
||||
require.NoError(t, err, "Failed to parse second YAML")
|
||||
|
||||
// Compare high-level structure
|
||||
t.Log(" Comparing high-level structure...")
|
||||
assert.Equal(t, len(db1.Schemas), len(db2.Schemas), "Schema count should match")
|
||||
|
||||
// Compare schemas and tables
|
||||
for i, schema1 := range db1.Schemas {
|
||||
if i >= len(db2.Schemas) {
|
||||
t.Errorf("Schema index %d out of bounds in second database", i)
|
||||
continue
|
||||
}
|
||||
schema2 := db2.Schemas[i]
|
||||
|
||||
assert.Equal(t, schema1.Name, schema2.Name, "Schema names should match")
|
||||
assert.Equal(t, len(schema1.Tables), len(schema2.Tables),
|
||||
"Table count in schema '%s' should match", schema1.Name)
|
||||
|
||||
// Compare tables
|
||||
for j, table1 := range schema1.Tables {
|
||||
if j >= len(schema2.Tables) {
|
||||
t.Errorf("Table index %d out of bounds in schema '%s'", j, schema1.Name)
|
||||
continue
|
||||
}
|
||||
table2 := schema2.Tables[j]
|
||||
|
||||
assert.Equal(t, table1.Name, table2.Name,
|
||||
"Table names should match in schema '%s'", schema1.Name)
|
||||
|
||||
// Compare column count
|
||||
assert.Equal(t, len(table1.Columns), len(table2.Columns),
|
||||
"Column count in table '%s.%s' should match", schema1.Name, table1.Name)
|
||||
|
||||
// Compare each column
|
||||
for colName, col1 := range table1.Columns {
|
||||
col2, ok := table2.Columns[colName]
|
||||
if !ok {
|
||||
t.Errorf("Column '%s' missing from roundtrip table '%s.%s'",
|
||||
colName, schema1.Name, table1.Name)
|
||||
continue
|
||||
}
|
||||
|
||||
// Compare key column properties
|
||||
assert.Equal(t, col1.Name, col2.Name,
|
||||
"Column name mismatch in '%s.%s.%s'", schema1.Name, table1.Name, colName)
|
||||
assert.Equal(t, col1.Type, col2.Type,
|
||||
"Column type mismatch in '%s.%s.%s'", schema1.Name, table1.Name, colName)
|
||||
assert.Equal(t, col1.IsPrimaryKey, col2.IsPrimaryKey,
|
||||
"Primary key mismatch in '%s.%s.%s'", schema1.Name, table1.Name, colName)
|
||||
}
|
||||
}
|
||||
}
|
||||
// Comprehensive comparison
|
||||
compareResults := compareDatabases(t, &db1, &db2, "Bun")
|
||||
|
||||
// Summary
|
||||
t.Log("Summary:")
|
||||
t.Logf(" ✓ Round-trip completed: YAML → Bun → YAML")
|
||||
t.Logf(" ✓ Schemas match: %d", len(db1.Schemas))
|
||||
t.Logf(" ✓ Schemas: %d", compareResults.Schemas)
|
||||
t.Logf(" ✓ Tables: %d", compareResults.Tables)
|
||||
t.Logf(" ✓ Columns: %d", compareResults.Columns)
|
||||
t.Logf(" ✓ Indexes: %d (original), %d (roundtrip)", compareResults.OriginalIndexes, compareResults.RoundtripIndexes)
|
||||
t.Logf(" ✓ Constraints: %d (original), %d (roundtrip)", compareResults.OriginalConstraints, compareResults.RoundtripConstraints)
|
||||
|
||||
totalTables := 0
|
||||
for _, schema := range db1.Schemas {
|
||||
totalTables += len(schema.Tables)
|
||||
if compareResults.OriginalIndexes != compareResults.RoundtripIndexes {
|
||||
t.Logf(" ⚠ Note: Index counts differ - Bun reader may not parse all index information from Go code")
|
||||
}
|
||||
if compareResults.OriginalConstraints != compareResults.RoundtripConstraints {
|
||||
t.Logf(" ⚠ Note: Constraint counts differ - Bun reader may not parse all constraint information from Go code")
|
||||
}
|
||||
t.Logf(" ✓ Total tables: %d", totalTables)
|
||||
}
|
||||
|
||||
// TestYAMLToGORMRoundTrip tests YAML → GORM Go → YAML roundtrip
|
||||
@@ -172,7 +250,7 @@ func TestYAMLToGORMRoundTrip(t *testing.T) {
|
||||
|
||||
// Step 1: Read YAML file
|
||||
t.Log("Step 1: Reading YAML file...")
|
||||
yamlPath := filepath.Join("..", "assets", "yaml", "database.yaml")
|
||||
yamlPath := filepath.Join("..", "assets", "yaml", "complex_database.yaml")
|
||||
yamlReaderOpts := &readers.ReaderOptions{
|
||||
FilePath: yamlPath,
|
||||
}
|
||||
@@ -183,6 +261,10 @@ func TestYAMLToGORMRoundTrip(t *testing.T) {
|
||||
require.NotNil(t, dbFromYAML, "Database from YAML should not be nil")
|
||||
t.Logf(" ✓ Read database '%s' with %d schemas", dbFromYAML.Name, len(dbFromYAML.Schemas))
|
||||
|
||||
// Log initial stats
|
||||
totalTables, totalIndexes, totalConstraints := countDatabaseStats(dbFromYAML)
|
||||
t.Logf(" ✓ Original: %d tables, %d indexes, %d constraints", totalTables, totalIndexes, totalConstraints)
|
||||
|
||||
// Step 2: Write to GORM Go code
|
||||
t.Log("Step 2: Writing to GORM Go code...")
|
||||
gormGoPath := filepath.Join(testDir, "models_gorm.go")
|
||||
@@ -250,65 +332,22 @@ func TestYAMLToGORMRoundTrip(t *testing.T) {
|
||||
err = yaml.Unmarshal(yaml2Data, &db2)
|
||||
require.NoError(t, err, "Failed to parse second YAML")
|
||||
|
||||
// Compare high-level structure
|
||||
t.Log(" Comparing high-level structure...")
|
||||
assert.Equal(t, len(db1.Schemas), len(db2.Schemas), "Schema count should match")
|
||||
|
||||
// Compare schemas and tables
|
||||
for i, schema1 := range db1.Schemas {
|
||||
if i >= len(db2.Schemas) {
|
||||
t.Errorf("Schema index %d out of bounds in second database", i)
|
||||
continue
|
||||
}
|
||||
schema2 := db2.Schemas[i]
|
||||
|
||||
assert.Equal(t, schema1.Name, schema2.Name, "Schema names should match")
|
||||
assert.Equal(t, len(schema1.Tables), len(schema2.Tables),
|
||||
"Table count in schema '%s' should match", schema1.Name)
|
||||
|
||||
// Compare tables
|
||||
for j, table1 := range schema1.Tables {
|
||||
if j >= len(schema2.Tables) {
|
||||
t.Errorf("Table index %d out of bounds in schema '%s'", j, schema1.Name)
|
||||
continue
|
||||
}
|
||||
table2 := schema2.Tables[j]
|
||||
|
||||
assert.Equal(t, table1.Name, table2.Name,
|
||||
"Table names should match in schema '%s'", schema1.Name)
|
||||
|
||||
// Compare column count
|
||||
assert.Equal(t, len(table1.Columns), len(table2.Columns),
|
||||
"Column count in table '%s.%s' should match", schema1.Name, table1.Name)
|
||||
|
||||
// Compare each column
|
||||
for colName, col1 := range table1.Columns {
|
||||
col2, ok := table2.Columns[colName]
|
||||
if !ok {
|
||||
t.Errorf("Column '%s' missing from roundtrip table '%s.%s'",
|
||||
colName, schema1.Name, table1.Name)
|
||||
continue
|
||||
}
|
||||
|
||||
// Compare key column properties
|
||||
assert.Equal(t, col1.Name, col2.Name,
|
||||
"Column name mismatch in '%s.%s.%s'", schema1.Name, table1.Name, colName)
|
||||
assert.Equal(t, col1.Type, col2.Type,
|
||||
"Column type mismatch in '%s.%s.%s'", schema1.Name, table1.Name, colName)
|
||||
assert.Equal(t, col1.IsPrimaryKey, col2.IsPrimaryKey,
|
||||
"Primary key mismatch in '%s.%s.%s'", schema1.Name, table1.Name, colName)
|
||||
}
|
||||
}
|
||||
}
|
||||
// Comprehensive comparison
|
||||
compareResults := compareDatabases(t, &db1, &db2, "GORM")
|
||||
|
||||
// Summary
|
||||
t.Log("Summary:")
|
||||
t.Logf(" ✓ Round-trip completed: YAML → GORM → YAML")
|
||||
t.Logf(" ✓ Schemas match: %d", len(db1.Schemas))
|
||||
t.Logf(" ✓ Schemas: %d", compareResults.Schemas)
|
||||
t.Logf(" ✓ Tables: %d", compareResults.Tables)
|
||||
t.Logf(" ✓ Columns: %d", compareResults.Columns)
|
||||
t.Logf(" ✓ Indexes: %d (original), %d (roundtrip)", compareResults.OriginalIndexes, compareResults.RoundtripIndexes)
|
||||
t.Logf(" ✓ Constraints: %d (original), %d (roundtrip)", compareResults.OriginalConstraints, compareResults.RoundtripConstraints)
|
||||
|
||||
totalTables := 0
|
||||
for _, schema := range db1.Schemas {
|
||||
totalTables += len(schema.Tables)
|
||||
if compareResults.OriginalIndexes != compareResults.RoundtripIndexes {
|
||||
t.Logf(" ⚠ Note: Index counts differ - GORM reader may not parse all index information from Go code")
|
||||
}
|
||||
if compareResults.OriginalConstraints != compareResults.RoundtripConstraints {
|
||||
t.Logf(" ⚠ Note: Constraint counts differ - GORM reader may not parse all constraint information from Go code")
|
||||
}
|
||||
t.Logf(" ✓ Total tables: %d", totalTables)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user