diff --git a/pkg/funcspec/function_api.go b/pkg/funcspec/function_api.go index ba97b32..6e5591d 100644 --- a/pkg/funcspec/function_api.go +++ b/pkg/funcspec/function_api.go @@ -24,6 +24,13 @@ type Handler struct { hooks *HookRegistry } +type SqlQueryOptions struct { + GetVariablesCallback func(w http.ResponseWriter, r *http.Request) map[string]interface{} + NoCount bool + BlankParams bool + AllowFilter bool +} + // NewHandler creates a new function API handler func NewHandler(db common.Database) *Handler { return &Handler{ @@ -48,7 +55,7 @@ func (h *Handler) Hooks() *HookRegistry { type HTTPFuncType func(http.ResponseWriter, *http.Request) // SqlQueryList creates an HTTP handler that executes a SQL query and returns a list with pagination -func (h *Handler) SqlQueryList(sqlquery string, pNoCount, pBlankparms, pAllowFilter bool) HTTPFuncType { +func (h *Handler) SqlQueryList(sqlquery string, options SqlQueryOptions) HTTPFuncType { return func(w http.ResponseWriter, r *http.Request) { defer func() { if err := recover(); err != nil { @@ -70,6 +77,9 @@ func (h *Handler) SqlQueryList(sqlquery string, pNoCount, pBlankparms, pAllowFil inputvars := make([]string, 0) metainfo := make(map[string]interface{}) variables := make(map[string]interface{}) + if options.GetVariablesCallback != nil { + variables = options.GetVariablesCallback(w, r) + } complexAPI := false // Get user context from security package @@ -93,9 +103,9 @@ func (h *Handler) SqlQueryList(sqlquery string, pNoCount, pBlankparms, pAllowFil MetaInfo: metainfo, PropQry: propQry, UserContext: userCtx, - NoCount: pNoCount, - BlankParams: pBlankparms, - AllowFilter: pAllowFilter, + NoCount: options.NoCount, + BlankParams: options.BlankParams, + AllowFilter: options.AllowFilter, ComplexAPI: complexAPI, } @@ -131,13 +141,13 @@ func (h *Handler) SqlQueryList(sqlquery string, pNoCount, pBlankparms, pAllowFil complexAPI = reqParams.ComplexAPI // Merge query string parameters - sqlquery = h.mergeQueryParams(r, sqlquery, variables, pAllowFilter, propQry) + sqlquery = h.mergeQueryParams(r, sqlquery, variables, options.AllowFilter, propQry) // Merge header parameters sqlquery = h.mergeHeaderParams(r, sqlquery, variables, propQry, &complexAPI) // Apply filters from parsed parameters (if not already applied by pAllowFilter) - if !pAllowFilter { + if !options.AllowFilter { sqlquery = h.ApplyFilters(sqlquery, reqParams) } @@ -149,7 +159,7 @@ func (h *Handler) SqlQueryList(sqlquery string, pNoCount, pBlankparms, pAllowFil // Override pNoCount if skipcount is specified if reqParams.SkipCount { - pNoCount = true + options.NoCount = true } // Build metainfo @@ -164,7 +174,7 @@ func (h *Handler) SqlQueryList(sqlquery string, pNoCount, pBlankparms, pAllowFil sqlquery = h.replaceMetaVariables(sqlquery, r, userCtx, metainfo, variables) // Remove unused input variables - if pBlankparms { + if options.BlankParams { for _, kw := range inputvars { replacement := getReplacementForBlankParam(sqlquery, kw) sqlquery = strings.ReplaceAll(sqlquery, kw, replacement) @@ -205,7 +215,7 @@ func (h *Handler) SqlQueryList(sqlquery string, pNoCount, pBlankparms, pAllowFil sqlquery = fmt.Sprintf("%s \nORDER BY %s", sqlquery, ValidSQL(sortcols, "select")) } - if !pNoCount { + if !options.NoCount { if limit > 0 && offset > 0 { sqlquery = fmt.Sprintf("%s \nLIMIT %d OFFSET %d", sqlquery, limit, offset) } else if limit > 0 { @@ -244,7 +254,7 @@ func (h *Handler) SqlQueryList(sqlquery string, pNoCount, pBlankparms, pAllowFil // Normalize PostgreSQL types for proper JSON marshaling dbobjlist = normalizePostgresTypesList(rows) - if pNoCount { + if options.NoCount { total = int64(len(dbobjlist)) } @@ -386,7 +396,7 @@ func (h *Handler) SqlQueryList(sqlquery string, pNoCount, pBlankparms, pAllowFil } // SqlQuery creates an HTTP handler that executes a SQL query and returns a single record -func (h *Handler) SqlQuery(sqlquery string, pBlankparms bool) HTTPFuncType { +func (h *Handler) SqlQuery(sqlquery string, options SqlQueryOptions) HTTPFuncType { return func(w http.ResponseWriter, r *http.Request) { defer func() { if err := recover(); err != nil { @@ -406,6 +416,9 @@ func (h *Handler) SqlQuery(sqlquery string, pBlankparms bool) HTTPFuncType { inputvars := make([]string, 0) metainfo := make(map[string]interface{}) variables := make(map[string]interface{}) + if options.GetVariablesCallback != nil { + variables = options.GetVariablesCallback(w, r) + } dbobj := make(map[string]interface{}) complexAPI := false @@ -430,7 +443,7 @@ func (h *Handler) SqlQuery(sqlquery string, pBlankparms bool) HTTPFuncType { MetaInfo: metainfo, PropQry: propQry, UserContext: userCtx, - BlankParams: pBlankparms, + BlankParams: options.BlankParams, ComplexAPI: complexAPI, } @@ -507,7 +520,7 @@ func (h *Handler) SqlQuery(sqlquery string, pBlankparms bool) HTTPFuncType { } // Remove unused input variables - if pBlankparms { + if options.BlankParams { for _, kw := range inputvars { replacement := getReplacementForBlankParam(sqlquery, kw) sqlquery = strings.ReplaceAll(sqlquery, kw, replacement) diff --git a/pkg/funcspec/function_api_test.go b/pkg/funcspec/function_api_test.go index d1849b9..c653a43 100644 --- a/pkg/funcspec/function_api_test.go +++ b/pkg/funcspec/function_api_test.go @@ -532,7 +532,7 @@ func TestSqlQuery(t *testing.T) { req := createTestRequest("GET", "/test", tt.queryParams, tt.headers, nil) w := httptest.NewRecorder() - handlerFunc := handler.SqlQuery(tt.sqlQuery, tt.blankParams) + handlerFunc := handler.SqlQuery(tt.sqlQuery, SqlQueryOptions{BlankParams: tt.blankParams}) handlerFunc(w, req) if w.Code != tt.expectedStatus { @@ -655,7 +655,7 @@ func TestSqlQueryList(t *testing.T) { req := createTestRequest("GET", "/test", tt.queryParams, tt.headers, nil) w := httptest.NewRecorder() - handlerFunc := handler.SqlQueryList(tt.sqlQuery, tt.noCount, tt.blankParams, tt.allowFilter) + handlerFunc := handler.SqlQueryList(tt.sqlQuery, SqlQueryOptions{NoCount: tt.noCount, BlankParams: tt.blankParams, AllowFilter: tt.allowFilter}) handlerFunc(w, req) if w.Code != tt.expectedStatus { diff --git a/pkg/funcspec/hooks_test.go b/pkg/funcspec/hooks_test.go index fb9055a..0013d65 100644 --- a/pkg/funcspec/hooks_test.go +++ b/pkg/funcspec/hooks_test.go @@ -576,7 +576,7 @@ func TestHookIntegrationWithHandler(t *testing.T) { req := createTestRequest("GET", "/test", nil, nil, nil) w := httptest.NewRecorder() - handlerFunc := handler.SqlQuery("SELECT * FROM users WHERE id = 1", false) + handlerFunc := handler.SqlQuery("SELECT * FROM users WHERE id = 1", SqlQueryOptions{}) handlerFunc(w, req) if !hookCalled {