All checks were successful
CI / Test (1.24) (push) Successful in -25m29s
CI / Test (1.25) (push) Successful in -25m13s
CI / Lint (push) Successful in -26m13s
CI / Build (push) Successful in -26m27s
Integration Tests / Integration Tests (push) Successful in -26m11s
Release / Build and Release (push) Successful in -25m8s
- Implemented ExecutionReport to track the execution status of SQL statements. - Added SchemaReport and TableReport to monitor execution per schema and table. - Enhanced WriteDatabase to execute SQL directly on a PostgreSQL database if a connection string is provided. - Included error handling and logging for failed statements during execution. - Added functionality to write execution reports to a JSON file. - Introduced utility functions to extract table names from CREATE TABLE statements and truncate long SQL statements for error messages.
1086 lines
32 KiB
Go
1086 lines
32 KiB
Go
package pgsql
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"os"
|
|
"sort"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/jackc/pgx/v5"
|
|
|
|
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
|
"git.warky.dev/wdevs/relspecgo/pkg/writers"
|
|
)
|
|
|
|
// Writer implements the Writer interface for PostgreSQL SQL output
|
|
type Writer struct {
|
|
options *writers.WriterOptions
|
|
writer io.Writer
|
|
executionReport *ExecutionReport
|
|
}
|
|
|
|
// ExecutionReport tracks the execution status of SQL statements
|
|
type ExecutionReport struct {
|
|
TotalStatements int `json:"total_statements"`
|
|
ExecutedStatements int `json:"executed_statements"`
|
|
FailedStatements int `json:"failed_statements"`
|
|
Schemas []SchemaReport `json:"schemas"`
|
|
Errors []ExecutionError `json:"errors,omitempty"`
|
|
StartTime string `json:"start_time"`
|
|
EndTime string `json:"end_time"`
|
|
}
|
|
|
|
// SchemaReport tracks execution per schema
|
|
type SchemaReport struct {
|
|
Name string `json:"name"`
|
|
Tables []TableReport `json:"tables"`
|
|
}
|
|
|
|
// TableReport tracks execution per table
|
|
type TableReport struct {
|
|
Name string `json:"name"`
|
|
Created bool `json:"created"`
|
|
Error string `json:"error,omitempty"`
|
|
}
|
|
|
|
// ExecutionError represents a failed statement
|
|
type ExecutionError struct {
|
|
StatementNumber int `json:"statement_number"`
|
|
Statement string `json:"statement"`
|
|
Error string `json:"error"`
|
|
}
|
|
|
|
// NewWriter creates a new PostgreSQL SQL writer
|
|
func NewWriter(options *writers.WriterOptions) *Writer {
|
|
return &Writer{
|
|
options: options,
|
|
}
|
|
}
|
|
|
|
// WriteDatabase writes the entire database schema as SQL
|
|
func (w *Writer) WriteDatabase(db *models.Database) error {
|
|
// Check if we should execute SQL directly on a database
|
|
if connString, ok := w.options.Metadata["connection_string"].(string); ok && connString != "" {
|
|
return w.executeDatabaseSQL(db, connString)
|
|
}
|
|
|
|
var writer io.Writer
|
|
var file *os.File
|
|
var err error
|
|
|
|
// Use existing writer if already set (for testing)
|
|
if w.writer != nil {
|
|
writer = w.writer
|
|
} else if w.options.OutputPath != "" {
|
|
// Determine output destination
|
|
file, err = os.Create(w.options.OutputPath)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to create output file: %w", err)
|
|
}
|
|
defer file.Close()
|
|
writer = file
|
|
} else {
|
|
writer = os.Stdout
|
|
}
|
|
|
|
w.writer = writer
|
|
|
|
// Write header comment
|
|
fmt.Fprintf(w.writer, "-- PostgreSQL Database Schema\n")
|
|
fmt.Fprintf(w.writer, "-- Database: %s\n", db.Name)
|
|
fmt.Fprintf(w.writer, "-- Generated by RelSpec\n\n")
|
|
|
|
// Process each schema in the database
|
|
for _, schema := range db.Schemas {
|
|
if err := w.WriteSchema(schema); err != nil {
|
|
return fmt.Errorf("failed to write schema %s: %w", schema.Name, err)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// GenerateDatabaseStatements generates SQL statements as a list for the entire database
|
|
// Returns a slice of SQL statements that can be executed independently
|
|
func (w *Writer) GenerateDatabaseStatements(db *models.Database) ([]string, error) {
|
|
statements := []string{}
|
|
|
|
// Add header comment
|
|
statements = append(statements, "-- PostgreSQL Database Schema")
|
|
statements = append(statements, fmt.Sprintf("-- Database: %s", db.Name))
|
|
statements = append(statements, "-- Generated by RelSpec")
|
|
|
|
// Process each schema in the database
|
|
for _, schema := range db.Schemas {
|
|
schemaStatements, err := w.GenerateSchemaStatements(schema)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to generate statements for schema %s: %w", schema.Name, err)
|
|
}
|
|
statements = append(statements, schemaStatements...)
|
|
}
|
|
|
|
return statements, nil
|
|
}
|
|
|
|
// GenerateSchemaStatements generates SQL statements as a list for a single schema
|
|
func (w *Writer) GenerateSchemaStatements(schema *models.Schema) ([]string, error) {
|
|
statements := []string{}
|
|
|
|
// Phase 1: Create schema
|
|
if schema.Name != "public" {
|
|
statements = append(statements, fmt.Sprintf("-- Schema: %s", schema.Name))
|
|
statements = append(statements, fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS %s", schema.SQLName()))
|
|
}
|
|
|
|
// Phase 2: Create sequences
|
|
for _, table := range schema.Tables {
|
|
pk := table.GetPrimaryKey()
|
|
if pk == nil || !isIntegerType(pk.Type) || pk.Default == "" {
|
|
continue
|
|
}
|
|
|
|
defaultStr, ok := pk.Default.(string)
|
|
if !ok || !strings.Contains(strings.ToLower(defaultStr), "nextval") {
|
|
continue
|
|
}
|
|
|
|
seqName := extractSequenceName(defaultStr)
|
|
if seqName == "" {
|
|
continue
|
|
}
|
|
|
|
stmt := fmt.Sprintf("CREATE SEQUENCE IF NOT EXISTS %s.%s\n INCREMENT 1\n MINVALUE 1\n MAXVALUE 9223372036854775807\n START 1\n CACHE 1",
|
|
schema.SQLName(), seqName)
|
|
statements = append(statements, stmt)
|
|
}
|
|
|
|
// Phase 3: Create tables
|
|
for _, table := range schema.Tables {
|
|
stmts, err := w.generateCreateTableStatement(schema, table)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to generate table %s: %w", table.Name, err)
|
|
}
|
|
statements = append(statements, stmts...)
|
|
}
|
|
|
|
// Phase 4: Primary keys
|
|
for _, table := range schema.Tables {
|
|
// First check for explicit PrimaryKeyConstraint
|
|
var pkConstraint *models.Constraint
|
|
for _, constraint := range table.Constraints {
|
|
if constraint.Type == models.PrimaryKeyConstraint {
|
|
pkConstraint = constraint
|
|
break
|
|
}
|
|
}
|
|
|
|
if pkConstraint != nil {
|
|
stmt := fmt.Sprintf("ALTER TABLE %s.%s ADD CONSTRAINT %s PRIMARY KEY (%s)",
|
|
schema.SQLName(), table.SQLName(), pkConstraint.Name, strings.Join(pkConstraint.Columns, ", "))
|
|
statements = append(statements, stmt)
|
|
} else {
|
|
// No explicit constraint, check for columns with IsPrimaryKey = true
|
|
pkColumns := []string{}
|
|
for _, col := range table.Columns {
|
|
if col.IsPrimaryKey {
|
|
pkColumns = append(pkColumns, col.SQLName())
|
|
}
|
|
}
|
|
if len(pkColumns) > 0 {
|
|
// Sort for consistent output
|
|
sort.Strings(pkColumns)
|
|
pkName := fmt.Sprintf("pk_%s_%s", schema.SQLName(), table.SQLName())
|
|
stmt := fmt.Sprintf("ALTER TABLE %s.%s ADD CONSTRAINT %s PRIMARY KEY (%s)",
|
|
schema.SQLName(), table.SQLName(), pkName, strings.Join(pkColumns, ", "))
|
|
statements = append(statements, stmt)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Phase 5: Indexes
|
|
for _, table := range schema.Tables {
|
|
for _, index := range table.Indexes {
|
|
// Skip primary key indexes
|
|
if strings.HasSuffix(index.Name, "_pkey") {
|
|
continue
|
|
}
|
|
|
|
uniqueStr := ""
|
|
if index.Unique {
|
|
uniqueStr = "UNIQUE "
|
|
}
|
|
|
|
indexType := index.Type
|
|
if indexType == "" {
|
|
indexType = "btree"
|
|
}
|
|
|
|
// Build column expressions with operator class support for GIN indexes
|
|
columnExprs := make([]string, 0, len(index.Columns))
|
|
for _, colName := range index.Columns {
|
|
colExpr := colName
|
|
if col, ok := table.Columns[colName]; ok {
|
|
// For GIN indexes on text columns, add operator class
|
|
if strings.EqualFold(indexType, "gin") && isTextType(col.Type) {
|
|
opClass := extractOperatorClass(index.Comment)
|
|
if opClass == "" {
|
|
opClass = "gin_trgm_ops"
|
|
}
|
|
colExpr = fmt.Sprintf("%s %s", colName, opClass)
|
|
}
|
|
}
|
|
columnExprs = append(columnExprs, colExpr)
|
|
}
|
|
|
|
whereClause := ""
|
|
if index.Where != "" {
|
|
whereClause = fmt.Sprintf(" WHERE %s", index.Where)
|
|
}
|
|
|
|
stmt := fmt.Sprintf("CREATE %sINDEX IF NOT EXISTS %s ON %s.%s USING %s (%s)%s",
|
|
uniqueStr, index.Name, schema.SQLName(), table.SQLName(), indexType, strings.Join(columnExprs, ", "), whereClause)
|
|
statements = append(statements, stmt)
|
|
}
|
|
}
|
|
|
|
// Phase 6: Foreign keys
|
|
for _, table := range schema.Tables {
|
|
for _, constraint := range table.Constraints {
|
|
if constraint.Type != models.ForeignKeyConstraint {
|
|
continue
|
|
}
|
|
|
|
refSchema := constraint.ReferencedSchema
|
|
if refSchema == "" {
|
|
refSchema = schema.Name
|
|
}
|
|
|
|
onDelete := constraint.OnDelete
|
|
if onDelete == "" {
|
|
onDelete = "NO ACTION"
|
|
}
|
|
|
|
onUpdate := constraint.OnUpdate
|
|
if onUpdate == "" {
|
|
onUpdate = "NO ACTION"
|
|
}
|
|
|
|
stmt := fmt.Sprintf("ALTER TABLE %s.%s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s.%s(%s) ON DELETE %s ON UPDATE %s",
|
|
schema.SQLName(), table.SQLName(), constraint.Name,
|
|
strings.Join(constraint.Columns, ", "),
|
|
strings.ToLower(refSchema), strings.ToLower(constraint.ReferencedTable),
|
|
strings.Join(constraint.ReferencedColumns, ", "),
|
|
onDelete, onUpdate)
|
|
statements = append(statements, stmt)
|
|
}
|
|
}
|
|
|
|
// Phase 7: Comments
|
|
for _, table := range schema.Tables {
|
|
if table.Comment != "" {
|
|
stmt := fmt.Sprintf("COMMENT ON TABLE %s.%s IS '%s'",
|
|
schema.SQLName(), table.SQLName(), escapeQuote(table.Comment))
|
|
statements = append(statements, stmt)
|
|
}
|
|
|
|
for _, column := range table.Columns {
|
|
if column.Comment != "" {
|
|
stmt := fmt.Sprintf("COMMENT ON COLUMN %s.%s.%s IS '%s'",
|
|
schema.SQLName(), table.SQLName(), column.SQLName(), escapeQuote(column.Comment))
|
|
statements = append(statements, stmt)
|
|
}
|
|
}
|
|
}
|
|
|
|
return statements, nil
|
|
}
|
|
|
|
// generateCreateTableStatement generates CREATE TABLE statement
|
|
func (w *Writer) generateCreateTableStatement(schema *models.Schema, table *models.Table) ([]string, error) {
|
|
statements := []string{}
|
|
|
|
// Sort columns by sequence or name
|
|
columns := make([]*models.Column, 0, len(table.Columns))
|
|
for _, col := range table.Columns {
|
|
columns = append(columns, col)
|
|
}
|
|
sort.Slice(columns, func(i, j int) bool {
|
|
if columns[i].Sequence != columns[j].Sequence {
|
|
return columns[i].Sequence < columns[j].Sequence
|
|
}
|
|
return columns[i].Name < columns[j].Name
|
|
})
|
|
|
|
columnDefs := []string{}
|
|
for _, col := range columns {
|
|
def := w.generateColumnDefinition(col)
|
|
columnDefs = append(columnDefs, " "+def)
|
|
}
|
|
|
|
stmt := fmt.Sprintf("CREATE TABLE %s.%s (\n%s\n)",
|
|
schema.SQLName(), table.SQLName(), strings.Join(columnDefs, ",\n"))
|
|
statements = append(statements, stmt)
|
|
|
|
return statements, nil
|
|
}
|
|
|
|
// generateColumnDefinition generates column definition
|
|
func (w *Writer) generateColumnDefinition(col *models.Column) string {
|
|
parts := []string{col.SQLName()}
|
|
|
|
// Type with length/precision
|
|
typeStr := col.Type
|
|
if col.Length > 0 && col.Precision == 0 {
|
|
typeStr = fmt.Sprintf("%s(%d)", col.Type, col.Length)
|
|
} else if col.Precision > 0 {
|
|
if col.Scale > 0 {
|
|
typeStr = fmt.Sprintf("%s(%d,%d)", col.Type, col.Precision, col.Scale)
|
|
} else {
|
|
typeStr = fmt.Sprintf("%s(%d)", col.Type, col.Precision)
|
|
}
|
|
}
|
|
parts = append(parts, typeStr)
|
|
|
|
// NOT NULL
|
|
if col.NotNull {
|
|
parts = append(parts, "NOT NULL")
|
|
}
|
|
|
|
// DEFAULT
|
|
if col.Default != nil {
|
|
switch v := col.Default.(type) {
|
|
case string:
|
|
// Strip backticks - DBML uses them for SQL expressions but PostgreSQL doesn't
|
|
cleanDefault := stripBackticks(v)
|
|
if strings.HasPrefix(cleanDefault, "nextval") || strings.HasPrefix(cleanDefault, "CURRENT_") || strings.Contains(cleanDefault, "()") {
|
|
parts = append(parts, fmt.Sprintf("DEFAULT %s", cleanDefault))
|
|
} else if cleanDefault == "true" || cleanDefault == "false" {
|
|
parts = append(parts, fmt.Sprintf("DEFAULT %s", cleanDefault))
|
|
} else {
|
|
parts = append(parts, fmt.Sprintf("DEFAULT '%s'", escapeQuote(cleanDefault)))
|
|
}
|
|
case bool:
|
|
parts = append(parts, fmt.Sprintf("DEFAULT %v", v))
|
|
default:
|
|
parts = append(parts, fmt.Sprintf("DEFAULT %v", v))
|
|
}
|
|
}
|
|
|
|
return strings.Join(parts, " ")
|
|
}
|
|
|
|
// WriteSchema writes a single schema and all its tables
|
|
func (w *Writer) WriteSchema(schema *models.Schema) error {
|
|
if w.writer == nil {
|
|
w.writer = os.Stdout
|
|
}
|
|
|
|
// Phase 1: Create schema (priority 1)
|
|
if err := w.writeCreateSchema(schema); err != nil {
|
|
return err
|
|
}
|
|
|
|
// Phase 2: Create sequences (priority 80)
|
|
if err := w.writeSequences(schema); err != nil {
|
|
return err
|
|
}
|
|
|
|
// Phase 3: Create tables with columns (priority 100)
|
|
if err := w.writeCreateTables(schema); err != nil {
|
|
return err
|
|
}
|
|
|
|
// Phase 4: Create primary keys (priority 160)
|
|
if err := w.writePrimaryKeys(schema); err != nil {
|
|
return err
|
|
}
|
|
|
|
// Phase 5: Create indexes (priority 180)
|
|
if err := w.writeIndexes(schema); err != nil {
|
|
return err
|
|
}
|
|
|
|
// Phase 6: Create foreign key constraints (priority 195)
|
|
if err := w.writeForeignKeys(schema); err != nil {
|
|
return err
|
|
}
|
|
|
|
// Phase 7: Set sequence values (priority 200)
|
|
if err := w.writeSetSequenceValues(schema); err != nil {
|
|
return err
|
|
}
|
|
|
|
// Phase 8: Add comments (priority 200+)
|
|
if err := w.writeComments(schema); err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// WriteTable writes a single table with all its elements
|
|
func (w *Writer) WriteTable(table *models.Table) error {
|
|
if w.writer == nil {
|
|
w.writer = os.Stdout
|
|
}
|
|
|
|
// Create a temporary schema with just this table
|
|
schema := models.InitSchema(table.Schema)
|
|
schema.Tables = append(schema.Tables, table)
|
|
|
|
return w.WriteSchema(schema)
|
|
}
|
|
|
|
// writeCreateSchema generates CREATE SCHEMA statement
|
|
func (w *Writer) writeCreateSchema(schema *models.Schema) error {
|
|
if schema.Name == "public" {
|
|
// public schema exists by default
|
|
return nil
|
|
}
|
|
|
|
fmt.Fprintf(w.writer, "-- Schema: %s\n", schema.Name)
|
|
fmt.Fprintf(w.writer, "CREATE SCHEMA IF NOT EXISTS %s;\n\n", schema.SQLName())
|
|
return nil
|
|
}
|
|
|
|
// writeSequences generates CREATE SEQUENCE statements for identity columns
|
|
func (w *Writer) writeSequences(schema *models.Schema) error {
|
|
fmt.Fprintf(w.writer, "-- Sequences for schema: %s\n", schema.Name)
|
|
|
|
for _, table := range schema.Tables {
|
|
pk := table.GetPrimaryKey()
|
|
if pk == nil {
|
|
continue
|
|
}
|
|
|
|
// Only create sequences for integer-type PKs with identity
|
|
if !isIntegerType(pk.Type) {
|
|
continue
|
|
}
|
|
|
|
seqName := fmt.Sprintf("identity_%s_%s", table.SQLName(), pk.SQLName())
|
|
fmt.Fprintf(w.writer, "CREATE SEQUENCE IF NOT EXISTS %s.%s\n",
|
|
schema.SQLName(), seqName)
|
|
fmt.Fprintf(w.writer, " INCREMENT 1\n")
|
|
fmt.Fprintf(w.writer, " MINVALUE 1\n")
|
|
fmt.Fprintf(w.writer, " MAXVALUE 9223372036854775807\n")
|
|
fmt.Fprintf(w.writer, " START 1\n")
|
|
fmt.Fprintf(w.writer, " CACHE 1;\n\n")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// writeCreateTables generates CREATE TABLE statements
|
|
func (w *Writer) writeCreateTables(schema *models.Schema) error {
|
|
fmt.Fprintf(w.writer, "-- Tables for schema: %s\n", schema.Name)
|
|
|
|
for _, table := range schema.Tables {
|
|
fmt.Fprintf(w.writer, "CREATE TABLE IF NOT EXISTS %s.%s (\n",
|
|
schema.SQLName(), table.SQLName())
|
|
|
|
// Write columns
|
|
columns := getSortedColumns(table.Columns)
|
|
columnDefs := make([]string, 0, len(columns))
|
|
|
|
for _, col := range columns {
|
|
colDef := fmt.Sprintf(" %s %s", col.SQLName(), col.Type)
|
|
|
|
// Add default value if present
|
|
if col.Default != nil && col.Default != "" {
|
|
// Strip backticks - DBML uses them for SQL expressions but PostgreSQL doesn't
|
|
defaultVal := fmt.Sprintf("%v", col.Default)
|
|
colDef += fmt.Sprintf(" DEFAULT %s", stripBackticks(defaultVal))
|
|
}
|
|
|
|
columnDefs = append(columnDefs, colDef)
|
|
}
|
|
|
|
fmt.Fprintf(w.writer, "%s\n", strings.Join(columnDefs, ",\n"))
|
|
fmt.Fprintf(w.writer, ");\n\n")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// writePrimaryKeys generates ALTER TABLE statements for primary keys
|
|
func (w *Writer) writePrimaryKeys(schema *models.Schema) error {
|
|
fmt.Fprintf(w.writer, "-- Primary keys for schema: %s\n", schema.Name)
|
|
|
|
for _, table := range schema.Tables {
|
|
// Find primary key constraint
|
|
var pkConstraint *models.Constraint
|
|
for name, constraint := range table.Constraints {
|
|
if constraint.Type == models.PrimaryKeyConstraint {
|
|
pkConstraint = constraint
|
|
_ = name // Use the name variable
|
|
break
|
|
}
|
|
}
|
|
|
|
var columnNames []string
|
|
pkName := fmt.Sprintf("pk_%s_%s", schema.SQLName(), table.SQLName())
|
|
|
|
if pkConstraint != nil {
|
|
// Build column list from explicit constraint
|
|
columnNames = make([]string, 0, len(pkConstraint.Columns))
|
|
for _, colName := range pkConstraint.Columns {
|
|
if col, ok := table.Columns[colName]; ok {
|
|
columnNames = append(columnNames, col.SQLName())
|
|
}
|
|
}
|
|
} else {
|
|
// No explicit PK constraint, check for columns with IsPrimaryKey = true
|
|
for _, col := range table.Columns {
|
|
if col.IsPrimaryKey {
|
|
columnNames = append(columnNames, col.SQLName())
|
|
}
|
|
}
|
|
// Sort for consistent output
|
|
sort.Strings(columnNames)
|
|
}
|
|
|
|
if len(columnNames) == 0 {
|
|
continue
|
|
}
|
|
|
|
fmt.Fprintf(w.writer, "DO $$\nBEGIN\n")
|
|
fmt.Fprintf(w.writer, " IF NOT EXISTS (\n")
|
|
fmt.Fprintf(w.writer, " SELECT 1 FROM information_schema.table_constraints\n")
|
|
fmt.Fprintf(w.writer, " WHERE table_schema = '%s'\n", schema.Name)
|
|
fmt.Fprintf(w.writer, " AND table_name = '%s'\n", table.Name)
|
|
fmt.Fprintf(w.writer, " AND constraint_name = '%s'\n", pkName)
|
|
fmt.Fprintf(w.writer, " ) THEN\n")
|
|
fmt.Fprintf(w.writer, " ALTER TABLE %s.%s\n", schema.SQLName(), table.SQLName())
|
|
fmt.Fprintf(w.writer, " ADD CONSTRAINT %s PRIMARY KEY (%s);\n",
|
|
pkName, strings.Join(columnNames, ", "))
|
|
fmt.Fprintf(w.writer, " END IF;\n")
|
|
fmt.Fprintf(w.writer, "END;\n$$;\n\n")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// writeIndexes generates CREATE INDEX statements
|
|
func (w *Writer) writeIndexes(schema *models.Schema) error {
|
|
fmt.Fprintf(w.writer, "-- Indexes for schema: %s\n", schema.Name)
|
|
|
|
for _, table := range schema.Tables {
|
|
// Sort indexes by name for consistent output
|
|
indexNames := make([]string, 0, len(table.Indexes))
|
|
for name := range table.Indexes {
|
|
indexNames = append(indexNames, name)
|
|
}
|
|
sort.Strings(indexNames)
|
|
|
|
for _, name := range indexNames {
|
|
index := table.Indexes[name]
|
|
|
|
// Skip if it's a primary key index (based on name convention or columns)
|
|
// Primary keys are handled separately
|
|
if strings.HasPrefix(strings.ToLower(index.Name), "pk_") {
|
|
continue
|
|
}
|
|
|
|
indexName := index.Name
|
|
if indexName == "" {
|
|
indexType := "idx"
|
|
if index.Unique {
|
|
indexType = "uk"
|
|
}
|
|
indexName = fmt.Sprintf("%s_%s_%s", indexType, schema.SQLName(), table.SQLName())
|
|
}
|
|
|
|
// Build column list with operator class support for GIN indexes
|
|
columnExprs := make([]string, 0, len(index.Columns))
|
|
for _, colName := range index.Columns {
|
|
if col, ok := table.Columns[colName]; ok {
|
|
colExpr := col.SQLName()
|
|
// For GIN indexes on text columns, add operator class
|
|
if strings.EqualFold(index.Type, "gin") && isTextType(col.Type) {
|
|
opClass := extractOperatorClass(index.Comment)
|
|
if opClass == "" {
|
|
opClass = "gin_trgm_ops"
|
|
}
|
|
colExpr = fmt.Sprintf("%s %s", col.SQLName(), opClass)
|
|
}
|
|
columnExprs = append(columnExprs, colExpr)
|
|
}
|
|
}
|
|
|
|
if len(columnExprs) == 0 {
|
|
continue
|
|
}
|
|
|
|
unique := ""
|
|
if index.Unique {
|
|
unique = "UNIQUE "
|
|
}
|
|
|
|
indexType := index.Type
|
|
if indexType == "" {
|
|
indexType = "btree"
|
|
}
|
|
|
|
whereClause := ""
|
|
if index.Where != "" {
|
|
whereClause = fmt.Sprintf(" WHERE %s", index.Where)
|
|
}
|
|
|
|
fmt.Fprintf(w.writer, "CREATE %sINDEX IF NOT EXISTS %s\n",
|
|
unique, indexName)
|
|
fmt.Fprintf(w.writer, " ON %s.%s USING %s (%s)%s;\n\n",
|
|
schema.SQLName(), table.SQLName(), indexType, strings.Join(columnExprs, ", "), whereClause)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// writeForeignKeys generates ALTER TABLE statements for foreign keys
|
|
func (w *Writer) writeForeignKeys(schema *models.Schema) error {
|
|
fmt.Fprintf(w.writer, "-- Foreign keys for schema: %s\n", schema.Name)
|
|
|
|
for _, table := range schema.Tables {
|
|
// Sort relationships by name for consistent output
|
|
relNames := make([]string, 0, len(table.Relationships))
|
|
for name := range table.Relationships {
|
|
relNames = append(relNames, name)
|
|
}
|
|
sort.Strings(relNames)
|
|
|
|
for _, name := range relNames {
|
|
rel := table.Relationships[name]
|
|
|
|
// For relationships, we need to look up the foreign key constraint
|
|
// that defines the actual column mappings
|
|
fkName := rel.ForeignKey
|
|
if fkName == "" {
|
|
fkName = name
|
|
}
|
|
if fkName == "" {
|
|
fkName = fmt.Sprintf("fk_%s_%s", table.SQLName(), rel.ToTable)
|
|
}
|
|
|
|
// Find the foreign key constraint that matches this relationship
|
|
var fkConstraint *models.Constraint
|
|
for _, constraint := range table.Constraints {
|
|
if constraint.Type == models.ForeignKeyConstraint &&
|
|
(constraint.Name == fkName || constraint.ReferencedTable == rel.ToTable) {
|
|
fkConstraint = constraint
|
|
break
|
|
}
|
|
}
|
|
|
|
// If no constraint found, skip this relationship
|
|
if fkConstraint == nil {
|
|
continue
|
|
}
|
|
|
|
// Build column lists from the constraint
|
|
sourceColumns := make([]string, 0, len(fkConstraint.Columns))
|
|
for _, colName := range fkConstraint.Columns {
|
|
if col, ok := table.Columns[colName]; ok {
|
|
sourceColumns = append(sourceColumns, col.SQLName())
|
|
}
|
|
}
|
|
|
|
targetColumns := make([]string, 0, len(fkConstraint.ReferencedColumns))
|
|
for _, colName := range fkConstraint.ReferencedColumns {
|
|
targetColumns = append(targetColumns, strings.ToLower(colName))
|
|
}
|
|
|
|
if len(sourceColumns) == 0 || len(targetColumns) == 0 {
|
|
continue
|
|
}
|
|
|
|
onDelete := "NO ACTION"
|
|
if fkConstraint.OnDelete != "" {
|
|
onDelete = strings.ToUpper(fkConstraint.OnDelete)
|
|
}
|
|
|
|
onUpdate := "NO ACTION"
|
|
if fkConstraint.OnUpdate != "" {
|
|
onUpdate = strings.ToUpper(fkConstraint.OnUpdate)
|
|
}
|
|
|
|
fmt.Fprintf(w.writer, "ALTER TABLE %s.%s\n", schema.SQLName(), table.SQLName())
|
|
fmt.Fprintf(w.writer, " DROP CONSTRAINT IF EXISTS %s;\n", fkName)
|
|
fmt.Fprintf(w.writer, "\n")
|
|
fmt.Fprintf(w.writer, "ALTER TABLE %s.%s\n", schema.SQLName(), table.SQLName())
|
|
fmt.Fprintf(w.writer, " ADD CONSTRAINT %s\n", fkName)
|
|
fmt.Fprintf(w.writer, " FOREIGN KEY (%s)\n", strings.Join(sourceColumns, ", "))
|
|
|
|
// Use constraint's referenced schema/table or relationship's ToSchema/ToTable
|
|
refSchema := fkConstraint.ReferencedSchema
|
|
if refSchema == "" {
|
|
refSchema = rel.ToSchema
|
|
}
|
|
refTable := fkConstraint.ReferencedTable
|
|
if refTable == "" {
|
|
refTable = rel.ToTable
|
|
}
|
|
|
|
fmt.Fprintf(w.writer, " REFERENCES %s.%s (%s)\n",
|
|
refSchema, refTable, strings.Join(targetColumns, ", "))
|
|
fmt.Fprintf(w.writer, " ON DELETE %s\n", onDelete)
|
|
fmt.Fprintf(w.writer, " ON UPDATE %s\n", onUpdate)
|
|
fmt.Fprintf(w.writer, " DEFERRABLE;\n\n")
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// writeSetSequenceValues generates statements to set sequence current values
|
|
func (w *Writer) writeSetSequenceValues(schema *models.Schema) error {
|
|
fmt.Fprintf(w.writer, "-- Set sequence values for schema: %s\n", schema.Name)
|
|
|
|
for _, table := range schema.Tables {
|
|
pk := table.GetPrimaryKey()
|
|
if pk == nil || !isIntegerType(pk.Type) {
|
|
continue
|
|
}
|
|
|
|
seqName := fmt.Sprintf("identity_%s_%s", table.SQLName(), pk.SQLName())
|
|
|
|
fmt.Fprintf(w.writer, "DO $$\n")
|
|
fmt.Fprintf(w.writer, "DECLARE\n")
|
|
fmt.Fprintf(w.writer, " m_cnt bigint;\n")
|
|
fmt.Fprintf(w.writer, "BEGIN\n")
|
|
fmt.Fprintf(w.writer, " IF EXISTS (\n")
|
|
fmt.Fprintf(w.writer, " SELECT 1 FROM pg_class c\n")
|
|
fmt.Fprintf(w.writer, " INNER JOIN pg_namespace n ON n.oid = c.relnamespace\n")
|
|
fmt.Fprintf(w.writer, " WHERE c.relname = '%s'\n", seqName)
|
|
fmt.Fprintf(w.writer, " AND n.nspname = '%s'\n", schema.Name)
|
|
fmt.Fprintf(w.writer, " AND c.relkind = 'S'\n")
|
|
fmt.Fprintf(w.writer, " ) THEN\n")
|
|
fmt.Fprintf(w.writer, " SELECT COALESCE(MAX(%s), 0) + 1\n", pk.SQLName())
|
|
fmt.Fprintf(w.writer, " FROM %s.%s\n", schema.SQLName(), table.SQLName())
|
|
fmt.Fprintf(w.writer, " INTO m_cnt;\n")
|
|
fmt.Fprintf(w.writer, " \n")
|
|
fmt.Fprintf(w.writer, " PERFORM setval('%s.%s'::regclass, m_cnt);\n",
|
|
schema.SQLName(), seqName)
|
|
fmt.Fprintf(w.writer, " END IF;\n")
|
|
fmt.Fprintf(w.writer, "END;\n")
|
|
fmt.Fprintf(w.writer, "$$;\n\n")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// writeComments generates COMMENT statements for tables and columns
|
|
func (w *Writer) writeComments(schema *models.Schema) error {
|
|
fmt.Fprintf(w.writer, "-- Comments for schema: %s\n", schema.Name)
|
|
|
|
for _, table := range schema.Tables {
|
|
// Table comment
|
|
if table.Description != "" {
|
|
fmt.Fprintf(w.writer, "COMMENT ON TABLE %s.%s IS '%s';\n",
|
|
schema.SQLName(), table.SQLName(),
|
|
escapeQuote(table.Description))
|
|
}
|
|
|
|
// Column comments
|
|
for _, col := range getSortedColumns(table.Columns) {
|
|
if col.Description != "" {
|
|
fmt.Fprintf(w.writer, "COMMENT ON COLUMN %s.%s.%s IS '%s';\n",
|
|
schema.SQLName(), table.SQLName(), col.SQLName(),
|
|
escapeQuote(col.Description))
|
|
}
|
|
}
|
|
|
|
fmt.Fprintf(w.writer, "\n")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Helper functions
|
|
|
|
// getSortedColumns returns columns sorted by name
|
|
func getSortedColumns(columns map[string]*models.Column) []*models.Column {
|
|
names := make([]string, 0, len(columns))
|
|
for name := range columns {
|
|
names = append(names, name)
|
|
}
|
|
sort.Strings(names)
|
|
|
|
sorted := make([]*models.Column, 0, len(columns))
|
|
for _, name := range names {
|
|
sorted = append(sorted, columns[name])
|
|
}
|
|
return sorted
|
|
}
|
|
|
|
// isIntegerType checks if a column type is an integer type
|
|
func isIntegerType(colType string) bool {
|
|
intTypes := []string{"integer", "int", "bigint", "smallint", "serial", "bigserial"}
|
|
lowerType := strings.ToLower(colType)
|
|
for _, t := range intTypes {
|
|
if strings.HasPrefix(lowerType, t) {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
// isTextType checks if a column type is a text type (for GIN index operator class)
|
|
func isTextType(colType string) bool {
|
|
textTypes := []string{"text", "varchar", "character varying", "char", "character", "string"}
|
|
lowerType := strings.ToLower(colType)
|
|
for _, t := range textTypes {
|
|
if strings.HasPrefix(lowerType, t) {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
// extractOperatorClass extracts operator class from index comment/note
|
|
// Looks for common operator classes like gin_trgm_ops, gist_trgm_ops, etc.
|
|
func extractOperatorClass(comment string) string {
|
|
if comment == "" {
|
|
return ""
|
|
}
|
|
lowerComment := strings.ToLower(comment)
|
|
// Common GIN/GiST operator classes
|
|
opClasses := []string{"gin_trgm_ops", "gist_trgm_ops", "gin_bigm_ops", "jsonb_ops", "jsonb_path_ops", "array_ops"}
|
|
for _, op := range opClasses {
|
|
if strings.Contains(lowerComment, op) {
|
|
return op
|
|
}
|
|
}
|
|
return ""
|
|
}
|
|
|
|
// escapeQuote escapes single quotes in strings for SQL
|
|
func escapeQuote(s string) string {
|
|
return strings.ReplaceAll(s, "'", "''")
|
|
}
|
|
|
|
// stripBackticks removes backticks from SQL expressions
|
|
// DBML uses backticks for SQL expressions like `now()`, but PostgreSQL doesn't use backticks
|
|
func stripBackticks(s string) string {
|
|
return strings.ReplaceAll(s, "`", "")
|
|
}
|
|
|
|
// extractSequenceName extracts sequence name from nextval() expression
|
|
// Example: "nextval('public.users_id_seq'::regclass)" returns "users_id_seq"
|
|
func extractSequenceName(defaultExpr string) string {
|
|
// Look for nextval('schema.sequence_name'::regclass) pattern
|
|
start := strings.Index(defaultExpr, "'")
|
|
if start == -1 {
|
|
return ""
|
|
}
|
|
end := strings.Index(defaultExpr[start+1:], "'")
|
|
if end == -1 {
|
|
return ""
|
|
}
|
|
|
|
fullName := defaultExpr[start+1 : start+1+end]
|
|
|
|
// Remove schema prefix if present
|
|
parts := strings.Split(fullName, ".")
|
|
if len(parts) > 1 {
|
|
return parts[len(parts)-1]
|
|
}
|
|
return fullName
|
|
}
|
|
|
|
// executeDatabaseSQL executes SQL statements directly on a PostgreSQL database
|
|
func (w *Writer) executeDatabaseSQL(db *models.Database, connString string) error {
|
|
// Initialize execution report
|
|
w.executionReport = &ExecutionReport{
|
|
StartTime: getCurrentTimestamp(),
|
|
Schemas: make([]SchemaReport, 0),
|
|
Errors: make([]ExecutionError, 0),
|
|
}
|
|
|
|
// Generate SQL statements
|
|
statements, err := w.GenerateDatabaseStatements(db)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to generate SQL statements: %w", err)
|
|
}
|
|
|
|
w.executionReport.TotalStatements = len(statements)
|
|
|
|
// Connect to database
|
|
ctx := context.Background()
|
|
conn, err := pgx.Connect(ctx, connString)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to connect to database: %w", err)
|
|
}
|
|
defer conn.Close(ctx)
|
|
|
|
// Track schemas and tables
|
|
schemaMap := make(map[string]*SchemaReport)
|
|
currentSchema := ""
|
|
|
|
// Execute each statement
|
|
for i, stmt := range statements {
|
|
stmtTrimmed := strings.TrimSpace(stmt)
|
|
|
|
// Skip comments
|
|
if strings.HasPrefix(stmtTrimmed, "--") {
|
|
// Check if this is a schema comment to track schema changes
|
|
if strings.Contains(stmtTrimmed, "Schema:") {
|
|
parts := strings.Split(stmtTrimmed, "Schema:")
|
|
if len(parts) > 1 {
|
|
currentSchema = strings.TrimSpace(parts[1])
|
|
if _, exists := schemaMap[currentSchema]; !exists {
|
|
schemaReport := SchemaReport{
|
|
Name: currentSchema,
|
|
Tables: make([]TableReport, 0),
|
|
}
|
|
schemaMap[currentSchema] = &schemaReport
|
|
}
|
|
}
|
|
}
|
|
continue
|
|
}
|
|
|
|
// Skip empty statements
|
|
if stmtTrimmed == "" {
|
|
continue
|
|
}
|
|
|
|
fmt.Fprintf(os.Stderr, "Executing statement %d/%d...\n", i+1, len(statements))
|
|
|
|
_, execErr := conn.Exec(ctx, stmt)
|
|
if execErr != nil {
|
|
w.executionReport.FailedStatements++
|
|
execError := ExecutionError{
|
|
StatementNumber: i + 1,
|
|
Statement: truncateStatement(stmt),
|
|
Error: execErr.Error(),
|
|
}
|
|
w.executionReport.Errors = append(w.executionReport.Errors, execError)
|
|
|
|
// Track table creation failure
|
|
if strings.Contains(strings.ToUpper(stmtTrimmed), "CREATE TABLE") && currentSchema != "" {
|
|
tableName := extractTableNameFromCreate(stmtTrimmed)
|
|
if tableName != "" && schemaMap[currentSchema] != nil {
|
|
schemaMap[currentSchema].Tables = append(schemaMap[currentSchema].Tables, TableReport{
|
|
Name: tableName,
|
|
Created: false,
|
|
Error: execErr.Error(),
|
|
})
|
|
}
|
|
}
|
|
|
|
// Continue with next statement instead of failing completely
|
|
fmt.Fprintf(os.Stderr, "⚠ Warning: Statement %d failed: %v\n", i+1, execErr)
|
|
continue
|
|
}
|
|
|
|
w.executionReport.ExecutedStatements++
|
|
|
|
// Track successful table creation
|
|
if strings.Contains(strings.ToUpper(stmtTrimmed), "CREATE TABLE") && currentSchema != "" {
|
|
tableName := extractTableNameFromCreate(stmtTrimmed)
|
|
if tableName != "" && schemaMap[currentSchema] != nil {
|
|
schemaMap[currentSchema].Tables = append(schemaMap[currentSchema].Tables, TableReport{
|
|
Name: tableName,
|
|
Created: true,
|
|
})
|
|
}
|
|
}
|
|
}
|
|
|
|
// Convert schema map to slice
|
|
for _, schemaReport := range schemaMap {
|
|
w.executionReport.Schemas = append(w.executionReport.Schemas, *schemaReport)
|
|
}
|
|
|
|
w.executionReport.EndTime = getCurrentTimestamp()
|
|
|
|
// Write report if path is specified
|
|
if reportPath, ok := w.options.Metadata["report_path"].(string); ok && reportPath != "" {
|
|
if err := w.writeReport(reportPath); err != nil {
|
|
fmt.Fprintf(os.Stderr, "⚠ Warning: Failed to write report: %v\n", err)
|
|
} else {
|
|
fmt.Fprintf(os.Stderr, "✓ Report written to: %s\n", reportPath)
|
|
}
|
|
}
|
|
|
|
if w.executionReport.FailedStatements > 0 {
|
|
fmt.Fprintf(os.Stderr, "⚠ Completed with %d errors out of %d statements\n",
|
|
w.executionReport.FailedStatements, w.executionReport.TotalStatements)
|
|
} else {
|
|
fmt.Fprintf(os.Stderr, "✓ Successfully executed %d statements\n", w.executionReport.ExecutedStatements)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// writeReport writes the execution report to a JSON file
|
|
func (w *Writer) writeReport(reportPath string) error {
|
|
file, err := os.Create(reportPath)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to create report file: %w", err)
|
|
}
|
|
defer file.Close()
|
|
|
|
encoder := json.NewEncoder(file)
|
|
encoder.SetIndent("", " ")
|
|
if err := encoder.Encode(w.executionReport); err != nil {
|
|
return fmt.Errorf("failed to encode report: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// extractTableNameFromCreate extracts table name from CREATE TABLE statement
|
|
func extractTableNameFromCreate(stmt string) string {
|
|
// Match: CREATE TABLE [IF NOT EXISTS] schema.table_name or table_name
|
|
upper := strings.ToUpper(stmt)
|
|
idx := strings.Index(upper, "CREATE TABLE")
|
|
if idx == -1 {
|
|
return ""
|
|
}
|
|
|
|
rest := strings.TrimSpace(stmt[idx+12:]) // Skip "CREATE TABLE"
|
|
|
|
// Skip "IF NOT EXISTS"
|
|
if strings.HasPrefix(strings.ToUpper(rest), "IF NOT EXISTS") {
|
|
rest = strings.TrimSpace(rest[13:])
|
|
}
|
|
|
|
// Get the table name (first token before '(' or whitespace)
|
|
tokens := strings.FieldsFunc(rest, func(r rune) bool {
|
|
return r == '(' || r == ' ' || r == '\n' || r == '\t'
|
|
})
|
|
|
|
if len(tokens) == 0 {
|
|
return ""
|
|
}
|
|
|
|
// Handle schema.table format
|
|
fullName := tokens[0]
|
|
parts := strings.Split(fullName, ".")
|
|
if len(parts) > 1 {
|
|
return parts[len(parts)-1]
|
|
}
|
|
|
|
return fullName
|
|
}
|
|
|
|
// truncateStatement truncates long SQL statements for error messages
|
|
func truncateStatement(stmt string) string {
|
|
const maxLen = 200
|
|
if len(stmt) <= maxLen {
|
|
return stmt
|
|
}
|
|
return stmt[:maxLen] + "..."
|
|
}
|
|
|
|
// getCurrentTimestamp returns the current timestamp in a readable format
|
|
func getCurrentTimestamp() string {
|
|
return time.Now().Format("2006-01-02 15:04:05")
|
|
}
|