diff --git a/pkg/websocketspec/connection.go b/pkg/websocketspec/connection.go index b26d858..f3e4c17 100644 --- a/pkg/websocketspec/connection.go +++ b/pkg/websocketspec/connection.go @@ -116,17 +116,21 @@ func (cm *ConnectionManager) Run() { case conn := <-cm.register: cm.mu.Lock() cm.connections[conn.ID] = conn + count := len(cm.connections) cm.mu.Unlock() - logger.Info("[WebSocketSpec] Connection registered: %s (total: %d)", conn.ID, cm.Count()) + logger.Info("[WebSocketSpec] Connection registered: %s (total: %d)", conn.ID, count) case conn := <-cm.unregister: cm.mu.Lock() if _, ok := cm.connections[conn.ID]; ok { delete(cm.connections, conn.ID) close(conn.send) - logger.Info("[WebSocketSpec] Connection unregistered: %s (total: %d)", conn.ID, cm.Count()) + count := len(cm.connections) + cm.mu.Unlock() + logger.Info("[WebSocketSpec] Connection unregistered: %s (total: %d)", conn.ID, count) + } else { + cm.mu.Unlock() } - cm.mu.Unlock() case msg := <-cm.broadcast: cm.mu.RLock() @@ -296,13 +300,19 @@ func (c *Connection) SendJSON(v interface{}) error { // Close closes the connection func (c *Connection) Close() { c.closedOnce.Do(func() { - c.cancel() - c.ws.Close() + if c.cancel != nil { + c.cancel() + } + if c.ws != nil { + c.ws.Close() + } // Clean up subscriptions c.mu.Lock() for subID := range c.subscriptions { - c.handler.subscriptionManager.Unsubscribe(subID) + if c.handler != nil && c.handler.subscriptionManager != nil { + c.handler.subscriptionManager.Unsubscribe(subID) + } } c.subscriptions = make(map[string]*Subscription) c.mu.Unlock() diff --git a/pkg/websocketspec/handler_test.go b/pkg/websocketspec/handler_test.go index 311ce39..d950914 100644 --- a/pkg/websocketspec/handler_test.go +++ b/pkg/websocketspec/handler_test.go @@ -2,6 +2,7 @@ package websocketspec import ( "context" + "encoding/json" "testing" "github.com/bitechdev/ResolveSpec/pkg/common" @@ -344,6 +345,7 @@ func TestNewHandler(t *testing.T) { mockRegistry := &MockModelRegistry{} handler := NewHandler(mockDB, mockRegistry) + defer handler.Shutdown() assert.NotNil(t, handler) assert.NotNil(t, handler.db) @@ -358,6 +360,7 @@ func TestHandler_Hooks(t *testing.T) { mockDB := &MockDatabase{} mockRegistry := &MockModelRegistry{} handler := NewHandler(mockDB, mockRegistry) + defer handler.Shutdown() hooks := handler.Hooks() assert.NotNil(t, hooks) @@ -368,6 +371,7 @@ func TestHandler_Registry(t *testing.T) { mockDB := &MockDatabase{} mockRegistry := &MockModelRegistry{} handler := NewHandler(mockDB, mockRegistry) + defer handler.Shutdown() registry := handler.Registry() assert.NotNil(t, registry) @@ -378,6 +382,7 @@ func TestHandler_GetDatabase(t *testing.T) { mockDB := &MockDatabase{} mockRegistry := &MockModelRegistry{} handler := NewHandler(mockDB, mockRegistry) + defer handler.Shutdown() db := handler.GetDatabase() assert.NotNil(t, db) @@ -388,6 +393,7 @@ func TestHandler_GetConnectionCount(t *testing.T) { mockDB := &MockDatabase{} mockRegistry := &MockModelRegistry{} handler := NewHandler(mockDB, mockRegistry) + defer handler.Shutdown() count := handler.GetConnectionCount() assert.Equal(t, 0, count) @@ -397,6 +403,7 @@ func TestHandler_GetSubscriptionCount(t *testing.T) { mockDB := &MockDatabase{} mockRegistry := &MockModelRegistry{} handler := NewHandler(mockDB, mockRegistry) + defer handler.Shutdown() count := handler.GetSubscriptionCount() assert.Equal(t, 0, count) @@ -406,6 +413,7 @@ func TestHandler_GetConnection(t *testing.T) { mockDB := &MockDatabase{} mockRegistry := &MockModelRegistry{} handler := NewHandler(mockDB, mockRegistry) + defer handler.Shutdown() // Non-existent connection _, exists := handler.GetConnection("non-existent") @@ -416,6 +424,7 @@ func TestHandler_HandleMessage_InvalidType(t *testing.T) { mockDB := &MockDatabase{} mockRegistry := &MockModelRegistry{} handler := NewHandler(mockDB, mockRegistry) + defer handler.Shutdown() conn := &Connection{ ID: "conn-1", @@ -431,25 +440,27 @@ func TestHandler_HandleMessage_InvalidType(t *testing.T) { handler.HandleMessage(conn, msg) + // Shutdown handler properly + defer handler.Shutdown() + // Should send error response select { case data := <-conn.send: - var response map[string]interface{} - require.NoError(t, ParseMessageBytes(data, &response)) - assert.False(t, response["success"].(bool)) + var response ResponseMessage + err := json.Unmarshal(data, &response) + require.NoError(t, err) + assert.False(t, response.Success) + assert.NotNil(t, response.Error) default: t.Fatal("Expected error response") } } -func ParseMessageBytes(data []byte, v interface{}) error { - return nil // Simplified for testing -} - func TestHandler_GetOperatorSQL(t *testing.T) { mockDB := &MockDatabase{} mockRegistry := &MockModelRegistry{} handler := NewHandler(mockDB, mockRegistry) + defer handler.Shutdown() tests := []struct { operator string @@ -479,6 +490,7 @@ func TestHandler_GetTableName(t *testing.T) { mockDB := &MockDatabase{} mockRegistry := &MockModelRegistry{} handler := NewHandler(mockDB, mockRegistry) + defer handler.Shutdown() tests := []struct { name string @@ -518,6 +530,7 @@ func TestHandler_GetMetadata(t *testing.T) { mockDB := &MockDatabase{} mockRegistry := &MockModelRegistry{} handler := NewHandler(mockDB, mockRegistry) + defer handler.Shutdown() metadata := handler.getMetadata("public", "users", &TestUser{}) @@ -533,13 +546,19 @@ func TestHandler_NotifySubscribers(t *testing.T) { mockDB := &MockDatabase{} mockRegistry := &MockModelRegistry{} handler := NewHandler(mockDB, mockRegistry) + defer handler.Shutdown() // Create connection + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + conn := &Connection{ ID: "conn-1", send: make(chan []byte, 256), subscriptions: make(map[string]*Subscription), handler: handler, + ctx: ctx, + cancel: cancel, } // Register connection @@ -566,6 +585,7 @@ func TestHandler_NotifySubscribers_NoSubscribers(t *testing.T) { mockDB := &MockDatabase{} mockRegistry := &MockModelRegistry{} handler := NewHandler(mockDB, mockRegistry) + defer handler.Shutdown() // Notify with no subscribers - should not panic data := map[string]interface{}{"id": 1, "name": "John"} @@ -578,6 +598,7 @@ func TestHandler_NotifySubscribers_ConnectionNotFound(t *testing.T) { mockDB := &MockDatabase{} mockRegistry := &MockModelRegistry{} handler := NewHandler(mockDB, mockRegistry) + defer handler.Shutdown() // Create subscription without connection handler.subscriptionManager.Subscribe("sub-1", "conn-1", "public", "users", nil) @@ -593,6 +614,7 @@ func TestHandler_HooksIntegration(t *testing.T) { mockDB := &MockDatabase{} mockRegistry := &MockModelRegistry{} handler := NewHandler(mockDB, mockRegistry) + defer handler.Shutdown() beforeCalled := false afterCalled := false @@ -625,6 +647,7 @@ func TestHandler_Shutdown(t *testing.T) { mockDB := &MockDatabase{} mockRegistry := &MockModelRegistry{} handler := NewHandler(mockDB, mockRegistry) + defer handler.Shutdown() // Shutdown should not panic handler.Shutdown() @@ -642,6 +665,7 @@ func TestHandler_SubscriptionLifecycle(t *testing.T) { mockDB := &MockDatabase{} mockRegistry := &MockModelRegistry{} handler := NewHandler(mockDB, mockRegistry) + defer handler.Shutdown() // Create connection conn := &Connection{ @@ -681,6 +705,7 @@ func TestHandler_UnsubscribeLifecycle(t *testing.T) { mockDB := &MockDatabase{} mockRegistry := &MockModelRegistry{} handler := NewHandler(mockDB, mockRegistry) + defer handler.Shutdown() // Create connection conn := &Connection{ @@ -725,11 +750,17 @@ func TestHandler_HandlePing(t *testing.T) { mockDB := &MockDatabase{} mockRegistry := &MockModelRegistry{} handler := NewHandler(mockDB, mockRegistry) + defer handler.Shutdown() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() conn := &Connection{ ID: "conn-1", send: make(chan []byte, 256), subscriptions: make(map[string]*Subscription), + ctx: ctx, + cancel: cancel, } msg := &Message{ @@ -752,6 +783,7 @@ func TestHandler_CompleteWorkflow(t *testing.T) { mockDB := &MockDatabase{} mockRegistry := &MockModelRegistry{} handler := NewHandler(mockDB, mockRegistry) + defer handler.Shutdown() // Setup hooks (these are registered but not called in this test workflow) handler.Hooks().RegisterBefore(OperationCreate, func(ctx *HookContext) error {