feat(mcp): add SSE transport support and related configuration options
Some checks failed
CI / build-and-test (push) Failing after -30m37s
Some checks failed
CI / build-and-test (push) Failing after -30m37s
This commit is contained in:
86
cmd/amcs-cli/cmd/sse.go
Normal file
86
cmd/amcs-cli/cmd/sse.go
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
)
|
||||||
|
|
||||||
|
var sseCmd = &cobra.Command{
|
||||||
|
Use: "sse",
|
||||||
|
Short: "Run a stdio MCP bridge backed by a remote AMCS server using SSE transport (widely supported by hosted MCP clients)",
|
||||||
|
RunE: func(cmd *cobra.Command, _ []string) error {
|
||||||
|
ctx := cmd.Context()
|
||||||
|
|
||||||
|
if err := requireServer(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
client := mcp.NewClient(&mcp.Implementation{Name: "amcs-cli", Version: "0.0.1"}, nil)
|
||||||
|
transport := &mcp.SSEClientTransport{
|
||||||
|
Endpoint: sseEndpointURL(),
|
||||||
|
HTTPClient: newHTTPClient(),
|
||||||
|
}
|
||||||
|
|
||||||
|
connectCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
remote, err := client.Connect(connectCtx, transport, nil)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("connect to AMCS SSE endpoint: %w", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = remote.Close() }()
|
||||||
|
|
||||||
|
tools, err := remote.ListTools(ctx, nil)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("load remote tools: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
server := mcp.NewServer(&mcp.Implementation{
|
||||||
|
Name: "amcs-cli",
|
||||||
|
Title: "AMCS CLI Bridge (SSE)",
|
||||||
|
Version: "0.0.1",
|
||||||
|
}, nil)
|
||||||
|
|
||||||
|
for _, tool := range tools.Tools {
|
||||||
|
remoteTool := tool
|
||||||
|
server.AddTool(&mcp.Tool{
|
||||||
|
Name: remoteTool.Name,
|
||||||
|
Description: remoteTool.Description,
|
||||||
|
InputSchema: remoteTool.InputSchema,
|
||||||
|
OutputSchema: remoteTool.OutputSchema,
|
||||||
|
Annotations: remoteTool.Annotations,
|
||||||
|
}, func(ctx context.Context, req *mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||||
|
return remote.CallTool(ctx, &mcp.CallToolParams{
|
||||||
|
Name: req.Params.Name,
|
||||||
|
Arguments: req.Params.Arguments,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
session, err := server.Connect(ctx, &mcp.StdioTransport{}, nil)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("start stdio bridge: %w", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = session.Close() }()
|
||||||
|
|
||||||
|
<-ctx.Done()
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
func sseEndpointURL() string {
|
||||||
|
base := strings.TrimRight(strings.TrimSpace(cfg.Server), "/")
|
||||||
|
if strings.HasSuffix(base, "/sse") {
|
||||||
|
return base
|
||||||
|
}
|
||||||
|
return base + "/sse"
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
rootCmd.AddCommand(sseCmd)
|
||||||
|
}
|
||||||
@@ -9,6 +9,7 @@ server:
|
|||||||
|
|
||||||
mcp:
|
mcp:
|
||||||
path: "/mcp"
|
path: "/mcp"
|
||||||
|
sse_path: "/sse"
|
||||||
server_name: "amcs"
|
server_name: "amcs"
|
||||||
transport: "streamable_http"
|
transport: "streamable_http"
|
||||||
session_timeout: "10m"
|
session_timeout: "10m"
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ server:
|
|||||||
|
|
||||||
mcp:
|
mcp:
|
||||||
path: "/mcp"
|
path: "/mcp"
|
||||||
|
sse_path: "/sse"
|
||||||
server_name: "amcs"
|
server_name: "amcs"
|
||||||
transport: "streamable_http"
|
transport: "streamable_http"
|
||||||
session_timeout: "10m"
|
session_timeout: "10m"
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ server:
|
|||||||
|
|
||||||
mcp:
|
mcp:
|
||||||
path: "/mcp"
|
path: "/mcp"
|
||||||
|
sse_path: "/sse"
|
||||||
server_name: "amcs"
|
server_name: "amcs"
|
||||||
transport: "streamable_http"
|
transport: "streamable_http"
|
||||||
session_timeout: "10m"
|
session_timeout: "10m"
|
||||||
|
|||||||
@@ -193,11 +193,15 @@ func routes(logger *slog.Logger, cfg *config.Config, info buildinfo.Info, db *st
|
|||||||
Describe: tools.NewDescribeTool(db, mcpserver.BuildToolCatalog()),
|
Describe: tools.NewDescribeTool(db, mcpserver.BuildToolCatalog()),
|
||||||
}
|
}
|
||||||
|
|
||||||
mcpHandler, err := mcpserver.New(cfg.MCP, logger, toolSet, activeProjects.Clear)
|
mcpHandlers, err := mcpserver.NewHandlers(cfg.MCP, logger, toolSet, activeProjects.Clear)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("build mcp handler: %w", err)
|
return nil, fmt.Errorf("build mcp handler: %w", err)
|
||||||
}
|
}
|
||||||
mux.Handle(cfg.MCP.Path, authMiddleware(mcpHandler))
|
mux.Handle(cfg.MCP.Path, authMiddleware(mcpHandlers.StreamableHTTP))
|
||||||
|
if mcpHandlers.SSE != nil {
|
||||||
|
mux.Handle(cfg.MCP.SSEPath, authMiddleware(mcpHandlers.SSE))
|
||||||
|
logger.Info("SSE transport enabled", slog.String("sse_path", cfg.MCP.SSEPath))
|
||||||
|
}
|
||||||
mux.Handle("/files", authMiddleware(fileHandler(filesTool)))
|
mux.Handle("/files", authMiddleware(fileHandler(filesTool)))
|
||||||
mux.Handle("/files/{id}", authMiddleware(fileHandler(filesTool)))
|
mux.Handle("/files/{id}", authMiddleware(fileHandler(filesTool)))
|
||||||
if oauthEnabled {
|
if oauthEnabled {
|
||||||
|
|||||||
@@ -32,6 +32,7 @@ type ServerConfig struct {
|
|||||||
|
|
||||||
type MCPConfig struct {
|
type MCPConfig struct {
|
||||||
Path string `yaml:"path"`
|
Path string `yaml:"path"`
|
||||||
|
SSEPath string `yaml:"sse_path"`
|
||||||
ServerName string `yaml:"server_name"`
|
ServerName string `yaml:"server_name"`
|
||||||
Version string `yaml:"version"`
|
Version string `yaml:"version"`
|
||||||
Transport string `yaml:"transport"`
|
Transport string `yaml:"transport"`
|
||||||
|
|||||||
@@ -58,6 +58,7 @@ func defaultConfig() Config {
|
|||||||
},
|
},
|
||||||
MCP: MCPConfig{
|
MCP: MCPConfig{
|
||||||
Path: "/mcp",
|
Path: "/mcp",
|
||||||
|
SSEPath: "/sse",
|
||||||
ServerName: "amcs",
|
ServerName: "amcs",
|
||||||
Version: info.Version,
|
Version: info.Version,
|
||||||
Transport: "streamable_http",
|
Transport: "streamable_http",
|
||||||
|
|||||||
@@ -33,6 +33,14 @@ func (c Config) Validate() error {
|
|||||||
if strings.TrimSpace(c.MCP.Path) == "" {
|
if strings.TrimSpace(c.MCP.Path) == "" {
|
||||||
return fmt.Errorf("invalid config: mcp.path is required")
|
return fmt.Errorf("invalid config: mcp.path is required")
|
||||||
}
|
}
|
||||||
|
if c.MCP.SSEPath != "" {
|
||||||
|
if strings.TrimSpace(c.MCP.SSEPath) == "" {
|
||||||
|
return fmt.Errorf("invalid config: mcp.sse_path must not be blank whitespace")
|
||||||
|
}
|
||||||
|
if c.MCP.SSEPath == c.MCP.Path {
|
||||||
|
return fmt.Errorf("invalid config: mcp.sse_path %q must differ from mcp.path", c.MCP.SSEPath)
|
||||||
|
}
|
||||||
|
}
|
||||||
if c.MCP.SessionTimeout <= 0 {
|
if c.MCP.SessionTimeout <= 0 {
|
||||||
return fmt.Errorf("invalid config: mcp.session_timeout must be greater than zero")
|
return fmt.Errorf("invalid config: mcp.session_timeout must be greater than zero")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -46,7 +46,28 @@ type ToolSet struct {
|
|||||||
Describe *tools.DescribeTool
|
Describe *tools.DescribeTool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Handlers groups the HTTP handlers produced for an MCP server instance.
|
||||||
|
type Handlers struct {
|
||||||
|
// StreamableHTTP is the primary MCP handler (always non-nil).
|
||||||
|
StreamableHTTP http.Handler
|
||||||
|
// SSE is the SSE transport handler; nil when SSEPath is empty.
|
||||||
|
// SSE is the de facto transport for MCP over the internet and is required by most hosted MCP clients.
|
||||||
|
SSE http.Handler
|
||||||
|
}
|
||||||
|
|
||||||
|
// New builds the StreamableHTTP MCP handler. It is a convenience wrapper
|
||||||
|
// around NewHandlers for callers that only need the primary transport.
|
||||||
func New(cfg config.MCPConfig, logger *slog.Logger, toolSet ToolSet, onSessionClosed func(string)) (http.Handler, error) {
|
func New(cfg config.MCPConfig, logger *slog.Logger, toolSet ToolSet, onSessionClosed func(string)) (http.Handler, error) {
|
||||||
|
h, err := NewHandlers(cfg, logger, toolSet, onSessionClosed)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return h.StreamableHTTP, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewHandlers builds MCP HTTP handlers for both transports.
|
||||||
|
// SSE is nil when cfg.SSEPath is empty.
|
||||||
|
func NewHandlers(cfg config.MCPConfig, logger *slog.Logger, toolSet ToolSet, onSessionClosed func(string)) (Handlers, error) {
|
||||||
instructions := cfg.Instructions
|
instructions := cfg.Instructions
|
||||||
if instructions == "" {
|
if instructions == "" {
|
||||||
instructions = string(amcsllm.MemoryInstructions)
|
instructions = string(amcsllm.MemoryInstructions)
|
||||||
@@ -77,7 +98,7 @@ func New(cfg config.MCPConfig, logger *slog.Logger, toolSet ToolSet, onSessionCl
|
|||||||
registerDescribeTools,
|
registerDescribeTools,
|
||||||
} {
|
} {
|
||||||
if err := register(server, logger, toolSet); err != nil {
|
if err := register(server, logger, toolSet); err != nil {
|
||||||
return nil, err
|
return Handlers{}, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -89,9 +110,19 @@ func New(cfg config.MCPConfig, logger *slog.Logger, toolSet ToolSet, onSessionCl
|
|||||||
opts.EventStore = newCleanupEventStore(mcp.NewMemoryEventStore(nil), onSessionClosed)
|
opts.EventStore = newCleanupEventStore(mcp.NewMemoryEventStore(nil), onSessionClosed)
|
||||||
}
|
}
|
||||||
|
|
||||||
return mcp.NewStreamableHTTPHandler(func(*http.Request) *mcp.Server {
|
h := Handlers{
|
||||||
|
StreamableHTTP: mcp.NewStreamableHTTPHandler(func(*http.Request) *mcp.Server {
|
||||||
return server
|
return server
|
||||||
}, opts), nil
|
}, opts),
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.TrimSpace(cfg.SSEPath) != "" {
|
||||||
|
h.SSE = mcp.NewSSEHandler(func(*http.Request) *mcp.Server {
|
||||||
|
return server
|
||||||
|
}, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
return h, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// buildServerIcons returns icon definitions referencing the server's own /images/icon.png endpoint.
|
// buildServerIcons returns icon definitions referencing the server's own /images/icon.png endpoint.
|
||||||
|
|||||||
136
internal/mcpserver/sse_test.go
Normal file
136
internal/mcpserver/sse_test.go
Normal file
@@ -0,0 +1,136 @@
|
|||||||
|
package mcpserver
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||||
|
|
||||||
|
"git.warky.dev/wdevs/amcs/internal/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewHandlers_SSEDisabledByDefault(t *testing.T) {
|
||||||
|
h, err := NewHandlers(config.MCPConfig{
|
||||||
|
ServerName: "test",
|
||||||
|
Version: "0.0.1",
|
||||||
|
SessionTimeout: time.Minute,
|
||||||
|
}, nil, streamableTestToolSet(), nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewHandlers() error = %v", err)
|
||||||
|
}
|
||||||
|
if h.StreamableHTTP == nil {
|
||||||
|
t.Fatal("StreamableHTTP handler is nil")
|
||||||
|
}
|
||||||
|
if h.SSE != nil {
|
||||||
|
t.Fatal("SSE handler should be nil when SSEPath is empty")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewHandlers_SSEEnabledWhenPathSet(t *testing.T) {
|
||||||
|
h, err := NewHandlers(config.MCPConfig{
|
||||||
|
ServerName: "test",
|
||||||
|
Version: "0.0.1",
|
||||||
|
SessionTimeout: time.Minute,
|
||||||
|
SSEPath: "/sse",
|
||||||
|
}, nil, streamableTestToolSet(), nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewHandlers() error = %v", err)
|
||||||
|
}
|
||||||
|
if h.StreamableHTTP == nil {
|
||||||
|
t.Fatal("StreamableHTTP handler is nil")
|
||||||
|
}
|
||||||
|
if h.SSE == nil {
|
||||||
|
t.Fatal("SSE handler is nil when SSEPath is set")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNew_BackwardCompatibility(t *testing.T) {
|
||||||
|
handler, err := New(config.MCPConfig{
|
||||||
|
ServerName: "test",
|
||||||
|
Version: "0.0.1",
|
||||||
|
SessionTimeout: time.Minute,
|
||||||
|
}, nil, streamableTestToolSet(), nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("New() error = %v", err)
|
||||||
|
}
|
||||||
|
if handler == nil {
|
||||||
|
t.Fatal("New() returned nil handler")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSSEListTools(t *testing.T) {
|
||||||
|
h, err := NewHandlers(config.MCPConfig{
|
||||||
|
ServerName: "test",
|
||||||
|
Version: "0.0.1",
|
||||||
|
SessionTimeout: time.Minute,
|
||||||
|
SSEPath: "/sse",
|
||||||
|
}, nil, streamableTestToolSet(), nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewHandlers() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
srv := httptest.NewServer(h.SSE)
|
||||||
|
t.Cleanup(srv.Close)
|
||||||
|
|
||||||
|
client := mcp.NewClient(&mcp.Implementation{Name: "client", Version: "0.0.1"}, nil)
|
||||||
|
cs, err := client.Connect(context.Background(), &mcp.SSEClientTransport{Endpoint: srv.URL}, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("connect SSE client: %v", err)
|
||||||
|
}
|
||||||
|
t.Cleanup(func() { _ = cs.Close() })
|
||||||
|
|
||||||
|
result, err := cs.ListTools(context.Background(), nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ListTools() error = %v", err)
|
||||||
|
}
|
||||||
|
if len(result.Tools) == 0 {
|
||||||
|
t.Fatal("ListTools() returned no tools")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSSEAndStreamableShareTools(t *testing.T) {
|
||||||
|
h, err := NewHandlers(config.MCPConfig{
|
||||||
|
ServerName: "test",
|
||||||
|
Version: "0.0.1",
|
||||||
|
SessionTimeout: time.Minute,
|
||||||
|
SSEPath: "/sse",
|
||||||
|
}, nil, streamableTestToolSet(), nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewHandlers() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
sseSrv := httptest.NewServer(h.SSE)
|
||||||
|
t.Cleanup(sseSrv.Close)
|
||||||
|
|
||||||
|
streamSrv := httptest.NewServer(h.StreamableHTTP)
|
||||||
|
t.Cleanup(streamSrv.Close)
|
||||||
|
|
||||||
|
sseClient := mcp.NewClient(&mcp.Implementation{Name: "sse-client", Version: "0.0.1"}, nil)
|
||||||
|
sseSession, err := sseClient.Connect(context.Background(), &mcp.SSEClientTransport{Endpoint: sseSrv.URL}, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("connect SSE client: %v", err)
|
||||||
|
}
|
||||||
|
t.Cleanup(func() { _ = sseSession.Close() })
|
||||||
|
|
||||||
|
streamClient := mcp.NewClient(&mcp.Implementation{Name: "stream-client", Version: "0.0.1"}, nil)
|
||||||
|
streamSession, err := streamClient.Connect(context.Background(), &mcp.StreamableClientTransport{Endpoint: streamSrv.URL}, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("connect StreamableHTTP client: %v", err)
|
||||||
|
}
|
||||||
|
t.Cleanup(func() { _ = streamSession.Close() })
|
||||||
|
|
||||||
|
sseTools, err := sseSession.ListTools(context.Background(), nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("SSE ListTools() error = %v", err)
|
||||||
|
}
|
||||||
|
streamTools, err := streamSession.ListTools(context.Background(), nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("StreamableHTTP ListTools() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(sseTools.Tools) != len(streamTools.Tools) {
|
||||||
|
t.Fatalf("SSE tool count = %d, StreamableHTTP tool count = %d, want equal", len(sseTools.Tools), len(streamTools.Tools))
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user