Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 4181cb1fbd | |||
| 120ffc6a5a | |||
| b20ad35485 |
@@ -5,6 +5,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||||
|
"git.warky.dev/wdevs/relspecgo/pkg/writers"
|
||||||
)
|
)
|
||||||
|
|
||||||
// TemplateData represents the data passed to the template for code generation
|
// TemplateData represents the data passed to the template for code generation
|
||||||
@@ -133,8 +134,10 @@ func NewModelData(table *models.Table, schema string, typeMapper *TypeMapper) *M
|
|||||||
// Find primary key
|
// Find primary key
|
||||||
for _, col := range table.Columns {
|
for _, col := range table.Columns {
|
||||||
if col.IsPrimaryKey {
|
if col.IsPrimaryKey {
|
||||||
model.PrimaryKeyField = SnakeCaseToPascalCase(col.Name)
|
// Sanitize column name to remove backticks
|
||||||
model.IDColumnName = col.Name
|
safeName := writers.SanitizeStructTagValue(col.Name)
|
||||||
|
model.PrimaryKeyField = SnakeCaseToPascalCase(safeName)
|
||||||
|
model.IDColumnName = safeName
|
||||||
// Check if PK type is a SQL type (contains resolvespec_common or sql_types)
|
// Check if PK type is a SQL type (contains resolvespec_common or sql_types)
|
||||||
goType := typeMapper.SQLTypeToGoType(col.Type, col.NotNull)
|
goType := typeMapper.SQLTypeToGoType(col.Type, col.NotNull)
|
||||||
model.PrimaryKeyIsSQL = strings.Contains(goType, "resolvespec_common") || strings.Contains(goType, "sql_types")
|
model.PrimaryKeyIsSQL = strings.Contains(goType, "resolvespec_common") || strings.Contains(goType, "sql_types")
|
||||||
@@ -154,10 +157,13 @@ func NewModelData(table *models.Table, schema string, typeMapper *TypeMapper) *M
|
|||||||
|
|
||||||
// columnToField converts a models.Column to FieldData
|
// columnToField converts a models.Column to FieldData
|
||||||
func columnToField(col *models.Column, table *models.Table, typeMapper *TypeMapper) *FieldData {
|
func columnToField(col *models.Column, table *models.Table, typeMapper *TypeMapper) *FieldData {
|
||||||
fieldName := SnakeCaseToPascalCase(col.Name)
|
// Sanitize column name first to remove backticks before generating field name
|
||||||
|
safeName := writers.SanitizeStructTagValue(col.Name)
|
||||||
|
fieldName := SnakeCaseToPascalCase(safeName)
|
||||||
goType := typeMapper.SQLTypeToGoType(col.Type, col.NotNull)
|
goType := typeMapper.SQLTypeToGoType(col.Type, col.NotNull)
|
||||||
bunTag := typeMapper.BuildBunTag(col, table)
|
bunTag := typeMapper.BuildBunTag(col, table)
|
||||||
jsonTag := col.Name // Use column name for JSON tag
|
// Use same sanitized name for JSON tag
|
||||||
|
jsonTag := safeName
|
||||||
|
|
||||||
return &FieldData{
|
return &FieldData{
|
||||||
Name: fieldName,
|
Name: fieldName,
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||||
|
"git.warky.dev/wdevs/relspecgo/pkg/writers"
|
||||||
)
|
)
|
||||||
|
|
||||||
// TypeMapper handles type conversions between SQL and Go types for Bun
|
// TypeMapper handles type conversions between SQL and Go types for Bun
|
||||||
@@ -164,11 +165,14 @@ func (tm *TypeMapper) BuildBunTag(column *models.Column, table *models.Table) st
|
|||||||
var parts []string
|
var parts []string
|
||||||
|
|
||||||
// Column name comes first (no prefix)
|
// Column name comes first (no prefix)
|
||||||
parts = append(parts, column.Name)
|
// Sanitize to remove backticks which would break struct tag syntax
|
||||||
|
safeName := writers.SanitizeStructTagValue(column.Name)
|
||||||
|
parts = append(parts, safeName)
|
||||||
|
|
||||||
// Add type if specified
|
// Add type if specified
|
||||||
if column.Type != "" {
|
if column.Type != "" {
|
||||||
typeStr := column.Type
|
// Sanitize type to remove backticks
|
||||||
|
typeStr := writers.SanitizeStructTagValue(column.Type)
|
||||||
if column.Length > 0 {
|
if column.Length > 0 {
|
||||||
typeStr = fmt.Sprintf("%s(%d)", typeStr, column.Length)
|
typeStr = fmt.Sprintf("%s(%d)", typeStr, column.Length)
|
||||||
} else if column.Precision > 0 {
|
} else if column.Precision > 0 {
|
||||||
@@ -188,7 +192,9 @@ func (tm *TypeMapper) BuildBunTag(column *models.Column, table *models.Table) st
|
|||||||
|
|
||||||
// Default value
|
// Default value
|
||||||
if column.Default != nil {
|
if column.Default != nil {
|
||||||
parts = append(parts, fmt.Sprintf("default:%v", column.Default))
|
// Sanitize default value to remove backticks
|
||||||
|
safeDefault := writers.SanitizeStructTagValue(fmt.Sprintf("%v", column.Default))
|
||||||
|
parts = append(parts, fmt.Sprintf("default:%s", safeDefault))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Nullable (Bun uses nullzero for nullable fields)
|
// Nullable (Bun uses nullzero for nullable fields)
|
||||||
@@ -263,7 +269,7 @@ func (tm *TypeMapper) NeedsFmtImport(generateGetIDStr bool) bool {
|
|||||||
|
|
||||||
// GetSQLTypesImport returns the import path for sql_types (ResolveSpec common)
|
// GetSQLTypesImport returns the import path for sql_types (ResolveSpec common)
|
||||||
func (tm *TypeMapper) GetSQLTypesImport() string {
|
func (tm *TypeMapper) GetSQLTypesImport() string {
|
||||||
return "github.com/bitechdev/ResolveSpec/pkg/common"
|
return "github.com/bitechdev/ResolveSpec/pkg/spectypes"
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetBunImport returns the import path for Bun
|
// GetBunImport returns the import path for Bun
|
||||||
|
|||||||
@@ -225,6 +225,9 @@ func (w *Writer) writeMultiFile(db *models.Database) error {
|
|||||||
|
|
||||||
// addRelationshipFields adds relationship fields to the model based on foreign keys
|
// addRelationshipFields adds relationship fields to the model based on foreign keys
|
||||||
func (w *Writer) addRelationshipFields(modelData *ModelData, table *models.Table, schema *models.Schema, db *models.Database) {
|
func (w *Writer) addRelationshipFields(modelData *ModelData, table *models.Table, schema *models.Schema, db *models.Database) {
|
||||||
|
// Track used field names to detect duplicates
|
||||||
|
usedFieldNames := make(map[string]int)
|
||||||
|
|
||||||
// For each foreign key in this table, add a belongs-to/has-one relationship
|
// For each foreign key in this table, add a belongs-to/has-one relationship
|
||||||
for _, constraint := range table.Constraints {
|
for _, constraint := range table.Constraints {
|
||||||
if constraint.Type != models.ForeignKeyConstraint {
|
if constraint.Type != models.ForeignKeyConstraint {
|
||||||
@@ -239,7 +242,8 @@ func (w *Writer) addRelationshipFields(modelData *ModelData, table *models.Table
|
|||||||
|
|
||||||
// Create relationship field (has-one in Bun, similar to belongs-to in GORM)
|
// Create relationship field (has-one in Bun, similar to belongs-to in GORM)
|
||||||
refModelName := w.getModelName(constraint.ReferencedTable)
|
refModelName := w.getModelName(constraint.ReferencedTable)
|
||||||
fieldName := w.generateRelationshipFieldName(constraint.ReferencedTable)
|
fieldName := w.generateHasOneFieldName(constraint)
|
||||||
|
fieldName = w.ensureUniqueFieldName(fieldName, usedFieldNames)
|
||||||
relationTag := w.typeMapper.BuildRelationshipTag(constraint, "has-one")
|
relationTag := w.typeMapper.BuildRelationshipTag(constraint, "has-one")
|
||||||
|
|
||||||
modelData.AddRelationshipField(&FieldData{
|
modelData.AddRelationshipField(&FieldData{
|
||||||
@@ -267,7 +271,8 @@ func (w *Writer) addRelationshipFields(modelData *ModelData, table *models.Table
|
|||||||
if constraint.ReferencedTable == table.Name && constraint.ReferencedSchema == schema.Name {
|
if constraint.ReferencedTable == table.Name && constraint.ReferencedSchema == schema.Name {
|
||||||
// Add has-many relationship
|
// Add has-many relationship
|
||||||
otherModelName := w.getModelName(otherTable.Name)
|
otherModelName := w.getModelName(otherTable.Name)
|
||||||
fieldName := w.generateRelationshipFieldName(otherTable.Name) + "s" // Pluralize
|
fieldName := w.generateHasManyFieldName(constraint, otherTable.Name)
|
||||||
|
fieldName = w.ensureUniqueFieldName(fieldName, usedFieldNames)
|
||||||
relationTag := w.typeMapper.BuildRelationshipTag(constraint, "has-many")
|
relationTag := w.typeMapper.BuildRelationshipTag(constraint, "has-many")
|
||||||
|
|
||||||
modelData.AddRelationshipField(&FieldData{
|
modelData.AddRelationshipField(&FieldData{
|
||||||
@@ -310,10 +315,60 @@ func (w *Writer) getModelName(tableName string) string {
|
|||||||
return modelName
|
return modelName
|
||||||
}
|
}
|
||||||
|
|
||||||
// generateRelationshipFieldName generates a field name for a relationship
|
// generateHasOneFieldName generates a field name for has-one relationships
|
||||||
func (w *Writer) generateRelationshipFieldName(tableName string) string {
|
// Uses the foreign key column name for uniqueness
|
||||||
// Use just the prefix (3 letters) for relationship fields
|
func (w *Writer) generateHasOneFieldName(constraint *models.Constraint) string {
|
||||||
return GeneratePrefix(tableName)
|
// Use the foreign key column name to ensure uniqueness
|
||||||
|
// If there are multiple columns, use the first one
|
||||||
|
if len(constraint.Columns) > 0 {
|
||||||
|
columnName := constraint.Columns[0]
|
||||||
|
// Convert to PascalCase for proper Go field naming
|
||||||
|
// e.g., "rid_filepointer_request" -> "RelRIDFilepointerRequest"
|
||||||
|
return "Rel" + SnakeCaseToPascalCase(columnName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback to table-based prefix if no columns defined
|
||||||
|
return "Rel" + GeneratePrefix(constraint.ReferencedTable)
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateHasManyFieldName generates a field name for has-many relationships
|
||||||
|
// Uses the foreign key column name + source table name to avoid duplicates
|
||||||
|
func (w *Writer) generateHasManyFieldName(constraint *models.Constraint, sourceTableName string) string {
|
||||||
|
// For has-many, we need to include the source table name to avoid duplicates
|
||||||
|
// e.g., multiple tables referencing the same column on this table
|
||||||
|
if len(constraint.Columns) > 0 {
|
||||||
|
columnName := constraint.Columns[0]
|
||||||
|
// Get the model name for the source table (pluralized)
|
||||||
|
sourceModelName := w.getModelName(sourceTableName)
|
||||||
|
// Remove "Model" prefix if present
|
||||||
|
sourceModelName = strings.TrimPrefix(sourceModelName, "Model")
|
||||||
|
|
||||||
|
// Convert column to PascalCase and combine with source table
|
||||||
|
// e.g., "rid_api_provider" + "Login" -> "RelRIDAPIProviderLogins"
|
||||||
|
columnPart := SnakeCaseToPascalCase(columnName)
|
||||||
|
return "Rel" + columnPart + Pluralize(sourceModelName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback to table-based naming
|
||||||
|
sourceModelName := w.getModelName(sourceTableName)
|
||||||
|
sourceModelName = strings.TrimPrefix(sourceModelName, "Model")
|
||||||
|
return "Rel" + Pluralize(sourceModelName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ensureUniqueFieldName ensures a field name is unique by adding numeric suffixes if needed
|
||||||
|
func (w *Writer) ensureUniqueFieldName(fieldName string, usedNames map[string]int) string {
|
||||||
|
originalName := fieldName
|
||||||
|
count := usedNames[originalName]
|
||||||
|
|
||||||
|
if count > 0 {
|
||||||
|
// Name is already used, add numeric suffix
|
||||||
|
fieldName = fmt.Sprintf("%s%d", originalName, count+1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Increment the counter for this base name
|
||||||
|
usedNames[originalName]++
|
||||||
|
|
||||||
|
return fieldName
|
||||||
}
|
}
|
||||||
|
|
||||||
// getPackageName returns the package name from options or defaults to "models"
|
// getPackageName returns the package name from options or defaults to "models"
|
||||||
|
|||||||
@@ -175,12 +175,310 @@ func TestWriter_WriteDatabase_MultiFile(t *testing.T) {
|
|||||||
postsStr := string(postsContent)
|
postsStr := string(postsContent)
|
||||||
|
|
||||||
// Verify relationship is present with Bun format
|
// Verify relationship is present with Bun format
|
||||||
if !strings.Contains(postsStr, "USE") {
|
// Should now be RelUserID (has-one) instead of USE
|
||||||
t.Errorf("Missing relationship field USE")
|
if !strings.Contains(postsStr, "RelUserID") {
|
||||||
|
t.Errorf("Missing relationship field RelUserID (new naming convention)")
|
||||||
}
|
}
|
||||||
if !strings.Contains(postsStr, "rel:has-one") {
|
if !strings.Contains(postsStr, "rel:has-one") {
|
||||||
t.Errorf("Missing Bun relationship tag: %s", postsStr)
|
t.Errorf("Missing Bun relationship tag: %s", postsStr)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check users file contains has-many relationship
|
||||||
|
usersContent, err := os.ReadFile(filepath.Join(tmpDir, "sql_public_users.go"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to read users file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
usersStr := string(usersContent)
|
||||||
|
|
||||||
|
// Should have RelUserIDPosts (has-many) field
|
||||||
|
if !strings.Contains(usersStr, "RelUserIDPosts") {
|
||||||
|
t.Errorf("Missing has-many relationship field RelUserIDPosts")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWriter_MultipleReferencesToSameTable(t *testing.T) {
|
||||||
|
// Test scenario: api_event table with multiple foreign keys to filepointer table
|
||||||
|
db := models.InitDatabase("testdb")
|
||||||
|
schema := models.InitSchema("org")
|
||||||
|
|
||||||
|
// Filepointer table
|
||||||
|
filepointer := models.InitTable("filepointer", "org")
|
||||||
|
filepointer.Columns["id_filepointer"] = &models.Column{
|
||||||
|
Name: "id_filepointer",
|
||||||
|
Type: "bigserial",
|
||||||
|
NotNull: true,
|
||||||
|
IsPrimaryKey: true,
|
||||||
|
}
|
||||||
|
schema.Tables = append(schema.Tables, filepointer)
|
||||||
|
|
||||||
|
// API event table with two foreign keys to filepointer
|
||||||
|
apiEvent := models.InitTable("api_event", "org")
|
||||||
|
apiEvent.Columns["id_api_event"] = &models.Column{
|
||||||
|
Name: "id_api_event",
|
||||||
|
Type: "bigserial",
|
||||||
|
NotNull: true,
|
||||||
|
IsPrimaryKey: true,
|
||||||
|
}
|
||||||
|
apiEvent.Columns["rid_filepointer_request"] = &models.Column{
|
||||||
|
Name: "rid_filepointer_request",
|
||||||
|
Type: "bigint",
|
||||||
|
NotNull: false,
|
||||||
|
}
|
||||||
|
apiEvent.Columns["rid_filepointer_response"] = &models.Column{
|
||||||
|
Name: "rid_filepointer_response",
|
||||||
|
Type: "bigint",
|
||||||
|
NotNull: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add constraints
|
||||||
|
apiEvent.Constraints["fk_request"] = &models.Constraint{
|
||||||
|
Name: "fk_request",
|
||||||
|
Type: models.ForeignKeyConstraint,
|
||||||
|
Columns: []string{"rid_filepointer_request"},
|
||||||
|
ReferencedTable: "filepointer",
|
||||||
|
ReferencedSchema: "org",
|
||||||
|
ReferencedColumns: []string{"id_filepointer"},
|
||||||
|
}
|
||||||
|
apiEvent.Constraints["fk_response"] = &models.Constraint{
|
||||||
|
Name: "fk_response",
|
||||||
|
Type: models.ForeignKeyConstraint,
|
||||||
|
Columns: []string{"rid_filepointer_response"},
|
||||||
|
ReferencedTable: "filepointer",
|
||||||
|
ReferencedSchema: "org",
|
||||||
|
ReferencedColumns: []string{"id_filepointer"},
|
||||||
|
}
|
||||||
|
|
||||||
|
schema.Tables = append(schema.Tables, apiEvent)
|
||||||
|
db.Schemas = append(db.Schemas, schema)
|
||||||
|
|
||||||
|
// Create writer
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
opts := &writers.WriterOptions{
|
||||||
|
PackageName: "models",
|
||||||
|
OutputPath: tmpDir,
|
||||||
|
Metadata: map[string]interface{}{
|
||||||
|
"multi_file": true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
writer := NewWriter(opts)
|
||||||
|
err := writer.WriteDatabase(db)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("WriteDatabase failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read the api_event file
|
||||||
|
apiEventContent, err := os.ReadFile(filepath.Join(tmpDir, "sql_org_api_event.go"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to read api_event file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
contentStr := string(apiEventContent)
|
||||||
|
|
||||||
|
// Verify both relationships have unique names based on column names
|
||||||
|
expectations := []struct {
|
||||||
|
fieldName string
|
||||||
|
tag string
|
||||||
|
}{
|
||||||
|
{"RelRIDFilepointerRequest", "join:rid_filepointer_request=id_filepointer"},
|
||||||
|
{"RelRIDFilepointerResponse", "join:rid_filepointer_response=id_filepointer"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, exp := range expectations {
|
||||||
|
if !strings.Contains(contentStr, exp.fieldName) {
|
||||||
|
t.Errorf("Missing relationship field: %s\nGenerated:\n%s", exp.fieldName, contentStr)
|
||||||
|
}
|
||||||
|
if !strings.Contains(contentStr, exp.tag) {
|
||||||
|
t.Errorf("Missing relationship tag: %s\nGenerated:\n%s", exp.tag, contentStr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify NO duplicate field names (old behavior would create duplicate "FIL" fields)
|
||||||
|
if strings.Contains(contentStr, "FIL *ModelFilepointer") {
|
||||||
|
t.Errorf("Found old prefix-based naming (FIL), should use column-based naming")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Also verify has-many relationships on filepointer table
|
||||||
|
filepointerContent, err := os.ReadFile(filepath.Join(tmpDir, "sql_org_filepointer.go"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to read filepointer file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
filepointerStr := string(filepointerContent)
|
||||||
|
|
||||||
|
// Should have two different has-many relationships with unique names
|
||||||
|
hasManyExpectations := []string{
|
||||||
|
"RelRIDFilepointerRequestAPIEvents", // Has many via rid_filepointer_request
|
||||||
|
"RelRIDFilepointerResponseAPIEvents", // Has many via rid_filepointer_response
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, exp := range hasManyExpectations {
|
||||||
|
if !strings.Contains(filepointerStr, exp) {
|
||||||
|
t.Errorf("Missing has-many relationship field: %s\nGenerated:\n%s", exp, filepointerStr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWriter_MultipleHasManyRelationships(t *testing.T) {
|
||||||
|
// Test scenario: api_provider table referenced by multiple tables via rid_api_provider
|
||||||
|
db := models.InitDatabase("testdb")
|
||||||
|
schema := models.InitSchema("org")
|
||||||
|
|
||||||
|
// Owner table
|
||||||
|
owner := models.InitTable("owner", "org")
|
||||||
|
owner.Columns["id_owner"] = &models.Column{
|
||||||
|
Name: "id_owner",
|
||||||
|
Type: "bigserial",
|
||||||
|
NotNull: true,
|
||||||
|
IsPrimaryKey: true,
|
||||||
|
}
|
||||||
|
schema.Tables = append(schema.Tables, owner)
|
||||||
|
|
||||||
|
// API Provider table
|
||||||
|
apiProvider := models.InitTable("api_provider", "org")
|
||||||
|
apiProvider.Columns["id_api_provider"] = &models.Column{
|
||||||
|
Name: "id_api_provider",
|
||||||
|
Type: "bigserial",
|
||||||
|
NotNull: true,
|
||||||
|
IsPrimaryKey: true,
|
||||||
|
}
|
||||||
|
apiProvider.Columns["rid_owner"] = &models.Column{
|
||||||
|
Name: "rid_owner",
|
||||||
|
Type: "bigint",
|
||||||
|
NotNull: true,
|
||||||
|
}
|
||||||
|
apiProvider.Constraints["fk_owner"] = &models.Constraint{
|
||||||
|
Name: "fk_owner",
|
||||||
|
Type: models.ForeignKeyConstraint,
|
||||||
|
Columns: []string{"rid_owner"},
|
||||||
|
ReferencedTable: "owner",
|
||||||
|
ReferencedSchema: "org",
|
||||||
|
ReferencedColumns: []string{"id_owner"},
|
||||||
|
}
|
||||||
|
schema.Tables = append(schema.Tables, apiProvider)
|
||||||
|
|
||||||
|
// Login table
|
||||||
|
login := models.InitTable("login", "org")
|
||||||
|
login.Columns["id_login"] = &models.Column{
|
||||||
|
Name: "id_login",
|
||||||
|
Type: "bigserial",
|
||||||
|
NotNull: true,
|
||||||
|
IsPrimaryKey: true,
|
||||||
|
}
|
||||||
|
login.Columns["rid_api_provider"] = &models.Column{
|
||||||
|
Name: "rid_api_provider",
|
||||||
|
Type: "bigint",
|
||||||
|
NotNull: true,
|
||||||
|
}
|
||||||
|
login.Constraints["fk_api_provider"] = &models.Constraint{
|
||||||
|
Name: "fk_api_provider",
|
||||||
|
Type: models.ForeignKeyConstraint,
|
||||||
|
Columns: []string{"rid_api_provider"},
|
||||||
|
ReferencedTable: "api_provider",
|
||||||
|
ReferencedSchema: "org",
|
||||||
|
ReferencedColumns: []string{"id_api_provider"},
|
||||||
|
}
|
||||||
|
schema.Tables = append(schema.Tables, login)
|
||||||
|
|
||||||
|
// Filepointer table
|
||||||
|
filepointer := models.InitTable("filepointer", "org")
|
||||||
|
filepointer.Columns["id_filepointer"] = &models.Column{
|
||||||
|
Name: "id_filepointer",
|
||||||
|
Type: "bigserial",
|
||||||
|
NotNull: true,
|
||||||
|
IsPrimaryKey: true,
|
||||||
|
}
|
||||||
|
filepointer.Columns["rid_api_provider"] = &models.Column{
|
||||||
|
Name: "rid_api_provider",
|
||||||
|
Type: "bigint",
|
||||||
|
NotNull: true,
|
||||||
|
}
|
||||||
|
filepointer.Constraints["fk_api_provider"] = &models.Constraint{
|
||||||
|
Name: "fk_api_provider",
|
||||||
|
Type: models.ForeignKeyConstraint,
|
||||||
|
Columns: []string{"rid_api_provider"},
|
||||||
|
ReferencedTable: "api_provider",
|
||||||
|
ReferencedSchema: "org",
|
||||||
|
ReferencedColumns: []string{"id_api_provider"},
|
||||||
|
}
|
||||||
|
schema.Tables = append(schema.Tables, filepointer)
|
||||||
|
|
||||||
|
// API Event table
|
||||||
|
apiEvent := models.InitTable("api_event", "org")
|
||||||
|
apiEvent.Columns["id_api_event"] = &models.Column{
|
||||||
|
Name: "id_api_event",
|
||||||
|
Type: "bigserial",
|
||||||
|
NotNull: true,
|
||||||
|
IsPrimaryKey: true,
|
||||||
|
}
|
||||||
|
apiEvent.Columns["rid_api_provider"] = &models.Column{
|
||||||
|
Name: "rid_api_provider",
|
||||||
|
Type: "bigint",
|
||||||
|
NotNull: true,
|
||||||
|
}
|
||||||
|
apiEvent.Constraints["fk_api_provider"] = &models.Constraint{
|
||||||
|
Name: "fk_api_provider",
|
||||||
|
Type: models.ForeignKeyConstraint,
|
||||||
|
Columns: []string{"rid_api_provider"},
|
||||||
|
ReferencedTable: "api_provider",
|
||||||
|
ReferencedSchema: "org",
|
||||||
|
ReferencedColumns: []string{"id_api_provider"},
|
||||||
|
}
|
||||||
|
schema.Tables = append(schema.Tables, apiEvent)
|
||||||
|
|
||||||
|
db.Schemas = append(db.Schemas, schema)
|
||||||
|
|
||||||
|
// Create writer
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
opts := &writers.WriterOptions{
|
||||||
|
PackageName: "models",
|
||||||
|
OutputPath: tmpDir,
|
||||||
|
Metadata: map[string]interface{}{
|
||||||
|
"multi_file": true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
writer := NewWriter(opts)
|
||||||
|
err := writer.WriteDatabase(db)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("WriteDatabase failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read the api_provider file
|
||||||
|
apiProviderContent, err := os.ReadFile(filepath.Join(tmpDir, "sql_org_api_provider.go"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to read api_provider file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
contentStr := string(apiProviderContent)
|
||||||
|
|
||||||
|
// Verify all has-many relationships have unique names
|
||||||
|
hasManyExpectations := []string{
|
||||||
|
"RelRIDAPIProviderLogins", // Has many via Login
|
||||||
|
"RelRIDAPIProviderFilepointers", // Has many via Filepointer
|
||||||
|
"RelRIDAPIProviderAPIEvents", // Has many via APIEvent
|
||||||
|
"RelRIDOwner", // Has one via rid_owner
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, exp := range hasManyExpectations {
|
||||||
|
if !strings.Contains(contentStr, exp) {
|
||||||
|
t.Errorf("Missing relationship field: %s\nGenerated:\n%s", exp, contentStr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify NO duplicate field names
|
||||||
|
// Count occurrences of "RelRIDAPIProvider" fields - should have 3 unique ones
|
||||||
|
count := strings.Count(contentStr, "RelRIDAPIProvider")
|
||||||
|
if count != 3 {
|
||||||
|
t.Errorf("Expected 3 RelRIDAPIProvider* fields, found %d\nGenerated:\n%s", count, contentStr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify no duplicate declarations (would cause compilation error)
|
||||||
|
duplicatePattern := "RelRIDAPIProviders []*Model"
|
||||||
|
if strings.Contains(contentStr, duplicatePattern) {
|
||||||
|
t.Errorf("Found duplicate field declaration pattern, fields should be unique")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestTypeMapper_SQLTypeToGoType_Bun(t *testing.T) {
|
func TestTypeMapper_SQLTypeToGoType_Bun(t *testing.T) {
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"sort"
|
"sort"
|
||||||
|
|
||||||
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||||
|
"git.warky.dev/wdevs/relspecgo/pkg/writers"
|
||||||
)
|
)
|
||||||
|
|
||||||
// TemplateData represents the data passed to the template for code generation
|
// TemplateData represents the data passed to the template for code generation
|
||||||
@@ -131,8 +132,10 @@ func NewModelData(table *models.Table, schema string, typeMapper *TypeMapper) *M
|
|||||||
// Find primary key
|
// Find primary key
|
||||||
for _, col := range table.Columns {
|
for _, col := range table.Columns {
|
||||||
if col.IsPrimaryKey {
|
if col.IsPrimaryKey {
|
||||||
model.PrimaryKeyField = SnakeCaseToPascalCase(col.Name)
|
// Sanitize column name to remove backticks
|
||||||
model.IDColumnName = col.Name
|
safeName := writers.SanitizeStructTagValue(col.Name)
|
||||||
|
model.PrimaryKeyField = SnakeCaseToPascalCase(safeName)
|
||||||
|
model.IDColumnName = safeName
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -149,10 +152,13 @@ func NewModelData(table *models.Table, schema string, typeMapper *TypeMapper) *M
|
|||||||
|
|
||||||
// columnToField converts a models.Column to FieldData
|
// columnToField converts a models.Column to FieldData
|
||||||
func columnToField(col *models.Column, table *models.Table, typeMapper *TypeMapper) *FieldData {
|
func columnToField(col *models.Column, table *models.Table, typeMapper *TypeMapper) *FieldData {
|
||||||
fieldName := SnakeCaseToPascalCase(col.Name)
|
// Sanitize column name first to remove backticks before generating field name
|
||||||
|
safeName := writers.SanitizeStructTagValue(col.Name)
|
||||||
|
fieldName := SnakeCaseToPascalCase(safeName)
|
||||||
goType := typeMapper.SQLTypeToGoType(col.Type, col.NotNull)
|
goType := typeMapper.SQLTypeToGoType(col.Type, col.NotNull)
|
||||||
gormTag := typeMapper.BuildGormTag(col, table)
|
gormTag := typeMapper.BuildGormTag(col, table)
|
||||||
jsonTag := col.Name // Use column name for JSON tag
|
// Use same sanitized name for JSON tag
|
||||||
|
jsonTag := safeName
|
||||||
|
|
||||||
return &FieldData{
|
return &FieldData{
|
||||||
Name: fieldName,
|
Name: fieldName,
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||||
|
"git.warky.dev/wdevs/relspecgo/pkg/writers"
|
||||||
)
|
)
|
||||||
|
|
||||||
// TypeMapper handles type conversions between SQL and Go types
|
// TypeMapper handles type conversions between SQL and Go types
|
||||||
@@ -199,12 +200,15 @@ func (tm *TypeMapper) BuildGormTag(column *models.Column, table *models.Table) s
|
|||||||
var parts []string
|
var parts []string
|
||||||
|
|
||||||
// Always include column name (lowercase as per user requirement)
|
// Always include column name (lowercase as per user requirement)
|
||||||
parts = append(parts, fmt.Sprintf("column:%s", column.Name))
|
// Sanitize to remove backticks which would break struct tag syntax
|
||||||
|
safeName := writers.SanitizeStructTagValue(column.Name)
|
||||||
|
parts = append(parts, fmt.Sprintf("column:%s", safeName))
|
||||||
|
|
||||||
// Add type if specified
|
// Add type if specified
|
||||||
if column.Type != "" {
|
if column.Type != "" {
|
||||||
// Include length, precision, scale if present
|
// Include length, precision, scale if present
|
||||||
typeStr := column.Type
|
// Sanitize type to remove backticks
|
||||||
|
typeStr := writers.SanitizeStructTagValue(column.Type)
|
||||||
if column.Length > 0 {
|
if column.Length > 0 {
|
||||||
typeStr = fmt.Sprintf("%s(%d)", typeStr, column.Length)
|
typeStr = fmt.Sprintf("%s(%d)", typeStr, column.Length)
|
||||||
} else if column.Precision > 0 {
|
} else if column.Precision > 0 {
|
||||||
@@ -234,7 +238,9 @@ func (tm *TypeMapper) BuildGormTag(column *models.Column, table *models.Table) s
|
|||||||
|
|
||||||
// Default value
|
// Default value
|
||||||
if column.Default != nil {
|
if column.Default != nil {
|
||||||
parts = append(parts, fmt.Sprintf("default:%v", column.Default))
|
// Sanitize default value to remove backticks
|
||||||
|
safeDefault := writers.SanitizeStructTagValue(fmt.Sprintf("%v", column.Default))
|
||||||
|
parts = append(parts, fmt.Sprintf("default:%s", safeDefault))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check for unique constraint
|
// Check for unique constraint
|
||||||
@@ -331,5 +337,5 @@ func (tm *TypeMapper) NeedsFmtImport(generateGetIDStr bool) bool {
|
|||||||
|
|
||||||
// GetSQLTypesImport returns the import path for sql_types
|
// GetSQLTypesImport returns the import path for sql_types
|
||||||
func (tm *TypeMapper) GetSQLTypesImport() string {
|
func (tm *TypeMapper) GetSQLTypesImport() string {
|
||||||
return "github.com/bitechdev/ResolveSpec/pkg/common/sql_types"
|
return "github.com/bitechdev/ResolveSpec/pkg/spectypes"
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -219,6 +219,9 @@ func (w *Writer) writeMultiFile(db *models.Database) error {
|
|||||||
|
|
||||||
// addRelationshipFields adds relationship fields to the model based on foreign keys
|
// addRelationshipFields adds relationship fields to the model based on foreign keys
|
||||||
func (w *Writer) addRelationshipFields(modelData *ModelData, table *models.Table, schema *models.Schema, db *models.Database) {
|
func (w *Writer) addRelationshipFields(modelData *ModelData, table *models.Table, schema *models.Schema, db *models.Database) {
|
||||||
|
// Track used field names to detect duplicates
|
||||||
|
usedFieldNames := make(map[string]int)
|
||||||
|
|
||||||
// For each foreign key in this table, add a belongs-to relationship
|
// For each foreign key in this table, add a belongs-to relationship
|
||||||
for _, constraint := range table.Constraints {
|
for _, constraint := range table.Constraints {
|
||||||
if constraint.Type != models.ForeignKeyConstraint {
|
if constraint.Type != models.ForeignKeyConstraint {
|
||||||
@@ -233,7 +236,8 @@ func (w *Writer) addRelationshipFields(modelData *ModelData, table *models.Table
|
|||||||
|
|
||||||
// Create relationship field (belongs-to)
|
// Create relationship field (belongs-to)
|
||||||
refModelName := w.getModelName(constraint.ReferencedTable)
|
refModelName := w.getModelName(constraint.ReferencedTable)
|
||||||
fieldName := w.generateRelationshipFieldName(constraint.ReferencedTable)
|
fieldName := w.generateBelongsToFieldName(constraint)
|
||||||
|
fieldName = w.ensureUniqueFieldName(fieldName, usedFieldNames)
|
||||||
relationTag := w.typeMapper.BuildRelationshipTag(constraint, false)
|
relationTag := w.typeMapper.BuildRelationshipTag(constraint, false)
|
||||||
|
|
||||||
modelData.AddRelationshipField(&FieldData{
|
modelData.AddRelationshipField(&FieldData{
|
||||||
@@ -261,7 +265,8 @@ func (w *Writer) addRelationshipFields(modelData *ModelData, table *models.Table
|
|||||||
if constraint.ReferencedTable == table.Name && constraint.ReferencedSchema == schema.Name {
|
if constraint.ReferencedTable == table.Name && constraint.ReferencedSchema == schema.Name {
|
||||||
// Add has-many relationship
|
// Add has-many relationship
|
||||||
otherModelName := w.getModelName(otherTable.Name)
|
otherModelName := w.getModelName(otherTable.Name)
|
||||||
fieldName := w.generateRelationshipFieldName(otherTable.Name) + "s" // Pluralize
|
fieldName := w.generateHasManyFieldName(constraint, otherTable.Name)
|
||||||
|
fieldName = w.ensureUniqueFieldName(fieldName, usedFieldNames)
|
||||||
relationTag := w.typeMapper.BuildRelationshipTag(constraint, true)
|
relationTag := w.typeMapper.BuildRelationshipTag(constraint, true)
|
||||||
|
|
||||||
modelData.AddRelationshipField(&FieldData{
|
modelData.AddRelationshipField(&FieldData{
|
||||||
@@ -304,10 +309,60 @@ func (w *Writer) getModelName(tableName string) string {
|
|||||||
return modelName
|
return modelName
|
||||||
}
|
}
|
||||||
|
|
||||||
// generateRelationshipFieldName generates a field name for a relationship
|
// generateBelongsToFieldName generates a field name for belongs-to relationships
|
||||||
func (w *Writer) generateRelationshipFieldName(tableName string) string {
|
// Uses the foreign key column name for uniqueness
|
||||||
// Use just the prefix (3 letters) for relationship fields
|
func (w *Writer) generateBelongsToFieldName(constraint *models.Constraint) string {
|
||||||
return GeneratePrefix(tableName)
|
// Use the foreign key column name to ensure uniqueness
|
||||||
|
// If there are multiple columns, use the first one
|
||||||
|
if len(constraint.Columns) > 0 {
|
||||||
|
columnName := constraint.Columns[0]
|
||||||
|
// Convert to PascalCase for proper Go field naming
|
||||||
|
// e.g., "rid_filepointer_request" -> "RelRIDFilepointerRequest"
|
||||||
|
return "Rel" + SnakeCaseToPascalCase(columnName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback to table-based prefix if no columns defined
|
||||||
|
return "Rel" + GeneratePrefix(constraint.ReferencedTable)
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateHasManyFieldName generates a field name for has-many relationships
|
||||||
|
// Uses the foreign key column name + source table name to avoid duplicates
|
||||||
|
func (w *Writer) generateHasManyFieldName(constraint *models.Constraint, sourceTableName string) string {
|
||||||
|
// For has-many, we need to include the source table name to avoid duplicates
|
||||||
|
// e.g., multiple tables referencing the same column on this table
|
||||||
|
if len(constraint.Columns) > 0 {
|
||||||
|
columnName := constraint.Columns[0]
|
||||||
|
// Get the model name for the source table (pluralized)
|
||||||
|
sourceModelName := w.getModelName(sourceTableName)
|
||||||
|
// Remove "Model" prefix if present
|
||||||
|
sourceModelName = strings.TrimPrefix(sourceModelName, "Model")
|
||||||
|
|
||||||
|
// Convert column to PascalCase and combine with source table
|
||||||
|
// e.g., "rid_api_provider" + "Login" -> "RelRIDAPIProviderLogins"
|
||||||
|
columnPart := SnakeCaseToPascalCase(columnName)
|
||||||
|
return "Rel" + columnPart + Pluralize(sourceModelName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback to table-based naming
|
||||||
|
sourceModelName := w.getModelName(sourceTableName)
|
||||||
|
sourceModelName = strings.TrimPrefix(sourceModelName, "Model")
|
||||||
|
return "Rel" + Pluralize(sourceModelName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ensureUniqueFieldName ensures a field name is unique by adding numeric suffixes if needed
|
||||||
|
func (w *Writer) ensureUniqueFieldName(fieldName string, usedNames map[string]int) string {
|
||||||
|
originalName := fieldName
|
||||||
|
count := usedNames[originalName]
|
||||||
|
|
||||||
|
if count > 0 {
|
||||||
|
// Name is already used, add numeric suffix
|
||||||
|
fieldName = fmt.Sprintf("%s%d", originalName, count+1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Increment the counter for this base name
|
||||||
|
usedNames[originalName]++
|
||||||
|
|
||||||
|
return fieldName
|
||||||
}
|
}
|
||||||
|
|
||||||
// getPackageName returns the package name from options or defaults to "models"
|
// getPackageName returns the package name from options or defaults to "models"
|
||||||
|
|||||||
@@ -164,9 +164,309 @@ func TestWriter_WriteDatabase_MultiFile(t *testing.T) {
|
|||||||
t.Fatalf("Failed to read posts file: %v", err)
|
t.Fatalf("Failed to read posts file: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !strings.Contains(string(postsContent), "USE *ModelUser") {
|
postsStr := string(postsContent)
|
||||||
// Relationship field should be present
|
|
||||||
t.Logf("Posts content:\n%s", string(postsContent))
|
// Verify relationship is present with new naming convention
|
||||||
|
// Should now be RelUserID (belongs-to) instead of USE
|
||||||
|
if !strings.Contains(postsStr, "RelUserID") {
|
||||||
|
t.Errorf("Missing relationship field RelUserID (new naming convention)")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check users file contains has-many relationship
|
||||||
|
usersContent, err := os.ReadFile(filepath.Join(tmpDir, "sql_public_users.go"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to read users file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
usersStr := string(usersContent)
|
||||||
|
|
||||||
|
// Should have RelUserIDPosts (has-many) field
|
||||||
|
if !strings.Contains(usersStr, "RelUserIDPosts") {
|
||||||
|
t.Errorf("Missing has-many relationship field RelUserIDPosts")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWriter_MultipleReferencesToSameTable(t *testing.T) {
|
||||||
|
// Test scenario: api_event table with multiple foreign keys to filepointer table
|
||||||
|
db := models.InitDatabase("testdb")
|
||||||
|
schema := models.InitSchema("org")
|
||||||
|
|
||||||
|
// Filepointer table
|
||||||
|
filepointer := models.InitTable("filepointer", "org")
|
||||||
|
filepointer.Columns["id_filepointer"] = &models.Column{
|
||||||
|
Name: "id_filepointer",
|
||||||
|
Type: "bigserial",
|
||||||
|
NotNull: true,
|
||||||
|
IsPrimaryKey: true,
|
||||||
|
}
|
||||||
|
schema.Tables = append(schema.Tables, filepointer)
|
||||||
|
|
||||||
|
// API event table with two foreign keys to filepointer
|
||||||
|
apiEvent := models.InitTable("api_event", "org")
|
||||||
|
apiEvent.Columns["id_api_event"] = &models.Column{
|
||||||
|
Name: "id_api_event",
|
||||||
|
Type: "bigserial",
|
||||||
|
NotNull: true,
|
||||||
|
IsPrimaryKey: true,
|
||||||
|
}
|
||||||
|
apiEvent.Columns["rid_filepointer_request"] = &models.Column{
|
||||||
|
Name: "rid_filepointer_request",
|
||||||
|
Type: "bigint",
|
||||||
|
NotNull: false,
|
||||||
|
}
|
||||||
|
apiEvent.Columns["rid_filepointer_response"] = &models.Column{
|
||||||
|
Name: "rid_filepointer_response",
|
||||||
|
Type: "bigint",
|
||||||
|
NotNull: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add constraints
|
||||||
|
apiEvent.Constraints["fk_request"] = &models.Constraint{
|
||||||
|
Name: "fk_request",
|
||||||
|
Type: models.ForeignKeyConstraint,
|
||||||
|
Columns: []string{"rid_filepointer_request"},
|
||||||
|
ReferencedTable: "filepointer",
|
||||||
|
ReferencedSchema: "org",
|
||||||
|
ReferencedColumns: []string{"id_filepointer"},
|
||||||
|
}
|
||||||
|
apiEvent.Constraints["fk_response"] = &models.Constraint{
|
||||||
|
Name: "fk_response",
|
||||||
|
Type: models.ForeignKeyConstraint,
|
||||||
|
Columns: []string{"rid_filepointer_response"},
|
||||||
|
ReferencedTable: "filepointer",
|
||||||
|
ReferencedSchema: "org",
|
||||||
|
ReferencedColumns: []string{"id_filepointer"},
|
||||||
|
}
|
||||||
|
|
||||||
|
schema.Tables = append(schema.Tables, apiEvent)
|
||||||
|
db.Schemas = append(db.Schemas, schema)
|
||||||
|
|
||||||
|
// Create writer
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
opts := &writers.WriterOptions{
|
||||||
|
PackageName: "models",
|
||||||
|
OutputPath: tmpDir,
|
||||||
|
Metadata: map[string]interface{}{
|
||||||
|
"multi_file": true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
writer := NewWriter(opts)
|
||||||
|
err := writer.WriteDatabase(db)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("WriteDatabase failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read the api_event file
|
||||||
|
apiEventContent, err := os.ReadFile(filepath.Join(tmpDir, "sql_org_api_event.go"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to read api_event file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
contentStr := string(apiEventContent)
|
||||||
|
|
||||||
|
// Verify both relationships have unique names based on column names
|
||||||
|
expectations := []struct {
|
||||||
|
fieldName string
|
||||||
|
tag string
|
||||||
|
}{
|
||||||
|
{"RelRIDFilepointerRequest", "foreignKey:RIDFilepointerRequest"},
|
||||||
|
{"RelRIDFilepointerResponse", "foreignKey:RIDFilepointerResponse"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, exp := range expectations {
|
||||||
|
if !strings.Contains(contentStr, exp.fieldName) {
|
||||||
|
t.Errorf("Missing relationship field: %s\nGenerated:\n%s", exp.fieldName, contentStr)
|
||||||
|
}
|
||||||
|
if !strings.Contains(contentStr, exp.tag) {
|
||||||
|
t.Errorf("Missing relationship tag: %s\nGenerated:\n%s", exp.tag, contentStr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify NO duplicate field names (old behavior would create duplicate "FIL" fields)
|
||||||
|
if strings.Contains(contentStr, "FIL *ModelFilepointer") {
|
||||||
|
t.Errorf("Found old prefix-based naming (FIL), should use column-based naming")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Also verify has-many relationships on filepointer table
|
||||||
|
filepointerContent, err := os.ReadFile(filepath.Join(tmpDir, "sql_org_filepointer.go"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to read filepointer file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
filepointerStr := string(filepointerContent)
|
||||||
|
|
||||||
|
// Should have two different has-many relationships with unique names
|
||||||
|
hasManyExpectations := []string{
|
||||||
|
"RelRIDFilepointerRequestAPIEvents", // Has many via rid_filepointer_request
|
||||||
|
"RelRIDFilepointerResponseAPIEvents", // Has many via rid_filepointer_response
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, exp := range hasManyExpectations {
|
||||||
|
if !strings.Contains(filepointerStr, exp) {
|
||||||
|
t.Errorf("Missing has-many relationship field: %s\nGenerated:\n%s", exp, filepointerStr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWriter_MultipleHasManyRelationships(t *testing.T) {
|
||||||
|
// Test scenario: api_provider table referenced by multiple tables via rid_api_provider
|
||||||
|
db := models.InitDatabase("testdb")
|
||||||
|
schema := models.InitSchema("org")
|
||||||
|
|
||||||
|
// Owner table
|
||||||
|
owner := models.InitTable("owner", "org")
|
||||||
|
owner.Columns["id_owner"] = &models.Column{
|
||||||
|
Name: "id_owner",
|
||||||
|
Type: "bigserial",
|
||||||
|
NotNull: true,
|
||||||
|
IsPrimaryKey: true,
|
||||||
|
}
|
||||||
|
schema.Tables = append(schema.Tables, owner)
|
||||||
|
|
||||||
|
// API Provider table
|
||||||
|
apiProvider := models.InitTable("api_provider", "org")
|
||||||
|
apiProvider.Columns["id_api_provider"] = &models.Column{
|
||||||
|
Name: "id_api_provider",
|
||||||
|
Type: "bigserial",
|
||||||
|
NotNull: true,
|
||||||
|
IsPrimaryKey: true,
|
||||||
|
}
|
||||||
|
apiProvider.Columns["rid_owner"] = &models.Column{
|
||||||
|
Name: "rid_owner",
|
||||||
|
Type: "bigint",
|
||||||
|
NotNull: true,
|
||||||
|
}
|
||||||
|
apiProvider.Constraints["fk_owner"] = &models.Constraint{
|
||||||
|
Name: "fk_owner",
|
||||||
|
Type: models.ForeignKeyConstraint,
|
||||||
|
Columns: []string{"rid_owner"},
|
||||||
|
ReferencedTable: "owner",
|
||||||
|
ReferencedSchema: "org",
|
||||||
|
ReferencedColumns: []string{"id_owner"},
|
||||||
|
}
|
||||||
|
schema.Tables = append(schema.Tables, apiProvider)
|
||||||
|
|
||||||
|
// Login table
|
||||||
|
login := models.InitTable("login", "org")
|
||||||
|
login.Columns["id_login"] = &models.Column{
|
||||||
|
Name: "id_login",
|
||||||
|
Type: "bigserial",
|
||||||
|
NotNull: true,
|
||||||
|
IsPrimaryKey: true,
|
||||||
|
}
|
||||||
|
login.Columns["rid_api_provider"] = &models.Column{
|
||||||
|
Name: "rid_api_provider",
|
||||||
|
Type: "bigint",
|
||||||
|
NotNull: true,
|
||||||
|
}
|
||||||
|
login.Constraints["fk_api_provider"] = &models.Constraint{
|
||||||
|
Name: "fk_api_provider",
|
||||||
|
Type: models.ForeignKeyConstraint,
|
||||||
|
Columns: []string{"rid_api_provider"},
|
||||||
|
ReferencedTable: "api_provider",
|
||||||
|
ReferencedSchema: "org",
|
||||||
|
ReferencedColumns: []string{"id_api_provider"},
|
||||||
|
}
|
||||||
|
schema.Tables = append(schema.Tables, login)
|
||||||
|
|
||||||
|
// Filepointer table
|
||||||
|
filepointer := models.InitTable("filepointer", "org")
|
||||||
|
filepointer.Columns["id_filepointer"] = &models.Column{
|
||||||
|
Name: "id_filepointer",
|
||||||
|
Type: "bigserial",
|
||||||
|
NotNull: true,
|
||||||
|
IsPrimaryKey: true,
|
||||||
|
}
|
||||||
|
filepointer.Columns["rid_api_provider"] = &models.Column{
|
||||||
|
Name: "rid_api_provider",
|
||||||
|
Type: "bigint",
|
||||||
|
NotNull: true,
|
||||||
|
}
|
||||||
|
filepointer.Constraints["fk_api_provider"] = &models.Constraint{
|
||||||
|
Name: "fk_api_provider",
|
||||||
|
Type: models.ForeignKeyConstraint,
|
||||||
|
Columns: []string{"rid_api_provider"},
|
||||||
|
ReferencedTable: "api_provider",
|
||||||
|
ReferencedSchema: "org",
|
||||||
|
ReferencedColumns: []string{"id_api_provider"},
|
||||||
|
}
|
||||||
|
schema.Tables = append(schema.Tables, filepointer)
|
||||||
|
|
||||||
|
// API Event table
|
||||||
|
apiEvent := models.InitTable("api_event", "org")
|
||||||
|
apiEvent.Columns["id_api_event"] = &models.Column{
|
||||||
|
Name: "id_api_event",
|
||||||
|
Type: "bigserial",
|
||||||
|
NotNull: true,
|
||||||
|
IsPrimaryKey: true,
|
||||||
|
}
|
||||||
|
apiEvent.Columns["rid_api_provider"] = &models.Column{
|
||||||
|
Name: "rid_api_provider",
|
||||||
|
Type: "bigint",
|
||||||
|
NotNull: true,
|
||||||
|
}
|
||||||
|
apiEvent.Constraints["fk_api_provider"] = &models.Constraint{
|
||||||
|
Name: "fk_api_provider",
|
||||||
|
Type: models.ForeignKeyConstraint,
|
||||||
|
Columns: []string{"rid_api_provider"},
|
||||||
|
ReferencedTable: "api_provider",
|
||||||
|
ReferencedSchema: "org",
|
||||||
|
ReferencedColumns: []string{"id_api_provider"},
|
||||||
|
}
|
||||||
|
schema.Tables = append(schema.Tables, apiEvent)
|
||||||
|
|
||||||
|
db.Schemas = append(db.Schemas, schema)
|
||||||
|
|
||||||
|
// Create writer
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
opts := &writers.WriterOptions{
|
||||||
|
PackageName: "models",
|
||||||
|
OutputPath: tmpDir,
|
||||||
|
Metadata: map[string]interface{}{
|
||||||
|
"multi_file": true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
writer := NewWriter(opts)
|
||||||
|
err := writer.WriteDatabase(db)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("WriteDatabase failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read the api_provider file
|
||||||
|
apiProviderContent, err := os.ReadFile(filepath.Join(tmpDir, "sql_org_api_provider.go"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to read api_provider file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
contentStr := string(apiProviderContent)
|
||||||
|
|
||||||
|
// Verify all has-many relationships have unique names
|
||||||
|
hasManyExpectations := []string{
|
||||||
|
"RelRIDAPIProviderLogins", // Has many via Login
|
||||||
|
"RelRIDAPIProviderFilepointers", // Has many via Filepointer
|
||||||
|
"RelRIDAPIProviderAPIEvents", // Has many via APIEvent
|
||||||
|
"RelRIDOwner", // Belongs to via rid_owner
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, exp := range hasManyExpectations {
|
||||||
|
if !strings.Contains(contentStr, exp) {
|
||||||
|
t.Errorf("Missing relationship field: %s\nGenerated:\n%s", exp, contentStr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify NO duplicate field names
|
||||||
|
// Count occurrences of "RelRIDAPIProvider" fields - should have 3 unique ones
|
||||||
|
count := strings.Count(contentStr, "RelRIDAPIProvider")
|
||||||
|
if count != 3 {
|
||||||
|
t.Errorf("Expected 3 RelRIDAPIProvider* fields, found %d\nGenerated:\n%s", count, contentStr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify no duplicate declarations (would cause compilation error)
|
||||||
|
duplicatePattern := "RelRIDAPIProviders []*Model"
|
||||||
|
if strings.Contains(contentStr, duplicatePattern) {
|
||||||
|
t.Errorf("Found duplicate field declaration pattern, fields should be unique")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -61,3 +61,26 @@ func SanitizeFilename(name string) string {
|
|||||||
|
|
||||||
return name
|
return name
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SanitizeStructTagValue sanitizes a value to be safely used inside Go struct tags.
|
||||||
|
// Go struct tags are delimited by backticks, so any backtick in the value would break the syntax.
|
||||||
|
// This function:
|
||||||
|
// - Removes DBML/DCTX comments in brackets
|
||||||
|
// - Removes all quotes (double, single, and backticks)
|
||||||
|
// - Returns a clean identifier safe for use in struct tags and field names
|
||||||
|
func SanitizeStructTagValue(value string) string {
|
||||||
|
// Remove DBML/DCTX style comments in brackets (e.g., [note: 'description'])
|
||||||
|
commentRegex := regexp.MustCompile(`\s*\[.*?\]\s*`)
|
||||||
|
value = commentRegex.ReplaceAllString(value, "")
|
||||||
|
|
||||||
|
// Trim whitespace
|
||||||
|
value = strings.TrimSpace(value)
|
||||||
|
|
||||||
|
// Remove all quotes: backticks, double quotes, and single quotes
|
||||||
|
// This ensures the value is clean for use as Go identifiers and struct tag values
|
||||||
|
value = strings.ReplaceAll(value, "`", "")
|
||||||
|
value = strings.ReplaceAll(value, `"`, "")
|
||||||
|
value = strings.ReplaceAll(value, `'`, "")
|
||||||
|
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user