diff --git a/pkg/funcspec/function_api.go b/pkg/funcspec/function_api.go index f1022d7..88f2288 100644 --- a/pkg/funcspec/function_api.go +++ b/pkg/funcspec/function_api.go @@ -713,18 +713,25 @@ func (h *Handler) mergeQueryParams(r *http.Request, sqlquery string, variables m // Apply filters if allowed if allowFilter && len(parmk) > 1 && strings.Contains(strings.ToLower(sqlquery), strings.ToLower(parmk)) { if len(parmv) > 1 { - // Sanitize each value in the IN clause and wrap in quotes + // Sanitize each value in the IN clause with appropriate quoting sanitizedValues := make([]string, len(parmv)) for i, v := range parmv { - sanitized := ValidSQL(v, "colvalue") - // Wrap each value in single quotes for SQL IN clause - sanitizedValues[i] = fmt.Sprintf("'%s'", sanitized) + if IsNumeric(v) { + // Numeric values don't need quotes + sanitizedValues[i] = ValidSQL(v, "colvalue") + } else { + // String values need quotes + sanitized := ValidSQL(v, "colvalue") + sanitizedValues[i] = fmt.Sprintf("'%s'", sanitized) + } } sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s IN (%s)", ValidSQL(parmk, "colname"), strings.Join(sanitizedValues, ","))) } else { if strings.Contains(val, "match=") { colval := strings.ReplaceAll(val, "match=", "") - colval = ValidSQL(colval, "colvalue") // Sanitize immediately + // Don't sanitize LIKE patterns as it would escape wildcards + // Just remove single quotes to prevent SQL injection + colval = strings.ReplaceAll(colval, "'", "''") if colval != "*" { sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s ILIKE '%%%s%%'", ValidSQL(parmk, "colname"), colval)) } @@ -908,18 +915,12 @@ func ValidSQL(input, mode string) string { "exec ", "execute ", "union ", "declare ", "alter ", "create ", } result := input - lowerResult := strings.ToLower(result) + // Use case-insensitive replacement via regex for _, d := range dangerous { - // Find all occurrences case-insensitively and remove them - for { - idx := strings.Index(lowerResult, d) - if idx == -1 { - break - } - // Remove from both result and lowerResult - result = result[:idx] + result[idx+len(d):] - lowerResult = lowerResult[:idx] + lowerResult[idx+len(d):] - } + // Create case-insensitive regex for the pattern + pattern := "(?i)" + regexp.QuoteMeta(d) + re := regexp.MustCompile(pattern) + result = re.ReplaceAllString(result, "") } return result default: