Files
relspecgo/pkg/writers/pgsql/template_functions.go
Hein 5e1448dcdb
Some checks are pending
CI / Test (1.23) (push) Waiting to run
CI / Test (1.24) (push) Waiting to run
CI / Test (1.25) (push) Waiting to run
CI / Lint (push) Waiting to run
CI / Build (push) Waiting to run
sql writer
2025-12-17 20:44:02 +02:00

286 lines
6.8 KiB
Go

package pgsql
import (
"fmt"
"regexp"
"strings"
"unicode"
)
// TemplateFunctions returns a map of custom template functions
func TemplateFunctions() map[string]interface{} {
return map[string]interface{}{
// String manipulation
"upper": strings.ToUpper,
"lower": strings.ToLower,
"snake_case": toSnakeCase,
"camelCase": toCamelCase,
// SQL formatting
"indent": indent,
"quote": quote,
"escape": escape,
"safe_identifier": safeIdentifier,
// Type conversion
"goTypeToSQL": goTypeToSQL,
"sqlTypeToGo": sqlTypeToGo,
"isNumeric": isNumeric,
"isText": isText,
// Collection helpers
"first": first,
"last": last,
"filter": filter,
"mapFunc": mapFunc,
"join_with": joinWith,
// Built-in Go template function (for convenience)
"join": strings.Join,
}
}
// String manipulation functions
// toSnakeCase converts a string to snake_case
func toSnakeCase(s string) string {
// Insert underscore before uppercase letters
var result strings.Builder
for i, r := range s {
if unicode.IsUpper(r) {
if i > 0 {
result.WriteRune('_')
}
result.WriteRune(unicode.ToLower(r))
} else {
result.WriteRune(r)
}
}
return result.String()
}
// toCamelCase converts a string to camelCase
func toCamelCase(s string) string {
// Split by underscore
parts := strings.Split(s, "_")
if len(parts) == 0 {
return s
}
// First part stays lowercase
result := strings.ToLower(parts[0])
// Capitalize first letter of remaining parts
for _, part := range parts[1:] {
if len(part) > 0 {
result += strings.ToUpper(part[0:1]) + strings.ToLower(part[1:])
}
}
return result
}
// SQL formatting functions
// indent indents each line of text by the specified number of spaces
func indent(spaces int, text string) string {
prefix := strings.Repeat(" ", spaces)
lines := strings.Split(text, "\n")
for i, line := range lines {
if line != "" {
lines[i] = prefix + line
}
}
return strings.Join(lines, "\n")
}
// quote quotes a string value for SQL (escapes single quotes)
func quote(s string) string {
return "'" + strings.ReplaceAll(s, "'", "''") + "'"
}
// escape escapes a string for SQL (escapes single quotes and backslashes)
func escape(s string) string {
s = strings.ReplaceAll(s, "\\", "\\\\")
s = strings.ReplaceAll(s, "'", "''")
return s
}
// safeIdentifier makes a string safe to use as a SQL identifier
func safeIdentifier(s string) string {
// Remove or replace dangerous characters
// Allow: letters, numbers, underscore
reg := regexp.MustCompile(`[^a-zA-Z0-9_]`)
safe := reg.ReplaceAllString(s, "_")
// Ensure it doesn't start with a number
if len(safe) > 0 && unicode.IsDigit(rune(safe[0])) {
safe = "_" + safe
}
// Convert to lowercase (PostgreSQL convention)
return strings.ToLower(safe)
}
// Type conversion functions
// goTypeToSQL converts Go type to PostgreSQL type
func goTypeToSQL(goType string) string {
typeMap := map[string]string{
"string": "text",
"int": "integer",
"int32": "integer",
"int64": "bigint",
"float32": "real",
"float64": "double precision",
"bool": "boolean",
"time.Time": "timestamp",
"[]byte": "bytea",
}
if sqlType, ok := typeMap[goType]; ok {
return sqlType
}
return "text" // Default
}
// sqlTypeToGo converts PostgreSQL type to Go type
func sqlTypeToGo(sqlType string) string {
sqlType = strings.ToLower(sqlType)
typeMap := map[string]string{
"text": "string",
"varchar": "string",
"char": "string",
"integer": "int",
"int": "int",
"bigint": "int64",
"smallint": "int16",
"serial": "int",
"bigserial": "int64",
"real": "float32",
"double precision": "float64",
"numeric": "float64",
"decimal": "float64",
"boolean": "bool",
"timestamp": "time.Time",
"timestamptz": "time.Time",
"date": "time.Time",
"time": "time.Time",
"bytea": "[]byte",
"json": "json.RawMessage",
"jsonb": "json.RawMessage",
"uuid": "string",
}
if goType, ok := typeMap[sqlType]; ok {
return goType
}
return "string" // Default
}
// isNumeric checks if a SQL type is numeric
func isNumeric(sqlType string) bool {
sqlType = strings.ToLower(sqlType)
numericTypes := []string{
"integer", "int", "bigint", "smallint", "serial", "bigserial",
"real", "double precision", "numeric", "decimal", "float",
}
for _, t := range numericTypes {
if strings.Contains(sqlType, t) {
return true
}
}
return false
}
// isText checks if a SQL type is text-based
func isText(sqlType string) bool {
sqlType = strings.ToLower(sqlType)
textTypes := []string{
"text", "varchar", "char", "character", "string",
}
for _, t := range textTypes {
if strings.Contains(sqlType, t) {
return true
}
}
return false
}
// Collection helper functions
// first returns the first element of a slice, or nil if empty
func first(slice interface{}) interface{} {
switch v := slice.(type) {
case []string:
if len(v) > 0 {
return v[0]
}
case []int:
if len(v) > 0 {
return v[0]
}
case []interface{}:
if len(v) > 0 {
return v[0]
}
}
return nil
}
// last returns the last element of a slice, or nil if empty
func last(slice interface{}) interface{} {
switch v := slice.(type) {
case []string:
if len(v) > 0 {
return v[len(v)-1]
}
case []int:
if len(v) > 0 {
return v[len(v)-1]
}
case []interface{}:
if len(v) > 0 {
return v[len(v)-1]
}
}
return nil
}
// filter filters a slice based on a predicate (simplified version)
// Usage in template: {{filter .Columns "NotNull"}}
func filter(slice interface{}, fieldName string) interface{} {
// This is a simplified implementation
// In templates, you'd use: {{range $col := .Columns}}{{if $col.NotNull}}...{{end}}{{end}}
// This function is mainly for documentation purposes
return slice
}
// mapFunc maps a function over a slice (simplified version)
// Usage in template: {{range .Columns}}{{mapFunc .Name "upper"}}{{end}}
func mapFunc(value interface{}, funcName string) interface{} {
// This is a simplified implementation
// In templates, you'd directly call: {{upper .Name}}
// This function is mainly for documentation purposes
return value
}
// joinWith joins a slice of strings with a separator
func joinWith(slice []string, separator string) string {
return strings.Join(slice, separator)
}
// Additional helper functions
// formatType formats a SQL type with length/precision
func formatType(baseType string, length, precision int) string {
if length > 0 && precision > 0 {
return fmt.Sprintf("%s(%d,%d)", baseType, length, precision)
}
if length > 0 {
return fmt.Sprintf("%s(%d)", baseType, length)
}
return baseType
}