Files
ResolveSpec/pkg/funcspec/function_api_test.go
Hein 1ebe0d7ac3 fix(funcspec): refine filter application logic for SQL queries
* update filter checks to only consider SELECT list
* add test for function parameters not matching filters
2026-05-15 14:28:12 +02:00

1194 lines
33 KiB
Go

package funcspec
import (
"bytes"
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/security"
)
// MockDatabase implements common.Database interface for testing
type MockDatabase struct {
QueryFunc func(ctx context.Context, dest interface{}, query string, args ...interface{}) error
ExecFunc func(ctx context.Context, query string, args ...interface{}) (common.Result, error)
RunInTransactionFunc func(ctx context.Context, fn func(common.Database) error) error
}
func (m *MockDatabase) NewSelect() common.SelectQuery {
return nil
}
func (m *MockDatabase) NewInsert() common.InsertQuery {
return nil
}
func (m *MockDatabase) NewUpdate() common.UpdateQuery {
return nil
}
func (m *MockDatabase) NewDelete() common.DeleteQuery {
return nil
}
func (m *MockDatabase) Exec(ctx context.Context, query string, args ...interface{}) (common.Result, error) {
if m.ExecFunc != nil {
return m.ExecFunc(ctx, query, args...)
}
return &MockResult{rows: 0}, nil
}
func (m *MockDatabase) Query(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
if m.QueryFunc != nil {
return m.QueryFunc(ctx, dest, query, args...)
}
return nil
}
func (m *MockDatabase) BeginTx(ctx context.Context) (common.Database, error) {
return m, nil
}
func (m *MockDatabase) CommitTx(ctx context.Context) error {
return nil
}
func (m *MockDatabase) RollbackTx(ctx context.Context) error {
return nil
}
func (m *MockDatabase) RunInTransaction(ctx context.Context, fn func(common.Database) error) error {
if m.RunInTransactionFunc != nil {
return m.RunInTransactionFunc(ctx, fn)
}
return fn(m)
}
func (m *MockDatabase) GetUnderlyingDB() interface{} {
return m
}
func (m *MockDatabase) DriverName() string {
return "postgres"
}
// MockResult implements common.Result interface for testing
type MockResult struct {
rows int64
id int64
}
func (m *MockResult) RowsAffected() int64 {
return m.rows
}
func (m *MockResult) LastInsertId() (int64, error) {
return m.id, nil
}
// Helper function to create a test request with user context
func createTestRequest(method, path string, queryParams map[string]string, headers map[string]string, body []byte) *http.Request {
u, _ := url.Parse(path)
if queryParams != nil {
q := u.Query()
for k, v := range queryParams {
q.Set(k, v)
}
u.RawQuery = q.Encode()
}
var bodyReader *bytes.Reader
if body != nil {
bodyReader = bytes.NewReader(body)
} else {
bodyReader = bytes.NewReader([]byte{})
}
req := httptest.NewRequest(method, u.String(), bodyReader)
if headers != nil {
for k, v := range headers {
req.Header.Set(k, v)
}
}
// Add user context
userCtx := &security.UserContext{
UserID: 1,
UserName: "testuser",
SessionID: "test-session-123",
}
ctx := context.WithValue(req.Context(), security.UserContextKey, userCtx)
req = req.WithContext(ctx)
return req
}
// TestNewHandler tests handler creation
func TestNewHandler(t *testing.T) {
db := &MockDatabase{}
handler := NewHandler(db)
if handler == nil {
t.Fatal("Expected handler to be created, got nil")
}
if handler.db != db {
t.Error("Expected handler to have the provided database")
}
if handler.hooks == nil {
t.Error("Expected handler to have a hook registry")
}
}
// TestHandlerHooks tests the Hooks method
func TestHandlerHooks(t *testing.T) {
handler := NewHandler(&MockDatabase{})
hooks := handler.Hooks()
if hooks == nil {
t.Fatal("Expected hooks registry to be non-nil")
}
// Should return the same instance
hooks2 := handler.Hooks()
if hooks != hooks2 {
t.Error("Expected Hooks() to return the same registry instance")
}
}
// TestExtractInputVariables tests the extractInputVariables function
func TestExtractInputVariables(t *testing.T) {
handler := NewHandler(&MockDatabase{})
tests := []struct {
name string
sqlQuery string
expectedVars []string
}{
{
name: "No variables",
sqlQuery: "SELECT * FROM users",
expectedVars: []string{},
},
{
name: "Single variable",
sqlQuery: "SELECT * FROM users WHERE id = [user_id]",
expectedVars: []string{"[user_id]"},
},
{
name: "Multiple variables",
sqlQuery: "SELECT * FROM users WHERE id = [user_id] AND name = [user_name]",
expectedVars: []string{"[user_id]", "[user_name]"},
},
{
name: "Nested brackets",
sqlQuery: "SELECT * FROM users WHERE data::jsonb @> '[field]'::jsonb AND id = [user_id]",
expectedVars: []string{"[field]", "[user_id]"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
inputvars := make([]string, 0)
result := handler.extractInputVariables(tt.sqlQuery, &inputvars)
if result != tt.sqlQuery {
t.Errorf("Expected SQL query to be unchanged, got %s", result)
}
if len(inputvars) != len(tt.expectedVars) {
t.Errorf("Expected %d variables, got %d: %v", len(tt.expectedVars), len(inputvars), inputvars)
return
}
for i, expected := range tt.expectedVars {
if inputvars[i] != expected {
t.Errorf("Expected variable %d to be %s, got %s", i, expected, inputvars[i])
}
}
})
}
}
// TestValidSQL tests the SQL sanitization function
func TestValidSQL(t *testing.T) {
tests := []struct {
name string
input string
mode string
expected string
}{
{
name: "Column name with valid characters",
input: "user_id",
mode: "colname",
expected: "user_id",
},
{
name: "Column name with dots (table.column)",
input: "users.user_id",
mode: "colname",
expected: "users.user_id",
},
{
name: "Column name with SQL injection attempt",
input: "id'; DROP TABLE users--",
mode: "colname",
expected: "idDROPTABLEusers",
},
{
name: "Column value with single quotes",
input: "O'Brien",
mode: "colvalue",
expected: "O''Brien",
},
{
name: "Select with dangerous keywords",
input: "name, email; DROP TABLE users",
mode: "select",
expected: "name, email TABLE users",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := ValidSQL(tt.input, tt.mode)
if result != tt.expected {
t.Errorf("ValidSQL(%q, %q) = %q, expected %q", tt.input, tt.mode, result, tt.expected)
}
})
}
}
// TestIsNumeric tests the IsNumeric function
func TestIsNumeric(t *testing.T) {
tests := []struct {
input string
expected bool
}{
{"123", true},
{"123.45", true},
{"-123", true},
{"-123.45", true},
{"0", true},
{"abc", false},
{"12.34.56", false},
{"", false},
{"123abc", false},
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
result := IsNumeric(tt.input)
if result != tt.expected {
t.Errorf("IsNumeric(%q) = %v, expected %v", tt.input, result, tt.expected)
}
})
}
}
// TestSqlQryWhere tests the WHERE clause manipulation
func TestSqlQryWhere(t *testing.T) {
tests := []struct {
name string
sqlQuery string
condition string
expected string
}{
{
name: "Add WHERE to query without WHERE",
sqlQuery: "SELECT * FROM users",
condition: "status = 'active'",
expected: "SELECT * FROM users WHERE status = 'active' ",
},
{
name: "Add AND to query with existing WHERE",
sqlQuery: "SELECT * FROM users WHERE id > 0",
condition: "status = 'active'",
expected: "SELECT * FROM users WHERE id > 0 AND status = 'active' ",
},
{
name: "Add WHERE before ORDER BY",
sqlQuery: "SELECT * FROM users ORDER BY name",
condition: "status = 'active'",
expected: "SELECT * FROM users WHERE status = 'active' ORDER BY name",
},
{
name: "Add WHERE before GROUP BY",
sqlQuery: "SELECT COUNT(*) FROM users GROUP BY department",
condition: "status = 'active'",
expected: "SELECT COUNT(*) FROM users WHERE status = 'active' GROUP BY department",
},
{
name: "Add WHERE before LIMIT",
sqlQuery: "SELECT * FROM users LIMIT 10",
condition: "status = 'active'",
expected: "SELECT * FROM users WHERE status = 'active' LIMIT 10",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := sqlQryWhere(tt.sqlQuery, tt.condition)
if result != tt.expected {
t.Errorf("sqlQryWhere() = %q, expected %q", result, tt.expected)
}
})
}
}
// TestGetIPAddress tests IP address extraction
func TestGetIPAddress(t *testing.T) {
tests := []struct {
name string
setupReq func() *http.Request
expected string
}{
{
name: "X-Forwarded-For header",
setupReq: func() *http.Request {
req := httptest.NewRequest("GET", "/test", nil)
req.Header.Set("X-Forwarded-For", "192.168.1.100, 10.0.0.1")
return req
},
expected: "192.168.1.100",
},
{
name: "X-Real-IP header",
setupReq: func() *http.Request {
req := httptest.NewRequest("GET", "/test", nil)
req.Header.Set("X-Real-IP", "192.168.1.200")
return req
},
expected: "192.168.1.200",
},
{
name: "RemoteAddr fallback",
setupReq: func() *http.Request {
req := httptest.NewRequest("GET", "/test", nil)
req.RemoteAddr = "192.168.1.1:12345"
return req
},
expected: "192.168.1.1:12345",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := tt.setupReq()
result := getIPAddress(req)
if result != tt.expected {
t.Errorf("getIPAddress() = %q, expected %q", result, tt.expected)
}
})
}
}
// TestParsePaginationParams tests pagination parameter parsing
func TestParsePaginationParams(t *testing.T) {
handler := NewHandler(&MockDatabase{})
tests := []struct {
name string
queryParams map[string]string
expectedSort string
expectedLimit int
expectedOffset int
}{
{
name: "No parameters - defaults",
queryParams: map[string]string{},
expectedSort: "",
expectedLimit: 20,
expectedOffset: 0,
},
{
name: "All parameters provided",
queryParams: map[string]string{
"sort": "name,-created_at",
"limit": "100",
"offset": "50",
},
expectedSort: "name,-created_at",
expectedLimit: 100,
expectedOffset: 50,
},
{
name: "Invalid limit - use default",
queryParams: map[string]string{
"limit": "invalid",
},
expectedSort: "",
expectedLimit: 20,
expectedOffset: 0,
},
{
name: "Negative offset - use default",
queryParams: map[string]string{
"offset": "-10",
},
expectedSort: "",
expectedLimit: 20,
expectedOffset: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := createTestRequest("GET", "/test", tt.queryParams, nil, nil)
sort, limit, offset := handler.parsePaginationParams(req)
if sort != tt.expectedSort {
t.Errorf("Expected sort=%q, got %q", tt.expectedSort, sort)
}
if limit != tt.expectedLimit {
t.Errorf("Expected limit=%d, got %d", tt.expectedLimit, limit)
}
if offset != tt.expectedOffset {
t.Errorf("Expected offset=%d, got %d", tt.expectedOffset, offset)
}
})
}
}
// TestSqlQuery tests the SqlQuery handler for single record queries
func TestSqlQuery(t *testing.T) {
tests := []struct {
name string
sqlQuery string
blankParams bool
queryParams map[string]string
headers map[string]string
setupDB func() *MockDatabase
expectedStatus int
validateResp func(t *testing.T, body []byte)
}{
{
name: "Basic query - returns single record",
sqlQuery: "SELECT * FROM users WHERE id = 1",
blankParams: false,
setupDB: func() *MockDatabase {
return &MockDatabase{
RunInTransactionFunc: func(ctx context.Context, fn func(common.Database) error) error {
db := &MockDatabase{
QueryFunc: func(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
rows := dest.(*[]map[string]interface{})
*rows = []map[string]interface{}{
{"id": float64(1), "name": "Test User", "email": "test@example.com"},
}
return nil
},
}
return fn(db)
},
}
},
expectedStatus: 200,
validateResp: func(t *testing.T, body []byte) {
var result map[string]interface{}
if err := json.Unmarshal(body, &result); err != nil {
t.Fatalf("Failed to unmarshal response: %v", err)
}
if result["name"] != "Test User" {
t.Errorf("Expected name='Test User', got %v", result["name"])
}
},
},
{
name: "Query with no results",
sqlQuery: "SELECT * FROM users WHERE id = 999",
blankParams: false,
setupDB: func() *MockDatabase {
return &MockDatabase{
RunInTransactionFunc: func(ctx context.Context, fn func(common.Database) error) error {
db := &MockDatabase{
QueryFunc: func(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
// Return empty array
return nil
},
}
return fn(db)
},
}
},
expectedStatus: 200,
validateResp: func(t *testing.T, body []byte) {
var result map[string]interface{}
if err := json.Unmarshal(body, &result); err != nil {
t.Fatalf("Failed to unmarshal response: %v", err)
}
if len(result) != 0 {
t.Errorf("Expected empty result, got %v", result)
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
db := tt.setupDB()
handler := NewHandler(db)
req := createTestRequest("GET", "/test", tt.queryParams, tt.headers, nil)
w := httptest.NewRecorder()
handlerFunc := handler.SqlQuery(tt.sqlQuery, SqlQueryOptions{BlankParams: tt.blankParams})
handlerFunc(w, req)
if w.Code != tt.expectedStatus {
t.Errorf("Expected status %d, got %d", tt.expectedStatus, w.Code)
}
if tt.validateResp != nil {
tt.validateResp(t, w.Body.Bytes())
}
})
}
}
// TestSqlQueryList tests the SqlQueryList handler for list queries
func TestSqlQueryList(t *testing.T) {
tests := []struct {
name string
sqlQuery string
noCount bool
blankParams bool
allowFilter bool
queryParams map[string]string
headers map[string]string
setupDB func() *MockDatabase
expectedStatus int
validateResp func(t *testing.T, w *httptest.ResponseRecorder)
}{
{
name: "Basic list query",
sqlQuery: "SELECT * FROM users",
noCount: false,
blankParams: false,
allowFilter: false,
setupDB: func() *MockDatabase {
return &MockDatabase{
RunInTransactionFunc: func(ctx context.Context, fn func(common.Database) error) error {
callCount := 0
db := &MockDatabase{
QueryFunc: func(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
callCount++
if strings.Contains(query, "COUNT") {
// Count query
countResult := dest.(*struct{ Count int64 })
countResult.Count = 2
} else {
// Main query
rows := dest.(*[]map[string]interface{})
*rows = []map[string]interface{}{
{"id": float64(1), "name": "User 1"},
{"id": float64(2), "name": "User 2"},
}
}
return nil
},
}
return fn(db)
},
}
},
expectedStatus: 200,
validateResp: func(t *testing.T, w *httptest.ResponseRecorder) {
var result []map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &result); err != nil {
t.Fatalf("Failed to unmarshal response: %v", err)
}
if len(result) != 2 {
t.Errorf("Expected 2 results, got %d", len(result))
}
// Check Content-Range header
contentRange := w.Header().Get("Content-Range")
if !strings.Contains(contentRange, "2") {
t.Errorf("Expected Content-Range to contain total count, got: %s", contentRange)
}
},
},
{
name: "List query with noCount",
sqlQuery: "SELECT * FROM users",
noCount: true,
blankParams: false,
allowFilter: false,
setupDB: func() *MockDatabase {
return &MockDatabase{
RunInTransactionFunc: func(ctx context.Context, fn func(common.Database) error) error {
db := &MockDatabase{
QueryFunc: func(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
if strings.Contains(query, "COUNT") {
t.Error("Count query should not be executed when noCount is true")
}
rows := dest.(*[]map[string]interface{})
*rows = []map[string]interface{}{
{"id": float64(1), "name": "User 1"},
}
return nil
},
}
return fn(db)
},
}
},
expectedStatus: 200,
validateResp: func(t *testing.T, w *httptest.ResponseRecorder) {
var result []map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &result); err != nil {
t.Fatalf("Failed to unmarshal response: %v", err)
}
if len(result) != 1 {
t.Errorf("Expected 1 result, got %d", len(result))
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
db := tt.setupDB()
handler := NewHandler(db)
req := createTestRequest("GET", "/test", tt.queryParams, tt.headers, nil)
w := httptest.NewRecorder()
handlerFunc := handler.SqlQueryList(tt.sqlQuery, SqlQueryOptions{NoCount: tt.noCount, BlankParams: tt.blankParams, AllowFilter: tt.allowFilter})
handlerFunc(w, req)
if w.Code != tt.expectedStatus {
t.Errorf("Expected status %d, got %d. Body: %s", tt.expectedStatus, w.Code, w.Body.String())
}
if tt.validateResp != nil {
tt.validateResp(t, w)
}
})
}
}
// TestMergeQueryParams tests query parameter merging
func TestMergeQueryParams(t *testing.T) {
handler := NewHandler(&MockDatabase{})
tests := []struct {
name string
sqlQuery string
queryParams map[string]string
allowFilter bool
expectedQuery string
checkVars func(t *testing.T, vars map[string]interface{})
}{
{
name: "Replace placeholder with parameter",
sqlQuery: "SELECT * FROM users WHERE id = [user_id]",
queryParams: map[string]string{"p-user_id": "123"},
allowFilter: false,
checkVars: func(t *testing.T, vars map[string]interface{}) {
if vars["p-user_id"] != "123" {
t.Errorf("Expected p-user_id=123, got %v", vars["p-user_id"])
}
},
},
{
name: "Add filter when allowed",
sqlQuery: "SELECT * FROM users",
queryParams: map[string]string{"status": "active"},
allowFilter: true,
checkVars: func(t *testing.T, vars map[string]interface{}) {
if vars["status"] != "active" {
t.Errorf("Expected status=active, got %v", vars["status"])
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := createTestRequest("GET", "/test", tt.queryParams, nil, nil)
variables := make(map[string]interface{})
propQry := make(map[string]string)
result := handler.mergeQueryParams(req, tt.sqlQuery, variables, tt.allowFilter, propQry)
if result == "" {
t.Error("Expected non-empty SQL query result")
}
if tt.checkVars != nil {
tt.checkVars(t, variables)
}
})
}
}
// TestMergeHeaderParams tests header parameter merging
func TestMergeHeaderParams(t *testing.T) {
handler := NewHandler(&MockDatabase{})
tests := []struct {
name string
sqlQuery string
headers map[string]string
expectedQuery string
checkVars func(t *testing.T, vars map[string]interface{})
}{
{
name: "Field filter header",
sqlQuery: "SELECT * FROM users",
headers: map[string]string{"X-FieldFilter-Status": "1"},
checkVars: func(t *testing.T, vars map[string]interface{}) {
if vars["x-fieldfilter-status"] != "1" {
t.Errorf("Expected x-fieldfilter-status=1, got %v", vars["x-fieldfilter-status"])
}
},
},
{
name: "Search filter header",
sqlQuery: "SELECT * FROM users",
headers: map[string]string{"X-SearchFilter-Name": "john"},
checkVars: func(t *testing.T, vars map[string]interface{}) {
if vars["x-searchfilter-name"] != "john" {
t.Errorf("Expected x-searchfilter-name=john, got %v", vars["x-searchfilter-name"])
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := createTestRequest("GET", "/test", nil, tt.headers, nil)
variables := make(map[string]interface{})
propQry := make(map[string]string)
complexAPI := false
result := handler.mergeHeaderParams(req, tt.sqlQuery, variables, propQry, &complexAPI)
if result == "" {
t.Error("Expected non-empty SQL query result")
}
if tt.checkVars != nil {
tt.checkVars(t, variables)
}
})
}
}
// TestReplaceMetaVariables tests meta variable replacement
func TestReplaceMetaVariables(t *testing.T) {
handler := NewHandler(&MockDatabase{})
userCtx := &security.UserContext{
UserID: 123,
UserName: "testuser",
SessionID: "ABC456",
SessionRID: 456,
}
metainfo := map[string]interface{}{
"ipaddress": "192.168.1.1",
"url": "/api/test",
}
variables := map[string]interface{}{
"param1": "value1",
}
tests := []struct {
name string
sqlQuery string
expectedCheck func(result string) bool
}{
{
name: "Replace [rid_user]",
sqlQuery: "SELECT * FROM users WHERE created_by = [rid_user]",
expectedCheck: func(result string) bool {
return strings.Contains(result, "123")
},
},
{
name: "Replace [user]",
sqlQuery: "SELECT * FROM audit WHERE username = [user]",
expectedCheck: func(result string) bool {
return strings.Contains(result, "'testuser'")
},
},
{
name: "Replace [rid_session]",
sqlQuery: "SELECT * FROM sessions WHERE session_id = [rid_session]",
expectedCheck: func(result string) bool {
return strings.Contains(result, "456")
},
}, {
name: "Replace [id_session]",
sqlQuery: "SELECT * FROM sessions WHERE session_id = [id_session]",
expectedCheck: func(result string) bool {
return strings.Contains(result, "ABC456")
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := createTestRequest("GET", "/test", nil, nil, nil)
result := handler.replaceMetaVariables(tt.sqlQuery, req, userCtx, metainfo, variables)
if !tt.expectedCheck(result) {
t.Errorf("Meta variable replacement failed. Query: %s", result)
}
})
}
}
// TestSqlStripStringLiterals tests that single-quoted string literals are removed
func TestSqlStripStringLiterals(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{
name: "No string literals",
input: "SELECT rid, rid_parent FROM users",
expected: "SELECT rid, rid_parent FROM users",
},
{
name: "Simple string literal",
input: "SELECT * FROM users WHERE mode = 'admin'",
expected: "SELECT * FROM users WHERE mode = ''",
},
{
name: "JSON argument containing column names",
input: `SELECT rid, rid_parent FROM crm_get_menu(1,'mode', '{"rid_parent":"[rid_parent]","CF:STARTDATE":"[cf_startdate]"}')`,
expected: `SELECT rid, rid_parent FROM crm_get_menu(1,'', '')`,
},
{
name: "Escaped single quotes inside literal",
input: "SELECT * FROM t WHERE name = 'O''Brien'",
expected: "SELECT * FROM t WHERE name = ''",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := sqlStripStringLiterals(tt.input)
if result != tt.expected {
t.Errorf("sqlStripStringLiterals() =\n %q\nwant\n %q", result, tt.expected)
}
})
}
}
// TestAllowFilterDoesNotMatchInsideJsonArgument verifies that AllowFilter will add WHERE
// clauses for real output columns (rid, rid_parent) but not for names that only appear
// inside a JSON string argument (cf_startdate, cf_rid_branch).
func TestAllowFilterDoesNotMatchInsideJsonArgument(t *testing.T) {
handler := NewHandler(&MockDatabase{})
sqlQuery := `select rid, rid_parent, description
from crm_get_menu([rid_user],'[p_mode]', 0, '', '{"rid_parent":"[rid_parent]", "CF:STARTDATE": "[cf_startdate]", "CF:RID_BRANCH": "[cf_rid_branch]"}')`
tests := []struct {
name string
queryParams map[string]string
checkResult func(t *testing.T, result string)
}{
{
name: "rid_parent=0 is a real column — filter applied",
queryParams: map[string]string{"rid_parent": "0"},
checkResult: func(t *testing.T, result string) {
if !strings.Contains(strings.ToLower(result), "where") {
t.Error("Expected WHERE clause to be added for rid_parent")
}
if !strings.Contains(result, "rid_parent = 0 OR") && !strings.Contains(result, "rid_parent IS NULL") {
t.Errorf("Expected null-safe filter for rid_parent=0, got:\n%s", result)
}
},
},
{
name: "cf_startdate only appears in JSON string — no filter applied",
queryParams: map[string]string{"cf_startdate": "2024-01-01"},
checkResult: func(t *testing.T, result string) {
if strings.Contains(strings.ToLower(result), "where") {
t.Errorf("Expected no WHERE clause for cf_startdate (only in JSON arg), got:\n%s", result)
}
},
},
{
name: "cf_rid_branch only appears in JSON string — no filter applied",
queryParams: map[string]string{"cf_rid_branch": "5"},
checkResult: func(t *testing.T, result string) {
if strings.Contains(strings.ToLower(result), "where") {
t.Errorf("Expected no WHERE clause for cf_rid_branch (only in JSON arg), got:\n%s", result)
}
},
},
{
name: "description is a real column — filter applied",
queryParams: map[string]string{"description": "test"},
checkResult: func(t *testing.T, result string) {
if !strings.Contains(strings.ToLower(result), "where") {
t.Error("Expected WHERE clause for description")
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := createTestRequest("GET", "/test", tt.queryParams, nil, nil)
variables := make(map[string]interface{})
propQry := make(map[string]string)
result := handler.mergeQueryParams(req, sqlQuery, variables, true, propQry)
tt.checkResult(t, result)
})
}
}
// TestAllowFilterDoesNotMatchFunctionParams verifies that query params that appear only
// as function call arguments in the FROM clause (e.g. [p_rid_doctype]) are not treated
// as column filters, since they are not in the SELECT list.
func TestAllowFilterDoesNotMatchFunctionParams(t *testing.T) {
handler := NewHandler(&MockDatabase{})
sqlQuery := `select rid, rid_parent, description, row_cnt, filterstring, tableprefix, rid_table, tooltip, additionalfilter, haschildren
from crm_get_doc_menu($JQ$[p_tableprefix]$JQ$,[p_rid_parent],[p_rid_doctype],[p_removedup],[p_showall]) r`
tests := []struct {
name string
queryParams map[string]string
checkResult func(t *testing.T, result string)
}{
{
name: "p_rid_doctype is a function param, not a column — no filter applied",
queryParams: map[string]string{"p_rid_doctype": "0"},
checkResult: func(t *testing.T, result string) {
if strings.Contains(strings.ToLower(result), "where") {
t.Errorf("Expected no WHERE clause for p_rid_doctype (function arg, not SELECT column), got:\n%s", result)
}
},
},
{
name: "p_showall is a function param, not a column — no filter applied",
queryParams: map[string]string{"p_showall": "1"},
checkResult: func(t *testing.T, result string) {
if strings.Contains(strings.ToLower(result), "where") {
t.Errorf("Expected no WHERE clause for p_showall (function arg, not SELECT column), got:\n%s", result)
}
},
},
{
name: "rid is a SELECT column — filter applied",
queryParams: map[string]string{"rid": "42"},
checkResult: func(t *testing.T, result string) {
if !strings.Contains(strings.ToLower(result), "where") {
t.Error("Expected WHERE clause for rid (real SELECT column)")
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := createTestRequest("GET", "/test", tt.queryParams, nil, nil)
variables := make(map[string]interface{})
propQry := make(map[string]string)
result := handler.mergeQueryParams(req, sqlQuery, variables, true, propQry)
tt.checkResult(t, result)
})
}
}
// TestGetReplacementForBlankParamDoubleQuote verifies that placeholders surrounded by
// double quotes (as in JSON string values) are blanked to "" not NULL.
func TestGetReplacementForBlankParamDoubleQuote(t *testing.T) {
tests := []struct {
name string
sqlQuery string
param string
expected string
}{
{
name: "Parameter in double quotes (JSON value)",
sqlQuery: `SELECT * FROM f(1, '{"key":"[myparam]"}')`,
param: "[myparam]",
expected: "",
},
{
name: "Parameter not in any quotes",
sqlQuery: `SELECT * FROM f([myparam])`,
param: "[myparam]",
expected: "NULL",
},
{
name: "Parameter in single quotes",
sqlQuery: `SELECT * FROM f('[myparam]')`,
param: "[myparam]",
expected: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := getReplacementForBlankParam(tt.sqlQuery, tt.param)
if result != tt.expected {
t.Errorf("getReplacementForBlankParam() = %q, want %q\nquery: %s", result, tt.expected, tt.sqlQuery)
}
})
}
}
// TestVariableReplacementFromQueryParams verifies that query params matching [placeholder]
// tokens are substituted even when they don't have the p- prefix.
func TestVariableReplacementFromQueryParams(t *testing.T) {
handler := NewHandler(&MockDatabase{})
sqlQuery := `select rid, rid_parent from crm_get_menu([rid_user],'[p_mode]', 0, '', '{"rid_parent":"[rid_parent]","CF:STARTDATE":"[cf_startdate]"}')`
tests := []struct {
name string
queryParams map[string]string
checkResult func(t *testing.T, result string)
}{
{
name: "rid_parent replaced from query param",
queryParams: map[string]string{"rid_parent": "42"},
checkResult: func(t *testing.T, result string) {
if strings.Contains(result, "[rid_parent]") {
t.Errorf("Expected [rid_parent] to be replaced, still present in:\n%s", result)
}
if !strings.Contains(result, "42") {
t.Errorf("Expected value 42 in query, got:\n%s", result)
}
},
},
{
name: "cf_startdate replaced from query param",
queryParams: map[string]string{"cf_startdate": "2024-01-01"},
checkResult: func(t *testing.T, result string) {
if strings.Contains(result, "[cf_startdate]") {
t.Errorf("Expected [cf_startdate] to be replaced, still present in:\n%s", result)
}
if !strings.Contains(result, "2024-01-01") {
t.Errorf("Expected date value in query, got:\n%s", result)
}
},
},
{
name: "missing param blanked to empty string inside JSON (double-quoted)",
queryParams: map[string]string{},
checkResult: func(t *testing.T, result string) {
// [cf_startdate] is surrounded by " in the JSON — should blank to ""
if strings.Contains(result, "[cf_startdate]") {
t.Errorf("Expected [cf_startdate] to be blanked, still present in:\n%s", result)
}
if strings.Contains(result, "NULL") && strings.Contains(result, "cf_startdate") {
t.Errorf("Expected empty string (not NULL) for double-quoted placeholder, got:\n%s", result)
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
inputvars := make([]string, 0)
q := handler.extractInputVariables(sqlQuery, &inputvars)
req := createTestRequest("GET", "/test", tt.queryParams, nil, nil)
variables := make(map[string]interface{})
propQry := make(map[string]string)
q = handler.mergeQueryParams(req, q, variables, false, propQry)
// Simulate the variable replacement + blank-param loop (mirrors function_api.go)
for _, kw := range inputvars {
varName := kw[1 : len(kw)-1]
if val, ok := variables[varName]; ok {
if strVal := strings.TrimSpace(val.(string)); strVal != "" {
q = strings.ReplaceAll(q, kw, ValidSQL(strVal, "colvalue"))
continue
}
}
replacement := getReplacementForBlankParam(q, kw)
q = strings.ReplaceAll(q, kw, replacement)
}
tt.checkResult(t, q)
})
}
}
// TestGetReplacementForBlankParam tests the blank parameter replacement logic
func TestGetReplacementForBlankParam(t *testing.T) {
tests := []struct {
name string
sqlQuery string
param string
expected string
}{
{
name: "Parameter in single quotes",
sqlQuery: "SELECT * FROM users WHERE name = '[username]'",
param: "[username]",
expected: "",
},
{
name: "Parameter in dollar quotes",
sqlQuery: "SELECT * FROM users WHERE data = $[jsondata]$",
param: "[jsondata]",
expected: "",
},
{
name: "Parameter not in quotes",
sqlQuery: "SELECT * FROM users WHERE id = [user_id]",
param: "[user_id]",
expected: "NULL",
},
{
name: "Parameter not in quotes with AND",
sqlQuery: "SELECT * FROM users WHERE id = [user_id] AND status = 1",
param: "[user_id]",
expected: "NULL",
},
{
name: "Parameter in mixed quote context - before quote",
sqlQuery: "SELECT * FROM users WHERE id = [user_id] AND name = 'test'",
param: "[user_id]",
expected: "NULL",
},
{
name: "Parameter in mixed quote context - in quotes",
sqlQuery: "SELECT * FROM users WHERE name = '[username]' AND id = 1",
param: "[username]",
expected: "",
},
{
name: "Parameter with dollar quote tag",
sqlQuery: "SELECT * FROM users WHERE body = $tag$[content]$tag$",
param: "[content]",
expected: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := getReplacementForBlankParam(tt.sqlQuery, tt.param)
if result != tt.expected {
t.Errorf("Expected replacement '%s', got '%s' for query: %s", tt.expected, result, tt.sqlQuery)
}
})
}
}