feat(mssql): add MSSQL writer for generating DDL from database schema
All checks were successful
All checks were successful
- Implement MSSQL writer to generate SQL scripts for creating schemas, tables, and constraints. - Support for identity columns, indexes, and extended properties. - Add tests for column definitions, table creation, primary keys, foreign keys, and comments. - Include testing guide and sample schema for integration tests.
This commit is contained in:
91
pkg/readers/mssql/README.md
Normal file
91
pkg/readers/mssql/README.md
Normal file
@@ -0,0 +1,91 @@
|
||||
# MSSQL Reader
|
||||
|
||||
Reads database schema from Microsoft SQL Server databases using a live connection.
|
||||
|
||||
## Features
|
||||
|
||||
- **Live Connection**: Connects to MSSQL databases using the Microsoft ODBC driver
|
||||
- **Multi-Schema Support**: Reads multiple schemas with full support for user-defined schemas
|
||||
- **Comprehensive Metadata**: Reads tables, columns, constraints, indexes, and extended properties
|
||||
- **Type Mapping**: Converts MSSQL types to canonical types for cross-database compatibility
|
||||
- **Extended Properties**: Extracts table and column descriptions from MS_Description
|
||||
- **Identity Columns**: Maps IDENTITY columns to AutoIncrement
|
||||
- **Relationships**: Derives relationships from foreign key constraints
|
||||
|
||||
## Connection String Format
|
||||
|
||||
```
|
||||
sqlserver://[user[:password]@][host][:port][?query]
|
||||
```
|
||||
|
||||
Examples:
|
||||
```
|
||||
sqlserver://sa:password@localhost/dbname
|
||||
sqlserver://user:pass@192.168.1.100:1433/production
|
||||
sqlserver://localhost/testdb?encrypt=disable
|
||||
```
|
||||
|
||||
## Supported Constraints
|
||||
|
||||
- Primary Keys
|
||||
- Foreign Keys (with ON DELETE and ON UPDATE actions)
|
||||
- Unique Constraints
|
||||
- Check Constraints
|
||||
|
||||
## Type Mappings
|
||||
|
||||
| MSSQL Type | Canonical Type |
|
||||
|------------|----------------|
|
||||
| INT | int |
|
||||
| BIGINT | int64 |
|
||||
| SMALLINT | int16 |
|
||||
| TINYINT | int8 |
|
||||
| BIT | bool |
|
||||
| REAL | float32 |
|
||||
| FLOAT | float64 |
|
||||
| NUMERIC, DECIMAL | decimal |
|
||||
| NVARCHAR, VARCHAR | string |
|
||||
| DATETIME2 | timestamp |
|
||||
| DATETIMEOFFSET | timestamptz |
|
||||
| UNIQUEIDENTIFIER | uuid |
|
||||
| VARBINARY | bytea |
|
||||
| DATE | date |
|
||||
| TIME | time |
|
||||
|
||||
## Usage
|
||||
|
||||
```go
|
||||
import "git.warky.dev/wdevs/relspecgo/pkg/readers/mssql"
|
||||
import "git.warky.dev/wdevs/relspecgo/pkg/readers"
|
||||
|
||||
reader := mssql.NewReader(&readers.ReaderOptions{
|
||||
ConnectionString: "sqlserver://sa:password@localhost/mydb",
|
||||
})
|
||||
|
||||
db, err := reader.ReadDatabase()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Process schema...
|
||||
for _, schema := range db.Schemas {
|
||||
fmt.Printf("Schema: %s\n", schema.Name)
|
||||
for _, table := range schema.Tables {
|
||||
fmt.Printf(" Table: %s\n", table.Name)
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Testing
|
||||
|
||||
Run tests with:
|
||||
```bash
|
||||
go test ./pkg/readers/mssql/...
|
||||
```
|
||||
|
||||
For integration testing with a live MSSQL database:
|
||||
```bash
|
||||
docker-compose up -d mssql
|
||||
go test -tags=integration ./pkg/readers/mssql/...
|
||||
docker-compose down
|
||||
```
|
||||
416
pkg/readers/mssql/queries.go
Normal file
416
pkg/readers/mssql/queries.go
Normal file
@@ -0,0 +1,416 @@
|
||||
package mssql
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||
)
|
||||
|
||||
// querySchemas retrieves all user-defined schemas from the database
|
||||
func (r *Reader) querySchemas() ([]*models.Schema, error) {
|
||||
query := `
|
||||
SELECT s.name, ISNULL(ep.value, '') as description
|
||||
FROM sys.schemas s
|
||||
LEFT JOIN sys.extended_properties ep
|
||||
ON ep.major_id = s.schema_id
|
||||
AND ep.minor_id = 0
|
||||
AND ep.class = 3
|
||||
AND ep.name = 'MS_Description'
|
||||
WHERE s.name NOT IN ('dbo', 'guest', 'INFORMATION_SCHEMA', 'sys')
|
||||
ORDER BY s.name
|
||||
`
|
||||
|
||||
rows, err := r.db.QueryContext(r.ctx, query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
schemas := make([]*models.Schema, 0)
|
||||
for rows.Next() {
|
||||
var name, description string
|
||||
|
||||
if err := rows.Scan(&name, &description); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
schema := models.InitSchema(name)
|
||||
if description != "" {
|
||||
schema.Description = description
|
||||
}
|
||||
|
||||
schemas = append(schemas, schema)
|
||||
}
|
||||
|
||||
// Always include dbo schema if it has tables
|
||||
dboSchema := models.InitSchema("dbo")
|
||||
schemas = append(schemas, dboSchema)
|
||||
|
||||
return schemas, rows.Err()
|
||||
}
|
||||
|
||||
// queryTables retrieves all tables for a given schema
|
||||
func (r *Reader) queryTables(schemaName string) ([]*models.Table, error) {
|
||||
query := `
|
||||
SELECT t.table_schema, t.table_name, ISNULL(ep.value, '') as description
|
||||
FROM information_schema.tables t
|
||||
LEFT JOIN sys.extended_properties ep
|
||||
ON ep.major_id = OBJECT_ID(QUOTENAME(t.table_schema) + '.' + QUOTENAME(t.table_name))
|
||||
AND ep.minor_id = 0
|
||||
AND ep.class = 1
|
||||
AND ep.name = 'MS_Description'
|
||||
WHERE t.table_schema = ? AND t.table_type = 'BASE TABLE'
|
||||
ORDER BY t.table_name
|
||||
`
|
||||
|
||||
rows, err := r.db.QueryContext(r.ctx, query, schemaName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
tables := make([]*models.Table, 0)
|
||||
for rows.Next() {
|
||||
var schema, tableName, description string
|
||||
|
||||
if err := rows.Scan(&schema, &tableName, &description); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
table := models.InitTable(tableName, schema)
|
||||
if description != "" {
|
||||
table.Description = description
|
||||
}
|
||||
|
||||
tables = append(tables, table)
|
||||
}
|
||||
|
||||
return tables, rows.Err()
|
||||
}
|
||||
|
||||
// queryColumns retrieves all columns for tables in a schema
|
||||
// Returns map[schema.table]map[columnName]*Column
|
||||
func (r *Reader) queryColumns(schemaName string) (map[string]map[string]*models.Column, error) {
|
||||
query := `
|
||||
SELECT
|
||||
c.table_schema,
|
||||
c.table_name,
|
||||
c.column_name,
|
||||
c.ordinal_position,
|
||||
c.column_default,
|
||||
c.is_nullable,
|
||||
c.data_type,
|
||||
c.character_maximum_length,
|
||||
c.numeric_precision,
|
||||
c.numeric_scale,
|
||||
ISNULL(ep.value, '') as description,
|
||||
COLUMNPROPERTY(OBJECT_ID(QUOTENAME(c.table_schema) + '.' + QUOTENAME(c.table_name)), c.column_name, 'IsIdentity') as is_identity
|
||||
FROM information_schema.columns c
|
||||
LEFT JOIN sys.extended_properties ep
|
||||
ON ep.major_id = OBJECT_ID(QUOTENAME(c.table_schema) + '.' + QUOTENAME(c.table_name))
|
||||
AND ep.minor_id = COLUMNPROPERTY(OBJECT_ID(QUOTENAME(c.table_schema) + '.' + QUOTENAME(c.table_name)), c.column_name, 'ColumnId')
|
||||
AND ep.class = 1
|
||||
AND ep.name = 'MS_Description'
|
||||
WHERE c.table_schema = ?
|
||||
ORDER BY c.table_schema, c.table_name, c.ordinal_position
|
||||
`
|
||||
|
||||
rows, err := r.db.QueryContext(r.ctx, query, schemaName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
columnsMap := make(map[string]map[string]*models.Column)
|
||||
|
||||
for rows.Next() {
|
||||
var schema, tableName, columnName, isNullable, dataType, description string
|
||||
var ordinalPosition int
|
||||
var columnDefault, charMaxLength, numPrecision, numScale, isIdentity *int
|
||||
|
||||
if err := rows.Scan(&schema, &tableName, &columnName, &ordinalPosition, &columnDefault, &isNullable, &dataType, &charMaxLength, &numPrecision, &numScale, &description, &isIdentity); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
column := models.InitColumn(columnName, tableName, schema)
|
||||
column.Type = r.mapDataType(dataType)
|
||||
column.NotNull = (isNullable == "NO")
|
||||
column.Sequence = uint(ordinalPosition)
|
||||
|
||||
if description != "" {
|
||||
column.Description = description
|
||||
}
|
||||
|
||||
// Check if this is an identity column (auto-increment)
|
||||
if isIdentity != nil && *isIdentity == 1 {
|
||||
column.AutoIncrement = true
|
||||
}
|
||||
|
||||
if charMaxLength != nil && *charMaxLength > 0 {
|
||||
column.Length = *charMaxLength
|
||||
}
|
||||
|
||||
if numPrecision != nil && *numPrecision > 0 {
|
||||
column.Precision = *numPrecision
|
||||
}
|
||||
|
||||
if numScale != nil && *numScale > 0 {
|
||||
column.Scale = *numScale
|
||||
}
|
||||
|
||||
// Create table key
|
||||
tableKey := schema + "." + tableName
|
||||
if columnsMap[tableKey] == nil {
|
||||
columnsMap[tableKey] = make(map[string]*models.Column)
|
||||
}
|
||||
columnsMap[tableKey][columnName] = column
|
||||
}
|
||||
|
||||
return columnsMap, rows.Err()
|
||||
}
|
||||
|
||||
// queryPrimaryKeys retrieves all primary key constraints for a schema
|
||||
// Returns map[schema.table]*Constraint
|
||||
func (r *Reader) queryPrimaryKeys(schemaName string) (map[string]*models.Constraint, error) {
|
||||
query := `
|
||||
SELECT
|
||||
s.name as schema_name,
|
||||
t.name as table_name,
|
||||
i.name as constraint_name,
|
||||
STRING_AGG(c.name, ',') WITHIN GROUP (ORDER BY ic.key_ordinal) as columns
|
||||
FROM sys.tables t
|
||||
INNER JOIN sys.indexes i ON t.object_id = i.object_id AND i.is_primary_key = 1
|
||||
INNER JOIN sys.schemas s ON t.schema_id = s.schema_id
|
||||
INNER JOIN sys.index_columns ic ON i.object_id = ic.object_id AND i.index_id = ic.index_id
|
||||
INNER JOIN sys.columns c ON t.object_id = c.object_id AND ic.column_id = c.column_id
|
||||
WHERE s.name = ?
|
||||
GROUP BY s.name, t.name, i.name
|
||||
`
|
||||
|
||||
rows, err := r.db.QueryContext(r.ctx, query, schemaName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
primaryKeys := make(map[string]*models.Constraint)
|
||||
|
||||
for rows.Next() {
|
||||
var schema, tableName, constraintName, columnsStr string
|
||||
|
||||
if err := rows.Scan(&schema, &tableName, &constraintName, &columnsStr); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
columns := strings.Split(columnsStr, ",")
|
||||
|
||||
constraint := models.InitConstraint(constraintName, models.PrimaryKeyConstraint)
|
||||
constraint.Schema = schema
|
||||
constraint.Table = tableName
|
||||
constraint.Columns = columns
|
||||
|
||||
tableKey := schema + "." + tableName
|
||||
primaryKeys[tableKey] = constraint
|
||||
}
|
||||
|
||||
return primaryKeys, rows.Err()
|
||||
}
|
||||
|
||||
// queryForeignKeys retrieves all foreign key constraints for a schema
|
||||
// Returns map[schema.table][]*Constraint
|
||||
func (r *Reader) queryForeignKeys(schemaName string) (map[string][]*models.Constraint, error) {
|
||||
query := `
|
||||
SELECT
|
||||
s.name as schema_name,
|
||||
t.name as table_name,
|
||||
fk.name as constraint_name,
|
||||
rs.name as referenced_schema,
|
||||
rt.name as referenced_table,
|
||||
STRING_AGG(c.name, ',') WITHIN GROUP (ORDER BY fkc.constraint_column_id) as columns,
|
||||
STRING_AGG(rc.name, ',') WITHIN GROUP (ORDER BY fkc.constraint_column_id) as referenced_columns,
|
||||
fk.delete_referential_action_desc,
|
||||
fk.update_referential_action_desc
|
||||
FROM sys.foreign_keys fk
|
||||
INNER JOIN sys.tables t ON fk.parent_object_id = t.object_id
|
||||
INNER JOIN sys.tables rt ON fk.referenced_object_id = rt.object_id
|
||||
INNER JOIN sys.schemas s ON t.schema_id = s.schema_id
|
||||
INNER JOIN sys.schemas rs ON rt.schema_id = rs.schema_id
|
||||
INNER JOIN sys.foreign_key_columns fkc ON fk.object_id = fkc.constraint_object_id
|
||||
INNER JOIN sys.columns c ON fkc.parent_object_id = c.object_id AND fkc.parent_column_id = c.column_id
|
||||
INNER JOIN sys.columns rc ON fkc.referenced_object_id = rc.object_id AND fkc.referenced_column_id = rc.column_id
|
||||
WHERE s.name = ?
|
||||
GROUP BY s.name, t.name, fk.name, rs.name, rt.name, fk.delete_referential_action_desc, fk.update_referential_action_desc
|
||||
`
|
||||
|
||||
rows, err := r.db.QueryContext(r.ctx, query, schemaName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
foreignKeys := make(map[string][]*models.Constraint)
|
||||
|
||||
for rows.Next() {
|
||||
var schema, tableName, constraintName, refSchema, refTable, columnsStr, refColumnsStr, deleteAction, updateAction string
|
||||
|
||||
if err := rows.Scan(&schema, &tableName, &constraintName, &refSchema, &refTable, &columnsStr, &refColumnsStr, &deleteAction, &updateAction); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
columns := strings.Split(columnsStr, ",")
|
||||
refColumns := strings.Split(refColumnsStr, ",")
|
||||
|
||||
constraint := models.InitConstraint(constraintName, models.ForeignKeyConstraint)
|
||||
constraint.Schema = schema
|
||||
constraint.Table = tableName
|
||||
constraint.Columns = columns
|
||||
constraint.ReferencedSchema = refSchema
|
||||
constraint.ReferencedTable = refTable
|
||||
constraint.ReferencedColumns = refColumns
|
||||
constraint.OnDelete = strings.ToUpper(deleteAction)
|
||||
constraint.OnUpdate = strings.ToUpper(updateAction)
|
||||
|
||||
tableKey := schema + "." + tableName
|
||||
foreignKeys[tableKey] = append(foreignKeys[tableKey], constraint)
|
||||
}
|
||||
|
||||
return foreignKeys, rows.Err()
|
||||
}
|
||||
|
||||
// queryUniqueConstraints retrieves all unique constraints for a schema
|
||||
// Returns map[schema.table][]*Constraint
|
||||
func (r *Reader) queryUniqueConstraints(schemaName string) (map[string][]*models.Constraint, error) {
|
||||
query := `
|
||||
SELECT
|
||||
s.name as schema_name,
|
||||
t.name as table_name,
|
||||
i.name as constraint_name,
|
||||
STRING_AGG(c.name, ',') WITHIN GROUP (ORDER BY ic.key_ordinal) as columns
|
||||
FROM sys.tables t
|
||||
INNER JOIN sys.indexes i ON t.object_id = i.object_id AND i.is_unique = 1 AND i.is_primary_key = 0
|
||||
INNER JOIN sys.schemas s ON t.schema_id = s.schema_id
|
||||
INNER JOIN sys.index_columns ic ON i.object_id = ic.object_id AND i.index_id = ic.index_id
|
||||
INNER JOIN sys.columns c ON t.object_id = c.object_id AND ic.column_id = c.column_id
|
||||
WHERE s.name = ?
|
||||
GROUP BY s.name, t.name, i.name
|
||||
`
|
||||
|
||||
rows, err := r.db.QueryContext(r.ctx, query, schemaName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
uniqueConstraints := make(map[string][]*models.Constraint)
|
||||
|
||||
for rows.Next() {
|
||||
var schema, tableName, constraintName, columnsStr string
|
||||
|
||||
if err := rows.Scan(&schema, &tableName, &constraintName, &columnsStr); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
columns := strings.Split(columnsStr, ",")
|
||||
|
||||
constraint := models.InitConstraint(constraintName, models.UniqueConstraint)
|
||||
constraint.Schema = schema
|
||||
constraint.Table = tableName
|
||||
constraint.Columns = columns
|
||||
|
||||
tableKey := schema + "." + tableName
|
||||
uniqueConstraints[tableKey] = append(uniqueConstraints[tableKey], constraint)
|
||||
}
|
||||
|
||||
return uniqueConstraints, rows.Err()
|
||||
}
|
||||
|
||||
// queryCheckConstraints retrieves all check constraints for a schema
|
||||
// Returns map[schema.table][]*Constraint
|
||||
func (r *Reader) queryCheckConstraints(schemaName string) (map[string][]*models.Constraint, error) {
|
||||
query := `
|
||||
SELECT
|
||||
s.name as schema_name,
|
||||
t.name as table_name,
|
||||
cc.name as constraint_name,
|
||||
cc.definition
|
||||
FROM sys.tables t
|
||||
INNER JOIN sys.check_constraints cc ON t.object_id = cc.parent_object_id
|
||||
INNER JOIN sys.schemas s ON t.schema_id = s.schema_id
|
||||
WHERE s.name = ?
|
||||
`
|
||||
|
||||
rows, err := r.db.QueryContext(r.ctx, query, schemaName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
checkConstraints := make(map[string][]*models.Constraint)
|
||||
|
||||
for rows.Next() {
|
||||
var schema, tableName, constraintName, definition string
|
||||
|
||||
if err := rows.Scan(&schema, &tableName, &constraintName, &definition); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
constraint := models.InitConstraint(constraintName, models.CheckConstraint)
|
||||
constraint.Schema = schema
|
||||
constraint.Table = tableName
|
||||
constraint.Expression = definition
|
||||
|
||||
tableKey := schema + "." + tableName
|
||||
checkConstraints[tableKey] = append(checkConstraints[tableKey], constraint)
|
||||
}
|
||||
|
||||
return checkConstraints, rows.Err()
|
||||
}
|
||||
|
||||
// queryIndexes retrieves all indexes for a schema
|
||||
// Returns map[schema.table][]*Index
|
||||
func (r *Reader) queryIndexes(schemaName string) (map[string][]*models.Index, error) {
|
||||
query := `
|
||||
SELECT
|
||||
s.name as schema_name,
|
||||
t.name as table_name,
|
||||
i.name as index_name,
|
||||
i.is_unique,
|
||||
STRING_AGG(c.name, ',') WITHIN GROUP (ORDER BY ic.key_ordinal) as columns
|
||||
FROM sys.tables t
|
||||
INNER JOIN sys.indexes i ON t.object_id = i.object_id AND i.is_primary_key = 0 AND i.name IS NOT NULL
|
||||
INNER JOIN sys.schemas s ON t.schema_id = s.schema_id
|
||||
INNER JOIN sys.index_columns ic ON i.object_id = ic.object_id AND i.index_id = ic.index_id
|
||||
INNER JOIN sys.columns c ON t.object_id = c.object_id AND ic.column_id = c.column_id
|
||||
WHERE s.name = ?
|
||||
GROUP BY s.name, t.name, i.name, i.is_unique
|
||||
`
|
||||
|
||||
rows, err := r.db.QueryContext(r.ctx, query, schemaName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
indexes := make(map[string][]*models.Index)
|
||||
|
||||
for rows.Next() {
|
||||
var schema, tableName, indexName, columnsStr string
|
||||
var isUnique int
|
||||
|
||||
if err := rows.Scan(&schema, &tableName, &indexName, &isUnique, &columnsStr); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
columns := strings.Split(columnsStr, ",")
|
||||
|
||||
index := models.InitIndex(indexName, tableName, schema)
|
||||
index.Columns = columns
|
||||
index.Unique = (isUnique == 1)
|
||||
index.Type = "btree" // MSSQL uses btree by default
|
||||
|
||||
tableKey := schema + "." + tableName
|
||||
indexes[tableKey] = append(indexes[tableKey], index)
|
||||
}
|
||||
|
||||
return indexes, rows.Err()
|
||||
}
|
||||
266
pkg/readers/mssql/reader.go
Normal file
266
pkg/readers/mssql/reader.go
Normal file
@@ -0,0 +1,266 @@
|
||||
package mssql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
_ "github.com/microsoft/go-mssqldb" // MSSQL driver
|
||||
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/mssql"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/readers"
|
||||
)
|
||||
|
||||
// Reader implements the readers.Reader interface for MSSQL databases
|
||||
type Reader struct {
|
||||
options *readers.ReaderOptions
|
||||
db *sql.DB
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
// NewReader creates a new MSSQL reader
|
||||
func NewReader(options *readers.ReaderOptions) *Reader {
|
||||
return &Reader{
|
||||
options: options,
|
||||
ctx: context.Background(),
|
||||
}
|
||||
}
|
||||
|
||||
// ReadDatabase reads the entire database schema from MSSQL
|
||||
func (r *Reader) ReadDatabase() (*models.Database, error) {
|
||||
// Validate connection string
|
||||
if r.options.ConnectionString == "" {
|
||||
return nil, fmt.Errorf("connection string is required")
|
||||
}
|
||||
|
||||
// Connect to the database
|
||||
if err := r.connect(); err != nil {
|
||||
return nil, fmt.Errorf("failed to connect: %w", err)
|
||||
}
|
||||
defer r.close()
|
||||
|
||||
// Get database name
|
||||
var dbName string
|
||||
err := r.db.QueryRowContext(r.ctx, "SELECT DB_NAME()").Scan(&dbName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get database name: %w", err)
|
||||
}
|
||||
|
||||
// Initialize database model
|
||||
db := models.InitDatabase(dbName)
|
||||
db.DatabaseType = models.MSSQLDatabaseType
|
||||
db.SourceFormat = "mssql"
|
||||
|
||||
// Get MSSQL version
|
||||
var version string
|
||||
err = r.db.QueryRowContext(r.ctx, "SELECT @@VERSION").Scan(&version)
|
||||
if err == nil {
|
||||
db.DatabaseVersion = version
|
||||
}
|
||||
|
||||
// Query all schemas
|
||||
schemas, err := r.querySchemas()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query schemas: %w", err)
|
||||
}
|
||||
|
||||
// Process each schema
|
||||
for _, schema := range schemas {
|
||||
// Query tables for this schema
|
||||
tables, err := r.queryTables(schema.Name)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query tables for schema %s: %w", schema.Name, err)
|
||||
}
|
||||
schema.Tables = tables
|
||||
|
||||
// Query columns for tables
|
||||
columnsMap, err := r.queryColumns(schema.Name)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query columns for schema %s: %w", schema.Name, err)
|
||||
}
|
||||
|
||||
// Populate table columns
|
||||
for _, table := range schema.Tables {
|
||||
tableKey := schema.Name + "." + table.Name
|
||||
if cols, exists := columnsMap[tableKey]; exists {
|
||||
table.Columns = cols
|
||||
}
|
||||
}
|
||||
|
||||
// Query primary keys
|
||||
primaryKeys, err := r.queryPrimaryKeys(schema.Name)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query primary keys for schema %s: %w", schema.Name, err)
|
||||
}
|
||||
|
||||
// Apply primary keys to tables
|
||||
for _, table := range schema.Tables {
|
||||
tableKey := schema.Name + "." + table.Name
|
||||
if pk, exists := primaryKeys[tableKey]; exists {
|
||||
table.Constraints[pk.Name] = pk
|
||||
// Mark columns as primary key and not null
|
||||
for _, colName := range pk.Columns {
|
||||
if col, colExists := table.Columns[colName]; colExists {
|
||||
col.IsPrimaryKey = true
|
||||
col.NotNull = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Query foreign keys
|
||||
foreignKeys, err := r.queryForeignKeys(schema.Name)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query foreign keys for schema %s: %w", schema.Name, err)
|
||||
}
|
||||
|
||||
// Apply foreign keys to tables
|
||||
for _, table := range schema.Tables {
|
||||
tableKey := schema.Name + "." + table.Name
|
||||
if fks, exists := foreignKeys[tableKey]; exists {
|
||||
for _, fk := range fks {
|
||||
table.Constraints[fk.Name] = fk
|
||||
// Derive relationship from foreign key
|
||||
r.deriveRelationship(table, fk)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Query unique constraints
|
||||
uniqueConstraints, err := r.queryUniqueConstraints(schema.Name)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query unique constraints for schema %s: %w", schema.Name, err)
|
||||
}
|
||||
|
||||
// Apply unique constraints to tables
|
||||
for _, table := range schema.Tables {
|
||||
tableKey := schema.Name + "." + table.Name
|
||||
if ucs, exists := uniqueConstraints[tableKey]; exists {
|
||||
for _, uc := range ucs {
|
||||
table.Constraints[uc.Name] = uc
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Query check constraints
|
||||
checkConstraints, err := r.queryCheckConstraints(schema.Name)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query check constraints for schema %s: %w", schema.Name, err)
|
||||
}
|
||||
|
||||
// Apply check constraints to tables
|
||||
for _, table := range schema.Tables {
|
||||
tableKey := schema.Name + "." + table.Name
|
||||
if ccs, exists := checkConstraints[tableKey]; exists {
|
||||
for _, cc := range ccs {
|
||||
table.Constraints[cc.Name] = cc
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Query indexes
|
||||
indexes, err := r.queryIndexes(schema.Name)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query indexes for schema %s: %w", schema.Name, err)
|
||||
}
|
||||
|
||||
// Apply indexes to tables
|
||||
for _, table := range schema.Tables {
|
||||
tableKey := schema.Name + "." + table.Name
|
||||
if idxs, exists := indexes[tableKey]; exists {
|
||||
for _, idx := range idxs {
|
||||
table.Indexes[idx.Name] = idx
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Set RefDatabase for schema
|
||||
schema.RefDatabase = db
|
||||
|
||||
// Set RefSchema for tables
|
||||
for _, table := range schema.Tables {
|
||||
table.RefSchema = schema
|
||||
}
|
||||
|
||||
// Add schema to database
|
||||
db.Schemas = append(db.Schemas, schema)
|
||||
}
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
// ReadSchema reads a single schema (returns the first schema from the database)
|
||||
func (r *Reader) ReadSchema() (*models.Schema, error) {
|
||||
db, err := r.ReadDatabase()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(db.Schemas) == 0 {
|
||||
return nil, fmt.Errorf("no schemas found in database")
|
||||
}
|
||||
return db.Schemas[0], nil
|
||||
}
|
||||
|
||||
// ReadTable reads a single table (returns the first table from the first schema)
|
||||
func (r *Reader) ReadTable() (*models.Table, error) {
|
||||
schema, err := r.ReadSchema()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(schema.Tables) == 0 {
|
||||
return nil, fmt.Errorf("no tables found in schema")
|
||||
}
|
||||
return schema.Tables[0], nil
|
||||
}
|
||||
|
||||
// connect establishes a connection to the MSSQL database
|
||||
func (r *Reader) connect() error {
|
||||
db, err := sql.Open("mssql", r.options.ConnectionString)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Test connection
|
||||
if err = db.PingContext(r.ctx); err != nil {
|
||||
db.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
r.db = db
|
||||
return nil
|
||||
}
|
||||
|
||||
// close closes the database connection
|
||||
func (r *Reader) close() {
|
||||
if r.db != nil {
|
||||
r.db.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// mapDataType maps MSSQL data types to canonical types
|
||||
func (r *Reader) mapDataType(mssqlType string) string {
|
||||
return mssql.ConvertMSSQLToCanonical(mssqlType)
|
||||
}
|
||||
|
||||
// deriveRelationship creates a relationship from a foreign key constraint
|
||||
func (r *Reader) deriveRelationship(table *models.Table, fk *models.Constraint) {
|
||||
relationshipName := fmt.Sprintf("%s_to_%s", table.Name, fk.ReferencedTable)
|
||||
|
||||
relationship := models.InitRelationship(relationshipName, models.OneToMany)
|
||||
relationship.FromTable = table.Name
|
||||
relationship.FromSchema = table.Schema
|
||||
relationship.ToTable = fk.ReferencedTable
|
||||
relationship.ToSchema = fk.ReferencedSchema
|
||||
relationship.ForeignKey = fk.Name
|
||||
|
||||
// Store constraint actions in properties
|
||||
if fk.OnDelete != "" {
|
||||
relationship.Properties["on_delete"] = fk.OnDelete
|
||||
}
|
||||
if fk.OnUpdate != "" {
|
||||
relationship.Properties["on_update"] = fk.OnUpdate
|
||||
}
|
||||
|
||||
table.Relationships[relationshipName] = relationship
|
||||
}
|
||||
86
pkg/readers/mssql/reader_test.go
Normal file
86
pkg/readers/mssql/reader_test.go
Normal file
@@ -0,0 +1,86 @@
|
||||
package mssql
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/mssql"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/readers"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// TestMapDataType tests MSSQL type mapping to canonical types
|
||||
func TestMapDataType(t *testing.T) {
|
||||
reader := NewReader(&readers.ReaderOptions{})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
mssqlType string
|
||||
expectedType string
|
||||
}{
|
||||
{"INT to int", "INT", "int"},
|
||||
{"BIGINT to int64", "BIGINT", "int64"},
|
||||
{"BIT to bool", "BIT", "bool"},
|
||||
{"NVARCHAR to string", "NVARCHAR(255)", "string"},
|
||||
{"DATETIME2 to timestamp", "DATETIME2", "timestamp"},
|
||||
{"DATETIMEOFFSET to timestamptz", "DATETIMEOFFSET", "timestamptz"},
|
||||
{"UNIQUEIDENTIFIER to uuid", "UNIQUEIDENTIFIER", "uuid"},
|
||||
{"FLOAT to float64", "FLOAT", "float64"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := reader.mapDataType(tt.mssqlType)
|
||||
assert.Equal(t, tt.expectedType, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestConvertCanonicalToMSSQL tests canonical to MSSQL type conversion
|
||||
func TestConvertCanonicalToMSSQL(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
canonicalType string
|
||||
expectedMSSQL string
|
||||
}{
|
||||
{"int to INT", "int", "INT"},
|
||||
{"int64 to BIGINT", "int64", "BIGINT"},
|
||||
{"bool to BIT", "bool", "BIT"},
|
||||
{"string to NVARCHAR(255)", "string", "NVARCHAR(255)"},
|
||||
{"text to NVARCHAR(MAX)", "text", "NVARCHAR(MAX)"},
|
||||
{"timestamp to DATETIME2", "timestamp", "DATETIME2"},
|
||||
{"timestamptz to DATETIMEOFFSET", "timestamptz", "DATETIMEOFFSET"},
|
||||
{"uuid to UNIQUEIDENTIFIER", "uuid", "UNIQUEIDENTIFIER"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := mssql.ConvertCanonicalToMSSQL(tt.canonicalType)
|
||||
assert.Equal(t, tt.expectedMSSQL, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestConvertMSSQLToCanonical tests MSSQL to canonical type conversion
|
||||
func TestConvertMSSQLToCanonical(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
mssqlType string
|
||||
expectedType string
|
||||
}{
|
||||
{"INT to int", "INT", "int"},
|
||||
{"BIGINT to int64", "BIGINT", "int64"},
|
||||
{"BIT to bool", "BIT", "bool"},
|
||||
{"NVARCHAR with params", "NVARCHAR(255)", "string"},
|
||||
{"DATETIME2 to timestamp", "DATETIME2", "timestamp"},
|
||||
{"DATETIMEOFFSET to timestamptz", "DATETIMEOFFSET", "timestamptz"},
|
||||
{"UNIQUEIDENTIFIER to uuid", "UNIQUEIDENTIFIER", "uuid"},
|
||||
{"VARBINARY to bytea", "VARBINARY(MAX)", "bytea"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := mssql.ConvertMSSQLToCanonical(tt.mssqlType)
|
||||
assert.Equal(t, tt.expectedType, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user