A lot more tests

This commit is contained in:
Hein
2025-12-08 16:56:48 +02:00
parent aeae9d7e0c
commit c9eaf84125
28 changed files with 4036 additions and 444 deletions

View File

@@ -0,0 +1,138 @@
package resolvespec
import (
"context"
"testing"
)
func TestContextOperations(t *testing.T) {
ctx := context.Background()
// Test Schema
t.Run("WithSchema and GetSchema", func(t *testing.T) {
ctx = WithSchema(ctx, "public")
schema := GetSchema(ctx)
if schema != "public" {
t.Errorf("Expected schema 'public', got '%s'", schema)
}
})
// Test Entity
t.Run("WithEntity and GetEntity", func(t *testing.T) {
ctx = WithEntity(ctx, "users")
entity := GetEntity(ctx)
if entity != "users" {
t.Errorf("Expected entity 'users', got '%s'", entity)
}
})
// Test TableName
t.Run("WithTableName and GetTableName", func(t *testing.T) {
ctx = WithTableName(ctx, "public.users")
tableName := GetTableName(ctx)
if tableName != "public.users" {
t.Errorf("Expected tableName 'public.users', got '%s'", tableName)
}
})
// Test Model
t.Run("WithModel and GetModel", func(t *testing.T) {
type TestModel struct {
ID int
Name string
}
model := &TestModel{ID: 1, Name: "test"}
ctx = WithModel(ctx, model)
retrieved := GetModel(ctx)
if retrieved == nil {
t.Error("Expected model to be retrieved, got nil")
}
if retrievedModel, ok := retrieved.(*TestModel); ok {
if retrievedModel.ID != 1 || retrievedModel.Name != "test" {
t.Errorf("Expected model with ID=1 and Name='test', got ID=%d, Name='%s'", retrievedModel.ID, retrievedModel.Name)
}
} else {
t.Error("Retrieved model is not of expected type")
}
})
// Test ModelPtr
t.Run("WithModelPtr and GetModelPtr", func(t *testing.T) {
type TestModel struct {
ID int
}
models := []*TestModel{}
ctx = WithModelPtr(ctx, &models)
retrieved := GetModelPtr(ctx)
if retrieved == nil {
t.Error("Expected modelPtr to be retrieved, got nil")
}
})
// Test WithRequestData
t.Run("WithRequestData", func(t *testing.T) {
type TestModel struct {
ID int
Name string
}
model := &TestModel{ID: 1, Name: "test"}
modelPtr := &[]*TestModel{}
ctx = WithRequestData(ctx, "test_schema", "test_entity", "test_schema.test_entity", model, modelPtr)
if GetSchema(ctx) != "test_schema" {
t.Errorf("Expected schema 'test_schema', got '%s'", GetSchema(ctx))
}
if GetEntity(ctx) != "test_entity" {
t.Errorf("Expected entity 'test_entity', got '%s'", GetEntity(ctx))
}
if GetTableName(ctx) != "test_schema.test_entity" {
t.Errorf("Expected tableName 'test_schema.test_entity', got '%s'", GetTableName(ctx))
}
if GetModel(ctx) == nil {
t.Error("Expected model to be set")
}
if GetModelPtr(ctx) == nil {
t.Error("Expected modelPtr to be set")
}
})
}
func TestEmptyContext(t *testing.T) {
ctx := context.Background()
t.Run("GetSchema with empty context", func(t *testing.T) {
schema := GetSchema(ctx)
if schema != "" {
t.Errorf("Expected empty schema, got '%s'", schema)
}
})
t.Run("GetEntity with empty context", func(t *testing.T) {
entity := GetEntity(ctx)
if entity != "" {
t.Errorf("Expected empty entity, got '%s'", entity)
}
})
t.Run("GetTableName with empty context", func(t *testing.T) {
tableName := GetTableName(ctx)
if tableName != "" {
t.Errorf("Expected empty tableName, got '%s'", tableName)
}
})
t.Run("GetModel with empty context", func(t *testing.T) {
model := GetModel(ctx)
if model != nil {
t.Errorf("Expected nil model, got %v", model)
}
})
t.Run("GetModelPtr with empty context", func(t *testing.T) {
modelPtr := GetModelPtr(ctx)
if modelPtr != nil {
t.Errorf("Expected nil modelPtr, got %v", modelPtr)
}
})
}

View File

@@ -0,0 +1,367 @@
package resolvespec
import (
"reflect"
"testing"
"github.com/bitechdev/ResolveSpec/pkg/common"
)
func TestNewHandler(t *testing.T) {
// Note: We can't create a real handler without actual DB and registry
// But we can test that the constructor doesn't panic with nil values
handler := NewHandler(nil, nil)
if handler == nil {
t.Error("Expected handler to be created, got nil")
}
if handler.hooks == nil {
t.Error("Expected hooks registry to be initialized")
}
}
func TestHandlerHooks(t *testing.T) {
handler := NewHandler(nil, nil)
hooks := handler.Hooks()
if hooks == nil {
t.Error("Expected hooks registry, got nil")
}
}
func TestSetFallbackHandler(t *testing.T) {
handler := NewHandler(nil, nil)
// We can't directly call the fallback without mocks, but we can verify it's set
handler.SetFallbackHandler(func(w common.ResponseWriter, r common.Request, params map[string]string) {
// Fallback handler implementation
})
if handler.fallbackHandler == nil {
t.Error("Expected fallback handler to be set")
}
}
func TestGetDatabase(t *testing.T) {
handler := NewHandler(nil, nil)
db := handler.GetDatabase()
// Should return nil since we passed nil
if db != nil {
t.Error("Expected nil database")
}
}
func TestParseTableName(t *testing.T) {
handler := NewHandler(nil, nil)
tests := []struct {
name string
fullTableName string
expectedSchema string
expectedTable string
}{
{
name: "Table with schema",
fullTableName: "public.users",
expectedSchema: "public",
expectedTable: "users",
},
{
name: "Table without schema",
fullTableName: "users",
expectedSchema: "",
expectedTable: "users",
},
{
name: "Multiple dots (use last)",
fullTableName: "db.public.users",
expectedSchema: "db.public",
expectedTable: "users",
},
{
name: "Empty string",
fullTableName: "",
expectedSchema: "",
expectedTable: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
schema, table := handler.parseTableName(tt.fullTableName)
if schema != tt.expectedSchema {
t.Errorf("Expected schema '%s', got '%s'", tt.expectedSchema, schema)
}
if table != tt.expectedTable {
t.Errorf("Expected table '%s', got '%s'", tt.expectedTable, table)
}
})
}
}
func TestGetColumnType(t *testing.T) {
tests := []struct {
name string
field reflect.StructField
expectedType string
}{
{
name: "String field",
field: reflect.StructField{
Name: "Name",
Type: reflect.TypeOf(""),
},
expectedType: "string",
},
{
name: "Int field",
field: reflect.StructField{
Name: "Count",
Type: reflect.TypeOf(int(0)),
},
expectedType: "integer",
},
{
name: "Int32 field",
field: reflect.StructField{
Name: "ID",
Type: reflect.TypeOf(int32(0)),
},
expectedType: "integer",
},
{
name: "Int64 field",
field: reflect.StructField{
Name: "BigID",
Type: reflect.TypeOf(int64(0)),
},
expectedType: "bigint",
},
{
name: "Float32 field",
field: reflect.StructField{
Name: "Price",
Type: reflect.TypeOf(float32(0)),
},
expectedType: "float",
},
{
name: "Float64 field",
field: reflect.StructField{
Name: "Amount",
Type: reflect.TypeOf(float64(0)),
},
expectedType: "double",
},
{
name: "Bool field",
field: reflect.StructField{
Name: "Active",
Type: reflect.TypeOf(false),
},
expectedType: "boolean",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
colType := getColumnType(tt.field)
if colType != tt.expectedType {
t.Errorf("Expected column type '%s', got '%s'", tt.expectedType, colType)
}
})
}
}
func TestIsNullable(t *testing.T) {
tests := []struct {
name string
field reflect.StructField
nullable bool
}{
{
name: "Pointer type is nullable",
field: reflect.StructField{
Name: "Name",
Type: reflect.TypeOf((*string)(nil)),
},
nullable: true,
},
{
name: "Non-pointer type without explicit 'not null' tag",
field: reflect.StructField{
Name: "ID",
Type: reflect.TypeOf(int(0)),
},
nullable: true, // isNullable returns true if there's no explicit "not null" tag
},
{
name: "Field with 'not null' tag is not nullable",
field: reflect.StructField{
Name: "Email",
Type: reflect.TypeOf(""),
Tag: `gorm:"not null"`,
},
nullable: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := isNullable(tt.field)
if result != tt.nullable {
t.Errorf("Expected nullable=%v, got %v", tt.nullable, result)
}
})
}
}
func TestToSnakeCase(t *testing.T) {
tests := []struct {
input string
expected string
}{
{
input: "UserID",
expected: "user_id",
},
{
input: "DepartmentName",
expected: "department_name",
},
{
input: "ID",
expected: "id",
},
{
input: "HTTPServer",
expected: "http_server",
},
{
input: "createdAt",
expected: "created_at",
},
{
input: "name",
expected: "name",
},
{
input: "",
expected: "",
},
{
input: "A",
expected: "a",
},
{
input: "APIKey",
expected: "api_key",
},
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
result := toSnakeCase(tt.input)
if result != tt.expected {
t.Errorf("toSnakeCase(%q) = %q, expected %q", tt.input, result, tt.expected)
}
})
}
}
func TestExtractTagValue(t *testing.T) {
handler := NewHandler(nil, nil)
tests := []struct {
name string
tag string
key string
expected string
}{
{
name: "Extract foreignKey",
tag: "foreignKey:UserID;references:ID",
key: "foreignKey",
expected: "UserID",
},
{
name: "Extract references",
tag: "foreignKey:UserID;references:ID",
key: "references",
expected: "ID",
},
{
name: "Key not found",
tag: "foreignKey:UserID;references:ID",
key: "notfound",
expected: "",
},
{
name: "Empty tag",
tag: "",
key: "foreignKey",
expected: "",
},
{
name: "Single value",
tag: "many2many:user_roles",
key: "many2many",
expected: "user_roles",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := handler.extractTagValue(tt.tag, tt.key)
if result != tt.expected {
t.Errorf("extractTagValue(%q, %q) = %q, expected %q", tt.tag, tt.key, result, tt.expected)
}
})
}
}
func TestApplyFilter(t *testing.T) {
// Note: Without a real database, we can't fully test query execution
// But we can test that the method exists
_ = NewHandler(nil, nil)
// The applyFilter method exists and can be tested with actual queries
// but requires database setup which is beyond unit test scope
t.Log("applyFilter method exists and is used in handler operations")
}
func TestShouldUseNestedProcessor(t *testing.T) {
handler := NewHandler(nil, nil)
tests := []struct {
name string
data map[string]interface{}
expected bool
}{
{
name: "Has _request field",
data: map[string]interface{}{
"_request": "nested",
"name": "test",
},
expected: true,
},
{
name: "No special fields",
data: map[string]interface{}{
"name": "test",
"email": "test@example.com",
},
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Note: Without a real model, we can't fully test this
// But we can verify the function exists
result := handler.shouldUseNestedProcessor(tt.data, nil)
// The actual result depends on the model structure
_ = result
})
}
}

View File

@@ -0,0 +1,400 @@
package resolvespec
import (
"context"
"fmt"
"testing"
)
func TestHookRegistry(t *testing.T) {
registry := NewHookRegistry()
// Test registering a hook
called := false
hook := func(ctx *HookContext) error {
called = true
return nil
}
registry.Register(BeforeRead, hook)
if registry.Count(BeforeRead) != 1 {
t.Errorf("Expected 1 hook, got %d", registry.Count(BeforeRead))
}
// Test executing a hook
ctx := &HookContext{
Context: context.Background(),
Schema: "test",
Entity: "users",
}
err := registry.Execute(BeforeRead, ctx)
if err != nil {
t.Errorf("Hook execution failed: %v", err)
}
if !called {
t.Error("Hook was not called")
}
}
func TestHookExecutionOrder(t *testing.T) {
registry := NewHookRegistry()
order := []int{}
hook1 := func(ctx *HookContext) error {
order = append(order, 1)
return nil
}
hook2 := func(ctx *HookContext) error {
order = append(order, 2)
return nil
}
hook3 := func(ctx *HookContext) error {
order = append(order, 3)
return nil
}
registry.Register(BeforeCreate, hook1)
registry.Register(BeforeCreate, hook2)
registry.Register(BeforeCreate, hook3)
ctx := &HookContext{
Context: context.Background(),
Schema: "test",
Entity: "users",
}
err := registry.Execute(BeforeCreate, ctx)
if err != nil {
t.Errorf("Hook execution failed: %v", err)
}
if len(order) != 3 {
t.Errorf("Expected 3 hooks to be called, got %d", len(order))
}
if order[0] != 1 || order[1] != 2 || order[2] != 3 {
t.Errorf("Hooks executed in wrong order: %v", order)
}
}
func TestHookError(t *testing.T) {
registry := NewHookRegistry()
executed := []string{}
hook1 := func(ctx *HookContext) error {
executed = append(executed, "hook1")
return nil
}
hook2 := func(ctx *HookContext) error {
executed = append(executed, "hook2")
return fmt.Errorf("hook2 error")
}
hook3 := func(ctx *HookContext) error {
executed = append(executed, "hook3")
return nil
}
registry.Register(BeforeUpdate, hook1)
registry.Register(BeforeUpdate, hook2)
registry.Register(BeforeUpdate, hook3)
ctx := &HookContext{
Context: context.Background(),
Schema: "test",
Entity: "users",
}
err := registry.Execute(BeforeUpdate, ctx)
if err == nil {
t.Error("Expected error from hook execution")
}
if len(executed) != 2 {
t.Errorf("Expected only 2 hooks to be executed, got %d", len(executed))
}
if executed[0] != "hook1" || executed[1] != "hook2" {
t.Errorf("Unexpected execution order: %v", executed)
}
}
func TestHookDataModification(t *testing.T) {
registry := NewHookRegistry()
modifyHook := func(ctx *HookContext) error {
if dataMap, ok := ctx.Data.(map[string]interface{}); ok {
dataMap["modified"] = true
ctx.Data = dataMap
}
return nil
}
registry.Register(BeforeCreate, modifyHook)
data := map[string]interface{}{
"name": "test",
}
ctx := &HookContext{
Context: context.Background(),
Schema: "test",
Entity: "users",
Data: data,
}
err := registry.Execute(BeforeCreate, ctx)
if err != nil {
t.Errorf("Hook execution failed: %v", err)
}
modifiedData := ctx.Data.(map[string]interface{})
if !modifiedData["modified"].(bool) {
t.Error("Data was not modified by hook")
}
}
func TestRegisterMultiple(t *testing.T) {
registry := NewHookRegistry()
called := 0
hook := func(ctx *HookContext) error {
called++
return nil
}
registry.RegisterMultiple([]HookType{
BeforeRead,
BeforeCreate,
BeforeUpdate,
}, hook)
if registry.Count(BeforeRead) != 1 {
t.Error("Hook not registered for BeforeRead")
}
if registry.Count(BeforeCreate) != 1 {
t.Error("Hook not registered for BeforeCreate")
}
if registry.Count(BeforeUpdate) != 1 {
t.Error("Hook not registered for BeforeUpdate")
}
ctx := &HookContext{
Context: context.Background(),
Schema: "test",
Entity: "users",
}
registry.Execute(BeforeRead, ctx)
registry.Execute(BeforeCreate, ctx)
registry.Execute(BeforeUpdate, ctx)
if called != 3 {
t.Errorf("Expected hook to be called 3 times, got %d", called)
}
}
func TestClearHooks(t *testing.T) {
registry := NewHookRegistry()
hook := func(ctx *HookContext) error {
return nil
}
registry.Register(BeforeRead, hook)
registry.Register(BeforeCreate, hook)
if registry.Count(BeforeRead) != 1 {
t.Error("Hook not registered")
}
registry.Clear(BeforeRead)
if registry.Count(BeforeRead) != 0 {
t.Error("Hook not cleared")
}
if registry.Count(BeforeCreate) != 1 {
t.Error("Wrong hook was cleared")
}
}
func TestClearAllHooks(t *testing.T) {
registry := NewHookRegistry()
hook := func(ctx *HookContext) error {
return nil
}
registry.Register(BeforeRead, hook)
registry.Register(BeforeCreate, hook)
registry.Register(BeforeUpdate, hook)
registry.ClearAll()
if registry.Count(BeforeRead) != 0 || registry.Count(BeforeCreate) != 0 || registry.Count(BeforeUpdate) != 0 {
t.Error("Not all hooks were cleared")
}
}
func TestHasHooks(t *testing.T) {
registry := NewHookRegistry()
if registry.HasHooks(BeforeRead) {
t.Error("Should not have hooks initially")
}
hook := func(ctx *HookContext) error {
return nil
}
registry.Register(BeforeRead, hook)
if !registry.HasHooks(BeforeRead) {
t.Error("Should have hooks after registration")
}
}
func TestGetAllHookTypes(t *testing.T) {
registry := NewHookRegistry()
hook := func(ctx *HookContext) error {
return nil
}
registry.Register(BeforeRead, hook)
registry.Register(BeforeCreate, hook)
registry.Register(AfterUpdate, hook)
types := registry.GetAllHookTypes()
if len(types) != 3 {
t.Errorf("Expected 3 hook types, got %d", len(types))
}
// Verify all expected types are present
expectedTypes := map[HookType]bool{
BeforeRead: true,
BeforeCreate: true,
AfterUpdate: true,
}
for _, hookType := range types {
if !expectedTypes[hookType] {
t.Errorf("Unexpected hook type: %s", hookType)
}
}
}
func TestHookContextHandler(t *testing.T) {
registry := NewHookRegistry()
var capturedHandler *Handler
hook := func(ctx *HookContext) error {
if ctx.Handler == nil {
return fmt.Errorf("handler is nil in hook context")
}
capturedHandler = ctx.Handler
return nil
}
registry.Register(BeforeRead, hook)
handler := &Handler{
hooks: registry,
}
ctx := &HookContext{
Context: context.Background(),
Handler: handler,
Schema: "test",
Entity: "users",
}
err := registry.Execute(BeforeRead, ctx)
if err != nil {
t.Errorf("Hook execution failed: %v", err)
}
if capturedHandler == nil {
t.Error("Handler was not captured from hook context")
}
if capturedHandler != handler {
t.Error("Captured handler does not match original handler")
}
}
func TestHookAbort(t *testing.T) {
registry := NewHookRegistry()
abortHook := func(ctx *HookContext) error {
ctx.Abort = true
ctx.AbortMessage = "Operation aborted by hook"
ctx.AbortCode = 403
return nil
}
registry.Register(BeforeCreate, abortHook)
ctx := &HookContext{
Context: context.Background(),
Schema: "test",
Entity: "users",
}
err := registry.Execute(BeforeCreate, ctx)
if err == nil {
t.Error("Expected error when hook sets Abort=true")
}
if err.Error() != "operation aborted by hook: Operation aborted by hook" {
t.Errorf("Expected abort error message, got: %v", err)
}
}
func TestHookTypes(t *testing.T) {
// Test all hook type constants
hookTypes := []HookType{
BeforeRead,
AfterRead,
BeforeCreate,
AfterCreate,
BeforeUpdate,
AfterUpdate,
BeforeDelete,
AfterDelete,
BeforeScan,
}
for _, hookType := range hookTypes {
if string(hookType) == "" {
t.Errorf("Hook type should not be empty: %v", hookType)
}
}
}
func TestExecuteWithNoHooks(t *testing.T) {
registry := NewHookRegistry()
ctx := &HookContext{
Context: context.Background(),
Schema: "test",
Entity: "users",
}
// Executing with no registered hooks should not cause an error
err := registry.Execute(BeforeRead, ctx)
if err != nil {
t.Errorf("Execute should not fail with no hooks, got: %v", err)
}
}

View File

@@ -0,0 +1,508 @@
// +build integration
package resolvespec
import (
"bytes"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"os"
"testing"
"time"
"github.com/gorilla/mux"
"gorm.io/driver/postgres"
"gorm.io/gorm"
"gorm.io/gorm/logger"
"github.com/bitechdev/ResolveSpec/pkg/common"
)
// Test models
type TestUser struct {
ID uint `gorm:"primaryKey" json:"id"`
Name string `gorm:"not null" json:"name"`
Email string `gorm:"uniqueIndex;not null" json:"email"`
Age int `json:"age"`
Active bool `gorm:"default:true" json:"active"`
CreatedAt time.Time `json:"created_at"`
Posts []TestPost `gorm:"foreignKey:UserID" json:"posts,omitempty"`
}
func (TestUser) TableName() string {
return "test_users"
}
type TestPost struct {
ID uint `gorm:"primaryKey" json:"id"`
UserID uint `gorm:"not null" json:"user_id"`
Title string `gorm:"not null" json:"title"`
Content string `json:"content"`
Published bool `gorm:"default:false" json:"published"`
CreatedAt time.Time `json:"created_at"`
User *TestUser `gorm:"foreignKey:UserID" json:"user,omitempty"`
Comments []TestComment `gorm:"foreignKey:PostID" json:"comments,omitempty"`
}
func (TestPost) TableName() string {
return "test_posts"
}
type TestComment struct {
ID uint `gorm:"primaryKey" json:"id"`
PostID uint `gorm:"not null" json:"post_id"`
Content string `gorm:"not null" json:"content"`
CreatedAt time.Time `json:"created_at"`
Post *TestPost `gorm:"foreignKey:PostID" json:"post,omitempty"`
}
func (TestComment) TableName() string {
return "test_comments"
}
// Test helper functions
func setupTestDB(t *testing.T) *gorm.DB {
// Get connection string from environment or use default
dsn := os.Getenv("TEST_DATABASE_URL")
if dsn == "" {
dsn = "host=localhost user=postgres password=postgres dbname=resolvespec_test port=5434 sslmode=disable"
}
db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
})
if err != nil {
t.Skipf("Skipping integration test: database not available: %v", err)
return nil
}
// Run migrations
err = db.AutoMigrate(&TestUser{}, &TestPost{}, &TestComment{})
if err != nil {
t.Skipf("Skipping integration test: failed to migrate database: %v", err)
return nil
}
return db
}
func cleanupTestDB(t *testing.T, db *gorm.DB) {
// Clean up test data
db.Exec("TRUNCATE TABLE test_comments CASCADE")
db.Exec("TRUNCATE TABLE test_posts CASCADE")
db.Exec("TRUNCATE TABLE test_users CASCADE")
}
func createTestData(t *testing.T, db *gorm.DB) {
users := []TestUser{
{Name: "John Doe", Email: "john@example.com", Age: 30, Active: true},
{Name: "Jane Smith", Email: "jane@example.com", Age: 25, Active: true},
{Name: "Bob Johnson", Email: "bob@example.com", Age: 35, Active: false},
}
for i := range users {
if err := db.Create(&users[i]).Error; err != nil {
t.Fatalf("Failed to create test user: %v", err)
}
}
posts := []TestPost{
{UserID: users[0].ID, Title: "First Post", Content: "Hello World", Published: true},
{UserID: users[0].ID, Title: "Second Post", Content: "More content", Published: true},
{UserID: users[1].ID, Title: "Jane's Post", Content: "Jane's content", Published: false},
}
for i := range posts {
if err := db.Create(&posts[i]).Error; err != nil {
t.Fatalf("Failed to create test post: %v", err)
}
}
comments := []TestComment{
{PostID: posts[0].ID, Content: "Great post!"},
{PostID: posts[0].ID, Content: "Thanks for sharing"},
{PostID: posts[1].ID, Content: "Interesting"},
}
for i := range comments {
if err := db.Create(&comments[i]).Error; err != nil {
t.Fatalf("Failed to create test comment: %v", err)
}
}
}
// Integration tests
func TestIntegration_CreateOperation(t *testing.T) {
db := setupTestDB(t)
defer cleanupTestDB(t, db)
handler := NewHandlerWithGORM(db)
handler.RegisterModel("public", "test_users", TestUser{})
muxRouter := mux.NewRouter()
SetupMuxRoutes(muxRouter, handler, nil)
// Create a new user
requestBody := map[string]interface{}{
"operation": "create",
"data": map[string]interface{}{
"name": "Test User",
"email": "test@example.com",
"age": 28,
},
}
body, _ := json.Marshal(requestBody)
req := httptest.NewRequest("POST", "/public/test_users", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
muxRouter.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d. Body: %s", w.Code, w.Body.String())
}
var response common.Response
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
t.Fatalf("Failed to parse response: %v", err)
}
if !response.Success {
t.Errorf("Expected success=true, got %v. Error: %v", response.Success, response.Error)
}
// Verify user was created
var user TestUser
if err := db.Where("email = ?", "test@example.com").First(&user).Error; err != nil {
t.Errorf("Failed to find created user: %v", err)
}
if user.Name != "Test User" {
t.Errorf("Expected name 'Test User', got '%s'", user.Name)
}
}
func TestIntegration_ReadOperation(t *testing.T) {
db := setupTestDB(t)
defer cleanupTestDB(t, db)
createTestData(t, db)
handler := NewHandlerWithGORM(db)
handler.RegisterModel("public", "test_users", TestUser{})
muxRouter := mux.NewRouter()
SetupMuxRoutes(muxRouter, handler, nil)
// Read all users
requestBody := map[string]interface{}{
"operation": "read",
"options": map[string]interface{}{
"limit": 10,
},
}
body, _ := json.Marshal(requestBody)
req := httptest.NewRequest("POST", "/public/test_users", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
muxRouter.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d. Body: %s", w.Code, w.Body.String())
}
var response common.Response
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
t.Fatalf("Failed to parse response: %v", err)
}
if !response.Success {
t.Errorf("Expected success=true, got %v", response.Success)
}
if response.Metadata == nil {
t.Fatal("Expected metadata, got nil")
}
if response.Metadata.Total != 3 {
t.Errorf("Expected 3 users, got %d", response.Metadata.Total)
}
}
func TestIntegration_ReadWithFilters(t *testing.T) {
db := setupTestDB(t)
defer cleanupTestDB(t, db)
createTestData(t, db)
handler := NewHandlerWithGORM(db)
handler.RegisterModel("public", "test_users", TestUser{})
muxRouter := mux.NewRouter()
SetupMuxRoutes(muxRouter, handler, nil)
// Read users with age > 25
requestBody := map[string]interface{}{
"operation": "read",
"options": map[string]interface{}{
"filters": []map[string]interface{}{
{
"column": "age",
"operator": "gt",
"value": 25,
},
},
},
}
body, _ := json.Marshal(requestBody)
req := httptest.NewRequest("POST", "/public/test_users", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
muxRouter.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", w.Code)
}
var response common.Response
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
t.Fatalf("Failed to parse response: %v", err)
}
if !response.Success {
t.Errorf("Expected success=true, got %v", response.Success)
}
// Should return 2 users (John: 30, Bob: 35)
if response.Metadata.Total != 2 {
t.Errorf("Expected 2 filtered users, got %d", response.Metadata.Total)
}
}
func TestIntegration_UpdateOperation(t *testing.T) {
db := setupTestDB(t)
defer cleanupTestDB(t, db)
createTestData(t, db)
handler := NewHandlerWithGORM(db)
handler.RegisterModel("public", "test_users", TestUser{})
muxRouter := mux.NewRouter()
SetupMuxRoutes(muxRouter, handler, nil)
// Get user ID
var user TestUser
db.Where("email = ?", "john@example.com").First(&user)
// Update user
requestBody := map[string]interface{}{
"operation": "update",
"data": map[string]interface{}{
"id": user.ID,
"age": 31,
"name": "John Doe Updated",
},
}
body, _ := json.Marshal(requestBody)
req := httptest.NewRequest("POST", fmt.Sprintf("/public/test_users/%d", user.ID), bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
muxRouter.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d. Body: %s", w.Code, w.Body.String())
}
// Verify update
var updatedUser TestUser
db.First(&updatedUser, user.ID)
if updatedUser.Age != 31 {
t.Errorf("Expected age 31, got %d", updatedUser.Age)
}
if updatedUser.Name != "John Doe Updated" {
t.Errorf("Expected name 'John Doe Updated', got '%s'", updatedUser.Name)
}
}
func TestIntegration_DeleteOperation(t *testing.T) {
db := setupTestDB(t)
defer cleanupTestDB(t, db)
createTestData(t, db)
handler := NewHandlerWithGORM(db)
handler.RegisterModel("public", "test_users", TestUser{})
muxRouter := mux.NewRouter()
SetupMuxRoutes(muxRouter, handler, nil)
// Get user ID
var user TestUser
db.Where("email = ?", "bob@example.com").First(&user)
// Delete user
requestBody := map[string]interface{}{
"operation": "delete",
}
body, _ := json.Marshal(requestBody)
req := httptest.NewRequest("POST", fmt.Sprintf("/public/test_users/%d", user.ID), bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
muxRouter.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d. Body: %s", w.Code, w.Body.String())
}
// Verify deletion
var count int64
db.Model(&TestUser{}).Where("id = ?", user.ID).Count(&count)
if count != 0 {
t.Errorf("Expected user to be deleted, but found %d records", count)
}
}
func TestIntegration_MetadataOperation(t *testing.T) {
db := setupTestDB(t)
defer cleanupTestDB(t, db)
handler := NewHandlerWithGORM(db)
handler.RegisterModel("public", "test_users", TestUser{})
muxRouter := mux.NewRouter()
SetupMuxRoutes(muxRouter, handler, nil)
// Get metadata
requestBody := map[string]interface{}{
"operation": "meta",
}
body, _ := json.Marshal(requestBody)
req := httptest.NewRequest("POST", "/public/test_users", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
muxRouter.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d. Body: %s", w.Code, w.Body.String())
}
var response common.Response
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
t.Fatalf("Failed to parse response: %v", err)
}
if !response.Success {
t.Errorf("Expected success=true, got %v", response.Success)
}
// Check that metadata includes columns
// The response.Data is an interface{}, we need to unmarshal it properly
dataBytes, _ := json.Marshal(response.Data)
var metadata common.TableMetadata
if err := json.Unmarshal(dataBytes, &metadata); err != nil {
t.Fatalf("Failed to unmarshal metadata: %v. Raw data: %+v", err, response.Data)
}
if len(metadata.Columns) == 0 {
t.Error("Expected metadata to contain columns")
}
// Verify some expected columns
hasID := false
hasName := false
hasEmail := false
for _, col := range metadata.Columns {
if col.Name == "id" {
hasID = true
if !col.IsPrimary {
t.Error("Expected 'id' column to be primary key")
}
}
if col.Name == "name" {
hasName = true
}
if col.Name == "email" {
hasEmail = true
}
}
if !hasID || !hasName || !hasEmail {
t.Error("Expected metadata to contain 'id', 'name', and 'email' columns")
}
}
func TestIntegration_ReadWithPreload(t *testing.T) {
db := setupTestDB(t)
defer cleanupTestDB(t, db)
createTestData(t, db)
handler := NewHandlerWithGORM(db)
handler.RegisterModel("public", "test_users", TestUser{})
handler.RegisterModel("public", "test_posts", TestPost{})
handler.RegisterModel("public", "test_comments", TestComment{})
muxRouter := mux.NewRouter()
SetupMuxRoutes(muxRouter, handler, nil)
// Read users with posts preloaded
requestBody := map[string]interface{}{
"operation": "read",
"options": map[string]interface{}{
"filters": []map[string]interface{}{
{
"column": "email",
"operator": "eq",
"value": "john@example.com",
},
},
"preload": []map[string]interface{}{
{"relation": "posts"},
},
},
}
body, _ := json.Marshal(requestBody)
req := httptest.NewRequest("POST", "/public/test_users", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
muxRouter.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d. Body: %s", w.Code, w.Body.String())
}
var response common.Response
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
t.Fatalf("Failed to parse response: %v", err)
}
if !response.Success {
t.Errorf("Expected success=true, got %v", response.Success)
}
// Verify posts are preloaded
dataBytes, _ := json.Marshal(response.Data)
var users []TestUser
json.Unmarshal(dataBytes, &users)
if len(users) == 0 {
t.Fatal("Expected at least one user")
}
if len(users[0].Posts) == 0 {
t.Error("Expected posts to be preloaded")
}
}

View File

@@ -0,0 +1,114 @@
package resolvespec
import (
"testing"
)
func TestParseModelName(t *testing.T) {
tests := []struct {
name string
fullName string
expectedSchema string
expectedEntity string
}{
{
name: "Model with schema",
fullName: "public.users",
expectedSchema: "public",
expectedEntity: "users",
},
{
name: "Model without schema",
fullName: "users",
expectedSchema: "",
expectedEntity: "users",
},
{
name: "Model with custom schema",
fullName: "myschema.products",
expectedSchema: "myschema",
expectedEntity: "products",
},
{
name: "Empty string",
fullName: "",
expectedSchema: "",
expectedEntity: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
schema, entity := parseModelName(tt.fullName)
if schema != tt.expectedSchema {
t.Errorf("Expected schema '%s', got '%s'", tt.expectedSchema, schema)
}
if entity != tt.expectedEntity {
t.Errorf("Expected entity '%s', got '%s'", tt.expectedEntity, entity)
}
})
}
}
func TestBuildRoutePath(t *testing.T) {
tests := []struct {
name string
schema string
entity string
expectedPath string
}{
{
name: "With schema",
schema: "public",
entity: "users",
expectedPath: "/public/users",
},
{
name: "Without schema",
schema: "",
entity: "users",
expectedPath: "/users",
},
{
name: "Custom schema",
schema: "admin",
entity: "logs",
expectedPath: "/admin/logs",
},
{
name: "Empty entity with schema",
schema: "public",
entity: "",
expectedPath: "/public/",
},
{
name: "Both empty",
schema: "",
entity: "",
expectedPath: "/",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
path := buildRoutePath(tt.schema, tt.entity)
if path != tt.expectedPath {
t.Errorf("Expected path '%s', got '%s'", tt.expectedPath, path)
}
})
}
}
func TestNewStandardMuxRouter(t *testing.T) {
router := NewStandardMuxRouter()
if router == nil {
t.Error("Expected router to be created, got nil")
}
}
func TestNewStandardBunRouter(t *testing.T) {
router := NewStandardBunRouter()
if router == nil {
t.Error("Expected router to be created, got nil")
}
}

View File

@@ -0,0 +1,181 @@
package restheadspec
import (
"context"
"testing"
"github.com/bitechdev/ResolveSpec/pkg/common"
)
func TestContextOperations(t *testing.T) {
ctx := context.Background()
// Test Schema
t.Run("WithSchema and GetSchema", func(t *testing.T) {
ctx = WithSchema(ctx, "public")
schema := GetSchema(ctx)
if schema != "public" {
t.Errorf("Expected schema 'public', got '%s'", schema)
}
})
// Test Entity
t.Run("WithEntity and GetEntity", func(t *testing.T) {
ctx = WithEntity(ctx, "users")
entity := GetEntity(ctx)
if entity != "users" {
t.Errorf("Expected entity 'users', got '%s'", entity)
}
})
// Test TableName
t.Run("WithTableName and GetTableName", func(t *testing.T) {
ctx = WithTableName(ctx, "public.users")
tableName := GetTableName(ctx)
if tableName != "public.users" {
t.Errorf("Expected tableName 'public.users', got '%s'", tableName)
}
})
// Test Model
t.Run("WithModel and GetModel", func(t *testing.T) {
type TestModel struct {
ID int
Name string
}
model := &TestModel{ID: 1, Name: "test"}
ctx = WithModel(ctx, model)
retrieved := GetModel(ctx)
if retrieved == nil {
t.Error("Expected model to be retrieved, got nil")
}
if retrievedModel, ok := retrieved.(*TestModel); ok {
if retrievedModel.ID != 1 || retrievedModel.Name != "test" {
t.Errorf("Expected model with ID=1 and Name='test', got ID=%d, Name='%s'", retrievedModel.ID, retrievedModel.Name)
}
} else {
t.Error("Retrieved model is not of expected type")
}
})
// Test ModelPtr
t.Run("WithModelPtr and GetModelPtr", func(t *testing.T) {
type TestModel struct {
ID int
}
models := []*TestModel{}
ctx = WithModelPtr(ctx, &models)
retrieved := GetModelPtr(ctx)
if retrieved == nil {
t.Error("Expected modelPtr to be retrieved, got nil")
}
})
// Test Options
t.Run("WithOptions and GetOptions", func(t *testing.T) {
limit := 10
options := ExtendedRequestOptions{
RequestOptions: common.RequestOptions{
Limit: &limit,
},
}
ctx = WithOptions(ctx, options)
retrieved := GetOptions(ctx)
if retrieved == nil {
t.Error("Expected options to be retrieved, got nil")
return
}
if retrieved.Limit == nil || *retrieved.Limit != 10 {
t.Error("Expected options to be retrieved with limit=10")
}
})
// Test WithRequestData
t.Run("WithRequestData", func(t *testing.T) {
type TestModel struct {
ID int
Name string
}
model := &TestModel{ID: 1, Name: "test"}
modelPtr := &[]*TestModel{}
limit := 20
options := ExtendedRequestOptions{
RequestOptions: common.RequestOptions{
Limit: &limit,
},
}
ctx = WithRequestData(ctx, "test_schema", "test_entity", "test_schema.test_entity", model, modelPtr, options)
if GetSchema(ctx) != "test_schema" {
t.Errorf("Expected schema 'test_schema', got '%s'", GetSchema(ctx))
}
if GetEntity(ctx) != "test_entity" {
t.Errorf("Expected entity 'test_entity', got '%s'", GetEntity(ctx))
}
if GetTableName(ctx) != "test_schema.test_entity" {
t.Errorf("Expected tableName 'test_schema.test_entity', got '%s'", GetTableName(ctx))
}
if GetModel(ctx) == nil {
t.Error("Expected model to be set")
}
if GetModelPtr(ctx) == nil {
t.Error("Expected modelPtr to be set")
}
opts := GetOptions(ctx)
if opts == nil {
t.Error("Expected options to be set")
return
}
if opts.Limit == nil || *opts.Limit != 20 {
t.Error("Expected options to be set with limit=20")
}
})
}
func TestEmptyContext(t *testing.T) {
ctx := context.Background()
t.Run("GetSchema with empty context", func(t *testing.T) {
schema := GetSchema(ctx)
if schema != "" {
t.Errorf("Expected empty schema, got '%s'", schema)
}
})
t.Run("GetEntity with empty context", func(t *testing.T) {
entity := GetEntity(ctx)
if entity != "" {
t.Errorf("Expected empty entity, got '%s'", entity)
}
})
t.Run("GetTableName with empty context", func(t *testing.T) {
tableName := GetTableName(ctx)
if tableName != "" {
t.Errorf("Expected empty tableName, got '%s'", tableName)
}
})
t.Run("GetModel with empty context", func(t *testing.T) {
model := GetModel(ctx)
if model != nil {
t.Errorf("Expected nil model, got %v", model)
}
})
t.Run("GetModelPtr with empty context", func(t *testing.T) {
modelPtr := GetModelPtr(ctx)
if modelPtr != nil {
t.Errorf("Expected nil modelPtr, got %v", modelPtr)
}
})
t.Run("GetOptions with empty context", func(t *testing.T) {
options := GetOptions(ctx)
// GetOptions returns nil when context is empty
if options != nil {
t.Errorf("Expected nil options in empty context, got %v", options)
}
})
}

View File

@@ -226,8 +226,20 @@ func (h *Handler) HandleGet(w common.ResponseWriter, r common.Request, params ma
return
}
metadata := h.generateMetadata(schema, entity, model)
h.sendResponse(w, metadata, nil)
// Parse request options from headers to get response format settings
options := h.parseOptionsFromHeaders(r, model)
tableMetadata := h.generateMetadata(schema, entity, model)
// Send with formatted response to respect DetailApi/SimpleApi/Syncfusion format
// Create empty metadata for response wrapper
responseMetadata := &common.Metadata{
Total: 0,
Filtered: 0,
Count: 0,
Limit: 0,
Offset: 0,
}
h.sendFormattedResponse(w, tableMetadata, responseMetadata, options)
}
// handleMeta processes meta operation requests

View File

@@ -1,3 +1,5 @@
// +build !integration
package restheadspec
import (

View File

@@ -0,0 +1,46 @@
package restheadspec
import (
"testing"
)
func TestDecodeHeaderValue(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{
name: "Normal string",
input: "test",
expected: "test",
},
{
name: "String without encoding prefix",
input: "hello world",
expected: "hello world",
},
{
name: "Empty string",
input: "",
expected: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := decodeHeaderValue(tt.input)
if result != tt.expected {
t.Errorf("Expected '%s', got '%s'", tt.expected, result)
}
})
}
}
// Note: The following functions are unexported (lowercase) and cannot be tested directly:
// - parseSelectFields
// - parseFieldFilter
// - mapSearchOperator
// - parseCommaSeparated
// - parseSorting
// These are tested indirectly through parseOptionsFromHeaders in query_params_test.go

View File

@@ -0,0 +1,556 @@
// +build integration
package restheadspec
import (
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"os"
"testing"
"time"
"github.com/gorilla/mux"
"gorm.io/driver/postgres"
"gorm.io/gorm"
"gorm.io/gorm/logger"
"github.com/bitechdev/ResolveSpec/pkg/common"
)
// Test models
type TestUser struct {
ID uint `gorm:"primaryKey" json:"id"`
Name string `gorm:"not null" json:"name"`
Email string `gorm:"uniqueIndex;not null" json:"email"`
Age int `json:"age"`
Active bool `gorm:"default:true" json:"active"`
CreatedAt time.Time `json:"created_at"`
Posts []TestPost `gorm:"foreignKey:UserID" json:"posts,omitempty"`
}
func (TestUser) TableName() string {
return "test_users"
}
type TestPost struct {
ID uint `gorm:"primaryKey" json:"id"`
UserID uint `gorm:"not null" json:"user_id"`
Title string `gorm:"not null" json:"title"`
Content string `json:"content"`
Published bool `gorm:"default:false" json:"published"`
CreatedAt time.Time `json:"created_at"`
User *TestUser `gorm:"foreignKey:UserID" json:"user,omitempty"`
Comments []TestComment `gorm:"foreignKey:PostID" json:"comments,omitempty"`
}
func (TestPost) TableName() string {
return "test_posts"
}
type TestComment struct {
ID uint `gorm:"primaryKey" json:"id"`
PostID uint `gorm:"not null" json:"post_id"`
Content string `gorm:"not null" json:"content"`
CreatedAt time.Time `json:"created_at"`
Post *TestPost `gorm:"foreignKey:PostID" json:"post,omitempty"`
}
func (TestComment) TableName() string {
return "test_comments"
}
// Test helper functions
func setupTestDB(t *testing.T) *gorm.DB {
// Get connection string from environment or use default
dsn := os.Getenv("TEST_DATABASE_URL")
if dsn == "" {
dsn = "host=localhost user=postgres password=postgres dbname=restheadspec_test port=5434 sslmode=disable"
}
db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
})
if err != nil {
t.Skipf("Skipping integration test: database not available: %v", err)
return nil
}
// Run migrations
err = db.AutoMigrate(&TestUser{}, &TestPost{}, &TestComment{})
if err != nil {
t.Skipf("Skipping integration test: failed to migrate database: %v", err)
return nil
}
return db
}
func cleanupTestDB(t *testing.T, db *gorm.DB) {
// Clean up test data
db.Exec("TRUNCATE TABLE test_comments CASCADE")
db.Exec("TRUNCATE TABLE test_posts CASCADE")
db.Exec("TRUNCATE TABLE test_users CASCADE")
}
func createTestData(t *testing.T, db *gorm.DB) {
users := []TestUser{
{Name: "John Doe", Email: "john@example.com", Age: 30, Active: true},
{Name: "Jane Smith", Email: "jane@example.com", Age: 25, Active: true},
{Name: "Bob Johnson", Email: "bob@example.com", Age: 35, Active: false},
}
for i := range users {
if err := db.Create(&users[i]).Error; err != nil {
t.Fatalf("Failed to create test user: %v", err)
}
}
posts := []TestPost{
{UserID: users[0].ID, Title: "First Post", Content: "Hello World", Published: true},
{UserID: users[0].ID, Title: "Second Post", Content: "More content", Published: true},
{UserID: users[1].ID, Title: "Jane's Post", Content: "Jane's content", Published: false},
}
for i := range posts {
if err := db.Create(&posts[i]).Error; err != nil {
t.Fatalf("Failed to create test post: %v", err)
}
}
comments := []TestComment{
{PostID: posts[0].ID, Content: "Great post!"},
{PostID: posts[0].ID, Content: "Thanks for sharing"},
{PostID: posts[1].ID, Content: "Interesting"},
}
for i := range comments {
if err := db.Create(&comments[i]).Error; err != nil {
t.Fatalf("Failed to create test comment: %v", err)
}
}
}
// Integration tests
func TestIntegration_GetAllUsers(t *testing.T) {
db := setupTestDB(t)
defer cleanupTestDB(t, db)
createTestData(t, db)
handler := NewHandlerWithGORM(db)
handler.registry.RegisterModel("public.test_users", TestUser{})
muxRouter := mux.NewRouter()
SetupMuxRoutes(muxRouter, handler, nil)
req := httptest.NewRequest("GET", "/public/test_users", nil)
req.Header.Set("X-DetailApi", "true")
w := httptest.NewRecorder()
muxRouter.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d. Body: %s", w.Code, w.Body.String())
}
var response common.Response
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
t.Fatalf("Failed to parse response: %v", err)
}
if !response.Success {
t.Errorf("Expected success=true, got %v", response.Success)
}
if response.Metadata == nil {
t.Fatal("Expected metadata, got nil")
}
if response.Metadata.Total != 3 {
t.Errorf("Expected 3 users, got %d", response.Metadata.Total)
}
}
func TestIntegration_GetUsersWithFilters(t *testing.T) {
db := setupTestDB(t)
defer cleanupTestDB(t, db)
createTestData(t, db)
handler := NewHandlerWithGORM(db)
handler.registry.RegisterModel("public.test_users", TestUser{})
muxRouter := mux.NewRouter()
SetupMuxRoutes(muxRouter, handler, nil)
// Filter: age > 25
req := httptest.NewRequest("GET", "/public/test_users", nil)
req.Header.Set("X-DetailApi", "true")
req.Header.Set("X-SearchOp-Gt-Age", "25")
w := httptest.NewRecorder()
muxRouter.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d. Body: %s", w.Code, w.Body.String())
}
var response common.Response
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
t.Fatalf("Failed to parse response: %v", err)
}
if !response.Success {
t.Errorf("Expected success=true, got %v", response.Success)
}
// Should return 2 users (John: 30, Bob: 35)
if response.Metadata.Total != 2 {
t.Errorf("Expected 2 filtered users, got %d", response.Metadata.Total)
}
}
func TestIntegration_GetUsersWithPagination(t *testing.T) {
db := setupTestDB(t)
defer cleanupTestDB(t, db)
createTestData(t, db)
handler := NewHandlerWithGORM(db)
handler.registry.RegisterModel("public.test_users", TestUser{})
muxRouter := mux.NewRouter()
SetupMuxRoutes(muxRouter, handler, nil)
req := httptest.NewRequest("GET", "/public/test_users", nil)
req.Header.Set("X-DetailApi", "true")
req.Header.Set("X-Limit", "2")
req.Header.Set("X-Offset", "1")
w := httptest.NewRecorder()
muxRouter.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", w.Code)
}
var response common.Response
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
t.Fatalf("Failed to parse response: %v", err)
}
if !response.Success {
t.Errorf("Expected success=true, got %v", response.Success)
}
// Total should still be 3, but we're only retrieving 2 records starting from offset 1
if response.Metadata.Total != 3 {
t.Errorf("Expected total 3 users, got %d", response.Metadata.Total)
}
if response.Metadata.Limit != 2 {
t.Errorf("Expected limit 2, got %d", response.Metadata.Limit)
}
if response.Metadata.Offset != 1 {
t.Errorf("Expected offset 1, got %d", response.Metadata.Offset)
}
}
func TestIntegration_GetUsersWithSorting(t *testing.T) {
db := setupTestDB(t)
defer cleanupTestDB(t, db)
createTestData(t, db)
handler := NewHandlerWithGORM(db)
handler.registry.RegisterModel("public.test_users", TestUser{})
muxRouter := mux.NewRouter()
SetupMuxRoutes(muxRouter, handler, nil)
// Sort by age descending
req := httptest.NewRequest("GET", "/public/test_users", nil)
req.Header.Set("X-DetailApi", "true")
req.Header.Set("X-Sort", "-age")
w := httptest.NewRecorder()
muxRouter.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", w.Code)
}
var response common.Response
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
t.Fatalf("Failed to parse response: %v", err)
}
if !response.Success {
t.Errorf("Expected success=true, got %v", response.Success)
}
// Parse data to verify sort order
dataBytes, _ := json.Marshal(response.Data)
var users []TestUser
json.Unmarshal(dataBytes, &users)
if len(users) < 3 {
t.Fatal("Expected at least 3 users")
}
// Check that users are sorted by age descending (Bob:35, John:30, Jane:25)
if users[0].Age < users[1].Age || users[1].Age < users[2].Age {
t.Error("Expected users to be sorted by age descending")
}
}
func TestIntegration_GetUsersWithColumnsSelection(t *testing.T) {
db := setupTestDB(t)
defer cleanupTestDB(t, db)
createTestData(t, db)
handler := NewHandlerWithGORM(db)
handler.registry.RegisterModel("public.test_users", TestUser{})
muxRouter := mux.NewRouter()
SetupMuxRoutes(muxRouter, handler, nil)
req := httptest.NewRequest("GET", "/public/test_users", nil)
req.Header.Set("X-DetailApi", "true")
req.Header.Set("X-Columns", "id,name,email")
w := httptest.NewRecorder()
muxRouter.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d. Body: %s", w.Code, w.Body.String())
}
var response common.Response
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
t.Fatalf("Failed to parse response: %v", err)
}
if !response.Success {
t.Errorf("Expected success=true, got %v", response.Success)
}
// Verify data was returned (column selection doesn't affect metadata count)
if response.Metadata.Total != 3 {
t.Errorf("Expected 3 users, got %d", response.Metadata.Total)
}
}
func TestIntegration_GetUsersWithPreload(t *testing.T) {
db := setupTestDB(t)
defer cleanupTestDB(t, db)
createTestData(t, db)
handler := NewHandlerWithGORM(db)
handler.registry.RegisterModel("public.test_users", TestUser{})
muxRouter := mux.NewRouter()
SetupMuxRoutes(muxRouter, handler, nil)
req := httptest.NewRequest("GET", "/public/test_users?x-fieldfilter-email=john@example.com", nil)
req.Header.Set("X-DetailApi", "true")
req.Header.Set("X-Preload", "Posts")
w := httptest.NewRecorder()
muxRouter.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d. Body: %s", w.Code, w.Body.String())
}
var response common.Response
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
t.Fatalf("Failed to parse response: %v", err)
}
if !response.Success {
t.Errorf("Expected success=true, got %v", response.Success)
}
// Verify posts are preloaded
dataBytes, _ := json.Marshal(response.Data)
var users []TestUser
json.Unmarshal(dataBytes, &users)
if len(users) == 0 {
t.Fatal("Expected at least one user")
}
if len(users[0].Posts) == 0 {
t.Error("Expected posts to be preloaded")
}
}
func TestIntegration_GetMetadata(t *testing.T) {
db := setupTestDB(t)
defer cleanupTestDB(t, db)
handler := NewHandlerWithGORM(db)
handler.registry.RegisterModel("public.test_users", TestUser{})
muxRouter := mux.NewRouter()
SetupMuxRoutes(muxRouter, handler, nil)
req := httptest.NewRequest("GET", "/public/test_users/metadata", nil)
req.Header.Set("X-DetailApi", "true")
w := httptest.NewRecorder()
muxRouter.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d. Body: %s", w.Code, w.Body.String())
}
var response common.Response
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
t.Fatalf("Failed to parse response: %v", err)
}
if !response.Success {
t.Errorf("Expected success=true, got %v. Body: %s, Error: %v", response.Success, w.Body.String(), response.Error)
}
// Check that metadata includes columns
metadataBytes, _ := json.Marshal(response.Data)
var metadata common.TableMetadata
json.Unmarshal(metadataBytes, &metadata)
if len(metadata.Columns) == 0 {
t.Error("Expected metadata to contain columns")
}
// Verify some expected columns
hasID := false
hasName := false
hasEmail := false
for _, col := range metadata.Columns {
if col.Name == "id" {
hasID = true
}
if col.Name == "name" {
hasName = true
}
if col.Name == "email" {
hasEmail = true
}
}
if !hasID || !hasName || !hasEmail {
t.Error("Expected metadata to contain 'id', 'name', and 'email' columns")
}
}
func TestIntegration_OptionsRequest(t *testing.T) {
db := setupTestDB(t)
defer cleanupTestDB(t, db)
handler := NewHandlerWithGORM(db)
handler.registry.RegisterModel("public.test_users", TestUser{})
muxRouter := mux.NewRouter()
SetupMuxRoutes(muxRouter, handler, nil)
req := httptest.NewRequest("OPTIONS", "/public/test_users", nil)
w := httptest.NewRecorder()
muxRouter.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", w.Code)
}
// Check CORS headers
if w.Header().Get("Access-Control-Allow-Origin") == "" {
t.Error("Expected Access-Control-Allow-Origin header")
}
if w.Header().Get("Access-Control-Allow-Methods") == "" {
t.Error("Expected Access-Control-Allow-Methods header")
}
}
func TestIntegration_QueryParamsOverHeaders(t *testing.T) {
db := setupTestDB(t)
defer cleanupTestDB(t, db)
createTestData(t, db)
handler := NewHandlerWithGORM(db)
handler.registry.RegisterModel("public.test_users", TestUser{})
muxRouter := mux.NewRouter()
SetupMuxRoutes(muxRouter, handler, nil)
// Query params should override headers
req := httptest.NewRequest("GET", "/public/test_users?x-limit=1", nil)
req.Header.Set("X-DetailApi", "true")
req.Header.Set("X-Limit", "10")
w := httptest.NewRecorder()
muxRouter.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", w.Code)
}
var response common.Response
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
t.Fatalf("Failed to parse response: %v", err)
}
if !response.Success {
t.Errorf("Expected success=true, got %v", response.Success)
}
// Query param should win (limit=1)
if response.Metadata.Limit != 1 {
t.Errorf("Expected limit 1 from query param, got %d", response.Metadata.Limit)
}
}
func TestIntegration_GetSingleRecord(t *testing.T) {
db := setupTestDB(t)
defer cleanupTestDB(t, db)
createTestData(t, db)
handler := NewHandlerWithGORM(db)
handler.registry.RegisterModel("public.test_users", TestUser{})
muxRouter := mux.NewRouter()
SetupMuxRoutes(muxRouter, handler, nil)
// Get first user ID
var user TestUser
db.Where("email = ?", "john@example.com").First(&user)
req := httptest.NewRequest("GET", fmt.Sprintf("/public/test_users/%d", user.ID), nil)
req.Header.Set("X-DetailApi", "true")
w := httptest.NewRecorder()
muxRouter.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d. Body: %s", w.Code, w.Body.String())
}
var response common.Response
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
t.Fatalf("Failed to parse response: %v", err)
}
if !response.Success {
t.Errorf("Expected success=true, got %v", response.Success)
}
// Verify it's a single record
dataBytes, _ := json.Marshal(response.Data)
var resultUser TestUser
json.Unmarshal(dataBytes, &resultUser)
if resultUser.Email != "john@example.com" {
t.Errorf("Expected user with email 'john@example.com', got '%s'", resultUser.Email)
}
}

View File

@@ -128,15 +128,17 @@ func SetupMuxRoutes(muxRouter *mux.Router, handler *Handler, authMiddleware Midd
}
// Register routes for this entity
// IMPORTANT: Register more specific routes before wildcard routes
// GET, POST for /{schema}/{entity}
muxRouter.Handle(entityPath, entityHandler).Methods("GET", "POST")
// GET for metadata (using HandleGet) - MUST be registered before /{id} route
muxRouter.Handle(metadataPath, metadataHandler).Methods("GET")
// GET, PUT, PATCH, DELETE, POST for /{schema}/{entity}/{id}
muxRouter.Handle(entityWithIDPath, entityWithIDHandler).Methods("GET", "PUT", "PATCH", "DELETE", "POST")
// GET for metadata (using HandleGet)
muxRouter.Handle(metadataPath, metadataHandler).Methods("GET")
// OPTIONS for CORS preflight - returns metadata
muxRouter.Handle(entityPath, optionsEntityHandler).Methods("OPTIONS")
muxRouter.Handle(entityWithIDPath, optionsEntityWithIDHandler).Methods("OPTIONS")

View File

@@ -0,0 +1,114 @@
package restheadspec
import (
"testing"
)
func TestParseModelName(t *testing.T) {
tests := []struct {
name string
fullName string
expectedSchema string
expectedEntity string
}{
{
name: "Model with schema",
fullName: "public.users",
expectedSchema: "public",
expectedEntity: "users",
},
{
name: "Model without schema",
fullName: "users",
expectedSchema: "",
expectedEntity: "users",
},
{
name: "Model with custom schema",
fullName: "myschema.products",
expectedSchema: "myschema",
expectedEntity: "products",
},
{
name: "Empty string",
fullName: "",
expectedSchema: "",
expectedEntity: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
schema, entity := parseModelName(tt.fullName)
if schema != tt.expectedSchema {
t.Errorf("Expected schema '%s', got '%s'", tt.expectedSchema, schema)
}
if entity != tt.expectedEntity {
t.Errorf("Expected entity '%s', got '%s'", tt.expectedEntity, entity)
}
})
}
}
func TestBuildRoutePath(t *testing.T) {
tests := []struct {
name string
schema string
entity string
expectedPath string
}{
{
name: "With schema",
schema: "public",
entity: "users",
expectedPath: "/public/users",
},
{
name: "Without schema",
schema: "",
entity: "users",
expectedPath: "/users",
},
{
name: "Custom schema",
schema: "admin",
entity: "logs",
expectedPath: "/admin/logs",
},
{
name: "Empty entity with schema",
schema: "public",
entity: "",
expectedPath: "/public/",
},
{
name: "Both empty",
schema: "",
entity: "",
expectedPath: "/",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
path := buildRoutePath(tt.schema, tt.entity)
if path != tt.expectedPath {
t.Errorf("Expected path '%s', got '%s'", tt.expectedPath, path)
}
})
}
}
func TestNewStandardMuxRouter(t *testing.T) {
router := NewStandardMuxRouter()
if router == nil {
t.Error("Expected router to be created, got nil")
}
}
func TestNewStandardBunRouter(t *testing.T) {
router := NewStandardBunRouter()
if router == nil {
t.Error("Expected router to be created, got nil")
}
}

View File

@@ -0,0 +1,440 @@
# Security Features: Blacklist & Rate Limit Inspection
## IP Blacklist
The IP blacklist middleware allows you to block specific IP addresses or CIDR ranges from accessing your application.
### Basic Usage
```go
import "github.com/bitechdev/ResolveSpec/pkg/middleware"
// Create blacklist (UseProxy=true if behind a proxy)
blacklist := middleware.NewIPBlacklist(middleware.BlacklistConfig{
UseProxy: true, // Checks X-Forwarded-For and X-Real-IP headers
})
// Block individual IP
blacklist.BlockIP("192.168.1.100", "Suspicious activity detected")
// Block entire CIDR range
blacklist.BlockCIDR("10.0.0.0/8", "Private network blocked")
// Apply middleware
http.Handle("/api/", blacklist.Middleware(yourHandler))
```
### Managing Blacklist
```go
// Unblock an IP
blacklist.UnblockIP("192.168.1.100")
// Unblock a CIDR range
blacklist.UnblockCIDR("10.0.0.0/8")
// Get all blacklisted IPs and CIDRs
ips, cidrs := blacklist.GetBlacklist()
fmt.Printf("Blocked IPs: %v\n", ips)
fmt.Printf("Blocked CIDRs: %v\n", cidrs)
// Check if specific IP is blocked
blocked, reason := blacklist.IsBlocked("192.168.1.100")
if blocked {
fmt.Printf("IP blocked: %s\n", reason)
}
```
### Blacklist Statistics Endpoint
Expose blacklist statistics via HTTP:
```go
// Add stats endpoint
http.Handle("/admin/blacklist-stats", blacklist.StatsHandler())
```
**Example Response:**
```json
{
"blocked_ips": ["192.168.1.100", "192.168.1.101"],
"blocked_cidrs": ["10.0.0.0/8"],
"total_ips": 2,
"total_cidrs": 1
}
```
### Integration Example
```go
func main() {
// Create blacklist
blacklist := middleware.NewIPBlacklist(middleware.BlacklistConfig{
UseProxy: true,
})
// Block known malicious IPs
blacklist.BlockIP("203.0.113.1", "Known scanner")
blacklist.BlockCIDR("198.51.100.0/24", "Spam network")
// Create your router
mux := http.NewServeMux()
// Protected routes
mux.Handle("/api/", blacklist.Middleware(apiHandler))
// Admin endpoint to manage blacklist
mux.HandleFunc("/admin/block-ip", func(w http.ResponseWriter, r *http.Request) {
ip := r.URL.Query().Get("ip")
reason := r.URL.Query().Get("reason")
if err := blacklist.BlockIP(ip, reason); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
w.WriteHeader(http.StatusOK)
fmt.Fprintf(w, "Blocked %s: %s", ip, reason)
})
// Stats endpoint
mux.Handle("/admin/blacklist-stats", blacklist.StatsHandler())
http.ListenAndServe(":8080", mux)
}
```
---
## Rate Limit Inspection
Monitor and inspect rate limit status per IP address in real-time.
### Basic Usage
```go
import "github.com/bitechdev/ResolveSpec/pkg/middleware"
// Create rate limiter (10 req/sec, burst of 20)
rateLimiter := middleware.NewRateLimiter(10, 20)
// Apply middleware
http.Handle("/api/", rateLimiter.Middleware(yourHandler))
```
### Programmatic Inspection
```go
// Get all tracked IPs
trackedIPs := rateLimiter.GetTrackedIPs()
fmt.Printf("Currently tracking %d IPs\n", len(trackedIPs))
// Get rate limit info for specific IP
info := rateLimiter.GetRateLimitInfo("192.168.1.1")
fmt.Printf("IP: %s\n", info.IP)
fmt.Printf("Tokens Remaining: %.2f\n", info.TokensRemaining)
fmt.Printf("Limit: %.2f req/sec\n", info.Limit)
fmt.Printf("Burst: %d\n", info.Burst)
// Get info for all tracked IPs
allInfo := rateLimiter.GetAllRateLimitInfo()
for _, info := range allInfo {
fmt.Printf("%s: %.2f tokens remaining\n", info.IP, info.TokensRemaining)
}
```
### Rate Limit Stats Endpoint
Expose rate limit statistics via HTTP:
```go
// Add stats endpoint
http.Handle("/admin/rate-limit-stats", rateLimiter.StatsHandler())
```
**Example Response (all IPs):**
```json
{
"total_tracked_ips": 3,
"rate_limit_config": {
"requests_per_second": 10,
"burst": 20
},
"tracked_ips": [
{
"ip": "192.168.1.1",
"tokens_remaining": 15.5,
"limit": 10,
"burst": 20
},
{
"ip": "192.168.1.2",
"tokens_remaining": 18.2,
"limit": 10,
"burst": 20
}
]
}
```
**Example Response (specific IP):**
```bash
GET /admin/rate-limit-stats?ip=192.168.1.1
```
```json
{
"ip": "192.168.1.1",
"tokens_remaining": 15.5,
"limit": 10,
"burst": 20
}
```
### Complete Integration Example
```go
package main
import (
"encoding/json"
"fmt"
"net/http"
"github.com/bitechdev/ResolveSpec/pkg/middleware"
)
func main() {
// Create rate limiter
rateLimiter := middleware.NewRateLimiter(10, 20)
// Create blacklist
blacklist := middleware.NewIPBlacklist(middleware.BlacklistConfig{
UseProxy: true,
})
mux := http.NewServeMux()
// API handler with both middlewares (blacklist first, then rate limit)
apiHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]string{
"message": "Success",
})
})
// Apply middleware chain: blacklist -> rate limit -> handler
mux.Handle("/api/", blacklist.Middleware(rateLimiter.Middleware(apiHandler)))
// Admin endpoints
mux.Handle("/admin/rate-limit-stats", rateLimiter.StatsHandler())
mux.Handle("/admin/blacklist-stats", blacklist.StatsHandler())
// Custom monitoring endpoint
mux.HandleFunc("/admin/monitor", func(w http.ResponseWriter, r *http.Request) {
// Get rate limit stats
rateLimitInfo := rateLimiter.GetAllRateLimitInfo()
// Get blacklist stats
blockedIPs, blockedCIDRs := blacklist.GetBlacklist()
response := map[string]interface{}{
"rate_limits": rateLimitInfo,
"blacklist": map[string]interface{}{
"ips": blockedIPs,
"cidrs": blockedCIDRs,
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(response)
})
// Dynamic blacklist management
mux.HandleFunc("/admin/block", func(w http.ResponseWriter, r *http.Request) {
ip := r.URL.Query().Get("ip")
reason := r.URL.Query().Get("reason")
if ip == "" {
http.Error(w, "IP required", http.StatusBadRequest)
return
}
if err := blacklist.BlockIP(ip, reason); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
fmt.Fprintf(w, "Blocked %s: %s", ip, reason)
})
mux.HandleFunc("/admin/unblock", func(w http.ResponseWriter, r *http.Request) {
ip := r.URL.Query().Get("ip")
if ip == "" {
http.Error(w, "IP required", http.StatusBadRequest)
return
}
blacklist.UnblockIP(ip)
fmt.Fprintf(w, "Unblocked %s", ip)
})
// Auto-block IPs that exceed rate limit
mux.HandleFunc("/admin/auto-block-heavy-users", func(w http.ResponseWriter, r *http.Request) {
blocked := 0
for _, info := range rateLimiter.GetAllRateLimitInfo() {
// If tokens are very low, IP is making many requests
if info.TokensRemaining < 1.0 {
blacklist.BlockIP(info.IP, "Exceeded rate limit")
blocked++
}
}
fmt.Fprintf(w, "Blocked %d IPs exceeding rate limits", blocked)
})
fmt.Println("Server starting on :8080")
fmt.Println("Rate limit stats: http://localhost:8080/admin/rate-limit-stats")
fmt.Println("Blacklist stats: http://localhost:8080/admin/blacklist-stats")
http.ListenAndServe(":8080", mux)
}
```
---
## Monitoring Dashboard Example
Create a simple monitoring page:
```go
mux.HandleFunc("/admin/dashboard", func(w http.ResponseWriter, r *http.Request) {
html := `
<html>
<head>
<title>Security Dashboard</title>
<script>
async function loadStats() {
const rateLimitRes = await fetch('/admin/rate-limit-stats');
const rateLimitData = await rateLimitRes.json();
const blacklistRes = await fetch('/admin/blacklist-stats');
const blacklistData = await blacklistRes.json();
document.getElementById('rate-limit').innerHTML =
JSON.stringify(rateLimitData, null, 2);
document.getElementById('blacklist').innerHTML =
JSON.stringify(blacklistData, null, 2);
}
setInterval(loadStats, 5000); // Refresh every 5 seconds
loadStats();
</script>
</head>
<body>
<h1>Security Dashboard</h1>
<h2>Rate Limits</h2>
<pre id="rate-limit">Loading...</pre>
<h2>Blacklist</h2>
<pre id="blacklist">Loading...</pre>
</body>
</html>
`
w.Header().Set("Content-Type", "text/html")
w.Write([]byte(html))
})
```
---
## Best Practices
### 1. Proxy Configuration
Always set `UseProxy: true` when running behind a reverse proxy (nginx, Cloudflare, etc.):
```go
blacklist := middleware.NewIPBlacklist(middleware.BlacklistConfig{
UseProxy: true, // Checks X-Forwarded-For headers
})
```
### 2. Middleware Order
Apply blacklist before rate limiting to save resources:
```go
// Correct order: blacklist -> rate limit -> handler
handler := blacklist.Middleware(
rateLimiter.Middleware(yourHandler)
)
```
### 3. Secure Admin Endpoints
Protect admin endpoints with authentication:
```go
mux.Handle("/admin/", authMiddleware(adminHandler))
```
### 4. Monitoring
Set up alerts when:
- Many IPs are being rate limited
- Blacklist grows too large
- Specific IPs are repeatedly blocked
### 5. Dynamic Response
Automatically block IPs that consistently exceed rate limits:
```go
// Check every minute
ticker := time.NewTicker(1 * time.Minute)
go func() {
for range ticker.C {
for _, info := range rateLimiter.GetAllRateLimitInfo() {
if info.TokensRemaining < 0.5 {
blacklist.BlockIP(info.IP, "Automated block: rate limit exceeded")
}
}
}
}()
```
### 6. CIDR for Network Blocks
Use CIDR ranges to block entire networks efficiently:
```go
// Block entire subnets
blacklist.BlockCIDR("10.0.0.0/8", "Private network")
blacklist.BlockCIDR("192.168.0.0/16", "Local network")
```
---
## API Reference
### IPBlacklist
#### Methods
- `BlockIP(ip, reason string) error` - Block a single IP address
- `BlockCIDR(cidr, reason string) error` - Block a CIDR range
- `UnblockIP(ip string)` - Remove IP from blacklist
- `UnblockCIDR(cidr string)` - Remove CIDR from blacklist
- `IsBlocked(ip string) (blocked bool, reason string)` - Check if IP is blocked
- `GetBlacklist() (ips, cidrs []string)` - Get all blocked IPs and CIDRs
- `Middleware(next http.Handler) http.Handler` - HTTP middleware
- `StatsHandler() http.Handler` - HTTP handler for statistics
### RateLimiter
#### Methods
- `GetTrackedIPs() []string` - Get all tracked IP addresses
- `GetRateLimitInfo(ip string) *RateLimitInfo` - Get info for specific IP
- `GetAllRateLimitInfo() []*RateLimitInfo` - Get info for all tracked IPs
- `Middleware(next http.Handler) http.Handler` - HTTP middleware
- `StatsHandler() http.Handler` - HTTP handler for statistics
#### RateLimitInfo Structure
```go
type RateLimitInfo struct {
IP string `json:"ip"`
TokensRemaining float64 `json:"tokens_remaining"`
Limit float64 `json:"limit"`
Burst int `json:"burst"`
}
```