ResolveSpec/pkg/funcspec/function_api.go
2025-12-03 12:10:13 +02:00

1014 lines
31 KiB
Go

package funcspec
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"regexp"
"runtime/debug"
"strconv"
"strings"
"time"
"github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/logger"
"github.com/bitechdev/ResolveSpec/pkg/restheadspec"
"github.com/bitechdev/ResolveSpec/pkg/security"
)
// Handler handles function-based SQL API requests
type Handler struct {
db common.Database
hooks *HookRegistry
}
// NewHandler creates a new function API handler
func NewHandler(db common.Database) *Handler {
return &Handler{
db: db,
hooks: NewHookRegistry(),
}
}
// GetDatabase returns the underlying database connection
// Implements common.SpecHandler interface
func (h *Handler) GetDatabase() common.Database {
return h.db
}
// Hooks returns the hook registry for this handler
// Use this to register custom hooks for operations
func (h *Handler) Hooks() *HookRegistry {
return h.hooks
}
// HTTPFuncType is a function type for HTTP handlers
type HTTPFuncType func(http.ResponseWriter, *http.Request)
// SqlQueryList creates an HTTP handler that executes a SQL query and returns a list with pagination
func (h *Handler) SqlQueryList(sqlquery string, pNoCount, pBlankparms, pAllowFilter bool) HTTPFuncType {
return func(w http.ResponseWriter, r *http.Request) {
defer func() {
if err := recover(); err != nil {
stack := debug.Stack()
logger.Error("Panic in SqlQueryList: %v\nStack trace:\n%s", err, string(stack))
http.Error(w, fmt.Sprintf("Internal server error: %v", err), http.StatusInternalServerError)
}
}()
ctx, cancel := context.WithTimeout(r.Context(), 900*time.Second)
defer cancel()
var dbobjlist []map[string]interface{}
var total int64
propQry := make(map[string]string)
inputvars := make([]string, 0)
metainfo := make(map[string]interface{})
variables := make(map[string]interface{})
complexAPI := false
// Get user context from security package
userCtx, ok := security.GetUserContext(ctx)
if !ok {
logger.Warn("No user context found in request")
userCtx = &security.UserContext{UserID: 0, UserName: "anonymous"}
}
w.Header().Set("Content-Type", "application/json")
// Initialize hook context
hookCtx := &HookContext{
Context: ctx,
Handler: h,
Request: r,
Writer: w,
SQLQuery: sqlquery,
Variables: variables,
InputVars: inputvars,
MetaInfo: metainfo,
PropQry: propQry,
UserContext: userCtx,
NoCount: pNoCount,
BlankParams: pBlankparms,
AllowFilter: pAllowFilter,
ComplexAPI: complexAPI,
}
// Execute BeforeQueryList hook
if err := h.hooks.Execute(BeforeQueryList, hookCtx); err != nil {
logger.Error("BeforeQueryList hook failed: %v", err)
sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err)
return
}
// Check if hook aborted the operation
if hookCtx.Abort {
if hookCtx.AbortCode == 0 {
hookCtx.AbortCode = http.StatusBadRequest
}
sendError(w, hookCtx.AbortCode, "operation_aborted", hookCtx.AbortMessage, nil)
return
}
// Use potentially modified SQL query and variables from hooks
sqlquery = hookCtx.SQLQuery
variables = hookCtx.Variables
// complexAPI = hookCtx.ComplexAPI
// Extract input variables from SQL query (placeholders like [variable])
sqlquery = h.extractInputVariables(sqlquery, &inputvars)
// Merge URL path parameters
sqlquery = h.mergePathParams(r, sqlquery, variables)
// Parse comprehensive parameters from headers and query string
reqParams := h.ParseParameters(r)
complexAPI = reqParams.ComplexAPI
// Merge query string parameters
sqlquery = h.mergeQueryParams(r, sqlquery, variables, pAllowFilter, propQry)
// Merge header parameters
sqlquery = h.mergeHeaderParams(r, sqlquery, variables, propQry, &complexAPI)
// Apply filters from parsed parameters (if not already applied by pAllowFilter)
if !pAllowFilter {
sqlquery = h.ApplyFilters(sqlquery, reqParams)
}
// Apply field selection
sqlquery = h.ApplyFieldSelection(sqlquery, reqParams)
// Apply DISTINCT if requested
sqlquery = h.ApplyDistinct(sqlquery, reqParams)
// Override pNoCount if skipcount is specified
if reqParams.SkipCount {
pNoCount = true
}
// Build metainfo
metainfo["ipaddress"] = getIPAddress(r)
metainfo["url"] = r.RequestURI
metainfo["user"] = userCtx.UserName
metainfo["rid_user"] = fmt.Sprintf("%d", userCtx.UserID)
metainfo["method"] = r.Method
metainfo["variables"] = variables
// Replace meta variables in SQL
sqlquery = h.replaceMetaVariables(sqlquery, r, userCtx, metainfo, variables)
// Remove unused input variables
if pBlankparms {
for _, kw := range inputvars {
replacement := getReplacementForBlankParam(sqlquery, kw)
sqlquery = strings.ReplaceAll(sqlquery, kw, replacement)
logger.Debug("Replaced unused variable %s with: %s", kw, replacement)
}
}
// Update hook context with latest SQL query and variables
hookCtx.SQLQuery = sqlquery
hookCtx.Variables = variables
hookCtx.InputVars = inputvars
// Execute query within transaction
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
sqlqueryCnt := sqlquery
// Parse sorting and pagination parameters
sortcols, limit, offset := h.parsePaginationParams(r)
// Override with parsed parameters if available
if reqParams.SortColumns != "" {
sortcols = reqParams.SortColumns
}
if reqParams.Limit > 0 {
limit = reqParams.Limit
}
if reqParams.Offset > 0 {
offset = reqParams.Offset
}
hookCtx.SortColumns = sortcols
hookCtx.Limit = limit
hookCtx.Offset = offset
fromPos := strings.Index(strings.ToLower(sqlquery), "from ")
orderbyPos := strings.Index(strings.ToLower(sqlquery), "order by")
if len(sortcols) > 0 && (orderbyPos < 0 || (orderbyPos > 0 && orderbyPos < fromPos)) {
sqlquery = fmt.Sprintf("%s \nORDER BY %s", sqlquery, ValidSQL(sortcols, "select"))
}
if !pNoCount {
if limit > 0 && offset > 0 {
sqlquery = fmt.Sprintf("%s \nLIMIT %d OFFSET %d", sqlquery, limit, offset)
} else if limit > 0 {
sqlquery = fmt.Sprintf("%s \nLIMIT %d", sqlquery, limit)
} else {
sqlquery = fmt.Sprintf("%s \nLIMIT %d", sqlquery, 20000)
}
// Get total count
countQuery := fmt.Sprintf("SELECT COUNT(1) FROM (%s) cnts", sqlqueryCnt)
var countResult struct{ Count int64 }
if err := tx.Query(ctx, &countResult, countQuery); err != nil {
sendError(w, http.StatusBadRequest, "count_failed", "Failed to retrieve record count", err)
return err
}
total = countResult.Count
}
// Execute BeforeSQLExec hook
hookCtx.SQLQuery = sqlquery
if err := h.hooks.Execute(BeforeSQLExec, hookCtx); err != nil {
logger.Error("BeforeSQLExec hook failed: %v", err)
sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err)
return err
}
// Use potentially modified SQL query from hook
sqlquery = hookCtx.SQLQuery
// Execute main query
rows := make([]map[string]interface{}, 0)
if err := tx.Query(ctx, &rows, sqlquery); err != nil {
sendError(w, http.StatusBadRequest, "query_failed", "Failed to retrieve records", err)
return err
}
// Normalize PostgreSQL types for proper JSON marshaling
dbobjlist = normalizePostgresTypesList(rows)
if pNoCount {
total = int64(len(dbobjlist))
}
// Execute AfterSQLExec hook
hookCtx.Result = dbobjlist
hookCtx.Total = total
if err := h.hooks.Execute(AfterSQLExec, hookCtx); err != nil {
logger.Error("AfterSQLExec hook failed: %v", err)
sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err)
return err
}
// Use potentially modified result from hook
if modifiedResult, ok := hookCtx.Result.([]map[string]interface{}); ok {
dbobjlist = modifiedResult
}
total = hookCtx.Total
return nil
})
if err != nil {
logger.Error("Transaction failed: %v", err)
return
}
// Execute AfterQueryList hook
hookCtx.Result = dbobjlist
hookCtx.Total = total
hookCtx.Error = err
if err := h.hooks.Execute(AfterQueryList, hookCtx); err != nil {
logger.Error("AfterQueryList hook failed: %v", err)
sendError(w, http.StatusInternalServerError, "hook_error", "Hook execution failed", err)
return
}
// Use potentially modified result from hook
if modifiedResult, ok := hookCtx.Result.([]map[string]interface{}); ok {
dbobjlist = modifiedResult
}
total = hookCtx.Total
// Set response headers
respOffset := 0
if offsetStr := r.URL.Query().Get("offset"); offsetStr != "" {
if o, err := strconv.Atoi(offsetStr); err == nil {
respOffset = o
}
}
w.Header().Set("Content-Range", fmt.Sprintf("items %d-%d/%d", respOffset, respOffset+len(dbobjlist), total))
logger.Info("Serving: Records %d of %d", len(dbobjlist), total)
// Execute BeforeResponse hook
hookCtx.Result = dbobjlist
hookCtx.Total = total
if err := h.hooks.Execute(BeforeResponse, hookCtx); err != nil {
logger.Error("BeforeResponse hook failed: %v", err)
sendError(w, http.StatusInternalServerError, "hook_error", "Hook execution failed", err)
return
}
// Use potentially modified result from hook
if modifiedResult, ok := hookCtx.Result.([]map[string]interface{}); ok {
dbobjlist = modifiedResult
}
if len(dbobjlist) == 0 {
_, _ = w.Write([]byte("[]"))
return
}
// Format response based on response format
switch reqParams.ResponseFormat {
case "syncfusion":
// Syncfusion format: { result: data, count: total }
response := map[string]interface{}{
"result": dbobjlist,
"count": total,
}
data, err := json.Marshal(response)
if err != nil {
sendError(w, http.StatusInternalServerError, "json_error", "Could not marshal response", err)
} else {
if int64(len(dbobjlist)) < total {
w.WriteHeader(http.StatusPartialContent)
}
_, _ = w.Write(data)
}
case "detail":
// Detail format: complex API with metadata
metaobj := map[string]interface{}{
"items": dbobjlist,
"count": fmt.Sprintf("%d", len(dbobjlist)),
"total": fmt.Sprintf("%d", total),
"tablename": r.URL.Path,
"tableprefix": "gsql",
}
data, err := json.Marshal(metaobj)
if err != nil {
sendError(w, http.StatusInternalServerError, "json_error", "Could not marshal response", err)
} else {
if int64(len(dbobjlist)) < total {
w.WriteHeader(http.StatusPartialContent)
}
_, _ = w.Write(data)
}
default:
// Simple format: just return the data array (or complex API if requested)
if complexAPI {
metaobj := map[string]interface{}{
"items": dbobjlist,
"count": fmt.Sprintf("%d", len(dbobjlist)),
"total": fmt.Sprintf("%d", total),
"tablename": r.URL.Path,
"tableprefix": "gsql",
}
data, err := json.Marshal(metaobj)
if err != nil {
sendError(w, http.StatusInternalServerError, "json_error", "Could not marshal response", err)
} else {
if int64(len(dbobjlist)) < total {
w.WriteHeader(http.StatusPartialContent)
}
_, _ = w.Write(data)
}
} else {
data, err := json.Marshal(dbobjlist)
if err != nil {
sendError(w, http.StatusInternalServerError, "json_error", "Could not marshal response", err)
} else {
if int64(len(dbobjlist)) < total {
w.WriteHeader(http.StatusPartialContent)
}
_, _ = w.Write(data)
}
}
}
}
}
// SqlQuery creates an HTTP handler that executes a SQL query and returns a single record
func (h *Handler) SqlQuery(sqlquery string, pBlankparms bool) HTTPFuncType {
return func(w http.ResponseWriter, r *http.Request) {
defer func() {
if err := recover(); err != nil {
stack := debug.Stack()
logger.Error("Panic in SqlQuery: %v\nStack trace:\n%s", err, string(stack))
http.Error(w, fmt.Sprintf("Internal server error: %v", err), http.StatusInternalServerError)
}
}()
ctx, cancel := context.WithTimeout(r.Context(), 600*time.Second)
defer cancel()
propQry := make(map[string]string)
inputvars := make([]string, 0)
metainfo := make(map[string]interface{})
variables := make(map[string]interface{})
dbobj := make(map[string]interface{})
complexAPI := false
// Get user context from security package
userCtx, ok := security.GetUserContext(ctx)
if !ok {
logger.Warn("No user context found in request")
userCtx = &security.UserContext{UserID: 0, UserName: "anonymous"}
}
w.Header().Set("Content-Type", "application/json")
// Initialize hook context
hookCtx := &HookContext{
Context: ctx,
Handler: h,
Request: r,
Writer: w,
SQLQuery: sqlquery,
Variables: variables,
InputVars: inputvars,
MetaInfo: metainfo,
PropQry: propQry,
UserContext: userCtx,
BlankParams: pBlankparms,
ComplexAPI: complexAPI,
}
// Execute BeforeQuery hook
if err := h.hooks.Execute(BeforeQuery, hookCtx); err != nil {
logger.Error("BeforeQuery hook failed: %v", err)
sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err)
return
}
// Check if hook aborted the operation
if hookCtx.Abort {
if hookCtx.AbortCode == 0 {
hookCtx.AbortCode = http.StatusBadRequest
}
sendError(w, hookCtx.AbortCode, "operation_aborted", hookCtx.AbortMessage, nil)
return
}
// Use potentially modified SQL query and variables from hooks
sqlquery = hookCtx.SQLQuery
variables = hookCtx.Variables
// Extract input variables from SQL query
sqlquery = h.extractInputVariables(sqlquery, &inputvars)
// Merge URL path parameters
sqlquery = h.mergePathParams(r, sqlquery, variables)
// Parse comprehensive parameters from headers and query string
reqParams := h.ParseParameters(r)
complexAPI = reqParams.ComplexAPI
// Merge query string parameters
sqlquery = h.mergeQueryParams(r, sqlquery, variables, false, propQry)
// Merge header parameters
sqlquery = h.mergeHeaderParams(r, sqlquery, variables, propQry, &complexAPI)
hookCtx.ComplexAPI = complexAPI
// Apply filters from parsed parameters
sqlquery = h.ApplyFilters(sqlquery, reqParams)
// Apply field selection
sqlquery = h.ApplyFieldSelection(sqlquery, reqParams)
// Apply DISTINCT if requested
sqlquery = h.ApplyDistinct(sqlquery, reqParams)
// Build metainfo
metainfo["ipaddress"] = getIPAddress(r)
metainfo["url"] = r.RequestURI
metainfo["user"] = userCtx.UserName
metainfo["rid_user"] = fmt.Sprintf("%d", userCtx.UserID)
metainfo["method"] = r.Method
metainfo["variables"] = variables
// Replace meta variables in SQL
sqlquery = h.replaceMetaVariables(sqlquery, r, userCtx, metainfo, variables)
// Apply field filters from headers
for k, val := range propQry {
kLower := strings.ToLower(k)
if strings.HasPrefix(kLower, "x-fieldfilter-") {
colname := strings.ReplaceAll(kLower, "x-fieldfilter-", "")
if strings.Contains(strings.ToLower(sqlquery), colname) {
if val == "" || val == "0" {
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("COALESCE(%s, 0) = %s", ValidSQL(colname, "colname"), ValidSQL(val, "colvalue")))
} else {
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s = %s", ValidSQL(colname, "colname"), ValidSQL(val, "colvalue")))
}
}
}
}
// Remove unused input variables
if pBlankparms {
for _, kw := range inputvars {
replacement := getReplacementForBlankParam(sqlquery, kw)
sqlquery = strings.ReplaceAll(sqlquery, kw, replacement)
logger.Debug("Replaced unused variable %s with: %s", kw, replacement)
}
}
// Update hook context with latest SQL query and variables
hookCtx.SQLQuery = sqlquery
hookCtx.Variables = variables
hookCtx.InputVars = inputvars
// Execute query within transaction
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
// Execute BeforeSQLExec hook
if err := h.hooks.Execute(BeforeSQLExec, hookCtx); err != nil {
logger.Error("BeforeSQLExec hook failed: %v", err)
sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err)
return err
}
// Use potentially modified SQL query from hook
sqlquery = hookCtx.SQLQuery
// Execute main query
rows := make([]map[string]interface{}, 0)
if err := tx.Query(ctx, &rows, sqlquery); err != nil {
sendError(w, http.StatusBadRequest, "query_failed", "Failed to retrieve records", err)
return err
}
if len(rows) > 0 {
dbobj = normalizePostgresTypes(rows[0])
}
// Execute AfterSQLExec hook
hookCtx.Result = dbobj
if err := h.hooks.Execute(AfterSQLExec, hookCtx); err != nil {
logger.Error("AfterSQLExec hook failed: %v", err)
sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err)
return err
}
// Use potentially modified result from hook
if modifiedResult, ok := hookCtx.Result.(map[string]interface{}); ok {
dbobj = modifiedResult
}
return nil
})
if err != nil {
logger.Error("Transaction failed: %v", err)
return
}
// Execute AfterQuery hook
hookCtx.Result = dbobj
hookCtx.Error = err
if err := h.hooks.Execute(AfterQuery, hookCtx); err != nil {
logger.Error("AfterQuery hook failed: %v", err)
sendError(w, http.StatusInternalServerError, "hook_error", "Hook execution failed", err)
return
}
// Use potentially modified result from hook
if modifiedResult, ok := hookCtx.Result.(map[string]interface{}); ok {
dbobj = modifiedResult
}
// Execute BeforeResponse hook
hookCtx.Result = dbobj
if err := h.hooks.Execute(BeforeResponse, hookCtx); err != nil {
logger.Error("BeforeResponse hook failed: %v", err)
sendError(w, http.StatusInternalServerError, "hook_error", "Hook execution failed", err)
return
}
// Use potentially modified result from hook
if modifiedResult, ok := hookCtx.Result.(map[string]interface{}); ok {
dbobj = modifiedResult
}
// Check if response should be root-level data
if val, ok := dbobj["root_as_data"]; ok {
data, err := json.Marshal(val)
if err != nil {
sendError(w, http.StatusInternalServerError, "json_error", "Could not marshal response", err)
} else {
_, _ = w.Write(data)
}
return
}
// Marshal and send response
data, err := json.Marshal(dbobj)
if err != nil {
sendError(w, http.StatusInternalServerError, "json_error", "Could not marshal response", err)
} else {
_, _ = w.Write(data)
}
}
}
// Helper functions
// extractInputVariables extracts placeholders like [variable] from the SQL query
func (h *Handler) extractInputVariables(sqlquery string, inputvars *[]string) string {
testsqlquery := sqlquery
for i := 0; i <= strings.Count(sqlquery, "[")*4; i++ {
iStart := strings.Index(testsqlquery, "[")
if iStart < 0 {
break
}
iEnd := strings.Index(testsqlquery, "]")
if iEnd < 0 {
break
}
*inputvars = append(*inputvars, testsqlquery[iStart:iEnd+1])
testsqlquery = testsqlquery[iEnd+1:]
}
return sqlquery
}
// mergePathParams merges URL path parameters into the SQL query
func (h *Handler) mergePathParams(r *http.Request, sqlquery string, variables map[string]interface{}) string {
// Note: Path parameters would typically come from a router like gorilla/mux
// For now, this is a placeholder for path parameter extraction
return sqlquery
}
// mergeQueryParams merges query string parameters into the SQL query
func (h *Handler) mergeQueryParams(r *http.Request, sqlquery string, variables map[string]interface{}, allowFilter bool, propQry map[string]string) string {
for parmk, parmv := range r.URL.Query() {
if len(parmk) == 0 || len(parmv) == 0 {
continue
}
val := parmv[0]
dec, err := restheadspec.DecodeParam(val)
if err == nil {
val = dec
}
kword := fmt.Sprintf("[%s]", parmk)
variables[parmk] = val
// Replace in SQL if placeholder exists
if strings.Contains(sqlquery, kword) && len(val) > 0 {
if strings.HasPrefix(parmk, "p-") {
sqlquery = strings.ReplaceAll(sqlquery, kword, val)
}
}
// Add to propQry for x- prefixed params
if strings.HasPrefix(parmk, "x-") {
propQry[parmk] = val
}
// Apply filters if allowed
if allowFilter && len(parmk) > 1 && strings.Contains(strings.ToLower(sqlquery), strings.ToLower(parmk)) {
if len(parmv) > 1 {
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s IN (%s)", ValidSQL(parmk, "colname"), strings.Join(parmv, ",")))
} else {
if strings.Contains(val, "match=") {
colval := strings.ReplaceAll(val, "match=", "")
if colval != "*" {
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s ILIKE '%%%s%%'", ValidSQL(parmk, "colname"), ValidSQL(colval, "colvalue")))
}
} else if val == "" || val == "0" {
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("(%[1]s = %[2]s OR %[1]s IS NULL)", ValidSQL(parmk, "colname"), ValidSQL(val, "colvalue")))
} else {
if IsNumeric(val) {
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s = %s", ValidSQL(parmk, "colname"), ValidSQL(val, "colvalue")))
} else {
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s = '%s'", ValidSQL(parmk, "colname"), ValidSQL(val, "colvalue")))
}
}
}
}
}
return sqlquery
}
// mergeHeaderParams merges HTTP header parameters into the SQL query
func (h *Handler) mergeHeaderParams(r *http.Request, sqlquery string, variables map[string]interface{}, propQry map[string]string, complexAPI *bool) string {
for kc, v := range r.Header {
k := strings.ToLower(kc)
if !strings.HasPrefix(k, "x-") || len(v) == 0 {
continue
}
val := v[0]
dec, err := restheadspec.DecodeParam(val)
if err == nil {
val = dec
}
variables[k] = val
propQry[k] = val
kword := fmt.Sprintf("[%s]", k)
if strings.Contains(sqlquery, kword) {
sqlquery = strings.ReplaceAll(sqlquery, kword, val)
}
// Handle special headers
if strings.Contains(k, "x-fieldfilter-") {
colname := strings.ReplaceAll(k, "x-fieldfilter-", "")
if val == "" || val == "0" {
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("COALESCE(%s, 0) = %s", ValidSQL(colname, "colname"), ValidSQL(val, "colvalue")))
} else {
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s = %s", ValidSQL(colname, "colname"), ValidSQL(val, "colvalue")))
}
}
if strings.Contains(k, "x-searchfilter-") {
colname := strings.ReplaceAll(k, "x-searchfilter-", "")
sval := strings.ReplaceAll(val, "'", "")
if sval != "" {
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s ILIKE '%%%s%%'", ValidSQL(colname, "colname"), ValidSQL(sval, "colvalue")))
}
}
if strings.Contains(k, "x-custom-sql-w") {
colval := ValidSQL(val, "select")
if len(colval) > 0 {
sqlquery = sqlQryWhere(sqlquery, colval)
}
}
if strings.Contains(k, "x-simpleapi") {
*complexAPI = !strings.EqualFold(val, "1") && !strings.EqualFold(val, "true")
}
}
return sqlquery
}
// replaceMetaVariables replaces meta variables like [rid_user], [user], etc. in the SQL query
func (h *Handler) replaceMetaVariables(sqlquery string, r *http.Request, userCtx *security.UserContext, metainfo map[string]interface{}, variables map[string]interface{}) string {
if strings.Contains(sqlquery, "[p_meta_default]") {
data, _ := json.Marshal(metainfo)
sqlquery = strings.ReplaceAll(sqlquery, "[p_meta_default]", fmt.Sprintf("'%s'::jsonb", string(data)))
}
if strings.Contains(sqlquery, "[json_variables]") {
data, _ := json.Marshal(variables)
sqlquery = strings.ReplaceAll(sqlquery, "[json_variables]", fmt.Sprintf("'%s'::jsonb", string(data)))
}
if strings.Contains(sqlquery, "[rid_user]") {
sqlquery = strings.ReplaceAll(sqlquery, "[rid_user]", fmt.Sprintf("%d", userCtx.UserID))
}
if strings.Contains(sqlquery, "[user]") {
sqlquery = strings.ReplaceAll(sqlquery, "[user]", fmt.Sprintf("'%s'", userCtx.UserName))
}
if strings.Contains(sqlquery, "[rid_session]") {
sessionID, _ := strconv.ParseInt(userCtx.SessionID, 10, 64)
sqlquery = strings.ReplaceAll(sqlquery, "[rid_session]", fmt.Sprintf("%d", sessionID))
}
if strings.Contains(sqlquery, "[method]") {
sqlquery = strings.ReplaceAll(sqlquery, "[method]", r.Method)
}
if strings.Contains(sqlquery, "[post_body]") {
bodystr := ""
if r.Method == "POST" || r.Method == "PUT" {
if r.Body != nil {
contents, err := io.ReadAll(r.Body)
if err == nil {
bodystr = string(contents)
}
}
}
sqlquery = strings.ReplaceAll(sqlquery, "[post_body]", fmt.Sprintf("'%s'", bodystr))
}
return sqlquery
}
// parsePaginationParams extracts sort, limit, and offset parameters from request
func (h *Handler) parsePaginationParams(r *http.Request) (sortcols string, limit, offset int) {
limit = 20
offset = 0
if sortStr := r.URL.Query().Get("sort"); sortStr != "" {
sortcols = sortStr
}
if limitStr := r.URL.Query().Get("limit"); limitStr != "" {
if l, err := strconv.Atoi(limitStr); err == nil && l > 0 {
limit = l
}
}
if offsetStr := r.URL.Query().Get("offset"); offsetStr != "" {
if o, err := strconv.Atoi(offsetStr); err == nil && o >= 0 {
offset = o
}
}
return
}
// ValidSQL validates and sanitizes SQL input to prevent injection
// mode can be: "colname", "colvalue", "select"
func ValidSQL(input, mode string) string {
// Remove dangerous characters based on mode
switch mode {
case "colname":
// For column names, only allow alphanumeric, underscore, and dot
reg := regexp.MustCompile(`[^a-zA-Z0-9_\.]`)
return reg.ReplaceAllString(input, "")
case "colvalue":
// For column values, escape single quotes
return strings.ReplaceAll(input, "'", "''")
case "select":
// For SELECT clauses, be more permissive but still safe
// Remove semicolons and common SQL injection patterns
dangerous := []string{";", "--", "/*", "*/", "xp_", "sp_", "DROP ", "DELETE ", "TRUNCATE ", "UPDATE ", "INSERT "}
result := input
for _, d := range dangerous {
result = strings.ReplaceAll(result, d, "")
result = strings.ReplaceAll(result, strings.ToLower(d), "")
result = strings.ReplaceAll(result, strings.ToUpper(d), "")
}
return result
default:
return input
}
}
// sqlQryWhere adds a WHERE clause to a SQL query or appends to existing WHERE with AND
func sqlQryWhere(sqlquery, condition string) string {
lowerQuery := strings.ToLower(sqlquery)
wherePos := strings.Index(lowerQuery, " where ")
groupPos := strings.Index(lowerQuery, " group by")
orderPos := strings.Index(lowerQuery, " order by")
limitPos := strings.Index(lowerQuery, " limit ")
// Find the insertion point (before GROUP BY, ORDER BY, or LIMIT)
insertPos := len(sqlquery)
if groupPos > 0 && groupPos < insertPos {
insertPos = groupPos
}
if orderPos > 0 && orderPos < insertPos {
insertPos = orderPos
}
if limitPos > 0 && limitPos < insertPos {
insertPos = limitPos
}
if wherePos > 0 {
// WHERE exists, add AND condition before GROUP BY / ORDER BY / LIMIT
before := sqlquery[:insertPos]
after := sqlquery[insertPos:]
return fmt.Sprintf("%s AND %s %s", before, condition, after)
} else {
// No WHERE exists, add it before GROUP BY / ORDER BY / LIMIT
before := sqlquery[:insertPos]
after := sqlquery[insertPos:]
return fmt.Sprintf("%s WHERE %s %s", before, condition, after)
}
}
// IsNumeric checks if a string contains only numeric characters
func IsNumeric(s string) bool {
_, err := strconv.ParseFloat(s, 64)
return err == nil
}
// getReplacementForBlankParam determines the replacement value for an unused parameter
// based on whether it appears within quotes in the SQL query.
// It checks for PostgreSQL quotes: single quotes (”) and dollar quotes ($...$)
func getReplacementForBlankParam(sqlquery, param string) string {
// Find the parameter in the query
idx := strings.Index(sqlquery, param)
if idx < 0 {
return ""
}
// Check characters immediately before and after the parameter
var charBefore, charAfter byte
if idx > 0 {
charBefore = sqlquery[idx-1]
}
endIdx := idx + len(param)
if endIdx < len(sqlquery) {
charAfter = sqlquery[endIdx]
}
// Check if parameter is surrounded by quotes (single quote or dollar sign for PostgreSQL dollar-quoted strings)
if (charBefore == '\'' || charBefore == '$') && (charAfter == '\'' || charAfter == '$') {
// Parameter is in quotes, return empty string
return ""
}
// Parameter is not in quotes, return NULL
return "NULL"
}
// makeResultReceiver creates a slice of interface{} pointers for scanning SQL rows
// func makeResultReceiver(length int) []interface{} {
// result := make([]interface{}, length)
// for i := 0; i < length; i++ {
// var v interface{}
// result[i] = &v
// }
// return result
// }
// getIPAddress extracts the real IP address from the request
func getIPAddress(r *http.Request) string {
if forwarded := r.Header.Get("X-Forwarded-For"); forwarded != "" {
// X-Forwarded-For can contain multiple IPs, take the first one
ips := strings.Split(forwarded, ",")
return strings.TrimSpace(ips[0])
}
if realIP := r.Header.Get("X-Real-IP"); realIP != "" {
return realIP
}
return r.RemoteAddr
}
// sendError sends a JSON error response
func sendError(w http.ResponseWriter, status int, code, message string, err error) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
errObj := common.APIError{
Code: code,
Message: message,
}
if err != nil {
errObj.Detail = err.Error()
}
data, _ := json.Marshal(map[string]interface{}{
"success": false,
"error": errObj,
})
_, _ = w.Write(data)
}
// normalizePostgresTypesList normalizes a list of result maps to handle PostgreSQL types correctly
func normalizePostgresTypesList(rows []map[string]interface{}) []map[string]interface{} {
if len(rows) == 0 {
return rows
}
normalized := make([]map[string]interface{}, len(rows))
for i, row := range rows {
normalized[i] = normalizePostgresTypes(row)
}
return normalized
}
// normalizePostgresTypes normalizes a result map to handle PostgreSQL types correctly for JSON marshaling
// This is necessary because when scanning into map[string]interface{}, PostgreSQL types like jsonb, bytea, etc.
// are scanned as []byte which would be base64-encoded when marshaled to JSON.
func normalizePostgresTypes(row map[string]interface{}) map[string]interface{} {
if row == nil {
return nil
}
normalized := make(map[string]interface{}, len(row))
for key, value := range row {
normalized[key] = normalizePostgresValue(value)
}
return normalized
}
// normalizePostgresValue normalizes a single value to the appropriate Go type for JSON marshaling
func normalizePostgresValue(value interface{}) interface{} {
if value == nil {
return nil
}
switch v := value.(type) {
case []byte:
// Check if it's valid JSON (jsonb type)
// Try to unmarshal as JSON first
var jsonObj interface{}
if err := json.Unmarshal(v, &jsonObj); err == nil {
// It's valid JSON, return as json.RawMessage so it's not double-encoded
return json.RawMessage(v)
}
// Not valid JSON, could be bytea - keep as []byte for base64 encoding
return v
case []interface{}:
// Recursively normalize array elements
normalized := make([]interface{}, len(v))
for i, elem := range v {
normalized[i] = normalizePostgresValue(elem)
}
return normalized
case map[string]interface{}:
// Recursively normalize nested maps
return normalizePostgresTypes(v)
default:
// For other types (int, float, string, bool, etc.), return as-is
return v
}
}