package gorm import ( "fmt" "go/format" "os" "os/exec" "path/filepath" "strings" "git.warky.dev/wdevs/relspecgo/pkg/models" "git.warky.dev/wdevs/relspecgo/pkg/writers" ) // Writer implements the writers.Writer interface for GORM models type Writer struct { options *writers.WriterOptions typeMapper *TypeMapper templates *Templates config *MethodConfig } // NewWriter creates a new GORM writer with the given options func NewWriter(options *writers.WriterOptions) *Writer { w := &Writer{ options: options, typeMapper: NewTypeMapper(), config: LoadMethodConfigFromMetadata(options.Metadata), } // Initialize templates tmpl, err := NewTemplates() if err != nil { // Should not happen with embedded templates panic(fmt.Sprintf("failed to initialize templates: %v", err)) } w.templates = tmpl return w } // WriteDatabase writes a complete database as GORM models func (w *Writer) WriteDatabase(db *models.Database) error { // Check if multi-file mode is enabled multiFile := w.shouldUseMultiFile() if multiFile { return w.writeMultiFile(db) } return w.writeSingleFile(db) } // WriteSchema writes a schema as GORM models func (w *Writer) WriteSchema(schema *models.Schema) error { // Create a temporary database with just this schema db := models.InitDatabase(schema.Name) db.Schemas = []*models.Schema{schema} return w.WriteDatabase(db) } // WriteTable writes a single table as a GORM model func (w *Writer) WriteTable(table *models.Table) error { // Create a temporary schema and database schema := models.InitSchema(table.Schema) schema.Tables = []*models.Table{table} db := models.InitDatabase(schema.Name) db.Schemas = []*models.Schema{schema} return w.WriteDatabase(db) } // writeSingleFile writes all models to a single file func (w *Writer) writeSingleFile(db *models.Database) error { packageName := w.getPackageName() templateData := NewTemplateData(packageName, w.config) // Add sql_types import (always needed for nullable types) templateData.AddImport(fmt.Sprintf("sql_types \"%s\"", w.typeMapper.GetSQLTypesImport())) // Collect all models for _, schema := range db.Schemas { for _, table := range schema.Tables { modelData := NewModelData(table, schema.Name, w.typeMapper) // Add relationship fields w.addRelationshipFields(modelData, table, schema, db) templateData.AddModel(modelData) // Check if we need time import for _, field := range modelData.Fields { if w.typeMapper.NeedsTimeImport(field.Type) { templateData.AddImport("\"time\"") } } } } // Add fmt import if GetIDStr is enabled if w.config.GenerateGetIDStr { templateData.AddImport("\"fmt\"") } // Finalize imports templateData.FinalizeImports() // Generate code code, err := w.templates.GenerateCode(templateData) if err != nil { return fmt.Errorf("failed to generate code: %w", err) } // Format code formatted, err := w.formatCode(code) if err != nil { // Return unformatted code with warning fmt.Fprintf(os.Stderr, "Warning: failed to format code: %v\n", err) formatted = code } // Write output if err := w.writeOutput(formatted); err != nil { return err } // Run go fmt on the output file if w.options.OutputPath != "" { w.runGoFmt(w.options.OutputPath) } return nil } // writeMultiFile writes each table to a separate file func (w *Writer) writeMultiFile(db *models.Database) error { packageName := w.getPackageName() // Check if populate_refs is enabled populateRefs := false if w.options.Metadata != nil { if pr, ok := w.options.Metadata["populate_refs"].(bool); ok { populateRefs = pr } } // Ensure output path is a directory if w.options.OutputPath == "" { return fmt.Errorf("output path is required for multi-file mode") } // Create output directory if it doesn't exist if err := os.MkdirAll(w.options.OutputPath, 0755); err != nil { return fmt.Errorf("failed to create output directory: %w", err) } // Generate a file for each table for _, schema := range db.Schemas { // Populate RefDatabase for schema if enabled if populateRefs && schema.RefDatabase == nil { schema.RefDatabase = w.createDatabaseRef(db) } for _, table := range schema.Tables { // Populate RefSchema for table if enabled if populateRefs && table.RefSchema == nil { table.RefSchema = w.createSchemaRef(schema, db) } // Create template data for this single table templateData := NewTemplateData(packageName, w.config) // Add sql_types import templateData.AddImport(fmt.Sprintf("sql_types \"%s\"", w.typeMapper.GetSQLTypesImport())) // Create model data modelData := NewModelData(table, schema.Name, w.typeMapper) // Add relationship fields w.addRelationshipFields(modelData, table, schema, db) templateData.AddModel(modelData) // Check if we need time import for _, field := range modelData.Fields { if w.typeMapper.NeedsTimeImport(field.Type) { templateData.AddImport("\"time\"") } } // Add fmt import if GetIDStr is enabled if w.config.GenerateGetIDStr { templateData.AddImport("\"fmt\"") } // Finalize imports templateData.FinalizeImports() // Generate code code, err := w.templates.GenerateCode(templateData) if err != nil { return fmt.Errorf("failed to generate code for table %s: %w", table.Name, err) } // Format code formatted, err := w.formatCode(code) if err != nil { fmt.Fprintf(os.Stderr, "Warning: failed to format code for %s: %v\n", table.Name, err) formatted = code } // Generate filename: sql_{schema}_{table}.go // Sanitize schema and table names to remove quotes, comments, and invalid characters safeSchemaName := writers.SanitizeFilename(schema.Name) safeTableName := writers.SanitizeFilename(table.Name) filename := fmt.Sprintf("sql_%s_%s.go", safeSchemaName, safeTableName) filepath := filepath.Join(w.options.OutputPath, filename) // Write file if err := os.WriteFile(filepath, []byte(formatted), 0644); err != nil { return fmt.Errorf("failed to write file %s: %w", filename, err) } // Run go fmt on the generated file w.runGoFmt(filepath) } } return nil } // addRelationshipFields adds relationship fields to the model based on foreign keys func (w *Writer) addRelationshipFields(modelData *ModelData, table *models.Table, schema *models.Schema, db *models.Database) { // Track used field names to detect duplicates usedFieldNames := make(map[string]int) // For each foreign key in this table, add a belongs-to relationship for _, constraint := range table.Constraints { if constraint.Type != models.ForeignKeyConstraint { continue } // Find the referenced table refTable := w.findTable(constraint.ReferencedSchema, constraint.ReferencedTable, db) if refTable == nil { continue } // Create relationship field (belongs-to) refModelName := w.getModelName(constraint.ReferencedSchema, constraint.ReferencedTable) fieldName := w.generateBelongsToFieldName(constraint) fieldName = w.ensureUniqueFieldName(fieldName, usedFieldNames) relationTag := w.typeMapper.BuildRelationshipTag(constraint, false) modelData.AddRelationshipField(&FieldData{ Name: fieldName, Type: "*" + refModelName, // Pointer type GormTag: relationTag, JSONTag: strings.ToLower(fieldName) + ",omitempty", Comment: fmt.Sprintf("Belongs to %s", refModelName), }) } // For each table that references this table, add a has-many relationship for _, otherSchema := range db.Schemas { for _, otherTable := range otherSchema.Tables { if otherTable.Name == table.Name && otherSchema.Name == schema.Name { continue // Skip self } for _, constraint := range otherTable.Constraints { if constraint.Type != models.ForeignKeyConstraint { continue } // Check if this constraint references our table if constraint.ReferencedTable == table.Name && constraint.ReferencedSchema == schema.Name { // Add has-many relationship otherModelName := w.getModelName(otherSchema.Name, otherTable.Name) fieldName := w.generateHasManyFieldName(constraint, otherSchema.Name, otherTable.Name) fieldName = w.ensureUniqueFieldName(fieldName, usedFieldNames) relationTag := w.typeMapper.BuildRelationshipTag(constraint, true) modelData.AddRelationshipField(&FieldData{ Name: fieldName, Type: "[]*" + otherModelName, // Slice of pointers GormTag: relationTag, JSONTag: strings.ToLower(fieldName) + ",omitempty", Comment: fmt.Sprintf("Has many %s", otherModelName), }) } } } } } // findTable finds a table by schema and name in the database func (w *Writer) findTable(schemaName, tableName string, db *models.Database) *models.Table { for _, schema := range db.Schemas { if schema.Name != schemaName { continue } for _, table := range schema.Tables { if table.Name == tableName { return table } } } return nil } // getModelName generates the model name from schema and table name func (w *Writer) getModelName(schemaName, tableName string) string { singular := Singularize(tableName) tablePart := SnakeCaseToPascalCase(singular) // Include schema name in model name var modelName string if schemaName != "" { schemaPart := SnakeCaseToPascalCase(schemaName) modelName = "Model" + schemaPart + tablePart } else { modelName = "Model" + tablePart } return modelName } // generateBelongsToFieldName generates a field name for belongs-to relationships // Uses the foreign key column name for uniqueness func (w *Writer) generateBelongsToFieldName(constraint *models.Constraint) string { // Use the foreign key column name to ensure uniqueness // If there are multiple columns, use the first one if len(constraint.Columns) > 0 { columnName := constraint.Columns[0] // Convert to PascalCase for proper Go field naming // e.g., "rid_filepointer_request" -> "RelRIDFilepointerRequest" return "Rel" + SnakeCaseToPascalCase(columnName) } // Fallback to table-based prefix if no columns defined return "Rel" + GeneratePrefix(constraint.ReferencedTable) } // generateHasManyFieldName generates a field name for has-many relationships // Uses the foreign key column name + source table name to avoid duplicates func (w *Writer) generateHasManyFieldName(constraint *models.Constraint, sourceSchemaName, sourceTableName string) string { // For has-many, we need to include the source table name to avoid duplicates // e.g., multiple tables referencing the same column on this table if len(constraint.Columns) > 0 { columnName := constraint.Columns[0] // Get the model name for the source table (pluralized) sourceModelName := w.getModelName(sourceSchemaName, sourceTableName) // Remove "Model" prefix if present sourceModelName = strings.TrimPrefix(sourceModelName, "Model") // Convert column to PascalCase and combine with source table // e.g., "rid_api_provider" + "Login" -> "RelRIDAPIProviderLogins" columnPart := SnakeCaseToPascalCase(columnName) return "Rel" + columnPart + Pluralize(sourceModelName) } // Fallback to table-based naming sourceModelName := w.getModelName(sourceSchemaName, sourceTableName) sourceModelName = strings.TrimPrefix(sourceModelName, "Model") return "Rel" + Pluralize(sourceModelName) } // ensureUniqueFieldName ensures a field name is unique by adding numeric suffixes if needed func (w *Writer) ensureUniqueFieldName(fieldName string, usedNames map[string]int) string { originalName := fieldName count := usedNames[originalName] if count > 0 { // Name is already used, add numeric suffix fieldName = fmt.Sprintf("%s%d", originalName, count+1) } // Increment the counter for this base name usedNames[originalName]++ return fieldName } // getPackageName returns the package name from options or defaults to "models" func (w *Writer) getPackageName() string { if w.options.PackageName != "" { return w.options.PackageName } return "models" } // formatCode formats Go code using gofmt func (w *Writer) formatCode(code string) (string, error) { formatted, err := format.Source([]byte(code)) if err != nil { return "", fmt.Errorf("format error: %w", err) } return string(formatted), nil } // writeOutput writes the content to file or stdout func (w *Writer) writeOutput(content string) error { if w.options.OutputPath != "" { return os.WriteFile(w.options.OutputPath, []byte(content), 0644) } // Print to stdout fmt.Print(content) return nil } // runGoFmt runs go fmt on the specified file func (w *Writer) runGoFmt(filepath string) { cmd := exec.Command("gofmt", "-w", filepath) if err := cmd.Run(); err != nil { // Don't fail the whole operation if gofmt fails, just warn fmt.Fprintf(os.Stderr, "Warning: failed to run gofmt on %s: %v\n", filepath, err) } } // shouldUseMultiFile determines whether to use multi-file mode based on metadata or output path func (w *Writer) shouldUseMultiFile() bool { // Check if multi_file is explicitly set in metadata if w.options.Metadata != nil { if mf, ok := w.options.Metadata["multi_file"].(bool); ok { return mf } } // Auto-detect based on output path if w.options.OutputPath == "" { // No output path means stdout (single file) return false } // Check if path ends with .go (explicit file) if strings.HasSuffix(w.options.OutputPath, ".go") { return false } // Check if path ends with directory separator if strings.HasSuffix(w.options.OutputPath, "/") || strings.HasSuffix(w.options.OutputPath, "\\") { return true } // Check if path exists and is a directory info, err := os.Stat(w.options.OutputPath) if err == nil && info.IsDir() { return true } // Default to single file for ambiguous cases return false } // createDatabaseRef creates a shallow copy of database without schemas to avoid circular references func (w *Writer) createDatabaseRef(db *models.Database) *models.Database { return &models.Database{ Name: db.Name, Description: db.Description, Comment: db.Comment, DatabaseType: db.DatabaseType, DatabaseVersion: db.DatabaseVersion, SourceFormat: db.SourceFormat, Schemas: nil, // Don't include schemas to avoid circular reference GUID: db.GUID, } } // createSchemaRef creates a shallow copy of schema without tables to avoid circular references func (w *Writer) createSchemaRef(schema *models.Schema, db *models.Database) *models.Schema { return &models.Schema{ Name: schema.Name, Description: schema.Description, Owner: schema.Owner, Permissions: schema.Permissions, Comment: schema.Comment, Metadata: schema.Metadata, Scripts: schema.Scripts, Sequence: schema.Sequence, RefDatabase: w.createDatabaseRef(db), // Include database ref Tables: nil, // Don't include tables to avoid circular reference GUID: schema.GUID, } }