mirror of
https://github.com/bitechdev/ResolveSpec.git
synced 2025-11-13 09:53:53 +00:00
Added cursor filters and hooks
This commit is contained in:
parent
fc82a9bc50
commit
c8704c07dd
@ -1,6 +1,11 @@
|
|||||||
package database
|
package database
|
||||||
|
|
||||||
import "strings"
|
import (
|
||||||
|
"reflect"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/Warky-Devs/ResolveSpec/pkg/common"
|
||||||
|
)
|
||||||
|
|
||||||
// parseTableName splits a table name that may contain schema into separate schema and table
|
// parseTableName splits a table name that may contain schema into separate schema and table
|
||||||
// For example: "public.users" -> ("public", "users")
|
// For example: "public.users" -> ("public", "users")
|
||||||
@ -11,3 +16,157 @@ func parseTableName(fullTableName string) (schema, table string) {
|
|||||||
}
|
}
|
||||||
return "", fullTableName
|
return "", fullTableName
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetPrimaryKeyName extracts the primary key column name from a model
|
||||||
|
// It first checks if the model implements PrimaryKeyNameProvider (GetIDName method)
|
||||||
|
// Falls back to reflection to find bun:",pk" tag, then gorm:"primaryKey" tag
|
||||||
|
func GetPrimaryKeyName(model any) string {
|
||||||
|
// Check if model implements PrimaryKeyNameProvider
|
||||||
|
if provider, ok := model.(common.PrimaryKeyNameProvider); ok {
|
||||||
|
return provider.GetIDName()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try Bun tag first
|
||||||
|
if pkName := getPrimaryKeyFromReflection(model, "bun"); pkName != "" {
|
||||||
|
return pkName
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fall back to GORM tag
|
||||||
|
return getPrimaryKeyFromReflection(model, "gorm")
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetModelColumns extracts all column names from a model using reflection
|
||||||
|
// It checks bun tags first, then gorm tags, then json tags, and finally falls back to lowercase field names
|
||||||
|
func GetModelColumns(model any) []string {
|
||||||
|
var columns []string
|
||||||
|
|
||||||
|
modelType := reflect.TypeOf(model)
|
||||||
|
|
||||||
|
// Unwrap pointers, slices, and arrays to get to the base struct type
|
||||||
|
for modelType != nil && (modelType.Kind() == reflect.Pointer || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
||||||
|
modelType = modelType.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate that we have a struct type
|
||||||
|
if modelType == nil || modelType.Kind() != reflect.Struct {
|
||||||
|
return columns
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < modelType.NumField(); i++ {
|
||||||
|
field := modelType.Field(i)
|
||||||
|
|
||||||
|
// Get column name using the same logic as primary key extraction
|
||||||
|
columnName := getColumnNameFromField(field)
|
||||||
|
|
||||||
|
if columnName != "" {
|
||||||
|
columns = append(columns, columnName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return columns
|
||||||
|
}
|
||||||
|
|
||||||
|
// getColumnNameFromField extracts the column name from a struct field
|
||||||
|
// Priority: bun tag -> gorm tag -> json tag -> lowercase field name
|
||||||
|
func getColumnNameFromField(field reflect.StructField) string {
|
||||||
|
// Try bun tag first
|
||||||
|
bunTag := field.Tag.Get("bun")
|
||||||
|
if bunTag != "" && bunTag != "-" {
|
||||||
|
if colName := extractColumnFromBunTag(bunTag); colName != "" {
|
||||||
|
return colName
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try gorm tag
|
||||||
|
gormTag := field.Tag.Get("gorm")
|
||||||
|
if gormTag != "" && gormTag != "-" {
|
||||||
|
if colName := extractColumnFromGormTag(gormTag); colName != "" {
|
||||||
|
return colName
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fall back to json tag
|
||||||
|
jsonTag := field.Tag.Get("json")
|
||||||
|
if jsonTag != "" && jsonTag != "-" {
|
||||||
|
// Extract just the field name before any options
|
||||||
|
parts := strings.Split(jsonTag, ",")
|
||||||
|
if len(parts) > 0 && parts[0] != "" {
|
||||||
|
return parts[0]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Last resort: use field name in lowercase
|
||||||
|
return strings.ToLower(field.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// getPrimaryKeyFromReflection uses reflection to find the primary key field
|
||||||
|
func getPrimaryKeyFromReflection(model any, ormType string) string {
|
||||||
|
val := reflect.ValueOf(model)
|
||||||
|
if val.Kind() == reflect.Pointer {
|
||||||
|
val = val.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
if val.Kind() != reflect.Struct {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
typ := val.Type()
|
||||||
|
for i := 0; i < typ.NumField(); i++ {
|
||||||
|
field := typ.Field(i)
|
||||||
|
|
||||||
|
switch ormType {
|
||||||
|
case "gorm":
|
||||||
|
// Check for gorm tag with primaryKey
|
||||||
|
gormTag := field.Tag.Get("gorm")
|
||||||
|
if strings.Contains(gormTag, "primaryKey") {
|
||||||
|
// Try to extract column name from gorm tag
|
||||||
|
if colName := extractColumnFromGormTag(gormTag); colName != "" {
|
||||||
|
return colName
|
||||||
|
}
|
||||||
|
// Fall back to json tag
|
||||||
|
if jsonTag := field.Tag.Get("json"); jsonTag != "" {
|
||||||
|
return strings.Split(jsonTag, ",")[0]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case "bun":
|
||||||
|
// Check for bun tag with pk flag
|
||||||
|
bunTag := field.Tag.Get("bun")
|
||||||
|
if strings.Contains(bunTag, "pk") {
|
||||||
|
// Extract column name from bun tag
|
||||||
|
if colName := extractColumnFromBunTag(bunTag); colName != "" {
|
||||||
|
return colName
|
||||||
|
}
|
||||||
|
// Fall back to json tag
|
||||||
|
if jsonTag := field.Tag.Get("json"); jsonTag != "" {
|
||||||
|
return strings.Split(jsonTag, ",")[0]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractColumnFromGormTag extracts the column name from a gorm tag
|
||||||
|
// Example: "column:id;primaryKey" -> "id"
|
||||||
|
func extractColumnFromGormTag(tag string) string {
|
||||||
|
parts := strings.Split(tag, ";")
|
||||||
|
for _, part := range parts {
|
||||||
|
part = strings.TrimSpace(part)
|
||||||
|
if colName, found := strings.CutPrefix(part, "column:"); found {
|
||||||
|
return colName
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractColumnFromBunTag extracts the column name from a bun tag
|
||||||
|
// Example: "id,pk" -> "id"
|
||||||
|
// Example: ",pk" -> "" (will fall back to json tag)
|
||||||
|
func extractColumnFromBunTag(tag string) string {
|
||||||
|
parts := strings.Split(tag, ",")
|
||||||
|
if len(parts) > 0 && parts[0] != "" {
|
||||||
|
return parts[0]
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|||||||
233
pkg/common/adapters/database/utils_test.go
Normal file
233
pkg/common/adapters/database/utils_test.go
Normal file
@ -0,0 +1,233 @@
|
|||||||
|
package database
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Test models for GORM
|
||||||
|
type GormModelWithGetIDName struct {
|
||||||
|
ID int `gorm:"column:rid_test;primaryKey" json:"id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m GormModelWithGetIDName) GetIDName() string {
|
||||||
|
return "rid_test"
|
||||||
|
}
|
||||||
|
|
||||||
|
type GormModelWithColumnTag struct {
|
||||||
|
ID int `gorm:"column:custom_id;primaryKey" json:"id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type GormModelWithJSONFallback struct {
|
||||||
|
ID int `gorm:"primaryKey" json:"user_id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test models for Bun
|
||||||
|
type BunModelWithGetIDName struct {
|
||||||
|
ID int `bun:"rid_test,pk" json:"id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m BunModelWithGetIDName) GetIDName() string {
|
||||||
|
return "rid_test"
|
||||||
|
}
|
||||||
|
|
||||||
|
type BunModelWithColumnTag struct {
|
||||||
|
ID int `bun:"custom_id,pk" json:"id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type BunModelWithJSONFallback struct {
|
||||||
|
ID int `bun:",pk" json:"user_id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetPrimaryKeyName(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
model any
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "GORM model with GetIDName method",
|
||||||
|
model: GormModelWithGetIDName{},
|
||||||
|
expected: "rid_test",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "GORM model with column tag",
|
||||||
|
model: GormModelWithColumnTag{},
|
||||||
|
expected: "custom_id",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "GORM model with JSON fallback",
|
||||||
|
model: GormModelWithJSONFallback{},
|
||||||
|
expected: "user_id",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "GORM model pointer with GetIDName",
|
||||||
|
model: &GormModelWithGetIDName{},
|
||||||
|
expected: "rid_test",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "GORM model pointer with column tag",
|
||||||
|
model: &GormModelWithColumnTag{},
|
||||||
|
expected: "custom_id",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Bun model with GetIDName method",
|
||||||
|
model: BunModelWithGetIDName{},
|
||||||
|
expected: "rid_test",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Bun model with column tag",
|
||||||
|
model: BunModelWithColumnTag{},
|
||||||
|
expected: "custom_id",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Bun model with JSON fallback",
|
||||||
|
model: BunModelWithJSONFallback{},
|
||||||
|
expected: "user_id",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Bun model pointer with GetIDName",
|
||||||
|
model: &BunModelWithGetIDName{},
|
||||||
|
expected: "rid_test",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Bun model pointer with column tag",
|
||||||
|
model: &BunModelWithColumnTag{},
|
||||||
|
expected: "custom_id",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := GetPrimaryKeyName(tt.model)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("GetPrimaryKeyName() = %v, want %v", result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractColumnFromGormTag(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
tag string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "column tag with primaryKey",
|
||||||
|
tag: "column:rid_test;primaryKey",
|
||||||
|
expected: "rid_test",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "column tag with spaces",
|
||||||
|
tag: "column:user_id ; primaryKey ; autoIncrement",
|
||||||
|
expected: "user_id",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no column tag",
|
||||||
|
tag: "primaryKey;autoIncrement",
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := extractColumnFromGormTag(tt.tag)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("extractColumnFromGormTag() = %v, want %v", result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractColumnFromBunTag(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
tag string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "column name with pk flag",
|
||||||
|
tag: "rid_test,pk",
|
||||||
|
expected: "rid_test",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "only pk flag",
|
||||||
|
tag: ",pk",
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "column with multiple flags",
|
||||||
|
tag: "user_id,pk,autoincrement",
|
||||||
|
expected: "user_id",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := extractColumnFromBunTag(tt.tag)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("extractColumnFromBunTag() = %v, want %v", result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetModelColumns(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
model any
|
||||||
|
expected []string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Bun model with multiple columns",
|
||||||
|
model: BunModelWithColumnTag{},
|
||||||
|
expected: []string{"custom_id", "name"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "GORM model with multiple columns",
|
||||||
|
model: GormModelWithColumnTag{},
|
||||||
|
expected: []string{"custom_id", "name"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Bun model pointer",
|
||||||
|
model: &BunModelWithColumnTag{},
|
||||||
|
expected: []string{"custom_id", "name"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "GORM model pointer",
|
||||||
|
model: &GormModelWithColumnTag{},
|
||||||
|
expected: []string{"custom_id", "name"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Bun model with JSON fallback",
|
||||||
|
model: BunModelWithJSONFallback{},
|
||||||
|
expected: []string{"user_id", "name"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "GORM model with JSON fallback",
|
||||||
|
model: GormModelWithJSONFallback{},
|
||||||
|
expected: []string{"user_id", "name"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := GetModelColumns(tt.model)
|
||||||
|
if len(result) != len(tt.expected) {
|
||||||
|
t.Errorf("GetModelColumns() returned %d columns, want %d", len(result), len(tt.expected))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for i, col := range result {
|
||||||
|
if col != tt.expected[i] {
|
||||||
|
t.Errorf("GetModelColumns()[%d] = %v, want %v", i, col, tt.expected[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -131,6 +131,11 @@ type TableNameProvider interface {
|
|||||||
TableName() string
|
TableName() string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PrimaryKeyNameProvider interface for models that provide primary key column names
|
||||||
|
type PrimaryKeyNameProvider interface {
|
||||||
|
GetIDName() string
|
||||||
|
}
|
||||||
|
|
||||||
// SchemaProvider interface for models that provide schema names
|
// SchemaProvider interface for models that provide schema names
|
||||||
type SchemaProvider interface {
|
type SchemaProvider interface {
|
||||||
SchemaName() string
|
SchemaName() string
|
||||||
|
|||||||
223
pkg/restheadspec/cursor.go
Normal file
223
pkg/restheadspec/cursor.go
Normal file
@ -0,0 +1,223 @@
|
|||||||
|
package restheadspec
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/Warky-Devs/ResolveSpec/pkg/common"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CursorDirection defines pagination direction
|
||||||
|
type CursorDirection int
|
||||||
|
|
||||||
|
const (
|
||||||
|
CursorForward CursorDirection = 1
|
||||||
|
CursorBackward CursorDirection = -1
|
||||||
|
)
|
||||||
|
|
||||||
|
// GetCursorFilter generates a SQL `EXISTS` subquery for cursor-based pagination.
|
||||||
|
// It uses the current request's sort, cursor, joins (via Expand), and CQL (via ComputedQL).
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - tableName: name of the main table (e.g. "post")
|
||||||
|
// - pkName: primary key column (e.g. "id")
|
||||||
|
// - modelColumns: optional list of valid main-table columns (for validation). Pass nil to skip.
|
||||||
|
// - expandJoins: optional map[alias]string of JOIN clauses (e.g. "user": "LEFT JOIN user ON ...")
|
||||||
|
//
|
||||||
|
// Returns SQL snippet to embed in WHERE clause.
|
||||||
|
func (opts *ExtendedRequestOptions) GetCursorFilter(
|
||||||
|
tableName string,
|
||||||
|
pkName string,
|
||||||
|
modelColumns []string, // optional: for validation
|
||||||
|
expandJoins map[string]string, // optional: alias → JOIN SQL
|
||||||
|
) (string, error) {
|
||||||
|
|
||||||
|
// --------------------------------------------------------------------- //
|
||||||
|
// 1. Determine active cursor
|
||||||
|
// --------------------------------------------------------------------- //
|
||||||
|
cursorID, direction := opts.getActiveCursor()
|
||||||
|
if cursorID == "" {
|
||||||
|
return "", fmt.Errorf("no cursor provided for table %s", tableName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --------------------------------------------------------------------- //
|
||||||
|
// 2. Extract sort columns
|
||||||
|
// --------------------------------------------------------------------- //
|
||||||
|
sortItems := opts.getSortColumns()
|
||||||
|
if len(sortItems) == 0 {
|
||||||
|
return "", fmt.Errorf("no sort columns defined")
|
||||||
|
}
|
||||||
|
|
||||||
|
// --------------------------------------------------------------------- //
|
||||||
|
// 3. Prepare
|
||||||
|
// --------------------------------------------------------------------- //
|
||||||
|
var whereClauses []string
|
||||||
|
joinSQL := ""
|
||||||
|
reverse := direction < 0
|
||||||
|
|
||||||
|
// --------------------------------------------------------------------- //
|
||||||
|
// 4. Process each sort column
|
||||||
|
// --------------------------------------------------------------------- //
|
||||||
|
for _, s := range sortItems {
|
||||||
|
col := strings.TrimSpace(s.Column)
|
||||||
|
if col == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse: "user.name desc nulls last"
|
||||||
|
parts := strings.Split(col, ".")
|
||||||
|
field := strings.TrimSpace(parts[len(parts)-1])
|
||||||
|
prefix := strings.Join(parts[:len(parts)-1], ".")
|
||||||
|
|
||||||
|
// Direction from struct or string
|
||||||
|
desc := strings.EqualFold(s.Direction, "desc") ||
|
||||||
|
strings.Contains(strings.ToLower(field), "desc")
|
||||||
|
field = opts.cleanSortField(field)
|
||||||
|
|
||||||
|
if reverse {
|
||||||
|
desc = !desc
|
||||||
|
}
|
||||||
|
|
||||||
|
// Resolve column
|
||||||
|
cursorCol, targetCol, isJoin, err := opts.resolveColumn(
|
||||||
|
field, prefix, tableName, modelColumns,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("WARN: Skipping invalid sort column %q: %v\n", col, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle joins
|
||||||
|
if isJoin && expandJoins != nil {
|
||||||
|
if joinClause, ok := expandJoins[prefix]; ok {
|
||||||
|
jSQL, cRef := rewriteJoin(joinClause, tableName, prefix)
|
||||||
|
joinSQL = jSQL
|
||||||
|
cursorCol = cRef + "." + field
|
||||||
|
targetCol = prefix + "." + field
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build inequality
|
||||||
|
op := "<"
|
||||||
|
if desc {
|
||||||
|
op = ">"
|
||||||
|
}
|
||||||
|
whereClauses = append(whereClauses, fmt.Sprintf("%s %s %s", cursorCol, op, targetCol))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(whereClauses) == 0 {
|
||||||
|
return "", fmt.Errorf("no valid sort columns after filtering")
|
||||||
|
}
|
||||||
|
|
||||||
|
// --------------------------------------------------------------------- //
|
||||||
|
// 5. Build priority OR-AND chain
|
||||||
|
// --------------------------------------------------------------------- //
|
||||||
|
orSQL := buildPriorityChain(whereClauses)
|
||||||
|
|
||||||
|
// --------------------------------------------------------------------- //
|
||||||
|
// 6. Final EXISTS subquery
|
||||||
|
// --------------------------------------------------------------------- //
|
||||||
|
query := fmt.Sprintf(`EXISTS (
|
||||||
|
SELECT 1
|
||||||
|
FROM %s cursor_select
|
||||||
|
%s
|
||||||
|
WHERE cursor_select.%s = %s
|
||||||
|
AND (%s)
|
||||||
|
)`,
|
||||||
|
tableName,
|
||||||
|
joinSQL,
|
||||||
|
pkName,
|
||||||
|
cursorID,
|
||||||
|
orSQL,
|
||||||
|
)
|
||||||
|
|
||||||
|
return query, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ------------------------------------------------------------------------- //
|
||||||
|
// Helper: get active cursor (forward or backward)
|
||||||
|
func (opts *ExtendedRequestOptions) getActiveCursor() (id string, direction CursorDirection) {
|
||||||
|
if opts.CursorForward != "" {
|
||||||
|
return opts.CursorForward, CursorForward
|
||||||
|
}
|
||||||
|
if opts.CursorBackward != "" {
|
||||||
|
return opts.CursorBackward, CursorBackward
|
||||||
|
}
|
||||||
|
return "", 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper: extract sort columns
|
||||||
|
func (opts *ExtendedRequestOptions) getSortColumns() []common.SortOption {
|
||||||
|
if opts.RequestOptions.Sort != nil {
|
||||||
|
return opts.RequestOptions.Sort
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper: clean sort field (remove desc, asc, nulls)
|
||||||
|
func (opts *ExtendedRequestOptions) cleanSortField(field string) string {
|
||||||
|
f := strings.ToLower(field)
|
||||||
|
for _, token := range []string{"desc", "asc", "nulls last", "nulls first"} {
|
||||||
|
f = strings.ReplaceAll(f, token, "")
|
||||||
|
}
|
||||||
|
return strings.TrimSpace(f)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper: resolve column (main, JSON, CQL, join)
|
||||||
|
func (opts *ExtendedRequestOptions) resolveColumn(
|
||||||
|
field, prefix, tableName string,
|
||||||
|
modelColumns []string,
|
||||||
|
) (cursorCol, targetCol string, isJoin bool, err error) {
|
||||||
|
|
||||||
|
// JSON field
|
||||||
|
if strings.Contains(field, "->") {
|
||||||
|
return "cursor_select." + field, tableName + "." + field, false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CQL via ComputedQL
|
||||||
|
if strings.Contains(strings.ToLower(field), "cql") && opts.ComputedQL != nil {
|
||||||
|
if expr, ok := opts.ComputedQL[field]; ok {
|
||||||
|
return "cursor_select." + expr, expr, false, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Main table column
|
||||||
|
if modelColumns != nil {
|
||||||
|
for _, col := range modelColumns {
|
||||||
|
if strings.EqualFold(col, field) {
|
||||||
|
return "cursor_select." + field, tableName + "." + field, false, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// No validation → allow all main-table fields
|
||||||
|
return "cursor_select." + field, tableName + "." + field, false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Joined column
|
||||||
|
if prefix != "" && prefix != tableName {
|
||||||
|
return "", "", true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return "", "", false, fmt.Errorf("invalid column: %s", field)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ------------------------------------------------------------------------- //
|
||||||
|
// Helper: rewrite JOIN clause for cursor subquery
|
||||||
|
func rewriteJoin(joinClause, mainTable, alias string) (joinSQL, cursorAlias string) {
|
||||||
|
joinSQL = strings.ReplaceAll(joinClause, mainTable+".", "cursor_select.")
|
||||||
|
cursorAlias = "cursor_select_" + alias
|
||||||
|
joinSQL = strings.ReplaceAll(joinSQL, " "+alias+" ", " "+cursorAlias+" ")
|
||||||
|
joinSQL = strings.ReplaceAll(joinSQL, " "+alias+".", " "+cursorAlias+".")
|
||||||
|
return joinSQL, cursorAlias
|
||||||
|
}
|
||||||
|
|
||||||
|
// ------------------------------------------------------------------------- //
|
||||||
|
// Helper: build OR-AND priority chain
|
||||||
|
func buildPriorityChain(clauses []string) string {
|
||||||
|
var or []string
|
||||||
|
for i := 0; i < len(clauses); i++ {
|
||||||
|
and := strings.Join(clauses[:i+1], "\n AND ")
|
||||||
|
or = append(or, "("+and+")")
|
||||||
|
}
|
||||||
|
return strings.Join(or, "\n OR ")
|
||||||
|
}
|
||||||
@ -10,6 +10,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/Warky-Devs/ResolveSpec/pkg/common"
|
"github.com/Warky-Devs/ResolveSpec/pkg/common"
|
||||||
|
"github.com/Warky-Devs/ResolveSpec/pkg/common/adapters/database"
|
||||||
"github.com/Warky-Devs/ResolveSpec/pkg/logger"
|
"github.com/Warky-Devs/ResolveSpec/pkg/logger"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -18,6 +19,7 @@ import (
|
|||||||
type Handler struct {
|
type Handler struct {
|
||||||
db common.Database
|
db common.Database
|
||||||
registry common.ModelRegistry
|
registry common.ModelRegistry
|
||||||
|
hooks *HookRegistry
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewHandler creates a new API handler with database and registry abstractions
|
// NewHandler creates a new API handler with database and registry abstractions
|
||||||
@ -25,9 +27,16 @@ func NewHandler(db common.Database, registry common.ModelRegistry) *Handler {
|
|||||||
return &Handler{
|
return &Handler{
|
||||||
db: db,
|
db: db,
|
||||||
registry: registry,
|
registry: registry,
|
||||||
|
hooks: NewHookRegistry(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Hooks returns the hook registry for this handler
|
||||||
|
// Use this to register custom hooks for operations
|
||||||
|
func (h *Handler) Hooks() *HookRegistry {
|
||||||
|
return h.hooks
|
||||||
|
}
|
||||||
|
|
||||||
// handlePanic is a helper function to handle panics with stack traces
|
// handlePanic is a helper function to handle panics with stack traces
|
||||||
func (h *Handler) handlePanic(w common.ResponseWriter, method string, err interface{}) {
|
func (h *Handler) handlePanic(w common.ResponseWriter, method string, err interface{}) {
|
||||||
stack := debug.Stack()
|
stack := debug.Stack()
|
||||||
@ -184,6 +193,25 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
|||||||
tableName := GetTableName(ctx)
|
tableName := GetTableName(ctx)
|
||||||
model := GetModel(ctx)
|
model := GetModel(ctx)
|
||||||
|
|
||||||
|
// Execute BeforeRead hooks
|
||||||
|
hookCtx := &HookContext{
|
||||||
|
Context: ctx,
|
||||||
|
Handler: h,
|
||||||
|
Schema: schema,
|
||||||
|
Entity: entity,
|
||||||
|
TableName: tableName,
|
||||||
|
Model: model,
|
||||||
|
Options: options,
|
||||||
|
ID: id,
|
||||||
|
Writer: w,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.hooks.Execute(BeforeRead, hookCtx); err != nil {
|
||||||
|
logger.Error("BeforeRead hook failed: %v", err)
|
||||||
|
h.sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// Validate and unwrap model type to get base struct
|
// Validate and unwrap model type to get base struct
|
||||||
modelType := reflect.TypeOf(model)
|
modelType := reflect.TypeOf(model)
|
||||||
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
||||||
@ -310,6 +338,39 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
|||||||
query = query.Offset(*options.Offset)
|
query = query.Offset(*options.Offset)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Apply cursor-based pagination
|
||||||
|
if len(options.CursorForward) > 0 || len(options.CursorBackward) > 0 {
|
||||||
|
logger.Debug("Applying cursor pagination")
|
||||||
|
|
||||||
|
// Get primary key name
|
||||||
|
pkName := database.GetPrimaryKeyName(model)
|
||||||
|
|
||||||
|
// Extract model columns for validation using the generic database function
|
||||||
|
modelColumns := database.GetModelColumns(model)
|
||||||
|
|
||||||
|
// Build expand joins map (if needed in future)
|
||||||
|
var expandJoins map[string]string
|
||||||
|
if len(options.Expand) > 0 {
|
||||||
|
expandJoins = make(map[string]string)
|
||||||
|
// TODO: Build actual JOIN SQL for each expand relation
|
||||||
|
// For now, pass empty map as joins are handled via Preload
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get cursor filter SQL
|
||||||
|
cursorFilter, err := options.GetCursorFilter(tableName, pkName, modelColumns, expandJoins)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Error building cursor filter: %v", err)
|
||||||
|
h.sendError(w, http.StatusBadRequest, "cursor_error", "Invalid cursor pagination", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply cursor filter to query
|
||||||
|
if cursorFilter != "" {
|
||||||
|
logger.Debug("Applying cursor filter: %s", cursorFilter)
|
||||||
|
query = query.Where(cursorFilter)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Execute query - modelPtr was already created earlier
|
// Execute query - modelPtr was already created earlier
|
||||||
if err := query.Scan(ctx, modelPtr); err != nil {
|
if err := query.Scan(ctx, modelPtr); err != nil {
|
||||||
logger.Error("Error executing query: %v", err)
|
logger.Error("Error executing query: %v", err)
|
||||||
@ -333,6 +394,16 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
|||||||
Offset: offset,
|
Offset: offset,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Execute AfterRead hooks
|
||||||
|
hookCtx.Result = modelPtr
|
||||||
|
hookCtx.Error = nil
|
||||||
|
|
||||||
|
if err := h.hooks.Execute(AfterRead, hookCtx); err != nil {
|
||||||
|
logger.Error("AfterRead hook failed: %v", err)
|
||||||
|
h.sendError(w, http.StatusInternalServerError, "hook_error", "Hook execution failed", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
h.sendFormattedResponse(w, modelPtr, metadata, options)
|
h.sendFormattedResponse(w, modelPtr, metadata, options)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -351,6 +422,28 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
|
|||||||
|
|
||||||
logger.Info("Creating record in %s.%s", schema, entity)
|
logger.Info("Creating record in %s.%s", schema, entity)
|
||||||
|
|
||||||
|
// Execute BeforeCreate hooks
|
||||||
|
hookCtx := &HookContext{
|
||||||
|
Context: ctx,
|
||||||
|
Handler: h,
|
||||||
|
Schema: schema,
|
||||||
|
Entity: entity,
|
||||||
|
TableName: tableName,
|
||||||
|
Model: model,
|
||||||
|
Options: options,
|
||||||
|
Data: data,
|
||||||
|
Writer: w,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.hooks.Execute(BeforeCreate, hookCtx); err != nil {
|
||||||
|
logger.Error("BeforeCreate hook failed: %v", err)
|
||||||
|
h.sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use potentially modified data from hook context
|
||||||
|
data = hookCtx.Data
|
||||||
|
|
||||||
// Handle batch creation
|
// Handle batch creation
|
||||||
dataValue := reflect.ValueOf(data)
|
dataValue := reflect.ValueOf(data)
|
||||||
if dataValue.Kind() == reflect.Slice || dataValue.Kind() == reflect.Array {
|
if dataValue.Kind() == reflect.Slice || dataValue.Kind() == reflect.Array {
|
||||||
@ -385,6 +478,16 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Execute AfterCreate hooks for batch creation
|
||||||
|
hookCtx.Result = map[string]interface{}{"created": dataValue.Len()}
|
||||||
|
hookCtx.Error = nil
|
||||||
|
|
||||||
|
if err := h.hooks.Execute(AfterCreate, hookCtx); err != nil {
|
||||||
|
logger.Error("AfterCreate hook failed: %v", err)
|
||||||
|
h.sendError(w, http.StatusInternalServerError, "hook_error", "Hook execution failed", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
h.sendResponse(w, map[string]interface{}{"created": dataValue.Len()}, nil)
|
h.sendResponse(w, map[string]interface{}{"created": dataValue.Len()}, nil)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -410,6 +513,16 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Execute AfterCreate hooks for single record creation
|
||||||
|
hookCtx.Result = modelValue
|
||||||
|
hookCtx.Error = nil
|
||||||
|
|
||||||
|
if err := h.hooks.Execute(AfterCreate, hookCtx); err != nil {
|
||||||
|
logger.Error("AfterCreate hook failed: %v", err)
|
||||||
|
h.sendError(w, http.StatusInternalServerError, "hook_error", "Hook execution failed", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
h.sendResponse(w, modelValue, nil)
|
h.sendResponse(w, modelValue, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -424,9 +537,33 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id
|
|||||||
schema := GetSchema(ctx)
|
schema := GetSchema(ctx)
|
||||||
entity := GetEntity(ctx)
|
entity := GetEntity(ctx)
|
||||||
tableName := GetTableName(ctx)
|
tableName := GetTableName(ctx)
|
||||||
|
model := GetModel(ctx)
|
||||||
|
|
||||||
logger.Info("Updating record in %s.%s", schema, entity)
|
logger.Info("Updating record in %s.%s", schema, entity)
|
||||||
|
|
||||||
|
// Execute BeforeUpdate hooks
|
||||||
|
hookCtx := &HookContext{
|
||||||
|
Context: ctx,
|
||||||
|
Handler: h,
|
||||||
|
Schema: schema,
|
||||||
|
Entity: entity,
|
||||||
|
TableName: tableName,
|
||||||
|
Model: model,
|
||||||
|
Options: options,
|
||||||
|
ID: id,
|
||||||
|
Data: data,
|
||||||
|
Writer: w,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.hooks.Execute(BeforeUpdate, hookCtx); err != nil {
|
||||||
|
logger.Error("BeforeUpdate hook failed: %v", err)
|
||||||
|
h.sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use potentially modified data from hook context
|
||||||
|
data = hookCtx.Data
|
||||||
|
|
||||||
// Convert data to map
|
// Convert data to map
|
||||||
dataMap, ok := data.(map[string]interface{})
|
dataMap, ok := data.(map[string]interface{})
|
||||||
if !ok {
|
if !ok {
|
||||||
@ -462,9 +599,20 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
h.sendResponse(w, map[string]interface{}{
|
// Execute AfterUpdate hooks
|
||||||
|
responseData := map[string]interface{}{
|
||||||
"updated": result.RowsAffected(),
|
"updated": result.RowsAffected(),
|
||||||
}, nil)
|
}
|
||||||
|
hookCtx.Result = responseData
|
||||||
|
hookCtx.Error = nil
|
||||||
|
|
||||||
|
if err := h.hooks.Execute(AfterUpdate, hookCtx); err != nil {
|
||||||
|
logger.Error("AfterUpdate hook failed: %v", err)
|
||||||
|
h.sendError(w, http.StatusInternalServerError, "hook_error", "Hook execution failed", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
h.sendResponse(w, responseData, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id string) {
|
func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id string) {
|
||||||
@ -478,9 +626,28 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
|||||||
schema := GetSchema(ctx)
|
schema := GetSchema(ctx)
|
||||||
entity := GetEntity(ctx)
|
entity := GetEntity(ctx)
|
||||||
tableName := GetTableName(ctx)
|
tableName := GetTableName(ctx)
|
||||||
|
model := GetModel(ctx)
|
||||||
|
|
||||||
logger.Info("Deleting record from %s.%s", schema, entity)
|
logger.Info("Deleting record from %s.%s", schema, entity)
|
||||||
|
|
||||||
|
// Execute BeforeDelete hooks
|
||||||
|
hookCtx := &HookContext{
|
||||||
|
Context: ctx,
|
||||||
|
Handler: h,
|
||||||
|
Schema: schema,
|
||||||
|
Entity: entity,
|
||||||
|
TableName: tableName,
|
||||||
|
Model: model,
|
||||||
|
ID: id,
|
||||||
|
Writer: w,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.hooks.Execute(BeforeDelete, hookCtx); err != nil {
|
||||||
|
logger.Error("BeforeDelete hook failed: %v", err)
|
||||||
|
h.sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
query := h.db.NewDelete().Table(tableName)
|
query := h.db.NewDelete().Table(tableName)
|
||||||
|
|
||||||
if id == "" {
|
if id == "" {
|
||||||
@ -497,9 +664,20 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
h.sendResponse(w, map[string]interface{}{
|
// Execute AfterDelete hooks
|
||||||
|
responseData := map[string]interface{}{
|
||||||
"deleted": result.RowsAffected(),
|
"deleted": result.RowsAffected(),
|
||||||
}, nil)
|
}
|
||||||
|
hookCtx.Result = responseData
|
||||||
|
hookCtx.Error = nil
|
||||||
|
|
||||||
|
if err := h.hooks.Execute(AfterDelete, hookCtx); err != nil {
|
||||||
|
logger.Error("AfterDelete hook failed: %v", err)
|
||||||
|
h.sendError(w, http.StatusInternalServerError, "hook_error", "Hook execution failed", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
h.sendResponse(w, responseData, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
// qualifyColumnName ensures column name is fully qualified with table name if not already
|
// qualifyColumnName ensures column name is fully qualified with table name if not already
|
||||||
|
|||||||
140
pkg/restheadspec/hooks.go
Normal file
140
pkg/restheadspec/hooks.go
Normal file
@ -0,0 +1,140 @@
|
|||||||
|
package restheadspec
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/Warky-Devs/ResolveSpec/pkg/common"
|
||||||
|
"github.com/Warky-Devs/ResolveSpec/pkg/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
// HookType defines the type of hook to execute
|
||||||
|
type HookType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
// Read operation hooks
|
||||||
|
BeforeRead HookType = "before_read"
|
||||||
|
AfterRead HookType = "after_read"
|
||||||
|
|
||||||
|
// Create operation hooks
|
||||||
|
BeforeCreate HookType = "before_create"
|
||||||
|
AfterCreate HookType = "after_create"
|
||||||
|
|
||||||
|
// Update operation hooks
|
||||||
|
BeforeUpdate HookType = "before_update"
|
||||||
|
AfterUpdate HookType = "after_update"
|
||||||
|
|
||||||
|
// Delete operation hooks
|
||||||
|
BeforeDelete HookType = "before_delete"
|
||||||
|
AfterDelete HookType = "after_delete"
|
||||||
|
)
|
||||||
|
|
||||||
|
// HookContext contains all the data available to a hook
|
||||||
|
type HookContext struct {
|
||||||
|
Context context.Context
|
||||||
|
Handler *Handler // Reference to the handler for accessing database, registry, etc.
|
||||||
|
Schema string
|
||||||
|
Entity string
|
||||||
|
TableName string
|
||||||
|
Model interface{}
|
||||||
|
Options ExtendedRequestOptions
|
||||||
|
|
||||||
|
// Operation-specific fields
|
||||||
|
ID string
|
||||||
|
Data interface{} // For create/update operations
|
||||||
|
Result interface{} // For after hooks
|
||||||
|
Error error // For after hooks
|
||||||
|
QueryFilter string // For read operations
|
||||||
|
|
||||||
|
// Response writer - allows hooks to modify response
|
||||||
|
Writer common.ResponseWriter
|
||||||
|
}
|
||||||
|
|
||||||
|
// HookFunc is the signature for hook functions
|
||||||
|
// It receives a HookContext and can modify it or return an error
|
||||||
|
// If an error is returned, the operation will be aborted
|
||||||
|
type HookFunc func(*HookContext) error
|
||||||
|
|
||||||
|
// HookRegistry manages all registered hooks
|
||||||
|
type HookRegistry struct {
|
||||||
|
hooks map[HookType][]HookFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewHookRegistry creates a new hook registry
|
||||||
|
func NewHookRegistry() *HookRegistry {
|
||||||
|
return &HookRegistry{
|
||||||
|
hooks: make(map[HookType][]HookFunc),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register adds a new hook for the specified hook type
|
||||||
|
func (r *HookRegistry) Register(hookType HookType, hook HookFunc) {
|
||||||
|
if r.hooks == nil {
|
||||||
|
r.hooks = make(map[HookType][]HookFunc)
|
||||||
|
}
|
||||||
|
r.hooks[hookType] = append(r.hooks[hookType], hook)
|
||||||
|
logger.Info("Registered hook for %s (total: %d)", hookType, len(r.hooks[hookType]))
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterMultiple registers a hook for multiple hook types
|
||||||
|
func (r *HookRegistry) RegisterMultiple(hookTypes []HookType, hook HookFunc) {
|
||||||
|
for _, hookType := range hookTypes {
|
||||||
|
r.Register(hookType, hook)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute runs all hooks for the specified type in order
|
||||||
|
// If any hook returns an error, execution stops and the error is returned
|
||||||
|
func (r *HookRegistry) Execute(hookType HookType, ctx *HookContext) error {
|
||||||
|
hooks, exists := r.hooks[hookType]
|
||||||
|
if !exists || len(hooks) == 0 {
|
||||||
|
logger.Debug("No hooks registered for %s", hookType)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Debug("Executing %d hook(s) for %s", len(hooks), hookType)
|
||||||
|
|
||||||
|
for i, hook := range hooks {
|
||||||
|
if err := hook(ctx); err != nil {
|
||||||
|
logger.Error("Hook %d for %s failed: %v", i+1, hookType, err)
|
||||||
|
return fmt.Errorf("hook execution failed: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Debug("All hooks for %s executed successfully", hookType)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clear removes all hooks for the specified type
|
||||||
|
func (r *HookRegistry) Clear(hookType HookType) {
|
||||||
|
delete(r.hooks, hookType)
|
||||||
|
logger.Info("Cleared all hooks for %s", hookType)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearAll removes all registered hooks
|
||||||
|
func (r *HookRegistry) ClearAll() {
|
||||||
|
r.hooks = make(map[HookType][]HookFunc)
|
||||||
|
logger.Info("Cleared all hooks")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Count returns the number of hooks registered for a specific type
|
||||||
|
func (r *HookRegistry) Count(hookType HookType) int {
|
||||||
|
if hooks, exists := r.hooks[hookType]; exists {
|
||||||
|
return len(hooks)
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// HasHooks returns true if there are any hooks registered for the specified type
|
||||||
|
func (r *HookRegistry) HasHooks(hookType HookType) bool {
|
||||||
|
return r.Count(hookType) > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAllHookTypes returns all hook types that have registered hooks
|
||||||
|
func (r *HookRegistry) GetAllHookTypes() []HookType {
|
||||||
|
types := make([]HookType, 0, len(r.hooks))
|
||||||
|
for hookType := range r.hooks {
|
||||||
|
types = append(types, hookType)
|
||||||
|
}
|
||||||
|
return types
|
||||||
|
}
|
||||||
197
pkg/restheadspec/hooks_example.go
Normal file
197
pkg/restheadspec/hooks_example.go
Normal file
@ -0,0 +1,197 @@
|
|||||||
|
package restheadspec
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/Warky-Devs/ResolveSpec/pkg/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
// This file contains example implementations showing how to use hooks
|
||||||
|
// These are just examples - you can implement hooks as needed for your application
|
||||||
|
|
||||||
|
// ExampleLoggingHook logs before and after operations
|
||||||
|
func ExampleLoggingHook(hookType HookType) HookFunc {
|
||||||
|
return func(ctx *HookContext) error {
|
||||||
|
logger.Info("[%s] Operation: %s.%s, ID: %s", hookType, ctx.Schema, ctx.Entity, ctx.ID)
|
||||||
|
if ctx.Data != nil {
|
||||||
|
logger.Debug("[%s] Data: %+v", hookType, ctx.Data)
|
||||||
|
}
|
||||||
|
if ctx.Result != nil {
|
||||||
|
logger.Debug("[%s] Result: %+v", hookType, ctx.Result)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExampleValidationHook validates data before create/update operations
|
||||||
|
func ExampleValidationHook(ctx *HookContext) error {
|
||||||
|
// Example: Ensure certain fields are present
|
||||||
|
if dataMap, ok := ctx.Data.(map[string]interface{}); ok {
|
||||||
|
// Check for required fields
|
||||||
|
requiredFields := []string{"name"} // Add your required fields here
|
||||||
|
for _, field := range requiredFields {
|
||||||
|
if _, exists := dataMap[field]; !exists {
|
||||||
|
return fmt.Errorf("required field missing: %s", field)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExampleAuthorizationHook checks if the user has permission to perform the operation
|
||||||
|
func ExampleAuthorizationHook(ctx *HookContext) error {
|
||||||
|
// Example: Check user permissions from context
|
||||||
|
// userID, ok := ctx.Context.Value("user_id").(string)
|
||||||
|
// if !ok {
|
||||||
|
// return fmt.Errorf("unauthorized: no user in context")
|
||||||
|
// }
|
||||||
|
|
||||||
|
// You can access the handler's database or registry if needed
|
||||||
|
// For example, to check permissions in the database:
|
||||||
|
// query := ctx.Handler.db.NewSelect().Table("permissions")...
|
||||||
|
|
||||||
|
// Add your authorization logic here
|
||||||
|
logger.Debug("Authorization check for %s.%s", ctx.Schema, ctx.Entity)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExampleDataTransformHook modifies data before create/update
|
||||||
|
func ExampleDataTransformHook(ctx *HookContext) error {
|
||||||
|
if dataMap, ok := ctx.Data.(map[string]interface{}); ok {
|
||||||
|
// Example: Add a timestamp or user ID
|
||||||
|
// dataMap["updated_at"] = time.Now()
|
||||||
|
// dataMap["updated_by"] = ctx.Context.Value("user_id")
|
||||||
|
|
||||||
|
// Update the context with modified data
|
||||||
|
ctx.Data = dataMap
|
||||||
|
logger.Debug("Data transformed for %s.%s", ctx.Schema, ctx.Entity)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExampleAuditLogHook creates audit log entries for operations
|
||||||
|
func ExampleAuditLogHook(hookType HookType) HookFunc {
|
||||||
|
return func(ctx *HookContext) error {
|
||||||
|
// Example: Log to audit system
|
||||||
|
auditEntry := map[string]interface{}{
|
||||||
|
"operation": hookType,
|
||||||
|
"schema": ctx.Schema,
|
||||||
|
"entity": ctx.Entity,
|
||||||
|
"table_name": ctx.TableName,
|
||||||
|
"id": ctx.ID,
|
||||||
|
}
|
||||||
|
|
||||||
|
if ctx.Error != nil {
|
||||||
|
auditEntry["error"] = ctx.Error.Error()
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("Audit log: %+v", auditEntry)
|
||||||
|
|
||||||
|
// In a real application, you would save this to a database using the handler
|
||||||
|
// Example:
|
||||||
|
// query := ctx.Handler.db.NewInsert().Table("audit_logs").Model(&auditEntry)
|
||||||
|
// if _, err := query.Exec(ctx.Context); err != nil {
|
||||||
|
// logger.Error("Failed to save audit log: %v", err)
|
||||||
|
// }
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExampleCacheInvalidationHook invalidates cache after create/update/delete
|
||||||
|
func ExampleCacheInvalidationHook(ctx *HookContext) error {
|
||||||
|
// Example: Invalidate cache for the entity
|
||||||
|
cacheKey := fmt.Sprintf("%s.%s", ctx.Schema, ctx.Entity)
|
||||||
|
logger.Info("Invalidating cache for: %s", cacheKey)
|
||||||
|
|
||||||
|
// Add your cache invalidation logic here
|
||||||
|
// cache.Delete(cacheKey)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExampleFilterSensitiveDataHook removes sensitive data from responses
|
||||||
|
func ExampleFilterSensitiveDataHook(ctx *HookContext) error {
|
||||||
|
// Example: Remove password fields from results
|
||||||
|
// This would be called in AfterRead hooks
|
||||||
|
logger.Debug("Filtering sensitive data for %s.%s", ctx.Schema, ctx.Entity)
|
||||||
|
|
||||||
|
// Add your data filtering logic here
|
||||||
|
// You would iterate through ctx.Result and remove sensitive fields
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExampleRelatedDataHook fetches related data using the handler's database
|
||||||
|
func ExampleRelatedDataHook(ctx *HookContext) error {
|
||||||
|
// Example: Fetch related data after reading the main entity
|
||||||
|
// This hook demonstrates using ctx.Handler to access the database
|
||||||
|
|
||||||
|
if ctx.Entity == "users" && ctx.Result != nil {
|
||||||
|
// Example: Fetch user's recent activity
|
||||||
|
// userID := ... extract from ctx.Result
|
||||||
|
|
||||||
|
// Use the handler's database to query related data
|
||||||
|
// query := ctx.Handler.db.NewSelect().Table("user_activity").Where("user_id = ?", userID)
|
||||||
|
// var activities []Activity
|
||||||
|
// if err := query.Scan(ctx.Context, &activities); err != nil {
|
||||||
|
// logger.Error("Failed to fetch user activities: %v", err)
|
||||||
|
// return err
|
||||||
|
// }
|
||||||
|
|
||||||
|
// Optionally modify the result to include the related data
|
||||||
|
// if resultMap, ok := ctx.Result.(map[string]interface{}); ok {
|
||||||
|
// resultMap["recent_activities"] = activities
|
||||||
|
// }
|
||||||
|
|
||||||
|
logger.Debug("Fetched related data for user entity")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetupExampleHooks demonstrates how to register hooks on a handler
|
||||||
|
func SetupExampleHooks(handler *Handler) {
|
||||||
|
hooks := handler.Hooks()
|
||||||
|
|
||||||
|
// Register logging hooks for all operations
|
||||||
|
hooks.Register(BeforeRead, ExampleLoggingHook(BeforeRead))
|
||||||
|
hooks.Register(AfterRead, ExampleLoggingHook(AfterRead))
|
||||||
|
hooks.Register(BeforeCreate, ExampleLoggingHook(BeforeCreate))
|
||||||
|
hooks.Register(AfterCreate, ExampleLoggingHook(AfterCreate))
|
||||||
|
hooks.Register(BeforeUpdate, ExampleLoggingHook(BeforeUpdate))
|
||||||
|
hooks.Register(AfterUpdate, ExampleLoggingHook(AfterUpdate))
|
||||||
|
hooks.Register(BeforeDelete, ExampleLoggingHook(BeforeDelete))
|
||||||
|
hooks.Register(AfterDelete, ExampleLoggingHook(AfterDelete))
|
||||||
|
|
||||||
|
// Register validation hooks for create/update
|
||||||
|
hooks.Register(BeforeCreate, ExampleValidationHook)
|
||||||
|
hooks.Register(BeforeUpdate, ExampleValidationHook)
|
||||||
|
|
||||||
|
// Register authorization hooks for all operations
|
||||||
|
hooks.RegisterMultiple([]HookType{
|
||||||
|
BeforeRead, BeforeCreate, BeforeUpdate, BeforeDelete,
|
||||||
|
}, ExampleAuthorizationHook)
|
||||||
|
|
||||||
|
// Register data transform hook for create/update
|
||||||
|
hooks.Register(BeforeCreate, ExampleDataTransformHook)
|
||||||
|
hooks.Register(BeforeUpdate, ExampleDataTransformHook)
|
||||||
|
|
||||||
|
// Register audit log hooks for after operations
|
||||||
|
hooks.Register(AfterCreate, ExampleAuditLogHook(AfterCreate))
|
||||||
|
hooks.Register(AfterUpdate, ExampleAuditLogHook(AfterUpdate))
|
||||||
|
hooks.Register(AfterDelete, ExampleAuditLogHook(AfterDelete))
|
||||||
|
|
||||||
|
// Register cache invalidation for after operations
|
||||||
|
hooks.Register(AfterCreate, ExampleCacheInvalidationHook)
|
||||||
|
hooks.Register(AfterUpdate, ExampleCacheInvalidationHook)
|
||||||
|
hooks.Register(AfterDelete, ExampleCacheInvalidationHook)
|
||||||
|
|
||||||
|
// Register sensitive data filtering for read operations
|
||||||
|
hooks.Register(AfterRead, ExampleFilterSensitiveDataHook)
|
||||||
|
|
||||||
|
// Register related data fetching for read operations
|
||||||
|
hooks.Register(AfterRead, ExampleRelatedDataHook)
|
||||||
|
|
||||||
|
logger.Info("Example hooks registered successfully")
|
||||||
|
}
|
||||||
347
pkg/restheadspec/hooks_test.go
Normal file
347
pkg/restheadspec/hooks_test.go
Normal file
@ -0,0 +1,347 @@
|
|||||||
|
package restheadspec
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestHookRegistry tests the hook registry functionality
|
||||||
|
func TestHookRegistry(t *testing.T) {
|
||||||
|
registry := NewHookRegistry()
|
||||||
|
|
||||||
|
// Test registering a hook
|
||||||
|
called := false
|
||||||
|
hook := func(ctx *HookContext) error {
|
||||||
|
called = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
registry.Register(BeforeRead, hook)
|
||||||
|
|
||||||
|
if registry.Count(BeforeRead) != 1 {
|
||||||
|
t.Errorf("Expected 1 hook, got %d", registry.Count(BeforeRead))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test executing a hook
|
||||||
|
ctx := &HookContext{
|
||||||
|
Context: context.Background(),
|
||||||
|
Schema: "test",
|
||||||
|
Entity: "users",
|
||||||
|
}
|
||||||
|
|
||||||
|
err := registry.Execute(BeforeRead, ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Hook execution failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !called {
|
||||||
|
t.Error("Hook was not called")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestHookExecution tests hook execution order
|
||||||
|
func TestHookExecutionOrder(t *testing.T) {
|
||||||
|
registry := NewHookRegistry()
|
||||||
|
|
||||||
|
order := []int{}
|
||||||
|
|
||||||
|
hook1 := func(ctx *HookContext) error {
|
||||||
|
order = append(order, 1)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
hook2 := func(ctx *HookContext) error {
|
||||||
|
order = append(order, 2)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
hook3 := func(ctx *HookContext) error {
|
||||||
|
order = append(order, 3)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
registry.Register(BeforeCreate, hook1)
|
||||||
|
registry.Register(BeforeCreate, hook2)
|
||||||
|
registry.Register(BeforeCreate, hook3)
|
||||||
|
|
||||||
|
ctx := &HookContext{
|
||||||
|
Context: context.Background(),
|
||||||
|
Schema: "test",
|
||||||
|
Entity: "users",
|
||||||
|
}
|
||||||
|
|
||||||
|
err := registry.Execute(BeforeCreate, ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Hook execution failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(order) != 3 {
|
||||||
|
t.Errorf("Expected 3 hooks to be called, got %d", len(order))
|
||||||
|
}
|
||||||
|
|
||||||
|
if order[0] != 1 || order[1] != 2 || order[2] != 3 {
|
||||||
|
t.Errorf("Hooks executed in wrong order: %v", order)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestHookError tests hook error handling
|
||||||
|
func TestHookError(t *testing.T) {
|
||||||
|
registry := NewHookRegistry()
|
||||||
|
|
||||||
|
executed := []string{}
|
||||||
|
|
||||||
|
hook1 := func(ctx *HookContext) error {
|
||||||
|
executed = append(executed, "hook1")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
hook2 := func(ctx *HookContext) error {
|
||||||
|
executed = append(executed, "hook2")
|
||||||
|
return fmt.Errorf("hook2 error")
|
||||||
|
}
|
||||||
|
|
||||||
|
hook3 := func(ctx *HookContext) error {
|
||||||
|
executed = append(executed, "hook3")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
registry.Register(BeforeUpdate, hook1)
|
||||||
|
registry.Register(BeforeUpdate, hook2)
|
||||||
|
registry.Register(BeforeUpdate, hook3)
|
||||||
|
|
||||||
|
ctx := &HookContext{
|
||||||
|
Context: context.Background(),
|
||||||
|
Schema: "test",
|
||||||
|
Entity: "users",
|
||||||
|
}
|
||||||
|
|
||||||
|
err := registry.Execute(BeforeUpdate, ctx)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Expected error from hook execution")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(executed) != 2 {
|
||||||
|
t.Errorf("Expected only 2 hooks to be executed, got %d", len(executed))
|
||||||
|
}
|
||||||
|
|
||||||
|
if executed[0] != "hook1" || executed[1] != "hook2" {
|
||||||
|
t.Errorf("Unexpected execution order: %v", executed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestHookDataModification tests modifying data in hooks
|
||||||
|
func TestHookDataModification(t *testing.T) {
|
||||||
|
registry := NewHookRegistry()
|
||||||
|
|
||||||
|
modifyHook := func(ctx *HookContext) error {
|
||||||
|
if dataMap, ok := ctx.Data.(map[string]interface{}); ok {
|
||||||
|
dataMap["modified"] = true
|
||||||
|
ctx.Data = dataMap
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
registry.Register(BeforeCreate, modifyHook)
|
||||||
|
|
||||||
|
data := map[string]interface{}{
|
||||||
|
"name": "test",
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := &HookContext{
|
||||||
|
Context: context.Background(),
|
||||||
|
Schema: "test",
|
||||||
|
Entity: "users",
|
||||||
|
Data: data,
|
||||||
|
}
|
||||||
|
|
||||||
|
err := registry.Execute(BeforeCreate, ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Hook execution failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
modifiedData := ctx.Data.(map[string]interface{})
|
||||||
|
if !modifiedData["modified"].(bool) {
|
||||||
|
t.Error("Data was not modified by hook")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRegisterMultiple tests registering a hook for multiple types
|
||||||
|
func TestRegisterMultiple(t *testing.T) {
|
||||||
|
registry := NewHookRegistry()
|
||||||
|
|
||||||
|
called := 0
|
||||||
|
hook := func(ctx *HookContext) error {
|
||||||
|
called++
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
registry.RegisterMultiple([]HookType{
|
||||||
|
BeforeRead,
|
||||||
|
BeforeCreate,
|
||||||
|
BeforeUpdate,
|
||||||
|
}, hook)
|
||||||
|
|
||||||
|
if registry.Count(BeforeRead) != 1 {
|
||||||
|
t.Error("Hook not registered for BeforeRead")
|
||||||
|
}
|
||||||
|
if registry.Count(BeforeCreate) != 1 {
|
||||||
|
t.Error("Hook not registered for BeforeCreate")
|
||||||
|
}
|
||||||
|
if registry.Count(BeforeUpdate) != 1 {
|
||||||
|
t.Error("Hook not registered for BeforeUpdate")
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := &HookContext{
|
||||||
|
Context: context.Background(),
|
||||||
|
Schema: "test",
|
||||||
|
Entity: "users",
|
||||||
|
}
|
||||||
|
|
||||||
|
registry.Execute(BeforeRead, ctx)
|
||||||
|
registry.Execute(BeforeCreate, ctx)
|
||||||
|
registry.Execute(BeforeUpdate, ctx)
|
||||||
|
|
||||||
|
if called != 3 {
|
||||||
|
t.Errorf("Expected hook to be called 3 times, got %d", called)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestClearHooks tests clearing hooks
|
||||||
|
func TestClearHooks(t *testing.T) {
|
||||||
|
registry := NewHookRegistry()
|
||||||
|
|
||||||
|
hook := func(ctx *HookContext) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
registry.Register(BeforeRead, hook)
|
||||||
|
registry.Register(BeforeCreate, hook)
|
||||||
|
|
||||||
|
if registry.Count(BeforeRead) != 1 {
|
||||||
|
t.Error("Hook not registered")
|
||||||
|
}
|
||||||
|
|
||||||
|
registry.Clear(BeforeRead)
|
||||||
|
|
||||||
|
if registry.Count(BeforeRead) != 0 {
|
||||||
|
t.Error("Hook not cleared")
|
||||||
|
}
|
||||||
|
|
||||||
|
if registry.Count(BeforeCreate) != 1 {
|
||||||
|
t.Error("Wrong hook was cleared")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestClearAllHooks tests clearing all hooks
|
||||||
|
func TestClearAllHooks(t *testing.T) {
|
||||||
|
registry := NewHookRegistry()
|
||||||
|
|
||||||
|
hook := func(ctx *HookContext) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
registry.Register(BeforeRead, hook)
|
||||||
|
registry.Register(BeforeCreate, hook)
|
||||||
|
registry.Register(BeforeUpdate, hook)
|
||||||
|
|
||||||
|
registry.ClearAll()
|
||||||
|
|
||||||
|
if registry.Count(BeforeRead) != 0 || registry.Count(BeforeCreate) != 0 || registry.Count(BeforeUpdate) != 0 {
|
||||||
|
t.Error("Not all hooks were cleared")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestHasHooks tests checking if hooks exist
|
||||||
|
func TestHasHooks(t *testing.T) {
|
||||||
|
registry := NewHookRegistry()
|
||||||
|
|
||||||
|
if registry.HasHooks(BeforeRead) {
|
||||||
|
t.Error("Should not have hooks initially")
|
||||||
|
}
|
||||||
|
|
||||||
|
hook := func(ctx *HookContext) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
registry.Register(BeforeRead, hook)
|
||||||
|
|
||||||
|
if !registry.HasHooks(BeforeRead) {
|
||||||
|
t.Error("Should have hooks after registration")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestGetAllHookTypes tests getting all registered hook types
|
||||||
|
func TestGetAllHookTypes(t *testing.T) {
|
||||||
|
registry := NewHookRegistry()
|
||||||
|
|
||||||
|
hook := func(ctx *HookContext) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
registry.Register(BeforeRead, hook)
|
||||||
|
registry.Register(BeforeCreate, hook)
|
||||||
|
registry.Register(AfterUpdate, hook)
|
||||||
|
|
||||||
|
types := registry.GetAllHookTypes()
|
||||||
|
|
||||||
|
if len(types) != 3 {
|
||||||
|
t.Errorf("Expected 3 hook types, got %d", len(types))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify all expected types are present
|
||||||
|
expectedTypes := map[HookType]bool{
|
||||||
|
BeforeRead: true,
|
||||||
|
BeforeCreate: true,
|
||||||
|
AfterUpdate: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, hookType := range types {
|
||||||
|
if !expectedTypes[hookType] {
|
||||||
|
t.Errorf("Unexpected hook type: %s", hookType)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestHookContextHandler tests that hooks can access the handler
|
||||||
|
func TestHookContextHandler(t *testing.T) {
|
||||||
|
registry := NewHookRegistry()
|
||||||
|
|
||||||
|
var capturedHandler *Handler
|
||||||
|
|
||||||
|
hook := func(ctx *HookContext) error {
|
||||||
|
// Verify that the handler is accessible from the context
|
||||||
|
if ctx.Handler == nil {
|
||||||
|
return fmt.Errorf("handler is nil in hook context")
|
||||||
|
}
|
||||||
|
capturedHandler = ctx.Handler
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
registry.Register(BeforeRead, hook)
|
||||||
|
|
||||||
|
// Create a mock handler
|
||||||
|
handler := &Handler{
|
||||||
|
hooks: registry,
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := &HookContext{
|
||||||
|
Context: context.Background(),
|
||||||
|
Handler: handler,
|
||||||
|
Schema: "test",
|
||||||
|
Entity: "users",
|
||||||
|
}
|
||||||
|
|
||||||
|
err := registry.Execute(BeforeRead, ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Hook execution failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if capturedHandler == nil {
|
||||||
|
t.Error("Handler was not captured from hook context")
|
||||||
|
}
|
||||||
|
|
||||||
|
if capturedHandler != handler {
|
||||||
|
t.Error("Captured handler does not match original handler")
|
||||||
|
}
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue
Block a user