From 5f1923233ef2b8c676d342bebca5ceb6c77eee6b Mon Sep 17 00:00:00 2001 From: Hein Date: Thu, 18 Dec 2025 20:00:59 +0200 Subject: [PATCH] Reverse reading bun/gorm models --- cmd/relspec/convert.go | 28 +++++- pkg/readers/gorm/reader.go | 194 ++++++++++++++++++++++++++++++++----- 2 files changed, 198 insertions(+), 24 deletions(-) diff --git a/cmd/relspec/convert.go b/cmd/relspec/convert.go index e95f8f7..a755aaa 100644 --- a/cmd/relspec/convert.go +++ b/cmd/relspec/convert.go @@ -8,9 +8,11 @@ import ( "git.warky.dev/wdevs/relspecgo/pkg/models" "git.warky.dev/wdevs/relspecgo/pkg/readers" + "git.warky.dev/wdevs/relspecgo/pkg/readers/bun" "git.warky.dev/wdevs/relspecgo/pkg/readers/dbml" "git.warky.dev/wdevs/relspecgo/pkg/readers/dctx" "git.warky.dev/wdevs/relspecgo/pkg/readers/drawdb" + "git.warky.dev/wdevs/relspecgo/pkg/readers/gorm" "git.warky.dev/wdevs/relspecgo/pkg/readers/json" "git.warky.dev/wdevs/relspecgo/pkg/readers/pgsql" "git.warky.dev/wdevs/relspecgo/pkg/readers/yaml" @@ -51,6 +53,8 @@ Input formats: - drawdb: DrawDB JSON files - json: JSON database schema - yaml: YAML database schema + - gorm: GORM model files (Go, file or directory) + - bun: Bun model files (Go, file or directory) - pgsql: PostgreSQL database (live connection) Output formats: @@ -106,12 +110,20 @@ Examples: # Convert DBML to DCTX with specific schema relspec convert --from dbml --from-path schema.dbml \ - --to dctx --to-path output.dctx --schema public`, + --to dctx --to-path output.dctx --schema public + + # Convert GORM models directory to DBML + relspec convert --from gorm --from-path /path/to/models \ + --to dbml --to-path schema.dbml + + # Convert Bun models directory to JSON + relspec convert --from bun --from-path ./models \ + --to json --to-path schema.json`, RunE: runConvert, } func init() { - convertCmd.Flags().StringVar(&convertSourceType, "from", "", "Source format (dbml, dctx, drawdb, json, yaml, pgsql)") + convertCmd.Flags().StringVar(&convertSourceType, "from", "", "Source format (dbml, dctx, drawdb, json, yaml, gorm, bun, pgsql)") convertCmd.Flags().StringVar(&convertSourcePath, "from-path", "", "Source file path (for file-based formats)") convertCmd.Flags().StringVar(&convertSourceConn, "from-conn", "", "Source connection string (for database formats)") @@ -215,6 +227,18 @@ func readDatabaseForConvert(dbType, filePath, connString string) (*models.Databa } reader = pgsql.NewReader(&readers.ReaderOptions{ConnectionString: connString}) + case "gorm": + if filePath == "" { + return nil, fmt.Errorf("file path is required for GORM format") + } + reader = gorm.NewReader(&readers.ReaderOptions{FilePath: filePath}) + + case "bun": + if filePath == "" { + return nil, fmt.Errorf("file path is required for Bun format") + } + reader = bun.NewReader(&readers.ReaderOptions{FilePath: filePath}) + default: return nil, fmt.Errorf("unsupported source format: %s", dbType) } diff --git a/pkg/readers/gorm/reader.go b/pkg/readers/gorm/reader.go index 44aac88..7f7abff 100644 --- a/pkg/readers/gorm/reader.go +++ b/pkg/readers/gorm/reader.go @@ -331,7 +331,7 @@ func (r *Reader) parseStruct(structName string, structType *ast.StructType) *mod } // Parse column from tag - column := r.parseColumn(fieldName, field.Type, tag, sequence) + column, inlineRef := r.parseColumn(fieldName, field.Type, tag, sequence) if column != nil { // Extract schema and table name from TableName() method if present if strings.Contains(tag, "gorm:") { @@ -353,6 +353,11 @@ func (r *Reader) parseStruct(structName string, structType *ast.StructType) *mod // Parse indexes from GORM tags r.parseIndexesFromTag(table, column, tag) + // Handle inline references (e.g., "bigint references mainaccount(id_mainaccount)") + if inlineRef != "" { + r.createInlineReferenceConstraint(table, column, inlineRef) + } + sequence++ } } @@ -360,6 +365,81 @@ func (r *Reader) parseStruct(structName string, structType *ast.StructType) *mod return table } +// createInlineReferenceConstraint creates a foreign key constraint from inline reference +// e.g., "mainaccount(id_mainaccount) ON DELETE CASCADE" creates an FK constraint +func (r *Reader) createInlineReferenceConstraint(table *models.Table, column *models.Column, refInfo string) { + // Parse refInfo: "mainaccount(id_mainaccount) ON DELETE CASCADE ON UPDATE RESTRICT" + // Extract table name, column name, and constraint actions + openParen := strings.Index(refInfo, "(") + closeParen := strings.Index(refInfo, ")") + + if openParen == -1 || closeParen == -1 || openParen >= closeParen { + // Invalid format, skip + return + } + + refTableFull := strings.TrimSpace(refInfo[:openParen]) + refColumn := strings.TrimSpace(refInfo[openParen+1 : closeParen]) + + // Extract ON DELETE/UPDATE clauses (everything after the closing paren) + constraintActions := "" + if closeParen+1 < len(refInfo) { + constraintActions = strings.TrimSpace(refInfo[closeParen+1:]) + } + + // Parse schema.table or just table + refSchema := "public" + refTable := refTableFull + if strings.Contains(refTableFull, ".") { + parts := strings.SplitN(refTableFull, ".", 2) + refSchema = parts[0] + refTable = parts[1] + } + + // Parse ON DELETE and ON UPDATE actions + onDelete := "NO ACTION" + onUpdate := "NO ACTION" + if constraintActions != "" { + constraintActionsLower := strings.ToLower(constraintActions) + + // Parse ON DELETE + if strings.Contains(constraintActionsLower, "on delete cascade") { + onDelete = "CASCADE" + } else if strings.Contains(constraintActionsLower, "on delete set null") { + onDelete = "SET NULL" + } else if strings.Contains(constraintActionsLower, "on delete restrict") { + onDelete = "RESTRICT" + } else if strings.Contains(constraintActionsLower, "on delete no action") { + onDelete = "NO ACTION" + } + + // Parse ON UPDATE + if strings.Contains(constraintActionsLower, "on update cascade") { + onUpdate = "CASCADE" + } else if strings.Contains(constraintActionsLower, "on update set null") { + onUpdate = "SET NULL" + } else if strings.Contains(constraintActionsLower, "on update restrict") { + onUpdate = "RESTRICT" + } else if strings.Contains(constraintActionsLower, "on update no action") { + onUpdate = "NO ACTION" + } + } + + // Create a foreign key constraint + constraintName := fmt.Sprintf("fk_%s_%s", table.Name, column.Name) + constraint := models.InitConstraint(constraintName, models.ForeignKeyConstraint) + constraint.Schema = table.Schema + constraint.Table = table.Name + constraint.Columns = []string{column.Name} + constraint.ReferencedSchema = refSchema + constraint.ReferencedTable = refTable + constraint.ReferencedColumns = []string{refColumn} + constraint.OnDelete = onDelete + constraint.OnUpdate = onUpdate + + table.Constraints[constraintName] = constraint +} + // isGORMModel checks if a field is gorm.Model func (r *Reader) isGORMModel(field *ast.Field) bool { if len(field.Names) > 0 { @@ -383,9 +463,11 @@ func (r *Reader) isGORMModel(field *ast.Field) bool { // isRelationship checks if a field is a relationship based on gorm tag func (r *Reader) isRelationship(tag string) bool { gormTag := r.extractGormTag(tag) - return strings.Contains(gormTag, "foreignKey:") || - strings.Contains(gormTag, "references:") || - strings.Contains(gormTag, "many2many:") + gormTagLower := strings.ToLower(gormTag) + return strings.Contains(gormTagLower, "foreignkey:") || + strings.Contains(gormTagLower, "references:") || + strings.Contains(gormTagLower, "many2many:") || + strings.Contains(gormTagLower, "preload") } // parseRelationshipConstraints parses relationship fields to extract foreign key constraints @@ -409,6 +491,15 @@ func (r *Reader) parseRelationshipConstraints(table *models.Table, structType *a continue } + // Determine if this is a has-many or belongs-to relationship + isSlice := r.isSliceRelationship(field.Type) + + // Only process has-many relationships (slices) + // Belongs-to relationships (*Type) should be handled by the other side's has-many + if !isSlice { + continue + } + // Find the referenced table referencedTable, ok := structMap[referencedType] if !ok { @@ -416,42 +507,54 @@ func (r *Reader) parseRelationshipConstraints(table *models.Table, structType *a } // Extract foreign key information - foreignKey, hasForeignKey := parts["foreignKey"] + foreignKey, hasForeignKey := parts["foreignkey"] if !hasForeignKey { continue } - // Convert field name to column name - fkColumn := r.fieldNameToColumnName(foreignKey) + // GORM's foreignkey tag contains the column name (already in snake_case), not field name + fkColumn := foreignKey + + // Extract association_foreignkey (the column being referenced) + assocForeignKey, hasAssocFK := parts["association_foreignkey"] + if !hasAssocFK { + // Default to "id" if not specified + assocForeignKey = "id" + } // Determine constraint behavior onDelete := "NO ACTION" onUpdate := "NO ACTION" if constraintStr, hasConstraint := parts["constraint"]; hasConstraint { // Parse constraint:OnDelete:CASCADE,OnUpdate:CASCADE - if strings.Contains(constraintStr, "OnDelete:CASCADE") { + constraintLower := strings.ToLower(constraintStr) + if strings.Contains(constraintLower, "ondelete:cascade") { onDelete = "CASCADE" - } else if strings.Contains(constraintStr, "OnDelete:SET NULL") { + } else if strings.Contains(constraintLower, "ondelete:set null") { onDelete = "SET NULL" + } else if strings.Contains(constraintLower, "ondelete:restrict") { + onDelete = "RESTRICT" } - if strings.Contains(constraintStr, "OnUpdate:CASCADE") { + if strings.Contains(constraintLower, "onupdate:cascade") { onUpdate = "CASCADE" - } else if strings.Contains(constraintStr, "OnUpdate:SET NULL") { + } else if strings.Contains(constraintLower, "onupdate:set null") { onUpdate = "SET NULL" + } else if strings.Contains(constraintLower, "onupdate:restrict") { + onUpdate = "RESTRICT" } } // The FK is on the referenced table, pointing back to this table // For has-many, the FK is on the "many" side constraint := &models.Constraint{ - Name: fmt.Sprintf("fk_%s_%s", referencedTable.Name, table.Name), + Name: fmt.Sprintf("fk_%s_%s", referencedTable.Name, fkColumn), Type: models.ForeignKeyConstraint, Table: referencedTable.Name, Schema: referencedTable.Schema, Columns: []string{fkColumn}, ReferencedTable: table.Name, ReferencedSchema: table.Schema, - ReferencedColumns: []string{"id"}, // Typically references the primary key + ReferencedColumns: []string{assocForeignKey}, OnDelete: onDelete, OnUpdate: onUpdate, } @@ -460,6 +563,19 @@ func (r *Reader) parseRelationshipConstraints(table *models.Table, structType *a } } +// isSliceRelationship checks if a relationship field is a slice (has-many) or pointer (belongs-to) +func (r *Reader) isSliceRelationship(expr ast.Expr) bool { + switch expr.(type) { + case *ast.ArrayType: + // []*Type is a has-many relationship + return true + case *ast.StarExpr: + // *Type is a belongs-to relationship + return false + } + return false +} + // getRelationshipType extracts the type name from a relationship field func (r *Reader) getRelationshipType(expr ast.Expr) string { switch t := expr.(type) { @@ -509,7 +625,7 @@ func (r *Reader) parseIndexesFromTag(table *models.Table, column *models.Column, } // Check for unique index: uniqueIndex:idx_name or uniqueIndex - if indexName, ok := parts["uniqueIndex"]; ok { + if indexName, ok := parts["uniqueindex"]; ok { if indexName == "" { // Auto-generated index name indexName = fmt.Sprintf("idx_%s_%s", table.Name, column.Name) @@ -576,11 +692,12 @@ func (r *Reader) deriveTableName(structName string) string { } // parseColumn parses a struct field into a Column model -func (r *Reader) parseColumn(fieldName string, fieldType ast.Expr, tag string, sequence uint) *models.Column { +// Returns the column and any inline reference information (e.g., "mainaccount(id_mainaccount)") +func (r *Reader) parseColumn(fieldName string, fieldType ast.Expr, tag string, sequence uint) (*models.Column, string) { // Extract gorm tag gormTag := r.extractGormTag(tag) if gormTag == "" { - return nil + return nil, "" } column := models.InitColumn("", "", "") @@ -598,17 +715,25 @@ func (r *Reader) parseColumn(fieldName string, fieldType ast.Expr, tag string, s } // Parse tag attributes + var inlineRef string if typ, ok := parts["type"]; ok { // Parse type and extract length if present (e.g., varchar(255)) - column.Type, column.Length = r.parseTypeWithLength(typ) + // Also extract references if present (e.g., bigint references table(column)) + baseType, length, refInfo := r.parseTypeWithReferences(typ) + column.Type = baseType + column.Length = length + inlineRef = refInfo } - if _, ok := parts["primaryKey"]; ok { + if _, ok := parts["primarykey"]; ok { + column.IsPrimaryKey = true + } + if _, ok := parts["primary_key"]; ok { column.IsPrimaryKey = true } if _, ok := parts["not null"]; ok { column.NotNull = true } - if _, ok := parts["autoIncrement"]; ok { + if _, ok := parts["autoincrement"]; ok { column.AutoIncrement = true } if def, ok := parts["default"]; ok { @@ -649,7 +774,7 @@ func (r *Reader) parseColumn(fieldName string, fieldType ast.Expr, tag string, s column.NotNull = true } - return column + return column, inlineRef } // extractGormTag extracts the gorm tag value from a struct tag @@ -688,6 +813,28 @@ func (r *Reader) parseTypeWithLength(typeStr string) (baseType string, length in return } +// parseTypeWithReferences parses a type string and extracts base type, length, and references +// e.g., "bigint references mainaccount(id_mainaccount) ON DELETE CASCADE" returns ("bigint", 0, "mainaccount(id_mainaccount) ON DELETE CASCADE") +func (r *Reader) parseTypeWithReferences(typeStr string) (baseType string, length int, refInfo string) { + // Check if the type contains "references" (case-insensitive) + lowerType := strings.ToLower(typeStr) + if strings.Contains(lowerType, " references ") { + // Split on "references" (case-insensitive) + idx := strings.Index(lowerType, " references ") + baseTypePart := strings.TrimSpace(typeStr[:idx]) + // Keep the entire reference info including ON DELETE/UPDATE clauses + refInfo = strings.TrimSpace(typeStr[idx+len(" references "):]) + + // Parse base type for length + baseType, length = r.parseTypeWithLength(baseTypePart) + return + } + + // No references, just parse type and length + baseType, length = r.parseTypeWithLength(typeStr) + return +} + // parseGormTag parses a gorm tag string into a map func (r *Reader) parseGormTag(gormTag string) map[string]string { result := make(map[string]string) @@ -703,10 +850,13 @@ func (r *Reader) parseGormTag(gormTag string) map[string]string { // Check for key:value pairs if strings.Contains(part, ":") { kv := strings.SplitN(part, ":", 2) - result[kv[0]] = kv[1] + // Normalize key to lowercase for case-insensitive matching + key := strings.ToLower(kv[0]) + result[key] = kv[1] } else { // Flags like "primaryKey", "not null", etc. - result[part] = "" + // Normalize to lowercase + result[strings.ToLower(part)] = "" } }