From aad5db51754b5f8875cebd790deecfc75811e042 Mon Sep 17 00:00:00 2001 From: Hein Date: Fri, 19 Dec 2025 22:28:24 +0200 Subject: [PATCH] fix: readers and linting issues --- Makefile | 9 + cmd/relspec/convert.go | 20 +- cmd/relspec/diff.go | 13 +- pkg/diff/diff.go | 5 +- pkg/diff/types.go | 48 +-- pkg/readers/bun/reader.go | 11 +- pkg/readers/bun/reader_test.go | 522 ++++++++++++++++++++++++++++++++ pkg/readers/gorm/reader.go | 16 +- pkg/readers/gorm/reader_test.go | 464 ++++++++++++++++++++++++++++ pkg/readers/prisma/reader.go | 28 +- pkg/readers/typeorm/reader.go | 4 +- pkg/writers/typeorm/writer.go | 16 +- tests/assets/bun/complex.go | 60 ++++ tests/assets/bun/simple.go | 18 ++ tests/assets/gorm/complex.go | 65 ++++ tests/assets/gorm/simple.go | 18 ++ 16 files changed, 1237 insertions(+), 80 deletions(-) create mode 100644 pkg/readers/bun/reader_test.go create mode 100644 pkg/readers/gorm/reader_test.go create mode 100644 tests/assets/bun/complex.go create mode 100644 tests/assets/bun/simple.go create mode 100644 tests/assets/gorm/complex.go create mode 100644 tests/assets/gorm/simple.go diff --git a/Makefile b/Makefile index a0b327a..2976aef 100644 --- a/Makefile +++ b/Makefile @@ -40,6 +40,15 @@ lint: ## Run linter exit 1; \ fi +lintfix: ## Run linter + @echo "Running linter..." + @if command -v golangci-lint > /dev/null; then \ + golangci-lint run --config=.golangci.json --fix; \ + else \ + echo "golangci-lint not installed. Install with: go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest"; \ + exit 1; \ + fi + clean: ## Clean build artifacts @echo "Cleaning..." $(GOCLEAN) diff --git a/cmd/relspec/convert.go b/cmd/relspec/convert.go index ce59b1a..a04c27d 100644 --- a/cmd/relspec/convert.go +++ b/cmd/relspec/convert.go @@ -6,6 +6,8 @@ import ( "strings" "time" + "github.com/spf13/cobra" + "git.warky.dev/wdevs/relspecgo/pkg/models" "git.warky.dev/wdevs/relspecgo/pkg/readers" "git.warky.dev/wdevs/relspecgo/pkg/readers/bun" @@ -29,7 +31,6 @@ import ( wprisma "git.warky.dev/wdevs/relspecgo/pkg/writers/prisma" wtypeorm "git.warky.dev/wdevs/relspecgo/pkg/writers/typeorm" wyaml "git.warky.dev/wdevs/relspecgo/pkg/writers/yaml" - "github.com/spf13/cobra" ) var ( @@ -140,9 +141,18 @@ func init() { convertCmd.Flags().StringVar(&convertPackageName, "package", "", "Package name (for code generation formats like gorm/bun)") convertCmd.Flags().StringVar(&convertSchemaFilter, "schema", "", "Filter to a specific schema by name (required for formats like dctx that only support single schemas)") - convertCmd.MarkFlagRequired("from") - convertCmd.MarkFlagRequired("to") - convertCmd.MarkFlagRequired("to-path") + err := convertCmd.MarkFlagRequired("from") + if err != nil { + fmt.Fprintf(os.Stderr, "Error marking from flag as required: %v\n", err) + } + err = convertCmd.MarkFlagRequired("to") + if err != nil { + fmt.Fprintf(os.Stderr, "Error marking to flag as required: %v\n", err) + } + err = convertCmd.MarkFlagRequired("to-path") + if err != nil { + fmt.Fprintf(os.Stderr, "Error marking to-path flag as required: %v\n", err) + } } func runConvert(cmd *cobra.Command, args []string) error { @@ -344,7 +354,7 @@ func writeDatabase(db *models.Database, dbType, outputPath, packageName, schemaF } // For formats like DCTX that don't support full database writes, require schema filter - if strings.ToLower(dbType) == "dctx" { + if strings.EqualFold(dbType, "dctx") { if len(db.Schemas) == 0 { return fmt.Errorf("no schemas found in database") } diff --git a/cmd/relspec/diff.go b/cmd/relspec/diff.go index 168de58..a6c243e 100644 --- a/cmd/relspec/diff.go +++ b/cmd/relspec/diff.go @@ -6,6 +6,8 @@ import ( "strings" "time" + "github.com/spf13/cobra" + "git.warky.dev/wdevs/relspecgo/pkg/diff" "git.warky.dev/wdevs/relspecgo/pkg/models" "git.warky.dev/wdevs/relspecgo/pkg/readers" @@ -15,7 +17,6 @@ import ( "git.warky.dev/wdevs/relspecgo/pkg/readers/json" "git.warky.dev/wdevs/relspecgo/pkg/readers/pgsql" "git.warky.dev/wdevs/relspecgo/pkg/readers/yaml" - "github.com/spf13/cobra" ) var ( @@ -96,8 +97,14 @@ func init() { diffCmd.Flags().StringVar(&outputFormat, "format", "summary", "Output format (summary, json, html)") diffCmd.Flags().StringVar(&outputPath, "output", "", "Output file path (default: stdout for summary, required for json/html)") - diffCmd.MarkFlagRequired("from") - diffCmd.MarkFlagRequired("to") + err := diffCmd.MarkFlagRequired("from") + if err != nil { + fmt.Fprintf(os.Stderr, "Error marking from flag as required: %v\n", err) + } + err = diffCmd.MarkFlagRequired("to") + if err != nil { + fmt.Fprintf(os.Stderr, "Error marking to flag as required: %v\n", err) + } } func runDiff(cmd *cobra.Command, args []string) error { diff --git a/pkg/diff/diff.go b/pkg/diff/diff.go index dbb6a55..fbfd6e2 100644 --- a/pkg/diff/diff.go +++ b/pkg/diff/diff.go @@ -2,14 +2,15 @@ package diff import ( "reflect" + "git.warky.dev/wdevs/relspecgo/pkg/models" ) // CompareDatabases compares two database models and returns the differences func CompareDatabases(source, target *models.Database) *DiffResult { result := &DiffResult{ - Source: source.Name, - Target: target.Name, + Source: source.Name, + Target: target.Name, Schemas: compareSchemas(source.Schemas, target.Schemas), } return result diff --git a/pkg/diff/types.go b/pkg/diff/types.go index 9d2fb7d..26beda9 100644 --- a/pkg/diff/types.go +++ b/pkg/diff/types.go @@ -4,8 +4,8 @@ import "git.warky.dev/wdevs/relspecgo/pkg/models" // DiffResult represents the complete difference analysis between two databases type DiffResult struct { - Source string `json:"source"` - Target string `json:"target"` + Source string `json:"source"` + Target string `json:"target"` Schemas *SchemaDiff `json:"schemas"` } @@ -18,17 +18,17 @@ type SchemaDiff struct { // SchemaChange represents changes within a schema type SchemaChange struct { - Name string `json:"name"` - Tables *TableDiff `json:"tables,omitempty"` - Views *ViewDiff `json:"views,omitempty"` - Sequences *SequenceDiff `json:"sequences,omitempty"` + Name string `json:"name"` + Tables *TableDiff `json:"tables,omitempty"` + Views *ViewDiff `json:"views,omitempty"` + Sequences *SequenceDiff `json:"sequences,omitempty"` } // TableDiff represents differences in tables type TableDiff struct { - Missing []*models.Table `json:"missing"` // Tables in source but not in target - Extra []*models.Table `json:"extra"` // Tables in target but not in source - Modified []*TableChange `json:"modified"` // Tables that exist in both but differ + Missing []*models.Table `json:"missing"` // Tables in source but not in target + Extra []*models.Table `json:"extra"` // Tables in target but not in source + Modified []*TableChange `json:"modified"` // Tables that exist in both but differ } // TableChange represents changes within a table @@ -50,16 +50,16 @@ type ColumnDiff struct { // ColumnChange represents a modified column type ColumnChange struct { - Name string `json:"name"` - Source *models.Column `json:"source"` - Target *models.Column `json:"target"` - Changes map[string]any `json:"changes"` // Map of field name to what changed + Name string `json:"name"` + Source *models.Column `json:"source"` + Target *models.Column `json:"target"` + Changes map[string]any `json:"changes"` // Map of field name to what changed } // IndexDiff represents differences in indexes type IndexDiff struct { - Missing []*models.Index `json:"missing"` // Indexes in source but not in target - Extra []*models.Index `json:"extra"` // Indexes in target but not in source + Missing []*models.Index `json:"missing"` // Indexes in source but not in target + Extra []*models.Index `json:"extra"` // Indexes in target but not in source Modified []*IndexChange `json:"modified"` // Indexes that exist in both but differ } @@ -103,8 +103,8 @@ type RelationshipChange struct { // ViewDiff represents differences in views type ViewDiff struct { - Missing []*models.View `json:"missing"` // Views in source but not in target - Extra []*models.View `json:"extra"` // Views in target but not in source + Missing []*models.View `json:"missing"` // Views in source but not in target + Extra []*models.View `json:"extra"` // Views in target but not in source Modified []*ViewChange `json:"modified"` // Views that exist in both but differ } @@ -133,14 +133,14 @@ type SequenceChange struct { // Summary provides counts for quick overview type Summary struct { - Schemas SchemaSummary `json:"schemas"` - Tables TableSummary `json:"tables"` - Columns ColumnSummary `json:"columns"` - Indexes IndexSummary `json:"indexes"` - Constraints ConstraintSummary `json:"constraints"` + Schemas SchemaSummary `json:"schemas"` + Tables TableSummary `json:"tables"` + Columns ColumnSummary `json:"columns"` + Indexes IndexSummary `json:"indexes"` + Constraints ConstraintSummary `json:"constraints"` Relationships RelationshipSummary `json:"relationships"` - Views ViewSummary `json:"views"` - Sequences SequenceSummary `json:"sequences"` + Views ViewSummary `json:"views"` + Sequences SequenceSummary `json:"sequences"` } type SchemaSummary struct { diff --git a/pkg/readers/bun/reader.go b/pkg/readers/bun/reader.go index 213302b..7b77bed 100644 --- a/pkg/readers/bun/reader.go +++ b/pkg/readers/bun/reader.go @@ -626,17 +626,14 @@ func (r *Reader) parseColumn(fieldName string, fieldType ast.Expr, tag string, s // - nullzero tag means the field is nullable (can be NULL in DB) // - absence of nullzero means the field is NOT NULL // - primitive types (int64, bool, string) are NOT NULL by default + column.NotNull = true + // Primary keys are always NOT NULL + if strings.Contains(bunTag, "nullzero") { column.NotNull = false - } else if r.isNullableGoType(fieldType) { - // SqlString, SqlInt, etc. without nullzero tag means NOT NULL - column.NotNull = true } else { - // Primitive types are NOT NULL by default - column.NotNull = true + column.NotNull = !r.isNullableGoType(fieldType) } - - // Primary keys are always NOT NULL if column.IsPrimaryKey { column.NotNull = true } diff --git a/pkg/readers/bun/reader_test.go b/pkg/readers/bun/reader_test.go new file mode 100644 index 0000000..10fb64c --- /dev/null +++ b/pkg/readers/bun/reader_test.go @@ -0,0 +1,522 @@ +package bun + +import ( + "path/filepath" + "testing" + + "git.warky.dev/wdevs/relspecgo/pkg/models" + "git.warky.dev/wdevs/relspecgo/pkg/readers" +) + +func TestReader_ReadDatabase_Simple(t *testing.T) { + opts := &readers.ReaderOptions{ + FilePath: filepath.Join("..", "..", "..", "tests", "assets", "bun", "simple.go"), + } + + reader := NewReader(opts) + db, err := reader.ReadDatabase() + if err != nil { + t.Fatalf("ReadDatabase() error = %v", err) + } + + if db == nil { + t.Fatal("ReadDatabase() returned nil database") + } + + if len(db.Schemas) == 0 { + t.Fatal("Expected at least one schema") + } + + schema := db.Schemas[0] + if schema.Name != "public" { + t.Errorf("Expected schema name 'public', got '%s'", schema.Name) + } + + if len(schema.Tables) != 1 { + t.Fatalf("Expected 1 table, got %d", len(schema.Tables)) + } + + table := schema.Tables[0] + if table.Name != "users" { + t.Errorf("Expected table name 'users', got '%s'", table.Name) + } + + if len(table.Columns) != 6 { + t.Errorf("Expected 6 columns, got %d", len(table.Columns)) + } + + // Verify id column - primary key should be NOT NULL + idCol, exists := table.Columns["id"] + if !exists { + t.Fatal("Column 'id' not found") + } + if !idCol.IsPrimaryKey { + t.Error("Column 'id' should be primary key") + } + if !idCol.AutoIncrement { + t.Error("Column 'id' should be auto-increment") + } + if !idCol.NotNull { + t.Error("Column 'id' should be NOT NULL (primary keys are always NOT NULL)") + } + if idCol.Type != "bigint" { + t.Errorf("Expected id type 'bigint', got '%s'", idCol.Type) + } + + // Verify email column - explicit notnull tag should be NOT NULL + emailCol, exists := table.Columns["email"] + if !exists { + t.Fatal("Column 'email' not found") + } + if !emailCol.NotNull { + t.Error("Column 'email' should be NOT NULL (explicit 'notnull' tag)") + } + if emailCol.Type != "varchar" || emailCol.Length != 255 { + t.Errorf("Expected email type 'varchar(255)', got '%s' with length %d", emailCol.Type, emailCol.Length) + } + + // Verify name column - primitive string type should be NOT NULL by default in Bun + nameCol, exists := table.Columns["name"] + if !exists { + t.Fatal("Column 'name' not found") + } + if !nameCol.NotNull { + t.Error("Column 'name' should be NOT NULL (primitive string type, no nullzero tag)") + } + if nameCol.Type != "text" { + t.Errorf("Expected name type 'text', got '%s'", nameCol.Type) + } + + // Verify age column - pointer type should be nullable (NOT NULL = false) + ageCol, exists := table.Columns["age"] + if !exists { + t.Fatal("Column 'age' not found") + } + if ageCol.NotNull { + t.Error("Column 'age' should be nullable (pointer type *int)") + } + if ageCol.Type != "integer" { + t.Errorf("Expected age type 'integer', got '%s'", ageCol.Type) + } + + // Verify is_active column - primitive bool type should be NOT NULL by default + isActiveCol, exists := table.Columns["is_active"] + if !exists { + t.Fatal("Column 'is_active' not found") + } + if !isActiveCol.NotNull { + t.Error("Column 'is_active' should be NOT NULL (primitive bool type, no nullzero tag)") + } + if isActiveCol.Type != "boolean" { + t.Errorf("Expected is_active type 'boolean', got '%s'", isActiveCol.Type) + } + + // Verify created_at column - time.Time should be NOT NULL by default + createdAtCol, exists := table.Columns["created_at"] + if !exists { + t.Fatal("Column 'created_at' not found") + } + if !createdAtCol.NotNull { + t.Error("Column 'created_at' should be NOT NULL (time.Time is NOT NULL by default)") + } + if createdAtCol.Type != "timestamp" { + t.Errorf("Expected created_at type 'timestamp', got '%s'", createdAtCol.Type) + } + + // Verify unique index on email + if len(table.Indexes) < 1 { + t.Error("Expected at least 1 index on users table") + } +} + +func TestReader_ReadDatabase_Complex(t *testing.T) { + opts := &readers.ReaderOptions{ + FilePath: filepath.Join("..", "..", "..", "tests", "assets", "bun", "complex.go"), + } + + reader := NewReader(opts) + db, err := reader.ReadDatabase() + if err != nil { + t.Fatalf("ReadDatabase() error = %v", err) + } + + if db == nil { + t.Fatal("ReadDatabase() returned nil database") + } + + // Verify schema + if len(db.Schemas) != 1 { + t.Fatalf("Expected 1 schema, got %d", len(db.Schemas)) + } + + schema := db.Schemas[0] + if schema.Name != "public" { + t.Errorf("Expected schema name 'public', got '%s'", schema.Name) + } + + // Verify tables + if len(schema.Tables) != 3 { + t.Fatalf("Expected 3 tables, got %d", len(schema.Tables)) + } + + // Find tables + var usersTable, postsTable, commentsTable *models.Table + for _, table := range schema.Tables { + switch table.Name { + case "users": + usersTable = table + case "posts": + postsTable = table + case "comments": + commentsTable = table + } + } + + if usersTable == nil { + t.Fatal("Users table not found") + } + if postsTable == nil { + t.Fatal("Posts table not found") + } + if commentsTable == nil { + t.Fatal("Comments table not found") + } + + // Verify users table - test NOT NULL logic for various field types + if len(usersTable.Columns) != 10 { + t.Errorf("Expected 10 columns in users table, got %d", len(usersTable.Columns)) + } + + // username - NOT NULL (explicit notnull tag) + usernameCol, exists := usersTable.Columns["username"] + if !exists { + t.Fatal("Column 'username' not found") + } + if !usernameCol.NotNull { + t.Error("Column 'username' should be NOT NULL (explicit 'notnull' tag)") + } + + // first_name - nullable (pointer type) + firstNameCol, exists := usersTable.Columns["first_name"] + if !exists { + t.Fatal("Column 'first_name' not found") + } + if firstNameCol.NotNull { + t.Error("Column 'first_name' should be nullable (pointer type *string)") + } + + // last_name - nullable (pointer type) + lastNameCol, exists := usersTable.Columns["last_name"] + if !exists { + t.Fatal("Column 'last_name' not found") + } + if lastNameCol.NotNull { + t.Error("Column 'last_name' should be nullable (pointer type *string)") + } + + // bio - nullable (pointer type) + bioCol, exists := usersTable.Columns["bio"] + if !exists { + t.Fatal("Column 'bio' not found") + } + if bioCol.NotNull { + t.Error("Column 'bio' should be nullable (pointer type *string)") + } + + // is_active - NOT NULL (primitive bool without nullzero) + isActiveCol, exists := usersTable.Columns["is_active"] + if !exists { + t.Fatal("Column 'is_active' not found") + } + if !isActiveCol.NotNull { + t.Error("Column 'is_active' should be NOT NULL (primitive bool type without nullzero)") + } + + // Verify users table indexes + if len(usersTable.Indexes) < 1 { + t.Error("Expected at least 1 index on users table") + } + + // Verify posts table + if len(postsTable.Columns) != 11 { + t.Errorf("Expected 11 columns in posts table, got %d", len(postsTable.Columns)) + } + + // excerpt - nullable (pointer type) + excerptCol, exists := postsTable.Columns["excerpt"] + if !exists { + t.Fatal("Column 'excerpt' not found") + } + if excerptCol.NotNull { + t.Error("Column 'excerpt' should be nullable (pointer type *string)") + } + + // published - NOT NULL (primitive bool without nullzero) + publishedCol, exists := postsTable.Columns["published"] + if !exists { + t.Fatal("Column 'published' not found") + } + if !publishedCol.NotNull { + t.Error("Column 'published' should be NOT NULL (primitive bool type without nullzero)") + } + + // published_at - nullable (has nullzero tag) + publishedAtCol, exists := postsTable.Columns["published_at"] + if !exists { + t.Fatal("Column 'published_at' not found") + } + if publishedAtCol.NotNull { + t.Error("Column 'published_at' should be nullable (has nullzero tag)") + } + + // view_count - NOT NULL (primitive int64 without nullzero) + viewCountCol, exists := postsTable.Columns["view_count"] + if !exists { + t.Fatal("Column 'view_count' not found") + } + if !viewCountCol.NotNull { + t.Error("Column 'view_count' should be NOT NULL (primitive int64 type without nullzero)") + } + + // Verify posts table indexes + if len(postsTable.Indexes) < 1 { + t.Error("Expected at least 1 index on posts table") + } + + // Verify comments table + if len(commentsTable.Columns) != 6 { + t.Errorf("Expected 6 columns in comments table, got %d", len(commentsTable.Columns)) + } + + // user_id - nullable (pointer type) + userIDCol, exists := commentsTable.Columns["user_id"] + if !exists { + t.Fatal("Column 'user_id' not found in comments table") + } + if userIDCol.NotNull { + t.Error("Column 'user_id' should be nullable (pointer type *int64)") + } + + // post_id - NOT NULL (explicit notnull tag) + postIDCol, exists := commentsTable.Columns["post_id"] + if !exists { + t.Fatal("Column 'post_id' not found in comments table") + } + if !postIDCol.NotNull { + t.Error("Column 'post_id' should be NOT NULL (explicit 'notnull' tag)") + } + + // Verify foreign key constraints are created from relationship tags + // In Bun, relationships are defined with rel: tags + // The constraints should be created on the referenced tables + if len(postsTable.Constraints) > 0 { + // Check if FK constraint exists + var fkPostsUser *models.Constraint + for _, c := range postsTable.Constraints { + if c.Type == models.ForeignKeyConstraint && c.ReferencedTable == "users" { + fkPostsUser = c + break + } + } + + if fkPostsUser != nil { + if len(fkPostsUser.Columns) != 1 || fkPostsUser.Columns[0] != "user_id" { + t.Error("Expected FK column 'user_id'") + } + if len(fkPostsUser.ReferencedColumns) != 1 || fkPostsUser.ReferencedColumns[0] != "id" { + t.Error("Expected FK referenced column 'id'") + } + } + } + + if len(commentsTable.Constraints) > 0 { + // Check if FK constraints exist + var fkCommentsPost, fkCommentsUser *models.Constraint + for _, c := range commentsTable.Constraints { + if c.Type == models.ForeignKeyConstraint { + if c.ReferencedTable == "posts" { + fkCommentsPost = c + } else if c.ReferencedTable == "users" { + fkCommentsUser = c + } + } + } + + if fkCommentsPost != nil { + if len(fkCommentsPost.Columns) != 1 || fkCommentsPost.Columns[0] != "post_id" { + t.Error("Expected FK column 'post_id'") + } + } + + if fkCommentsUser != nil { + if len(fkCommentsUser.Columns) != 1 || fkCommentsUser.Columns[0] != "user_id" { + t.Error("Expected FK column 'user_id'") + } + } + } +} + +func TestReader_ReadSchema(t *testing.T) { + opts := &readers.ReaderOptions{ + FilePath: filepath.Join("..", "..", "..", "tests", "assets", "bun", "simple.go"), + } + + reader := NewReader(opts) + schema, err := reader.ReadSchema() + if err != nil { + t.Fatalf("ReadSchema() error = %v", err) + } + + if schema == nil { + t.Fatal("ReadSchema() returned nil schema") + } + + if schema.Name != "public" { + t.Errorf("Expected schema name 'public', got '%s'", schema.Name) + } + + if len(schema.Tables) != 1 { + t.Errorf("Expected 1 table, got %d", len(schema.Tables)) + } +} + +func TestReader_ReadTable(t *testing.T) { + opts := &readers.ReaderOptions{ + FilePath: filepath.Join("..", "..", "..", "tests", "assets", "bun", "simple.go"), + } + + reader := NewReader(opts) + table, err := reader.ReadTable() + if err != nil { + t.Fatalf("ReadTable() error = %v", err) + } + + if table == nil { + t.Fatal("ReadTable() returned nil table") + } + + if table.Name != "users" { + t.Errorf("Expected table name 'users', got '%s'", table.Name) + } + + if len(table.Columns) != 6 { + t.Errorf("Expected 6 columns, got %d", len(table.Columns)) + } +} + +func TestReader_ReadDatabase_Directory(t *testing.T) { + opts := &readers.ReaderOptions{ + FilePath: filepath.Join("..", "..", "..", "tests", "assets", "bun"), + } + + reader := NewReader(opts) + db, err := reader.ReadDatabase() + if err != nil { + t.Fatalf("ReadDatabase() error = %v", err) + } + + if db == nil { + t.Fatal("ReadDatabase() returned nil database") + } + + // Should read both simple.go and complex.go + if len(db.Schemas) == 0 { + t.Fatal("Expected at least one schema") + } + + schema := db.Schemas[0] + // Should have at least 3 tables from complex.go (users, posts, comments) + // plus 1 from simple.go (users) - but same table name, so may be overwritten + if len(schema.Tables) < 3 { + t.Errorf("Expected at least 3 tables, got %d", len(schema.Tables)) + } +} + +func TestReader_ReadDatabase_InvalidPath(t *testing.T) { + opts := &readers.ReaderOptions{ + FilePath: "/nonexistent/file.go", + } + + reader := NewReader(opts) + _, err := reader.ReadDatabase() + if err == nil { + t.Error("Expected error for invalid file path") + } +} + +func TestReader_ReadDatabase_EmptyPath(t *testing.T) { + opts := &readers.ReaderOptions{ + FilePath: "", + } + + reader := NewReader(opts) + _, err := reader.ReadDatabase() + if err == nil { + t.Error("Expected error for empty file path") + } +} + +func TestReader_NullableTypes(t *testing.T) { + // This test specifically verifies the NOT NULL logic changes + opts := &readers.ReaderOptions{ + FilePath: filepath.Join("..", "..", "..", "tests", "assets", "bun", "complex.go"), + } + + reader := NewReader(opts) + db, err := reader.ReadDatabase() + if err != nil { + t.Fatalf("ReadDatabase() error = %v", err) + } + + // Find posts table + var postsTable *models.Table + for _, schema := range db.Schemas { + for _, table := range schema.Tables { + if table.Name == "posts" { + postsTable = table + break + } + } + } + + if postsTable == nil { + t.Fatal("Posts table not found") + } + + // Test all nullability scenarios + tests := []struct { + column string + notNull bool + reason string + }{ + {"id", true, "primary key"}, + {"user_id", true, "explicit notnull tag"}, + {"title", true, "explicit notnull tag"}, + {"slug", true, "explicit notnull tag"}, + {"content", true, "explicit notnull tag"}, + {"excerpt", false, "pointer type *string"}, + {"published", true, "primitive bool without nullzero"}, + {"view_count", true, "primitive int64 without nullzero"}, + {"published_at", false, "has nullzero tag"}, + {"created_at", true, "time.Time without nullzero"}, + {"updated_at", true, "time.Time without nullzero"}, + } + + for _, tt := range tests { + col, exists := postsTable.Columns[tt.column] + if !exists { + t.Errorf("Column '%s' not found", tt.column) + continue + } + + if col.NotNull != tt.notNull { + if tt.notNull { + t.Errorf("Column '%s' should be NOT NULL (%s), but NotNull=%v", + tt.column, tt.reason, col.NotNull) + } else { + t.Errorf("Column '%s' should be nullable (%s), but NotNull=%v", + tt.column, tt.reason, col.NotNull) + } + } + } +} diff --git a/pkg/readers/gorm/reader.go b/pkg/readers/gorm/reader.go index 7f7abff..a9aba30 100644 --- a/pkg/readers/gorm/reader.go +++ b/pkg/readers/gorm/reader.go @@ -693,7 +693,7 @@ func (r *Reader) deriveTableName(structName string) string { // parseColumn parses a struct field into a Column model // Returns the column and any inline reference information (e.g., "mainaccount(id_mainaccount)") -func (r *Reader) parseColumn(fieldName string, fieldType ast.Expr, tag string, sequence uint) (*models.Column, string) { +func (r *Reader) parseColumn(fieldName string, fieldType ast.Expr, tag string, sequence uint) (col *models.Column, ref string) { // Extract gorm tag gormTag := r.extractGormTag(tag) if gormTag == "" { @@ -756,20 +756,14 @@ func (r *Reader) parseColumn(fieldName string, fieldType ast.Expr, tag string, s // - explicit "not null" tag means NOT NULL // - absence of "not null" tag with sql_types means nullable // - primitive types (string, int64, bool) default to NOT NULL unless explicitly nullable + // Primary keys are always NOT NULL + column.NotNull = false if _, hasNotNull := parts["not null"]; hasNotNull { column.NotNull = true } else { - // If no explicit "not null" tag, check the Go type - if r.isNullableGoType(fieldType) { - // sql_types.SqlString, etc. are nullable by default - column.NotNull = false - } else { - // Primitive types default to NOT NULL - column.NotNull = false // Default to nullable unless explicitly set - } + // sql_types.SqlString, etc. are nullable by default + column.NotNull = !r.isNullableGoType(fieldType) } - - // Primary keys are always NOT NULL if column.IsPrimaryKey { column.NotNull = true } diff --git a/pkg/readers/gorm/reader_test.go b/pkg/readers/gorm/reader_test.go new file mode 100644 index 0000000..76f53d0 --- /dev/null +++ b/pkg/readers/gorm/reader_test.go @@ -0,0 +1,464 @@ +package gorm + +import ( + "path/filepath" + "testing" + + "git.warky.dev/wdevs/relspecgo/pkg/models" + "git.warky.dev/wdevs/relspecgo/pkg/readers" +) + +func TestReader_ReadDatabase_Simple(t *testing.T) { + opts := &readers.ReaderOptions{ + FilePath: filepath.Join("..", "..", "..", "tests", "assets", "gorm", "simple.go"), + } + + reader := NewReader(opts) + db, err := reader.ReadDatabase() + if err != nil { + t.Fatalf("ReadDatabase() error = %v", err) + } + + if db == nil { + t.Fatal("ReadDatabase() returned nil database") + } + + if len(db.Schemas) == 0 { + t.Fatal("Expected at least one schema") + } + + schema := db.Schemas[0] + if schema.Name != "public" { + t.Errorf("Expected schema name 'public', got '%s'", schema.Name) + } + + if len(schema.Tables) != 1 { + t.Fatalf("Expected 1 table, got %d", len(schema.Tables)) + } + + table := schema.Tables[0] + if table.Name != "users" { + t.Errorf("Expected table name 'users', got '%s'", table.Name) + } + + if len(table.Columns) != 6 { + t.Errorf("Expected 6 columns, got %d", len(table.Columns)) + } + + // Verify id column - primary key should be NOT NULL + idCol, exists := table.Columns["id"] + if !exists { + t.Fatal("Column 'id' not found") + } + if !idCol.IsPrimaryKey { + t.Error("Column 'id' should be primary key") + } + if !idCol.AutoIncrement { + t.Error("Column 'id' should be auto-increment") + } + if !idCol.NotNull { + t.Error("Column 'id' should be NOT NULL (primary keys are always NOT NULL)") + } + if idCol.Type != "bigint" { + t.Errorf("Expected id type 'bigint', got '%s'", idCol.Type) + } + + // Verify email column - explicit "not null" tag should be NOT NULL + emailCol, exists := table.Columns["email"] + if !exists { + t.Fatal("Column 'email' not found") + } + if !emailCol.NotNull { + t.Error("Column 'email' should be NOT NULL (explicit 'not null' tag)") + } + if emailCol.Type != "varchar" || emailCol.Length != 255 { + t.Errorf("Expected email type 'varchar(255)', got '%s' with length %d", emailCol.Type, emailCol.Length) + } + + // Verify name column - primitive string type should be NOT NULL by default + nameCol, exists := table.Columns["name"] + if !exists { + t.Fatal("Column 'name' not found") + } + if !nameCol.NotNull { + t.Error("Column 'name' should be NOT NULL (primitive string type defaults to NOT NULL)") + } + if nameCol.Type != "text" { + t.Errorf("Expected name type 'text', got '%s'", nameCol.Type) + } + + // Verify age column - pointer type should be nullable (NOT NULL = false) + ageCol, exists := table.Columns["age"] + if !exists { + t.Fatal("Column 'age' not found") + } + if ageCol.NotNull { + t.Error("Column 'age' should be nullable (pointer type *int)") + } + if ageCol.Type != "integer" { + t.Errorf("Expected age type 'integer', got '%s'", ageCol.Type) + } + + // Verify is_active column - primitive bool type should be NOT NULL by default + isActiveCol, exists := table.Columns["is_active"] + if !exists { + t.Fatal("Column 'is_active' not found") + } + if !isActiveCol.NotNull { + t.Error("Column 'is_active' should be NOT NULL (primitive bool type defaults to NOT NULL)") + } + if isActiveCol.Type != "boolean" { + t.Errorf("Expected is_active type 'boolean', got '%s'", isActiveCol.Type) + } + + // Verify created_at column - time.Time should be NOT NULL by default + createdAtCol, exists := table.Columns["created_at"] + if !exists { + t.Fatal("Column 'created_at' not found") + } + if !createdAtCol.NotNull { + t.Error("Column 'created_at' should be NOT NULL (time.Time is NOT NULL by default)") + } + if createdAtCol.Type != "timestamp" { + t.Errorf("Expected created_at type 'timestamp', got '%s'", createdAtCol.Type) + } + if createdAtCol.Default != "now()" { + t.Errorf("Expected created_at default 'now()', got '%v'", createdAtCol.Default) + } +} + +func TestReader_ReadDatabase_Complex(t *testing.T) { + opts := &readers.ReaderOptions{ + FilePath: filepath.Join("..", "..", "..", "tests", "assets", "gorm", "complex.go"), + } + + reader := NewReader(opts) + db, err := reader.ReadDatabase() + if err != nil { + t.Fatalf("ReadDatabase() error = %v", err) + } + + if db == nil { + t.Fatal("ReadDatabase() returned nil database") + } + + // Verify schema + if len(db.Schemas) != 1 { + t.Fatalf("Expected 1 schema, got %d", len(db.Schemas)) + } + + schema := db.Schemas[0] + if schema.Name != "public" { + t.Errorf("Expected schema name 'public', got '%s'", schema.Name) + } + + // Verify tables + if len(schema.Tables) != 3 { + t.Fatalf("Expected 3 tables, got %d", len(schema.Tables)) + } + + // Find tables + var usersTable, postsTable, commentsTable *models.Table + for _, table := range schema.Tables { + switch table.Name { + case "users": + usersTable = table + case "posts": + postsTable = table + case "comments": + commentsTable = table + } + } + + if usersTable == nil { + t.Fatal("Users table not found") + } + if postsTable == nil { + t.Fatal("Posts table not found") + } + if commentsTable == nil { + t.Fatal("Comments table not found") + } + + // Verify users table - test NOT NULL logic for various field types + if len(usersTable.Columns) != 10 { + t.Errorf("Expected 10 columns in users table, got %d", len(usersTable.Columns)) + } + + // username - NOT NULL (explicit tag) + usernameCol, exists := usersTable.Columns["username"] + if !exists { + t.Fatal("Column 'username' not found") + } + if !usernameCol.NotNull { + t.Error("Column 'username' should be NOT NULL (explicit 'not null' tag)") + } + + // first_name - nullable (pointer type) + firstNameCol, exists := usersTable.Columns["first_name"] + if !exists { + t.Fatal("Column 'first_name' not found") + } + if firstNameCol.NotNull { + t.Error("Column 'first_name' should be nullable (pointer type *string)") + } + + // last_name - nullable (pointer type) + lastNameCol, exists := usersTable.Columns["last_name"] + if !exists { + t.Fatal("Column 'last_name' not found") + } + if lastNameCol.NotNull { + t.Error("Column 'last_name' should be nullable (pointer type *string)") + } + + // bio - nullable (pointer type) + bioCol, exists := usersTable.Columns["bio"] + if !exists { + t.Fatal("Column 'bio' not found") + } + if bioCol.NotNull { + t.Error("Column 'bio' should be nullable (pointer type *string)") + } + + // is_active - NOT NULL (primitive bool) + isActiveCol, exists := usersTable.Columns["is_active"] + if !exists { + t.Fatal("Column 'is_active' not found") + } + if !isActiveCol.NotNull { + t.Error("Column 'is_active' should be NOT NULL (primitive bool type)") + } + + // Verify users table indexes + if len(usersTable.Indexes) < 1 { + t.Error("Expected at least 1 index on users table") + } + + // Verify posts table + if len(postsTable.Columns) != 11 { + t.Errorf("Expected 11 columns in posts table, got %d", len(postsTable.Columns)) + } + + // excerpt - nullable (pointer type) + excerptCol, exists := postsTable.Columns["excerpt"] + if !exists { + t.Fatal("Column 'excerpt' not found") + } + if excerptCol.NotNull { + t.Error("Column 'excerpt' should be nullable (pointer type *string)") + } + + // published - NOT NULL (primitive bool with default) + publishedCol, exists := postsTable.Columns["published"] + if !exists { + t.Fatal("Column 'published' not found") + } + if !publishedCol.NotNull { + t.Error("Column 'published' should be NOT NULL (primitive bool type)") + } + if publishedCol.Default != "false" { + t.Errorf("Expected published default 'false', got '%v'", publishedCol.Default) + } + + // published_at - nullable (pointer to time.Time) + publishedAtCol, exists := postsTable.Columns["published_at"] + if !exists { + t.Fatal("Column 'published_at' not found") + } + if publishedAtCol.NotNull { + t.Error("Column 'published_at' should be nullable (pointer type *time.Time)") + } + + // view_count - NOT NULL (primitive int64 with default) + viewCountCol, exists := postsTable.Columns["view_count"] + if !exists { + t.Fatal("Column 'view_count' not found") + } + if !viewCountCol.NotNull { + t.Error("Column 'view_count' should be NOT NULL (primitive int64 type)") + } + if viewCountCol.Default != "0" { + t.Errorf("Expected view_count default '0', got '%v'", viewCountCol.Default) + } + + // Verify posts table indexes + if len(postsTable.Indexes) < 1 { + t.Error("Expected at least 1 index on posts table") + } + + // Verify comments table + if len(commentsTable.Columns) != 6 { + t.Errorf("Expected 6 columns in comments table, got %d", len(commentsTable.Columns)) + } + + // user_id - nullable (pointer type) + userIDCol, exists := commentsTable.Columns["user_id"] + if !exists { + t.Fatal("Column 'user_id' not found in comments table") + } + if userIDCol.NotNull { + t.Error("Column 'user_id' should be nullable (pointer type *int64)") + } + + // post_id - NOT NULL (explicit tag) + postIDCol, exists := commentsTable.Columns["post_id"] + if !exists { + t.Fatal("Column 'post_id' not found in comments table") + } + if !postIDCol.NotNull { + t.Error("Column 'post_id' should be NOT NULL (explicit 'not null' tag)") + } + + // Verify foreign key constraints + if len(postsTable.Constraints) == 0 { + t.Error("Expected at least one constraint on posts table") + } + + // Find FK constraint to users + var fkPostsUser *models.Constraint + for _, c := range postsTable.Constraints { + if c.Type == models.ForeignKeyConstraint && c.ReferencedTable == "users" { + fkPostsUser = c + break + } + } + + if fkPostsUser != nil { + if fkPostsUser.OnDelete != "CASCADE" { + t.Errorf("Expected ON DELETE CASCADE for posts->users FK, got '%s'", fkPostsUser.OnDelete) + } + if fkPostsUser.OnUpdate != "CASCADE" { + t.Errorf("Expected ON UPDATE CASCADE for posts->users FK, got '%s'", fkPostsUser.OnUpdate) + } + } + + // Verify comments table constraints + if len(commentsTable.Constraints) == 0 { + t.Error("Expected at least one constraint on comments table") + } + + // Find FK constraints + var fkCommentsPost, fkCommentsUser *models.Constraint + for _, c := range commentsTable.Constraints { + if c.Type == models.ForeignKeyConstraint { + if c.ReferencedTable == "posts" { + fkCommentsPost = c + } else if c.ReferencedTable == "users" { + fkCommentsUser = c + } + } + } + + if fkCommentsPost != nil { + if fkCommentsPost.OnDelete != "CASCADE" { + t.Errorf("Expected ON DELETE CASCADE for comments->posts FK, got '%s'", fkCommentsPost.OnDelete) + } + } + + if fkCommentsUser != nil { + if fkCommentsUser.OnDelete != "SET NULL" { + t.Errorf("Expected ON DELETE SET NULL for comments->users FK, got '%s'", fkCommentsUser.OnDelete) + } + } +} + +func TestReader_ReadSchema(t *testing.T) { + opts := &readers.ReaderOptions{ + FilePath: filepath.Join("..", "..", "..", "tests", "assets", "gorm", "simple.go"), + } + + reader := NewReader(opts) + schema, err := reader.ReadSchema() + if err != nil { + t.Fatalf("ReadSchema() error = %v", err) + } + + if schema == nil { + t.Fatal("ReadSchema() returned nil schema") + } + + if schema.Name != "public" { + t.Errorf("Expected schema name 'public', got '%s'", schema.Name) + } + + if len(schema.Tables) != 1 { + t.Errorf("Expected 1 table, got %d", len(schema.Tables)) + } +} + +func TestReader_ReadTable(t *testing.T) { + opts := &readers.ReaderOptions{ + FilePath: filepath.Join("..", "..", "..", "tests", "assets", "gorm", "simple.go"), + } + + reader := NewReader(opts) + table, err := reader.ReadTable() + if err != nil { + t.Fatalf("ReadTable() error = %v", err) + } + + if table == nil { + t.Fatal("ReadTable() returned nil table") + } + + if table.Name != "users" { + t.Errorf("Expected table name 'users', got '%s'", table.Name) + } + + if len(table.Columns) != 6 { + t.Errorf("Expected 6 columns, got %d", len(table.Columns)) + } +} + +func TestReader_ReadDatabase_Directory(t *testing.T) { + opts := &readers.ReaderOptions{ + FilePath: filepath.Join("..", "..", "..", "tests", "assets", "gorm"), + } + + reader := NewReader(opts) + db, err := reader.ReadDatabase() + if err != nil { + t.Fatalf("ReadDatabase() error = %v", err) + } + + if db == nil { + t.Fatal("ReadDatabase() returned nil database") + } + + // Should read both simple.go and complex.go + if len(db.Schemas) == 0 { + t.Fatal("Expected at least one schema") + } + + schema := db.Schemas[0] + // Should have at least 3 tables from complex.go (users, posts, comments) + // plus 1 from simple.go (users) - but same table name, so may be overwritten + if len(schema.Tables) < 3 { + t.Errorf("Expected at least 3 tables, got %d", len(schema.Tables)) + } +} + +func TestReader_ReadDatabase_InvalidPath(t *testing.T) { + opts := &readers.ReaderOptions{ + FilePath: "/nonexistent/file.go", + } + + reader := NewReader(opts) + _, err := reader.ReadDatabase() + if err == nil { + t.Error("Expected error for invalid file path") + } +} + +func TestReader_ReadDatabase_EmptyPath(t *testing.T) { + opts := &readers.ReaderOptions{ + FilePath: "", + } + + reader := NewReader(opts) + _, err := reader.ReadDatabase() + if err == nil { + t.Error("Expected error for empty file path") + } +} diff --git a/pkg/readers/prisma/reader.go b/pkg/readers/prisma/reader.go index 6cf1495..c961669 100644 --- a/pkg/readers/prisma/reader.go +++ b/pkg/readers/prisma/reader.go @@ -67,15 +67,6 @@ func (r *Reader) ReadTable() (*models.Table, error) { return schema.Tables[0], nil } -// stripQuotes removes surrounding quotes from an identifier -func stripQuotes(s string) string { - s = strings.TrimSpace(s) - if len(s) >= 2 && ((s[0] == '"' && s[len(s)-1] == '"') || (s[0] == '\'' && s[len(s)-1] == '\'')) { - return s[1 : len(s)-1] - } - return s -} - // parsePrisma parses Prisma schema content and returns a Database model func (r *Reader) parsePrisma(content string) (*models.Database, error) { db := models.InitDatabase("database") @@ -239,8 +230,8 @@ func (r *Reader) parseModelFields(lines []string, table *models.Table) { if matches := fieldRegex.FindStringSubmatch(trimmed); matches != nil { fieldName := matches[1] fieldType := matches[2] - modifier := matches[3] // ? or [] - attributes := matches[4] // @... part + modifier := matches[3] // ? or [] + attributes := matches[4] // @... part column := r.parseField(fieldName, fieldType, modifier, attributes, table) if column != nil { @@ -370,9 +361,10 @@ func (r *Reader) extractDefaultValue(attributes string) string { i := start for i < len(attributes) && depth > 0 { - if attributes[i] == '(' { + switch attributes[i] { + case '(': depth++ - } else if attributes[i] == ')' { + case ')': depth-- } i++ @@ -479,11 +471,11 @@ func (r *Reader) parseBlockAttribute(attrName, content string, table *models.Tab // relationField stores information about a relation field for second-pass processing type relationField struct { - tableName string - fieldName string - relatedModel string - isArray bool - relationAttr string + tableName string + fieldName string + relatedModel string + isArray bool + relationAttr string } // resolveRelationships performs a second pass to resolve @relation attributes diff --git a/pkg/readers/typeorm/reader.go b/pkg/readers/typeorm/reader.go index 0b6d300..3e1f01c 100644 --- a/pkg/readers/typeorm/reader.go +++ b/pkg/readers/typeorm/reader.go @@ -578,9 +578,9 @@ func (r *Reader) parseColumnOptions(decorator string, column *models.Column, tab func (r *Reader) resolveRelationships(entities []entityInfo, tableMap map[string]*models.Table, schema *models.Schema) { // Track M2M relations that need join tables type m2mRelation struct { - ownerEntity string + ownerEntity string targetEntity string - ownerField string + ownerField string } m2mRelations := make([]m2mRelation, 0) diff --git a/pkg/writers/typeorm/writer.go b/pkg/writers/typeorm/writer.go index 759f796..30a47b5 100644 --- a/pkg/writers/typeorm/writer.go +++ b/pkg/writers/typeorm/writer.go @@ -520,9 +520,9 @@ func (w *Writer) generateInverseRelations(table *models.Table, schema *models.Sc fieldName := w.pluralize(strings.ToLower(otherTable.Name)) inverseName := strings.ToLower(table.Name) - sb.WriteString(fmt.Sprintf(" @OneToMany(() => %s, %s => %s.%s)\n", - otherTable.Name, strings.ToLower(otherTable.Name), strings.ToLower(otherTable.Name), inverseName)) - sb.WriteString(fmt.Sprintf(" %s: %s[];\n", fieldName, otherTable.Name)) + fmt.Fprintf(sb, " @OneToMany(() => %s, %s => %s.%s)\n", + otherTable.Name, strings.ToLower(otherTable.Name), strings.ToLower(otherTable.Name), inverseName) + fmt.Fprintf(sb, " %s: %s[];\n", fieldName, otherTable.Name) sb.WriteString("\n") } } @@ -570,15 +570,15 @@ func (w *Writer) generateManyToManyRelations(table *models.Table, schema *models inverseName := w.pluralize(strings.ToLower(table.Name)) if isOwner { - sb.WriteString(fmt.Sprintf(" @ManyToMany(() => %s, %s => %s.%s)\n", - otherTable, strings.ToLower(otherTable), strings.ToLower(otherTable), inverseName)) + fmt.Fprintf(sb, " @ManyToMany(() => %s, %s => %s.%s)\n", + otherTable, strings.ToLower(otherTable), strings.ToLower(otherTable), inverseName) sb.WriteString(" @JoinTable()\n") } else { - sb.WriteString(fmt.Sprintf(" @ManyToMany(() => %s, %s => %s.%s)\n", - otherTable, strings.ToLower(otherTable), strings.ToLower(otherTable), inverseName)) + fmt.Fprintf(sb, " @ManyToMany(() => %s, %s => %s.%s)\n", + otherTable, strings.ToLower(otherTable), strings.ToLower(otherTable), inverseName) } - sb.WriteString(fmt.Sprintf(" %s: %s[];\n", fieldName, otherTable)) + fmt.Fprintf(sb, " %s: %s[];\n", fieldName, otherTable) sb.WriteString("\n") } } diff --git a/tests/assets/bun/complex.go b/tests/assets/bun/complex.go new file mode 100644 index 0000000..111661d --- /dev/null +++ b/tests/assets/bun/complex.go @@ -0,0 +1,60 @@ +package models + +import ( + "time" + + "github.com/uptrace/bun" +) + +// ModelUser represents a user in the system +type ModelUser struct { + bun.BaseModel `bun:"table:users,alias:u"` + + ID int64 `bun:"id,pk,autoincrement,type:bigint"` + Username string `bun:"username,notnull,type:varchar(100),unique:idx_username"` + Email string `bun:"email,notnull,type:varchar(255),unique"` + Password string `bun:"password,notnull,type:varchar(255)"` + FirstName *string `bun:"first_name,type:varchar(100)"` + LastName *string `bun:"last_name,type:varchar(100)"` + Bio *string `bun:"bio,type:text"` + IsActive bool `bun:"is_active,type:boolean"` + CreatedAt time.Time `bun:"created_at,type:timestamp"` + UpdatedAt time.Time `bun:"updated_at,type:timestamp"` + + Posts []*ModelPost `bun:"rel:has-many,join:id=user_id"` +} + +// ModelPost represents a blog post +type ModelPost struct { + bun.BaseModel `bun:"table:posts,alias:p"` + + ID int64 `bun:"id,pk,autoincrement,type:bigint"` + UserID int64 `bun:"user_id,notnull,type:bigint"` + Title string `bun:"title,notnull,type:varchar(255)"` + Slug string `bun:"slug,notnull,type:varchar(255),unique:idx_slug"` + Content string `bun:"content,notnull,type:text"` + Excerpt *string `bun:"excerpt,type:text"` + Published bool `bun:"published,type:boolean"` + ViewCount int64 `bun:"view_count,type:bigint"` + PublishedAt *time.Time `bun:"published_at,type:timestamp,nullzero"` + CreatedAt time.Time `bun:"created_at,type:timestamp"` + UpdatedAt time.Time `bun:"updated_at,type:timestamp"` + + User *ModelUser `bun:"rel:belongs-to,join:user_id=id"` + Comments []*ModelComment `bun:"rel:has-many,join:id=post_id"` +} + +// ModelComment represents a comment on a post +type ModelComment struct { + bun.BaseModel `bun:"table:comments,alias:c"` + + ID int64 `bun:"id,pk,autoincrement,type:bigint"` + PostID int64 `bun:"post_id,notnull,type:bigint"` + UserID *int64 `bun:"user_id,type:bigint"` + Content string `bun:"content,notnull,type:text"` + CreatedAt time.Time `bun:"created_at,type:timestamp"` + UpdatedAt time.Time `bun:"updated_at,type:timestamp"` + + Post *ModelPost `bun:"rel:belongs-to,join:post_id=id"` + User *ModelUser `bun:"rel:belongs-to,join:user_id=id"` +} diff --git a/tests/assets/bun/simple.go b/tests/assets/bun/simple.go new file mode 100644 index 0000000..1f4eb1f --- /dev/null +++ b/tests/assets/bun/simple.go @@ -0,0 +1,18 @@ +package models + +import ( + "time" + + "github.com/uptrace/bun" +) + +type User struct { + bun.BaseModel `bun:"table:users,alias:u"` + + ID int64 `bun:"id,pk,autoincrement,type:bigint"` + Email string `bun:"email,notnull,type:varchar(255),unique"` + Name string `bun:"name,type:text"` + Age *int `bun:"age,type:integer"` + IsActive bool `bun:"is_active,type:boolean"` + CreatedAt time.Time `bun:"created_at,type:timestamp,default:now()"` +} diff --git a/tests/assets/gorm/complex.go b/tests/assets/gorm/complex.go new file mode 100644 index 0000000..ed1cc98 --- /dev/null +++ b/tests/assets/gorm/complex.go @@ -0,0 +1,65 @@ +package models + +import ( + "time" +) + +// ModelUser represents a user in the system +type ModelUser struct { + ID int64 `gorm:"column:id;primaryKey;autoIncrement;type:bigint"` + Username string `gorm:"column:username;type:varchar(100);not null;uniqueIndex:idx_username"` + Email string `gorm:"column:email;type:varchar(255);not null;uniqueIndex"` + Password string `gorm:"column:password;type:varchar(255);not null"` + FirstName *string `gorm:"column:first_name;type:varchar(100)"` + LastName *string `gorm:"column:last_name;type:varchar(100)"` + Bio *string `gorm:"column:bio;type:text"` + IsActive bool `gorm:"column:is_active;type:boolean;default:true"` + CreatedAt time.Time `gorm:"column:created_at;type:timestamp;default:now()"` + UpdatedAt time.Time `gorm:"column:updated_at;type:timestamp;default:now()"` + + Posts []*ModelPost `gorm:"foreignKey:UserID;association_foreignkey:ID;constraint:OnDelete:CASCADE,OnUpdate:CASCADE"` + Comments []*ModelComment `gorm:"foreignKey:UserID;association_foreignkey:ID;constraint:OnDelete:SET NULL"` +} + +func (ModelUser) TableName() string { + return "users" +} + +// ModelPost represents a blog post +type ModelPost struct { + ID int64 `gorm:"column:id;primaryKey;autoIncrement;type:bigint"` + UserID int64 `gorm:"column:user_id;type:bigint;not null;index:idx_user_id"` + Title string `gorm:"column:title;type:varchar(255);not null"` + Slug string `gorm:"column:slug;type:varchar(255);not null;uniqueIndex:idx_slug"` + Content string `gorm:"column:content;type:text;not null"` + Excerpt *string `gorm:"column:excerpt;type:text"` + Published bool `gorm:"column:published;type:boolean;default:false"` + ViewCount int64 `gorm:"column:view_count;type:bigint;default:0"` + PublishedAt *time.Time `gorm:"column:published_at;type:timestamp"` + CreatedAt time.Time `gorm:"column:created_at;type:timestamp;default:now()"` + UpdatedAt time.Time `gorm:"column:updated_at;type:timestamp;default:now()"` + + User *ModelUser `gorm:"foreignKey:UserID;references:ID;constraint:OnDelete:CASCADE,OnUpdate:CASCADE"` + Comments []*ModelComment `gorm:"foreignKey:PostID;association_foreignkey:ID;constraint:OnDelete:CASCADE"` +} + +func (ModelPost) TableName() string { + return "posts" +} + +// ModelComment represents a comment on a post +type ModelComment struct { + ID int64 `gorm:"column:id;primaryKey;autoIncrement;type:bigint"` + PostID int64 `gorm:"column:post_id;type:bigint;not null;index:idx_post_id"` + UserID *int64 `gorm:"column:user_id;type:bigint;index:idx_user_id"` + Content string `gorm:"column:content;type:text;not null"` + CreatedAt time.Time `gorm:"column:created_at;type:timestamp;default:now()"` + UpdatedAt time.Time `gorm:"column:updated_at;type:timestamp;default:now()"` + + Post *ModelPost `gorm:"foreignKey:PostID;references:ID;constraint:OnDelete:CASCADE"` + User *ModelUser `gorm:"foreignKey:UserID;references:ID;constraint:OnDelete:SET NULL"` +} + +func (ModelComment) TableName() string { + return "comments" +} diff --git a/tests/assets/gorm/simple.go b/tests/assets/gorm/simple.go new file mode 100644 index 0000000..62c13b9 --- /dev/null +++ b/tests/assets/gorm/simple.go @@ -0,0 +1,18 @@ +package models + +import ( + "time" +) + +type User struct { + ID int64 `gorm:"column:id;primaryKey;autoIncrement;type:bigint"` + Email string `gorm:"column:email;type:varchar(255);not null"` + Name string `gorm:"column:name;type:text"` + Age *int `gorm:"column:age;type:integer"` + IsActive bool `gorm:"column:is_active;type:boolean"` + CreatedAt time.Time `gorm:"column:created_at;type:timestamp;default:now()"` +} + +func (User) TableName() string { + return "users" +}