So far so good
Some checks are pending
CI / Test (1.23) (push) Waiting to run
CI / Test (1.24) (push) Waiting to run
CI / Test (1.25) (push) Waiting to run
CI / Lint (push) Waiting to run
CI / Build (push) Waiting to run

This commit is contained in:
2025-12-16 18:10:40 +02:00
parent b9650739bf
commit 7c7054d2e2
44 changed files with 27029 additions and 48 deletions

View File

@@ -0,0 +1,284 @@
package bun
import (
"strings"
"unicode"
)
// SnakeCaseToPascalCase converts snake_case to PascalCase
// Examples: user_id → UserID, http_request → HTTPRequest
func SnakeCaseToPascalCase(s string) string {
if s == "" {
return ""
}
parts := strings.Split(s, "_")
for i, part := range parts {
parts[i] = capitalize(part)
}
return strings.Join(parts, "")
}
// SnakeCaseToCamelCase converts snake_case to camelCase
// Examples: user_id → userID, http_request → httpRequest
func SnakeCaseToCamelCase(s string) string {
if s == "" {
return ""
}
parts := strings.Split(s, "_")
for i, part := range parts {
if i == 0 {
parts[i] = strings.ToLower(part)
} else {
parts[i] = capitalize(part)
}
}
return strings.Join(parts, "")
}
// PascalCaseToSnakeCase converts PascalCase to snake_case
// Examples: UserID → user_id, HTTPRequest → http_request
func PascalCaseToSnakeCase(s string) string {
if s == "" {
return ""
}
var result strings.Builder
var prevUpper bool
var nextUpper bool
runes := []rune(s)
for i, r := range runes {
isUpper := unicode.IsUpper(r)
if i+1 < len(runes) {
nextUpper = unicode.IsUpper(runes[i+1])
} else {
nextUpper = false
}
if i > 0 && isUpper {
// Add underscore before uppercase letter if:
// 1. Previous char was lowercase, OR
// 2. Next char is lowercase (end of acronym)
if !prevUpper || (nextUpper == false && i+1 < len(runes)) {
result.WriteRune('_')
}
}
result.WriteRune(unicode.ToLower(r))
prevUpper = isUpper
}
return result.String()
}
// capitalize capitalizes the first letter and handles common acronyms
func capitalize(s string) string {
if s == "" {
return ""
}
upper := strings.ToUpper(s)
// Handle common acronyms
acronyms := map[string]bool{
"ID": true,
"UUID": true,
"GUID": true,
"URL": true,
"URI": true,
"HTTP": true,
"HTTPS": true,
"API": true,
"JSON": true,
"XML": true,
"SQL": true,
"HTML": true,
"CSS": true,
"RID": true,
}
if acronyms[upper] {
return upper
}
// Capitalize first letter
runes := []rune(s)
runes[0] = unicode.ToUpper(runes[0])
return string(runes)
}
// Pluralize converts a singular word to plural
// Basic implementation with common rules
func Pluralize(s string) string {
if s == "" {
return ""
}
// Special cases
irregular := map[string]string{
"person": "people",
"child": "children",
"tooth": "teeth",
"foot": "feet",
"man": "men",
"woman": "women",
"mouse": "mice",
"goose": "geese",
"ox": "oxen",
"datum": "data",
"medium": "media",
"analysis": "analyses",
"crisis": "crises",
"status": "statuses",
}
if plural, ok := irregular[strings.ToLower(s)]; ok {
return plural
}
// Already plural (ends in 's' but not 'ss' or 'us')
if strings.HasSuffix(s, "s") && !strings.HasSuffix(s, "ss") && !strings.HasSuffix(s, "us") {
return s
}
// Words ending in s, x, z, ch, sh
if strings.HasSuffix(s, "s") || strings.HasSuffix(s, "x") ||
strings.HasSuffix(s, "z") || strings.HasSuffix(s, "ch") ||
strings.HasSuffix(s, "sh") {
return s + "es"
}
// Words ending in consonant + y
if len(s) >= 2 && strings.HasSuffix(s, "y") {
prevChar := s[len(s)-2]
if !isVowel(prevChar) {
return s[:len(s)-1] + "ies"
}
}
// Words ending in f or fe
if strings.HasSuffix(s, "f") {
return s[:len(s)-1] + "ves"
}
if strings.HasSuffix(s, "fe") {
return s[:len(s)-2] + "ves"
}
// Words ending in consonant + o
if len(s) >= 2 && strings.HasSuffix(s, "o") {
prevChar := s[len(s)-2]
if !isVowel(prevChar) {
return s + "es"
}
}
// Default: add 's'
return s + "s"
}
// Singularize converts a plural word to singular
// Basic implementation with common rules
func Singularize(s string) string {
if s == "" {
return ""
}
// Special cases
irregular := map[string]string{
"people": "person",
"children": "child",
"teeth": "tooth",
"feet": "foot",
"men": "man",
"women": "woman",
"mice": "mouse",
"geese": "goose",
"oxen": "ox",
"data": "datum",
"media": "medium",
"analyses": "analysis",
"crises": "crisis",
"statuses": "status",
}
if singular, ok := irregular[strings.ToLower(s)]; ok {
return singular
}
// Words ending in ies
if strings.HasSuffix(s, "ies") && len(s) > 3 {
return s[:len(s)-3] + "y"
}
// Words ending in ves
if strings.HasSuffix(s, "ves") {
return s[:len(s)-3] + "f"
}
// Words ending in ses, xes, zes, ches, shes
if strings.HasSuffix(s, "ses") || strings.HasSuffix(s, "xes") ||
strings.HasSuffix(s, "zes") || strings.HasSuffix(s, "ches") ||
strings.HasSuffix(s, "shes") {
return s[:len(s)-2]
}
// Words ending in s (not ss)
if strings.HasSuffix(s, "s") && !strings.HasSuffix(s, "ss") {
return s[:len(s)-1]
}
// Already singular
return s
}
// GeneratePrefix generates a 3-letter prefix from a table name
// Examples: process → PRO, mastertask → MTL, user → USR
func GeneratePrefix(tableName string) string {
if tableName == "" {
return "TBL"
}
// Remove common prefixes
tableName = strings.TrimPrefix(tableName, "tbl_")
tableName = strings.TrimPrefix(tableName, "tb_")
// Split by underscore and take first letters
parts := strings.Split(tableName, "_")
var prefix strings.Builder
for _, part := range parts {
if part == "" {
continue
}
prefix.WriteRune(unicode.ToUpper(rune(part[0])))
if prefix.Len() >= 3 {
break
}
}
result := prefix.String()
// If we don't have 3 letters yet, add more from the first part
if len(result) < 3 && len(parts) > 0 {
firstPart := parts[0]
for i := 1; i < len(firstPart) && len(result) < 3; i++ {
result += strings.ToUpper(string(firstPart[i]))
}
}
// Pad with 'X' if still too short
for len(result) < 3 {
result += "X"
}
return result[:3]
}
// isVowel checks if a byte is a vowel
func isVowel(c byte) bool {
c = byte(unicode.ToLower(rune(c)))
return c == 'a' || c == 'e' || c == 'i' || c == 'o' || c == 'u'
}

View File

@@ -0,0 +1,250 @@
package bun
import (
"sort"
"git.warky.dev/wdevs/relspecgo/pkg/models"
)
// TemplateData represents the data passed to the template for code generation
type TemplateData struct {
PackageName string
Imports []string
Models []*ModelData
Config *MethodConfig
}
// ModelData represents a single model/struct in the template
type ModelData struct {
Name string
TableName string // schema.table format
SchemaName string
TableNameOnly string // just table name without schema
Comment string
Fields []*FieldData
Config *MethodConfig
PrimaryKeyField string // Name of the primary key field
IDColumnName string // Name of the ID column in database
Prefix string // 3-letter prefix
}
// FieldData represents a single field in a struct
type FieldData struct {
Name string // Go field name (PascalCase)
Type string // Go type
GormTag string // Complete gorm tag
JSONTag string // JSON tag
Comment string // Field comment
}
// MethodConfig controls which helper methods to generate
type MethodConfig struct {
GenerateTableName bool
GenerateSchemaName bool
GenerateTableNameOnly bool
GenerateGetID bool
GenerateGetIDStr bool
GenerateSetID bool
GenerateUpdateID bool
GenerateGetIDName bool
GenerateGetPrefix bool
}
// DefaultMethodConfig returns a MethodConfig with all methods enabled
func DefaultMethodConfig() *MethodConfig {
return &MethodConfig{
GenerateTableName: true,
GenerateSchemaName: true,
GenerateTableNameOnly: true,
GenerateGetID: true,
GenerateGetIDStr: true,
GenerateSetID: true,
GenerateUpdateID: true,
GenerateGetIDName: true,
GenerateGetPrefix: true,
}
}
// NewTemplateData creates a new TemplateData with the given package name and config
func NewTemplateData(packageName string, config *MethodConfig) *TemplateData {
if config == nil {
config = DefaultMethodConfig()
}
return &TemplateData{
PackageName: packageName,
Imports: make([]string, 0),
Models: make([]*ModelData, 0),
Config: config,
}
}
// AddModel adds a model to the template data
func (td *TemplateData) AddModel(model *ModelData) {
model.Config = td.Config
td.Models = append(td.Models, model)
}
// AddImport adds an import to the template data (deduplicates automatically)
func (td *TemplateData) AddImport(importPath string) {
// Check if already exists
for _, imp := range td.Imports {
if imp == importPath {
return
}
}
td.Imports = append(td.Imports, importPath)
}
// FinalizeImports sorts and organizes imports
func (td *TemplateData) FinalizeImports() {
// Sort imports alphabetically
sort.Strings(td.Imports)
}
// NewModelData creates a new ModelData from a models.Table
func NewModelData(table *models.Table, schema string, typeMapper *TypeMapper) *ModelData {
tableName := table.Name
if schema != "" {
tableName = schema + "." + table.Name
}
// Generate model name: singularize and convert to PascalCase
singularTable := Singularize(table.Name)
modelName := SnakeCaseToPascalCase(singularTable)
// Add "Model" prefix if not already present
if !hasModelPrefix(modelName) {
modelName = "Model" + modelName
}
model := &ModelData{
Name: modelName,
TableName: tableName,
SchemaName: schema,
TableNameOnly: table.Name,
Comment: formatComment(table.Description, table.Comment),
Fields: make([]*FieldData, 0),
Prefix: GeneratePrefix(table.Name),
}
// Find primary key
for _, col := range table.Columns {
if col.IsPrimaryKey {
model.PrimaryKeyField = SnakeCaseToPascalCase(col.Name)
model.IDColumnName = col.Name
break
}
}
// Convert columns to fields (sorted by sequence or name)
columns := sortColumns(table.Columns)
for _, col := range columns {
field := columnToField(col, table, typeMapper)
model.Fields = append(model.Fields, field)
}
return model
}
// columnToField converts a models.Column to FieldData
func columnToField(col *models.Column, table *models.Table, typeMapper *TypeMapper) *FieldData {
fieldName := SnakeCaseToPascalCase(col.Name)
goType := typeMapper.SQLTypeToGoType(col.Type, col.NotNull)
gormTag := typeMapper.BuildGormTag(col, table)
jsonTag := col.Name // Use column name for JSON tag
return &FieldData{
Name: fieldName,
Type: goType,
GormTag: gormTag,
JSONTag: jsonTag,
Comment: formatComment(col.Description, col.Comment),
}
}
// AddRelationshipField adds a relationship field to the model
func (md *ModelData) AddRelationshipField(field *FieldData) {
md.Fields = append(md.Fields, field)
}
// formatComment combines description and comment into a single comment string
func formatComment(description, comment string) string {
if description != "" && comment != "" {
return description + " - " + comment
}
if description != "" {
return description
}
return comment
}
// hasModelPrefix checks if a name already has "Model" prefix
func hasModelPrefix(name string) bool {
return len(name) >= 5 && name[:5] == "Model"
}
// sortColumns sorts columns by sequence, then by name
func sortColumns(columns map[string]*models.Column) []*models.Column {
result := make([]*models.Column, 0, len(columns))
for _, col := range columns {
result = append(result, col)
}
sort.Slice(result, func(i, j int) bool {
// Sort by sequence if both have it
if result[i].Sequence > 0 && result[j].Sequence > 0 {
return result[i].Sequence < result[j].Sequence
}
// Put primary keys first
if result[i].IsPrimaryKey != result[j].IsPrimaryKey {
return result[i].IsPrimaryKey
}
// Otherwise sort alphabetically
return result[i].Name < result[j].Name
})
return result
}
// LoadMethodConfigFromMetadata loads method configuration from metadata map
func LoadMethodConfigFromMetadata(metadata map[string]interface{}) *MethodConfig {
config := DefaultMethodConfig()
if metadata == nil {
return config
}
// Load each setting from metadata if present
if val, ok := metadata["generate_table_name"].(bool); ok {
config.GenerateTableName = val
}
if val, ok := metadata["generate_schema_name"].(bool); ok {
config.GenerateSchemaName = val
}
if val, ok := metadata["generate_table_name_only"].(bool); ok {
config.GenerateTableNameOnly = val
}
if val, ok := metadata["generate_get_id"].(bool); ok {
config.GenerateGetID = val
}
if val, ok := metadata["generate_get_id_str"].(bool); ok {
config.GenerateGetIDStr = val
}
if val, ok := metadata["generate_set_id"].(bool); ok {
config.GenerateSetID = val
}
if val, ok := metadata["generate_update_id"].(bool); ok {
config.GenerateUpdateID = val
}
if val, ok := metadata["generate_get_id_name"].(bool); ok {
config.GenerateGetIDName = val
}
if val, ok := metadata["generate_get_prefix"].(bool); ok {
config.GenerateGetPrefix = val
}
return config
}

View File

@@ -0,0 +1,118 @@
package bun
import (
"bytes"
"text/template"
)
// modelTemplate defines the template for generating Bun models
const modelTemplate = `// Code generated by relspecgo. DO NOT EDIT.
package {{.PackageName}}
{{if .Imports -}}
import (
{{range .Imports -}}
{{.}}
{{end -}}
)
{{end}}
{{range .Models}}
{{if .Comment}}// {{.Comment}}{{end}}
type {{.Name}} struct {
bun.BaseModel ` + "`bun:\"table:{{.TableName}},alias:{{.TableNameOnly}}\"`" + `
{{- range .Fields}}
{{.Name}} {{.Type}} ` + "`bun:\"{{.BunTag}}\" json:\"{{.JSONTag}}\"`" + `{{if .Comment}} // {{.Comment}}{{end}}
{{- end}}
}
{{if .Config.GenerateTableName}}
// TableName returns the table name for {{.Name}}
func (m {{.Name}}) TableName() string {
return "{{.TableName}}"
}
{{end}}
{{if .Config.GenerateTableNameOnly}}
// TableNameOnly returns the table name without schema for {{.Name}}
func (m {{.Name}}) TableNameOnly() string {
return "{{.TableNameOnly}}"
}
{{end}}
{{if .Config.GenerateSchemaName}}
// SchemaName returns the schema name for {{.Name}}
func (m {{.Name}}) SchemaName() string {
return "{{.SchemaName}}"
}
{{end}}
{{if and .Config.GenerateGetID .PrimaryKeyField}}
// GetID returns the primary key value
func (m {{.Name}}) GetID() int64 {
{{if .PrimaryKeyIsSQL -}}
return m.{{.PrimaryKeyField}}.Int64()
{{- else -}}
return int64(m.{{.PrimaryKeyField}})
{{- end}}
}
{{end}}
{{if and .Config.GenerateGetIDStr .PrimaryKeyField}}
// GetIDStr returns the primary key as a string
func (m {{.Name}}) GetIDStr() string {
return fmt.Sprintf("%d", m.{{.PrimaryKeyField}})
}
{{end}}
{{if and .Config.GenerateSetID .PrimaryKeyField}}
// SetID sets the primary key value
func (m {{.Name}}) SetID(newid int64) {
m.UpdateID(newid)
}
{{end}}
{{if and .Config.GenerateUpdateID .PrimaryKeyField}}
// UpdateID updates the primary key value
func (m *{{.Name}}) UpdateID(newid int64) {
{{if .PrimaryKeyIsSQL -}}
m.{{.PrimaryKeyField}}.FromString(fmt.Sprintf("%d", newid))
{{- else -}}
m.{{.PrimaryKeyField}} = int32(newid)
{{- end}}
}
{{end}}
{{if and .Config.GenerateGetIDName .IDColumnName}}
// GetIDName returns the name of the primary key column
func (m {{.Name}}) GetIDName() string {
return "{{.IDColumnName}}"
}
{{end}}
{{if .Config.GenerateGetPrefix}}
// GetPrefix returns the table prefix
func (m {{.Name}}) GetPrefix() string {
return "{{.Prefix}}"
}
{{end}}
{{end -}}
`
// Templates holds the parsed templates
type Templates struct {
modelTmpl *template.Template
}
// NewTemplates creates and parses the templates
func NewTemplates() (*Templates, error) {
modelTmpl, err := template.New("model").Parse(modelTemplate)
if err != nil {
return nil, err
}
return &Templates{
modelTmpl: modelTmpl,
}, nil
}
// GenerateCode executes the template with the given data
func (t *Templates) GenerateCode(data *TemplateData) (string, error) {
var buf bytes.Buffer
err := t.modelTmpl.Execute(&buf, data)
if err != nil {
return "", err
}
return buf.String(), nil
}

View File

@@ -0,0 +1,253 @@
package bun
import (
"fmt"
"strings"
"git.warky.dev/wdevs/relspecgo/pkg/models"
)
// TypeMapper handles type conversions between SQL and Go types for Bun
type TypeMapper struct {
// Package alias for sql_types import
sqlTypesAlias string
}
// NewTypeMapper creates a new TypeMapper with default settings
func NewTypeMapper() *TypeMapper {
return &TypeMapper{
sqlTypesAlias: "resolvespec_common",
}
}
// SQLTypeToGoType converts a SQL type to its Go equivalent
// Uses ResolveSpec common package types (all are nullable by default in Bun)
func (tm *TypeMapper) SQLTypeToGoType(sqlType string, notNull bool) string {
// Normalize SQL type (lowercase, remove length/precision)
baseType := tm.extractBaseType(sqlType)
// For Bun, we typically use resolvespec_common types for most fields
// unless they're explicitly NOT NULL and we want to avoid null handling
if notNull && tm.isSimpleType(baseType) {
return tm.baseGoType(baseType)
}
// Use resolvespec_common types for nullable fields
return tm.bunGoType(baseType)
}
// extractBaseType extracts the base type from a SQL type string
func (tm *TypeMapper) extractBaseType(sqlType string) string {
sqlType = strings.ToLower(strings.TrimSpace(sqlType))
// Remove everything after '('
if idx := strings.Index(sqlType, "("); idx > 0 {
sqlType = sqlType[:idx]
}
return sqlType
}
// isSimpleType checks if a type should use base Go type when NOT NULL
func (tm *TypeMapper) isSimpleType(sqlType string) bool {
simpleTypes := map[string]bool{
"bigint": true,
"integer": true,
"int8": true,
"int4": true,
"boolean": true,
"bool": true,
}
return simpleTypes[sqlType]
}
// baseGoType returns the base Go type for a SQL type (not null, simple types only)
func (tm *TypeMapper) baseGoType(sqlType string) string {
typeMap := map[string]string{
"integer": "int32",
"int": "int32",
"int4": "int32",
"smallint": "int16",
"int2": "int16",
"bigint": "int64",
"int8": "int64",
"serial": "int32",
"bigserial": "int64",
"boolean": "bool",
"bool": "bool",
}
if goType, ok := typeMap[sqlType]; ok {
return goType
}
// Default to resolvespec type
return tm.bunGoType(sqlType)
}
// bunGoType returns the Bun/ResolveSpec common type
func (tm *TypeMapper) bunGoType(sqlType string) string {
typeMap := map[string]string{
// Integer types
"integer": tm.sqlTypesAlias + ".SqlInt32",
"int": tm.sqlTypesAlias + ".SqlInt32",
"int4": tm.sqlTypesAlias + ".SqlInt32",
"smallint": tm.sqlTypesAlias + ".SqlInt16",
"int2": tm.sqlTypesAlias + ".SqlInt16",
"bigint": tm.sqlTypesAlias + ".SqlInt64",
"int8": tm.sqlTypesAlias + ".SqlInt64",
"serial": tm.sqlTypesAlias + ".SqlInt32",
"bigserial": tm.sqlTypesAlias + ".SqlInt64",
"smallserial": tm.sqlTypesAlias + ".SqlInt16",
// String types
"text": tm.sqlTypesAlias + ".SqlString",
"varchar": tm.sqlTypesAlias + ".SqlString",
"char": tm.sqlTypesAlias + ".SqlString",
"character": tm.sqlTypesAlias + ".SqlString",
"citext": tm.sqlTypesAlias + ".SqlString",
"bpchar": tm.sqlTypesAlias + ".SqlString",
// Boolean
"boolean": tm.sqlTypesAlias + ".SqlBool",
"bool": tm.sqlTypesAlias + ".SqlBool",
// Float types
"real": tm.sqlTypesAlias + ".SqlFloat32",
"float4": tm.sqlTypesAlias + ".SqlFloat32",
"double precision": tm.sqlTypesAlias + ".SqlFloat64",
"float8": tm.sqlTypesAlias + ".SqlFloat64",
"numeric": tm.sqlTypesAlias + ".SqlFloat64",
"decimal": tm.sqlTypesAlias + ".SqlFloat64",
// Date/Time types
"timestamp": tm.sqlTypesAlias + ".SqlTime",
"timestamp without time zone": tm.sqlTypesAlias + ".SqlTime",
"timestamp with time zone": tm.sqlTypesAlias + ".SqlTime",
"timestamptz": tm.sqlTypesAlias + ".SqlTime",
"date": tm.sqlTypesAlias + ".SqlDate",
"time": tm.sqlTypesAlias + ".SqlTime",
"time without time zone": tm.sqlTypesAlias + ".SqlTime",
"time with time zone": tm.sqlTypesAlias + ".SqlTime",
"timetz": tm.sqlTypesAlias + ".SqlTime",
// Binary
"bytea": "[]byte",
// UUID
"uuid": tm.sqlTypesAlias + ".SqlUUID",
// JSON
"json": tm.sqlTypesAlias + ".SqlJSON",
"jsonb": tm.sqlTypesAlias + ".SqlJSONB",
// Network
"inet": tm.sqlTypesAlias + ".SqlString",
"cidr": tm.sqlTypesAlias + ".SqlString",
"macaddr": tm.sqlTypesAlias + ".SqlString",
// Other
"money": tm.sqlTypesAlias + ".SqlFloat64",
}
if goType, ok := typeMap[sqlType]; ok {
return goType
}
// Default to SqlString for unknown types
return tm.sqlTypesAlias + ".SqlString"
}
// BuildBunTag generates a complete Bun tag string for a column
// Bun format: bun:"column_name,type:type_name,pk,default:value"
func (tm *TypeMapper) BuildBunTag(column *models.Column, table *models.Table) string {
var parts []string
// Column name comes first (no prefix)
parts = append(parts, column.Name)
// Add type if specified
if column.Type != "" {
typeStr := column.Type
if column.Length > 0 {
typeStr = fmt.Sprintf("%s(%d)", typeStr, column.Length)
} else if column.Precision > 0 {
if column.Scale > 0 {
typeStr = fmt.Sprintf("%s(%d,%d)", typeStr, column.Precision, column.Scale)
} else {
typeStr = fmt.Sprintf("%s(%d)", typeStr, column.Precision)
}
}
parts = append(parts, fmt.Sprintf("type:%s", typeStr))
}
// Primary key
if column.IsPrimaryKey {
parts = append(parts, "pk")
}
// Default value
if column.Default != nil {
parts = append(parts, fmt.Sprintf("default:%v", column.Default))
}
// Nullable (Bun uses nullzero for nullable fields)
if !column.NotNull && !column.IsPrimaryKey {
parts = append(parts, "nullzero")
}
// Check for unique constraint
if table != nil {
for _, constraint := range table.Constraints {
if constraint.Type == models.UniqueConstraint {
for _, col := range constraint.Columns {
if col == column.Name {
parts = append(parts, "unique")
break
}
}
}
}
}
// Join with commas and add trailing comma (Bun convention)
return strings.Join(parts, ",") + ","
}
// BuildRelationshipTag generates Bun tag for relationship fields
// Bun format: bun:"rel:has-one,join:local_column=foreign_column"
func (tm *TypeMapper) BuildRelationshipTag(constraint *models.Constraint, relType string) string {
var parts []string
// Add relationship type
parts = append(parts, fmt.Sprintf("rel:%s", relType))
// Add join clause
if len(constraint.Columns) > 0 && len(constraint.ReferencedColumns) > 0 {
localCol := constraint.Columns[0]
foreignCol := constraint.ReferencedColumns[0]
parts = append(parts, fmt.Sprintf("join:%s=%s", localCol, foreignCol))
}
return strings.Join(parts, ",")
}
// NeedsTimeImport checks if the Go type requires time package import
func (tm *TypeMapper) NeedsTimeImport(goType string) bool {
return strings.Contains(goType, "time.Time")
}
// NeedsFmtImport checks if we need fmt import (for GetIDStr method)
func (tm *TypeMapper) NeedsFmtImport(generateGetIDStr bool) bool {
return generateGetIDStr
}
// GetSQLTypesImport returns the import path for sql_types (ResolveSpec common)
func (tm *TypeMapper) GetSQLTypesImport() string {
return "github.com/bitechdev/ResolveSpec/pkg/common"
}
// GetBunImport returns the import path for Bun
func (tm *TypeMapper) GetBunImport() string {
return "github.com/uptrace/bun"
}

224
pkg/writers/dbml/writer.go Normal file
View File

@@ -0,0 +1,224 @@
package dbml
import (
"fmt"
"os"
"strings"
"git.warky.dev/wdevs/relspecgo/pkg/models"
"git.warky.dev/wdevs/relspecgo/pkg/writers"
)
// Writer implements the writers.Writer interface for DBML format
type Writer struct {
options *writers.WriterOptions
}
// NewWriter creates a new DBML writer with the given options
func NewWriter(options *writers.WriterOptions) *Writer {
return &Writer{
options: options,
}
}
// WriteDatabase writes a Database model to DBML format
func (w *Writer) WriteDatabase(db *models.Database) error {
content := w.databaseToDBML(db)
if w.options.OutputPath != "" {
return os.WriteFile(w.options.OutputPath, []byte(content), 0644)
}
// If no output path, print to stdout
fmt.Print(content)
return nil
}
// WriteSchema writes a Schema model to DBML format
func (w *Writer) WriteSchema(schema *models.Schema) error {
content := w.schemaToDBML(schema)
if w.options.OutputPath != "" {
return os.WriteFile(w.options.OutputPath, []byte(content), 0644)
}
fmt.Print(content)
return nil
}
// WriteTable writes a Table model to DBML format
func (w *Writer) WriteTable(table *models.Table) error {
content := w.tableToDBML(table, table.Schema)
if w.options.OutputPath != "" {
return os.WriteFile(w.options.OutputPath, []byte(content), 0644)
}
fmt.Print(content)
return nil
}
// databaseToDBML converts a Database to DBML format string
func (w *Writer) databaseToDBML(d *models.Database) string {
var result string
// Add database comment if exists
if d.Description != "" {
result += fmt.Sprintf("// %s\n", d.Description)
}
if d.Comment != "" {
result += fmt.Sprintf("// %s\n", d.Comment)
}
if d.Description != "" || d.Comment != "" {
result += "\n"
}
// Process each schema
for _, schema := range d.Schemas {
result += w.schemaToDBML(schema)
}
// Add relationships
result += "\n// Relationships\n"
for _, schema := range d.Schemas {
for _, table := range schema.Tables {
for _, constraint := range table.Constraints {
if constraint.Type == models.ForeignKeyConstraint {
result += w.constraintToDBML(constraint, schema.Name, table.Name)
}
}
}
}
return result
}
// schemaToDBML converts a Schema to DBML format string
func (w *Writer) schemaToDBML(schema *models.Schema) string {
var result string
if schema.Description != "" {
result += fmt.Sprintf("// Schema: %s - %s\n", schema.Name, schema.Description)
}
// Process tables
for _, table := range schema.Tables {
result += w.tableToDBML(table, schema.Name)
result += "\n"
}
return result
}
// tableToDBML converts a Table to DBML format string
func (w *Writer) tableToDBML(t *models.Table, schemaName string) string {
var result string
// Table definition
tableName := fmt.Sprintf("%s.%s", schemaName, t.Name)
result += fmt.Sprintf("Table %s {\n", tableName)
// Add columns
for _, column := range t.Columns {
result += fmt.Sprintf(" %s %s", column.Name, column.Type)
// Add column attributes
attrs := make([]string, 0)
if column.IsPrimaryKey {
attrs = append(attrs, "primary key")
}
if column.NotNull && !column.IsPrimaryKey {
attrs = append(attrs, "not null")
}
if column.AutoIncrement {
attrs = append(attrs, "increment")
}
if column.Default != nil {
attrs = append(attrs, fmt.Sprintf("default: %v", column.Default))
}
if len(attrs) > 0 {
result += fmt.Sprintf(" [%s]", strings.Join(attrs, ", "))
}
if column.Comment != "" {
result += fmt.Sprintf(" // %s", column.Comment)
}
result += "\n"
}
// Add indexes
indexCount := 0
for _, index := range t.Indexes {
if indexCount == 0 {
result += "\n indexes {\n"
}
indexAttrs := make([]string, 0)
if index.Unique {
indexAttrs = append(indexAttrs, "unique")
}
if index.Name != "" {
indexAttrs = append(indexAttrs, fmt.Sprintf("name: '%s'", index.Name))
}
if index.Type != "" {
indexAttrs = append(indexAttrs, fmt.Sprintf("type: %s", index.Type))
}
result += fmt.Sprintf(" (%s)", strings.Join(index.Columns, ", "))
if len(indexAttrs) > 0 {
result += fmt.Sprintf(" [%s]", strings.Join(indexAttrs, ", "))
}
result += "\n"
indexCount++
}
if indexCount > 0 {
result += " }\n"
}
// Add table note
if t.Description != "" || t.Comment != "" {
note := t.Description
if note != "" && t.Comment != "" {
note += " - "
}
note += t.Comment
result += fmt.Sprintf("\n Note: '%s'\n", note)
}
result += "}\n"
return result
}
// constraintToDBML converts a Constraint to DBML format string
func (w *Writer) constraintToDBML(c *models.Constraint, schemaName, tableName string) string {
if c.Type != models.ForeignKeyConstraint || c.ReferencedTable == "" {
return ""
}
fromTable := fmt.Sprintf("%s.%s", schemaName, tableName)
toTable := fmt.Sprintf("%s.%s", c.ReferencedSchema, c.ReferencedTable)
// Determine relationship cardinality
// For foreign keys, it's typically many-to-one
relationship := ">"
fromCols := strings.Join(c.Columns, ", ")
toCols := strings.Join(c.ReferencedColumns, ", ")
result := fmt.Sprintf("Ref: %s.(%s) %s %s.(%s)", fromTable, fromCols, relationship, toTable, toCols)
// Add actions
actions := make([]string, 0)
if c.OnDelete != "" {
actions = append(actions, fmt.Sprintf("ondelete: %s", c.OnDelete))
}
if c.OnUpdate != "" {
actions = append(actions, fmt.Sprintf("onupdate: %s", c.OnUpdate))
}
if len(actions) > 0 {
result += fmt.Sprintf(" [%s]", strings.Join(actions, ", "))
}
result += "\n"
return result
}

View File

@@ -0,0 +1,36 @@
package dctx
import (
"fmt"
"git.warky.dev/wdevs/relspecgo/pkg/models"
"git.warky.dev/wdevs/relspecgo/pkg/writers"
)
// Writer implements the writers.Writer interface for DCTX format
// Note: DCTX is a read-only format used for loading Clarion dictionary files
type Writer struct {
options *writers.WriterOptions
}
// NewWriter creates a new DCTX writer with the given options
func NewWriter(options *writers.WriterOptions) *Writer {
return &Writer{
options: options,
}
}
// WriteDatabase returns an error as DCTX format is read-only
func (w *Writer) WriteDatabase(db *models.Database) error {
return fmt.Errorf("DCTX format is read-only and does not support writing - it is used for loading Clarion dictionary files only")
}
// WriteSchema returns an error as DCTX format is read-only
func (w *Writer) WriteSchema(schema *models.Schema) error {
return fmt.Errorf("DCTX format is read-only and does not support writing - it is used for loading Clarion dictionary files only")
}
// WriteTable returns an error as DCTX format is read-only
func (w *Writer) WriteTable(table *models.Table) error {
return fmt.Errorf("DCTX format is read-only and does not support writing - it is used for loading Clarion dictionary files only")
}

View File

@@ -0,0 +1,77 @@
package drawdb
// DrawDBSchema represents the complete DrawDB JSON structure
type DrawDBSchema struct {
Tables []*DrawDBTable `json:"tables" yaml:"tables" xml:"tables"`
Relationships []*DrawDBRelationship `json:"relationships" yaml:"relationships" xml:"relationships"`
Notes []*DrawDBNote `json:"notes,omitempty" yaml:"notes,omitempty" xml:"notes,omitempty"`
SubjectAreas []*DrawDBArea `json:"subjectAreas,omitempty" yaml:"subjectAreas,omitempty" xml:"subjectAreas,omitempty"`
}
// DrawDBTable represents a table in DrawDB format
type DrawDBTable struct {
ID int `json:"id" yaml:"id" xml:"id"`
Name string `json:"name" yaml:"name" xml:"name"`
Schema string `json:"schema,omitempty" yaml:"schema,omitempty" xml:"schema,omitempty"`
Comment string `json:"comment,omitempty" yaml:"comment,omitempty" xml:"comment,omitempty"`
Color string `json:"color" yaml:"color" xml:"color"`
X int `json:"x" yaml:"x" xml:"x"`
Y int `json:"y" yaml:"y" xml:"y"`
Fields []*DrawDBField `json:"fields" yaml:"fields" xml:"fields"`
Indexes []*DrawDBIndex `json:"indexes,omitempty" yaml:"indexes,omitempty" xml:"indexes,omitempty"`
}
// DrawDBField represents a column/field in DrawDB format
type DrawDBField struct {
ID int `json:"id" yaml:"id" xml:"id"`
Name string `json:"name" yaml:"name" xml:"name"`
Type string `json:"type" yaml:"type" xml:"type"`
Default string `json:"default,omitempty" yaml:"default,omitempty" xml:"default,omitempty"`
Check string `json:"check,omitempty" yaml:"check,omitempty" xml:"check,omitempty"`
Primary bool `json:"primary" yaml:"primary" xml:"primary"`
Unique bool `json:"unique" yaml:"unique" xml:"unique"`
NotNull bool `json:"notNull" yaml:"notNull" xml:"notNull"`
Increment bool `json:"increment" yaml:"increment" xml:"increment"`
Comment string `json:"comment,omitempty" yaml:"comment,omitempty" xml:"comment,omitempty"`
}
// DrawDBIndex represents an index in DrawDB format
type DrawDBIndex struct {
ID int `json:"id" yaml:"id" xml:"id"`
Name string `json:"name" yaml:"name" xml:"name"`
Unique bool `json:"unique" yaml:"unique" xml:"unique"`
Fields []int `json:"fields" yaml:"fields" xml:"fields"` // Field IDs
}
// DrawDBRelationship represents a relationship in DrawDB format
type DrawDBRelationship struct {
ID int `json:"id" yaml:"id" xml:"id"`
Name string `json:"name" yaml:"name" xml:"name"`
StartTableID int `json:"startTableId" yaml:"startTableId" xml:"startTableId"`
EndTableID int `json:"endTableId" yaml:"endTableId" xml:"endTableId"`
StartFieldID int `json:"startFieldId" yaml:"startFieldId" xml:"startFieldId"`
EndFieldID int `json:"endFieldId" yaml:"endFieldId" xml:"endFieldId"`
Cardinality string `json:"cardinality" yaml:"cardinality" xml:"cardinality"` // "One to one", "One to many", "Many to one"
UpdateConstraint string `json:"updateConstraint,omitempty" yaml:"updateConstraint,omitempty" xml:"updateConstraint,omitempty"`
DeleteConstraint string `json:"deleteConstraint,omitempty" yaml:"deleteConstraint,omitempty" xml:"deleteConstraint,omitempty"`
}
// DrawDBNote represents a note in DrawDB format
type DrawDBNote struct {
ID int `json:"id" yaml:"id" xml:"id"`
Content string `json:"content" yaml:"content" xml:"content"`
Color string `json:"color" yaml:"color" xml:"color"`
X int `json:"x" yaml:"x" xml:"x"`
Y int `json:"y" yaml:"y" xml:"y"`
}
// DrawDBArea represents a subject area/grouping in DrawDB format
type DrawDBArea struct {
ID int `json:"id" yaml:"id" xml:"id"`
Name string `json:"name" yaml:"name" xml:"name"`
Color string `json:"color" yaml:"color" xml:"color"`
X int `json:"x" yaml:"x" xml:"x"`
Y int `json:"y" yaml:"y" xml:"y"`
Width int `json:"width" yaml:"width" xml:"width"`
Height int `json:"height" yaml:"height" xml:"height"`
}

View File

@@ -0,0 +1,349 @@
package drawdb
import (
"encoding/json"
"fmt"
"os"
"git.warky.dev/wdevs/relspecgo/pkg/models"
"git.warky.dev/wdevs/relspecgo/pkg/writers"
)
// Writer implements the writers.Writer interface for DrawDB JSON format
type Writer struct {
options *writers.WriterOptions
}
// NewWriter creates a new DrawDB writer with the given options
func NewWriter(options *writers.WriterOptions) *Writer {
return &Writer{
options: options,
}
}
// WriteDatabase writes a Database model to DrawDB JSON format
func (w *Writer) WriteDatabase(db *models.Database) error {
schema := w.databaseToDrawDB(db)
return w.writeJSON(schema)
}
// WriteSchema writes a Schema model to DrawDB JSON format
func (w *Writer) WriteSchema(schema *models.Schema) error {
drawSchema := w.schemaToDrawDB(schema)
return w.writeJSON(drawSchema)
}
// WriteTable writes a Table model to DrawDB JSON format
func (w *Writer) WriteTable(table *models.Table) error {
drawSchema := w.tableToDrawDB(table)
return w.writeJSON(drawSchema)
}
// writeJSON marshals the data to JSON and writes to output
func (w *Writer) writeJSON(data interface{}) error {
jsonData, err := json.MarshalIndent(data, "", " ")
if err != nil {
return fmt.Errorf("failed to marshal to JSON: %w", err)
}
if w.options.OutputPath != "" {
return os.WriteFile(w.options.OutputPath, jsonData, 0644)
}
// If no output path, print to stdout
fmt.Println(string(jsonData))
return nil
}
// databaseToDrawDB converts a Database to DrawDB JSON format
func (w *Writer) databaseToDrawDB(d *models.Database) *DrawDBSchema {
schema := &DrawDBSchema{
Tables: make([]*DrawDBTable, 0),
Relationships: make([]*DrawDBRelationship, 0),
Notes: make([]*DrawDBNote, 0),
SubjectAreas: make([]*DrawDBArea, 0),
}
// Track IDs and mappings
tableID := 0
fieldID := 0
relationshipID := 0
noteID := 0
areaID := 0
// Map to track table name to ID
tableMap := make(map[string]int)
// Map to track field full path to ID
fieldMap := make(map[string]int)
// Position tables in a grid layout
gridX, gridY := 50, 50
colWidth, rowHeight := 300, 200
tablesPerRow := 4
tableIndex := 0
// Create subject areas for schemas
for schemaIdx, schemaModel := range d.Schemas {
if schemaModel.Description != "" || schemaModel.Comment != "" {
note := schemaModel.Description
if note != "" && schemaModel.Comment != "" {
note += "\n"
}
note += schemaModel.Comment
area := &DrawDBArea{
ID: areaID,
Name: schemaModel.Name,
Color: getColorForIndex(schemaIdx),
X: gridX - 20,
Y: gridY - 20,
Width: colWidth*tablesPerRow + 100,
Height: rowHeight*((len(schemaModel.Tables)/tablesPerRow)+1) + 100,
}
schema.SubjectAreas = append(schema.SubjectAreas, area)
areaID++
}
// Process tables in schema
for _, table := range schemaModel.Tables {
drawTable, newFieldID := w.convertTableToDrawDB(table, schemaModel.Name, tableID, fieldID, tableIndex, tablesPerRow, gridX, gridY, colWidth, rowHeight, schemaIdx)
// Store table mapping
tableKey := fmt.Sprintf("%s.%s", schemaModel.Name, table.Name)
tableMap[tableKey] = tableID
// Store field mappings
for _, field := range drawTable.Fields {
fieldKey := fmt.Sprintf("%s.%s.%s", schemaModel.Name, table.Name, field.Name)
fieldMap[fieldKey] = field.ID
}
schema.Tables = append(schema.Tables, drawTable)
fieldID = newFieldID
tableID++
tableIndex++
}
}
// Add relationships
for _, schemaModel := range d.Schemas {
for _, table := range schemaModel.Tables {
for _, constraint := range table.Constraints {
if constraint.Type == models.ForeignKeyConstraint && constraint.ReferencedTable != "" {
startTableKey := fmt.Sprintf("%s.%s", schemaModel.Name, table.Name)
endTableKey := fmt.Sprintf("%s.%s", constraint.ReferencedSchema, constraint.ReferencedTable)
startTableID, startExists := tableMap[startTableKey]
endTableID, endExists := tableMap[endTableKey]
if startExists && endExists && len(constraint.Columns) > 0 && len(constraint.ReferencedColumns) > 0 {
// Find relative field IDs within their tables
startFieldID := 0
endFieldID := 0
for _, t := range schema.Tables {
if t.ID == startTableID {
for idx, f := range t.Fields {
if f.Name == constraint.Columns[0] {
startFieldID = idx
break
}
}
}
if t.ID == endTableID {
for idx, f := range t.Fields {
if f.Name == constraint.ReferencedColumns[0] {
endFieldID = idx
break
}
}
}
}
relationship := &DrawDBRelationship{
ID: relationshipID,
Name: constraint.Name,
StartTableID: startTableID,
EndTableID: endTableID,
StartFieldID: startFieldID,
EndFieldID: endFieldID,
Cardinality: "Many to one",
UpdateConstraint: constraint.OnUpdate,
DeleteConstraint: constraint.OnDelete,
}
schema.Relationships = append(schema.Relationships, relationship)
relationshipID++
}
}
}
}
}
// Add database description as a note
if d.Description != "" || d.Comment != "" {
note := d.Description
if note != "" && d.Comment != "" {
note += "\n"
}
note += d.Comment
schema.Notes = append(schema.Notes, &DrawDBNote{
ID: noteID,
Content: fmt.Sprintf("Database: %s\n\n%s", d.Name, note),
Color: "#ffd93d",
X: 10,
Y: 10,
})
}
return schema
}
// schemaToDrawDB converts a Schema to DrawDB format
func (w *Writer) schemaToDrawDB(schema *models.Schema) *DrawDBSchema {
drawSchema := &DrawDBSchema{
Tables: make([]*DrawDBTable, 0),
Relationships: make([]*DrawDBRelationship, 0),
Notes: make([]*DrawDBNote, 0),
SubjectAreas: make([]*DrawDBArea, 0),
}
tableID := 0
fieldID := 0
gridX, gridY := 50, 50
colWidth, rowHeight := 300, 200
tablesPerRow := 4
for idx, table := range schema.Tables {
drawTable, newFieldID := w.convertTableToDrawDB(table, schema.Name, tableID, fieldID, idx, tablesPerRow, gridX, gridY, colWidth, rowHeight, 0)
drawSchema.Tables = append(drawSchema.Tables, drawTable)
fieldID = newFieldID
tableID++
}
return drawSchema
}
// tableToDrawDB converts a single Table to DrawDB format
func (w *Writer) tableToDrawDB(table *models.Table) *DrawDBSchema {
drawSchema := &DrawDBSchema{
Tables: make([]*DrawDBTable, 0),
Relationships: make([]*DrawDBRelationship, 0),
Notes: make([]*DrawDBNote, 0),
SubjectAreas: make([]*DrawDBArea, 0),
}
drawTable, _ := w.convertTableToDrawDB(table, table.Schema, 0, 0, 0, 4, 50, 50, 300, 200, 0)
drawSchema.Tables = append(drawSchema.Tables, drawTable)
return drawSchema
}
// convertTableToDrawDB converts a table to DrawDB format and returns the table and next field ID
func (w *Writer) convertTableToDrawDB(table *models.Table, schemaName string, tableID, fieldID, tableIndex, tablesPerRow, gridX, gridY, colWidth, rowHeight, colorIndex int) (*DrawDBTable, int) {
// Calculate position
x := gridX + (tableIndex%tablesPerRow)*colWidth
y := gridY + (tableIndex/tablesPerRow)*rowHeight
drawTable := &DrawDBTable{
ID: tableID,
Name: table.Name,
Schema: schemaName,
Comment: table.Description,
Color: getColorForIndex(colorIndex),
X: x,
Y: y,
Fields: make([]*DrawDBField, 0),
Indexes: make([]*DrawDBIndex, 0),
}
// Add fields
for _, column := range table.Columns {
field := &DrawDBField{
ID: fieldID,
Name: column.Name,
Type: formatTypeForDrawDB(column),
Primary: column.IsPrimaryKey,
NotNull: column.NotNull,
Increment: column.AutoIncrement,
Comment: column.Comment,
}
if column.Default != nil {
field.Default = fmt.Sprintf("%v", column.Default)
}
// Check for unique constraint
for _, constraint := range table.Constraints {
if constraint.Type == models.UniqueConstraint {
for _, col := range constraint.Columns {
if col == column.Name {
field.Unique = true
break
}
}
}
}
drawTable.Fields = append(drawTable.Fields, field)
fieldID++
}
// Add indexes
indexID := 0
for _, index := range table.Indexes {
drawIndex := &DrawDBIndex{
ID: indexID,
Name: index.Name,
Unique: index.Unique,
Fields: make([]int, 0),
}
// Map column names to field IDs
for _, colName := range index.Columns {
for idx, field := range drawTable.Fields {
if field.Name == colName {
drawIndex.Fields = append(drawIndex.Fields, idx)
break
}
}
}
drawTable.Indexes = append(drawTable.Indexes, drawIndex)
indexID++
}
return drawTable, fieldID
}
// Helper functions
func formatTypeForDrawDB(column *models.Column) string {
typeStr := column.Type
if column.Length > 0 {
typeStr = fmt.Sprintf("%s(%d)", typeStr, column.Length)
} else if column.Precision > 0 {
if column.Scale > 0 {
typeStr = fmt.Sprintf("%s(%d,%d)", typeStr, column.Precision, column.Scale)
} else {
typeStr = fmt.Sprintf("%s(%d)", typeStr, column.Precision)
}
}
return typeStr
}
func getColorForIndex(index int) string {
colors := []string{
"#6366f1", // indigo
"#8b5cf6", // violet
"#ec4899", // pink
"#f43f5e", // rose
"#14b8a6", // teal
"#06b6d4", // cyan
"#0ea5e9", // sky
"#3b82f6", // blue
}
return colors[index%len(colors)]
}

View File

@@ -0,0 +1,284 @@
package gorm
import (
"strings"
"unicode"
)
// SnakeCaseToPascalCase converts snake_case to PascalCase
// Examples: user_id → UserID, http_request → HTTPRequest
func SnakeCaseToPascalCase(s string) string {
if s == "" {
return ""
}
parts := strings.Split(s, "_")
for i, part := range parts {
parts[i] = capitalize(part)
}
return strings.Join(parts, "")
}
// SnakeCaseToCamelCase converts snake_case to camelCase
// Examples: user_id → userID, http_request → httpRequest
func SnakeCaseToCamelCase(s string) string {
if s == "" {
return ""
}
parts := strings.Split(s, "_")
for i, part := range parts {
if i == 0 {
parts[i] = strings.ToLower(part)
} else {
parts[i] = capitalize(part)
}
}
return strings.Join(parts, "")
}
// PascalCaseToSnakeCase converts PascalCase to snake_case
// Examples: UserID → user_id, HTTPRequest → http_request
func PascalCaseToSnakeCase(s string) string {
if s == "" {
return ""
}
var result strings.Builder
var prevUpper bool
var nextUpper bool
runes := []rune(s)
for i, r := range runes {
isUpper := unicode.IsUpper(r)
if i+1 < len(runes) {
nextUpper = unicode.IsUpper(runes[i+1])
} else {
nextUpper = false
}
if i > 0 && isUpper {
// Add underscore before uppercase letter if:
// 1. Previous char was lowercase, OR
// 2. Next char is lowercase (end of acronym)
if !prevUpper || (nextUpper == false && i+1 < len(runes)) {
result.WriteRune('_')
}
}
result.WriteRune(unicode.ToLower(r))
prevUpper = isUpper
}
return result.String()
}
// capitalize capitalizes the first letter and handles common acronyms
func capitalize(s string) string {
if s == "" {
return ""
}
upper := strings.ToUpper(s)
// Handle common acronyms
acronyms := map[string]bool{
"ID": true,
"UUID": true,
"GUID": true,
"URL": true,
"URI": true,
"HTTP": true,
"HTTPS": true,
"API": true,
"JSON": true,
"XML": true,
"SQL": true,
"HTML": true,
"CSS": true,
"RID": true,
}
if acronyms[upper] {
return upper
}
// Capitalize first letter
runes := []rune(s)
runes[0] = unicode.ToUpper(runes[0])
return string(runes)
}
// Pluralize converts a singular word to plural
// Basic implementation with common rules
func Pluralize(s string) string {
if s == "" {
return ""
}
// Special cases
irregular := map[string]string{
"person": "people",
"child": "children",
"tooth": "teeth",
"foot": "feet",
"man": "men",
"woman": "women",
"mouse": "mice",
"goose": "geese",
"ox": "oxen",
"datum": "data",
"medium": "media",
"analysis": "analyses",
"crisis": "crises",
"status": "statuses",
}
if plural, ok := irregular[strings.ToLower(s)]; ok {
return plural
}
// Already plural (ends in 's' but not 'ss' or 'us')
if strings.HasSuffix(s, "s") && !strings.HasSuffix(s, "ss") && !strings.HasSuffix(s, "us") {
return s
}
// Words ending in s, x, z, ch, sh
if strings.HasSuffix(s, "s") || strings.HasSuffix(s, "x") ||
strings.HasSuffix(s, "z") || strings.HasSuffix(s, "ch") ||
strings.HasSuffix(s, "sh") {
return s + "es"
}
// Words ending in consonant + y
if len(s) >= 2 && strings.HasSuffix(s, "y") {
prevChar := s[len(s)-2]
if !isVowel(prevChar) {
return s[:len(s)-1] + "ies"
}
}
// Words ending in f or fe
if strings.HasSuffix(s, "f") {
return s[:len(s)-1] + "ves"
}
if strings.HasSuffix(s, "fe") {
return s[:len(s)-2] + "ves"
}
// Words ending in consonant + o
if len(s) >= 2 && strings.HasSuffix(s, "o") {
prevChar := s[len(s)-2]
if !isVowel(prevChar) {
return s + "es"
}
}
// Default: add 's'
return s + "s"
}
// Singularize converts a plural word to singular
// Basic implementation with common rules
func Singularize(s string) string {
if s == "" {
return ""
}
// Special cases
irregular := map[string]string{
"people": "person",
"children": "child",
"teeth": "tooth",
"feet": "foot",
"men": "man",
"women": "woman",
"mice": "mouse",
"geese": "goose",
"oxen": "ox",
"data": "datum",
"media": "medium",
"analyses": "analysis",
"crises": "crisis",
"statuses": "status",
}
if singular, ok := irregular[strings.ToLower(s)]; ok {
return singular
}
// Words ending in ies
if strings.HasSuffix(s, "ies") && len(s) > 3 {
return s[:len(s)-3] + "y"
}
// Words ending in ves
if strings.HasSuffix(s, "ves") {
return s[:len(s)-3] + "f"
}
// Words ending in ses, xes, zes, ches, shes
if strings.HasSuffix(s, "ses") || strings.HasSuffix(s, "xes") ||
strings.HasSuffix(s, "zes") || strings.HasSuffix(s, "ches") ||
strings.HasSuffix(s, "shes") {
return s[:len(s)-2]
}
// Words ending in s (not ss)
if strings.HasSuffix(s, "s") && !strings.HasSuffix(s, "ss") {
return s[:len(s)-1]
}
// Already singular
return s
}
// GeneratePrefix generates a 3-letter prefix from a table name
// Examples: process → PRO, mastertask → MTL, user → USR
func GeneratePrefix(tableName string) string {
if tableName == "" {
return "TBL"
}
// Remove common prefixes
tableName = strings.TrimPrefix(tableName, "tbl_")
tableName = strings.TrimPrefix(tableName, "tb_")
// Split by underscore and take first letters
parts := strings.Split(tableName, "_")
var prefix strings.Builder
for _, part := range parts {
if part == "" {
continue
}
prefix.WriteRune(unicode.ToUpper(rune(part[0])))
if prefix.Len() >= 3 {
break
}
}
result := prefix.String()
// If we don't have 3 letters yet, add more from the first part
if len(result) < 3 && len(parts) > 0 {
firstPart := parts[0]
for i := 1; i < len(firstPart) && len(result) < 3; i++ {
result += strings.ToUpper(string(firstPart[i]))
}
}
// Pad with 'X' if still too short
for len(result) < 3 {
result += "X"
}
return result[:3]
}
// isVowel checks if a byte is a vowel
func isVowel(c byte) bool {
c = byte(unicode.ToLower(rune(c)))
return c == 'a' || c == 'e' || c == 'i' || c == 'o' || c == 'u'
}

View File

@@ -0,0 +1,250 @@
package gorm
import (
"sort"
"git.warky.dev/wdevs/relspecgo/pkg/models"
)
// TemplateData represents the data passed to the template for code generation
type TemplateData struct {
PackageName string
Imports []string
Models []*ModelData
Config *MethodConfig
}
// ModelData represents a single model/struct in the template
type ModelData struct {
Name string
TableName string // schema.table format
SchemaName string
TableNameOnly string // just table name without schema
Comment string
Fields []*FieldData
Config *MethodConfig
PrimaryKeyField string // Name of the primary key field
IDColumnName string // Name of the ID column in database
Prefix string // 3-letter prefix
}
// FieldData represents a single field in a struct
type FieldData struct {
Name string // Go field name (PascalCase)
Type string // Go type
GormTag string // Complete gorm tag
JSONTag string // JSON tag
Comment string // Field comment
}
// MethodConfig controls which helper methods to generate
type MethodConfig struct {
GenerateTableName bool
GenerateSchemaName bool
GenerateTableNameOnly bool
GenerateGetID bool
GenerateGetIDStr bool
GenerateSetID bool
GenerateUpdateID bool
GenerateGetIDName bool
GenerateGetPrefix bool
}
// DefaultMethodConfig returns a MethodConfig with all methods enabled
func DefaultMethodConfig() *MethodConfig {
return &MethodConfig{
GenerateTableName: true,
GenerateSchemaName: true,
GenerateTableNameOnly: true,
GenerateGetID: true,
GenerateGetIDStr: true,
GenerateSetID: true,
GenerateUpdateID: true,
GenerateGetIDName: true,
GenerateGetPrefix: true,
}
}
// NewTemplateData creates a new TemplateData with the given package name and config
func NewTemplateData(packageName string, config *MethodConfig) *TemplateData {
if config == nil {
config = DefaultMethodConfig()
}
return &TemplateData{
PackageName: packageName,
Imports: make([]string, 0),
Models: make([]*ModelData, 0),
Config: config,
}
}
// AddModel adds a model to the template data
func (td *TemplateData) AddModel(model *ModelData) {
model.Config = td.Config
td.Models = append(td.Models, model)
}
// AddImport adds an import to the template data (deduplicates automatically)
func (td *TemplateData) AddImport(importPath string) {
// Check if already exists
for _, imp := range td.Imports {
if imp == importPath {
return
}
}
td.Imports = append(td.Imports, importPath)
}
// FinalizeImports sorts and organizes imports
func (td *TemplateData) FinalizeImports() {
// Sort imports alphabetically
sort.Strings(td.Imports)
}
// NewModelData creates a new ModelData from a models.Table
func NewModelData(table *models.Table, schema string, typeMapper *TypeMapper) *ModelData {
tableName := table.Name
if schema != "" {
tableName = schema + "." + table.Name
}
// Generate model name: singularize and convert to PascalCase
singularTable := Singularize(table.Name)
modelName := SnakeCaseToPascalCase(singularTable)
// Add "Model" prefix if not already present
if !hasModelPrefix(modelName) {
modelName = "Model" + modelName
}
model := &ModelData{
Name: modelName,
TableName: tableName,
SchemaName: schema,
TableNameOnly: table.Name,
Comment: formatComment(table.Description, table.Comment),
Fields: make([]*FieldData, 0),
Prefix: GeneratePrefix(table.Name),
}
// Find primary key
for _, col := range table.Columns {
if col.IsPrimaryKey {
model.PrimaryKeyField = SnakeCaseToPascalCase(col.Name)
model.IDColumnName = col.Name
break
}
}
// Convert columns to fields (sorted by sequence or name)
columns := sortColumns(table.Columns)
for _, col := range columns {
field := columnToField(col, table, typeMapper)
model.Fields = append(model.Fields, field)
}
return model
}
// columnToField converts a models.Column to FieldData
func columnToField(col *models.Column, table *models.Table, typeMapper *TypeMapper) *FieldData {
fieldName := SnakeCaseToPascalCase(col.Name)
goType := typeMapper.SQLTypeToGoType(col.Type, col.NotNull)
gormTag := typeMapper.BuildGormTag(col, table)
jsonTag := col.Name // Use column name for JSON tag
return &FieldData{
Name: fieldName,
Type: goType,
GormTag: gormTag,
JSONTag: jsonTag,
Comment: formatComment(col.Description, col.Comment),
}
}
// AddRelationshipField adds a relationship field to the model
func (md *ModelData) AddRelationshipField(field *FieldData) {
md.Fields = append(md.Fields, field)
}
// formatComment combines description and comment into a single comment string
func formatComment(description, comment string) string {
if description != "" && comment != "" {
return description + " - " + comment
}
if description != "" {
return description
}
return comment
}
// hasModelPrefix checks if a name already has "Model" prefix
func hasModelPrefix(name string) bool {
return len(name) >= 5 && name[:5] == "Model"
}
// sortColumns sorts columns by sequence, then by name
func sortColumns(columns map[string]*models.Column) []*models.Column {
result := make([]*models.Column, 0, len(columns))
for _, col := range columns {
result = append(result, col)
}
sort.Slice(result, func(i, j int) bool {
// Sort by sequence if both have it
if result[i].Sequence > 0 && result[j].Sequence > 0 {
return result[i].Sequence < result[j].Sequence
}
// Put primary keys first
if result[i].IsPrimaryKey != result[j].IsPrimaryKey {
return result[i].IsPrimaryKey
}
// Otherwise sort alphabetically
return result[i].Name < result[j].Name
})
return result
}
// LoadMethodConfigFromMetadata loads method configuration from metadata map
func LoadMethodConfigFromMetadata(metadata map[string]interface{}) *MethodConfig {
config := DefaultMethodConfig()
if metadata == nil {
return config
}
// Load each setting from metadata if present
if val, ok := metadata["generate_table_name"].(bool); ok {
config.GenerateTableName = val
}
if val, ok := metadata["generate_schema_name"].(bool); ok {
config.GenerateSchemaName = val
}
if val, ok := metadata["generate_table_name_only"].(bool); ok {
config.GenerateTableNameOnly = val
}
if val, ok := metadata["generate_get_id"].(bool); ok {
config.GenerateGetID = val
}
if val, ok := metadata["generate_get_id_str"].(bool); ok {
config.GenerateGetIDStr = val
}
if val, ok := metadata["generate_set_id"].(bool); ok {
config.GenerateSetID = val
}
if val, ok := metadata["generate_update_id"].(bool); ok {
config.GenerateUpdateID = val
}
if val, ok := metadata["generate_get_id_name"].(bool); ok {
config.GenerateGetIDName = val
}
if val, ok := metadata["generate_get_prefix"].(bool); ok {
config.GenerateGetPrefix = val
}
return config
}

View File

@@ -0,0 +1,109 @@
package gorm
import (
"bytes"
"text/template"
)
// modelTemplate defines the template for generating GORM models
const modelTemplate = `// Code generated by relspecgo. DO NOT EDIT.
package {{.PackageName}}
{{if .Imports -}}
import (
{{range .Imports -}}
{{.}}
{{end -}}
)
{{end}}
{{range .Models}}
{{if .Comment}}// {{.Comment}}{{end}}
type {{.Name}} struct {
{{- range .Fields}}
{{.Name}} {{.Type}} ` + "`gorm:\"{{.GormTag}}\" json:\"{{.JSONTag}}\"`" + `{{if .Comment}} // {{.Comment}}{{end}}
{{- end}}
}
{{if .Config.GenerateTableName}}
// TableName returns the table name for {{.Name}}
func (m {{.Name}}) TableName() string {
return "{{.TableName}}"
}
{{end}}
{{if .Config.GenerateTableNameOnly}}
// TableNameOnly returns the table name without schema for {{.Name}}
func (m {{.Name}}) TableNameOnly() string {
return "{{.TableNameOnly}}"
}
{{end}}
{{if .Config.GenerateSchemaName}}
// SchemaName returns the schema name for {{.Name}}
func (m {{.Name}}) SchemaName() string {
return "{{.SchemaName}}"
}
{{end}}
{{if and .Config.GenerateGetID .PrimaryKeyField}}
// GetID returns the primary key value
func (m {{.Name}}) GetID() int64 {
return int64(m.{{.PrimaryKeyField}})
}
{{end}}
{{if and .Config.GenerateGetIDStr .PrimaryKeyField}}
// GetIDStr returns the primary key as a string
func (m {{.Name}}) GetIDStr() string {
return fmt.Sprintf("%d", m.{{.PrimaryKeyField}})
}
{{end}}
{{if and .Config.GenerateSetID .PrimaryKeyField}}
// SetID sets the primary key value
func (m {{.Name}}) SetID(newid int64) {
m.UpdateID(newid)
}
{{end}}
{{if and .Config.GenerateUpdateID .PrimaryKeyField}}
// UpdateID updates the primary key value
func (m *{{.Name}}) UpdateID(newid int64) {
m.{{.PrimaryKeyField}} = int32(newid)
}
{{end}}
{{if and .Config.GenerateGetIDName .IDColumnName}}
// GetIDName returns the name of the primary key column
func (m {{.Name}}) GetIDName() string {
return "{{.IDColumnName}}"
}
{{end}}
{{if .Config.GenerateGetPrefix}}
// GetPrefix returns the table prefix
func (m {{.Name}}) GetPrefix() string {
return "{{.Prefix}}"
}
{{end}}
{{end -}}
`
// Templates holds the parsed templates
type Templates struct {
modelTmpl *template.Template
}
// NewTemplates creates and parses the templates
func NewTemplates() (*Templates, error) {
modelTmpl, err := template.New("model").Parse(modelTemplate)
if err != nil {
return nil, err
}
return &Templates{
modelTmpl: modelTmpl,
}, nil
}
// GenerateCode executes the template with the given data
func (t *Templates) GenerateCode(data *TemplateData) (string, error) {
var buf bytes.Buffer
err := t.modelTmpl.Execute(&buf, data)
if err != nil {
return "", err
}
return buf.String(), nil
}

View File

@@ -0,0 +1,335 @@
package gorm
import (
"fmt"
"strings"
"git.warky.dev/wdevs/relspecgo/pkg/models"
)
// TypeMapper handles type conversions between SQL and Go types
type TypeMapper struct {
// Package alias for sql_types import
sqlTypesAlias string
}
// NewTypeMapper creates a new TypeMapper with default settings
func NewTypeMapper() *TypeMapper {
return &TypeMapper{
sqlTypesAlias: "sql_types",
}
}
// SQLTypeToGoType converts a SQL type to its Go equivalent
// Handles nullable types using ResolveSpec sql_types package
func (tm *TypeMapper) SQLTypeToGoType(sqlType string, notNull bool) string {
// Normalize SQL type (lowercase, remove length/precision)
baseType := tm.extractBaseType(sqlType)
// If not null, use base Go types
if notNull {
return tm.baseGoType(baseType)
}
// For nullable fields, use sql_types
return tm.nullableGoType(baseType)
}
// extractBaseType extracts the base type from a SQL type string
// Examples: varchar(100) → varchar, numeric(10,2) → numeric
func (tm *TypeMapper) extractBaseType(sqlType string) string {
sqlType = strings.ToLower(strings.TrimSpace(sqlType))
// Remove everything after '('
if idx := strings.Index(sqlType, "("); idx > 0 {
sqlType = sqlType[:idx]
}
return sqlType
}
// baseGoType returns the base Go type for a SQL type (not null)
func (tm *TypeMapper) baseGoType(sqlType string) string {
typeMap := map[string]string{
// Integer types
"integer": "int32",
"int": "int32",
"int4": "int32",
"smallint": "int16",
"int2": "int16",
"bigint": "int64",
"int8": "int64",
"serial": "int32",
"bigserial": "int64",
"smallserial": "int16",
// String types
"text": "string",
"varchar": "string",
"char": "string",
"character": "string",
"citext": "string",
"bpchar": "string",
// Boolean
"boolean": "bool",
"bool": "bool",
// Float types
"real": "float32",
"float4": "float32",
"double precision": "float64",
"float8": "float64",
"numeric": "float64",
"decimal": "float64",
// Date/Time types
"timestamp": "time.Time",
"timestamp without time zone": "time.Time",
"timestamp with time zone": "time.Time",
"timestamptz": "time.Time",
"date": "time.Time",
"time": "time.Time",
"time without time zone": "time.Time",
"time with time zone": "time.Time",
"timetz": "time.Time",
// Binary
"bytea": "[]byte",
// UUID
"uuid": "string",
// JSON
"json": "string",
"jsonb": "string",
// Network
"inet": "string",
"cidr": "string",
"macaddr": "string",
// Other
"money": "float64",
}
if goType, ok := typeMap[sqlType]; ok {
return goType
}
// Default to string for unknown types
return "string"
}
// nullableGoType returns the nullable Go type using sql_types package
func (tm *TypeMapper) nullableGoType(sqlType string) string {
typeMap := map[string]string{
// Integer types
"integer": tm.sqlTypesAlias + ".SqlInt32",
"int": tm.sqlTypesAlias + ".SqlInt32",
"int4": tm.sqlTypesAlias + ".SqlInt32",
"smallint": tm.sqlTypesAlias + ".SqlInt16",
"int2": tm.sqlTypesAlias + ".SqlInt16",
"bigint": tm.sqlTypesAlias + ".SqlInt64",
"int8": tm.sqlTypesAlias + ".SqlInt64",
"serial": tm.sqlTypesAlias + ".SqlInt32",
"bigserial": tm.sqlTypesAlias + ".SqlInt64",
"smallserial": tm.sqlTypesAlias + ".SqlInt16",
// String types
"text": tm.sqlTypesAlias + ".SqlString",
"varchar": tm.sqlTypesAlias + ".SqlString",
"char": tm.sqlTypesAlias + ".SqlString",
"character": tm.sqlTypesAlias + ".SqlString",
"citext": tm.sqlTypesAlias + ".SqlString",
"bpchar": tm.sqlTypesAlias + ".SqlString",
// Boolean
"boolean": tm.sqlTypesAlias + ".SqlBool",
"bool": tm.sqlTypesAlias + ".SqlBool",
// Float types
"real": tm.sqlTypesAlias + ".SqlFloat32",
"float4": tm.sqlTypesAlias + ".SqlFloat32",
"double precision": tm.sqlTypesAlias + ".SqlFloat64",
"float8": tm.sqlTypesAlias + ".SqlFloat64",
"numeric": tm.sqlTypesAlias + ".SqlFloat64",
"decimal": tm.sqlTypesAlias + ".SqlFloat64",
// Date/Time types
"timestamp": tm.sqlTypesAlias + ".SqlTime",
"timestamp without time zone": tm.sqlTypesAlias + ".SqlTime",
"timestamp with time zone": tm.sqlTypesAlias + ".SqlTime",
"timestamptz": tm.sqlTypesAlias + ".SqlTime",
"date": tm.sqlTypesAlias + ".SqlDate",
"time": tm.sqlTypesAlias + ".SqlTime",
"time without time zone": tm.sqlTypesAlias + ".SqlTime",
"time with time zone": tm.sqlTypesAlias + ".SqlTime",
"timetz": tm.sqlTypesAlias + ".SqlTime",
// Binary
"bytea": "[]byte", // No nullable version needed
// UUID
"uuid": tm.sqlTypesAlias + ".SqlUUID",
// JSON
"json": tm.sqlTypesAlias + ".SqlString",
"jsonb": tm.sqlTypesAlias + ".SqlString",
// Network
"inet": tm.sqlTypesAlias + ".SqlString",
"cidr": tm.sqlTypesAlias + ".SqlString",
"macaddr": tm.sqlTypesAlias + ".SqlString",
// Other
"money": tm.sqlTypesAlias + ".SqlFloat64",
}
if goType, ok := typeMap[sqlType]; ok {
return goType
}
// Default to SqlString for unknown types
return tm.sqlTypesAlias + ".SqlString"
}
// BuildGormTag generates a complete GORM tag string for a column
func (tm *TypeMapper) BuildGormTag(column *models.Column, table *models.Table) string {
var parts []string
// Always include column name (lowercase as per user requirement)
parts = append(parts, fmt.Sprintf("column:%s", column.Name))
// Add type if specified
if column.Type != "" {
// Include length, precision, scale if present
typeStr := column.Type
if column.Length > 0 {
typeStr = fmt.Sprintf("%s(%d)", typeStr, column.Length)
} else if column.Precision > 0 {
if column.Scale > 0 {
typeStr = fmt.Sprintf("%s(%d,%d)", typeStr, column.Precision, column.Scale)
} else {
typeStr = fmt.Sprintf("%s(%d)", typeStr, column.Precision)
}
}
parts = append(parts, fmt.Sprintf("type:%s", typeStr))
}
// Primary key
if column.IsPrimaryKey {
parts = append(parts, "primaryKey")
}
// Auto increment
if column.AutoIncrement {
parts = append(parts, "autoIncrement")
}
// Not null (skip if primary key, as it's implied)
if column.NotNull && !column.IsPrimaryKey {
parts = append(parts, "not null")
}
// Default value
if column.Default != nil {
parts = append(parts, fmt.Sprintf("default:%v", column.Default))
}
// Check for unique constraint
if table != nil {
for _, constraint := range table.Constraints {
if constraint.Type == models.UniqueConstraint {
for _, col := range constraint.Columns {
if col == column.Name {
if constraint.Name != "" {
parts = append(parts, fmt.Sprintf("uniqueIndex:%s", constraint.Name))
} else {
parts = append(parts, "unique")
}
break
}
}
}
}
// Check for index
for _, index := range table.Indexes {
for _, col := range index.Columns {
if col == column.Name {
if index.Unique {
if index.Name != "" {
parts = append(parts, fmt.Sprintf("uniqueIndex:%s", index.Name))
} else {
parts = append(parts, "unique")
}
} else {
if index.Name != "" {
parts = append(parts, fmt.Sprintf("index:%s", index.Name))
} else {
parts = append(parts, "index")
}
}
break
}
}
}
}
return strings.Join(parts, ";")
}
// BuildRelationshipTag generates GORM tag for relationship fields
func (tm *TypeMapper) BuildRelationshipTag(constraint *models.Constraint, isParent bool) string {
var parts []string
if !isParent {
// Child side (has foreign key)
if len(constraint.Columns) > 0 && len(constraint.ReferencedColumns) > 0 {
// foreignKey points to the field name in this struct
fkFieldName := SnakeCaseToPascalCase(constraint.Columns[0])
parts = append(parts, fmt.Sprintf("foreignKey:%s", fkFieldName))
// references points to the field name in the other struct
refFieldName := SnakeCaseToPascalCase(constraint.ReferencedColumns[0])
parts = append(parts, fmt.Sprintf("references:%s", refFieldName))
}
} else {
// Parent side (being referenced)
if len(constraint.Columns) > 0 {
fkFieldName := SnakeCaseToPascalCase(constraint.Columns[0])
parts = append(parts, fmt.Sprintf("foreignKey:%s", fkFieldName))
}
}
// Add constraint actions
if constraint.OnDelete != "" {
parts = append(parts, fmt.Sprintf("constraint:OnDelete:%s", strings.ToUpper(constraint.OnDelete)))
}
if constraint.OnUpdate != "" {
if len(parts) > 0 && strings.Contains(parts[len(parts)-1], "constraint:") {
// Append to existing constraint
parts[len(parts)-1] += fmt.Sprintf(",OnUpdate:%s", strings.ToUpper(constraint.OnUpdate))
} else {
parts = append(parts, fmt.Sprintf("constraint:OnUpdate:%s", strings.ToUpper(constraint.OnUpdate)))
}
}
return strings.Join(parts, ";")
}
// NeedsTimeImport checks if the Go type requires time package import
func (tm *TypeMapper) NeedsTimeImport(goType string) bool {
return strings.Contains(goType, "time.Time")
}
// NeedsFmtImport checks if we need fmt import (for GetIDStr method)
func (tm *TypeMapper) NeedsFmtImport(generateGetIDStr bool) bool {
return generateGetIDStr
}
// GetSQLTypesImport returns the import path for sql_types
func (tm *TypeMapper) GetSQLTypesImport() string {
return "github.com/bitechdev/ResolveSpec/pkg/common/sql_types"
}

324
pkg/writers/gorm/writer.go Normal file
View File

@@ -0,0 +1,324 @@
package gorm
import (
"fmt"
"go/format"
"os"
"path/filepath"
"strings"
"git.warky.dev/wdevs/relspecgo/pkg/models"
"git.warky.dev/wdevs/relspecgo/pkg/writers"
)
// Writer implements the writers.Writer interface for GORM models
type Writer struct {
options *writers.WriterOptions
typeMapper *TypeMapper
templates *Templates
config *MethodConfig
}
// NewWriter creates a new GORM writer with the given options
func NewWriter(options *writers.WriterOptions) *Writer {
w := &Writer{
options: options,
typeMapper: NewTypeMapper(),
config: LoadMethodConfigFromMetadata(options.Metadata),
}
// Initialize templates
tmpl, err := NewTemplates()
if err != nil {
// Should not happen with embedded templates
panic(fmt.Sprintf("failed to initialize templates: %v", err))
}
w.templates = tmpl
return w
}
// WriteDatabase writes a complete database as GORM models
func (w *Writer) WriteDatabase(db *models.Database) error {
// Check if multi-file mode is enabled
multiFile := false
if w.options.Metadata != nil {
if mf, ok := w.options.Metadata["multi_file"].(bool); ok {
multiFile = mf
}
}
if multiFile {
return w.writeMultiFile(db)
}
return w.writeSingleFile(db)
}
// WriteSchema writes a schema as GORM models
func (w *Writer) WriteSchema(schema *models.Schema) error {
// Create a temporary database with just this schema
db := models.InitDatabase(schema.Name)
db.Schemas = []*models.Schema{schema}
return w.WriteDatabase(db)
}
// WriteTable writes a single table as a GORM model
func (w *Writer) WriteTable(table *models.Table) error {
// Create a temporary schema and database
schema := models.InitSchema(table.Schema)
schema.Tables = []*models.Table{table}
db := models.InitDatabase(schema.Name)
db.Schemas = []*models.Schema{schema}
return w.WriteDatabase(db)
}
// writeSingleFile writes all models to a single file
func (w *Writer) writeSingleFile(db *models.Database) error {
packageName := w.getPackageName()
templateData := NewTemplateData(packageName, w.config)
// Add sql_types import (always needed for nullable types)
templateData.AddImport(fmt.Sprintf("sql_types \"%s\"", w.typeMapper.GetSQLTypesImport()))
// Collect all models
for _, schema := range db.Schemas {
for _, table := range schema.Tables {
modelData := NewModelData(table, schema.Name, w.typeMapper)
// Add relationship fields
w.addRelationshipFields(modelData, table, schema, db)
templateData.AddModel(modelData)
// Check if we need time import
for _, field := range modelData.Fields {
if w.typeMapper.NeedsTimeImport(field.Type) {
templateData.AddImport("\"time\"")
}
}
}
}
// Add fmt import if GetIDStr is enabled
if w.config.GenerateGetIDStr {
templateData.AddImport("\"fmt\"")
}
// Finalize imports
templateData.FinalizeImports()
// Generate code
code, err := w.templates.GenerateCode(templateData)
if err != nil {
return fmt.Errorf("failed to generate code: %w", err)
}
// Format code
formatted, err := w.formatCode(code)
if err != nil {
// Return unformatted code with warning
fmt.Fprintf(os.Stderr, "Warning: failed to format code: %v\n", err)
formatted = code
}
// Write output
return w.writeOutput(formatted)
}
// writeMultiFile writes each table to a separate file
func (w *Writer) writeMultiFile(db *models.Database) error {
packageName := w.getPackageName()
// Ensure output path is a directory
if w.options.OutputPath == "" {
return fmt.Errorf("output path is required for multi-file mode")
}
// Create output directory if it doesn't exist
if err := os.MkdirAll(w.options.OutputPath, 0755); err != nil {
return fmt.Errorf("failed to create output directory: %w", err)
}
// Generate a file for each table
for _, schema := range db.Schemas {
for _, table := range schema.Tables {
// Create template data for this single table
templateData := NewTemplateData(packageName, w.config)
// Add sql_types import
templateData.AddImport(fmt.Sprintf("sql_types \"%s\"", w.typeMapper.GetSQLTypesImport()))
// Create model data
modelData := NewModelData(table, schema.Name, w.typeMapper)
// Add relationship fields
w.addRelationshipFields(modelData, table, schema, db)
templateData.AddModel(modelData)
// Check if we need time import
for _, field := range modelData.Fields {
if w.typeMapper.NeedsTimeImport(field.Type) {
templateData.AddImport("\"time\"")
}
}
// Add fmt import if GetIDStr is enabled
if w.config.GenerateGetIDStr {
templateData.AddImport("\"fmt\"")
}
// Finalize imports
templateData.FinalizeImports()
// Generate code
code, err := w.templates.GenerateCode(templateData)
if err != nil {
return fmt.Errorf("failed to generate code for table %s: %w", table.Name, err)
}
// Format code
formatted, err := w.formatCode(code)
if err != nil {
fmt.Fprintf(os.Stderr, "Warning: failed to format code for %s: %v\n", table.Name, err)
formatted = code
}
// Generate filename: sql_{schema}_{table}.go
filename := fmt.Sprintf("sql_%s_%s.go", schema.Name, table.Name)
filepath := filepath.Join(w.options.OutputPath, filename)
// Write file
if err := os.WriteFile(filepath, []byte(formatted), 0644); err != nil {
return fmt.Errorf("failed to write file %s: %w", filename, err)
}
}
}
return nil
}
// addRelationshipFields adds relationship fields to the model based on foreign keys
func (w *Writer) addRelationshipFields(modelData *ModelData, table *models.Table, schema *models.Schema, db *models.Database) {
// For each foreign key in this table, add a belongs-to relationship
for _, constraint := range table.Constraints {
if constraint.Type != models.ForeignKeyConstraint {
continue
}
// Find the referenced table
refTable := w.findTable(constraint.ReferencedSchema, constraint.ReferencedTable, db)
if refTable == nil {
continue
}
// Create relationship field (belongs-to)
refModelName := w.getModelName(constraint.ReferencedTable)
fieldName := w.generateRelationshipFieldName(constraint.ReferencedTable)
relationTag := w.typeMapper.BuildRelationshipTag(constraint, false)
modelData.AddRelationshipField(&FieldData{
Name: fieldName,
Type: "*" + refModelName, // Pointer type
GormTag: relationTag,
JSONTag: strings.ToLower(fieldName) + ",omitempty",
Comment: fmt.Sprintf("Belongs to %s", refModelName),
})
}
// For each table that references this table, add a has-many relationship
for _, otherSchema := range db.Schemas {
for _, otherTable := range otherSchema.Tables {
if otherTable.Name == table.Name && otherSchema.Name == schema.Name {
continue // Skip self
}
for _, constraint := range otherTable.Constraints {
if constraint.Type != models.ForeignKeyConstraint {
continue
}
// Check if this constraint references our table
if constraint.ReferencedTable == table.Name && constraint.ReferencedSchema == schema.Name {
// Add has-many relationship
otherModelName := w.getModelName(otherTable.Name)
fieldName := w.generateRelationshipFieldName(otherTable.Name) + "s" // Pluralize
relationTag := w.typeMapper.BuildRelationshipTag(constraint, true)
modelData.AddRelationshipField(&FieldData{
Name: fieldName,
Type: "[]*" + otherModelName, // Slice of pointers
GormTag: relationTag,
JSONTag: strings.ToLower(fieldName) + ",omitempty",
Comment: fmt.Sprintf("Has many %s", otherModelName),
})
}
}
}
}
}
// findTable finds a table by schema and name in the database
func (w *Writer) findTable(schemaName, tableName string, db *models.Database) *models.Table {
for _, schema := range db.Schemas {
if schema.Name != schemaName {
continue
}
for _, table := range schema.Tables {
if table.Name == tableName {
return table
}
}
}
return nil
}
// getModelName generates the model name from a table name
func (w *Writer) getModelName(tableName string) string {
singular := Singularize(tableName)
modelName := SnakeCaseToPascalCase(singular)
if !hasModelPrefix(modelName) {
modelName = "Model" + modelName
}
return modelName
}
// generateRelationshipFieldName generates a field name for a relationship
func (w *Writer) generateRelationshipFieldName(tableName string) string {
// Use just the prefix (3 letters) for relationship fields
return GeneratePrefix(tableName)
}
// getPackageName returns the package name from options or defaults to "models"
func (w *Writer) getPackageName() string {
if w.options.PackageName != "" {
return w.options.PackageName
}
return "models"
}
// formatCode formats Go code using gofmt
func (w *Writer) formatCode(code string) (string, error) {
formatted, err := format.Source([]byte(code))
if err != nil {
return "", fmt.Errorf("format error: %w", err)
}
return string(formatted), nil
}
// writeOutput writes the content to file or stdout
func (w *Writer) writeOutput(content string) error {
if w.options.OutputPath != "" {
return os.WriteFile(w.options.OutputPath, []byte(content), 0644)
}
// Print to stdout
fmt.Print(content)
return nil
}

View File

@@ -0,0 +1,243 @@
package gorm
import (
"os"
"path/filepath"
"strings"
"testing"
"git.warky.dev/wdevs/relspecgo/pkg/models"
"git.warky.dev/wdevs/relspecgo/pkg/writers"
)
func TestWriter_WriteTable(t *testing.T) {
// Create a simple table
table := models.InitTable("users", "public")
table.Columns["id"] = &models.Column{
Name: "id",
Type: "bigint",
NotNull: true,
IsPrimaryKey: true,
AutoIncrement: true,
Sequence: 1,
}
table.Columns["email"] = &models.Column{
Name: "email",
Type: "varchar",
Length: 255,
NotNull: false,
Sequence: 2,
}
table.Columns["created_at"] = &models.Column{
Name: "created_at",
Type: "timestamp",
NotNull: true,
Sequence: 3,
}
// Create writer
opts := &writers.WriterOptions{
PackageName: "models",
Metadata: map[string]interface{}{
"generate_table_name": true,
"generate_get_id": true,
},
}
writer := NewWriter(opts)
// Write to temporary file
tmpDir := t.TempDir()
opts.OutputPath = filepath.Join(tmpDir, "test.go")
err := writer.WriteTable(table)
if err != nil {
t.Fatalf("WriteTable failed: %v", err)
}
// Read the generated file
content, err := os.ReadFile(opts.OutputPath)
if err != nil {
t.Fatalf("Failed to read generated file: %v", err)
}
generated := string(content)
// Verify key elements are present
expectations := []string{
"package models",
"type ModelUser struct",
"ID",
"int64",
"Email",
"sql_types.SqlString",
"CreatedAt",
"time.Time",
"gorm:\"column:id",
"gorm:\"column:email",
"func (m ModelUser) TableName() string",
"return \"public.users\"",
"func (m ModelUser) GetID() int64",
}
for _, expected := range expectations {
if !strings.Contains(generated, expected) {
t.Errorf("Generated code missing expected content: %q\nGenerated:\n%s", expected, generated)
}
}
}
func TestWriter_WriteDatabase_MultiFile(t *testing.T) {
// Create a database with two tables
db := models.InitDatabase("testdb")
schema := models.InitSchema("public")
// Table 1: users
users := models.InitTable("users", "public")
users.Columns["id"] = &models.Column{
Name: "id",
Type: "bigint",
NotNull: true,
IsPrimaryKey: true,
}
schema.Tables = append(schema.Tables, users)
// Table 2: posts
posts := models.InitTable("posts", "public")
posts.Columns["id"] = &models.Column{
Name: "id",
Type: "bigint",
NotNull: true,
IsPrimaryKey: true,
}
posts.Columns["user_id"] = &models.Column{
Name: "user_id",
Type: "bigint",
NotNull: true,
}
posts.Constraints["fk_user"] = &models.Constraint{
Name: "fk_user",
Type: models.ForeignKeyConstraint,
Columns: []string{"user_id"},
ReferencedTable: "users",
ReferencedSchema: "public",
ReferencedColumns: []string{"id"},
OnDelete: "CASCADE",
}
schema.Tables = append(schema.Tables, posts)
db.Schemas = append(db.Schemas, schema)
// Create writer with multi-file mode
tmpDir := t.TempDir()
opts := &writers.WriterOptions{
PackageName: "models",
OutputPath: tmpDir,
Metadata: map[string]interface{}{
"multi_file": true,
},
}
writer := NewWriter(opts)
err := writer.WriteDatabase(db)
if err != nil {
t.Fatalf("WriteDatabase failed: %v", err)
}
// Verify two files were created
expectedFiles := []string{
"sql_public_users.go",
"sql_public_posts.go",
}
for _, filename := range expectedFiles {
filepath := filepath.Join(tmpDir, filename)
if _, err := os.Stat(filepath); os.IsNotExist(err) {
t.Errorf("Expected file not created: %s", filename)
}
}
// Check posts file contains relationship
postsContent, err := os.ReadFile(filepath.Join(tmpDir, "sql_public_posts.go"))
if err != nil {
t.Fatalf("Failed to read posts file: %v", err)
}
if !strings.Contains(string(postsContent), "USE *ModelUser") {
// Relationship field should be present
t.Logf("Posts content:\n%s", string(postsContent))
}
}
func TestNameConverter_SnakeCaseToPascalCase(t *testing.T) {
tests := []struct {
input string
expected string
}{
{"user_id", "UserID"},
{"http_request", "HTTPRequest"},
{"user_profiles", "UserProfiles"},
{"guid", "GUID"},
{"rid_process", "RIDProcess"},
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
result := SnakeCaseToPascalCase(tt.input)
if result != tt.expected {
t.Errorf("SnakeCaseToPascalCase(%q) = %q, want %q", tt.input, result, tt.expected)
}
})
}
}
func TestNameConverter_Pluralize(t *testing.T) {
tests := []struct {
input string
expected string
}{
{"user", "users"},
{"process", "processes"},
{"child", "children"},
{"person", "people"},
{"status", "statuses"},
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
result := Pluralize(tt.input)
if result != tt.expected {
t.Errorf("Pluralize(%q) = %q, want %q", tt.input, result, tt.expected)
}
})
}
}
func TestTypeMapper_SQLTypeToGoType(t *testing.T) {
mapper := NewTypeMapper()
tests := []struct {
sqlType string
notNull bool
want string
}{
{"bigint", true, "int64"},
{"bigint", false, "sql_types.SqlInt64"},
{"varchar", true, "string"},
{"varchar", false, "sql_types.SqlString"},
{"timestamp", true, "time.Time"},
{"timestamp", false, "sql_types.SqlTime"},
{"boolean", true, "bool"},
{"boolean", false, "sql_types.SqlBool"},
}
for _, tt := range tests {
t.Run(tt.sqlType, func(t *testing.T) {
result := mapper.SQLTypeToGoType(tt.sqlType, tt.notNull)
if result != tt.want {
t.Errorf("SQLTypeToGoType(%q, %v) = %q, want %q", tt.sqlType, tt.notNull, result, tt.want)
}
})
}
}

View File

@@ -5,10 +5,16 @@ import (
)
// Writer defines the interface for writing database specifications
// to various output formats
// to various output formats at different granularity levels
type Writer interface {
// Write takes a Database model and writes it to the desired format
Write(db *models.Database) error
// WriteDatabase takes a Database model and writes it to the desired format
WriteDatabase(db *models.Database) error
// WriteSchema takes a Schema model and writes it to the desired format
WriteSchema(schema *models.Schema) error
// WriteTable takes a Table model and writes it to the desired format
WriteTable(table *models.Table) error
}
// WriterOptions contains common options for writers