package pgsql import ( "fmt" "regexp" "strings" "unicode" ) // TemplateFunctions returns a map of custom template functions func TemplateFunctions() map[string]interface{} { return map[string]interface{}{ // String manipulation "upper": strings.ToUpper, "lower": strings.ToLower, "snake_case": toSnakeCase, "camelCase": toCamelCase, // SQL formatting "indent": indent, "quote": quote, "escape": escape, "safe_identifier": safeIdentifier, // Type conversion "goTypeToSQL": goTypeToSQL, "sqlTypeToGo": sqlTypeToGo, "isNumeric": isNumeric, "isText": isText, // Collection helpers "first": first, "last": last, "filter": filter, "mapFunc": mapFunc, "join_with": joinWith, // Built-in Go template function (for convenience) "join": strings.Join, } } // String manipulation functions // toSnakeCase converts a string to snake_case func toSnakeCase(s string) string { // Insert underscore before uppercase letters var result strings.Builder for i, r := range s { if unicode.IsUpper(r) { if i > 0 { result.WriteRune('_') } result.WriteRune(unicode.ToLower(r)) } else { result.WriteRune(r) } } return result.String() } // toCamelCase converts a string to camelCase func toCamelCase(s string) string { // Split by underscore parts := strings.Split(s, "_") if len(parts) == 0 { return s } // First part stays lowercase result := strings.ToLower(parts[0]) // Capitalize first letter of remaining parts for _, part := range parts[1:] { if len(part) > 0 { result += strings.ToUpper(part[0:1]) + strings.ToLower(part[1:]) } } return result } // SQL formatting functions // indent indents each line of text by the specified number of spaces func indent(spaces int, text string) string { prefix := strings.Repeat(" ", spaces) lines := strings.Split(text, "\n") for i, line := range lines { if line != "" { lines[i] = prefix + line } } return strings.Join(lines, "\n") } // quote quotes a string value for SQL (escapes single quotes) func quote(s string) string { return "'" + strings.ReplaceAll(s, "'", "''") + "'" } // escape escapes a string for SQL (escapes single quotes and backslashes) func escape(s string) string { s = strings.ReplaceAll(s, "\\", "\\\\") s = strings.ReplaceAll(s, "'", "''") return s } // safeIdentifier makes a string safe to use as a SQL identifier func safeIdentifier(s string) string { // Remove or replace dangerous characters // Allow: letters, numbers, underscore reg := regexp.MustCompile(`[^a-zA-Z0-9_]`) safe := reg.ReplaceAllString(s, "_") // Ensure it doesn't start with a number if len(safe) > 0 && unicode.IsDigit(rune(safe[0])) { safe = "_" + safe } // Convert to lowercase (PostgreSQL convention) return strings.ToLower(safe) } // Type conversion functions // goTypeToSQL converts Go type to PostgreSQL type func goTypeToSQL(goType string) string { typeMap := map[string]string{ "string": "text", "int": "integer", "int32": "integer", "int64": "bigint", "float32": "real", "float64": "double precision", "bool": "boolean", "time.Time": "timestamp", "[]byte": "bytea", } if sqlType, ok := typeMap[goType]; ok { return sqlType } return "text" // Default } // sqlTypeToGo converts PostgreSQL type to Go type func sqlTypeToGo(sqlType string) string { sqlType = strings.ToLower(sqlType) typeMap := map[string]string{ "text": "string", "varchar": "string", "char": "string", "integer": "int", "int": "int", "bigint": "int64", "smallint": "int16", "serial": "int", "bigserial": "int64", "real": "float32", "double precision": "float64", "numeric": "float64", "decimal": "float64", "boolean": "bool", "timestamp": "time.Time", "timestamptz": "time.Time", "date": "time.Time", "time": "time.Time", "bytea": "[]byte", "json": "json.RawMessage", "jsonb": "json.RawMessage", "uuid": "string", } if goType, ok := typeMap[sqlType]; ok { return goType } return "string" // Default } // isNumeric checks if a SQL type is numeric func isNumeric(sqlType string) bool { sqlType = strings.ToLower(sqlType) numericTypes := []string{ "integer", "int", "bigint", "smallint", "serial", "bigserial", "real", "double precision", "numeric", "decimal", "float", } for _, t := range numericTypes { if strings.Contains(sqlType, t) { return true } } return false } // isText checks if a SQL type is text-based func isText(sqlType string) bool { sqlType = strings.ToLower(sqlType) textTypes := []string{ "text", "varchar", "char", "character", "string", } for _, t := range textTypes { if strings.Contains(sqlType, t) { return true } } return false } // Collection helper functions // first returns the first element of a slice, or nil if empty func first(slice interface{}) interface{} { switch v := slice.(type) { case []string: if len(v) > 0 { return v[0] } case []int: if len(v) > 0 { return v[0] } case []interface{}: if len(v) > 0 { return v[0] } } return nil } // last returns the last element of a slice, or nil if empty func last(slice interface{}) interface{} { switch v := slice.(type) { case []string: if len(v) > 0 { return v[len(v)-1] } case []int: if len(v) > 0 { return v[len(v)-1] } case []interface{}: if len(v) > 0 { return v[len(v)-1] } } return nil } // filter filters a slice based on a predicate (simplified version) // Usage in template: {{filter .Columns "NotNull"}} func filter(slice interface{}, fieldName string) interface{} { // This is a simplified implementation // In templates, you'd use: {{range $col := .Columns}}{{if $col.NotNull}}...{{end}}{{end}} // This function is mainly for documentation purposes return slice } // mapFunc maps a function over a slice (simplified version) // Usage in template: {{range .Columns}}{{mapFunc .Name "upper"}}{{end}} func mapFunc(value interface{}, funcName string) interface{} { // This is a simplified implementation // In templates, you'd directly call: {{upper .Name}} // This function is mainly for documentation purposes return value } // joinWith joins a slice of strings with a separator func joinWith(slice []string, separator string) string { return strings.Join(slice, separator) } // Additional helper functions // formatType formats a SQL type with length/precision func formatType(baseType string, length, precision int) string { if length > 0 && precision > 0 { return fmt.Sprintf("%s(%d,%d)", baseType, length, precision) } if length > 0 { return fmt.Sprintf("%s(%d)", baseType, length) } return baseType }