test(tools): add unit tests for error handling functions

* Implement tests for error functions like errRequiredField, errInvalidField, and errEntityNotFound.
* Ensure proper metadata is returned for various error scenarios.
* Validate error handling in CRM, Files, and other tools.
* Introduce tests for parsing stored file IDs and UUIDs.
* Enhance coverage for helper functions related to project resolution and session management.
This commit is contained in:
Hein
2026-03-31 15:10:07 +02:00
parent acd780ac9c
commit f41c512f36
54 changed files with 1937 additions and 365 deletions

View File

@@ -6,15 +6,31 @@ import (
"fmt"
"log/slog"
"reflect"
"regexp"
"strings"
"time"
"github.com/google/jsonschema-go/jsonschema"
"github.com/google/uuid"
"github.com/modelcontextprotocol/go-sdk/jsonrpc"
"github.com/modelcontextprotocol/go-sdk/mcp"
"git.warky.dev/wdevs/amcs/internal/mcperrors"
)
const maxLoggedArgBytes = 512
var sensitiveToolArgKeys = map[string]struct{}{
"client_secret": {},
"content": {},
"content_base64": {},
"content_path": {},
"email": {},
"notes": {},
"phone": {},
"value": {},
}
var toolSchemaOptions = &jsonschema.ForOptions{
TypeSchemas: map[reflect.Type]*jsonschema.Schema{
reflect.TypeFor[uuid.UUID](): {
@@ -24,19 +40,92 @@ var toolSchemaOptions = &jsonschema.ForOptions{
},
}
func addTool[In any, Out any](server *mcp.Server, logger *slog.Logger, tool *mcp.Tool, handler func(context.Context, *mcp.CallToolRequest, In) (*mcp.CallToolResult, Out, error)) {
var (
missingPropertyPattern = regexp.MustCompile(`missing properties: \["([^"]+)"\]`)
propertyPathPattern = regexp.MustCompile(`validating /properties/([^:]+):`)
structFieldPattern = regexp.MustCompile(`\.([A-Za-z0-9_]+) of type`)
)
func addTool[In any, Out any](server *mcp.Server, logger *slog.Logger, tool *mcp.Tool, handler func(context.Context, *mcp.CallToolRequest, In) (*mcp.CallToolResult, Out, error)) error {
if err := setToolSchemas[In, Out](tool); err != nil {
panic(fmt.Sprintf("configure MCP tool %q schemas: %v", tool.Name, err))
return fmt.Errorf("configure MCP tool %q schemas: %w", tool.Name, err)
}
mcp.AddTool(server, tool, logToolCall(logger, tool.Name, handler))
inputResolved, err := resolveToolSchema(tool.InputSchema)
if err != nil {
return fmt.Errorf("resolve MCP tool %q input schema: %w", tool.Name, err)
}
outputResolved, err := resolveToolSchema(tool.OutputSchema)
if err != nil {
return fmt.Errorf("resolve MCP tool %q output schema: %w", tool.Name, err)
}
server.AddTool(tool, logToolCall(logger, tool.Name, func(ctx context.Context, req *mcp.CallToolRequest) (*mcp.CallToolResult, error) {
var input json.RawMessage
if req != nil && req.Params != nil && req.Params.Arguments != nil {
input = req.Params.Arguments
}
input, err = applyToolSchema(input, inputResolved)
if err != nil {
return nil, invalidArgumentsError(err)
}
var in In
if input != nil {
if err := json.Unmarshal(input, &in); err != nil {
return nil, invalidArgumentsError(err)
}
}
result, out, err := handler(ctx, req, in)
if err != nil {
if wireErr, ok := err.(*jsonrpc.Error); ok {
return nil, wireErr
}
var errResult mcp.CallToolResult
errResult.SetError(err)
return &errResult, nil
}
if result == nil {
result = &mcp.CallToolResult{}
}
if outputResolved == nil {
return result, nil
}
outValue := normalizeTypedNil[Out](out)
if outValue == nil {
return result, nil
}
outBytes, err := json.Marshal(outValue)
if err != nil {
return nil, fmt.Errorf("marshaling output: %w", err)
}
outJSON, err := applyToolSchema(json.RawMessage(outBytes), outputResolved)
if err != nil {
return nil, fmt.Errorf("validating tool output: %w", err)
}
result.StructuredContent = outJSON
if result.Content == nil {
result.Content = []mcp.Content{&mcp.TextContent{Text: string(outJSON)}}
}
return result, nil
}))
return nil
}
func logToolCall[In any, Out any](logger *slog.Logger, toolName string, handler func(context.Context, *mcp.CallToolRequest, In) (*mcp.CallToolResult, Out, error)) func(context.Context, *mcp.CallToolRequest, In) (*mcp.CallToolResult, Out, error) {
func logToolCall(logger *slog.Logger, toolName string, handler mcp.ToolHandler) mcp.ToolHandler {
if logger == nil {
return handler
}
return func(ctx context.Context, req *mcp.CallToolRequest, in In) (*mcp.CallToolResult, Out, error) {
return func(ctx context.Context, req *mcp.CallToolRequest) (*mcp.CallToolResult, error) {
start := time.Now()
attrs := []any{slog.String("tool", toolName)}
if req != nil && req.Params != nil {
@@ -44,23 +133,27 @@ func logToolCall[In any, Out any](logger *slog.Logger, toolName string, handler
}
logger.Info("mcp tool started", attrs...)
result, out, err := handler(ctx, req, in)
result, err := handler(ctx, req)
completionAttrs := append([]any{}, attrs...)
completionAttrs = append(completionAttrs, slog.String("duration", formatLogDuration(time.Since(start))))
if err != nil {
completionAttrs = append(completionAttrs, slog.String("error", err.Error()))
logger.Error("mcp tool completed", completionAttrs...)
return result, out, err
return result, err
}
logger.Info("mcp tool completed", completionAttrs...)
return result, out, nil
return result, nil
}
}
func truncateArgs(args any) string {
b, err := json.Marshal(args)
redacted, err := redactArguments(args)
if err != nil {
return "<unserializable>"
}
b, err := json.Marshal(redacted)
if err != nil {
return "<unserializable>"
}
@@ -70,6 +163,52 @@ func truncateArgs(args any) string {
return string(b[:maxLoggedArgBytes]) + fmt.Sprintf("… (%d bytes total)", len(b))
}
func redactArguments(args any) (any, error) {
if args == nil {
return nil, nil
}
b, err := json.Marshal(args)
if err != nil {
return nil, err
}
var decoded any
if err := json.Unmarshal(b, &decoded); err != nil {
return nil, err
}
return redactJSONValue("", decoded), nil
}
func redactJSONValue(key string, value any) any {
if isSensitiveToolArgKey(key) {
return "<redacted>"
}
switch typed := value.(type) {
case map[string]any:
redacted := make(map[string]any, len(typed))
for childKey, childValue := range typed {
redacted[childKey] = redactJSONValue(childKey, childValue)
}
return redacted
case []any:
redacted := make([]any, len(typed))
for i, item := range typed {
redacted[i] = redactJSONValue("", item)
}
return redacted
default:
return value
}
}
func isSensitiveToolArgKey(key string) bool {
_, ok := sensitiveToolArgKeys[strings.ToLower(strings.TrimSpace(key))]
return ok
}
func formatLogDuration(d time.Duration) string {
if d < 0 {
d = -d
@@ -101,3 +240,104 @@ func setToolSchemas[In any, Out any](tool *mcp.Tool) error {
return nil
}
func resolveToolSchema(schema any) (*jsonschema.Resolved, error) {
if schema == nil {
return nil, nil
}
if typed, ok := schema.(*jsonschema.Schema); ok {
return typed.Resolve(&jsonschema.ResolveOptions{ValidateDefaults: true})
}
var remarshalTarget *jsonschema.Schema
if err := remarshalJSON(schema, &remarshalTarget); err != nil {
return nil, err
}
return remarshalTarget.Resolve(&jsonschema.ResolveOptions{ValidateDefaults: true})
}
func applyToolSchema(data json.RawMessage, resolved *jsonschema.Resolved) (json.RawMessage, error) {
if resolved == nil {
return data, nil
}
value := make(map[string]any)
if len(data) > 0 {
if err := json.Unmarshal(data, &value); err != nil {
return nil, fmt.Errorf("unmarshaling arguments: %w", err)
}
}
if err := resolved.ApplyDefaults(&value); err != nil {
return nil, fmt.Errorf("applying schema defaults: %w", err)
}
if err := resolved.Validate(&value); err != nil {
return nil, err
}
normalized, err := json.Marshal(value)
if err != nil {
return nil, fmt.Errorf("marshalling with defaults: %w", err)
}
return normalized, nil
}
func remarshalJSON(src any, dst any) error {
b, err := json.Marshal(src)
if err != nil {
return err
}
return json.Unmarshal(b, dst)
}
func normalizeTypedNil[T any](value T) any {
rt := reflect.TypeFor[T]()
if rt.Kind() == reflect.Pointer {
var zero T
if any(value) == any(zero) {
return reflect.Zero(rt.Elem()).Interface()
}
}
return any(value)
}
func invalidArgumentsError(err error) error {
detail := err.Error()
field := validationField(detail)
payload := mcperrors.Data{
Type: mcperrors.TypeInvalidArguments,
Detail: detail,
}
if field != "" {
payload.Field = field
payload.Hint = "check the " + field + " argument"
}
data, marshalErr := json.Marshal(payload)
if marshalErr != nil {
return &jsonrpc.Error{
Code: jsonrpc.CodeInvalidParams,
Message: "invalid tool arguments",
}
}
return &jsonrpc.Error{
Code: jsonrpc.CodeInvalidParams,
Message: "invalid tool arguments",
Data: data,
}
}
func validationField(detail string) string {
if matches := missingPropertyPattern.FindStringSubmatch(detail); len(matches) == 2 {
return matches[1]
}
if matches := propertyPathPattern.FindStringSubmatch(detail); len(matches) == 2 {
return matches[1]
}
if matches := structFieldPattern.FindStringSubmatch(detail); len(matches) == 2 {
return strings.ToLower(matches[1])
}
return ""
}