feat(pgsql): enhance data type mapping to support serial types
All checks were successful
CI / Test (1.25) (push) Successful in -24m18s
CI / Test (1.24) (push) Successful in -24m6s
CI / Build (push) Successful in -25m14s
CI / Lint (push) Successful in -24m47s
Release / Build and Release (push) Successful in -25m37s
Integration Tests / Integration Tests (push) Successful in -25m9s

This commit is contained in:
2026-02-08 17:31:28 +02:00
parent 59c4a5ebf8
commit 1dcbc79387
3 changed files with 75 additions and 38 deletions

View File

@@ -231,14 +231,13 @@ func (r *Reader) queryColumns(schemaName string) (map[string]map[string]*models.
}
column := models.InitColumn(columnName, tableName, schema)
column.Type = r.mapDataType(dataType, udtName)
column.NotNull = (isNullable == "NO")
column.Sequence = uint(ordinalPosition)
// Check if this is a serial type (has nextval default)
hasNextval := false
if columnDefault != nil {
// Parse default value - remove nextval for sequences
defaultVal := *columnDefault
if strings.HasPrefix(defaultVal, "nextval") {
hasNextval = true
column.AutoIncrement = true
column.Default = defaultVal
} else {
@@ -246,6 +245,11 @@ func (r *Reader) queryColumns(schemaName string) (map[string]map[string]*models.
}
}
// Map data type, preserving serial types when detected
column.Type = r.mapDataType(dataType, udtName, hasNextval)
column.NotNull = (isNullable == "NO")
column.Sequence = uint(ordinalPosition)
if description != nil {
column.Description = *description
}

View File

@@ -3,6 +3,7 @@ package pgsql
import (
"context"
"fmt"
"strings"
"github.com/jackc/pgx/v5"
@@ -259,33 +260,46 @@ func (r *Reader) close() {
}
// mapDataType maps PostgreSQL data types to canonical types
func (r *Reader) mapDataType(pgType, udtName string) string {
func (r *Reader) mapDataType(pgType, udtName string, hasNextval bool) string {
// If the column has a nextval default, it's likely a serial type
// Map to the appropriate serial type instead of the base integer type
if hasNextval {
switch strings.ToLower(pgType) {
case "integer", "int", "int4":
return "serial"
case "bigint", "int8":
return "bigserial"
case "smallint", "int2":
return "smallserial"
}
}
// Map common PostgreSQL types
typeMap := map[string]string{
"integer": "int",
"bigint": "int64",
"smallint": "int16",
"int": "int",
"int2": "int16",
"int4": "int",
"int8": "int64",
"serial": "int",
"bigserial": "int64",
"smallserial": "int16",
"numeric": "decimal",
"integer": "integer",
"bigint": "bigint",
"smallint": "smallint",
"int": "integer",
"int2": "smallint",
"int4": "integer",
"int8": "bigint",
"serial": "serial",
"bigserial": "bigserial",
"smallserial": "smallserial",
"numeric": "numeric",
"decimal": "decimal",
"real": "float32",
"double precision": "float64",
"float4": "float32",
"float8": "float64",
"money": "decimal",
"character varying": "string",
"varchar": "string",
"character": "string",
"char": "string",
"text": "string",
"boolean": "bool",
"bool": "bool",
"real": "real",
"double precision": "double precision",
"float4": "real",
"float8": "double precision",
"money": "money",
"character varying": "varchar",
"varchar": "varchar",
"character": "char",
"char": "char",
"text": "text",
"boolean": "boolean",
"bool": "boolean",
"date": "date",
"time": "time",
"time without time zone": "time",

View File

@@ -177,20 +177,20 @@ func TestMapDataType(t *testing.T) {
udtName string
expected string
}{
{"integer", "int4", "int"},
{"bigint", "int8", "int64"},
{"smallint", "int2", "int16"},
{"character varying", "varchar", "string"},
{"text", "text", "string"},
{"boolean", "bool", "bool"},
{"integer", "int4", "integer"},
{"bigint", "int8", "bigint"},
{"smallint", "int2", "smallint"},
{"character varying", "varchar", "varchar"},
{"text", "text", "text"},
{"boolean", "bool", "boolean"},
{"timestamp without time zone", "timestamp", "timestamp"},
{"timestamp with time zone", "timestamptz", "timestamptz"},
{"json", "json", "json"},
{"jsonb", "jsonb", "jsonb"},
{"uuid", "uuid", "uuid"},
{"numeric", "numeric", "decimal"},
{"real", "float4", "float32"},
{"double precision", "float8", "float64"},
{"numeric", "numeric", "numeric"},
{"real", "float4", "real"},
{"double precision", "float8", "double precision"},
{"date", "date", "date"},
{"time without time zone", "time", "time"},
{"bytea", "bytea", "bytea"},
@@ -199,12 +199,31 @@ func TestMapDataType(t *testing.T) {
for _, tt := range tests {
t.Run(tt.pgType, func(t *testing.T) {
result := reader.mapDataType(tt.pgType, tt.udtName)
result := reader.mapDataType(tt.pgType, tt.udtName, false)
if result != tt.expected {
t.Errorf("mapDataType(%s, %s) = %s, expected %s", tt.pgType, tt.udtName, result, tt.expected)
}
})
}
// Test serial type detection with hasNextval=true
serialTests := []struct {
pgType string
expected string
}{
{"integer", "serial"},
{"bigint", "bigserial"},
{"smallint", "smallserial"},
}
for _, tt := range serialTests {
t.Run(tt.pgType+"_with_nextval", func(t *testing.T) {
result := reader.mapDataType(tt.pgType, "", true)
if result != tt.expected {
t.Errorf("mapDataType(%s, '', true) = %s, expected %s", tt.pgType, result, tt.expected)
}
})
}
}
func TestParseIndexDefinition(t *testing.T) {