test(tools): add unit tests for error handling functions
* Implement tests for error functions like errRequiredField, errInvalidField, and errEntityNotFound. * Ensure proper metadata is returned for various error scenarios. * Validate error handling in CRM, Files, and other tools. * Introduce tests for parsing stored file IDs and UUIDs. * Enhance coverage for helper functions related to project resolution and session management.
This commit is contained in:
154
internal/mcpserver/error_integration_test.go
Normal file
154
internal/mcpserver/error_integration_test.go
Normal file
@@ -0,0 +1,154 @@
|
||||
package mcpserver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/modelcontextprotocol/go-sdk/jsonrpc"
|
||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
|
||||
"git.warky.dev/wdevs/amcs/internal/config"
|
||||
"git.warky.dev/wdevs/amcs/internal/mcperrors"
|
||||
"git.warky.dev/wdevs/amcs/internal/session"
|
||||
"git.warky.dev/wdevs/amcs/internal/tools"
|
||||
)
|
||||
|
||||
func TestToolValidationErrorsRoundTripAsStructuredJSONRPC(t *testing.T) {
|
||||
server := mcp.NewServer(&mcp.Implementation{Name: "test", Version: "0.0.1"}, nil)
|
||||
|
||||
projects := tools.NewProjectsTool(nil, session.NewActiveProjects())
|
||||
contextTool := tools.NewContextTool(nil, nil, toolsSearchConfig(), session.NewActiveProjects())
|
||||
|
||||
if err := addTool(server, nil, &mcp.Tool{
|
||||
Name: "create_project",
|
||||
Description: "Create a named project container for thoughts.",
|
||||
}, projects.Create); err != nil {
|
||||
t.Fatalf("add create_project tool: %v", err)
|
||||
}
|
||||
if err := addTool(server, nil, &mcp.Tool{
|
||||
Name: "get_project_context",
|
||||
Description: "Get recent and semantic context for a project.",
|
||||
}, contextTool.Handle); err != nil {
|
||||
t.Fatalf("add get_project_context tool: %v", err)
|
||||
}
|
||||
|
||||
ct, st := mcp.NewInMemoryTransports()
|
||||
_, err := server.Connect(context.Background(), st, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("connect server: %v", err)
|
||||
}
|
||||
|
||||
client := mcp.NewClient(&mcp.Implementation{Name: "client", Version: "0.0.1"}, nil)
|
||||
cs, err := client.Connect(context.Background(), ct, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("connect client: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = cs.Close()
|
||||
}()
|
||||
|
||||
t.Run("required_field", func(t *testing.T) {
|
||||
_, err := cs.CallTool(context.Background(), &mcp.CallToolParams{
|
||||
Name: "create_project",
|
||||
Arguments: map[string]any{"name": ""},
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("CallTool(create_project) error = nil, want error")
|
||||
}
|
||||
|
||||
rpcErr, data := requireWireError(t, err)
|
||||
if rpcErr.Code != jsonrpc.CodeInvalidParams {
|
||||
t.Fatalf("create_project code = %d, want %d", rpcErr.Code, jsonrpc.CodeInvalidParams)
|
||||
}
|
||||
if data.Type != mcperrors.TypeInvalidInput {
|
||||
t.Fatalf("create_project data.type = %q, want %q", data.Type, mcperrors.TypeInvalidInput)
|
||||
}
|
||||
if data.Field != "name" {
|
||||
t.Fatalf("create_project data.field = %q, want %q", data.Field, "name")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("schema_required_field", func(t *testing.T) {
|
||||
_, err := cs.CallTool(context.Background(), &mcp.CallToolParams{
|
||||
Name: "create_project",
|
||||
Arguments: map[string]any{},
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("CallTool(create_project missing field) error = nil, want error")
|
||||
}
|
||||
|
||||
rpcErr, data := requireWireError(t, err)
|
||||
if rpcErr.Code != jsonrpc.CodeInvalidParams {
|
||||
t.Fatalf("create_project schema code = %d, want %d", rpcErr.Code, jsonrpc.CodeInvalidParams)
|
||||
}
|
||||
if data.Type != mcperrors.TypeInvalidArguments {
|
||||
t.Fatalf("create_project schema data.type = %q, want %q", data.Type, mcperrors.TypeInvalidArguments)
|
||||
}
|
||||
if data.Field != "name" {
|
||||
t.Fatalf("create_project schema data.field = %q, want %q", data.Field, "name")
|
||||
}
|
||||
if data.Detail == "" {
|
||||
t.Fatal("create_project schema data.detail = empty, want validation detail")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("project_required", func(t *testing.T) {
|
||||
_, err := cs.CallTool(context.Background(), &mcp.CallToolParams{
|
||||
Name: "get_project_context",
|
||||
Arguments: map[string]any{},
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("CallTool(get_project_context) error = nil, want error")
|
||||
}
|
||||
|
||||
rpcErr, data := requireWireError(t, err)
|
||||
if rpcErr.Code != mcperrors.CodeProjectRequired {
|
||||
t.Fatalf("get_project_context code = %d, want %d", rpcErr.Code, mcperrors.CodeProjectRequired)
|
||||
}
|
||||
if data.Type != mcperrors.TypeProjectRequired {
|
||||
t.Fatalf("get_project_context data.type = %q, want %q", data.Type, mcperrors.TypeProjectRequired)
|
||||
}
|
||||
if data.Field != "project" {
|
||||
t.Fatalf("get_project_context data.field = %q, want %q", data.Field, "project")
|
||||
}
|
||||
if data.Hint == "" {
|
||||
t.Fatal("get_project_context data.hint = empty, want guidance")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
type wireErrorData struct {
|
||||
Type string `json:"type"`
|
||||
Field string `json:"field,omitempty"`
|
||||
Fields []string `json:"fields,omitempty"`
|
||||
Detail string `json:"detail,omitempty"`
|
||||
Hint string `json:"hint,omitempty"`
|
||||
}
|
||||
|
||||
func requireWireError(t *testing.T, err error) (*jsonrpc.Error, wireErrorData) {
|
||||
t.Helper()
|
||||
|
||||
var rpcErr *jsonrpc.Error
|
||||
if !errors.As(err, &rpcErr) {
|
||||
t.Fatalf("error type = %T, want *jsonrpc.Error", err)
|
||||
}
|
||||
|
||||
var data wireErrorData
|
||||
if len(rpcErr.Data) > 0 {
|
||||
if unmarshalErr := json.Unmarshal(rpcErr.Data, &data); unmarshalErr != nil {
|
||||
t.Fatalf("unmarshal wire error data: %v", unmarshalErr)
|
||||
}
|
||||
}
|
||||
|
||||
return rpcErr, data
|
||||
}
|
||||
|
||||
func toolsSearchConfig() config.SearchConfig {
|
||||
return config.SearchConfig{
|
||||
DefaultLimit: 10,
|
||||
MaxLimit: 50,
|
||||
DefaultThreshold: 0.7,
|
||||
}
|
||||
}
|
||||
42
internal/mcpserver/eventstore.go
Normal file
42
internal/mcpserver/eventstore.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package mcpserver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"iter"
|
||||
|
||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
)
|
||||
|
||||
type cleanupEventStore struct {
|
||||
base mcp.EventStore
|
||||
onSessionClosed func(string)
|
||||
}
|
||||
|
||||
func newCleanupEventStore(base mcp.EventStore, onSessionClosed func(string)) mcp.EventStore {
|
||||
return &cleanupEventStore{
|
||||
base: base,
|
||||
onSessionClosed: onSessionClosed,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *cleanupEventStore) Open(ctx context.Context, sessionID, streamID string) error {
|
||||
return s.base.Open(ctx, sessionID, streamID)
|
||||
}
|
||||
|
||||
func (s *cleanupEventStore) Append(ctx context.Context, sessionID, streamID string, data []byte) error {
|
||||
return s.base.Append(ctx, sessionID, streamID, data)
|
||||
}
|
||||
|
||||
func (s *cleanupEventStore) After(ctx context.Context, sessionID, streamID string, index int) iter.Seq2[[]byte, error] {
|
||||
return s.base.After(ctx, sessionID, streamID, index)
|
||||
}
|
||||
|
||||
func (s *cleanupEventStore) SessionClosed(ctx context.Context, sessionID string) error {
|
||||
if err := s.base.SessionClosed(ctx, sessionID); err != nil {
|
||||
return err
|
||||
}
|
||||
if s.onSessionClosed != nil {
|
||||
s.onSessionClosed(sessionID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
30
internal/mcpserver/eventstore_test.go
Normal file
30
internal/mcpserver/eventstore_test.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package mcpserver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
|
||||
"git.warky.dev/wdevs/amcs/internal/session"
|
||||
)
|
||||
|
||||
func TestCleanupEventStoreSessionClosedClearsActiveProject(t *testing.T) {
|
||||
activeProjects := session.NewActiveProjects()
|
||||
activeProjects.Set("session-1", uuid.New())
|
||||
|
||||
store := newCleanupEventStore(mcp.NewMemoryEventStore(nil), activeProjects.Clear)
|
||||
|
||||
if _, ok := activeProjects.Get("session-1"); !ok {
|
||||
t.Fatal("active project missing before SessionClosed")
|
||||
}
|
||||
|
||||
if err := store.SessionClosed(context.Background(), "session-1"); err != nil {
|
||||
t.Fatalf("SessionClosed() error = %v", err)
|
||||
}
|
||||
|
||||
if _, ok := activeProjects.Get("session-1"); ok {
|
||||
t.Fatal("active project still present after SessionClosed")
|
||||
}
|
||||
}
|
||||
@@ -6,15 +6,31 @@ import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/jsonschema-go/jsonschema"
|
||||
"github.com/google/uuid"
|
||||
"github.com/modelcontextprotocol/go-sdk/jsonrpc"
|
||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
|
||||
"git.warky.dev/wdevs/amcs/internal/mcperrors"
|
||||
)
|
||||
|
||||
const maxLoggedArgBytes = 512
|
||||
|
||||
var sensitiveToolArgKeys = map[string]struct{}{
|
||||
"client_secret": {},
|
||||
"content": {},
|
||||
"content_base64": {},
|
||||
"content_path": {},
|
||||
"email": {},
|
||||
"notes": {},
|
||||
"phone": {},
|
||||
"value": {},
|
||||
}
|
||||
|
||||
var toolSchemaOptions = &jsonschema.ForOptions{
|
||||
TypeSchemas: map[reflect.Type]*jsonschema.Schema{
|
||||
reflect.TypeFor[uuid.UUID](): {
|
||||
@@ -24,19 +40,92 @@ var toolSchemaOptions = &jsonschema.ForOptions{
|
||||
},
|
||||
}
|
||||
|
||||
func addTool[In any, Out any](server *mcp.Server, logger *slog.Logger, tool *mcp.Tool, handler func(context.Context, *mcp.CallToolRequest, In) (*mcp.CallToolResult, Out, error)) {
|
||||
var (
|
||||
missingPropertyPattern = regexp.MustCompile(`missing properties: \["([^"]+)"\]`)
|
||||
propertyPathPattern = regexp.MustCompile(`validating /properties/([^:]+):`)
|
||||
structFieldPattern = regexp.MustCompile(`\.([A-Za-z0-9_]+) of type`)
|
||||
)
|
||||
|
||||
func addTool[In any, Out any](server *mcp.Server, logger *slog.Logger, tool *mcp.Tool, handler func(context.Context, *mcp.CallToolRequest, In) (*mcp.CallToolResult, Out, error)) error {
|
||||
if err := setToolSchemas[In, Out](tool); err != nil {
|
||||
panic(fmt.Sprintf("configure MCP tool %q schemas: %v", tool.Name, err))
|
||||
return fmt.Errorf("configure MCP tool %q schemas: %w", tool.Name, err)
|
||||
}
|
||||
mcp.AddTool(server, tool, logToolCall(logger, tool.Name, handler))
|
||||
|
||||
inputResolved, err := resolveToolSchema(tool.InputSchema)
|
||||
if err != nil {
|
||||
return fmt.Errorf("resolve MCP tool %q input schema: %w", tool.Name, err)
|
||||
}
|
||||
outputResolved, err := resolveToolSchema(tool.OutputSchema)
|
||||
if err != nil {
|
||||
return fmt.Errorf("resolve MCP tool %q output schema: %w", tool.Name, err)
|
||||
}
|
||||
|
||||
server.AddTool(tool, logToolCall(logger, tool.Name, func(ctx context.Context, req *mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
var input json.RawMessage
|
||||
if req != nil && req.Params != nil && req.Params.Arguments != nil {
|
||||
input = req.Params.Arguments
|
||||
}
|
||||
|
||||
input, err = applyToolSchema(input, inputResolved)
|
||||
if err != nil {
|
||||
return nil, invalidArgumentsError(err)
|
||||
}
|
||||
|
||||
var in In
|
||||
if input != nil {
|
||||
if err := json.Unmarshal(input, &in); err != nil {
|
||||
return nil, invalidArgumentsError(err)
|
||||
}
|
||||
}
|
||||
|
||||
result, out, err := handler(ctx, req, in)
|
||||
if err != nil {
|
||||
if wireErr, ok := err.(*jsonrpc.Error); ok {
|
||||
return nil, wireErr
|
||||
}
|
||||
var errResult mcp.CallToolResult
|
||||
errResult.SetError(err)
|
||||
return &errResult, nil
|
||||
}
|
||||
|
||||
if result == nil {
|
||||
result = &mcp.CallToolResult{}
|
||||
}
|
||||
|
||||
if outputResolved == nil {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
outValue := normalizeTypedNil[Out](out)
|
||||
if outValue == nil {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
outBytes, err := json.Marshal(outValue)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshaling output: %w", err)
|
||||
}
|
||||
outJSON, err := applyToolSchema(json.RawMessage(outBytes), outputResolved)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("validating tool output: %w", err)
|
||||
}
|
||||
|
||||
result.StructuredContent = outJSON
|
||||
if result.Content == nil {
|
||||
result.Content = []mcp.Content{&mcp.TextContent{Text: string(outJSON)}}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}))
|
||||
return nil
|
||||
}
|
||||
|
||||
func logToolCall[In any, Out any](logger *slog.Logger, toolName string, handler func(context.Context, *mcp.CallToolRequest, In) (*mcp.CallToolResult, Out, error)) func(context.Context, *mcp.CallToolRequest, In) (*mcp.CallToolResult, Out, error) {
|
||||
func logToolCall(logger *slog.Logger, toolName string, handler mcp.ToolHandler) mcp.ToolHandler {
|
||||
if logger == nil {
|
||||
return handler
|
||||
}
|
||||
|
||||
return func(ctx context.Context, req *mcp.CallToolRequest, in In) (*mcp.CallToolResult, Out, error) {
|
||||
return func(ctx context.Context, req *mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
start := time.Now()
|
||||
attrs := []any{slog.String("tool", toolName)}
|
||||
if req != nil && req.Params != nil {
|
||||
@@ -44,23 +133,27 @@ func logToolCall[In any, Out any](logger *slog.Logger, toolName string, handler
|
||||
}
|
||||
|
||||
logger.Info("mcp tool started", attrs...)
|
||||
result, out, err := handler(ctx, req, in)
|
||||
result, err := handler(ctx, req)
|
||||
|
||||
completionAttrs := append([]any{}, attrs...)
|
||||
completionAttrs = append(completionAttrs, slog.String("duration", formatLogDuration(time.Since(start))))
|
||||
if err != nil {
|
||||
completionAttrs = append(completionAttrs, slog.String("error", err.Error()))
|
||||
logger.Error("mcp tool completed", completionAttrs...)
|
||||
return result, out, err
|
||||
return result, err
|
||||
}
|
||||
|
||||
logger.Info("mcp tool completed", completionAttrs...)
|
||||
return result, out, nil
|
||||
return result, nil
|
||||
}
|
||||
}
|
||||
|
||||
func truncateArgs(args any) string {
|
||||
b, err := json.Marshal(args)
|
||||
redacted, err := redactArguments(args)
|
||||
if err != nil {
|
||||
return "<unserializable>"
|
||||
}
|
||||
b, err := json.Marshal(redacted)
|
||||
if err != nil {
|
||||
return "<unserializable>"
|
||||
}
|
||||
@@ -70,6 +163,52 @@ func truncateArgs(args any) string {
|
||||
return string(b[:maxLoggedArgBytes]) + fmt.Sprintf("… (%d bytes total)", len(b))
|
||||
}
|
||||
|
||||
func redactArguments(args any) (any, error) {
|
||||
if args == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
b, err := json.Marshal(args)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var decoded any
|
||||
if err := json.Unmarshal(b, &decoded); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return redactJSONValue("", decoded), nil
|
||||
}
|
||||
|
||||
func redactJSONValue(key string, value any) any {
|
||||
if isSensitiveToolArgKey(key) {
|
||||
return "<redacted>"
|
||||
}
|
||||
|
||||
switch typed := value.(type) {
|
||||
case map[string]any:
|
||||
redacted := make(map[string]any, len(typed))
|
||||
for childKey, childValue := range typed {
|
||||
redacted[childKey] = redactJSONValue(childKey, childValue)
|
||||
}
|
||||
return redacted
|
||||
case []any:
|
||||
redacted := make([]any, len(typed))
|
||||
for i, item := range typed {
|
||||
redacted[i] = redactJSONValue("", item)
|
||||
}
|
||||
return redacted
|
||||
default:
|
||||
return value
|
||||
}
|
||||
}
|
||||
|
||||
func isSensitiveToolArgKey(key string) bool {
|
||||
_, ok := sensitiveToolArgKeys[strings.ToLower(strings.TrimSpace(key))]
|
||||
return ok
|
||||
}
|
||||
|
||||
func formatLogDuration(d time.Duration) string {
|
||||
if d < 0 {
|
||||
d = -d
|
||||
@@ -101,3 +240,104 @@ func setToolSchemas[In any, Out any](tool *mcp.Tool) error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func resolveToolSchema(schema any) (*jsonschema.Resolved, error) {
|
||||
if schema == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if typed, ok := schema.(*jsonschema.Schema); ok {
|
||||
return typed.Resolve(&jsonschema.ResolveOptions{ValidateDefaults: true})
|
||||
}
|
||||
|
||||
var remarshalTarget *jsonschema.Schema
|
||||
if err := remarshalJSON(schema, &remarshalTarget); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return remarshalTarget.Resolve(&jsonschema.ResolveOptions{ValidateDefaults: true})
|
||||
}
|
||||
|
||||
func applyToolSchema(data json.RawMessage, resolved *jsonschema.Resolved) (json.RawMessage, error) {
|
||||
if resolved == nil {
|
||||
return data, nil
|
||||
}
|
||||
|
||||
value := make(map[string]any)
|
||||
if len(data) > 0 {
|
||||
if err := json.Unmarshal(data, &value); err != nil {
|
||||
return nil, fmt.Errorf("unmarshaling arguments: %w", err)
|
||||
}
|
||||
}
|
||||
if err := resolved.ApplyDefaults(&value); err != nil {
|
||||
return nil, fmt.Errorf("applying schema defaults: %w", err)
|
||||
}
|
||||
if err := resolved.Validate(&value); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
normalized, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshalling with defaults: %w", err)
|
||||
}
|
||||
return normalized, nil
|
||||
}
|
||||
|
||||
func remarshalJSON(src any, dst any) error {
|
||||
b, err := json.Marshal(src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return json.Unmarshal(b, dst)
|
||||
}
|
||||
|
||||
func normalizeTypedNil[T any](value T) any {
|
||||
rt := reflect.TypeFor[T]()
|
||||
if rt.Kind() == reflect.Pointer {
|
||||
var zero T
|
||||
if any(value) == any(zero) {
|
||||
return reflect.Zero(rt.Elem()).Interface()
|
||||
}
|
||||
}
|
||||
return any(value)
|
||||
}
|
||||
|
||||
func invalidArgumentsError(err error) error {
|
||||
detail := err.Error()
|
||||
field := validationField(detail)
|
||||
|
||||
payload := mcperrors.Data{
|
||||
Type: mcperrors.TypeInvalidArguments,
|
||||
Detail: detail,
|
||||
}
|
||||
if field != "" {
|
||||
payload.Field = field
|
||||
payload.Hint = "check the " + field + " argument"
|
||||
}
|
||||
|
||||
data, marshalErr := json.Marshal(payload)
|
||||
if marshalErr != nil {
|
||||
return &jsonrpc.Error{
|
||||
Code: jsonrpc.CodeInvalidParams,
|
||||
Message: "invalid tool arguments",
|
||||
}
|
||||
}
|
||||
|
||||
return &jsonrpc.Error{
|
||||
Code: jsonrpc.CodeInvalidParams,
|
||||
Message: "invalid tool arguments",
|
||||
Data: data,
|
||||
}
|
||||
}
|
||||
|
||||
func validationField(detail string) string {
|
||||
if matches := missingPropertyPattern.FindStringSubmatch(detail); len(matches) == 2 {
|
||||
return matches[1]
|
||||
}
|
||||
if matches := propertyPathPattern.FindStringSubmatch(detail); len(matches) == 2 {
|
||||
return matches[1]
|
||||
}
|
||||
if matches := structFieldPattern.FindStringSubmatch(detail); len(matches) == 2 {
|
||||
return strings.ToLower(matches[1])
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
package mcpserver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/jsonschema-go/jsonschema"
|
||||
@@ -40,3 +43,59 @@ func TestSetToolSchemasUsesStringUUIDsInListOutput(t *testing.T) {
|
||||
t.Fatalf("id schema format = %q, want %q", idSchema.Format, "uuid")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTruncateArgsRedactsSensitiveFields(t *testing.T) {
|
||||
args := json.RawMessage(`{
|
||||
"query": "todo from yesterday",
|
||||
"content": "private thought body",
|
||||
"content_base64": "c2VjcmV0LWJpbmFyeQ==",
|
||||
"email": "user@example.com",
|
||||
"nested": {
|
||||
"notes": "private note",
|
||||
"phone": "+1-555-0100",
|
||||
"keep": "visible"
|
||||
}
|
||||
}`)
|
||||
|
||||
got := truncateArgs(args)
|
||||
|
||||
for _, secret := range []string{
|
||||
"private thought body",
|
||||
"c2VjcmV0LWJpbmFyeQ==",
|
||||
"user@example.com",
|
||||
"private note",
|
||||
"+1-555-0100",
|
||||
} {
|
||||
if strings.Contains(got, secret) {
|
||||
t.Fatalf("truncateArgs leaked sensitive value %q in %s", secret, got)
|
||||
}
|
||||
}
|
||||
|
||||
for _, expected := range []string{
|
||||
`"query":"todo from yesterday"`,
|
||||
`"keep":"visible"`,
|
||||
`"content":"\u003credacted\u003e"`,
|
||||
`"content_base64":"\u003credacted\u003e"`,
|
||||
`"email":"\u003credacted\u003e"`,
|
||||
`"notes":"\u003credacted\u003e"`,
|
||||
`"phone":"\u003credacted\u003e"`,
|
||||
} {
|
||||
if !strings.Contains(got, expected) {
|
||||
t.Fatalf("truncateArgs(%s) missing %q", got, expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddToolReturnsSchemaErrorInsteadOfPanicking(t *testing.T) {
|
||||
server := mcp.NewServer(&mcp.Implementation{Name: "test", Version: "0.0.1"}, nil)
|
||||
|
||||
err := addTool(server, nil, &mcp.Tool{Name: "broken"}, func(_ context.Context, _ *mcp.CallToolRequest, _ struct{}) (*mcp.CallToolResult, chan int, error) {
|
||||
return nil, nil, nil
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("addTool() error = nil, want schema inference error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), `configure MCP tool "broken" schemas`) {
|
||||
t.Fatalf("addTool() error = %q, want tool context", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,7 +3,6 @@ package mcpserver
|
||||
import (
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
|
||||
@@ -12,6 +11,7 @@ import (
|
||||
)
|
||||
|
||||
type ToolSet struct {
|
||||
Version *tools.VersionTool
|
||||
Capture *tools.CaptureTool
|
||||
Search *tools.SearchTool
|
||||
List *tools.ListTool
|
||||
@@ -37,355 +37,480 @@ type ToolSet struct {
|
||||
Skills *tools.SkillsTool
|
||||
}
|
||||
|
||||
func New(cfg config.MCPConfig, logger *slog.Logger, toolSet ToolSet) http.Handler {
|
||||
func New(cfg config.MCPConfig, logger *slog.Logger, toolSet ToolSet, onSessionClosed func(string)) (http.Handler, error) {
|
||||
server := mcp.NewServer(&mcp.Implementation{
|
||||
Name: cfg.ServerName,
|
||||
Version: cfg.Version,
|
||||
}, nil)
|
||||
|
||||
addTool(server, logger, &mcp.Tool{
|
||||
for _, register := range []func(*mcp.Server, *slog.Logger, ToolSet) error{
|
||||
registerSystemTools,
|
||||
registerThoughtTools,
|
||||
registerProjectTools,
|
||||
registerFileTools,
|
||||
registerMaintenanceTools,
|
||||
registerHouseholdTools,
|
||||
registerCalendarTools,
|
||||
registerMealTools,
|
||||
registerCRMTools,
|
||||
registerSkillTools,
|
||||
} {
|
||||
if err := register(server, logger, toolSet); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
opts := &mcp.StreamableHTTPOptions{
|
||||
JSONResponse: true,
|
||||
SessionTimeout: cfg.SessionTimeout,
|
||||
}
|
||||
if onSessionClosed != nil {
|
||||
opts.EventStore = newCleanupEventStore(mcp.NewMemoryEventStore(nil), onSessionClosed)
|
||||
}
|
||||
|
||||
return mcp.NewStreamableHTTPHandler(func(*http.Request) *mcp.Server {
|
||||
return server
|
||||
}, opts), nil
|
||||
}
|
||||
|
||||
func registerSystemTools(server *mcp.Server, logger *slog.Logger, toolSet ToolSet) error {
|
||||
if err := addTool(server, logger, &mcp.Tool{
|
||||
Name: "get_version_info",
|
||||
Description: "Return the server build version information, including version, tag name, commit, and build date.",
|
||||
}, toolSet.Version.GetInfo); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func registerThoughtTools(server *mcp.Server, logger *slog.Logger, toolSet ToolSet) error {
|
||||
if err := addTool(server, logger, &mcp.Tool{
|
||||
Name: "capture_thought",
|
||||
Description: "Store a thought with generated embeddings and extracted metadata.",
|
||||
}, toolSet.Capture.Handle)
|
||||
|
||||
addTool(server, logger, &mcp.Tool{
|
||||
}, toolSet.Capture.Handle); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := addTool(server, logger, &mcp.Tool{
|
||||
Name: "search_thoughts",
|
||||
Description: "Search stored thoughts by semantic similarity.",
|
||||
}, toolSet.Search.Handle)
|
||||
|
||||
addTool(server, logger, &mcp.Tool{
|
||||
}, toolSet.Search.Handle); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := addTool(server, logger, &mcp.Tool{
|
||||
Name: "list_thoughts",
|
||||
Description: "List recent thoughts with optional metadata filters.",
|
||||
}, toolSet.List.Handle)
|
||||
|
||||
addTool(server, logger, &mcp.Tool{
|
||||
}, toolSet.List.Handle); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := addTool(server, logger, &mcp.Tool{
|
||||
Name: "thought_stats",
|
||||
Description: "Get counts and top metadata buckets across stored thoughts.",
|
||||
}, toolSet.Stats.Handle)
|
||||
|
||||
addTool(server, logger, &mcp.Tool{
|
||||
}, toolSet.Stats.Handle); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := addTool(server, logger, &mcp.Tool{
|
||||
Name: "get_thought",
|
||||
Description: "Retrieve a full thought by id.",
|
||||
}, toolSet.Get.Handle)
|
||||
|
||||
addTool(server, logger, &mcp.Tool{
|
||||
}, toolSet.Get.Handle); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := addTool(server, logger, &mcp.Tool{
|
||||
Name: "update_thought",
|
||||
Description: "Update thought content or merge metadata.",
|
||||
}, toolSet.Update.Handle)
|
||||
|
||||
addTool(server, logger, &mcp.Tool{
|
||||
}, toolSet.Update.Handle); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := addTool(server, logger, &mcp.Tool{
|
||||
Name: "delete_thought",
|
||||
Description: "Hard-delete a thought by id.",
|
||||
}, toolSet.Delete.Handle)
|
||||
|
||||
addTool(server, logger, &mcp.Tool{
|
||||
}, toolSet.Delete.Handle); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := addTool(server, logger, &mcp.Tool{
|
||||
Name: "archive_thought",
|
||||
Description: "Archive a thought so it is hidden from default search and listing.",
|
||||
}, toolSet.Archive.Handle)
|
||||
|
||||
addTool(server, logger, &mcp.Tool{
|
||||
Name: "create_project",
|
||||
Description: "Create a named project container for thoughts.",
|
||||
}, toolSet.Projects.Create)
|
||||
|
||||
addTool(server, logger, &mcp.Tool{
|
||||
Name: "list_projects",
|
||||
Description: "List projects and their current thought counts.",
|
||||
}, toolSet.Projects.List)
|
||||
|
||||
addTool(server, logger, &mcp.Tool{
|
||||
Name: "set_active_project",
|
||||
Description: "Set the active project for the current MCP session.",
|
||||
}, toolSet.Projects.SetActive)
|
||||
|
||||
addTool(server, logger, &mcp.Tool{
|
||||
Name: "get_active_project",
|
||||
Description: "Return the active project for the current MCP session.",
|
||||
}, toolSet.Projects.GetActive)
|
||||
|
||||
addTool(server, logger, &mcp.Tool{
|
||||
Name: "get_project_context",
|
||||
Description: "Get recent and semantic context for a project.",
|
||||
}, toolSet.Context.Handle)
|
||||
|
||||
addTool(server, logger, &mcp.Tool{
|
||||
Name: "recall_context",
|
||||
Description: "Recall semantically relevant and recent context.",
|
||||
}, toolSet.Recall.Handle)
|
||||
|
||||
addTool(server, logger, &mcp.Tool{
|
||||
}, toolSet.Archive.Handle); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := addTool(server, logger, &mcp.Tool{
|
||||
Name: "summarize_thoughts",
|
||||
Description: "Summarize a filtered or searched set of thoughts.",
|
||||
}, toolSet.Summarize.Handle)
|
||||
|
||||
addTool(server, logger, &mcp.Tool{
|
||||
}, toolSet.Summarize.Handle); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := addTool(server, logger, &mcp.Tool{
|
||||
Name: "recall_context",
|
||||
Description: "Recall semantically relevant and recent context.",
|
||||
}, toolSet.Recall.Handle); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := addTool(server, logger, &mcp.Tool{
|
||||
Name: "link_thoughts",
|
||||
Description: "Create a typed relationship between two thoughts.",
|
||||
}, toolSet.Links.Link)
|
||||
|
||||
addTool(server, logger, &mcp.Tool{
|
||||
}, toolSet.Links.Link); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := addTool(server, logger, &mcp.Tool{
|
||||
Name: "related_thoughts",
|
||||
Description: "Retrieve explicit links and semantic neighbors for a thought.",
|
||||
}, toolSet.Links.Related)
|
||||
}, toolSet.Links.Related); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func registerProjectTools(server *mcp.Server, logger *slog.Logger, toolSet ToolSet) error {
|
||||
if err := addTool(server, logger, &mcp.Tool{
|
||||
Name: "create_project",
|
||||
Description: "Create a named project container for thoughts.",
|
||||
}, toolSet.Projects.Create); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := addTool(server, logger, &mcp.Tool{
|
||||
Name: "list_projects",
|
||||
Description: "List projects and their current thought counts.",
|
||||
}, toolSet.Projects.List); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := addTool(server, logger, &mcp.Tool{
|
||||
Name: "set_active_project",
|
||||
Description: "Set the active project for the current MCP session. Requires a stateful MCP client that reuses the same session across calls.",
|
||||
}, toolSet.Projects.SetActive); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := addTool(server, logger, &mcp.Tool{
|
||||
Name: "get_active_project",
|
||||
Description: "Return the active project for the current MCP session. If your client does not preserve MCP sessions, pass project explicitly to project-scoped tools instead.",
|
||||
}, toolSet.Projects.GetActive); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := addTool(server, logger, &mcp.Tool{
|
||||
Name: "get_project_context",
|
||||
Description: "Get recent and semantic context for a project. Uses the explicit project when provided, otherwise the active MCP session project.",
|
||||
}, toolSet.Context.Handle); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func registerFileTools(server *mcp.Server, logger *slog.Logger, toolSet ToolSet) error {
|
||||
server.AddResourceTemplate(&mcp.ResourceTemplate{
|
||||
Name: "stored_file",
|
||||
URITemplate: "amcs://files/{id}",
|
||||
Description: "A stored file. Read a file's raw binary content by its id. Use load_file for metadata.",
|
||||
}, toolSet.Files.ReadResource)
|
||||
|
||||
addTool(server, logger, &mcp.Tool{
|
||||
if err := addTool(server, logger, &mcp.Tool{
|
||||
Name: "upload_file",
|
||||
Description: "Stage a file and get an amcs://files/{id} resource URI. Provide content_path (absolute server-side path, no size limit) or content_base64 (≤10 MB). Optionally link immediately with thought_id/project, or omit them and pass the returned URI to save_file later.",
|
||||
}, toolSet.Files.Upload)
|
||||
|
||||
addTool(server, logger, &mcp.Tool{
|
||||
}, toolSet.Files.Upload); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := addTool(server, logger, &mcp.Tool{
|
||||
Name: "save_file",
|
||||
Description: "Store a file and optionally link it to a thought. Supply either content_base64 (≤10 MB) or content_uri (amcs://files/{id} from a prior upload_file or POST /files call). For files larger than 10 MB, use upload_file with content_path first.",
|
||||
}, toolSet.Files.Save)
|
||||
|
||||
addTool(server, logger, &mcp.Tool{
|
||||
}, toolSet.Files.Save); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := addTool(server, logger, &mcp.Tool{
|
||||
Name: "load_file",
|
||||
Description: "Load a previously stored file by id and return its metadata and base64 content.",
|
||||
}, toolSet.Files.Load)
|
||||
|
||||
addTool(server, logger, &mcp.Tool{
|
||||
}, toolSet.Files.Load); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := addTool(server, logger, &mcp.Tool{
|
||||
Name: "list_files",
|
||||
Description: "List stored files, optionally filtered by thought, project, or kind.",
|
||||
}, toolSet.Files.List)
|
||||
}, toolSet.Files.List); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
addTool(server, logger, &mcp.Tool{
|
||||
func registerMaintenanceTools(server *mcp.Server, logger *slog.Logger, toolSet ToolSet) error {
|
||||
if err := addTool(server, logger, &mcp.Tool{
|
||||
Name: "backfill_embeddings",
|
||||
Description: "Generate missing embeddings for stored thoughts using the active embedding model.",
|
||||
}, toolSet.Backfill.Handle)
|
||||
|
||||
addTool(server, logger, &mcp.Tool{
|
||||
}, toolSet.Backfill.Handle); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := addTool(server, logger, &mcp.Tool{
|
||||
Name: "reparse_thought_metadata",
|
||||
Description: "Re-extract and normalize metadata for stored thoughts from their content.",
|
||||
}, toolSet.Reparse.Handle)
|
||||
|
||||
addTool(server, logger, &mcp.Tool{
|
||||
}, toolSet.Reparse.Handle); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := addTool(server, logger, &mcp.Tool{
|
||||
Name: "retry_failed_metadata",
|
||||
Description: "Retry metadata extraction for thoughts still marked pending or failed.",
|
||||
}, toolSet.RetryMetadata.Handle)
|
||||
|
||||
// Household Knowledge
|
||||
addTool(server, logger, &mcp.Tool{
|
||||
Name: "add_household_item",
|
||||
Description: "Store a household fact (paint color, appliance details, measurement, document, etc.).",
|
||||
}, toolSet.Household.AddItem)
|
||||
|
||||
addTool(server, logger, &mcp.Tool{
|
||||
Name: "search_household_items",
|
||||
Description: "Search household items by name, category, or location.",
|
||||
}, toolSet.Household.SearchItems)
|
||||
|
||||
addTool(server, logger, &mcp.Tool{
|
||||
Name: "get_household_item",
|
||||
Description: "Retrieve a household item by id.",
|
||||
}, toolSet.Household.GetItem)
|
||||
|
||||
addTool(server, logger, &mcp.Tool{
|
||||
Name: "add_vendor",
|
||||
Description: "Add a service provider (plumber, electrician, landscaper, etc.).",
|
||||
}, toolSet.Household.AddVendor)
|
||||
|
||||
addTool(server, logger, &mcp.Tool{
|
||||
Name: "list_vendors",
|
||||
Description: "List household service vendors, optionally filtered by service type.",
|
||||
}, toolSet.Household.ListVendors)
|
||||
|
||||
// Home Maintenance
|
||||
addTool(server, logger, &mcp.Tool{
|
||||
}, toolSet.RetryMetadata.Handle); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := addTool(server, logger, &mcp.Tool{
|
||||
Name: "add_maintenance_task",
|
||||
Description: "Create a recurring or one-time home maintenance task.",
|
||||
}, toolSet.Maintenance.AddTask)
|
||||
|
||||
addTool(server, logger, &mcp.Tool{
|
||||
}, toolSet.Maintenance.AddTask); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := addTool(server, logger, &mcp.Tool{
|
||||
Name: "log_maintenance",
|
||||
Description: "Log completed maintenance work; automatically updates the task's next due date.",
|
||||
}, toolSet.Maintenance.LogWork)
|
||||
|
||||
addTool(server, logger, &mcp.Tool{
|
||||
}, toolSet.Maintenance.LogWork); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := addTool(server, logger, &mcp.Tool{
|
||||
Name: "get_upcoming_maintenance",
|
||||
Description: "List maintenance tasks due within the next N days.",
|
||||
}, toolSet.Maintenance.GetUpcoming)
|
||||
|
||||
addTool(server, logger, &mcp.Tool{
|
||||
}, toolSet.Maintenance.GetUpcoming); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := addTool(server, logger, &mcp.Tool{
|
||||
Name: "search_maintenance_history",
|
||||
Description: "Search the maintenance log by task name, category, or date range.",
|
||||
}, toolSet.Maintenance.SearchHistory)
|
||||
}, toolSet.Maintenance.SearchHistory); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Family Calendar
|
||||
addTool(server, logger, &mcp.Tool{
|
||||
func registerHouseholdTools(server *mcp.Server, logger *slog.Logger, toolSet ToolSet) error {
|
||||
if err := addTool(server, logger, &mcp.Tool{
|
||||
Name: "add_household_item",
|
||||
Description: "Store a household fact (paint color, appliance details, measurement, document, etc.).",
|
||||
}, toolSet.Household.AddItem); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := addTool(server, logger, &mcp.Tool{
|
||||
Name: "search_household_items",
|
||||
Description: "Search household items by name, category, or location.",
|
||||
}, toolSet.Household.SearchItems); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := addTool(server, logger, &mcp.Tool{
|
||||
Name: "get_household_item",
|
||||
Description: "Retrieve a household item by id.",
|
||||
}, toolSet.Household.GetItem); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := addTool(server, logger, &mcp.Tool{
|
||||
Name: "add_vendor",
|
||||
Description: "Add a service provider (plumber, electrician, landscaper, etc.).",
|
||||
}, toolSet.Household.AddVendor); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := addTool(server, logger, &mcp.Tool{
|
||||
Name: "list_vendors",
|
||||
Description: "List household service vendors, optionally filtered by service type.",
|
||||
}, toolSet.Household.ListVendors); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func registerCalendarTools(server *mcp.Server, logger *slog.Logger, toolSet ToolSet) error {
|
||||
if err := addTool(server, logger, &mcp.Tool{
|
||||
Name: "add_family_member",
|
||||
Description: "Add a family member to the household.",
|
||||
}, toolSet.Calendar.AddMember)
|
||||
|
||||
addTool(server, logger, &mcp.Tool{
|
||||
}, toolSet.Calendar.AddMember); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := addTool(server, logger, &mcp.Tool{
|
||||
Name: "list_family_members",
|
||||
Description: "List all family members.",
|
||||
}, toolSet.Calendar.ListMembers)
|
||||
|
||||
addTool(server, logger, &mcp.Tool{
|
||||
}, toolSet.Calendar.ListMembers); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := addTool(server, logger, &mcp.Tool{
|
||||
Name: "add_activity",
|
||||
Description: "Schedule a one-time or recurring family activity.",
|
||||
}, toolSet.Calendar.AddActivity)
|
||||
|
||||
addTool(server, logger, &mcp.Tool{
|
||||
}, toolSet.Calendar.AddActivity); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := addTool(server, logger, &mcp.Tool{
|
||||
Name: "get_week_schedule",
|
||||
Description: "Get all activities scheduled for a given week.",
|
||||
}, toolSet.Calendar.GetWeekSchedule)
|
||||
|
||||
addTool(server, logger, &mcp.Tool{
|
||||
}, toolSet.Calendar.GetWeekSchedule); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := addTool(server, logger, &mcp.Tool{
|
||||
Name: "search_activities",
|
||||
Description: "Search activities by title, type, or family member.",
|
||||
}, toolSet.Calendar.SearchActivities)
|
||||
|
||||
addTool(server, logger, &mcp.Tool{
|
||||
}, toolSet.Calendar.SearchActivities); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := addTool(server, logger, &mcp.Tool{
|
||||
Name: "add_important_date",
|
||||
Description: "Track a birthday, anniversary, deadline, or other important date.",
|
||||
}, toolSet.Calendar.AddImportantDate)
|
||||
|
||||
addTool(server, logger, &mcp.Tool{
|
||||
}, toolSet.Calendar.AddImportantDate); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := addTool(server, logger, &mcp.Tool{
|
||||
Name: "get_upcoming_dates",
|
||||
Description: "Get important dates coming up in the next N days.",
|
||||
}, toolSet.Calendar.GetUpcomingDates)
|
||||
}, toolSet.Calendar.GetUpcomingDates); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Meal Planning
|
||||
addTool(server, logger, &mcp.Tool{
|
||||
func registerMealTools(server *mcp.Server, logger *slog.Logger, toolSet ToolSet) error {
|
||||
if err := addTool(server, logger, &mcp.Tool{
|
||||
Name: "add_recipe",
|
||||
Description: "Save a recipe with ingredients and instructions.",
|
||||
}, toolSet.Meals.AddRecipe)
|
||||
|
||||
addTool(server, logger, &mcp.Tool{
|
||||
}, toolSet.Meals.AddRecipe); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := addTool(server, logger, &mcp.Tool{
|
||||
Name: "search_recipes",
|
||||
Description: "Search recipes by name, cuisine, tags, or ingredient.",
|
||||
}, toolSet.Meals.SearchRecipes)
|
||||
|
||||
addTool(server, logger, &mcp.Tool{
|
||||
}, toolSet.Meals.SearchRecipes); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := addTool(server, logger, &mcp.Tool{
|
||||
Name: "update_recipe",
|
||||
Description: "Update an existing recipe.",
|
||||
}, toolSet.Meals.UpdateRecipe)
|
||||
|
||||
addTool(server, logger, &mcp.Tool{
|
||||
}, toolSet.Meals.UpdateRecipe); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := addTool(server, logger, &mcp.Tool{
|
||||
Name: "create_meal_plan",
|
||||
Description: "Set the meal plan for a week; replaces any existing plan for that week.",
|
||||
}, toolSet.Meals.CreateMealPlan)
|
||||
|
||||
addTool(server, logger, &mcp.Tool{
|
||||
}, toolSet.Meals.CreateMealPlan); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := addTool(server, logger, &mcp.Tool{
|
||||
Name: "get_meal_plan",
|
||||
Description: "Get the meal plan for a given week.",
|
||||
}, toolSet.Meals.GetMealPlan)
|
||||
|
||||
addTool(server, logger, &mcp.Tool{
|
||||
}, toolSet.Meals.GetMealPlan); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := addTool(server, logger, &mcp.Tool{
|
||||
Name: "generate_shopping_list",
|
||||
Description: "Auto-generate a shopping list from the meal plan for a given week.",
|
||||
}, toolSet.Meals.GenerateShoppingList)
|
||||
}, toolSet.Meals.GenerateShoppingList); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Professional CRM
|
||||
addTool(server, logger, &mcp.Tool{
|
||||
func registerCRMTools(server *mcp.Server, logger *slog.Logger, toolSet ToolSet) error {
|
||||
if err := addTool(server, logger, &mcp.Tool{
|
||||
Name: "add_professional_contact",
|
||||
Description: "Add a professional contact to the CRM.",
|
||||
}, toolSet.CRM.AddContact)
|
||||
|
||||
addTool(server, logger, &mcp.Tool{
|
||||
}, toolSet.CRM.AddContact); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := addTool(server, logger, &mcp.Tool{
|
||||
Name: "search_contacts",
|
||||
Description: "Search professional contacts by name, company, title, notes, or tags.",
|
||||
}, toolSet.CRM.SearchContacts)
|
||||
|
||||
addTool(server, logger, &mcp.Tool{
|
||||
}, toolSet.CRM.SearchContacts); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := addTool(server, logger, &mcp.Tool{
|
||||
Name: "log_interaction",
|
||||
Description: "Log an interaction with a professional contact.",
|
||||
}, toolSet.CRM.LogInteraction)
|
||||
|
||||
addTool(server, logger, &mcp.Tool{
|
||||
}, toolSet.CRM.LogInteraction); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := addTool(server, logger, &mcp.Tool{
|
||||
Name: "get_contact_history",
|
||||
Description: "Get full history (interactions and opportunities) for a contact.",
|
||||
}, toolSet.CRM.GetHistory)
|
||||
|
||||
addTool(server, logger, &mcp.Tool{
|
||||
}, toolSet.CRM.GetHistory); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := addTool(server, logger, &mcp.Tool{
|
||||
Name: "create_opportunity",
|
||||
Description: "Create a deal, project, or opportunity linked to a contact.",
|
||||
}, toolSet.CRM.CreateOpportunity)
|
||||
|
||||
addTool(server, logger, &mcp.Tool{
|
||||
}, toolSet.CRM.CreateOpportunity); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := addTool(server, logger, &mcp.Tool{
|
||||
Name: "get_follow_ups_due",
|
||||
Description: "List contacts with a follow-up date due within the next N days.",
|
||||
}, toolSet.CRM.GetFollowUpsDue)
|
||||
|
||||
addTool(server, logger, &mcp.Tool{
|
||||
}, toolSet.CRM.GetFollowUpsDue); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := addTool(server, logger, &mcp.Tool{
|
||||
Name: "link_thought_to_contact",
|
||||
Description: "Append a stored thought to a contact's notes.",
|
||||
}, toolSet.CRM.LinkThought)
|
||||
}, toolSet.CRM.LinkThought); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Agent Skills
|
||||
addTool(server, logger, &mcp.Tool{
|
||||
func registerSkillTools(server *mcp.Server, logger *slog.Logger, toolSet ToolSet) error {
|
||||
if err := addTool(server, logger, &mcp.Tool{
|
||||
Name: "add_skill",
|
||||
Description: "Store a reusable agent skill (behavioural instruction or capability prompt).",
|
||||
}, toolSet.Skills.AddSkill)
|
||||
|
||||
addTool(server, logger, &mcp.Tool{
|
||||
}, toolSet.Skills.AddSkill); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := addTool(server, logger, &mcp.Tool{
|
||||
Name: "remove_skill",
|
||||
Description: "Delete an agent skill by id.",
|
||||
}, toolSet.Skills.RemoveSkill)
|
||||
|
||||
addTool(server, logger, &mcp.Tool{
|
||||
}, toolSet.Skills.RemoveSkill); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := addTool(server, logger, &mcp.Tool{
|
||||
Name: "list_skills",
|
||||
Description: "List all agent skills, optionally filtered by tag.",
|
||||
}, toolSet.Skills.ListSkills)
|
||||
|
||||
// Agent Guardrails
|
||||
addTool(server, logger, &mcp.Tool{
|
||||
}, toolSet.Skills.ListSkills); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := addTool(server, logger, &mcp.Tool{
|
||||
Name: "add_guardrail",
|
||||
Description: "Store a reusable agent guardrail (constraint or safety rule).",
|
||||
}, toolSet.Skills.AddGuardrail)
|
||||
|
||||
addTool(server, logger, &mcp.Tool{
|
||||
}, toolSet.Skills.AddGuardrail); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := addTool(server, logger, &mcp.Tool{
|
||||
Name: "remove_guardrail",
|
||||
Description: "Delete an agent guardrail by id.",
|
||||
}, toolSet.Skills.RemoveGuardrail)
|
||||
|
||||
addTool(server, logger, &mcp.Tool{
|
||||
}, toolSet.Skills.RemoveGuardrail); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := addTool(server, logger, &mcp.Tool{
|
||||
Name: "list_guardrails",
|
||||
Description: "List all agent guardrails, optionally filtered by tag or severity.",
|
||||
}, toolSet.Skills.ListGuardrails)
|
||||
|
||||
// Project Skills & Guardrails
|
||||
addTool(server, logger, &mcp.Tool{
|
||||
}, toolSet.Skills.ListGuardrails); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := addTool(server, logger, &mcp.Tool{
|
||||
Name: "add_project_skill",
|
||||
Description: "Link an agent skill to a project.",
|
||||
}, toolSet.Skills.AddProjectSkill)
|
||||
|
||||
addTool(server, logger, &mcp.Tool{
|
||||
Description: "Link an agent skill to a project. Pass project explicitly when your client does not preserve MCP sessions.",
|
||||
}, toolSet.Skills.AddProjectSkill); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := addTool(server, logger, &mcp.Tool{
|
||||
Name: "remove_project_skill",
|
||||
Description: "Unlink an agent skill from a project.",
|
||||
}, toolSet.Skills.RemoveProjectSkill)
|
||||
|
||||
addTool(server, logger, &mcp.Tool{
|
||||
Description: "Unlink an agent skill from a project. Pass project explicitly when your client does not preserve MCP sessions.",
|
||||
}, toolSet.Skills.RemoveProjectSkill); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := addTool(server, logger, &mcp.Tool{
|
||||
Name: "list_project_skills",
|
||||
Description: "List all skills linked to a project. Call this at the start of a project session to load existing agent behaviour instructions before generating new ones.",
|
||||
}, toolSet.Skills.ListProjectSkills)
|
||||
|
||||
addTool(server, logger, &mcp.Tool{
|
||||
Description: "List all skills linked to a project. Call this at the start of a project session to load existing agent behaviour instructions before generating new ones. Pass project explicitly when your client does not preserve MCP sessions.",
|
||||
}, toolSet.Skills.ListProjectSkills); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := addTool(server, logger, &mcp.Tool{
|
||||
Name: "add_project_guardrail",
|
||||
Description: "Link an agent guardrail to a project.",
|
||||
}, toolSet.Skills.AddProjectGuardrail)
|
||||
|
||||
addTool(server, logger, &mcp.Tool{
|
||||
Description: "Link an agent guardrail to a project. Pass project explicitly when your client does not preserve MCP sessions.",
|
||||
}, toolSet.Skills.AddProjectGuardrail); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := addTool(server, logger, &mcp.Tool{
|
||||
Name: "remove_project_guardrail",
|
||||
Description: "Unlink an agent guardrail from a project.",
|
||||
}, toolSet.Skills.RemoveProjectGuardrail)
|
||||
|
||||
addTool(server, logger, &mcp.Tool{
|
||||
Description: "Unlink an agent guardrail from a project. Pass project explicitly when your client does not preserve MCP sessions.",
|
||||
}, toolSet.Skills.RemoveProjectGuardrail); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := addTool(server, logger, &mcp.Tool{
|
||||
Name: "list_project_guardrails",
|
||||
Description: "List all guardrails linked to a project. Call this at the start of a project session to load existing agent constraints before generating new ones.",
|
||||
}, toolSet.Skills.ListProjectGuardrails)
|
||||
|
||||
return mcp.NewStreamableHTTPHandler(func(*http.Request) *mcp.Server {
|
||||
return server
|
||||
}, &mcp.StreamableHTTPOptions{
|
||||
JSONResponse: true,
|
||||
SessionTimeout: 10 * time.Minute,
|
||||
})
|
||||
Description: "List all guardrails linked to a project. Call this at the start of a project session to load existing agent constraints before generating new ones. Pass project explicitly when your client does not preserve MCP sessions.",
|
||||
}, toolSet.Skills.ListProjectGuardrails); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
137
internal/mcpserver/streamable_integration_test.go
Normal file
137
internal/mcpserver/streamable_integration_test.go
Normal file
@@ -0,0 +1,137 @@
|
||||
package mcpserver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/modelcontextprotocol/go-sdk/jsonrpc"
|
||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
|
||||
"git.warky.dev/wdevs/amcs/internal/buildinfo"
|
||||
"git.warky.dev/wdevs/amcs/internal/config"
|
||||
"git.warky.dev/wdevs/amcs/internal/mcperrors"
|
||||
"git.warky.dev/wdevs/amcs/internal/tools"
|
||||
)
|
||||
|
||||
func TestStreamableHTTPReturnsStructuredToolErrors(t *testing.T) {
|
||||
handler, err := New(config.MCPConfig{
|
||||
ServerName: "test",
|
||||
Version: "0.0.1",
|
||||
SessionTimeout: time.Minute,
|
||||
}, nil, streamableTestToolSet(), nil)
|
||||
if err != nil {
|
||||
t.Fatalf("mcpserver.New() error = %v", err)
|
||||
}
|
||||
|
||||
httpServer := httptest.NewServer(handler)
|
||||
defer httpServer.Close()
|
||||
|
||||
client := mcp.NewClient(&mcp.Implementation{Name: "client", Version: "0.0.1"}, nil)
|
||||
cs, err := client.Connect(context.Background(), &mcp.StreamableClientTransport{Endpoint: httpServer.URL}, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("connect client: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = cs.Close()
|
||||
}()
|
||||
|
||||
t.Run("schema_validation", func(t *testing.T) {
|
||||
_, err := cs.CallTool(context.Background(), &mcp.CallToolParams{
|
||||
Name: "create_project",
|
||||
Arguments: map[string]any{},
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("CallTool(create_project) error = nil, want error")
|
||||
}
|
||||
|
||||
rpcErr, data := requireWireError(t, err)
|
||||
if rpcErr.Code != jsonrpc.CodeInvalidParams {
|
||||
t.Fatalf("create_project code = %d, want %d", rpcErr.Code, jsonrpc.CodeInvalidParams)
|
||||
}
|
||||
if data.Type != mcperrors.TypeInvalidArguments {
|
||||
t.Fatalf("create_project data.type = %q, want %q", data.Type, mcperrors.TypeInvalidArguments)
|
||||
}
|
||||
if data.Field != "name" {
|
||||
t.Fatalf("create_project data.field = %q, want %q", data.Field, "name")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("project_required", func(t *testing.T) {
|
||||
_, err := cs.CallTool(context.Background(), &mcp.CallToolParams{
|
||||
Name: "get_project_context",
|
||||
Arguments: map[string]any{},
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("CallTool(get_project_context) error = nil, want error")
|
||||
}
|
||||
|
||||
rpcErr, data := requireWireError(t, err)
|
||||
if rpcErr.Code != mcperrors.CodeProjectRequired {
|
||||
t.Fatalf("get_project_context code = %d, want %d", rpcErr.Code, mcperrors.CodeProjectRequired)
|
||||
}
|
||||
if data.Type != mcperrors.TypeProjectRequired {
|
||||
t.Fatalf("get_project_context data.type = %q, want %q", data.Type, mcperrors.TypeProjectRequired)
|
||||
}
|
||||
if data.Field != "project" {
|
||||
t.Fatalf("get_project_context data.field = %q, want %q", data.Field, "project")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("version_info", func(t *testing.T) {
|
||||
result, err := cs.CallTool(context.Background(), &mcp.CallToolParams{
|
||||
Name: "get_version_info",
|
||||
Arguments: map[string]any{},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("CallTool(get_version_info) error = %v", err)
|
||||
}
|
||||
|
||||
got, ok := result.StructuredContent.(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("structured content type = %T, want map[string]any", result.StructuredContent)
|
||||
}
|
||||
if got["server_name"] != "test" {
|
||||
t.Fatalf("server_name = %#v, want %q", got["server_name"], "test")
|
||||
}
|
||||
if got["version"] != "0.0.1" {
|
||||
t.Fatalf("version = %#v, want %q", got["version"], "0.0.1")
|
||||
}
|
||||
if got["tag_name"] != "v0.0.1" {
|
||||
t.Fatalf("tag_name = %#v, want %q", got["tag_name"], "v0.0.1")
|
||||
}
|
||||
if got["build_date"] != "2026-03-31T00:00:00Z" {
|
||||
t.Fatalf("build_date = %#v, want %q", got["build_date"], "2026-03-31T00:00:00Z")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func streamableTestToolSet() ToolSet {
|
||||
return ToolSet{
|
||||
Version: tools.NewVersionTool("test", buildinfo.Info{Version: "0.0.1", TagName: "v0.0.1", Commit: "test", BuildDate: "2026-03-31T00:00:00Z"}),
|
||||
Capture: new(tools.CaptureTool),
|
||||
Search: new(tools.SearchTool),
|
||||
List: new(tools.ListTool),
|
||||
Stats: new(tools.StatsTool),
|
||||
Get: new(tools.GetTool),
|
||||
Update: new(tools.UpdateTool),
|
||||
Delete: new(tools.DeleteTool),
|
||||
Archive: new(tools.ArchiveTool),
|
||||
Projects: new(tools.ProjectsTool),
|
||||
Context: new(tools.ContextTool),
|
||||
Recall: new(tools.RecallTool),
|
||||
Summarize: new(tools.SummarizeTool),
|
||||
Links: new(tools.LinksTool),
|
||||
Files: new(tools.FilesTool),
|
||||
Backfill: new(tools.BackfillTool),
|
||||
Reparse: new(tools.ReparseMetadataTool),
|
||||
RetryMetadata: new(tools.RetryMetadataTool),
|
||||
Household: new(tools.HouseholdTool),
|
||||
Maintenance: new(tools.MaintenanceTool),
|
||||
Calendar: new(tools.CalendarTool),
|
||||
Meals: new(tools.MealsTool),
|
||||
CRM: new(tools.CRMTool),
|
||||
Skills: new(tools.SkillsTool),
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user