feat(writer): enhance type conversion for PostgreSQL compatibility and add tests
This commit is contained in:
@@ -4,31 +4,31 @@ import "strings"
|
||||
|
||||
var GoToStdTypes = map[string]string{
|
||||
"bool": "boolean",
|
||||
"int64": "integer",
|
||||
"int64": "bigint",
|
||||
"int": "integer",
|
||||
"int8": "integer",
|
||||
"int16": "integer",
|
||||
"int8": "smallint",
|
||||
"int16": "smallint",
|
||||
"int32": "integer",
|
||||
"uint": "integer",
|
||||
"uint8": "integer",
|
||||
"uint16": "integer",
|
||||
"uint8": "smallint",
|
||||
"uint16": "smallint",
|
||||
"uint32": "integer",
|
||||
"uint64": "integer",
|
||||
"uintptr": "integer",
|
||||
"znullint64": "integer",
|
||||
"uint64": "bigint",
|
||||
"uintptr": "bigint",
|
||||
"znullint64": "bigint",
|
||||
"znullint32": "integer",
|
||||
"znullbyte": "integer",
|
||||
"znullbyte": "smallint",
|
||||
"float64": "double",
|
||||
"float32": "double",
|
||||
"complex64": "double",
|
||||
"complex128": "double",
|
||||
"customfloat64": "double",
|
||||
"string": "string",
|
||||
"Pointer": "integer",
|
||||
"string": "text",
|
||||
"Pointer": "bigint",
|
||||
"[]byte": "blob",
|
||||
"customdate": "string",
|
||||
"customtime": "string",
|
||||
"customtimestamp": "string",
|
||||
"customdate": "date",
|
||||
"customtime": "time",
|
||||
"customtimestamp": "timestamp",
|
||||
"sqlfloat64": "double",
|
||||
"sqlfloat16": "double",
|
||||
"sqluuid": "uuid",
|
||||
@@ -36,9 +36,9 @@ var GoToStdTypes = map[string]string{
|
||||
"sqljson": "json",
|
||||
"sqlint64": "bigint",
|
||||
"sqlint32": "integer",
|
||||
"sqlint16": "integer",
|
||||
"sqlint16": "smallint",
|
||||
"sqlbool": "boolean",
|
||||
"sqlstring": "string",
|
||||
"sqlstring": "text",
|
||||
"nullablejsonb": "jsonb",
|
||||
"nullablejson": "json",
|
||||
"nullableuuid": "uuid",
|
||||
@@ -67,7 +67,7 @@ var GoToPGSQLTypes = map[string]string{
|
||||
"float32": "real",
|
||||
"complex64": "double precision",
|
||||
"complex128": "double precision",
|
||||
"customfloat64": "double precisio",
|
||||
"customfloat64": "double precision",
|
||||
"string": "text",
|
||||
"Pointer": "bigint",
|
||||
"[]byte": "bytea",
|
||||
@@ -81,9 +81,9 @@ var GoToPGSQLTypes = map[string]string{
|
||||
"sqljson": "json",
|
||||
"sqlint64": "bigint",
|
||||
"sqlint32": "integer",
|
||||
"sqlint16": "integer",
|
||||
"sqlint16": "smallint",
|
||||
"sqlbool": "boolean",
|
||||
"sqlstring": "string",
|
||||
"sqlstring": "text",
|
||||
"nullablejsonb": "jsonb",
|
||||
"nullablejson": "json",
|
||||
"nullableuuid": "uuid",
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/pgsql"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/writers"
|
||||
)
|
||||
|
||||
@@ -335,7 +336,7 @@ func (w *MigrationWriter) generateAlterTableScripts(schema *models.Schema, model
|
||||
SchemaName: schema.Name,
|
||||
TableName: modelTable.Name,
|
||||
ColumnName: modelCol.Name,
|
||||
ColumnType: modelCol.Type,
|
||||
ColumnType: pgsql.ConvertSQLType(modelCol.Type),
|
||||
Default: defaultVal,
|
||||
NotNull: modelCol.NotNull,
|
||||
})
|
||||
@@ -359,7 +360,7 @@ func (w *MigrationWriter) generateAlterTableScripts(schema *models.Schema, model
|
||||
SchemaName: schema.Name,
|
||||
TableName: modelTable.Name,
|
||||
ColumnName: modelCol.Name,
|
||||
NewType: modelCol.Type,
|
||||
NewType: pgsql.ConvertSQLType(modelCol.Type),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"github.com/jackc/pgx/v5"
|
||||
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/pgsql"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/writers"
|
||||
)
|
||||
|
||||
@@ -332,15 +333,16 @@ func (w *Writer) generateCreateTableStatement(schema *models.Schema, table *mode
|
||||
func (w *Writer) generateColumnDefinition(col *models.Column) string {
|
||||
parts := []string{col.SQLName()}
|
||||
|
||||
// Type with length/precision
|
||||
typeStr := col.Type
|
||||
// Type with length/precision - convert to valid PostgreSQL type
|
||||
baseType := pgsql.ConvertSQLType(col.Type)
|
||||
typeStr := baseType
|
||||
if col.Length > 0 && col.Precision == 0 {
|
||||
typeStr = fmt.Sprintf("%s(%d)", col.Type, col.Length)
|
||||
typeStr = fmt.Sprintf("%s(%d)", baseType, col.Length)
|
||||
} else if col.Precision > 0 {
|
||||
if col.Scale > 0 {
|
||||
typeStr = fmt.Sprintf("%s(%d,%d)", col.Type, col.Precision, col.Scale)
|
||||
typeStr = fmt.Sprintf("%s(%d,%d)", baseType, col.Precision, col.Scale)
|
||||
} else {
|
||||
typeStr = fmt.Sprintf("%s(%d)", col.Type, col.Precision)
|
||||
typeStr = fmt.Sprintf("%s(%d)", baseType, col.Precision)
|
||||
}
|
||||
}
|
||||
parts = append(parts, typeStr)
|
||||
@@ -488,7 +490,7 @@ func (w *Writer) writeCreateTables(schema *models.Schema) error {
|
||||
columnDefs := make([]string, 0, len(columns))
|
||||
|
||||
for _, col := range columns {
|
||||
colDef := fmt.Sprintf(" %s %s", col.SQLName(), col.Type)
|
||||
colDef := fmt.Sprintf(" %s %s", col.SQLName(), pgsql.ConvertSQLType(col.Type))
|
||||
|
||||
// Add default value if present
|
||||
if col.Default != nil && col.Default != "" {
|
||||
|
||||
@@ -241,3 +241,67 @@ func TestIsIntegerType(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestTypeConversion(t *testing.T) {
|
||||
// Test that invalid Go types are converted to valid PostgreSQL types
|
||||
db := models.InitDatabase("testdb")
|
||||
schema := models.InitSchema("public")
|
||||
|
||||
// Create a test table with Go types instead of SQL types
|
||||
table := models.InitTable("test_types", "public")
|
||||
|
||||
// Add columns with Go types (invalid for PostgreSQL)
|
||||
stringCol := models.InitColumn("name", "test_types", "public")
|
||||
stringCol.Type = "string" // Should be converted to "text"
|
||||
table.Columns["name"] = stringCol
|
||||
|
||||
int64Col := models.InitColumn("big_id", "test_types", "public")
|
||||
int64Col.Type = "int64" // Should be converted to "bigint"
|
||||
table.Columns["big_id"] = int64Col
|
||||
|
||||
int16Col := models.InitColumn("small_id", "test_types", "public")
|
||||
int16Col.Type = "int16" // Should be converted to "smallint"
|
||||
table.Columns["small_id"] = int16Col
|
||||
|
||||
schema.Tables = append(schema.Tables, table)
|
||||
db.Schemas = append(db.Schemas, schema)
|
||||
|
||||
// Create writer with output to buffer
|
||||
var buf bytes.Buffer
|
||||
options := &writers.WriterOptions{}
|
||||
writer := NewWriter(options)
|
||||
writer.writer = &buf
|
||||
|
||||
// Write the database
|
||||
err := writer.WriteDatabase(db)
|
||||
if err != nil {
|
||||
t.Fatalf("WriteDatabase failed: %v", err)
|
||||
}
|
||||
|
||||
output := buf.String()
|
||||
|
||||
// Print output for debugging
|
||||
t.Logf("Generated SQL:\n%s", output)
|
||||
|
||||
// Verify that Go types were converted to PostgreSQL types
|
||||
if strings.Contains(output, "string") {
|
||||
t.Errorf("Output contains 'string' type - should be converted to 'text'\nFull output:\n%s", output)
|
||||
}
|
||||
if strings.Contains(output, "int64") {
|
||||
t.Errorf("Output contains 'int64' type - should be converted to 'bigint'\nFull output:\n%s", output)
|
||||
}
|
||||
if strings.Contains(output, "int16") {
|
||||
t.Errorf("Output contains 'int16' type - should be converted to 'smallint'\nFull output:\n%s", output)
|
||||
}
|
||||
|
||||
// Verify correct PostgreSQL types are present
|
||||
if !strings.Contains(output, "text") {
|
||||
t.Errorf("Output missing 'text' type (converted from 'string')\nFull output:\n%s", output)
|
||||
}
|
||||
if !strings.Contains(output, "bigint") {
|
||||
t.Errorf("Output missing 'bigint' type (converted from 'int64')\nFull output:\n%s", output)
|
||||
}
|
||||
if !strings.Contains(output, "smallint") {
|
||||
t.Errorf("Output missing 'smallint' type (converted from 'int16')\nFull output:\n%s", output)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user