Compare commits
9 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 6f55505444 | |||
| e0e7b64c69 | |||
| 4181cb1fbd | |||
| 120ffc6a5a | |||
| b20ad35485 | |||
| f258f8baeb | |||
| 6388daba56 | |||
| f6c3f2b460 | |||
| 156e655571 |
5
.github/workflows/integration-tests.yml
vendored
5
.github/workflows/integration-tests.yml
vendored
@@ -46,6 +46,11 @@ jobs:
|
|||||||
- name: Download dependencies
|
- name: Download dependencies
|
||||||
run: go mod download
|
run: go mod download
|
||||||
|
|
||||||
|
- name: Install PostgreSQL client
|
||||||
|
run: |
|
||||||
|
sudo apt-get update
|
||||||
|
sudo apt-get install -y postgresql-client
|
||||||
|
|
||||||
- name: Initialize test database
|
- name: Initialize test database
|
||||||
env:
|
env:
|
||||||
PGPASSWORD: relspec_test_password
|
PGPASSWORD: relspec_test_password
|
||||||
|
|||||||
@@ -632,6 +632,9 @@ func (r *Reader) parseColumn(fieldName string, fieldType ast.Expr, tag string, s
|
|||||||
column.Name = parts[0]
|
column.Name = parts[0]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Track if we found explicit nullability markers
|
||||||
|
hasExplicitNullableMarker := false
|
||||||
|
|
||||||
// Parse tag attributes
|
// Parse tag attributes
|
||||||
for _, part := range parts[1:] {
|
for _, part := range parts[1:] {
|
||||||
kv := strings.SplitN(part, ":", 2)
|
kv := strings.SplitN(part, ":", 2)
|
||||||
@@ -649,6 +652,10 @@ func (r *Reader) parseColumn(fieldName string, fieldType ast.Expr, tag string, s
|
|||||||
column.IsPrimaryKey = true
|
column.IsPrimaryKey = true
|
||||||
case "notnull":
|
case "notnull":
|
||||||
column.NotNull = true
|
column.NotNull = true
|
||||||
|
hasExplicitNullableMarker = true
|
||||||
|
case "nullzero":
|
||||||
|
column.NotNull = false
|
||||||
|
hasExplicitNullableMarker = true
|
||||||
case "autoincrement":
|
case "autoincrement":
|
||||||
column.AutoIncrement = true
|
column.AutoIncrement = true
|
||||||
case "default":
|
case "default":
|
||||||
@@ -664,17 +671,15 @@ func (r *Reader) parseColumn(fieldName string, fieldType ast.Expr, tag string, s
|
|||||||
|
|
||||||
// Determine if nullable based on Go type and bun tags
|
// Determine if nullable based on Go type and bun tags
|
||||||
// In Bun:
|
// In Bun:
|
||||||
// - nullzero tag means the field is nullable (can be NULL in DB)
|
// - explicit "notnull" tag means NOT NULL
|
||||||
// - absence of nullzero means the field is NOT NULL
|
// - explicit "nullzero" tag means nullable
|
||||||
// - primitive types (int64, bool, string) are NOT NULL by default
|
// - absence of explicit markers: infer from Go type
|
||||||
column.NotNull = true
|
if !hasExplicitNullableMarker {
|
||||||
// Primary keys are always NOT NULL
|
// Infer from Go type if no explicit marker found
|
||||||
|
|
||||||
if strings.Contains(bunTag, "nullzero") {
|
|
||||||
column.NotNull = false
|
|
||||||
} else {
|
|
||||||
column.NotNull = !r.isNullableGoType(fieldType)
|
column.NotNull = !r.isNullableGoType(fieldType)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Primary keys are always NOT NULL
|
||||||
if column.IsPrimaryKey {
|
if column.IsPrimaryKey {
|
||||||
column.NotNull = true
|
column.NotNull = true
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,7 +4,9 @@ import (
|
|||||||
"bufio"
|
"bufio"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
|
"path/filepath"
|
||||||
"regexp"
|
"regexp"
|
||||||
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||||
@@ -24,11 +26,23 @@ func NewReader(options *readers.ReaderOptions) *Reader {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ReadDatabase reads and parses DBML input, returning a Database model
|
// ReadDatabase reads and parses DBML input, returning a Database model
|
||||||
|
// If FilePath points to a directory, all .dbml files are loaded and merged
|
||||||
func (r *Reader) ReadDatabase() (*models.Database, error) {
|
func (r *Reader) ReadDatabase() (*models.Database, error) {
|
||||||
if r.options.FilePath == "" {
|
if r.options.FilePath == "" {
|
||||||
return nil, fmt.Errorf("file path is required for DBML reader")
|
return nil, fmt.Errorf("file path is required for DBML reader")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check if path is a directory
|
||||||
|
info, err := os.Stat(r.options.FilePath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to stat path: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if info.IsDir() {
|
||||||
|
return r.readDirectoryDBML(r.options.FilePath)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Single file - existing logic
|
||||||
content, err := os.ReadFile(r.options.FilePath)
|
content, err := os.ReadFile(r.options.FilePath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to read file: %w", err)
|
return nil, fmt.Errorf("failed to read file: %w", err)
|
||||||
@@ -67,15 +81,301 @@ func (r *Reader) ReadTable() (*models.Table, error) {
|
|||||||
return schema.Tables[0], nil
|
return schema.Tables[0], nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// stripQuotes removes surrounding quotes from an identifier
|
// readDirectoryDBML processes all .dbml files in directory
|
||||||
|
// Returns merged Database model
|
||||||
|
func (r *Reader) readDirectoryDBML(dirPath string) (*models.Database, error) {
|
||||||
|
// Discover and sort DBML files
|
||||||
|
files, err := r.discoverDBMLFiles(dirPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to discover DBML files: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// If no files found, return empty database
|
||||||
|
if len(files) == 0 {
|
||||||
|
db := models.InitDatabase("database")
|
||||||
|
if r.options.Metadata != nil {
|
||||||
|
if name, ok := r.options.Metadata["name"].(string); ok {
|
||||||
|
db.Name = name
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return db, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize database (will be merged with files)
|
||||||
|
var db *models.Database
|
||||||
|
|
||||||
|
// Process each file in sorted order
|
||||||
|
for _, filePath := range files {
|
||||||
|
content, err := os.ReadFile(filePath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to read file %s: %w", filePath, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fileDB, err := r.parseDBML(string(content))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse file %s: %w", filePath, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// First file initializes the database
|
||||||
|
if db == nil {
|
||||||
|
db = fileDB
|
||||||
|
} else {
|
||||||
|
// Subsequent files are merged
|
||||||
|
mergeDatabase(db, fileDB)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return db, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// stripQuotes removes surrounding quotes and comments from an identifier
|
||||||
func stripQuotes(s string) string {
|
func stripQuotes(s string) string {
|
||||||
s = strings.TrimSpace(s)
|
s = strings.TrimSpace(s)
|
||||||
|
|
||||||
|
// Remove DBML comments in brackets (e.g., [note: 'description'])
|
||||||
|
// This handles inline comments like: "table_name" [note: 'comment']
|
||||||
|
commentRegex := regexp.MustCompile(`\s*\[.*?\]\s*`)
|
||||||
|
s = commentRegex.ReplaceAllString(s, "")
|
||||||
|
|
||||||
|
// Trim again after removing comments
|
||||||
|
s = strings.TrimSpace(s)
|
||||||
|
|
||||||
|
// Remove surrounding quotes (double or single)
|
||||||
if len(s) >= 2 && ((s[0] == '"' && s[len(s)-1] == '"') || (s[0] == '\'' && s[len(s)-1] == '\'')) {
|
if len(s) >= 2 && ((s[0] == '"' && s[len(s)-1] == '"') || (s[0] == '\'' && s[len(s)-1] == '\'')) {
|
||||||
return s[1 : len(s)-1]
|
return s[1 : len(s)-1]
|
||||||
}
|
}
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// parseFilePrefix extracts numeric prefix from filename
|
||||||
|
// Examples: "1_schema.dbml" -> (1, true), "tables.dbml" -> (0, false)
|
||||||
|
func parseFilePrefix(filename string) (int, bool) {
|
||||||
|
base := filepath.Base(filename)
|
||||||
|
re := regexp.MustCompile(`^(\d+)[_-]`)
|
||||||
|
matches := re.FindStringSubmatch(base)
|
||||||
|
if len(matches) > 1 {
|
||||||
|
var prefix int
|
||||||
|
_, err := fmt.Sscanf(matches[1], "%d", &prefix)
|
||||||
|
if err == nil {
|
||||||
|
return prefix, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// hasCommentedRefs scans file content for commented-out Ref statements
|
||||||
|
// Returns true if file contains lines like: // Ref: table.col > other.col
|
||||||
|
func hasCommentedRefs(filePath string) (bool, error) {
|
||||||
|
content, err := os.ReadFile(filePath)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
scanner := bufio.NewScanner(strings.NewReader(string(content)))
|
||||||
|
commentedRefRegex := regexp.MustCompile(`^\s*//.*Ref:\s+`)
|
||||||
|
|
||||||
|
for scanner.Scan() {
|
||||||
|
line := scanner.Text()
|
||||||
|
if commentedRefRegex.MatchString(line) {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// discoverDBMLFiles finds all .dbml files in directory and returns them sorted
|
||||||
|
func (r *Reader) discoverDBMLFiles(dirPath string) ([]string, error) {
|
||||||
|
pattern := filepath.Join(dirPath, "*.dbml")
|
||||||
|
files, err := filepath.Glob(pattern)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to glob .dbml files: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return sortDBMLFiles(files), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// sortDBMLFiles sorts files by:
|
||||||
|
// 1. Files without commented refs (by numeric prefix, then alphabetically)
|
||||||
|
// 2. Files with commented refs (by numeric prefix, then alphabetically)
|
||||||
|
func sortDBMLFiles(files []string) []string {
|
||||||
|
// Create a slice to hold file info for sorting
|
||||||
|
type fileInfo struct {
|
||||||
|
path string
|
||||||
|
hasCommented bool
|
||||||
|
prefix int
|
||||||
|
hasPrefix bool
|
||||||
|
basename string
|
||||||
|
}
|
||||||
|
|
||||||
|
fileInfos := make([]fileInfo, 0, len(files))
|
||||||
|
|
||||||
|
for _, file := range files {
|
||||||
|
hasCommented, err := hasCommentedRefs(file)
|
||||||
|
if err != nil {
|
||||||
|
// If we can't read the file, treat it as not having commented refs
|
||||||
|
hasCommented = false
|
||||||
|
}
|
||||||
|
|
||||||
|
prefix, hasPrefix := parseFilePrefix(file)
|
||||||
|
basename := filepath.Base(file)
|
||||||
|
|
||||||
|
fileInfos = append(fileInfos, fileInfo{
|
||||||
|
path: file,
|
||||||
|
hasCommented: hasCommented,
|
||||||
|
prefix: prefix,
|
||||||
|
hasPrefix: hasPrefix,
|
||||||
|
basename: basename,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sort by: hasCommented (false first), hasPrefix (true first), prefix, basename
|
||||||
|
sort.Slice(fileInfos, func(i, j int) bool {
|
||||||
|
// First, sort by commented refs (files without commented refs come first)
|
||||||
|
if fileInfos[i].hasCommented != fileInfos[j].hasCommented {
|
||||||
|
return !fileInfos[i].hasCommented
|
||||||
|
}
|
||||||
|
|
||||||
|
// Then by presence of prefix (files with prefix come first)
|
||||||
|
if fileInfos[i].hasPrefix != fileInfos[j].hasPrefix {
|
||||||
|
return fileInfos[i].hasPrefix
|
||||||
|
}
|
||||||
|
|
||||||
|
// If both have prefix, sort by prefix value
|
||||||
|
if fileInfos[i].hasPrefix && fileInfos[j].hasPrefix {
|
||||||
|
if fileInfos[i].prefix != fileInfos[j].prefix {
|
||||||
|
return fileInfos[i].prefix < fileInfos[j].prefix
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Finally, sort alphabetically by basename
|
||||||
|
return fileInfos[i].basename < fileInfos[j].basename
|
||||||
|
})
|
||||||
|
|
||||||
|
// Extract sorted paths
|
||||||
|
sortedFiles := make([]string, len(fileInfos))
|
||||||
|
for i, info := range fileInfos {
|
||||||
|
sortedFiles[i] = info.path
|
||||||
|
}
|
||||||
|
|
||||||
|
return sortedFiles
|
||||||
|
}
|
||||||
|
|
||||||
|
// mergeTable combines two table definitions
|
||||||
|
// Merges: Columns (map), Constraints (map), Indexes (map), Relationships (map)
|
||||||
|
// Uses first non-empty Description
|
||||||
|
func mergeTable(baseTable, fileTable *models.Table) {
|
||||||
|
// Merge columns (map naturally merges - later keys overwrite)
|
||||||
|
for key, col := range fileTable.Columns {
|
||||||
|
baseTable.Columns[key] = col
|
||||||
|
}
|
||||||
|
|
||||||
|
// Merge constraints
|
||||||
|
for key, constraint := range fileTable.Constraints {
|
||||||
|
baseTable.Constraints[key] = constraint
|
||||||
|
}
|
||||||
|
|
||||||
|
// Merge indexes
|
||||||
|
for key, index := range fileTable.Indexes {
|
||||||
|
baseTable.Indexes[key] = index
|
||||||
|
}
|
||||||
|
|
||||||
|
// Merge relationships
|
||||||
|
for key, rel := range fileTable.Relationships {
|
||||||
|
baseTable.Relationships[key] = rel
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use first non-empty description
|
||||||
|
if baseTable.Description == "" && fileTable.Description != "" {
|
||||||
|
baseTable.Description = fileTable.Description
|
||||||
|
}
|
||||||
|
|
||||||
|
// Merge metadata maps
|
||||||
|
if baseTable.Metadata == nil {
|
||||||
|
baseTable.Metadata = make(map[string]any)
|
||||||
|
}
|
||||||
|
for key, val := range fileTable.Metadata {
|
||||||
|
baseTable.Metadata[key] = val
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// mergeSchema finds or creates schema and merges tables
|
||||||
|
func mergeSchema(baseDB *models.Database, fileSchema *models.Schema) {
|
||||||
|
// Find existing schema by name (normalize names by stripping quotes)
|
||||||
|
var existingSchema *models.Schema
|
||||||
|
fileSchemaName := stripQuotes(fileSchema.Name)
|
||||||
|
for _, schema := range baseDB.Schemas {
|
||||||
|
if stripQuotes(schema.Name) == fileSchemaName {
|
||||||
|
existingSchema = schema
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If schema doesn't exist, add it and return
|
||||||
|
if existingSchema == nil {
|
||||||
|
baseDB.Schemas = append(baseDB.Schemas, fileSchema)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Merge tables from fileSchema into existingSchema
|
||||||
|
for _, fileTable := range fileSchema.Tables {
|
||||||
|
// Find existing table by name (normalize names by stripping quotes)
|
||||||
|
var existingTable *models.Table
|
||||||
|
fileTableName := stripQuotes(fileTable.Name)
|
||||||
|
for _, table := range existingSchema.Tables {
|
||||||
|
if stripQuotes(table.Name) == fileTableName {
|
||||||
|
existingTable = table
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If table doesn't exist, add it
|
||||||
|
if existingTable == nil {
|
||||||
|
existingSchema.Tables = append(existingSchema.Tables, fileTable)
|
||||||
|
} else {
|
||||||
|
// Merge table properties - tables are identical, skip
|
||||||
|
mergeTable(existingTable, fileTable)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Merge other schema properties
|
||||||
|
existingSchema.Views = append(existingSchema.Views, fileSchema.Views...)
|
||||||
|
existingSchema.Sequences = append(existingSchema.Sequences, fileSchema.Sequences...)
|
||||||
|
existingSchema.Scripts = append(existingSchema.Scripts, fileSchema.Scripts...)
|
||||||
|
|
||||||
|
// Merge permissions
|
||||||
|
if existingSchema.Permissions == nil {
|
||||||
|
existingSchema.Permissions = make(map[string]string)
|
||||||
|
}
|
||||||
|
for key, val := range fileSchema.Permissions {
|
||||||
|
existingSchema.Permissions[key] = val
|
||||||
|
}
|
||||||
|
|
||||||
|
// Merge metadata
|
||||||
|
if existingSchema.Metadata == nil {
|
||||||
|
existingSchema.Metadata = make(map[string]any)
|
||||||
|
}
|
||||||
|
for key, val := range fileSchema.Metadata {
|
||||||
|
existingSchema.Metadata[key] = val
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// mergeDatabase merges schemas from fileDB into baseDB
|
||||||
|
func mergeDatabase(baseDB, fileDB *models.Database) {
|
||||||
|
// Merge each schema from fileDB
|
||||||
|
for _, fileSchema := range fileDB.Schemas {
|
||||||
|
mergeSchema(baseDB, fileSchema)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Merge domains
|
||||||
|
baseDB.Domains = append(baseDB.Domains, fileDB.Domains...)
|
||||||
|
|
||||||
|
// Use first non-empty description
|
||||||
|
if baseDB.Description == "" && fileDB.Description != "" {
|
||||||
|
baseDB.Description = fileDB.Description
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// parseDBML parses DBML content and returns a Database model
|
// parseDBML parses DBML content and returns a Database model
|
||||||
func (r *Reader) parseDBML(content string) (*models.Database, error) {
|
func (r *Reader) parseDBML(content string) (*models.Database, error) {
|
||||||
db := models.InitDatabase("database")
|
db := models.InitDatabase("database")
|
||||||
@@ -332,27 +632,31 @@ func (r *Reader) parseIndex(line, tableName, schemaName string) *models.Index {
|
|||||||
// Format: (columns) [attributes] OR columnname [attributes]
|
// Format: (columns) [attributes] OR columnname [attributes]
|
||||||
var columns []string
|
var columns []string
|
||||||
|
|
||||||
if strings.Contains(line, "(") && strings.Contains(line, ")") {
|
// Find the attributes section to avoid parsing parentheses in notes/attributes
|
||||||
|
attrStart := strings.Index(line, "[")
|
||||||
|
columnPart := line
|
||||||
|
if attrStart > 0 {
|
||||||
|
columnPart = line[:attrStart]
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.Contains(columnPart, "(") && strings.Contains(columnPart, ")") {
|
||||||
// Multi-column format: (col1, col2) [attributes]
|
// Multi-column format: (col1, col2) [attributes]
|
||||||
colStart := strings.Index(line, "(")
|
colStart := strings.Index(columnPart, "(")
|
||||||
colEnd := strings.Index(line, ")")
|
colEnd := strings.Index(columnPart, ")")
|
||||||
if colStart >= colEnd {
|
if colStart >= colEnd {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
columnsStr := line[colStart+1 : colEnd]
|
columnsStr := columnPart[colStart+1 : colEnd]
|
||||||
for _, col := range strings.Split(columnsStr, ",") {
|
for _, col := range strings.Split(columnsStr, ",") {
|
||||||
columns = append(columns, stripQuotes(strings.TrimSpace(col)))
|
columns = append(columns, stripQuotes(strings.TrimSpace(col)))
|
||||||
}
|
}
|
||||||
} else if strings.Contains(line, "[") {
|
} else if attrStart > 0 {
|
||||||
// Single column format: columnname [attributes]
|
// Single column format: columnname [attributes]
|
||||||
// Extract column name before the bracket
|
// Extract column name before the bracket
|
||||||
idx := strings.Index(line, "[")
|
colName := strings.TrimSpace(columnPart)
|
||||||
if idx > 0 {
|
if colName != "" {
|
||||||
colName := strings.TrimSpace(line[:idx])
|
columns = []string{stripQuotes(colName)}
|
||||||
if colName != "" {
|
|
||||||
columns = []string{stripQuotes(colName)}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package dbml
|
package dbml
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
@@ -517,3 +518,286 @@ func TestGetForeignKeys(t *testing.T) {
|
|||||||
t.Error("Expected foreign key constraint type")
|
t.Error("Expected foreign key constraint type")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Tests for multi-file directory loading
|
||||||
|
|
||||||
|
func TestReadDirectory_MultipleFiles(t *testing.T) {
|
||||||
|
opts := &readers.ReaderOptions{
|
||||||
|
FilePath: filepath.Join("..", "..", "..", "tests", "assets", "dbml", "multifile"),
|
||||||
|
}
|
||||||
|
|
||||||
|
reader := NewReader(opts)
|
||||||
|
db, err := reader.ReadDatabase()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ReadDatabase() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if db == nil {
|
||||||
|
t.Fatal("ReadDatabase() returned nil database")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should have public schema
|
||||||
|
if len(db.Schemas) == 0 {
|
||||||
|
t.Fatal("Expected at least one schema")
|
||||||
|
}
|
||||||
|
|
||||||
|
var publicSchema *models.Schema
|
||||||
|
for _, schema := range db.Schemas {
|
||||||
|
if schema.Name == "public" {
|
||||||
|
publicSchema = schema
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if publicSchema == nil {
|
||||||
|
t.Fatal("Public schema not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should have 3 tables: users, posts, comments
|
||||||
|
if len(publicSchema.Tables) != 3 {
|
||||||
|
t.Fatalf("Expected 3 tables, got %d", len(publicSchema.Tables))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find tables
|
||||||
|
var usersTable, postsTable, commentsTable *models.Table
|
||||||
|
for _, table := range publicSchema.Tables {
|
||||||
|
switch table.Name {
|
||||||
|
case "users":
|
||||||
|
usersTable = table
|
||||||
|
case "posts":
|
||||||
|
postsTable = table
|
||||||
|
case "comments":
|
||||||
|
commentsTable = table
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if usersTable == nil {
|
||||||
|
t.Fatal("Users table not found")
|
||||||
|
}
|
||||||
|
if postsTable == nil {
|
||||||
|
t.Fatal("Posts table not found")
|
||||||
|
}
|
||||||
|
if commentsTable == nil {
|
||||||
|
t.Fatal("Comments table not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify users table has merged columns from 1_users.dbml and 3_add_columns.dbml
|
||||||
|
expectedUserColumns := []string{"id", "email", "name", "created_at"}
|
||||||
|
if len(usersTable.Columns) != len(expectedUserColumns) {
|
||||||
|
t.Errorf("Expected %d columns in users table, got %d", len(expectedUserColumns), len(usersTable.Columns))
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, colName := range expectedUserColumns {
|
||||||
|
if _, exists := usersTable.Columns[colName]; !exists {
|
||||||
|
t.Errorf("Expected column '%s' in users table", colName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify posts table columns
|
||||||
|
expectedPostColumns := []string{"id", "user_id", "title", "content", "created_at"}
|
||||||
|
for _, colName := range expectedPostColumns {
|
||||||
|
if _, exists := postsTable.Columns[colName]; !exists {
|
||||||
|
t.Errorf("Expected column '%s' in posts table", colName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReadDirectory_TableMerging(t *testing.T) {
|
||||||
|
opts := &readers.ReaderOptions{
|
||||||
|
FilePath: filepath.Join("..", "..", "..", "tests", "assets", "dbml", "multifile"),
|
||||||
|
}
|
||||||
|
|
||||||
|
reader := NewReader(opts)
|
||||||
|
db, err := reader.ReadDatabase()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ReadDatabase() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find users table
|
||||||
|
var usersTable *models.Table
|
||||||
|
for _, schema := range db.Schemas {
|
||||||
|
for _, table := range schema.Tables {
|
||||||
|
if table.Name == "users" && schema.Name == "public" {
|
||||||
|
usersTable = table
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if usersTable == nil {
|
||||||
|
t.Fatal("Users table not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify columns from file 1 (id, email)
|
||||||
|
if _, exists := usersTable.Columns["id"]; !exists {
|
||||||
|
t.Error("Column 'id' from 1_users.dbml not found")
|
||||||
|
}
|
||||||
|
if _, exists := usersTable.Columns["email"]; !exists {
|
||||||
|
t.Error("Column 'email' from 1_users.dbml not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify columns from file 3 (name, created_at)
|
||||||
|
if _, exists := usersTable.Columns["name"]; !exists {
|
||||||
|
t.Error("Column 'name' from 3_add_columns.dbml not found")
|
||||||
|
}
|
||||||
|
if _, exists := usersTable.Columns["created_at"]; !exists {
|
||||||
|
t.Error("Column 'created_at' from 3_add_columns.dbml not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify column properties from file 1
|
||||||
|
emailCol := usersTable.Columns["email"]
|
||||||
|
if !emailCol.NotNull {
|
||||||
|
t.Error("Email column should be not null (from 1_users.dbml)")
|
||||||
|
}
|
||||||
|
if emailCol.Type != "varchar(255)" {
|
||||||
|
t.Errorf("Expected email type 'varchar(255)', got '%s'", emailCol.Type)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReadDirectory_CommentedRefsLast(t *testing.T) {
|
||||||
|
// This test verifies that files with commented refs are processed last
|
||||||
|
// by checking that the file discovery returns them in the correct order
|
||||||
|
dirPath := filepath.Join("..", "..", "..", "tests", "assets", "dbml", "multifile")
|
||||||
|
|
||||||
|
opts := &readers.ReaderOptions{
|
||||||
|
FilePath: dirPath,
|
||||||
|
}
|
||||||
|
|
||||||
|
reader := NewReader(opts)
|
||||||
|
files, err := reader.discoverDBMLFiles(dirPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("discoverDBMLFiles() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(files) < 2 {
|
||||||
|
t.Skip("Not enough files to test ordering")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that 9_refs.dbml (which has commented refs) comes last
|
||||||
|
lastFile := filepath.Base(files[len(files)-1])
|
||||||
|
if lastFile != "9_refs.dbml" {
|
||||||
|
t.Errorf("Expected last file to be '9_refs.dbml' (has commented refs), got '%s'", lastFile)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that numbered files without commented refs come first
|
||||||
|
firstFile := filepath.Base(files[0])
|
||||||
|
if firstFile != "1_users.dbml" {
|
||||||
|
t.Errorf("Expected first file to be '1_users.dbml', got '%s'", firstFile)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReadDirectory_EmptyDirectory(t *testing.T) {
|
||||||
|
// Create a temporary empty directory
|
||||||
|
tmpDir := filepath.Join("..", "..", "..", "tests", "assets", "dbml", "empty_test_dir")
|
||||||
|
err := os.MkdirAll(tmpDir, 0755)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create temp directory: %v", err)
|
||||||
|
}
|
||||||
|
defer os.RemoveAll(tmpDir)
|
||||||
|
|
||||||
|
opts := &readers.ReaderOptions{
|
||||||
|
FilePath: tmpDir,
|
||||||
|
}
|
||||||
|
|
||||||
|
reader := NewReader(opts)
|
||||||
|
db, err := reader.ReadDatabase()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ReadDatabase() should not error on empty directory, got: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if db == nil {
|
||||||
|
t.Fatal("ReadDatabase() returned nil database")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Empty directory should return empty database
|
||||||
|
if len(db.Schemas) != 0 {
|
||||||
|
t.Errorf("Expected 0 schemas for empty directory, got %d", len(db.Schemas))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReadDatabase_BackwardCompat(t *testing.T) {
|
||||||
|
// Test that single file loading still works
|
||||||
|
opts := &readers.ReaderOptions{
|
||||||
|
FilePath: filepath.Join("..", "..", "..", "tests", "assets", "dbml", "simple.dbml"),
|
||||||
|
}
|
||||||
|
|
||||||
|
reader := NewReader(opts)
|
||||||
|
db, err := reader.ReadDatabase()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ReadDatabase() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if db == nil {
|
||||||
|
t.Fatal("ReadDatabase() returned nil database")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(db.Schemas) == 0 {
|
||||||
|
t.Fatal("Expected at least one schema")
|
||||||
|
}
|
||||||
|
|
||||||
|
schema := db.Schemas[0]
|
||||||
|
if len(schema.Tables) != 1 {
|
||||||
|
t.Fatalf("Expected 1 table, got %d", len(schema.Tables))
|
||||||
|
}
|
||||||
|
|
||||||
|
table := schema.Tables[0]
|
||||||
|
if table.Name != "users" {
|
||||||
|
t.Errorf("Expected table name 'users', got '%s'", table.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseFilePrefix(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
filename string
|
||||||
|
wantPrefix int
|
||||||
|
wantHas bool
|
||||||
|
}{
|
||||||
|
{"1_schema.dbml", 1, true},
|
||||||
|
{"2_tables.dbml", 2, true},
|
||||||
|
{"10_relationships.dbml", 10, true},
|
||||||
|
{"99_data.dbml", 99, true},
|
||||||
|
{"schema.dbml", 0, false},
|
||||||
|
{"tables_no_prefix.dbml", 0, false},
|
||||||
|
{"/path/to/1_file.dbml", 1, true},
|
||||||
|
{"/path/to/file.dbml", 0, false},
|
||||||
|
{"1-file.dbml", 1, true},
|
||||||
|
{"2-another.dbml", 2, true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.filename, func(t *testing.T) {
|
||||||
|
gotPrefix, gotHas := parseFilePrefix(tt.filename)
|
||||||
|
if gotPrefix != tt.wantPrefix {
|
||||||
|
t.Errorf("parseFilePrefix(%s) prefix = %d, want %d", tt.filename, gotPrefix, tt.wantPrefix)
|
||||||
|
}
|
||||||
|
if gotHas != tt.wantHas {
|
||||||
|
t.Errorf("parseFilePrefix(%s) hasPrefix = %v, want %v", tt.filename, gotHas, tt.wantHas)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHasCommentedRefs(t *testing.T) {
|
||||||
|
// Test with the actual multifile test fixtures
|
||||||
|
tests := []struct {
|
||||||
|
filename string
|
||||||
|
wantHas bool
|
||||||
|
}{
|
||||||
|
{filepath.Join("..", "..", "..", "tests", "assets", "dbml", "multifile", "1_users.dbml"), false},
|
||||||
|
{filepath.Join("..", "..", "..", "tests", "assets", "dbml", "multifile", "2_posts.dbml"), false},
|
||||||
|
{filepath.Join("..", "..", "..", "tests", "assets", "dbml", "multifile", "3_add_columns.dbml"), false},
|
||||||
|
{filepath.Join("..", "..", "..", "tests", "assets", "dbml", "multifile", "9_refs.dbml"), true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(filepath.Base(tt.filename), func(t *testing.T) {
|
||||||
|
gotHas, err := hasCommentedRefs(tt.filename)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("hasCommentedRefs() error = %v", err)
|
||||||
|
}
|
||||||
|
if gotHas != tt.wantHas {
|
||||||
|
t.Errorf("hasCommentedRefs(%s) = %v, want %v", filepath.Base(tt.filename), gotHas, tt.wantHas)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||||
|
"git.warky.dev/wdevs/relspecgo/pkg/writers"
|
||||||
)
|
)
|
||||||
|
|
||||||
// TemplateData represents the data passed to the template for code generation
|
// TemplateData represents the data passed to the template for code generation
|
||||||
@@ -111,13 +112,17 @@ func NewModelData(table *models.Table, schema string, typeMapper *TypeMapper) *M
|
|||||||
tableName = schema + "." + table.Name
|
tableName = schema + "." + table.Name
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate model name: singularize and convert to PascalCase
|
// Generate model name: Model + Schema + Table (all PascalCase)
|
||||||
singularTable := Singularize(table.Name)
|
singularTable := Singularize(table.Name)
|
||||||
modelName := SnakeCaseToPascalCase(singularTable)
|
tablePart := SnakeCaseToPascalCase(singularTable)
|
||||||
|
|
||||||
// Add "Model" prefix if not already present
|
// Include schema name in model name
|
||||||
if !hasModelPrefix(modelName) {
|
var modelName string
|
||||||
modelName = "Model" + modelName
|
if schema != "" {
|
||||||
|
schemaPart := SnakeCaseToPascalCase(schema)
|
||||||
|
modelName = "Model" + schemaPart + tablePart
|
||||||
|
} else {
|
||||||
|
modelName = "Model" + tablePart
|
||||||
}
|
}
|
||||||
|
|
||||||
model := &ModelData{
|
model := &ModelData{
|
||||||
@@ -133,8 +138,10 @@ func NewModelData(table *models.Table, schema string, typeMapper *TypeMapper) *M
|
|||||||
// Find primary key
|
// Find primary key
|
||||||
for _, col := range table.Columns {
|
for _, col := range table.Columns {
|
||||||
if col.IsPrimaryKey {
|
if col.IsPrimaryKey {
|
||||||
model.PrimaryKeyField = SnakeCaseToPascalCase(col.Name)
|
// Sanitize column name to remove backticks
|
||||||
model.IDColumnName = col.Name
|
safeName := writers.SanitizeStructTagValue(col.Name)
|
||||||
|
model.PrimaryKeyField = SnakeCaseToPascalCase(safeName)
|
||||||
|
model.IDColumnName = safeName
|
||||||
// Check if PK type is a SQL type (contains resolvespec_common or sql_types)
|
// Check if PK type is a SQL type (contains resolvespec_common or sql_types)
|
||||||
goType := typeMapper.SQLTypeToGoType(col.Type, col.NotNull)
|
goType := typeMapper.SQLTypeToGoType(col.Type, col.NotNull)
|
||||||
model.PrimaryKeyIsSQL = strings.Contains(goType, "resolvespec_common") || strings.Contains(goType, "sql_types")
|
model.PrimaryKeyIsSQL = strings.Contains(goType, "resolvespec_common") || strings.Contains(goType, "sql_types")
|
||||||
@@ -146,6 +153,8 @@ func NewModelData(table *models.Table, schema string, typeMapper *TypeMapper) *M
|
|||||||
columns := sortColumns(table.Columns)
|
columns := sortColumns(table.Columns)
|
||||||
for _, col := range columns {
|
for _, col := range columns {
|
||||||
field := columnToField(col, table, typeMapper)
|
field := columnToField(col, table, typeMapper)
|
||||||
|
// Check for name collision with generated methods and rename if needed
|
||||||
|
field.Name = resolveFieldNameCollision(field.Name)
|
||||||
model.Fields = append(model.Fields, field)
|
model.Fields = append(model.Fields, field)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -154,10 +163,13 @@ func NewModelData(table *models.Table, schema string, typeMapper *TypeMapper) *M
|
|||||||
|
|
||||||
// columnToField converts a models.Column to FieldData
|
// columnToField converts a models.Column to FieldData
|
||||||
func columnToField(col *models.Column, table *models.Table, typeMapper *TypeMapper) *FieldData {
|
func columnToField(col *models.Column, table *models.Table, typeMapper *TypeMapper) *FieldData {
|
||||||
fieldName := SnakeCaseToPascalCase(col.Name)
|
// Sanitize column name first to remove backticks before generating field name
|
||||||
|
safeName := writers.SanitizeStructTagValue(col.Name)
|
||||||
|
fieldName := SnakeCaseToPascalCase(safeName)
|
||||||
goType := typeMapper.SQLTypeToGoType(col.Type, col.NotNull)
|
goType := typeMapper.SQLTypeToGoType(col.Type, col.NotNull)
|
||||||
bunTag := typeMapper.BuildBunTag(col, table)
|
bunTag := typeMapper.BuildBunTag(col, table)
|
||||||
jsonTag := col.Name // Use column name for JSON tag
|
// Use same sanitized name for JSON tag
|
||||||
|
jsonTag := safeName
|
||||||
|
|
||||||
return &FieldData{
|
return &FieldData{
|
||||||
Name: fieldName,
|
Name: fieldName,
|
||||||
@@ -184,9 +196,28 @@ func formatComment(description, comment string) string {
|
|||||||
return comment
|
return comment
|
||||||
}
|
}
|
||||||
|
|
||||||
// hasModelPrefix checks if a name already has "Model" prefix
|
// resolveFieldNameCollision checks if a field name conflicts with generated method names
|
||||||
func hasModelPrefix(name string) bool {
|
// and adds an underscore suffix if there's a collision
|
||||||
return len(name) >= 5 && name[:5] == "Model"
|
func resolveFieldNameCollision(fieldName string) string {
|
||||||
|
// List of method names that are generated by the template
|
||||||
|
reservedNames := map[string]bool{
|
||||||
|
"TableName": true,
|
||||||
|
"TableNameOnly": true,
|
||||||
|
"SchemaName": true,
|
||||||
|
"GetID": true,
|
||||||
|
"GetIDStr": true,
|
||||||
|
"SetID": true,
|
||||||
|
"UpdateID": true,
|
||||||
|
"GetIDName": true,
|
||||||
|
"GetPrefix": true,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if field name conflicts with a reserved method name
|
||||||
|
if reservedNames[fieldName] {
|
||||||
|
return fieldName + "_"
|
||||||
|
}
|
||||||
|
|
||||||
|
return fieldName
|
||||||
}
|
}
|
||||||
|
|
||||||
// sortColumns sorts columns by sequence, then by name
|
// sortColumns sorts columns by sequence, then by name
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||||
|
"git.warky.dev/wdevs/relspecgo/pkg/writers"
|
||||||
)
|
)
|
||||||
|
|
||||||
// TypeMapper handles type conversions between SQL and Go types for Bun
|
// TypeMapper handles type conversions between SQL and Go types for Bun
|
||||||
@@ -164,11 +165,14 @@ func (tm *TypeMapper) BuildBunTag(column *models.Column, table *models.Table) st
|
|||||||
var parts []string
|
var parts []string
|
||||||
|
|
||||||
// Column name comes first (no prefix)
|
// Column name comes first (no prefix)
|
||||||
parts = append(parts, column.Name)
|
// Sanitize to remove backticks which would break struct tag syntax
|
||||||
|
safeName := writers.SanitizeStructTagValue(column.Name)
|
||||||
|
parts = append(parts, safeName)
|
||||||
|
|
||||||
// Add type if specified
|
// Add type if specified
|
||||||
if column.Type != "" {
|
if column.Type != "" {
|
||||||
typeStr := column.Type
|
// Sanitize type to remove backticks
|
||||||
|
typeStr := writers.SanitizeStructTagValue(column.Type)
|
||||||
if column.Length > 0 {
|
if column.Length > 0 {
|
||||||
typeStr = fmt.Sprintf("%s(%d)", typeStr, column.Length)
|
typeStr = fmt.Sprintf("%s(%d)", typeStr, column.Length)
|
||||||
} else if column.Precision > 0 {
|
} else if column.Precision > 0 {
|
||||||
@@ -188,12 +192,17 @@ func (tm *TypeMapper) BuildBunTag(column *models.Column, table *models.Table) st
|
|||||||
|
|
||||||
// Default value
|
// Default value
|
||||||
if column.Default != nil {
|
if column.Default != nil {
|
||||||
parts = append(parts, fmt.Sprintf("default:%v", column.Default))
|
// Sanitize default value to remove backticks
|
||||||
|
safeDefault := writers.SanitizeStructTagValue(fmt.Sprintf("%v", column.Default))
|
||||||
|
parts = append(parts, fmt.Sprintf("default:%s", safeDefault))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Nullable (Bun uses nullzero for nullable fields)
|
// Nullable (Bun uses nullzero for nullable fields)
|
||||||
|
// and notnull tag for explicitly non-nullable fields
|
||||||
if !column.NotNull && !column.IsPrimaryKey {
|
if !column.NotNull && !column.IsPrimaryKey {
|
||||||
parts = append(parts, "nullzero")
|
parts = append(parts, "nullzero")
|
||||||
|
} else if column.NotNull && !column.IsPrimaryKey {
|
||||||
|
parts = append(parts, "notnull")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check for indexes (unique indexes should be added to tag)
|
// Check for indexes (unique indexes should be added to tag)
|
||||||
@@ -260,7 +269,7 @@ func (tm *TypeMapper) NeedsFmtImport(generateGetIDStr bool) bool {
|
|||||||
|
|
||||||
// GetSQLTypesImport returns the import path for sql_types (ResolveSpec common)
|
// GetSQLTypesImport returns the import path for sql_types (ResolveSpec common)
|
||||||
func (tm *TypeMapper) GetSQLTypesImport() string {
|
func (tm *TypeMapper) GetSQLTypesImport() string {
|
||||||
return "github.com/bitechdev/ResolveSpec/pkg/common"
|
return "github.com/bitechdev/ResolveSpec/pkg/spectypes"
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetBunImport returns the import path for Bun
|
// GetBunImport returns the import path for Bun
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"go/format"
|
"go/format"
|
||||||
"os"
|
"os"
|
||||||
|
"os/exec"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
@@ -124,7 +125,16 @@ func (w *Writer) writeSingleFile(db *models.Database) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Write output
|
// Write output
|
||||||
return w.writeOutput(formatted)
|
if err := w.writeOutput(formatted); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run go fmt on the output file
|
||||||
|
if w.options.OutputPath != "" {
|
||||||
|
w.runGoFmt(w.options.OutputPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// writeMultiFile writes each table to a separate file
|
// writeMultiFile writes each table to a separate file
|
||||||
@@ -207,13 +217,19 @@ func (w *Writer) writeMultiFile(db *models.Database) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Generate filename: sql_{schema}_{table}.go
|
// Generate filename: sql_{schema}_{table}.go
|
||||||
filename := fmt.Sprintf("sql_%s_%s.go", schema.Name, table.Name)
|
// Sanitize schema and table names to remove quotes, comments, and invalid characters
|
||||||
|
safeSchemaName := writers.SanitizeFilename(schema.Name)
|
||||||
|
safeTableName := writers.SanitizeFilename(table.Name)
|
||||||
|
filename := fmt.Sprintf("sql_%s_%s.go", safeSchemaName, safeTableName)
|
||||||
filepath := filepath.Join(w.options.OutputPath, filename)
|
filepath := filepath.Join(w.options.OutputPath, filename)
|
||||||
|
|
||||||
// Write file
|
// Write file
|
||||||
if err := os.WriteFile(filepath, []byte(formatted), 0644); err != nil {
|
if err := os.WriteFile(filepath, []byte(formatted), 0644); err != nil {
|
||||||
return fmt.Errorf("failed to write file %s: %w", filename, err)
|
return fmt.Errorf("failed to write file %s: %w", filename, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Run go fmt on the generated file
|
||||||
|
w.runGoFmt(filepath)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -222,6 +238,9 @@ func (w *Writer) writeMultiFile(db *models.Database) error {
|
|||||||
|
|
||||||
// addRelationshipFields adds relationship fields to the model based on foreign keys
|
// addRelationshipFields adds relationship fields to the model based on foreign keys
|
||||||
func (w *Writer) addRelationshipFields(modelData *ModelData, table *models.Table, schema *models.Schema, db *models.Database) {
|
func (w *Writer) addRelationshipFields(modelData *ModelData, table *models.Table, schema *models.Schema, db *models.Database) {
|
||||||
|
// Track used field names to detect duplicates
|
||||||
|
usedFieldNames := make(map[string]int)
|
||||||
|
|
||||||
// For each foreign key in this table, add a belongs-to/has-one relationship
|
// For each foreign key in this table, add a belongs-to/has-one relationship
|
||||||
for _, constraint := range table.Constraints {
|
for _, constraint := range table.Constraints {
|
||||||
if constraint.Type != models.ForeignKeyConstraint {
|
if constraint.Type != models.ForeignKeyConstraint {
|
||||||
@@ -235,8 +254,9 @@ func (w *Writer) addRelationshipFields(modelData *ModelData, table *models.Table
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create relationship field (has-one in Bun, similar to belongs-to in GORM)
|
// Create relationship field (has-one in Bun, similar to belongs-to in GORM)
|
||||||
refModelName := w.getModelName(constraint.ReferencedTable)
|
refModelName := w.getModelName(constraint.ReferencedSchema, constraint.ReferencedTable)
|
||||||
fieldName := w.generateRelationshipFieldName(constraint.ReferencedTable)
|
fieldName := w.generateHasOneFieldName(constraint)
|
||||||
|
fieldName = w.ensureUniqueFieldName(fieldName, usedFieldNames)
|
||||||
relationTag := w.typeMapper.BuildRelationshipTag(constraint, "has-one")
|
relationTag := w.typeMapper.BuildRelationshipTag(constraint, "has-one")
|
||||||
|
|
||||||
modelData.AddRelationshipField(&FieldData{
|
modelData.AddRelationshipField(&FieldData{
|
||||||
@@ -263,8 +283,9 @@ func (w *Writer) addRelationshipFields(modelData *ModelData, table *models.Table
|
|||||||
// Check if this constraint references our table
|
// Check if this constraint references our table
|
||||||
if constraint.ReferencedTable == table.Name && constraint.ReferencedSchema == schema.Name {
|
if constraint.ReferencedTable == table.Name && constraint.ReferencedSchema == schema.Name {
|
||||||
// Add has-many relationship
|
// Add has-many relationship
|
||||||
otherModelName := w.getModelName(otherTable.Name)
|
otherModelName := w.getModelName(otherSchema.Name, otherTable.Name)
|
||||||
fieldName := w.generateRelationshipFieldName(otherTable.Name) + "s" // Pluralize
|
fieldName := w.generateHasManyFieldName(constraint, otherSchema.Name, otherTable.Name)
|
||||||
|
fieldName = w.ensureUniqueFieldName(fieldName, usedFieldNames)
|
||||||
relationTag := w.typeMapper.BuildRelationshipTag(constraint, "has-many")
|
relationTag := w.typeMapper.BuildRelationshipTag(constraint, "has-many")
|
||||||
|
|
||||||
modelData.AddRelationshipField(&FieldData{
|
modelData.AddRelationshipField(&FieldData{
|
||||||
@@ -295,22 +316,77 @@ func (w *Writer) findTable(schemaName, tableName string, db *models.Database) *m
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// getModelName generates the model name from a table name
|
// getModelName generates the model name from schema and table name
|
||||||
func (w *Writer) getModelName(tableName string) string {
|
func (w *Writer) getModelName(schemaName, tableName string) string {
|
||||||
singular := Singularize(tableName)
|
singular := Singularize(tableName)
|
||||||
modelName := SnakeCaseToPascalCase(singular)
|
tablePart := SnakeCaseToPascalCase(singular)
|
||||||
|
|
||||||
if !hasModelPrefix(modelName) {
|
// Include schema name in model name
|
||||||
modelName = "Model" + modelName
|
var modelName string
|
||||||
|
if schemaName != "" {
|
||||||
|
schemaPart := SnakeCaseToPascalCase(schemaName)
|
||||||
|
modelName = "Model" + schemaPart + tablePart
|
||||||
|
} else {
|
||||||
|
modelName = "Model" + tablePart
|
||||||
}
|
}
|
||||||
|
|
||||||
return modelName
|
return modelName
|
||||||
}
|
}
|
||||||
|
|
||||||
// generateRelationshipFieldName generates a field name for a relationship
|
// generateHasOneFieldName generates a field name for has-one relationships
|
||||||
func (w *Writer) generateRelationshipFieldName(tableName string) string {
|
// Uses the foreign key column name for uniqueness
|
||||||
// Use just the prefix (3 letters) for relationship fields
|
func (w *Writer) generateHasOneFieldName(constraint *models.Constraint) string {
|
||||||
return GeneratePrefix(tableName)
|
// Use the foreign key column name to ensure uniqueness
|
||||||
|
// If there are multiple columns, use the first one
|
||||||
|
if len(constraint.Columns) > 0 {
|
||||||
|
columnName := constraint.Columns[0]
|
||||||
|
// Convert to PascalCase for proper Go field naming
|
||||||
|
// e.g., "rid_filepointer_request" -> "RelRIDFilepointerRequest"
|
||||||
|
return "Rel" + SnakeCaseToPascalCase(columnName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback to table-based prefix if no columns defined
|
||||||
|
return "Rel" + GeneratePrefix(constraint.ReferencedTable)
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateHasManyFieldName generates a field name for has-many relationships
|
||||||
|
// Uses the foreign key column name + source table name to avoid duplicates
|
||||||
|
func (w *Writer) generateHasManyFieldName(constraint *models.Constraint, sourceSchemaName, sourceTableName string) string {
|
||||||
|
// For has-many, we need to include the source table name to avoid duplicates
|
||||||
|
// e.g., multiple tables referencing the same column on this table
|
||||||
|
if len(constraint.Columns) > 0 {
|
||||||
|
columnName := constraint.Columns[0]
|
||||||
|
// Get the model name for the source table (pluralized)
|
||||||
|
sourceModelName := w.getModelName(sourceSchemaName, sourceTableName)
|
||||||
|
// Remove "Model" prefix if present
|
||||||
|
sourceModelName = strings.TrimPrefix(sourceModelName, "Model")
|
||||||
|
|
||||||
|
// Convert column to PascalCase and combine with source table
|
||||||
|
// e.g., "rid_api_provider" + "Login" -> "RelRIDAPIProviderLogins"
|
||||||
|
columnPart := SnakeCaseToPascalCase(columnName)
|
||||||
|
return "Rel" + columnPart + Pluralize(sourceModelName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback to table-based naming
|
||||||
|
sourceModelName := w.getModelName(sourceSchemaName, sourceTableName)
|
||||||
|
sourceModelName = strings.TrimPrefix(sourceModelName, "Model")
|
||||||
|
return "Rel" + Pluralize(sourceModelName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ensureUniqueFieldName ensures a field name is unique by adding numeric suffixes if needed
|
||||||
|
func (w *Writer) ensureUniqueFieldName(fieldName string, usedNames map[string]int) string {
|
||||||
|
originalName := fieldName
|
||||||
|
count := usedNames[originalName]
|
||||||
|
|
||||||
|
if count > 0 {
|
||||||
|
// Name is already used, add numeric suffix
|
||||||
|
fieldName = fmt.Sprintf("%s%d", originalName, count+1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Increment the counter for this base name
|
||||||
|
usedNames[originalName]++
|
||||||
|
|
||||||
|
return fieldName
|
||||||
}
|
}
|
||||||
|
|
||||||
// getPackageName returns the package name from options or defaults to "models"
|
// getPackageName returns the package name from options or defaults to "models"
|
||||||
@@ -341,6 +417,15 @@ func (w *Writer) writeOutput(content string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// runGoFmt runs go fmt on the specified file
|
||||||
|
func (w *Writer) runGoFmt(filepath string) {
|
||||||
|
cmd := exec.Command("gofmt", "-w", filepath)
|
||||||
|
if err := cmd.Run(); err != nil {
|
||||||
|
// Don't fail the whole operation if gofmt fails, just warn
|
||||||
|
fmt.Fprintf(os.Stderr, "Warning: failed to run gofmt on %s: %v\n", filepath, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// shouldUseMultiFile determines whether to use multi-file mode based on metadata or output path
|
// shouldUseMultiFile determines whether to use multi-file mode based on metadata or output path
|
||||||
func (w *Writer) shouldUseMultiFile() bool {
|
func (w *Writer) shouldUseMultiFile() bool {
|
||||||
// Check if multi_file is explicitly set in metadata
|
// Check if multi_file is explicitly set in metadata
|
||||||
|
|||||||
@@ -66,7 +66,7 @@ func TestWriter_WriteTable(t *testing.T) {
|
|||||||
// Verify key elements are present
|
// Verify key elements are present
|
||||||
expectations := []string{
|
expectations := []string{
|
||||||
"package models",
|
"package models",
|
||||||
"type ModelUser struct",
|
"type ModelPublicUser struct",
|
||||||
"bun.BaseModel",
|
"bun.BaseModel",
|
||||||
"table:public.users",
|
"table:public.users",
|
||||||
"alias:users",
|
"alias:users",
|
||||||
@@ -78,9 +78,9 @@ func TestWriter_WriteTable(t *testing.T) {
|
|||||||
"resolvespec_common.SqlTime",
|
"resolvespec_common.SqlTime",
|
||||||
"bun:\"id",
|
"bun:\"id",
|
||||||
"bun:\"email",
|
"bun:\"email",
|
||||||
"func (m ModelUser) TableName() string",
|
"func (m ModelPublicUser) TableName() string",
|
||||||
"return \"public.users\"",
|
"return \"public.users\"",
|
||||||
"func (m ModelUser) GetID() int64",
|
"func (m ModelPublicUser) GetID() int64",
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, expected := range expectations {
|
for _, expected := range expectations {
|
||||||
@@ -175,12 +175,378 @@ func TestWriter_WriteDatabase_MultiFile(t *testing.T) {
|
|||||||
postsStr := string(postsContent)
|
postsStr := string(postsContent)
|
||||||
|
|
||||||
// Verify relationship is present with Bun format
|
// Verify relationship is present with Bun format
|
||||||
if !strings.Contains(postsStr, "USE") {
|
// Should now be RelUserID (has-one) instead of USE
|
||||||
t.Errorf("Missing relationship field USE")
|
if !strings.Contains(postsStr, "RelUserID") {
|
||||||
|
t.Errorf("Missing relationship field RelUserID (new naming convention)")
|
||||||
}
|
}
|
||||||
if !strings.Contains(postsStr, "rel:has-one") {
|
if !strings.Contains(postsStr, "rel:has-one") {
|
||||||
t.Errorf("Missing Bun relationship tag: %s", postsStr)
|
t.Errorf("Missing Bun relationship tag: %s", postsStr)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check users file contains has-many relationship
|
||||||
|
usersContent, err := os.ReadFile(filepath.Join(tmpDir, "sql_public_users.go"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to read users file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
usersStr := string(usersContent)
|
||||||
|
|
||||||
|
// Should have RelUserIDPublicPosts (has-many) field - includes schema prefix
|
||||||
|
if !strings.Contains(usersStr, "RelUserIDPublicPosts") {
|
||||||
|
t.Errorf("Missing has-many relationship field RelUserIDPublicPosts")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWriter_MultipleReferencesToSameTable(t *testing.T) {
|
||||||
|
// Test scenario: api_event table with multiple foreign keys to filepointer table
|
||||||
|
db := models.InitDatabase("testdb")
|
||||||
|
schema := models.InitSchema("org")
|
||||||
|
|
||||||
|
// Filepointer table
|
||||||
|
filepointer := models.InitTable("filepointer", "org")
|
||||||
|
filepointer.Columns["id_filepointer"] = &models.Column{
|
||||||
|
Name: "id_filepointer",
|
||||||
|
Type: "bigserial",
|
||||||
|
NotNull: true,
|
||||||
|
IsPrimaryKey: true,
|
||||||
|
}
|
||||||
|
schema.Tables = append(schema.Tables, filepointer)
|
||||||
|
|
||||||
|
// API event table with two foreign keys to filepointer
|
||||||
|
apiEvent := models.InitTable("api_event", "org")
|
||||||
|
apiEvent.Columns["id_api_event"] = &models.Column{
|
||||||
|
Name: "id_api_event",
|
||||||
|
Type: "bigserial",
|
||||||
|
NotNull: true,
|
||||||
|
IsPrimaryKey: true,
|
||||||
|
}
|
||||||
|
apiEvent.Columns["rid_filepointer_request"] = &models.Column{
|
||||||
|
Name: "rid_filepointer_request",
|
||||||
|
Type: "bigint",
|
||||||
|
NotNull: false,
|
||||||
|
}
|
||||||
|
apiEvent.Columns["rid_filepointer_response"] = &models.Column{
|
||||||
|
Name: "rid_filepointer_response",
|
||||||
|
Type: "bigint",
|
||||||
|
NotNull: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add constraints
|
||||||
|
apiEvent.Constraints["fk_request"] = &models.Constraint{
|
||||||
|
Name: "fk_request",
|
||||||
|
Type: models.ForeignKeyConstraint,
|
||||||
|
Columns: []string{"rid_filepointer_request"},
|
||||||
|
ReferencedTable: "filepointer",
|
||||||
|
ReferencedSchema: "org",
|
||||||
|
ReferencedColumns: []string{"id_filepointer"},
|
||||||
|
}
|
||||||
|
apiEvent.Constraints["fk_response"] = &models.Constraint{
|
||||||
|
Name: "fk_response",
|
||||||
|
Type: models.ForeignKeyConstraint,
|
||||||
|
Columns: []string{"rid_filepointer_response"},
|
||||||
|
ReferencedTable: "filepointer",
|
||||||
|
ReferencedSchema: "org",
|
||||||
|
ReferencedColumns: []string{"id_filepointer"},
|
||||||
|
}
|
||||||
|
|
||||||
|
schema.Tables = append(schema.Tables, apiEvent)
|
||||||
|
db.Schemas = append(db.Schemas, schema)
|
||||||
|
|
||||||
|
// Create writer
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
opts := &writers.WriterOptions{
|
||||||
|
PackageName: "models",
|
||||||
|
OutputPath: tmpDir,
|
||||||
|
Metadata: map[string]interface{}{
|
||||||
|
"multi_file": true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
writer := NewWriter(opts)
|
||||||
|
err := writer.WriteDatabase(db)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("WriteDatabase failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read the api_event file
|
||||||
|
apiEventContent, err := os.ReadFile(filepath.Join(tmpDir, "sql_org_api_event.go"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to read api_event file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
contentStr := string(apiEventContent)
|
||||||
|
|
||||||
|
// Verify both relationships have unique names based on column names
|
||||||
|
expectations := []struct {
|
||||||
|
fieldName string
|
||||||
|
tag string
|
||||||
|
}{
|
||||||
|
{"RelRIDFilepointerRequest", "join:rid_filepointer_request=id_filepointer"},
|
||||||
|
{"RelRIDFilepointerResponse", "join:rid_filepointer_response=id_filepointer"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, exp := range expectations {
|
||||||
|
if !strings.Contains(contentStr, exp.fieldName) {
|
||||||
|
t.Errorf("Missing relationship field: %s\nGenerated:\n%s", exp.fieldName, contentStr)
|
||||||
|
}
|
||||||
|
if !strings.Contains(contentStr, exp.tag) {
|
||||||
|
t.Errorf("Missing relationship tag: %s\nGenerated:\n%s", exp.tag, contentStr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify NO duplicate field names (old behavior would create duplicate "FIL" fields)
|
||||||
|
if strings.Contains(contentStr, "FIL *ModelFilepointer") {
|
||||||
|
t.Errorf("Found old prefix-based naming (FIL), should use column-based naming")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Also verify has-many relationships on filepointer table
|
||||||
|
filepointerContent, err := os.ReadFile(filepath.Join(tmpDir, "sql_org_filepointer.go"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to read filepointer file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
filepointerStr := string(filepointerContent)
|
||||||
|
|
||||||
|
// Should have two different has-many relationships with unique names
|
||||||
|
hasManyExpectations := []string{
|
||||||
|
"RelRIDFilepointerRequestOrgAPIEvents", // Has many via rid_filepointer_request
|
||||||
|
"RelRIDFilepointerResponseOrgAPIEvents", // Has many via rid_filepointer_response
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, exp := range hasManyExpectations {
|
||||||
|
if !strings.Contains(filepointerStr, exp) {
|
||||||
|
t.Errorf("Missing has-many relationship field: %s\nGenerated:\n%s", exp, filepointerStr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWriter_MultipleHasManyRelationships(t *testing.T) {
|
||||||
|
// Test scenario: api_provider table referenced by multiple tables via rid_api_provider
|
||||||
|
db := models.InitDatabase("testdb")
|
||||||
|
schema := models.InitSchema("org")
|
||||||
|
|
||||||
|
// Owner table
|
||||||
|
owner := models.InitTable("owner", "org")
|
||||||
|
owner.Columns["id_owner"] = &models.Column{
|
||||||
|
Name: "id_owner",
|
||||||
|
Type: "bigserial",
|
||||||
|
NotNull: true,
|
||||||
|
IsPrimaryKey: true,
|
||||||
|
}
|
||||||
|
schema.Tables = append(schema.Tables, owner)
|
||||||
|
|
||||||
|
// API Provider table
|
||||||
|
apiProvider := models.InitTable("api_provider", "org")
|
||||||
|
apiProvider.Columns["id_api_provider"] = &models.Column{
|
||||||
|
Name: "id_api_provider",
|
||||||
|
Type: "bigserial",
|
||||||
|
NotNull: true,
|
||||||
|
IsPrimaryKey: true,
|
||||||
|
}
|
||||||
|
apiProvider.Columns["rid_owner"] = &models.Column{
|
||||||
|
Name: "rid_owner",
|
||||||
|
Type: "bigint",
|
||||||
|
NotNull: true,
|
||||||
|
}
|
||||||
|
apiProvider.Constraints["fk_owner"] = &models.Constraint{
|
||||||
|
Name: "fk_owner",
|
||||||
|
Type: models.ForeignKeyConstraint,
|
||||||
|
Columns: []string{"rid_owner"},
|
||||||
|
ReferencedTable: "owner",
|
||||||
|
ReferencedSchema: "org",
|
||||||
|
ReferencedColumns: []string{"id_owner"},
|
||||||
|
}
|
||||||
|
schema.Tables = append(schema.Tables, apiProvider)
|
||||||
|
|
||||||
|
// Login table
|
||||||
|
login := models.InitTable("login", "org")
|
||||||
|
login.Columns["id_login"] = &models.Column{
|
||||||
|
Name: "id_login",
|
||||||
|
Type: "bigserial",
|
||||||
|
NotNull: true,
|
||||||
|
IsPrimaryKey: true,
|
||||||
|
}
|
||||||
|
login.Columns["rid_api_provider"] = &models.Column{
|
||||||
|
Name: "rid_api_provider",
|
||||||
|
Type: "bigint",
|
||||||
|
NotNull: true,
|
||||||
|
}
|
||||||
|
login.Constraints["fk_api_provider"] = &models.Constraint{
|
||||||
|
Name: "fk_api_provider",
|
||||||
|
Type: models.ForeignKeyConstraint,
|
||||||
|
Columns: []string{"rid_api_provider"},
|
||||||
|
ReferencedTable: "api_provider",
|
||||||
|
ReferencedSchema: "org",
|
||||||
|
ReferencedColumns: []string{"id_api_provider"},
|
||||||
|
}
|
||||||
|
schema.Tables = append(schema.Tables, login)
|
||||||
|
|
||||||
|
// Filepointer table
|
||||||
|
filepointer := models.InitTable("filepointer", "org")
|
||||||
|
filepointer.Columns["id_filepointer"] = &models.Column{
|
||||||
|
Name: "id_filepointer",
|
||||||
|
Type: "bigserial",
|
||||||
|
NotNull: true,
|
||||||
|
IsPrimaryKey: true,
|
||||||
|
}
|
||||||
|
filepointer.Columns["rid_api_provider"] = &models.Column{
|
||||||
|
Name: "rid_api_provider",
|
||||||
|
Type: "bigint",
|
||||||
|
NotNull: true,
|
||||||
|
}
|
||||||
|
filepointer.Constraints["fk_api_provider"] = &models.Constraint{
|
||||||
|
Name: "fk_api_provider",
|
||||||
|
Type: models.ForeignKeyConstraint,
|
||||||
|
Columns: []string{"rid_api_provider"},
|
||||||
|
ReferencedTable: "api_provider",
|
||||||
|
ReferencedSchema: "org",
|
||||||
|
ReferencedColumns: []string{"id_api_provider"},
|
||||||
|
}
|
||||||
|
schema.Tables = append(schema.Tables, filepointer)
|
||||||
|
|
||||||
|
// API Event table
|
||||||
|
apiEvent := models.InitTable("api_event", "org")
|
||||||
|
apiEvent.Columns["id_api_event"] = &models.Column{
|
||||||
|
Name: "id_api_event",
|
||||||
|
Type: "bigserial",
|
||||||
|
NotNull: true,
|
||||||
|
IsPrimaryKey: true,
|
||||||
|
}
|
||||||
|
apiEvent.Columns["rid_api_provider"] = &models.Column{
|
||||||
|
Name: "rid_api_provider",
|
||||||
|
Type: "bigint",
|
||||||
|
NotNull: true,
|
||||||
|
}
|
||||||
|
apiEvent.Constraints["fk_api_provider"] = &models.Constraint{
|
||||||
|
Name: "fk_api_provider",
|
||||||
|
Type: models.ForeignKeyConstraint,
|
||||||
|
Columns: []string{"rid_api_provider"},
|
||||||
|
ReferencedTable: "api_provider",
|
||||||
|
ReferencedSchema: "org",
|
||||||
|
ReferencedColumns: []string{"id_api_provider"},
|
||||||
|
}
|
||||||
|
schema.Tables = append(schema.Tables, apiEvent)
|
||||||
|
|
||||||
|
db.Schemas = append(db.Schemas, schema)
|
||||||
|
|
||||||
|
// Create writer
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
opts := &writers.WriterOptions{
|
||||||
|
PackageName: "models",
|
||||||
|
OutputPath: tmpDir,
|
||||||
|
Metadata: map[string]interface{}{
|
||||||
|
"multi_file": true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
writer := NewWriter(opts)
|
||||||
|
err := writer.WriteDatabase(db)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("WriteDatabase failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read the api_provider file
|
||||||
|
apiProviderContent, err := os.ReadFile(filepath.Join(tmpDir, "sql_org_api_provider.go"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to read api_provider file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
contentStr := string(apiProviderContent)
|
||||||
|
|
||||||
|
// Verify all has-many relationships have unique names
|
||||||
|
hasManyExpectations := []string{
|
||||||
|
"RelRIDAPIProviderOrgLogins", // Has many via Login
|
||||||
|
"RelRIDAPIProviderOrgFilepointers", // Has many via Filepointer
|
||||||
|
"RelRIDAPIProviderOrgAPIEvents", // Has many via APIEvent
|
||||||
|
"RelRIDOwner", // Has one via rid_owner
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, exp := range hasManyExpectations {
|
||||||
|
if !strings.Contains(contentStr, exp) {
|
||||||
|
t.Errorf("Missing relationship field: %s\nGenerated:\n%s", exp, contentStr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify NO duplicate field names
|
||||||
|
// Count occurrences of "RelRIDAPIProvider" fields - should have 3 unique ones
|
||||||
|
count := strings.Count(contentStr, "RelRIDAPIProvider")
|
||||||
|
if count != 3 {
|
||||||
|
t.Errorf("Expected 3 RelRIDAPIProvider* fields, found %d\nGenerated:\n%s", count, contentStr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify no duplicate declarations (would cause compilation error)
|
||||||
|
duplicatePattern := "RelRIDAPIProviders []*Model"
|
||||||
|
if strings.Contains(contentStr, duplicatePattern) {
|
||||||
|
t.Errorf("Found duplicate field declaration pattern, fields should be unique")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWriter_FieldNameCollision(t *testing.T) {
|
||||||
|
// Test scenario: table with columns that would conflict with generated method names
|
||||||
|
table := models.InitTable("audit_table", "audit")
|
||||||
|
table.Columns["id_audit_table"] = &models.Column{
|
||||||
|
Name: "id_audit_table",
|
||||||
|
Type: "smallint",
|
||||||
|
NotNull: true,
|
||||||
|
IsPrimaryKey: true,
|
||||||
|
Sequence: 1,
|
||||||
|
}
|
||||||
|
table.Columns["table_name"] = &models.Column{
|
||||||
|
Name: "table_name",
|
||||||
|
Type: "varchar",
|
||||||
|
Length: 100,
|
||||||
|
NotNull: true,
|
||||||
|
Sequence: 2,
|
||||||
|
}
|
||||||
|
table.Columns["table_schema"] = &models.Column{
|
||||||
|
Name: "table_schema",
|
||||||
|
Type: "varchar",
|
||||||
|
Length: 100,
|
||||||
|
NotNull: true,
|
||||||
|
Sequence: 3,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create writer
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
opts := &writers.WriterOptions{
|
||||||
|
PackageName: "models",
|
||||||
|
OutputPath: filepath.Join(tmpDir, "test.go"),
|
||||||
|
}
|
||||||
|
|
||||||
|
writer := NewWriter(opts)
|
||||||
|
|
||||||
|
err := writer.WriteTable(table)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("WriteTable failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read the generated file
|
||||||
|
content, err := os.ReadFile(opts.OutputPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to read generated file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
generated := string(content)
|
||||||
|
|
||||||
|
// Verify that TableName field was renamed to TableName_ to avoid collision
|
||||||
|
if !strings.Contains(generated, "TableName_") {
|
||||||
|
t.Errorf("Expected field 'TableName_' (with underscore) but not found\nGenerated:\n%s", generated)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify the struct tag still references the correct database column
|
||||||
|
if !strings.Contains(generated, `bun:"table_name,`) {
|
||||||
|
t.Errorf("Expected bun tag to reference 'table_name' column\nGenerated:\n%s", generated)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify the TableName() method still exists and doesn't conflict
|
||||||
|
if !strings.Contains(generated, "func (m ModelAuditAuditTable) TableName() string") {
|
||||||
|
t.Errorf("TableName() method should still be generated\nGenerated:\n%s", generated)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify NO field named just "TableName" (without underscore)
|
||||||
|
if strings.Contains(generated, "TableName resolvespec_common") || strings.Contains(generated, "TableName string") {
|
||||||
|
t.Errorf("Field 'TableName' without underscore should not exist (would conflict with method)\nGenerated:\n%s", generated)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestTypeMapper_SQLTypeToGoType_Bun(t *testing.T) {
|
func TestTypeMapper_SQLTypeToGoType_Bun(t *testing.T) {
|
||||||
|
|||||||
@@ -126,7 +126,15 @@ func (w *Writer) tableToDBML(t *models.Table) string {
|
|||||||
attrs = append(attrs, "increment")
|
attrs = append(attrs, "increment")
|
||||||
}
|
}
|
||||||
if column.Default != nil {
|
if column.Default != nil {
|
||||||
attrs = append(attrs, fmt.Sprintf("default: `%v`", column.Default))
|
// Check if default value contains backticks (DBML expressions like `now()`)
|
||||||
|
defaultStr := fmt.Sprintf("%v", column.Default)
|
||||||
|
if strings.HasPrefix(defaultStr, "`") && strings.HasSuffix(defaultStr, "`") {
|
||||||
|
// Already an expression with backticks, use as-is
|
||||||
|
attrs = append(attrs, fmt.Sprintf("default: %s", defaultStr))
|
||||||
|
} else {
|
||||||
|
// Regular value, wrap in single quotes
|
||||||
|
attrs = append(attrs, fmt.Sprintf("default: '%v'", column.Default))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(attrs) > 0 {
|
if len(attrs) > 0 {
|
||||||
|
|||||||
@@ -196,7 +196,9 @@ func (w *Writer) writeTableFile(table *models.Table, schema *models.Schema, db *
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Generate filename: {tableName}.ts
|
// Generate filename: {tableName}.ts
|
||||||
filename := filepath.Join(w.options.OutputPath, table.Name+".ts")
|
// Sanitize table name to remove quotes, comments, and invalid characters
|
||||||
|
safeTableName := writers.SanitizeFilename(table.Name)
|
||||||
|
filename := filepath.Join(w.options.OutputPath, safeTableName+".ts")
|
||||||
return os.WriteFile(filename, []byte(code), 0644)
|
return os.WriteFile(filename, []byte(code), 0644)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"sort"
|
"sort"
|
||||||
|
|
||||||
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||||
|
"git.warky.dev/wdevs/relspecgo/pkg/writers"
|
||||||
)
|
)
|
||||||
|
|
||||||
// TemplateData represents the data passed to the template for code generation
|
// TemplateData represents the data passed to the template for code generation
|
||||||
@@ -24,6 +25,7 @@ type ModelData struct {
|
|||||||
Fields []*FieldData
|
Fields []*FieldData
|
||||||
Config *MethodConfig
|
Config *MethodConfig
|
||||||
PrimaryKeyField string // Name of the primary key field
|
PrimaryKeyField string // Name of the primary key field
|
||||||
|
PrimaryKeyType string // Go type of the primary key field
|
||||||
IDColumnName string // Name of the ID column in database
|
IDColumnName string // Name of the ID column in database
|
||||||
Prefix string // 3-letter prefix
|
Prefix string // 3-letter prefix
|
||||||
}
|
}
|
||||||
@@ -109,13 +111,17 @@ func NewModelData(table *models.Table, schema string, typeMapper *TypeMapper) *M
|
|||||||
tableName = schema + "." + table.Name
|
tableName = schema + "." + table.Name
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate model name: singularize and convert to PascalCase
|
// Generate model name: Model + Schema + Table (all PascalCase)
|
||||||
singularTable := Singularize(table.Name)
|
singularTable := Singularize(table.Name)
|
||||||
modelName := SnakeCaseToPascalCase(singularTable)
|
tablePart := SnakeCaseToPascalCase(singularTable)
|
||||||
|
|
||||||
// Add "Model" prefix if not already present
|
// Include schema name in model name
|
||||||
if !hasModelPrefix(modelName) {
|
var modelName string
|
||||||
modelName = "Model" + modelName
|
if schema != "" {
|
||||||
|
schemaPart := SnakeCaseToPascalCase(schema)
|
||||||
|
modelName = "Model" + schemaPart + tablePart
|
||||||
|
} else {
|
||||||
|
modelName = "Model" + tablePart
|
||||||
}
|
}
|
||||||
|
|
||||||
model := &ModelData{
|
model := &ModelData{
|
||||||
@@ -131,8 +137,11 @@ func NewModelData(table *models.Table, schema string, typeMapper *TypeMapper) *M
|
|||||||
// Find primary key
|
// Find primary key
|
||||||
for _, col := range table.Columns {
|
for _, col := range table.Columns {
|
||||||
if col.IsPrimaryKey {
|
if col.IsPrimaryKey {
|
||||||
model.PrimaryKeyField = SnakeCaseToPascalCase(col.Name)
|
// Sanitize column name to remove backticks
|
||||||
model.IDColumnName = col.Name
|
safeName := writers.SanitizeStructTagValue(col.Name)
|
||||||
|
model.PrimaryKeyField = SnakeCaseToPascalCase(safeName)
|
||||||
|
model.PrimaryKeyType = typeMapper.SQLTypeToGoType(col.Type, col.NotNull)
|
||||||
|
model.IDColumnName = safeName
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -141,6 +150,8 @@ func NewModelData(table *models.Table, schema string, typeMapper *TypeMapper) *M
|
|||||||
columns := sortColumns(table.Columns)
|
columns := sortColumns(table.Columns)
|
||||||
for _, col := range columns {
|
for _, col := range columns {
|
||||||
field := columnToField(col, table, typeMapper)
|
field := columnToField(col, table, typeMapper)
|
||||||
|
// Check for name collision with generated methods and rename if needed
|
||||||
|
field.Name = resolveFieldNameCollision(field.Name)
|
||||||
model.Fields = append(model.Fields, field)
|
model.Fields = append(model.Fields, field)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -149,10 +160,13 @@ func NewModelData(table *models.Table, schema string, typeMapper *TypeMapper) *M
|
|||||||
|
|
||||||
// columnToField converts a models.Column to FieldData
|
// columnToField converts a models.Column to FieldData
|
||||||
func columnToField(col *models.Column, table *models.Table, typeMapper *TypeMapper) *FieldData {
|
func columnToField(col *models.Column, table *models.Table, typeMapper *TypeMapper) *FieldData {
|
||||||
fieldName := SnakeCaseToPascalCase(col.Name)
|
// Sanitize column name first to remove backticks before generating field name
|
||||||
|
safeName := writers.SanitizeStructTagValue(col.Name)
|
||||||
|
fieldName := SnakeCaseToPascalCase(safeName)
|
||||||
goType := typeMapper.SQLTypeToGoType(col.Type, col.NotNull)
|
goType := typeMapper.SQLTypeToGoType(col.Type, col.NotNull)
|
||||||
gormTag := typeMapper.BuildGormTag(col, table)
|
gormTag := typeMapper.BuildGormTag(col, table)
|
||||||
jsonTag := col.Name // Use column name for JSON tag
|
// Use same sanitized name for JSON tag
|
||||||
|
jsonTag := safeName
|
||||||
|
|
||||||
return &FieldData{
|
return &FieldData{
|
||||||
Name: fieldName,
|
Name: fieldName,
|
||||||
@@ -179,9 +193,28 @@ func formatComment(description, comment string) string {
|
|||||||
return comment
|
return comment
|
||||||
}
|
}
|
||||||
|
|
||||||
// hasModelPrefix checks if a name already has "Model" prefix
|
// resolveFieldNameCollision checks if a field name conflicts with generated method names
|
||||||
func hasModelPrefix(name string) bool {
|
// and adds an underscore suffix if there's a collision
|
||||||
return len(name) >= 5 && name[:5] == "Model"
|
func resolveFieldNameCollision(fieldName string) string {
|
||||||
|
// List of method names that are generated by the template
|
||||||
|
reservedNames := map[string]bool{
|
||||||
|
"TableName": true,
|
||||||
|
"TableNameOnly": true,
|
||||||
|
"SchemaName": true,
|
||||||
|
"GetID": true,
|
||||||
|
"GetIDStr": true,
|
||||||
|
"SetID": true,
|
||||||
|
"UpdateID": true,
|
||||||
|
"GetIDName": true,
|
||||||
|
"GetPrefix": true,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if field name conflicts with a reserved method name
|
||||||
|
if reservedNames[fieldName] {
|
||||||
|
return fieldName + "_"
|
||||||
|
}
|
||||||
|
|
||||||
|
return fieldName
|
||||||
}
|
}
|
||||||
|
|
||||||
// sortColumns sorts columns by sequence, then by name
|
// sortColumns sorts columns by sequence, then by name
|
||||||
|
|||||||
@@ -62,7 +62,7 @@ func (m {{.Name}}) SetID(newid int64) {
|
|||||||
{{if and .Config.GenerateUpdateID .PrimaryKeyField}}
|
{{if and .Config.GenerateUpdateID .PrimaryKeyField}}
|
||||||
// UpdateID updates the primary key value
|
// UpdateID updates the primary key value
|
||||||
func (m *{{.Name}}) UpdateID(newid int64) {
|
func (m *{{.Name}}) UpdateID(newid int64) {
|
||||||
m.{{.PrimaryKeyField}} = int32(newid)
|
m.{{.PrimaryKeyField}} = {{.PrimaryKeyType}}(newid)
|
||||||
}
|
}
|
||||||
{{end}}
|
{{end}}
|
||||||
{{if and .Config.GenerateGetIDName .IDColumnName}}
|
{{if and .Config.GenerateGetIDName .IDColumnName}}
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||||
|
"git.warky.dev/wdevs/relspecgo/pkg/writers"
|
||||||
)
|
)
|
||||||
|
|
||||||
// TypeMapper handles type conversions between SQL and Go types
|
// TypeMapper handles type conversions between SQL and Go types
|
||||||
@@ -199,12 +200,15 @@ func (tm *TypeMapper) BuildGormTag(column *models.Column, table *models.Table) s
|
|||||||
var parts []string
|
var parts []string
|
||||||
|
|
||||||
// Always include column name (lowercase as per user requirement)
|
// Always include column name (lowercase as per user requirement)
|
||||||
parts = append(parts, fmt.Sprintf("column:%s", column.Name))
|
// Sanitize to remove backticks which would break struct tag syntax
|
||||||
|
safeName := writers.SanitizeStructTagValue(column.Name)
|
||||||
|
parts = append(parts, fmt.Sprintf("column:%s", safeName))
|
||||||
|
|
||||||
// Add type if specified
|
// Add type if specified
|
||||||
if column.Type != "" {
|
if column.Type != "" {
|
||||||
// Include length, precision, scale if present
|
// Include length, precision, scale if present
|
||||||
typeStr := column.Type
|
// Sanitize type to remove backticks
|
||||||
|
typeStr := writers.SanitizeStructTagValue(column.Type)
|
||||||
if column.Length > 0 {
|
if column.Length > 0 {
|
||||||
typeStr = fmt.Sprintf("%s(%d)", typeStr, column.Length)
|
typeStr = fmt.Sprintf("%s(%d)", typeStr, column.Length)
|
||||||
} else if column.Precision > 0 {
|
} else if column.Precision > 0 {
|
||||||
@@ -234,7 +238,9 @@ func (tm *TypeMapper) BuildGormTag(column *models.Column, table *models.Table) s
|
|||||||
|
|
||||||
// Default value
|
// Default value
|
||||||
if column.Default != nil {
|
if column.Default != nil {
|
||||||
parts = append(parts, fmt.Sprintf("default:%v", column.Default))
|
// Sanitize default value to remove backticks
|
||||||
|
safeDefault := writers.SanitizeStructTagValue(fmt.Sprintf("%v", column.Default))
|
||||||
|
parts = append(parts, fmt.Sprintf("default:%s", safeDefault))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check for unique constraint
|
// Check for unique constraint
|
||||||
@@ -331,5 +337,5 @@ func (tm *TypeMapper) NeedsFmtImport(generateGetIDStr bool) bool {
|
|||||||
|
|
||||||
// GetSQLTypesImport returns the import path for sql_types
|
// GetSQLTypesImport returns the import path for sql_types
|
||||||
func (tm *TypeMapper) GetSQLTypesImport() string {
|
func (tm *TypeMapper) GetSQLTypesImport() string {
|
||||||
return "github.com/bitechdev/ResolveSpec/pkg/common/sql_types"
|
return "github.com/bitechdev/ResolveSpec/pkg/spectypes"
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"go/format"
|
"go/format"
|
||||||
"os"
|
"os"
|
||||||
|
"os/exec"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
@@ -121,7 +122,16 @@ func (w *Writer) writeSingleFile(db *models.Database) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Write output
|
// Write output
|
||||||
return w.writeOutput(formatted)
|
if err := w.writeOutput(formatted); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run go fmt on the output file
|
||||||
|
if w.options.OutputPath != "" {
|
||||||
|
w.runGoFmt(w.options.OutputPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// writeMultiFile writes each table to a separate file
|
// writeMultiFile writes each table to a separate file
|
||||||
@@ -201,13 +211,19 @@ func (w *Writer) writeMultiFile(db *models.Database) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Generate filename: sql_{schema}_{table}.go
|
// Generate filename: sql_{schema}_{table}.go
|
||||||
filename := fmt.Sprintf("sql_%s_%s.go", schema.Name, table.Name)
|
// Sanitize schema and table names to remove quotes, comments, and invalid characters
|
||||||
|
safeSchemaName := writers.SanitizeFilename(schema.Name)
|
||||||
|
safeTableName := writers.SanitizeFilename(table.Name)
|
||||||
|
filename := fmt.Sprintf("sql_%s_%s.go", safeSchemaName, safeTableName)
|
||||||
filepath := filepath.Join(w.options.OutputPath, filename)
|
filepath := filepath.Join(w.options.OutputPath, filename)
|
||||||
|
|
||||||
// Write file
|
// Write file
|
||||||
if err := os.WriteFile(filepath, []byte(formatted), 0644); err != nil {
|
if err := os.WriteFile(filepath, []byte(formatted), 0644); err != nil {
|
||||||
return fmt.Errorf("failed to write file %s: %w", filename, err)
|
return fmt.Errorf("failed to write file %s: %w", filename, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Run go fmt on the generated file
|
||||||
|
w.runGoFmt(filepath)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -216,6 +232,9 @@ func (w *Writer) writeMultiFile(db *models.Database) error {
|
|||||||
|
|
||||||
// addRelationshipFields adds relationship fields to the model based on foreign keys
|
// addRelationshipFields adds relationship fields to the model based on foreign keys
|
||||||
func (w *Writer) addRelationshipFields(modelData *ModelData, table *models.Table, schema *models.Schema, db *models.Database) {
|
func (w *Writer) addRelationshipFields(modelData *ModelData, table *models.Table, schema *models.Schema, db *models.Database) {
|
||||||
|
// Track used field names to detect duplicates
|
||||||
|
usedFieldNames := make(map[string]int)
|
||||||
|
|
||||||
// For each foreign key in this table, add a belongs-to relationship
|
// For each foreign key in this table, add a belongs-to relationship
|
||||||
for _, constraint := range table.Constraints {
|
for _, constraint := range table.Constraints {
|
||||||
if constraint.Type != models.ForeignKeyConstraint {
|
if constraint.Type != models.ForeignKeyConstraint {
|
||||||
@@ -229,8 +248,9 @@ func (w *Writer) addRelationshipFields(modelData *ModelData, table *models.Table
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create relationship field (belongs-to)
|
// Create relationship field (belongs-to)
|
||||||
refModelName := w.getModelName(constraint.ReferencedTable)
|
refModelName := w.getModelName(constraint.ReferencedSchema, constraint.ReferencedTable)
|
||||||
fieldName := w.generateRelationshipFieldName(constraint.ReferencedTable)
|
fieldName := w.generateBelongsToFieldName(constraint)
|
||||||
|
fieldName = w.ensureUniqueFieldName(fieldName, usedFieldNames)
|
||||||
relationTag := w.typeMapper.BuildRelationshipTag(constraint, false)
|
relationTag := w.typeMapper.BuildRelationshipTag(constraint, false)
|
||||||
|
|
||||||
modelData.AddRelationshipField(&FieldData{
|
modelData.AddRelationshipField(&FieldData{
|
||||||
@@ -257,8 +277,9 @@ func (w *Writer) addRelationshipFields(modelData *ModelData, table *models.Table
|
|||||||
// Check if this constraint references our table
|
// Check if this constraint references our table
|
||||||
if constraint.ReferencedTable == table.Name && constraint.ReferencedSchema == schema.Name {
|
if constraint.ReferencedTable == table.Name && constraint.ReferencedSchema == schema.Name {
|
||||||
// Add has-many relationship
|
// Add has-many relationship
|
||||||
otherModelName := w.getModelName(otherTable.Name)
|
otherModelName := w.getModelName(otherSchema.Name, otherTable.Name)
|
||||||
fieldName := w.generateRelationshipFieldName(otherTable.Name) + "s" // Pluralize
|
fieldName := w.generateHasManyFieldName(constraint, otherSchema.Name, otherTable.Name)
|
||||||
|
fieldName = w.ensureUniqueFieldName(fieldName, usedFieldNames)
|
||||||
relationTag := w.typeMapper.BuildRelationshipTag(constraint, true)
|
relationTag := w.typeMapper.BuildRelationshipTag(constraint, true)
|
||||||
|
|
||||||
modelData.AddRelationshipField(&FieldData{
|
modelData.AddRelationshipField(&FieldData{
|
||||||
@@ -289,22 +310,77 @@ func (w *Writer) findTable(schemaName, tableName string, db *models.Database) *m
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// getModelName generates the model name from a table name
|
// getModelName generates the model name from schema and table name
|
||||||
func (w *Writer) getModelName(tableName string) string {
|
func (w *Writer) getModelName(schemaName, tableName string) string {
|
||||||
singular := Singularize(tableName)
|
singular := Singularize(tableName)
|
||||||
modelName := SnakeCaseToPascalCase(singular)
|
tablePart := SnakeCaseToPascalCase(singular)
|
||||||
|
|
||||||
if !hasModelPrefix(modelName) {
|
// Include schema name in model name
|
||||||
modelName = "Model" + modelName
|
var modelName string
|
||||||
|
if schemaName != "" {
|
||||||
|
schemaPart := SnakeCaseToPascalCase(schemaName)
|
||||||
|
modelName = "Model" + schemaPart + tablePart
|
||||||
|
} else {
|
||||||
|
modelName = "Model" + tablePart
|
||||||
}
|
}
|
||||||
|
|
||||||
return modelName
|
return modelName
|
||||||
}
|
}
|
||||||
|
|
||||||
// generateRelationshipFieldName generates a field name for a relationship
|
// generateBelongsToFieldName generates a field name for belongs-to relationships
|
||||||
func (w *Writer) generateRelationshipFieldName(tableName string) string {
|
// Uses the foreign key column name for uniqueness
|
||||||
// Use just the prefix (3 letters) for relationship fields
|
func (w *Writer) generateBelongsToFieldName(constraint *models.Constraint) string {
|
||||||
return GeneratePrefix(tableName)
|
// Use the foreign key column name to ensure uniqueness
|
||||||
|
// If there are multiple columns, use the first one
|
||||||
|
if len(constraint.Columns) > 0 {
|
||||||
|
columnName := constraint.Columns[0]
|
||||||
|
// Convert to PascalCase for proper Go field naming
|
||||||
|
// e.g., "rid_filepointer_request" -> "RelRIDFilepointerRequest"
|
||||||
|
return "Rel" + SnakeCaseToPascalCase(columnName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback to table-based prefix if no columns defined
|
||||||
|
return "Rel" + GeneratePrefix(constraint.ReferencedTable)
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateHasManyFieldName generates a field name for has-many relationships
|
||||||
|
// Uses the foreign key column name + source table name to avoid duplicates
|
||||||
|
func (w *Writer) generateHasManyFieldName(constraint *models.Constraint, sourceSchemaName, sourceTableName string) string {
|
||||||
|
// For has-many, we need to include the source table name to avoid duplicates
|
||||||
|
// e.g., multiple tables referencing the same column on this table
|
||||||
|
if len(constraint.Columns) > 0 {
|
||||||
|
columnName := constraint.Columns[0]
|
||||||
|
// Get the model name for the source table (pluralized)
|
||||||
|
sourceModelName := w.getModelName(sourceSchemaName, sourceTableName)
|
||||||
|
// Remove "Model" prefix if present
|
||||||
|
sourceModelName = strings.TrimPrefix(sourceModelName, "Model")
|
||||||
|
|
||||||
|
// Convert column to PascalCase and combine with source table
|
||||||
|
// e.g., "rid_api_provider" + "Login" -> "RelRIDAPIProviderLogins"
|
||||||
|
columnPart := SnakeCaseToPascalCase(columnName)
|
||||||
|
return "Rel" + columnPart + Pluralize(sourceModelName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback to table-based naming
|
||||||
|
sourceModelName := w.getModelName(sourceSchemaName, sourceTableName)
|
||||||
|
sourceModelName = strings.TrimPrefix(sourceModelName, "Model")
|
||||||
|
return "Rel" + Pluralize(sourceModelName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ensureUniqueFieldName ensures a field name is unique by adding numeric suffixes if needed
|
||||||
|
func (w *Writer) ensureUniqueFieldName(fieldName string, usedNames map[string]int) string {
|
||||||
|
originalName := fieldName
|
||||||
|
count := usedNames[originalName]
|
||||||
|
|
||||||
|
if count > 0 {
|
||||||
|
// Name is already used, add numeric suffix
|
||||||
|
fieldName = fmt.Sprintf("%s%d", originalName, count+1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Increment the counter for this base name
|
||||||
|
usedNames[originalName]++
|
||||||
|
|
||||||
|
return fieldName
|
||||||
}
|
}
|
||||||
|
|
||||||
// getPackageName returns the package name from options or defaults to "models"
|
// getPackageName returns the package name from options or defaults to "models"
|
||||||
@@ -335,6 +411,15 @@ func (w *Writer) writeOutput(content string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// runGoFmt runs go fmt on the specified file
|
||||||
|
func (w *Writer) runGoFmt(filepath string) {
|
||||||
|
cmd := exec.Command("gofmt", "-w", filepath)
|
||||||
|
if err := cmd.Run(); err != nil {
|
||||||
|
// Don't fail the whole operation if gofmt fails, just warn
|
||||||
|
fmt.Fprintf(os.Stderr, "Warning: failed to run gofmt on %s: %v\n", filepath, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// shouldUseMultiFile determines whether to use multi-file mode based on metadata or output path
|
// shouldUseMultiFile determines whether to use multi-file mode based on metadata or output path
|
||||||
func (w *Writer) shouldUseMultiFile() bool {
|
func (w *Writer) shouldUseMultiFile() bool {
|
||||||
// Check if multi_file is explicitly set in metadata
|
// Check if multi_file is explicitly set in metadata
|
||||||
|
|||||||
@@ -66,7 +66,7 @@ func TestWriter_WriteTable(t *testing.T) {
|
|||||||
// Verify key elements are present
|
// Verify key elements are present
|
||||||
expectations := []string{
|
expectations := []string{
|
||||||
"package models",
|
"package models",
|
||||||
"type ModelUser struct",
|
"type ModelPublicUser struct",
|
||||||
"ID",
|
"ID",
|
||||||
"int64",
|
"int64",
|
||||||
"Email",
|
"Email",
|
||||||
@@ -75,9 +75,9 @@ func TestWriter_WriteTable(t *testing.T) {
|
|||||||
"time.Time",
|
"time.Time",
|
||||||
"gorm:\"column:id",
|
"gorm:\"column:id",
|
||||||
"gorm:\"column:email",
|
"gorm:\"column:email",
|
||||||
"func (m ModelUser) TableName() string",
|
"func (m ModelPublicUser) TableName() string",
|
||||||
"return \"public.users\"",
|
"return \"public.users\"",
|
||||||
"func (m ModelUser) GetID() int64",
|
"func (m ModelPublicUser) GetID() int64",
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, expected := range expectations {
|
for _, expected := range expectations {
|
||||||
@@ -164,9 +164,437 @@ func TestWriter_WriteDatabase_MultiFile(t *testing.T) {
|
|||||||
t.Fatalf("Failed to read posts file: %v", err)
|
t.Fatalf("Failed to read posts file: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !strings.Contains(string(postsContent), "USE *ModelUser") {
|
postsStr := string(postsContent)
|
||||||
// Relationship field should be present
|
|
||||||
t.Logf("Posts content:\n%s", string(postsContent))
|
// Verify relationship is present with new naming convention
|
||||||
|
// Should now be RelUserID (belongs-to) instead of USE
|
||||||
|
if !strings.Contains(postsStr, "RelUserID") {
|
||||||
|
t.Errorf("Missing relationship field RelUserID (new naming convention)")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check users file contains has-many relationship
|
||||||
|
usersContent, err := os.ReadFile(filepath.Join(tmpDir, "sql_public_users.go"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to read users file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
usersStr := string(usersContent)
|
||||||
|
|
||||||
|
// Should have RelUserIDPublicPosts (has-many) field - includes schema prefix
|
||||||
|
if !strings.Contains(usersStr, "RelUserIDPublicPosts") {
|
||||||
|
t.Errorf("Missing has-many relationship field RelUserIDPublicPosts")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWriter_MultipleReferencesToSameTable(t *testing.T) {
|
||||||
|
// Test scenario: api_event table with multiple foreign keys to filepointer table
|
||||||
|
db := models.InitDatabase("testdb")
|
||||||
|
schema := models.InitSchema("org")
|
||||||
|
|
||||||
|
// Filepointer table
|
||||||
|
filepointer := models.InitTable("filepointer", "org")
|
||||||
|
filepointer.Columns["id_filepointer"] = &models.Column{
|
||||||
|
Name: "id_filepointer",
|
||||||
|
Type: "bigserial",
|
||||||
|
NotNull: true,
|
||||||
|
IsPrimaryKey: true,
|
||||||
|
}
|
||||||
|
schema.Tables = append(schema.Tables, filepointer)
|
||||||
|
|
||||||
|
// API event table with two foreign keys to filepointer
|
||||||
|
apiEvent := models.InitTable("api_event", "org")
|
||||||
|
apiEvent.Columns["id_api_event"] = &models.Column{
|
||||||
|
Name: "id_api_event",
|
||||||
|
Type: "bigserial",
|
||||||
|
NotNull: true,
|
||||||
|
IsPrimaryKey: true,
|
||||||
|
}
|
||||||
|
apiEvent.Columns["rid_filepointer_request"] = &models.Column{
|
||||||
|
Name: "rid_filepointer_request",
|
||||||
|
Type: "bigint",
|
||||||
|
NotNull: false,
|
||||||
|
}
|
||||||
|
apiEvent.Columns["rid_filepointer_response"] = &models.Column{
|
||||||
|
Name: "rid_filepointer_response",
|
||||||
|
Type: "bigint",
|
||||||
|
NotNull: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add constraints
|
||||||
|
apiEvent.Constraints["fk_request"] = &models.Constraint{
|
||||||
|
Name: "fk_request",
|
||||||
|
Type: models.ForeignKeyConstraint,
|
||||||
|
Columns: []string{"rid_filepointer_request"},
|
||||||
|
ReferencedTable: "filepointer",
|
||||||
|
ReferencedSchema: "org",
|
||||||
|
ReferencedColumns: []string{"id_filepointer"},
|
||||||
|
}
|
||||||
|
apiEvent.Constraints["fk_response"] = &models.Constraint{
|
||||||
|
Name: "fk_response",
|
||||||
|
Type: models.ForeignKeyConstraint,
|
||||||
|
Columns: []string{"rid_filepointer_response"},
|
||||||
|
ReferencedTable: "filepointer",
|
||||||
|
ReferencedSchema: "org",
|
||||||
|
ReferencedColumns: []string{"id_filepointer"},
|
||||||
|
}
|
||||||
|
|
||||||
|
schema.Tables = append(schema.Tables, apiEvent)
|
||||||
|
db.Schemas = append(db.Schemas, schema)
|
||||||
|
|
||||||
|
// Create writer
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
opts := &writers.WriterOptions{
|
||||||
|
PackageName: "models",
|
||||||
|
OutputPath: tmpDir,
|
||||||
|
Metadata: map[string]interface{}{
|
||||||
|
"multi_file": true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
writer := NewWriter(opts)
|
||||||
|
err := writer.WriteDatabase(db)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("WriteDatabase failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read the api_event file
|
||||||
|
apiEventContent, err := os.ReadFile(filepath.Join(tmpDir, "sql_org_api_event.go"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to read api_event file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
contentStr := string(apiEventContent)
|
||||||
|
|
||||||
|
// Verify both relationships have unique names based on column names
|
||||||
|
expectations := []struct {
|
||||||
|
fieldName string
|
||||||
|
tag string
|
||||||
|
}{
|
||||||
|
{"RelRIDFilepointerRequest", "foreignKey:RIDFilepointerRequest"},
|
||||||
|
{"RelRIDFilepointerResponse", "foreignKey:RIDFilepointerResponse"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, exp := range expectations {
|
||||||
|
if !strings.Contains(contentStr, exp.fieldName) {
|
||||||
|
t.Errorf("Missing relationship field: %s\nGenerated:\n%s", exp.fieldName, contentStr)
|
||||||
|
}
|
||||||
|
if !strings.Contains(contentStr, exp.tag) {
|
||||||
|
t.Errorf("Missing relationship tag: %s\nGenerated:\n%s", exp.tag, contentStr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify NO duplicate field names (old behavior would create duplicate "FIL" fields)
|
||||||
|
if strings.Contains(contentStr, "FIL *ModelFilepointer") {
|
||||||
|
t.Errorf("Found old prefix-based naming (FIL), should use column-based naming")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Also verify has-many relationships on filepointer table
|
||||||
|
filepointerContent, err := os.ReadFile(filepath.Join(tmpDir, "sql_org_filepointer.go"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to read filepointer file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
filepointerStr := string(filepointerContent)
|
||||||
|
|
||||||
|
// Should have two different has-many relationships with unique names
|
||||||
|
hasManyExpectations := []string{
|
||||||
|
"RelRIDFilepointerRequestOrgAPIEvents", // Has many via rid_filepointer_request
|
||||||
|
"RelRIDFilepointerResponseOrgAPIEvents", // Has many via rid_filepointer_response
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, exp := range hasManyExpectations {
|
||||||
|
if !strings.Contains(filepointerStr, exp) {
|
||||||
|
t.Errorf("Missing has-many relationship field: %s\nGenerated:\n%s", exp, filepointerStr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWriter_MultipleHasManyRelationships(t *testing.T) {
|
||||||
|
// Test scenario: api_provider table referenced by multiple tables via rid_api_provider
|
||||||
|
db := models.InitDatabase("testdb")
|
||||||
|
schema := models.InitSchema("org")
|
||||||
|
|
||||||
|
// Owner table
|
||||||
|
owner := models.InitTable("owner", "org")
|
||||||
|
owner.Columns["id_owner"] = &models.Column{
|
||||||
|
Name: "id_owner",
|
||||||
|
Type: "bigserial",
|
||||||
|
NotNull: true,
|
||||||
|
IsPrimaryKey: true,
|
||||||
|
}
|
||||||
|
schema.Tables = append(schema.Tables, owner)
|
||||||
|
|
||||||
|
// API Provider table
|
||||||
|
apiProvider := models.InitTable("api_provider", "org")
|
||||||
|
apiProvider.Columns["id_api_provider"] = &models.Column{
|
||||||
|
Name: "id_api_provider",
|
||||||
|
Type: "bigserial",
|
||||||
|
NotNull: true,
|
||||||
|
IsPrimaryKey: true,
|
||||||
|
}
|
||||||
|
apiProvider.Columns["rid_owner"] = &models.Column{
|
||||||
|
Name: "rid_owner",
|
||||||
|
Type: "bigint",
|
||||||
|
NotNull: true,
|
||||||
|
}
|
||||||
|
apiProvider.Constraints["fk_owner"] = &models.Constraint{
|
||||||
|
Name: "fk_owner",
|
||||||
|
Type: models.ForeignKeyConstraint,
|
||||||
|
Columns: []string{"rid_owner"},
|
||||||
|
ReferencedTable: "owner",
|
||||||
|
ReferencedSchema: "org",
|
||||||
|
ReferencedColumns: []string{"id_owner"},
|
||||||
|
}
|
||||||
|
schema.Tables = append(schema.Tables, apiProvider)
|
||||||
|
|
||||||
|
// Login table
|
||||||
|
login := models.InitTable("login", "org")
|
||||||
|
login.Columns["id_login"] = &models.Column{
|
||||||
|
Name: "id_login",
|
||||||
|
Type: "bigserial",
|
||||||
|
NotNull: true,
|
||||||
|
IsPrimaryKey: true,
|
||||||
|
}
|
||||||
|
login.Columns["rid_api_provider"] = &models.Column{
|
||||||
|
Name: "rid_api_provider",
|
||||||
|
Type: "bigint",
|
||||||
|
NotNull: true,
|
||||||
|
}
|
||||||
|
login.Constraints["fk_api_provider"] = &models.Constraint{
|
||||||
|
Name: "fk_api_provider",
|
||||||
|
Type: models.ForeignKeyConstraint,
|
||||||
|
Columns: []string{"rid_api_provider"},
|
||||||
|
ReferencedTable: "api_provider",
|
||||||
|
ReferencedSchema: "org",
|
||||||
|
ReferencedColumns: []string{"id_api_provider"},
|
||||||
|
}
|
||||||
|
schema.Tables = append(schema.Tables, login)
|
||||||
|
|
||||||
|
// Filepointer table
|
||||||
|
filepointer := models.InitTable("filepointer", "org")
|
||||||
|
filepointer.Columns["id_filepointer"] = &models.Column{
|
||||||
|
Name: "id_filepointer",
|
||||||
|
Type: "bigserial",
|
||||||
|
NotNull: true,
|
||||||
|
IsPrimaryKey: true,
|
||||||
|
}
|
||||||
|
filepointer.Columns["rid_api_provider"] = &models.Column{
|
||||||
|
Name: "rid_api_provider",
|
||||||
|
Type: "bigint",
|
||||||
|
NotNull: true,
|
||||||
|
}
|
||||||
|
filepointer.Constraints["fk_api_provider"] = &models.Constraint{
|
||||||
|
Name: "fk_api_provider",
|
||||||
|
Type: models.ForeignKeyConstraint,
|
||||||
|
Columns: []string{"rid_api_provider"},
|
||||||
|
ReferencedTable: "api_provider",
|
||||||
|
ReferencedSchema: "org",
|
||||||
|
ReferencedColumns: []string{"id_api_provider"},
|
||||||
|
}
|
||||||
|
schema.Tables = append(schema.Tables, filepointer)
|
||||||
|
|
||||||
|
// API Event table
|
||||||
|
apiEvent := models.InitTable("api_event", "org")
|
||||||
|
apiEvent.Columns["id_api_event"] = &models.Column{
|
||||||
|
Name: "id_api_event",
|
||||||
|
Type: "bigserial",
|
||||||
|
NotNull: true,
|
||||||
|
IsPrimaryKey: true,
|
||||||
|
}
|
||||||
|
apiEvent.Columns["rid_api_provider"] = &models.Column{
|
||||||
|
Name: "rid_api_provider",
|
||||||
|
Type: "bigint",
|
||||||
|
NotNull: true,
|
||||||
|
}
|
||||||
|
apiEvent.Constraints["fk_api_provider"] = &models.Constraint{
|
||||||
|
Name: "fk_api_provider",
|
||||||
|
Type: models.ForeignKeyConstraint,
|
||||||
|
Columns: []string{"rid_api_provider"},
|
||||||
|
ReferencedTable: "api_provider",
|
||||||
|
ReferencedSchema: "org",
|
||||||
|
ReferencedColumns: []string{"id_api_provider"},
|
||||||
|
}
|
||||||
|
schema.Tables = append(schema.Tables, apiEvent)
|
||||||
|
|
||||||
|
db.Schemas = append(db.Schemas, schema)
|
||||||
|
|
||||||
|
// Create writer
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
opts := &writers.WriterOptions{
|
||||||
|
PackageName: "models",
|
||||||
|
OutputPath: tmpDir,
|
||||||
|
Metadata: map[string]interface{}{
|
||||||
|
"multi_file": true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
writer := NewWriter(opts)
|
||||||
|
err := writer.WriteDatabase(db)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("WriteDatabase failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read the api_provider file
|
||||||
|
apiProviderContent, err := os.ReadFile(filepath.Join(tmpDir, "sql_org_api_provider.go"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to read api_provider file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
contentStr := string(apiProviderContent)
|
||||||
|
|
||||||
|
// Verify all has-many relationships have unique names
|
||||||
|
hasManyExpectations := []string{
|
||||||
|
"RelRIDAPIProviderOrgLogins", // Has many via Login
|
||||||
|
"RelRIDAPIProviderOrgFilepointers", // Has many via Filepointer
|
||||||
|
"RelRIDAPIProviderOrgAPIEvents", // Has many via APIEvent
|
||||||
|
"RelRIDOwner", // Belongs to via rid_owner
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, exp := range hasManyExpectations {
|
||||||
|
if !strings.Contains(contentStr, exp) {
|
||||||
|
t.Errorf("Missing relationship field: %s\nGenerated:\n%s", exp, contentStr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify NO duplicate field names
|
||||||
|
// Count occurrences of "RelRIDAPIProvider" fields - should have 3 unique ones
|
||||||
|
count := strings.Count(contentStr, "RelRIDAPIProvider")
|
||||||
|
if count != 3 {
|
||||||
|
t.Errorf("Expected 3 RelRIDAPIProvider* fields, found %d\nGenerated:\n%s", count, contentStr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify no duplicate declarations (would cause compilation error)
|
||||||
|
duplicatePattern := "RelRIDAPIProviders []*Model"
|
||||||
|
if strings.Contains(contentStr, duplicatePattern) {
|
||||||
|
t.Errorf("Found duplicate field declaration pattern, fields should be unique")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWriter_FieldNameCollision(t *testing.T) {
|
||||||
|
// Test scenario: table with columns that would conflict with generated method names
|
||||||
|
table := models.InitTable("audit_table", "audit")
|
||||||
|
table.Columns["id_audit_table"] = &models.Column{
|
||||||
|
Name: "id_audit_table",
|
||||||
|
Type: "smallint",
|
||||||
|
NotNull: true,
|
||||||
|
IsPrimaryKey: true,
|
||||||
|
Sequence: 1,
|
||||||
|
}
|
||||||
|
table.Columns["table_name"] = &models.Column{
|
||||||
|
Name: "table_name",
|
||||||
|
Type: "varchar",
|
||||||
|
Length: 100,
|
||||||
|
NotNull: true,
|
||||||
|
Sequence: 2,
|
||||||
|
}
|
||||||
|
table.Columns["table_schema"] = &models.Column{
|
||||||
|
Name: "table_schema",
|
||||||
|
Type: "varchar",
|
||||||
|
Length: 100,
|
||||||
|
NotNull: true,
|
||||||
|
Sequence: 3,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create writer
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
opts := &writers.WriterOptions{
|
||||||
|
PackageName: "models",
|
||||||
|
OutputPath: filepath.Join(tmpDir, "test.go"),
|
||||||
|
}
|
||||||
|
|
||||||
|
writer := NewWriter(opts)
|
||||||
|
|
||||||
|
err := writer.WriteTable(table)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("WriteTable failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read the generated file
|
||||||
|
content, err := os.ReadFile(opts.OutputPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to read generated file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
generated := string(content)
|
||||||
|
|
||||||
|
// Verify that TableName field was renamed to TableName_ to avoid collision
|
||||||
|
if !strings.Contains(generated, "TableName_") {
|
||||||
|
t.Errorf("Expected field 'TableName_' (with underscore) but not found\nGenerated:\n%s", generated)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify the struct tag still references the correct database column
|
||||||
|
if !strings.Contains(generated, `gorm:"column:table_name;`) {
|
||||||
|
t.Errorf("Expected gorm tag to reference 'table_name' column\nGenerated:\n%s", generated)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify the TableName() method still exists and doesn't conflict
|
||||||
|
if !strings.Contains(generated, "func (m ModelAuditAuditTable) TableName() string") {
|
||||||
|
t.Errorf("TableName() method should still be generated\nGenerated:\n%s", generated)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify NO field named just "TableName" (without underscore)
|
||||||
|
if strings.Contains(generated, "TableName sql_types") || strings.Contains(generated, "TableName string") {
|
||||||
|
t.Errorf("Field 'TableName' without underscore should not exist (would conflict with method)\nGenerated:\n%s", generated)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWriter_UpdateIDTypeSafety(t *testing.T) {
|
||||||
|
// Test scenario: tables with different primary key types
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
pkType string
|
||||||
|
expectedPK string
|
||||||
|
castType string
|
||||||
|
}{
|
||||||
|
{"int32_pk", "int", "int32", "int32(newid)"},
|
||||||
|
{"int16_pk", "smallint", "int16", "int16(newid)"},
|
||||||
|
{"int64_pk", "bigint", "int64", "int64(newid)"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
table := models.InitTable("test_table", "public")
|
||||||
|
table.Columns["id"] = &models.Column{
|
||||||
|
Name: "id",
|
||||||
|
Type: tt.pkType,
|
||||||
|
NotNull: true,
|
||||||
|
IsPrimaryKey: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
opts := &writers.WriterOptions{
|
||||||
|
PackageName: "models",
|
||||||
|
OutputPath: filepath.Join(tmpDir, "test.go"),
|
||||||
|
}
|
||||||
|
|
||||||
|
writer := NewWriter(opts)
|
||||||
|
err := writer.WriteTable(table)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("WriteTable failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
content, err := os.ReadFile(opts.OutputPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to read generated file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
generated := string(content)
|
||||||
|
|
||||||
|
// Verify UpdateID method has correct type cast
|
||||||
|
if !strings.Contains(generated, tt.castType) {
|
||||||
|
t.Errorf("Expected UpdateID to cast to %s\nGenerated:\n%s", tt.castType, generated)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify no invalid int32(newid) for non-int32 types
|
||||||
|
if tt.expectedPK != "int32" && strings.Contains(generated, "int32(newid)") {
|
||||||
|
t.Errorf("UpdateID should not cast to int32 for %s type\nGenerated:\n%s", tt.pkType, generated)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify UpdateID parameter is int64 (for consistency)
|
||||||
|
if !strings.Contains(generated, "UpdateID(newid int64)") {
|
||||||
|
t.Errorf("UpdateID should accept int64 parameter\nGenerated:\n%s", generated)
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,9 @@
|
|||||||
package writers
|
package writers
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -28,3 +31,56 @@ type WriterOptions struct {
|
|||||||
// Additional options can be added here as needed
|
// Additional options can be added here as needed
|
||||||
Metadata map[string]interface{}
|
Metadata map[string]interface{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SanitizeFilename removes quotes, comments, and invalid characters from identifiers
|
||||||
|
// to make them safe for use in filenames. This handles:
|
||||||
|
// - Double and single quotes: "table_name" or 'table_name' -> table_name
|
||||||
|
// - DBML comments: table [note: 'description'] -> table
|
||||||
|
// - Invalid filename characters: replaced with underscores
|
||||||
|
func SanitizeFilename(name string) string {
|
||||||
|
// Remove DBML/DCTX style comments in brackets (e.g., [note: 'description'])
|
||||||
|
commentRegex := regexp.MustCompile(`\s*\[.*?\]\s*`)
|
||||||
|
name = commentRegex.ReplaceAllString(name, "")
|
||||||
|
|
||||||
|
// Remove quotes (both single and double)
|
||||||
|
name = strings.ReplaceAll(name, `"`, "")
|
||||||
|
name = strings.ReplaceAll(name, `'`, "")
|
||||||
|
|
||||||
|
// Remove backticks (MySQL style identifiers)
|
||||||
|
name = strings.ReplaceAll(name, "`", "")
|
||||||
|
|
||||||
|
// Replace invalid filename characters with underscores
|
||||||
|
// Invalid chars: / \ : * ? " < > | and control characters
|
||||||
|
invalidChars := regexp.MustCompile(`[/\\:*?"<>|\x00-\x1f\x7f]`)
|
||||||
|
name = invalidChars.ReplaceAllString(name, "_")
|
||||||
|
|
||||||
|
// Trim whitespace and consecutive underscores
|
||||||
|
name = strings.TrimSpace(name)
|
||||||
|
name = regexp.MustCompile(`_+`).ReplaceAllString(name, "_")
|
||||||
|
name = strings.Trim(name, "_")
|
||||||
|
|
||||||
|
return name
|
||||||
|
}
|
||||||
|
|
||||||
|
// SanitizeStructTagValue sanitizes a value to be safely used inside Go struct tags.
|
||||||
|
// Go struct tags are delimited by backticks, so any backtick in the value would break the syntax.
|
||||||
|
// This function:
|
||||||
|
// - Removes DBML/DCTX comments in brackets
|
||||||
|
// - Removes all quotes (double, single, and backticks)
|
||||||
|
// - Returns a clean identifier safe for use in struct tags and field names
|
||||||
|
func SanitizeStructTagValue(value string) string {
|
||||||
|
// Remove DBML/DCTX style comments in brackets (e.g., [note: 'description'])
|
||||||
|
commentRegex := regexp.MustCompile(`\s*\[.*?\]\s*`)
|
||||||
|
value = commentRegex.ReplaceAllString(value, "")
|
||||||
|
|
||||||
|
// Trim whitespace
|
||||||
|
value = strings.TrimSpace(value)
|
||||||
|
|
||||||
|
// Remove all quotes: backticks, double quotes, and single quotes
|
||||||
|
// This ensures the value is clean for use as Go identifiers and struct tag values
|
||||||
|
value = strings.ReplaceAll(value, "`", "")
|
||||||
|
value = strings.ReplaceAll(value, `"`, "")
|
||||||
|
value = strings.ReplaceAll(value, `'`, "")
|
||||||
|
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
|||||||
5
tests/assets/dbml/multifile/1_users.dbml
Normal file
5
tests/assets/dbml/multifile/1_users.dbml
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
// First file - users table basic structure
|
||||||
|
Table public.users {
|
||||||
|
id bigint [pk, increment]
|
||||||
|
email varchar(255) [unique, not null]
|
||||||
|
}
|
||||||
8
tests/assets/dbml/multifile/2_posts.dbml
Normal file
8
tests/assets/dbml/multifile/2_posts.dbml
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
// Second file - posts table
|
||||||
|
Table public.posts {
|
||||||
|
id bigint [pk, increment]
|
||||||
|
user_id bigint [not null]
|
||||||
|
title varchar(200) [not null]
|
||||||
|
content text
|
||||||
|
created_at timestamp [not null]
|
||||||
|
}
|
||||||
5
tests/assets/dbml/multifile/3_add_columns.dbml
Normal file
5
tests/assets/dbml/multifile/3_add_columns.dbml
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
// Third file - adds more columns to users table (tests merging)
|
||||||
|
Table public.users {
|
||||||
|
name varchar(100)
|
||||||
|
created_at timestamp [not null]
|
||||||
|
}
|
||||||
10
tests/assets/dbml/multifile/9_refs.dbml
Normal file
10
tests/assets/dbml/multifile/9_refs.dbml
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
// File with commented-out refs - should load last
|
||||||
|
// Contains relationships that depend on earlier tables
|
||||||
|
|
||||||
|
// Ref: public.posts.user_id > public.users.id [ondelete: CASCADE]
|
||||||
|
|
||||||
|
Table public.comments {
|
||||||
|
id bigint [pk, increment]
|
||||||
|
post_id bigint [not null]
|
||||||
|
content text [not null]
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user