feat(pgsql): enhance data type mapping to support serial types
All checks were successful
CI / Test (1.25) (push) Successful in -24m18s
CI / Test (1.24) (push) Successful in -24m6s
CI / Build (push) Successful in -25m14s
CI / Lint (push) Successful in -24m47s
Release / Build and Release (push) Successful in -25m37s
Integration Tests / Integration Tests (push) Successful in -25m9s
All checks were successful
CI / Test (1.25) (push) Successful in -24m18s
CI / Test (1.24) (push) Successful in -24m6s
CI / Build (push) Successful in -25m14s
CI / Lint (push) Successful in -24m47s
Release / Build and Release (push) Successful in -25m37s
Integration Tests / Integration Tests (push) Successful in -25m9s
This commit is contained in:
@@ -231,14 +231,13 @@ func (r *Reader) queryColumns(schemaName string) (map[string]map[string]*models.
|
|||||||
}
|
}
|
||||||
|
|
||||||
column := models.InitColumn(columnName, tableName, schema)
|
column := models.InitColumn(columnName, tableName, schema)
|
||||||
column.Type = r.mapDataType(dataType, udtName)
|
|
||||||
column.NotNull = (isNullable == "NO")
|
|
||||||
column.Sequence = uint(ordinalPosition)
|
|
||||||
|
|
||||||
|
// Check if this is a serial type (has nextval default)
|
||||||
|
hasNextval := false
|
||||||
if columnDefault != nil {
|
if columnDefault != nil {
|
||||||
// Parse default value - remove nextval for sequences
|
|
||||||
defaultVal := *columnDefault
|
defaultVal := *columnDefault
|
||||||
if strings.HasPrefix(defaultVal, "nextval") {
|
if strings.HasPrefix(defaultVal, "nextval") {
|
||||||
|
hasNextval = true
|
||||||
column.AutoIncrement = true
|
column.AutoIncrement = true
|
||||||
column.Default = defaultVal
|
column.Default = defaultVal
|
||||||
} else {
|
} else {
|
||||||
@@ -246,6 +245,11 @@ func (r *Reader) queryColumns(schemaName string) (map[string]map[string]*models.
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Map data type, preserving serial types when detected
|
||||||
|
column.Type = r.mapDataType(dataType, udtName, hasNextval)
|
||||||
|
column.NotNull = (isNullable == "NO")
|
||||||
|
column.Sequence = uint(ordinalPosition)
|
||||||
|
|
||||||
if description != nil {
|
if description != nil {
|
||||||
column.Description = *description
|
column.Description = *description
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package pgsql
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/jackc/pgx/v5"
|
"github.com/jackc/pgx/v5"
|
||||||
|
|
||||||
@@ -259,33 +260,46 @@ func (r *Reader) close() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// mapDataType maps PostgreSQL data types to canonical types
|
// mapDataType maps PostgreSQL data types to canonical types
|
||||||
func (r *Reader) mapDataType(pgType, udtName string) string {
|
func (r *Reader) mapDataType(pgType, udtName string, hasNextval bool) string {
|
||||||
|
// If the column has a nextval default, it's likely a serial type
|
||||||
|
// Map to the appropriate serial type instead of the base integer type
|
||||||
|
if hasNextval {
|
||||||
|
switch strings.ToLower(pgType) {
|
||||||
|
case "integer", "int", "int4":
|
||||||
|
return "serial"
|
||||||
|
case "bigint", "int8":
|
||||||
|
return "bigserial"
|
||||||
|
case "smallint", "int2":
|
||||||
|
return "smallserial"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Map common PostgreSQL types
|
// Map common PostgreSQL types
|
||||||
typeMap := map[string]string{
|
typeMap := map[string]string{
|
||||||
"integer": "int",
|
"integer": "integer",
|
||||||
"bigint": "int64",
|
"bigint": "bigint",
|
||||||
"smallint": "int16",
|
"smallint": "smallint",
|
||||||
"int": "int",
|
"int": "integer",
|
||||||
"int2": "int16",
|
"int2": "smallint",
|
||||||
"int4": "int",
|
"int4": "integer",
|
||||||
"int8": "int64",
|
"int8": "bigint",
|
||||||
"serial": "int",
|
"serial": "serial",
|
||||||
"bigserial": "int64",
|
"bigserial": "bigserial",
|
||||||
"smallserial": "int16",
|
"smallserial": "smallserial",
|
||||||
"numeric": "decimal",
|
"numeric": "numeric",
|
||||||
"decimal": "decimal",
|
"decimal": "decimal",
|
||||||
"real": "float32",
|
"real": "real",
|
||||||
"double precision": "float64",
|
"double precision": "double precision",
|
||||||
"float4": "float32",
|
"float4": "real",
|
||||||
"float8": "float64",
|
"float8": "double precision",
|
||||||
"money": "decimal",
|
"money": "money",
|
||||||
"character varying": "string",
|
"character varying": "varchar",
|
||||||
"varchar": "string",
|
"varchar": "varchar",
|
||||||
"character": "string",
|
"character": "char",
|
||||||
"char": "string",
|
"char": "char",
|
||||||
"text": "string",
|
"text": "text",
|
||||||
"boolean": "bool",
|
"boolean": "boolean",
|
||||||
"bool": "bool",
|
"bool": "boolean",
|
||||||
"date": "date",
|
"date": "date",
|
||||||
"time": "time",
|
"time": "time",
|
||||||
"time without time zone": "time",
|
"time without time zone": "time",
|
||||||
|
|||||||
@@ -177,20 +177,20 @@ func TestMapDataType(t *testing.T) {
|
|||||||
udtName string
|
udtName string
|
||||||
expected string
|
expected string
|
||||||
}{
|
}{
|
||||||
{"integer", "int4", "int"},
|
{"integer", "int4", "integer"},
|
||||||
{"bigint", "int8", "int64"},
|
{"bigint", "int8", "bigint"},
|
||||||
{"smallint", "int2", "int16"},
|
{"smallint", "int2", "smallint"},
|
||||||
{"character varying", "varchar", "string"},
|
{"character varying", "varchar", "varchar"},
|
||||||
{"text", "text", "string"},
|
{"text", "text", "text"},
|
||||||
{"boolean", "bool", "bool"},
|
{"boolean", "bool", "boolean"},
|
||||||
{"timestamp without time zone", "timestamp", "timestamp"},
|
{"timestamp without time zone", "timestamp", "timestamp"},
|
||||||
{"timestamp with time zone", "timestamptz", "timestamptz"},
|
{"timestamp with time zone", "timestamptz", "timestamptz"},
|
||||||
{"json", "json", "json"},
|
{"json", "json", "json"},
|
||||||
{"jsonb", "jsonb", "jsonb"},
|
{"jsonb", "jsonb", "jsonb"},
|
||||||
{"uuid", "uuid", "uuid"},
|
{"uuid", "uuid", "uuid"},
|
||||||
{"numeric", "numeric", "decimal"},
|
{"numeric", "numeric", "numeric"},
|
||||||
{"real", "float4", "float32"},
|
{"real", "float4", "real"},
|
||||||
{"double precision", "float8", "float64"},
|
{"double precision", "float8", "double precision"},
|
||||||
{"date", "date", "date"},
|
{"date", "date", "date"},
|
||||||
{"time without time zone", "time", "time"},
|
{"time without time zone", "time", "time"},
|
||||||
{"bytea", "bytea", "bytea"},
|
{"bytea", "bytea", "bytea"},
|
||||||
@@ -199,12 +199,31 @@ func TestMapDataType(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.pgType, func(t *testing.T) {
|
t.Run(tt.pgType, func(t *testing.T) {
|
||||||
result := reader.mapDataType(tt.pgType, tt.udtName)
|
result := reader.mapDataType(tt.pgType, tt.udtName, false)
|
||||||
if result != tt.expected {
|
if result != tt.expected {
|
||||||
t.Errorf("mapDataType(%s, %s) = %s, expected %s", tt.pgType, tt.udtName, result, tt.expected)
|
t.Errorf("mapDataType(%s, %s) = %s, expected %s", tt.pgType, tt.udtName, result, tt.expected)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Test serial type detection with hasNextval=true
|
||||||
|
serialTests := []struct {
|
||||||
|
pgType string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{"integer", "serial"},
|
||||||
|
{"bigint", "bigserial"},
|
||||||
|
{"smallint", "smallserial"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range serialTests {
|
||||||
|
t.Run(tt.pgType+"_with_nextval", func(t *testing.T) {
|
||||||
|
result := reader.mapDataType(tt.pgType, "", true)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("mapDataType(%s, '', true) = %s, expected %s", tt.pgType, result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestParseIndexDefinition(t *testing.T) {
|
func TestParseIndexDefinition(t *testing.T) {
|
||||||
|
|||||||
Reference in New Issue
Block a user