9190df81dd
* Introduced extractTypeParts function to handle embedded dimensions in type strings. * Updated columnTypeConflict to utilize new type extraction logic. * Improved PostgreSQL type normalization and handling in various components.
344 lines
9.5 KiB
Go
344 lines
9.5 KiB
Go
package pgsql
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"strings"
|
|
|
|
"github.com/jackc/pgx/v5"
|
|
|
|
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
|
"git.warky.dev/wdevs/relspecgo/pkg/pgsql"
|
|
"git.warky.dev/wdevs/relspecgo/pkg/readers"
|
|
)
|
|
|
|
// Reader implements the readers.Reader interface for PostgreSQL databases
|
|
type Reader struct {
|
|
options *readers.ReaderOptions
|
|
conn *pgx.Conn
|
|
ctx context.Context
|
|
}
|
|
|
|
// NewReader creates a new PostgreSQL reader
|
|
func NewReader(options *readers.ReaderOptions) *Reader {
|
|
return &Reader{
|
|
options: options,
|
|
ctx: context.Background(),
|
|
}
|
|
}
|
|
|
|
// ReadDatabase reads the entire database schema from PostgreSQL
|
|
func (r *Reader) ReadDatabase() (*models.Database, error) {
|
|
// Validate connection string
|
|
if r.options.ConnectionString == "" {
|
|
return nil, fmt.Errorf("connection string is required")
|
|
}
|
|
|
|
// Connect to the database
|
|
if err := r.connect(); err != nil {
|
|
return nil, fmt.Errorf("failed to connect: %w", err)
|
|
}
|
|
defer r.close()
|
|
|
|
// Get database name from connection
|
|
var dbName string
|
|
err := r.conn.QueryRow(r.ctx, "SELECT current_database()").Scan(&dbName)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get database name: %w", err)
|
|
}
|
|
|
|
// Initialize database model
|
|
db := models.InitDatabase(dbName)
|
|
db.DatabaseType = models.PostgresqlDatabaseType
|
|
db.SourceFormat = "pgsql"
|
|
|
|
// Get PostgreSQL version
|
|
var version string
|
|
err = r.conn.QueryRow(r.ctx, "SELECT version()").Scan(&version)
|
|
if err == nil {
|
|
db.DatabaseVersion = version
|
|
}
|
|
|
|
// Query all schemas
|
|
schemas, err := r.querySchemas()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to query schemas: %w", err)
|
|
}
|
|
|
|
// Process each schema
|
|
for _, schema := range schemas {
|
|
// Query tables for this schema
|
|
tables, err := r.queryTables(schema.Name)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to query tables for schema %s: %w", schema.Name, err)
|
|
}
|
|
schema.Tables = tables
|
|
|
|
// Query views for this schema
|
|
views, err := r.queryViews(schema.Name)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to query views for schema %s: %w", schema.Name, err)
|
|
}
|
|
schema.Views = views
|
|
|
|
// Query sequences for this schema
|
|
sequences, err := r.querySequences(schema.Name)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to query sequences for schema %s: %w", schema.Name, err)
|
|
}
|
|
schema.Sequences = sequences
|
|
|
|
// Query columns for tables and views
|
|
columnsMap, err := r.queryColumns(schema.Name)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to query columns for schema %s: %w", schema.Name, err)
|
|
}
|
|
|
|
// Populate table columns
|
|
for _, table := range schema.Tables {
|
|
tableKey := schema.Name + "." + table.Name
|
|
if cols, exists := columnsMap[tableKey]; exists {
|
|
table.Columns = cols
|
|
}
|
|
}
|
|
|
|
// Populate view columns
|
|
for _, view := range schema.Views {
|
|
viewKey := schema.Name + "." + view.Name
|
|
if cols, exists := columnsMap[viewKey]; exists {
|
|
view.Columns = cols
|
|
}
|
|
}
|
|
|
|
// Query primary keys
|
|
primaryKeys, err := r.queryPrimaryKeys(schema.Name)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to query primary keys for schema %s: %w", schema.Name, err)
|
|
}
|
|
|
|
// Apply primary keys to tables
|
|
for _, table := range schema.Tables {
|
|
tableKey := schema.Name + "." + table.Name
|
|
if pk, exists := primaryKeys[tableKey]; exists {
|
|
table.Constraints[pk.Name] = pk
|
|
// Mark columns as primary key and not null
|
|
for _, colName := range pk.Columns {
|
|
if col, colExists := table.Columns[colName]; colExists {
|
|
col.IsPrimaryKey = true
|
|
col.NotNull = true
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Query foreign keys
|
|
foreignKeys, err := r.queryForeignKeys(schema.Name)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to query foreign keys for schema %s: %w", schema.Name, err)
|
|
}
|
|
|
|
// Apply foreign keys to tables
|
|
for _, table := range schema.Tables {
|
|
tableKey := schema.Name + "." + table.Name
|
|
if fks, exists := foreignKeys[tableKey]; exists {
|
|
for _, fk := range fks {
|
|
table.Constraints[fk.Name] = fk
|
|
// Derive relationship from foreign key
|
|
r.deriveRelationship(table, fk)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Query unique constraints
|
|
uniqueConstraints, err := r.queryUniqueConstraints(schema.Name)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to query unique constraints for schema %s: %w", schema.Name, err)
|
|
}
|
|
|
|
// Apply unique constraints to tables
|
|
for _, table := range schema.Tables {
|
|
tableKey := schema.Name + "." + table.Name
|
|
if ucs, exists := uniqueConstraints[tableKey]; exists {
|
|
for _, uc := range ucs {
|
|
table.Constraints[uc.Name] = uc
|
|
}
|
|
}
|
|
}
|
|
|
|
// Query check constraints
|
|
checkConstraints, err := r.queryCheckConstraints(schema.Name)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to query check constraints for schema %s: %w", schema.Name, err)
|
|
}
|
|
|
|
// Apply check constraints to tables
|
|
for _, table := range schema.Tables {
|
|
tableKey := schema.Name + "." + table.Name
|
|
if ccs, exists := checkConstraints[tableKey]; exists {
|
|
for _, cc := range ccs {
|
|
table.Constraints[cc.Name] = cc
|
|
}
|
|
}
|
|
}
|
|
|
|
// Query indexes
|
|
indexes, err := r.queryIndexes(schema.Name)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to query indexes for schema %s: %w", schema.Name, err)
|
|
}
|
|
|
|
// Apply indexes to tables
|
|
for _, table := range schema.Tables {
|
|
tableKey := schema.Name + "." + table.Name
|
|
if idxs, exists := indexes[tableKey]; exists {
|
|
for _, idx := range idxs {
|
|
table.Indexes[idx.Name] = idx
|
|
}
|
|
}
|
|
}
|
|
|
|
// Set RefDatabase for schema
|
|
schema.RefDatabase = db
|
|
|
|
// Set RefSchema for tables and views
|
|
for _, table := range schema.Tables {
|
|
table.RefSchema = schema
|
|
}
|
|
for _, view := range schema.Views {
|
|
view.RefSchema = schema
|
|
}
|
|
for _, seq := range schema.Sequences {
|
|
seq.RefSchema = schema
|
|
}
|
|
|
|
// Add schema to database
|
|
db.Schemas = append(db.Schemas, schema)
|
|
}
|
|
|
|
return db, nil
|
|
}
|
|
|
|
// ReadSchema reads a single schema (returns the first schema from the database)
|
|
func (r *Reader) ReadSchema() (*models.Schema, error) {
|
|
db, err := r.ReadDatabase()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if len(db.Schemas) == 0 {
|
|
return nil, fmt.Errorf("no schemas found in database")
|
|
}
|
|
return db.Schemas[0], nil
|
|
}
|
|
|
|
// ReadTable reads a single table (returns the first table from the first schema)
|
|
func (r *Reader) ReadTable() (*models.Table, error) {
|
|
schema, err := r.ReadSchema()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if len(schema.Tables) == 0 {
|
|
return nil, fmt.Errorf("no tables found in schema")
|
|
}
|
|
return schema.Tables[0], nil
|
|
}
|
|
|
|
// connect establishes a connection to the PostgreSQL database
|
|
func (r *Reader) connect() error {
|
|
conn, err := pgsql.Connect(r.ctx, r.options.ConnectionString, "reader-pgsql")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
r.conn = conn
|
|
return nil
|
|
}
|
|
|
|
// close closes the database connection
|
|
func (r *Reader) close() {
|
|
if r.conn != nil {
|
|
r.conn.Close(r.ctx)
|
|
}
|
|
}
|
|
|
|
// mapDataType maps a PostgreSQL data type to its canonical RelSpec name.
|
|
// For known built-in types, dimensions are stripped from the type string (they are
|
|
// stored separately in column.Length/Precision/Scale). For custom types (e.g.
|
|
// vector(1536), postgis geometries), the full formatted type is preserved.
|
|
func (r *Reader) mapDataType(pgType, udtName, formattedType string, hasNextval bool) string {
|
|
normalizedPGType := strings.ToLower(strings.TrimSpace(pgType))
|
|
|
|
// Detect serial types from nextval defaults before anything else.
|
|
if hasNextval {
|
|
switch normalizedPGType {
|
|
case "integer", "int", "int4":
|
|
return "serial"
|
|
case "bigint", "int8":
|
|
return "bigserial"
|
|
case "smallint", "int2":
|
|
return "smallserial"
|
|
}
|
|
}
|
|
|
|
// information_schema reports arrays generically as "ARRAY" with udt_name like "_text".
|
|
if strings.EqualFold(pgType, "ARRAY") && strings.HasPrefix(udtName, "_") && len(udtName) > 1 {
|
|
return udtName[1:] + "[]"
|
|
}
|
|
|
|
// Use the database-formatted type when available. For known built-in types, strip
|
|
// embedded dimensions (they are stored in column.Length/Precision/Scale separately).
|
|
// For unknown/custom types, keep the full formatted string (e.g. vector(1536)).
|
|
if strings.TrimSpace(formattedType) != "" {
|
|
lower := strings.ToLower(strings.TrimSpace(formattedType))
|
|
isArray := strings.HasSuffix(lower, "[]")
|
|
base := strings.TrimSuffix(lower, "[]")
|
|
if idx := strings.Index(base, "("); idx >= 0 {
|
|
base = strings.TrimSpace(base[:idx])
|
|
}
|
|
canonical := pgsql.NormalizePGType(base)
|
|
if pgsql.IsKnownPGBaseType(canonical) {
|
|
if isArray {
|
|
return canonical + "[]"
|
|
}
|
|
return canonical
|
|
}
|
|
return formattedType
|
|
}
|
|
|
|
// Fall back to normalizing the information_schema type name directly.
|
|
canonical := pgsql.NormalizePGType(normalizedPGType)
|
|
if pgsql.IsKnownPGBaseType(canonical) {
|
|
return canonical
|
|
}
|
|
|
|
// Return UDT name for custom types.
|
|
if udtName != "" {
|
|
if strings.HasPrefix(udtName, "_") && len(udtName) > 1 {
|
|
return udtName[1:] + "[]"
|
|
}
|
|
return udtName
|
|
}
|
|
|
|
return pgType
|
|
}
|
|
|
|
// deriveRelationship creates a relationship from a foreign key constraint
|
|
func (r *Reader) deriveRelationship(table *models.Table, fk *models.Constraint) {
|
|
relationshipName := fmt.Sprintf("%s_to_%s", table.Name, fk.ReferencedTable)
|
|
|
|
relationship := models.InitRelationship(relationshipName, models.OneToMany)
|
|
relationship.FromTable = table.Name
|
|
relationship.FromSchema = table.Schema
|
|
relationship.ToTable = fk.ReferencedTable
|
|
relationship.ToSchema = fk.ReferencedSchema
|
|
relationship.ForeignKey = fk.Name
|
|
|
|
// Store constraint actions in properties
|
|
if fk.OnDelete != "" {
|
|
relationship.Properties["on_delete"] = fk.OnDelete
|
|
}
|
|
if fk.OnUpdate != "" {
|
|
relationship.Properties["on_update"] = fk.OnUpdate
|
|
}
|
|
|
|
table.Relationships[relationshipName] = relationship
|
|
}
|