Added cursor filters and hooks

This commit is contained in:
Hein 2025-11-10 10:22:55 +02:00
parent fc82a9bc50
commit c8704c07dd
8 changed files with 1487 additions and 5 deletions

View File

@ -1,6 +1,11 @@
package database
import "strings"
import (
"reflect"
"strings"
"github.com/Warky-Devs/ResolveSpec/pkg/common"
)
// parseTableName splits a table name that may contain schema into separate schema and table
// For example: "public.users" -> ("public", "users")
@ -11,3 +16,157 @@ func parseTableName(fullTableName string) (schema, table string) {
}
return "", fullTableName
}
// GetPrimaryKeyName extracts the primary key column name from a model
// It first checks if the model implements PrimaryKeyNameProvider (GetIDName method)
// Falls back to reflection to find bun:",pk" tag, then gorm:"primaryKey" tag
func GetPrimaryKeyName(model any) string {
// Check if model implements PrimaryKeyNameProvider
if provider, ok := model.(common.PrimaryKeyNameProvider); ok {
return provider.GetIDName()
}
// Try Bun tag first
if pkName := getPrimaryKeyFromReflection(model, "bun"); pkName != "" {
return pkName
}
// Fall back to GORM tag
return getPrimaryKeyFromReflection(model, "gorm")
}
// GetModelColumns extracts all column names from a model using reflection
// It checks bun tags first, then gorm tags, then json tags, and finally falls back to lowercase field names
func GetModelColumns(model any) []string {
var columns []string
modelType := reflect.TypeOf(model)
// Unwrap pointers, slices, and arrays to get to the base struct type
for modelType != nil && (modelType.Kind() == reflect.Pointer || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
modelType = modelType.Elem()
}
// Validate that we have a struct type
if modelType == nil || modelType.Kind() != reflect.Struct {
return columns
}
for i := 0; i < modelType.NumField(); i++ {
field := modelType.Field(i)
// Get column name using the same logic as primary key extraction
columnName := getColumnNameFromField(field)
if columnName != "" {
columns = append(columns, columnName)
}
}
return columns
}
// getColumnNameFromField extracts the column name from a struct field
// Priority: bun tag -> gorm tag -> json tag -> lowercase field name
func getColumnNameFromField(field reflect.StructField) string {
// Try bun tag first
bunTag := field.Tag.Get("bun")
if bunTag != "" && bunTag != "-" {
if colName := extractColumnFromBunTag(bunTag); colName != "" {
return colName
}
}
// Try gorm tag
gormTag := field.Tag.Get("gorm")
if gormTag != "" && gormTag != "-" {
if colName := extractColumnFromGormTag(gormTag); colName != "" {
return colName
}
}
// Fall back to json tag
jsonTag := field.Tag.Get("json")
if jsonTag != "" && jsonTag != "-" {
// Extract just the field name before any options
parts := strings.Split(jsonTag, ",")
if len(parts) > 0 && parts[0] != "" {
return parts[0]
}
}
// Last resort: use field name in lowercase
return strings.ToLower(field.Name)
}
// getPrimaryKeyFromReflection uses reflection to find the primary key field
func getPrimaryKeyFromReflection(model any, ormType string) string {
val := reflect.ValueOf(model)
if val.Kind() == reflect.Pointer {
val = val.Elem()
}
if val.Kind() != reflect.Struct {
return ""
}
typ := val.Type()
for i := 0; i < typ.NumField(); i++ {
field := typ.Field(i)
switch ormType {
case "gorm":
// Check for gorm tag with primaryKey
gormTag := field.Tag.Get("gorm")
if strings.Contains(gormTag, "primaryKey") {
// Try to extract column name from gorm tag
if colName := extractColumnFromGormTag(gormTag); colName != "" {
return colName
}
// Fall back to json tag
if jsonTag := field.Tag.Get("json"); jsonTag != "" {
return strings.Split(jsonTag, ",")[0]
}
}
case "bun":
// Check for bun tag with pk flag
bunTag := field.Tag.Get("bun")
if strings.Contains(bunTag, "pk") {
// Extract column name from bun tag
if colName := extractColumnFromBunTag(bunTag); colName != "" {
return colName
}
// Fall back to json tag
if jsonTag := field.Tag.Get("json"); jsonTag != "" {
return strings.Split(jsonTag, ",")[0]
}
}
}
}
return ""
}
// extractColumnFromGormTag extracts the column name from a gorm tag
// Example: "column:id;primaryKey" -> "id"
func extractColumnFromGormTag(tag string) string {
parts := strings.Split(tag, ";")
for _, part := range parts {
part = strings.TrimSpace(part)
if colName, found := strings.CutPrefix(part, "column:"); found {
return colName
}
}
return ""
}
// extractColumnFromBunTag extracts the column name from a bun tag
// Example: "id,pk" -> "id"
// Example: ",pk" -> "" (will fall back to json tag)
func extractColumnFromBunTag(tag string) string {
parts := strings.Split(tag, ",")
if len(parts) > 0 && parts[0] != "" {
return parts[0]
}
return ""
}

View File

@ -0,0 +1,233 @@
package database
import (
"testing"
)
// Test models for GORM
type GormModelWithGetIDName struct {
ID int `gorm:"column:rid_test;primaryKey" json:"id"`
Name string `json:"name"`
}
func (m GormModelWithGetIDName) GetIDName() string {
return "rid_test"
}
type GormModelWithColumnTag struct {
ID int `gorm:"column:custom_id;primaryKey" json:"id"`
Name string `json:"name"`
}
type GormModelWithJSONFallback struct {
ID int `gorm:"primaryKey" json:"user_id"`
Name string `json:"name"`
}
// Test models for Bun
type BunModelWithGetIDName struct {
ID int `bun:"rid_test,pk" json:"id"`
Name string `json:"name"`
}
func (m BunModelWithGetIDName) GetIDName() string {
return "rid_test"
}
type BunModelWithColumnTag struct {
ID int `bun:"custom_id,pk" json:"id"`
Name string `json:"name"`
}
type BunModelWithJSONFallback struct {
ID int `bun:",pk" json:"user_id"`
Name string `json:"name"`
}
func TestGetPrimaryKeyName(t *testing.T) {
tests := []struct {
name string
model any
expected string
}{
{
name: "GORM model with GetIDName method",
model: GormModelWithGetIDName{},
expected: "rid_test",
},
{
name: "GORM model with column tag",
model: GormModelWithColumnTag{},
expected: "custom_id",
},
{
name: "GORM model with JSON fallback",
model: GormModelWithJSONFallback{},
expected: "user_id",
},
{
name: "GORM model pointer with GetIDName",
model: &GormModelWithGetIDName{},
expected: "rid_test",
},
{
name: "GORM model pointer with column tag",
model: &GormModelWithColumnTag{},
expected: "custom_id",
},
{
name: "Bun model with GetIDName method",
model: BunModelWithGetIDName{},
expected: "rid_test",
},
{
name: "Bun model with column tag",
model: BunModelWithColumnTag{},
expected: "custom_id",
},
{
name: "Bun model with JSON fallback",
model: BunModelWithJSONFallback{},
expected: "user_id",
},
{
name: "Bun model pointer with GetIDName",
model: &BunModelWithGetIDName{},
expected: "rid_test",
},
{
name: "Bun model pointer with column tag",
model: &BunModelWithColumnTag{},
expected: "custom_id",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := GetPrimaryKeyName(tt.model)
if result != tt.expected {
t.Errorf("GetPrimaryKeyName() = %v, want %v", result, tt.expected)
}
})
}
}
func TestExtractColumnFromGormTag(t *testing.T) {
tests := []struct {
name string
tag string
expected string
}{
{
name: "column tag with primaryKey",
tag: "column:rid_test;primaryKey",
expected: "rid_test",
},
{
name: "column tag with spaces",
tag: "column:user_id ; primaryKey ; autoIncrement",
expected: "user_id",
},
{
name: "no column tag",
tag: "primaryKey;autoIncrement",
expected: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := extractColumnFromGormTag(tt.tag)
if result != tt.expected {
t.Errorf("extractColumnFromGormTag() = %v, want %v", result, tt.expected)
}
})
}
}
func TestExtractColumnFromBunTag(t *testing.T) {
tests := []struct {
name string
tag string
expected string
}{
{
name: "column name with pk flag",
tag: "rid_test,pk",
expected: "rid_test",
},
{
name: "only pk flag",
tag: ",pk",
expected: "",
},
{
name: "column with multiple flags",
tag: "user_id,pk,autoincrement",
expected: "user_id",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := extractColumnFromBunTag(tt.tag)
if result != tt.expected {
t.Errorf("extractColumnFromBunTag() = %v, want %v", result, tt.expected)
}
})
}
}
func TestGetModelColumns(t *testing.T) {
tests := []struct {
name string
model any
expected []string
}{
{
name: "Bun model with multiple columns",
model: BunModelWithColumnTag{},
expected: []string{"custom_id", "name"},
},
{
name: "GORM model with multiple columns",
model: GormModelWithColumnTag{},
expected: []string{"custom_id", "name"},
},
{
name: "Bun model pointer",
model: &BunModelWithColumnTag{},
expected: []string{"custom_id", "name"},
},
{
name: "GORM model pointer",
model: &GormModelWithColumnTag{},
expected: []string{"custom_id", "name"},
},
{
name: "Bun model with JSON fallback",
model: BunModelWithJSONFallback{},
expected: []string{"user_id", "name"},
},
{
name: "GORM model with JSON fallback",
model: GormModelWithJSONFallback{},
expected: []string{"user_id", "name"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := GetModelColumns(tt.model)
if len(result) != len(tt.expected) {
t.Errorf("GetModelColumns() returned %d columns, want %d", len(result), len(tt.expected))
return
}
for i, col := range result {
if col != tt.expected[i] {
t.Errorf("GetModelColumns()[%d] = %v, want %v", i, col, tt.expected[i])
}
}
})
}
}

View File

@ -131,6 +131,11 @@ type TableNameProvider interface {
TableName() string
}
// PrimaryKeyNameProvider interface for models that provide primary key column names
type PrimaryKeyNameProvider interface {
GetIDName() string
}
// SchemaProvider interface for models that provide schema names
type SchemaProvider interface {
SchemaName() string

223
pkg/restheadspec/cursor.go Normal file
View File

@ -0,0 +1,223 @@
package restheadspec
import (
"fmt"
"strings"
"github.com/Warky-Devs/ResolveSpec/pkg/common"
)
// CursorDirection defines pagination direction
type CursorDirection int
const (
CursorForward CursorDirection = 1
CursorBackward CursorDirection = -1
)
// GetCursorFilter generates a SQL `EXISTS` subquery for cursor-based pagination.
// It uses the current request's sort, cursor, joins (via Expand), and CQL (via ComputedQL).
//
// Parameters:
// - tableName: name of the main table (e.g. "post")
// - pkName: primary key column (e.g. "id")
// - modelColumns: optional list of valid main-table columns (for validation). Pass nil to skip.
// - expandJoins: optional map[alias]string of JOIN clauses (e.g. "user": "LEFT JOIN user ON ...")
//
// Returns SQL snippet to embed in WHERE clause.
func (opts *ExtendedRequestOptions) GetCursorFilter(
tableName string,
pkName string,
modelColumns []string, // optional: for validation
expandJoins map[string]string, // optional: alias → JOIN SQL
) (string, error) {
// --------------------------------------------------------------------- //
// 1. Determine active cursor
// --------------------------------------------------------------------- //
cursorID, direction := opts.getActiveCursor()
if cursorID == "" {
return "", fmt.Errorf("no cursor provided for table %s", tableName)
}
// --------------------------------------------------------------------- //
// 2. Extract sort columns
// --------------------------------------------------------------------- //
sortItems := opts.getSortColumns()
if len(sortItems) == 0 {
return "", fmt.Errorf("no sort columns defined")
}
// --------------------------------------------------------------------- //
// 3. Prepare
// --------------------------------------------------------------------- //
var whereClauses []string
joinSQL := ""
reverse := direction < 0
// --------------------------------------------------------------------- //
// 4. Process each sort column
// --------------------------------------------------------------------- //
for _, s := range sortItems {
col := strings.TrimSpace(s.Column)
if col == "" {
continue
}
// Parse: "user.name desc nulls last"
parts := strings.Split(col, ".")
field := strings.TrimSpace(parts[len(parts)-1])
prefix := strings.Join(parts[:len(parts)-1], ".")
// Direction from struct or string
desc := strings.EqualFold(s.Direction, "desc") ||
strings.Contains(strings.ToLower(field), "desc")
field = opts.cleanSortField(field)
if reverse {
desc = !desc
}
// Resolve column
cursorCol, targetCol, isJoin, err := opts.resolveColumn(
field, prefix, tableName, modelColumns,
)
if err != nil {
fmt.Printf("WARN: Skipping invalid sort column %q: %v\n", col, err)
continue
}
// Handle joins
if isJoin && expandJoins != nil {
if joinClause, ok := expandJoins[prefix]; ok {
jSQL, cRef := rewriteJoin(joinClause, tableName, prefix)
joinSQL = jSQL
cursorCol = cRef + "." + field
targetCol = prefix + "." + field
}
}
// Build inequality
op := "<"
if desc {
op = ">"
}
whereClauses = append(whereClauses, fmt.Sprintf("%s %s %s", cursorCol, op, targetCol))
}
if len(whereClauses) == 0 {
return "", fmt.Errorf("no valid sort columns after filtering")
}
// --------------------------------------------------------------------- //
// 5. Build priority OR-AND chain
// --------------------------------------------------------------------- //
orSQL := buildPriorityChain(whereClauses)
// --------------------------------------------------------------------- //
// 6. Final EXISTS subquery
// --------------------------------------------------------------------- //
query := fmt.Sprintf(`EXISTS (
SELECT 1
FROM %s cursor_select
%s
WHERE cursor_select.%s = %s
AND (%s)
)`,
tableName,
joinSQL,
pkName,
cursorID,
orSQL,
)
return query, nil
}
// ------------------------------------------------------------------------- //
// Helper: get active cursor (forward or backward)
func (opts *ExtendedRequestOptions) getActiveCursor() (id string, direction CursorDirection) {
if opts.CursorForward != "" {
return opts.CursorForward, CursorForward
}
if opts.CursorBackward != "" {
return opts.CursorBackward, CursorBackward
}
return "", 0
}
// Helper: extract sort columns
func (opts *ExtendedRequestOptions) getSortColumns() []common.SortOption {
if opts.RequestOptions.Sort != nil {
return opts.RequestOptions.Sort
}
return nil
}
// Helper: clean sort field (remove desc, asc, nulls)
func (opts *ExtendedRequestOptions) cleanSortField(field string) string {
f := strings.ToLower(field)
for _, token := range []string{"desc", "asc", "nulls last", "nulls first"} {
f = strings.ReplaceAll(f, token, "")
}
return strings.TrimSpace(f)
}
// Helper: resolve column (main, JSON, CQL, join)
func (opts *ExtendedRequestOptions) resolveColumn(
field, prefix, tableName string,
modelColumns []string,
) (cursorCol, targetCol string, isJoin bool, err error) {
// JSON field
if strings.Contains(field, "->") {
return "cursor_select." + field, tableName + "." + field, false, nil
}
// CQL via ComputedQL
if strings.Contains(strings.ToLower(field), "cql") && opts.ComputedQL != nil {
if expr, ok := opts.ComputedQL[field]; ok {
return "cursor_select." + expr, expr, false, nil
}
}
// Main table column
if modelColumns != nil {
for _, col := range modelColumns {
if strings.EqualFold(col, field) {
return "cursor_select." + field, tableName + "." + field, false, nil
}
}
} else {
// No validation → allow all main-table fields
return "cursor_select." + field, tableName + "." + field, false, nil
}
// Joined column
if prefix != "" && prefix != tableName {
return "", "", true, nil
}
return "", "", false, fmt.Errorf("invalid column: %s", field)
}
// ------------------------------------------------------------------------- //
// Helper: rewrite JOIN clause for cursor subquery
func rewriteJoin(joinClause, mainTable, alias string) (joinSQL, cursorAlias string) {
joinSQL = strings.ReplaceAll(joinClause, mainTable+".", "cursor_select.")
cursorAlias = "cursor_select_" + alias
joinSQL = strings.ReplaceAll(joinSQL, " "+alias+" ", " "+cursorAlias+" ")
joinSQL = strings.ReplaceAll(joinSQL, " "+alias+".", " "+cursorAlias+".")
return joinSQL, cursorAlias
}
// ------------------------------------------------------------------------- //
// Helper: build OR-AND priority chain
func buildPriorityChain(clauses []string) string {
var or []string
for i := 0; i < len(clauses); i++ {
and := strings.Join(clauses[:i+1], "\n AND ")
or = append(or, "("+and+")")
}
return strings.Join(or, "\n OR ")
}

View File

@ -10,6 +10,7 @@ import (
"strings"
"github.com/Warky-Devs/ResolveSpec/pkg/common"
"github.com/Warky-Devs/ResolveSpec/pkg/common/adapters/database"
"github.com/Warky-Devs/ResolveSpec/pkg/logger"
)
@ -18,6 +19,7 @@ import (
type Handler struct {
db common.Database
registry common.ModelRegistry
hooks *HookRegistry
}
// NewHandler creates a new API handler with database and registry abstractions
@ -25,9 +27,16 @@ func NewHandler(db common.Database, registry common.ModelRegistry) *Handler {
return &Handler{
db: db,
registry: registry,
hooks: NewHookRegistry(),
}
}
// Hooks returns the hook registry for this handler
// Use this to register custom hooks for operations
func (h *Handler) Hooks() *HookRegistry {
return h.hooks
}
// handlePanic is a helper function to handle panics with stack traces
func (h *Handler) handlePanic(w common.ResponseWriter, method string, err interface{}) {
stack := debug.Stack()
@ -184,6 +193,25 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
tableName := GetTableName(ctx)
model := GetModel(ctx)
// Execute BeforeRead hooks
hookCtx := &HookContext{
Context: ctx,
Handler: h,
Schema: schema,
Entity: entity,
TableName: tableName,
Model: model,
Options: options,
ID: id,
Writer: w,
}
if err := h.hooks.Execute(BeforeRead, hookCtx); err != nil {
logger.Error("BeforeRead hook failed: %v", err)
h.sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err)
return
}
// Validate and unwrap model type to get base struct
modelType := reflect.TypeOf(model)
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
@ -310,6 +338,39 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
query = query.Offset(*options.Offset)
}
// Apply cursor-based pagination
if len(options.CursorForward) > 0 || len(options.CursorBackward) > 0 {
logger.Debug("Applying cursor pagination")
// Get primary key name
pkName := database.GetPrimaryKeyName(model)
// Extract model columns for validation using the generic database function
modelColumns := database.GetModelColumns(model)
// Build expand joins map (if needed in future)
var expandJoins map[string]string
if len(options.Expand) > 0 {
expandJoins = make(map[string]string)
// TODO: Build actual JOIN SQL for each expand relation
// For now, pass empty map as joins are handled via Preload
}
// Get cursor filter SQL
cursorFilter, err := options.GetCursorFilter(tableName, pkName, modelColumns, expandJoins)
if err != nil {
logger.Error("Error building cursor filter: %v", err)
h.sendError(w, http.StatusBadRequest, "cursor_error", "Invalid cursor pagination", err)
return
}
// Apply cursor filter to query
if cursorFilter != "" {
logger.Debug("Applying cursor filter: %s", cursorFilter)
query = query.Where(cursorFilter)
}
}
// Execute query - modelPtr was already created earlier
if err := query.Scan(ctx, modelPtr); err != nil {
logger.Error("Error executing query: %v", err)
@ -333,6 +394,16 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
Offset: offset,
}
// Execute AfterRead hooks
hookCtx.Result = modelPtr
hookCtx.Error = nil
if err := h.hooks.Execute(AfterRead, hookCtx); err != nil {
logger.Error("AfterRead hook failed: %v", err)
h.sendError(w, http.StatusInternalServerError, "hook_error", "Hook execution failed", err)
return
}
h.sendFormattedResponse(w, modelPtr, metadata, options)
}
@ -351,6 +422,28 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
logger.Info("Creating record in %s.%s", schema, entity)
// Execute BeforeCreate hooks
hookCtx := &HookContext{
Context: ctx,
Handler: h,
Schema: schema,
Entity: entity,
TableName: tableName,
Model: model,
Options: options,
Data: data,
Writer: w,
}
if err := h.hooks.Execute(BeforeCreate, hookCtx); err != nil {
logger.Error("BeforeCreate hook failed: %v", err)
h.sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err)
return
}
// Use potentially modified data from hook context
data = hookCtx.Data
// Handle batch creation
dataValue := reflect.ValueOf(data)
if dataValue.Kind() == reflect.Slice || dataValue.Kind() == reflect.Array {
@ -385,6 +478,16 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
return
}
// Execute AfterCreate hooks for batch creation
hookCtx.Result = map[string]interface{}{"created": dataValue.Len()}
hookCtx.Error = nil
if err := h.hooks.Execute(AfterCreate, hookCtx); err != nil {
logger.Error("AfterCreate hook failed: %v", err)
h.sendError(w, http.StatusInternalServerError, "hook_error", "Hook execution failed", err)
return
}
h.sendResponse(w, map[string]interface{}{"created": dataValue.Len()}, nil)
return
}
@ -410,6 +513,16 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
return
}
// Execute AfterCreate hooks for single record creation
hookCtx.Result = modelValue
hookCtx.Error = nil
if err := h.hooks.Execute(AfterCreate, hookCtx); err != nil {
logger.Error("AfterCreate hook failed: %v", err)
h.sendError(w, http.StatusInternalServerError, "hook_error", "Hook execution failed", err)
return
}
h.sendResponse(w, modelValue, nil)
}
@ -424,9 +537,33 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id
schema := GetSchema(ctx)
entity := GetEntity(ctx)
tableName := GetTableName(ctx)
model := GetModel(ctx)
logger.Info("Updating record in %s.%s", schema, entity)
// Execute BeforeUpdate hooks
hookCtx := &HookContext{
Context: ctx,
Handler: h,
Schema: schema,
Entity: entity,
TableName: tableName,
Model: model,
Options: options,
ID: id,
Data: data,
Writer: w,
}
if err := h.hooks.Execute(BeforeUpdate, hookCtx); err != nil {
logger.Error("BeforeUpdate hook failed: %v", err)
h.sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err)
return
}
// Use potentially modified data from hook context
data = hookCtx.Data
// Convert data to map
dataMap, ok := data.(map[string]interface{})
if !ok {
@ -462,9 +599,20 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id
return
}
h.sendResponse(w, map[string]interface{}{
// Execute AfterUpdate hooks
responseData := map[string]interface{}{
"updated": result.RowsAffected(),
}, nil)
}
hookCtx.Result = responseData
hookCtx.Error = nil
if err := h.hooks.Execute(AfterUpdate, hookCtx); err != nil {
logger.Error("AfterUpdate hook failed: %v", err)
h.sendError(w, http.StatusInternalServerError, "hook_error", "Hook execution failed", err)
return
}
h.sendResponse(w, responseData, nil)
}
func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id string) {
@ -478,9 +626,28 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
schema := GetSchema(ctx)
entity := GetEntity(ctx)
tableName := GetTableName(ctx)
model := GetModel(ctx)
logger.Info("Deleting record from %s.%s", schema, entity)
// Execute BeforeDelete hooks
hookCtx := &HookContext{
Context: ctx,
Handler: h,
Schema: schema,
Entity: entity,
TableName: tableName,
Model: model,
ID: id,
Writer: w,
}
if err := h.hooks.Execute(BeforeDelete, hookCtx); err != nil {
logger.Error("BeforeDelete hook failed: %v", err)
h.sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err)
return
}
query := h.db.NewDelete().Table(tableName)
if id == "" {
@ -497,9 +664,20 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
return
}
h.sendResponse(w, map[string]interface{}{
// Execute AfterDelete hooks
responseData := map[string]interface{}{
"deleted": result.RowsAffected(),
}, nil)
}
hookCtx.Result = responseData
hookCtx.Error = nil
if err := h.hooks.Execute(AfterDelete, hookCtx); err != nil {
logger.Error("AfterDelete hook failed: %v", err)
h.sendError(w, http.StatusInternalServerError, "hook_error", "Hook execution failed", err)
return
}
h.sendResponse(w, responseData, nil)
}
// qualifyColumnName ensures column name is fully qualified with table name if not already

140
pkg/restheadspec/hooks.go Normal file
View File

@ -0,0 +1,140 @@
package restheadspec
import (
"context"
"fmt"
"github.com/Warky-Devs/ResolveSpec/pkg/common"
"github.com/Warky-Devs/ResolveSpec/pkg/logger"
)
// HookType defines the type of hook to execute
type HookType string
const (
// Read operation hooks
BeforeRead HookType = "before_read"
AfterRead HookType = "after_read"
// Create operation hooks
BeforeCreate HookType = "before_create"
AfterCreate HookType = "after_create"
// Update operation hooks
BeforeUpdate HookType = "before_update"
AfterUpdate HookType = "after_update"
// Delete operation hooks
BeforeDelete HookType = "before_delete"
AfterDelete HookType = "after_delete"
)
// HookContext contains all the data available to a hook
type HookContext struct {
Context context.Context
Handler *Handler // Reference to the handler for accessing database, registry, etc.
Schema string
Entity string
TableName string
Model interface{}
Options ExtendedRequestOptions
// Operation-specific fields
ID string
Data interface{} // For create/update operations
Result interface{} // For after hooks
Error error // For after hooks
QueryFilter string // For read operations
// Response writer - allows hooks to modify response
Writer common.ResponseWriter
}
// HookFunc is the signature for hook functions
// It receives a HookContext and can modify it or return an error
// If an error is returned, the operation will be aborted
type HookFunc func(*HookContext) error
// HookRegistry manages all registered hooks
type HookRegistry struct {
hooks map[HookType][]HookFunc
}
// NewHookRegistry creates a new hook registry
func NewHookRegistry() *HookRegistry {
return &HookRegistry{
hooks: make(map[HookType][]HookFunc),
}
}
// Register adds a new hook for the specified hook type
func (r *HookRegistry) Register(hookType HookType, hook HookFunc) {
if r.hooks == nil {
r.hooks = make(map[HookType][]HookFunc)
}
r.hooks[hookType] = append(r.hooks[hookType], hook)
logger.Info("Registered hook for %s (total: %d)", hookType, len(r.hooks[hookType]))
}
// RegisterMultiple registers a hook for multiple hook types
func (r *HookRegistry) RegisterMultiple(hookTypes []HookType, hook HookFunc) {
for _, hookType := range hookTypes {
r.Register(hookType, hook)
}
}
// Execute runs all hooks for the specified type in order
// If any hook returns an error, execution stops and the error is returned
func (r *HookRegistry) Execute(hookType HookType, ctx *HookContext) error {
hooks, exists := r.hooks[hookType]
if !exists || len(hooks) == 0 {
logger.Debug("No hooks registered for %s", hookType)
return nil
}
logger.Debug("Executing %d hook(s) for %s", len(hooks), hookType)
for i, hook := range hooks {
if err := hook(ctx); err != nil {
logger.Error("Hook %d for %s failed: %v", i+1, hookType, err)
return fmt.Errorf("hook execution failed: %w", err)
}
}
logger.Debug("All hooks for %s executed successfully", hookType)
return nil
}
// Clear removes all hooks for the specified type
func (r *HookRegistry) Clear(hookType HookType) {
delete(r.hooks, hookType)
logger.Info("Cleared all hooks for %s", hookType)
}
// ClearAll removes all registered hooks
func (r *HookRegistry) ClearAll() {
r.hooks = make(map[HookType][]HookFunc)
logger.Info("Cleared all hooks")
}
// Count returns the number of hooks registered for a specific type
func (r *HookRegistry) Count(hookType HookType) int {
if hooks, exists := r.hooks[hookType]; exists {
return len(hooks)
}
return 0
}
// HasHooks returns true if there are any hooks registered for the specified type
func (r *HookRegistry) HasHooks(hookType HookType) bool {
return r.Count(hookType) > 0
}
// GetAllHookTypes returns all hook types that have registered hooks
func (r *HookRegistry) GetAllHookTypes() []HookType {
types := make([]HookType, 0, len(r.hooks))
for hookType := range r.hooks {
types = append(types, hookType)
}
return types
}

View File

@ -0,0 +1,197 @@
package restheadspec
import (
"fmt"
"github.com/Warky-Devs/ResolveSpec/pkg/logger"
)
// This file contains example implementations showing how to use hooks
// These are just examples - you can implement hooks as needed for your application
// ExampleLoggingHook logs before and after operations
func ExampleLoggingHook(hookType HookType) HookFunc {
return func(ctx *HookContext) error {
logger.Info("[%s] Operation: %s.%s, ID: %s", hookType, ctx.Schema, ctx.Entity, ctx.ID)
if ctx.Data != nil {
logger.Debug("[%s] Data: %+v", hookType, ctx.Data)
}
if ctx.Result != nil {
logger.Debug("[%s] Result: %+v", hookType, ctx.Result)
}
return nil
}
}
// ExampleValidationHook validates data before create/update operations
func ExampleValidationHook(ctx *HookContext) error {
// Example: Ensure certain fields are present
if dataMap, ok := ctx.Data.(map[string]interface{}); ok {
// Check for required fields
requiredFields := []string{"name"} // Add your required fields here
for _, field := range requiredFields {
if _, exists := dataMap[field]; !exists {
return fmt.Errorf("required field missing: %s", field)
}
}
}
return nil
}
// ExampleAuthorizationHook checks if the user has permission to perform the operation
func ExampleAuthorizationHook(ctx *HookContext) error {
// Example: Check user permissions from context
// userID, ok := ctx.Context.Value("user_id").(string)
// if !ok {
// return fmt.Errorf("unauthorized: no user in context")
// }
// You can access the handler's database or registry if needed
// For example, to check permissions in the database:
// query := ctx.Handler.db.NewSelect().Table("permissions")...
// Add your authorization logic here
logger.Debug("Authorization check for %s.%s", ctx.Schema, ctx.Entity)
return nil
}
// ExampleDataTransformHook modifies data before create/update
func ExampleDataTransformHook(ctx *HookContext) error {
if dataMap, ok := ctx.Data.(map[string]interface{}); ok {
// Example: Add a timestamp or user ID
// dataMap["updated_at"] = time.Now()
// dataMap["updated_by"] = ctx.Context.Value("user_id")
// Update the context with modified data
ctx.Data = dataMap
logger.Debug("Data transformed for %s.%s", ctx.Schema, ctx.Entity)
}
return nil
}
// ExampleAuditLogHook creates audit log entries for operations
func ExampleAuditLogHook(hookType HookType) HookFunc {
return func(ctx *HookContext) error {
// Example: Log to audit system
auditEntry := map[string]interface{}{
"operation": hookType,
"schema": ctx.Schema,
"entity": ctx.Entity,
"table_name": ctx.TableName,
"id": ctx.ID,
}
if ctx.Error != nil {
auditEntry["error"] = ctx.Error.Error()
}
logger.Info("Audit log: %+v", auditEntry)
// In a real application, you would save this to a database using the handler
// Example:
// query := ctx.Handler.db.NewInsert().Table("audit_logs").Model(&auditEntry)
// if _, err := query.Exec(ctx.Context); err != nil {
// logger.Error("Failed to save audit log: %v", err)
// }
return nil
}
}
// ExampleCacheInvalidationHook invalidates cache after create/update/delete
func ExampleCacheInvalidationHook(ctx *HookContext) error {
// Example: Invalidate cache for the entity
cacheKey := fmt.Sprintf("%s.%s", ctx.Schema, ctx.Entity)
logger.Info("Invalidating cache for: %s", cacheKey)
// Add your cache invalidation logic here
// cache.Delete(cacheKey)
return nil
}
// ExampleFilterSensitiveDataHook removes sensitive data from responses
func ExampleFilterSensitiveDataHook(ctx *HookContext) error {
// Example: Remove password fields from results
// This would be called in AfterRead hooks
logger.Debug("Filtering sensitive data for %s.%s", ctx.Schema, ctx.Entity)
// Add your data filtering logic here
// You would iterate through ctx.Result and remove sensitive fields
return nil
}
// ExampleRelatedDataHook fetches related data using the handler's database
func ExampleRelatedDataHook(ctx *HookContext) error {
// Example: Fetch related data after reading the main entity
// This hook demonstrates using ctx.Handler to access the database
if ctx.Entity == "users" && ctx.Result != nil {
// Example: Fetch user's recent activity
// userID := ... extract from ctx.Result
// Use the handler's database to query related data
// query := ctx.Handler.db.NewSelect().Table("user_activity").Where("user_id = ?", userID)
// var activities []Activity
// if err := query.Scan(ctx.Context, &activities); err != nil {
// logger.Error("Failed to fetch user activities: %v", err)
// return err
// }
// Optionally modify the result to include the related data
// if resultMap, ok := ctx.Result.(map[string]interface{}); ok {
// resultMap["recent_activities"] = activities
// }
logger.Debug("Fetched related data for user entity")
}
return nil
}
// SetupExampleHooks demonstrates how to register hooks on a handler
func SetupExampleHooks(handler *Handler) {
hooks := handler.Hooks()
// Register logging hooks for all operations
hooks.Register(BeforeRead, ExampleLoggingHook(BeforeRead))
hooks.Register(AfterRead, ExampleLoggingHook(AfterRead))
hooks.Register(BeforeCreate, ExampleLoggingHook(BeforeCreate))
hooks.Register(AfterCreate, ExampleLoggingHook(AfterCreate))
hooks.Register(BeforeUpdate, ExampleLoggingHook(BeforeUpdate))
hooks.Register(AfterUpdate, ExampleLoggingHook(AfterUpdate))
hooks.Register(BeforeDelete, ExampleLoggingHook(BeforeDelete))
hooks.Register(AfterDelete, ExampleLoggingHook(AfterDelete))
// Register validation hooks for create/update
hooks.Register(BeforeCreate, ExampleValidationHook)
hooks.Register(BeforeUpdate, ExampleValidationHook)
// Register authorization hooks for all operations
hooks.RegisterMultiple([]HookType{
BeforeRead, BeforeCreate, BeforeUpdate, BeforeDelete,
}, ExampleAuthorizationHook)
// Register data transform hook for create/update
hooks.Register(BeforeCreate, ExampleDataTransformHook)
hooks.Register(BeforeUpdate, ExampleDataTransformHook)
// Register audit log hooks for after operations
hooks.Register(AfterCreate, ExampleAuditLogHook(AfterCreate))
hooks.Register(AfterUpdate, ExampleAuditLogHook(AfterUpdate))
hooks.Register(AfterDelete, ExampleAuditLogHook(AfterDelete))
// Register cache invalidation for after operations
hooks.Register(AfterCreate, ExampleCacheInvalidationHook)
hooks.Register(AfterUpdate, ExampleCacheInvalidationHook)
hooks.Register(AfterDelete, ExampleCacheInvalidationHook)
// Register sensitive data filtering for read operations
hooks.Register(AfterRead, ExampleFilterSensitiveDataHook)
// Register related data fetching for read operations
hooks.Register(AfterRead, ExampleRelatedDataHook)
logger.Info("Example hooks registered successfully")
}

View File

@ -0,0 +1,347 @@
package restheadspec
import (
"context"
"fmt"
"testing"
)
// TestHookRegistry tests the hook registry functionality
func TestHookRegistry(t *testing.T) {
registry := NewHookRegistry()
// Test registering a hook
called := false
hook := func(ctx *HookContext) error {
called = true
return nil
}
registry.Register(BeforeRead, hook)
if registry.Count(BeforeRead) != 1 {
t.Errorf("Expected 1 hook, got %d", registry.Count(BeforeRead))
}
// Test executing a hook
ctx := &HookContext{
Context: context.Background(),
Schema: "test",
Entity: "users",
}
err := registry.Execute(BeforeRead, ctx)
if err != nil {
t.Errorf("Hook execution failed: %v", err)
}
if !called {
t.Error("Hook was not called")
}
}
// TestHookExecution tests hook execution order
func TestHookExecutionOrder(t *testing.T) {
registry := NewHookRegistry()
order := []int{}
hook1 := func(ctx *HookContext) error {
order = append(order, 1)
return nil
}
hook2 := func(ctx *HookContext) error {
order = append(order, 2)
return nil
}
hook3 := func(ctx *HookContext) error {
order = append(order, 3)
return nil
}
registry.Register(BeforeCreate, hook1)
registry.Register(BeforeCreate, hook2)
registry.Register(BeforeCreate, hook3)
ctx := &HookContext{
Context: context.Background(),
Schema: "test",
Entity: "users",
}
err := registry.Execute(BeforeCreate, ctx)
if err != nil {
t.Errorf("Hook execution failed: %v", err)
}
if len(order) != 3 {
t.Errorf("Expected 3 hooks to be called, got %d", len(order))
}
if order[0] != 1 || order[1] != 2 || order[2] != 3 {
t.Errorf("Hooks executed in wrong order: %v", order)
}
}
// TestHookError tests hook error handling
func TestHookError(t *testing.T) {
registry := NewHookRegistry()
executed := []string{}
hook1 := func(ctx *HookContext) error {
executed = append(executed, "hook1")
return nil
}
hook2 := func(ctx *HookContext) error {
executed = append(executed, "hook2")
return fmt.Errorf("hook2 error")
}
hook3 := func(ctx *HookContext) error {
executed = append(executed, "hook3")
return nil
}
registry.Register(BeforeUpdate, hook1)
registry.Register(BeforeUpdate, hook2)
registry.Register(BeforeUpdate, hook3)
ctx := &HookContext{
Context: context.Background(),
Schema: "test",
Entity: "users",
}
err := registry.Execute(BeforeUpdate, ctx)
if err == nil {
t.Error("Expected error from hook execution")
}
if len(executed) != 2 {
t.Errorf("Expected only 2 hooks to be executed, got %d", len(executed))
}
if executed[0] != "hook1" || executed[1] != "hook2" {
t.Errorf("Unexpected execution order: %v", executed)
}
}
// TestHookDataModification tests modifying data in hooks
func TestHookDataModification(t *testing.T) {
registry := NewHookRegistry()
modifyHook := func(ctx *HookContext) error {
if dataMap, ok := ctx.Data.(map[string]interface{}); ok {
dataMap["modified"] = true
ctx.Data = dataMap
}
return nil
}
registry.Register(BeforeCreate, modifyHook)
data := map[string]interface{}{
"name": "test",
}
ctx := &HookContext{
Context: context.Background(),
Schema: "test",
Entity: "users",
Data: data,
}
err := registry.Execute(BeforeCreate, ctx)
if err != nil {
t.Errorf("Hook execution failed: %v", err)
}
modifiedData := ctx.Data.(map[string]interface{})
if !modifiedData["modified"].(bool) {
t.Error("Data was not modified by hook")
}
}
// TestRegisterMultiple tests registering a hook for multiple types
func TestRegisterMultiple(t *testing.T) {
registry := NewHookRegistry()
called := 0
hook := func(ctx *HookContext) error {
called++
return nil
}
registry.RegisterMultiple([]HookType{
BeforeRead,
BeforeCreate,
BeforeUpdate,
}, hook)
if registry.Count(BeforeRead) != 1 {
t.Error("Hook not registered for BeforeRead")
}
if registry.Count(BeforeCreate) != 1 {
t.Error("Hook not registered for BeforeCreate")
}
if registry.Count(BeforeUpdate) != 1 {
t.Error("Hook not registered for BeforeUpdate")
}
ctx := &HookContext{
Context: context.Background(),
Schema: "test",
Entity: "users",
}
registry.Execute(BeforeRead, ctx)
registry.Execute(BeforeCreate, ctx)
registry.Execute(BeforeUpdate, ctx)
if called != 3 {
t.Errorf("Expected hook to be called 3 times, got %d", called)
}
}
// TestClearHooks tests clearing hooks
func TestClearHooks(t *testing.T) {
registry := NewHookRegistry()
hook := func(ctx *HookContext) error {
return nil
}
registry.Register(BeforeRead, hook)
registry.Register(BeforeCreate, hook)
if registry.Count(BeforeRead) != 1 {
t.Error("Hook not registered")
}
registry.Clear(BeforeRead)
if registry.Count(BeforeRead) != 0 {
t.Error("Hook not cleared")
}
if registry.Count(BeforeCreate) != 1 {
t.Error("Wrong hook was cleared")
}
}
// TestClearAllHooks tests clearing all hooks
func TestClearAllHooks(t *testing.T) {
registry := NewHookRegistry()
hook := func(ctx *HookContext) error {
return nil
}
registry.Register(BeforeRead, hook)
registry.Register(BeforeCreate, hook)
registry.Register(BeforeUpdate, hook)
registry.ClearAll()
if registry.Count(BeforeRead) != 0 || registry.Count(BeforeCreate) != 0 || registry.Count(BeforeUpdate) != 0 {
t.Error("Not all hooks were cleared")
}
}
// TestHasHooks tests checking if hooks exist
func TestHasHooks(t *testing.T) {
registry := NewHookRegistry()
if registry.HasHooks(BeforeRead) {
t.Error("Should not have hooks initially")
}
hook := func(ctx *HookContext) error {
return nil
}
registry.Register(BeforeRead, hook)
if !registry.HasHooks(BeforeRead) {
t.Error("Should have hooks after registration")
}
}
// TestGetAllHookTypes tests getting all registered hook types
func TestGetAllHookTypes(t *testing.T) {
registry := NewHookRegistry()
hook := func(ctx *HookContext) error {
return nil
}
registry.Register(BeforeRead, hook)
registry.Register(BeforeCreate, hook)
registry.Register(AfterUpdate, hook)
types := registry.GetAllHookTypes()
if len(types) != 3 {
t.Errorf("Expected 3 hook types, got %d", len(types))
}
// Verify all expected types are present
expectedTypes := map[HookType]bool{
BeforeRead: true,
BeforeCreate: true,
AfterUpdate: true,
}
for _, hookType := range types {
if !expectedTypes[hookType] {
t.Errorf("Unexpected hook type: %s", hookType)
}
}
}
// TestHookContextHandler tests that hooks can access the handler
func TestHookContextHandler(t *testing.T) {
registry := NewHookRegistry()
var capturedHandler *Handler
hook := func(ctx *HookContext) error {
// Verify that the handler is accessible from the context
if ctx.Handler == nil {
return fmt.Errorf("handler is nil in hook context")
}
capturedHandler = ctx.Handler
return nil
}
registry.Register(BeforeRead, hook)
// Create a mock handler
handler := &Handler{
hooks: registry,
}
ctx := &HookContext{
Context: context.Background(),
Handler: handler,
Schema: "test",
Entity: "users",
}
err := registry.Execute(BeforeRead, ctx)
if err != nil {
t.Errorf("Hook execution failed: %v", err)
}
if capturedHandler == nil {
t.Error("Handler was not captured from hook context")
}
if capturedHandler != handler {
t.Error("Captured handler does not match original handler")
}
}