ResolveSpec/pkg/websocketspec/connection.go
2025-12-12 16:14:47 +02:00

370 lines
8.9 KiB
Go

package websocketspec
import (
"context"
"encoding/json"
"fmt"
"sync"
"time"
"github.com/bitechdev/ResolveSpec/pkg/logger"
"github.com/gorilla/websocket"
)
// Connection rvepresents a WebSocket connection with its state
type Connection struct {
// ID is a unique identifier for this connection
ID string
// ws is the underlying WebSocket connection
ws *websocket.Conn
// send is a channel for outbound messages
send chan []byte
// subscriptions holds active subscriptions for this connection
subscriptions map[string]*Subscription
// mu protects subscriptions map
mu sync.RWMutex
// ctx is the connection context
ctx context.Context
// cancel cancels the connection context
cancel context.CancelFunc
// handler is the WebSocket handler
handler *Handler
// metadata stores connection-specific metadata (e.g., user info, auth state)
metadata map[string]interface{}
// metaMu protects metadata map
metaMu sync.RWMutex
// closedOnce ensures cleanup happens only once
closedOnce sync.Once
}
// ConnectionManager manages all active WebSocket connections
type ConnectionManager struct {
// connections holds all active connections
connections map[string]*Connection
// mu protects the connections map
mu sync.RWMutex
// register channel for new connections
register chan *Connection
// unregister channel for closing connections
unregister chan *Connection
// broadcast channel for broadcasting messages
broadcast chan *BroadcastMessage
// ctx is the manager context
ctx context.Context
// cancel cancels the manager context
cancel context.CancelFunc
}
// BroadcastMessage represents a message to broadcast to multiple connections
type BroadcastMessage struct {
// Message is the message to broadcast
Message []byte
// Filter is an optional function to filter which connections receive the message
Filter func(*Connection) bool
}
// NewConnection creates a new WebSocket connection
func NewConnection(id string, ws *websocket.Conn, handler *Handler) *Connection {
ctx, cancel := context.WithCancel(context.Background())
return &Connection{
ID: id,
ws: ws,
send: make(chan []byte, 256),
subscriptions: make(map[string]*Subscription),
ctx: ctx,
cancel: cancel,
handler: handler,
metadata: make(map[string]interface{}),
}
}
// NewConnectionManager creates a new connection manager
func NewConnectionManager(ctx context.Context) *ConnectionManager {
ctx, cancel := context.WithCancel(ctx)
return &ConnectionManager{
connections: make(map[string]*Connection),
register: make(chan *Connection),
unregister: make(chan *Connection),
broadcast: make(chan *BroadcastMessage),
ctx: ctx,
cancel: cancel,
}
}
// Run starts the connection manager event loop
func (cm *ConnectionManager) Run() {
for {
select {
case conn := <-cm.register:
cm.mu.Lock()
cm.connections[conn.ID] = conn
cm.mu.Unlock()
logger.Info("[WebSocketSpec] Connection registered: %s (total: %d)", conn.ID, cm.Count())
case conn := <-cm.unregister:
cm.mu.Lock()
if _, ok := cm.connections[conn.ID]; ok {
delete(cm.connections, conn.ID)
close(conn.send)
logger.Info("[WebSocketSpec] Connection unregistered: %s (total: %d)", conn.ID, cm.Count())
}
cm.mu.Unlock()
case msg := <-cm.broadcast:
cm.mu.RLock()
for _, conn := range cm.connections {
if msg.Filter == nil || msg.Filter(conn) {
select {
case conn.send <- msg.Message:
default:
// Channel full, connection is slow - close it
logger.Warn("[WebSocketSpec] Connection %s send buffer full, closing", conn.ID)
cm.mu.RUnlock()
cm.unregister <- conn
cm.mu.RLock()
}
}
}
cm.mu.RUnlock()
case <-cm.ctx.Done():
logger.Info("[WebSocketSpec] Connection manager shutting down")
return
}
}
}
// Register registers a new connection
func (cm *ConnectionManager) Register(conn *Connection) {
cm.register <- conn
}
// Unregister removes a connection
func (cm *ConnectionManager) Unregister(conn *Connection) {
cm.unregister <- conn
}
// Broadcast sends a message to all connections matching the filter
func (cm *ConnectionManager) Broadcast(message []byte, filter func(*Connection) bool) {
cm.broadcast <- &BroadcastMessage{
Message: message,
Filter: filter,
}
}
// Count returns the number of active connections
func (cm *ConnectionManager) Count() int {
cm.mu.RLock()
defer cm.mu.RUnlock()
return len(cm.connections)
}
// GetConnection retrieves a connection by ID
func (cm *ConnectionManager) GetConnection(id string) (*Connection, bool) {
cm.mu.RLock()
defer cm.mu.RUnlock()
conn, ok := cm.connections[id]
return conn, ok
}
// Shutdown gracefully shuts down the connection manager
func (cm *ConnectionManager) Shutdown() {
cm.cancel()
// Close all connections
cm.mu.Lock()
for _, conn := range cm.connections {
conn.Close()
}
cm.mu.Unlock()
}
// ReadPump reads messages from the WebSocket connection
func (c *Connection) ReadPump() {
defer func() {
c.handler.connManager.Unregister(c)
c.Close()
}()
// Configure read parameters
c.ws.SetReadDeadline(time.Now().Add(60 * time.Second))
c.ws.SetPongHandler(func(string) error {
c.ws.SetReadDeadline(time.Now().Add(60 * time.Second))
return nil
})
for {
_, message, err := c.ws.ReadMessage()
if err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
logger.Error("[WebSocketSpec] Connection %s read error: %v", c.ID, err)
}
break
}
// Parse and handle the message
c.handleMessage(message)
}
}
// WritePump writes messages to the WebSocket connection
func (c *Connection) WritePump() {
ticker := time.NewTicker(54 * time.Second)
defer func() {
ticker.Stop()
c.Close()
}()
for {
select {
case message, ok := <-c.send:
c.ws.SetWriteDeadline(time.Now().Add(10 * time.Second))
if !ok {
// Channel closed
c.ws.WriteMessage(websocket.CloseMessage, []byte{})
return
}
w, err := c.ws.NextWriter(websocket.TextMessage)
if err != nil {
return
}
w.Write(message)
// Write any queued messages
n := len(c.send)
for i := 0; i < n; i++ {
w.Write([]byte{'\n'})
w.Write(<-c.send)
}
if err := w.Close(); err != nil {
return
}
case <-ticker.C:
c.ws.SetWriteDeadline(time.Now().Add(10 * time.Second))
if err := c.ws.WriteMessage(websocket.PingMessage, nil); err != nil {
return
}
case <-c.ctx.Done():
return
}
}
}
// Send sends a message to this connection
func (c *Connection) Send(message []byte) error {
select {
case c.send <- message:
return nil
case <-c.ctx.Done():
return fmt.Errorf("connection closed")
default:
return fmt.Errorf("send buffer full")
}
}
// SendJSON sends a JSON-encoded message to this connection
func (c *Connection) SendJSON(v interface{}) error {
data, err := json.Marshal(v)
if err != nil {
return fmt.Errorf("failed to marshal message: %w", err)
}
return c.Send(data)
}
// Close closes the connection
func (c *Connection) Close() {
c.closedOnce.Do(func() {
c.cancel()
c.ws.Close()
// Clean up subscriptions
c.mu.Lock()
for subID := range c.subscriptions {
c.handler.subscriptionManager.Unsubscribe(subID)
}
c.subscriptions = make(map[string]*Subscription)
c.mu.Unlock()
logger.Info("[WebSocketSpec] Connection %s closed", c.ID)
})
}
// AddSubscription adds a subscription to this connection
func (c *Connection) AddSubscription(sub *Subscription) {
c.mu.Lock()
defer c.mu.Unlock()
c.subscriptions[sub.ID] = sub
}
// RemoveSubscription removes a subscription from this connection
func (c *Connection) RemoveSubscription(subID string) {
c.mu.Lock()
defer c.mu.Unlock()
delete(c.subscriptions, subID)
}
// GetSubscription retrieves a subscription by ID
func (c *Connection) GetSubscription(subID string) (*Subscription, bool) {
c.mu.RLock()
defer c.mu.RUnlock()
sub, ok := c.subscriptions[subID]
return sub, ok
}
// SetMetadata sets metadata for this connection
func (c *Connection) SetMetadata(key string, value interface{}) {
c.metaMu.Lock()
defer c.metaMu.Unlock()
c.metadata[key] = value
}
// GetMetadata retrieves metadata for this connection
func (c *Connection) GetMetadata(key string) (interface{}, bool) {
c.metaMu.RLock()
defer c.metaMu.RUnlock()
val, ok := c.metadata[key]
return val, ok
}
// handleMessage processes an incoming message
func (c *Connection) handleMessage(data []byte) {
msg, err := ParseMessage(data)
if err != nil {
logger.Error("[WebSocketSpec] Failed to parse message: %v", err)
errResp := NewErrorResponse("", "invalid_message", "Failed to parse message")
c.SendJSON(errResp)
return
}
if !msg.IsValid() {
logger.Error("[WebSocketSpec] Invalid message received")
errResp := NewErrorResponse(msg.ID, "invalid_message", "Message validation failed")
c.SendJSON(errResp)
return
}
// Route message to appropriate handler
c.handler.HandleMessage(c, msg)
}