test(tools): add unit tests for error handling functions

* Implement tests for error functions like errRequiredField, errInvalidField, and errEntityNotFound.
* Ensure proper metadata is returned for various error scenarios.
* Validate error handling in CRM, Files, and other tools.
* Introduce tests for parsing stored file IDs and UUIDs.
* Enhance coverage for helper functions related to project resolution and session management.
This commit is contained in:
Hein
2026-03-31 15:10:07 +02:00
parent acd780ac9c
commit f41c512f36
54 changed files with 1937 additions and 365 deletions

View File

@@ -1,4 +1,3 @@
.git
.gitignore .gitignore
.vscode .vscode
bin bin

View File

@@ -7,7 +7,17 @@ RUN go mod download
COPY . . COPY . .
RUN CGO_ENABLED=0 GOOS=linux go build -trimpath -ldflags="-s -w" -o /out/amcs-server ./cmd/amcs-server RUN set -eu; \
VERSION_TAG="$(git describe --tags --exact-match 2>/dev/null || echo dev)"; \
COMMIT_SHA="$(git rev-parse --short HEAD 2>/dev/null || echo unknown)"; \
BUILD_DATE="$(date -u +%Y-%m-%dT%H:%M:%SZ)"; \
CGO_ENABLED=0 GOOS=linux go build -trimpath \
-ldflags="-s -w \
-X git.warky.dev/wdevs/amcs/internal/buildinfo.Version=${VERSION_TAG} \
-X git.warky.dev/wdevs/amcs/internal/buildinfo.TagName=${VERSION_TAG} \
-X git.warky.dev/wdevs/amcs/internal/buildinfo.Commit=${COMMIT_SHA} \
-X git.warky.dev/wdevs/amcs/internal/buildinfo.BuildDate=${BUILD_DATE}" \
-o /out/amcs-server ./cmd/amcs-server
FROM debian:bookworm-slim FROM debian:bookworm-slim

View File

@@ -1,6 +1,15 @@
BIN_DIR := bin BIN_DIR := bin
SERVER_BIN := $(BIN_DIR)/amcs-server SERVER_BIN := $(BIN_DIR)/amcs-server
CMD_SERVER := ./cmd/amcs-server CMD_SERVER := ./cmd/amcs-server
BUILDINFO_PKG := git.warky.dev/wdevs/amcs/internal/buildinfo
VERSION_TAG ?= $(shell git describe --tags --exact-match 2>/dev/null || echo dev)
COMMIT_SHA ?= $(shell git rev-parse --short HEAD 2>/dev/null || echo unknown)
BUILD_DATE ?= $(shell date -u +%Y-%m-%dT%H:%M:%SZ)
LDFLAGS := -s -w \
-X $(BUILDINFO_PKG).Version=$(VERSION_TAG) \
-X $(BUILDINFO_PKG).TagName=$(VERSION_TAG) \
-X $(BUILDINFO_PKG).Commit=$(COMMIT_SHA) \
-X $(BUILDINFO_PKG).BuildDate=$(BUILD_DATE)
.PHONY: all build clean migrate .PHONY: all build clean migrate
@@ -8,7 +17,7 @@ all: build
build: build:
@mkdir -p $(BIN_DIR) @mkdir -p $(BIN_DIR)
go build -o $(SERVER_BIN) $(CMD_SERVER) go build -ldflags "$(LDFLAGS)" -o $(SERVER_BIN) $(CMD_SERVER)
migrate: migrate:
./scripts/migrate.sh ./scripts/migrate.sh

132
README.md
View File

@@ -34,8 +34,8 @@ A Go MCP server for capturing and retrieving thoughts, memory, and project conte
| `archive_thought` | Soft delete | | `archive_thought` | Soft delete |
| `create_project` | Register a named project | | `create_project` | Register a named project |
| `list_projects` | List projects with thought counts | | `list_projects` | List projects with thought counts |
| `get_project_context` | Recent + semantic context for a project | | `get_project_context` | Recent + semantic context for a project; uses explicit `project` or the active session project |
| `set_active_project` | Set session project scope | | `set_active_project` | Set session project scope; requires a stateful MCP session |
| `get_active_project` | Get current session project | | `get_active_project` | Get current session project |
| `summarize_thoughts` | LLM prose summary over a filtered set | | `summarize_thoughts` | LLM prose summary over a filtered set |
| `recall_context` | Semantic + recency context block for injection | | `recall_context` | Semantic + recency context block for injection |
@@ -54,18 +54,129 @@ A Go MCP server for capturing and retrieving thoughts, memory, and project conte
| `add_guardrail` | Store a reusable agent guardrail (constraint or safety rule) | | `add_guardrail` | Store a reusable agent guardrail (constraint or safety rule) |
| `remove_guardrail` | Delete an agent guardrail by id | | `remove_guardrail` | Delete an agent guardrail by id |
| `list_guardrails` | List all agent guardrails, optionally filtered by tag or severity | | `list_guardrails` | List all agent guardrails, optionally filtered by tag or severity |
| `add_project_skill` | Link an agent skill to a project | | `add_project_skill` | Link an agent skill to a project; pass `project` explicitly if your client does not preserve MCP sessions |
| `remove_project_skill` | Unlink an agent skill from a project | | `remove_project_skill` | Unlink an agent skill from a project; pass `project` explicitly if your client does not preserve MCP sessions |
| `list_project_skills` | List all skills linked to a project | | `list_project_skills` | List all skills linked to a project; pass `project` explicitly if your client does not preserve MCP sessions |
| `add_project_guardrail` | Link an agent guardrail to a project | | `add_project_guardrail` | Link an agent guardrail to a project; pass `project` explicitly if your client does not preserve MCP sessions |
| `remove_project_guardrail` | Unlink an agent guardrail from a project | | `remove_project_guardrail` | Unlink an agent guardrail from a project; pass `project` explicitly if your client does not preserve MCP sessions |
| `list_project_guardrails` | List all guardrails linked to a project | | `list_project_guardrails` | List all guardrails linked to a project; pass `project` explicitly if your client does not preserve MCP sessions |
| `get_version_info` | Return the server build version information, including version, tag name, commit, and build date |
## MCP Error Contract
AMCS returns structured JSON-RPC errors for common MCP failures. Clients should branch on both `error.code` and `error.data.type` instead of parsing the human-readable message.
### Stable error codes
| Code | `data.type` | Meaning |
|---|---|---|
| `-32602` | `invalid_arguments` | MCP argument/schema validation failed before the tool handler ran |
| `-32602` | `invalid_input` | Tool-level input validation failed inside the handler |
| `-32050` | `session_required` | Tool requires a stateful MCP session |
| `-32051` | `project_required` | No explicit `project` was provided and no active session project was available |
| `-32052` | `project_not_found` | The referenced project does not exist |
| `-32053` | `invalid_id` | A UUID-like identifier was malformed |
| `-32054` | `entity_not_found` | A referenced entity such as a thought or contact does not exist |
### Error data shape
AMCS may include these fields in `error.data`:
- `type` — stable machine-readable error type
- `field` — single argument name such as `name`, `project`, or `thought_id`
- `fields` — multiple argument names for one-of or mutually-exclusive validation
- `value` — offending value when safe to expose
- `detail` — validation detail such as `required`, `invalid`, `one_of_required`, `mutually_exclusive`, or a schema validation message
- `hint` — remediation guidance
- `entity` — entity name for generic not-found errors
Example schema-level error:
```json
{
"code": -32602,
"message": "invalid tool arguments",
"data": {
"type": "invalid_arguments",
"field": "name",
"detail": "validating root: required: missing properties: [\"name\"]",
"hint": "check the name argument"
}
}
```
Example tool-level error:
```json
{
"code": -32051,
"message": "project is required; pass project explicitly or call set_active_project in this MCP session first",
"data": {
"type": "project_required",
"field": "project",
"hint": "pass project explicitly or call set_active_project in this MCP session first"
}
}
```
### Client example
Go client example handling AMCS MCP errors:
```go
result, err := session.CallTool(ctx, &mcp.CallToolParams{
Name: "get_project_context",
Arguments: map[string]any{},
})
if err != nil {
var rpcErr *jsonrpc.Error
if errors.As(err, &rpcErr) {
var data struct {
Type string `json:"type"`
Field string `json:"field"`
Hint string `json:"hint"`
}
_ = json.Unmarshal(rpcErr.Data, &data)
switch {
case rpcErr.Code == -32051 && data.Type == "project_required":
// Retry with an explicit project, or call set_active_project first.
case rpcErr.Code == -32602 && data.Type == "invalid_arguments":
// Ask the caller to fix the malformed arguments.
}
}
}
_ = result
```
## Build Versioning
AMCS embeds build metadata into the binary at build time.
- `version` is generated from the current git tag when building from a tagged commit
- `tag_name` is the repo tag name, for example `v1.0.1`
- `build_date` is the UTC build timestamp in RFC3339 format
- `commit` is the short git commit SHA
For untagged builds, `version` and `tag_name` fall back to `dev`.
Use `get_version_info` to retrieve the runtime build metadata:
```json
{
"server_name": "amcs",
"version": "v1.0.1",
"tag_name": "v1.0.1",
"commit": "abc1234",
"build_date": "2026-03-31T14:22:10Z"
}
```
## Agent Skills and Guardrails ## Agent Skills and Guardrails
Skills and guardrails are reusable agent behaviour instructions and constraints that can be attached to projects. Skills and guardrails are reusable agent behaviour instructions and constraints that can be attached to projects.
**At the start of every project session, always call `list_project_skills` and `list_project_guardrails` first.** Use the returned skills and guardrails to guide agent behaviour for that project. Only generate or create new skills/guardrails if none are returned. **At the start of every project session, always call `list_project_skills` and `list_project_guardrails` first.** Use the returned skills and guardrails to guide agent behaviour for that project. Only generate or create new skills/guardrails if none are returned. If your MCP client does not preserve sessions across calls, pass `project` explicitly instead of relying on `set_active_project`.
### Skills ### Skills
@@ -102,6 +213,7 @@ Config is YAML-driven. Copy `configs/config.example.yaml` and set:
- `auth.mode``api_keys` or `oauth_client_credentials` - `auth.mode``api_keys` or `oauth_client_credentials`
- `auth.keys` — API keys for MCP access via `x-brain-key` or `Authorization: Bearer <key>` when `auth.mode=api_keys` - `auth.keys` — API keys for MCP access via `x-brain-key` or `Authorization: Bearer <key>` when `auth.mode=api_keys`
- `auth.oauth.clients` — client registry when `auth.mode=oauth_client_credentials` - `auth.oauth.clients` — client registry when `auth.mode=oauth_client_credentials`
- `mcp.version` is build-generated and should not be set in config
**OAuth Client Credentials flow** (`auth.mode=oauth_client_credentials`): **OAuth Client Credentials flow** (`auth.mode=oauth_client_credentials`):
@@ -230,7 +342,7 @@ Returns `{"file": {...}, "uri": "amcs://files/<id>"}`. Pass `thought_id`/`projec
`content_base64` and `content_uri` are mutually exclusive in both tools. `content_base64` and `content_uri` are mutually exclusive in both tools.
**Load a file** — returns metadata, base64 content, and an embedded MCP binary resource (`amcs://files/{id}`): **Load a file** — returns metadata, base64 content, and an embedded MCP binary resource (`amcs://files/{id}`). The `id` field accepts either the bare stored file UUID or the full `amcs://files/{id}` URI:
```json ```json
{ "id": "stored-file-uuid" } { "id": "stored-file-uuid" }

View File

@@ -10,8 +10,8 @@ server:
mcp: mcp:
path: "/mcp" path: "/mcp"
server_name: "amcs" server_name: "amcs"
version: "0.1.0"
transport: "streamable_http" transport: "streamable_http"
session_timeout: "10m"
auth: auth:
header_name: "x-brain-key" header_name: "x-brain-key"

View File

@@ -10,8 +10,8 @@ server:
mcp: mcp:
path: "/mcp" path: "/mcp"
server_name: "amcs" server_name: "amcs"
version: "0.1.0"
transport: "streamable_http" transport: "streamable_http"
session_timeout: "10m"
auth: auth:
header_name: "x-brain-key" header_name: "x-brain-key"

View File

@@ -10,8 +10,8 @@ server:
mcp: mcp:
path: "/mcp" path: "/mcp"
server_name: "amcs" server_name: "amcs"
version: "0.1.0"
transport: "streamable_http" transport: "streamable_http"
session_timeout: "10m"
auth: auth:
header_name: "x-brain-key" header_name: "x-brain-key"

2
go.mod
View File

@@ -3,6 +3,7 @@ module git.warky.dev/wdevs/amcs
go 1.26.1 go 1.26.1
require ( require (
github.com/google/jsonschema-go v0.4.2
github.com/google/uuid v1.6.0 github.com/google/uuid v1.6.0
github.com/jackc/pgx/v5 v5.9.1 github.com/jackc/pgx/v5 v5.9.1
github.com/modelcontextprotocol/go-sdk v1.4.1 github.com/modelcontextprotocol/go-sdk v1.4.1
@@ -12,7 +13,6 @@ require (
) )
require ( require (
github.com/google/jsonschema-go v0.4.2 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
github.com/jackc/puddle/v2 v2.2.2 // indirect github.com/jackc/puddle/v2 v2.2.2 // indirect

View File

@@ -493,7 +493,7 @@ func (c *Client) doJSON(ctx context.Context, path string, requestBody any, dest
} }
payload, readErr := io.ReadAll(resp.Body) payload, readErr := io.ReadAll(resp.Body)
resp.Body.Close() _ = resp.Body.Close()
if readErr != nil { if readErr != nil {
lastErr = fmt.Errorf("%s read response: %w", c.name, readErr) lastErr = fmt.Errorf("%s read response: %w", c.name, readErr)
if attempt < maxAttempts { if attempt < maxAttempts {
@@ -587,7 +587,9 @@ func (c *Client) doChatCompletions(ctx context.Context, reqBody chatCompletionsR
} }
func (c *Client) decodeChatCompletionsResponse(resp *http.Response) (chatCompletionsResponse, error) { func (c *Client) decodeChatCompletionsResponse(resp *http.Response) (chatCompletionsResponse, error) {
defer resp.Body.Close() defer func() {
_ = resp.Body.Close()
}()
contentType := strings.ToLower(resp.Header.Get("Content-Type")) contentType := strings.ToLower(resp.Header.Get("Content-Type"))
if strings.Contains(contentType, "text/event-stream") { if strings.Contains(contentType, "text/event-stream") {

View File

@@ -15,7 +15,9 @@ func TestExtractMetadataFromStreamingResponse(t *testing.T) {
t.Parallel() t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer r.Body.Close() defer func() {
_ = r.Body.Close()
}()
var req chatCompletionsRequest var req chatCompletionsRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil { if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
@@ -68,7 +70,9 @@ func TestExtractMetadataRetriesWithoutJSONMode(t *testing.T) {
plainCalls := 0 plainCalls := 0
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer r.Body.Close() defer func() {
_ = r.Body.Close()
}()
var req chatCompletionsRequest var req chatCompletionsRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil { if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
@@ -131,7 +135,9 @@ func TestExtractMetadataBypassesInvalidFallbackModelAfterFirstFailure(t *testing
invalidFallbackCalls := 0 invalidFallbackCalls := 0
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer r.Body.Close() defer func() {
_ = r.Body.Close()
}()
var req chatCompletionsRequest var req chatCompletionsRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil { if err := json.NewDecoder(r.Body).Decode(&req); err != nil {

View File

@@ -10,6 +10,7 @@ import (
"git.warky.dev/wdevs/amcs/internal/ai" "git.warky.dev/wdevs/amcs/internal/ai"
"git.warky.dev/wdevs/amcs/internal/auth" "git.warky.dev/wdevs/amcs/internal/auth"
"git.warky.dev/wdevs/amcs/internal/buildinfo"
"git.warky.dev/wdevs/amcs/internal/config" "git.warky.dev/wdevs/amcs/internal/config"
"git.warky.dev/wdevs/amcs/internal/mcpserver" "git.warky.dev/wdevs/amcs/internal/mcpserver"
"git.warky.dev/wdevs/amcs/internal/observability" "git.warky.dev/wdevs/amcs/internal/observability"
@@ -23,6 +24,8 @@ func Run(ctx context.Context, configPath string) error {
if err != nil { if err != nil {
return err return err
} }
info := buildinfo.Current()
cfg.MCP.Version = info.Version
logger, err := observability.NewLogger(cfg.Logging) logger, err := observability.NewLogger(cfg.Logging)
if err != nil { if err != nil {
@@ -32,6 +35,10 @@ func Run(ctx context.Context, configPath string) error {
logger.Info("loaded configuration", logger.Info("loaded configuration",
slog.String("path", loadedFrom), slog.String("path", loadedFrom),
slog.String("provider", cfg.AI.Provider), slog.String("provider", cfg.AI.Provider),
slog.String("version", info.Version),
slog.String("tag_name", info.TagName),
slog.String("build_date", info.BuildDate),
slog.String("commit", info.Commit),
) )
db, err := store.New(ctx, cfg.Database) db, err := store.New(ctx, cfg.Database)
@@ -112,9 +119,14 @@ func Run(ctx context.Context, configPath string) error {
}() }()
} }
handler, err := routes(logger, cfg, info, db, provider, keyring, oauthRegistry, tokenStore, authCodes, dynClients, activeProjects)
if err != nil {
return err
}
server := &http.Server{ server := &http.Server{
Addr: fmt.Sprintf("%s:%d", cfg.Server.Host, cfg.Server.Port), Addr: fmt.Sprintf("%s:%d", cfg.Server.Host, cfg.Server.Port),
Handler: routes(logger, cfg, db, provider, keyring, oauthRegistry, tokenStore, authCodes, dynClients, activeProjects), Handler: handler,
ReadTimeout: cfg.Server.ReadTimeout, ReadTimeout: cfg.Server.ReadTimeout,
WriteTimeout: cfg.Server.WriteTimeout, WriteTimeout: cfg.Server.WriteTimeout,
IdleTimeout: cfg.Server.IdleTimeout, IdleTimeout: cfg.Server.IdleTimeout,
@@ -144,7 +156,7 @@ func Run(ctx context.Context, configPath string) error {
} }
} }
func routes(logger *slog.Logger, cfg *config.Config, db *store.DB, provider ai.Provider, keyring *auth.Keyring, oauthRegistry *auth.OAuthRegistry, tokenStore *auth.TokenStore, authCodes *auth.AuthCodeStore, dynClients *auth.DynamicClientStore, activeProjects *session.ActiveProjects) http.Handler { func routes(logger *slog.Logger, cfg *config.Config, info buildinfo.Info, db *store.DB, provider ai.Provider, keyring *auth.Keyring, oauthRegistry *auth.OAuthRegistry, tokenStore *auth.TokenStore, authCodes *auth.AuthCodeStore, dynClients *auth.DynamicClientStore, activeProjects *session.ActiveProjects) (http.Handler, error) {
mux := http.NewServeMux() mux := http.NewServeMux()
authMiddleware := auth.Middleware(cfg.Auth, keyring, oauthRegistry, tokenStore, logger) authMiddleware := auth.Middleware(cfg.Auth, keyring, oauthRegistry, tokenStore, logger)
filesTool := tools.NewFilesTool(db, activeProjects) filesTool := tools.NewFilesTool(db, activeProjects)
@@ -160,6 +172,7 @@ func routes(logger *slog.Logger, cfg *config.Config, db *store.DB, provider ai.P
Delete: tools.NewDeleteTool(db), Delete: tools.NewDeleteTool(db),
Archive: tools.NewArchiveTool(db), Archive: tools.NewArchiveTool(db),
Projects: tools.NewProjectsTool(db, activeProjects), Projects: tools.NewProjectsTool(db, activeProjects),
Version: tools.NewVersionTool(cfg.MCP.ServerName, info),
Context: tools.NewContextTool(db, provider, cfg.Search, activeProjects), Context: tools.NewContextTool(db, provider, cfg.Search, activeProjects),
Recall: tools.NewRecallTool(db, provider, cfg.Search, activeProjects), Recall: tools.NewRecallTool(db, provider, cfg.Search, activeProjects),
Summarize: tools.NewSummarizeTool(db, provider, cfg.Search, activeProjects), Summarize: tools.NewSummarizeTool(db, provider, cfg.Search, activeProjects),
@@ -176,7 +189,10 @@ func routes(logger *slog.Logger, cfg *config.Config, db *store.DB, provider ai.P
Skills: tools.NewSkillsTool(db, activeProjects), Skills: tools.NewSkillsTool(db, activeProjects),
} }
mcpHandler := mcpserver.New(cfg.MCP, logger, toolSet) mcpHandler, err := mcpserver.New(cfg.MCP, logger, toolSet, activeProjects.Clear)
if err != nil {
return nil, fmt.Errorf("build mcp handler: %w", err)
}
mux.Handle(cfg.MCP.Path, authMiddleware(mcpHandler)) mux.Handle(cfg.MCP.Path, authMiddleware(mcpHandler))
mux.Handle("/files", authMiddleware(fileHandler(filesTool))) mux.Handle("/files", authMiddleware(fileHandler(filesTool)))
mux.Handle("/files/{id}", authMiddleware(fileHandler(filesTool))) mux.Handle("/files/{id}", authMiddleware(fileHandler(filesTool)))
@@ -268,7 +284,7 @@ func routes(logger *slog.Logger, cfg *config.Config, db *store.DB, provider ai.P
observability.Recover(logger), observability.Recover(logger),
observability.AccessLog(logger), observability.AccessLog(logger),
observability.Timeout(cfg.Server.WriteTimeout), observability.Timeout(cfg.Server.WriteTimeout),
) ), nil
} }
func runMetadataRetryPass(ctx context.Context, db *store.DB, provider ai.Provider, cfg *config.Config, activeProjects *session.ActiveProjects, logger *slog.Logger) { func runMetadataRetryPass(ctx context.Context, db *store.DB, provider ai.Provider, cfg *config.Config, activeProjects *session.ActiveProjects, logger *slog.Logger) {

View File

@@ -93,7 +93,9 @@ func parseMultipartUpload(r *http.Request) (tools.SaveFileDecodedInput, error) {
if err != nil { if err != nil {
return tools.SaveFileDecodedInput{}, errors.New("multipart upload requires a file field named \"file\"") return tools.SaveFileDecodedInput{}, errors.New("multipart upload requires a file field named \"file\"")
} }
defer file.Close() defer func() {
_ = file.Close()
}()
content, err := io.ReadAll(file) content, err := io.ReadAll(file)
if err != nil { if err != nil {

View File

@@ -15,7 +15,9 @@ func TestServeLLMInstructions(t *testing.T) {
serveLLMInstructions(rec, req) serveLLMInstructions(rec, req)
res := rec.Result() res := rec.Result()
defer res.Body.Close() defer func() {
_ = res.Body.Close()
}()
if res.StatusCode != http.StatusOK { if res.StatusCode != http.StatusOK {
t.Fatalf("status = %d, want %d", res.StatusCode, http.StatusOK) t.Fatalf("status = %d, want %d", res.StatusCode, http.StatusOK)

View File

@@ -320,7 +320,7 @@ func issueToken(w http.ResponseWriter, keyID string, tokenStore *auth.TokenStore
func serveAuthorizePage(w http.ResponseWriter, clientName, clientID, redirectURI, state, codeChallenge, codeChallengeMethod, scope string) { func serveAuthorizePage(w http.ResponseWriter, clientName, clientID, redirectURI, state, codeChallenge, codeChallengeMethod, scope string) {
e := html.EscapeString e := html.EscapeString
w.Header().Set("Content-Type", "text/html; charset=utf-8") w.Header().Set("Content-Type", "text/html; charset=utf-8")
fmt.Fprintf(w, `<!DOCTYPE html> _, _ = fmt.Fprintf(w, `<!DOCTYPE html>
<html> <html>
<head> <head>
<meta charset=utf-8> <meta charset=utf-8>

View File

@@ -89,21 +89,6 @@ func Middleware(cfg config.AuthConfig, keyring *Keyring, oauthRegistry *OAuthReg
} }
} }
func extractToken(r *http.Request, headerName string) string {
token := strings.TrimSpace(r.Header.Get(headerName))
if token != "" {
return token
}
authHeader := strings.TrimSpace(r.Header.Get("Authorization"))
scheme, credentials, ok := strings.Cut(authHeader, " ")
if !ok || !strings.EqualFold(scheme, "Bearer") {
return ""
}
return strings.TrimSpace(credentials)
}
func extractBearer(r *http.Request) string { func extractBearer(r *http.Request) string {
authHeader := strings.TrimSpace(r.Header.Get("Authorization")) authHeader := strings.TrimSpace(r.Header.Get("Authorization"))
scheme, credentials, ok := strings.Cut(authHeader, " ") scheme, credentials, ok := strings.Cut(authHeader, " ")

View File

@@ -60,7 +60,6 @@ func TestMiddlewareAllowsOAuthBasicAuthAndSetsContext(t *testing.T) {
} }
} }
func TestMiddlewareRejectsOAuthMissingOrInvalidCredentials(t *testing.T) { func TestMiddlewareRejectsOAuthMissingOrInvalidCredentials(t *testing.T) {
oauthRegistry, err := NewOAuthRegistry([]config.OAuthClient{{ oauthRegistry, err := NewOAuthRegistry([]config.OAuthClient{{
ID: "oauth-client", ID: "oauth-client",

View File

@@ -0,0 +1,41 @@
package buildinfo
import "strings"
var (
Version = "dev"
TagName = "dev"
Commit = "unknown"
BuildDate = "unknown"
)
type Info struct {
Version string `json:"version"`
TagName string `json:"tag_name"`
Commit string `json:"commit"`
BuildDate string `json:"build_date"`
}
func Current() Info {
info := Info{
Version: strings.TrimSpace(Version),
TagName: strings.TrimSpace(TagName),
Commit: strings.TrimSpace(Commit),
BuildDate: strings.TrimSpace(BuildDate),
}
if info.Version == "" {
info.Version = "dev"
}
if info.TagName == "" {
info.TagName = info.Version
}
if info.Commit == "" {
info.Commit = "unknown"
}
if info.BuildDate == "" {
info.BuildDate = "unknown"
}
return info
}

View File

@@ -0,0 +1,32 @@
package buildinfo
import "testing"
func TestCurrentAppliesFallbacks(t *testing.T) {
originalVersion, originalTagName, originalCommit, originalBuildDate := Version, TagName, Commit, BuildDate
t.Cleanup(func() {
Version = originalVersion
TagName = originalTagName
Commit = originalCommit
BuildDate = originalBuildDate
})
Version = ""
TagName = ""
Commit = ""
BuildDate = ""
info := Current()
if info.Version != "dev" {
t.Fatalf("version = %q, want %q", info.Version, "dev")
}
if info.TagName != "dev" {
t.Fatalf("tag_name = %q, want %q", info.TagName, "dev")
}
if info.Commit != "unknown" {
t.Fatalf("commit = %q, want %q", info.Commit, "unknown")
}
if info.BuildDate != "unknown" {
t.Fatalf("build_date = %q, want %q", info.BuildDate, "unknown")
}
}

View File

@@ -35,6 +35,7 @@ type MCPConfig struct {
ServerName string `yaml:"server_name"` ServerName string `yaml:"server_name"`
Version string `yaml:"version"` Version string `yaml:"version"`
Transport string `yaml:"transport"` Transport string `yaml:"transport"`
SessionTimeout time.Duration `yaml:"session_timeout"`
} }
type AuthConfig struct { type AuthConfig struct {

View File

@@ -7,6 +7,7 @@ import (
"strings" "strings"
"time" "time"
"git.warky.dev/wdevs/amcs/internal/buildinfo"
"gopkg.in/yaml.v3" "gopkg.in/yaml.v3"
) )
@@ -46,6 +47,7 @@ func ResolvePath(explicitPath string) string {
} }
func defaultConfig() Config { func defaultConfig() Config {
info := buildinfo.Current()
return Config{ return Config{
Server: ServerConfig{ Server: ServerConfig{
Host: "0.0.0.0", Host: "0.0.0.0",
@@ -57,8 +59,9 @@ func defaultConfig() Config {
MCP: MCPConfig{ MCP: MCPConfig{
Path: "/mcp", Path: "/mcp",
ServerName: "amcs", ServerName: "amcs",
Version: "0.1.0", Version: info.Version,
Transport: "streamable_http", Transport: "streamable_http",
SessionTimeout: 10 * time.Minute,
}, },
Auth: AuthConfig{ Auth: AuthConfig{
HeaderName: "x-brain-key", HeaderName: "x-brain-key",

View File

@@ -4,6 +4,7 @@ import (
"os" "os"
"path/filepath" "path/filepath"
"testing" "testing"
"time"
) )
func TestResolvePathPrecedence(t *testing.T) { func TestResolvePathPrecedence(t *testing.T) {
@@ -37,6 +38,7 @@ server:
port: 8080 port: 8080
mcp: mcp:
path: "/mcp" path: "/mcp"
session_timeout: "30m"
auth: auth:
keys: keys:
- id: "test" - id: "test"
@@ -80,6 +82,9 @@ logging:
if cfg.Server.Port != 9090 { if cfg.Server.Port != 9090 {
t.Fatalf("server port = %d, want 9090", cfg.Server.Port) t.Fatalf("server port = %d, want 9090", cfg.Server.Port)
} }
if cfg.MCP.SessionTimeout != 30*time.Minute {
t.Fatalf("mcp session timeout = %v, want 30m", cfg.MCP.SessionTimeout)
}
} }
func TestLoadAppliesOllamaEnvOverrides(t *testing.T) { func TestLoadAppliesOllamaEnvOverrides(t *testing.T) {

View File

@@ -33,6 +33,9 @@ func (c Config) Validate() error {
if strings.TrimSpace(c.MCP.Path) == "" { if strings.TrimSpace(c.MCP.Path) == "" {
return fmt.Errorf("invalid config: mcp.path is required") return fmt.Errorf("invalid config: mcp.path is required")
} }
if c.MCP.SessionTimeout <= 0 {
return fmt.Errorf("invalid config: mcp.session_timeout must be greater than zero")
}
switch c.AI.Provider { switch c.AI.Provider {
case "litellm", "ollama", "openrouter": case "litellm", "ollama", "openrouter":

View File

@@ -1,11 +1,14 @@
package config package config
import "testing" import (
"testing"
"time"
)
func validConfig() Config { func validConfig() Config {
return Config{ return Config{
Server: ServerConfig{Port: 8080}, Server: ServerConfig{Port: 8080},
MCP: MCPConfig{Path: "/mcp"}, MCP: MCPConfig{Path: "/mcp", SessionTimeout: 10 * time.Minute},
Auth: AuthConfig{ Auth: AuthConfig{
Keys: []APIKey{{ID: "test", Value: "secret"}}, Keys: []APIKey{{ID: "test", Value: "secret"}},
}, },
@@ -121,3 +124,12 @@ func TestValidateRejectsInvalidMetadataRetryConfig(t *testing.T) {
t.Fatal("Validate() error = nil, want error for invalid metadata retry config") t.Fatal("Validate() error = nil, want error for invalid metadata retry config")
} }
} }
func TestValidateRejectsInvalidMCPSessionTimeout(t *testing.T) {
cfg := validConfig()
cfg.MCP.SessionTimeout = 0
if err := cfg.Validate(); err == nil {
t.Fatal("Validate() error = nil, want error for invalid mcp session timeout")
}
}

View File

@@ -0,0 +1,30 @@
package mcperrors
const (
CodeSessionRequired = -32050
CodeProjectRequired = -32051
CodeProjectNotFound = -32052
CodeInvalidID = -32053
CodeEntityNotFound = -32054
)
const (
TypeInvalidArguments = "invalid_arguments"
TypeInvalidInput = "invalid_input"
TypeSessionRequired = "session_required"
TypeProjectRequired = "project_required"
TypeProjectNotFound = "project_not_found"
TypeInvalidID = "invalid_id"
TypeEntityNotFound = "entity_not_found"
)
type Data struct {
Type string `json:"type"`
Field string `json:"field,omitempty"`
Fields []string `json:"fields,omitempty"`
Value string `json:"value,omitempty"`
Detail string `json:"detail,omitempty"`
Hint string `json:"hint,omitempty"`
Project string `json:"project,omitempty"`
Entity string `json:"entity,omitempty"`
}

View File

@@ -0,0 +1,154 @@
package mcpserver
import (
"context"
"encoding/json"
"errors"
"testing"
"github.com/modelcontextprotocol/go-sdk/jsonrpc"
"github.com/modelcontextprotocol/go-sdk/mcp"
"git.warky.dev/wdevs/amcs/internal/config"
"git.warky.dev/wdevs/amcs/internal/mcperrors"
"git.warky.dev/wdevs/amcs/internal/session"
"git.warky.dev/wdevs/amcs/internal/tools"
)
func TestToolValidationErrorsRoundTripAsStructuredJSONRPC(t *testing.T) {
server := mcp.NewServer(&mcp.Implementation{Name: "test", Version: "0.0.1"}, nil)
projects := tools.NewProjectsTool(nil, session.NewActiveProjects())
contextTool := tools.NewContextTool(nil, nil, toolsSearchConfig(), session.NewActiveProjects())
if err := addTool(server, nil, &mcp.Tool{
Name: "create_project",
Description: "Create a named project container for thoughts.",
}, projects.Create); err != nil {
t.Fatalf("add create_project tool: %v", err)
}
if err := addTool(server, nil, &mcp.Tool{
Name: "get_project_context",
Description: "Get recent and semantic context for a project.",
}, contextTool.Handle); err != nil {
t.Fatalf("add get_project_context tool: %v", err)
}
ct, st := mcp.NewInMemoryTransports()
_, err := server.Connect(context.Background(), st, nil)
if err != nil {
t.Fatalf("connect server: %v", err)
}
client := mcp.NewClient(&mcp.Implementation{Name: "client", Version: "0.0.1"}, nil)
cs, err := client.Connect(context.Background(), ct, nil)
if err != nil {
t.Fatalf("connect client: %v", err)
}
defer func() {
_ = cs.Close()
}()
t.Run("required_field", func(t *testing.T) {
_, err := cs.CallTool(context.Background(), &mcp.CallToolParams{
Name: "create_project",
Arguments: map[string]any{"name": ""},
})
if err == nil {
t.Fatal("CallTool(create_project) error = nil, want error")
}
rpcErr, data := requireWireError(t, err)
if rpcErr.Code != jsonrpc.CodeInvalidParams {
t.Fatalf("create_project code = %d, want %d", rpcErr.Code, jsonrpc.CodeInvalidParams)
}
if data.Type != mcperrors.TypeInvalidInput {
t.Fatalf("create_project data.type = %q, want %q", data.Type, mcperrors.TypeInvalidInput)
}
if data.Field != "name" {
t.Fatalf("create_project data.field = %q, want %q", data.Field, "name")
}
})
t.Run("schema_required_field", func(t *testing.T) {
_, err := cs.CallTool(context.Background(), &mcp.CallToolParams{
Name: "create_project",
Arguments: map[string]any{},
})
if err == nil {
t.Fatal("CallTool(create_project missing field) error = nil, want error")
}
rpcErr, data := requireWireError(t, err)
if rpcErr.Code != jsonrpc.CodeInvalidParams {
t.Fatalf("create_project schema code = %d, want %d", rpcErr.Code, jsonrpc.CodeInvalidParams)
}
if data.Type != mcperrors.TypeInvalidArguments {
t.Fatalf("create_project schema data.type = %q, want %q", data.Type, mcperrors.TypeInvalidArguments)
}
if data.Field != "name" {
t.Fatalf("create_project schema data.field = %q, want %q", data.Field, "name")
}
if data.Detail == "" {
t.Fatal("create_project schema data.detail = empty, want validation detail")
}
})
t.Run("project_required", func(t *testing.T) {
_, err := cs.CallTool(context.Background(), &mcp.CallToolParams{
Name: "get_project_context",
Arguments: map[string]any{},
})
if err == nil {
t.Fatal("CallTool(get_project_context) error = nil, want error")
}
rpcErr, data := requireWireError(t, err)
if rpcErr.Code != mcperrors.CodeProjectRequired {
t.Fatalf("get_project_context code = %d, want %d", rpcErr.Code, mcperrors.CodeProjectRequired)
}
if data.Type != mcperrors.TypeProjectRequired {
t.Fatalf("get_project_context data.type = %q, want %q", data.Type, mcperrors.TypeProjectRequired)
}
if data.Field != "project" {
t.Fatalf("get_project_context data.field = %q, want %q", data.Field, "project")
}
if data.Hint == "" {
t.Fatal("get_project_context data.hint = empty, want guidance")
}
})
}
type wireErrorData struct {
Type string `json:"type"`
Field string `json:"field,omitempty"`
Fields []string `json:"fields,omitempty"`
Detail string `json:"detail,omitempty"`
Hint string `json:"hint,omitempty"`
}
func requireWireError(t *testing.T, err error) (*jsonrpc.Error, wireErrorData) {
t.Helper()
var rpcErr *jsonrpc.Error
if !errors.As(err, &rpcErr) {
t.Fatalf("error type = %T, want *jsonrpc.Error", err)
}
var data wireErrorData
if len(rpcErr.Data) > 0 {
if unmarshalErr := json.Unmarshal(rpcErr.Data, &data); unmarshalErr != nil {
t.Fatalf("unmarshal wire error data: %v", unmarshalErr)
}
}
return rpcErr, data
}
func toolsSearchConfig() config.SearchConfig {
return config.SearchConfig{
DefaultLimit: 10,
MaxLimit: 50,
DefaultThreshold: 0.7,
}
}

View File

@@ -0,0 +1,42 @@
package mcpserver
import (
"context"
"iter"
"github.com/modelcontextprotocol/go-sdk/mcp"
)
type cleanupEventStore struct {
base mcp.EventStore
onSessionClosed func(string)
}
func newCleanupEventStore(base mcp.EventStore, onSessionClosed func(string)) mcp.EventStore {
return &cleanupEventStore{
base: base,
onSessionClosed: onSessionClosed,
}
}
func (s *cleanupEventStore) Open(ctx context.Context, sessionID, streamID string) error {
return s.base.Open(ctx, sessionID, streamID)
}
func (s *cleanupEventStore) Append(ctx context.Context, sessionID, streamID string, data []byte) error {
return s.base.Append(ctx, sessionID, streamID, data)
}
func (s *cleanupEventStore) After(ctx context.Context, sessionID, streamID string, index int) iter.Seq2[[]byte, error] {
return s.base.After(ctx, sessionID, streamID, index)
}
func (s *cleanupEventStore) SessionClosed(ctx context.Context, sessionID string) error {
if err := s.base.SessionClosed(ctx, sessionID); err != nil {
return err
}
if s.onSessionClosed != nil {
s.onSessionClosed(sessionID)
}
return nil
}

View File

@@ -0,0 +1,30 @@
package mcpserver
import (
"context"
"testing"
"github.com/google/uuid"
"github.com/modelcontextprotocol/go-sdk/mcp"
"git.warky.dev/wdevs/amcs/internal/session"
)
func TestCleanupEventStoreSessionClosedClearsActiveProject(t *testing.T) {
activeProjects := session.NewActiveProjects()
activeProjects.Set("session-1", uuid.New())
store := newCleanupEventStore(mcp.NewMemoryEventStore(nil), activeProjects.Clear)
if _, ok := activeProjects.Get("session-1"); !ok {
t.Fatal("active project missing before SessionClosed")
}
if err := store.SessionClosed(context.Background(), "session-1"); err != nil {
t.Fatalf("SessionClosed() error = %v", err)
}
if _, ok := activeProjects.Get("session-1"); ok {
t.Fatal("active project still present after SessionClosed")
}
}

View File

@@ -6,15 +6,31 @@ import (
"fmt" "fmt"
"log/slog" "log/slog"
"reflect" "reflect"
"regexp"
"strings"
"time" "time"
"github.com/google/jsonschema-go/jsonschema" "github.com/google/jsonschema-go/jsonschema"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/modelcontextprotocol/go-sdk/jsonrpc"
"github.com/modelcontextprotocol/go-sdk/mcp" "github.com/modelcontextprotocol/go-sdk/mcp"
"git.warky.dev/wdevs/amcs/internal/mcperrors"
) )
const maxLoggedArgBytes = 512 const maxLoggedArgBytes = 512
var sensitiveToolArgKeys = map[string]struct{}{
"client_secret": {},
"content": {},
"content_base64": {},
"content_path": {},
"email": {},
"notes": {},
"phone": {},
"value": {},
}
var toolSchemaOptions = &jsonschema.ForOptions{ var toolSchemaOptions = &jsonschema.ForOptions{
TypeSchemas: map[reflect.Type]*jsonschema.Schema{ TypeSchemas: map[reflect.Type]*jsonschema.Schema{
reflect.TypeFor[uuid.UUID](): { reflect.TypeFor[uuid.UUID](): {
@@ -24,19 +40,92 @@ var toolSchemaOptions = &jsonschema.ForOptions{
}, },
} }
func addTool[In any, Out any](server *mcp.Server, logger *slog.Logger, tool *mcp.Tool, handler func(context.Context, *mcp.CallToolRequest, In) (*mcp.CallToolResult, Out, error)) { var (
missingPropertyPattern = regexp.MustCompile(`missing properties: \["([^"]+)"\]`)
propertyPathPattern = regexp.MustCompile(`validating /properties/([^:]+):`)
structFieldPattern = regexp.MustCompile(`\.([A-Za-z0-9_]+) of type`)
)
func addTool[In any, Out any](server *mcp.Server, logger *slog.Logger, tool *mcp.Tool, handler func(context.Context, *mcp.CallToolRequest, In) (*mcp.CallToolResult, Out, error)) error {
if err := setToolSchemas[In, Out](tool); err != nil { if err := setToolSchemas[In, Out](tool); err != nil {
panic(fmt.Sprintf("configure MCP tool %q schemas: %v", tool.Name, err)) return fmt.Errorf("configure MCP tool %q schemas: %w", tool.Name, err)
}
mcp.AddTool(server, tool, logToolCall(logger, tool.Name, handler))
} }
func logToolCall[In any, Out any](logger *slog.Logger, toolName string, handler func(context.Context, *mcp.CallToolRequest, In) (*mcp.CallToolResult, Out, error)) func(context.Context, *mcp.CallToolRequest, In) (*mcp.CallToolResult, Out, error) { inputResolved, err := resolveToolSchema(tool.InputSchema)
if err != nil {
return fmt.Errorf("resolve MCP tool %q input schema: %w", tool.Name, err)
}
outputResolved, err := resolveToolSchema(tool.OutputSchema)
if err != nil {
return fmt.Errorf("resolve MCP tool %q output schema: %w", tool.Name, err)
}
server.AddTool(tool, logToolCall(logger, tool.Name, func(ctx context.Context, req *mcp.CallToolRequest) (*mcp.CallToolResult, error) {
var input json.RawMessage
if req != nil && req.Params != nil && req.Params.Arguments != nil {
input = req.Params.Arguments
}
input, err = applyToolSchema(input, inputResolved)
if err != nil {
return nil, invalidArgumentsError(err)
}
var in In
if input != nil {
if err := json.Unmarshal(input, &in); err != nil {
return nil, invalidArgumentsError(err)
}
}
result, out, err := handler(ctx, req, in)
if err != nil {
if wireErr, ok := err.(*jsonrpc.Error); ok {
return nil, wireErr
}
var errResult mcp.CallToolResult
errResult.SetError(err)
return &errResult, nil
}
if result == nil {
result = &mcp.CallToolResult{}
}
if outputResolved == nil {
return result, nil
}
outValue := normalizeTypedNil[Out](out)
if outValue == nil {
return result, nil
}
outBytes, err := json.Marshal(outValue)
if err != nil {
return nil, fmt.Errorf("marshaling output: %w", err)
}
outJSON, err := applyToolSchema(json.RawMessage(outBytes), outputResolved)
if err != nil {
return nil, fmt.Errorf("validating tool output: %w", err)
}
result.StructuredContent = outJSON
if result.Content == nil {
result.Content = []mcp.Content{&mcp.TextContent{Text: string(outJSON)}}
}
return result, nil
}))
return nil
}
func logToolCall(logger *slog.Logger, toolName string, handler mcp.ToolHandler) mcp.ToolHandler {
if logger == nil { if logger == nil {
return handler return handler
} }
return func(ctx context.Context, req *mcp.CallToolRequest, in In) (*mcp.CallToolResult, Out, error) { return func(ctx context.Context, req *mcp.CallToolRequest) (*mcp.CallToolResult, error) {
start := time.Now() start := time.Now()
attrs := []any{slog.String("tool", toolName)} attrs := []any{slog.String("tool", toolName)}
if req != nil && req.Params != nil { if req != nil && req.Params != nil {
@@ -44,23 +133,27 @@ func logToolCall[In any, Out any](logger *slog.Logger, toolName string, handler
} }
logger.Info("mcp tool started", attrs...) logger.Info("mcp tool started", attrs...)
result, out, err := handler(ctx, req, in) result, err := handler(ctx, req)
completionAttrs := append([]any{}, attrs...) completionAttrs := append([]any{}, attrs...)
completionAttrs = append(completionAttrs, slog.String("duration", formatLogDuration(time.Since(start)))) completionAttrs = append(completionAttrs, slog.String("duration", formatLogDuration(time.Since(start))))
if err != nil { if err != nil {
completionAttrs = append(completionAttrs, slog.String("error", err.Error())) completionAttrs = append(completionAttrs, slog.String("error", err.Error()))
logger.Error("mcp tool completed", completionAttrs...) logger.Error("mcp tool completed", completionAttrs...)
return result, out, err return result, err
} }
logger.Info("mcp tool completed", completionAttrs...) logger.Info("mcp tool completed", completionAttrs...)
return result, out, nil return result, nil
} }
} }
func truncateArgs(args any) string { func truncateArgs(args any) string {
b, err := json.Marshal(args) redacted, err := redactArguments(args)
if err != nil {
return "<unserializable>"
}
b, err := json.Marshal(redacted)
if err != nil { if err != nil {
return "<unserializable>" return "<unserializable>"
} }
@@ -70,6 +163,52 @@ func truncateArgs(args any) string {
return string(b[:maxLoggedArgBytes]) + fmt.Sprintf("… (%d bytes total)", len(b)) return string(b[:maxLoggedArgBytes]) + fmt.Sprintf("… (%d bytes total)", len(b))
} }
func redactArguments(args any) (any, error) {
if args == nil {
return nil, nil
}
b, err := json.Marshal(args)
if err != nil {
return nil, err
}
var decoded any
if err := json.Unmarshal(b, &decoded); err != nil {
return nil, err
}
return redactJSONValue("", decoded), nil
}
func redactJSONValue(key string, value any) any {
if isSensitiveToolArgKey(key) {
return "<redacted>"
}
switch typed := value.(type) {
case map[string]any:
redacted := make(map[string]any, len(typed))
for childKey, childValue := range typed {
redacted[childKey] = redactJSONValue(childKey, childValue)
}
return redacted
case []any:
redacted := make([]any, len(typed))
for i, item := range typed {
redacted[i] = redactJSONValue("", item)
}
return redacted
default:
return value
}
}
func isSensitiveToolArgKey(key string) bool {
_, ok := sensitiveToolArgKeys[strings.ToLower(strings.TrimSpace(key))]
return ok
}
func formatLogDuration(d time.Duration) string { func formatLogDuration(d time.Duration) string {
if d < 0 { if d < 0 {
d = -d d = -d
@@ -101,3 +240,104 @@ func setToolSchemas[In any, Out any](tool *mcp.Tool) error {
return nil return nil
} }
func resolveToolSchema(schema any) (*jsonschema.Resolved, error) {
if schema == nil {
return nil, nil
}
if typed, ok := schema.(*jsonschema.Schema); ok {
return typed.Resolve(&jsonschema.ResolveOptions{ValidateDefaults: true})
}
var remarshalTarget *jsonschema.Schema
if err := remarshalJSON(schema, &remarshalTarget); err != nil {
return nil, err
}
return remarshalTarget.Resolve(&jsonschema.ResolveOptions{ValidateDefaults: true})
}
func applyToolSchema(data json.RawMessage, resolved *jsonschema.Resolved) (json.RawMessage, error) {
if resolved == nil {
return data, nil
}
value := make(map[string]any)
if len(data) > 0 {
if err := json.Unmarshal(data, &value); err != nil {
return nil, fmt.Errorf("unmarshaling arguments: %w", err)
}
}
if err := resolved.ApplyDefaults(&value); err != nil {
return nil, fmt.Errorf("applying schema defaults: %w", err)
}
if err := resolved.Validate(&value); err != nil {
return nil, err
}
normalized, err := json.Marshal(value)
if err != nil {
return nil, fmt.Errorf("marshalling with defaults: %w", err)
}
return normalized, nil
}
func remarshalJSON(src any, dst any) error {
b, err := json.Marshal(src)
if err != nil {
return err
}
return json.Unmarshal(b, dst)
}
func normalizeTypedNil[T any](value T) any {
rt := reflect.TypeFor[T]()
if rt.Kind() == reflect.Pointer {
var zero T
if any(value) == any(zero) {
return reflect.Zero(rt.Elem()).Interface()
}
}
return any(value)
}
func invalidArgumentsError(err error) error {
detail := err.Error()
field := validationField(detail)
payload := mcperrors.Data{
Type: mcperrors.TypeInvalidArguments,
Detail: detail,
}
if field != "" {
payload.Field = field
payload.Hint = "check the " + field + " argument"
}
data, marshalErr := json.Marshal(payload)
if marshalErr != nil {
return &jsonrpc.Error{
Code: jsonrpc.CodeInvalidParams,
Message: "invalid tool arguments",
}
}
return &jsonrpc.Error{
Code: jsonrpc.CodeInvalidParams,
Message: "invalid tool arguments",
Data: data,
}
}
func validationField(detail string) string {
if matches := missingPropertyPattern.FindStringSubmatch(detail); len(matches) == 2 {
return matches[1]
}
if matches := propertyPathPattern.FindStringSubmatch(detail); len(matches) == 2 {
return matches[1]
}
if matches := structFieldPattern.FindStringSubmatch(detail); len(matches) == 2 {
return strings.ToLower(matches[1])
}
return ""
}

View File

@@ -1,6 +1,9 @@
package mcpserver package mcpserver
import ( import (
"context"
"encoding/json"
"strings"
"testing" "testing"
"github.com/google/jsonschema-go/jsonschema" "github.com/google/jsonschema-go/jsonschema"
@@ -40,3 +43,59 @@ func TestSetToolSchemasUsesStringUUIDsInListOutput(t *testing.T) {
t.Fatalf("id schema format = %q, want %q", idSchema.Format, "uuid") t.Fatalf("id schema format = %q, want %q", idSchema.Format, "uuid")
} }
} }
func TestTruncateArgsRedactsSensitiveFields(t *testing.T) {
args := json.RawMessage(`{
"query": "todo from yesterday",
"content": "private thought body",
"content_base64": "c2VjcmV0LWJpbmFyeQ==",
"email": "user@example.com",
"nested": {
"notes": "private note",
"phone": "+1-555-0100",
"keep": "visible"
}
}`)
got := truncateArgs(args)
for _, secret := range []string{
"private thought body",
"c2VjcmV0LWJpbmFyeQ==",
"user@example.com",
"private note",
"+1-555-0100",
} {
if strings.Contains(got, secret) {
t.Fatalf("truncateArgs leaked sensitive value %q in %s", secret, got)
}
}
for _, expected := range []string{
`"query":"todo from yesterday"`,
`"keep":"visible"`,
`"content":"\u003credacted\u003e"`,
`"content_base64":"\u003credacted\u003e"`,
`"email":"\u003credacted\u003e"`,
`"notes":"\u003credacted\u003e"`,
`"phone":"\u003credacted\u003e"`,
} {
if !strings.Contains(got, expected) {
t.Fatalf("truncateArgs(%s) missing %q", got, expected)
}
}
}
func TestAddToolReturnsSchemaErrorInsteadOfPanicking(t *testing.T) {
server := mcp.NewServer(&mcp.Implementation{Name: "test", Version: "0.0.1"}, nil)
err := addTool(server, nil, &mcp.Tool{Name: "broken"}, func(_ context.Context, _ *mcp.CallToolRequest, _ struct{}) (*mcp.CallToolResult, chan int, error) {
return nil, nil, nil
})
if err == nil {
t.Fatal("addTool() error = nil, want schema inference error")
}
if !strings.Contains(err.Error(), `configure MCP tool "broken" schemas`) {
t.Fatalf("addTool() error = %q, want tool context", err.Error())
}
}

View File

@@ -3,7 +3,6 @@ package mcpserver
import ( import (
"log/slog" "log/slog"
"net/http" "net/http"
"time"
"github.com/modelcontextprotocol/go-sdk/mcp" "github.com/modelcontextprotocol/go-sdk/mcp"
@@ -12,6 +11,7 @@ import (
) )
type ToolSet struct { type ToolSet struct {
Version *tools.VersionTool
Capture *tools.CaptureTool Capture *tools.CaptureTool
Search *tools.SearchTool Search *tools.SearchTool
List *tools.ListTool List *tools.ListTool
@@ -37,355 +37,480 @@ type ToolSet struct {
Skills *tools.SkillsTool Skills *tools.SkillsTool
} }
func New(cfg config.MCPConfig, logger *slog.Logger, toolSet ToolSet) http.Handler { func New(cfg config.MCPConfig, logger *slog.Logger, toolSet ToolSet, onSessionClosed func(string)) (http.Handler, error) {
server := mcp.NewServer(&mcp.Implementation{ server := mcp.NewServer(&mcp.Implementation{
Name: cfg.ServerName, Name: cfg.ServerName,
Version: cfg.Version, Version: cfg.Version,
}, nil) }, nil)
addTool(server, logger, &mcp.Tool{ for _, register := range []func(*mcp.Server, *slog.Logger, ToolSet) error{
registerSystemTools,
registerThoughtTools,
registerProjectTools,
registerFileTools,
registerMaintenanceTools,
registerHouseholdTools,
registerCalendarTools,
registerMealTools,
registerCRMTools,
registerSkillTools,
} {
if err := register(server, logger, toolSet); err != nil {
return nil, err
}
}
opts := &mcp.StreamableHTTPOptions{
JSONResponse: true,
SessionTimeout: cfg.SessionTimeout,
}
if onSessionClosed != nil {
opts.EventStore = newCleanupEventStore(mcp.NewMemoryEventStore(nil), onSessionClosed)
}
return mcp.NewStreamableHTTPHandler(func(*http.Request) *mcp.Server {
return server
}, opts), nil
}
func registerSystemTools(server *mcp.Server, logger *slog.Logger, toolSet ToolSet) error {
if err := addTool(server, logger, &mcp.Tool{
Name: "get_version_info",
Description: "Return the server build version information, including version, tag name, commit, and build date.",
}, toolSet.Version.GetInfo); err != nil {
return err
}
return nil
}
func registerThoughtTools(server *mcp.Server, logger *slog.Logger, toolSet ToolSet) error {
if err := addTool(server, logger, &mcp.Tool{
Name: "capture_thought", Name: "capture_thought",
Description: "Store a thought with generated embeddings and extracted metadata.", Description: "Store a thought with generated embeddings and extracted metadata.",
}, toolSet.Capture.Handle) }, toolSet.Capture.Handle); err != nil {
return err
addTool(server, logger, &mcp.Tool{ }
if err := addTool(server, logger, &mcp.Tool{
Name: "search_thoughts", Name: "search_thoughts",
Description: "Search stored thoughts by semantic similarity.", Description: "Search stored thoughts by semantic similarity.",
}, toolSet.Search.Handle) }, toolSet.Search.Handle); err != nil {
return err
addTool(server, logger, &mcp.Tool{ }
if err := addTool(server, logger, &mcp.Tool{
Name: "list_thoughts", Name: "list_thoughts",
Description: "List recent thoughts with optional metadata filters.", Description: "List recent thoughts with optional metadata filters.",
}, toolSet.List.Handle) }, toolSet.List.Handle); err != nil {
return err
addTool(server, logger, &mcp.Tool{ }
if err := addTool(server, logger, &mcp.Tool{
Name: "thought_stats", Name: "thought_stats",
Description: "Get counts and top metadata buckets across stored thoughts.", Description: "Get counts and top metadata buckets across stored thoughts.",
}, toolSet.Stats.Handle) }, toolSet.Stats.Handle); err != nil {
return err
addTool(server, logger, &mcp.Tool{ }
if err := addTool(server, logger, &mcp.Tool{
Name: "get_thought", Name: "get_thought",
Description: "Retrieve a full thought by id.", Description: "Retrieve a full thought by id.",
}, toolSet.Get.Handle) }, toolSet.Get.Handle); err != nil {
return err
addTool(server, logger, &mcp.Tool{ }
if err := addTool(server, logger, &mcp.Tool{
Name: "update_thought", Name: "update_thought",
Description: "Update thought content or merge metadata.", Description: "Update thought content or merge metadata.",
}, toolSet.Update.Handle) }, toolSet.Update.Handle); err != nil {
return err
addTool(server, logger, &mcp.Tool{ }
if err := addTool(server, logger, &mcp.Tool{
Name: "delete_thought", Name: "delete_thought",
Description: "Hard-delete a thought by id.", Description: "Hard-delete a thought by id.",
}, toolSet.Delete.Handle) }, toolSet.Delete.Handle); err != nil {
return err
addTool(server, logger, &mcp.Tool{ }
if err := addTool(server, logger, &mcp.Tool{
Name: "archive_thought", Name: "archive_thought",
Description: "Archive a thought so it is hidden from default search and listing.", Description: "Archive a thought so it is hidden from default search and listing.",
}, toolSet.Archive.Handle) }, toolSet.Archive.Handle); err != nil {
return err
addTool(server, logger, &mcp.Tool{ }
Name: "create_project", if err := addTool(server, logger, &mcp.Tool{
Description: "Create a named project container for thoughts.",
}, toolSet.Projects.Create)
addTool(server, logger, &mcp.Tool{
Name: "list_projects",
Description: "List projects and their current thought counts.",
}, toolSet.Projects.List)
addTool(server, logger, &mcp.Tool{
Name: "set_active_project",
Description: "Set the active project for the current MCP session.",
}, toolSet.Projects.SetActive)
addTool(server, logger, &mcp.Tool{
Name: "get_active_project",
Description: "Return the active project for the current MCP session.",
}, toolSet.Projects.GetActive)
addTool(server, logger, &mcp.Tool{
Name: "get_project_context",
Description: "Get recent and semantic context for a project.",
}, toolSet.Context.Handle)
addTool(server, logger, &mcp.Tool{
Name: "recall_context",
Description: "Recall semantically relevant and recent context.",
}, toolSet.Recall.Handle)
addTool(server, logger, &mcp.Tool{
Name: "summarize_thoughts", Name: "summarize_thoughts",
Description: "Summarize a filtered or searched set of thoughts.", Description: "Summarize a filtered or searched set of thoughts.",
}, toolSet.Summarize.Handle) }, toolSet.Summarize.Handle); err != nil {
return err
addTool(server, logger, &mcp.Tool{ }
if err := addTool(server, logger, &mcp.Tool{
Name: "recall_context",
Description: "Recall semantically relevant and recent context.",
}, toolSet.Recall.Handle); err != nil {
return err
}
if err := addTool(server, logger, &mcp.Tool{
Name: "link_thoughts", Name: "link_thoughts",
Description: "Create a typed relationship between two thoughts.", Description: "Create a typed relationship between two thoughts.",
}, toolSet.Links.Link) }, toolSet.Links.Link); err != nil {
return err
addTool(server, logger, &mcp.Tool{ }
if err := addTool(server, logger, &mcp.Tool{
Name: "related_thoughts", Name: "related_thoughts",
Description: "Retrieve explicit links and semantic neighbors for a thought.", Description: "Retrieve explicit links and semantic neighbors for a thought.",
}, toolSet.Links.Related) }, toolSet.Links.Related); err != nil {
return err
}
return nil
}
func registerProjectTools(server *mcp.Server, logger *slog.Logger, toolSet ToolSet) error {
if err := addTool(server, logger, &mcp.Tool{
Name: "create_project",
Description: "Create a named project container for thoughts.",
}, toolSet.Projects.Create); err != nil {
return err
}
if err := addTool(server, logger, &mcp.Tool{
Name: "list_projects",
Description: "List projects and their current thought counts.",
}, toolSet.Projects.List); err != nil {
return err
}
if err := addTool(server, logger, &mcp.Tool{
Name: "set_active_project",
Description: "Set the active project for the current MCP session. Requires a stateful MCP client that reuses the same session across calls.",
}, toolSet.Projects.SetActive); err != nil {
return err
}
if err := addTool(server, logger, &mcp.Tool{
Name: "get_active_project",
Description: "Return the active project for the current MCP session. If your client does not preserve MCP sessions, pass project explicitly to project-scoped tools instead.",
}, toolSet.Projects.GetActive); err != nil {
return err
}
if err := addTool(server, logger, &mcp.Tool{
Name: "get_project_context",
Description: "Get recent and semantic context for a project. Uses the explicit project when provided, otherwise the active MCP session project.",
}, toolSet.Context.Handle); err != nil {
return err
}
return nil
}
func registerFileTools(server *mcp.Server, logger *slog.Logger, toolSet ToolSet) error {
server.AddResourceTemplate(&mcp.ResourceTemplate{ server.AddResourceTemplate(&mcp.ResourceTemplate{
Name: "stored_file", Name: "stored_file",
URITemplate: "amcs://files/{id}", URITemplate: "amcs://files/{id}",
Description: "A stored file. Read a file's raw binary content by its id. Use load_file for metadata.", Description: "A stored file. Read a file's raw binary content by its id. Use load_file for metadata.",
}, toolSet.Files.ReadResource) }, toolSet.Files.ReadResource)
addTool(server, logger, &mcp.Tool{ if err := addTool(server, logger, &mcp.Tool{
Name: "upload_file", Name: "upload_file",
Description: "Stage a file and get an amcs://files/{id} resource URI. Provide content_path (absolute server-side path, no size limit) or content_base64 (≤10 MB). Optionally link immediately with thought_id/project, or omit them and pass the returned URI to save_file later.", Description: "Stage a file and get an amcs://files/{id} resource URI. Provide content_path (absolute server-side path, no size limit) or content_base64 (≤10 MB). Optionally link immediately with thought_id/project, or omit them and pass the returned URI to save_file later.",
}, toolSet.Files.Upload) }, toolSet.Files.Upload); err != nil {
return err
addTool(server, logger, &mcp.Tool{ }
if err := addTool(server, logger, &mcp.Tool{
Name: "save_file", Name: "save_file",
Description: "Store a file and optionally link it to a thought. Supply either content_base64 (≤10 MB) or content_uri (amcs://files/{id} from a prior upload_file or POST /files call). For files larger than 10 MB, use upload_file with content_path first.", Description: "Store a file and optionally link it to a thought. Supply either content_base64 (≤10 MB) or content_uri (amcs://files/{id} from a prior upload_file or POST /files call). For files larger than 10 MB, use upload_file with content_path first.",
}, toolSet.Files.Save) }, toolSet.Files.Save); err != nil {
return err
addTool(server, logger, &mcp.Tool{ }
if err := addTool(server, logger, &mcp.Tool{
Name: "load_file", Name: "load_file",
Description: "Load a previously stored file by id and return its metadata and base64 content.", Description: "Load a previously stored file by id and return its metadata and base64 content.",
}, toolSet.Files.Load) }, toolSet.Files.Load); err != nil {
return err
addTool(server, logger, &mcp.Tool{ }
if err := addTool(server, logger, &mcp.Tool{
Name: "list_files", Name: "list_files",
Description: "List stored files, optionally filtered by thought, project, or kind.", Description: "List stored files, optionally filtered by thought, project, or kind.",
}, toolSet.Files.List) }, toolSet.Files.List); err != nil {
return err
}
return nil
}
addTool(server, logger, &mcp.Tool{ func registerMaintenanceTools(server *mcp.Server, logger *slog.Logger, toolSet ToolSet) error {
if err := addTool(server, logger, &mcp.Tool{
Name: "backfill_embeddings", Name: "backfill_embeddings",
Description: "Generate missing embeddings for stored thoughts using the active embedding model.", Description: "Generate missing embeddings for stored thoughts using the active embedding model.",
}, toolSet.Backfill.Handle) }, toolSet.Backfill.Handle); err != nil {
return err
addTool(server, logger, &mcp.Tool{ }
if err := addTool(server, logger, &mcp.Tool{
Name: "reparse_thought_metadata", Name: "reparse_thought_metadata",
Description: "Re-extract and normalize metadata for stored thoughts from their content.", Description: "Re-extract and normalize metadata for stored thoughts from their content.",
}, toolSet.Reparse.Handle) }, toolSet.Reparse.Handle); err != nil {
return err
addTool(server, logger, &mcp.Tool{ }
if err := addTool(server, logger, &mcp.Tool{
Name: "retry_failed_metadata", Name: "retry_failed_metadata",
Description: "Retry metadata extraction for thoughts still marked pending or failed.", Description: "Retry metadata extraction for thoughts still marked pending or failed.",
}, toolSet.RetryMetadata.Handle) }, toolSet.RetryMetadata.Handle); err != nil {
return err
// Household Knowledge }
addTool(server, logger, &mcp.Tool{ if err := addTool(server, logger, &mcp.Tool{
Name: "add_household_item",
Description: "Store a household fact (paint color, appliance details, measurement, document, etc.).",
}, toolSet.Household.AddItem)
addTool(server, logger, &mcp.Tool{
Name: "search_household_items",
Description: "Search household items by name, category, or location.",
}, toolSet.Household.SearchItems)
addTool(server, logger, &mcp.Tool{
Name: "get_household_item",
Description: "Retrieve a household item by id.",
}, toolSet.Household.GetItem)
addTool(server, logger, &mcp.Tool{
Name: "add_vendor",
Description: "Add a service provider (plumber, electrician, landscaper, etc.).",
}, toolSet.Household.AddVendor)
addTool(server, logger, &mcp.Tool{
Name: "list_vendors",
Description: "List household service vendors, optionally filtered by service type.",
}, toolSet.Household.ListVendors)
// Home Maintenance
addTool(server, logger, &mcp.Tool{
Name: "add_maintenance_task", Name: "add_maintenance_task",
Description: "Create a recurring or one-time home maintenance task.", Description: "Create a recurring or one-time home maintenance task.",
}, toolSet.Maintenance.AddTask) }, toolSet.Maintenance.AddTask); err != nil {
return err
addTool(server, logger, &mcp.Tool{ }
if err := addTool(server, logger, &mcp.Tool{
Name: "log_maintenance", Name: "log_maintenance",
Description: "Log completed maintenance work; automatically updates the task's next due date.", Description: "Log completed maintenance work; automatically updates the task's next due date.",
}, toolSet.Maintenance.LogWork) }, toolSet.Maintenance.LogWork); err != nil {
return err
addTool(server, logger, &mcp.Tool{ }
if err := addTool(server, logger, &mcp.Tool{
Name: "get_upcoming_maintenance", Name: "get_upcoming_maintenance",
Description: "List maintenance tasks due within the next N days.", Description: "List maintenance tasks due within the next N days.",
}, toolSet.Maintenance.GetUpcoming) }, toolSet.Maintenance.GetUpcoming); err != nil {
return err
addTool(server, logger, &mcp.Tool{ }
if err := addTool(server, logger, &mcp.Tool{
Name: "search_maintenance_history", Name: "search_maintenance_history",
Description: "Search the maintenance log by task name, category, or date range.", Description: "Search the maintenance log by task name, category, or date range.",
}, toolSet.Maintenance.SearchHistory) }, toolSet.Maintenance.SearchHistory); err != nil {
return err
}
return nil
}
// Family Calendar func registerHouseholdTools(server *mcp.Server, logger *slog.Logger, toolSet ToolSet) error {
addTool(server, logger, &mcp.Tool{ if err := addTool(server, logger, &mcp.Tool{
Name: "add_household_item",
Description: "Store a household fact (paint color, appliance details, measurement, document, etc.).",
}, toolSet.Household.AddItem); err != nil {
return err
}
if err := addTool(server, logger, &mcp.Tool{
Name: "search_household_items",
Description: "Search household items by name, category, or location.",
}, toolSet.Household.SearchItems); err != nil {
return err
}
if err := addTool(server, logger, &mcp.Tool{
Name: "get_household_item",
Description: "Retrieve a household item by id.",
}, toolSet.Household.GetItem); err != nil {
return err
}
if err := addTool(server, logger, &mcp.Tool{
Name: "add_vendor",
Description: "Add a service provider (plumber, electrician, landscaper, etc.).",
}, toolSet.Household.AddVendor); err != nil {
return err
}
if err := addTool(server, logger, &mcp.Tool{
Name: "list_vendors",
Description: "List household service vendors, optionally filtered by service type.",
}, toolSet.Household.ListVendors); err != nil {
return err
}
return nil
}
func registerCalendarTools(server *mcp.Server, logger *slog.Logger, toolSet ToolSet) error {
if err := addTool(server, logger, &mcp.Tool{
Name: "add_family_member", Name: "add_family_member",
Description: "Add a family member to the household.", Description: "Add a family member to the household.",
}, toolSet.Calendar.AddMember) }, toolSet.Calendar.AddMember); err != nil {
return err
addTool(server, logger, &mcp.Tool{ }
if err := addTool(server, logger, &mcp.Tool{
Name: "list_family_members", Name: "list_family_members",
Description: "List all family members.", Description: "List all family members.",
}, toolSet.Calendar.ListMembers) }, toolSet.Calendar.ListMembers); err != nil {
return err
addTool(server, logger, &mcp.Tool{ }
if err := addTool(server, logger, &mcp.Tool{
Name: "add_activity", Name: "add_activity",
Description: "Schedule a one-time or recurring family activity.", Description: "Schedule a one-time or recurring family activity.",
}, toolSet.Calendar.AddActivity) }, toolSet.Calendar.AddActivity); err != nil {
return err
addTool(server, logger, &mcp.Tool{ }
if err := addTool(server, logger, &mcp.Tool{
Name: "get_week_schedule", Name: "get_week_schedule",
Description: "Get all activities scheduled for a given week.", Description: "Get all activities scheduled for a given week.",
}, toolSet.Calendar.GetWeekSchedule) }, toolSet.Calendar.GetWeekSchedule); err != nil {
return err
addTool(server, logger, &mcp.Tool{ }
if err := addTool(server, logger, &mcp.Tool{
Name: "search_activities", Name: "search_activities",
Description: "Search activities by title, type, or family member.", Description: "Search activities by title, type, or family member.",
}, toolSet.Calendar.SearchActivities) }, toolSet.Calendar.SearchActivities); err != nil {
return err
addTool(server, logger, &mcp.Tool{ }
if err := addTool(server, logger, &mcp.Tool{
Name: "add_important_date", Name: "add_important_date",
Description: "Track a birthday, anniversary, deadline, or other important date.", Description: "Track a birthday, anniversary, deadline, or other important date.",
}, toolSet.Calendar.AddImportantDate) }, toolSet.Calendar.AddImportantDate); err != nil {
return err
addTool(server, logger, &mcp.Tool{ }
if err := addTool(server, logger, &mcp.Tool{
Name: "get_upcoming_dates", Name: "get_upcoming_dates",
Description: "Get important dates coming up in the next N days.", Description: "Get important dates coming up in the next N days.",
}, toolSet.Calendar.GetUpcomingDates) }, toolSet.Calendar.GetUpcomingDates); err != nil {
return err
}
return nil
}
// Meal Planning func registerMealTools(server *mcp.Server, logger *slog.Logger, toolSet ToolSet) error {
addTool(server, logger, &mcp.Tool{ if err := addTool(server, logger, &mcp.Tool{
Name: "add_recipe", Name: "add_recipe",
Description: "Save a recipe with ingredients and instructions.", Description: "Save a recipe with ingredients and instructions.",
}, toolSet.Meals.AddRecipe) }, toolSet.Meals.AddRecipe); err != nil {
return err
addTool(server, logger, &mcp.Tool{ }
if err := addTool(server, logger, &mcp.Tool{
Name: "search_recipes", Name: "search_recipes",
Description: "Search recipes by name, cuisine, tags, or ingredient.", Description: "Search recipes by name, cuisine, tags, or ingredient.",
}, toolSet.Meals.SearchRecipes) }, toolSet.Meals.SearchRecipes); err != nil {
return err
addTool(server, logger, &mcp.Tool{ }
if err := addTool(server, logger, &mcp.Tool{
Name: "update_recipe", Name: "update_recipe",
Description: "Update an existing recipe.", Description: "Update an existing recipe.",
}, toolSet.Meals.UpdateRecipe) }, toolSet.Meals.UpdateRecipe); err != nil {
return err
addTool(server, logger, &mcp.Tool{ }
if err := addTool(server, logger, &mcp.Tool{
Name: "create_meal_plan", Name: "create_meal_plan",
Description: "Set the meal plan for a week; replaces any existing plan for that week.", Description: "Set the meal plan for a week; replaces any existing plan for that week.",
}, toolSet.Meals.CreateMealPlan) }, toolSet.Meals.CreateMealPlan); err != nil {
return err
addTool(server, logger, &mcp.Tool{ }
if err := addTool(server, logger, &mcp.Tool{
Name: "get_meal_plan", Name: "get_meal_plan",
Description: "Get the meal plan for a given week.", Description: "Get the meal plan for a given week.",
}, toolSet.Meals.GetMealPlan) }, toolSet.Meals.GetMealPlan); err != nil {
return err
addTool(server, logger, &mcp.Tool{ }
if err := addTool(server, logger, &mcp.Tool{
Name: "generate_shopping_list", Name: "generate_shopping_list",
Description: "Auto-generate a shopping list from the meal plan for a given week.", Description: "Auto-generate a shopping list from the meal plan for a given week.",
}, toolSet.Meals.GenerateShoppingList) }, toolSet.Meals.GenerateShoppingList); err != nil {
return err
}
return nil
}
// Professional CRM func registerCRMTools(server *mcp.Server, logger *slog.Logger, toolSet ToolSet) error {
addTool(server, logger, &mcp.Tool{ if err := addTool(server, logger, &mcp.Tool{
Name: "add_professional_contact", Name: "add_professional_contact",
Description: "Add a professional contact to the CRM.", Description: "Add a professional contact to the CRM.",
}, toolSet.CRM.AddContact) }, toolSet.CRM.AddContact); err != nil {
return err
addTool(server, logger, &mcp.Tool{ }
if err := addTool(server, logger, &mcp.Tool{
Name: "search_contacts", Name: "search_contacts",
Description: "Search professional contacts by name, company, title, notes, or tags.", Description: "Search professional contacts by name, company, title, notes, or tags.",
}, toolSet.CRM.SearchContacts) }, toolSet.CRM.SearchContacts); err != nil {
return err
addTool(server, logger, &mcp.Tool{ }
if err := addTool(server, logger, &mcp.Tool{
Name: "log_interaction", Name: "log_interaction",
Description: "Log an interaction with a professional contact.", Description: "Log an interaction with a professional contact.",
}, toolSet.CRM.LogInteraction) }, toolSet.CRM.LogInteraction); err != nil {
return err
addTool(server, logger, &mcp.Tool{ }
if err := addTool(server, logger, &mcp.Tool{
Name: "get_contact_history", Name: "get_contact_history",
Description: "Get full history (interactions and opportunities) for a contact.", Description: "Get full history (interactions and opportunities) for a contact.",
}, toolSet.CRM.GetHistory) }, toolSet.CRM.GetHistory); err != nil {
return err
addTool(server, logger, &mcp.Tool{ }
if err := addTool(server, logger, &mcp.Tool{
Name: "create_opportunity", Name: "create_opportunity",
Description: "Create a deal, project, or opportunity linked to a contact.", Description: "Create a deal, project, or opportunity linked to a contact.",
}, toolSet.CRM.CreateOpportunity) }, toolSet.CRM.CreateOpportunity); err != nil {
return err
addTool(server, logger, &mcp.Tool{ }
if err := addTool(server, logger, &mcp.Tool{
Name: "get_follow_ups_due", Name: "get_follow_ups_due",
Description: "List contacts with a follow-up date due within the next N days.", Description: "List contacts with a follow-up date due within the next N days.",
}, toolSet.CRM.GetFollowUpsDue) }, toolSet.CRM.GetFollowUpsDue); err != nil {
return err
addTool(server, logger, &mcp.Tool{ }
if err := addTool(server, logger, &mcp.Tool{
Name: "link_thought_to_contact", Name: "link_thought_to_contact",
Description: "Append a stored thought to a contact's notes.", Description: "Append a stored thought to a contact's notes.",
}, toolSet.CRM.LinkThought) }, toolSet.CRM.LinkThought); err != nil {
return err
}
return nil
}
// Agent Skills func registerSkillTools(server *mcp.Server, logger *slog.Logger, toolSet ToolSet) error {
addTool(server, logger, &mcp.Tool{ if err := addTool(server, logger, &mcp.Tool{
Name: "add_skill", Name: "add_skill",
Description: "Store a reusable agent skill (behavioural instruction or capability prompt).", Description: "Store a reusable agent skill (behavioural instruction or capability prompt).",
}, toolSet.Skills.AddSkill) }, toolSet.Skills.AddSkill); err != nil {
return err
addTool(server, logger, &mcp.Tool{ }
if err := addTool(server, logger, &mcp.Tool{
Name: "remove_skill", Name: "remove_skill",
Description: "Delete an agent skill by id.", Description: "Delete an agent skill by id.",
}, toolSet.Skills.RemoveSkill) }, toolSet.Skills.RemoveSkill); err != nil {
return err
addTool(server, logger, &mcp.Tool{ }
if err := addTool(server, logger, &mcp.Tool{
Name: "list_skills", Name: "list_skills",
Description: "List all agent skills, optionally filtered by tag.", Description: "List all agent skills, optionally filtered by tag.",
}, toolSet.Skills.ListSkills) }, toolSet.Skills.ListSkills); err != nil {
return err
// Agent Guardrails }
addTool(server, logger, &mcp.Tool{ if err := addTool(server, logger, &mcp.Tool{
Name: "add_guardrail", Name: "add_guardrail",
Description: "Store a reusable agent guardrail (constraint or safety rule).", Description: "Store a reusable agent guardrail (constraint or safety rule).",
}, toolSet.Skills.AddGuardrail) }, toolSet.Skills.AddGuardrail); err != nil {
return err
addTool(server, logger, &mcp.Tool{ }
if err := addTool(server, logger, &mcp.Tool{
Name: "remove_guardrail", Name: "remove_guardrail",
Description: "Delete an agent guardrail by id.", Description: "Delete an agent guardrail by id.",
}, toolSet.Skills.RemoveGuardrail) }, toolSet.Skills.RemoveGuardrail); err != nil {
return err
addTool(server, logger, &mcp.Tool{ }
if err := addTool(server, logger, &mcp.Tool{
Name: "list_guardrails", Name: "list_guardrails",
Description: "List all agent guardrails, optionally filtered by tag or severity.", Description: "List all agent guardrails, optionally filtered by tag or severity.",
}, toolSet.Skills.ListGuardrails) }, toolSet.Skills.ListGuardrails); err != nil {
return err
// Project Skills & Guardrails }
addTool(server, logger, &mcp.Tool{ if err := addTool(server, logger, &mcp.Tool{
Name: "add_project_skill", Name: "add_project_skill",
Description: "Link an agent skill to a project.", Description: "Link an agent skill to a project. Pass project explicitly when your client does not preserve MCP sessions.",
}, toolSet.Skills.AddProjectSkill) }, toolSet.Skills.AddProjectSkill); err != nil {
return err
addTool(server, logger, &mcp.Tool{ }
Name: "remove_project_skill", if err := addTool(server, logger, &mcp.Tool{
Description: "Unlink an agent skill from a project.", Name: "remove_project_skill",
}, toolSet.Skills.RemoveProjectSkill) Description: "Unlink an agent skill from a project. Pass project explicitly when your client does not preserve MCP sessions.",
}, toolSet.Skills.RemoveProjectSkill); err != nil {
addTool(server, logger, &mcp.Tool{ return err
Name: "list_project_skills", }
Description: "List all skills linked to a project. Call this at the start of a project session to load existing agent behaviour instructions before generating new ones.", if err := addTool(server, logger, &mcp.Tool{
}, toolSet.Skills.ListProjectSkills) Name: "list_project_skills",
Description: "List all skills linked to a project. Call this at the start of a project session to load existing agent behaviour instructions before generating new ones. Pass project explicitly when your client does not preserve MCP sessions.",
addTool(server, logger, &mcp.Tool{ }, toolSet.Skills.ListProjectSkills); err != nil {
Name: "add_project_guardrail", return err
Description: "Link an agent guardrail to a project.", }
}, toolSet.Skills.AddProjectGuardrail) if err := addTool(server, logger, &mcp.Tool{
Name: "add_project_guardrail",
addTool(server, logger, &mcp.Tool{ Description: "Link an agent guardrail to a project. Pass project explicitly when your client does not preserve MCP sessions.",
Name: "remove_project_guardrail", }, toolSet.Skills.AddProjectGuardrail); err != nil {
Description: "Unlink an agent guardrail from a project.", return err
}, toolSet.Skills.RemoveProjectGuardrail) }
if err := addTool(server, logger, &mcp.Tool{
addTool(server, logger, &mcp.Tool{ Name: "remove_project_guardrail",
Name: "list_project_guardrails", Description: "Unlink an agent guardrail from a project. Pass project explicitly when your client does not preserve MCP sessions.",
Description: "List all guardrails linked to a project. Call this at the start of a project session to load existing agent constraints before generating new ones.", }, toolSet.Skills.RemoveProjectGuardrail); err != nil {
}, toolSet.Skills.ListProjectGuardrails) return err
}
return mcp.NewStreamableHTTPHandler(func(*http.Request) *mcp.Server { if err := addTool(server, logger, &mcp.Tool{
return server Name: "list_project_guardrails",
}, &mcp.StreamableHTTPOptions{ Description: "List all guardrails linked to a project. Call this at the start of a project session to load existing agent constraints before generating new ones. Pass project explicitly when your client does not preserve MCP sessions.",
JSONResponse: true, }, toolSet.Skills.ListProjectGuardrails); err != nil {
SessionTimeout: 10 * time.Minute, return err
}) }
return nil
} }

View File

@@ -0,0 +1,137 @@
package mcpserver
import (
"context"
"net/http/httptest"
"testing"
"time"
"github.com/modelcontextprotocol/go-sdk/jsonrpc"
"github.com/modelcontextprotocol/go-sdk/mcp"
"git.warky.dev/wdevs/amcs/internal/buildinfo"
"git.warky.dev/wdevs/amcs/internal/config"
"git.warky.dev/wdevs/amcs/internal/mcperrors"
"git.warky.dev/wdevs/amcs/internal/tools"
)
func TestStreamableHTTPReturnsStructuredToolErrors(t *testing.T) {
handler, err := New(config.MCPConfig{
ServerName: "test",
Version: "0.0.1",
SessionTimeout: time.Minute,
}, nil, streamableTestToolSet(), nil)
if err != nil {
t.Fatalf("mcpserver.New() error = %v", err)
}
httpServer := httptest.NewServer(handler)
defer httpServer.Close()
client := mcp.NewClient(&mcp.Implementation{Name: "client", Version: "0.0.1"}, nil)
cs, err := client.Connect(context.Background(), &mcp.StreamableClientTransport{Endpoint: httpServer.URL}, nil)
if err != nil {
t.Fatalf("connect client: %v", err)
}
defer func() {
_ = cs.Close()
}()
t.Run("schema_validation", func(t *testing.T) {
_, err := cs.CallTool(context.Background(), &mcp.CallToolParams{
Name: "create_project",
Arguments: map[string]any{},
})
if err == nil {
t.Fatal("CallTool(create_project) error = nil, want error")
}
rpcErr, data := requireWireError(t, err)
if rpcErr.Code != jsonrpc.CodeInvalidParams {
t.Fatalf("create_project code = %d, want %d", rpcErr.Code, jsonrpc.CodeInvalidParams)
}
if data.Type != mcperrors.TypeInvalidArguments {
t.Fatalf("create_project data.type = %q, want %q", data.Type, mcperrors.TypeInvalidArguments)
}
if data.Field != "name" {
t.Fatalf("create_project data.field = %q, want %q", data.Field, "name")
}
})
t.Run("project_required", func(t *testing.T) {
_, err := cs.CallTool(context.Background(), &mcp.CallToolParams{
Name: "get_project_context",
Arguments: map[string]any{},
})
if err == nil {
t.Fatal("CallTool(get_project_context) error = nil, want error")
}
rpcErr, data := requireWireError(t, err)
if rpcErr.Code != mcperrors.CodeProjectRequired {
t.Fatalf("get_project_context code = %d, want %d", rpcErr.Code, mcperrors.CodeProjectRequired)
}
if data.Type != mcperrors.TypeProjectRequired {
t.Fatalf("get_project_context data.type = %q, want %q", data.Type, mcperrors.TypeProjectRequired)
}
if data.Field != "project" {
t.Fatalf("get_project_context data.field = %q, want %q", data.Field, "project")
}
})
t.Run("version_info", func(t *testing.T) {
result, err := cs.CallTool(context.Background(), &mcp.CallToolParams{
Name: "get_version_info",
Arguments: map[string]any{},
})
if err != nil {
t.Fatalf("CallTool(get_version_info) error = %v", err)
}
got, ok := result.StructuredContent.(map[string]any)
if !ok {
t.Fatalf("structured content type = %T, want map[string]any", result.StructuredContent)
}
if got["server_name"] != "test" {
t.Fatalf("server_name = %#v, want %q", got["server_name"], "test")
}
if got["version"] != "0.0.1" {
t.Fatalf("version = %#v, want %q", got["version"], "0.0.1")
}
if got["tag_name"] != "v0.0.1" {
t.Fatalf("tag_name = %#v, want %q", got["tag_name"], "v0.0.1")
}
if got["build_date"] != "2026-03-31T00:00:00Z" {
t.Fatalf("build_date = %#v, want %q", got["build_date"], "2026-03-31T00:00:00Z")
}
})
}
func streamableTestToolSet() ToolSet {
return ToolSet{
Version: tools.NewVersionTool("test", buildinfo.Info{Version: "0.0.1", TagName: "v0.0.1", Commit: "test", BuildDate: "2026-03-31T00:00:00Z"}),
Capture: new(tools.CaptureTool),
Search: new(tools.SearchTool),
List: new(tools.ListTool),
Stats: new(tools.StatsTool),
Get: new(tools.GetTool),
Update: new(tools.UpdateTool),
Delete: new(tools.DeleteTool),
Archive: new(tools.ArchiveTool),
Projects: new(tools.ProjectsTool),
Context: new(tools.ContextTool),
Recall: new(tools.RecallTool),
Summarize: new(tools.SummarizeTool),
Links: new(tools.LinksTool),
Files: new(tools.FilesTool),
Backfill: new(tools.BackfillTool),
Reparse: new(tools.ReparseMetadataTool),
RetryMetadata: new(tools.RetryMetadataTool),
Household: new(tools.HouseholdTool),
Maintenance: new(tools.MaintenanceTool),
Calendar: new(tools.CalendarTool),
Meals: new(tools.MealsTool),
CRM: new(tools.CRMTool),
Skills: new(tools.SkillsTool),
}
}

View File

@@ -136,7 +136,9 @@ func (db *DB) AddThoughtAttachment(ctx context.Context, thoughtID uuid.UUID, att
if err != nil { if err != nil {
return fmt.Errorf("begin transaction: %w", err) return fmt.Errorf("begin transaction: %w", err)
} }
defer tx.Rollback(ctx) defer func() {
_ = tx.Rollback(ctx)
}()
var metadataBytes []byte var metadataBytes []byte
if err := tx.QueryRow(ctx, `select metadata from thoughts where guid = $1 for update`, thoughtID).Scan(&metadataBytes); err != nil { if err := tx.QueryRow(ctx, `select metadata from thoughts where guid = $1 for update`, thoughtID).Scan(&metadataBytes); err != nil {

View File

@@ -25,7 +25,9 @@ func (db *DB) InsertThought(ctx context.Context, thought thoughttypes.Thought, e
if err != nil { if err != nil {
return thoughttypes.Thought{}, fmt.Errorf("begin transaction: %w", err) return thoughttypes.Thought{}, fmt.Errorf("begin transaction: %w", err)
} }
defer tx.Rollback(ctx) defer func() {
_ = tx.Rollback(ctx)
}()
row := tx.QueryRow(ctx, ` row := tx.QueryRow(ctx, `
insert into thoughts (content, metadata, project_id) insert into thoughts (content, metadata, project_id)
@@ -240,7 +242,9 @@ func (db *DB) UpdateThought(ctx context.Context, id uuid.UUID, content string, e
if err != nil { if err != nil {
return thoughttypes.Thought{}, fmt.Errorf("begin transaction: %w", err) return thoughttypes.Thought{}, fmt.Errorf("begin transaction: %w", err)
} }
defer tx.Rollback(ctx) defer func() {
_ = tx.Rollback(ctx)
}()
tag, err := tx.Exec(ctx, ` tag, err := tx.Exec(ctx, `
update thoughts update thoughts

View File

@@ -35,7 +35,7 @@ type AddFamilyMemberOutput struct {
func (t *CalendarTool) AddMember(ctx context.Context, _ *mcp.CallToolRequest, in AddFamilyMemberInput) (*mcp.CallToolResult, AddFamilyMemberOutput, error) { func (t *CalendarTool) AddMember(ctx context.Context, _ *mcp.CallToolRequest, in AddFamilyMemberInput) (*mcp.CallToolResult, AddFamilyMemberOutput, error) {
if strings.TrimSpace(in.Name) == "" { if strings.TrimSpace(in.Name) == "" {
return nil, AddFamilyMemberOutput{}, errInvalidInput("name is required") return nil, AddFamilyMemberOutput{}, errRequiredField("name")
} }
member, err := t.store.AddFamilyMember(ctx, ext.FamilyMember{ member, err := t.store.AddFamilyMember(ctx, ext.FamilyMember{
Name: strings.TrimSpace(in.Name), Name: strings.TrimSpace(in.Name),
@@ -89,7 +89,7 @@ type AddActivityOutput struct {
func (t *CalendarTool) AddActivity(ctx context.Context, _ *mcp.CallToolRequest, in AddActivityInput) (*mcp.CallToolResult, AddActivityOutput, error) { func (t *CalendarTool) AddActivity(ctx context.Context, _ *mcp.CallToolRequest, in AddActivityInput) (*mcp.CallToolResult, AddActivityOutput, error) {
if strings.TrimSpace(in.Title) == "" { if strings.TrimSpace(in.Title) == "" {
return nil, AddActivityOutput{}, errInvalidInput("title is required") return nil, AddActivityOutput{}, errRequiredField("title")
} }
activity, err := t.store.AddActivity(ctx, ext.Activity{ activity, err := t.store.AddActivity(ctx, ext.Activity{
FamilyMemberID: in.FamilyMemberID, FamilyMemberID: in.FamilyMemberID,
@@ -170,7 +170,7 @@ type AddImportantDateOutput struct {
func (t *CalendarTool) AddImportantDate(ctx context.Context, _ *mcp.CallToolRequest, in AddImportantDateInput) (*mcp.CallToolResult, AddImportantDateOutput, error) { func (t *CalendarTool) AddImportantDate(ctx context.Context, _ *mcp.CallToolRequest, in AddImportantDateInput) (*mcp.CallToolResult, AddImportantDateOutput, error) {
if strings.TrimSpace(in.Title) == "" { if strings.TrimSpace(in.Title) == "" {
return nil, AddImportantDateOutput{}, errInvalidInput("title is required") return nil, AddImportantDateOutput{}, errRequiredField("title")
} }
reminder := in.ReminderDaysBefore reminder := in.ReminderDaysBefore
if reminder <= 0 { if reminder <= 0 {

View File

@@ -43,7 +43,7 @@ func NewCaptureTool(db *store.DB, provider ai.Provider, capture config.CaptureCo
func (t *CaptureTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in CaptureInput) (*mcp.CallToolResult, CaptureOutput, error) { func (t *CaptureTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in CaptureInput) (*mcp.CallToolResult, CaptureOutput, error) {
content := strings.TrimSpace(in.Content) content := strings.TrimSpace(in.Content)
if content == "" { if content == "" {
return nil, CaptureOutput{}, errInvalidInput("content is required") return nil, CaptureOutput{}, errRequiredField("content")
} }
project, err := resolveProject(ctx, t.store, t.sessions, req, in.Project, false) project, err := resolveProject(ctx, t.store, t.sessions, req, in.Project, false)

View File

@@ -1,9 +1,13 @@
package tools package tools
import ( import (
"encoding/json"
"fmt" "fmt"
"strings"
"git.warky.dev/wdevs/amcs/internal/config" "git.warky.dev/wdevs/amcs/internal/config"
"git.warky.dev/wdevs/amcs/internal/mcperrors"
"github.com/modelcontextprotocol/go-sdk/jsonrpc"
) )
func normalizeLimit(limit int, cfg config.SearchConfig) int { func normalizeLimit(limit int, cfg config.SearchConfig) int {
@@ -26,6 +30,116 @@ func normalizeThreshold(value float64, fallback float64) float64 {
return value return value
} }
func errInvalidInput(message string) error { const (
return fmt.Errorf("invalid input: %s", message) codeSessionRequired = mcperrors.CodeSessionRequired
codeProjectRequired = mcperrors.CodeProjectRequired
codeProjectNotFound = mcperrors.CodeProjectNotFound
codeInvalidID = mcperrors.CodeInvalidID
codeEntityNotFound = mcperrors.CodeEntityNotFound
)
type mcpErrorData = mcperrors.Data
func newMCPError(code int64, message string, data mcpErrorData) error {
rpcErr := &jsonrpc.Error{
Code: code,
Message: message,
}
payload, err := json.Marshal(data)
if err != nil {
return fmt.Errorf("marshal mcp error data: %w", err)
}
rpcErr.Data = payload
return rpcErr
}
func errInvalidInput(message string) error {
return newMCPError(
jsonrpc.CodeInvalidParams,
"invalid input: "+message,
mcpErrorData{
Type: mcperrors.TypeInvalidInput,
},
)
}
func errRequiredField(field string) error {
return newMCPError(
jsonrpc.CodeInvalidParams,
field+" is required",
mcpErrorData{
Type: mcperrors.TypeInvalidInput,
Field: field,
Detail: "required",
Hint: "provide " + field,
},
)
}
func errInvalidField(field string, message string, hint string) error {
return newMCPError(
jsonrpc.CodeInvalidParams,
message,
mcpErrorData{
Type: mcperrors.TypeInvalidInput,
Field: field,
Detail: "invalid",
Hint: hint,
},
)
}
func errOneOfRequired(fields ...string) error {
return newMCPError(
jsonrpc.CodeInvalidParams,
joinFields(fields)+" is required",
mcpErrorData{
Type: mcperrors.TypeInvalidInput,
Fields: fields,
Detail: "one_of_required",
Hint: "provide one of: " + strings.Join(fields, ", "),
},
)
}
func errMutuallyExclusiveFields(fields ...string) error {
return newMCPError(
jsonrpc.CodeInvalidParams,
"provide "+joinFields(fields)+", not both",
mcpErrorData{
Type: mcperrors.TypeInvalidInput,
Fields: fields,
Detail: "mutually_exclusive",
Hint: "provide only one of: " + strings.Join(fields, ", "),
},
)
}
func errEntityNotFound(entity string, field string, value string) error {
return newMCPError(
codeEntityNotFound,
entity+" not found",
mcpErrorData{
Type: mcperrors.TypeEntityNotFound,
Entity: entity,
Field: field,
Value: value,
Detail: "not_found",
},
)
}
func joinFields(fields []string) string {
switch len(fields) {
case 0:
return "field"
case 1:
return fields[0]
case 2:
return fields[0] + " or " + fields[1]
default:
return strings.Join(fields[:len(fields)-1], ", ") + ", or " + fields[len(fields)-1]
}
} }

View File

@@ -0,0 +1,84 @@
package tools
import (
"testing"
"git.warky.dev/wdevs/amcs/internal/mcperrors"
)
func TestErrRequiredFieldReturnsFieldMetadata(t *testing.T) {
rpcErr, data := requireRPCError(t, errRequiredField("name"))
if data.Type != mcperrors.TypeInvalidInput {
t.Fatalf("errRequiredField() type = %q, want %q", data.Type, mcperrors.TypeInvalidInput)
}
if data.Field != "name" {
t.Fatalf("errRequiredField() field = %q, want %q", data.Field, "name")
}
if data.Detail != "required" {
t.Fatalf("errRequiredField() detail = %q, want %q", data.Detail, "required")
}
if rpcErr.Message != "name is required" {
t.Fatalf("errRequiredField() message = %q, want %q", rpcErr.Message, "name is required")
}
}
func TestErrInvalidFieldReturnsFieldMetadata(t *testing.T) {
rpcErr, data := requireRPCError(t, errInvalidField("severity", "severity must be one of: low, medium, high, critical", "pass one of: low, medium, high, critical"))
if data.Field != "severity" {
t.Fatalf("errInvalidField() field = %q, want %q", data.Field, "severity")
}
if data.Detail != "invalid" {
t.Fatalf("errInvalidField() detail = %q, want %q", data.Detail, "invalid")
}
if data.Hint == "" {
t.Fatal("errInvalidField() hint = empty, want guidance")
}
if rpcErr.Message == "" {
t.Fatal("errInvalidField() message = empty, want non-empty")
}
}
func TestErrOneOfRequiredReturnsFieldsMetadata(t *testing.T) {
rpcErr, data := requireRPCError(t, errOneOfRequired("content_base64", "content_uri"))
if data.Detail != "one_of_required" {
t.Fatalf("errOneOfRequired() detail = %q, want %q", data.Detail, "one_of_required")
}
if len(data.Fields) != 2 || data.Fields[0] != "content_base64" || data.Fields[1] != "content_uri" {
t.Fatalf("errOneOfRequired() fields = %#v, want [content_base64 content_uri]", data.Fields)
}
if rpcErr.Message != "content_base64 or content_uri is required" {
t.Fatalf("errOneOfRequired() message = %q, want %q", rpcErr.Message, "content_base64 or content_uri is required")
}
}
func TestErrMutuallyExclusiveFieldsReturnsFieldsMetadata(t *testing.T) {
rpcErr, data := requireRPCError(t, errMutuallyExclusiveFields("content_uri", "content_base64"))
if data.Detail != "mutually_exclusive" {
t.Fatalf("errMutuallyExclusiveFields() detail = %q, want %q", data.Detail, "mutually_exclusive")
}
if len(data.Fields) != 2 || data.Fields[0] != "content_uri" || data.Fields[1] != "content_base64" {
t.Fatalf("errMutuallyExclusiveFields() fields = %#v, want [content_uri content_base64]", data.Fields)
}
if rpcErr.Message != "provide content_uri or content_base64, not both" {
t.Fatalf("errMutuallyExclusiveFields() message = %q, want %q", rpcErr.Message, "provide content_uri or content_base64, not both")
}
}
func TestErrEntityNotFoundReturnsEntityMetadata(t *testing.T) {
rpcErr, data := requireRPCError(t, errEntityNotFound("thought", "thought_id", "123"))
if rpcErr.Code != codeEntityNotFound {
t.Fatalf("errEntityNotFound() code = %d, want %d", rpcErr.Code, codeEntityNotFound)
}
if data.Type != mcperrors.TypeEntityNotFound {
t.Fatalf("errEntityNotFound() type = %q, want %q", data.Type, mcperrors.TypeEntityNotFound)
}
if data.Entity != "thought" {
t.Fatalf("errEntityNotFound() entity = %q, want %q", data.Entity, "thought")
}
if data.Field != "thought_id" {
t.Fatalf("errEntityNotFound() field = %q, want %q", data.Field, "thought_id")
}
if data.Value != "123" {
t.Fatalf("errEntityNotFound() value = %q, want %q", data.Value, "123")
}
}

View File

@@ -2,11 +2,13 @@ package tools
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"strings" "strings"
"time" "time"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/jackc/pgx/v5"
"github.com/modelcontextprotocol/go-sdk/mcp" "github.com/modelcontextprotocol/go-sdk/mcp"
"git.warky.dev/wdevs/amcs/internal/store" "git.warky.dev/wdevs/amcs/internal/store"
@@ -42,7 +44,7 @@ type AddContactOutput struct {
func (t *CRMTool) AddContact(ctx context.Context, _ *mcp.CallToolRequest, in AddContactInput) (*mcp.CallToolResult, AddContactOutput, error) { func (t *CRMTool) AddContact(ctx context.Context, _ *mcp.CallToolRequest, in AddContactInput) (*mcp.CallToolResult, AddContactOutput, error) {
if strings.TrimSpace(in.Name) == "" { if strings.TrimSpace(in.Name) == "" {
return nil, AddContactOutput{}, errInvalidInput("name is required") return nil, AddContactOutput{}, errRequiredField("name")
} }
if in.Tags == nil { if in.Tags == nil {
in.Tags = []string{} in.Tags = []string{}
@@ -104,7 +106,7 @@ type LogInteractionOutput struct {
func (t *CRMTool) LogInteraction(ctx context.Context, _ *mcp.CallToolRequest, in LogInteractionInput) (*mcp.CallToolResult, LogInteractionOutput, error) { func (t *CRMTool) LogInteraction(ctx context.Context, _ *mcp.CallToolRequest, in LogInteractionInput) (*mcp.CallToolResult, LogInteractionOutput, error) {
if strings.TrimSpace(in.Summary) == "" { if strings.TrimSpace(in.Summary) == "" {
return nil, LogInteractionOutput{}, errInvalidInput("summary is required") return nil, LogInteractionOutput{}, errRequiredField("summary")
} }
occurredAt := time.Now() occurredAt := time.Now()
if in.OccurredAt != nil { if in.OccurredAt != nil {
@@ -160,7 +162,7 @@ type CreateOpportunityOutput struct {
func (t *CRMTool) CreateOpportunity(ctx context.Context, _ *mcp.CallToolRequest, in CreateOpportunityInput) (*mcp.CallToolResult, CreateOpportunityOutput, error) { func (t *CRMTool) CreateOpportunity(ctx context.Context, _ *mcp.CallToolRequest, in CreateOpportunityInput) (*mcp.CallToolResult, CreateOpportunityOutput, error) {
if strings.TrimSpace(in.Title) == "" { if strings.TrimSpace(in.Title) == "" {
return nil, CreateOpportunityOutput{}, errInvalidInput("title is required") return nil, CreateOpportunityOutput{}, errRequiredField("title")
} }
stage := strings.TrimSpace(in.Stage) stage := strings.TrimSpace(in.Stage)
if stage == "" { if stage == "" {
@@ -216,7 +218,10 @@ type LinkThoughtToContactOutput struct {
func (t *CRMTool) LinkThought(ctx context.Context, _ *mcp.CallToolRequest, in LinkThoughtToContactInput) (*mcp.CallToolResult, LinkThoughtToContactOutput, error) { func (t *CRMTool) LinkThought(ctx context.Context, _ *mcp.CallToolRequest, in LinkThoughtToContactInput) (*mcp.CallToolResult, LinkThoughtToContactOutput, error) {
thought, err := t.store.GetThought(ctx, in.ThoughtID) thought, err := t.store.GetThought(ctx, in.ThoughtID)
if err != nil { if err != nil {
return nil, LinkThoughtToContactOutput{}, fmt.Errorf("thought not found: %w", err) if errors.Is(err, pgx.ErrNoRows) {
return nil, LinkThoughtToContactOutput{}, errEntityNotFound("thought", "thought_id", in.ThoughtID.String())
}
return nil, LinkThoughtToContactOutput{}, err
} }
appendText := fmt.Sprintf("\n\n[Linked thought %s]: %s", thought.ID, thought.Content) appendText := fmt.Sprintf("\n\n[Linked thought %s]: %s", thought.ID, thought.Content)
@@ -226,6 +231,9 @@ func (t *CRMTool) LinkThought(ctx context.Context, _ *mcp.CallToolRequest, in Li
contact, err := t.store.GetContact(ctx, in.ContactID) contact, err := t.store.GetContact(ctx, in.ContactID)
if err != nil { if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return nil, LinkThoughtToContactOutput{}, errEntityNotFound("contact", "contact_id", in.ContactID.String())
}
return nil, LinkThoughtToContactOutput{}, err return nil, LinkThoughtToContactOutput{}, err
} }
return nil, LinkThoughtToContactOutput{Contact: contact}, nil return nil, LinkThoughtToContactOutput{Contact: contact}, nil

View File

@@ -52,7 +52,7 @@ type SaveFileOutput struct {
} }
type LoadFileInput struct { type LoadFileInput struct {
ID string `json:"id" jsonschema:"the stored file id"` ID string `json:"id" jsonschema:"the stored file id or amcs://files/{id} URI"`
} }
type LoadFileOutput struct { type LoadFileOutput struct {
@@ -95,7 +95,7 @@ func (t *FilesTool) Upload(ctx context.Context, req *mcp.CallToolRequest, in Upl
b64 := strings.TrimSpace(in.ContentBase64) b64 := strings.TrimSpace(in.ContentBase64)
if path != "" && b64 != "" { if path != "" && b64 != "" {
return nil, UploadFileOutput{}, errInvalidInput("provide content_path or content_base64, not both") return nil, UploadFileOutput{}, errMutuallyExclusiveFields("content_path", "content_base64")
} }
var content []byte var content []byte
@@ -103,7 +103,11 @@ func (t *FilesTool) Upload(ctx context.Context, req *mcp.CallToolRequest, in Upl
if path != "" { if path != "" {
if !filepath.IsAbs(path) { if !filepath.IsAbs(path) {
return nil, UploadFileOutput{}, errInvalidInput("content_path must be an absolute path") return nil, UploadFileOutput{}, errInvalidField(
"content_path",
"content_path must be an absolute path",
"pass an absolute path on the server filesystem",
)
} }
var err error var err error
content, err = os.ReadFile(path) content, err = os.ReadFile(path)
@@ -112,7 +116,7 @@ func (t *FilesTool) Upload(ctx context.Context, req *mcp.CallToolRequest, in Upl
} }
} else { } else {
if b64 == "" { if b64 == "" {
return nil, UploadFileOutput{}, errInvalidInput("content_path or content_base64 is required") return nil, UploadFileOutput{}, errOneOfRequired("content_path", "content_base64")
} }
if len(b64) > maxBase64ToolBytes { if len(b64) > maxBase64ToolBytes {
return nil, UploadFileOutput{}, errInvalidInput( return nil, UploadFileOutput{}, errInvalidInput(
@@ -123,7 +127,11 @@ func (t *FilesTool) Upload(ctx context.Context, req *mcp.CallToolRequest, in Upl
var err error var err error
content, err = decodeBase64(raw) content, err = decodeBase64(raw)
if err != nil { if err != nil {
return nil, UploadFileOutput{}, errInvalidInput("content_base64 must be valid base64") return nil, UploadFileOutput{}, errInvalidField(
"content_base64",
"content_base64 must be valid base64",
"pass valid base64 data or a data URL",
)
} }
mediaTypeFromSource = dataURLMediaType mediaTypeFromSource = dataURLMediaType
} }
@@ -149,7 +157,7 @@ func (t *FilesTool) Save(ctx context.Context, req *mcp.CallToolRequest, in SaveF
b64 := strings.TrimSpace(in.ContentBase64) b64 := strings.TrimSpace(in.ContentBase64)
if uri != "" && b64 != "" { if uri != "" && b64 != "" {
return nil, SaveFileOutput{}, errInvalidInput("provide content_uri or content_base64, not both") return nil, SaveFileOutput{}, errMutuallyExclusiveFields("content_uri", "content_base64")
} }
if len(b64) > maxBase64ToolBytes { if len(b64) > maxBase64ToolBytes {
return nil, SaveFileOutput{}, errInvalidInput( return nil, SaveFileOutput{}, errInvalidInput(
@@ -162,28 +170,44 @@ func (t *FilesTool) Save(ctx context.Context, req *mcp.CallToolRequest, in SaveF
if uri != "" { if uri != "" {
if !strings.HasPrefix(uri, fileURIPrefix) { if !strings.HasPrefix(uri, fileURIPrefix) {
return nil, SaveFileOutput{}, errInvalidInput("content_uri must be an amcs://files/{id} URI") return nil, SaveFileOutput{}, errInvalidField(
"content_uri",
"content_uri must be an amcs://files/{id} URI",
"pass an amcs://files/{id} URI returned by upload_file or POST /files",
)
} }
rawID := strings.TrimPrefix(uri, fileURIPrefix) rawID := strings.TrimPrefix(uri, fileURIPrefix)
id, err := parseUUID(rawID) id, err := parseUUID(rawID)
if err != nil { if err != nil {
return nil, SaveFileOutput{}, errInvalidInput("content_uri contains an invalid file id") return nil, SaveFileOutput{}, errInvalidField(
"content_uri",
"content_uri contains an invalid file id",
"pass a valid amcs://files/{id} URI",
)
} }
file, err := t.store.GetStoredFile(ctx, id) file, err := t.store.GetStoredFile(ctx, id)
if err != nil { if err != nil {
return nil, SaveFileOutput{}, errInvalidInput("content_uri references a file that does not exist") return nil, SaveFileOutput{}, errInvalidField(
"content_uri",
"content_uri references a file that does not exist",
"upload the file first or pass an existing amcs://files/{id} URI",
)
} }
content = file.Content content = file.Content
mediaTypeFromSource = file.MediaType mediaTypeFromSource = file.MediaType
} else { } else {
contentBase64, mediaTypeFromDataURL := splitDataURL(b64) contentBase64, mediaTypeFromDataURL := splitDataURL(b64)
if contentBase64 == "" { if contentBase64 == "" {
return nil, SaveFileOutput{}, errInvalidInput("content_base64 or content_uri is required") return nil, SaveFileOutput{}, errOneOfRequired("content_base64", "content_uri")
} }
var err error var err error
content, err = decodeBase64(contentBase64) content, err = decodeBase64(contentBase64)
if err != nil { if err != nil {
return nil, SaveFileOutput{}, errInvalidInput("content_base64 must be valid base64") return nil, SaveFileOutput{}, errInvalidField(
"content_base64",
"content_base64 must be valid base64",
"pass valid base64 data or a data URL",
)
} }
mediaTypeFromSource = mediaTypeFromDataURL mediaTypeFromSource = mediaTypeFromDataURL
} }
@@ -205,7 +229,7 @@ func (t *FilesTool) Save(ctx context.Context, req *mcp.CallToolRequest, in SaveF
const fileURIPrefix = "amcs://files/" const fileURIPrefix = "amcs://files/"
func (t *FilesTool) GetRaw(ctx context.Context, rawID string) (thoughttypes.StoredFile, error) { func (t *FilesTool) GetRaw(ctx context.Context, rawID string) (thoughttypes.StoredFile, error) {
id, err := parseUUID(strings.TrimSpace(rawID)) id, err := parseStoredFileID(rawID)
if err != nil { if err != nil {
return thoughttypes.StoredFile{}, err return thoughttypes.StoredFile{}, err
} }
@@ -213,7 +237,7 @@ func (t *FilesTool) GetRaw(ctx context.Context, rawID string) (thoughttypes.Stor
} }
func (t *FilesTool) Load(ctx context.Context, _ *mcp.CallToolRequest, in LoadFileInput) (*mcp.CallToolResult, LoadFileOutput, error) { func (t *FilesTool) Load(ctx context.Context, _ *mcp.CallToolRequest, in LoadFileInput) (*mcp.CallToolResult, LoadFileOutput, error) {
id, err := parseUUID(in.ID) id, err := parseStoredFileID(in.ID)
if err != nil { if err != nil {
return nil, LoadFileOutput{}, err return nil, LoadFileOutput{}, err
} }
@@ -243,8 +267,7 @@ func (t *FilesTool) Load(ctx context.Context, _ *mcp.CallToolRequest, in LoadFil
} }
func (t *FilesTool) ReadResource(ctx context.Context, req *mcp.ReadResourceRequest) (*mcp.ReadResourceResult, error) { func (t *FilesTool) ReadResource(ctx context.Context, req *mcp.ReadResourceRequest) (*mcp.ReadResourceResult, error) {
rawID := strings.TrimPrefix(req.Params.URI, fileURIPrefix) id, err := parseStoredFileID(req.Params.URI)
id, err := parseUUID(strings.TrimSpace(rawID))
if err != nil { if err != nil {
return nil, mcp.ResourceNotFoundError(req.Params.URI) return nil, mcp.ResourceNotFoundError(req.Params.URI)
} }
@@ -309,7 +332,7 @@ func (t *FilesTool) List(ctx context.Context, req *mcp.CallToolRequest, in ListF
func (t *FilesTool) SaveDecoded(ctx context.Context, req *mcp.CallToolRequest, in SaveFileDecodedInput) (SaveFileOutput, error) { func (t *FilesTool) SaveDecoded(ctx context.Context, req *mcp.CallToolRequest, in SaveFileDecodedInput) (SaveFileOutput, error) {
name := strings.TrimSpace(in.Name) name := strings.TrimSpace(in.Name)
if name == "" { if name == "" {
return SaveFileOutput{}, errInvalidInput("name is required") return SaveFileOutput{}, errRequiredField("name")
} }
if len(in.Content) == 0 { if len(in.Content) == 0 {
return SaveFileOutput{}, errInvalidInput("decoded file content must not be empty") return SaveFileOutput{}, errInvalidInput("decoded file content must not be empty")
@@ -492,3 +515,9 @@ func normalizeFileLimit(limit int) int {
return limit return limit
} }
} }
func parseStoredFileID(raw string) (uuid.UUID, error) {
value := strings.TrimSpace(raw)
value = strings.TrimPrefix(value, fileURIPrefix)
return parseUUID(value)
}

View File

@@ -1,6 +1,10 @@
package tools package tools
import "testing" import (
"testing"
"github.com/google/uuid"
)
func TestDecodeBase64AcceptsWhitespaceAndMultipleVariants(t *testing.T) { func TestDecodeBase64AcceptsWhitespaceAndMultipleVariants(t *testing.T) {
tests := []struct { tests := []struct {
@@ -27,3 +31,45 @@ func TestDecodeBase64AcceptsWhitespaceAndMultipleVariants(t *testing.T) {
}) })
} }
} }
func TestParseStoredFileIDAcceptsUUIDAndURI(t *testing.T) {
id := uuid.New()
tests := []struct {
name string
input string
want uuid.UUID
}{
{name: "bare uuid", input: id.String(), want: id},
{name: "resource uri", input: fileURIPrefix + id.String(), want: id},
{name: "resource uri with surrounding whitespace", input: " " + fileURIPrefix + id.String() + " ", want: id},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got, err := parseStoredFileID(tc.input)
if err != nil {
t.Fatalf("parseStoredFileID(%q) error = %v", tc.input, err)
}
if got != tc.want {
t.Fatalf("parseStoredFileID(%q) = %v, want %v", tc.input, got, tc.want)
}
})
}
}
func TestParseStoredFileIDRejectsInvalidValues(t *testing.T) {
tests := []string{
"",
"not-a-uuid",
fileURIPrefix + "not-a-uuid",
}
for _, input := range tests {
t.Run(input, func(t *testing.T) {
if _, err := parseStoredFileID(input); err == nil {
t.Fatalf("parseStoredFileID(%q) = nil error, want error", input)
}
})
}
}

View File

@@ -9,22 +9,41 @@ import (
"github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5"
"github.com/modelcontextprotocol/go-sdk/mcp" "github.com/modelcontextprotocol/go-sdk/mcp"
"git.warky.dev/wdevs/amcs/internal/mcperrors"
"git.warky.dev/wdevs/amcs/internal/session" "git.warky.dev/wdevs/amcs/internal/session"
"git.warky.dev/wdevs/amcs/internal/store" "git.warky.dev/wdevs/amcs/internal/store"
thoughttypes "git.warky.dev/wdevs/amcs/internal/types" thoughttypes "git.warky.dev/wdevs/amcs/internal/types"
) )
func parseUUID(id string) (uuid.UUID, error) { func parseUUID(id string) (uuid.UUID, error) {
parsed, err := uuid.Parse(strings.TrimSpace(id)) trimmed := strings.TrimSpace(id)
parsed, err := uuid.Parse(trimmed)
if err != nil { if err != nil {
return uuid.Nil, fmt.Errorf("invalid id %q: %w", id, err) return uuid.Nil, newMCPError(
codeInvalidID,
fmt.Sprintf("invalid id %q", id),
mcpErrorData{
Type: mcperrors.TypeInvalidID,
Field: "id",
Value: trimmed,
Detail: err.Error(),
Hint: "pass a valid UUID",
},
)
} }
return parsed, nil return parsed, nil
} }
func sessionID(req *mcp.CallToolRequest) (string, error) { func sessionID(req *mcp.CallToolRequest) (string, error) {
if req == nil || req.Session == nil || req.Session.ID() == "" { if req == nil || req.Session == nil || req.Session.ID() == "" {
return "", fmt.Errorf("tool requires an MCP session") return "", newMCPError(
codeSessionRequired,
"tool requires an MCP session; use a stateful MCP client for session-scoped operations",
mcpErrorData{
Type: mcperrors.TypeSessionRequired,
Hint: "use a stateful MCP client for session-scoped operations",
},
)
} }
return req.Session.ID(), nil return req.Session.ID(), nil
} }
@@ -45,7 +64,15 @@ func resolveProject(ctx context.Context, db *store.DB, sessions *session.ActiveP
if projectRef == "" { if projectRef == "" {
if required { if required {
return nil, fmt.Errorf("project is required") return nil, newMCPError(
codeProjectRequired,
"project is required; pass project explicitly or call set_active_project in this MCP session first",
mcpErrorData{
Type: mcperrors.TypeProjectRequired,
Field: "project",
Hint: "pass project explicitly or call set_active_project in this MCP session first",
},
)
} }
return nil, nil return nil, nil
} }
@@ -53,7 +80,15 @@ func resolveProject(ctx context.Context, db *store.DB, sessions *session.ActiveP
project, err := db.GetProject(ctx, projectRef) project, err := db.GetProject(ctx, projectRef)
if err != nil { if err != nil {
if err == pgx.ErrNoRows { if err == pgx.ErrNoRows {
return nil, fmt.Errorf("project %q not found", projectRef) return nil, newMCPError(
codeProjectNotFound,
fmt.Sprintf("project %q not found", projectRef),
mcpErrorData{
Type: mcperrors.TypeProjectNotFound,
Field: "project",
Project: projectRef,
},
)
} }
return nil, err return nil, err
} }

View File

@@ -0,0 +1,107 @@
package tools
import (
"context"
"encoding/json"
"errors"
"testing"
"git.warky.dev/wdevs/amcs/internal/mcperrors"
"github.com/modelcontextprotocol/go-sdk/jsonrpc"
)
func TestResolveProjectRequiredErrorGuidesCaller(t *testing.T) {
_, err := resolveProject(context.Background(), nil, nil, nil, "", true)
if err == nil {
t.Fatal("resolveProject() error = nil, want error")
}
rpcErr, data := requireRPCError(t, err)
if rpcErr.Code != codeProjectRequired {
t.Fatalf("resolveProject() code = %d, want %d", rpcErr.Code, codeProjectRequired)
}
if data.Type != mcperrors.TypeProjectRequired {
t.Fatalf("resolveProject() type = %q, want %q", data.Type, mcperrors.TypeProjectRequired)
}
if data.Field != "project" {
t.Fatalf("resolveProject() field = %q, want %q", data.Field, "project")
}
if data.Hint == "" {
t.Fatal("resolveProject() hint = empty, want guidance")
}
}
func TestSessionIDErrorGuidesCaller(t *testing.T) {
_, err := sessionID(nil)
if err == nil {
t.Fatal("sessionID() error = nil, want error")
}
rpcErr, data := requireRPCError(t, err)
if rpcErr.Code != codeSessionRequired {
t.Fatalf("sessionID() code = %d, want %d", rpcErr.Code, codeSessionRequired)
}
if data.Type != mcperrors.TypeSessionRequired {
t.Fatalf("sessionID() type = %q, want %q", data.Type, mcperrors.TypeSessionRequired)
}
if data.Hint == "" {
t.Fatal("sessionID() hint = empty, want guidance")
}
}
func TestParseUUIDReturnsTypedError(t *testing.T) {
_, err := parseUUID("not-a-uuid")
if err == nil {
t.Fatal("parseUUID() error = nil, want error")
}
rpcErr, data := requireRPCError(t, err)
if rpcErr.Code != codeInvalidID {
t.Fatalf("parseUUID() code = %d, want %d", rpcErr.Code, codeInvalidID)
}
if data.Type != mcperrors.TypeInvalidID {
t.Fatalf("parseUUID() type = %q, want %q", data.Type, mcperrors.TypeInvalidID)
}
if data.Field != "id" {
t.Fatalf("parseUUID() field = %q, want %q", data.Field, "id")
}
if data.Value != "not-a-uuid" {
t.Fatalf("parseUUID() value = %q, want %q", data.Value, "not-a-uuid")
}
if data.Detail == "" {
t.Fatal("parseUUID() detail = empty, want parse failure detail")
}
}
func TestErrInvalidInputReturnsTypedError(t *testing.T) {
err := errInvalidInput("name is required")
if err == nil {
t.Fatal("errInvalidInput() error = nil, want error")
}
rpcErr, data := requireRPCError(t, err)
if rpcErr.Code != jsonrpc.CodeInvalidParams {
t.Fatalf("errInvalidInput() code = %d, want %d", rpcErr.Code, jsonrpc.CodeInvalidParams)
}
if data.Type != mcperrors.TypeInvalidInput {
t.Fatalf("errInvalidInput() type = %q, want %q", data.Type, mcperrors.TypeInvalidInput)
}
}
func requireRPCError(t *testing.T, err error) (*jsonrpc.Error, mcpErrorData) {
t.Helper()
var rpcErr *jsonrpc.Error
if !errors.As(err, &rpcErr) {
t.Fatalf("error type = %T, want *jsonrpc.Error", err)
}
var data mcpErrorData
if len(rpcErr.Data) > 0 {
if unmarshalErr := json.Unmarshal(rpcErr.Data, &data); unmarshalErr != nil {
t.Fatalf("unmarshal error data: %v", unmarshalErr)
}
}
return rpcErr, data
}

View File

@@ -35,7 +35,7 @@ type AddHouseholdItemOutput struct {
func (t *HouseholdTool) AddItem(ctx context.Context, _ *mcp.CallToolRequest, in AddHouseholdItemInput) (*mcp.CallToolResult, AddHouseholdItemOutput, error) { func (t *HouseholdTool) AddItem(ctx context.Context, _ *mcp.CallToolRequest, in AddHouseholdItemInput) (*mcp.CallToolResult, AddHouseholdItemOutput, error) {
if strings.TrimSpace(in.Name) == "" { if strings.TrimSpace(in.Name) == "" {
return nil, AddHouseholdItemOutput{}, errInvalidInput("name is required") return nil, AddHouseholdItemOutput{}, errRequiredField("name")
} }
if in.Details == nil { if in.Details == nil {
in.Details = map[string]any{} in.Details = map[string]any{}
@@ -112,7 +112,7 @@ type AddVendorOutput struct {
func (t *HouseholdTool) AddVendor(ctx context.Context, _ *mcp.CallToolRequest, in AddVendorInput) (*mcp.CallToolResult, AddVendorOutput, error) { func (t *HouseholdTool) AddVendor(ctx context.Context, _ *mcp.CallToolRequest, in AddVendorInput) (*mcp.CallToolResult, AddVendorOutput, error) {
if strings.TrimSpace(in.Name) == "" { if strings.TrimSpace(in.Name) == "" {
return nil, AddVendorOutput{}, errInvalidInput("name is required") return nil, AddVendorOutput{}, errRequiredField("name")
} }
vendor, err := t.store.AddVendor(ctx, ext.HouseholdVendor{ vendor, err := t.store.AddVendor(ctx, ext.HouseholdVendor{
Name: strings.TrimSpace(in.Name), Name: strings.TrimSpace(in.Name),

View File

@@ -62,7 +62,7 @@ func (t *LinksTool) Link(ctx context.Context, _ *mcp.CallToolRequest, in LinkInp
} }
relation := strings.TrimSpace(in.Relation) relation := strings.TrimSpace(in.Relation)
if relation == "" { if relation == "" {
return nil, LinkOutput{}, errInvalidInput("relation is required") return nil, LinkOutput{}, errRequiredField("relation")
} }
if _, err := t.store.GetThought(ctx, fromID); err != nil { if _, err := t.store.GetThought(ctx, fromID); err != nil {
return nil, LinkOutput{}, err return nil, LinkOutput{}, err

View File

@@ -37,7 +37,7 @@ type AddMaintenanceTaskOutput struct {
func (t *MaintenanceTool) AddTask(ctx context.Context, _ *mcp.CallToolRequest, in AddMaintenanceTaskInput) (*mcp.CallToolResult, AddMaintenanceTaskOutput, error) { func (t *MaintenanceTool) AddTask(ctx context.Context, _ *mcp.CallToolRequest, in AddMaintenanceTaskInput) (*mcp.CallToolResult, AddMaintenanceTaskOutput, error) {
if strings.TrimSpace(in.Name) == "" { if strings.TrimSpace(in.Name) == "" {
return nil, AddMaintenanceTaskOutput{}, errInvalidInput("name is required") return nil, AddMaintenanceTaskOutput{}, errRequiredField("name")
} }
priority := strings.TrimSpace(in.Priority) priority := strings.TrimSpace(in.Priority)
if priority == "" { if priority == "" {

View File

@@ -41,7 +41,7 @@ type AddRecipeOutput struct {
func (t *MealsTool) AddRecipe(ctx context.Context, _ *mcp.CallToolRequest, in AddRecipeInput) (*mcp.CallToolResult, AddRecipeOutput, error) { func (t *MealsTool) AddRecipe(ctx context.Context, _ *mcp.CallToolRequest, in AddRecipeInput) (*mcp.CallToolResult, AddRecipeOutput, error) {
if strings.TrimSpace(in.Name) == "" { if strings.TrimSpace(in.Name) == "" {
return nil, AddRecipeOutput{}, errInvalidInput("name is required") return nil, AddRecipeOutput{}, errRequiredField("name")
} }
if in.Ingredients == nil { if in.Ingredients == nil {
in.Ingredients = []ext.Ingredient{} in.Ingredients = []ext.Ingredient{}
@@ -116,7 +116,7 @@ type UpdateRecipeOutput struct {
func (t *MealsTool) UpdateRecipe(ctx context.Context, _ *mcp.CallToolRequest, in UpdateRecipeInput) (*mcp.CallToolResult, UpdateRecipeOutput, error) { func (t *MealsTool) UpdateRecipe(ctx context.Context, _ *mcp.CallToolRequest, in UpdateRecipeInput) (*mcp.CallToolResult, UpdateRecipeOutput, error) {
if strings.TrimSpace(in.Name) == "" { if strings.TrimSpace(in.Name) == "" {
return nil, UpdateRecipeOutput{}, errInvalidInput("name is required") return nil, UpdateRecipeOutput{}, errRequiredField("name")
} }
if in.Ingredients == nil { if in.Ingredients == nil {
in.Ingredients = []ext.Ingredient{} in.Ingredients = []ext.Ingredient{}

View File

@@ -52,7 +52,7 @@ func NewProjectsTool(db *store.DB, sessions *session.ActiveProjects) *ProjectsTo
func (t *ProjectsTool) Create(ctx context.Context, _ *mcp.CallToolRequest, in CreateProjectInput) (*mcp.CallToolResult, CreateProjectOutput, error) { func (t *ProjectsTool) Create(ctx context.Context, _ *mcp.CallToolRequest, in CreateProjectInput) (*mcp.CallToolResult, CreateProjectOutput, error) {
name := strings.TrimSpace(in.Name) name := strings.TrimSpace(in.Name)
if name == "" { if name == "" {
return nil, CreateProjectOutput{}, errInvalidInput("name is required") return nil, CreateProjectOutput{}, errRequiredField("name")
} }
project, err := t.store.CreateProject(ctx, name, strings.TrimSpace(in.Description)) project, err := t.store.CreateProject(ctx, name, strings.TrimSpace(in.Description))
if err != nil { if err != nil {

View File

@@ -39,7 +39,7 @@ func NewRecallTool(db *store.DB, provider ai.Provider, search config.SearchConfi
func (t *RecallTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in RecallInput) (*mcp.CallToolResult, RecallOutput, error) { func (t *RecallTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in RecallInput) (*mcp.CallToolResult, RecallOutput, error) {
query := strings.TrimSpace(in.Query) query := strings.TrimSpace(in.Query)
if query == "" { if query == "" {
return nil, RecallOutput{}, errInvalidInput("query is required") return nil, RecallOutput{}, errRequiredField("query")
} }
project, err := resolveProject(ctx, t.store, t.sessions, req, in.Project, false) project, err := resolveProject(ctx, t.store, t.sessions, req, in.Project, false)

View File

@@ -39,7 +39,7 @@ func NewSearchTool(db *store.DB, provider ai.Provider, search config.SearchConfi
func (t *SearchTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in SearchInput) (*mcp.CallToolResult, SearchOutput, error) { func (t *SearchTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in SearchInput) (*mcp.CallToolResult, SearchOutput, error) {
query := strings.TrimSpace(in.Query) query := strings.TrimSpace(in.Query)
if query == "" { if query == "" {
return nil, SearchOutput{}, errInvalidInput("query is required") return nil, SearchOutput{}, errRequiredField("query")
} }
limit := normalizeLimit(in.Limit, t.search) limit := normalizeLimit(in.Limit, t.search)

View File

@@ -36,10 +36,10 @@ type AddSkillOutput struct {
func (t *SkillsTool) AddSkill(ctx context.Context, _ *mcp.CallToolRequest, in AddSkillInput) (*mcp.CallToolResult, AddSkillOutput, error) { func (t *SkillsTool) AddSkill(ctx context.Context, _ *mcp.CallToolRequest, in AddSkillInput) (*mcp.CallToolResult, AddSkillOutput, error) {
if strings.TrimSpace(in.Name) == "" { if strings.TrimSpace(in.Name) == "" {
return nil, AddSkillOutput{}, errInvalidInput("name is required") return nil, AddSkillOutput{}, errRequiredField("name")
} }
if strings.TrimSpace(in.Content) == "" { if strings.TrimSpace(in.Content) == "" {
return nil, AddSkillOutput{}, errInvalidInput("content is required") return nil, AddSkillOutput{}, errRequiredField("content")
} }
if in.Tags == nil { if in.Tags == nil {
in.Tags = []string{} in.Tags = []string{}
@@ -110,10 +110,10 @@ type AddGuardrailOutput struct {
func (t *SkillsTool) AddGuardrail(ctx context.Context, _ *mcp.CallToolRequest, in AddGuardrailInput) (*mcp.CallToolResult, AddGuardrailOutput, error) { func (t *SkillsTool) AddGuardrail(ctx context.Context, _ *mcp.CallToolRequest, in AddGuardrailInput) (*mcp.CallToolResult, AddGuardrailOutput, error) {
if strings.TrimSpace(in.Name) == "" { if strings.TrimSpace(in.Name) == "" {
return nil, AddGuardrailOutput{}, errInvalidInput("name is required") return nil, AddGuardrailOutput{}, errRequiredField("name")
} }
if strings.TrimSpace(in.Content) == "" { if strings.TrimSpace(in.Content) == "" {
return nil, AddGuardrailOutput{}, errInvalidInput("content is required") return nil, AddGuardrailOutput{}, errRequiredField("content")
} }
severity := strings.TrimSpace(in.Severity) severity := strings.TrimSpace(in.Severity)
if severity == "" { if severity == "" {
@@ -122,7 +122,11 @@ func (t *SkillsTool) AddGuardrail(ctx context.Context, _ *mcp.CallToolRequest, i
switch severity { switch severity {
case "low", "medium", "high", "critical": case "low", "medium", "high", "critical":
default: default:
return nil, AddGuardrailOutput{}, errInvalidInput("severity must be one of: low, medium, high, critical") return nil, AddGuardrailOutput{}, errInvalidField(
"severity",
"severity must be one of: low, medium, high, critical",
"pass one of: low, medium, high, critical",
)
} }
if in.Tags == nil { if in.Tags == nil {
in.Tags = []string{} in.Tags = []string{}

45
internal/tools/version.go Normal file
View File

@@ -0,0 +1,45 @@
package tools
import (
"context"
"github.com/modelcontextprotocol/go-sdk/mcp"
"git.warky.dev/wdevs/amcs/internal/buildinfo"
)
type VersionTool struct {
serverName string
info buildinfo.Info
}
type GetVersionInfoInput struct{}
type GetVersionInfoOutput struct {
ServerName string `json:"server_name"`
Version string `json:"version"`
TagName string `json:"tag_name"`
Commit string `json:"commit"`
BuildDate string `json:"build_date"`
}
func NewVersionTool(serverName string, info buildinfo.Info) *VersionTool {
return &VersionTool{
serverName: serverName,
info: info,
}
}
func (t *VersionTool) GetInfo(_ context.Context, _ *mcp.CallToolRequest, _ GetVersionInfoInput) (*mcp.CallToolResult, GetVersionInfoOutput, error) {
if t == nil {
return nil, GetVersionInfoOutput{}, nil
}
return nil, GetVersionInfoOutput{
ServerName: t.serverName,
Version: t.info.Version,
TagName: t.info.TagName,
Commit: t.info.Commit,
BuildDate: t.info.BuildDate,
}, nil
}

View File

@@ -0,0 +1,38 @@
package tools
import (
"context"
"testing"
"git.warky.dev/wdevs/amcs/internal/buildinfo"
)
func TestVersionToolReturnsBuildInformation(t *testing.T) {
tool := NewVersionTool("amcs", buildinfo.Info{
Version: "v1.2.3",
TagName: "v1.2.3",
Commit: "abc1234",
BuildDate: "2026-03-31T12:34:56Z",
})
_, out, err := tool.GetInfo(context.Background(), nil, GetVersionInfoInput{})
if err != nil {
t.Fatalf("GetInfo() error = %v", err)
}
if out.ServerName != "amcs" {
t.Fatalf("server_name = %q, want %q", out.ServerName, "amcs")
}
if out.Version != "v1.2.3" {
t.Fatalf("version = %q, want %q", out.Version, "v1.2.3")
}
if out.TagName != "v1.2.3" {
t.Fatalf("tag_name = %q, want %q", out.TagName, "v1.2.3")
}
if out.Commit != "abc1234" {
t.Fatalf("commit = %q, want %q", out.Commit, "abc1234")
}
if out.BuildDate != "2026-03-31T12:34:56Z" {
t.Fatalf("build_date = %q, want %q", out.BuildDate, "2026-03-31T12:34:56Z")
}
}

View File

@@ -69,4 +69,4 @@ At the start of every project session, after setting the active project:
## Short Operational Form ## Short Operational Form
Use AMCS memory in project scope when the current work matches a known project. If no clear project matches, global notebook memory is allowed for non-project-specific information. At the start of every project session call `list_project_skills` and `list_project_guardrails` and apply what is returned; only create new skills or guardrails if none exist. Store durable notes with `capture_thought`. For binary files or files larger than 10 MB, call `upload_file` with `content_path` to stage the file and get an `amcs://files/{id}` URI, then pass that URI to `save_file` as `content_uri` to link it to a thought. For small files, use `save_file` or `upload_file` with `content_base64` directly. Browse stored files with `list_files`, and load them with `load_file` only when their contents are needed. Stored files can also be read as raw binary via MCP resources at `amcs://files/{id}`. Never store project-specific memory globally when a matching project exists, and never store memory in the wrong project. If project matching is ambiguous, ask the user. Use AMCS memory in project scope when the current work matches a known project. If no clear project matches, global notebook memory is allowed for non-project-specific information. At the start of every project session call `list_project_skills` and `list_project_guardrails` and apply what is returned; only create new skills or guardrails if none exist. If your MCP client does not preserve sessions across calls, pass `project` explicitly instead of relying on `set_active_project`. Store durable notes with `capture_thought`. For binary files or files larger than 10 MB, call `upload_file` with `content_path` to stage the file and get an `amcs://files/{id}` URI, then pass that URI to `save_file` as `content_uri` to link it to a thought. For small files, use `save_file` or `upload_file` with `content_base64` directly. Browse stored files with `list_files`, and load them with `load_file` only when their contents are needed. Stored files can also be read as raw binary via MCP resources at `amcs://files/{id}`. Never store project-specific memory globally when a matching project exists, and never store memory in the wrong project. If project matching is ambiguous, ask the user.