package mcpserver import ( "context" "net/http/httptest" "testing" "time" "github.com/modelcontextprotocol/go-sdk/mcp" "git.warky.dev/wdevs/amcs/internal/config" ) func TestNewHandlers_SSEDisabledByDefault(t *testing.T) { h, err := NewHandlers(config.MCPConfig{ ServerName: "test", Version: "0.0.1", SessionTimeout: time.Minute, }, nil, streamableTestToolSet(), nil) if err != nil { t.Fatalf("NewHandlers() error = %v", err) } if h.StreamableHTTP == nil { t.Fatal("StreamableHTTP handler is nil") } if h.SSE != nil { t.Fatal("SSE handler should be nil when SSEPath is empty") } } func TestNewHandlers_SSEEnabledWhenPathSet(t *testing.T) { h, err := NewHandlers(config.MCPConfig{ ServerName: "test", Version: "0.0.1", SessionTimeout: time.Minute, SSEPath: "/sse", }, nil, streamableTestToolSet(), nil) if err != nil { t.Fatalf("NewHandlers() error = %v", err) } if h.StreamableHTTP == nil { t.Fatal("StreamableHTTP handler is nil") } if h.SSE == nil { t.Fatal("SSE handler is nil when SSEPath is set") } } func TestNew_BackwardCompatibility(t *testing.T) { handler, err := New(config.MCPConfig{ ServerName: "test", Version: "0.0.1", SessionTimeout: time.Minute, }, nil, streamableTestToolSet(), nil) if err != nil { t.Fatalf("New() error = %v", err) } if handler == nil { t.Fatal("New() returned nil handler") } } func TestSSEListTools(t *testing.T) { h, err := NewHandlers(config.MCPConfig{ ServerName: "test", Version: "0.0.1", SessionTimeout: time.Minute, SSEPath: "/sse", }, nil, streamableTestToolSet(), nil) if err != nil { t.Fatalf("NewHandlers() error = %v", err) } srv := httptest.NewServer(h.SSE) t.Cleanup(srv.Close) client := mcp.NewClient(&mcp.Implementation{Name: "client", Version: "0.0.1"}, nil) cs, err := client.Connect(context.Background(), &mcp.SSEClientTransport{Endpoint: srv.URL}, nil) if err != nil { t.Fatalf("connect SSE client: %v", err) } t.Cleanup(func() { _ = cs.Close() }) result, err := cs.ListTools(context.Background(), nil) if err != nil { t.Fatalf("ListTools() error = %v", err) } if len(result.Tools) == 0 { t.Fatal("ListTools() returned no tools") } } func TestSSEAndStreamableShareTools(t *testing.T) { h, err := NewHandlers(config.MCPConfig{ ServerName: "test", Version: "0.0.1", SessionTimeout: time.Minute, SSEPath: "/sse", }, nil, streamableTestToolSet(), nil) if err != nil { t.Fatalf("NewHandlers() error = %v", err) } sseSrv := httptest.NewServer(h.SSE) t.Cleanup(sseSrv.Close) streamSrv := httptest.NewServer(h.StreamableHTTP) t.Cleanup(streamSrv.Close) sseClient := mcp.NewClient(&mcp.Implementation{Name: "sse-client", Version: "0.0.1"}, nil) sseSession, err := sseClient.Connect(context.Background(), &mcp.SSEClientTransport{Endpoint: sseSrv.URL}, nil) if err != nil { t.Fatalf("connect SSE client: %v", err) } t.Cleanup(func() { _ = sseSession.Close() }) streamClient := mcp.NewClient(&mcp.Implementation{Name: "stream-client", Version: "0.0.1"}, nil) streamSession, err := streamClient.Connect(context.Background(), &mcp.StreamableClientTransport{Endpoint: streamSrv.URL}, nil) if err != nil { t.Fatalf("connect StreamableHTTP client: %v", err) } t.Cleanup(func() { _ = streamSession.Close() }) sseTools, err := sseSession.ListTools(context.Background(), nil) if err != nil { t.Fatalf("SSE ListTools() error = %v", err) } streamTools, err := streamSession.ListTools(context.Background(), nil) if err != nil { t.Fatalf("StreamableHTTP ListTools() error = %v", err) } if len(sseTools.Tools) != len(streamTools.Tools) { t.Fatalf("SSE tool count = %d, StreamableHTTP tool count = %d, want equal", len(sseTools.Tools), len(streamTools.Tools)) } }