mirror of
https://github.com/bitechdev/ResolveSpec.git
synced 2025-11-13 09:53:53 +00:00
Added cursor filters and hooks
This commit is contained in:
parent
fc82a9bc50
commit
c8704c07dd
@ -1,6 +1,11 @@
|
||||
package database
|
||||
|
||||
import "strings"
|
||||
import (
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/Warky-Devs/ResolveSpec/pkg/common"
|
||||
)
|
||||
|
||||
// parseTableName splits a table name that may contain schema into separate schema and table
|
||||
// For example: "public.users" -> ("public", "users")
|
||||
@ -11,3 +16,157 @@ func parseTableName(fullTableName string) (schema, table string) {
|
||||
}
|
||||
return "", fullTableName
|
||||
}
|
||||
|
||||
// GetPrimaryKeyName extracts the primary key column name from a model
|
||||
// It first checks if the model implements PrimaryKeyNameProvider (GetIDName method)
|
||||
// Falls back to reflection to find bun:",pk" tag, then gorm:"primaryKey" tag
|
||||
func GetPrimaryKeyName(model any) string {
|
||||
// Check if model implements PrimaryKeyNameProvider
|
||||
if provider, ok := model.(common.PrimaryKeyNameProvider); ok {
|
||||
return provider.GetIDName()
|
||||
}
|
||||
|
||||
// Try Bun tag first
|
||||
if pkName := getPrimaryKeyFromReflection(model, "bun"); pkName != "" {
|
||||
return pkName
|
||||
}
|
||||
|
||||
// Fall back to GORM tag
|
||||
return getPrimaryKeyFromReflection(model, "gorm")
|
||||
}
|
||||
|
||||
// GetModelColumns extracts all column names from a model using reflection
|
||||
// It checks bun tags first, then gorm tags, then json tags, and finally falls back to lowercase field names
|
||||
func GetModelColumns(model any) []string {
|
||||
var columns []string
|
||||
|
||||
modelType := reflect.TypeOf(model)
|
||||
|
||||
// Unwrap pointers, slices, and arrays to get to the base struct type
|
||||
for modelType != nil && (modelType.Kind() == reflect.Pointer || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
|
||||
// Validate that we have a struct type
|
||||
if modelType == nil || modelType.Kind() != reflect.Struct {
|
||||
return columns
|
||||
}
|
||||
|
||||
for i := 0; i < modelType.NumField(); i++ {
|
||||
field := modelType.Field(i)
|
||||
|
||||
// Get column name using the same logic as primary key extraction
|
||||
columnName := getColumnNameFromField(field)
|
||||
|
||||
if columnName != "" {
|
||||
columns = append(columns, columnName)
|
||||
}
|
||||
}
|
||||
|
||||
return columns
|
||||
}
|
||||
|
||||
// getColumnNameFromField extracts the column name from a struct field
|
||||
// Priority: bun tag -> gorm tag -> json tag -> lowercase field name
|
||||
func getColumnNameFromField(field reflect.StructField) string {
|
||||
// Try bun tag first
|
||||
bunTag := field.Tag.Get("bun")
|
||||
if bunTag != "" && bunTag != "-" {
|
||||
if colName := extractColumnFromBunTag(bunTag); colName != "" {
|
||||
return colName
|
||||
}
|
||||
}
|
||||
|
||||
// Try gorm tag
|
||||
gormTag := field.Tag.Get("gorm")
|
||||
if gormTag != "" && gormTag != "-" {
|
||||
if colName := extractColumnFromGormTag(gormTag); colName != "" {
|
||||
return colName
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to json tag
|
||||
jsonTag := field.Tag.Get("json")
|
||||
if jsonTag != "" && jsonTag != "-" {
|
||||
// Extract just the field name before any options
|
||||
parts := strings.Split(jsonTag, ",")
|
||||
if len(parts) > 0 && parts[0] != "" {
|
||||
return parts[0]
|
||||
}
|
||||
}
|
||||
|
||||
// Last resort: use field name in lowercase
|
||||
return strings.ToLower(field.Name)
|
||||
}
|
||||
|
||||
// getPrimaryKeyFromReflection uses reflection to find the primary key field
|
||||
func getPrimaryKeyFromReflection(model any, ormType string) string {
|
||||
val := reflect.ValueOf(model)
|
||||
if val.Kind() == reflect.Pointer {
|
||||
val = val.Elem()
|
||||
}
|
||||
|
||||
if val.Kind() != reflect.Struct {
|
||||
return ""
|
||||
}
|
||||
|
||||
typ := val.Type()
|
||||
for i := 0; i < typ.NumField(); i++ {
|
||||
field := typ.Field(i)
|
||||
|
||||
switch ormType {
|
||||
case "gorm":
|
||||
// Check for gorm tag with primaryKey
|
||||
gormTag := field.Tag.Get("gorm")
|
||||
if strings.Contains(gormTag, "primaryKey") {
|
||||
// Try to extract column name from gorm tag
|
||||
if colName := extractColumnFromGormTag(gormTag); colName != "" {
|
||||
return colName
|
||||
}
|
||||
// Fall back to json tag
|
||||
if jsonTag := field.Tag.Get("json"); jsonTag != "" {
|
||||
return strings.Split(jsonTag, ",")[0]
|
||||
}
|
||||
}
|
||||
case "bun":
|
||||
// Check for bun tag with pk flag
|
||||
bunTag := field.Tag.Get("bun")
|
||||
if strings.Contains(bunTag, "pk") {
|
||||
// Extract column name from bun tag
|
||||
if colName := extractColumnFromBunTag(bunTag); colName != "" {
|
||||
return colName
|
||||
}
|
||||
// Fall back to json tag
|
||||
if jsonTag := field.Tag.Get("json"); jsonTag != "" {
|
||||
return strings.Split(jsonTag, ",")[0]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// extractColumnFromGormTag extracts the column name from a gorm tag
|
||||
// Example: "column:id;primaryKey" -> "id"
|
||||
func extractColumnFromGormTag(tag string) string {
|
||||
parts := strings.Split(tag, ";")
|
||||
for _, part := range parts {
|
||||
part = strings.TrimSpace(part)
|
||||
if colName, found := strings.CutPrefix(part, "column:"); found {
|
||||
return colName
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// extractColumnFromBunTag extracts the column name from a bun tag
|
||||
// Example: "id,pk" -> "id"
|
||||
// Example: ",pk" -> "" (will fall back to json tag)
|
||||
func extractColumnFromBunTag(tag string) string {
|
||||
parts := strings.Split(tag, ",")
|
||||
if len(parts) > 0 && parts[0] != "" {
|
||||
return parts[0]
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
233
pkg/common/adapters/database/utils_test.go
Normal file
233
pkg/common/adapters/database/utils_test.go
Normal file
@ -0,0 +1,233 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
// Test models for GORM
|
||||
type GormModelWithGetIDName struct {
|
||||
ID int `gorm:"column:rid_test;primaryKey" json:"id"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
func (m GormModelWithGetIDName) GetIDName() string {
|
||||
return "rid_test"
|
||||
}
|
||||
|
||||
type GormModelWithColumnTag struct {
|
||||
ID int `gorm:"column:custom_id;primaryKey" json:"id"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
type GormModelWithJSONFallback struct {
|
||||
ID int `gorm:"primaryKey" json:"user_id"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
// Test models for Bun
|
||||
type BunModelWithGetIDName struct {
|
||||
ID int `bun:"rid_test,pk" json:"id"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
func (m BunModelWithGetIDName) GetIDName() string {
|
||||
return "rid_test"
|
||||
}
|
||||
|
||||
type BunModelWithColumnTag struct {
|
||||
ID int `bun:"custom_id,pk" json:"id"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
type BunModelWithJSONFallback struct {
|
||||
ID int `bun:",pk" json:"user_id"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
func TestGetPrimaryKeyName(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
model any
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "GORM model with GetIDName method",
|
||||
model: GormModelWithGetIDName{},
|
||||
expected: "rid_test",
|
||||
},
|
||||
{
|
||||
name: "GORM model with column tag",
|
||||
model: GormModelWithColumnTag{},
|
||||
expected: "custom_id",
|
||||
},
|
||||
{
|
||||
name: "GORM model with JSON fallback",
|
||||
model: GormModelWithJSONFallback{},
|
||||
expected: "user_id",
|
||||
},
|
||||
{
|
||||
name: "GORM model pointer with GetIDName",
|
||||
model: &GormModelWithGetIDName{},
|
||||
expected: "rid_test",
|
||||
},
|
||||
{
|
||||
name: "GORM model pointer with column tag",
|
||||
model: &GormModelWithColumnTag{},
|
||||
expected: "custom_id",
|
||||
},
|
||||
{
|
||||
name: "Bun model with GetIDName method",
|
||||
model: BunModelWithGetIDName{},
|
||||
expected: "rid_test",
|
||||
},
|
||||
{
|
||||
name: "Bun model with column tag",
|
||||
model: BunModelWithColumnTag{},
|
||||
expected: "custom_id",
|
||||
},
|
||||
{
|
||||
name: "Bun model with JSON fallback",
|
||||
model: BunModelWithJSONFallback{},
|
||||
expected: "user_id",
|
||||
},
|
||||
{
|
||||
name: "Bun model pointer with GetIDName",
|
||||
model: &BunModelWithGetIDName{},
|
||||
expected: "rid_test",
|
||||
},
|
||||
{
|
||||
name: "Bun model pointer with column tag",
|
||||
model: &BunModelWithColumnTag{},
|
||||
expected: "custom_id",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := GetPrimaryKeyName(tt.model)
|
||||
if result != tt.expected {
|
||||
t.Errorf("GetPrimaryKeyName() = %v, want %v", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractColumnFromGormTag(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
tag string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "column tag with primaryKey",
|
||||
tag: "column:rid_test;primaryKey",
|
||||
expected: "rid_test",
|
||||
},
|
||||
{
|
||||
name: "column tag with spaces",
|
||||
tag: "column:user_id ; primaryKey ; autoIncrement",
|
||||
expected: "user_id",
|
||||
},
|
||||
{
|
||||
name: "no column tag",
|
||||
tag: "primaryKey;autoIncrement",
|
||||
expected: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := extractColumnFromGormTag(tt.tag)
|
||||
if result != tt.expected {
|
||||
t.Errorf("extractColumnFromGormTag() = %v, want %v", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractColumnFromBunTag(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
tag string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "column name with pk flag",
|
||||
tag: "rid_test,pk",
|
||||
expected: "rid_test",
|
||||
},
|
||||
{
|
||||
name: "only pk flag",
|
||||
tag: ",pk",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "column with multiple flags",
|
||||
tag: "user_id,pk,autoincrement",
|
||||
expected: "user_id",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := extractColumnFromBunTag(tt.tag)
|
||||
if result != tt.expected {
|
||||
t.Errorf("extractColumnFromBunTag() = %v, want %v", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetModelColumns(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
model any
|
||||
expected []string
|
||||
}{
|
||||
{
|
||||
name: "Bun model with multiple columns",
|
||||
model: BunModelWithColumnTag{},
|
||||
expected: []string{"custom_id", "name"},
|
||||
},
|
||||
{
|
||||
name: "GORM model with multiple columns",
|
||||
model: GormModelWithColumnTag{},
|
||||
expected: []string{"custom_id", "name"},
|
||||
},
|
||||
{
|
||||
name: "Bun model pointer",
|
||||
model: &BunModelWithColumnTag{},
|
||||
expected: []string{"custom_id", "name"},
|
||||
},
|
||||
{
|
||||
name: "GORM model pointer",
|
||||
model: &GormModelWithColumnTag{},
|
||||
expected: []string{"custom_id", "name"},
|
||||
},
|
||||
{
|
||||
name: "Bun model with JSON fallback",
|
||||
model: BunModelWithJSONFallback{},
|
||||
expected: []string{"user_id", "name"},
|
||||
},
|
||||
{
|
||||
name: "GORM model with JSON fallback",
|
||||
model: GormModelWithJSONFallback{},
|
||||
expected: []string{"user_id", "name"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := GetModelColumns(tt.model)
|
||||
if len(result) != len(tt.expected) {
|
||||
t.Errorf("GetModelColumns() returned %d columns, want %d", len(result), len(tt.expected))
|
||||
return
|
||||
}
|
||||
for i, col := range result {
|
||||
if col != tt.expected[i] {
|
||||
t.Errorf("GetModelColumns()[%d] = %v, want %v", i, col, tt.expected[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -131,6 +131,11 @@ type TableNameProvider interface {
|
||||
TableName() string
|
||||
}
|
||||
|
||||
// PrimaryKeyNameProvider interface for models that provide primary key column names
|
||||
type PrimaryKeyNameProvider interface {
|
||||
GetIDName() string
|
||||
}
|
||||
|
||||
// SchemaProvider interface for models that provide schema names
|
||||
type SchemaProvider interface {
|
||||
SchemaName() string
|
||||
|
||||
223
pkg/restheadspec/cursor.go
Normal file
223
pkg/restheadspec/cursor.go
Normal file
@ -0,0 +1,223 @@
|
||||
package restheadspec
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/Warky-Devs/ResolveSpec/pkg/common"
|
||||
)
|
||||
|
||||
// CursorDirection defines pagination direction
|
||||
type CursorDirection int
|
||||
|
||||
const (
|
||||
CursorForward CursorDirection = 1
|
||||
CursorBackward CursorDirection = -1
|
||||
)
|
||||
|
||||
// GetCursorFilter generates a SQL `EXISTS` subquery for cursor-based pagination.
|
||||
// It uses the current request's sort, cursor, joins (via Expand), and CQL (via ComputedQL).
|
||||
//
|
||||
// Parameters:
|
||||
// - tableName: name of the main table (e.g. "post")
|
||||
// - pkName: primary key column (e.g. "id")
|
||||
// - modelColumns: optional list of valid main-table columns (for validation). Pass nil to skip.
|
||||
// - expandJoins: optional map[alias]string of JOIN clauses (e.g. "user": "LEFT JOIN user ON ...")
|
||||
//
|
||||
// Returns SQL snippet to embed in WHERE clause.
|
||||
func (opts *ExtendedRequestOptions) GetCursorFilter(
|
||||
tableName string,
|
||||
pkName string,
|
||||
modelColumns []string, // optional: for validation
|
||||
expandJoins map[string]string, // optional: alias → JOIN SQL
|
||||
) (string, error) {
|
||||
|
||||
// --------------------------------------------------------------------- //
|
||||
// 1. Determine active cursor
|
||||
// --------------------------------------------------------------------- //
|
||||
cursorID, direction := opts.getActiveCursor()
|
||||
if cursorID == "" {
|
||||
return "", fmt.Errorf("no cursor provided for table %s", tableName)
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------- //
|
||||
// 2. Extract sort columns
|
||||
// --------------------------------------------------------------------- //
|
||||
sortItems := opts.getSortColumns()
|
||||
if len(sortItems) == 0 {
|
||||
return "", fmt.Errorf("no sort columns defined")
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------- //
|
||||
// 3. Prepare
|
||||
// --------------------------------------------------------------------- //
|
||||
var whereClauses []string
|
||||
joinSQL := ""
|
||||
reverse := direction < 0
|
||||
|
||||
// --------------------------------------------------------------------- //
|
||||
// 4. Process each sort column
|
||||
// --------------------------------------------------------------------- //
|
||||
for _, s := range sortItems {
|
||||
col := strings.TrimSpace(s.Column)
|
||||
if col == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Parse: "user.name desc nulls last"
|
||||
parts := strings.Split(col, ".")
|
||||
field := strings.TrimSpace(parts[len(parts)-1])
|
||||
prefix := strings.Join(parts[:len(parts)-1], ".")
|
||||
|
||||
// Direction from struct or string
|
||||
desc := strings.EqualFold(s.Direction, "desc") ||
|
||||
strings.Contains(strings.ToLower(field), "desc")
|
||||
field = opts.cleanSortField(field)
|
||||
|
||||
if reverse {
|
||||
desc = !desc
|
||||
}
|
||||
|
||||
// Resolve column
|
||||
cursorCol, targetCol, isJoin, err := opts.resolveColumn(
|
||||
field, prefix, tableName, modelColumns,
|
||||
)
|
||||
if err != nil {
|
||||
fmt.Printf("WARN: Skipping invalid sort column %q: %v\n", col, err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Handle joins
|
||||
if isJoin && expandJoins != nil {
|
||||
if joinClause, ok := expandJoins[prefix]; ok {
|
||||
jSQL, cRef := rewriteJoin(joinClause, tableName, prefix)
|
||||
joinSQL = jSQL
|
||||
cursorCol = cRef + "." + field
|
||||
targetCol = prefix + "." + field
|
||||
}
|
||||
}
|
||||
|
||||
// Build inequality
|
||||
op := "<"
|
||||
if desc {
|
||||
op = ">"
|
||||
}
|
||||
whereClauses = append(whereClauses, fmt.Sprintf("%s %s %s", cursorCol, op, targetCol))
|
||||
}
|
||||
|
||||
if len(whereClauses) == 0 {
|
||||
return "", fmt.Errorf("no valid sort columns after filtering")
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------- //
|
||||
// 5. Build priority OR-AND chain
|
||||
// --------------------------------------------------------------------- //
|
||||
orSQL := buildPriorityChain(whereClauses)
|
||||
|
||||
// --------------------------------------------------------------------- //
|
||||
// 6. Final EXISTS subquery
|
||||
// --------------------------------------------------------------------- //
|
||||
query := fmt.Sprintf(`EXISTS (
|
||||
SELECT 1
|
||||
FROM %s cursor_select
|
||||
%s
|
||||
WHERE cursor_select.%s = %s
|
||||
AND (%s)
|
||||
)`,
|
||||
tableName,
|
||||
joinSQL,
|
||||
pkName,
|
||||
cursorID,
|
||||
orSQL,
|
||||
)
|
||||
|
||||
return query, nil
|
||||
}
|
||||
|
||||
// ------------------------------------------------------------------------- //
|
||||
// Helper: get active cursor (forward or backward)
|
||||
func (opts *ExtendedRequestOptions) getActiveCursor() (id string, direction CursorDirection) {
|
||||
if opts.CursorForward != "" {
|
||||
return opts.CursorForward, CursorForward
|
||||
}
|
||||
if opts.CursorBackward != "" {
|
||||
return opts.CursorBackward, CursorBackward
|
||||
}
|
||||
return "", 0
|
||||
}
|
||||
|
||||
// Helper: extract sort columns
|
||||
func (opts *ExtendedRequestOptions) getSortColumns() []common.SortOption {
|
||||
if opts.RequestOptions.Sort != nil {
|
||||
return opts.RequestOptions.Sort
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Helper: clean sort field (remove desc, asc, nulls)
|
||||
func (opts *ExtendedRequestOptions) cleanSortField(field string) string {
|
||||
f := strings.ToLower(field)
|
||||
for _, token := range []string{"desc", "asc", "nulls last", "nulls first"} {
|
||||
f = strings.ReplaceAll(f, token, "")
|
||||
}
|
||||
return strings.TrimSpace(f)
|
||||
}
|
||||
|
||||
// Helper: resolve column (main, JSON, CQL, join)
|
||||
func (opts *ExtendedRequestOptions) resolveColumn(
|
||||
field, prefix, tableName string,
|
||||
modelColumns []string,
|
||||
) (cursorCol, targetCol string, isJoin bool, err error) {
|
||||
|
||||
// JSON field
|
||||
if strings.Contains(field, "->") {
|
||||
return "cursor_select." + field, tableName + "." + field, false, nil
|
||||
}
|
||||
|
||||
// CQL via ComputedQL
|
||||
if strings.Contains(strings.ToLower(field), "cql") && opts.ComputedQL != nil {
|
||||
if expr, ok := opts.ComputedQL[field]; ok {
|
||||
return "cursor_select." + expr, expr, false, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Main table column
|
||||
if modelColumns != nil {
|
||||
for _, col := range modelColumns {
|
||||
if strings.EqualFold(col, field) {
|
||||
return "cursor_select." + field, tableName + "." + field, false, nil
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// No validation → allow all main-table fields
|
||||
return "cursor_select." + field, tableName + "." + field, false, nil
|
||||
}
|
||||
|
||||
// Joined column
|
||||
if prefix != "" && prefix != tableName {
|
||||
return "", "", true, nil
|
||||
}
|
||||
|
||||
return "", "", false, fmt.Errorf("invalid column: %s", field)
|
||||
}
|
||||
|
||||
// ------------------------------------------------------------------------- //
|
||||
// Helper: rewrite JOIN clause for cursor subquery
|
||||
func rewriteJoin(joinClause, mainTable, alias string) (joinSQL, cursorAlias string) {
|
||||
joinSQL = strings.ReplaceAll(joinClause, mainTable+".", "cursor_select.")
|
||||
cursorAlias = "cursor_select_" + alias
|
||||
joinSQL = strings.ReplaceAll(joinSQL, " "+alias+" ", " "+cursorAlias+" ")
|
||||
joinSQL = strings.ReplaceAll(joinSQL, " "+alias+".", " "+cursorAlias+".")
|
||||
return joinSQL, cursorAlias
|
||||
}
|
||||
|
||||
// ------------------------------------------------------------------------- //
|
||||
// Helper: build OR-AND priority chain
|
||||
func buildPriorityChain(clauses []string) string {
|
||||
var or []string
|
||||
for i := 0; i < len(clauses); i++ {
|
||||
and := strings.Join(clauses[:i+1], "\n AND ")
|
||||
or = append(or, "("+and+")")
|
||||
}
|
||||
return strings.Join(or, "\n OR ")
|
||||
}
|
||||
@ -10,6 +10,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/Warky-Devs/ResolveSpec/pkg/common"
|
||||
"github.com/Warky-Devs/ResolveSpec/pkg/common/adapters/database"
|
||||
"github.com/Warky-Devs/ResolveSpec/pkg/logger"
|
||||
)
|
||||
|
||||
@ -18,6 +19,7 @@ import (
|
||||
type Handler struct {
|
||||
db common.Database
|
||||
registry common.ModelRegistry
|
||||
hooks *HookRegistry
|
||||
}
|
||||
|
||||
// NewHandler creates a new API handler with database and registry abstractions
|
||||
@ -25,9 +27,16 @@ func NewHandler(db common.Database, registry common.ModelRegistry) *Handler {
|
||||
return &Handler{
|
||||
db: db,
|
||||
registry: registry,
|
||||
hooks: NewHookRegistry(),
|
||||
}
|
||||
}
|
||||
|
||||
// Hooks returns the hook registry for this handler
|
||||
// Use this to register custom hooks for operations
|
||||
func (h *Handler) Hooks() *HookRegistry {
|
||||
return h.hooks
|
||||
}
|
||||
|
||||
// handlePanic is a helper function to handle panics with stack traces
|
||||
func (h *Handler) handlePanic(w common.ResponseWriter, method string, err interface{}) {
|
||||
stack := debug.Stack()
|
||||
@ -184,6 +193,25 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
||||
tableName := GetTableName(ctx)
|
||||
model := GetModel(ctx)
|
||||
|
||||
// Execute BeforeRead hooks
|
||||
hookCtx := &HookContext{
|
||||
Context: ctx,
|
||||
Handler: h,
|
||||
Schema: schema,
|
||||
Entity: entity,
|
||||
TableName: tableName,
|
||||
Model: model,
|
||||
Options: options,
|
||||
ID: id,
|
||||
Writer: w,
|
||||
}
|
||||
|
||||
if err := h.hooks.Execute(BeforeRead, hookCtx); err != nil {
|
||||
logger.Error("BeforeRead hook failed: %v", err)
|
||||
h.sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Validate and unwrap model type to get base struct
|
||||
modelType := reflect.TypeOf(model)
|
||||
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
||||
@ -310,6 +338,39 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
||||
query = query.Offset(*options.Offset)
|
||||
}
|
||||
|
||||
// Apply cursor-based pagination
|
||||
if len(options.CursorForward) > 0 || len(options.CursorBackward) > 0 {
|
||||
logger.Debug("Applying cursor pagination")
|
||||
|
||||
// Get primary key name
|
||||
pkName := database.GetPrimaryKeyName(model)
|
||||
|
||||
// Extract model columns for validation using the generic database function
|
||||
modelColumns := database.GetModelColumns(model)
|
||||
|
||||
// Build expand joins map (if needed in future)
|
||||
var expandJoins map[string]string
|
||||
if len(options.Expand) > 0 {
|
||||
expandJoins = make(map[string]string)
|
||||
// TODO: Build actual JOIN SQL for each expand relation
|
||||
// For now, pass empty map as joins are handled via Preload
|
||||
}
|
||||
|
||||
// Get cursor filter SQL
|
||||
cursorFilter, err := options.GetCursorFilter(tableName, pkName, modelColumns, expandJoins)
|
||||
if err != nil {
|
||||
logger.Error("Error building cursor filter: %v", err)
|
||||
h.sendError(w, http.StatusBadRequest, "cursor_error", "Invalid cursor pagination", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Apply cursor filter to query
|
||||
if cursorFilter != "" {
|
||||
logger.Debug("Applying cursor filter: %s", cursorFilter)
|
||||
query = query.Where(cursorFilter)
|
||||
}
|
||||
}
|
||||
|
||||
// Execute query - modelPtr was already created earlier
|
||||
if err := query.Scan(ctx, modelPtr); err != nil {
|
||||
logger.Error("Error executing query: %v", err)
|
||||
@ -333,6 +394,16 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
||||
Offset: offset,
|
||||
}
|
||||
|
||||
// Execute AfterRead hooks
|
||||
hookCtx.Result = modelPtr
|
||||
hookCtx.Error = nil
|
||||
|
||||
if err := h.hooks.Execute(AfterRead, hookCtx); err != nil {
|
||||
logger.Error("AfterRead hook failed: %v", err)
|
||||
h.sendError(w, http.StatusInternalServerError, "hook_error", "Hook execution failed", err)
|
||||
return
|
||||
}
|
||||
|
||||
h.sendFormattedResponse(w, modelPtr, metadata, options)
|
||||
}
|
||||
|
||||
@ -351,6 +422,28 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
|
||||
|
||||
logger.Info("Creating record in %s.%s", schema, entity)
|
||||
|
||||
// Execute BeforeCreate hooks
|
||||
hookCtx := &HookContext{
|
||||
Context: ctx,
|
||||
Handler: h,
|
||||
Schema: schema,
|
||||
Entity: entity,
|
||||
TableName: tableName,
|
||||
Model: model,
|
||||
Options: options,
|
||||
Data: data,
|
||||
Writer: w,
|
||||
}
|
||||
|
||||
if err := h.hooks.Execute(BeforeCreate, hookCtx); err != nil {
|
||||
logger.Error("BeforeCreate hook failed: %v", err)
|
||||
h.sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Use potentially modified data from hook context
|
||||
data = hookCtx.Data
|
||||
|
||||
// Handle batch creation
|
||||
dataValue := reflect.ValueOf(data)
|
||||
if dataValue.Kind() == reflect.Slice || dataValue.Kind() == reflect.Array {
|
||||
@ -385,6 +478,16 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
|
||||
return
|
||||
}
|
||||
|
||||
// Execute AfterCreate hooks for batch creation
|
||||
hookCtx.Result = map[string]interface{}{"created": dataValue.Len()}
|
||||
hookCtx.Error = nil
|
||||
|
||||
if err := h.hooks.Execute(AfterCreate, hookCtx); err != nil {
|
||||
logger.Error("AfterCreate hook failed: %v", err)
|
||||
h.sendError(w, http.StatusInternalServerError, "hook_error", "Hook execution failed", err)
|
||||
return
|
||||
}
|
||||
|
||||
h.sendResponse(w, map[string]interface{}{"created": dataValue.Len()}, nil)
|
||||
return
|
||||
}
|
||||
@ -410,6 +513,16 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
|
||||
return
|
||||
}
|
||||
|
||||
// Execute AfterCreate hooks for single record creation
|
||||
hookCtx.Result = modelValue
|
||||
hookCtx.Error = nil
|
||||
|
||||
if err := h.hooks.Execute(AfterCreate, hookCtx); err != nil {
|
||||
logger.Error("AfterCreate hook failed: %v", err)
|
||||
h.sendError(w, http.StatusInternalServerError, "hook_error", "Hook execution failed", err)
|
||||
return
|
||||
}
|
||||
|
||||
h.sendResponse(w, modelValue, nil)
|
||||
}
|
||||
|
||||
@ -424,9 +537,33 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id
|
||||
schema := GetSchema(ctx)
|
||||
entity := GetEntity(ctx)
|
||||
tableName := GetTableName(ctx)
|
||||
model := GetModel(ctx)
|
||||
|
||||
logger.Info("Updating record in %s.%s", schema, entity)
|
||||
|
||||
// Execute BeforeUpdate hooks
|
||||
hookCtx := &HookContext{
|
||||
Context: ctx,
|
||||
Handler: h,
|
||||
Schema: schema,
|
||||
Entity: entity,
|
||||
TableName: tableName,
|
||||
Model: model,
|
||||
Options: options,
|
||||
ID: id,
|
||||
Data: data,
|
||||
Writer: w,
|
||||
}
|
||||
|
||||
if err := h.hooks.Execute(BeforeUpdate, hookCtx); err != nil {
|
||||
logger.Error("BeforeUpdate hook failed: %v", err)
|
||||
h.sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Use potentially modified data from hook context
|
||||
data = hookCtx.Data
|
||||
|
||||
// Convert data to map
|
||||
dataMap, ok := data.(map[string]interface{})
|
||||
if !ok {
|
||||
@ -462,9 +599,20 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id
|
||||
return
|
||||
}
|
||||
|
||||
h.sendResponse(w, map[string]interface{}{
|
||||
// Execute AfterUpdate hooks
|
||||
responseData := map[string]interface{}{
|
||||
"updated": result.RowsAffected(),
|
||||
}, nil)
|
||||
}
|
||||
hookCtx.Result = responseData
|
||||
hookCtx.Error = nil
|
||||
|
||||
if err := h.hooks.Execute(AfterUpdate, hookCtx); err != nil {
|
||||
logger.Error("AfterUpdate hook failed: %v", err)
|
||||
h.sendError(w, http.StatusInternalServerError, "hook_error", "Hook execution failed", err)
|
||||
return
|
||||
}
|
||||
|
||||
h.sendResponse(w, responseData, nil)
|
||||
}
|
||||
|
||||
func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id string) {
|
||||
@ -478,9 +626,28 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
||||
schema := GetSchema(ctx)
|
||||
entity := GetEntity(ctx)
|
||||
tableName := GetTableName(ctx)
|
||||
model := GetModel(ctx)
|
||||
|
||||
logger.Info("Deleting record from %s.%s", schema, entity)
|
||||
|
||||
// Execute BeforeDelete hooks
|
||||
hookCtx := &HookContext{
|
||||
Context: ctx,
|
||||
Handler: h,
|
||||
Schema: schema,
|
||||
Entity: entity,
|
||||
TableName: tableName,
|
||||
Model: model,
|
||||
ID: id,
|
||||
Writer: w,
|
||||
}
|
||||
|
||||
if err := h.hooks.Execute(BeforeDelete, hookCtx); err != nil {
|
||||
logger.Error("BeforeDelete hook failed: %v", err)
|
||||
h.sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err)
|
||||
return
|
||||
}
|
||||
|
||||
query := h.db.NewDelete().Table(tableName)
|
||||
|
||||
if id == "" {
|
||||
@ -497,9 +664,20 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
||||
return
|
||||
}
|
||||
|
||||
h.sendResponse(w, map[string]interface{}{
|
||||
// Execute AfterDelete hooks
|
||||
responseData := map[string]interface{}{
|
||||
"deleted": result.RowsAffected(),
|
||||
}, nil)
|
||||
}
|
||||
hookCtx.Result = responseData
|
||||
hookCtx.Error = nil
|
||||
|
||||
if err := h.hooks.Execute(AfterDelete, hookCtx); err != nil {
|
||||
logger.Error("AfterDelete hook failed: %v", err)
|
||||
h.sendError(w, http.StatusInternalServerError, "hook_error", "Hook execution failed", err)
|
||||
return
|
||||
}
|
||||
|
||||
h.sendResponse(w, responseData, nil)
|
||||
}
|
||||
|
||||
// qualifyColumnName ensures column name is fully qualified with table name if not already
|
||||
|
||||
140
pkg/restheadspec/hooks.go
Normal file
140
pkg/restheadspec/hooks.go
Normal file
@ -0,0 +1,140 @@
|
||||
package restheadspec
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/Warky-Devs/ResolveSpec/pkg/common"
|
||||
"github.com/Warky-Devs/ResolveSpec/pkg/logger"
|
||||
)
|
||||
|
||||
// HookType defines the type of hook to execute
|
||||
type HookType string
|
||||
|
||||
const (
|
||||
// Read operation hooks
|
||||
BeforeRead HookType = "before_read"
|
||||
AfterRead HookType = "after_read"
|
||||
|
||||
// Create operation hooks
|
||||
BeforeCreate HookType = "before_create"
|
||||
AfterCreate HookType = "after_create"
|
||||
|
||||
// Update operation hooks
|
||||
BeforeUpdate HookType = "before_update"
|
||||
AfterUpdate HookType = "after_update"
|
||||
|
||||
// Delete operation hooks
|
||||
BeforeDelete HookType = "before_delete"
|
||||
AfterDelete HookType = "after_delete"
|
||||
)
|
||||
|
||||
// HookContext contains all the data available to a hook
|
||||
type HookContext struct {
|
||||
Context context.Context
|
||||
Handler *Handler // Reference to the handler for accessing database, registry, etc.
|
||||
Schema string
|
||||
Entity string
|
||||
TableName string
|
||||
Model interface{}
|
||||
Options ExtendedRequestOptions
|
||||
|
||||
// Operation-specific fields
|
||||
ID string
|
||||
Data interface{} // For create/update operations
|
||||
Result interface{} // For after hooks
|
||||
Error error // For after hooks
|
||||
QueryFilter string // For read operations
|
||||
|
||||
// Response writer - allows hooks to modify response
|
||||
Writer common.ResponseWriter
|
||||
}
|
||||
|
||||
// HookFunc is the signature for hook functions
|
||||
// It receives a HookContext and can modify it or return an error
|
||||
// If an error is returned, the operation will be aborted
|
||||
type HookFunc func(*HookContext) error
|
||||
|
||||
// HookRegistry manages all registered hooks
|
||||
type HookRegistry struct {
|
||||
hooks map[HookType][]HookFunc
|
||||
}
|
||||
|
||||
// NewHookRegistry creates a new hook registry
|
||||
func NewHookRegistry() *HookRegistry {
|
||||
return &HookRegistry{
|
||||
hooks: make(map[HookType][]HookFunc),
|
||||
}
|
||||
}
|
||||
|
||||
// Register adds a new hook for the specified hook type
|
||||
func (r *HookRegistry) Register(hookType HookType, hook HookFunc) {
|
||||
if r.hooks == nil {
|
||||
r.hooks = make(map[HookType][]HookFunc)
|
||||
}
|
||||
r.hooks[hookType] = append(r.hooks[hookType], hook)
|
||||
logger.Info("Registered hook for %s (total: %d)", hookType, len(r.hooks[hookType]))
|
||||
}
|
||||
|
||||
// RegisterMultiple registers a hook for multiple hook types
|
||||
func (r *HookRegistry) RegisterMultiple(hookTypes []HookType, hook HookFunc) {
|
||||
for _, hookType := range hookTypes {
|
||||
r.Register(hookType, hook)
|
||||
}
|
||||
}
|
||||
|
||||
// Execute runs all hooks for the specified type in order
|
||||
// If any hook returns an error, execution stops and the error is returned
|
||||
func (r *HookRegistry) Execute(hookType HookType, ctx *HookContext) error {
|
||||
hooks, exists := r.hooks[hookType]
|
||||
if !exists || len(hooks) == 0 {
|
||||
logger.Debug("No hooks registered for %s", hookType)
|
||||
return nil
|
||||
}
|
||||
|
||||
logger.Debug("Executing %d hook(s) for %s", len(hooks), hookType)
|
||||
|
||||
for i, hook := range hooks {
|
||||
if err := hook(ctx); err != nil {
|
||||
logger.Error("Hook %d for %s failed: %v", i+1, hookType, err)
|
||||
return fmt.Errorf("hook execution failed: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
logger.Debug("All hooks for %s executed successfully", hookType)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Clear removes all hooks for the specified type
|
||||
func (r *HookRegistry) Clear(hookType HookType) {
|
||||
delete(r.hooks, hookType)
|
||||
logger.Info("Cleared all hooks for %s", hookType)
|
||||
}
|
||||
|
||||
// ClearAll removes all registered hooks
|
||||
func (r *HookRegistry) ClearAll() {
|
||||
r.hooks = make(map[HookType][]HookFunc)
|
||||
logger.Info("Cleared all hooks")
|
||||
}
|
||||
|
||||
// Count returns the number of hooks registered for a specific type
|
||||
func (r *HookRegistry) Count(hookType HookType) int {
|
||||
if hooks, exists := r.hooks[hookType]; exists {
|
||||
return len(hooks)
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// HasHooks returns true if there are any hooks registered for the specified type
|
||||
func (r *HookRegistry) HasHooks(hookType HookType) bool {
|
||||
return r.Count(hookType) > 0
|
||||
}
|
||||
|
||||
// GetAllHookTypes returns all hook types that have registered hooks
|
||||
func (r *HookRegistry) GetAllHookTypes() []HookType {
|
||||
types := make([]HookType, 0, len(r.hooks))
|
||||
for hookType := range r.hooks {
|
||||
types = append(types, hookType)
|
||||
}
|
||||
return types
|
||||
}
|
||||
197
pkg/restheadspec/hooks_example.go
Normal file
197
pkg/restheadspec/hooks_example.go
Normal file
@ -0,0 +1,197 @@
|
||||
package restheadspec
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/Warky-Devs/ResolveSpec/pkg/logger"
|
||||
)
|
||||
|
||||
// This file contains example implementations showing how to use hooks
|
||||
// These are just examples - you can implement hooks as needed for your application
|
||||
|
||||
// ExampleLoggingHook logs before and after operations
|
||||
func ExampleLoggingHook(hookType HookType) HookFunc {
|
||||
return func(ctx *HookContext) error {
|
||||
logger.Info("[%s] Operation: %s.%s, ID: %s", hookType, ctx.Schema, ctx.Entity, ctx.ID)
|
||||
if ctx.Data != nil {
|
||||
logger.Debug("[%s] Data: %+v", hookType, ctx.Data)
|
||||
}
|
||||
if ctx.Result != nil {
|
||||
logger.Debug("[%s] Result: %+v", hookType, ctx.Result)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// ExampleValidationHook validates data before create/update operations
|
||||
func ExampleValidationHook(ctx *HookContext) error {
|
||||
// Example: Ensure certain fields are present
|
||||
if dataMap, ok := ctx.Data.(map[string]interface{}); ok {
|
||||
// Check for required fields
|
||||
requiredFields := []string{"name"} // Add your required fields here
|
||||
for _, field := range requiredFields {
|
||||
if _, exists := dataMap[field]; !exists {
|
||||
return fmt.Errorf("required field missing: %s", field)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ExampleAuthorizationHook checks if the user has permission to perform the operation
|
||||
func ExampleAuthorizationHook(ctx *HookContext) error {
|
||||
// Example: Check user permissions from context
|
||||
// userID, ok := ctx.Context.Value("user_id").(string)
|
||||
// if !ok {
|
||||
// return fmt.Errorf("unauthorized: no user in context")
|
||||
// }
|
||||
|
||||
// You can access the handler's database or registry if needed
|
||||
// For example, to check permissions in the database:
|
||||
// query := ctx.Handler.db.NewSelect().Table("permissions")...
|
||||
|
||||
// Add your authorization logic here
|
||||
logger.Debug("Authorization check for %s.%s", ctx.Schema, ctx.Entity)
|
||||
return nil
|
||||
}
|
||||
|
||||
// ExampleDataTransformHook modifies data before create/update
|
||||
func ExampleDataTransformHook(ctx *HookContext) error {
|
||||
if dataMap, ok := ctx.Data.(map[string]interface{}); ok {
|
||||
// Example: Add a timestamp or user ID
|
||||
// dataMap["updated_at"] = time.Now()
|
||||
// dataMap["updated_by"] = ctx.Context.Value("user_id")
|
||||
|
||||
// Update the context with modified data
|
||||
ctx.Data = dataMap
|
||||
logger.Debug("Data transformed for %s.%s", ctx.Schema, ctx.Entity)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ExampleAuditLogHook creates audit log entries for operations
|
||||
func ExampleAuditLogHook(hookType HookType) HookFunc {
|
||||
return func(ctx *HookContext) error {
|
||||
// Example: Log to audit system
|
||||
auditEntry := map[string]interface{}{
|
||||
"operation": hookType,
|
||||
"schema": ctx.Schema,
|
||||
"entity": ctx.Entity,
|
||||
"table_name": ctx.TableName,
|
||||
"id": ctx.ID,
|
||||
}
|
||||
|
||||
if ctx.Error != nil {
|
||||
auditEntry["error"] = ctx.Error.Error()
|
||||
}
|
||||
|
||||
logger.Info("Audit log: %+v", auditEntry)
|
||||
|
||||
// In a real application, you would save this to a database using the handler
|
||||
// Example:
|
||||
// query := ctx.Handler.db.NewInsert().Table("audit_logs").Model(&auditEntry)
|
||||
// if _, err := query.Exec(ctx.Context); err != nil {
|
||||
// logger.Error("Failed to save audit log: %v", err)
|
||||
// }
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// ExampleCacheInvalidationHook invalidates cache after create/update/delete
|
||||
func ExampleCacheInvalidationHook(ctx *HookContext) error {
|
||||
// Example: Invalidate cache for the entity
|
||||
cacheKey := fmt.Sprintf("%s.%s", ctx.Schema, ctx.Entity)
|
||||
logger.Info("Invalidating cache for: %s", cacheKey)
|
||||
|
||||
// Add your cache invalidation logic here
|
||||
// cache.Delete(cacheKey)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ExampleFilterSensitiveDataHook removes sensitive data from responses
|
||||
func ExampleFilterSensitiveDataHook(ctx *HookContext) error {
|
||||
// Example: Remove password fields from results
|
||||
// This would be called in AfterRead hooks
|
||||
logger.Debug("Filtering sensitive data for %s.%s", ctx.Schema, ctx.Entity)
|
||||
|
||||
// Add your data filtering logic here
|
||||
// You would iterate through ctx.Result and remove sensitive fields
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ExampleRelatedDataHook fetches related data using the handler's database
|
||||
func ExampleRelatedDataHook(ctx *HookContext) error {
|
||||
// Example: Fetch related data after reading the main entity
|
||||
// This hook demonstrates using ctx.Handler to access the database
|
||||
|
||||
if ctx.Entity == "users" && ctx.Result != nil {
|
||||
// Example: Fetch user's recent activity
|
||||
// userID := ... extract from ctx.Result
|
||||
|
||||
// Use the handler's database to query related data
|
||||
// query := ctx.Handler.db.NewSelect().Table("user_activity").Where("user_id = ?", userID)
|
||||
// var activities []Activity
|
||||
// if err := query.Scan(ctx.Context, &activities); err != nil {
|
||||
// logger.Error("Failed to fetch user activities: %v", err)
|
||||
// return err
|
||||
// }
|
||||
|
||||
// Optionally modify the result to include the related data
|
||||
// if resultMap, ok := ctx.Result.(map[string]interface{}); ok {
|
||||
// resultMap["recent_activities"] = activities
|
||||
// }
|
||||
|
||||
logger.Debug("Fetched related data for user entity")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetupExampleHooks demonstrates how to register hooks on a handler
|
||||
func SetupExampleHooks(handler *Handler) {
|
||||
hooks := handler.Hooks()
|
||||
|
||||
// Register logging hooks for all operations
|
||||
hooks.Register(BeforeRead, ExampleLoggingHook(BeforeRead))
|
||||
hooks.Register(AfterRead, ExampleLoggingHook(AfterRead))
|
||||
hooks.Register(BeforeCreate, ExampleLoggingHook(BeforeCreate))
|
||||
hooks.Register(AfterCreate, ExampleLoggingHook(AfterCreate))
|
||||
hooks.Register(BeforeUpdate, ExampleLoggingHook(BeforeUpdate))
|
||||
hooks.Register(AfterUpdate, ExampleLoggingHook(AfterUpdate))
|
||||
hooks.Register(BeforeDelete, ExampleLoggingHook(BeforeDelete))
|
||||
hooks.Register(AfterDelete, ExampleLoggingHook(AfterDelete))
|
||||
|
||||
// Register validation hooks for create/update
|
||||
hooks.Register(BeforeCreate, ExampleValidationHook)
|
||||
hooks.Register(BeforeUpdate, ExampleValidationHook)
|
||||
|
||||
// Register authorization hooks for all operations
|
||||
hooks.RegisterMultiple([]HookType{
|
||||
BeforeRead, BeforeCreate, BeforeUpdate, BeforeDelete,
|
||||
}, ExampleAuthorizationHook)
|
||||
|
||||
// Register data transform hook for create/update
|
||||
hooks.Register(BeforeCreate, ExampleDataTransformHook)
|
||||
hooks.Register(BeforeUpdate, ExampleDataTransformHook)
|
||||
|
||||
// Register audit log hooks for after operations
|
||||
hooks.Register(AfterCreate, ExampleAuditLogHook(AfterCreate))
|
||||
hooks.Register(AfterUpdate, ExampleAuditLogHook(AfterUpdate))
|
||||
hooks.Register(AfterDelete, ExampleAuditLogHook(AfterDelete))
|
||||
|
||||
// Register cache invalidation for after operations
|
||||
hooks.Register(AfterCreate, ExampleCacheInvalidationHook)
|
||||
hooks.Register(AfterUpdate, ExampleCacheInvalidationHook)
|
||||
hooks.Register(AfterDelete, ExampleCacheInvalidationHook)
|
||||
|
||||
// Register sensitive data filtering for read operations
|
||||
hooks.Register(AfterRead, ExampleFilterSensitiveDataHook)
|
||||
|
||||
// Register related data fetching for read operations
|
||||
hooks.Register(AfterRead, ExampleRelatedDataHook)
|
||||
|
||||
logger.Info("Example hooks registered successfully")
|
||||
}
|
||||
347
pkg/restheadspec/hooks_test.go
Normal file
347
pkg/restheadspec/hooks_test.go
Normal file
@ -0,0 +1,347 @@
|
||||
package restheadspec
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestHookRegistry tests the hook registry functionality
|
||||
func TestHookRegistry(t *testing.T) {
|
||||
registry := NewHookRegistry()
|
||||
|
||||
// Test registering a hook
|
||||
called := false
|
||||
hook := func(ctx *HookContext) error {
|
||||
called = true
|
||||
return nil
|
||||
}
|
||||
|
||||
registry.Register(BeforeRead, hook)
|
||||
|
||||
if registry.Count(BeforeRead) != 1 {
|
||||
t.Errorf("Expected 1 hook, got %d", registry.Count(BeforeRead))
|
||||
}
|
||||
|
||||
// Test executing a hook
|
||||
ctx := &HookContext{
|
||||
Context: context.Background(),
|
||||
Schema: "test",
|
||||
Entity: "users",
|
||||
}
|
||||
|
||||
err := registry.Execute(BeforeRead, ctx)
|
||||
if err != nil {
|
||||
t.Errorf("Hook execution failed: %v", err)
|
||||
}
|
||||
|
||||
if !called {
|
||||
t.Error("Hook was not called")
|
||||
}
|
||||
}
|
||||
|
||||
// TestHookExecution tests hook execution order
|
||||
func TestHookExecutionOrder(t *testing.T) {
|
||||
registry := NewHookRegistry()
|
||||
|
||||
order := []int{}
|
||||
|
||||
hook1 := func(ctx *HookContext) error {
|
||||
order = append(order, 1)
|
||||
return nil
|
||||
}
|
||||
|
||||
hook2 := func(ctx *HookContext) error {
|
||||
order = append(order, 2)
|
||||
return nil
|
||||
}
|
||||
|
||||
hook3 := func(ctx *HookContext) error {
|
||||
order = append(order, 3)
|
||||
return nil
|
||||
}
|
||||
|
||||
registry.Register(BeforeCreate, hook1)
|
||||
registry.Register(BeforeCreate, hook2)
|
||||
registry.Register(BeforeCreate, hook3)
|
||||
|
||||
ctx := &HookContext{
|
||||
Context: context.Background(),
|
||||
Schema: "test",
|
||||
Entity: "users",
|
||||
}
|
||||
|
||||
err := registry.Execute(BeforeCreate, ctx)
|
||||
if err != nil {
|
||||
t.Errorf("Hook execution failed: %v", err)
|
||||
}
|
||||
|
||||
if len(order) != 3 {
|
||||
t.Errorf("Expected 3 hooks to be called, got %d", len(order))
|
||||
}
|
||||
|
||||
if order[0] != 1 || order[1] != 2 || order[2] != 3 {
|
||||
t.Errorf("Hooks executed in wrong order: %v", order)
|
||||
}
|
||||
}
|
||||
|
||||
// TestHookError tests hook error handling
|
||||
func TestHookError(t *testing.T) {
|
||||
registry := NewHookRegistry()
|
||||
|
||||
executed := []string{}
|
||||
|
||||
hook1 := func(ctx *HookContext) error {
|
||||
executed = append(executed, "hook1")
|
||||
return nil
|
||||
}
|
||||
|
||||
hook2 := func(ctx *HookContext) error {
|
||||
executed = append(executed, "hook2")
|
||||
return fmt.Errorf("hook2 error")
|
||||
}
|
||||
|
||||
hook3 := func(ctx *HookContext) error {
|
||||
executed = append(executed, "hook3")
|
||||
return nil
|
||||
}
|
||||
|
||||
registry.Register(BeforeUpdate, hook1)
|
||||
registry.Register(BeforeUpdate, hook2)
|
||||
registry.Register(BeforeUpdate, hook3)
|
||||
|
||||
ctx := &HookContext{
|
||||
Context: context.Background(),
|
||||
Schema: "test",
|
||||
Entity: "users",
|
||||
}
|
||||
|
||||
err := registry.Execute(BeforeUpdate, ctx)
|
||||
if err == nil {
|
||||
t.Error("Expected error from hook execution")
|
||||
}
|
||||
|
||||
if len(executed) != 2 {
|
||||
t.Errorf("Expected only 2 hooks to be executed, got %d", len(executed))
|
||||
}
|
||||
|
||||
if executed[0] != "hook1" || executed[1] != "hook2" {
|
||||
t.Errorf("Unexpected execution order: %v", executed)
|
||||
}
|
||||
}
|
||||
|
||||
// TestHookDataModification tests modifying data in hooks
|
||||
func TestHookDataModification(t *testing.T) {
|
||||
registry := NewHookRegistry()
|
||||
|
||||
modifyHook := func(ctx *HookContext) error {
|
||||
if dataMap, ok := ctx.Data.(map[string]interface{}); ok {
|
||||
dataMap["modified"] = true
|
||||
ctx.Data = dataMap
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
registry.Register(BeforeCreate, modifyHook)
|
||||
|
||||
data := map[string]interface{}{
|
||||
"name": "test",
|
||||
}
|
||||
|
||||
ctx := &HookContext{
|
||||
Context: context.Background(),
|
||||
Schema: "test",
|
||||
Entity: "users",
|
||||
Data: data,
|
||||
}
|
||||
|
||||
err := registry.Execute(BeforeCreate, ctx)
|
||||
if err != nil {
|
||||
t.Errorf("Hook execution failed: %v", err)
|
||||
}
|
||||
|
||||
modifiedData := ctx.Data.(map[string]interface{})
|
||||
if !modifiedData["modified"].(bool) {
|
||||
t.Error("Data was not modified by hook")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRegisterMultiple tests registering a hook for multiple types
|
||||
func TestRegisterMultiple(t *testing.T) {
|
||||
registry := NewHookRegistry()
|
||||
|
||||
called := 0
|
||||
hook := func(ctx *HookContext) error {
|
||||
called++
|
||||
return nil
|
||||
}
|
||||
|
||||
registry.RegisterMultiple([]HookType{
|
||||
BeforeRead,
|
||||
BeforeCreate,
|
||||
BeforeUpdate,
|
||||
}, hook)
|
||||
|
||||
if registry.Count(BeforeRead) != 1 {
|
||||
t.Error("Hook not registered for BeforeRead")
|
||||
}
|
||||
if registry.Count(BeforeCreate) != 1 {
|
||||
t.Error("Hook not registered for BeforeCreate")
|
||||
}
|
||||
if registry.Count(BeforeUpdate) != 1 {
|
||||
t.Error("Hook not registered for BeforeUpdate")
|
||||
}
|
||||
|
||||
ctx := &HookContext{
|
||||
Context: context.Background(),
|
||||
Schema: "test",
|
||||
Entity: "users",
|
||||
}
|
||||
|
||||
registry.Execute(BeforeRead, ctx)
|
||||
registry.Execute(BeforeCreate, ctx)
|
||||
registry.Execute(BeforeUpdate, ctx)
|
||||
|
||||
if called != 3 {
|
||||
t.Errorf("Expected hook to be called 3 times, got %d", called)
|
||||
}
|
||||
}
|
||||
|
||||
// TestClearHooks tests clearing hooks
|
||||
func TestClearHooks(t *testing.T) {
|
||||
registry := NewHookRegistry()
|
||||
|
||||
hook := func(ctx *HookContext) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
registry.Register(BeforeRead, hook)
|
||||
registry.Register(BeforeCreate, hook)
|
||||
|
||||
if registry.Count(BeforeRead) != 1 {
|
||||
t.Error("Hook not registered")
|
||||
}
|
||||
|
||||
registry.Clear(BeforeRead)
|
||||
|
||||
if registry.Count(BeforeRead) != 0 {
|
||||
t.Error("Hook not cleared")
|
||||
}
|
||||
|
||||
if registry.Count(BeforeCreate) != 1 {
|
||||
t.Error("Wrong hook was cleared")
|
||||
}
|
||||
}
|
||||
|
||||
// TestClearAllHooks tests clearing all hooks
|
||||
func TestClearAllHooks(t *testing.T) {
|
||||
registry := NewHookRegistry()
|
||||
|
||||
hook := func(ctx *HookContext) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
registry.Register(BeforeRead, hook)
|
||||
registry.Register(BeforeCreate, hook)
|
||||
registry.Register(BeforeUpdate, hook)
|
||||
|
||||
registry.ClearAll()
|
||||
|
||||
if registry.Count(BeforeRead) != 0 || registry.Count(BeforeCreate) != 0 || registry.Count(BeforeUpdate) != 0 {
|
||||
t.Error("Not all hooks were cleared")
|
||||
}
|
||||
}
|
||||
|
||||
// TestHasHooks tests checking if hooks exist
|
||||
func TestHasHooks(t *testing.T) {
|
||||
registry := NewHookRegistry()
|
||||
|
||||
if registry.HasHooks(BeforeRead) {
|
||||
t.Error("Should not have hooks initially")
|
||||
}
|
||||
|
||||
hook := func(ctx *HookContext) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
registry.Register(BeforeRead, hook)
|
||||
|
||||
if !registry.HasHooks(BeforeRead) {
|
||||
t.Error("Should have hooks after registration")
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetAllHookTypes tests getting all registered hook types
|
||||
func TestGetAllHookTypes(t *testing.T) {
|
||||
registry := NewHookRegistry()
|
||||
|
||||
hook := func(ctx *HookContext) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
registry.Register(BeforeRead, hook)
|
||||
registry.Register(BeforeCreate, hook)
|
||||
registry.Register(AfterUpdate, hook)
|
||||
|
||||
types := registry.GetAllHookTypes()
|
||||
|
||||
if len(types) != 3 {
|
||||
t.Errorf("Expected 3 hook types, got %d", len(types))
|
||||
}
|
||||
|
||||
// Verify all expected types are present
|
||||
expectedTypes := map[HookType]bool{
|
||||
BeforeRead: true,
|
||||
BeforeCreate: true,
|
||||
AfterUpdate: true,
|
||||
}
|
||||
|
||||
for _, hookType := range types {
|
||||
if !expectedTypes[hookType] {
|
||||
t.Errorf("Unexpected hook type: %s", hookType)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestHookContextHandler tests that hooks can access the handler
|
||||
func TestHookContextHandler(t *testing.T) {
|
||||
registry := NewHookRegistry()
|
||||
|
||||
var capturedHandler *Handler
|
||||
|
||||
hook := func(ctx *HookContext) error {
|
||||
// Verify that the handler is accessible from the context
|
||||
if ctx.Handler == nil {
|
||||
return fmt.Errorf("handler is nil in hook context")
|
||||
}
|
||||
capturedHandler = ctx.Handler
|
||||
return nil
|
||||
}
|
||||
|
||||
registry.Register(BeforeRead, hook)
|
||||
|
||||
// Create a mock handler
|
||||
handler := &Handler{
|
||||
hooks: registry,
|
||||
}
|
||||
|
||||
ctx := &HookContext{
|
||||
Context: context.Background(),
|
||||
Handler: handler,
|
||||
Schema: "test",
|
||||
Entity: "users",
|
||||
}
|
||||
|
||||
err := registry.Execute(BeforeRead, ctx)
|
||||
if err != nil {
|
||||
t.Errorf("Hook execution failed: %v", err)
|
||||
}
|
||||
|
||||
if capturedHandler == nil {
|
||||
t.Error("Handler was not captured from hook context")
|
||||
}
|
||||
|
||||
if capturedHandler != handler {
|
||||
t.Error("Captured handler does not match original handler")
|
||||
}
|
||||
}
|
||||
Loading…
Reference in New Issue
Block a user