So far so good
This commit is contained in:
284
pkg/writers/gorm/name_converter.go
Normal file
284
pkg/writers/gorm/name_converter.go
Normal file
@@ -0,0 +1,284 @@
|
||||
package gorm
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"unicode"
|
||||
)
|
||||
|
||||
// SnakeCaseToPascalCase converts snake_case to PascalCase
|
||||
// Examples: user_id → UserID, http_request → HTTPRequest
|
||||
func SnakeCaseToPascalCase(s string) string {
|
||||
if s == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
parts := strings.Split(s, "_")
|
||||
for i, part := range parts {
|
||||
parts[i] = capitalize(part)
|
||||
}
|
||||
return strings.Join(parts, "")
|
||||
}
|
||||
|
||||
// SnakeCaseToCamelCase converts snake_case to camelCase
|
||||
// Examples: user_id → userID, http_request → httpRequest
|
||||
func SnakeCaseToCamelCase(s string) string {
|
||||
if s == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
parts := strings.Split(s, "_")
|
||||
for i, part := range parts {
|
||||
if i == 0 {
|
||||
parts[i] = strings.ToLower(part)
|
||||
} else {
|
||||
parts[i] = capitalize(part)
|
||||
}
|
||||
}
|
||||
return strings.Join(parts, "")
|
||||
}
|
||||
|
||||
// PascalCaseToSnakeCase converts PascalCase to snake_case
|
||||
// Examples: UserID → user_id, HTTPRequest → http_request
|
||||
func PascalCaseToSnakeCase(s string) string {
|
||||
if s == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
var result strings.Builder
|
||||
var prevUpper bool
|
||||
var nextUpper bool
|
||||
|
||||
runes := []rune(s)
|
||||
for i, r := range runes {
|
||||
isUpper := unicode.IsUpper(r)
|
||||
|
||||
if i+1 < len(runes) {
|
||||
nextUpper = unicode.IsUpper(runes[i+1])
|
||||
} else {
|
||||
nextUpper = false
|
||||
}
|
||||
|
||||
if i > 0 && isUpper {
|
||||
// Add underscore before uppercase letter if:
|
||||
// 1. Previous char was lowercase, OR
|
||||
// 2. Next char is lowercase (end of acronym)
|
||||
if !prevUpper || (nextUpper == false && i+1 < len(runes)) {
|
||||
result.WriteRune('_')
|
||||
}
|
||||
}
|
||||
|
||||
result.WriteRune(unicode.ToLower(r))
|
||||
prevUpper = isUpper
|
||||
}
|
||||
|
||||
return result.String()
|
||||
}
|
||||
|
||||
// capitalize capitalizes the first letter and handles common acronyms
|
||||
func capitalize(s string) string {
|
||||
if s == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
upper := strings.ToUpper(s)
|
||||
|
||||
// Handle common acronyms
|
||||
acronyms := map[string]bool{
|
||||
"ID": true,
|
||||
"UUID": true,
|
||||
"GUID": true,
|
||||
"URL": true,
|
||||
"URI": true,
|
||||
"HTTP": true,
|
||||
"HTTPS": true,
|
||||
"API": true,
|
||||
"JSON": true,
|
||||
"XML": true,
|
||||
"SQL": true,
|
||||
"HTML": true,
|
||||
"CSS": true,
|
||||
"RID": true,
|
||||
}
|
||||
|
||||
if acronyms[upper] {
|
||||
return upper
|
||||
}
|
||||
|
||||
// Capitalize first letter
|
||||
runes := []rune(s)
|
||||
runes[0] = unicode.ToUpper(runes[0])
|
||||
return string(runes)
|
||||
}
|
||||
|
||||
// Pluralize converts a singular word to plural
|
||||
// Basic implementation with common rules
|
||||
func Pluralize(s string) string {
|
||||
if s == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Special cases
|
||||
irregular := map[string]string{
|
||||
"person": "people",
|
||||
"child": "children",
|
||||
"tooth": "teeth",
|
||||
"foot": "feet",
|
||||
"man": "men",
|
||||
"woman": "women",
|
||||
"mouse": "mice",
|
||||
"goose": "geese",
|
||||
"ox": "oxen",
|
||||
"datum": "data",
|
||||
"medium": "media",
|
||||
"analysis": "analyses",
|
||||
"crisis": "crises",
|
||||
"status": "statuses",
|
||||
}
|
||||
|
||||
if plural, ok := irregular[strings.ToLower(s)]; ok {
|
||||
return plural
|
||||
}
|
||||
|
||||
// Already plural (ends in 's' but not 'ss' or 'us')
|
||||
if strings.HasSuffix(s, "s") && !strings.HasSuffix(s, "ss") && !strings.HasSuffix(s, "us") {
|
||||
return s
|
||||
}
|
||||
|
||||
// Words ending in s, x, z, ch, sh
|
||||
if strings.HasSuffix(s, "s") || strings.HasSuffix(s, "x") ||
|
||||
strings.HasSuffix(s, "z") || strings.HasSuffix(s, "ch") ||
|
||||
strings.HasSuffix(s, "sh") {
|
||||
return s + "es"
|
||||
}
|
||||
|
||||
// Words ending in consonant + y
|
||||
if len(s) >= 2 && strings.HasSuffix(s, "y") {
|
||||
prevChar := s[len(s)-2]
|
||||
if !isVowel(prevChar) {
|
||||
return s[:len(s)-1] + "ies"
|
||||
}
|
||||
}
|
||||
|
||||
// Words ending in f or fe
|
||||
if strings.HasSuffix(s, "f") {
|
||||
return s[:len(s)-1] + "ves"
|
||||
}
|
||||
if strings.HasSuffix(s, "fe") {
|
||||
return s[:len(s)-2] + "ves"
|
||||
}
|
||||
|
||||
// Words ending in consonant + o
|
||||
if len(s) >= 2 && strings.HasSuffix(s, "o") {
|
||||
prevChar := s[len(s)-2]
|
||||
if !isVowel(prevChar) {
|
||||
return s + "es"
|
||||
}
|
||||
}
|
||||
|
||||
// Default: add 's'
|
||||
return s + "s"
|
||||
}
|
||||
|
||||
// Singularize converts a plural word to singular
|
||||
// Basic implementation with common rules
|
||||
func Singularize(s string) string {
|
||||
if s == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Special cases
|
||||
irregular := map[string]string{
|
||||
"people": "person",
|
||||
"children": "child",
|
||||
"teeth": "tooth",
|
||||
"feet": "foot",
|
||||
"men": "man",
|
||||
"women": "woman",
|
||||
"mice": "mouse",
|
||||
"geese": "goose",
|
||||
"oxen": "ox",
|
||||
"data": "datum",
|
||||
"media": "medium",
|
||||
"analyses": "analysis",
|
||||
"crises": "crisis",
|
||||
"statuses": "status",
|
||||
}
|
||||
|
||||
if singular, ok := irregular[strings.ToLower(s)]; ok {
|
||||
return singular
|
||||
}
|
||||
|
||||
// Words ending in ies
|
||||
if strings.HasSuffix(s, "ies") && len(s) > 3 {
|
||||
return s[:len(s)-3] + "y"
|
||||
}
|
||||
|
||||
// Words ending in ves
|
||||
if strings.HasSuffix(s, "ves") {
|
||||
return s[:len(s)-3] + "f"
|
||||
}
|
||||
|
||||
// Words ending in ses, xes, zes, ches, shes
|
||||
if strings.HasSuffix(s, "ses") || strings.HasSuffix(s, "xes") ||
|
||||
strings.HasSuffix(s, "zes") || strings.HasSuffix(s, "ches") ||
|
||||
strings.HasSuffix(s, "shes") {
|
||||
return s[:len(s)-2]
|
||||
}
|
||||
|
||||
// Words ending in s (not ss)
|
||||
if strings.HasSuffix(s, "s") && !strings.HasSuffix(s, "ss") {
|
||||
return s[:len(s)-1]
|
||||
}
|
||||
|
||||
// Already singular
|
||||
return s
|
||||
}
|
||||
|
||||
// GeneratePrefix generates a 3-letter prefix from a table name
|
||||
// Examples: process → PRO, mastertask → MTL, user → USR
|
||||
func GeneratePrefix(tableName string) string {
|
||||
if tableName == "" {
|
||||
return "TBL"
|
||||
}
|
||||
|
||||
// Remove common prefixes
|
||||
tableName = strings.TrimPrefix(tableName, "tbl_")
|
||||
tableName = strings.TrimPrefix(tableName, "tb_")
|
||||
|
||||
// Split by underscore and take first letters
|
||||
parts := strings.Split(tableName, "_")
|
||||
|
||||
var prefix strings.Builder
|
||||
for _, part := range parts {
|
||||
if part == "" {
|
||||
continue
|
||||
}
|
||||
prefix.WriteRune(unicode.ToUpper(rune(part[0])))
|
||||
if prefix.Len() >= 3 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
result := prefix.String()
|
||||
|
||||
// If we don't have 3 letters yet, add more from the first part
|
||||
if len(result) < 3 && len(parts) > 0 {
|
||||
firstPart := parts[0]
|
||||
for i := 1; i < len(firstPart) && len(result) < 3; i++ {
|
||||
result += strings.ToUpper(string(firstPart[i]))
|
||||
}
|
||||
}
|
||||
|
||||
// Pad with 'X' if still too short
|
||||
for len(result) < 3 {
|
||||
result += "X"
|
||||
}
|
||||
|
||||
return result[:3]
|
||||
}
|
||||
|
||||
// isVowel checks if a byte is a vowel
|
||||
func isVowel(c byte) bool {
|
||||
c = byte(unicode.ToLower(rune(c)))
|
||||
return c == 'a' || c == 'e' || c == 'i' || c == 'o' || c == 'u'
|
||||
}
|
||||
250
pkg/writers/gorm/template_data.go
Normal file
250
pkg/writers/gorm/template_data.go
Normal file
@@ -0,0 +1,250 @@
|
||||
package gorm
|
||||
|
||||
import (
|
||||
"sort"
|
||||
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||
)
|
||||
|
||||
// TemplateData represents the data passed to the template for code generation
|
||||
type TemplateData struct {
|
||||
PackageName string
|
||||
Imports []string
|
||||
Models []*ModelData
|
||||
Config *MethodConfig
|
||||
}
|
||||
|
||||
// ModelData represents a single model/struct in the template
|
||||
type ModelData struct {
|
||||
Name string
|
||||
TableName string // schema.table format
|
||||
SchemaName string
|
||||
TableNameOnly string // just table name without schema
|
||||
Comment string
|
||||
Fields []*FieldData
|
||||
Config *MethodConfig
|
||||
PrimaryKeyField string // Name of the primary key field
|
||||
IDColumnName string // Name of the ID column in database
|
||||
Prefix string // 3-letter prefix
|
||||
}
|
||||
|
||||
// FieldData represents a single field in a struct
|
||||
type FieldData struct {
|
||||
Name string // Go field name (PascalCase)
|
||||
Type string // Go type
|
||||
GormTag string // Complete gorm tag
|
||||
JSONTag string // JSON tag
|
||||
Comment string // Field comment
|
||||
}
|
||||
|
||||
// MethodConfig controls which helper methods to generate
|
||||
type MethodConfig struct {
|
||||
GenerateTableName bool
|
||||
GenerateSchemaName bool
|
||||
GenerateTableNameOnly bool
|
||||
GenerateGetID bool
|
||||
GenerateGetIDStr bool
|
||||
GenerateSetID bool
|
||||
GenerateUpdateID bool
|
||||
GenerateGetIDName bool
|
||||
GenerateGetPrefix bool
|
||||
}
|
||||
|
||||
// DefaultMethodConfig returns a MethodConfig with all methods enabled
|
||||
func DefaultMethodConfig() *MethodConfig {
|
||||
return &MethodConfig{
|
||||
GenerateTableName: true,
|
||||
GenerateSchemaName: true,
|
||||
GenerateTableNameOnly: true,
|
||||
GenerateGetID: true,
|
||||
GenerateGetIDStr: true,
|
||||
GenerateSetID: true,
|
||||
GenerateUpdateID: true,
|
||||
GenerateGetIDName: true,
|
||||
GenerateGetPrefix: true,
|
||||
}
|
||||
}
|
||||
|
||||
// NewTemplateData creates a new TemplateData with the given package name and config
|
||||
func NewTemplateData(packageName string, config *MethodConfig) *TemplateData {
|
||||
if config == nil {
|
||||
config = DefaultMethodConfig()
|
||||
}
|
||||
|
||||
return &TemplateData{
|
||||
PackageName: packageName,
|
||||
Imports: make([]string, 0),
|
||||
Models: make([]*ModelData, 0),
|
||||
Config: config,
|
||||
}
|
||||
}
|
||||
|
||||
// AddModel adds a model to the template data
|
||||
func (td *TemplateData) AddModel(model *ModelData) {
|
||||
model.Config = td.Config
|
||||
td.Models = append(td.Models, model)
|
||||
}
|
||||
|
||||
// AddImport adds an import to the template data (deduplicates automatically)
|
||||
func (td *TemplateData) AddImport(importPath string) {
|
||||
// Check if already exists
|
||||
for _, imp := range td.Imports {
|
||||
if imp == importPath {
|
||||
return
|
||||
}
|
||||
}
|
||||
td.Imports = append(td.Imports, importPath)
|
||||
}
|
||||
|
||||
// FinalizeImports sorts and organizes imports
|
||||
func (td *TemplateData) FinalizeImports() {
|
||||
// Sort imports alphabetically
|
||||
sort.Strings(td.Imports)
|
||||
}
|
||||
|
||||
// NewModelData creates a new ModelData from a models.Table
|
||||
func NewModelData(table *models.Table, schema string, typeMapper *TypeMapper) *ModelData {
|
||||
tableName := table.Name
|
||||
if schema != "" {
|
||||
tableName = schema + "." + table.Name
|
||||
}
|
||||
|
||||
// Generate model name: singularize and convert to PascalCase
|
||||
singularTable := Singularize(table.Name)
|
||||
modelName := SnakeCaseToPascalCase(singularTable)
|
||||
|
||||
// Add "Model" prefix if not already present
|
||||
if !hasModelPrefix(modelName) {
|
||||
modelName = "Model" + modelName
|
||||
}
|
||||
|
||||
model := &ModelData{
|
||||
Name: modelName,
|
||||
TableName: tableName,
|
||||
SchemaName: schema,
|
||||
TableNameOnly: table.Name,
|
||||
Comment: formatComment(table.Description, table.Comment),
|
||||
Fields: make([]*FieldData, 0),
|
||||
Prefix: GeneratePrefix(table.Name),
|
||||
}
|
||||
|
||||
// Find primary key
|
||||
for _, col := range table.Columns {
|
||||
if col.IsPrimaryKey {
|
||||
model.PrimaryKeyField = SnakeCaseToPascalCase(col.Name)
|
||||
model.IDColumnName = col.Name
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Convert columns to fields (sorted by sequence or name)
|
||||
columns := sortColumns(table.Columns)
|
||||
for _, col := range columns {
|
||||
field := columnToField(col, table, typeMapper)
|
||||
model.Fields = append(model.Fields, field)
|
||||
}
|
||||
|
||||
return model
|
||||
}
|
||||
|
||||
// columnToField converts a models.Column to FieldData
|
||||
func columnToField(col *models.Column, table *models.Table, typeMapper *TypeMapper) *FieldData {
|
||||
fieldName := SnakeCaseToPascalCase(col.Name)
|
||||
goType := typeMapper.SQLTypeToGoType(col.Type, col.NotNull)
|
||||
gormTag := typeMapper.BuildGormTag(col, table)
|
||||
jsonTag := col.Name // Use column name for JSON tag
|
||||
|
||||
return &FieldData{
|
||||
Name: fieldName,
|
||||
Type: goType,
|
||||
GormTag: gormTag,
|
||||
JSONTag: jsonTag,
|
||||
Comment: formatComment(col.Description, col.Comment),
|
||||
}
|
||||
}
|
||||
|
||||
// AddRelationshipField adds a relationship field to the model
|
||||
func (md *ModelData) AddRelationshipField(field *FieldData) {
|
||||
md.Fields = append(md.Fields, field)
|
||||
}
|
||||
|
||||
// formatComment combines description and comment into a single comment string
|
||||
func formatComment(description, comment string) string {
|
||||
if description != "" && comment != "" {
|
||||
return description + " - " + comment
|
||||
}
|
||||
if description != "" {
|
||||
return description
|
||||
}
|
||||
return comment
|
||||
}
|
||||
|
||||
// hasModelPrefix checks if a name already has "Model" prefix
|
||||
func hasModelPrefix(name string) bool {
|
||||
return len(name) >= 5 && name[:5] == "Model"
|
||||
}
|
||||
|
||||
// sortColumns sorts columns by sequence, then by name
|
||||
func sortColumns(columns map[string]*models.Column) []*models.Column {
|
||||
result := make([]*models.Column, 0, len(columns))
|
||||
for _, col := range columns {
|
||||
result = append(result, col)
|
||||
}
|
||||
|
||||
sort.Slice(result, func(i, j int) bool {
|
||||
// Sort by sequence if both have it
|
||||
if result[i].Sequence > 0 && result[j].Sequence > 0 {
|
||||
return result[i].Sequence < result[j].Sequence
|
||||
}
|
||||
|
||||
// Put primary keys first
|
||||
if result[i].IsPrimaryKey != result[j].IsPrimaryKey {
|
||||
return result[i].IsPrimaryKey
|
||||
}
|
||||
|
||||
// Otherwise sort alphabetically
|
||||
return result[i].Name < result[j].Name
|
||||
})
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// LoadMethodConfigFromMetadata loads method configuration from metadata map
|
||||
func LoadMethodConfigFromMetadata(metadata map[string]interface{}) *MethodConfig {
|
||||
config := DefaultMethodConfig()
|
||||
|
||||
if metadata == nil {
|
||||
return config
|
||||
}
|
||||
|
||||
// Load each setting from metadata if present
|
||||
if val, ok := metadata["generate_table_name"].(bool); ok {
|
||||
config.GenerateTableName = val
|
||||
}
|
||||
if val, ok := metadata["generate_schema_name"].(bool); ok {
|
||||
config.GenerateSchemaName = val
|
||||
}
|
||||
if val, ok := metadata["generate_table_name_only"].(bool); ok {
|
||||
config.GenerateTableNameOnly = val
|
||||
}
|
||||
if val, ok := metadata["generate_get_id"].(bool); ok {
|
||||
config.GenerateGetID = val
|
||||
}
|
||||
if val, ok := metadata["generate_get_id_str"].(bool); ok {
|
||||
config.GenerateGetIDStr = val
|
||||
}
|
||||
if val, ok := metadata["generate_set_id"].(bool); ok {
|
||||
config.GenerateSetID = val
|
||||
}
|
||||
if val, ok := metadata["generate_update_id"].(bool); ok {
|
||||
config.GenerateUpdateID = val
|
||||
}
|
||||
if val, ok := metadata["generate_get_id_name"].(bool); ok {
|
||||
config.GenerateGetIDName = val
|
||||
}
|
||||
if val, ok := metadata["generate_get_prefix"].(bool); ok {
|
||||
config.GenerateGetPrefix = val
|
||||
}
|
||||
|
||||
return config
|
||||
}
|
||||
109
pkg/writers/gorm/templates.go
Normal file
109
pkg/writers/gorm/templates.go
Normal file
@@ -0,0 +1,109 @@
|
||||
package gorm
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"text/template"
|
||||
)
|
||||
|
||||
// modelTemplate defines the template for generating GORM models
|
||||
const modelTemplate = `// Code generated by relspecgo. DO NOT EDIT.
|
||||
package {{.PackageName}}
|
||||
|
||||
{{if .Imports -}}
|
||||
import (
|
||||
{{range .Imports -}}
|
||||
{{.}}
|
||||
{{end -}}
|
||||
)
|
||||
{{end}}
|
||||
{{range .Models}}
|
||||
{{if .Comment}}// {{.Comment}}{{end}}
|
||||
type {{.Name}} struct {
|
||||
{{- range .Fields}}
|
||||
{{.Name}} {{.Type}} ` + "`gorm:\"{{.GormTag}}\" json:\"{{.JSONTag}}\"`" + `{{if .Comment}} // {{.Comment}}{{end}}
|
||||
{{- end}}
|
||||
}
|
||||
{{if .Config.GenerateTableName}}
|
||||
// TableName returns the table name for {{.Name}}
|
||||
func (m {{.Name}}) TableName() string {
|
||||
return "{{.TableName}}"
|
||||
}
|
||||
{{end}}
|
||||
{{if .Config.GenerateTableNameOnly}}
|
||||
// TableNameOnly returns the table name without schema for {{.Name}}
|
||||
func (m {{.Name}}) TableNameOnly() string {
|
||||
return "{{.TableNameOnly}}"
|
||||
}
|
||||
{{end}}
|
||||
{{if .Config.GenerateSchemaName}}
|
||||
// SchemaName returns the schema name for {{.Name}}
|
||||
func (m {{.Name}}) SchemaName() string {
|
||||
return "{{.SchemaName}}"
|
||||
}
|
||||
{{end}}
|
||||
{{if and .Config.GenerateGetID .PrimaryKeyField}}
|
||||
// GetID returns the primary key value
|
||||
func (m {{.Name}}) GetID() int64 {
|
||||
return int64(m.{{.PrimaryKeyField}})
|
||||
}
|
||||
{{end}}
|
||||
{{if and .Config.GenerateGetIDStr .PrimaryKeyField}}
|
||||
// GetIDStr returns the primary key as a string
|
||||
func (m {{.Name}}) GetIDStr() string {
|
||||
return fmt.Sprintf("%d", m.{{.PrimaryKeyField}})
|
||||
}
|
||||
{{end}}
|
||||
{{if and .Config.GenerateSetID .PrimaryKeyField}}
|
||||
// SetID sets the primary key value
|
||||
func (m {{.Name}}) SetID(newid int64) {
|
||||
m.UpdateID(newid)
|
||||
}
|
||||
{{end}}
|
||||
{{if and .Config.GenerateUpdateID .PrimaryKeyField}}
|
||||
// UpdateID updates the primary key value
|
||||
func (m *{{.Name}}) UpdateID(newid int64) {
|
||||
m.{{.PrimaryKeyField}} = int32(newid)
|
||||
}
|
||||
{{end}}
|
||||
{{if and .Config.GenerateGetIDName .IDColumnName}}
|
||||
// GetIDName returns the name of the primary key column
|
||||
func (m {{.Name}}) GetIDName() string {
|
||||
return "{{.IDColumnName}}"
|
||||
}
|
||||
{{end}}
|
||||
{{if .Config.GenerateGetPrefix}}
|
||||
// GetPrefix returns the table prefix
|
||||
func (m {{.Name}}) GetPrefix() string {
|
||||
return "{{.Prefix}}"
|
||||
}
|
||||
{{end}}
|
||||
{{end -}}
|
||||
`
|
||||
|
||||
// Templates holds the parsed templates
|
||||
type Templates struct {
|
||||
modelTmpl *template.Template
|
||||
}
|
||||
|
||||
// NewTemplates creates and parses the templates
|
||||
func NewTemplates() (*Templates, error) {
|
||||
modelTmpl, err := template.New("model").Parse(modelTemplate)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Templates{
|
||||
modelTmpl: modelTmpl,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GenerateCode executes the template with the given data
|
||||
func (t *Templates) GenerateCode(data *TemplateData) (string, error) {
|
||||
var buf bytes.Buffer
|
||||
err := t.modelTmpl.Execute(&buf, data)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return buf.String(), nil
|
||||
}
|
||||
335
pkg/writers/gorm/type_mapper.go
Normal file
335
pkg/writers/gorm/type_mapper.go
Normal file
@@ -0,0 +1,335 @@
|
||||
package gorm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||
)
|
||||
|
||||
// TypeMapper handles type conversions between SQL and Go types
|
||||
type TypeMapper struct {
|
||||
// Package alias for sql_types import
|
||||
sqlTypesAlias string
|
||||
}
|
||||
|
||||
// NewTypeMapper creates a new TypeMapper with default settings
|
||||
func NewTypeMapper() *TypeMapper {
|
||||
return &TypeMapper{
|
||||
sqlTypesAlias: "sql_types",
|
||||
}
|
||||
}
|
||||
|
||||
// SQLTypeToGoType converts a SQL type to its Go equivalent
|
||||
// Handles nullable types using ResolveSpec sql_types package
|
||||
func (tm *TypeMapper) SQLTypeToGoType(sqlType string, notNull bool) string {
|
||||
// Normalize SQL type (lowercase, remove length/precision)
|
||||
baseType := tm.extractBaseType(sqlType)
|
||||
|
||||
// If not null, use base Go types
|
||||
if notNull {
|
||||
return tm.baseGoType(baseType)
|
||||
}
|
||||
|
||||
// For nullable fields, use sql_types
|
||||
return tm.nullableGoType(baseType)
|
||||
}
|
||||
|
||||
// extractBaseType extracts the base type from a SQL type string
|
||||
// Examples: varchar(100) → varchar, numeric(10,2) → numeric
|
||||
func (tm *TypeMapper) extractBaseType(sqlType string) string {
|
||||
sqlType = strings.ToLower(strings.TrimSpace(sqlType))
|
||||
|
||||
// Remove everything after '('
|
||||
if idx := strings.Index(sqlType, "("); idx > 0 {
|
||||
sqlType = sqlType[:idx]
|
||||
}
|
||||
|
||||
return sqlType
|
||||
}
|
||||
|
||||
// baseGoType returns the base Go type for a SQL type (not null)
|
||||
func (tm *TypeMapper) baseGoType(sqlType string) string {
|
||||
typeMap := map[string]string{
|
||||
// Integer types
|
||||
"integer": "int32",
|
||||
"int": "int32",
|
||||
"int4": "int32",
|
||||
"smallint": "int16",
|
||||
"int2": "int16",
|
||||
"bigint": "int64",
|
||||
"int8": "int64",
|
||||
"serial": "int32",
|
||||
"bigserial": "int64",
|
||||
"smallserial": "int16",
|
||||
|
||||
// String types
|
||||
"text": "string",
|
||||
"varchar": "string",
|
||||
"char": "string",
|
||||
"character": "string",
|
||||
"citext": "string",
|
||||
"bpchar": "string",
|
||||
|
||||
// Boolean
|
||||
"boolean": "bool",
|
||||
"bool": "bool",
|
||||
|
||||
// Float types
|
||||
"real": "float32",
|
||||
"float4": "float32",
|
||||
"double precision": "float64",
|
||||
"float8": "float64",
|
||||
"numeric": "float64",
|
||||
"decimal": "float64",
|
||||
|
||||
// Date/Time types
|
||||
"timestamp": "time.Time",
|
||||
"timestamp without time zone": "time.Time",
|
||||
"timestamp with time zone": "time.Time",
|
||||
"timestamptz": "time.Time",
|
||||
"date": "time.Time",
|
||||
"time": "time.Time",
|
||||
"time without time zone": "time.Time",
|
||||
"time with time zone": "time.Time",
|
||||
"timetz": "time.Time",
|
||||
|
||||
// Binary
|
||||
"bytea": "[]byte",
|
||||
|
||||
// UUID
|
||||
"uuid": "string",
|
||||
|
||||
// JSON
|
||||
"json": "string",
|
||||
"jsonb": "string",
|
||||
|
||||
// Network
|
||||
"inet": "string",
|
||||
"cidr": "string",
|
||||
"macaddr": "string",
|
||||
|
||||
// Other
|
||||
"money": "float64",
|
||||
}
|
||||
|
||||
if goType, ok := typeMap[sqlType]; ok {
|
||||
return goType
|
||||
}
|
||||
|
||||
// Default to string for unknown types
|
||||
return "string"
|
||||
}
|
||||
|
||||
// nullableGoType returns the nullable Go type using sql_types package
|
||||
func (tm *TypeMapper) nullableGoType(sqlType string) string {
|
||||
typeMap := map[string]string{
|
||||
// Integer types
|
||||
"integer": tm.sqlTypesAlias + ".SqlInt32",
|
||||
"int": tm.sqlTypesAlias + ".SqlInt32",
|
||||
"int4": tm.sqlTypesAlias + ".SqlInt32",
|
||||
"smallint": tm.sqlTypesAlias + ".SqlInt16",
|
||||
"int2": tm.sqlTypesAlias + ".SqlInt16",
|
||||
"bigint": tm.sqlTypesAlias + ".SqlInt64",
|
||||
"int8": tm.sqlTypesAlias + ".SqlInt64",
|
||||
"serial": tm.sqlTypesAlias + ".SqlInt32",
|
||||
"bigserial": tm.sqlTypesAlias + ".SqlInt64",
|
||||
"smallserial": tm.sqlTypesAlias + ".SqlInt16",
|
||||
|
||||
// String types
|
||||
"text": tm.sqlTypesAlias + ".SqlString",
|
||||
"varchar": tm.sqlTypesAlias + ".SqlString",
|
||||
"char": tm.sqlTypesAlias + ".SqlString",
|
||||
"character": tm.sqlTypesAlias + ".SqlString",
|
||||
"citext": tm.sqlTypesAlias + ".SqlString",
|
||||
"bpchar": tm.sqlTypesAlias + ".SqlString",
|
||||
|
||||
// Boolean
|
||||
"boolean": tm.sqlTypesAlias + ".SqlBool",
|
||||
"bool": tm.sqlTypesAlias + ".SqlBool",
|
||||
|
||||
// Float types
|
||||
"real": tm.sqlTypesAlias + ".SqlFloat32",
|
||||
"float4": tm.sqlTypesAlias + ".SqlFloat32",
|
||||
"double precision": tm.sqlTypesAlias + ".SqlFloat64",
|
||||
"float8": tm.sqlTypesAlias + ".SqlFloat64",
|
||||
"numeric": tm.sqlTypesAlias + ".SqlFloat64",
|
||||
"decimal": tm.sqlTypesAlias + ".SqlFloat64",
|
||||
|
||||
// Date/Time types
|
||||
"timestamp": tm.sqlTypesAlias + ".SqlTime",
|
||||
"timestamp without time zone": tm.sqlTypesAlias + ".SqlTime",
|
||||
"timestamp with time zone": tm.sqlTypesAlias + ".SqlTime",
|
||||
"timestamptz": tm.sqlTypesAlias + ".SqlTime",
|
||||
"date": tm.sqlTypesAlias + ".SqlDate",
|
||||
"time": tm.sqlTypesAlias + ".SqlTime",
|
||||
"time without time zone": tm.sqlTypesAlias + ".SqlTime",
|
||||
"time with time zone": tm.sqlTypesAlias + ".SqlTime",
|
||||
"timetz": tm.sqlTypesAlias + ".SqlTime",
|
||||
|
||||
// Binary
|
||||
"bytea": "[]byte", // No nullable version needed
|
||||
|
||||
// UUID
|
||||
"uuid": tm.sqlTypesAlias + ".SqlUUID",
|
||||
|
||||
// JSON
|
||||
"json": tm.sqlTypesAlias + ".SqlString",
|
||||
"jsonb": tm.sqlTypesAlias + ".SqlString",
|
||||
|
||||
// Network
|
||||
"inet": tm.sqlTypesAlias + ".SqlString",
|
||||
"cidr": tm.sqlTypesAlias + ".SqlString",
|
||||
"macaddr": tm.sqlTypesAlias + ".SqlString",
|
||||
|
||||
// Other
|
||||
"money": tm.sqlTypesAlias + ".SqlFloat64",
|
||||
}
|
||||
|
||||
if goType, ok := typeMap[sqlType]; ok {
|
||||
return goType
|
||||
}
|
||||
|
||||
// Default to SqlString for unknown types
|
||||
return tm.sqlTypesAlias + ".SqlString"
|
||||
}
|
||||
|
||||
// BuildGormTag generates a complete GORM tag string for a column
|
||||
func (tm *TypeMapper) BuildGormTag(column *models.Column, table *models.Table) string {
|
||||
var parts []string
|
||||
|
||||
// Always include column name (lowercase as per user requirement)
|
||||
parts = append(parts, fmt.Sprintf("column:%s", column.Name))
|
||||
|
||||
// Add type if specified
|
||||
if column.Type != "" {
|
||||
// Include length, precision, scale if present
|
||||
typeStr := column.Type
|
||||
if column.Length > 0 {
|
||||
typeStr = fmt.Sprintf("%s(%d)", typeStr, column.Length)
|
||||
} else if column.Precision > 0 {
|
||||
if column.Scale > 0 {
|
||||
typeStr = fmt.Sprintf("%s(%d,%d)", typeStr, column.Precision, column.Scale)
|
||||
} else {
|
||||
typeStr = fmt.Sprintf("%s(%d)", typeStr, column.Precision)
|
||||
}
|
||||
}
|
||||
parts = append(parts, fmt.Sprintf("type:%s", typeStr))
|
||||
}
|
||||
|
||||
// Primary key
|
||||
if column.IsPrimaryKey {
|
||||
parts = append(parts, "primaryKey")
|
||||
}
|
||||
|
||||
// Auto increment
|
||||
if column.AutoIncrement {
|
||||
parts = append(parts, "autoIncrement")
|
||||
}
|
||||
|
||||
// Not null (skip if primary key, as it's implied)
|
||||
if column.NotNull && !column.IsPrimaryKey {
|
||||
parts = append(parts, "not null")
|
||||
}
|
||||
|
||||
// Default value
|
||||
if column.Default != nil {
|
||||
parts = append(parts, fmt.Sprintf("default:%v", column.Default))
|
||||
}
|
||||
|
||||
// Check for unique constraint
|
||||
if table != nil {
|
||||
for _, constraint := range table.Constraints {
|
||||
if constraint.Type == models.UniqueConstraint {
|
||||
for _, col := range constraint.Columns {
|
||||
if col == column.Name {
|
||||
if constraint.Name != "" {
|
||||
parts = append(parts, fmt.Sprintf("uniqueIndex:%s", constraint.Name))
|
||||
} else {
|
||||
parts = append(parts, "unique")
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check for index
|
||||
for _, index := range table.Indexes {
|
||||
for _, col := range index.Columns {
|
||||
if col == column.Name {
|
||||
if index.Unique {
|
||||
if index.Name != "" {
|
||||
parts = append(parts, fmt.Sprintf("uniqueIndex:%s", index.Name))
|
||||
} else {
|
||||
parts = append(parts, "unique")
|
||||
}
|
||||
} else {
|
||||
if index.Name != "" {
|
||||
parts = append(parts, fmt.Sprintf("index:%s", index.Name))
|
||||
} else {
|
||||
parts = append(parts, "index")
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return strings.Join(parts, ";")
|
||||
}
|
||||
|
||||
// BuildRelationshipTag generates GORM tag for relationship fields
|
||||
func (tm *TypeMapper) BuildRelationshipTag(constraint *models.Constraint, isParent bool) string {
|
||||
var parts []string
|
||||
|
||||
if !isParent {
|
||||
// Child side (has foreign key)
|
||||
if len(constraint.Columns) > 0 && len(constraint.ReferencedColumns) > 0 {
|
||||
// foreignKey points to the field name in this struct
|
||||
fkFieldName := SnakeCaseToPascalCase(constraint.Columns[0])
|
||||
parts = append(parts, fmt.Sprintf("foreignKey:%s", fkFieldName))
|
||||
|
||||
// references points to the field name in the other struct
|
||||
refFieldName := SnakeCaseToPascalCase(constraint.ReferencedColumns[0])
|
||||
parts = append(parts, fmt.Sprintf("references:%s", refFieldName))
|
||||
}
|
||||
} else {
|
||||
// Parent side (being referenced)
|
||||
if len(constraint.Columns) > 0 {
|
||||
fkFieldName := SnakeCaseToPascalCase(constraint.Columns[0])
|
||||
parts = append(parts, fmt.Sprintf("foreignKey:%s", fkFieldName))
|
||||
}
|
||||
}
|
||||
|
||||
// Add constraint actions
|
||||
if constraint.OnDelete != "" {
|
||||
parts = append(parts, fmt.Sprintf("constraint:OnDelete:%s", strings.ToUpper(constraint.OnDelete)))
|
||||
}
|
||||
if constraint.OnUpdate != "" {
|
||||
if len(parts) > 0 && strings.Contains(parts[len(parts)-1], "constraint:") {
|
||||
// Append to existing constraint
|
||||
parts[len(parts)-1] += fmt.Sprintf(",OnUpdate:%s", strings.ToUpper(constraint.OnUpdate))
|
||||
} else {
|
||||
parts = append(parts, fmt.Sprintf("constraint:OnUpdate:%s", strings.ToUpper(constraint.OnUpdate)))
|
||||
}
|
||||
}
|
||||
|
||||
return strings.Join(parts, ";")
|
||||
}
|
||||
|
||||
// NeedsTimeImport checks if the Go type requires time package import
|
||||
func (tm *TypeMapper) NeedsTimeImport(goType string) bool {
|
||||
return strings.Contains(goType, "time.Time")
|
||||
}
|
||||
|
||||
// NeedsFmtImport checks if we need fmt import (for GetIDStr method)
|
||||
func (tm *TypeMapper) NeedsFmtImport(generateGetIDStr bool) bool {
|
||||
return generateGetIDStr
|
||||
}
|
||||
|
||||
// GetSQLTypesImport returns the import path for sql_types
|
||||
func (tm *TypeMapper) GetSQLTypesImport() string {
|
||||
return "github.com/bitechdev/ResolveSpec/pkg/common/sql_types"
|
||||
}
|
||||
324
pkg/writers/gorm/writer.go
Normal file
324
pkg/writers/gorm/writer.go
Normal file
@@ -0,0 +1,324 @@
|
||||
package gorm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"go/format"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/writers"
|
||||
)
|
||||
|
||||
// Writer implements the writers.Writer interface for GORM models
|
||||
type Writer struct {
|
||||
options *writers.WriterOptions
|
||||
typeMapper *TypeMapper
|
||||
templates *Templates
|
||||
config *MethodConfig
|
||||
}
|
||||
|
||||
// NewWriter creates a new GORM writer with the given options
|
||||
func NewWriter(options *writers.WriterOptions) *Writer {
|
||||
w := &Writer{
|
||||
options: options,
|
||||
typeMapper: NewTypeMapper(),
|
||||
config: LoadMethodConfigFromMetadata(options.Metadata),
|
||||
}
|
||||
|
||||
// Initialize templates
|
||||
tmpl, err := NewTemplates()
|
||||
if err != nil {
|
||||
// Should not happen with embedded templates
|
||||
panic(fmt.Sprintf("failed to initialize templates: %v", err))
|
||||
}
|
||||
w.templates = tmpl
|
||||
|
||||
return w
|
||||
}
|
||||
|
||||
// WriteDatabase writes a complete database as GORM models
|
||||
func (w *Writer) WriteDatabase(db *models.Database) error {
|
||||
// Check if multi-file mode is enabled
|
||||
multiFile := false
|
||||
if w.options.Metadata != nil {
|
||||
if mf, ok := w.options.Metadata["multi_file"].(bool); ok {
|
||||
multiFile = mf
|
||||
}
|
||||
}
|
||||
|
||||
if multiFile {
|
||||
return w.writeMultiFile(db)
|
||||
}
|
||||
|
||||
return w.writeSingleFile(db)
|
||||
}
|
||||
|
||||
// WriteSchema writes a schema as GORM models
|
||||
func (w *Writer) WriteSchema(schema *models.Schema) error {
|
||||
// Create a temporary database with just this schema
|
||||
db := models.InitDatabase(schema.Name)
|
||||
db.Schemas = []*models.Schema{schema}
|
||||
|
||||
return w.WriteDatabase(db)
|
||||
}
|
||||
|
||||
// WriteTable writes a single table as a GORM model
|
||||
func (w *Writer) WriteTable(table *models.Table) error {
|
||||
// Create a temporary schema and database
|
||||
schema := models.InitSchema(table.Schema)
|
||||
schema.Tables = []*models.Table{table}
|
||||
|
||||
db := models.InitDatabase(schema.Name)
|
||||
db.Schemas = []*models.Schema{schema}
|
||||
|
||||
return w.WriteDatabase(db)
|
||||
}
|
||||
|
||||
// writeSingleFile writes all models to a single file
|
||||
func (w *Writer) writeSingleFile(db *models.Database) error {
|
||||
packageName := w.getPackageName()
|
||||
templateData := NewTemplateData(packageName, w.config)
|
||||
|
||||
// Add sql_types import (always needed for nullable types)
|
||||
templateData.AddImport(fmt.Sprintf("sql_types \"%s\"", w.typeMapper.GetSQLTypesImport()))
|
||||
|
||||
// Collect all models
|
||||
for _, schema := range db.Schemas {
|
||||
for _, table := range schema.Tables {
|
||||
modelData := NewModelData(table, schema.Name, w.typeMapper)
|
||||
|
||||
// Add relationship fields
|
||||
w.addRelationshipFields(modelData, table, schema, db)
|
||||
|
||||
templateData.AddModel(modelData)
|
||||
|
||||
// Check if we need time import
|
||||
for _, field := range modelData.Fields {
|
||||
if w.typeMapper.NeedsTimeImport(field.Type) {
|
||||
templateData.AddImport("\"time\"")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add fmt import if GetIDStr is enabled
|
||||
if w.config.GenerateGetIDStr {
|
||||
templateData.AddImport("\"fmt\"")
|
||||
}
|
||||
|
||||
// Finalize imports
|
||||
templateData.FinalizeImports()
|
||||
|
||||
// Generate code
|
||||
code, err := w.templates.GenerateCode(templateData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate code: %w", err)
|
||||
}
|
||||
|
||||
// Format code
|
||||
formatted, err := w.formatCode(code)
|
||||
if err != nil {
|
||||
// Return unformatted code with warning
|
||||
fmt.Fprintf(os.Stderr, "Warning: failed to format code: %v\n", err)
|
||||
formatted = code
|
||||
}
|
||||
|
||||
// Write output
|
||||
return w.writeOutput(formatted)
|
||||
}
|
||||
|
||||
// writeMultiFile writes each table to a separate file
|
||||
func (w *Writer) writeMultiFile(db *models.Database) error {
|
||||
packageName := w.getPackageName()
|
||||
|
||||
// Ensure output path is a directory
|
||||
if w.options.OutputPath == "" {
|
||||
return fmt.Errorf("output path is required for multi-file mode")
|
||||
}
|
||||
|
||||
// Create output directory if it doesn't exist
|
||||
if err := os.MkdirAll(w.options.OutputPath, 0755); err != nil {
|
||||
return fmt.Errorf("failed to create output directory: %w", err)
|
||||
}
|
||||
|
||||
// Generate a file for each table
|
||||
for _, schema := range db.Schemas {
|
||||
for _, table := range schema.Tables {
|
||||
// Create template data for this single table
|
||||
templateData := NewTemplateData(packageName, w.config)
|
||||
|
||||
// Add sql_types import
|
||||
templateData.AddImport(fmt.Sprintf("sql_types \"%s\"", w.typeMapper.GetSQLTypesImport()))
|
||||
|
||||
// Create model data
|
||||
modelData := NewModelData(table, schema.Name, w.typeMapper)
|
||||
|
||||
// Add relationship fields
|
||||
w.addRelationshipFields(modelData, table, schema, db)
|
||||
|
||||
templateData.AddModel(modelData)
|
||||
|
||||
// Check if we need time import
|
||||
for _, field := range modelData.Fields {
|
||||
if w.typeMapper.NeedsTimeImport(field.Type) {
|
||||
templateData.AddImport("\"time\"")
|
||||
}
|
||||
}
|
||||
|
||||
// Add fmt import if GetIDStr is enabled
|
||||
if w.config.GenerateGetIDStr {
|
||||
templateData.AddImport("\"fmt\"")
|
||||
}
|
||||
|
||||
// Finalize imports
|
||||
templateData.FinalizeImports()
|
||||
|
||||
// Generate code
|
||||
code, err := w.templates.GenerateCode(templateData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate code for table %s: %w", table.Name, err)
|
||||
}
|
||||
|
||||
// Format code
|
||||
formatted, err := w.formatCode(code)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Warning: failed to format code for %s: %v\n", table.Name, err)
|
||||
formatted = code
|
||||
}
|
||||
|
||||
// Generate filename: sql_{schema}_{table}.go
|
||||
filename := fmt.Sprintf("sql_%s_%s.go", schema.Name, table.Name)
|
||||
filepath := filepath.Join(w.options.OutputPath, filename)
|
||||
|
||||
// Write file
|
||||
if err := os.WriteFile(filepath, []byte(formatted), 0644); err != nil {
|
||||
return fmt.Errorf("failed to write file %s: %w", filename, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// addRelationshipFields adds relationship fields to the model based on foreign keys
|
||||
func (w *Writer) addRelationshipFields(modelData *ModelData, table *models.Table, schema *models.Schema, db *models.Database) {
|
||||
// For each foreign key in this table, add a belongs-to relationship
|
||||
for _, constraint := range table.Constraints {
|
||||
if constraint.Type != models.ForeignKeyConstraint {
|
||||
continue
|
||||
}
|
||||
|
||||
// Find the referenced table
|
||||
refTable := w.findTable(constraint.ReferencedSchema, constraint.ReferencedTable, db)
|
||||
if refTable == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Create relationship field (belongs-to)
|
||||
refModelName := w.getModelName(constraint.ReferencedTable)
|
||||
fieldName := w.generateRelationshipFieldName(constraint.ReferencedTable)
|
||||
relationTag := w.typeMapper.BuildRelationshipTag(constraint, false)
|
||||
|
||||
modelData.AddRelationshipField(&FieldData{
|
||||
Name: fieldName,
|
||||
Type: "*" + refModelName, // Pointer type
|
||||
GormTag: relationTag,
|
||||
JSONTag: strings.ToLower(fieldName) + ",omitempty",
|
||||
Comment: fmt.Sprintf("Belongs to %s", refModelName),
|
||||
})
|
||||
}
|
||||
|
||||
// For each table that references this table, add a has-many relationship
|
||||
for _, otherSchema := range db.Schemas {
|
||||
for _, otherTable := range otherSchema.Tables {
|
||||
if otherTable.Name == table.Name && otherSchema.Name == schema.Name {
|
||||
continue // Skip self
|
||||
}
|
||||
|
||||
for _, constraint := range otherTable.Constraints {
|
||||
if constraint.Type != models.ForeignKeyConstraint {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if this constraint references our table
|
||||
if constraint.ReferencedTable == table.Name && constraint.ReferencedSchema == schema.Name {
|
||||
// Add has-many relationship
|
||||
otherModelName := w.getModelName(otherTable.Name)
|
||||
fieldName := w.generateRelationshipFieldName(otherTable.Name) + "s" // Pluralize
|
||||
relationTag := w.typeMapper.BuildRelationshipTag(constraint, true)
|
||||
|
||||
modelData.AddRelationshipField(&FieldData{
|
||||
Name: fieldName,
|
||||
Type: "[]*" + otherModelName, // Slice of pointers
|
||||
GormTag: relationTag,
|
||||
JSONTag: strings.ToLower(fieldName) + ",omitempty",
|
||||
Comment: fmt.Sprintf("Has many %s", otherModelName),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// findTable finds a table by schema and name in the database
|
||||
func (w *Writer) findTable(schemaName, tableName string, db *models.Database) *models.Table {
|
||||
for _, schema := range db.Schemas {
|
||||
if schema.Name != schemaName {
|
||||
continue
|
||||
}
|
||||
for _, table := range schema.Tables {
|
||||
if table.Name == tableName {
|
||||
return table
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// getModelName generates the model name from a table name
|
||||
func (w *Writer) getModelName(tableName string) string {
|
||||
singular := Singularize(tableName)
|
||||
modelName := SnakeCaseToPascalCase(singular)
|
||||
|
||||
if !hasModelPrefix(modelName) {
|
||||
modelName = "Model" + modelName
|
||||
}
|
||||
|
||||
return modelName
|
||||
}
|
||||
|
||||
// generateRelationshipFieldName generates a field name for a relationship
|
||||
func (w *Writer) generateRelationshipFieldName(tableName string) string {
|
||||
// Use just the prefix (3 letters) for relationship fields
|
||||
return GeneratePrefix(tableName)
|
||||
}
|
||||
|
||||
// getPackageName returns the package name from options or defaults to "models"
|
||||
func (w *Writer) getPackageName() string {
|
||||
if w.options.PackageName != "" {
|
||||
return w.options.PackageName
|
||||
}
|
||||
return "models"
|
||||
}
|
||||
|
||||
// formatCode formats Go code using gofmt
|
||||
func (w *Writer) formatCode(code string) (string, error) {
|
||||
formatted, err := format.Source([]byte(code))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("format error: %w", err)
|
||||
}
|
||||
return string(formatted), nil
|
||||
}
|
||||
|
||||
// writeOutput writes the content to file or stdout
|
||||
func (w *Writer) writeOutput(content string) error {
|
||||
if w.options.OutputPath != "" {
|
||||
return os.WriteFile(w.options.OutputPath, []byte(content), 0644)
|
||||
}
|
||||
|
||||
// Print to stdout
|
||||
fmt.Print(content)
|
||||
return nil
|
||||
}
|
||||
243
pkg/writers/gorm/writer_test.go
Normal file
243
pkg/writers/gorm/writer_test.go
Normal file
@@ -0,0 +1,243 @@
|
||||
package gorm
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/models"
|
||||
"git.warky.dev/wdevs/relspecgo/pkg/writers"
|
||||
)
|
||||
|
||||
func TestWriter_WriteTable(t *testing.T) {
|
||||
// Create a simple table
|
||||
table := models.InitTable("users", "public")
|
||||
table.Columns["id"] = &models.Column{
|
||||
Name: "id",
|
||||
Type: "bigint",
|
||||
NotNull: true,
|
||||
IsPrimaryKey: true,
|
||||
AutoIncrement: true,
|
||||
Sequence: 1,
|
||||
}
|
||||
table.Columns["email"] = &models.Column{
|
||||
Name: "email",
|
||||
Type: "varchar",
|
||||
Length: 255,
|
||||
NotNull: false,
|
||||
Sequence: 2,
|
||||
}
|
||||
table.Columns["created_at"] = &models.Column{
|
||||
Name: "created_at",
|
||||
Type: "timestamp",
|
||||
NotNull: true,
|
||||
Sequence: 3,
|
||||
}
|
||||
|
||||
// Create writer
|
||||
opts := &writers.WriterOptions{
|
||||
PackageName: "models",
|
||||
Metadata: map[string]interface{}{
|
||||
"generate_table_name": true,
|
||||
"generate_get_id": true,
|
||||
},
|
||||
}
|
||||
|
||||
writer := NewWriter(opts)
|
||||
|
||||
// Write to temporary file
|
||||
tmpDir := t.TempDir()
|
||||
opts.OutputPath = filepath.Join(tmpDir, "test.go")
|
||||
|
||||
err := writer.WriteTable(table)
|
||||
if err != nil {
|
||||
t.Fatalf("WriteTable failed: %v", err)
|
||||
}
|
||||
|
||||
// Read the generated file
|
||||
content, err := os.ReadFile(opts.OutputPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read generated file: %v", err)
|
||||
}
|
||||
|
||||
generated := string(content)
|
||||
|
||||
// Verify key elements are present
|
||||
expectations := []string{
|
||||
"package models",
|
||||
"type ModelUser struct",
|
||||
"ID",
|
||||
"int64",
|
||||
"Email",
|
||||
"sql_types.SqlString",
|
||||
"CreatedAt",
|
||||
"time.Time",
|
||||
"gorm:\"column:id",
|
||||
"gorm:\"column:email",
|
||||
"func (m ModelUser) TableName() string",
|
||||
"return \"public.users\"",
|
||||
"func (m ModelUser) GetID() int64",
|
||||
}
|
||||
|
||||
for _, expected := range expectations {
|
||||
if !strings.Contains(generated, expected) {
|
||||
t.Errorf("Generated code missing expected content: %q\nGenerated:\n%s", expected, generated)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriter_WriteDatabase_MultiFile(t *testing.T) {
|
||||
// Create a database with two tables
|
||||
db := models.InitDatabase("testdb")
|
||||
schema := models.InitSchema("public")
|
||||
|
||||
// Table 1: users
|
||||
users := models.InitTable("users", "public")
|
||||
users.Columns["id"] = &models.Column{
|
||||
Name: "id",
|
||||
Type: "bigint",
|
||||
NotNull: true,
|
||||
IsPrimaryKey: true,
|
||||
}
|
||||
schema.Tables = append(schema.Tables, users)
|
||||
|
||||
// Table 2: posts
|
||||
posts := models.InitTable("posts", "public")
|
||||
posts.Columns["id"] = &models.Column{
|
||||
Name: "id",
|
||||
Type: "bigint",
|
||||
NotNull: true,
|
||||
IsPrimaryKey: true,
|
||||
}
|
||||
posts.Columns["user_id"] = &models.Column{
|
||||
Name: "user_id",
|
||||
Type: "bigint",
|
||||
NotNull: true,
|
||||
}
|
||||
posts.Constraints["fk_user"] = &models.Constraint{
|
||||
Name: "fk_user",
|
||||
Type: models.ForeignKeyConstraint,
|
||||
Columns: []string{"user_id"},
|
||||
ReferencedTable: "users",
|
||||
ReferencedSchema: "public",
|
||||
ReferencedColumns: []string{"id"},
|
||||
OnDelete: "CASCADE",
|
||||
}
|
||||
schema.Tables = append(schema.Tables, posts)
|
||||
|
||||
db.Schemas = append(db.Schemas, schema)
|
||||
|
||||
// Create writer with multi-file mode
|
||||
tmpDir := t.TempDir()
|
||||
opts := &writers.WriterOptions{
|
||||
PackageName: "models",
|
||||
OutputPath: tmpDir,
|
||||
Metadata: map[string]interface{}{
|
||||
"multi_file": true,
|
||||
},
|
||||
}
|
||||
|
||||
writer := NewWriter(opts)
|
||||
|
||||
err := writer.WriteDatabase(db)
|
||||
if err != nil {
|
||||
t.Fatalf("WriteDatabase failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify two files were created
|
||||
expectedFiles := []string{
|
||||
"sql_public_users.go",
|
||||
"sql_public_posts.go",
|
||||
}
|
||||
|
||||
for _, filename := range expectedFiles {
|
||||
filepath := filepath.Join(tmpDir, filename)
|
||||
if _, err := os.Stat(filepath); os.IsNotExist(err) {
|
||||
t.Errorf("Expected file not created: %s", filename)
|
||||
}
|
||||
}
|
||||
|
||||
// Check posts file contains relationship
|
||||
postsContent, err := os.ReadFile(filepath.Join(tmpDir, "sql_public_posts.go"))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read posts file: %v", err)
|
||||
}
|
||||
|
||||
if !strings.Contains(string(postsContent), "USE *ModelUser") {
|
||||
// Relationship field should be present
|
||||
t.Logf("Posts content:\n%s", string(postsContent))
|
||||
}
|
||||
}
|
||||
|
||||
func TestNameConverter_SnakeCaseToPascalCase(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{"user_id", "UserID"},
|
||||
{"http_request", "HTTPRequest"},
|
||||
{"user_profiles", "UserProfiles"},
|
||||
{"guid", "GUID"},
|
||||
{"rid_process", "RIDProcess"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.input, func(t *testing.T) {
|
||||
result := SnakeCaseToPascalCase(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("SnakeCaseToPascalCase(%q) = %q, want %q", tt.input, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNameConverter_Pluralize(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{"user", "users"},
|
||||
{"process", "processes"},
|
||||
{"child", "children"},
|
||||
{"person", "people"},
|
||||
{"status", "statuses"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.input, func(t *testing.T) {
|
||||
result := Pluralize(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("Pluralize(%q) = %q, want %q", tt.input, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTypeMapper_SQLTypeToGoType(t *testing.T) {
|
||||
mapper := NewTypeMapper()
|
||||
|
||||
tests := []struct {
|
||||
sqlType string
|
||||
notNull bool
|
||||
want string
|
||||
}{
|
||||
{"bigint", true, "int64"},
|
||||
{"bigint", false, "sql_types.SqlInt64"},
|
||||
{"varchar", true, "string"},
|
||||
{"varchar", false, "sql_types.SqlString"},
|
||||
{"timestamp", true, "time.Time"},
|
||||
{"timestamp", false, "sql_types.SqlTime"},
|
||||
{"boolean", true, "bool"},
|
||||
{"boolean", false, "sql_types.SqlBool"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.sqlType, func(t *testing.T) {
|
||||
result := mapper.SQLTypeToGoType(tt.sqlType, tt.notNull)
|
||||
if result != tt.want {
|
||||
t.Errorf("SQLTypeToGoType(%q, %v) = %q, want %q", tt.sqlType, tt.notNull, result, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user