Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 77436757c8 | |||
| 5e6f03e412 | |||
| 1dcbc79387 |
@@ -231,14 +231,13 @@ func (r *Reader) queryColumns(schemaName string) (map[string]map[string]*models.
|
||||
}
|
||||
|
||||
column := models.InitColumn(columnName, tableName, schema)
|
||||
column.Type = r.mapDataType(dataType, udtName)
|
||||
column.NotNull = (isNullable == "NO")
|
||||
column.Sequence = uint(ordinalPosition)
|
||||
|
||||
// Check if this is a serial type (has nextval default)
|
||||
hasNextval := false
|
||||
if columnDefault != nil {
|
||||
// Parse default value - remove nextval for sequences
|
||||
defaultVal := *columnDefault
|
||||
if strings.HasPrefix(defaultVal, "nextval") {
|
||||
hasNextval = true
|
||||
column.AutoIncrement = true
|
||||
column.Default = defaultVal
|
||||
} else {
|
||||
@@ -246,6 +245,11 @@ func (r *Reader) queryColumns(schemaName string) (map[string]map[string]*models.
|
||||
}
|
||||
}
|
||||
|
||||
// Map data type, preserving serial types when detected
|
||||
column.Type = r.mapDataType(dataType, udtName, hasNextval)
|
||||
column.NotNull = (isNullable == "NO")
|
||||
column.Sequence = uint(ordinalPosition)
|
||||
|
||||
if description != nil {
|
||||
column.Description = *description
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package pgsql
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
|
||||
@@ -259,33 +260,46 @@ func (r *Reader) close() {
|
||||
}
|
||||
|
||||
// mapDataType maps PostgreSQL data types to canonical types
|
||||
func (r *Reader) mapDataType(pgType, udtName string) string {
|
||||
func (r *Reader) mapDataType(pgType, udtName string, hasNextval bool) string {
|
||||
// If the column has a nextval default, it's likely a serial type
|
||||
// Map to the appropriate serial type instead of the base integer type
|
||||
if hasNextval {
|
||||
switch strings.ToLower(pgType) {
|
||||
case "integer", "int", "int4":
|
||||
return "serial"
|
||||
case "bigint", "int8":
|
||||
return "bigserial"
|
||||
case "smallint", "int2":
|
||||
return "smallserial"
|
||||
}
|
||||
}
|
||||
|
||||
// Map common PostgreSQL types
|
||||
typeMap := map[string]string{
|
||||
"integer": "int",
|
||||
"bigint": "int64",
|
||||
"smallint": "int16",
|
||||
"int": "int",
|
||||
"int2": "int16",
|
||||
"int4": "int",
|
||||
"int8": "int64",
|
||||
"serial": "int",
|
||||
"bigserial": "int64",
|
||||
"smallserial": "int16",
|
||||
"numeric": "decimal",
|
||||
"integer": "integer",
|
||||
"bigint": "bigint",
|
||||
"smallint": "smallint",
|
||||
"int": "integer",
|
||||
"int2": "smallint",
|
||||
"int4": "integer",
|
||||
"int8": "bigint",
|
||||
"serial": "serial",
|
||||
"bigserial": "bigserial",
|
||||
"smallserial": "smallserial",
|
||||
"numeric": "numeric",
|
||||
"decimal": "decimal",
|
||||
"real": "float32",
|
||||
"double precision": "float64",
|
||||
"float4": "float32",
|
||||
"float8": "float64",
|
||||
"money": "decimal",
|
||||
"character varying": "string",
|
||||
"varchar": "string",
|
||||
"character": "string",
|
||||
"char": "string",
|
||||
"text": "string",
|
||||
"boolean": "bool",
|
||||
"bool": "bool",
|
||||
"real": "real",
|
||||
"double precision": "double precision",
|
||||
"float4": "real",
|
||||
"float8": "double precision",
|
||||
"money": "money",
|
||||
"character varying": "varchar",
|
||||
"varchar": "varchar",
|
||||
"character": "char",
|
||||
"char": "char",
|
||||
"text": "text",
|
||||
"boolean": "boolean",
|
||||
"bool": "boolean",
|
||||
"date": "date",
|
||||
"time": "time",
|
||||
"time without time zone": "time",
|
||||
|
||||
@@ -177,20 +177,20 @@ func TestMapDataType(t *testing.T) {
|
||||
udtName string
|
||||
expected string
|
||||
}{
|
||||
{"integer", "int4", "int"},
|
||||
{"bigint", "int8", "int64"},
|
||||
{"smallint", "int2", "int16"},
|
||||
{"character varying", "varchar", "string"},
|
||||
{"text", "text", "string"},
|
||||
{"boolean", "bool", "bool"},
|
||||
{"integer", "int4", "integer"},
|
||||
{"bigint", "int8", "bigint"},
|
||||
{"smallint", "int2", "smallint"},
|
||||
{"character varying", "varchar", "varchar"},
|
||||
{"text", "text", "text"},
|
||||
{"boolean", "bool", "boolean"},
|
||||
{"timestamp without time zone", "timestamp", "timestamp"},
|
||||
{"timestamp with time zone", "timestamptz", "timestamptz"},
|
||||
{"json", "json", "json"},
|
||||
{"jsonb", "jsonb", "jsonb"},
|
||||
{"uuid", "uuid", "uuid"},
|
||||
{"numeric", "numeric", "decimal"},
|
||||
{"real", "float4", "float32"},
|
||||
{"double precision", "float8", "float64"},
|
||||
{"numeric", "numeric", "numeric"},
|
||||
{"real", "float4", "real"},
|
||||
{"double precision", "float8", "double precision"},
|
||||
{"date", "date", "date"},
|
||||
{"time without time zone", "time", "time"},
|
||||
{"bytea", "bytea", "bytea"},
|
||||
@@ -199,12 +199,31 @@ func TestMapDataType(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.pgType, func(t *testing.T) {
|
||||
result := reader.mapDataType(tt.pgType, tt.udtName)
|
||||
result := reader.mapDataType(tt.pgType, tt.udtName, false)
|
||||
if result != tt.expected {
|
||||
t.Errorf("mapDataType(%s, %s) = %s, expected %s", tt.pgType, tt.udtName, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Test serial type detection with hasNextval=true
|
||||
serialTests := []struct {
|
||||
pgType string
|
||||
expected string
|
||||
}{
|
||||
{"integer", "serial"},
|
||||
{"bigint", "bigserial"},
|
||||
{"smallint", "smallserial"},
|
||||
}
|
||||
|
||||
for _, tt := range serialTests {
|
||||
t.Run(tt.pgType+"_with_nextval", func(t *testing.T) {
|
||||
result := reader.mapDataType(tt.pgType, "", true)
|
||||
if result != tt.expected {
|
||||
t.Errorf("mapDataType(%s, '', true) = %s, expected %s", tt.pgType, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseIndexDefinition(t *testing.T) {
|
||||
|
||||
@@ -62,6 +62,17 @@ func (tm *TypeMapper) isSimpleType(sqlType string) bool {
|
||||
return simpleTypes[sqlType]
|
||||
}
|
||||
|
||||
// isSerialType checks if a SQL type is a serial type (auto-incrementing)
|
||||
func (tm *TypeMapper) isSerialType(sqlType string) bool {
|
||||
baseType := tm.extractBaseType(sqlType)
|
||||
serialTypes := map[string]bool{
|
||||
"serial": true,
|
||||
"bigserial": true,
|
||||
"smallserial": true,
|
||||
}
|
||||
return serialTypes[baseType]
|
||||
}
|
||||
|
||||
// baseGoType returns the base Go type for a SQL type (not null, simple types only)
|
||||
func (tm *TypeMapper) baseGoType(sqlType string) string {
|
||||
typeMap := map[string]string{
|
||||
@@ -122,10 +133,10 @@ func (tm *TypeMapper) bunGoType(sqlType string) string {
|
||||
"decimal": tm.sqlTypesAlias + ".SqlFloat64",
|
||||
|
||||
// Date/Time types
|
||||
"timestamp": tm.sqlTypesAlias + ".SqlTime",
|
||||
"timestamp without time zone": tm.sqlTypesAlias + ".SqlTime",
|
||||
"timestamp with time zone": tm.sqlTypesAlias + ".SqlTime",
|
||||
"timestamptz": tm.sqlTypesAlias + ".SqlTime",
|
||||
"timestamp": tm.sqlTypesAlias + ".SqlTimeStamp",
|
||||
"timestamp without time zone": tm.sqlTypesAlias + ".SqlTimeStamp",
|
||||
"timestamp with time zone": tm.sqlTypesAlias + ".SqlTimeStamp",
|
||||
"timestamptz": tm.sqlTypesAlias + ".SqlTimeStamp",
|
||||
"date": tm.sqlTypesAlias + ".SqlDate",
|
||||
"time": tm.sqlTypesAlias + ".SqlTime",
|
||||
"time without time zone": tm.sqlTypesAlias + ".SqlTime",
|
||||
@@ -190,6 +201,11 @@ func (tm *TypeMapper) BuildBunTag(column *models.Column, table *models.Table) st
|
||||
parts = append(parts, "pk")
|
||||
}
|
||||
|
||||
// Auto increment (for serial types or explicit auto_increment)
|
||||
if column.AutoIncrement || tm.isSerialType(column.Type) {
|
||||
parts = append(parts, "autoincrement")
|
||||
}
|
||||
|
||||
// Default value
|
||||
if column.Default != nil {
|
||||
// Sanitize default value to remove backticks
|
||||
|
||||
@@ -90,8 +90,8 @@ func TestWriter_WriteTable(t *testing.T) {
|
||||
}
|
||||
|
||||
// Verify Bun-specific elements
|
||||
if !strings.Contains(generated, "bun:\"id,type:bigint,pk,") {
|
||||
t.Errorf("Missing Bun-style primary key tag")
|
||||
if !strings.Contains(generated, "bun:\"id,type:bigint,pk,autoincrement,") {
|
||||
t.Errorf("Missing Bun-style primary key tag with autoincrement")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -567,8 +567,8 @@ func TestTypeMapper_SQLTypeToGoType_Bun(t *testing.T) {
|
||||
{"bigint", false, "resolvespec_common.SqlInt64"},
|
||||
{"varchar", true, "resolvespec_common.SqlString"}, // Bun uses sql types even for NOT NULL strings
|
||||
{"varchar", false, "resolvespec_common.SqlString"},
|
||||
{"timestamp", true, "resolvespec_common.SqlTime"},
|
||||
{"timestamp", false, "resolvespec_common.SqlTime"},
|
||||
{"timestamp", true, "resolvespec_common.SqlTimeStamp"},
|
||||
{"timestamp", false, "resolvespec_common.SqlTimeStamp"},
|
||||
{"date", false, "resolvespec_common.SqlDate"},
|
||||
{"boolean", true, "bool"},
|
||||
{"boolean", false, "resolvespec_common.SqlBool"},
|
||||
@@ -624,6 +624,37 @@ func TestTypeMapper_BuildBunTag(t *testing.T) {
|
||||
},
|
||||
want: []string{"status,", "type:text,", "default:active,"},
|
||||
},
|
||||
{
|
||||
name: "auto increment with AutoIncrement flag",
|
||||
column: &models.Column{
|
||||
Name: "id",
|
||||
Type: "bigint",
|
||||
NotNull: true,
|
||||
IsPrimaryKey: true,
|
||||
AutoIncrement: true,
|
||||
},
|
||||
want: []string{"id,", "type:bigint,", "pk,", "autoincrement,"},
|
||||
},
|
||||
{
|
||||
name: "serial type (auto-increment)",
|
||||
column: &models.Column{
|
||||
Name: "id",
|
||||
Type: "serial",
|
||||
NotNull: true,
|
||||
IsPrimaryKey: true,
|
||||
},
|
||||
want: []string{"id,", "type:serial,", "pk,", "autoincrement,"},
|
||||
},
|
||||
{
|
||||
name: "bigserial type (auto-increment)",
|
||||
column: &models.Column{
|
||||
Name: "id",
|
||||
Type: "bigserial",
|
||||
NotNull: true,
|
||||
IsPrimaryKey: true,
|
||||
},
|
||||
want: []string{"id,", "type:bigserial,", "pk,", "autoincrement,"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
|
||||
@@ -158,10 +158,10 @@ func (tm *TypeMapper) nullableGoType(sqlType string) string {
|
||||
"decimal": tm.sqlTypesAlias + ".SqlFloat64",
|
||||
|
||||
// Date/Time types
|
||||
"timestamp": tm.sqlTypesAlias + ".SqlTime",
|
||||
"timestamp without time zone": tm.sqlTypesAlias + ".SqlTime",
|
||||
"timestamp with time zone": tm.sqlTypesAlias + ".SqlTime",
|
||||
"timestamptz": tm.sqlTypesAlias + ".SqlTime",
|
||||
"timestamp": tm.sqlTypesAlias + ".SqlTimeStamp",
|
||||
"timestamp without time zone": tm.sqlTypesAlias + ".SqlTimeStamp",
|
||||
"timestamp with time zone": tm.sqlTypesAlias + ".SqlTimeStamp",
|
||||
"timestamptz": tm.sqlTypesAlias + ".SqlTimeStamp",
|
||||
"date": tm.sqlTypesAlias + ".SqlDate",
|
||||
"time": tm.sqlTypesAlias + ".SqlTime",
|
||||
"time without time zone": tm.sqlTypesAlias + ".SqlTime",
|
||||
|
||||
@@ -655,7 +655,7 @@ func TestTypeMapper_SQLTypeToGoType(t *testing.T) {
|
||||
{"varchar", true, "string"},
|
||||
{"varchar", false, "sql_types.SqlString"},
|
||||
{"timestamp", true, "time.Time"},
|
||||
{"timestamp", false, "sql_types.SqlTime"},
|
||||
{"timestamp", false, "sql_types.SqlTimeStamp"},
|
||||
{"boolean", true, "bool"},
|
||||
{"boolean", false, "sql_types.SqlBool"},
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user