fix: readers and linting issues
This commit is contained in:
9
Makefile
9
Makefile
@@ -40,6 +40,15 @@ lint: ## Run linter
|
||||
exit 1; \
|
||||
fi
|
||||
|
||||
lintfix: ## Run linter
|
||||
@echo "Running linter..."
|
||||
@if command -v golangci-lint > /dev/null; then \
|
||||
golangci-lint run --config=.golangci.json --fix; \
|
||||
else \
|
||||
echo "golangci-lint not installed. Install with: go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest"; \
|
||||
exit 1; \
|
||||
fi
|
||||
|
||||
clean: ## Clean build artifacts
|
||||
@echo "Cleaning..."
|
||||
$(GOCLEAN)
|
||||
|
||||
@@ -6,6 +6,8 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/readers"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/readers/bun"
|
||||
@@ -29,7 +31,6 @@ import (
|
||||
wprisma "git.warky.dev/wdevs/relspecgo/pkg/writers/prisma"
|
||||
wtypeorm "git.warky.dev/wdevs/relspecgo/pkg/writers/typeorm"
|
||||
wyaml "git.warky.dev/wdevs/relspecgo/pkg/writers/yaml"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -140,9 +141,18 @@ func init() {
|
||||
convertCmd.Flags().StringVar(&convertPackageName, "package", "", "Package name (for code generation formats like gorm/bun)")
|
||||
convertCmd.Flags().StringVar(&convertSchemaFilter, "schema", "", "Filter to a specific schema by name (required for formats like dctx that only support single schemas)")
|
||||
|
||||
convertCmd.MarkFlagRequired("from")
|
||||
convertCmd.MarkFlagRequired("to")
|
||||
convertCmd.MarkFlagRequired("to-path")
|
||||
err := convertCmd.MarkFlagRequired("from")
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error marking from flag as required: %v\n", err)
|
||||
}
|
||||
err = convertCmd.MarkFlagRequired("to")
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error marking to flag as required: %v\n", err)
|
||||
}
|
||||
err = convertCmd.MarkFlagRequired("to-path")
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error marking to-path flag as required: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
func runConvert(cmd *cobra.Command, args []string) error {
|
||||
@@ -344,7 +354,7 @@ func writeDatabase(db *models.Database, dbType, outputPath, packageName, schemaF
|
||||
}
|
||||
|
||||
// For formats like DCTX that don't support full database writes, require schema filter
|
||||
if strings.ToLower(dbType) == "dctx" {
|
||||
if strings.EqualFold(dbType, "dctx") {
|
||||
if len(db.Schemas) == 0 {
|
||||
return fmt.Errorf("no schemas found in database")
|
||||
}
|
||||
|
||||
@@ -6,6 +6,8 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/diff"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/readers"
|
||||
@@ -15,7 +17,6 @@ import (
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/readers/json"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/readers/pgsql"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/readers/yaml"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -96,8 +97,14 @@ func init() {
|
||||
diffCmd.Flags().StringVar(&outputFormat, "format", "summary", "Output format (summary, json, html)")
|
||||
diffCmd.Flags().StringVar(&outputPath, "output", "", "Output file path (default: stdout for summary, required for json/html)")
|
||||
|
||||
diffCmd.MarkFlagRequired("from")
|
||||
diffCmd.MarkFlagRequired("to")
|
||||
err := diffCmd.MarkFlagRequired("from")
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error marking from flag as required: %v\n", err)
|
||||
}
|
||||
err = diffCmd.MarkFlagRequired("to")
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error marking to flag as required: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
func runDiff(cmd *cobra.Command, args []string) error {
|
||||
|
||||
@@ -2,14 +2,15 @@ package diff
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||
)
|
||||
|
||||
// CompareDatabases compares two database models and returns the differences
|
||||
func CompareDatabases(source, target *models.Database) *DiffResult {
|
||||
result := &DiffResult{
|
||||
Source: source.Name,
|
||||
Target: target.Name,
|
||||
Source: source.Name,
|
||||
Target: target.Name,
|
||||
Schemas: compareSchemas(source.Schemas, target.Schemas),
|
||||
}
|
||||
return result
|
||||
|
||||
@@ -4,8 +4,8 @@ import "git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||
|
||||
// DiffResult represents the complete difference analysis between two databases
|
||||
type DiffResult struct {
|
||||
Source string `json:"source"`
|
||||
Target string `json:"target"`
|
||||
Source string `json:"source"`
|
||||
Target string `json:"target"`
|
||||
Schemas *SchemaDiff `json:"schemas"`
|
||||
}
|
||||
|
||||
@@ -18,17 +18,17 @@ type SchemaDiff struct {
|
||||
|
||||
// SchemaChange represents changes within a schema
|
||||
type SchemaChange struct {
|
||||
Name string `json:"name"`
|
||||
Tables *TableDiff `json:"tables,omitempty"`
|
||||
Views *ViewDiff `json:"views,omitempty"`
|
||||
Sequences *SequenceDiff `json:"sequences,omitempty"`
|
||||
Name string `json:"name"`
|
||||
Tables *TableDiff `json:"tables,omitempty"`
|
||||
Views *ViewDiff `json:"views,omitempty"`
|
||||
Sequences *SequenceDiff `json:"sequences,omitempty"`
|
||||
}
|
||||
|
||||
// TableDiff represents differences in tables
|
||||
type TableDiff struct {
|
||||
Missing []*models.Table `json:"missing"` // Tables in source but not in target
|
||||
Extra []*models.Table `json:"extra"` // Tables in target but not in source
|
||||
Modified []*TableChange `json:"modified"` // Tables that exist in both but differ
|
||||
Missing []*models.Table `json:"missing"` // Tables in source but not in target
|
||||
Extra []*models.Table `json:"extra"` // Tables in target but not in source
|
||||
Modified []*TableChange `json:"modified"` // Tables that exist in both but differ
|
||||
}
|
||||
|
||||
// TableChange represents changes within a table
|
||||
@@ -50,16 +50,16 @@ type ColumnDiff struct {
|
||||
|
||||
// ColumnChange represents a modified column
|
||||
type ColumnChange struct {
|
||||
Name string `json:"name"`
|
||||
Source *models.Column `json:"source"`
|
||||
Target *models.Column `json:"target"`
|
||||
Changes map[string]any `json:"changes"` // Map of field name to what changed
|
||||
Name string `json:"name"`
|
||||
Source *models.Column `json:"source"`
|
||||
Target *models.Column `json:"target"`
|
||||
Changes map[string]any `json:"changes"` // Map of field name to what changed
|
||||
}
|
||||
|
||||
// IndexDiff represents differences in indexes
|
||||
type IndexDiff struct {
|
||||
Missing []*models.Index `json:"missing"` // Indexes in source but not in target
|
||||
Extra []*models.Index `json:"extra"` // Indexes in target but not in source
|
||||
Missing []*models.Index `json:"missing"` // Indexes in source but not in target
|
||||
Extra []*models.Index `json:"extra"` // Indexes in target but not in source
|
||||
Modified []*IndexChange `json:"modified"` // Indexes that exist in both but differ
|
||||
}
|
||||
|
||||
@@ -103,8 +103,8 @@ type RelationshipChange struct {
|
||||
|
||||
// ViewDiff represents differences in views
|
||||
type ViewDiff struct {
|
||||
Missing []*models.View `json:"missing"` // Views in source but not in target
|
||||
Extra []*models.View `json:"extra"` // Views in target but not in source
|
||||
Missing []*models.View `json:"missing"` // Views in source but not in target
|
||||
Extra []*models.View `json:"extra"` // Views in target but not in source
|
||||
Modified []*ViewChange `json:"modified"` // Views that exist in both but differ
|
||||
}
|
||||
|
||||
@@ -133,14 +133,14 @@ type SequenceChange struct {
|
||||
|
||||
// Summary provides counts for quick overview
|
||||
type Summary struct {
|
||||
Schemas SchemaSummary `json:"schemas"`
|
||||
Tables TableSummary `json:"tables"`
|
||||
Columns ColumnSummary `json:"columns"`
|
||||
Indexes IndexSummary `json:"indexes"`
|
||||
Constraints ConstraintSummary `json:"constraints"`
|
||||
Schemas SchemaSummary `json:"schemas"`
|
||||
Tables TableSummary `json:"tables"`
|
||||
Columns ColumnSummary `json:"columns"`
|
||||
Indexes IndexSummary `json:"indexes"`
|
||||
Constraints ConstraintSummary `json:"constraints"`
|
||||
Relationships RelationshipSummary `json:"relationships"`
|
||||
Views ViewSummary `json:"views"`
|
||||
Sequences SequenceSummary `json:"sequences"`
|
||||
Views ViewSummary `json:"views"`
|
||||
Sequences SequenceSummary `json:"sequences"`
|
||||
}
|
||||
|
||||
type SchemaSummary struct {
|
||||
|
||||
@@ -626,17 +626,14 @@ func (r *Reader) parseColumn(fieldName string, fieldType ast.Expr, tag string, s
|
||||
// - nullzero tag means the field is nullable (can be NULL in DB)
|
||||
// - absence of nullzero means the field is NOT NULL
|
||||
// - primitive types (int64, bool, string) are NOT NULL by default
|
||||
column.NotNull = true
|
||||
// Primary keys are always NOT NULL
|
||||
|
||||
if strings.Contains(bunTag, "nullzero") {
|
||||
column.NotNull = false
|
||||
} else if r.isNullableGoType(fieldType) {
|
||||
// SqlString, SqlInt, etc. without nullzero tag means NOT NULL
|
||||
column.NotNull = true
|
||||
} else {
|
||||
// Primitive types are NOT NULL by default
|
||||
column.NotNull = true
|
||||
column.NotNull = !r.isNullableGoType(fieldType)
|
||||
}
|
||||
|
||||
// Primary keys are always NOT NULL
|
||||
if column.IsPrimaryKey {
|
||||
column.NotNull = true
|
||||
}
|
||||
|
||||
522
pkg/readers/bun/reader_test.go
Normal file
522
pkg/readers/bun/reader_test.go
Normal file
@@ -0,0 +1,522 @@
|
||||
package bun
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/readers"
|
||||
)
|
||||
|
||||
func TestReader_ReadDatabase_Simple(t *testing.T) {
|
||||
opts := &readers.ReaderOptions{
|
||||
FilePath: filepath.Join("..", "..", "..", "tests", "assets", "bun", "simple.go"),
|
||||
}
|
||||
|
||||
reader := NewReader(opts)
|
||||
db, err := reader.ReadDatabase()
|
||||
if err != nil {
|
||||
t.Fatalf("ReadDatabase() error = %v", err)
|
||||
}
|
||||
|
||||
if db == nil {
|
||||
t.Fatal("ReadDatabase() returned nil database")
|
||||
}
|
||||
|
||||
if len(db.Schemas) == 0 {
|
||||
t.Fatal("Expected at least one schema")
|
||||
}
|
||||
|
||||
schema := db.Schemas[0]
|
||||
if schema.Name != "public" {
|
||||
t.Errorf("Expected schema name 'public', got '%s'", schema.Name)
|
||||
}
|
||||
|
||||
if len(schema.Tables) != 1 {
|
||||
t.Fatalf("Expected 1 table, got %d", len(schema.Tables))
|
||||
}
|
||||
|
||||
table := schema.Tables[0]
|
||||
if table.Name != "users" {
|
||||
t.Errorf("Expected table name 'users', got '%s'", table.Name)
|
||||
}
|
||||
|
||||
if len(table.Columns) != 6 {
|
||||
t.Errorf("Expected 6 columns, got %d", len(table.Columns))
|
||||
}
|
||||
|
||||
// Verify id column - primary key should be NOT NULL
|
||||
idCol, exists := table.Columns["id"]
|
||||
if !exists {
|
||||
t.Fatal("Column 'id' not found")
|
||||
}
|
||||
if !idCol.IsPrimaryKey {
|
||||
t.Error("Column 'id' should be primary key")
|
||||
}
|
||||
if !idCol.AutoIncrement {
|
||||
t.Error("Column 'id' should be auto-increment")
|
||||
}
|
||||
if !idCol.NotNull {
|
||||
t.Error("Column 'id' should be NOT NULL (primary keys are always NOT NULL)")
|
||||
}
|
||||
if idCol.Type != "bigint" {
|
||||
t.Errorf("Expected id type 'bigint', got '%s'", idCol.Type)
|
||||
}
|
||||
|
||||
// Verify email column - explicit notnull tag should be NOT NULL
|
||||
emailCol, exists := table.Columns["email"]
|
||||
if !exists {
|
||||
t.Fatal("Column 'email' not found")
|
||||
}
|
||||
if !emailCol.NotNull {
|
||||
t.Error("Column 'email' should be NOT NULL (explicit 'notnull' tag)")
|
||||
}
|
||||
if emailCol.Type != "varchar" || emailCol.Length != 255 {
|
||||
t.Errorf("Expected email type 'varchar(255)', got '%s' with length %d", emailCol.Type, emailCol.Length)
|
||||
}
|
||||
|
||||
// Verify name column - primitive string type should be NOT NULL by default in Bun
|
||||
nameCol, exists := table.Columns["name"]
|
||||
if !exists {
|
||||
t.Fatal("Column 'name' not found")
|
||||
}
|
||||
if !nameCol.NotNull {
|
||||
t.Error("Column 'name' should be NOT NULL (primitive string type, no nullzero tag)")
|
||||
}
|
||||
if nameCol.Type != "text" {
|
||||
t.Errorf("Expected name type 'text', got '%s'", nameCol.Type)
|
||||
}
|
||||
|
||||
// Verify age column - pointer type should be nullable (NOT NULL = false)
|
||||
ageCol, exists := table.Columns["age"]
|
||||
if !exists {
|
||||
t.Fatal("Column 'age' not found")
|
||||
}
|
||||
if ageCol.NotNull {
|
||||
t.Error("Column 'age' should be nullable (pointer type *int)")
|
||||
}
|
||||
if ageCol.Type != "integer" {
|
||||
t.Errorf("Expected age type 'integer', got '%s'", ageCol.Type)
|
||||
}
|
||||
|
||||
// Verify is_active column - primitive bool type should be NOT NULL by default
|
||||
isActiveCol, exists := table.Columns["is_active"]
|
||||
if !exists {
|
||||
t.Fatal("Column 'is_active' not found")
|
||||
}
|
||||
if !isActiveCol.NotNull {
|
||||
t.Error("Column 'is_active' should be NOT NULL (primitive bool type, no nullzero tag)")
|
||||
}
|
||||
if isActiveCol.Type != "boolean" {
|
||||
t.Errorf("Expected is_active type 'boolean', got '%s'", isActiveCol.Type)
|
||||
}
|
||||
|
||||
// Verify created_at column - time.Time should be NOT NULL by default
|
||||
createdAtCol, exists := table.Columns["created_at"]
|
||||
if !exists {
|
||||
t.Fatal("Column 'created_at' not found")
|
||||
}
|
||||
if !createdAtCol.NotNull {
|
||||
t.Error("Column 'created_at' should be NOT NULL (time.Time is NOT NULL by default)")
|
||||
}
|
||||
if createdAtCol.Type != "timestamp" {
|
||||
t.Errorf("Expected created_at type 'timestamp', got '%s'", createdAtCol.Type)
|
||||
}
|
||||
|
||||
// Verify unique index on email
|
||||
if len(table.Indexes) < 1 {
|
||||
t.Error("Expected at least 1 index on users table")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReader_ReadDatabase_Complex(t *testing.T) {
|
||||
opts := &readers.ReaderOptions{
|
||||
FilePath: filepath.Join("..", "..", "..", "tests", "assets", "bun", "complex.go"),
|
||||
}
|
||||
|
||||
reader := NewReader(opts)
|
||||
db, err := reader.ReadDatabase()
|
||||
if err != nil {
|
||||
t.Fatalf("ReadDatabase() error = %v", err)
|
||||
}
|
||||
|
||||
if db == nil {
|
||||
t.Fatal("ReadDatabase() returned nil database")
|
||||
}
|
||||
|
||||
// Verify schema
|
||||
if len(db.Schemas) != 1 {
|
||||
t.Fatalf("Expected 1 schema, got %d", len(db.Schemas))
|
||||
}
|
||||
|
||||
schema := db.Schemas[0]
|
||||
if schema.Name != "public" {
|
||||
t.Errorf("Expected schema name 'public', got '%s'", schema.Name)
|
||||
}
|
||||
|
||||
// Verify tables
|
||||
if len(schema.Tables) != 3 {
|
||||
t.Fatalf("Expected 3 tables, got %d", len(schema.Tables))
|
||||
}
|
||||
|
||||
// Find tables
|
||||
var usersTable, postsTable, commentsTable *models.Table
|
||||
for _, table := range schema.Tables {
|
||||
switch table.Name {
|
||||
case "users":
|
||||
usersTable = table
|
||||
case "posts":
|
||||
postsTable = table
|
||||
case "comments":
|
||||
commentsTable = table
|
||||
}
|
||||
}
|
||||
|
||||
if usersTable == nil {
|
||||
t.Fatal("Users table not found")
|
||||
}
|
||||
if postsTable == nil {
|
||||
t.Fatal("Posts table not found")
|
||||
}
|
||||
if commentsTable == nil {
|
||||
t.Fatal("Comments table not found")
|
||||
}
|
||||
|
||||
// Verify users table - test NOT NULL logic for various field types
|
||||
if len(usersTable.Columns) != 10 {
|
||||
t.Errorf("Expected 10 columns in users table, got %d", len(usersTable.Columns))
|
||||
}
|
||||
|
||||
// username - NOT NULL (explicit notnull tag)
|
||||
usernameCol, exists := usersTable.Columns["username"]
|
||||
if !exists {
|
||||
t.Fatal("Column 'username' not found")
|
||||
}
|
||||
if !usernameCol.NotNull {
|
||||
t.Error("Column 'username' should be NOT NULL (explicit 'notnull' tag)")
|
||||
}
|
||||
|
||||
// first_name - nullable (pointer type)
|
||||
firstNameCol, exists := usersTable.Columns["first_name"]
|
||||
if !exists {
|
||||
t.Fatal("Column 'first_name' not found")
|
||||
}
|
||||
if firstNameCol.NotNull {
|
||||
t.Error("Column 'first_name' should be nullable (pointer type *string)")
|
||||
}
|
||||
|
||||
// last_name - nullable (pointer type)
|
||||
lastNameCol, exists := usersTable.Columns["last_name"]
|
||||
if !exists {
|
||||
t.Fatal("Column 'last_name' not found")
|
||||
}
|
||||
if lastNameCol.NotNull {
|
||||
t.Error("Column 'last_name' should be nullable (pointer type *string)")
|
||||
}
|
||||
|
||||
// bio - nullable (pointer type)
|
||||
bioCol, exists := usersTable.Columns["bio"]
|
||||
if !exists {
|
||||
t.Fatal("Column 'bio' not found")
|
||||
}
|
||||
if bioCol.NotNull {
|
||||
t.Error("Column 'bio' should be nullable (pointer type *string)")
|
||||
}
|
||||
|
||||
// is_active - NOT NULL (primitive bool without nullzero)
|
||||
isActiveCol, exists := usersTable.Columns["is_active"]
|
||||
if !exists {
|
||||
t.Fatal("Column 'is_active' not found")
|
||||
}
|
||||
if !isActiveCol.NotNull {
|
||||
t.Error("Column 'is_active' should be NOT NULL (primitive bool type without nullzero)")
|
||||
}
|
||||
|
||||
// Verify users table indexes
|
||||
if len(usersTable.Indexes) < 1 {
|
||||
t.Error("Expected at least 1 index on users table")
|
||||
}
|
||||
|
||||
// Verify posts table
|
||||
if len(postsTable.Columns) != 11 {
|
||||
t.Errorf("Expected 11 columns in posts table, got %d", len(postsTable.Columns))
|
||||
}
|
||||
|
||||
// excerpt - nullable (pointer type)
|
||||
excerptCol, exists := postsTable.Columns["excerpt"]
|
||||
if !exists {
|
||||
t.Fatal("Column 'excerpt' not found")
|
||||
}
|
||||
if excerptCol.NotNull {
|
||||
t.Error("Column 'excerpt' should be nullable (pointer type *string)")
|
||||
}
|
||||
|
||||
// published - NOT NULL (primitive bool without nullzero)
|
||||
publishedCol, exists := postsTable.Columns["published"]
|
||||
if !exists {
|
||||
t.Fatal("Column 'published' not found")
|
||||
}
|
||||
if !publishedCol.NotNull {
|
||||
t.Error("Column 'published' should be NOT NULL (primitive bool type without nullzero)")
|
||||
}
|
||||
|
||||
// published_at - nullable (has nullzero tag)
|
||||
publishedAtCol, exists := postsTable.Columns["published_at"]
|
||||
if !exists {
|
||||
t.Fatal("Column 'published_at' not found")
|
||||
}
|
||||
if publishedAtCol.NotNull {
|
||||
t.Error("Column 'published_at' should be nullable (has nullzero tag)")
|
||||
}
|
||||
|
||||
// view_count - NOT NULL (primitive int64 without nullzero)
|
||||
viewCountCol, exists := postsTable.Columns["view_count"]
|
||||
if !exists {
|
||||
t.Fatal("Column 'view_count' not found")
|
||||
}
|
||||
if !viewCountCol.NotNull {
|
||||
t.Error("Column 'view_count' should be NOT NULL (primitive int64 type without nullzero)")
|
||||
}
|
||||
|
||||
// Verify posts table indexes
|
||||
if len(postsTable.Indexes) < 1 {
|
||||
t.Error("Expected at least 1 index on posts table")
|
||||
}
|
||||
|
||||
// Verify comments table
|
||||
if len(commentsTable.Columns) != 6 {
|
||||
t.Errorf("Expected 6 columns in comments table, got %d", len(commentsTable.Columns))
|
||||
}
|
||||
|
||||
// user_id - nullable (pointer type)
|
||||
userIDCol, exists := commentsTable.Columns["user_id"]
|
||||
if !exists {
|
||||
t.Fatal("Column 'user_id' not found in comments table")
|
||||
}
|
||||
if userIDCol.NotNull {
|
||||
t.Error("Column 'user_id' should be nullable (pointer type *int64)")
|
||||
}
|
||||
|
||||
// post_id - NOT NULL (explicit notnull tag)
|
||||
postIDCol, exists := commentsTable.Columns["post_id"]
|
||||
if !exists {
|
||||
t.Fatal("Column 'post_id' not found in comments table")
|
||||
}
|
||||
if !postIDCol.NotNull {
|
||||
t.Error("Column 'post_id' should be NOT NULL (explicit 'notnull' tag)")
|
||||
}
|
||||
|
||||
// Verify foreign key constraints are created from relationship tags
|
||||
// In Bun, relationships are defined with rel: tags
|
||||
// The constraints should be created on the referenced tables
|
||||
if len(postsTable.Constraints) > 0 {
|
||||
// Check if FK constraint exists
|
||||
var fkPostsUser *models.Constraint
|
||||
for _, c := range postsTable.Constraints {
|
||||
if c.Type == models.ForeignKeyConstraint && c.ReferencedTable == "users" {
|
||||
fkPostsUser = c
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if fkPostsUser != nil {
|
||||
if len(fkPostsUser.Columns) != 1 || fkPostsUser.Columns[0] != "user_id" {
|
||||
t.Error("Expected FK column 'user_id'")
|
||||
}
|
||||
if len(fkPostsUser.ReferencedColumns) != 1 || fkPostsUser.ReferencedColumns[0] != "id" {
|
||||
t.Error("Expected FK referenced column 'id'")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(commentsTable.Constraints) > 0 {
|
||||
// Check if FK constraints exist
|
||||
var fkCommentsPost, fkCommentsUser *models.Constraint
|
||||
for _, c := range commentsTable.Constraints {
|
||||
if c.Type == models.ForeignKeyConstraint {
|
||||
if c.ReferencedTable == "posts" {
|
||||
fkCommentsPost = c
|
||||
} else if c.ReferencedTable == "users" {
|
||||
fkCommentsUser = c
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if fkCommentsPost != nil {
|
||||
if len(fkCommentsPost.Columns) != 1 || fkCommentsPost.Columns[0] != "post_id" {
|
||||
t.Error("Expected FK column 'post_id'")
|
||||
}
|
||||
}
|
||||
|
||||
if fkCommentsUser != nil {
|
||||
if len(fkCommentsUser.Columns) != 1 || fkCommentsUser.Columns[0] != "user_id" {
|
||||
t.Error("Expected FK column 'user_id'")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestReader_ReadSchema(t *testing.T) {
|
||||
opts := &readers.ReaderOptions{
|
||||
FilePath: filepath.Join("..", "..", "..", "tests", "assets", "bun", "simple.go"),
|
||||
}
|
||||
|
||||
reader := NewReader(opts)
|
||||
schema, err := reader.ReadSchema()
|
||||
if err != nil {
|
||||
t.Fatalf("ReadSchema() error = %v", err)
|
||||
}
|
||||
|
||||
if schema == nil {
|
||||
t.Fatal("ReadSchema() returned nil schema")
|
||||
}
|
||||
|
||||
if schema.Name != "public" {
|
||||
t.Errorf("Expected schema name 'public', got '%s'", schema.Name)
|
||||
}
|
||||
|
||||
if len(schema.Tables) != 1 {
|
||||
t.Errorf("Expected 1 table, got %d", len(schema.Tables))
|
||||
}
|
||||
}
|
||||
|
||||
func TestReader_ReadTable(t *testing.T) {
|
||||
opts := &readers.ReaderOptions{
|
||||
FilePath: filepath.Join("..", "..", "..", "tests", "assets", "bun", "simple.go"),
|
||||
}
|
||||
|
||||
reader := NewReader(opts)
|
||||
table, err := reader.ReadTable()
|
||||
if err != nil {
|
||||
t.Fatalf("ReadTable() error = %v", err)
|
||||
}
|
||||
|
||||
if table == nil {
|
||||
t.Fatal("ReadTable() returned nil table")
|
||||
}
|
||||
|
||||
if table.Name != "users" {
|
||||
t.Errorf("Expected table name 'users', got '%s'", table.Name)
|
||||
}
|
||||
|
||||
if len(table.Columns) != 6 {
|
||||
t.Errorf("Expected 6 columns, got %d", len(table.Columns))
|
||||
}
|
||||
}
|
||||
|
||||
func TestReader_ReadDatabase_Directory(t *testing.T) {
|
||||
opts := &readers.ReaderOptions{
|
||||
FilePath: filepath.Join("..", "..", "..", "tests", "assets", "bun"),
|
||||
}
|
||||
|
||||
reader := NewReader(opts)
|
||||
db, err := reader.ReadDatabase()
|
||||
if err != nil {
|
||||
t.Fatalf("ReadDatabase() error = %v", err)
|
||||
}
|
||||
|
||||
if db == nil {
|
||||
t.Fatal("ReadDatabase() returned nil database")
|
||||
}
|
||||
|
||||
// Should read both simple.go and complex.go
|
||||
if len(db.Schemas) == 0 {
|
||||
t.Fatal("Expected at least one schema")
|
||||
}
|
||||
|
||||
schema := db.Schemas[0]
|
||||
// Should have at least 3 tables from complex.go (users, posts, comments)
|
||||
// plus 1 from simple.go (users) - but same table name, so may be overwritten
|
||||
if len(schema.Tables) < 3 {
|
||||
t.Errorf("Expected at least 3 tables, got %d", len(schema.Tables))
|
||||
}
|
||||
}
|
||||
|
||||
func TestReader_ReadDatabase_InvalidPath(t *testing.T) {
|
||||
opts := &readers.ReaderOptions{
|
||||
FilePath: "/nonexistent/file.go",
|
||||
}
|
||||
|
||||
reader := NewReader(opts)
|
||||
_, err := reader.ReadDatabase()
|
||||
if err == nil {
|
||||
t.Error("Expected error for invalid file path")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReader_ReadDatabase_EmptyPath(t *testing.T) {
|
||||
opts := &readers.ReaderOptions{
|
||||
FilePath: "",
|
||||
}
|
||||
|
||||
reader := NewReader(opts)
|
||||
_, err := reader.ReadDatabase()
|
||||
if err == nil {
|
||||
t.Error("Expected error for empty file path")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReader_NullableTypes(t *testing.T) {
|
||||
// This test specifically verifies the NOT NULL logic changes
|
||||
opts := &readers.ReaderOptions{
|
||||
FilePath: filepath.Join("..", "..", "..", "tests", "assets", "bun", "complex.go"),
|
||||
}
|
||||
|
||||
reader := NewReader(opts)
|
||||
db, err := reader.ReadDatabase()
|
||||
if err != nil {
|
||||
t.Fatalf("ReadDatabase() error = %v", err)
|
||||
}
|
||||
|
||||
// Find posts table
|
||||
var postsTable *models.Table
|
||||
for _, schema := range db.Schemas {
|
||||
for _, table := range schema.Tables {
|
||||
if table.Name == "posts" {
|
||||
postsTable = table
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if postsTable == nil {
|
||||
t.Fatal("Posts table not found")
|
||||
}
|
||||
|
||||
// Test all nullability scenarios
|
||||
tests := []struct {
|
||||
column string
|
||||
notNull bool
|
||||
reason string
|
||||
}{
|
||||
{"id", true, "primary key"},
|
||||
{"user_id", true, "explicit notnull tag"},
|
||||
{"title", true, "explicit notnull tag"},
|
||||
{"slug", true, "explicit notnull tag"},
|
||||
{"content", true, "explicit notnull tag"},
|
||||
{"excerpt", false, "pointer type *string"},
|
||||
{"published", true, "primitive bool without nullzero"},
|
||||
{"view_count", true, "primitive int64 without nullzero"},
|
||||
{"published_at", false, "has nullzero tag"},
|
||||
{"created_at", true, "time.Time without nullzero"},
|
||||
{"updated_at", true, "time.Time without nullzero"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
col, exists := postsTable.Columns[tt.column]
|
||||
if !exists {
|
||||
t.Errorf("Column '%s' not found", tt.column)
|
||||
continue
|
||||
}
|
||||
|
||||
if col.NotNull != tt.notNull {
|
||||
if tt.notNull {
|
||||
t.Errorf("Column '%s' should be NOT NULL (%s), but NotNull=%v",
|
||||
tt.column, tt.reason, col.NotNull)
|
||||
} else {
|
||||
t.Errorf("Column '%s' should be nullable (%s), but NotNull=%v",
|
||||
tt.column, tt.reason, col.NotNull)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -693,7 +693,7 @@ func (r *Reader) deriveTableName(structName string) string {
|
||||
|
||||
// parseColumn parses a struct field into a Column model
|
||||
// Returns the column and any inline reference information (e.g., "mainaccount(id_mainaccount)")
|
||||
func (r *Reader) parseColumn(fieldName string, fieldType ast.Expr, tag string, sequence uint) (*models.Column, string) {
|
||||
func (r *Reader) parseColumn(fieldName string, fieldType ast.Expr, tag string, sequence uint) (col *models.Column, ref string) {
|
||||
// Extract gorm tag
|
||||
gormTag := r.extractGormTag(tag)
|
||||
if gormTag == "" {
|
||||
@@ -756,20 +756,14 @@ func (r *Reader) parseColumn(fieldName string, fieldType ast.Expr, tag string, s
|
||||
// - explicit "not null" tag means NOT NULL
|
||||
// - absence of "not null" tag with sql_types means nullable
|
||||
// - primitive types (string, int64, bool) default to NOT NULL unless explicitly nullable
|
||||
// Primary keys are always NOT NULL
|
||||
column.NotNull = false
|
||||
if _, hasNotNull := parts["not null"]; hasNotNull {
|
||||
column.NotNull = true
|
||||
} else {
|
||||
// If no explicit "not null" tag, check the Go type
|
||||
if r.isNullableGoType(fieldType) {
|
||||
// sql_types.SqlString, etc. are nullable by default
|
||||
column.NotNull = false
|
||||
} else {
|
||||
// Primitive types default to NOT NULL
|
||||
column.NotNull = false // Default to nullable unless explicitly set
|
||||
}
|
||||
// sql_types.SqlString, etc. are nullable by default
|
||||
column.NotNull = !r.isNullableGoType(fieldType)
|
||||
}
|
||||
|
||||
// Primary keys are always NOT NULL
|
||||
if column.IsPrimaryKey {
|
||||
column.NotNull = true
|
||||
}
|
||||
|
||||
464
pkg/readers/gorm/reader_test.go
Normal file
464
pkg/readers/gorm/reader_test.go
Normal file
@@ -0,0 +1,464 @@
|
||||
package gorm
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/readers"
|
||||
)
|
||||
|
||||
func TestReader_ReadDatabase_Simple(t *testing.T) {
|
||||
opts := &readers.ReaderOptions{
|
||||
FilePath: filepath.Join("..", "..", "..", "tests", "assets", "gorm", "simple.go"),
|
||||
}
|
||||
|
||||
reader := NewReader(opts)
|
||||
db, err := reader.ReadDatabase()
|
||||
if err != nil {
|
||||
t.Fatalf("ReadDatabase() error = %v", err)
|
||||
}
|
||||
|
||||
if db == nil {
|
||||
t.Fatal("ReadDatabase() returned nil database")
|
||||
}
|
||||
|
||||
if len(db.Schemas) == 0 {
|
||||
t.Fatal("Expected at least one schema")
|
||||
}
|
||||
|
||||
schema := db.Schemas[0]
|
||||
if schema.Name != "public" {
|
||||
t.Errorf("Expected schema name 'public', got '%s'", schema.Name)
|
||||
}
|
||||
|
||||
if len(schema.Tables) != 1 {
|
||||
t.Fatalf("Expected 1 table, got %d", len(schema.Tables))
|
||||
}
|
||||
|
||||
table := schema.Tables[0]
|
||||
if table.Name != "users" {
|
||||
t.Errorf("Expected table name 'users', got '%s'", table.Name)
|
||||
}
|
||||
|
||||
if len(table.Columns) != 6 {
|
||||
t.Errorf("Expected 6 columns, got %d", len(table.Columns))
|
||||
}
|
||||
|
||||
// Verify id column - primary key should be NOT NULL
|
||||
idCol, exists := table.Columns["id"]
|
||||
if !exists {
|
||||
t.Fatal("Column 'id' not found")
|
||||
}
|
||||
if !idCol.IsPrimaryKey {
|
||||
t.Error("Column 'id' should be primary key")
|
||||
}
|
||||
if !idCol.AutoIncrement {
|
||||
t.Error("Column 'id' should be auto-increment")
|
||||
}
|
||||
if !idCol.NotNull {
|
||||
t.Error("Column 'id' should be NOT NULL (primary keys are always NOT NULL)")
|
||||
}
|
||||
if idCol.Type != "bigint" {
|
||||
t.Errorf("Expected id type 'bigint', got '%s'", idCol.Type)
|
||||
}
|
||||
|
||||
// Verify email column - explicit "not null" tag should be NOT NULL
|
||||
emailCol, exists := table.Columns["email"]
|
||||
if !exists {
|
||||
t.Fatal("Column 'email' not found")
|
||||
}
|
||||
if !emailCol.NotNull {
|
||||
t.Error("Column 'email' should be NOT NULL (explicit 'not null' tag)")
|
||||
}
|
||||
if emailCol.Type != "varchar" || emailCol.Length != 255 {
|
||||
t.Errorf("Expected email type 'varchar(255)', got '%s' with length %d", emailCol.Type, emailCol.Length)
|
||||
}
|
||||
|
||||
// Verify name column - primitive string type should be NOT NULL by default
|
||||
nameCol, exists := table.Columns["name"]
|
||||
if !exists {
|
||||
t.Fatal("Column 'name' not found")
|
||||
}
|
||||
if !nameCol.NotNull {
|
||||
t.Error("Column 'name' should be NOT NULL (primitive string type defaults to NOT NULL)")
|
||||
}
|
||||
if nameCol.Type != "text" {
|
||||
t.Errorf("Expected name type 'text', got '%s'", nameCol.Type)
|
||||
}
|
||||
|
||||
// Verify age column - pointer type should be nullable (NOT NULL = false)
|
||||
ageCol, exists := table.Columns["age"]
|
||||
if !exists {
|
||||
t.Fatal("Column 'age' not found")
|
||||
}
|
||||
if ageCol.NotNull {
|
||||
t.Error("Column 'age' should be nullable (pointer type *int)")
|
||||
}
|
||||
if ageCol.Type != "integer" {
|
||||
t.Errorf("Expected age type 'integer', got '%s'", ageCol.Type)
|
||||
}
|
||||
|
||||
// Verify is_active column - primitive bool type should be NOT NULL by default
|
||||
isActiveCol, exists := table.Columns["is_active"]
|
||||
if !exists {
|
||||
t.Fatal("Column 'is_active' not found")
|
||||
}
|
||||
if !isActiveCol.NotNull {
|
||||
t.Error("Column 'is_active' should be NOT NULL (primitive bool type defaults to NOT NULL)")
|
||||
}
|
||||
if isActiveCol.Type != "boolean" {
|
||||
t.Errorf("Expected is_active type 'boolean', got '%s'", isActiveCol.Type)
|
||||
}
|
||||
|
||||
// Verify created_at column - time.Time should be NOT NULL by default
|
||||
createdAtCol, exists := table.Columns["created_at"]
|
||||
if !exists {
|
||||
t.Fatal("Column 'created_at' not found")
|
||||
}
|
||||
if !createdAtCol.NotNull {
|
||||
t.Error("Column 'created_at' should be NOT NULL (time.Time is NOT NULL by default)")
|
||||
}
|
||||
if createdAtCol.Type != "timestamp" {
|
||||
t.Errorf("Expected created_at type 'timestamp', got '%s'", createdAtCol.Type)
|
||||
}
|
||||
if createdAtCol.Default != "now()" {
|
||||
t.Errorf("Expected created_at default 'now()', got '%v'", createdAtCol.Default)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReader_ReadDatabase_Complex(t *testing.T) {
|
||||
opts := &readers.ReaderOptions{
|
||||
FilePath: filepath.Join("..", "..", "..", "tests", "assets", "gorm", "complex.go"),
|
||||
}
|
||||
|
||||
reader := NewReader(opts)
|
||||
db, err := reader.ReadDatabase()
|
||||
if err != nil {
|
||||
t.Fatalf("ReadDatabase() error = %v", err)
|
||||
}
|
||||
|
||||
if db == nil {
|
||||
t.Fatal("ReadDatabase() returned nil database")
|
||||
}
|
||||
|
||||
// Verify schema
|
||||
if len(db.Schemas) != 1 {
|
||||
t.Fatalf("Expected 1 schema, got %d", len(db.Schemas))
|
||||
}
|
||||
|
||||
schema := db.Schemas[0]
|
||||
if schema.Name != "public" {
|
||||
t.Errorf("Expected schema name 'public', got '%s'", schema.Name)
|
||||
}
|
||||
|
||||
// Verify tables
|
||||
if len(schema.Tables) != 3 {
|
||||
t.Fatalf("Expected 3 tables, got %d", len(schema.Tables))
|
||||
}
|
||||
|
||||
// Find tables
|
||||
var usersTable, postsTable, commentsTable *models.Table
|
||||
for _, table := range schema.Tables {
|
||||
switch table.Name {
|
||||
case "users":
|
||||
usersTable = table
|
||||
case "posts":
|
||||
postsTable = table
|
||||
case "comments":
|
||||
commentsTable = table
|
||||
}
|
||||
}
|
||||
|
||||
if usersTable == nil {
|
||||
t.Fatal("Users table not found")
|
||||
}
|
||||
if postsTable == nil {
|
||||
t.Fatal("Posts table not found")
|
||||
}
|
||||
if commentsTable == nil {
|
||||
t.Fatal("Comments table not found")
|
||||
}
|
||||
|
||||
// Verify users table - test NOT NULL logic for various field types
|
||||
if len(usersTable.Columns) != 10 {
|
||||
t.Errorf("Expected 10 columns in users table, got %d", len(usersTable.Columns))
|
||||
}
|
||||
|
||||
// username - NOT NULL (explicit tag)
|
||||
usernameCol, exists := usersTable.Columns["username"]
|
||||
if !exists {
|
||||
t.Fatal("Column 'username' not found")
|
||||
}
|
||||
if !usernameCol.NotNull {
|
||||
t.Error("Column 'username' should be NOT NULL (explicit 'not null' tag)")
|
||||
}
|
||||
|
||||
// first_name - nullable (pointer type)
|
||||
firstNameCol, exists := usersTable.Columns["first_name"]
|
||||
if !exists {
|
||||
t.Fatal("Column 'first_name' not found")
|
||||
}
|
||||
if firstNameCol.NotNull {
|
||||
t.Error("Column 'first_name' should be nullable (pointer type *string)")
|
||||
}
|
||||
|
||||
// last_name - nullable (pointer type)
|
||||
lastNameCol, exists := usersTable.Columns["last_name"]
|
||||
if !exists {
|
||||
t.Fatal("Column 'last_name' not found")
|
||||
}
|
||||
if lastNameCol.NotNull {
|
||||
t.Error("Column 'last_name' should be nullable (pointer type *string)")
|
||||
}
|
||||
|
||||
// bio - nullable (pointer type)
|
||||
bioCol, exists := usersTable.Columns["bio"]
|
||||
if !exists {
|
||||
t.Fatal("Column 'bio' not found")
|
||||
}
|
||||
if bioCol.NotNull {
|
||||
t.Error("Column 'bio' should be nullable (pointer type *string)")
|
||||
}
|
||||
|
||||
// is_active - NOT NULL (primitive bool)
|
||||
isActiveCol, exists := usersTable.Columns["is_active"]
|
||||
if !exists {
|
||||
t.Fatal("Column 'is_active' not found")
|
||||
}
|
||||
if !isActiveCol.NotNull {
|
||||
t.Error("Column 'is_active' should be NOT NULL (primitive bool type)")
|
||||
}
|
||||
|
||||
// Verify users table indexes
|
||||
if len(usersTable.Indexes) < 1 {
|
||||
t.Error("Expected at least 1 index on users table")
|
||||
}
|
||||
|
||||
// Verify posts table
|
||||
if len(postsTable.Columns) != 11 {
|
||||
t.Errorf("Expected 11 columns in posts table, got %d", len(postsTable.Columns))
|
||||
}
|
||||
|
||||
// excerpt - nullable (pointer type)
|
||||
excerptCol, exists := postsTable.Columns["excerpt"]
|
||||
if !exists {
|
||||
t.Fatal("Column 'excerpt' not found")
|
||||
}
|
||||
if excerptCol.NotNull {
|
||||
t.Error("Column 'excerpt' should be nullable (pointer type *string)")
|
||||
}
|
||||
|
||||
// published - NOT NULL (primitive bool with default)
|
||||
publishedCol, exists := postsTable.Columns["published"]
|
||||
if !exists {
|
||||
t.Fatal("Column 'published' not found")
|
||||
}
|
||||
if !publishedCol.NotNull {
|
||||
t.Error("Column 'published' should be NOT NULL (primitive bool type)")
|
||||
}
|
||||
if publishedCol.Default != "false" {
|
||||
t.Errorf("Expected published default 'false', got '%v'", publishedCol.Default)
|
||||
}
|
||||
|
||||
// published_at - nullable (pointer to time.Time)
|
||||
publishedAtCol, exists := postsTable.Columns["published_at"]
|
||||
if !exists {
|
||||
t.Fatal("Column 'published_at' not found")
|
||||
}
|
||||
if publishedAtCol.NotNull {
|
||||
t.Error("Column 'published_at' should be nullable (pointer type *time.Time)")
|
||||
}
|
||||
|
||||
// view_count - NOT NULL (primitive int64 with default)
|
||||
viewCountCol, exists := postsTable.Columns["view_count"]
|
||||
if !exists {
|
||||
t.Fatal("Column 'view_count' not found")
|
||||
}
|
||||
if !viewCountCol.NotNull {
|
||||
t.Error("Column 'view_count' should be NOT NULL (primitive int64 type)")
|
||||
}
|
||||
if viewCountCol.Default != "0" {
|
||||
t.Errorf("Expected view_count default '0', got '%v'", viewCountCol.Default)
|
||||
}
|
||||
|
||||
// Verify posts table indexes
|
||||
if len(postsTable.Indexes) < 1 {
|
||||
t.Error("Expected at least 1 index on posts table")
|
||||
}
|
||||
|
||||
// Verify comments table
|
||||
if len(commentsTable.Columns) != 6 {
|
||||
t.Errorf("Expected 6 columns in comments table, got %d", len(commentsTable.Columns))
|
||||
}
|
||||
|
||||
// user_id - nullable (pointer type)
|
||||
userIDCol, exists := commentsTable.Columns["user_id"]
|
||||
if !exists {
|
||||
t.Fatal("Column 'user_id' not found in comments table")
|
||||
}
|
||||
if userIDCol.NotNull {
|
||||
t.Error("Column 'user_id' should be nullable (pointer type *int64)")
|
||||
}
|
||||
|
||||
// post_id - NOT NULL (explicit tag)
|
||||
postIDCol, exists := commentsTable.Columns["post_id"]
|
||||
if !exists {
|
||||
t.Fatal("Column 'post_id' not found in comments table")
|
||||
}
|
||||
if !postIDCol.NotNull {
|
||||
t.Error("Column 'post_id' should be NOT NULL (explicit 'not null' tag)")
|
||||
}
|
||||
|
||||
// Verify foreign key constraints
|
||||
if len(postsTable.Constraints) == 0 {
|
||||
t.Error("Expected at least one constraint on posts table")
|
||||
}
|
||||
|
||||
// Find FK constraint to users
|
||||
var fkPostsUser *models.Constraint
|
||||
for _, c := range postsTable.Constraints {
|
||||
if c.Type == models.ForeignKeyConstraint && c.ReferencedTable == "users" {
|
||||
fkPostsUser = c
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if fkPostsUser != nil {
|
||||
if fkPostsUser.OnDelete != "CASCADE" {
|
||||
t.Errorf("Expected ON DELETE CASCADE for posts->users FK, got '%s'", fkPostsUser.OnDelete)
|
||||
}
|
||||
if fkPostsUser.OnUpdate != "CASCADE" {
|
||||
t.Errorf("Expected ON UPDATE CASCADE for posts->users FK, got '%s'", fkPostsUser.OnUpdate)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify comments table constraints
|
||||
if len(commentsTable.Constraints) == 0 {
|
||||
t.Error("Expected at least one constraint on comments table")
|
||||
}
|
||||
|
||||
// Find FK constraints
|
||||
var fkCommentsPost, fkCommentsUser *models.Constraint
|
||||
for _, c := range commentsTable.Constraints {
|
||||
if c.Type == models.ForeignKeyConstraint {
|
||||
if c.ReferencedTable == "posts" {
|
||||
fkCommentsPost = c
|
||||
} else if c.ReferencedTable == "users" {
|
||||
fkCommentsUser = c
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if fkCommentsPost != nil {
|
||||
if fkCommentsPost.OnDelete != "CASCADE" {
|
||||
t.Errorf("Expected ON DELETE CASCADE for comments->posts FK, got '%s'", fkCommentsPost.OnDelete)
|
||||
}
|
||||
}
|
||||
|
||||
if fkCommentsUser != nil {
|
||||
if fkCommentsUser.OnDelete != "SET NULL" {
|
||||
t.Errorf("Expected ON DELETE SET NULL for comments->users FK, got '%s'", fkCommentsUser.OnDelete)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestReader_ReadSchema(t *testing.T) {
|
||||
opts := &readers.ReaderOptions{
|
||||
FilePath: filepath.Join("..", "..", "..", "tests", "assets", "gorm", "simple.go"),
|
||||
}
|
||||
|
||||
reader := NewReader(opts)
|
||||
schema, err := reader.ReadSchema()
|
||||
if err != nil {
|
||||
t.Fatalf("ReadSchema() error = %v", err)
|
||||
}
|
||||
|
||||
if schema == nil {
|
||||
t.Fatal("ReadSchema() returned nil schema")
|
||||
}
|
||||
|
||||
if schema.Name != "public" {
|
||||
t.Errorf("Expected schema name 'public', got '%s'", schema.Name)
|
||||
}
|
||||
|
||||
if len(schema.Tables) != 1 {
|
||||
t.Errorf("Expected 1 table, got %d", len(schema.Tables))
|
||||
}
|
||||
}
|
||||
|
||||
func TestReader_ReadTable(t *testing.T) {
|
||||
opts := &readers.ReaderOptions{
|
||||
FilePath: filepath.Join("..", "..", "..", "tests", "assets", "gorm", "simple.go"),
|
||||
}
|
||||
|
||||
reader := NewReader(opts)
|
||||
table, err := reader.ReadTable()
|
||||
if err != nil {
|
||||
t.Fatalf("ReadTable() error = %v", err)
|
||||
}
|
||||
|
||||
if table == nil {
|
||||
t.Fatal("ReadTable() returned nil table")
|
||||
}
|
||||
|
||||
if table.Name != "users" {
|
||||
t.Errorf("Expected table name 'users', got '%s'", table.Name)
|
||||
}
|
||||
|
||||
if len(table.Columns) != 6 {
|
||||
t.Errorf("Expected 6 columns, got %d", len(table.Columns))
|
||||
}
|
||||
}
|
||||
|
||||
func TestReader_ReadDatabase_Directory(t *testing.T) {
|
||||
opts := &readers.ReaderOptions{
|
||||
FilePath: filepath.Join("..", "..", "..", "tests", "assets", "gorm"),
|
||||
}
|
||||
|
||||
reader := NewReader(opts)
|
||||
db, err := reader.ReadDatabase()
|
||||
if err != nil {
|
||||
t.Fatalf("ReadDatabase() error = %v", err)
|
||||
}
|
||||
|
||||
if db == nil {
|
||||
t.Fatal("ReadDatabase() returned nil database")
|
||||
}
|
||||
|
||||
// Should read both simple.go and complex.go
|
||||
if len(db.Schemas) == 0 {
|
||||
t.Fatal("Expected at least one schema")
|
||||
}
|
||||
|
||||
schema := db.Schemas[0]
|
||||
// Should have at least 3 tables from complex.go (users, posts, comments)
|
||||
// plus 1 from simple.go (users) - but same table name, so may be overwritten
|
||||
if len(schema.Tables) < 3 {
|
||||
t.Errorf("Expected at least 3 tables, got %d", len(schema.Tables))
|
||||
}
|
||||
}
|
||||
|
||||
func TestReader_ReadDatabase_InvalidPath(t *testing.T) {
|
||||
opts := &readers.ReaderOptions{
|
||||
FilePath: "/nonexistent/file.go",
|
||||
}
|
||||
|
||||
reader := NewReader(opts)
|
||||
_, err := reader.ReadDatabase()
|
||||
if err == nil {
|
||||
t.Error("Expected error for invalid file path")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReader_ReadDatabase_EmptyPath(t *testing.T) {
|
||||
opts := &readers.ReaderOptions{
|
||||
FilePath: "",
|
||||
}
|
||||
|
||||
reader := NewReader(opts)
|
||||
_, err := reader.ReadDatabase()
|
||||
if err == nil {
|
||||
t.Error("Expected error for empty file path")
|
||||
}
|
||||
}
|
||||
@@ -67,15 +67,6 @@ func (r *Reader) ReadTable() (*models.Table, error) {
|
||||
return schema.Tables[0], nil
|
||||
}
|
||||
|
||||
// stripQuotes removes surrounding quotes from an identifier
|
||||
func stripQuotes(s string) string {
|
||||
s = strings.TrimSpace(s)
|
||||
if len(s) >= 2 && ((s[0] == '"' && s[len(s)-1] == '"') || (s[0] == '\'' && s[len(s)-1] == '\'')) {
|
||||
return s[1 : len(s)-1]
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// parsePrisma parses Prisma schema content and returns a Database model
|
||||
func (r *Reader) parsePrisma(content string) (*models.Database, error) {
|
||||
db := models.InitDatabase("database")
|
||||
@@ -239,8 +230,8 @@ func (r *Reader) parseModelFields(lines []string, table *models.Table) {
|
||||
if matches := fieldRegex.FindStringSubmatch(trimmed); matches != nil {
|
||||
fieldName := matches[1]
|
||||
fieldType := matches[2]
|
||||
modifier := matches[3] // ? or []
|
||||
attributes := matches[4] // @... part
|
||||
modifier := matches[3] // ? or []
|
||||
attributes := matches[4] // @... part
|
||||
|
||||
column := r.parseField(fieldName, fieldType, modifier, attributes, table)
|
||||
if column != nil {
|
||||
@@ -370,9 +361,10 @@ func (r *Reader) extractDefaultValue(attributes string) string {
|
||||
i := start
|
||||
|
||||
for i < len(attributes) && depth > 0 {
|
||||
if attributes[i] == '(' {
|
||||
switch attributes[i] {
|
||||
case '(':
|
||||
depth++
|
||||
} else if attributes[i] == ')' {
|
||||
case ')':
|
||||
depth--
|
||||
}
|
||||
i++
|
||||
@@ -479,11 +471,11 @@ func (r *Reader) parseBlockAttribute(attrName, content string, table *models.Tab
|
||||
|
||||
// relationField stores information about a relation field for second-pass processing
|
||||
type relationField struct {
|
||||
tableName string
|
||||
fieldName string
|
||||
relatedModel string
|
||||
isArray bool
|
||||
relationAttr string
|
||||
tableName string
|
||||
fieldName string
|
||||
relatedModel string
|
||||
isArray bool
|
||||
relationAttr string
|
||||
}
|
||||
|
||||
// resolveRelationships performs a second pass to resolve @relation attributes
|
||||
|
||||
@@ -578,9 +578,9 @@ func (r *Reader) parseColumnOptions(decorator string, column *models.Column, tab
|
||||
func (r *Reader) resolveRelationships(entities []entityInfo, tableMap map[string]*models.Table, schema *models.Schema) {
|
||||
// Track M2M relations that need join tables
|
||||
type m2mRelation struct {
|
||||
ownerEntity string
|
||||
ownerEntity string
|
||||
targetEntity string
|
||||
ownerField string
|
||||
ownerField string
|
||||
}
|
||||
m2mRelations := make([]m2mRelation, 0)
|
||||
|
||||
|
||||
@@ -520,9 +520,9 @@ func (w *Writer) generateInverseRelations(table *models.Table, schema *models.Sc
|
||||
fieldName := w.pluralize(strings.ToLower(otherTable.Name))
|
||||
inverseName := strings.ToLower(table.Name)
|
||||
|
||||
sb.WriteString(fmt.Sprintf(" @OneToMany(() => %s, %s => %s.%s)\n",
|
||||
otherTable.Name, strings.ToLower(otherTable.Name), strings.ToLower(otherTable.Name), inverseName))
|
||||
sb.WriteString(fmt.Sprintf(" %s: %s[];\n", fieldName, otherTable.Name))
|
||||
fmt.Fprintf(sb, " @OneToMany(() => %s, %s => %s.%s)\n",
|
||||
otherTable.Name, strings.ToLower(otherTable.Name), strings.ToLower(otherTable.Name), inverseName)
|
||||
fmt.Fprintf(sb, " %s: %s[];\n", fieldName, otherTable.Name)
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
}
|
||||
@@ -570,15 +570,15 @@ func (w *Writer) generateManyToManyRelations(table *models.Table, schema *models
|
||||
inverseName := w.pluralize(strings.ToLower(table.Name))
|
||||
|
||||
if isOwner {
|
||||
sb.WriteString(fmt.Sprintf(" @ManyToMany(() => %s, %s => %s.%s)\n",
|
||||
otherTable, strings.ToLower(otherTable), strings.ToLower(otherTable), inverseName))
|
||||
fmt.Fprintf(sb, " @ManyToMany(() => %s, %s => %s.%s)\n",
|
||||
otherTable, strings.ToLower(otherTable), strings.ToLower(otherTable), inverseName)
|
||||
sb.WriteString(" @JoinTable()\n")
|
||||
} else {
|
||||
sb.WriteString(fmt.Sprintf(" @ManyToMany(() => %s, %s => %s.%s)\n",
|
||||
otherTable, strings.ToLower(otherTable), strings.ToLower(otherTable), inverseName))
|
||||
fmt.Fprintf(sb, " @ManyToMany(() => %s, %s => %s.%s)\n",
|
||||
otherTable, strings.ToLower(otherTable), strings.ToLower(otherTable), inverseName)
|
||||
}
|
||||
|
||||
sb.WriteString(fmt.Sprintf(" %s: %s[];\n", fieldName, otherTable))
|
||||
fmt.Fprintf(sb, " %s: %s[];\n", fieldName, otherTable)
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
}
|
||||
|
||||
60
tests/assets/bun/complex.go
Normal file
60
tests/assets/bun/complex.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/uptrace/bun"
|
||||
)
|
||||
|
||||
// ModelUser represents a user in the system
|
||||
type ModelUser struct {
|
||||
bun.BaseModel `bun:"table:users,alias:u"`
|
||||
|
||||
ID int64 `bun:"id,pk,autoincrement,type:bigint"`
|
||||
Username string `bun:"username,notnull,type:varchar(100),unique:idx_username"`
|
||||
Email string `bun:"email,notnull,type:varchar(255),unique"`
|
||||
Password string `bun:"password,notnull,type:varchar(255)"`
|
||||
FirstName *string `bun:"first_name,type:varchar(100)"`
|
||||
LastName *string `bun:"last_name,type:varchar(100)"`
|
||||
Bio *string `bun:"bio,type:text"`
|
||||
IsActive bool `bun:"is_active,type:boolean"`
|
||||
CreatedAt time.Time `bun:"created_at,type:timestamp"`
|
||||
UpdatedAt time.Time `bun:"updated_at,type:timestamp"`
|
||||
|
||||
Posts []*ModelPost `bun:"rel:has-many,join:id=user_id"`
|
||||
}
|
||||
|
||||
// ModelPost represents a blog post
|
||||
type ModelPost struct {
|
||||
bun.BaseModel `bun:"table:posts,alias:p"`
|
||||
|
||||
ID int64 `bun:"id,pk,autoincrement,type:bigint"`
|
||||
UserID int64 `bun:"user_id,notnull,type:bigint"`
|
||||
Title string `bun:"title,notnull,type:varchar(255)"`
|
||||
Slug string `bun:"slug,notnull,type:varchar(255),unique:idx_slug"`
|
||||
Content string `bun:"content,notnull,type:text"`
|
||||
Excerpt *string `bun:"excerpt,type:text"`
|
||||
Published bool `bun:"published,type:boolean"`
|
||||
ViewCount int64 `bun:"view_count,type:bigint"`
|
||||
PublishedAt *time.Time `bun:"published_at,type:timestamp,nullzero"`
|
||||
CreatedAt time.Time `bun:"created_at,type:timestamp"`
|
||||
UpdatedAt time.Time `bun:"updated_at,type:timestamp"`
|
||||
|
||||
User *ModelUser `bun:"rel:belongs-to,join:user_id=id"`
|
||||
Comments []*ModelComment `bun:"rel:has-many,join:id=post_id"`
|
||||
}
|
||||
|
||||
// ModelComment represents a comment on a post
|
||||
type ModelComment struct {
|
||||
bun.BaseModel `bun:"table:comments,alias:c"`
|
||||
|
||||
ID int64 `bun:"id,pk,autoincrement,type:bigint"`
|
||||
PostID int64 `bun:"post_id,notnull,type:bigint"`
|
||||
UserID *int64 `bun:"user_id,type:bigint"`
|
||||
Content string `bun:"content,notnull,type:text"`
|
||||
CreatedAt time.Time `bun:"created_at,type:timestamp"`
|
||||
UpdatedAt time.Time `bun:"updated_at,type:timestamp"`
|
||||
|
||||
Post *ModelPost `bun:"rel:belongs-to,join:post_id=id"`
|
||||
User *ModelUser `bun:"rel:belongs-to,join:user_id=id"`
|
||||
}
|
||||
18
tests/assets/bun/simple.go
Normal file
18
tests/assets/bun/simple.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/uptrace/bun"
|
||||
)
|
||||
|
||||
type User struct {
|
||||
bun.BaseModel `bun:"table:users,alias:u"`
|
||||
|
||||
ID int64 `bun:"id,pk,autoincrement,type:bigint"`
|
||||
Email string `bun:"email,notnull,type:varchar(255),unique"`
|
||||
Name string `bun:"name,type:text"`
|
||||
Age *int `bun:"age,type:integer"`
|
||||
IsActive bool `bun:"is_active,type:boolean"`
|
||||
CreatedAt time.Time `bun:"created_at,type:timestamp,default:now()"`
|
||||
}
|
||||
65
tests/assets/gorm/complex.go
Normal file
65
tests/assets/gorm/complex.go
Normal file
@@ -0,0 +1,65 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// ModelUser represents a user in the system
|
||||
type ModelUser struct {
|
||||
ID int64 `gorm:"column:id;primaryKey;autoIncrement;type:bigint"`
|
||||
Username string `gorm:"column:username;type:varchar(100);not null;uniqueIndex:idx_username"`
|
||||
Email string `gorm:"column:email;type:varchar(255);not null;uniqueIndex"`
|
||||
Password string `gorm:"column:password;type:varchar(255);not null"`
|
||||
FirstName *string `gorm:"column:first_name;type:varchar(100)"`
|
||||
LastName *string `gorm:"column:last_name;type:varchar(100)"`
|
||||
Bio *string `gorm:"column:bio;type:text"`
|
||||
IsActive bool `gorm:"column:is_active;type:boolean;default:true"`
|
||||
CreatedAt time.Time `gorm:"column:created_at;type:timestamp;default:now()"`
|
||||
UpdatedAt time.Time `gorm:"column:updated_at;type:timestamp;default:now()"`
|
||||
|
||||
Posts []*ModelPost `gorm:"foreignKey:UserID;association_foreignkey:ID;constraint:OnDelete:CASCADE,OnUpdate:CASCADE"`
|
||||
Comments []*ModelComment `gorm:"foreignKey:UserID;association_foreignkey:ID;constraint:OnDelete:SET NULL"`
|
||||
}
|
||||
|
||||
func (ModelUser) TableName() string {
|
||||
return "users"
|
||||
}
|
||||
|
||||
// ModelPost represents a blog post
|
||||
type ModelPost struct {
|
||||
ID int64 `gorm:"column:id;primaryKey;autoIncrement;type:bigint"`
|
||||
UserID int64 `gorm:"column:user_id;type:bigint;not null;index:idx_user_id"`
|
||||
Title string `gorm:"column:title;type:varchar(255);not null"`
|
||||
Slug string `gorm:"column:slug;type:varchar(255);not null;uniqueIndex:idx_slug"`
|
||||
Content string `gorm:"column:content;type:text;not null"`
|
||||
Excerpt *string `gorm:"column:excerpt;type:text"`
|
||||
Published bool `gorm:"column:published;type:boolean;default:false"`
|
||||
ViewCount int64 `gorm:"column:view_count;type:bigint;default:0"`
|
||||
PublishedAt *time.Time `gorm:"column:published_at;type:timestamp"`
|
||||
CreatedAt time.Time `gorm:"column:created_at;type:timestamp;default:now()"`
|
||||
UpdatedAt time.Time `gorm:"column:updated_at;type:timestamp;default:now()"`
|
||||
|
||||
User *ModelUser `gorm:"foreignKey:UserID;references:ID;constraint:OnDelete:CASCADE,OnUpdate:CASCADE"`
|
||||
Comments []*ModelComment `gorm:"foreignKey:PostID;association_foreignkey:ID;constraint:OnDelete:CASCADE"`
|
||||
}
|
||||
|
||||
func (ModelPost) TableName() string {
|
||||
return "posts"
|
||||
}
|
||||
|
||||
// ModelComment represents a comment on a post
|
||||
type ModelComment struct {
|
||||
ID int64 `gorm:"column:id;primaryKey;autoIncrement;type:bigint"`
|
||||
PostID int64 `gorm:"column:post_id;type:bigint;not null;index:idx_post_id"`
|
||||
UserID *int64 `gorm:"column:user_id;type:bigint;index:idx_user_id"`
|
||||
Content string `gorm:"column:content;type:text;not null"`
|
||||
CreatedAt time.Time `gorm:"column:created_at;type:timestamp;default:now()"`
|
||||
UpdatedAt time.Time `gorm:"column:updated_at;type:timestamp;default:now()"`
|
||||
|
||||
Post *ModelPost `gorm:"foreignKey:PostID;references:ID;constraint:OnDelete:CASCADE"`
|
||||
User *ModelUser `gorm:"foreignKey:UserID;references:ID;constraint:OnDelete:SET NULL"`
|
||||
}
|
||||
|
||||
func (ModelComment) TableName() string {
|
||||
return "comments"
|
||||
}
|
||||
18
tests/assets/gorm/simple.go
Normal file
18
tests/assets/gorm/simple.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
type User struct {
|
||||
ID int64 `gorm:"column:id;primaryKey;autoIncrement;type:bigint"`
|
||||
Email string `gorm:"column:email;type:varchar(255);not null"`
|
||||
Name string `gorm:"column:name;type:text"`
|
||||
Age *int `gorm:"column:age;type:integer"`
|
||||
IsActive bool `gorm:"column:is_active;type:boolean"`
|
||||
CreatedAt time.Time `gorm:"column:created_at;type:timestamp;default:now()"`
|
||||
}
|
||||
|
||||
func (User) TableName() string {
|
||||
return "users"
|
||||
}
|
||||
Reference in New Issue
Block a user