diff --git a/pkg/funcspec/function_api.go b/pkg/funcspec/function_api.go index 5c3c59c..a5231a9 100644 --- a/pkg/funcspec/function_api.go +++ b/pkg/funcspec/function_api.go @@ -28,18 +28,16 @@ type Handler struct { } type SqlQueryOptions struct { - NoCount bool - BlankParams bool - AllowFilter bool - AllowQueryParamFilters bool + NoCount bool + BlankParams bool + AllowFilter bool } func NewSqlQueryOptions() SqlQueryOptions { return SqlQueryOptions{ - NoCount: false, - BlankParams: true, - AllowFilter: true, - AllowQueryParamFilters: false, + NoCount: false, + BlankParams: true, + AllowFilter: true, } } @@ -140,11 +138,6 @@ func (h *Handler) SqlQueryList(sqlquery string, options SqlQueryOptions) HTTPFun // Merge query string parameters sqlquery = h.mergeQueryParams(r, sqlquery, variables, options.AllowFilter, propQry) - // Apply p_-prefixed query params as field filters - if options.AllowQueryParamFilters { - sqlquery = h.applyQueryParamFilters(r, sqlquery) - } - // Merge header parameters sqlquery = h.mergeHeaderParams(r, sqlquery, variables, propQry, &complexAPI) @@ -488,11 +481,6 @@ func (h *Handler) SqlQuery(sqlquery string, options SqlQueryOptions) HTTPFuncTyp // Merge query string parameters sqlquery = h.mergeQueryParams(r, sqlquery, variables, false, propQry) - // Apply p_-prefixed query params as field filters - if options.AllowQueryParamFilters { - sqlquery = h.applyQueryParamFilters(r, sqlquery) - } - // Merge header parameters sqlquery = h.mergeHeaderParams(r, sqlquery, variables, propQry, &complexAPI) hookCtx.ComplexAPI = complexAPI @@ -741,8 +729,9 @@ func (h *Handler) mergeQueryParams(r *http.Request, sqlquery string, variables m propQry[parmk] = val } - // Apply filters if allowed - if allowFilter && len(parmk) > 1 && strings.Contains(strings.ToLower(sqlquery), strings.ToLower(parmk)) { + // Apply filters if allowed — check against string-literal-stripped SQL to avoid + // matching column names that only appear inside quoted arguments (e.g. JSON strings) + if allowFilter && len(parmk) > 1 && strings.Contains(strings.ToLower(sqlStripStringLiterals(sqlquery)), strings.ToLower(parmk)) { if len(parmv) > 1 { // Sanitize each value in the IN clause with appropriate quoting sanitizedValues := make([]string, len(parmv)) @@ -858,35 +847,6 @@ func sqlStripStringLiterals(sql string) string { return re.ReplaceAllString(sql, "''") } -// applyQueryParamFilters applies query parameters as SQL field filters when the param name -// appears as a structural identifier in the SQL (not inside a string literal). -// e.g. ?rid_parent=0 → (rid_parent = 0 OR rid_parent IS NULL) -func (h *Handler) applyQueryParamFilters(r *http.Request, sqlquery string) string { - sqlStructure := strings.ToLower(sqlStripStringLiterals(sqlquery)) - for parmk, parmv := range r.URL.Query() { - if len(parmv) == 0 || !strings.Contains(sqlStructure, strings.ToLower(parmk)) { - continue - } - val := parmv[0] - dec, err := restheadspec.DecodeParam(val) - if err == nil { - val = dec - } - col := ValidSQL(parmk, "colname") - switch { - case val == "0": - sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("(%[1]s = 0 OR %[1]s IS NULL)", col)) - case val == "": - sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("(%[1]s = '' OR %[1]s IS NULL)", col)) - case IsNumeric(val): - sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s = %s", col, ValidSQL(val, "colvalue"))) - default: - sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s = '%s'", col, ValidSQL(val, "colvalue"))) - } - } - return sqlquery -} - // replaceMetaVariables replaces meta variables like [rid_user], [user], etc. in the SQL query func (h *Handler) replaceMetaVariables(sqlquery string, r *http.Request, userCtx *security.UserContext, metainfo map[string]interface{}, variables map[string]interface{}) string { if strings.Contains(sqlquery, "[p_meta_default]") { diff --git a/pkg/funcspec/function_api_test.go b/pkg/funcspec/function_api_test.go index c90bf90..505a25b 100644 --- a/pkg/funcspec/function_api_test.go +++ b/pkg/funcspec/function_api_test.go @@ -851,6 +851,231 @@ func TestReplaceMetaVariables(t *testing.T) { } } +// TestSqlStripStringLiterals tests that single-quoted string literals are removed +func TestSqlStripStringLiterals(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + { + name: "No string literals", + input: "SELECT rid, rid_parent FROM users", + expected: "SELECT rid, rid_parent FROM users", + }, + { + name: "Simple string literal", + input: "SELECT * FROM users WHERE mode = 'admin'", + expected: "SELECT * FROM users WHERE mode = ''", + }, + { + name: "JSON argument containing column names", + input: `SELECT rid, rid_parent FROM crm_get_menu(1,'mode', '{"rid_parent":"[rid_parent]","CF:STARTDATE":"[cf_startdate]"}')`, + expected: `SELECT rid, rid_parent FROM crm_get_menu(1,'', '')`, + }, + { + name: "Escaped single quotes inside literal", + input: "SELECT * FROM t WHERE name = 'O''Brien'", + expected: "SELECT * FROM t WHERE name = ''", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := sqlStripStringLiterals(tt.input) + if result != tt.expected { + t.Errorf("sqlStripStringLiterals() =\n %q\nwant\n %q", result, tt.expected) + } + }) + } +} + +// TestAllowFilterDoesNotMatchInsideJsonArgument verifies that AllowFilter will add WHERE +// clauses for real output columns (rid, rid_parent) but not for names that only appear +// inside a JSON string argument (cf_startdate, cf_rid_branch). +func TestAllowFilterDoesNotMatchInsideJsonArgument(t *testing.T) { + handler := NewHandler(&MockDatabase{}) + + sqlQuery := `select rid, rid_parent, description + from crm_get_menu([rid_user],'[p_mode]', 0, '', '{"rid_parent":"[rid_parent]", "CF:STARTDATE": "[cf_startdate]", "CF:RID_BRANCH": "[cf_rid_branch]"}')` + + tests := []struct { + name string + queryParams map[string]string + checkResult func(t *testing.T, result string) + }{ + { + name: "rid_parent=0 is a real column — filter applied", + queryParams: map[string]string{"rid_parent": "0"}, + checkResult: func(t *testing.T, result string) { + if !strings.Contains(strings.ToLower(result), "where") { + t.Error("Expected WHERE clause to be added for rid_parent") + } + if !strings.Contains(result, "rid_parent = 0 OR") && !strings.Contains(result, "rid_parent IS NULL") { + t.Errorf("Expected null-safe filter for rid_parent=0, got:\n%s", result) + } + }, + }, + { + name: "cf_startdate only appears in JSON string — no filter applied", + queryParams: map[string]string{"cf_startdate": "2024-01-01"}, + checkResult: func(t *testing.T, result string) { + if strings.Contains(strings.ToLower(result), "where") { + t.Errorf("Expected no WHERE clause for cf_startdate (only in JSON arg), got:\n%s", result) + } + }, + }, + { + name: "cf_rid_branch only appears in JSON string — no filter applied", + queryParams: map[string]string{"cf_rid_branch": "5"}, + checkResult: func(t *testing.T, result string) { + if strings.Contains(strings.ToLower(result), "where") { + t.Errorf("Expected no WHERE clause for cf_rid_branch (only in JSON arg), got:\n%s", result) + } + }, + }, + { + name: "description is a real column — filter applied", + queryParams: map[string]string{"description": "test"}, + checkResult: func(t *testing.T, result string) { + if !strings.Contains(strings.ToLower(result), "where") { + t.Error("Expected WHERE clause for description") + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := createTestRequest("GET", "/test", tt.queryParams, nil, nil) + variables := make(map[string]interface{}) + propQry := make(map[string]string) + + result := handler.mergeQueryParams(req, sqlQuery, variables, true, propQry) + tt.checkResult(t, result) + }) + } +} + +// TestGetReplacementForBlankParamDoubleQuote verifies that placeholders surrounded by +// double quotes (as in JSON string values) are blanked to "" not NULL. +func TestGetReplacementForBlankParamDoubleQuote(t *testing.T) { + tests := []struct { + name string + sqlQuery string + param string + expected string + }{ + { + name: "Parameter in double quotes (JSON value)", + sqlQuery: `SELECT * FROM f(1, '{"key":"[myparam]"}')`, + param: "[myparam]", + expected: "", + }, + { + name: "Parameter not in any quotes", + sqlQuery: `SELECT * FROM f([myparam])`, + param: "[myparam]", + expected: "NULL", + }, + { + name: "Parameter in single quotes", + sqlQuery: `SELECT * FROM f('[myparam]')`, + param: "[myparam]", + expected: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := getReplacementForBlankParam(tt.sqlQuery, tt.param) + if result != tt.expected { + t.Errorf("getReplacementForBlankParam() = %q, want %q\nquery: %s", result, tt.expected, tt.sqlQuery) + } + }) + } +} + +// TestVariableReplacementFromQueryParams verifies that query params matching [placeholder] +// tokens are substituted even when they don't have the p- prefix. +func TestVariableReplacementFromQueryParams(t *testing.T) { + handler := NewHandler(&MockDatabase{}) + + sqlQuery := `select rid, rid_parent from crm_get_menu([rid_user],'[p_mode]', 0, '', '{"rid_parent":"[rid_parent]","CF:STARTDATE":"[cf_startdate]"}')` + + tests := []struct { + name string + queryParams map[string]string + checkResult func(t *testing.T, result string) + }{ + { + name: "rid_parent replaced from query param", + queryParams: map[string]string{"rid_parent": "42"}, + checkResult: func(t *testing.T, result string) { + if strings.Contains(result, "[rid_parent]") { + t.Errorf("Expected [rid_parent] to be replaced, still present in:\n%s", result) + } + if !strings.Contains(result, "42") { + t.Errorf("Expected value 42 in query, got:\n%s", result) + } + }, + }, + { + name: "cf_startdate replaced from query param", + queryParams: map[string]string{"cf_startdate": "2024-01-01"}, + checkResult: func(t *testing.T, result string) { + if strings.Contains(result, "[cf_startdate]") { + t.Errorf("Expected [cf_startdate] to be replaced, still present in:\n%s", result) + } + if !strings.Contains(result, "2024-01-01") { + t.Errorf("Expected date value in query, got:\n%s", result) + } + }, + }, + { + name: "missing param blanked to empty string inside JSON (double-quoted)", + queryParams: map[string]string{}, + checkResult: func(t *testing.T, result string) { + // [cf_startdate] is surrounded by " in the JSON — should blank to "" + if strings.Contains(result, "[cf_startdate]") { + t.Errorf("Expected [cf_startdate] to be blanked, still present in:\n%s", result) + } + if strings.Contains(result, "NULL") && strings.Contains(result, "cf_startdate") { + t.Errorf("Expected empty string (not NULL) for double-quoted placeholder, got:\n%s", result) + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + inputvars := make([]string, 0) + q := handler.extractInputVariables(sqlQuery, &inputvars) + + req := createTestRequest("GET", "/test", tt.queryParams, nil, nil) + variables := make(map[string]interface{}) + propQry := make(map[string]string) + + q = handler.mergeQueryParams(req, q, variables, false, propQry) + + // Simulate the variable replacement + blank-param loop (mirrors function_api.go) + for _, kw := range inputvars { + varName := kw[1 : len(kw)-1] + if val, ok := variables[varName]; ok { + if strVal := strings.TrimSpace(val.(string)); strVal != "" { + q = strings.ReplaceAll(q, kw, ValidSQL(strVal, "colvalue")) + continue + } + } + replacement := getReplacementForBlankParam(q, kw) + q = strings.ReplaceAll(q, kw, replacement) + } + + tt.checkResult(t, q) + }) + } +} + // TestGetReplacementForBlankParam tests the blank parameter replacement logic func TestGetReplacementForBlankParam(t *testing.T) { tests := []struct {