package pgsql import ( "fmt" "io" "os" "sort" "strings" "git.warky.dev/wdevs/relspecgo/pkg/models" "git.warky.dev/wdevs/relspecgo/pkg/writers" ) // Writer implements the Writer interface for PostgreSQL SQL output type Writer struct { options *writers.WriterOptions writer io.Writer } // NewWriter creates a new PostgreSQL SQL writer func NewWriter(options *writers.WriterOptions) *Writer { return &Writer{ options: options, } } // WriteDatabase writes the entire database schema as SQL func (w *Writer) WriteDatabase(db *models.Database) error { var writer io.Writer var file *os.File var err error // Use existing writer if already set (for testing) if w.writer != nil { writer = w.writer } else if w.options.OutputPath != "" { // Determine output destination file, err = os.Create(w.options.OutputPath) if err != nil { return fmt.Errorf("failed to create output file: %w", err) } defer file.Close() writer = file } else { writer = os.Stdout } w.writer = writer // Write header comment fmt.Fprintf(w.writer, "-- PostgreSQL Database Schema\n") fmt.Fprintf(w.writer, "-- Database: %s\n", db.Name) fmt.Fprintf(w.writer, "-- Generated by RelSpec\n\n") // Process each schema in the database for _, schema := range db.Schemas { if err := w.WriteSchema(schema); err != nil { return fmt.Errorf("failed to write schema %s: %w", schema.Name, err) } } return nil } // GenerateDatabaseStatements generates SQL statements as a list for the entire database // Returns a slice of SQL statements that can be executed independently func (w *Writer) GenerateDatabaseStatements(db *models.Database) ([]string, error) { statements := []string{} // Add header comment statements = append(statements, "-- PostgreSQL Database Schema") statements = append(statements, fmt.Sprintf("-- Database: %s", db.Name)) statements = append(statements, "-- Generated by RelSpec") // Process each schema in the database for _, schema := range db.Schemas { schemaStatements, err := w.GenerateSchemaStatements(schema) if err != nil { return nil, fmt.Errorf("failed to generate statements for schema %s: %w", schema.Name, err) } statements = append(statements, schemaStatements...) } return statements, nil } // GenerateSchemaStatements generates SQL statements as a list for a single schema func (w *Writer) GenerateSchemaStatements(schema *models.Schema) ([]string, error) { statements := []string{} // Phase 1: Create schema if schema.Name != "public" { statements = append(statements, fmt.Sprintf("-- Schema: %s", schema.Name)) statements = append(statements, fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS %s", schema.SQLName())) } // Phase 2: Create sequences for _, table := range schema.Tables { pk := table.GetPrimaryKey() if pk == nil || !isIntegerType(pk.Type) || pk.Default == "" { continue } defaultStr, ok := pk.Default.(string) if !ok || !strings.Contains(strings.ToLower(defaultStr), "nextval") { continue } seqName := extractSequenceName(defaultStr) if seqName == "" { continue } stmt := fmt.Sprintf("CREATE SEQUENCE IF NOT EXISTS %s.%s\n INCREMENT 1\n MINVALUE 1\n MAXVALUE 9223372036854775807\n START 1\n CACHE 1", schema.SQLName(), seqName) statements = append(statements, stmt) } // Phase 3: Create tables for _, table := range schema.Tables { stmts, err := w.generateCreateTableStatement(schema, table) if err != nil { return nil, fmt.Errorf("failed to generate table %s: %w", table.Name, err) } statements = append(statements, stmts...) } // Phase 4: Primary keys for _, table := range schema.Tables { for _, constraint := range table.Constraints { if constraint.Type != models.PrimaryKeyConstraint { continue } stmt := fmt.Sprintf("ALTER TABLE %s.%s ADD CONSTRAINT %s PRIMARY KEY (%s)", schema.SQLName(), table.SQLName(), constraint.Name, strings.Join(constraint.Columns, ", ")) statements = append(statements, stmt) } } // Phase 5: Indexes for _, table := range schema.Tables { for _, index := range table.Indexes { // Skip primary key indexes if strings.HasSuffix(index.Name, "_pkey") { continue } uniqueStr := "" if index.Unique { uniqueStr = "UNIQUE " } indexType := index.Type if indexType == "" { indexType = "btree" } whereClause := "" if index.Where != "" { whereClause = fmt.Sprintf(" WHERE %s", index.Where) } stmt := fmt.Sprintf("CREATE %sINDEX IF NOT EXISTS %s ON %s.%s USING %s (%s)%s", uniqueStr, index.Name, schema.SQLName(), table.SQLName(), indexType, strings.Join(index.Columns, ", "), whereClause) statements = append(statements, stmt) } } // Phase 6: Foreign keys for _, table := range schema.Tables { for _, constraint := range table.Constraints { if constraint.Type != models.ForeignKeyConstraint { continue } refSchema := constraint.ReferencedSchema if refSchema == "" { refSchema = schema.Name } onDelete := constraint.OnDelete if onDelete == "" { onDelete = "NO ACTION" } onUpdate := constraint.OnUpdate if onUpdate == "" { onUpdate = "NO ACTION" } stmt := fmt.Sprintf("ALTER TABLE %s.%s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s.%s(%s) ON DELETE %s ON UPDATE %s", schema.SQLName(), table.SQLName(), constraint.Name, strings.Join(constraint.Columns, ", "), strings.ToLower(refSchema), strings.ToLower(constraint.ReferencedTable), strings.Join(constraint.ReferencedColumns, ", "), onDelete, onUpdate) statements = append(statements, stmt) } } // Phase 7: Comments for _, table := range schema.Tables { if table.Comment != "" { stmt := fmt.Sprintf("COMMENT ON TABLE %s.%s IS '%s'", schema.SQLName(), table.SQLName(), escapeQuote(table.Comment)) statements = append(statements, stmt) } for _, column := range table.Columns { if column.Comment != "" { stmt := fmt.Sprintf("COMMENT ON COLUMN %s.%s.%s IS '%s'", schema.SQLName(), table.SQLName(), column.SQLName(), escapeQuote(column.Comment)) statements = append(statements, stmt) } } } return statements, nil } // generateCreateTableStatement generates CREATE TABLE statement func (w *Writer) generateCreateTableStatement(schema *models.Schema, table *models.Table) ([]string, error) { statements := []string{} // Sort columns by sequence or name columns := make([]*models.Column, 0, len(table.Columns)) for _, col := range table.Columns { columns = append(columns, col) } sort.Slice(columns, func(i, j int) bool { if columns[i].Sequence != columns[j].Sequence { return columns[i].Sequence < columns[j].Sequence } return columns[i].Name < columns[j].Name }) columnDefs := []string{} for _, col := range columns { def := w.generateColumnDefinition(col) columnDefs = append(columnDefs, " "+def) } stmt := fmt.Sprintf("CREATE TABLE %s.%s (\n%s\n)", schema.SQLName(), table.SQLName(), strings.Join(columnDefs, ",\n")) statements = append(statements, stmt) return statements, nil } // generateColumnDefinition generates column definition func (w *Writer) generateColumnDefinition(col *models.Column) string { parts := []string{col.SQLName()} // Type with length/precision typeStr := col.Type if col.Length > 0 && col.Precision == 0 { typeStr = fmt.Sprintf("%s(%d)", col.Type, col.Length) } else if col.Precision > 0 { if col.Scale > 0 { typeStr = fmt.Sprintf("%s(%d,%d)", col.Type, col.Precision, col.Scale) } else { typeStr = fmt.Sprintf("%s(%d)", col.Type, col.Precision) } } parts = append(parts, typeStr) // NOT NULL if col.NotNull { parts = append(parts, "NOT NULL") } // DEFAULT if col.Default != nil { switch v := col.Default.(type) { case string: if strings.HasPrefix(v, "nextval") || strings.HasPrefix(v, "CURRENT_") || strings.Contains(v, "()") { parts = append(parts, fmt.Sprintf("DEFAULT %s", v)) } else if v == "true" || v == "false" { parts = append(parts, fmt.Sprintf("DEFAULT %s", v)) } else { parts = append(parts, fmt.Sprintf("DEFAULT '%s'", escapeQuote(v))) } case bool: parts = append(parts, fmt.Sprintf("DEFAULT %v", v)) default: parts = append(parts, fmt.Sprintf("DEFAULT %v", v)) } } return strings.Join(parts, " ") } // WriteSchema writes a single schema and all its tables func (w *Writer) WriteSchema(schema *models.Schema) error { if w.writer == nil { w.writer = os.Stdout } // Phase 1: Create schema (priority 1) if err := w.writeCreateSchema(schema); err != nil { return err } // Phase 2: Create sequences (priority 80) if err := w.writeSequences(schema); err != nil { return err } // Phase 3: Create tables with columns (priority 100) if err := w.writeCreateTables(schema); err != nil { return err } // Phase 4: Create primary keys (priority 160) if err := w.writePrimaryKeys(schema); err != nil { return err } // Phase 5: Create indexes (priority 180) if err := w.writeIndexes(schema); err != nil { return err } // Phase 6: Create foreign key constraints (priority 195) if err := w.writeForeignKeys(schema); err != nil { return err } // Phase 7: Set sequence values (priority 200) if err := w.writeSetSequenceValues(schema); err != nil { return err } // Phase 8: Add comments (priority 200+) if err := w.writeComments(schema); err != nil { return err } return nil } // WriteTable writes a single table with all its elements func (w *Writer) WriteTable(table *models.Table) error { if w.writer == nil { w.writer = os.Stdout } // Create a temporary schema with just this table schema := models.InitSchema(table.Schema) schema.Tables = append(schema.Tables, table) return w.WriteSchema(schema) } // writeCreateSchema generates CREATE SCHEMA statement func (w *Writer) writeCreateSchema(schema *models.Schema) error { if schema.Name == "public" { // public schema exists by default return nil } fmt.Fprintf(w.writer, "-- Schema: %s\n", schema.Name) fmt.Fprintf(w.writer, "CREATE SCHEMA IF NOT EXISTS %s;\n\n", schema.SQLName()) return nil } // writeSequences generates CREATE SEQUENCE statements for identity columns func (w *Writer) writeSequences(schema *models.Schema) error { fmt.Fprintf(w.writer, "-- Sequences for schema: %s\n", schema.Name) for _, table := range schema.Tables { pk := table.GetPrimaryKey() if pk == nil { continue } // Only create sequences for integer-type PKs with identity if !isIntegerType(pk.Type) { continue } seqName := fmt.Sprintf("identity_%s_%s", table.SQLName(), pk.SQLName()) fmt.Fprintf(w.writer, "CREATE SEQUENCE IF NOT EXISTS %s.%s\n", schema.SQLName(), seqName) fmt.Fprintf(w.writer, " INCREMENT 1\n") fmt.Fprintf(w.writer, " MINVALUE 1\n") fmt.Fprintf(w.writer, " MAXVALUE 9223372036854775807\n") fmt.Fprintf(w.writer, " START 1\n") fmt.Fprintf(w.writer, " CACHE 1;\n\n") } return nil } // writeCreateTables generates CREATE TABLE statements func (w *Writer) writeCreateTables(schema *models.Schema) error { fmt.Fprintf(w.writer, "-- Tables for schema: %s\n", schema.Name) for _, table := range schema.Tables { fmt.Fprintf(w.writer, "CREATE TABLE IF NOT EXISTS %s.%s (\n", schema.SQLName(), table.SQLName()) // Write columns columns := getSortedColumns(table.Columns) columnDefs := make([]string, 0, len(columns)) for _, col := range columns { colDef := fmt.Sprintf(" %s %s", col.SQLName(), col.Type) // Add default value if present if col.Default != "" { colDef += fmt.Sprintf(" DEFAULT %s", col.Default) } columnDefs = append(columnDefs, colDef) } fmt.Fprintf(w.writer, "%s\n", strings.Join(columnDefs, ",\n")) fmt.Fprintf(w.writer, ");\n\n") } return nil } // writePrimaryKeys generates ALTER TABLE statements for primary keys func (w *Writer) writePrimaryKeys(schema *models.Schema) error { fmt.Fprintf(w.writer, "-- Primary keys for schema: %s\n", schema.Name) for _, table := range schema.Tables { // Find primary key constraint var pkConstraint *models.Constraint for name, constraint := range table.Constraints { if constraint.Type == models.PrimaryKeyConstraint { pkConstraint = constraint _ = name // Use the name variable break } } if pkConstraint == nil { // No explicit PK constraint, skip continue } pkName := fmt.Sprintf("pk_%s_%s", schema.SQLName(), table.SQLName()) // Build column list columnNames := make([]string, 0, len(pkConstraint.Columns)) for _, colName := range pkConstraint.Columns { if col, ok := table.Columns[colName]; ok { columnNames = append(columnNames, col.SQLName()) } } if len(columnNames) == 0 { continue } fmt.Fprintf(w.writer, "DO $$\nBEGIN\n") fmt.Fprintf(w.writer, " IF NOT EXISTS (\n") fmt.Fprintf(w.writer, " SELECT 1 FROM information_schema.table_constraints\n") fmt.Fprintf(w.writer, " WHERE table_schema = '%s'\n", schema.Name) fmt.Fprintf(w.writer, " AND table_name = '%s'\n", table.Name) fmt.Fprintf(w.writer, " AND constraint_name = '%s'\n", pkName) fmt.Fprintf(w.writer, " ) THEN\n") fmt.Fprintf(w.writer, " ALTER TABLE %s.%s\n", schema.SQLName(), table.SQLName()) fmt.Fprintf(w.writer, " ADD CONSTRAINT %s PRIMARY KEY (%s);\n", pkName, strings.Join(columnNames, ", ")) fmt.Fprintf(w.writer, " END IF;\n") fmt.Fprintf(w.writer, "END;\n$$;\n\n") } return nil } // writeIndexes generates CREATE INDEX statements func (w *Writer) writeIndexes(schema *models.Schema) error { fmt.Fprintf(w.writer, "-- Indexes for schema: %s\n", schema.Name) for _, table := range schema.Tables { // Sort indexes by name for consistent output indexNames := make([]string, 0, len(table.Indexes)) for name := range table.Indexes { indexNames = append(indexNames, name) } sort.Strings(indexNames) for _, name := range indexNames { index := table.Indexes[name] // Skip if it's a primary key index (based on name convention or columns) // Primary keys are handled separately if strings.HasPrefix(strings.ToLower(index.Name), "pk_") { continue } indexName := index.Name if indexName == "" { indexType := "idx" if index.Unique { indexType = "uk" } indexName = fmt.Sprintf("%s_%s_%s", indexType, schema.SQLName(), table.SQLName()) } // Build column list columnNames := make([]string, 0, len(index.Columns)) for _, colName := range index.Columns { if col, ok := table.Columns[colName]; ok { columnNames = append(columnNames, col.SQLName()) } } if len(columnNames) == 0 { continue } unique := "" if index.Unique { unique = "UNIQUE " } fmt.Fprintf(w.writer, "CREATE %sINDEX IF NOT EXISTS %s\n", unique, indexName) fmt.Fprintf(w.writer, " ON %s.%s USING btree (%s);\n\n", schema.SQLName(), table.SQLName(), strings.Join(columnNames, ", ")) } } return nil } // writeForeignKeys generates ALTER TABLE statements for foreign keys func (w *Writer) writeForeignKeys(schema *models.Schema) error { fmt.Fprintf(w.writer, "-- Foreign keys for schema: %s\n", schema.Name) for _, table := range schema.Tables { // Sort relationships by name for consistent output relNames := make([]string, 0, len(table.Relationships)) for name := range table.Relationships { relNames = append(relNames, name) } sort.Strings(relNames) for _, name := range relNames { rel := table.Relationships[name] // For relationships, we need to look up the foreign key constraint // that defines the actual column mappings fkName := rel.ForeignKey if fkName == "" { fkName = name } if fkName == "" { fkName = fmt.Sprintf("fk_%s_%s", table.SQLName(), rel.ToTable) } // Find the foreign key constraint that matches this relationship var fkConstraint *models.Constraint for _, constraint := range table.Constraints { if constraint.Type == models.ForeignKeyConstraint && (constraint.Name == fkName || constraint.ReferencedTable == rel.ToTable) { fkConstraint = constraint break } } // If no constraint found, skip this relationship if fkConstraint == nil { continue } // Build column lists from the constraint sourceColumns := make([]string, 0, len(fkConstraint.Columns)) for _, colName := range fkConstraint.Columns { if col, ok := table.Columns[colName]; ok { sourceColumns = append(sourceColumns, col.SQLName()) } } targetColumns := make([]string, 0, len(fkConstraint.ReferencedColumns)) for _, colName := range fkConstraint.ReferencedColumns { targetColumns = append(targetColumns, strings.ToLower(colName)) } if len(sourceColumns) == 0 || len(targetColumns) == 0 { continue } onDelete := "NO ACTION" if fkConstraint.OnDelete != "" { onDelete = strings.ToUpper(fkConstraint.OnDelete) } onUpdate := "NO ACTION" if fkConstraint.OnUpdate != "" { onUpdate = strings.ToUpper(fkConstraint.OnUpdate) } fmt.Fprintf(w.writer, "ALTER TABLE %s.%s\n", schema.SQLName(), table.SQLName()) fmt.Fprintf(w.writer, " DROP CONSTRAINT IF EXISTS %s;\n", fkName) fmt.Fprintf(w.writer, "\n") fmt.Fprintf(w.writer, "ALTER TABLE %s.%s\n", schema.SQLName(), table.SQLName()) fmt.Fprintf(w.writer, " ADD CONSTRAINT %s\n", fkName) fmt.Fprintf(w.writer, " FOREIGN KEY (%s)\n", strings.Join(sourceColumns, ", ")) // Use constraint's referenced schema/table or relationship's ToSchema/ToTable refSchema := fkConstraint.ReferencedSchema if refSchema == "" { refSchema = rel.ToSchema } refTable := fkConstraint.ReferencedTable if refTable == "" { refTable = rel.ToTable } fmt.Fprintf(w.writer, " REFERENCES %s.%s (%s)\n", refSchema, refTable, strings.Join(targetColumns, ", ")) fmt.Fprintf(w.writer, " ON DELETE %s\n", onDelete) fmt.Fprintf(w.writer, " ON UPDATE %s\n", onUpdate) fmt.Fprintf(w.writer, " DEFERRABLE;\n\n") } } return nil } // writeSetSequenceValues generates statements to set sequence current values func (w *Writer) writeSetSequenceValues(schema *models.Schema) error { fmt.Fprintf(w.writer, "-- Set sequence values for schema: %s\n", schema.Name) for _, table := range schema.Tables { pk := table.GetPrimaryKey() if pk == nil || !isIntegerType(pk.Type) { continue } seqName := fmt.Sprintf("identity_%s_%s", table.SQLName(), pk.SQLName()) fmt.Fprintf(w.writer, "DO $$\n") fmt.Fprintf(w.writer, "DECLARE\n") fmt.Fprintf(w.writer, " m_cnt bigint;\n") fmt.Fprintf(w.writer, "BEGIN\n") fmt.Fprintf(w.writer, " IF EXISTS (\n") fmt.Fprintf(w.writer, " SELECT 1 FROM pg_class c\n") fmt.Fprintf(w.writer, " INNER JOIN pg_namespace n ON n.oid = c.relnamespace\n") fmt.Fprintf(w.writer, " WHERE c.relname = '%s'\n", seqName) fmt.Fprintf(w.writer, " AND n.nspname = '%s'\n", schema.Name) fmt.Fprintf(w.writer, " AND c.relkind = 'S'\n") fmt.Fprintf(w.writer, " ) THEN\n") fmt.Fprintf(w.writer, " SELECT COALESCE(MAX(%s), 0) + 1\n", pk.SQLName()) fmt.Fprintf(w.writer, " FROM %s.%s\n", schema.SQLName(), table.SQLName()) fmt.Fprintf(w.writer, " INTO m_cnt;\n") fmt.Fprintf(w.writer, " \n") fmt.Fprintf(w.writer, " PERFORM setval('%s.%s'::regclass, m_cnt);\n", schema.SQLName(), seqName) fmt.Fprintf(w.writer, " END IF;\n") fmt.Fprintf(w.writer, "END;\n") fmt.Fprintf(w.writer, "$$;\n\n") } return nil } // writeComments generates COMMENT statements for tables and columns func (w *Writer) writeComments(schema *models.Schema) error { fmt.Fprintf(w.writer, "-- Comments for schema: %s\n", schema.Name) for _, table := range schema.Tables { // Table comment if table.Description != "" { fmt.Fprintf(w.writer, "COMMENT ON TABLE %s.%s IS '%s';\n", schema.SQLName(), table.SQLName(), escapeQuote(table.Description)) } // Column comments for _, col := range getSortedColumns(table.Columns) { if col.Description != "" { fmt.Fprintf(w.writer, "COMMENT ON COLUMN %s.%s.%s IS '%s';\n", schema.SQLName(), table.SQLName(), col.SQLName(), escapeQuote(col.Description)) } } fmt.Fprintf(w.writer, "\n") } return nil } // Helper functions // getSortedColumns returns columns sorted by name func getSortedColumns(columns map[string]*models.Column) []*models.Column { names := make([]string, 0, len(columns)) for name := range columns { names = append(names, name) } sort.Strings(names) sorted := make([]*models.Column, 0, len(columns)) for _, name := range names { sorted = append(sorted, columns[name]) } return sorted } // isIntegerType checks if a column type is an integer type func isIntegerType(colType string) bool { intTypes := []string{"integer", "int", "bigint", "smallint", "serial", "bigserial"} lowerType := strings.ToLower(colType) for _, t := range intTypes { if strings.HasPrefix(lowerType, t) { return true } } return false } // escapeQuote escapes single quotes in strings for SQL func escapeQuote(s string) string { return strings.ReplaceAll(s, "'", "''") } // extractSequenceName extracts sequence name from nextval() expression // Example: "nextval('public.users_id_seq'::regclass)" returns "users_id_seq" func extractSequenceName(defaultExpr string) string { // Look for nextval('schema.sequence_name'::regclass) pattern start := strings.Index(defaultExpr, "'") if start == -1 { return "" } end := strings.Index(defaultExpr[start+1:], "'") if end == -1 { return "" } fullName := defaultExpr[start+1 : start+1+end] // Remove schema prefix if present parts := strings.Split(fullName, ".") if len(parts) > 1 { return parts[len(parts)-1] } return fullName }