333 lines
6.9 KiB
Go
333 lines
6.9 KiB
Go
package pgsql
|
|
|
|
import (
|
|
"strings"
|
|
"testing"
|
|
)
|
|
|
|
func TestToSnakeCase(t *testing.T) {
|
|
tests := []struct {
|
|
input string
|
|
expected string
|
|
}{
|
|
{"UserId", "user_id"},
|
|
{"UserID", "user_i_d"},
|
|
{"HTTPResponse", "h_t_t_p_response"},
|
|
{"already_snake", "already_snake"},
|
|
{"", ""},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
result := toSnakeCase(tt.input)
|
|
if result != tt.expected {
|
|
t.Errorf("toSnakeCase(%q) = %q, want %q", tt.input, result, tt.expected)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestToCamelCase(t *testing.T) {
|
|
tests := []struct {
|
|
input string
|
|
expected string
|
|
}{
|
|
{"user_id", "userId"},
|
|
{"user_name", "userName"},
|
|
{"http_response", "httpResponse"},
|
|
{"", ""},
|
|
{"alreadycamel", "alreadycamel"},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
result := toCamelCase(tt.input)
|
|
if result != tt.expected {
|
|
t.Errorf("toCamelCase(%q) = %q, want %q", tt.input, result, tt.expected)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestQuote(t *testing.T) {
|
|
tests := []struct {
|
|
input string
|
|
expected string
|
|
}{
|
|
{"hello", "'hello'"},
|
|
{"O'Brien", "'O''Brien'"},
|
|
{"", "''"},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
result := quote(tt.input)
|
|
if result != tt.expected {
|
|
t.Errorf("quote(%q) = %q, want %q", tt.input, result, tt.expected)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestEscape(t *testing.T) {
|
|
tests := []struct {
|
|
input string
|
|
expected string
|
|
}{
|
|
{"hello", "hello"},
|
|
{"O'Brien", "O''Brien"},
|
|
{"path\\to\\file", "path\\\\to\\\\file"},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
result := escape(tt.input)
|
|
if result != tt.expected {
|
|
t.Errorf("escape(%q) = %q, want %q", tt.input, result, tt.expected)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestSafeIdentifier(t *testing.T) {
|
|
tests := []struct {
|
|
input string
|
|
expected string
|
|
}{
|
|
{"User-Id", "user_id"},
|
|
{"123column", "_123column"},
|
|
{"valid_name", "valid_name"},
|
|
{"Column@Name!", "column_name_"},
|
|
{"UPPERCASE", "uppercase"},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
result := safeIdentifier(tt.input)
|
|
if result != tt.expected {
|
|
t.Errorf("safeIdentifier(%q) = %q, want %q", tt.input, result, tt.expected)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestGoTypeToSQL(t *testing.T) {
|
|
tests := []struct {
|
|
input string
|
|
expected string
|
|
}{
|
|
{"string", "text"},
|
|
{"int", "integer"},
|
|
{"int64", "bigint"},
|
|
{"bool", "boolean"},
|
|
{"time.Time", "timestamp"},
|
|
{"unknown", "text"},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
result := goTypeToSQL(tt.input)
|
|
if result != tt.expected {
|
|
t.Errorf("goTypeToSQL(%q) = %q, want %q", tt.input, result, tt.expected)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestSQLTypeToGo(t *testing.T) {
|
|
tests := []struct {
|
|
input string
|
|
expected string
|
|
}{
|
|
{"text", "string"},
|
|
{"integer", "int"},
|
|
{"bigint", "int64"},
|
|
{"boolean", "bool"},
|
|
{"timestamp", "time.Time"},
|
|
{"unknown", "string"},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
result := sqlTypeToGo(tt.input)
|
|
if result != tt.expected {
|
|
t.Errorf("sqlTypeToGo(%q) = %q, want %q", tt.input, result, tt.expected)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestIsNumeric(t *testing.T) {
|
|
tests := []struct {
|
|
input string
|
|
expected bool
|
|
}{
|
|
{"integer", true},
|
|
{"bigint", true},
|
|
{"numeric(10,2)", true},
|
|
{"text", false},
|
|
{"varchar", false},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
result := isNumeric(tt.input)
|
|
if result != tt.expected {
|
|
t.Errorf("isNumeric(%q) = %v, want %v", tt.input, result, tt.expected)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestIsText(t *testing.T) {
|
|
tests := []struct {
|
|
input string
|
|
expected bool
|
|
}{
|
|
{"text", true},
|
|
{"varchar(255)", true},
|
|
{"character varying", true},
|
|
{"integer", false},
|
|
{"bigint", false},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
result := isText(tt.input)
|
|
if result != tt.expected {
|
|
t.Errorf("isText(%q) = %v, want %v", tt.input, result, tt.expected)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestIndent(t *testing.T) {
|
|
input := "line1\nline2\nline3"
|
|
expected := " line1\n line2\n line3"
|
|
result := indent(2, input)
|
|
if result != expected {
|
|
t.Errorf("indent(2, %q) = %q, want %q", input, result, expected)
|
|
}
|
|
}
|
|
|
|
func TestFirst(t *testing.T) {
|
|
tests := []struct {
|
|
input interface{}
|
|
expected interface{}
|
|
}{
|
|
{[]string{"a", "b", "c"}, "a"},
|
|
{[]string{}, nil},
|
|
{[]int{1, 2, 3}, 1},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
result := first(tt.input)
|
|
if result != tt.expected {
|
|
t.Errorf("first(%v) = %v, want %v", tt.input, result, tt.expected)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestLast(t *testing.T) {
|
|
tests := []struct {
|
|
input interface{}
|
|
expected interface{}
|
|
}{
|
|
{[]string{"a", "b", "c"}, "c"},
|
|
{[]string{}, nil},
|
|
{[]int{1, 2, 3}, 3},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
result := last(tt.input)
|
|
if result != tt.expected {
|
|
t.Errorf("last(%v) = %v, want %v", tt.input, result, tt.expected)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestJoinWith(t *testing.T) {
|
|
input := []string{"a", "b", "c"}
|
|
expected := "a, b, c"
|
|
result := joinWith(input, ", ")
|
|
if result != expected {
|
|
t.Errorf("joinWith(%v, \", \") = %q, want %q", input, result, expected)
|
|
}
|
|
}
|
|
|
|
func TestTemplateFunctions(t *testing.T) {
|
|
funcs := TemplateFunctions()
|
|
|
|
// Check that all expected functions are registered
|
|
expectedFuncs := []string{
|
|
"upper", "lower", "snake_case", "camelCase",
|
|
"indent", "quote", "escape", "safe_identifier",
|
|
"goTypeToSQL", "sqlTypeToGo", "isNumeric", "isText",
|
|
"first", "last", "filter", "mapFunc", "join_with",
|
|
"join",
|
|
}
|
|
|
|
for _, name := range expectedFuncs {
|
|
if _, ok := funcs[name]; !ok {
|
|
t.Errorf("Expected function %q not found in TemplateFunctions()", name)
|
|
}
|
|
}
|
|
|
|
// Test that they're callable
|
|
if upperFunc, ok := funcs["upper"].(func(string) string); ok {
|
|
result := upperFunc("hello")
|
|
if result != "HELLO" {
|
|
t.Errorf("upper function not working correctly")
|
|
}
|
|
} else {
|
|
t.Error("upper function has wrong type")
|
|
}
|
|
}
|
|
|
|
func TestFormatType(t *testing.T) {
|
|
tests := []struct {
|
|
baseType string
|
|
length int
|
|
precision int
|
|
expected string
|
|
}{
|
|
{"varchar", 255, 0, "varchar(255)"},
|
|
{"numeric", 10, 2, "numeric(10,2)"},
|
|
{"integer", 0, 0, "integer"},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
result := formatType(tt.baseType, tt.length, tt.precision)
|
|
if result != tt.expected {
|
|
t.Errorf("formatType(%q, %d, %d) = %q, want %q",
|
|
tt.baseType, tt.length, tt.precision, result, tt.expected)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Test that template functions work in actual templates
|
|
func TestTemplateFunctionsInTemplate(t *testing.T) {
|
|
executor, err := NewTemplateExecutor()
|
|
if err != nil {
|
|
t.Fatalf("Failed to create executor: %v", err)
|
|
}
|
|
|
|
// Create a simple test template
|
|
tmpl, err := executor.templates.New("test").Parse(`
|
|
{{- upper .Name -}}
|
|
{{- lower .Type -}}
|
|
{{- snake_case .CamelName -}}
|
|
{{- safe_identifier .UnsafeName -}}
|
|
`)
|
|
if err != nil {
|
|
t.Fatalf("Failed to parse test template: %v", err)
|
|
}
|
|
|
|
data := struct {
|
|
Name string
|
|
Type string
|
|
CamelName string
|
|
UnsafeName string
|
|
}{
|
|
Name: "hello",
|
|
Type: "TEXT",
|
|
CamelName: "UserId",
|
|
UnsafeName: "user-id!",
|
|
}
|
|
|
|
var buf strings.Builder
|
|
err = tmpl.Execute(&buf, data)
|
|
if err != nil {
|
|
t.Fatalf("Failed to execute template: %v", err)
|
|
}
|
|
|
|
result := buf.String()
|
|
expected := "HELLOtextuser_iduser_id_"
|
|
|
|
if result != expected {
|
|
t.Errorf("Template output = %q, want %q", result, expected)
|
|
}
|
|
}
|