Added more examples and pgsql reader
This commit is contained in:
600
pkg/readers/pgsql/queries.go
Normal file
600
pkg/readers/pgsql/queries.go
Normal file
@@ -0,0 +1,600 @@
|
||||
package pgsql
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||
)
|
||||
|
||||
// querySchemas retrieves all non-system schemas from the database
|
||||
func (r *Reader) querySchemas() ([]*models.Schema, error) {
|
||||
query := `
|
||||
SELECT
|
||||
nspname as schema_name,
|
||||
obj_description(oid, 'pg_namespace') as description
|
||||
FROM pg_namespace
|
||||
WHERE nspname NOT IN ('pg_catalog', 'information_schema', 'pg_toast')
|
||||
AND nspname NOT LIKE 'pg_temp_%'
|
||||
AND nspname NOT LIKE 'pg_toast_temp_%'
|
||||
ORDER BY nspname
|
||||
`
|
||||
|
||||
rows, err := r.conn.Query(r.ctx, query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
schemas := make([]*models.Schema, 0)
|
||||
for rows.Next() {
|
||||
var name string
|
||||
var description *string
|
||||
|
||||
if err := rows.Scan(&name, &description); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
schema := models.InitSchema(name)
|
||||
if description != nil {
|
||||
schema.Description = *description
|
||||
}
|
||||
|
||||
schemas = append(schemas, schema)
|
||||
}
|
||||
|
||||
return schemas, rows.Err()
|
||||
}
|
||||
|
||||
// queryTables retrieves all tables for a given schema
|
||||
func (r *Reader) queryTables(schemaName string) ([]*models.Table, error) {
|
||||
query := `
|
||||
SELECT
|
||||
schemaname,
|
||||
tablename,
|
||||
obj_description((schemaname||'.'||tablename)::regclass, 'pg_class') as description
|
||||
FROM pg_tables
|
||||
WHERE schemaname = $1
|
||||
ORDER BY tablename
|
||||
`
|
||||
|
||||
rows, err := r.conn.Query(r.ctx, query, schemaName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
tables := make([]*models.Table, 0)
|
||||
for rows.Next() {
|
||||
var schema, tableName string
|
||||
var description *string
|
||||
|
||||
if err := rows.Scan(&schema, &tableName, &description); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
table := models.InitTable(tableName, schema)
|
||||
if description != nil {
|
||||
table.Description = *description
|
||||
}
|
||||
|
||||
tables = append(tables, table)
|
||||
}
|
||||
|
||||
return tables, rows.Err()
|
||||
}
|
||||
|
||||
// queryViews retrieves all views for a given schema
|
||||
func (r *Reader) queryViews(schemaName string) ([]*models.View, error) {
|
||||
query := `
|
||||
SELECT
|
||||
schemaname,
|
||||
viewname,
|
||||
definition,
|
||||
obj_description((schemaname||'.'||viewname)::regclass, 'pg_class') as description
|
||||
FROM pg_views
|
||||
WHERE schemaname = $1
|
||||
ORDER BY viewname
|
||||
`
|
||||
|
||||
rows, err := r.conn.Query(r.ctx, query, schemaName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
views := make([]*models.View, 0)
|
||||
for rows.Next() {
|
||||
var schema, viewName, definition string
|
||||
var description *string
|
||||
|
||||
if err := rows.Scan(&schema, &viewName, &definition, &description); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
view := models.InitView(viewName, schema)
|
||||
view.Definition = definition
|
||||
if description != nil {
|
||||
view.Description = *description
|
||||
}
|
||||
|
||||
views = append(views, view)
|
||||
}
|
||||
|
||||
return views, rows.Err()
|
||||
}
|
||||
|
||||
// querySequences retrieves all sequences for a given schema
|
||||
func (r *Reader) querySequences(schemaName string) ([]*models.Sequence, error) {
|
||||
query := `
|
||||
SELECT
|
||||
ps.schemaname,
|
||||
ps.sequencename,
|
||||
ps.start_value,
|
||||
ps.min_value,
|
||||
ps.max_value,
|
||||
ps.increment_by,
|
||||
ps.cycle,
|
||||
ps.cache_size,
|
||||
obj_description((ps.schemaname||'.'||ps.sequencename)::regclass, 'pg_class') as description,
|
||||
owner_table.relname as owned_by_table,
|
||||
owner_column.attname as owned_by_column
|
||||
FROM pg_sequences ps
|
||||
LEFT JOIN pg_class seq_class ON seq_class.relname = ps.sequencename
|
||||
AND seq_class.relnamespace = (SELECT oid FROM pg_namespace WHERE nspname = ps.schemaname)
|
||||
LEFT JOIN pg_depend ON pg_depend.objid = seq_class.oid AND pg_depend.deptype = 'a'
|
||||
LEFT JOIN pg_class owner_table ON pg_depend.refobjid = owner_table.oid
|
||||
LEFT JOIN pg_attribute owner_column ON pg_depend.refobjid = owner_column.attrelid
|
||||
AND pg_depend.refobjsubid = owner_column.attnum
|
||||
WHERE ps.schemaname = $1
|
||||
ORDER BY ps.sequencename
|
||||
`
|
||||
|
||||
rows, err := r.conn.Query(r.ctx, query, schemaName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
sequences := make([]*models.Sequence, 0)
|
||||
for rows.Next() {
|
||||
var schema, seqName string
|
||||
var startValue, minValue, maxValue, incrementBy, cacheSize int64
|
||||
var cycle bool
|
||||
var description, tableName, columnName *string
|
||||
|
||||
if err := rows.Scan(&schema, &seqName, &startValue, &minValue, &maxValue, &incrementBy, &cycle, &cacheSize, &description, &tableName, &columnName); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
seq := models.InitSequence(seqName, schema)
|
||||
seq.StartValue = startValue
|
||||
seq.MinValue = minValue
|
||||
seq.MaxValue = maxValue
|
||||
seq.IncrementBy = incrementBy
|
||||
seq.Cycle = cycle
|
||||
seq.CacheSize = cacheSize
|
||||
if description != nil {
|
||||
seq.Description = *description
|
||||
}
|
||||
if tableName != nil {
|
||||
seq.OwnedByTable = *tableName
|
||||
}
|
||||
if columnName != nil {
|
||||
seq.OwnedByColumn = *columnName
|
||||
}
|
||||
|
||||
sequences = append(sequences, seq)
|
||||
}
|
||||
|
||||
return sequences, rows.Err()
|
||||
}
|
||||
|
||||
// queryColumns retrieves all columns for tables and views in a schema
|
||||
// Returns map[schema.table]map[columnName]*Column
|
||||
func (r *Reader) queryColumns(schemaName string) (map[string]map[string]*models.Column, error) {
|
||||
query := `
|
||||
SELECT
|
||||
c.table_schema,
|
||||
c.table_name,
|
||||
c.column_name,
|
||||
c.ordinal_position,
|
||||
c.column_default,
|
||||
c.is_nullable,
|
||||
c.data_type,
|
||||
c.character_maximum_length,
|
||||
c.numeric_precision,
|
||||
c.numeric_scale,
|
||||
c.udt_name,
|
||||
col_description((c.table_schema||'.'||c.table_name)::regclass, c.ordinal_position) as description
|
||||
FROM information_schema.columns c
|
||||
WHERE c.table_schema = $1
|
||||
ORDER BY c.table_schema, c.table_name, c.ordinal_position
|
||||
`
|
||||
|
||||
rows, err := r.conn.Query(r.ctx, query, schemaName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
columnsMap := make(map[string]map[string]*models.Column)
|
||||
|
||||
for rows.Next() {
|
||||
var schema, tableName, columnName, isNullable, dataType, udtName string
|
||||
var ordinalPosition int
|
||||
var columnDefault, description *string
|
||||
var charMaxLength, numPrecision, numScale *int
|
||||
|
||||
if err := rows.Scan(&schema, &tableName, &columnName, &ordinalPosition, &columnDefault, &isNullable, &dataType, &charMaxLength, &numPrecision, &numScale, &udtName, &description); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
column := models.InitColumn(columnName, tableName, schema)
|
||||
column.Type = r.mapDataType(dataType, udtName)
|
||||
column.NotNull = (isNullable == "NO")
|
||||
column.Sequence = uint(ordinalPosition)
|
||||
|
||||
if columnDefault != nil {
|
||||
// Parse default value - remove nextval for sequences
|
||||
defaultVal := *columnDefault
|
||||
if strings.HasPrefix(defaultVal, "nextval") {
|
||||
column.AutoIncrement = true
|
||||
column.Default = defaultVal
|
||||
} else {
|
||||
column.Default = defaultVal
|
||||
}
|
||||
}
|
||||
|
||||
if description != nil {
|
||||
column.Description = *description
|
||||
}
|
||||
|
||||
if charMaxLength != nil {
|
||||
column.Length = *charMaxLength
|
||||
}
|
||||
|
||||
if numPrecision != nil {
|
||||
column.Precision = *numPrecision
|
||||
}
|
||||
|
||||
if numScale != nil {
|
||||
column.Scale = *numScale
|
||||
}
|
||||
|
||||
// Create table key
|
||||
tableKey := schema + "." + tableName
|
||||
if columnsMap[tableKey] == nil {
|
||||
columnsMap[tableKey] = make(map[string]*models.Column)
|
||||
}
|
||||
columnsMap[tableKey][columnName] = column
|
||||
}
|
||||
|
||||
return columnsMap, rows.Err()
|
||||
}
|
||||
|
||||
// queryPrimaryKeys retrieves all primary key constraints for a schema
|
||||
// Returns map[schema.table]*Constraint
|
||||
func (r *Reader) queryPrimaryKeys(schemaName string) (map[string]*models.Constraint, error) {
|
||||
query := `
|
||||
SELECT
|
||||
tc.table_schema,
|
||||
tc.table_name,
|
||||
tc.constraint_name,
|
||||
array_agg(kcu.column_name ORDER BY kcu.ordinal_position) as columns
|
||||
FROM information_schema.table_constraints tc
|
||||
JOIN information_schema.key_column_usage kcu
|
||||
ON tc.constraint_name = kcu.constraint_name
|
||||
AND tc.table_schema = kcu.table_schema
|
||||
WHERE tc.constraint_type = 'PRIMARY KEY'
|
||||
AND tc.table_schema = $1
|
||||
GROUP BY tc.table_schema, tc.table_name, tc.constraint_name
|
||||
`
|
||||
|
||||
rows, err := r.conn.Query(r.ctx, query, schemaName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
primaryKeys := make(map[string]*models.Constraint)
|
||||
|
||||
for rows.Next() {
|
||||
var schema, tableName, constraintName string
|
||||
var columns []string
|
||||
|
||||
if err := rows.Scan(&schema, &tableName, &constraintName, &columns); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
constraint := models.InitConstraint(constraintName, models.PrimaryKeyConstraint)
|
||||
constraint.Schema = schema
|
||||
constraint.Table = tableName
|
||||
constraint.Columns = columns
|
||||
|
||||
tableKey := schema + "." + tableName
|
||||
primaryKeys[tableKey] = constraint
|
||||
}
|
||||
|
||||
return primaryKeys, rows.Err()
|
||||
}
|
||||
|
||||
// queryForeignKeys retrieves all foreign key constraints for a schema
|
||||
// Returns map[schema.table][]*Constraint
|
||||
func (r *Reader) queryForeignKeys(schemaName string) (map[string][]*models.Constraint, error) {
|
||||
query := `
|
||||
SELECT
|
||||
tc.table_schema,
|
||||
tc.table_name,
|
||||
tc.constraint_name,
|
||||
kcu.table_schema as foreign_table_schema,
|
||||
kcu.table_name as foreign_table_name,
|
||||
kcu.column_name as foreign_column,
|
||||
ccu.table_schema as referenced_table_schema,
|
||||
ccu.table_name as referenced_table_name,
|
||||
ccu.column_name as referenced_column,
|
||||
rc.update_rule,
|
||||
rc.delete_rule
|
||||
FROM information_schema.table_constraints tc
|
||||
JOIN information_schema.key_column_usage kcu
|
||||
ON tc.constraint_name = kcu.constraint_name
|
||||
AND tc.table_schema = kcu.table_schema
|
||||
JOIN information_schema.constraint_column_usage ccu
|
||||
ON ccu.constraint_name = tc.constraint_name
|
||||
JOIN information_schema.referential_constraints rc
|
||||
ON rc.constraint_name = tc.constraint_name
|
||||
AND rc.constraint_schema = tc.table_schema
|
||||
WHERE tc.constraint_type = 'FOREIGN KEY'
|
||||
AND tc.table_schema = $1
|
||||
ORDER BY tc.table_schema, tc.table_name, tc.constraint_name, kcu.ordinal_position
|
||||
`
|
||||
|
||||
rows, err := r.conn.Query(r.ctx, query, schemaName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
// First pass: collect all FK data
|
||||
type fkData struct {
|
||||
schema string
|
||||
tableName string
|
||||
constraintName string
|
||||
foreignColumns []string
|
||||
referencedSchema string
|
||||
referencedTable string
|
||||
referencedColumns []string
|
||||
updateRule string
|
||||
deleteRule string
|
||||
}
|
||||
|
||||
fkMap := make(map[string]*fkData)
|
||||
|
||||
for rows.Next() {
|
||||
var schema, tableName, constraintName string
|
||||
var foreignSchema, foreignTable, foreignColumn string
|
||||
var referencedSchema, referencedTable, referencedColumn string
|
||||
var updateRule, deleteRule string
|
||||
|
||||
if err := rows.Scan(&schema, &tableName, &constraintName, &foreignSchema, &foreignTable, &foreignColumn, &referencedSchema, &referencedTable, &referencedColumn, &updateRule, &deleteRule); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
key := schema + "." + tableName + "." + constraintName
|
||||
|
||||
if _, exists := fkMap[key]; !exists {
|
||||
fkMap[key] = &fkData{
|
||||
schema: schema,
|
||||
tableName: tableName,
|
||||
constraintName: constraintName,
|
||||
foreignColumns: []string{},
|
||||
referencedSchema: referencedSchema,
|
||||
referencedTable: referencedTable,
|
||||
referencedColumns: []string{},
|
||||
updateRule: updateRule,
|
||||
deleteRule: deleteRule,
|
||||
}
|
||||
}
|
||||
|
||||
fkMap[key].foreignColumns = append(fkMap[key].foreignColumns, foreignColumn)
|
||||
fkMap[key].referencedColumns = append(fkMap[key].referencedColumns, referencedColumn)
|
||||
}
|
||||
|
||||
// Second pass: create constraints
|
||||
foreignKeys := make(map[string][]*models.Constraint)
|
||||
|
||||
for _, fk := range fkMap {
|
||||
constraint := models.InitConstraint(fk.constraintName, models.ForeignKeyConstraint)
|
||||
constraint.Schema = fk.schema
|
||||
constraint.Table = fk.tableName
|
||||
constraint.Columns = fk.foreignColumns
|
||||
constraint.ReferencedSchema = fk.referencedSchema
|
||||
constraint.ReferencedTable = fk.referencedTable
|
||||
constraint.ReferencedColumns = fk.referencedColumns
|
||||
constraint.OnUpdate = fk.updateRule
|
||||
constraint.OnDelete = fk.deleteRule
|
||||
|
||||
tableKey := fk.schema + "." + fk.tableName
|
||||
foreignKeys[tableKey] = append(foreignKeys[tableKey], constraint)
|
||||
}
|
||||
|
||||
return foreignKeys, rows.Err()
|
||||
}
|
||||
|
||||
// queryUniqueConstraints retrieves all unique constraints for a schema
|
||||
// Returns map[schema.table][]*Constraint
|
||||
func (r *Reader) queryUniqueConstraints(schemaName string) (map[string][]*models.Constraint, error) {
|
||||
query := `
|
||||
SELECT
|
||||
tc.table_schema,
|
||||
tc.table_name,
|
||||
tc.constraint_name,
|
||||
array_agg(kcu.column_name ORDER BY kcu.ordinal_position) as columns
|
||||
FROM information_schema.table_constraints tc
|
||||
JOIN information_schema.key_column_usage kcu
|
||||
ON tc.constraint_name = kcu.constraint_name
|
||||
AND tc.table_schema = kcu.table_schema
|
||||
WHERE tc.constraint_type = 'UNIQUE'
|
||||
AND tc.table_schema = $1
|
||||
GROUP BY tc.table_schema, tc.table_name, tc.constraint_name
|
||||
`
|
||||
|
||||
rows, err := r.conn.Query(r.ctx, query, schemaName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
uniqueConstraints := make(map[string][]*models.Constraint)
|
||||
|
||||
for rows.Next() {
|
||||
var schema, tableName, constraintName string
|
||||
var columns []string
|
||||
|
||||
if err := rows.Scan(&schema, &tableName, &constraintName, &columns); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
constraint := models.InitConstraint(constraintName, models.UniqueConstraint)
|
||||
constraint.Schema = schema
|
||||
constraint.Table = tableName
|
||||
constraint.Columns = columns
|
||||
|
||||
tableKey := schema + "." + tableName
|
||||
uniqueConstraints[tableKey] = append(uniqueConstraints[tableKey], constraint)
|
||||
}
|
||||
|
||||
return uniqueConstraints, rows.Err()
|
||||
}
|
||||
|
||||
// queryCheckConstraints retrieves all check constraints for a schema
|
||||
// Returns map[schema.table][]*Constraint
|
||||
func (r *Reader) queryCheckConstraints(schemaName string) (map[string][]*models.Constraint, error) {
|
||||
query := `
|
||||
SELECT
|
||||
tc.table_schema,
|
||||
tc.table_name,
|
||||
tc.constraint_name,
|
||||
cc.check_clause
|
||||
FROM information_schema.table_constraints tc
|
||||
JOIN information_schema.check_constraints cc
|
||||
ON tc.constraint_name = cc.constraint_name
|
||||
WHERE tc.constraint_type = 'CHECK'
|
||||
AND tc.table_schema = $1
|
||||
`
|
||||
|
||||
rows, err := r.conn.Query(r.ctx, query, schemaName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
checkConstraints := make(map[string][]*models.Constraint)
|
||||
|
||||
for rows.Next() {
|
||||
var schema, tableName, constraintName, checkClause string
|
||||
|
||||
if err := rows.Scan(&schema, &tableName, &constraintName, &checkClause); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
constraint := models.InitConstraint(constraintName, models.CheckConstraint)
|
||||
constraint.Schema = schema
|
||||
constraint.Table = tableName
|
||||
constraint.Expression = checkClause
|
||||
|
||||
tableKey := schema + "." + tableName
|
||||
checkConstraints[tableKey] = append(checkConstraints[tableKey], constraint)
|
||||
}
|
||||
|
||||
return checkConstraints, rows.Err()
|
||||
}
|
||||
|
||||
// queryIndexes retrieves all indexes for a schema
|
||||
// Returns map[schema.table][]*Index
|
||||
func (r *Reader) queryIndexes(schemaName string) (map[string][]*models.Index, error) {
|
||||
query := `
|
||||
SELECT
|
||||
schemaname,
|
||||
tablename,
|
||||
indexname,
|
||||
indexdef
|
||||
FROM pg_indexes
|
||||
WHERE schemaname = $1
|
||||
ORDER BY schemaname, tablename, indexname
|
||||
`
|
||||
|
||||
rows, err := r.conn.Query(r.ctx, query, schemaName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
indexes := make(map[string][]*models.Index)
|
||||
|
||||
for rows.Next() {
|
||||
var schema, tableName, indexName, indexDef string
|
||||
|
||||
if err := rows.Scan(&schema, &tableName, &indexName, &indexDef); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
index, err := r.parseIndexDefinition(indexName, tableName, schema, indexDef)
|
||||
if err != nil {
|
||||
// If parsing fails, create a basic index
|
||||
index = models.InitIndex(indexName)
|
||||
index.Table = tableName
|
||||
index.Schema = schema
|
||||
}
|
||||
|
||||
tableKey := schema + "." + tableName
|
||||
indexes[tableKey] = append(indexes[tableKey], index)
|
||||
}
|
||||
|
||||
return indexes, rows.Err()
|
||||
}
|
||||
|
||||
// parseIndexDefinition parses a PostgreSQL index definition
|
||||
func (r *Reader) parseIndexDefinition(indexName, tableName, schema, indexDef string) (*models.Index, error) {
|
||||
index := models.InitIndex(indexName)
|
||||
index.Table = tableName
|
||||
index.Schema = schema
|
||||
|
||||
// Check if unique
|
||||
if strings.Contains(strings.ToUpper(indexDef), "UNIQUE") {
|
||||
index.Unique = true
|
||||
}
|
||||
|
||||
// Extract index method (USING btree, hash, gin, gist, etc.)
|
||||
usingRegex := regexp.MustCompile(`USING\s+(\w+)`)
|
||||
if matches := usingRegex.FindStringSubmatch(indexDef); len(matches) > 1 {
|
||||
index.Type = strings.ToLower(matches[1])
|
||||
} else {
|
||||
index.Type = "btree" // default
|
||||
}
|
||||
|
||||
// Extract columns - pattern: (column1, column2, ...)
|
||||
columnsRegex := regexp.MustCompile(`\(([^)]+)\)`)
|
||||
if matches := columnsRegex.FindStringSubmatch(indexDef); len(matches) > 1 {
|
||||
columnsStr := matches[1]
|
||||
// Split by comma and clean up
|
||||
columnParts := strings.Split(columnsStr, ",")
|
||||
for _, col := range columnParts {
|
||||
col = strings.TrimSpace(col)
|
||||
// Remove any ordering (ASC/DESC) or other modifiers
|
||||
col = strings.Fields(col)[0]
|
||||
// Remove parentheses if it's an expression
|
||||
if !strings.Contains(col, "(") {
|
||||
index.Columns = append(index.Columns, col)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Extract WHERE clause for partial indexes
|
||||
whereRegex := regexp.MustCompile(`WHERE\s+(.+)$`)
|
||||
if matches := whereRegex.FindStringSubmatch(indexDef); len(matches) > 1 {
|
||||
index.Where = strings.TrimSpace(matches[1])
|
||||
}
|
||||
|
||||
return index, nil
|
||||
}
|
||||
346
pkg/readers/pgsql/reader.go
Normal file
346
pkg/readers/pgsql/reader.go
Normal file
@@ -0,0 +1,346 @@
|
||||
package pgsql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/pgsql"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/readers"
|
||||
)
|
||||
|
||||
// Reader implements the readers.Reader interface for PostgreSQL databases
|
||||
type Reader struct {
|
||||
options *readers.ReaderOptions
|
||||
conn *pgx.Conn
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
// NewReader creates a new PostgreSQL reader
|
||||
func NewReader(options *readers.ReaderOptions) *Reader {
|
||||
return &Reader{
|
||||
options: options,
|
||||
ctx: context.Background(),
|
||||
}
|
||||
}
|
||||
|
||||
// ReadDatabase reads the entire database schema from PostgreSQL
|
||||
func (r *Reader) ReadDatabase() (*models.Database, error) {
|
||||
// Validate connection string
|
||||
if r.options.ConnectionString == "" {
|
||||
return nil, fmt.Errorf("connection string is required")
|
||||
}
|
||||
|
||||
// Connect to the database
|
||||
if err := r.connect(); err != nil {
|
||||
return nil, fmt.Errorf("failed to connect: %w", err)
|
||||
}
|
||||
defer r.close()
|
||||
|
||||
// Get database name from connection
|
||||
var dbName string
|
||||
err := r.conn.QueryRow(r.ctx, "SELECT current_database()").Scan(&dbName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get database name: %w", err)
|
||||
}
|
||||
|
||||
// Initialize database model
|
||||
db := models.InitDatabase(dbName)
|
||||
db.DatabaseType = models.PostgresqlDatabaseType
|
||||
db.SourceFormat = "pgsql"
|
||||
|
||||
// Get PostgreSQL version
|
||||
var version string
|
||||
err = r.conn.QueryRow(r.ctx, "SELECT version()").Scan(&version)
|
||||
if err == nil {
|
||||
db.DatabaseVersion = version
|
||||
}
|
||||
|
||||
// Query all schemas
|
||||
schemas, err := r.querySchemas()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query schemas: %w", err)
|
||||
}
|
||||
|
||||
// Process each schema
|
||||
for _, schema := range schemas {
|
||||
// Query tables for this schema
|
||||
tables, err := r.queryTables(schema.Name)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query tables for schema %s: %w", schema.Name, err)
|
||||
}
|
||||
schema.Tables = tables
|
||||
|
||||
// Query views for this schema
|
||||
views, err := r.queryViews(schema.Name)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query views for schema %s: %w", schema.Name, err)
|
||||
}
|
||||
schema.Views = views
|
||||
|
||||
// Query sequences for this schema
|
||||
sequences, err := r.querySequences(schema.Name)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query sequences for schema %s: %w", schema.Name, err)
|
||||
}
|
||||
schema.Sequences = sequences
|
||||
|
||||
// Query columns for tables and views
|
||||
columnsMap, err := r.queryColumns(schema.Name)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query columns for schema %s: %w", schema.Name, err)
|
||||
}
|
||||
|
||||
// Populate table columns
|
||||
for _, table := range schema.Tables {
|
||||
tableKey := schema.Name + "." + table.Name
|
||||
if cols, exists := columnsMap[tableKey]; exists {
|
||||
table.Columns = cols
|
||||
}
|
||||
}
|
||||
|
||||
// Populate view columns
|
||||
for _, view := range schema.Views {
|
||||
viewKey := schema.Name + "." + view.Name
|
||||
if cols, exists := columnsMap[viewKey]; exists {
|
||||
view.Columns = cols
|
||||
}
|
||||
}
|
||||
|
||||
// Query primary keys
|
||||
primaryKeys, err := r.queryPrimaryKeys(schema.Name)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query primary keys for schema %s: %w", schema.Name, err)
|
||||
}
|
||||
|
||||
// Apply primary keys to tables
|
||||
for _, table := range schema.Tables {
|
||||
tableKey := schema.Name + "." + table.Name
|
||||
if pk, exists := primaryKeys[tableKey]; exists {
|
||||
table.Constraints[pk.Name] = pk
|
||||
// Mark columns as primary key and not null
|
||||
for _, colName := range pk.Columns {
|
||||
if col, colExists := table.Columns[colName]; colExists {
|
||||
col.IsPrimaryKey = true
|
||||
col.NotNull = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Query foreign keys
|
||||
foreignKeys, err := r.queryForeignKeys(schema.Name)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query foreign keys for schema %s: %w", schema.Name, err)
|
||||
}
|
||||
|
||||
// Apply foreign keys to tables
|
||||
for _, table := range schema.Tables {
|
||||
tableKey := schema.Name + "." + table.Name
|
||||
if fks, exists := foreignKeys[tableKey]; exists {
|
||||
for _, fk := range fks {
|
||||
table.Constraints[fk.Name] = fk
|
||||
// Derive relationship from foreign key
|
||||
r.deriveRelationship(table, fk)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Query unique constraints
|
||||
uniqueConstraints, err := r.queryUniqueConstraints(schema.Name)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query unique constraints for schema %s: %w", schema.Name, err)
|
||||
}
|
||||
|
||||
// Apply unique constraints to tables
|
||||
for _, table := range schema.Tables {
|
||||
tableKey := schema.Name + "." + table.Name
|
||||
if ucs, exists := uniqueConstraints[tableKey]; exists {
|
||||
for _, uc := range ucs {
|
||||
table.Constraints[uc.Name] = uc
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Query check constraints
|
||||
checkConstraints, err := r.queryCheckConstraints(schema.Name)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query check constraints for schema %s: %w", schema.Name, err)
|
||||
}
|
||||
|
||||
// Apply check constraints to tables
|
||||
for _, table := range schema.Tables {
|
||||
tableKey := schema.Name + "." + table.Name
|
||||
if ccs, exists := checkConstraints[tableKey]; exists {
|
||||
for _, cc := range ccs {
|
||||
table.Constraints[cc.Name] = cc
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Query indexes
|
||||
indexes, err := r.queryIndexes(schema.Name)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query indexes for schema %s: %w", schema.Name, err)
|
||||
}
|
||||
|
||||
// Apply indexes to tables
|
||||
for _, table := range schema.Tables {
|
||||
tableKey := schema.Name + "." + table.Name
|
||||
if idxs, exists := indexes[tableKey]; exists {
|
||||
for _, idx := range idxs {
|
||||
table.Indexes[idx.Name] = idx
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Set RefDatabase for schema
|
||||
schema.RefDatabase = db
|
||||
|
||||
// Set RefSchema for tables and views
|
||||
for _, table := range schema.Tables {
|
||||
table.RefSchema = schema
|
||||
}
|
||||
for _, view := range schema.Views {
|
||||
view.RefSchema = schema
|
||||
}
|
||||
for _, seq := range schema.Sequences {
|
||||
seq.RefSchema = schema
|
||||
}
|
||||
|
||||
// Add schema to database
|
||||
db.Schemas = append(db.Schemas, schema)
|
||||
}
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
// ReadSchema reads a single schema (returns the first schema from the database)
|
||||
func (r *Reader) ReadSchema() (*models.Schema, error) {
|
||||
db, err := r.ReadDatabase()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(db.Schemas) == 0 {
|
||||
return nil, fmt.Errorf("no schemas found in database")
|
||||
}
|
||||
return db.Schemas[0], nil
|
||||
}
|
||||
|
||||
// ReadTable reads a single table (returns the first table from the first schema)
|
||||
func (r *Reader) ReadTable() (*models.Table, error) {
|
||||
schema, err := r.ReadSchema()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(schema.Tables) == 0 {
|
||||
return nil, fmt.Errorf("no tables found in schema")
|
||||
}
|
||||
return schema.Tables[0], nil
|
||||
}
|
||||
|
||||
// connect establishes a connection to the PostgreSQL database
|
||||
func (r *Reader) connect() error {
|
||||
conn, err := pgx.Connect(r.ctx, r.options.ConnectionString)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
r.conn = conn
|
||||
return nil
|
||||
}
|
||||
|
||||
// close closes the database connection
|
||||
func (r *Reader) close() {
|
||||
if r.conn != nil {
|
||||
r.conn.Close(r.ctx)
|
||||
}
|
||||
}
|
||||
|
||||
// mapDataType maps PostgreSQL data types to canonical types
|
||||
func (r *Reader) mapDataType(pgType, udtName string) string {
|
||||
// Map common PostgreSQL types
|
||||
typeMap := map[string]string{
|
||||
"integer": "int",
|
||||
"bigint": "int64",
|
||||
"smallint": "int16",
|
||||
"int": "int",
|
||||
"int2": "int16",
|
||||
"int4": "int",
|
||||
"int8": "int64",
|
||||
"serial": "int",
|
||||
"bigserial": "int64",
|
||||
"smallserial": "int16",
|
||||
"numeric": "decimal",
|
||||
"decimal": "decimal",
|
||||
"real": "float32",
|
||||
"double precision": "float64",
|
||||
"float4": "float32",
|
||||
"float8": "float64",
|
||||
"money": "decimal",
|
||||
"character varying": "string",
|
||||
"varchar": "string",
|
||||
"character": "string",
|
||||
"char": "string",
|
||||
"text": "string",
|
||||
"boolean": "bool",
|
||||
"bool": "bool",
|
||||
"date": "date",
|
||||
"time": "time",
|
||||
"time without time zone": "time",
|
||||
"time with time zone": "timetz",
|
||||
"timestamp": "timestamp",
|
||||
"timestamp without time zone": "timestamp",
|
||||
"timestamp with time zone": "timestamptz",
|
||||
"timestamptz": "timestamptz",
|
||||
"interval": "interval",
|
||||
"uuid": "uuid",
|
||||
"json": "json",
|
||||
"jsonb": "jsonb",
|
||||
"bytea": "bytea",
|
||||
"inet": "inet",
|
||||
"cidr": "cidr",
|
||||
"macaddr": "macaddr",
|
||||
"xml": "xml",
|
||||
}
|
||||
|
||||
// Try mapped type first
|
||||
if mapped, exists := typeMap[pgType]; exists {
|
||||
return mapped
|
||||
}
|
||||
|
||||
// Use pgsql utilities if available
|
||||
if pgsql.ValidSQLType(pgType) {
|
||||
return pgsql.GetSQLType(pgType)
|
||||
}
|
||||
|
||||
// Return UDT name for custom types
|
||||
if udtName != "" {
|
||||
return udtName
|
||||
}
|
||||
|
||||
// Default to the original type
|
||||
return pgType
|
||||
}
|
||||
|
||||
// deriveRelationship creates a relationship from a foreign key constraint
|
||||
func (r *Reader) deriveRelationship(table *models.Table, fk *models.Constraint) {
|
||||
relationshipName := fmt.Sprintf("%s_to_%s", table.Name, fk.ReferencedTable)
|
||||
|
||||
relationship := models.InitRelationship(relationshipName, models.OneToMany)
|
||||
relationship.FromTable = fk.ReferencedTable
|
||||
relationship.FromSchema = fk.ReferencedSchema
|
||||
relationship.ToTable = table.Name
|
||||
relationship.ToSchema = table.Schema
|
||||
relationship.ForeignKey = fk.Name
|
||||
|
||||
// Store constraint actions in properties
|
||||
if fk.OnDelete != "" {
|
||||
relationship.Properties["on_delete"] = fk.OnDelete
|
||||
}
|
||||
if fk.OnUpdate != "" {
|
||||
relationship.Properties["on_update"] = fk.OnUpdate
|
||||
}
|
||||
|
||||
table.Relationships[relationshipName] = relationship
|
||||
}
|
||||
371
pkg/readers/pgsql/reader_test.go
Normal file
371
pkg/readers/pgsql/reader_test.go
Normal file
@@ -0,0 +1,371 @@
|
||||
package pgsql
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/readers"
|
||||
)
|
||||
|
||||
// getTestConnectionString returns a PostgreSQL connection string from environment
|
||||
// or skips the test if not available
|
||||
func getTestConnectionString(t *testing.T) string {
|
||||
connStr := os.Getenv("RELSPEC_TEST_PG_CONN")
|
||||
if connStr == "" {
|
||||
t.Skip("Skipping PostgreSQL reader test: RELSPEC_TEST_PG_CONN environment variable not set")
|
||||
}
|
||||
return connStr
|
||||
}
|
||||
|
||||
func TestReader_ReadDatabase(t *testing.T) {
|
||||
connStr := getTestConnectionString(t)
|
||||
|
||||
options := &readers.ReaderOptions{
|
||||
ConnectionString: connStr,
|
||||
}
|
||||
|
||||
reader := NewReader(options)
|
||||
db, err := reader.ReadDatabase()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read database: %v", err)
|
||||
}
|
||||
|
||||
// Verify database properties
|
||||
if db.Name == "" {
|
||||
t.Error("Database name should not be empty")
|
||||
}
|
||||
|
||||
if db.DatabaseType != models.PostgresqlDatabaseType {
|
||||
t.Errorf("Expected database type %s, got %s", models.PostgresqlDatabaseType, db.DatabaseType)
|
||||
}
|
||||
|
||||
if db.SourceFormat != "pgsql" {
|
||||
t.Errorf("Expected source format 'pgsql', got %s", db.SourceFormat)
|
||||
}
|
||||
|
||||
// Verify schemas
|
||||
if len(db.Schemas) == 0 {
|
||||
t.Error("Expected at least one schema, got none")
|
||||
}
|
||||
|
||||
// Check that system schemas are excluded
|
||||
for _, schema := range db.Schemas {
|
||||
if schema.Name == "pg_catalog" || schema.Name == "information_schema" {
|
||||
t.Errorf("System schema %s should be excluded", schema.Name)
|
||||
}
|
||||
}
|
||||
|
||||
t.Logf("Successfully read database '%s' with %d schemas", db.Name, len(db.Schemas))
|
||||
|
||||
// Log schema details
|
||||
for _, schema := range db.Schemas {
|
||||
t.Logf(" Schema: %s (Tables: %d, Views: %d, Sequences: %d)",
|
||||
schema.Name, len(schema.Tables), len(schema.Views), len(schema.Sequences))
|
||||
|
||||
// Verify tables have columns
|
||||
for _, table := range schema.Tables {
|
||||
if len(table.Columns) == 0 {
|
||||
t.Logf(" Warning: Table %s.%s has no columns", schema.Name, table.Name)
|
||||
} else {
|
||||
t.Logf(" Table: %s.%s (Columns: %d, Constraints: %d, Indexes: %d, Relationships: %d)",
|
||||
schema.Name, table.Name, len(table.Columns), len(table.Constraints),
|
||||
len(table.Indexes), len(table.Relationships))
|
||||
}
|
||||
}
|
||||
|
||||
// Verify views have columns and definitions
|
||||
for _, view := range schema.Views {
|
||||
if view.Definition == "" {
|
||||
t.Errorf("View %s.%s should have a definition", schema.Name, view.Name)
|
||||
}
|
||||
t.Logf(" View: %s.%s (Columns: %d)", schema.Name, view.Name, len(view.Columns))
|
||||
}
|
||||
|
||||
// Verify sequences
|
||||
for _, seq := range schema.Sequences {
|
||||
if seq.IncrementBy == 0 {
|
||||
t.Errorf("Sequence %s.%s should have non-zero increment", schema.Name, seq.Name)
|
||||
}
|
||||
t.Logf(" Sequence: %s.%s (Start: %d, Increment: %d)", schema.Name, seq.Name, seq.StartValue, seq.IncrementBy)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestReader_ReadSchema(t *testing.T) {
|
||||
connStr := getTestConnectionString(t)
|
||||
|
||||
options := &readers.ReaderOptions{
|
||||
ConnectionString: connStr,
|
||||
}
|
||||
|
||||
reader := NewReader(options)
|
||||
schema, err := reader.ReadSchema()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read schema: %v", err)
|
||||
}
|
||||
|
||||
if schema.Name == "" {
|
||||
t.Error("Schema name should not be empty")
|
||||
}
|
||||
|
||||
t.Logf("Successfully read schema '%s' with %d tables, %d views, %d sequences",
|
||||
schema.Name, len(schema.Tables), len(schema.Views), len(schema.Sequences))
|
||||
}
|
||||
|
||||
func TestReader_ReadTable(t *testing.T) {
|
||||
connStr := getTestConnectionString(t)
|
||||
|
||||
options := &readers.ReaderOptions{
|
||||
ConnectionString: connStr,
|
||||
}
|
||||
|
||||
reader := NewReader(options)
|
||||
table, err := reader.ReadTable()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read table: %v", err)
|
||||
}
|
||||
|
||||
if table.Name == "" {
|
||||
t.Error("Table name should not be empty")
|
||||
}
|
||||
|
||||
if table.Schema == "" {
|
||||
t.Error("Table schema should not be empty")
|
||||
}
|
||||
|
||||
t.Logf("Successfully read table '%s.%s' with %d columns",
|
||||
table.Schema, table.Name, len(table.Columns))
|
||||
}
|
||||
|
||||
func TestReader_ReadDatabase_InvalidConnectionString(t *testing.T) {
|
||||
options := &readers.ReaderOptions{
|
||||
ConnectionString: "invalid connection string",
|
||||
}
|
||||
|
||||
reader := NewReader(options)
|
||||
_, err := reader.ReadDatabase()
|
||||
if err == nil {
|
||||
t.Error("Expected error with invalid connection string, got nil")
|
||||
}
|
||||
|
||||
t.Logf("Correctly rejected invalid connection string: %v", err)
|
||||
}
|
||||
|
||||
func TestReader_ReadDatabase_EmptyConnectionString(t *testing.T) {
|
||||
options := &readers.ReaderOptions{
|
||||
ConnectionString: "",
|
||||
}
|
||||
|
||||
reader := NewReader(options)
|
||||
_, err := reader.ReadDatabase()
|
||||
if err == nil {
|
||||
t.Error("Expected error with empty connection string, got nil")
|
||||
}
|
||||
|
||||
expectedMsg := "connection string is required"
|
||||
if err.Error() != expectedMsg {
|
||||
t.Errorf("Expected error message '%s', got '%s'", expectedMsg, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestMapDataType(t *testing.T) {
|
||||
reader := &Reader{}
|
||||
|
||||
tests := []struct {
|
||||
pgType string
|
||||
udtName string
|
||||
expected string
|
||||
}{
|
||||
{"integer", "int4", "int"},
|
||||
{"bigint", "int8", "int64"},
|
||||
{"smallint", "int2", "int16"},
|
||||
{"character varying", "varchar", "string"},
|
||||
{"text", "text", "string"},
|
||||
{"boolean", "bool", "bool"},
|
||||
{"timestamp without time zone", "timestamp", "timestamp"},
|
||||
{"timestamp with time zone", "timestamptz", "timestamptz"},
|
||||
{"json", "json", "json"},
|
||||
{"jsonb", "jsonb", "jsonb"},
|
||||
{"uuid", "uuid", "uuid"},
|
||||
{"numeric", "numeric", "decimal"},
|
||||
{"real", "float4", "float32"},
|
||||
{"double precision", "float8", "float64"},
|
||||
{"date", "date", "date"},
|
||||
{"time without time zone", "time", "time"},
|
||||
{"bytea", "bytea", "bytea"},
|
||||
{"unknown_type", "custom", "custom"}, // Should return UDT name
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.pgType, func(t *testing.T) {
|
||||
result := reader.mapDataType(tt.pgType, tt.udtName)
|
||||
if result != tt.expected {
|
||||
t.Errorf("mapDataType(%s, %s) = %s, expected %s", tt.pgType, tt.udtName, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseIndexDefinition(t *testing.T) {
|
||||
reader := &Reader{}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
indexName string
|
||||
tableName string
|
||||
schema string
|
||||
indexDef string
|
||||
wantType string
|
||||
wantUnique bool
|
||||
wantColumns int
|
||||
}{
|
||||
{
|
||||
name: "simple btree index",
|
||||
indexName: "idx_users_email",
|
||||
tableName: "users",
|
||||
schema: "public",
|
||||
indexDef: "CREATE INDEX idx_users_email ON public.users USING btree (email)",
|
||||
wantType: "btree",
|
||||
wantUnique: false,
|
||||
wantColumns: 1,
|
||||
},
|
||||
{
|
||||
name: "unique index",
|
||||
indexName: "idx_users_username",
|
||||
tableName: "users",
|
||||
schema: "public",
|
||||
indexDef: "CREATE UNIQUE INDEX idx_users_username ON public.users USING btree (username)",
|
||||
wantType: "btree",
|
||||
wantUnique: true,
|
||||
wantColumns: 1,
|
||||
},
|
||||
{
|
||||
name: "composite index",
|
||||
indexName: "idx_users_name",
|
||||
tableName: "users",
|
||||
schema: "public",
|
||||
indexDef: "CREATE INDEX idx_users_name ON public.users USING btree (first_name, last_name)",
|
||||
wantType: "btree",
|
||||
wantUnique: false,
|
||||
wantColumns: 2,
|
||||
},
|
||||
{
|
||||
name: "gin index",
|
||||
indexName: "idx_posts_tags",
|
||||
tableName: "posts",
|
||||
schema: "public",
|
||||
indexDef: "CREATE INDEX idx_posts_tags ON public.posts USING gin (tags)",
|
||||
wantType: "gin",
|
||||
wantUnique: false,
|
||||
wantColumns: 1,
|
||||
},
|
||||
{
|
||||
name: "partial index with where clause",
|
||||
indexName: "idx_users_active",
|
||||
tableName: "users",
|
||||
schema: "public",
|
||||
indexDef: "CREATE INDEX idx_users_active ON public.users USING btree (id) WHERE (active = true)",
|
||||
wantType: "btree",
|
||||
wantUnique: false,
|
||||
wantColumns: 1,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
index, err := reader.parseIndexDefinition(tt.indexName, tt.tableName, tt.schema, tt.indexDef)
|
||||
if err != nil {
|
||||
t.Fatalf("parseIndexDefinition() error = %v", err)
|
||||
}
|
||||
|
||||
if index.Name != tt.indexName {
|
||||
t.Errorf("Name = %s, want %s", index.Name, tt.indexName)
|
||||
}
|
||||
|
||||
if index.Type != tt.wantType {
|
||||
t.Errorf("Type = %s, want %s", index.Type, tt.wantType)
|
||||
}
|
||||
|
||||
if index.Unique != tt.wantUnique {
|
||||
t.Errorf("Unique = %v, want %v", index.Unique, tt.wantUnique)
|
||||
}
|
||||
|
||||
if len(index.Columns) != tt.wantColumns {
|
||||
t.Errorf("Columns count = %d, want %d (columns: %v)", len(index.Columns), tt.wantColumns, index.Columns)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeriveRelationship(t *testing.T) {
|
||||
table := models.InitTable("orders", "public")
|
||||
|
||||
fk := models.InitConstraint("fk_orders_user_id", models.ForeignKeyConstraint)
|
||||
fk.Schema = "public"
|
||||
fk.Table = "orders"
|
||||
fk.Columns = []string{"user_id"}
|
||||
fk.ReferencedSchema = "public"
|
||||
fk.ReferencedTable = "users"
|
||||
fk.ReferencedColumns = []string{"id"}
|
||||
fk.OnDelete = "CASCADE"
|
||||
fk.OnUpdate = "RESTRICT"
|
||||
|
||||
reader := &Reader{}
|
||||
reader.deriveRelationship(table, fk)
|
||||
|
||||
if len(table.Relationships) != 1 {
|
||||
t.Fatalf("Expected 1 relationship, got %d", len(table.Relationships))
|
||||
}
|
||||
|
||||
relName := "orders_to_users"
|
||||
rel, exists := table.Relationships[relName]
|
||||
if !exists {
|
||||
t.Fatalf("Expected relationship '%s', not found", relName)
|
||||
}
|
||||
|
||||
if rel.Type != models.OneToMany {
|
||||
t.Errorf("Expected relationship type %s, got %s", models.OneToMany, rel.Type)
|
||||
}
|
||||
|
||||
if rel.FromTable != "users" {
|
||||
t.Errorf("Expected FromTable 'users', got '%s'", rel.FromTable)
|
||||
}
|
||||
|
||||
if rel.ToTable != "orders" {
|
||||
t.Errorf("Expected ToTable 'orders', got '%s'", rel.ToTable)
|
||||
}
|
||||
|
||||
if rel.ForeignKey != "fk_orders_user_id" {
|
||||
t.Errorf("Expected ForeignKey 'fk_orders_user_id', got '%s'", rel.ForeignKey)
|
||||
}
|
||||
|
||||
if rel.Properties["on_delete"] != "CASCADE" {
|
||||
t.Errorf("Expected on_delete 'CASCADE', got '%s'", rel.Properties["on_delete"])
|
||||
}
|
||||
|
||||
if rel.Properties["on_update"] != "RESTRICT" {
|
||||
t.Errorf("Expected on_update 'RESTRICT', got '%s'", rel.Properties["on_update"])
|
||||
}
|
||||
}
|
||||
|
||||
// Benchmark tests
|
||||
func BenchmarkReader_ReadDatabase(b *testing.B) {
|
||||
connStr := os.Getenv("RELSPEC_TEST_PG_CONN")
|
||||
if connStr == "" {
|
||||
b.Skip("Skipping benchmark: RELSPEC_TEST_PG_CONN environment variable not set")
|
||||
}
|
||||
|
||||
options := &readers.ReaderOptions{
|
||||
ConnectionString: connStr,
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
reader := NewReader(options)
|
||||
_, err := reader.ReadDatabase()
|
||||
if err != nil {
|
||||
b.Fatalf("Failed to read database: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user