407 lines
12 KiB
Go
407 lines
12 KiB
Go
package bun
|
|
|
|
import (
|
|
"fmt"
|
|
"go/format"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
|
|
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
|
"git.warky.dev/wdevs/relspecgo/pkg/writers"
|
|
)
|
|
|
|
// Writer implements the writers.Writer interface for Bun models
|
|
type Writer struct {
|
|
options *writers.WriterOptions
|
|
typeMapper *TypeMapper
|
|
templates *Templates
|
|
config *MethodConfig
|
|
}
|
|
|
|
// NewWriter creates a new Bun writer with the given options
|
|
func NewWriter(options *writers.WriterOptions) *Writer {
|
|
w := &Writer{
|
|
options: options,
|
|
typeMapper: NewTypeMapper(),
|
|
config: LoadMethodConfigFromMetadata(options.Metadata),
|
|
}
|
|
|
|
// Initialize templates
|
|
tmpl, err := NewTemplates()
|
|
if err != nil {
|
|
// Should not happen with embedded templates
|
|
panic(fmt.Sprintf("failed to initialize templates: %v", err))
|
|
}
|
|
w.templates = tmpl
|
|
|
|
return w
|
|
}
|
|
|
|
// WriteDatabase writes a complete database as Bun models
|
|
func (w *Writer) WriteDatabase(db *models.Database) error {
|
|
// Check if multi-file mode is enabled
|
|
multiFile := w.shouldUseMultiFile()
|
|
|
|
if multiFile {
|
|
return w.writeMultiFile(db)
|
|
}
|
|
|
|
return w.writeSingleFile(db)
|
|
}
|
|
|
|
// WriteSchema writes a schema as Bun models
|
|
func (w *Writer) WriteSchema(schema *models.Schema) error {
|
|
// Create a temporary database with just this schema
|
|
db := models.InitDatabase(schema.Name)
|
|
db.Schemas = []*models.Schema{schema}
|
|
|
|
return w.WriteDatabase(db)
|
|
}
|
|
|
|
// WriteTable writes a single table as a Bun model
|
|
func (w *Writer) WriteTable(table *models.Table) error {
|
|
// Create a temporary schema and database
|
|
schema := models.InitSchema(table.Schema)
|
|
schema.Tables = []*models.Table{table}
|
|
|
|
db := models.InitDatabase(schema.Name)
|
|
db.Schemas = []*models.Schema{schema}
|
|
|
|
return w.WriteDatabase(db)
|
|
}
|
|
|
|
// writeSingleFile writes all models to a single file
|
|
func (w *Writer) writeSingleFile(db *models.Database) error {
|
|
packageName := w.getPackageName()
|
|
templateData := NewTemplateData(packageName, w.config)
|
|
|
|
// Add bun import (always needed)
|
|
templateData.AddImport(fmt.Sprintf("\"%s\"", w.typeMapper.GetBunImport()))
|
|
|
|
// Add resolvespec_common import (always needed for nullable types)
|
|
templateData.AddImport(fmt.Sprintf("resolvespec_common \"%s\"", w.typeMapper.GetSQLTypesImport()))
|
|
|
|
// Collect all models
|
|
for _, schema := range db.Schemas {
|
|
for _, table := range schema.Tables {
|
|
modelData := NewModelData(table, schema.Name, w.typeMapper)
|
|
|
|
// Add relationship fields
|
|
w.addRelationshipFields(modelData, table, schema, db)
|
|
|
|
templateData.AddModel(modelData)
|
|
|
|
// Check if we need time import
|
|
for _, field := range modelData.Fields {
|
|
if w.typeMapper.NeedsTimeImport(field.Type) {
|
|
templateData.AddImport("\"time\"")
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Add fmt import if GetIDStr is enabled
|
|
if w.config.GenerateGetIDStr {
|
|
templateData.AddImport("\"fmt\"")
|
|
}
|
|
|
|
// Finalize imports
|
|
templateData.FinalizeImports()
|
|
|
|
// Generate code
|
|
code, err := w.templates.GenerateCode(templateData)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to generate code: %w", err)
|
|
}
|
|
|
|
// Format code
|
|
formatted, err := w.formatCode(code)
|
|
if err != nil {
|
|
// Return unformatted code with warning
|
|
fmt.Fprintf(os.Stderr, "Warning: failed to format code: %v\n", err)
|
|
formatted = code
|
|
}
|
|
|
|
// Write output
|
|
return w.writeOutput(formatted)
|
|
}
|
|
|
|
// writeMultiFile writes each table to a separate file
|
|
func (w *Writer) writeMultiFile(db *models.Database) error {
|
|
packageName := w.getPackageName()
|
|
|
|
// Check if populate_refs is enabled
|
|
populateRefs := false
|
|
if w.options.Metadata != nil {
|
|
if pr, ok := w.options.Metadata["populate_refs"].(bool); ok {
|
|
populateRefs = pr
|
|
}
|
|
}
|
|
|
|
// Ensure output path is a directory
|
|
if w.options.OutputPath == "" {
|
|
return fmt.Errorf("output path is required for multi-file mode")
|
|
}
|
|
|
|
// Create output directory if it doesn't exist
|
|
if err := os.MkdirAll(w.options.OutputPath, 0755); err != nil {
|
|
return fmt.Errorf("failed to create output directory: %w", err)
|
|
}
|
|
|
|
// Generate a file for each table
|
|
for _, schema := range db.Schemas {
|
|
// Populate RefDatabase for schema if enabled
|
|
if populateRefs && schema.RefDatabase == nil {
|
|
schema.RefDatabase = w.createDatabaseRef(db)
|
|
}
|
|
|
|
for _, table := range schema.Tables {
|
|
// Populate RefSchema for table if enabled
|
|
if populateRefs && table.RefSchema == nil {
|
|
table.RefSchema = w.createSchemaRef(schema, db)
|
|
}
|
|
// Create template data for this single table
|
|
templateData := NewTemplateData(packageName, w.config)
|
|
|
|
// Add bun import
|
|
templateData.AddImport(fmt.Sprintf("\"%s\"", w.typeMapper.GetBunImport()))
|
|
|
|
// Add resolvespec_common import
|
|
templateData.AddImport(fmt.Sprintf("resolvespec_common \"%s\"", w.typeMapper.GetSQLTypesImport()))
|
|
|
|
// Create model data
|
|
modelData := NewModelData(table, schema.Name, w.typeMapper)
|
|
|
|
// Add relationship fields
|
|
w.addRelationshipFields(modelData, table, schema, db)
|
|
|
|
templateData.AddModel(modelData)
|
|
|
|
// Check if we need time import
|
|
for _, field := range modelData.Fields {
|
|
if w.typeMapper.NeedsTimeImport(field.Type) {
|
|
templateData.AddImport("\"time\"")
|
|
}
|
|
}
|
|
|
|
// Add fmt import if GetIDStr is enabled
|
|
if w.config.GenerateGetIDStr {
|
|
templateData.AddImport("\"fmt\"")
|
|
}
|
|
|
|
// Finalize imports
|
|
templateData.FinalizeImports()
|
|
|
|
// Generate code
|
|
code, err := w.templates.GenerateCode(templateData)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to generate code for table %s: %w", table.Name, err)
|
|
}
|
|
|
|
// Format code
|
|
formatted, err := w.formatCode(code)
|
|
if err != nil {
|
|
fmt.Fprintf(os.Stderr, "Warning: failed to format code for %s: %v\n", table.Name, err)
|
|
formatted = code
|
|
}
|
|
|
|
// Generate filename: sql_{schema}_{table}.go
|
|
filename := fmt.Sprintf("sql_%s_%s.go", schema.Name, table.Name)
|
|
filepath := filepath.Join(w.options.OutputPath, filename)
|
|
|
|
// Write file
|
|
if err := os.WriteFile(filepath, []byte(formatted), 0644); err != nil {
|
|
return fmt.Errorf("failed to write file %s: %w", filename, err)
|
|
}
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// addRelationshipFields adds relationship fields to the model based on foreign keys
|
|
func (w *Writer) addRelationshipFields(modelData *ModelData, table *models.Table, schema *models.Schema, db *models.Database) {
|
|
// For each foreign key in this table, add a belongs-to/has-one relationship
|
|
for _, constraint := range table.Constraints {
|
|
if constraint.Type != models.ForeignKeyConstraint {
|
|
continue
|
|
}
|
|
|
|
// Find the referenced table
|
|
refTable := w.findTable(constraint.ReferencedSchema, constraint.ReferencedTable, db)
|
|
if refTable == nil {
|
|
continue
|
|
}
|
|
|
|
// Create relationship field (has-one in Bun, similar to belongs-to in GORM)
|
|
refModelName := w.getModelName(constraint.ReferencedTable)
|
|
fieldName := w.generateRelationshipFieldName(constraint.ReferencedTable)
|
|
relationTag := w.typeMapper.BuildRelationshipTag(constraint, "has-one")
|
|
|
|
modelData.AddRelationshipField(&FieldData{
|
|
Name: fieldName,
|
|
Type: "*" + refModelName, // Pointer type
|
|
BunTag: relationTag,
|
|
JSONTag: strings.ToLower(fieldName) + ",omitempty",
|
|
Comment: fmt.Sprintf("Has one %s", refModelName),
|
|
})
|
|
}
|
|
|
|
// For each table that references this table, add a has-many relationship
|
|
for _, otherSchema := range db.Schemas {
|
|
for _, otherTable := range otherSchema.Tables {
|
|
if otherTable.Name == table.Name && otherSchema.Name == schema.Name {
|
|
continue // Skip self
|
|
}
|
|
|
|
for _, constraint := range otherTable.Constraints {
|
|
if constraint.Type != models.ForeignKeyConstraint {
|
|
continue
|
|
}
|
|
|
|
// Check if this constraint references our table
|
|
if constraint.ReferencedTable == table.Name && constraint.ReferencedSchema == schema.Name {
|
|
// Add has-many relationship
|
|
otherModelName := w.getModelName(otherTable.Name)
|
|
fieldName := w.generateRelationshipFieldName(otherTable.Name) + "s" // Pluralize
|
|
relationTag := w.typeMapper.BuildRelationshipTag(constraint, "has-many")
|
|
|
|
modelData.AddRelationshipField(&FieldData{
|
|
Name: fieldName,
|
|
Type: "[]*" + otherModelName, // Slice of pointers
|
|
BunTag: relationTag,
|
|
JSONTag: strings.ToLower(fieldName) + ",omitempty",
|
|
Comment: fmt.Sprintf("Has many %s", otherModelName),
|
|
})
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// findTable finds a table by schema and name in the database
|
|
func (w *Writer) findTable(schemaName, tableName string, db *models.Database) *models.Table {
|
|
for _, schema := range db.Schemas {
|
|
if schema.Name != schemaName {
|
|
continue
|
|
}
|
|
for _, table := range schema.Tables {
|
|
if table.Name == tableName {
|
|
return table
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// getModelName generates the model name from a table name
|
|
func (w *Writer) getModelName(tableName string) string {
|
|
singular := Singularize(tableName)
|
|
modelName := SnakeCaseToPascalCase(singular)
|
|
|
|
if !hasModelPrefix(modelName) {
|
|
modelName = "Model" + modelName
|
|
}
|
|
|
|
return modelName
|
|
}
|
|
|
|
// generateRelationshipFieldName generates a field name for a relationship
|
|
func (w *Writer) generateRelationshipFieldName(tableName string) string {
|
|
// Use just the prefix (3 letters) for relationship fields
|
|
return GeneratePrefix(tableName)
|
|
}
|
|
|
|
// getPackageName returns the package name from options or defaults to "models"
|
|
func (w *Writer) getPackageName() string {
|
|
if w.options.PackageName != "" {
|
|
return w.options.PackageName
|
|
}
|
|
return "models"
|
|
}
|
|
|
|
// formatCode formats Go code using gofmt
|
|
func (w *Writer) formatCode(code string) (string, error) {
|
|
formatted, err := format.Source([]byte(code))
|
|
if err != nil {
|
|
return "", fmt.Errorf("format error: %w", err)
|
|
}
|
|
return string(formatted), nil
|
|
}
|
|
|
|
// writeOutput writes the content to file or stdout
|
|
func (w *Writer) writeOutput(content string) error {
|
|
if w.options.OutputPath != "" {
|
|
return os.WriteFile(w.options.OutputPath, []byte(content), 0644)
|
|
}
|
|
|
|
// Print to stdout
|
|
fmt.Print(content)
|
|
return nil
|
|
}
|
|
|
|
// shouldUseMultiFile determines whether to use multi-file mode based on metadata or output path
|
|
func (w *Writer) shouldUseMultiFile() bool {
|
|
// Check if multi_file is explicitly set in metadata
|
|
if w.options.Metadata != nil {
|
|
if mf, ok := w.options.Metadata["multi_file"].(bool); ok {
|
|
return mf
|
|
}
|
|
}
|
|
|
|
// Auto-detect based on output path
|
|
if w.options.OutputPath == "" {
|
|
// No output path means stdout (single file)
|
|
return false
|
|
}
|
|
|
|
// Check if path ends with .go (explicit file)
|
|
if strings.HasSuffix(w.options.OutputPath, ".go") {
|
|
return false
|
|
}
|
|
|
|
// Check if path ends with directory separator
|
|
if strings.HasSuffix(w.options.OutputPath, "/") || strings.HasSuffix(w.options.OutputPath, "\\") {
|
|
return true
|
|
}
|
|
|
|
// Check if path exists and is a directory
|
|
info, err := os.Stat(w.options.OutputPath)
|
|
if err == nil && info.IsDir() {
|
|
return true
|
|
}
|
|
|
|
// Default to single file for ambiguous cases
|
|
return false
|
|
}
|
|
|
|
// createDatabaseRef creates a shallow copy of database without schemas to avoid circular references
|
|
func (w *Writer) createDatabaseRef(db *models.Database) *models.Database {
|
|
return &models.Database{
|
|
Name: db.Name,
|
|
Description: db.Description,
|
|
Comment: db.Comment,
|
|
DatabaseType: db.DatabaseType,
|
|
DatabaseVersion: db.DatabaseVersion,
|
|
SourceFormat: db.SourceFormat,
|
|
Schemas: nil, // Don't include schemas to avoid circular reference
|
|
}
|
|
}
|
|
|
|
// createSchemaRef creates a shallow copy of schema without tables to avoid circular references
|
|
func (w *Writer) createSchemaRef(schema *models.Schema, db *models.Database) *models.Schema {
|
|
return &models.Schema{
|
|
Name: schema.Name,
|
|
Description: schema.Description,
|
|
Owner: schema.Owner,
|
|
Permissions: schema.Permissions,
|
|
Comment: schema.Comment,
|
|
Metadata: schema.Metadata,
|
|
Scripts: schema.Scripts,
|
|
Sequence: schema.Sequence,
|
|
RefDatabase: w.createDatabaseRef(db), // Include database ref
|
|
Tables: nil, // Don't include tables to avoid circular reference
|
|
}
|
|
}
|