package funcspec import ( "context" "encoding/json" "fmt" "io" "net/http" "regexp" "runtime/debug" "strconv" "strings" "time" "github.com/bitechdev/ResolveSpec/pkg/common" "github.com/bitechdev/ResolveSpec/pkg/logger" "github.com/bitechdev/ResolveSpec/pkg/restheadspec" "github.com/bitechdev/ResolveSpec/pkg/security" ) // Handler handles function-based SQL API requests type Handler struct { db common.Database hooks *HookRegistry } // NewHandler creates a new function API handler func NewHandler(db common.Database) *Handler { return &Handler{ db: db, hooks: NewHookRegistry(), } } // Hooks returns the hook registry for this handler // Use this to register custom hooks for operations func (h *Handler) Hooks() *HookRegistry { return h.hooks } // HTTPFuncType is a function type for HTTP handlers type HTTPFuncType func(http.ResponseWriter, *http.Request) // SqlQueryList creates an HTTP handler that executes a SQL query and returns a list with pagination func (h *Handler) SqlQueryList(sqlquery string, pNoCount, pBlankparms, pAllowFilter bool) HTTPFuncType { return func(w http.ResponseWriter, r *http.Request) { defer func() { if err := recover(); err != nil { stack := debug.Stack() logger.Error("Panic in SqlQueryList: %v\nStack trace:\n%s", err, string(stack)) http.Error(w, fmt.Sprintf("Internal server error: %v", err), http.StatusInternalServerError) } }() ctx, cancel := context.WithTimeout(r.Context(), 900*time.Second) defer cancel() var dbobjlist []map[string]interface{} var total int64 propQry := make(map[string]string) inputvars := make([]string, 0) metainfo := make(map[string]interface{}) variables := make(map[string]interface{}) complexAPI := false // Get user context from security package userCtx, ok := security.GetUserContext(ctx) if !ok { logger.Warn("No user context found in request") userCtx = &security.UserContext{UserID: 0, UserName: "anonymous"} } w.Header().Set("Content-Type", "application/json") // Initialize hook context hookCtx := &HookContext{ Context: ctx, Handler: h, Request: r, Writer: w, SQLQuery: sqlquery, Variables: variables, InputVars: inputvars, MetaInfo: metainfo, PropQry: propQry, UserContext: userCtx, NoCount: pNoCount, BlankParams: pBlankparms, AllowFilter: pAllowFilter, ComplexAPI: complexAPI, } // Execute BeforeQueryList hook if err := h.hooks.Execute(BeforeQueryList, hookCtx); err != nil { logger.Error("BeforeQueryList hook failed: %v", err) sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err) return } // Check if hook aborted the operation if hookCtx.Abort { if hookCtx.AbortCode == 0 { hookCtx.AbortCode = http.StatusBadRequest } sendError(w, hookCtx.AbortCode, "operation_aborted", hookCtx.AbortMessage, nil) return } // Use potentially modified SQL query and variables from hooks sqlquery = hookCtx.SQLQuery variables = hookCtx.Variables // complexAPI = hookCtx.ComplexAPI // Extract input variables from SQL query (placeholders like [variable]) sqlquery = h.extractInputVariables(sqlquery, &inputvars) // Merge URL path parameters sqlquery = h.mergePathParams(r, sqlquery, variables) // Parse comprehensive parameters from headers and query string reqParams := h.ParseParameters(r) complexAPI = reqParams.ComplexAPI // Merge query string parameters sqlquery = h.mergeQueryParams(r, sqlquery, variables, pAllowFilter, propQry) // Merge header parameters sqlquery = h.mergeHeaderParams(r, sqlquery, variables, propQry, &complexAPI) // Apply filters from parsed parameters (if not already applied by pAllowFilter) if !pAllowFilter { sqlquery = h.ApplyFilters(sqlquery, reqParams) } // Apply field selection sqlquery = h.ApplyFieldSelection(sqlquery, reqParams) // Apply DISTINCT if requested sqlquery = h.ApplyDistinct(sqlquery, reqParams) // Override pNoCount if skipcount is specified if reqParams.SkipCount { pNoCount = true } // Build metainfo metainfo["ipaddress"] = getIPAddress(r) metainfo["url"] = r.RequestURI metainfo["user"] = userCtx.UserName metainfo["rid_user"] = fmt.Sprintf("%d", userCtx.UserID) metainfo["method"] = r.Method metainfo["variables"] = variables // Replace meta variables in SQL sqlquery = h.replaceMetaVariables(sqlquery, r, userCtx, metainfo, variables) // Remove unused input variables if pBlankparms { for _, kw := range inputvars { sqlquery = strings.ReplaceAll(sqlquery, kw, "") logger.Debug("Removed unused variable: %s", kw) } } // Update hook context with latest SQL query and variables hookCtx.SQLQuery = sqlquery hookCtx.Variables = variables hookCtx.InputVars = inputvars // Execute query within transaction err := h.db.RunInTransaction(ctx, func(tx common.Database) error { sqlqueryCnt := sqlquery // Parse sorting and pagination parameters sortcols, limit, offset := h.parsePaginationParams(r) // Override with parsed parameters if available if reqParams.SortColumns != "" { sortcols = reqParams.SortColumns } if reqParams.Limit > 0 { limit = reqParams.Limit } if reqParams.Offset > 0 { offset = reqParams.Offset } hookCtx.SortColumns = sortcols hookCtx.Limit = limit hookCtx.Offset = offset fromPos := strings.Index(strings.ToLower(sqlquery), "from ") orderbyPos := strings.Index(strings.ToLower(sqlquery), "order by") if len(sortcols) > 0 && (orderbyPos < 0 || (orderbyPos > 0 && orderbyPos < fromPos)) { sqlquery = fmt.Sprintf("%s \nORDER BY %s", sqlquery, ValidSQL(sortcols, "select")) } if !pNoCount { if limit > 0 && offset > 0 { sqlquery = fmt.Sprintf("%s \nLIMIT %d OFFSET %d", sqlquery, limit, offset) } else if limit > 0 { sqlquery = fmt.Sprintf("%s \nLIMIT %d", sqlquery, limit) } else { sqlquery = fmt.Sprintf("%s \nLIMIT %d", sqlquery, 20000) } // Get total count countQuery := fmt.Sprintf("SELECT COUNT(1) FROM (%s) cnts", sqlqueryCnt) var countResult struct{ Count int64 } if err := tx.Query(ctx, &countResult, countQuery); err != nil { sendError(w, http.StatusBadRequest, "count_failed", "Failed to retrieve record count", err) return err } total = countResult.Count } // Execute BeforeSQLExec hook hookCtx.SQLQuery = sqlquery if err := h.hooks.Execute(BeforeSQLExec, hookCtx); err != nil { logger.Error("BeforeSQLExec hook failed: %v", err) sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err) return err } // Use potentially modified SQL query from hook sqlquery = hookCtx.SQLQuery // Execute main query rows := make([]map[string]interface{}, 0) if err := tx.Query(ctx, &rows, sqlquery); err != nil { sendError(w, http.StatusBadRequest, "query_failed", "Failed to retrieve records", err) return err } dbobjlist = rows if pNoCount { total = int64(len(dbobjlist)) } // Execute AfterSQLExec hook hookCtx.Result = dbobjlist hookCtx.Total = total if err := h.hooks.Execute(AfterSQLExec, hookCtx); err != nil { logger.Error("AfterSQLExec hook failed: %v", err) sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err) return err } // Use potentially modified result from hook if modifiedResult, ok := hookCtx.Result.([]map[string]interface{}); ok { dbobjlist = modifiedResult } total = hookCtx.Total return nil }) if err != nil { logger.Error("Transaction failed: %v", err) return } // Execute AfterQueryList hook hookCtx.Result = dbobjlist hookCtx.Total = total hookCtx.Error = err if err := h.hooks.Execute(AfterQueryList, hookCtx); err != nil { logger.Error("AfterQueryList hook failed: %v", err) sendError(w, http.StatusInternalServerError, "hook_error", "Hook execution failed", err) return } // Use potentially modified result from hook if modifiedResult, ok := hookCtx.Result.([]map[string]interface{}); ok { dbobjlist = modifiedResult } total = hookCtx.Total // Set response headers respOffset := 0 if offsetStr := r.URL.Query().Get("offset"); offsetStr != "" { if o, err := strconv.Atoi(offsetStr); err == nil { respOffset = o } } w.Header().Set("Content-Range", fmt.Sprintf("items %d-%d/%d", respOffset, respOffset+len(dbobjlist), total)) logger.Info("Serving: Records %d of %d", len(dbobjlist), total) // Execute BeforeResponse hook hookCtx.Result = dbobjlist hookCtx.Total = total if err := h.hooks.Execute(BeforeResponse, hookCtx); err != nil { logger.Error("BeforeResponse hook failed: %v", err) sendError(w, http.StatusInternalServerError, "hook_error", "Hook execution failed", err) return } // Use potentially modified result from hook if modifiedResult, ok := hookCtx.Result.([]map[string]interface{}); ok { dbobjlist = modifiedResult } if len(dbobjlist) == 0 { _, _ = w.Write([]byte("[]")) return } // Format response based on response format switch reqParams.ResponseFormat { case "syncfusion": // Syncfusion format: { result: data, count: total } response := map[string]interface{}{ "result": dbobjlist, "count": total, } data, err := json.Marshal(response) if err != nil { sendError(w, http.StatusInternalServerError, "json_error", "Could not marshal response", err) } else { if int64(len(dbobjlist)) < total { w.WriteHeader(http.StatusPartialContent) } _, _ = w.Write(data) } case "detail": // Detail format: complex API with metadata metaobj := map[string]interface{}{ "items": dbobjlist, "count": fmt.Sprintf("%d", len(dbobjlist)), "total": fmt.Sprintf("%d", total), "tablename": r.URL.Path, "tableprefix": "gsql", } data, err := json.Marshal(metaobj) if err != nil { sendError(w, http.StatusInternalServerError, "json_error", "Could not marshal response", err) } else { if int64(len(dbobjlist)) < total { w.WriteHeader(http.StatusPartialContent) } _, _ = w.Write(data) } default: // Simple format: just return the data array (or complex API if requested) if complexAPI { metaobj := map[string]interface{}{ "items": dbobjlist, "count": fmt.Sprintf("%d", len(dbobjlist)), "total": fmt.Sprintf("%d", total), "tablename": r.URL.Path, "tableprefix": "gsql", } data, err := json.Marshal(metaobj) if err != nil { sendError(w, http.StatusInternalServerError, "json_error", "Could not marshal response", err) } else { if int64(len(dbobjlist)) < total { w.WriteHeader(http.StatusPartialContent) } _, _ = w.Write(data) } } else { data, err := json.Marshal(dbobjlist) if err != nil { sendError(w, http.StatusInternalServerError, "json_error", "Could not marshal response", err) } else { if int64(len(dbobjlist)) < total { w.WriteHeader(http.StatusPartialContent) } _, _ = w.Write(data) } } } } } // SqlQuery creates an HTTP handler that executes a SQL query and returns a single record func (h *Handler) SqlQuery(sqlquery string, pBlankparms bool) HTTPFuncType { return func(w http.ResponseWriter, r *http.Request) { defer func() { if err := recover(); err != nil { stack := debug.Stack() logger.Error("Panic in SqlQuery: %v\nStack trace:\n%s", err, string(stack)) http.Error(w, fmt.Sprintf("Internal server error: %v", err), http.StatusInternalServerError) } }() ctx, cancel := context.WithTimeout(r.Context(), 600*time.Second) defer cancel() propQry := make(map[string]string) inputvars := make([]string, 0) metainfo := make(map[string]interface{}) variables := make(map[string]interface{}) dbobj := make(map[string]interface{}) complexAPI := false // Get user context from security package userCtx, ok := security.GetUserContext(ctx) if !ok { logger.Warn("No user context found in request") userCtx = &security.UserContext{UserID: 0, UserName: "anonymous"} } w.Header().Set("Content-Type", "application/json") // Initialize hook context hookCtx := &HookContext{ Context: ctx, Handler: h, Request: r, Writer: w, SQLQuery: sqlquery, Variables: variables, InputVars: inputvars, MetaInfo: metainfo, PropQry: propQry, UserContext: userCtx, BlankParams: pBlankparms, ComplexAPI: complexAPI, } // Execute BeforeQuery hook if err := h.hooks.Execute(BeforeQuery, hookCtx); err != nil { logger.Error("BeforeQuery hook failed: %v", err) sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err) return } // Check if hook aborted the operation if hookCtx.Abort { if hookCtx.AbortCode == 0 { hookCtx.AbortCode = http.StatusBadRequest } sendError(w, hookCtx.AbortCode, "operation_aborted", hookCtx.AbortMessage, nil) return } // Use potentially modified SQL query and variables from hooks sqlquery = hookCtx.SQLQuery variables = hookCtx.Variables // Extract input variables from SQL query sqlquery = h.extractInputVariables(sqlquery, &inputvars) // Merge URL path parameters sqlquery = h.mergePathParams(r, sqlquery, variables) // Parse comprehensive parameters from headers and query string reqParams := h.ParseParameters(r) complexAPI = reqParams.ComplexAPI // Merge query string parameters sqlquery = h.mergeQueryParams(r, sqlquery, variables, false, propQry) // Merge header parameters sqlquery = h.mergeHeaderParams(r, sqlquery, variables, propQry, &complexAPI) hookCtx.ComplexAPI = complexAPI // Apply filters from parsed parameters sqlquery = h.ApplyFilters(sqlquery, reqParams) // Apply field selection sqlquery = h.ApplyFieldSelection(sqlquery, reqParams) // Apply DISTINCT if requested sqlquery = h.ApplyDistinct(sqlquery, reqParams) // Build metainfo metainfo["ipaddress"] = getIPAddress(r) metainfo["url"] = r.RequestURI metainfo["user"] = userCtx.UserName metainfo["rid_user"] = fmt.Sprintf("%d", userCtx.UserID) metainfo["method"] = r.Method metainfo["variables"] = variables // Replace meta variables in SQL sqlquery = h.replaceMetaVariables(sqlquery, r, userCtx, metainfo, variables) // Apply field filters from headers for k, val := range propQry { kLower := strings.ToLower(k) if strings.HasPrefix(kLower, "x-fieldfilter-") { colname := strings.ReplaceAll(kLower, "x-fieldfilter-", "") if strings.Contains(strings.ToLower(sqlquery), colname) { if val == "" || val == "0" { sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("COALESCE(%s, 0) = %s", ValidSQL(colname, "colname"), ValidSQL(val, "colvalue"))) } else { sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s = %s", ValidSQL(colname, "colname"), ValidSQL(val, "colvalue"))) } } } } // Remove unused input variables if pBlankparms { for _, kw := range inputvars { sqlquery = strings.ReplaceAll(sqlquery, kw, "") logger.Debug("Removed unused variable: %s", kw) } } // Update hook context with latest SQL query and variables hookCtx.SQLQuery = sqlquery hookCtx.Variables = variables hookCtx.InputVars = inputvars // Execute query within transaction err := h.db.RunInTransaction(ctx, func(tx common.Database) error { // Execute BeforeSQLExec hook if err := h.hooks.Execute(BeforeSQLExec, hookCtx); err != nil { logger.Error("BeforeSQLExec hook failed: %v", err) sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err) return err } // Use potentially modified SQL query from hook sqlquery = hookCtx.SQLQuery // Execute main query rows := make([]map[string]interface{}, 0) if err := tx.Query(ctx, &rows, sqlquery); err != nil { sendError(w, http.StatusBadRequest, "query_failed", "Failed to retrieve records", err) return err } if len(rows) > 0 { dbobj = rows[0] } // Execute AfterSQLExec hook hookCtx.Result = dbobj if err := h.hooks.Execute(AfterSQLExec, hookCtx); err != nil { logger.Error("AfterSQLExec hook failed: %v", err) sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err) return err } // Use potentially modified result from hook if modifiedResult, ok := hookCtx.Result.(map[string]interface{}); ok { dbobj = modifiedResult } return nil }) if err != nil { logger.Error("Transaction failed: %v", err) return } // Execute AfterQuery hook hookCtx.Result = dbobj hookCtx.Error = err if err := h.hooks.Execute(AfterQuery, hookCtx); err != nil { logger.Error("AfterQuery hook failed: %v", err) sendError(w, http.StatusInternalServerError, "hook_error", "Hook execution failed", err) return } // Use potentially modified result from hook if modifiedResult, ok := hookCtx.Result.(map[string]interface{}); ok { dbobj = modifiedResult } // Execute BeforeResponse hook hookCtx.Result = dbobj if err := h.hooks.Execute(BeforeResponse, hookCtx); err != nil { logger.Error("BeforeResponse hook failed: %v", err) sendError(w, http.StatusInternalServerError, "hook_error", "Hook execution failed", err) return } // Use potentially modified result from hook if modifiedResult, ok := hookCtx.Result.(map[string]interface{}); ok { dbobj = modifiedResult } // Check if response should be root-level data if val, ok := dbobj["root_as_data"]; ok { data, err := json.Marshal(val) if err != nil { sendError(w, http.StatusInternalServerError, "json_error", "Could not marshal response", err) } else { _, _ = w.Write(data) } return } // Marshal and send response data, err := json.Marshal(dbobj) if err != nil { sendError(w, http.StatusInternalServerError, "json_error", "Could not marshal response", err) } else { _, _ = w.Write(data) } } } // Helper functions // extractInputVariables extracts placeholders like [variable] from the SQL query func (h *Handler) extractInputVariables(sqlquery string, inputvars *[]string) string { testsqlquery := sqlquery for i := 0; i <= strings.Count(sqlquery, "[")*4; i++ { iStart := strings.Index(testsqlquery, "[") if iStart < 0 { break } iEnd := strings.Index(testsqlquery, "]") if iEnd < 0 { break } *inputvars = append(*inputvars, testsqlquery[iStart:iEnd+1]) testsqlquery = testsqlquery[iEnd+1:] } return sqlquery } // mergePathParams merges URL path parameters into the SQL query func (h *Handler) mergePathParams(r *http.Request, sqlquery string, variables map[string]interface{}) string { // Note: Path parameters would typically come from a router like gorilla/mux // For now, this is a placeholder for path parameter extraction return sqlquery } // mergeQueryParams merges query string parameters into the SQL query func (h *Handler) mergeQueryParams(r *http.Request, sqlquery string, variables map[string]interface{}, allowFilter bool, propQry map[string]string) string { for parmk, parmv := range r.URL.Query() { if len(parmk) == 0 || len(parmv) == 0 { continue } val := parmv[0] dec, err := restheadspec.DecodeParam(val) if err == nil { val = dec } kword := fmt.Sprintf("[%s]", parmk) variables[parmk] = val // Replace in SQL if placeholder exists if strings.Contains(sqlquery, kword) && len(val) > 0 { if strings.HasPrefix(parmk, "p-") { sqlquery = strings.ReplaceAll(sqlquery, kword, val) } } // Add to propQry for x- prefixed params if strings.HasPrefix(parmk, "x-") { propQry[parmk] = val } // Apply filters if allowed if allowFilter && len(parmk) > 1 && strings.Contains(strings.ToLower(sqlquery), strings.ToLower(parmk)) { if len(parmv) > 1 { sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s IN (%s)", ValidSQL(parmk, "colname"), strings.Join(parmv, ","))) } else { if strings.Contains(val, "match=") { colval := strings.ReplaceAll(val, "match=", "") if colval != "*" { sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s ILIKE '%%%s%%'", ValidSQL(parmk, "colname"), ValidSQL(colval, "colvalue"))) } } else if val == "" || val == "0" { sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("(%[1]s = %[2]s OR %[1]s IS NULL)", ValidSQL(parmk, "colname"), ValidSQL(val, "colvalue"))) } else { if IsNumeric(val) { sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s = %s", ValidSQL(parmk, "colname"), ValidSQL(val, "colvalue"))) } else { sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s = '%s'", ValidSQL(parmk, "colname"), ValidSQL(val, "colvalue"))) } } } } } return sqlquery } // mergeHeaderParams merges HTTP header parameters into the SQL query func (h *Handler) mergeHeaderParams(r *http.Request, sqlquery string, variables map[string]interface{}, propQry map[string]string, complexAPI *bool) string { for kc, v := range r.Header { k := strings.ToLower(kc) if !strings.HasPrefix(k, "x-") || len(v) == 0 { continue } val := v[0] dec, err := restheadspec.DecodeParam(val) if err == nil { val = dec } variables[k] = val propQry[k] = val kword := fmt.Sprintf("[%s]", k) if strings.Contains(sqlquery, kword) { sqlquery = strings.ReplaceAll(sqlquery, kword, val) } // Handle special headers if strings.Contains(k, "x-fieldfilter-") { colname := strings.ReplaceAll(k, "x-fieldfilter-", "") if val == "" || val == "0" { sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("COALESCE(%s, 0) = %s", ValidSQL(colname, "colname"), ValidSQL(val, "colvalue"))) } else { sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s = %s", ValidSQL(colname, "colname"), ValidSQL(val, "colvalue"))) } } if strings.Contains(k, "x-searchfilter-") { colname := strings.ReplaceAll(k, "x-searchfilter-", "") sval := strings.ReplaceAll(val, "'", "") if sval != "" { sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s ILIKE '%%%s%%'", ValidSQL(colname, "colname"), ValidSQL(sval, "colvalue"))) } } if strings.Contains(k, "x-custom-sql-w") { colval := ValidSQL(val, "select") if len(colval) > 0 { sqlquery = sqlQryWhere(sqlquery, colval) } } if strings.Contains(k, "x-simpleapi") { *complexAPI = !strings.EqualFold(val, "1") && !strings.EqualFold(val, "true") } } return sqlquery } // replaceMetaVariables replaces meta variables like [rid_user], [user], etc. in the SQL query func (h *Handler) replaceMetaVariables(sqlquery string, r *http.Request, userCtx *security.UserContext, metainfo map[string]interface{}, variables map[string]interface{}) string { if strings.Contains(sqlquery, "[p_meta_default]") { data, _ := json.Marshal(metainfo) sqlquery = strings.ReplaceAll(sqlquery, "[p_meta_default]", fmt.Sprintf("'%s'::jsonb", string(data))) } if strings.Contains(sqlquery, "[json_variables]") { data, _ := json.Marshal(variables) sqlquery = strings.ReplaceAll(sqlquery, "[json_variables]", fmt.Sprintf("'%s'::jsonb", string(data))) } if strings.Contains(sqlquery, "[rid_user]") { sqlquery = strings.ReplaceAll(sqlquery, "[rid_user]", fmt.Sprintf("%d", userCtx.UserID)) } if strings.Contains(sqlquery, "[user]") { sqlquery = strings.ReplaceAll(sqlquery, "[user]", fmt.Sprintf("'%s'", userCtx.UserName)) } if strings.Contains(sqlquery, "[rid_session]") { sessionID := userCtx.SessionID sqlquery = strings.ReplaceAll(sqlquery, "[rid_session]", fmt.Sprintf("'%s'", sessionID)) } if strings.Contains(sqlquery, "[method]") { sqlquery = strings.ReplaceAll(sqlquery, "[method]", r.Method) } if strings.Contains(sqlquery, "[post_body]") { bodystr := "" if r.Method == "POST" || r.Method == "PUT" { if r.Body != nil { contents, err := io.ReadAll(r.Body) if err == nil { bodystr = string(contents) } } } sqlquery = strings.ReplaceAll(sqlquery, "[post_body]", fmt.Sprintf("'%s'", bodystr)) } return sqlquery } // parsePaginationParams extracts sort, limit, and offset parameters from request func (h *Handler) parsePaginationParams(r *http.Request) (sortcols string, limit, offset int) { limit = 20 offset = 0 if sortStr := r.URL.Query().Get("sort"); sortStr != "" { sortcols = sortStr } if limitStr := r.URL.Query().Get("limit"); limitStr != "" { if l, err := strconv.Atoi(limitStr); err == nil && l > 0 { limit = l } } if offsetStr := r.URL.Query().Get("offset"); offsetStr != "" { if o, err := strconv.Atoi(offsetStr); err == nil && o >= 0 { offset = o } } return } // ValidSQL validates and sanitizes SQL input to prevent injection // mode can be: "colname", "colvalue", "select" func ValidSQL(input, mode string) string { // Remove dangerous characters based on mode switch mode { case "colname": // For column names, only allow alphanumeric, underscore, and dot reg := regexp.MustCompile(`[^a-zA-Z0-9_\.]`) return reg.ReplaceAllString(input, "") case "colvalue": // For column values, escape single quotes return strings.ReplaceAll(input, "'", "''") case "select": // For SELECT clauses, be more permissive but still safe // Remove semicolons and common SQL injection patterns dangerous := []string{";", "--", "/*", "*/", "xp_", "sp_", "DROP ", "DELETE ", "TRUNCATE ", "UPDATE ", "INSERT "} result := input for _, d := range dangerous { result = strings.ReplaceAll(result, d, "") result = strings.ReplaceAll(result, strings.ToLower(d), "") result = strings.ReplaceAll(result, strings.ToUpper(d), "") } return result default: return input } } // sqlQryWhere adds a WHERE clause to a SQL query or appends to existing WHERE with AND func sqlQryWhere(sqlquery, condition string) string { lowerQuery := strings.ToLower(sqlquery) wherePos := strings.Index(lowerQuery, " where ") groupPos := strings.Index(lowerQuery, " group by") orderPos := strings.Index(lowerQuery, " order by") limitPos := strings.Index(lowerQuery, " limit ") // Find the insertion point (before GROUP BY, ORDER BY, or LIMIT) insertPos := len(sqlquery) if groupPos > 0 && groupPos < insertPos { insertPos = groupPos } if orderPos > 0 && orderPos < insertPos { insertPos = orderPos } if limitPos > 0 && limitPos < insertPos { insertPos = limitPos } if wherePos > 0 { // WHERE exists, add AND condition before GROUP BY / ORDER BY / LIMIT before := sqlquery[:insertPos] after := sqlquery[insertPos:] return fmt.Sprintf("%s AND %s %s", before, condition, after) } else { // No WHERE exists, add it before GROUP BY / ORDER BY / LIMIT before := sqlquery[:insertPos] after := sqlquery[insertPos:] return fmt.Sprintf("%s WHERE %s %s", before, condition, after) } } // IsNumeric checks if a string contains only numeric characters func IsNumeric(s string) bool { _, err := strconv.ParseFloat(s, 64) return err == nil } // makeResultReceiver creates a slice of interface{} pointers for scanning SQL rows // func makeResultReceiver(length int) []interface{} { // result := make([]interface{}, length) // for i := 0; i < length; i++ { // var v interface{} // result[i] = &v // } // return result // } // getIPAddress extracts the real IP address from the request func getIPAddress(r *http.Request) string { if forwarded := r.Header.Get("X-Forwarded-For"); forwarded != "" { // X-Forwarded-For can contain multiple IPs, take the first one ips := strings.Split(forwarded, ",") return strings.TrimSpace(ips[0]) } if realIP := r.Header.Get("X-Real-IP"); realIP != "" { return realIP } return r.RemoteAddr } // sendError sends a JSON error response func sendError(w http.ResponseWriter, status int, code, message string, err error) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(status) errObj := common.APIError{ Code: code, Message: message, } if err != nil { errObj.Detail = err.Error() } data, _ := json.Marshal(map[string]interface{}{ "success": false, "error": errObj, }) _, _ = w.Write(data) }