Files
amcs/internal/tools/files.go

495 lines
15 KiB
Go

package tools
import (
"context"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"net/http"
"os"
"path/filepath"
"strings"
"github.com/google/uuid"
"github.com/modelcontextprotocol/go-sdk/mcp"
"git.warky.dev/wdevs/amcs/internal/session"
"git.warky.dev/wdevs/amcs/internal/store"
thoughttypes "git.warky.dev/wdevs/amcs/internal/types"
)
// maxBase64ToolBytes is the maximum base64 payload accepted by save_file via
// the MCP tool interface. For larger files use POST /files (binary) and pass
// the returned amcs://files/{id} URI as content_uri instead.
const maxBase64ToolBytes = 10 << 20 // 10 MB of base64 ≈ 7.5 MB decoded
type FilesTool struct {
store *store.DB
sessions *session.ActiveProjects
}
type SaveFileInput struct {
Name string `json:"name" jsonschema:"file name including extension, for example photo.png or note.pdf"`
ContentBase64 string `json:"content_base64,omitempty" jsonschema:"file contents encoded as base64; provide this or content_uri, not both"`
ContentURI string `json:"content_uri,omitempty" jsonschema:"resource URI of an already-uploaded file, e.g. amcs://files/{id}; use this instead of content_base64 to avoid re-encoding binary content"`
MediaType string `json:"media_type,omitempty" jsonschema:"optional MIME type such as image/png, application/pdf, or audio/mpeg"`
Kind string `json:"kind,omitempty" jsonschema:"optional logical type such as image, document, audio, or file"`
ThoughtID string `json:"thought_id,omitempty" jsonschema:"optional thought id to link this file to"`
Project string `json:"project,omitempty" jsonschema:"optional project name or id when saving outside a linked thought"`
}
type SaveFileDecodedInput struct {
Name string
Content []byte
MediaType string
Kind string
ThoughtID string
Project string
}
type SaveFileOutput struct {
File thoughttypes.StoredFile `json:"file"`
}
type LoadFileInput struct {
ID string `json:"id" jsonschema:"the stored file id"`
}
type LoadFileOutput struct {
File thoughttypes.StoredFile `json:"file"`
ContentBase64 string `json:"content_base64"`
}
type ListFilesInput struct {
Limit int `json:"limit,omitempty" jsonschema:"maximum number of files to return"`
ThoughtID string `json:"thought_id,omitempty" jsonschema:"optional thought id to list files for"`
Project string `json:"project,omitempty" jsonschema:"optional project name or id to scope the listing"`
Kind string `json:"kind,omitempty" jsonschema:"optional kind filter such as image, document, audio, or file"`
}
type UploadFileInput struct {
Name string `json:"name" jsonschema:"file name including extension, for example photo.png or note.pdf"`
ContentPath string `json:"content_path,omitempty" jsonschema:"absolute path to a file on the server; preferred for large files — no base64 overhead"`
ContentBase64 string `json:"content_base64,omitempty" jsonschema:"file contents encoded as base64 (≤10 MB); use content_path for larger files"`
MediaType string `json:"media_type,omitempty" jsonschema:"optional MIME type such as image/png, application/pdf, or audio/mpeg"`
Kind string `json:"kind,omitempty" jsonschema:"optional logical type such as image, document, audio, or file"`
ThoughtID string `json:"thought_id,omitempty" jsonschema:"optional thought id to link this file to immediately"`
Project string `json:"project,omitempty" jsonschema:"optional project name or id"`
}
type UploadFileOutput struct {
File thoughttypes.StoredFile `json:"file"`
URI string `json:"uri" jsonschema:"amcs resource URI for this file, e.g. amcs://files/{id}; pass as content_uri in save_file to link without re-uploading"`
}
type ListFilesOutput struct {
Files []thoughttypes.StoredFile `json:"files"`
}
func NewFilesTool(db *store.DB, sessions *session.ActiveProjects) *FilesTool {
return &FilesTool{store: db, sessions: sessions}
}
func (t *FilesTool) Upload(ctx context.Context, req *mcp.CallToolRequest, in UploadFileInput) (*mcp.CallToolResult, UploadFileOutput, error) {
path := strings.TrimSpace(in.ContentPath)
b64 := strings.TrimSpace(in.ContentBase64)
if path != "" && b64 != "" {
return nil, UploadFileOutput{}, errInvalidInput("provide content_path or content_base64, not both")
}
var content []byte
var mediaTypeFromSource string
if path != "" {
if !filepath.IsAbs(path) {
return nil, UploadFileOutput{}, errInvalidInput("content_path must be an absolute path")
}
var err error
content, err = os.ReadFile(path)
if err != nil {
return nil, UploadFileOutput{}, errInvalidInput("cannot read content_path: " + err.Error())
}
} else {
if b64 == "" {
return nil, UploadFileOutput{}, errInvalidInput("content_path or content_base64 is required")
}
if len(b64) > maxBase64ToolBytes {
return nil, UploadFileOutput{}, errInvalidInput(
"content_base64 exceeds the 10 MB MCP tool limit; use content_path instead",
)
}
raw, dataURLMediaType := splitDataURL(b64)
var err error
content, err = decodeBase64(raw)
if err != nil {
return nil, UploadFileOutput{}, errInvalidInput("content_base64 must be valid base64")
}
mediaTypeFromSource = dataURLMediaType
}
out, err := t.SaveDecoded(ctx, req, SaveFileDecodedInput{
Name: in.Name,
Content: content,
MediaType: firstNonEmpty(strings.TrimSpace(in.MediaType), mediaTypeFromSource),
Kind: in.Kind,
ThoughtID: in.ThoughtID,
Project: in.Project,
})
if err != nil {
return nil, UploadFileOutput{}, err
}
uri := fileURIPrefix + out.File.ID.String()
return nil, UploadFileOutput{File: out.File, URI: uri}, nil
}
func (t *FilesTool) Save(ctx context.Context, req *mcp.CallToolRequest, in SaveFileInput) (*mcp.CallToolResult, SaveFileOutput, error) {
uri := strings.TrimSpace(in.ContentURI)
b64 := strings.TrimSpace(in.ContentBase64)
if uri != "" && b64 != "" {
return nil, SaveFileOutput{}, errInvalidInput("provide content_uri or content_base64, not both")
}
if len(b64) > maxBase64ToolBytes {
return nil, SaveFileOutput{}, errInvalidInput(
"content_base64 exceeds the 10 MB MCP tool limit; upload the file via POST /files and pass the returned amcs://files/{id} URI as content_uri instead",
)
}
var content []byte
var mediaTypeFromSource string
if uri != "" {
if !strings.HasPrefix(uri, fileURIPrefix) {
return nil, SaveFileOutput{}, errInvalidInput("content_uri must be an amcs://files/{id} URI")
}
rawID := strings.TrimPrefix(uri, fileURIPrefix)
id, err := parseUUID(rawID)
if err != nil {
return nil, SaveFileOutput{}, errInvalidInput("content_uri contains an invalid file id")
}
file, err := t.store.GetStoredFile(ctx, id)
if err != nil {
return nil, SaveFileOutput{}, errInvalidInput("content_uri references a file that does not exist")
}
content = file.Content
mediaTypeFromSource = file.MediaType
} else {
contentBase64, mediaTypeFromDataURL := splitDataURL(b64)
if contentBase64 == "" {
return nil, SaveFileOutput{}, errInvalidInput("content_base64 or content_uri is required")
}
var err error
content, err = decodeBase64(contentBase64)
if err != nil {
return nil, SaveFileOutput{}, errInvalidInput("content_base64 must be valid base64")
}
mediaTypeFromSource = mediaTypeFromDataURL
}
out, err := t.SaveDecoded(ctx, req, SaveFileDecodedInput{
Name: in.Name,
Content: content,
MediaType: firstNonEmpty(strings.TrimSpace(in.MediaType), mediaTypeFromSource),
Kind: in.Kind,
ThoughtID: in.ThoughtID,
Project: in.Project,
})
if err != nil {
return nil, SaveFileOutput{}, err
}
return nil, out, nil
}
const fileURIPrefix = "amcs://files/"
func (t *FilesTool) GetRaw(ctx context.Context, rawID string) (thoughttypes.StoredFile, error) {
id, err := parseUUID(strings.TrimSpace(rawID))
if err != nil {
return thoughttypes.StoredFile{}, err
}
return t.store.GetStoredFile(ctx, id)
}
func (t *FilesTool) Load(ctx context.Context, _ *mcp.CallToolRequest, in LoadFileInput) (*mcp.CallToolResult, LoadFileOutput, error) {
id, err := parseUUID(in.ID)
if err != nil {
return nil, LoadFileOutput{}, err
}
file, err := t.store.GetStoredFile(ctx, id)
if err != nil {
return nil, LoadFileOutput{}, err
}
uri := fileURIPrefix + file.ID.String()
result := &mcp.CallToolResult{
Content: []mcp.Content{
&mcp.EmbeddedResource{
Resource: &mcp.ResourceContents{
URI: uri,
MIMEType: file.MediaType,
Blob: file.Content,
},
},
},
}
return result, LoadFileOutput{
File: file,
ContentBase64: base64.StdEncoding.EncodeToString(file.Content),
}, nil
}
func (t *FilesTool) ReadResource(ctx context.Context, req *mcp.ReadResourceRequest) (*mcp.ReadResourceResult, error) {
rawID := strings.TrimPrefix(req.Params.URI, fileURIPrefix)
id, err := parseUUID(strings.TrimSpace(rawID))
if err != nil {
return nil, mcp.ResourceNotFoundError(req.Params.URI)
}
file, err := t.store.GetStoredFile(ctx, id)
if err != nil {
return nil, mcp.ResourceNotFoundError(req.Params.URI)
}
return &mcp.ReadResourceResult{
Contents: []*mcp.ResourceContents{
{
URI: req.Params.URI,
MIMEType: file.MediaType,
Blob: file.Content,
},
},
}, nil
}
func (t *FilesTool) List(ctx context.Context, req *mcp.CallToolRequest, in ListFilesInput) (*mcp.CallToolResult, ListFilesOutput, error) {
project, err := resolveProject(ctx, t.store, t.sessions, req, in.Project, false)
if err != nil {
return nil, ListFilesOutput{}, err
}
var thoughtID *uuid.UUID
if rawThoughtID := strings.TrimSpace(in.ThoughtID); rawThoughtID != "" {
parsedThoughtID, err := parseUUID(rawThoughtID)
if err != nil {
return nil, ListFilesOutput{}, err
}
thought, err := t.store.GetThought(ctx, parsedThoughtID)
if err != nil {
return nil, ListFilesOutput{}, err
}
thoughtID = &parsedThoughtID
if project != nil && thought.ProjectID != nil && *thought.ProjectID != project.ID {
return nil, ListFilesOutput{}, errInvalidInput("project does not match the linked thought's project")
}
if project == nil && thought.ProjectID != nil {
project = &thoughttypes.Project{ID: *thought.ProjectID}
}
}
files, err := t.store.ListStoredFiles(ctx, thoughttypes.StoredFileFilter{
Limit: normalizeFileLimit(in.Limit),
ThoughtID: thoughtID,
ProjectID: projectIDPtr(project),
Kind: strings.TrimSpace(in.Kind),
})
if err != nil {
return nil, ListFilesOutput{}, err
}
if project != nil {
_ = t.store.TouchProject(ctx, project.ID)
}
return nil, ListFilesOutput{Files: files}, nil
}
func (t *FilesTool) SaveDecoded(ctx context.Context, req *mcp.CallToolRequest, in SaveFileDecodedInput) (SaveFileOutput, error) {
name := strings.TrimSpace(in.Name)
if name == "" {
return SaveFileOutput{}, errInvalidInput("name is required")
}
if len(in.Content) == 0 {
return SaveFileOutput{}, errInvalidInput("decoded file content must not be empty")
}
project, err := resolveProject(ctx, t.store, t.sessions, req, in.Project, false)
if err != nil {
return SaveFileOutput{}, err
}
var thoughtID *uuid.UUID
var projectID = projectIDPtr(project)
if rawThoughtID := strings.TrimSpace(in.ThoughtID); rawThoughtID != "" {
parsedThoughtID, err := parseUUID(rawThoughtID)
if err != nil {
return SaveFileOutput{}, err
}
thought, err := t.store.GetThought(ctx, parsedThoughtID)
if err != nil {
return SaveFileOutput{}, err
}
thoughtID = &parsedThoughtID
projectID = thought.ProjectID
if project != nil && thought.ProjectID != nil && *thought.ProjectID != project.ID {
return SaveFileOutput{}, errInvalidInput("project does not match the linked thought's project")
}
}
mediaType := normalizeMediaType(strings.TrimSpace(in.MediaType), "", in.Content)
kind := normalizeFileKind(strings.TrimSpace(in.Kind), mediaType)
sum := sha256.Sum256(in.Content)
file := thoughttypes.StoredFile{
Name: name,
MediaType: mediaType,
Kind: kind,
Encoding: "base64",
SizeBytes: int64(len(in.Content)),
SHA256: hex.EncodeToString(sum[:]),
Content: in.Content,
ProjectID: projectID,
}
if thoughtID != nil {
file.ThoughtID = thoughtID
}
created, err := t.store.InsertStoredFile(ctx, file)
if err != nil {
return SaveFileOutput{}, err
}
if created.ThoughtID != nil {
if err := t.store.AddThoughtAttachment(ctx, *created.ThoughtID, thoughtAttachmentFromFile(created)); err != nil {
return SaveFileOutput{}, err
}
}
if created.ProjectID != nil {
_ = t.store.TouchProject(ctx, *created.ProjectID)
}
return SaveFileOutput{File: created}, nil
}
func thoughtAttachmentFromFile(file thoughttypes.StoredFile) thoughttypes.ThoughtAttachment {
return thoughttypes.ThoughtAttachment{
FileID: file.ID,
Name: file.Name,
MediaType: file.MediaType,
Kind: file.Kind,
SizeBytes: file.SizeBytes,
SHA256: file.SHA256,
}
}
func splitDataURL(value string) (contentBase64 string, mediaType string) {
const marker = ";base64,"
if !strings.HasPrefix(value, "data:") {
return value, ""
}
prefix, payload, ok := strings.Cut(value, marker)
if !ok {
return value, ""
}
mediaType = strings.TrimPrefix(prefix, "data:")
return payload, strings.TrimSpace(mediaType)
}
func decodeBase64(value string) ([]byte, error) {
cleaned := strings.Map(func(r rune) rune {
switch r {
case ' ', '\t', '\n', '\r':
return -1
default:
return r
}
}, value)
var candidates []string
candidates = append(candidates, cleaned)
if trimmed := strings.TrimRight(cleaned, "="); trimmed != cleaned && trimmed != "" {
candidates = append(candidates, trimmed)
}
encodings := []*base64.Encoding{
base64.StdEncoding,
base64.RawStdEncoding,
base64.URLEncoding,
base64.RawURLEncoding,
}
var lastErr error
for _, candidate := range candidates {
for _, encoding := range encodings {
decoded, err := encoding.DecodeString(candidate)
if err == nil {
return decoded, nil
}
lastErr = err
}
}
return nil, lastErr
}
func normalizeMediaType(explicit string, fromDataURL string, content []byte) string {
switch {
case explicit != "":
return explicit
case fromDataURL != "":
return fromDataURL
default:
return http.DetectContentType(content)
}
}
func firstNonEmpty(values ...string) string {
for _, value := range values {
if strings.TrimSpace(value) != "" {
return strings.TrimSpace(value)
}
}
return ""
}
func normalizeFileKind(explicit string, mediaType string) string {
if explicit != "" {
return explicit
}
switch {
case strings.HasPrefix(mediaType, "image/"):
return "image"
case strings.HasPrefix(mediaType, "audio/"):
return "audio"
case strings.HasPrefix(mediaType, "video/"):
return "video"
case mediaType == "application/pdf" || strings.HasPrefix(mediaType, "text/") || strings.Contains(mediaType, "document"):
return "document"
default:
return "file"
}
}
func projectIDPtr(project *thoughttypes.Project) *uuid.UUID {
if project == nil {
return nil
}
return &project.ID
}
func normalizeFileLimit(limit int) int {
switch {
case limit <= 0:
return 20
case limit > 100:
return 100
default:
return limit
}
}