package mssql import ( "context" "database/sql" "fmt" "io" "os" "sort" "strings" _ "github.com/microsoft/go-mssqldb" // MSSQL driver "git.warky.dev/wdevs/relspecgo/pkg/models" "git.warky.dev/wdevs/relspecgo/pkg/mssql" "git.warky.dev/wdevs/relspecgo/pkg/writers" ) // Writer implements the Writer interface for MSSQL SQL output type Writer struct { options *writers.WriterOptions writer io.Writer } // NewWriter creates a new MSSQL SQL writer func NewWriter(options *writers.WriterOptions) *Writer { return &Writer{ options: options, } } // qualTable returns a schema-qualified name using bracket notation func (w *Writer) qualTable(schema, name string) string { if w.options.FlattenSchema { return fmt.Sprintf("[%s]", name) } return fmt.Sprintf("[%s].[%s]", schema, name) } // WriteDatabase writes the entire database schema as SQL func (w *Writer) WriteDatabase(db *models.Database) error { // Check if we should execute SQL directly on a database if connString, ok := w.options.Metadata["connection_string"].(string); ok && connString != "" { return w.executeDatabaseSQL(db, connString) } var writer io.Writer var file *os.File var err error // Use existing writer if already set (for testing) if w.writer != nil { writer = w.writer } else if w.options.OutputPath != "" { // Determine output destination file, err = os.Create(w.options.OutputPath) if err != nil { return fmt.Errorf("failed to create output file: %w", err) } defer file.Close() writer = file } else { writer = os.Stdout } w.writer = writer // Write header comment fmt.Fprintf(w.writer, "-- MSSQL Database Schema\n") fmt.Fprintf(w.writer, "-- Database: %s\n", db.Name) fmt.Fprintf(w.writer, "-- Generated by RelSpec\n\n") // Process each schema in the database for _, schema := range db.Schemas { if err := w.WriteSchema(schema); err != nil { return fmt.Errorf("failed to write schema %s: %w", schema.Name, err) } } return nil } // WriteSchema writes a single schema and all its tables func (w *Writer) WriteSchema(schema *models.Schema) error { if w.writer == nil { w.writer = os.Stdout } // Phase 1: Create schema (skip dbo schema and when flattening) if schema.Name != "dbo" && !w.options.FlattenSchema { fmt.Fprintf(w.writer, "-- Schema: %s\n", schema.Name) fmt.Fprintf(w.writer, "CREATE SCHEMA [%s];\n\n", schema.Name) } // Phase 2: Create tables with columns fmt.Fprintf(w.writer, "-- Tables for schema: %s\n", schema.Name) for _, table := range schema.Tables { if err := w.writeCreateTable(schema, table); err != nil { return err } } // Phase 3: Primary keys fmt.Fprintf(w.writer, "-- Primary keys for schema: %s\n", schema.Name) for _, table := range schema.Tables { if err := w.writePrimaryKey(schema, table); err != nil { return err } } // Phase 4: Indexes fmt.Fprintf(w.writer, "-- Indexes for schema: %s\n", schema.Name) for _, table := range schema.Tables { if err := w.writeIndexes(schema, table); err != nil { return err } } // Phase 5: Unique constraints fmt.Fprintf(w.writer, "-- Unique constraints for schema: %s\n", schema.Name) for _, table := range schema.Tables { if err := w.writeUniqueConstraints(schema, table); err != nil { return err } } // Phase 6: Check constraints fmt.Fprintf(w.writer, "-- Check constraints for schema: %s\n", schema.Name) for _, table := range schema.Tables { if err := w.writeCheckConstraints(schema, table); err != nil { return err } } // Phase 7: Foreign keys fmt.Fprintf(w.writer, "-- Foreign keys for schema: %s\n", schema.Name) for _, table := range schema.Tables { if err := w.writeForeignKeys(schema, table); err != nil { return err } } // Phase 8: Comments fmt.Fprintf(w.writer, "-- Comments for schema: %s\n", schema.Name) for _, table := range schema.Tables { if err := w.writeComments(schema, table); err != nil { return err } } return nil } // WriteTable writes a single table with all its elements func (w *Writer) WriteTable(table *models.Table) error { if w.writer == nil { w.writer = os.Stdout } // Create a temporary schema with just this table schema := models.InitSchema(table.Schema) schema.Tables = append(schema.Tables, table) return w.WriteSchema(schema) } // writeCreateTable generates CREATE TABLE statement func (w *Writer) writeCreateTable(schema *models.Schema, table *models.Table) error { fmt.Fprintf(w.writer, "CREATE TABLE %s (\n", w.qualTable(schema.Name, table.Name)) // Sort columns by sequence columns := getSortedColumns(table.Columns) columnDefs := make([]string, 0, len(columns)) for _, col := range columns { def := w.generateColumnDefinition(col) columnDefs = append(columnDefs, " "+def) } fmt.Fprintf(w.writer, "%s\n", strings.Join(columnDefs, ",\n")) fmt.Fprintf(w.writer, ");\n\n") return nil } // generateColumnDefinition generates MSSQL column definition func (w *Writer) generateColumnDefinition(col *models.Column) string { parts := []string{fmt.Sprintf("[%s]", col.Name)} // Type with length/precision baseType := mssql.ConvertCanonicalToMSSQL(col.Type) typeStr := baseType // Handle specific type parameters for MSSQL if col.Length > 0 && col.Precision == 0 { // String types with length - override the default length from baseType if strings.HasPrefix(baseType, "NVARCHAR") || strings.HasPrefix(baseType, "VARCHAR") || strings.HasPrefix(baseType, "CHAR") || strings.HasPrefix(baseType, "NCHAR") { if col.Length > 0 && col.Length < 8000 { // Extract base type without length specification baseName := strings.Split(baseType, "(")[0] typeStr = fmt.Sprintf("%s(%d)", baseName, col.Length) } } } else if col.Precision > 0 { // Numeric types with precision/scale baseName := strings.Split(baseType, "(")[0] if col.Scale > 0 { typeStr = fmt.Sprintf("%s(%d,%d)", baseName, col.Precision, col.Scale) } else { typeStr = fmt.Sprintf("%s(%d)", baseName, col.Precision) } } parts = append(parts, typeStr) // IDENTITY for auto-increment if col.AutoIncrement { parts = append(parts, "IDENTITY(1,1)") } // NOT NULL if col.NotNull { parts = append(parts, "NOT NULL") } // DEFAULT if col.Default != nil { switch v := col.Default.(type) { case string: cleanDefault := stripBackticks(v) if strings.HasPrefix(strings.ToUpper(cleanDefault), "GETDATE") || strings.HasPrefix(strings.ToUpper(cleanDefault), "CURRENT_") { parts = append(parts, fmt.Sprintf("DEFAULT %s", cleanDefault)) } else if cleanDefault == "true" || cleanDefault == "false" { if cleanDefault == "true" { parts = append(parts, "DEFAULT 1") } else { parts = append(parts, "DEFAULT 0") } } else { parts = append(parts, fmt.Sprintf("DEFAULT '%s'", escapeQuote(cleanDefault))) } case bool: if v { parts = append(parts, "DEFAULT 1") } else { parts = append(parts, "DEFAULT 0") } case int, int64: parts = append(parts, fmt.Sprintf("DEFAULT %v", v)) } } return strings.Join(parts, " ") } // writePrimaryKey generates ALTER TABLE statement for primary key func (w *Writer) writePrimaryKey(schema *models.Schema, table *models.Table) error { // Find primary key constraint var pkConstraint *models.Constraint for _, constraint := range table.Constraints { if constraint.Type == models.PrimaryKeyConstraint { pkConstraint = constraint break } } var columnNames []string pkName := fmt.Sprintf("PK_%s_%s", schema.Name, table.Name) if pkConstraint != nil { pkName = pkConstraint.Name columnNames = make([]string, 0, len(pkConstraint.Columns)) for _, colName := range pkConstraint.Columns { columnNames = append(columnNames, fmt.Sprintf("[%s]", colName)) } } else { // Check for columns with IsPrimaryKey = true for _, col := range table.Columns { if col.IsPrimaryKey { columnNames = append(columnNames, fmt.Sprintf("[%s]", col.Name)) } } sort.Strings(columnNames) } if len(columnNames) == 0 { return nil } fmt.Fprintf(w.writer, "ALTER TABLE %s ADD CONSTRAINT [%s] PRIMARY KEY (%s);\n\n", w.qualTable(schema.Name, table.Name), pkName, strings.Join(columnNames, ", ")) return nil } // writeIndexes generates CREATE INDEX statements func (w *Writer) writeIndexes(schema *models.Schema, table *models.Table) error { // Sort indexes by name indexNames := make([]string, 0, len(table.Indexes)) for name := range table.Indexes { indexNames = append(indexNames, name) } sort.Strings(indexNames) for _, name := range indexNames { index := table.Indexes[name] // Skip if it's a primary key index if strings.HasPrefix(strings.ToLower(index.Name), "pk_") { continue } // Build column list columnExprs := make([]string, 0, len(index.Columns)) for _, colName := range index.Columns { columnExprs = append(columnExprs, fmt.Sprintf("[%s]", colName)) } if len(columnExprs) == 0 { continue } unique := "" if index.Unique { unique = "UNIQUE " } fmt.Fprintf(w.writer, "CREATE %sINDEX [%s] ON %s (%s);\n\n", unique, index.Name, w.qualTable(schema.Name, table.Name), strings.Join(columnExprs, ", ")) } return nil } // writeUniqueConstraints generates ALTER TABLE statements for unique constraints func (w *Writer) writeUniqueConstraints(schema *models.Schema, table *models.Table) error { // Sort constraints by name constraintNames := make([]string, 0) for name, constraint := range table.Constraints { if constraint.Type == models.UniqueConstraint { constraintNames = append(constraintNames, name) } } sort.Strings(constraintNames) for _, name := range constraintNames { constraint := table.Constraints[name] // Build column list columnExprs := make([]string, 0, len(constraint.Columns)) for _, colName := range constraint.Columns { columnExprs = append(columnExprs, fmt.Sprintf("[%s]", colName)) } if len(columnExprs) == 0 { continue } fmt.Fprintf(w.writer, "ALTER TABLE %s ADD CONSTRAINT [%s] UNIQUE (%s);\n\n", w.qualTable(schema.Name, table.Name), constraint.Name, strings.Join(columnExprs, ", ")) } return nil } // writeCheckConstraints generates ALTER TABLE statements for check constraints func (w *Writer) writeCheckConstraints(schema *models.Schema, table *models.Table) error { // Sort constraints by name constraintNames := make([]string, 0) for name, constraint := range table.Constraints { if constraint.Type == models.CheckConstraint { constraintNames = append(constraintNames, name) } } sort.Strings(constraintNames) for _, name := range constraintNames { constraint := table.Constraints[name] if constraint.Expression == "" { continue } fmt.Fprintf(w.writer, "ALTER TABLE %s ADD CONSTRAINT [%s] CHECK (%s);\n\n", w.qualTable(schema.Name, table.Name), constraint.Name, constraint.Expression) } return nil } // writeForeignKeys generates ALTER TABLE statements for foreign keys func (w *Writer) writeForeignKeys(schema *models.Schema, table *models.Table) error { // Process foreign key constraints constraintNames := make([]string, 0) for name, constraint := range table.Constraints { if constraint.Type == models.ForeignKeyConstraint { constraintNames = append(constraintNames, name) } } sort.Strings(constraintNames) for _, name := range constraintNames { constraint := table.Constraints[name] // Build column lists sourceColumns := make([]string, 0, len(constraint.Columns)) for _, colName := range constraint.Columns { sourceColumns = append(sourceColumns, fmt.Sprintf("[%s]", colName)) } targetColumns := make([]string, 0, len(constraint.ReferencedColumns)) for _, colName := range constraint.ReferencedColumns { targetColumns = append(targetColumns, fmt.Sprintf("[%s]", colName)) } if len(sourceColumns) == 0 || len(targetColumns) == 0 { continue } refSchema := constraint.ReferencedSchema if refSchema == "" { refSchema = schema.Name } onDelete := "NO ACTION" if constraint.OnDelete != "" { onDelete = strings.ToUpper(constraint.OnDelete) } onUpdate := "NO ACTION" if constraint.OnUpdate != "" { onUpdate = strings.ToUpper(constraint.OnUpdate) } fmt.Fprintf(w.writer, "ALTER TABLE %s ADD CONSTRAINT [%s] FOREIGN KEY (%s)\n", w.qualTable(schema.Name, table.Name), constraint.Name, strings.Join(sourceColumns, ", ")) fmt.Fprintf(w.writer, " REFERENCES %s (%s)\n", w.qualTable(refSchema, constraint.ReferencedTable), strings.Join(targetColumns, ", ")) fmt.Fprintf(w.writer, " ON DELETE %s ON UPDATE %s;\n\n", onDelete, onUpdate) } return nil } // writeComments generates EXEC sp_addextendedproperty statements for table and column descriptions func (w *Writer) writeComments(schema *models.Schema, table *models.Table) error { // Table comment if table.Description != "" { fmt.Fprintf(w.writer, "EXEC sp_addextendedproperty\n") fmt.Fprintf(w.writer, " @name = 'MS_Description',\n") fmt.Fprintf(w.writer, " @value = '%s',\n", escapeQuote(table.Description)) fmt.Fprintf(w.writer, " @level0type = 'SCHEMA', @level0name = '%s',\n", schema.Name) fmt.Fprintf(w.writer, " @level1type = 'TABLE', @level1name = '%s';\n\n", table.Name) } // Column comments for _, col := range getSortedColumns(table.Columns) { if col.Description != "" { fmt.Fprintf(w.writer, "EXEC sp_addextendedproperty\n") fmt.Fprintf(w.writer, " @name = 'MS_Description',\n") fmt.Fprintf(w.writer, " @value = '%s',\n", escapeQuote(col.Description)) fmt.Fprintf(w.writer, " @level0type = 'SCHEMA', @level0name = '%s',\n", schema.Name) fmt.Fprintf(w.writer, " @level1type = 'TABLE', @level1name = '%s',\n", table.Name) fmt.Fprintf(w.writer, " @level2type = 'COLUMN', @level2name = '%s';\n\n", col.Name) } } return nil } // executeDatabaseSQL executes SQL statements directly on an MSSQL database func (w *Writer) executeDatabaseSQL(db *models.Database, connString string) error { // Generate SQL statements statements := []string{} statements = append(statements, "-- MSSQL Database Schema") statements = append(statements, fmt.Sprintf("-- Database: %s", db.Name)) statements = append(statements, "-- Generated by RelSpec") for _, schema := range db.Schemas { if err := w.generateSchemaStatements(schema, &statements); err != nil { return fmt.Errorf("failed to generate statements for schema %s: %w", schema.Name, err) } } // Connect to database dbConn, err := sql.Open("mssql", connString) if err != nil { return fmt.Errorf("failed to connect to database: %w", err) } defer dbConn.Close() ctx := context.Background() if err = dbConn.PingContext(ctx); err != nil { return fmt.Errorf("failed to ping database: %w", err) } // Execute statements executedCount := 0 for i, stmt := range statements { stmtTrimmed := strings.TrimSpace(stmt) // Skip comments and empty statements if strings.HasPrefix(stmtTrimmed, "--") || stmtTrimmed == "" { continue } fmt.Fprintf(os.Stderr, "Executing statement %d/%d...\n", i+1, len(statements)) _, execErr := dbConn.ExecContext(ctx, stmt) if execErr != nil { fmt.Fprintf(os.Stderr, "⚠ Warning: Statement failed: %v\n", execErr) continue } executedCount++ } fmt.Fprintf(os.Stderr, "✓ Successfully executed %d statements\n", executedCount) return nil } // generateSchemaStatements generates SQL statements for a schema func (w *Writer) generateSchemaStatements(schema *models.Schema, statements *[]string) error { // Phase 1: Create schema if schema.Name != "dbo" && !w.options.FlattenSchema { *statements = append(*statements, fmt.Sprintf("-- Schema: %s", schema.Name)) *statements = append(*statements, fmt.Sprintf("CREATE SCHEMA [%s];", schema.Name)) } // Phase 2: Create tables *statements = append(*statements, fmt.Sprintf("-- Tables for schema: %s", schema.Name)) for _, table := range schema.Tables { createTableSQL := fmt.Sprintf("CREATE TABLE %s (", w.qualTable(schema.Name, table.Name)) columnDefs := make([]string, 0) columns := getSortedColumns(table.Columns) for _, col := range columns { def := w.generateColumnDefinition(col) columnDefs = append(columnDefs, " "+def) } createTableSQL += "\n" + strings.Join(columnDefs, ",\n") + "\n)" *statements = append(*statements, createTableSQL) } // Phase 3-7: Constraints and indexes will be added by WriteSchema logic // For now, just create tables return nil } // Helper functions // getSortedColumns returns columns sorted by sequence func getSortedColumns(columns map[string]*models.Column) []*models.Column { names := make([]string, 0, len(columns)) for name := range columns { names = append(names, name) } sort.Strings(names) sorted := make([]*models.Column, 0, len(columns)) for _, name := range names { sorted = append(sorted, columns[name]) } return sorted } // escapeQuote escapes single quotes in strings for SQL func escapeQuote(s string) string { return strings.ReplaceAll(s, "'", "''") } // stripBackticks removes backticks from SQL expressions func stripBackticks(s string) string { return strings.ReplaceAll(s, "`", "") }