diff --git a/pkg/writers/pgsql/writer.go b/pkg/writers/pgsql/writer.go index 76f8645..277b4e9 100644 --- a/pkg/writers/pgsql/writer.go +++ b/pkg/writers/pgsql/writer.go @@ -179,27 +179,67 @@ func (w *Writer) GenerateSchemaStatements(schema *models.Schema) ([]string, erro } } + var pkColumns []string + var pkName string + if pkConstraint != nil { - stmt := fmt.Sprintf("ALTER TABLE %s.%s ADD CONSTRAINT %s PRIMARY KEY (%s)", - schema.SQLName(), table.SQLName(), pkConstraint.Name, strings.Join(pkConstraint.Columns, ", ")) - statements = append(statements, stmt) + pkColumns = pkConstraint.Columns + pkName = pkConstraint.Name } else { // No explicit constraint, check for columns with IsPrimaryKey = true - pkColumns := []string{} + pkCols := []string{} for _, col := range table.Columns { if col.IsPrimaryKey { - pkColumns = append(pkColumns, col.SQLName()) + pkCols = append(pkCols, col.SQLName()) } } - if len(pkColumns) > 0 { + if len(pkCols) > 0 { // Sort for consistent output - sort.Strings(pkColumns) - pkName := fmt.Sprintf("pk_%s_%s", schema.SQLName(), table.SQLName()) - stmt := fmt.Sprintf("ALTER TABLE %s.%s ADD CONSTRAINT %s PRIMARY KEY (%s)", - schema.SQLName(), table.SQLName(), pkName, strings.Join(pkColumns, ", ")) - statements = append(statements, stmt) + sort.Strings(pkCols) + pkColumns = pkCols + pkName = fmt.Sprintf("pk_%s_%s", schema.SQLName(), table.SQLName()) } } + + if len(pkColumns) > 0 { + // Auto-generated primary key names to check for and drop + autoGenPKNames := []string{ + fmt.Sprintf("%s_pkey", table.Name), + fmt.Sprintf("%s_%s_pkey", schema.Name, table.Name), + } + + // Wrap in DO block to drop auto-generated PK and add our named PK + stmt := fmt.Sprintf("DO $$\nDECLARE\n"+ + " auto_pk_name text;\n"+ + "BEGIN\n"+ + " -- Drop auto-generated primary key if it exists\n"+ + " SELECT constraint_name INTO auto_pk_name\n"+ + " FROM information_schema.table_constraints\n"+ + " WHERE table_schema = '%s'\n"+ + " AND table_name = '%s'\n"+ + " AND constraint_type = 'PRIMARY KEY'\n"+ + " AND constraint_name IN (%s);\n"+ + "\n"+ + " IF auto_pk_name IS NOT NULL THEN\n"+ + " EXECUTE 'ALTER TABLE %s.%s DROP CONSTRAINT ' || quote_ident(auto_pk_name);\n"+ + " END IF;\n"+ + "\n"+ + " -- Add named primary key if it doesn't exist\n"+ + " IF NOT EXISTS (\n"+ + " SELECT 1 FROM information_schema.table_constraints\n"+ + " WHERE table_schema = '%s'\n"+ + " AND table_name = '%s'\n"+ + " AND constraint_name = '%s'\n"+ + " ) THEN\n"+ + " ALTER TABLE %s.%s ADD CONSTRAINT %s PRIMARY KEY (%s);\n"+ + " END IF;\n"+ + "END;\n$$", + schema.Name, table.Name, formatStringList(autoGenPKNames), + schema.SQLName(), table.SQLName(), + schema.Name, table.Name, pkName, + schema.SQLName(), table.SQLName(), pkName, strings.Join(pkColumns, ", ")) + statements = append(statements, stmt) + } } // Phase 5: Indexes @@ -270,7 +310,18 @@ func (w *Writer) GenerateSchemaStatements(schema *models.Schema) ([]string, erro onUpdate = "NO ACTION" } - stmt := fmt.Sprintf("ALTER TABLE %s.%s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s.%s(%s) ON DELETE %s ON UPDATE %s", + // Wrap in DO block to check for existing constraint + stmt := fmt.Sprintf("DO $$\nBEGIN\n"+ + " IF NOT EXISTS (\n"+ + " SELECT 1 FROM information_schema.table_constraints\n"+ + " WHERE table_schema = '%s'\n"+ + " AND table_name = '%s'\n"+ + " AND constraint_name = '%s'\n"+ + " ) THEN\n"+ + " ALTER TABLE %s.%s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s.%s(%s) ON DELETE %s ON UPDATE %s;\n"+ + " END IF;\n"+ + "END;\n$$", + schema.Name, table.Name, constraint.Name, schema.SQLName(), table.SQLName(), constraint.Name, strings.Join(constraint.Columns, ", "), strings.ToLower(refSchema), strings.ToLower(constraint.ReferencedTable), @@ -336,14 +387,25 @@ func (w *Writer) generateColumnDefinition(col *models.Column) string { // Type with length/precision - convert to valid PostgreSQL type baseType := pgsql.ConvertSQLType(col.Type) typeStr := baseType + + // Only add size specifiers for types that support them if col.Length > 0 && col.Precision == 0 { - typeStr = fmt.Sprintf("%s(%d)", baseType, col.Length) - } else if col.Precision > 0 { - if col.Scale > 0 { - typeStr = fmt.Sprintf("%s(%d,%d)", baseType, col.Precision, col.Scale) - } else { - typeStr = fmt.Sprintf("%s(%d)", baseType, col.Precision) + if supportsLength(baseType) { + typeStr = fmt.Sprintf("%s(%d)", baseType, col.Length) + } else if isTextTypeWithoutLength(baseType) { + // Convert text with length to varchar + typeStr = fmt.Sprintf("varchar(%d)", col.Length) } + // For types that don't support length (integer, bigint, etc.), ignore the length + } else if col.Precision > 0 { + if supportsPrecision(baseType) { + if col.Scale > 0 { + typeStr = fmt.Sprintf("%s(%d,%d)", baseType, col.Precision, col.Scale) + } else { + typeStr = fmt.Sprintf("%s(%d)", baseType, col.Precision) + } + } + // For types that don't support precision, ignore it } parts = append(parts, typeStr) @@ -490,15 +552,8 @@ func (w *Writer) writeCreateTables(schema *models.Schema) error { columnDefs := make([]string, 0, len(columns)) for _, col := range columns { - colDef := fmt.Sprintf(" %s %s", col.SQLName(), pgsql.ConvertSQLType(col.Type)) - - // Add default value if present - if col.Default != nil && col.Default != "" { - // Strip backticks - DBML uses them for SQL expressions but PostgreSQL doesn't - defaultVal := fmt.Sprintf("%v", col.Default) - colDef += fmt.Sprintf(" DEFAULT %s", stripBackticks(defaultVal)) - } - + // Use generateColumnDefinition to properly handle type, length, precision, and defaults + colDef := " " + w.generateColumnDefinition(col) columnDefs = append(columnDefs, colDef) } @@ -550,7 +605,32 @@ func (w *Writer) writePrimaryKeys(schema *models.Schema) error { continue } - fmt.Fprintf(w.writer, "DO $$\nBEGIN\n") + // Auto-generated primary key names to check for and drop + autoGenPKNames := []string{ + fmt.Sprintf("%s_pkey", table.Name), + fmt.Sprintf("%s_%s_pkey", schema.Name, table.Name), + } + + fmt.Fprintf(w.writer, "DO $$\nDECLARE\n") + fmt.Fprintf(w.writer, " auto_pk_name text;\nBEGIN\n") + + // Check for and drop auto-generated primary keys + fmt.Fprintf(w.writer, " -- Drop auto-generated primary key if it exists\n") + fmt.Fprintf(w.writer, " SELECT constraint_name INTO auto_pk_name\n") + fmt.Fprintf(w.writer, " FROM information_schema.table_constraints\n") + fmt.Fprintf(w.writer, " WHERE table_schema = '%s'\n", schema.Name) + fmt.Fprintf(w.writer, " AND table_name = '%s'\n", table.Name) + fmt.Fprintf(w.writer, " AND constraint_type = 'PRIMARY KEY'\n") + fmt.Fprintf(w.writer, " AND constraint_name IN (%s);\n", formatStringList(autoGenPKNames)) + fmt.Fprintf(w.writer, "\n") + fmt.Fprintf(w.writer, " IF auto_pk_name IS NOT NULL THEN\n") + fmt.Fprintf(w.writer, " EXECUTE 'ALTER TABLE %s.%s DROP CONSTRAINT ' || quote_ident(auto_pk_name);\n", + schema.SQLName(), table.SQLName()) + fmt.Fprintf(w.writer, " END IF;\n") + fmt.Fprintf(w.writer, "\n") + + // Add our named primary key if it doesn't exist + fmt.Fprintf(w.writer, " -- Add named primary key if it doesn't exist\n") fmt.Fprintf(w.writer, " IF NOT EXISTS (\n") fmt.Fprintf(w.writer, " SELECT 1 FROM information_schema.table_constraints\n") fmt.Fprintf(w.writer, " WHERE table_schema = '%s'\n", schema.Name) @@ -711,13 +791,6 @@ func (w *Writer) writeForeignKeys(schema *models.Schema) error { onUpdate = strings.ToUpper(fkConstraint.OnUpdate) } - fmt.Fprintf(w.writer, "ALTER TABLE %s.%s\n", schema.SQLName(), table.SQLName()) - fmt.Fprintf(w.writer, " DROP CONSTRAINT IF EXISTS %s;\n", fkName) - fmt.Fprintf(w.writer, "\n") - fmt.Fprintf(w.writer, "ALTER TABLE %s.%s\n", schema.SQLName(), table.SQLName()) - fmt.Fprintf(w.writer, " ADD CONSTRAINT %s\n", fkName) - fmt.Fprintf(w.writer, " FOREIGN KEY (%s)\n", strings.Join(sourceColumns, ", ")) - // Use constraint's referenced schema/table or relationship's ToSchema/ToTable refSchema := fkConstraint.ReferencedSchema if refSchema == "" { @@ -728,11 +801,24 @@ func (w *Writer) writeForeignKeys(schema *models.Schema) error { refTable = rel.ToTable } - fmt.Fprintf(w.writer, " REFERENCES %s.%s (%s)\n", + // Use DO block to check if constraint exists before adding + fmt.Fprintf(w.writer, "DO $$\nBEGIN\n") + fmt.Fprintf(w.writer, " IF NOT EXISTS (\n") + fmt.Fprintf(w.writer, " SELECT 1 FROM information_schema.table_constraints\n") + fmt.Fprintf(w.writer, " WHERE table_schema = '%s'\n", schema.Name) + fmt.Fprintf(w.writer, " AND table_name = '%s'\n", table.Name) + fmt.Fprintf(w.writer, " AND constraint_name = '%s'\n", fkName) + fmt.Fprintf(w.writer, " ) THEN\n") + fmt.Fprintf(w.writer, " ALTER TABLE %s.%s\n", schema.SQLName(), table.SQLName()) + fmt.Fprintf(w.writer, " ADD CONSTRAINT %s\n", fkName) + fmt.Fprintf(w.writer, " FOREIGN KEY (%s)\n", strings.Join(sourceColumns, ", ")) + fmt.Fprintf(w.writer, " REFERENCES %s.%s (%s)\n", refSchema, refTable, strings.Join(targetColumns, ", ")) - fmt.Fprintf(w.writer, " ON DELETE %s\n", onDelete) - fmt.Fprintf(w.writer, " ON UPDATE %s\n", onUpdate) - fmt.Fprintf(w.writer, " DEFERRABLE;\n\n") + fmt.Fprintf(w.writer, " ON DELETE %s\n", onDelete) + fmt.Fprintf(w.writer, " ON UPDATE %s\n", onUpdate) + fmt.Fprintf(w.writer, " DEFERRABLE;\n") + fmt.Fprintf(w.writer, " END IF;\n") + fmt.Fprintf(w.writer, "END;\n$$;\n\n") } } @@ -844,6 +930,44 @@ func isTextType(colType string) bool { return false } +// supportsLength checks if a PostgreSQL type supports length specification +func supportsLength(colType string) bool { + lengthTypes := []string{"varchar", "character varying", "char", "character", "bit", "bit varying", "varbit"} + lowerType := strings.ToLower(colType) + for _, t := range lengthTypes { + if lowerType == t || strings.HasPrefix(lowerType, t+"(") { + return true + } + } + return false +} + +// supportsPrecision checks if a PostgreSQL type supports precision/scale specification +func supportsPrecision(colType string) bool { + precisionTypes := []string{"numeric", "decimal", "time", "timestamp", "timestamptz", "timestamp with time zone", "timestamp without time zone", "time with time zone", "time without time zone", "interval"} + lowerType := strings.ToLower(colType) + for _, t := range precisionTypes { + if lowerType == t || strings.HasPrefix(lowerType, t+"(") { + return true + } + } + return false +} + +// isTextTypeWithoutLength checks if type is text (which should convert to varchar when length is specified) +func isTextTypeWithoutLength(colType string) bool { + return strings.EqualFold(colType, "text") +} + +// formatStringList formats a list of strings as a SQL-safe comma-separated quoted list +func formatStringList(items []string) string { + quoted := make([]string, len(items)) + for i, item := range items { + quoted[i] = fmt.Sprintf("'%s'", escapeQuote(item)) + } + return strings.Join(quoted, ", ") +} + // extractOperatorClass extracts operator class from index comment/note // Looks for common operator classes like gin_trgm_ops, gist_trgm_ops, etc. func extractOperatorClass(comment string) string { diff --git a/pkg/writers/pgsql/writer_test.go b/pkg/writers/pgsql/writer_test.go index 6f57b82..4344881 100644 --- a/pkg/writers/pgsql/writer_test.go +++ b/pkg/writers/pgsql/writer_test.go @@ -305,3 +305,136 @@ func TestTypeConversion(t *testing.T) { t.Errorf("Output missing 'smallint' type (converted from 'int16')\nFull output:\n%s", output) } } + +func TestPrimaryKeyExistenceCheck(t *testing.T) { + db := models.InitDatabase("testdb") + schema := models.InitSchema("public") + table := models.InitTable("products", "public") + + idCol := models.InitColumn("id", "products", "public") + idCol.Type = "integer" + idCol.IsPrimaryKey = true + table.Columns["id"] = idCol + + nameCol := models.InitColumn("name", "products", "public") + nameCol.Type = "text" + table.Columns["name"] = nameCol + + schema.Tables = append(schema.Tables, table) + db.Schemas = append(db.Schemas, schema) + + var buf bytes.Buffer + options := &writers.WriterOptions{} + writer := NewWriter(options) + writer.writer = &buf + + err := writer.WriteDatabase(db) + if err != nil { + t.Fatalf("WriteDatabase failed: %v", err) + } + + output := buf.String() + t.Logf("Generated SQL:\n%s", output) + + // Verify our naming convention is used + if !strings.Contains(output, "pk_public_products") { + t.Errorf("Output missing expected primary key name 'pk_public_products'\nFull output:\n%s", output) + } + + // Verify it drops auto-generated primary keys + if !strings.Contains(output, "products_pkey") || !strings.Contains(output, "DROP CONSTRAINT") { + t.Errorf("Output missing logic to drop auto-generated primary key\nFull output:\n%s", output) + } + + // Verify it checks for our specific named constraint before adding it + if !strings.Contains(output, "constraint_name = 'pk_public_products'") { + t.Errorf("Output missing check for our named primary key constraint\nFull output:\n%s", output) + } +} + +func TestColumnSizeSpecifiers(t *testing.T) { + db := models.InitDatabase("testdb") + schema := models.InitSchema("public") + table := models.InitTable("test_sizes", "public") + + // Integer with invalid size specifier - should ignore size + integerCol := models.InitColumn("int_col", "test_sizes", "public") + integerCol.Type = "integer" + integerCol.Length = 32 + table.Columns["int_col"] = integerCol + + // Bigint with invalid size specifier - should ignore size + bigintCol := models.InitColumn("bigint_col", "test_sizes", "public") + bigintCol.Type = "bigint" + bigintCol.Length = 64 + table.Columns["bigint_col"] = bigintCol + + // Smallint with invalid size specifier - should ignore size + smallintCol := models.InitColumn("smallint_col", "test_sizes", "public") + smallintCol.Type = "smallint" + smallintCol.Length = 16 + table.Columns["smallint_col"] = smallintCol + + // Text with length - should convert to varchar + textCol := models.InitColumn("text_col", "test_sizes", "public") + textCol.Type = "text" + textCol.Length = 100 + table.Columns["text_col"] = textCol + + // Varchar with length - should keep varchar with length + varcharCol := models.InitColumn("varchar_col", "test_sizes", "public") + varcharCol.Type = "varchar" + varcharCol.Length = 50 + table.Columns["varchar_col"] = varcharCol + + // Decimal with precision and scale - should keep them + decimalCol := models.InitColumn("decimal_col", "test_sizes", "public") + decimalCol.Type = "decimal" + decimalCol.Precision = 19 + decimalCol.Scale = 4 + table.Columns["decimal_col"] = decimalCol + + schema.Tables = append(schema.Tables, table) + db.Schemas = append(db.Schemas, schema) + + var buf bytes.Buffer + options := &writers.WriterOptions{} + writer := NewWriter(options) + writer.writer = &buf + + err := writer.WriteDatabase(db) + if err != nil { + t.Fatalf("WriteDatabase failed: %v", err) + } + + output := buf.String() + t.Logf("Generated SQL:\n%s", output) + + // Verify invalid size specifiers are NOT present + invalidPatterns := []string{ + "integer(32)", + "bigint(64)", + "smallint(16)", + "text(100)", + } + for _, pattern := range invalidPatterns { + if strings.Contains(output, pattern) { + t.Errorf("Output contains invalid pattern '%s' - PostgreSQL doesn't support this\nFull output:\n%s", pattern, output) + } + } + + // Verify valid patterns ARE present + validPatterns := []string{ + "integer", // without size + "bigint", // without size + "smallint", // without size + "varchar(100)", // text converted to varchar with length + "varchar(50)", // varchar with length + "decimal(19,4)", // decimal with precision and scale + } + for _, pattern := range validPatterns { + if !strings.Contains(output, pattern) { + t.Errorf("Output missing expected pattern '%s'\nFull output:\n%s", pattern, output) + } + } +}