From f711bf38d2fe009dd55f4a659524e6cfa616dce4 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 31 Dec 2025 07:19:53 +0000 Subject: [PATCH] fix: Enhanced SQL injection protection in funcspec - Added sanitization for path parameters in mergePathParams - Added sanitization for query parameters with p- prefix in mergeQueryParams - Added sanitization for header parameters in mergeHeaderParams - Fixed IN clause to sanitize all values individually - Improved ValidSQL function with better escaping and more injection patterns - Added backslash escaping to colvalue mode - Extended dangerous keyword list in select mode Co-authored-by: warkanum <208308+warkanum@users.noreply.github.com> --- pkg/funcspec/function_api.go | 43 ++++++++++++++++++++++++++++-------- 1 file changed, 34 insertions(+), 9 deletions(-) diff --git a/pkg/funcspec/function_api.go b/pkg/funcspec/function_api.go index 5577cae..82a41df 100644 --- a/pkg/funcspec/function_api.go +++ b/pkg/funcspec/function_api.go @@ -662,7 +662,10 @@ func (h *Handler) mergePathParams(r *http.Request, sqlquery string, variables ma for k, v := range pathVars { kword := fmt.Sprintf("[%s]", k) if strings.Contains(sqlquery, kword) { - sqlquery = strings.ReplaceAll(sqlquery, kword, fmt.Sprintf("%v", v)) + // Sanitize the value before replacing + vStr := fmt.Sprintf("%v", v) + sanitized := ValidSQL(vStr, "colvalue") + sqlquery = strings.ReplaceAll(sqlquery, kword, sanitized) } variables[k] = v @@ -690,7 +693,9 @@ func (h *Handler) mergeQueryParams(r *http.Request, sqlquery string, variables m // Replace in SQL if placeholder exists if strings.Contains(sqlquery, kword) && len(val) > 0 { if strings.HasPrefix(parmk, "p-") { - sqlquery = strings.ReplaceAll(sqlquery, kword, val) + // Sanitize the parameter value before replacing + sanitized := ValidSQL(val, "colvalue") + sqlquery = strings.ReplaceAll(sqlquery, kword, sanitized) } } @@ -702,7 +707,12 @@ func (h *Handler) mergeQueryParams(r *http.Request, sqlquery string, variables m // Apply filters if allowed if allowFilter && len(parmk) > 1 && strings.Contains(strings.ToLower(sqlquery), strings.ToLower(parmk)) { if len(parmv) > 1 { - sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s IN (%s)", ValidSQL(parmk, "colname"), strings.Join(parmv, ","))) + // Sanitize each value in the IN clause + sanitizedValues := make([]string, len(parmv)) + for i, v := range parmv { + sanitizedValues[i] = ValidSQL(v, "colvalue") + } + sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s IN (%s)", ValidSQL(parmk, "colname"), strings.Join(sanitizedValues, ","))) } else { if strings.Contains(val, "match=") { colval := strings.ReplaceAll(val, "match=", "") @@ -743,7 +753,9 @@ func (h *Handler) mergeHeaderParams(r *http.Request, sqlquery string, variables kword := fmt.Sprintf("[%s]", k) if strings.Contains(sqlquery, kword) { - sqlquery = strings.ReplaceAll(sqlquery, kword, val) + // Sanitize the header value before replacing + sanitized := ValidSQL(val, "colvalue") + sqlquery = strings.ReplaceAll(sqlquery, kword, sanitized) } // Handle special headers @@ -862,17 +874,30 @@ func ValidSQL(input, mode string) string { reg := regexp.MustCompile(`[^a-zA-Z0-9_\.]`) return reg.ReplaceAllString(input, "") case "colvalue": - // For column values, escape single quotes - return strings.ReplaceAll(input, "'", "''") + // For column values, escape single quotes and backslashes + result := strings.ReplaceAll(input, "\\", "\\\\") + result = strings.ReplaceAll(result, "'", "''") + return result case "select": // For SELECT clauses, be more permissive but still safe // Remove semicolons and common SQL injection patterns - dangerous := []string{";", "--", "/*", "*/", "xp_", "sp_", "DROP ", "DELETE ", "TRUNCATE ", "UPDATE ", "INSERT "} + dangerous := []string{ + ";", "--", "/*", "*/", "xp_", "sp_", + "DROP ", "drop ", "Drop ", + "DELETE ", "delete ", "Delete ", + "TRUNCATE ", "truncate ", "Truncate ", + "UPDATE ", "update ", "Update ", + "INSERT ", "insert ", "Insert ", + "EXEC ", "exec ", "Exec ", + "EXECUTE ", "execute ", "Execute ", + "UNION ", "union ", "Union ", + "DECLARE ", "declare ", "Declare ", + "ALTER ", "alter ", "Alter ", + "CREATE ", "create ", "Create ", + } result := input for _, d := range dangerous { result = strings.ReplaceAll(result, d, "") - result = strings.ReplaceAll(result, strings.ToLower(d), "") - result = strings.ReplaceAll(result, strings.ToUpper(d), "") } return result default: