More Roundtrip tests
This commit is contained in:
@@ -29,7 +29,6 @@ func (w *Writer) WriteDatabase(db *models.Database) error {
|
||||
return os.WriteFile(w.options.OutputPath, []byte(content), 0644)
|
||||
}
|
||||
|
||||
// If no output path, print to stdout
|
||||
fmt.Print(content)
|
||||
return nil
|
||||
}
|
||||
@@ -48,7 +47,7 @@ func (w *Writer) WriteSchema(schema *models.Schema) error {
|
||||
|
||||
// WriteTable writes a Table model to DBML format
|
||||
func (w *Writer) WriteTable(table *models.Table) error {
|
||||
content := w.tableToDBML(table, table.Schema)
|
||||
content := w.tableToDBML(table)
|
||||
|
||||
if w.options.OutputPath != "" {
|
||||
return os.WriteFile(w.options.OutputPath, []byte(content), 0644)
|
||||
@@ -60,70 +59,63 @@ func (w *Writer) WriteTable(table *models.Table) error {
|
||||
|
||||
// databaseToDBML converts a Database to DBML format string
|
||||
func (w *Writer) databaseToDBML(d *models.Database) string {
|
||||
var result string
|
||||
var sb strings.Builder
|
||||
|
||||
// Add database comment if exists
|
||||
if d.Description != "" {
|
||||
result += fmt.Sprintf("// %s\n", d.Description)
|
||||
sb.WriteString(fmt.Sprintf("// %s\n", d.Description))
|
||||
}
|
||||
if d.Comment != "" {
|
||||
result += fmt.Sprintf("// %s\n", d.Comment)
|
||||
sb.WriteString(fmt.Sprintf("// %s\n", d.Comment))
|
||||
}
|
||||
if d.Description != "" || d.Comment != "" {
|
||||
result += "\n"
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
|
||||
// Process each schema
|
||||
for _, schema := range d.Schemas {
|
||||
result += w.schemaToDBML(schema)
|
||||
sb.WriteString(w.schemaToDBML(schema))
|
||||
}
|
||||
|
||||
// Add relationships
|
||||
result += "\n// Relationships\n"
|
||||
sb.WriteString("\n// Relationships\n")
|
||||
for _, schema := range d.Schemas {
|
||||
for _, table := range schema.Tables {
|
||||
for _, constraint := range table.Constraints {
|
||||
if constraint.Type == models.ForeignKeyConstraint {
|
||||
result += w.constraintToDBML(constraint, schema.Name, table.Name)
|
||||
sb.WriteString(w.constraintToDBML(constraint, table))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// schemaToDBML converts a Schema to DBML format string
|
||||
func (w *Writer) schemaToDBML(schema *models.Schema) string {
|
||||
var result string
|
||||
var sb strings.Builder
|
||||
|
||||
if schema.Description != "" {
|
||||
result += fmt.Sprintf("// Schema: %s - %s\n", schema.Name, schema.Description)
|
||||
sb.WriteString(fmt.Sprintf("// Schema: %s - %s\n", schema.Name, schema.Description))
|
||||
}
|
||||
|
||||
// Process tables
|
||||
for _, table := range schema.Tables {
|
||||
result += w.tableToDBML(table, schema.Name)
|
||||
result += "\n"
|
||||
sb.WriteString(w.tableToDBML(table))
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
|
||||
return result
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// tableToDBML converts a Table to DBML format string
|
||||
func (w *Writer) tableToDBML(t *models.Table, schemaName string) string {
|
||||
var result string
|
||||
func (w *Writer) tableToDBML(t *models.Table) string {
|
||||
var sb strings.Builder
|
||||
|
||||
// Table definition
|
||||
tableName := fmt.Sprintf("%s.%s", schemaName, t.Name)
|
||||
result += fmt.Sprintf("Table %s {\n", tableName)
|
||||
tableName := fmt.Sprintf("%s.%s", t.Schema, t.Name)
|
||||
sb.WriteString(fmt.Sprintf("Table %s {\n", tableName))
|
||||
|
||||
// Add columns
|
||||
for _, column := range t.Columns {
|
||||
result += fmt.Sprintf(" %s %s", column.Name, column.Type)
|
||||
sb.WriteString(fmt.Sprintf(" %s %s", column.Name, column.Type))
|
||||
|
||||
// Add column attributes
|
||||
attrs := make([]string, 0)
|
||||
var attrs []string
|
||||
if column.IsPrimaryKey {
|
||||
attrs = append(attrs, "pk")
|
||||
}
|
||||
@@ -134,77 +126,74 @@ func (w *Writer) tableToDBML(t *models.Table, schemaName string) string {
|
||||
attrs = append(attrs, "increment")
|
||||
}
|
||||
if column.Default != nil {
|
||||
attrs = append(attrs, fmt.Sprintf("default: %v", column.Default))
|
||||
attrs = append(attrs, fmt.Sprintf("default: `%v`", column.Default))
|
||||
}
|
||||
|
||||
if len(attrs) > 0 {
|
||||
result += fmt.Sprintf(" [%s]", strings.Join(attrs, ", "))
|
||||
sb.WriteString(fmt.Sprintf(" [%s]", strings.Join(attrs, ", ")))
|
||||
}
|
||||
|
||||
if column.Comment != "" {
|
||||
result += fmt.Sprintf(" // %s", column.Comment)
|
||||
sb.WriteString(fmt.Sprintf(" // %s", column.Comment))
|
||||
}
|
||||
result += "\n"
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
|
||||
// Add indexes
|
||||
indexCount := 0
|
||||
for _, index := range t.Indexes {
|
||||
if indexCount == 0 {
|
||||
result += "\n indexes {\n"
|
||||
}
|
||||
indexAttrs := make([]string, 0)
|
||||
if index.Unique {
|
||||
indexAttrs = append(indexAttrs, "unique")
|
||||
}
|
||||
if index.Name != "" {
|
||||
indexAttrs = append(indexAttrs, fmt.Sprintf("name: '%s'", index.Name))
|
||||
}
|
||||
if index.Type != "" {
|
||||
indexAttrs = append(indexAttrs, fmt.Sprintf("type: %s", index.Type))
|
||||
}
|
||||
if len(t.Indexes) > 0 {
|
||||
sb.WriteString("\n indexes {\n")
|
||||
for _, index := range t.Indexes {
|
||||
var indexAttrs []string
|
||||
if index.Unique {
|
||||
indexAttrs = append(indexAttrs, "unique")
|
||||
}
|
||||
if index.Name != "" {
|
||||
indexAttrs = append(indexAttrs, fmt.Sprintf("name: '%s'", index.Name))
|
||||
}
|
||||
if index.Type != "" {
|
||||
indexAttrs = append(indexAttrs, fmt.Sprintf("type: %s", index.Type))
|
||||
}
|
||||
|
||||
result += fmt.Sprintf(" (%s)", strings.Join(index.Columns, ", "))
|
||||
if len(indexAttrs) > 0 {
|
||||
result += fmt.Sprintf(" [%s]", strings.Join(indexAttrs, ", "))
|
||||
sb.WriteString(fmt.Sprintf(" (%s)", strings.Join(index.Columns, ", ")))
|
||||
if len(indexAttrs) > 0 {
|
||||
sb.WriteString(fmt.Sprintf(" [%s]", strings.Join(indexAttrs, ", ")))
|
||||
}
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
result += "\n"
|
||||
indexCount++
|
||||
}
|
||||
if indexCount > 0 {
|
||||
result += " }\n"
|
||||
sb.WriteString(" }\n")
|
||||
}
|
||||
|
||||
// Add table note
|
||||
if t.Description != "" || t.Comment != "" {
|
||||
note := t.Description
|
||||
if note != "" && t.Comment != "" {
|
||||
note += " - "
|
||||
}
|
||||
note += t.Comment
|
||||
result += fmt.Sprintf("\n Note: '%s'\n", note)
|
||||
note := strings.TrimSpace(t.Description + " " + t.Comment)
|
||||
if note != "" {
|
||||
sb.WriteString(fmt.Sprintf("\n Note: '%s'\n", note))
|
||||
}
|
||||
|
||||
result += "}\n"
|
||||
return result
|
||||
sb.WriteString("}\n")
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// constraintToDBML converts a Constraint to DBML format string
|
||||
func (w *Writer) constraintToDBML(c *models.Constraint, schemaName, tableName string) string {
|
||||
func (w *Writer) constraintToDBML(c *models.Constraint, t *models.Table) string {
|
||||
if c.Type != models.ForeignKeyConstraint || c.ReferencedTable == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
fromTable := fmt.Sprintf("%s.%s", schemaName, tableName)
|
||||
fromTable := fmt.Sprintf("%s.%s", c.Schema, c.Table)
|
||||
toTable := fmt.Sprintf("%s.%s", c.ReferencedSchema, c.ReferencedTable)
|
||||
|
||||
// Determine relationship cardinality
|
||||
// For foreign keys, it's typically many-to-one
|
||||
relationship := ">"
|
||||
relationship := ">" // Default to many-to-one
|
||||
for _, index := range t.Indexes {
|
||||
if index.Unique && strings.Join(index.Columns, ",") == strings.Join(c.Columns, ",") {
|
||||
relationship = "-" // one-to-one
|
||||
break
|
||||
}
|
||||
}
|
||||
for _, column := range c.Columns {
|
||||
if t.Columns[column].IsPrimaryKey {
|
||||
relationship = "-" // one-to-one
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Build from and to column references
|
||||
// For single columns: table.column
|
||||
// For multiple columns: table.(col1, col2)
|
||||
var fromRef, toRef string
|
||||
if len(c.Columns) == 1 {
|
||||
fromRef = fmt.Sprintf("%s.%s", fromTable, c.Columns[0])
|
||||
@@ -218,20 +207,18 @@ func (w *Writer) constraintToDBML(c *models.Constraint, schemaName, tableName st
|
||||
toRef = fmt.Sprintf("%s.(%s)", toTable, strings.Join(c.ReferencedColumns, ", "))
|
||||
}
|
||||
|
||||
result := fmt.Sprintf("Ref: %s %s %s", fromRef, relationship, toRef)
|
||||
|
||||
// Add actions
|
||||
actions := make([]string, 0)
|
||||
var actions []string
|
||||
if c.OnDelete != "" {
|
||||
actions = append(actions, fmt.Sprintf("ondelete: %s", c.OnDelete))
|
||||
actions = append(actions, fmt.Sprintf("delete: %s", c.OnDelete))
|
||||
}
|
||||
if c.OnUpdate != "" {
|
||||
actions = append(actions, fmt.Sprintf("onupdate: %s", c.OnUpdate))
|
||||
}
|
||||
if len(actions) > 0 {
|
||||
result += fmt.Sprintf(" [%s]", strings.Join(actions, ", "))
|
||||
actions = append(actions, fmt.Sprintf("update: %s", c.OnUpdate))
|
||||
}
|
||||
|
||||
result += "\n"
|
||||
return result
|
||||
}
|
||||
refLine := fmt.Sprintf("Ref: %s %s %s", fromRef, relationship, toRef)
|
||||
if len(actions) > 0 {
|
||||
refLine += fmt.Sprintf(" [%s]", strings.Join(actions, ", "))
|
||||
}
|
||||
|
||||
return refLine + "\n"
|
||||
}
|
||||
@@ -3,11 +3,11 @@ package dbml
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/writers"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestWriter_WriteTable(t *testing.T) {
|
||||
@@ -46,96 +46,40 @@ func TestWriter_WriteTable(t *testing.T) {
|
||||
|
||||
writer := NewWriter(opts)
|
||||
err := writer.WriteTable(table)
|
||||
if err != nil {
|
||||
t.Fatalf("WriteTable() error = %v", err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
|
||||
content, err := os.ReadFile(outputPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read output file: %v", err)
|
||||
}
|
||||
|
||||
assert.NoError(t, err)
|
||||
output := string(content)
|
||||
|
||||
// Verify table structure
|
||||
if !strings.Contains(output, "Table public.users {") {
|
||||
t.Error("Output should contain table definition")
|
||||
}
|
||||
|
||||
// Verify columns
|
||||
if !strings.Contains(output, "id bigint") {
|
||||
t.Error("Output should contain id column")
|
||||
}
|
||||
if !strings.Contains(output, "pk") {
|
||||
t.Error("Output should contain pk attribute for id")
|
||||
}
|
||||
if !strings.Contains(output, "increment") {
|
||||
t.Error("Output should contain increment attribute for id")
|
||||
}
|
||||
if !strings.Contains(output, "email varchar(255)") {
|
||||
t.Error("Output should contain email column")
|
||||
}
|
||||
if !strings.Contains(output, "not null") {
|
||||
t.Error("Output should contain not null attribute")
|
||||
}
|
||||
|
||||
// Verify table note
|
||||
if !strings.Contains(output, "Note:") && table.Description != "" {
|
||||
t.Error("Output should contain table note when description is present")
|
||||
}
|
||||
assert.Contains(t, output, "Table public.users {")
|
||||
assert.Contains(t, output, "id bigint [pk, increment]")
|
||||
assert.Contains(t, output, "email varchar(255) [not null]")
|
||||
assert.Contains(t, output, "Note: 'User accounts table'")
|
||||
}
|
||||
|
||||
func TestWriter_WriteDatabase_WithRelationships(t *testing.T) {
|
||||
db := models.InitDatabase("test_db")
|
||||
schema := models.InitSchema("public")
|
||||
|
||||
// Create users table
|
||||
usersTable := models.InitTable("users", "public")
|
||||
idCol := models.InitColumn("id", "users", "public")
|
||||
idCol.Type = "bigint"
|
||||
idCol.IsPrimaryKey = true
|
||||
idCol.AutoIncrement = true
|
||||
idCol.NotNull = true
|
||||
usersTable.Columns["id"] = idCol
|
||||
|
||||
emailCol := models.InitColumn("email", "users", "public")
|
||||
emailCol.Type = "varchar(255)"
|
||||
emailCol.NotNull = true
|
||||
usersTable.Columns["email"] = emailCol
|
||||
|
||||
// Add index to users table
|
||||
emailIdx := models.InitIndex("idx_users_email")
|
||||
emailIdx := models.InitIndex("idx_users_email", "users", "public")
|
||||
emailIdx.Columns = []string{"email"}
|
||||
emailIdx.Unique = true
|
||||
emailIdx.Table = "users"
|
||||
emailIdx.Schema = "public"
|
||||
usersTable.Indexes["idx_users_email"] = emailIdx
|
||||
schema.Tables = append(schema.Tables, usersTable)
|
||||
|
||||
// Create posts table
|
||||
postsTable := models.InitTable("posts", "public")
|
||||
postIdCol := models.InitColumn("id", "posts", "public")
|
||||
postIdCol.Type = "bigint"
|
||||
postIdCol.IsPrimaryKey = true
|
||||
postIdCol.AutoIncrement = true
|
||||
postIdCol.NotNull = true
|
||||
postsTable.Columns["id"] = postIdCol
|
||||
|
||||
userIdCol := models.InitColumn("user_id", "posts", "public")
|
||||
userIdCol.Type = "bigint"
|
||||
userIdCol.NotNull = true
|
||||
postsTable.Columns["user_id"] = userIdCol
|
||||
|
||||
titleCol := models.InitColumn("title", "posts", "public")
|
||||
titleCol.Type = "varchar(200)"
|
||||
titleCol.NotNull = true
|
||||
postsTable.Columns["title"] = titleCol
|
||||
|
||||
publishedCol := models.InitColumn("published", "posts", "public")
|
||||
publishedCol.Type = "boolean"
|
||||
publishedCol.Default = "false"
|
||||
postsTable.Columns["published"] = publishedCol
|
||||
|
||||
// Add foreign key constraint
|
||||
fk := models.InitConstraint("fk_posts_user", models.ForeignKeyConstraint)
|
||||
fk.Table = "posts"
|
||||
fk.Schema = "public"
|
||||
@@ -144,353 +88,68 @@ func TestWriter_WriteDatabase_WithRelationships(t *testing.T) {
|
||||
fk.ReferencedSchema = "public"
|
||||
fk.ReferencedColumns = []string{"id"}
|
||||
fk.OnDelete = "CASCADE"
|
||||
fk.OnUpdate = "CASCADE"
|
||||
postsTable.Constraints["fk_posts_user"] = fk
|
||||
|
||||
schema.Tables = append(schema.Tables, usersTable, postsTable)
|
||||
schema.Tables = append(schema.Tables, postsTable)
|
||||
db.Schemas = append(db.Schemas, schema)
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
outputPath := filepath.Join(tmpDir, "test.dbml")
|
||||
|
||||
opts := &writers.WriterOptions{
|
||||
OutputPath: outputPath,
|
||||
}
|
||||
|
||||
opts := &writers.WriterOptions{OutputPath: outputPath}
|
||||
writer := NewWriter(opts)
|
||||
err := writer.WriteDatabase(db)
|
||||
if err != nil {
|
||||
t.Fatalf("WriteDatabase() error = %v", err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
|
||||
content, err := os.ReadFile(outputPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read output file: %v", err)
|
||||
}
|
||||
|
||||
assert.NoError(t, err)
|
||||
output := string(content)
|
||||
|
||||
// Verify tables
|
||||
if !strings.Contains(output, "Table public.users {") {
|
||||
t.Error("Output should contain users table")
|
||||
}
|
||||
if !strings.Contains(output, "Table public.posts {") {
|
||||
t.Error("Output should contain posts table")
|
||||
}
|
||||
|
||||
// Verify foreign key reference
|
||||
if !strings.Contains(output, "Ref:") {
|
||||
t.Error("Output should contain Ref for foreign key")
|
||||
}
|
||||
if !strings.Contains(output, "public.posts.user_id") {
|
||||
t.Error("Output should contain posts.user_id in reference")
|
||||
}
|
||||
if !strings.Contains(output, "public.users.id") {
|
||||
t.Error("Output should contain users.id in reference")
|
||||
}
|
||||
if !strings.Contains(output, "ondelete: CASCADE") {
|
||||
t.Error("Output should contain ondelete: CASCADE")
|
||||
}
|
||||
if !strings.Contains(output, "onupdate: CASCADE") {
|
||||
t.Error("Output should contain onupdate: CASCADE")
|
||||
}
|
||||
|
||||
// Verify index
|
||||
if !strings.Contains(output, "indexes") {
|
||||
t.Error("Output should contain indexes section")
|
||||
}
|
||||
if !strings.Contains(output, "(email)") {
|
||||
t.Error("Output should contain email index")
|
||||
}
|
||||
if !strings.Contains(output, "unique") {
|
||||
t.Error("Output should contain unique attribute for email index")
|
||||
}
|
||||
assert.Contains(t, output, "Table public.users {")
|
||||
assert.Contains(t, output, "Table public.posts {")
|
||||
assert.Contains(t, output, "Ref: public.posts.user_id > public.users.id [delete: CASCADE]")
|
||||
assert.Contains(t, output, "(email) [unique, name: 'idx_users_email']")
|
||||
}
|
||||
|
||||
func TestWriter_WriteSchema(t *testing.T) {
|
||||
func TestWriter_WriteDatabase_OneToOneRelationship(t *testing.T) {
|
||||
db := models.InitDatabase("test_db")
|
||||
schema := models.InitSchema("public")
|
||||
|
||||
table := models.InitTable("users", "public")
|
||||
idCol := models.InitColumn("id", "users", "public")
|
||||
idCol.Type = "bigint"
|
||||
idCol.IsPrimaryKey = true
|
||||
idCol.NotNull = true
|
||||
table.Columns["id"] = idCol
|
||||
|
||||
usernameCol := models.InitColumn("username", "users", "public")
|
||||
usernameCol.Type = "varchar(50)"
|
||||
usernameCol.NotNull = true
|
||||
table.Columns["username"] = usernameCol
|
||||
|
||||
schema.Tables = append(schema.Tables, table)
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
outputPath := filepath.Join(tmpDir, "test.dbml")
|
||||
|
||||
opts := &writers.WriterOptions{
|
||||
OutputPath: outputPath,
|
||||
}
|
||||
|
||||
writer := NewWriter(opts)
|
||||
err := writer.WriteSchema(schema)
|
||||
if err != nil {
|
||||
t.Fatalf("WriteSchema() error = %v", err)
|
||||
}
|
||||
|
||||
content, err := os.ReadFile(outputPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read output file: %v", err)
|
||||
}
|
||||
|
||||
output := string(content)
|
||||
|
||||
// Verify table exists
|
||||
if !strings.Contains(output, "Table public.users {") {
|
||||
t.Error("Output should contain users table")
|
||||
}
|
||||
|
||||
// Verify columns
|
||||
if !strings.Contains(output, "id bigint") {
|
||||
t.Error("Output should contain id column")
|
||||
}
|
||||
if !strings.Contains(output, "username varchar(50)") {
|
||||
t.Error("Output should contain username column")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriter_WriteDatabase_MultipleSchemas(t *testing.T) {
|
||||
db := models.InitDatabase("test_db")
|
||||
|
||||
// Create public schema with users table
|
||||
publicSchema := models.InitSchema("public")
|
||||
usersTable := models.InitTable("users", "public")
|
||||
idCol := models.InitColumn("id", "users", "public")
|
||||
idCol.Type = "bigint"
|
||||
idCol.IsPrimaryKey = true
|
||||
usersTable.Columns["id"] = idCol
|
||||
publicSchema.Tables = append(publicSchema.Tables, usersTable)
|
||||
schema.Tables = append(schema.Tables, usersTable)
|
||||
|
||||
// Create admin schema with audit_logs table
|
||||
adminSchema := models.InitSchema("admin")
|
||||
auditTable := models.InitTable("audit_logs", "admin")
|
||||
auditIdCol := models.InitColumn("id", "audit_logs", "admin")
|
||||
auditIdCol.Type = "bigint"
|
||||
auditIdCol.IsPrimaryKey = true
|
||||
auditTable.Columns["id"] = auditIdCol
|
||||
|
||||
userIdCol := models.InitColumn("user_id", "audit_logs", "admin")
|
||||
profilesTable := models.InitTable("profiles", "public")
|
||||
profileIdCol := models.InitColumn("id", "profiles", "public")
|
||||
profileIdCol.Type = "bigint"
|
||||
profilesTable.Columns["id"] = profileIdCol
|
||||
userIdCol := models.InitColumn("user_id", "profiles", "public")
|
||||
userIdCol.Type = "bigint"
|
||||
auditTable.Columns["user_id"] = userIdCol
|
||||
userIdCol.IsPrimaryKey = true // This makes it a one-to-one
|
||||
profilesTable.Columns["user_id"] = userIdCol
|
||||
|
||||
// Add foreign key from admin.audit_logs to public.users
|
||||
fk := models.InitConstraint("fk_audit_user", models.ForeignKeyConstraint)
|
||||
fk.Table = "audit_logs"
|
||||
fk.Schema = "admin"
|
||||
fk := models.InitConstraint("fk_profiles_user", models.ForeignKeyConstraint)
|
||||
fk.Table = "profiles"
|
||||
fk.Schema = "public"
|
||||
fk.Columns = []string{"user_id"}
|
||||
fk.ReferencedTable = "users"
|
||||
fk.ReferencedSchema = "public"
|
||||
fk.ReferencedColumns = []string{"id"}
|
||||
fk.OnDelete = "SET NULL"
|
||||
auditTable.Constraints["fk_audit_user"] = fk
|
||||
|
||||
adminSchema.Tables = append(adminSchema.Tables, auditTable)
|
||||
|
||||
db.Schemas = append(db.Schemas, publicSchema, adminSchema)
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
outputPath := filepath.Join(tmpDir, "test.dbml")
|
||||
|
||||
opts := &writers.WriterOptions{
|
||||
OutputPath: outputPath,
|
||||
}
|
||||
|
||||
writer := NewWriter(opts)
|
||||
err := writer.WriteDatabase(db)
|
||||
if err != nil {
|
||||
t.Fatalf("WriteDatabase() error = %v", err)
|
||||
}
|
||||
|
||||
content, err := os.ReadFile(outputPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read output file: %v", err)
|
||||
}
|
||||
|
||||
output := string(content)
|
||||
|
||||
// Verify both schemas present
|
||||
if !strings.Contains(output, "public.users") {
|
||||
t.Error("Output should contain public.users table")
|
||||
}
|
||||
if !strings.Contains(output, "admin.audit_logs") {
|
||||
t.Error("Output should contain admin.audit_logs table")
|
||||
}
|
||||
|
||||
// Verify cross-schema foreign key
|
||||
if !strings.Contains(output, "admin.audit_logs.user_id") {
|
||||
t.Error("Output should contain admin.audit_logs.user_id in reference")
|
||||
}
|
||||
if !strings.Contains(output, "public.users.id") {
|
||||
t.Error("Output should contain public.users.id in reference")
|
||||
}
|
||||
if !strings.Contains(output, "ondelete: SET NULL") {
|
||||
t.Error("Output should contain ondelete: SET NULL")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriter_WriteTable_WithDefaults(t *testing.T) {
|
||||
table := models.InitTable("products", "public")
|
||||
|
||||
idCol := models.InitColumn("id", "products", "public")
|
||||
idCol.Type = "bigint"
|
||||
idCol.IsPrimaryKey = true
|
||||
table.Columns["id"] = idCol
|
||||
|
||||
isActiveCol := models.InitColumn("is_active", "products", "public")
|
||||
isActiveCol.Type = "boolean"
|
||||
isActiveCol.Default = "true"
|
||||
table.Columns["is_active"] = isActiveCol
|
||||
|
||||
createdCol := models.InitColumn("created_at", "products", "public")
|
||||
createdCol.Type = "timestamp"
|
||||
createdCol.Default = "CURRENT_TIMESTAMP"
|
||||
table.Columns["created_at"] = createdCol
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
outputPath := filepath.Join(tmpDir, "test.dbml")
|
||||
|
||||
opts := &writers.WriterOptions{
|
||||
OutputPath: outputPath,
|
||||
}
|
||||
|
||||
writer := NewWriter(opts)
|
||||
err := writer.WriteTable(table)
|
||||
if err != nil {
|
||||
t.Fatalf("WriteTable() error = %v", err)
|
||||
}
|
||||
|
||||
content, err := os.ReadFile(outputPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read output file: %v", err)
|
||||
}
|
||||
|
||||
output := string(content)
|
||||
|
||||
// Verify default values
|
||||
if !strings.Contains(output, "default:") {
|
||||
t.Error("Output should contain default values")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriter_WriteTable_EmptyPath(t *testing.T) {
|
||||
table := models.InitTable("users", "public")
|
||||
idCol := models.InitColumn("id", "users", "public")
|
||||
idCol.Type = "bigint"
|
||||
table.Columns["id"] = idCol
|
||||
|
||||
// When OutputPath is empty, it should print to stdout (not error)
|
||||
opts := &writers.WriterOptions{
|
||||
OutputPath: "",
|
||||
}
|
||||
|
||||
writer := NewWriter(opts)
|
||||
err := writer.WriteTable(table)
|
||||
if err != nil {
|
||||
t.Fatalf("WriteTable() with empty path should not error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriter_WriteDatabase_WithComments(t *testing.T) {
|
||||
db := models.InitDatabase("test_db")
|
||||
db.Description = "Test database description"
|
||||
db.Comment = "Additional comment"
|
||||
|
||||
schema := models.InitSchema("public")
|
||||
table := models.InitTable("users", "public")
|
||||
table.Comment = "Users table comment"
|
||||
|
||||
idCol := models.InitColumn("id", "users", "public")
|
||||
idCol.Type = "bigint"
|
||||
idCol.IsPrimaryKey = true
|
||||
idCol.Comment = "Primary key"
|
||||
table.Columns["id"] = idCol
|
||||
|
||||
schema.Tables = append(schema.Tables, table)
|
||||
profilesTable.Constraints["fk_profiles_user"] = fk
|
||||
schema.Tables = append(schema.Tables, profilesTable)
|
||||
db.Schemas = append(db.Schemas, schema)
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
outputPath := filepath.Join(tmpDir, "test.dbml")
|
||||
|
||||
opts := &writers.WriterOptions{
|
||||
OutputPath: outputPath,
|
||||
}
|
||||
|
||||
opts := &writers.WriterOptions{OutputPath: outputPath}
|
||||
writer := NewWriter(opts)
|
||||
err := writer.WriteDatabase(db)
|
||||
if err != nil {
|
||||
t.Fatalf("WriteDatabase() error = %v", err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
|
||||
content, err := os.ReadFile(outputPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read output file: %v", err)
|
||||
}
|
||||
|
||||
assert.NoError(t, err)
|
||||
output := string(content)
|
||||
|
||||
// Verify comments are present
|
||||
if !strings.Contains(output, "//") {
|
||||
t.Error("Output should contain comments")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriter_WriteDatabase_WithIndexType(t *testing.T) {
|
||||
db := models.InitDatabase("test_db")
|
||||
schema := models.InitSchema("public")
|
||||
table := models.InitTable("users", "public")
|
||||
|
||||
idCol := models.InitColumn("id", "users", "public")
|
||||
idCol.Type = "bigint"
|
||||
idCol.IsPrimaryKey = true
|
||||
table.Columns["id"] = idCol
|
||||
|
||||
emailCol := models.InitColumn("email", "users", "public")
|
||||
emailCol.Type = "varchar(255)"
|
||||
table.Columns["email"] = emailCol
|
||||
|
||||
// Add index with type
|
||||
idx := models.InitIndex("idx_email")
|
||||
idx.Columns = []string{"email"}
|
||||
idx.Type = "btree"
|
||||
idx.Unique = true
|
||||
idx.Table = "users"
|
||||
idx.Schema = "public"
|
||||
table.Indexes["idx_email"] = idx
|
||||
|
||||
schema.Tables = append(schema.Tables, table)
|
||||
db.Schemas = append(db.Schemas, schema)
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
outputPath := filepath.Join(tmpDir, "test.dbml")
|
||||
|
||||
opts := &writers.WriterOptions{
|
||||
OutputPath: outputPath,
|
||||
}
|
||||
|
||||
writer := NewWriter(opts)
|
||||
err := writer.WriteDatabase(db)
|
||||
if err != nil {
|
||||
t.Fatalf("WriteDatabase() error = %v", err)
|
||||
}
|
||||
|
||||
content, err := os.ReadFile(outputPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read output file: %v", err)
|
||||
}
|
||||
|
||||
output := string(content)
|
||||
|
||||
// Verify index with type
|
||||
if !strings.Contains(output, "type:") || !strings.Contains(output, "btree") {
|
||||
t.Error("Output should contain index type")
|
||||
}
|
||||
}
|
||||
assert.Contains(t, output, "Ref: public.profiles.user_id - public.users.id")
|
||||
}
|
||||
194
pkg/writers/dctx/roundtrip_test.go
Normal file
194
pkg/writers/dctx/roundtrip_test.go
Normal file
@@ -0,0 +1,194 @@
|
||||
package dctx
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/readers"
|
||||
dctxreader "git.warky.dev/wdevs/relspecgo/pkg/readers/dctx"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/writers"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestRoundTrip_WriteAndRead(t *testing.T) {
|
||||
// 1. Create a sample schema with relationships
|
||||
schema := models.InitSchema("public")
|
||||
schema.Name = "TestDB"
|
||||
|
||||
// Table 1: users
|
||||
usersTable := models.InitTable("users", "public")
|
||||
usersTable.Comment = "Stores user information"
|
||||
idCol := models.InitColumn("id", "users", "public")
|
||||
idCol.Type = "serial"
|
||||
idCol.IsPrimaryKey = true
|
||||
idCol.NotNull = true
|
||||
usersTable.Columns["id"] = idCol
|
||||
nameCol := models.InitColumn("name", "users", "public")
|
||||
nameCol.Type = "varchar"
|
||||
nameCol.Length = 100
|
||||
usersTable.Columns["name"] = nameCol
|
||||
pkIndex := models.InitIndex("users_pkey", "users", "public")
|
||||
pkIndex.Unique = true
|
||||
pkIndex.Columns = []string{"id"}
|
||||
usersTable.Indexes["users_pkey"] = pkIndex
|
||||
|
||||
pkConstraint := models.InitConstraint("users_pkey", models.PrimaryKeyConstraint)
|
||||
pkConstraint.Table = "users"
|
||||
pkConstraint.Schema = "public"
|
||||
pkConstraint.Columns = []string{"id"}
|
||||
usersTable.Constraints["users_pkey"] = pkConstraint
|
||||
|
||||
schema.Tables = append(schema.Tables, usersTable)
|
||||
|
||||
// Table 2: posts
|
||||
postsTable := models.InitTable("posts", "public")
|
||||
postsTable.Comment = "Stores blog posts"
|
||||
postIDCol := models.InitColumn("id", "posts", "public")
|
||||
postIDCol.Type = "serial"
|
||||
postIDCol.IsPrimaryKey = true
|
||||
postIDCol.NotNull = true
|
||||
postsTable.Columns["id"] = postIDCol
|
||||
titleCol := models.InitColumn("title", "posts", "public")
|
||||
titleCol.Type = "varchar"
|
||||
titleCol.Length = 255
|
||||
postsTable.Columns["title"] = titleCol
|
||||
userIDCol := models.InitColumn("user_id", "posts", "public")
|
||||
userIDCol.Type = "integer"
|
||||
postsTable.Columns["user_id"] = userIDCol
|
||||
postsPKIndex := models.InitIndex("posts_pkey", "posts", "public")
|
||||
postsPKIndex.Unique = true
|
||||
postsPKIndex.Columns = []string{"id"}
|
||||
postsTable.Indexes["posts_pkey"] = postsPKIndex
|
||||
|
||||
fkIndex := models.InitIndex("posts_user_id_idx", "posts", "public")
|
||||
fkIndex.Columns = []string{"user_id"}
|
||||
postsTable.Indexes["posts_user_id_idx"] = fkIndex
|
||||
|
||||
postsPKConstraint := models.InitConstraint("posts_pkey", models.PrimaryKeyConstraint)
|
||||
postsPKConstraint.Table = "posts"
|
||||
postsPKConstraint.Schema = "public"
|
||||
postsPKConstraint.Columns = []string{"id"}
|
||||
postsTable.Constraints["posts_pkey"] = postsPKConstraint
|
||||
|
||||
// Foreign key constraint
|
||||
fkConstraint := models.InitConstraint("fk_posts_users", models.ForeignKeyConstraint)
|
||||
fkConstraint.Table = "posts"
|
||||
fkConstraint.Schema = "public"
|
||||
fkConstraint.Columns = []string{"user_id"}
|
||||
fkConstraint.ReferencedTable = "users"
|
||||
fkConstraint.ReferencedSchema = "public"
|
||||
fkConstraint.ReferencedColumns = []string{"id"}
|
||||
fkConstraint.OnDelete = "CASCADE"
|
||||
fkConstraint.OnUpdate = "NO ACTION"
|
||||
postsTable.Constraints["fk_posts_users"] = fkConstraint
|
||||
|
||||
schema.Tables = append(schema.Tables, postsTable)
|
||||
|
||||
// Relation
|
||||
relation := models.InitRelationship("posts_to_users", models.OneToMany)
|
||||
relation.FromTable = "posts"
|
||||
relation.FromSchema = "public"
|
||||
relation.ToTable = "users"
|
||||
relation.ToSchema = "public"
|
||||
relation.ForeignKey = "fk_posts_users"
|
||||
schema.Relations = append(schema.Relations, relation)
|
||||
|
||||
// 2. Write the schema to DCTX
|
||||
outputPath := filepath.Join(t.TempDir(), "roundtrip.dctx")
|
||||
writerOpts := &writers.WriterOptions{
|
||||
OutputPath: outputPath,
|
||||
}
|
||||
writer := NewWriter(writerOpts)
|
||||
|
||||
err := writer.WriteSchema(schema)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Verify file was created
|
||||
_, err = os.Stat(outputPath)
|
||||
assert.NoError(t, err, "Output file should exist")
|
||||
|
||||
// 3. Read the schema back from DCTX
|
||||
readerOpts := &readers.ReaderOptions{
|
||||
FilePath: outputPath,
|
||||
}
|
||||
reader := dctxreader.NewReader(readerOpts)
|
||||
|
||||
db, err := reader.ReadDatabase()
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, db)
|
||||
|
||||
// 4. Verify the schema was read correctly
|
||||
assert.Len(t, db.Schemas, 1, "Should have one schema")
|
||||
readSchema := db.Schemas[0]
|
||||
|
||||
// Verify tables
|
||||
assert.Len(t, readSchema.Tables, 2, "Should have two tables")
|
||||
|
||||
// Find users and posts tables
|
||||
var readUsersTable, readPostsTable *models.Table
|
||||
for _, table := range readSchema.Tables {
|
||||
switch table.Name {
|
||||
case "users":
|
||||
readUsersTable = table
|
||||
case "posts":
|
||||
readPostsTable = table
|
||||
}
|
||||
}
|
||||
|
||||
assert.NotNil(t, readUsersTable, "Users table should exist")
|
||||
assert.NotNil(t, readPostsTable, "Posts table should exist")
|
||||
|
||||
// Verify columns
|
||||
assert.Len(t, readUsersTable.Columns, 2, "Users table should have 2 columns")
|
||||
assert.NotNil(t, readUsersTable.Columns["id"])
|
||||
assert.NotNil(t, readUsersTable.Columns["name"])
|
||||
|
||||
assert.Len(t, readPostsTable.Columns, 3, "Posts table should have 3 columns")
|
||||
assert.NotNil(t, readPostsTable.Columns["id"])
|
||||
assert.NotNil(t, readPostsTable.Columns["title"])
|
||||
assert.NotNil(t, readPostsTable.Columns["user_id"])
|
||||
|
||||
// Verify relationships were preserved
|
||||
// The DCTX reader stores relationships on the foreign table (posts)
|
||||
assert.NotEmpty(t, readPostsTable.Relationships, "Posts table should have relationships")
|
||||
|
||||
// Debug: print all relationships
|
||||
t.Logf("Posts table has %d relationships:", len(readPostsTable.Relationships))
|
||||
for name, rel := range readPostsTable.Relationships {
|
||||
t.Logf(" - %s: from=%s to=%s fk=%s", name, rel.FromTable, rel.ToTable, rel.ForeignKey)
|
||||
}
|
||||
|
||||
// Find the relationship - the reader creates it with FromTable as primary and ToTable as foreign
|
||||
var postsToUsersRel *models.Relationship
|
||||
for _, rel := range readPostsTable.Relationships {
|
||||
// The relationship should have posts as ToTable (foreign) and users as FromTable (primary)
|
||||
if rel.FromTable == "users" && rel.ToTable == "posts" {
|
||||
postsToUsersRel = rel
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
assert.NotNil(t, postsToUsersRel, "Should have relationship from users to posts")
|
||||
if postsToUsersRel != nil {
|
||||
assert.Equal(t, "users", postsToUsersRel.FromTable, "Relationship should come from users (primary) table")
|
||||
assert.Equal(t, "posts", postsToUsersRel.ToTable, "Relationship should point to posts (foreign) table")
|
||||
assert.NotEmpty(t, postsToUsersRel.ForeignKey, "Relationship should have a foreign key")
|
||||
}
|
||||
|
||||
// Verify foreign key constraint
|
||||
fks := readPostsTable.GetForeignKeys()
|
||||
assert.NotEmpty(t, fks, "Posts table should have foreign keys")
|
||||
|
||||
if len(fks) > 0 {
|
||||
fk := fks[0]
|
||||
assert.Equal(t, models.ForeignKeyConstraint, fk.Type)
|
||||
assert.Contains(t, fk.Columns, "user_id")
|
||||
assert.Equal(t, "users", fk.ReferencedTable)
|
||||
assert.Contains(t, fk.ReferencedColumns, "id")
|
||||
assert.Equal(t, "CASCADE", fk.OnDelete)
|
||||
}
|
||||
|
||||
t.Logf("Round-trip test successful: wrote and read back %d tables with relationships", len(readSchema.Tables))
|
||||
}
|
||||
@@ -1,36 +1,379 @@
|
||||
package dctx
|
||||
|
||||
import (
|
||||
"encoding/xml"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/writers"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// Writer implements the writers.Writer interface for DCTX format
|
||||
// Note: DCTX is a read-only format used for loading Clarion dictionary files
|
||||
type Writer struct {
|
||||
options *writers.WriterOptions
|
||||
options *writers.WriterOptions
|
||||
fieldGuidMap map[string]string // key: "table.column", value: guid
|
||||
keyGuidMap map[string]string // key: "table.index", value: guid
|
||||
tableGuidMap map[string]string // key: "table", value: guid
|
||||
}
|
||||
|
||||
// NewWriter creates a new DCTX writer with the given options
|
||||
func NewWriter(options *writers.WriterOptions) *Writer {
|
||||
return &Writer{
|
||||
options: options,
|
||||
options: options,
|
||||
fieldGuidMap: make(map[string]string),
|
||||
keyGuidMap: make(map[string]string),
|
||||
tableGuidMap: make(map[string]string),
|
||||
}
|
||||
}
|
||||
|
||||
// WriteDatabase returns an error as DCTX format is read-only
|
||||
// WriteDatabase is not implemented for DCTX
|
||||
func (w *Writer) WriteDatabase(db *models.Database) error {
|
||||
return fmt.Errorf("DCTX format is read-only and does not support writing - it is used for loading Clarion dictionary files only")
|
||||
return fmt.Errorf("writing a full database is not supported for DCTX, please write a single schema")
|
||||
}
|
||||
|
||||
// WriteSchema returns an error as DCTX format is read-only
|
||||
// WriteSchema writes a schema to the writer in DCTX format
|
||||
func (w *Writer) WriteSchema(schema *models.Schema) error {
|
||||
return fmt.Errorf("DCTX format is read-only and does not support writing - it is used for loading Clarion dictionary files only")
|
||||
dctx := models.DCTXDictionary{
|
||||
Name: schema.Name,
|
||||
Version: "1",
|
||||
Tables: make([]models.DCTXTable, len(schema.Tables)),
|
||||
}
|
||||
|
||||
tableSlice := make([]*models.Table, 0, len(schema.Tables))
|
||||
for _, t := range schema.Tables {
|
||||
tableSlice = append(tableSlice, t)
|
||||
}
|
||||
|
||||
// Pass 1: Create fields and populate fieldGuidMap
|
||||
for i, table := range tableSlice {
|
||||
dctx.Tables[i] = w.mapTableFields(table)
|
||||
}
|
||||
|
||||
// Pass 2: Create keys and populate keyGuidMap
|
||||
for i, table := range tableSlice {
|
||||
dctx.Tables[i].Keys = w.mapTableKeys(table)
|
||||
}
|
||||
|
||||
// Pass 3: Collect all relationships (from schema and tables)
|
||||
var allRelations []*models.Relationship
|
||||
|
||||
// Add schema-level relations
|
||||
allRelations = append(allRelations, schema.Relations...)
|
||||
|
||||
// Add table-level relationships
|
||||
for _, table := range tableSlice {
|
||||
for _, rel := range table.Relationships {
|
||||
// Check if this relationship is already in the list (avoid duplicates)
|
||||
isDuplicate := false
|
||||
for _, existing := range allRelations {
|
||||
if existing.Name == rel.Name &&
|
||||
existing.FromTable == rel.FromTable &&
|
||||
existing.ToTable == rel.ToTable {
|
||||
isDuplicate = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !isDuplicate {
|
||||
allRelations = append(allRelations, rel)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Map all relations to DCTX format
|
||||
dctx.Relations = make([]models.DCTXRelation, len(allRelations))
|
||||
for i, rel := range allRelations {
|
||||
dctx.Relations[i] = w.mapRelation(rel, schema)
|
||||
}
|
||||
|
||||
output, err := xml.MarshalIndent(dctx, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
file, err := os.Create(w.options.OutputPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
if _, err := file.Write([]byte(xml.Header)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = file.Write(output)
|
||||
return err
|
||||
}
|
||||
|
||||
// WriteTable returns an error as DCTX format is read-only
|
||||
// WriteTable writes a single table to the writer in DCTX format
|
||||
func (w *Writer) WriteTable(table *models.Table) error {
|
||||
return fmt.Errorf("DCTX format is read-only and does not support writing - it is used for loading Clarion dictionary files only")
|
||||
dctxTable := w.mapTableFields(table)
|
||||
dctxTable.Keys = w.mapTableKeys(table)
|
||||
|
||||
output, err := xml.MarshalIndent(dctxTable, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
file, err := os.Create(w.options.OutputPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
_, err = file.Write(output)
|
||||
return err
|
||||
}
|
||||
|
||||
func (w *Writer) mapTableFields(table *models.Table) models.DCTXTable {
|
||||
// Generate prefix (first 3 chars, or full name if shorter)
|
||||
prefix := table.Name
|
||||
if len(table.Name) > 3 {
|
||||
prefix = table.Name[:3]
|
||||
}
|
||||
|
||||
tableGuid := w.newGUID()
|
||||
w.tableGuidMap[table.Name] = tableGuid
|
||||
|
||||
dctxTable := models.DCTXTable{
|
||||
Guid: tableGuid,
|
||||
Name: table.Name,
|
||||
Prefix: prefix,
|
||||
Description: table.Comment,
|
||||
Fields: make([]models.DCTXField, len(table.Columns)),
|
||||
Options: []models.DCTXOption{
|
||||
{
|
||||
Property: "SQL",
|
||||
PropertyType: "1",
|
||||
PropertyValue: "1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
i := 0
|
||||
for _, column := range table.Columns {
|
||||
dctxTable.Fields[i] = w.mapField(column)
|
||||
i++
|
||||
}
|
||||
|
||||
return dctxTable
|
||||
}
|
||||
|
||||
func (w *Writer) mapTableKeys(table *models.Table) []models.DCTXKey {
|
||||
keys := make([]models.DCTXKey, len(table.Indexes))
|
||||
i := 0
|
||||
for _, index := range table.Indexes {
|
||||
keys[i] = w.mapKey(index, table)
|
||||
i++
|
||||
}
|
||||
return keys
|
||||
}
|
||||
|
||||
func (w *Writer) mapField(column *models.Column) models.DCTXField {
|
||||
guid := w.newGUID()
|
||||
fieldKey := fmt.Sprintf("%s.%s", column.Table, column.Name)
|
||||
w.fieldGuidMap[fieldKey] = guid
|
||||
|
||||
return models.DCTXField{
|
||||
Guid: guid,
|
||||
Name: column.Name,
|
||||
DataType: w.mapDataType(column.Type),
|
||||
Size: column.Length,
|
||||
}
|
||||
}
|
||||
|
||||
func (w *Writer) mapDataType(dataType string) string {
|
||||
switch dataType {
|
||||
case "integer", "int", "int4", "serial":
|
||||
return "LONG"
|
||||
case "bigint", "int8", "bigserial":
|
||||
return "DECIMAL"
|
||||
case "smallint", "int2":
|
||||
return "SHORT"
|
||||
case "boolean", "bool":
|
||||
return "BYTE"
|
||||
case "text", "varchar", "char":
|
||||
return "CSTRING"
|
||||
case "date":
|
||||
return "DATE"
|
||||
case "time":
|
||||
return "TIME"
|
||||
case "timestamp", "timestamptz":
|
||||
return "STRING"
|
||||
case "decimal", "numeric":
|
||||
return "DECIMAL"
|
||||
default:
|
||||
return "STRING"
|
||||
}
|
||||
}
|
||||
|
||||
func (w *Writer) mapKey(index *models.Index, table *models.Table) models.DCTXKey {
|
||||
guid := w.newGUID()
|
||||
keyKey := fmt.Sprintf("%s.%s", table.Name, index.Name)
|
||||
w.keyGuidMap[keyKey] = guid
|
||||
|
||||
key := models.DCTXKey{
|
||||
Guid: guid,
|
||||
Name: index.Name,
|
||||
Primary: strings.HasSuffix(index.Name, "_pkey"),
|
||||
Unique: index.Unique,
|
||||
Components: make([]models.DCTXComponent, len(index.Columns)),
|
||||
Description: index.Comment,
|
||||
}
|
||||
|
||||
for i, colName := range index.Columns {
|
||||
fieldKey := fmt.Sprintf("%s.%s", table.Name, colName)
|
||||
fieldID := w.fieldGuidMap[fieldKey]
|
||||
key.Components[i] = models.DCTXComponent{
|
||||
Guid: w.newGUID(),
|
||||
FieldId: fieldID,
|
||||
Order: i + 1,
|
||||
Ascend: true,
|
||||
}
|
||||
}
|
||||
|
||||
return key
|
||||
}
|
||||
|
||||
func (w *Writer) mapRelation(rel *models.Relationship, schema *models.Schema) models.DCTXRelation {
|
||||
// Find the foreign key constraint from the 'from' table
|
||||
var fromTable *models.Table
|
||||
for _, t := range schema.Tables {
|
||||
if t.Name == rel.FromTable {
|
||||
fromTable = t
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
var constraint *models.Constraint
|
||||
if fromTable != nil {
|
||||
for _, c := range fromTable.Constraints {
|
||||
if c.Name == rel.ForeignKey {
|
||||
constraint = c
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var foreignKeyGUID string
|
||||
var fkColumns []string
|
||||
if constraint != nil {
|
||||
fkColumns = constraint.Columns
|
||||
// In DCTX, a relation is often linked by a foreign key which is an index.
|
||||
// We'll look for an index that matches the constraint columns.
|
||||
for _, index := range fromTable.Indexes {
|
||||
if strings.Join(index.Columns, ",") == strings.Join(constraint.Columns, ",") {
|
||||
keyKey := fmt.Sprintf("%s.%s", fromTable.Name, index.Name)
|
||||
foreignKeyGUID = w.keyGuidMap[keyKey]
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Find the primary key of the 'to' table
|
||||
var toTable *models.Table
|
||||
for _, t := range schema.Tables {
|
||||
if t.Name == rel.ToTable {
|
||||
toTable = t
|
||||
break
|
||||
}
|
||||
}
|
||||
var primaryKeyGUID string
|
||||
var pkColumns []string
|
||||
|
||||
// Use referenced columns from the constraint if available
|
||||
if constraint != nil && len(constraint.ReferencedColumns) > 0 {
|
||||
pkColumns = constraint.ReferencedColumns
|
||||
}
|
||||
|
||||
if toTable != nil {
|
||||
// Find the matching primary key index
|
||||
for _, index := range toTable.Indexes {
|
||||
// If we have referenced columns, try to match them
|
||||
if len(pkColumns) > 0 {
|
||||
if strings.Join(index.Columns, ",") == strings.Join(pkColumns, ",") {
|
||||
keyKey := fmt.Sprintf("%s.%s", toTable.Name, index.Name)
|
||||
primaryKeyGUID = w.keyGuidMap[keyKey]
|
||||
break
|
||||
}
|
||||
} else if strings.HasSuffix(index.Name, "_pkey") {
|
||||
// Fall back to finding primary key by naming convention
|
||||
keyKey := fmt.Sprintf("%s.%s", toTable.Name, index.Name)
|
||||
primaryKeyGUID = w.keyGuidMap[keyKey]
|
||||
pkColumns = index.Columns
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Create field mappings
|
||||
// NOTE: DCTX has backwards naming - ForeignMapping contains PRIMARY table fields,
|
||||
// and PrimaryMapping contains FOREIGN table fields
|
||||
var foreignMappings []models.DCTXFieldMapping // Will contain primary table fields
|
||||
var primaryMappings []models.DCTXFieldMapping // Will contain foreign table fields
|
||||
|
||||
// Map foreign key columns (from foreign table) to PrimaryMapping
|
||||
for _, colName := range fkColumns {
|
||||
fieldKey := fmt.Sprintf("%s.%s", rel.FromTable, colName)
|
||||
if fieldGUID, exists := w.fieldGuidMap[fieldKey]; exists {
|
||||
primaryMappings = append(primaryMappings, models.DCTXFieldMapping{
|
||||
Guid: w.newGUID(),
|
||||
Field: fieldGUID,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Map primary key columns (from primary table) to ForeignMapping
|
||||
for _, colName := range pkColumns {
|
||||
fieldKey := fmt.Sprintf("%s.%s", rel.ToTable, colName)
|
||||
if fieldGUID, exists := w.fieldGuidMap[fieldKey]; exists {
|
||||
foreignMappings = append(foreignMappings, models.DCTXFieldMapping{
|
||||
Guid: w.newGUID(),
|
||||
Field: fieldGUID,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Get OnDelete and OnUpdate actions from the constraint
|
||||
onDelete := ""
|
||||
onUpdate := ""
|
||||
if constraint != nil {
|
||||
onDelete = w.mapReferentialAction(constraint.OnDelete)
|
||||
onUpdate = w.mapReferentialAction(constraint.OnUpdate)
|
||||
}
|
||||
|
||||
return models.DCTXRelation{
|
||||
Guid: w.newGUID(),
|
||||
PrimaryTable: w.tableGuidMap[rel.ToTable], // GUID of the 'to' table (e.g., users)
|
||||
ForeignTable: w.tableGuidMap[rel.FromTable], // GUID of the 'from' table (e.g., posts)
|
||||
PrimaryKey: primaryKeyGUID,
|
||||
ForeignKey: foreignKeyGUID,
|
||||
Delete: onDelete,
|
||||
Update: onUpdate,
|
||||
ForeignMappings: foreignMappings,
|
||||
PrimaryMappings: primaryMappings,
|
||||
}
|
||||
}
|
||||
|
||||
// mapReferentialAction maps SQL referential actions to DCTX format
|
||||
func (w *Writer) mapReferentialAction(action string) string {
|
||||
switch strings.ToUpper(action) {
|
||||
case "RESTRICT":
|
||||
return "RESTRICT_SERVER"
|
||||
case "CASCADE":
|
||||
return "CASCADE_SERVER"
|
||||
case "SET NULL":
|
||||
return "SET_NULL_SERVER"
|
||||
case "SET DEFAULT":
|
||||
return "SET_DEFAULT_SERVER"
|
||||
case "NO ACTION":
|
||||
return "NO_ACTION_SERVER"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func (w *Writer) newGUID() string {
|
||||
return "{" + uuid.New().String() + "}"
|
||||
}
|
||||
|
||||
@@ -1,110 +1,152 @@
|
||||
package dctx
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"encoding/xml"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/writers"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// TestWriter_WriteDatabase_ReturnsError tests that WriteDatabase returns an error
|
||||
// since DCTX format is read-only
|
||||
func TestWriter_WriteDatabase_ReturnsError(t *testing.T) {
|
||||
db := models.InitDatabase("test_db")
|
||||
func TestWriter_WriteSchema(t *testing.T) {
|
||||
// 1. Create a sample schema
|
||||
schema := models.InitSchema("public")
|
||||
table := models.InitTable("users", "public")
|
||||
schema.Name = "TestDB"
|
||||
|
||||
// Table 1: users
|
||||
usersTable := models.InitTable("users", "public")
|
||||
usersTable.Comment = "Stores user information"
|
||||
idCol := models.InitColumn("id", "users", "public")
|
||||
idCol.Type = "bigint"
|
||||
table.Columns["id"] = idCol
|
||||
|
||||
schema.Tables = append(schema.Tables, table)
|
||||
db.Schemas = append(db.Schemas, schema)
|
||||
|
||||
opts := &writers.WriterOptions{
|
||||
OutputPath: "/tmp/test.dctx",
|
||||
}
|
||||
|
||||
writer := NewWriter(opts)
|
||||
err := writer.WriteDatabase(db)
|
||||
|
||||
if err == nil {
|
||||
t.Error("WriteDatabase() should return an error for read-only format")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "read-only") {
|
||||
t.Errorf("Error message should mention 'read-only', got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestWriter_WriteSchema_ReturnsError tests that WriteSchema returns an error
|
||||
// since DCTX format is read-only
|
||||
func TestWriter_WriteSchema_ReturnsError(t *testing.T) {
|
||||
schema := models.InitSchema("public")
|
||||
table := models.InitTable("users", "public")
|
||||
|
||||
idCol := models.InitColumn("id", "users", "public")
|
||||
idCol.Type = "bigint"
|
||||
table.Columns["id"] = idCol
|
||||
|
||||
schema.Tables = append(schema.Tables, table)
|
||||
|
||||
opts := &writers.WriterOptions{
|
||||
OutputPath: "/tmp/test.dctx",
|
||||
}
|
||||
|
||||
writer := NewWriter(opts)
|
||||
err := writer.WriteSchema(schema)
|
||||
|
||||
if err == nil {
|
||||
t.Error("WriteSchema() should return an error for read-only format")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "read-only") {
|
||||
t.Errorf("Error message should mention 'read-only', got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestWriter_WriteTable_ReturnsError tests that WriteTable returns an error
|
||||
// since DCTX format is read-only
|
||||
func TestWriter_WriteTable_ReturnsError(t *testing.T) {
|
||||
table := models.InitTable("users", "public")
|
||||
|
||||
idCol := models.InitColumn("id", "users", "public")
|
||||
idCol.Type = "bigint"
|
||||
idCol.Type = "serial"
|
||||
idCol.IsPrimaryKey = true
|
||||
table.Columns["id"] = idCol
|
||||
usersTable.Columns["id"] = idCol
|
||||
nameCol := models.InitColumn("name", "users", "public")
|
||||
nameCol.Type = "varchar"
|
||||
nameCol.Length = 100
|
||||
usersTable.Columns["name"] = nameCol
|
||||
pkIndex := models.InitIndex("users_pkey", "users", "public")
|
||||
pkIndex.Unique = true
|
||||
pkIndex.Columns = []string{"id"}
|
||||
usersTable.Indexes["users_pkey"] = pkIndex
|
||||
schema.Tables = append(schema.Tables, usersTable)
|
||||
|
||||
// Table 2: posts
|
||||
postsTable := models.InitTable("posts", "public")
|
||||
postsTable.Comment = "Stores blog posts"
|
||||
postIDCol := models.InitColumn("id", "posts", "public")
|
||||
postIDCol.Type = "serial"
|
||||
postIDCol.IsPrimaryKey = true
|
||||
postsTable.Columns["id"] = postIDCol
|
||||
titleCol := models.InitColumn("title", "posts", "public")
|
||||
titleCol.Type = "varchar"
|
||||
titleCol.Length = 255
|
||||
postsTable.Columns["title"] = titleCol
|
||||
userIDCol := models.InitColumn("user_id", "posts", "public")
|
||||
userIDCol.Type = "integer"
|
||||
postsTable.Columns["user_id"] = userIDCol
|
||||
postsPKIndex := models.InitIndex("posts_pkey", "posts", "public")
|
||||
postsPKIndex.Unique = true
|
||||
postsPKIndex.Columns = []string{"id"}
|
||||
postsTable.Indexes["posts_pkey"] = postsPKIndex
|
||||
|
||||
fkIndex := models.InitIndex("posts_user_id_idx", "posts", "public")
|
||||
fkIndex.Columns = []string{"user_id"}
|
||||
postsTable.Indexes["posts_user_id_idx"] = fkIndex
|
||||
schema.Tables = append(schema.Tables, postsTable)
|
||||
|
||||
// Constraint for the relationship
|
||||
fkConstraint := models.InitConstraint("fk_posts_users", models.ForeignKeyConstraint)
|
||||
fkConstraint.Table = "posts"
|
||||
fkConstraint.Schema = "public"
|
||||
fkConstraint.Columns = []string{"user_id"}
|
||||
fkConstraint.ReferencedTable = "users"
|
||||
fkConstraint.ReferencedSchema = "public"
|
||||
fkConstraint.ReferencedColumns = []string{"id"}
|
||||
postsTable.Constraints["fk_posts_users"] = fkConstraint
|
||||
|
||||
// Relation
|
||||
relation := models.InitRelation("fk_posts_users", "public")
|
||||
relation.FromTable = "posts"
|
||||
relation.ToTable = "users"
|
||||
relation.ForeignKey = "fk_posts_users"
|
||||
schema.Relations = append(schema.Relations, relation)
|
||||
|
||||
// 2. Setup writer
|
||||
outputPath := "/tmp/test.dctx"
|
||||
opts := &writers.WriterOptions{
|
||||
OutputPath: "/tmp/test.dctx",
|
||||
OutputPath: outputPath,
|
||||
}
|
||||
|
||||
writer := NewWriter(opts)
|
||||
err := writer.WriteTable(table)
|
||||
|
||||
if err == nil {
|
||||
t.Error("WriteTable() should return an error for read-only format")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "read-only") {
|
||||
t.Errorf("Error message should mention 'read-only', got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewWriter tests that NewWriter creates a valid writer instance
|
||||
func TestNewWriter(t *testing.T) {
|
||||
opts := &writers.WriterOptions{
|
||||
OutputPath: "/tmp/test.dctx",
|
||||
}
|
||||
|
||||
writer := NewWriter(opts)
|
||||
|
||||
if writer == nil {
|
||||
t.Error("NewWriter() should return a non-nil writer")
|
||||
}
|
||||
// 3. Write the schema
|
||||
err := writer.WriteSchema(schema)
|
||||
assert.NoError(t, err)
|
||||
|
||||
if writer.options != opts {
|
||||
t.Error("Writer options should match the provided options")
|
||||
}
|
||||
}
|
||||
// 4. Read the file and unmarshal it
|
||||
actualBytes, err := os.ReadFile(outputPath)
|
||||
assert.NoError(t, err)
|
||||
|
||||
var dctx models.DCTXDictionary
|
||||
err = xml.Unmarshal(actualBytes, &dctx)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 5. Assert properties of the unmarshaled struct
|
||||
assert.Equal(t, "TestDB", dctx.Name)
|
||||
assert.Equal(t, "1", dctx.Version)
|
||||
assert.Len(t, dctx.Tables, 2)
|
||||
assert.Len(t, dctx.Relations, 1)
|
||||
|
||||
// Assert users table
|
||||
usersTableResult := dctx.Tables[0]
|
||||
assert.Equal(t, "users", usersTableResult.Name)
|
||||
assert.Len(t, usersTableResult.Fields, 2)
|
||||
assert.Len(t, usersTableResult.Keys, 1)
|
||||
userPK := usersTableResult.Keys[0]
|
||||
assert.True(t, userPK.Primary)
|
||||
assert.Equal(t, "users_pkey", userPK.Name)
|
||||
assert.Len(t, userPK.Components, 1)
|
||||
userPKComponent := userPK.Components[0]
|
||||
assert.NotEmpty(t, userPKComponent.FieldId)
|
||||
|
||||
// Assert posts table
|
||||
postsTableResult := dctx.Tables[1]
|
||||
assert.Equal(t, "posts", postsTableResult.Name)
|
||||
assert.Len(t, postsTableResult.Fields, 3)
|
||||
assert.Len(t, postsTableResult.Keys, 2)
|
||||
postsFK := postsTableResult.Keys[1] // Assuming order
|
||||
assert.False(t, postsFK.Primary)
|
||||
assert.Equal(t, "posts_user_id_idx", postsFK.Name)
|
||||
assert.Len(t, postsFK.Components, 1)
|
||||
postsFKComponent := postsFK.Components[0]
|
||||
assert.NotEmpty(t, postsFKComponent.FieldId)
|
||||
|
||||
// Assert relation
|
||||
relationResult := dctx.Relations[0]
|
||||
// PrimaryTable and ForeignTable should be GUIDs in DCTX format
|
||||
assert.NotEmpty(t, relationResult.PrimaryTable, "PrimaryTable should have a GUID")
|
||||
assert.NotEmpty(t, relationResult.ForeignTable, "ForeignTable should have a GUID")
|
||||
assert.NotEmpty(t, relationResult.PrimaryKey)
|
||||
assert.NotEmpty(t, relationResult.ForeignKey)
|
||||
|
||||
// Check if the table GUIDs match
|
||||
assert.Equal(t, usersTableResult.Guid, relationResult.PrimaryTable, "PrimaryTable GUID should match users table")
|
||||
assert.Equal(t, postsTableResult.Guid, relationResult.ForeignTable, "ForeignTable GUID should match posts table")
|
||||
|
||||
// Check if the key GUIDs match up
|
||||
assert.Equal(t, userPK.Guid, relationResult.PrimaryKey)
|
||||
assert.Equal(t, postsFK.Guid, relationResult.ForeignKey)
|
||||
|
||||
// Verify field mappings exist
|
||||
assert.NotEmpty(t, relationResult.ForeignMappings, "Relation should have ForeignMappings")
|
||||
assert.NotEmpty(t, relationResult.PrimaryMappings, "Relation should have PrimaryMappings")
|
||||
|
||||
// ForeignMapping should reference primary table (users) fields
|
||||
assert.Len(t, relationResult.ForeignMappings, 1)
|
||||
assert.NotEmpty(t, relationResult.ForeignMappings[0].Field)
|
||||
|
||||
// PrimaryMapping should reference foreign table (posts) fields
|
||||
assert.Len(t, relationResult.PrimaryMappings, 1)
|
||||
assert.NotEmpty(t, relationResult.PrimaryMappings[0].Field)
|
||||
}
|
||||
@@ -33,7 +33,7 @@ func TestWriter_WriteTable(t *testing.T) {
|
||||
table.Columns["name"] = nameCol
|
||||
|
||||
// Add index
|
||||
emailIdx := models.InitIndex("idx_users_email")
|
||||
emailIdx := models.InitIndex("idx_users_email", "users", "public")
|
||||
emailIdx.Columns = []string{"email"}
|
||||
emailIdx.Unique = true
|
||||
emailIdx.Table = "users"
|
||||
|
||||
@@ -111,7 +111,7 @@ func TestWriter_WriteDatabase_WithRelationships(t *testing.T) {
|
||||
usersTable.Columns["email"] = emailCol
|
||||
|
||||
// Add index
|
||||
emailIdx := models.InitIndex("idx_users_email")
|
||||
emailIdx := models.InitIndex("idx_users_email", "users", "public")
|
||||
emailIdx.Columns = []string{"email"}
|
||||
emailIdx.Unique = true
|
||||
emailIdx.Type = "btree"
|
||||
|
||||
@@ -1,696 +0,0 @@
|
||||
# PostgreSQL Migration Templates
|
||||
|
||||
## Overview
|
||||
|
||||
The PostgreSQL migration writer uses Go text templates to generate SQL, making the code much more maintainable and customizable than hardcoded string concatenation.
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
pkg/writers/pgsql/
|
||||
├── templates/ # Template files
|
||||
│ ├── create_table.tmpl # CREATE TABLE
|
||||
│ ├── add_column.tmpl # ALTER TABLE ADD COLUMN
|
||||
│ ├── alter_column_type.tmpl # ALTER TABLE ALTER COLUMN TYPE
|
||||
│ ├── alter_column_default.tmpl # ALTER TABLE ALTER COLUMN DEFAULT
|
||||
│ ├── create_primary_key.tmpl # ADD CONSTRAINT PRIMARY KEY
|
||||
│ ├── create_index.tmpl # CREATE INDEX
|
||||
│ ├── create_foreign_key.tmpl # ADD CONSTRAINT FOREIGN KEY
|
||||
│ ├── drop_constraint.tmpl # DROP CONSTRAINT
|
||||
│ ├── drop_index.tmpl # DROP INDEX
|
||||
│ ├── comment_table.tmpl # COMMENT ON TABLE
|
||||
│ ├── comment_column.tmpl # COMMENT ON COLUMN
|
||||
│ ├── audit_tables.tmpl # CREATE audit tables
|
||||
│ ├── audit_function.tmpl # CREATE audit function
|
||||
│ └── audit_trigger.tmpl # CREATE audit trigger
|
||||
├── templates.go # Template executor and data structures
|
||||
└── migration_writer_templated.go # Templated migration writer
|
||||
```
|
||||
|
||||
## Using Templates
|
||||
|
||||
### Basic Usage
|
||||
|
||||
```go
|
||||
// Create template executor
|
||||
executor, err := pgsql.NewTemplateExecutor()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Prepare data
|
||||
data := pgsql.CreateTableData{
|
||||
SchemaName: "public",
|
||||
TableName: "users",
|
||||
Columns: []pgsql.ColumnData{
|
||||
{Name: "id", Type: "integer", NotNull: true},
|
||||
{Name: "name", Type: "text"},
|
||||
},
|
||||
}
|
||||
|
||||
// Execute template
|
||||
sql, err := executor.ExecuteCreateTable(data)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
fmt.Println(sql)
|
||||
```
|
||||
|
||||
### Using Templated Migration Writer
|
||||
|
||||
```go
|
||||
// Create templated migration writer
|
||||
writer, err := pgsql.NewTemplatedMigrationWriter(&writers.WriterOptions{
|
||||
OutputPath: "migration.sql",
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Generate migration (uses templates internally)
|
||||
err = writer.WriteMigration(modelDB, currentDB)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
```
|
||||
|
||||
## Template Data Structures
|
||||
|
||||
### CreateTableData
|
||||
|
||||
For `create_table.tmpl`:
|
||||
|
||||
```go
|
||||
type CreateTableData struct {
|
||||
SchemaName string
|
||||
TableName string
|
||||
Columns []ColumnData
|
||||
}
|
||||
|
||||
type ColumnData struct {
|
||||
Name string
|
||||
Type string
|
||||
Default string
|
||||
NotNull bool
|
||||
}
|
||||
```
|
||||
|
||||
Example:
|
||||
```go
|
||||
data := CreateTableData{
|
||||
SchemaName: "public",
|
||||
TableName: "products",
|
||||
Columns: []ColumnData{
|
||||
{Name: "id", Type: "serial", NotNull: true},
|
||||
{Name: "name", Type: "text", NotNull: true},
|
||||
{Name: "price", Type: "numeric(10,2)", Default: "0.00"},
|
||||
},
|
||||
}
|
||||
```
|
||||
|
||||
### AddColumnData
|
||||
|
||||
For `add_column.tmpl`:
|
||||
|
||||
```go
|
||||
type AddColumnData struct {
|
||||
SchemaName string
|
||||
TableName string
|
||||
ColumnName string
|
||||
ColumnType string
|
||||
Default string
|
||||
NotNull bool
|
||||
}
|
||||
```
|
||||
|
||||
### CreateIndexData
|
||||
|
||||
For `create_index.tmpl`:
|
||||
|
||||
```go
|
||||
type CreateIndexData struct {
|
||||
SchemaName string
|
||||
TableName string
|
||||
IndexName string
|
||||
IndexType string // btree, hash, gin, gist
|
||||
Columns string // comma-separated
|
||||
Unique bool
|
||||
}
|
||||
```
|
||||
|
||||
### CreateForeignKeyData
|
||||
|
||||
For `create_foreign_key.tmpl`:
|
||||
|
||||
```go
|
||||
type CreateForeignKeyData struct {
|
||||
SchemaName string
|
||||
TableName string
|
||||
ConstraintName string
|
||||
SourceColumns string // comma-separated
|
||||
TargetSchema string
|
||||
TargetTable string
|
||||
TargetColumns string // comma-separated
|
||||
OnDelete string // CASCADE, SET NULL, etc.
|
||||
OnUpdate string
|
||||
}
|
||||
```
|
||||
|
||||
### AuditFunctionData
|
||||
|
||||
For `audit_function.tmpl`:
|
||||
|
||||
```go
|
||||
type AuditFunctionData struct {
|
||||
SchemaName string
|
||||
FunctionName string
|
||||
TableName string
|
||||
TablePrefix string
|
||||
PrimaryKey string
|
||||
AuditSchema string
|
||||
UserFunction string
|
||||
AuditInsert bool
|
||||
AuditUpdate bool
|
||||
AuditDelete bool
|
||||
UpdateCondition string
|
||||
UpdateColumns []AuditColumnData
|
||||
DeleteColumns []AuditColumnData
|
||||
}
|
||||
|
||||
type AuditColumnData struct {
|
||||
Name string
|
||||
OldValue string // SQL expression for old value
|
||||
NewValue string // SQL expression for new value
|
||||
}
|
||||
```
|
||||
|
||||
## Customizing Templates
|
||||
|
||||
### Modifying Existing Templates
|
||||
|
||||
Templates are embedded in the binary but can be modified at compile time:
|
||||
|
||||
1. **Edit template file** in `pkg/writers/pgsql/templates/`:
|
||||
|
||||
```go
|
||||
// templates/create_table.tmpl
|
||||
CREATE TABLE IF NOT EXISTS {{.SchemaName}}.{{.TableName}} (
|
||||
{{- range $i, $col := .Columns}}
|
||||
{{- if $i}},{{end}}
|
||||
{{$col.Name}} {{$col.Type}}
|
||||
{{- if $col.Default}} DEFAULT {{$col.Default}}{{end}}
|
||||
{{- if $col.NotNull}} NOT NULL{{end}}
|
||||
{{- end}}
|
||||
);
|
||||
|
||||
-- Custom comment
|
||||
COMMENT ON TABLE {{.SchemaName}}.{{.TableName}} IS 'Auto-generated by RelSpec';
|
||||
```
|
||||
|
||||
2. **Rebuild** the application:
|
||||
|
||||
```bash
|
||||
go build ./cmd/relspec
|
||||
```
|
||||
|
||||
The new template is automatically embedded.
|
||||
|
||||
### Template Syntax Reference
|
||||
|
||||
#### Variables
|
||||
|
||||
```go
|
||||
{{.FieldName}} // Access field
|
||||
{{.SchemaName}} // String field
|
||||
{{.NotNull}} // Boolean field
|
||||
```
|
||||
|
||||
#### Conditionals
|
||||
|
||||
```go
|
||||
{{if .NotNull}}
|
||||
NOT NULL
|
||||
{{end}}
|
||||
|
||||
{{if .Default}}
|
||||
DEFAULT {{.Default}}
|
||||
{{else}}
|
||||
-- No default
|
||||
{{end}}
|
||||
```
|
||||
|
||||
#### Loops
|
||||
|
||||
```go
|
||||
{{range $i, $col := .Columns}}
|
||||
Column: {{$col.Name}} Type: {{$col.Type}}
|
||||
{{end}}
|
||||
```
|
||||
|
||||
#### Functions
|
||||
|
||||
```go
|
||||
{{if eq .Type "CASCADE"}}
|
||||
ON DELETE CASCADE
|
||||
{{end}}
|
||||
|
||||
{{join .Columns ", "}} // Join string slice
|
||||
```
|
||||
|
||||
### Creating New Templates
|
||||
|
||||
1. **Create template file** in `pkg/writers/pgsql/templates/`:
|
||||
|
||||
```go
|
||||
// templates/custom_operation.tmpl
|
||||
-- Custom operation for {{.TableName}}
|
||||
ALTER TABLE {{.SchemaName}}.{{.TableName}}
|
||||
{{.CustomOperation}};
|
||||
```
|
||||
|
||||
2. **Define data structure** in `templates.go`:
|
||||
|
||||
```go
|
||||
type CustomOperationData struct {
|
||||
SchemaName string
|
||||
TableName string
|
||||
CustomOperation string
|
||||
}
|
||||
```
|
||||
|
||||
3. **Add executor method** in `templates.go`:
|
||||
|
||||
```go
|
||||
func (te *TemplateExecutor) ExecuteCustomOperation(data CustomOperationData) (string, error) {
|
||||
var buf bytes.Buffer
|
||||
err := te.templates.ExecuteTemplate(&buf, "custom_operation.tmpl", data)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to execute custom_operation template: %w", err)
|
||||
}
|
||||
return buf.String(), nil
|
||||
}
|
||||
```
|
||||
|
||||
4. **Use in migration writer**:
|
||||
|
||||
```go
|
||||
sql, err := w.executor.ExecuteCustomOperation(CustomOperationData{
|
||||
SchemaName: "public",
|
||||
TableName: "users",
|
||||
CustomOperation: "ADD COLUMN custom_field text",
|
||||
})
|
||||
```
|
||||
|
||||
## Template Examples
|
||||
|
||||
### Example 1: Custom Table Creation
|
||||
|
||||
Modify `create_table.tmpl` to add table options:
|
||||
|
||||
```sql
|
||||
CREATE TABLE IF NOT EXISTS {{.SchemaName}}.{{.TableName}} (
|
||||
{{- range $i, $col := .Columns}}
|
||||
{{- if $i}},{{end}}
|
||||
{{$col.Name}} {{$col.Type}}
|
||||
{{- if $col.Default}} DEFAULT {{$col.Default}}{{end}}
|
||||
{{- if $col.NotNull}} NOT NULL{{end}}
|
||||
{{- end}}
|
||||
) WITH (fillfactor = 90);
|
||||
|
||||
-- Add automatic comment
|
||||
COMMENT ON TABLE {{.SchemaName}}.{{.TableName}}
|
||||
IS 'Created: {{.CreatedDate}} | Version: {{.Version}}';
|
||||
```
|
||||
|
||||
### Example 2: Custom Index with WHERE Clause
|
||||
|
||||
Add to `create_index.tmpl`:
|
||||
|
||||
```sql
|
||||
CREATE {{if .Unique}}UNIQUE {{end}}INDEX IF NOT EXISTS {{.IndexName}}
|
||||
ON {{.SchemaName}}.{{.TableName}}
|
||||
USING {{.IndexType}} ({{.Columns}})
|
||||
{{- if .Where}}
|
||||
WHERE {{.Where}}
|
||||
{{- end}}
|
||||
{{- if .Include}}
|
||||
INCLUDE ({{.Include}})
|
||||
{{- end}};
|
||||
```
|
||||
|
||||
Update data structure:
|
||||
|
||||
```go
|
||||
type CreateIndexData struct {
|
||||
SchemaName string
|
||||
TableName string
|
||||
IndexName string
|
||||
IndexType string
|
||||
Columns string
|
||||
Unique bool
|
||||
Where string // New field for partial indexes
|
||||
Include string // New field for covering indexes
|
||||
}
|
||||
```
|
||||
|
||||
### Example 3: Enhanced Audit Function
|
||||
|
||||
Modify `audit_function.tmpl` to add custom logging:
|
||||
|
||||
```sql
|
||||
CREATE OR REPLACE FUNCTION {{.SchemaName}}.{{.FunctionName}}()
|
||||
RETURNS trigger AS
|
||||
$body$
|
||||
DECLARE
|
||||
m_funcname text = '{{.FunctionName}}';
|
||||
m_user text;
|
||||
m_atevent integer;
|
||||
m_application_name text;
|
||||
BEGIN
|
||||
-- Get current user and application
|
||||
m_user := {{.UserFunction}}::text;
|
||||
m_application_name := current_setting('application_name', true);
|
||||
|
||||
-- Custom logging
|
||||
RAISE NOTICE 'Audit: % on %.% by % from %',
|
||||
TG_OP, TG_TABLE_SCHEMA, TG_TABLE_NAME, m_user, m_application_name;
|
||||
|
||||
-- Rest of function...
|
||||
...
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
### 1. Keep Templates Simple
|
||||
|
||||
Templates should focus on SQL generation. Complex logic belongs in Go code:
|
||||
|
||||
**Good:**
|
||||
```go
|
||||
// In Go code
|
||||
columns := buildColumnList(table)
|
||||
|
||||
// In template
|
||||
{{range .Columns}}
|
||||
{{.Name}} {{.Type}}
|
||||
{{end}}
|
||||
```
|
||||
|
||||
**Bad:**
|
||||
```go
|
||||
// Don't do complex transformations in templates
|
||||
{{range .Columns}}
|
||||
{{if eq .Type "integer"}}
|
||||
{{.Name}} serial
|
||||
{{else}}
|
||||
{{.Name}} {{.Type}}
|
||||
{{end}}
|
||||
{{end}}
|
||||
```
|
||||
|
||||
### 2. Use Descriptive Field Names
|
||||
|
||||
```go
|
||||
// Good
|
||||
type CreateTableData struct {
|
||||
SchemaName string
|
||||
TableName string
|
||||
}
|
||||
|
||||
// Bad
|
||||
type CreateTableData struct {
|
||||
S string // What is S?
|
||||
T string // What is T?
|
||||
}
|
||||
```
|
||||
|
||||
### 3. Document Template Data
|
||||
|
||||
Always document what data a template expects:
|
||||
|
||||
```go
|
||||
// CreateTableData contains data for create table template.
|
||||
// Used by templates/create_table.tmpl
|
||||
type CreateTableData struct {
|
||||
SchemaName string // Schema where table will be created
|
||||
TableName string // Name of the table
|
||||
Columns []ColumnData // List of columns to create
|
||||
}
|
||||
```
|
||||
|
||||
### 4. Handle SQL Injection
|
||||
|
||||
Always escape user input:
|
||||
|
||||
```go
|
||||
// In Go code - escape before passing to template
|
||||
data := CommentTableData{
|
||||
SchemaName: schema,
|
||||
TableName: table,
|
||||
Comment: escapeQuote(userComment), // Escape quotes
|
||||
}
|
||||
```
|
||||
|
||||
### 5. Test Templates Thoroughly
|
||||
|
||||
```go
|
||||
func TestTemplate_CreateTable(t *testing.T) {
|
||||
executor, _ := NewTemplateExecutor()
|
||||
|
||||
data := CreateTableData{
|
||||
SchemaName: "public",
|
||||
TableName: "test",
|
||||
Columns: []ColumnData{{Name: "id", Type: "integer"}},
|
||||
}
|
||||
|
||||
sql, err := executor.ExecuteCreateTable(data)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Verify expected SQL patterns
|
||||
if !strings.Contains(sql, "CREATE TABLE") {
|
||||
t.Error("Missing CREATE TABLE")
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Benefits of Template-Based Approach
|
||||
|
||||
### Maintainability
|
||||
|
||||
**Before (string concatenation):**
|
||||
```go
|
||||
sql := fmt.Sprintf(`CREATE TABLE %s.%s (
|
||||
%s %s%s%s
|
||||
);`, schema, table, col, typ,
|
||||
func() string {
|
||||
if def != "" {
|
||||
return " DEFAULT " + def
|
||||
}
|
||||
return ""
|
||||
}(),
|
||||
func() string {
|
||||
if notNull {
|
||||
return " NOT NULL"
|
||||
}
|
||||
return ""
|
||||
}(),
|
||||
)
|
||||
```
|
||||
|
||||
**After (templates):**
|
||||
```go
|
||||
sql, _ := executor.ExecuteCreateTable(CreateTableData{
|
||||
SchemaName: schema,
|
||||
TableName: table,
|
||||
Columns: columns,
|
||||
})
|
||||
```
|
||||
|
||||
### Customization
|
||||
|
||||
Users can modify templates without changing Go code:
|
||||
- Edit template file
|
||||
- Rebuild application
|
||||
- New SQL generation logic active
|
||||
|
||||
### Testing
|
||||
|
||||
Templates can be tested independently:
|
||||
```go
|
||||
func TestAuditTemplate(t *testing.T) {
|
||||
executor, _ := NewTemplateExecutor()
|
||||
|
||||
// Test with various data
|
||||
for _, testCase := range testCases {
|
||||
sql, err := executor.ExecuteAuditFunction(testCase.data)
|
||||
// Verify output
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Readability
|
||||
|
||||
SQL templates are easier to read and review than Go string building code.
|
||||
|
||||
## Migration from Old Writer
|
||||
|
||||
To migrate from the old string-based writer to templates:
|
||||
|
||||
### Option 1: Use TemplatedMigrationWriter
|
||||
|
||||
```go
|
||||
// Old
|
||||
writer := pgsql.NewMigrationWriter(options)
|
||||
|
||||
// New
|
||||
writer, err := pgsql.NewTemplatedMigrationWriter(options)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Same interface
|
||||
writer.WriteMigration(model, current)
|
||||
```
|
||||
|
||||
### Option 2: Keep Both
|
||||
|
||||
Both writers are available:
|
||||
- `MigrationWriter` - Original string-based
|
||||
- `TemplatedMigrationWriter` - New template-based
|
||||
|
||||
Choose based on your needs.
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Template Not Found
|
||||
|
||||
```
|
||||
Error: template: "my_template.tmpl" not defined
|
||||
```
|
||||
|
||||
Solution: Ensure template file exists in `templates/` directory and rebuild.
|
||||
|
||||
### Template Execution Error
|
||||
|
||||
```
|
||||
Error: template: create_table.tmpl:5:10: executing "create_table.tmpl"
|
||||
at <.InvalidField>: can't evaluate field InvalidField
|
||||
```
|
||||
|
||||
Solution: Check data structure has all fields used in template.
|
||||
|
||||
### Embedded Files Not Updating
|
||||
|
||||
If template changes aren't reflected:
|
||||
|
||||
1. Clean build cache: `go clean -cache`
|
||||
2. Rebuild: `go build ./cmd/relspec`
|
||||
3. Verify template file is in `templates/` directory
|
||||
|
||||
## Custom Template Functions
|
||||
|
||||
RelSpec provides a comprehensive library of template functions for SQL generation:
|
||||
|
||||
### String Manipulation
|
||||
- `upper`, `lower` - Case conversion
|
||||
- `snake_case`, `camelCase` - Naming convention conversion
|
||||
- Usage: `{{upper .TableName}}` → `USERS`
|
||||
|
||||
### SQL Formatting
|
||||
- `indent(spaces, text)` - Indent text
|
||||
- `quote(string)` - Quote for SQL with escaping
|
||||
- `escape(string)` - Escape special characters
|
||||
- `safe_identifier(string)` - Make SQL-safe identifier
|
||||
- Usage: `{{quote "O'Brien"}}` → `'O''Brien'`
|
||||
|
||||
### Type Conversion
|
||||
- `goTypeToSQL(type)` - Convert Go type to PostgreSQL type
|
||||
- `sqlTypeToGo(type)` - Convert PostgreSQL type to Go type
|
||||
- `isNumeric(type)`, `isText(type)` - Type checking
|
||||
- Usage: `{{goTypeToSQL "int64"}}` → `bigint`
|
||||
|
||||
### Collection Helpers
|
||||
- `first(slice)`, `last(slice)` - Get elements
|
||||
- `join_with(slice, sep)` - Join with custom separator
|
||||
- Usage: `{{join_with .Columns ", "}}` → `id, name, email`
|
||||
|
||||
See [template_functions.go](template_functions.go) for full documentation.
|
||||
|
||||
## Template Inheritance and Composition
|
||||
|
||||
RelSpec supports Go template inheritance using `{{template}}` and `{{block}}`:
|
||||
|
||||
### Base Templates
|
||||
- `base_ddl.tmpl` - Common DDL patterns
|
||||
- `base_constraint.tmpl` - Constraint operations
|
||||
- `fragments.tmpl` - Reusable fragments
|
||||
|
||||
### Using Fragments
|
||||
```gotmpl
|
||||
{{/* Use predefined fragments */}}
|
||||
CREATE TABLE {{template "qualified_table" .}} (
|
||||
{{range .Columns}}
|
||||
{{template "column_definition" .}}
|
||||
{{end}}
|
||||
);
|
||||
```
|
||||
|
||||
### Template Blocks
|
||||
```gotmpl
|
||||
{{/* Define with override capability */}}
|
||||
{{define "table_options"}}
|
||||
) {{block "storage_options" .}}WITH (fillfactor = 90){{end}};
|
||||
{{end}}
|
||||
```
|
||||
|
||||
See [TEMPLATE_INHERITANCE.md](TEMPLATE_INHERITANCE.md) for detailed guide.
|
||||
|
||||
## Visual Template Editor
|
||||
|
||||
A VS Code extension is available for visual template editing:
|
||||
|
||||
### Features
|
||||
- **Live Preview** - See rendered SQL as you type
|
||||
- **IntelliSense** - Auto-completion for functions
|
||||
- **Validation** - Syntax checking and error highlighting
|
||||
- **Scaffolding** - Quick template creation
|
||||
- **Function Browser** - Browse available functions
|
||||
|
||||
### Installation
|
||||
```bash
|
||||
cd vscode-extension
|
||||
npm install
|
||||
npm run compile
|
||||
code .
|
||||
# Press F5 to launch
|
||||
```
|
||||
|
||||
See [vscode-extension/README.md](../../vscode-extension/README.md) for full documentation.
|
||||
|
||||
## Future Enhancements
|
||||
|
||||
Completed:
|
||||
- [x] Template inheritance/composition
|
||||
- [x] Custom template functions library
|
||||
- [x] Visual template editor (VS Code)
|
||||
|
||||
Potential future improvements:
|
||||
- [ ] Parameterized templates (load from config)
|
||||
- [ ] Template validation CLI tool
|
||||
- [ ] Template library/marketplace
|
||||
- [ ] Template versioning
|
||||
- [ ] Hot-reload during development
|
||||
|
||||
## Contributing Templates
|
||||
|
||||
When contributing new templates:
|
||||
|
||||
1. Place in `pkg/writers/pgsql/templates/`
|
||||
2. Use `.tmpl` extension
|
||||
3. Document data structure in `templates.go`
|
||||
4. Add executor method
|
||||
5. Write tests
|
||||
6. Update this documentation
|
||||
@@ -62,6 +62,234 @@ func (w *Writer) WriteDatabase(db *models.Database) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GenerateDatabaseStatements generates SQL statements as a list for the entire database
|
||||
// Returns a slice of SQL statements that can be executed independently
|
||||
func (w *Writer) GenerateDatabaseStatements(db *models.Database) ([]string, error) {
|
||||
statements := []string{}
|
||||
|
||||
// Add header comment
|
||||
statements = append(statements, fmt.Sprintf("-- PostgreSQL Database Schema"))
|
||||
statements = append(statements, fmt.Sprintf("-- Database: %s", db.Name))
|
||||
statements = append(statements, fmt.Sprintf("-- Generated by RelSpec"))
|
||||
|
||||
// Process each schema in the database
|
||||
for _, schema := range db.Schemas {
|
||||
schemaStatements, err := w.GenerateSchemaStatements(schema)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate statements for schema %s: %w", schema.Name, err)
|
||||
}
|
||||
statements = append(statements, schemaStatements...)
|
||||
}
|
||||
|
||||
return statements, nil
|
||||
}
|
||||
|
||||
// GenerateSchemaStatements generates SQL statements as a list for a single schema
|
||||
func (w *Writer) GenerateSchemaStatements(schema *models.Schema) ([]string, error) {
|
||||
statements := []string{}
|
||||
|
||||
// Phase 1: Create schema
|
||||
if schema.Name != "public" {
|
||||
statements = append(statements, fmt.Sprintf("-- Schema: %s", schema.Name))
|
||||
statements = append(statements, fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS %s", schema.SQLName()))
|
||||
}
|
||||
|
||||
// Phase 2: Create sequences
|
||||
for _, table := range schema.Tables {
|
||||
pk := table.GetPrimaryKey()
|
||||
if pk == nil || !isIntegerType(pk.Type) || pk.Default == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
defaultStr, ok := pk.Default.(string)
|
||||
if !ok || !strings.Contains(strings.ToLower(defaultStr), "nextval") {
|
||||
continue
|
||||
}
|
||||
|
||||
seqName := extractSequenceName(defaultStr)
|
||||
if seqName == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
stmt := fmt.Sprintf("CREATE SEQUENCE IF NOT EXISTS %s.%s\n INCREMENT 1\n MINVALUE 1\n MAXVALUE 9223372036854775807\n START 1\n CACHE 1",
|
||||
schema.SQLName(), seqName)
|
||||
statements = append(statements, stmt)
|
||||
}
|
||||
|
||||
// Phase 3: Create tables
|
||||
for _, table := range schema.Tables {
|
||||
stmts, err := w.generateCreateTableStatement(schema, table)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate table %s: %w", table.Name, err)
|
||||
}
|
||||
statements = append(statements, stmts...)
|
||||
}
|
||||
|
||||
// Phase 4: Primary keys
|
||||
for _, table := range schema.Tables {
|
||||
for _, constraint := range table.Constraints {
|
||||
if constraint.Type != models.PrimaryKeyConstraint {
|
||||
continue
|
||||
}
|
||||
stmt := fmt.Sprintf("ALTER TABLE %s.%s ADD CONSTRAINT %s PRIMARY KEY (%s)",
|
||||
schema.SQLName(), table.SQLName(), constraint.Name, strings.Join(constraint.Columns, ", "))
|
||||
statements = append(statements, stmt)
|
||||
}
|
||||
}
|
||||
|
||||
// Phase 5: Indexes
|
||||
for _, table := range schema.Tables {
|
||||
for _, index := range table.Indexes {
|
||||
// Skip primary key indexes
|
||||
if strings.HasSuffix(index.Name, "_pkey") {
|
||||
continue
|
||||
}
|
||||
|
||||
uniqueStr := ""
|
||||
if index.Unique {
|
||||
uniqueStr = "UNIQUE "
|
||||
}
|
||||
|
||||
indexType := index.Type
|
||||
if indexType == "" {
|
||||
indexType = "btree"
|
||||
}
|
||||
|
||||
whereClause := ""
|
||||
if index.Where != "" {
|
||||
whereClause = fmt.Sprintf(" WHERE %s", index.Where)
|
||||
}
|
||||
|
||||
stmt := fmt.Sprintf("CREATE %sINDEX IF NOT EXISTS %s ON %s.%s USING %s (%s)%s",
|
||||
uniqueStr, index.Name, schema.SQLName(), table.SQLName(), indexType, strings.Join(index.Columns, ", "), whereClause)
|
||||
statements = append(statements, stmt)
|
||||
}
|
||||
}
|
||||
|
||||
// Phase 6: Foreign keys
|
||||
for _, table := range schema.Tables {
|
||||
for _, constraint := range table.Constraints {
|
||||
if constraint.Type != models.ForeignKeyConstraint {
|
||||
continue
|
||||
}
|
||||
|
||||
refSchema := constraint.ReferencedSchema
|
||||
if refSchema == "" {
|
||||
refSchema = schema.Name
|
||||
}
|
||||
|
||||
onDelete := constraint.OnDelete
|
||||
if onDelete == "" {
|
||||
onDelete = "NO ACTION"
|
||||
}
|
||||
|
||||
onUpdate := constraint.OnUpdate
|
||||
if onUpdate == "" {
|
||||
onUpdate = "NO ACTION"
|
||||
}
|
||||
|
||||
stmt := fmt.Sprintf("ALTER TABLE %s.%s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s.%s(%s) ON DELETE %s ON UPDATE %s",
|
||||
schema.SQLName(), table.SQLName(), constraint.Name,
|
||||
strings.Join(constraint.Columns, ", "),
|
||||
strings.ToLower(refSchema), strings.ToLower(constraint.ReferencedTable),
|
||||
strings.Join(constraint.ReferencedColumns, ", "),
|
||||
onDelete, onUpdate)
|
||||
statements = append(statements, stmt)
|
||||
}
|
||||
}
|
||||
|
||||
// Phase 7: Comments
|
||||
for _, table := range schema.Tables {
|
||||
if table.Comment != "" {
|
||||
stmt := fmt.Sprintf("COMMENT ON TABLE %s.%s IS '%s'",
|
||||
schema.SQLName(), table.SQLName(), escapeQuote(table.Comment))
|
||||
statements = append(statements, stmt)
|
||||
}
|
||||
|
||||
for _, column := range table.Columns {
|
||||
if column.Comment != "" {
|
||||
stmt := fmt.Sprintf("COMMENT ON COLUMN %s.%s.%s IS '%s'",
|
||||
schema.SQLName(), table.SQLName(), column.SQLName(), escapeQuote(column.Comment))
|
||||
statements = append(statements, stmt)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return statements, nil
|
||||
}
|
||||
|
||||
// generateCreateTableStatement generates CREATE TABLE statement
|
||||
func (w *Writer) generateCreateTableStatement(schema *models.Schema, table *models.Table) ([]string, error) {
|
||||
statements := []string{}
|
||||
|
||||
// Sort columns by sequence or name
|
||||
columns := make([]*models.Column, 0, len(table.Columns))
|
||||
for _, col := range table.Columns {
|
||||
columns = append(columns, col)
|
||||
}
|
||||
sort.Slice(columns, func(i, j int) bool {
|
||||
if columns[i].Sequence != columns[j].Sequence {
|
||||
return columns[i].Sequence < columns[j].Sequence
|
||||
}
|
||||
return columns[i].Name < columns[j].Name
|
||||
})
|
||||
|
||||
columnDefs := []string{}
|
||||
for _, col := range columns {
|
||||
def := w.generateColumnDefinition(col)
|
||||
columnDefs = append(columnDefs, " "+def)
|
||||
}
|
||||
|
||||
stmt := fmt.Sprintf("CREATE TABLE %s.%s (\n%s\n)",
|
||||
schema.SQLName(), table.SQLName(), strings.Join(columnDefs, ",\n"))
|
||||
statements = append(statements, stmt)
|
||||
|
||||
return statements, nil
|
||||
}
|
||||
|
||||
// generateColumnDefinition generates column definition
|
||||
func (w *Writer) generateColumnDefinition(col *models.Column) string {
|
||||
parts := []string{col.SQLName()}
|
||||
|
||||
// Type with length/precision
|
||||
typeStr := col.Type
|
||||
if col.Length > 0 && col.Precision == 0 {
|
||||
typeStr = fmt.Sprintf("%s(%d)", col.Type, col.Length)
|
||||
} else if col.Precision > 0 {
|
||||
if col.Scale > 0 {
|
||||
typeStr = fmt.Sprintf("%s(%d,%d)", col.Type, col.Precision, col.Scale)
|
||||
} else {
|
||||
typeStr = fmt.Sprintf("%s(%d)", col.Type, col.Precision)
|
||||
}
|
||||
}
|
||||
parts = append(parts, typeStr)
|
||||
|
||||
// NOT NULL
|
||||
if col.NotNull {
|
||||
parts = append(parts, "NOT NULL")
|
||||
}
|
||||
|
||||
// DEFAULT
|
||||
if col.Default != nil {
|
||||
switch v := col.Default.(type) {
|
||||
case string:
|
||||
if strings.HasPrefix(v, "nextval") || strings.HasPrefix(v, "CURRENT_") || strings.Contains(v, "()") {
|
||||
parts = append(parts, fmt.Sprintf("DEFAULT %s", v))
|
||||
} else if v == "true" || v == "false" {
|
||||
parts = append(parts, fmt.Sprintf("DEFAULT %s", v))
|
||||
} else {
|
||||
parts = append(parts, fmt.Sprintf("DEFAULT '%s'", escapeQuote(v)))
|
||||
}
|
||||
case bool:
|
||||
parts = append(parts, fmt.Sprintf("DEFAULT %v", v))
|
||||
default:
|
||||
parts = append(parts, fmt.Sprintf("DEFAULT %v", v))
|
||||
}
|
||||
}
|
||||
|
||||
return strings.Join(parts, " ")
|
||||
}
|
||||
|
||||
// WriteSchema writes a single schema and all its tables
|
||||
func (w *Writer) WriteSchema(schema *models.Schema) error {
|
||||
if w.writer == nil {
|
||||
@@ -494,3 +722,26 @@ func isIntegerType(colType string) bool {
|
||||
func escapeQuote(s string) string {
|
||||
return strings.ReplaceAll(s, "'", "''")
|
||||
}
|
||||
|
||||
// extractSequenceName extracts sequence name from nextval() expression
|
||||
// Example: "nextval('public.users_id_seq'::regclass)" returns "users_id_seq"
|
||||
func extractSequenceName(defaultExpr string) string {
|
||||
// Look for nextval('schema.sequence_name'::regclass) pattern
|
||||
start := strings.Index(defaultExpr, "'")
|
||||
if start == -1 {
|
||||
return ""
|
||||
}
|
||||
end := strings.Index(defaultExpr[start+1:], "'")
|
||||
if end == -1 {
|
||||
return ""
|
||||
}
|
||||
|
||||
fullName := defaultExpr[start+1 : start+1+end]
|
||||
|
||||
// Remove schema prefix if present
|
||||
parts := strings.Split(fullName, ".")
|
||||
if len(parts) > 1 {
|
||||
return parts[len(parts)-1]
|
||||
}
|
||||
return fullName
|
||||
}
|
||||
|
||||
@@ -112,7 +112,7 @@ func TestWriter_WriteDatabase_WithRelationships(t *testing.T) {
|
||||
usersTable.Columns["email"] = emailCol
|
||||
|
||||
// Add index
|
||||
emailIdx := models.InitIndex("idx_users_email")
|
||||
emailIdx := models.InitIndex("idx_users_email", "users", "public")
|
||||
emailIdx.Columns = []string{"email"}
|
||||
emailIdx.Unique = true
|
||||
emailIdx.Type = "btree"
|
||||
|
||||
Reference in New Issue
Block a user