Files
ResolveSpec/pkg/websocketspec/connection_test.go
2025-12-23 17:27:29 +02:00

597 lines
14 KiB
Go

package websocketspec
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// Helper function to create a test connection with proper initialization
func createTestConnection(id string) *Connection {
ctx, cancel := context.WithCancel(context.Background())
return &Connection{
ID: id,
send: make(chan []byte, 256),
subscriptions: make(map[string]*Subscription),
metadata: make(map[string]interface{}),
ctx: ctx,
cancel: cancel,
}
}
func TestNewConnectionManager(t *testing.T) {
ctx := context.Background()
cm := NewConnectionManager(ctx)
assert.NotNil(t, cm)
assert.NotNil(t, cm.connections)
assert.NotNil(t, cm.register)
assert.NotNil(t, cm.unregister)
assert.NotNil(t, cm.broadcast)
assert.Equal(t, 0, cm.Count())
}
func TestConnectionManager_Count(t *testing.T) {
ctx := context.Background()
cm := NewConnectionManager(ctx)
// Start manager
go cm.Run()
defer func() {
// Cancel context without calling Shutdown which tries to close connections
cm.cancel()
}()
// Initially empty
assert.Equal(t, 0, cm.Count())
// Add a connection
conn := createTestConnection("conn-1")
cm.Register(conn)
time.Sleep(10 * time.Millisecond) // Give time for registration
assert.Equal(t, 1, cm.Count())
}
func TestConnectionManager_Register(t *testing.T) {
ctx := context.Background()
cm := NewConnectionManager(ctx)
// Start manager
go cm.Run()
defer cm.cancel()
conn := createTestConnection("conn-1")
cm.Register(conn)
time.Sleep(10 * time.Millisecond)
// Verify connection was registered
retrievedConn, exists := cm.GetConnection("conn-1")
assert.True(t, exists)
assert.Equal(t, "conn-1", retrievedConn.ID)
}
func TestConnectionManager_Unregister(t *testing.T) {
ctx := context.Background()
cm := NewConnectionManager(ctx)
// Start manager
go cm.Run()
defer cm.cancel()
conn := &Connection{
ID: "conn-1",
send: make(chan []byte, 256),
subscriptions: make(map[string]*Subscription),
}
cm.Register(conn)
time.Sleep(10 * time.Millisecond)
assert.Equal(t, 1, cm.Count())
cm.Unregister(conn)
time.Sleep(10 * time.Millisecond)
assert.Equal(t, 0, cm.Count())
// Verify connection was removed
_, exists := cm.GetConnection("conn-1")
assert.False(t, exists)
}
func TestConnectionManager_GetConnection(t *testing.T) {
ctx := context.Background()
cm := NewConnectionManager(ctx)
// Start manager
go cm.Run()
defer cm.cancel()
// Non-existent connection
_, exists := cm.GetConnection("non-existent")
assert.False(t, exists)
// Register connection
conn := &Connection{
ID: "conn-1",
send: make(chan []byte, 256),
subscriptions: make(map[string]*Subscription),
}
cm.Register(conn)
time.Sleep(10 * time.Millisecond)
// Get existing connection
retrievedConn, exists := cm.GetConnection("conn-1")
assert.True(t, exists)
assert.Equal(t, "conn-1", retrievedConn.ID)
}
func TestConnectionManager_MultipleConnections(t *testing.T) {
ctx := context.Background()
cm := NewConnectionManager(ctx)
// Start manager
go cm.Run()
defer cm.cancel()
// Register multiple connections
conn1 := &Connection{ID: "conn-1", send: make(chan []byte, 256), subscriptions: make(map[string]*Subscription)}
conn2 := &Connection{ID: "conn-2", send: make(chan []byte, 256), subscriptions: make(map[string]*Subscription)}
conn3 := &Connection{ID: "conn-3", send: make(chan []byte, 256), subscriptions: make(map[string]*Subscription)}
cm.Register(conn1)
cm.Register(conn2)
cm.Register(conn3)
time.Sleep(10 * time.Millisecond)
assert.Equal(t, 3, cm.Count())
// Verify all connections exist
_, exists := cm.GetConnection("conn-1")
assert.True(t, exists)
_, exists = cm.GetConnection("conn-2")
assert.True(t, exists)
_, exists = cm.GetConnection("conn-3")
assert.True(t, exists)
// Unregister one
cm.Unregister(conn2)
time.Sleep(10 * time.Millisecond)
assert.Equal(t, 2, cm.Count())
// Verify conn-2 is gone but others remain
_, exists = cm.GetConnection("conn-2")
assert.False(t, exists)
_, exists = cm.GetConnection("conn-1")
assert.True(t, exists)
_, exists = cm.GetConnection("conn-3")
assert.True(t, exists)
}
func TestConnectionManager_Shutdown(t *testing.T) {
ctx := context.Background()
cm := NewConnectionManager(ctx)
// Start manager
go cm.Run()
// Register connections
conn1 := &Connection{
ID: "conn-1",
send: make(chan []byte, 256),
subscriptions: make(map[string]*Subscription),
ctx: context.Background(),
}
conn1.ctx, conn1.cancel = context.WithCancel(context.Background())
cm.Register(conn1)
time.Sleep(10 * time.Millisecond)
assert.Equal(t, 1, cm.Count())
// Shutdown
cm.Shutdown()
time.Sleep(10 * time.Millisecond)
// Verify context was cancelled
select {
case <-cm.ctx.Done():
// Expected
case <-time.After(100 * time.Millisecond):
t.Fatal("Context not cancelled after shutdown")
}
}
func TestConnection_SetMetadata(t *testing.T) {
conn := &Connection{
metadata: make(map[string]interface{}),
}
conn.SetMetadata("user_id", 123)
conn.SetMetadata("username", "john")
// Verify metadata was set
userID, exists := conn.GetMetadata("user_id")
assert.True(t, exists)
assert.Equal(t, 123, userID)
username, exists := conn.GetMetadata("username")
assert.True(t, exists)
assert.Equal(t, "john", username)
}
func TestConnection_GetMetadata(t *testing.T) {
conn := &Connection{
metadata: map[string]interface{}{
"user_id": 123,
"role": "admin",
},
}
// Get existing metadata
userID, exists := conn.GetMetadata("user_id")
assert.True(t, exists)
assert.Equal(t, 123, userID)
// Get non-existent metadata
_, exists = conn.GetMetadata("non_existent")
assert.False(t, exists)
}
func TestConnection_AddSubscription(t *testing.T) {
conn := &Connection{
subscriptions: make(map[string]*Subscription),
}
sub := &Subscription{
ID: "sub-1",
ConnectionID: "conn-1",
Entity: "users",
Active: true,
}
conn.AddSubscription(sub)
// Verify subscription was added
retrievedSub, exists := conn.GetSubscription("sub-1")
assert.True(t, exists)
assert.Equal(t, "sub-1", retrievedSub.ID)
}
func TestConnection_RemoveSubscription(t *testing.T) {
sub := &Subscription{
ID: "sub-1",
ConnectionID: "conn-1",
Entity: "users",
Active: true,
}
conn := &Connection{
subscriptions: map[string]*Subscription{
"sub-1": sub,
},
}
// Verify subscription exists
_, exists := conn.GetSubscription("sub-1")
assert.True(t, exists)
// Remove subscription
conn.RemoveSubscription("sub-1")
// Verify subscription was removed
_, exists = conn.GetSubscription("sub-1")
assert.False(t, exists)
}
func TestConnection_GetSubscription(t *testing.T) {
sub1 := &Subscription{ID: "sub-1", Entity: "users"}
sub2 := &Subscription{ID: "sub-2", Entity: "posts"}
conn := &Connection{
subscriptions: map[string]*Subscription{
"sub-1": sub1,
"sub-2": sub2,
},
}
// Get existing subscription
retrievedSub, exists := conn.GetSubscription("sub-1")
assert.True(t, exists)
assert.Equal(t, "sub-1", retrievedSub.ID)
// Get non-existent subscription
_, exists = conn.GetSubscription("non-existent")
assert.False(t, exists)
}
func TestConnection_MultipleSubscriptions(t *testing.T) {
conn := &Connection{
subscriptions: make(map[string]*Subscription),
}
sub1 := &Subscription{ID: "sub-1", Entity: "users"}
sub2 := &Subscription{ID: "sub-2", Entity: "posts"}
sub3 := &Subscription{ID: "sub-3", Entity: "comments"}
conn.AddSubscription(sub1)
conn.AddSubscription(sub2)
conn.AddSubscription(sub3)
// Verify all subscriptions exist
_, exists := conn.GetSubscription("sub-1")
assert.True(t, exists)
_, exists = conn.GetSubscription("sub-2")
assert.True(t, exists)
_, exists = conn.GetSubscription("sub-3")
assert.True(t, exists)
// Remove one subscription
conn.RemoveSubscription("sub-2")
// Verify sub-2 is gone but others remain
_, exists = conn.GetSubscription("sub-2")
assert.False(t, exists)
_, exists = conn.GetSubscription("sub-1")
assert.True(t, exists)
_, exists = conn.GetSubscription("sub-3")
assert.True(t, exists)
}
func TestBroadcastMessage_Structure(t *testing.T) {
msg := &BroadcastMessage{
Message: []byte("test message"),
Filter: func(conn *Connection) bool {
return true
},
}
assert.NotNil(t, msg.Message)
assert.NotNil(t, msg.Filter)
assert.Equal(t, "test message", string(msg.Message))
}
func TestBroadcastMessage_Filter(t *testing.T) {
// Filter that only allows specific connection
filter := func(conn *Connection) bool {
return conn.ID == "conn-1"
}
msg := &BroadcastMessage{
Message: []byte("test"),
Filter: filter,
}
conn1 := &Connection{ID: "conn-1"}
conn2 := &Connection{ID: "conn-2"}
assert.True(t, msg.Filter(conn1))
assert.False(t, msg.Filter(conn2))
}
func TestConnectionManager_Broadcast(t *testing.T) {
ctx := context.Background()
cm := NewConnectionManager(ctx)
// Start manager
go cm.Run()
defer cm.cancel()
// Register connections
conn1 := &Connection{ID: "conn-1", send: make(chan []byte, 256), subscriptions: make(map[string]*Subscription)}
conn2 := &Connection{ID: "conn-2", send: make(chan []byte, 256), subscriptions: make(map[string]*Subscription)}
cm.Register(conn1)
cm.Register(conn2)
time.Sleep(10 * time.Millisecond)
// Broadcast message
message := []byte("test broadcast")
cm.Broadcast(message, nil)
time.Sleep(10 * time.Millisecond)
// Verify both connections received the message
select {
case msg := <-conn1.send:
assert.Equal(t, message, msg)
case <-time.After(100 * time.Millisecond):
t.Fatal("conn1 did not receive message")
}
select {
case msg := <-conn2.send:
assert.Equal(t, message, msg)
case <-time.After(100 * time.Millisecond):
t.Fatal("conn2 did not receive message")
}
}
func TestConnectionManager_BroadcastWithFilter(t *testing.T) {
ctx := context.Background()
cm := NewConnectionManager(ctx)
// Start manager
go cm.Run()
defer cm.cancel()
// Register connections with metadata
conn1 := &Connection{
ID: "conn-1",
send: make(chan []byte, 256),
subscriptions: make(map[string]*Subscription),
metadata: map[string]interface{}{"role": "admin"},
}
conn2 := &Connection{
ID: "conn-2",
send: make(chan []byte, 256),
subscriptions: make(map[string]*Subscription),
metadata: map[string]interface{}{"role": "user"},
}
cm.Register(conn1)
cm.Register(conn2)
time.Sleep(10 * time.Millisecond)
// Broadcast only to admins
filter := func(conn *Connection) bool {
role, _ := conn.GetMetadata("role")
return role == "admin"
}
message := []byte("admin message")
cm.Broadcast(message, filter)
time.Sleep(10 * time.Millisecond)
// Verify only conn1 received the message
select {
case msg := <-conn1.send:
assert.Equal(t, message, msg)
case <-time.After(100 * time.Millisecond):
t.Fatal("conn1 (admin) did not receive message")
}
// Verify conn2 did not receive the message
select {
case <-conn2.send:
t.Fatal("conn2 (user) should not have received admin message")
case <-time.After(50 * time.Millisecond):
// Expected - no message
}
}
func TestConnection_ConcurrentMetadataAccess(t *testing.T) {
// This test verifies that concurrent metadata access doesn't cause race conditions
// Run with: go test -race
conn := &Connection{
metadata: make(map[string]interface{}),
}
done := make(chan bool)
// Goroutine 1: Write metadata
go func() {
for i := 0; i < 100; i++ {
conn.SetMetadata("key", i)
}
done <- true
}()
// Goroutine 2: Read metadata
go func() {
for i := 0; i < 100; i++ {
conn.GetMetadata("key")
}
done <- true
}()
// Wait for completion
<-done
<-done
}
func TestConnection_ConcurrentSubscriptionAccess(t *testing.T) {
// This test verifies that concurrent subscription access doesn't cause race conditions
// Run with: go test -race
conn := &Connection{
subscriptions: make(map[string]*Subscription),
}
done := make(chan bool)
// Goroutine 1: Add subscriptions
go func() {
for i := 0; i < 100; i++ {
sub := &Subscription{ID: "sub-" + string(rune(i)), Entity: "users"}
conn.AddSubscription(sub)
}
done <- true
}()
// Goroutine 2: Get subscriptions
go func() {
for i := 0; i < 100; i++ {
conn.GetSubscription("sub-" + string(rune(i)))
}
done <- true
}()
// Wait for completion
<-done
<-done
}
func TestConnectionManager_CompleteLifecycle(t *testing.T) {
ctx := context.Background()
cm := NewConnectionManager(ctx)
// Start manager
go cm.Run()
defer cm.cancel()
// Create and register connection
conn := &Connection{
ID: "conn-1",
send: make(chan []byte, 256),
subscriptions: make(map[string]*Subscription),
metadata: make(map[string]interface{}),
}
// Set metadata
conn.SetMetadata("user_id", 123)
// Add subscriptions
sub1 := &Subscription{ID: "sub-1", Entity: "users"}
sub2 := &Subscription{ID: "sub-2", Entity: "posts"}
conn.AddSubscription(sub1)
conn.AddSubscription(sub2)
// Register connection
cm.Register(conn)
time.Sleep(10 * time.Millisecond)
assert.Equal(t, 1, cm.Count())
// Verify connection exists
retrievedConn, exists := cm.GetConnection("conn-1")
require.True(t, exists)
assert.Equal(t, "conn-1", retrievedConn.ID)
// Verify metadata
userID, exists := retrievedConn.GetMetadata("user_id")
assert.True(t, exists)
assert.Equal(t, 123, userID)
// Verify subscriptions
_, exists = retrievedConn.GetSubscription("sub-1")
assert.True(t, exists)
_, exists = retrievedConn.GetSubscription("sub-2")
assert.True(t, exists)
// Broadcast message
message := []byte("test message")
cm.Broadcast(message, nil)
time.Sleep(10 * time.Millisecond)
select {
case msg := <-retrievedConn.send:
assert.Equal(t, message, msg)
case <-time.After(100 * time.Millisecond):
t.Fatal("Connection did not receive broadcast")
}
// Unregister connection
cm.Unregister(conn)
time.Sleep(10 * time.Millisecond)
assert.Equal(t, 0, cm.Count())
// Verify connection is gone
_, exists = cm.GetConnection("conn-1")
assert.False(t, exists)
}