diff --git a/pkg/funcspec/function_api.go b/pkg/funcspec/function_api.go index 3e4890b..cf8787c 100644 --- a/pkg/funcspec/function_api.go +++ b/pkg/funcspec/function_api.go @@ -84,7 +84,7 @@ func (h *Handler) SqlQueryList(sqlquery string, options SqlQueryOptions) HTTPFun // Create local copy to avoid modifying the captured parameter across requests sqlquery := sqlquery - ctx, cancel := context.WithTimeout(r.Context(), 900*time.Second) + ctx, cancel := context.WithTimeout(r.Context(), 15*time.Minute) defer cancel() var dbobjlist []map[string]interface{} @@ -423,7 +423,7 @@ func (h *Handler) SqlQuery(sqlquery string, options SqlQueryOptions) HTTPFuncTyp // Create local copy to avoid modifying the captured parameter across requests sqlquery := sqlquery - ctx, cancel := context.WithTimeout(r.Context(), 600*time.Second) + ctx, cancel := context.WithTimeout(r.Context(), 15*time.Minute) defer cancel() propQry := make(map[string]string) @@ -522,10 +522,17 @@ func (h *Handler) SqlQuery(sqlquery string, options SqlQueryOptions) HTTPFuncTyp if strings.HasPrefix(kLower, "x-fieldfilter-") { colname := strings.ReplaceAll(kLower, "x-fieldfilter-", "") if strings.Contains(strings.ToLower(sqlquery), colname) { - if val == "" || val == "0" { - sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("COALESCE(%s, 0) = %s", ValidSQL(colname, "colname"), ValidSQL(val, "colvalue"))) - } else { - sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s = %s", ValidSQL(colname, "colname"), ValidSQL(val, "colvalue"))) + switch val { + case "0": + sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("COALESCE(%s, 0) = 0", ValidSQL(colname, "colname"))) + case "": + sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("(%[1]s = '' OR %[1]s IS NULL)", ValidSQL(colname, "colname"))) + default: + if IsNumeric(val) { + sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s = %s", ValidSQL(colname, "colname"), ValidSQL(val, "colvalue"))) + } else { + sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s = '%s'", ValidSQL(colname, "colname"), ValidSQL(val, "colvalue"))) + } } } } @@ -662,7 +669,10 @@ func (h *Handler) mergePathParams(r *http.Request, sqlquery string, variables ma for k, v := range pathVars { kword := fmt.Sprintf("[%s]", k) if strings.Contains(sqlquery, kword) { - sqlquery = strings.ReplaceAll(sqlquery, kword, fmt.Sprintf("%v", v)) + // Sanitize the value before replacing + vStr := fmt.Sprintf("%v", v) + sanitized := ValidSQL(vStr, "colvalue") + sqlquery = strings.ReplaceAll(sqlquery, kword, sanitized) } variables[k] = v @@ -690,7 +700,9 @@ func (h *Handler) mergeQueryParams(r *http.Request, sqlquery string, variables m // Replace in SQL if placeholder exists if strings.Contains(sqlquery, kword) && len(val) > 0 { if strings.HasPrefix(parmk, "p-") { - sqlquery = strings.ReplaceAll(sqlquery, kword, val) + // Sanitize the parameter value before replacing + sanitized := ValidSQL(val, "colvalue") + sqlquery = strings.ReplaceAll(sqlquery, kword, sanitized) } } @@ -702,15 +714,36 @@ func (h *Handler) mergeQueryParams(r *http.Request, sqlquery string, variables m // Apply filters if allowed if allowFilter && len(parmk) > 1 && strings.Contains(strings.ToLower(sqlquery), strings.ToLower(parmk)) { if len(parmv) > 1 { - sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s IN (%s)", ValidSQL(parmk, "colname"), strings.Join(parmv, ","))) + // Sanitize each value in the IN clause with appropriate quoting + sanitizedValues := make([]string, len(parmv)) + for i, v := range parmv { + if IsNumeric(v) { + // Numeric values don't need quotes + sanitizedValues[i] = ValidSQL(v, "colvalue") + } else { + // String values need quotes + sanitized := ValidSQL(v, "colvalue") + sanitizedValues[i] = fmt.Sprintf("'%s'", sanitized) + } + } + sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s IN (%s)", ValidSQL(parmk, "colname"), strings.Join(sanitizedValues, ","))) } else { if strings.Contains(val, "match=") { colval := strings.ReplaceAll(val, "match=", "") + // Escape single quotes and backslashes for LIKE patterns + // But don't escape wildcards % and _ which are intentional + colval = strings.ReplaceAll(colval, "\\", "\\\\") + colval = strings.ReplaceAll(colval, "'", "''") if colval != "*" { - sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s ILIKE '%%%s%%'", ValidSQL(parmk, "colname"), ValidSQL(colval, "colvalue"))) + sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s ILIKE '%%%s%%'", ValidSQL(parmk, "colname"), colval)) } } else if val == "" || val == "0" { - sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("(%[1]s = %[2]s OR %[1]s IS NULL)", ValidSQL(parmk, "colname"), ValidSQL(val, "colvalue"))) + // For empty/zero values, treat as literal 0 or empty string with quotes + if val == "0" { + sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("(%[1]s = 0 OR %[1]s IS NULL)", ValidSQL(parmk, "colname"))) + } else { + sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("(%[1]s = '' OR %[1]s IS NULL)", ValidSQL(parmk, "colname"))) + } } else { if IsNumeric(val) { sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s = %s", ValidSQL(parmk, "colname"), ValidSQL(val, "colvalue"))) @@ -743,16 +776,25 @@ func (h *Handler) mergeHeaderParams(r *http.Request, sqlquery string, variables kword := fmt.Sprintf("[%s]", k) if strings.Contains(sqlquery, kword) { - sqlquery = strings.ReplaceAll(sqlquery, kword, val) + // Sanitize the header value before replacing + sanitized := ValidSQL(val, "colvalue") + sqlquery = strings.ReplaceAll(sqlquery, kword, sanitized) } // Handle special headers if strings.Contains(k, "x-fieldfilter-") { colname := strings.ReplaceAll(k, "x-fieldfilter-", "") - if val == "" || val == "0" { - sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("COALESCE(%s, 0) = %s", ValidSQL(colname, "colname"), ValidSQL(val, "colvalue"))) - } else { - sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s = %s", ValidSQL(colname, "colname"), ValidSQL(val, "colvalue"))) + switch val { + case "0": + sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("COALESCE(%s, 0) = 0", ValidSQL(colname, "colname"))) + case "": + sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("(%[1]s = '' OR %[1]s IS NULL)", ValidSQL(colname, "colname"))) + default: + if IsNumeric(val) { + sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s = %s", ValidSQL(colname, "colname"), ValidSQL(val, "colvalue"))) + } else { + sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s = '%s'", ValidSQL(colname, "colname"), ValidSQL(val, "colvalue"))) + } } } @@ -782,12 +824,15 @@ func (h *Handler) mergeHeaderParams(r *http.Request, sqlquery string, variables func (h *Handler) replaceMetaVariables(sqlquery string, r *http.Request, userCtx *security.UserContext, metainfo map[string]interface{}, variables map[string]interface{}) string { if strings.Contains(sqlquery, "[p_meta_default]") { data, _ := json.Marshal(metainfo) - sqlquery = strings.ReplaceAll(sqlquery, "[p_meta_default]", fmt.Sprintf("'%s'::jsonb", string(data))) + dataStr := strings.ReplaceAll(string(data), "$META$", "/*META*/") + sqlquery = strings.ReplaceAll(sqlquery, "[p_meta_default]", fmt.Sprintf("$META$%s$META$::jsonb", dataStr)) } if strings.Contains(sqlquery, "[json_variables]") { data, _ := json.Marshal(variables) - sqlquery = strings.ReplaceAll(sqlquery, "[json_variables]", fmt.Sprintf("'%s'::jsonb", string(data))) + dataStr := strings.ReplaceAll(string(data), "$VAR$", "/*VAR*/") + + sqlquery = strings.ReplaceAll(sqlquery, "[json_variables]", fmt.Sprintf("$VAR$%s$VAR$::jsonb", dataStr)) } if strings.Contains(sqlquery, "[rid_user]") { @@ -795,7 +840,7 @@ func (h *Handler) replaceMetaVariables(sqlquery string, r *http.Request, userCtx } if strings.Contains(sqlquery, "[user]") { - sqlquery = strings.ReplaceAll(sqlquery, "[user]", fmt.Sprintf("'%s'", userCtx.UserName)) + sqlquery = strings.ReplaceAll(sqlquery, "[user]", fmt.Sprintf("$USR$%s$USR$", strings.ReplaceAll(userCtx.UserName, "$USR$", "/*USR*/"))) } if strings.Contains(sqlquery, "[rid_session]") { @@ -806,7 +851,7 @@ func (h *Handler) replaceMetaVariables(sqlquery string, r *http.Request, userCtx } if strings.Contains(sqlquery, "[method]") { - sqlquery = strings.ReplaceAll(sqlquery, "[method]", r.Method) + sqlquery = strings.ReplaceAll(sqlquery, "[method]", fmt.Sprintf("$M$%s$M$", strings.ReplaceAll(r.Method, "$M$", "/*M*/"))) } if strings.Contains(sqlquery, "[post_body]") { @@ -819,7 +864,7 @@ func (h *Handler) replaceMetaVariables(sqlquery string, r *http.Request, userCtx } } } - sqlquery = strings.ReplaceAll(sqlquery, "[post_body]", fmt.Sprintf("'%s'", bodystr)) + sqlquery = strings.ReplaceAll(sqlquery, "[post_body]", fmt.Sprintf("$PBODY$%s$PBODY$", strings.ReplaceAll(bodystr, "$PBODY$", "/*PBODY*/"))) } return sqlquery @@ -859,19 +904,23 @@ func ValidSQL(input, mode string) string { reg := regexp.MustCompile(`[^a-zA-Z0-9_\.]`) return reg.ReplaceAllString(input, "") case "colvalue": - // For column values, escape single quotes - return strings.ReplaceAll(input, "'", "''") + // For column values, escape single quotes and backslashes + // Note: Backslashes must be escaped first, then single quotes + result := strings.ReplaceAll(input, "\\", "\\\\") + result = strings.ReplaceAll(result, "'", "''") + return result case "select": // For SELECT clauses, be more permissive but still safe - // Remove semicolons and common SQL injection patterns - dangerous := []string{";", "--", "/*", "*/", "xp_", "sp_", "DROP ", "DELETE ", "TRUNCATE ", "UPDATE ", "INSERT "} - result := input - for _, d := range dangerous { - result = strings.ReplaceAll(result, d, "") - result = strings.ReplaceAll(result, strings.ToLower(d), "") - result = strings.ReplaceAll(result, strings.ToUpper(d), "") + // Remove semicolons and common SQL injection patterns (case-insensitive) + dangerous := []string{ + ";", "--", "/\\*", "\\*/", "xp_", "sp_", + "drop ", "delete ", "truncate ", "update ", "insert ", + "exec ", "execute ", "union ", "declare ", "alter ", "create ", } - return result + // Build a single regex pattern with all dangerous keywords + pattern := "(?i)(" + strings.Join(dangerous, "|") + ")" + re := regexp.MustCompile(pattern) + return re.ReplaceAllString(input, "") default: return input }