348 lines
9.7 KiB
Go
348 lines
9.7 KiB
Go
package pgsql
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
|
|
"github.com/jackc/pgx/v5"
|
|
|
|
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
|
"git.warky.dev/wdevs/relspecgo/pkg/pgsql"
|
|
"git.warky.dev/wdevs/relspecgo/pkg/readers"
|
|
)
|
|
|
|
// Reader implements the readers.Reader interface for PostgreSQL databases
|
|
type Reader struct {
|
|
options *readers.ReaderOptions
|
|
conn *pgx.Conn
|
|
ctx context.Context
|
|
}
|
|
|
|
// NewReader creates a new PostgreSQL reader
|
|
func NewReader(options *readers.ReaderOptions) *Reader {
|
|
return &Reader{
|
|
options: options,
|
|
ctx: context.Background(),
|
|
}
|
|
}
|
|
|
|
// ReadDatabase reads the entire database schema from PostgreSQL
|
|
func (r *Reader) ReadDatabase() (*models.Database, error) {
|
|
// Validate connection string
|
|
if r.options.ConnectionString == "" {
|
|
return nil, fmt.Errorf("connection string is required")
|
|
}
|
|
|
|
// Connect to the database
|
|
if err := r.connect(); err != nil {
|
|
return nil, fmt.Errorf("failed to connect: %w", err)
|
|
}
|
|
defer r.close()
|
|
|
|
// Get database name from connection
|
|
var dbName string
|
|
err := r.conn.QueryRow(r.ctx, "SELECT current_database()").Scan(&dbName)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get database name: %w", err)
|
|
}
|
|
|
|
// Initialize database model
|
|
db := models.InitDatabase(dbName)
|
|
db.DatabaseType = models.PostgresqlDatabaseType
|
|
db.SourceFormat = "pgsql"
|
|
|
|
// Get PostgreSQL version
|
|
var version string
|
|
err = r.conn.QueryRow(r.ctx, "SELECT version()").Scan(&version)
|
|
if err == nil {
|
|
db.DatabaseVersion = version
|
|
}
|
|
|
|
// Query all schemas
|
|
schemas, err := r.querySchemas()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to query schemas: %w", err)
|
|
}
|
|
|
|
// Process each schema
|
|
for _, schema := range schemas {
|
|
// Query tables for this schema
|
|
tables, err := r.queryTables(schema.Name)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to query tables for schema %s: %w", schema.Name, err)
|
|
}
|
|
schema.Tables = tables
|
|
|
|
// Query views for this schema
|
|
views, err := r.queryViews(schema.Name)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to query views for schema %s: %w", schema.Name, err)
|
|
}
|
|
schema.Views = views
|
|
|
|
// Query sequences for this schema
|
|
sequences, err := r.querySequences(schema.Name)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to query sequences for schema %s: %w", schema.Name, err)
|
|
}
|
|
schema.Sequences = sequences
|
|
|
|
// Query columns for tables and views
|
|
columnsMap, err := r.queryColumns(schema.Name)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to query columns for schema %s: %w", schema.Name, err)
|
|
}
|
|
|
|
// Populate table columns
|
|
for _, table := range schema.Tables {
|
|
tableKey := schema.Name + "." + table.Name
|
|
if cols, exists := columnsMap[tableKey]; exists {
|
|
table.Columns = cols
|
|
}
|
|
}
|
|
|
|
// Populate view columns
|
|
for _, view := range schema.Views {
|
|
viewKey := schema.Name + "." + view.Name
|
|
if cols, exists := columnsMap[viewKey]; exists {
|
|
view.Columns = cols
|
|
}
|
|
}
|
|
|
|
// Query primary keys
|
|
primaryKeys, err := r.queryPrimaryKeys(schema.Name)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to query primary keys for schema %s: %w", schema.Name, err)
|
|
}
|
|
|
|
// Apply primary keys to tables
|
|
for _, table := range schema.Tables {
|
|
tableKey := schema.Name + "." + table.Name
|
|
if pk, exists := primaryKeys[tableKey]; exists {
|
|
table.Constraints[pk.Name] = pk
|
|
// Mark columns as primary key and not null
|
|
for _, colName := range pk.Columns {
|
|
if col, colExists := table.Columns[colName]; colExists {
|
|
col.IsPrimaryKey = true
|
|
col.NotNull = true
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Query foreign keys
|
|
foreignKeys, err := r.queryForeignKeys(schema.Name)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to query foreign keys for schema %s: %w", schema.Name, err)
|
|
}
|
|
|
|
// Apply foreign keys to tables
|
|
for _, table := range schema.Tables {
|
|
tableKey := schema.Name + "." + table.Name
|
|
if fks, exists := foreignKeys[tableKey]; exists {
|
|
for _, fk := range fks {
|
|
table.Constraints[fk.Name] = fk
|
|
// Derive relationship from foreign key
|
|
r.deriveRelationship(table, fk)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Query unique constraints
|
|
uniqueConstraints, err := r.queryUniqueConstraints(schema.Name)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to query unique constraints for schema %s: %w", schema.Name, err)
|
|
}
|
|
|
|
// Apply unique constraints to tables
|
|
for _, table := range schema.Tables {
|
|
tableKey := schema.Name + "." + table.Name
|
|
if ucs, exists := uniqueConstraints[tableKey]; exists {
|
|
for _, uc := range ucs {
|
|
table.Constraints[uc.Name] = uc
|
|
}
|
|
}
|
|
}
|
|
|
|
// Query check constraints
|
|
checkConstraints, err := r.queryCheckConstraints(schema.Name)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to query check constraints for schema %s: %w", schema.Name, err)
|
|
}
|
|
|
|
// Apply check constraints to tables
|
|
for _, table := range schema.Tables {
|
|
tableKey := schema.Name + "." + table.Name
|
|
if ccs, exists := checkConstraints[tableKey]; exists {
|
|
for _, cc := range ccs {
|
|
table.Constraints[cc.Name] = cc
|
|
}
|
|
}
|
|
}
|
|
|
|
// Query indexes
|
|
indexes, err := r.queryIndexes(schema.Name)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to query indexes for schema %s: %w", schema.Name, err)
|
|
}
|
|
|
|
// Apply indexes to tables
|
|
for _, table := range schema.Tables {
|
|
tableKey := schema.Name + "." + table.Name
|
|
if idxs, exists := indexes[tableKey]; exists {
|
|
for _, idx := range idxs {
|
|
table.Indexes[idx.Name] = idx
|
|
}
|
|
}
|
|
}
|
|
|
|
// Set RefDatabase for schema
|
|
schema.RefDatabase = db
|
|
|
|
// Set RefSchema for tables and views
|
|
for _, table := range schema.Tables {
|
|
table.RefSchema = schema
|
|
}
|
|
for _, view := range schema.Views {
|
|
view.RefSchema = schema
|
|
}
|
|
for _, seq := range schema.Sequences {
|
|
seq.RefSchema = schema
|
|
}
|
|
|
|
// Add schema to database
|
|
db.Schemas = append(db.Schemas, schema)
|
|
}
|
|
|
|
return db, nil
|
|
}
|
|
|
|
// ReadSchema reads a single schema (returns the first schema from the database)
|
|
func (r *Reader) ReadSchema() (*models.Schema, error) {
|
|
db, err := r.ReadDatabase()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if len(db.Schemas) == 0 {
|
|
return nil, fmt.Errorf("no schemas found in database")
|
|
}
|
|
return db.Schemas[0], nil
|
|
}
|
|
|
|
// ReadTable reads a single table (returns the first table from the first schema)
|
|
func (r *Reader) ReadTable() (*models.Table, error) {
|
|
schema, err := r.ReadSchema()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if len(schema.Tables) == 0 {
|
|
return nil, fmt.Errorf("no tables found in schema")
|
|
}
|
|
return schema.Tables[0], nil
|
|
}
|
|
|
|
// connect establishes a connection to the PostgreSQL database
|
|
func (r *Reader) connect() error {
|
|
conn, err := pgx.Connect(r.ctx, r.options.ConnectionString)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
r.conn = conn
|
|
return nil
|
|
}
|
|
|
|
// close closes the database connection
|
|
func (r *Reader) close() {
|
|
if r.conn != nil {
|
|
r.conn.Close(r.ctx)
|
|
}
|
|
}
|
|
|
|
// mapDataType maps PostgreSQL data types to canonical types
|
|
func (r *Reader) mapDataType(pgType, udtName string) string {
|
|
// Map common PostgreSQL types
|
|
typeMap := map[string]string{
|
|
"integer": "int",
|
|
"bigint": "int64",
|
|
"smallint": "int16",
|
|
"int": "int",
|
|
"int2": "int16",
|
|
"int4": "int",
|
|
"int8": "int64",
|
|
"serial": "int",
|
|
"bigserial": "int64",
|
|
"smallserial": "int16",
|
|
"numeric": "decimal",
|
|
"decimal": "decimal",
|
|
"real": "float32",
|
|
"double precision": "float64",
|
|
"float4": "float32",
|
|
"float8": "float64",
|
|
"money": "decimal",
|
|
"character varying": "string",
|
|
"varchar": "string",
|
|
"character": "string",
|
|
"char": "string",
|
|
"text": "string",
|
|
"boolean": "bool",
|
|
"bool": "bool",
|
|
"date": "date",
|
|
"time": "time",
|
|
"time without time zone": "time",
|
|
"time with time zone": "timetz",
|
|
"timestamp": "timestamp",
|
|
"timestamp without time zone": "timestamp",
|
|
"timestamp with time zone": "timestamptz",
|
|
"timestamptz": "timestamptz",
|
|
"interval": "interval",
|
|
"uuid": "uuid",
|
|
"json": "json",
|
|
"jsonb": "jsonb",
|
|
"bytea": "bytea",
|
|
"inet": "inet",
|
|
"cidr": "cidr",
|
|
"macaddr": "macaddr",
|
|
"xml": "xml",
|
|
}
|
|
|
|
// Try mapped type first
|
|
if mapped, exists := typeMap[pgType]; exists {
|
|
return mapped
|
|
}
|
|
|
|
// Use pgsql utilities if available
|
|
if pgsql.ValidSQLType(pgType) {
|
|
return pgsql.GetSQLType(pgType)
|
|
}
|
|
|
|
// Return UDT name for custom types
|
|
if udtName != "" {
|
|
return udtName
|
|
}
|
|
|
|
// Default to the original type
|
|
return pgType
|
|
}
|
|
|
|
// deriveRelationship creates a relationship from a foreign key constraint
|
|
func (r *Reader) deriveRelationship(table *models.Table, fk *models.Constraint) {
|
|
relationshipName := fmt.Sprintf("%s_to_%s", table.Name, fk.ReferencedTable)
|
|
|
|
relationship := models.InitRelationship(relationshipName, models.OneToMany)
|
|
relationship.FromTable = fk.ReferencedTable
|
|
relationship.FromSchema = fk.ReferencedSchema
|
|
relationship.ToTable = table.Name
|
|
relationship.ToSchema = table.Schema
|
|
relationship.ForeignKey = fk.Name
|
|
|
|
// Store constraint actions in properties
|
|
if fk.OnDelete != "" {
|
|
relationship.Properties["on_delete"] = fk.OnDelete
|
|
}
|
|
if fk.OnUpdate != "" {
|
|
relationship.Properties["on_update"] = fk.OnUpdate
|
|
}
|
|
|
|
table.Relationships[relationshipName] = relationship
|
|
}
|