mirror of
https://github.com/bitechdev/ResolveSpec.git
synced 2025-12-06 22:36:23 +00:00
590 lines
14 KiB
Go
590 lines
14 KiB
Go
package funcspec
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"net/http/httptest"
|
|
"testing"
|
|
|
|
"github.com/bitechdev/ResolveSpec/pkg/common"
|
|
"github.com/bitechdev/ResolveSpec/pkg/security"
|
|
)
|
|
|
|
// TestNewHookRegistry tests hook registry creation
|
|
func TestNewHookRegistry(t *testing.T) {
|
|
registry := NewHookRegistry()
|
|
|
|
if registry == nil {
|
|
t.Fatal("Expected registry to be created, got nil")
|
|
}
|
|
|
|
if registry.hooks == nil {
|
|
t.Error("Expected hooks map to be initialized")
|
|
}
|
|
}
|
|
|
|
// TestRegisterHook tests registering a single hook
|
|
func TestRegisterHook(t *testing.T) {
|
|
registry := NewHookRegistry()
|
|
|
|
hookCalled := false
|
|
testHook := func(ctx *HookContext) error {
|
|
hookCalled = true
|
|
return nil
|
|
}
|
|
|
|
registry.Register(BeforeQuery, testHook)
|
|
|
|
if !registry.HasHooks(BeforeQuery) {
|
|
t.Error("Expected hook to be registered")
|
|
}
|
|
|
|
if registry.Count(BeforeQuery) != 1 {
|
|
t.Errorf("Expected 1 hook, got %d", registry.Count(BeforeQuery))
|
|
}
|
|
|
|
// Execute the hook
|
|
ctx := &HookContext{}
|
|
err := registry.Execute(BeforeQuery, ctx)
|
|
if err != nil {
|
|
t.Errorf("Hook execution failed: %v", err)
|
|
}
|
|
|
|
if !hookCalled {
|
|
t.Error("Expected hook to be called")
|
|
}
|
|
}
|
|
|
|
// TestRegisterMultipleHooks tests registering multiple hooks for same type
|
|
func TestRegisterMultipleHooks(t *testing.T) {
|
|
registry := NewHookRegistry()
|
|
|
|
callOrder := []int{}
|
|
|
|
hook1 := func(ctx *HookContext) error {
|
|
callOrder = append(callOrder, 1)
|
|
return nil
|
|
}
|
|
|
|
hook2 := func(ctx *HookContext) error {
|
|
callOrder = append(callOrder, 2)
|
|
return nil
|
|
}
|
|
|
|
hook3 := func(ctx *HookContext) error {
|
|
callOrder = append(callOrder, 3)
|
|
return nil
|
|
}
|
|
|
|
registry.Register(BeforeQuery, hook1)
|
|
registry.Register(BeforeQuery, hook2)
|
|
registry.Register(BeforeQuery, hook3)
|
|
|
|
if registry.Count(BeforeQuery) != 3 {
|
|
t.Errorf("Expected 3 hooks, got %d", registry.Count(BeforeQuery))
|
|
}
|
|
|
|
// Execute hooks
|
|
ctx := &HookContext{}
|
|
err := registry.Execute(BeforeQuery, ctx)
|
|
if err != nil {
|
|
t.Errorf("Hook execution failed: %v", err)
|
|
}
|
|
|
|
// Verify hooks were called in order
|
|
if len(callOrder) != 3 {
|
|
t.Errorf("Expected 3 hooks to be called, got %d", len(callOrder))
|
|
}
|
|
|
|
for i, expected := range []int{1, 2, 3} {
|
|
if callOrder[i] != expected {
|
|
t.Errorf("Expected hook %d at position %d, got %d", expected, i, callOrder[i])
|
|
}
|
|
}
|
|
}
|
|
|
|
// TestRegisterMultipleHookTypes tests registering a hook for multiple types
|
|
func TestRegisterMultipleHookTypes(t *testing.T) {
|
|
registry := NewHookRegistry()
|
|
|
|
callCount := 0
|
|
testHook := func(ctx *HookContext) error {
|
|
callCount++
|
|
return nil
|
|
}
|
|
|
|
hookTypes := []HookType{BeforeQuery, AfterQuery, BeforeSQLExec}
|
|
registry.RegisterMultiple(hookTypes, testHook)
|
|
|
|
// Verify hook is registered for all types
|
|
for _, hookType := range hookTypes {
|
|
if !registry.HasHooks(hookType) {
|
|
t.Errorf("Expected hook to be registered for %s", hookType)
|
|
}
|
|
|
|
if registry.Count(hookType) != 1 {
|
|
t.Errorf("Expected 1 hook for %s, got %d", hookType, registry.Count(hookType))
|
|
}
|
|
}
|
|
|
|
// Execute each hook type
|
|
ctx := &HookContext{}
|
|
for _, hookType := range hookTypes {
|
|
if err := registry.Execute(hookType, ctx); err != nil {
|
|
t.Errorf("Hook execution failed for %s: %v", hookType, err)
|
|
}
|
|
}
|
|
|
|
if callCount != 3 {
|
|
t.Errorf("Expected hook to be called 3 times, got %d", callCount)
|
|
}
|
|
}
|
|
|
|
// TestHookError tests hook error handling
|
|
func TestHookError(t *testing.T) {
|
|
registry := NewHookRegistry()
|
|
|
|
expectedError := fmt.Errorf("test error")
|
|
errorHook := func(ctx *HookContext) error {
|
|
return expectedError
|
|
}
|
|
|
|
registry.Register(BeforeQuery, errorHook)
|
|
|
|
ctx := &HookContext{}
|
|
err := registry.Execute(BeforeQuery, ctx)
|
|
|
|
if err == nil {
|
|
t.Error("Expected error from hook, got nil")
|
|
}
|
|
|
|
if err.Error() != fmt.Sprintf("hook execution failed: %v", expectedError) {
|
|
t.Errorf("Expected error message to contain hook error, got: %v", err)
|
|
}
|
|
}
|
|
|
|
// TestHookAbort tests hook abort functionality
|
|
func TestHookAbort(t *testing.T) {
|
|
registry := NewHookRegistry()
|
|
|
|
abortHook := func(ctx *HookContext) error {
|
|
ctx.Abort = true
|
|
ctx.AbortMessage = "Operation aborted by hook"
|
|
ctx.AbortCode = 403
|
|
return nil
|
|
}
|
|
|
|
registry.Register(BeforeQuery, abortHook)
|
|
|
|
ctx := &HookContext{}
|
|
err := registry.Execute(BeforeQuery, ctx)
|
|
|
|
if err == nil {
|
|
t.Error("Expected error when hook aborts, got nil")
|
|
}
|
|
|
|
if !ctx.Abort {
|
|
t.Error("Expected Abort to be true")
|
|
}
|
|
|
|
if ctx.AbortMessage != "Operation aborted by hook" {
|
|
t.Errorf("Expected abort message, got: %s", ctx.AbortMessage)
|
|
}
|
|
|
|
if ctx.AbortCode != 403 {
|
|
t.Errorf("Expected abort code 403, got: %d", ctx.AbortCode)
|
|
}
|
|
}
|
|
|
|
// TestHookChainWithError tests that hook chain stops on first error
|
|
func TestHookChainWithError(t *testing.T) {
|
|
registry := NewHookRegistry()
|
|
|
|
callOrder := []int{}
|
|
|
|
hook1 := func(ctx *HookContext) error {
|
|
callOrder = append(callOrder, 1)
|
|
return nil
|
|
}
|
|
|
|
hook2 := func(ctx *HookContext) error {
|
|
callOrder = append(callOrder, 2)
|
|
return fmt.Errorf("error in hook 2")
|
|
}
|
|
|
|
hook3 := func(ctx *HookContext) error {
|
|
callOrder = append(callOrder, 3)
|
|
return nil
|
|
}
|
|
|
|
registry.Register(BeforeQuery, hook1)
|
|
registry.Register(BeforeQuery, hook2)
|
|
registry.Register(BeforeQuery, hook3)
|
|
|
|
ctx := &HookContext{}
|
|
err := registry.Execute(BeforeQuery, ctx)
|
|
|
|
if err == nil {
|
|
t.Error("Expected error from hook chain")
|
|
}
|
|
|
|
// Only first two hooks should have been called
|
|
if len(callOrder) != 2 {
|
|
t.Errorf("Expected 2 hooks to be called, got %d", len(callOrder))
|
|
}
|
|
|
|
if callOrder[0] != 1 || callOrder[1] != 2 {
|
|
t.Errorf("Expected hooks 1 and 2 to be called, got: %v", callOrder)
|
|
}
|
|
}
|
|
|
|
// TestClearHooks tests clearing hooks
|
|
func TestClearHooks(t *testing.T) {
|
|
registry := NewHookRegistry()
|
|
|
|
testHook := func(ctx *HookContext) error {
|
|
return nil
|
|
}
|
|
|
|
registry.Register(BeforeQuery, testHook)
|
|
registry.Register(AfterQuery, testHook)
|
|
|
|
if !registry.HasHooks(BeforeQuery) {
|
|
t.Error("Expected BeforeQuery hook to be registered")
|
|
}
|
|
|
|
registry.Clear(BeforeQuery)
|
|
|
|
if registry.HasHooks(BeforeQuery) {
|
|
t.Error("Expected BeforeQuery hooks to be cleared")
|
|
}
|
|
|
|
if !registry.HasHooks(AfterQuery) {
|
|
t.Error("Expected AfterQuery hook to still be registered")
|
|
}
|
|
}
|
|
|
|
// TestClearAllHooks tests clearing all hooks
|
|
func TestClearAllHooks(t *testing.T) {
|
|
registry := NewHookRegistry()
|
|
|
|
testHook := func(ctx *HookContext) error {
|
|
return nil
|
|
}
|
|
|
|
registry.Register(BeforeQuery, testHook)
|
|
registry.Register(AfterQuery, testHook)
|
|
registry.Register(BeforeSQLExec, testHook)
|
|
|
|
registry.ClearAll()
|
|
|
|
if registry.HasHooks(BeforeQuery) || registry.HasHooks(AfterQuery) || registry.HasHooks(BeforeSQLExec) {
|
|
t.Error("Expected all hooks to be cleared")
|
|
}
|
|
}
|
|
|
|
// TestGetAllHookTypes tests getting all registered hook types
|
|
func TestGetAllHookTypes(t *testing.T) {
|
|
registry := NewHookRegistry()
|
|
|
|
testHook := func(ctx *HookContext) error {
|
|
return nil
|
|
}
|
|
|
|
registry.Register(BeforeQuery, testHook)
|
|
registry.Register(AfterQuery, testHook)
|
|
|
|
types := registry.GetAllHookTypes()
|
|
|
|
if len(types) != 2 {
|
|
t.Errorf("Expected 2 hook types, got %d", len(types))
|
|
}
|
|
|
|
// Verify the types are present
|
|
foundBefore := false
|
|
foundAfter := false
|
|
for _, hookType := range types {
|
|
if hookType == BeforeQuery {
|
|
foundBefore = true
|
|
}
|
|
if hookType == AfterQuery {
|
|
foundAfter = true
|
|
}
|
|
}
|
|
|
|
if !foundBefore || !foundAfter {
|
|
t.Error("Expected both BeforeQuery and AfterQuery hook types")
|
|
}
|
|
}
|
|
|
|
// TestHookContextModification tests that hooks can modify the context
|
|
func TestHookContextModification(t *testing.T) {
|
|
registry := NewHookRegistry()
|
|
|
|
// Hook that modifies SQL query
|
|
modifyHook := func(ctx *HookContext) error {
|
|
ctx.SQLQuery = "SELECT * FROM modified_table"
|
|
ctx.Variables["new_var"] = "new_value"
|
|
return nil
|
|
}
|
|
|
|
registry.Register(BeforeQuery, modifyHook)
|
|
|
|
ctx := &HookContext{
|
|
SQLQuery: "SELECT * FROM original_table",
|
|
Variables: make(map[string]interface{}),
|
|
}
|
|
|
|
err := registry.Execute(BeforeQuery, ctx)
|
|
if err != nil {
|
|
t.Errorf("Hook execution failed: %v", err)
|
|
}
|
|
|
|
if ctx.SQLQuery != "SELECT * FROM modified_table" {
|
|
t.Errorf("Expected SQL query to be modified, got: %s", ctx.SQLQuery)
|
|
}
|
|
|
|
if ctx.Variables["new_var"] != "new_value" {
|
|
t.Errorf("Expected variable to be added, got: %v", ctx.Variables)
|
|
}
|
|
}
|
|
|
|
// TestExampleHooks tests the example hooks
|
|
func TestExampleLoggingHook(t *testing.T) {
|
|
ctx := &HookContext{
|
|
Context: context.Background(),
|
|
SQLQuery: "SELECT * FROM test",
|
|
UserContext: &security.UserContext{
|
|
UserName: "testuser",
|
|
},
|
|
}
|
|
|
|
err := ExampleLoggingHook(ctx)
|
|
if err != nil {
|
|
t.Errorf("ExampleLoggingHook failed: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestExampleSecurityHook(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
sqlQuery string
|
|
userID int
|
|
shouldAbort bool
|
|
}{
|
|
{
|
|
name: "Admin accessing sensitive table",
|
|
sqlQuery: "SELECT * FROM sensitive_table",
|
|
userID: 1,
|
|
shouldAbort: false,
|
|
},
|
|
{
|
|
name: "Non-admin accessing sensitive table",
|
|
sqlQuery: "SELECT * FROM sensitive_table",
|
|
userID: 2,
|
|
shouldAbort: true,
|
|
},
|
|
{
|
|
name: "Non-admin accessing normal table",
|
|
sqlQuery: "SELECT * FROM users",
|
|
userID: 2,
|
|
shouldAbort: false,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
ctx := &HookContext{
|
|
Context: context.Background(),
|
|
SQLQuery: tt.sqlQuery,
|
|
UserContext: &security.UserContext{
|
|
UserID: tt.userID,
|
|
},
|
|
}
|
|
|
|
_ = ExampleSecurityHook(ctx)
|
|
|
|
if tt.shouldAbort {
|
|
if !ctx.Abort {
|
|
t.Error("Expected security hook to abort operation")
|
|
}
|
|
if ctx.AbortCode != 403 {
|
|
t.Errorf("Expected abort code 403, got %d", ctx.AbortCode)
|
|
}
|
|
} else {
|
|
if ctx.Abort {
|
|
t.Error("Expected security hook not to abort operation")
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestExampleResultFilterHook(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
userID int
|
|
result interface{}
|
|
validate func(t *testing.T, result interface{})
|
|
}{
|
|
{
|
|
name: "Admin user - no filtering",
|
|
userID: 1,
|
|
result: map[string]interface{}{
|
|
"id": 1,
|
|
"name": "Test",
|
|
"password": "secret",
|
|
},
|
|
validate: func(t *testing.T, result interface{}) {
|
|
m := result.(map[string]interface{})
|
|
if _, exists := m["password"]; !exists {
|
|
t.Error("Expected password field to remain for admin")
|
|
}
|
|
},
|
|
},
|
|
{
|
|
name: "Regular user - sensitive fields removed",
|
|
userID: 2,
|
|
result: map[string]interface{}{
|
|
"id": 1,
|
|
"name": "Test",
|
|
"password": "secret",
|
|
"ssn": "123-45-6789",
|
|
},
|
|
validate: func(t *testing.T, result interface{}) {
|
|
m := result.(map[string]interface{})
|
|
if _, exists := m["password"]; exists {
|
|
t.Error("Expected password field to be removed")
|
|
}
|
|
if _, exists := m["ssn"]; exists {
|
|
t.Error("Expected ssn field to be removed")
|
|
}
|
|
if _, exists := m["name"]; !exists {
|
|
t.Error("Expected name field to remain")
|
|
}
|
|
},
|
|
},
|
|
{
|
|
name: "Regular user - list results filtered",
|
|
userID: 2,
|
|
result: []map[string]interface{}{
|
|
{"id": 1, "name": "User 1", "password": "secret1"},
|
|
{"id": 2, "name": "User 2", "password": "secret2"},
|
|
},
|
|
validate: func(t *testing.T, result interface{}) {
|
|
list := result.([]map[string]interface{})
|
|
for _, m := range list {
|
|
if _, exists := m["password"]; exists {
|
|
t.Error("Expected password field to be removed from list")
|
|
}
|
|
}
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
ctx := &HookContext{
|
|
Context: context.Background(),
|
|
Result: tt.result,
|
|
UserContext: &security.UserContext{
|
|
UserID: tt.userID,
|
|
},
|
|
}
|
|
|
|
err := ExampleResultFilterHook(ctx)
|
|
if err != nil {
|
|
t.Errorf("Hook failed: %v", err)
|
|
}
|
|
|
|
if tt.validate != nil {
|
|
tt.validate(t, ctx.Result)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestExampleAuditHook(t *testing.T) {
|
|
req := httptest.NewRequest("GET", "/api/test", nil)
|
|
req.RemoteAddr = "192.168.1.1:12345"
|
|
|
|
ctx := &HookContext{
|
|
Context: context.Background(),
|
|
Request: req,
|
|
UserContext: &security.UserContext{
|
|
UserID: 123,
|
|
UserName: "testuser",
|
|
},
|
|
}
|
|
|
|
err := ExampleAuditHook(ctx)
|
|
if err != nil {
|
|
t.Errorf("ExampleAuditHook failed: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestExampleErrorHandlingHook(t *testing.T) {
|
|
ctx := &HookContext{
|
|
Context: context.Background(),
|
|
SQLQuery: "SELECT * FROM test",
|
|
Error: fmt.Errorf("test error"),
|
|
UserContext: &security.UserContext{
|
|
UserName: "testuser",
|
|
},
|
|
}
|
|
|
|
err := ExampleErrorHandlingHook(ctx)
|
|
if err != nil {
|
|
t.Errorf("ExampleErrorHandlingHook failed: %v", err)
|
|
}
|
|
}
|
|
|
|
// TestHookIntegrationWithHandler tests hooks integrated with the handler
|
|
func TestHookIntegrationWithHandler(t *testing.T) {
|
|
db := &MockDatabase{
|
|
RunInTransactionFunc: func(ctx context.Context, fn func(common.Database) error) error {
|
|
queryDB := &MockDatabase{
|
|
QueryFunc: func(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
|
|
rows := dest.(*[]map[string]interface{})
|
|
*rows = []map[string]interface{}{
|
|
{"id": float64(1), "name": "Test User"},
|
|
}
|
|
return nil
|
|
},
|
|
}
|
|
return fn(queryDB)
|
|
},
|
|
}
|
|
|
|
handler := NewHandler(db)
|
|
|
|
// Register a hook that modifies the SQL query
|
|
hookCalled := false
|
|
handler.Hooks().Register(BeforeSQLExec, func(ctx *HookContext) error {
|
|
hookCalled = true
|
|
// Verify we can access context data
|
|
if ctx.SQLQuery == "" {
|
|
t.Error("Expected SQL query to be set")
|
|
}
|
|
if ctx.UserContext == nil {
|
|
t.Error("Expected user context to be set")
|
|
}
|
|
return nil
|
|
})
|
|
|
|
// Execute a query
|
|
req := createTestRequest("GET", "/test", nil, nil, nil)
|
|
w := httptest.NewRecorder()
|
|
|
|
handlerFunc := handler.SqlQuery("SELECT * FROM users WHERE id = 1", false)
|
|
handlerFunc(w, req)
|
|
|
|
if !hookCalled {
|
|
t.Error("Expected hook to be called during query execution")
|
|
}
|
|
|
|
if w.Code != 200 {
|
|
t.Errorf("Expected status 200, got %d", w.Code)
|
|
}
|
|
}
|