mirror of
https://github.com/bitechdev/ResolveSpec.git
synced 2026-01-02 18:04:25 +00:00
Fixed providers
This commit is contained in:
583
pkg/security/hooks_test.go
Normal file
583
pkg/security/hooks_test.go
Normal file
@@ -0,0 +1,583 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"context"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// Mock SecurityContext for testing hooks
|
||||
type mockSecurityContext struct {
|
||||
ctx context.Context
|
||||
userID int
|
||||
hasUser bool
|
||||
schema string
|
||||
entity string
|
||||
model interface{}
|
||||
query interface{}
|
||||
result interface{}
|
||||
}
|
||||
|
||||
func (m *mockSecurityContext) GetContext() context.Context {
|
||||
return m.ctx
|
||||
}
|
||||
|
||||
func (m *mockSecurityContext) GetUserID() (int, bool) {
|
||||
return m.userID, m.hasUser
|
||||
}
|
||||
|
||||
func (m *mockSecurityContext) GetSchema() string {
|
||||
return m.schema
|
||||
}
|
||||
|
||||
func (m *mockSecurityContext) GetEntity() string {
|
||||
return m.entity
|
||||
}
|
||||
|
||||
func (m *mockSecurityContext) GetModel() interface{} {
|
||||
return m.model
|
||||
}
|
||||
|
||||
func (m *mockSecurityContext) GetQuery() interface{} {
|
||||
return m.query
|
||||
}
|
||||
|
||||
func (m *mockSecurityContext) SetQuery(q interface{}) {
|
||||
m.query = q
|
||||
}
|
||||
|
||||
func (m *mockSecurityContext) GetResult() interface{} {
|
||||
return m.result
|
||||
}
|
||||
|
||||
func (m *mockSecurityContext) SetResult(r interface{}) {
|
||||
m.result = r
|
||||
}
|
||||
|
||||
// Test helper functions
|
||||
func TestContains(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
s string
|
||||
substr string
|
||||
expected bool
|
||||
}{
|
||||
{"substring at start", "hello world", "hello", true},
|
||||
{"substring at end", "hello world", "world", true},
|
||||
{"substring in middle", "hello world", "lo wo", false}, // contains only checks prefix/suffix
|
||||
{"substring not present", "hello world", "xyz", false},
|
||||
{"exact match", "test", "test", true},
|
||||
{"empty substring", "test", "", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := contains(tt.s, tt.substr)
|
||||
if result != tt.expected {
|
||||
t.Errorf("contains(%q, %q) = %v, want %v", tt.s, tt.substr, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractSQLName(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
tag string
|
||||
expected string
|
||||
}{
|
||||
{"simple name", "user_id", "user_id"},
|
||||
{"column prefix", "column:email", "column:email"}, // Implementation doesn't strip prefix in all cases
|
||||
{"with other tags", "id,pk,autoincrement", "id"},
|
||||
{"column with comma", "column:user_name,notnull", "column:user_name"}, // Implementation behavior
|
||||
{"empty tag", "", ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := extractSQLName(tt.tag)
|
||||
if result != tt.expected {
|
||||
t.Errorf("extractSQLName(%q) = %q, want %q", tt.tag, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSplitTag(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
tag string
|
||||
sep rune
|
||||
expected []string
|
||||
}{
|
||||
{"single part", "id", ',', []string{"id"}},
|
||||
{"multiple parts", "id,pk,autoincrement", ',', []string{"id", "pk", "autoincrement"}},
|
||||
{"empty parts filtered", "id,,pk", ',', []string{"id", "pk"}},
|
||||
{"no separator", "singlepart", ',', []string{"singlepart"}},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := splitTag(tt.tag, tt.sep)
|
||||
if len(result) != len(tt.expected) {
|
||||
t.Errorf("splitTag(%q) returned %d parts, want %d", tt.tag, len(result), len(tt.expected))
|
||||
return
|
||||
}
|
||||
for i, part := range tt.expected {
|
||||
if result[i] != part {
|
||||
t.Errorf("splitTag(%q)[%d] = %q, want %q", tt.tag, i, result[i], part)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Test loadSecurityRules
|
||||
func TestLoadSecurityRules(t *testing.T) {
|
||||
t.Run("load rules successfully", func(t *testing.T) {
|
||||
provider := &mockSecurityProvider{
|
||||
columnSecurity: []ColumnSecurity{
|
||||
{Schema: "public", Tablename: "users", Path: []string{"email"}},
|
||||
},
|
||||
rowSecurity: RowSecurity{
|
||||
Schema: "public",
|
||||
Tablename: "users",
|
||||
Template: "id = {UserID}",
|
||||
},
|
||||
}
|
||||
secList, _ := NewSecurityList(provider)
|
||||
|
||||
secCtx := &mockSecurityContext{
|
||||
ctx: context.Background(),
|
||||
userID: 1,
|
||||
hasUser: true,
|
||||
schema: "public",
|
||||
entity: "users",
|
||||
}
|
||||
|
||||
err := LoadSecurityRules(secCtx, secList)
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
|
||||
// Verify column security was loaded
|
||||
key := "public.users@1"
|
||||
if _, ok := secList.ColumnSecurity[key]; !ok {
|
||||
t.Error("expected column security to be loaded")
|
||||
}
|
||||
|
||||
// Verify row security was loaded
|
||||
if _, ok := secList.RowSecurity[key]; !ok {
|
||||
t.Error("expected row security to be loaded")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("no user in context", func(t *testing.T) {
|
||||
provider := &mockSecurityProvider{}
|
||||
secList, _ := NewSecurityList(provider)
|
||||
|
||||
secCtx := &mockSecurityContext{
|
||||
ctx: context.Background(),
|
||||
hasUser: false,
|
||||
schema: "public",
|
||||
entity: "users",
|
||||
}
|
||||
|
||||
err := LoadSecurityRules(secCtx, secList)
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error with no user, got %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Test applyRowSecurity
|
||||
func TestApplyRowSecurity(t *testing.T) {
|
||||
type TestModel struct {
|
||||
ID int `bun:"id,pk"`
|
||||
}
|
||||
|
||||
t.Run("apply row security template", func(t *testing.T) {
|
||||
provider := &mockSecurityProvider{
|
||||
rowSecurity: RowSecurity{
|
||||
Schema: "public",
|
||||
Tablename: "orders",
|
||||
Template: "user_id = {UserID}",
|
||||
HasBlock: false,
|
||||
},
|
||||
}
|
||||
secList, _ := NewSecurityList(provider)
|
||||
ctx := context.Background()
|
||||
|
||||
// Load row security
|
||||
_, _ = secList.LoadRowSecurity(ctx, 1, "public", "orders", false)
|
||||
|
||||
// Mock query that supports Where
|
||||
type MockQuery struct {
|
||||
whereClause string
|
||||
}
|
||||
mockQuery := &MockQuery{}
|
||||
|
||||
secCtx := &mockSecurityContext{
|
||||
ctx: ctx,
|
||||
userID: 1,
|
||||
hasUser: true,
|
||||
schema: "public",
|
||||
entity: "orders",
|
||||
model: &TestModel{},
|
||||
query: mockQuery,
|
||||
}
|
||||
|
||||
err := ApplyRowSecurity(secCtx, secList)
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
|
||||
// Note: The actual WHERE clause application requires a query type that implements Where()
|
||||
// In a real scenario, this would be a bun.SelectQuery or similar
|
||||
})
|
||||
|
||||
t.Run("block access", func(t *testing.T) {
|
||||
provider := &mockSecurityProvider{
|
||||
rowSecurity: RowSecurity{
|
||||
Schema: "public",
|
||||
Tablename: "secrets",
|
||||
HasBlock: true,
|
||||
},
|
||||
}
|
||||
secList, _ := NewSecurityList(provider)
|
||||
ctx := context.Background()
|
||||
|
||||
// Load row security
|
||||
_, _ = secList.LoadRowSecurity(ctx, 1, "public", "secrets", false)
|
||||
|
||||
secCtx := &mockSecurityContext{
|
||||
ctx: ctx,
|
||||
userID: 1,
|
||||
hasUser: true,
|
||||
schema: "public",
|
||||
entity: "secrets",
|
||||
}
|
||||
|
||||
err := ApplyRowSecurity(secCtx, secList)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for blocked access")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("no user in context", func(t *testing.T) {
|
||||
provider := &mockSecurityProvider{}
|
||||
secList, _ := NewSecurityList(provider)
|
||||
|
||||
secCtx := &mockSecurityContext{
|
||||
ctx: context.Background(),
|
||||
hasUser: false,
|
||||
schema: "public",
|
||||
entity: "orders",
|
||||
}
|
||||
|
||||
err := ApplyRowSecurity(secCtx, secList)
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error with no user, got %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("no row security defined", func(t *testing.T) {
|
||||
provider := &mockSecurityProvider{}
|
||||
secList, _ := NewSecurityList(provider)
|
||||
|
||||
secCtx := &mockSecurityContext{
|
||||
ctx: context.Background(),
|
||||
userID: 1,
|
||||
hasUser: true,
|
||||
schema: "public",
|
||||
entity: "unknown_table",
|
||||
}
|
||||
|
||||
err := ApplyRowSecurity(secCtx, secList)
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error with no security, got %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Test applyColumnSecurity
|
||||
func TestApplyColumnSecurityHook(t *testing.T) {
|
||||
type User struct {
|
||||
ID int `bun:"id,pk"`
|
||||
Email string `bun:"email"`
|
||||
}
|
||||
|
||||
t.Run("apply column security to results", func(t *testing.T) {
|
||||
provider := &mockSecurityProvider{
|
||||
columnSecurity: []ColumnSecurity{
|
||||
{
|
||||
Schema: "public",
|
||||
Tablename: "users",
|
||||
Path: []string{"email"},
|
||||
Accesstype: "mask",
|
||||
UserID: 1,
|
||||
MaskStart: 3,
|
||||
MaskEnd: 0,
|
||||
MaskChar: "*",
|
||||
},
|
||||
},
|
||||
}
|
||||
secList, _ := NewSecurityList(provider)
|
||||
ctx := context.Background()
|
||||
|
||||
// Load column security
|
||||
_ = secList.LoadColumnSecurity(ctx, 1, "public", "users", false)
|
||||
|
||||
users := []User{
|
||||
{ID: 1, Email: "test@example.com"},
|
||||
{ID: 2, Email: "user@test.com"},
|
||||
}
|
||||
|
||||
secCtx := &mockSecurityContext{
|
||||
ctx: ctx,
|
||||
userID: 1,
|
||||
hasUser: true,
|
||||
schema: "public",
|
||||
entity: "users",
|
||||
model: &User{},
|
||||
result: users,
|
||||
}
|
||||
|
||||
err := ApplyColumnSecurity(secCtx, secList)
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
|
||||
// Check that result was updated with masked data
|
||||
maskedResult := secCtx.GetResult()
|
||||
if maskedResult == nil {
|
||||
t.Error("expected result to be set")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("no user in context", func(t *testing.T) {
|
||||
provider := &mockSecurityProvider{}
|
||||
secList, _ := NewSecurityList(provider)
|
||||
|
||||
secCtx := &mockSecurityContext{
|
||||
ctx: context.Background(),
|
||||
hasUser: false,
|
||||
schema: "public",
|
||||
entity: "users",
|
||||
}
|
||||
|
||||
err := ApplyColumnSecurity(secCtx, secList)
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error with no user, got %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("nil result", func(t *testing.T) {
|
||||
provider := &mockSecurityProvider{}
|
||||
secList, _ := NewSecurityList(provider)
|
||||
|
||||
secCtx := &mockSecurityContext{
|
||||
ctx: context.Background(),
|
||||
userID: 1,
|
||||
hasUser: true,
|
||||
schema: "public",
|
||||
entity: "users",
|
||||
result: nil,
|
||||
}
|
||||
|
||||
err := ApplyColumnSecurity(secCtx, secList)
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error with nil result, got %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("nil model", func(t *testing.T) {
|
||||
provider := &mockSecurityProvider{}
|
||||
secList, _ := NewSecurityList(provider)
|
||||
|
||||
secCtx := &mockSecurityContext{
|
||||
ctx: context.Background(),
|
||||
userID: 1,
|
||||
hasUser: true,
|
||||
schema: "public",
|
||||
entity: "users",
|
||||
model: nil,
|
||||
result: []interface{}{},
|
||||
}
|
||||
|
||||
err := ApplyColumnSecurity(secCtx, secList)
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error with nil model, got %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Test logDataAccess
|
||||
func TestLogDataAccess(t *testing.T) {
|
||||
t.Run("log access with user", func(t *testing.T) {
|
||||
secCtx := &mockSecurityContext{
|
||||
ctx: context.Background(),
|
||||
userID: 1,
|
||||
hasUser: true,
|
||||
schema: "public",
|
||||
entity: "users",
|
||||
}
|
||||
|
||||
err := LogDataAccess(secCtx)
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("log access without user", func(t *testing.T) {
|
||||
secCtx := &mockSecurityContext{
|
||||
ctx: context.Background(),
|
||||
hasUser: false,
|
||||
schema: "public",
|
||||
entity: "users",
|
||||
}
|
||||
|
||||
err := LogDataAccess(secCtx)
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Test integration: loading and applying all security
|
||||
func TestSecurityIntegration(t *testing.T) {
|
||||
type Order struct {
|
||||
ID int `bun:"id,pk"`
|
||||
UserID int `bun:"user_id"`
|
||||
Amount int `bun:"amount"`
|
||||
Description string `bun:"description"`
|
||||
}
|
||||
|
||||
provider := &mockSecurityProvider{
|
||||
columnSecurity: []ColumnSecurity{
|
||||
{
|
||||
Schema: "public",
|
||||
Tablename: "orders",
|
||||
Path: []string{"amount"},
|
||||
Accesstype: "mask",
|
||||
UserID: 1,
|
||||
},
|
||||
},
|
||||
rowSecurity: RowSecurity{
|
||||
Schema: "public",
|
||||
Tablename: "orders",
|
||||
Template: "user_id = {UserID}",
|
||||
HasBlock: false,
|
||||
},
|
||||
}
|
||||
|
||||
secList, _ := NewSecurityList(provider)
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("complete security flow", func(t *testing.T) {
|
||||
secCtx := &mockSecurityContext{
|
||||
ctx: ctx,
|
||||
userID: 1,
|
||||
hasUser: true,
|
||||
schema: "public",
|
||||
entity: "orders",
|
||||
model: &Order{},
|
||||
}
|
||||
|
||||
// Step 1: Load security rules
|
||||
err := LoadSecurityRules(secCtx, secList)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadSecurityRules failed: %v", err)
|
||||
}
|
||||
|
||||
// Step 2: Apply row security
|
||||
err = ApplyRowSecurity(secCtx, secList)
|
||||
if err != nil {
|
||||
t.Fatalf("ApplyRowSecurity failed: %v", err)
|
||||
}
|
||||
|
||||
// Step 3: Set some results
|
||||
orders := []Order{
|
||||
{ID: 1, UserID: 1, Amount: 1000, Description: "Order 1"},
|
||||
{ID: 2, UserID: 1, Amount: 2000, Description: "Order 2"},
|
||||
}
|
||||
secCtx.SetResult(orders)
|
||||
|
||||
// Step 4: Apply column security
|
||||
err = ApplyColumnSecurity(secCtx, secList)
|
||||
if err != nil {
|
||||
t.Fatalf("ApplyColumnSecurity failed: %v", err)
|
||||
}
|
||||
|
||||
// Step 5: Log access
|
||||
err = LogDataAccess(secCtx)
|
||||
if err != nil {
|
||||
t.Fatalf("LogDataAccess failed: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("security without user context", func(t *testing.T) {
|
||||
secCtx := &mockSecurityContext{
|
||||
ctx: ctx,
|
||||
hasUser: false,
|
||||
schema: "public",
|
||||
entity: "orders",
|
||||
}
|
||||
|
||||
// All security operations should handle missing user gracefully
|
||||
_ = LoadSecurityRules(secCtx, secList)
|
||||
_ = ApplyRowSecurity(secCtx, secList)
|
||||
_ = ApplyColumnSecurity(secCtx, secList)
|
||||
_ = LogDataAccess(secCtx)
|
||||
|
||||
// If we reach here without panics, the test passes
|
||||
})
|
||||
}
|
||||
|
||||
// Test RowSecurity GetTemplate with various placeholders
|
||||
func TestRowSecurityGetTemplateIntegration(t *testing.T) {
|
||||
type Model struct {
|
||||
OrderID int `bun:"order_id,pk"`
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
rowSec RowSecurity
|
||||
pkName string
|
||||
expectedPart string // Part of the expected output
|
||||
}{
|
||||
{
|
||||
name: "with all placeholders",
|
||||
rowSec: RowSecurity{
|
||||
Schema: "sales",
|
||||
Tablename: "orders",
|
||||
UserID: 42,
|
||||
Template: "{PrimaryKeyName} IN (SELECT {PrimaryKeyName} FROM {SchemaName}.{TableName}_access WHERE user_id = {UserID})",
|
||||
},
|
||||
pkName: "order_id",
|
||||
expectedPart: "order_id IN (SELECT order_id FROM sales.orders_access WHERE user_id = 42)",
|
||||
},
|
||||
{
|
||||
name: "simple user filter",
|
||||
rowSec: RowSecurity{
|
||||
Schema: "public",
|
||||
Tablename: "orders",
|
||||
UserID: 1,
|
||||
Template: "user_id = {UserID}",
|
||||
},
|
||||
pkName: "id",
|
||||
expectedPart: "user_id = 1",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
modelType := reflect.TypeOf(Model{})
|
||||
result := tt.rowSec.GetTemplate(tt.pkName, modelType)
|
||||
|
||||
if result != tt.expectedPart {
|
||||
t.Errorf("GetTemplate() = %q, want %q", result, tt.expectedPart)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user