feat(tools): implement CRUD operations for thoughts and projects
* Add tools for creating, retrieving, updating, and deleting thoughts. * Implement project management tools for creating and listing projects. * Introduce linking functionality between thoughts. * Add search and recall capabilities for thoughts based on semantic queries. * Implement statistics and summarization tools for thought analysis. * Create database migrations for thoughts, projects, and links. * Add helper functions for UUID parsing and project resolution.
This commit is contained in:
8
.dockerignore
Normal file
8
.dockerignore
Normal file
@@ -0,0 +1,8 @@
|
||||
.git
|
||||
.gitignore
|
||||
.vscode
|
||||
bin
|
||||
.env
|
||||
assets
|
||||
llm
|
||||
*.local.yaml
|
||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -25,3 +25,7 @@ go.work.sum
|
||||
# env file
|
||||
.env
|
||||
|
||||
# local config
|
||||
configs/*.local.yaml
|
||||
cmd/amcs-server/__debug_*
|
||||
bin/
|
||||
13
.vscode/launch.json
vendored
Normal file
13
.vscode/launch.json
vendored
Normal file
@@ -0,0 +1,13 @@
|
||||
{
|
||||
"version": "0.2.0",
|
||||
"configurations": [
|
||||
{
|
||||
"name": "Debug amcs-server",
|
||||
"type": "go",
|
||||
"request": "launch",
|
||||
"mode": "auto",
|
||||
"program": "${workspaceFolder}/cmd/amcs-server",
|
||||
"cwd": "${workspaceFolder}/bin"
|
||||
}
|
||||
]
|
||||
}
|
||||
20
.vscode/tasks.json
vendored
Normal file
20
.vscode/tasks.json
vendored
Normal file
@@ -0,0 +1,20 @@
|
||||
{
|
||||
"version": "2.0.0",
|
||||
"tasks": [
|
||||
{
|
||||
"label": "build",
|
||||
"type": "shell",
|
||||
"command": "make build",
|
||||
|
||||
"group": {
|
||||
"kind": "build",
|
||||
"isDefault": true
|
||||
},
|
||||
"presentation": {
|
||||
"reveal": "silent",
|
||||
"panel": "shared"
|
||||
},
|
||||
"problemMatcher": "$go"
|
||||
}
|
||||
]
|
||||
}
|
||||
30
Dockerfile
Normal file
30
Dockerfile
Normal file
@@ -0,0 +1,30 @@
|
||||
FROM golang:1.26.1-bookworm AS builder
|
||||
|
||||
WORKDIR /src
|
||||
|
||||
COPY go.mod go.sum ./
|
||||
RUN go mod download
|
||||
|
||||
COPY . .
|
||||
|
||||
RUN CGO_ENABLED=0 GOOS=linux go build -trimpath -ldflags="-s -w" -o /out/amcs-server ./cmd/amcs-server
|
||||
|
||||
FROM debian:bookworm-slim
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y --no-install-recommends ca-certificates \
|
||||
&& rm -rf /var/lib/apt/lists/* \
|
||||
&& useradd --system --create-home --uid 10001 appuser
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
COPY --from=builder /out/amcs-server /app/amcs-server
|
||||
COPY --chown=appuser:appuser configs /app/configs
|
||||
|
||||
USER appuser
|
||||
|
||||
EXPOSE 8080
|
||||
|
||||
ENV OB1_CONFIG=/app/configs/docker.yaml
|
||||
|
||||
ENTRYPOINT ["/app/amcs-server"]
|
||||
17
Makefile
Normal file
17
Makefile
Normal file
@@ -0,0 +1,17 @@
|
||||
BIN_DIR := bin
|
||||
SERVER_BIN := $(BIN_DIR)/amcs-server
|
||||
CMD_SERVER := ./cmd/amcs-server
|
||||
|
||||
.PHONY: all build clean migrate
|
||||
|
||||
all: build
|
||||
|
||||
build:
|
||||
@mkdir -p $(BIN_DIR)
|
||||
go build -o $(SERVER_BIN) $(CMD_SERVER)
|
||||
|
||||
migrate:
|
||||
./scripts/migrate.sh
|
||||
|
||||
clean:
|
||||
rm -rf $(BIN_DIR)
|
||||
26
README.md
26
README.md
@@ -50,3 +50,29 @@ Config is YAML-driven. Copy `configs/config.example.yaml` and set:
|
||||
- `ai.litellm.base_url` and `ai.litellm.api_key` — LiteLLM proxy
|
||||
|
||||
See `llm/plan.md` for full architecture and implementation plan.
|
||||
|
||||
## Development
|
||||
|
||||
Run the SQL migrations against a local database with:
|
||||
|
||||
`DATABASE_URL=postgres://... make migrate`
|
||||
|
||||
## Containers
|
||||
|
||||
The repo now includes a `Dockerfile` and Compose files for running the app with Postgres + pgvector.
|
||||
|
||||
1. Set a real LiteLLM key in your shell:
|
||||
`export OB1_LITELLM_API_KEY=your-key`
|
||||
2. Start the stack with your runtime:
|
||||
`docker compose -f docker-compose.yml -f docker-compose.docker.yml up --build`
|
||||
`podman compose -f docker-compose.yml up --build`
|
||||
3. Call the service on `http://localhost:8080`
|
||||
|
||||
Notes:
|
||||
|
||||
- The app uses `configs/docker.yaml` inside the container.
|
||||
- `OB1_LITELLM_BASE_URL` overrides the LiteLLM endpoint, so you can retarget it without editing YAML.
|
||||
- The base Compose file uses `host.containers.internal`, which is Podman-friendly.
|
||||
- The Docker override file adds `host-gateway` aliases so Docker can resolve the same host endpoint.
|
||||
- Database migrations `001` through `005` run automatically when the Postgres volume is created for the first time.
|
||||
- `migrations/006_rls_and_grants.sql` is intentionally skipped during container bootstrap because it contains deployment-specific grants for a role named `amcs_user`.
|
||||
|
||||
24
cmd/amcs-server/main.go
Normal file
24
cmd/amcs-server/main.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"flag"
|
||||
"log"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"git.warky.dev/wdevs/amcs/internal/app"
|
||||
)
|
||||
|
||||
func main() {
|
||||
var configPath string
|
||||
flag.StringVar(&configPath, "config", "", "Path to the YAML config file")
|
||||
flag.Parse()
|
||||
|
||||
ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
|
||||
defer stop()
|
||||
|
||||
if err := app.Run(ctx, configPath); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
72
configs/config.example.yaml
Normal file
72
configs/config.example.yaml
Normal file
@@ -0,0 +1,72 @@
|
||||
server:
|
||||
host: "0.0.0.0"
|
||||
port: 8080
|
||||
read_timeout: "15s"
|
||||
write_timeout: "30s"
|
||||
idle_timeout: "60s"
|
||||
allowed_origins:
|
||||
- "*"
|
||||
|
||||
mcp:
|
||||
path: "/mcp"
|
||||
server_name: "amcs"
|
||||
version: "0.1.0"
|
||||
transport: "streamable_http"
|
||||
|
||||
auth:
|
||||
mode: "api_keys"
|
||||
header_name: "x-brain-key"
|
||||
query_param: "key"
|
||||
allow_query_param: false
|
||||
keys:
|
||||
- id: "local-client"
|
||||
value: "replace-me"
|
||||
description: "main local client key"
|
||||
|
||||
database:
|
||||
url: "postgres://postgres:postgres@localhost:5432/amcs?sslmode=disable"
|
||||
max_conns: 10
|
||||
min_conns: 2
|
||||
max_conn_lifetime: "30m"
|
||||
max_conn_idle_time: "10m"
|
||||
|
||||
ai:
|
||||
provider: "litellm"
|
||||
embeddings:
|
||||
model: "openai/text-embedding-3-small"
|
||||
dimensions: 1536
|
||||
metadata:
|
||||
model: "gpt-4o-mini"
|
||||
temperature: 0.1
|
||||
litellm:
|
||||
base_url: "http://localhost:4000/v1"
|
||||
api_key: "replace-me"
|
||||
use_responses_api: false
|
||||
request_headers: {}
|
||||
embedding_model: "openrouter/openai/text-embedding-3-small"
|
||||
metadata_model: "gpt-4o-mini"
|
||||
openrouter:
|
||||
base_url: "https://openrouter.ai/api/v1"
|
||||
api_key: ""
|
||||
app_name: "amcs"
|
||||
site_url: ""
|
||||
extra_headers: {}
|
||||
|
||||
capture:
|
||||
source: "mcp"
|
||||
metadata_defaults:
|
||||
type: "observation"
|
||||
topic_fallback: "uncategorized"
|
||||
|
||||
search:
|
||||
default_limit: 10
|
||||
default_threshold: 0.5
|
||||
max_limit: 50
|
||||
|
||||
logging:
|
||||
level: "info"
|
||||
format: "json"
|
||||
|
||||
observability:
|
||||
metrics_enabled: true
|
||||
pprof_enabled: false
|
||||
72
configs/dev.yaml
Normal file
72
configs/dev.yaml
Normal file
@@ -0,0 +1,72 @@
|
||||
server:
|
||||
host: "0.0.0.0"
|
||||
port: 8080
|
||||
read_timeout: "15s"
|
||||
write_timeout: "30s"
|
||||
idle_timeout: "60s"
|
||||
allowed_origins:
|
||||
- "*"
|
||||
|
||||
mcp:
|
||||
path: "/mcp"
|
||||
server_name: "amcs"
|
||||
version: "0.1.0"
|
||||
transport: "streamable_http"
|
||||
|
||||
auth:
|
||||
mode: "api_keys"
|
||||
header_name: "x-brain-key"
|
||||
query_param: "key"
|
||||
allow_query_param: false
|
||||
keys:
|
||||
- id: "local-client"
|
||||
value: "replace-me"
|
||||
description: "main local client key"
|
||||
|
||||
database:
|
||||
url: "postgres://postgres:postgres@localhost:5432/amcs?sslmode=disable"
|
||||
max_conns: 10
|
||||
min_conns: 2
|
||||
max_conn_lifetime: "30m"
|
||||
max_conn_idle_time: "10m"
|
||||
|
||||
ai:
|
||||
provider: "litellm"
|
||||
embeddings:
|
||||
model: "openai/text-embedding-3-small"
|
||||
dimensions: 1536
|
||||
metadata:
|
||||
model: "gpt-4o-mini"
|
||||
temperature: 0.1
|
||||
litellm:
|
||||
base_url: "http://localhost:4000/v1"
|
||||
api_key: "replace-me"
|
||||
use_responses_api: false
|
||||
request_headers: {}
|
||||
embedding_model: "openrouter/openai/text-embedding-3-small"
|
||||
metadata_model: "gpt-4o-mini"
|
||||
openrouter:
|
||||
base_url: "https://openrouter.ai/api/v1"
|
||||
api_key: ""
|
||||
app_name: "amcs"
|
||||
site_url: ""
|
||||
extra_headers: {}
|
||||
|
||||
capture:
|
||||
source: "mcp"
|
||||
metadata_defaults:
|
||||
type: "observation"
|
||||
topic_fallback: "uncategorized"
|
||||
|
||||
search:
|
||||
default_limit: 10
|
||||
default_threshold: 0.5
|
||||
max_limit: 50
|
||||
|
||||
logging:
|
||||
level: "info"
|
||||
format: "json"
|
||||
|
||||
observability:
|
||||
metrics_enabled: true
|
||||
pprof_enabled: false
|
||||
72
configs/docker.yaml
Normal file
72
configs/docker.yaml
Normal file
@@ -0,0 +1,72 @@
|
||||
server:
|
||||
host: "0.0.0.0"
|
||||
port: 8080
|
||||
read_timeout: "15s"
|
||||
write_timeout: "30s"
|
||||
idle_timeout: "60s"
|
||||
allowed_origins:
|
||||
- "*"
|
||||
|
||||
mcp:
|
||||
path: "/mcp"
|
||||
server_name: "amcs"
|
||||
version: "0.1.0"
|
||||
transport: "streamable_http"
|
||||
|
||||
auth:
|
||||
mode: "api_keys"
|
||||
header_name: "x-brain-key"
|
||||
query_param: "key"
|
||||
allow_query_param: false
|
||||
keys:
|
||||
- id: "local-client"
|
||||
value: "replace-me"
|
||||
description: "main local client key"
|
||||
|
||||
database:
|
||||
url: "postgres://postgres:postgres@db:5432/amcs?sslmode=disable"
|
||||
max_conns: 10
|
||||
min_conns: 2
|
||||
max_conn_lifetime: "30m"
|
||||
max_conn_idle_time: "10m"
|
||||
|
||||
ai:
|
||||
provider: "litellm"
|
||||
embeddings:
|
||||
model: "openai/text-embedding-3-small"
|
||||
dimensions: 1536
|
||||
metadata:
|
||||
model: "gpt-4o-mini"
|
||||
temperature: 0.1
|
||||
litellm:
|
||||
base_url: "http://host.containers.internal:4000/v1"
|
||||
api_key: "replace-me"
|
||||
use_responses_api: false
|
||||
request_headers: {}
|
||||
embedding_model: "openrouter/openai/text-embedding-3-small"
|
||||
metadata_model: "gpt-4o-mini"
|
||||
openrouter:
|
||||
base_url: "https://openrouter.ai/api/v1"
|
||||
api_key: ""
|
||||
app_name: "amcs"
|
||||
site_url: ""
|
||||
extra_headers: {}
|
||||
|
||||
capture:
|
||||
source: "mcp"
|
||||
metadata_defaults:
|
||||
type: "observation"
|
||||
topic_fallback: "uncategorized"
|
||||
|
||||
search:
|
||||
default_limit: 10
|
||||
default_threshold: 0.5
|
||||
max_limit: 50
|
||||
|
||||
logging:
|
||||
level: "info"
|
||||
format: "json"
|
||||
|
||||
observability:
|
||||
metrics_enabled: true
|
||||
pprof_enabled: false
|
||||
5
docker-compose.docker.yml
Normal file
5
docker-compose.docker.yml
Normal file
@@ -0,0 +1,5 @@
|
||||
services:
|
||||
app:
|
||||
extra_hosts:
|
||||
- "host.containers.internal:host-gateway"
|
||||
- "host.docker.internal:host-gateway"
|
||||
38
docker-compose.yml
Normal file
38
docker-compose.yml
Normal file
@@ -0,0 +1,38 @@
|
||||
services:
|
||||
db:
|
||||
image: pgvector/pgvector:pg16
|
||||
restart: unless-stopped
|
||||
environment:
|
||||
POSTGRES_DB: amcs
|
||||
POSTGRES_USER: postgres
|
||||
POSTGRES_PASSWORD: postgres
|
||||
ports:
|
||||
- "5432:5432"
|
||||
volumes:
|
||||
- postgres_data:/var/lib/postgresql/data
|
||||
- ./docker/postgres/init:/docker-entrypoint-initdb.d:ro
|
||||
- ./migrations:/migrations:ro
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "pg_isready -U postgres -d amcs"]
|
||||
interval: 5s
|
||||
timeout: 5s
|
||||
retries: 20
|
||||
|
||||
app:
|
||||
build:
|
||||
context: .
|
||||
depends_on:
|
||||
db:
|
||||
condition: service_healthy
|
||||
restart: unless-stopped
|
||||
environment:
|
||||
OB1_CONFIG: /app/configs/docker.yaml
|
||||
OB1_DATABASE_URL: postgres://postgres:postgres@db:5432/amcs?sslmode=disable
|
||||
OB1_LITELLM_BASE_URL: ${OB1_LITELLM_BASE_URL:-http://host.containers.internal:4000/v1}
|
||||
OB1_LITELLM_API_KEY: ${OB1_LITELLM_API_KEY:-replace-me}
|
||||
OB1_SERVER_PORT: 8080
|
||||
ports:
|
||||
- "8080:8080"
|
||||
|
||||
volumes:
|
||||
postgres_data:
|
||||
15
docker/postgres/init/00_apply_migrations.sh
Executable file
15
docker/postgres/init/00_apply_migrations.sh
Executable file
@@ -0,0 +1,15 @@
|
||||
#!/bin/sh
|
||||
|
||||
set -eu
|
||||
|
||||
for migration in /migrations/*.sql; do
|
||||
case "$migration" in
|
||||
*/006_rls_and_grants.sql)
|
||||
echo "Skipping $migration because it contains deployment-specific grants."
|
||||
;;
|
||||
*)
|
||||
echo "Applying $migration"
|
||||
psql -v ON_ERROR_STOP=1 --username "$POSTGRES_USER" --dbname "$POSTGRES_DB" -f "$migration"
|
||||
;;
|
||||
esac
|
||||
done
|
||||
28
go.mod
Normal file
28
go.mod
Normal file
@@ -0,0 +1,28 @@
|
||||
module git.warky.dev/wdevs/amcs
|
||||
|
||||
go 1.26.1
|
||||
|
||||
require (
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/jackc/pgx/v5 v5.9.1
|
||||
github.com/modelcontextprotocol/go-sdk v1.4.1
|
||||
github.com/pgvector/pgvector-go v0.3.0
|
||||
golang.org/x/sync v0.17.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/google/jsonschema-go v0.4.2 // indirect
|
||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
|
||||
github.com/jackc/puddle/v2 v2.2.2 // indirect
|
||||
github.com/kr/text v0.2.0 // indirect
|
||||
github.com/rogpeppe/go-internal v1.14.1 // indirect
|
||||
github.com/segmentio/asm v1.1.3 // indirect
|
||||
github.com/segmentio/encoding v0.5.4 // indirect
|
||||
github.com/x448/float16 v0.8.4 // indirect
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
|
||||
golang.org/x/oauth2 v0.34.0 // indirect
|
||||
golang.org/x/sys v0.40.0 // indirect
|
||||
golang.org/x/text v0.29.0 // indirect
|
||||
)
|
||||
99
go.sum
Normal file
99
go.sum
Normal file
@@ -0,0 +1,99 @@
|
||||
entgo.io/ent v0.14.3 h1:wokAV/kIlH9TeklJWGGS7AYJdVckr0DloWjIcO9iIIQ=
|
||||
entgo.io/ent v0.14.3/go.mod h1:aDPE/OziPEu8+OWbzy4UlvWmD2/kbRuWfK2A40hcxJM=
|
||||
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/go-pg/pg/v10 v10.11.0 h1:CMKJqLgTrfpE/aOVeLdybezR2om071Vh38OLZjsyMI0=
|
||||
github.com/go-pg/pg/v10 v10.11.0/go.mod h1:4BpHRoxE61y4Onpof3x1a2SQvi9c+q1dJnrNdMjsroA=
|
||||
github.com/go-pg/zerochecker v0.2.0 h1:pp7f72c3DobMWOb2ErtZsnrPaSvHd2W4o9//8HtF4mU=
|
||||
github.com/go-pg/zerochecker v0.2.0/go.mod h1:NJZ4wKL0NmTtz0GKCoJ8kym6Xn/EQzXRl2OnAe7MmDo=
|
||||
github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo=
|
||||
github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
|
||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||
github.com/google/jsonschema-go v0.4.2 h1:tmrUohrwoLZZS/P3x7ex0WAVknEkBZM46iALbcqoRA8=
|
||||
github.com/google/jsonschema-go v0.4.2/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
||||
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
|
||||
github.com/jackc/pgx/v5 v5.9.1 h1:uwrxJXBnx76nyISkhr33kQLlUqjv7et7b9FjCen/tdc=
|
||||
github.com/jackc/pgx/v5 v5.9.1/go.mod h1:mal1tBGAFfLHvZzaYh77YS/eC6IX9OWbRV1QIIM0Jn4=
|
||||
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
|
||||
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
|
||||
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
|
||||
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
|
||||
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
|
||||
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
|
||||
github.com/jmoiron/sqlx v1.3.5 h1:vFFPA71p1o5gAeqtEAwLU4dnX2napprKtHr7PYIcN3g=
|
||||
github.com/jmoiron/sqlx v1.3.5/go.mod h1:nRVWtLre0KfCLJvgxzCsLVMogSvQ1zNJtpYr2Ccp0mQ=
|
||||
github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0=
|
||||
github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
|
||||
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
|
||||
github.com/modelcontextprotocol/go-sdk v1.4.1 h1:M4x9GyIPj+HoIlHNGpK2hq5o3BFhC+78PkEaldQRphc=
|
||||
github.com/modelcontextprotocol/go-sdk v1.4.1/go.mod h1:Bo/mS87hPQqHSRkMv4dQq1XCu6zv4INdXnFZabkNU6s=
|
||||
github.com/pgvector/pgvector-go v0.3.0 h1:Ij+Yt78R//uYqs3Zk35evZFvr+G0blW0OUN+Q2D1RWc=
|
||||
github.com/pgvector/pgvector-go v0.3.0/go.mod h1:duFy+PXWfW7QQd5ibqutBO4GxLsUZ9RVXhFZGIBsWSA=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ=
|
||||
github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc=
|
||||
github.com/segmentio/asm v1.1.3 h1:WM03sfUOENvvKexOLp+pCqgb/WDjsi7EK8gIsICtzhc=
|
||||
github.com/segmentio/asm v1.1.3/go.mod h1:Ld3L4ZXGNcSLRg4JBsZ3//1+f/TjYl0Mzen/DQy1EJg=
|
||||
github.com/segmentio/encoding v0.5.4 h1:OW1VRern8Nw6ITAtwSZ7Idrl3MXCFwXHPgqESYfvNt0=
|
||||
github.com/segmentio/encoding v0.5.4/go.mod h1:HS1ZKa3kSN32ZHVZ7ZLPLXWvOVIiZtyJnO1gPH1sKt0=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc h1:9lRDQMhESg+zvGYmW5DyG0UqvY96Bu5QYsTLvCHdrgo=
|
||||
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc/go.mod h1:bciPuU6GHm1iF1pBvUfxfsH0Wmnc2VbpgvbI9ZWuIRs=
|
||||
github.com/uptrace/bun v1.1.12 h1:sOjDVHxNTuM6dNGaba0wUuz7KvDE1BmNu9Gqs2gJSXQ=
|
||||
github.com/uptrace/bun v1.1.12/go.mod h1:NPG6JGULBeQ9IU6yHp7YGELRa5Agmd7ATZdz4tGZ6z0=
|
||||
github.com/uptrace/bun/dialect/pgdialect v1.1.12 h1:m/CM1UfOkoBTglGO5CUTKnIKKOApOYxkcP2qn0F9tJk=
|
||||
github.com/uptrace/bun/dialect/pgdialect v1.1.12/go.mod h1:Ij6WIxQILxLlL2frUBxUBOZJtLElD2QQNDcu/PWDHTc=
|
||||
github.com/uptrace/bun/driver/pgdriver v1.1.12 h1:3rRWB1GK0psTJrHwxzNfEij2MLibggiLdTqjTtfHc1w=
|
||||
github.com/uptrace/bun/driver/pgdriver v1.1.12/go.mod h1:ssYUP+qwSEgeDDS1xm2XBip9el1y9Mi5mTAvLoiADLM=
|
||||
github.com/vmihailenco/bufpool v0.1.11 h1:gOq2WmBrq0i2yW5QJ16ykccQ4wH9UyEsgLm6czKAd94=
|
||||
github.com/vmihailenco/bufpool v0.1.11/go.mod h1:AFf/MOy3l2CFTKbxwt0mp2MwnqjNEs5H/UxrkA5jxTQ=
|
||||
github.com/vmihailenco/msgpack/v5 v5.3.5 h1:5gO0H1iULLWGhs2H5tbAHIZTV8/cYafcFOr9znI5mJU=
|
||||
github.com/vmihailenco/msgpack/v5 v5.3.5/go.mod h1:7xyJ9e+0+9SaZT0Wt1RGleJXzli6Q/V5KbhBonMG9jc=
|
||||
github.com/vmihailenco/tagparser v0.1.2 h1:gnjoVuB/kljJ5wICEEOpx98oXMWPLj22G67Vbd1qPqc=
|
||||
github.com/vmihailenco/tagparser v0.1.2/go.mod h1:OeAg3pn3UbLjkWt+rN9oFYB6u/cQgqMEUPoW2WPyhdI=
|
||||
github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g=
|
||||
github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds=
|
||||
github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=
|
||||
github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg=
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4=
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4=
|
||||
golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
|
||||
golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
|
||||
golang.org/x/oauth2 v0.34.0 h1:hqK/t4AKgbqWkdkcAeI8XLmbK+4m4G5YeQRrmiotGlw=
|
||||
golang.org/x/oauth2 v0.34.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA=
|
||||
golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug=
|
||||
golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ=
|
||||
golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/text v0.29.0 h1:1neNs90w9YzJ9BocxfsQNHKuAT4pkghyXc4nhZ6sJvk=
|
||||
golang.org/x/text v0.29.0/go.mod h1:7MhJOA9CD2qZyOKYazxdYMF85OwPdEr9jTtBpO7ydH4=
|
||||
golang.org/x/tools v0.41.0 h1:a9b8iMweWG+S0OBnlU36rzLp20z1Rp10w+IY2czHTQc=
|
||||
golang.org/x/tools v0.41.0/go.mod h1:XSY6eDqxVNiYgezAVqqCeihT4j1U2CCsqvH3WhQpnlg=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gorm.io/driver/postgres v1.5.4 h1:Iyrp9Meh3GmbSuyIAGyjkN+n9K+GHX9b9MqsTL4EJCo=
|
||||
gorm.io/driver/postgres v1.5.4/go.mod h1:Bgo89+h0CRcdA33Y6frlaHHVuTdOf87pmyzwW9C/BH0=
|
||||
gorm.io/gorm v1.25.5 h1:zR9lOiiYf09VNh5Q1gphfyia1JpiClIWG9hQaxB/mls=
|
||||
gorm.io/gorm v1.25.5/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8=
|
||||
mellium.im/sasl v0.3.1 h1:wE0LW6g7U83vhvxjC1IY8DnXM+EU095yeo8XClvCdfo=
|
||||
mellium.im/sasl v0.3.1/go.mod h1:xm59PUYpZHhgQ9ZqoJ5QaCqzWMi8IeS49dhp6plPCzw=
|
||||
330
internal/ai/compat/client.go
Normal file
330
internal/ai/compat/client.go
Normal file
@@ -0,0 +1,330 @@
|
||||
package compat
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
thoughttypes "git.warky.dev/wdevs/amcs/internal/types"
|
||||
)
|
||||
|
||||
const metadataSystemPrompt = `You extract structured metadata from short notes.
|
||||
Return only valid JSON matching this schema:
|
||||
{
|
||||
"people": ["string"],
|
||||
"action_items": ["string"],
|
||||
"dates_mentioned": ["string"],
|
||||
"topics": ["string"],
|
||||
"type": "observation|task|idea|reference|person_note",
|
||||
"source": "string"
|
||||
}
|
||||
Rules:
|
||||
- Keep arrays concise.
|
||||
- Use lowercase for type.
|
||||
- If unsure, prefer "observation".
|
||||
- Do not include any text outside the JSON object.`
|
||||
|
||||
type Client struct {
|
||||
name string
|
||||
baseURL string
|
||||
apiKey string
|
||||
embeddingModel string
|
||||
metadataModel string
|
||||
temperature float64
|
||||
headers map[string]string
|
||||
httpClient *http.Client
|
||||
log *slog.Logger
|
||||
dimensions int
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Name string
|
||||
BaseURL string
|
||||
APIKey string
|
||||
EmbeddingModel string
|
||||
MetadataModel string
|
||||
Temperature float64
|
||||
Headers map[string]string
|
||||
HTTPClient *http.Client
|
||||
Log *slog.Logger
|
||||
Dimensions int
|
||||
}
|
||||
|
||||
type embeddingsRequest struct {
|
||||
Input string `json:"input"`
|
||||
Model string `json:"model"`
|
||||
}
|
||||
|
||||
type embeddingsResponse struct {
|
||||
Data []struct {
|
||||
Embedding []float32 `json:"embedding"`
|
||||
} `json:"data"`
|
||||
Error *providerError `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
type chatCompletionsRequest struct {
|
||||
Model string `json:"model"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
ResponseFormat *responseType `json:"response_format,omitempty"`
|
||||
Messages []chatMessage `json:"messages"`
|
||||
}
|
||||
|
||||
type responseType struct {
|
||||
Type string `json:"type"`
|
||||
}
|
||||
|
||||
type chatMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type chatCompletionsResponse struct {
|
||||
Choices []struct {
|
||||
Message chatMessage `json:"message"`
|
||||
} `json:"choices"`
|
||||
Error *providerError `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
type providerError struct {
|
||||
Message string `json:"message"`
|
||||
Type string `json:"type,omitempty"`
|
||||
}
|
||||
|
||||
func New(cfg Config) *Client {
|
||||
return &Client{
|
||||
name: cfg.Name,
|
||||
baseURL: cfg.BaseURL,
|
||||
apiKey: cfg.APIKey,
|
||||
embeddingModel: cfg.EmbeddingModel,
|
||||
metadataModel: cfg.MetadataModel,
|
||||
temperature: cfg.Temperature,
|
||||
headers: cfg.Headers,
|
||||
httpClient: cfg.HTTPClient,
|
||||
log: cfg.Log,
|
||||
dimensions: cfg.Dimensions,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) Embed(ctx context.Context, input string) ([]float32, error) {
|
||||
input = strings.TrimSpace(input)
|
||||
if input == "" {
|
||||
return nil, fmt.Errorf("%s embed: input must not be empty", c.name)
|
||||
}
|
||||
|
||||
var resp embeddingsResponse
|
||||
err := c.doJSON(ctx, "/embeddings", embeddingsRequest{
|
||||
Input: input,
|
||||
Model: c.embeddingModel,
|
||||
}, &resp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Error != nil {
|
||||
return nil, fmt.Errorf("%s embed error: %s", c.name, resp.Error.Message)
|
||||
}
|
||||
if len(resp.Data) == 0 {
|
||||
return nil, fmt.Errorf("%s embed: no embedding returned", c.name)
|
||||
}
|
||||
if c.dimensions > 0 && len(resp.Data[0].Embedding) != c.dimensions {
|
||||
return nil, fmt.Errorf("%s embed: expected %d dimensions, got %d", c.name, c.dimensions, len(resp.Data[0].Embedding))
|
||||
}
|
||||
|
||||
return resp.Data[0].Embedding, nil
|
||||
}
|
||||
|
||||
func (c *Client) ExtractMetadata(ctx context.Context, input string) (thoughttypes.ThoughtMetadata, error) {
|
||||
input = strings.TrimSpace(input)
|
||||
if input == "" {
|
||||
return thoughttypes.ThoughtMetadata{}, fmt.Errorf("%s extract metadata: input must not be empty", c.name)
|
||||
}
|
||||
|
||||
req := chatCompletionsRequest{
|
||||
Model: c.metadataModel,
|
||||
Temperature: c.temperature,
|
||||
ResponseFormat: &responseType{
|
||||
Type: "json_object",
|
||||
},
|
||||
Messages: []chatMessage{
|
||||
{Role: "system", Content: metadataSystemPrompt},
|
||||
{Role: "user", Content: input},
|
||||
},
|
||||
}
|
||||
|
||||
var resp chatCompletionsResponse
|
||||
if err := c.doJSON(ctx, "/chat/completions", req, &resp); err != nil {
|
||||
return thoughttypes.ThoughtMetadata{}, err
|
||||
}
|
||||
if resp.Error != nil {
|
||||
return thoughttypes.ThoughtMetadata{}, fmt.Errorf("%s metadata error: %s", c.name, resp.Error.Message)
|
||||
}
|
||||
if len(resp.Choices) == 0 {
|
||||
return thoughttypes.ThoughtMetadata{}, fmt.Errorf("%s metadata: no choices returned", c.name)
|
||||
}
|
||||
|
||||
metadataText := strings.TrimSpace(resp.Choices[0].Message.Content)
|
||||
metadataText = stripCodeFence(metadataText)
|
||||
|
||||
var metadata thoughttypes.ThoughtMetadata
|
||||
if err := json.Unmarshal([]byte(metadataText), &metadata); err != nil {
|
||||
return thoughttypes.ThoughtMetadata{}, fmt.Errorf("%s metadata: parse json: %w", c.name, err)
|
||||
}
|
||||
|
||||
return metadata, nil
|
||||
}
|
||||
|
||||
func (c *Client) Summarize(ctx context.Context, systemPrompt, userPrompt string) (string, error) {
|
||||
req := chatCompletionsRequest{
|
||||
Model: c.metadataModel,
|
||||
Temperature: 0.2,
|
||||
Messages: []chatMessage{
|
||||
{Role: "system", Content: systemPrompt},
|
||||
{Role: "user", Content: userPrompt},
|
||||
},
|
||||
}
|
||||
|
||||
var resp chatCompletionsResponse
|
||||
if err := c.doJSON(ctx, "/chat/completions", req, &resp); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if resp.Error != nil {
|
||||
return "", fmt.Errorf("%s summarize error: %s", c.name, resp.Error.Message)
|
||||
}
|
||||
if len(resp.Choices) == 0 {
|
||||
return "", fmt.Errorf("%s summarize: no choices returned", c.name)
|
||||
}
|
||||
|
||||
return strings.TrimSpace(resp.Choices[0].Message.Content), nil
|
||||
}
|
||||
|
||||
func (c *Client) Name() string {
|
||||
return c.name
|
||||
}
|
||||
|
||||
func (c *Client) doJSON(ctx context.Context, path string, requestBody any, dest any) error {
|
||||
body, err := json.Marshal(requestBody)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s request marshal: %w", c.name, err)
|
||||
}
|
||||
|
||||
const maxAttempts = 3
|
||||
|
||||
var lastErr error
|
||||
for attempt := 1; attempt <= maxAttempts; attempt++ {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, strings.TrimRight(c.baseURL, "/")+path, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s build request: %w", c.name, err)
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+c.apiKey)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
for key, value := range c.headers {
|
||||
if strings.TrimSpace(key) == "" || strings.TrimSpace(value) == "" {
|
||||
continue
|
||||
}
|
||||
req.Header.Set(key, value)
|
||||
}
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
lastErr = fmt.Errorf("%s request failed: %w", c.name, err)
|
||||
if attempt < maxAttempts && isRetryableError(err) {
|
||||
if retryErr := sleepRetry(ctx, attempt, c.log, c.name); retryErr != nil {
|
||||
return retryErr
|
||||
}
|
||||
continue
|
||||
}
|
||||
return lastErr
|
||||
}
|
||||
|
||||
payload, readErr := io.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
if readErr != nil {
|
||||
lastErr = fmt.Errorf("%s read response: %w", c.name, readErr)
|
||||
if attempt < maxAttempts {
|
||||
if retryErr := sleepRetry(ctx, attempt, c.log, c.name); retryErr != nil {
|
||||
return retryErr
|
||||
}
|
||||
continue
|
||||
}
|
||||
return lastErr
|
||||
}
|
||||
|
||||
if resp.StatusCode >= http.StatusBadRequest {
|
||||
lastErr = fmt.Errorf("%s request failed with status %d: %s", c.name, resp.StatusCode, strings.TrimSpace(string(payload)))
|
||||
if attempt < maxAttempts && isRetryableStatus(resp.StatusCode) {
|
||||
if retryErr := sleepRetry(ctx, attempt, c.log, c.name); retryErr != nil {
|
||||
return retryErr
|
||||
}
|
||||
continue
|
||||
}
|
||||
return lastErr
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(payload, dest); err != nil {
|
||||
if c.log != nil {
|
||||
c.log.Debug("provider response body", slog.String("provider", c.name), slog.String("body", string(payload)))
|
||||
}
|
||||
return fmt.Errorf("%s decode response: %w", c.name, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
return lastErr
|
||||
}
|
||||
|
||||
func stripCodeFence(value string) string {
|
||||
value = strings.TrimSpace(value)
|
||||
if !strings.HasPrefix(value, "```") {
|
||||
return value
|
||||
}
|
||||
|
||||
value = strings.TrimPrefix(value, "```json")
|
||||
value = strings.TrimPrefix(value, "```")
|
||||
value = strings.TrimSuffix(value, "```")
|
||||
return strings.TrimSpace(value)
|
||||
}
|
||||
|
||||
func isRetryableStatus(status int) bool {
|
||||
switch status {
|
||||
case http.StatusTooManyRequests, http.StatusInternalServerError, http.StatusBadGateway, http.StatusServiceUnavailable, http.StatusGatewayTimeout:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func isRetryableError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) {
|
||||
return true
|
||||
}
|
||||
var netErr net.Error
|
||||
return errors.As(err, &netErr)
|
||||
}
|
||||
|
||||
func sleepRetry(ctx context.Context, attempt int, log *slog.Logger, provider string) error {
|
||||
delay := time.Duration(attempt*attempt) * 200 * time.Millisecond
|
||||
if log != nil {
|
||||
log.Warn("retrying provider request", slog.String("provider", provider), slog.Duration("delay", delay), slog.Int("attempt", attempt+1))
|
||||
}
|
||||
timer := time.NewTimer(delay)
|
||||
defer timer.Stop()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-timer.C:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
91
internal/ai/compat/client_test.go
Normal file
91
internal/ai/compat/client_test.go
Normal file
@@ -0,0 +1,91 @@
|
||||
package compat
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func discardLogger() *slog.Logger {
|
||||
return slog.New(slog.NewTextHandler(io.Discard, nil))
|
||||
}
|
||||
|
||||
func TestEmbedRetriesTransientFailures(t *testing.T) {
|
||||
var calls atomic.Int32
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if calls.Add(1) < 3 {
|
||||
http.Error(w, "temporary failure", http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"data": []map[string]any{
|
||||
{"embedding": []float32{1, 2, 3}},
|
||||
},
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := New(Config{
|
||||
Name: "test",
|
||||
BaseURL: server.URL,
|
||||
APIKey: "secret",
|
||||
EmbeddingModel: "embed-model",
|
||||
MetadataModel: "meta-model",
|
||||
HTTPClient: server.Client(),
|
||||
Log: discardLogger(),
|
||||
Dimensions: 3,
|
||||
})
|
||||
|
||||
embedding, err := client.Embed(context.Background(), "hello")
|
||||
if err != nil {
|
||||
t.Fatalf("Embed() error = %v", err)
|
||||
}
|
||||
if len(embedding) != 3 {
|
||||
t.Fatalf("embedding len = %d, want 3", len(embedding))
|
||||
}
|
||||
if got := calls.Load(); got != 3 {
|
||||
t.Fatalf("call count = %d, want 3", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractMetadataParsesCodeFencedJSON(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"choices": []map[string]any{
|
||||
{
|
||||
"message": map[string]any{
|
||||
"content": "```json\n{\"people\":[\"Alice\"],\"action_items\":[],\"dates_mentioned\":[],\"topics\":[\"memory\"],\"type\":\"idea\",\"source\":\"mcp\"}\n```",
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := New(Config{
|
||||
Name: "test",
|
||||
BaseURL: server.URL,
|
||||
APIKey: "secret",
|
||||
EmbeddingModel: "embed-model",
|
||||
MetadataModel: "meta-model",
|
||||
HTTPClient: server.Client(),
|
||||
Log: discardLogger(),
|
||||
})
|
||||
|
||||
metadata, err := client.ExtractMetadata(context.Background(), "hello")
|
||||
if err != nil {
|
||||
t.Fatalf("ExtractMetadata() error = %v", err)
|
||||
}
|
||||
if metadata.Type != "idea" {
|
||||
t.Fatalf("metadata type = %q, want idea", metadata.Type)
|
||||
}
|
||||
if len(metadata.People) != 1 || metadata.People[0] != "Alice" {
|
||||
t.Fatalf("metadata people = %#v, want [Alice]", metadata.People)
|
||||
}
|
||||
}
|
||||
22
internal/ai/factory.go
Normal file
22
internal/ai/factory.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package ai
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
|
||||
"git.warky.dev/wdevs/amcs/internal/ai/litellm"
|
||||
"git.warky.dev/wdevs/amcs/internal/ai/openrouter"
|
||||
"git.warky.dev/wdevs/amcs/internal/config"
|
||||
)
|
||||
|
||||
func NewProvider(cfg config.AIConfig, httpClient *http.Client, log *slog.Logger) (Provider, error) {
|
||||
switch cfg.Provider {
|
||||
case "litellm":
|
||||
return litellm.New(cfg, httpClient, log)
|
||||
case "openrouter":
|
||||
return openrouter.New(cfg, httpClient, log)
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported ai.provider: %s", cfg.Provider)
|
||||
}
|
||||
}
|
||||
24
internal/ai/litellm/client.go
Normal file
24
internal/ai/litellm/client.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package litellm
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"net/http"
|
||||
|
||||
"git.warky.dev/wdevs/amcs/internal/ai/compat"
|
||||
"git.warky.dev/wdevs/amcs/internal/config"
|
||||
)
|
||||
|
||||
func New(cfg config.AIConfig, httpClient *http.Client, log *slog.Logger) (*compat.Client, error) {
|
||||
return compat.New(compat.Config{
|
||||
Name: "litellm",
|
||||
BaseURL: cfg.LiteLLM.BaseURL,
|
||||
APIKey: cfg.LiteLLM.APIKey,
|
||||
EmbeddingModel: cfg.LiteLLM.EmbeddingModel,
|
||||
MetadataModel: cfg.LiteLLM.MetadataModel,
|
||||
Temperature: cfg.Metadata.Temperature,
|
||||
Headers: cfg.LiteLLM.RequestHeaders,
|
||||
HTTPClient: httpClient,
|
||||
Log: log,
|
||||
Dimensions: cfg.Embeddings.Dimensions,
|
||||
}), nil
|
||||
}
|
||||
35
internal/ai/openrouter/client.go
Normal file
35
internal/ai/openrouter/client.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package openrouter
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"net/http"
|
||||
|
||||
"git.warky.dev/wdevs/amcs/internal/ai/compat"
|
||||
"git.warky.dev/wdevs/amcs/internal/config"
|
||||
)
|
||||
|
||||
func New(cfg config.AIConfig, httpClient *http.Client, log *slog.Logger) (*compat.Client, error) {
|
||||
headers := make(map[string]string, len(cfg.OpenRouter.ExtraHeaders)+2)
|
||||
for key, value := range cfg.OpenRouter.ExtraHeaders {
|
||||
headers[key] = value
|
||||
}
|
||||
if cfg.OpenRouter.SiteURL != "" {
|
||||
headers["HTTP-Referer"] = cfg.OpenRouter.SiteURL
|
||||
}
|
||||
if cfg.OpenRouter.AppName != "" {
|
||||
headers["X-Title"] = cfg.OpenRouter.AppName
|
||||
}
|
||||
|
||||
return compat.New(compat.Config{
|
||||
Name: "openrouter",
|
||||
BaseURL: cfg.OpenRouter.BaseURL,
|
||||
APIKey: cfg.OpenRouter.APIKey,
|
||||
EmbeddingModel: cfg.Embeddings.Model,
|
||||
MetadataModel: cfg.Metadata.Model,
|
||||
Temperature: cfg.Metadata.Temperature,
|
||||
Headers: headers,
|
||||
HTTPClient: httpClient,
|
||||
Log: log,
|
||||
Dimensions: cfg.Embeddings.Dimensions,
|
||||
}), nil
|
||||
}
|
||||
14
internal/ai/provider.go
Normal file
14
internal/ai/provider.go
Normal file
@@ -0,0 +1,14 @@
|
||||
package ai
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
thoughttypes "git.warky.dev/wdevs/amcs/internal/types"
|
||||
)
|
||||
|
||||
type Provider interface {
|
||||
Embed(ctx context.Context, input string) ([]float32, error)
|
||||
ExtractMetadata(ctx context.Context, input string) (thoughttypes.ThoughtMetadata, error)
|
||||
Summarize(ctx context.Context, systemPrompt, userPrompt string) (string, error)
|
||||
Name() string
|
||||
}
|
||||
145
internal/app/app.go
Normal file
145
internal/app/app.go
Normal file
@@ -0,0 +1,145 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"git.warky.dev/wdevs/amcs/internal/ai"
|
||||
"git.warky.dev/wdevs/amcs/internal/auth"
|
||||
"git.warky.dev/wdevs/amcs/internal/config"
|
||||
"git.warky.dev/wdevs/amcs/internal/mcpserver"
|
||||
"git.warky.dev/wdevs/amcs/internal/observability"
|
||||
"git.warky.dev/wdevs/amcs/internal/session"
|
||||
"git.warky.dev/wdevs/amcs/internal/store"
|
||||
"git.warky.dev/wdevs/amcs/internal/tools"
|
||||
)
|
||||
|
||||
func Run(ctx context.Context, configPath string) error {
|
||||
cfg, loadedFrom, err := config.Load(configPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
logger, err := observability.NewLogger(cfg.Logging)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
logger.Info("loaded configuration",
|
||||
slog.String("path", loadedFrom),
|
||||
slog.String("provider", cfg.AI.Provider),
|
||||
)
|
||||
|
||||
db, err := store.New(ctx, cfg.Database)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
if err := db.VerifyRequirements(ctx, cfg.AI.Embeddings.Dimensions); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
httpClient := &http.Client{Timeout: 30 * time.Second}
|
||||
provider, err := ai.NewProvider(cfg.AI, httpClient, logger)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
keyring, err := auth.NewKeyring(cfg.Auth.Keys)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
activeProjects := session.NewActiveProjects()
|
||||
|
||||
logger.Info("database connection verified",
|
||||
slog.String("provider", provider.Name()),
|
||||
)
|
||||
|
||||
server := &http.Server{
|
||||
Addr: fmt.Sprintf("%s:%d", cfg.Server.Host, cfg.Server.Port),
|
||||
Handler: routes(logger, cfg, db, provider, keyring, activeProjects),
|
||||
ReadTimeout: cfg.Server.ReadTimeout,
|
||||
WriteTimeout: cfg.Server.WriteTimeout,
|
||||
IdleTimeout: cfg.Server.IdleTimeout,
|
||||
}
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
logger.Info("starting HTTP server",
|
||||
slog.String("addr", server.Addr),
|
||||
slog.String("mcp_path", cfg.MCP.Path),
|
||||
)
|
||||
|
||||
if err := server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
errCh <- err
|
||||
}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
logger.Info("shutting down HTTP server")
|
||||
return server.Shutdown(shutdownCtx)
|
||||
case err := <-errCh:
|
||||
return fmt.Errorf("run server: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
func routes(logger *slog.Logger, cfg *config.Config, db *store.DB, provider ai.Provider, keyring *auth.Keyring, activeProjects *session.ActiveProjects) http.Handler {
|
||||
mux := http.NewServeMux()
|
||||
|
||||
toolSet := mcpserver.ToolSet{
|
||||
Capture: tools.NewCaptureTool(db, provider, cfg.Capture, activeProjects, logger),
|
||||
Search: tools.NewSearchTool(db, provider, cfg.Search, activeProjects),
|
||||
List: tools.NewListTool(db, cfg.Search, activeProjects),
|
||||
Stats: tools.NewStatsTool(db),
|
||||
Get: tools.NewGetTool(db),
|
||||
Update: tools.NewUpdateTool(db, provider, cfg.Capture, logger),
|
||||
Delete: tools.NewDeleteTool(db),
|
||||
Archive: tools.NewArchiveTool(db),
|
||||
Projects: tools.NewProjectsTool(db, activeProjects),
|
||||
Context: tools.NewContextTool(db, provider, cfg.Search, activeProjects),
|
||||
Recall: tools.NewRecallTool(db, provider, cfg.Search, activeProjects),
|
||||
Summarize: tools.NewSummarizeTool(db, provider, cfg.Search, activeProjects),
|
||||
Links: tools.NewLinksTool(db, provider, cfg.Search),
|
||||
}
|
||||
|
||||
mcpHandler := mcpserver.New(cfg.MCP, toolSet)
|
||||
mux.Handle(cfg.MCP.Path, auth.Middleware(cfg.Auth, keyring, logger)(mcpHandler))
|
||||
|
||||
mux.HandleFunc("/healthz", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("ok"))
|
||||
})
|
||||
|
||||
mux.HandleFunc("/readyz", func(w http.ResponseWriter, r *http.Request) {
|
||||
if err := db.Ready(r.Context()); err != nil {
|
||||
logger.Error("readiness check failed", slog.String("error", err.Error()))
|
||||
http.Error(w, "not ready", http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("ready"))
|
||||
})
|
||||
|
||||
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("amcs is running"))
|
||||
})
|
||||
|
||||
return observability.Chain(
|
||||
mux,
|
||||
observability.RequestID(),
|
||||
observability.Recover(logger),
|
||||
observability.AccessLog(logger),
|
||||
observability.Timeout(cfg.Server.WriteTimeout),
|
||||
)
|
||||
}
|
||||
29
internal/auth/keyring.go
Normal file
29
internal/auth/keyring.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/subtle"
|
||||
"fmt"
|
||||
|
||||
"git.warky.dev/wdevs/amcs/internal/config"
|
||||
)
|
||||
|
||||
type Keyring struct {
|
||||
keys []config.APIKey
|
||||
}
|
||||
|
||||
func NewKeyring(keys []config.APIKey) (*Keyring, error) {
|
||||
if len(keys) == 0 {
|
||||
return nil, fmt.Errorf("keyring requires at least one key")
|
||||
}
|
||||
|
||||
return &Keyring{keys: append([]config.APIKey(nil), keys...)}, nil
|
||||
}
|
||||
|
||||
func (k *Keyring) Lookup(value string) (string, bool) {
|
||||
for _, key := range k.keys {
|
||||
if subtle.ConstantTimeCompare([]byte(key.Value), []byte(value)) == 1 {
|
||||
return key.ID, true
|
||||
}
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
107
internal/auth/keyring_test.go
Normal file
107
internal/auth/keyring_test.go
Normal file
@@ -0,0 +1,107 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"git.warky.dev/wdevs/amcs/internal/config"
|
||||
)
|
||||
|
||||
func testLogger() *slog.Logger {
|
||||
return slog.New(slog.NewTextHandler(io.Discard, nil))
|
||||
}
|
||||
|
||||
func TestNewKeyringAndLookup(t *testing.T) {
|
||||
_, err := NewKeyring(nil)
|
||||
if err == nil {
|
||||
t.Fatal("NewKeyring(nil) error = nil, want error")
|
||||
}
|
||||
|
||||
keyring, err := NewKeyring([]config.APIKey{{ID: "client-a", Value: "secret"}})
|
||||
if err != nil {
|
||||
t.Fatalf("NewKeyring() error = %v", err)
|
||||
}
|
||||
|
||||
if got, ok := keyring.Lookup("secret"); !ok || got != "client-a" {
|
||||
t.Fatalf("Lookup(secret) = (%q, %v), want (client-a, true)", got, ok)
|
||||
}
|
||||
if _, ok := keyring.Lookup("wrong"); ok {
|
||||
t.Fatal("Lookup(wrong) = true, want false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMiddlewareAllowsHeaderAuthAndSetsContext(t *testing.T) {
|
||||
keyring, err := NewKeyring([]config.APIKey{{ID: "client-a", Value: "secret"}})
|
||||
if err != nil {
|
||||
t.Fatalf("NewKeyring() error = %v", err)
|
||||
}
|
||||
|
||||
handler := Middleware(config.AuthConfig{HeaderName: "x-brain-key"}, keyring, testLogger())(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
keyID, ok := KeyIDFromContext(r.Context())
|
||||
if !ok || keyID != "client-a" {
|
||||
t.Fatalf("KeyIDFromContext() = (%q, %v), want (client-a, true)", keyID, ok)
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/mcp", nil)
|
||||
req.Header.Set("x-brain-key", "secret")
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMiddlewareAllowsQueryParamWhenEnabled(t *testing.T) {
|
||||
keyring, err := NewKeyring([]config.APIKey{{ID: "client-a", Value: "secret"}})
|
||||
if err != nil {
|
||||
t.Fatalf("NewKeyring() error = %v", err)
|
||||
}
|
||||
|
||||
handler := Middleware(config.AuthConfig{
|
||||
HeaderName: "x-brain-key",
|
||||
QueryParam: "key",
|
||||
AllowQueryParam: true,
|
||||
}, keyring, testLogger())(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/mcp?key=secret", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusNoContent {
|
||||
t.Fatalf("status = %d, want %d", rec.Code, http.StatusNoContent)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMiddlewareRejectsMissingOrInvalidKey(t *testing.T) {
|
||||
keyring, err := NewKeyring([]config.APIKey{{ID: "client-a", Value: "secret"}})
|
||||
if err != nil {
|
||||
t.Fatalf("NewKeyring() error = %v", err)
|
||||
}
|
||||
|
||||
handler := Middleware(config.AuthConfig{HeaderName: "x-brain-key"}, keyring, testLogger())(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
t.Fatal("next handler should not be called")
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/mcp", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
if rec.Code != http.StatusUnauthorized {
|
||||
t.Fatalf("missing key status = %d, want %d", rec.Code, http.StatusUnauthorized)
|
||||
}
|
||||
|
||||
req = httptest.NewRequest(http.MethodGet, "/mcp", nil)
|
||||
req.Header.Set("x-brain-key", "wrong")
|
||||
rec = httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
if rec.Code != http.StatusUnauthorized {
|
||||
t.Fatalf("invalid key status = %d, want %d", rec.Code, http.StatusUnauthorized)
|
||||
}
|
||||
}
|
||||
49
internal/auth/middleware.go
Normal file
49
internal/auth/middleware.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"git.warky.dev/wdevs/amcs/internal/config"
|
||||
)
|
||||
|
||||
type contextKey string
|
||||
|
||||
const keyIDContextKey contextKey = "auth.key_id"
|
||||
|
||||
func Middleware(cfg config.AuthConfig, keyring *Keyring, log *slog.Logger) func(http.Handler) http.Handler {
|
||||
headerName := cfg.HeaderName
|
||||
if headerName == "" {
|
||||
headerName = "x-brain-key"
|
||||
}
|
||||
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
token := strings.TrimSpace(r.Header.Get(headerName))
|
||||
if token == "" && cfg.AllowQueryParam {
|
||||
token = strings.TrimSpace(r.URL.Query().Get(cfg.QueryParam))
|
||||
}
|
||||
|
||||
if token == "" {
|
||||
http.Error(w, "missing API key", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
keyID, ok := keyring.Lookup(token)
|
||||
if !ok {
|
||||
log.Warn("authentication failed", slog.String("remote_addr", r.RemoteAddr))
|
||||
http.Error(w, "invalid API key", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), keyIDContextKey, keyID)))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func KeyIDFromContext(ctx context.Context) (string, bool) {
|
||||
value, ok := ctx.Value(keyIDContextKey).(string)
|
||||
return value, ok
|
||||
}
|
||||
119
internal/config/config.go
Normal file
119
internal/config/config.go
Normal file
@@ -0,0 +1,119 @@
|
||||
package config
|
||||
|
||||
import "time"
|
||||
|
||||
const (
|
||||
DefaultConfigPath = "./configs/dev.yaml"
|
||||
DefaultSource = "mcp"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Server ServerConfig `yaml:"server"`
|
||||
MCP MCPConfig `yaml:"mcp"`
|
||||
Auth AuthConfig `yaml:"auth"`
|
||||
Database DatabaseConfig `yaml:"database"`
|
||||
AI AIConfig `yaml:"ai"`
|
||||
Capture CaptureConfig `yaml:"capture"`
|
||||
Search SearchConfig `yaml:"search"`
|
||||
Logging LoggingConfig `yaml:"logging"`
|
||||
Observability ObservabilityConfig `yaml:"observability"`
|
||||
}
|
||||
|
||||
type ServerConfig struct {
|
||||
Host string `yaml:"host"`
|
||||
Port int `yaml:"port"`
|
||||
ReadTimeout time.Duration `yaml:"read_timeout"`
|
||||
WriteTimeout time.Duration `yaml:"write_timeout"`
|
||||
IdleTimeout time.Duration `yaml:"idle_timeout"`
|
||||
AllowedOrigins []string `yaml:"allowed_origins"`
|
||||
}
|
||||
|
||||
type MCPConfig struct {
|
||||
Path string `yaml:"path"`
|
||||
ServerName string `yaml:"server_name"`
|
||||
Version string `yaml:"version"`
|
||||
Transport string `yaml:"transport"`
|
||||
}
|
||||
|
||||
type AuthConfig struct {
|
||||
Mode string `yaml:"mode"`
|
||||
HeaderName string `yaml:"header_name"`
|
||||
QueryParam string `yaml:"query_param"`
|
||||
AllowQueryParam bool `yaml:"allow_query_param"`
|
||||
Keys []APIKey `yaml:"keys"`
|
||||
}
|
||||
|
||||
type APIKey struct {
|
||||
ID string `yaml:"id"`
|
||||
Value string `yaml:"value"`
|
||||
Description string `yaml:"description"`
|
||||
}
|
||||
|
||||
type DatabaseConfig struct {
|
||||
URL string `yaml:"url"`
|
||||
MaxConns int32 `yaml:"max_conns"`
|
||||
MinConns int32 `yaml:"min_conns"`
|
||||
MaxConnLifetime time.Duration `yaml:"max_conn_lifetime"`
|
||||
MaxConnIdleTime time.Duration `yaml:"max_conn_idle_time"`
|
||||
}
|
||||
|
||||
type AIConfig struct {
|
||||
Provider string `yaml:"provider"`
|
||||
Embeddings AIEmbeddingConfig `yaml:"embeddings"`
|
||||
Metadata AIMetadataConfig `yaml:"metadata"`
|
||||
LiteLLM LiteLLMConfig `yaml:"litellm"`
|
||||
OpenRouter OpenRouterAIConfig `yaml:"openrouter"`
|
||||
}
|
||||
|
||||
type AIEmbeddingConfig struct {
|
||||
Model string `yaml:"model"`
|
||||
Dimensions int `yaml:"dimensions"`
|
||||
}
|
||||
|
||||
type AIMetadataConfig struct {
|
||||
Model string `yaml:"model"`
|
||||
Temperature float64 `yaml:"temperature"`
|
||||
}
|
||||
|
||||
type LiteLLMConfig struct {
|
||||
BaseURL string `yaml:"base_url"`
|
||||
APIKey string `yaml:"api_key"`
|
||||
UseResponsesAPI bool `yaml:"use_responses_api"`
|
||||
RequestHeaders map[string]string `yaml:"request_headers"`
|
||||
EmbeddingModel string `yaml:"embedding_model"`
|
||||
MetadataModel string `yaml:"metadata_model"`
|
||||
}
|
||||
|
||||
type OpenRouterAIConfig struct {
|
||||
BaseURL string `yaml:"base_url"`
|
||||
APIKey string `yaml:"api_key"`
|
||||
AppName string `yaml:"app_name"`
|
||||
SiteURL string `yaml:"site_url"`
|
||||
ExtraHeaders map[string]string `yaml:"extra_headers"`
|
||||
}
|
||||
|
||||
type CaptureConfig struct {
|
||||
Source string `yaml:"source"`
|
||||
MetadataDefaults CaptureMetadataDefault `yaml:"metadata_defaults"`
|
||||
}
|
||||
|
||||
type CaptureMetadataDefault struct {
|
||||
Type string `yaml:"type"`
|
||||
TopicFallback string `yaml:"topic_fallback"`
|
||||
}
|
||||
|
||||
type SearchConfig struct {
|
||||
DefaultLimit int `yaml:"default_limit"`
|
||||
DefaultThreshold float64 `yaml:"default_threshold"`
|
||||
MaxLimit int `yaml:"max_limit"`
|
||||
}
|
||||
|
||||
type LoggingConfig struct {
|
||||
Level string `yaml:"level"`
|
||||
Format string `yaml:"format"`
|
||||
}
|
||||
|
||||
type ObservabilityConfig struct {
|
||||
MetricsEnabled bool `yaml:"metrics_enabled"`
|
||||
PprofEnabled bool `yaml:"pprof_enabled"`
|
||||
}
|
||||
113
internal/config/loader.go
Normal file
113
internal/config/loader.go
Normal file
@@ -0,0 +1,113 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
func Load(explicitPath string) (*Config, string, error) {
|
||||
path := ResolvePath(explicitPath)
|
||||
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, path, fmt.Errorf("read config %q: %w", path, err)
|
||||
}
|
||||
|
||||
cfg := defaultConfig()
|
||||
if err := yaml.Unmarshal(data, &cfg); err != nil {
|
||||
return nil, path, fmt.Errorf("decode config %q: %w", path, err)
|
||||
}
|
||||
|
||||
applyEnvOverrides(&cfg)
|
||||
if err := cfg.Validate(); err != nil {
|
||||
return nil, path, err
|
||||
}
|
||||
|
||||
return &cfg, path, nil
|
||||
}
|
||||
|
||||
func ResolvePath(explicitPath string) string {
|
||||
if strings.TrimSpace(explicitPath) != "" {
|
||||
return explicitPath
|
||||
}
|
||||
|
||||
if envPath := strings.TrimSpace(os.Getenv("OB1_CONFIG")); envPath != "" {
|
||||
return envPath
|
||||
}
|
||||
|
||||
return DefaultConfigPath
|
||||
}
|
||||
|
||||
func defaultConfig() Config {
|
||||
return Config{
|
||||
Server: ServerConfig{
|
||||
Host: "0.0.0.0",
|
||||
Port: 8080,
|
||||
ReadTimeout: 15 * time.Second,
|
||||
WriteTimeout: 30 * time.Second,
|
||||
IdleTimeout: 60 * time.Second,
|
||||
},
|
||||
MCP: MCPConfig{
|
||||
Path: "/mcp",
|
||||
ServerName: "amcs",
|
||||
Version: "0.1.0",
|
||||
Transport: "streamable_http",
|
||||
},
|
||||
Auth: AuthConfig{
|
||||
Mode: "api_keys",
|
||||
HeaderName: "x-brain-key",
|
||||
QueryParam: "key",
|
||||
},
|
||||
AI: AIConfig{
|
||||
Provider: "litellm",
|
||||
Embeddings: AIEmbeddingConfig{
|
||||
Model: "openai/text-embedding-3-small",
|
||||
Dimensions: 1536,
|
||||
},
|
||||
Metadata: AIMetadataConfig{
|
||||
Model: "gpt-4o-mini",
|
||||
Temperature: 0.1,
|
||||
},
|
||||
},
|
||||
Capture: CaptureConfig{
|
||||
Source: DefaultSource,
|
||||
MetadataDefaults: CaptureMetadataDefault{
|
||||
Type: "observation",
|
||||
TopicFallback: "uncategorized",
|
||||
},
|
||||
},
|
||||
Search: SearchConfig{
|
||||
DefaultLimit: 10,
|
||||
DefaultThreshold: 0.5,
|
||||
MaxLimit: 50,
|
||||
},
|
||||
Logging: LoggingConfig{
|
||||
Level: "info",
|
||||
Format: "json",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func applyEnvOverrides(cfg *Config) {
|
||||
overrideString(&cfg.Database.URL, "OB1_DATABASE_URL")
|
||||
overrideString(&cfg.AI.LiteLLM.BaseURL, "OB1_LITELLM_BASE_URL")
|
||||
overrideString(&cfg.AI.LiteLLM.APIKey, "OB1_LITELLM_API_KEY")
|
||||
overrideString(&cfg.AI.OpenRouter.APIKey, "OB1_OPENROUTER_API_KEY")
|
||||
|
||||
if value, ok := os.LookupEnv("OB1_SERVER_PORT"); ok {
|
||||
if port, err := strconv.Atoi(strings.TrimSpace(value)); err == nil {
|
||||
cfg.Server.Port = port
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func overrideString(target *string, envKey string) {
|
||||
if value, ok := os.LookupEnv(envKey); ok {
|
||||
*target = strings.TrimSpace(value)
|
||||
}
|
||||
}
|
||||
71
internal/config/loader_test.go
Normal file
71
internal/config/loader_test.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestResolvePathPrecedence(t *testing.T) {
|
||||
t.Setenv("OB1_CONFIG", "/tmp/from-env.yaml")
|
||||
|
||||
if got := ResolvePath("/tmp/explicit.yaml"); got != "/tmp/explicit.yaml" {
|
||||
t.Fatalf("ResolvePath explicit = %q, want %q", got, "/tmp/explicit.yaml")
|
||||
}
|
||||
|
||||
if got := ResolvePath(""); got != "/tmp/from-env.yaml" {
|
||||
t.Fatalf("ResolvePath env = %q, want %q", got, "/tmp/from-env.yaml")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadAppliesEnvOverrides(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "test.yaml")
|
||||
if err := os.WriteFile(configPath, []byte(`
|
||||
server:
|
||||
port: 8080
|
||||
mcp:
|
||||
path: "/mcp"
|
||||
auth:
|
||||
keys:
|
||||
- id: "test"
|
||||
value: "secret"
|
||||
database:
|
||||
url: "postgres://from-file"
|
||||
ai:
|
||||
provider: "litellm"
|
||||
embeddings:
|
||||
dimensions: 1536
|
||||
litellm:
|
||||
base_url: "http://localhost:4000/v1"
|
||||
api_key: "file-key"
|
||||
search:
|
||||
default_limit: 10
|
||||
max_limit: 50
|
||||
logging:
|
||||
level: "info"
|
||||
`), 0o600); err != nil {
|
||||
t.Fatalf("write config: %v", err)
|
||||
}
|
||||
|
||||
t.Setenv("OB1_DATABASE_URL", "postgres://from-env")
|
||||
t.Setenv("OB1_LITELLM_API_KEY", "env-key")
|
||||
t.Setenv("OB1_SERVER_PORT", "9090")
|
||||
|
||||
cfg, loadedFrom, err := Load(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Load() error = %v", err)
|
||||
}
|
||||
|
||||
if loadedFrom != configPath {
|
||||
t.Fatalf("loadedFrom = %q, want %q", loadedFrom, configPath)
|
||||
}
|
||||
if cfg.Database.URL != "postgres://from-env" {
|
||||
t.Fatalf("database url = %q, want env override", cfg.Database.URL)
|
||||
}
|
||||
if cfg.AI.LiteLLM.APIKey != "env-key" {
|
||||
t.Fatalf("litellm api key = %q, want env override", cfg.AI.LiteLLM.APIKey)
|
||||
}
|
||||
if cfg.Server.Port != 9090 {
|
||||
t.Fatalf("server port = %d, want 9090", cfg.Server.Port)
|
||||
}
|
||||
}
|
||||
71
internal/config/validate.go
Normal file
71
internal/config/validate.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func (c Config) Validate() error {
|
||||
if strings.TrimSpace(c.Database.URL) == "" {
|
||||
return fmt.Errorf("invalid config: database.url is required")
|
||||
}
|
||||
|
||||
if len(c.Auth.Keys) == 0 {
|
||||
return fmt.Errorf("invalid config: auth.keys must not be empty")
|
||||
}
|
||||
|
||||
for i, key := range c.Auth.Keys {
|
||||
if strings.TrimSpace(key.ID) == "" {
|
||||
return fmt.Errorf("invalid config: auth.keys[%d].id is required", i)
|
||||
}
|
||||
if strings.TrimSpace(key.Value) == "" {
|
||||
return fmt.Errorf("invalid config: auth.keys[%d].value is required", i)
|
||||
}
|
||||
}
|
||||
|
||||
if strings.TrimSpace(c.MCP.Path) == "" {
|
||||
return fmt.Errorf("invalid config: mcp.path is required")
|
||||
}
|
||||
|
||||
switch c.AI.Provider {
|
||||
case "litellm", "openrouter":
|
||||
default:
|
||||
return fmt.Errorf("invalid config: unsupported ai.provider %q", c.AI.Provider)
|
||||
}
|
||||
|
||||
if c.AI.Embeddings.Dimensions <= 0 {
|
||||
return fmt.Errorf("invalid config: ai.embeddings.dimensions must be greater than zero")
|
||||
}
|
||||
|
||||
switch c.AI.Provider {
|
||||
case "litellm":
|
||||
if strings.TrimSpace(c.AI.LiteLLM.BaseURL) == "" {
|
||||
return fmt.Errorf("invalid config: ai.litellm.base_url is required when ai.provider=litellm")
|
||||
}
|
||||
if strings.TrimSpace(c.AI.LiteLLM.APIKey) == "" {
|
||||
return fmt.Errorf("invalid config: ai.litellm.api_key is required when ai.provider=litellm")
|
||||
}
|
||||
case "openrouter":
|
||||
if strings.TrimSpace(c.AI.OpenRouter.BaseURL) == "" {
|
||||
return fmt.Errorf("invalid config: ai.openrouter.base_url is required when ai.provider=openrouter")
|
||||
}
|
||||
if strings.TrimSpace(c.AI.OpenRouter.APIKey) == "" {
|
||||
return fmt.Errorf("invalid config: ai.openrouter.api_key is required when ai.provider=openrouter")
|
||||
}
|
||||
}
|
||||
|
||||
if c.Server.Port <= 0 {
|
||||
return fmt.Errorf("invalid config: server.port must be greater than zero")
|
||||
}
|
||||
if c.Search.DefaultLimit <= 0 {
|
||||
return fmt.Errorf("invalid config: search.default_limit must be greater than zero")
|
||||
}
|
||||
if c.Search.MaxLimit < c.Search.DefaultLimit {
|
||||
return fmt.Errorf("invalid config: search.max_limit must be greater than or equal to search.default_limit")
|
||||
}
|
||||
if strings.TrimSpace(c.Logging.Level) == "" {
|
||||
return fmt.Errorf("invalid config: logging.level is required")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
60
internal/config/validate_test.go
Normal file
60
internal/config/validate_test.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package config
|
||||
|
||||
import "testing"
|
||||
|
||||
func validConfig() Config {
|
||||
return Config{
|
||||
Server: ServerConfig{Port: 8080},
|
||||
MCP: MCPConfig{Path: "/mcp"},
|
||||
Auth: AuthConfig{
|
||||
Keys: []APIKey{{ID: "test", Value: "secret"}},
|
||||
},
|
||||
Database: DatabaseConfig{URL: "postgres://example"},
|
||||
AI: AIConfig{
|
||||
Provider: "litellm",
|
||||
Embeddings: AIEmbeddingConfig{
|
||||
Dimensions: 1536,
|
||||
},
|
||||
LiteLLM: LiteLLMConfig{
|
||||
BaseURL: "http://localhost:4000/v1",
|
||||
APIKey: "key",
|
||||
},
|
||||
OpenRouter: OpenRouterAIConfig{
|
||||
BaseURL: "https://openrouter.ai/api/v1",
|
||||
APIKey: "key",
|
||||
},
|
||||
},
|
||||
Search: SearchConfig{DefaultLimit: 10, MaxLimit: 50},
|
||||
Logging: LoggingConfig{Level: "info"},
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateAcceptsLiteLLMAndOpenRouter(t *testing.T) {
|
||||
cfg := validConfig()
|
||||
if err := cfg.Validate(); err != nil {
|
||||
t.Fatalf("Validate litellm error = %v", err)
|
||||
}
|
||||
|
||||
cfg.AI.Provider = "openrouter"
|
||||
if err := cfg.Validate(); err != nil {
|
||||
t.Fatalf("Validate openrouter error = %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateRejectsInvalidProvider(t *testing.T) {
|
||||
cfg := validConfig()
|
||||
cfg.AI.Provider = "unknown"
|
||||
|
||||
if err := cfg.Validate(); err == nil {
|
||||
t.Fatal("Validate() error = nil, want error for unsupported provider")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateRejectsEmptyAuthKeyValue(t *testing.T) {
|
||||
cfg := validConfig()
|
||||
cfg.Auth.Keys[0].Value = ""
|
||||
|
||||
if err := cfg.Validate(); err == nil {
|
||||
t.Fatal("Validate() error = nil, want error for empty auth key value")
|
||||
}
|
||||
}
|
||||
126
internal/mcpserver/server.go
Normal file
126
internal/mcpserver/server.go
Normal file
@@ -0,0 +1,126 @@
|
||||
package mcpserver
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
|
||||
"git.warky.dev/wdevs/amcs/internal/config"
|
||||
"git.warky.dev/wdevs/amcs/internal/tools"
|
||||
)
|
||||
|
||||
type ToolSet struct {
|
||||
Capture *tools.CaptureTool
|
||||
Search *tools.SearchTool
|
||||
List *tools.ListTool
|
||||
Stats *tools.StatsTool
|
||||
Get *tools.GetTool
|
||||
Update *tools.UpdateTool
|
||||
Delete *tools.DeleteTool
|
||||
Archive *tools.ArchiveTool
|
||||
Projects *tools.ProjectsTool
|
||||
Context *tools.ContextTool
|
||||
Recall *tools.RecallTool
|
||||
Summarize *tools.SummarizeTool
|
||||
Links *tools.LinksTool
|
||||
}
|
||||
|
||||
func New(cfg config.MCPConfig, toolSet ToolSet) http.Handler {
|
||||
server := mcp.NewServer(&mcp.Implementation{
|
||||
Name: cfg.ServerName,
|
||||
Version: cfg.Version,
|
||||
}, nil)
|
||||
|
||||
mcp.AddTool(server, &mcp.Tool{
|
||||
Name: "capture_thought",
|
||||
Description: "Store a thought with generated embeddings and extracted metadata.",
|
||||
}, toolSet.Capture.Handle)
|
||||
|
||||
mcp.AddTool(server, &mcp.Tool{
|
||||
Name: "search_thoughts",
|
||||
Description: "Search stored thoughts by semantic similarity.",
|
||||
}, toolSet.Search.Handle)
|
||||
|
||||
mcp.AddTool(server, &mcp.Tool{
|
||||
Name: "list_thoughts",
|
||||
Description: "List recent thoughts with optional metadata filters.",
|
||||
}, toolSet.List.Handle)
|
||||
|
||||
mcp.AddTool(server, &mcp.Tool{
|
||||
Name: "thought_stats",
|
||||
Description: "Get counts and top metadata buckets across stored thoughts.",
|
||||
}, toolSet.Stats.Handle)
|
||||
|
||||
mcp.AddTool(server, &mcp.Tool{
|
||||
Name: "get_thought",
|
||||
Description: "Retrieve a full thought by id.",
|
||||
}, toolSet.Get.Handle)
|
||||
|
||||
mcp.AddTool(server, &mcp.Tool{
|
||||
Name: "update_thought",
|
||||
Description: "Update thought content or merge metadata.",
|
||||
}, toolSet.Update.Handle)
|
||||
|
||||
mcp.AddTool(server, &mcp.Tool{
|
||||
Name: "delete_thought",
|
||||
Description: "Hard-delete a thought by id.",
|
||||
}, toolSet.Delete.Handle)
|
||||
|
||||
mcp.AddTool(server, &mcp.Tool{
|
||||
Name: "archive_thought",
|
||||
Description: "Archive a thought so it is hidden from default search and listing.",
|
||||
}, toolSet.Archive.Handle)
|
||||
|
||||
mcp.AddTool(server, &mcp.Tool{
|
||||
Name: "create_project",
|
||||
Description: "Create a named project container for thoughts.",
|
||||
}, toolSet.Projects.Create)
|
||||
|
||||
mcp.AddTool(server, &mcp.Tool{
|
||||
Name: "list_projects",
|
||||
Description: "List projects and their current thought counts.",
|
||||
}, toolSet.Projects.List)
|
||||
|
||||
mcp.AddTool(server, &mcp.Tool{
|
||||
Name: "set_active_project",
|
||||
Description: "Set the active project for the current MCP session.",
|
||||
}, toolSet.Projects.SetActive)
|
||||
|
||||
mcp.AddTool(server, &mcp.Tool{
|
||||
Name: "get_active_project",
|
||||
Description: "Return the active project for the current MCP session.",
|
||||
}, toolSet.Projects.GetActive)
|
||||
|
||||
mcp.AddTool(server, &mcp.Tool{
|
||||
Name: "get_project_context",
|
||||
Description: "Get recent and semantic context for a project.",
|
||||
}, toolSet.Context.Handle)
|
||||
|
||||
mcp.AddTool(server, &mcp.Tool{
|
||||
Name: "recall_context",
|
||||
Description: "Recall semantically relevant and recent context.",
|
||||
}, toolSet.Recall.Handle)
|
||||
|
||||
mcp.AddTool(server, &mcp.Tool{
|
||||
Name: "summarize_thoughts",
|
||||
Description: "Summarize a filtered or searched set of thoughts.",
|
||||
}, toolSet.Summarize.Handle)
|
||||
|
||||
mcp.AddTool(server, &mcp.Tool{
|
||||
Name: "link_thoughts",
|
||||
Description: "Create a typed relationship between two thoughts.",
|
||||
}, toolSet.Links.Link)
|
||||
|
||||
mcp.AddTool(server, &mcp.Tool{
|
||||
Name: "related_thoughts",
|
||||
Description: "Retrieve explicit links and semantic neighbors for a thought.",
|
||||
}, toolSet.Links.Related)
|
||||
|
||||
return mcp.NewStreamableHTTPHandler(func(*http.Request) *mcp.Server {
|
||||
return server
|
||||
}, &mcp.StreamableHTTPOptions{
|
||||
JSONResponse: true,
|
||||
SessionTimeout: 10 * time.Minute,
|
||||
})
|
||||
}
|
||||
155
internal/metadata/normalize.go
Normal file
155
internal/metadata/normalize.go
Normal file
@@ -0,0 +1,155 @@
|
||||
package metadata
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"git.warky.dev/wdevs/amcs/internal/config"
|
||||
thoughttypes "git.warky.dev/wdevs/amcs/internal/types"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultType = "observation"
|
||||
DefaultTopicFallback = "uncategorized"
|
||||
maxTopics = 10
|
||||
)
|
||||
|
||||
var allowedTypes = map[string]struct{}{
|
||||
"observation": {},
|
||||
"task": {},
|
||||
"idea": {},
|
||||
"reference": {},
|
||||
"person_note": {},
|
||||
}
|
||||
|
||||
func Fallback(capture config.CaptureConfig) thoughttypes.ThoughtMetadata {
|
||||
topicFallback := strings.TrimSpace(capture.MetadataDefaults.TopicFallback)
|
||||
if topicFallback == "" {
|
||||
topicFallback = DefaultTopicFallback
|
||||
}
|
||||
|
||||
return thoughttypes.ThoughtMetadata{
|
||||
People: []string{},
|
||||
ActionItems: []string{},
|
||||
DatesMentioned: []string{},
|
||||
Topics: []string{topicFallback},
|
||||
Type: normalizeType(capture.MetadataDefaults.Type),
|
||||
Source: normalizeSource(capture.Source),
|
||||
}
|
||||
}
|
||||
|
||||
func Normalize(in thoughttypes.ThoughtMetadata, capture config.CaptureConfig) thoughttypes.ThoughtMetadata {
|
||||
out := thoughttypes.ThoughtMetadata{
|
||||
People: normalizeList(in.People, 0),
|
||||
ActionItems: normalizeList(in.ActionItems, 0),
|
||||
DatesMentioned: normalizeList(in.DatesMentioned, 0),
|
||||
Topics: normalizeList(in.Topics, maxTopics),
|
||||
Type: normalizeType(in.Type),
|
||||
Source: normalizeSource(in.Source),
|
||||
}
|
||||
|
||||
if len(out.Topics) == 0 {
|
||||
out.Topics = Fallback(capture).Topics
|
||||
}
|
||||
if out.Type == "" {
|
||||
out.Type = Fallback(capture).Type
|
||||
}
|
||||
if out.Source == "" {
|
||||
out.Source = Fallback(capture).Source
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func normalizeList(values []string, limit int) []string {
|
||||
seen := make(map[string]struct{}, len(values))
|
||||
result := make([]string, 0, len(values))
|
||||
|
||||
for _, value := range values {
|
||||
trimmed := strings.Join(strings.Fields(strings.TrimSpace(value)), " ")
|
||||
if trimmed == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
key := strings.ToLower(trimmed)
|
||||
if _, ok := seen[key]; ok {
|
||||
continue
|
||||
}
|
||||
|
||||
seen[key] = struct{}{}
|
||||
result = append(result, trimmed)
|
||||
|
||||
if limit > 0 && len(result) >= limit {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func normalizeType(value string) string {
|
||||
normalized := strings.ToLower(strings.TrimSpace(value))
|
||||
if normalized == "" {
|
||||
return DefaultType
|
||||
}
|
||||
if _, ok := allowedTypes[normalized]; ok {
|
||||
return normalized
|
||||
}
|
||||
return DefaultType
|
||||
}
|
||||
|
||||
func normalizeSource(value string) string {
|
||||
normalized := strings.TrimSpace(value)
|
||||
if normalized == "" {
|
||||
return config.DefaultSource
|
||||
}
|
||||
return normalized
|
||||
}
|
||||
|
||||
func Merge(base, patch thoughttypes.ThoughtMetadata, capture config.CaptureConfig) thoughttypes.ThoughtMetadata {
|
||||
merged := base
|
||||
|
||||
if len(patch.People) > 0 {
|
||||
merged.People = append(append([]string{}, merged.People...), patch.People...)
|
||||
}
|
||||
if len(patch.ActionItems) > 0 {
|
||||
merged.ActionItems = append(append([]string{}, merged.ActionItems...), patch.ActionItems...)
|
||||
}
|
||||
if len(patch.DatesMentioned) > 0 {
|
||||
merged.DatesMentioned = append(append([]string{}, merged.DatesMentioned...), patch.DatesMentioned...)
|
||||
}
|
||||
if len(patch.Topics) > 0 {
|
||||
merged.Topics = append(append([]string{}, merged.Topics...), patch.Topics...)
|
||||
}
|
||||
if strings.TrimSpace(patch.Type) != "" {
|
||||
merged.Type = patch.Type
|
||||
}
|
||||
if strings.TrimSpace(patch.Source) != "" {
|
||||
merged.Source = patch.Source
|
||||
}
|
||||
|
||||
return Normalize(merged, capture)
|
||||
}
|
||||
|
||||
func SortedTopCounts(in map[string]int, limit int) []thoughttypes.KeyCount {
|
||||
out := make([]thoughttypes.KeyCount, 0, len(in))
|
||||
for key, count := range in {
|
||||
if strings.TrimSpace(key) == "" {
|
||||
continue
|
||||
}
|
||||
out = append(out, thoughttypes.KeyCount{Key: key, Count: count})
|
||||
}
|
||||
|
||||
sort.Slice(out, func(i, j int) bool {
|
||||
if out[i].Count == out[j].Count {
|
||||
return out[i].Key < out[j].Key
|
||||
}
|
||||
return out[i].Count > out[j].Count
|
||||
})
|
||||
|
||||
if limit > 0 && len(out) > limit {
|
||||
return out[:limit]
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
81
internal/metadata/normalize_test.go
Normal file
81
internal/metadata/normalize_test.go
Normal file
@@ -0,0 +1,81 @@
|
||||
package metadata
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"git.warky.dev/wdevs/amcs/internal/config"
|
||||
thoughttypes "git.warky.dev/wdevs/amcs/internal/types"
|
||||
)
|
||||
|
||||
func testCaptureConfig() config.CaptureConfig {
|
||||
return config.CaptureConfig{
|
||||
Source: "mcp",
|
||||
MetadataDefaults: config.CaptureMetadataDefault{
|
||||
Type: "observation",
|
||||
TopicFallback: "uncategorized",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func TestFallbackUsesConfiguredDefaults(t *testing.T) {
|
||||
got := Fallback(testCaptureConfig())
|
||||
if got.Type != "observation" {
|
||||
t.Fatalf("Fallback type = %q, want observation", got.Type)
|
||||
}
|
||||
if len(got.Topics) != 1 || got.Topics[0] != "uncategorized" {
|
||||
t.Fatalf("Fallback topics = %#v, want [uncategorized]", got.Topics)
|
||||
}
|
||||
if got.Source != "mcp" {
|
||||
t.Fatalf("Fallback source = %q, want mcp", got.Source)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeTrimsDedupesAndCapsTopics(t *testing.T) {
|
||||
topics := []string{}
|
||||
for i := 0; i < 12; i++ {
|
||||
topics = append(topics, strings.TrimSpace(" topic "))
|
||||
topics = append(topics, string(rune('a'+i)))
|
||||
}
|
||||
|
||||
got := Normalize(thoughttypes.ThoughtMetadata{
|
||||
People: []string{" Alice ", "alice", "", "Bob"},
|
||||
Topics: topics,
|
||||
Type: "INVALID",
|
||||
}, testCaptureConfig())
|
||||
|
||||
if len(got.People) != 2 {
|
||||
t.Fatalf("People len = %d, want 2", len(got.People))
|
||||
}
|
||||
if got.Type != "observation" {
|
||||
t.Fatalf("Type = %q, want observation", got.Type)
|
||||
}
|
||||
if len(got.Topics) != maxTopics {
|
||||
t.Fatalf("Topics len = %d, want %d", len(got.Topics), maxTopics)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeAddsPatchAndNormalizes(t *testing.T) {
|
||||
base := thoughttypes.ThoughtMetadata{
|
||||
People: []string{"Alice"},
|
||||
Topics: []string{"go"},
|
||||
Type: "idea",
|
||||
Source: "mcp",
|
||||
}
|
||||
patch := thoughttypes.ThoughtMetadata{
|
||||
People: []string{" Bob ", "alice"},
|
||||
Topics: []string{"testing"},
|
||||
Type: "task",
|
||||
}
|
||||
|
||||
got := Merge(base, patch, testCaptureConfig())
|
||||
if got.Type != "task" {
|
||||
t.Fatalf("Type = %q, want task", got.Type)
|
||||
}
|
||||
if len(got.People) != 2 {
|
||||
t.Fatalf("People len = %d, want 2", len(got.People))
|
||||
}
|
||||
if len(got.Topics) != 2 {
|
||||
t.Fatalf("Topics len = %d, want 2", len(got.Topics))
|
||||
}
|
||||
}
|
||||
110
internal/observability/http.go
Normal file
110
internal/observability/http.go
Normal file
@@ -0,0 +1,110 @@
|
||||
package observability
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"runtime/debug"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type contextKey string
|
||||
|
||||
const requestIDContextKey contextKey = "request_id"
|
||||
|
||||
func Chain(h http.Handler, middlewares ...func(http.Handler) http.Handler) http.Handler {
|
||||
for i := len(middlewares) - 1; i >= 0; i-- {
|
||||
h = middlewares[i](h)
|
||||
}
|
||||
return h
|
||||
}
|
||||
|
||||
func RequestID() func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
requestID := r.Header.Get("X-Request-Id")
|
||||
if requestID == "" {
|
||||
requestID = uuid.NewString()
|
||||
}
|
||||
w.Header().Set("X-Request-Id", requestID)
|
||||
ctx := context.WithValue(r.Context(), requestIDContextKey, requestID)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Recover(log *slog.Logger) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
defer func() {
|
||||
if recovered := recover(); recovered != nil {
|
||||
log.Error("panic recovered",
|
||||
slog.Any("panic", recovered),
|
||||
slog.String("request_id", RequestIDFromContext(r.Context())),
|
||||
slog.String("stack", string(debug.Stack())),
|
||||
)
|
||||
http.Error(w, "internal server error", http.StatusInternalServerError)
|
||||
}
|
||||
}()
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func AccessLog(log *slog.Logger) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
recorder := &statusRecorder{ResponseWriter: w, status: http.StatusOK}
|
||||
started := time.Now()
|
||||
next.ServeHTTP(recorder, r)
|
||||
|
||||
log.Info("http request",
|
||||
slog.String("request_id", RequestIDFromContext(r.Context())),
|
||||
slog.String("method", r.Method),
|
||||
slog.String("path", r.URL.Path),
|
||||
slog.Int("status", recorder.status),
|
||||
slog.Duration("duration", time.Since(started)),
|
||||
slog.String("remote_addr", stripPort(r.RemoteAddr)),
|
||||
)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Timeout(timeout time.Duration) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
if timeout <= 0 {
|
||||
return next
|
||||
}
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, cancel := context.WithTimeout(r.Context(), timeout)
|
||||
defer cancel()
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func RequestIDFromContext(ctx context.Context) string {
|
||||
value, _ := ctx.Value(requestIDContextKey).(string)
|
||||
return value
|
||||
}
|
||||
|
||||
type statusRecorder struct {
|
||||
http.ResponseWriter
|
||||
status int
|
||||
}
|
||||
|
||||
func (s *statusRecorder) WriteHeader(statusCode int) {
|
||||
s.status = statusCode
|
||||
s.ResponseWriter.WriteHeader(statusCode)
|
||||
}
|
||||
|
||||
func stripPort(remote string) string {
|
||||
host, _, err := net.SplitHostPort(remote)
|
||||
if err != nil {
|
||||
return remote
|
||||
}
|
||||
return host
|
||||
}
|
||||
59
internal/observability/http_test.go
Normal file
59
internal/observability/http_test.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package observability
|
||||
|
||||
import (
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestRequestIDSetsHeaderAndContext(t *testing.T) {
|
||||
handler := RequestID()(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if got := RequestIDFromContext(r.Context()); got == "" {
|
||||
t.Fatal("RequestIDFromContext() = empty, want non-empty")
|
||||
}
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Header().Get("X-Request-Id") == "" {
|
||||
t.Fatal("X-Request-Id header = empty, want non-empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTimeoutAddsContextDeadline(t *testing.T) {
|
||||
handler := Timeout(50 * time.Millisecond)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if _, ok := r.Context().Deadline(); !ok {
|
||||
t.Fatal("context deadline missing")
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecoverHandlesPanic(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
|
||||
handler := Recover(logger)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
panic("boom")
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusInternalServerError {
|
||||
t.Fatalf("status = %d, want %d", rec.Code, http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
52
internal/observability/logger.go
Normal file
52
internal/observability/logger.go
Normal file
@@ -0,0 +1,52 @@
|
||||
package observability
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"git.warky.dev/wdevs/amcs/internal/config"
|
||||
)
|
||||
|
||||
func NewLogger(cfg config.LoggingConfig) (*slog.Logger, error) {
|
||||
level, err := parseLevel(cfg.Level)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
options := &slog.HandlerOptions{Level: level}
|
||||
handler, err := newHandler(cfg.Format, os.Stdout, options)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return slog.New(handler), nil
|
||||
}
|
||||
|
||||
func parseLevel(value string) (slog.Leveler, error) {
|
||||
switch strings.ToLower(strings.TrimSpace(value)) {
|
||||
case "", "info":
|
||||
return slog.LevelInfo, nil
|
||||
case "debug":
|
||||
return slog.LevelDebug, nil
|
||||
case "warn", "warning":
|
||||
return slog.LevelWarn, nil
|
||||
case "error":
|
||||
return slog.LevelError, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid logging.level %q", value)
|
||||
}
|
||||
}
|
||||
|
||||
func newHandler(format string, w io.Writer, opts *slog.HandlerOptions) (slog.Handler, error) {
|
||||
switch strings.ToLower(strings.TrimSpace(format)) {
|
||||
case "", "json":
|
||||
return slog.NewJSONHandler(w, opts), nil
|
||||
case "text":
|
||||
return slog.NewTextHandler(w, opts), nil
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid logging.format %q", format)
|
||||
}
|
||||
}
|
||||
37
internal/session/active_project.go
Normal file
37
internal/session/active_project.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type ActiveProjects struct {
|
||||
mu sync.RWMutex
|
||||
bySession map[string]uuid.UUID
|
||||
}
|
||||
|
||||
func NewActiveProjects() *ActiveProjects {
|
||||
return &ActiveProjects{
|
||||
bySession: map[string]uuid.UUID{},
|
||||
}
|
||||
}
|
||||
|
||||
func (a *ActiveProjects) Set(sessionID string, projectID uuid.UUID) {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
a.bySession[sessionID] = projectID
|
||||
}
|
||||
|
||||
func (a *ActiveProjects) Get(sessionID string) (uuid.UUID, bool) {
|
||||
a.mu.RLock()
|
||||
defer a.mu.RUnlock()
|
||||
projectID, ok := a.bySession[sessionID]
|
||||
return projectID, ok
|
||||
}
|
||||
|
||||
func (a *ActiveProjects) Clear(sessionID string) {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
delete(a.bySession, sessionID)
|
||||
}
|
||||
27
internal/session/active_project_test.go
Normal file
27
internal/session/active_project_test.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
func TestActiveProjectsSetGetClear(t *testing.T) {
|
||||
store := NewActiveProjects()
|
||||
projectID := uuid.New()
|
||||
|
||||
if _, ok := store.Get("session-1"); ok {
|
||||
t.Fatal("Get() before Set() = true, want false")
|
||||
}
|
||||
|
||||
store.Set("session-1", projectID)
|
||||
got, ok := store.Get("session-1")
|
||||
if !ok || got != projectID {
|
||||
t.Fatalf("Get() = (%v, %v), want (%v, true)", got, ok, projectID)
|
||||
}
|
||||
|
||||
store.Clear("session-1")
|
||||
if _, ok := store.Get("session-1"); ok {
|
||||
t.Fatal("Get() after Clear() = true, want false")
|
||||
}
|
||||
}
|
||||
118
internal/store/db.go
Normal file
118
internal/store/db.go
Normal file
@@ -0,0 +1,118 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
pgxvec "github.com/pgvector/pgvector-go/pgx"
|
||||
|
||||
"git.warky.dev/wdevs/amcs/internal/config"
|
||||
)
|
||||
|
||||
type DB struct {
|
||||
pool *pgxpool.Pool
|
||||
}
|
||||
|
||||
func New(ctx context.Context, cfg config.DatabaseConfig) (*DB, error) {
|
||||
poolConfig, err := pgxpool.ParseConfig(cfg.URL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse database config: %w", err)
|
||||
}
|
||||
|
||||
poolConfig.MaxConns = cfg.MaxConns
|
||||
poolConfig.MinConns = cfg.MinConns
|
||||
poolConfig.MaxConnLifetime = cfg.MaxConnLifetime
|
||||
poolConfig.MaxConnIdleTime = cfg.MaxConnIdleTime
|
||||
poolConfig.AfterConnect = func(ctx context.Context, conn *pgx.Conn) error {
|
||||
return pgxvec.RegisterTypes(ctx, conn)
|
||||
}
|
||||
|
||||
pool, err := pgxpool.NewWithConfig(ctx, poolConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create database pool: %w", err)
|
||||
}
|
||||
|
||||
db := &DB{pool: pool}
|
||||
if err := db.Ping(ctx); err != nil {
|
||||
pool.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
func (db *DB) Close() {
|
||||
if db == nil || db.pool == nil {
|
||||
return
|
||||
}
|
||||
|
||||
db.pool.Close()
|
||||
}
|
||||
|
||||
func (db *DB) Ping(ctx context.Context) error {
|
||||
if err := db.pool.Ping(ctx); err != nil {
|
||||
return fmt.Errorf("ping database: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *DB) Ready(ctx context.Context) error {
|
||||
readyCtx, cancel := context.WithTimeout(ctx, 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
return db.Ping(readyCtx)
|
||||
}
|
||||
|
||||
func (db *DB) VerifyRequirements(ctx context.Context, dimensions int) error {
|
||||
var hasVector bool
|
||||
if err := db.pool.QueryRow(ctx, `select exists(select 1 from pg_extension where extname = 'vector')`).Scan(&hasVector); err != nil {
|
||||
return fmt.Errorf("verify vector extension: %w", err)
|
||||
}
|
||||
if !hasVector {
|
||||
return fmt.Errorf("vector extension is not installed")
|
||||
}
|
||||
|
||||
var hasMatchThoughts bool
|
||||
if err := db.pool.QueryRow(ctx, `select exists(select 1 from pg_proc where proname = 'match_thoughts')`).Scan(&hasMatchThoughts); err != nil {
|
||||
return fmt.Errorf("verify match_thoughts function: %w", err)
|
||||
}
|
||||
if !hasMatchThoughts {
|
||||
return fmt.Errorf("match_thoughts function is missing")
|
||||
}
|
||||
|
||||
var embeddingType string
|
||||
err := db.pool.QueryRow(ctx, `
|
||||
select format_type(a.atttypid, a.atttypmod)
|
||||
from pg_attribute a
|
||||
join pg_class c on c.oid = a.attrelid
|
||||
join pg_namespace n on n.oid = c.relnamespace
|
||||
where n.nspname = 'public'
|
||||
and c.relname = 'thoughts'
|
||||
and a.attname = 'embedding'
|
||||
and not a.attisdropped
|
||||
`).Scan(&embeddingType)
|
||||
if err != nil {
|
||||
return fmt.Errorf("verify thoughts.embedding type: %w", err)
|
||||
}
|
||||
|
||||
re := regexp.MustCompile(`vector\((\d+)\)`)
|
||||
matches := re.FindStringSubmatch(embeddingType)
|
||||
if len(matches) != 2 {
|
||||
return fmt.Errorf("unexpected embedding type %q", embeddingType)
|
||||
}
|
||||
|
||||
var actualDimensions int
|
||||
if _, err := fmt.Sscanf(matches[1], "%d", &actualDimensions); err != nil {
|
||||
return fmt.Errorf("parse embedding dimensions from %q: %w", embeddingType, err)
|
||||
}
|
||||
if actualDimensions != dimensions {
|
||||
return fmt.Errorf("embedding dimension mismatch: config=%d db=%d", dimensions, actualDimensions)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
69
internal/store/links.go
Normal file
69
internal/store/links.go
Normal file
@@ -0,0 +1,69 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
thoughttypes "git.warky.dev/wdevs/amcs/internal/types"
|
||||
)
|
||||
|
||||
func (db *DB) InsertLink(ctx context.Context, link thoughttypes.ThoughtLink) error {
|
||||
_, err := db.pool.Exec(ctx, `
|
||||
insert into thought_links (from_id, to_id, relation)
|
||||
values ($1, $2, $3)
|
||||
`, link.FromID, link.ToID, link.Relation)
|
||||
if err != nil {
|
||||
return fmt.Errorf("insert link: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *DB) LinkedThoughts(ctx context.Context, thoughtID uuid.UUID) ([]thoughttypes.LinkedThought, error) {
|
||||
rows, err := db.pool.Query(ctx, `
|
||||
select t.id, t.content, t.metadata, t.project_id, t.archived_at, t.created_at, t.updated_at, l.relation, 'outgoing' as direction, l.created_at
|
||||
from thought_links l
|
||||
join thoughts t on t.id = l.to_id
|
||||
where l.from_id = $1
|
||||
union all
|
||||
select t.id, t.content, t.metadata, t.project_id, t.archived_at, t.created_at, t.updated_at, l.relation, 'incoming' as direction, l.created_at
|
||||
from thought_links l
|
||||
join thoughts t on t.id = l.from_id
|
||||
where l.to_id = $1
|
||||
order by created_at desc
|
||||
`, thoughtID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query linked thoughts: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
links := make([]thoughttypes.LinkedThought, 0)
|
||||
for rows.Next() {
|
||||
var linked thoughttypes.LinkedThought
|
||||
var metadataBytes []byte
|
||||
if err := rows.Scan(
|
||||
&linked.Thought.ID,
|
||||
&linked.Thought.Content,
|
||||
&metadataBytes,
|
||||
&linked.Thought.ProjectID,
|
||||
&linked.Thought.ArchivedAt,
|
||||
&linked.Thought.CreatedAt,
|
||||
&linked.Thought.UpdatedAt,
|
||||
&linked.Relation,
|
||||
&linked.Direction,
|
||||
&linked.CreatedAt,
|
||||
); err != nil {
|
||||
return nil, fmt.Errorf("scan linked thought: %w", err)
|
||||
}
|
||||
if err := json.Unmarshal(metadataBytes, &linked.Thought.Metadata); err != nil {
|
||||
return nil, fmt.Errorf("decode linked thought metadata: %w", err)
|
||||
}
|
||||
links = append(links, linked)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("iterate linked thoughts: %w", err)
|
||||
}
|
||||
return links, nil
|
||||
}
|
||||
90
internal/store/projects.go
Normal file
90
internal/store/projects.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5"
|
||||
|
||||
thoughttypes "git.warky.dev/wdevs/amcs/internal/types"
|
||||
)
|
||||
|
||||
func (db *DB) CreateProject(ctx context.Context, name, description string) (thoughttypes.Project, error) {
|
||||
row := db.pool.QueryRow(ctx, `
|
||||
insert into projects (name, description)
|
||||
values ($1, $2)
|
||||
returning id, name, description, created_at, last_active_at
|
||||
`, name, description)
|
||||
|
||||
var project thoughttypes.Project
|
||||
if err := row.Scan(&project.ID, &project.Name, &project.Description, &project.CreatedAt, &project.LastActiveAt); err != nil {
|
||||
return thoughttypes.Project{}, fmt.Errorf("create project: %w", err)
|
||||
}
|
||||
return project, nil
|
||||
}
|
||||
|
||||
func (db *DB) GetProject(ctx context.Context, nameOrID string) (thoughttypes.Project, error) {
|
||||
var row pgx.Row
|
||||
if parsedID, err := uuid.Parse(strings.TrimSpace(nameOrID)); err == nil {
|
||||
row = db.pool.QueryRow(ctx, `
|
||||
select id, name, description, created_at, last_active_at
|
||||
from projects
|
||||
where id = $1
|
||||
`, parsedID)
|
||||
} else {
|
||||
row = db.pool.QueryRow(ctx, `
|
||||
select id, name, description, created_at, last_active_at
|
||||
from projects
|
||||
where name = $1
|
||||
`, strings.TrimSpace(nameOrID))
|
||||
}
|
||||
|
||||
var project thoughttypes.Project
|
||||
if err := row.Scan(&project.ID, &project.Name, &project.Description, &project.CreatedAt, &project.LastActiveAt); err != nil {
|
||||
if err == pgx.ErrNoRows {
|
||||
return thoughttypes.Project{}, err
|
||||
}
|
||||
return thoughttypes.Project{}, fmt.Errorf("get project: %w", err)
|
||||
}
|
||||
return project, nil
|
||||
}
|
||||
|
||||
func (db *DB) ListProjects(ctx context.Context) ([]thoughttypes.ProjectSummary, error) {
|
||||
rows, err := db.pool.Query(ctx, `
|
||||
select p.id, p.name, p.description, p.created_at, p.last_active_at, count(t.id) as thought_count
|
||||
from projects p
|
||||
left join thoughts t on t.project_id = p.id and t.archived_at is null
|
||||
group by p.id
|
||||
order by p.last_active_at desc, p.created_at desc
|
||||
`)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list projects: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
projects := make([]thoughttypes.ProjectSummary, 0)
|
||||
for rows.Next() {
|
||||
var project thoughttypes.ProjectSummary
|
||||
if err := rows.Scan(&project.ID, &project.Name, &project.Description, &project.CreatedAt, &project.LastActiveAt, &project.ThoughtCount); err != nil {
|
||||
return nil, fmt.Errorf("scan project summary: %w", err)
|
||||
}
|
||||
projects = append(projects, project)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("iterate projects: %w", err)
|
||||
}
|
||||
return projects, nil
|
||||
}
|
||||
|
||||
func (db *DB) TouchProject(ctx context.Context, id uuid.UUID) error {
|
||||
tag, err := db.pool.Exec(ctx, `update projects set last_active_at = now() where id = $1`, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("touch project: %w", err)
|
||||
}
|
||||
if tag.RowsAffected() == 0 {
|
||||
return pgx.ErrNoRows
|
||||
}
|
||||
return nil
|
||||
}
|
||||
345
internal/store/thoughts.go
Normal file
345
internal/store/thoughts.go
Normal file
@@ -0,0 +1,345 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/pgvector/pgvector-go"
|
||||
|
||||
thoughttypes "git.warky.dev/wdevs/amcs/internal/types"
|
||||
)
|
||||
|
||||
func (db *DB) InsertThought(ctx context.Context, thought thoughttypes.Thought) (thoughttypes.Thought, error) {
|
||||
metadata, err := json.Marshal(thought.Metadata)
|
||||
if err != nil {
|
||||
return thoughttypes.Thought{}, fmt.Errorf("marshal metadata: %w", err)
|
||||
}
|
||||
|
||||
row := db.pool.QueryRow(ctx, `
|
||||
insert into thoughts (content, embedding, metadata, project_id)
|
||||
values ($1, $2, $3::jsonb, $4)
|
||||
returning id, created_at, updated_at
|
||||
`, thought.Content, pgvector.NewVector(thought.Embedding), metadata, thought.ProjectID)
|
||||
|
||||
created := thought
|
||||
if err := row.Scan(&created.ID, &created.CreatedAt, &created.UpdatedAt); err != nil {
|
||||
return thoughttypes.Thought{}, fmt.Errorf("insert thought: %w", err)
|
||||
}
|
||||
|
||||
return created, nil
|
||||
}
|
||||
|
||||
func (db *DB) SearchThoughts(ctx context.Context, embedding []float32, threshold float64, limit int, filter map[string]any) ([]thoughttypes.SearchResult, error) {
|
||||
filterJSON, err := json.Marshal(filter)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal search filter: %w", err)
|
||||
}
|
||||
|
||||
rows, err := db.pool.Query(ctx, `
|
||||
select id, content, metadata, similarity, created_at
|
||||
from match_thoughts($1, $2, $3, $4::jsonb)
|
||||
`, pgvector.NewVector(embedding), threshold, limit, filterJSON)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("search thoughts: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
results := make([]thoughttypes.SearchResult, 0, limit)
|
||||
for rows.Next() {
|
||||
var result thoughttypes.SearchResult
|
||||
var metadataBytes []byte
|
||||
if err := rows.Scan(&result.ID, &result.Content, &metadataBytes, &result.Similarity, &result.CreatedAt); err != nil {
|
||||
return nil, fmt.Errorf("scan search result: %w", err)
|
||||
}
|
||||
if err := json.Unmarshal(metadataBytes, &result.Metadata); err != nil {
|
||||
return nil, fmt.Errorf("decode search metadata: %w", err)
|
||||
}
|
||||
results = append(results, result)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("iterate search results: %w", err)
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func (db *DB) ListThoughts(ctx context.Context, filter thoughttypes.ListFilter) ([]thoughttypes.Thought, error) {
|
||||
args := make([]any, 0, 6)
|
||||
conditions := []string{}
|
||||
|
||||
if !filter.IncludeArchived {
|
||||
conditions = append(conditions, "archived_at is null")
|
||||
}
|
||||
if value := strings.TrimSpace(filter.Type); value != "" {
|
||||
args = append(args, value)
|
||||
conditions = append(conditions, fmt.Sprintf("metadata->>'type' = $%d", len(args)))
|
||||
}
|
||||
if value := strings.TrimSpace(filter.Topic); value != "" {
|
||||
args = append(args, value)
|
||||
conditions = append(conditions, fmt.Sprintf("metadata->'topics' ? $%d", len(args)))
|
||||
}
|
||||
if value := strings.TrimSpace(filter.Person); value != "" {
|
||||
args = append(args, value)
|
||||
conditions = append(conditions, fmt.Sprintf("metadata->'people' ? $%d", len(args)))
|
||||
}
|
||||
if filter.Days > 0 {
|
||||
args = append(args, time.Now().Add(-time.Duration(filter.Days)*24*time.Hour))
|
||||
conditions = append(conditions, fmt.Sprintf("created_at >= $%d", len(args)))
|
||||
}
|
||||
if filter.ProjectID != nil {
|
||||
args = append(args, *filter.ProjectID)
|
||||
conditions = append(conditions, fmt.Sprintf("project_id = $%d", len(args)))
|
||||
}
|
||||
|
||||
query := `
|
||||
select id, content, metadata, project_id, archived_at, created_at, updated_at
|
||||
from thoughts
|
||||
`
|
||||
if len(conditions) > 0 {
|
||||
query += " where " + strings.Join(conditions, " and ")
|
||||
}
|
||||
|
||||
args = append(args, filter.Limit)
|
||||
query += fmt.Sprintf(" order by created_at desc limit $%d", len(args))
|
||||
|
||||
rows, err := db.pool.Query(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list thoughts: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
thoughts := make([]thoughttypes.Thought, 0, filter.Limit)
|
||||
for rows.Next() {
|
||||
var thought thoughttypes.Thought
|
||||
var metadataBytes []byte
|
||||
if err := rows.Scan(&thought.ID, &thought.Content, &metadataBytes, &thought.ProjectID, &thought.ArchivedAt, &thought.CreatedAt, &thought.UpdatedAt); err != nil {
|
||||
return nil, fmt.Errorf("scan listed thought: %w", err)
|
||||
}
|
||||
if err := json.Unmarshal(metadataBytes, &thought.Metadata); err != nil {
|
||||
return nil, fmt.Errorf("decode listed metadata: %w", err)
|
||||
}
|
||||
thoughts = append(thoughts, thought)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("iterate listed thoughts: %w", err)
|
||||
}
|
||||
|
||||
return thoughts, nil
|
||||
}
|
||||
|
||||
func (db *DB) Stats(ctx context.Context) (thoughttypes.ThoughtStats, error) {
|
||||
var total int
|
||||
if err := db.pool.QueryRow(ctx, `select count(*) from thoughts where archived_at is null`).Scan(&total); err != nil {
|
||||
return thoughttypes.ThoughtStats{}, fmt.Errorf("count thoughts: %w", err)
|
||||
}
|
||||
|
||||
rows, err := db.pool.Query(ctx, `select metadata from thoughts where archived_at is null`)
|
||||
if err != nil {
|
||||
return thoughttypes.ThoughtStats{}, fmt.Errorf("query stats metadata: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
stats := thoughttypes.ThoughtStats{
|
||||
TotalCount: total,
|
||||
TypeCounts: map[string]int{},
|
||||
}
|
||||
topics := map[string]int{}
|
||||
people := map[string]int{}
|
||||
|
||||
for rows.Next() {
|
||||
var metadataBytes []byte
|
||||
if err := rows.Scan(&metadataBytes); err != nil {
|
||||
return thoughttypes.ThoughtStats{}, fmt.Errorf("scan stats metadata: %w", err)
|
||||
}
|
||||
|
||||
var metadata thoughttypes.ThoughtMetadata
|
||||
if err := json.Unmarshal(metadataBytes, &metadata); err != nil {
|
||||
return thoughttypes.ThoughtStats{}, fmt.Errorf("decode stats metadata: %w", err)
|
||||
}
|
||||
|
||||
stats.TypeCounts[metadata.Type]++
|
||||
for _, topic := range metadata.Topics {
|
||||
topics[topic]++
|
||||
}
|
||||
for _, person := range metadata.People {
|
||||
people[person]++
|
||||
}
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return thoughttypes.ThoughtStats{}, fmt.Errorf("iterate stats metadata: %w", err)
|
||||
}
|
||||
|
||||
stats.TopTopics = topCounts(topics, 10)
|
||||
stats.TopPeople = topCounts(people, 10)
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
func (db *DB) GetThought(ctx context.Context, id uuid.UUID) (thoughttypes.Thought, error) {
|
||||
row := db.pool.QueryRow(ctx, `
|
||||
select id, content, embedding, metadata, project_id, archived_at, created_at, updated_at
|
||||
from thoughts
|
||||
where id = $1
|
||||
`, id)
|
||||
|
||||
var thought thoughttypes.Thought
|
||||
var embedding pgvector.Vector
|
||||
var metadataBytes []byte
|
||||
if err := row.Scan(&thought.ID, &thought.Content, &embedding, &metadataBytes, &thought.ProjectID, &thought.ArchivedAt, &thought.CreatedAt, &thought.UpdatedAt); err != nil {
|
||||
if err == pgx.ErrNoRows {
|
||||
return thoughttypes.Thought{}, err
|
||||
}
|
||||
return thoughttypes.Thought{}, fmt.Errorf("get thought: %w", err)
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(metadataBytes, &thought.Metadata); err != nil {
|
||||
return thoughttypes.Thought{}, fmt.Errorf("decode thought metadata: %w", err)
|
||||
}
|
||||
thought.Embedding = embedding.Slice()
|
||||
|
||||
return thought, nil
|
||||
}
|
||||
|
||||
func (db *DB) UpdateThought(ctx context.Context, id uuid.UUID, content string, embedding []float32, metadata thoughttypes.ThoughtMetadata, projectID *uuid.UUID) (thoughttypes.Thought, error) {
|
||||
metadataBytes, err := json.Marshal(metadata)
|
||||
if err != nil {
|
||||
return thoughttypes.Thought{}, fmt.Errorf("marshal updated metadata: %w", err)
|
||||
}
|
||||
|
||||
tag, err := db.pool.Exec(ctx, `
|
||||
update thoughts
|
||||
set content = $2,
|
||||
embedding = $3,
|
||||
metadata = $4::jsonb,
|
||||
project_id = $5,
|
||||
updated_at = now()
|
||||
where id = $1
|
||||
`, id, content, pgvector.NewVector(embedding), metadataBytes, projectID)
|
||||
if err != nil {
|
||||
return thoughttypes.Thought{}, fmt.Errorf("update thought: %w", err)
|
||||
}
|
||||
if tag.RowsAffected() == 0 {
|
||||
return thoughttypes.Thought{}, pgx.ErrNoRows
|
||||
}
|
||||
|
||||
return db.GetThought(ctx, id)
|
||||
}
|
||||
|
||||
func (db *DB) DeleteThought(ctx context.Context, id uuid.UUID) error {
|
||||
tag, err := db.pool.Exec(ctx, `delete from thoughts where id = $1`, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("delete thought: %w", err)
|
||||
}
|
||||
if tag.RowsAffected() == 0 {
|
||||
return pgx.ErrNoRows
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *DB) ArchiveThought(ctx context.Context, id uuid.UUID) error {
|
||||
tag, err := db.pool.Exec(ctx, `update thoughts set archived_at = now(), updated_at = now() where id = $1`, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("archive thought: %w", err)
|
||||
}
|
||||
if tag.RowsAffected() == 0 {
|
||||
return pgx.ErrNoRows
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *DB) RecentThoughts(ctx context.Context, projectID *uuid.UUID, limit int, days int) ([]thoughttypes.Thought, error) {
|
||||
filter := thoughttypes.ListFilter{
|
||||
Limit: limit,
|
||||
ProjectID: projectID,
|
||||
Days: days,
|
||||
IncludeArchived: false,
|
||||
}
|
||||
return db.ListThoughts(ctx, filter)
|
||||
}
|
||||
|
||||
func (db *DB) SearchSimilarThoughts(ctx context.Context, embedding []float32, threshold float64, limit int, projectID *uuid.UUID, excludeID *uuid.UUID) ([]thoughttypes.SearchResult, error) {
|
||||
args := []any{pgvector.NewVector(embedding), threshold}
|
||||
conditions := []string{
|
||||
"archived_at is null",
|
||||
"1 - (embedding <=> $1) > $2",
|
||||
}
|
||||
if projectID != nil {
|
||||
args = append(args, *projectID)
|
||||
conditions = append(conditions, fmt.Sprintf("project_id = $%d", len(args)))
|
||||
}
|
||||
if excludeID != nil {
|
||||
args = append(args, *excludeID)
|
||||
conditions = append(conditions, fmt.Sprintf("id <> $%d", len(args)))
|
||||
}
|
||||
args = append(args, limit)
|
||||
|
||||
query := `
|
||||
select id, content, metadata, 1 - (embedding <=> $1) as similarity, created_at
|
||||
from thoughts
|
||||
where ` + strings.Join(conditions, " and ") + fmt.Sprintf(`
|
||||
order by embedding <=> $1
|
||||
limit $%d`, len(args))
|
||||
|
||||
rows, err := db.pool.Query(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("search similar thoughts: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
results := make([]thoughttypes.SearchResult, 0, limit)
|
||||
for rows.Next() {
|
||||
var result thoughttypes.SearchResult
|
||||
var metadataBytes []byte
|
||||
if err := rows.Scan(&result.ID, &result.Content, &metadataBytes, &result.Similarity, &result.CreatedAt); err != nil {
|
||||
return nil, fmt.Errorf("scan similar thought: %w", err)
|
||||
}
|
||||
if err := json.Unmarshal(metadataBytes, &result.Metadata); err != nil {
|
||||
return nil, fmt.Errorf("decode similar thought metadata: %w", err)
|
||||
}
|
||||
results = append(results, result)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("iterate similar thoughts: %w", err)
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func topCounts(in map[string]int, limit int) []thoughttypes.KeyCount {
|
||||
type kv struct {
|
||||
key string
|
||||
count int
|
||||
}
|
||||
|
||||
pairs := make([]kv, 0, len(in))
|
||||
for key, count := range in {
|
||||
if strings.TrimSpace(key) == "" {
|
||||
continue
|
||||
}
|
||||
pairs = append(pairs, kv{key: key, count: count})
|
||||
}
|
||||
|
||||
sort.Slice(pairs, func(i, j int) bool {
|
||||
if pairs[i].count == pairs[j].count {
|
||||
return pairs[i].key < pairs[j].key
|
||||
}
|
||||
return pairs[i].count > pairs[j].count
|
||||
})
|
||||
|
||||
if limit > 0 && len(pairs) > limit {
|
||||
pairs = pairs[:limit]
|
||||
}
|
||||
|
||||
out := make([]thoughttypes.KeyCount, 0, len(pairs))
|
||||
for _, pair := range pairs {
|
||||
out = append(out, thoughttypes.KeyCount{Key: pair.key, Count: pair.count})
|
||||
}
|
||||
return out
|
||||
}
|
||||
36
internal/tools/archive.go
Normal file
36
internal/tools/archive.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
|
||||
"git.warky.dev/wdevs/amcs/internal/store"
|
||||
)
|
||||
|
||||
type ArchiveTool struct {
|
||||
store *store.DB
|
||||
}
|
||||
|
||||
type ArchiveInput struct {
|
||||
ID string `json:"id" jsonschema:"the thought id"`
|
||||
}
|
||||
|
||||
type ArchiveOutput struct {
|
||||
Archived bool `json:"archived"`
|
||||
}
|
||||
|
||||
func NewArchiveTool(db *store.DB) *ArchiveTool {
|
||||
return &ArchiveTool{store: db}
|
||||
}
|
||||
|
||||
func (t *ArchiveTool) Handle(ctx context.Context, _ *mcp.CallToolRequest, in ArchiveInput) (*mcp.CallToolResult, ArchiveOutput, error) {
|
||||
id, err := parseUUID(in.ID)
|
||||
if err != nil {
|
||||
return nil, ArchiveOutput{}, err
|
||||
}
|
||||
if err := t.store.ArchiveThought(ctx, id); err != nil {
|
||||
return nil, ArchiveOutput{}, err
|
||||
}
|
||||
return nil, ArchiveOutput{Archived: true}, nil
|
||||
}
|
||||
95
internal/tools/capture.go
Normal file
95
internal/tools/capture.go
Normal file
@@ -0,0 +1,95 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"strings"
|
||||
|
||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"git.warky.dev/wdevs/amcs/internal/ai"
|
||||
"git.warky.dev/wdevs/amcs/internal/config"
|
||||
"git.warky.dev/wdevs/amcs/internal/metadata"
|
||||
"git.warky.dev/wdevs/amcs/internal/session"
|
||||
"git.warky.dev/wdevs/amcs/internal/store"
|
||||
thoughttypes "git.warky.dev/wdevs/amcs/internal/types"
|
||||
)
|
||||
|
||||
type CaptureTool struct {
|
||||
store *store.DB
|
||||
provider ai.Provider
|
||||
capture config.CaptureConfig
|
||||
sessions *session.ActiveProjects
|
||||
log *slog.Logger
|
||||
}
|
||||
|
||||
type CaptureInput struct {
|
||||
Content string `json:"content" jsonschema:"the thought or note to capture"`
|
||||
Project string `json:"project,omitempty" jsonschema:"optional project name or id to associate with the thought"`
|
||||
}
|
||||
|
||||
type CaptureOutput struct {
|
||||
Thought thoughttypes.Thought `json:"thought"`
|
||||
}
|
||||
|
||||
func NewCaptureTool(db *store.DB, provider ai.Provider, capture config.CaptureConfig, sessions *session.ActiveProjects, log *slog.Logger) *CaptureTool {
|
||||
return &CaptureTool{store: db, provider: provider, capture: capture, sessions: sessions, log: log}
|
||||
}
|
||||
|
||||
func (t *CaptureTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in CaptureInput) (*mcp.CallToolResult, CaptureOutput, error) {
|
||||
content := strings.TrimSpace(in.Content)
|
||||
if content == "" {
|
||||
return nil, CaptureOutput{}, errInvalidInput("content is required")
|
||||
}
|
||||
|
||||
project, err := resolveProject(ctx, t.store, t.sessions, req, in.Project, false)
|
||||
if err != nil {
|
||||
return nil, CaptureOutput{}, err
|
||||
}
|
||||
|
||||
var embedding []float32
|
||||
rawMetadata := metadata.Fallback(t.capture)
|
||||
|
||||
group, groupCtx := errgroup.WithContext(ctx)
|
||||
group.Go(func() error {
|
||||
vector, err := t.provider.Embed(groupCtx, content)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
embedding = vector
|
||||
return nil
|
||||
})
|
||||
group.Go(func() error {
|
||||
extracted, err := t.provider.ExtractMetadata(groupCtx, content)
|
||||
if err != nil {
|
||||
t.log.Warn("metadata extraction failed, using fallback", slog.String("provider", t.provider.Name()), slog.String("error", err.Error()))
|
||||
return nil
|
||||
}
|
||||
rawMetadata = extracted
|
||||
return nil
|
||||
})
|
||||
|
||||
if err := group.Wait(); err != nil {
|
||||
return nil, CaptureOutput{}, err
|
||||
}
|
||||
|
||||
thought := thoughttypes.Thought{
|
||||
Content: content,
|
||||
Embedding: embedding,
|
||||
Metadata: metadata.Normalize(rawMetadata, t.capture),
|
||||
}
|
||||
if project != nil {
|
||||
thought.ProjectID = &project.ID
|
||||
}
|
||||
|
||||
created, err := t.store.InsertThought(ctx, thought)
|
||||
if err != nil {
|
||||
return nil, CaptureOutput{}, err
|
||||
}
|
||||
if project != nil {
|
||||
_ = t.store.TouchProject(ctx, project.ID)
|
||||
}
|
||||
|
||||
return nil, CaptureOutput{Thought: created}, nil
|
||||
}
|
||||
31
internal/tools/common.go
Normal file
31
internal/tools/common.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"git.warky.dev/wdevs/amcs/internal/config"
|
||||
)
|
||||
|
||||
func normalizeLimit(limit int, cfg config.SearchConfig) int {
|
||||
if limit <= 0 {
|
||||
return cfg.DefaultLimit
|
||||
}
|
||||
if limit > cfg.MaxLimit {
|
||||
return cfg.MaxLimit
|
||||
}
|
||||
return limit
|
||||
}
|
||||
|
||||
func normalizeThreshold(value float64, fallback float64) float64 {
|
||||
if value <= 0 {
|
||||
return fallback
|
||||
}
|
||||
if value > 1 {
|
||||
return 1
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
func errInvalidInput(message string) error {
|
||||
return fmt.Errorf("invalid input: %s", message)
|
||||
}
|
||||
111
internal/tools/context.go
Normal file
111
internal/tools/context.go
Normal file
@@ -0,0 +1,111 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
|
||||
"git.warky.dev/wdevs/amcs/internal/ai"
|
||||
"git.warky.dev/wdevs/amcs/internal/config"
|
||||
"git.warky.dev/wdevs/amcs/internal/session"
|
||||
"git.warky.dev/wdevs/amcs/internal/store"
|
||||
thoughttypes "git.warky.dev/wdevs/amcs/internal/types"
|
||||
)
|
||||
|
||||
type ContextTool struct {
|
||||
store *store.DB
|
||||
provider ai.Provider
|
||||
search config.SearchConfig
|
||||
sessions *session.ActiveProjects
|
||||
}
|
||||
|
||||
type ProjectContextInput struct {
|
||||
Project string `json:"project,omitempty" jsonschema:"project name or id; falls back to the active session project"`
|
||||
Query string `json:"query,omitempty" jsonschema:"optional semantic focus for project context"`
|
||||
Limit int `json:"limit,omitempty" jsonschema:"maximum number of context items to return"`
|
||||
}
|
||||
|
||||
type ContextItem struct {
|
||||
ID string `json:"id"`
|
||||
Content string `json:"content"`
|
||||
Metadata thoughttypes.ThoughtMetadata `json:"metadata"`
|
||||
Similarity float64 `json:"similarity,omitempty"`
|
||||
Source string `json:"source"`
|
||||
}
|
||||
|
||||
type ProjectContextOutput struct {
|
||||
Project thoughttypes.Project `json:"project"`
|
||||
Context string `json:"context"`
|
||||
Items []ContextItem `json:"items"`
|
||||
}
|
||||
|
||||
func NewContextTool(db *store.DB, provider ai.Provider, search config.SearchConfig, sessions *session.ActiveProjects) *ContextTool {
|
||||
return &ContextTool{store: db, provider: provider, search: search, sessions: sessions}
|
||||
}
|
||||
|
||||
func (t *ContextTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in ProjectContextInput) (*mcp.CallToolResult, ProjectContextOutput, error) {
|
||||
project, err := resolveProject(ctx, t.store, t.sessions, req, in.Project, true)
|
||||
if err != nil {
|
||||
return nil, ProjectContextOutput{}, err
|
||||
}
|
||||
|
||||
limit := normalizeLimit(in.Limit, t.search)
|
||||
recent, err := t.store.RecentThoughts(ctx, &project.ID, limit, 0)
|
||||
if err != nil {
|
||||
return nil, ProjectContextOutput{}, err
|
||||
}
|
||||
|
||||
items := make([]ContextItem, 0, limit*2)
|
||||
seen := map[string]struct{}{}
|
||||
for _, thought := range recent {
|
||||
key := thought.ID.String()
|
||||
seen[key] = struct{}{}
|
||||
items = append(items, ContextItem{
|
||||
ID: key,
|
||||
Content: thought.Content,
|
||||
Metadata: thought.Metadata,
|
||||
Source: "recent",
|
||||
})
|
||||
}
|
||||
|
||||
query := strings.TrimSpace(in.Query)
|
||||
if query != "" {
|
||||
embedding, err := t.provider.Embed(ctx, query)
|
||||
if err != nil {
|
||||
return nil, ProjectContextOutput{}, err
|
||||
}
|
||||
semantic, err := t.store.SearchSimilarThoughts(ctx, embedding, t.search.DefaultThreshold, limit, &project.ID, nil)
|
||||
if err != nil {
|
||||
return nil, ProjectContextOutput{}, err
|
||||
}
|
||||
for _, result := range semantic {
|
||||
key := result.ID.String()
|
||||
if _, ok := seen[key]; ok {
|
||||
continue
|
||||
}
|
||||
seen[key] = struct{}{}
|
||||
items = append(items, ContextItem{
|
||||
ID: key,
|
||||
Content: result.Content,
|
||||
Metadata: result.Metadata,
|
||||
Similarity: result.Similarity,
|
||||
Source: "semantic",
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
lines := make([]string, 0, len(items))
|
||||
for i, item := range items {
|
||||
lines = append(lines, thoughtContextLine(i, item.Content, item.Metadata, item.Similarity))
|
||||
}
|
||||
contextBlock := formatContextBlock(fmt.Sprintf("Project context for %s", project.Name), lines)
|
||||
_ = t.store.TouchProject(ctx, project.ID)
|
||||
|
||||
return nil, ProjectContextOutput{
|
||||
Project: *project,
|
||||
Context: contextBlock,
|
||||
Items: items,
|
||||
}, nil
|
||||
}
|
||||
36
internal/tools/delete.go
Normal file
36
internal/tools/delete.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
|
||||
"git.warky.dev/wdevs/amcs/internal/store"
|
||||
)
|
||||
|
||||
type DeleteTool struct {
|
||||
store *store.DB
|
||||
}
|
||||
|
||||
type DeleteInput struct {
|
||||
ID string `json:"id" jsonschema:"the thought id"`
|
||||
}
|
||||
|
||||
type DeleteOutput struct {
|
||||
Deleted bool `json:"deleted"`
|
||||
}
|
||||
|
||||
func NewDeleteTool(db *store.DB) *DeleteTool {
|
||||
return &DeleteTool{store: db}
|
||||
}
|
||||
|
||||
func (t *DeleteTool) Handle(ctx context.Context, _ *mcp.CallToolRequest, in DeleteInput) (*mcp.CallToolResult, DeleteOutput, error) {
|
||||
id, err := parseUUID(in.ID)
|
||||
if err != nil {
|
||||
return nil, DeleteOutput{}, err
|
||||
}
|
||||
if err := t.store.DeleteThought(ctx, id); err != nil {
|
||||
return nil, DeleteOutput{}, err
|
||||
}
|
||||
return nil, DeleteOutput{Deleted: true}, nil
|
||||
}
|
||||
40
internal/tools/get.go
Normal file
40
internal/tools/get.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
|
||||
"git.warky.dev/wdevs/amcs/internal/store"
|
||||
thoughttypes "git.warky.dev/wdevs/amcs/internal/types"
|
||||
)
|
||||
|
||||
type GetTool struct {
|
||||
store *store.DB
|
||||
}
|
||||
|
||||
type GetInput struct {
|
||||
ID string `json:"id" jsonschema:"the thought id"`
|
||||
}
|
||||
|
||||
type GetOutput struct {
|
||||
Thought thoughttypes.Thought `json:"thought"`
|
||||
}
|
||||
|
||||
func NewGetTool(db *store.DB) *GetTool {
|
||||
return &GetTool{store: db}
|
||||
}
|
||||
|
||||
func (t *GetTool) Handle(ctx context.Context, _ *mcp.CallToolRequest, in GetInput) (*mcp.CallToolResult, GetOutput, error) {
|
||||
id, err := parseUUID(in.ID)
|
||||
if err != nil {
|
||||
return nil, GetOutput{}, err
|
||||
}
|
||||
|
||||
thought, err := t.store.GetThought(ctx, id)
|
||||
if err != nil {
|
||||
return nil, GetOutput{}, err
|
||||
}
|
||||
|
||||
return nil, GetOutput{Thought: thought}, nil
|
||||
}
|
||||
87
internal/tools/helpers.go
Normal file
87
internal/tools/helpers.go
Normal file
@@ -0,0 +1,87 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
|
||||
"git.warky.dev/wdevs/amcs/internal/session"
|
||||
"git.warky.dev/wdevs/amcs/internal/store"
|
||||
thoughttypes "git.warky.dev/wdevs/amcs/internal/types"
|
||||
)
|
||||
|
||||
func parseUUID(id string) (uuid.UUID, error) {
|
||||
parsed, err := uuid.Parse(strings.TrimSpace(id))
|
||||
if err != nil {
|
||||
return uuid.Nil, fmt.Errorf("invalid id %q: %w", id, err)
|
||||
}
|
||||
return parsed, nil
|
||||
}
|
||||
|
||||
func sessionID(req *mcp.CallToolRequest) (string, error) {
|
||||
if req == nil || req.Session == nil || req.Session.ID() == "" {
|
||||
return "", fmt.Errorf("tool requires an MCP session")
|
||||
}
|
||||
return req.Session.ID(), nil
|
||||
}
|
||||
|
||||
func resolveProject(ctx context.Context, db *store.DB, sessions *session.ActiveProjects, req *mcp.CallToolRequest, raw string, required bool) (*thoughttypes.Project, error) {
|
||||
projectRef := strings.TrimSpace(raw)
|
||||
if projectRef == "" && sessions != nil && req != nil && req.Session != nil {
|
||||
if activeID, ok := sessions.Get(req.Session.ID()); ok {
|
||||
project, err := db.GetProject(ctx, activeID.String())
|
||||
if err == nil {
|
||||
return &project, nil
|
||||
}
|
||||
if err != pgx.ErrNoRows {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if projectRef == "" {
|
||||
if required {
|
||||
return nil, fmt.Errorf("project is required")
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
project, err := db.GetProject(ctx, projectRef)
|
||||
if err != nil {
|
||||
if err == pgx.ErrNoRows {
|
||||
return nil, fmt.Errorf("project %q not found", projectRef)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &project, nil
|
||||
}
|
||||
|
||||
func formatContextBlock(header string, lines []string) string {
|
||||
if len(lines) == 0 {
|
||||
return header + "\n\nNo matching thoughts."
|
||||
}
|
||||
return header + "\n\n" + strings.Join(lines, "\n\n")
|
||||
}
|
||||
|
||||
func thoughtContextLine(index int, content string, metadata thoughttypes.ThoughtMetadata, similarity float64) string {
|
||||
label := fmt.Sprintf("%d. %s", index+1, strings.TrimSpace(content))
|
||||
parts := make([]string, 0, 3)
|
||||
if len(metadata.Topics) > 0 {
|
||||
parts = append(parts, "topics="+strings.Join(metadata.Topics, ", "))
|
||||
}
|
||||
if metadata.Type != "" {
|
||||
parts = append(parts, "type="+metadata.Type)
|
||||
}
|
||||
if similarity > 0 {
|
||||
parts = append(parts, fmt.Sprintf("similarity=%.3f", similarity))
|
||||
}
|
||||
if len(parts) == 0 {
|
||||
return label
|
||||
}
|
||||
return label + "\n[" + strings.Join(parts, " | ") + "]"
|
||||
}
|
||||
145
internal/tools/links.go
Normal file
145
internal/tools/links.go
Normal file
@@ -0,0 +1,145 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
|
||||
"git.warky.dev/wdevs/amcs/internal/ai"
|
||||
"git.warky.dev/wdevs/amcs/internal/config"
|
||||
"git.warky.dev/wdevs/amcs/internal/store"
|
||||
thoughttypes "git.warky.dev/wdevs/amcs/internal/types"
|
||||
)
|
||||
|
||||
type LinksTool struct {
|
||||
store *store.DB
|
||||
provider ai.Provider
|
||||
search config.SearchConfig
|
||||
}
|
||||
|
||||
type LinkInput struct {
|
||||
FromID string `json:"from_id" jsonschema:"the source thought id"`
|
||||
ToID string `json:"to_id" jsonschema:"the target thought id"`
|
||||
Relation string `json:"relation" jsonschema:"relationship label such as follows_up or references"`
|
||||
}
|
||||
|
||||
type LinkOutput struct {
|
||||
Linked bool `json:"linked"`
|
||||
}
|
||||
|
||||
type RelatedInput struct {
|
||||
ID string `json:"id" jsonschema:"the thought id"`
|
||||
IncludeSemantic *bool `json:"include_semantic,omitempty" jsonschema:"whether to include semantic neighbors; defaults to true"`
|
||||
}
|
||||
|
||||
type RelatedThought struct {
|
||||
ID string `json:"id"`
|
||||
Content string `json:"content"`
|
||||
Metadata thoughttypes.ThoughtMetadata `json:"metadata"`
|
||||
Relation string `json:"relation,omitempty"`
|
||||
Direction string `json:"direction,omitempty"`
|
||||
Similarity float64 `json:"similarity,omitempty"`
|
||||
Source string `json:"source"`
|
||||
}
|
||||
|
||||
type RelatedOutput struct {
|
||||
Related []RelatedThought `json:"related"`
|
||||
}
|
||||
|
||||
func NewLinksTool(db *store.DB, provider ai.Provider, search config.SearchConfig) *LinksTool {
|
||||
return &LinksTool{store: db, provider: provider, search: search}
|
||||
}
|
||||
|
||||
func (t *LinksTool) Link(ctx context.Context, _ *mcp.CallToolRequest, in LinkInput) (*mcp.CallToolResult, LinkOutput, error) {
|
||||
fromID, err := parseUUID(in.FromID)
|
||||
if err != nil {
|
||||
return nil, LinkOutput{}, err
|
||||
}
|
||||
toID, err := parseUUID(in.ToID)
|
||||
if err != nil {
|
||||
return nil, LinkOutput{}, err
|
||||
}
|
||||
relation := strings.TrimSpace(in.Relation)
|
||||
if relation == "" {
|
||||
return nil, LinkOutput{}, errInvalidInput("relation is required")
|
||||
}
|
||||
if _, err := t.store.GetThought(ctx, fromID); err != nil {
|
||||
return nil, LinkOutput{}, err
|
||||
}
|
||||
if _, err := t.store.GetThought(ctx, toID); err != nil {
|
||||
return nil, LinkOutput{}, err
|
||||
}
|
||||
if err := t.store.InsertLink(ctx, thoughttypes.ThoughtLink{
|
||||
FromID: fromID,
|
||||
ToID: toID,
|
||||
Relation: relation,
|
||||
}); err != nil {
|
||||
return nil, LinkOutput{}, err
|
||||
}
|
||||
return nil, LinkOutput{Linked: true}, nil
|
||||
}
|
||||
|
||||
func (t *LinksTool) Related(ctx context.Context, _ *mcp.CallToolRequest, in RelatedInput) (*mcp.CallToolResult, RelatedOutput, error) {
|
||||
id, err := parseUUID(in.ID)
|
||||
if err != nil {
|
||||
return nil, RelatedOutput{}, err
|
||||
}
|
||||
|
||||
thought, err := t.store.GetThought(ctx, id)
|
||||
if err != nil {
|
||||
return nil, RelatedOutput{}, err
|
||||
}
|
||||
|
||||
linked, err := t.store.LinkedThoughts(ctx, id)
|
||||
if err != nil {
|
||||
return nil, RelatedOutput{}, err
|
||||
}
|
||||
|
||||
related := make([]RelatedThought, 0, len(linked)+t.search.DefaultLimit)
|
||||
seen := map[string]struct{}{thought.ID.String(): {}}
|
||||
for _, item := range linked {
|
||||
key := item.Thought.ID.String()
|
||||
seen[key] = struct{}{}
|
||||
related = append(related, RelatedThought{
|
||||
ID: key,
|
||||
Content: item.Thought.Content,
|
||||
Metadata: item.Thought.Metadata,
|
||||
Relation: item.Relation,
|
||||
Direction: item.Direction,
|
||||
Source: "link",
|
||||
})
|
||||
}
|
||||
|
||||
includeSemantic := true
|
||||
if in.IncludeSemantic != nil {
|
||||
includeSemantic = *in.IncludeSemantic
|
||||
}
|
||||
|
||||
if includeSemantic {
|
||||
embedding, err := t.provider.Embed(ctx, thought.Content)
|
||||
if err != nil {
|
||||
return nil, RelatedOutput{}, err
|
||||
}
|
||||
semantic, err := t.store.SearchSimilarThoughts(ctx, embedding, t.search.DefaultThreshold, t.search.DefaultLimit, thought.ProjectID, &thought.ID)
|
||||
if err != nil {
|
||||
return nil, RelatedOutput{}, err
|
||||
}
|
||||
for _, item := range semantic {
|
||||
key := item.ID.String()
|
||||
if _, ok := seen[key]; ok {
|
||||
continue
|
||||
}
|
||||
seen[key] = struct{}{}
|
||||
related = append(related, RelatedThought{
|
||||
ID: key,
|
||||
Content: item.Content,
|
||||
Metadata: item.Metadata,
|
||||
Similarity: item.Similarity,
|
||||
Source: "semantic",
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return nil, RelatedOutput{Related: related}, nil
|
||||
}
|
||||
68
internal/tools/list.go
Normal file
68
internal/tools/list.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
|
||||
"git.warky.dev/wdevs/amcs/internal/config"
|
||||
"git.warky.dev/wdevs/amcs/internal/session"
|
||||
"git.warky.dev/wdevs/amcs/internal/store"
|
||||
thoughttypes "git.warky.dev/wdevs/amcs/internal/types"
|
||||
)
|
||||
|
||||
type ListTool struct {
|
||||
store *store.DB
|
||||
search config.SearchConfig
|
||||
sessions *session.ActiveProjects
|
||||
}
|
||||
|
||||
type ListInput struct {
|
||||
Limit int `json:"limit,omitempty" jsonschema:"maximum number of thoughts to return"`
|
||||
Type string `json:"type,omitempty" jsonschema:"filter by thought type"`
|
||||
Topic string `json:"topic,omitempty" jsonschema:"filter by topic"`
|
||||
Person string `json:"person,omitempty" jsonschema:"filter by mentioned person"`
|
||||
Days int `json:"days,omitempty" jsonschema:"only include thoughts from the last N days"`
|
||||
IncludeArchived bool `json:"include_archived,omitempty" jsonschema:"include archived thoughts when true"`
|
||||
Project string `json:"project,omitempty" jsonschema:"optional project name or id to scope the listing"`
|
||||
}
|
||||
|
||||
type ListOutput struct {
|
||||
Thoughts []thoughttypes.Thought `json:"thoughts"`
|
||||
}
|
||||
|
||||
func NewListTool(db *store.DB, search config.SearchConfig, sessions *session.ActiveProjects) *ListTool {
|
||||
return &ListTool{store: db, search: search, sessions: sessions}
|
||||
}
|
||||
|
||||
func (t *ListTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in ListInput) (*mcp.CallToolResult, ListOutput, error) {
|
||||
project, err := resolveProject(ctx, t.store, t.sessions, req, in.Project, false)
|
||||
if err != nil {
|
||||
return nil, ListOutput{}, err
|
||||
}
|
||||
|
||||
var projectID *uuid.UUID
|
||||
if project != nil {
|
||||
projectID = &project.ID
|
||||
}
|
||||
|
||||
thoughts, err := t.store.ListThoughts(ctx, thoughttypes.ListFilter{
|
||||
Limit: normalizeLimit(in.Limit, t.search),
|
||||
Type: strings.TrimSpace(in.Type),
|
||||
Topic: strings.TrimSpace(in.Topic),
|
||||
Person: strings.TrimSpace(in.Person),
|
||||
Days: in.Days,
|
||||
IncludeArchived: in.IncludeArchived,
|
||||
ProjectID: projectID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, ListOutput{}, err
|
||||
}
|
||||
if project != nil {
|
||||
_ = t.store.TouchProject(ctx, project.ID)
|
||||
}
|
||||
|
||||
return nil, ListOutput{Thoughts: thoughts}, nil
|
||||
}
|
||||
92
internal/tools/projects.go
Normal file
92
internal/tools/projects.go
Normal file
@@ -0,0 +1,92 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
|
||||
"git.warky.dev/wdevs/amcs/internal/session"
|
||||
"git.warky.dev/wdevs/amcs/internal/store"
|
||||
thoughttypes "git.warky.dev/wdevs/amcs/internal/types"
|
||||
)
|
||||
|
||||
type ProjectsTool struct {
|
||||
store *store.DB
|
||||
sessions *session.ActiveProjects
|
||||
}
|
||||
|
||||
type CreateProjectInput struct {
|
||||
Name string `json:"name" jsonschema:"the unique project name"`
|
||||
Description string `json:"description,omitempty" jsonschema:"optional project description"`
|
||||
}
|
||||
|
||||
type CreateProjectOutput struct {
|
||||
Project thoughttypes.Project `json:"project"`
|
||||
}
|
||||
|
||||
type ListProjectsInput struct{}
|
||||
|
||||
type ListProjectsOutput struct {
|
||||
Projects []thoughttypes.ProjectSummary `json:"projects"`
|
||||
}
|
||||
|
||||
type SetActiveProjectInput struct {
|
||||
Project string `json:"project" jsonschema:"project name or id"`
|
||||
}
|
||||
|
||||
type SetActiveProjectOutput struct {
|
||||
Project thoughttypes.Project `json:"project"`
|
||||
}
|
||||
|
||||
type GetActiveProjectInput struct{}
|
||||
|
||||
type GetActiveProjectOutput struct {
|
||||
Project *thoughttypes.Project `json:"project,omitempty"`
|
||||
}
|
||||
|
||||
func NewProjectsTool(db *store.DB, sessions *session.ActiveProjects) *ProjectsTool {
|
||||
return &ProjectsTool{store: db, sessions: sessions}
|
||||
}
|
||||
|
||||
func (t *ProjectsTool) Create(ctx context.Context, _ *mcp.CallToolRequest, in CreateProjectInput) (*mcp.CallToolResult, CreateProjectOutput, error) {
|
||||
name := strings.TrimSpace(in.Name)
|
||||
if name == "" {
|
||||
return nil, CreateProjectOutput{}, errInvalidInput("name is required")
|
||||
}
|
||||
project, err := t.store.CreateProject(ctx, name, strings.TrimSpace(in.Description))
|
||||
if err != nil {
|
||||
return nil, CreateProjectOutput{}, err
|
||||
}
|
||||
return nil, CreateProjectOutput{Project: project}, nil
|
||||
}
|
||||
|
||||
func (t *ProjectsTool) List(ctx context.Context, _ *mcp.CallToolRequest, _ ListProjectsInput) (*mcp.CallToolResult, ListProjectsOutput, error) {
|
||||
projects, err := t.store.ListProjects(ctx)
|
||||
if err != nil {
|
||||
return nil, ListProjectsOutput{}, err
|
||||
}
|
||||
return nil, ListProjectsOutput{Projects: projects}, nil
|
||||
}
|
||||
|
||||
func (t *ProjectsTool) SetActive(ctx context.Context, req *mcp.CallToolRequest, in SetActiveProjectInput) (*mcp.CallToolResult, SetActiveProjectOutput, error) {
|
||||
sid, err := sessionID(req)
|
||||
if err != nil {
|
||||
return nil, SetActiveProjectOutput{}, err
|
||||
}
|
||||
project, err := resolveProject(ctx, t.store, t.sessions, req, in.Project, true)
|
||||
if err != nil {
|
||||
return nil, SetActiveProjectOutput{}, err
|
||||
}
|
||||
t.sessions.Set(sid, project.ID)
|
||||
_ = t.store.TouchProject(ctx, project.ID)
|
||||
return nil, SetActiveProjectOutput{Project: *project}, nil
|
||||
}
|
||||
|
||||
func (t *ProjectsTool) GetActive(ctx context.Context, req *mcp.CallToolRequest, _ GetActiveProjectInput) (*mcp.CallToolResult, GetActiveProjectOutput, error) {
|
||||
project, err := resolveProject(ctx, t.store, t.sessions, req, "", false)
|
||||
if err != nil {
|
||||
return nil, GetActiveProjectOutput{}, err
|
||||
}
|
||||
return nil, GetActiveProjectOutput{Project: project}, nil
|
||||
}
|
||||
111
internal/tools/recall.go
Normal file
111
internal/tools/recall.go
Normal file
@@ -0,0 +1,111 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
|
||||
"git.warky.dev/wdevs/amcs/internal/ai"
|
||||
"git.warky.dev/wdevs/amcs/internal/config"
|
||||
"git.warky.dev/wdevs/amcs/internal/session"
|
||||
"git.warky.dev/wdevs/amcs/internal/store"
|
||||
)
|
||||
|
||||
type RecallTool struct {
|
||||
store *store.DB
|
||||
provider ai.Provider
|
||||
search config.SearchConfig
|
||||
sessions *session.ActiveProjects
|
||||
}
|
||||
|
||||
type RecallInput struct {
|
||||
Query string `json:"query" jsonschema:"semantic query for recalled context"`
|
||||
Project string `json:"project,omitempty" jsonschema:"optional project name or id; falls back to the active session project"`
|
||||
Limit int `json:"limit,omitempty" jsonschema:"maximum number of context items to return"`
|
||||
}
|
||||
|
||||
type RecallOutput struct {
|
||||
Context string `json:"context"`
|
||||
Items []ContextItem `json:"items"`
|
||||
}
|
||||
|
||||
func NewRecallTool(db *store.DB, provider ai.Provider, search config.SearchConfig, sessions *session.ActiveProjects) *RecallTool {
|
||||
return &RecallTool{store: db, provider: provider, search: search, sessions: sessions}
|
||||
}
|
||||
|
||||
func (t *RecallTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in RecallInput) (*mcp.CallToolResult, RecallOutput, error) {
|
||||
query := strings.TrimSpace(in.Query)
|
||||
if query == "" {
|
||||
return nil, RecallOutput{}, errInvalidInput("query is required")
|
||||
}
|
||||
|
||||
project, err := resolveProject(ctx, t.store, t.sessions, req, in.Project, false)
|
||||
if err != nil {
|
||||
return nil, RecallOutput{}, err
|
||||
}
|
||||
|
||||
limit := normalizeLimit(in.Limit, t.search)
|
||||
embedding, err := t.provider.Embed(ctx, query)
|
||||
if err != nil {
|
||||
return nil, RecallOutput{}, err
|
||||
}
|
||||
|
||||
var projectID *uuid.UUID
|
||||
if project != nil {
|
||||
projectID = &project.ID
|
||||
}
|
||||
|
||||
semantic, err := t.store.SearchSimilarThoughts(ctx, embedding, t.search.DefaultThreshold, limit, projectID, nil)
|
||||
if err != nil {
|
||||
return nil, RecallOutput{}, err
|
||||
}
|
||||
recent, err := t.store.RecentThoughts(ctx, projectID, limit, 0)
|
||||
if err != nil {
|
||||
return nil, RecallOutput{}, err
|
||||
}
|
||||
|
||||
items := make([]ContextItem, 0, limit*2)
|
||||
seen := map[string]struct{}{}
|
||||
for _, result := range semantic {
|
||||
key := result.ID.String()
|
||||
seen[key] = struct{}{}
|
||||
items = append(items, ContextItem{
|
||||
ID: key,
|
||||
Content: result.Content,
|
||||
Metadata: result.Metadata,
|
||||
Similarity: result.Similarity,
|
||||
Source: "semantic",
|
||||
})
|
||||
}
|
||||
for _, thought := range recent {
|
||||
key := thought.ID.String()
|
||||
if _, ok := seen[key]; ok {
|
||||
continue
|
||||
}
|
||||
seen[key] = struct{}{}
|
||||
items = append(items, ContextItem{
|
||||
ID: key,
|
||||
Content: thought.Content,
|
||||
Metadata: thought.Metadata,
|
||||
Source: "recent",
|
||||
})
|
||||
}
|
||||
|
||||
lines := make([]string, 0, len(items))
|
||||
for i, item := range items {
|
||||
lines = append(lines, thoughtContextLine(i, item.Content, item.Metadata, item.Similarity))
|
||||
}
|
||||
header := "Recalled context"
|
||||
if project != nil {
|
||||
header = fmt.Sprintf("Recalled context for %s", project.Name)
|
||||
_ = t.store.TouchProject(ctx, project.ID)
|
||||
}
|
||||
|
||||
return nil, RecallOutput{
|
||||
Context: formatContextBlock(header, lines),
|
||||
Items: items,
|
||||
}, nil
|
||||
}
|
||||
69
internal/tools/search.go
Normal file
69
internal/tools/search.go
Normal file
@@ -0,0 +1,69 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
|
||||
"git.warky.dev/wdevs/amcs/internal/ai"
|
||||
"git.warky.dev/wdevs/amcs/internal/config"
|
||||
"git.warky.dev/wdevs/amcs/internal/session"
|
||||
"git.warky.dev/wdevs/amcs/internal/store"
|
||||
thoughttypes "git.warky.dev/wdevs/amcs/internal/types"
|
||||
)
|
||||
|
||||
type SearchTool struct {
|
||||
store *store.DB
|
||||
provider ai.Provider
|
||||
search config.SearchConfig
|
||||
sessions *session.ActiveProjects
|
||||
}
|
||||
|
||||
type SearchInput struct {
|
||||
Query string `json:"query" jsonschema:"the semantic query to search for"`
|
||||
Limit int `json:"limit,omitempty" jsonschema:"maximum number of results to return"`
|
||||
Threshold float64 `json:"threshold,omitempty" jsonschema:"minimum similarity threshold between 0 and 1"`
|
||||
Project string `json:"project,omitempty" jsonschema:"optional project name or id to scope the search"`
|
||||
}
|
||||
|
||||
type SearchOutput struct {
|
||||
Results []thoughttypes.SearchResult `json:"results"`
|
||||
}
|
||||
|
||||
func NewSearchTool(db *store.DB, provider ai.Provider, search config.SearchConfig, sessions *session.ActiveProjects) *SearchTool {
|
||||
return &SearchTool{store: db, provider: provider, search: search, sessions: sessions}
|
||||
}
|
||||
|
||||
func (t *SearchTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in SearchInput) (*mcp.CallToolResult, SearchOutput, error) {
|
||||
query := strings.TrimSpace(in.Query)
|
||||
if query == "" {
|
||||
return nil, SearchOutput{}, errInvalidInput("query is required")
|
||||
}
|
||||
|
||||
limit := normalizeLimit(in.Limit, t.search)
|
||||
threshold := normalizeThreshold(in.Threshold, t.search.DefaultThreshold)
|
||||
|
||||
embedding, err := t.provider.Embed(ctx, query)
|
||||
if err != nil {
|
||||
return nil, SearchOutput{}, err
|
||||
}
|
||||
|
||||
project, err := resolveProject(ctx, t.store, t.sessions, req, in.Project, false)
|
||||
if err != nil {
|
||||
return nil, SearchOutput{}, err
|
||||
}
|
||||
|
||||
var results []thoughttypes.SearchResult
|
||||
if project != nil {
|
||||
results, err = t.store.SearchSimilarThoughts(ctx, embedding, threshold, limit, &project.ID, nil)
|
||||
_ = t.store.TouchProject(ctx, project.ID)
|
||||
} else {
|
||||
results, err = t.store.SearchThoughts(ctx, embedding, threshold, limit, map[string]any{})
|
||||
}
|
||||
if err != nil {
|
||||
return nil, SearchOutput{}, err
|
||||
}
|
||||
|
||||
return nil, SearchOutput{Results: results}, nil
|
||||
}
|
||||
33
internal/tools/stats.go
Normal file
33
internal/tools/stats.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
|
||||
"git.warky.dev/wdevs/amcs/internal/store"
|
||||
thoughttypes "git.warky.dev/wdevs/amcs/internal/types"
|
||||
)
|
||||
|
||||
type StatsTool struct {
|
||||
store *store.DB
|
||||
}
|
||||
|
||||
type StatsInput struct{}
|
||||
|
||||
type StatsOutput struct {
|
||||
Stats thoughttypes.ThoughtStats `json:"stats"`
|
||||
}
|
||||
|
||||
func NewStatsTool(db *store.DB) *StatsTool {
|
||||
return &StatsTool{store: db}
|
||||
}
|
||||
|
||||
func (t *StatsTool) Handle(ctx context.Context, _ *mcp.CallToolRequest, _ StatsInput) (*mcp.CallToolResult, StatsOutput, error) {
|
||||
stats, err := t.store.Stats(ctx)
|
||||
if err != nil {
|
||||
return nil, StatsOutput{}, err
|
||||
}
|
||||
|
||||
return nil, StatsOutput{Stats: stats}, nil
|
||||
}
|
||||
93
internal/tools/summarize.go
Normal file
93
internal/tools/summarize.go
Normal file
@@ -0,0 +1,93 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
|
||||
"git.warky.dev/wdevs/amcs/internal/ai"
|
||||
"git.warky.dev/wdevs/amcs/internal/config"
|
||||
"git.warky.dev/wdevs/amcs/internal/session"
|
||||
"git.warky.dev/wdevs/amcs/internal/store"
|
||||
)
|
||||
|
||||
type SummarizeTool struct {
|
||||
store *store.DB
|
||||
provider ai.Provider
|
||||
search config.SearchConfig
|
||||
sessions *session.ActiveProjects
|
||||
}
|
||||
|
||||
type SummarizeInput struct {
|
||||
Query string `json:"query,omitempty" jsonschema:"optional semantic focus for the summary"`
|
||||
Project string `json:"project,omitempty" jsonschema:"optional project name or id; falls back to the active session project"`
|
||||
Days int `json:"days,omitempty" jsonschema:"only include thoughts from the last N days when query is omitted"`
|
||||
Limit int `json:"limit,omitempty" jsonschema:"maximum number of thoughts to summarize"`
|
||||
}
|
||||
|
||||
type SummarizeOutput struct {
|
||||
Summary string `json:"summary"`
|
||||
Count int `json:"count"`
|
||||
}
|
||||
|
||||
func NewSummarizeTool(db *store.DB, provider ai.Provider, search config.SearchConfig, sessions *session.ActiveProjects) *SummarizeTool {
|
||||
return &SummarizeTool{store: db, provider: provider, search: search, sessions: sessions}
|
||||
}
|
||||
|
||||
func (t *SummarizeTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in SummarizeInput) (*mcp.CallToolResult, SummarizeOutput, error) {
|
||||
project, err := resolveProject(ctx, t.store, t.sessions, req, in.Project, false)
|
||||
if err != nil {
|
||||
return nil, SummarizeOutput{}, err
|
||||
}
|
||||
|
||||
limit := normalizeLimit(in.Limit, t.search)
|
||||
query := strings.TrimSpace(in.Query)
|
||||
lines := make([]string, 0, limit)
|
||||
count := 0
|
||||
|
||||
if query != "" {
|
||||
embedding, err := t.provider.Embed(ctx, query)
|
||||
if err != nil {
|
||||
return nil, SummarizeOutput{}, err
|
||||
}
|
||||
var projectID *uuid.UUID
|
||||
if project != nil {
|
||||
projectID = &project.ID
|
||||
}
|
||||
results, err := t.store.SearchSimilarThoughts(ctx, embedding, t.search.DefaultThreshold, limit, projectID, nil)
|
||||
if err != nil {
|
||||
return nil, SummarizeOutput{}, err
|
||||
}
|
||||
for i, result := range results {
|
||||
lines = append(lines, thoughtContextLine(i, result.Content, result.Metadata, result.Similarity))
|
||||
}
|
||||
count = len(results)
|
||||
} else {
|
||||
var projectID *uuid.UUID
|
||||
if project != nil {
|
||||
projectID = &project.ID
|
||||
}
|
||||
thoughts, err := t.store.RecentThoughts(ctx, projectID, limit, in.Days)
|
||||
if err != nil {
|
||||
return nil, SummarizeOutput{}, err
|
||||
}
|
||||
for i, thought := range thoughts {
|
||||
lines = append(lines, thoughtContextLine(i, thought.Content, thought.Metadata, 0))
|
||||
}
|
||||
count = len(thoughts)
|
||||
}
|
||||
|
||||
userPrompt := formatContextBlock("Summarize the following thoughts into concise prose with themes, action items, and notable people.", lines)
|
||||
systemPrompt := "You summarize note collections. Be concise, concrete, and structured in plain prose."
|
||||
summary, err := t.provider.Summarize(ctx, systemPrompt, userPrompt)
|
||||
if err != nil {
|
||||
return nil, SummarizeOutput{}, err
|
||||
}
|
||||
if project != nil {
|
||||
_ = t.store.TouchProject(ctx, project.ID)
|
||||
}
|
||||
|
||||
return nil, SummarizeOutput{Summary: summary, Count: count}, nil
|
||||
}
|
||||
88
internal/tools/update.go
Normal file
88
internal/tools/update.go
Normal file
@@ -0,0 +1,88 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"strings"
|
||||
|
||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
|
||||
"git.warky.dev/wdevs/amcs/internal/ai"
|
||||
"git.warky.dev/wdevs/amcs/internal/config"
|
||||
"git.warky.dev/wdevs/amcs/internal/metadata"
|
||||
"git.warky.dev/wdevs/amcs/internal/store"
|
||||
thoughttypes "git.warky.dev/wdevs/amcs/internal/types"
|
||||
)
|
||||
|
||||
type UpdateTool struct {
|
||||
store *store.DB
|
||||
provider ai.Provider
|
||||
capture config.CaptureConfig
|
||||
log *slog.Logger
|
||||
}
|
||||
|
||||
type UpdateInput struct {
|
||||
ID string `json:"id" jsonschema:"the thought id"`
|
||||
Content *string `json:"content,omitempty" jsonschema:"replacement content for the thought"`
|
||||
Metadata thoughttypes.ThoughtMetadata `json:"metadata,omitempty" jsonschema:"metadata fields to merge into the thought"`
|
||||
Project string `json:"project,omitempty" jsonschema:"optional project name or id to move the thought into"`
|
||||
}
|
||||
|
||||
type UpdateOutput struct {
|
||||
Thought thoughttypes.Thought `json:"thought"`
|
||||
}
|
||||
|
||||
func NewUpdateTool(db *store.DB, provider ai.Provider, capture config.CaptureConfig, log *slog.Logger) *UpdateTool {
|
||||
return &UpdateTool{store: db, provider: provider, capture: capture, log: log}
|
||||
}
|
||||
|
||||
func (t *UpdateTool) Handle(ctx context.Context, _ *mcp.CallToolRequest, in UpdateInput) (*mcp.CallToolResult, UpdateOutput, error) {
|
||||
id, err := parseUUID(in.ID)
|
||||
if err != nil {
|
||||
return nil, UpdateOutput{}, err
|
||||
}
|
||||
|
||||
current, err := t.store.GetThought(ctx, id)
|
||||
if err != nil {
|
||||
return nil, UpdateOutput{}, err
|
||||
}
|
||||
|
||||
content := current.Content
|
||||
embedding := current.Embedding
|
||||
mergedMetadata := current.Metadata
|
||||
projectID := current.ProjectID
|
||||
|
||||
if in.Content != nil {
|
||||
content = strings.TrimSpace(*in.Content)
|
||||
if content == "" {
|
||||
return nil, UpdateOutput{}, errInvalidInput("content must not be empty")
|
||||
}
|
||||
embedding, err = t.provider.Embed(ctx, content)
|
||||
if err != nil {
|
||||
return nil, UpdateOutput{}, err
|
||||
}
|
||||
extracted, extractErr := t.provider.ExtractMetadata(ctx, content)
|
||||
if extractErr != nil {
|
||||
t.log.Warn("metadata extraction failed during update, keeping current metadata", slog.String("error", extractErr.Error()))
|
||||
} else {
|
||||
mergedMetadata = metadata.Normalize(extracted, t.capture)
|
||||
}
|
||||
}
|
||||
|
||||
mergedMetadata = metadata.Merge(mergedMetadata, in.Metadata, t.capture)
|
||||
|
||||
if rawProject := strings.TrimSpace(in.Project); rawProject != "" {
|
||||
project, err := t.store.GetProject(ctx, rawProject)
|
||||
if err != nil {
|
||||
return nil, UpdateOutput{}, err
|
||||
}
|
||||
projectID = &project.ID
|
||||
}
|
||||
|
||||
updated, err := t.store.UpdateThought(ctx, id, content, embedding, mergedMetadata, projectID)
|
||||
if err != nil {
|
||||
return nil, UpdateOutput{}, err
|
||||
}
|
||||
|
||||
return nil, UpdateOutput{Thought: updated}, nil
|
||||
}
|
||||
40
internal/types/project.go
Normal file
40
internal/types/project.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type ThoughtPatch struct {
|
||||
Content *string `json:"content,omitempty"`
|
||||
Metadata ThoughtMetadata `json:"metadata,omitempty"`
|
||||
Project *uuid.UUID `json:"project_id,omitempty"`
|
||||
}
|
||||
|
||||
type Project struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
LastActiveAt time.Time `json:"last_active_at"`
|
||||
}
|
||||
|
||||
type ProjectSummary struct {
|
||||
Project
|
||||
ThoughtCount int `json:"thought_count"`
|
||||
}
|
||||
|
||||
type ThoughtLink struct {
|
||||
FromID uuid.UUID `json:"from_id"`
|
||||
ToID uuid.UUID `json:"to_id"`
|
||||
Relation string `json:"relation"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
type LinkedThought struct {
|
||||
Thought Thought `json:"thought"`
|
||||
Relation string `json:"relation"`
|
||||
Direction string `json:"direction"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
57
internal/types/thought.go
Normal file
57
internal/types/thought.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type ThoughtMetadata struct {
|
||||
People []string `json:"people"`
|
||||
ActionItems []string `json:"action_items"`
|
||||
DatesMentioned []string `json:"dates_mentioned"`
|
||||
Topics []string `json:"topics"`
|
||||
Type string `json:"type"`
|
||||
Source string `json:"source"`
|
||||
}
|
||||
|
||||
type Thought struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
Content string `json:"content"`
|
||||
Embedding []float32 `json:"embedding,omitempty"`
|
||||
Metadata ThoughtMetadata `json:"metadata"`
|
||||
ProjectID *uuid.UUID `json:"project_id,omitempty"`
|
||||
ArchivedAt *time.Time `json:"archived_at,omitempty"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
type SearchResult struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
Content string `json:"content"`
|
||||
Metadata ThoughtMetadata `json:"metadata"`
|
||||
Similarity float64 `json:"similarity"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
type ListFilter struct {
|
||||
Limit int
|
||||
Type string
|
||||
Topic string
|
||||
Person string
|
||||
Days int
|
||||
ProjectID *uuid.UUID
|
||||
IncludeArchived bool
|
||||
}
|
||||
|
||||
type ThoughtStats struct {
|
||||
TotalCount int `json:"total_count"`
|
||||
TypeCounts map[string]int `json:"type_counts"`
|
||||
TopTopics []KeyCount `json:"top_topics"`
|
||||
TopPeople []KeyCount `json:"top_people"`
|
||||
}
|
||||
|
||||
type KeyCount struct {
|
||||
Key string `json:"key"`
|
||||
Count int `json:"count"`
|
||||
}
|
||||
2
migrations/001_enable_vector.sql
Normal file
2
migrations/001_enable_vector.sql
Normal file
@@ -0,0 +1,2 @@
|
||||
create extension if not exists vector;
|
||||
create extension if not exists pgcrypto;
|
||||
17
migrations/002_create_thoughts.sql
Normal file
17
migrations/002_create_thoughts.sql
Normal file
@@ -0,0 +1,17 @@
|
||||
create table if not exists thoughts (
|
||||
id uuid default gen_random_uuid() primary key,
|
||||
content text not null,
|
||||
embedding vector(1536),
|
||||
metadata jsonb default '{}'::jsonb,
|
||||
created_at timestamptz default now(),
|
||||
updated_at timestamptz default now()
|
||||
);
|
||||
|
||||
create index if not exists thoughts_embedding_hnsw_idx
|
||||
on thoughts using hnsw (embedding vector_cosine_ops);
|
||||
|
||||
create index if not exists thoughts_metadata_gin_idx
|
||||
on thoughts using gin (metadata);
|
||||
|
||||
create index if not exists thoughts_created_at_idx
|
||||
on thoughts (created_at desc);
|
||||
13
migrations/003_add_projects.sql
Normal file
13
migrations/003_add_projects.sql
Normal file
@@ -0,0 +1,13 @@
|
||||
create table if not exists projects (
|
||||
id uuid default gen_random_uuid() primary key,
|
||||
name text not null unique,
|
||||
description text,
|
||||
created_at timestamptz default now(),
|
||||
last_active_at timestamptz default now()
|
||||
);
|
||||
|
||||
alter table thoughts add column if not exists project_id uuid references projects(id);
|
||||
alter table thoughts add column if not exists archived_at timestamptz;
|
||||
|
||||
create index if not exists thoughts_project_id_idx on thoughts (project_id);
|
||||
create index if not exists thoughts_archived_at_idx on thoughts (archived_at);
|
||||
10
migrations/004_create_thought_links.sql
Normal file
10
migrations/004_create_thought_links.sql
Normal file
@@ -0,0 +1,10 @@
|
||||
create table if not exists thought_links (
|
||||
from_id uuid references thoughts(id) on delete cascade,
|
||||
to_id uuid references thoughts(id) on delete cascade,
|
||||
relation text not null,
|
||||
created_at timestamptz default now(),
|
||||
primary key (from_id, to_id, relation)
|
||||
);
|
||||
|
||||
create index if not exists thought_links_from_idx on thought_links (from_id);
|
||||
create index if not exists thought_links_to_idx on thought_links (to_id);
|
||||
31
migrations/005_create_match_thoughts.sql
Normal file
31
migrations/005_create_match_thoughts.sql
Normal file
@@ -0,0 +1,31 @@
|
||||
create or replace function match_thoughts(
|
||||
query_embedding vector(1536),
|
||||
match_threshold float default 0.7,
|
||||
match_count int default 10,
|
||||
filter jsonb default '{}'::jsonb
|
||||
)
|
||||
returns table (
|
||||
id uuid,
|
||||
content text,
|
||||
metadata jsonb,
|
||||
similarity float,
|
||||
created_at timestamptz
|
||||
)
|
||||
language plpgsql
|
||||
as $$
|
||||
begin
|
||||
return query
|
||||
select
|
||||
t.id,
|
||||
t.content,
|
||||
t.metadata,
|
||||
1 - (t.embedding <=> query_embedding) as similarity,
|
||||
t.created_at
|
||||
from thoughts t
|
||||
where 1 - (t.embedding <=> query_embedding) > match_threshold
|
||||
and t.archived_at is null
|
||||
and (filter = '{}'::jsonb or t.metadata @> filter)
|
||||
order by t.embedding <=> query_embedding
|
||||
limit match_count;
|
||||
end;
|
||||
$$;
|
||||
5
migrations/006_rls_and_grants.sql
Normal file
5
migrations/006_rls_and_grants.sql
Normal file
@@ -0,0 +1,5 @@
|
||||
-- Grant these permissions to the database role used by the application.
|
||||
-- Replace amcs_user with the actual role in your deployment before applying.
|
||||
grant select, insert, update, delete on table public.thoughts to amcs_user;
|
||||
grant select, insert, update, delete on table public.projects to amcs_user;
|
||||
grant select, insert, update, delete on table public.thought_links to amcs_user;
|
||||
15
scripts/migrate.sh
Executable file
15
scripts/migrate.sh
Executable file
@@ -0,0 +1,15 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
DATABASE_URL="${DATABASE_URL:-${OB1_DATABASE_URL:-}}"
|
||||
|
||||
if [[ -z "${DATABASE_URL}" ]]; then
|
||||
echo "DATABASE_URL or OB1_DATABASE_URL must be set" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
for migration in migrations/*.sql; do
|
||||
echo "Applying ${migration}"
|
||||
psql "${DATABASE_URL}" -v ON_ERROR_STOP=1 -f "${migration}"
|
||||
done
|
||||
5
scripts/run-local.sh
Executable file
5
scripts/run-local.sh
Executable file
@@ -0,0 +1,5 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
go run ./cmd/amcs-server --config "${1:-configs/dev.yaml}"
|
||||
Reference in New Issue
Block a user