More Roundtrip tests
Some checks are pending
CI / Test (1.23) (push) Waiting to run
CI / Test (1.24) (push) Waiting to run
CI / Test (1.25) (push) Waiting to run
CI / Lint (push) Waiting to run
CI / Build (push) Waiting to run

This commit is contained in:
2025-12-17 22:52:24 +02:00
parent 5e1448dcdb
commit a427aa5537
23 changed files with 22897 additions and 1319 deletions

View File

@@ -1,4 +1,4 @@
package dctx
package models
import "encoding/xml"
@@ -8,7 +8,7 @@ type DCTXDictionary struct {
Name string `xml:"Name,attr"`
Version string `xml:"Version,attr"`
Tables []DCTXTable `xml:"Table"`
Relations []DCTXRelation `xml:"Relation"`
Relations []DCTXRelation `xml:"Relation,omitempty"`
}
// DCTXTable represents a table definition in DCTX
@@ -16,13 +16,13 @@ type DCTXTable struct {
Guid string `xml:"Guid,attr"`
Name string `xml:"Name,attr"`
Prefix string `xml:"Prefix,attr"`
Driver string `xml:"Driver,attr"`
Owner string `xml:"Owner,attr"`
Path string `xml:"Path,attr"`
Description string `xml:"Description,attr"`
Driver string `xml:"Driver,attr,omitempty"`
Owner string `xml:"Owner,attr,omitempty"`
Path string `xml:"Path,attr,omitempty"`
Description string `xml:"Description,attr,omitempty"`
Fields []DCTXField `xml:"Field"`
Keys []DCTXKey `xml:"Key"`
Options []DCTXOption `xml:"Option"`
Keys []DCTXKey `xml:"Key,omitempty"`
Options []DCTXOption `xml:"Option,omitempty"`
}
// DCTXField represents a field/column definition in DCTX
@@ -30,37 +30,37 @@ type DCTXField struct {
Guid string `xml:"Guid,attr"`
Name string `xml:"Name,attr"`
DataType string `xml:"DataType,attr"`
Size int `xml:"Size,attr"`
NoPopulate bool `xml:"NoPopulate,attr"`
Thread bool `xml:"Thread,attr"`
Fields []DCTXField `xml:"Field"` // For GROUP fields (nested structures)
Options []DCTXOption `xml:"Option"`
Size int `xml:"Size,attr,omitempty"`
NoPopulate bool `xml:"NoPopulate,attr,omitempty"`
Thread bool `xml:"Thread,attr,omitempty"`
Fields []DCTXField `xml:"Field,omitempty"` // For GROUP fields (nested structures)
Options []DCTXOption `xml:"Option,omitempty"`
}
// DCTXKey represents an index or key definition in DCTX
type DCTXKey struct {
Guid string `xml:"Guid,attr"`
Name string `xml:"Name,attr"`
KeyType string `xml:"KeyType,attr"`
Primary bool `xml:"Primary,attr"`
Unique bool `xml:"Unique,attr"`
Order int `xml:"Order,attr"`
Description string `xml:"Description,attr"`
KeyType string `xml:"KeyType,attr,omitempty"`
Primary bool `xml:"Primary,attr,omitempty"`
Unique bool `xml:"Unique,attr,omitempty"`
Order int `xml:"Order,attr,omitempty"`
Description string `xml:"Description,attr,omitempty"`
Components []DCTXComponent `xml:"Component"`
}
// DCTXComponent represents a component of a key (field reference)
type DCTXComponent struct {
Guid string `xml:"Guid,attr"`
FieldId string `xml:"FieldId,attr"`
FieldId string `xml:"FieldId,attr,omitempty"`
Order int `xml:"Order,attr"`
Ascend bool `xml:"Ascend,attr"`
Ascend bool `xml:"Ascend,attr,omitempty"`
}
// DCTXOption represents a property option in DCTX
type DCTXOption struct {
Property string `xml:"Property,attr"`
PropertyType string `xml:"PropertyType,attr"`
PropertyType string `xml:"PropertyType,attr,omitempty"`
PropertyValue string `xml:"PropertyValue,attr"`
}
@@ -69,12 +69,12 @@ type DCTXRelation struct {
Guid string `xml:"Guid,attr"`
PrimaryTable string `xml:"PrimaryTable,attr"`
ForeignTable string `xml:"ForeignTable,attr"`
PrimaryKey string `xml:"PrimaryKey,attr"`
ForeignKey string `xml:"ForeignKey,attr"`
Delete string `xml:"Delete,attr"`
Update string `xml:"Update,attr"`
ForeignMappings []DCTXFieldMapping `xml:"ForeignMapping"`
PrimaryMappings []DCTXFieldMapping `xml:"PrimaryMapping"`
PrimaryKey string `xml:"PrimaryKey,attr,omitempty"`
ForeignKey string `xml:"ForeignKey,attr,omitempty"`
Delete string `xml:"Delete,attr,omitempty"`
Update string `xml:"Update,attr,omitempty"`
ForeignMappings []DCTXFieldMapping `xml:"ForeignMapping,omitempty"`
PrimaryMappings []DCTXFieldMapping `xml:"PrimaryMapping,omitempty"`
}
// DCTXFieldMapping represents a field mapping in a relation

View File

@@ -38,7 +38,8 @@ type Schema struct {
Metadata map[string]any `json:"metadata,omitempty" yaml:"metadata,omitempty" xml:"-"`
Scripts []*Script `json:"scripts,omitempty" yaml:"scripts,omitempty" xml:"scripts,omitempty"`
Sequence uint `json:"sequence,omitempty" yaml:"sequence,omitempty" xml:"sequence,omitempty"`
RefDatabase *Database `json:"ref_database,omitempty" yaml:"ref_database,omitempty" xml:"ref_database,omitempty"`
RefDatabase *Database `json:"-" yaml:"-" xml:"-"` // Excluded to prevent circular references
Relations []*Relationship `json:"relations,omitempty" yaml:"relations,omitempty" xml:"-"`
}
// SQLName returns the schema name in lowercase
@@ -58,7 +59,7 @@ type Table struct {
Tablespace string `json:"tablespace,omitempty" yaml:"tablespace,omitempty" xml:"tablespace,omitempty"`
Metadata map[string]any `json:"metadata,omitempty" yaml:"metadata,omitempty" xml:"-"`
Sequence uint `json:"sequence,omitempty" yaml:"sequence,omitempty" xml:"sequence,omitempty"`
RefSchema *Schema `json:"ref_schema,omitempty" yaml:"ref_schema,omitempty" xml:"ref_schema,omitempty"`
RefSchema *Schema `json:"-" yaml:"-" xml:"-"` // Excluded to prevent circular references
}
// SQLName returns the table name in lowercase
@@ -96,7 +97,7 @@ type View struct {
Comment string `json:"comment,omitempty" yaml:"comment,omitempty" xml:"comment,omitempty"`
Metadata map[string]any `json:"metadata,omitempty" yaml:"metadata,omitempty" xml:"-"`
Sequence uint `json:"sequence,omitempty" yaml:"sequence,omitempty" xml:"sequence,omitempty"`
RefSchema *Schema `json:"ref_schema,omitempty" yaml:"ref_schema,omitempty" xml:"ref_schema,omitempty"`
RefSchema *Schema `json:"-" yaml:"-" xml:"-"` // Excluded to prevent circular references
}
// SQLName returns the view name in lowercase
@@ -119,7 +120,7 @@ type Sequence struct {
OwnedByColumn string `json:"owned_by_column,omitempty" yaml:"owned_by_column,omitempty" xml:"owned_by_column,omitempty"`
Comment string `json:"comment,omitempty" yaml:"comment,omitempty" xml:"comment,omitempty"`
Sequence uint `json:"sequence,omitempty" yaml:"sequence,omitempty" xml:"sequence,omitempty"`
RefSchema *Schema `json:"ref_schema,omitempty" yaml:"ref_schema,omitempty" xml:"ref_schema,omitempty"`
RefSchema *Schema `json:"-" yaml:"-" xml:"-"` // Excluded to prevent circular references
}
// SQLName returns the sequence name in lowercase
@@ -184,8 +185,10 @@ type Relationship struct {
Type RelationType `json:"type" yaml:"type" xml:"type"`
FromTable string `json:"from_table" yaml:"from_table" xml:"from_table"`
FromSchema string `json:"from_schema" yaml:"from_schema" xml:"from_schema"`
FromColumns []string `json:"from_columns" yaml:"from_columns" xml:"from_columns"`
ToTable string `json:"to_table" yaml:"to_table" xml:"to_table"`
ToSchema string `json:"to_schema" yaml:"to_schema" xml:"to_schema"`
ToColumns []string `json:"to_columns" yaml:"to_columns" xml:"to_columns"`
ForeignKey string `json:"foreign_key" yaml:"foreign_key" xml:"foreign_key"`
Properties map[string]string `json:"properties" yaml:"properties" xml:"-"`
ThroughTable string `json:"through_table,omitempty" yaml:"through_table,omitempty" xml:"through_table,omitempty"` // For many-to-many
@@ -292,14 +295,28 @@ func InitColumn(name, table, schema string) *Column {
}
// InitIndex initializes a new Index with empty slices
func InitIndex(name string) *Index {
func InitIndex(name, table, schema string) *Index {
return &Index{
Name: name,
Table: table,
Schema: schema,
Columns: make([]string, 0),
Include: make([]string, 0),
}
}
// InitRelation initializes a new Relationship with empty slices
func InitRelation(name, schema string) *Relationship {
return &Relationship{
Name: name,
FromSchema: schema,
ToSchema: schema,
Properties: make(map[string]string),
FromColumns: make([]string, 0),
ToColumns: make([]string, 0),
}
}
// InitRelationship initializes a new Relationship with empty maps
func InitRelationship(name string, relType RelationType) *Relationship {
return &Relationship{

View File

@@ -272,6 +272,10 @@ func (r *Reader) parseColumn(line, tableName, schemaName string) (*models.Column
if constraint == nil {
constraint = uniqueConstraint
}
} else if strings.HasPrefix(attr, "note:") {
// Parse column note/comment
note := strings.TrimSpace(strings.TrimPrefix(attr, "note:"))
column.Comment = strings.Trim(note, "'\"")
} else if strings.HasPrefix(attr, "ref:") {
// Parse inline reference
// DBML semantics depend on context:
@@ -355,7 +359,7 @@ func (r *Reader) parseIndex(line, tableName, schemaName string) *models.Index {
return nil
}
index := models.InitIndex("")
index := models.InitIndex("", tableName, schemaName)
index.Table = tableName
index.Schema = schemaName
index.Columns = columns

View File

@@ -33,7 +33,7 @@ func (r *Reader) ReadDatabase() (*models.Database, error) {
return nil, fmt.Errorf("failed to read file: %w", err)
}
var dctx DCTXDictionary
var dctx models.DCTXDictionary
if err := xml.Unmarshal(data, &dctx); err != nil {
return nil, fmt.Errorf("failed to parse DCTX XML: %w", err)
}
@@ -70,7 +70,7 @@ func (r *Reader) ReadTable() (*models.Table, error) {
}
// convertToDatabase converts a DCTX dictionary to a Database model
func (r *Reader) convertToDatabase(dctx *DCTXDictionary) (*models.Database, error) {
func (r *Reader) convertToDatabase(dctx *models.DCTXDictionary) (*models.Database, error) {
dbName := dctx.Name
if dbName == "" {
dbName = "database"
@@ -81,7 +81,7 @@ func (r *Reader) convertToDatabase(dctx *DCTXDictionary) (*models.Database, erro
// Create GUID mappings for tables and keys
tableGuidMap := make(map[string]string) // GUID -> table name
keyGuidMap := make(map[string]*DCTXKey) // GUID -> key definition
keyGuidMap := make(map[string]*models.DCTXKey) // GUID -> key definition
keyTableMap := make(map[string]string) // key GUID -> table name
fieldGuidMaps := make(map[string]map[string]string) // table name -> field GUID -> field name
@@ -135,7 +135,7 @@ func (r *Reader) convertToDatabase(dctx *DCTXDictionary) (*models.Database, erro
}
// hasSQLOption checks if a DCTX table has the SQL option set to "1"
func (r *Reader) hasSQLOption(dctxTable *DCTXTable) bool {
func (r *Reader) hasSQLOption(dctxTable *models.DCTXTable) bool {
for _, option := range dctxTable.Options {
if option.Property == "SQL" && option.PropertyValue == "1" {
return true
@@ -144,8 +144,21 @@ func (r *Reader) hasSQLOption(dctxTable *DCTXTable) bool {
return false
}
// collectFieldGuids recursively collects all field GUIDs from a field and its nested fields
func (r *Reader) collectFieldGuids(dctxField *models.DCTXField, guidMap map[string]string) {
// Store the current field's GUID if available
if dctxField.Guid != "" && dctxField.Name != "" {
guidMap[dctxField.Guid] = r.sanitizeName(dctxField.Name)
}
// Recursively process nested fields (for GROUP types)
for i := range dctxField.Fields {
r.collectFieldGuids(&dctxField.Fields[i], guidMap)
}
}
// convertTable converts a DCTX table to a Table model
func (r *Reader) convertTable(dctxTable *DCTXTable) (*models.Table, map[string]string, error) {
func (r *Reader) convertTable(dctxTable *models.DCTXTable) (*models.Table, map[string]string, error) {
tableName := r.sanitizeName(dctxTable.Name)
table := models.InitTable(tableName, "public")
table.Description = dctxTable.Description
@@ -154,10 +167,8 @@ func (r *Reader) convertTable(dctxTable *DCTXTable) (*models.Table, map[string]s
// Process fields
for _, dctxField := range dctxTable.Fields {
// Store GUID to name mapping
if dctxField.Guid != "" && dctxField.Name != "" {
fieldGuidMap[dctxField.Guid] = r.sanitizeName(dctxField.Name)
}
// Recursively collect all field GUIDs (including nested fields in GROUP types)
r.collectFieldGuids(&dctxField, fieldGuidMap)
columns, err := r.convertField(&dctxField, table.Name)
if err != nil {
@@ -174,7 +185,7 @@ func (r *Reader) convertTable(dctxTable *DCTXTable) (*models.Table, map[string]s
}
// convertField converts a DCTX field to Column(s)
func (r *Reader) convertField(dctxField *DCTXField, tableName string) ([]*models.Column, error) {
func (r *Reader) convertField(dctxField *models.DCTXField, tableName string) ([]*models.Column, error) {
var columns []*models.Column
// Handle GROUP fields (nested structures)
@@ -286,7 +297,7 @@ func (r *Reader) mapDataType(clarionType string, size int) (sqlType string, prec
}
// processKeys processes DCTX keys and converts them to indexes and primary keys
func (r *Reader) processKeys(dctxTable *DCTXTable, table *models.Table, fieldGuidMap map[string]string) error {
func (r *Reader) processKeys(dctxTable *models.DCTXTable, table *models.Table, fieldGuidMap map[string]string) error {
for _, dctxKey := range dctxTable.Keys {
err := r.convertKey(&dctxKey, table, fieldGuidMap)
if err != nil {
@@ -297,7 +308,7 @@ func (r *Reader) processKeys(dctxTable *DCTXTable, table *models.Table, fieldGui
}
// convertKey converts a DCTX key to appropriate constraint/index
func (r *Reader) convertKey(dctxKey *DCTXKey, table *models.Table, fieldGuidMap map[string]string) error {
func (r *Reader) convertKey(dctxKey *models.DCTXKey, table *models.Table, fieldGuidMap map[string]string) error {
var columns []string
// Extract column names from key components
@@ -349,7 +360,7 @@ func (r *Reader) convertKey(dctxKey *DCTXKey, table *models.Table, fieldGuidMap
}
// Handle regular index
index := models.InitIndex(r.sanitizeName(dctxKey.Name))
index := models.InitIndex(r.sanitizeName(dctxKey.Name), table.Name, table.Schema)
index.Table = table.Name
index.Schema = table.Schema
index.Columns = columns
@@ -361,7 +372,7 @@ func (r *Reader) convertKey(dctxKey *DCTXKey, table *models.Table, fieldGuidMap
}
// processRelations processes DCTX relations and creates foreign keys
func (r *Reader) processRelations(dctx *DCTXDictionary, schema *models.Schema, tableGuidMap map[string]string, keyGuidMap map[string]*DCTXKey, fieldGuidMaps map[string]map[string]string) error {
func (r *Reader) processRelations(dctx *models.DCTXDictionary, schema *models.Schema, tableGuidMap map[string]string, keyGuidMap map[string]*models.DCTXKey, fieldGuidMaps map[string]map[string]string) error {
for i := range dctx.Relations {
relation := &dctx.Relations[i]
// Get table names from GUIDs
@@ -390,19 +401,23 @@ func (r *Reader) processRelations(dctx *DCTXDictionary, schema *models.Schema, t
var fkColumns, pkColumns []string
// Try to use explicit field mappings
// NOTE: DCTX format has backwards naming - ForeignMapping contains primary table fields,
// and PrimaryMapping contains foreign table fields
if len(relation.ForeignMappings) > 0 && len(relation.PrimaryMappings) > 0 {
foreignFieldMap := fieldGuidMaps[foreignTableName]
primaryFieldMap := fieldGuidMaps[primaryTableName]
// ForeignMapping actually contains fields from the PRIMARY table
for _, mapping := range relation.ForeignMappings {
if fieldName, exists := foreignFieldMap[mapping.Field]; exists {
fkColumns = append(fkColumns, fieldName)
if fieldName, exists := primaryFieldMap[mapping.Field]; exists {
pkColumns = append(pkColumns, fieldName)
}
}
// PrimaryMapping actually contains fields from the FOREIGN table
for _, mapping := range relation.PrimaryMappings {
if fieldName, exists := primaryFieldMap[mapping.Field]; exists {
pkColumns = append(pkColumns, fieldName)
if fieldName, exists := foreignFieldMap[mapping.Field]; exists {
fkColumns = append(fkColumns, fieldName)
}
}
}

View File

@@ -445,3 +445,51 @@ func TestColumnProperties(t *testing.T) {
t.Log("Note: No columns with default values found (this may be valid for the test data)")
}
}
func TestRelationships(t *testing.T) {
opts := &readers.ReaderOptions{
FilePath: filepath.Join("..", "..", "..", "examples", "dctx", "example.dctx"),
}
reader := NewReader(opts)
db, err := reader.ReadDatabase()
if err != nil {
t.Fatalf("ReadDatabase() error = %v", err)
}
// Count total relationships across all tables
relationshipCount := 0
for _, schema := range db.Schemas {
for _, table := range schema.Tables {
relationshipCount += len(table.Relationships)
}
}
// The example.dctx file should have a significant number of relationships
// With the fix for nested field GUID mapping, we expect around 100+ relationships
if relationshipCount < 50 {
t.Errorf("Expected at least 50 relationships, got %d. This may indicate relationships are not being parsed correctly", relationshipCount)
}
t.Logf("Successfully parsed %d relationships", relationshipCount)
// Verify relationship properties
for _, schema := range db.Schemas {
for _, table := range schema.Tables {
for _, rel := range table.Relationships {
if rel.Name == "" {
t.Errorf("Relationship in table '%s' should have a name", table.Name)
}
if rel.FromTable == "" {
t.Errorf("Relationship '%s' should have a from table", rel.Name)
}
if rel.ToTable == "" {
t.Errorf("Relationship '%s' should have a to table", rel.Name)
}
if rel.ForeignKey == "" {
t.Errorf("Relationship '%s' should reference a foreign key", rel.Name)
}
}
}
}
}

View File

@@ -248,7 +248,7 @@ func (r *Reader) convertToColumn(field *drawdb.DrawDBField, tableName, schemaNam
// convertToIndex converts a DrawDB index to an Index model
func (r *Reader) convertToIndex(drawIndex *drawdb.DrawDBIndex, drawTable *drawdb.DrawDBTable, schemaName string) *models.Index {
index := models.InitIndex(drawIndex.Name)
index := models.InitIndex(drawIndex.Name, drawTable.Name, schemaName)
index.Table = drawTable.Name
index.Schema = schemaName
index.Unique = drawIndex.Unique

View File

@@ -542,7 +542,7 @@ func (r *Reader) queryIndexes(schemaName string) (map[string][]*models.Index, er
index, err := r.parseIndexDefinition(indexName, tableName, schema, indexDef)
if err != nil {
// If parsing fails, create a basic index
index = models.InitIndex(indexName)
index = models.InitIndex(indexName, tableName, schema)
index.Table = tableName
index.Schema = schema
}
@@ -556,7 +556,7 @@ func (r *Reader) queryIndexes(schemaName string) (map[string][]*models.Index, er
// parseIndexDefinition parses a PostgreSQL index definition
func (r *Reader) parseIndexDefinition(indexName, tableName, schema, indexDef string) (*models.Index, error) {
index := models.InitIndex(indexName)
index := models.InitIndex(indexName, tableName, schema)
index.Table = tableName
index.Schema = schema

View File

@@ -29,7 +29,6 @@ func (w *Writer) WriteDatabase(db *models.Database) error {
return os.WriteFile(w.options.OutputPath, []byte(content), 0644)
}
// If no output path, print to stdout
fmt.Print(content)
return nil
}
@@ -48,7 +47,7 @@ func (w *Writer) WriteSchema(schema *models.Schema) error {
// WriteTable writes a Table model to DBML format
func (w *Writer) WriteTable(table *models.Table) error {
content := w.tableToDBML(table, table.Schema)
content := w.tableToDBML(table)
if w.options.OutputPath != "" {
return os.WriteFile(w.options.OutputPath, []byte(content), 0644)
@@ -60,70 +59,63 @@ func (w *Writer) WriteTable(table *models.Table) error {
// databaseToDBML converts a Database to DBML format string
func (w *Writer) databaseToDBML(d *models.Database) string {
var result string
var sb strings.Builder
// Add database comment if exists
if d.Description != "" {
result += fmt.Sprintf("// %s\n", d.Description)
sb.WriteString(fmt.Sprintf("// %s\n", d.Description))
}
if d.Comment != "" {
result += fmt.Sprintf("// %s\n", d.Comment)
sb.WriteString(fmt.Sprintf("// %s\n", d.Comment))
}
if d.Description != "" || d.Comment != "" {
result += "\n"
sb.WriteString("\n")
}
// Process each schema
for _, schema := range d.Schemas {
result += w.schemaToDBML(schema)
sb.WriteString(w.schemaToDBML(schema))
}
// Add relationships
result += "\n// Relationships\n"
sb.WriteString("\n// Relationships\n")
for _, schema := range d.Schemas {
for _, table := range schema.Tables {
for _, constraint := range table.Constraints {
if constraint.Type == models.ForeignKeyConstraint {
result += w.constraintToDBML(constraint, schema.Name, table.Name)
sb.WriteString(w.constraintToDBML(constraint, table))
}
}
}
}
return result
return sb.String()
}
// schemaToDBML converts a Schema to DBML format string
func (w *Writer) schemaToDBML(schema *models.Schema) string {
var result string
var sb strings.Builder
if schema.Description != "" {
result += fmt.Sprintf("// Schema: %s - %s\n", schema.Name, schema.Description)
sb.WriteString(fmt.Sprintf("// Schema: %s - %s\n", schema.Name, schema.Description))
}
// Process tables
for _, table := range schema.Tables {
result += w.tableToDBML(table, schema.Name)
result += "\n"
sb.WriteString(w.tableToDBML(table))
sb.WriteString("\n")
}
return result
return sb.String()
}
// tableToDBML converts a Table to DBML format string
func (w *Writer) tableToDBML(t *models.Table, schemaName string) string {
var result string
func (w *Writer) tableToDBML(t *models.Table) string {
var sb strings.Builder
// Table definition
tableName := fmt.Sprintf("%s.%s", schemaName, t.Name)
result += fmt.Sprintf("Table %s {\n", tableName)
tableName := fmt.Sprintf("%s.%s", t.Schema, t.Name)
sb.WriteString(fmt.Sprintf("Table %s {\n", tableName))
// Add columns
for _, column := range t.Columns {
result += fmt.Sprintf(" %s %s", column.Name, column.Type)
sb.WriteString(fmt.Sprintf(" %s %s", column.Name, column.Type))
// Add column attributes
attrs := make([]string, 0)
var attrs []string
if column.IsPrimaryKey {
attrs = append(attrs, "pk")
}
@@ -134,77 +126,74 @@ func (w *Writer) tableToDBML(t *models.Table, schemaName string) string {
attrs = append(attrs, "increment")
}
if column.Default != nil {
attrs = append(attrs, fmt.Sprintf("default: %v", column.Default))
attrs = append(attrs, fmt.Sprintf("default: `%v`", column.Default))
}
if len(attrs) > 0 {
result += fmt.Sprintf(" [%s]", strings.Join(attrs, ", "))
sb.WriteString(fmt.Sprintf(" [%s]", strings.Join(attrs, ", ")))
}
if column.Comment != "" {
result += fmt.Sprintf(" // %s", column.Comment)
sb.WriteString(fmt.Sprintf(" // %s", column.Comment))
}
result += "\n"
sb.WriteString("\n")
}
// Add indexes
indexCount := 0
for _, index := range t.Indexes {
if indexCount == 0 {
result += "\n indexes {\n"
}
indexAttrs := make([]string, 0)
if index.Unique {
indexAttrs = append(indexAttrs, "unique")
}
if index.Name != "" {
indexAttrs = append(indexAttrs, fmt.Sprintf("name: '%s'", index.Name))
}
if index.Type != "" {
indexAttrs = append(indexAttrs, fmt.Sprintf("type: %s", index.Type))
}
if len(t.Indexes) > 0 {
sb.WriteString("\n indexes {\n")
for _, index := range t.Indexes {
var indexAttrs []string
if index.Unique {
indexAttrs = append(indexAttrs, "unique")
}
if index.Name != "" {
indexAttrs = append(indexAttrs, fmt.Sprintf("name: '%s'", index.Name))
}
if index.Type != "" {
indexAttrs = append(indexAttrs, fmt.Sprintf("type: %s", index.Type))
}
result += fmt.Sprintf(" (%s)", strings.Join(index.Columns, ", "))
if len(indexAttrs) > 0 {
result += fmt.Sprintf(" [%s]", strings.Join(indexAttrs, ", "))
sb.WriteString(fmt.Sprintf(" (%s)", strings.Join(index.Columns, ", ")))
if len(indexAttrs) > 0 {
sb.WriteString(fmt.Sprintf(" [%s]", strings.Join(indexAttrs, ", ")))
}
sb.WriteString("\n")
}
result += "\n"
indexCount++
}
if indexCount > 0 {
result += " }\n"
sb.WriteString(" }\n")
}
// Add table note
if t.Description != "" || t.Comment != "" {
note := t.Description
if note != "" && t.Comment != "" {
note += " - "
}
note += t.Comment
result += fmt.Sprintf("\n Note: '%s'\n", note)
note := strings.TrimSpace(t.Description + " " + t.Comment)
if note != "" {
sb.WriteString(fmt.Sprintf("\n Note: '%s'\n", note))
}
result += "}\n"
return result
sb.WriteString("}\n")
return sb.String()
}
// constraintToDBML converts a Constraint to DBML format string
func (w *Writer) constraintToDBML(c *models.Constraint, schemaName, tableName string) string {
func (w *Writer) constraintToDBML(c *models.Constraint, t *models.Table) string {
if c.Type != models.ForeignKeyConstraint || c.ReferencedTable == "" {
return ""
}
fromTable := fmt.Sprintf("%s.%s", schemaName, tableName)
fromTable := fmt.Sprintf("%s.%s", c.Schema, c.Table)
toTable := fmt.Sprintf("%s.%s", c.ReferencedSchema, c.ReferencedTable)
// Determine relationship cardinality
// For foreign keys, it's typically many-to-one
relationship := ">"
relationship := ">" // Default to many-to-one
for _, index := range t.Indexes {
if index.Unique && strings.Join(index.Columns, ",") == strings.Join(c.Columns, ",") {
relationship = "-" // one-to-one
break
}
}
for _, column := range c.Columns {
if t.Columns[column].IsPrimaryKey {
relationship = "-" // one-to-one
break
}
}
// Build from and to column references
// For single columns: table.column
// For multiple columns: table.(col1, col2)
var fromRef, toRef string
if len(c.Columns) == 1 {
fromRef = fmt.Sprintf("%s.%s", fromTable, c.Columns[0])
@@ -218,20 +207,18 @@ func (w *Writer) constraintToDBML(c *models.Constraint, schemaName, tableName st
toRef = fmt.Sprintf("%s.(%s)", toTable, strings.Join(c.ReferencedColumns, ", "))
}
result := fmt.Sprintf("Ref: %s %s %s", fromRef, relationship, toRef)
// Add actions
actions := make([]string, 0)
var actions []string
if c.OnDelete != "" {
actions = append(actions, fmt.Sprintf("ondelete: %s", c.OnDelete))
actions = append(actions, fmt.Sprintf("delete: %s", c.OnDelete))
}
if c.OnUpdate != "" {
actions = append(actions, fmt.Sprintf("onupdate: %s", c.OnUpdate))
}
if len(actions) > 0 {
result += fmt.Sprintf(" [%s]", strings.Join(actions, ", "))
actions = append(actions, fmt.Sprintf("update: %s", c.OnUpdate))
}
result += "\n"
return result
}
refLine := fmt.Sprintf("Ref: %s %s %s", fromRef, relationship, toRef)
if len(actions) > 0 {
refLine += fmt.Sprintf(" [%s]", strings.Join(actions, ", "))
}
return refLine + "\n"
}

View File

@@ -3,11 +3,11 @@ package dbml
import (
"os"
"path/filepath"
"strings"
"testing"
"git.warky.dev/wdevs/relspecgo/pkg/models"
"git.warky.dev/wdevs/relspecgo/pkg/writers"
"github.com/stretchr/testify/assert"
)
func TestWriter_WriteTable(t *testing.T) {
@@ -46,96 +46,40 @@ func TestWriter_WriteTable(t *testing.T) {
writer := NewWriter(opts)
err := writer.WriteTable(table)
if err != nil {
t.Fatalf("WriteTable() error = %v", err)
}
assert.NoError(t, err)
content, err := os.ReadFile(outputPath)
if err != nil {
t.Fatalf("Failed to read output file: %v", err)
}
assert.NoError(t, err)
output := string(content)
// Verify table structure
if !strings.Contains(output, "Table public.users {") {
t.Error("Output should contain table definition")
}
// Verify columns
if !strings.Contains(output, "id bigint") {
t.Error("Output should contain id column")
}
if !strings.Contains(output, "pk") {
t.Error("Output should contain pk attribute for id")
}
if !strings.Contains(output, "increment") {
t.Error("Output should contain increment attribute for id")
}
if !strings.Contains(output, "email varchar(255)") {
t.Error("Output should contain email column")
}
if !strings.Contains(output, "not null") {
t.Error("Output should contain not null attribute")
}
// Verify table note
if !strings.Contains(output, "Note:") && table.Description != "" {
t.Error("Output should contain table note when description is present")
}
assert.Contains(t, output, "Table public.users {")
assert.Contains(t, output, "id bigint [pk, increment]")
assert.Contains(t, output, "email varchar(255) [not null]")
assert.Contains(t, output, "Note: 'User accounts table'")
}
func TestWriter_WriteDatabase_WithRelationships(t *testing.T) {
db := models.InitDatabase("test_db")
schema := models.InitSchema("public")
// Create users table
usersTable := models.InitTable("users", "public")
idCol := models.InitColumn("id", "users", "public")
idCol.Type = "bigint"
idCol.IsPrimaryKey = true
idCol.AutoIncrement = true
idCol.NotNull = true
usersTable.Columns["id"] = idCol
emailCol := models.InitColumn("email", "users", "public")
emailCol.Type = "varchar(255)"
emailCol.NotNull = true
usersTable.Columns["email"] = emailCol
// Add index to users table
emailIdx := models.InitIndex("idx_users_email")
emailIdx := models.InitIndex("idx_users_email", "users", "public")
emailIdx.Columns = []string{"email"}
emailIdx.Unique = true
emailIdx.Table = "users"
emailIdx.Schema = "public"
usersTable.Indexes["idx_users_email"] = emailIdx
schema.Tables = append(schema.Tables, usersTable)
// Create posts table
postsTable := models.InitTable("posts", "public")
postIdCol := models.InitColumn("id", "posts", "public")
postIdCol.Type = "bigint"
postIdCol.IsPrimaryKey = true
postIdCol.AutoIncrement = true
postIdCol.NotNull = true
postsTable.Columns["id"] = postIdCol
userIdCol := models.InitColumn("user_id", "posts", "public")
userIdCol.Type = "bigint"
userIdCol.NotNull = true
postsTable.Columns["user_id"] = userIdCol
titleCol := models.InitColumn("title", "posts", "public")
titleCol.Type = "varchar(200)"
titleCol.NotNull = true
postsTable.Columns["title"] = titleCol
publishedCol := models.InitColumn("published", "posts", "public")
publishedCol.Type = "boolean"
publishedCol.Default = "false"
postsTable.Columns["published"] = publishedCol
// Add foreign key constraint
fk := models.InitConstraint("fk_posts_user", models.ForeignKeyConstraint)
fk.Table = "posts"
fk.Schema = "public"
@@ -144,353 +88,68 @@ func TestWriter_WriteDatabase_WithRelationships(t *testing.T) {
fk.ReferencedSchema = "public"
fk.ReferencedColumns = []string{"id"}
fk.OnDelete = "CASCADE"
fk.OnUpdate = "CASCADE"
postsTable.Constraints["fk_posts_user"] = fk
schema.Tables = append(schema.Tables, usersTable, postsTable)
schema.Tables = append(schema.Tables, postsTable)
db.Schemas = append(db.Schemas, schema)
tmpDir := t.TempDir()
outputPath := filepath.Join(tmpDir, "test.dbml")
opts := &writers.WriterOptions{
OutputPath: outputPath,
}
opts := &writers.WriterOptions{OutputPath: outputPath}
writer := NewWriter(opts)
err := writer.WriteDatabase(db)
if err != nil {
t.Fatalf("WriteDatabase() error = %v", err)
}
assert.NoError(t, err)
content, err := os.ReadFile(outputPath)
if err != nil {
t.Fatalf("Failed to read output file: %v", err)
}
assert.NoError(t, err)
output := string(content)
// Verify tables
if !strings.Contains(output, "Table public.users {") {
t.Error("Output should contain users table")
}
if !strings.Contains(output, "Table public.posts {") {
t.Error("Output should contain posts table")
}
// Verify foreign key reference
if !strings.Contains(output, "Ref:") {
t.Error("Output should contain Ref for foreign key")
}
if !strings.Contains(output, "public.posts.user_id") {
t.Error("Output should contain posts.user_id in reference")
}
if !strings.Contains(output, "public.users.id") {
t.Error("Output should contain users.id in reference")
}
if !strings.Contains(output, "ondelete: CASCADE") {
t.Error("Output should contain ondelete: CASCADE")
}
if !strings.Contains(output, "onupdate: CASCADE") {
t.Error("Output should contain onupdate: CASCADE")
}
// Verify index
if !strings.Contains(output, "indexes") {
t.Error("Output should contain indexes section")
}
if !strings.Contains(output, "(email)") {
t.Error("Output should contain email index")
}
if !strings.Contains(output, "unique") {
t.Error("Output should contain unique attribute for email index")
}
assert.Contains(t, output, "Table public.users {")
assert.Contains(t, output, "Table public.posts {")
assert.Contains(t, output, "Ref: public.posts.user_id > public.users.id [delete: CASCADE]")
assert.Contains(t, output, "(email) [unique, name: 'idx_users_email']")
}
func TestWriter_WriteSchema(t *testing.T) {
func TestWriter_WriteDatabase_OneToOneRelationship(t *testing.T) {
db := models.InitDatabase("test_db")
schema := models.InitSchema("public")
table := models.InitTable("users", "public")
idCol := models.InitColumn("id", "users", "public")
idCol.Type = "bigint"
idCol.IsPrimaryKey = true
idCol.NotNull = true
table.Columns["id"] = idCol
usernameCol := models.InitColumn("username", "users", "public")
usernameCol.Type = "varchar(50)"
usernameCol.NotNull = true
table.Columns["username"] = usernameCol
schema.Tables = append(schema.Tables, table)
tmpDir := t.TempDir()
outputPath := filepath.Join(tmpDir, "test.dbml")
opts := &writers.WriterOptions{
OutputPath: outputPath,
}
writer := NewWriter(opts)
err := writer.WriteSchema(schema)
if err != nil {
t.Fatalf("WriteSchema() error = %v", err)
}
content, err := os.ReadFile(outputPath)
if err != nil {
t.Fatalf("Failed to read output file: %v", err)
}
output := string(content)
// Verify table exists
if !strings.Contains(output, "Table public.users {") {
t.Error("Output should contain users table")
}
// Verify columns
if !strings.Contains(output, "id bigint") {
t.Error("Output should contain id column")
}
if !strings.Contains(output, "username varchar(50)") {
t.Error("Output should contain username column")
}
}
func TestWriter_WriteDatabase_MultipleSchemas(t *testing.T) {
db := models.InitDatabase("test_db")
// Create public schema with users table
publicSchema := models.InitSchema("public")
usersTable := models.InitTable("users", "public")
idCol := models.InitColumn("id", "users", "public")
idCol.Type = "bigint"
idCol.IsPrimaryKey = true
usersTable.Columns["id"] = idCol
publicSchema.Tables = append(publicSchema.Tables, usersTable)
schema.Tables = append(schema.Tables, usersTable)
// Create admin schema with audit_logs table
adminSchema := models.InitSchema("admin")
auditTable := models.InitTable("audit_logs", "admin")
auditIdCol := models.InitColumn("id", "audit_logs", "admin")
auditIdCol.Type = "bigint"
auditIdCol.IsPrimaryKey = true
auditTable.Columns["id"] = auditIdCol
userIdCol := models.InitColumn("user_id", "audit_logs", "admin")
profilesTable := models.InitTable("profiles", "public")
profileIdCol := models.InitColumn("id", "profiles", "public")
profileIdCol.Type = "bigint"
profilesTable.Columns["id"] = profileIdCol
userIdCol := models.InitColumn("user_id", "profiles", "public")
userIdCol.Type = "bigint"
auditTable.Columns["user_id"] = userIdCol
userIdCol.IsPrimaryKey = true // This makes it a one-to-one
profilesTable.Columns["user_id"] = userIdCol
// Add foreign key from admin.audit_logs to public.users
fk := models.InitConstraint("fk_audit_user", models.ForeignKeyConstraint)
fk.Table = "audit_logs"
fk.Schema = "admin"
fk := models.InitConstraint("fk_profiles_user", models.ForeignKeyConstraint)
fk.Table = "profiles"
fk.Schema = "public"
fk.Columns = []string{"user_id"}
fk.ReferencedTable = "users"
fk.ReferencedSchema = "public"
fk.ReferencedColumns = []string{"id"}
fk.OnDelete = "SET NULL"
auditTable.Constraints["fk_audit_user"] = fk
adminSchema.Tables = append(adminSchema.Tables, auditTable)
db.Schemas = append(db.Schemas, publicSchema, adminSchema)
tmpDir := t.TempDir()
outputPath := filepath.Join(tmpDir, "test.dbml")
opts := &writers.WriterOptions{
OutputPath: outputPath,
}
writer := NewWriter(opts)
err := writer.WriteDatabase(db)
if err != nil {
t.Fatalf("WriteDatabase() error = %v", err)
}
content, err := os.ReadFile(outputPath)
if err != nil {
t.Fatalf("Failed to read output file: %v", err)
}
output := string(content)
// Verify both schemas present
if !strings.Contains(output, "public.users") {
t.Error("Output should contain public.users table")
}
if !strings.Contains(output, "admin.audit_logs") {
t.Error("Output should contain admin.audit_logs table")
}
// Verify cross-schema foreign key
if !strings.Contains(output, "admin.audit_logs.user_id") {
t.Error("Output should contain admin.audit_logs.user_id in reference")
}
if !strings.Contains(output, "public.users.id") {
t.Error("Output should contain public.users.id in reference")
}
if !strings.Contains(output, "ondelete: SET NULL") {
t.Error("Output should contain ondelete: SET NULL")
}
}
func TestWriter_WriteTable_WithDefaults(t *testing.T) {
table := models.InitTable("products", "public")
idCol := models.InitColumn("id", "products", "public")
idCol.Type = "bigint"
idCol.IsPrimaryKey = true
table.Columns["id"] = idCol
isActiveCol := models.InitColumn("is_active", "products", "public")
isActiveCol.Type = "boolean"
isActiveCol.Default = "true"
table.Columns["is_active"] = isActiveCol
createdCol := models.InitColumn("created_at", "products", "public")
createdCol.Type = "timestamp"
createdCol.Default = "CURRENT_TIMESTAMP"
table.Columns["created_at"] = createdCol
tmpDir := t.TempDir()
outputPath := filepath.Join(tmpDir, "test.dbml")
opts := &writers.WriterOptions{
OutputPath: outputPath,
}
writer := NewWriter(opts)
err := writer.WriteTable(table)
if err != nil {
t.Fatalf("WriteTable() error = %v", err)
}
content, err := os.ReadFile(outputPath)
if err != nil {
t.Fatalf("Failed to read output file: %v", err)
}
output := string(content)
// Verify default values
if !strings.Contains(output, "default:") {
t.Error("Output should contain default values")
}
}
func TestWriter_WriteTable_EmptyPath(t *testing.T) {
table := models.InitTable("users", "public")
idCol := models.InitColumn("id", "users", "public")
idCol.Type = "bigint"
table.Columns["id"] = idCol
// When OutputPath is empty, it should print to stdout (not error)
opts := &writers.WriterOptions{
OutputPath: "",
}
writer := NewWriter(opts)
err := writer.WriteTable(table)
if err != nil {
t.Fatalf("WriteTable() with empty path should not error, got: %v", err)
}
}
func TestWriter_WriteDatabase_WithComments(t *testing.T) {
db := models.InitDatabase("test_db")
db.Description = "Test database description"
db.Comment = "Additional comment"
schema := models.InitSchema("public")
table := models.InitTable("users", "public")
table.Comment = "Users table comment"
idCol := models.InitColumn("id", "users", "public")
idCol.Type = "bigint"
idCol.IsPrimaryKey = true
idCol.Comment = "Primary key"
table.Columns["id"] = idCol
schema.Tables = append(schema.Tables, table)
profilesTable.Constraints["fk_profiles_user"] = fk
schema.Tables = append(schema.Tables, profilesTable)
db.Schemas = append(db.Schemas, schema)
tmpDir := t.TempDir()
outputPath := filepath.Join(tmpDir, "test.dbml")
opts := &writers.WriterOptions{
OutputPath: outputPath,
}
opts := &writers.WriterOptions{OutputPath: outputPath}
writer := NewWriter(opts)
err := writer.WriteDatabase(db)
if err != nil {
t.Fatalf("WriteDatabase() error = %v", err)
}
assert.NoError(t, err)
content, err := os.ReadFile(outputPath)
if err != nil {
t.Fatalf("Failed to read output file: %v", err)
}
assert.NoError(t, err)
output := string(content)
// Verify comments are present
if !strings.Contains(output, "//") {
t.Error("Output should contain comments")
}
}
func TestWriter_WriteDatabase_WithIndexType(t *testing.T) {
db := models.InitDatabase("test_db")
schema := models.InitSchema("public")
table := models.InitTable("users", "public")
idCol := models.InitColumn("id", "users", "public")
idCol.Type = "bigint"
idCol.IsPrimaryKey = true
table.Columns["id"] = idCol
emailCol := models.InitColumn("email", "users", "public")
emailCol.Type = "varchar(255)"
table.Columns["email"] = emailCol
// Add index with type
idx := models.InitIndex("idx_email")
idx.Columns = []string{"email"}
idx.Type = "btree"
idx.Unique = true
idx.Table = "users"
idx.Schema = "public"
table.Indexes["idx_email"] = idx
schema.Tables = append(schema.Tables, table)
db.Schemas = append(db.Schemas, schema)
tmpDir := t.TempDir()
outputPath := filepath.Join(tmpDir, "test.dbml")
opts := &writers.WriterOptions{
OutputPath: outputPath,
}
writer := NewWriter(opts)
err := writer.WriteDatabase(db)
if err != nil {
t.Fatalf("WriteDatabase() error = %v", err)
}
content, err := os.ReadFile(outputPath)
if err != nil {
t.Fatalf("Failed to read output file: %v", err)
}
output := string(content)
// Verify index with type
if !strings.Contains(output, "type:") || !strings.Contains(output, "btree") {
t.Error("Output should contain index type")
}
}
assert.Contains(t, output, "Ref: public.profiles.user_id - public.users.id")
}

View File

@@ -0,0 +1,194 @@
package dctx
import (
"os"
"path/filepath"
"testing"
"git.warky.dev/wdevs/relspecgo/pkg/models"
"git.warky.dev/wdevs/relspecgo/pkg/readers"
dctxreader "git.warky.dev/wdevs/relspecgo/pkg/readers/dctx"
"git.warky.dev/wdevs/relspecgo/pkg/writers"
"github.com/stretchr/testify/assert"
)
func TestRoundTrip_WriteAndRead(t *testing.T) {
// 1. Create a sample schema with relationships
schema := models.InitSchema("public")
schema.Name = "TestDB"
// Table 1: users
usersTable := models.InitTable("users", "public")
usersTable.Comment = "Stores user information"
idCol := models.InitColumn("id", "users", "public")
idCol.Type = "serial"
idCol.IsPrimaryKey = true
idCol.NotNull = true
usersTable.Columns["id"] = idCol
nameCol := models.InitColumn("name", "users", "public")
nameCol.Type = "varchar"
nameCol.Length = 100
usersTable.Columns["name"] = nameCol
pkIndex := models.InitIndex("users_pkey", "users", "public")
pkIndex.Unique = true
pkIndex.Columns = []string{"id"}
usersTable.Indexes["users_pkey"] = pkIndex
pkConstraint := models.InitConstraint("users_pkey", models.PrimaryKeyConstraint)
pkConstraint.Table = "users"
pkConstraint.Schema = "public"
pkConstraint.Columns = []string{"id"}
usersTable.Constraints["users_pkey"] = pkConstraint
schema.Tables = append(schema.Tables, usersTable)
// Table 2: posts
postsTable := models.InitTable("posts", "public")
postsTable.Comment = "Stores blog posts"
postIDCol := models.InitColumn("id", "posts", "public")
postIDCol.Type = "serial"
postIDCol.IsPrimaryKey = true
postIDCol.NotNull = true
postsTable.Columns["id"] = postIDCol
titleCol := models.InitColumn("title", "posts", "public")
titleCol.Type = "varchar"
titleCol.Length = 255
postsTable.Columns["title"] = titleCol
userIDCol := models.InitColumn("user_id", "posts", "public")
userIDCol.Type = "integer"
postsTable.Columns["user_id"] = userIDCol
postsPKIndex := models.InitIndex("posts_pkey", "posts", "public")
postsPKIndex.Unique = true
postsPKIndex.Columns = []string{"id"}
postsTable.Indexes["posts_pkey"] = postsPKIndex
fkIndex := models.InitIndex("posts_user_id_idx", "posts", "public")
fkIndex.Columns = []string{"user_id"}
postsTable.Indexes["posts_user_id_idx"] = fkIndex
postsPKConstraint := models.InitConstraint("posts_pkey", models.PrimaryKeyConstraint)
postsPKConstraint.Table = "posts"
postsPKConstraint.Schema = "public"
postsPKConstraint.Columns = []string{"id"}
postsTable.Constraints["posts_pkey"] = postsPKConstraint
// Foreign key constraint
fkConstraint := models.InitConstraint("fk_posts_users", models.ForeignKeyConstraint)
fkConstraint.Table = "posts"
fkConstraint.Schema = "public"
fkConstraint.Columns = []string{"user_id"}
fkConstraint.ReferencedTable = "users"
fkConstraint.ReferencedSchema = "public"
fkConstraint.ReferencedColumns = []string{"id"}
fkConstraint.OnDelete = "CASCADE"
fkConstraint.OnUpdate = "NO ACTION"
postsTable.Constraints["fk_posts_users"] = fkConstraint
schema.Tables = append(schema.Tables, postsTable)
// Relation
relation := models.InitRelationship("posts_to_users", models.OneToMany)
relation.FromTable = "posts"
relation.FromSchema = "public"
relation.ToTable = "users"
relation.ToSchema = "public"
relation.ForeignKey = "fk_posts_users"
schema.Relations = append(schema.Relations, relation)
// 2. Write the schema to DCTX
outputPath := filepath.Join(t.TempDir(), "roundtrip.dctx")
writerOpts := &writers.WriterOptions{
OutputPath: outputPath,
}
writer := NewWriter(writerOpts)
err := writer.WriteSchema(schema)
assert.NoError(t, err)
// Verify file was created
_, err = os.Stat(outputPath)
assert.NoError(t, err, "Output file should exist")
// 3. Read the schema back from DCTX
readerOpts := &readers.ReaderOptions{
FilePath: outputPath,
}
reader := dctxreader.NewReader(readerOpts)
db, err := reader.ReadDatabase()
assert.NoError(t, err)
assert.NotNil(t, db)
// 4. Verify the schema was read correctly
assert.Len(t, db.Schemas, 1, "Should have one schema")
readSchema := db.Schemas[0]
// Verify tables
assert.Len(t, readSchema.Tables, 2, "Should have two tables")
// Find users and posts tables
var readUsersTable, readPostsTable *models.Table
for _, table := range readSchema.Tables {
switch table.Name {
case "users":
readUsersTable = table
case "posts":
readPostsTable = table
}
}
assert.NotNil(t, readUsersTable, "Users table should exist")
assert.NotNil(t, readPostsTable, "Posts table should exist")
// Verify columns
assert.Len(t, readUsersTable.Columns, 2, "Users table should have 2 columns")
assert.NotNil(t, readUsersTable.Columns["id"])
assert.NotNil(t, readUsersTable.Columns["name"])
assert.Len(t, readPostsTable.Columns, 3, "Posts table should have 3 columns")
assert.NotNil(t, readPostsTable.Columns["id"])
assert.NotNil(t, readPostsTable.Columns["title"])
assert.NotNil(t, readPostsTable.Columns["user_id"])
// Verify relationships were preserved
// The DCTX reader stores relationships on the foreign table (posts)
assert.NotEmpty(t, readPostsTable.Relationships, "Posts table should have relationships")
// Debug: print all relationships
t.Logf("Posts table has %d relationships:", len(readPostsTable.Relationships))
for name, rel := range readPostsTable.Relationships {
t.Logf(" - %s: from=%s to=%s fk=%s", name, rel.FromTable, rel.ToTable, rel.ForeignKey)
}
// Find the relationship - the reader creates it with FromTable as primary and ToTable as foreign
var postsToUsersRel *models.Relationship
for _, rel := range readPostsTable.Relationships {
// The relationship should have posts as ToTable (foreign) and users as FromTable (primary)
if rel.FromTable == "users" && rel.ToTable == "posts" {
postsToUsersRel = rel
break
}
}
assert.NotNil(t, postsToUsersRel, "Should have relationship from users to posts")
if postsToUsersRel != nil {
assert.Equal(t, "users", postsToUsersRel.FromTable, "Relationship should come from users (primary) table")
assert.Equal(t, "posts", postsToUsersRel.ToTable, "Relationship should point to posts (foreign) table")
assert.NotEmpty(t, postsToUsersRel.ForeignKey, "Relationship should have a foreign key")
}
// Verify foreign key constraint
fks := readPostsTable.GetForeignKeys()
assert.NotEmpty(t, fks, "Posts table should have foreign keys")
if len(fks) > 0 {
fk := fks[0]
assert.Equal(t, models.ForeignKeyConstraint, fk.Type)
assert.Contains(t, fk.Columns, "user_id")
assert.Equal(t, "users", fk.ReferencedTable)
assert.Contains(t, fk.ReferencedColumns, "id")
assert.Equal(t, "CASCADE", fk.OnDelete)
}
t.Logf("Round-trip test successful: wrote and read back %d tables with relationships", len(readSchema.Tables))
}

View File

@@ -1,36 +1,379 @@
package dctx
import (
"encoding/xml"
"fmt"
"os"
"strings"
"git.warky.dev/wdevs/relspecgo/pkg/models"
"git.warky.dev/wdevs/relspecgo/pkg/writers"
"github.com/google/uuid"
)
// Writer implements the writers.Writer interface for DCTX format
// Note: DCTX is a read-only format used for loading Clarion dictionary files
type Writer struct {
options *writers.WriterOptions
options *writers.WriterOptions
fieldGuidMap map[string]string // key: "table.column", value: guid
keyGuidMap map[string]string // key: "table.index", value: guid
tableGuidMap map[string]string // key: "table", value: guid
}
// NewWriter creates a new DCTX writer with the given options
func NewWriter(options *writers.WriterOptions) *Writer {
return &Writer{
options: options,
options: options,
fieldGuidMap: make(map[string]string),
keyGuidMap: make(map[string]string),
tableGuidMap: make(map[string]string),
}
}
// WriteDatabase returns an error as DCTX format is read-only
// WriteDatabase is not implemented for DCTX
func (w *Writer) WriteDatabase(db *models.Database) error {
return fmt.Errorf("DCTX format is read-only and does not support writing - it is used for loading Clarion dictionary files only")
return fmt.Errorf("writing a full database is not supported for DCTX, please write a single schema")
}
// WriteSchema returns an error as DCTX format is read-only
// WriteSchema writes a schema to the writer in DCTX format
func (w *Writer) WriteSchema(schema *models.Schema) error {
return fmt.Errorf("DCTX format is read-only and does not support writing - it is used for loading Clarion dictionary files only")
dctx := models.DCTXDictionary{
Name: schema.Name,
Version: "1",
Tables: make([]models.DCTXTable, len(schema.Tables)),
}
tableSlice := make([]*models.Table, 0, len(schema.Tables))
for _, t := range schema.Tables {
tableSlice = append(tableSlice, t)
}
// Pass 1: Create fields and populate fieldGuidMap
for i, table := range tableSlice {
dctx.Tables[i] = w.mapTableFields(table)
}
// Pass 2: Create keys and populate keyGuidMap
for i, table := range tableSlice {
dctx.Tables[i].Keys = w.mapTableKeys(table)
}
// Pass 3: Collect all relationships (from schema and tables)
var allRelations []*models.Relationship
// Add schema-level relations
allRelations = append(allRelations, schema.Relations...)
// Add table-level relationships
for _, table := range tableSlice {
for _, rel := range table.Relationships {
// Check if this relationship is already in the list (avoid duplicates)
isDuplicate := false
for _, existing := range allRelations {
if existing.Name == rel.Name &&
existing.FromTable == rel.FromTable &&
existing.ToTable == rel.ToTable {
isDuplicate = true
break
}
}
if !isDuplicate {
allRelations = append(allRelations, rel)
}
}
}
// Map all relations to DCTX format
dctx.Relations = make([]models.DCTXRelation, len(allRelations))
for i, rel := range allRelations {
dctx.Relations[i] = w.mapRelation(rel, schema)
}
output, err := xml.MarshalIndent(dctx, "", " ")
if err != nil {
return err
}
file, err := os.Create(w.options.OutputPath)
if err != nil {
return err
}
defer file.Close()
if _, err := file.Write([]byte(xml.Header)); err != nil {
return err
}
_, err = file.Write(output)
return err
}
// WriteTable returns an error as DCTX format is read-only
// WriteTable writes a single table to the writer in DCTX format
func (w *Writer) WriteTable(table *models.Table) error {
return fmt.Errorf("DCTX format is read-only and does not support writing - it is used for loading Clarion dictionary files only")
dctxTable := w.mapTableFields(table)
dctxTable.Keys = w.mapTableKeys(table)
output, err := xml.MarshalIndent(dctxTable, "", " ")
if err != nil {
return err
}
file, err := os.Create(w.options.OutputPath)
if err != nil {
return err
}
defer file.Close()
_, err = file.Write(output)
return err
}
func (w *Writer) mapTableFields(table *models.Table) models.DCTXTable {
// Generate prefix (first 3 chars, or full name if shorter)
prefix := table.Name
if len(table.Name) > 3 {
prefix = table.Name[:3]
}
tableGuid := w.newGUID()
w.tableGuidMap[table.Name] = tableGuid
dctxTable := models.DCTXTable{
Guid: tableGuid,
Name: table.Name,
Prefix: prefix,
Description: table.Comment,
Fields: make([]models.DCTXField, len(table.Columns)),
Options: []models.DCTXOption{
{
Property: "SQL",
PropertyType: "1",
PropertyValue: "1",
},
},
}
i := 0
for _, column := range table.Columns {
dctxTable.Fields[i] = w.mapField(column)
i++
}
return dctxTable
}
func (w *Writer) mapTableKeys(table *models.Table) []models.DCTXKey {
keys := make([]models.DCTXKey, len(table.Indexes))
i := 0
for _, index := range table.Indexes {
keys[i] = w.mapKey(index, table)
i++
}
return keys
}
func (w *Writer) mapField(column *models.Column) models.DCTXField {
guid := w.newGUID()
fieldKey := fmt.Sprintf("%s.%s", column.Table, column.Name)
w.fieldGuidMap[fieldKey] = guid
return models.DCTXField{
Guid: guid,
Name: column.Name,
DataType: w.mapDataType(column.Type),
Size: column.Length,
}
}
func (w *Writer) mapDataType(dataType string) string {
switch dataType {
case "integer", "int", "int4", "serial":
return "LONG"
case "bigint", "int8", "bigserial":
return "DECIMAL"
case "smallint", "int2":
return "SHORT"
case "boolean", "bool":
return "BYTE"
case "text", "varchar", "char":
return "CSTRING"
case "date":
return "DATE"
case "time":
return "TIME"
case "timestamp", "timestamptz":
return "STRING"
case "decimal", "numeric":
return "DECIMAL"
default:
return "STRING"
}
}
func (w *Writer) mapKey(index *models.Index, table *models.Table) models.DCTXKey {
guid := w.newGUID()
keyKey := fmt.Sprintf("%s.%s", table.Name, index.Name)
w.keyGuidMap[keyKey] = guid
key := models.DCTXKey{
Guid: guid,
Name: index.Name,
Primary: strings.HasSuffix(index.Name, "_pkey"),
Unique: index.Unique,
Components: make([]models.DCTXComponent, len(index.Columns)),
Description: index.Comment,
}
for i, colName := range index.Columns {
fieldKey := fmt.Sprintf("%s.%s", table.Name, colName)
fieldID := w.fieldGuidMap[fieldKey]
key.Components[i] = models.DCTXComponent{
Guid: w.newGUID(),
FieldId: fieldID,
Order: i + 1,
Ascend: true,
}
}
return key
}
func (w *Writer) mapRelation(rel *models.Relationship, schema *models.Schema) models.DCTXRelation {
// Find the foreign key constraint from the 'from' table
var fromTable *models.Table
for _, t := range schema.Tables {
if t.Name == rel.FromTable {
fromTable = t
break
}
}
var constraint *models.Constraint
if fromTable != nil {
for _, c := range fromTable.Constraints {
if c.Name == rel.ForeignKey {
constraint = c
break
}
}
}
var foreignKeyGUID string
var fkColumns []string
if constraint != nil {
fkColumns = constraint.Columns
// In DCTX, a relation is often linked by a foreign key which is an index.
// We'll look for an index that matches the constraint columns.
for _, index := range fromTable.Indexes {
if strings.Join(index.Columns, ",") == strings.Join(constraint.Columns, ",") {
keyKey := fmt.Sprintf("%s.%s", fromTable.Name, index.Name)
foreignKeyGUID = w.keyGuidMap[keyKey]
break
}
}
}
// Find the primary key of the 'to' table
var toTable *models.Table
for _, t := range schema.Tables {
if t.Name == rel.ToTable {
toTable = t
break
}
}
var primaryKeyGUID string
var pkColumns []string
// Use referenced columns from the constraint if available
if constraint != nil && len(constraint.ReferencedColumns) > 0 {
pkColumns = constraint.ReferencedColumns
}
if toTable != nil {
// Find the matching primary key index
for _, index := range toTable.Indexes {
// If we have referenced columns, try to match them
if len(pkColumns) > 0 {
if strings.Join(index.Columns, ",") == strings.Join(pkColumns, ",") {
keyKey := fmt.Sprintf("%s.%s", toTable.Name, index.Name)
primaryKeyGUID = w.keyGuidMap[keyKey]
break
}
} else if strings.HasSuffix(index.Name, "_pkey") {
// Fall back to finding primary key by naming convention
keyKey := fmt.Sprintf("%s.%s", toTable.Name, index.Name)
primaryKeyGUID = w.keyGuidMap[keyKey]
pkColumns = index.Columns
break
}
}
}
// Create field mappings
// NOTE: DCTX has backwards naming - ForeignMapping contains PRIMARY table fields,
// and PrimaryMapping contains FOREIGN table fields
var foreignMappings []models.DCTXFieldMapping // Will contain primary table fields
var primaryMappings []models.DCTXFieldMapping // Will contain foreign table fields
// Map foreign key columns (from foreign table) to PrimaryMapping
for _, colName := range fkColumns {
fieldKey := fmt.Sprintf("%s.%s", rel.FromTable, colName)
if fieldGUID, exists := w.fieldGuidMap[fieldKey]; exists {
primaryMappings = append(primaryMappings, models.DCTXFieldMapping{
Guid: w.newGUID(),
Field: fieldGUID,
})
}
}
// Map primary key columns (from primary table) to ForeignMapping
for _, colName := range pkColumns {
fieldKey := fmt.Sprintf("%s.%s", rel.ToTable, colName)
if fieldGUID, exists := w.fieldGuidMap[fieldKey]; exists {
foreignMappings = append(foreignMappings, models.DCTXFieldMapping{
Guid: w.newGUID(),
Field: fieldGUID,
})
}
}
// Get OnDelete and OnUpdate actions from the constraint
onDelete := ""
onUpdate := ""
if constraint != nil {
onDelete = w.mapReferentialAction(constraint.OnDelete)
onUpdate = w.mapReferentialAction(constraint.OnUpdate)
}
return models.DCTXRelation{
Guid: w.newGUID(),
PrimaryTable: w.tableGuidMap[rel.ToTable], // GUID of the 'to' table (e.g., users)
ForeignTable: w.tableGuidMap[rel.FromTable], // GUID of the 'from' table (e.g., posts)
PrimaryKey: primaryKeyGUID,
ForeignKey: foreignKeyGUID,
Delete: onDelete,
Update: onUpdate,
ForeignMappings: foreignMappings,
PrimaryMappings: primaryMappings,
}
}
// mapReferentialAction maps SQL referential actions to DCTX format
func (w *Writer) mapReferentialAction(action string) string {
switch strings.ToUpper(action) {
case "RESTRICT":
return "RESTRICT_SERVER"
case "CASCADE":
return "CASCADE_SERVER"
case "SET NULL":
return "SET_NULL_SERVER"
case "SET DEFAULT":
return "SET_DEFAULT_SERVER"
case "NO ACTION":
return "NO_ACTION_SERVER"
default:
return ""
}
}
func (w *Writer) newGUID() string {
return "{" + uuid.New().String() + "}"
}

View File

@@ -1,110 +1,152 @@
package dctx
import (
"strings"
"encoding/xml"
"os"
"testing"
"git.warky.dev/wdevs/relspecgo/pkg/models"
"git.warky.dev/wdevs/relspecgo/pkg/writers"
"github.com/stretchr/testify/assert"
)
// TestWriter_WriteDatabase_ReturnsError tests that WriteDatabase returns an error
// since DCTX format is read-only
func TestWriter_WriteDatabase_ReturnsError(t *testing.T) {
db := models.InitDatabase("test_db")
func TestWriter_WriteSchema(t *testing.T) {
// 1. Create a sample schema
schema := models.InitSchema("public")
table := models.InitTable("users", "public")
schema.Name = "TestDB"
// Table 1: users
usersTable := models.InitTable("users", "public")
usersTable.Comment = "Stores user information"
idCol := models.InitColumn("id", "users", "public")
idCol.Type = "bigint"
table.Columns["id"] = idCol
schema.Tables = append(schema.Tables, table)
db.Schemas = append(db.Schemas, schema)
opts := &writers.WriterOptions{
OutputPath: "/tmp/test.dctx",
}
writer := NewWriter(opts)
err := writer.WriteDatabase(db)
if err == nil {
t.Error("WriteDatabase() should return an error for read-only format")
}
if !strings.Contains(err.Error(), "read-only") {
t.Errorf("Error message should mention 'read-only', got: %v", err)
}
}
// TestWriter_WriteSchema_ReturnsError tests that WriteSchema returns an error
// since DCTX format is read-only
func TestWriter_WriteSchema_ReturnsError(t *testing.T) {
schema := models.InitSchema("public")
table := models.InitTable("users", "public")
idCol := models.InitColumn("id", "users", "public")
idCol.Type = "bigint"
table.Columns["id"] = idCol
schema.Tables = append(schema.Tables, table)
opts := &writers.WriterOptions{
OutputPath: "/tmp/test.dctx",
}
writer := NewWriter(opts)
err := writer.WriteSchema(schema)
if err == nil {
t.Error("WriteSchema() should return an error for read-only format")
}
if !strings.Contains(err.Error(), "read-only") {
t.Errorf("Error message should mention 'read-only', got: %v", err)
}
}
// TestWriter_WriteTable_ReturnsError tests that WriteTable returns an error
// since DCTX format is read-only
func TestWriter_WriteTable_ReturnsError(t *testing.T) {
table := models.InitTable("users", "public")
idCol := models.InitColumn("id", "users", "public")
idCol.Type = "bigint"
idCol.Type = "serial"
idCol.IsPrimaryKey = true
table.Columns["id"] = idCol
usersTable.Columns["id"] = idCol
nameCol := models.InitColumn("name", "users", "public")
nameCol.Type = "varchar"
nameCol.Length = 100
usersTable.Columns["name"] = nameCol
pkIndex := models.InitIndex("users_pkey", "users", "public")
pkIndex.Unique = true
pkIndex.Columns = []string{"id"}
usersTable.Indexes["users_pkey"] = pkIndex
schema.Tables = append(schema.Tables, usersTable)
// Table 2: posts
postsTable := models.InitTable("posts", "public")
postsTable.Comment = "Stores blog posts"
postIDCol := models.InitColumn("id", "posts", "public")
postIDCol.Type = "serial"
postIDCol.IsPrimaryKey = true
postsTable.Columns["id"] = postIDCol
titleCol := models.InitColumn("title", "posts", "public")
titleCol.Type = "varchar"
titleCol.Length = 255
postsTable.Columns["title"] = titleCol
userIDCol := models.InitColumn("user_id", "posts", "public")
userIDCol.Type = "integer"
postsTable.Columns["user_id"] = userIDCol
postsPKIndex := models.InitIndex("posts_pkey", "posts", "public")
postsPKIndex.Unique = true
postsPKIndex.Columns = []string{"id"}
postsTable.Indexes["posts_pkey"] = postsPKIndex
fkIndex := models.InitIndex("posts_user_id_idx", "posts", "public")
fkIndex.Columns = []string{"user_id"}
postsTable.Indexes["posts_user_id_idx"] = fkIndex
schema.Tables = append(schema.Tables, postsTable)
// Constraint for the relationship
fkConstraint := models.InitConstraint("fk_posts_users", models.ForeignKeyConstraint)
fkConstraint.Table = "posts"
fkConstraint.Schema = "public"
fkConstraint.Columns = []string{"user_id"}
fkConstraint.ReferencedTable = "users"
fkConstraint.ReferencedSchema = "public"
fkConstraint.ReferencedColumns = []string{"id"}
postsTable.Constraints["fk_posts_users"] = fkConstraint
// Relation
relation := models.InitRelation("fk_posts_users", "public")
relation.FromTable = "posts"
relation.ToTable = "users"
relation.ForeignKey = "fk_posts_users"
schema.Relations = append(schema.Relations, relation)
// 2. Setup writer
outputPath := "/tmp/test.dctx"
opts := &writers.WriterOptions{
OutputPath: "/tmp/test.dctx",
OutputPath: outputPath,
}
writer := NewWriter(opts)
err := writer.WriteTable(table)
if err == nil {
t.Error("WriteTable() should return an error for read-only format")
}
if !strings.Contains(err.Error(), "read-only") {
t.Errorf("Error message should mention 'read-only', got: %v", err)
}
}
// TestNewWriter tests that NewWriter creates a valid writer instance
func TestNewWriter(t *testing.T) {
opts := &writers.WriterOptions{
OutputPath: "/tmp/test.dctx",
}
writer := NewWriter(opts)
if writer == nil {
t.Error("NewWriter() should return a non-nil writer")
}
// 3. Write the schema
err := writer.WriteSchema(schema)
assert.NoError(t, err)
if writer.options != opts {
t.Error("Writer options should match the provided options")
}
}
// 4. Read the file and unmarshal it
actualBytes, err := os.ReadFile(outputPath)
assert.NoError(t, err)
var dctx models.DCTXDictionary
err = xml.Unmarshal(actualBytes, &dctx)
assert.NoError(t, err)
// 5. Assert properties of the unmarshaled struct
assert.Equal(t, "TestDB", dctx.Name)
assert.Equal(t, "1", dctx.Version)
assert.Len(t, dctx.Tables, 2)
assert.Len(t, dctx.Relations, 1)
// Assert users table
usersTableResult := dctx.Tables[0]
assert.Equal(t, "users", usersTableResult.Name)
assert.Len(t, usersTableResult.Fields, 2)
assert.Len(t, usersTableResult.Keys, 1)
userPK := usersTableResult.Keys[0]
assert.True(t, userPK.Primary)
assert.Equal(t, "users_pkey", userPK.Name)
assert.Len(t, userPK.Components, 1)
userPKComponent := userPK.Components[0]
assert.NotEmpty(t, userPKComponent.FieldId)
// Assert posts table
postsTableResult := dctx.Tables[1]
assert.Equal(t, "posts", postsTableResult.Name)
assert.Len(t, postsTableResult.Fields, 3)
assert.Len(t, postsTableResult.Keys, 2)
postsFK := postsTableResult.Keys[1] // Assuming order
assert.False(t, postsFK.Primary)
assert.Equal(t, "posts_user_id_idx", postsFK.Name)
assert.Len(t, postsFK.Components, 1)
postsFKComponent := postsFK.Components[0]
assert.NotEmpty(t, postsFKComponent.FieldId)
// Assert relation
relationResult := dctx.Relations[0]
// PrimaryTable and ForeignTable should be GUIDs in DCTX format
assert.NotEmpty(t, relationResult.PrimaryTable, "PrimaryTable should have a GUID")
assert.NotEmpty(t, relationResult.ForeignTable, "ForeignTable should have a GUID")
assert.NotEmpty(t, relationResult.PrimaryKey)
assert.NotEmpty(t, relationResult.ForeignKey)
// Check if the table GUIDs match
assert.Equal(t, usersTableResult.Guid, relationResult.PrimaryTable, "PrimaryTable GUID should match users table")
assert.Equal(t, postsTableResult.Guid, relationResult.ForeignTable, "ForeignTable GUID should match posts table")
// Check if the key GUIDs match up
assert.Equal(t, userPK.Guid, relationResult.PrimaryKey)
assert.Equal(t, postsFK.Guid, relationResult.ForeignKey)
// Verify field mappings exist
assert.NotEmpty(t, relationResult.ForeignMappings, "Relation should have ForeignMappings")
assert.NotEmpty(t, relationResult.PrimaryMappings, "Relation should have PrimaryMappings")
// ForeignMapping should reference primary table (users) fields
assert.Len(t, relationResult.ForeignMappings, 1)
assert.NotEmpty(t, relationResult.ForeignMappings[0].Field)
// PrimaryMapping should reference foreign table (posts) fields
assert.Len(t, relationResult.PrimaryMappings, 1)
assert.NotEmpty(t, relationResult.PrimaryMappings[0].Field)
}

View File

@@ -33,7 +33,7 @@ func TestWriter_WriteTable(t *testing.T) {
table.Columns["name"] = nameCol
// Add index
emailIdx := models.InitIndex("idx_users_email")
emailIdx := models.InitIndex("idx_users_email", "users", "public")
emailIdx.Columns = []string{"email"}
emailIdx.Unique = true
emailIdx.Table = "users"

View File

@@ -111,7 +111,7 @@ func TestWriter_WriteDatabase_WithRelationships(t *testing.T) {
usersTable.Columns["email"] = emailCol
// Add index
emailIdx := models.InitIndex("idx_users_email")
emailIdx := models.InitIndex("idx_users_email", "users", "public")
emailIdx.Columns = []string{"email"}
emailIdx.Unique = true
emailIdx.Type = "btree"

View File

@@ -1,696 +0,0 @@
# PostgreSQL Migration Templates
## Overview
The PostgreSQL migration writer uses Go text templates to generate SQL, making the code much more maintainable and customizable than hardcoded string concatenation.
## Architecture
```
pkg/writers/pgsql/
├── templates/ # Template files
│ ├── create_table.tmpl # CREATE TABLE
│ ├── add_column.tmpl # ALTER TABLE ADD COLUMN
│ ├── alter_column_type.tmpl # ALTER TABLE ALTER COLUMN TYPE
│ ├── alter_column_default.tmpl # ALTER TABLE ALTER COLUMN DEFAULT
│ ├── create_primary_key.tmpl # ADD CONSTRAINT PRIMARY KEY
│ ├── create_index.tmpl # CREATE INDEX
│ ├── create_foreign_key.tmpl # ADD CONSTRAINT FOREIGN KEY
│ ├── drop_constraint.tmpl # DROP CONSTRAINT
│ ├── drop_index.tmpl # DROP INDEX
│ ├── comment_table.tmpl # COMMENT ON TABLE
│ ├── comment_column.tmpl # COMMENT ON COLUMN
│ ├── audit_tables.tmpl # CREATE audit tables
│ ├── audit_function.tmpl # CREATE audit function
│ └── audit_trigger.tmpl # CREATE audit trigger
├── templates.go # Template executor and data structures
└── migration_writer_templated.go # Templated migration writer
```
## Using Templates
### Basic Usage
```go
// Create template executor
executor, err := pgsql.NewTemplateExecutor()
if err != nil {
log.Fatal(err)
}
// Prepare data
data := pgsql.CreateTableData{
SchemaName: "public",
TableName: "users",
Columns: []pgsql.ColumnData{
{Name: "id", Type: "integer", NotNull: true},
{Name: "name", Type: "text"},
},
}
// Execute template
sql, err := executor.ExecuteCreateTable(data)
if err != nil {
log.Fatal(err)
}
fmt.Println(sql)
```
### Using Templated Migration Writer
```go
// Create templated migration writer
writer, err := pgsql.NewTemplatedMigrationWriter(&writers.WriterOptions{
OutputPath: "migration.sql",
})
if err != nil {
log.Fatal(err)
}
// Generate migration (uses templates internally)
err = writer.WriteMigration(modelDB, currentDB)
if err != nil {
log.Fatal(err)
}
```
## Template Data Structures
### CreateTableData
For `create_table.tmpl`:
```go
type CreateTableData struct {
SchemaName string
TableName string
Columns []ColumnData
}
type ColumnData struct {
Name string
Type string
Default string
NotNull bool
}
```
Example:
```go
data := CreateTableData{
SchemaName: "public",
TableName: "products",
Columns: []ColumnData{
{Name: "id", Type: "serial", NotNull: true},
{Name: "name", Type: "text", NotNull: true},
{Name: "price", Type: "numeric(10,2)", Default: "0.00"},
},
}
```
### AddColumnData
For `add_column.tmpl`:
```go
type AddColumnData struct {
SchemaName string
TableName string
ColumnName string
ColumnType string
Default string
NotNull bool
}
```
### CreateIndexData
For `create_index.tmpl`:
```go
type CreateIndexData struct {
SchemaName string
TableName string
IndexName string
IndexType string // btree, hash, gin, gist
Columns string // comma-separated
Unique bool
}
```
### CreateForeignKeyData
For `create_foreign_key.tmpl`:
```go
type CreateForeignKeyData struct {
SchemaName string
TableName string
ConstraintName string
SourceColumns string // comma-separated
TargetSchema string
TargetTable string
TargetColumns string // comma-separated
OnDelete string // CASCADE, SET NULL, etc.
OnUpdate string
}
```
### AuditFunctionData
For `audit_function.tmpl`:
```go
type AuditFunctionData struct {
SchemaName string
FunctionName string
TableName string
TablePrefix string
PrimaryKey string
AuditSchema string
UserFunction string
AuditInsert bool
AuditUpdate bool
AuditDelete bool
UpdateCondition string
UpdateColumns []AuditColumnData
DeleteColumns []AuditColumnData
}
type AuditColumnData struct {
Name string
OldValue string // SQL expression for old value
NewValue string // SQL expression for new value
}
```
## Customizing Templates
### Modifying Existing Templates
Templates are embedded in the binary but can be modified at compile time:
1. **Edit template file** in `pkg/writers/pgsql/templates/`:
```go
// templates/create_table.tmpl
CREATE TABLE IF NOT EXISTS {{.SchemaName}}.{{.TableName}} (
{{- range $i, $col := .Columns}}
{{- if $i}},{{end}}
{{$col.Name}} {{$col.Type}}
{{- if $col.Default}} DEFAULT {{$col.Default}}{{end}}
{{- if $col.NotNull}} NOT NULL{{end}}
{{- end}}
);
-- Custom comment
COMMENT ON TABLE {{.SchemaName}}.{{.TableName}} IS 'Auto-generated by RelSpec';
```
2. **Rebuild** the application:
```bash
go build ./cmd/relspec
```
The new template is automatically embedded.
### Template Syntax Reference
#### Variables
```go
{{.FieldName}} // Access field
{{.SchemaName}} // String field
{{.NotNull}} // Boolean field
```
#### Conditionals
```go
{{if .NotNull}}
NOT NULL
{{end}}
{{if .Default}}
DEFAULT {{.Default}}
{{else}}
-- No default
{{end}}
```
#### Loops
```go
{{range $i, $col := .Columns}}
Column: {{$col.Name}} Type: {{$col.Type}}
{{end}}
```
#### Functions
```go
{{if eq .Type "CASCADE"}}
ON DELETE CASCADE
{{end}}
{{join .Columns ", "}} // Join string slice
```
### Creating New Templates
1. **Create template file** in `pkg/writers/pgsql/templates/`:
```go
// templates/custom_operation.tmpl
-- Custom operation for {{.TableName}}
ALTER TABLE {{.SchemaName}}.{{.TableName}}
{{.CustomOperation}};
```
2. **Define data structure** in `templates.go`:
```go
type CustomOperationData struct {
SchemaName string
TableName string
CustomOperation string
}
```
3. **Add executor method** in `templates.go`:
```go
func (te *TemplateExecutor) ExecuteCustomOperation(data CustomOperationData) (string, error) {
var buf bytes.Buffer
err := te.templates.ExecuteTemplate(&buf, "custom_operation.tmpl", data)
if err != nil {
return "", fmt.Errorf("failed to execute custom_operation template: %w", err)
}
return buf.String(), nil
}
```
4. **Use in migration writer**:
```go
sql, err := w.executor.ExecuteCustomOperation(CustomOperationData{
SchemaName: "public",
TableName: "users",
CustomOperation: "ADD COLUMN custom_field text",
})
```
## Template Examples
### Example 1: Custom Table Creation
Modify `create_table.tmpl` to add table options:
```sql
CREATE TABLE IF NOT EXISTS {{.SchemaName}}.{{.TableName}} (
{{- range $i, $col := .Columns}}
{{- if $i}},{{end}}
{{$col.Name}} {{$col.Type}}
{{- if $col.Default}} DEFAULT {{$col.Default}}{{end}}
{{- if $col.NotNull}} NOT NULL{{end}}
{{- end}}
) WITH (fillfactor = 90);
-- Add automatic comment
COMMENT ON TABLE {{.SchemaName}}.{{.TableName}}
IS 'Created: {{.CreatedDate}} | Version: {{.Version}}';
```
### Example 2: Custom Index with WHERE Clause
Add to `create_index.tmpl`:
```sql
CREATE {{if .Unique}}UNIQUE {{end}}INDEX IF NOT EXISTS {{.IndexName}}
ON {{.SchemaName}}.{{.TableName}}
USING {{.IndexType}} ({{.Columns}})
{{- if .Where}}
WHERE {{.Where}}
{{- end}}
{{- if .Include}}
INCLUDE ({{.Include}})
{{- end}};
```
Update data structure:
```go
type CreateIndexData struct {
SchemaName string
TableName string
IndexName string
IndexType string
Columns string
Unique bool
Where string // New field for partial indexes
Include string // New field for covering indexes
}
```
### Example 3: Enhanced Audit Function
Modify `audit_function.tmpl` to add custom logging:
```sql
CREATE OR REPLACE FUNCTION {{.SchemaName}}.{{.FunctionName}}()
RETURNS trigger AS
$body$
DECLARE
m_funcname text = '{{.FunctionName}}';
m_user text;
m_atevent integer;
m_application_name text;
BEGIN
-- Get current user and application
m_user := {{.UserFunction}}::text;
m_application_name := current_setting('application_name', true);
-- Custom logging
RAISE NOTICE 'Audit: % on %.% by % from %',
TG_OP, TG_TABLE_SCHEMA, TG_TABLE_NAME, m_user, m_application_name;
-- Rest of function...
...
```
## Best Practices
### 1. Keep Templates Simple
Templates should focus on SQL generation. Complex logic belongs in Go code:
**Good:**
```go
// In Go code
columns := buildColumnList(table)
// In template
{{range .Columns}}
{{.Name}} {{.Type}}
{{end}}
```
**Bad:**
```go
// Don't do complex transformations in templates
{{range .Columns}}
{{if eq .Type "integer"}}
{{.Name}} serial
{{else}}
{{.Name}} {{.Type}}
{{end}}
{{end}}
```
### 2. Use Descriptive Field Names
```go
// Good
type CreateTableData struct {
SchemaName string
TableName string
}
// Bad
type CreateTableData struct {
S string // What is S?
T string // What is T?
}
```
### 3. Document Template Data
Always document what data a template expects:
```go
// CreateTableData contains data for create table template.
// Used by templates/create_table.tmpl
type CreateTableData struct {
SchemaName string // Schema where table will be created
TableName string // Name of the table
Columns []ColumnData // List of columns to create
}
```
### 4. Handle SQL Injection
Always escape user input:
```go
// In Go code - escape before passing to template
data := CommentTableData{
SchemaName: schema,
TableName: table,
Comment: escapeQuote(userComment), // Escape quotes
}
```
### 5. Test Templates Thoroughly
```go
func TestTemplate_CreateTable(t *testing.T) {
executor, _ := NewTemplateExecutor()
data := CreateTableData{
SchemaName: "public",
TableName: "test",
Columns: []ColumnData{{Name: "id", Type: "integer"}},
}
sql, err := executor.ExecuteCreateTable(data)
if err != nil {
t.Fatal(err)
}
// Verify expected SQL patterns
if !strings.Contains(sql, "CREATE TABLE") {
t.Error("Missing CREATE TABLE")
}
}
```
## Benefits of Template-Based Approach
### Maintainability
**Before (string concatenation):**
```go
sql := fmt.Sprintf(`CREATE TABLE %s.%s (
%s %s%s%s
);`, schema, table, col, typ,
func() string {
if def != "" {
return " DEFAULT " + def
}
return ""
}(),
func() string {
if notNull {
return " NOT NULL"
}
return ""
}(),
)
```
**After (templates):**
```go
sql, _ := executor.ExecuteCreateTable(CreateTableData{
SchemaName: schema,
TableName: table,
Columns: columns,
})
```
### Customization
Users can modify templates without changing Go code:
- Edit template file
- Rebuild application
- New SQL generation logic active
### Testing
Templates can be tested independently:
```go
func TestAuditTemplate(t *testing.T) {
executor, _ := NewTemplateExecutor()
// Test with various data
for _, testCase := range testCases {
sql, err := executor.ExecuteAuditFunction(testCase.data)
// Verify output
}
}
```
### Readability
SQL templates are easier to read and review than Go string building code.
## Migration from Old Writer
To migrate from the old string-based writer to templates:
### Option 1: Use TemplatedMigrationWriter
```go
// Old
writer := pgsql.NewMigrationWriter(options)
// New
writer, err := pgsql.NewTemplatedMigrationWriter(options)
if err != nil {
log.Fatal(err)
}
// Same interface
writer.WriteMigration(model, current)
```
### Option 2: Keep Both
Both writers are available:
- `MigrationWriter` - Original string-based
- `TemplatedMigrationWriter` - New template-based
Choose based on your needs.
## Troubleshooting
### Template Not Found
```
Error: template: "my_template.tmpl" not defined
```
Solution: Ensure template file exists in `templates/` directory and rebuild.
### Template Execution Error
```
Error: template: create_table.tmpl:5:10: executing "create_table.tmpl"
at <.InvalidField>: can't evaluate field InvalidField
```
Solution: Check data structure has all fields used in template.
### Embedded Files Not Updating
If template changes aren't reflected:
1. Clean build cache: `go clean -cache`
2. Rebuild: `go build ./cmd/relspec`
3. Verify template file is in `templates/` directory
## Custom Template Functions
RelSpec provides a comprehensive library of template functions for SQL generation:
### String Manipulation
- `upper`, `lower` - Case conversion
- `snake_case`, `camelCase` - Naming convention conversion
- Usage: `{{upper .TableName}}``USERS`
### SQL Formatting
- `indent(spaces, text)` - Indent text
- `quote(string)` - Quote for SQL with escaping
- `escape(string)` - Escape special characters
- `safe_identifier(string)` - Make SQL-safe identifier
- Usage: `{{quote "O'Brien"}}``'O''Brien'`
### Type Conversion
- `goTypeToSQL(type)` - Convert Go type to PostgreSQL type
- `sqlTypeToGo(type)` - Convert PostgreSQL type to Go type
- `isNumeric(type)`, `isText(type)` - Type checking
- Usage: `{{goTypeToSQL "int64"}}``bigint`
### Collection Helpers
- `first(slice)`, `last(slice)` - Get elements
- `join_with(slice, sep)` - Join with custom separator
- Usage: `{{join_with .Columns ", "}}``id, name, email`
See [template_functions.go](template_functions.go) for full documentation.
## Template Inheritance and Composition
RelSpec supports Go template inheritance using `{{template}}` and `{{block}}`:
### Base Templates
- `base_ddl.tmpl` - Common DDL patterns
- `base_constraint.tmpl` - Constraint operations
- `fragments.tmpl` - Reusable fragments
### Using Fragments
```gotmpl
{{/* Use predefined fragments */}}
CREATE TABLE {{template "qualified_table" .}} (
{{range .Columns}}
{{template "column_definition" .}}
{{end}}
);
```
### Template Blocks
```gotmpl
{{/* Define with override capability */}}
{{define "table_options"}}
) {{block "storage_options" .}}WITH (fillfactor = 90){{end}};
{{end}}
```
See [TEMPLATE_INHERITANCE.md](TEMPLATE_INHERITANCE.md) for detailed guide.
## Visual Template Editor
A VS Code extension is available for visual template editing:
### Features
- **Live Preview** - See rendered SQL as you type
- **IntelliSense** - Auto-completion for functions
- **Validation** - Syntax checking and error highlighting
- **Scaffolding** - Quick template creation
- **Function Browser** - Browse available functions
### Installation
```bash
cd vscode-extension
npm install
npm run compile
code .
# Press F5 to launch
```
See [vscode-extension/README.md](../../vscode-extension/README.md) for full documentation.
## Future Enhancements
Completed:
- [x] Template inheritance/composition
- [x] Custom template functions library
- [x] Visual template editor (VS Code)
Potential future improvements:
- [ ] Parameterized templates (load from config)
- [ ] Template validation CLI tool
- [ ] Template library/marketplace
- [ ] Template versioning
- [ ] Hot-reload during development
## Contributing Templates
When contributing new templates:
1. Place in `pkg/writers/pgsql/templates/`
2. Use `.tmpl` extension
3. Document data structure in `templates.go`
4. Add executor method
5. Write tests
6. Update this documentation

View File

@@ -62,6 +62,234 @@ func (w *Writer) WriteDatabase(db *models.Database) error {
return nil
}
// GenerateDatabaseStatements generates SQL statements as a list for the entire database
// Returns a slice of SQL statements that can be executed independently
func (w *Writer) GenerateDatabaseStatements(db *models.Database) ([]string, error) {
statements := []string{}
// Add header comment
statements = append(statements, fmt.Sprintf("-- PostgreSQL Database Schema"))
statements = append(statements, fmt.Sprintf("-- Database: %s", db.Name))
statements = append(statements, fmt.Sprintf("-- Generated by RelSpec"))
// Process each schema in the database
for _, schema := range db.Schemas {
schemaStatements, err := w.GenerateSchemaStatements(schema)
if err != nil {
return nil, fmt.Errorf("failed to generate statements for schema %s: %w", schema.Name, err)
}
statements = append(statements, schemaStatements...)
}
return statements, nil
}
// GenerateSchemaStatements generates SQL statements as a list for a single schema
func (w *Writer) GenerateSchemaStatements(schema *models.Schema) ([]string, error) {
statements := []string{}
// Phase 1: Create schema
if schema.Name != "public" {
statements = append(statements, fmt.Sprintf("-- Schema: %s", schema.Name))
statements = append(statements, fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS %s", schema.SQLName()))
}
// Phase 2: Create sequences
for _, table := range schema.Tables {
pk := table.GetPrimaryKey()
if pk == nil || !isIntegerType(pk.Type) || pk.Default == "" {
continue
}
defaultStr, ok := pk.Default.(string)
if !ok || !strings.Contains(strings.ToLower(defaultStr), "nextval") {
continue
}
seqName := extractSequenceName(defaultStr)
if seqName == "" {
continue
}
stmt := fmt.Sprintf("CREATE SEQUENCE IF NOT EXISTS %s.%s\n INCREMENT 1\n MINVALUE 1\n MAXVALUE 9223372036854775807\n START 1\n CACHE 1",
schema.SQLName(), seqName)
statements = append(statements, stmt)
}
// Phase 3: Create tables
for _, table := range schema.Tables {
stmts, err := w.generateCreateTableStatement(schema, table)
if err != nil {
return nil, fmt.Errorf("failed to generate table %s: %w", table.Name, err)
}
statements = append(statements, stmts...)
}
// Phase 4: Primary keys
for _, table := range schema.Tables {
for _, constraint := range table.Constraints {
if constraint.Type != models.PrimaryKeyConstraint {
continue
}
stmt := fmt.Sprintf("ALTER TABLE %s.%s ADD CONSTRAINT %s PRIMARY KEY (%s)",
schema.SQLName(), table.SQLName(), constraint.Name, strings.Join(constraint.Columns, ", "))
statements = append(statements, stmt)
}
}
// Phase 5: Indexes
for _, table := range schema.Tables {
for _, index := range table.Indexes {
// Skip primary key indexes
if strings.HasSuffix(index.Name, "_pkey") {
continue
}
uniqueStr := ""
if index.Unique {
uniqueStr = "UNIQUE "
}
indexType := index.Type
if indexType == "" {
indexType = "btree"
}
whereClause := ""
if index.Where != "" {
whereClause = fmt.Sprintf(" WHERE %s", index.Where)
}
stmt := fmt.Sprintf("CREATE %sINDEX IF NOT EXISTS %s ON %s.%s USING %s (%s)%s",
uniqueStr, index.Name, schema.SQLName(), table.SQLName(), indexType, strings.Join(index.Columns, ", "), whereClause)
statements = append(statements, stmt)
}
}
// Phase 6: Foreign keys
for _, table := range schema.Tables {
for _, constraint := range table.Constraints {
if constraint.Type != models.ForeignKeyConstraint {
continue
}
refSchema := constraint.ReferencedSchema
if refSchema == "" {
refSchema = schema.Name
}
onDelete := constraint.OnDelete
if onDelete == "" {
onDelete = "NO ACTION"
}
onUpdate := constraint.OnUpdate
if onUpdate == "" {
onUpdate = "NO ACTION"
}
stmt := fmt.Sprintf("ALTER TABLE %s.%s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s.%s(%s) ON DELETE %s ON UPDATE %s",
schema.SQLName(), table.SQLName(), constraint.Name,
strings.Join(constraint.Columns, ", "),
strings.ToLower(refSchema), strings.ToLower(constraint.ReferencedTable),
strings.Join(constraint.ReferencedColumns, ", "),
onDelete, onUpdate)
statements = append(statements, stmt)
}
}
// Phase 7: Comments
for _, table := range schema.Tables {
if table.Comment != "" {
stmt := fmt.Sprintf("COMMENT ON TABLE %s.%s IS '%s'",
schema.SQLName(), table.SQLName(), escapeQuote(table.Comment))
statements = append(statements, stmt)
}
for _, column := range table.Columns {
if column.Comment != "" {
stmt := fmt.Sprintf("COMMENT ON COLUMN %s.%s.%s IS '%s'",
schema.SQLName(), table.SQLName(), column.SQLName(), escapeQuote(column.Comment))
statements = append(statements, stmt)
}
}
}
return statements, nil
}
// generateCreateTableStatement generates CREATE TABLE statement
func (w *Writer) generateCreateTableStatement(schema *models.Schema, table *models.Table) ([]string, error) {
statements := []string{}
// Sort columns by sequence or name
columns := make([]*models.Column, 0, len(table.Columns))
for _, col := range table.Columns {
columns = append(columns, col)
}
sort.Slice(columns, func(i, j int) bool {
if columns[i].Sequence != columns[j].Sequence {
return columns[i].Sequence < columns[j].Sequence
}
return columns[i].Name < columns[j].Name
})
columnDefs := []string{}
for _, col := range columns {
def := w.generateColumnDefinition(col)
columnDefs = append(columnDefs, " "+def)
}
stmt := fmt.Sprintf("CREATE TABLE %s.%s (\n%s\n)",
schema.SQLName(), table.SQLName(), strings.Join(columnDefs, ",\n"))
statements = append(statements, stmt)
return statements, nil
}
// generateColumnDefinition generates column definition
func (w *Writer) generateColumnDefinition(col *models.Column) string {
parts := []string{col.SQLName()}
// Type with length/precision
typeStr := col.Type
if col.Length > 0 && col.Precision == 0 {
typeStr = fmt.Sprintf("%s(%d)", col.Type, col.Length)
} else if col.Precision > 0 {
if col.Scale > 0 {
typeStr = fmt.Sprintf("%s(%d,%d)", col.Type, col.Precision, col.Scale)
} else {
typeStr = fmt.Sprintf("%s(%d)", col.Type, col.Precision)
}
}
parts = append(parts, typeStr)
// NOT NULL
if col.NotNull {
parts = append(parts, "NOT NULL")
}
// DEFAULT
if col.Default != nil {
switch v := col.Default.(type) {
case string:
if strings.HasPrefix(v, "nextval") || strings.HasPrefix(v, "CURRENT_") || strings.Contains(v, "()") {
parts = append(parts, fmt.Sprintf("DEFAULT %s", v))
} else if v == "true" || v == "false" {
parts = append(parts, fmt.Sprintf("DEFAULT %s", v))
} else {
parts = append(parts, fmt.Sprintf("DEFAULT '%s'", escapeQuote(v)))
}
case bool:
parts = append(parts, fmt.Sprintf("DEFAULT %v", v))
default:
parts = append(parts, fmt.Sprintf("DEFAULT %v", v))
}
}
return strings.Join(parts, " ")
}
// WriteSchema writes a single schema and all its tables
func (w *Writer) WriteSchema(schema *models.Schema) error {
if w.writer == nil {
@@ -494,3 +722,26 @@ func isIntegerType(colType string) bool {
func escapeQuote(s string) string {
return strings.ReplaceAll(s, "'", "''")
}
// extractSequenceName extracts sequence name from nextval() expression
// Example: "nextval('public.users_id_seq'::regclass)" returns "users_id_seq"
func extractSequenceName(defaultExpr string) string {
// Look for nextval('schema.sequence_name'::regclass) pattern
start := strings.Index(defaultExpr, "'")
if start == -1 {
return ""
}
end := strings.Index(defaultExpr[start+1:], "'")
if end == -1 {
return ""
}
fullName := defaultExpr[start+1 : start+1+end]
// Remove schema prefix if present
parts := strings.Split(fullName, ".")
if len(parts) > 1 {
return parts[len(parts)-1]
}
return fullName
}

View File

@@ -112,7 +112,7 @@ func TestWriter_WriteDatabase_WithRelationships(t *testing.T) {
usersTable.Columns["email"] = emailCol
// Add index
emailIdx := models.InitIndex("idx_users_email")
emailIdx := models.InitIndex("idx_users_email", "users", "public")
emailIdx.Columns = []string{"email"}
emailIdx.Unique = true
emailIdx.Type = "btree"