330 lines
9.0 KiB
Go
330 lines
9.0 KiB
Go
package tools
|
|
|
|
import (
|
|
"context"
|
|
"crypto/sha256"
|
|
"encoding/base64"
|
|
"encoding/hex"
|
|
"net/http"
|
|
"strings"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/modelcontextprotocol/go-sdk/mcp"
|
|
|
|
"git.warky.dev/wdevs/amcs/internal/session"
|
|
"git.warky.dev/wdevs/amcs/internal/store"
|
|
thoughttypes "git.warky.dev/wdevs/amcs/internal/types"
|
|
)
|
|
|
|
type FilesTool struct {
|
|
store *store.DB
|
|
sessions *session.ActiveProjects
|
|
}
|
|
|
|
type SaveFileInput struct {
|
|
Name string `json:"name" jsonschema:"file name including extension, for example photo.png or note.pdf"`
|
|
ContentBase64 string `json:"content_base64" jsonschema:"file contents encoded as base64"`
|
|
MediaType string `json:"media_type,omitempty" jsonschema:"optional MIME type such as image/png, application/pdf, or audio/mpeg"`
|
|
Kind string `json:"kind,omitempty" jsonschema:"optional logical type such as image, document, audio, or file"`
|
|
ThoughtID string `json:"thought_id,omitempty" jsonschema:"optional thought id to link this file to"`
|
|
Project string `json:"project,omitempty" jsonschema:"optional project name or id when saving outside a linked thought"`
|
|
}
|
|
|
|
type SaveFileDecodedInput struct {
|
|
Name string
|
|
Content []byte
|
|
MediaType string
|
|
Kind string
|
|
ThoughtID string
|
|
Project string
|
|
}
|
|
|
|
type SaveFileOutput struct {
|
|
File thoughttypes.StoredFile `json:"file"`
|
|
}
|
|
|
|
type LoadFileInput struct {
|
|
ID string `json:"id" jsonschema:"the stored file id"`
|
|
}
|
|
|
|
type LoadFileOutput struct {
|
|
File thoughttypes.StoredFile `json:"file"`
|
|
ContentBase64 string `json:"content_base64"`
|
|
}
|
|
|
|
type ListFilesInput struct {
|
|
Limit int `json:"limit,omitempty" jsonschema:"maximum number of files to return"`
|
|
ThoughtID string `json:"thought_id,omitempty" jsonschema:"optional thought id to list files for"`
|
|
Project string `json:"project,omitempty" jsonschema:"optional project name or id to scope the listing"`
|
|
Kind string `json:"kind,omitempty" jsonschema:"optional kind filter such as image, document, audio, or file"`
|
|
}
|
|
|
|
type ListFilesOutput struct {
|
|
Files []thoughttypes.StoredFile `json:"files"`
|
|
}
|
|
|
|
func NewFilesTool(db *store.DB, sessions *session.ActiveProjects) *FilesTool {
|
|
return &FilesTool{store: db, sessions: sessions}
|
|
}
|
|
|
|
func (t *FilesTool) Save(ctx context.Context, req *mcp.CallToolRequest, in SaveFileInput) (*mcp.CallToolResult, SaveFileOutput, error) {
|
|
contentBase64, mediaTypeFromDataURL := splitDataURL(strings.TrimSpace(in.ContentBase64))
|
|
if contentBase64 == "" {
|
|
return nil, SaveFileOutput{}, errInvalidInput("content_base64 is required")
|
|
}
|
|
|
|
content, err := decodeBase64(contentBase64)
|
|
if err != nil {
|
|
return nil, SaveFileOutput{}, errInvalidInput("content_base64 must be valid base64")
|
|
}
|
|
out, err := t.SaveDecoded(ctx, req, SaveFileDecodedInput{
|
|
Name: in.Name,
|
|
Content: content,
|
|
MediaType: firstNonEmpty(strings.TrimSpace(in.MediaType), mediaTypeFromDataURL),
|
|
Kind: in.Kind,
|
|
ThoughtID: in.ThoughtID,
|
|
Project: in.Project,
|
|
})
|
|
if err != nil {
|
|
return nil, SaveFileOutput{}, err
|
|
}
|
|
return nil, out, nil
|
|
}
|
|
|
|
func (t *FilesTool) Load(ctx context.Context, _ *mcp.CallToolRequest, in LoadFileInput) (*mcp.CallToolResult, LoadFileOutput, error) {
|
|
id, err := parseUUID(in.ID)
|
|
if err != nil {
|
|
return nil, LoadFileOutput{}, err
|
|
}
|
|
|
|
file, err := t.store.GetStoredFile(ctx, id)
|
|
if err != nil {
|
|
return nil, LoadFileOutput{}, err
|
|
}
|
|
|
|
return nil, LoadFileOutput{
|
|
File: file,
|
|
ContentBase64: base64.StdEncoding.EncodeToString(file.Content),
|
|
}, nil
|
|
}
|
|
|
|
func (t *FilesTool) List(ctx context.Context, req *mcp.CallToolRequest, in ListFilesInput) (*mcp.CallToolResult, ListFilesOutput, error) {
|
|
project, err := resolveProject(ctx, t.store, t.sessions, req, in.Project, false)
|
|
if err != nil {
|
|
return nil, ListFilesOutput{}, err
|
|
}
|
|
|
|
var thoughtID *uuid.UUID
|
|
if rawThoughtID := strings.TrimSpace(in.ThoughtID); rawThoughtID != "" {
|
|
parsedThoughtID, err := parseUUID(rawThoughtID)
|
|
if err != nil {
|
|
return nil, ListFilesOutput{}, err
|
|
}
|
|
thought, err := t.store.GetThought(ctx, parsedThoughtID)
|
|
if err != nil {
|
|
return nil, ListFilesOutput{}, err
|
|
}
|
|
thoughtID = &parsedThoughtID
|
|
if project != nil && thought.ProjectID != nil && *thought.ProjectID != project.ID {
|
|
return nil, ListFilesOutput{}, errInvalidInput("project does not match the linked thought's project")
|
|
}
|
|
if project == nil && thought.ProjectID != nil {
|
|
project = &thoughttypes.Project{ID: *thought.ProjectID}
|
|
}
|
|
}
|
|
|
|
files, err := t.store.ListStoredFiles(ctx, thoughttypes.StoredFileFilter{
|
|
Limit: normalizeFileLimit(in.Limit),
|
|
ThoughtID: thoughtID,
|
|
ProjectID: projectIDPtr(project),
|
|
Kind: strings.TrimSpace(in.Kind),
|
|
})
|
|
if err != nil {
|
|
return nil, ListFilesOutput{}, err
|
|
}
|
|
if project != nil {
|
|
_ = t.store.TouchProject(ctx, project.ID)
|
|
}
|
|
|
|
return nil, ListFilesOutput{Files: files}, nil
|
|
}
|
|
|
|
func (t *FilesTool) SaveDecoded(ctx context.Context, req *mcp.CallToolRequest, in SaveFileDecodedInput) (SaveFileOutput, error) {
|
|
name := strings.TrimSpace(in.Name)
|
|
if name == "" {
|
|
return SaveFileOutput{}, errInvalidInput("name is required")
|
|
}
|
|
if len(in.Content) == 0 {
|
|
return SaveFileOutput{}, errInvalidInput("decoded file content must not be empty")
|
|
}
|
|
|
|
project, err := resolveProject(ctx, t.store, t.sessions, req, in.Project, false)
|
|
if err != nil {
|
|
return SaveFileOutput{}, err
|
|
}
|
|
|
|
var thoughtID *uuid.UUID
|
|
var projectID = projectIDPtr(project)
|
|
if rawThoughtID := strings.TrimSpace(in.ThoughtID); rawThoughtID != "" {
|
|
parsedThoughtID, err := parseUUID(rawThoughtID)
|
|
if err != nil {
|
|
return SaveFileOutput{}, err
|
|
}
|
|
thought, err := t.store.GetThought(ctx, parsedThoughtID)
|
|
if err != nil {
|
|
return SaveFileOutput{}, err
|
|
}
|
|
thoughtID = &parsedThoughtID
|
|
projectID = thought.ProjectID
|
|
if project != nil && thought.ProjectID != nil && *thought.ProjectID != project.ID {
|
|
return SaveFileOutput{}, errInvalidInput("project does not match the linked thought's project")
|
|
}
|
|
}
|
|
|
|
mediaType := normalizeMediaType(strings.TrimSpace(in.MediaType), "", in.Content)
|
|
kind := normalizeFileKind(strings.TrimSpace(in.Kind), mediaType)
|
|
sum := sha256.Sum256(in.Content)
|
|
|
|
file := thoughttypes.StoredFile{
|
|
Name: name,
|
|
MediaType: mediaType,
|
|
Kind: kind,
|
|
Encoding: "base64",
|
|
SizeBytes: int64(len(in.Content)),
|
|
SHA256: hex.EncodeToString(sum[:]),
|
|
Content: in.Content,
|
|
ProjectID: projectID,
|
|
}
|
|
if thoughtID != nil {
|
|
file.ThoughtID = thoughtID
|
|
}
|
|
|
|
created, err := t.store.InsertStoredFile(ctx, file)
|
|
if err != nil {
|
|
return SaveFileOutput{}, err
|
|
}
|
|
|
|
if created.ThoughtID != nil {
|
|
if err := t.store.AddThoughtAttachment(ctx, *created.ThoughtID, thoughtAttachmentFromFile(created)); err != nil {
|
|
return SaveFileOutput{}, err
|
|
}
|
|
}
|
|
if created.ProjectID != nil {
|
|
_ = t.store.TouchProject(ctx, *created.ProjectID)
|
|
}
|
|
|
|
return SaveFileOutput{File: created}, nil
|
|
}
|
|
|
|
func thoughtAttachmentFromFile(file thoughttypes.StoredFile) thoughttypes.ThoughtAttachment {
|
|
return thoughttypes.ThoughtAttachment{
|
|
FileID: file.ID,
|
|
Name: file.Name,
|
|
MediaType: file.MediaType,
|
|
Kind: file.Kind,
|
|
SizeBytes: file.SizeBytes,
|
|
SHA256: file.SHA256,
|
|
}
|
|
}
|
|
|
|
func splitDataURL(value string) (contentBase64 string, mediaType string) {
|
|
const marker = ";base64,"
|
|
if !strings.HasPrefix(value, "data:") {
|
|
return value, ""
|
|
}
|
|
|
|
prefix, payload, ok := strings.Cut(value, marker)
|
|
if !ok {
|
|
return value, ""
|
|
}
|
|
|
|
mediaType = strings.TrimPrefix(prefix, "data:")
|
|
return payload, strings.TrimSpace(mediaType)
|
|
}
|
|
|
|
func decodeBase64(value string) ([]byte, error) {
|
|
cleaned := strings.Map(func(r rune) rune {
|
|
switch r {
|
|
case ' ', '\t', '\n', '\r':
|
|
return -1
|
|
default:
|
|
return r
|
|
}
|
|
}, value)
|
|
|
|
encodings := []*base64.Encoding{
|
|
base64.StdEncoding,
|
|
base64.RawStdEncoding,
|
|
base64.URLEncoding,
|
|
base64.RawURLEncoding,
|
|
}
|
|
|
|
var lastErr error
|
|
for _, encoding := range encodings {
|
|
decoded, err := encoding.DecodeString(cleaned)
|
|
if err == nil {
|
|
return decoded, nil
|
|
}
|
|
lastErr = err
|
|
}
|
|
|
|
return nil, lastErr
|
|
}
|
|
|
|
func normalizeMediaType(explicit string, fromDataURL string, content []byte) string {
|
|
switch {
|
|
case explicit != "":
|
|
return explicit
|
|
case fromDataURL != "":
|
|
return fromDataURL
|
|
default:
|
|
return http.DetectContentType(content)
|
|
}
|
|
}
|
|
|
|
func firstNonEmpty(values ...string) string {
|
|
for _, value := range values {
|
|
if strings.TrimSpace(value) != "" {
|
|
return strings.TrimSpace(value)
|
|
}
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func normalizeFileKind(explicit string, mediaType string) string {
|
|
if explicit != "" {
|
|
return explicit
|
|
}
|
|
|
|
switch {
|
|
case strings.HasPrefix(mediaType, "image/"):
|
|
return "image"
|
|
case strings.HasPrefix(mediaType, "audio/"):
|
|
return "audio"
|
|
case strings.HasPrefix(mediaType, "video/"):
|
|
return "video"
|
|
case mediaType == "application/pdf" || strings.HasPrefix(mediaType, "text/") || strings.Contains(mediaType, "document"):
|
|
return "document"
|
|
default:
|
|
return "file"
|
|
}
|
|
}
|
|
|
|
func projectIDPtr(project *thoughttypes.Project) *uuid.UUID {
|
|
if project == nil {
|
|
return nil
|
|
}
|
|
return &project.ID
|
|
}
|
|
|
|
func normalizeFileLimit(limit int) int {
|
|
switch {
|
|
case limit <= 0:
|
|
return 20
|
|
case limit > 100:
|
|
return 100
|
|
default:
|
|
return limit
|
|
}
|
|
}
|