Files
relspecgo/pkg/writers/pgsql/writer.go
Hein (Warky) 283b568adb
All checks were successful
CI / Test (1.24) (push) Successful in -25m29s
CI / Test (1.25) (push) Successful in -25m13s
CI / Lint (push) Successful in -26m13s
CI / Build (push) Successful in -26m27s
Integration Tests / Integration Tests (push) Successful in -26m11s
Release / Build and Release (push) Successful in -25m8s
feat(pgsql): add execution reporting for SQL statements
- Implemented ExecutionReport to track the execution status of SQL statements.
- Added SchemaReport and TableReport to monitor execution per schema and table.
- Enhanced WriteDatabase to execute SQL directly on a PostgreSQL database if a connection string is provided.
- Included error handling and logging for failed statements during execution.
- Added functionality to write execution reports to a JSON file.
- Introduced utility functions to extract table names from CREATE TABLE statements and truncate long SQL statements for error messages.
2026-01-29 21:16:14 +02:00

1086 lines
32 KiB
Go

package pgsql
import (
"context"
"encoding/json"
"fmt"
"io"
"os"
"sort"
"strings"
"time"
"github.com/jackc/pgx/v5"
"git.warky.dev/wdevs/relspecgo/pkg/models"
"git.warky.dev/wdevs/relspecgo/pkg/writers"
)
// Writer implements the Writer interface for PostgreSQL SQL output
type Writer struct {
options *writers.WriterOptions
writer io.Writer
executionReport *ExecutionReport
}
// ExecutionReport tracks the execution status of SQL statements
type ExecutionReport struct {
TotalStatements int `json:"total_statements"`
ExecutedStatements int `json:"executed_statements"`
FailedStatements int `json:"failed_statements"`
Schemas []SchemaReport `json:"schemas"`
Errors []ExecutionError `json:"errors,omitempty"`
StartTime string `json:"start_time"`
EndTime string `json:"end_time"`
}
// SchemaReport tracks execution per schema
type SchemaReport struct {
Name string `json:"name"`
Tables []TableReport `json:"tables"`
}
// TableReport tracks execution per table
type TableReport struct {
Name string `json:"name"`
Created bool `json:"created"`
Error string `json:"error,omitempty"`
}
// ExecutionError represents a failed statement
type ExecutionError struct {
StatementNumber int `json:"statement_number"`
Statement string `json:"statement"`
Error string `json:"error"`
}
// NewWriter creates a new PostgreSQL SQL writer
func NewWriter(options *writers.WriterOptions) *Writer {
return &Writer{
options: options,
}
}
// WriteDatabase writes the entire database schema as SQL
func (w *Writer) WriteDatabase(db *models.Database) error {
// Check if we should execute SQL directly on a database
if connString, ok := w.options.Metadata["connection_string"].(string); ok && connString != "" {
return w.executeDatabaseSQL(db, connString)
}
var writer io.Writer
var file *os.File
var err error
// Use existing writer if already set (for testing)
if w.writer != nil {
writer = w.writer
} else if w.options.OutputPath != "" {
// Determine output destination
file, err = os.Create(w.options.OutputPath)
if err != nil {
return fmt.Errorf("failed to create output file: %w", err)
}
defer file.Close()
writer = file
} else {
writer = os.Stdout
}
w.writer = writer
// Write header comment
fmt.Fprintf(w.writer, "-- PostgreSQL Database Schema\n")
fmt.Fprintf(w.writer, "-- Database: %s\n", db.Name)
fmt.Fprintf(w.writer, "-- Generated by RelSpec\n\n")
// Process each schema in the database
for _, schema := range db.Schemas {
if err := w.WriteSchema(schema); err != nil {
return fmt.Errorf("failed to write schema %s: %w", schema.Name, err)
}
}
return nil
}
// GenerateDatabaseStatements generates SQL statements as a list for the entire database
// Returns a slice of SQL statements that can be executed independently
func (w *Writer) GenerateDatabaseStatements(db *models.Database) ([]string, error) {
statements := []string{}
// Add header comment
statements = append(statements, "-- PostgreSQL Database Schema")
statements = append(statements, fmt.Sprintf("-- Database: %s", db.Name))
statements = append(statements, "-- Generated by RelSpec")
// Process each schema in the database
for _, schema := range db.Schemas {
schemaStatements, err := w.GenerateSchemaStatements(schema)
if err != nil {
return nil, fmt.Errorf("failed to generate statements for schema %s: %w", schema.Name, err)
}
statements = append(statements, schemaStatements...)
}
return statements, nil
}
// GenerateSchemaStatements generates SQL statements as a list for a single schema
func (w *Writer) GenerateSchemaStatements(schema *models.Schema) ([]string, error) {
statements := []string{}
// Phase 1: Create schema
if schema.Name != "public" {
statements = append(statements, fmt.Sprintf("-- Schema: %s", schema.Name))
statements = append(statements, fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS %s", schema.SQLName()))
}
// Phase 2: Create sequences
for _, table := range schema.Tables {
pk := table.GetPrimaryKey()
if pk == nil || !isIntegerType(pk.Type) || pk.Default == "" {
continue
}
defaultStr, ok := pk.Default.(string)
if !ok || !strings.Contains(strings.ToLower(defaultStr), "nextval") {
continue
}
seqName := extractSequenceName(defaultStr)
if seqName == "" {
continue
}
stmt := fmt.Sprintf("CREATE SEQUENCE IF NOT EXISTS %s.%s\n INCREMENT 1\n MINVALUE 1\n MAXVALUE 9223372036854775807\n START 1\n CACHE 1",
schema.SQLName(), seqName)
statements = append(statements, stmt)
}
// Phase 3: Create tables
for _, table := range schema.Tables {
stmts, err := w.generateCreateTableStatement(schema, table)
if err != nil {
return nil, fmt.Errorf("failed to generate table %s: %w", table.Name, err)
}
statements = append(statements, stmts...)
}
// Phase 4: Primary keys
for _, table := range schema.Tables {
// First check for explicit PrimaryKeyConstraint
var pkConstraint *models.Constraint
for _, constraint := range table.Constraints {
if constraint.Type == models.PrimaryKeyConstraint {
pkConstraint = constraint
break
}
}
if pkConstraint != nil {
stmt := fmt.Sprintf("ALTER TABLE %s.%s ADD CONSTRAINT %s PRIMARY KEY (%s)",
schema.SQLName(), table.SQLName(), pkConstraint.Name, strings.Join(pkConstraint.Columns, ", "))
statements = append(statements, stmt)
} else {
// No explicit constraint, check for columns with IsPrimaryKey = true
pkColumns := []string{}
for _, col := range table.Columns {
if col.IsPrimaryKey {
pkColumns = append(pkColumns, col.SQLName())
}
}
if len(pkColumns) > 0 {
// Sort for consistent output
sort.Strings(pkColumns)
pkName := fmt.Sprintf("pk_%s_%s", schema.SQLName(), table.SQLName())
stmt := fmt.Sprintf("ALTER TABLE %s.%s ADD CONSTRAINT %s PRIMARY KEY (%s)",
schema.SQLName(), table.SQLName(), pkName, strings.Join(pkColumns, ", "))
statements = append(statements, stmt)
}
}
}
// Phase 5: Indexes
for _, table := range schema.Tables {
for _, index := range table.Indexes {
// Skip primary key indexes
if strings.HasSuffix(index.Name, "_pkey") {
continue
}
uniqueStr := ""
if index.Unique {
uniqueStr = "UNIQUE "
}
indexType := index.Type
if indexType == "" {
indexType = "btree"
}
// Build column expressions with operator class support for GIN indexes
columnExprs := make([]string, 0, len(index.Columns))
for _, colName := range index.Columns {
colExpr := colName
if col, ok := table.Columns[colName]; ok {
// For GIN indexes on text columns, add operator class
if strings.EqualFold(indexType, "gin") && isTextType(col.Type) {
opClass := extractOperatorClass(index.Comment)
if opClass == "" {
opClass = "gin_trgm_ops"
}
colExpr = fmt.Sprintf("%s %s", colName, opClass)
}
}
columnExprs = append(columnExprs, colExpr)
}
whereClause := ""
if index.Where != "" {
whereClause = fmt.Sprintf(" WHERE %s", index.Where)
}
stmt := fmt.Sprintf("CREATE %sINDEX IF NOT EXISTS %s ON %s.%s USING %s (%s)%s",
uniqueStr, index.Name, schema.SQLName(), table.SQLName(), indexType, strings.Join(columnExprs, ", "), whereClause)
statements = append(statements, stmt)
}
}
// Phase 6: Foreign keys
for _, table := range schema.Tables {
for _, constraint := range table.Constraints {
if constraint.Type != models.ForeignKeyConstraint {
continue
}
refSchema := constraint.ReferencedSchema
if refSchema == "" {
refSchema = schema.Name
}
onDelete := constraint.OnDelete
if onDelete == "" {
onDelete = "NO ACTION"
}
onUpdate := constraint.OnUpdate
if onUpdate == "" {
onUpdate = "NO ACTION"
}
stmt := fmt.Sprintf("ALTER TABLE %s.%s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s.%s(%s) ON DELETE %s ON UPDATE %s",
schema.SQLName(), table.SQLName(), constraint.Name,
strings.Join(constraint.Columns, ", "),
strings.ToLower(refSchema), strings.ToLower(constraint.ReferencedTable),
strings.Join(constraint.ReferencedColumns, ", "),
onDelete, onUpdate)
statements = append(statements, stmt)
}
}
// Phase 7: Comments
for _, table := range schema.Tables {
if table.Comment != "" {
stmt := fmt.Sprintf("COMMENT ON TABLE %s.%s IS '%s'",
schema.SQLName(), table.SQLName(), escapeQuote(table.Comment))
statements = append(statements, stmt)
}
for _, column := range table.Columns {
if column.Comment != "" {
stmt := fmt.Sprintf("COMMENT ON COLUMN %s.%s.%s IS '%s'",
schema.SQLName(), table.SQLName(), column.SQLName(), escapeQuote(column.Comment))
statements = append(statements, stmt)
}
}
}
return statements, nil
}
// generateCreateTableStatement generates CREATE TABLE statement
func (w *Writer) generateCreateTableStatement(schema *models.Schema, table *models.Table) ([]string, error) {
statements := []string{}
// Sort columns by sequence or name
columns := make([]*models.Column, 0, len(table.Columns))
for _, col := range table.Columns {
columns = append(columns, col)
}
sort.Slice(columns, func(i, j int) bool {
if columns[i].Sequence != columns[j].Sequence {
return columns[i].Sequence < columns[j].Sequence
}
return columns[i].Name < columns[j].Name
})
columnDefs := []string{}
for _, col := range columns {
def := w.generateColumnDefinition(col)
columnDefs = append(columnDefs, " "+def)
}
stmt := fmt.Sprintf("CREATE TABLE %s.%s (\n%s\n)",
schema.SQLName(), table.SQLName(), strings.Join(columnDefs, ",\n"))
statements = append(statements, stmt)
return statements, nil
}
// generateColumnDefinition generates column definition
func (w *Writer) generateColumnDefinition(col *models.Column) string {
parts := []string{col.SQLName()}
// Type with length/precision
typeStr := col.Type
if col.Length > 0 && col.Precision == 0 {
typeStr = fmt.Sprintf("%s(%d)", col.Type, col.Length)
} else if col.Precision > 0 {
if col.Scale > 0 {
typeStr = fmt.Sprintf("%s(%d,%d)", col.Type, col.Precision, col.Scale)
} else {
typeStr = fmt.Sprintf("%s(%d)", col.Type, col.Precision)
}
}
parts = append(parts, typeStr)
// NOT NULL
if col.NotNull {
parts = append(parts, "NOT NULL")
}
// DEFAULT
if col.Default != nil {
switch v := col.Default.(type) {
case string:
// Strip backticks - DBML uses them for SQL expressions but PostgreSQL doesn't
cleanDefault := stripBackticks(v)
if strings.HasPrefix(cleanDefault, "nextval") || strings.HasPrefix(cleanDefault, "CURRENT_") || strings.Contains(cleanDefault, "()") {
parts = append(parts, fmt.Sprintf("DEFAULT %s", cleanDefault))
} else if cleanDefault == "true" || cleanDefault == "false" {
parts = append(parts, fmt.Sprintf("DEFAULT %s", cleanDefault))
} else {
parts = append(parts, fmt.Sprintf("DEFAULT '%s'", escapeQuote(cleanDefault)))
}
case bool:
parts = append(parts, fmt.Sprintf("DEFAULT %v", v))
default:
parts = append(parts, fmt.Sprintf("DEFAULT %v", v))
}
}
return strings.Join(parts, " ")
}
// WriteSchema writes a single schema and all its tables
func (w *Writer) WriteSchema(schema *models.Schema) error {
if w.writer == nil {
w.writer = os.Stdout
}
// Phase 1: Create schema (priority 1)
if err := w.writeCreateSchema(schema); err != nil {
return err
}
// Phase 2: Create sequences (priority 80)
if err := w.writeSequences(schema); err != nil {
return err
}
// Phase 3: Create tables with columns (priority 100)
if err := w.writeCreateTables(schema); err != nil {
return err
}
// Phase 4: Create primary keys (priority 160)
if err := w.writePrimaryKeys(schema); err != nil {
return err
}
// Phase 5: Create indexes (priority 180)
if err := w.writeIndexes(schema); err != nil {
return err
}
// Phase 6: Create foreign key constraints (priority 195)
if err := w.writeForeignKeys(schema); err != nil {
return err
}
// Phase 7: Set sequence values (priority 200)
if err := w.writeSetSequenceValues(schema); err != nil {
return err
}
// Phase 8: Add comments (priority 200+)
if err := w.writeComments(schema); err != nil {
return err
}
return nil
}
// WriteTable writes a single table with all its elements
func (w *Writer) WriteTable(table *models.Table) error {
if w.writer == nil {
w.writer = os.Stdout
}
// Create a temporary schema with just this table
schema := models.InitSchema(table.Schema)
schema.Tables = append(schema.Tables, table)
return w.WriteSchema(schema)
}
// writeCreateSchema generates CREATE SCHEMA statement
func (w *Writer) writeCreateSchema(schema *models.Schema) error {
if schema.Name == "public" {
// public schema exists by default
return nil
}
fmt.Fprintf(w.writer, "-- Schema: %s\n", schema.Name)
fmt.Fprintf(w.writer, "CREATE SCHEMA IF NOT EXISTS %s;\n\n", schema.SQLName())
return nil
}
// writeSequences generates CREATE SEQUENCE statements for identity columns
func (w *Writer) writeSequences(schema *models.Schema) error {
fmt.Fprintf(w.writer, "-- Sequences for schema: %s\n", schema.Name)
for _, table := range schema.Tables {
pk := table.GetPrimaryKey()
if pk == nil {
continue
}
// Only create sequences for integer-type PKs with identity
if !isIntegerType(pk.Type) {
continue
}
seqName := fmt.Sprintf("identity_%s_%s", table.SQLName(), pk.SQLName())
fmt.Fprintf(w.writer, "CREATE SEQUENCE IF NOT EXISTS %s.%s\n",
schema.SQLName(), seqName)
fmt.Fprintf(w.writer, " INCREMENT 1\n")
fmt.Fprintf(w.writer, " MINVALUE 1\n")
fmt.Fprintf(w.writer, " MAXVALUE 9223372036854775807\n")
fmt.Fprintf(w.writer, " START 1\n")
fmt.Fprintf(w.writer, " CACHE 1;\n\n")
}
return nil
}
// writeCreateTables generates CREATE TABLE statements
func (w *Writer) writeCreateTables(schema *models.Schema) error {
fmt.Fprintf(w.writer, "-- Tables for schema: %s\n", schema.Name)
for _, table := range schema.Tables {
fmt.Fprintf(w.writer, "CREATE TABLE IF NOT EXISTS %s.%s (\n",
schema.SQLName(), table.SQLName())
// Write columns
columns := getSortedColumns(table.Columns)
columnDefs := make([]string, 0, len(columns))
for _, col := range columns {
colDef := fmt.Sprintf(" %s %s", col.SQLName(), col.Type)
// Add default value if present
if col.Default != nil && col.Default != "" {
// Strip backticks - DBML uses them for SQL expressions but PostgreSQL doesn't
defaultVal := fmt.Sprintf("%v", col.Default)
colDef += fmt.Sprintf(" DEFAULT %s", stripBackticks(defaultVal))
}
columnDefs = append(columnDefs, colDef)
}
fmt.Fprintf(w.writer, "%s\n", strings.Join(columnDefs, ",\n"))
fmt.Fprintf(w.writer, ");\n\n")
}
return nil
}
// writePrimaryKeys generates ALTER TABLE statements for primary keys
func (w *Writer) writePrimaryKeys(schema *models.Schema) error {
fmt.Fprintf(w.writer, "-- Primary keys for schema: %s\n", schema.Name)
for _, table := range schema.Tables {
// Find primary key constraint
var pkConstraint *models.Constraint
for name, constraint := range table.Constraints {
if constraint.Type == models.PrimaryKeyConstraint {
pkConstraint = constraint
_ = name // Use the name variable
break
}
}
var columnNames []string
pkName := fmt.Sprintf("pk_%s_%s", schema.SQLName(), table.SQLName())
if pkConstraint != nil {
// Build column list from explicit constraint
columnNames = make([]string, 0, len(pkConstraint.Columns))
for _, colName := range pkConstraint.Columns {
if col, ok := table.Columns[colName]; ok {
columnNames = append(columnNames, col.SQLName())
}
}
} else {
// No explicit PK constraint, check for columns with IsPrimaryKey = true
for _, col := range table.Columns {
if col.IsPrimaryKey {
columnNames = append(columnNames, col.SQLName())
}
}
// Sort for consistent output
sort.Strings(columnNames)
}
if len(columnNames) == 0 {
continue
}
fmt.Fprintf(w.writer, "DO $$\nBEGIN\n")
fmt.Fprintf(w.writer, " IF NOT EXISTS (\n")
fmt.Fprintf(w.writer, " SELECT 1 FROM information_schema.table_constraints\n")
fmt.Fprintf(w.writer, " WHERE table_schema = '%s'\n", schema.Name)
fmt.Fprintf(w.writer, " AND table_name = '%s'\n", table.Name)
fmt.Fprintf(w.writer, " AND constraint_name = '%s'\n", pkName)
fmt.Fprintf(w.writer, " ) THEN\n")
fmt.Fprintf(w.writer, " ALTER TABLE %s.%s\n", schema.SQLName(), table.SQLName())
fmt.Fprintf(w.writer, " ADD CONSTRAINT %s PRIMARY KEY (%s);\n",
pkName, strings.Join(columnNames, ", "))
fmt.Fprintf(w.writer, " END IF;\n")
fmt.Fprintf(w.writer, "END;\n$$;\n\n")
}
return nil
}
// writeIndexes generates CREATE INDEX statements
func (w *Writer) writeIndexes(schema *models.Schema) error {
fmt.Fprintf(w.writer, "-- Indexes for schema: %s\n", schema.Name)
for _, table := range schema.Tables {
// Sort indexes by name for consistent output
indexNames := make([]string, 0, len(table.Indexes))
for name := range table.Indexes {
indexNames = append(indexNames, name)
}
sort.Strings(indexNames)
for _, name := range indexNames {
index := table.Indexes[name]
// Skip if it's a primary key index (based on name convention or columns)
// Primary keys are handled separately
if strings.HasPrefix(strings.ToLower(index.Name), "pk_") {
continue
}
indexName := index.Name
if indexName == "" {
indexType := "idx"
if index.Unique {
indexType = "uk"
}
indexName = fmt.Sprintf("%s_%s_%s", indexType, schema.SQLName(), table.SQLName())
}
// Build column list with operator class support for GIN indexes
columnExprs := make([]string, 0, len(index.Columns))
for _, colName := range index.Columns {
if col, ok := table.Columns[colName]; ok {
colExpr := col.SQLName()
// For GIN indexes on text columns, add operator class
if strings.EqualFold(index.Type, "gin") && isTextType(col.Type) {
opClass := extractOperatorClass(index.Comment)
if opClass == "" {
opClass = "gin_trgm_ops"
}
colExpr = fmt.Sprintf("%s %s", col.SQLName(), opClass)
}
columnExprs = append(columnExprs, colExpr)
}
}
if len(columnExprs) == 0 {
continue
}
unique := ""
if index.Unique {
unique = "UNIQUE "
}
indexType := index.Type
if indexType == "" {
indexType = "btree"
}
whereClause := ""
if index.Where != "" {
whereClause = fmt.Sprintf(" WHERE %s", index.Where)
}
fmt.Fprintf(w.writer, "CREATE %sINDEX IF NOT EXISTS %s\n",
unique, indexName)
fmt.Fprintf(w.writer, " ON %s.%s USING %s (%s)%s;\n\n",
schema.SQLName(), table.SQLName(), indexType, strings.Join(columnExprs, ", "), whereClause)
}
}
return nil
}
// writeForeignKeys generates ALTER TABLE statements for foreign keys
func (w *Writer) writeForeignKeys(schema *models.Schema) error {
fmt.Fprintf(w.writer, "-- Foreign keys for schema: %s\n", schema.Name)
for _, table := range schema.Tables {
// Sort relationships by name for consistent output
relNames := make([]string, 0, len(table.Relationships))
for name := range table.Relationships {
relNames = append(relNames, name)
}
sort.Strings(relNames)
for _, name := range relNames {
rel := table.Relationships[name]
// For relationships, we need to look up the foreign key constraint
// that defines the actual column mappings
fkName := rel.ForeignKey
if fkName == "" {
fkName = name
}
if fkName == "" {
fkName = fmt.Sprintf("fk_%s_%s", table.SQLName(), rel.ToTable)
}
// Find the foreign key constraint that matches this relationship
var fkConstraint *models.Constraint
for _, constraint := range table.Constraints {
if constraint.Type == models.ForeignKeyConstraint &&
(constraint.Name == fkName || constraint.ReferencedTable == rel.ToTable) {
fkConstraint = constraint
break
}
}
// If no constraint found, skip this relationship
if fkConstraint == nil {
continue
}
// Build column lists from the constraint
sourceColumns := make([]string, 0, len(fkConstraint.Columns))
for _, colName := range fkConstraint.Columns {
if col, ok := table.Columns[colName]; ok {
sourceColumns = append(sourceColumns, col.SQLName())
}
}
targetColumns := make([]string, 0, len(fkConstraint.ReferencedColumns))
for _, colName := range fkConstraint.ReferencedColumns {
targetColumns = append(targetColumns, strings.ToLower(colName))
}
if len(sourceColumns) == 0 || len(targetColumns) == 0 {
continue
}
onDelete := "NO ACTION"
if fkConstraint.OnDelete != "" {
onDelete = strings.ToUpper(fkConstraint.OnDelete)
}
onUpdate := "NO ACTION"
if fkConstraint.OnUpdate != "" {
onUpdate = strings.ToUpper(fkConstraint.OnUpdate)
}
fmt.Fprintf(w.writer, "ALTER TABLE %s.%s\n", schema.SQLName(), table.SQLName())
fmt.Fprintf(w.writer, " DROP CONSTRAINT IF EXISTS %s;\n", fkName)
fmt.Fprintf(w.writer, "\n")
fmt.Fprintf(w.writer, "ALTER TABLE %s.%s\n", schema.SQLName(), table.SQLName())
fmt.Fprintf(w.writer, " ADD CONSTRAINT %s\n", fkName)
fmt.Fprintf(w.writer, " FOREIGN KEY (%s)\n", strings.Join(sourceColumns, ", "))
// Use constraint's referenced schema/table or relationship's ToSchema/ToTable
refSchema := fkConstraint.ReferencedSchema
if refSchema == "" {
refSchema = rel.ToSchema
}
refTable := fkConstraint.ReferencedTable
if refTable == "" {
refTable = rel.ToTable
}
fmt.Fprintf(w.writer, " REFERENCES %s.%s (%s)\n",
refSchema, refTable, strings.Join(targetColumns, ", "))
fmt.Fprintf(w.writer, " ON DELETE %s\n", onDelete)
fmt.Fprintf(w.writer, " ON UPDATE %s\n", onUpdate)
fmt.Fprintf(w.writer, " DEFERRABLE;\n\n")
}
}
return nil
}
// writeSetSequenceValues generates statements to set sequence current values
func (w *Writer) writeSetSequenceValues(schema *models.Schema) error {
fmt.Fprintf(w.writer, "-- Set sequence values for schema: %s\n", schema.Name)
for _, table := range schema.Tables {
pk := table.GetPrimaryKey()
if pk == nil || !isIntegerType(pk.Type) {
continue
}
seqName := fmt.Sprintf("identity_%s_%s", table.SQLName(), pk.SQLName())
fmt.Fprintf(w.writer, "DO $$\n")
fmt.Fprintf(w.writer, "DECLARE\n")
fmt.Fprintf(w.writer, " m_cnt bigint;\n")
fmt.Fprintf(w.writer, "BEGIN\n")
fmt.Fprintf(w.writer, " IF EXISTS (\n")
fmt.Fprintf(w.writer, " SELECT 1 FROM pg_class c\n")
fmt.Fprintf(w.writer, " INNER JOIN pg_namespace n ON n.oid = c.relnamespace\n")
fmt.Fprintf(w.writer, " WHERE c.relname = '%s'\n", seqName)
fmt.Fprintf(w.writer, " AND n.nspname = '%s'\n", schema.Name)
fmt.Fprintf(w.writer, " AND c.relkind = 'S'\n")
fmt.Fprintf(w.writer, " ) THEN\n")
fmt.Fprintf(w.writer, " SELECT COALESCE(MAX(%s), 0) + 1\n", pk.SQLName())
fmt.Fprintf(w.writer, " FROM %s.%s\n", schema.SQLName(), table.SQLName())
fmt.Fprintf(w.writer, " INTO m_cnt;\n")
fmt.Fprintf(w.writer, " \n")
fmt.Fprintf(w.writer, " PERFORM setval('%s.%s'::regclass, m_cnt);\n",
schema.SQLName(), seqName)
fmt.Fprintf(w.writer, " END IF;\n")
fmt.Fprintf(w.writer, "END;\n")
fmt.Fprintf(w.writer, "$$;\n\n")
}
return nil
}
// writeComments generates COMMENT statements for tables and columns
func (w *Writer) writeComments(schema *models.Schema) error {
fmt.Fprintf(w.writer, "-- Comments for schema: %s\n", schema.Name)
for _, table := range schema.Tables {
// Table comment
if table.Description != "" {
fmt.Fprintf(w.writer, "COMMENT ON TABLE %s.%s IS '%s';\n",
schema.SQLName(), table.SQLName(),
escapeQuote(table.Description))
}
// Column comments
for _, col := range getSortedColumns(table.Columns) {
if col.Description != "" {
fmt.Fprintf(w.writer, "COMMENT ON COLUMN %s.%s.%s IS '%s';\n",
schema.SQLName(), table.SQLName(), col.SQLName(),
escapeQuote(col.Description))
}
}
fmt.Fprintf(w.writer, "\n")
}
return nil
}
// Helper functions
// getSortedColumns returns columns sorted by name
func getSortedColumns(columns map[string]*models.Column) []*models.Column {
names := make([]string, 0, len(columns))
for name := range columns {
names = append(names, name)
}
sort.Strings(names)
sorted := make([]*models.Column, 0, len(columns))
for _, name := range names {
sorted = append(sorted, columns[name])
}
return sorted
}
// isIntegerType checks if a column type is an integer type
func isIntegerType(colType string) bool {
intTypes := []string{"integer", "int", "bigint", "smallint", "serial", "bigserial"}
lowerType := strings.ToLower(colType)
for _, t := range intTypes {
if strings.HasPrefix(lowerType, t) {
return true
}
}
return false
}
// isTextType checks if a column type is a text type (for GIN index operator class)
func isTextType(colType string) bool {
textTypes := []string{"text", "varchar", "character varying", "char", "character", "string"}
lowerType := strings.ToLower(colType)
for _, t := range textTypes {
if strings.HasPrefix(lowerType, t) {
return true
}
}
return false
}
// extractOperatorClass extracts operator class from index comment/note
// Looks for common operator classes like gin_trgm_ops, gist_trgm_ops, etc.
func extractOperatorClass(comment string) string {
if comment == "" {
return ""
}
lowerComment := strings.ToLower(comment)
// Common GIN/GiST operator classes
opClasses := []string{"gin_trgm_ops", "gist_trgm_ops", "gin_bigm_ops", "jsonb_ops", "jsonb_path_ops", "array_ops"}
for _, op := range opClasses {
if strings.Contains(lowerComment, op) {
return op
}
}
return ""
}
// escapeQuote escapes single quotes in strings for SQL
func escapeQuote(s string) string {
return strings.ReplaceAll(s, "'", "''")
}
// stripBackticks removes backticks from SQL expressions
// DBML uses backticks for SQL expressions like `now()`, but PostgreSQL doesn't use backticks
func stripBackticks(s string) string {
return strings.ReplaceAll(s, "`", "")
}
// extractSequenceName extracts sequence name from nextval() expression
// Example: "nextval('public.users_id_seq'::regclass)" returns "users_id_seq"
func extractSequenceName(defaultExpr string) string {
// Look for nextval('schema.sequence_name'::regclass) pattern
start := strings.Index(defaultExpr, "'")
if start == -1 {
return ""
}
end := strings.Index(defaultExpr[start+1:], "'")
if end == -1 {
return ""
}
fullName := defaultExpr[start+1 : start+1+end]
// Remove schema prefix if present
parts := strings.Split(fullName, ".")
if len(parts) > 1 {
return parts[len(parts)-1]
}
return fullName
}
// executeDatabaseSQL executes SQL statements directly on a PostgreSQL database
func (w *Writer) executeDatabaseSQL(db *models.Database, connString string) error {
// Initialize execution report
w.executionReport = &ExecutionReport{
StartTime: getCurrentTimestamp(),
Schemas: make([]SchemaReport, 0),
Errors: make([]ExecutionError, 0),
}
// Generate SQL statements
statements, err := w.GenerateDatabaseStatements(db)
if err != nil {
return fmt.Errorf("failed to generate SQL statements: %w", err)
}
w.executionReport.TotalStatements = len(statements)
// Connect to database
ctx := context.Background()
conn, err := pgx.Connect(ctx, connString)
if err != nil {
return fmt.Errorf("failed to connect to database: %w", err)
}
defer conn.Close(ctx)
// Track schemas and tables
schemaMap := make(map[string]*SchemaReport)
currentSchema := ""
// Execute each statement
for i, stmt := range statements {
stmtTrimmed := strings.TrimSpace(stmt)
// Skip comments
if strings.HasPrefix(stmtTrimmed, "--") {
// Check if this is a schema comment to track schema changes
if strings.Contains(stmtTrimmed, "Schema:") {
parts := strings.Split(stmtTrimmed, "Schema:")
if len(parts) > 1 {
currentSchema = strings.TrimSpace(parts[1])
if _, exists := schemaMap[currentSchema]; !exists {
schemaReport := SchemaReport{
Name: currentSchema,
Tables: make([]TableReport, 0),
}
schemaMap[currentSchema] = &schemaReport
}
}
}
continue
}
// Skip empty statements
if stmtTrimmed == "" {
continue
}
fmt.Fprintf(os.Stderr, "Executing statement %d/%d...\n", i+1, len(statements))
_, execErr := conn.Exec(ctx, stmt)
if execErr != nil {
w.executionReport.FailedStatements++
execError := ExecutionError{
StatementNumber: i + 1,
Statement: truncateStatement(stmt),
Error: execErr.Error(),
}
w.executionReport.Errors = append(w.executionReport.Errors, execError)
// Track table creation failure
if strings.Contains(strings.ToUpper(stmtTrimmed), "CREATE TABLE") && currentSchema != "" {
tableName := extractTableNameFromCreate(stmtTrimmed)
if tableName != "" && schemaMap[currentSchema] != nil {
schemaMap[currentSchema].Tables = append(schemaMap[currentSchema].Tables, TableReport{
Name: tableName,
Created: false,
Error: execErr.Error(),
})
}
}
// Continue with next statement instead of failing completely
fmt.Fprintf(os.Stderr, "⚠ Warning: Statement %d failed: %v\n", i+1, execErr)
continue
}
w.executionReport.ExecutedStatements++
// Track successful table creation
if strings.Contains(strings.ToUpper(stmtTrimmed), "CREATE TABLE") && currentSchema != "" {
tableName := extractTableNameFromCreate(stmtTrimmed)
if tableName != "" && schemaMap[currentSchema] != nil {
schemaMap[currentSchema].Tables = append(schemaMap[currentSchema].Tables, TableReport{
Name: tableName,
Created: true,
})
}
}
}
// Convert schema map to slice
for _, schemaReport := range schemaMap {
w.executionReport.Schemas = append(w.executionReport.Schemas, *schemaReport)
}
w.executionReport.EndTime = getCurrentTimestamp()
// Write report if path is specified
if reportPath, ok := w.options.Metadata["report_path"].(string); ok && reportPath != "" {
if err := w.writeReport(reportPath); err != nil {
fmt.Fprintf(os.Stderr, "⚠ Warning: Failed to write report: %v\n", err)
} else {
fmt.Fprintf(os.Stderr, "✓ Report written to: %s\n", reportPath)
}
}
if w.executionReport.FailedStatements > 0 {
fmt.Fprintf(os.Stderr, "⚠ Completed with %d errors out of %d statements\n",
w.executionReport.FailedStatements, w.executionReport.TotalStatements)
} else {
fmt.Fprintf(os.Stderr, "✓ Successfully executed %d statements\n", w.executionReport.ExecutedStatements)
}
return nil
}
// writeReport writes the execution report to a JSON file
func (w *Writer) writeReport(reportPath string) error {
file, err := os.Create(reportPath)
if err != nil {
return fmt.Errorf("failed to create report file: %w", err)
}
defer file.Close()
encoder := json.NewEncoder(file)
encoder.SetIndent("", " ")
if err := encoder.Encode(w.executionReport); err != nil {
return fmt.Errorf("failed to encode report: %w", err)
}
return nil
}
// extractTableNameFromCreate extracts table name from CREATE TABLE statement
func extractTableNameFromCreate(stmt string) string {
// Match: CREATE TABLE [IF NOT EXISTS] schema.table_name or table_name
upper := strings.ToUpper(stmt)
idx := strings.Index(upper, "CREATE TABLE")
if idx == -1 {
return ""
}
rest := strings.TrimSpace(stmt[idx+12:]) // Skip "CREATE TABLE"
// Skip "IF NOT EXISTS"
if strings.HasPrefix(strings.ToUpper(rest), "IF NOT EXISTS") {
rest = strings.TrimSpace(rest[13:])
}
// Get the table name (first token before '(' or whitespace)
tokens := strings.FieldsFunc(rest, func(r rune) bool {
return r == '(' || r == ' ' || r == '\n' || r == '\t'
})
if len(tokens) == 0 {
return ""
}
// Handle schema.table format
fullName := tokens[0]
parts := strings.Split(fullName, ".")
if len(parts) > 1 {
return parts[len(parts)-1]
}
return fullName
}
// truncateStatement truncates long SQL statements for error messages
func truncateStatement(stmt string) string {
const maxLen = 200
if len(stmt) <= maxLen {
return stmt
}
return stmt[:maxLen] + "..."
}
// getCurrentTimestamp returns the current timestamp in a readable format
func getCurrentTimestamp() string {
return time.Now().Format("2006-01-02 15:04:05")
}