137 lines
3.7 KiB
Go
137 lines
3.7 KiB
Go
package mcpserver
|
|
|
|
import (
|
|
"context"
|
|
"net/http/httptest"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/modelcontextprotocol/go-sdk/mcp"
|
|
|
|
"git.warky.dev/wdevs/amcs/internal/config"
|
|
)
|
|
|
|
func TestNewHandlers_SSEDisabledByDefault(t *testing.T) {
|
|
h, err := NewHandlers(config.MCPConfig{
|
|
ServerName: "test",
|
|
Version: "0.0.1",
|
|
SessionTimeout: time.Minute,
|
|
}, nil, streamableTestToolSet(), nil)
|
|
if err != nil {
|
|
t.Fatalf("NewHandlers() error = %v", err)
|
|
}
|
|
if h.StreamableHTTP == nil {
|
|
t.Fatal("StreamableHTTP handler is nil")
|
|
}
|
|
if h.SSE != nil {
|
|
t.Fatal("SSE handler should be nil when SSEPath is empty")
|
|
}
|
|
}
|
|
|
|
func TestNewHandlers_SSEEnabledWhenPathSet(t *testing.T) {
|
|
h, err := NewHandlers(config.MCPConfig{
|
|
ServerName: "test",
|
|
Version: "0.0.1",
|
|
SessionTimeout: time.Minute,
|
|
SSEPath: "/sse",
|
|
}, nil, streamableTestToolSet(), nil)
|
|
if err != nil {
|
|
t.Fatalf("NewHandlers() error = %v", err)
|
|
}
|
|
if h.StreamableHTTP == nil {
|
|
t.Fatal("StreamableHTTP handler is nil")
|
|
}
|
|
if h.SSE == nil {
|
|
t.Fatal("SSE handler is nil when SSEPath is set")
|
|
}
|
|
}
|
|
|
|
func TestNew_BackwardCompatibility(t *testing.T) {
|
|
handler, err := New(config.MCPConfig{
|
|
ServerName: "test",
|
|
Version: "0.0.1",
|
|
SessionTimeout: time.Minute,
|
|
}, nil, streamableTestToolSet(), nil)
|
|
if err != nil {
|
|
t.Fatalf("New() error = %v", err)
|
|
}
|
|
if handler == nil {
|
|
t.Fatal("New() returned nil handler")
|
|
}
|
|
}
|
|
|
|
func TestSSEListTools(t *testing.T) {
|
|
h, err := NewHandlers(config.MCPConfig{
|
|
ServerName: "test",
|
|
Version: "0.0.1",
|
|
SessionTimeout: time.Minute,
|
|
SSEPath: "/sse",
|
|
}, nil, streamableTestToolSet(), nil)
|
|
if err != nil {
|
|
t.Fatalf("NewHandlers() error = %v", err)
|
|
}
|
|
|
|
srv := httptest.NewServer(h.SSE)
|
|
t.Cleanup(srv.Close)
|
|
|
|
client := mcp.NewClient(&mcp.Implementation{Name: "client", Version: "0.0.1"}, nil)
|
|
cs, err := client.Connect(context.Background(), &mcp.SSEClientTransport{Endpoint: srv.URL}, nil)
|
|
if err != nil {
|
|
t.Fatalf("connect SSE client: %v", err)
|
|
}
|
|
t.Cleanup(func() { _ = cs.Close() })
|
|
|
|
result, err := cs.ListTools(context.Background(), nil)
|
|
if err != nil {
|
|
t.Fatalf("ListTools() error = %v", err)
|
|
}
|
|
if len(result.Tools) == 0 {
|
|
t.Fatal("ListTools() returned no tools")
|
|
}
|
|
}
|
|
|
|
func TestSSEAndStreamableShareTools(t *testing.T) {
|
|
h, err := NewHandlers(config.MCPConfig{
|
|
ServerName: "test",
|
|
Version: "0.0.1",
|
|
SessionTimeout: time.Minute,
|
|
SSEPath: "/sse",
|
|
}, nil, streamableTestToolSet(), nil)
|
|
if err != nil {
|
|
t.Fatalf("NewHandlers() error = %v", err)
|
|
}
|
|
|
|
sseSrv := httptest.NewServer(h.SSE)
|
|
t.Cleanup(sseSrv.Close)
|
|
|
|
streamSrv := httptest.NewServer(h.StreamableHTTP)
|
|
t.Cleanup(streamSrv.Close)
|
|
|
|
sseClient := mcp.NewClient(&mcp.Implementation{Name: "sse-client", Version: "0.0.1"}, nil)
|
|
sseSession, err := sseClient.Connect(context.Background(), &mcp.SSEClientTransport{Endpoint: sseSrv.URL}, nil)
|
|
if err != nil {
|
|
t.Fatalf("connect SSE client: %v", err)
|
|
}
|
|
t.Cleanup(func() { _ = sseSession.Close() })
|
|
|
|
streamClient := mcp.NewClient(&mcp.Implementation{Name: "stream-client", Version: "0.0.1"}, nil)
|
|
streamSession, err := streamClient.Connect(context.Background(), &mcp.StreamableClientTransport{Endpoint: streamSrv.URL}, nil)
|
|
if err != nil {
|
|
t.Fatalf("connect StreamableHTTP client: %v", err)
|
|
}
|
|
t.Cleanup(func() { _ = streamSession.Close() })
|
|
|
|
sseTools, err := sseSession.ListTools(context.Background(), nil)
|
|
if err != nil {
|
|
t.Fatalf("SSE ListTools() error = %v", err)
|
|
}
|
|
streamTools, err := streamSession.ListTools(context.Background(), nil)
|
|
if err != nil {
|
|
t.Fatalf("StreamableHTTP ListTools() error = %v", err)
|
|
}
|
|
|
|
if len(sseTools.Tools) != len(streamTools.Tools) {
|
|
t.Fatalf("SSE tool count = %d, StreamableHTTP tool count = %d, want equal", len(sseTools.Tools), len(streamTools.Tools))
|
|
}
|
|
}
|