286 lines
6.8 KiB
Go
286 lines
6.8 KiB
Go
package pgsql
|
|
|
|
import (
|
|
"fmt"
|
|
"regexp"
|
|
"strings"
|
|
"unicode"
|
|
)
|
|
|
|
// TemplateFunctions returns a map of custom template functions
|
|
func TemplateFunctions() map[string]interface{} {
|
|
return map[string]interface{}{
|
|
// String manipulation
|
|
"upper": strings.ToUpper,
|
|
"lower": strings.ToLower,
|
|
"snake_case": toSnakeCase,
|
|
"camelCase": toCamelCase,
|
|
|
|
// SQL formatting
|
|
"indent": indent,
|
|
"quote": quote,
|
|
"escape": escape,
|
|
"safe_identifier": safeIdentifier,
|
|
|
|
// Type conversion
|
|
"goTypeToSQL": goTypeToSQL,
|
|
"sqlTypeToGo": sqlTypeToGo,
|
|
"isNumeric": isNumeric,
|
|
"isText": isText,
|
|
|
|
// Collection helpers
|
|
"first": first,
|
|
"last": last,
|
|
"filter": filter,
|
|
"mapFunc": mapFunc,
|
|
"join_with": joinWith,
|
|
|
|
// Built-in Go template function (for convenience)
|
|
"join": strings.Join,
|
|
}
|
|
}
|
|
|
|
// String manipulation functions
|
|
|
|
// toSnakeCase converts a string to snake_case
|
|
func toSnakeCase(s string) string {
|
|
// Insert underscore before uppercase letters
|
|
var result strings.Builder
|
|
for i, r := range s {
|
|
if unicode.IsUpper(r) {
|
|
if i > 0 {
|
|
result.WriteRune('_')
|
|
}
|
|
result.WriteRune(unicode.ToLower(r))
|
|
} else {
|
|
result.WriteRune(r)
|
|
}
|
|
}
|
|
return result.String()
|
|
}
|
|
|
|
// toCamelCase converts a string to camelCase
|
|
func toCamelCase(s string) string {
|
|
// Split by underscore
|
|
parts := strings.Split(s, "_")
|
|
if len(parts) == 0 {
|
|
return s
|
|
}
|
|
|
|
// First part stays lowercase
|
|
result := strings.ToLower(parts[0])
|
|
|
|
// Capitalize first letter of remaining parts
|
|
for _, part := range parts[1:] {
|
|
if len(part) > 0 {
|
|
result += strings.ToUpper(part[0:1]) + strings.ToLower(part[1:])
|
|
}
|
|
}
|
|
|
|
return result
|
|
}
|
|
|
|
// SQL formatting functions
|
|
|
|
// indent indents each line of text by the specified number of spaces
|
|
func indent(spaces int, text string) string {
|
|
prefix := strings.Repeat(" ", spaces)
|
|
lines := strings.Split(text, "\n")
|
|
for i, line := range lines {
|
|
if line != "" {
|
|
lines[i] = prefix + line
|
|
}
|
|
}
|
|
return strings.Join(lines, "\n")
|
|
}
|
|
|
|
// quote quotes a string value for SQL (escapes single quotes)
|
|
func quote(s string) string {
|
|
return "'" + strings.ReplaceAll(s, "'", "''") + "'"
|
|
}
|
|
|
|
// escape escapes a string for SQL (escapes single quotes and backslashes)
|
|
func escape(s string) string {
|
|
s = strings.ReplaceAll(s, "\\", "\\\\")
|
|
s = strings.ReplaceAll(s, "'", "''")
|
|
return s
|
|
}
|
|
|
|
// safeIdentifier makes a string safe to use as a SQL identifier
|
|
func safeIdentifier(s string) string {
|
|
// Remove or replace dangerous characters
|
|
// Allow: letters, numbers, underscore
|
|
reg := regexp.MustCompile(`[^a-zA-Z0-9_]`)
|
|
safe := reg.ReplaceAllString(s, "_")
|
|
|
|
// Ensure it doesn't start with a number
|
|
if len(safe) > 0 && unicode.IsDigit(rune(safe[0])) {
|
|
safe = "_" + safe
|
|
}
|
|
|
|
// Convert to lowercase (PostgreSQL convention)
|
|
return strings.ToLower(safe)
|
|
}
|
|
|
|
// Type conversion functions
|
|
|
|
// goTypeToSQL converts Go type to PostgreSQL type
|
|
func goTypeToSQL(goType string) string {
|
|
typeMap := map[string]string{
|
|
"string": "text",
|
|
"int": "integer",
|
|
"int32": "integer",
|
|
"int64": "bigint",
|
|
"float32": "real",
|
|
"float64": "double precision",
|
|
"bool": "boolean",
|
|
"time.Time": "timestamp",
|
|
"[]byte": "bytea",
|
|
}
|
|
|
|
if sqlType, ok := typeMap[goType]; ok {
|
|
return sqlType
|
|
}
|
|
return "text" // Default
|
|
}
|
|
|
|
// sqlTypeToGo converts PostgreSQL type to Go type
|
|
func sqlTypeToGo(sqlType string) string {
|
|
sqlType = strings.ToLower(sqlType)
|
|
|
|
typeMap := map[string]string{
|
|
"text": "string",
|
|
"varchar": "string",
|
|
"char": "string",
|
|
"integer": "int",
|
|
"int": "int",
|
|
"bigint": "int64",
|
|
"smallint": "int16",
|
|
"serial": "int",
|
|
"bigserial": "int64",
|
|
"real": "float32",
|
|
"double precision": "float64",
|
|
"numeric": "float64",
|
|
"decimal": "float64",
|
|
"boolean": "bool",
|
|
"timestamp": "time.Time",
|
|
"timestamptz": "time.Time",
|
|
"date": "time.Time",
|
|
"time": "time.Time",
|
|
"bytea": "[]byte",
|
|
"json": "json.RawMessage",
|
|
"jsonb": "json.RawMessage",
|
|
"uuid": "string",
|
|
}
|
|
|
|
if goType, ok := typeMap[sqlType]; ok {
|
|
return goType
|
|
}
|
|
return "string" // Default
|
|
}
|
|
|
|
// isNumeric checks if a SQL type is numeric
|
|
func isNumeric(sqlType string) bool {
|
|
sqlType = strings.ToLower(sqlType)
|
|
numericTypes := []string{
|
|
"integer", "int", "bigint", "smallint", "serial", "bigserial",
|
|
"real", "double precision", "numeric", "decimal", "float",
|
|
}
|
|
for _, t := range numericTypes {
|
|
if strings.Contains(sqlType, t) {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
// isText checks if a SQL type is text-based
|
|
func isText(sqlType string) bool {
|
|
sqlType = strings.ToLower(sqlType)
|
|
textTypes := []string{
|
|
"text", "varchar", "char", "character", "string",
|
|
}
|
|
for _, t := range textTypes {
|
|
if strings.Contains(sqlType, t) {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
// Collection helper functions
|
|
|
|
// first returns the first element of a slice, or nil if empty
|
|
func first(slice interface{}) interface{} {
|
|
switch v := slice.(type) {
|
|
case []string:
|
|
if len(v) > 0 {
|
|
return v[0]
|
|
}
|
|
case []int:
|
|
if len(v) > 0 {
|
|
return v[0]
|
|
}
|
|
case []interface{}:
|
|
if len(v) > 0 {
|
|
return v[0]
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// last returns the last element of a slice, or nil if empty
|
|
func last(slice interface{}) interface{} {
|
|
switch v := slice.(type) {
|
|
case []string:
|
|
if len(v) > 0 {
|
|
return v[len(v)-1]
|
|
}
|
|
case []int:
|
|
if len(v) > 0 {
|
|
return v[len(v)-1]
|
|
}
|
|
case []interface{}:
|
|
if len(v) > 0 {
|
|
return v[len(v)-1]
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// filter filters a slice based on a predicate (simplified version)
|
|
// Usage in template: {{filter .Columns "NotNull"}}
|
|
func filter(slice interface{}, fieldName string) interface{} {
|
|
// This is a simplified implementation
|
|
// In templates, you'd use: {{range $col := .Columns}}{{if $col.NotNull}}...{{end}}{{end}}
|
|
// This function is mainly for documentation purposes
|
|
return slice
|
|
}
|
|
|
|
// mapFunc maps a function over a slice (simplified version)
|
|
// Usage in template: {{range .Columns}}{{mapFunc .Name "upper"}}{{end}}
|
|
func mapFunc(value interface{}, funcName string) interface{} {
|
|
// This is a simplified implementation
|
|
// In templates, you'd directly call: {{upper .Name}}
|
|
// This function is mainly for documentation purposes
|
|
return value
|
|
}
|
|
|
|
// joinWith joins a slice of strings with a separator
|
|
func joinWith(slice []string, separator string) string {
|
|
return strings.Join(slice, separator)
|
|
}
|
|
|
|
// Additional helper functions
|
|
|
|
// formatType formats a SQL type with length/precision
|
|
func formatType(baseType string, length, precision int) string {
|
|
if length > 0 && precision > 0 {
|
|
return fmt.Sprintf("%s(%d,%d)", baseType, length, precision)
|
|
}
|
|
if length > 0 {
|
|
return fmt.Sprintf("%s(%d)", baseType, length)
|
|
}
|
|
return baseType
|
|
}
|