test(tools): add unit tests for error handling functions
* Implement tests for error functions like errRequiredField, errInvalidField, and errEntityNotFound. * Ensure proper metadata is returned for various error scenarios. * Validate error handling in CRM, Files, and other tools. * Introduce tests for parsing stored file IDs and UUIDs. * Enhance coverage for helper functions related to project resolution and session management.
This commit is contained in:
@@ -35,7 +35,7 @@ type AddFamilyMemberOutput struct {
|
||||
|
||||
func (t *CalendarTool) AddMember(ctx context.Context, _ *mcp.CallToolRequest, in AddFamilyMemberInput) (*mcp.CallToolResult, AddFamilyMemberOutput, error) {
|
||||
if strings.TrimSpace(in.Name) == "" {
|
||||
return nil, AddFamilyMemberOutput{}, errInvalidInput("name is required")
|
||||
return nil, AddFamilyMemberOutput{}, errRequiredField("name")
|
||||
}
|
||||
member, err := t.store.AddFamilyMember(ctx, ext.FamilyMember{
|
||||
Name: strings.TrimSpace(in.Name),
|
||||
@@ -89,7 +89,7 @@ type AddActivityOutput struct {
|
||||
|
||||
func (t *CalendarTool) AddActivity(ctx context.Context, _ *mcp.CallToolRequest, in AddActivityInput) (*mcp.CallToolResult, AddActivityOutput, error) {
|
||||
if strings.TrimSpace(in.Title) == "" {
|
||||
return nil, AddActivityOutput{}, errInvalidInput("title is required")
|
||||
return nil, AddActivityOutput{}, errRequiredField("title")
|
||||
}
|
||||
activity, err := t.store.AddActivity(ctx, ext.Activity{
|
||||
FamilyMemberID: in.FamilyMemberID,
|
||||
@@ -170,7 +170,7 @@ type AddImportantDateOutput struct {
|
||||
|
||||
func (t *CalendarTool) AddImportantDate(ctx context.Context, _ *mcp.CallToolRequest, in AddImportantDateInput) (*mcp.CallToolResult, AddImportantDateOutput, error) {
|
||||
if strings.TrimSpace(in.Title) == "" {
|
||||
return nil, AddImportantDateOutput{}, errInvalidInput("title is required")
|
||||
return nil, AddImportantDateOutput{}, errRequiredField("title")
|
||||
}
|
||||
reminder := in.ReminderDaysBefore
|
||||
if reminder <= 0 {
|
||||
|
||||
@@ -43,7 +43,7 @@ func NewCaptureTool(db *store.DB, provider ai.Provider, capture config.CaptureCo
|
||||
func (t *CaptureTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in CaptureInput) (*mcp.CallToolResult, CaptureOutput, error) {
|
||||
content := strings.TrimSpace(in.Content)
|
||||
if content == "" {
|
||||
return nil, CaptureOutput{}, errInvalidInput("content is required")
|
||||
return nil, CaptureOutput{}, errRequiredField("content")
|
||||
}
|
||||
|
||||
project, err := resolveProject(ctx, t.store, t.sessions, req, in.Project, false)
|
||||
|
||||
@@ -1,9 +1,13 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"git.warky.dev/wdevs/amcs/internal/config"
|
||||
"git.warky.dev/wdevs/amcs/internal/mcperrors"
|
||||
"github.com/modelcontextprotocol/go-sdk/jsonrpc"
|
||||
)
|
||||
|
||||
func normalizeLimit(limit int, cfg config.SearchConfig) int {
|
||||
@@ -26,6 +30,116 @@ func normalizeThreshold(value float64, fallback float64) float64 {
|
||||
return value
|
||||
}
|
||||
|
||||
func errInvalidInput(message string) error {
|
||||
return fmt.Errorf("invalid input: %s", message)
|
||||
const (
|
||||
codeSessionRequired = mcperrors.CodeSessionRequired
|
||||
codeProjectRequired = mcperrors.CodeProjectRequired
|
||||
codeProjectNotFound = mcperrors.CodeProjectNotFound
|
||||
codeInvalidID = mcperrors.CodeInvalidID
|
||||
codeEntityNotFound = mcperrors.CodeEntityNotFound
|
||||
)
|
||||
|
||||
type mcpErrorData = mcperrors.Data
|
||||
|
||||
func newMCPError(code int64, message string, data mcpErrorData) error {
|
||||
rpcErr := &jsonrpc.Error{
|
||||
Code: code,
|
||||
Message: message,
|
||||
}
|
||||
|
||||
payload, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal mcp error data: %w", err)
|
||||
}
|
||||
rpcErr.Data = payload
|
||||
|
||||
return rpcErr
|
||||
}
|
||||
|
||||
func errInvalidInput(message string) error {
|
||||
return newMCPError(
|
||||
jsonrpc.CodeInvalidParams,
|
||||
"invalid input: "+message,
|
||||
mcpErrorData{
|
||||
Type: mcperrors.TypeInvalidInput,
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func errRequiredField(field string) error {
|
||||
return newMCPError(
|
||||
jsonrpc.CodeInvalidParams,
|
||||
field+" is required",
|
||||
mcpErrorData{
|
||||
Type: mcperrors.TypeInvalidInput,
|
||||
Field: field,
|
||||
Detail: "required",
|
||||
Hint: "provide " + field,
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func errInvalidField(field string, message string, hint string) error {
|
||||
return newMCPError(
|
||||
jsonrpc.CodeInvalidParams,
|
||||
message,
|
||||
mcpErrorData{
|
||||
Type: mcperrors.TypeInvalidInput,
|
||||
Field: field,
|
||||
Detail: "invalid",
|
||||
Hint: hint,
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func errOneOfRequired(fields ...string) error {
|
||||
return newMCPError(
|
||||
jsonrpc.CodeInvalidParams,
|
||||
joinFields(fields)+" is required",
|
||||
mcpErrorData{
|
||||
Type: mcperrors.TypeInvalidInput,
|
||||
Fields: fields,
|
||||
Detail: "one_of_required",
|
||||
Hint: "provide one of: " + strings.Join(fields, ", "),
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func errMutuallyExclusiveFields(fields ...string) error {
|
||||
return newMCPError(
|
||||
jsonrpc.CodeInvalidParams,
|
||||
"provide "+joinFields(fields)+", not both",
|
||||
mcpErrorData{
|
||||
Type: mcperrors.TypeInvalidInput,
|
||||
Fields: fields,
|
||||
Detail: "mutually_exclusive",
|
||||
Hint: "provide only one of: " + strings.Join(fields, ", "),
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func errEntityNotFound(entity string, field string, value string) error {
|
||||
return newMCPError(
|
||||
codeEntityNotFound,
|
||||
entity+" not found",
|
||||
mcpErrorData{
|
||||
Type: mcperrors.TypeEntityNotFound,
|
||||
Entity: entity,
|
||||
Field: field,
|
||||
Value: value,
|
||||
Detail: "not_found",
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func joinFields(fields []string) string {
|
||||
switch len(fields) {
|
||||
case 0:
|
||||
return "field"
|
||||
case 1:
|
||||
return fields[0]
|
||||
case 2:
|
||||
return fields[0] + " or " + fields[1]
|
||||
default:
|
||||
return strings.Join(fields[:len(fields)-1], ", ") + ", or " + fields[len(fields)-1]
|
||||
}
|
||||
}
|
||||
|
||||
84
internal/tools/common_test.go
Normal file
84
internal/tools/common_test.go
Normal file
@@ -0,0 +1,84 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"git.warky.dev/wdevs/amcs/internal/mcperrors"
|
||||
)
|
||||
|
||||
func TestErrRequiredFieldReturnsFieldMetadata(t *testing.T) {
|
||||
rpcErr, data := requireRPCError(t, errRequiredField("name"))
|
||||
if data.Type != mcperrors.TypeInvalidInput {
|
||||
t.Fatalf("errRequiredField() type = %q, want %q", data.Type, mcperrors.TypeInvalidInput)
|
||||
}
|
||||
if data.Field != "name" {
|
||||
t.Fatalf("errRequiredField() field = %q, want %q", data.Field, "name")
|
||||
}
|
||||
if data.Detail != "required" {
|
||||
t.Fatalf("errRequiredField() detail = %q, want %q", data.Detail, "required")
|
||||
}
|
||||
if rpcErr.Message != "name is required" {
|
||||
t.Fatalf("errRequiredField() message = %q, want %q", rpcErr.Message, "name is required")
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrInvalidFieldReturnsFieldMetadata(t *testing.T) {
|
||||
rpcErr, data := requireRPCError(t, errInvalidField("severity", "severity must be one of: low, medium, high, critical", "pass one of: low, medium, high, critical"))
|
||||
if data.Field != "severity" {
|
||||
t.Fatalf("errInvalidField() field = %q, want %q", data.Field, "severity")
|
||||
}
|
||||
if data.Detail != "invalid" {
|
||||
t.Fatalf("errInvalidField() detail = %q, want %q", data.Detail, "invalid")
|
||||
}
|
||||
if data.Hint == "" {
|
||||
t.Fatal("errInvalidField() hint = empty, want guidance")
|
||||
}
|
||||
if rpcErr.Message == "" {
|
||||
t.Fatal("errInvalidField() message = empty, want non-empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrOneOfRequiredReturnsFieldsMetadata(t *testing.T) {
|
||||
rpcErr, data := requireRPCError(t, errOneOfRequired("content_base64", "content_uri"))
|
||||
if data.Detail != "one_of_required" {
|
||||
t.Fatalf("errOneOfRequired() detail = %q, want %q", data.Detail, "one_of_required")
|
||||
}
|
||||
if len(data.Fields) != 2 || data.Fields[0] != "content_base64" || data.Fields[1] != "content_uri" {
|
||||
t.Fatalf("errOneOfRequired() fields = %#v, want [content_base64 content_uri]", data.Fields)
|
||||
}
|
||||
if rpcErr.Message != "content_base64 or content_uri is required" {
|
||||
t.Fatalf("errOneOfRequired() message = %q, want %q", rpcErr.Message, "content_base64 or content_uri is required")
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrMutuallyExclusiveFieldsReturnsFieldsMetadata(t *testing.T) {
|
||||
rpcErr, data := requireRPCError(t, errMutuallyExclusiveFields("content_uri", "content_base64"))
|
||||
if data.Detail != "mutually_exclusive" {
|
||||
t.Fatalf("errMutuallyExclusiveFields() detail = %q, want %q", data.Detail, "mutually_exclusive")
|
||||
}
|
||||
if len(data.Fields) != 2 || data.Fields[0] != "content_uri" || data.Fields[1] != "content_base64" {
|
||||
t.Fatalf("errMutuallyExclusiveFields() fields = %#v, want [content_uri content_base64]", data.Fields)
|
||||
}
|
||||
if rpcErr.Message != "provide content_uri or content_base64, not both" {
|
||||
t.Fatalf("errMutuallyExclusiveFields() message = %q, want %q", rpcErr.Message, "provide content_uri or content_base64, not both")
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrEntityNotFoundReturnsEntityMetadata(t *testing.T) {
|
||||
rpcErr, data := requireRPCError(t, errEntityNotFound("thought", "thought_id", "123"))
|
||||
if rpcErr.Code != codeEntityNotFound {
|
||||
t.Fatalf("errEntityNotFound() code = %d, want %d", rpcErr.Code, codeEntityNotFound)
|
||||
}
|
||||
if data.Type != mcperrors.TypeEntityNotFound {
|
||||
t.Fatalf("errEntityNotFound() type = %q, want %q", data.Type, mcperrors.TypeEntityNotFound)
|
||||
}
|
||||
if data.Entity != "thought" {
|
||||
t.Fatalf("errEntityNotFound() entity = %q, want %q", data.Entity, "thought")
|
||||
}
|
||||
if data.Field != "thought_id" {
|
||||
t.Fatalf("errEntityNotFound() field = %q, want %q", data.Field, "thought_id")
|
||||
}
|
||||
if data.Value != "123" {
|
||||
t.Fatalf("errEntityNotFound() value = %q, want %q", data.Value, "123")
|
||||
}
|
||||
}
|
||||
@@ -2,11 +2,13 @@ package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
|
||||
"git.warky.dev/wdevs/amcs/internal/store"
|
||||
@@ -24,15 +26,15 @@ func NewCRMTool(db *store.DB) *CRMTool {
|
||||
// add_professional_contact
|
||||
|
||||
type AddContactInput struct {
|
||||
Name string `json:"name" jsonschema:"contact's full name"`
|
||||
Company string `json:"company,omitempty"`
|
||||
Title string `json:"title,omitempty" jsonschema:"job title"`
|
||||
Email string `json:"email,omitempty"`
|
||||
Phone string `json:"phone,omitempty"`
|
||||
LinkedInURL string `json:"linkedin_url,omitempty"`
|
||||
HowWeMet string `json:"how_we_met,omitempty"`
|
||||
Tags []string `json:"tags,omitempty"`
|
||||
Notes string `json:"notes,omitempty"`
|
||||
Name string `json:"name" jsonschema:"contact's full name"`
|
||||
Company string `json:"company,omitempty"`
|
||||
Title string `json:"title,omitempty" jsonschema:"job title"`
|
||||
Email string `json:"email,omitempty"`
|
||||
Phone string `json:"phone,omitempty"`
|
||||
LinkedInURL string `json:"linkedin_url,omitempty"`
|
||||
HowWeMet string `json:"how_we_met,omitempty"`
|
||||
Tags []string `json:"tags,omitempty"`
|
||||
Notes string `json:"notes,omitempty"`
|
||||
FollowUpDate *time.Time `json:"follow_up_date,omitempty"`
|
||||
}
|
||||
|
||||
@@ -42,7 +44,7 @@ type AddContactOutput struct {
|
||||
|
||||
func (t *CRMTool) AddContact(ctx context.Context, _ *mcp.CallToolRequest, in AddContactInput) (*mcp.CallToolResult, AddContactOutput, error) {
|
||||
if strings.TrimSpace(in.Name) == "" {
|
||||
return nil, AddContactOutput{}, errInvalidInput("name is required")
|
||||
return nil, AddContactOutput{}, errRequiredField("name")
|
||||
}
|
||||
if in.Tags == nil {
|
||||
in.Tags = []string{}
|
||||
@@ -104,7 +106,7 @@ type LogInteractionOutput struct {
|
||||
|
||||
func (t *CRMTool) LogInteraction(ctx context.Context, _ *mcp.CallToolRequest, in LogInteractionInput) (*mcp.CallToolResult, LogInteractionOutput, error) {
|
||||
if strings.TrimSpace(in.Summary) == "" {
|
||||
return nil, LogInteractionOutput{}, errInvalidInput("summary is required")
|
||||
return nil, LogInteractionOutput{}, errRequiredField("summary")
|
||||
}
|
||||
occurredAt := time.Now()
|
||||
if in.OccurredAt != nil {
|
||||
@@ -160,7 +162,7 @@ type CreateOpportunityOutput struct {
|
||||
|
||||
func (t *CRMTool) CreateOpportunity(ctx context.Context, _ *mcp.CallToolRequest, in CreateOpportunityInput) (*mcp.CallToolResult, CreateOpportunityOutput, error) {
|
||||
if strings.TrimSpace(in.Title) == "" {
|
||||
return nil, CreateOpportunityOutput{}, errInvalidInput("title is required")
|
||||
return nil, CreateOpportunityOutput{}, errRequiredField("title")
|
||||
}
|
||||
stage := strings.TrimSpace(in.Stage)
|
||||
if stage == "" {
|
||||
@@ -216,7 +218,10 @@ type LinkThoughtToContactOutput struct {
|
||||
func (t *CRMTool) LinkThought(ctx context.Context, _ *mcp.CallToolRequest, in LinkThoughtToContactInput) (*mcp.CallToolResult, LinkThoughtToContactOutput, error) {
|
||||
thought, err := t.store.GetThought(ctx, in.ThoughtID)
|
||||
if err != nil {
|
||||
return nil, LinkThoughtToContactOutput{}, fmt.Errorf("thought not found: %w", err)
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return nil, LinkThoughtToContactOutput{}, errEntityNotFound("thought", "thought_id", in.ThoughtID.String())
|
||||
}
|
||||
return nil, LinkThoughtToContactOutput{}, err
|
||||
}
|
||||
|
||||
appendText := fmt.Sprintf("\n\n[Linked thought %s]: %s", thought.ID, thought.Content)
|
||||
@@ -226,6 +231,9 @@ func (t *CRMTool) LinkThought(ctx context.Context, _ *mcp.CallToolRequest, in Li
|
||||
|
||||
contact, err := t.store.GetContact(ctx, in.ContactID)
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return nil, LinkThoughtToContactOutput{}, errEntityNotFound("contact", "contact_id", in.ContactID.String())
|
||||
}
|
||||
return nil, LinkThoughtToContactOutput{}, err
|
||||
}
|
||||
return nil, LinkThoughtToContactOutput{Contact: contact}, nil
|
||||
|
||||
@@ -52,7 +52,7 @@ type SaveFileOutput struct {
|
||||
}
|
||||
|
||||
type LoadFileInput struct {
|
||||
ID string `json:"id" jsonschema:"the stored file id"`
|
||||
ID string `json:"id" jsonschema:"the stored file id or amcs://files/{id} URI"`
|
||||
}
|
||||
|
||||
type LoadFileOutput struct {
|
||||
@@ -95,7 +95,7 @@ func (t *FilesTool) Upload(ctx context.Context, req *mcp.CallToolRequest, in Upl
|
||||
b64 := strings.TrimSpace(in.ContentBase64)
|
||||
|
||||
if path != "" && b64 != "" {
|
||||
return nil, UploadFileOutput{}, errInvalidInput("provide content_path or content_base64, not both")
|
||||
return nil, UploadFileOutput{}, errMutuallyExclusiveFields("content_path", "content_base64")
|
||||
}
|
||||
|
||||
var content []byte
|
||||
@@ -103,7 +103,11 @@ func (t *FilesTool) Upload(ctx context.Context, req *mcp.CallToolRequest, in Upl
|
||||
|
||||
if path != "" {
|
||||
if !filepath.IsAbs(path) {
|
||||
return nil, UploadFileOutput{}, errInvalidInput("content_path must be an absolute path")
|
||||
return nil, UploadFileOutput{}, errInvalidField(
|
||||
"content_path",
|
||||
"content_path must be an absolute path",
|
||||
"pass an absolute path on the server filesystem",
|
||||
)
|
||||
}
|
||||
var err error
|
||||
content, err = os.ReadFile(path)
|
||||
@@ -112,7 +116,7 @@ func (t *FilesTool) Upload(ctx context.Context, req *mcp.CallToolRequest, in Upl
|
||||
}
|
||||
} else {
|
||||
if b64 == "" {
|
||||
return nil, UploadFileOutput{}, errInvalidInput("content_path or content_base64 is required")
|
||||
return nil, UploadFileOutput{}, errOneOfRequired("content_path", "content_base64")
|
||||
}
|
||||
if len(b64) > maxBase64ToolBytes {
|
||||
return nil, UploadFileOutput{}, errInvalidInput(
|
||||
@@ -123,7 +127,11 @@ func (t *FilesTool) Upload(ctx context.Context, req *mcp.CallToolRequest, in Upl
|
||||
var err error
|
||||
content, err = decodeBase64(raw)
|
||||
if err != nil {
|
||||
return nil, UploadFileOutput{}, errInvalidInput("content_base64 must be valid base64")
|
||||
return nil, UploadFileOutput{}, errInvalidField(
|
||||
"content_base64",
|
||||
"content_base64 must be valid base64",
|
||||
"pass valid base64 data or a data URL",
|
||||
)
|
||||
}
|
||||
mediaTypeFromSource = dataURLMediaType
|
||||
}
|
||||
@@ -149,7 +157,7 @@ func (t *FilesTool) Save(ctx context.Context, req *mcp.CallToolRequest, in SaveF
|
||||
b64 := strings.TrimSpace(in.ContentBase64)
|
||||
|
||||
if uri != "" && b64 != "" {
|
||||
return nil, SaveFileOutput{}, errInvalidInput("provide content_uri or content_base64, not both")
|
||||
return nil, SaveFileOutput{}, errMutuallyExclusiveFields("content_uri", "content_base64")
|
||||
}
|
||||
if len(b64) > maxBase64ToolBytes {
|
||||
return nil, SaveFileOutput{}, errInvalidInput(
|
||||
@@ -162,28 +170,44 @@ func (t *FilesTool) Save(ctx context.Context, req *mcp.CallToolRequest, in SaveF
|
||||
|
||||
if uri != "" {
|
||||
if !strings.HasPrefix(uri, fileURIPrefix) {
|
||||
return nil, SaveFileOutput{}, errInvalidInput("content_uri must be an amcs://files/{id} URI")
|
||||
return nil, SaveFileOutput{}, errInvalidField(
|
||||
"content_uri",
|
||||
"content_uri must be an amcs://files/{id} URI",
|
||||
"pass an amcs://files/{id} URI returned by upload_file or POST /files",
|
||||
)
|
||||
}
|
||||
rawID := strings.TrimPrefix(uri, fileURIPrefix)
|
||||
id, err := parseUUID(rawID)
|
||||
if err != nil {
|
||||
return nil, SaveFileOutput{}, errInvalidInput("content_uri contains an invalid file id")
|
||||
return nil, SaveFileOutput{}, errInvalidField(
|
||||
"content_uri",
|
||||
"content_uri contains an invalid file id",
|
||||
"pass a valid amcs://files/{id} URI",
|
||||
)
|
||||
}
|
||||
file, err := t.store.GetStoredFile(ctx, id)
|
||||
if err != nil {
|
||||
return nil, SaveFileOutput{}, errInvalidInput("content_uri references a file that does not exist")
|
||||
return nil, SaveFileOutput{}, errInvalidField(
|
||||
"content_uri",
|
||||
"content_uri references a file that does not exist",
|
||||
"upload the file first or pass an existing amcs://files/{id} URI",
|
||||
)
|
||||
}
|
||||
content = file.Content
|
||||
mediaTypeFromSource = file.MediaType
|
||||
} else {
|
||||
contentBase64, mediaTypeFromDataURL := splitDataURL(b64)
|
||||
if contentBase64 == "" {
|
||||
return nil, SaveFileOutput{}, errInvalidInput("content_base64 or content_uri is required")
|
||||
return nil, SaveFileOutput{}, errOneOfRequired("content_base64", "content_uri")
|
||||
}
|
||||
var err error
|
||||
content, err = decodeBase64(contentBase64)
|
||||
if err != nil {
|
||||
return nil, SaveFileOutput{}, errInvalidInput("content_base64 must be valid base64")
|
||||
return nil, SaveFileOutput{}, errInvalidField(
|
||||
"content_base64",
|
||||
"content_base64 must be valid base64",
|
||||
"pass valid base64 data or a data URL",
|
||||
)
|
||||
}
|
||||
mediaTypeFromSource = mediaTypeFromDataURL
|
||||
}
|
||||
@@ -205,7 +229,7 @@ func (t *FilesTool) Save(ctx context.Context, req *mcp.CallToolRequest, in SaveF
|
||||
const fileURIPrefix = "amcs://files/"
|
||||
|
||||
func (t *FilesTool) GetRaw(ctx context.Context, rawID string) (thoughttypes.StoredFile, error) {
|
||||
id, err := parseUUID(strings.TrimSpace(rawID))
|
||||
id, err := parseStoredFileID(rawID)
|
||||
if err != nil {
|
||||
return thoughttypes.StoredFile{}, err
|
||||
}
|
||||
@@ -213,7 +237,7 @@ func (t *FilesTool) GetRaw(ctx context.Context, rawID string) (thoughttypes.Stor
|
||||
}
|
||||
|
||||
func (t *FilesTool) Load(ctx context.Context, _ *mcp.CallToolRequest, in LoadFileInput) (*mcp.CallToolResult, LoadFileOutput, error) {
|
||||
id, err := parseUUID(in.ID)
|
||||
id, err := parseStoredFileID(in.ID)
|
||||
if err != nil {
|
||||
return nil, LoadFileOutput{}, err
|
||||
}
|
||||
@@ -243,8 +267,7 @@ func (t *FilesTool) Load(ctx context.Context, _ *mcp.CallToolRequest, in LoadFil
|
||||
}
|
||||
|
||||
func (t *FilesTool) ReadResource(ctx context.Context, req *mcp.ReadResourceRequest) (*mcp.ReadResourceResult, error) {
|
||||
rawID := strings.TrimPrefix(req.Params.URI, fileURIPrefix)
|
||||
id, err := parseUUID(strings.TrimSpace(rawID))
|
||||
id, err := parseStoredFileID(req.Params.URI)
|
||||
if err != nil {
|
||||
return nil, mcp.ResourceNotFoundError(req.Params.URI)
|
||||
}
|
||||
@@ -309,7 +332,7 @@ func (t *FilesTool) List(ctx context.Context, req *mcp.CallToolRequest, in ListF
|
||||
func (t *FilesTool) SaveDecoded(ctx context.Context, req *mcp.CallToolRequest, in SaveFileDecodedInput) (SaveFileOutput, error) {
|
||||
name := strings.TrimSpace(in.Name)
|
||||
if name == "" {
|
||||
return SaveFileOutput{}, errInvalidInput("name is required")
|
||||
return SaveFileOutput{}, errRequiredField("name")
|
||||
}
|
||||
if len(in.Content) == 0 {
|
||||
return SaveFileOutput{}, errInvalidInput("decoded file content must not be empty")
|
||||
@@ -492,3 +515,9 @@ func normalizeFileLimit(limit int) int {
|
||||
return limit
|
||||
}
|
||||
}
|
||||
|
||||
func parseStoredFileID(raw string) (uuid.UUID, error) {
|
||||
value := strings.TrimSpace(raw)
|
||||
value = strings.TrimPrefix(value, fileURIPrefix)
|
||||
return parseUUID(value)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
package tools
|
||||
|
||||
import "testing"
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
func TestDecodeBase64AcceptsWhitespaceAndMultipleVariants(t *testing.T) {
|
||||
tests := []struct {
|
||||
@@ -27,3 +31,45 @@ func TestDecodeBase64AcceptsWhitespaceAndMultipleVariants(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseStoredFileIDAcceptsUUIDAndURI(t *testing.T) {
|
||||
id := uuid.New()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
want uuid.UUID
|
||||
}{
|
||||
{name: "bare uuid", input: id.String(), want: id},
|
||||
{name: "resource uri", input: fileURIPrefix + id.String(), want: id},
|
||||
{name: "resource uri with surrounding whitespace", input: " " + fileURIPrefix + id.String() + " ", want: id},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got, err := parseStoredFileID(tc.input)
|
||||
if err != nil {
|
||||
t.Fatalf("parseStoredFileID(%q) error = %v", tc.input, err)
|
||||
}
|
||||
if got != tc.want {
|
||||
t.Fatalf("parseStoredFileID(%q) = %v, want %v", tc.input, got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseStoredFileIDRejectsInvalidValues(t *testing.T) {
|
||||
tests := []string{
|
||||
"",
|
||||
"not-a-uuid",
|
||||
fileURIPrefix + "not-a-uuid",
|
||||
}
|
||||
|
||||
for _, input := range tests {
|
||||
t.Run(input, func(t *testing.T) {
|
||||
if _, err := parseStoredFileID(input); err == nil {
|
||||
t.Fatalf("parseStoredFileID(%q) = nil error, want error", input)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,22 +9,41 @@ import (
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
|
||||
"git.warky.dev/wdevs/amcs/internal/mcperrors"
|
||||
"git.warky.dev/wdevs/amcs/internal/session"
|
||||
"git.warky.dev/wdevs/amcs/internal/store"
|
||||
thoughttypes "git.warky.dev/wdevs/amcs/internal/types"
|
||||
)
|
||||
|
||||
func parseUUID(id string) (uuid.UUID, error) {
|
||||
parsed, err := uuid.Parse(strings.TrimSpace(id))
|
||||
trimmed := strings.TrimSpace(id)
|
||||
parsed, err := uuid.Parse(trimmed)
|
||||
if err != nil {
|
||||
return uuid.Nil, fmt.Errorf("invalid id %q: %w", id, err)
|
||||
return uuid.Nil, newMCPError(
|
||||
codeInvalidID,
|
||||
fmt.Sprintf("invalid id %q", id),
|
||||
mcpErrorData{
|
||||
Type: mcperrors.TypeInvalidID,
|
||||
Field: "id",
|
||||
Value: trimmed,
|
||||
Detail: err.Error(),
|
||||
Hint: "pass a valid UUID",
|
||||
},
|
||||
)
|
||||
}
|
||||
return parsed, nil
|
||||
}
|
||||
|
||||
func sessionID(req *mcp.CallToolRequest) (string, error) {
|
||||
if req == nil || req.Session == nil || req.Session.ID() == "" {
|
||||
return "", fmt.Errorf("tool requires an MCP session")
|
||||
return "", newMCPError(
|
||||
codeSessionRequired,
|
||||
"tool requires an MCP session; use a stateful MCP client for session-scoped operations",
|
||||
mcpErrorData{
|
||||
Type: mcperrors.TypeSessionRequired,
|
||||
Hint: "use a stateful MCP client for session-scoped operations",
|
||||
},
|
||||
)
|
||||
}
|
||||
return req.Session.ID(), nil
|
||||
}
|
||||
@@ -45,7 +64,15 @@ func resolveProject(ctx context.Context, db *store.DB, sessions *session.ActiveP
|
||||
|
||||
if projectRef == "" {
|
||||
if required {
|
||||
return nil, fmt.Errorf("project is required")
|
||||
return nil, newMCPError(
|
||||
codeProjectRequired,
|
||||
"project is required; pass project explicitly or call set_active_project in this MCP session first",
|
||||
mcpErrorData{
|
||||
Type: mcperrors.TypeProjectRequired,
|
||||
Field: "project",
|
||||
Hint: "pass project explicitly or call set_active_project in this MCP session first",
|
||||
},
|
||||
)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
@@ -53,7 +80,15 @@ func resolveProject(ctx context.Context, db *store.DB, sessions *session.ActiveP
|
||||
project, err := db.GetProject(ctx, projectRef)
|
||||
if err != nil {
|
||||
if err == pgx.ErrNoRows {
|
||||
return nil, fmt.Errorf("project %q not found", projectRef)
|
||||
return nil, newMCPError(
|
||||
codeProjectNotFound,
|
||||
fmt.Sprintf("project %q not found", projectRef),
|
||||
mcpErrorData{
|
||||
Type: mcperrors.TypeProjectNotFound,
|
||||
Field: "project",
|
||||
Project: projectRef,
|
||||
},
|
||||
)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
107
internal/tools/helpers_test.go
Normal file
107
internal/tools/helpers_test.go
Normal file
@@ -0,0 +1,107 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"git.warky.dev/wdevs/amcs/internal/mcperrors"
|
||||
"github.com/modelcontextprotocol/go-sdk/jsonrpc"
|
||||
)
|
||||
|
||||
func TestResolveProjectRequiredErrorGuidesCaller(t *testing.T) {
|
||||
_, err := resolveProject(context.Background(), nil, nil, nil, "", true)
|
||||
if err == nil {
|
||||
t.Fatal("resolveProject() error = nil, want error")
|
||||
}
|
||||
|
||||
rpcErr, data := requireRPCError(t, err)
|
||||
if rpcErr.Code != codeProjectRequired {
|
||||
t.Fatalf("resolveProject() code = %d, want %d", rpcErr.Code, codeProjectRequired)
|
||||
}
|
||||
if data.Type != mcperrors.TypeProjectRequired {
|
||||
t.Fatalf("resolveProject() type = %q, want %q", data.Type, mcperrors.TypeProjectRequired)
|
||||
}
|
||||
if data.Field != "project" {
|
||||
t.Fatalf("resolveProject() field = %q, want %q", data.Field, "project")
|
||||
}
|
||||
if data.Hint == "" {
|
||||
t.Fatal("resolveProject() hint = empty, want guidance")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionIDErrorGuidesCaller(t *testing.T) {
|
||||
_, err := sessionID(nil)
|
||||
if err == nil {
|
||||
t.Fatal("sessionID() error = nil, want error")
|
||||
}
|
||||
|
||||
rpcErr, data := requireRPCError(t, err)
|
||||
if rpcErr.Code != codeSessionRequired {
|
||||
t.Fatalf("sessionID() code = %d, want %d", rpcErr.Code, codeSessionRequired)
|
||||
}
|
||||
if data.Type != mcperrors.TypeSessionRequired {
|
||||
t.Fatalf("sessionID() type = %q, want %q", data.Type, mcperrors.TypeSessionRequired)
|
||||
}
|
||||
if data.Hint == "" {
|
||||
t.Fatal("sessionID() hint = empty, want guidance")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseUUIDReturnsTypedError(t *testing.T) {
|
||||
_, err := parseUUID("not-a-uuid")
|
||||
if err == nil {
|
||||
t.Fatal("parseUUID() error = nil, want error")
|
||||
}
|
||||
|
||||
rpcErr, data := requireRPCError(t, err)
|
||||
if rpcErr.Code != codeInvalidID {
|
||||
t.Fatalf("parseUUID() code = %d, want %d", rpcErr.Code, codeInvalidID)
|
||||
}
|
||||
if data.Type != mcperrors.TypeInvalidID {
|
||||
t.Fatalf("parseUUID() type = %q, want %q", data.Type, mcperrors.TypeInvalidID)
|
||||
}
|
||||
if data.Field != "id" {
|
||||
t.Fatalf("parseUUID() field = %q, want %q", data.Field, "id")
|
||||
}
|
||||
if data.Value != "not-a-uuid" {
|
||||
t.Fatalf("parseUUID() value = %q, want %q", data.Value, "not-a-uuid")
|
||||
}
|
||||
if data.Detail == "" {
|
||||
t.Fatal("parseUUID() detail = empty, want parse failure detail")
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrInvalidInputReturnsTypedError(t *testing.T) {
|
||||
err := errInvalidInput("name is required")
|
||||
if err == nil {
|
||||
t.Fatal("errInvalidInput() error = nil, want error")
|
||||
}
|
||||
|
||||
rpcErr, data := requireRPCError(t, err)
|
||||
if rpcErr.Code != jsonrpc.CodeInvalidParams {
|
||||
t.Fatalf("errInvalidInput() code = %d, want %d", rpcErr.Code, jsonrpc.CodeInvalidParams)
|
||||
}
|
||||
if data.Type != mcperrors.TypeInvalidInput {
|
||||
t.Fatalf("errInvalidInput() type = %q, want %q", data.Type, mcperrors.TypeInvalidInput)
|
||||
}
|
||||
}
|
||||
|
||||
func requireRPCError(t *testing.T, err error) (*jsonrpc.Error, mcpErrorData) {
|
||||
t.Helper()
|
||||
|
||||
var rpcErr *jsonrpc.Error
|
||||
if !errors.As(err, &rpcErr) {
|
||||
t.Fatalf("error type = %T, want *jsonrpc.Error", err)
|
||||
}
|
||||
|
||||
var data mcpErrorData
|
||||
if len(rpcErr.Data) > 0 {
|
||||
if unmarshalErr := json.Unmarshal(rpcErr.Data, &data); unmarshalErr != nil {
|
||||
t.Fatalf("unmarshal error data: %v", unmarshalErr)
|
||||
}
|
||||
}
|
||||
|
||||
return rpcErr, data
|
||||
}
|
||||
@@ -35,7 +35,7 @@ type AddHouseholdItemOutput struct {
|
||||
|
||||
func (t *HouseholdTool) AddItem(ctx context.Context, _ *mcp.CallToolRequest, in AddHouseholdItemInput) (*mcp.CallToolResult, AddHouseholdItemOutput, error) {
|
||||
if strings.TrimSpace(in.Name) == "" {
|
||||
return nil, AddHouseholdItemOutput{}, errInvalidInput("name is required")
|
||||
return nil, AddHouseholdItemOutput{}, errRequiredField("name")
|
||||
}
|
||||
if in.Details == nil {
|
||||
in.Details = map[string]any{}
|
||||
@@ -112,7 +112,7 @@ type AddVendorOutput struct {
|
||||
|
||||
func (t *HouseholdTool) AddVendor(ctx context.Context, _ *mcp.CallToolRequest, in AddVendorInput) (*mcp.CallToolResult, AddVendorOutput, error) {
|
||||
if strings.TrimSpace(in.Name) == "" {
|
||||
return nil, AddVendorOutput{}, errInvalidInput("name is required")
|
||||
return nil, AddVendorOutput{}, errRequiredField("name")
|
||||
}
|
||||
vendor, err := t.store.AddVendor(ctx, ext.HouseholdVendor{
|
||||
Name: strings.TrimSpace(in.Name),
|
||||
|
||||
@@ -62,7 +62,7 @@ func (t *LinksTool) Link(ctx context.Context, _ *mcp.CallToolRequest, in LinkInp
|
||||
}
|
||||
relation := strings.TrimSpace(in.Relation)
|
||||
if relation == "" {
|
||||
return nil, LinkOutput{}, errInvalidInput("relation is required")
|
||||
return nil, LinkOutput{}, errRequiredField("relation")
|
||||
}
|
||||
if _, err := t.store.GetThought(ctx, fromID); err != nil {
|
||||
return nil, LinkOutput{}, err
|
||||
|
||||
@@ -37,7 +37,7 @@ type AddMaintenanceTaskOutput struct {
|
||||
|
||||
func (t *MaintenanceTool) AddTask(ctx context.Context, _ *mcp.CallToolRequest, in AddMaintenanceTaskInput) (*mcp.CallToolResult, AddMaintenanceTaskOutput, error) {
|
||||
if strings.TrimSpace(in.Name) == "" {
|
||||
return nil, AddMaintenanceTaskOutput{}, errInvalidInput("name is required")
|
||||
return nil, AddMaintenanceTaskOutput{}, errRequiredField("name")
|
||||
}
|
||||
priority := strings.TrimSpace(in.Priority)
|
||||
if priority == "" {
|
||||
|
||||
@@ -41,7 +41,7 @@ type AddRecipeOutput struct {
|
||||
|
||||
func (t *MealsTool) AddRecipe(ctx context.Context, _ *mcp.CallToolRequest, in AddRecipeInput) (*mcp.CallToolResult, AddRecipeOutput, error) {
|
||||
if strings.TrimSpace(in.Name) == "" {
|
||||
return nil, AddRecipeOutput{}, errInvalidInput("name is required")
|
||||
return nil, AddRecipeOutput{}, errRequiredField("name")
|
||||
}
|
||||
if in.Ingredients == nil {
|
||||
in.Ingredients = []ext.Ingredient{}
|
||||
@@ -116,7 +116,7 @@ type UpdateRecipeOutput struct {
|
||||
|
||||
func (t *MealsTool) UpdateRecipe(ctx context.Context, _ *mcp.CallToolRequest, in UpdateRecipeInput) (*mcp.CallToolResult, UpdateRecipeOutput, error) {
|
||||
if strings.TrimSpace(in.Name) == "" {
|
||||
return nil, UpdateRecipeOutput{}, errInvalidInput("name is required")
|
||||
return nil, UpdateRecipeOutput{}, errRequiredField("name")
|
||||
}
|
||||
if in.Ingredients == nil {
|
||||
in.Ingredients = []ext.Ingredient{}
|
||||
|
||||
@@ -52,7 +52,7 @@ func NewProjectsTool(db *store.DB, sessions *session.ActiveProjects) *ProjectsTo
|
||||
func (t *ProjectsTool) Create(ctx context.Context, _ *mcp.CallToolRequest, in CreateProjectInput) (*mcp.CallToolResult, CreateProjectOutput, error) {
|
||||
name := strings.TrimSpace(in.Name)
|
||||
if name == "" {
|
||||
return nil, CreateProjectOutput{}, errInvalidInput("name is required")
|
||||
return nil, CreateProjectOutput{}, errRequiredField("name")
|
||||
}
|
||||
project, err := t.store.CreateProject(ctx, name, strings.TrimSpace(in.Description))
|
||||
if err != nil {
|
||||
|
||||
@@ -39,7 +39,7 @@ func NewRecallTool(db *store.DB, provider ai.Provider, search config.SearchConfi
|
||||
func (t *RecallTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in RecallInput) (*mcp.CallToolResult, RecallOutput, error) {
|
||||
query := strings.TrimSpace(in.Query)
|
||||
if query == "" {
|
||||
return nil, RecallOutput{}, errInvalidInput("query is required")
|
||||
return nil, RecallOutput{}, errRequiredField("query")
|
||||
}
|
||||
|
||||
project, err := resolveProject(ctx, t.store, t.sessions, req, in.Project, false)
|
||||
|
||||
@@ -39,7 +39,7 @@ func NewSearchTool(db *store.DB, provider ai.Provider, search config.SearchConfi
|
||||
func (t *SearchTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in SearchInput) (*mcp.CallToolResult, SearchOutput, error) {
|
||||
query := strings.TrimSpace(in.Query)
|
||||
if query == "" {
|
||||
return nil, SearchOutput{}, errInvalidInput("query is required")
|
||||
return nil, SearchOutput{}, errRequiredField("query")
|
||||
}
|
||||
|
||||
limit := normalizeLimit(in.Limit, t.search)
|
||||
|
||||
@@ -36,10 +36,10 @@ type AddSkillOutput struct {
|
||||
|
||||
func (t *SkillsTool) AddSkill(ctx context.Context, _ *mcp.CallToolRequest, in AddSkillInput) (*mcp.CallToolResult, AddSkillOutput, error) {
|
||||
if strings.TrimSpace(in.Name) == "" {
|
||||
return nil, AddSkillOutput{}, errInvalidInput("name is required")
|
||||
return nil, AddSkillOutput{}, errRequiredField("name")
|
||||
}
|
||||
if strings.TrimSpace(in.Content) == "" {
|
||||
return nil, AddSkillOutput{}, errInvalidInput("content is required")
|
||||
return nil, AddSkillOutput{}, errRequiredField("content")
|
||||
}
|
||||
if in.Tags == nil {
|
||||
in.Tags = []string{}
|
||||
@@ -110,10 +110,10 @@ type AddGuardrailOutput struct {
|
||||
|
||||
func (t *SkillsTool) AddGuardrail(ctx context.Context, _ *mcp.CallToolRequest, in AddGuardrailInput) (*mcp.CallToolResult, AddGuardrailOutput, error) {
|
||||
if strings.TrimSpace(in.Name) == "" {
|
||||
return nil, AddGuardrailOutput{}, errInvalidInput("name is required")
|
||||
return nil, AddGuardrailOutput{}, errRequiredField("name")
|
||||
}
|
||||
if strings.TrimSpace(in.Content) == "" {
|
||||
return nil, AddGuardrailOutput{}, errInvalidInput("content is required")
|
||||
return nil, AddGuardrailOutput{}, errRequiredField("content")
|
||||
}
|
||||
severity := strings.TrimSpace(in.Severity)
|
||||
if severity == "" {
|
||||
@@ -122,7 +122,11 @@ func (t *SkillsTool) AddGuardrail(ctx context.Context, _ *mcp.CallToolRequest, i
|
||||
switch severity {
|
||||
case "low", "medium", "high", "critical":
|
||||
default:
|
||||
return nil, AddGuardrailOutput{}, errInvalidInput("severity must be one of: low, medium, high, critical")
|
||||
return nil, AddGuardrailOutput{}, errInvalidField(
|
||||
"severity",
|
||||
"severity must be one of: low, medium, high, critical",
|
||||
"pass one of: low, medium, high, critical",
|
||||
)
|
||||
}
|
||||
if in.Tags == nil {
|
||||
in.Tags = []string{}
|
||||
@@ -231,7 +235,7 @@ type ListProjectSkillsInput struct {
|
||||
}
|
||||
|
||||
type ListProjectSkillsOutput struct {
|
||||
ProjectID uuid.UUID `json:"project_id"`
|
||||
ProjectID uuid.UUID `json:"project_id"`
|
||||
Skills []ext.AgentSkill `json:"skills"`
|
||||
}
|
||||
|
||||
@@ -302,7 +306,7 @@ type ListProjectGuardrailsInput struct {
|
||||
}
|
||||
|
||||
type ListProjectGuardrailsOutput struct {
|
||||
ProjectID uuid.UUID `json:"project_id"`
|
||||
ProjectID uuid.UUID `json:"project_id"`
|
||||
Guardrails []ext.AgentGuardrail `json:"guardrails"`
|
||||
}
|
||||
|
||||
|
||||
45
internal/tools/version.go
Normal file
45
internal/tools/version.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
|
||||
"git.warky.dev/wdevs/amcs/internal/buildinfo"
|
||||
)
|
||||
|
||||
type VersionTool struct {
|
||||
serverName string
|
||||
info buildinfo.Info
|
||||
}
|
||||
|
||||
type GetVersionInfoInput struct{}
|
||||
|
||||
type GetVersionInfoOutput struct {
|
||||
ServerName string `json:"server_name"`
|
||||
Version string `json:"version"`
|
||||
TagName string `json:"tag_name"`
|
||||
Commit string `json:"commit"`
|
||||
BuildDate string `json:"build_date"`
|
||||
}
|
||||
|
||||
func NewVersionTool(serverName string, info buildinfo.Info) *VersionTool {
|
||||
return &VersionTool{
|
||||
serverName: serverName,
|
||||
info: info,
|
||||
}
|
||||
}
|
||||
|
||||
func (t *VersionTool) GetInfo(_ context.Context, _ *mcp.CallToolRequest, _ GetVersionInfoInput) (*mcp.CallToolResult, GetVersionInfoOutput, error) {
|
||||
if t == nil {
|
||||
return nil, GetVersionInfoOutput{}, nil
|
||||
}
|
||||
|
||||
return nil, GetVersionInfoOutput{
|
||||
ServerName: t.serverName,
|
||||
Version: t.info.Version,
|
||||
TagName: t.info.TagName,
|
||||
Commit: t.info.Commit,
|
||||
BuildDate: t.info.BuildDate,
|
||||
}, nil
|
||||
}
|
||||
38
internal/tools/version_test.go
Normal file
38
internal/tools/version_test.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"git.warky.dev/wdevs/amcs/internal/buildinfo"
|
||||
)
|
||||
|
||||
func TestVersionToolReturnsBuildInformation(t *testing.T) {
|
||||
tool := NewVersionTool("amcs", buildinfo.Info{
|
||||
Version: "v1.2.3",
|
||||
TagName: "v1.2.3",
|
||||
Commit: "abc1234",
|
||||
BuildDate: "2026-03-31T12:34:56Z",
|
||||
})
|
||||
|
||||
_, out, err := tool.GetInfo(context.Background(), nil, GetVersionInfoInput{})
|
||||
if err != nil {
|
||||
t.Fatalf("GetInfo() error = %v", err)
|
||||
}
|
||||
|
||||
if out.ServerName != "amcs" {
|
||||
t.Fatalf("server_name = %q, want %q", out.ServerName, "amcs")
|
||||
}
|
||||
if out.Version != "v1.2.3" {
|
||||
t.Fatalf("version = %q, want %q", out.Version, "v1.2.3")
|
||||
}
|
||||
if out.TagName != "v1.2.3" {
|
||||
t.Fatalf("tag_name = %q, want %q", out.TagName, "v1.2.3")
|
||||
}
|
||||
if out.Commit != "abc1234" {
|
||||
t.Fatalf("commit = %q, want %q", out.Commit, "abc1234")
|
||||
}
|
||||
if out.BuildDate != "2026-03-31T12:34:56Z" {
|
||||
t.Fatalf("build_date = %q, want %q", out.BuildDate, "2026-03-31T12:34:56Z")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user