7 Commits

Author SHA1 Message Date
165623bb1d feat(pgsql): Add templates for constraints and sequences
All checks were successful
CI / Test (1.24) (push) Successful in -26m21s
CI / Test (1.25) (push) Successful in -26m13s
CI / Build (push) Successful in -26m39s
CI / Lint (push) Successful in -26m29s
Release / Build and Release (push) Successful in -26m28s
Integration Tests / Integration Tests (push) Successful in -26m10s
* Introduce new templates for creating unique, check, and foreign key constraints with existence checks.
* Add templates for setting sequence values and creating sequences.
* Refactor existing SQL generation logic to utilize new templates for better maintainability and readability.
* Ensure identifiers are properly quoted to handle special characters and reserved keywords.
2026-01-31 21:04:43 +02:00
3c20c3c5d9 feat(writer): 🎉 Add support for check constraints in schema generation
All checks were successful
CI / Test (1.24) (push) Successful in -26m17s
CI / Test (1.25) (push) Successful in -26m14s
CI / Build (push) Successful in -26m41s
CI / Lint (push) Successful in -26m32s
Release / Build and Release (push) Successful in -26m31s
Integration Tests / Integration Tests (push) Successful in -26m13s
* Implement check constraints in the schema writer.
* Generate SQL statements to add check constraints if they do not exist.
* Add tests to verify correct generation of check constraints.
2026-01-31 20:42:19 +02:00
a54594e49b feat(writer): 🎉 Add support for unique constraints in schema generation
All checks were successful
CI / Test (1.24) (push) Successful in -26m26s
CI / Test (1.25) (push) Successful in -26m18s
CI / Lint (push) Successful in -26m25s
CI / Build (push) Successful in -26m35s
Release / Build and Release (push) Successful in -26m29s
Integration Tests / Integration Tests (push) Successful in -26m11s
* Implement unique constraint handling in GenerateSchemaStatements
* Add writeUniqueConstraints method for generating SQL statements
* Create unit test for unique constraints in writer_test.go
2026-01-31 20:33:08 +02:00
cafe6a461f feat(scripts): 🎉 Add --ignore-errors flag for script execution
All checks were successful
CI / Test (1.24) (push) Successful in -26m18s
CI / Test (1.25) (push) Successful in -26m14s
CI / Build (push) Successful in -26m38s
CI / Lint (push) Successful in -26m30s
Release / Build and Release (push) Successful in -26m27s
Integration Tests / Integration Tests (push) Successful in -26m10s
- Allow continued execution of scripts even if errors occur.
- Update execution summary to include counts of successful and failed scripts.
- Enhance error handling and reporting for better visibility.
2026-01-31 20:21:22 +02:00
abdb9b4c78 feat(dbml/reader): 🎉 Implement splitIdentifier function for parsing
All checks were successful
CI / Test (1.24) (push) Successful in -26m24s
CI / Test (1.25) (push) Successful in -26m17s
CI / Build (push) Successful in -26m44s
CI / Lint (push) Successful in -26m33s
Integration Tests / Integration Tests (push) Successful in -26m11s
Release / Build and Release (push) Successful in -26m36s
2026-01-31 19:45:24 +02:00
e7a15c8e4f feat(writer): 🎉 Implement add column statements for schema evolution
All checks were successful
CI / Test (1.24) (push) Successful in -26m24s
CI / Test (1.25) (push) Successful in -26m14s
CI / Lint (push) Successful in -26m30s
CI / Build (push) Successful in -26m41s
Release / Build and Release (push) Successful in -26m29s
Integration Tests / Integration Tests (push) Successful in -26m13s
* Add functionality to generate ALTER TABLE ADD COLUMN statements for existing tables.
* Introduce tests for generating and writing add column statements.
* Enhance schema evolution capabilities when new columns are added.
2026-01-31 19:12:00 +02:00
c36b5ede2b feat(writer): 🎉 Enhance primary key handling and add tests
All checks were successful
CI / Test (1.24) (push) Successful in -26m18s
CI / Test (1.25) (push) Successful in -26m11s
CI / Build (push) Successful in -26m43s
CI / Lint (push) Successful in -26m34s
Release / Build and Release (push) Successful in -26m31s
Integration Tests / Integration Tests (push) Successful in -26m20s
* Implement checks for existing primary keys before adding new ones.
* Drop auto-generated primary keys if they exist.
* Add tests for primary key existence and column size specifiers.
* Improve type conversion handling for PostgreSQL compatibility.
2026-01-31 18:59:32 +02:00
28 changed files with 1712 additions and 129 deletions

View File

@@ -14,10 +14,11 @@ import (
)
var (
scriptsDir string
scriptsConn string
scriptsSchemaName string
scriptsDBName string
scriptsDir string
scriptsConn string
scriptsSchemaName string
scriptsDBName string
scriptsIgnoreErrors bool
)
var scriptsCmd = &cobra.Command{
@@ -62,7 +63,7 @@ var scriptsExecuteCmd = &cobra.Command{
Long: `Execute SQL scripts from a directory against a PostgreSQL database.
Scripts are executed in order: Priority (ascending), Sequence (ascending), Name (alphabetical).
Execution stops immediately on the first error.
By default, execution stops immediately on the first error. Use --ignore-errors to continue execution.
The directory is scanned recursively for all subdirectories and files matching the patterns:
{priority}_{sequence}_{name}.sql or .pgsql (underscore format)
@@ -86,7 +87,12 @@ Examples:
# Execute with SSL disabled
relspec scripts execute --dir ./sql \
--conn "postgres://user:pass@localhost/db?sslmode=disable"`,
--conn "postgres://user:pass@localhost/db?sslmode=disable"
# Continue executing even if errors occur
relspec scripts execute --dir ./migrations \
--conn "postgres://localhost/mydb" \
--ignore-errors`,
RunE: runScriptsExecute,
}
@@ -105,6 +111,7 @@ func init() {
scriptsExecuteCmd.Flags().StringVar(&scriptsConn, "conn", "", "PostgreSQL connection string (required)")
scriptsExecuteCmd.Flags().StringVar(&scriptsSchemaName, "schema", "public", "Schema name (optional, default: public)")
scriptsExecuteCmd.Flags().StringVar(&scriptsDBName, "database", "database", "Database name (optional, default: database)")
scriptsExecuteCmd.Flags().BoolVar(&scriptsIgnoreErrors, "ignore-errors", false, "Continue executing scripts even if errors occur")
err = scriptsExecuteCmd.MarkFlagRequired("dir")
if err != nil {
@@ -250,17 +257,39 @@ func runScriptsExecute(cmd *cobra.Command, args []string) error {
writer := sqlexec.NewWriter(&writers.WriterOptions{
Metadata: map[string]any{
"connection_string": scriptsConn,
"ignore_errors": scriptsIgnoreErrors,
},
})
if err := writer.WriteSchema(schema); err != nil {
fmt.Fprintf(os.Stderr, "\n")
return fmt.Errorf("execution failed: %w", err)
return fmt.Errorf("script execution failed: %w", err)
}
// Get execution results from writer metadata
totalCount := len(schema.Scripts)
successCount := totalCount
failedCount := 0
opts := writer.Options()
if total, exists := opts.Metadata["execution_total"].(int); exists {
totalCount = total
}
if success, exists := opts.Metadata["execution_success"].(int); exists {
successCount = success
}
if failed, exists := opts.Metadata["execution_failed"].(int); exists {
failedCount = failed
}
fmt.Fprintf(os.Stderr, "\n=== Execution Complete ===\n")
fmt.Fprintf(os.Stderr, "Completed at: %s\n", getCurrentTimestamp())
fmt.Fprintf(os.Stderr, "Successfully executed %d script(s)\n\n", len(schema.Scripts))
fmt.Fprintf(os.Stderr, "Total scripts: %d\n", totalCount)
fmt.Fprintf(os.Stderr, "Successful: %d\n", successCount)
if failedCount > 0 {
fmt.Fprintf(os.Stderr, "Failed: %d\n", failedCount)
}
fmt.Fprintf(os.Stderr, "\n")
return nil
}

View File

@@ -128,6 +128,46 @@ func (r *Reader) readDirectoryDBML(dirPath string) (*models.Database, error) {
return db, nil
}
// splitIdentifier splits a dotted identifier while respecting quotes
// Handles cases like: "schema.with.dots"."table"."column"
func splitIdentifier(s string) []string {
var parts []string
var current strings.Builder
inQuote := false
quoteChar := byte(0)
for i := 0; i < len(s); i++ {
ch := s[i]
if !inQuote {
switch ch {
case '"', '\'':
inQuote = true
quoteChar = ch
current.WriteByte(ch)
case '.':
if current.Len() > 0 {
parts = append(parts, current.String())
current.Reset()
}
default:
current.WriteByte(ch)
}
} else {
current.WriteByte(ch)
if ch == quoteChar {
inQuote = false
}
}
}
if current.Len() > 0 {
parts = append(parts, current.String())
}
return parts
}
// stripQuotes removes surrounding quotes and comments from an identifier
func stripQuotes(s string) string {
s = strings.TrimSpace(s)
@@ -409,7 +449,9 @@ func (r *Reader) parseDBML(content string) (*models.Database, error) {
// Parse Table definition
if matches := tableRegex.FindStringSubmatch(line); matches != nil {
tableName := matches[1]
parts := strings.Split(tableName, ".")
// Strip comments/notes before parsing to avoid dots in notes
tableName = strings.TrimSpace(regexp.MustCompile(`\s*\[.*?\]\s*`).ReplaceAllString(tableName, ""))
parts := splitIdentifier(tableName)
if len(parts) == 2 {
currentSchema = stripQuotes(parts[0])
@@ -562,7 +604,7 @@ func (r *Reader) parseColumn(line, tableName, schemaName string) (*models.Column
} else if attr == "unique" {
// Create a unique constraint
uniqueConstraint := models.InitConstraint(
fmt.Sprintf("uq_%s", columnName),
fmt.Sprintf("uq_%s_%s", tableName, columnName),
models.UniqueConstraint,
)
uniqueConstraint.Schema = schemaName
@@ -610,8 +652,8 @@ func (r *Reader) parseColumn(line, tableName, schemaName string) (*models.Column
constraint.Table = tableName
constraint.Columns = []string{columnName}
}
// Generate short constraint name based on the column
constraint.Name = fmt.Sprintf("fk_%s", constraint.Columns[0])
// Generate constraint name based on table and columns
constraint.Name = fmt.Sprintf("fk_%s_%s", constraint.Table, strings.Join(constraint.Columns, "_"))
}
}
}
@@ -695,7 +737,11 @@ func (r *Reader) parseIndex(line, tableName, schemaName string) *models.Index {
// Generate name if not provided
if index.Name == "" {
index.Name = fmt.Sprintf("idx_%s_%s", tableName, strings.Join(columns, "_"))
prefix := "idx"
if index.Unique {
prefix = "uidx"
}
index.Name = fmt.Sprintf("%s_%s_%s", prefix, tableName, strings.Join(columns, "_"))
}
return index
@@ -755,10 +801,10 @@ func (r *Reader) parseRef(refStr string) *models.Constraint {
return nil
}
// Generate short constraint name based on the source column
constraintName := fmt.Sprintf("fk_%s_%s", fromTable, toTable)
if len(fromColumns) > 0 {
constraintName = fmt.Sprintf("fk_%s", fromColumns[0])
// Generate constraint name based on table and columns
constraintName := fmt.Sprintf("fk_%s_%s", fromTable, strings.Join(fromColumns, "_"))
if len(fromColumns) == 0 {
constraintName = fmt.Sprintf("fk_%s_%s", fromTable, toTable)
}
constraint := models.InitConstraint(
@@ -814,7 +860,7 @@ func (r *Reader) parseTableRef(ref string) (schema, table string, columns []stri
}
// Parse schema, table, and optionally column
parts := strings.Split(strings.TrimSpace(ref), ".")
parts := splitIdentifier(strings.TrimSpace(ref))
if len(parts) == 3 {
// Format: "schema"."table"."column"
schema = stripQuotes(parts[0])

View File

@@ -777,6 +777,76 @@ func TestParseFilePrefix(t *testing.T) {
}
}
func TestConstraintNaming(t *testing.T) {
// Test that constraints are named with proper prefixes
opts := &readers.ReaderOptions{
FilePath: filepath.Join("..", "..", "..", "tests", "assets", "dbml", "complex.dbml"),
}
reader := NewReader(opts)
db, err := reader.ReadDatabase()
if err != nil {
t.Fatalf("ReadDatabase() error = %v", err)
}
// Find users table
var usersTable *models.Table
var postsTable *models.Table
for _, schema := range db.Schemas {
for _, table := range schema.Tables {
if table.Name == "users" {
usersTable = table
} else if table.Name == "posts" {
postsTable = table
}
}
}
if usersTable == nil {
t.Fatal("Users table not found")
}
if postsTable == nil {
t.Fatal("Posts table not found")
}
// Test unique constraint naming: uq_table_column
if _, exists := usersTable.Constraints["uq_users_email"]; !exists {
t.Error("Expected unique constraint 'uq_users_email' not found")
t.Logf("Available constraints: %v", getKeys(usersTable.Constraints))
}
if _, exists := postsTable.Constraints["uq_posts_slug"]; !exists {
t.Error("Expected unique constraint 'uq_posts_slug' not found")
t.Logf("Available constraints: %v", getKeys(postsTable.Constraints))
}
// Test foreign key naming: fk_table_column
if _, exists := postsTable.Constraints["fk_posts_user_id"]; !exists {
t.Error("Expected foreign key 'fk_posts_user_id' not found")
t.Logf("Available constraints: %v", getKeys(postsTable.Constraints))
}
// Test unique index naming: uidx_table_columns
if _, exists := postsTable.Indexes["uidx_posts_slug"]; !exists {
t.Error("Expected unique index 'uidx_posts_slug' not found")
t.Logf("Available indexes: %v", getKeys(postsTable.Indexes))
}
// Test regular index naming: idx_table_columns
if _, exists := postsTable.Indexes["idx_posts_user_id_published"]; !exists {
t.Error("Expected index 'idx_posts_user_id_published' not found")
t.Logf("Available indexes: %v", getKeys(postsTable.Indexes))
}
}
func getKeys[V any](m map[string]V) []string {
keys := make([]string, 0, len(m))
for k := range m {
keys = append(keys, k)
}
return keys
}
func TestHasCommentedRefs(t *testing.T) {
// Test with the actual multifile test fixtures
tests := []struct {

View File

@@ -215,3 +215,70 @@ func TestTemplateExecutor_AuditFunction(t *testing.T) {
t.Error("SQL missing DELETE handling")
}
}
func TestWriteMigration_NumericConstraintNames(t *testing.T) {
// Current database (empty)
current := models.InitDatabase("testdb")
currentSchema := models.InitSchema("entity")
current.Schemas = append(current.Schemas, currentSchema)
// Model database (with constraint starting with number)
model := models.InitDatabase("testdb")
modelSchema := models.InitSchema("entity")
// Create individual_actor_relationship table
table := models.InitTable("individual_actor_relationship", "entity")
idCol := models.InitColumn("id", "individual_actor_relationship", "entity")
idCol.Type = "integer"
idCol.IsPrimaryKey = true
table.Columns["id"] = idCol
actorIDCol := models.InitColumn("actor_id", "individual_actor_relationship", "entity")
actorIDCol.Type = "integer"
table.Columns["actor_id"] = actorIDCol
// Add constraint with name starting with number
constraint := &models.Constraint{
Name: "215162_fk_actor",
Type: models.ForeignKeyConstraint,
Columns: []string{"actor_id"},
ReferencedSchema: "entity",
ReferencedTable: "actor",
ReferencedColumns: []string{"id"},
OnDelete: "CASCADE",
OnUpdate: "NO ACTION",
}
table.Constraints["215162_fk_actor"] = constraint
modelSchema.Tables = append(modelSchema.Tables, table)
model.Schemas = append(model.Schemas, modelSchema)
// Generate migration
var buf bytes.Buffer
writer, err := NewMigrationWriter(&writers.WriterOptions{})
if err != nil {
t.Fatalf("Failed to create writer: %v", err)
}
writer.writer = &buf
err = writer.WriteMigration(model, current)
if err != nil {
t.Fatalf("WriteMigration failed: %v", err)
}
output := buf.String()
t.Logf("Generated migration:\n%s", output)
// Verify constraint name is properly quoted
if !strings.Contains(output, `"215162_fk_actor"`) {
t.Error("Constraint name starting with number should be quoted")
}
// Verify the SQL is syntactically correct (contains required keywords)
if !strings.Contains(output, "ADD CONSTRAINT") {
t.Error("Migration missing ADD CONSTRAINT")
}
if !strings.Contains(output, "FOREIGN KEY") {
t.Error("Migration missing FOREIGN KEY")
}
}

View File

@@ -21,6 +21,7 @@ func TemplateFunctions() map[string]interface{} {
"quote": quote,
"escape": escape,
"safe_identifier": safeIdentifier,
"quote_ident": quoteIdent,
// Type conversion
"goTypeToSQL": goTypeToSQL,
@@ -122,6 +123,43 @@ func safeIdentifier(s string) string {
return strings.ToLower(safe)
}
// quoteIdent quotes a PostgreSQL identifier if necessary
// Identifiers need quoting if they:
// - Start with a digit
// - Contain special characters
// - Are reserved keywords
// - Contain uppercase letters (to preserve case)
func quoteIdent(s string) string {
if s == "" {
return `""`
}
// Check if quoting is needed
needsQuoting := unicode.IsDigit(rune(s[0]))
// Starts with digit
// Contains uppercase letters or special characters
for _, r := range s {
if unicode.IsUpper(r) {
needsQuoting = true
break
}
if !unicode.IsLetter(r) && !unicode.IsDigit(r) && r != '_' {
needsQuoting = true
break
}
}
if needsQuoting {
// Escape double quotes by doubling them
escaped := strings.ReplaceAll(s, `"`, `""`)
return `"` + escaped + `"`
}
return s
}
// Type conversion functions
// goTypeToSQL converts Go type to PostgreSQL type

View File

@@ -101,6 +101,31 @@ func TestSafeIdentifier(t *testing.T) {
}
}
func TestQuoteIdent(t *testing.T) {
tests := []struct {
input string
expected string
}{
{"valid_name", "valid_name"},
{"ValidName", `"ValidName"`},
{"123column", `"123column"`},
{"215162_fk_constraint", `"215162_fk_constraint"`},
{"user-id", `"user-id"`},
{"user@domain", `"user@domain"`},
{`"quoted"`, `"""quoted"""`},
{"", `""`},
{"lowercase", "lowercase"},
{"with_underscore", "with_underscore"},
}
for _, tt := range tests {
result := quoteIdent(tt.input)
if result != tt.expected {
t.Errorf("quoteIdent(%q) = %q, want %q", tt.input, result, tt.expected)
}
}
}
func TestGoTypeToSQL(t *testing.T) {
tests := []struct {
input string
@@ -243,7 +268,7 @@ func TestTemplateFunctions(t *testing.T) {
// Check that all expected functions are registered
expectedFuncs := []string{
"upper", "lower", "snake_case", "camelCase",
"indent", "quote", "escape", "safe_identifier",
"indent", "quote", "escape", "safe_identifier", "quote_ident",
"goTypeToSQL", "sqlTypeToGo", "isNumeric", "isText",
"first", "last", "filter", "mapFunc", "join_with",
"join",

View File

@@ -177,6 +177,72 @@ type AuditTriggerData struct {
Events string
}
// CreateUniqueConstraintData contains data for create unique constraint template
type CreateUniqueConstraintData struct {
SchemaName string
TableName string
ConstraintName string
Columns string
}
// CreateCheckConstraintData contains data for create check constraint template
type CreateCheckConstraintData struct {
SchemaName string
TableName string
ConstraintName string
Expression string
}
// CreateForeignKeyWithCheckData contains data for create foreign key with existence check template
type CreateForeignKeyWithCheckData struct {
SchemaName string
TableName string
ConstraintName string
SourceColumns string
TargetSchema string
TargetTable string
TargetColumns string
OnDelete string
OnUpdate string
Deferrable bool
}
// SetSequenceValueData contains data for set sequence value template
type SetSequenceValueData struct {
SchemaName string
TableName string
SequenceName string
ColumnName string
}
// CreateSequenceData contains data for create sequence template
type CreateSequenceData struct {
SchemaName string
SequenceName string
Increment int
MinValue int64
MaxValue int64
StartValue int64
CacheSize int
}
// AddColumnWithCheckData contains data for add column with existence check template
type AddColumnWithCheckData struct {
SchemaName string
TableName string
ColumnName string
ColumnDefinition string
}
// CreatePrimaryKeyWithAutoGenCheckData contains data for primary key with auto-generated key check template
type CreatePrimaryKeyWithAutoGenCheckData struct {
SchemaName string
TableName string
ConstraintName string
AutoGenNames string // Comma-separated list of names like "'name1', 'name2'"
Columns string
}
// Execute methods for each template
// ExecuteCreateTable executes the create table template
@@ -319,6 +385,76 @@ func (te *TemplateExecutor) ExecuteAuditTrigger(data AuditTriggerData) (string,
return buf.String(), nil
}
// ExecuteCreateUniqueConstraint executes the create unique constraint template
func (te *TemplateExecutor) ExecuteCreateUniqueConstraint(data CreateUniqueConstraintData) (string, error) {
var buf bytes.Buffer
err := te.templates.ExecuteTemplate(&buf, "create_unique_constraint.tmpl", data)
if err != nil {
return "", fmt.Errorf("failed to execute create_unique_constraint template: %w", err)
}
return buf.String(), nil
}
// ExecuteCreateCheckConstraint executes the create check constraint template
func (te *TemplateExecutor) ExecuteCreateCheckConstraint(data CreateCheckConstraintData) (string, error) {
var buf bytes.Buffer
err := te.templates.ExecuteTemplate(&buf, "create_check_constraint.tmpl", data)
if err != nil {
return "", fmt.Errorf("failed to execute create_check_constraint template: %w", err)
}
return buf.String(), nil
}
// ExecuteCreateForeignKeyWithCheck executes the create foreign key with check template
func (te *TemplateExecutor) ExecuteCreateForeignKeyWithCheck(data CreateForeignKeyWithCheckData) (string, error) {
var buf bytes.Buffer
err := te.templates.ExecuteTemplate(&buf, "create_foreign_key_with_check.tmpl", data)
if err != nil {
return "", fmt.Errorf("failed to execute create_foreign_key_with_check template: %w", err)
}
return buf.String(), nil
}
// ExecuteSetSequenceValue executes the set sequence value template
func (te *TemplateExecutor) ExecuteSetSequenceValue(data SetSequenceValueData) (string, error) {
var buf bytes.Buffer
err := te.templates.ExecuteTemplate(&buf, "set_sequence_value.tmpl", data)
if err != nil {
return "", fmt.Errorf("failed to execute set_sequence_value template: %w", err)
}
return buf.String(), nil
}
// ExecuteCreateSequence executes the create sequence template
func (te *TemplateExecutor) ExecuteCreateSequence(data CreateSequenceData) (string, error) {
var buf bytes.Buffer
err := te.templates.ExecuteTemplate(&buf, "create_sequence.tmpl", data)
if err != nil {
return "", fmt.Errorf("failed to execute create_sequence template: %w", err)
}
return buf.String(), nil
}
// ExecuteAddColumnWithCheck executes the add column with check template
func (te *TemplateExecutor) ExecuteAddColumnWithCheck(data AddColumnWithCheckData) (string, error) {
var buf bytes.Buffer
err := te.templates.ExecuteTemplate(&buf, "add_column_with_check.tmpl", data)
if err != nil {
return "", fmt.Errorf("failed to execute add_column_with_check template: %w", err)
}
return buf.String(), nil
}
// ExecuteCreatePrimaryKeyWithAutoGenCheck executes the create primary key with auto-generated key check template
func (te *TemplateExecutor) ExecuteCreatePrimaryKeyWithAutoGenCheck(data CreatePrimaryKeyWithAutoGenCheckData) (string, error) {
var buf bytes.Buffer
err := te.templates.ExecuteTemplate(&buf, "create_primary_key_with_autogen_check.tmpl", data)
if err != nil {
return "", fmt.Errorf("failed to execute create_primary_key_with_autogen_check template: %w", err)
}
return buf.String(), nil
}
// Helper functions to build template data from models
// BuildCreateTableData builds CreateTableData from a models.Table

View File

@@ -1,4 +1,4 @@
ALTER TABLE {{.SchemaName}}.{{.TableName}}
ADD COLUMN IF NOT EXISTS {{.ColumnName}} {{.ColumnType}}
ALTER TABLE {{quote_ident .SchemaName}}.{{quote_ident .TableName}}
ADD COLUMN IF NOT EXISTS {{quote_ident .ColumnName}} {{.ColumnType}}
{{- if .Default}} DEFAULT {{.Default}}{{end}}
{{- if .NotNull}} NOT NULL{{end}};

View File

@@ -0,0 +1,12 @@
DO $$
BEGIN
IF NOT EXISTS (
SELECT 1 FROM information_schema.columns
WHERE table_schema = '{{.SchemaName}}'
AND table_name = '{{.TableName}}'
AND column_name = '{{.ColumnName}}'
) THEN
ALTER TABLE {{quote_ident .SchemaName}}.{{quote_ident .TableName}} ADD COLUMN {{.ColumnDefinition}};
END IF;
END;
$$;

View File

@@ -1,7 +1,7 @@
{{- if .SetDefault -}}
ALTER TABLE {{.SchemaName}}.{{.TableName}}
ALTER COLUMN {{.ColumnName}} SET DEFAULT {{.DefaultValue}};
ALTER TABLE {{quote_ident .SchemaName}}.{{quote_ident .TableName}}
ALTER COLUMN {{quote_ident .ColumnName}} SET DEFAULT {{.DefaultValue}};
{{- else -}}
ALTER TABLE {{.SchemaName}}.{{.TableName}}
ALTER COLUMN {{.ColumnName}} DROP DEFAULT;
ALTER TABLE {{quote_ident .SchemaName}}.{{quote_ident .TableName}}
ALTER COLUMN {{quote_ident .ColumnName}} DROP DEFAULT;
{{- end -}}

View File

@@ -1,2 +1,2 @@
ALTER TABLE {{.SchemaName}}.{{.TableName}}
ALTER COLUMN {{.ColumnName}} TYPE {{.NewType}};
ALTER TABLE {{quote_ident .SchemaName}}.{{quote_ident .TableName}}
ALTER COLUMN {{quote_ident .ColumnName}} TYPE {{.NewType}};

View File

@@ -1 +1 @@
COMMENT ON COLUMN {{.SchemaName}}.{{.TableName}}.{{.ColumnName}} IS '{{.Comment}}';
COMMENT ON COLUMN {{quote_ident .SchemaName}}.{{quote_ident .TableName}}.{{quote_ident .ColumnName}} IS '{{.Comment}}';

View File

@@ -1 +1 @@
COMMENT ON TABLE {{.SchemaName}}.{{.TableName}} IS '{{.Comment}}';
COMMENT ON TABLE {{quote_ident .SchemaName}}.{{quote_ident .TableName}} IS '{{.Comment}}';

View File

@@ -0,0 +1,12 @@
DO $$
BEGIN
IF NOT EXISTS (
SELECT 1 FROM information_schema.table_constraints
WHERE table_schema = '{{.SchemaName}}'
AND table_name = '{{.TableName}}'
AND constraint_name = '{{.ConstraintName}}'
) THEN
ALTER TABLE {{quote_ident .SchemaName}}.{{quote_ident .TableName}} ADD CONSTRAINT {{quote_ident .ConstraintName}} CHECK ({{.Expression}});
END IF;
END;
$$;

View File

@@ -1,10 +1,10 @@
ALTER TABLE {{.SchemaName}}.{{.TableName}}
DROP CONSTRAINT IF EXISTS {{.ConstraintName}};
ALTER TABLE {{quote_ident .SchemaName}}.{{quote_ident .TableName}}
DROP CONSTRAINT IF EXISTS {{quote_ident .ConstraintName}};
ALTER TABLE {{.SchemaName}}.{{.TableName}}
ADD CONSTRAINT {{.ConstraintName}}
ALTER TABLE {{quote_ident .SchemaName}}.{{quote_ident .TableName}}
ADD CONSTRAINT {{quote_ident .ConstraintName}}
FOREIGN KEY ({{.SourceColumns}})
REFERENCES {{.TargetSchema}}.{{.TargetTable}} ({{.TargetColumns}})
REFERENCES {{quote_ident .TargetSchema}}.{{quote_ident .TargetTable}} ({{.TargetColumns}})
ON DELETE {{.OnDelete}}
ON UPDATE {{.OnUpdate}}
DEFERRABLE;

View File

@@ -0,0 +1,18 @@
DO $$
BEGIN
IF NOT EXISTS (
SELECT 1 FROM information_schema.table_constraints
WHERE table_schema = '{{.SchemaName}}'
AND table_name = '{{.TableName}}'
AND constraint_name = '{{.ConstraintName}}'
) THEN
ALTER TABLE {{quote_ident .SchemaName}}.{{quote_ident .TableName}}
ADD CONSTRAINT {{quote_ident .ConstraintName}}
FOREIGN KEY ({{.SourceColumns}})
REFERENCES {{quote_ident .TargetSchema}}.{{quote_ident .TargetTable}} ({{.TargetColumns}})
ON DELETE {{.OnDelete}}
ON UPDATE {{.OnUpdate}}{{if .Deferrable}}
DEFERRABLE{{end}};
END IF;
END;
$$;

View File

@@ -1,2 +1,2 @@
CREATE {{if .Unique}}UNIQUE {{end}}INDEX IF NOT EXISTS {{.IndexName}}
ON {{.SchemaName}}.{{.TableName}} USING {{.IndexType}} ({{.Columns}});
CREATE {{if .Unique}}UNIQUE {{end}}INDEX IF NOT EXISTS {{quote_ident .IndexName}}
ON {{quote_ident .SchemaName}}.{{quote_ident .TableName}} USING {{.IndexType}} ({{.Columns}});

View File

@@ -6,8 +6,8 @@ BEGIN
AND table_name = '{{.TableName}}'
AND constraint_name = '{{.ConstraintName}}'
) THEN
ALTER TABLE {{.SchemaName}}.{{.TableName}}
ADD CONSTRAINT {{.ConstraintName}} PRIMARY KEY ({{.Columns}});
ALTER TABLE {{quote_ident .SchemaName}}.{{quote_ident .TableName}}
ADD CONSTRAINT {{quote_ident .ConstraintName}} PRIMARY KEY ({{.Columns}});
END IF;
END;
$$;

View File

@@ -0,0 +1,27 @@
DO $$
DECLARE
auto_pk_name text;
BEGIN
-- Drop auto-generated primary key if it exists
SELECT constraint_name INTO auto_pk_name
FROM information_schema.table_constraints
WHERE table_schema = '{{.SchemaName}}'
AND table_name = '{{.TableName}}'
AND constraint_type = 'PRIMARY KEY'
AND constraint_name IN ({{.AutoGenNames}});
IF auto_pk_name IS NOT NULL THEN
EXECUTE 'ALTER TABLE {{quote_ident .SchemaName}}.{{quote_ident .TableName}} DROP CONSTRAINT ' || quote_ident(auto_pk_name);
END IF;
-- Add named primary key if it doesn't exist
IF NOT EXISTS (
SELECT 1 FROM information_schema.table_constraints
WHERE table_schema = '{{.SchemaName}}'
AND table_name = '{{.TableName}}'
AND constraint_name = '{{.ConstraintName}}'
) THEN
ALTER TABLE {{quote_ident .SchemaName}}.{{quote_ident .TableName}} ADD CONSTRAINT {{quote_ident .ConstraintName}} PRIMARY KEY ({{.Columns}});
END IF;
END;
$$;

View File

@@ -0,0 +1,6 @@
CREATE SEQUENCE IF NOT EXISTS {{quote_ident .SchemaName}}.{{quote_ident .SequenceName}}
INCREMENT {{.Increment}}
MINVALUE {{.MinValue}}
MAXVALUE {{.MaxValue}}
START {{.StartValue}}
CACHE {{.CacheSize}};

View File

@@ -1,7 +1,7 @@
CREATE TABLE IF NOT EXISTS {{.SchemaName}}.{{.TableName}} (
CREATE TABLE IF NOT EXISTS {{quote_ident .SchemaName}}.{{quote_ident .TableName}} (
{{- range $i, $col := .Columns}}
{{- if $i}},{{end}}
{{$col.Name}} {{$col.Type}}
{{quote_ident $col.Name}} {{$col.Type}}
{{- if $col.Default}} DEFAULT {{$col.Default}}{{end}}
{{- if $col.NotNull}} NOT NULL{{end}}
{{- end}}

View File

@@ -0,0 +1,12 @@
DO $$
BEGIN
IF NOT EXISTS (
SELECT 1 FROM information_schema.table_constraints
WHERE table_schema = '{{.SchemaName}}'
AND table_name = '{{.TableName}}'
AND constraint_name = '{{.ConstraintName}}'
) THEN
ALTER TABLE {{quote_ident .SchemaName}}.{{quote_ident .TableName}} ADD CONSTRAINT {{quote_ident .ConstraintName}} UNIQUE ({{.Columns}});
END IF;
END;
$$;

View File

@@ -1 +1 @@
ALTER TABLE {{.SchemaName}}.{{.TableName}} DROP CONSTRAINT IF EXISTS {{.ConstraintName}};
ALTER TABLE {{quote_ident .SchemaName}}.{{quote_ident .TableName}} DROP CONSTRAINT IF EXISTS {{quote_ident .ConstraintName}};

View File

@@ -1 +1 @@
DROP INDEX IF EXISTS {{.SchemaName}}.{{.IndexName}} CASCADE;
DROP INDEX IF EXISTS {{quote_ident .SchemaName}}.{{quote_ident .IndexName}} CASCADE;

View File

@@ -0,0 +1,19 @@
DO $$
DECLARE
m_cnt bigint;
BEGIN
IF EXISTS (
SELECT 1 FROM pg_class c
INNER JOIN pg_namespace n ON n.oid = c.relnamespace
WHERE c.relname = '{{.SequenceName}}'
AND n.nspname = '{{.SchemaName}}'
AND c.relkind = 'S'
) THEN
SELECT COALESCE(MAX({{quote_ident .ColumnName}}), 0) + 1
FROM {{quote_ident .SchemaName}}.{{quote_ident .TableName}}
INTO m_cnt;
PERFORM setval('{{quote_ident .SchemaName}}.{{quote_ident .SequenceName}}'::regclass, m_cnt);
END IF;
END;
$$;

View File

@@ -22,6 +22,7 @@ type Writer struct {
options *writers.WriterOptions
writer io.Writer
executionReport *ExecutionReport
executor *TemplateExecutor
}
// ExecutionReport tracks the execution status of SQL statements
@@ -57,8 +58,10 @@ type ExecutionError struct {
// NewWriter creates a new PostgreSQL SQL writer
func NewWriter(options *writers.WriterOptions) *Writer {
executor, _ := NewTemplateExecutor()
return &Writer{
options: options,
options: options,
executor: executor,
}
}
@@ -168,6 +171,13 @@ func (w *Writer) GenerateSchemaStatements(schema *models.Schema) ([]string, erro
statements = append(statements, stmts...)
}
// Phase 3.5: Add missing columns (for existing tables)
addColStmts, err := w.GenerateAddColumnStatements(schema)
if err != nil {
return nil, fmt.Errorf("failed to generate add column statements: %w", err)
}
statements = append(statements, addColStmts...)
// Phase 4: Primary keys
for _, table := range schema.Tables {
// First check for explicit PrimaryKeyConstraint
@@ -179,27 +189,50 @@ func (w *Writer) GenerateSchemaStatements(schema *models.Schema) ([]string, erro
}
}
var pkColumns []string
var pkName string
if pkConstraint != nil {
stmt := fmt.Sprintf("ALTER TABLE %s.%s ADD CONSTRAINT %s PRIMARY KEY (%s)",
schema.SQLName(), table.SQLName(), pkConstraint.Name, strings.Join(pkConstraint.Columns, ", "))
statements = append(statements, stmt)
pkColumns = pkConstraint.Columns
pkName = pkConstraint.Name
} else {
// No explicit constraint, check for columns with IsPrimaryKey = true
pkColumns := []string{}
pkCols := []string{}
for _, col := range table.Columns {
if col.IsPrimaryKey {
pkColumns = append(pkColumns, col.SQLName())
pkCols = append(pkCols, col.SQLName())
}
}
if len(pkColumns) > 0 {
if len(pkCols) > 0 {
// Sort for consistent output
sort.Strings(pkColumns)
pkName := fmt.Sprintf("pk_%s_%s", schema.SQLName(), table.SQLName())
stmt := fmt.Sprintf("ALTER TABLE %s.%s ADD CONSTRAINT %s PRIMARY KEY (%s)",
schema.SQLName(), table.SQLName(), pkName, strings.Join(pkColumns, ", "))
statements = append(statements, stmt)
sort.Strings(pkCols)
pkColumns = pkCols
pkName = fmt.Sprintf("pk_%s_%s", schema.SQLName(), table.SQLName())
}
}
if len(pkColumns) > 0 {
// Auto-generated primary key names to check for and drop
autoGenPKNames := []string{
fmt.Sprintf("%s_pkey", table.Name),
fmt.Sprintf("%s_%s_pkey", schema.Name, table.Name),
}
// Use template to generate primary key statement
data := CreatePrimaryKeyWithAutoGenCheckData{
SchemaName: schema.Name,
TableName: table.Name,
ConstraintName: pkName,
AutoGenNames: formatStringList(autoGenPKNames),
Columns: strings.Join(pkColumns, ", "),
}
stmt, err := w.executor.ExecuteCreatePrimaryKeyWithAutoGenCheck(data)
if err != nil {
return nil, fmt.Errorf("failed to generate primary key for %s.%s: %w", schema.Name, table.Name, err)
}
statements = append(statements, stmt)
}
}
// Phase 5: Indexes
@@ -243,7 +276,53 @@ func (w *Writer) GenerateSchemaStatements(schema *models.Schema) ([]string, erro
}
stmt := fmt.Sprintf("CREATE %sINDEX IF NOT EXISTS %s ON %s.%s USING %s (%s)%s",
uniqueStr, index.Name, schema.SQLName(), table.SQLName(), indexType, strings.Join(columnExprs, ", "), whereClause)
uniqueStr, quoteIdentifier(index.Name), schema.SQLName(), table.SQLName(), indexType, strings.Join(columnExprs, ", "), whereClause)
statements = append(statements, stmt)
}
}
// Phase 5.5: Unique constraints
for _, table := range schema.Tables {
for _, constraint := range table.Constraints {
if constraint.Type != models.UniqueConstraint {
continue
}
// Use template to generate unique constraint statement
data := CreateUniqueConstraintData{
SchemaName: schema.Name,
TableName: table.Name,
ConstraintName: constraint.Name,
Columns: strings.Join(constraint.Columns, ", "),
}
stmt, err := w.executor.ExecuteCreateUniqueConstraint(data)
if err != nil {
return nil, fmt.Errorf("failed to generate unique constraint for %s.%s: %w", schema.Name, table.Name, err)
}
statements = append(statements, stmt)
}
}
// Phase 5.7: Check constraints
for _, table := range schema.Tables {
for _, constraint := range table.Constraints {
if constraint.Type != models.CheckConstraint {
continue
}
// Use template to generate check constraint statement
data := CreateCheckConstraintData{
SchemaName: schema.Name,
TableName: table.Name,
ConstraintName: constraint.Name,
Expression: constraint.Expression,
}
stmt, err := w.executor.ExecuteCreateCheckConstraint(data)
if err != nil {
return nil, fmt.Errorf("failed to generate check constraint for %s.%s: %w", schema.Name, table.Name, err)
}
statements = append(statements, stmt)
}
}
@@ -270,12 +349,24 @@ func (w *Writer) GenerateSchemaStatements(schema *models.Schema) ([]string, erro
onUpdate = "NO ACTION"
}
stmt := fmt.Sprintf("ALTER TABLE %s.%s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s.%s(%s) ON DELETE %s ON UPDATE %s",
schema.SQLName(), table.SQLName(), constraint.Name,
strings.Join(constraint.Columns, ", "),
strings.ToLower(refSchema), strings.ToLower(constraint.ReferencedTable),
strings.Join(constraint.ReferencedColumns, ", "),
onDelete, onUpdate)
// Use template to generate foreign key statement
data := CreateForeignKeyWithCheckData{
SchemaName: schema.Name,
TableName: table.Name,
ConstraintName: constraint.Name,
SourceColumns: strings.Join(constraint.Columns, ", "),
TargetSchema: refSchema,
TargetTable: constraint.ReferencedTable,
TargetColumns: strings.Join(constraint.ReferencedColumns, ", "),
OnDelete: onDelete,
OnUpdate: onUpdate,
Deferrable: false,
}
stmt, err := w.executor.ExecuteCreateForeignKeyWithCheck(data)
if err != nil {
return nil, fmt.Errorf("failed to generate foreign key for %s.%s: %w", schema.Name, table.Name, err)
}
statements = append(statements, stmt)
}
}
@@ -300,6 +391,67 @@ func (w *Writer) GenerateSchemaStatements(schema *models.Schema) ([]string, erro
return statements, nil
}
// GenerateAddColumnStatements generates ALTER TABLE ADD COLUMN statements for existing tables
// This is useful for schema evolution when new columns are added to existing tables
func (w *Writer) GenerateAddColumnStatements(schema *models.Schema) ([]string, error) {
statements := []string{}
statements = append(statements, fmt.Sprintf("-- Add missing columns for schema: %s", schema.Name))
for _, table := range schema.Tables {
// Sort columns by sequence or name for consistent output
columns := make([]*models.Column, 0, len(table.Columns))
for _, col := range table.Columns {
columns = append(columns, col)
}
sort.Slice(columns, func(i, j int) bool {
if columns[i].Sequence != columns[j].Sequence {
return columns[i].Sequence < columns[j].Sequence
}
return columns[i].Name < columns[j].Name
})
for _, col := range columns {
colDef := w.generateColumnDefinition(col)
// Use template to generate add column statement
data := AddColumnWithCheckData{
SchemaName: schema.Name,
TableName: table.Name,
ColumnName: col.Name,
ColumnDefinition: colDef,
}
stmt, err := w.executor.ExecuteAddColumnWithCheck(data)
if err != nil {
return nil, fmt.Errorf("failed to generate add column for %s.%s.%s: %w", schema.Name, table.Name, col.Name, err)
}
statements = append(statements, stmt)
}
}
return statements, nil
}
// GenerateAddColumnsForDatabase generates ALTER TABLE ADD COLUMN statements for the entire database
func (w *Writer) GenerateAddColumnsForDatabase(db *models.Database) ([]string, error) {
statements := []string{}
statements = append(statements, "-- Add missing columns to existing tables")
statements = append(statements, fmt.Sprintf("-- Database: %s", db.Name))
statements = append(statements, "-- Generated by RelSpec")
for _, schema := range db.Schemas {
schemaStatements, err := w.GenerateAddColumnStatements(schema)
if err != nil {
return nil, fmt.Errorf("failed to generate add column statements for schema %s: %w", schema.Name, err)
}
statements = append(statements, schemaStatements...)
}
return statements, nil
}
// generateCreateTableStatement generates CREATE TABLE statement
func (w *Writer) generateCreateTableStatement(schema *models.Schema, table *models.Table) ([]string, error) {
statements := []string{}
@@ -322,7 +474,7 @@ func (w *Writer) generateCreateTableStatement(schema *models.Schema, table *mode
columnDefs = append(columnDefs, " "+def)
}
stmt := fmt.Sprintf("CREATE TABLE %s.%s (\n%s\n)",
stmt := fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s.%s (\n%s\n)",
schema.SQLName(), table.SQLName(), strings.Join(columnDefs, ",\n"))
statements = append(statements, stmt)
@@ -336,14 +488,25 @@ func (w *Writer) generateColumnDefinition(col *models.Column) string {
// Type with length/precision - convert to valid PostgreSQL type
baseType := pgsql.ConvertSQLType(col.Type)
typeStr := baseType
// Only add size specifiers for types that support them
if col.Length > 0 && col.Precision == 0 {
typeStr = fmt.Sprintf("%s(%d)", baseType, col.Length)
} else if col.Precision > 0 {
if col.Scale > 0 {
typeStr = fmt.Sprintf("%s(%d,%d)", baseType, col.Precision, col.Scale)
} else {
typeStr = fmt.Sprintf("%s(%d)", baseType, col.Precision)
if supportsLength(baseType) {
typeStr = fmt.Sprintf("%s(%d)", baseType, col.Length)
} else if isTextTypeWithoutLength(baseType) {
// Convert text with length to varchar
typeStr = fmt.Sprintf("varchar(%d)", col.Length)
}
// For types that don't support length (integer, bigint, etc.), ignore the length
} else if col.Precision > 0 {
if supportsPrecision(baseType) {
if col.Scale > 0 {
typeStr = fmt.Sprintf("%s(%d,%d)", baseType, col.Precision, col.Scale)
} else {
typeStr = fmt.Sprintf("%s(%d)", baseType, col.Precision)
}
}
// For types that don't support precision, ignore it
}
parts = append(parts, typeStr)
@@ -396,6 +559,11 @@ func (w *Writer) WriteSchema(schema *models.Schema) error {
return err
}
// Phase 3.5: Add missing columns (priority 120)
if err := w.writeAddColumns(schema); err != nil {
return err
}
// Phase 4: Create primary keys (priority 160)
if err := w.writePrimaryKeys(schema); err != nil {
return err
@@ -406,6 +574,16 @@ func (w *Writer) WriteSchema(schema *models.Schema) error {
return err
}
// Phase 5.5: Create unique constraints (priority 185)
if err := w.writeUniqueConstraints(schema); err != nil {
return err
}
// Phase 5.7: Create check constraints (priority 190)
if err := w.writeCheckConstraints(schema); err != nil {
return err
}
// Phase 6: Create foreign key constraints (priority 195)
if err := w.writeForeignKeys(schema); err != nil {
return err
@@ -437,6 +615,44 @@ func (w *Writer) WriteTable(table *models.Table) error {
return w.WriteSchema(schema)
}
// WriteAddColumnStatements writes ALTER TABLE ADD COLUMN statements for a database
// This is used for schema evolution/migration when new columns are added
func (w *Writer) WriteAddColumnStatements(db *models.Database) error {
var writer io.Writer
var file *os.File
var err error
// Use existing writer if already set (for testing)
if w.writer != nil {
writer = w.writer
} else if w.options.OutputPath != "" {
// Determine output destination
file, err = os.Create(w.options.OutputPath)
if err != nil {
return fmt.Errorf("failed to create output file: %w", err)
}
defer file.Close()
writer = file
} else {
writer = os.Stdout
}
w.writer = writer
// Generate statements
statements, err := w.GenerateAddColumnsForDatabase(db)
if err != nil {
return err
}
// Write each statement
for _, stmt := range statements {
fmt.Fprintf(w.writer, "%s;\n\n", stmt)
}
return nil
}
// writeCreateSchema generates CREATE SCHEMA statement
func (w *Writer) writeCreateSchema(schema *models.Schema) error {
if schema.Name == "public" {
@@ -465,13 +681,23 @@ func (w *Writer) writeSequences(schema *models.Schema) error {
}
seqName := fmt.Sprintf("identity_%s_%s", table.SQLName(), pk.SQLName())
fmt.Fprintf(w.writer, "CREATE SEQUENCE IF NOT EXISTS %s.%s\n",
schema.SQLName(), seqName)
fmt.Fprintf(w.writer, " INCREMENT 1\n")
fmt.Fprintf(w.writer, " MINVALUE 1\n")
fmt.Fprintf(w.writer, " MAXVALUE 9223372036854775807\n")
fmt.Fprintf(w.writer, " START 1\n")
fmt.Fprintf(w.writer, " CACHE 1;\n\n")
data := CreateSequenceData{
SchemaName: schema.Name,
SequenceName: seqName,
Increment: 1,
MinValue: 1,
MaxValue: 9223372036854775807,
StartValue: 1,
CacheSize: 1,
}
sql, err := w.executor.ExecuteCreateSequence(data)
if err != nil {
return fmt.Errorf("failed to generate create sequence for %s.%s: %w", schema.Name, seqName, err)
}
fmt.Fprint(w.writer, sql)
fmt.Fprint(w.writer, "\n")
}
return nil
@@ -490,15 +716,8 @@ func (w *Writer) writeCreateTables(schema *models.Schema) error {
columnDefs := make([]string, 0, len(columns))
for _, col := range columns {
colDef := fmt.Sprintf(" %s %s", col.SQLName(), pgsql.ConvertSQLType(col.Type))
// Add default value if present
if col.Default != nil && col.Default != "" {
// Strip backticks - DBML uses them for SQL expressions but PostgreSQL doesn't
defaultVal := fmt.Sprintf("%v", col.Default)
colDef += fmt.Sprintf(" DEFAULT %s", stripBackticks(defaultVal))
}
// Use generateColumnDefinition to properly handle type, length, precision, and defaults
colDef := " " + w.generateColumnDefinition(col)
columnDefs = append(columnDefs, colDef)
}
@@ -509,6 +728,36 @@ func (w *Writer) writeCreateTables(schema *models.Schema) error {
return nil
}
// writeAddColumns generates ALTER TABLE ADD COLUMN statements for missing columns
func (w *Writer) writeAddColumns(schema *models.Schema) error {
fmt.Fprintf(w.writer, "-- Add missing columns for schema: %s\n", schema.Name)
for _, table := range schema.Tables {
// Sort columns by sequence or name for consistent output
columns := getSortedColumns(table.Columns)
for _, col := range columns {
colDef := w.generateColumnDefinition(col)
data := AddColumnWithCheckData{
SchemaName: schema.Name,
TableName: table.Name,
ColumnName: col.Name,
ColumnDefinition: colDef,
}
sql, err := w.executor.ExecuteAddColumnWithCheck(data)
if err != nil {
return fmt.Errorf("failed to generate add column for %s.%s.%s: %w", schema.Name, table.Name, col.Name, err)
}
fmt.Fprint(w.writer, sql)
fmt.Fprint(w.writer, "\n")
}
}
return nil
}
// writePrimaryKeys generates ALTER TABLE statements for primary keys
func (w *Writer) writePrimaryKeys(schema *models.Schema) error {
fmt.Fprintf(w.writer, "-- Primary keys for schema: %s\n", schema.Name)
@@ -550,18 +799,26 @@ func (w *Writer) writePrimaryKeys(schema *models.Schema) error {
continue
}
fmt.Fprintf(w.writer, "DO $$\nBEGIN\n")
fmt.Fprintf(w.writer, " IF NOT EXISTS (\n")
fmt.Fprintf(w.writer, " SELECT 1 FROM information_schema.table_constraints\n")
fmt.Fprintf(w.writer, " WHERE table_schema = '%s'\n", schema.Name)
fmt.Fprintf(w.writer, " AND table_name = '%s'\n", table.Name)
fmt.Fprintf(w.writer, " AND constraint_name = '%s'\n", pkName)
fmt.Fprintf(w.writer, " ) THEN\n")
fmt.Fprintf(w.writer, " ALTER TABLE %s.%s\n", schema.SQLName(), table.SQLName())
fmt.Fprintf(w.writer, " ADD CONSTRAINT %s PRIMARY KEY (%s);\n",
pkName, strings.Join(columnNames, ", "))
fmt.Fprintf(w.writer, " END IF;\n")
fmt.Fprintf(w.writer, "END;\n$$;\n\n")
// Auto-generated primary key names to check for and drop
autoGenPKNames := []string{
fmt.Sprintf("%s_pkey", table.Name),
fmt.Sprintf("%s_%s_pkey", schema.Name, table.Name),
}
data := CreatePrimaryKeyWithAutoGenCheckData{
SchemaName: schema.Name,
TableName: table.Name,
ConstraintName: pkName,
AutoGenNames: formatStringList(autoGenPKNames),
Columns: strings.Join(columnNames, ", "),
}
sql, err := w.executor.ExecuteCreatePrimaryKeyWithAutoGenCheck(data)
if err != nil {
return fmt.Errorf("failed to generate primary key for %s.%s: %w", schema.Name, table.Name, err)
}
fmt.Fprint(w.writer, sql)
fmt.Fprint(w.writer, "\n")
}
return nil
@@ -644,6 +901,91 @@ func (w *Writer) writeIndexes(schema *models.Schema) error {
return nil
}
// writeUniqueConstraints generates ALTER TABLE statements for unique constraints
func (w *Writer) writeUniqueConstraints(schema *models.Schema) error {
fmt.Fprintf(w.writer, "-- Unique constraints for schema: %s\n", schema.Name)
for _, table := range schema.Tables {
// Sort constraints by name for consistent output
constraintNames := make([]string, 0, len(table.Constraints))
for name, constraint := range table.Constraints {
if constraint.Type == models.UniqueConstraint {
constraintNames = append(constraintNames, name)
}
}
sort.Strings(constraintNames)
for _, name := range constraintNames {
constraint := table.Constraints[name]
// Build column list
columnExprs := make([]string, 0, len(constraint.Columns))
for _, colName := range constraint.Columns {
if col, ok := table.Columns[colName]; ok {
columnExprs = append(columnExprs, col.SQLName())
}
}
if len(columnExprs) == 0 {
continue
}
sql, err := w.executor.ExecuteCreateUniqueConstraint(CreateUniqueConstraintData{
SchemaName: schema.Name,
TableName: table.Name,
ConstraintName: constraint.Name,
Columns: strings.Join(columnExprs, ", "),
})
if err != nil {
return fmt.Errorf("failed to generate unique constraint: %w", err)
}
fmt.Fprintf(w.writer, "%s\n\n", sql)
}
}
return nil
}
// writeCheckConstraints generates ALTER TABLE statements for check constraints
func (w *Writer) writeCheckConstraints(schema *models.Schema) error {
fmt.Fprintf(w.writer, "-- Check constraints for schema: %s\n", schema.Name)
for _, table := range schema.Tables {
// Sort constraints by name for consistent output
constraintNames := make([]string, 0, len(table.Constraints))
for name, constraint := range table.Constraints {
if constraint.Type == models.CheckConstraint {
constraintNames = append(constraintNames, name)
}
}
sort.Strings(constraintNames)
for _, name := range constraintNames {
constraint := table.Constraints[name]
// Skip if expression is empty
if constraint.Expression == "" {
continue
}
sql, err := w.executor.ExecuteCreateCheckConstraint(CreateCheckConstraintData{
SchemaName: schema.Name,
TableName: table.Name,
ConstraintName: constraint.Name,
Expression: constraint.Expression,
})
if err != nil {
return fmt.Errorf("failed to generate check constraint: %w", err)
}
fmt.Fprintf(w.writer, "%s\n\n", sql)
}
}
return nil
}
// writeForeignKeys generates ALTER TABLE statements for foreign keys
func (w *Writer) writeForeignKeys(schema *models.Schema) error {
fmt.Fprintf(w.writer, "-- Foreign keys for schema: %s\n", schema.Name)
@@ -711,13 +1053,6 @@ func (w *Writer) writeForeignKeys(schema *models.Schema) error {
onUpdate = strings.ToUpper(fkConstraint.OnUpdate)
}
fmt.Fprintf(w.writer, "ALTER TABLE %s.%s\n", schema.SQLName(), table.SQLName())
fmt.Fprintf(w.writer, " DROP CONSTRAINT IF EXISTS %s;\n", fkName)
fmt.Fprintf(w.writer, "\n")
fmt.Fprintf(w.writer, "ALTER TABLE %s.%s\n", schema.SQLName(), table.SQLName())
fmt.Fprintf(w.writer, " ADD CONSTRAINT %s\n", fkName)
fmt.Fprintf(w.writer, " FOREIGN KEY (%s)\n", strings.Join(sourceColumns, ", "))
// Use constraint's referenced schema/table or relationship's ToSchema/ToTable
refSchema := fkConstraint.ReferencedSchema
if refSchema == "" {
@@ -728,11 +1063,103 @@ func (w *Writer) writeForeignKeys(schema *models.Schema) error {
refTable = rel.ToTable
}
fmt.Fprintf(w.writer, " REFERENCES %s.%s (%s)\n",
refSchema, refTable, strings.Join(targetColumns, ", "))
fmt.Fprintf(w.writer, " ON DELETE %s\n", onDelete)
fmt.Fprintf(w.writer, " ON UPDATE %s\n", onUpdate)
fmt.Fprintf(w.writer, " DEFERRABLE;\n\n")
// Use template executor to generate foreign key with existence check
data := CreateForeignKeyWithCheckData{
SchemaName: schema.Name,
TableName: table.Name,
ConstraintName: fkName,
SourceColumns: strings.Join(sourceColumns, ", "),
TargetSchema: refSchema,
TargetTable: refTable,
TargetColumns: strings.Join(targetColumns, ", "),
OnDelete: onDelete,
OnUpdate: onUpdate,
Deferrable: true,
}
sql, err := w.executor.ExecuteCreateForeignKeyWithCheck(data)
if err != nil {
return fmt.Errorf("failed to generate foreign key for %s.%s: %w", schema.Name, table.Name, err)
}
fmt.Fprint(w.writer, sql)
}
// Also process any foreign key constraints that don't have a relationship
processedConstraints := make(map[string]bool)
for _, rel := range table.Relationships {
fkName := rel.ForeignKey
if fkName == "" {
fkName = rel.Name
}
if fkName != "" {
processedConstraints[fkName] = true
}
}
// Find unprocessed foreign key constraints
constraintNames := make([]string, 0)
for name, constraint := range table.Constraints {
if constraint.Type == models.ForeignKeyConstraint && !processedConstraints[name] {
constraintNames = append(constraintNames, name)
}
}
sort.Strings(constraintNames)
for _, name := range constraintNames {
constraint := table.Constraints[name]
// Build column lists
sourceColumns := make([]string, 0, len(constraint.Columns))
for _, colName := range constraint.Columns {
if col, ok := table.Columns[colName]; ok {
sourceColumns = append(sourceColumns, col.SQLName())
} else {
sourceColumns = append(sourceColumns, colName)
}
}
targetColumns := make([]string, 0, len(constraint.ReferencedColumns))
for _, colName := range constraint.ReferencedColumns {
targetColumns = append(targetColumns, strings.ToLower(colName))
}
if len(sourceColumns) == 0 || len(targetColumns) == 0 {
continue
}
onDelete := "NO ACTION"
if constraint.OnDelete != "" {
onDelete = strings.ToUpper(constraint.OnDelete)
}
onUpdate := "NO ACTION"
if constraint.OnUpdate != "" {
onUpdate = strings.ToUpper(constraint.OnUpdate)
}
refSchema := constraint.ReferencedSchema
if refSchema == "" {
refSchema = schema.Name
}
refTable := constraint.ReferencedTable
// Use template executor to generate foreign key with existence check
data := CreateForeignKeyWithCheckData{
SchemaName: schema.Name,
TableName: table.Name,
ConstraintName: constraint.Name,
SourceColumns: strings.Join(sourceColumns, ", "),
TargetSchema: refSchema,
TargetTable: refTable,
TargetColumns: strings.Join(targetColumns, ", "),
OnDelete: onDelete,
OnUpdate: onUpdate,
Deferrable: false,
}
sql, err := w.executor.ExecuteCreateForeignKeyWithCheck(data)
if err != nil {
return fmt.Errorf("failed to generate foreign key for %s.%s: %w", schema.Name, table.Name, err)
}
fmt.Fprint(w.writer, sql)
}
}
@@ -751,26 +1178,19 @@ func (w *Writer) writeSetSequenceValues(schema *models.Schema) error {
seqName := fmt.Sprintf("identity_%s_%s", table.SQLName(), pk.SQLName())
fmt.Fprintf(w.writer, "DO $$\n")
fmt.Fprintf(w.writer, "DECLARE\n")
fmt.Fprintf(w.writer, " m_cnt bigint;\n")
fmt.Fprintf(w.writer, "BEGIN\n")
fmt.Fprintf(w.writer, " IF EXISTS (\n")
fmt.Fprintf(w.writer, " SELECT 1 FROM pg_class c\n")
fmt.Fprintf(w.writer, " INNER JOIN pg_namespace n ON n.oid = c.relnamespace\n")
fmt.Fprintf(w.writer, " WHERE c.relname = '%s'\n", seqName)
fmt.Fprintf(w.writer, " AND n.nspname = '%s'\n", schema.Name)
fmt.Fprintf(w.writer, " AND c.relkind = 'S'\n")
fmt.Fprintf(w.writer, " ) THEN\n")
fmt.Fprintf(w.writer, " SELECT COALESCE(MAX(%s), 0) + 1\n", pk.SQLName())
fmt.Fprintf(w.writer, " FROM %s.%s\n", schema.SQLName(), table.SQLName())
fmt.Fprintf(w.writer, " INTO m_cnt;\n")
fmt.Fprintf(w.writer, " \n")
fmt.Fprintf(w.writer, " PERFORM setval('%s.%s'::regclass, m_cnt);\n",
schema.SQLName(), seqName)
fmt.Fprintf(w.writer, " END IF;\n")
fmt.Fprintf(w.writer, "END;\n")
fmt.Fprintf(w.writer, "$$;\n\n")
// Use template executor to generate set sequence value statement
data := SetSequenceValueData{
SchemaName: schema.Name,
TableName: table.Name,
SequenceName: seqName,
ColumnName: pk.Name,
}
sql, err := w.executor.ExecuteSetSequenceValue(data)
if err != nil {
return fmt.Errorf("failed to generate set sequence value for %s.%s: %w", schema.Name, table.Name, err)
}
fmt.Fprint(w.writer, sql)
fmt.Fprint(w.writer, "\n")
}
return nil
@@ -844,6 +1264,44 @@ func isTextType(colType string) bool {
return false
}
// supportsLength checks if a PostgreSQL type supports length specification
func supportsLength(colType string) bool {
lengthTypes := []string{"varchar", "character varying", "char", "character", "bit", "bit varying", "varbit"}
lowerType := strings.ToLower(colType)
for _, t := range lengthTypes {
if lowerType == t || strings.HasPrefix(lowerType, t+"(") {
return true
}
}
return false
}
// supportsPrecision checks if a PostgreSQL type supports precision/scale specification
func supportsPrecision(colType string) bool {
precisionTypes := []string{"numeric", "decimal", "time", "timestamp", "timestamptz", "timestamp with time zone", "timestamp without time zone", "time with time zone", "time without time zone", "interval"}
lowerType := strings.ToLower(colType)
for _, t := range precisionTypes {
if lowerType == t || strings.HasPrefix(lowerType, t+"(") {
return true
}
}
return false
}
// isTextTypeWithoutLength checks if type is text (which should convert to varchar when length is specified)
func isTextTypeWithoutLength(colType string) bool {
return strings.EqualFold(colType, "text")
}
// formatStringList formats a list of strings as a SQL-safe comma-separated quoted list
func formatStringList(items []string) string {
quoted := make([]string, len(items))
for i, item := range items {
quoted[i] = fmt.Sprintf("'%s'", escapeQuote(item))
}
return strings.Join(quoted, ", ")
}
// extractOperatorClass extracts operator class from index comment/note
// Looks for common operator classes like gin_trgm_ops, gist_trgm_ops, etc.
func extractOperatorClass(comment string) string {
@@ -1086,3 +1544,9 @@ func truncateStatement(stmt string) string {
func getCurrentTimestamp() string {
return time.Now().Format("2006-01-02 15:04:05")
}
// quoteIdentifier wraps an identifier in double quotes if necessary
// This is needed for identifiers that start with numbers or contain special characters
func quoteIdentifier(s string) string {
return quoteIdent(s)
}

View File

@@ -164,6 +164,296 @@ func TestWriteForeignKeys(t *testing.T) {
}
}
func TestWriteUniqueConstraints(t *testing.T) {
// Create a test database with unique constraints
db := models.InitDatabase("testdb")
schema := models.InitSchema("public")
// Create table with unique constraints
table := models.InitTable("users", "public")
// Add columns
emailCol := models.InitColumn("email", "users", "public")
emailCol.Type = "varchar(255)"
emailCol.NotNull = true
table.Columns["email"] = emailCol
guidCol := models.InitColumn("guid", "users", "public")
guidCol.Type = "uuid"
guidCol.NotNull = true
table.Columns["guid"] = guidCol
// Add unique constraints
emailConstraint := &models.Constraint{
Name: "uq_email",
Type: models.UniqueConstraint,
Schema: "public",
Table: "users",
Columns: []string{"email"},
}
table.Constraints["uq_email"] = emailConstraint
guidConstraint := &models.Constraint{
Name: "uq_guid",
Type: models.UniqueConstraint,
Schema: "public",
Table: "users",
Columns: []string{"guid"},
}
table.Constraints["uq_guid"] = guidConstraint
schema.Tables = append(schema.Tables, table)
db.Schemas = append(db.Schemas, schema)
// Create writer with output to buffer
var buf bytes.Buffer
options := &writers.WriterOptions{}
writer := NewWriter(options)
writer.writer = &buf
// Write the database
err := writer.WriteDatabase(db)
if err != nil {
t.Fatalf("WriteDatabase failed: %v", err)
}
output := buf.String()
// Print output for debugging
t.Logf("Generated SQL:\n%s", output)
// Verify unique constraints are present
if !strings.Contains(output, "-- Unique constraints for schema: public") {
t.Errorf("Output missing unique constraints header")
}
if !strings.Contains(output, "ADD CONSTRAINT uq_email UNIQUE (email)") {
t.Errorf("Output missing uq_email unique constraint\nFull output:\n%s", output)
}
if !strings.Contains(output, "ADD CONSTRAINT uq_guid UNIQUE (guid)") {
t.Errorf("Output missing uq_guid unique constraint\nFull output:\n%s", output)
}
}
func TestWriteCheckConstraints(t *testing.T) {
// Create a test database with check constraints
db := models.InitDatabase("testdb")
schema := models.InitSchema("public")
// Create table with check constraints
table := models.InitTable("products", "public")
// Add columns
priceCol := models.InitColumn("price", "products", "public")
priceCol.Type = "numeric(10,2)"
table.Columns["price"] = priceCol
statusCol := models.InitColumn("status", "products", "public")
statusCol.Type = "varchar(20)"
table.Columns["status"] = statusCol
quantityCol := models.InitColumn("quantity", "products", "public")
quantityCol.Type = "integer"
table.Columns["quantity"] = quantityCol
// Add check constraints
priceConstraint := &models.Constraint{
Name: "ck_price_positive",
Type: models.CheckConstraint,
Schema: "public",
Table: "products",
Expression: "price >= 0",
}
table.Constraints["ck_price_positive"] = priceConstraint
statusConstraint := &models.Constraint{
Name: "ck_status_valid",
Type: models.CheckConstraint,
Schema: "public",
Table: "products",
Expression: "status IN ('active', 'inactive', 'discontinued')",
}
table.Constraints["ck_status_valid"] = statusConstraint
quantityConstraint := &models.Constraint{
Name: "ck_quantity_nonnegative",
Type: models.CheckConstraint,
Schema: "public",
Table: "products",
Expression: "quantity >= 0",
}
table.Constraints["ck_quantity_nonnegative"] = quantityConstraint
schema.Tables = append(schema.Tables, table)
db.Schemas = append(db.Schemas, schema)
// Create writer with output to buffer
var buf bytes.Buffer
options := &writers.WriterOptions{}
writer := NewWriter(options)
writer.writer = &buf
// Write the database
err := writer.WriteDatabase(db)
if err != nil {
t.Fatalf("WriteDatabase failed: %v", err)
}
output := buf.String()
// Print output for debugging
t.Logf("Generated SQL:\n%s", output)
// Verify check constraints are present
if !strings.Contains(output, "-- Check constraints for schema: public") {
t.Errorf("Output missing check constraints header")
}
if !strings.Contains(output, "ADD CONSTRAINT ck_price_positive CHECK (price >= 0)") {
t.Errorf("Output missing ck_price_positive check constraint\nFull output:\n%s", output)
}
if !strings.Contains(output, "ADD CONSTRAINT ck_status_valid CHECK (status IN ('active', 'inactive', 'discontinued'))") {
t.Errorf("Output missing ck_status_valid check constraint\nFull output:\n%s", output)
}
if !strings.Contains(output, "ADD CONSTRAINT ck_quantity_nonnegative CHECK (quantity >= 0)") {
t.Errorf("Output missing ck_quantity_nonnegative check constraint\nFull output:\n%s", output)
}
}
func TestWriteAllConstraintTypes(t *testing.T) {
// Create a comprehensive test with all constraint types
db := models.InitDatabase("testdb")
schema := models.InitSchema("public")
// Create orders table
ordersTable := models.InitTable("orders", "public")
// Add columns
idCol := models.InitColumn("id", "orders", "public")
idCol.Type = "integer"
idCol.IsPrimaryKey = true
ordersTable.Columns["id"] = idCol
userIDCol := models.InitColumn("user_id", "orders", "public")
userIDCol.Type = "integer"
userIDCol.NotNull = true
ordersTable.Columns["user_id"] = userIDCol
orderNumberCol := models.InitColumn("order_number", "orders", "public")
orderNumberCol.Type = "varchar(50)"
orderNumberCol.NotNull = true
ordersTable.Columns["order_number"] = orderNumberCol
totalCol := models.InitColumn("total", "orders", "public")
totalCol.Type = "numeric(10,2)"
ordersTable.Columns["total"] = totalCol
statusCol := models.InitColumn("status", "orders", "public")
statusCol.Type = "varchar(20)"
ordersTable.Columns["status"] = statusCol
// Add primary key constraint
pkConstraint := &models.Constraint{
Name: "pk_orders",
Type: models.PrimaryKeyConstraint,
Schema: "public",
Table: "orders",
Columns: []string{"id"},
}
ordersTable.Constraints["pk_orders"] = pkConstraint
// Add unique constraint
uniqueConstraint := &models.Constraint{
Name: "uq_order_number",
Type: models.UniqueConstraint,
Schema: "public",
Table: "orders",
Columns: []string{"order_number"},
}
ordersTable.Constraints["uq_order_number"] = uniqueConstraint
// Add check constraint
checkConstraint := &models.Constraint{
Name: "ck_total_positive",
Type: models.CheckConstraint,
Schema: "public",
Table: "orders",
Expression: "total > 0",
}
ordersTable.Constraints["ck_total_positive"] = checkConstraint
statusCheckConstraint := &models.Constraint{
Name: "ck_status_valid",
Type: models.CheckConstraint,
Schema: "public",
Table: "orders",
Expression: "status IN ('pending', 'completed', 'cancelled')",
}
ordersTable.Constraints["ck_status_valid"] = statusCheckConstraint
// Add foreign key constraint (referencing a users table)
fkConstraint := &models.Constraint{
Name: "fk_orders_user",
Type: models.ForeignKeyConstraint,
Schema: "public",
Table: "orders",
Columns: []string{"user_id"},
ReferencedSchema: "public",
ReferencedTable: "users",
ReferencedColumns: []string{"id"},
OnDelete: "CASCADE",
OnUpdate: "CASCADE",
}
ordersTable.Constraints["fk_orders_user"] = fkConstraint
schema.Tables = append(schema.Tables, ordersTable)
db.Schemas = append(db.Schemas, schema)
// Create writer with output to buffer
var buf bytes.Buffer
options := &writers.WriterOptions{}
writer := NewWriter(options)
writer.writer = &buf
// Write the database
err := writer.WriteDatabase(db)
if err != nil {
t.Fatalf("WriteDatabase failed: %v", err)
}
output := buf.String()
// Print output for debugging
t.Logf("Generated SQL:\n%s", output)
// Verify all constraint types are present
expectedConstraints := map[string]string{
"Primary Key": "PRIMARY KEY",
"Unique": "ADD CONSTRAINT uq_order_number UNIQUE (order_number)",
"Check (total)": "ADD CONSTRAINT ck_total_positive CHECK (total > 0)",
"Check (status)": "ADD CONSTRAINT ck_status_valid CHECK (status IN ('pending', 'completed', 'cancelled'))",
"Foreign Key": "FOREIGN KEY",
}
for name, expected := range expectedConstraints {
if !strings.Contains(output, expected) {
t.Errorf("Output missing %s constraint: %s\nFull output:\n%s", name, expected, output)
}
}
// Verify section headers
sections := []string{
"-- Primary keys for schema: public",
"-- Unique constraints for schema: public",
"-- Check constraints for schema: public",
"-- Foreign keys for schema: public",
}
for _, section := range sections {
if !strings.Contains(output, section) {
t.Errorf("Output missing section header: %s", section)
}
}
}
func TestWriteTable(t *testing.T) {
// Create a single table
table := models.InitTable("products", "public")
@@ -305,3 +595,263 @@ func TestTypeConversion(t *testing.T) {
t.Errorf("Output missing 'smallint' type (converted from 'int16')\nFull output:\n%s", output)
}
}
func TestPrimaryKeyExistenceCheck(t *testing.T) {
db := models.InitDatabase("testdb")
schema := models.InitSchema("public")
table := models.InitTable("products", "public")
idCol := models.InitColumn("id", "products", "public")
idCol.Type = "integer"
idCol.IsPrimaryKey = true
table.Columns["id"] = idCol
nameCol := models.InitColumn("name", "products", "public")
nameCol.Type = "text"
table.Columns["name"] = nameCol
schema.Tables = append(schema.Tables, table)
db.Schemas = append(db.Schemas, schema)
var buf bytes.Buffer
options := &writers.WriterOptions{}
writer := NewWriter(options)
writer.writer = &buf
err := writer.WriteDatabase(db)
if err != nil {
t.Fatalf("WriteDatabase failed: %v", err)
}
output := buf.String()
t.Logf("Generated SQL:\n%s", output)
// Verify our naming convention is used
if !strings.Contains(output, "pk_public_products") {
t.Errorf("Output missing expected primary key name 'pk_public_products'\nFull output:\n%s", output)
}
// Verify it drops auto-generated primary keys
if !strings.Contains(output, "products_pkey") || !strings.Contains(output, "DROP CONSTRAINT") {
t.Errorf("Output missing logic to drop auto-generated primary key\nFull output:\n%s", output)
}
// Verify it checks for our specific named constraint before adding it
if !strings.Contains(output, "constraint_name = 'pk_public_products'") {
t.Errorf("Output missing check for our named primary key constraint\nFull output:\n%s", output)
}
}
func TestColumnSizeSpecifiers(t *testing.T) {
db := models.InitDatabase("testdb")
schema := models.InitSchema("public")
table := models.InitTable("test_sizes", "public")
// Integer with invalid size specifier - should ignore size
integerCol := models.InitColumn("int_col", "test_sizes", "public")
integerCol.Type = "integer"
integerCol.Length = 32
table.Columns["int_col"] = integerCol
// Bigint with invalid size specifier - should ignore size
bigintCol := models.InitColumn("bigint_col", "test_sizes", "public")
bigintCol.Type = "bigint"
bigintCol.Length = 64
table.Columns["bigint_col"] = bigintCol
// Smallint with invalid size specifier - should ignore size
smallintCol := models.InitColumn("smallint_col", "test_sizes", "public")
smallintCol.Type = "smallint"
smallintCol.Length = 16
table.Columns["smallint_col"] = smallintCol
// Text with length - should convert to varchar
textCol := models.InitColumn("text_col", "test_sizes", "public")
textCol.Type = "text"
textCol.Length = 100
table.Columns["text_col"] = textCol
// Varchar with length - should keep varchar with length
varcharCol := models.InitColumn("varchar_col", "test_sizes", "public")
varcharCol.Type = "varchar"
varcharCol.Length = 50
table.Columns["varchar_col"] = varcharCol
// Decimal with precision and scale - should keep them
decimalCol := models.InitColumn("decimal_col", "test_sizes", "public")
decimalCol.Type = "decimal"
decimalCol.Precision = 19
decimalCol.Scale = 4
table.Columns["decimal_col"] = decimalCol
schema.Tables = append(schema.Tables, table)
db.Schemas = append(db.Schemas, schema)
var buf bytes.Buffer
options := &writers.WriterOptions{}
writer := NewWriter(options)
writer.writer = &buf
err := writer.WriteDatabase(db)
if err != nil {
t.Fatalf("WriteDatabase failed: %v", err)
}
output := buf.String()
t.Logf("Generated SQL:\n%s", output)
// Verify invalid size specifiers are NOT present
invalidPatterns := []string{
"integer(32)",
"bigint(64)",
"smallint(16)",
"text(100)",
}
for _, pattern := range invalidPatterns {
if strings.Contains(output, pattern) {
t.Errorf("Output contains invalid pattern '%s' - PostgreSQL doesn't support this\nFull output:\n%s", pattern, output)
}
}
// Verify valid patterns ARE present
validPatterns := []string{
"integer", // without size
"bigint", // without size
"smallint", // without size
"varchar(100)", // text converted to varchar with length
"varchar(50)", // varchar with length
"decimal(19,4)", // decimal with precision and scale
}
for _, pattern := range validPatterns {
if !strings.Contains(output, pattern) {
t.Errorf("Output missing expected pattern '%s'\nFull output:\n%s", pattern, output)
}
}
}
func TestGenerateAddColumnStatements(t *testing.T) {
// Create a test database with tables that have new columns
db := models.InitDatabase("testdb")
schema := models.InitSchema("public")
// Create a table with columns
table := models.InitTable("users", "public")
// Existing column
idCol := models.InitColumn("id", "users", "public")
idCol.Type = "integer"
idCol.NotNull = true
idCol.Sequence = 1
table.Columns["id"] = idCol
// New column to be added
emailCol := models.InitColumn("email", "users", "public")
emailCol.Type = "varchar"
emailCol.Length = 255
emailCol.NotNull = true
emailCol.Sequence = 2
table.Columns["email"] = emailCol
// New column with default
statusCol := models.InitColumn("status", "users", "public")
statusCol.Type = "text"
statusCol.Default = "active"
statusCol.Sequence = 3
table.Columns["status"] = statusCol
schema.Tables = append(schema.Tables, table)
db.Schemas = append(db.Schemas, schema)
// Create writer
options := &writers.WriterOptions{}
writer := NewWriter(options)
// Generate ADD COLUMN statements
statements, err := writer.GenerateAddColumnsForDatabase(db)
if err != nil {
t.Fatalf("GenerateAddColumnsForDatabase failed: %v", err)
}
// Join all statements to verify content
output := strings.Join(statements, "\n")
t.Logf("Generated ADD COLUMN statements:\n%s", output)
// Verify expected elements
expectedStrings := []string{
"ALTER TABLE public.users ADD COLUMN id integer NOT NULL",
"ALTER TABLE public.users ADD COLUMN email varchar(255) NOT NULL",
"ALTER TABLE public.users ADD COLUMN status text DEFAULT 'active'",
"information_schema.columns",
"table_schema = 'public'",
"table_name = 'users'",
"column_name = 'id'",
"column_name = 'email'",
"column_name = 'status'",
}
for _, expected := range expectedStrings {
if !strings.Contains(output, expected) {
t.Errorf("Output missing expected string: %s\nFull output:\n%s", expected, output)
}
}
// Verify DO blocks are present for conditional adds
doBlockCount := strings.Count(output, "DO $$")
if doBlockCount < 3 {
t.Errorf("Expected at least 3 DO blocks (one per column), got %d", doBlockCount)
}
// Verify IF NOT EXISTS logic
ifNotExistsCount := strings.Count(output, "IF NOT EXISTS")
if ifNotExistsCount < 3 {
t.Errorf("Expected at least 3 IF NOT EXISTS checks (one per column), got %d", ifNotExistsCount)
}
}
func TestWriteAddColumnStatements(t *testing.T) {
// Create a test database
db := models.InitDatabase("testdb")
schema := models.InitSchema("public")
// Create a table with a new column to be added
table := models.InitTable("products", "public")
idCol := models.InitColumn("id", "products", "public")
idCol.Type = "integer"
table.Columns["id"] = idCol
// New column with various properties
descCol := models.InitColumn("description", "products", "public")
descCol.Type = "text"
descCol.NotNull = false
table.Columns["description"] = descCol
schema.Tables = append(schema.Tables, table)
db.Schemas = append(db.Schemas, schema)
// Create writer with output to buffer
var buf bytes.Buffer
options := &writers.WriterOptions{}
writer := NewWriter(options)
writer.writer = &buf
// Write ADD COLUMN statements
err := writer.WriteAddColumnStatements(db)
if err != nil {
t.Fatalf("WriteAddColumnStatements failed: %v", err)
}
output := buf.String()
t.Logf("Generated output:\n%s", output)
// Verify output contains expected elements
if !strings.Contains(output, "ALTER TABLE public.products ADD COLUMN id integer") {
t.Errorf("Output missing ADD COLUMN for id\nFull output:\n%s", output)
}
if !strings.Contains(output, "ALTER TABLE public.products ADD COLUMN description text") {
t.Errorf("Output missing ADD COLUMN for description\nFull output:\n%s", output)
}
if !strings.Contains(output, "DO $$") {
t.Errorf("Output missing DO block\nFull output:\n%s", output)
}
}

View File

@@ -23,6 +23,11 @@ func NewWriter(options *writers.WriterOptions) *Writer {
}
}
// Options returns the writer options (useful for reading execution results)
func (w *Writer) Options() *writers.WriterOptions {
return w.options
}
// WriteDatabase executes all scripts from all schemas in the database
func (w *Writer) WriteDatabase(db *models.Database) error {
if db == nil {
@@ -92,6 +97,22 @@ func (w *Writer) executeScripts(ctx context.Context, conn *pgx.Conn, scripts []*
return nil
}
// Check if we should ignore errors
ignoreErrors := false
if val, ok := w.options.Metadata["ignore_errors"].(bool); ok {
ignoreErrors = val
}
// Track failed scripts and execution counts
var failedScripts []struct {
name string
priority int
sequence uint
err error
}
successCount := 0
totalCount := 0
// Sort scripts by Priority (ascending), Sequence (ascending), then Name (ascending)
sortedScripts := make([]*models.Script, len(scripts))
copy(sortedScripts, scripts)
@@ -111,18 +132,49 @@ func (w *Writer) executeScripts(ctx context.Context, conn *pgx.Conn, scripts []*
continue
}
totalCount++
fmt.Printf("Executing script: %s (Priority=%d, Sequence=%d)\n",
script.Name, script.Priority, script.Sequence)
// Execute the SQL script
_, err := conn.Exec(ctx, script.SQL)
if err != nil {
return fmt.Errorf("failed to execute script %s (Priority=%d, Sequence=%d): %w",
if ignoreErrors {
fmt.Printf("⚠ Error executing %s: %v (continuing due to --ignore-errors)\n", script.Name, err)
failedScripts = append(failedScripts, struct {
name string
priority int
sequence uint
err error
}{
name: script.Name,
priority: script.Priority,
sequence: script.Sequence,
err: err,
})
continue
}
return fmt.Errorf("script %s (Priority=%d, Sequence=%d): %w",
script.Name, script.Priority, script.Sequence, err)
}
successCount++
fmt.Printf("✓ Successfully executed: %s\n", script.Name)
}
// Store execution results in metadata for caller
w.options.Metadata["execution_total"] = totalCount
w.options.Metadata["execution_success"] = successCount
w.options.Metadata["execution_failed"] = len(failedScripts)
// Print summary of failed scripts if any
if len(failedScripts) > 0 {
fmt.Printf("\n⚠ Failed Scripts Summary (%d failed):\n", len(failedScripts))
for i, failed := range failedScripts {
fmt.Printf(" %d. %s (Priority=%d, Sequence=%d)\n Error: %v\n",
i+1, failed.name, failed.priority, failed.sequence, failed.err)
}
}
return nil
}