More Roundtrip tests
This commit is contained in:
@@ -62,6 +62,234 @@ func (w *Writer) WriteDatabase(db *models.Database) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GenerateDatabaseStatements generates SQL statements as a list for the entire database
|
||||
// Returns a slice of SQL statements that can be executed independently
|
||||
func (w *Writer) GenerateDatabaseStatements(db *models.Database) ([]string, error) {
|
||||
statements := []string{}
|
||||
|
||||
// Add header comment
|
||||
statements = append(statements, fmt.Sprintf("-- PostgreSQL Database Schema"))
|
||||
statements = append(statements, fmt.Sprintf("-- Database: %s", db.Name))
|
||||
statements = append(statements, fmt.Sprintf("-- Generated by RelSpec"))
|
||||
|
||||
// Process each schema in the database
|
||||
for _, schema := range db.Schemas {
|
||||
schemaStatements, err := w.GenerateSchemaStatements(schema)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate statements for schema %s: %w", schema.Name, err)
|
||||
}
|
||||
statements = append(statements, schemaStatements...)
|
||||
}
|
||||
|
||||
return statements, nil
|
||||
}
|
||||
|
||||
// GenerateSchemaStatements generates SQL statements as a list for a single schema
|
||||
func (w *Writer) GenerateSchemaStatements(schema *models.Schema) ([]string, error) {
|
||||
statements := []string{}
|
||||
|
||||
// Phase 1: Create schema
|
||||
if schema.Name != "public" {
|
||||
statements = append(statements, fmt.Sprintf("-- Schema: %s", schema.Name))
|
||||
statements = append(statements, fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS %s", schema.SQLName()))
|
||||
}
|
||||
|
||||
// Phase 2: Create sequences
|
||||
for _, table := range schema.Tables {
|
||||
pk := table.GetPrimaryKey()
|
||||
if pk == nil || !isIntegerType(pk.Type) || pk.Default == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
defaultStr, ok := pk.Default.(string)
|
||||
if !ok || !strings.Contains(strings.ToLower(defaultStr), "nextval") {
|
||||
continue
|
||||
}
|
||||
|
||||
seqName := extractSequenceName(defaultStr)
|
||||
if seqName == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
stmt := fmt.Sprintf("CREATE SEQUENCE IF NOT EXISTS %s.%s\n INCREMENT 1\n MINVALUE 1\n MAXVALUE 9223372036854775807\n START 1\n CACHE 1",
|
||||
schema.SQLName(), seqName)
|
||||
statements = append(statements, stmt)
|
||||
}
|
||||
|
||||
// Phase 3: Create tables
|
||||
for _, table := range schema.Tables {
|
||||
stmts, err := w.generateCreateTableStatement(schema, table)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate table %s: %w", table.Name, err)
|
||||
}
|
||||
statements = append(statements, stmts...)
|
||||
}
|
||||
|
||||
// Phase 4: Primary keys
|
||||
for _, table := range schema.Tables {
|
||||
for _, constraint := range table.Constraints {
|
||||
if constraint.Type != models.PrimaryKeyConstraint {
|
||||
continue
|
||||
}
|
||||
stmt := fmt.Sprintf("ALTER TABLE %s.%s ADD CONSTRAINT %s PRIMARY KEY (%s)",
|
||||
schema.SQLName(), table.SQLName(), constraint.Name, strings.Join(constraint.Columns, ", "))
|
||||
statements = append(statements, stmt)
|
||||
}
|
||||
}
|
||||
|
||||
// Phase 5: Indexes
|
||||
for _, table := range schema.Tables {
|
||||
for _, index := range table.Indexes {
|
||||
// Skip primary key indexes
|
||||
if strings.HasSuffix(index.Name, "_pkey") {
|
||||
continue
|
||||
}
|
||||
|
||||
uniqueStr := ""
|
||||
if index.Unique {
|
||||
uniqueStr = "UNIQUE "
|
||||
}
|
||||
|
||||
indexType := index.Type
|
||||
if indexType == "" {
|
||||
indexType = "btree"
|
||||
}
|
||||
|
||||
whereClause := ""
|
||||
if index.Where != "" {
|
||||
whereClause = fmt.Sprintf(" WHERE %s", index.Where)
|
||||
}
|
||||
|
||||
stmt := fmt.Sprintf("CREATE %sINDEX IF NOT EXISTS %s ON %s.%s USING %s (%s)%s",
|
||||
uniqueStr, index.Name, schema.SQLName(), table.SQLName(), indexType, strings.Join(index.Columns, ", "), whereClause)
|
||||
statements = append(statements, stmt)
|
||||
}
|
||||
}
|
||||
|
||||
// Phase 6: Foreign keys
|
||||
for _, table := range schema.Tables {
|
||||
for _, constraint := range table.Constraints {
|
||||
if constraint.Type != models.ForeignKeyConstraint {
|
||||
continue
|
||||
}
|
||||
|
||||
refSchema := constraint.ReferencedSchema
|
||||
if refSchema == "" {
|
||||
refSchema = schema.Name
|
||||
}
|
||||
|
||||
onDelete := constraint.OnDelete
|
||||
if onDelete == "" {
|
||||
onDelete = "NO ACTION"
|
||||
}
|
||||
|
||||
onUpdate := constraint.OnUpdate
|
||||
if onUpdate == "" {
|
||||
onUpdate = "NO ACTION"
|
||||
}
|
||||
|
||||
stmt := fmt.Sprintf("ALTER TABLE %s.%s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s.%s(%s) ON DELETE %s ON UPDATE %s",
|
||||
schema.SQLName(), table.SQLName(), constraint.Name,
|
||||
strings.Join(constraint.Columns, ", "),
|
||||
strings.ToLower(refSchema), strings.ToLower(constraint.ReferencedTable),
|
||||
strings.Join(constraint.ReferencedColumns, ", "),
|
||||
onDelete, onUpdate)
|
||||
statements = append(statements, stmt)
|
||||
}
|
||||
}
|
||||
|
||||
// Phase 7: Comments
|
||||
for _, table := range schema.Tables {
|
||||
if table.Comment != "" {
|
||||
stmt := fmt.Sprintf("COMMENT ON TABLE %s.%s IS '%s'",
|
||||
schema.SQLName(), table.SQLName(), escapeQuote(table.Comment))
|
||||
statements = append(statements, stmt)
|
||||
}
|
||||
|
||||
for _, column := range table.Columns {
|
||||
if column.Comment != "" {
|
||||
stmt := fmt.Sprintf("COMMENT ON COLUMN %s.%s.%s IS '%s'",
|
||||
schema.SQLName(), table.SQLName(), column.SQLName(), escapeQuote(column.Comment))
|
||||
statements = append(statements, stmt)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return statements, nil
|
||||
}
|
||||
|
||||
// generateCreateTableStatement generates CREATE TABLE statement
|
||||
func (w *Writer) generateCreateTableStatement(schema *models.Schema, table *models.Table) ([]string, error) {
|
||||
statements := []string{}
|
||||
|
||||
// Sort columns by sequence or name
|
||||
columns := make([]*models.Column, 0, len(table.Columns))
|
||||
for _, col := range table.Columns {
|
||||
columns = append(columns, col)
|
||||
}
|
||||
sort.Slice(columns, func(i, j int) bool {
|
||||
if columns[i].Sequence != columns[j].Sequence {
|
||||
return columns[i].Sequence < columns[j].Sequence
|
||||
}
|
||||
return columns[i].Name < columns[j].Name
|
||||
})
|
||||
|
||||
columnDefs := []string{}
|
||||
for _, col := range columns {
|
||||
def := w.generateColumnDefinition(col)
|
||||
columnDefs = append(columnDefs, " "+def)
|
||||
}
|
||||
|
||||
stmt := fmt.Sprintf("CREATE TABLE %s.%s (\n%s\n)",
|
||||
schema.SQLName(), table.SQLName(), strings.Join(columnDefs, ",\n"))
|
||||
statements = append(statements, stmt)
|
||||
|
||||
return statements, nil
|
||||
}
|
||||
|
||||
// generateColumnDefinition generates column definition
|
||||
func (w *Writer) generateColumnDefinition(col *models.Column) string {
|
||||
parts := []string{col.SQLName()}
|
||||
|
||||
// Type with length/precision
|
||||
typeStr := col.Type
|
||||
if col.Length > 0 && col.Precision == 0 {
|
||||
typeStr = fmt.Sprintf("%s(%d)", col.Type, col.Length)
|
||||
} else if col.Precision > 0 {
|
||||
if col.Scale > 0 {
|
||||
typeStr = fmt.Sprintf("%s(%d,%d)", col.Type, col.Precision, col.Scale)
|
||||
} else {
|
||||
typeStr = fmt.Sprintf("%s(%d)", col.Type, col.Precision)
|
||||
}
|
||||
}
|
||||
parts = append(parts, typeStr)
|
||||
|
||||
// NOT NULL
|
||||
if col.NotNull {
|
||||
parts = append(parts, "NOT NULL")
|
||||
}
|
||||
|
||||
// DEFAULT
|
||||
if col.Default != nil {
|
||||
switch v := col.Default.(type) {
|
||||
case string:
|
||||
if strings.HasPrefix(v, "nextval") || strings.HasPrefix(v, "CURRENT_") || strings.Contains(v, "()") {
|
||||
parts = append(parts, fmt.Sprintf("DEFAULT %s", v))
|
||||
} else if v == "true" || v == "false" {
|
||||
parts = append(parts, fmt.Sprintf("DEFAULT %s", v))
|
||||
} else {
|
||||
parts = append(parts, fmt.Sprintf("DEFAULT '%s'", escapeQuote(v)))
|
||||
}
|
||||
case bool:
|
||||
parts = append(parts, fmt.Sprintf("DEFAULT %v", v))
|
||||
default:
|
||||
parts = append(parts, fmt.Sprintf("DEFAULT %v", v))
|
||||
}
|
||||
}
|
||||
|
||||
return strings.Join(parts, " ")
|
||||
}
|
||||
|
||||
// WriteSchema writes a single schema and all its tables
|
||||
func (w *Writer) WriteSchema(schema *models.Schema) error {
|
||||
if w.writer == nil {
|
||||
@@ -494,3 +722,26 @@ func isIntegerType(colType string) bool {
|
||||
func escapeQuote(s string) string {
|
||||
return strings.ReplaceAll(s, "'", "''")
|
||||
}
|
||||
|
||||
// extractSequenceName extracts sequence name from nextval() expression
|
||||
// Example: "nextval('public.users_id_seq'::regclass)" returns "users_id_seq"
|
||||
func extractSequenceName(defaultExpr string) string {
|
||||
// Look for nextval('schema.sequence_name'::regclass) pattern
|
||||
start := strings.Index(defaultExpr, "'")
|
||||
if start == -1 {
|
||||
return ""
|
||||
}
|
||||
end := strings.Index(defaultExpr[start+1:], "'")
|
||||
if end == -1 {
|
||||
return ""
|
||||
}
|
||||
|
||||
fullName := defaultExpr[start+1 : start+1+end]
|
||||
|
||||
// Remove schema prefix if present
|
||||
parts := strings.Split(fullName, ".")
|
||||
if len(parts) > 1 {
|
||||
return parts[len(parts)-1]
|
||||
}
|
||||
return fullName
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user