test(tools): add unit tests for error handling functions
* Implement tests for error functions like errRequiredField, errInvalidField, and errEntityNotFound. * Ensure proper metadata is returned for various error scenarios. * Validate error handling in CRM, Files, and other tools. * Introduce tests for parsing stored file IDs and UUIDs. * Enhance coverage for helper functions related to project resolution and session management.
This commit is contained in:
@@ -6,15 +6,31 @@ import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/jsonschema-go/jsonschema"
|
||||
"github.com/google/uuid"
|
||||
"github.com/modelcontextprotocol/go-sdk/jsonrpc"
|
||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
|
||||
"git.warky.dev/wdevs/amcs/internal/mcperrors"
|
||||
)
|
||||
|
||||
const maxLoggedArgBytes = 512
|
||||
|
||||
var sensitiveToolArgKeys = map[string]struct{}{
|
||||
"client_secret": {},
|
||||
"content": {},
|
||||
"content_base64": {},
|
||||
"content_path": {},
|
||||
"email": {},
|
||||
"notes": {},
|
||||
"phone": {},
|
||||
"value": {},
|
||||
}
|
||||
|
||||
var toolSchemaOptions = &jsonschema.ForOptions{
|
||||
TypeSchemas: map[reflect.Type]*jsonschema.Schema{
|
||||
reflect.TypeFor[uuid.UUID](): {
|
||||
@@ -24,19 +40,92 @@ var toolSchemaOptions = &jsonschema.ForOptions{
|
||||
},
|
||||
}
|
||||
|
||||
func addTool[In any, Out any](server *mcp.Server, logger *slog.Logger, tool *mcp.Tool, handler func(context.Context, *mcp.CallToolRequest, In) (*mcp.CallToolResult, Out, error)) {
|
||||
var (
|
||||
missingPropertyPattern = regexp.MustCompile(`missing properties: \["([^"]+)"\]`)
|
||||
propertyPathPattern = regexp.MustCompile(`validating /properties/([^:]+):`)
|
||||
structFieldPattern = regexp.MustCompile(`\.([A-Za-z0-9_]+) of type`)
|
||||
)
|
||||
|
||||
func addTool[In any, Out any](server *mcp.Server, logger *slog.Logger, tool *mcp.Tool, handler func(context.Context, *mcp.CallToolRequest, In) (*mcp.CallToolResult, Out, error)) error {
|
||||
if err := setToolSchemas[In, Out](tool); err != nil {
|
||||
panic(fmt.Sprintf("configure MCP tool %q schemas: %v", tool.Name, err))
|
||||
return fmt.Errorf("configure MCP tool %q schemas: %w", tool.Name, err)
|
||||
}
|
||||
mcp.AddTool(server, tool, logToolCall(logger, tool.Name, handler))
|
||||
|
||||
inputResolved, err := resolveToolSchema(tool.InputSchema)
|
||||
if err != nil {
|
||||
return fmt.Errorf("resolve MCP tool %q input schema: %w", tool.Name, err)
|
||||
}
|
||||
outputResolved, err := resolveToolSchema(tool.OutputSchema)
|
||||
if err != nil {
|
||||
return fmt.Errorf("resolve MCP tool %q output schema: %w", tool.Name, err)
|
||||
}
|
||||
|
||||
server.AddTool(tool, logToolCall(logger, tool.Name, func(ctx context.Context, req *mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
var input json.RawMessage
|
||||
if req != nil && req.Params != nil && req.Params.Arguments != nil {
|
||||
input = req.Params.Arguments
|
||||
}
|
||||
|
||||
input, err = applyToolSchema(input, inputResolved)
|
||||
if err != nil {
|
||||
return nil, invalidArgumentsError(err)
|
||||
}
|
||||
|
||||
var in In
|
||||
if input != nil {
|
||||
if err := json.Unmarshal(input, &in); err != nil {
|
||||
return nil, invalidArgumentsError(err)
|
||||
}
|
||||
}
|
||||
|
||||
result, out, err := handler(ctx, req, in)
|
||||
if err != nil {
|
||||
if wireErr, ok := err.(*jsonrpc.Error); ok {
|
||||
return nil, wireErr
|
||||
}
|
||||
var errResult mcp.CallToolResult
|
||||
errResult.SetError(err)
|
||||
return &errResult, nil
|
||||
}
|
||||
|
||||
if result == nil {
|
||||
result = &mcp.CallToolResult{}
|
||||
}
|
||||
|
||||
if outputResolved == nil {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
outValue := normalizeTypedNil[Out](out)
|
||||
if outValue == nil {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
outBytes, err := json.Marshal(outValue)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshaling output: %w", err)
|
||||
}
|
||||
outJSON, err := applyToolSchema(json.RawMessage(outBytes), outputResolved)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("validating tool output: %w", err)
|
||||
}
|
||||
|
||||
result.StructuredContent = outJSON
|
||||
if result.Content == nil {
|
||||
result.Content = []mcp.Content{&mcp.TextContent{Text: string(outJSON)}}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}))
|
||||
return nil
|
||||
}
|
||||
|
||||
func logToolCall[In any, Out any](logger *slog.Logger, toolName string, handler func(context.Context, *mcp.CallToolRequest, In) (*mcp.CallToolResult, Out, error)) func(context.Context, *mcp.CallToolRequest, In) (*mcp.CallToolResult, Out, error) {
|
||||
func logToolCall(logger *slog.Logger, toolName string, handler mcp.ToolHandler) mcp.ToolHandler {
|
||||
if logger == nil {
|
||||
return handler
|
||||
}
|
||||
|
||||
return func(ctx context.Context, req *mcp.CallToolRequest, in In) (*mcp.CallToolResult, Out, error) {
|
||||
return func(ctx context.Context, req *mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
start := time.Now()
|
||||
attrs := []any{slog.String("tool", toolName)}
|
||||
if req != nil && req.Params != nil {
|
||||
@@ -44,23 +133,27 @@ func logToolCall[In any, Out any](logger *slog.Logger, toolName string, handler
|
||||
}
|
||||
|
||||
logger.Info("mcp tool started", attrs...)
|
||||
result, out, err := handler(ctx, req, in)
|
||||
result, err := handler(ctx, req)
|
||||
|
||||
completionAttrs := append([]any{}, attrs...)
|
||||
completionAttrs = append(completionAttrs, slog.String("duration", formatLogDuration(time.Since(start))))
|
||||
if err != nil {
|
||||
completionAttrs = append(completionAttrs, slog.String("error", err.Error()))
|
||||
logger.Error("mcp tool completed", completionAttrs...)
|
||||
return result, out, err
|
||||
return result, err
|
||||
}
|
||||
|
||||
logger.Info("mcp tool completed", completionAttrs...)
|
||||
return result, out, nil
|
||||
return result, nil
|
||||
}
|
||||
}
|
||||
|
||||
func truncateArgs(args any) string {
|
||||
b, err := json.Marshal(args)
|
||||
redacted, err := redactArguments(args)
|
||||
if err != nil {
|
||||
return "<unserializable>"
|
||||
}
|
||||
b, err := json.Marshal(redacted)
|
||||
if err != nil {
|
||||
return "<unserializable>"
|
||||
}
|
||||
@@ -70,6 +163,52 @@ func truncateArgs(args any) string {
|
||||
return string(b[:maxLoggedArgBytes]) + fmt.Sprintf("… (%d bytes total)", len(b))
|
||||
}
|
||||
|
||||
func redactArguments(args any) (any, error) {
|
||||
if args == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
b, err := json.Marshal(args)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var decoded any
|
||||
if err := json.Unmarshal(b, &decoded); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return redactJSONValue("", decoded), nil
|
||||
}
|
||||
|
||||
func redactJSONValue(key string, value any) any {
|
||||
if isSensitiveToolArgKey(key) {
|
||||
return "<redacted>"
|
||||
}
|
||||
|
||||
switch typed := value.(type) {
|
||||
case map[string]any:
|
||||
redacted := make(map[string]any, len(typed))
|
||||
for childKey, childValue := range typed {
|
||||
redacted[childKey] = redactJSONValue(childKey, childValue)
|
||||
}
|
||||
return redacted
|
||||
case []any:
|
||||
redacted := make([]any, len(typed))
|
||||
for i, item := range typed {
|
||||
redacted[i] = redactJSONValue("", item)
|
||||
}
|
||||
return redacted
|
||||
default:
|
||||
return value
|
||||
}
|
||||
}
|
||||
|
||||
func isSensitiveToolArgKey(key string) bool {
|
||||
_, ok := sensitiveToolArgKeys[strings.ToLower(strings.TrimSpace(key))]
|
||||
return ok
|
||||
}
|
||||
|
||||
func formatLogDuration(d time.Duration) string {
|
||||
if d < 0 {
|
||||
d = -d
|
||||
@@ -101,3 +240,104 @@ func setToolSchemas[In any, Out any](tool *mcp.Tool) error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func resolveToolSchema(schema any) (*jsonschema.Resolved, error) {
|
||||
if schema == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if typed, ok := schema.(*jsonschema.Schema); ok {
|
||||
return typed.Resolve(&jsonschema.ResolveOptions{ValidateDefaults: true})
|
||||
}
|
||||
|
||||
var remarshalTarget *jsonschema.Schema
|
||||
if err := remarshalJSON(schema, &remarshalTarget); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return remarshalTarget.Resolve(&jsonschema.ResolveOptions{ValidateDefaults: true})
|
||||
}
|
||||
|
||||
func applyToolSchema(data json.RawMessage, resolved *jsonschema.Resolved) (json.RawMessage, error) {
|
||||
if resolved == nil {
|
||||
return data, nil
|
||||
}
|
||||
|
||||
value := make(map[string]any)
|
||||
if len(data) > 0 {
|
||||
if err := json.Unmarshal(data, &value); err != nil {
|
||||
return nil, fmt.Errorf("unmarshaling arguments: %w", err)
|
||||
}
|
||||
}
|
||||
if err := resolved.ApplyDefaults(&value); err != nil {
|
||||
return nil, fmt.Errorf("applying schema defaults: %w", err)
|
||||
}
|
||||
if err := resolved.Validate(&value); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
normalized, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshalling with defaults: %w", err)
|
||||
}
|
||||
return normalized, nil
|
||||
}
|
||||
|
||||
func remarshalJSON(src any, dst any) error {
|
||||
b, err := json.Marshal(src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return json.Unmarshal(b, dst)
|
||||
}
|
||||
|
||||
func normalizeTypedNil[T any](value T) any {
|
||||
rt := reflect.TypeFor[T]()
|
||||
if rt.Kind() == reflect.Pointer {
|
||||
var zero T
|
||||
if any(value) == any(zero) {
|
||||
return reflect.Zero(rt.Elem()).Interface()
|
||||
}
|
||||
}
|
||||
return any(value)
|
||||
}
|
||||
|
||||
func invalidArgumentsError(err error) error {
|
||||
detail := err.Error()
|
||||
field := validationField(detail)
|
||||
|
||||
payload := mcperrors.Data{
|
||||
Type: mcperrors.TypeInvalidArguments,
|
||||
Detail: detail,
|
||||
}
|
||||
if field != "" {
|
||||
payload.Field = field
|
||||
payload.Hint = "check the " + field + " argument"
|
||||
}
|
||||
|
||||
data, marshalErr := json.Marshal(payload)
|
||||
if marshalErr != nil {
|
||||
return &jsonrpc.Error{
|
||||
Code: jsonrpc.CodeInvalidParams,
|
||||
Message: "invalid tool arguments",
|
||||
}
|
||||
}
|
||||
|
||||
return &jsonrpc.Error{
|
||||
Code: jsonrpc.CodeInvalidParams,
|
||||
Message: "invalid tool arguments",
|
||||
Data: data,
|
||||
}
|
||||
}
|
||||
|
||||
func validationField(detail string) string {
|
||||
if matches := missingPropertyPattern.FindStringSubmatch(detail); len(matches) == 2 {
|
||||
return matches[1]
|
||||
}
|
||||
if matches := propertyPathPattern.FindStringSubmatch(detail); len(matches) == 2 {
|
||||
return matches[1]
|
||||
}
|
||||
if matches := structFieldPattern.FindStringSubmatch(detail); len(matches) == 2 {
|
||||
return strings.ToLower(matches[1])
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user