244 lines
5.6 KiB
Go
244 lines
5.6 KiB
Go
package gorm
|
|
|
|
import (
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
"testing"
|
|
|
|
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
|
"git.warky.dev/wdevs/relspecgo/pkg/writers"
|
|
)
|
|
|
|
func TestWriter_WriteTable(t *testing.T) {
|
|
// Create a simple table
|
|
table := models.InitTable("users", "public")
|
|
table.Columns["id"] = &models.Column{
|
|
Name: "id",
|
|
Type: "bigint",
|
|
NotNull: true,
|
|
IsPrimaryKey: true,
|
|
AutoIncrement: true,
|
|
Sequence: 1,
|
|
}
|
|
table.Columns["email"] = &models.Column{
|
|
Name: "email",
|
|
Type: "varchar",
|
|
Length: 255,
|
|
NotNull: false,
|
|
Sequence: 2,
|
|
}
|
|
table.Columns["created_at"] = &models.Column{
|
|
Name: "created_at",
|
|
Type: "timestamp",
|
|
NotNull: true,
|
|
Sequence: 3,
|
|
}
|
|
|
|
// Create writer
|
|
opts := &writers.WriterOptions{
|
|
PackageName: "models",
|
|
Metadata: map[string]interface{}{
|
|
"generate_table_name": true,
|
|
"generate_get_id": true,
|
|
},
|
|
}
|
|
|
|
writer := NewWriter(opts)
|
|
|
|
// Write to temporary file
|
|
tmpDir := t.TempDir()
|
|
opts.OutputPath = filepath.Join(tmpDir, "test.go")
|
|
|
|
err := writer.WriteTable(table)
|
|
if err != nil {
|
|
t.Fatalf("WriteTable failed: %v", err)
|
|
}
|
|
|
|
// Read the generated file
|
|
content, err := os.ReadFile(opts.OutputPath)
|
|
if err != nil {
|
|
t.Fatalf("Failed to read generated file: %v", err)
|
|
}
|
|
|
|
generated := string(content)
|
|
|
|
// Verify key elements are present
|
|
expectations := []string{
|
|
"package models",
|
|
"type ModelUser struct",
|
|
"ID",
|
|
"int64",
|
|
"Email",
|
|
"sql_types.SqlString",
|
|
"CreatedAt",
|
|
"time.Time",
|
|
"gorm:\"column:id",
|
|
"gorm:\"column:email",
|
|
"func (m ModelUser) TableName() string",
|
|
"return \"public.users\"",
|
|
"func (m ModelUser) GetID() int64",
|
|
}
|
|
|
|
for _, expected := range expectations {
|
|
if !strings.Contains(generated, expected) {
|
|
t.Errorf("Generated code missing expected content: %q\nGenerated:\n%s", expected, generated)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestWriter_WriteDatabase_MultiFile(t *testing.T) {
|
|
// Create a database with two tables
|
|
db := models.InitDatabase("testdb")
|
|
schema := models.InitSchema("public")
|
|
|
|
// Table 1: users
|
|
users := models.InitTable("users", "public")
|
|
users.Columns["id"] = &models.Column{
|
|
Name: "id",
|
|
Type: "bigint",
|
|
NotNull: true,
|
|
IsPrimaryKey: true,
|
|
}
|
|
schema.Tables = append(schema.Tables, users)
|
|
|
|
// Table 2: posts
|
|
posts := models.InitTable("posts", "public")
|
|
posts.Columns["id"] = &models.Column{
|
|
Name: "id",
|
|
Type: "bigint",
|
|
NotNull: true,
|
|
IsPrimaryKey: true,
|
|
}
|
|
posts.Columns["user_id"] = &models.Column{
|
|
Name: "user_id",
|
|
Type: "bigint",
|
|
NotNull: true,
|
|
}
|
|
posts.Constraints["fk_user"] = &models.Constraint{
|
|
Name: "fk_user",
|
|
Type: models.ForeignKeyConstraint,
|
|
Columns: []string{"user_id"},
|
|
ReferencedTable: "users",
|
|
ReferencedSchema: "public",
|
|
ReferencedColumns: []string{"id"},
|
|
OnDelete: "CASCADE",
|
|
}
|
|
schema.Tables = append(schema.Tables, posts)
|
|
|
|
db.Schemas = append(db.Schemas, schema)
|
|
|
|
// Create writer with multi-file mode
|
|
tmpDir := t.TempDir()
|
|
opts := &writers.WriterOptions{
|
|
PackageName: "models",
|
|
OutputPath: tmpDir,
|
|
Metadata: map[string]interface{}{
|
|
"multi_file": true,
|
|
},
|
|
}
|
|
|
|
writer := NewWriter(opts)
|
|
|
|
err := writer.WriteDatabase(db)
|
|
if err != nil {
|
|
t.Fatalf("WriteDatabase failed: %v", err)
|
|
}
|
|
|
|
// Verify two files were created
|
|
expectedFiles := []string{
|
|
"sql_public_users.go",
|
|
"sql_public_posts.go",
|
|
}
|
|
|
|
for _, filename := range expectedFiles {
|
|
filepath := filepath.Join(tmpDir, filename)
|
|
if _, err := os.Stat(filepath); os.IsNotExist(err) {
|
|
t.Errorf("Expected file not created: %s", filename)
|
|
}
|
|
}
|
|
|
|
// Check posts file contains relationship
|
|
postsContent, err := os.ReadFile(filepath.Join(tmpDir, "sql_public_posts.go"))
|
|
if err != nil {
|
|
t.Fatalf("Failed to read posts file: %v", err)
|
|
}
|
|
|
|
if !strings.Contains(string(postsContent), "USE *ModelUser") {
|
|
// Relationship field should be present
|
|
t.Logf("Posts content:\n%s", string(postsContent))
|
|
}
|
|
}
|
|
|
|
func TestNameConverter_SnakeCaseToPascalCase(t *testing.T) {
|
|
tests := []struct {
|
|
input string
|
|
expected string
|
|
}{
|
|
{"user_id", "UserID"},
|
|
{"http_request", "HTTPRequest"},
|
|
{"user_profiles", "UserProfiles"},
|
|
{"guid", "GUID"},
|
|
{"rid_process", "RIDProcess"},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.input, func(t *testing.T) {
|
|
result := SnakeCaseToPascalCase(tt.input)
|
|
if result != tt.expected {
|
|
t.Errorf("SnakeCaseToPascalCase(%q) = %q, want %q", tt.input, result, tt.expected)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestNameConverter_Pluralize(t *testing.T) {
|
|
tests := []struct {
|
|
input string
|
|
expected string
|
|
}{
|
|
{"user", "users"},
|
|
{"process", "processes"},
|
|
{"child", "children"},
|
|
{"person", "people"},
|
|
{"status", "statuses"},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.input, func(t *testing.T) {
|
|
result := Pluralize(tt.input)
|
|
if result != tt.expected {
|
|
t.Errorf("Pluralize(%q) = %q, want %q", tt.input, result, tt.expected)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestTypeMapper_SQLTypeToGoType(t *testing.T) {
|
|
mapper := NewTypeMapper()
|
|
|
|
tests := []struct {
|
|
sqlType string
|
|
notNull bool
|
|
want string
|
|
}{
|
|
{"bigint", true, "int64"},
|
|
{"bigint", false, "sql_types.SqlInt64"},
|
|
{"varchar", true, "string"},
|
|
{"varchar", false, "sql_types.SqlString"},
|
|
{"timestamp", true, "time.Time"},
|
|
{"timestamp", false, "sql_types.SqlTime"},
|
|
{"boolean", true, "bool"},
|
|
{"boolean", false, "sql_types.SqlBool"},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.sqlType, func(t *testing.T) {
|
|
result := mapper.SQLTypeToGoType(tt.sqlType, tt.notNull)
|
|
if result != tt.want {
|
|
t.Errorf("SQLTypeToGoType(%q, %v) = %q, want %q", tt.sqlType, tt.notNull, result, tt.want)
|
|
}
|
|
})
|
|
}
|
|
}
|