diff --git a/pkg/readers/pgsql/queries.go b/pkg/readers/pgsql/queries.go index 7fb81f7..668cecb 100644 --- a/pkg/readers/pgsql/queries.go +++ b/pkg/readers/pgsql/queries.go @@ -231,14 +231,13 @@ func (r *Reader) queryColumns(schemaName string) (map[string]map[string]*models. } column := models.InitColumn(columnName, tableName, schema) - column.Type = r.mapDataType(dataType, udtName) - column.NotNull = (isNullable == "NO") - column.Sequence = uint(ordinalPosition) + // Check if this is a serial type (has nextval default) + hasNextval := false if columnDefault != nil { - // Parse default value - remove nextval for sequences defaultVal := *columnDefault if strings.HasPrefix(defaultVal, "nextval") { + hasNextval = true column.AutoIncrement = true column.Default = defaultVal } else { @@ -246,6 +245,11 @@ func (r *Reader) queryColumns(schemaName string) (map[string]map[string]*models. } } + // Map data type, preserving serial types when detected + column.Type = r.mapDataType(dataType, udtName, hasNextval) + column.NotNull = (isNullable == "NO") + column.Sequence = uint(ordinalPosition) + if description != nil { column.Description = *description } diff --git a/pkg/readers/pgsql/reader.go b/pkg/readers/pgsql/reader.go index 75c234e..a60cd44 100644 --- a/pkg/readers/pgsql/reader.go +++ b/pkg/readers/pgsql/reader.go @@ -3,6 +3,7 @@ package pgsql import ( "context" "fmt" + "strings" "github.com/jackc/pgx/v5" @@ -259,33 +260,46 @@ func (r *Reader) close() { } // mapDataType maps PostgreSQL data types to canonical types -func (r *Reader) mapDataType(pgType, udtName string) string { +func (r *Reader) mapDataType(pgType, udtName string, hasNextval bool) string { + // If the column has a nextval default, it's likely a serial type + // Map to the appropriate serial type instead of the base integer type + if hasNextval { + switch strings.ToLower(pgType) { + case "integer", "int", "int4": + return "serial" + case "bigint", "int8": + return "bigserial" + case "smallint", "int2": + return "smallserial" + } + } + // Map common PostgreSQL types typeMap := map[string]string{ - "integer": "int", - "bigint": "int64", - "smallint": "int16", - "int": "int", - "int2": "int16", - "int4": "int", - "int8": "int64", - "serial": "int", - "bigserial": "int64", - "smallserial": "int16", - "numeric": "decimal", + "integer": "integer", + "bigint": "bigint", + "smallint": "smallint", + "int": "integer", + "int2": "smallint", + "int4": "integer", + "int8": "bigint", + "serial": "serial", + "bigserial": "bigserial", + "smallserial": "smallserial", + "numeric": "numeric", "decimal": "decimal", - "real": "float32", - "double precision": "float64", - "float4": "float32", - "float8": "float64", - "money": "decimal", - "character varying": "string", - "varchar": "string", - "character": "string", - "char": "string", - "text": "string", - "boolean": "bool", - "bool": "bool", + "real": "real", + "double precision": "double precision", + "float4": "real", + "float8": "double precision", + "money": "money", + "character varying": "varchar", + "varchar": "varchar", + "character": "char", + "char": "char", + "text": "text", + "boolean": "boolean", + "bool": "boolean", "date": "date", "time": "time", "time without time zone": "time", diff --git a/pkg/readers/pgsql/reader_test.go b/pkg/readers/pgsql/reader_test.go index 80415de..e496b47 100644 --- a/pkg/readers/pgsql/reader_test.go +++ b/pkg/readers/pgsql/reader_test.go @@ -177,20 +177,20 @@ func TestMapDataType(t *testing.T) { udtName string expected string }{ - {"integer", "int4", "int"}, - {"bigint", "int8", "int64"}, - {"smallint", "int2", "int16"}, - {"character varying", "varchar", "string"}, - {"text", "text", "string"}, - {"boolean", "bool", "bool"}, + {"integer", "int4", "integer"}, + {"bigint", "int8", "bigint"}, + {"smallint", "int2", "smallint"}, + {"character varying", "varchar", "varchar"}, + {"text", "text", "text"}, + {"boolean", "bool", "boolean"}, {"timestamp without time zone", "timestamp", "timestamp"}, {"timestamp with time zone", "timestamptz", "timestamptz"}, {"json", "json", "json"}, {"jsonb", "jsonb", "jsonb"}, {"uuid", "uuid", "uuid"}, - {"numeric", "numeric", "decimal"}, - {"real", "float4", "float32"}, - {"double precision", "float8", "float64"}, + {"numeric", "numeric", "numeric"}, + {"real", "float4", "real"}, + {"double precision", "float8", "double precision"}, {"date", "date", "date"}, {"time without time zone", "time", "time"}, {"bytea", "bytea", "bytea"}, @@ -199,12 +199,31 @@ func TestMapDataType(t *testing.T) { for _, tt := range tests { t.Run(tt.pgType, func(t *testing.T) { - result := reader.mapDataType(tt.pgType, tt.udtName) + result := reader.mapDataType(tt.pgType, tt.udtName, false) if result != tt.expected { t.Errorf("mapDataType(%s, %s) = %s, expected %s", tt.pgType, tt.udtName, result, tt.expected) } }) } + + // Test serial type detection with hasNextval=true + serialTests := []struct { + pgType string + expected string + }{ + {"integer", "serial"}, + {"bigint", "bigserial"}, + {"smallint", "smallserial"}, + } + + for _, tt := range serialTests { + t.Run(tt.pgType+"_with_nextval", func(t *testing.T) { + result := reader.mapDataType(tt.pgType, "", true) + if result != tt.expected { + t.Errorf("mapDataType(%s, '', true) = %s, expected %s", tt.pgType, result, tt.expected) + } + }) + } } func TestParseIndexDefinition(t *testing.T) {