From c8704c07dd07043a4717c08d04b1f48e6abf4621 Mon Sep 17 00:00:00 2001 From: Hein Date: Mon, 10 Nov 2025 10:22:55 +0200 Subject: [PATCH] Added cursor filters and hooks --- pkg/common/adapters/database/utils.go | 161 +++++++++- pkg/common/adapters/database/utils_test.go | 233 ++++++++++++++ pkg/common/interfaces.go | 5 + pkg/restheadspec/cursor.go | 223 +++++++++++++ pkg/restheadspec/handler.go | 186 ++++++++++- pkg/restheadspec/hooks.go | 140 +++++++++ pkg/restheadspec/hooks_example.go | 197 ++++++++++++ pkg/restheadspec/hooks_test.go | 347 +++++++++++++++++++++ 8 files changed, 1487 insertions(+), 5 deletions(-) create mode 100644 pkg/common/adapters/database/utils_test.go create mode 100644 pkg/restheadspec/cursor.go create mode 100644 pkg/restheadspec/hooks.go create mode 100644 pkg/restheadspec/hooks_example.go create mode 100644 pkg/restheadspec/hooks_test.go diff --git a/pkg/common/adapters/database/utils.go b/pkg/common/adapters/database/utils.go index e1e710d..7360fc7 100644 --- a/pkg/common/adapters/database/utils.go +++ b/pkg/common/adapters/database/utils.go @@ -1,6 +1,11 @@ package database -import "strings" +import ( + "reflect" + "strings" + + "github.com/Warky-Devs/ResolveSpec/pkg/common" +) // parseTableName splits a table name that may contain schema into separate schema and table // For example: "public.users" -> ("public", "users") @@ -11,3 +16,157 @@ func parseTableName(fullTableName string) (schema, table string) { } return "", fullTableName } + +// GetPrimaryKeyName extracts the primary key column name from a model +// It first checks if the model implements PrimaryKeyNameProvider (GetIDName method) +// Falls back to reflection to find bun:",pk" tag, then gorm:"primaryKey" tag +func GetPrimaryKeyName(model any) string { + // Check if model implements PrimaryKeyNameProvider + if provider, ok := model.(common.PrimaryKeyNameProvider); ok { + return provider.GetIDName() + } + + // Try Bun tag first + if pkName := getPrimaryKeyFromReflection(model, "bun"); pkName != "" { + return pkName + } + + // Fall back to GORM tag + return getPrimaryKeyFromReflection(model, "gorm") +} + +// GetModelColumns extracts all column names from a model using reflection +// It checks bun tags first, then gorm tags, then json tags, and finally falls back to lowercase field names +func GetModelColumns(model any) []string { + var columns []string + + modelType := reflect.TypeOf(model) + + // Unwrap pointers, slices, and arrays to get to the base struct type + for modelType != nil && (modelType.Kind() == reflect.Pointer || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) { + modelType = modelType.Elem() + } + + // Validate that we have a struct type + if modelType == nil || modelType.Kind() != reflect.Struct { + return columns + } + + for i := 0; i < modelType.NumField(); i++ { + field := modelType.Field(i) + + // Get column name using the same logic as primary key extraction + columnName := getColumnNameFromField(field) + + if columnName != "" { + columns = append(columns, columnName) + } + } + + return columns +} + +// getColumnNameFromField extracts the column name from a struct field +// Priority: bun tag -> gorm tag -> json tag -> lowercase field name +func getColumnNameFromField(field reflect.StructField) string { + // Try bun tag first + bunTag := field.Tag.Get("bun") + if bunTag != "" && bunTag != "-" { + if colName := extractColumnFromBunTag(bunTag); colName != "" { + return colName + } + } + + // Try gorm tag + gormTag := field.Tag.Get("gorm") + if gormTag != "" && gormTag != "-" { + if colName := extractColumnFromGormTag(gormTag); colName != "" { + return colName + } + } + + // Fall back to json tag + jsonTag := field.Tag.Get("json") + if jsonTag != "" && jsonTag != "-" { + // Extract just the field name before any options + parts := strings.Split(jsonTag, ",") + if len(parts) > 0 && parts[0] != "" { + return parts[0] + } + } + + // Last resort: use field name in lowercase + return strings.ToLower(field.Name) +} + +// getPrimaryKeyFromReflection uses reflection to find the primary key field +func getPrimaryKeyFromReflection(model any, ormType string) string { + val := reflect.ValueOf(model) + if val.Kind() == reflect.Pointer { + val = val.Elem() + } + + if val.Kind() != reflect.Struct { + return "" + } + + typ := val.Type() + for i := 0; i < typ.NumField(); i++ { + field := typ.Field(i) + + switch ormType { + case "gorm": + // Check for gorm tag with primaryKey + gormTag := field.Tag.Get("gorm") + if strings.Contains(gormTag, "primaryKey") { + // Try to extract column name from gorm tag + if colName := extractColumnFromGormTag(gormTag); colName != "" { + return colName + } + // Fall back to json tag + if jsonTag := field.Tag.Get("json"); jsonTag != "" { + return strings.Split(jsonTag, ",")[0] + } + } + case "bun": + // Check for bun tag with pk flag + bunTag := field.Tag.Get("bun") + if strings.Contains(bunTag, "pk") { + // Extract column name from bun tag + if colName := extractColumnFromBunTag(bunTag); colName != "" { + return colName + } + // Fall back to json tag + if jsonTag := field.Tag.Get("json"); jsonTag != "" { + return strings.Split(jsonTag, ",")[0] + } + } + } + } + + return "" +} + +// extractColumnFromGormTag extracts the column name from a gorm tag +// Example: "column:id;primaryKey" -> "id" +func extractColumnFromGormTag(tag string) string { + parts := strings.Split(tag, ";") + for _, part := range parts { + part = strings.TrimSpace(part) + if colName, found := strings.CutPrefix(part, "column:"); found { + return colName + } + } + return "" +} + +// extractColumnFromBunTag extracts the column name from a bun tag +// Example: "id,pk" -> "id" +// Example: ",pk" -> "" (will fall back to json tag) +func extractColumnFromBunTag(tag string) string { + parts := strings.Split(tag, ",") + if len(parts) > 0 && parts[0] != "" { + return parts[0] + } + return "" +} diff --git a/pkg/common/adapters/database/utils_test.go b/pkg/common/adapters/database/utils_test.go new file mode 100644 index 0000000..0be46bf --- /dev/null +++ b/pkg/common/adapters/database/utils_test.go @@ -0,0 +1,233 @@ +package database + +import ( + "testing" +) + +// Test models for GORM +type GormModelWithGetIDName struct { + ID int `gorm:"column:rid_test;primaryKey" json:"id"` + Name string `json:"name"` +} + +func (m GormModelWithGetIDName) GetIDName() string { + return "rid_test" +} + +type GormModelWithColumnTag struct { + ID int `gorm:"column:custom_id;primaryKey" json:"id"` + Name string `json:"name"` +} + +type GormModelWithJSONFallback struct { + ID int `gorm:"primaryKey" json:"user_id"` + Name string `json:"name"` +} + +// Test models for Bun +type BunModelWithGetIDName struct { + ID int `bun:"rid_test,pk" json:"id"` + Name string `json:"name"` +} + +func (m BunModelWithGetIDName) GetIDName() string { + return "rid_test" +} + +type BunModelWithColumnTag struct { + ID int `bun:"custom_id,pk" json:"id"` + Name string `json:"name"` +} + +type BunModelWithJSONFallback struct { + ID int `bun:",pk" json:"user_id"` + Name string `json:"name"` +} + +func TestGetPrimaryKeyName(t *testing.T) { + tests := []struct { + name string + model any + expected string + }{ + { + name: "GORM model with GetIDName method", + model: GormModelWithGetIDName{}, + expected: "rid_test", + }, + { + name: "GORM model with column tag", + model: GormModelWithColumnTag{}, + expected: "custom_id", + }, + { + name: "GORM model with JSON fallback", + model: GormModelWithJSONFallback{}, + expected: "user_id", + }, + { + name: "GORM model pointer with GetIDName", + model: &GormModelWithGetIDName{}, + expected: "rid_test", + }, + { + name: "GORM model pointer with column tag", + model: &GormModelWithColumnTag{}, + expected: "custom_id", + }, + { + name: "Bun model with GetIDName method", + model: BunModelWithGetIDName{}, + expected: "rid_test", + }, + { + name: "Bun model with column tag", + model: BunModelWithColumnTag{}, + expected: "custom_id", + }, + { + name: "Bun model with JSON fallback", + model: BunModelWithJSONFallback{}, + expected: "user_id", + }, + { + name: "Bun model pointer with GetIDName", + model: &BunModelWithGetIDName{}, + expected: "rid_test", + }, + { + name: "Bun model pointer with column tag", + model: &BunModelWithColumnTag{}, + expected: "custom_id", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := GetPrimaryKeyName(tt.model) + if result != tt.expected { + t.Errorf("GetPrimaryKeyName() = %v, want %v", result, tt.expected) + } + }) + } +} + +func TestExtractColumnFromGormTag(t *testing.T) { + tests := []struct { + name string + tag string + expected string + }{ + { + name: "column tag with primaryKey", + tag: "column:rid_test;primaryKey", + expected: "rid_test", + }, + { + name: "column tag with spaces", + tag: "column:user_id ; primaryKey ; autoIncrement", + expected: "user_id", + }, + { + name: "no column tag", + tag: "primaryKey;autoIncrement", + expected: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := extractColumnFromGormTag(tt.tag) + if result != tt.expected { + t.Errorf("extractColumnFromGormTag() = %v, want %v", result, tt.expected) + } + }) + } +} + +func TestExtractColumnFromBunTag(t *testing.T) { + tests := []struct { + name string + tag string + expected string + }{ + { + name: "column name with pk flag", + tag: "rid_test,pk", + expected: "rid_test", + }, + { + name: "only pk flag", + tag: ",pk", + expected: "", + }, + { + name: "column with multiple flags", + tag: "user_id,pk,autoincrement", + expected: "user_id", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := extractColumnFromBunTag(tt.tag) + if result != tt.expected { + t.Errorf("extractColumnFromBunTag() = %v, want %v", result, tt.expected) + } + }) + } +} + +func TestGetModelColumns(t *testing.T) { + tests := []struct { + name string + model any + expected []string + }{ + { + name: "Bun model with multiple columns", + model: BunModelWithColumnTag{}, + expected: []string{"custom_id", "name"}, + }, + { + name: "GORM model with multiple columns", + model: GormModelWithColumnTag{}, + expected: []string{"custom_id", "name"}, + }, + { + name: "Bun model pointer", + model: &BunModelWithColumnTag{}, + expected: []string{"custom_id", "name"}, + }, + { + name: "GORM model pointer", + model: &GormModelWithColumnTag{}, + expected: []string{"custom_id", "name"}, + }, + { + name: "Bun model with JSON fallback", + model: BunModelWithJSONFallback{}, + expected: []string{"user_id", "name"}, + }, + { + name: "GORM model with JSON fallback", + model: GormModelWithJSONFallback{}, + expected: []string{"user_id", "name"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := GetModelColumns(tt.model) + if len(result) != len(tt.expected) { + t.Errorf("GetModelColumns() returned %d columns, want %d", len(result), len(tt.expected)) + return + } + for i, col := range result { + if col != tt.expected[i] { + t.Errorf("GetModelColumns()[%d] = %v, want %v", i, col, tt.expected[i]) + } + } + }) + } +} diff --git a/pkg/common/interfaces.go b/pkg/common/interfaces.go index bc3bb1d..5b668c4 100644 --- a/pkg/common/interfaces.go +++ b/pkg/common/interfaces.go @@ -131,6 +131,11 @@ type TableNameProvider interface { TableName() string } +// PrimaryKeyNameProvider interface for models that provide primary key column names +type PrimaryKeyNameProvider interface { + GetIDName() string +} + // SchemaProvider interface for models that provide schema names type SchemaProvider interface { SchemaName() string diff --git a/pkg/restheadspec/cursor.go b/pkg/restheadspec/cursor.go new file mode 100644 index 0000000..6bcd18b --- /dev/null +++ b/pkg/restheadspec/cursor.go @@ -0,0 +1,223 @@ +package restheadspec + +import ( + "fmt" + "strings" + + "github.com/Warky-Devs/ResolveSpec/pkg/common" +) + +// CursorDirection defines pagination direction +type CursorDirection int + +const ( + CursorForward CursorDirection = 1 + CursorBackward CursorDirection = -1 +) + +// GetCursorFilter generates a SQL `EXISTS` subquery for cursor-based pagination. +// It uses the current request's sort, cursor, joins (via Expand), and CQL (via ComputedQL). +// +// Parameters: +// - tableName: name of the main table (e.g. "post") +// - pkName: primary key column (e.g. "id") +// - modelColumns: optional list of valid main-table columns (for validation). Pass nil to skip. +// - expandJoins: optional map[alias]string of JOIN clauses (e.g. "user": "LEFT JOIN user ON ...") +// +// Returns SQL snippet to embed in WHERE clause. +func (opts *ExtendedRequestOptions) GetCursorFilter( + tableName string, + pkName string, + modelColumns []string, // optional: for validation + expandJoins map[string]string, // optional: alias → JOIN SQL +) (string, error) { + + // --------------------------------------------------------------------- // + // 1. Determine active cursor + // --------------------------------------------------------------------- // + cursorID, direction := opts.getActiveCursor() + if cursorID == "" { + return "", fmt.Errorf("no cursor provided for table %s", tableName) + } + + // --------------------------------------------------------------------- // + // 2. Extract sort columns + // --------------------------------------------------------------------- // + sortItems := opts.getSortColumns() + if len(sortItems) == 0 { + return "", fmt.Errorf("no sort columns defined") + } + + // --------------------------------------------------------------------- // + // 3. Prepare + // --------------------------------------------------------------------- // + var whereClauses []string + joinSQL := "" + reverse := direction < 0 + + // --------------------------------------------------------------------- // + // 4. Process each sort column + // --------------------------------------------------------------------- // + for _, s := range sortItems { + col := strings.TrimSpace(s.Column) + if col == "" { + continue + } + + // Parse: "user.name desc nulls last" + parts := strings.Split(col, ".") + field := strings.TrimSpace(parts[len(parts)-1]) + prefix := strings.Join(parts[:len(parts)-1], ".") + + // Direction from struct or string + desc := strings.EqualFold(s.Direction, "desc") || + strings.Contains(strings.ToLower(field), "desc") + field = opts.cleanSortField(field) + + if reverse { + desc = !desc + } + + // Resolve column + cursorCol, targetCol, isJoin, err := opts.resolveColumn( + field, prefix, tableName, modelColumns, + ) + if err != nil { + fmt.Printf("WARN: Skipping invalid sort column %q: %v\n", col, err) + continue + } + + // Handle joins + if isJoin && expandJoins != nil { + if joinClause, ok := expandJoins[prefix]; ok { + jSQL, cRef := rewriteJoin(joinClause, tableName, prefix) + joinSQL = jSQL + cursorCol = cRef + "." + field + targetCol = prefix + "." + field + } + } + + // Build inequality + op := "<" + if desc { + op = ">" + } + whereClauses = append(whereClauses, fmt.Sprintf("%s %s %s", cursorCol, op, targetCol)) + } + + if len(whereClauses) == 0 { + return "", fmt.Errorf("no valid sort columns after filtering") + } + + // --------------------------------------------------------------------- // + // 5. Build priority OR-AND chain + // --------------------------------------------------------------------- // + orSQL := buildPriorityChain(whereClauses) + + // --------------------------------------------------------------------- // + // 6. Final EXISTS subquery + // --------------------------------------------------------------------- // + query := fmt.Sprintf(`EXISTS ( + SELECT 1 + FROM %s cursor_select + %s + WHERE cursor_select.%s = %s + AND (%s) +)`, + tableName, + joinSQL, + pkName, + cursorID, + orSQL, + ) + + return query, nil +} + +// ------------------------------------------------------------------------- // +// Helper: get active cursor (forward or backward) +func (opts *ExtendedRequestOptions) getActiveCursor() (id string, direction CursorDirection) { + if opts.CursorForward != "" { + return opts.CursorForward, CursorForward + } + if opts.CursorBackward != "" { + return opts.CursorBackward, CursorBackward + } + return "", 0 +} + +// Helper: extract sort columns +func (opts *ExtendedRequestOptions) getSortColumns() []common.SortOption { + if opts.RequestOptions.Sort != nil { + return opts.RequestOptions.Sort + } + return nil +} + +// Helper: clean sort field (remove desc, asc, nulls) +func (opts *ExtendedRequestOptions) cleanSortField(field string) string { + f := strings.ToLower(field) + for _, token := range []string{"desc", "asc", "nulls last", "nulls first"} { + f = strings.ReplaceAll(f, token, "") + } + return strings.TrimSpace(f) +} + +// Helper: resolve column (main, JSON, CQL, join) +func (opts *ExtendedRequestOptions) resolveColumn( + field, prefix, tableName string, + modelColumns []string, +) (cursorCol, targetCol string, isJoin bool, err error) { + + // JSON field + if strings.Contains(field, "->") { + return "cursor_select." + field, tableName + "." + field, false, nil + } + + // CQL via ComputedQL + if strings.Contains(strings.ToLower(field), "cql") && opts.ComputedQL != nil { + if expr, ok := opts.ComputedQL[field]; ok { + return "cursor_select." + expr, expr, false, nil + } + } + + // Main table column + if modelColumns != nil { + for _, col := range modelColumns { + if strings.EqualFold(col, field) { + return "cursor_select." + field, tableName + "." + field, false, nil + } + } + } else { + // No validation → allow all main-table fields + return "cursor_select." + field, tableName + "." + field, false, nil + } + + // Joined column + if prefix != "" && prefix != tableName { + return "", "", true, nil + } + + return "", "", false, fmt.Errorf("invalid column: %s", field) +} + +// ------------------------------------------------------------------------- // +// Helper: rewrite JOIN clause for cursor subquery +func rewriteJoin(joinClause, mainTable, alias string) (joinSQL, cursorAlias string) { + joinSQL = strings.ReplaceAll(joinClause, mainTable+".", "cursor_select.") + cursorAlias = "cursor_select_" + alias + joinSQL = strings.ReplaceAll(joinSQL, " "+alias+" ", " "+cursorAlias+" ") + joinSQL = strings.ReplaceAll(joinSQL, " "+alias+".", " "+cursorAlias+".") + return joinSQL, cursorAlias +} + +// ------------------------------------------------------------------------- // +// Helper: build OR-AND priority chain +func buildPriorityChain(clauses []string) string { + var or []string + for i := 0; i < len(clauses); i++ { + and := strings.Join(clauses[:i+1], "\n AND ") + or = append(or, "("+and+")") + } + return strings.Join(or, "\n OR ") +} diff --git a/pkg/restheadspec/handler.go b/pkg/restheadspec/handler.go index 5b39d2a..bc78b2e 100644 --- a/pkg/restheadspec/handler.go +++ b/pkg/restheadspec/handler.go @@ -10,6 +10,7 @@ import ( "strings" "github.com/Warky-Devs/ResolveSpec/pkg/common" + "github.com/Warky-Devs/ResolveSpec/pkg/common/adapters/database" "github.com/Warky-Devs/ResolveSpec/pkg/logger" ) @@ -18,6 +19,7 @@ import ( type Handler struct { db common.Database registry common.ModelRegistry + hooks *HookRegistry } // NewHandler creates a new API handler with database and registry abstractions @@ -25,9 +27,16 @@ func NewHandler(db common.Database, registry common.ModelRegistry) *Handler { return &Handler{ db: db, registry: registry, + hooks: NewHookRegistry(), } } +// Hooks returns the hook registry for this handler +// Use this to register custom hooks for operations +func (h *Handler) Hooks() *HookRegistry { + return h.hooks +} + // handlePanic is a helper function to handle panics with stack traces func (h *Handler) handlePanic(w common.ResponseWriter, method string, err interface{}) { stack := debug.Stack() @@ -184,6 +193,25 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st tableName := GetTableName(ctx) model := GetModel(ctx) + // Execute BeforeRead hooks + hookCtx := &HookContext{ + Context: ctx, + Handler: h, + Schema: schema, + Entity: entity, + TableName: tableName, + Model: model, + Options: options, + ID: id, + Writer: w, + } + + if err := h.hooks.Execute(BeforeRead, hookCtx); err != nil { + logger.Error("BeforeRead hook failed: %v", err) + h.sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err) + return + } + // Validate and unwrap model type to get base struct modelType := reflect.TypeOf(model) for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) { @@ -310,6 +338,39 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st query = query.Offset(*options.Offset) } + // Apply cursor-based pagination + if len(options.CursorForward) > 0 || len(options.CursorBackward) > 0 { + logger.Debug("Applying cursor pagination") + + // Get primary key name + pkName := database.GetPrimaryKeyName(model) + + // Extract model columns for validation using the generic database function + modelColumns := database.GetModelColumns(model) + + // Build expand joins map (if needed in future) + var expandJoins map[string]string + if len(options.Expand) > 0 { + expandJoins = make(map[string]string) + // TODO: Build actual JOIN SQL for each expand relation + // For now, pass empty map as joins are handled via Preload + } + + // Get cursor filter SQL + cursorFilter, err := options.GetCursorFilter(tableName, pkName, modelColumns, expandJoins) + if err != nil { + logger.Error("Error building cursor filter: %v", err) + h.sendError(w, http.StatusBadRequest, "cursor_error", "Invalid cursor pagination", err) + return + } + + // Apply cursor filter to query + if cursorFilter != "" { + logger.Debug("Applying cursor filter: %s", cursorFilter) + query = query.Where(cursorFilter) + } + } + // Execute query - modelPtr was already created earlier if err := query.Scan(ctx, modelPtr); err != nil { logger.Error("Error executing query: %v", err) @@ -333,6 +394,16 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st Offset: offset, } + // Execute AfterRead hooks + hookCtx.Result = modelPtr + hookCtx.Error = nil + + if err := h.hooks.Execute(AfterRead, hookCtx); err != nil { + logger.Error("AfterRead hook failed: %v", err) + h.sendError(w, http.StatusInternalServerError, "hook_error", "Hook execution failed", err) + return + } + h.sendFormattedResponse(w, modelPtr, metadata, options) } @@ -351,6 +422,28 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat logger.Info("Creating record in %s.%s", schema, entity) + // Execute BeforeCreate hooks + hookCtx := &HookContext{ + Context: ctx, + Handler: h, + Schema: schema, + Entity: entity, + TableName: tableName, + Model: model, + Options: options, + Data: data, + Writer: w, + } + + if err := h.hooks.Execute(BeforeCreate, hookCtx); err != nil { + logger.Error("BeforeCreate hook failed: %v", err) + h.sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err) + return + } + + // Use potentially modified data from hook context + data = hookCtx.Data + // Handle batch creation dataValue := reflect.ValueOf(data) if dataValue.Kind() == reflect.Slice || dataValue.Kind() == reflect.Array { @@ -385,6 +478,16 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat return } + // Execute AfterCreate hooks for batch creation + hookCtx.Result = map[string]interface{}{"created": dataValue.Len()} + hookCtx.Error = nil + + if err := h.hooks.Execute(AfterCreate, hookCtx); err != nil { + logger.Error("AfterCreate hook failed: %v", err) + h.sendError(w, http.StatusInternalServerError, "hook_error", "Hook execution failed", err) + return + } + h.sendResponse(w, map[string]interface{}{"created": dataValue.Len()}, nil) return } @@ -410,6 +513,16 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat return } + // Execute AfterCreate hooks for single record creation + hookCtx.Result = modelValue + hookCtx.Error = nil + + if err := h.hooks.Execute(AfterCreate, hookCtx); err != nil { + logger.Error("AfterCreate hook failed: %v", err) + h.sendError(w, http.StatusInternalServerError, "hook_error", "Hook execution failed", err) + return + } + h.sendResponse(w, modelValue, nil) } @@ -424,9 +537,33 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id schema := GetSchema(ctx) entity := GetEntity(ctx) tableName := GetTableName(ctx) + model := GetModel(ctx) logger.Info("Updating record in %s.%s", schema, entity) + // Execute BeforeUpdate hooks + hookCtx := &HookContext{ + Context: ctx, + Handler: h, + Schema: schema, + Entity: entity, + TableName: tableName, + Model: model, + Options: options, + ID: id, + Data: data, + Writer: w, + } + + if err := h.hooks.Execute(BeforeUpdate, hookCtx); err != nil { + logger.Error("BeforeUpdate hook failed: %v", err) + h.sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err) + return + } + + // Use potentially modified data from hook context + data = hookCtx.Data + // Convert data to map dataMap, ok := data.(map[string]interface{}) if !ok { @@ -462,9 +599,20 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id return } - h.sendResponse(w, map[string]interface{}{ + // Execute AfterUpdate hooks + responseData := map[string]interface{}{ "updated": result.RowsAffected(), - }, nil) + } + hookCtx.Result = responseData + hookCtx.Error = nil + + if err := h.hooks.Execute(AfterUpdate, hookCtx); err != nil { + logger.Error("AfterUpdate hook failed: %v", err) + h.sendError(w, http.StatusInternalServerError, "hook_error", "Hook execution failed", err) + return + } + + h.sendResponse(w, responseData, nil) } func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id string) { @@ -478,9 +626,28 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id schema := GetSchema(ctx) entity := GetEntity(ctx) tableName := GetTableName(ctx) + model := GetModel(ctx) logger.Info("Deleting record from %s.%s", schema, entity) + // Execute BeforeDelete hooks + hookCtx := &HookContext{ + Context: ctx, + Handler: h, + Schema: schema, + Entity: entity, + TableName: tableName, + Model: model, + ID: id, + Writer: w, + } + + if err := h.hooks.Execute(BeforeDelete, hookCtx); err != nil { + logger.Error("BeforeDelete hook failed: %v", err) + h.sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err) + return + } + query := h.db.NewDelete().Table(tableName) if id == "" { @@ -497,9 +664,20 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id return } - h.sendResponse(w, map[string]interface{}{ + // Execute AfterDelete hooks + responseData := map[string]interface{}{ "deleted": result.RowsAffected(), - }, nil) + } + hookCtx.Result = responseData + hookCtx.Error = nil + + if err := h.hooks.Execute(AfterDelete, hookCtx); err != nil { + logger.Error("AfterDelete hook failed: %v", err) + h.sendError(w, http.StatusInternalServerError, "hook_error", "Hook execution failed", err) + return + } + + h.sendResponse(w, responseData, nil) } // qualifyColumnName ensures column name is fully qualified with table name if not already diff --git a/pkg/restheadspec/hooks.go b/pkg/restheadspec/hooks.go new file mode 100644 index 0000000..0d72c48 --- /dev/null +++ b/pkg/restheadspec/hooks.go @@ -0,0 +1,140 @@ +package restheadspec + +import ( + "context" + "fmt" + + "github.com/Warky-Devs/ResolveSpec/pkg/common" + "github.com/Warky-Devs/ResolveSpec/pkg/logger" +) + +// HookType defines the type of hook to execute +type HookType string + +const ( + // Read operation hooks + BeforeRead HookType = "before_read" + AfterRead HookType = "after_read" + + // Create operation hooks + BeforeCreate HookType = "before_create" + AfterCreate HookType = "after_create" + + // Update operation hooks + BeforeUpdate HookType = "before_update" + AfterUpdate HookType = "after_update" + + // Delete operation hooks + BeforeDelete HookType = "before_delete" + AfterDelete HookType = "after_delete" +) + +// HookContext contains all the data available to a hook +type HookContext struct { + Context context.Context + Handler *Handler // Reference to the handler for accessing database, registry, etc. + Schema string + Entity string + TableName string + Model interface{} + Options ExtendedRequestOptions + + // Operation-specific fields + ID string + Data interface{} // For create/update operations + Result interface{} // For after hooks + Error error // For after hooks + QueryFilter string // For read operations + + // Response writer - allows hooks to modify response + Writer common.ResponseWriter +} + +// HookFunc is the signature for hook functions +// It receives a HookContext and can modify it or return an error +// If an error is returned, the operation will be aborted +type HookFunc func(*HookContext) error + +// HookRegistry manages all registered hooks +type HookRegistry struct { + hooks map[HookType][]HookFunc +} + +// NewHookRegistry creates a new hook registry +func NewHookRegistry() *HookRegistry { + return &HookRegistry{ + hooks: make(map[HookType][]HookFunc), + } +} + +// Register adds a new hook for the specified hook type +func (r *HookRegistry) Register(hookType HookType, hook HookFunc) { + if r.hooks == nil { + r.hooks = make(map[HookType][]HookFunc) + } + r.hooks[hookType] = append(r.hooks[hookType], hook) + logger.Info("Registered hook for %s (total: %d)", hookType, len(r.hooks[hookType])) +} + +// RegisterMultiple registers a hook for multiple hook types +func (r *HookRegistry) RegisterMultiple(hookTypes []HookType, hook HookFunc) { + for _, hookType := range hookTypes { + r.Register(hookType, hook) + } +} + +// Execute runs all hooks for the specified type in order +// If any hook returns an error, execution stops and the error is returned +func (r *HookRegistry) Execute(hookType HookType, ctx *HookContext) error { + hooks, exists := r.hooks[hookType] + if !exists || len(hooks) == 0 { + logger.Debug("No hooks registered for %s", hookType) + return nil + } + + logger.Debug("Executing %d hook(s) for %s", len(hooks), hookType) + + for i, hook := range hooks { + if err := hook(ctx); err != nil { + logger.Error("Hook %d for %s failed: %v", i+1, hookType, err) + return fmt.Errorf("hook execution failed: %w", err) + } + } + + logger.Debug("All hooks for %s executed successfully", hookType) + return nil +} + +// Clear removes all hooks for the specified type +func (r *HookRegistry) Clear(hookType HookType) { + delete(r.hooks, hookType) + logger.Info("Cleared all hooks for %s", hookType) +} + +// ClearAll removes all registered hooks +func (r *HookRegistry) ClearAll() { + r.hooks = make(map[HookType][]HookFunc) + logger.Info("Cleared all hooks") +} + +// Count returns the number of hooks registered for a specific type +func (r *HookRegistry) Count(hookType HookType) int { + if hooks, exists := r.hooks[hookType]; exists { + return len(hooks) + } + return 0 +} + +// HasHooks returns true if there are any hooks registered for the specified type +func (r *HookRegistry) HasHooks(hookType HookType) bool { + return r.Count(hookType) > 0 +} + +// GetAllHookTypes returns all hook types that have registered hooks +func (r *HookRegistry) GetAllHookTypes() []HookType { + types := make([]HookType, 0, len(r.hooks)) + for hookType := range r.hooks { + types = append(types, hookType) + } + return types +} diff --git a/pkg/restheadspec/hooks_example.go b/pkg/restheadspec/hooks_example.go new file mode 100644 index 0000000..6f3f17d --- /dev/null +++ b/pkg/restheadspec/hooks_example.go @@ -0,0 +1,197 @@ +package restheadspec + +import ( + "fmt" + + "github.com/Warky-Devs/ResolveSpec/pkg/logger" +) + +// This file contains example implementations showing how to use hooks +// These are just examples - you can implement hooks as needed for your application + +// ExampleLoggingHook logs before and after operations +func ExampleLoggingHook(hookType HookType) HookFunc { + return func(ctx *HookContext) error { + logger.Info("[%s] Operation: %s.%s, ID: %s", hookType, ctx.Schema, ctx.Entity, ctx.ID) + if ctx.Data != nil { + logger.Debug("[%s] Data: %+v", hookType, ctx.Data) + } + if ctx.Result != nil { + logger.Debug("[%s] Result: %+v", hookType, ctx.Result) + } + return nil + } +} + +// ExampleValidationHook validates data before create/update operations +func ExampleValidationHook(ctx *HookContext) error { + // Example: Ensure certain fields are present + if dataMap, ok := ctx.Data.(map[string]interface{}); ok { + // Check for required fields + requiredFields := []string{"name"} // Add your required fields here + for _, field := range requiredFields { + if _, exists := dataMap[field]; !exists { + return fmt.Errorf("required field missing: %s", field) + } + } + } + return nil +} + +// ExampleAuthorizationHook checks if the user has permission to perform the operation +func ExampleAuthorizationHook(ctx *HookContext) error { + // Example: Check user permissions from context + // userID, ok := ctx.Context.Value("user_id").(string) + // if !ok { + // return fmt.Errorf("unauthorized: no user in context") + // } + + // You can access the handler's database or registry if needed + // For example, to check permissions in the database: + // query := ctx.Handler.db.NewSelect().Table("permissions")... + + // Add your authorization logic here + logger.Debug("Authorization check for %s.%s", ctx.Schema, ctx.Entity) + return nil +} + +// ExampleDataTransformHook modifies data before create/update +func ExampleDataTransformHook(ctx *HookContext) error { + if dataMap, ok := ctx.Data.(map[string]interface{}); ok { + // Example: Add a timestamp or user ID + // dataMap["updated_at"] = time.Now() + // dataMap["updated_by"] = ctx.Context.Value("user_id") + + // Update the context with modified data + ctx.Data = dataMap + logger.Debug("Data transformed for %s.%s", ctx.Schema, ctx.Entity) + } + return nil +} + +// ExampleAuditLogHook creates audit log entries for operations +func ExampleAuditLogHook(hookType HookType) HookFunc { + return func(ctx *HookContext) error { + // Example: Log to audit system + auditEntry := map[string]interface{}{ + "operation": hookType, + "schema": ctx.Schema, + "entity": ctx.Entity, + "table_name": ctx.TableName, + "id": ctx.ID, + } + + if ctx.Error != nil { + auditEntry["error"] = ctx.Error.Error() + } + + logger.Info("Audit log: %+v", auditEntry) + + // In a real application, you would save this to a database using the handler + // Example: + // query := ctx.Handler.db.NewInsert().Table("audit_logs").Model(&auditEntry) + // if _, err := query.Exec(ctx.Context); err != nil { + // logger.Error("Failed to save audit log: %v", err) + // } + + return nil + } +} + +// ExampleCacheInvalidationHook invalidates cache after create/update/delete +func ExampleCacheInvalidationHook(ctx *HookContext) error { + // Example: Invalidate cache for the entity + cacheKey := fmt.Sprintf("%s.%s", ctx.Schema, ctx.Entity) + logger.Info("Invalidating cache for: %s", cacheKey) + + // Add your cache invalidation logic here + // cache.Delete(cacheKey) + + return nil +} + +// ExampleFilterSensitiveDataHook removes sensitive data from responses +func ExampleFilterSensitiveDataHook(ctx *HookContext) error { + // Example: Remove password fields from results + // This would be called in AfterRead hooks + logger.Debug("Filtering sensitive data for %s.%s", ctx.Schema, ctx.Entity) + + // Add your data filtering logic here + // You would iterate through ctx.Result and remove sensitive fields + + return nil +} + +// ExampleRelatedDataHook fetches related data using the handler's database +func ExampleRelatedDataHook(ctx *HookContext) error { + // Example: Fetch related data after reading the main entity + // This hook demonstrates using ctx.Handler to access the database + + if ctx.Entity == "users" && ctx.Result != nil { + // Example: Fetch user's recent activity + // userID := ... extract from ctx.Result + + // Use the handler's database to query related data + // query := ctx.Handler.db.NewSelect().Table("user_activity").Where("user_id = ?", userID) + // var activities []Activity + // if err := query.Scan(ctx.Context, &activities); err != nil { + // logger.Error("Failed to fetch user activities: %v", err) + // return err + // } + + // Optionally modify the result to include the related data + // if resultMap, ok := ctx.Result.(map[string]interface{}); ok { + // resultMap["recent_activities"] = activities + // } + + logger.Debug("Fetched related data for user entity") + } + + return nil +} + +// SetupExampleHooks demonstrates how to register hooks on a handler +func SetupExampleHooks(handler *Handler) { + hooks := handler.Hooks() + + // Register logging hooks for all operations + hooks.Register(BeforeRead, ExampleLoggingHook(BeforeRead)) + hooks.Register(AfterRead, ExampleLoggingHook(AfterRead)) + hooks.Register(BeforeCreate, ExampleLoggingHook(BeforeCreate)) + hooks.Register(AfterCreate, ExampleLoggingHook(AfterCreate)) + hooks.Register(BeforeUpdate, ExampleLoggingHook(BeforeUpdate)) + hooks.Register(AfterUpdate, ExampleLoggingHook(AfterUpdate)) + hooks.Register(BeforeDelete, ExampleLoggingHook(BeforeDelete)) + hooks.Register(AfterDelete, ExampleLoggingHook(AfterDelete)) + + // Register validation hooks for create/update + hooks.Register(BeforeCreate, ExampleValidationHook) + hooks.Register(BeforeUpdate, ExampleValidationHook) + + // Register authorization hooks for all operations + hooks.RegisterMultiple([]HookType{ + BeforeRead, BeforeCreate, BeforeUpdate, BeforeDelete, + }, ExampleAuthorizationHook) + + // Register data transform hook for create/update + hooks.Register(BeforeCreate, ExampleDataTransformHook) + hooks.Register(BeforeUpdate, ExampleDataTransformHook) + + // Register audit log hooks for after operations + hooks.Register(AfterCreate, ExampleAuditLogHook(AfterCreate)) + hooks.Register(AfterUpdate, ExampleAuditLogHook(AfterUpdate)) + hooks.Register(AfterDelete, ExampleAuditLogHook(AfterDelete)) + + // Register cache invalidation for after operations + hooks.Register(AfterCreate, ExampleCacheInvalidationHook) + hooks.Register(AfterUpdate, ExampleCacheInvalidationHook) + hooks.Register(AfterDelete, ExampleCacheInvalidationHook) + + // Register sensitive data filtering for read operations + hooks.Register(AfterRead, ExampleFilterSensitiveDataHook) + + // Register related data fetching for read operations + hooks.Register(AfterRead, ExampleRelatedDataHook) + + logger.Info("Example hooks registered successfully") +} diff --git a/pkg/restheadspec/hooks_test.go b/pkg/restheadspec/hooks_test.go new file mode 100644 index 0000000..d688a67 --- /dev/null +++ b/pkg/restheadspec/hooks_test.go @@ -0,0 +1,347 @@ +package restheadspec + +import ( + "context" + "fmt" + "testing" +) + +// TestHookRegistry tests the hook registry functionality +func TestHookRegistry(t *testing.T) { + registry := NewHookRegistry() + + // Test registering a hook + called := false + hook := func(ctx *HookContext) error { + called = true + return nil + } + + registry.Register(BeforeRead, hook) + + if registry.Count(BeforeRead) != 1 { + t.Errorf("Expected 1 hook, got %d", registry.Count(BeforeRead)) + } + + // Test executing a hook + ctx := &HookContext{ + Context: context.Background(), + Schema: "test", + Entity: "users", + } + + err := registry.Execute(BeforeRead, ctx) + if err != nil { + t.Errorf("Hook execution failed: %v", err) + } + + if !called { + t.Error("Hook was not called") + } +} + +// TestHookExecution tests hook execution order +func TestHookExecutionOrder(t *testing.T) { + registry := NewHookRegistry() + + order := []int{} + + hook1 := func(ctx *HookContext) error { + order = append(order, 1) + return nil + } + + hook2 := func(ctx *HookContext) error { + order = append(order, 2) + return nil + } + + hook3 := func(ctx *HookContext) error { + order = append(order, 3) + return nil + } + + registry.Register(BeforeCreate, hook1) + registry.Register(BeforeCreate, hook2) + registry.Register(BeforeCreate, hook3) + + ctx := &HookContext{ + Context: context.Background(), + Schema: "test", + Entity: "users", + } + + err := registry.Execute(BeforeCreate, ctx) + if err != nil { + t.Errorf("Hook execution failed: %v", err) + } + + if len(order) != 3 { + t.Errorf("Expected 3 hooks to be called, got %d", len(order)) + } + + if order[0] != 1 || order[1] != 2 || order[2] != 3 { + t.Errorf("Hooks executed in wrong order: %v", order) + } +} + +// TestHookError tests hook error handling +func TestHookError(t *testing.T) { + registry := NewHookRegistry() + + executed := []string{} + + hook1 := func(ctx *HookContext) error { + executed = append(executed, "hook1") + return nil + } + + hook2 := func(ctx *HookContext) error { + executed = append(executed, "hook2") + return fmt.Errorf("hook2 error") + } + + hook3 := func(ctx *HookContext) error { + executed = append(executed, "hook3") + return nil + } + + registry.Register(BeforeUpdate, hook1) + registry.Register(BeforeUpdate, hook2) + registry.Register(BeforeUpdate, hook3) + + ctx := &HookContext{ + Context: context.Background(), + Schema: "test", + Entity: "users", + } + + err := registry.Execute(BeforeUpdate, ctx) + if err == nil { + t.Error("Expected error from hook execution") + } + + if len(executed) != 2 { + t.Errorf("Expected only 2 hooks to be executed, got %d", len(executed)) + } + + if executed[0] != "hook1" || executed[1] != "hook2" { + t.Errorf("Unexpected execution order: %v", executed) + } +} + +// TestHookDataModification tests modifying data in hooks +func TestHookDataModification(t *testing.T) { + registry := NewHookRegistry() + + modifyHook := func(ctx *HookContext) error { + if dataMap, ok := ctx.Data.(map[string]interface{}); ok { + dataMap["modified"] = true + ctx.Data = dataMap + } + return nil + } + + registry.Register(BeforeCreate, modifyHook) + + data := map[string]interface{}{ + "name": "test", + } + + ctx := &HookContext{ + Context: context.Background(), + Schema: "test", + Entity: "users", + Data: data, + } + + err := registry.Execute(BeforeCreate, ctx) + if err != nil { + t.Errorf("Hook execution failed: %v", err) + } + + modifiedData := ctx.Data.(map[string]interface{}) + if !modifiedData["modified"].(bool) { + t.Error("Data was not modified by hook") + } +} + +// TestRegisterMultiple tests registering a hook for multiple types +func TestRegisterMultiple(t *testing.T) { + registry := NewHookRegistry() + + called := 0 + hook := func(ctx *HookContext) error { + called++ + return nil + } + + registry.RegisterMultiple([]HookType{ + BeforeRead, + BeforeCreate, + BeforeUpdate, + }, hook) + + if registry.Count(BeforeRead) != 1 { + t.Error("Hook not registered for BeforeRead") + } + if registry.Count(BeforeCreate) != 1 { + t.Error("Hook not registered for BeforeCreate") + } + if registry.Count(BeforeUpdate) != 1 { + t.Error("Hook not registered for BeforeUpdate") + } + + ctx := &HookContext{ + Context: context.Background(), + Schema: "test", + Entity: "users", + } + + registry.Execute(BeforeRead, ctx) + registry.Execute(BeforeCreate, ctx) + registry.Execute(BeforeUpdate, ctx) + + if called != 3 { + t.Errorf("Expected hook to be called 3 times, got %d", called) + } +} + +// TestClearHooks tests clearing hooks +func TestClearHooks(t *testing.T) { + registry := NewHookRegistry() + + hook := func(ctx *HookContext) error { + return nil + } + + registry.Register(BeforeRead, hook) + registry.Register(BeforeCreate, hook) + + if registry.Count(BeforeRead) != 1 { + t.Error("Hook not registered") + } + + registry.Clear(BeforeRead) + + if registry.Count(BeforeRead) != 0 { + t.Error("Hook not cleared") + } + + if registry.Count(BeforeCreate) != 1 { + t.Error("Wrong hook was cleared") + } +} + +// TestClearAllHooks tests clearing all hooks +func TestClearAllHooks(t *testing.T) { + registry := NewHookRegistry() + + hook := func(ctx *HookContext) error { + return nil + } + + registry.Register(BeforeRead, hook) + registry.Register(BeforeCreate, hook) + registry.Register(BeforeUpdate, hook) + + registry.ClearAll() + + if registry.Count(BeforeRead) != 0 || registry.Count(BeforeCreate) != 0 || registry.Count(BeforeUpdate) != 0 { + t.Error("Not all hooks were cleared") + } +} + +// TestHasHooks tests checking if hooks exist +func TestHasHooks(t *testing.T) { + registry := NewHookRegistry() + + if registry.HasHooks(BeforeRead) { + t.Error("Should not have hooks initially") + } + + hook := func(ctx *HookContext) error { + return nil + } + + registry.Register(BeforeRead, hook) + + if !registry.HasHooks(BeforeRead) { + t.Error("Should have hooks after registration") + } +} + +// TestGetAllHookTypes tests getting all registered hook types +func TestGetAllHookTypes(t *testing.T) { + registry := NewHookRegistry() + + hook := func(ctx *HookContext) error { + return nil + } + + registry.Register(BeforeRead, hook) + registry.Register(BeforeCreate, hook) + registry.Register(AfterUpdate, hook) + + types := registry.GetAllHookTypes() + + if len(types) != 3 { + t.Errorf("Expected 3 hook types, got %d", len(types)) + } + + // Verify all expected types are present + expectedTypes := map[HookType]bool{ + BeforeRead: true, + BeforeCreate: true, + AfterUpdate: true, + } + + for _, hookType := range types { + if !expectedTypes[hookType] { + t.Errorf("Unexpected hook type: %s", hookType) + } + } +} + +// TestHookContextHandler tests that hooks can access the handler +func TestHookContextHandler(t *testing.T) { + registry := NewHookRegistry() + + var capturedHandler *Handler + + hook := func(ctx *HookContext) error { + // Verify that the handler is accessible from the context + if ctx.Handler == nil { + return fmt.Errorf("handler is nil in hook context") + } + capturedHandler = ctx.Handler + return nil + } + + registry.Register(BeforeRead, hook) + + // Create a mock handler + handler := &Handler{ + hooks: registry, + } + + ctx := &HookContext{ + Context: context.Background(), + Handler: handler, + Schema: "test", + Entity: "users", + } + + err := registry.Execute(BeforeRead, ctx) + if err != nil { + t.Errorf("Hook execution failed: %v", err) + } + + if capturedHandler == nil { + t.Error("Handler was not captured from hook context") + } + + if capturedHandler != handler { + t.Error("Captured handler does not match original handler") + } +}