All checks were successful
CI / Test (1.24) (push) Successful in -27m27s
CI / Test (1.25) (push) Successful in -27m17s
CI / Lint (push) Successful in -27m27s
CI / Build (push) Successful in -27m38s
Release / Build and Release (push) Successful in -27m24s
Integration Tests / Integration Tests (push) Successful in -27m16s
* Update model name generation to include schema name. * Add gofmt execution after writing output files. * Refactor relationship field naming to include schema. * Update tests to reflect changes in model names and relationships.
494 lines
15 KiB
Go
494 lines
15 KiB
Go
package bun
|
|
|
|
import (
|
|
"fmt"
|
|
"go/format"
|
|
"os"
|
|
"os/exec"
|
|
"path/filepath"
|
|
"strings"
|
|
|
|
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
|
"git.warky.dev/wdevs/relspecgo/pkg/writers"
|
|
)
|
|
|
|
// Writer implements the writers.Writer interface for Bun models
|
|
type Writer struct {
|
|
options *writers.WriterOptions
|
|
typeMapper *TypeMapper
|
|
templates *Templates
|
|
config *MethodConfig
|
|
}
|
|
|
|
// NewWriter creates a new Bun writer with the given options
|
|
func NewWriter(options *writers.WriterOptions) *Writer {
|
|
w := &Writer{
|
|
options: options,
|
|
typeMapper: NewTypeMapper(),
|
|
config: LoadMethodConfigFromMetadata(options.Metadata),
|
|
}
|
|
|
|
// Initialize templates
|
|
tmpl, err := NewTemplates()
|
|
if err != nil {
|
|
// Should not happen with embedded templates
|
|
panic(fmt.Sprintf("failed to initialize templates: %v", err))
|
|
}
|
|
w.templates = tmpl
|
|
|
|
return w
|
|
}
|
|
|
|
// WriteDatabase writes a complete database as Bun models
|
|
func (w *Writer) WriteDatabase(db *models.Database) error {
|
|
// Check if multi-file mode is enabled
|
|
multiFile := w.shouldUseMultiFile()
|
|
|
|
if multiFile {
|
|
return w.writeMultiFile(db)
|
|
}
|
|
|
|
return w.writeSingleFile(db)
|
|
}
|
|
|
|
// WriteSchema writes a schema as Bun models
|
|
func (w *Writer) WriteSchema(schema *models.Schema) error {
|
|
// Create a temporary database with just this schema
|
|
db := models.InitDatabase(schema.Name)
|
|
db.Schemas = []*models.Schema{schema}
|
|
|
|
return w.WriteDatabase(db)
|
|
}
|
|
|
|
// WriteTable writes a single table as a Bun model
|
|
func (w *Writer) WriteTable(table *models.Table) error {
|
|
// Create a temporary schema and database
|
|
schema := models.InitSchema(table.Schema)
|
|
schema.Tables = []*models.Table{table}
|
|
|
|
db := models.InitDatabase(schema.Name)
|
|
db.Schemas = []*models.Schema{schema}
|
|
|
|
return w.WriteDatabase(db)
|
|
}
|
|
|
|
// writeSingleFile writes all models to a single file
|
|
func (w *Writer) writeSingleFile(db *models.Database) error {
|
|
packageName := w.getPackageName()
|
|
templateData := NewTemplateData(packageName, w.config)
|
|
|
|
// Add bun import (always needed)
|
|
templateData.AddImport(fmt.Sprintf("\"%s\"", w.typeMapper.GetBunImport()))
|
|
|
|
// Add resolvespec_common import (always needed for nullable types)
|
|
templateData.AddImport(fmt.Sprintf("resolvespec_common \"%s\"", w.typeMapper.GetSQLTypesImport()))
|
|
|
|
// Collect all models
|
|
for _, schema := range db.Schemas {
|
|
for _, table := range schema.Tables {
|
|
modelData := NewModelData(table, schema.Name, w.typeMapper)
|
|
|
|
// Add relationship fields
|
|
w.addRelationshipFields(modelData, table, schema, db)
|
|
|
|
templateData.AddModel(modelData)
|
|
|
|
// Check if we need time import
|
|
for _, field := range modelData.Fields {
|
|
if w.typeMapper.NeedsTimeImport(field.Type) {
|
|
templateData.AddImport("\"time\"")
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Add fmt import if GetIDStr is enabled
|
|
if w.config.GenerateGetIDStr {
|
|
templateData.AddImport("\"fmt\"")
|
|
}
|
|
|
|
// Finalize imports
|
|
templateData.FinalizeImports()
|
|
|
|
// Generate code
|
|
code, err := w.templates.GenerateCode(templateData)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to generate code: %w", err)
|
|
}
|
|
|
|
// Format code
|
|
formatted, err := w.formatCode(code)
|
|
if err != nil {
|
|
// Return unformatted code with warning
|
|
fmt.Fprintf(os.Stderr, "Warning: failed to format code: %v\n", err)
|
|
formatted = code
|
|
}
|
|
|
|
// Write output
|
|
if err := w.writeOutput(formatted); err != nil {
|
|
return err
|
|
}
|
|
|
|
// Run go fmt on the output file
|
|
if w.options.OutputPath != "" {
|
|
w.runGoFmt(w.options.OutputPath)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// writeMultiFile writes each table to a separate file
|
|
func (w *Writer) writeMultiFile(db *models.Database) error {
|
|
packageName := w.getPackageName()
|
|
|
|
// Check if populate_refs is enabled
|
|
populateRefs := false
|
|
if w.options.Metadata != nil {
|
|
if pr, ok := w.options.Metadata["populate_refs"].(bool); ok {
|
|
populateRefs = pr
|
|
}
|
|
}
|
|
|
|
// Ensure output path is a directory
|
|
if w.options.OutputPath == "" {
|
|
return fmt.Errorf("output path is required for multi-file mode")
|
|
}
|
|
|
|
// Create output directory if it doesn't exist
|
|
if err := os.MkdirAll(w.options.OutputPath, 0755); err != nil {
|
|
return fmt.Errorf("failed to create output directory: %w", err)
|
|
}
|
|
|
|
// Generate a file for each table
|
|
for _, schema := range db.Schemas {
|
|
// Populate RefDatabase for schema if enabled
|
|
if populateRefs && schema.RefDatabase == nil {
|
|
schema.RefDatabase = w.createDatabaseRef(db)
|
|
}
|
|
|
|
for _, table := range schema.Tables {
|
|
// Populate RefSchema for table if enabled
|
|
if populateRefs && table.RefSchema == nil {
|
|
table.RefSchema = w.createSchemaRef(schema, db)
|
|
}
|
|
// Create template data for this single table
|
|
templateData := NewTemplateData(packageName, w.config)
|
|
|
|
// Add bun import
|
|
templateData.AddImport(fmt.Sprintf("\"%s\"", w.typeMapper.GetBunImport()))
|
|
|
|
// Add resolvespec_common import
|
|
templateData.AddImport(fmt.Sprintf("resolvespec_common \"%s\"", w.typeMapper.GetSQLTypesImport()))
|
|
|
|
// Create model data
|
|
modelData := NewModelData(table, schema.Name, w.typeMapper)
|
|
|
|
// Add relationship fields
|
|
w.addRelationshipFields(modelData, table, schema, db)
|
|
|
|
templateData.AddModel(modelData)
|
|
|
|
// Check if we need time import
|
|
for _, field := range modelData.Fields {
|
|
if w.typeMapper.NeedsTimeImport(field.Type) {
|
|
templateData.AddImport("\"time\"")
|
|
}
|
|
}
|
|
|
|
// Add fmt import if GetIDStr is enabled
|
|
if w.config.GenerateGetIDStr {
|
|
templateData.AddImport("\"fmt\"")
|
|
}
|
|
|
|
// Finalize imports
|
|
templateData.FinalizeImports()
|
|
|
|
// Generate code
|
|
code, err := w.templates.GenerateCode(templateData)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to generate code for table %s: %w", table.Name, err)
|
|
}
|
|
|
|
// Format code
|
|
formatted, err := w.formatCode(code)
|
|
if err != nil {
|
|
fmt.Fprintf(os.Stderr, "Warning: failed to format code for %s: %v\n", table.Name, err)
|
|
formatted = code
|
|
}
|
|
|
|
// Generate filename: sql_{schema}_{table}.go
|
|
// Sanitize schema and table names to remove quotes, comments, and invalid characters
|
|
safeSchemaName := writers.SanitizeFilename(schema.Name)
|
|
safeTableName := writers.SanitizeFilename(table.Name)
|
|
filename := fmt.Sprintf("sql_%s_%s.go", safeSchemaName, safeTableName)
|
|
filepath := filepath.Join(w.options.OutputPath, filename)
|
|
|
|
// Write file
|
|
if err := os.WriteFile(filepath, []byte(formatted), 0644); err != nil {
|
|
return fmt.Errorf("failed to write file %s: %w", filename, err)
|
|
}
|
|
|
|
// Run go fmt on the generated file
|
|
w.runGoFmt(filepath)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// addRelationshipFields adds relationship fields to the model based on foreign keys
|
|
func (w *Writer) addRelationshipFields(modelData *ModelData, table *models.Table, schema *models.Schema, db *models.Database) {
|
|
// Track used field names to detect duplicates
|
|
usedFieldNames := make(map[string]int)
|
|
|
|
// For each foreign key in this table, add a belongs-to/has-one relationship
|
|
for _, constraint := range table.Constraints {
|
|
if constraint.Type != models.ForeignKeyConstraint {
|
|
continue
|
|
}
|
|
|
|
// Find the referenced table
|
|
refTable := w.findTable(constraint.ReferencedSchema, constraint.ReferencedTable, db)
|
|
if refTable == nil {
|
|
continue
|
|
}
|
|
|
|
// Create relationship field (has-one in Bun, similar to belongs-to in GORM)
|
|
refModelName := w.getModelName(constraint.ReferencedSchema, constraint.ReferencedTable)
|
|
fieldName := w.generateHasOneFieldName(constraint)
|
|
fieldName = w.ensureUniqueFieldName(fieldName, usedFieldNames)
|
|
relationTag := w.typeMapper.BuildRelationshipTag(constraint, "has-one")
|
|
|
|
modelData.AddRelationshipField(&FieldData{
|
|
Name: fieldName,
|
|
Type: "*" + refModelName, // Pointer type
|
|
BunTag: relationTag,
|
|
JSONTag: strings.ToLower(fieldName) + ",omitempty",
|
|
Comment: fmt.Sprintf("Has one %s", refModelName),
|
|
})
|
|
}
|
|
|
|
// For each table that references this table, add a has-many relationship
|
|
for _, otherSchema := range db.Schemas {
|
|
for _, otherTable := range otherSchema.Tables {
|
|
if otherTable.Name == table.Name && otherSchema.Name == schema.Name {
|
|
continue // Skip self
|
|
}
|
|
|
|
for _, constraint := range otherTable.Constraints {
|
|
if constraint.Type != models.ForeignKeyConstraint {
|
|
continue
|
|
}
|
|
|
|
// Check if this constraint references our table
|
|
if constraint.ReferencedTable == table.Name && constraint.ReferencedSchema == schema.Name {
|
|
// Add has-many relationship
|
|
otherModelName := w.getModelName(otherSchema.Name, otherTable.Name)
|
|
fieldName := w.generateHasManyFieldName(constraint, otherSchema.Name, otherTable.Name)
|
|
fieldName = w.ensureUniqueFieldName(fieldName, usedFieldNames)
|
|
relationTag := w.typeMapper.BuildRelationshipTag(constraint, "has-many")
|
|
|
|
modelData.AddRelationshipField(&FieldData{
|
|
Name: fieldName,
|
|
Type: "[]*" + otherModelName, // Slice of pointers
|
|
BunTag: relationTag,
|
|
JSONTag: strings.ToLower(fieldName) + ",omitempty",
|
|
Comment: fmt.Sprintf("Has many %s", otherModelName),
|
|
})
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// findTable finds a table by schema and name in the database
|
|
func (w *Writer) findTable(schemaName, tableName string, db *models.Database) *models.Table {
|
|
for _, schema := range db.Schemas {
|
|
if schema.Name != schemaName {
|
|
continue
|
|
}
|
|
for _, table := range schema.Tables {
|
|
if table.Name == tableName {
|
|
return table
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// getModelName generates the model name from schema and table name
|
|
func (w *Writer) getModelName(schemaName, tableName string) string {
|
|
singular := Singularize(tableName)
|
|
tablePart := SnakeCaseToPascalCase(singular)
|
|
|
|
// Include schema name in model name
|
|
var modelName string
|
|
if schemaName != "" {
|
|
schemaPart := SnakeCaseToPascalCase(schemaName)
|
|
modelName = "Model" + schemaPart + tablePart
|
|
} else {
|
|
modelName = "Model" + tablePart
|
|
}
|
|
|
|
return modelName
|
|
}
|
|
|
|
// generateHasOneFieldName generates a field name for has-one relationships
|
|
// Uses the foreign key column name for uniqueness
|
|
func (w *Writer) generateHasOneFieldName(constraint *models.Constraint) string {
|
|
// Use the foreign key column name to ensure uniqueness
|
|
// If there are multiple columns, use the first one
|
|
if len(constraint.Columns) > 0 {
|
|
columnName := constraint.Columns[0]
|
|
// Convert to PascalCase for proper Go field naming
|
|
// e.g., "rid_filepointer_request" -> "RelRIDFilepointerRequest"
|
|
return "Rel" + SnakeCaseToPascalCase(columnName)
|
|
}
|
|
|
|
// Fallback to table-based prefix if no columns defined
|
|
return "Rel" + GeneratePrefix(constraint.ReferencedTable)
|
|
}
|
|
|
|
// generateHasManyFieldName generates a field name for has-many relationships
|
|
// Uses the foreign key column name + source table name to avoid duplicates
|
|
func (w *Writer) generateHasManyFieldName(constraint *models.Constraint, sourceSchemaName, sourceTableName string) string {
|
|
// For has-many, we need to include the source table name to avoid duplicates
|
|
// e.g., multiple tables referencing the same column on this table
|
|
if len(constraint.Columns) > 0 {
|
|
columnName := constraint.Columns[0]
|
|
// Get the model name for the source table (pluralized)
|
|
sourceModelName := w.getModelName(sourceSchemaName, sourceTableName)
|
|
// Remove "Model" prefix if present
|
|
sourceModelName = strings.TrimPrefix(sourceModelName, "Model")
|
|
|
|
// Convert column to PascalCase and combine with source table
|
|
// e.g., "rid_api_provider" + "Login" -> "RelRIDAPIProviderLogins"
|
|
columnPart := SnakeCaseToPascalCase(columnName)
|
|
return "Rel" + columnPart + Pluralize(sourceModelName)
|
|
}
|
|
|
|
// Fallback to table-based naming
|
|
sourceModelName := w.getModelName(sourceSchemaName, sourceTableName)
|
|
sourceModelName = strings.TrimPrefix(sourceModelName, "Model")
|
|
return "Rel" + Pluralize(sourceModelName)
|
|
}
|
|
|
|
// ensureUniqueFieldName ensures a field name is unique by adding numeric suffixes if needed
|
|
func (w *Writer) ensureUniqueFieldName(fieldName string, usedNames map[string]int) string {
|
|
originalName := fieldName
|
|
count := usedNames[originalName]
|
|
|
|
if count > 0 {
|
|
// Name is already used, add numeric suffix
|
|
fieldName = fmt.Sprintf("%s%d", originalName, count+1)
|
|
}
|
|
|
|
// Increment the counter for this base name
|
|
usedNames[originalName]++
|
|
|
|
return fieldName
|
|
}
|
|
|
|
// getPackageName returns the package name from options or defaults to "models"
|
|
func (w *Writer) getPackageName() string {
|
|
if w.options.PackageName != "" {
|
|
return w.options.PackageName
|
|
}
|
|
return "models"
|
|
}
|
|
|
|
// formatCode formats Go code using gofmt
|
|
func (w *Writer) formatCode(code string) (string, error) {
|
|
formatted, err := format.Source([]byte(code))
|
|
if err != nil {
|
|
return "", fmt.Errorf("format error: %w", err)
|
|
}
|
|
return string(formatted), nil
|
|
}
|
|
|
|
// writeOutput writes the content to file or stdout
|
|
func (w *Writer) writeOutput(content string) error {
|
|
if w.options.OutputPath != "" {
|
|
return os.WriteFile(w.options.OutputPath, []byte(content), 0644)
|
|
}
|
|
|
|
// Print to stdout
|
|
fmt.Print(content)
|
|
return nil
|
|
}
|
|
|
|
// runGoFmt runs go fmt on the specified file
|
|
func (w *Writer) runGoFmt(filepath string) {
|
|
cmd := exec.Command("gofmt", "-w", filepath)
|
|
if err := cmd.Run(); err != nil {
|
|
// Don't fail the whole operation if gofmt fails, just warn
|
|
fmt.Fprintf(os.Stderr, "Warning: failed to run gofmt on %s: %v\n", filepath, err)
|
|
}
|
|
}
|
|
|
|
// shouldUseMultiFile determines whether to use multi-file mode based on metadata or output path
|
|
func (w *Writer) shouldUseMultiFile() bool {
|
|
// Check if multi_file is explicitly set in metadata
|
|
if w.options.Metadata != nil {
|
|
if mf, ok := w.options.Metadata["multi_file"].(bool); ok {
|
|
return mf
|
|
}
|
|
}
|
|
|
|
// Auto-detect based on output path
|
|
if w.options.OutputPath == "" {
|
|
// No output path means stdout (single file)
|
|
return false
|
|
}
|
|
|
|
// Check if path ends with .go (explicit file)
|
|
if strings.HasSuffix(w.options.OutputPath, ".go") {
|
|
return false
|
|
}
|
|
|
|
// Check if path ends with directory separator
|
|
if strings.HasSuffix(w.options.OutputPath, "/") || strings.HasSuffix(w.options.OutputPath, "\\") {
|
|
return true
|
|
}
|
|
|
|
// Check if path exists and is a directory
|
|
info, err := os.Stat(w.options.OutputPath)
|
|
if err == nil && info.IsDir() {
|
|
return true
|
|
}
|
|
|
|
// Default to single file for ambiguous cases
|
|
return false
|
|
}
|
|
|
|
// createDatabaseRef creates a shallow copy of database without schemas to avoid circular references
|
|
func (w *Writer) createDatabaseRef(db *models.Database) *models.Database {
|
|
return &models.Database{
|
|
Name: db.Name,
|
|
Description: db.Description,
|
|
Comment: db.Comment,
|
|
DatabaseType: db.DatabaseType,
|
|
DatabaseVersion: db.DatabaseVersion,
|
|
SourceFormat: db.SourceFormat,
|
|
Schemas: nil, // Don't include schemas to avoid circular reference
|
|
GUID: db.GUID,
|
|
}
|
|
}
|
|
|
|
// createSchemaRef creates a shallow copy of schema without tables to avoid circular references
|
|
func (w *Writer) createSchemaRef(schema *models.Schema, db *models.Database) *models.Schema {
|
|
return &models.Schema{
|
|
Name: schema.Name,
|
|
Description: schema.Description,
|
|
Owner: schema.Owner,
|
|
Permissions: schema.Permissions,
|
|
Comment: schema.Comment,
|
|
Metadata: schema.Metadata,
|
|
Scripts: schema.Scripts,
|
|
Sequence: schema.Sequence,
|
|
RefDatabase: w.createDatabaseRef(db), // Include database ref
|
|
Tables: nil, // Don't include tables to avoid circular reference
|
|
GUID: schema.GUID,
|
|
}
|
|
}
|