sql writer
This commit is contained in:
285
pkg/writers/pgsql/template_functions.go
Normal file
285
pkg/writers/pgsql/template_functions.go
Normal file
@@ -0,0 +1,285 @@
|
||||
package pgsql
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
"unicode"
|
||||
)
|
||||
|
||||
// TemplateFunctions returns a map of custom template functions
|
||||
func TemplateFunctions() map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
// String manipulation
|
||||
"upper": strings.ToUpper,
|
||||
"lower": strings.ToLower,
|
||||
"snake_case": toSnakeCase,
|
||||
"camelCase": toCamelCase,
|
||||
|
||||
// SQL formatting
|
||||
"indent": indent,
|
||||
"quote": quote,
|
||||
"escape": escape,
|
||||
"safe_identifier": safeIdentifier,
|
||||
|
||||
// Type conversion
|
||||
"goTypeToSQL": goTypeToSQL,
|
||||
"sqlTypeToGo": sqlTypeToGo,
|
||||
"isNumeric": isNumeric,
|
||||
"isText": isText,
|
||||
|
||||
// Collection helpers
|
||||
"first": first,
|
||||
"last": last,
|
||||
"filter": filter,
|
||||
"mapFunc": mapFunc,
|
||||
"join_with": joinWith,
|
||||
|
||||
// Built-in Go template function (for convenience)
|
||||
"join": strings.Join,
|
||||
}
|
||||
}
|
||||
|
||||
// String manipulation functions
|
||||
|
||||
// toSnakeCase converts a string to snake_case
|
||||
func toSnakeCase(s string) string {
|
||||
// Insert underscore before uppercase letters
|
||||
var result strings.Builder
|
||||
for i, r := range s {
|
||||
if unicode.IsUpper(r) {
|
||||
if i > 0 {
|
||||
result.WriteRune('_')
|
||||
}
|
||||
result.WriteRune(unicode.ToLower(r))
|
||||
} else {
|
||||
result.WriteRune(r)
|
||||
}
|
||||
}
|
||||
return result.String()
|
||||
}
|
||||
|
||||
// toCamelCase converts a string to camelCase
|
||||
func toCamelCase(s string) string {
|
||||
// Split by underscore
|
||||
parts := strings.Split(s, "_")
|
||||
if len(parts) == 0 {
|
||||
return s
|
||||
}
|
||||
|
||||
// First part stays lowercase
|
||||
result := strings.ToLower(parts[0])
|
||||
|
||||
// Capitalize first letter of remaining parts
|
||||
for _, part := range parts[1:] {
|
||||
if len(part) > 0 {
|
||||
result += strings.ToUpper(part[0:1]) + strings.ToLower(part[1:])
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// SQL formatting functions
|
||||
|
||||
// indent indents each line of text by the specified number of spaces
|
||||
func indent(spaces int, text string) string {
|
||||
prefix := strings.Repeat(" ", spaces)
|
||||
lines := strings.Split(text, "\n")
|
||||
for i, line := range lines {
|
||||
if line != "" {
|
||||
lines[i] = prefix + line
|
||||
}
|
||||
}
|
||||
return strings.Join(lines, "\n")
|
||||
}
|
||||
|
||||
// quote quotes a string value for SQL (escapes single quotes)
|
||||
func quote(s string) string {
|
||||
return "'" + strings.ReplaceAll(s, "'", "''") + "'"
|
||||
}
|
||||
|
||||
// escape escapes a string for SQL (escapes single quotes and backslashes)
|
||||
func escape(s string) string {
|
||||
s = strings.ReplaceAll(s, "\\", "\\\\")
|
||||
s = strings.ReplaceAll(s, "'", "''")
|
||||
return s
|
||||
}
|
||||
|
||||
// safeIdentifier makes a string safe to use as a SQL identifier
|
||||
func safeIdentifier(s string) string {
|
||||
// Remove or replace dangerous characters
|
||||
// Allow: letters, numbers, underscore
|
||||
reg := regexp.MustCompile(`[^a-zA-Z0-9_]`)
|
||||
safe := reg.ReplaceAllString(s, "_")
|
||||
|
||||
// Ensure it doesn't start with a number
|
||||
if len(safe) > 0 && unicode.IsDigit(rune(safe[0])) {
|
||||
safe = "_" + safe
|
||||
}
|
||||
|
||||
// Convert to lowercase (PostgreSQL convention)
|
||||
return strings.ToLower(safe)
|
||||
}
|
||||
|
||||
// Type conversion functions
|
||||
|
||||
// goTypeToSQL converts Go type to PostgreSQL type
|
||||
func goTypeToSQL(goType string) string {
|
||||
typeMap := map[string]string{
|
||||
"string": "text",
|
||||
"int": "integer",
|
||||
"int32": "integer",
|
||||
"int64": "bigint",
|
||||
"float32": "real",
|
||||
"float64": "double precision",
|
||||
"bool": "boolean",
|
||||
"time.Time": "timestamp",
|
||||
"[]byte": "bytea",
|
||||
}
|
||||
|
||||
if sqlType, ok := typeMap[goType]; ok {
|
||||
return sqlType
|
||||
}
|
||||
return "text" // Default
|
||||
}
|
||||
|
||||
// sqlTypeToGo converts PostgreSQL type to Go type
|
||||
func sqlTypeToGo(sqlType string) string {
|
||||
sqlType = strings.ToLower(sqlType)
|
||||
|
||||
typeMap := map[string]string{
|
||||
"text": "string",
|
||||
"varchar": "string",
|
||||
"char": "string",
|
||||
"integer": "int",
|
||||
"int": "int",
|
||||
"bigint": "int64",
|
||||
"smallint": "int16",
|
||||
"serial": "int",
|
||||
"bigserial": "int64",
|
||||
"real": "float32",
|
||||
"double precision": "float64",
|
||||
"numeric": "float64",
|
||||
"decimal": "float64",
|
||||
"boolean": "bool",
|
||||
"timestamp": "time.Time",
|
||||
"timestamptz": "time.Time",
|
||||
"date": "time.Time",
|
||||
"time": "time.Time",
|
||||
"bytea": "[]byte",
|
||||
"json": "json.RawMessage",
|
||||
"jsonb": "json.RawMessage",
|
||||
"uuid": "string",
|
||||
}
|
||||
|
||||
if goType, ok := typeMap[sqlType]; ok {
|
||||
return goType
|
||||
}
|
||||
return "string" // Default
|
||||
}
|
||||
|
||||
// isNumeric checks if a SQL type is numeric
|
||||
func isNumeric(sqlType string) bool {
|
||||
sqlType = strings.ToLower(sqlType)
|
||||
numericTypes := []string{
|
||||
"integer", "int", "bigint", "smallint", "serial", "bigserial",
|
||||
"real", "double precision", "numeric", "decimal", "float",
|
||||
}
|
||||
for _, t := range numericTypes {
|
||||
if strings.Contains(sqlType, t) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// isText checks if a SQL type is text-based
|
||||
func isText(sqlType string) bool {
|
||||
sqlType = strings.ToLower(sqlType)
|
||||
textTypes := []string{
|
||||
"text", "varchar", "char", "character", "string",
|
||||
}
|
||||
for _, t := range textTypes {
|
||||
if strings.Contains(sqlType, t) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Collection helper functions
|
||||
|
||||
// first returns the first element of a slice, or nil if empty
|
||||
func first(slice interface{}) interface{} {
|
||||
switch v := slice.(type) {
|
||||
case []string:
|
||||
if len(v) > 0 {
|
||||
return v[0]
|
||||
}
|
||||
case []int:
|
||||
if len(v) > 0 {
|
||||
return v[0]
|
||||
}
|
||||
case []interface{}:
|
||||
if len(v) > 0 {
|
||||
return v[0]
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// last returns the last element of a slice, or nil if empty
|
||||
func last(slice interface{}) interface{} {
|
||||
switch v := slice.(type) {
|
||||
case []string:
|
||||
if len(v) > 0 {
|
||||
return v[len(v)-1]
|
||||
}
|
||||
case []int:
|
||||
if len(v) > 0 {
|
||||
return v[len(v)-1]
|
||||
}
|
||||
case []interface{}:
|
||||
if len(v) > 0 {
|
||||
return v[len(v)-1]
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// filter filters a slice based on a predicate (simplified version)
|
||||
// Usage in template: {{filter .Columns "NotNull"}}
|
||||
func filter(slice interface{}, fieldName string) interface{} {
|
||||
// This is a simplified implementation
|
||||
// In templates, you'd use: {{range $col := .Columns}}{{if $col.NotNull}}...{{end}}{{end}}
|
||||
// This function is mainly for documentation purposes
|
||||
return slice
|
||||
}
|
||||
|
||||
// mapFunc maps a function over a slice (simplified version)
|
||||
// Usage in template: {{range .Columns}}{{mapFunc .Name "upper"}}{{end}}
|
||||
func mapFunc(value interface{}, funcName string) interface{} {
|
||||
// This is a simplified implementation
|
||||
// In templates, you'd directly call: {{upper .Name}}
|
||||
// This function is mainly for documentation purposes
|
||||
return value
|
||||
}
|
||||
|
||||
// joinWith joins a slice of strings with a separator
|
||||
func joinWith(slice []string, separator string) string {
|
||||
return strings.Join(slice, separator)
|
||||
}
|
||||
|
||||
// Additional helper functions
|
||||
|
||||
// formatType formats a SQL type with length/precision
|
||||
func formatType(baseType string, length, precision int) string {
|
||||
if length > 0 && precision > 0 {
|
||||
return fmt.Sprintf("%s(%d,%d)", baseType, length, precision)
|
||||
}
|
||||
if length > 0 {
|
||||
return fmt.Sprintf("%s(%d)", baseType, length)
|
||||
}
|
||||
return baseType
|
||||
}
|
||||
Reference in New Issue
Block a user