Files
relspecgo/pkg/writers/pgsql/migration_writer.go
Hein 5e1448dcdb
Some checks are pending
CI / Test (1.23) (push) Waiting to run
CI / Test (1.24) (push) Waiting to run
CI / Test (1.25) (push) Waiting to run
CI / Lint (push) Waiting to run
CI / Build (push) Waiting to run
sql writer
2025-12-17 20:44:02 +02:00

839 lines
23 KiB
Go

package pgsql
import (
"fmt"
"io"
"os"
"sort"
"strings"
"git.warky.dev/wdevs/relspecgo/pkg/models"
"git.warky.dev/wdevs/relspecgo/pkg/writers"
)
// MigrationScript represents a single migration script with priority
type MigrationScript struct {
ObjectName string
ObjectType string
Schema string
Priority int
Sequence int
Body string
}
// MigrationWriter generates differential migration SQL scripts using templates
type MigrationWriter struct {
options *writers.WriterOptions
writer io.Writer
executor *TemplateExecutor
}
// NewMigrationWriter creates a new templated migration writer
func NewMigrationWriter(options *writers.WriterOptions) (*MigrationWriter, error) {
executor, err := NewTemplateExecutor()
if err != nil {
return nil, fmt.Errorf("failed to create template executor: %w", err)
}
return &MigrationWriter{
options: options,
executor: executor,
}, nil
}
// WriteMigration generates migration scripts using templates
func (w *MigrationWriter) WriteMigration(model *models.Database, current *models.Database) error {
var writer io.Writer
var file *os.File
var err error
// Use existing writer if already set (for testing)
if w.writer != nil {
writer = w.writer
} else if w.options.OutputPath != "" {
file, err = os.Create(w.options.OutputPath)
if err != nil {
return fmt.Errorf("failed to create output file: %w", err)
}
defer file.Close()
writer = file
} else {
writer = os.Stdout
}
w.writer = writer
// Check if audit is configured in metadata
var auditConfig *AuditConfig
if w.options.Metadata != nil {
if ac, ok := w.options.Metadata["audit_config"].(*AuditConfig); ok {
auditConfig = ac
}
}
// Generate all migration scripts
scripts := make([]MigrationScript, 0)
// Generate audit tables if needed (priority 90)
if auditConfig != nil && len(auditConfig.EnabledTables) > 0 {
auditTableScript, err := w.generateAuditTablesScript(auditConfig)
if err != nil {
return fmt.Errorf("failed to generate audit tables: %w", err)
}
scripts = append(scripts, auditTableScript...)
}
// Process each schema in the model
for _, modelSchema := range model.Schemas {
// Find corresponding schema in current database
var currentSchema *models.Schema
for _, cs := range current.Schemas {
if strings.EqualFold(cs.Name, modelSchema.Name) {
currentSchema = cs
break
}
}
// Generate schema-level scripts
schemaScripts, err := w.generateSchemaScripts(modelSchema, currentSchema)
if err != nil {
return fmt.Errorf("failed to generate schema scripts: %w", err)
}
scripts = append(scripts, schemaScripts...)
// Generate audit scripts for this schema (if configured)
if auditConfig != nil {
auditScripts, err := w.generateAuditScripts(modelSchema, auditConfig)
if err != nil {
return fmt.Errorf("failed to generate audit scripts: %w", err)
}
scripts = append(scripts, auditScripts...)
}
}
// Sort scripts by priority and sequence
sort.Slice(scripts, func(i, j int) bool {
if scripts[i].Priority != scripts[j].Priority {
return scripts[i].Priority < scripts[j].Priority
}
return scripts[i].Sequence < scripts[j].Sequence
})
// Write header
fmt.Fprintf(w.writer, "-- PostgreSQL Migration Script\n")
fmt.Fprintf(w.writer, "-- Generated by RelSpec\n")
fmt.Fprintf(w.writer, "-- Source: %s -> %s\n\n", current.Name, model.Name)
// Write scripts
for _, script := range scripts {
fmt.Fprintf(w.writer, "-- Priority: %d | Type: %s | Object: %s\n",
script.Priority, script.ObjectType, script.ObjectName)
fmt.Fprintf(w.writer, "%s\n\n", script.Body)
}
return nil
}
// generateSchemaScripts generates migration scripts for a schema using templates
func (w *MigrationWriter) generateSchemaScripts(model *models.Schema, current *models.Schema) ([]MigrationScript, error) {
scripts := make([]MigrationScript, 0)
// Phase 1: Drop constraints and indexes that changed (Priority 11-50)
if current != nil {
dropScripts, err := w.generateDropScripts(model, current)
if err != nil {
return nil, fmt.Errorf("failed to generate drop scripts: %w", err)
}
scripts = append(scripts, dropScripts...)
}
// Phase 3: Create/Alter tables and columns (Priority 100-145)
tableScripts, err := w.generateTableScripts(model, current)
if err != nil {
return nil, fmt.Errorf("failed to generate table scripts: %w", err)
}
scripts = append(scripts, tableScripts...)
// Phase 4: Create indexes (Priority 160-180)
indexScripts, err := w.generateIndexScripts(model, current)
if err != nil {
return nil, fmt.Errorf("failed to generate index scripts: %w", err)
}
scripts = append(scripts, indexScripts...)
// Phase 5: Create foreign keys (Priority 195)
fkScripts, err := w.generateForeignKeyScripts(model, current)
if err != nil {
return nil, fmt.Errorf("failed to generate foreign key scripts: %w", err)
}
scripts = append(scripts, fkScripts...)
// Phase 6: Add comments (Priority 200+)
commentScripts, err := w.generateCommentScripts(model, current)
if err != nil {
return nil, fmt.Errorf("failed to generate comment scripts: %w", err)
}
scripts = append(scripts, commentScripts...)
return scripts, nil
}
// generateDropScripts generates DROP scripts using templates
func (w *MigrationWriter) generateDropScripts(model *models.Schema, current *models.Schema) ([]MigrationScript, error) {
scripts := make([]MigrationScript, 0)
// Build map of model tables for quick lookup
modelTables := make(map[string]*models.Table)
for _, table := range model.Tables {
modelTables[strings.ToLower(table.Name)] = table
}
// Find constraints to drop
for _, currentTable := range current.Tables {
modelTable, existsInModel := modelTables[strings.ToLower(currentTable.Name)]
if !existsInModel {
continue
}
// Check each constraint in current database
for constraintName, currentConstraint := range currentTable.Constraints {
modelConstraint, existsInModel := modelTable.Constraints[constraintName]
shouldDrop := false
if !existsInModel {
shouldDrop = true
} else if !constraintsEqual(modelConstraint, currentConstraint) {
shouldDrop = true
}
if shouldDrop {
sql, err := w.executor.ExecuteDropConstraint(DropConstraintData{
SchemaName: current.Name,
TableName: currentTable.Name,
ConstraintName: constraintName,
})
if err != nil {
return nil, err
}
script := MigrationScript{
ObjectName: fmt.Sprintf("%s.%s.%s", current.Name, currentTable.Name, constraintName),
ObjectType: "drop constraint",
Schema: current.Name,
Priority: 11,
Sequence: len(scripts),
Body: sql,
}
scripts = append(scripts, script)
}
}
// Check indexes
for indexName, currentIndex := range currentTable.Indexes {
modelIndex, existsInModel := modelTable.Indexes[indexName]
shouldDrop := false
if !existsInModel {
shouldDrop = true
} else if !indexesEqual(modelIndex, currentIndex) {
shouldDrop = true
}
if shouldDrop {
sql, err := w.executor.ExecuteDropIndex(DropIndexData{
SchemaName: current.Name,
IndexName: indexName,
})
if err != nil {
return nil, err
}
script := MigrationScript{
ObjectName: fmt.Sprintf("%s.%s.%s", current.Name, currentTable.Name, indexName),
ObjectType: "drop index",
Schema: current.Name,
Priority: 20,
Sequence: len(scripts),
Body: sql,
}
scripts = append(scripts, script)
}
}
}
return scripts, nil
}
// generateTableScripts generates CREATE/ALTER TABLE scripts using templates
func (w *MigrationWriter) generateTableScripts(model *models.Schema, current *models.Schema) ([]MigrationScript, error) {
scripts := make([]MigrationScript, 0)
// Build map of current tables
currentTables := make(map[string]*models.Table)
if current != nil {
for _, table := range current.Tables {
currentTables[strings.ToLower(table.Name)] = table
}
}
// Process each model table
for _, modelTable := range model.Tables {
currentTable, exists := currentTables[strings.ToLower(modelTable.Name)]
if !exists {
// Table doesn't exist, create it
sql, err := w.executor.ExecuteCreateTable(BuildCreateTableData(model.Name, modelTable))
if err != nil {
return nil, err
}
script := MigrationScript{
ObjectName: fmt.Sprintf("%s.%s", model.Name, modelTable.Name),
ObjectType: "create table",
Schema: model.Name,
Priority: 100,
Sequence: len(scripts),
Body: sql,
}
scripts = append(scripts, script)
} else {
// Table exists, check for column changes
alterScripts, err := w.generateAlterTableScripts(model, modelTable, currentTable)
if err != nil {
return nil, err
}
scripts = append(scripts, alterScripts...)
}
}
return scripts, nil
}
// generateAlterTableScripts generates ALTER TABLE scripts using templates
func (w *MigrationWriter) generateAlterTableScripts(schema *models.Schema, modelTable *models.Table, currentTable *models.Table) ([]MigrationScript, error) {
scripts := make([]MigrationScript, 0)
// Build map of current columns
currentColumns := make(map[string]*models.Column)
for name, col := range currentTable.Columns {
currentColumns[strings.ToLower(name)] = col
}
// Check each model column
for _, modelCol := range modelTable.Columns {
currentCol, exists := currentColumns[strings.ToLower(modelCol.Name)]
if !exists {
// Column doesn't exist, add it
defaultVal := ""
if modelCol.Default != nil {
defaultVal = fmt.Sprintf("%v", modelCol.Default)
}
sql, err := w.executor.ExecuteAddColumn(AddColumnData{
SchemaName: schema.Name,
TableName: modelTable.Name,
ColumnName: modelCol.Name,
ColumnType: modelCol.Type,
Default: defaultVal,
NotNull: modelCol.NotNull,
})
if err != nil {
return nil, err
}
script := MigrationScript{
ObjectName: fmt.Sprintf("%s.%s.%s", schema.Name, modelTable.Name, modelCol.Name),
ObjectType: "create column",
Schema: schema.Name,
Priority: 120,
Sequence: len(scripts),
Body: sql,
}
scripts = append(scripts, script)
} else if !columnsEqual(modelCol, currentCol) {
// Column exists but properties changed
if modelCol.Type != currentCol.Type {
sql, err := w.executor.ExecuteAlterColumnType(AlterColumnTypeData{
SchemaName: schema.Name,
TableName: modelTable.Name,
ColumnName: modelCol.Name,
NewType: modelCol.Type,
})
if err != nil {
return nil, err
}
script := MigrationScript{
ObjectName: fmt.Sprintf("%s.%s.%s", schema.Name, modelTable.Name, modelCol.Name),
ObjectType: "alter column type",
Schema: schema.Name,
Priority: 120,
Sequence: len(scripts),
Body: sql,
}
scripts = append(scripts, script)
}
// Check default value changes
if fmt.Sprintf("%v", modelCol.Default) != fmt.Sprintf("%v", currentCol.Default) {
setDefault := modelCol.Default != nil
defaultVal := ""
if setDefault {
defaultVal = fmt.Sprintf("%v", modelCol.Default)
}
sql, err := w.executor.ExecuteAlterColumnDefault(AlterColumnDefaultData{
SchemaName: schema.Name,
TableName: modelTable.Name,
ColumnName: modelCol.Name,
SetDefault: setDefault,
DefaultValue: defaultVal,
})
if err != nil {
return nil, err
}
script := MigrationScript{
ObjectName: fmt.Sprintf("%s.%s.%s", schema.Name, modelTable.Name, modelCol.Name),
ObjectType: "alter column default",
Schema: schema.Name,
Priority: 145,
Sequence: len(scripts),
Body: sql,
}
scripts = append(scripts, script)
}
}
}
return scripts, nil
}
// generateIndexScripts generates CREATE INDEX scripts using templates
func (w *MigrationWriter) generateIndexScripts(model *models.Schema, current *models.Schema) ([]MigrationScript, error) {
scripts := make([]MigrationScript, 0)
// Build map of current tables
currentTables := make(map[string]*models.Table)
if current != nil {
for _, table := range current.Tables {
currentTables[strings.ToLower(table.Name)] = table
}
}
// Process each model table
for _, modelTable := range model.Tables {
currentTable := currentTables[strings.ToLower(modelTable.Name)]
// Process primary keys first
for constraintName, constraint := range modelTable.Constraints {
if constraint.Type == models.PrimaryKeyConstraint {
shouldCreate := true
if currentTable != nil {
if currentConstraint, exists := currentTable.Constraints[constraintName]; exists {
if constraintsEqual(constraint, currentConstraint) {
shouldCreate = false
}
}
}
if shouldCreate {
sql, err := w.executor.ExecuteCreatePrimaryKey(CreatePrimaryKeyData{
SchemaName: model.Name,
TableName: modelTable.Name,
ConstraintName: constraintName,
Columns: strings.Join(constraint.Columns, ", "),
})
if err != nil {
return nil, err
}
script := MigrationScript{
ObjectName: fmt.Sprintf("%s.%s.%s", model.Name, modelTable.Name, constraintName),
ObjectType: "create primary key",
Schema: model.Name,
Priority: 160,
Sequence: len(scripts),
Body: sql,
}
scripts = append(scripts, script)
}
}
}
// Process indexes
for indexName, modelIndex := range modelTable.Indexes {
// Skip primary key indexes
if strings.HasPrefix(strings.ToLower(indexName), "pk_") {
continue
}
shouldCreate := true
if currentTable != nil {
if currentIndex, exists := currentTable.Indexes[indexName]; exists {
if indexesEqual(modelIndex, currentIndex) {
shouldCreate = false
}
}
}
if shouldCreate {
indexType := "btree"
if modelIndex.Type != "" {
indexType = modelIndex.Type
}
sql, err := w.executor.ExecuteCreateIndex(CreateIndexData{
SchemaName: model.Name,
TableName: modelTable.Name,
IndexName: indexName,
IndexType: indexType,
Columns: strings.Join(modelIndex.Columns, ", "),
Unique: modelIndex.Unique,
})
if err != nil {
return nil, err
}
script := MigrationScript{
ObjectName: fmt.Sprintf("%s.%s.%s", model.Name, modelTable.Name, indexName),
ObjectType: "create index",
Schema: model.Name,
Priority: 180,
Sequence: len(scripts),
Body: sql,
}
scripts = append(scripts, script)
}
}
}
return scripts, nil
}
// generateForeignKeyScripts generates ADD CONSTRAINT FOREIGN KEY scripts using templates
func (w *MigrationWriter) generateForeignKeyScripts(model *models.Schema, current *models.Schema) ([]MigrationScript, error) {
scripts := make([]MigrationScript, 0)
// Build map of current tables
currentTables := make(map[string]*models.Table)
if current != nil {
for _, table := range current.Tables {
currentTables[strings.ToLower(table.Name)] = table
}
}
// Process each model table
for _, modelTable := range model.Tables {
currentTable := currentTables[strings.ToLower(modelTable.Name)]
// Process each constraint
for constraintName, constraint := range modelTable.Constraints {
if constraint.Type != models.ForeignKeyConstraint {
continue
}
shouldCreate := true
if currentTable != nil {
if currentConstraint, exists := currentTable.Constraints[constraintName]; exists {
if constraintsEqual(constraint, currentConstraint) {
shouldCreate = false
}
}
}
if shouldCreate {
onDelete := "NO ACTION"
if constraint.OnDelete != "" {
onDelete = strings.ToUpper(constraint.OnDelete)
}
onUpdate := "NO ACTION"
if constraint.OnUpdate != "" {
onUpdate = strings.ToUpper(constraint.OnUpdate)
}
sql, err := w.executor.ExecuteCreateForeignKey(CreateForeignKeyData{
SchemaName: model.Name,
TableName: modelTable.Name,
ConstraintName: constraintName,
SourceColumns: strings.Join(constraint.Columns, ", "),
TargetSchema: constraint.ReferencedSchema,
TargetTable: constraint.ReferencedTable,
TargetColumns: strings.Join(constraint.ReferencedColumns, ", "),
OnDelete: onDelete,
OnUpdate: onUpdate,
})
if err != nil {
return nil, err
}
script := MigrationScript{
ObjectName: fmt.Sprintf("%s.%s.%s", model.Name, modelTable.Name, constraintName),
ObjectType: "create foreign key",
Schema: model.Name,
Priority: 195,
Sequence: len(scripts),
Body: sql,
}
scripts = append(scripts, script)
}
}
}
return scripts, nil
}
// generateCommentScripts generates COMMENT ON scripts using templates
func (w *MigrationWriter) generateCommentScripts(model *models.Schema, current *models.Schema) ([]MigrationScript, error) {
scripts := make([]MigrationScript, 0)
_ = current // TODO: Compare with current schema to only add new/changed comments
// Process each model table
for _, modelTable := range model.Tables {
// Table comment
if modelTable.Description != "" {
sql, err := w.executor.ExecuteCommentTable(CommentTableData{
SchemaName: model.Name,
TableName: modelTable.Name,
Comment: escapeQuote(modelTable.Description),
})
if err != nil {
return nil, err
}
script := MigrationScript{
ObjectName: fmt.Sprintf("%s.%s", model.Name, modelTable.Name),
ObjectType: "comment on table",
Schema: model.Name,
Priority: 200,
Sequence: len(scripts),
Body: sql,
}
scripts = append(scripts, script)
}
// Column comments
for _, col := range modelTable.Columns {
if col.Description != "" {
sql, err := w.executor.ExecuteCommentColumn(CommentColumnData{
SchemaName: model.Name,
TableName: modelTable.Name,
ColumnName: col.Name,
Comment: escapeQuote(col.Description),
})
if err != nil {
return nil, err
}
script := MigrationScript{
ObjectName: fmt.Sprintf("%s.%s.%s", model.Name, modelTable.Name, col.Name),
ObjectType: "comment on column",
Schema: model.Name,
Priority: 200,
Sequence: len(scripts),
Body: sql,
}
scripts = append(scripts, script)
}
}
}
return scripts, nil
}
// generateAuditTablesScript generates audit table creation scripts using templates
func (w *MigrationWriter) generateAuditTablesScript(auditConfig *AuditConfig) ([]MigrationScript, error) {
scripts := make([]MigrationScript, 0)
auditSchema := auditConfig.AuditSchema
if auditSchema == "" {
auditSchema = "public"
}
sql, err := w.executor.ExecuteAuditTables(AuditTablesData{
AuditSchema: auditSchema,
})
if err != nil {
return nil, err
}
script := MigrationScript{
ObjectName: fmt.Sprintf("%s.atevent+atdetail", auditSchema),
ObjectType: "create audit tables",
Schema: auditSchema,
Priority: 90,
Sequence: 0,
Body: sql,
}
scripts = append(scripts, script)
return scripts, nil
}
// generateAuditScripts generates audit functions and triggers using templates
func (w *MigrationWriter) generateAuditScripts(schema *models.Schema, auditConfig *AuditConfig) ([]MigrationScript, error) {
scripts := make([]MigrationScript, 0)
// Process each table in the schema
for _, table := range schema.Tables {
if !auditConfig.IsTableAudited(schema.Name, table.Name) {
continue
}
config := auditConfig.GetTableConfig(schema.Name, table.Name)
if config == nil {
continue
}
// Find primary key
pk := table.GetPrimaryKey()
if pk == nil {
continue
}
auditSchema := auditConfig.AuditSchema
if auditSchema == "" {
auditSchema = schema.Name
}
// Generate audit function
funcName := fmt.Sprintf("ft_audit_%s", table.Name)
funcData := BuildAuditFunctionData(schema.Name, table, pk, config, auditSchema, auditConfig.UserFunction)
funcSQL, err := w.executor.ExecuteAuditFunction(funcData)
if err != nil {
return nil, err
}
functionScript := MigrationScript{
ObjectName: fmt.Sprintf("%s.%s", schema.Name, funcName),
ObjectType: "create audit function",
Schema: schema.Name,
Priority: 345,
Sequence: len(scripts),
Body: funcSQL,
}
scripts = append(scripts, functionScript)
// Generate audit trigger
triggerName := fmt.Sprintf("t_audit_%s", table.Name)
events := make([]string, 0)
if config.AuditInsert {
events = append(events, "INSERT")
}
if config.AuditUpdate {
events = append(events, "UPDATE")
}
if config.AuditDelete {
events = append(events, "DELETE")
}
if len(events) == 0 {
continue
}
triggerSQL, err := w.executor.ExecuteAuditTrigger(AuditTriggerData{
SchemaName: schema.Name,
TableName: table.Name,
TriggerName: triggerName,
FunctionName: funcName,
Events: strings.Join(events, " OR "),
})
if err != nil {
return nil, err
}
triggerScript := MigrationScript{
ObjectName: fmt.Sprintf("%s.%s", schema.Name, triggerName),
ObjectType: "create audit trigger",
Schema: schema.Name,
Priority: 355,
Sequence: len(scripts),
Body: triggerSQL,
}
scripts = append(scripts, triggerScript)
}
return scripts, nil
}
// Helper functions for comparing database objects
// columnsEqual checks if two columns have the same definition
func columnsEqual(col1, col2 *models.Column) bool {
if col1 == nil || col2 == nil {
return false
}
return strings.EqualFold(col1.Type, col2.Type) &&
col1.NotNull == col2.NotNull &&
fmt.Sprintf("%v", col1.Default) == fmt.Sprintf("%v", col2.Default)
}
// constraintsEqual checks if two constraints are equal
func constraintsEqual(c1, c2 *models.Constraint) bool {
if c1 == nil || c2 == nil {
return false
}
if c1.Type != c2.Type {
return false
}
// Compare columns
if len(c1.Columns) != len(c2.Columns) {
return false
}
for i, col := range c1.Columns {
if !strings.EqualFold(col, c2.Columns[i]) {
return false
}
}
// For foreign keys, also compare referenced table and columns
if c1.Type == models.ForeignKeyConstraint {
if !strings.EqualFold(c1.ReferencedTable, c2.ReferencedTable) {
return false
}
if len(c1.ReferencedColumns) != len(c2.ReferencedColumns) {
return false
}
for i, col := range c1.ReferencedColumns {
if !strings.EqualFold(col, c2.ReferencedColumns[i]) {
return false
}
}
if c1.OnDelete != c2.OnDelete || c1.OnUpdate != c2.OnUpdate {
return false
}
}
return true
}
// indexesEqual checks if two indexes are equal
func indexesEqual(idx1, idx2 *models.Index) bool {
if idx1 == nil || idx2 == nil {
return false
}
if idx1.Unique != idx2.Unique {
return false
}
if !strings.EqualFold(idx1.Type, idx2.Type) {
return false
}
if len(idx1.Columns) != len(idx2.Columns) {
return false
}
for i, col := range idx1.Columns {
if !strings.EqualFold(col, idx2.Columns[i]) {
return false
}
}
return true
}