diff --git a/pkg/pgsql/datatypes.go b/pkg/pgsql/datatypes.go index a55b0fc..ebe163f 100644 --- a/pkg/pgsql/datatypes.go +++ b/pkg/pgsql/datatypes.go @@ -4,31 +4,31 @@ import "strings" var GoToStdTypes = map[string]string{ "bool": "boolean", - "int64": "integer", + "int64": "bigint", "int": "integer", - "int8": "integer", - "int16": "integer", + "int8": "smallint", + "int16": "smallint", "int32": "integer", "uint": "integer", - "uint8": "integer", - "uint16": "integer", + "uint8": "smallint", + "uint16": "smallint", "uint32": "integer", - "uint64": "integer", - "uintptr": "integer", - "znullint64": "integer", + "uint64": "bigint", + "uintptr": "bigint", + "znullint64": "bigint", "znullint32": "integer", - "znullbyte": "integer", + "znullbyte": "smallint", "float64": "double", "float32": "double", "complex64": "double", "complex128": "double", "customfloat64": "double", - "string": "string", - "Pointer": "integer", + "string": "text", + "Pointer": "bigint", "[]byte": "blob", - "customdate": "string", - "customtime": "string", - "customtimestamp": "string", + "customdate": "date", + "customtime": "time", + "customtimestamp": "timestamp", "sqlfloat64": "double", "sqlfloat16": "double", "sqluuid": "uuid", @@ -36,9 +36,9 @@ var GoToStdTypes = map[string]string{ "sqljson": "json", "sqlint64": "bigint", "sqlint32": "integer", - "sqlint16": "integer", + "sqlint16": "smallint", "sqlbool": "boolean", - "sqlstring": "string", + "sqlstring": "text", "nullablejsonb": "jsonb", "nullablejson": "json", "nullableuuid": "uuid", @@ -67,7 +67,7 @@ var GoToPGSQLTypes = map[string]string{ "float32": "real", "complex64": "double precision", "complex128": "double precision", - "customfloat64": "double precisio", + "customfloat64": "double precision", "string": "text", "Pointer": "bigint", "[]byte": "bytea", @@ -81,9 +81,9 @@ var GoToPGSQLTypes = map[string]string{ "sqljson": "json", "sqlint64": "bigint", "sqlint32": "integer", - "sqlint16": "integer", + "sqlint16": "smallint", "sqlbool": "boolean", - "sqlstring": "string", + "sqlstring": "text", "nullablejsonb": "jsonb", "nullablejson": "json", "nullableuuid": "uuid", diff --git a/pkg/writers/pgsql/migration_writer.go b/pkg/writers/pgsql/migration_writer.go index 671bb93..9278055 100644 --- a/pkg/writers/pgsql/migration_writer.go +++ b/pkg/writers/pgsql/migration_writer.go @@ -8,6 +8,7 @@ import ( "strings" "git.warky.dev/wdevs/relspecgo/pkg/models" + "git.warky.dev/wdevs/relspecgo/pkg/pgsql" "git.warky.dev/wdevs/relspecgo/pkg/writers" ) @@ -335,7 +336,7 @@ func (w *MigrationWriter) generateAlterTableScripts(schema *models.Schema, model SchemaName: schema.Name, TableName: modelTable.Name, ColumnName: modelCol.Name, - ColumnType: modelCol.Type, + ColumnType: pgsql.ConvertSQLType(modelCol.Type), Default: defaultVal, NotNull: modelCol.NotNull, }) @@ -359,7 +360,7 @@ func (w *MigrationWriter) generateAlterTableScripts(schema *models.Schema, model SchemaName: schema.Name, TableName: modelTable.Name, ColumnName: modelCol.Name, - NewType: modelCol.Type, + NewType: pgsql.ConvertSQLType(modelCol.Type), }) if err != nil { return nil, err diff --git a/pkg/writers/pgsql/writer.go b/pkg/writers/pgsql/writer.go index 528f362..2b12049 100644 --- a/pkg/writers/pgsql/writer.go +++ b/pkg/writers/pgsql/writer.go @@ -13,6 +13,7 @@ import ( "github.com/jackc/pgx/v5" "git.warky.dev/wdevs/relspecgo/pkg/models" + "git.warky.dev/wdevs/relspecgo/pkg/pgsql" "git.warky.dev/wdevs/relspecgo/pkg/writers" ) @@ -332,15 +333,16 @@ func (w *Writer) generateCreateTableStatement(schema *models.Schema, table *mode func (w *Writer) generateColumnDefinition(col *models.Column) string { parts := []string{col.SQLName()} - // Type with length/precision - typeStr := col.Type + // Type with length/precision - convert to valid PostgreSQL type + baseType := pgsql.ConvertSQLType(col.Type) + typeStr := baseType if col.Length > 0 && col.Precision == 0 { - typeStr = fmt.Sprintf("%s(%d)", col.Type, col.Length) + typeStr = fmt.Sprintf("%s(%d)", baseType, col.Length) } else if col.Precision > 0 { if col.Scale > 0 { - typeStr = fmt.Sprintf("%s(%d,%d)", col.Type, col.Precision, col.Scale) + typeStr = fmt.Sprintf("%s(%d,%d)", baseType, col.Precision, col.Scale) } else { - typeStr = fmt.Sprintf("%s(%d)", col.Type, col.Precision) + typeStr = fmt.Sprintf("%s(%d)", baseType, col.Precision) } } parts = append(parts, typeStr) @@ -488,7 +490,7 @@ func (w *Writer) writeCreateTables(schema *models.Schema) error { columnDefs := make([]string, 0, len(columns)) for _, col := range columns { - colDef := fmt.Sprintf(" %s %s", col.SQLName(), col.Type) + colDef := fmt.Sprintf(" %s %s", col.SQLName(), pgsql.ConvertSQLType(col.Type)) // Add default value if present if col.Default != nil && col.Default != "" { diff --git a/pkg/writers/pgsql/writer_test.go b/pkg/writers/pgsql/writer_test.go index 097a93d..c159904 100644 --- a/pkg/writers/pgsql/writer_test.go +++ b/pkg/writers/pgsql/writer_test.go @@ -241,3 +241,67 @@ func TestIsIntegerType(t *testing.T) { } } } + +func TestTypeConversion(t *testing.T) { + // Test that invalid Go types are converted to valid PostgreSQL types + db := models.InitDatabase("testdb") + schema := models.InitSchema("public") + + // Create a test table with Go types instead of SQL types + table := models.InitTable("test_types", "public") + + // Add columns with Go types (invalid for PostgreSQL) + stringCol := models.InitColumn("name", "test_types", "public") + stringCol.Type = "string" // Should be converted to "text" + table.Columns["name"] = stringCol + + int64Col := models.InitColumn("big_id", "test_types", "public") + int64Col.Type = "int64" // Should be converted to "bigint" + table.Columns["big_id"] = int64Col + + int16Col := models.InitColumn("small_id", "test_types", "public") + int16Col.Type = "int16" // Should be converted to "smallint" + table.Columns["small_id"] = int16Col + + schema.Tables = append(schema.Tables, table) + db.Schemas = append(db.Schemas, schema) + + // Create writer with output to buffer + var buf bytes.Buffer + options := &writers.WriterOptions{} + writer := NewWriter(options) + writer.writer = &buf + + // Write the database + err := writer.WriteDatabase(db) + if err != nil { + t.Fatalf("WriteDatabase failed: %v", err) + } + + output := buf.String() + + // Print output for debugging + t.Logf("Generated SQL:\n%s", output) + + // Verify that Go types were converted to PostgreSQL types + if strings.Contains(output, "string") { + t.Errorf("Output contains 'string' type - should be converted to 'text'\nFull output:\n%s", output) + } + if strings.Contains(output, "int64") { + t.Errorf("Output contains 'int64' type - should be converted to 'bigint'\nFull output:\n%s", output) + } + if strings.Contains(output, "int16") { + t.Errorf("Output contains 'int16' type - should be converted to 'smallint'\nFull output:\n%s", output) + } + + // Verify correct PostgreSQL types are present + if !strings.Contains(output, "text") { + t.Errorf("Output missing 'text' type (converted from 'string')\nFull output:\n%s", output) + } + if !strings.Contains(output, "bigint") { + t.Errorf("Output missing 'bigint' type (converted from 'int64')\nFull output:\n%s", output) + } + if !strings.Contains(output, "smallint") { + t.Errorf("Output missing 'smallint' type (converted from 'int16')\nFull output:\n%s", output) + } +}