From 66370a7f0ea6bd30998125ba5f784de1570d8860 Mon Sep 17 00:00:00 2001 From: Hein Date: Tue, 24 Mar 2026 15:38:59 +0200 Subject: [PATCH] feat(tools): implement CRUD operations for thoughts and projects * Add tools for creating, retrieving, updating, and deleting thoughts. * Implement project management tools for creating and listing projects. * Introduce linking functionality between thoughts. * Add search and recall capabilities for thoughts based on semantic queries. * Implement statistics and summarization tools for thought analysis. * Create database migrations for thoughts, projects, and links. * Add helper functions for UUID parsing and project resolution. --- .dockerignore | 8 + .gitignore | 4 + .vscode/launch.json | 13 + .vscode/tasks.json | 20 ++ Dockerfile | 30 ++ Makefile | 17 + README.md | 26 ++ cmd/amcs-server/main.go | 24 ++ configs/config.example.yaml | 72 ++++ configs/dev.yaml | 72 ++++ configs/docker.yaml | 72 ++++ docker-compose.docker.yml | 5 + docker-compose.yml | 38 +++ docker/postgres/init/00_apply_migrations.sh | 15 + go.mod | 28 ++ go.sum | 99 ++++++ internal/ai/compat/client.go | 330 +++++++++++++++++++ internal/ai/compat/client_test.go | 91 ++++++ internal/ai/factory.go | 22 ++ internal/ai/litellm/client.go | 24 ++ internal/ai/openrouter/client.go | 35 ++ internal/ai/provider.go | 14 + internal/app/app.go | 145 ++++++++ internal/auth/keyring.go | 29 ++ internal/auth/keyring_test.go | 107 ++++++ internal/auth/middleware.go | 49 +++ internal/config/config.go | 119 +++++++ internal/config/loader.go | 113 +++++++ internal/config/loader_test.go | 71 ++++ internal/config/validate.go | 71 ++++ internal/config/validate_test.go | 60 ++++ internal/mcpserver/server.go | 126 +++++++ internal/metadata/normalize.go | 155 +++++++++ internal/metadata/normalize_test.go | 81 +++++ internal/observability/http.go | 110 +++++++ internal/observability/http_test.go | 59 ++++ internal/observability/logger.go | 52 +++ internal/session/active_project.go | 37 +++ internal/session/active_project_test.go | 27 ++ internal/store/db.go | 118 +++++++ internal/store/links.go | 69 ++++ internal/store/projects.go | 90 +++++ internal/store/thoughts.go | 345 ++++++++++++++++++++ internal/tools/archive.go | 36 ++ internal/tools/capture.go | 95 ++++++ internal/tools/common.go | 31 ++ internal/tools/context.go | 111 +++++++ internal/tools/delete.go | 36 ++ internal/tools/get.go | 40 +++ internal/tools/helpers.go | 87 +++++ internal/tools/links.go | 145 ++++++++ internal/tools/list.go | 68 ++++ internal/tools/projects.go | 92 ++++++ internal/tools/recall.go | 111 +++++++ internal/tools/search.go | 69 ++++ internal/tools/stats.go | 33 ++ internal/tools/summarize.go | 93 ++++++ internal/tools/update.go | 88 +++++ internal/types/project.go | 40 +++ internal/types/thought.go | 57 ++++ migrations/001_enable_vector.sql | 2 + migrations/002_create_thoughts.sql | 17 + migrations/003_add_projects.sql | 13 + migrations/004_create_thought_links.sql | 10 + migrations/005_create_match_thoughts.sql | 31 ++ migrations/006_rls_and_grants.sql | 5 + scripts/migrate.sh | 15 + scripts/run-local.sh | 5 + 68 files changed, 4422 insertions(+) create mode 100644 .dockerignore create mode 100644 .vscode/launch.json create mode 100644 .vscode/tasks.json create mode 100644 Dockerfile create mode 100644 Makefile create mode 100644 cmd/amcs-server/main.go create mode 100644 configs/config.example.yaml create mode 100644 configs/dev.yaml create mode 100644 configs/docker.yaml create mode 100644 docker-compose.docker.yml create mode 100644 docker-compose.yml create mode 100755 docker/postgres/init/00_apply_migrations.sh create mode 100644 go.mod create mode 100644 go.sum create mode 100644 internal/ai/compat/client.go create mode 100644 internal/ai/compat/client_test.go create mode 100644 internal/ai/factory.go create mode 100644 internal/ai/litellm/client.go create mode 100644 internal/ai/openrouter/client.go create mode 100644 internal/ai/provider.go create mode 100644 internal/app/app.go create mode 100644 internal/auth/keyring.go create mode 100644 internal/auth/keyring_test.go create mode 100644 internal/auth/middleware.go create mode 100644 internal/config/config.go create mode 100644 internal/config/loader.go create mode 100644 internal/config/loader_test.go create mode 100644 internal/config/validate.go create mode 100644 internal/config/validate_test.go create mode 100644 internal/mcpserver/server.go create mode 100644 internal/metadata/normalize.go create mode 100644 internal/metadata/normalize_test.go create mode 100644 internal/observability/http.go create mode 100644 internal/observability/http_test.go create mode 100644 internal/observability/logger.go create mode 100644 internal/session/active_project.go create mode 100644 internal/session/active_project_test.go create mode 100644 internal/store/db.go create mode 100644 internal/store/links.go create mode 100644 internal/store/projects.go create mode 100644 internal/store/thoughts.go create mode 100644 internal/tools/archive.go create mode 100644 internal/tools/capture.go create mode 100644 internal/tools/common.go create mode 100644 internal/tools/context.go create mode 100644 internal/tools/delete.go create mode 100644 internal/tools/get.go create mode 100644 internal/tools/helpers.go create mode 100644 internal/tools/links.go create mode 100644 internal/tools/list.go create mode 100644 internal/tools/projects.go create mode 100644 internal/tools/recall.go create mode 100644 internal/tools/search.go create mode 100644 internal/tools/stats.go create mode 100644 internal/tools/summarize.go create mode 100644 internal/tools/update.go create mode 100644 internal/types/project.go create mode 100644 internal/types/thought.go create mode 100644 migrations/001_enable_vector.sql create mode 100644 migrations/002_create_thoughts.sql create mode 100644 migrations/003_add_projects.sql create mode 100644 migrations/004_create_thought_links.sql create mode 100644 migrations/005_create_match_thoughts.sql create mode 100644 migrations/006_rls_and_grants.sql create mode 100755 scripts/migrate.sh create mode 100755 scripts/run-local.sh diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..3dfcfb1 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,8 @@ +.git +.gitignore +.vscode +bin +.env +assets +llm +*.local.yaml diff --git a/.gitignore b/.gitignore index 5b90e79..8d82007 100644 --- a/.gitignore +++ b/.gitignore @@ -25,3 +25,7 @@ go.work.sum # env file .env +# local config +configs/*.local.yaml +cmd/amcs-server/__debug_* +bin/ \ No newline at end of file diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 0000000..6e570d2 --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,13 @@ +{ + "version": "0.2.0", + "configurations": [ + { + "name": "Debug amcs-server", + "type": "go", + "request": "launch", + "mode": "auto", + "program": "${workspaceFolder}/cmd/amcs-server", + "cwd": "${workspaceFolder}/bin" + } + ] +} diff --git a/.vscode/tasks.json b/.vscode/tasks.json new file mode 100644 index 0000000..25dfd73 --- /dev/null +++ b/.vscode/tasks.json @@ -0,0 +1,20 @@ +{ + "version": "2.0.0", + "tasks": [ + { + "label": "build", + "type": "shell", + "command": "make build", + + "group": { + "kind": "build", + "isDefault": true + }, + "presentation": { + "reveal": "silent", + "panel": "shared" + }, + "problemMatcher": "$go" + } + ] +} diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..7fd5f9a --- /dev/null +++ b/Dockerfile @@ -0,0 +1,30 @@ +FROM golang:1.26.1-bookworm AS builder + +WORKDIR /src + +COPY go.mod go.sum ./ +RUN go mod download + +COPY . . + +RUN CGO_ENABLED=0 GOOS=linux go build -trimpath -ldflags="-s -w" -o /out/amcs-server ./cmd/amcs-server + +FROM debian:bookworm-slim + +RUN apt-get update \ + && apt-get install -y --no-install-recommends ca-certificates \ + && rm -rf /var/lib/apt/lists/* \ + && useradd --system --create-home --uid 10001 appuser + +WORKDIR /app + +COPY --from=builder /out/amcs-server /app/amcs-server +COPY --chown=appuser:appuser configs /app/configs + +USER appuser + +EXPOSE 8080 + +ENV OB1_CONFIG=/app/configs/docker.yaml + +ENTRYPOINT ["/app/amcs-server"] diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..abec7bf --- /dev/null +++ b/Makefile @@ -0,0 +1,17 @@ +BIN_DIR := bin +SERVER_BIN := $(BIN_DIR)/amcs-server +CMD_SERVER := ./cmd/amcs-server + +.PHONY: all build clean migrate + +all: build + +build: + @mkdir -p $(BIN_DIR) + go build -o $(SERVER_BIN) $(CMD_SERVER) + +migrate: + ./scripts/migrate.sh + +clean: + rm -rf $(BIN_DIR) diff --git a/README.md b/README.md index e8f488c..703ea57 100644 --- a/README.md +++ b/README.md @@ -50,3 +50,29 @@ Config is YAML-driven. Copy `configs/config.example.yaml` and set: - `ai.litellm.base_url` and `ai.litellm.api_key` — LiteLLM proxy See `llm/plan.md` for full architecture and implementation plan. + +## Development + +Run the SQL migrations against a local database with: + +`DATABASE_URL=postgres://... make migrate` + +## Containers + +The repo now includes a `Dockerfile` and Compose files for running the app with Postgres + pgvector. + +1. Set a real LiteLLM key in your shell: + `export OB1_LITELLM_API_KEY=your-key` +2. Start the stack with your runtime: + `docker compose -f docker-compose.yml -f docker-compose.docker.yml up --build` + `podman compose -f docker-compose.yml up --build` +3. Call the service on `http://localhost:8080` + +Notes: + +- The app uses `configs/docker.yaml` inside the container. +- `OB1_LITELLM_BASE_URL` overrides the LiteLLM endpoint, so you can retarget it without editing YAML. +- The base Compose file uses `host.containers.internal`, which is Podman-friendly. +- The Docker override file adds `host-gateway` aliases so Docker can resolve the same host endpoint. +- Database migrations `001` through `005` run automatically when the Postgres volume is created for the first time. +- `migrations/006_rls_and_grants.sql` is intentionally skipped during container bootstrap because it contains deployment-specific grants for a role named `amcs_user`. diff --git a/cmd/amcs-server/main.go b/cmd/amcs-server/main.go new file mode 100644 index 0000000..3a7da27 --- /dev/null +++ b/cmd/amcs-server/main.go @@ -0,0 +1,24 @@ +package main + +import ( + "context" + "flag" + "log" + "os/signal" + "syscall" + + "git.warky.dev/wdevs/amcs/internal/app" +) + +func main() { + var configPath string + flag.StringVar(&configPath, "config", "", "Path to the YAML config file") + flag.Parse() + + ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) + defer stop() + + if err := app.Run(ctx, configPath); err != nil { + log.Fatal(err) + } +} diff --git a/configs/config.example.yaml b/configs/config.example.yaml new file mode 100644 index 0000000..5c09c67 --- /dev/null +++ b/configs/config.example.yaml @@ -0,0 +1,72 @@ +server: + host: "0.0.0.0" + port: 8080 + read_timeout: "15s" + write_timeout: "30s" + idle_timeout: "60s" + allowed_origins: + - "*" + +mcp: + path: "/mcp" + server_name: "amcs" + version: "0.1.0" + transport: "streamable_http" + +auth: + mode: "api_keys" + header_name: "x-brain-key" + query_param: "key" + allow_query_param: false + keys: + - id: "local-client" + value: "replace-me" + description: "main local client key" + +database: + url: "postgres://postgres:postgres@localhost:5432/amcs?sslmode=disable" + max_conns: 10 + min_conns: 2 + max_conn_lifetime: "30m" + max_conn_idle_time: "10m" + +ai: + provider: "litellm" + embeddings: + model: "openai/text-embedding-3-small" + dimensions: 1536 + metadata: + model: "gpt-4o-mini" + temperature: 0.1 + litellm: + base_url: "http://localhost:4000/v1" + api_key: "replace-me" + use_responses_api: false + request_headers: {} + embedding_model: "openrouter/openai/text-embedding-3-small" + metadata_model: "gpt-4o-mini" + openrouter: + base_url: "https://openrouter.ai/api/v1" + api_key: "" + app_name: "amcs" + site_url: "" + extra_headers: {} + +capture: + source: "mcp" + metadata_defaults: + type: "observation" + topic_fallback: "uncategorized" + +search: + default_limit: 10 + default_threshold: 0.5 + max_limit: 50 + +logging: + level: "info" + format: "json" + +observability: + metrics_enabled: true + pprof_enabled: false diff --git a/configs/dev.yaml b/configs/dev.yaml new file mode 100644 index 0000000..5c09c67 --- /dev/null +++ b/configs/dev.yaml @@ -0,0 +1,72 @@ +server: + host: "0.0.0.0" + port: 8080 + read_timeout: "15s" + write_timeout: "30s" + idle_timeout: "60s" + allowed_origins: + - "*" + +mcp: + path: "/mcp" + server_name: "amcs" + version: "0.1.0" + transport: "streamable_http" + +auth: + mode: "api_keys" + header_name: "x-brain-key" + query_param: "key" + allow_query_param: false + keys: + - id: "local-client" + value: "replace-me" + description: "main local client key" + +database: + url: "postgres://postgres:postgres@localhost:5432/amcs?sslmode=disable" + max_conns: 10 + min_conns: 2 + max_conn_lifetime: "30m" + max_conn_idle_time: "10m" + +ai: + provider: "litellm" + embeddings: + model: "openai/text-embedding-3-small" + dimensions: 1536 + metadata: + model: "gpt-4o-mini" + temperature: 0.1 + litellm: + base_url: "http://localhost:4000/v1" + api_key: "replace-me" + use_responses_api: false + request_headers: {} + embedding_model: "openrouter/openai/text-embedding-3-small" + metadata_model: "gpt-4o-mini" + openrouter: + base_url: "https://openrouter.ai/api/v1" + api_key: "" + app_name: "amcs" + site_url: "" + extra_headers: {} + +capture: + source: "mcp" + metadata_defaults: + type: "observation" + topic_fallback: "uncategorized" + +search: + default_limit: 10 + default_threshold: 0.5 + max_limit: 50 + +logging: + level: "info" + format: "json" + +observability: + metrics_enabled: true + pprof_enabled: false diff --git a/configs/docker.yaml b/configs/docker.yaml new file mode 100644 index 0000000..2193d0e --- /dev/null +++ b/configs/docker.yaml @@ -0,0 +1,72 @@ +server: + host: "0.0.0.0" + port: 8080 + read_timeout: "15s" + write_timeout: "30s" + idle_timeout: "60s" + allowed_origins: + - "*" + +mcp: + path: "/mcp" + server_name: "amcs" + version: "0.1.0" + transport: "streamable_http" + +auth: + mode: "api_keys" + header_name: "x-brain-key" + query_param: "key" + allow_query_param: false + keys: + - id: "local-client" + value: "replace-me" + description: "main local client key" + +database: + url: "postgres://postgres:postgres@db:5432/amcs?sslmode=disable" + max_conns: 10 + min_conns: 2 + max_conn_lifetime: "30m" + max_conn_idle_time: "10m" + +ai: + provider: "litellm" + embeddings: + model: "openai/text-embedding-3-small" + dimensions: 1536 + metadata: + model: "gpt-4o-mini" + temperature: 0.1 + litellm: + base_url: "http://host.containers.internal:4000/v1" + api_key: "replace-me" + use_responses_api: false + request_headers: {} + embedding_model: "openrouter/openai/text-embedding-3-small" + metadata_model: "gpt-4o-mini" + openrouter: + base_url: "https://openrouter.ai/api/v1" + api_key: "" + app_name: "amcs" + site_url: "" + extra_headers: {} + +capture: + source: "mcp" + metadata_defaults: + type: "observation" + topic_fallback: "uncategorized" + +search: + default_limit: 10 + default_threshold: 0.5 + max_limit: 50 + +logging: + level: "info" + format: "json" + +observability: + metrics_enabled: true + pprof_enabled: false diff --git a/docker-compose.docker.yml b/docker-compose.docker.yml new file mode 100644 index 0000000..4c30c19 --- /dev/null +++ b/docker-compose.docker.yml @@ -0,0 +1,5 @@ +services: + app: + extra_hosts: + - "host.containers.internal:host-gateway" + - "host.docker.internal:host-gateway" diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..0f79d8c --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,38 @@ +services: + db: + image: pgvector/pgvector:pg16 + restart: unless-stopped + environment: + POSTGRES_DB: amcs + POSTGRES_USER: postgres + POSTGRES_PASSWORD: postgres + ports: + - "5432:5432" + volumes: + - postgres_data:/var/lib/postgresql/data + - ./docker/postgres/init:/docker-entrypoint-initdb.d:ro + - ./migrations:/migrations:ro + healthcheck: + test: ["CMD-SHELL", "pg_isready -U postgres -d amcs"] + interval: 5s + timeout: 5s + retries: 20 + + app: + build: + context: . + depends_on: + db: + condition: service_healthy + restart: unless-stopped + environment: + OB1_CONFIG: /app/configs/docker.yaml + OB1_DATABASE_URL: postgres://postgres:postgres@db:5432/amcs?sslmode=disable + OB1_LITELLM_BASE_URL: ${OB1_LITELLM_BASE_URL:-http://host.containers.internal:4000/v1} + OB1_LITELLM_API_KEY: ${OB1_LITELLM_API_KEY:-replace-me} + OB1_SERVER_PORT: 8080 + ports: + - "8080:8080" + +volumes: + postgres_data: diff --git a/docker/postgres/init/00_apply_migrations.sh b/docker/postgres/init/00_apply_migrations.sh new file mode 100755 index 0000000..c556cbb --- /dev/null +++ b/docker/postgres/init/00_apply_migrations.sh @@ -0,0 +1,15 @@ +#!/bin/sh + +set -eu + +for migration in /migrations/*.sql; do + case "$migration" in + */006_rls_and_grants.sql) + echo "Skipping $migration because it contains deployment-specific grants." + ;; + *) + echo "Applying $migration" + psql -v ON_ERROR_STOP=1 --username "$POSTGRES_USER" --dbname "$POSTGRES_DB" -f "$migration" + ;; + esac +done diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..fdd4491 --- /dev/null +++ b/go.mod @@ -0,0 +1,28 @@ +module git.warky.dev/wdevs/amcs + +go 1.26.1 + +require ( + github.com/google/uuid v1.6.0 + github.com/jackc/pgx/v5 v5.9.1 + github.com/modelcontextprotocol/go-sdk v1.4.1 + github.com/pgvector/pgvector-go v0.3.0 + golang.org/x/sync v0.17.0 + gopkg.in/yaml.v3 v3.0.1 +) + +require ( + github.com/google/jsonschema-go v0.4.2 // indirect + github.com/jackc/pgpassfile v1.0.0 // indirect + github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect + github.com/jackc/puddle/v2 v2.2.2 // indirect + github.com/kr/text v0.2.0 // indirect + github.com/rogpeppe/go-internal v1.14.1 // indirect + github.com/segmentio/asm v1.1.3 // indirect + github.com/segmentio/encoding v0.5.4 // indirect + github.com/x448/float16 v0.8.4 // indirect + github.com/yosida95/uritemplate/v3 v3.0.2 // indirect + golang.org/x/oauth2 v0.34.0 // indirect + golang.org/x/sys v0.40.0 // indirect + golang.org/x/text v0.29.0 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..c3b96a1 --- /dev/null +++ b/go.sum @@ -0,0 +1,99 @@ +entgo.io/ent v0.14.3 h1:wokAV/kIlH9TeklJWGGS7AYJdVckr0DloWjIcO9iIIQ= +entgo.io/ent v0.14.3/go.mod h1:aDPE/OziPEu8+OWbzy4UlvWmD2/kbRuWfK2A40hcxJM= +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-pg/pg/v10 v10.11.0 h1:CMKJqLgTrfpE/aOVeLdybezR2om071Vh38OLZjsyMI0= +github.com/go-pg/pg/v10 v10.11.0/go.mod h1:4BpHRoxE61y4Onpof3x1a2SQvi9c+q1dJnrNdMjsroA= +github.com/go-pg/zerochecker v0.2.0 h1:pp7f72c3DobMWOb2ErtZsnrPaSvHd2W4o9//8HtF4mU= +github.com/go-pg/zerochecker v0.2.0/go.mod h1:NJZ4wKL0NmTtz0GKCoJ8kym6Xn/EQzXRl2OnAe7MmDo= +github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo= +github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/jsonschema-go v0.4.2 h1:tmrUohrwoLZZS/P3x7ex0WAVknEkBZM46iALbcqoRA8= +github.com/google/jsonschema-go v0.4.2/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= +github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= +github.com/jackc/pgx/v5 v5.9.1 h1:uwrxJXBnx76nyISkhr33kQLlUqjv7et7b9FjCen/tdc= +github.com/jackc/pgx/v5 v5.9.1/go.mod h1:mal1tBGAFfLHvZzaYh77YS/eC6IX9OWbRV1QIIM0Jn4= +github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= +github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= +github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= +github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= +github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= +github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= +github.com/jmoiron/sqlx v1.3.5 h1:vFFPA71p1o5gAeqtEAwLU4dnX2napprKtHr7PYIcN3g= +github.com/jmoiron/sqlx v1.3.5/go.mod h1:nRVWtLre0KfCLJvgxzCsLVMogSvQ1zNJtpYr2Ccp0mQ= +github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= +github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= +github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +github.com/modelcontextprotocol/go-sdk v1.4.1 h1:M4x9GyIPj+HoIlHNGpK2hq5o3BFhC+78PkEaldQRphc= +github.com/modelcontextprotocol/go-sdk v1.4.1/go.mod h1:Bo/mS87hPQqHSRkMv4dQq1XCu6zv4INdXnFZabkNU6s= +github.com/pgvector/pgvector-go v0.3.0 h1:Ij+Yt78R//uYqs3Zk35evZFvr+G0blW0OUN+Q2D1RWc= +github.com/pgvector/pgvector-go v0.3.0/go.mod h1:duFy+PXWfW7QQd5ibqutBO4GxLsUZ9RVXhFZGIBsWSA= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= +github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= +github.com/segmentio/asm v1.1.3 h1:WM03sfUOENvvKexOLp+pCqgb/WDjsi7EK8gIsICtzhc= +github.com/segmentio/asm v1.1.3/go.mod h1:Ld3L4ZXGNcSLRg4JBsZ3//1+f/TjYl0Mzen/DQy1EJg= +github.com/segmentio/encoding v0.5.4 h1:OW1VRern8Nw6ITAtwSZ7Idrl3MXCFwXHPgqESYfvNt0= +github.com/segmentio/encoding v0.5.4/go.mod h1:HS1ZKa3kSN32ZHVZ7ZLPLXWvOVIiZtyJnO1gPH1sKt0= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc h1:9lRDQMhESg+zvGYmW5DyG0UqvY96Bu5QYsTLvCHdrgo= +github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc/go.mod h1:bciPuU6GHm1iF1pBvUfxfsH0Wmnc2VbpgvbI9ZWuIRs= +github.com/uptrace/bun v1.1.12 h1:sOjDVHxNTuM6dNGaba0wUuz7KvDE1BmNu9Gqs2gJSXQ= +github.com/uptrace/bun v1.1.12/go.mod h1:NPG6JGULBeQ9IU6yHp7YGELRa5Agmd7ATZdz4tGZ6z0= +github.com/uptrace/bun/dialect/pgdialect v1.1.12 h1:m/CM1UfOkoBTglGO5CUTKnIKKOApOYxkcP2qn0F9tJk= +github.com/uptrace/bun/dialect/pgdialect v1.1.12/go.mod h1:Ij6WIxQILxLlL2frUBxUBOZJtLElD2QQNDcu/PWDHTc= +github.com/uptrace/bun/driver/pgdriver v1.1.12 h1:3rRWB1GK0psTJrHwxzNfEij2MLibggiLdTqjTtfHc1w= +github.com/uptrace/bun/driver/pgdriver v1.1.12/go.mod h1:ssYUP+qwSEgeDDS1xm2XBip9el1y9Mi5mTAvLoiADLM= +github.com/vmihailenco/bufpool v0.1.11 h1:gOq2WmBrq0i2yW5QJ16ykccQ4wH9UyEsgLm6czKAd94= +github.com/vmihailenco/bufpool v0.1.11/go.mod h1:AFf/MOy3l2CFTKbxwt0mp2MwnqjNEs5H/UxrkA5jxTQ= +github.com/vmihailenco/msgpack/v5 v5.3.5 h1:5gO0H1iULLWGhs2H5tbAHIZTV8/cYafcFOr9znI5mJU= +github.com/vmihailenco/msgpack/v5 v5.3.5/go.mod h1:7xyJ9e+0+9SaZT0Wt1RGleJXzli6Q/V5KbhBonMG9jc= +github.com/vmihailenco/tagparser v0.1.2 h1:gnjoVuB/kljJ5wICEEOpx98oXMWPLj22G67Vbd1qPqc= +github.com/vmihailenco/tagparser v0.1.2/go.mod h1:OeAg3pn3UbLjkWt+rN9oFYB6u/cQgqMEUPoW2WPyhdI= +github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g= +github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds= +github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= +github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg= +github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= +github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= +golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34= +golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc= +golang.org/x/oauth2 v0.34.0 h1:hqK/t4AKgbqWkdkcAeI8XLmbK+4m4G5YeQRrmiotGlw= +golang.org/x/oauth2 v0.34.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA= +golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug= +golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= +golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ= +golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/text v0.29.0 h1:1neNs90w9YzJ9BocxfsQNHKuAT4pkghyXc4nhZ6sJvk= +golang.org/x/text v0.29.0/go.mod h1:7MhJOA9CD2qZyOKYazxdYMF85OwPdEr9jTtBpO7ydH4= +golang.org/x/tools v0.41.0 h1:a9b8iMweWG+S0OBnlU36rzLp20z1Rp10w+IY2czHTQc= +golang.org/x/tools v0.41.0/go.mod h1:XSY6eDqxVNiYgezAVqqCeihT4j1U2CCsqvH3WhQpnlg= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gorm.io/driver/postgres v1.5.4 h1:Iyrp9Meh3GmbSuyIAGyjkN+n9K+GHX9b9MqsTL4EJCo= +gorm.io/driver/postgres v1.5.4/go.mod h1:Bgo89+h0CRcdA33Y6frlaHHVuTdOf87pmyzwW9C/BH0= +gorm.io/gorm v1.25.5 h1:zR9lOiiYf09VNh5Q1gphfyia1JpiClIWG9hQaxB/mls= +gorm.io/gorm v1.25.5/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= +mellium.im/sasl v0.3.1 h1:wE0LW6g7U83vhvxjC1IY8DnXM+EU095yeo8XClvCdfo= +mellium.im/sasl v0.3.1/go.mod h1:xm59PUYpZHhgQ9ZqoJ5QaCqzWMi8IeS49dhp6plPCzw= diff --git a/internal/ai/compat/client.go b/internal/ai/compat/client.go new file mode 100644 index 0000000..d03ee38 --- /dev/null +++ b/internal/ai/compat/client.go @@ -0,0 +1,330 @@ +package compat + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "log/slog" + "net" + "net/http" + "strings" + "time" + + thoughttypes "git.warky.dev/wdevs/amcs/internal/types" +) + +const metadataSystemPrompt = `You extract structured metadata from short notes. +Return only valid JSON matching this schema: +{ + "people": ["string"], + "action_items": ["string"], + "dates_mentioned": ["string"], + "topics": ["string"], + "type": "observation|task|idea|reference|person_note", + "source": "string" +} +Rules: +- Keep arrays concise. +- Use lowercase for type. +- If unsure, prefer "observation". +- Do not include any text outside the JSON object.` + +type Client struct { + name string + baseURL string + apiKey string + embeddingModel string + metadataModel string + temperature float64 + headers map[string]string + httpClient *http.Client + log *slog.Logger + dimensions int +} + +type Config struct { + Name string + BaseURL string + APIKey string + EmbeddingModel string + MetadataModel string + Temperature float64 + Headers map[string]string + HTTPClient *http.Client + Log *slog.Logger + Dimensions int +} + +type embeddingsRequest struct { + Input string `json:"input"` + Model string `json:"model"` +} + +type embeddingsResponse struct { + Data []struct { + Embedding []float32 `json:"embedding"` + } `json:"data"` + Error *providerError `json:"error,omitempty"` +} + +type chatCompletionsRequest struct { + Model string `json:"model"` + Temperature float64 `json:"temperature,omitempty"` + ResponseFormat *responseType `json:"response_format,omitempty"` + Messages []chatMessage `json:"messages"` +} + +type responseType struct { + Type string `json:"type"` +} + +type chatMessage struct { + Role string `json:"role"` + Content string `json:"content"` +} + +type chatCompletionsResponse struct { + Choices []struct { + Message chatMessage `json:"message"` + } `json:"choices"` + Error *providerError `json:"error,omitempty"` +} + +type providerError struct { + Message string `json:"message"` + Type string `json:"type,omitempty"` +} + +func New(cfg Config) *Client { + return &Client{ + name: cfg.Name, + baseURL: cfg.BaseURL, + apiKey: cfg.APIKey, + embeddingModel: cfg.EmbeddingModel, + metadataModel: cfg.MetadataModel, + temperature: cfg.Temperature, + headers: cfg.Headers, + httpClient: cfg.HTTPClient, + log: cfg.Log, + dimensions: cfg.Dimensions, + } +} + +func (c *Client) Embed(ctx context.Context, input string) ([]float32, error) { + input = strings.TrimSpace(input) + if input == "" { + return nil, fmt.Errorf("%s embed: input must not be empty", c.name) + } + + var resp embeddingsResponse + err := c.doJSON(ctx, "/embeddings", embeddingsRequest{ + Input: input, + Model: c.embeddingModel, + }, &resp) + if err != nil { + return nil, err + } + if resp.Error != nil { + return nil, fmt.Errorf("%s embed error: %s", c.name, resp.Error.Message) + } + if len(resp.Data) == 0 { + return nil, fmt.Errorf("%s embed: no embedding returned", c.name) + } + if c.dimensions > 0 && len(resp.Data[0].Embedding) != c.dimensions { + return nil, fmt.Errorf("%s embed: expected %d dimensions, got %d", c.name, c.dimensions, len(resp.Data[0].Embedding)) + } + + return resp.Data[0].Embedding, nil +} + +func (c *Client) ExtractMetadata(ctx context.Context, input string) (thoughttypes.ThoughtMetadata, error) { + input = strings.TrimSpace(input) + if input == "" { + return thoughttypes.ThoughtMetadata{}, fmt.Errorf("%s extract metadata: input must not be empty", c.name) + } + + req := chatCompletionsRequest{ + Model: c.metadataModel, + Temperature: c.temperature, + ResponseFormat: &responseType{ + Type: "json_object", + }, + Messages: []chatMessage{ + {Role: "system", Content: metadataSystemPrompt}, + {Role: "user", Content: input}, + }, + } + + var resp chatCompletionsResponse + if err := c.doJSON(ctx, "/chat/completions", req, &resp); err != nil { + return thoughttypes.ThoughtMetadata{}, err + } + if resp.Error != nil { + return thoughttypes.ThoughtMetadata{}, fmt.Errorf("%s metadata error: %s", c.name, resp.Error.Message) + } + if len(resp.Choices) == 0 { + return thoughttypes.ThoughtMetadata{}, fmt.Errorf("%s metadata: no choices returned", c.name) + } + + metadataText := strings.TrimSpace(resp.Choices[0].Message.Content) + metadataText = stripCodeFence(metadataText) + + var metadata thoughttypes.ThoughtMetadata + if err := json.Unmarshal([]byte(metadataText), &metadata); err != nil { + return thoughttypes.ThoughtMetadata{}, fmt.Errorf("%s metadata: parse json: %w", c.name, err) + } + + return metadata, nil +} + +func (c *Client) Summarize(ctx context.Context, systemPrompt, userPrompt string) (string, error) { + req := chatCompletionsRequest{ + Model: c.metadataModel, + Temperature: 0.2, + Messages: []chatMessage{ + {Role: "system", Content: systemPrompt}, + {Role: "user", Content: userPrompt}, + }, + } + + var resp chatCompletionsResponse + if err := c.doJSON(ctx, "/chat/completions", req, &resp); err != nil { + return "", err + } + if resp.Error != nil { + return "", fmt.Errorf("%s summarize error: %s", c.name, resp.Error.Message) + } + if len(resp.Choices) == 0 { + return "", fmt.Errorf("%s summarize: no choices returned", c.name) + } + + return strings.TrimSpace(resp.Choices[0].Message.Content), nil +} + +func (c *Client) Name() string { + return c.name +} + +func (c *Client) doJSON(ctx context.Context, path string, requestBody any, dest any) error { + body, err := json.Marshal(requestBody) + if err != nil { + return fmt.Errorf("%s request marshal: %w", c.name, err) + } + + const maxAttempts = 3 + + var lastErr error + for attempt := 1; attempt <= maxAttempts; attempt++ { + req, err := http.NewRequestWithContext(ctx, http.MethodPost, strings.TrimRight(c.baseURL, "/")+path, bytes.NewReader(body)) + if err != nil { + return fmt.Errorf("%s build request: %w", c.name, err) + } + + req.Header.Set("Authorization", "Bearer "+c.apiKey) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + for key, value := range c.headers { + if strings.TrimSpace(key) == "" || strings.TrimSpace(value) == "" { + continue + } + req.Header.Set(key, value) + } + + resp, err := c.httpClient.Do(req) + if err != nil { + lastErr = fmt.Errorf("%s request failed: %w", c.name, err) + if attempt < maxAttempts && isRetryableError(err) { + if retryErr := sleepRetry(ctx, attempt, c.log, c.name); retryErr != nil { + return retryErr + } + continue + } + return lastErr + } + + payload, readErr := io.ReadAll(resp.Body) + resp.Body.Close() + if readErr != nil { + lastErr = fmt.Errorf("%s read response: %w", c.name, readErr) + if attempt < maxAttempts { + if retryErr := sleepRetry(ctx, attempt, c.log, c.name); retryErr != nil { + return retryErr + } + continue + } + return lastErr + } + + if resp.StatusCode >= http.StatusBadRequest { + lastErr = fmt.Errorf("%s request failed with status %d: %s", c.name, resp.StatusCode, strings.TrimSpace(string(payload))) + if attempt < maxAttempts && isRetryableStatus(resp.StatusCode) { + if retryErr := sleepRetry(ctx, attempt, c.log, c.name); retryErr != nil { + return retryErr + } + continue + } + return lastErr + } + + if err := json.Unmarshal(payload, dest); err != nil { + if c.log != nil { + c.log.Debug("provider response body", slog.String("provider", c.name), slog.String("body", string(payload))) + } + return fmt.Errorf("%s decode response: %w", c.name, err) + } + + return nil + } + + return lastErr +} + +func stripCodeFence(value string) string { + value = strings.TrimSpace(value) + if !strings.HasPrefix(value, "```") { + return value + } + + value = strings.TrimPrefix(value, "```json") + value = strings.TrimPrefix(value, "```") + value = strings.TrimSuffix(value, "```") + return strings.TrimSpace(value) +} + +func isRetryableStatus(status int) bool { + switch status { + case http.StatusTooManyRequests, http.StatusInternalServerError, http.StatusBadGateway, http.StatusServiceUnavailable, http.StatusGatewayTimeout: + return true + default: + return false + } +} + +func isRetryableError(err error) bool { + if err == nil { + return false + } + if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) { + return true + } + var netErr net.Error + return errors.As(err, &netErr) +} + +func sleepRetry(ctx context.Context, attempt int, log *slog.Logger, provider string) error { + delay := time.Duration(attempt*attempt) * 200 * time.Millisecond + if log != nil { + log.Warn("retrying provider request", slog.String("provider", provider), slog.Duration("delay", delay), slog.Int("attempt", attempt+1)) + } + timer := time.NewTimer(delay) + defer timer.Stop() + select { + case <-ctx.Done(): + return ctx.Err() + case <-timer.C: + return nil + } +} diff --git a/internal/ai/compat/client_test.go b/internal/ai/compat/client_test.go new file mode 100644 index 0000000..e25a841 --- /dev/null +++ b/internal/ai/compat/client_test.go @@ -0,0 +1,91 @@ +package compat + +import ( + "context" + "encoding/json" + "io" + "log/slog" + "net/http" + "net/http/httptest" + "sync/atomic" + "testing" +) + +func discardLogger() *slog.Logger { + return slog.New(slog.NewTextHandler(io.Discard, nil)) +} + +func TestEmbedRetriesTransientFailures(t *testing.T) { + var calls atomic.Int32 + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if calls.Add(1) < 3 { + http.Error(w, "temporary failure", http.StatusServiceUnavailable) + return + } + _ = json.NewEncoder(w).Encode(map[string]any{ + "data": []map[string]any{ + {"embedding": []float32{1, 2, 3}}, + }, + }) + })) + defer server.Close() + + client := New(Config{ + Name: "test", + BaseURL: server.URL, + APIKey: "secret", + EmbeddingModel: "embed-model", + MetadataModel: "meta-model", + HTTPClient: server.Client(), + Log: discardLogger(), + Dimensions: 3, + }) + + embedding, err := client.Embed(context.Background(), "hello") + if err != nil { + t.Fatalf("Embed() error = %v", err) + } + if len(embedding) != 3 { + t.Fatalf("embedding len = %d, want 3", len(embedding)) + } + if got := calls.Load(); got != 3 { + t.Fatalf("call count = %d, want 3", got) + } +} + +func TestExtractMetadataParsesCodeFencedJSON(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _ = json.NewEncoder(w).Encode(map[string]any{ + "choices": []map[string]any{ + { + "message": map[string]any{ + "content": "```json\n{\"people\":[\"Alice\"],\"action_items\":[],\"dates_mentioned\":[],\"topics\":[\"memory\"],\"type\":\"idea\",\"source\":\"mcp\"}\n```", + }, + }, + }, + }) + })) + defer server.Close() + + client := New(Config{ + Name: "test", + BaseURL: server.URL, + APIKey: "secret", + EmbeddingModel: "embed-model", + MetadataModel: "meta-model", + HTTPClient: server.Client(), + Log: discardLogger(), + }) + + metadata, err := client.ExtractMetadata(context.Background(), "hello") + if err != nil { + t.Fatalf("ExtractMetadata() error = %v", err) + } + if metadata.Type != "idea" { + t.Fatalf("metadata type = %q, want idea", metadata.Type) + } + if len(metadata.People) != 1 || metadata.People[0] != "Alice" { + t.Fatalf("metadata people = %#v, want [Alice]", metadata.People) + } +} diff --git a/internal/ai/factory.go b/internal/ai/factory.go new file mode 100644 index 0000000..994ba75 --- /dev/null +++ b/internal/ai/factory.go @@ -0,0 +1,22 @@ +package ai + +import ( + "fmt" + "log/slog" + "net/http" + + "git.warky.dev/wdevs/amcs/internal/ai/litellm" + "git.warky.dev/wdevs/amcs/internal/ai/openrouter" + "git.warky.dev/wdevs/amcs/internal/config" +) + +func NewProvider(cfg config.AIConfig, httpClient *http.Client, log *slog.Logger) (Provider, error) { + switch cfg.Provider { + case "litellm": + return litellm.New(cfg, httpClient, log) + case "openrouter": + return openrouter.New(cfg, httpClient, log) + default: + return nil, fmt.Errorf("unsupported ai.provider: %s", cfg.Provider) + } +} diff --git a/internal/ai/litellm/client.go b/internal/ai/litellm/client.go new file mode 100644 index 0000000..c2753dc --- /dev/null +++ b/internal/ai/litellm/client.go @@ -0,0 +1,24 @@ +package litellm + +import ( + "log/slog" + "net/http" + + "git.warky.dev/wdevs/amcs/internal/ai/compat" + "git.warky.dev/wdevs/amcs/internal/config" +) + +func New(cfg config.AIConfig, httpClient *http.Client, log *slog.Logger) (*compat.Client, error) { + return compat.New(compat.Config{ + Name: "litellm", + BaseURL: cfg.LiteLLM.BaseURL, + APIKey: cfg.LiteLLM.APIKey, + EmbeddingModel: cfg.LiteLLM.EmbeddingModel, + MetadataModel: cfg.LiteLLM.MetadataModel, + Temperature: cfg.Metadata.Temperature, + Headers: cfg.LiteLLM.RequestHeaders, + HTTPClient: httpClient, + Log: log, + Dimensions: cfg.Embeddings.Dimensions, + }), nil +} diff --git a/internal/ai/openrouter/client.go b/internal/ai/openrouter/client.go new file mode 100644 index 0000000..c715080 --- /dev/null +++ b/internal/ai/openrouter/client.go @@ -0,0 +1,35 @@ +package openrouter + +import ( + "log/slog" + "net/http" + + "git.warky.dev/wdevs/amcs/internal/ai/compat" + "git.warky.dev/wdevs/amcs/internal/config" +) + +func New(cfg config.AIConfig, httpClient *http.Client, log *slog.Logger) (*compat.Client, error) { + headers := make(map[string]string, len(cfg.OpenRouter.ExtraHeaders)+2) + for key, value := range cfg.OpenRouter.ExtraHeaders { + headers[key] = value + } + if cfg.OpenRouter.SiteURL != "" { + headers["HTTP-Referer"] = cfg.OpenRouter.SiteURL + } + if cfg.OpenRouter.AppName != "" { + headers["X-Title"] = cfg.OpenRouter.AppName + } + + return compat.New(compat.Config{ + Name: "openrouter", + BaseURL: cfg.OpenRouter.BaseURL, + APIKey: cfg.OpenRouter.APIKey, + EmbeddingModel: cfg.Embeddings.Model, + MetadataModel: cfg.Metadata.Model, + Temperature: cfg.Metadata.Temperature, + Headers: headers, + HTTPClient: httpClient, + Log: log, + Dimensions: cfg.Embeddings.Dimensions, + }), nil +} diff --git a/internal/ai/provider.go b/internal/ai/provider.go new file mode 100644 index 0000000..ac35676 --- /dev/null +++ b/internal/ai/provider.go @@ -0,0 +1,14 @@ +package ai + +import ( + "context" + + thoughttypes "git.warky.dev/wdevs/amcs/internal/types" +) + +type Provider interface { + Embed(ctx context.Context, input string) ([]float32, error) + ExtractMetadata(ctx context.Context, input string) (thoughttypes.ThoughtMetadata, error) + Summarize(ctx context.Context, systemPrompt, userPrompt string) (string, error) + Name() string +} diff --git a/internal/app/app.go b/internal/app/app.go new file mode 100644 index 0000000..573d24a --- /dev/null +++ b/internal/app/app.go @@ -0,0 +1,145 @@ +package app + +import ( + "context" + "errors" + "fmt" + "log/slog" + "net/http" + "time" + + "git.warky.dev/wdevs/amcs/internal/ai" + "git.warky.dev/wdevs/amcs/internal/auth" + "git.warky.dev/wdevs/amcs/internal/config" + "git.warky.dev/wdevs/amcs/internal/mcpserver" + "git.warky.dev/wdevs/amcs/internal/observability" + "git.warky.dev/wdevs/amcs/internal/session" + "git.warky.dev/wdevs/amcs/internal/store" + "git.warky.dev/wdevs/amcs/internal/tools" +) + +func Run(ctx context.Context, configPath string) error { + cfg, loadedFrom, err := config.Load(configPath) + if err != nil { + return err + } + + logger, err := observability.NewLogger(cfg.Logging) + if err != nil { + return err + } + + logger.Info("loaded configuration", + slog.String("path", loadedFrom), + slog.String("provider", cfg.AI.Provider), + ) + + db, err := store.New(ctx, cfg.Database) + if err != nil { + return err + } + defer db.Close() + + if err := db.VerifyRequirements(ctx, cfg.AI.Embeddings.Dimensions); err != nil { + return err + } + + httpClient := &http.Client{Timeout: 30 * time.Second} + provider, err := ai.NewProvider(cfg.AI, httpClient, logger) + if err != nil { + return err + } + + keyring, err := auth.NewKeyring(cfg.Auth.Keys) + if err != nil { + return err + } + activeProjects := session.NewActiveProjects() + + logger.Info("database connection verified", + slog.String("provider", provider.Name()), + ) + + server := &http.Server{ + Addr: fmt.Sprintf("%s:%d", cfg.Server.Host, cfg.Server.Port), + Handler: routes(logger, cfg, db, provider, keyring, activeProjects), + ReadTimeout: cfg.Server.ReadTimeout, + WriteTimeout: cfg.Server.WriteTimeout, + IdleTimeout: cfg.Server.IdleTimeout, + } + + errCh := make(chan error, 1) + go func() { + logger.Info("starting HTTP server", + slog.String("addr", server.Addr), + slog.String("mcp_path", cfg.MCP.Path), + ) + + if err := server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { + errCh <- err + } + }() + + select { + case <-ctx.Done(): + shutdownCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + logger.Info("shutting down HTTP server") + return server.Shutdown(shutdownCtx) + case err := <-errCh: + return fmt.Errorf("run server: %w", err) + } +} + +func routes(logger *slog.Logger, cfg *config.Config, db *store.DB, provider ai.Provider, keyring *auth.Keyring, activeProjects *session.ActiveProjects) http.Handler { + mux := http.NewServeMux() + + toolSet := mcpserver.ToolSet{ + Capture: tools.NewCaptureTool(db, provider, cfg.Capture, activeProjects, logger), + Search: tools.NewSearchTool(db, provider, cfg.Search, activeProjects), + List: tools.NewListTool(db, cfg.Search, activeProjects), + Stats: tools.NewStatsTool(db), + Get: tools.NewGetTool(db), + Update: tools.NewUpdateTool(db, provider, cfg.Capture, logger), + Delete: tools.NewDeleteTool(db), + Archive: tools.NewArchiveTool(db), + Projects: tools.NewProjectsTool(db, activeProjects), + Context: tools.NewContextTool(db, provider, cfg.Search, activeProjects), + Recall: tools.NewRecallTool(db, provider, cfg.Search, activeProjects), + Summarize: tools.NewSummarizeTool(db, provider, cfg.Search, activeProjects), + Links: tools.NewLinksTool(db, provider, cfg.Search), + } + + mcpHandler := mcpserver.New(cfg.MCP, toolSet) + mux.Handle(cfg.MCP.Path, auth.Middleware(cfg.Auth, keyring, logger)(mcpHandler)) + + mux.HandleFunc("/healthz", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ok")) + }) + + mux.HandleFunc("/readyz", func(w http.ResponseWriter, r *http.Request) { + if err := db.Ready(r.Context()); err != nil { + logger.Error("readiness check failed", slog.String("error", err.Error())) + http.Error(w, "not ready", http.StatusServiceUnavailable) + return + } + + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ready")) + }) + + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("amcs is running")) + }) + + return observability.Chain( + mux, + observability.RequestID(), + observability.Recover(logger), + observability.AccessLog(logger), + observability.Timeout(cfg.Server.WriteTimeout), + ) +} diff --git a/internal/auth/keyring.go b/internal/auth/keyring.go new file mode 100644 index 0000000..56c3135 --- /dev/null +++ b/internal/auth/keyring.go @@ -0,0 +1,29 @@ +package auth + +import ( + "crypto/subtle" + "fmt" + + "git.warky.dev/wdevs/amcs/internal/config" +) + +type Keyring struct { + keys []config.APIKey +} + +func NewKeyring(keys []config.APIKey) (*Keyring, error) { + if len(keys) == 0 { + return nil, fmt.Errorf("keyring requires at least one key") + } + + return &Keyring{keys: append([]config.APIKey(nil), keys...)}, nil +} + +func (k *Keyring) Lookup(value string) (string, bool) { + for _, key := range k.keys { + if subtle.ConstantTimeCompare([]byte(key.Value), []byte(value)) == 1 { + return key.ID, true + } + } + return "", false +} diff --git a/internal/auth/keyring_test.go b/internal/auth/keyring_test.go new file mode 100644 index 0000000..95fc074 --- /dev/null +++ b/internal/auth/keyring_test.go @@ -0,0 +1,107 @@ +package auth + +import ( + "io" + "log/slog" + "net/http" + "net/http/httptest" + "testing" + + "git.warky.dev/wdevs/amcs/internal/config" +) + +func testLogger() *slog.Logger { + return slog.New(slog.NewTextHandler(io.Discard, nil)) +} + +func TestNewKeyringAndLookup(t *testing.T) { + _, err := NewKeyring(nil) + if err == nil { + t.Fatal("NewKeyring(nil) error = nil, want error") + } + + keyring, err := NewKeyring([]config.APIKey{{ID: "client-a", Value: "secret"}}) + if err != nil { + t.Fatalf("NewKeyring() error = %v", err) + } + + if got, ok := keyring.Lookup("secret"); !ok || got != "client-a" { + t.Fatalf("Lookup(secret) = (%q, %v), want (client-a, true)", got, ok) + } + if _, ok := keyring.Lookup("wrong"); ok { + t.Fatal("Lookup(wrong) = true, want false") + } +} + +func TestMiddlewareAllowsHeaderAuthAndSetsContext(t *testing.T) { + keyring, err := NewKeyring([]config.APIKey{{ID: "client-a", Value: "secret"}}) + if err != nil { + t.Fatalf("NewKeyring() error = %v", err) + } + + handler := Middleware(config.AuthConfig{HeaderName: "x-brain-key"}, keyring, testLogger())(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + keyID, ok := KeyIDFromContext(r.Context()) + if !ok || keyID != "client-a" { + t.Fatalf("KeyIDFromContext() = (%q, %v), want (client-a, true)", keyID, ok) + } + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest(http.MethodGet, "/mcp", nil) + req.Header.Set("x-brain-key", "secret") + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK) + } +} + +func TestMiddlewareAllowsQueryParamWhenEnabled(t *testing.T) { + keyring, err := NewKeyring([]config.APIKey{{ID: "client-a", Value: "secret"}}) + if err != nil { + t.Fatalf("NewKeyring() error = %v", err) + } + + handler := Middleware(config.AuthConfig{ + HeaderName: "x-brain-key", + QueryParam: "key", + AllowQueryParam: true, + }, keyring, testLogger())(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNoContent) + })) + + req := httptest.NewRequest(http.MethodGet, "/mcp?key=secret", nil) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusNoContent { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusNoContent) + } +} + +func TestMiddlewareRejectsMissingOrInvalidKey(t *testing.T) { + keyring, err := NewKeyring([]config.APIKey{{ID: "client-a", Value: "secret"}}) + if err != nil { + t.Fatalf("NewKeyring() error = %v", err) + } + + handler := Middleware(config.AuthConfig{HeaderName: "x-brain-key"}, keyring, testLogger())(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Fatal("next handler should not be called") + })) + + req := httptest.NewRequest(http.MethodGet, "/mcp", nil) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + if rec.Code != http.StatusUnauthorized { + t.Fatalf("missing key status = %d, want %d", rec.Code, http.StatusUnauthorized) + } + + req = httptest.NewRequest(http.MethodGet, "/mcp", nil) + req.Header.Set("x-brain-key", "wrong") + rec = httptest.NewRecorder() + handler.ServeHTTP(rec, req) + if rec.Code != http.StatusUnauthorized { + t.Fatalf("invalid key status = %d, want %d", rec.Code, http.StatusUnauthorized) + } +} diff --git a/internal/auth/middleware.go b/internal/auth/middleware.go new file mode 100644 index 0000000..a99da09 --- /dev/null +++ b/internal/auth/middleware.go @@ -0,0 +1,49 @@ +package auth + +import ( + "context" + "log/slog" + "net/http" + "strings" + + "git.warky.dev/wdevs/amcs/internal/config" +) + +type contextKey string + +const keyIDContextKey contextKey = "auth.key_id" + +func Middleware(cfg config.AuthConfig, keyring *Keyring, log *slog.Logger) func(http.Handler) http.Handler { + headerName := cfg.HeaderName + if headerName == "" { + headerName = "x-brain-key" + } + + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + token := strings.TrimSpace(r.Header.Get(headerName)) + if token == "" && cfg.AllowQueryParam { + token = strings.TrimSpace(r.URL.Query().Get(cfg.QueryParam)) + } + + if token == "" { + http.Error(w, "missing API key", http.StatusUnauthorized) + return + } + + keyID, ok := keyring.Lookup(token) + if !ok { + log.Warn("authentication failed", slog.String("remote_addr", r.RemoteAddr)) + http.Error(w, "invalid API key", http.StatusUnauthorized) + return + } + + next.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), keyIDContextKey, keyID))) + }) + } +} + +func KeyIDFromContext(ctx context.Context) (string, bool) { + value, ok := ctx.Value(keyIDContextKey).(string) + return value, ok +} diff --git a/internal/config/config.go b/internal/config/config.go new file mode 100644 index 0000000..3bb07f8 --- /dev/null +++ b/internal/config/config.go @@ -0,0 +1,119 @@ +package config + +import "time" + +const ( + DefaultConfigPath = "./configs/dev.yaml" + DefaultSource = "mcp" +) + +type Config struct { + Server ServerConfig `yaml:"server"` + MCP MCPConfig `yaml:"mcp"` + Auth AuthConfig `yaml:"auth"` + Database DatabaseConfig `yaml:"database"` + AI AIConfig `yaml:"ai"` + Capture CaptureConfig `yaml:"capture"` + Search SearchConfig `yaml:"search"` + Logging LoggingConfig `yaml:"logging"` + Observability ObservabilityConfig `yaml:"observability"` +} + +type ServerConfig struct { + Host string `yaml:"host"` + Port int `yaml:"port"` + ReadTimeout time.Duration `yaml:"read_timeout"` + WriteTimeout time.Duration `yaml:"write_timeout"` + IdleTimeout time.Duration `yaml:"idle_timeout"` + AllowedOrigins []string `yaml:"allowed_origins"` +} + +type MCPConfig struct { + Path string `yaml:"path"` + ServerName string `yaml:"server_name"` + Version string `yaml:"version"` + Transport string `yaml:"transport"` +} + +type AuthConfig struct { + Mode string `yaml:"mode"` + HeaderName string `yaml:"header_name"` + QueryParam string `yaml:"query_param"` + AllowQueryParam bool `yaml:"allow_query_param"` + Keys []APIKey `yaml:"keys"` +} + +type APIKey struct { + ID string `yaml:"id"` + Value string `yaml:"value"` + Description string `yaml:"description"` +} + +type DatabaseConfig struct { + URL string `yaml:"url"` + MaxConns int32 `yaml:"max_conns"` + MinConns int32 `yaml:"min_conns"` + MaxConnLifetime time.Duration `yaml:"max_conn_lifetime"` + MaxConnIdleTime time.Duration `yaml:"max_conn_idle_time"` +} + +type AIConfig struct { + Provider string `yaml:"provider"` + Embeddings AIEmbeddingConfig `yaml:"embeddings"` + Metadata AIMetadataConfig `yaml:"metadata"` + LiteLLM LiteLLMConfig `yaml:"litellm"` + OpenRouter OpenRouterAIConfig `yaml:"openrouter"` +} + +type AIEmbeddingConfig struct { + Model string `yaml:"model"` + Dimensions int `yaml:"dimensions"` +} + +type AIMetadataConfig struct { + Model string `yaml:"model"` + Temperature float64 `yaml:"temperature"` +} + +type LiteLLMConfig struct { + BaseURL string `yaml:"base_url"` + APIKey string `yaml:"api_key"` + UseResponsesAPI bool `yaml:"use_responses_api"` + RequestHeaders map[string]string `yaml:"request_headers"` + EmbeddingModel string `yaml:"embedding_model"` + MetadataModel string `yaml:"metadata_model"` +} + +type OpenRouterAIConfig struct { + BaseURL string `yaml:"base_url"` + APIKey string `yaml:"api_key"` + AppName string `yaml:"app_name"` + SiteURL string `yaml:"site_url"` + ExtraHeaders map[string]string `yaml:"extra_headers"` +} + +type CaptureConfig struct { + Source string `yaml:"source"` + MetadataDefaults CaptureMetadataDefault `yaml:"metadata_defaults"` +} + +type CaptureMetadataDefault struct { + Type string `yaml:"type"` + TopicFallback string `yaml:"topic_fallback"` +} + +type SearchConfig struct { + DefaultLimit int `yaml:"default_limit"` + DefaultThreshold float64 `yaml:"default_threshold"` + MaxLimit int `yaml:"max_limit"` +} + +type LoggingConfig struct { + Level string `yaml:"level"` + Format string `yaml:"format"` +} + +type ObservabilityConfig struct { + MetricsEnabled bool `yaml:"metrics_enabled"` + PprofEnabled bool `yaml:"pprof_enabled"` +} diff --git a/internal/config/loader.go b/internal/config/loader.go new file mode 100644 index 0000000..0013eaa --- /dev/null +++ b/internal/config/loader.go @@ -0,0 +1,113 @@ +package config + +import ( + "fmt" + "os" + "strconv" + "strings" + "time" + + "gopkg.in/yaml.v3" +) + +func Load(explicitPath string) (*Config, string, error) { + path := ResolvePath(explicitPath) + + data, err := os.ReadFile(path) + if err != nil { + return nil, path, fmt.Errorf("read config %q: %w", path, err) + } + + cfg := defaultConfig() + if err := yaml.Unmarshal(data, &cfg); err != nil { + return nil, path, fmt.Errorf("decode config %q: %w", path, err) + } + + applyEnvOverrides(&cfg) + if err := cfg.Validate(); err != nil { + return nil, path, err + } + + return &cfg, path, nil +} + +func ResolvePath(explicitPath string) string { + if strings.TrimSpace(explicitPath) != "" { + return explicitPath + } + + if envPath := strings.TrimSpace(os.Getenv("OB1_CONFIG")); envPath != "" { + return envPath + } + + return DefaultConfigPath +} + +func defaultConfig() Config { + return Config{ + Server: ServerConfig{ + Host: "0.0.0.0", + Port: 8080, + ReadTimeout: 15 * time.Second, + WriteTimeout: 30 * time.Second, + IdleTimeout: 60 * time.Second, + }, + MCP: MCPConfig{ + Path: "/mcp", + ServerName: "amcs", + Version: "0.1.0", + Transport: "streamable_http", + }, + Auth: AuthConfig{ + Mode: "api_keys", + HeaderName: "x-brain-key", + QueryParam: "key", + }, + AI: AIConfig{ + Provider: "litellm", + Embeddings: AIEmbeddingConfig{ + Model: "openai/text-embedding-3-small", + Dimensions: 1536, + }, + Metadata: AIMetadataConfig{ + Model: "gpt-4o-mini", + Temperature: 0.1, + }, + }, + Capture: CaptureConfig{ + Source: DefaultSource, + MetadataDefaults: CaptureMetadataDefault{ + Type: "observation", + TopicFallback: "uncategorized", + }, + }, + Search: SearchConfig{ + DefaultLimit: 10, + DefaultThreshold: 0.5, + MaxLimit: 50, + }, + Logging: LoggingConfig{ + Level: "info", + Format: "json", + }, + } +} + +func applyEnvOverrides(cfg *Config) { + overrideString(&cfg.Database.URL, "OB1_DATABASE_URL") + overrideString(&cfg.AI.LiteLLM.BaseURL, "OB1_LITELLM_BASE_URL") + overrideString(&cfg.AI.LiteLLM.APIKey, "OB1_LITELLM_API_KEY") + overrideString(&cfg.AI.OpenRouter.APIKey, "OB1_OPENROUTER_API_KEY") + + if value, ok := os.LookupEnv("OB1_SERVER_PORT"); ok { + if port, err := strconv.Atoi(strings.TrimSpace(value)); err == nil { + cfg.Server.Port = port + } + } +} + +func overrideString(target *string, envKey string) { + if value, ok := os.LookupEnv(envKey); ok { + *target = strings.TrimSpace(value) + } +} diff --git a/internal/config/loader_test.go b/internal/config/loader_test.go new file mode 100644 index 0000000..87b5933 --- /dev/null +++ b/internal/config/loader_test.go @@ -0,0 +1,71 @@ +package config + +import ( + "os" + "path/filepath" + "testing" +) + +func TestResolvePathPrecedence(t *testing.T) { + t.Setenv("OB1_CONFIG", "/tmp/from-env.yaml") + + if got := ResolvePath("/tmp/explicit.yaml"); got != "/tmp/explicit.yaml" { + t.Fatalf("ResolvePath explicit = %q, want %q", got, "/tmp/explicit.yaml") + } + + if got := ResolvePath(""); got != "/tmp/from-env.yaml" { + t.Fatalf("ResolvePath env = %q, want %q", got, "/tmp/from-env.yaml") + } +} + +func TestLoadAppliesEnvOverrides(t *testing.T) { + configPath := filepath.Join(t.TempDir(), "test.yaml") + if err := os.WriteFile(configPath, []byte(` +server: + port: 8080 +mcp: + path: "/mcp" +auth: + keys: + - id: "test" + value: "secret" +database: + url: "postgres://from-file" +ai: + provider: "litellm" + embeddings: + dimensions: 1536 + litellm: + base_url: "http://localhost:4000/v1" + api_key: "file-key" +search: + default_limit: 10 + max_limit: 50 +logging: + level: "info" +`), 0o600); err != nil { + t.Fatalf("write config: %v", err) + } + + t.Setenv("OB1_DATABASE_URL", "postgres://from-env") + t.Setenv("OB1_LITELLM_API_KEY", "env-key") + t.Setenv("OB1_SERVER_PORT", "9090") + + cfg, loadedFrom, err := Load(configPath) + if err != nil { + t.Fatalf("Load() error = %v", err) + } + + if loadedFrom != configPath { + t.Fatalf("loadedFrom = %q, want %q", loadedFrom, configPath) + } + if cfg.Database.URL != "postgres://from-env" { + t.Fatalf("database url = %q, want env override", cfg.Database.URL) + } + if cfg.AI.LiteLLM.APIKey != "env-key" { + t.Fatalf("litellm api key = %q, want env override", cfg.AI.LiteLLM.APIKey) + } + if cfg.Server.Port != 9090 { + t.Fatalf("server port = %d, want 9090", cfg.Server.Port) + } +} diff --git a/internal/config/validate.go b/internal/config/validate.go new file mode 100644 index 0000000..ef09be5 --- /dev/null +++ b/internal/config/validate.go @@ -0,0 +1,71 @@ +package config + +import ( + "fmt" + "strings" +) + +func (c Config) Validate() error { + if strings.TrimSpace(c.Database.URL) == "" { + return fmt.Errorf("invalid config: database.url is required") + } + + if len(c.Auth.Keys) == 0 { + return fmt.Errorf("invalid config: auth.keys must not be empty") + } + + for i, key := range c.Auth.Keys { + if strings.TrimSpace(key.ID) == "" { + return fmt.Errorf("invalid config: auth.keys[%d].id is required", i) + } + if strings.TrimSpace(key.Value) == "" { + return fmt.Errorf("invalid config: auth.keys[%d].value is required", i) + } + } + + if strings.TrimSpace(c.MCP.Path) == "" { + return fmt.Errorf("invalid config: mcp.path is required") + } + + switch c.AI.Provider { + case "litellm", "openrouter": + default: + return fmt.Errorf("invalid config: unsupported ai.provider %q", c.AI.Provider) + } + + if c.AI.Embeddings.Dimensions <= 0 { + return fmt.Errorf("invalid config: ai.embeddings.dimensions must be greater than zero") + } + + switch c.AI.Provider { + case "litellm": + if strings.TrimSpace(c.AI.LiteLLM.BaseURL) == "" { + return fmt.Errorf("invalid config: ai.litellm.base_url is required when ai.provider=litellm") + } + if strings.TrimSpace(c.AI.LiteLLM.APIKey) == "" { + return fmt.Errorf("invalid config: ai.litellm.api_key is required when ai.provider=litellm") + } + case "openrouter": + if strings.TrimSpace(c.AI.OpenRouter.BaseURL) == "" { + return fmt.Errorf("invalid config: ai.openrouter.base_url is required when ai.provider=openrouter") + } + if strings.TrimSpace(c.AI.OpenRouter.APIKey) == "" { + return fmt.Errorf("invalid config: ai.openrouter.api_key is required when ai.provider=openrouter") + } + } + + if c.Server.Port <= 0 { + return fmt.Errorf("invalid config: server.port must be greater than zero") + } + if c.Search.DefaultLimit <= 0 { + return fmt.Errorf("invalid config: search.default_limit must be greater than zero") + } + if c.Search.MaxLimit < c.Search.DefaultLimit { + return fmt.Errorf("invalid config: search.max_limit must be greater than or equal to search.default_limit") + } + if strings.TrimSpace(c.Logging.Level) == "" { + return fmt.Errorf("invalid config: logging.level is required") + } + + return nil +} diff --git a/internal/config/validate_test.go b/internal/config/validate_test.go new file mode 100644 index 0000000..8d8192e --- /dev/null +++ b/internal/config/validate_test.go @@ -0,0 +1,60 @@ +package config + +import "testing" + +func validConfig() Config { + return Config{ + Server: ServerConfig{Port: 8080}, + MCP: MCPConfig{Path: "/mcp"}, + Auth: AuthConfig{ + Keys: []APIKey{{ID: "test", Value: "secret"}}, + }, + Database: DatabaseConfig{URL: "postgres://example"}, + AI: AIConfig{ + Provider: "litellm", + Embeddings: AIEmbeddingConfig{ + Dimensions: 1536, + }, + LiteLLM: LiteLLMConfig{ + BaseURL: "http://localhost:4000/v1", + APIKey: "key", + }, + OpenRouter: OpenRouterAIConfig{ + BaseURL: "https://openrouter.ai/api/v1", + APIKey: "key", + }, + }, + Search: SearchConfig{DefaultLimit: 10, MaxLimit: 50}, + Logging: LoggingConfig{Level: "info"}, + } +} + +func TestValidateAcceptsLiteLLMAndOpenRouter(t *testing.T) { + cfg := validConfig() + if err := cfg.Validate(); err != nil { + t.Fatalf("Validate litellm error = %v", err) + } + + cfg.AI.Provider = "openrouter" + if err := cfg.Validate(); err != nil { + t.Fatalf("Validate openrouter error = %v", err) + } +} + +func TestValidateRejectsInvalidProvider(t *testing.T) { + cfg := validConfig() + cfg.AI.Provider = "unknown" + + if err := cfg.Validate(); err == nil { + t.Fatal("Validate() error = nil, want error for unsupported provider") + } +} + +func TestValidateRejectsEmptyAuthKeyValue(t *testing.T) { + cfg := validConfig() + cfg.Auth.Keys[0].Value = "" + + if err := cfg.Validate(); err == nil { + t.Fatal("Validate() error = nil, want error for empty auth key value") + } +} diff --git a/internal/mcpserver/server.go b/internal/mcpserver/server.go new file mode 100644 index 0000000..ce0a513 --- /dev/null +++ b/internal/mcpserver/server.go @@ -0,0 +1,126 @@ +package mcpserver + +import ( + "net/http" + "time" + + "github.com/modelcontextprotocol/go-sdk/mcp" + + "git.warky.dev/wdevs/amcs/internal/config" + "git.warky.dev/wdevs/amcs/internal/tools" +) + +type ToolSet struct { + Capture *tools.CaptureTool + Search *tools.SearchTool + List *tools.ListTool + Stats *tools.StatsTool + Get *tools.GetTool + Update *tools.UpdateTool + Delete *tools.DeleteTool + Archive *tools.ArchiveTool + Projects *tools.ProjectsTool + Context *tools.ContextTool + Recall *tools.RecallTool + Summarize *tools.SummarizeTool + Links *tools.LinksTool +} + +func New(cfg config.MCPConfig, toolSet ToolSet) http.Handler { + server := mcp.NewServer(&mcp.Implementation{ + Name: cfg.ServerName, + Version: cfg.Version, + }, nil) + + mcp.AddTool(server, &mcp.Tool{ + Name: "capture_thought", + Description: "Store a thought with generated embeddings and extracted metadata.", + }, toolSet.Capture.Handle) + + mcp.AddTool(server, &mcp.Tool{ + Name: "search_thoughts", + Description: "Search stored thoughts by semantic similarity.", + }, toolSet.Search.Handle) + + mcp.AddTool(server, &mcp.Tool{ + Name: "list_thoughts", + Description: "List recent thoughts with optional metadata filters.", + }, toolSet.List.Handle) + + mcp.AddTool(server, &mcp.Tool{ + Name: "thought_stats", + Description: "Get counts and top metadata buckets across stored thoughts.", + }, toolSet.Stats.Handle) + + mcp.AddTool(server, &mcp.Tool{ + Name: "get_thought", + Description: "Retrieve a full thought by id.", + }, toolSet.Get.Handle) + + mcp.AddTool(server, &mcp.Tool{ + Name: "update_thought", + Description: "Update thought content or merge metadata.", + }, toolSet.Update.Handle) + + mcp.AddTool(server, &mcp.Tool{ + Name: "delete_thought", + Description: "Hard-delete a thought by id.", + }, toolSet.Delete.Handle) + + mcp.AddTool(server, &mcp.Tool{ + Name: "archive_thought", + Description: "Archive a thought so it is hidden from default search and listing.", + }, toolSet.Archive.Handle) + + mcp.AddTool(server, &mcp.Tool{ + Name: "create_project", + Description: "Create a named project container for thoughts.", + }, toolSet.Projects.Create) + + mcp.AddTool(server, &mcp.Tool{ + Name: "list_projects", + Description: "List projects and their current thought counts.", + }, toolSet.Projects.List) + + mcp.AddTool(server, &mcp.Tool{ + Name: "set_active_project", + Description: "Set the active project for the current MCP session.", + }, toolSet.Projects.SetActive) + + mcp.AddTool(server, &mcp.Tool{ + Name: "get_active_project", + Description: "Return the active project for the current MCP session.", + }, toolSet.Projects.GetActive) + + mcp.AddTool(server, &mcp.Tool{ + Name: "get_project_context", + Description: "Get recent and semantic context for a project.", + }, toolSet.Context.Handle) + + mcp.AddTool(server, &mcp.Tool{ + Name: "recall_context", + Description: "Recall semantically relevant and recent context.", + }, toolSet.Recall.Handle) + + mcp.AddTool(server, &mcp.Tool{ + Name: "summarize_thoughts", + Description: "Summarize a filtered or searched set of thoughts.", + }, toolSet.Summarize.Handle) + + mcp.AddTool(server, &mcp.Tool{ + Name: "link_thoughts", + Description: "Create a typed relationship between two thoughts.", + }, toolSet.Links.Link) + + mcp.AddTool(server, &mcp.Tool{ + Name: "related_thoughts", + Description: "Retrieve explicit links and semantic neighbors for a thought.", + }, toolSet.Links.Related) + + return mcp.NewStreamableHTTPHandler(func(*http.Request) *mcp.Server { + return server + }, &mcp.StreamableHTTPOptions{ + JSONResponse: true, + SessionTimeout: 10 * time.Minute, + }) +} diff --git a/internal/metadata/normalize.go b/internal/metadata/normalize.go new file mode 100644 index 0000000..68a9df7 --- /dev/null +++ b/internal/metadata/normalize.go @@ -0,0 +1,155 @@ +package metadata + +import ( + "sort" + "strings" + + "git.warky.dev/wdevs/amcs/internal/config" + thoughttypes "git.warky.dev/wdevs/amcs/internal/types" +) + +const ( + DefaultType = "observation" + DefaultTopicFallback = "uncategorized" + maxTopics = 10 +) + +var allowedTypes = map[string]struct{}{ + "observation": {}, + "task": {}, + "idea": {}, + "reference": {}, + "person_note": {}, +} + +func Fallback(capture config.CaptureConfig) thoughttypes.ThoughtMetadata { + topicFallback := strings.TrimSpace(capture.MetadataDefaults.TopicFallback) + if topicFallback == "" { + topicFallback = DefaultTopicFallback + } + + return thoughttypes.ThoughtMetadata{ + People: []string{}, + ActionItems: []string{}, + DatesMentioned: []string{}, + Topics: []string{topicFallback}, + Type: normalizeType(capture.MetadataDefaults.Type), + Source: normalizeSource(capture.Source), + } +} + +func Normalize(in thoughttypes.ThoughtMetadata, capture config.CaptureConfig) thoughttypes.ThoughtMetadata { + out := thoughttypes.ThoughtMetadata{ + People: normalizeList(in.People, 0), + ActionItems: normalizeList(in.ActionItems, 0), + DatesMentioned: normalizeList(in.DatesMentioned, 0), + Topics: normalizeList(in.Topics, maxTopics), + Type: normalizeType(in.Type), + Source: normalizeSource(in.Source), + } + + if len(out.Topics) == 0 { + out.Topics = Fallback(capture).Topics + } + if out.Type == "" { + out.Type = Fallback(capture).Type + } + if out.Source == "" { + out.Source = Fallback(capture).Source + } + + return out +} + +func normalizeList(values []string, limit int) []string { + seen := make(map[string]struct{}, len(values)) + result := make([]string, 0, len(values)) + + for _, value := range values { + trimmed := strings.Join(strings.Fields(strings.TrimSpace(value)), " ") + if trimmed == "" { + continue + } + + key := strings.ToLower(trimmed) + if _, ok := seen[key]; ok { + continue + } + + seen[key] = struct{}{} + result = append(result, trimmed) + + if limit > 0 && len(result) >= limit { + break + } + } + + return result +} + +func normalizeType(value string) string { + normalized := strings.ToLower(strings.TrimSpace(value)) + if normalized == "" { + return DefaultType + } + if _, ok := allowedTypes[normalized]; ok { + return normalized + } + return DefaultType +} + +func normalizeSource(value string) string { + normalized := strings.TrimSpace(value) + if normalized == "" { + return config.DefaultSource + } + return normalized +} + +func Merge(base, patch thoughttypes.ThoughtMetadata, capture config.CaptureConfig) thoughttypes.ThoughtMetadata { + merged := base + + if len(patch.People) > 0 { + merged.People = append(append([]string{}, merged.People...), patch.People...) + } + if len(patch.ActionItems) > 0 { + merged.ActionItems = append(append([]string{}, merged.ActionItems...), patch.ActionItems...) + } + if len(patch.DatesMentioned) > 0 { + merged.DatesMentioned = append(append([]string{}, merged.DatesMentioned...), patch.DatesMentioned...) + } + if len(patch.Topics) > 0 { + merged.Topics = append(append([]string{}, merged.Topics...), patch.Topics...) + } + if strings.TrimSpace(patch.Type) != "" { + merged.Type = patch.Type + } + if strings.TrimSpace(patch.Source) != "" { + merged.Source = patch.Source + } + + return Normalize(merged, capture) +} + +func SortedTopCounts(in map[string]int, limit int) []thoughttypes.KeyCount { + out := make([]thoughttypes.KeyCount, 0, len(in)) + for key, count := range in { + if strings.TrimSpace(key) == "" { + continue + } + out = append(out, thoughttypes.KeyCount{Key: key, Count: count}) + } + + sort.Slice(out, func(i, j int) bool { + if out[i].Count == out[j].Count { + return out[i].Key < out[j].Key + } + return out[i].Count > out[j].Count + }) + + if limit > 0 && len(out) > limit { + return out[:limit] + } + + return out +} diff --git a/internal/metadata/normalize_test.go b/internal/metadata/normalize_test.go new file mode 100644 index 0000000..55397b7 --- /dev/null +++ b/internal/metadata/normalize_test.go @@ -0,0 +1,81 @@ +package metadata + +import ( + "strings" + "testing" + + "git.warky.dev/wdevs/amcs/internal/config" + thoughttypes "git.warky.dev/wdevs/amcs/internal/types" +) + +func testCaptureConfig() config.CaptureConfig { + return config.CaptureConfig{ + Source: "mcp", + MetadataDefaults: config.CaptureMetadataDefault{ + Type: "observation", + TopicFallback: "uncategorized", + }, + } +} + +func TestFallbackUsesConfiguredDefaults(t *testing.T) { + got := Fallback(testCaptureConfig()) + if got.Type != "observation" { + t.Fatalf("Fallback type = %q, want observation", got.Type) + } + if len(got.Topics) != 1 || got.Topics[0] != "uncategorized" { + t.Fatalf("Fallback topics = %#v, want [uncategorized]", got.Topics) + } + if got.Source != "mcp" { + t.Fatalf("Fallback source = %q, want mcp", got.Source) + } +} + +func TestNormalizeTrimsDedupesAndCapsTopics(t *testing.T) { + topics := []string{} + for i := 0; i < 12; i++ { + topics = append(topics, strings.TrimSpace(" topic ")) + topics = append(topics, string(rune('a'+i))) + } + + got := Normalize(thoughttypes.ThoughtMetadata{ + People: []string{" Alice ", "alice", "", "Bob"}, + Topics: topics, + Type: "INVALID", + }, testCaptureConfig()) + + if len(got.People) != 2 { + t.Fatalf("People len = %d, want 2", len(got.People)) + } + if got.Type != "observation" { + t.Fatalf("Type = %q, want observation", got.Type) + } + if len(got.Topics) != maxTopics { + t.Fatalf("Topics len = %d, want %d", len(got.Topics), maxTopics) + } +} + +func TestMergeAddsPatchAndNormalizes(t *testing.T) { + base := thoughttypes.ThoughtMetadata{ + People: []string{"Alice"}, + Topics: []string{"go"}, + Type: "idea", + Source: "mcp", + } + patch := thoughttypes.ThoughtMetadata{ + People: []string{" Bob ", "alice"}, + Topics: []string{"testing"}, + Type: "task", + } + + got := Merge(base, patch, testCaptureConfig()) + if got.Type != "task" { + t.Fatalf("Type = %q, want task", got.Type) + } + if len(got.People) != 2 { + t.Fatalf("People len = %d, want 2", len(got.People)) + } + if len(got.Topics) != 2 { + t.Fatalf("Topics len = %d, want 2", len(got.Topics)) + } +} diff --git a/internal/observability/http.go b/internal/observability/http.go new file mode 100644 index 0000000..f7d7aae --- /dev/null +++ b/internal/observability/http.go @@ -0,0 +1,110 @@ +package observability + +import ( + "context" + "log/slog" + "net" + "net/http" + "runtime/debug" + "time" + + "github.com/google/uuid" +) + +type contextKey string + +const requestIDContextKey contextKey = "request_id" + +func Chain(h http.Handler, middlewares ...func(http.Handler) http.Handler) http.Handler { + for i := len(middlewares) - 1; i >= 0; i-- { + h = middlewares[i](h) + } + return h +} + +func RequestID() func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requestID := r.Header.Get("X-Request-Id") + if requestID == "" { + requestID = uuid.NewString() + } + w.Header().Set("X-Request-Id", requestID) + ctx := context.WithValue(r.Context(), requestIDContextKey, requestID) + next.ServeHTTP(w, r.WithContext(ctx)) + }) + } +} + +func Recover(log *slog.Logger) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer func() { + if recovered := recover(); recovered != nil { + log.Error("panic recovered", + slog.Any("panic", recovered), + slog.String("request_id", RequestIDFromContext(r.Context())), + slog.String("stack", string(debug.Stack())), + ) + http.Error(w, "internal server error", http.StatusInternalServerError) + } + }() + next.ServeHTTP(w, r) + }) + } +} + +func AccessLog(log *slog.Logger) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + recorder := &statusRecorder{ResponseWriter: w, status: http.StatusOK} + started := time.Now() + next.ServeHTTP(recorder, r) + + log.Info("http request", + slog.String("request_id", RequestIDFromContext(r.Context())), + slog.String("method", r.Method), + slog.String("path", r.URL.Path), + slog.Int("status", recorder.status), + slog.Duration("duration", time.Since(started)), + slog.String("remote_addr", stripPort(r.RemoteAddr)), + ) + }) + } +} + +func Timeout(timeout time.Duration) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + if timeout <= 0 { + return next + } + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx, cancel := context.WithTimeout(r.Context(), timeout) + defer cancel() + next.ServeHTTP(w, r.WithContext(ctx)) + }) + } +} + +func RequestIDFromContext(ctx context.Context) string { + value, _ := ctx.Value(requestIDContextKey).(string) + return value +} + +type statusRecorder struct { + http.ResponseWriter + status int +} + +func (s *statusRecorder) WriteHeader(statusCode int) { + s.status = statusCode + s.ResponseWriter.WriteHeader(statusCode) +} + +func stripPort(remote string) string { + host, _, err := net.SplitHostPort(remote) + if err != nil { + return remote + } + return host +} diff --git a/internal/observability/http_test.go b/internal/observability/http_test.go new file mode 100644 index 0000000..45b524c --- /dev/null +++ b/internal/observability/http_test.go @@ -0,0 +1,59 @@ +package observability + +import ( + "io" + "log/slog" + "net/http" + "net/http/httptest" + "testing" + "time" +) + +func TestRequestIDSetsHeaderAndContext(t *testing.T) { + handler := RequestID()(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if got := RequestIDFromContext(r.Context()); got == "" { + t.Fatal("RequestIDFromContext() = empty, want non-empty") + } + w.WriteHeader(http.StatusNoContent) + })) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + if rec.Header().Get("X-Request-Id") == "" { + t.Fatal("X-Request-Id header = empty, want non-empty") + } +} + +func TestTimeoutAddsContextDeadline(t *testing.T) { + handler := Timeout(50 * time.Millisecond)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if _, ok := r.Context().Deadline(); !ok { + t.Fatal("context deadline missing") + } + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK) + } +} + +func TestRecoverHandlesPanic(t *testing.T) { + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + handler := Recover(logger)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + panic("boom") + })) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusInternalServerError { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusInternalServerError) + } +} diff --git a/internal/observability/logger.go b/internal/observability/logger.go new file mode 100644 index 0000000..efa4bca --- /dev/null +++ b/internal/observability/logger.go @@ -0,0 +1,52 @@ +package observability + +import ( + "fmt" + "io" + "log/slog" + "os" + "strings" + + "git.warky.dev/wdevs/amcs/internal/config" +) + +func NewLogger(cfg config.LoggingConfig) (*slog.Logger, error) { + level, err := parseLevel(cfg.Level) + if err != nil { + return nil, err + } + + options := &slog.HandlerOptions{Level: level} + handler, err := newHandler(cfg.Format, os.Stdout, options) + if err != nil { + return nil, err + } + + return slog.New(handler), nil +} + +func parseLevel(value string) (slog.Leveler, error) { + switch strings.ToLower(strings.TrimSpace(value)) { + case "", "info": + return slog.LevelInfo, nil + case "debug": + return slog.LevelDebug, nil + case "warn", "warning": + return slog.LevelWarn, nil + case "error": + return slog.LevelError, nil + default: + return nil, fmt.Errorf("invalid logging.level %q", value) + } +} + +func newHandler(format string, w io.Writer, opts *slog.HandlerOptions) (slog.Handler, error) { + switch strings.ToLower(strings.TrimSpace(format)) { + case "", "json": + return slog.NewJSONHandler(w, opts), nil + case "text": + return slog.NewTextHandler(w, opts), nil + default: + return nil, fmt.Errorf("invalid logging.format %q", format) + } +} diff --git a/internal/session/active_project.go b/internal/session/active_project.go new file mode 100644 index 0000000..bfab87d --- /dev/null +++ b/internal/session/active_project.go @@ -0,0 +1,37 @@ +package session + +import ( + "sync" + + "github.com/google/uuid" +) + +type ActiveProjects struct { + mu sync.RWMutex + bySession map[string]uuid.UUID +} + +func NewActiveProjects() *ActiveProjects { + return &ActiveProjects{ + bySession: map[string]uuid.UUID{}, + } +} + +func (a *ActiveProjects) Set(sessionID string, projectID uuid.UUID) { + a.mu.Lock() + defer a.mu.Unlock() + a.bySession[sessionID] = projectID +} + +func (a *ActiveProjects) Get(sessionID string) (uuid.UUID, bool) { + a.mu.RLock() + defer a.mu.RUnlock() + projectID, ok := a.bySession[sessionID] + return projectID, ok +} + +func (a *ActiveProjects) Clear(sessionID string) { + a.mu.Lock() + defer a.mu.Unlock() + delete(a.bySession, sessionID) +} diff --git a/internal/session/active_project_test.go b/internal/session/active_project_test.go new file mode 100644 index 0000000..e9b50f9 --- /dev/null +++ b/internal/session/active_project_test.go @@ -0,0 +1,27 @@ +package session + +import ( + "testing" + + "github.com/google/uuid" +) + +func TestActiveProjectsSetGetClear(t *testing.T) { + store := NewActiveProjects() + projectID := uuid.New() + + if _, ok := store.Get("session-1"); ok { + t.Fatal("Get() before Set() = true, want false") + } + + store.Set("session-1", projectID) + got, ok := store.Get("session-1") + if !ok || got != projectID { + t.Fatalf("Get() = (%v, %v), want (%v, true)", got, ok, projectID) + } + + store.Clear("session-1") + if _, ok := store.Get("session-1"); ok { + t.Fatal("Get() after Clear() = true, want false") + } +} diff --git a/internal/store/db.go b/internal/store/db.go new file mode 100644 index 0000000..8235d53 --- /dev/null +++ b/internal/store/db.go @@ -0,0 +1,118 @@ +package store + +import ( + "context" + "fmt" + "regexp" + "time" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" + pgxvec "github.com/pgvector/pgvector-go/pgx" + + "git.warky.dev/wdevs/amcs/internal/config" +) + +type DB struct { + pool *pgxpool.Pool +} + +func New(ctx context.Context, cfg config.DatabaseConfig) (*DB, error) { + poolConfig, err := pgxpool.ParseConfig(cfg.URL) + if err != nil { + return nil, fmt.Errorf("parse database config: %w", err) + } + + poolConfig.MaxConns = cfg.MaxConns + poolConfig.MinConns = cfg.MinConns + poolConfig.MaxConnLifetime = cfg.MaxConnLifetime + poolConfig.MaxConnIdleTime = cfg.MaxConnIdleTime + poolConfig.AfterConnect = func(ctx context.Context, conn *pgx.Conn) error { + return pgxvec.RegisterTypes(ctx, conn) + } + + pool, err := pgxpool.NewWithConfig(ctx, poolConfig) + if err != nil { + return nil, fmt.Errorf("create database pool: %w", err) + } + + db := &DB{pool: pool} + if err := db.Ping(ctx); err != nil { + pool.Close() + return nil, err + } + + return db, nil +} + +func (db *DB) Close() { + if db == nil || db.pool == nil { + return + } + + db.pool.Close() +} + +func (db *DB) Ping(ctx context.Context) error { + if err := db.pool.Ping(ctx); err != nil { + return fmt.Errorf("ping database: %w", err) + } + + return nil +} + +func (db *DB) Ready(ctx context.Context) error { + readyCtx, cancel := context.WithTimeout(ctx, 2*time.Second) + defer cancel() + + return db.Ping(readyCtx) +} + +func (db *DB) VerifyRequirements(ctx context.Context, dimensions int) error { + var hasVector bool + if err := db.pool.QueryRow(ctx, `select exists(select 1 from pg_extension where extname = 'vector')`).Scan(&hasVector); err != nil { + return fmt.Errorf("verify vector extension: %w", err) + } + if !hasVector { + return fmt.Errorf("vector extension is not installed") + } + + var hasMatchThoughts bool + if err := db.pool.QueryRow(ctx, `select exists(select 1 from pg_proc where proname = 'match_thoughts')`).Scan(&hasMatchThoughts); err != nil { + return fmt.Errorf("verify match_thoughts function: %w", err) + } + if !hasMatchThoughts { + return fmt.Errorf("match_thoughts function is missing") + } + + var embeddingType string + err := db.pool.QueryRow(ctx, ` + select format_type(a.atttypid, a.atttypmod) + from pg_attribute a + join pg_class c on c.oid = a.attrelid + join pg_namespace n on n.oid = c.relnamespace + where n.nspname = 'public' + and c.relname = 'thoughts' + and a.attname = 'embedding' + and not a.attisdropped + `).Scan(&embeddingType) + if err != nil { + return fmt.Errorf("verify thoughts.embedding type: %w", err) + } + + re := regexp.MustCompile(`vector\((\d+)\)`) + matches := re.FindStringSubmatch(embeddingType) + if len(matches) != 2 { + return fmt.Errorf("unexpected embedding type %q", embeddingType) + } + + var actualDimensions int + if _, err := fmt.Sscanf(matches[1], "%d", &actualDimensions); err != nil { + return fmt.Errorf("parse embedding dimensions from %q: %w", embeddingType, err) + } + if actualDimensions != dimensions { + return fmt.Errorf("embedding dimension mismatch: config=%d db=%d", dimensions, actualDimensions) + } + + return nil +} diff --git a/internal/store/links.go b/internal/store/links.go new file mode 100644 index 0000000..6749084 --- /dev/null +++ b/internal/store/links.go @@ -0,0 +1,69 @@ +package store + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/google/uuid" + + thoughttypes "git.warky.dev/wdevs/amcs/internal/types" +) + +func (db *DB) InsertLink(ctx context.Context, link thoughttypes.ThoughtLink) error { + _, err := db.pool.Exec(ctx, ` + insert into thought_links (from_id, to_id, relation) + values ($1, $2, $3) + `, link.FromID, link.ToID, link.Relation) + if err != nil { + return fmt.Errorf("insert link: %w", err) + } + return nil +} + +func (db *DB) LinkedThoughts(ctx context.Context, thoughtID uuid.UUID) ([]thoughttypes.LinkedThought, error) { + rows, err := db.pool.Query(ctx, ` + select t.id, t.content, t.metadata, t.project_id, t.archived_at, t.created_at, t.updated_at, l.relation, 'outgoing' as direction, l.created_at + from thought_links l + join thoughts t on t.id = l.to_id + where l.from_id = $1 + union all + select t.id, t.content, t.metadata, t.project_id, t.archived_at, t.created_at, t.updated_at, l.relation, 'incoming' as direction, l.created_at + from thought_links l + join thoughts t on t.id = l.from_id + where l.to_id = $1 + order by created_at desc + `, thoughtID) + if err != nil { + return nil, fmt.Errorf("query linked thoughts: %w", err) + } + defer rows.Close() + + links := make([]thoughttypes.LinkedThought, 0) + for rows.Next() { + var linked thoughttypes.LinkedThought + var metadataBytes []byte + if err := rows.Scan( + &linked.Thought.ID, + &linked.Thought.Content, + &metadataBytes, + &linked.Thought.ProjectID, + &linked.Thought.ArchivedAt, + &linked.Thought.CreatedAt, + &linked.Thought.UpdatedAt, + &linked.Relation, + &linked.Direction, + &linked.CreatedAt, + ); err != nil { + return nil, fmt.Errorf("scan linked thought: %w", err) + } + if err := json.Unmarshal(metadataBytes, &linked.Thought.Metadata); err != nil { + return nil, fmt.Errorf("decode linked thought metadata: %w", err) + } + links = append(links, linked) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("iterate linked thoughts: %w", err) + } + return links, nil +} diff --git a/internal/store/projects.go b/internal/store/projects.go new file mode 100644 index 0000000..651c156 --- /dev/null +++ b/internal/store/projects.go @@ -0,0 +1,90 @@ +package store + +import ( + "context" + "fmt" + "strings" + + "github.com/google/uuid" + "github.com/jackc/pgx/v5" + + thoughttypes "git.warky.dev/wdevs/amcs/internal/types" +) + +func (db *DB) CreateProject(ctx context.Context, name, description string) (thoughttypes.Project, error) { + row := db.pool.QueryRow(ctx, ` + insert into projects (name, description) + values ($1, $2) + returning id, name, description, created_at, last_active_at + `, name, description) + + var project thoughttypes.Project + if err := row.Scan(&project.ID, &project.Name, &project.Description, &project.CreatedAt, &project.LastActiveAt); err != nil { + return thoughttypes.Project{}, fmt.Errorf("create project: %w", err) + } + return project, nil +} + +func (db *DB) GetProject(ctx context.Context, nameOrID string) (thoughttypes.Project, error) { + var row pgx.Row + if parsedID, err := uuid.Parse(strings.TrimSpace(nameOrID)); err == nil { + row = db.pool.QueryRow(ctx, ` + select id, name, description, created_at, last_active_at + from projects + where id = $1 + `, parsedID) + } else { + row = db.pool.QueryRow(ctx, ` + select id, name, description, created_at, last_active_at + from projects + where name = $1 + `, strings.TrimSpace(nameOrID)) + } + + var project thoughttypes.Project + if err := row.Scan(&project.ID, &project.Name, &project.Description, &project.CreatedAt, &project.LastActiveAt); err != nil { + if err == pgx.ErrNoRows { + return thoughttypes.Project{}, err + } + return thoughttypes.Project{}, fmt.Errorf("get project: %w", err) + } + return project, nil +} + +func (db *DB) ListProjects(ctx context.Context) ([]thoughttypes.ProjectSummary, error) { + rows, err := db.pool.Query(ctx, ` + select p.id, p.name, p.description, p.created_at, p.last_active_at, count(t.id) as thought_count + from projects p + left join thoughts t on t.project_id = p.id and t.archived_at is null + group by p.id + order by p.last_active_at desc, p.created_at desc + `) + if err != nil { + return nil, fmt.Errorf("list projects: %w", err) + } + defer rows.Close() + + projects := make([]thoughttypes.ProjectSummary, 0) + for rows.Next() { + var project thoughttypes.ProjectSummary + if err := rows.Scan(&project.ID, &project.Name, &project.Description, &project.CreatedAt, &project.LastActiveAt, &project.ThoughtCount); err != nil { + return nil, fmt.Errorf("scan project summary: %w", err) + } + projects = append(projects, project) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("iterate projects: %w", err) + } + return projects, nil +} + +func (db *DB) TouchProject(ctx context.Context, id uuid.UUID) error { + tag, err := db.pool.Exec(ctx, `update projects set last_active_at = now() where id = $1`, id) + if err != nil { + return fmt.Errorf("touch project: %w", err) + } + if tag.RowsAffected() == 0 { + return pgx.ErrNoRows + } + return nil +} diff --git a/internal/store/thoughts.go b/internal/store/thoughts.go new file mode 100644 index 0000000..36d54cb --- /dev/null +++ b/internal/store/thoughts.go @@ -0,0 +1,345 @@ +package store + +import ( + "context" + "encoding/json" + "fmt" + "sort" + "strings" + "time" + + "github.com/google/uuid" + "github.com/jackc/pgx/v5" + "github.com/pgvector/pgvector-go" + + thoughttypes "git.warky.dev/wdevs/amcs/internal/types" +) + +func (db *DB) InsertThought(ctx context.Context, thought thoughttypes.Thought) (thoughttypes.Thought, error) { + metadata, err := json.Marshal(thought.Metadata) + if err != nil { + return thoughttypes.Thought{}, fmt.Errorf("marshal metadata: %w", err) + } + + row := db.pool.QueryRow(ctx, ` + insert into thoughts (content, embedding, metadata, project_id) + values ($1, $2, $3::jsonb, $4) + returning id, created_at, updated_at + `, thought.Content, pgvector.NewVector(thought.Embedding), metadata, thought.ProjectID) + + created := thought + if err := row.Scan(&created.ID, &created.CreatedAt, &created.UpdatedAt); err != nil { + return thoughttypes.Thought{}, fmt.Errorf("insert thought: %w", err) + } + + return created, nil +} + +func (db *DB) SearchThoughts(ctx context.Context, embedding []float32, threshold float64, limit int, filter map[string]any) ([]thoughttypes.SearchResult, error) { + filterJSON, err := json.Marshal(filter) + if err != nil { + return nil, fmt.Errorf("marshal search filter: %w", err) + } + + rows, err := db.pool.Query(ctx, ` + select id, content, metadata, similarity, created_at + from match_thoughts($1, $2, $3, $4::jsonb) + `, pgvector.NewVector(embedding), threshold, limit, filterJSON) + if err != nil { + return nil, fmt.Errorf("search thoughts: %w", err) + } + defer rows.Close() + + results := make([]thoughttypes.SearchResult, 0, limit) + for rows.Next() { + var result thoughttypes.SearchResult + var metadataBytes []byte + if err := rows.Scan(&result.ID, &result.Content, &metadataBytes, &result.Similarity, &result.CreatedAt); err != nil { + return nil, fmt.Errorf("scan search result: %w", err) + } + if err := json.Unmarshal(metadataBytes, &result.Metadata); err != nil { + return nil, fmt.Errorf("decode search metadata: %w", err) + } + results = append(results, result) + } + + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("iterate search results: %w", err) + } + + return results, nil +} + +func (db *DB) ListThoughts(ctx context.Context, filter thoughttypes.ListFilter) ([]thoughttypes.Thought, error) { + args := make([]any, 0, 6) + conditions := []string{} + + if !filter.IncludeArchived { + conditions = append(conditions, "archived_at is null") + } + if value := strings.TrimSpace(filter.Type); value != "" { + args = append(args, value) + conditions = append(conditions, fmt.Sprintf("metadata->>'type' = $%d", len(args))) + } + if value := strings.TrimSpace(filter.Topic); value != "" { + args = append(args, value) + conditions = append(conditions, fmt.Sprintf("metadata->'topics' ? $%d", len(args))) + } + if value := strings.TrimSpace(filter.Person); value != "" { + args = append(args, value) + conditions = append(conditions, fmt.Sprintf("metadata->'people' ? $%d", len(args))) + } + if filter.Days > 0 { + args = append(args, time.Now().Add(-time.Duration(filter.Days)*24*time.Hour)) + conditions = append(conditions, fmt.Sprintf("created_at >= $%d", len(args))) + } + if filter.ProjectID != nil { + args = append(args, *filter.ProjectID) + conditions = append(conditions, fmt.Sprintf("project_id = $%d", len(args))) + } + + query := ` + select id, content, metadata, project_id, archived_at, created_at, updated_at + from thoughts + ` + if len(conditions) > 0 { + query += " where " + strings.Join(conditions, " and ") + } + + args = append(args, filter.Limit) + query += fmt.Sprintf(" order by created_at desc limit $%d", len(args)) + + rows, err := db.pool.Query(ctx, query, args...) + if err != nil { + return nil, fmt.Errorf("list thoughts: %w", err) + } + defer rows.Close() + + thoughts := make([]thoughttypes.Thought, 0, filter.Limit) + for rows.Next() { + var thought thoughttypes.Thought + var metadataBytes []byte + if err := rows.Scan(&thought.ID, &thought.Content, &metadataBytes, &thought.ProjectID, &thought.ArchivedAt, &thought.CreatedAt, &thought.UpdatedAt); err != nil { + return nil, fmt.Errorf("scan listed thought: %w", err) + } + if err := json.Unmarshal(metadataBytes, &thought.Metadata); err != nil { + return nil, fmt.Errorf("decode listed metadata: %w", err) + } + thoughts = append(thoughts, thought) + } + + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("iterate listed thoughts: %w", err) + } + + return thoughts, nil +} + +func (db *DB) Stats(ctx context.Context) (thoughttypes.ThoughtStats, error) { + var total int + if err := db.pool.QueryRow(ctx, `select count(*) from thoughts where archived_at is null`).Scan(&total); err != nil { + return thoughttypes.ThoughtStats{}, fmt.Errorf("count thoughts: %w", err) + } + + rows, err := db.pool.Query(ctx, `select metadata from thoughts where archived_at is null`) + if err != nil { + return thoughttypes.ThoughtStats{}, fmt.Errorf("query stats metadata: %w", err) + } + defer rows.Close() + + stats := thoughttypes.ThoughtStats{ + TotalCount: total, + TypeCounts: map[string]int{}, + } + topics := map[string]int{} + people := map[string]int{} + + for rows.Next() { + var metadataBytes []byte + if err := rows.Scan(&metadataBytes); err != nil { + return thoughttypes.ThoughtStats{}, fmt.Errorf("scan stats metadata: %w", err) + } + + var metadata thoughttypes.ThoughtMetadata + if err := json.Unmarshal(metadataBytes, &metadata); err != nil { + return thoughttypes.ThoughtStats{}, fmt.Errorf("decode stats metadata: %w", err) + } + + stats.TypeCounts[metadata.Type]++ + for _, topic := range metadata.Topics { + topics[topic]++ + } + for _, person := range metadata.People { + people[person]++ + } + } + + if err := rows.Err(); err != nil { + return thoughttypes.ThoughtStats{}, fmt.Errorf("iterate stats metadata: %w", err) + } + + stats.TopTopics = topCounts(topics, 10) + stats.TopPeople = topCounts(people, 10) + return stats, nil +} + +func (db *DB) GetThought(ctx context.Context, id uuid.UUID) (thoughttypes.Thought, error) { + row := db.pool.QueryRow(ctx, ` + select id, content, embedding, metadata, project_id, archived_at, created_at, updated_at + from thoughts + where id = $1 + `, id) + + var thought thoughttypes.Thought + var embedding pgvector.Vector + var metadataBytes []byte + if err := row.Scan(&thought.ID, &thought.Content, &embedding, &metadataBytes, &thought.ProjectID, &thought.ArchivedAt, &thought.CreatedAt, &thought.UpdatedAt); err != nil { + if err == pgx.ErrNoRows { + return thoughttypes.Thought{}, err + } + return thoughttypes.Thought{}, fmt.Errorf("get thought: %w", err) + } + + if err := json.Unmarshal(metadataBytes, &thought.Metadata); err != nil { + return thoughttypes.Thought{}, fmt.Errorf("decode thought metadata: %w", err) + } + thought.Embedding = embedding.Slice() + + return thought, nil +} + +func (db *DB) UpdateThought(ctx context.Context, id uuid.UUID, content string, embedding []float32, metadata thoughttypes.ThoughtMetadata, projectID *uuid.UUID) (thoughttypes.Thought, error) { + metadataBytes, err := json.Marshal(metadata) + if err != nil { + return thoughttypes.Thought{}, fmt.Errorf("marshal updated metadata: %w", err) + } + + tag, err := db.pool.Exec(ctx, ` + update thoughts + set content = $2, + embedding = $3, + metadata = $4::jsonb, + project_id = $5, + updated_at = now() + where id = $1 + `, id, content, pgvector.NewVector(embedding), metadataBytes, projectID) + if err != nil { + return thoughttypes.Thought{}, fmt.Errorf("update thought: %w", err) + } + if tag.RowsAffected() == 0 { + return thoughttypes.Thought{}, pgx.ErrNoRows + } + + return db.GetThought(ctx, id) +} + +func (db *DB) DeleteThought(ctx context.Context, id uuid.UUID) error { + tag, err := db.pool.Exec(ctx, `delete from thoughts where id = $1`, id) + if err != nil { + return fmt.Errorf("delete thought: %w", err) + } + if tag.RowsAffected() == 0 { + return pgx.ErrNoRows + } + return nil +} + +func (db *DB) ArchiveThought(ctx context.Context, id uuid.UUID) error { + tag, err := db.pool.Exec(ctx, `update thoughts set archived_at = now(), updated_at = now() where id = $1`, id) + if err != nil { + return fmt.Errorf("archive thought: %w", err) + } + if tag.RowsAffected() == 0 { + return pgx.ErrNoRows + } + return nil +} + +func (db *DB) RecentThoughts(ctx context.Context, projectID *uuid.UUID, limit int, days int) ([]thoughttypes.Thought, error) { + filter := thoughttypes.ListFilter{ + Limit: limit, + ProjectID: projectID, + Days: days, + IncludeArchived: false, + } + return db.ListThoughts(ctx, filter) +} + +func (db *DB) SearchSimilarThoughts(ctx context.Context, embedding []float32, threshold float64, limit int, projectID *uuid.UUID, excludeID *uuid.UUID) ([]thoughttypes.SearchResult, error) { + args := []any{pgvector.NewVector(embedding), threshold} + conditions := []string{ + "archived_at is null", + "1 - (embedding <=> $1) > $2", + } + if projectID != nil { + args = append(args, *projectID) + conditions = append(conditions, fmt.Sprintf("project_id = $%d", len(args))) + } + if excludeID != nil { + args = append(args, *excludeID) + conditions = append(conditions, fmt.Sprintf("id <> $%d", len(args))) + } + args = append(args, limit) + + query := ` + select id, content, metadata, 1 - (embedding <=> $1) as similarity, created_at + from thoughts + where ` + strings.Join(conditions, " and ") + fmt.Sprintf(` + order by embedding <=> $1 + limit $%d`, len(args)) + + rows, err := db.pool.Query(ctx, query, args...) + if err != nil { + return nil, fmt.Errorf("search similar thoughts: %w", err) + } + defer rows.Close() + + results := make([]thoughttypes.SearchResult, 0, limit) + for rows.Next() { + var result thoughttypes.SearchResult + var metadataBytes []byte + if err := rows.Scan(&result.ID, &result.Content, &metadataBytes, &result.Similarity, &result.CreatedAt); err != nil { + return nil, fmt.Errorf("scan similar thought: %w", err) + } + if err := json.Unmarshal(metadataBytes, &result.Metadata); err != nil { + return nil, fmt.Errorf("decode similar thought metadata: %w", err) + } + results = append(results, result) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("iterate similar thoughts: %w", err) + } + return results, nil +} + +func topCounts(in map[string]int, limit int) []thoughttypes.KeyCount { + type kv struct { + key string + count int + } + + pairs := make([]kv, 0, len(in)) + for key, count := range in { + if strings.TrimSpace(key) == "" { + continue + } + pairs = append(pairs, kv{key: key, count: count}) + } + + sort.Slice(pairs, func(i, j int) bool { + if pairs[i].count == pairs[j].count { + return pairs[i].key < pairs[j].key + } + return pairs[i].count > pairs[j].count + }) + + if limit > 0 && len(pairs) > limit { + pairs = pairs[:limit] + } + + out := make([]thoughttypes.KeyCount, 0, len(pairs)) + for _, pair := range pairs { + out = append(out, thoughttypes.KeyCount{Key: pair.key, Count: pair.count}) + } + return out +} diff --git a/internal/tools/archive.go b/internal/tools/archive.go new file mode 100644 index 0000000..5b94b74 --- /dev/null +++ b/internal/tools/archive.go @@ -0,0 +1,36 @@ +package tools + +import ( + "context" + + "github.com/modelcontextprotocol/go-sdk/mcp" + + "git.warky.dev/wdevs/amcs/internal/store" +) + +type ArchiveTool struct { + store *store.DB +} + +type ArchiveInput struct { + ID string `json:"id" jsonschema:"the thought id"` +} + +type ArchiveOutput struct { + Archived bool `json:"archived"` +} + +func NewArchiveTool(db *store.DB) *ArchiveTool { + return &ArchiveTool{store: db} +} + +func (t *ArchiveTool) Handle(ctx context.Context, _ *mcp.CallToolRequest, in ArchiveInput) (*mcp.CallToolResult, ArchiveOutput, error) { + id, err := parseUUID(in.ID) + if err != nil { + return nil, ArchiveOutput{}, err + } + if err := t.store.ArchiveThought(ctx, id); err != nil { + return nil, ArchiveOutput{}, err + } + return nil, ArchiveOutput{Archived: true}, nil +} diff --git a/internal/tools/capture.go b/internal/tools/capture.go new file mode 100644 index 0000000..9c1b8fe --- /dev/null +++ b/internal/tools/capture.go @@ -0,0 +1,95 @@ +package tools + +import ( + "context" + "log/slog" + "strings" + + "github.com/modelcontextprotocol/go-sdk/mcp" + "golang.org/x/sync/errgroup" + + "git.warky.dev/wdevs/amcs/internal/ai" + "git.warky.dev/wdevs/amcs/internal/config" + "git.warky.dev/wdevs/amcs/internal/metadata" + "git.warky.dev/wdevs/amcs/internal/session" + "git.warky.dev/wdevs/amcs/internal/store" + thoughttypes "git.warky.dev/wdevs/amcs/internal/types" +) + +type CaptureTool struct { + store *store.DB + provider ai.Provider + capture config.CaptureConfig + sessions *session.ActiveProjects + log *slog.Logger +} + +type CaptureInput struct { + Content string `json:"content" jsonschema:"the thought or note to capture"` + Project string `json:"project,omitempty" jsonschema:"optional project name or id to associate with the thought"` +} + +type CaptureOutput struct { + Thought thoughttypes.Thought `json:"thought"` +} + +func NewCaptureTool(db *store.DB, provider ai.Provider, capture config.CaptureConfig, sessions *session.ActiveProjects, log *slog.Logger) *CaptureTool { + return &CaptureTool{store: db, provider: provider, capture: capture, sessions: sessions, log: log} +} + +func (t *CaptureTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in CaptureInput) (*mcp.CallToolResult, CaptureOutput, error) { + content := strings.TrimSpace(in.Content) + if content == "" { + return nil, CaptureOutput{}, errInvalidInput("content is required") + } + + project, err := resolveProject(ctx, t.store, t.sessions, req, in.Project, false) + if err != nil { + return nil, CaptureOutput{}, err + } + + var embedding []float32 + rawMetadata := metadata.Fallback(t.capture) + + group, groupCtx := errgroup.WithContext(ctx) + group.Go(func() error { + vector, err := t.provider.Embed(groupCtx, content) + if err != nil { + return err + } + embedding = vector + return nil + }) + group.Go(func() error { + extracted, err := t.provider.ExtractMetadata(groupCtx, content) + if err != nil { + t.log.Warn("metadata extraction failed, using fallback", slog.String("provider", t.provider.Name()), slog.String("error", err.Error())) + return nil + } + rawMetadata = extracted + return nil + }) + + if err := group.Wait(); err != nil { + return nil, CaptureOutput{}, err + } + + thought := thoughttypes.Thought{ + Content: content, + Embedding: embedding, + Metadata: metadata.Normalize(rawMetadata, t.capture), + } + if project != nil { + thought.ProjectID = &project.ID + } + + created, err := t.store.InsertThought(ctx, thought) + if err != nil { + return nil, CaptureOutput{}, err + } + if project != nil { + _ = t.store.TouchProject(ctx, project.ID) + } + + return nil, CaptureOutput{Thought: created}, nil +} diff --git a/internal/tools/common.go b/internal/tools/common.go new file mode 100644 index 0000000..c2567b2 --- /dev/null +++ b/internal/tools/common.go @@ -0,0 +1,31 @@ +package tools + +import ( + "fmt" + + "git.warky.dev/wdevs/amcs/internal/config" +) + +func normalizeLimit(limit int, cfg config.SearchConfig) int { + if limit <= 0 { + return cfg.DefaultLimit + } + if limit > cfg.MaxLimit { + return cfg.MaxLimit + } + return limit +} + +func normalizeThreshold(value float64, fallback float64) float64 { + if value <= 0 { + return fallback + } + if value > 1 { + return 1 + } + return value +} + +func errInvalidInput(message string) error { + return fmt.Errorf("invalid input: %s", message) +} diff --git a/internal/tools/context.go b/internal/tools/context.go new file mode 100644 index 0000000..3ca9e94 --- /dev/null +++ b/internal/tools/context.go @@ -0,0 +1,111 @@ +package tools + +import ( + "context" + "fmt" + "strings" + + "github.com/modelcontextprotocol/go-sdk/mcp" + + "git.warky.dev/wdevs/amcs/internal/ai" + "git.warky.dev/wdevs/amcs/internal/config" + "git.warky.dev/wdevs/amcs/internal/session" + "git.warky.dev/wdevs/amcs/internal/store" + thoughttypes "git.warky.dev/wdevs/amcs/internal/types" +) + +type ContextTool struct { + store *store.DB + provider ai.Provider + search config.SearchConfig + sessions *session.ActiveProjects +} + +type ProjectContextInput struct { + Project string `json:"project,omitempty" jsonschema:"project name or id; falls back to the active session project"` + Query string `json:"query,omitempty" jsonschema:"optional semantic focus for project context"` + Limit int `json:"limit,omitempty" jsonschema:"maximum number of context items to return"` +} + +type ContextItem struct { + ID string `json:"id"` + Content string `json:"content"` + Metadata thoughttypes.ThoughtMetadata `json:"metadata"` + Similarity float64 `json:"similarity,omitempty"` + Source string `json:"source"` +} + +type ProjectContextOutput struct { + Project thoughttypes.Project `json:"project"` + Context string `json:"context"` + Items []ContextItem `json:"items"` +} + +func NewContextTool(db *store.DB, provider ai.Provider, search config.SearchConfig, sessions *session.ActiveProjects) *ContextTool { + return &ContextTool{store: db, provider: provider, search: search, sessions: sessions} +} + +func (t *ContextTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in ProjectContextInput) (*mcp.CallToolResult, ProjectContextOutput, error) { + project, err := resolveProject(ctx, t.store, t.sessions, req, in.Project, true) + if err != nil { + return nil, ProjectContextOutput{}, err + } + + limit := normalizeLimit(in.Limit, t.search) + recent, err := t.store.RecentThoughts(ctx, &project.ID, limit, 0) + if err != nil { + return nil, ProjectContextOutput{}, err + } + + items := make([]ContextItem, 0, limit*2) + seen := map[string]struct{}{} + for _, thought := range recent { + key := thought.ID.String() + seen[key] = struct{}{} + items = append(items, ContextItem{ + ID: key, + Content: thought.Content, + Metadata: thought.Metadata, + Source: "recent", + }) + } + + query := strings.TrimSpace(in.Query) + if query != "" { + embedding, err := t.provider.Embed(ctx, query) + if err != nil { + return nil, ProjectContextOutput{}, err + } + semantic, err := t.store.SearchSimilarThoughts(ctx, embedding, t.search.DefaultThreshold, limit, &project.ID, nil) + if err != nil { + return nil, ProjectContextOutput{}, err + } + for _, result := range semantic { + key := result.ID.String() + if _, ok := seen[key]; ok { + continue + } + seen[key] = struct{}{} + items = append(items, ContextItem{ + ID: key, + Content: result.Content, + Metadata: result.Metadata, + Similarity: result.Similarity, + Source: "semantic", + }) + } + } + + lines := make([]string, 0, len(items)) + for i, item := range items { + lines = append(lines, thoughtContextLine(i, item.Content, item.Metadata, item.Similarity)) + } + contextBlock := formatContextBlock(fmt.Sprintf("Project context for %s", project.Name), lines) + _ = t.store.TouchProject(ctx, project.ID) + + return nil, ProjectContextOutput{ + Project: *project, + Context: contextBlock, + Items: items, + }, nil +} diff --git a/internal/tools/delete.go b/internal/tools/delete.go new file mode 100644 index 0000000..18c6f4d --- /dev/null +++ b/internal/tools/delete.go @@ -0,0 +1,36 @@ +package tools + +import ( + "context" + + "github.com/modelcontextprotocol/go-sdk/mcp" + + "git.warky.dev/wdevs/amcs/internal/store" +) + +type DeleteTool struct { + store *store.DB +} + +type DeleteInput struct { + ID string `json:"id" jsonschema:"the thought id"` +} + +type DeleteOutput struct { + Deleted bool `json:"deleted"` +} + +func NewDeleteTool(db *store.DB) *DeleteTool { + return &DeleteTool{store: db} +} + +func (t *DeleteTool) Handle(ctx context.Context, _ *mcp.CallToolRequest, in DeleteInput) (*mcp.CallToolResult, DeleteOutput, error) { + id, err := parseUUID(in.ID) + if err != nil { + return nil, DeleteOutput{}, err + } + if err := t.store.DeleteThought(ctx, id); err != nil { + return nil, DeleteOutput{}, err + } + return nil, DeleteOutput{Deleted: true}, nil +} diff --git a/internal/tools/get.go b/internal/tools/get.go new file mode 100644 index 0000000..b38f710 --- /dev/null +++ b/internal/tools/get.go @@ -0,0 +1,40 @@ +package tools + +import ( + "context" + + "github.com/modelcontextprotocol/go-sdk/mcp" + + "git.warky.dev/wdevs/amcs/internal/store" + thoughttypes "git.warky.dev/wdevs/amcs/internal/types" +) + +type GetTool struct { + store *store.DB +} + +type GetInput struct { + ID string `json:"id" jsonschema:"the thought id"` +} + +type GetOutput struct { + Thought thoughttypes.Thought `json:"thought"` +} + +func NewGetTool(db *store.DB) *GetTool { + return &GetTool{store: db} +} + +func (t *GetTool) Handle(ctx context.Context, _ *mcp.CallToolRequest, in GetInput) (*mcp.CallToolResult, GetOutput, error) { + id, err := parseUUID(in.ID) + if err != nil { + return nil, GetOutput{}, err + } + + thought, err := t.store.GetThought(ctx, id) + if err != nil { + return nil, GetOutput{}, err + } + + return nil, GetOutput{Thought: thought}, nil +} diff --git a/internal/tools/helpers.go b/internal/tools/helpers.go new file mode 100644 index 0000000..306147a --- /dev/null +++ b/internal/tools/helpers.go @@ -0,0 +1,87 @@ +package tools + +import ( + "context" + "fmt" + "strings" + + "github.com/google/uuid" + "github.com/jackc/pgx/v5" + "github.com/modelcontextprotocol/go-sdk/mcp" + + "git.warky.dev/wdevs/amcs/internal/session" + "git.warky.dev/wdevs/amcs/internal/store" + thoughttypes "git.warky.dev/wdevs/amcs/internal/types" +) + +func parseUUID(id string) (uuid.UUID, error) { + parsed, err := uuid.Parse(strings.TrimSpace(id)) + if err != nil { + return uuid.Nil, fmt.Errorf("invalid id %q: %w", id, err) + } + return parsed, nil +} + +func sessionID(req *mcp.CallToolRequest) (string, error) { + if req == nil || req.Session == nil || req.Session.ID() == "" { + return "", fmt.Errorf("tool requires an MCP session") + } + return req.Session.ID(), nil +} + +func resolveProject(ctx context.Context, db *store.DB, sessions *session.ActiveProjects, req *mcp.CallToolRequest, raw string, required bool) (*thoughttypes.Project, error) { + projectRef := strings.TrimSpace(raw) + if projectRef == "" && sessions != nil && req != nil && req.Session != nil { + if activeID, ok := sessions.Get(req.Session.ID()); ok { + project, err := db.GetProject(ctx, activeID.String()) + if err == nil { + return &project, nil + } + if err != pgx.ErrNoRows { + return nil, err + } + } + } + + if projectRef == "" { + if required { + return nil, fmt.Errorf("project is required") + } + return nil, nil + } + + project, err := db.GetProject(ctx, projectRef) + if err != nil { + if err == pgx.ErrNoRows { + return nil, fmt.Errorf("project %q not found", projectRef) + } + return nil, err + } + + return &project, nil +} + +func formatContextBlock(header string, lines []string) string { + if len(lines) == 0 { + return header + "\n\nNo matching thoughts." + } + return header + "\n\n" + strings.Join(lines, "\n\n") +} + +func thoughtContextLine(index int, content string, metadata thoughttypes.ThoughtMetadata, similarity float64) string { + label := fmt.Sprintf("%d. %s", index+1, strings.TrimSpace(content)) + parts := make([]string, 0, 3) + if len(metadata.Topics) > 0 { + parts = append(parts, "topics="+strings.Join(metadata.Topics, ", ")) + } + if metadata.Type != "" { + parts = append(parts, "type="+metadata.Type) + } + if similarity > 0 { + parts = append(parts, fmt.Sprintf("similarity=%.3f", similarity)) + } + if len(parts) == 0 { + return label + } + return label + "\n[" + strings.Join(parts, " | ") + "]" +} diff --git a/internal/tools/links.go b/internal/tools/links.go new file mode 100644 index 0000000..1940ea3 --- /dev/null +++ b/internal/tools/links.go @@ -0,0 +1,145 @@ +package tools + +import ( + "context" + "strings" + + "github.com/modelcontextprotocol/go-sdk/mcp" + + "git.warky.dev/wdevs/amcs/internal/ai" + "git.warky.dev/wdevs/amcs/internal/config" + "git.warky.dev/wdevs/amcs/internal/store" + thoughttypes "git.warky.dev/wdevs/amcs/internal/types" +) + +type LinksTool struct { + store *store.DB + provider ai.Provider + search config.SearchConfig +} + +type LinkInput struct { + FromID string `json:"from_id" jsonschema:"the source thought id"` + ToID string `json:"to_id" jsonschema:"the target thought id"` + Relation string `json:"relation" jsonschema:"relationship label such as follows_up or references"` +} + +type LinkOutput struct { + Linked bool `json:"linked"` +} + +type RelatedInput struct { + ID string `json:"id" jsonschema:"the thought id"` + IncludeSemantic *bool `json:"include_semantic,omitempty" jsonschema:"whether to include semantic neighbors; defaults to true"` +} + +type RelatedThought struct { + ID string `json:"id"` + Content string `json:"content"` + Metadata thoughttypes.ThoughtMetadata `json:"metadata"` + Relation string `json:"relation,omitempty"` + Direction string `json:"direction,omitempty"` + Similarity float64 `json:"similarity,omitempty"` + Source string `json:"source"` +} + +type RelatedOutput struct { + Related []RelatedThought `json:"related"` +} + +func NewLinksTool(db *store.DB, provider ai.Provider, search config.SearchConfig) *LinksTool { + return &LinksTool{store: db, provider: provider, search: search} +} + +func (t *LinksTool) Link(ctx context.Context, _ *mcp.CallToolRequest, in LinkInput) (*mcp.CallToolResult, LinkOutput, error) { + fromID, err := parseUUID(in.FromID) + if err != nil { + return nil, LinkOutput{}, err + } + toID, err := parseUUID(in.ToID) + if err != nil { + return nil, LinkOutput{}, err + } + relation := strings.TrimSpace(in.Relation) + if relation == "" { + return nil, LinkOutput{}, errInvalidInput("relation is required") + } + if _, err := t.store.GetThought(ctx, fromID); err != nil { + return nil, LinkOutput{}, err + } + if _, err := t.store.GetThought(ctx, toID); err != nil { + return nil, LinkOutput{}, err + } + if err := t.store.InsertLink(ctx, thoughttypes.ThoughtLink{ + FromID: fromID, + ToID: toID, + Relation: relation, + }); err != nil { + return nil, LinkOutput{}, err + } + return nil, LinkOutput{Linked: true}, nil +} + +func (t *LinksTool) Related(ctx context.Context, _ *mcp.CallToolRequest, in RelatedInput) (*mcp.CallToolResult, RelatedOutput, error) { + id, err := parseUUID(in.ID) + if err != nil { + return nil, RelatedOutput{}, err + } + + thought, err := t.store.GetThought(ctx, id) + if err != nil { + return nil, RelatedOutput{}, err + } + + linked, err := t.store.LinkedThoughts(ctx, id) + if err != nil { + return nil, RelatedOutput{}, err + } + + related := make([]RelatedThought, 0, len(linked)+t.search.DefaultLimit) + seen := map[string]struct{}{thought.ID.String(): {}} + for _, item := range linked { + key := item.Thought.ID.String() + seen[key] = struct{}{} + related = append(related, RelatedThought{ + ID: key, + Content: item.Thought.Content, + Metadata: item.Thought.Metadata, + Relation: item.Relation, + Direction: item.Direction, + Source: "link", + }) + } + + includeSemantic := true + if in.IncludeSemantic != nil { + includeSemantic = *in.IncludeSemantic + } + + if includeSemantic { + embedding, err := t.provider.Embed(ctx, thought.Content) + if err != nil { + return nil, RelatedOutput{}, err + } + semantic, err := t.store.SearchSimilarThoughts(ctx, embedding, t.search.DefaultThreshold, t.search.DefaultLimit, thought.ProjectID, &thought.ID) + if err != nil { + return nil, RelatedOutput{}, err + } + for _, item := range semantic { + key := item.ID.String() + if _, ok := seen[key]; ok { + continue + } + seen[key] = struct{}{} + related = append(related, RelatedThought{ + ID: key, + Content: item.Content, + Metadata: item.Metadata, + Similarity: item.Similarity, + Source: "semantic", + }) + } + } + + return nil, RelatedOutput{Related: related}, nil +} diff --git a/internal/tools/list.go b/internal/tools/list.go new file mode 100644 index 0000000..e8bc890 --- /dev/null +++ b/internal/tools/list.go @@ -0,0 +1,68 @@ +package tools + +import ( + "context" + "strings" + + "github.com/google/uuid" + "github.com/modelcontextprotocol/go-sdk/mcp" + + "git.warky.dev/wdevs/amcs/internal/config" + "git.warky.dev/wdevs/amcs/internal/session" + "git.warky.dev/wdevs/amcs/internal/store" + thoughttypes "git.warky.dev/wdevs/amcs/internal/types" +) + +type ListTool struct { + store *store.DB + search config.SearchConfig + sessions *session.ActiveProjects +} + +type ListInput struct { + Limit int `json:"limit,omitempty" jsonschema:"maximum number of thoughts to return"` + Type string `json:"type,omitempty" jsonschema:"filter by thought type"` + Topic string `json:"topic,omitempty" jsonschema:"filter by topic"` + Person string `json:"person,omitempty" jsonschema:"filter by mentioned person"` + Days int `json:"days,omitempty" jsonschema:"only include thoughts from the last N days"` + IncludeArchived bool `json:"include_archived,omitempty" jsonschema:"include archived thoughts when true"` + Project string `json:"project,omitempty" jsonschema:"optional project name or id to scope the listing"` +} + +type ListOutput struct { + Thoughts []thoughttypes.Thought `json:"thoughts"` +} + +func NewListTool(db *store.DB, search config.SearchConfig, sessions *session.ActiveProjects) *ListTool { + return &ListTool{store: db, search: search, sessions: sessions} +} + +func (t *ListTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in ListInput) (*mcp.CallToolResult, ListOutput, error) { + project, err := resolveProject(ctx, t.store, t.sessions, req, in.Project, false) + if err != nil { + return nil, ListOutput{}, err + } + + var projectID *uuid.UUID + if project != nil { + projectID = &project.ID + } + + thoughts, err := t.store.ListThoughts(ctx, thoughttypes.ListFilter{ + Limit: normalizeLimit(in.Limit, t.search), + Type: strings.TrimSpace(in.Type), + Topic: strings.TrimSpace(in.Topic), + Person: strings.TrimSpace(in.Person), + Days: in.Days, + IncludeArchived: in.IncludeArchived, + ProjectID: projectID, + }) + if err != nil { + return nil, ListOutput{}, err + } + if project != nil { + _ = t.store.TouchProject(ctx, project.ID) + } + + return nil, ListOutput{Thoughts: thoughts}, nil +} diff --git a/internal/tools/projects.go b/internal/tools/projects.go new file mode 100644 index 0000000..de6f462 --- /dev/null +++ b/internal/tools/projects.go @@ -0,0 +1,92 @@ +package tools + +import ( + "context" + "strings" + + "github.com/modelcontextprotocol/go-sdk/mcp" + + "git.warky.dev/wdevs/amcs/internal/session" + "git.warky.dev/wdevs/amcs/internal/store" + thoughttypes "git.warky.dev/wdevs/amcs/internal/types" +) + +type ProjectsTool struct { + store *store.DB + sessions *session.ActiveProjects +} + +type CreateProjectInput struct { + Name string `json:"name" jsonschema:"the unique project name"` + Description string `json:"description,omitempty" jsonschema:"optional project description"` +} + +type CreateProjectOutput struct { + Project thoughttypes.Project `json:"project"` +} + +type ListProjectsInput struct{} + +type ListProjectsOutput struct { + Projects []thoughttypes.ProjectSummary `json:"projects"` +} + +type SetActiveProjectInput struct { + Project string `json:"project" jsonschema:"project name or id"` +} + +type SetActiveProjectOutput struct { + Project thoughttypes.Project `json:"project"` +} + +type GetActiveProjectInput struct{} + +type GetActiveProjectOutput struct { + Project *thoughttypes.Project `json:"project,omitempty"` +} + +func NewProjectsTool(db *store.DB, sessions *session.ActiveProjects) *ProjectsTool { + return &ProjectsTool{store: db, sessions: sessions} +} + +func (t *ProjectsTool) Create(ctx context.Context, _ *mcp.CallToolRequest, in CreateProjectInput) (*mcp.CallToolResult, CreateProjectOutput, error) { + name := strings.TrimSpace(in.Name) + if name == "" { + return nil, CreateProjectOutput{}, errInvalidInput("name is required") + } + project, err := t.store.CreateProject(ctx, name, strings.TrimSpace(in.Description)) + if err != nil { + return nil, CreateProjectOutput{}, err + } + return nil, CreateProjectOutput{Project: project}, nil +} + +func (t *ProjectsTool) List(ctx context.Context, _ *mcp.CallToolRequest, _ ListProjectsInput) (*mcp.CallToolResult, ListProjectsOutput, error) { + projects, err := t.store.ListProjects(ctx) + if err != nil { + return nil, ListProjectsOutput{}, err + } + return nil, ListProjectsOutput{Projects: projects}, nil +} + +func (t *ProjectsTool) SetActive(ctx context.Context, req *mcp.CallToolRequest, in SetActiveProjectInput) (*mcp.CallToolResult, SetActiveProjectOutput, error) { + sid, err := sessionID(req) + if err != nil { + return nil, SetActiveProjectOutput{}, err + } + project, err := resolveProject(ctx, t.store, t.sessions, req, in.Project, true) + if err != nil { + return nil, SetActiveProjectOutput{}, err + } + t.sessions.Set(sid, project.ID) + _ = t.store.TouchProject(ctx, project.ID) + return nil, SetActiveProjectOutput{Project: *project}, nil +} + +func (t *ProjectsTool) GetActive(ctx context.Context, req *mcp.CallToolRequest, _ GetActiveProjectInput) (*mcp.CallToolResult, GetActiveProjectOutput, error) { + project, err := resolveProject(ctx, t.store, t.sessions, req, "", false) + if err != nil { + return nil, GetActiveProjectOutput{}, err + } + return nil, GetActiveProjectOutput{Project: project}, nil +} diff --git a/internal/tools/recall.go b/internal/tools/recall.go new file mode 100644 index 0000000..a7cddc0 --- /dev/null +++ b/internal/tools/recall.go @@ -0,0 +1,111 @@ +package tools + +import ( + "context" + "fmt" + "strings" + + "github.com/google/uuid" + "github.com/modelcontextprotocol/go-sdk/mcp" + + "git.warky.dev/wdevs/amcs/internal/ai" + "git.warky.dev/wdevs/amcs/internal/config" + "git.warky.dev/wdevs/amcs/internal/session" + "git.warky.dev/wdevs/amcs/internal/store" +) + +type RecallTool struct { + store *store.DB + provider ai.Provider + search config.SearchConfig + sessions *session.ActiveProjects +} + +type RecallInput struct { + Query string `json:"query" jsonschema:"semantic query for recalled context"` + Project string `json:"project,omitempty" jsonschema:"optional project name or id; falls back to the active session project"` + Limit int `json:"limit,omitempty" jsonschema:"maximum number of context items to return"` +} + +type RecallOutput struct { + Context string `json:"context"` + Items []ContextItem `json:"items"` +} + +func NewRecallTool(db *store.DB, provider ai.Provider, search config.SearchConfig, sessions *session.ActiveProjects) *RecallTool { + return &RecallTool{store: db, provider: provider, search: search, sessions: sessions} +} + +func (t *RecallTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in RecallInput) (*mcp.CallToolResult, RecallOutput, error) { + query := strings.TrimSpace(in.Query) + if query == "" { + return nil, RecallOutput{}, errInvalidInput("query is required") + } + + project, err := resolveProject(ctx, t.store, t.sessions, req, in.Project, false) + if err != nil { + return nil, RecallOutput{}, err + } + + limit := normalizeLimit(in.Limit, t.search) + embedding, err := t.provider.Embed(ctx, query) + if err != nil { + return nil, RecallOutput{}, err + } + + var projectID *uuid.UUID + if project != nil { + projectID = &project.ID + } + + semantic, err := t.store.SearchSimilarThoughts(ctx, embedding, t.search.DefaultThreshold, limit, projectID, nil) + if err != nil { + return nil, RecallOutput{}, err + } + recent, err := t.store.RecentThoughts(ctx, projectID, limit, 0) + if err != nil { + return nil, RecallOutput{}, err + } + + items := make([]ContextItem, 0, limit*2) + seen := map[string]struct{}{} + for _, result := range semantic { + key := result.ID.String() + seen[key] = struct{}{} + items = append(items, ContextItem{ + ID: key, + Content: result.Content, + Metadata: result.Metadata, + Similarity: result.Similarity, + Source: "semantic", + }) + } + for _, thought := range recent { + key := thought.ID.String() + if _, ok := seen[key]; ok { + continue + } + seen[key] = struct{}{} + items = append(items, ContextItem{ + ID: key, + Content: thought.Content, + Metadata: thought.Metadata, + Source: "recent", + }) + } + + lines := make([]string, 0, len(items)) + for i, item := range items { + lines = append(lines, thoughtContextLine(i, item.Content, item.Metadata, item.Similarity)) + } + header := "Recalled context" + if project != nil { + header = fmt.Sprintf("Recalled context for %s", project.Name) + _ = t.store.TouchProject(ctx, project.ID) + } + + return nil, RecallOutput{ + Context: formatContextBlock(header, lines), + Items: items, + }, nil +} diff --git a/internal/tools/search.go b/internal/tools/search.go new file mode 100644 index 0000000..56aabb4 --- /dev/null +++ b/internal/tools/search.go @@ -0,0 +1,69 @@ +package tools + +import ( + "context" + "strings" + + "github.com/modelcontextprotocol/go-sdk/mcp" + + "git.warky.dev/wdevs/amcs/internal/ai" + "git.warky.dev/wdevs/amcs/internal/config" + "git.warky.dev/wdevs/amcs/internal/session" + "git.warky.dev/wdevs/amcs/internal/store" + thoughttypes "git.warky.dev/wdevs/amcs/internal/types" +) + +type SearchTool struct { + store *store.DB + provider ai.Provider + search config.SearchConfig + sessions *session.ActiveProjects +} + +type SearchInput struct { + Query string `json:"query" jsonschema:"the semantic query to search for"` + Limit int `json:"limit,omitempty" jsonschema:"maximum number of results to return"` + Threshold float64 `json:"threshold,omitempty" jsonschema:"minimum similarity threshold between 0 and 1"` + Project string `json:"project,omitempty" jsonschema:"optional project name or id to scope the search"` +} + +type SearchOutput struct { + Results []thoughttypes.SearchResult `json:"results"` +} + +func NewSearchTool(db *store.DB, provider ai.Provider, search config.SearchConfig, sessions *session.ActiveProjects) *SearchTool { + return &SearchTool{store: db, provider: provider, search: search, sessions: sessions} +} + +func (t *SearchTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in SearchInput) (*mcp.CallToolResult, SearchOutput, error) { + query := strings.TrimSpace(in.Query) + if query == "" { + return nil, SearchOutput{}, errInvalidInput("query is required") + } + + limit := normalizeLimit(in.Limit, t.search) + threshold := normalizeThreshold(in.Threshold, t.search.DefaultThreshold) + + embedding, err := t.provider.Embed(ctx, query) + if err != nil { + return nil, SearchOutput{}, err + } + + project, err := resolveProject(ctx, t.store, t.sessions, req, in.Project, false) + if err != nil { + return nil, SearchOutput{}, err + } + + var results []thoughttypes.SearchResult + if project != nil { + results, err = t.store.SearchSimilarThoughts(ctx, embedding, threshold, limit, &project.ID, nil) + _ = t.store.TouchProject(ctx, project.ID) + } else { + results, err = t.store.SearchThoughts(ctx, embedding, threshold, limit, map[string]any{}) + } + if err != nil { + return nil, SearchOutput{}, err + } + + return nil, SearchOutput{Results: results}, nil +} diff --git a/internal/tools/stats.go b/internal/tools/stats.go new file mode 100644 index 0000000..73ea30f --- /dev/null +++ b/internal/tools/stats.go @@ -0,0 +1,33 @@ +package tools + +import ( + "context" + + "github.com/modelcontextprotocol/go-sdk/mcp" + + "git.warky.dev/wdevs/amcs/internal/store" + thoughttypes "git.warky.dev/wdevs/amcs/internal/types" +) + +type StatsTool struct { + store *store.DB +} + +type StatsInput struct{} + +type StatsOutput struct { + Stats thoughttypes.ThoughtStats `json:"stats"` +} + +func NewStatsTool(db *store.DB) *StatsTool { + return &StatsTool{store: db} +} + +func (t *StatsTool) Handle(ctx context.Context, _ *mcp.CallToolRequest, _ StatsInput) (*mcp.CallToolResult, StatsOutput, error) { + stats, err := t.store.Stats(ctx) + if err != nil { + return nil, StatsOutput{}, err + } + + return nil, StatsOutput{Stats: stats}, nil +} diff --git a/internal/tools/summarize.go b/internal/tools/summarize.go new file mode 100644 index 0000000..11a7912 --- /dev/null +++ b/internal/tools/summarize.go @@ -0,0 +1,93 @@ +package tools + +import ( + "context" + "strings" + + "github.com/google/uuid" + "github.com/modelcontextprotocol/go-sdk/mcp" + + "git.warky.dev/wdevs/amcs/internal/ai" + "git.warky.dev/wdevs/amcs/internal/config" + "git.warky.dev/wdevs/amcs/internal/session" + "git.warky.dev/wdevs/amcs/internal/store" +) + +type SummarizeTool struct { + store *store.DB + provider ai.Provider + search config.SearchConfig + sessions *session.ActiveProjects +} + +type SummarizeInput struct { + Query string `json:"query,omitempty" jsonschema:"optional semantic focus for the summary"` + Project string `json:"project,omitempty" jsonschema:"optional project name or id; falls back to the active session project"` + Days int `json:"days,omitempty" jsonschema:"only include thoughts from the last N days when query is omitted"` + Limit int `json:"limit,omitempty" jsonschema:"maximum number of thoughts to summarize"` +} + +type SummarizeOutput struct { + Summary string `json:"summary"` + Count int `json:"count"` +} + +func NewSummarizeTool(db *store.DB, provider ai.Provider, search config.SearchConfig, sessions *session.ActiveProjects) *SummarizeTool { + return &SummarizeTool{store: db, provider: provider, search: search, sessions: sessions} +} + +func (t *SummarizeTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in SummarizeInput) (*mcp.CallToolResult, SummarizeOutput, error) { + project, err := resolveProject(ctx, t.store, t.sessions, req, in.Project, false) + if err != nil { + return nil, SummarizeOutput{}, err + } + + limit := normalizeLimit(in.Limit, t.search) + query := strings.TrimSpace(in.Query) + lines := make([]string, 0, limit) + count := 0 + + if query != "" { + embedding, err := t.provider.Embed(ctx, query) + if err != nil { + return nil, SummarizeOutput{}, err + } + var projectID *uuid.UUID + if project != nil { + projectID = &project.ID + } + results, err := t.store.SearchSimilarThoughts(ctx, embedding, t.search.DefaultThreshold, limit, projectID, nil) + if err != nil { + return nil, SummarizeOutput{}, err + } + for i, result := range results { + lines = append(lines, thoughtContextLine(i, result.Content, result.Metadata, result.Similarity)) + } + count = len(results) + } else { + var projectID *uuid.UUID + if project != nil { + projectID = &project.ID + } + thoughts, err := t.store.RecentThoughts(ctx, projectID, limit, in.Days) + if err != nil { + return nil, SummarizeOutput{}, err + } + for i, thought := range thoughts { + lines = append(lines, thoughtContextLine(i, thought.Content, thought.Metadata, 0)) + } + count = len(thoughts) + } + + userPrompt := formatContextBlock("Summarize the following thoughts into concise prose with themes, action items, and notable people.", lines) + systemPrompt := "You summarize note collections. Be concise, concrete, and structured in plain prose." + summary, err := t.provider.Summarize(ctx, systemPrompt, userPrompt) + if err != nil { + return nil, SummarizeOutput{}, err + } + if project != nil { + _ = t.store.TouchProject(ctx, project.ID) + } + + return nil, SummarizeOutput{Summary: summary, Count: count}, nil +} diff --git a/internal/tools/update.go b/internal/tools/update.go new file mode 100644 index 0000000..65bf75f --- /dev/null +++ b/internal/tools/update.go @@ -0,0 +1,88 @@ +package tools + +import ( + "context" + "log/slog" + "strings" + + "github.com/modelcontextprotocol/go-sdk/mcp" + + "git.warky.dev/wdevs/amcs/internal/ai" + "git.warky.dev/wdevs/amcs/internal/config" + "git.warky.dev/wdevs/amcs/internal/metadata" + "git.warky.dev/wdevs/amcs/internal/store" + thoughttypes "git.warky.dev/wdevs/amcs/internal/types" +) + +type UpdateTool struct { + store *store.DB + provider ai.Provider + capture config.CaptureConfig + log *slog.Logger +} + +type UpdateInput struct { + ID string `json:"id" jsonschema:"the thought id"` + Content *string `json:"content,omitempty" jsonschema:"replacement content for the thought"` + Metadata thoughttypes.ThoughtMetadata `json:"metadata,omitempty" jsonschema:"metadata fields to merge into the thought"` + Project string `json:"project,omitempty" jsonschema:"optional project name or id to move the thought into"` +} + +type UpdateOutput struct { + Thought thoughttypes.Thought `json:"thought"` +} + +func NewUpdateTool(db *store.DB, provider ai.Provider, capture config.CaptureConfig, log *slog.Logger) *UpdateTool { + return &UpdateTool{store: db, provider: provider, capture: capture, log: log} +} + +func (t *UpdateTool) Handle(ctx context.Context, _ *mcp.CallToolRequest, in UpdateInput) (*mcp.CallToolResult, UpdateOutput, error) { + id, err := parseUUID(in.ID) + if err != nil { + return nil, UpdateOutput{}, err + } + + current, err := t.store.GetThought(ctx, id) + if err != nil { + return nil, UpdateOutput{}, err + } + + content := current.Content + embedding := current.Embedding + mergedMetadata := current.Metadata + projectID := current.ProjectID + + if in.Content != nil { + content = strings.TrimSpace(*in.Content) + if content == "" { + return nil, UpdateOutput{}, errInvalidInput("content must not be empty") + } + embedding, err = t.provider.Embed(ctx, content) + if err != nil { + return nil, UpdateOutput{}, err + } + extracted, extractErr := t.provider.ExtractMetadata(ctx, content) + if extractErr != nil { + t.log.Warn("metadata extraction failed during update, keeping current metadata", slog.String("error", extractErr.Error())) + } else { + mergedMetadata = metadata.Normalize(extracted, t.capture) + } + } + + mergedMetadata = metadata.Merge(mergedMetadata, in.Metadata, t.capture) + + if rawProject := strings.TrimSpace(in.Project); rawProject != "" { + project, err := t.store.GetProject(ctx, rawProject) + if err != nil { + return nil, UpdateOutput{}, err + } + projectID = &project.ID + } + + updated, err := t.store.UpdateThought(ctx, id, content, embedding, mergedMetadata, projectID) + if err != nil { + return nil, UpdateOutput{}, err + } + + return nil, UpdateOutput{Thought: updated}, nil +} diff --git a/internal/types/project.go b/internal/types/project.go new file mode 100644 index 0000000..633c03c --- /dev/null +++ b/internal/types/project.go @@ -0,0 +1,40 @@ +package types + +import ( + "time" + + "github.com/google/uuid" +) + +type ThoughtPatch struct { + Content *string `json:"content,omitempty"` + Metadata ThoughtMetadata `json:"metadata,omitempty"` + Project *uuid.UUID `json:"project_id,omitempty"` +} + +type Project struct { + ID uuid.UUID `json:"id"` + Name string `json:"name"` + Description string `json:"description,omitempty"` + CreatedAt time.Time `json:"created_at"` + LastActiveAt time.Time `json:"last_active_at"` +} + +type ProjectSummary struct { + Project + ThoughtCount int `json:"thought_count"` +} + +type ThoughtLink struct { + FromID uuid.UUID `json:"from_id"` + ToID uuid.UUID `json:"to_id"` + Relation string `json:"relation"` + CreatedAt time.Time `json:"created_at"` +} + +type LinkedThought struct { + Thought Thought `json:"thought"` + Relation string `json:"relation"` + Direction string `json:"direction"` + CreatedAt time.Time `json:"created_at"` +} diff --git a/internal/types/thought.go b/internal/types/thought.go new file mode 100644 index 0000000..e02ac1b --- /dev/null +++ b/internal/types/thought.go @@ -0,0 +1,57 @@ +package types + +import ( + "time" + + "github.com/google/uuid" +) + +type ThoughtMetadata struct { + People []string `json:"people"` + ActionItems []string `json:"action_items"` + DatesMentioned []string `json:"dates_mentioned"` + Topics []string `json:"topics"` + Type string `json:"type"` + Source string `json:"source"` +} + +type Thought struct { + ID uuid.UUID `json:"id"` + Content string `json:"content"` + Embedding []float32 `json:"embedding,omitempty"` + Metadata ThoughtMetadata `json:"metadata"` + ProjectID *uuid.UUID `json:"project_id,omitempty"` + ArchivedAt *time.Time `json:"archived_at,omitempty"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +type SearchResult struct { + ID uuid.UUID `json:"id"` + Content string `json:"content"` + Metadata ThoughtMetadata `json:"metadata"` + Similarity float64 `json:"similarity"` + CreatedAt time.Time `json:"created_at"` +} + +type ListFilter struct { + Limit int + Type string + Topic string + Person string + Days int + ProjectID *uuid.UUID + IncludeArchived bool +} + +type ThoughtStats struct { + TotalCount int `json:"total_count"` + TypeCounts map[string]int `json:"type_counts"` + TopTopics []KeyCount `json:"top_topics"` + TopPeople []KeyCount `json:"top_people"` +} + +type KeyCount struct { + Key string `json:"key"` + Count int `json:"count"` +} diff --git a/migrations/001_enable_vector.sql b/migrations/001_enable_vector.sql new file mode 100644 index 0000000..c92e838 --- /dev/null +++ b/migrations/001_enable_vector.sql @@ -0,0 +1,2 @@ +create extension if not exists vector; +create extension if not exists pgcrypto; diff --git a/migrations/002_create_thoughts.sql b/migrations/002_create_thoughts.sql new file mode 100644 index 0000000..11ff04c --- /dev/null +++ b/migrations/002_create_thoughts.sql @@ -0,0 +1,17 @@ +create table if not exists thoughts ( + id uuid default gen_random_uuid() primary key, + content text not null, + embedding vector(1536), + metadata jsonb default '{}'::jsonb, + created_at timestamptz default now(), + updated_at timestamptz default now() +); + +create index if not exists thoughts_embedding_hnsw_idx + on thoughts using hnsw (embedding vector_cosine_ops); + +create index if not exists thoughts_metadata_gin_idx + on thoughts using gin (metadata); + +create index if not exists thoughts_created_at_idx + on thoughts (created_at desc); diff --git a/migrations/003_add_projects.sql b/migrations/003_add_projects.sql new file mode 100644 index 0000000..909a1d9 --- /dev/null +++ b/migrations/003_add_projects.sql @@ -0,0 +1,13 @@ +create table if not exists projects ( + id uuid default gen_random_uuid() primary key, + name text not null unique, + description text, + created_at timestamptz default now(), + last_active_at timestamptz default now() +); + +alter table thoughts add column if not exists project_id uuid references projects(id); +alter table thoughts add column if not exists archived_at timestamptz; + +create index if not exists thoughts_project_id_idx on thoughts (project_id); +create index if not exists thoughts_archived_at_idx on thoughts (archived_at); diff --git a/migrations/004_create_thought_links.sql b/migrations/004_create_thought_links.sql new file mode 100644 index 0000000..ff6d0f3 --- /dev/null +++ b/migrations/004_create_thought_links.sql @@ -0,0 +1,10 @@ +create table if not exists thought_links ( + from_id uuid references thoughts(id) on delete cascade, + to_id uuid references thoughts(id) on delete cascade, + relation text not null, + created_at timestamptz default now(), + primary key (from_id, to_id, relation) +); + +create index if not exists thought_links_from_idx on thought_links (from_id); +create index if not exists thought_links_to_idx on thought_links (to_id); diff --git a/migrations/005_create_match_thoughts.sql b/migrations/005_create_match_thoughts.sql new file mode 100644 index 0000000..49bc2d6 --- /dev/null +++ b/migrations/005_create_match_thoughts.sql @@ -0,0 +1,31 @@ +create or replace function match_thoughts( + query_embedding vector(1536), + match_threshold float default 0.7, + match_count int default 10, + filter jsonb default '{}'::jsonb +) +returns table ( + id uuid, + content text, + metadata jsonb, + similarity float, + created_at timestamptz +) +language plpgsql +as $$ +begin + return query + select + t.id, + t.content, + t.metadata, + 1 - (t.embedding <=> query_embedding) as similarity, + t.created_at + from thoughts t + where 1 - (t.embedding <=> query_embedding) > match_threshold + and t.archived_at is null + and (filter = '{}'::jsonb or t.metadata @> filter) + order by t.embedding <=> query_embedding + limit match_count; +end; +$$; diff --git a/migrations/006_rls_and_grants.sql b/migrations/006_rls_and_grants.sql new file mode 100644 index 0000000..c94fdb5 --- /dev/null +++ b/migrations/006_rls_and_grants.sql @@ -0,0 +1,5 @@ +-- Grant these permissions to the database role used by the application. +-- Replace amcs_user with the actual role in your deployment before applying. +grant select, insert, update, delete on table public.thoughts to amcs_user; +grant select, insert, update, delete on table public.projects to amcs_user; +grant select, insert, update, delete on table public.thought_links to amcs_user; diff --git a/scripts/migrate.sh b/scripts/migrate.sh new file mode 100755 index 0000000..b4b7771 --- /dev/null +++ b/scripts/migrate.sh @@ -0,0 +1,15 @@ +#!/usr/bin/env bash + +set -euo pipefail + +DATABASE_URL="${DATABASE_URL:-${OB1_DATABASE_URL:-}}" + +if [[ -z "${DATABASE_URL}" ]]; then + echo "DATABASE_URL or OB1_DATABASE_URL must be set" >&2 + exit 1 +fi + +for migration in migrations/*.sql; do + echo "Applying ${migration}" + psql "${DATABASE_URL}" -v ON_ERROR_STOP=1 -f "${migration}" +done diff --git a/scripts/run-local.sh b/scripts/run-local.sh new file mode 100755 index 0000000..449ed46 --- /dev/null +++ b/scripts/run-local.sh @@ -0,0 +1,5 @@ +#!/usr/bin/env bash + +set -euo pipefail + +go run ./cmd/amcs-server --config "${1:-configs/dev.yaml}"