mirror of
https://github.com/bitechdev/ResolveSpec.git
synced 2025-12-06 22:36:23 +00:00
915 lines
28 KiB
Go
915 lines
28 KiB
Go
package funcspec
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"regexp"
|
|
"runtime/debug"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/bitechdev/ResolveSpec/pkg/common"
|
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
|
"github.com/bitechdev/ResolveSpec/pkg/restheadspec"
|
|
"github.com/bitechdev/ResolveSpec/pkg/security"
|
|
)
|
|
|
|
// Handler handles function-based SQL API requests
|
|
type Handler struct {
|
|
db common.Database
|
|
hooks *HookRegistry
|
|
}
|
|
|
|
// NewHandler creates a new function API handler
|
|
func NewHandler(db common.Database) *Handler {
|
|
return &Handler{
|
|
db: db,
|
|
hooks: NewHookRegistry(),
|
|
}
|
|
}
|
|
|
|
// GetDatabase returns the underlying database connection
|
|
// Implements common.SpecHandler interface
|
|
func (h *Handler) GetDatabase() common.Database {
|
|
return h.db
|
|
}
|
|
|
|
// Hooks returns the hook registry for this handler
|
|
// Use this to register custom hooks for operations
|
|
func (h *Handler) Hooks() *HookRegistry {
|
|
return h.hooks
|
|
}
|
|
|
|
// HTTPFuncType is a function type for HTTP handlers
|
|
type HTTPFuncType func(http.ResponseWriter, *http.Request)
|
|
|
|
// SqlQueryList creates an HTTP handler that executes a SQL query and returns a list with pagination
|
|
func (h *Handler) SqlQueryList(sqlquery string, pNoCount, pBlankparms, pAllowFilter bool) HTTPFuncType {
|
|
return func(w http.ResponseWriter, r *http.Request) {
|
|
defer func() {
|
|
if err := recover(); err != nil {
|
|
stack := debug.Stack()
|
|
logger.Error("Panic in SqlQueryList: %v\nStack trace:\n%s", err, string(stack))
|
|
http.Error(w, fmt.Sprintf("Internal server error: %v", err), http.StatusInternalServerError)
|
|
}
|
|
}()
|
|
|
|
ctx, cancel := context.WithTimeout(r.Context(), 900*time.Second)
|
|
defer cancel()
|
|
|
|
var dbobjlist []map[string]interface{}
|
|
var total int64
|
|
propQry := make(map[string]string)
|
|
inputvars := make([]string, 0)
|
|
metainfo := make(map[string]interface{})
|
|
variables := make(map[string]interface{})
|
|
complexAPI := false
|
|
|
|
// Get user context from security package
|
|
userCtx, ok := security.GetUserContext(ctx)
|
|
if !ok {
|
|
logger.Warn("No user context found in request")
|
|
userCtx = &security.UserContext{UserID: 0, UserName: "anonymous"}
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
|
|
// Initialize hook context
|
|
hookCtx := &HookContext{
|
|
Context: ctx,
|
|
Handler: h,
|
|
Request: r,
|
|
Writer: w,
|
|
SQLQuery: sqlquery,
|
|
Variables: variables,
|
|
InputVars: inputvars,
|
|
MetaInfo: metainfo,
|
|
PropQry: propQry,
|
|
UserContext: userCtx,
|
|
NoCount: pNoCount,
|
|
BlankParams: pBlankparms,
|
|
AllowFilter: pAllowFilter,
|
|
ComplexAPI: complexAPI,
|
|
}
|
|
|
|
// Execute BeforeQueryList hook
|
|
if err := h.hooks.Execute(BeforeQueryList, hookCtx); err != nil {
|
|
logger.Error("BeforeQueryList hook failed: %v", err)
|
|
sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err)
|
|
return
|
|
}
|
|
|
|
// Check if hook aborted the operation
|
|
if hookCtx.Abort {
|
|
if hookCtx.AbortCode == 0 {
|
|
hookCtx.AbortCode = http.StatusBadRequest
|
|
}
|
|
sendError(w, hookCtx.AbortCode, "operation_aborted", hookCtx.AbortMessage, nil)
|
|
return
|
|
}
|
|
|
|
// Use potentially modified SQL query and variables from hooks
|
|
sqlquery = hookCtx.SQLQuery
|
|
variables = hookCtx.Variables
|
|
// complexAPI = hookCtx.ComplexAPI
|
|
|
|
// Extract input variables from SQL query (placeholders like [variable])
|
|
sqlquery = h.extractInputVariables(sqlquery, &inputvars)
|
|
|
|
// Merge URL path parameters
|
|
sqlquery = h.mergePathParams(r, sqlquery, variables)
|
|
|
|
// Parse comprehensive parameters from headers and query string
|
|
reqParams := h.ParseParameters(r)
|
|
complexAPI = reqParams.ComplexAPI
|
|
|
|
// Merge query string parameters
|
|
sqlquery = h.mergeQueryParams(r, sqlquery, variables, pAllowFilter, propQry)
|
|
|
|
// Merge header parameters
|
|
sqlquery = h.mergeHeaderParams(r, sqlquery, variables, propQry, &complexAPI)
|
|
|
|
// Apply filters from parsed parameters (if not already applied by pAllowFilter)
|
|
if !pAllowFilter {
|
|
sqlquery = h.ApplyFilters(sqlquery, reqParams)
|
|
}
|
|
|
|
// Apply field selection
|
|
sqlquery = h.ApplyFieldSelection(sqlquery, reqParams)
|
|
|
|
// Apply DISTINCT if requested
|
|
sqlquery = h.ApplyDistinct(sqlquery, reqParams)
|
|
|
|
// Override pNoCount if skipcount is specified
|
|
if reqParams.SkipCount {
|
|
pNoCount = true
|
|
}
|
|
|
|
// Build metainfo
|
|
metainfo["ipaddress"] = getIPAddress(r)
|
|
metainfo["url"] = r.RequestURI
|
|
metainfo["user"] = userCtx.UserName
|
|
metainfo["rid_user"] = fmt.Sprintf("%d", userCtx.UserID)
|
|
metainfo["method"] = r.Method
|
|
metainfo["variables"] = variables
|
|
|
|
// Replace meta variables in SQL
|
|
sqlquery = h.replaceMetaVariables(sqlquery, r, userCtx, metainfo, variables)
|
|
|
|
// Remove unused input variables
|
|
if pBlankparms {
|
|
for _, kw := range inputvars {
|
|
sqlquery = strings.ReplaceAll(sqlquery, kw, "")
|
|
logger.Debug("Removed unused variable: %s", kw)
|
|
}
|
|
}
|
|
|
|
// Update hook context with latest SQL query and variables
|
|
hookCtx.SQLQuery = sqlquery
|
|
hookCtx.Variables = variables
|
|
hookCtx.InputVars = inputvars
|
|
|
|
// Execute query within transaction
|
|
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
|
sqlqueryCnt := sqlquery
|
|
|
|
// Parse sorting and pagination parameters
|
|
sortcols, limit, offset := h.parsePaginationParams(r)
|
|
|
|
// Override with parsed parameters if available
|
|
if reqParams.SortColumns != "" {
|
|
sortcols = reqParams.SortColumns
|
|
}
|
|
if reqParams.Limit > 0 {
|
|
limit = reqParams.Limit
|
|
}
|
|
if reqParams.Offset > 0 {
|
|
offset = reqParams.Offset
|
|
}
|
|
|
|
hookCtx.SortColumns = sortcols
|
|
hookCtx.Limit = limit
|
|
hookCtx.Offset = offset
|
|
fromPos := strings.Index(strings.ToLower(sqlquery), "from ")
|
|
orderbyPos := strings.Index(strings.ToLower(sqlquery), "order by")
|
|
|
|
if len(sortcols) > 0 && (orderbyPos < 0 || (orderbyPos > 0 && orderbyPos < fromPos)) {
|
|
sqlquery = fmt.Sprintf("%s \nORDER BY %s", sqlquery, ValidSQL(sortcols, "select"))
|
|
}
|
|
|
|
if !pNoCount {
|
|
if limit > 0 && offset > 0 {
|
|
sqlquery = fmt.Sprintf("%s \nLIMIT %d OFFSET %d", sqlquery, limit, offset)
|
|
} else if limit > 0 {
|
|
sqlquery = fmt.Sprintf("%s \nLIMIT %d", sqlquery, limit)
|
|
} else {
|
|
sqlquery = fmt.Sprintf("%s \nLIMIT %d", sqlquery, 20000)
|
|
}
|
|
|
|
// Get total count
|
|
countQuery := fmt.Sprintf("SELECT COUNT(1) FROM (%s) cnts", sqlqueryCnt)
|
|
var countResult struct{ Count int64 }
|
|
if err := tx.Query(ctx, &countResult, countQuery); err != nil {
|
|
sendError(w, http.StatusBadRequest, "count_failed", "Failed to retrieve record count", err)
|
|
return err
|
|
}
|
|
total = countResult.Count
|
|
}
|
|
|
|
// Execute BeforeSQLExec hook
|
|
hookCtx.SQLQuery = sqlquery
|
|
if err := h.hooks.Execute(BeforeSQLExec, hookCtx); err != nil {
|
|
logger.Error("BeforeSQLExec hook failed: %v", err)
|
|
sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err)
|
|
return err
|
|
}
|
|
// Use potentially modified SQL query from hook
|
|
sqlquery = hookCtx.SQLQuery
|
|
|
|
// Execute main query
|
|
rows := make([]map[string]interface{}, 0)
|
|
if err := tx.Query(ctx, &rows, sqlquery); err != nil {
|
|
sendError(w, http.StatusBadRequest, "query_failed", "Failed to retrieve records", err)
|
|
return err
|
|
}
|
|
|
|
dbobjlist = rows
|
|
|
|
if pNoCount {
|
|
total = int64(len(dbobjlist))
|
|
}
|
|
|
|
// Execute AfterSQLExec hook
|
|
hookCtx.Result = dbobjlist
|
|
hookCtx.Total = total
|
|
if err := h.hooks.Execute(AfterSQLExec, hookCtx); err != nil {
|
|
logger.Error("AfterSQLExec hook failed: %v", err)
|
|
sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err)
|
|
return err
|
|
}
|
|
// Use potentially modified result from hook
|
|
if modifiedResult, ok := hookCtx.Result.([]map[string]interface{}); ok {
|
|
dbobjlist = modifiedResult
|
|
}
|
|
total = hookCtx.Total
|
|
|
|
return nil
|
|
})
|
|
|
|
if err != nil {
|
|
logger.Error("Transaction failed: %v", err)
|
|
return
|
|
}
|
|
|
|
// Execute AfterQueryList hook
|
|
hookCtx.Result = dbobjlist
|
|
hookCtx.Total = total
|
|
hookCtx.Error = err
|
|
if err := h.hooks.Execute(AfterQueryList, hookCtx); err != nil {
|
|
logger.Error("AfterQueryList hook failed: %v", err)
|
|
sendError(w, http.StatusInternalServerError, "hook_error", "Hook execution failed", err)
|
|
return
|
|
}
|
|
// Use potentially modified result from hook
|
|
if modifiedResult, ok := hookCtx.Result.([]map[string]interface{}); ok {
|
|
dbobjlist = modifiedResult
|
|
}
|
|
total = hookCtx.Total
|
|
|
|
// Set response headers
|
|
respOffset := 0
|
|
if offsetStr := r.URL.Query().Get("offset"); offsetStr != "" {
|
|
if o, err := strconv.Atoi(offsetStr); err == nil {
|
|
respOffset = o
|
|
}
|
|
}
|
|
|
|
w.Header().Set("Content-Range", fmt.Sprintf("items %d-%d/%d", respOffset, respOffset+len(dbobjlist), total))
|
|
logger.Info("Serving: Records %d of %d", len(dbobjlist), total)
|
|
|
|
// Execute BeforeResponse hook
|
|
hookCtx.Result = dbobjlist
|
|
hookCtx.Total = total
|
|
if err := h.hooks.Execute(BeforeResponse, hookCtx); err != nil {
|
|
logger.Error("BeforeResponse hook failed: %v", err)
|
|
sendError(w, http.StatusInternalServerError, "hook_error", "Hook execution failed", err)
|
|
return
|
|
}
|
|
// Use potentially modified result from hook
|
|
if modifiedResult, ok := hookCtx.Result.([]map[string]interface{}); ok {
|
|
dbobjlist = modifiedResult
|
|
}
|
|
|
|
if len(dbobjlist) == 0 {
|
|
_, _ = w.Write([]byte("[]"))
|
|
return
|
|
}
|
|
|
|
// Format response based on response format
|
|
switch reqParams.ResponseFormat {
|
|
case "syncfusion":
|
|
// Syncfusion format: { result: data, count: total }
|
|
response := map[string]interface{}{
|
|
"result": dbobjlist,
|
|
"count": total,
|
|
}
|
|
data, err := json.Marshal(response)
|
|
if err != nil {
|
|
sendError(w, http.StatusInternalServerError, "json_error", "Could not marshal response", err)
|
|
} else {
|
|
if int64(len(dbobjlist)) < total {
|
|
w.WriteHeader(http.StatusPartialContent)
|
|
}
|
|
_, _ = w.Write(data)
|
|
}
|
|
|
|
case "detail":
|
|
// Detail format: complex API with metadata
|
|
metaobj := map[string]interface{}{
|
|
"items": dbobjlist,
|
|
"count": fmt.Sprintf("%d", len(dbobjlist)),
|
|
"total": fmt.Sprintf("%d", total),
|
|
"tablename": r.URL.Path,
|
|
"tableprefix": "gsql",
|
|
}
|
|
data, err := json.Marshal(metaobj)
|
|
if err != nil {
|
|
sendError(w, http.StatusInternalServerError, "json_error", "Could not marshal response", err)
|
|
} else {
|
|
if int64(len(dbobjlist)) < total {
|
|
w.WriteHeader(http.StatusPartialContent)
|
|
}
|
|
_, _ = w.Write(data)
|
|
}
|
|
|
|
default:
|
|
// Simple format: just return the data array (or complex API if requested)
|
|
if complexAPI {
|
|
metaobj := map[string]interface{}{
|
|
"items": dbobjlist,
|
|
"count": fmt.Sprintf("%d", len(dbobjlist)),
|
|
"total": fmt.Sprintf("%d", total),
|
|
"tablename": r.URL.Path,
|
|
"tableprefix": "gsql",
|
|
}
|
|
data, err := json.Marshal(metaobj)
|
|
if err != nil {
|
|
sendError(w, http.StatusInternalServerError, "json_error", "Could not marshal response", err)
|
|
} else {
|
|
if int64(len(dbobjlist)) < total {
|
|
w.WriteHeader(http.StatusPartialContent)
|
|
}
|
|
_, _ = w.Write(data)
|
|
}
|
|
} else {
|
|
data, err := json.Marshal(dbobjlist)
|
|
if err != nil {
|
|
sendError(w, http.StatusInternalServerError, "json_error", "Could not marshal response", err)
|
|
} else {
|
|
if int64(len(dbobjlist)) < total {
|
|
w.WriteHeader(http.StatusPartialContent)
|
|
}
|
|
_, _ = w.Write(data)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// SqlQuery creates an HTTP handler that executes a SQL query and returns a single record
|
|
func (h *Handler) SqlQuery(sqlquery string, pBlankparms bool) HTTPFuncType {
|
|
return func(w http.ResponseWriter, r *http.Request) {
|
|
defer func() {
|
|
if err := recover(); err != nil {
|
|
stack := debug.Stack()
|
|
logger.Error("Panic in SqlQuery: %v\nStack trace:\n%s", err, string(stack))
|
|
http.Error(w, fmt.Sprintf("Internal server error: %v", err), http.StatusInternalServerError)
|
|
}
|
|
}()
|
|
|
|
ctx, cancel := context.WithTimeout(r.Context(), 600*time.Second)
|
|
defer cancel()
|
|
|
|
propQry := make(map[string]string)
|
|
inputvars := make([]string, 0)
|
|
metainfo := make(map[string]interface{})
|
|
variables := make(map[string]interface{})
|
|
dbobj := make(map[string]interface{})
|
|
complexAPI := false
|
|
|
|
// Get user context from security package
|
|
userCtx, ok := security.GetUserContext(ctx)
|
|
if !ok {
|
|
logger.Warn("No user context found in request")
|
|
userCtx = &security.UserContext{UserID: 0, UserName: "anonymous"}
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
|
|
// Initialize hook context
|
|
hookCtx := &HookContext{
|
|
Context: ctx,
|
|
Handler: h,
|
|
Request: r,
|
|
Writer: w,
|
|
SQLQuery: sqlquery,
|
|
Variables: variables,
|
|
InputVars: inputvars,
|
|
MetaInfo: metainfo,
|
|
PropQry: propQry,
|
|
UserContext: userCtx,
|
|
BlankParams: pBlankparms,
|
|
ComplexAPI: complexAPI,
|
|
}
|
|
|
|
// Execute BeforeQuery hook
|
|
if err := h.hooks.Execute(BeforeQuery, hookCtx); err != nil {
|
|
logger.Error("BeforeQuery hook failed: %v", err)
|
|
sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err)
|
|
return
|
|
}
|
|
|
|
// Check if hook aborted the operation
|
|
if hookCtx.Abort {
|
|
if hookCtx.AbortCode == 0 {
|
|
hookCtx.AbortCode = http.StatusBadRequest
|
|
}
|
|
sendError(w, hookCtx.AbortCode, "operation_aborted", hookCtx.AbortMessage, nil)
|
|
return
|
|
}
|
|
|
|
// Use potentially modified SQL query and variables from hooks
|
|
sqlquery = hookCtx.SQLQuery
|
|
variables = hookCtx.Variables
|
|
|
|
// Extract input variables from SQL query
|
|
sqlquery = h.extractInputVariables(sqlquery, &inputvars)
|
|
|
|
// Merge URL path parameters
|
|
sqlquery = h.mergePathParams(r, sqlquery, variables)
|
|
|
|
// Parse comprehensive parameters from headers and query string
|
|
reqParams := h.ParseParameters(r)
|
|
complexAPI = reqParams.ComplexAPI
|
|
|
|
// Merge query string parameters
|
|
sqlquery = h.mergeQueryParams(r, sqlquery, variables, false, propQry)
|
|
|
|
// Merge header parameters
|
|
sqlquery = h.mergeHeaderParams(r, sqlquery, variables, propQry, &complexAPI)
|
|
hookCtx.ComplexAPI = complexAPI
|
|
|
|
// Apply filters from parsed parameters
|
|
sqlquery = h.ApplyFilters(sqlquery, reqParams)
|
|
|
|
// Apply field selection
|
|
sqlquery = h.ApplyFieldSelection(sqlquery, reqParams)
|
|
|
|
// Apply DISTINCT if requested
|
|
sqlquery = h.ApplyDistinct(sqlquery, reqParams)
|
|
|
|
// Build metainfo
|
|
metainfo["ipaddress"] = getIPAddress(r)
|
|
metainfo["url"] = r.RequestURI
|
|
metainfo["user"] = userCtx.UserName
|
|
metainfo["rid_user"] = fmt.Sprintf("%d", userCtx.UserID)
|
|
metainfo["method"] = r.Method
|
|
metainfo["variables"] = variables
|
|
|
|
// Replace meta variables in SQL
|
|
sqlquery = h.replaceMetaVariables(sqlquery, r, userCtx, metainfo, variables)
|
|
|
|
// Apply field filters from headers
|
|
for k, val := range propQry {
|
|
kLower := strings.ToLower(k)
|
|
if strings.HasPrefix(kLower, "x-fieldfilter-") {
|
|
colname := strings.ReplaceAll(kLower, "x-fieldfilter-", "")
|
|
if strings.Contains(strings.ToLower(sqlquery), colname) {
|
|
if val == "" || val == "0" {
|
|
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("COALESCE(%s, 0) = %s", ValidSQL(colname, "colname"), ValidSQL(val, "colvalue")))
|
|
} else {
|
|
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s = %s", ValidSQL(colname, "colname"), ValidSQL(val, "colvalue")))
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Remove unused input variables
|
|
if pBlankparms {
|
|
for _, kw := range inputvars {
|
|
sqlquery = strings.ReplaceAll(sqlquery, kw, "")
|
|
logger.Debug("Removed unused variable: %s", kw)
|
|
}
|
|
}
|
|
|
|
// Update hook context with latest SQL query and variables
|
|
hookCtx.SQLQuery = sqlquery
|
|
hookCtx.Variables = variables
|
|
hookCtx.InputVars = inputvars
|
|
|
|
// Execute query within transaction
|
|
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
|
// Execute BeforeSQLExec hook
|
|
if err := h.hooks.Execute(BeforeSQLExec, hookCtx); err != nil {
|
|
logger.Error("BeforeSQLExec hook failed: %v", err)
|
|
sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err)
|
|
return err
|
|
}
|
|
// Use potentially modified SQL query from hook
|
|
sqlquery = hookCtx.SQLQuery
|
|
|
|
// Execute main query
|
|
rows := make([]map[string]interface{}, 0)
|
|
if err := tx.Query(ctx, &rows, sqlquery); err != nil {
|
|
sendError(w, http.StatusBadRequest, "query_failed", "Failed to retrieve records", err)
|
|
return err
|
|
}
|
|
|
|
if len(rows) > 0 {
|
|
dbobj = rows[0]
|
|
}
|
|
|
|
// Execute AfterSQLExec hook
|
|
hookCtx.Result = dbobj
|
|
if err := h.hooks.Execute(AfterSQLExec, hookCtx); err != nil {
|
|
logger.Error("AfterSQLExec hook failed: %v", err)
|
|
sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err)
|
|
return err
|
|
}
|
|
// Use potentially modified result from hook
|
|
if modifiedResult, ok := hookCtx.Result.(map[string]interface{}); ok {
|
|
dbobj = modifiedResult
|
|
}
|
|
|
|
return nil
|
|
})
|
|
|
|
if err != nil {
|
|
logger.Error("Transaction failed: %v", err)
|
|
return
|
|
}
|
|
|
|
// Execute AfterQuery hook
|
|
hookCtx.Result = dbobj
|
|
hookCtx.Error = err
|
|
if err := h.hooks.Execute(AfterQuery, hookCtx); err != nil {
|
|
logger.Error("AfterQuery hook failed: %v", err)
|
|
sendError(w, http.StatusInternalServerError, "hook_error", "Hook execution failed", err)
|
|
return
|
|
}
|
|
// Use potentially modified result from hook
|
|
if modifiedResult, ok := hookCtx.Result.(map[string]interface{}); ok {
|
|
dbobj = modifiedResult
|
|
}
|
|
|
|
// Execute BeforeResponse hook
|
|
hookCtx.Result = dbobj
|
|
if err := h.hooks.Execute(BeforeResponse, hookCtx); err != nil {
|
|
logger.Error("BeforeResponse hook failed: %v", err)
|
|
sendError(w, http.StatusInternalServerError, "hook_error", "Hook execution failed", err)
|
|
return
|
|
}
|
|
// Use potentially modified result from hook
|
|
if modifiedResult, ok := hookCtx.Result.(map[string]interface{}); ok {
|
|
dbobj = modifiedResult
|
|
}
|
|
|
|
// Check if response should be root-level data
|
|
if val, ok := dbobj["root_as_data"]; ok {
|
|
data, err := json.Marshal(val)
|
|
if err != nil {
|
|
sendError(w, http.StatusInternalServerError, "json_error", "Could not marshal response", err)
|
|
} else {
|
|
_, _ = w.Write(data)
|
|
}
|
|
return
|
|
}
|
|
|
|
// Marshal and send response
|
|
data, err := json.Marshal(dbobj)
|
|
if err != nil {
|
|
sendError(w, http.StatusInternalServerError, "json_error", "Could not marshal response", err)
|
|
} else {
|
|
_, _ = w.Write(data)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Helper functions
|
|
|
|
// extractInputVariables extracts placeholders like [variable] from the SQL query
|
|
func (h *Handler) extractInputVariables(sqlquery string, inputvars *[]string) string {
|
|
|
|
testsqlquery := sqlquery
|
|
for i := 0; i <= strings.Count(sqlquery, "[")*4; i++ {
|
|
iStart := strings.Index(testsqlquery, "[")
|
|
if iStart < 0 {
|
|
break
|
|
}
|
|
iEnd := strings.Index(testsqlquery, "]")
|
|
if iEnd < 0 {
|
|
break
|
|
}
|
|
*inputvars = append(*inputvars, testsqlquery[iStart:iEnd+1])
|
|
testsqlquery = testsqlquery[iEnd+1:]
|
|
}
|
|
return sqlquery
|
|
}
|
|
|
|
// mergePathParams merges URL path parameters into the SQL query
|
|
func (h *Handler) mergePathParams(r *http.Request, sqlquery string, variables map[string]interface{}) string {
|
|
// Note: Path parameters would typically come from a router like gorilla/mux
|
|
// For now, this is a placeholder for path parameter extraction
|
|
return sqlquery
|
|
}
|
|
|
|
// mergeQueryParams merges query string parameters into the SQL query
|
|
func (h *Handler) mergeQueryParams(r *http.Request, sqlquery string, variables map[string]interface{}, allowFilter bool, propQry map[string]string) string {
|
|
for parmk, parmv := range r.URL.Query() {
|
|
if len(parmk) == 0 || len(parmv) == 0 {
|
|
continue
|
|
}
|
|
|
|
val := parmv[0]
|
|
dec, err := restheadspec.DecodeParam(val)
|
|
if err == nil {
|
|
val = dec
|
|
}
|
|
|
|
kword := fmt.Sprintf("[%s]", parmk)
|
|
variables[parmk] = val
|
|
|
|
// Replace in SQL if placeholder exists
|
|
if strings.Contains(sqlquery, kword) && len(val) > 0 {
|
|
if strings.HasPrefix(parmk, "p-") {
|
|
sqlquery = strings.ReplaceAll(sqlquery, kword, val)
|
|
}
|
|
}
|
|
|
|
// Add to propQry for x- prefixed params
|
|
if strings.HasPrefix(parmk, "x-") {
|
|
propQry[parmk] = val
|
|
}
|
|
|
|
// Apply filters if allowed
|
|
if allowFilter && len(parmk) > 1 && strings.Contains(strings.ToLower(sqlquery), strings.ToLower(parmk)) {
|
|
if len(parmv) > 1 {
|
|
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s IN (%s)", ValidSQL(parmk, "colname"), strings.Join(parmv, ",")))
|
|
} else {
|
|
if strings.Contains(val, "match=") {
|
|
colval := strings.ReplaceAll(val, "match=", "")
|
|
if colval != "*" {
|
|
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s ILIKE '%%%s%%'", ValidSQL(parmk, "colname"), ValidSQL(colval, "colvalue")))
|
|
}
|
|
} else if val == "" || val == "0" {
|
|
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("(%[1]s = %[2]s OR %[1]s IS NULL)", ValidSQL(parmk, "colname"), ValidSQL(val, "colvalue")))
|
|
} else {
|
|
if IsNumeric(val) {
|
|
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s = %s", ValidSQL(parmk, "colname"), ValidSQL(val, "colvalue")))
|
|
} else {
|
|
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s = '%s'", ValidSQL(parmk, "colname"), ValidSQL(val, "colvalue")))
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return sqlquery
|
|
}
|
|
|
|
// mergeHeaderParams merges HTTP header parameters into the SQL query
|
|
func (h *Handler) mergeHeaderParams(r *http.Request, sqlquery string, variables map[string]interface{}, propQry map[string]string, complexAPI *bool) string {
|
|
for kc, v := range r.Header {
|
|
k := strings.ToLower(kc)
|
|
if !strings.HasPrefix(k, "x-") || len(v) == 0 {
|
|
continue
|
|
}
|
|
|
|
val := v[0]
|
|
dec, err := restheadspec.DecodeParam(val)
|
|
if err == nil {
|
|
val = dec
|
|
}
|
|
|
|
variables[k] = val
|
|
propQry[k] = val
|
|
|
|
kword := fmt.Sprintf("[%s]", k)
|
|
if strings.Contains(sqlquery, kword) {
|
|
sqlquery = strings.ReplaceAll(sqlquery, kword, val)
|
|
}
|
|
|
|
// Handle special headers
|
|
if strings.Contains(k, "x-fieldfilter-") {
|
|
colname := strings.ReplaceAll(k, "x-fieldfilter-", "")
|
|
if val == "" || val == "0" {
|
|
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("COALESCE(%s, 0) = %s", ValidSQL(colname, "colname"), ValidSQL(val, "colvalue")))
|
|
} else {
|
|
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s = %s", ValidSQL(colname, "colname"), ValidSQL(val, "colvalue")))
|
|
}
|
|
}
|
|
|
|
if strings.Contains(k, "x-searchfilter-") {
|
|
colname := strings.ReplaceAll(k, "x-searchfilter-", "")
|
|
sval := strings.ReplaceAll(val, "'", "")
|
|
if sval != "" {
|
|
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s ILIKE '%%%s%%'", ValidSQL(colname, "colname"), ValidSQL(sval, "colvalue")))
|
|
}
|
|
}
|
|
|
|
if strings.Contains(k, "x-custom-sql-w") {
|
|
colval := ValidSQL(val, "select")
|
|
if len(colval) > 0 {
|
|
sqlquery = sqlQryWhere(sqlquery, colval)
|
|
}
|
|
}
|
|
|
|
if strings.Contains(k, "x-simpleapi") {
|
|
*complexAPI = !strings.EqualFold(val, "1") && !strings.EqualFold(val, "true")
|
|
}
|
|
}
|
|
return sqlquery
|
|
}
|
|
|
|
// replaceMetaVariables replaces meta variables like [rid_user], [user], etc. in the SQL query
|
|
func (h *Handler) replaceMetaVariables(sqlquery string, r *http.Request, userCtx *security.UserContext, metainfo map[string]interface{}, variables map[string]interface{}) string {
|
|
if strings.Contains(sqlquery, "[p_meta_default]") {
|
|
data, _ := json.Marshal(metainfo)
|
|
sqlquery = strings.ReplaceAll(sqlquery, "[p_meta_default]", fmt.Sprintf("'%s'::jsonb", string(data)))
|
|
}
|
|
|
|
if strings.Contains(sqlquery, "[json_variables]") {
|
|
data, _ := json.Marshal(variables)
|
|
sqlquery = strings.ReplaceAll(sqlquery, "[json_variables]", fmt.Sprintf("'%s'::jsonb", string(data)))
|
|
}
|
|
|
|
if strings.Contains(sqlquery, "[rid_user]") {
|
|
sqlquery = strings.ReplaceAll(sqlquery, "[rid_user]", fmt.Sprintf("%d", userCtx.UserID))
|
|
}
|
|
|
|
if strings.Contains(sqlquery, "[user]") {
|
|
sqlquery = strings.ReplaceAll(sqlquery, "[user]", fmt.Sprintf("'%s'", userCtx.UserName))
|
|
}
|
|
|
|
if strings.Contains(sqlquery, "[rid_session]") {
|
|
sessionID := userCtx.SessionID
|
|
sqlquery = strings.ReplaceAll(sqlquery, "[rid_session]", fmt.Sprintf("'%s'", sessionID))
|
|
}
|
|
|
|
if strings.Contains(sqlquery, "[method]") {
|
|
sqlquery = strings.ReplaceAll(sqlquery, "[method]", r.Method)
|
|
}
|
|
|
|
if strings.Contains(sqlquery, "[post_body]") {
|
|
bodystr := ""
|
|
if r.Method == "POST" || r.Method == "PUT" {
|
|
if r.Body != nil {
|
|
contents, err := io.ReadAll(r.Body)
|
|
if err == nil {
|
|
bodystr = string(contents)
|
|
}
|
|
}
|
|
}
|
|
sqlquery = strings.ReplaceAll(sqlquery, "[post_body]", fmt.Sprintf("'%s'", bodystr))
|
|
}
|
|
|
|
return sqlquery
|
|
}
|
|
|
|
// parsePaginationParams extracts sort, limit, and offset parameters from request
|
|
func (h *Handler) parsePaginationParams(r *http.Request) (sortcols string, limit, offset int) {
|
|
limit = 20
|
|
offset = 0
|
|
|
|
if sortStr := r.URL.Query().Get("sort"); sortStr != "" {
|
|
sortcols = sortStr
|
|
}
|
|
|
|
if limitStr := r.URL.Query().Get("limit"); limitStr != "" {
|
|
if l, err := strconv.Atoi(limitStr); err == nil && l > 0 {
|
|
limit = l
|
|
}
|
|
}
|
|
|
|
if offsetStr := r.URL.Query().Get("offset"); offsetStr != "" {
|
|
if o, err := strconv.Atoi(offsetStr); err == nil && o >= 0 {
|
|
offset = o
|
|
}
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
// ValidSQL validates and sanitizes SQL input to prevent injection
|
|
// mode can be: "colname", "colvalue", "select"
|
|
func ValidSQL(input, mode string) string {
|
|
// Remove dangerous characters based on mode
|
|
switch mode {
|
|
case "colname":
|
|
// For column names, only allow alphanumeric, underscore, and dot
|
|
reg := regexp.MustCompile(`[^a-zA-Z0-9_\.]`)
|
|
return reg.ReplaceAllString(input, "")
|
|
case "colvalue":
|
|
// For column values, escape single quotes
|
|
return strings.ReplaceAll(input, "'", "''")
|
|
case "select":
|
|
// For SELECT clauses, be more permissive but still safe
|
|
// Remove semicolons and common SQL injection patterns
|
|
dangerous := []string{";", "--", "/*", "*/", "xp_", "sp_", "DROP ", "DELETE ", "TRUNCATE ", "UPDATE ", "INSERT "}
|
|
result := input
|
|
for _, d := range dangerous {
|
|
result = strings.ReplaceAll(result, d, "")
|
|
result = strings.ReplaceAll(result, strings.ToLower(d), "")
|
|
result = strings.ReplaceAll(result, strings.ToUpper(d), "")
|
|
}
|
|
return result
|
|
default:
|
|
return input
|
|
}
|
|
}
|
|
|
|
// sqlQryWhere adds a WHERE clause to a SQL query or appends to existing WHERE with AND
|
|
func sqlQryWhere(sqlquery, condition string) string {
|
|
lowerQuery := strings.ToLower(sqlquery)
|
|
wherePos := strings.Index(lowerQuery, " where ")
|
|
groupPos := strings.Index(lowerQuery, " group by")
|
|
orderPos := strings.Index(lowerQuery, " order by")
|
|
limitPos := strings.Index(lowerQuery, " limit ")
|
|
|
|
// Find the insertion point (before GROUP BY, ORDER BY, or LIMIT)
|
|
insertPos := len(sqlquery)
|
|
if groupPos > 0 && groupPos < insertPos {
|
|
insertPos = groupPos
|
|
}
|
|
if orderPos > 0 && orderPos < insertPos {
|
|
insertPos = orderPos
|
|
}
|
|
if limitPos > 0 && limitPos < insertPos {
|
|
insertPos = limitPos
|
|
}
|
|
|
|
if wherePos > 0 {
|
|
// WHERE exists, add AND condition before GROUP BY / ORDER BY / LIMIT
|
|
before := sqlquery[:insertPos]
|
|
after := sqlquery[insertPos:]
|
|
return fmt.Sprintf("%s AND %s %s", before, condition, after)
|
|
} else {
|
|
// No WHERE exists, add it before GROUP BY / ORDER BY / LIMIT
|
|
before := sqlquery[:insertPos]
|
|
after := sqlquery[insertPos:]
|
|
return fmt.Sprintf("%s WHERE %s %s", before, condition, after)
|
|
}
|
|
}
|
|
|
|
// IsNumeric checks if a string contains only numeric characters
|
|
func IsNumeric(s string) bool {
|
|
_, err := strconv.ParseFloat(s, 64)
|
|
return err == nil
|
|
}
|
|
|
|
// makeResultReceiver creates a slice of interface{} pointers for scanning SQL rows
|
|
// func makeResultReceiver(length int) []interface{} {
|
|
// result := make([]interface{}, length)
|
|
// for i := 0; i < length; i++ {
|
|
// var v interface{}
|
|
// result[i] = &v
|
|
// }
|
|
// return result
|
|
// }
|
|
|
|
// getIPAddress extracts the real IP address from the request
|
|
func getIPAddress(r *http.Request) string {
|
|
if forwarded := r.Header.Get("X-Forwarded-For"); forwarded != "" {
|
|
// X-Forwarded-For can contain multiple IPs, take the first one
|
|
ips := strings.Split(forwarded, ",")
|
|
return strings.TrimSpace(ips[0])
|
|
}
|
|
if realIP := r.Header.Get("X-Real-IP"); realIP != "" {
|
|
return realIP
|
|
}
|
|
return r.RemoteAddr
|
|
}
|
|
|
|
// sendError sends a JSON error response
|
|
func sendError(w http.ResponseWriter, status int, code, message string, err error) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(status)
|
|
|
|
errObj := common.APIError{
|
|
Code: code,
|
|
Message: message,
|
|
}
|
|
if err != nil {
|
|
errObj.Detail = err.Error()
|
|
}
|
|
|
|
data, _ := json.Marshal(map[string]interface{}{
|
|
"success": false,
|
|
"error": errObj,
|
|
})
|
|
_, _ = w.Write(data)
|
|
}
|