mirror of
https://github.com/bitechdev/ResolveSpec.git
synced 2025-12-13 17:10:36 +00:00
584 lines
13 KiB
Go
584 lines
13 KiB
Go
package security
|
|
|
|
import (
|
|
"context"
|
|
"reflect"
|
|
"testing"
|
|
)
|
|
|
|
// Mock SecurityContext for testing hooks
|
|
type mockSecurityContext struct {
|
|
ctx context.Context
|
|
userID int
|
|
hasUser bool
|
|
schema string
|
|
entity string
|
|
model interface{}
|
|
query interface{}
|
|
result interface{}
|
|
}
|
|
|
|
func (m *mockSecurityContext) GetContext() context.Context {
|
|
return m.ctx
|
|
}
|
|
|
|
func (m *mockSecurityContext) GetUserID() (int, bool) {
|
|
return m.userID, m.hasUser
|
|
}
|
|
|
|
func (m *mockSecurityContext) GetSchema() string {
|
|
return m.schema
|
|
}
|
|
|
|
func (m *mockSecurityContext) GetEntity() string {
|
|
return m.entity
|
|
}
|
|
|
|
func (m *mockSecurityContext) GetModel() interface{} {
|
|
return m.model
|
|
}
|
|
|
|
func (m *mockSecurityContext) GetQuery() interface{} {
|
|
return m.query
|
|
}
|
|
|
|
func (m *mockSecurityContext) SetQuery(q interface{}) {
|
|
m.query = q
|
|
}
|
|
|
|
func (m *mockSecurityContext) GetResult() interface{} {
|
|
return m.result
|
|
}
|
|
|
|
func (m *mockSecurityContext) SetResult(r interface{}) {
|
|
m.result = r
|
|
}
|
|
|
|
// Test helper functions
|
|
func TestContains(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
s string
|
|
substr string
|
|
expected bool
|
|
}{
|
|
{"substring at start", "hello world", "hello", true},
|
|
{"substring at end", "hello world", "world", true},
|
|
{"substring in middle", "hello world", "lo wo", false}, // contains only checks prefix/suffix
|
|
{"substring not present", "hello world", "xyz", false},
|
|
{"exact match", "test", "test", true},
|
|
{"empty substring", "test", "", true},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
result := contains(tt.s, tt.substr)
|
|
if result != tt.expected {
|
|
t.Errorf("contains(%q, %q) = %v, want %v", tt.s, tt.substr, result, tt.expected)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestExtractSQLName(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
tag string
|
|
expected string
|
|
}{
|
|
{"simple name", "user_id", "user_id"},
|
|
{"column prefix", "column:email", "column:email"}, // Implementation doesn't strip prefix in all cases
|
|
{"with other tags", "id,pk,autoincrement", "id"},
|
|
{"column with comma", "column:user_name,notnull", "column:user_name"}, // Implementation behavior
|
|
{"empty tag", "", ""},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
result := extractSQLName(tt.tag)
|
|
if result != tt.expected {
|
|
t.Errorf("extractSQLName(%q) = %q, want %q", tt.tag, result, tt.expected)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestSplitTag(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
tag string
|
|
sep rune
|
|
expected []string
|
|
}{
|
|
{"single part", "id", ',', []string{"id"}},
|
|
{"multiple parts", "id,pk,autoincrement", ',', []string{"id", "pk", "autoincrement"}},
|
|
{"empty parts filtered", "id,,pk", ',', []string{"id", "pk"}},
|
|
{"no separator", "singlepart", ',', []string{"singlepart"}},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
result := splitTag(tt.tag, tt.sep)
|
|
if len(result) != len(tt.expected) {
|
|
t.Errorf("splitTag(%q) returned %d parts, want %d", tt.tag, len(result), len(tt.expected))
|
|
return
|
|
}
|
|
for i, part := range tt.expected {
|
|
if result[i] != part {
|
|
t.Errorf("splitTag(%q)[%d] = %q, want %q", tt.tag, i, result[i], part)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// Test loadSecurityRules
|
|
func TestLoadSecurityRules(t *testing.T) {
|
|
t.Run("load rules successfully", func(t *testing.T) {
|
|
provider := &mockSecurityProvider{
|
|
columnSecurity: []ColumnSecurity{
|
|
{Schema: "public", Tablename: "users", Path: []string{"email"}},
|
|
},
|
|
rowSecurity: RowSecurity{
|
|
Schema: "public",
|
|
Tablename: "users",
|
|
Template: "id = {UserID}",
|
|
},
|
|
}
|
|
secList, _ := NewSecurityList(provider)
|
|
|
|
secCtx := &mockSecurityContext{
|
|
ctx: context.Background(),
|
|
userID: 1,
|
|
hasUser: true,
|
|
schema: "public",
|
|
entity: "users",
|
|
}
|
|
|
|
err := LoadSecurityRules(secCtx, secList)
|
|
if err != nil {
|
|
t.Fatalf("expected no error, got %v", err)
|
|
}
|
|
|
|
// Verify column security was loaded
|
|
key := "public.users@1"
|
|
if _, ok := secList.ColumnSecurity[key]; !ok {
|
|
t.Error("expected column security to be loaded")
|
|
}
|
|
|
|
// Verify row security was loaded
|
|
if _, ok := secList.RowSecurity[key]; !ok {
|
|
t.Error("expected row security to be loaded")
|
|
}
|
|
})
|
|
|
|
t.Run("no user in context", func(t *testing.T) {
|
|
provider := &mockSecurityProvider{}
|
|
secList, _ := NewSecurityList(provider)
|
|
|
|
secCtx := &mockSecurityContext{
|
|
ctx: context.Background(),
|
|
hasUser: false,
|
|
schema: "public",
|
|
entity: "users",
|
|
}
|
|
|
|
err := LoadSecurityRules(secCtx, secList)
|
|
if err != nil {
|
|
t.Fatalf("expected no error with no user, got %v", err)
|
|
}
|
|
})
|
|
}
|
|
|
|
// Test applyRowSecurity
|
|
func TestApplyRowSecurity(t *testing.T) {
|
|
type TestModel struct {
|
|
ID int `bun:"id,pk"`
|
|
}
|
|
|
|
t.Run("apply row security template", func(t *testing.T) {
|
|
provider := &mockSecurityProvider{
|
|
rowSecurity: RowSecurity{
|
|
Schema: "public",
|
|
Tablename: "orders",
|
|
Template: "user_id = {UserID}",
|
|
HasBlock: false,
|
|
},
|
|
}
|
|
secList, _ := NewSecurityList(provider)
|
|
ctx := context.Background()
|
|
|
|
// Load row security
|
|
_, _ = secList.LoadRowSecurity(ctx, 1, "public", "orders", false)
|
|
|
|
// Mock query that supports Where
|
|
type MockQuery struct {
|
|
whereClause string
|
|
}
|
|
mockQuery := &MockQuery{}
|
|
|
|
secCtx := &mockSecurityContext{
|
|
ctx: ctx,
|
|
userID: 1,
|
|
hasUser: true,
|
|
schema: "public",
|
|
entity: "orders",
|
|
model: &TestModel{},
|
|
query: mockQuery,
|
|
}
|
|
|
|
err := ApplyRowSecurity(secCtx, secList)
|
|
if err != nil {
|
|
t.Fatalf("expected no error, got %v", err)
|
|
}
|
|
|
|
// Note: The actual WHERE clause application requires a query type that implements Where()
|
|
// In a real scenario, this would be a bun.SelectQuery or similar
|
|
})
|
|
|
|
t.Run("block access", func(t *testing.T) {
|
|
provider := &mockSecurityProvider{
|
|
rowSecurity: RowSecurity{
|
|
Schema: "public",
|
|
Tablename: "secrets",
|
|
HasBlock: true,
|
|
},
|
|
}
|
|
secList, _ := NewSecurityList(provider)
|
|
ctx := context.Background()
|
|
|
|
// Load row security
|
|
_, _ = secList.LoadRowSecurity(ctx, 1, "public", "secrets", false)
|
|
|
|
secCtx := &mockSecurityContext{
|
|
ctx: ctx,
|
|
userID: 1,
|
|
hasUser: true,
|
|
schema: "public",
|
|
entity: "secrets",
|
|
}
|
|
|
|
err := ApplyRowSecurity(secCtx, secList)
|
|
if err == nil {
|
|
t.Fatal("expected error for blocked access")
|
|
}
|
|
})
|
|
|
|
t.Run("no user in context", func(t *testing.T) {
|
|
provider := &mockSecurityProvider{}
|
|
secList, _ := NewSecurityList(provider)
|
|
|
|
secCtx := &mockSecurityContext{
|
|
ctx: context.Background(),
|
|
hasUser: false,
|
|
schema: "public",
|
|
entity: "orders",
|
|
}
|
|
|
|
err := ApplyRowSecurity(secCtx, secList)
|
|
if err != nil {
|
|
t.Fatalf("expected no error with no user, got %v", err)
|
|
}
|
|
})
|
|
|
|
t.Run("no row security defined", func(t *testing.T) {
|
|
provider := &mockSecurityProvider{}
|
|
secList, _ := NewSecurityList(provider)
|
|
|
|
secCtx := &mockSecurityContext{
|
|
ctx: context.Background(),
|
|
userID: 1,
|
|
hasUser: true,
|
|
schema: "public",
|
|
entity: "unknown_table",
|
|
}
|
|
|
|
err := ApplyRowSecurity(secCtx, secList)
|
|
if err != nil {
|
|
t.Fatalf("expected no error with no security, got %v", err)
|
|
}
|
|
})
|
|
}
|
|
|
|
// Test applyColumnSecurity
|
|
func TestApplyColumnSecurityHook(t *testing.T) {
|
|
type User struct {
|
|
ID int `bun:"id,pk"`
|
|
Email string `bun:"email"`
|
|
}
|
|
|
|
t.Run("apply column security to results", func(t *testing.T) {
|
|
provider := &mockSecurityProvider{
|
|
columnSecurity: []ColumnSecurity{
|
|
{
|
|
Schema: "public",
|
|
Tablename: "users",
|
|
Path: []string{"email"},
|
|
Accesstype: "mask",
|
|
UserID: 1,
|
|
MaskStart: 3,
|
|
MaskEnd: 0,
|
|
MaskChar: "*",
|
|
},
|
|
},
|
|
}
|
|
secList, _ := NewSecurityList(provider)
|
|
ctx := context.Background()
|
|
|
|
// Load column security
|
|
_ = secList.LoadColumnSecurity(ctx, 1, "public", "users", false)
|
|
|
|
users := []User{
|
|
{ID: 1, Email: "test@example.com"},
|
|
{ID: 2, Email: "user@test.com"},
|
|
}
|
|
|
|
secCtx := &mockSecurityContext{
|
|
ctx: ctx,
|
|
userID: 1,
|
|
hasUser: true,
|
|
schema: "public",
|
|
entity: "users",
|
|
model: &User{},
|
|
result: users,
|
|
}
|
|
|
|
err := ApplyColumnSecurity(secCtx, secList)
|
|
if err != nil {
|
|
t.Fatalf("expected no error, got %v", err)
|
|
}
|
|
|
|
// Check that result was updated with masked data
|
|
maskedResult := secCtx.GetResult()
|
|
if maskedResult == nil {
|
|
t.Error("expected result to be set")
|
|
}
|
|
})
|
|
|
|
t.Run("no user in context", func(t *testing.T) {
|
|
provider := &mockSecurityProvider{}
|
|
secList, _ := NewSecurityList(provider)
|
|
|
|
secCtx := &mockSecurityContext{
|
|
ctx: context.Background(),
|
|
hasUser: false,
|
|
schema: "public",
|
|
entity: "users",
|
|
}
|
|
|
|
err := ApplyColumnSecurity(secCtx, secList)
|
|
if err != nil {
|
|
t.Fatalf("expected no error with no user, got %v", err)
|
|
}
|
|
})
|
|
|
|
t.Run("nil result", func(t *testing.T) {
|
|
provider := &mockSecurityProvider{}
|
|
secList, _ := NewSecurityList(provider)
|
|
|
|
secCtx := &mockSecurityContext{
|
|
ctx: context.Background(),
|
|
userID: 1,
|
|
hasUser: true,
|
|
schema: "public",
|
|
entity: "users",
|
|
result: nil,
|
|
}
|
|
|
|
err := ApplyColumnSecurity(secCtx, secList)
|
|
if err != nil {
|
|
t.Fatalf("expected no error with nil result, got %v", err)
|
|
}
|
|
})
|
|
|
|
t.Run("nil model", func(t *testing.T) {
|
|
provider := &mockSecurityProvider{}
|
|
secList, _ := NewSecurityList(provider)
|
|
|
|
secCtx := &mockSecurityContext{
|
|
ctx: context.Background(),
|
|
userID: 1,
|
|
hasUser: true,
|
|
schema: "public",
|
|
entity: "users",
|
|
model: nil,
|
|
result: []interface{}{},
|
|
}
|
|
|
|
err := ApplyColumnSecurity(secCtx, secList)
|
|
if err != nil {
|
|
t.Fatalf("expected no error with nil model, got %v", err)
|
|
}
|
|
})
|
|
}
|
|
|
|
// Test logDataAccess
|
|
func TestLogDataAccess(t *testing.T) {
|
|
t.Run("log access with user", func(t *testing.T) {
|
|
secCtx := &mockSecurityContext{
|
|
ctx: context.Background(),
|
|
userID: 1,
|
|
hasUser: true,
|
|
schema: "public",
|
|
entity: "users",
|
|
}
|
|
|
|
err := LogDataAccess(secCtx)
|
|
if err != nil {
|
|
t.Fatalf("expected no error, got %v", err)
|
|
}
|
|
})
|
|
|
|
t.Run("log access without user", func(t *testing.T) {
|
|
secCtx := &mockSecurityContext{
|
|
ctx: context.Background(),
|
|
hasUser: false,
|
|
schema: "public",
|
|
entity: "users",
|
|
}
|
|
|
|
err := LogDataAccess(secCtx)
|
|
if err != nil {
|
|
t.Fatalf("expected no error, got %v", err)
|
|
}
|
|
})
|
|
}
|
|
|
|
// Test integration: loading and applying all security
|
|
func TestSecurityIntegration(t *testing.T) {
|
|
type Order struct {
|
|
ID int `bun:"id,pk"`
|
|
UserID int `bun:"user_id"`
|
|
Amount int `bun:"amount"`
|
|
Description string `bun:"description"`
|
|
}
|
|
|
|
provider := &mockSecurityProvider{
|
|
columnSecurity: []ColumnSecurity{
|
|
{
|
|
Schema: "public",
|
|
Tablename: "orders",
|
|
Path: []string{"amount"},
|
|
Accesstype: "mask",
|
|
UserID: 1,
|
|
},
|
|
},
|
|
rowSecurity: RowSecurity{
|
|
Schema: "public",
|
|
Tablename: "orders",
|
|
Template: "user_id = {UserID}",
|
|
HasBlock: false,
|
|
},
|
|
}
|
|
|
|
secList, _ := NewSecurityList(provider)
|
|
ctx := context.Background()
|
|
|
|
t.Run("complete security flow", func(t *testing.T) {
|
|
secCtx := &mockSecurityContext{
|
|
ctx: ctx,
|
|
userID: 1,
|
|
hasUser: true,
|
|
schema: "public",
|
|
entity: "orders",
|
|
model: &Order{},
|
|
}
|
|
|
|
// Step 1: Load security rules
|
|
err := LoadSecurityRules(secCtx, secList)
|
|
if err != nil {
|
|
t.Fatalf("LoadSecurityRules failed: %v", err)
|
|
}
|
|
|
|
// Step 2: Apply row security
|
|
err = ApplyRowSecurity(secCtx, secList)
|
|
if err != nil {
|
|
t.Fatalf("ApplyRowSecurity failed: %v", err)
|
|
}
|
|
|
|
// Step 3: Set some results
|
|
orders := []Order{
|
|
{ID: 1, UserID: 1, Amount: 1000, Description: "Order 1"},
|
|
{ID: 2, UserID: 1, Amount: 2000, Description: "Order 2"},
|
|
}
|
|
secCtx.SetResult(orders)
|
|
|
|
// Step 4: Apply column security
|
|
err = ApplyColumnSecurity(secCtx, secList)
|
|
if err != nil {
|
|
t.Fatalf("ApplyColumnSecurity failed: %v", err)
|
|
}
|
|
|
|
// Step 5: Log access
|
|
err = LogDataAccess(secCtx)
|
|
if err != nil {
|
|
t.Fatalf("LogDataAccess failed: %v", err)
|
|
}
|
|
})
|
|
|
|
t.Run("security without user context", func(t *testing.T) {
|
|
secCtx := &mockSecurityContext{
|
|
ctx: ctx,
|
|
hasUser: false,
|
|
schema: "public",
|
|
entity: "orders",
|
|
}
|
|
|
|
// All security operations should handle missing user gracefully
|
|
_ = LoadSecurityRules(secCtx, secList)
|
|
_ = ApplyRowSecurity(secCtx, secList)
|
|
_ = ApplyColumnSecurity(secCtx, secList)
|
|
_ = LogDataAccess(secCtx)
|
|
|
|
// If we reach here without panics, the test passes
|
|
})
|
|
}
|
|
|
|
// Test RowSecurity GetTemplate with various placeholders
|
|
func TestRowSecurityGetTemplateIntegration(t *testing.T) {
|
|
type Model struct {
|
|
OrderID int `bun:"order_id,pk"`
|
|
}
|
|
|
|
tests := []struct {
|
|
name string
|
|
rowSec RowSecurity
|
|
pkName string
|
|
expectedPart string // Part of the expected output
|
|
}{
|
|
{
|
|
name: "with all placeholders",
|
|
rowSec: RowSecurity{
|
|
Schema: "sales",
|
|
Tablename: "orders",
|
|
UserID: 42,
|
|
Template: "{PrimaryKeyName} IN (SELECT {PrimaryKeyName} FROM {SchemaName}.{TableName}_access WHERE user_id = {UserID})",
|
|
},
|
|
pkName: "order_id",
|
|
expectedPart: "order_id IN (SELECT order_id FROM sales.orders_access WHERE user_id = 42)",
|
|
},
|
|
{
|
|
name: "simple user filter",
|
|
rowSec: RowSecurity{
|
|
Schema: "public",
|
|
Tablename: "orders",
|
|
UserID: 1,
|
|
Template: "user_id = {UserID}",
|
|
},
|
|
pkName: "id",
|
|
expectedPart: "user_id = 1",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
modelType := reflect.TypeOf(Model{})
|
|
result := tt.rowSec.GetTemplate(tt.pkName, modelType)
|
|
|
|
if result != tt.expectedPart {
|
|
t.Errorf("GetTemplate() = %q, want %q", result, tt.expectedPart)
|
|
}
|
|
})
|
|
}
|
|
}
|