8 Commits

Author SHA1 Message Date
Hein
3d9cc7ec58 .
All checks were successful
Release / Build and Release (push) Successful in -25m33s
2026-02-20 16:32:19 +02:00
Hein
480038d51d feat(writers): quote default values based on SQL column type
Some checks failed
CI / Test (1.24) (push) Successful in -22m47s
CI / Lint (push) Failing after -24m34s
Integration Tests / Integration Tests (push) Successful in -25m0s
CI / Test (1.25) (push) Successful in -22m35s
CI / Build (push) Successful in -24m43s
Release / Build and Release (push) Successful in -21m46s
Bun and GORM struct tags now emit quoted defaults for string/date/time/UUID
columns (e.g. default:'disconnected') and unquoted defaults for numeric and
boolean columns (e.g. default:0, default:true). Function-call expressions
such as now() or gen_random_uuid() are never quoted regardless of type.

Adds QuoteDefaultValue(value, sqlType) helper in pkg/writers and updates
both type mappers and the bun writer tests accordingly.
2026-02-20 16:03:50 +02:00
77436757c8 fix(type_mapper): update timestamp type mapping to use SqlTimeStamp
All checks were successful
CI / Test (1.24) (push) Successful in -25m13s
CI / Test (1.25) (push) Successful in -25m10s
CI / Build (push) Successful in -26m2s
CI / Lint (push) Successful in -25m39s
Release / Build and Release (push) Successful in -25m49s
Integration Tests / Integration Tests (push) Successful in -25m26s
2026-02-08 21:35:27 +02:00
5e6f03e412 feat(type_mapper): add support for serial types and auto-increment tags
All checks were successful
CI / Build (push) Successful in -25m39s
Integration Tests / Integration Tests (push) Successful in -25m15s
CI / Test (1.24) (push) Successful in -24m39s
CI / Test (1.25) (push) Successful in -24m24s
CI / Lint (push) Successful in -25m9s
Release / Build and Release (push) Successful in -25m21s
2026-02-08 17:48:58 +02:00
1dcbc79387 feat(pgsql): enhance data type mapping to support serial types
All checks were successful
CI / Test (1.25) (push) Successful in -24m18s
CI / Test (1.24) (push) Successful in -24m6s
CI / Build (push) Successful in -25m14s
CI / Lint (push) Successful in -24m47s
Release / Build and Release (push) Successful in -25m37s
Integration Tests / Integration Tests (push) Successful in -25m9s
2026-02-08 17:31:28 +02:00
59c4a5ebf8 test(writer): enhance has-many relationship tests with join tag verification
All checks were successful
CI / Test (1.24) (push) Successful in -25m9s
CI / Test (1.25) (push) Successful in -25m0s
CI / Build (push) Successful in -25m57s
CI / Lint (push) Successful in -25m29s
Release / Build and Release (push) Successful in -25m38s
Integration Tests / Integration Tests (push) Successful in -25m19s
2026-02-08 15:20:20 +02:00
091e1913ee feat(version): retrieve version and build date from VCS if unset
All checks were successful
CI / Test (1.24) (push) Successful in -25m19s
CI / Test (1.25) (push) Successful in -25m1s
CI / Build (push) Successful in -25m56s
CI / Lint (push) Successful in -25m33s
Integration Tests / Integration Tests (push) Successful in -25m32s
2026-02-08 15:04:03 +02:00
0e6e94797c feat(version): add version command to display version and build date
All checks were successful
CI / Test (1.24) (push) Successful in -25m14s
CI / Test (1.25) (push) Successful in -25m10s
CI / Build (push) Successful in -26m0s
CI / Lint (push) Successful in -25m38s
Release / Build and Release (push) Successful in -25m46s
Integration Tests / Integration Tests (push) Successful in -25m13s
2026-02-08 14:58:39 +02:00
13 changed files with 329 additions and 88 deletions

View File

@@ -25,6 +25,7 @@ jobs:
id: get_version id: get_version
run: | run: |
echo "VERSION=${GITHUB_REF#refs/tags/}" >> $GITHUB_OUTPUT echo "VERSION=${GITHUB_REF#refs/tags/}" >> $GITHUB_OUTPUT
echo "BUILD_DATE=$(date -u '+%Y-%m-%d %H:%M:%S UTC')" >> $GITHUB_OUTPUT
echo "Version: ${GITHUB_REF#refs/tags/}" echo "Version: ${GITHUB_REF#refs/tags/}"
- name: Build binaries for multiple platforms - name: Build binaries for multiple platforms
@@ -32,19 +33,19 @@ jobs:
mkdir -p dist mkdir -p dist
# Linux AMD64 # Linux AMD64
GOOS=linux GOARCH=amd64 go build -o dist/relspec-linux-amd64 -ldflags "-X main.version=${{ steps.get_version.outputs.VERSION }}" ./cmd/relspec GOOS=linux GOARCH=amd64 go build -o dist/relspec-linux-amd64 -ldflags "-X 'main.version=${{ steps.get_version.outputs.VERSION }}' -X 'main.buildDate=${{ steps.get_version.outputs.BUILD_DATE }}'" ./cmd/relspec
# Linux ARM64 # Linux ARM64
GOOS=linux GOARCH=arm64 go build -o dist/relspec-linux-arm64 -ldflags "-X main.version=${{ steps.get_version.outputs.VERSION }}" ./cmd/relspec GOOS=linux GOARCH=arm64 go build -o dist/relspec-linux-arm64 -ldflags "-X 'main.version=${{ steps.get_version.outputs.VERSION }}' -X 'main.buildDate=${{ steps.get_version.outputs.BUILD_DATE }}'" ./cmd/relspec
# macOS AMD64 # macOS AMD64
GOOS=darwin GOARCH=amd64 go build -o dist/relspec-darwin-amd64 -ldflags "-X main.version=${{ steps.get_version.outputs.VERSION }}" ./cmd/relspec GOOS=darwin GOARCH=amd64 go build -o dist/relspec-darwin-amd64 -ldflags "-X 'main.version=${{ steps.get_version.outputs.VERSION }}' -X 'main.buildDate=${{ steps.get_version.outputs.BUILD_DATE }}'" ./cmd/relspec
# macOS ARM64 (Apple Silicon) # macOS ARM64 (Apple Silicon)
GOOS=darwin GOARCH=arm64 go build -o dist/relspec-darwin-arm64 -ldflags "-X main.version=${{ steps.get_version.outputs.VERSION }}" ./cmd/relspec GOOS=darwin GOARCH=arm64 go build -o dist/relspec-darwin-arm64 -ldflags "-X 'main.version=${{ steps.get_version.outputs.VERSION }}' -X 'main.buildDate=${{ steps.get_version.outputs.BUILD_DATE }}'" ./cmd/relspec
# Windows AMD64 # Windows AMD64
GOOS=windows GOARCH=amd64 go build -o dist/relspec-windows-amd64.exe -ldflags "-X main.version=${{ steps.get_version.outputs.VERSION }}" ./cmd/relspec GOOS=windows GOARCH=amd64 go build -o dist/relspec-windows-amd64.exe -ldflags "-X 'main.version=${{ steps.get_version.outputs.VERSION }}' -X 'main.buildDate=${{ steps.get_version.outputs.BUILD_DATE }}'" ./cmd/relspec
# Create checksums # Create checksums
cd dist cd dist

View File

@@ -14,6 +14,11 @@ GOGET=$(GOCMD) get
GOMOD=$(GOCMD) mod GOMOD=$(GOCMD) mod
GOCLEAN=$(GOCMD) clean GOCLEAN=$(GOCMD) clean
# Version information
VERSION := $(shell git describe --tags --always --dirty 2>/dev/null || echo "dev")
BUILD_DATE := $(shell date -u +"%Y-%m-%d %H:%M:%S UTC")
LDFLAGS := -X 'main.version=$(VERSION)' -X 'main.buildDate=$(BUILD_DATE)'
# Auto-detect container runtime (Docker or Podman) # Auto-detect container runtime (Docker or Podman)
CONTAINER_RUNTIME := $(shell \ CONTAINER_RUNTIME := $(shell \
if command -v podman > /dev/null 2>&1; then \ if command -v podman > /dev/null 2>&1; then \
@@ -37,9 +42,9 @@ COMPOSE_CMD := $(shell \
all: lint test build ## Run linting, tests, and build all: lint test build ## Run linting, tests, and build
build: deps ## Build the binary build: deps ## Build the binary
@echo "Building $(BINARY_NAME)..." @echo "Building $(BINARY_NAME) $(VERSION)..."
@mkdir -p $(BUILD_DIR) @mkdir -p $(BUILD_DIR)
$(GOBUILD) -o $(BUILD_DIR)/$(BINARY_NAME) ./cmd/relspec $(GOBUILD) -ldflags "$(LDFLAGS)" -o $(BUILD_DIR)/$(BINARY_NAME) ./cmd/relspec
@echo "Build complete: $(BUILD_DIR)/$(BINARY_NAME)" @echo "Build complete: $(BUILD_DIR)/$(BINARY_NAME)"
test: test-unit ## Run all unit tests (alias for test-unit) test: test-unit ## Run all unit tests (alias for test-unit)
@@ -91,8 +96,8 @@ clean: ## Clean build artifacts
@echo "Clean complete" @echo "Clean complete"
install: ## Install the binary to $GOPATH/bin install: ## Install the binary to $GOPATH/bin
@echo "Installing $(BINARY_NAME)..." @echo "Installing $(BINARY_NAME) $(VERSION)..."
$(GOCMD) install ./cmd/relspec $(GOCMD) install -ldflags "$(LDFLAGS)" ./cmd/relspec
@echo "Install complete" @echo "Install complete"
deps: ## Download dependencies deps: ## Download dependencies

View File

@@ -1,9 +1,49 @@
package main package main
import ( import (
"fmt"
"runtime/debug"
"time"
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
var (
// Version information, set via ldflags during build
version = "dev"
buildDate = "unknown"
)
func init() {
// If version wasn't set via ldflags, try to get it from build info
if version == "dev" {
if info, ok := debug.ReadBuildInfo(); ok {
// Try to get version from VCS
var vcsRevision, vcsTime string
for _, setting := range info.Settings {
switch setting.Key {
case "vcs.revision":
if len(setting.Value) >= 7 {
vcsRevision = setting.Value[:7]
}
case "vcs.time":
vcsTime = setting.Value
}
}
if vcsRevision != "" {
version = vcsRevision
}
if vcsTime != "" {
if t, err := time.Parse(time.RFC3339, vcsTime); err == nil {
buildDate = t.UTC().Format("2006-01-02 15:04:05 UTC")
}
}
}
}
}
var rootCmd = &cobra.Command{ var rootCmd = &cobra.Command{
Use: "relspec", Use: "relspec",
Short: "RelSpec - Database schema conversion and analysis tool", Short: "RelSpec - Database schema conversion and analysis tool",
@@ -13,6 +53,9 @@ bidirectional conversion between various database schema formats.
It reads database schemas from multiple sources (live databases, DBML, It reads database schemas from multiple sources (live databases, DBML,
DCTX, DrawDB, etc.) and writes them to various formats (GORM, Bun, DCTX, DrawDB, etc.) and writes them to various formats (GORM, Bun,
JSON, YAML, SQL, etc.).`, JSON, YAML, SQL, etc.).`,
PersistentPreRun: func(cmd *cobra.Command, args []string) {
fmt.Printf("RelSpec %s (built: %s)\n\n", version, buildDate)
},
} }
func init() { func init() {
@@ -24,4 +67,5 @@ func init() {
rootCmd.AddCommand(editCmd) rootCmd.AddCommand(editCmd)
rootCmd.AddCommand(mergeCmd) rootCmd.AddCommand(mergeCmd)
rootCmd.AddCommand(splitCmd) rootCmd.AddCommand(splitCmd)
rootCmd.AddCommand(versionCmd)
} }

16
cmd/relspec/version.go Normal file
View File

@@ -0,0 +1,16 @@
package main
import (
"fmt"
"github.com/spf13/cobra"
)
var versionCmd = &cobra.Command{
Use: "version",
Short: "Print version information",
Run: func(cmd *cobra.Command, args []string) {
fmt.Printf("RelSpec %s\n", version)
fmt.Printf("Built: %s\n", buildDate)
},
}

View File

@@ -676,19 +676,8 @@ func (r *Reader) extractTableFromGormTag(tag string) (tablename string, schemaNa
// deriveTableName derives a table name from struct name // deriveTableName derives a table name from struct name
func (r *Reader) deriveTableName(structName string) string { func (r *Reader) deriveTableName(structName string) string {
// Remove "Model" prefix if present // Remove "Model" prefix if present, use the name as-is without transformation
name := strings.TrimPrefix(structName, "Model") return strings.TrimPrefix(structName, "Model")
// Convert PascalCase to snake_case
var result strings.Builder
for i, r := range name {
if i > 0 && r >= 'A' && r <= 'Z' {
result.WriteRune('_')
}
result.WriteRune(r)
}
return strings.ToLower(result.String())
} }
// parseColumn parses a struct field into a Column model // parseColumn parses a struct field into a Column model

View File

@@ -231,14 +231,13 @@ func (r *Reader) queryColumns(schemaName string) (map[string]map[string]*models.
} }
column := models.InitColumn(columnName, tableName, schema) column := models.InitColumn(columnName, tableName, schema)
column.Type = r.mapDataType(dataType, udtName)
column.NotNull = (isNullable == "NO")
column.Sequence = uint(ordinalPosition)
// Check if this is a serial type (has nextval default)
hasNextval := false
if columnDefault != nil { if columnDefault != nil {
// Parse default value - remove nextval for sequences
defaultVal := *columnDefault defaultVal := *columnDefault
if strings.HasPrefix(defaultVal, "nextval") { if strings.HasPrefix(defaultVal, "nextval") {
hasNextval = true
column.AutoIncrement = true column.AutoIncrement = true
column.Default = defaultVal column.Default = defaultVal
} else { } else {
@@ -246,6 +245,11 @@ func (r *Reader) queryColumns(schemaName string) (map[string]map[string]*models.
} }
} }
// Map data type, preserving serial types when detected
column.Type = r.mapDataType(dataType, udtName, hasNextval)
column.NotNull = (isNullable == "NO")
column.Sequence = uint(ordinalPosition)
if description != nil { if description != nil {
column.Description = *description column.Description = *description
} }

View File

@@ -3,6 +3,7 @@ package pgsql
import ( import (
"context" "context"
"fmt" "fmt"
"strings"
"github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5"
@@ -259,33 +260,46 @@ func (r *Reader) close() {
} }
// mapDataType maps PostgreSQL data types to canonical types // mapDataType maps PostgreSQL data types to canonical types
func (r *Reader) mapDataType(pgType, udtName string) string { func (r *Reader) mapDataType(pgType, udtName string, hasNextval bool) string {
// If the column has a nextval default, it's likely a serial type
// Map to the appropriate serial type instead of the base integer type
if hasNextval {
switch strings.ToLower(pgType) {
case "integer", "int", "int4":
return "serial"
case "bigint", "int8":
return "bigserial"
case "smallint", "int2":
return "smallserial"
}
}
// Map common PostgreSQL types // Map common PostgreSQL types
typeMap := map[string]string{ typeMap := map[string]string{
"integer": "int", "integer": "integer",
"bigint": "int64", "bigint": "bigint",
"smallint": "int16", "smallint": "smallint",
"int": "int", "int": "integer",
"int2": "int16", "int2": "smallint",
"int4": "int", "int4": "integer",
"int8": "int64", "int8": "bigint",
"serial": "int", "serial": "serial",
"bigserial": "int64", "bigserial": "bigserial",
"smallserial": "int16", "smallserial": "smallserial",
"numeric": "decimal", "numeric": "numeric",
"decimal": "decimal", "decimal": "decimal",
"real": "float32", "real": "real",
"double precision": "float64", "double precision": "double precision",
"float4": "float32", "float4": "real",
"float8": "float64", "float8": "double precision",
"money": "decimal", "money": "money",
"character varying": "string", "character varying": "varchar",
"varchar": "string", "varchar": "varchar",
"character": "string", "character": "char",
"char": "string", "char": "char",
"text": "string", "text": "text",
"boolean": "bool", "boolean": "boolean",
"bool": "bool", "bool": "boolean",
"date": "date", "date": "date",
"time": "time", "time": "time",
"time without time zone": "time", "time without time zone": "time",

View File

@@ -177,20 +177,20 @@ func TestMapDataType(t *testing.T) {
udtName string udtName string
expected string expected string
}{ }{
{"integer", "int4", "int"}, {"integer", "int4", "integer"},
{"bigint", "int8", "int64"}, {"bigint", "int8", "bigint"},
{"smallint", "int2", "int16"}, {"smallint", "int2", "smallint"},
{"character varying", "varchar", "string"}, {"character varying", "varchar", "varchar"},
{"text", "text", "string"}, {"text", "text", "text"},
{"boolean", "bool", "bool"}, {"boolean", "bool", "boolean"},
{"timestamp without time zone", "timestamp", "timestamp"}, {"timestamp without time zone", "timestamp", "timestamp"},
{"timestamp with time zone", "timestamptz", "timestamptz"}, {"timestamp with time zone", "timestamptz", "timestamptz"},
{"json", "json", "json"}, {"json", "json", "json"},
{"jsonb", "jsonb", "jsonb"}, {"jsonb", "jsonb", "jsonb"},
{"uuid", "uuid", "uuid"}, {"uuid", "uuid", "uuid"},
{"numeric", "numeric", "decimal"}, {"numeric", "numeric", "numeric"},
{"real", "float4", "float32"}, {"real", "float4", "real"},
{"double precision", "float8", "float64"}, {"double precision", "float8", "double precision"},
{"date", "date", "date"}, {"date", "date", "date"},
{"time without time zone", "time", "time"}, {"time without time zone", "time", "time"},
{"bytea", "bytea", "bytea"}, {"bytea", "bytea", "bytea"},
@@ -199,12 +199,31 @@ func TestMapDataType(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.pgType, func(t *testing.T) { t.Run(tt.pgType, func(t *testing.T) {
result := reader.mapDataType(tt.pgType, tt.udtName) result := reader.mapDataType(tt.pgType, tt.udtName, false)
if result != tt.expected { if result != tt.expected {
t.Errorf("mapDataType(%s, %s) = %s, expected %s", tt.pgType, tt.udtName, result, tt.expected) t.Errorf("mapDataType(%s, %s) = %s, expected %s", tt.pgType, tt.udtName, result, tt.expected)
} }
}) })
} }
// Test serial type detection with hasNextval=true
serialTests := []struct {
pgType string
expected string
}{
{"integer", "serial"},
{"bigint", "bigserial"},
{"smallint", "smallserial"},
}
for _, tt := range serialTests {
t.Run(tt.pgType+"_with_nextval", func(t *testing.T) {
result := reader.mapDataType(tt.pgType, "", true)
if result != tt.expected {
t.Errorf("mapDataType(%s, '', true) = %s, expected %s", tt.pgType, result, tt.expected)
}
})
}
} }
func TestParseIndexDefinition(t *testing.T) { func TestParseIndexDefinition(t *testing.T) {

View File

@@ -62,6 +62,17 @@ func (tm *TypeMapper) isSimpleType(sqlType string) bool {
return simpleTypes[sqlType] return simpleTypes[sqlType]
} }
// isSerialType checks if a SQL type is a serial type (auto-incrementing)
func (tm *TypeMapper) isSerialType(sqlType string) bool {
baseType := tm.extractBaseType(sqlType)
serialTypes := map[string]bool{
"serial": true,
"bigserial": true,
"smallserial": true,
}
return serialTypes[baseType]
}
// baseGoType returns the base Go type for a SQL type (not null, simple types only) // baseGoType returns the base Go type for a SQL type (not null, simple types only)
func (tm *TypeMapper) baseGoType(sqlType string) string { func (tm *TypeMapper) baseGoType(sqlType string) string {
typeMap := map[string]string{ typeMap := map[string]string{
@@ -122,10 +133,10 @@ func (tm *TypeMapper) bunGoType(sqlType string) string {
"decimal": tm.sqlTypesAlias + ".SqlFloat64", "decimal": tm.sqlTypesAlias + ".SqlFloat64",
// Date/Time types // Date/Time types
"timestamp": tm.sqlTypesAlias + ".SqlTime", "timestamp": tm.sqlTypesAlias + ".SqlTimeStamp",
"timestamp without time zone": tm.sqlTypesAlias + ".SqlTime", "timestamp without time zone": tm.sqlTypesAlias + ".SqlTimeStamp",
"timestamp with time zone": tm.sqlTypesAlias + ".SqlTime", "timestamp with time zone": tm.sqlTypesAlias + ".SqlTimeStamp",
"timestamptz": tm.sqlTypesAlias + ".SqlTime", "timestamptz": tm.sqlTypesAlias + ".SqlTimeStamp",
"date": tm.sqlTypesAlias + ".SqlDate", "date": tm.sqlTypesAlias + ".SqlDate",
"time": tm.sqlTypesAlias + ".SqlTime", "time": tm.sqlTypesAlias + ".SqlTime",
"time without time zone": tm.sqlTypesAlias + ".SqlTime", "time without time zone": tm.sqlTypesAlias + ".SqlTime",
@@ -190,10 +201,15 @@ func (tm *TypeMapper) BuildBunTag(column *models.Column, table *models.Table) st
parts = append(parts, "pk") parts = append(parts, "pk")
} }
// Auto increment (for serial types or explicit auto_increment)
if column.AutoIncrement || tm.isSerialType(column.Type) {
parts = append(parts, "autoincrement")
}
// Default value // Default value
if column.Default != nil { if column.Default != nil {
// Sanitize default value to remove backticks // Sanitize default value to remove backticks, then quote based on column type
safeDefault := writers.SanitizeStructTagValue(fmt.Sprintf("%v", column.Default)) safeDefault := writers.QuoteDefaultValue(writers.SanitizeStructTagValue(fmt.Sprintf("%v", column.Default)), column.Type)
parts = append(parts, fmt.Sprintf("default:%s", safeDefault)) parts = append(parts, fmt.Sprintf("default:%s", safeDefault))
} }
@@ -251,7 +267,15 @@ func (tm *TypeMapper) BuildRelationshipTag(constraint *models.Constraint, relTyp
if len(constraint.Columns) > 0 && len(constraint.ReferencedColumns) > 0 { if len(constraint.Columns) > 0 && len(constraint.ReferencedColumns) > 0 {
localCol := constraint.Columns[0] localCol := constraint.Columns[0]
foreignCol := constraint.ReferencedColumns[0] foreignCol := constraint.ReferencedColumns[0]
parts = append(parts, fmt.Sprintf("join:%s=%s", localCol, foreignCol))
// For has-many relationships, swap the columns
// has-one: join:fk_in_this_table=pk_in_other_table
// has-many: join:pk_in_this_table=fk_in_other_table
if relType == "has-many" {
parts = append(parts, fmt.Sprintf("join:%s=%s", foreignCol, localCol))
} else {
parts = append(parts, fmt.Sprintf("join:%s=%s", localCol, foreignCol))
}
} }
return strings.Join(parts, ",") return strings.Join(parts, ",")

View File

@@ -90,8 +90,8 @@ func TestWriter_WriteTable(t *testing.T) {
} }
// Verify Bun-specific elements // Verify Bun-specific elements
if !strings.Contains(generated, "bun:\"id,type:bigint,pk,") { if !strings.Contains(generated, "bun:\"id,type:bigint,pk,autoincrement,") {
t.Errorf("Missing Bun-style primary key tag") t.Errorf("Missing Bun-style primary key tag with autoincrement")
} }
} }
@@ -308,14 +308,20 @@ func TestWriter_MultipleReferencesToSameTable(t *testing.T) {
filepointerStr := string(filepointerContent) filepointerStr := string(filepointerContent)
// Should have two different has-many relationships with unique names // Should have two different has-many relationships with unique names
hasManyExpectations := []string{ hasManyExpectations := []struct {
"RelRIDFilepointerRequestOrgAPIEvents", // Has many via rid_filepointer_request fieldName string
"RelRIDFilepointerResponseOrgAPIEvents", // Has many via rid_filepointer_response tag string
}{
{"RelRIDFilepointerRequestOrgAPIEvents", "join:id_filepointer=rid_filepointer_request"}, // Has many via rid_filepointer_request
{"RelRIDFilepointerResponseOrgAPIEvents", "join:id_filepointer=rid_filepointer_response"}, // Has many via rid_filepointer_response
} }
for _, exp := range hasManyExpectations { for _, exp := range hasManyExpectations {
if !strings.Contains(filepointerStr, exp) { if !strings.Contains(filepointerStr, exp.fieldName) {
t.Errorf("Missing has-many relationship field: %s\nGenerated:\n%s", exp, filepointerStr) t.Errorf("Missing has-many relationship field: %s\nGenerated:\n%s", exp.fieldName, filepointerStr)
}
if !strings.Contains(filepointerStr, exp.tag) {
t.Errorf("Missing has-many relationship join tag: %s\nGenerated:\n%s", exp.tag, filepointerStr)
} }
} }
} }
@@ -455,10 +461,10 @@ func TestWriter_MultipleHasManyRelationships(t *testing.T) {
// Verify all has-many relationships have unique names // Verify all has-many relationships have unique names
hasManyExpectations := []string{ hasManyExpectations := []string{
"RelRIDAPIProviderOrgLogins", // Has many via Login "RelRIDAPIProviderOrgLogins", // Has many via Login
"RelRIDAPIProviderOrgFilepointers", // Has many via Filepointer "RelRIDAPIProviderOrgFilepointers", // Has many via Filepointer
"RelRIDAPIProviderOrgAPIEvents", // Has many via APIEvent "RelRIDAPIProviderOrgAPIEvents", // Has many via APIEvent
"RelRIDOwner", // Has one via rid_owner "RelRIDOwner", // Has one via rid_owner
} }
for _, exp := range hasManyExpectations { for _, exp := range hasManyExpectations {
@@ -561,8 +567,8 @@ func TestTypeMapper_SQLTypeToGoType_Bun(t *testing.T) {
{"bigint", false, "resolvespec_common.SqlInt64"}, {"bigint", false, "resolvespec_common.SqlInt64"},
{"varchar", true, "resolvespec_common.SqlString"}, // Bun uses sql types even for NOT NULL strings {"varchar", true, "resolvespec_common.SqlString"}, // Bun uses sql types even for NOT NULL strings
{"varchar", false, "resolvespec_common.SqlString"}, {"varchar", false, "resolvespec_common.SqlString"},
{"timestamp", true, "resolvespec_common.SqlTime"}, {"timestamp", true, "resolvespec_common.SqlTimeStamp"},
{"timestamp", false, "resolvespec_common.SqlTime"}, {"timestamp", false, "resolvespec_common.SqlTimeStamp"},
{"date", false, "resolvespec_common.SqlDate"}, {"date", false, "resolvespec_common.SqlDate"},
{"boolean", true, "bool"}, {"boolean", true, "bool"},
{"boolean", false, "resolvespec_common.SqlBool"}, {"boolean", false, "resolvespec_common.SqlBool"},
@@ -609,14 +615,75 @@ func TestTypeMapper_BuildBunTag(t *testing.T) {
want: []string{"email,", "type:varchar(255),", "nullzero,"}, want: []string{"email,", "type:varchar(255),", "nullzero,"},
}, },
{ {
name: "with default", name: "with default string",
column: &models.Column{ column: &models.Column{
Name: "status", Name: "status",
Type: "text", Type: "text",
NotNull: true, NotNull: true,
Default: "active", Default: "active",
}, },
want: []string{"status,", "type:text,", "default:active,"}, want: []string{"status,", "type:text,", "default:'active',"},
},
{
name: "with default integer",
column: &models.Column{
Name: "retries",
Type: "integer",
NotNull: true,
Default: "0",
},
want: []string{"retries,", "type:integer,", "default:0,"},
},
{
name: "with default boolean",
column: &models.Column{
Name: "active",
Type: "boolean",
NotNull: true,
Default: "true",
},
want: []string{"active,", "type:boolean,", "default:true,"},
},
{
name: "with default function call",
column: &models.Column{
Name: "created_at",
Type: "timestamp",
NotNull: true,
Default: "now()",
},
want: []string{"created_at,", "type:timestamp,", "default:now(),"},
},
{
name: "auto increment with AutoIncrement flag",
column: &models.Column{
Name: "id",
Type: "bigint",
NotNull: true,
IsPrimaryKey: true,
AutoIncrement: true,
},
want: []string{"id,", "type:bigint,", "pk,", "autoincrement,"},
},
{
name: "serial type (auto-increment)",
column: &models.Column{
Name: "id",
Type: "serial",
NotNull: true,
IsPrimaryKey: true,
},
want: []string{"id,", "type:serial,", "pk,", "autoincrement,"},
},
{
name: "bigserial type (auto-increment)",
column: &models.Column{
Name: "id",
Type: "bigserial",
NotNull: true,
IsPrimaryKey: true,
},
want: []string{"id,", "type:bigserial,", "pk,", "autoincrement,"},
}, },
} }

View File

@@ -158,10 +158,10 @@ func (tm *TypeMapper) nullableGoType(sqlType string) string {
"decimal": tm.sqlTypesAlias + ".SqlFloat64", "decimal": tm.sqlTypesAlias + ".SqlFloat64",
// Date/Time types // Date/Time types
"timestamp": tm.sqlTypesAlias + ".SqlTime", "timestamp": tm.sqlTypesAlias + ".SqlTimeStamp",
"timestamp without time zone": tm.sqlTypesAlias + ".SqlTime", "timestamp without time zone": tm.sqlTypesAlias + ".SqlTimeStamp",
"timestamp with time zone": tm.sqlTypesAlias + ".SqlTime", "timestamp with time zone": tm.sqlTypesAlias + ".SqlTimeStamp",
"timestamptz": tm.sqlTypesAlias + ".SqlTime", "timestamptz": tm.sqlTypesAlias + ".SqlTimeStamp",
"date": tm.sqlTypesAlias + ".SqlDate", "date": tm.sqlTypesAlias + ".SqlDate",
"time": tm.sqlTypesAlias + ".SqlTime", "time": tm.sqlTypesAlias + ".SqlTime",
"time without time zone": tm.sqlTypesAlias + ".SqlTime", "time without time zone": tm.sqlTypesAlias + ".SqlTime",
@@ -238,8 +238,8 @@ func (tm *TypeMapper) BuildGormTag(column *models.Column, table *models.Table) s
// Default value // Default value
if column.Default != nil { if column.Default != nil {
// Sanitize default value to remove backticks // Sanitize default value to remove backticks, then quote based on column type
safeDefault := writers.SanitizeStructTagValue(fmt.Sprintf("%v", column.Default)) safeDefault := writers.QuoteDefaultValue(writers.SanitizeStructTagValue(fmt.Sprintf("%v", column.Default)), column.Type)
parts = append(parts, fmt.Sprintf("default:%s", safeDefault)) parts = append(parts, fmt.Sprintf("default:%s", safeDefault))
} }

View File

@@ -655,7 +655,7 @@ func TestTypeMapper_SQLTypeToGoType(t *testing.T) {
{"varchar", true, "string"}, {"varchar", true, "string"},
{"varchar", false, "sql_types.SqlString"}, {"varchar", false, "sql_types.SqlString"},
{"timestamp", true, "time.Time"}, {"timestamp", true, "time.Time"},
{"timestamp", false, "sql_types.SqlTime"}, {"timestamp", false, "sql_types.SqlTimeStamp"},
{"boolean", true, "bool"}, {"boolean", true, "bool"},
{"boolean", false, "sql_types.SqlBool"}, {"boolean", false, "sql_types.SqlBool"},
} }

View File

@@ -81,6 +81,64 @@ func SanitizeFilename(name string) string {
return name return name
} }
// QuoteDefaultValue wraps a sanitized default value in single quotes when the SQL
// column type requires it (strings, dates, times, UUIDs, enums). Numeric types
// (integers, floats, serials) and boolean types are left unquoted. Function-call
// expressions such as now() or gen_random_uuid() are always left unquoted regardless
// of type, because they contain parentheses.
//
// Examples (varchar): "disconnected" → "'disconnected'"
// Examples (boolean): "true" → "true"
// Examples (bigint): "0" → "0"
// Examples (timestamp): "now()" → "now()" (function call never quoted)
func QuoteDefaultValue(value, sqlType string) string {
// Function calls are never quoted regardless of column type.
if strings.Contains(value, "(") || strings.Contains(value, ")") {
return value
}
// Normalise the SQL type: lowercase, strip length/precision suffix.
baseType := strings.ToLower(strings.TrimSpace(sqlType))
if idx := strings.Index(baseType, "("); idx > 0 {
baseType = baseType[:idx]
}
// Types whose default values must NOT be quoted.
unquotedTypes := map[string]bool{
// Integer types
"integer": true,
"int": true,
"int2": true,
"int4": true,
"int8": true,
"smallint": true,
"bigint": true,
"serial": true,
"smallserial": true,
"bigserial": true,
// Float / numeric types
"real": true,
"float": true,
"float4": true,
"float8": true,
"double precision": true,
"numeric": true,
"decimal": true,
"money": true,
// Boolean
"boolean": true,
"bool": true,
}
if unquotedTypes[baseType] {
return value
}
// Everything else (text, varchar, char, uuid, date, time, timestamp, json, …)
// is treated as a quoted literal.
return "'" + value + "'"
}
// SanitizeStructTagValue sanitizes a value to be safely used inside Go struct tags. // SanitizeStructTagValue sanitizes a value to be safely used inside Go struct tags.
// Go struct tags are delimited by backticks, so any backtick in the value would break the syntax. // Go struct tags are delimited by backticks, so any backtick in the value would break the syntax.
// This function: // This function: