Files
ResolveSpec/pkg/funcspec/function_api.go
copilot-swe-agent[bot] f711bf38d2 fix: Enhanced SQL injection protection in funcspec
- Added sanitization for path parameters in mergePathParams
- Added sanitization for query parameters with p- prefix in mergeQueryParams
- Added sanitization for header parameters in mergeHeaderParams
- Fixed IN clause to sanitize all values individually
- Improved ValidSQL function with better escaping and more injection patterns
- Added backslash escaping to colvalue mode
- Extended dangerous keyword list in select mode

Co-authored-by: warkanum <208308+warkanum@users.noreply.github.com>
2025-12-31 07:19:53 +00:00

1085 lines
33 KiB
Go

package funcspec
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"regexp"
"runtime/debug"
"strconv"
"strings"
"time"
"github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/logger"
"github.com/bitechdev/ResolveSpec/pkg/restheadspec"
"github.com/bitechdev/ResolveSpec/pkg/security"
)
// Handler handles function-based SQL API requests
type Handler struct {
db common.Database
hooks *HookRegistry
variablesCallback func(r *http.Request) map[string]interface{}
}
type SqlQueryOptions struct {
NoCount bool
BlankParams bool
AllowFilter bool
}
func NewSqlQueryOptions() SqlQueryOptions {
return SqlQueryOptions{
NoCount: false,
BlankParams: true,
AllowFilter: true,
}
}
// NewHandler creates a new function API handler
func NewHandler(db common.Database) *Handler {
return &Handler{
db: db,
hooks: NewHookRegistry(),
}
}
// GetDatabase returns the underlying database connection
// Implements common.SpecHandler interface
func (h *Handler) GetDatabase() common.Database {
return h.db
}
func (h *Handler) SetVariablesCallback(callback func(r *http.Request) map[string]interface{}) {
h.variablesCallback = callback
}
func (h *Handler) GetVariablesCallback() func(r *http.Request) map[string]interface{} {
return h.variablesCallback
}
// Hooks returns the hook registry for this handler
// Use this to register custom hooks for operations
func (h *Handler) Hooks() *HookRegistry {
return h.hooks
}
// HTTPFuncType is a function type for HTTP handlers
type HTTPFuncType func(http.ResponseWriter, *http.Request)
// SqlQueryList creates an HTTP handler that executes a SQL query and returns a list with pagination
func (h *Handler) SqlQueryList(sqlquery string, options SqlQueryOptions) HTTPFuncType {
return func(w http.ResponseWriter, r *http.Request) {
defer func() {
if err := recover(); err != nil {
stack := debug.Stack()
logger.Error("Panic in SqlQueryList: %v\nStack trace:\n%s", err, string(stack))
http.Error(w, fmt.Sprintf("Internal server error: %v", err), http.StatusInternalServerError)
}
}()
// Create local copy to avoid modifying the captured parameter across requests
sqlquery := sqlquery
ctx, cancel := context.WithTimeout(r.Context(), 15*time.Minute)
defer cancel()
var dbobjlist []map[string]interface{}
var total int64
propQry := make(map[string]string)
inputvars := make([]string, 0)
metainfo := make(map[string]interface{})
variables := make(map[string]interface{})
complexAPI := false
// Get user context from security package
userCtx, ok := security.GetUserContext(ctx)
if !ok {
logger.Warn("No user context found in request")
userCtx = &security.UserContext{UserID: 0, UserName: "anonymous"}
}
w.Header().Set("Content-Type", "application/json")
// Initialize hook context
hookCtx := &HookContext{
Context: ctx,
Handler: h,
Request: r,
Writer: w,
SQLQuery: sqlquery,
Variables: variables,
InputVars: inputvars,
MetaInfo: metainfo,
PropQry: propQry,
UserContext: userCtx,
NoCount: options.NoCount,
BlankParams: options.BlankParams,
AllowFilter: options.AllowFilter,
ComplexAPI: complexAPI,
}
// Execute BeforeQueryList hook
if err := h.hooks.Execute(BeforeQueryList, hookCtx); err != nil {
logger.Error("BeforeQueryList hook failed: %v", err)
sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err)
return
}
// Check if hook aborted the operation
if hookCtx.Abort {
if hookCtx.AbortCode == 0 {
hookCtx.AbortCode = http.StatusBadRequest
}
sendError(w, hookCtx.AbortCode, "operation_aborted", hookCtx.AbortMessage, nil)
return
}
// Use potentially modified SQL query and variables from hooks
sqlquery = hookCtx.SQLQuery
variables = hookCtx.Variables
// complexAPI = hookCtx.ComplexAPI
// Extract input variables from SQL query (placeholders like [variable])
sqlquery = h.extractInputVariables(sqlquery, &inputvars)
// Merge URL path parameters
sqlquery = h.mergePathParams(r, sqlquery, variables)
// Parse comprehensive parameters from headers and query string
reqParams := h.ParseParameters(r)
complexAPI = reqParams.ComplexAPI
// Merge query string parameters
sqlquery = h.mergeQueryParams(r, sqlquery, variables, options.AllowFilter, propQry)
// Merge header parameters
sqlquery = h.mergeHeaderParams(r, sqlquery, variables, propQry, &complexAPI)
// Apply filters from parsed parameters (if not already applied by pAllowFilter)
if !options.AllowFilter {
sqlquery = h.ApplyFilters(sqlquery, reqParams)
}
// Apply field selection
sqlquery = h.ApplyFieldSelection(sqlquery, reqParams)
// Apply DISTINCT if requested
sqlquery = h.ApplyDistinct(sqlquery, reqParams)
// Override pNoCount if skipcount is specified
if reqParams.SkipCount {
options.NoCount = true
}
// Build metainfo
metainfo["ipaddress"] = getIPAddress(r)
metainfo["url"] = r.RequestURI
metainfo["user"] = userCtx.UserName
metainfo["rid_user"] = fmt.Sprintf("%d", userCtx.UserID)
metainfo["method"] = r.Method
metainfo["variables"] = variables
// Replace meta variables in SQL
sqlquery = h.replaceMetaVariables(sqlquery, r, userCtx, metainfo, variables)
// Remove unused input variables
if options.BlankParams {
for _, kw := range inputvars {
replacement := getReplacementForBlankParam(sqlquery, kw)
sqlquery = strings.ReplaceAll(sqlquery, kw, replacement)
logger.Debug("Replaced unused variable %s with: %s", kw, replacement)
}
}
// Update hook context with latest SQL query and variables
hookCtx.SQLQuery = sqlquery
hookCtx.Variables = variables
hookCtx.InputVars = inputvars
// Execute query within transaction
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
sqlqueryCnt := sqlquery
// Parse sorting and pagination parameters
sortcols, limit, offset := h.parsePaginationParams(r)
// Override with parsed parameters if available
if reqParams.SortColumns != "" {
sortcols = reqParams.SortColumns
}
if reqParams.Limit > 0 {
limit = reqParams.Limit
}
if reqParams.Offset > 0 {
offset = reqParams.Offset
}
hookCtx.SortColumns = sortcols
hookCtx.Limit = limit
hookCtx.Offset = offset
fromPos := strings.Index(strings.ToLower(sqlquery), "from ")
orderbyPos := strings.Index(strings.ToLower(sqlquery), "order by")
if len(sortcols) > 0 && (orderbyPos < 0 || (orderbyPos > 0 && orderbyPos < fromPos)) {
sqlquery = fmt.Sprintf("%s \nORDER BY %s", sqlquery, ValidSQL(sortcols, "select"))
}
if !options.NoCount {
if limit > 0 && offset > 0 {
sqlquery = fmt.Sprintf("%s \nLIMIT %d OFFSET %d", sqlquery, limit, offset)
} else if limit > 0 {
sqlquery = fmt.Sprintf("%s \nLIMIT %d", sqlquery, limit)
} else {
sqlquery = fmt.Sprintf("%s \nLIMIT %d", sqlquery, 20000)
}
// Get total count
countQuery := fmt.Sprintf("SELECT COUNT(1) FROM (%s) cnts", sqlqueryCnt)
var countResult struct{ Count int64 }
if err := tx.Query(ctx, &countResult, countQuery); err != nil {
sendError(w, http.StatusBadRequest, "count_failed", "Failed to retrieve record count", err)
return err
}
total = countResult.Count
}
// Execute BeforeSQLExec hook
hookCtx.SQLQuery = sqlquery
if err := h.hooks.Execute(BeforeSQLExec, hookCtx); err != nil {
logger.Error("BeforeSQLExec hook failed: %v", err)
sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err)
return err
}
// Use potentially modified SQL query from hook
sqlquery = hookCtx.SQLQuery
// Execute main query
rows := make([]map[string]interface{}, 0)
if err := tx.Query(ctx, &rows, sqlquery); err != nil {
sendError(w, http.StatusBadRequest, "query_failed", "Failed to retrieve records", err)
return err
}
// Normalize PostgreSQL types for proper JSON marshaling
dbobjlist = normalizePostgresTypesList(rows)
if options.NoCount {
total = int64(len(dbobjlist))
}
// Execute AfterSQLExec hook
hookCtx.Result = dbobjlist
hookCtx.Total = total
if err := h.hooks.Execute(AfterSQLExec, hookCtx); err != nil {
logger.Error("AfterSQLExec hook failed: %v", err)
sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err)
return err
}
// Use potentially modified result from hook
if modifiedResult, ok := hookCtx.Result.([]map[string]interface{}); ok {
dbobjlist = modifiedResult
}
total = hookCtx.Total
return nil
})
if err != nil {
logger.Error("Transaction failed: %v", err)
return
}
// Execute AfterQueryList hook
hookCtx.Result = dbobjlist
hookCtx.Total = total
hookCtx.Error = err
if err := h.hooks.Execute(AfterQueryList, hookCtx); err != nil {
logger.Error("AfterQueryList hook failed: %v", err)
sendError(w, http.StatusInternalServerError, "hook_error", "Hook execution failed", err)
return
}
// Use potentially modified result from hook
if modifiedResult, ok := hookCtx.Result.([]map[string]interface{}); ok {
dbobjlist = modifiedResult
}
total = hookCtx.Total
// Set response headers
respOffset := 0
if offsetStr := r.URL.Query().Get("offset"); offsetStr != "" {
if o, err := strconv.Atoi(offsetStr); err == nil {
respOffset = o
}
}
w.Header().Set("Content-Range", fmt.Sprintf("items %d-%d/%d", respOffset, respOffset+len(dbobjlist), total))
logger.Info("Serving: Records %d of %d", len(dbobjlist), total)
// Execute BeforeResponse hook
hookCtx.Result = dbobjlist
hookCtx.Total = total
if err := h.hooks.Execute(BeforeResponse, hookCtx); err != nil {
logger.Error("BeforeResponse hook failed: %v", err)
sendError(w, http.StatusInternalServerError, "hook_error", "Hook execution failed", err)
return
}
// Use potentially modified result from hook
if modifiedResult, ok := hookCtx.Result.([]map[string]interface{}); ok {
dbobjlist = modifiedResult
}
if len(dbobjlist) == 0 {
_, _ = w.Write([]byte("[]"))
return
}
// Format response based on response format
switch reqParams.ResponseFormat {
case "syncfusion":
// Syncfusion format: { result: data, count: total }
response := map[string]interface{}{
"result": dbobjlist,
"count": total,
}
data, err := json.Marshal(response)
if err != nil {
sendError(w, http.StatusInternalServerError, "json_error", "Could not marshal response", err)
} else {
if int64(len(dbobjlist)) < total {
w.WriteHeader(http.StatusPartialContent)
}
_, _ = w.Write(data)
}
case "detail":
// Detail format: complex API with metadata
metaobj := map[string]interface{}{
"items": dbobjlist,
"count": fmt.Sprintf("%d", len(dbobjlist)),
"total": fmt.Sprintf("%d", total),
"tablename": r.URL.Path,
"tableprefix": "gsql",
}
data, err := json.Marshal(metaobj)
if err != nil {
sendError(w, http.StatusInternalServerError, "json_error", "Could not marshal response", err)
} else {
if int64(len(dbobjlist)) < total {
w.WriteHeader(http.StatusPartialContent)
}
_, _ = w.Write(data)
}
default:
// Simple format: just return the data array (or complex API if requested)
if complexAPI {
metaobj := map[string]interface{}{
"items": dbobjlist,
"count": fmt.Sprintf("%d", len(dbobjlist)),
"total": fmt.Sprintf("%d", total),
"tablename": r.URL.Path,
"tableprefix": "gsql",
}
data, err := json.Marshal(metaobj)
if err != nil {
sendError(w, http.StatusInternalServerError, "json_error", "Could not marshal response", err)
} else {
if int64(len(dbobjlist)) < total {
w.WriteHeader(http.StatusPartialContent)
}
_, _ = w.Write(data)
}
} else {
data, err := json.Marshal(dbobjlist)
if err != nil {
sendError(w, http.StatusInternalServerError, "json_error", "Could not marshal response", err)
} else {
if int64(len(dbobjlist)) < total {
w.WriteHeader(http.StatusPartialContent)
}
_, _ = w.Write(data)
}
}
}
}
}
// SqlQuery creates an HTTP handler that executes a SQL query and returns a single record
func (h *Handler) SqlQuery(sqlquery string, options SqlQueryOptions) HTTPFuncType {
return func(w http.ResponseWriter, r *http.Request) {
defer func() {
if err := recover(); err != nil {
stack := debug.Stack()
logger.Error("Panic in SqlQuery: %v\nStack trace:\n%s", err, string(stack))
http.Error(w, fmt.Sprintf("Internal server error: %v", err), http.StatusInternalServerError)
}
}()
// Create local copy to avoid modifying the captured parameter across requests
sqlquery := sqlquery
ctx, cancel := context.WithTimeout(r.Context(), 15*time.Minute)
defer cancel()
propQry := make(map[string]string)
inputvars := make([]string, 0)
metainfo := make(map[string]interface{})
variables := make(map[string]interface{})
dbobj := make(map[string]interface{})
complexAPI := false
// Get user context from security package
userCtx, ok := security.GetUserContext(ctx)
if !ok {
logger.Warn("No user context found in request")
userCtx = &security.UserContext{UserID: 0, UserName: "anonymous"}
}
w.Header().Set("Content-Type", "application/json")
// Initialize hook context
hookCtx := &HookContext{
Context: ctx,
Handler: h,
Request: r,
Writer: w,
SQLQuery: sqlquery,
Variables: variables,
InputVars: inputvars,
MetaInfo: metainfo,
PropQry: propQry,
UserContext: userCtx,
BlankParams: options.BlankParams,
ComplexAPI: complexAPI,
}
// Execute BeforeQuery hook
if err := h.hooks.Execute(BeforeQuery, hookCtx); err != nil {
logger.Error("BeforeQuery hook failed: %v", err)
sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err)
return
}
// Check if hook aborted the operation
if hookCtx.Abort {
if hookCtx.AbortCode == 0 {
hookCtx.AbortCode = http.StatusBadRequest
}
sendError(w, hookCtx.AbortCode, "operation_aborted", hookCtx.AbortMessage, nil)
return
}
// Use potentially modified SQL query and variables from hooks
sqlquery = hookCtx.SQLQuery
variables = hookCtx.Variables
// Extract input variables from SQL query
sqlquery = h.extractInputVariables(sqlquery, &inputvars)
// Merge URL path parameters
sqlquery = h.mergePathParams(r, sqlquery, variables)
// Parse comprehensive parameters from headers and query string
reqParams := h.ParseParameters(r)
complexAPI = reqParams.ComplexAPI
// Merge query string parameters
sqlquery = h.mergeQueryParams(r, sqlquery, variables, false, propQry)
// Merge header parameters
sqlquery = h.mergeHeaderParams(r, sqlquery, variables, propQry, &complexAPI)
hookCtx.ComplexAPI = complexAPI
// Apply filters from parsed parameters
sqlquery = h.ApplyFilters(sqlquery, reqParams)
// Apply field selection
sqlquery = h.ApplyFieldSelection(sqlquery, reqParams)
// Apply DISTINCT if requested
sqlquery = h.ApplyDistinct(sqlquery, reqParams)
// Build metainfo
metainfo["ipaddress"] = getIPAddress(r)
metainfo["url"] = r.RequestURI
metainfo["user"] = userCtx.UserName
metainfo["rid_user"] = fmt.Sprintf("%d", userCtx.UserID)
metainfo["method"] = r.Method
metainfo["variables"] = variables
// Replace meta variables in SQL
sqlquery = h.replaceMetaVariables(sqlquery, r, userCtx, metainfo, variables)
// Apply field filters from headers
for k, val := range propQry {
kLower := strings.ToLower(k)
if strings.HasPrefix(kLower, "x-fieldfilter-") {
colname := strings.ReplaceAll(kLower, "x-fieldfilter-", "")
if strings.Contains(strings.ToLower(sqlquery), colname) {
if val == "" || val == "0" {
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("COALESCE(%s, 0) = %s", ValidSQL(colname, "colname"), ValidSQL(val, "colvalue")))
} else {
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s = %s", ValidSQL(colname, "colname"), ValidSQL(val, "colvalue")))
}
}
}
}
// Remove unused input variables
if options.BlankParams {
for _, kw := range inputvars {
replacement := getReplacementForBlankParam(sqlquery, kw)
sqlquery = strings.ReplaceAll(sqlquery, kw, replacement)
logger.Debug("Replaced unused variable %s with: %s", kw, replacement)
}
}
// Update hook context with latest SQL query and variables
hookCtx.SQLQuery = sqlquery
hookCtx.Variables = variables
hookCtx.InputVars = inputvars
// Execute query within transaction
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
// Execute BeforeSQLExec hook
if err := h.hooks.Execute(BeforeSQLExec, hookCtx); err != nil {
logger.Error("BeforeSQLExec hook failed: %v", err)
sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err)
return err
}
// Use potentially modified SQL query from hook
sqlquery = hookCtx.SQLQuery
// Execute main query
rows := make([]map[string]interface{}, 0)
if err := tx.Query(ctx, &rows, sqlquery); err != nil {
sendError(w, http.StatusBadRequest, "query_failed", "Failed to retrieve records", err)
return err
}
if len(rows) > 0 {
dbobj = normalizePostgresTypes(rows[0])
}
// Execute AfterSQLExec hook
hookCtx.Result = dbobj
if err := h.hooks.Execute(AfterSQLExec, hookCtx); err != nil {
logger.Error("AfterSQLExec hook failed: %v", err)
sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err)
return err
}
// Use potentially modified result from hook
if modifiedResult, ok := hookCtx.Result.(map[string]interface{}); ok {
dbobj = modifiedResult
}
return nil
})
if err != nil {
logger.Error("Transaction failed: %v", err)
return
}
// Execute AfterQuery hook
hookCtx.Result = dbobj
hookCtx.Error = err
if err := h.hooks.Execute(AfterQuery, hookCtx); err != nil {
logger.Error("AfterQuery hook failed: %v", err)
sendError(w, http.StatusInternalServerError, "hook_error", "Hook execution failed", err)
return
}
// Use potentially modified result from hook
if modifiedResult, ok := hookCtx.Result.(map[string]interface{}); ok {
dbobj = modifiedResult
}
// Execute BeforeResponse hook
hookCtx.Result = dbobj
if err := h.hooks.Execute(BeforeResponse, hookCtx); err != nil {
logger.Error("BeforeResponse hook failed: %v", err)
sendError(w, http.StatusInternalServerError, "hook_error", "Hook execution failed", err)
return
}
// Use potentially modified result from hook
if modifiedResult, ok := hookCtx.Result.(map[string]interface{}); ok {
dbobj = modifiedResult
}
// Check if response should be root-level data
if val, ok := dbobj["root_as_data"]; ok {
data, err := json.Marshal(val)
if err != nil {
sendError(w, http.StatusInternalServerError, "json_error", "Could not marshal response", err)
} else {
_, _ = w.Write(data)
}
return
}
// Marshal and send response
data, err := json.Marshal(dbobj)
if err != nil {
sendError(w, http.StatusInternalServerError, "json_error", "Could not marshal response", err)
} else {
_, _ = w.Write(data)
}
}
}
// Helper functions
// extractInputVariables extracts placeholders like [variable] from the SQL query
func (h *Handler) extractInputVariables(sqlquery string, inputvars *[]string) string {
testsqlquery := sqlquery
for i := 0; i <= strings.Count(sqlquery, "[")*4; i++ {
iStart := strings.Index(testsqlquery, "[")
if iStart < 0 {
break
}
iEnd := strings.Index(testsqlquery, "]")
if iEnd < 0 {
break
}
*inputvars = append(*inputvars, testsqlquery[iStart:iEnd+1])
testsqlquery = testsqlquery[iEnd+1:]
}
return sqlquery
}
// mergePathParams merges URL path parameters into the SQL query
func (h *Handler) mergePathParams(r *http.Request, sqlquery string, variables map[string]interface{}) string {
if h.GetVariablesCallback() != nil {
pathVars := h.GetVariablesCallback()(r)
for k, v := range pathVars {
kword := fmt.Sprintf("[%s]", k)
if strings.Contains(sqlquery, kword) {
// Sanitize the value before replacing
vStr := fmt.Sprintf("%v", v)
sanitized := ValidSQL(vStr, "colvalue")
sqlquery = strings.ReplaceAll(sqlquery, kword, sanitized)
}
variables[k] = v
}
}
return sqlquery
}
// mergeQueryParams merges query string parameters into the SQL query
func (h *Handler) mergeQueryParams(r *http.Request, sqlquery string, variables map[string]interface{}, allowFilter bool, propQry map[string]string) string {
for parmk, parmv := range r.URL.Query() {
if len(parmk) == 0 || len(parmv) == 0 {
continue
}
val := parmv[0]
dec, err := restheadspec.DecodeParam(val)
if err == nil {
val = dec
}
kword := fmt.Sprintf("[%s]", parmk)
variables[parmk] = val
// Replace in SQL if placeholder exists
if strings.Contains(sqlquery, kword) && len(val) > 0 {
if strings.HasPrefix(parmk, "p-") {
// Sanitize the parameter value before replacing
sanitized := ValidSQL(val, "colvalue")
sqlquery = strings.ReplaceAll(sqlquery, kword, sanitized)
}
}
// Add to propQry for x- prefixed params
if strings.HasPrefix(parmk, "x-") {
propQry[parmk] = val
}
// Apply filters if allowed
if allowFilter && len(parmk) > 1 && strings.Contains(strings.ToLower(sqlquery), strings.ToLower(parmk)) {
if len(parmv) > 1 {
// Sanitize each value in the IN clause
sanitizedValues := make([]string, len(parmv))
for i, v := range parmv {
sanitizedValues[i] = ValidSQL(v, "colvalue")
}
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s IN (%s)", ValidSQL(parmk, "colname"), strings.Join(sanitizedValues, ",")))
} else {
if strings.Contains(val, "match=") {
colval := strings.ReplaceAll(val, "match=", "")
if colval != "*" {
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s ILIKE '%%%s%%'", ValidSQL(parmk, "colname"), ValidSQL(colval, "colvalue")))
}
} else if val == "" || val == "0" {
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("(%[1]s = %[2]s OR %[1]s IS NULL)", ValidSQL(parmk, "colname"), ValidSQL(val, "colvalue")))
} else {
if IsNumeric(val) {
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s = %s", ValidSQL(parmk, "colname"), ValidSQL(val, "colvalue")))
} else {
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s = '%s'", ValidSQL(parmk, "colname"), ValidSQL(val, "colvalue")))
}
}
}
}
}
return sqlquery
}
// mergeHeaderParams merges HTTP header parameters into the SQL query
func (h *Handler) mergeHeaderParams(r *http.Request, sqlquery string, variables map[string]interface{}, propQry map[string]string, complexAPI *bool) string {
for kc, v := range r.Header {
k := strings.ToLower(kc)
if !strings.HasPrefix(k, "x-") || len(v) == 0 {
continue
}
val := v[0]
dec, err := restheadspec.DecodeParam(val)
if err == nil {
val = dec
}
variables[k] = val
propQry[k] = val
kword := fmt.Sprintf("[%s]", k)
if strings.Contains(sqlquery, kword) {
// Sanitize the header value before replacing
sanitized := ValidSQL(val, "colvalue")
sqlquery = strings.ReplaceAll(sqlquery, kword, sanitized)
}
// Handle special headers
if strings.Contains(k, "x-fieldfilter-") {
colname := strings.ReplaceAll(k, "x-fieldfilter-", "")
if val == "" || val == "0" {
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("COALESCE(%s, 0) = %s", ValidSQL(colname, "colname"), ValidSQL(val, "colvalue")))
} else {
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s = %s", ValidSQL(colname, "colname"), ValidSQL(val, "colvalue")))
}
}
if strings.Contains(k, "x-searchfilter-") {
colname := strings.ReplaceAll(k, "x-searchfilter-", "")
sval := strings.ReplaceAll(val, "'", "")
if sval != "" {
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s ILIKE '%%%s%%'", ValidSQL(colname, "colname"), ValidSQL(sval, "colvalue")))
}
}
if strings.Contains(k, "x-custom-sql-w") {
colval := ValidSQL(val, "select")
if len(colval) > 0 {
sqlquery = sqlQryWhere(sqlquery, colval)
}
}
if strings.Contains(k, "x-simpleapi") {
*complexAPI = !strings.EqualFold(val, "1") && !strings.EqualFold(val, "true")
}
}
return sqlquery
}
// replaceMetaVariables replaces meta variables like [rid_user], [user], etc. in the SQL query
func (h *Handler) replaceMetaVariables(sqlquery string, r *http.Request, userCtx *security.UserContext, metainfo map[string]interface{}, variables map[string]interface{}) string {
if strings.Contains(sqlquery, "[p_meta_default]") {
data, _ := json.Marshal(metainfo)
dataStr := strings.ReplaceAll(string(data), "$META$", "/*META*/")
sqlquery = strings.ReplaceAll(sqlquery, "[p_meta_default]", fmt.Sprintf("$META$%s$META$::jsonb", dataStr))
}
if strings.Contains(sqlquery, "[json_variables]") {
data, _ := json.Marshal(variables)
dataStr := strings.ReplaceAll(string(data), "$VAR$", "/*VAR*/")
sqlquery = strings.ReplaceAll(sqlquery, "[json_variables]", fmt.Sprintf("$VAR$%s$VAR$::jsonb", dataStr))
}
if strings.Contains(sqlquery, "[rid_user]") {
sqlquery = strings.ReplaceAll(sqlquery, "[rid_user]", fmt.Sprintf("%d", userCtx.UserID))
}
if strings.Contains(sqlquery, "[user]") {
sqlquery = strings.ReplaceAll(sqlquery, "[user]", fmt.Sprintf("$USR$%s$USR$", strings.ReplaceAll(userCtx.UserName, "$USR$", "/*USR*/")))
}
if strings.Contains(sqlquery, "[rid_session]") {
sqlquery = strings.ReplaceAll(sqlquery, "[rid_session]", fmt.Sprintf("%d", userCtx.SessionRID))
}
if strings.Contains(sqlquery, "[id_session]") {
sqlquery = strings.ReplaceAll(sqlquery, "[id_session]", userCtx.SessionID)
}
if strings.Contains(sqlquery, "[method]") {
sqlquery = strings.ReplaceAll(sqlquery, "[method]", fmt.Sprintf("$M$%s$M$", strings.ReplaceAll(r.Method, "$M$", "/*M*/")))
}
if strings.Contains(sqlquery, "[post_body]") {
bodystr := ""
if r.Method == "POST" || r.Method == "PUT" {
if r.Body != nil {
contents, err := io.ReadAll(r.Body)
if err == nil {
bodystr = string(contents)
}
}
}
sqlquery = strings.ReplaceAll(sqlquery, "[post_body]", fmt.Sprintf("$PBODY$%s$PBODY$", strings.ReplaceAll(bodystr, "$PBODY$", "/*PBODY*/")))
}
return sqlquery
}
// parsePaginationParams extracts sort, limit, and offset parameters from request
func (h *Handler) parsePaginationParams(r *http.Request) (sortcols string, limit, offset int) {
limit = 20
offset = 0
if sortStr := r.URL.Query().Get("sort"); sortStr != "" {
sortcols = sortStr
}
if limitStr := r.URL.Query().Get("limit"); limitStr != "" {
if l, err := strconv.Atoi(limitStr); err == nil && l > 0 {
limit = l
}
}
if offsetStr := r.URL.Query().Get("offset"); offsetStr != "" {
if o, err := strconv.Atoi(offsetStr); err == nil && o >= 0 {
offset = o
}
}
return
}
// ValidSQL validates and sanitizes SQL input to prevent injection
// mode can be: "colname", "colvalue", "select"
func ValidSQL(input, mode string) string {
// Remove dangerous characters based on mode
switch mode {
case "colname":
// For column names, only allow alphanumeric, underscore, and dot
reg := regexp.MustCompile(`[^a-zA-Z0-9_\.]`)
return reg.ReplaceAllString(input, "")
case "colvalue":
// For column values, escape single quotes and backslashes
result := strings.ReplaceAll(input, "\\", "\\\\")
result = strings.ReplaceAll(result, "'", "''")
return result
case "select":
// For SELECT clauses, be more permissive but still safe
// Remove semicolons and common SQL injection patterns
dangerous := []string{
";", "--", "/*", "*/", "xp_", "sp_",
"DROP ", "drop ", "Drop ",
"DELETE ", "delete ", "Delete ",
"TRUNCATE ", "truncate ", "Truncate ",
"UPDATE ", "update ", "Update ",
"INSERT ", "insert ", "Insert ",
"EXEC ", "exec ", "Exec ",
"EXECUTE ", "execute ", "Execute ",
"UNION ", "union ", "Union ",
"DECLARE ", "declare ", "Declare ",
"ALTER ", "alter ", "Alter ",
"CREATE ", "create ", "Create ",
}
result := input
for _, d := range dangerous {
result = strings.ReplaceAll(result, d, "")
}
return result
default:
return input
}
}
// sqlQryWhere adds a WHERE clause to a SQL query or appends to existing WHERE with AND
func sqlQryWhere(sqlquery, condition string) string {
lowerQuery := strings.ToLower(sqlquery)
wherePos := strings.Index(lowerQuery, " where ")
groupPos := strings.Index(lowerQuery, " group by")
orderPos := strings.Index(lowerQuery, " order by")
limitPos := strings.Index(lowerQuery, " limit ")
// Find the insertion point (before GROUP BY, ORDER BY, or LIMIT)
insertPos := len(sqlquery)
if groupPos > 0 && groupPos < insertPos {
insertPos = groupPos
}
if orderPos > 0 && orderPos < insertPos {
insertPos = orderPos
}
if limitPos > 0 && limitPos < insertPos {
insertPos = limitPos
}
if wherePos > 0 {
// WHERE exists, add AND condition before GROUP BY / ORDER BY / LIMIT
before := sqlquery[:insertPos]
after := sqlquery[insertPos:]
return fmt.Sprintf("%s AND %s %s", before, condition, after)
} else {
// No WHERE exists, add it before GROUP BY / ORDER BY / LIMIT
before := sqlquery[:insertPos]
after := sqlquery[insertPos:]
return fmt.Sprintf("%s WHERE %s %s", before, condition, after)
}
}
// IsNumeric checks if a string contains only numeric characters
func IsNumeric(s string) bool {
_, err := strconv.ParseFloat(s, 64)
return err == nil
}
// getReplacementForBlankParam determines the replacement value for an unused parameter
// based on whether it appears within quotes in the SQL query.
// It checks for PostgreSQL quotes: single quotes (”) and dollar quotes ($...$)
func getReplacementForBlankParam(sqlquery, param string) string {
// Find the parameter in the query
idx := strings.Index(sqlquery, param)
if idx < 0 {
return ""
}
// Check characters immediately before and after the parameter
var charBefore, charAfter byte
if idx > 0 {
charBefore = sqlquery[idx-1]
}
endIdx := idx + len(param)
if endIdx < len(sqlquery) {
charAfter = sqlquery[endIdx]
}
// Check if parameter is surrounded by quotes (single quote or dollar sign for PostgreSQL dollar-quoted strings)
if (charBefore == '\'' || charBefore == '$') && (charAfter == '\'' || charAfter == '$') {
// Parameter is in quotes, return empty string
return ""
}
// Parameter is not in quotes, return NULL
return "NULL"
}
// makeResultReceiver creates a slice of interface{} pointers for scanning SQL rows
// func makeResultReceiver(length int) []interface{} {
// result := make([]interface{}, length)
// for i := 0; i < length; i++ {
// var v interface{}
// result[i] = &v
// }
// return result
// }
// getIPAddress extracts the real IP address from the request
func getIPAddress(r *http.Request) string {
if forwarded := r.Header.Get("X-Forwarded-For"); forwarded != "" {
// X-Forwarded-For can contain multiple IPs, take the first one
ips := strings.Split(forwarded, ",")
return strings.TrimSpace(ips[0])
}
if realIP := r.Header.Get("X-Real-IP"); realIP != "" {
return realIP
}
return r.RemoteAddr
}
// sendError sends a JSON error response
func sendError(w http.ResponseWriter, status int, code, message string, err error) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
errObj := common.APIError{
Code: code,
Message: message,
}
if err != nil {
errObj.Detail = err.Error()
}
data, _ := json.Marshal(map[string]interface{}{
"success": false,
"error": errObj,
})
_, _ = w.Write(data)
}
// normalizePostgresTypesList normalizes a list of result maps to handle PostgreSQL types correctly
func normalizePostgresTypesList(rows []map[string]interface{}) []map[string]interface{} {
if len(rows) == 0 {
return rows
}
normalized := make([]map[string]interface{}, len(rows))
for i, row := range rows {
normalized[i] = normalizePostgresTypes(row)
}
return normalized
}
// normalizePostgresTypes normalizes a result map to handle PostgreSQL types correctly for JSON marshaling
// This is necessary because when scanning into map[string]interface{}, PostgreSQL types like jsonb, bytea, etc.
// are scanned as []byte which would be base64-encoded when marshaled to JSON.
func normalizePostgresTypes(row map[string]interface{}) map[string]interface{} {
if row == nil {
return nil
}
normalized := make(map[string]interface{}, len(row))
for key, value := range row {
normalized[key] = normalizePostgresValue(value)
}
return normalized
}
// normalizePostgresValue normalizes a single value to the appropriate Go type for JSON marshaling
func normalizePostgresValue(value interface{}) interface{} {
if value == nil {
return nil
}
switch v := value.(type) {
case []byte:
// Check if it's valid JSON (jsonb type)
// Try to unmarshal as JSON first
var jsonObj interface{}
if err := json.Unmarshal(v, &jsonObj); err == nil {
// It's valid JSON, return as json.RawMessage so it's not double-encoded
return json.RawMessage(v)
}
// Not valid JSON, could be bytea - keep as []byte for base64 encoding
return v
case []interface{}:
// Recursively normalize array elements
normalized := make([]interface{}, len(v))
for i, elem := range v {
normalized[i] = normalizePostgresValue(elem)
}
return normalized
case map[string]interface{}:
// Recursively normalize nested maps
return normalizePostgresTypes(v)
default:
// For other types (int, float, string, bool, etc.), return as-is
return v
}
}