mirror of
https://github.com/bitechdev/ResolveSpec.git
synced 2025-12-06 14:26:22 +00:00
Added function api prototype
This commit is contained in:
parent
e32ec9e17e
commit
6bbe0ec8b0
629
pkg/funcspec/function_api.go
Normal file
629
pkg/funcspec/function_api.go
Normal file
@ -0,0 +1,629 @@
|
||||
package funcspec
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"runtime/debug"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/restheadspec"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/security"
|
||||
)
|
||||
|
||||
// Handler handles function-based SQL API requests
|
||||
type Handler struct {
|
||||
db common.Database
|
||||
}
|
||||
|
||||
// NewHandler creates a new function API handler
|
||||
func NewHandler(db common.Database) *Handler {
|
||||
return &Handler{
|
||||
db: db,
|
||||
}
|
||||
}
|
||||
|
||||
// HTTPFuncType is a function type for HTTP handlers
|
||||
type HTTPFuncType func(http.ResponseWriter, *http.Request)
|
||||
|
||||
// SqlQueryList creates an HTTP handler that executes a SQL query and returns a list with pagination
|
||||
func (h *Handler) SqlQueryList(sqlquery string, pNoCount, pBlankparms, pAllowFilter bool) HTTPFuncType {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
stack := debug.Stack()
|
||||
logger.Error("Panic in SqlQueryList: %v\nStack trace:\n%s", err, string(stack))
|
||||
http.Error(w, fmt.Sprintf("Internal server error: %v", err), http.StatusInternalServerError)
|
||||
}
|
||||
}()
|
||||
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 900*time.Second)
|
||||
defer cancel()
|
||||
|
||||
var dbobjlist []map[string]interface{}
|
||||
var total int64
|
||||
propQry := make(map[string]string)
|
||||
inputvars := make([]string, 0)
|
||||
metainfo := make(map[string]interface{})
|
||||
variables := make(map[string]interface{})
|
||||
complexAPI := false
|
||||
|
||||
// Get user context from security package
|
||||
userCtx, ok := security.GetUserContext(ctx)
|
||||
if !ok {
|
||||
logger.Warn("No user context found in request")
|
||||
userCtx = &security.UserContext{UserID: 0, UserName: "anonymous"}
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
// Extract input variables from SQL query (placeholders like [variable])
|
||||
sqlquery = h.extractInputVariables(sqlquery, &inputvars)
|
||||
|
||||
// Merge URL path parameters
|
||||
sqlquery = h.mergePathParams(r, sqlquery, variables)
|
||||
|
||||
// Merge query string parameters
|
||||
sqlquery = h.mergeQueryParams(r, sqlquery, variables, pAllowFilter, propQry)
|
||||
|
||||
// Merge header parameters
|
||||
sqlquery = h.mergeHeaderParams(r, sqlquery, variables, propQry, &complexAPI)
|
||||
|
||||
// Build metainfo
|
||||
metainfo["ipaddress"] = getIPAddress(r)
|
||||
metainfo["url"] = r.RequestURI
|
||||
metainfo["user"] = userCtx.UserName
|
||||
metainfo["rid_user"] = fmt.Sprintf("%d", userCtx.UserID)
|
||||
metainfo["method"] = r.Method
|
||||
metainfo["variables"] = variables
|
||||
|
||||
// Replace meta variables in SQL
|
||||
sqlquery = h.replaceMetaVariables(sqlquery, r, userCtx, metainfo, variables)
|
||||
|
||||
// Remove unused input variables
|
||||
if pBlankparms {
|
||||
for _, kw := range inputvars {
|
||||
sqlquery = strings.ReplaceAll(sqlquery, kw, "")
|
||||
logger.Debug("Removed unused variable: %s", kw)
|
||||
}
|
||||
}
|
||||
|
||||
// Execute query within transaction
|
||||
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
||||
sqlqueryCnt := sqlquery
|
||||
|
||||
// Parse sorting and pagination parameters
|
||||
sortcols, limit, offset := h.parsePaginationParams(r)
|
||||
fromPos := strings.Index(strings.ToLower(sqlquery), "from ")
|
||||
orderbyPos := strings.Index(strings.ToLower(sqlquery), "order by")
|
||||
|
||||
if len(sortcols) > 0 && (orderbyPos < 0 || (orderbyPos > 0 && orderbyPos < fromPos)) {
|
||||
sqlquery = fmt.Sprintf("%s \nORDER BY %s", sqlquery, ValidSQL(sortcols, "select"))
|
||||
}
|
||||
|
||||
if !pNoCount {
|
||||
if limit > 0 && offset > 0 {
|
||||
sqlquery = fmt.Sprintf("%s \nLIMIT %d OFFSET %d", sqlquery, limit, offset)
|
||||
} else if limit > 0 {
|
||||
sqlquery = fmt.Sprintf("%s \nLIMIT %d", sqlquery, limit)
|
||||
} else {
|
||||
sqlquery = fmt.Sprintf("%s \nLIMIT %d", sqlquery, 20000)
|
||||
}
|
||||
|
||||
// Get total count
|
||||
countQuery := fmt.Sprintf("SELECT COUNT(1) FROM (%s) cnts", sqlqueryCnt)
|
||||
var countResult struct{ Count int64 }
|
||||
if err := tx.Query(ctx, &countResult, countQuery); err != nil {
|
||||
sendError(w, http.StatusBadRequest, "count_failed", "Failed to retrieve record count", err)
|
||||
return err
|
||||
}
|
||||
total = countResult.Count
|
||||
}
|
||||
|
||||
// Execute main query
|
||||
rows := make([]map[string]interface{}, 0)
|
||||
if err := tx.Query(ctx, &rows, sqlquery); err != nil {
|
||||
sendError(w, http.StatusBadRequest, "query_failed", "Failed to retrieve records", err)
|
||||
return err
|
||||
}
|
||||
|
||||
dbobjlist = rows
|
||||
|
||||
if pNoCount {
|
||||
total = int64(len(dbobjlist))
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
logger.Error("Transaction failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Set response headers
|
||||
respOffset := 0
|
||||
if offsetStr := r.URL.Query().Get("offset"); offsetStr != "" {
|
||||
if o, err := strconv.Atoi(offsetStr); err == nil {
|
||||
respOffset = o
|
||||
}
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Range", fmt.Sprintf("items %d-%d/%d", respOffset, respOffset+len(dbobjlist), total))
|
||||
logger.Info("Serving: Records %d of %d", len(dbobjlist), total)
|
||||
|
||||
if len(dbobjlist) == 0 {
|
||||
w.Write([]byte("[]"))
|
||||
return
|
||||
}
|
||||
|
||||
if complexAPI {
|
||||
metaobj := map[string]interface{}{
|
||||
"items": dbobjlist,
|
||||
"count": fmt.Sprintf("%d", len(dbobjlist)),
|
||||
"total": fmt.Sprintf("%d", total),
|
||||
"tablename": r.URL.Path,
|
||||
"tableprefix": "gsql",
|
||||
}
|
||||
|
||||
data, err := json.Marshal(metaobj)
|
||||
if err != nil {
|
||||
sendError(w, http.StatusInternalServerError, "json_error", "Could not marshal response", err)
|
||||
} else {
|
||||
if int64(len(dbobjlist)) < total {
|
||||
w.WriteHeader(http.StatusPartialContent)
|
||||
}
|
||||
w.Write(data)
|
||||
}
|
||||
} else {
|
||||
data, err := json.Marshal(dbobjlist)
|
||||
if err != nil {
|
||||
sendError(w, http.StatusInternalServerError, "json_error", "Could not marshal response", err)
|
||||
} else {
|
||||
if int64(len(dbobjlist)) < total {
|
||||
w.WriteHeader(http.StatusPartialContent)
|
||||
}
|
||||
w.Write(data)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SqlQuery creates an HTTP handler that executes a SQL query and returns a single record
|
||||
func (h *Handler) SqlQuery(sqlquery string, pBlankparms bool) HTTPFuncType {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
stack := debug.Stack()
|
||||
logger.Error("Panic in SqlQuery: %v\nStack trace:\n%s", err, string(stack))
|
||||
http.Error(w, fmt.Sprintf("Internal server error: %v", err), http.StatusInternalServerError)
|
||||
}
|
||||
}()
|
||||
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 600*time.Second)
|
||||
defer cancel()
|
||||
|
||||
propQry := make(map[string]string)
|
||||
inputvars := make([]string, 0)
|
||||
metainfo := make(map[string]interface{})
|
||||
variables := make(map[string]interface{})
|
||||
dbobj := make(map[string]interface{})
|
||||
|
||||
// Get user context from security package
|
||||
userCtx, ok := security.GetUserContext(ctx)
|
||||
if !ok {
|
||||
logger.Warn("No user context found in request")
|
||||
userCtx = &security.UserContext{UserID: 0, UserName: "anonymous"}
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
// Extract input variables from SQL query
|
||||
sqlquery = h.extractInputVariables(sqlquery, &inputvars)
|
||||
|
||||
// Merge URL path parameters
|
||||
sqlquery = h.mergePathParams(r, sqlquery, variables)
|
||||
|
||||
// Merge query string parameters
|
||||
sqlquery = h.mergeQueryParams(r, sqlquery, variables, false, propQry)
|
||||
|
||||
// Merge header parameters
|
||||
complexAPI := false
|
||||
sqlquery = h.mergeHeaderParams(r, sqlquery, variables, propQry, &complexAPI)
|
||||
|
||||
// Build metainfo
|
||||
metainfo["ipaddress"] = getIPAddress(r)
|
||||
metainfo["url"] = r.RequestURI
|
||||
metainfo["user"] = userCtx.UserName
|
||||
metainfo["rid_user"] = fmt.Sprintf("%d", userCtx.UserID)
|
||||
metainfo["method"] = r.Method
|
||||
metainfo["variables"] = variables
|
||||
|
||||
// Replace meta variables in SQL
|
||||
sqlquery = h.replaceMetaVariables(sqlquery, r, userCtx, metainfo, variables)
|
||||
|
||||
// Apply field filters from headers
|
||||
for k, val := range propQry {
|
||||
kLower := strings.ToLower(k)
|
||||
if strings.HasPrefix(kLower, "x-fieldfilter-") {
|
||||
colname := strings.ReplaceAll(kLower, "x-fieldfilter-", "")
|
||||
if strings.Contains(strings.ToLower(sqlquery), colname) {
|
||||
if val == "" || val == "0" {
|
||||
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("COALESCE(%s, 0) = %s", ValidSQL(colname, "colname"), ValidSQL(val, "colvalue")))
|
||||
} else {
|
||||
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s = %s", ValidSQL(colname, "colname"), ValidSQL(val, "colvalue")))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Remove unused input variables
|
||||
if pBlankparms {
|
||||
for _, kw := range inputvars {
|
||||
sqlquery = strings.ReplaceAll(sqlquery, kw, "")
|
||||
logger.Debug("Removed unused variable: %s", kw)
|
||||
}
|
||||
}
|
||||
|
||||
// Execute query within transaction
|
||||
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
||||
// Execute main query
|
||||
rows := make([]map[string]interface{}, 0)
|
||||
if err := tx.Query(ctx, &rows, sqlquery); err != nil {
|
||||
sendError(w, http.StatusBadRequest, "query_failed", "Failed to retrieve records", err)
|
||||
return err
|
||||
}
|
||||
|
||||
if len(rows) > 0 {
|
||||
dbobj = rows[0]
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
logger.Error("Transaction failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Check if response should be root-level data
|
||||
if val, ok := dbobj["root_as_data"]; ok {
|
||||
data, err := json.Marshal(val)
|
||||
if err != nil {
|
||||
sendError(w, http.StatusInternalServerError, "json_error", "Could not marshal response", err)
|
||||
} else {
|
||||
w.Write(data)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Marshal and send response
|
||||
data, err := json.Marshal(dbobj)
|
||||
if err != nil {
|
||||
sendError(w, http.StatusInternalServerError, "json_error", "Could not marshal response", err)
|
||||
} else {
|
||||
w.Write(data)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
|
||||
// extractInputVariables extracts placeholders like [variable] from the SQL query
|
||||
func (h *Handler) extractInputVariables(sqlquery string, inputvars *[]string) string {
|
||||
max := strings.Count(sqlquery, "[") * 4
|
||||
testsqlquery := sqlquery
|
||||
for i := 0; i <= max; i++ {
|
||||
iStart := strings.Index(testsqlquery, "[")
|
||||
if iStart < 0 {
|
||||
break
|
||||
}
|
||||
iEnd := strings.Index(testsqlquery, "]")
|
||||
if iEnd < 0 {
|
||||
break
|
||||
}
|
||||
*inputvars = append(*inputvars, testsqlquery[iStart:iEnd+1])
|
||||
testsqlquery = testsqlquery[iEnd+1:]
|
||||
}
|
||||
return sqlquery
|
||||
}
|
||||
|
||||
// mergePathParams merges URL path parameters into the SQL query
|
||||
func (h *Handler) mergePathParams(r *http.Request, sqlquery string, variables map[string]interface{}) string {
|
||||
// Note: Path parameters would typically come from a router like gorilla/mux
|
||||
// For now, this is a placeholder for path parameter extraction
|
||||
return sqlquery
|
||||
}
|
||||
|
||||
// mergeQueryParams merges query string parameters into the SQL query
|
||||
func (h *Handler) mergeQueryParams(r *http.Request, sqlquery string, variables map[string]interface{}, allowFilter bool, propQry map[string]string) string {
|
||||
for parmk, parmv := range r.URL.Query() {
|
||||
if len(parmk) == 0 || len(parmv) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
val := parmv[0]
|
||||
dec, err := restheadspec.DecodeParam(val)
|
||||
if err == nil {
|
||||
val = dec
|
||||
}
|
||||
|
||||
kword := fmt.Sprintf("[%s]", parmk)
|
||||
variables[parmk] = val
|
||||
|
||||
// Replace in SQL if placeholder exists
|
||||
if strings.Contains(sqlquery, kword) && len(val) > 0 {
|
||||
if strings.HasPrefix(parmk, "p-") {
|
||||
sqlquery = strings.ReplaceAll(sqlquery, kword, val)
|
||||
}
|
||||
}
|
||||
|
||||
// Add to propQry for x- prefixed params
|
||||
if strings.HasPrefix(parmk, "x-") {
|
||||
propQry[parmk] = val
|
||||
}
|
||||
|
||||
// Apply filters if allowed
|
||||
if allowFilter && len(parmk) > 1 && strings.Contains(strings.ToLower(sqlquery), strings.ToLower(parmk)) {
|
||||
if len(parmv) > 1 {
|
||||
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s IN (%s)", ValidSQL(parmk, "colname"), strings.Join(parmv, ",")))
|
||||
} else {
|
||||
if strings.Contains(val, "match=") {
|
||||
colval := strings.ReplaceAll(val, "match=", "")
|
||||
if colval != "*" {
|
||||
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s ILIKE '%%%s%%'", ValidSQL(parmk, "colname"), ValidSQL(colval, "colvalue")))
|
||||
}
|
||||
} else if val == "" || val == "0" {
|
||||
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("(%[1]s = %[2]s OR %[1]s IS NULL)", ValidSQL(parmk, "colname"), ValidSQL(val, "colvalue")))
|
||||
} else {
|
||||
if IsNumeric(val) {
|
||||
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s = %s", ValidSQL(parmk, "colname"), ValidSQL(val, "colvalue")))
|
||||
} else {
|
||||
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s = '%s'", ValidSQL(parmk, "colname"), ValidSQL(val, "colvalue")))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return sqlquery
|
||||
}
|
||||
|
||||
// mergeHeaderParams merges HTTP header parameters into the SQL query
|
||||
func (h *Handler) mergeHeaderParams(r *http.Request, sqlquery string, variables map[string]interface{}, propQry map[string]string, complexAPI *bool) string {
|
||||
for kc, v := range r.Header {
|
||||
k := strings.ToLower(kc)
|
||||
if !strings.HasPrefix(k, "x-") || len(v) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
val := v[0]
|
||||
dec, err := restheadspec.DecodeParam(val)
|
||||
if err == nil {
|
||||
val = dec
|
||||
}
|
||||
|
||||
variables[k] = val
|
||||
propQry[k] = val
|
||||
|
||||
kword := fmt.Sprintf("[%s]", k)
|
||||
if strings.Contains(sqlquery, kword) {
|
||||
sqlquery = strings.ReplaceAll(sqlquery, kword, val)
|
||||
}
|
||||
|
||||
// Handle special headers
|
||||
if strings.Contains(k, "x-fieldfilter-") {
|
||||
colname := strings.ReplaceAll(k, "x-fieldfilter-", "")
|
||||
if val == "" || val == "0" {
|
||||
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("COALESCE(%s, 0) = %s", ValidSQL(colname, "colname"), ValidSQL(val, "colvalue")))
|
||||
} else {
|
||||
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s = %s", ValidSQL(colname, "colname"), ValidSQL(val, "colvalue")))
|
||||
}
|
||||
}
|
||||
|
||||
if strings.Contains(k, "x-searchfilter-") {
|
||||
colname := strings.ReplaceAll(k, "x-searchfilter-", "")
|
||||
sval := strings.ReplaceAll(val, "'", "")
|
||||
if sval != "" {
|
||||
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s ILIKE '%%%s%%'", ValidSQL(colname, "colname"), ValidSQL(sval, "colvalue")))
|
||||
}
|
||||
}
|
||||
|
||||
if strings.Contains(k, "x-custom-sql-w") {
|
||||
colval := ValidSQL(val, "select")
|
||||
if len(colval) > 0 {
|
||||
sqlquery = sqlQryWhere(sqlquery, colval)
|
||||
}
|
||||
}
|
||||
|
||||
if strings.Contains(k, "x-simpleapi") {
|
||||
*complexAPI = !(val == "1" || strings.ToLower(val) == "true")
|
||||
}
|
||||
}
|
||||
return sqlquery
|
||||
}
|
||||
|
||||
// replaceMetaVariables replaces meta variables like [rid_user], [user], etc. in the SQL query
|
||||
func (h *Handler) replaceMetaVariables(sqlquery string, r *http.Request, userCtx *security.UserContext, metainfo map[string]interface{}, variables map[string]interface{}) string {
|
||||
if strings.Contains(sqlquery, "[p_meta_default]") {
|
||||
data, _ := json.Marshal(metainfo)
|
||||
sqlquery = strings.ReplaceAll(sqlquery, "[p_meta_default]", fmt.Sprintf("'%s'::jsonb", string(data)))
|
||||
}
|
||||
|
||||
if strings.Contains(sqlquery, "[json_variables]") {
|
||||
data, _ := json.Marshal(variables)
|
||||
sqlquery = strings.ReplaceAll(sqlquery, "[json_variables]", fmt.Sprintf("'%s'::jsonb", string(data)))
|
||||
}
|
||||
|
||||
if strings.Contains(sqlquery, "[rid_user]") {
|
||||
sqlquery = strings.ReplaceAll(sqlquery, "[rid_user]", fmt.Sprintf("%d", userCtx.UserID))
|
||||
}
|
||||
|
||||
if strings.Contains(sqlquery, "[user]") {
|
||||
sqlquery = strings.ReplaceAll(sqlquery, "[user]", fmt.Sprintf("'%s'", userCtx.UserName))
|
||||
}
|
||||
|
||||
if strings.Contains(sqlquery, "[rid_session]") {
|
||||
sessionID := userCtx.SessionID
|
||||
sqlquery = strings.ReplaceAll(sqlquery, "[rid_session]", fmt.Sprintf("'%s'", sessionID))
|
||||
}
|
||||
|
||||
if strings.Contains(sqlquery, "[method]") {
|
||||
sqlquery = strings.ReplaceAll(sqlquery, "[method]", r.Method)
|
||||
}
|
||||
|
||||
if strings.Contains(sqlquery, "[post_body]") {
|
||||
bodystr := ""
|
||||
if r.Method == "POST" || r.Method == "PUT" {
|
||||
if r.Body != nil {
|
||||
contents, err := io.ReadAll(r.Body)
|
||||
if err == nil {
|
||||
bodystr = string(contents)
|
||||
}
|
||||
}
|
||||
}
|
||||
sqlquery = strings.ReplaceAll(sqlquery, "[post_body]", fmt.Sprintf("'%s'", bodystr))
|
||||
}
|
||||
|
||||
return sqlquery
|
||||
}
|
||||
|
||||
// parsePaginationParams extracts sort, limit, and offset parameters from request
|
||||
func (h *Handler) parsePaginationParams(r *http.Request) (sortcols string, limit, offset int) {
|
||||
limit = 20
|
||||
offset = 0
|
||||
|
||||
if sortStr := r.URL.Query().Get("sort"); sortStr != "" {
|
||||
sortcols = sortStr
|
||||
}
|
||||
|
||||
if limitStr := r.URL.Query().Get("limit"); limitStr != "" {
|
||||
if l, err := strconv.Atoi(limitStr); err == nil && l > 0 {
|
||||
limit = l
|
||||
}
|
||||
}
|
||||
|
||||
if offsetStr := r.URL.Query().Get("offset"); offsetStr != "" {
|
||||
if o, err := strconv.Atoi(offsetStr); err == nil && o >= 0 {
|
||||
offset = o
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// ValidSQL validates and sanitizes SQL input to prevent injection
|
||||
// mode can be: "colname", "colvalue", "select"
|
||||
func ValidSQL(input, mode string) string {
|
||||
// Remove dangerous characters based on mode
|
||||
switch mode {
|
||||
case "colname":
|
||||
// For column names, only allow alphanumeric, underscore, and dot
|
||||
reg := regexp.MustCompile(`[^a-zA-Z0-9_\.]`)
|
||||
return reg.ReplaceAllString(input, "")
|
||||
case "colvalue":
|
||||
// For column values, escape single quotes
|
||||
return strings.ReplaceAll(input, "'", "''")
|
||||
case "select":
|
||||
// For SELECT clauses, be more permissive but still safe
|
||||
// Remove semicolons and common SQL injection patterns
|
||||
dangerous := []string{";", "--", "/*", "*/", "xp_", "sp_", "DROP ", "DELETE ", "TRUNCATE ", "UPDATE ", "INSERT "}
|
||||
result := input
|
||||
for _, d := range dangerous {
|
||||
result = strings.ReplaceAll(result, d, "")
|
||||
result = strings.ReplaceAll(result, strings.ToLower(d), "")
|
||||
result = strings.ReplaceAll(result, strings.ToUpper(d), "")
|
||||
}
|
||||
return result
|
||||
default:
|
||||
return input
|
||||
}
|
||||
}
|
||||
|
||||
// sqlQryWhere adds a WHERE clause to a SQL query or appends to existing WHERE with AND
|
||||
func sqlQryWhere(sqlquery, condition string) string {
|
||||
lowerQuery := strings.ToLower(sqlquery)
|
||||
wherePos := strings.Index(lowerQuery, " where ")
|
||||
groupPos := strings.Index(lowerQuery, " group by")
|
||||
orderPos := strings.Index(lowerQuery, " order by")
|
||||
limitPos := strings.Index(lowerQuery, " limit ")
|
||||
|
||||
// Find the insertion point (before GROUP BY, ORDER BY, or LIMIT)
|
||||
insertPos := len(sqlquery)
|
||||
if groupPos > 0 && groupPos < insertPos {
|
||||
insertPos = groupPos
|
||||
}
|
||||
if orderPos > 0 && orderPos < insertPos {
|
||||
insertPos = orderPos
|
||||
}
|
||||
if limitPos > 0 && limitPos < insertPos {
|
||||
insertPos = limitPos
|
||||
}
|
||||
|
||||
if wherePos > 0 {
|
||||
// WHERE exists, add AND condition before GROUP BY / ORDER BY / LIMIT
|
||||
before := sqlquery[:insertPos]
|
||||
after := sqlquery[insertPos:]
|
||||
return fmt.Sprintf("%s AND %s %s", before, condition, after)
|
||||
} else {
|
||||
// No WHERE exists, add it before GROUP BY / ORDER BY / LIMIT
|
||||
before := sqlquery[:insertPos]
|
||||
after := sqlquery[insertPos:]
|
||||
return fmt.Sprintf("%s WHERE %s %s", before, condition, after)
|
||||
}
|
||||
}
|
||||
|
||||
// IsNumeric checks if a string contains only numeric characters
|
||||
func IsNumeric(s string) bool {
|
||||
_, err := strconv.ParseFloat(s, 64)
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// makeResultReceiver creates a slice of interface{} pointers for scanning SQL rows
|
||||
// func makeResultReceiver(length int) []interface{} {
|
||||
// result := make([]interface{}, length)
|
||||
// for i := 0; i < length; i++ {
|
||||
// var v interface{}
|
||||
// result[i] = &v
|
||||
// }
|
||||
// return result
|
||||
// }
|
||||
|
||||
// getIPAddress extracts the real IP address from the request
|
||||
func getIPAddress(r *http.Request) string {
|
||||
if forwarded := r.Header.Get("X-Forwarded-For"); forwarded != "" {
|
||||
// X-Forwarded-For can contain multiple IPs, take the first one
|
||||
ips := strings.Split(forwarded, ",")
|
||||
return strings.TrimSpace(ips[0])
|
||||
}
|
||||
if realIP := r.Header.Get("X-Real-IP"); realIP != "" {
|
||||
return realIP
|
||||
}
|
||||
return r.RemoteAddr
|
||||
}
|
||||
|
||||
// sendError sends a JSON error response
|
||||
func sendError(w http.ResponseWriter, status int, code, message string, err error) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
|
||||
errObj := common.APIError{
|
||||
Code: code,
|
||||
Message: message,
|
||||
}
|
||||
if err != nil {
|
||||
errObj.Detail = err.Error()
|
||||
}
|
||||
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"success": false,
|
||||
"error": errObj,
|
||||
})
|
||||
w.Write(data)
|
||||
}
|
||||
Loading…
Reference in New Issue
Block a user