So far so good
This commit is contained in:
324
pkg/writers/gorm/writer.go
Normal file
324
pkg/writers/gorm/writer.go
Normal file
@@ -0,0 +1,324 @@
|
||||
package gorm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"go/format"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/writers"
|
||||
)
|
||||
|
||||
// Writer implements the writers.Writer interface for GORM models
|
||||
type Writer struct {
|
||||
options *writers.WriterOptions
|
||||
typeMapper *TypeMapper
|
||||
templates *Templates
|
||||
config *MethodConfig
|
||||
}
|
||||
|
||||
// NewWriter creates a new GORM writer with the given options
|
||||
func NewWriter(options *writers.WriterOptions) *Writer {
|
||||
w := &Writer{
|
||||
options: options,
|
||||
typeMapper: NewTypeMapper(),
|
||||
config: LoadMethodConfigFromMetadata(options.Metadata),
|
||||
}
|
||||
|
||||
// Initialize templates
|
||||
tmpl, err := NewTemplates()
|
||||
if err != nil {
|
||||
// Should not happen with embedded templates
|
||||
panic(fmt.Sprintf("failed to initialize templates: %v", err))
|
||||
}
|
||||
w.templates = tmpl
|
||||
|
||||
return w
|
||||
}
|
||||
|
||||
// WriteDatabase writes a complete database as GORM models
|
||||
func (w *Writer) WriteDatabase(db *models.Database) error {
|
||||
// Check if multi-file mode is enabled
|
||||
multiFile := false
|
||||
if w.options.Metadata != nil {
|
||||
if mf, ok := w.options.Metadata["multi_file"].(bool); ok {
|
||||
multiFile = mf
|
||||
}
|
||||
}
|
||||
|
||||
if multiFile {
|
||||
return w.writeMultiFile(db)
|
||||
}
|
||||
|
||||
return w.writeSingleFile(db)
|
||||
}
|
||||
|
||||
// WriteSchema writes a schema as GORM models
|
||||
func (w *Writer) WriteSchema(schema *models.Schema) error {
|
||||
// Create a temporary database with just this schema
|
||||
db := models.InitDatabase(schema.Name)
|
||||
db.Schemas = []*models.Schema{schema}
|
||||
|
||||
return w.WriteDatabase(db)
|
||||
}
|
||||
|
||||
// WriteTable writes a single table as a GORM model
|
||||
func (w *Writer) WriteTable(table *models.Table) error {
|
||||
// Create a temporary schema and database
|
||||
schema := models.InitSchema(table.Schema)
|
||||
schema.Tables = []*models.Table{table}
|
||||
|
||||
db := models.InitDatabase(schema.Name)
|
||||
db.Schemas = []*models.Schema{schema}
|
||||
|
||||
return w.WriteDatabase(db)
|
||||
}
|
||||
|
||||
// writeSingleFile writes all models to a single file
|
||||
func (w *Writer) writeSingleFile(db *models.Database) error {
|
||||
packageName := w.getPackageName()
|
||||
templateData := NewTemplateData(packageName, w.config)
|
||||
|
||||
// Add sql_types import (always needed for nullable types)
|
||||
templateData.AddImport(fmt.Sprintf("sql_types \"%s\"", w.typeMapper.GetSQLTypesImport()))
|
||||
|
||||
// Collect all models
|
||||
for _, schema := range db.Schemas {
|
||||
for _, table := range schema.Tables {
|
||||
modelData := NewModelData(table, schema.Name, w.typeMapper)
|
||||
|
||||
// Add relationship fields
|
||||
w.addRelationshipFields(modelData, table, schema, db)
|
||||
|
||||
templateData.AddModel(modelData)
|
||||
|
||||
// Check if we need time import
|
||||
for _, field := range modelData.Fields {
|
||||
if w.typeMapper.NeedsTimeImport(field.Type) {
|
||||
templateData.AddImport("\"time\"")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add fmt import if GetIDStr is enabled
|
||||
if w.config.GenerateGetIDStr {
|
||||
templateData.AddImport("\"fmt\"")
|
||||
}
|
||||
|
||||
// Finalize imports
|
||||
templateData.FinalizeImports()
|
||||
|
||||
// Generate code
|
||||
code, err := w.templates.GenerateCode(templateData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate code: %w", err)
|
||||
}
|
||||
|
||||
// Format code
|
||||
formatted, err := w.formatCode(code)
|
||||
if err != nil {
|
||||
// Return unformatted code with warning
|
||||
fmt.Fprintf(os.Stderr, "Warning: failed to format code: %v\n", err)
|
||||
formatted = code
|
||||
}
|
||||
|
||||
// Write output
|
||||
return w.writeOutput(formatted)
|
||||
}
|
||||
|
||||
// writeMultiFile writes each table to a separate file
|
||||
func (w *Writer) writeMultiFile(db *models.Database) error {
|
||||
packageName := w.getPackageName()
|
||||
|
||||
// Ensure output path is a directory
|
||||
if w.options.OutputPath == "" {
|
||||
return fmt.Errorf("output path is required for multi-file mode")
|
||||
}
|
||||
|
||||
// Create output directory if it doesn't exist
|
||||
if err := os.MkdirAll(w.options.OutputPath, 0755); err != nil {
|
||||
return fmt.Errorf("failed to create output directory: %w", err)
|
||||
}
|
||||
|
||||
// Generate a file for each table
|
||||
for _, schema := range db.Schemas {
|
||||
for _, table := range schema.Tables {
|
||||
// Create template data for this single table
|
||||
templateData := NewTemplateData(packageName, w.config)
|
||||
|
||||
// Add sql_types import
|
||||
templateData.AddImport(fmt.Sprintf("sql_types \"%s\"", w.typeMapper.GetSQLTypesImport()))
|
||||
|
||||
// Create model data
|
||||
modelData := NewModelData(table, schema.Name, w.typeMapper)
|
||||
|
||||
// Add relationship fields
|
||||
w.addRelationshipFields(modelData, table, schema, db)
|
||||
|
||||
templateData.AddModel(modelData)
|
||||
|
||||
// Check if we need time import
|
||||
for _, field := range modelData.Fields {
|
||||
if w.typeMapper.NeedsTimeImport(field.Type) {
|
||||
templateData.AddImport("\"time\"")
|
||||
}
|
||||
}
|
||||
|
||||
// Add fmt import if GetIDStr is enabled
|
||||
if w.config.GenerateGetIDStr {
|
||||
templateData.AddImport("\"fmt\"")
|
||||
}
|
||||
|
||||
// Finalize imports
|
||||
templateData.FinalizeImports()
|
||||
|
||||
// Generate code
|
||||
code, err := w.templates.GenerateCode(templateData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate code for table %s: %w", table.Name, err)
|
||||
}
|
||||
|
||||
// Format code
|
||||
formatted, err := w.formatCode(code)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Warning: failed to format code for %s: %v\n", table.Name, err)
|
||||
formatted = code
|
||||
}
|
||||
|
||||
// Generate filename: sql_{schema}_{table}.go
|
||||
filename := fmt.Sprintf("sql_%s_%s.go", schema.Name, table.Name)
|
||||
filepath := filepath.Join(w.options.OutputPath, filename)
|
||||
|
||||
// Write file
|
||||
if err := os.WriteFile(filepath, []byte(formatted), 0644); err != nil {
|
||||
return fmt.Errorf("failed to write file %s: %w", filename, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// addRelationshipFields adds relationship fields to the model based on foreign keys
|
||||
func (w *Writer) addRelationshipFields(modelData *ModelData, table *models.Table, schema *models.Schema, db *models.Database) {
|
||||
// For each foreign key in this table, add a belongs-to relationship
|
||||
for _, constraint := range table.Constraints {
|
||||
if constraint.Type != models.ForeignKeyConstraint {
|
||||
continue
|
||||
}
|
||||
|
||||
// Find the referenced table
|
||||
refTable := w.findTable(constraint.ReferencedSchema, constraint.ReferencedTable, db)
|
||||
if refTable == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Create relationship field (belongs-to)
|
||||
refModelName := w.getModelName(constraint.ReferencedTable)
|
||||
fieldName := w.generateRelationshipFieldName(constraint.ReferencedTable)
|
||||
relationTag := w.typeMapper.BuildRelationshipTag(constraint, false)
|
||||
|
||||
modelData.AddRelationshipField(&FieldData{
|
||||
Name: fieldName,
|
||||
Type: "*" + refModelName, // Pointer type
|
||||
GormTag: relationTag,
|
||||
JSONTag: strings.ToLower(fieldName) + ",omitempty",
|
||||
Comment: fmt.Sprintf("Belongs to %s", refModelName),
|
||||
})
|
||||
}
|
||||
|
||||
// For each table that references this table, add a has-many relationship
|
||||
for _, otherSchema := range db.Schemas {
|
||||
for _, otherTable := range otherSchema.Tables {
|
||||
if otherTable.Name == table.Name && otherSchema.Name == schema.Name {
|
||||
continue // Skip self
|
||||
}
|
||||
|
||||
for _, constraint := range otherTable.Constraints {
|
||||
if constraint.Type != models.ForeignKeyConstraint {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if this constraint references our table
|
||||
if constraint.ReferencedTable == table.Name && constraint.ReferencedSchema == schema.Name {
|
||||
// Add has-many relationship
|
||||
otherModelName := w.getModelName(otherTable.Name)
|
||||
fieldName := w.generateRelationshipFieldName(otherTable.Name) + "s" // Pluralize
|
||||
relationTag := w.typeMapper.BuildRelationshipTag(constraint, true)
|
||||
|
||||
modelData.AddRelationshipField(&FieldData{
|
||||
Name: fieldName,
|
||||
Type: "[]*" + otherModelName, // Slice of pointers
|
||||
GormTag: relationTag,
|
||||
JSONTag: strings.ToLower(fieldName) + ",omitempty",
|
||||
Comment: fmt.Sprintf("Has many %s", otherModelName),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// findTable finds a table by schema and name in the database
|
||||
func (w *Writer) findTable(schemaName, tableName string, db *models.Database) *models.Table {
|
||||
for _, schema := range db.Schemas {
|
||||
if schema.Name != schemaName {
|
||||
continue
|
||||
}
|
||||
for _, table := range schema.Tables {
|
||||
if table.Name == tableName {
|
||||
return table
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// getModelName generates the model name from a table name
|
||||
func (w *Writer) getModelName(tableName string) string {
|
||||
singular := Singularize(tableName)
|
||||
modelName := SnakeCaseToPascalCase(singular)
|
||||
|
||||
if !hasModelPrefix(modelName) {
|
||||
modelName = "Model" + modelName
|
||||
}
|
||||
|
||||
return modelName
|
||||
}
|
||||
|
||||
// generateRelationshipFieldName generates a field name for a relationship
|
||||
func (w *Writer) generateRelationshipFieldName(tableName string) string {
|
||||
// Use just the prefix (3 letters) for relationship fields
|
||||
return GeneratePrefix(tableName)
|
||||
}
|
||||
|
||||
// getPackageName returns the package name from options or defaults to "models"
|
||||
func (w *Writer) getPackageName() string {
|
||||
if w.options.PackageName != "" {
|
||||
return w.options.PackageName
|
||||
}
|
||||
return "models"
|
||||
}
|
||||
|
||||
// formatCode formats Go code using gofmt
|
||||
func (w *Writer) formatCode(code string) (string, error) {
|
||||
formatted, err := format.Source([]byte(code))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("format error: %w", err)
|
||||
}
|
||||
return string(formatted), nil
|
||||
}
|
||||
|
||||
// writeOutput writes the content to file or stdout
|
||||
func (w *Writer) writeOutput(content string) error {
|
||||
if w.options.OutputPath != "" {
|
||||
return os.WriteFile(w.options.OutputPath, []byte(content), 0644)
|
||||
}
|
||||
|
||||
// Print to stdout
|
||||
fmt.Print(content)
|
||||
return nil
|
||||
}
|
||||
Reference in New Issue
Block a user