Files
amcs/internal/tools/files.go

330 lines
9.0 KiB
Go

package tools
import (
"context"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"net/http"
"strings"
"github.com/google/uuid"
"github.com/modelcontextprotocol/go-sdk/mcp"
"git.warky.dev/wdevs/amcs/internal/session"
"git.warky.dev/wdevs/amcs/internal/store"
thoughttypes "git.warky.dev/wdevs/amcs/internal/types"
)
type FilesTool struct {
store *store.DB
sessions *session.ActiveProjects
}
type SaveFileInput struct {
Name string `json:"name" jsonschema:"file name including extension, for example photo.png or note.pdf"`
ContentBase64 string `json:"content_base64" jsonschema:"file contents encoded as base64"`
MediaType string `json:"media_type,omitempty" jsonschema:"optional MIME type such as image/png, application/pdf, or audio/mpeg"`
Kind string `json:"kind,omitempty" jsonschema:"optional logical type such as image, document, audio, or file"`
ThoughtID string `json:"thought_id,omitempty" jsonschema:"optional thought id to link this file to"`
Project string `json:"project,omitempty" jsonschema:"optional project name or id when saving outside a linked thought"`
}
type SaveFileDecodedInput struct {
Name string
Content []byte
MediaType string
Kind string
ThoughtID string
Project string
}
type SaveFileOutput struct {
File thoughttypes.StoredFile `json:"file"`
}
type LoadFileInput struct {
ID string `json:"id" jsonschema:"the stored file id"`
}
type LoadFileOutput struct {
File thoughttypes.StoredFile `json:"file"`
ContentBase64 string `json:"content_base64"`
}
type ListFilesInput struct {
Limit int `json:"limit,omitempty" jsonschema:"maximum number of files to return"`
ThoughtID string `json:"thought_id,omitempty" jsonschema:"optional thought id to list files for"`
Project string `json:"project,omitempty" jsonschema:"optional project name or id to scope the listing"`
Kind string `json:"kind,omitempty" jsonschema:"optional kind filter such as image, document, audio, or file"`
}
type ListFilesOutput struct {
Files []thoughttypes.StoredFile `json:"files"`
}
func NewFilesTool(db *store.DB, sessions *session.ActiveProjects) *FilesTool {
return &FilesTool{store: db, sessions: sessions}
}
func (t *FilesTool) Save(ctx context.Context, req *mcp.CallToolRequest, in SaveFileInput) (*mcp.CallToolResult, SaveFileOutput, error) {
contentBase64, mediaTypeFromDataURL := splitDataURL(strings.TrimSpace(in.ContentBase64))
if contentBase64 == "" {
return nil, SaveFileOutput{}, errInvalidInput("content_base64 is required")
}
content, err := decodeBase64(contentBase64)
if err != nil {
return nil, SaveFileOutput{}, errInvalidInput("content_base64 must be valid base64")
}
out, err := t.SaveDecoded(ctx, req, SaveFileDecodedInput{
Name: in.Name,
Content: content,
MediaType: firstNonEmpty(strings.TrimSpace(in.MediaType), mediaTypeFromDataURL),
Kind: in.Kind,
ThoughtID: in.ThoughtID,
Project: in.Project,
})
if err != nil {
return nil, SaveFileOutput{}, err
}
return nil, out, nil
}
func (t *FilesTool) Load(ctx context.Context, _ *mcp.CallToolRequest, in LoadFileInput) (*mcp.CallToolResult, LoadFileOutput, error) {
id, err := parseUUID(in.ID)
if err != nil {
return nil, LoadFileOutput{}, err
}
file, err := t.store.GetStoredFile(ctx, id)
if err != nil {
return nil, LoadFileOutput{}, err
}
return nil, LoadFileOutput{
File: file,
ContentBase64: base64.StdEncoding.EncodeToString(file.Content),
}, nil
}
func (t *FilesTool) List(ctx context.Context, req *mcp.CallToolRequest, in ListFilesInput) (*mcp.CallToolResult, ListFilesOutput, error) {
project, err := resolveProject(ctx, t.store, t.sessions, req, in.Project, false)
if err != nil {
return nil, ListFilesOutput{}, err
}
var thoughtID *uuid.UUID
if rawThoughtID := strings.TrimSpace(in.ThoughtID); rawThoughtID != "" {
parsedThoughtID, err := parseUUID(rawThoughtID)
if err != nil {
return nil, ListFilesOutput{}, err
}
thought, err := t.store.GetThought(ctx, parsedThoughtID)
if err != nil {
return nil, ListFilesOutput{}, err
}
thoughtID = &parsedThoughtID
if project != nil && thought.ProjectID != nil && *thought.ProjectID != project.ID {
return nil, ListFilesOutput{}, errInvalidInput("project does not match the linked thought's project")
}
if project == nil && thought.ProjectID != nil {
project = &thoughttypes.Project{ID: *thought.ProjectID}
}
}
files, err := t.store.ListStoredFiles(ctx, thoughttypes.StoredFileFilter{
Limit: normalizeFileLimit(in.Limit),
ThoughtID: thoughtID,
ProjectID: projectIDPtr(project),
Kind: strings.TrimSpace(in.Kind),
})
if err != nil {
return nil, ListFilesOutput{}, err
}
if project != nil {
_ = t.store.TouchProject(ctx, project.ID)
}
return nil, ListFilesOutput{Files: files}, nil
}
func (t *FilesTool) SaveDecoded(ctx context.Context, req *mcp.CallToolRequest, in SaveFileDecodedInput) (SaveFileOutput, error) {
name := strings.TrimSpace(in.Name)
if name == "" {
return SaveFileOutput{}, errInvalidInput("name is required")
}
if len(in.Content) == 0 {
return SaveFileOutput{}, errInvalidInput("decoded file content must not be empty")
}
project, err := resolveProject(ctx, t.store, t.sessions, req, in.Project, false)
if err != nil {
return SaveFileOutput{}, err
}
var thoughtID *uuid.UUID
var projectID = projectIDPtr(project)
if rawThoughtID := strings.TrimSpace(in.ThoughtID); rawThoughtID != "" {
parsedThoughtID, err := parseUUID(rawThoughtID)
if err != nil {
return SaveFileOutput{}, err
}
thought, err := t.store.GetThought(ctx, parsedThoughtID)
if err != nil {
return SaveFileOutput{}, err
}
thoughtID = &parsedThoughtID
projectID = thought.ProjectID
if project != nil && thought.ProjectID != nil && *thought.ProjectID != project.ID {
return SaveFileOutput{}, errInvalidInput("project does not match the linked thought's project")
}
}
mediaType := normalizeMediaType(strings.TrimSpace(in.MediaType), "", in.Content)
kind := normalizeFileKind(strings.TrimSpace(in.Kind), mediaType)
sum := sha256.Sum256(in.Content)
file := thoughttypes.StoredFile{
Name: name,
MediaType: mediaType,
Kind: kind,
Encoding: "base64",
SizeBytes: int64(len(in.Content)),
SHA256: hex.EncodeToString(sum[:]),
Content: in.Content,
ProjectID: projectID,
}
if thoughtID != nil {
file.ThoughtID = thoughtID
}
created, err := t.store.InsertStoredFile(ctx, file)
if err != nil {
return SaveFileOutput{}, err
}
if created.ThoughtID != nil {
if err := t.store.AddThoughtAttachment(ctx, *created.ThoughtID, thoughtAttachmentFromFile(created)); err != nil {
return SaveFileOutput{}, err
}
}
if created.ProjectID != nil {
_ = t.store.TouchProject(ctx, *created.ProjectID)
}
return SaveFileOutput{File: created}, nil
}
func thoughtAttachmentFromFile(file thoughttypes.StoredFile) thoughttypes.ThoughtAttachment {
return thoughttypes.ThoughtAttachment{
FileID: file.ID,
Name: file.Name,
MediaType: file.MediaType,
Kind: file.Kind,
SizeBytes: file.SizeBytes,
SHA256: file.SHA256,
}
}
func splitDataURL(value string) (contentBase64 string, mediaType string) {
const marker = ";base64,"
if !strings.HasPrefix(value, "data:") {
return value, ""
}
prefix, payload, ok := strings.Cut(value, marker)
if !ok {
return value, ""
}
mediaType = strings.TrimPrefix(prefix, "data:")
return payload, strings.TrimSpace(mediaType)
}
func decodeBase64(value string) ([]byte, error) {
cleaned := strings.Map(func(r rune) rune {
switch r {
case ' ', '\t', '\n', '\r':
return -1
default:
return r
}
}, value)
encodings := []*base64.Encoding{
base64.StdEncoding,
base64.RawStdEncoding,
base64.URLEncoding,
base64.RawURLEncoding,
}
var lastErr error
for _, encoding := range encodings {
decoded, err := encoding.DecodeString(cleaned)
if err == nil {
return decoded, nil
}
lastErr = err
}
return nil, lastErr
}
func normalizeMediaType(explicit string, fromDataURL string, content []byte) string {
switch {
case explicit != "":
return explicit
case fromDataURL != "":
return fromDataURL
default:
return http.DetectContentType(content)
}
}
func firstNonEmpty(values ...string) string {
for _, value := range values {
if strings.TrimSpace(value) != "" {
return strings.TrimSpace(value)
}
}
return ""
}
func normalizeFileKind(explicit string, mediaType string) string {
if explicit != "" {
return explicit
}
switch {
case strings.HasPrefix(mediaType, "image/"):
return "image"
case strings.HasPrefix(mediaType, "audio/"):
return "audio"
case strings.HasPrefix(mediaType, "video/"):
return "video"
case mediaType == "application/pdf" || strings.HasPrefix(mediaType, "text/") || strings.Contains(mediaType, "document"):
return "document"
default:
return "file"
}
}
func projectIDPtr(project *thoughttypes.Project) *uuid.UUID {
if project == nil {
return nil
}
return &project.ID
}
func normalizeFileLimit(limit int) int {
switch {
case limit <= 0:
return 20
case limit > 100:
return 100
default:
return limit
}
}