Files
relspecgo/pkg/writers/bun/writer.go
Hein 6f55505444
All checks were successful
CI / Test (1.24) (push) Successful in -27m27s
CI / Test (1.25) (push) Successful in -27m17s
CI / Lint (push) Successful in -27m27s
CI / Build (push) Successful in -27m38s
Release / Build and Release (push) Successful in -27m24s
Integration Tests / Integration Tests (push) Successful in -27m16s
feat(writer): 🎉 Enhance model name generation and formatting
* Update model name generation to include schema name.
* Add gofmt execution after writing output files.
* Refactor relationship field naming to include schema.
* Update tests to reflect changes in model names and relationships.
2026-01-10 18:28:41 +02:00

494 lines
15 KiB
Go

package bun
import (
"fmt"
"go/format"
"os"
"os/exec"
"path/filepath"
"strings"
"git.warky.dev/wdevs/relspecgo/pkg/models"
"git.warky.dev/wdevs/relspecgo/pkg/writers"
)
// Writer implements the writers.Writer interface for Bun models
type Writer struct {
options *writers.WriterOptions
typeMapper *TypeMapper
templates *Templates
config *MethodConfig
}
// NewWriter creates a new Bun writer with the given options
func NewWriter(options *writers.WriterOptions) *Writer {
w := &Writer{
options: options,
typeMapper: NewTypeMapper(),
config: LoadMethodConfigFromMetadata(options.Metadata),
}
// Initialize templates
tmpl, err := NewTemplates()
if err != nil {
// Should not happen with embedded templates
panic(fmt.Sprintf("failed to initialize templates: %v", err))
}
w.templates = tmpl
return w
}
// WriteDatabase writes a complete database as Bun models
func (w *Writer) WriteDatabase(db *models.Database) error {
// Check if multi-file mode is enabled
multiFile := w.shouldUseMultiFile()
if multiFile {
return w.writeMultiFile(db)
}
return w.writeSingleFile(db)
}
// WriteSchema writes a schema as Bun models
func (w *Writer) WriteSchema(schema *models.Schema) error {
// Create a temporary database with just this schema
db := models.InitDatabase(schema.Name)
db.Schemas = []*models.Schema{schema}
return w.WriteDatabase(db)
}
// WriteTable writes a single table as a Bun model
func (w *Writer) WriteTable(table *models.Table) error {
// Create a temporary schema and database
schema := models.InitSchema(table.Schema)
schema.Tables = []*models.Table{table}
db := models.InitDatabase(schema.Name)
db.Schemas = []*models.Schema{schema}
return w.WriteDatabase(db)
}
// writeSingleFile writes all models to a single file
func (w *Writer) writeSingleFile(db *models.Database) error {
packageName := w.getPackageName()
templateData := NewTemplateData(packageName, w.config)
// Add bun import (always needed)
templateData.AddImport(fmt.Sprintf("\"%s\"", w.typeMapper.GetBunImport()))
// Add resolvespec_common import (always needed for nullable types)
templateData.AddImport(fmt.Sprintf("resolvespec_common \"%s\"", w.typeMapper.GetSQLTypesImport()))
// Collect all models
for _, schema := range db.Schemas {
for _, table := range schema.Tables {
modelData := NewModelData(table, schema.Name, w.typeMapper)
// Add relationship fields
w.addRelationshipFields(modelData, table, schema, db)
templateData.AddModel(modelData)
// Check if we need time import
for _, field := range modelData.Fields {
if w.typeMapper.NeedsTimeImport(field.Type) {
templateData.AddImport("\"time\"")
}
}
}
}
// Add fmt import if GetIDStr is enabled
if w.config.GenerateGetIDStr {
templateData.AddImport("\"fmt\"")
}
// Finalize imports
templateData.FinalizeImports()
// Generate code
code, err := w.templates.GenerateCode(templateData)
if err != nil {
return fmt.Errorf("failed to generate code: %w", err)
}
// Format code
formatted, err := w.formatCode(code)
if err != nil {
// Return unformatted code with warning
fmt.Fprintf(os.Stderr, "Warning: failed to format code: %v\n", err)
formatted = code
}
// Write output
if err := w.writeOutput(formatted); err != nil {
return err
}
// Run go fmt on the output file
if w.options.OutputPath != "" {
w.runGoFmt(w.options.OutputPath)
}
return nil
}
// writeMultiFile writes each table to a separate file
func (w *Writer) writeMultiFile(db *models.Database) error {
packageName := w.getPackageName()
// Check if populate_refs is enabled
populateRefs := false
if w.options.Metadata != nil {
if pr, ok := w.options.Metadata["populate_refs"].(bool); ok {
populateRefs = pr
}
}
// Ensure output path is a directory
if w.options.OutputPath == "" {
return fmt.Errorf("output path is required for multi-file mode")
}
// Create output directory if it doesn't exist
if err := os.MkdirAll(w.options.OutputPath, 0755); err != nil {
return fmt.Errorf("failed to create output directory: %w", err)
}
// Generate a file for each table
for _, schema := range db.Schemas {
// Populate RefDatabase for schema if enabled
if populateRefs && schema.RefDatabase == nil {
schema.RefDatabase = w.createDatabaseRef(db)
}
for _, table := range schema.Tables {
// Populate RefSchema for table if enabled
if populateRefs && table.RefSchema == nil {
table.RefSchema = w.createSchemaRef(schema, db)
}
// Create template data for this single table
templateData := NewTemplateData(packageName, w.config)
// Add bun import
templateData.AddImport(fmt.Sprintf("\"%s\"", w.typeMapper.GetBunImport()))
// Add resolvespec_common import
templateData.AddImport(fmt.Sprintf("resolvespec_common \"%s\"", w.typeMapper.GetSQLTypesImport()))
// Create model data
modelData := NewModelData(table, schema.Name, w.typeMapper)
// Add relationship fields
w.addRelationshipFields(modelData, table, schema, db)
templateData.AddModel(modelData)
// Check if we need time import
for _, field := range modelData.Fields {
if w.typeMapper.NeedsTimeImport(field.Type) {
templateData.AddImport("\"time\"")
}
}
// Add fmt import if GetIDStr is enabled
if w.config.GenerateGetIDStr {
templateData.AddImport("\"fmt\"")
}
// Finalize imports
templateData.FinalizeImports()
// Generate code
code, err := w.templates.GenerateCode(templateData)
if err != nil {
return fmt.Errorf("failed to generate code for table %s: %w", table.Name, err)
}
// Format code
formatted, err := w.formatCode(code)
if err != nil {
fmt.Fprintf(os.Stderr, "Warning: failed to format code for %s: %v\n", table.Name, err)
formatted = code
}
// Generate filename: sql_{schema}_{table}.go
// Sanitize schema and table names to remove quotes, comments, and invalid characters
safeSchemaName := writers.SanitizeFilename(schema.Name)
safeTableName := writers.SanitizeFilename(table.Name)
filename := fmt.Sprintf("sql_%s_%s.go", safeSchemaName, safeTableName)
filepath := filepath.Join(w.options.OutputPath, filename)
// Write file
if err := os.WriteFile(filepath, []byte(formatted), 0644); err != nil {
return fmt.Errorf("failed to write file %s: %w", filename, err)
}
// Run go fmt on the generated file
w.runGoFmt(filepath)
}
}
return nil
}
// addRelationshipFields adds relationship fields to the model based on foreign keys
func (w *Writer) addRelationshipFields(modelData *ModelData, table *models.Table, schema *models.Schema, db *models.Database) {
// Track used field names to detect duplicates
usedFieldNames := make(map[string]int)
// For each foreign key in this table, add a belongs-to/has-one relationship
for _, constraint := range table.Constraints {
if constraint.Type != models.ForeignKeyConstraint {
continue
}
// Find the referenced table
refTable := w.findTable(constraint.ReferencedSchema, constraint.ReferencedTable, db)
if refTable == nil {
continue
}
// Create relationship field (has-one in Bun, similar to belongs-to in GORM)
refModelName := w.getModelName(constraint.ReferencedSchema, constraint.ReferencedTable)
fieldName := w.generateHasOneFieldName(constraint)
fieldName = w.ensureUniqueFieldName(fieldName, usedFieldNames)
relationTag := w.typeMapper.BuildRelationshipTag(constraint, "has-one")
modelData.AddRelationshipField(&FieldData{
Name: fieldName,
Type: "*" + refModelName, // Pointer type
BunTag: relationTag,
JSONTag: strings.ToLower(fieldName) + ",omitempty",
Comment: fmt.Sprintf("Has one %s", refModelName),
})
}
// For each table that references this table, add a has-many relationship
for _, otherSchema := range db.Schemas {
for _, otherTable := range otherSchema.Tables {
if otherTable.Name == table.Name && otherSchema.Name == schema.Name {
continue // Skip self
}
for _, constraint := range otherTable.Constraints {
if constraint.Type != models.ForeignKeyConstraint {
continue
}
// Check if this constraint references our table
if constraint.ReferencedTable == table.Name && constraint.ReferencedSchema == schema.Name {
// Add has-many relationship
otherModelName := w.getModelName(otherSchema.Name, otherTable.Name)
fieldName := w.generateHasManyFieldName(constraint, otherSchema.Name, otherTable.Name)
fieldName = w.ensureUniqueFieldName(fieldName, usedFieldNames)
relationTag := w.typeMapper.BuildRelationshipTag(constraint, "has-many")
modelData.AddRelationshipField(&FieldData{
Name: fieldName,
Type: "[]*" + otherModelName, // Slice of pointers
BunTag: relationTag,
JSONTag: strings.ToLower(fieldName) + ",omitempty",
Comment: fmt.Sprintf("Has many %s", otherModelName),
})
}
}
}
}
}
// findTable finds a table by schema and name in the database
func (w *Writer) findTable(schemaName, tableName string, db *models.Database) *models.Table {
for _, schema := range db.Schemas {
if schema.Name != schemaName {
continue
}
for _, table := range schema.Tables {
if table.Name == tableName {
return table
}
}
}
return nil
}
// getModelName generates the model name from schema and table name
func (w *Writer) getModelName(schemaName, tableName string) string {
singular := Singularize(tableName)
tablePart := SnakeCaseToPascalCase(singular)
// Include schema name in model name
var modelName string
if schemaName != "" {
schemaPart := SnakeCaseToPascalCase(schemaName)
modelName = "Model" + schemaPart + tablePart
} else {
modelName = "Model" + tablePart
}
return modelName
}
// generateHasOneFieldName generates a field name for has-one relationships
// Uses the foreign key column name for uniqueness
func (w *Writer) generateHasOneFieldName(constraint *models.Constraint) string {
// Use the foreign key column name to ensure uniqueness
// If there are multiple columns, use the first one
if len(constraint.Columns) > 0 {
columnName := constraint.Columns[0]
// Convert to PascalCase for proper Go field naming
// e.g., "rid_filepointer_request" -> "RelRIDFilepointerRequest"
return "Rel" + SnakeCaseToPascalCase(columnName)
}
// Fallback to table-based prefix if no columns defined
return "Rel" + GeneratePrefix(constraint.ReferencedTable)
}
// generateHasManyFieldName generates a field name for has-many relationships
// Uses the foreign key column name + source table name to avoid duplicates
func (w *Writer) generateHasManyFieldName(constraint *models.Constraint, sourceSchemaName, sourceTableName string) string {
// For has-many, we need to include the source table name to avoid duplicates
// e.g., multiple tables referencing the same column on this table
if len(constraint.Columns) > 0 {
columnName := constraint.Columns[0]
// Get the model name for the source table (pluralized)
sourceModelName := w.getModelName(sourceSchemaName, sourceTableName)
// Remove "Model" prefix if present
sourceModelName = strings.TrimPrefix(sourceModelName, "Model")
// Convert column to PascalCase and combine with source table
// e.g., "rid_api_provider" + "Login" -> "RelRIDAPIProviderLogins"
columnPart := SnakeCaseToPascalCase(columnName)
return "Rel" + columnPart + Pluralize(sourceModelName)
}
// Fallback to table-based naming
sourceModelName := w.getModelName(sourceSchemaName, sourceTableName)
sourceModelName = strings.TrimPrefix(sourceModelName, "Model")
return "Rel" + Pluralize(sourceModelName)
}
// ensureUniqueFieldName ensures a field name is unique by adding numeric suffixes if needed
func (w *Writer) ensureUniqueFieldName(fieldName string, usedNames map[string]int) string {
originalName := fieldName
count := usedNames[originalName]
if count > 0 {
// Name is already used, add numeric suffix
fieldName = fmt.Sprintf("%s%d", originalName, count+1)
}
// Increment the counter for this base name
usedNames[originalName]++
return fieldName
}
// getPackageName returns the package name from options or defaults to "models"
func (w *Writer) getPackageName() string {
if w.options.PackageName != "" {
return w.options.PackageName
}
return "models"
}
// formatCode formats Go code using gofmt
func (w *Writer) formatCode(code string) (string, error) {
formatted, err := format.Source([]byte(code))
if err != nil {
return "", fmt.Errorf("format error: %w", err)
}
return string(formatted), nil
}
// writeOutput writes the content to file or stdout
func (w *Writer) writeOutput(content string) error {
if w.options.OutputPath != "" {
return os.WriteFile(w.options.OutputPath, []byte(content), 0644)
}
// Print to stdout
fmt.Print(content)
return nil
}
// runGoFmt runs go fmt on the specified file
func (w *Writer) runGoFmt(filepath string) {
cmd := exec.Command("gofmt", "-w", filepath)
if err := cmd.Run(); err != nil {
// Don't fail the whole operation if gofmt fails, just warn
fmt.Fprintf(os.Stderr, "Warning: failed to run gofmt on %s: %v\n", filepath, err)
}
}
// shouldUseMultiFile determines whether to use multi-file mode based on metadata or output path
func (w *Writer) shouldUseMultiFile() bool {
// Check if multi_file is explicitly set in metadata
if w.options.Metadata != nil {
if mf, ok := w.options.Metadata["multi_file"].(bool); ok {
return mf
}
}
// Auto-detect based on output path
if w.options.OutputPath == "" {
// No output path means stdout (single file)
return false
}
// Check if path ends with .go (explicit file)
if strings.HasSuffix(w.options.OutputPath, ".go") {
return false
}
// Check if path ends with directory separator
if strings.HasSuffix(w.options.OutputPath, "/") || strings.HasSuffix(w.options.OutputPath, "\\") {
return true
}
// Check if path exists and is a directory
info, err := os.Stat(w.options.OutputPath)
if err == nil && info.IsDir() {
return true
}
// Default to single file for ambiguous cases
return false
}
// createDatabaseRef creates a shallow copy of database without schemas to avoid circular references
func (w *Writer) createDatabaseRef(db *models.Database) *models.Database {
return &models.Database{
Name: db.Name,
Description: db.Description,
Comment: db.Comment,
DatabaseType: db.DatabaseType,
DatabaseVersion: db.DatabaseVersion,
SourceFormat: db.SourceFormat,
Schemas: nil, // Don't include schemas to avoid circular reference
GUID: db.GUID,
}
}
// createSchemaRef creates a shallow copy of schema without tables to avoid circular references
func (w *Writer) createSchemaRef(schema *models.Schema, db *models.Database) *models.Schema {
return &models.Schema{
Name: schema.Name,
Description: schema.Description,
Owner: schema.Owner,
Permissions: schema.Permissions,
Comment: schema.Comment,
Metadata: schema.Metadata,
Scripts: schema.Scripts,
Sequence: schema.Sequence,
RefDatabase: w.createDatabaseRef(db), // Include database ref
Tables: nil, // Don't include tables to avoid circular reference
GUID: schema.GUID,
}
}