diff --git a/.codex b/.codex new file mode 100644 index 0000000..e69de29 diff --git a/pkg/pgsql/types_registry.go b/pkg/pgsql/types_registry.go new file mode 100644 index 0000000..8b7f0f8 --- /dev/null +++ b/pkg/pgsql/types_registry.go @@ -0,0 +1,246 @@ +package pgsql + +import "strings" + +// TypeSpec describes PostgreSQL type capabilities used by parsers/writers. +type TypeSpec struct { + SupportsLength bool + SupportsPrecision bool +} + +var postgresBaseTypes = map[string]TypeSpec{ + // Numeric types + "smallint": {}, + "integer": {}, + "bigint": {}, + "decimal": {SupportsPrecision: true}, + "numeric": {SupportsPrecision: true}, + "real": {}, + "double precision": {}, + "smallserial": {}, + "serial": {}, + "bigserial": {}, + "money": {}, + + // Character types + "char": {SupportsLength: true}, + "character": {SupportsLength: true}, + "varchar": {SupportsLength: true}, + "character varying": {SupportsLength: true}, + "text": {}, + "name": {}, + + // Binary + "bytea": {}, + + // Date/time + "timestamp": {SupportsPrecision: true}, + "timestamp without time zone": {SupportsPrecision: true}, + "timestamp with time zone": {SupportsPrecision: true}, + "time": {SupportsPrecision: true}, + "time without time zone": {SupportsPrecision: true}, + "time with time zone": {SupportsPrecision: true}, + "date": {}, + "interval": {SupportsPrecision: true}, + + // Boolean + "boolean": {}, + + // Geometric + "point": {}, + "line": {}, + "lseg": {}, + "box": {}, + "path": {}, + "polygon": {}, + "circle": {}, + + // Network + "cidr": {}, + "inet": {}, + "macaddr": {}, + "macaddr8": {}, + + // Bit string + "bit": {SupportsLength: true}, + "bit varying": {SupportsLength: true}, + "varbit": {SupportsLength: true}, + + // Text search + "tsvector": {}, + "tsquery": {}, + + // UUID/XML/JSON + "uuid": {}, + "xml": {}, + "json": {}, + "jsonb": {}, + + // Range + "int4range": {}, + "int8range": {}, + "numrange": {}, + "tsrange": {}, + "tstzrange": {}, + "daterange": {}, + "int4multirange": {}, + "int8multirange": {}, + "nummultirange": {}, + "tsmultirange": {}, + "tstzmultirange": {}, + "datemultirange": {}, + + // Object identifier + "oid": {}, + "regclass": {}, + "regproc": {}, + "regtype": {}, + + // Pseudo-ish/common built-ins seen in schemas + "record": {}, + "void": {}, + + // Common extensions + "citext": {}, + "hstore": {}, + "ltree": {}, + "lquery": {}, + "ltxtquery": {}, + "vector": {SupportsLength: true}, // pgvector: vector(dim) + "halfvec": {SupportsLength: true}, // pgvector: halfvec(dim) + "sparsevec": {SupportsLength: true}, // pgvector: sparsevec(dim) +} + +var postgresTypeAliases = map[string]string{ + // Integer aliases + "int2": "smallint", + "int4": "integer", + "int8": "bigint", + "int": "integer", + + // Serial aliases + "serial2": "smallserial", + "serial4": "serial", + "serial8": "bigserial", + + // Character aliases + "bpchar": "char", + + // Float aliases + "float4": "real", + "float8": "double precision", + "float": "double precision", + + // Time aliases + "timestamptz": "timestamp with time zone", + "timetz": "time with time zone", + + // Bit alias + "varbit": "bit varying", + + // Boolean alias + "bool": "boolean", +} + +// GetPostgresBaseTypes returns a sorted-ish stable list of registered base type names. +func GetPostgresBaseTypes() []string { + result := make([]string, 0, len(postgresBaseTypes)) + for t := range postgresBaseTypes { + result = append(result, t) + } + return result +} + +// GetPostgresTypes returns the registered PostgreSQL types. +// When includeArrays is true, each base type also includes an array variant ("type[]"). +func GetPostgresTypes(includeArrays bool) []string { + base := GetPostgresBaseTypes() + if !includeArrays { + return base + } + + result := make([]string, 0, len(base)*2) + result = append(result, base...) + for _, t := range base { + result = append(result, t+"[]") + } + return result +} + +// ExtractBaseType returns the type without outer array suffixes and modifiers. +// Examples: +// - varchar(255) -> varchar +// - text[] -> text +// - numeric(10,2)[] -> numeric +func ExtractBaseType(sqlType string) string { + t := normalizeTypeToken(sqlType) + t = strings.TrimSpace(stripArraySuffixes(t)) + if idx := strings.Index(t, "("); idx > 0 { + t = strings.TrimSpace(t[:idx]) + } + return t +} + +// ExtractBaseTypeLower is ExtractBaseType with lowercase normalization. +func ExtractBaseTypeLower(sqlType string) string { + return strings.ToLower(ExtractBaseType(sqlType)) +} + +// IsArrayType reports whether the SQL type has one or more [] suffixes. +func IsArrayType(sqlType string) bool { + t := normalizeTypeToken(sqlType) + return strings.HasSuffix(t, "[]") +} + +// ElementType returns the underlying element type for array types. +// For non-array types, it returns the input unchanged. +func ElementType(sqlType string) string { + t := normalizeTypeToken(sqlType) + return stripArraySuffixes(t) +} + +// CanonicalizeBaseType resolves aliases to canonical PostgreSQL type names. +func CanonicalizeBaseType(baseType string) string { + base := strings.ToLower(normalizeTypeToken(baseType)) + if canonical, ok := postgresTypeAliases[base]; ok { + return canonical + } + return base +} + +// IsKnownPostgresType reports whether a type (including array forms) exists in the registry. +func IsKnownPostgresType(sqlType string) bool { + base := CanonicalizeBaseType(ExtractBaseTypeLower(sqlType)) + _, ok := postgresBaseTypes[base] + return ok +} + +// SupportsLength reports if this SQL type accepts a single length/dimension modifier. +func SupportsLength(sqlType string) bool { + base := CanonicalizeBaseType(ExtractBaseTypeLower(sqlType)) + spec, ok := postgresBaseTypes[base] + return ok && spec.SupportsLength +} + +// SupportsPrecision reports if this SQL type accepts precision (and possibly scale). +func SupportsPrecision(sqlType string) bool { + base := CanonicalizeBaseType(ExtractBaseTypeLower(sqlType)) + spec, ok := postgresBaseTypes[base] + return ok && spec.SupportsPrecision +} + +// HasExplicitTypeModifier reports if the type already includes "(...)". +func HasExplicitTypeModifier(sqlType string) bool { + return strings.Contains(sqlType, "(") +} + +func stripArraySuffixes(t string) string { + for strings.HasSuffix(t, "[]") { + t = strings.TrimSpace(strings.TrimSuffix(t, "[]")) + } + return t +} + +func normalizeTypeToken(t string) string { + return strings.Join(strings.Fields(strings.TrimSpace(t)), " ") +} diff --git a/pkg/pgsql/types_registry_test.go b/pkg/pgsql/types_registry_test.go new file mode 100644 index 0000000..5ea920b --- /dev/null +++ b/pkg/pgsql/types_registry_test.go @@ -0,0 +1,99 @@ +package pgsql + +import "testing" + +func TestPostgresTypeRegistry_MasterListIncludesRequestedTypes(t *testing.T) { + required := []string{ + "vector", + "integer", + "citext", + } + + types := make(map[string]bool) + for _, typ := range GetPostgresTypes(true) { + types[typ] = true + } + + for _, typ := range required { + if !types[typ] { + t.Fatalf("master type list missing %q", typ) + } + if !types[typ+"[]"] { + t.Fatalf("master type list missing array variant %q", typ+"[]") + } + } +} + +func TestPostgresTypeRegistry_TypeParsingAndCapabilities(t *testing.T) { + tests := []struct { + input string + wantBase string + wantCanonicalBase string + wantArray bool + wantKnown bool + wantLength bool + wantPrecision bool + }{ + { + input: "integer[]", + wantBase: "integer", + wantCanonicalBase: "integer", + wantArray: true, + wantKnown: true, + }, + { + input: "citext[]", + wantBase: "citext", + wantCanonicalBase: "citext", + wantArray: true, + wantKnown: true, + }, + { + input: "vector(1536)", + wantBase: "vector", + wantCanonicalBase: "vector", + wantKnown: true, + wantLength: true, + }, + { + input: "numeric(10,2)", + wantBase: "numeric", + wantCanonicalBase: "numeric", + wantKnown: true, + wantPrecision: true, + }, + { + input: "int4", + wantBase: "int4", + wantCanonicalBase: "integer", + wantKnown: true, + }, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + base := ExtractBaseTypeLower(tt.input) + if base != tt.wantBase { + t.Fatalf("ExtractBaseTypeLower(%q) = %q, want %q", tt.input, base, tt.wantBase) + } + + canonical := CanonicalizeBaseType(base) + if canonical != tt.wantCanonicalBase { + t.Fatalf("CanonicalizeBaseType(%q) = %q, want %q", base, canonical, tt.wantCanonicalBase) + } + + if IsArrayType(tt.input) != tt.wantArray { + t.Fatalf("IsArrayType(%q) = %v, want %v", tt.input, IsArrayType(tt.input), tt.wantArray) + } + if IsKnownPostgresType(tt.input) != tt.wantKnown { + t.Fatalf("IsKnownPostgresType(%q) = %v, want %v", tt.input, IsKnownPostgresType(tt.input), tt.wantKnown) + } + if SupportsLength(tt.input) != tt.wantLength { + t.Fatalf("SupportsLength(%q) = %v, want %v", tt.input, SupportsLength(tt.input), tt.wantLength) + } + if SupportsPrecision(tt.input) != tt.wantPrecision { + t.Fatalf("SupportsPrecision(%q) = %v, want %v", tt.input, SupportsPrecision(tt.input), tt.wantPrecision) + } + }) + } +} diff --git a/pkg/readers/bun/reader.go b/pkg/readers/bun/reader.go index b3281de..3dfebaf 100644 --- a/pkg/readers/bun/reader.go +++ b/pkg/readers/bun/reader.go @@ -12,6 +12,7 @@ import ( "strings" "git.warky.dev/wdevs/relspecgo/pkg/models" + "git.warky.dev/wdevs/relspecgo/pkg/pgsql" "git.warky.dev/wdevs/relspecgo/pkg/readers" ) @@ -700,16 +701,21 @@ func (r *Reader) extractBunTag(tag string) string { // parseTypeWithLength parses a type string and extracts length if present // e.g., "varchar(255)" returns ("varchar", 255) func (r *Reader) parseTypeWithLength(typeStr string) (baseType string, length int) { + typeStr = strings.TrimSpace(typeStr) + baseType = typeStr + // Check for type with length: varchar(255), char(10), etc. re := regexp.MustCompile(`^([a-zA-Z\s]+)\((\d+)\)$`) matches := re.FindStringSubmatch(typeStr) if len(matches) == 3 { - if _, err := fmt.Sscanf(matches[2], "%d", &length); err == nil { - baseType = strings.TrimSpace(matches[1]) - return + rawBaseType := strings.TrimSpace(matches[1]) + if pgsql.SupportsLength(rawBaseType) { + if _, err := fmt.Sscanf(matches[2], "%d", &length); err == nil { + return + } } } - baseType = typeStr + return } diff --git a/pkg/readers/bun/reader_test.go b/pkg/readers/bun/reader_test.go index 10fb64c..a5f2c33 100644 --- a/pkg/readers/bun/reader_test.go +++ b/pkg/readers/bun/reader_test.go @@ -71,8 +71,11 @@ func TestReader_ReadDatabase_Simple(t *testing.T) { if !emailCol.NotNull { t.Error("Column 'email' should be NOT NULL (explicit 'notnull' tag)") } - if emailCol.Type != "varchar" || emailCol.Length != 255 { - t.Errorf("Expected email type 'varchar(255)', got '%s' with length %d", emailCol.Type, emailCol.Length) + if emailCol.Type != "varchar" && emailCol.Type != "varchar(255)" { + t.Errorf("Expected email type 'varchar' or 'varchar(255)', got '%s' with length %d", emailCol.Type, emailCol.Length) + } + if emailCol.Length != 255 { + t.Errorf("Expected email length 255, got %d", emailCol.Length) } // Verify name column - primitive string type should be NOT NULL by default in Bun @@ -356,6 +359,33 @@ func TestReader_ReadDatabase_Complex(t *testing.T) { } } +func TestParseTypeWithLength_PreservesExplicitTypeModifiers(t *testing.T) { + reader := &Reader{} + + tests := []struct { + input string + wantType string + wantLength int + }{ + {"varchar(255)", "varchar(255)", 255}, + {"character varying(120)", "character varying(120)", 120}, + {"vector(1536)", "vector(1536)", 1536}, + {"numeric(10,2)", "numeric(10,2)", 0}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + gotType, gotLength := reader.parseTypeWithLength(tt.input) + if gotType != tt.wantType { + t.Fatalf("parseTypeWithLength(%q) type = %q, want %q", tt.input, gotType, tt.wantType) + } + if gotLength != tt.wantLength { + t.Fatalf("parseTypeWithLength(%q) length = %d, want %d", tt.input, gotLength, tt.wantLength) + } + }) + } +} + func TestReader_ReadSchema(t *testing.T) { opts := &readers.ReaderOptions{ FilePath: filepath.Join("..", "..", "..", "tests", "assets", "bun", "simple.go"), @@ -485,9 +515,9 @@ func TestReader_NullableTypes(t *testing.T) { // Test all nullability scenarios tests := []struct { - column string - notNull bool - reason string + column string + notNull bool + reason string }{ {"id", true, "primary key"}, {"user_id", true, "explicit notnull tag"}, diff --git a/pkg/readers/dbml/reader.go b/pkg/readers/dbml/reader.go index ea378f9..96fd59c 100644 --- a/pkg/readers/dbml/reader.go +++ b/pkg/readers/dbml/reader.go @@ -567,110 +567,182 @@ func (r *Reader) parseDBML(content string) (*models.Database, error) { // parseColumn parses a DBML column definition func (r *Reader) parseColumn(line, tableName, schemaName string) (*models.Column, *models.Constraint) { // Format: column_name type [attributes] // comment - parts := strings.Fields(line) - if len(parts) < 2 { + lineNoComment, inlineComment := splitInlineComment(line) + signature, attrs := splitColumnSignatureAndAttrs(lineNoComment) + columnName, columnType, ok := parseColumnSignature(signature) + if !ok { return nil, nil } - columnName := stripQuotes(parts[0]) - columnType := stripQuotes(parts[1]) - column := models.InitColumn(columnName, tableName, schemaName) column.Type = columnType var constraint *models.Constraint // Parse attributes in brackets - if strings.Contains(line, "[") && strings.Contains(line, "]") { - attrStart := strings.Index(line, "[") - attrEnd := strings.Index(line, "]") - if attrStart < attrEnd { - attrs := line[attrStart+1 : attrEnd] - attrList := strings.Split(attrs, ",") + if attrs != "" { + attrList := strings.Split(attrs, ",") - for _, attr := range attrList { - attr = strings.TrimSpace(attr) + for _, attr := range attrList { + attr = strings.TrimSpace(attr) - if strings.Contains(attr, "primary key") || attr == "pk" { - column.IsPrimaryKey = true - column.NotNull = true - } else if strings.Contains(attr, "not null") { - column.NotNull = true - } else if attr == "increment" { - column.AutoIncrement = true - } else if strings.HasPrefix(attr, "default:") { - defaultVal := strings.TrimSpace(strings.TrimPrefix(attr, "default:")) - column.Default = strings.Trim(defaultVal, "'\"") - } else if attr == "unique" { - // Create a unique constraint - // Clean table name by removing leading underscores to avoid double underscores - cleanTableName := strings.TrimLeft(tableName, "_") - uniqueConstraint := models.InitConstraint( - fmt.Sprintf("ukey_%s_%s", cleanTableName, columnName), - models.UniqueConstraint, - ) - uniqueConstraint.Schema = schemaName - uniqueConstraint.Table = tableName - uniqueConstraint.Columns = []string{columnName} - // Store it to be added later - if constraint == nil { - constraint = uniqueConstraint - } - } else if strings.HasPrefix(attr, "note:") { - // Parse column note/comment - note := strings.TrimSpace(strings.TrimPrefix(attr, "note:")) - column.Comment = strings.Trim(note, "'\"") - } else if strings.HasPrefix(attr, "ref:") { - // Parse inline reference - // DBML semantics depend on context: - // - On FK column: ref: < target means "this FK references target" - // - On PK column: ref: < source means "source references this PK" (reverse notation) - refStr := strings.TrimSpace(strings.TrimPrefix(attr, "ref:")) - - // Check relationship direction operator - refOp := strings.TrimSpace(refStr) - var isReverse bool - if strings.HasPrefix(refOp, "<") { - // < means "is referenced by" - only makes sense on PK columns - isReverse = column.IsPrimaryKey - } - // > means "references" - always a forward FK, never reverse - - constraint = r.parseRef(refStr) - if constraint != nil { - if isReverse { - // Reverse: parsed ref is SOURCE, current column is TARGET - // Constraint should be ON the source table - constraint.Schema = constraint.ReferencedSchema - constraint.Table = constraint.ReferencedTable - constraint.Columns = constraint.ReferencedColumns - constraint.ReferencedSchema = schemaName - constraint.ReferencedTable = tableName - constraint.ReferencedColumns = []string{columnName} - } else { - // Forward: current column is SOURCE, parsed ref is TARGET - // Standard FK: constraint is ON current table - constraint.Schema = schemaName - constraint.Table = tableName - constraint.Columns = []string{columnName} - } - // Generate constraint name based on table and columns - constraint.Name = fmt.Sprintf("fk_%s_%s", constraint.Table, strings.Join(constraint.Columns, "_")) + if strings.Contains(attr, "primary key") || attr == "pk" { + column.IsPrimaryKey = true + column.NotNull = true + } else if strings.Contains(attr, "not null") { + column.NotNull = true + } else if attr == "increment" { + column.AutoIncrement = true + } else if strings.HasPrefix(attr, "default:") { + defaultVal := strings.TrimSpace(strings.TrimPrefix(attr, "default:")) + column.Default = strings.Trim(defaultVal, "'\"") + } else if attr == "unique" { + // Create a unique constraint + // Clean table name by removing leading underscores to avoid double underscores + cleanTableName := strings.TrimLeft(tableName, "_") + uniqueConstraint := models.InitConstraint( + fmt.Sprintf("ukey_%s_%s", cleanTableName, columnName), + models.UniqueConstraint, + ) + uniqueConstraint.Schema = schemaName + uniqueConstraint.Table = tableName + uniqueConstraint.Columns = []string{columnName} + // Store it to be added later + if constraint == nil { + constraint = uniqueConstraint + } + } else if strings.HasPrefix(attr, "note:") { + // Parse column note/comment + note := strings.TrimSpace(strings.TrimPrefix(attr, "note:")) + column.Comment = strings.Trim(note, "'\"") + } else if strings.HasPrefix(attr, "ref:") { + // Parse inline reference + // DBML semantics depend on context: + // - On FK column: ref: < target means "this FK references target" + // - On PK column: ref: < source means "source references this PK" (reverse notation) + refStr := strings.TrimSpace(strings.TrimPrefix(attr, "ref:")) + + // Check relationship direction operator + refOp := strings.TrimSpace(refStr) + var isReverse bool + if strings.HasPrefix(refOp, "<") { + // < means "is referenced by" - only makes sense on PK columns + isReverse = column.IsPrimaryKey + } + // > means "references" - always a forward FK, never reverse + + constraint = r.parseRef(refStr) + if constraint != nil { + if isReverse { + // Reverse: parsed ref is SOURCE, current column is TARGET + // Constraint should be ON the source table + constraint.Schema = constraint.ReferencedSchema + constraint.Table = constraint.ReferencedTable + constraint.Columns = constraint.ReferencedColumns + constraint.ReferencedSchema = schemaName + constraint.ReferencedTable = tableName + constraint.ReferencedColumns = []string{columnName} + } else { + // Forward: current column is SOURCE, parsed ref is TARGET + // Standard FK: constraint is ON current table + constraint.Schema = schemaName + constraint.Table = tableName + constraint.Columns = []string{columnName} } + // Generate constraint name based on table and columns + constraint.Name = fmt.Sprintf("fk_%s_%s", constraint.Table, strings.Join(constraint.Columns, "_")) } } } } // Parse inline comment - if strings.Contains(line, "//") { - commentStart := strings.Index(line, "//") - column.Comment = strings.TrimSpace(line[commentStart+2:]) + if inlineComment != "" { + column.Comment = inlineComment } return column, constraint } +func splitInlineComment(line string) (string, string) { + commentStart := strings.Index(line, "//") + if commentStart == -1 { + return line, "" + } + + return strings.TrimSpace(line[:commentStart]), strings.TrimSpace(line[commentStart+2:]) +} + +func splitColumnSignatureAndAttrs(line string) (string, string) { + trimmed := strings.TrimSpace(line) + if trimmed == "" || !strings.HasSuffix(trimmed, "]") { + return trimmed, "" + } + + bracketDepth := 0 + for i := len(trimmed) - 1; i >= 0; i-- { + switch trimmed[i] { + case ']': + bracketDepth++ + case '[': + bracketDepth-- + if bracketDepth == 0 { + // DBML attributes are a trailing [ ... ] block preceded by whitespace. + // This avoids confusing array types like text[] with attribute blocks. + if i > 0 && (trimmed[i-1] == ' ' || trimmed[i-1] == '\t') { + return strings.TrimSpace(trimmed[:i]), strings.TrimSpace(trimmed[i+1 : len(trimmed)-1]) + } + } + } + } + + return trimmed, "" +} + +func parseColumnSignature(signature string) (string, string, bool) { + signature = strings.TrimSpace(signature) + if signature == "" { + return "", "", false + } + + var splitAt int + if signature[0] == '"' || signature[0] == '\'' { + quote := signature[0] + splitAt = 1 + for splitAt < len(signature) { + if signature[splitAt] == quote { + splitAt++ + break + } + splitAt++ + } + } else { + for splitAt < len(signature) && signature[splitAt] != ' ' && signature[splitAt] != '\t' { + splitAt++ + } + } + + if splitAt <= 0 || splitAt >= len(signature) { + return "", "", false + } + + columnName := stripQuotes(strings.TrimSpace(signature[:splitAt])) + columnType := stripWrappingQuotes(strings.TrimSpace(signature[splitAt:])) + if columnName == "" || columnType == "" { + return "", "", false + } + + return columnName, columnType, true +} + +func stripWrappingQuotes(s string) string { + s = strings.TrimSpace(s) + if len(s) >= 2 && ((s[0] == '"' && s[len(s)-1] == '"') || (s[0] == '\'' && s[len(s)-1] == '\'')) { + return s[1 : len(s)-1] + } + return s +} + // parseIndex parses a DBML index definition func (r *Reader) parseIndex(line, tableName, schemaName string) *models.Index { // Format: (columns) [attributes] OR columnname [attributes] diff --git a/pkg/readers/dbml/reader_test.go b/pkg/readers/dbml/reader_test.go index 1e360dc..e9e39e3 100644 --- a/pkg/readers/dbml/reader_test.go +++ b/pkg/readers/dbml/reader_test.go @@ -839,6 +839,67 @@ func TestConstraintNaming(t *testing.T) { } } +func TestParseColumn_PostgresTypes(t *testing.T) { + reader := &Reader{} + + tests := []struct { + name string + line string + wantName string + wantType string + wantNotNull bool + wantComment string + }{ + { + name: "array type with attrs", + line: "tags text[] [not null]", + wantName: "tags", + wantType: "text[]", + wantNotNull: true, + }, + { + name: "vector with dimension", + line: "embedding vector(1536)", + wantName: "embedding", + wantType: "vector(1536)", + }, + { + name: "multi word timestamp type", + line: "published_at timestamp with time zone", + wantName: "published_at", + wantType: "timestamp with time zone", + }, + { + name: "array type with inline comment", + line: "labels varchar(20)[] // column labels", + wantName: "labels", + wantType: "varchar(20)[]", + wantComment: "column labels", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + col, _ := reader.parseColumn(tt.line, "events", "public") + if col == nil { + t.Fatalf("parseColumn() returned nil column") + } + if col.Name != tt.wantName { + t.Errorf("column name = %q, want %q", col.Name, tt.wantName) + } + if col.Type != tt.wantType { + t.Errorf("column type = %q, want %q", col.Type, tt.wantType) + } + if col.NotNull != tt.wantNotNull { + t.Errorf("column not null = %v, want %v", col.NotNull, tt.wantNotNull) + } + if col.Comment != tt.wantComment { + t.Errorf("column comment = %q, want %q", col.Comment, tt.wantComment) + } + }) + } +} + func getKeys[V any](m map[string]V) []string { keys := make([]string, 0, len(m)) for k := range m { diff --git a/pkg/readers/drawdb/reader.go b/pkg/readers/drawdb/reader.go index f42f2c8..33d36d0 100644 --- a/pkg/readers/drawdb/reader.go +++ b/pkg/readers/drawdb/reader.go @@ -8,6 +8,7 @@ import ( "strings" "git.warky.dev/wdevs/relspecgo/pkg/models" + "git.warky.dev/wdevs/relspecgo/pkg/pgsql" "git.warky.dev/wdevs/relspecgo/pkg/readers" "git.warky.dev/wdevs/relspecgo/pkg/writers/drawdb" ) @@ -231,30 +232,35 @@ func (r *Reader) convertToColumn(field *drawdb.DrawDBField, tableName, schemaNam // Parse type and dimensions typeStr := field.Type + typeStr = strings.TrimSpace(typeStr) column.Type = typeStr // Try to extract length/precision from type string like "varchar(255)" or "decimal(10,2)" if strings.Contains(typeStr, "(") { parts := strings.Split(typeStr, "(") - column.Type = parts[0] + baseType := strings.TrimSpace(parts[0]) if len(parts) > 1 { dimensions := strings.TrimSuffix(parts[1], ")") if strings.Contains(dimensions, ",") { - // Precision and scale (e.g., decimal(10,2)) - dims := strings.Split(dimensions, ",") - if precision, err := strconv.Atoi(strings.TrimSpace(dims[0])); err == nil { - column.Precision = precision - } - if len(dims) > 1 { - if scale, err := strconv.Atoi(strings.TrimSpace(dims[1])); err == nil { - column.Scale = scale + // Precision and scale (e.g., decimal(10,2), numeric(10,2)) + if pgsql.SupportsPrecision(baseType) { + dims := strings.Split(dimensions, ",") + if precision, err := strconv.Atoi(strings.TrimSpace(dims[0])); err == nil { + column.Precision = precision + } + if len(dims) > 1 { + if scale, err := strconv.Atoi(strings.TrimSpace(dims[1])); err == nil { + column.Scale = scale + } } } } else { // Just length (e.g., varchar(255)) - if length, err := strconv.Atoi(dimensions); err == nil { - column.Length = length + if pgsql.SupportsLength(baseType) { + if length, err := strconv.Atoi(dimensions); err == nil { + column.Length = length + } } } } diff --git a/pkg/readers/drawdb/reader_test.go b/pkg/readers/drawdb/reader_test.go index fbdf1ab..3ac2911 100644 --- a/pkg/readers/drawdb/reader_test.go +++ b/pkg/readers/drawdb/reader_test.go @@ -6,6 +6,7 @@ import ( "git.warky.dev/wdevs/relspecgo/pkg/models" "git.warky.dev/wdevs/relspecgo/pkg/readers" + "git.warky.dev/wdevs/relspecgo/pkg/writers/drawdb" ) func TestReader_ReadDatabase_Simple(t *testing.T) { @@ -288,6 +289,61 @@ func TestReader_ReadDatabase_Complex(t *testing.T) { } } +func TestConvertToColumn_PreservesExplicitTypeModifiers(t *testing.T) { + reader := &Reader{} + + tests := []struct { + name string + fieldType string + wantType string + wantLength int + wantPrecision int + wantScale int + }{ + { + name: "varchar with length", + fieldType: "varchar(255)", + wantType: "varchar(255)", + wantLength: 255, + }, + { + name: "numeric precision/scale", + fieldType: "numeric(10,2)", + wantType: "numeric(10,2)", + wantPrecision: 10, + wantScale: 2, + }, + { + name: "custom vector modifier", + fieldType: "vector(1536)", + wantType: "vector(1536)", + wantLength: 1536, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + field := &drawdb.DrawDBField{ + Name: tt.name, + Type: tt.fieldType, + } + col := reader.convertToColumn(field, "events", "public") + if col.Type != tt.wantType { + t.Fatalf("column type = %q, want %q", col.Type, tt.wantType) + } + if col.Length != tt.wantLength { + t.Fatalf("column length = %d, want %d", col.Length, tt.wantLength) + } + if col.Precision != tt.wantPrecision { + t.Fatalf("column precision = %d, want %d", col.Precision, tt.wantPrecision) + } + if col.Scale != tt.wantScale { + t.Fatalf("column scale = %d, want %d", col.Scale, tt.wantScale) + } + }) + } +} + func TestReader_ReadSchema(t *testing.T) { opts := &readers.ReaderOptions{ FilePath: filepath.Join("..", "..", "..", "tests", "assets", "drawdb", "simple.json"), diff --git a/pkg/readers/gorm/reader.go b/pkg/readers/gorm/reader.go index a9aba30..b2bc698 100644 --- a/pkg/readers/gorm/reader.go +++ b/pkg/readers/gorm/reader.go @@ -12,6 +12,7 @@ import ( "strings" "git.warky.dev/wdevs/relspecgo/pkg/models" + "git.warky.dev/wdevs/relspecgo/pkg/pgsql" "git.warky.dev/wdevs/relspecgo/pkg/readers" ) @@ -784,11 +785,14 @@ func (r *Reader) extractGormTag(tag string) string { // parseTypeWithLength parses a type string and extracts length if present // e.g., "varchar(255)" returns ("varchar", 255) func (r *Reader) parseTypeWithLength(typeStr string) (baseType string, length int) { + typeStr = strings.TrimSpace(typeStr) + baseType = typeStr + // Check for type with length: varchar(255), char(10), etc. // Also handle precision/scale: numeric(10,2) if strings.Contains(typeStr, "(") { idx := strings.Index(typeStr, "(") - baseType = strings.TrimSpace(typeStr[:idx]) + rawBaseType := strings.TrimSpace(typeStr[:idx]) // Extract numbers from parentheses parens := typeStr[idx+1:] @@ -796,14 +800,15 @@ func (r *Reader) parseTypeWithLength(typeStr string) (baseType string, length in parens = parens[:endIdx] } - // For now, just handle single number (length) - if !strings.Contains(parens, ",") { + // Only treat as "length" for text-ish SQL types. + // This avoids converting custom modifiers like vector(1536) into Length. + if pgsql.SupportsLength(rawBaseType) && !strings.Contains(parens, ",") { if _, err := fmt.Sscanf(parens, "%d", &length); err == nil { return } } } - baseType = typeStr + return } diff --git a/pkg/readers/gorm/reader_test.go b/pkg/readers/gorm/reader_test.go index 76f53d0..b3bd9bc 100644 --- a/pkg/readers/gorm/reader_test.go +++ b/pkg/readers/gorm/reader_test.go @@ -71,8 +71,11 @@ func TestReader_ReadDatabase_Simple(t *testing.T) { if !emailCol.NotNull { t.Error("Column 'email' should be NOT NULL (explicit 'not null' tag)") } - if emailCol.Type != "varchar" || emailCol.Length != 255 { - t.Errorf("Expected email type 'varchar(255)', got '%s' with length %d", emailCol.Type, emailCol.Length) + if emailCol.Type != "varchar" && emailCol.Type != "varchar(255)" { + t.Errorf("Expected email type 'varchar' or 'varchar(255)', got '%s' with length %d", emailCol.Type, emailCol.Length) + } + if emailCol.Length != 255 { + t.Errorf("Expected email length 255, got %d", emailCol.Length) } // Verify name column - primitive string type should be NOT NULL by default @@ -363,6 +366,33 @@ func TestReader_ReadDatabase_Complex(t *testing.T) { } } +func TestParseTypeWithLength_PreservesExplicitTypeModifiers(t *testing.T) { + reader := &Reader{} + + tests := []struct { + input string + wantType string + wantLength int + }{ + {"varchar(255)", "varchar(255)", 255}, + {"character varying(120)", "character varying(120)", 120}, + {"vector(1536)", "vector(1536)", 1536}, + {"numeric(10,2)", "numeric(10,2)", 0}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + gotType, gotLength := reader.parseTypeWithLength(tt.input) + if gotType != tt.wantType { + t.Fatalf("parseTypeWithLength(%q) type = %q, want %q", tt.input, gotType, tt.wantType) + } + if gotLength != tt.wantLength { + t.Fatalf("parseTypeWithLength(%q) length = %d, want %d", tt.input, gotLength, tt.wantLength) + } + }) + } +} + func TestReader_ReadSchema(t *testing.T) { opts := &readers.ReaderOptions{ FilePath: filepath.Join("..", "..", "..", "tests", "assets", "gorm", "simple.go"), diff --git a/pkg/readers/pgsql/queries.go b/pkg/readers/pgsql/queries.go index 668cecb..e43d417 100644 --- a/pkg/readers/pgsql/queries.go +++ b/pkg/readers/pgsql/queries.go @@ -206,8 +206,19 @@ func (r *Reader) queryColumns(schemaName string) (map[string]map[string]*models. c.numeric_precision, c.numeric_scale, c.udt_name, + pg_catalog.format_type(a.atttypid, a.atttypmod) as formatted_data_type, col_description((c.table_schema||'.'||c.table_name)::regclass, c.ordinal_position) as description FROM information_schema.columns c + JOIN pg_catalog.pg_namespace n + ON n.nspname = c.table_schema + JOIN pg_catalog.pg_class cls + ON cls.relname = c.table_name + AND cls.relnamespace = n.oid + JOIN pg_catalog.pg_attribute a + ON a.attrelid = cls.oid + AND a.attname = c.column_name + AND a.attnum > 0 + AND NOT a.attisdropped WHERE c.table_schema = $1 ORDER BY c.table_schema, c.table_name, c.ordinal_position ` @@ -221,12 +232,12 @@ func (r *Reader) queryColumns(schemaName string) (map[string]map[string]*models. columnsMap := make(map[string]map[string]*models.Column) for rows.Next() { - var schema, tableName, columnName, isNullable, dataType, udtName string + var schema, tableName, columnName, isNullable, dataType, udtName, formattedDataType string var ordinalPosition int var columnDefault, description *string var charMaxLength, numPrecision, numScale *int - if err := rows.Scan(&schema, &tableName, &columnName, &ordinalPosition, &columnDefault, &isNullable, &dataType, &charMaxLength, &numPrecision, &numScale, &udtName, &description); err != nil { + if err := rows.Scan(&schema, &tableName, &columnName, &ordinalPosition, &columnDefault, &isNullable, &dataType, &charMaxLength, &numPrecision, &numScale, &udtName, &formattedDataType, &description); err != nil { return nil, err } @@ -246,7 +257,7 @@ func (r *Reader) queryColumns(schemaName string) (map[string]map[string]*models. } // Map data type, preserving serial types when detected - column.Type = r.mapDataType(dataType, udtName, hasNextval) + column.Type = r.mapDataType(dataType, udtName, formattedDataType, hasNextval) column.NotNull = (isNullable == "NO") column.Sequence = uint(ordinalPosition) diff --git a/pkg/readers/pgsql/reader.go b/pkg/readers/pgsql/reader.go index a60cd44..fd4bc4b 100644 --- a/pkg/readers/pgsql/reader.go +++ b/pkg/readers/pgsql/reader.go @@ -259,12 +259,14 @@ func (r *Reader) close() { } } -// mapDataType maps PostgreSQL data types to canonical types -func (r *Reader) mapDataType(pgType, udtName string, hasNextval bool) string { +// mapDataType maps PostgreSQL data types while preserving exact type text when available. +func (r *Reader) mapDataType(pgType, udtName, formattedType string, hasNextval bool) string { + normalizedPGType := strings.ToLower(strings.TrimSpace(pgType)) + // If the column has a nextval default, it's likely a serial type // Map to the appropriate serial type instead of the base integer type if hasNextval { - switch strings.ToLower(pgType) { + switch normalizedPGType { case "integer", "int", "int4": return "serial" case "bigint", "int8": @@ -274,6 +276,17 @@ func (r *Reader) mapDataType(pgType, udtName string, hasNextval bool) string { } } + // Prefer the database-provided formatted type; this preserves arrays/custom + // types/modifiers like text[], vector(1536), numeric(10,2), etc. + if strings.TrimSpace(formattedType) != "" { + return formattedType + } + + // information_schema reports arrays generically as "ARRAY" with udt_name like "_text". + if strings.EqualFold(pgType, "ARRAY") && strings.HasPrefix(udtName, "_") && len(udtName) > 1 { + return udtName[1:] + "[]" + } + // Map common PostgreSQL types typeMap := map[string]string{ "integer": "integer", @@ -320,7 +333,7 @@ func (r *Reader) mapDataType(pgType, udtName string, hasNextval bool) string { } // Try mapped type first - if mapped, exists := typeMap[pgType]; exists { + if mapped, exists := typeMap[normalizedPGType]; exists { return mapped } @@ -329,8 +342,11 @@ func (r *Reader) mapDataType(pgType, udtName string, hasNextval bool) string { return pgsql.GetSQLType(pgType) } - // Return UDT name for custom types + // Return UDT name for custom types (including array fallback when needed) if udtName != "" { + if strings.HasPrefix(udtName, "_") && len(udtName) > 1 { + return udtName[1:] + "[]" + } return udtName } diff --git a/pkg/readers/pgsql/reader_test.go b/pkg/readers/pgsql/reader_test.go index e496b47..e7eb09c 100644 --- a/pkg/readers/pgsql/reader_test.go +++ b/pkg/readers/pgsql/reader_test.go @@ -173,35 +173,39 @@ func TestMapDataType(t *testing.T) { reader := &Reader{} tests := []struct { - pgType string - udtName string - expected string + pgType string + udtName string + formattedType string + expected string }{ - {"integer", "int4", "integer"}, - {"bigint", "int8", "bigint"}, - {"smallint", "int2", "smallint"}, - {"character varying", "varchar", "varchar"}, - {"text", "text", "text"}, - {"boolean", "bool", "boolean"}, - {"timestamp without time zone", "timestamp", "timestamp"}, - {"timestamp with time zone", "timestamptz", "timestamptz"}, - {"json", "json", "json"}, - {"jsonb", "jsonb", "jsonb"}, - {"uuid", "uuid", "uuid"}, - {"numeric", "numeric", "numeric"}, - {"real", "float4", "real"}, - {"double precision", "float8", "double precision"}, - {"date", "date", "date"}, - {"time without time zone", "time", "time"}, - {"bytea", "bytea", "bytea"}, - {"unknown_type", "custom", "custom"}, // Should return UDT name + {"integer", "int4", "", "integer"}, + {"bigint", "int8", "", "bigint"}, + {"smallint", "int2", "", "smallint"}, + {"character varying", "varchar", "", "varchar"}, + {"text", "text", "", "text"}, + {"boolean", "bool", "", "boolean"}, + {"timestamp without time zone", "timestamp", "", "timestamp"}, + {"timestamp with time zone", "timestamptz", "", "timestamptz"}, + {"json", "json", "", "json"}, + {"jsonb", "jsonb", "", "jsonb"}, + {"uuid", "uuid", "", "uuid"}, + {"numeric", "numeric", "", "numeric"}, + {"real", "float4", "", "real"}, + {"double precision", "float8", "", "double precision"}, + {"date", "date", "", "date"}, + {"time without time zone", "time", "", "time"}, + {"bytea", "bytea", "", "bytea"}, + {"unknown_type", "custom", "", "custom"}, // Should return UDT name + {"ARRAY", "_text", "", "text[]"}, + {"USER-DEFINED", "vector", "vector(1536)", "vector(1536)"}, + {"character varying", "varchar", "character varying(255)", "character varying(255)"}, } for _, tt := range tests { t.Run(tt.pgType, func(t *testing.T) { - result := reader.mapDataType(tt.pgType, tt.udtName, false) + result := reader.mapDataType(tt.pgType, tt.udtName, tt.formattedType, false) if result != tt.expected { - t.Errorf("mapDataType(%s, %s) = %s, expected %s", tt.pgType, tt.udtName, result, tt.expected) + t.Errorf("mapDataType(%s, %s, %s) = %s, expected %s", tt.pgType, tt.udtName, tt.formattedType, result, tt.expected) } }) } @@ -218,9 +222,9 @@ func TestMapDataType(t *testing.T) { for _, tt := range serialTests { t.Run(tt.pgType+"_with_nextval", func(t *testing.T) { - result := reader.mapDataType(tt.pgType, "", true) + result := reader.mapDataType(tt.pgType, "", "", true) if result != tt.expected { - t.Errorf("mapDataType(%s, '', true) = %s, expected %s", tt.pgType, result, tt.expected) + t.Errorf("mapDataType(%s, '', '', true) = %s, expected %s", tt.pgType, result, tt.expected) } }) } @@ -230,63 +234,63 @@ func TestParseIndexDefinition(t *testing.T) { reader := &Reader{} tests := []struct { - name string - indexName string - tableName string - schema string - indexDef string - wantType string - wantUnique bool + name string + indexName string + tableName string + schema string + indexDef string + wantType string + wantUnique bool wantColumns int }{ { - name: "simple btree index", - indexName: "idx_users_email", - tableName: "users", - schema: "public", - indexDef: "CREATE INDEX idx_users_email ON public.users USING btree (email)", - wantType: "btree", - wantUnique: false, + name: "simple btree index", + indexName: "idx_users_email", + tableName: "users", + schema: "public", + indexDef: "CREATE INDEX idx_users_email ON public.users USING btree (email)", + wantType: "btree", + wantUnique: false, wantColumns: 1, }, { - name: "unique index", - indexName: "idx_users_username", - tableName: "users", - schema: "public", - indexDef: "CREATE UNIQUE INDEX idx_users_username ON public.users USING btree (username)", - wantType: "btree", - wantUnique: true, + name: "unique index", + indexName: "idx_users_username", + tableName: "users", + schema: "public", + indexDef: "CREATE UNIQUE INDEX idx_users_username ON public.users USING btree (username)", + wantType: "btree", + wantUnique: true, wantColumns: 1, }, { - name: "composite index", - indexName: "idx_users_name", - tableName: "users", - schema: "public", - indexDef: "CREATE INDEX idx_users_name ON public.users USING btree (first_name, last_name)", - wantType: "btree", - wantUnique: false, + name: "composite index", + indexName: "idx_users_name", + tableName: "users", + schema: "public", + indexDef: "CREATE INDEX idx_users_name ON public.users USING btree (first_name, last_name)", + wantType: "btree", + wantUnique: false, wantColumns: 2, }, { - name: "gin index", - indexName: "idx_posts_tags", - tableName: "posts", - schema: "public", - indexDef: "CREATE INDEX idx_posts_tags ON public.posts USING gin (tags)", - wantType: "gin", - wantUnique: false, + name: "gin index", + indexName: "idx_posts_tags", + tableName: "posts", + schema: "public", + indexDef: "CREATE INDEX idx_posts_tags ON public.posts USING gin (tags)", + wantType: "gin", + wantUnique: false, wantColumns: 1, }, { - name: "partial index with where clause", - indexName: "idx_users_active", - tableName: "users", - schema: "public", - indexDef: "CREATE INDEX idx_users_active ON public.users USING btree (id) WHERE (active = true)", - wantType: "btree", - wantUnique: false, + name: "partial index with where clause", + indexName: "idx_users_active", + tableName: "users", + schema: "public", + indexDef: "CREATE INDEX idx_users_active ON public.users USING btree (id) WHERE (active = true)", + wantType: "btree", + wantUnique: false, wantColumns: 1, }, } diff --git a/pkg/readers/typeorm/reader.go b/pkg/readers/typeorm/reader.go index 3e1f01c..660a8db 100644 --- a/pkg/readers/typeorm/reader.go +++ b/pkg/readers/typeorm/reader.go @@ -5,9 +5,11 @@ import ( "fmt" "os" "regexp" + "strconv" "strings" "git.warky.dev/wdevs/relspecgo/pkg/models" + "git.warky.dev/wdevs/relspecgo/pkg/pgsql" "git.warky.dev/wdevs/relspecgo/pkg/readers" ) @@ -549,6 +551,41 @@ func (r *Reader) parseColumnOptions(decorator string, column *models.Column, tab } } + // Preserve explicit type modifiers from options where present. + // Example: @Column({ type: 'varchar', length: 255 }) -> varchar(255) + if column.Type != "" && !strings.Contains(column.Type, "(") { + lengthRegex := regexp.MustCompile(`length:\s*(\d+)`) + precisionRegex := regexp.MustCompile(`precision:\s*(\d+)`) + scaleRegex := regexp.MustCompile(`scale:\s*(\d+)`) + + baseType := strings.ToLower(strings.TrimSpace(column.Type)) + + if pgsql.SupportsLength(baseType) { + if matches := lengthRegex.FindStringSubmatch(content); len(matches) == 2 { + if n, err := strconv.Atoi(matches[1]); err == nil && n > 0 { + column.Length = n + column.Type = fmt.Sprintf("%s(%d)", column.Type, n) + } + } + } + + if pgsql.SupportsPrecision(baseType) { + if matches := precisionRegex.FindStringSubmatch(content); len(matches) == 2 { + if p, err := strconv.Atoi(matches[1]); err == nil && p > 0 { + column.Precision = p + if sm := scaleRegex.FindStringSubmatch(content); len(sm) == 2 { + if s, err := strconv.Atoi(sm[1]); err == nil && s >= 0 { + column.Scale = s + column.Type = fmt.Sprintf("%s(%d,%d)", column.Type, p, s) + } + } else { + column.Type = fmt.Sprintf("%s(%d)", column.Type, p) + } + } + } + } + } + if strings.Contains(content, "nullable: true") || strings.Contains(content, "nullable:true") { column.NotNull = false } diff --git a/pkg/readers/typeorm/reader_test.go b/pkg/readers/typeorm/reader_test.go new file mode 100644 index 0000000..4e98c5a --- /dev/null +++ b/pkg/readers/typeorm/reader_test.go @@ -0,0 +1,60 @@ +package typeorm + +import ( + "testing" + + "git.warky.dev/wdevs/relspecgo/pkg/models" +) + +func TestParseColumnOptions_PreservesTypeModifiers(t *testing.T) { + reader := &Reader{} + table := models.InitTable("users", "public") + + tests := []struct { + name string + decorator string + wantType string + wantLength int + wantPrecision int + wantScale int + }{ + { + name: "varchar with length", + decorator: `@Column({ type: 'varchar', length: 255 })`, + wantType: "varchar(255)", + wantLength: 255, + }, + { + name: "numeric with precision and scale", + decorator: `@Column({ type: 'numeric', precision: 10, scale: 2 })`, + wantType: "numeric(10,2)", + wantPrecision: 10, + wantScale: 2, + }, + { + name: "custom type with explicit modifier is preserved", + decorator: `@Column({ type: 'vector(1536)' })`, + wantType: "vector(1536)", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + col := models.InitColumn("sample", table.Name, table.Schema) + reader.parseColumnOptions(tt.decorator, col, table) + + if col.Type != tt.wantType { + t.Fatalf("column type = %q, want %q", col.Type, tt.wantType) + } + if col.Length != tt.wantLength { + t.Fatalf("column length = %d, want %d", col.Length, tt.wantLength) + } + if col.Precision != tt.wantPrecision { + t.Fatalf("column precision = %d, want %d", col.Precision, tt.wantPrecision) + } + if col.Scale != tt.wantScale { + t.Fatalf("column scale = %d, want %d", col.Scale, tt.wantScale) + } + }) + } +} diff --git a/pkg/writers/bun/type_mapper.go b/pkg/writers/bun/type_mapper.go index 5767849..e7c136b 100644 --- a/pkg/writers/bun/type_mapper.go +++ b/pkg/writers/bun/type_mapper.go @@ -5,6 +5,7 @@ import ( "strings" "git.warky.dev/wdevs/relspecgo/pkg/models" + "git.warky.dev/wdevs/relspecgo/pkg/pgsql" "git.warky.dev/wdevs/relspecgo/pkg/writers" ) @@ -39,14 +40,7 @@ func (tm *TypeMapper) SQLTypeToGoType(sqlType string, notNull bool) string { // extractBaseType extracts the base type from a SQL type string func (tm *TypeMapper) extractBaseType(sqlType string) string { - sqlType = strings.ToLower(strings.TrimSpace(sqlType)) - - // Remove everything after '(' - if idx := strings.Index(sqlType, "("); idx > 0 { - sqlType = sqlType[:idx] - } - - return sqlType + return pgsql.CanonicalizeBaseType(pgsql.ExtractBaseTypeLower(sqlType)) } // isSimpleType checks if a type should use base Go type when NOT NULL @@ -184,9 +178,10 @@ func (tm *TypeMapper) BuildBunTag(column *models.Column, table *models.Table) st if column.Type != "" { // Sanitize type to remove backticks typeStr := writers.SanitizeStructTagValue(column.Type) - if column.Length > 0 { + hasExplicitTypeModifier := pgsql.HasExplicitTypeModifier(typeStr) + if !hasExplicitTypeModifier && column.Length > 0 { typeStr = fmt.Sprintf("%s(%d)", typeStr, column.Length) - } else if column.Precision > 0 { + } else if !hasExplicitTypeModifier && column.Precision > 0 { if column.Scale > 0 { typeStr = fmt.Sprintf("%s(%d,%d)", typeStr, column.Precision, column.Scale) } else { diff --git a/pkg/writers/bun/writer_test.go b/pkg/writers/bun/writer_test.go index 2a7862a..287b122 100644 --- a/pkg/writers/bun/writer_test.go +++ b/pkg/writers/bun/writer_test.go @@ -698,3 +698,23 @@ func TestTypeMapper_BuildBunTag(t *testing.T) { }) } } + +func TestTypeMapper_BuildBunTag_PreservesExplicitTypeModifiers(t *testing.T) { + mapper := NewTypeMapper() + + col := &models.Column{ + Name: "embedding", + Type: "vector(1536)", + Length: 1536, + Precision: 0, + Scale: 0, + } + + tag := mapper.BuildBunTag(col, nil) + if !strings.Contains(tag, "type:vector(1536),") { + t.Fatalf("expected explicit modifier to be preserved, got %q", tag) + } + if strings.Contains(tag, ")(") { + t.Fatalf("type modifier appears duplicated in %q", tag) + } +} diff --git a/pkg/writers/drizzle/type_mapper.go b/pkg/writers/drizzle/type_mapper.go index 97998bd..475d3df 100644 --- a/pkg/writers/drizzle/type_mapper.go +++ b/pkg/writers/drizzle/type_mapper.go @@ -5,6 +5,7 @@ import ( "strings" "git.warky.dev/wdevs/relspecgo/pkg/models" + "git.warky.dev/wdevs/relspecgo/pkg/pgsql" ) // TypeMapper handles SQL to Drizzle type conversions @@ -18,7 +19,7 @@ func NewTypeMapper() *TypeMapper { // SQLTypeToDrizzle converts SQL types to Drizzle column type functions // Returns the Drizzle column constructor (e.g., "integer", "varchar", "text") func (tm *TypeMapper) SQLTypeToDrizzle(sqlType string) string { - sqlTypeLower := strings.ToLower(sqlType) + sqlTypeLower := pgsql.CanonicalizeBaseType(pgsql.ExtractBaseTypeLower(sqlType)) // PostgreSQL type mapping to Drizzle typeMap := map[string]string{ @@ -87,13 +88,6 @@ func (tm *TypeMapper) SQLTypeToDrizzle(sqlType string) string { return drizzleType } - // Check for partial matches (e.g., "varchar(255)" -> "varchar") - for sqlPattern, drizzleType := range typeMap { - if strings.HasPrefix(sqlTypeLower, sqlPattern) { - return drizzleType - } - } - // Default to text for unknown types return "text" } diff --git a/pkg/writers/gorm/type_mapper.go b/pkg/writers/gorm/type_mapper.go index 097d503..a61e643 100644 --- a/pkg/writers/gorm/type_mapper.go +++ b/pkg/writers/gorm/type_mapper.go @@ -5,6 +5,7 @@ import ( "strings" "git.warky.dev/wdevs/relspecgo/pkg/models" + "git.warky.dev/wdevs/relspecgo/pkg/pgsql" "git.warky.dev/wdevs/relspecgo/pkg/writers" ) @@ -39,14 +40,7 @@ func (tm *TypeMapper) SQLTypeToGoType(sqlType string, notNull bool) string { // extractBaseType extracts the base type from a SQL type string // Examples: varchar(100) → varchar, numeric(10,2) → numeric func (tm *TypeMapper) extractBaseType(sqlType string) string { - sqlType = strings.ToLower(strings.TrimSpace(sqlType)) - - // Remove everything after '(' - if idx := strings.Index(sqlType, "("); idx > 0 { - sqlType = sqlType[:idx] - } - - return sqlType + return pgsql.CanonicalizeBaseType(pgsql.ExtractBaseTypeLower(sqlType)) } // baseGoType returns the base Go type for a SQL type (not null) @@ -209,9 +203,10 @@ func (tm *TypeMapper) BuildGormTag(column *models.Column, table *models.Table) s // Include length, precision, scale if present // Sanitize type to remove backticks typeStr := writers.SanitizeStructTagValue(column.Type) - if column.Length > 0 { + hasExplicitTypeModifier := pgsql.HasExplicitTypeModifier(typeStr) + if !hasExplicitTypeModifier && column.Length > 0 { typeStr = fmt.Sprintf("%s(%d)", typeStr, column.Length) - } else if column.Precision > 0 { + } else if !hasExplicitTypeModifier && column.Precision > 0 { if column.Scale > 0 { typeStr = fmt.Sprintf("%s(%d,%d)", typeStr, column.Precision, column.Scale) } else { diff --git a/pkg/writers/gorm/writer_test.go b/pkg/writers/gorm/writer_test.go index 65ab0e6..b90a67a 100644 --- a/pkg/writers/gorm/writer_test.go +++ b/pkg/writers/gorm/writer_test.go @@ -14,12 +14,12 @@ func TestWriter_WriteTable(t *testing.T) { // Create a simple table table := models.InitTable("users", "public") table.Columns["id"] = &models.Column{ - Name: "id", - Type: "bigint", - NotNull: true, - IsPrimaryKey: true, + Name: "id", + Type: "bigint", + NotNull: true, + IsPrimaryKey: true, AutoIncrement: true, - Sequence: 1, + Sequence: 1, } table.Columns["email"] = &models.Column{ Name: "email", @@ -444,10 +444,10 @@ func TestWriter_MultipleHasManyRelationships(t *testing.T) { // Verify all has-many relationships have unique names hasManyExpectations := []string{ - "RelRIDAPIProviderOrgLogins", // Has many via Login + "RelRIDAPIProviderOrgLogins", // Has many via Login "RelRIDAPIProviderOrgFilepointers", // Has many via Filepointer - "RelRIDAPIProviderOrgAPIEvents", // Has many via APIEvent - "RelRIDOwner", // Belongs to via rid_owner + "RelRIDAPIProviderOrgAPIEvents", // Has many via APIEvent + "RelRIDOwner", // Belongs to via rid_owner } for _, exp := range hasManyExpectations { @@ -669,3 +669,23 @@ func TestTypeMapper_SQLTypeToGoType(t *testing.T) { }) } } + +func TestTypeMapper_BuildGormTag_PreservesExplicitTypeModifiers(t *testing.T) { + mapper := NewTypeMapper() + + col := &models.Column{ + Name: "embedding", + Type: "vector(1536)", + Length: 1536, + Precision: 0, + Scale: 0, + } + + tag := mapper.BuildGormTag(col, nil) + if !strings.Contains(tag, "type:vector(1536)") { + t.Fatalf("expected explicit modifier to be preserved, got %q", tag) + } + if strings.Contains(tag, ")(") { + t.Fatalf("type modifier appears duplicated in %q", tag) + } +} diff --git a/pkg/writers/graphql/type_mapping.go b/pkg/writers/graphql/type_mapping.go index c252cea..142a95e 100644 --- a/pkg/writers/graphql/type_mapping.go +++ b/pkg/writers/graphql/type_mapping.go @@ -4,6 +4,7 @@ import ( "strings" "git.warky.dev/wdevs/relspecgo/pkg/models" + "git.warky.dev/wdevs/relspecgo/pkg/pgsql" ) func (w *Writer) sqlTypeToGraphQL(sqlType string, column *models.Column, table *models.Table, schema *models.Schema) string { @@ -33,12 +34,11 @@ func (w *Writer) sqlTypeToGraphQL(sqlType string, column *models.Column, table * } // Standard type mappings - baseType := strings.Split(sqlType, "(")[0] // Remove length/precision - baseType = strings.TrimSpace(baseType) + baseType := pgsql.CanonicalizeBaseType(pgsql.ExtractBaseTypeLower(sqlType)) // Handle array types - if strings.HasSuffix(baseType, "[]") { - elemType := strings.TrimSuffix(baseType, "[]") + if pgsql.IsArrayType(sqlType) { + elemType := pgsql.CanonicalizeBaseType(pgsql.ExtractBaseTypeLower(pgsql.ElementType(sqlType))) gqlType := w.mapBaseTypeToGraphQL(elemType) return "[" + gqlType + "]" } @@ -108,8 +108,7 @@ func (w *Writer) sqlTypeToCustomScalar(sqlType string) string { "date": "Date", } - baseType := strings.Split(sqlType, "(")[0] - baseType = strings.TrimSpace(baseType) + baseType := pgsql.CanonicalizeBaseType(pgsql.ExtractBaseTypeLower(sqlType)) if scalar, ok := scalarMap[baseType]; ok { return scalar @@ -132,8 +131,7 @@ func (w *Writer) isIntegerType(sqlType string) bool { "smallserial": true, } - baseType := strings.Split(sqlType, "(")[0] - baseType = strings.TrimSpace(baseType) + baseType := pgsql.CanonicalizeBaseType(pgsql.ExtractBaseTypeLower(sqlType)) return intTypes[baseType] } diff --git a/pkg/writers/pgsql/writer.go b/pkg/writers/pgsql/writer.go index 726be72..cd588b2 100644 --- a/pkg/writers/pgsql/writer.go +++ b/pkg/writers/pgsql/writer.go @@ -493,18 +493,19 @@ func (w *Writer) generateColumnDefinition(col *models.Column) string { // Type with length/precision - convert to valid PostgreSQL type baseType := pgsql.ConvertSQLType(col.Type) typeStr := baseType + hasExplicitTypeModifier := pgsql.HasExplicitTypeModifier(baseType) // Only add size specifiers for types that support them - if col.Length > 0 && col.Precision == 0 { - if supportsLength(baseType) { + if !hasExplicitTypeModifier && col.Length > 0 && col.Precision == 0 { + if pgsql.SupportsLength(baseType) { typeStr = fmt.Sprintf("%s(%d)", baseType, col.Length) } else if isTextTypeWithoutLength(baseType) { // Convert text with length to varchar typeStr = fmt.Sprintf("varchar(%d)", col.Length) } // For types that don't support length (integer, bigint, etc.), ignore the length - } else if col.Precision > 0 { - if supportsPrecision(baseType) { + } else if !hasExplicitTypeModifier && col.Precision > 0 { + if pgsql.SupportsPrecision(baseType) { if col.Scale > 0 { typeStr = fmt.Sprintf("%s(%d,%d)", baseType, col.Precision, col.Scale) } else { @@ -1268,30 +1269,6 @@ func isTextType(colType string) bool { return false } -// supportsLength checks if a PostgreSQL type supports length specification -func supportsLength(colType string) bool { - lengthTypes := []string{"varchar", "character varying", "char", "character", "bit", "bit varying", "varbit"} - lowerType := strings.ToLower(colType) - for _, t := range lengthTypes { - if lowerType == t || strings.HasPrefix(lowerType, t+"(") { - return true - } - } - return false -} - -// supportsPrecision checks if a PostgreSQL type supports precision/scale specification -func supportsPrecision(colType string) bool { - precisionTypes := []string{"numeric", "decimal", "time", "timestamp", "timestamptz", "timestamp with time zone", "timestamp without time zone", "time with time zone", "time without time zone", "interval"} - lowerType := strings.ToLower(colType) - for _, t := range precisionTypes { - if lowerType == t || strings.HasPrefix(lowerType, t+"(") { - return true - } - } - return false -} - // isTextTypeWithoutLength checks if type is text (which should convert to varchar when length is specified) func isTextTypeWithoutLength(colType string) bool { return strings.EqualFold(colType, "text") diff --git a/pkg/writers/pgsql/writer_test.go b/pkg/writers/pgsql/writer_test.go index 06adcdc..776f6f1 100644 --- a/pkg/writers/pgsql/writer_test.go +++ b/pkg/writers/pgsql/writer_test.go @@ -426,11 +426,11 @@ func TestWriteAllConstraintTypes(t *testing.T) { // Verify all constraint types are present expectedConstraints := map[string]string{ - "Primary Key": "PRIMARY KEY", - "Unique": "ADD CONSTRAINT uq_order_number UNIQUE (order_number)", - "Check (total)": "ADD CONSTRAINT ck_total_positive CHECK (total > 0)", - "Check (status)": "ADD CONSTRAINT ck_status_valid CHECK (status IN ('pending', 'completed', 'cancelled'))", - "Foreign Key": "FOREIGN KEY", + "Primary Key": "PRIMARY KEY", + "Unique": "ADD CONSTRAINT uq_order_number UNIQUE (order_number)", + "Check (total)": "ADD CONSTRAINT ck_total_positive CHECK (total > 0)", + "Check (status)": "ADD CONSTRAINT ck_status_valid CHECK (status IN ('pending', 'completed', 'cancelled'))", + "Foreign Key": "FOREIGN KEY", } for name, expected := range expectedConstraints { @@ -715,11 +715,11 @@ func TestColumnSizeSpecifiers(t *testing.T) { // Verify valid patterns ARE present validPatterns := []string{ - "integer", // without size - "bigint", // without size - "smallint", // without size - "varchar(100)", // text converted to varchar with length - "varchar(50)", // varchar with length + "integer", // without size + "bigint", // without size + "smallint", // without size + "varchar(100)", // text converted to varchar with length + "varchar(50)", // varchar with length "decimal(19,4)", // decimal with precision and scale } for _, pattern := range validPatterns { @@ -729,6 +729,56 @@ func TestColumnSizeSpecifiers(t *testing.T) { } } +func TestGenerateColumnDefinition_PreservesExplicitTypeModifiers(t *testing.T) { + writer := NewWriter(&writers.WriterOptions{}) + + cases := []struct { + name string + colType string + length int + precision int + scale int + wantType string + }{ + { + name: "character varying already includes length", + colType: "character varying(50)", + length: 50, + wantType: "character varying(50)", + }, + { + name: "numeric already includes precision", + colType: "numeric(10,2)", + precision: 10, + scale: 2, + wantType: "numeric(10,2)", + }, + { + name: "custom vector modifier preserved", + colType: "vector(1536)", + wantType: "vector(1536)", + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + col := models.InitColumn("sample", "events", "public") + col.Type = tc.colType + col.Length = tc.length + col.Precision = tc.precision + col.Scale = tc.scale + + def := writer.generateColumnDefinition(col) + if !strings.Contains(def, " "+tc.wantType+" ") && !strings.HasSuffix(def, " "+tc.wantType) { + t.Fatalf("generated definition %q does not contain expected type %q", def, tc.wantType) + } + if strings.Contains(def, ")(") { + t.Fatalf("generated definition %q appears to duplicate modifiers", def) + } + }) + } +} + func TestGenerateAddColumnStatements(t *testing.T) { // Create a test database with tables that have new columns db := models.InitDatabase("testdb")