From 6bbe0ec8b0407b9f89e0fda808038d38eda68dea Mon Sep 17 00:00:00 2001 From: Hein Date: Mon, 24 Nov 2025 17:00:15 +0200 Subject: [PATCH] Added function api prototype --- pkg/funcspec/function_api.go | 629 +++++++++++++++++++++++++++++++++++ 1 file changed, 629 insertions(+) create mode 100644 pkg/funcspec/function_api.go diff --git a/pkg/funcspec/function_api.go b/pkg/funcspec/function_api.go new file mode 100644 index 0000000..e56d8ac --- /dev/null +++ b/pkg/funcspec/function_api.go @@ -0,0 +1,629 @@ +package funcspec + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "regexp" + "runtime/debug" + "strconv" + "strings" + "time" + + "github.com/bitechdev/ResolveSpec/pkg/common" + "github.com/bitechdev/ResolveSpec/pkg/logger" + "github.com/bitechdev/ResolveSpec/pkg/restheadspec" + "github.com/bitechdev/ResolveSpec/pkg/security" +) + +// Handler handles function-based SQL API requests +type Handler struct { + db common.Database +} + +// NewHandler creates a new function API handler +func NewHandler(db common.Database) *Handler { + return &Handler{ + db: db, + } +} + +// HTTPFuncType is a function type for HTTP handlers +type HTTPFuncType func(http.ResponseWriter, *http.Request) + +// SqlQueryList creates an HTTP handler that executes a SQL query and returns a list with pagination +func (h *Handler) SqlQueryList(sqlquery string, pNoCount, pBlankparms, pAllowFilter bool) HTTPFuncType { + return func(w http.ResponseWriter, r *http.Request) { + defer func() { + if err := recover(); err != nil { + stack := debug.Stack() + logger.Error("Panic in SqlQueryList: %v\nStack trace:\n%s", err, string(stack)) + http.Error(w, fmt.Sprintf("Internal server error: %v", err), http.StatusInternalServerError) + } + }() + + ctx, cancel := context.WithTimeout(r.Context(), 900*time.Second) + defer cancel() + + var dbobjlist []map[string]interface{} + var total int64 + propQry := make(map[string]string) + inputvars := make([]string, 0) + metainfo := make(map[string]interface{}) + variables := make(map[string]interface{}) + complexAPI := false + + // Get user context from security package + userCtx, ok := security.GetUserContext(ctx) + if !ok { + logger.Warn("No user context found in request") + userCtx = &security.UserContext{UserID: 0, UserName: "anonymous"} + } + + w.Header().Set("Content-Type", "application/json") + + // Extract input variables from SQL query (placeholders like [variable]) + sqlquery = h.extractInputVariables(sqlquery, &inputvars) + + // Merge URL path parameters + sqlquery = h.mergePathParams(r, sqlquery, variables) + + // Merge query string parameters + sqlquery = h.mergeQueryParams(r, sqlquery, variables, pAllowFilter, propQry) + + // Merge header parameters + sqlquery = h.mergeHeaderParams(r, sqlquery, variables, propQry, &complexAPI) + + // Build metainfo + metainfo["ipaddress"] = getIPAddress(r) + metainfo["url"] = r.RequestURI + metainfo["user"] = userCtx.UserName + metainfo["rid_user"] = fmt.Sprintf("%d", userCtx.UserID) + metainfo["method"] = r.Method + metainfo["variables"] = variables + + // Replace meta variables in SQL + sqlquery = h.replaceMetaVariables(sqlquery, r, userCtx, metainfo, variables) + + // Remove unused input variables + if pBlankparms { + for _, kw := range inputvars { + sqlquery = strings.ReplaceAll(sqlquery, kw, "") + logger.Debug("Removed unused variable: %s", kw) + } + } + + // Execute query within transaction + err := h.db.RunInTransaction(ctx, func(tx common.Database) error { + sqlqueryCnt := sqlquery + + // Parse sorting and pagination parameters + sortcols, limit, offset := h.parsePaginationParams(r) + fromPos := strings.Index(strings.ToLower(sqlquery), "from ") + orderbyPos := strings.Index(strings.ToLower(sqlquery), "order by") + + if len(sortcols) > 0 && (orderbyPos < 0 || (orderbyPos > 0 && orderbyPos < fromPos)) { + sqlquery = fmt.Sprintf("%s \nORDER BY %s", sqlquery, ValidSQL(sortcols, "select")) + } + + if !pNoCount { + if limit > 0 && offset > 0 { + sqlquery = fmt.Sprintf("%s \nLIMIT %d OFFSET %d", sqlquery, limit, offset) + } else if limit > 0 { + sqlquery = fmt.Sprintf("%s \nLIMIT %d", sqlquery, limit) + } else { + sqlquery = fmt.Sprintf("%s \nLIMIT %d", sqlquery, 20000) + } + + // Get total count + countQuery := fmt.Sprintf("SELECT COUNT(1) FROM (%s) cnts", sqlqueryCnt) + var countResult struct{ Count int64 } + if err := tx.Query(ctx, &countResult, countQuery); err != nil { + sendError(w, http.StatusBadRequest, "count_failed", "Failed to retrieve record count", err) + return err + } + total = countResult.Count + } + + // Execute main query + rows := make([]map[string]interface{}, 0) + if err := tx.Query(ctx, &rows, sqlquery); err != nil { + sendError(w, http.StatusBadRequest, "query_failed", "Failed to retrieve records", err) + return err + } + + dbobjlist = rows + + if pNoCount { + total = int64(len(dbobjlist)) + } + + return nil + }) + + if err != nil { + logger.Error("Transaction failed: %v", err) + return + } + + // Set response headers + respOffset := 0 + if offsetStr := r.URL.Query().Get("offset"); offsetStr != "" { + if o, err := strconv.Atoi(offsetStr); err == nil { + respOffset = o + } + } + + w.Header().Set("Content-Range", fmt.Sprintf("items %d-%d/%d", respOffset, respOffset+len(dbobjlist), total)) + logger.Info("Serving: Records %d of %d", len(dbobjlist), total) + + if len(dbobjlist) == 0 { + w.Write([]byte("[]")) + return + } + + if complexAPI { + metaobj := map[string]interface{}{ + "items": dbobjlist, + "count": fmt.Sprintf("%d", len(dbobjlist)), + "total": fmt.Sprintf("%d", total), + "tablename": r.URL.Path, + "tableprefix": "gsql", + } + + data, err := json.Marshal(metaobj) + if err != nil { + sendError(w, http.StatusInternalServerError, "json_error", "Could not marshal response", err) + } else { + if int64(len(dbobjlist)) < total { + w.WriteHeader(http.StatusPartialContent) + } + w.Write(data) + } + } else { + data, err := json.Marshal(dbobjlist) + if err != nil { + sendError(w, http.StatusInternalServerError, "json_error", "Could not marshal response", err) + } else { + if int64(len(dbobjlist)) < total { + w.WriteHeader(http.StatusPartialContent) + } + w.Write(data) + } + } + } +} + +// SqlQuery creates an HTTP handler that executes a SQL query and returns a single record +func (h *Handler) SqlQuery(sqlquery string, pBlankparms bool) HTTPFuncType { + return func(w http.ResponseWriter, r *http.Request) { + defer func() { + if err := recover(); err != nil { + stack := debug.Stack() + logger.Error("Panic in SqlQuery: %v\nStack trace:\n%s", err, string(stack)) + http.Error(w, fmt.Sprintf("Internal server error: %v", err), http.StatusInternalServerError) + } + }() + + ctx, cancel := context.WithTimeout(r.Context(), 600*time.Second) + defer cancel() + + propQry := make(map[string]string) + inputvars := make([]string, 0) + metainfo := make(map[string]interface{}) + variables := make(map[string]interface{}) + dbobj := make(map[string]interface{}) + + // Get user context from security package + userCtx, ok := security.GetUserContext(ctx) + if !ok { + logger.Warn("No user context found in request") + userCtx = &security.UserContext{UserID: 0, UserName: "anonymous"} + } + + w.Header().Set("Content-Type", "application/json") + + // Extract input variables from SQL query + sqlquery = h.extractInputVariables(sqlquery, &inputvars) + + // Merge URL path parameters + sqlquery = h.mergePathParams(r, sqlquery, variables) + + // Merge query string parameters + sqlquery = h.mergeQueryParams(r, sqlquery, variables, false, propQry) + + // Merge header parameters + complexAPI := false + sqlquery = h.mergeHeaderParams(r, sqlquery, variables, propQry, &complexAPI) + + // Build metainfo + metainfo["ipaddress"] = getIPAddress(r) + metainfo["url"] = r.RequestURI + metainfo["user"] = userCtx.UserName + metainfo["rid_user"] = fmt.Sprintf("%d", userCtx.UserID) + metainfo["method"] = r.Method + metainfo["variables"] = variables + + // Replace meta variables in SQL + sqlquery = h.replaceMetaVariables(sqlquery, r, userCtx, metainfo, variables) + + // Apply field filters from headers + for k, val := range propQry { + kLower := strings.ToLower(k) + if strings.HasPrefix(kLower, "x-fieldfilter-") { + colname := strings.ReplaceAll(kLower, "x-fieldfilter-", "") + if strings.Contains(strings.ToLower(sqlquery), colname) { + if val == "" || val == "0" { + sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("COALESCE(%s, 0) = %s", ValidSQL(colname, "colname"), ValidSQL(val, "colvalue"))) + } else { + sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s = %s", ValidSQL(colname, "colname"), ValidSQL(val, "colvalue"))) + } + } + } + } + + // Remove unused input variables + if pBlankparms { + for _, kw := range inputvars { + sqlquery = strings.ReplaceAll(sqlquery, kw, "") + logger.Debug("Removed unused variable: %s", kw) + } + } + + // Execute query within transaction + err := h.db.RunInTransaction(ctx, func(tx common.Database) error { + // Execute main query + rows := make([]map[string]interface{}, 0) + if err := tx.Query(ctx, &rows, sqlquery); err != nil { + sendError(w, http.StatusBadRequest, "query_failed", "Failed to retrieve records", err) + return err + } + + if len(rows) > 0 { + dbobj = rows[0] + } + + return nil + }) + + if err != nil { + logger.Error("Transaction failed: %v", err) + return + } + + // Check if response should be root-level data + if val, ok := dbobj["root_as_data"]; ok { + data, err := json.Marshal(val) + if err != nil { + sendError(w, http.StatusInternalServerError, "json_error", "Could not marshal response", err) + } else { + w.Write(data) + } + return + } + + // Marshal and send response + data, err := json.Marshal(dbobj) + if err != nil { + sendError(w, http.StatusInternalServerError, "json_error", "Could not marshal response", err) + } else { + w.Write(data) + } + } +} + +// Helper functions + +// extractInputVariables extracts placeholders like [variable] from the SQL query +func (h *Handler) extractInputVariables(sqlquery string, inputvars *[]string) string { + max := strings.Count(sqlquery, "[") * 4 + testsqlquery := sqlquery + for i := 0; i <= max; i++ { + iStart := strings.Index(testsqlquery, "[") + if iStart < 0 { + break + } + iEnd := strings.Index(testsqlquery, "]") + if iEnd < 0 { + break + } + *inputvars = append(*inputvars, testsqlquery[iStart:iEnd+1]) + testsqlquery = testsqlquery[iEnd+1:] + } + return sqlquery +} + +// mergePathParams merges URL path parameters into the SQL query +func (h *Handler) mergePathParams(r *http.Request, sqlquery string, variables map[string]interface{}) string { + // Note: Path parameters would typically come from a router like gorilla/mux + // For now, this is a placeholder for path parameter extraction + return sqlquery +} + +// mergeQueryParams merges query string parameters into the SQL query +func (h *Handler) mergeQueryParams(r *http.Request, sqlquery string, variables map[string]interface{}, allowFilter bool, propQry map[string]string) string { + for parmk, parmv := range r.URL.Query() { + if len(parmk) == 0 || len(parmv) == 0 { + continue + } + + val := parmv[0] + dec, err := restheadspec.DecodeParam(val) + if err == nil { + val = dec + } + + kword := fmt.Sprintf("[%s]", parmk) + variables[parmk] = val + + // Replace in SQL if placeholder exists + if strings.Contains(sqlquery, kword) && len(val) > 0 { + if strings.HasPrefix(parmk, "p-") { + sqlquery = strings.ReplaceAll(sqlquery, kword, val) + } + } + + // Add to propQry for x- prefixed params + if strings.HasPrefix(parmk, "x-") { + propQry[parmk] = val + } + + // Apply filters if allowed + if allowFilter && len(parmk) > 1 && strings.Contains(strings.ToLower(sqlquery), strings.ToLower(parmk)) { + if len(parmv) > 1 { + sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s IN (%s)", ValidSQL(parmk, "colname"), strings.Join(parmv, ","))) + } else { + if strings.Contains(val, "match=") { + colval := strings.ReplaceAll(val, "match=", "") + if colval != "*" { + sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s ILIKE '%%%s%%'", ValidSQL(parmk, "colname"), ValidSQL(colval, "colvalue"))) + } + } else if val == "" || val == "0" { + sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("(%[1]s = %[2]s OR %[1]s IS NULL)", ValidSQL(parmk, "colname"), ValidSQL(val, "colvalue"))) + } else { + if IsNumeric(val) { + sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s = %s", ValidSQL(parmk, "colname"), ValidSQL(val, "colvalue"))) + } else { + sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s = '%s'", ValidSQL(parmk, "colname"), ValidSQL(val, "colvalue"))) + } + } + } + } + } + return sqlquery +} + +// mergeHeaderParams merges HTTP header parameters into the SQL query +func (h *Handler) mergeHeaderParams(r *http.Request, sqlquery string, variables map[string]interface{}, propQry map[string]string, complexAPI *bool) string { + for kc, v := range r.Header { + k := strings.ToLower(kc) + if !strings.HasPrefix(k, "x-") || len(v) == 0 { + continue + } + + val := v[0] + dec, err := restheadspec.DecodeParam(val) + if err == nil { + val = dec + } + + variables[k] = val + propQry[k] = val + + kword := fmt.Sprintf("[%s]", k) + if strings.Contains(sqlquery, kword) { + sqlquery = strings.ReplaceAll(sqlquery, kword, val) + } + + // Handle special headers + if strings.Contains(k, "x-fieldfilter-") { + colname := strings.ReplaceAll(k, "x-fieldfilter-", "") + if val == "" || val == "0" { + sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("COALESCE(%s, 0) = %s", ValidSQL(colname, "colname"), ValidSQL(val, "colvalue"))) + } else { + sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s = %s", ValidSQL(colname, "colname"), ValidSQL(val, "colvalue"))) + } + } + + if strings.Contains(k, "x-searchfilter-") { + colname := strings.ReplaceAll(k, "x-searchfilter-", "") + sval := strings.ReplaceAll(val, "'", "") + if sval != "" { + sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s ILIKE '%%%s%%'", ValidSQL(colname, "colname"), ValidSQL(sval, "colvalue"))) + } + } + + if strings.Contains(k, "x-custom-sql-w") { + colval := ValidSQL(val, "select") + if len(colval) > 0 { + sqlquery = sqlQryWhere(sqlquery, colval) + } + } + + if strings.Contains(k, "x-simpleapi") { + *complexAPI = !(val == "1" || strings.ToLower(val) == "true") + } + } + return sqlquery +} + +// replaceMetaVariables replaces meta variables like [rid_user], [user], etc. in the SQL query +func (h *Handler) replaceMetaVariables(sqlquery string, r *http.Request, userCtx *security.UserContext, metainfo map[string]interface{}, variables map[string]interface{}) string { + if strings.Contains(sqlquery, "[p_meta_default]") { + data, _ := json.Marshal(metainfo) + sqlquery = strings.ReplaceAll(sqlquery, "[p_meta_default]", fmt.Sprintf("'%s'::jsonb", string(data))) + } + + if strings.Contains(sqlquery, "[json_variables]") { + data, _ := json.Marshal(variables) + sqlquery = strings.ReplaceAll(sqlquery, "[json_variables]", fmt.Sprintf("'%s'::jsonb", string(data))) + } + + if strings.Contains(sqlquery, "[rid_user]") { + sqlquery = strings.ReplaceAll(sqlquery, "[rid_user]", fmt.Sprintf("%d", userCtx.UserID)) + } + + if strings.Contains(sqlquery, "[user]") { + sqlquery = strings.ReplaceAll(sqlquery, "[user]", fmt.Sprintf("'%s'", userCtx.UserName)) + } + + if strings.Contains(sqlquery, "[rid_session]") { + sessionID := userCtx.SessionID + sqlquery = strings.ReplaceAll(sqlquery, "[rid_session]", fmt.Sprintf("'%s'", sessionID)) + } + + if strings.Contains(sqlquery, "[method]") { + sqlquery = strings.ReplaceAll(sqlquery, "[method]", r.Method) + } + + if strings.Contains(sqlquery, "[post_body]") { + bodystr := "" + if r.Method == "POST" || r.Method == "PUT" { + if r.Body != nil { + contents, err := io.ReadAll(r.Body) + if err == nil { + bodystr = string(contents) + } + } + } + sqlquery = strings.ReplaceAll(sqlquery, "[post_body]", fmt.Sprintf("'%s'", bodystr)) + } + + return sqlquery +} + +// parsePaginationParams extracts sort, limit, and offset parameters from request +func (h *Handler) parsePaginationParams(r *http.Request) (sortcols string, limit, offset int) { + limit = 20 + offset = 0 + + if sortStr := r.URL.Query().Get("sort"); sortStr != "" { + sortcols = sortStr + } + + if limitStr := r.URL.Query().Get("limit"); limitStr != "" { + if l, err := strconv.Atoi(limitStr); err == nil && l > 0 { + limit = l + } + } + + if offsetStr := r.URL.Query().Get("offset"); offsetStr != "" { + if o, err := strconv.Atoi(offsetStr); err == nil && o >= 0 { + offset = o + } + } + + return +} + +// ValidSQL validates and sanitizes SQL input to prevent injection +// mode can be: "colname", "colvalue", "select" +func ValidSQL(input, mode string) string { + // Remove dangerous characters based on mode + switch mode { + case "colname": + // For column names, only allow alphanumeric, underscore, and dot + reg := regexp.MustCompile(`[^a-zA-Z0-9_\.]`) + return reg.ReplaceAllString(input, "") + case "colvalue": + // For column values, escape single quotes + return strings.ReplaceAll(input, "'", "''") + case "select": + // For SELECT clauses, be more permissive but still safe + // Remove semicolons and common SQL injection patterns + dangerous := []string{";", "--", "/*", "*/", "xp_", "sp_", "DROP ", "DELETE ", "TRUNCATE ", "UPDATE ", "INSERT "} + result := input + for _, d := range dangerous { + result = strings.ReplaceAll(result, d, "") + result = strings.ReplaceAll(result, strings.ToLower(d), "") + result = strings.ReplaceAll(result, strings.ToUpper(d), "") + } + return result + default: + return input + } +} + +// sqlQryWhere adds a WHERE clause to a SQL query or appends to existing WHERE with AND +func sqlQryWhere(sqlquery, condition string) string { + lowerQuery := strings.ToLower(sqlquery) + wherePos := strings.Index(lowerQuery, " where ") + groupPos := strings.Index(lowerQuery, " group by") + orderPos := strings.Index(lowerQuery, " order by") + limitPos := strings.Index(lowerQuery, " limit ") + + // Find the insertion point (before GROUP BY, ORDER BY, or LIMIT) + insertPos := len(sqlquery) + if groupPos > 0 && groupPos < insertPos { + insertPos = groupPos + } + if orderPos > 0 && orderPos < insertPos { + insertPos = orderPos + } + if limitPos > 0 && limitPos < insertPos { + insertPos = limitPos + } + + if wherePos > 0 { + // WHERE exists, add AND condition before GROUP BY / ORDER BY / LIMIT + before := sqlquery[:insertPos] + after := sqlquery[insertPos:] + return fmt.Sprintf("%s AND %s %s", before, condition, after) + } else { + // No WHERE exists, add it before GROUP BY / ORDER BY / LIMIT + before := sqlquery[:insertPos] + after := sqlquery[insertPos:] + return fmt.Sprintf("%s WHERE %s %s", before, condition, after) + } +} + +// IsNumeric checks if a string contains only numeric characters +func IsNumeric(s string) bool { + _, err := strconv.ParseFloat(s, 64) + return err == nil +} + +// makeResultReceiver creates a slice of interface{} pointers for scanning SQL rows +// func makeResultReceiver(length int) []interface{} { +// result := make([]interface{}, length) +// for i := 0; i < length; i++ { +// var v interface{} +// result[i] = &v +// } +// return result +// } + +// getIPAddress extracts the real IP address from the request +func getIPAddress(r *http.Request) string { + if forwarded := r.Header.Get("X-Forwarded-For"); forwarded != "" { + // X-Forwarded-For can contain multiple IPs, take the first one + ips := strings.Split(forwarded, ",") + return strings.TrimSpace(ips[0]) + } + if realIP := r.Header.Get("X-Real-IP"); realIP != "" { + return realIP + } + return r.RemoteAddr +} + +// sendError sends a JSON error response +func sendError(w http.ResponseWriter, status int, code, message string, err error) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + + errObj := common.APIError{ + Code: code, + Message: message, + } + if err != nil { + errObj.Detail = err.Error() + } + + data, _ := json.Marshal(map[string]interface{}{ + "success": false, + "error": errObj, + }) + w.Write(data) +}