From 5f48a197e89a08b3efbbe4e98f879843024e4136 Mon Sep 17 00:00:00 2001 From: Hein Date: Sun, 5 Apr 2026 15:57:34 +0200 Subject: [PATCH] feat(mcp): add SSE transport support and related configuration options --- cmd/amcs-cli/cmd/sse.go | 86 +++++++++++++++++++++ configs/config.example.yaml | 1 + configs/dev.yaml | 1 + configs/docker.yaml | 1 + internal/app/app.go | 8 +- internal/config/config.go | 1 + internal/config/loader.go | 1 + internal/config/validate.go | 8 ++ internal/mcpserver/server.go | 39 +++++++++- internal/mcpserver/sse_test.go | 136 +++++++++++++++++++++++++++++++++ 10 files changed, 276 insertions(+), 6 deletions(-) create mode 100644 cmd/amcs-cli/cmd/sse.go create mode 100644 internal/mcpserver/sse_test.go diff --git a/cmd/amcs-cli/cmd/sse.go b/cmd/amcs-cli/cmd/sse.go new file mode 100644 index 0000000..184e5ce --- /dev/null +++ b/cmd/amcs-cli/cmd/sse.go @@ -0,0 +1,86 @@ +package cmd + +import ( + "context" + "fmt" + "strings" + "time" + + "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/spf13/cobra" +) + +var sseCmd = &cobra.Command{ + Use: "sse", + Short: "Run a stdio MCP bridge backed by a remote AMCS server using SSE transport (widely supported by hosted MCP clients)", + RunE: func(cmd *cobra.Command, _ []string) error { + ctx := cmd.Context() + + if err := requireServer(); err != nil { + return err + } + + client := mcp.NewClient(&mcp.Implementation{Name: "amcs-cli", Version: "0.0.1"}, nil) + transport := &mcp.SSEClientTransport{ + Endpoint: sseEndpointURL(), + HTTPClient: newHTTPClient(), + } + + connectCtx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + + remote, err := client.Connect(connectCtx, transport, nil) + if err != nil { + return fmt.Errorf("connect to AMCS SSE endpoint: %w", err) + } + defer func() { _ = remote.Close() }() + + tools, err := remote.ListTools(ctx, nil) + if err != nil { + return fmt.Errorf("load remote tools: %w", err) + } + + server := mcp.NewServer(&mcp.Implementation{ + Name: "amcs-cli", + Title: "AMCS CLI Bridge (SSE)", + Version: "0.0.1", + }, nil) + + for _, tool := range tools.Tools { + remoteTool := tool + server.AddTool(&mcp.Tool{ + Name: remoteTool.Name, + Description: remoteTool.Description, + InputSchema: remoteTool.InputSchema, + OutputSchema: remoteTool.OutputSchema, + Annotations: remoteTool.Annotations, + }, func(ctx context.Context, req *mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return remote.CallTool(ctx, &mcp.CallToolParams{ + Name: req.Params.Name, + Arguments: req.Params.Arguments, + }) + }) + } + + session, err := server.Connect(ctx, &mcp.StdioTransport{}, nil) + if err != nil { + return fmt.Errorf("start stdio bridge: %w", err) + } + defer func() { _ = session.Close() }() + + <-ctx.Done() + return nil + }, +} + +func sseEndpointURL() string { + base := strings.TrimRight(strings.TrimSpace(cfg.Server), "/") + if strings.HasSuffix(base, "/sse") { + return base + } + return base + "/sse" +} + +func init() { + rootCmd.AddCommand(sseCmd) +} diff --git a/configs/config.example.yaml b/configs/config.example.yaml index 1559190..abd34dc 100644 --- a/configs/config.example.yaml +++ b/configs/config.example.yaml @@ -9,6 +9,7 @@ server: mcp: path: "/mcp" + sse_path: "/sse" server_name: "amcs" transport: "streamable_http" session_timeout: "10m" diff --git a/configs/dev.yaml b/configs/dev.yaml index 09c7740..6fae16d 100644 --- a/configs/dev.yaml +++ b/configs/dev.yaml @@ -9,6 +9,7 @@ server: mcp: path: "/mcp" + sse_path: "/sse" server_name: "amcs" transport: "streamable_http" session_timeout: "10m" diff --git a/configs/docker.yaml b/configs/docker.yaml index 87151ed..bdacc04 100644 --- a/configs/docker.yaml +++ b/configs/docker.yaml @@ -9,6 +9,7 @@ server: mcp: path: "/mcp" + sse_path: "/sse" server_name: "amcs" transport: "streamable_http" session_timeout: "10m" diff --git a/internal/app/app.go b/internal/app/app.go index adac527..ef4bf00 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -193,11 +193,15 @@ func routes(logger *slog.Logger, cfg *config.Config, info buildinfo.Info, db *st Describe: tools.NewDescribeTool(db, mcpserver.BuildToolCatalog()), } - mcpHandler, err := mcpserver.New(cfg.MCP, logger, toolSet, activeProjects.Clear) + mcpHandlers, err := mcpserver.NewHandlers(cfg.MCP, logger, toolSet, activeProjects.Clear) if err != nil { return nil, fmt.Errorf("build mcp handler: %w", err) } - mux.Handle(cfg.MCP.Path, authMiddleware(mcpHandler)) + mux.Handle(cfg.MCP.Path, authMiddleware(mcpHandlers.StreamableHTTP)) + if mcpHandlers.SSE != nil { + mux.Handle(cfg.MCP.SSEPath, authMiddleware(mcpHandlers.SSE)) + logger.Info("SSE transport enabled", slog.String("sse_path", cfg.MCP.SSEPath)) + } mux.Handle("/files", authMiddleware(fileHandler(filesTool))) mux.Handle("/files/{id}", authMiddleware(fileHandler(filesTool))) if oauthEnabled { diff --git a/internal/config/config.go b/internal/config/config.go index a6081b3..46f8daa 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -32,6 +32,7 @@ type ServerConfig struct { type MCPConfig struct { Path string `yaml:"path"` + SSEPath string `yaml:"sse_path"` ServerName string `yaml:"server_name"` Version string `yaml:"version"` Transport string `yaml:"transport"` diff --git a/internal/config/loader.go b/internal/config/loader.go index 95d849b..d6b0a88 100644 --- a/internal/config/loader.go +++ b/internal/config/loader.go @@ -58,6 +58,7 @@ func defaultConfig() Config { }, MCP: MCPConfig{ Path: "/mcp", + SSEPath: "/sse", ServerName: "amcs", Version: info.Version, Transport: "streamable_http", diff --git a/internal/config/validate.go b/internal/config/validate.go index ccb2cc9..af40b37 100644 --- a/internal/config/validate.go +++ b/internal/config/validate.go @@ -33,6 +33,14 @@ func (c Config) Validate() error { if strings.TrimSpace(c.MCP.Path) == "" { return fmt.Errorf("invalid config: mcp.path is required") } + if c.MCP.SSEPath != "" { + if strings.TrimSpace(c.MCP.SSEPath) == "" { + return fmt.Errorf("invalid config: mcp.sse_path must not be blank whitespace") + } + if c.MCP.SSEPath == c.MCP.Path { + return fmt.Errorf("invalid config: mcp.sse_path %q must differ from mcp.path", c.MCP.SSEPath) + } + } if c.MCP.SessionTimeout <= 0 { return fmt.Errorf("invalid config: mcp.session_timeout must be greater than zero") } diff --git a/internal/mcpserver/server.go b/internal/mcpserver/server.go index 9d29515..212e931 100644 --- a/internal/mcpserver/server.go +++ b/internal/mcpserver/server.go @@ -46,7 +46,28 @@ type ToolSet struct { Describe *tools.DescribeTool } +// Handlers groups the HTTP handlers produced for an MCP server instance. +type Handlers struct { + // StreamableHTTP is the primary MCP handler (always non-nil). + StreamableHTTP http.Handler + // SSE is the SSE transport handler; nil when SSEPath is empty. + // SSE is the de facto transport for MCP over the internet and is required by most hosted MCP clients. + SSE http.Handler +} + +// New builds the StreamableHTTP MCP handler. It is a convenience wrapper +// around NewHandlers for callers that only need the primary transport. func New(cfg config.MCPConfig, logger *slog.Logger, toolSet ToolSet, onSessionClosed func(string)) (http.Handler, error) { + h, err := NewHandlers(cfg, logger, toolSet, onSessionClosed) + if err != nil { + return nil, err + } + return h.StreamableHTTP, nil +} + +// NewHandlers builds MCP HTTP handlers for both transports. +// SSE is nil when cfg.SSEPath is empty. +func NewHandlers(cfg config.MCPConfig, logger *slog.Logger, toolSet ToolSet, onSessionClosed func(string)) (Handlers, error) { instructions := cfg.Instructions if instructions == "" { instructions = string(amcsllm.MemoryInstructions) @@ -77,7 +98,7 @@ func New(cfg config.MCPConfig, logger *slog.Logger, toolSet ToolSet, onSessionCl registerDescribeTools, } { if err := register(server, logger, toolSet); err != nil { - return nil, err + return Handlers{}, err } } @@ -89,9 +110,19 @@ func New(cfg config.MCPConfig, logger *slog.Logger, toolSet ToolSet, onSessionCl opts.EventStore = newCleanupEventStore(mcp.NewMemoryEventStore(nil), onSessionClosed) } - return mcp.NewStreamableHTTPHandler(func(*http.Request) *mcp.Server { - return server - }, opts), nil + h := Handlers{ + StreamableHTTP: mcp.NewStreamableHTTPHandler(func(*http.Request) *mcp.Server { + return server + }, opts), + } + + if strings.TrimSpace(cfg.SSEPath) != "" { + h.SSE = mcp.NewSSEHandler(func(*http.Request) *mcp.Server { + return server + }, nil) + } + + return h, nil } // buildServerIcons returns icon definitions referencing the server's own /images/icon.png endpoint. diff --git a/internal/mcpserver/sse_test.go b/internal/mcpserver/sse_test.go new file mode 100644 index 0000000..987036c --- /dev/null +++ b/internal/mcpserver/sse_test.go @@ -0,0 +1,136 @@ +package mcpserver + +import ( + "context" + "net/http/httptest" + "testing" + "time" + + "github.com/modelcontextprotocol/go-sdk/mcp" + + "git.warky.dev/wdevs/amcs/internal/config" +) + +func TestNewHandlers_SSEDisabledByDefault(t *testing.T) { + h, err := NewHandlers(config.MCPConfig{ + ServerName: "test", + Version: "0.0.1", + SessionTimeout: time.Minute, + }, nil, streamableTestToolSet(), nil) + if err != nil { + t.Fatalf("NewHandlers() error = %v", err) + } + if h.StreamableHTTP == nil { + t.Fatal("StreamableHTTP handler is nil") + } + if h.SSE != nil { + t.Fatal("SSE handler should be nil when SSEPath is empty") + } +} + +func TestNewHandlers_SSEEnabledWhenPathSet(t *testing.T) { + h, err := NewHandlers(config.MCPConfig{ + ServerName: "test", + Version: "0.0.1", + SessionTimeout: time.Minute, + SSEPath: "/sse", + }, nil, streamableTestToolSet(), nil) + if err != nil { + t.Fatalf("NewHandlers() error = %v", err) + } + if h.StreamableHTTP == nil { + t.Fatal("StreamableHTTP handler is nil") + } + if h.SSE == nil { + t.Fatal("SSE handler is nil when SSEPath is set") + } +} + +func TestNew_BackwardCompatibility(t *testing.T) { + handler, err := New(config.MCPConfig{ + ServerName: "test", + Version: "0.0.1", + SessionTimeout: time.Minute, + }, nil, streamableTestToolSet(), nil) + if err != nil { + t.Fatalf("New() error = %v", err) + } + if handler == nil { + t.Fatal("New() returned nil handler") + } +} + +func TestSSEListTools(t *testing.T) { + h, err := NewHandlers(config.MCPConfig{ + ServerName: "test", + Version: "0.0.1", + SessionTimeout: time.Minute, + SSEPath: "/sse", + }, nil, streamableTestToolSet(), nil) + if err != nil { + t.Fatalf("NewHandlers() error = %v", err) + } + + srv := httptest.NewServer(h.SSE) + t.Cleanup(srv.Close) + + client := mcp.NewClient(&mcp.Implementation{Name: "client", Version: "0.0.1"}, nil) + cs, err := client.Connect(context.Background(), &mcp.SSEClientTransport{Endpoint: srv.URL}, nil) + if err != nil { + t.Fatalf("connect SSE client: %v", err) + } + t.Cleanup(func() { _ = cs.Close() }) + + result, err := cs.ListTools(context.Background(), nil) + if err != nil { + t.Fatalf("ListTools() error = %v", err) + } + if len(result.Tools) == 0 { + t.Fatal("ListTools() returned no tools") + } +} + +func TestSSEAndStreamableShareTools(t *testing.T) { + h, err := NewHandlers(config.MCPConfig{ + ServerName: "test", + Version: "0.0.1", + SessionTimeout: time.Minute, + SSEPath: "/sse", + }, nil, streamableTestToolSet(), nil) + if err != nil { + t.Fatalf("NewHandlers() error = %v", err) + } + + sseSrv := httptest.NewServer(h.SSE) + t.Cleanup(sseSrv.Close) + + streamSrv := httptest.NewServer(h.StreamableHTTP) + t.Cleanup(streamSrv.Close) + + sseClient := mcp.NewClient(&mcp.Implementation{Name: "sse-client", Version: "0.0.1"}, nil) + sseSession, err := sseClient.Connect(context.Background(), &mcp.SSEClientTransport{Endpoint: sseSrv.URL}, nil) + if err != nil { + t.Fatalf("connect SSE client: %v", err) + } + t.Cleanup(func() { _ = sseSession.Close() }) + + streamClient := mcp.NewClient(&mcp.Implementation{Name: "stream-client", Version: "0.0.1"}, nil) + streamSession, err := streamClient.Connect(context.Background(), &mcp.StreamableClientTransport{Endpoint: streamSrv.URL}, nil) + if err != nil { + t.Fatalf("connect StreamableHTTP client: %v", err) + } + t.Cleanup(func() { _ = streamSession.Close() }) + + sseTools, err := sseSession.ListTools(context.Background(), nil) + if err != nil { + t.Fatalf("SSE ListTools() error = %v", err) + } + streamTools, err := streamSession.ListTools(context.Background(), nil) + if err != nil { + t.Fatalf("StreamableHTTP ListTools() error = %v", err) + } + + if len(sseTools.Tools) != len(streamTools.Tools) { + t.Fatalf("SSE tool count = %d, StreamableHTTP tool count = %d, want equal", len(sseTools.Tools), len(streamTools.Tools)) + } +}