748 lines
21 KiB
Go
748 lines
21 KiB
Go
package pgsql
|
|
|
|
import (
|
|
"fmt"
|
|
"io"
|
|
"os"
|
|
"sort"
|
|
"strings"
|
|
|
|
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
|
"git.warky.dev/wdevs/relspecgo/pkg/writers"
|
|
)
|
|
|
|
// Writer implements the Writer interface for PostgreSQL SQL output
|
|
type Writer struct {
|
|
options *writers.WriterOptions
|
|
writer io.Writer
|
|
}
|
|
|
|
// NewWriter creates a new PostgreSQL SQL writer
|
|
func NewWriter(options *writers.WriterOptions) *Writer {
|
|
return &Writer{
|
|
options: options,
|
|
}
|
|
}
|
|
|
|
// WriteDatabase writes the entire database schema as SQL
|
|
func (w *Writer) WriteDatabase(db *models.Database) error {
|
|
var writer io.Writer
|
|
var file *os.File
|
|
var err error
|
|
|
|
// Use existing writer if already set (for testing)
|
|
if w.writer != nil {
|
|
writer = w.writer
|
|
} else if w.options.OutputPath != "" {
|
|
// Determine output destination
|
|
file, err = os.Create(w.options.OutputPath)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to create output file: %w", err)
|
|
}
|
|
defer file.Close()
|
|
writer = file
|
|
} else {
|
|
writer = os.Stdout
|
|
}
|
|
|
|
w.writer = writer
|
|
|
|
// Write header comment
|
|
fmt.Fprintf(w.writer, "-- PostgreSQL Database Schema\n")
|
|
fmt.Fprintf(w.writer, "-- Database: %s\n", db.Name)
|
|
fmt.Fprintf(w.writer, "-- Generated by RelSpec\n\n")
|
|
|
|
// Process each schema in the database
|
|
for _, schema := range db.Schemas {
|
|
if err := w.WriteSchema(schema); err != nil {
|
|
return fmt.Errorf("failed to write schema %s: %w", schema.Name, err)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// GenerateDatabaseStatements generates SQL statements as a list for the entire database
|
|
// Returns a slice of SQL statements that can be executed independently
|
|
func (w *Writer) GenerateDatabaseStatements(db *models.Database) ([]string, error) {
|
|
statements := []string{}
|
|
|
|
// Add header comment
|
|
statements = append(statements, "-- PostgreSQL Database Schema")
|
|
statements = append(statements, fmt.Sprintf("-- Database: %s", db.Name))
|
|
statements = append(statements, "-- Generated by RelSpec")
|
|
|
|
// Process each schema in the database
|
|
for _, schema := range db.Schemas {
|
|
schemaStatements, err := w.GenerateSchemaStatements(schema)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to generate statements for schema %s: %w", schema.Name, err)
|
|
}
|
|
statements = append(statements, schemaStatements...)
|
|
}
|
|
|
|
return statements, nil
|
|
}
|
|
|
|
// GenerateSchemaStatements generates SQL statements as a list for a single schema
|
|
func (w *Writer) GenerateSchemaStatements(schema *models.Schema) ([]string, error) {
|
|
statements := []string{}
|
|
|
|
// Phase 1: Create schema
|
|
if schema.Name != "public" {
|
|
statements = append(statements, fmt.Sprintf("-- Schema: %s", schema.Name))
|
|
statements = append(statements, fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS %s", schema.SQLName()))
|
|
}
|
|
|
|
// Phase 2: Create sequences
|
|
for _, table := range schema.Tables {
|
|
pk := table.GetPrimaryKey()
|
|
if pk == nil || !isIntegerType(pk.Type) || pk.Default == "" {
|
|
continue
|
|
}
|
|
|
|
defaultStr, ok := pk.Default.(string)
|
|
if !ok || !strings.Contains(strings.ToLower(defaultStr), "nextval") {
|
|
continue
|
|
}
|
|
|
|
seqName := extractSequenceName(defaultStr)
|
|
if seqName == "" {
|
|
continue
|
|
}
|
|
|
|
stmt := fmt.Sprintf("CREATE SEQUENCE IF NOT EXISTS %s.%s\n INCREMENT 1\n MINVALUE 1\n MAXVALUE 9223372036854775807\n START 1\n CACHE 1",
|
|
schema.SQLName(), seqName)
|
|
statements = append(statements, stmt)
|
|
}
|
|
|
|
// Phase 3: Create tables
|
|
for _, table := range schema.Tables {
|
|
stmts, err := w.generateCreateTableStatement(schema, table)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to generate table %s: %w", table.Name, err)
|
|
}
|
|
statements = append(statements, stmts...)
|
|
}
|
|
|
|
// Phase 4: Primary keys
|
|
for _, table := range schema.Tables {
|
|
for _, constraint := range table.Constraints {
|
|
if constraint.Type != models.PrimaryKeyConstraint {
|
|
continue
|
|
}
|
|
stmt := fmt.Sprintf("ALTER TABLE %s.%s ADD CONSTRAINT %s PRIMARY KEY (%s)",
|
|
schema.SQLName(), table.SQLName(), constraint.Name, strings.Join(constraint.Columns, ", "))
|
|
statements = append(statements, stmt)
|
|
}
|
|
}
|
|
|
|
// Phase 5: Indexes
|
|
for _, table := range schema.Tables {
|
|
for _, index := range table.Indexes {
|
|
// Skip primary key indexes
|
|
if strings.HasSuffix(index.Name, "_pkey") {
|
|
continue
|
|
}
|
|
|
|
uniqueStr := ""
|
|
if index.Unique {
|
|
uniqueStr = "UNIQUE "
|
|
}
|
|
|
|
indexType := index.Type
|
|
if indexType == "" {
|
|
indexType = "btree"
|
|
}
|
|
|
|
whereClause := ""
|
|
if index.Where != "" {
|
|
whereClause = fmt.Sprintf(" WHERE %s", index.Where)
|
|
}
|
|
|
|
stmt := fmt.Sprintf("CREATE %sINDEX IF NOT EXISTS %s ON %s.%s USING %s (%s)%s",
|
|
uniqueStr, index.Name, schema.SQLName(), table.SQLName(), indexType, strings.Join(index.Columns, ", "), whereClause)
|
|
statements = append(statements, stmt)
|
|
}
|
|
}
|
|
|
|
// Phase 6: Foreign keys
|
|
for _, table := range schema.Tables {
|
|
for _, constraint := range table.Constraints {
|
|
if constraint.Type != models.ForeignKeyConstraint {
|
|
continue
|
|
}
|
|
|
|
refSchema := constraint.ReferencedSchema
|
|
if refSchema == "" {
|
|
refSchema = schema.Name
|
|
}
|
|
|
|
onDelete := constraint.OnDelete
|
|
if onDelete == "" {
|
|
onDelete = "NO ACTION"
|
|
}
|
|
|
|
onUpdate := constraint.OnUpdate
|
|
if onUpdate == "" {
|
|
onUpdate = "NO ACTION"
|
|
}
|
|
|
|
stmt := fmt.Sprintf("ALTER TABLE %s.%s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s.%s(%s) ON DELETE %s ON UPDATE %s",
|
|
schema.SQLName(), table.SQLName(), constraint.Name,
|
|
strings.Join(constraint.Columns, ", "),
|
|
strings.ToLower(refSchema), strings.ToLower(constraint.ReferencedTable),
|
|
strings.Join(constraint.ReferencedColumns, ", "),
|
|
onDelete, onUpdate)
|
|
statements = append(statements, stmt)
|
|
}
|
|
}
|
|
|
|
// Phase 7: Comments
|
|
for _, table := range schema.Tables {
|
|
if table.Comment != "" {
|
|
stmt := fmt.Sprintf("COMMENT ON TABLE %s.%s IS '%s'",
|
|
schema.SQLName(), table.SQLName(), escapeQuote(table.Comment))
|
|
statements = append(statements, stmt)
|
|
}
|
|
|
|
for _, column := range table.Columns {
|
|
if column.Comment != "" {
|
|
stmt := fmt.Sprintf("COMMENT ON COLUMN %s.%s.%s IS '%s'",
|
|
schema.SQLName(), table.SQLName(), column.SQLName(), escapeQuote(column.Comment))
|
|
statements = append(statements, stmt)
|
|
}
|
|
}
|
|
}
|
|
|
|
return statements, nil
|
|
}
|
|
|
|
// generateCreateTableStatement generates CREATE TABLE statement
|
|
func (w *Writer) generateCreateTableStatement(schema *models.Schema, table *models.Table) ([]string, error) {
|
|
statements := []string{}
|
|
|
|
// Sort columns by sequence or name
|
|
columns := make([]*models.Column, 0, len(table.Columns))
|
|
for _, col := range table.Columns {
|
|
columns = append(columns, col)
|
|
}
|
|
sort.Slice(columns, func(i, j int) bool {
|
|
if columns[i].Sequence != columns[j].Sequence {
|
|
return columns[i].Sequence < columns[j].Sequence
|
|
}
|
|
return columns[i].Name < columns[j].Name
|
|
})
|
|
|
|
columnDefs := []string{}
|
|
for _, col := range columns {
|
|
def := w.generateColumnDefinition(col)
|
|
columnDefs = append(columnDefs, " "+def)
|
|
}
|
|
|
|
stmt := fmt.Sprintf("CREATE TABLE %s.%s (\n%s\n)",
|
|
schema.SQLName(), table.SQLName(), strings.Join(columnDefs, ",\n"))
|
|
statements = append(statements, stmt)
|
|
|
|
return statements, nil
|
|
}
|
|
|
|
// generateColumnDefinition generates column definition
|
|
func (w *Writer) generateColumnDefinition(col *models.Column) string {
|
|
parts := []string{col.SQLName()}
|
|
|
|
// Type with length/precision
|
|
typeStr := col.Type
|
|
if col.Length > 0 && col.Precision == 0 {
|
|
typeStr = fmt.Sprintf("%s(%d)", col.Type, col.Length)
|
|
} else if col.Precision > 0 {
|
|
if col.Scale > 0 {
|
|
typeStr = fmt.Sprintf("%s(%d,%d)", col.Type, col.Precision, col.Scale)
|
|
} else {
|
|
typeStr = fmt.Sprintf("%s(%d)", col.Type, col.Precision)
|
|
}
|
|
}
|
|
parts = append(parts, typeStr)
|
|
|
|
// NOT NULL
|
|
if col.NotNull {
|
|
parts = append(parts, "NOT NULL")
|
|
}
|
|
|
|
// DEFAULT
|
|
if col.Default != nil {
|
|
switch v := col.Default.(type) {
|
|
case string:
|
|
if strings.HasPrefix(v, "nextval") || strings.HasPrefix(v, "CURRENT_") || strings.Contains(v, "()") {
|
|
parts = append(parts, fmt.Sprintf("DEFAULT %s", v))
|
|
} else if v == "true" || v == "false" {
|
|
parts = append(parts, fmt.Sprintf("DEFAULT %s", v))
|
|
} else {
|
|
parts = append(parts, fmt.Sprintf("DEFAULT '%s'", escapeQuote(v)))
|
|
}
|
|
case bool:
|
|
parts = append(parts, fmt.Sprintf("DEFAULT %v", v))
|
|
default:
|
|
parts = append(parts, fmt.Sprintf("DEFAULT %v", v))
|
|
}
|
|
}
|
|
|
|
return strings.Join(parts, " ")
|
|
}
|
|
|
|
// WriteSchema writes a single schema and all its tables
|
|
func (w *Writer) WriteSchema(schema *models.Schema) error {
|
|
if w.writer == nil {
|
|
w.writer = os.Stdout
|
|
}
|
|
|
|
// Phase 1: Create schema (priority 1)
|
|
if err := w.writeCreateSchema(schema); err != nil {
|
|
return err
|
|
}
|
|
|
|
// Phase 2: Create sequences (priority 80)
|
|
if err := w.writeSequences(schema); err != nil {
|
|
return err
|
|
}
|
|
|
|
// Phase 3: Create tables with columns (priority 100)
|
|
if err := w.writeCreateTables(schema); err != nil {
|
|
return err
|
|
}
|
|
|
|
// Phase 4: Create primary keys (priority 160)
|
|
if err := w.writePrimaryKeys(schema); err != nil {
|
|
return err
|
|
}
|
|
|
|
// Phase 5: Create indexes (priority 180)
|
|
if err := w.writeIndexes(schema); err != nil {
|
|
return err
|
|
}
|
|
|
|
// Phase 6: Create foreign key constraints (priority 195)
|
|
if err := w.writeForeignKeys(schema); err != nil {
|
|
return err
|
|
}
|
|
|
|
// Phase 7: Set sequence values (priority 200)
|
|
if err := w.writeSetSequenceValues(schema); err != nil {
|
|
return err
|
|
}
|
|
|
|
// Phase 8: Add comments (priority 200+)
|
|
if err := w.writeComments(schema); err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// WriteTable writes a single table with all its elements
|
|
func (w *Writer) WriteTable(table *models.Table) error {
|
|
if w.writer == nil {
|
|
w.writer = os.Stdout
|
|
}
|
|
|
|
// Create a temporary schema with just this table
|
|
schema := models.InitSchema(table.Schema)
|
|
schema.Tables = append(schema.Tables, table)
|
|
|
|
return w.WriteSchema(schema)
|
|
}
|
|
|
|
// writeCreateSchema generates CREATE SCHEMA statement
|
|
func (w *Writer) writeCreateSchema(schema *models.Schema) error {
|
|
if schema.Name == "public" {
|
|
// public schema exists by default
|
|
return nil
|
|
}
|
|
|
|
fmt.Fprintf(w.writer, "-- Schema: %s\n", schema.Name)
|
|
fmt.Fprintf(w.writer, "CREATE SCHEMA IF NOT EXISTS %s;\n\n", schema.SQLName())
|
|
return nil
|
|
}
|
|
|
|
// writeSequences generates CREATE SEQUENCE statements for identity columns
|
|
func (w *Writer) writeSequences(schema *models.Schema) error {
|
|
fmt.Fprintf(w.writer, "-- Sequences for schema: %s\n", schema.Name)
|
|
|
|
for _, table := range schema.Tables {
|
|
pk := table.GetPrimaryKey()
|
|
if pk == nil {
|
|
continue
|
|
}
|
|
|
|
// Only create sequences for integer-type PKs with identity
|
|
if !isIntegerType(pk.Type) {
|
|
continue
|
|
}
|
|
|
|
seqName := fmt.Sprintf("identity_%s_%s", table.SQLName(), pk.SQLName())
|
|
fmt.Fprintf(w.writer, "CREATE SEQUENCE IF NOT EXISTS %s.%s\n",
|
|
schema.SQLName(), seqName)
|
|
fmt.Fprintf(w.writer, " INCREMENT 1\n")
|
|
fmt.Fprintf(w.writer, " MINVALUE 1\n")
|
|
fmt.Fprintf(w.writer, " MAXVALUE 9223372036854775807\n")
|
|
fmt.Fprintf(w.writer, " START 1\n")
|
|
fmt.Fprintf(w.writer, " CACHE 1;\n\n")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// writeCreateTables generates CREATE TABLE statements
|
|
func (w *Writer) writeCreateTables(schema *models.Schema) error {
|
|
fmt.Fprintf(w.writer, "-- Tables for schema: %s\n", schema.Name)
|
|
|
|
for _, table := range schema.Tables {
|
|
fmt.Fprintf(w.writer, "CREATE TABLE IF NOT EXISTS %s.%s (\n",
|
|
schema.SQLName(), table.SQLName())
|
|
|
|
// Write columns
|
|
columns := getSortedColumns(table.Columns)
|
|
columnDefs := make([]string, 0, len(columns))
|
|
|
|
for _, col := range columns {
|
|
colDef := fmt.Sprintf(" %s %s", col.SQLName(), col.Type)
|
|
|
|
// Add default value if present
|
|
if col.Default != "" {
|
|
colDef += fmt.Sprintf(" DEFAULT %s", col.Default)
|
|
}
|
|
|
|
columnDefs = append(columnDefs, colDef)
|
|
}
|
|
|
|
fmt.Fprintf(w.writer, "%s\n", strings.Join(columnDefs, ",\n"))
|
|
fmt.Fprintf(w.writer, ");\n\n")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// writePrimaryKeys generates ALTER TABLE statements for primary keys
|
|
func (w *Writer) writePrimaryKeys(schema *models.Schema) error {
|
|
fmt.Fprintf(w.writer, "-- Primary keys for schema: %s\n", schema.Name)
|
|
|
|
for _, table := range schema.Tables {
|
|
// Find primary key constraint
|
|
var pkConstraint *models.Constraint
|
|
for name, constraint := range table.Constraints {
|
|
if constraint.Type == models.PrimaryKeyConstraint {
|
|
pkConstraint = constraint
|
|
_ = name // Use the name variable
|
|
break
|
|
}
|
|
}
|
|
|
|
if pkConstraint == nil {
|
|
// No explicit PK constraint, skip
|
|
continue
|
|
}
|
|
|
|
pkName := fmt.Sprintf("pk_%s_%s", schema.SQLName(), table.SQLName())
|
|
|
|
// Build column list
|
|
columnNames := make([]string, 0, len(pkConstraint.Columns))
|
|
for _, colName := range pkConstraint.Columns {
|
|
if col, ok := table.Columns[colName]; ok {
|
|
columnNames = append(columnNames, col.SQLName())
|
|
}
|
|
}
|
|
|
|
if len(columnNames) == 0 {
|
|
continue
|
|
}
|
|
|
|
fmt.Fprintf(w.writer, "DO $$\nBEGIN\n")
|
|
fmt.Fprintf(w.writer, " IF NOT EXISTS (\n")
|
|
fmt.Fprintf(w.writer, " SELECT 1 FROM information_schema.table_constraints\n")
|
|
fmt.Fprintf(w.writer, " WHERE table_schema = '%s'\n", schema.Name)
|
|
fmt.Fprintf(w.writer, " AND table_name = '%s'\n", table.Name)
|
|
fmt.Fprintf(w.writer, " AND constraint_name = '%s'\n", pkName)
|
|
fmt.Fprintf(w.writer, " ) THEN\n")
|
|
fmt.Fprintf(w.writer, " ALTER TABLE %s.%s\n", schema.SQLName(), table.SQLName())
|
|
fmt.Fprintf(w.writer, " ADD CONSTRAINT %s PRIMARY KEY (%s);\n",
|
|
pkName, strings.Join(columnNames, ", "))
|
|
fmt.Fprintf(w.writer, " END IF;\n")
|
|
fmt.Fprintf(w.writer, "END;\n$$;\n\n")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// writeIndexes generates CREATE INDEX statements
|
|
func (w *Writer) writeIndexes(schema *models.Schema) error {
|
|
fmt.Fprintf(w.writer, "-- Indexes for schema: %s\n", schema.Name)
|
|
|
|
for _, table := range schema.Tables {
|
|
// Sort indexes by name for consistent output
|
|
indexNames := make([]string, 0, len(table.Indexes))
|
|
for name := range table.Indexes {
|
|
indexNames = append(indexNames, name)
|
|
}
|
|
sort.Strings(indexNames)
|
|
|
|
for _, name := range indexNames {
|
|
index := table.Indexes[name]
|
|
|
|
// Skip if it's a primary key index (based on name convention or columns)
|
|
// Primary keys are handled separately
|
|
if strings.HasPrefix(strings.ToLower(index.Name), "pk_") {
|
|
continue
|
|
}
|
|
|
|
indexName := index.Name
|
|
if indexName == "" {
|
|
indexType := "idx"
|
|
if index.Unique {
|
|
indexType = "uk"
|
|
}
|
|
indexName = fmt.Sprintf("%s_%s_%s", indexType, schema.SQLName(), table.SQLName())
|
|
}
|
|
|
|
// Build column list
|
|
columnNames := make([]string, 0, len(index.Columns))
|
|
for _, colName := range index.Columns {
|
|
if col, ok := table.Columns[colName]; ok {
|
|
columnNames = append(columnNames, col.SQLName())
|
|
}
|
|
}
|
|
|
|
if len(columnNames) == 0 {
|
|
continue
|
|
}
|
|
|
|
unique := ""
|
|
if index.Unique {
|
|
unique = "UNIQUE "
|
|
}
|
|
|
|
fmt.Fprintf(w.writer, "CREATE %sINDEX IF NOT EXISTS %s\n",
|
|
unique, indexName)
|
|
fmt.Fprintf(w.writer, " ON %s.%s USING btree (%s);\n\n",
|
|
schema.SQLName(), table.SQLName(), strings.Join(columnNames, ", "))
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// writeForeignKeys generates ALTER TABLE statements for foreign keys
|
|
func (w *Writer) writeForeignKeys(schema *models.Schema) error {
|
|
fmt.Fprintf(w.writer, "-- Foreign keys for schema: %s\n", schema.Name)
|
|
|
|
for _, table := range schema.Tables {
|
|
// Sort relationships by name for consistent output
|
|
relNames := make([]string, 0, len(table.Relationships))
|
|
for name := range table.Relationships {
|
|
relNames = append(relNames, name)
|
|
}
|
|
sort.Strings(relNames)
|
|
|
|
for _, name := range relNames {
|
|
rel := table.Relationships[name]
|
|
|
|
// For relationships, we need to look up the foreign key constraint
|
|
// that defines the actual column mappings
|
|
fkName := rel.ForeignKey
|
|
if fkName == "" {
|
|
fkName = name
|
|
}
|
|
if fkName == "" {
|
|
fkName = fmt.Sprintf("fk_%s_%s", table.SQLName(), rel.ToTable)
|
|
}
|
|
|
|
// Find the foreign key constraint that matches this relationship
|
|
var fkConstraint *models.Constraint
|
|
for _, constraint := range table.Constraints {
|
|
if constraint.Type == models.ForeignKeyConstraint &&
|
|
(constraint.Name == fkName || constraint.ReferencedTable == rel.ToTable) {
|
|
fkConstraint = constraint
|
|
break
|
|
}
|
|
}
|
|
|
|
// If no constraint found, skip this relationship
|
|
if fkConstraint == nil {
|
|
continue
|
|
}
|
|
|
|
// Build column lists from the constraint
|
|
sourceColumns := make([]string, 0, len(fkConstraint.Columns))
|
|
for _, colName := range fkConstraint.Columns {
|
|
if col, ok := table.Columns[colName]; ok {
|
|
sourceColumns = append(sourceColumns, col.SQLName())
|
|
}
|
|
}
|
|
|
|
targetColumns := make([]string, 0, len(fkConstraint.ReferencedColumns))
|
|
for _, colName := range fkConstraint.ReferencedColumns {
|
|
targetColumns = append(targetColumns, strings.ToLower(colName))
|
|
}
|
|
|
|
if len(sourceColumns) == 0 || len(targetColumns) == 0 {
|
|
continue
|
|
}
|
|
|
|
onDelete := "NO ACTION"
|
|
if fkConstraint.OnDelete != "" {
|
|
onDelete = strings.ToUpper(fkConstraint.OnDelete)
|
|
}
|
|
|
|
onUpdate := "NO ACTION"
|
|
if fkConstraint.OnUpdate != "" {
|
|
onUpdate = strings.ToUpper(fkConstraint.OnUpdate)
|
|
}
|
|
|
|
fmt.Fprintf(w.writer, "ALTER TABLE %s.%s\n", schema.SQLName(), table.SQLName())
|
|
fmt.Fprintf(w.writer, " DROP CONSTRAINT IF EXISTS %s;\n", fkName)
|
|
fmt.Fprintf(w.writer, "\n")
|
|
fmt.Fprintf(w.writer, "ALTER TABLE %s.%s\n", schema.SQLName(), table.SQLName())
|
|
fmt.Fprintf(w.writer, " ADD CONSTRAINT %s\n", fkName)
|
|
fmt.Fprintf(w.writer, " FOREIGN KEY (%s)\n", strings.Join(sourceColumns, ", "))
|
|
|
|
// Use constraint's referenced schema/table or relationship's ToSchema/ToTable
|
|
refSchema := fkConstraint.ReferencedSchema
|
|
if refSchema == "" {
|
|
refSchema = rel.ToSchema
|
|
}
|
|
refTable := fkConstraint.ReferencedTable
|
|
if refTable == "" {
|
|
refTable = rel.ToTable
|
|
}
|
|
|
|
fmt.Fprintf(w.writer, " REFERENCES %s.%s (%s)\n",
|
|
refSchema, refTable, strings.Join(targetColumns, ", "))
|
|
fmt.Fprintf(w.writer, " ON DELETE %s\n", onDelete)
|
|
fmt.Fprintf(w.writer, " ON UPDATE %s\n", onUpdate)
|
|
fmt.Fprintf(w.writer, " DEFERRABLE;\n\n")
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// writeSetSequenceValues generates statements to set sequence current values
|
|
func (w *Writer) writeSetSequenceValues(schema *models.Schema) error {
|
|
fmt.Fprintf(w.writer, "-- Set sequence values for schema: %s\n", schema.Name)
|
|
|
|
for _, table := range schema.Tables {
|
|
pk := table.GetPrimaryKey()
|
|
if pk == nil || !isIntegerType(pk.Type) {
|
|
continue
|
|
}
|
|
|
|
seqName := fmt.Sprintf("identity_%s_%s", table.SQLName(), pk.SQLName())
|
|
|
|
fmt.Fprintf(w.writer, "DO $$\n")
|
|
fmt.Fprintf(w.writer, "DECLARE\n")
|
|
fmt.Fprintf(w.writer, " m_cnt bigint;\n")
|
|
fmt.Fprintf(w.writer, "BEGIN\n")
|
|
fmt.Fprintf(w.writer, " IF EXISTS (\n")
|
|
fmt.Fprintf(w.writer, " SELECT 1 FROM pg_class c\n")
|
|
fmt.Fprintf(w.writer, " INNER JOIN pg_namespace n ON n.oid = c.relnamespace\n")
|
|
fmt.Fprintf(w.writer, " WHERE c.relname = '%s'\n", seqName)
|
|
fmt.Fprintf(w.writer, " AND n.nspname = '%s'\n", schema.Name)
|
|
fmt.Fprintf(w.writer, " AND c.relkind = 'S'\n")
|
|
fmt.Fprintf(w.writer, " ) THEN\n")
|
|
fmt.Fprintf(w.writer, " SELECT COALESCE(MAX(%s), 0) + 1\n", pk.SQLName())
|
|
fmt.Fprintf(w.writer, " FROM %s.%s\n", schema.SQLName(), table.SQLName())
|
|
fmt.Fprintf(w.writer, " INTO m_cnt;\n")
|
|
fmt.Fprintf(w.writer, " \n")
|
|
fmt.Fprintf(w.writer, " PERFORM setval('%s.%s'::regclass, m_cnt);\n",
|
|
schema.SQLName(), seqName)
|
|
fmt.Fprintf(w.writer, " END IF;\n")
|
|
fmt.Fprintf(w.writer, "END;\n")
|
|
fmt.Fprintf(w.writer, "$$;\n\n")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// writeComments generates COMMENT statements for tables and columns
|
|
func (w *Writer) writeComments(schema *models.Schema) error {
|
|
fmt.Fprintf(w.writer, "-- Comments for schema: %s\n", schema.Name)
|
|
|
|
for _, table := range schema.Tables {
|
|
// Table comment
|
|
if table.Description != "" {
|
|
fmt.Fprintf(w.writer, "COMMENT ON TABLE %s.%s IS '%s';\n",
|
|
schema.SQLName(), table.SQLName(),
|
|
escapeQuote(table.Description))
|
|
}
|
|
|
|
// Column comments
|
|
for _, col := range getSortedColumns(table.Columns) {
|
|
if col.Description != "" {
|
|
fmt.Fprintf(w.writer, "COMMENT ON COLUMN %s.%s.%s IS '%s';\n",
|
|
schema.SQLName(), table.SQLName(), col.SQLName(),
|
|
escapeQuote(col.Description))
|
|
}
|
|
}
|
|
|
|
fmt.Fprintf(w.writer, "\n")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Helper functions
|
|
|
|
// getSortedColumns returns columns sorted by name
|
|
func getSortedColumns(columns map[string]*models.Column) []*models.Column {
|
|
names := make([]string, 0, len(columns))
|
|
for name := range columns {
|
|
names = append(names, name)
|
|
}
|
|
sort.Strings(names)
|
|
|
|
sorted := make([]*models.Column, 0, len(columns))
|
|
for _, name := range names {
|
|
sorted = append(sorted, columns[name])
|
|
}
|
|
return sorted
|
|
}
|
|
|
|
// isIntegerType checks if a column type is an integer type
|
|
func isIntegerType(colType string) bool {
|
|
intTypes := []string{"integer", "int", "bigint", "smallint", "serial", "bigserial"}
|
|
lowerType := strings.ToLower(colType)
|
|
for _, t := range intTypes {
|
|
if strings.HasPrefix(lowerType, t) {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
// escapeQuote escapes single quotes in strings for SQL
|
|
func escapeQuote(s string) string {
|
|
return strings.ReplaceAll(s, "'", "''")
|
|
}
|
|
|
|
// extractSequenceName extracts sequence name from nextval() expression
|
|
// Example: "nextval('public.users_id_seq'::regclass)" returns "users_id_seq"
|
|
func extractSequenceName(defaultExpr string) string {
|
|
// Look for nextval('schema.sequence_name'::regclass) pattern
|
|
start := strings.Index(defaultExpr, "'")
|
|
if start == -1 {
|
|
return ""
|
|
}
|
|
end := strings.Index(defaultExpr[start+1:], "'")
|
|
if end == -1 {
|
|
return ""
|
|
}
|
|
|
|
fullName := defaultExpr[start+1 : start+1+end]
|
|
|
|
// Remove schema prefix if present
|
|
parts := strings.Split(fullName, ".")
|
|
if len(parts) > 1 {
|
|
return parts[len(parts)-1]
|
|
}
|
|
return fullName
|
|
}
|