Add structured learnings updates #35
@@ -31,6 +31,9 @@ jobs:
|
|||||||
- name: Download dependencies
|
- name: Download dependencies
|
||||||
run: go mod download
|
run: go mod download
|
||||||
|
|
||||||
|
- name: Tidy modules
|
||||||
|
run: go mod tidy
|
||||||
|
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
run: go test ./...
|
run: go test ./...
|
||||||
|
|
||||||
|
|||||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -34,3 +34,4 @@ OB1/
|
|||||||
ui/node_modules/
|
ui/node_modules/
|
||||||
ui/.svelte-kit/
|
ui/.svelte-kit/
|
||||||
internal/app/ui/dist/
|
internal/app/ui/dist/
|
||||||
|
.codex
|
||||||
|
|||||||
10
Dockerfile
10
Dockerfile
@@ -29,7 +29,14 @@ RUN set -eu; \
|
|||||||
-X git.warky.dev/wdevs/amcs/internal/buildinfo.TagName=${VERSION_TAG} \
|
-X git.warky.dev/wdevs/amcs/internal/buildinfo.TagName=${VERSION_TAG} \
|
||||||
-X git.warky.dev/wdevs/amcs/internal/buildinfo.Commit=${COMMIT_SHA} \
|
-X git.warky.dev/wdevs/amcs/internal/buildinfo.Commit=${COMMIT_SHA} \
|
||||||
-X git.warky.dev/wdevs/amcs/internal/buildinfo.BuildDate=${BUILD_DATE}" \
|
-X git.warky.dev/wdevs/amcs/internal/buildinfo.BuildDate=${BUILD_DATE}" \
|
||||||
-o /out/amcs-server ./cmd/amcs-server
|
-o /out/amcs-server ./cmd/amcs-server; \
|
||||||
|
CGO_ENABLED=0 GOOS=linux go build -trimpath \
|
||||||
|
-ldflags="-s -w \
|
||||||
|
-X git.warky.dev/wdevs/amcs/internal/buildinfo.Version=${VERSION_TAG} \
|
||||||
|
-X git.warky.dev/wdevs/amcs/internal/buildinfo.TagName=${VERSION_TAG} \
|
||||||
|
-X git.warky.dev/wdevs/amcs/internal/buildinfo.Commit=${COMMIT_SHA} \
|
||||||
|
-X git.warky.dev/wdevs/amcs/internal/buildinfo.BuildDate=${BUILD_DATE}" \
|
||||||
|
-o /out/amcs-migrate-config ./cmd/amcs-migrate-config
|
||||||
|
|
||||||
FROM debian:bookworm-slim
|
FROM debian:bookworm-slim
|
||||||
|
|
||||||
@@ -41,6 +48,7 @@ RUN apt-get update \
|
|||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
|
|
||||||
COPY --from=builder /out/amcs-server /app/amcs-server
|
COPY --from=builder /out/amcs-server /app/amcs-server
|
||||||
|
COPY --from=builder /out/amcs-migrate-config /app/amcs-migrate-config
|
||||||
COPY --chown=appuser:appuser configs /app/configs
|
COPY --chown=appuser:appuser configs /app/configs
|
||||||
|
|
||||||
USER appuser
|
USER appuser
|
||||||
|
|||||||
685
README.md
685
README.md
@@ -14,6 +14,685 @@ The AMCS directory is used to store configuration and code for the Advanced Modu
|
|||||||
|
|
||||||
## Next Steps
|
## Next Steps
|
||||||
|
|
||||||
- Review the configuration files in `configs/`
|
## Tools
|
||||||
- Run the setup script in `scripts/`
|
|
||||||
- Check the `assets/` directory for any required media files
|
| Tool | Purpose |
|
||||||
|
|---|---|
|
||||||
|
| `capture_thought` | Store a thought with embedding and metadata |
|
||||||
|
| `search_thoughts` | Semantic similarity search |
|
||||||
|
| `list_thoughts` | Filter thoughts by type, topic, person, date |
|
||||||
|
| `thought_stats` | Counts and top topics/people |
|
||||||
|
| `get_thought` | Retrieve a thought by ID |
|
||||||
|
| `update_thought` | Patch content or metadata |
|
||||||
|
| `delete_thought` | Hard delete |
|
||||||
|
| `archive_thought` | Soft delete |
|
||||||
|
| `create_project` | Register a named project |
|
||||||
|
| `list_projects` | List projects with thought counts |
|
||||||
|
| `get_project_context` | Recent + semantic context for a project; uses explicit `project` or the active session project |
|
||||||
|
| `set_active_project` | Set session project scope; requires a stateful MCP session |
|
||||||
|
| `get_active_project` | Get current session project |
|
||||||
|
| `summarize_thoughts` | LLM prose summary over a filtered set |
|
||||||
|
| `recall_context` | Semantic + recency context block for injection |
|
||||||
|
| `link_thoughts` | Create a typed relationship between thoughts |
|
||||||
|
| `related_thoughts` | Explicit links + semantic neighbours |
|
||||||
|
| `upload_file` | Stage a file from a server-side path or base64 and get an `amcs://files/{id}` resource URI |
|
||||||
|
| `save_file` | Store a file (base64 or resource URI) and optionally link it to a thought |
|
||||||
|
| `load_file` | Retrieve a stored file by ID; returns metadata, base64 content, and an embedded MCP binary resource |
|
||||||
|
| `list_files` | Browse stored files by thought, project, or kind |
|
||||||
|
| `backfill_embeddings` | Generate missing embeddings for stored thoughts |
|
||||||
|
| `reparse_thought_metadata` | Re-extract metadata from thought content |
|
||||||
|
| `retry_failed_metadata` | Retry pending/failed metadata extraction |
|
||||||
|
| `add_maintenance_task` | Create a recurring or one-time home maintenance task |
|
||||||
|
| `log_maintenance` | Log completed maintenance; updates next due date |
|
||||||
|
| `get_upcoming_maintenance` | List maintenance tasks due within the next N days |
|
||||||
|
| `search_maintenance_history` | Search the maintenance log by task name, category, or date range |
|
||||||
|
| `save_chat_history` | Save chat messages with optional title, summary, channel, agent, and project |
|
||||||
|
| `get_chat_history` | Fetch chat history by UUID or session_id |
|
||||||
|
| `list_chat_histories` | List chat histories; filter by project, channel, agent_id, session_id, or days |
|
||||||
|
| `delete_chat_history` | Delete a chat history by id |
|
||||||
|
| `add_skill` | Store an agent skill (instruction or capability prompt) |
|
||||||
|
| `remove_skill` | Delete an agent skill by id |
|
||||||
|
| `list_skills` | List all agent skills, optionally filtered by tag |
|
||||||
|
| `add_guardrail` | Store an agent guardrail (constraint or safety rule) |
|
||||||
|
| `remove_guardrail` | Delete an agent guardrail by id |
|
||||||
|
| `list_guardrails` | List all agent guardrails, optionally filtered by tag or severity |
|
||||||
|
| `add_project_skill` | Link a skill to a project; pass `project` if client is stateless |
|
||||||
|
| `remove_project_skill` | Unlink a skill from a project; pass `project` if client is stateless |
|
||||||
|
| `list_project_skills` | Skills for a project; pass `project` if client is stateless |
|
||||||
|
| `add_project_guardrail` | Link a guardrail to a project; pass `project` if client is stateless |
|
||||||
|
| `remove_project_guardrail` | Unlink a guardrail from a project; pass `project` if client is stateless |
|
||||||
|
| `list_project_guardrails` | Guardrails for a project; pass `project` if client is stateless |
|
||||||
|
| `get_version_info` | Build version, commit, and date |
|
||||||
|
| `describe_tools` | List all available MCP tools with names, descriptions, categories, and model-authored usage notes; call this at the start of a session to orient yourself |
|
||||||
|
| `annotate_tool` | Persist your own usage notes for a specific tool; notes are returned by `describe_tools` in future sessions |
|
||||||
|
|
||||||
|
## Self-Documenting Tools
|
||||||
|
|
||||||
|
AMCS includes a built-in tool directory that models can read and annotate.
|
||||||
|
|
||||||
|
**`describe_tools`** returns every registered tool with its name, description, category, and any model-written notes. Call it with no arguments to get the full list, or filter by category:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{ "category": "thoughts" }
|
||||||
|
```
|
||||||
|
|
||||||
|
Available categories: `system`, `thoughts`, `projects`, `files`, `admin`, `maintenance`, `skills`, `chat`, `meta`.
|
||||||
|
|
||||||
|
**`annotate_tool`** lets a model write persistent usage notes against a tool name. Notes survive across sessions and are returned by `describe_tools`:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{ "tool_name": "capture_thought", "notes": "Always pass project explicitly — session state is not reliable in this client." }
|
||||||
|
```
|
||||||
|
|
||||||
|
Pass an empty string to clear notes. The intended workflow is:
|
||||||
|
|
||||||
|
1. At the start of a session, call `describe_tools` to discover tools and read accumulated notes.
|
||||||
|
2. As you learn something non-obvious about a tool — a gotcha, a workflow pattern, a required field ordering — call `annotate_tool` to record it.
|
||||||
|
3. Future sessions receive the annotation automatically via `describe_tools`.
|
||||||
|
|
||||||
|
## MCP Error Contract
|
||||||
|
|
||||||
|
AMCS returns structured JSON-RPC errors for common MCP failures. Clients should branch on both `error.code` and `error.data.type` instead of parsing the human-readable message.
|
||||||
|
|
||||||
|
### Stable error codes
|
||||||
|
|
||||||
|
| Code | `data.type` | Meaning |
|
||||||
|
|---|---|---|
|
||||||
|
| `-32602` | `invalid_arguments` | MCP argument/schema validation failed before the tool handler ran |
|
||||||
|
| `-32602` | `invalid_input` | Tool-level input validation failed inside the handler |
|
||||||
|
| `-32050` | `session_required` | Tool requires a stateful MCP session |
|
||||||
|
| `-32051` | `project_required` | No explicit `project` was provided and no active session project was available |
|
||||||
|
| `-32052` | `project_not_found` | The referenced project does not exist |
|
||||||
|
| `-32053` | `invalid_id` | A UUID-like identifier was malformed |
|
||||||
|
| `-32054` | `entity_not_found` | A referenced entity such as a thought or contact does not exist |
|
||||||
|
|
||||||
|
### Error data shape
|
||||||
|
|
||||||
|
AMCS may include these fields in `error.data`:
|
||||||
|
|
||||||
|
- `type` — stable machine-readable error type
|
||||||
|
- `field` — single argument name such as `name`, `project`, or `thought_id`
|
||||||
|
- `fields` — multiple argument names for one-of or mutually-exclusive validation
|
||||||
|
- `value` — offending value when safe to expose
|
||||||
|
- `detail` — validation detail such as `required`, `invalid`, `one_of_required`, `mutually_exclusive`, or a schema validation message
|
||||||
|
- `hint` — remediation guidance
|
||||||
|
- `entity` — entity name for generic not-found errors
|
||||||
|
|
||||||
|
Example schema-level error:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"code": -32602,
|
||||||
|
"message": "invalid tool arguments",
|
||||||
|
"data": {
|
||||||
|
"type": "invalid_arguments",
|
||||||
|
"field": "name",
|
||||||
|
"detail": "validating root: required: missing properties: [\"name\"]",
|
||||||
|
"hint": "check the name argument"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Example tool-level error:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"code": -32051,
|
||||||
|
"message": "project is required; pass project explicitly or call set_active_project in this MCP session first",
|
||||||
|
"data": {
|
||||||
|
"type": "project_required",
|
||||||
|
"field": "project",
|
||||||
|
"hint": "pass project explicitly or call set_active_project in this MCP session first"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Client example
|
||||||
|
|
||||||
|
Go client example handling AMCS MCP errors:
|
||||||
|
|
||||||
|
```go
|
||||||
|
result, err := session.CallTool(ctx, &mcp.CallToolParams{
|
||||||
|
Name: "get_project_context",
|
||||||
|
Arguments: map[string]any{},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
var rpcErr *jsonrpc.Error
|
||||||
|
if errors.As(err, &rpcErr) {
|
||||||
|
var data struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Field string `json:"field"`
|
||||||
|
Hint string `json:"hint"`
|
||||||
|
}
|
||||||
|
_ = json.Unmarshal(rpcErr.Data, &data)
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case rpcErr.Code == -32051 && data.Type == "project_required":
|
||||||
|
// Retry with an explicit project, or call set_active_project first.
|
||||||
|
case rpcErr.Code == -32602 && data.Type == "invalid_arguments":
|
||||||
|
// Ask the caller to fix the malformed arguments.
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ = result
|
||||||
|
```
|
||||||
|
|
||||||
|
## Build Versioning
|
||||||
|
|
||||||
|
AMCS embeds build metadata into the binary at build time.
|
||||||
|
|
||||||
|
- `version` is generated from the current git tag when building from a tagged commit
|
||||||
|
- `tag_name` is the repo tag name, for example `v1.0.1`
|
||||||
|
- `build_date` is the UTC build timestamp in RFC3339 format
|
||||||
|
- `commit` is the short git commit SHA
|
||||||
|
|
||||||
|
For untagged builds, `version` and `tag_name` fall back to `dev`.
|
||||||
|
|
||||||
|
Use `get_version_info` to retrieve the runtime build metadata:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"server_name": "amcs",
|
||||||
|
"version": "v1.0.1",
|
||||||
|
"tag_name": "v1.0.1",
|
||||||
|
"commit": "abc1234",
|
||||||
|
"build_date": "2026-03-31T14:22:10Z"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Agent Skills and Guardrails
|
||||||
|
|
||||||
|
Skills and guardrails are reusable agent behaviour instructions and constraints that can be attached to projects.
|
||||||
|
|
||||||
|
**At the start of every project session, always call `list_project_skills` and `list_project_guardrails` first.** Use the returned skills and guardrails to guide agent behaviour for that project. Only generate or create new skills/guardrails if none are returned. If your MCP client does not preserve sessions across calls, pass `project` explicitly instead of relying on `set_active_project`.
|
||||||
|
|
||||||
|
### Skills
|
||||||
|
|
||||||
|
A skill is a reusable behavioural instruction or capability prompt — for example, "always respond in structured markdown" or "break complex tasks into numbered steps before starting".
|
||||||
|
|
||||||
|
```json
|
||||||
|
{ "name": "structured-output", "description": "Enforce markdown output format", "content": "Always structure responses using markdown headers and bullet points.", "tags": ["formatting"] }
|
||||||
|
```
|
||||||
|
|
||||||
|
### Guardrails
|
||||||
|
|
||||||
|
A guardrail is a constraint or safety rule — for example, "never delete files without explicit confirmation" or "do not expose secrets in output".
|
||||||
|
|
||||||
|
```json
|
||||||
|
{ "name": "no-silent-deletes", "description": "Require confirmation before deletes", "content": "Never delete, drop, or truncate data without first confirming with the user.", "severity": "high", "tags": ["safety"] }
|
||||||
|
```
|
||||||
|
|
||||||
|
Severity levels: `low`, `medium`, `high`, `critical`.
|
||||||
|
|
||||||
|
### Project linking
|
||||||
|
|
||||||
|
Link existing skills and guardrails to a project so they are automatically available when that project is active:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{ "project": "my-project", "skill_id": "<uuid>" }
|
||||||
|
{ "project": "my-project", "guardrail_id": "<uuid>" }
|
||||||
|
```
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
Config is YAML-driven. Copy `configs/config.example.yaml` and set:
|
||||||
|
|
||||||
|
- `database.url` — Postgres connection string
|
||||||
|
- `auth.keys` — static API keys for MCP access via `x-brain-key` or `Authorization: Bearer <key>`
|
||||||
|
- `auth.oauth.clients` — optional OAuth client credentials registry
|
||||||
|
- `ai.providers` — named provider definitions (`litellm`, `ollama`, `openrouter`)
|
||||||
|
- `ai.embeddings.primary` / `ai.metadata.primary` — primary role targets (`provider` + `model`)
|
||||||
|
- `ai.embeddings.fallbacks` / `ai.metadata.fallbacks` — sequential fallback targets
|
||||||
|
- `mcp.version` is build-generated and should not be set in config
|
||||||
|
|
||||||
|
Config schema is versioned. Current schema version is `2`.
|
||||||
|
|
||||||
|
Use the migration helper to rewrite legacy configs in-place:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
go run ./cmd/amcs-migrate-config --config ./configs/dev.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
Use `--dry-run` to print migrated YAML without writing.
|
||||||
|
Server startup migrates older config formats in memory only and does not write files.
|
||||||
|
|
||||||
|
**OAuth Client Credentials flow**:
|
||||||
|
|
||||||
|
1. Obtain a token — `POST /oauth/token` (public, no auth required):
|
||||||
|
```
|
||||||
|
POST /oauth/token
|
||||||
|
Content-Type: application/x-www-form-urlencoded
|
||||||
|
Authorization: Basic base64(client_id:client_secret)
|
||||||
|
|
||||||
|
grant_type=client_credentials
|
||||||
|
```
|
||||||
|
Returns: `{"access_token": "...", "token_type": "bearer", "expires_in": 3600}`
|
||||||
|
|
||||||
|
2. Use the token on the MCP endpoint:
|
||||||
|
```
|
||||||
|
Authorization: Bearer <access_token>
|
||||||
|
```
|
||||||
|
|
||||||
|
Alternatively, pass `client_id` and `client_secret` as body parameters instead of `Authorization: Basic`. Direct `Authorization: Basic` credential validation on the MCP endpoint is also supported as a fallback (no token required).
|
||||||
|
- `AMCS_LITELLM_BASE_URL` / `AMCS_LITELLM_API_KEY` override all configured LiteLLM providers
|
||||||
|
- `AMCS_OLLAMA_BASE_URL` / `AMCS_OLLAMA_API_KEY` override all configured Ollama providers
|
||||||
|
- `AMCS_OPENROUTER_API_KEY` overrides all configured OpenRouter providers
|
||||||
|
|
||||||
|
See `llm/plan.md` for an audited high-level status summary of the original implementation plan, and `llm/todo.md` for the audited backfill/fallback follow-up status.
|
||||||
|
|
||||||
|
## Backfill
|
||||||
|
|
||||||
|
Run `backfill_embeddings` after switching embedding models or importing thoughts without vectors.
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"project": "optional-project-name",
|
||||||
|
"limit": 100,
|
||||||
|
"include_archived": false,
|
||||||
|
"older_than_days": 0,
|
||||||
|
"dry_run": false
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
- `dry_run: true` — report counts without calling the embedding provider
|
||||||
|
- `limit` — max thoughts per call (default 100)
|
||||||
|
- Embeddings are generated in parallel (4 workers) and upserted; one failure does not abort the run
|
||||||
|
|
||||||
|
## Metadata Reparse
|
||||||
|
|
||||||
|
Run `reparse_thought_metadata` to fix stale or inconsistent metadata by re-extracting it from thought content.
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"project": "optional-project-name",
|
||||||
|
"limit": 100,
|
||||||
|
"include_archived": false,
|
||||||
|
"older_than_days": 0,
|
||||||
|
"dry_run": false
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
- `dry_run: true` scans only and does not call metadata extraction or write updates
|
||||||
|
- If extraction fails for a thought, existing metadata is normalized and written only if it changes
|
||||||
|
- Metadata reparse runs in parallel (4 workers); one failure does not abort the run
|
||||||
|
|
||||||
|
## Failed Metadata Retry
|
||||||
|
|
||||||
|
`capture_thought` now stores the thought even when metadata extraction times out or fails. Those thoughts are marked with `metadata_status: "pending"` and retried in the background. Use `retry_failed_metadata` to sweep any thoughts still marked `pending` or `failed`.
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"project": "optional-project-name",
|
||||||
|
"limit": 100,
|
||||||
|
"include_archived": false,
|
||||||
|
"older_than_days": 1,
|
||||||
|
"dry_run": false
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
- `dry_run: true` scans only and does not call metadata extraction or write updates
|
||||||
|
- successful retries mark the thought metadata as `complete` and clear the last error
|
||||||
|
- failed retries update the retry markers so the daily sweep can pick them up again later
|
||||||
|
|
||||||
|
## File Storage
|
||||||
|
|
||||||
|
Files can optionally be linked to a thought by passing `thought_id`, which also adds an attachment reference to that thought's metadata. AI clients should prefer `save_file` when the goal is to retain the artifact itself, rather than reading or summarizing the file first. Stored files and attachment metadata are not forwarded to the metadata extraction client.
|
||||||
|
|
||||||
|
### MCP tools
|
||||||
|
|
||||||
|
**Stage a file and get a URI** (`upload_file`) — preferred for large or binary files:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"name": "diagram.png",
|
||||||
|
"content_path": "/absolute/path/to/diagram.png"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Or with base64 for small files (≤10 MB):
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"name": "diagram.png",
|
||||||
|
"content_base64": "<base64-payload>"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Returns `{"file": {...}, "uri": "amcs://files/<id>"}`. Pass `thought_id`/`project` to link immediately, or omit them and use the URI in a later `save_file` call.
|
||||||
|
|
||||||
|
**Link a staged file to a thought** (`save_file` with `content_uri`):
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"name": "meeting-notes.pdf",
|
||||||
|
"thought_id": "optional-thought-uuid",
|
||||||
|
"content_uri": "amcs://files/<id-from-upload_file>"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Save small files inline** (`save_file` with `content_base64`, ≤10 MB):
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"name": "meeting-notes.pdf",
|
||||||
|
"media_type": "application/pdf",
|
||||||
|
"kind": "document",
|
||||||
|
"thought_id": "optional-thought-uuid",
|
||||||
|
"content_base64": "<base64-payload>"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
`content_base64` and `content_uri` are mutually exclusive in both tools.
|
||||||
|
|
||||||
|
**Load a file** — returns metadata, base64 content, and an embedded MCP binary resource (`amcs://files/{id}`). The `id` field accepts either the bare stored file UUID or the full `amcs://files/{id}` URI:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{ "id": "stored-file-uuid" }
|
||||||
|
```
|
||||||
|
|
||||||
|
**List files** for a thought or project:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"thought_id": "optional-thought-uuid",
|
||||||
|
"project": "optional-project-name",
|
||||||
|
"kind": "optional-image-document-audio-file",
|
||||||
|
"limit": 20
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### MCP resources
|
||||||
|
|
||||||
|
Stored files are also exposed as MCP resources at `amcs://files/{id}`. MCP clients can read raw binary content directly via `resources/read` without going through `load_file`.
|
||||||
|
|
||||||
|
### HTTP upload and download
|
||||||
|
|
||||||
|
Direct HTTP access avoids base64 encoding entirely. The Go server caps `/files` uploads at 100 MB per request. Large uploads are also subject to available memory, Postgres limits, and any reverse proxy or load balancer in front of AMCS.
|
||||||
|
|
||||||
|
Multipart upload:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -X POST http://localhost:8080/files \
|
||||||
|
-H "x-brain-key: <key>" \
|
||||||
|
-F "file=@./diagram.png" \
|
||||||
|
-F "project=amcs" \
|
||||||
|
-F "kind=image"
|
||||||
|
```
|
||||||
|
|
||||||
|
Raw body upload:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -X POST "http://localhost:8080/files?project=amcs&name=meeting-notes.pdf" \
|
||||||
|
-H "x-brain-key: <key>" \
|
||||||
|
-H "Content-Type: application/pdf" \
|
||||||
|
--data-binary @./meeting-notes.pdf
|
||||||
|
```
|
||||||
|
|
||||||
|
Binary download:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl http://localhost:8080/files/<id> \
|
||||||
|
-H "x-brain-key: <key>" \
|
||||||
|
-o meeting-notes.pdf
|
||||||
|
```
|
||||||
|
|
||||||
|
**Automatic backfill** (optional, config-gated):
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
backfill:
|
||||||
|
enabled: true
|
||||||
|
run_on_startup: true # run once on server start
|
||||||
|
interval: "15m" # repeat every 15 minutes
|
||||||
|
batch_size: 20
|
||||||
|
max_per_run: 100
|
||||||
|
include_archived: false
|
||||||
|
```
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
metadata_retry:
|
||||||
|
enabled: true
|
||||||
|
run_on_startup: true # retry failed metadata once on server start
|
||||||
|
interval: "24h" # retry pending/failed metadata daily
|
||||||
|
max_per_run: 100
|
||||||
|
include_archived: false
|
||||||
|
```
|
||||||
|
|
||||||
|
**Search fallback**: when no embeddings exist for the active model in scope, `search_thoughts`, `recall_context`, `get_project_context`, `summarize_thoughts`, and `related_thoughts` automatically fall back to Postgres full-text search so results are never silently empty.
|
||||||
|
|
||||||
|
## Client Setup
|
||||||
|
|
||||||
|
### Claude Code
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# API key auth
|
||||||
|
claude mcp add --transport http amcs http://localhost:8080/mcp --header "x-brain-key: <key>"
|
||||||
|
|
||||||
|
# Bearer token auth
|
||||||
|
claude mcp add --transport http amcs http://localhost:8080/mcp --header "Authorization: Bearer <token>"
|
||||||
|
```
|
||||||
|
|
||||||
|
### OpenAI Codex
|
||||||
|
|
||||||
|
Add to `~/.codex/config.toml`:
|
||||||
|
|
||||||
|
```toml
|
||||||
|
[[mcp_servers]]
|
||||||
|
name = "amcs"
|
||||||
|
url = "http://localhost:8080/mcp"
|
||||||
|
|
||||||
|
[mcp_servers.headers]
|
||||||
|
x-brain-key = "<key>"
|
||||||
|
```
|
||||||
|
|
||||||
|
### OpenCode
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# API key auth
|
||||||
|
opencode mcp add --name amcs --type remote --url http://localhost:8080/mcp --header "x-brain-key=<key>"
|
||||||
|
|
||||||
|
# Bearer token auth
|
||||||
|
opencode mcp add --name amcs --type remote --url http://localhost:8080/mcp --header "Authorization=Bearer <token>"
|
||||||
|
```
|
||||||
|
|
||||||
|
Or add directly to `opencode.json` / `~/.config/opencode/config.json`:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"mcp": {
|
||||||
|
"amcs": {
|
||||||
|
"type": "remote",
|
||||||
|
"url": "http://localhost:8080/mcp",
|
||||||
|
"headers": {
|
||||||
|
"x-brain-key": "<key>"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Apache Proxy
|
||||||
|
|
||||||
|
If AMCS is deployed behind Apache HTTP Server, configure the proxy explicitly for larger uploads and longer-running requests.
|
||||||
|
|
||||||
|
Example virtual host settings for the current AMCS defaults:
|
||||||
|
|
||||||
|
```apache
|
||||||
|
<VirtualHost *:443>
|
||||||
|
ServerName amcs.example.com
|
||||||
|
|
||||||
|
ProxyPreserveHost On
|
||||||
|
LimitRequestBody 104857600
|
||||||
|
RequestReadTimeout handshake=0 header=20-40,MinRate=500 body=600,MinRate=500
|
||||||
|
Timeout 600
|
||||||
|
ProxyTimeout 600
|
||||||
|
|
||||||
|
ProxyPass /mcp http://127.0.0.1:8080/mcp connectiontimeout=30 timeout=600
|
||||||
|
ProxyPassReverse /mcp http://127.0.0.1:8080/mcp
|
||||||
|
|
||||||
|
ProxyPass /files http://127.0.0.1:8080/files connectiontimeout=30 timeout=600
|
||||||
|
ProxyPassReverse /files http://127.0.0.1:8080/files
|
||||||
|
</VirtualHost>
|
||||||
|
```
|
||||||
|
|
||||||
|
Recommended Apache settings:
|
||||||
|
|
||||||
|
- `LimitRequestBody 104857600` matches AMCS's 100 MB `/files` upload cap.
|
||||||
|
- `RequestReadTimeout ... body=600` gives clients up to 10 minutes to send larger request bodies.
|
||||||
|
- `ProxyTimeout 600` and `ProxyPass ... timeout=600` give Apache enough time to wait for the Go backend.
|
||||||
|
- If another proxy or load balancer sits in front of Apache, align its size and timeout settings too.
|
||||||
|
|
||||||
|
## CLI
|
||||||
|
|
||||||
|
`amcs-cli` is a pre-built CLI client for the AMCS MCP server. Download it from https://git.warky.dev/wdevs/amcs/releases
|
||||||
|
|
||||||
|
The primary purpose is to give agents and MCP clients a ready-made bridge to the AMCS server so they do not need to implement their own HTTP MCP client. Configure it once and any stdio-based MCP client can use AMCS immediately.
|
||||||
|
|
||||||
|
### Commands
|
||||||
|
|
||||||
|
| Command | Purpose |
|
||||||
|
|---|---|
|
||||||
|
| `amcs-cli tools` | List all tools available on the remote server |
|
||||||
|
| `amcs-cli call <tool>` | Call a tool by name with `--arg key=value` flags |
|
||||||
|
| `amcs-cli stdio` | Start a stdio MCP bridge backed by the remote server |
|
||||||
|
|
||||||
|
`stdio` is the main integration point. It connects to the remote HTTP MCP server, discovers all its tools, and re-exposes them over stdio. Register it as a stdio MCP server in your agent config and it proxies every tool call through to AMCS.
|
||||||
|
|
||||||
|
### Configuration
|
||||||
|
|
||||||
|
Config file: `~/.config/amcs/config.yaml`
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
server: https://your-amcs-server
|
||||||
|
token: your-bearer-token
|
||||||
|
```
|
||||||
|
|
||||||
|
Env vars override the config file: `AMCS_SERVER` (preferred), `AMCS_URL` (legacy alias), and `AMCS_TOKEN`. Flags `--server` and `--token` override env vars.
|
||||||
|
|
||||||
|
### stdio MCP client setup
|
||||||
|
|
||||||
|
#### Claude Code
|
||||||
|
|
||||||
|
```bash
|
||||||
|
claude mcp add --transport stdio amcs amcs-cli stdio
|
||||||
|
```
|
||||||
|
|
||||||
|
With inline credentials (no config file):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
claude mcp add --transport stdio amcs amcs-cli stdio \
|
||||||
|
--env AMCS_SERVER=https://your-amcs-server \
|
||||||
|
--env AMCS_TOKEN=your-bearer-token
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Output format
|
||||||
|
|
||||||
|
`call` outputs JSON by default. Pass `--output yaml` for YAML.
|
||||||
|
|
||||||
|
## Development
|
||||||
|
|
||||||
|
Run the SQL migrations against a local database with:
|
||||||
|
|
||||||
|
`DATABASE_URL=postgres://... make migrate`
|
||||||
|
|
||||||
|
### Backend + embedded UI build
|
||||||
|
|
||||||
|
The web UI now lives in the top-level `ui/` module and is embedded into the Go binary at build time with `go:embed`.
|
||||||
|
|
||||||
|
**Use `pnpm` for all UI work in this repo.**
|
||||||
|
|
||||||
|
- `make build` — runs the real UI build first, then compiles the Go server
|
||||||
|
- `make test` — runs `svelte-check` for the frontend and `go test ./...` for the backend
|
||||||
|
- `make ui-install` — installs frontend dependencies with `pnpm install --frozen-lockfile`
|
||||||
|
- `make ui-build` — builds only the frontend bundle
|
||||||
|
- `make ui-dev` — starts the Vite dev server with hot reload on `http://localhost:5173`
|
||||||
|
- `make ui-check` — runs the frontend type and Svelte checks
|
||||||
|
|
||||||
|
### Local UI workflow
|
||||||
|
|
||||||
|
For the normal production-style local flow:
|
||||||
|
|
||||||
|
1. Start the backend: `./scripts/run-local.sh configs/dev.yaml`
|
||||||
|
2. Open `http://localhost:8080`
|
||||||
|
|
||||||
|
For frontend iteration with hot reload and no Go rebuilds:
|
||||||
|
|
||||||
|
1. Start the backend once: `go run ./cmd/amcs-server --config configs/dev.yaml`
|
||||||
|
2. In another shell start the UI dev server: `make ui-dev`
|
||||||
|
3. Open `http://localhost:5173`
|
||||||
|
|
||||||
|
The Vite dev server proxies backend routes such as `/api/status`, `/llm`, `/healthz`, `/readyz`, `/files`, `/mcp`, and the OAuth endpoints back to the Go server on `http://127.0.0.1:8080` by default. Override that target with `AMCS_UI_BACKEND` if needed.
|
||||||
|
|
||||||
|
The root page (`/`) is now the Svelte frontend. It preserves the existing landing-page content and status information by fetching data from `GET /api/status`.
|
||||||
|
|
||||||
|
LLM integration instructions are still served at `/llm`.
|
||||||
|
|
||||||
|
## Containers
|
||||||
|
|
||||||
|
The repo now includes a `Dockerfile` and Compose files for running the app with Postgres + pgvector.
|
||||||
|
|
||||||
|
1. Set a real LiteLLM key in your shell:
|
||||||
|
`export AMCS_LITELLM_API_KEY=your-key`
|
||||||
|
2. Start the stack with your runtime:
|
||||||
|
`docker compose -f docker-compose.yml -f docker-compose.docker.yml up --build`
|
||||||
|
`podman compose -f docker-compose.yml up --build`
|
||||||
|
3. Call the service on `http://localhost:8080`
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
|
||||||
|
- The app uses `configs/docker.yaml` inside the container.
|
||||||
|
- The local `./configs` directory is mounted into `/app/configs`, so config edits apply without rebuilding the image.
|
||||||
|
- `AMCS_LITELLM_BASE_URL` overrides the LiteLLM endpoint, so you can retarget it without editing YAML.
|
||||||
|
- `AMCS_OLLAMA_BASE_URL` overrides the Ollama endpoint for local or remote servers.
|
||||||
|
- The Compose stack uses a default bridge network named `amcs`.
|
||||||
|
- The base Compose file uses `host.containers.internal`, which is Podman-friendly.
|
||||||
|
- The Docker override file adds `host-gateway` aliases so Docker can resolve the same host endpoint.
|
||||||
|
- Database migrations `001` through `005` run automatically when the Postgres volume is created for the first time.
|
||||||
|
- `migrations/006_rls_and_grants.sql` is intentionally skipped during container bootstrap because it contains deployment-specific grants for a role named `amcs_user`.
|
||||||
|
|
||||||
|
### Run config migration with Compose
|
||||||
|
|
||||||
|
The container image now includes `/app/amcs-migrate-config`.
|
||||||
|
|
||||||
|
Dry-run (prints migrated YAML, does not write files):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
docker compose --profile tools run --rm migrate-config --config /app/configs/dev.yaml --dry-run
|
||||||
|
```
|
||||||
|
|
||||||
|
Apply migration in-place (writes file + creates backup):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
docker compose --profile tools run --rm migrate-config --config /app/configs/dev.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
## Ollama
|
||||||
|
|
||||||
|
Set your role targets to an Ollama provider to use a local or self-hosted Ollama server through its OpenAI-compatible API.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
ai:
|
||||||
|
providers:
|
||||||
|
local:
|
||||||
|
type: "ollama"
|
||||||
|
base_url: "http://localhost:11434/v1"
|
||||||
|
api_key: "ollama"
|
||||||
|
request_headers: {}
|
||||||
|
embeddings:
|
||||||
|
dimensions: 768
|
||||||
|
primary:
|
||||||
|
provider: "local"
|
||||||
|
model: "nomic-embed-text"
|
||||||
|
metadata:
|
||||||
|
temperature: 0.1
|
||||||
|
primary:
|
||||||
|
provider: "local"
|
||||||
|
model: "llama3.2"
|
||||||
|
```
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
|
||||||
|
- For remote Ollama servers, point `ai.providers.<name>.base_url` at the remote `/v1` endpoint.
|
||||||
|
- The client always sends Bearer auth; Ollama ignores it locally, so `api_key: "ollama"` is a safe default.
|
||||||
|
- `ai.embeddings.dimensions` must match the embedding model you actually use, or startup will fail the database vector-dimension check.
|
||||||
|
|||||||
90
changelog.md
Normal file
90
changelog.md
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
# Changelog
|
||||||
|
|
||||||
|
## 2026-04-21
|
||||||
|
|
||||||
|
### 2026-04-21 21h - Config Schema v2 Introduced
|
||||||
|
|
||||||
|
- Refactored configuration to schema version `2` with named AI providers and role-based model chains.
|
||||||
|
- Added support for per-role primary and fallback targets for embeddings and metadata.
|
||||||
|
- Added optional background role overrides for backfill and metadata retry workers.
|
||||||
|
|
||||||
|
### 2026-04-21 21h - Automatic v1 -> v2 Migration
|
||||||
|
|
||||||
|
- Added config migration framework with explicit schema versioning.
|
||||||
|
- Implemented `v1 -> v2` migration to transform legacy provider blocks into named providers + role chains.
|
||||||
|
- Loader now auto-migrates older config files, rewrites migrated YAML, and creates timestamped backups.
|
||||||
|
|
||||||
|
### 2026-04-21 21h - AI Registry and Role Runners
|
||||||
|
|
||||||
|
- Added `ai.Registry` to build provider clients from named provider config entries.
|
||||||
|
- Added `EmbeddingRunner` and `MetadataRunner` with sequential fallback execution.
|
||||||
|
- Added target health tracking with cooldowns for transient/permanent/empty-response failures.
|
||||||
|
|
||||||
|
### 2026-04-21 21h - App and Tool Wiring Updates
|
||||||
|
|
||||||
|
- Rewired app startup to use provider registry + role runners for foreground and background flows.
|
||||||
|
- Updated capture, search, summarize, context, recall, backfill, metadata retry, and reparse paths to use new runners.
|
||||||
|
- Preserved environment override behavior for provider credentials/endpoints across matching provider types.
|
||||||
|
|
||||||
|
### 2026-04-21 21h - Migrate Config CLI Added
|
||||||
|
|
||||||
|
- Added `cmd/amcs-migrate-config` CLI to migrate config files to the current schema version.
|
||||||
|
- Supports dry-run output and in-place write mode with automatic backup file creation.
|
||||||
|
|
||||||
|
### 2026-04-21 21h - Tests and Documentation Updated
|
||||||
|
|
||||||
|
- Added focused tests for config migration, AI registry behavior, and runner fallback behavior.
|
||||||
|
- Updated `configs/config.example.yaml` to the new v2 schema.
|
||||||
|
- Updated README configuration sections and migration guidance to reflect v2 and `amcs-migrate-config` usage.
|
||||||
|
|
||||||
|
### 2026-04-21 21h - Uncommitted File Change List
|
||||||
|
|
||||||
|
- Modified: `.gitignore`
|
||||||
|
- Modified: `README.md`
|
||||||
|
- Modified: `configs/config.example.yaml`
|
||||||
|
- Modified: `internal/ai/compat/client.go`
|
||||||
|
- Modified: `internal/ai/compat/client_test.go`
|
||||||
|
- Modified: `internal/app/app.go`
|
||||||
|
- Modified: `internal/config/config.go`
|
||||||
|
- Modified: `internal/config/loader.go`
|
||||||
|
- Modified: `internal/config/loader_test.go`
|
||||||
|
- Modified: `internal/config/validate.go`
|
||||||
|
- Modified: `internal/config/validate_test.go`
|
||||||
|
- Modified: `internal/mcpserver/server.go`
|
||||||
|
- Modified: `internal/mcpserver/streamable_integration_test.go`
|
||||||
|
- Modified: `internal/tools/backfill.go`
|
||||||
|
- Modified: `internal/tools/capture.go`
|
||||||
|
- Modified: `internal/tools/context.go`
|
||||||
|
- Modified: `internal/tools/enrichment_retry.go`
|
||||||
|
- Modified: `internal/tools/links.go`
|
||||||
|
- Modified: `internal/tools/metadata_retry.go`
|
||||||
|
- Modified: `internal/tools/recall.go`
|
||||||
|
- Modified: `internal/tools/reparse_metadata.go`
|
||||||
|
- Modified: `internal/tools/retrieval.go`
|
||||||
|
- Modified: `internal/tools/search.go`
|
||||||
|
- Modified: `internal/tools/summarize.go`
|
||||||
|
- Modified: `internal/tools/update.go`
|
||||||
|
- Deleted: `internal/ai/factory.go`
|
||||||
|
- Deleted: `internal/ai/factory_test.go`
|
||||||
|
- Deleted: `internal/ai/litellm/client.go`
|
||||||
|
- Deleted: `internal/ai/ollama/client.go`
|
||||||
|
- Deleted: `internal/ai/openrouter/client.go`
|
||||||
|
- Deleted: `internal/ai/provider.go`
|
||||||
|
- New: `changelog.md`
|
||||||
|
- New: `cmd/amcs-migrate-config/main.go`
|
||||||
|
- New: `internal/ai/registry.go`
|
||||||
|
- New: `internal/ai/registry_test.go`
|
||||||
|
- New: `internal/ai/runner.go`
|
||||||
|
- New: `internal/ai/runner_test.go`
|
||||||
|
- New: `internal/config/migrate.go`
|
||||||
|
- New: `internal/config/migrate_test.go`
|
||||||
|
|
||||||
|
### 2026-04-21 21h - Docker Support for Config Migration CLI
|
||||||
|
- Added `amcs-migrate-config` binary to the Docker image build output.
|
||||||
|
- Added `migrate-config` service in `docker-compose.yml` under the `tools` profile.
|
||||||
|
- Documented compose-based migration commands (dry-run and in-place apply) in the README.
|
||||||
|
|
||||||
|
### 2026-04-21 21h - Startup Migration Write Disabled
|
||||||
|
- Changed config loading to migrate legacy schemas in memory only during startup.
|
||||||
|
- Removed automatic file rewrite and backup creation from the startup config loader.
|
||||||
|
- Added loader log hint to use `amcs-migrate-config` when persistent conversion is needed.
|
||||||
@@ -6,7 +6,6 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
@@ -17,9 +16,12 @@ var (
|
|||||||
serverFlag string
|
serverFlag string
|
||||||
tokenFlag string
|
tokenFlag string
|
||||||
outputFlag string
|
outputFlag string
|
||||||
|
verbose bool
|
||||||
cfg Config
|
cfg Config
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const cliUserAgent = "amcs-cli/0.0.1"
|
||||||
|
|
||||||
var rootCmd = &cobra.Command{
|
var rootCmd = &cobra.Command{
|
||||||
Use: "amcs-cli",
|
Use: "amcs-cli",
|
||||||
Short: "CLI for connecting to a remote AMCS MCP server",
|
Short: "CLI for connecting to a remote AMCS MCP server",
|
||||||
@@ -42,6 +44,7 @@ func init() {
|
|||||||
rootCmd.PersistentFlags().StringVar(&serverFlag, "server", "", "AMCS server URL")
|
rootCmd.PersistentFlags().StringVar(&serverFlag, "server", "", "AMCS server URL")
|
||||||
rootCmd.PersistentFlags().StringVar(&tokenFlag, "token", "", "AMCS bearer token")
|
rootCmd.PersistentFlags().StringVar(&tokenFlag, "token", "", "AMCS bearer token")
|
||||||
rootCmd.PersistentFlags().StringVar(&outputFlag, "output", "json", "Output format: json or yaml")
|
rootCmd.PersistentFlags().StringVar(&outputFlag, "output", "json", "Output format: json or yaml")
|
||||||
|
rootCmd.PersistentFlags().BoolVar(&verbose, "verbose", false, "Enable verbose logging to stderr")
|
||||||
}
|
}
|
||||||
|
|
||||||
func loadConfig() error {
|
func loadConfig() error {
|
||||||
@@ -54,6 +57,9 @@ func loadConfig() error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
cfg = loaded
|
cfg = loaded
|
||||||
|
if v := strings.TrimSpace(os.Getenv("AMCS_SERVER")); v != "" {
|
||||||
|
cfg.Server = v
|
||||||
|
}
|
||||||
if v := strings.TrimSpace(os.Getenv("AMCS_URL")); v != "" {
|
if v := strings.TrimSpace(os.Getenv("AMCS_URL")); v != "" {
|
||||||
cfg.Server = v
|
cfg.Server = v
|
||||||
}
|
}
|
||||||
@@ -75,7 +81,7 @@ func loadConfig() error {
|
|||||||
|
|
||||||
func requireServer() error {
|
func requireServer() error {
|
||||||
if strings.TrimSpace(cfg.Server) == "" {
|
if strings.TrimSpace(cfg.Server) == "" {
|
||||||
return fmt.Errorf("server URL is required; set --server, AMCS_URL, or config server")
|
return fmt.Errorf("server URL is required; set --server, AMCS_SERVER, AMCS_URL, or config server")
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -109,6 +115,9 @@ func (t *bearerTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
|||||||
base = http.DefaultTransport
|
base = http.DefaultTransport
|
||||||
}
|
}
|
||||||
clone := req.Clone(req.Context())
|
clone := req.Clone(req.Context())
|
||||||
|
if strings.TrimSpace(clone.Header.Get("User-Agent")) == "" {
|
||||||
|
clone.Header.Set("User-Agent", cliUserAgent)
|
||||||
|
}
|
||||||
if strings.TrimSpace(t.token) != "" {
|
if strings.TrimSpace(t.token) != "" {
|
||||||
clone.Header.Set("Authorization", "Bearer "+t.token)
|
clone.Header.Set("Authorization", "Bearer "+t.token)
|
||||||
}
|
}
|
||||||
@@ -119,16 +128,24 @@ func connectRemote(ctx context.Context) (*mcp.ClientSession, error) {
|
|||||||
if err := requireServer(); err != nil {
|
if err := requireServer(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
verboseLogf("connecting to %s", endpointURL())
|
||||||
client := mcp.NewClient(&mcp.Implementation{Name: "amcs-cli", Version: "0.0.1"}, nil)
|
client := mcp.NewClient(&mcp.Implementation{Name: "amcs-cli", Version: "0.0.1"}, nil)
|
||||||
transport := &mcp.StreamableClientTransport{
|
transport := &mcp.StreamableClientTransport{
|
||||||
Endpoint: endpointURL(),
|
Endpoint: endpointURL(),
|
||||||
HTTPClient: newHTTPClient(),
|
HTTPClient: newHTTPClient(),
|
||||||
|
DisableStandaloneSSE: true,
|
||||||
}
|
}
|
||||||
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
session, err := client.Connect(ctx, transport, nil)
|
session, err := client.Connect(ctx, transport, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("connect to AMCS server: %w", err)
|
return nil, fmt.Errorf("connect to AMCS server: %w", err)
|
||||||
}
|
}
|
||||||
|
verboseLogf("connected to %s", endpointURL())
|
||||||
return session, nil
|
return session, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func verboseLogf(format string, args ...any) {
|
||||||
|
if !verbose {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_, _ = fmt.Fprintf(os.Stderr, "[amcs-cli] "+format+"\n", args...)
|
||||||
|
}
|
||||||
|
|||||||
35
cmd/amcs-cli/cmd/root_test.go
Normal file
35
cmd/amcs-cli/cmd/root_test.go
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestBearerTransportFormatsBearerToken(t *testing.T) {
|
||||||
|
const want = "Bearer X"
|
||||||
|
const wantUA = "amcs-cli/0.0.1"
|
||||||
|
|
||||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if got := r.Header.Get("Authorization"); got != want {
|
||||||
|
t.Fatalf("Authorization header = %q, want %q", got, want)
|
||||||
|
}
|
||||||
|
if got := r.Header.Get("User-Agent"); got != wantUA {
|
||||||
|
t.Fatalf("User-Agent header = %q, want %q", got, wantUA)
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusNoContent)
|
||||||
|
}))
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
client := &http.Client{Transport: &bearerTransport{token: "X"}}
|
||||||
|
req, err := http.NewRequest(http.MethodGet, ts.URL, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewRequest() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
res, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("client.Do() error = %v", err)
|
||||||
|
}
|
||||||
|
_ = res.Body.Close()
|
||||||
|
}
|
||||||
@@ -4,7 +4,6 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
@@ -26,14 +25,13 @@ var sseCmd = &cobra.Command{
|
|||||||
HTTPClient: newHTTPClient(),
|
HTTPClient: newHTTPClient(),
|
||||||
}
|
}
|
||||||
|
|
||||||
connectCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
verboseLogf("connecting to SSE endpoint %s", sseEndpointURL())
|
||||||
defer cancel()
|
remote, err := client.Connect(ctx, transport, nil)
|
||||||
|
|
||||||
remote, err := client.Connect(connectCtx, transport, nil)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("connect to AMCS SSE endpoint: %w", err)
|
return fmt.Errorf("connect to AMCS SSE endpoint: %w", err)
|
||||||
}
|
}
|
||||||
defer func() { _ = remote.Close() }()
|
defer func() { _ = remote.Close() }()
|
||||||
|
verboseLogf("connected to SSE endpoint %s", sseEndpointURL())
|
||||||
|
|
||||||
tools, err := remote.ListTools(ctx, nil)
|
tools, err := remote.ListTools(ctx, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -67,6 +65,8 @@ var sseCmd = &cobra.Command{
|
|||||||
return fmt.Errorf("start stdio bridge: %w", err)
|
return fmt.Errorf("start stdio bridge: %w", err)
|
||||||
}
|
}
|
||||||
defer func() { _ = session.Close() }()
|
defer func() { _ = session.Close() }()
|
||||||
|
verboseLogf("sse stdio bridge ready")
|
||||||
|
verboseLogf("waiting for MCP commands on stdin")
|
||||||
|
|
||||||
<-ctx.Done()
|
<-ctx.Done()
|
||||||
return nil
|
return nil
|
||||||
@@ -75,6 +75,9 @@ var sseCmd = &cobra.Command{
|
|||||||
|
|
||||||
func sseEndpointURL() string {
|
func sseEndpointURL() string {
|
||||||
base := strings.TrimRight(strings.TrimSpace(cfg.Server), "/")
|
base := strings.TrimRight(strings.TrimSpace(cfg.Server), "/")
|
||||||
|
if strings.HasSuffix(base, "/mcp") {
|
||||||
|
base = strings.TrimSuffix(base, "/mcp")
|
||||||
|
}
|
||||||
if strings.HasSuffix(base, "/sse") {
|
if strings.HasSuffix(base, "/sse") {
|
||||||
return base
|
return base
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -51,6 +51,8 @@ var stdioCmd = &cobra.Command{
|
|||||||
return fmt.Errorf("start stdio bridge: %w", err)
|
return fmt.Errorf("start stdio bridge: %w", err)
|
||||||
}
|
}
|
||||||
defer func() { _ = session.Close() }()
|
defer func() { _ = session.Close() }()
|
||||||
|
verboseLogf("stdio bridge connected to remote AMCS and ready")
|
||||||
|
verboseLogf("waiting for MCP commands on stdin")
|
||||||
|
|
||||||
<-ctx.Done()
|
<-ctx.Done()
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
105
cmd/amcs-migrate-config/main.go
Normal file
105
cmd/amcs-migrate-config/main.go
Normal file
@@ -0,0 +1,105 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"flag"
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"os"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"gopkg.in/yaml.v3"
|
||||||
|
|
||||||
|
"git.warky.dev/wdevs/amcs/internal/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
var (
|
||||||
|
configPath string
|
||||||
|
dryRun bool
|
||||||
|
toVersion int
|
||||||
|
)
|
||||||
|
flag.StringVar(&configPath, "config", "", "Path to the YAML config file (default: $AMCS_CONFIG or ./configs/dev.yaml)")
|
||||||
|
flag.BoolVar(&dryRun, "dry-run", false, "Print the migrated config to stdout instead of writing it back")
|
||||||
|
flag.IntVar(&toVersion, "to-version", config.CurrentConfigVersion, "Stop migrating after reaching this version")
|
||||||
|
flag.Parse()
|
||||||
|
|
||||||
|
if toVersion <= 0 || toVersion > config.CurrentConfigVersion {
|
||||||
|
log.Fatalf("invalid -to-version %d (must be between 1 and %d)", toVersion, config.CurrentConfigVersion)
|
||||||
|
}
|
||||||
|
|
||||||
|
path := config.ResolvePath(configPath)
|
||||||
|
original, err := os.ReadFile(path)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("read config %q: %v", path, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
raw := map[string]any{}
|
||||||
|
if err := yaml.Unmarshal(original, &raw); err != nil {
|
||||||
|
log.Fatalf("decode config %q: %v", path, err)
|
||||||
|
}
|
||||||
|
if raw == nil {
|
||||||
|
raw = map[string]any{}
|
||||||
|
}
|
||||||
|
|
||||||
|
applied, err := migrateUpTo(raw, toVersion)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("migrate: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(applied) == 0 {
|
||||||
|
fmt.Fprintf(os.Stderr, "%s already at version %d; nothing to do\n", path, currentVersion(raw))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
out, err := yaml.Marshal(raw)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("marshal migrated config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, step := range applied {
|
||||||
|
fmt.Fprintf(os.Stderr, "applied migration v%d -> v%d: %s\n", step.From, step.To, step.Describe)
|
||||||
|
}
|
||||||
|
|
||||||
|
if dryRun {
|
||||||
|
_, _ = os.Stdout.Write(out)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
backup := fmt.Sprintf("%s.bak.%d", path, time.Now().Unix())
|
||||||
|
if err := os.WriteFile(backup, original, 0o600); err != nil {
|
||||||
|
log.Fatalf("write backup %q: %v", backup, err)
|
||||||
|
}
|
||||||
|
if err := os.WriteFile(path, out, 0o600); err != nil {
|
||||||
|
log.Fatalf("write migrated config %q: %v", path, err)
|
||||||
|
}
|
||||||
|
fmt.Fprintf(os.Stderr, "wrote migrated config to %s (backup: %s)\n", path, backup)
|
||||||
|
}
|
||||||
|
|
||||||
|
// migrateUpTo runs the migration ladder but stops at the requested version.
|
||||||
|
func migrateUpTo(raw map[string]any, target int) ([]config.ConfigMigration, error) {
|
||||||
|
if currentVersion(raw) >= target {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
if target == config.CurrentConfigVersion {
|
||||||
|
return config.Migrate(raw)
|
||||||
|
}
|
||||||
|
// Partial migrations are rare; for now reject anything other than the
|
||||||
|
// current version target since the migration ladder is short.
|
||||||
|
return nil, fmt.Errorf("partial migration to v%d is not supported (use -to-version=%d)", target, config.CurrentConfigVersion)
|
||||||
|
}
|
||||||
|
|
||||||
|
func currentVersion(raw map[string]any) int {
|
||||||
|
v, ok := raw["version"]
|
||||||
|
if !ok {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
switch n := v.(type) {
|
||||||
|
case int:
|
||||||
|
return n
|
||||||
|
case int64:
|
||||||
|
return int(n)
|
||||||
|
case float64:
|
||||||
|
return int(n)
|
||||||
|
}
|
||||||
|
return 1
|
||||||
|
}
|
||||||
@@ -1,3 +1,5 @@
|
|||||||
|
version: 2
|
||||||
|
|
||||||
server:
|
server:
|
||||||
host: "0.0.0.0"
|
host: "0.0.0.0"
|
||||||
port: 8080
|
port: 8080
|
||||||
@@ -27,7 +29,7 @@ auth:
|
|||||||
- id: "oauth-client"
|
- id: "oauth-client"
|
||||||
client_id: ""
|
client_id: ""
|
||||||
client_secret: ""
|
client_secret: ""
|
||||||
description: "used when auth.mode=oauth_client_credentials"
|
description: "optional OAuth client credentials"
|
||||||
|
|
||||||
database:
|
database:
|
||||||
url: "postgres://postgres:postgres@localhost:5432/amcs?sslmode=disable"
|
url: "postgres://postgres:postgres@localhost:5432/amcs?sslmode=disable"
|
||||||
@@ -37,33 +39,58 @@ database:
|
|||||||
max_conn_idle_time: "10m"
|
max_conn_idle_time: "10m"
|
||||||
|
|
||||||
ai:
|
ai:
|
||||||
provider: "litellm"
|
providers:
|
||||||
embeddings:
|
default:
|
||||||
model: "openai/text-embedding-3-small"
|
type: "litellm"
|
||||||
dimensions: 1536
|
|
||||||
metadata:
|
|
||||||
model: "gpt-4o-mini"
|
|
||||||
fallback_models: []
|
|
||||||
temperature: 0.1
|
|
||||||
log_conversations: false
|
|
||||||
litellm:
|
|
||||||
base_url: "http://localhost:4000/v1"
|
base_url: "http://localhost:4000/v1"
|
||||||
api_key: "replace-me"
|
api_key: "replace-me"
|
||||||
use_responses_api: false
|
|
||||||
request_headers: {}
|
request_headers: {}
|
||||||
embedding_model: "openrouter/openai/text-embedding-3-small"
|
|
||||||
metadata_model: "gpt-4o-mini"
|
ollama_local:
|
||||||
fallback_metadata_models: []
|
type: "ollama"
|
||||||
ollama:
|
|
||||||
base_url: "http://localhost:11434/v1"
|
base_url: "http://localhost:11434/v1"
|
||||||
api_key: "ollama"
|
api_key: "ollama"
|
||||||
request_headers: {}
|
request_headers: {}
|
||||||
|
|
||||||
openrouter:
|
openrouter:
|
||||||
|
type: "openrouter"
|
||||||
base_url: "https://openrouter.ai/api/v1"
|
base_url: "https://openrouter.ai/api/v1"
|
||||||
api_key: ""
|
api_key: "replace-me"
|
||||||
app_name: "amcs"
|
app_name: "amcs"
|
||||||
site_url: ""
|
site_url: ""
|
||||||
extra_headers: {}
|
request_headers: {}
|
||||||
|
|
||||||
|
embeddings:
|
||||||
|
dimensions: 1536
|
||||||
|
primary:
|
||||||
|
provider: "default"
|
||||||
|
model: "openai/text-embedding-3-small"
|
||||||
|
fallbacks:
|
||||||
|
- provider: "ollama_local"
|
||||||
|
model: "nomic-embed-text"
|
||||||
|
|
||||||
|
metadata:
|
||||||
|
temperature: 0.1
|
||||||
|
log_conversations: false
|
||||||
|
timeout: "10s"
|
||||||
|
primary:
|
||||||
|
provider: "default"
|
||||||
|
model: "gpt-4o-mini"
|
||||||
|
fallbacks:
|
||||||
|
- provider: "openrouter"
|
||||||
|
model: "openai/gpt-4.1-mini"
|
||||||
|
|
||||||
|
# Optional overrides for background jobs (backfill_embeddings,
|
||||||
|
# retry_failed_metadata, reparse_thought_metadata).
|
||||||
|
background:
|
||||||
|
embeddings:
|
||||||
|
primary:
|
||||||
|
provider: "default"
|
||||||
|
model: "openai/text-embedding-3-small"
|
||||||
|
metadata:
|
||||||
|
primary:
|
||||||
|
provider: "default"
|
||||||
|
model: "gpt-4o-mini"
|
||||||
|
|
||||||
capture:
|
capture:
|
||||||
source: "mcp"
|
source: "mcp"
|
||||||
|
|||||||
@@ -36,6 +36,18 @@ services:
|
|||||||
ports:
|
ports:
|
||||||
- "8080:8080"
|
- "8080:8080"
|
||||||
|
|
||||||
|
migrate-config:
|
||||||
|
build:
|
||||||
|
context: .
|
||||||
|
profiles: ["tools"]
|
||||||
|
restart: "no"
|
||||||
|
volumes:
|
||||||
|
- ./configs:/app/configs
|
||||||
|
environment:
|
||||||
|
AMCS_CONFIG: /app/configs/docker.yaml
|
||||||
|
entrypoint: ["/app/amcs-migrate-config"]
|
||||||
|
command: ["--config", "/app/configs/docker.yaml", "--dry-run"]
|
||||||
|
|
||||||
volumes:
|
volumes:
|
||||||
postgres_data:
|
postgres_data:
|
||||||
|
|
||||||
|
|||||||
@@ -14,7 +14,6 @@ import (
|
|||||||
"regexp"
|
"regexp"
|
||||||
"slices"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
thoughttypes "git.warky.dev/wdevs/amcs/internal/types"
|
thoughttypes "git.warky.dev/wdevs/amcs/internal/types"
|
||||||
@@ -36,38 +35,41 @@ Rules:
|
|||||||
- If unsure, prefer "observation".
|
- If unsure, prefer "observation".
|
||||||
- Do not include any text outside the JSON object.`
|
- Do not include any text outside the JSON object.`
|
||||||
|
|
||||||
|
// Client is a low-level OpenAI-compatible HTTP client. It knows nothing about
|
||||||
|
// role chains, fallbacks, or health — those concerns belong to ai.Runner. Each
|
||||||
|
// method takes the model name per-call so a single Client instance can service
|
||||||
|
// many different models on the same base URL.
|
||||||
type Client struct {
|
type Client struct {
|
||||||
name string
|
name string
|
||||||
baseURL string
|
baseURL string
|
||||||
apiKey string
|
apiKey string
|
||||||
embeddingModel string
|
|
||||||
metadataModel string
|
|
||||||
fallbackMetadataModels []string
|
|
||||||
temperature float64
|
|
||||||
headers map[string]string
|
headers map[string]string
|
||||||
httpClient *http.Client
|
httpClient *http.Client
|
||||||
log *slog.Logger
|
log *slog.Logger
|
||||||
dimensions int
|
|
||||||
logConversations bool
|
|
||||||
modelHealthMu sync.Mutex
|
|
||||||
modelHealth map[string]modelHealthState
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
Name string
|
Name string
|
||||||
BaseURL string
|
BaseURL string
|
||||||
APIKey string
|
APIKey string
|
||||||
EmbeddingModel string
|
|
||||||
MetadataModel string
|
|
||||||
FallbackMetadataModels []string
|
|
||||||
Temperature float64
|
|
||||||
Headers map[string]string
|
Headers map[string]string
|
||||||
HTTPClient *http.Client
|
HTTPClient *http.Client
|
||||||
Log *slog.Logger
|
Log *slog.Logger
|
||||||
Dimensions int
|
}
|
||||||
|
|
||||||
|
// MetadataOptions control a single ExtractMetadataWith call.
|
||||||
|
type MetadataOptions struct {
|
||||||
|
Model string
|
||||||
|
Temperature float64
|
||||||
LogConversations bool
|
LogConversations bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SummarizeOptions control a single SummarizeWith call.
|
||||||
|
type SummarizeOptions struct {
|
||||||
|
Model string
|
||||||
|
Temperature float64
|
||||||
|
}
|
||||||
|
|
||||||
type embeddingsRequest struct {
|
type embeddingsRequest struct {
|
||||||
Input string `json:"input"`
|
Input string `json:"input"`
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
@@ -127,65 +129,38 @@ type providerError struct {
|
|||||||
|
|
||||||
const maxMetadataAttempts = 3
|
const maxMetadataAttempts = 3
|
||||||
|
|
||||||
const (
|
// ErrEmptyResponse and ErrNoJSONObject are sentinel errors callers can inspect
|
||||||
emptyResponseCircuitThreshold = 3
|
// to classify metadata failures (e.g. bump empty-response health counters).
|
||||||
emptyResponseCircuitTTL = 5 * time.Minute
|
|
||||||
permanentModelFailureTTL = 24 * time.Hour
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
var (
|
||||||
errMetadataEmptyResponse = errors.New("metadata empty response")
|
ErrEmptyResponse = errors.New("metadata empty response")
|
||||||
errMetadataNoJSONObject = errors.New("metadata response contains no JSON object")
|
ErrNoJSONObject = errors.New("metadata response contains no JSON object")
|
||||||
)
|
)
|
||||||
|
|
||||||
type modelHealthState struct {
|
|
||||||
consecutiveEmpty int
|
|
||||||
unhealthyUntil time.Time
|
|
||||||
}
|
|
||||||
|
|
||||||
func New(cfg Config) *Client {
|
func New(cfg Config) *Client {
|
||||||
fallbacks := make([]string, 0, len(cfg.FallbackMetadataModels))
|
|
||||||
seen := make(map[string]struct{}, len(cfg.FallbackMetadataModels))
|
|
||||||
for _, model := range cfg.FallbackMetadataModels {
|
|
||||||
model = strings.TrimSpace(model)
|
|
||||||
if model == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if _, ok := seen[model]; ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
seen[model] = struct{}{}
|
|
||||||
fallbacks = append(fallbacks, model)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &Client{
|
return &Client{
|
||||||
name: cfg.Name,
|
name: cfg.Name,
|
||||||
baseURL: cfg.BaseURL,
|
baseURL: cfg.BaseURL,
|
||||||
apiKey: cfg.APIKey,
|
apiKey: cfg.APIKey,
|
||||||
embeddingModel: cfg.EmbeddingModel,
|
|
||||||
metadataModel: cfg.MetadataModel,
|
|
||||||
fallbackMetadataModels: fallbacks,
|
|
||||||
temperature: cfg.Temperature,
|
|
||||||
headers: cfg.Headers,
|
headers: cfg.Headers,
|
||||||
httpClient: cfg.HTTPClient,
|
httpClient: cfg.HTTPClient,
|
||||||
log: cfg.Log,
|
log: cfg.Log,
|
||||||
dimensions: cfg.Dimensions,
|
|
||||||
logConversations: cfg.LogConversations,
|
|
||||||
modelHealth: make(map[string]modelHealthState),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) Embed(ctx context.Context, input string) ([]float32, error) {
|
func (c *Client) Name() string { return c.name }
|
||||||
|
|
||||||
|
// EmbedWith generates an embedding for the given input using model.
|
||||||
|
func (c *Client) EmbedWith(ctx context.Context, model, input string) ([]float32, error) {
|
||||||
input = strings.TrimSpace(input)
|
input = strings.TrimSpace(input)
|
||||||
if input == "" {
|
if input == "" {
|
||||||
return nil, fmt.Errorf("%s embed: input must not be empty", c.name)
|
return nil, fmt.Errorf("%s embed: input must not be empty", c.name)
|
||||||
}
|
}
|
||||||
|
if strings.TrimSpace(model) == "" {
|
||||||
|
return nil, fmt.Errorf("%s embed: model is required", c.name)
|
||||||
|
}
|
||||||
|
|
||||||
var resp embeddingsResponse
|
var resp embeddingsResponse
|
||||||
err := c.doJSON(ctx, "/embeddings", embeddingsRequest{
|
err := c.doJSON(ctx, "/embeddings", embeddingsRequest{Input: input, Model: model}, &resp)
|
||||||
Input: input,
|
|
||||||
Model: c.embeddingModel,
|
|
||||||
}, &resp)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -195,133 +170,26 @@ func (c *Client) Embed(ctx context.Context, input string) ([]float32, error) {
|
|||||||
if len(resp.Data) == 0 {
|
if len(resp.Data) == 0 {
|
||||||
return nil, fmt.Errorf("%s embed: no embedding returned", c.name)
|
return nil, fmt.Errorf("%s embed: no embedding returned", c.name)
|
||||||
}
|
}
|
||||||
if c.dimensions > 0 && len(resp.Data[0].Embedding) != c.dimensions {
|
|
||||||
return nil, fmt.Errorf("%s embed: expected %d dimensions, got %d", c.name, c.dimensions, len(resp.Data[0].Embedding))
|
|
||||||
}
|
|
||||||
|
|
||||||
return resp.Data[0].Embedding, nil
|
return resp.Data[0].Embedding, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) ExtractMetadata(ctx context.Context, input string) (thoughttypes.ThoughtMetadata, error) {
|
// ExtractMetadataWith extracts structured metadata for input using opts.Model.
|
||||||
|
// Returns compat.ErrEmptyResponse / ErrNoJSONObject wrapped when the model
|
||||||
|
// produces unusable output so callers can classify the failure.
|
||||||
|
func (c *Client) ExtractMetadataWith(ctx context.Context, opts MetadataOptions, input string) (thoughttypes.ThoughtMetadata, error) {
|
||||||
input = strings.TrimSpace(input)
|
input = strings.TrimSpace(input)
|
||||||
if input == "" {
|
if input == "" {
|
||||||
return thoughttypes.ThoughtMetadata{}, fmt.Errorf("%s extract metadata: input must not be empty", c.name)
|
return thoughttypes.ThoughtMetadata{}, fmt.Errorf("%s extract metadata: input must not be empty", c.name)
|
||||||
}
|
}
|
||||||
|
if strings.TrimSpace(opts.Model) == "" {
|
||||||
start := time.Now()
|
return thoughttypes.ThoughtMetadata{}, fmt.Errorf("%s extract metadata: model is required", c.name)
|
||||||
if c.log != nil {
|
|
||||||
c.log.Info("metadata client started",
|
|
||||||
slog.String("provider", c.name),
|
|
||||||
slog.String("model", c.metadataModel),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
logCompletion := func(model string, err error) {
|
|
||||||
if c.log == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
attrs := []any{
|
|
||||||
slog.String("provider", c.name),
|
|
||||||
slog.String("model", model),
|
|
||||||
slog.String("duration", formatLogDuration(time.Since(start))),
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
attrs = append(attrs, slog.String("error", err.Error()))
|
|
||||||
c.log.Error("metadata client completed", attrs...)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
c.log.Info("metadata client completed", attrs...)
|
|
||||||
}
|
|
||||||
|
|
||||||
result, err := c.extractMetadataWithModel(ctx, input, c.metadataModel)
|
|
||||||
if errors.Is(err, errMetadataEmptyResponse) {
|
|
||||||
c.noteEmptyResponse(c.metadataModel)
|
|
||||||
}
|
|
||||||
if isPermanentModelError(err) {
|
|
||||||
c.notePermanentModelFailure(c.metadataModel, err)
|
|
||||||
}
|
|
||||||
if err == nil {
|
|
||||||
c.noteModelSuccess(c.metadataModel)
|
|
||||||
logCompletion(c.metadataModel, nil)
|
|
||||||
return result, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, fallbackModel := range c.fallbackMetadataModels {
|
|
||||||
if ctx.Err() != nil {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
if fallbackModel == "" || fallbackModel == c.metadataModel {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if c.shouldBypassModel(fallbackModel) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if c.log != nil {
|
|
||||||
c.log.Warn("metadata extraction failed, trying fallback model",
|
|
||||||
slog.String("provider", c.name),
|
|
||||||
slog.String("primary_model", c.metadataModel),
|
|
||||||
slog.String("fallback_model", fallbackModel),
|
|
||||||
slog.String("error", err.Error()),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
fallbackResult, fallbackErr := c.extractMetadataWithModel(ctx, input, fallbackModel)
|
|
||||||
if errors.Is(fallbackErr, errMetadataEmptyResponse) {
|
|
||||||
c.noteEmptyResponse(fallbackModel)
|
|
||||||
}
|
|
||||||
if isPermanentModelError(fallbackErr) {
|
|
||||||
c.notePermanentModelFailure(fallbackModel, fallbackErr)
|
|
||||||
}
|
|
||||||
if fallbackErr == nil {
|
|
||||||
c.noteModelSuccess(fallbackModel)
|
|
||||||
logCompletion(fallbackModel, nil)
|
|
||||||
return fallbackResult, nil
|
|
||||||
}
|
|
||||||
err = fallbackErr
|
|
||||||
}
|
|
||||||
|
|
||||||
if ctx.Err() != nil {
|
|
||||||
err = fmt.Errorf("%s metadata: %w", c.name, ctx.Err())
|
|
||||||
logCompletion(c.metadataModel, err)
|
|
||||||
return thoughttypes.ThoughtMetadata{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
heuristic := heuristicMetadataFromInput(input)
|
|
||||||
if c.log != nil {
|
|
||||||
c.log.Warn("metadata extraction failed for all models, using heuristic fallback",
|
|
||||||
slog.String("provider", c.name),
|
|
||||||
slog.String("error", err.Error()),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
logCompletion(c.metadataModel, nil)
|
|
||||||
return heuristic, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func formatLogDuration(d time.Duration) string {
|
|
||||||
if d < 0 {
|
|
||||||
d = -d
|
|
||||||
}
|
|
||||||
|
|
||||||
totalMilliseconds := d.Milliseconds()
|
|
||||||
minutes := totalMilliseconds / 60000
|
|
||||||
seconds := (totalMilliseconds / 1000) % 60
|
|
||||||
milliseconds := totalMilliseconds % 1000
|
|
||||||
return fmt.Sprintf("%02d:%02d:%03d", minutes, seconds, milliseconds)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Client) extractMetadataWithModel(ctx context.Context, input, model string) (thoughttypes.ThoughtMetadata, error) {
|
|
||||||
if c.shouldBypassModel(model) {
|
|
||||||
return thoughttypes.ThoughtMetadata{}, fmt.Errorf("%s metadata: model %q temporarily bypassed after repeated empty responses", c.name, model)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
stream := true
|
stream := true
|
||||||
req := chatCompletionsRequest{
|
req := chatCompletionsRequest{
|
||||||
Model: model,
|
Model: opts.Model,
|
||||||
Temperature: c.temperature,
|
Temperature: opts.Temperature,
|
||||||
ResponseFormat: &responseType{
|
ResponseFormat: &responseType{Type: "json_object"},
|
||||||
Type: "json_object",
|
|
||||||
},
|
|
||||||
Stream: &stream,
|
Stream: &stream,
|
||||||
Messages: []chatMessage{
|
Messages: []chatMessage{
|
||||||
{Role: "system", Content: metadataSystemPrompt},
|
{Role: "system", Content: metadataSystemPrompt},
|
||||||
@@ -329,7 +197,7 @@ func (c *Client) extractMetadataWithModel(ctx context.Context, input, model stri
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
metadata, err := c.extractMetadataWithRequest(ctx, req, input, model)
|
metadata, err := c.extractMetadataWithRequest(ctx, req, input, opts)
|
||||||
if err == nil || !shouldRetryWithoutJSONMode(err) {
|
if err == nil || !shouldRetryWithoutJSONMode(err) {
|
||||||
return metadata, err
|
return metadata, err
|
||||||
}
|
}
|
||||||
@@ -337,23 +205,22 @@ func (c *Client) extractMetadataWithModel(ctx context.Context, input, model stri
|
|||||||
if c.log != nil {
|
if c.log != nil {
|
||||||
c.log.Warn("metadata json mode failed, retrying without response_format",
|
c.log.Warn("metadata json mode failed, retrying without response_format",
|
||||||
slog.String("provider", c.name),
|
slog.String("provider", c.name),
|
||||||
slog.String("model", model),
|
slog.String("model", opts.Model),
|
||||||
slog.String("error", err.Error()),
|
slog.String("error", err.Error()),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
req.ResponseFormat = nil
|
req.ResponseFormat = nil
|
||||||
return c.extractMetadataWithRequest(ctx, req, input, model)
|
return c.extractMetadataWithRequest(ctx, req, input, opts)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) extractMetadataWithRequest(ctx context.Context, req chatCompletionsRequest, input, model string) (thoughttypes.ThoughtMetadata, error) {
|
func (c *Client) extractMetadataWithRequest(ctx context.Context, req chatCompletionsRequest, input string, opts MetadataOptions) (thoughttypes.ThoughtMetadata, error) {
|
||||||
|
|
||||||
var lastErr error
|
var lastErr error
|
||||||
for attempt := 1; attempt <= maxMetadataAttempts; attempt++ {
|
for attempt := 1; attempt <= maxMetadataAttempts; attempt++ {
|
||||||
if c.logConversations && c.log != nil {
|
if opts.LogConversations && c.log != nil {
|
||||||
c.log.Info("metadata conversation request",
|
c.log.Info("metadata conversation request",
|
||||||
slog.String("provider", c.name),
|
slog.String("provider", c.name),
|
||||||
slog.String("model", model),
|
slog.String("model", opts.Model),
|
||||||
slog.Int("attempt", attempt),
|
slog.Int("attempt", attempt),
|
||||||
slog.String("system", metadataSystemPrompt),
|
slog.String("system", metadataSystemPrompt),
|
||||||
slog.String("input", input),
|
slog.String("input", input),
|
||||||
@@ -373,10 +240,10 @@ func (c *Client) extractMetadataWithRequest(ctx context.Context, req chatComplet
|
|||||||
|
|
||||||
rawResponse := extractChoiceText(resp.Choices[0].Message, resp.Choices[0].Text)
|
rawResponse := extractChoiceText(resp.Choices[0].Message, resp.Choices[0].Text)
|
||||||
|
|
||||||
if c.logConversations && c.log != nil {
|
if opts.LogConversations && c.log != nil {
|
||||||
c.log.Info("metadata conversation response",
|
c.log.Info("metadata conversation response",
|
||||||
slog.String("provider", c.name),
|
slog.String("provider", c.name),
|
||||||
slog.String("model", model),
|
slog.String("model", opts.Model),
|
||||||
slog.Int("attempt", attempt),
|
slog.Int("attempt", attempt),
|
||||||
slog.String("response", rawResponse),
|
slog.String("response", rawResponse),
|
||||||
)
|
)
|
||||||
@@ -387,13 +254,13 @@ func (c *Client) extractMetadataWithRequest(ctx context.Context, req chatComplet
|
|||||||
metadataText = stripCodeFence(metadataText)
|
metadataText = stripCodeFence(metadataText)
|
||||||
metadataText = extractJSONObject(metadataText)
|
metadataText = extractJSONObject(metadataText)
|
||||||
if metadataText == "" {
|
if metadataText == "" {
|
||||||
lastErr = fmt.Errorf("%s metadata: %w", c.name, errMetadataNoJSONObject)
|
lastErr = fmt.Errorf("%s metadata: %w", c.name, ErrNoJSONObject)
|
||||||
if strings.TrimSpace(rawResponse) == "" && attempt < maxMetadataAttempts && ctx.Err() == nil {
|
if strings.TrimSpace(rawResponse) == "" && attempt < maxMetadataAttempts && ctx.Err() == nil {
|
||||||
lastErr = fmt.Errorf("%s metadata: %w", c.name, errMetadataEmptyResponse)
|
lastErr = fmt.Errorf("%s metadata: %w", c.name, ErrEmptyResponse)
|
||||||
if c.log != nil {
|
if c.log != nil {
|
||||||
c.log.Warn("metadata response empty, waiting and retrying",
|
c.log.Warn("metadata response empty, waiting and retrying",
|
||||||
slog.String("provider", c.name),
|
slog.String("provider", c.name),
|
||||||
slog.String("model", model),
|
slog.String("model", opts.Model),
|
||||||
slog.Int("attempt", attempt+1),
|
slog.Int("attempt", attempt+1),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
@@ -403,7 +270,7 @@ func (c *Client) extractMetadataWithRequest(ctx context.Context, req chatComplet
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if strings.TrimSpace(rawResponse) == "" {
|
if strings.TrimSpace(rawResponse) == "" {
|
||||||
lastErr = fmt.Errorf("%s metadata: %w", c.name, errMetadataEmptyResponse)
|
lastErr = fmt.Errorf("%s metadata: %w", c.name, ErrEmptyResponse)
|
||||||
}
|
}
|
||||||
return thoughttypes.ThoughtMetadata{}, lastErr
|
return thoughttypes.ThoughtMetadata{}, lastErr
|
||||||
}
|
}
|
||||||
@@ -420,13 +287,17 @@ func (c *Client) extractMetadataWithRequest(ctx context.Context, req chatComplet
|
|||||||
if lastErr != nil {
|
if lastErr != nil {
|
||||||
return thoughttypes.ThoughtMetadata{}, lastErr
|
return thoughttypes.ThoughtMetadata{}, lastErr
|
||||||
}
|
}
|
||||||
return thoughttypes.ThoughtMetadata{}, fmt.Errorf("%s metadata: %w", c.name, errMetadataNoJSONObject)
|
return thoughttypes.ThoughtMetadata{}, fmt.Errorf("%s metadata: %w", c.name, ErrNoJSONObject)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) Summarize(ctx context.Context, systemPrompt, userPrompt string) (string, error) {
|
// SummarizeWith runs a chat-completion summarisation using opts.Model.
|
||||||
|
func (c *Client) SummarizeWith(ctx context.Context, opts SummarizeOptions, systemPrompt, userPrompt string) (string, error) {
|
||||||
|
if strings.TrimSpace(opts.Model) == "" {
|
||||||
|
return "", fmt.Errorf("%s summarize: model is required", c.name)
|
||||||
|
}
|
||||||
req := chatCompletionsRequest{
|
req := chatCompletionsRequest{
|
||||||
Model: c.metadataModel,
|
Model: opts.Model,
|
||||||
Temperature: 0.2,
|
Temperature: opts.Temperature,
|
||||||
Messages: []chatMessage{
|
Messages: []chatMessage{
|
||||||
{Role: "system", Content: systemPrompt},
|
{Role: "system", Content: systemPrompt},
|
||||||
{Role: "user", Content: userPrompt},
|
{Role: "user", Content: userPrompt},
|
||||||
@@ -447,12 +318,49 @@ func (c *Client) Summarize(ctx context.Context, systemPrompt, userPrompt string)
|
|||||||
return extractChoiceText(resp.Choices[0].Message, resp.Choices[0].Text), nil
|
return extractChoiceText(resp.Choices[0].Message, resp.Choices[0].Text), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) Name() string {
|
// IsPermanentModelError reports whether err indicates the model itself is
|
||||||
return c.name
|
// invalid or missing (vs. a transient outage). Runners use this to mark a
|
||||||
|
// target unhealthy for longer.
|
||||||
|
func IsPermanentModelError(err error) bool {
|
||||||
|
if err == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
lower := strings.ToLower(err.Error())
|
||||||
|
for _, marker := range []string{
|
||||||
|
"invalid model name",
|
||||||
|
"model_not_found",
|
||||||
|
"model not found",
|
||||||
|
"unknown model",
|
||||||
|
"no such model",
|
||||||
|
"does not exist",
|
||||||
|
} {
|
||||||
|
if strings.Contains(lower, marker) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) EmbeddingModel() string {
|
// HeuristicMetadataFromInput produces best-effort metadata from the note text
|
||||||
return c.embeddingModel
|
// when every model in the chain has failed. Exported so ai.Runner can use it.
|
||||||
|
func HeuristicMetadataFromInput(input string) thoughttypes.ThoughtMetadata {
|
||||||
|
text := strings.TrimSpace(input)
|
||||||
|
lower := strings.ToLower(text)
|
||||||
|
|
||||||
|
metadata := thoughttypes.ThoughtMetadata{
|
||||||
|
People: heuristicPeople(text),
|
||||||
|
ActionItems: heuristicActionItems(text),
|
||||||
|
DatesMentioned: heuristicDates(text),
|
||||||
|
Topics: heuristicTopics(lower),
|
||||||
|
Type: heuristicType(lower),
|
||||||
|
}
|
||||||
|
if len(metadata.Topics) == 0 {
|
||||||
|
metadata.Topics = []string{"uncategorized"}
|
||||||
|
}
|
||||||
|
if metadata.Type == "" {
|
||||||
|
metadata.Type = "observation"
|
||||||
|
}
|
||||||
|
return metadata
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) doJSON(ctx context.Context, path string, requestBody any, dest any) error {
|
func (c *Client) doJSON(ctx context.Context, path string, requestBody any, dest any) error {
|
||||||
@@ -724,8 +632,6 @@ func isRetryableChatResponseError(err error) bool {
|
|||||||
return strings.Contains(lower, "read response") || strings.Contains(lower, "read stream response")
|
return strings.Contains(lower, "read response") || strings.Contains(lower, "read stream response")
|
||||||
}
|
}
|
||||||
|
|
||||||
// extractJSONObject finds the first complete {...} block in s.
|
|
||||||
// It handles models that prepend prose to a JSON response despite json_object mode.
|
|
||||||
func extractJSONObject(s string) string {
|
func extractJSONObject(s string) string {
|
||||||
for start := 0; start < len(s); start++ {
|
for start := 0; start < len(s); start++ {
|
||||||
if s[start] != '{' {
|
if s[start] != '{' {
|
||||||
@@ -768,10 +674,6 @@ func extractJSONObject(s string) string {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
// stripThinkingBlocks removes <think>...</think> and <thinking>...</thinking>
|
|
||||||
// blocks produced by reasoning models (DeepSeek R1, QwQ, etc.) so that the
|
|
||||||
// remaining text can be parsed as JSON without interference from thinking content
|
|
||||||
// that may itself contain braces.
|
|
||||||
func stripThinkingBlocks(s string) string {
|
func stripThinkingBlocks(s string) string {
|
||||||
for _, tag := range []string{"think", "thinking"} {
|
for _, tag := range []string{"think", "thinking"} {
|
||||||
open := "<" + tag + ">"
|
open := "<" + tag + ">"
|
||||||
@@ -857,7 +759,6 @@ func extractTextFromAny(value any) string {
|
|||||||
}
|
}
|
||||||
return strings.Join(parts, "\n")
|
return strings.Join(parts, "\n")
|
||||||
case map[string]any:
|
case map[string]any:
|
||||||
// Common provider shapes for chat content parts.
|
|
||||||
for _, key := range []string{"text", "output_text", "content", "value"} {
|
for _, key := range []string{"text", "output_text", "content", "value"} {
|
||||||
if nested, ok := typed[key]; ok {
|
if nested, ok := typed[key]; ok {
|
||||||
if text := strings.TrimSpace(extractTextFromAny(nested)); text != "" {
|
if text := strings.TrimSpace(extractTextFromAny(nested)); text != "" {
|
||||||
@@ -875,28 +776,6 @@ var (
|
|||||||
wordPattern = regexp.MustCompile(`[a-zA-Z][a-zA-Z0-9_/-]{2,}`)
|
wordPattern = regexp.MustCompile(`[a-zA-Z][a-zA-Z0-9_/-]{2,}`)
|
||||||
)
|
)
|
||||||
|
|
||||||
func heuristicMetadataFromInput(input string) thoughttypes.ThoughtMetadata {
|
|
||||||
text := strings.TrimSpace(input)
|
|
||||||
lower := strings.ToLower(text)
|
|
||||||
|
|
||||||
metadata := thoughttypes.ThoughtMetadata{
|
|
||||||
People: heuristicPeople(text),
|
|
||||||
ActionItems: heuristicActionItems(text),
|
|
||||||
DatesMentioned: heuristicDates(text),
|
|
||||||
Topics: heuristicTopics(lower),
|
|
||||||
Type: heuristicType(lower),
|
|
||||||
Source: "",
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(metadata.Topics) == 0 {
|
|
||||||
metadata.Topics = []string{"uncategorized"}
|
|
||||||
}
|
|
||||||
if metadata.Type == "" {
|
|
||||||
metadata.Type = "observation"
|
|
||||||
}
|
|
||||||
return metadata
|
|
||||||
}
|
|
||||||
|
|
||||||
func heuristicType(lower string) string {
|
func heuristicType(lower string) string {
|
||||||
switch {
|
switch {
|
||||||
case strings.Contains(lower, "preferred name"), strings.Contains(lower, "personal profile"), strings.Contains(lower, "wife:"), strings.Contains(lower, "daughter:"), strings.Contains(lower, "born:"):
|
case strings.Contains(lower, "preferred name"), strings.Contains(lower, "personal profile"), strings.Contains(lower, "wife:"), strings.Contains(lower, "daughter:"), strings.Contains(lower, "born:"):
|
||||||
@@ -1055,7 +934,7 @@ func shouldRetryWithoutJSONMode(err error) bool {
|
|||||||
if err == nil {
|
if err == nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
if errors.Is(err, errMetadataEmptyResponse) || errors.Is(err, errMetadataNoJSONObject) {
|
if errors.Is(err, ErrEmptyResponse) || errors.Is(err, ErrNoJSONObject) {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1063,27 +942,6 @@ func shouldRetryWithoutJSONMode(err error) bool {
|
|||||||
return strings.Contains(lower, "parse json")
|
return strings.Contains(lower, "parse json")
|
||||||
}
|
}
|
||||||
|
|
||||||
func isPermanentModelError(err error) bool {
|
|
||||||
if err == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
lower := strings.ToLower(err.Error())
|
|
||||||
for _, marker := range []string{
|
|
||||||
"invalid model name",
|
|
||||||
"model_not_found",
|
|
||||||
"model not found",
|
|
||||||
"unknown model",
|
|
||||||
"no such model",
|
|
||||||
"does not exist",
|
|
||||||
} {
|
|
||||||
if strings.Contains(lower, marker) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func sleepRetry(ctx context.Context, attempt int, log *slog.Logger, provider string) error {
|
func sleepRetry(ctx context.Context, attempt int, log *slog.Logger, provider string) error {
|
||||||
delay := time.Duration(attempt*attempt) * 200 * time.Millisecond
|
delay := time.Duration(attempt*attempt) * 200 * time.Millisecond
|
||||||
if log != nil {
|
if log != nil {
|
||||||
@@ -1110,59 +968,3 @@ func sleepMetadataRetry(ctx context.Context, attempt int) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) shouldBypassModel(model string) bool {
|
|
||||||
c.modelHealthMu.Lock()
|
|
||||||
defer c.modelHealthMu.Unlock()
|
|
||||||
|
|
||||||
state, ok := c.modelHealth[model]
|
|
||||||
if !ok {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return !state.unhealthyUntil.IsZero() && time.Now().Before(state.unhealthyUntil)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Client) noteEmptyResponse(model string) {
|
|
||||||
c.modelHealthMu.Lock()
|
|
||||||
defer c.modelHealthMu.Unlock()
|
|
||||||
|
|
||||||
state := c.modelHealth[model]
|
|
||||||
state.consecutiveEmpty++
|
|
||||||
if state.consecutiveEmpty >= emptyResponseCircuitThreshold {
|
|
||||||
state.unhealthyUntil = time.Now().Add(emptyResponseCircuitTTL)
|
|
||||||
if c.log != nil {
|
|
||||||
c.log.Warn("metadata model marked temporarily unhealthy after repeated empty responses",
|
|
||||||
slog.String("provider", c.name),
|
|
||||||
slog.String("model", model),
|
|
||||||
slog.Time("until", state.unhealthyUntil),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
c.modelHealth[model] = state
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Client) noteModelSuccess(model string) {
|
|
||||||
c.modelHealthMu.Lock()
|
|
||||||
defer c.modelHealthMu.Unlock()
|
|
||||||
|
|
||||||
delete(c.modelHealth, model)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Client) notePermanentModelFailure(model string, err error) {
|
|
||||||
c.modelHealthMu.Lock()
|
|
||||||
defer c.modelHealthMu.Unlock()
|
|
||||||
|
|
||||||
state := c.modelHealth[model]
|
|
||||||
state.consecutiveEmpty = emptyResponseCircuitThreshold
|
|
||||||
state.unhealthyUntil = time.Now().Add(permanentModelFailureTTL)
|
|
||||||
c.modelHealth[model] = state
|
|
||||||
|
|
||||||
if c.log != nil {
|
|
||||||
c.log.Warn("metadata model marked unhealthy after permanent failure",
|
|
||||||
slog.String("provider", c.name),
|
|
||||||
slog.String("model", model),
|
|
||||||
slog.String("error", err.Error()),
|
|
||||||
slog.Time("until", state.unhealthyUntil),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -11,6 +11,17 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func newTestClient(t *testing.T, url string) *Client {
|
||||||
|
t.Helper()
|
||||||
|
return New(Config{
|
||||||
|
Name: "litellm",
|
||||||
|
BaseURL: url,
|
||||||
|
APIKey: "test-key",
|
||||||
|
HTTPClient: http.DefaultClient,
|
||||||
|
Log: slog.New(slog.NewTextHandler(io.Discard, nil)),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestExtractMetadataFromStreamingResponse(t *testing.T) {
|
func TestExtractMetadataFromStreamingResponse(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
@@ -26,6 +37,9 @@ func TestExtractMetadataFromStreamingResponse(t *testing.T) {
|
|||||||
if req.Stream == nil || !*req.Stream {
|
if req.Stream == nil || !*req.Stream {
|
||||||
t.Fatalf("stream flag = %v, want true", req.Stream)
|
t.Fatalf("stream flag = %v, want true", req.Stream)
|
||||||
}
|
}
|
||||||
|
if req.Model != "qwen3.5:latest" {
|
||||||
|
t.Fatalf("model = %q, want qwen3.5:latest", req.Model)
|
||||||
|
}
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "text/event-stream")
|
w.Header().Set("Content-Type", "text/event-stream")
|
||||||
_, _ = io.WriteString(w, "data: {\"choices\":[{\"delta\":{\"content\":\"{\\\"people\\\":[],\"}}]}\n\n")
|
_, _ = io.WriteString(w, "data: {\"choices\":[{\"delta\":{\"content\":\"{\\\"people\\\":[],\"}}]}\n\n")
|
||||||
@@ -35,20 +49,13 @@ func TestExtractMetadataFromStreamingResponse(t *testing.T) {
|
|||||||
}))
|
}))
|
||||||
defer server.Close()
|
defer server.Close()
|
||||||
|
|
||||||
client := New(Config{
|
client := newTestClient(t, server.URL)
|
||||||
Name: "litellm",
|
metadata, err := client.ExtractMetadataWith(context.Background(), MetadataOptions{
|
||||||
BaseURL: server.URL,
|
Model: "qwen3.5:latest",
|
||||||
APIKey: "test-key",
|
|
||||||
MetadataModel: "qwen3.5:latest",
|
|
||||||
Temperature: 0.1,
|
Temperature: 0.1,
|
||||||
HTTPClient: server.Client(),
|
}, "Project idea: Build an Android companion app.")
|
||||||
Log: slog.New(slog.NewTextHandler(io.Discard, nil)),
|
|
||||||
EmbeddingModel: "unused",
|
|
||||||
})
|
|
||||||
|
|
||||||
metadata, err := client.ExtractMetadata(context.Background(), "Project idea: Build an Android companion app.")
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("ExtractMetadata() error = %v", err)
|
t.Fatalf("ExtractMetadataWith() error = %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if metadata.Type != "idea" {
|
if metadata.Type != "idea" {
|
||||||
@@ -94,20 +101,13 @@ func TestExtractMetadataRetriesWithoutJSONMode(t *testing.T) {
|
|||||||
}))
|
}))
|
||||||
defer server.Close()
|
defer server.Close()
|
||||||
|
|
||||||
client := New(Config{
|
client := newTestClient(t, server.URL)
|
||||||
Name: "litellm",
|
metadata, err := client.ExtractMetadataWith(context.Background(), MetadataOptions{
|
||||||
BaseURL: server.URL,
|
Model: "qwen3.5:latest",
|
||||||
APIKey: "test-key",
|
|
||||||
MetadataModel: "qwen3.5:latest",
|
|
||||||
Temperature: 0.1,
|
Temperature: 0.1,
|
||||||
HTTPClient: server.Client(),
|
}, "Project idea: Build an Android companion app.")
|
||||||
Log: slog.New(slog.NewTextHandler(io.Discard, nil)),
|
|
||||||
EmbeddingModel: "unused",
|
|
||||||
})
|
|
||||||
|
|
||||||
metadata, err := client.ExtractMetadata(context.Background(), "Project idea: Build an Android companion app.")
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("ExtractMetadata() error = %v", err)
|
t.Fatalf("ExtractMetadataWith() error = %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if metadata.Type != "idea" {
|
if metadata.Type != "idea" {
|
||||||
@@ -127,71 +127,33 @@ func TestExtractMetadataRetriesWithoutJSONMode(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestExtractMetadataBypassesInvalidFallbackModelAfterFirstFailure(t *testing.T) {
|
func TestIsPermanentModelError(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
var mu sync.Mutex
|
cases := []struct {
|
||||||
primaryCalls := 0
|
name string
|
||||||
invalidFallbackCalls := 0
|
err error
|
||||||
|
want bool
|
||||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
}{
|
||||||
defer func() {
|
{"nil", nil, false},
|
||||||
_ = r.Body.Close()
|
{"invalid model", errMsg("Invalid model name passed in model=qwen3"), true},
|
||||||
}()
|
{"model not found", errMsg("model_not_found"), true},
|
||||||
|
{"no such model", errMsg("no such model"), true},
|
||||||
var req chatCompletionsRequest
|
{"transient", errMsg("connection refused"), false},
|
||||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
||||||
t.Fatalf("decode request: %v", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
switch req.Model {
|
for _, tc := range cases {
|
||||||
case "empty-primary":
|
tc := tc
|
||||||
_, _ = io.WriteString(w, `{"choices":[{"message":{"role":"assistant","content":""}}]}`)
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
case "qwen3.5:latest":
|
if got := IsPermanentModelError(tc.err); got != tc.want {
|
||||||
mu.Lock()
|
t.Fatalf("IsPermanentModelError(%v) = %v, want %v", tc.err, got, tc.want)
|
||||||
primaryCalls++
|
|
||||||
mu.Unlock()
|
|
||||||
_, _ = io.WriteString(w, `{"choices":[{"message":{"role":"assistant","content":"{\"people\":[],\"action_items\":[],\"dates_mentioned\":[],\"topics\":[\"metadata\"],\"type\":\"observation\",\"source\":\"primary\"}"}}]}`)
|
|
||||||
case "qwen3":
|
|
||||||
mu.Lock()
|
|
||||||
invalidFallbackCalls++
|
|
||||||
mu.Unlock()
|
|
||||||
w.WriteHeader(http.StatusBadRequest)
|
|
||||||
_, _ = io.WriteString(w, "{\"error\":{\"message\":\"{'error': '/chat/completions: Invalid model name passed in model=qwen3. Call `/v1/models` to view available models for your key.'}\"}}")
|
|
||||||
default:
|
|
||||||
t.Fatalf("unexpected model %q", req.Model)
|
|
||||||
}
|
}
|
||||||
}))
|
|
||||||
defer server.Close()
|
|
||||||
|
|
||||||
client := New(Config{
|
|
||||||
Name: "litellm",
|
|
||||||
BaseURL: server.URL,
|
|
||||||
APIKey: "test-key",
|
|
||||||
MetadataModel: "empty-primary",
|
|
||||||
FallbackMetadataModels: []string{"qwen3", "qwen3.5:latest"},
|
|
||||||
Temperature: 0.1,
|
|
||||||
HTTPClient: server.Client(),
|
|
||||||
Log: slog.New(slog.NewTextHandler(io.Discard, nil)),
|
|
||||||
EmbeddingModel: "unused",
|
|
||||||
})
|
})
|
||||||
|
|
||||||
for i := 0; i < 2; i++ {
|
|
||||||
metadata, err := client.ExtractMetadata(context.Background(), "A short note about metadata.")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("ExtractMetadata() error = %v", err)
|
|
||||||
}
|
|
||||||
if metadata.Source != "primary" {
|
|
||||||
t.Fatalf("metadata source = %q, want primary", metadata.Source)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
mu.Lock()
|
type stringError string
|
||||||
defer mu.Unlock()
|
|
||||||
if invalidFallbackCalls != 1 {
|
func (s stringError) Error() string { return string(s) }
|
||||||
t.Fatalf("invalid fallback calls = %d, want 1", invalidFallbackCalls)
|
|
||||||
}
|
func errMsg(s string) error { return stringError(s) }
|
||||||
if primaryCalls != 2 {
|
|
||||||
t.Fatalf("valid fallback calls = %d, want 2", primaryCalls)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,25 +0,0 @@
|
|||||||
package ai
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"log/slog"
|
|
||||||
"net/http"
|
|
||||||
|
|
||||||
"git.warky.dev/wdevs/amcs/internal/ai/litellm"
|
|
||||||
"git.warky.dev/wdevs/amcs/internal/ai/ollama"
|
|
||||||
"git.warky.dev/wdevs/amcs/internal/ai/openrouter"
|
|
||||||
"git.warky.dev/wdevs/amcs/internal/config"
|
|
||||||
)
|
|
||||||
|
|
||||||
func NewProvider(cfg config.AIConfig, httpClient *http.Client, log *slog.Logger) (Provider, error) {
|
|
||||||
switch cfg.Provider {
|
|
||||||
case "litellm":
|
|
||||||
return litellm.New(cfg, httpClient, log)
|
|
||||||
case "ollama":
|
|
||||||
return ollama.New(cfg, httpClient, log)
|
|
||||||
case "openrouter":
|
|
||||||
return openrouter.New(cfg, httpClient, log)
|
|
||||||
default:
|
|
||||||
return nil, fmt.Errorf("unsupported ai.provider: %s", cfg.Provider)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,33 +0,0 @@
|
|||||||
package ai
|
|
||||||
|
|
||||||
import (
|
|
||||||
"io"
|
|
||||||
"log/slog"
|
|
||||||
"net/http"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"git.warky.dev/wdevs/amcs/internal/config"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestNewProviderSupportsOllama(t *testing.T) {
|
|
||||||
provider, err := NewProvider(config.AIConfig{
|
|
||||||
Provider: "ollama",
|
|
||||||
Embeddings: config.AIEmbeddingConfig{
|
|
||||||
Model: "nomic-embed-text",
|
|
||||||
Dimensions: 768,
|
|
||||||
},
|
|
||||||
Metadata: config.AIMetadataConfig{
|
|
||||||
Model: "llama3.2",
|
|
||||||
},
|
|
||||||
Ollama: config.OllamaConfig{
|
|
||||||
BaseURL: "http://localhost:11434/v1",
|
|
||||||
APIKey: "ollama",
|
|
||||||
},
|
|
||||||
}, &http.Client{}, slog.New(slog.NewTextHandler(io.Discard, nil)))
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("NewProvider() error = %v", err)
|
|
||||||
}
|
|
||||||
if provider.Name() != "ollama" {
|
|
||||||
t.Fatalf("provider name = %q, want ollama", provider.Name())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,30 +0,0 @@
|
|||||||
package litellm
|
|
||||||
|
|
||||||
import (
|
|
||||||
"log/slog"
|
|
||||||
"net/http"
|
|
||||||
|
|
||||||
"git.warky.dev/wdevs/amcs/internal/ai/compat"
|
|
||||||
"git.warky.dev/wdevs/amcs/internal/config"
|
|
||||||
)
|
|
||||||
|
|
||||||
func New(cfg config.AIConfig, httpClient *http.Client, log *slog.Logger) (*compat.Client, error) {
|
|
||||||
fallbacks := cfg.LiteLLM.EffectiveFallbackMetadataModels()
|
|
||||||
if len(fallbacks) == 0 {
|
|
||||||
fallbacks = cfg.Metadata.EffectiveFallbackModels()
|
|
||||||
}
|
|
||||||
return compat.New(compat.Config{
|
|
||||||
Name: "litellm",
|
|
||||||
BaseURL: cfg.LiteLLM.BaseURL,
|
|
||||||
APIKey: cfg.LiteLLM.APIKey,
|
|
||||||
EmbeddingModel: cfg.LiteLLM.EmbeddingModel,
|
|
||||||
MetadataModel: cfg.LiteLLM.MetadataModel,
|
|
||||||
FallbackMetadataModels: fallbacks,
|
|
||||||
Temperature: cfg.Metadata.Temperature,
|
|
||||||
Headers: cfg.LiteLLM.RequestHeaders,
|
|
||||||
HTTPClient: httpClient,
|
|
||||||
Log: log,
|
|
||||||
Dimensions: cfg.Embeddings.Dimensions,
|
|
||||||
LogConversations: cfg.Metadata.LogConversations,
|
|
||||||
}), nil
|
|
||||||
}
|
|
||||||
@@ -1,26 +0,0 @@
|
|||||||
package ollama
|
|
||||||
|
|
||||||
import (
|
|
||||||
"log/slog"
|
|
||||||
"net/http"
|
|
||||||
|
|
||||||
"git.warky.dev/wdevs/amcs/internal/ai/compat"
|
|
||||||
"git.warky.dev/wdevs/amcs/internal/config"
|
|
||||||
)
|
|
||||||
|
|
||||||
func New(cfg config.AIConfig, httpClient *http.Client, log *slog.Logger) (*compat.Client, error) {
|
|
||||||
return compat.New(compat.Config{
|
|
||||||
Name: "ollama",
|
|
||||||
BaseURL: cfg.Ollama.BaseURL,
|
|
||||||
APIKey: cfg.Ollama.APIKey,
|
|
||||||
EmbeddingModel: cfg.Embeddings.Model,
|
|
||||||
MetadataModel: cfg.Metadata.Model,
|
|
||||||
FallbackMetadataModels: cfg.Metadata.EffectiveFallbackModels(),
|
|
||||||
Temperature: cfg.Metadata.Temperature,
|
|
||||||
Headers: cfg.Ollama.RequestHeaders,
|
|
||||||
HTTPClient: httpClient,
|
|
||||||
Log: log,
|
|
||||||
Dimensions: cfg.Embeddings.Dimensions,
|
|
||||||
LogConversations: cfg.Metadata.LogConversations,
|
|
||||||
}), nil
|
|
||||||
}
|
|
||||||
@@ -1,37 +0,0 @@
|
|||||||
package openrouter
|
|
||||||
|
|
||||||
import (
|
|
||||||
"log/slog"
|
|
||||||
"net/http"
|
|
||||||
|
|
||||||
"git.warky.dev/wdevs/amcs/internal/ai/compat"
|
|
||||||
"git.warky.dev/wdevs/amcs/internal/config"
|
|
||||||
)
|
|
||||||
|
|
||||||
func New(cfg config.AIConfig, httpClient *http.Client, log *slog.Logger) (*compat.Client, error) {
|
|
||||||
headers := make(map[string]string, len(cfg.OpenRouter.ExtraHeaders)+2)
|
|
||||||
for key, value := range cfg.OpenRouter.ExtraHeaders {
|
|
||||||
headers[key] = value
|
|
||||||
}
|
|
||||||
if cfg.OpenRouter.SiteURL != "" {
|
|
||||||
headers["HTTP-Referer"] = cfg.OpenRouter.SiteURL
|
|
||||||
}
|
|
||||||
if cfg.OpenRouter.AppName != "" {
|
|
||||||
headers["X-Title"] = cfg.OpenRouter.AppName
|
|
||||||
}
|
|
||||||
|
|
||||||
return compat.New(compat.Config{
|
|
||||||
Name: "openrouter",
|
|
||||||
BaseURL: cfg.OpenRouter.BaseURL,
|
|
||||||
APIKey: cfg.OpenRouter.APIKey,
|
|
||||||
EmbeddingModel: cfg.Embeddings.Model,
|
|
||||||
MetadataModel: cfg.Metadata.Model,
|
|
||||||
FallbackMetadataModels: cfg.Metadata.EffectiveFallbackModels(),
|
|
||||||
Temperature: cfg.Metadata.Temperature,
|
|
||||||
Headers: headers,
|
|
||||||
HTTPClient: httpClient,
|
|
||||||
Log: log,
|
|
||||||
Dimensions: cfg.Embeddings.Dimensions,
|
|
||||||
LogConversations: cfg.Metadata.LogConversations,
|
|
||||||
}), nil
|
|
||||||
}
|
|
||||||
@@ -1,15 +0,0 @@
|
|||||||
package ai
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
|
|
||||||
thoughttypes "git.warky.dev/wdevs/amcs/internal/types"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Provider interface {
|
|
||||||
Embed(ctx context.Context, input string) ([]float32, error)
|
|
||||||
ExtractMetadata(ctx context.Context, input string) (thoughttypes.ThoughtMetadata, error)
|
|
||||||
Summarize(ctx context.Context, systemPrompt, userPrompt string) (string, error)
|
|
||||||
Name() string
|
|
||||||
EmbeddingModel() string
|
|
||||||
}
|
|
||||||
96
internal/ai/registry.go
Normal file
96
internal/ai/registry.go
Normal file
@@ -0,0 +1,96 @@
|
|||||||
|
package ai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"git.warky.dev/wdevs/amcs/internal/ai/compat"
|
||||||
|
"git.warky.dev/wdevs/amcs/internal/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Registry holds one compat.Client per named provider. Runners look up clients
|
||||||
|
// by provider name when walking a role chain.
|
||||||
|
type Registry struct {
|
||||||
|
clients map[string]*compat.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewRegistry builds a Registry from the configured providers. Each provider
|
||||||
|
// type maps onto a compat.Client with type-specific header plumbing (e.g.
|
||||||
|
// openrouter's HTTP-Referer / X-Title).
|
||||||
|
func NewRegistry(providers map[string]config.ProviderConfig, httpClient *http.Client, log *slog.Logger) (*Registry, error) {
|
||||||
|
if httpClient == nil {
|
||||||
|
return nil, fmt.Errorf("ai registry: http client is required")
|
||||||
|
}
|
||||||
|
if len(providers) == 0 {
|
||||||
|
return nil, fmt.Errorf("ai registry: no providers configured")
|
||||||
|
}
|
||||||
|
|
||||||
|
clients := make(map[string]*compat.Client, len(providers))
|
||||||
|
for name, p := range providers {
|
||||||
|
headers, err := providerHeaders(p)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("ai registry: provider %q: %w", name, err)
|
||||||
|
}
|
||||||
|
clients[name] = compat.New(compat.Config{
|
||||||
|
Name: name,
|
||||||
|
BaseURL: p.BaseURL,
|
||||||
|
APIKey: p.APIKey,
|
||||||
|
Headers: headers,
|
||||||
|
HTTPClient: httpClient,
|
||||||
|
Log: log,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return &Registry{clients: clients}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Client returns the compat.Client registered under name.
|
||||||
|
func (r *Registry) Client(name string) (*compat.Client, error) {
|
||||||
|
c, ok := r.clients[name]
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("ai registry: provider %q is not configured", name)
|
||||||
|
}
|
||||||
|
return c, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Names returns the registered provider names.
|
||||||
|
func (r *Registry) Names() []string {
|
||||||
|
names := make([]string, 0, len(r.clients))
|
||||||
|
for name := range r.clients {
|
||||||
|
names = append(names, name)
|
||||||
|
}
|
||||||
|
return names
|
||||||
|
}
|
||||||
|
|
||||||
|
func providerHeaders(p config.ProviderConfig) (map[string]string, error) {
|
||||||
|
switch p.Type {
|
||||||
|
case "litellm", "ollama":
|
||||||
|
return cloneHeaders(p.RequestHeaders), nil
|
||||||
|
case "openrouter":
|
||||||
|
headers := cloneHeaders(p.RequestHeaders)
|
||||||
|
if headers == nil {
|
||||||
|
headers = map[string]string{}
|
||||||
|
}
|
||||||
|
if s := strings.TrimSpace(p.SiteURL); s != "" {
|
||||||
|
headers["HTTP-Referer"] = s
|
||||||
|
}
|
||||||
|
if s := strings.TrimSpace(p.AppName); s != "" {
|
||||||
|
headers["X-Title"] = s
|
||||||
|
}
|
||||||
|
return headers, nil
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("unsupported provider type %q", p.Type)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func cloneHeaders(in map[string]string) map[string]string {
|
||||||
|
if len(in) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := make(map[string]string, len(in))
|
||||||
|
for k, v := range in {
|
||||||
|
out[k] = v
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
80
internal/ai/registry_test.go
Normal file
80
internal/ai/registry_test.go
Normal file
@@ -0,0 +1,80 @@
|
|||||||
|
package ai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"git.warky.dev/wdevs/amcs/internal/ai/compat"
|
||||||
|
"git.warky.dev/wdevs/amcs/internal/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewRegistryOpenRouterHeaders(t *testing.T) {
|
||||||
|
var (
|
||||||
|
gotReferer string
|
||||||
|
gotTitle string
|
||||||
|
gotCustom string
|
||||||
|
)
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
gotReferer = r.Header.Get("HTTP-Referer")
|
||||||
|
gotTitle = r.Header.Get("X-Title")
|
||||||
|
gotCustom = r.Header.Get("X-Custom")
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"choices": []map[string]any{{"message": map[string]any{"role": "assistant", "content": "ok"}}},
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
providers := map[string]config.ProviderConfig{
|
||||||
|
"router": {
|
||||||
|
Type: "openrouter",
|
||||||
|
BaseURL: srv.URL,
|
||||||
|
APIKey: "secret",
|
||||||
|
RequestHeaders: map[string]string{
|
||||||
|
"X-Custom": "value",
|
||||||
|
},
|
||||||
|
AppName: "amcs",
|
||||||
|
SiteURL: "https://example.com",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
reg, err := NewRegistry(providers, srv.Client(), nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewRegistry() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
client, err := reg.Client("router")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Client(router) error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := client.SummarizeWith(context.Background(), compat.SummarizeOptions{Model: "gpt-4.1-mini"}, "system", "user"); err != nil {
|
||||||
|
t.Fatalf("SummarizeWith() error = %v", err)
|
||||||
|
}
|
||||||
|
if gotReferer != "https://example.com" {
|
||||||
|
t.Fatalf("HTTP-Referer = %q, want https://example.com", gotReferer)
|
||||||
|
}
|
||||||
|
if gotTitle != "amcs" {
|
||||||
|
t.Fatalf("X-Title = %q, want amcs", gotTitle)
|
||||||
|
}
|
||||||
|
if gotCustom != "value" {
|
||||||
|
t.Fatalf("X-Custom = %q, want value", gotCustom)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewRegistryRejectsUnsupportedProviderType(t *testing.T) {
|
||||||
|
providers := map[string]config.ProviderConfig{
|
||||||
|
"bad": {
|
||||||
|
Type: "unknown",
|
||||||
|
BaseURL: "http://localhost:4000/v1",
|
||||||
|
APIKey: "secret",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := NewRegistry(providers, &http.Client{}, nil)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("NewRegistry() error = nil, want unsupported provider type error")
|
||||||
|
}
|
||||||
|
}
|
||||||
367
internal/ai/runner.go
Normal file
367
internal/ai/runner.go
Normal file
@@ -0,0 +1,367 @@
|
|||||||
|
package ai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.warky.dev/wdevs/amcs/internal/ai/compat"
|
||||||
|
"git.warky.dev/wdevs/amcs/internal/config"
|
||||||
|
thoughttypes "git.warky.dev/wdevs/amcs/internal/types"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Health TTLs per failure class. These are short enough that a healed target
|
||||||
|
// gets retried without manual intervention, but long enough to avoid hammering
|
||||||
|
// a broken provider every call.
|
||||||
|
const (
|
||||||
|
transientCooldown = 30 * time.Second
|
||||||
|
permanentCooldown = 10 * time.Minute
|
||||||
|
emptyResponseThreshold = 3
|
||||||
|
emptyResponseCooldown = 2 * time.Minute
|
||||||
|
dimensionMismatchWarning = "embedding dimension mismatch"
|
||||||
|
)
|
||||||
|
|
||||||
|
// EmbedResult carries the vector plus the (provider, model) that produced it —
|
||||||
|
// callers store the actual model so later searches against that row use the
|
||||||
|
// matching query embedding.
|
||||||
|
type EmbedResult struct {
|
||||||
|
Vector []float32
|
||||||
|
Provider string
|
||||||
|
Model string
|
||||||
|
}
|
||||||
|
|
||||||
|
// EmbeddingRunner executes the embeddings role chain with sequential fallback.
|
||||||
|
type EmbeddingRunner struct {
|
||||||
|
registry *Registry
|
||||||
|
chain []config.RoleTarget
|
||||||
|
dimensions int
|
||||||
|
health *healthTracker
|
||||||
|
log *slog.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
// MetadataRunner executes the metadata role chain with sequential fallback and
|
||||||
|
// a heuristic fallthrough when every target is unhealthy or fails.
|
||||||
|
type MetadataRunner struct {
|
||||||
|
registry *Registry
|
||||||
|
chain []config.RoleTarget
|
||||||
|
opts metadataRunOpts
|
||||||
|
health *healthTracker
|
||||||
|
log *slog.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
type metadataRunOpts struct {
|
||||||
|
temperature float64
|
||||||
|
logConversations bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewEmbeddingRunner builds a runner for the embeddings role. chain must be
|
||||||
|
// non-empty and every target must be registered.
|
||||||
|
func NewEmbeddingRunner(registry *Registry, chain []config.RoleTarget, dimensions int, log *slog.Logger) (*EmbeddingRunner, error) {
|
||||||
|
if registry == nil {
|
||||||
|
return nil, fmt.Errorf("embedding runner: registry is required")
|
||||||
|
}
|
||||||
|
if len(chain) == 0 {
|
||||||
|
return nil, fmt.Errorf("embedding runner: chain is empty")
|
||||||
|
}
|
||||||
|
if dimensions <= 0 {
|
||||||
|
return nil, fmt.Errorf("embedding runner: dimensions must be > 0")
|
||||||
|
}
|
||||||
|
for i, t := range chain {
|
||||||
|
if _, err := registry.Client(t.Provider); err != nil {
|
||||||
|
return nil, fmt.Errorf("embedding runner: chain[%d]: %w", i, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return &EmbeddingRunner{
|
||||||
|
registry: registry,
|
||||||
|
chain: chain,
|
||||||
|
dimensions: dimensions,
|
||||||
|
health: newHealthTracker(),
|
||||||
|
log: log,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMetadataRunner builds a runner for the metadata role.
|
||||||
|
func NewMetadataRunner(registry *Registry, chain []config.RoleTarget, temperature float64, logConversations bool, log *slog.Logger) (*MetadataRunner, error) {
|
||||||
|
if registry == nil {
|
||||||
|
return nil, fmt.Errorf("metadata runner: registry is required")
|
||||||
|
}
|
||||||
|
if len(chain) == 0 {
|
||||||
|
return nil, fmt.Errorf("metadata runner: chain is empty")
|
||||||
|
}
|
||||||
|
for i, t := range chain {
|
||||||
|
if _, err := registry.Client(t.Provider); err != nil {
|
||||||
|
return nil, fmt.Errorf("metadata runner: chain[%d]: %w", i, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return &MetadataRunner{
|
||||||
|
registry: registry,
|
||||||
|
chain: chain,
|
||||||
|
opts: metadataRunOpts{
|
||||||
|
temperature: temperature,
|
||||||
|
logConversations: logConversations,
|
||||||
|
},
|
||||||
|
health: newHealthTracker(),
|
||||||
|
log: log,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// PrimaryProvider returns the first provider in the chain.
|
||||||
|
func (r *EmbeddingRunner) PrimaryProvider() string { return r.chain[0].Provider }
|
||||||
|
|
||||||
|
// PrimaryModel returns the first model in the chain — the one used as the
|
||||||
|
// storage key for search matching.
|
||||||
|
func (r *EmbeddingRunner) PrimaryModel() string { return r.chain[0].Model }
|
||||||
|
|
||||||
|
// Dimensions returns the required vector dimension.
|
||||||
|
func (r *EmbeddingRunner) Dimensions() int { return r.dimensions }
|
||||||
|
|
||||||
|
// Embed walks the chain and returns the first successful embedding. The
|
||||||
|
// returned EmbedResult names the actual (provider, model) that produced the
|
||||||
|
// vector — callers use that when recording the row.
|
||||||
|
func (r *EmbeddingRunner) Embed(ctx context.Context, input string) (EmbedResult, error) {
|
||||||
|
var errs []error
|
||||||
|
for _, target := range r.chain {
|
||||||
|
if r.health.skip(target) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
client, err := r.registry.Client(target.Provider)
|
||||||
|
if err != nil {
|
||||||
|
errs = append(errs, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
vec, err := client.EmbedWith(ctx, target.Model, input)
|
||||||
|
if err != nil {
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
return EmbedResult{}, ctx.Err()
|
||||||
|
}
|
||||||
|
r.classify(target, err)
|
||||||
|
r.logFailure("embed", target, err)
|
||||||
|
errs = append(errs, fmt.Errorf("%s/%s: %w", target.Provider, target.Model, err))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if len(vec) != r.dimensions {
|
||||||
|
dimErr := fmt.Errorf("%s: expected %d, got %d", dimensionMismatchWarning, r.dimensions, len(vec))
|
||||||
|
r.health.markTransient(target)
|
||||||
|
r.logFailure("embed", target, dimErr)
|
||||||
|
errs = append(errs, fmt.Errorf("%s/%s: %w", target.Provider, target.Model, dimErr))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
r.health.markHealthy(target)
|
||||||
|
return EmbedResult{Vector: vec, Provider: target.Provider, Model: target.Model}, nil
|
||||||
|
}
|
||||||
|
return EmbedResult{}, fmt.Errorf("all embedding targets failed: %w", errors.Join(errs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// EmbedPrimary embeds using only the primary target — used for search queries
|
||||||
|
// so the query vector matches rows stored under the primary model. Falls back
|
||||||
|
// to returning the error without walking the chain.
|
||||||
|
func (r *EmbeddingRunner) EmbedPrimary(ctx context.Context, input string) ([]float32, error) {
|
||||||
|
target := r.chain[0]
|
||||||
|
client, err := r.registry.Client(target.Provider)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
vec, err := client.EmbedWith(ctx, target.Model, input)
|
||||||
|
if err != nil {
|
||||||
|
r.classify(target, err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if len(vec) != r.dimensions {
|
||||||
|
return nil, fmt.Errorf("%s: expected %d, got %d", dimensionMismatchWarning, r.dimensions, len(vec))
|
||||||
|
}
|
||||||
|
r.health.markHealthy(target)
|
||||||
|
return vec, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// PrimaryProvider / PrimaryModel for metadata mirror the embedding runner.
|
||||||
|
func (r *MetadataRunner) PrimaryProvider() string { return r.chain[0].Provider }
|
||||||
|
func (r *MetadataRunner) PrimaryModel() string { return r.chain[0].Model }
|
||||||
|
|
||||||
|
// ExtractMetadata walks the chain sequentially. If every target fails or is
|
||||||
|
// unhealthy, it returns a heuristic metadata so capture never hard-fails.
|
||||||
|
func (r *MetadataRunner) ExtractMetadata(ctx context.Context, input string) (thoughttypes.ThoughtMetadata, error) {
|
||||||
|
var errs []error
|
||||||
|
for _, target := range r.chain {
|
||||||
|
if r.health.skip(target) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
client, err := r.registry.Client(target.Provider)
|
||||||
|
if err != nil {
|
||||||
|
errs = append(errs, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
md, err := client.ExtractMetadataWith(ctx, compat.MetadataOptions{
|
||||||
|
Model: target.Model,
|
||||||
|
Temperature: r.opts.temperature,
|
||||||
|
LogConversations: r.opts.logConversations,
|
||||||
|
}, input)
|
||||||
|
if err != nil {
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
return thoughttypes.ThoughtMetadata{}, ctx.Err()
|
||||||
|
}
|
||||||
|
r.classify(target, err)
|
||||||
|
r.logFailure("metadata", target, err)
|
||||||
|
errs = append(errs, fmt.Errorf("%s/%s: %w", target.Provider, target.Model, err))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
r.health.markHealthy(target)
|
||||||
|
return md, nil
|
||||||
|
}
|
||||||
|
if r.log != nil {
|
||||||
|
r.log.Warn("metadata chain exhausted, using heuristic fallback",
|
||||||
|
slog.Int("targets", len(r.chain)),
|
||||||
|
slog.String("error", errors.Join(errs...).Error()),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
return compat.HeuristicMetadataFromInput(input), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Summarize walks the chain; unlike metadata, there is no heuristic fallback —
|
||||||
|
// returns the joined error when everything fails.
|
||||||
|
func (r *MetadataRunner) Summarize(ctx context.Context, systemPrompt, userPrompt string) (string, error) {
|
||||||
|
var errs []error
|
||||||
|
for _, target := range r.chain {
|
||||||
|
if r.health.skip(target) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
client, err := r.registry.Client(target.Provider)
|
||||||
|
if err != nil {
|
||||||
|
errs = append(errs, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
out, err := client.SummarizeWith(ctx, compat.SummarizeOptions{
|
||||||
|
Model: target.Model,
|
||||||
|
Temperature: r.opts.temperature,
|
||||||
|
}, systemPrompt, userPrompt)
|
||||||
|
if err != nil {
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
return "", ctx.Err()
|
||||||
|
}
|
||||||
|
r.classify(target, err)
|
||||||
|
r.logFailure("summarize", target, err)
|
||||||
|
errs = append(errs, fmt.Errorf("%s/%s: %w", target.Provider, target.Model, err))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
r.health.markHealthy(target)
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
return "", fmt.Errorf("all summarize targets failed: %w", errors.Join(errs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *EmbeddingRunner) classify(target config.RoleTarget, err error) {
|
||||||
|
switch {
|
||||||
|
case compat.IsPermanentModelError(err):
|
||||||
|
r.health.markPermanent(target)
|
||||||
|
default:
|
||||||
|
r.health.markTransient(target)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *MetadataRunner) classify(target config.RoleTarget, err error) {
|
||||||
|
switch {
|
||||||
|
case compat.IsPermanentModelError(err):
|
||||||
|
r.health.markPermanent(target)
|
||||||
|
case errors.Is(err, compat.ErrEmptyResponse):
|
||||||
|
r.health.markEmpty(target)
|
||||||
|
default:
|
||||||
|
r.health.markTransient(target)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *EmbeddingRunner) logFailure(role string, target config.RoleTarget, err error) {
|
||||||
|
if r.log == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
r.log.Warn("ai target failed",
|
||||||
|
slog.String("role", role),
|
||||||
|
slog.String("provider", target.Provider),
|
||||||
|
slog.String("model", target.Model),
|
||||||
|
slog.String("error", err.Error()),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *MetadataRunner) logFailure(role string, target config.RoleTarget, err error) {
|
||||||
|
if r.log == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
r.log.Warn("ai target failed",
|
||||||
|
slog.String("role", role),
|
||||||
|
slog.String("provider", target.Provider),
|
||||||
|
slog.String("model", target.Model),
|
||||||
|
slog.String("error", err.Error()),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// healthTracker records per-(provider, model) failure state. skip returns true
|
||||||
|
// when a target is still inside its cooldown window; the caller then tries the
|
||||||
|
// next target in the chain.
|
||||||
|
type healthTracker struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
states map[config.RoleTarget]*healthState
|
||||||
|
}
|
||||||
|
|
||||||
|
type healthState struct {
|
||||||
|
unhealthyUntil time.Time
|
||||||
|
emptyCount int
|
||||||
|
}
|
||||||
|
|
||||||
|
func newHealthTracker() *healthTracker {
|
||||||
|
return &healthTracker{states: map[config.RoleTarget]*healthState{}}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *healthTracker) skip(target config.RoleTarget) bool {
|
||||||
|
h.mu.Lock()
|
||||||
|
defer h.mu.Unlock()
|
||||||
|
s, ok := h.states[target]
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return time.Now().Before(s.unhealthyUntil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *healthTracker) markTransient(target config.RoleTarget) {
|
||||||
|
h.setCooldown(target, transientCooldown)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *healthTracker) markPermanent(target config.RoleTarget) {
|
||||||
|
h.setCooldown(target, permanentCooldown)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *healthTracker) markEmpty(target config.RoleTarget) {
|
||||||
|
h.mu.Lock()
|
||||||
|
defer h.mu.Unlock()
|
||||||
|
s := h.states[target]
|
||||||
|
if s == nil {
|
||||||
|
s = &healthState{}
|
||||||
|
h.states[target] = s
|
||||||
|
}
|
||||||
|
s.emptyCount++
|
||||||
|
if s.emptyCount >= emptyResponseThreshold {
|
||||||
|
s.unhealthyUntil = time.Now().Add(emptyResponseCooldown)
|
||||||
|
s.emptyCount = 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *healthTracker) markHealthy(target config.RoleTarget) {
|
||||||
|
h.mu.Lock()
|
||||||
|
defer h.mu.Unlock()
|
||||||
|
if s, ok := h.states[target]; ok {
|
||||||
|
s.unhealthyUntil = time.Time{}
|
||||||
|
s.emptyCount = 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *healthTracker) setCooldown(target config.RoleTarget, d time.Duration) {
|
||||||
|
h.mu.Lock()
|
||||||
|
defer h.mu.Unlock()
|
||||||
|
s := h.states[target]
|
||||||
|
if s == nil {
|
||||||
|
s = &healthState{}
|
||||||
|
h.states[target] = s
|
||||||
|
}
|
||||||
|
s.unhealthyUntil = time.Now().Add(d)
|
||||||
|
s.emptyCount = 0
|
||||||
|
}
|
||||||
139
internal/ai/runner_test.go
Normal file
139
internal/ai/runner_test.go
Normal file
@@ -0,0 +1,139 @@
|
|||||||
|
package ai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"git.warky.dev/wdevs/amcs/internal/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestEmbeddingRunnerFallsBackAndSkipsUnhealthyPrimary(t *testing.T) {
|
||||||
|
var (
|
||||||
|
mu sync.Mutex
|
||||||
|
primaryCalls int
|
||||||
|
)
|
||||||
|
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path != "/embeddings" {
|
||||||
|
http.NotFound(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var req struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
}
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
switch req.Model {
|
||||||
|
case "embed-primary":
|
||||||
|
mu.Lock()
|
||||||
|
primaryCalls++
|
||||||
|
mu.Unlock()
|
||||||
|
http.Error(w, "upstream down", http.StatusBadGateway)
|
||||||
|
case "embed-fallback":
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"data": []map[string]any{{"embedding": []float32{0.1, 0.2, 0.3}}},
|
||||||
|
})
|
||||||
|
default:
|
||||||
|
http.Error(w, "unknown model", http.StatusBadRequest)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
reg, err := NewRegistry(map[string]config.ProviderConfig{
|
||||||
|
"p1": {Type: "litellm", BaseURL: srv.URL, APIKey: "k1"},
|
||||||
|
"p2": {Type: "litellm", BaseURL: srv.URL, APIKey: "k2"},
|
||||||
|
}, srv.Client(), nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewRegistry() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
runner, err := NewEmbeddingRunner(reg, []config.RoleTarget{
|
||||||
|
{Provider: "p1", Model: "embed-primary"},
|
||||||
|
{Provider: "p2", Model: "embed-fallback"},
|
||||||
|
}, 3, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewEmbeddingRunner() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
res, err := runner.Embed(context.Background(), "hello")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Embed() first call error = %v", err)
|
||||||
|
}
|
||||||
|
if res.Provider != "p2" || res.Model != "embed-fallback" {
|
||||||
|
t.Fatalf("Embed() first call target = %s/%s, want p2/embed-fallback", res.Provider, res.Model)
|
||||||
|
}
|
||||||
|
|
||||||
|
res, err = runner.Embed(context.Background(), "hello again")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Embed() second call error = %v", err)
|
||||||
|
}
|
||||||
|
if res.Provider != "p2" || res.Model != "embed-fallback" {
|
||||||
|
t.Fatalf("Embed() second call target = %s/%s, want p2/embed-fallback", res.Provider, res.Model)
|
||||||
|
}
|
||||||
|
|
||||||
|
mu.Lock()
|
||||||
|
calls := primaryCalls
|
||||||
|
mu.Unlock()
|
||||||
|
if calls != 3 {
|
||||||
|
t.Fatalf("primary calls = %d, want 3 (first request retries 3x; second call should skip unhealthy primary)", calls)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMetadataRunnerSummarizeFallsBack(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path != "/chat/completions" {
|
||||||
|
http.NotFound(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var req struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
}
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
switch req.Model {
|
||||||
|
case "sum-primary":
|
||||||
|
http.Error(w, "provider error", http.StatusBadGateway)
|
||||||
|
case "sum-fallback":
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"choices": []map[string]any{{
|
||||||
|
"message": map[string]any{"role": "assistant", "content": "fallback summary"},
|
||||||
|
}},
|
||||||
|
})
|
||||||
|
default:
|
||||||
|
http.Error(w, "unknown model", http.StatusBadRequest)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
reg, err := NewRegistry(map[string]config.ProviderConfig{
|
||||||
|
"p1": {Type: "litellm", BaseURL: srv.URL, APIKey: "k1"},
|
||||||
|
"p2": {Type: "litellm", BaseURL: srv.URL, APIKey: "k2"},
|
||||||
|
}, srv.Client(), nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewRegistry() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
runner, err := NewMetadataRunner(reg, []config.RoleTarget{
|
||||||
|
{Provider: "p1", Model: "sum-primary"},
|
||||||
|
{Provider: "p2", Model: "sum-fallback"},
|
||||||
|
}, 0.1, false, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewMetadataRunner() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
summary, err := runner.Summarize(context.Background(), "system", "user")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Summarize() error = %v", err)
|
||||||
|
}
|
||||||
|
if summary != "fallback summary" {
|
||||||
|
t.Fatalf("summary = %q, want %q", summary, "fallback summary")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -34,7 +34,7 @@ func Run(ctx context.Context, configPath string) error {
|
|||||||
|
|
||||||
logger.Info("loaded configuration",
|
logger.Info("loaded configuration",
|
||||||
slog.String("path", loadedFrom),
|
slog.String("path", loadedFrom),
|
||||||
slog.String("provider", cfg.AI.Provider),
|
slog.Int("config_version", cfg.Version),
|
||||||
slog.String("version", info.Version),
|
slog.String("version", info.Version),
|
||||||
slog.String("tag_name", info.TagName),
|
slog.String("tag_name", info.TagName),
|
||||||
slog.String("build_date", info.BuildDate),
|
slog.String("build_date", info.BuildDate),
|
||||||
@@ -52,11 +52,37 @@ func Run(ctx context.Context, configPath string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
httpClient := &http.Client{Timeout: 30 * time.Second}
|
httpClient := &http.Client{Timeout: 30 * time.Second}
|
||||||
provider, err := ai.NewProvider(cfg.AI, httpClient, logger)
|
registry, err := ai.NewRegistry(cfg.AI.Providers, httpClient, logger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
foregroundEmbeddings, err := ai.NewEmbeddingRunner(registry, cfg.AI.Embeddings.Chain(), cfg.AI.Embeddings.Dimensions, logger)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
foregroundMetadata, err := ai.NewMetadataRunner(registry, cfg.AI.Metadata.Chain(), cfg.AI.Metadata.Temperature, cfg.AI.Metadata.LogConversations, logger)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
backgroundEmbeddings := foregroundEmbeddings
|
||||||
|
backgroundMetadata := foregroundMetadata
|
||||||
|
if cfg.AI.Background != nil {
|
||||||
|
if cfg.AI.Background.Embeddings != nil {
|
||||||
|
backgroundEmbeddings, err = ai.NewEmbeddingRunner(registry, cfg.AI.Background.Embeddings.AsTargets(), cfg.AI.Embeddings.Dimensions, logger)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if cfg.AI.Background.Metadata != nil {
|
||||||
|
backgroundMetadata, err = ai.NewMetadataRunner(registry, cfg.AI.Background.Metadata.AsTargets(), cfg.AI.Metadata.Temperature, cfg.AI.Metadata.LogConversations, logger)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
var keyring *auth.Keyring
|
var keyring *auth.Keyring
|
||||||
var oauthRegistry *auth.OAuthRegistry
|
var oauthRegistry *auth.OAuthRegistry
|
||||||
var tokenStore *auth.TokenStore
|
var tokenStore *auth.TokenStore
|
||||||
@@ -77,12 +103,13 @@ func Run(ctx context.Context, configPath string) error {
|
|||||||
dynClients := auth.NewDynamicClientStore()
|
dynClients := auth.NewDynamicClientStore()
|
||||||
activeProjects := session.NewActiveProjects()
|
activeProjects := session.NewActiveProjects()
|
||||||
|
|
||||||
logger.Info("database connection verified",
|
logger.Info("ai providers initialised",
|
||||||
slog.String("provider", provider.Name()),
|
slog.String("embedding_primary", foregroundEmbeddings.PrimaryProvider()+"/"+foregroundEmbeddings.PrimaryModel()),
|
||||||
|
slog.String("metadata_primary", foregroundMetadata.PrimaryProvider()+"/"+foregroundMetadata.PrimaryModel()),
|
||||||
)
|
)
|
||||||
|
|
||||||
if cfg.Backfill.Enabled && cfg.Backfill.RunOnStartup {
|
if cfg.Backfill.Enabled && cfg.Backfill.RunOnStartup {
|
||||||
go runBackfillPass(ctx, db, provider, cfg.Backfill, logger)
|
go runBackfillPass(ctx, db, backgroundEmbeddings, cfg.Backfill, logger)
|
||||||
}
|
}
|
||||||
|
|
||||||
if cfg.Backfill.Enabled && cfg.Backfill.Interval > 0 {
|
if cfg.Backfill.Enabled && cfg.Backfill.Interval > 0 {
|
||||||
@@ -94,14 +121,14 @@ func Run(ctx context.Context, configPath string) error {
|
|||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return
|
return
|
||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
runBackfillPass(ctx, db, provider, cfg.Backfill, logger)
|
runBackfillPass(ctx, db, backgroundEmbeddings, cfg.Backfill, logger)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
if cfg.MetadataRetry.Enabled && cfg.MetadataRetry.RunOnStartup {
|
if cfg.MetadataRetry.Enabled && cfg.MetadataRetry.RunOnStartup {
|
||||||
go runMetadataRetryPass(ctx, db, provider, cfg, activeProjects, logger)
|
go runMetadataRetryPass(ctx, db, backgroundMetadata, cfg, activeProjects, logger)
|
||||||
}
|
}
|
||||||
|
|
||||||
if cfg.MetadataRetry.Enabled && cfg.MetadataRetry.Interval > 0 {
|
if cfg.MetadataRetry.Enabled && cfg.MetadataRetry.Interval > 0 {
|
||||||
@@ -113,13 +140,13 @@ func Run(ctx context.Context, configPath string) error {
|
|||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return
|
return
|
||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
runMetadataRetryPass(ctx, db, provider, cfg, activeProjects, logger)
|
runMetadataRetryPass(ctx, db, backgroundMetadata, cfg, activeProjects, logger)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
handler, err := routes(logger, cfg, info, db, provider, keyring, oauthRegistry, tokenStore, authCodes, dynClients, activeProjects)
|
handler, err := routes(logger, cfg, info, db, foregroundEmbeddings, foregroundMetadata, backgroundEmbeddings, backgroundMetadata, keyring, oauthRegistry, tokenStore, authCodes, dynClients, activeProjects)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -156,33 +183,35 @@ func Run(ctx context.Context, configPath string) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func routes(logger *slog.Logger, cfg *config.Config, info buildinfo.Info, db *store.DB, provider ai.Provider, keyring *auth.Keyring, oauthRegistry *auth.OAuthRegistry, tokenStore *auth.TokenStore, authCodes *auth.AuthCodeStore, dynClients *auth.DynamicClientStore, activeProjects *session.ActiveProjects) (http.Handler, error) {
|
func routes(logger *slog.Logger, cfg *config.Config, info buildinfo.Info, db *store.DB, embeddings *ai.EmbeddingRunner, metadata *ai.MetadataRunner, bgEmbeddings *ai.EmbeddingRunner, bgMetadata *ai.MetadataRunner, keyring *auth.Keyring, oauthRegistry *auth.OAuthRegistry, tokenStore *auth.TokenStore, authCodes *auth.AuthCodeStore, dynClients *auth.DynamicClientStore, activeProjects *session.ActiveProjects) (http.Handler, error) {
|
||||||
mux := http.NewServeMux()
|
mux := http.NewServeMux()
|
||||||
accessTracker := auth.NewAccessTracker()
|
accessTracker := auth.NewAccessTracker()
|
||||||
oauthEnabled := oauthRegistry != nil && tokenStore != nil
|
oauthEnabled := oauthRegistry != nil && tokenStore != nil
|
||||||
authMiddleware := auth.Middleware(cfg.Auth, keyring, oauthRegistry, tokenStore, accessTracker, logger)
|
authMiddleware := auth.Middleware(cfg.Auth, keyring, oauthRegistry, tokenStore, accessTracker, logger)
|
||||||
filesTool := tools.NewFilesTool(db, activeProjects)
|
filesTool := tools.NewFilesTool(db, activeProjects)
|
||||||
metadataRetryer := tools.NewMetadataRetryer(context.Background(), db, provider, cfg.Capture, cfg.AI.Metadata.Timeout, activeProjects, logger)
|
enrichmentRetryer := tools.NewEnrichmentRetryer(context.Background(), db, bgMetadata, cfg.Capture, cfg.AI.Metadata.Timeout, activeProjects, logger)
|
||||||
|
backfillTool := tools.NewBackfillTool(db, bgEmbeddings, activeProjects, logger)
|
||||||
|
|
||||||
toolSet := mcpserver.ToolSet{
|
toolSet := mcpserver.ToolSet{
|
||||||
Capture: tools.NewCaptureTool(db, provider, cfg.Capture, cfg.AI.Metadata.Timeout, activeProjects, metadataRetryer, logger),
|
Capture: tools.NewCaptureTool(db, embeddings, cfg.Capture, activeProjects, enrichmentRetryer, backfillTool),
|
||||||
Search: tools.NewSearchTool(db, provider, cfg.Search, activeProjects),
|
Search: tools.NewSearchTool(db, embeddings, cfg.Search, activeProjects),
|
||||||
List: tools.NewListTool(db, cfg.Search, activeProjects),
|
List: tools.NewListTool(db, cfg.Search, activeProjects),
|
||||||
Stats: tools.NewStatsTool(db),
|
Stats: tools.NewStatsTool(db),
|
||||||
Get: tools.NewGetTool(db),
|
Get: tools.NewGetTool(db),
|
||||||
Update: tools.NewUpdateTool(db, provider, cfg.Capture, logger),
|
Update: tools.NewUpdateTool(db, embeddings, metadata, cfg.Capture, logger),
|
||||||
Delete: tools.NewDeleteTool(db),
|
Delete: tools.NewDeleteTool(db),
|
||||||
Archive: tools.NewArchiveTool(db),
|
Archive: tools.NewArchiveTool(db),
|
||||||
Projects: tools.NewProjectsTool(db, activeProjects),
|
Projects: tools.NewProjectsTool(db, activeProjects),
|
||||||
Version: tools.NewVersionTool(cfg.MCP.ServerName, info),
|
Version: tools.NewVersionTool(cfg.MCP.ServerName, info),
|
||||||
Context: tools.NewContextTool(db, provider, cfg.Search, activeProjects),
|
Learnings: tools.NewLearningsTool(db, activeProjects, cfg.Search),
|
||||||
Recall: tools.NewRecallTool(db, provider, cfg.Search, activeProjects),
|
Context: tools.NewContextTool(db, embeddings, cfg.Search, activeProjects),
|
||||||
Summarize: tools.NewSummarizeTool(db, provider, cfg.Search, activeProjects),
|
Recall: tools.NewRecallTool(db, embeddings, cfg.Search, activeProjects),
|
||||||
Links: tools.NewLinksTool(db, provider, cfg.Search),
|
Summarize: tools.NewSummarizeTool(db, embeddings, metadata, cfg.Search, activeProjects),
|
||||||
|
Links: tools.NewLinksTool(db, embeddings, cfg.Search),
|
||||||
Files: filesTool,
|
Files: filesTool,
|
||||||
Backfill: tools.NewBackfillTool(db, provider, activeProjects, logger),
|
Backfill: backfillTool,
|
||||||
Reparse: tools.NewReparseMetadataTool(db, provider, cfg.Capture, activeProjects, logger),
|
Reparse: tools.NewReparseMetadataTool(db, bgMetadata, cfg.Capture, activeProjects, logger),
|
||||||
RetryMetadata: tools.NewRetryMetadataTool(metadataRetryer),
|
RetryMetadata: tools.NewRetryEnrichmentTool(enrichmentRetryer),
|
||||||
Maintenance: tools.NewMaintenanceTool(db),
|
Maintenance: tools.NewMaintenanceTool(db),
|
||||||
Skills: tools.NewSkillsTool(db, activeProjects),
|
Skills: tools.NewSkillsTool(db, activeProjects),
|
||||||
ChatHistory: tools.NewChatHistoryTool(db, activeProjects),
|
ChatHistory: tools.NewChatHistoryTool(db, activeProjects),
|
||||||
@@ -241,8 +270,8 @@ func routes(logger *slog.Logger, cfg *config.Config, info buildinfo.Info, db *st
|
|||||||
), nil
|
), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func runMetadataRetryPass(ctx context.Context, db *store.DB, provider ai.Provider, cfg *config.Config, activeProjects *session.ActiveProjects, logger *slog.Logger) {
|
func runMetadataRetryPass(ctx context.Context, db *store.DB, metadataRunner *ai.MetadataRunner, cfg *config.Config, activeProjects *session.ActiveProjects, logger *slog.Logger) {
|
||||||
retryer := tools.NewMetadataRetryer(ctx, db, provider, cfg.Capture, cfg.AI.Metadata.Timeout, activeProjects, logger)
|
retryer := tools.NewMetadataRetryer(ctx, db, metadataRunner, cfg.Capture, cfg.AI.Metadata.Timeout, activeProjects, logger)
|
||||||
_, out, err := retryer.Handle(ctx, nil, tools.RetryMetadataInput{
|
_, out, err := retryer.Handle(ctx, nil, tools.RetryMetadataInput{
|
||||||
Limit: cfg.MetadataRetry.MaxPerRun,
|
Limit: cfg.MetadataRetry.MaxPerRun,
|
||||||
IncludeArchived: cfg.MetadataRetry.IncludeArchived,
|
IncludeArchived: cfg.MetadataRetry.IncludeArchived,
|
||||||
@@ -260,8 +289,8 @@ func runMetadataRetryPass(ctx context.Context, db *store.DB, provider ai.Provide
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
func runBackfillPass(ctx context.Context, db *store.DB, provider ai.Provider, cfg config.BackfillConfig, logger *slog.Logger) {
|
func runBackfillPass(ctx context.Context, db *store.DB, embeddings *ai.EmbeddingRunner, cfg config.BackfillConfig, logger *slog.Logger) {
|
||||||
backfiller := tools.NewBackfillTool(db, provider, nil, logger)
|
backfiller := tools.NewBackfillTool(db, embeddings, nil, logger)
|
||||||
_, out, err := backfiller.Handle(ctx, nil, tools.BackfillInput{
|
_, out, err := backfiller.Handle(ctx, nil, tools.BackfillInput{
|
||||||
Limit: cfg.MaxPerRun,
|
Limit: cfg.MaxPerRun,
|
||||||
IncludeArchived: cfg.IncludeArchived,
|
IncludeArchived: cfg.IncludeArchived,
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"git.warky.dev/wdevs/amcs/internal/auth"
|
"git.warky.dev/wdevs/amcs/internal/auth"
|
||||||
|
"git.warky.dev/wdevs/amcs/internal/requestip"
|
||||||
)
|
)
|
||||||
|
|
||||||
// --- JSON types ---
|
// --- JSON types ---
|
||||||
@@ -261,7 +262,7 @@ func handleClientCredentials(w http.ResponseWriter, r *http.Request, oauthRegist
|
|||||||
}
|
}
|
||||||
keyID, ok := oauthRegistry.Lookup(clientID, clientSecret)
|
keyID, ok := oauthRegistry.Lookup(clientID, clientSecret)
|
||||||
if !ok {
|
if !ok {
|
||||||
log.Warn("oauth token: invalid client credentials", slog.String("remote_addr", r.RemoteAddr))
|
log.Warn("oauth token: invalid client credentials", slog.String("remote_addr", requestip.FromRequest(r)))
|
||||||
w.Header().Set("WWW-Authenticate", `Basic realm="oauth"`)
|
w.Header().Set("WWW-Authenticate", `Basic realm="oauth"`)
|
||||||
writeTokenError(w, "invalid_client", http.StatusUnauthorized)
|
writeTokenError(w, "invalid_client", http.StatusUnauthorized)
|
||||||
return
|
return
|
||||||
@@ -290,7 +291,7 @@ func handleAuthorizationCode(w http.ResponseWriter, r *http.Request, authCodes *
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
if !verifyPKCE(codeVerifier, entry.CodeChallenge, entry.CodeChallengeMethod) {
|
if !verifyPKCE(codeVerifier, entry.CodeChallenge, entry.CodeChallengeMethod) {
|
||||||
log.Warn("oauth token: PKCE verification failed", slog.String("remote_addr", r.RemoteAddr))
|
log.Warn("oauth token: PKCE verification failed", slog.String("remote_addr", requestip.FromRequest(r)))
|
||||||
writeTokenError(w, "invalid_grant", http.StatusBadRequest)
|
writeTokenError(w, "invalid_grant", http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -157,3 +157,34 @@ func TestMiddlewareRejectsMissingOrInvalidKey(t *testing.T) {
|
|||||||
t.Fatalf("invalid key status = %d, want %d", rec.Code, http.StatusUnauthorized)
|
t.Fatalf("invalid key status = %d, want %d", rec.Code, http.StatusUnauthorized)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestMiddlewareRecordsForwardedRemoteAddr(t *testing.T) {
|
||||||
|
keyring, err := NewKeyring([]config.APIKey{{ID: "client-a", Value: "secret"}})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewKeyring() error = %v", err)
|
||||||
|
}
|
||||||
|
tracker := NewAccessTracker()
|
||||||
|
|
||||||
|
handler := Middleware(config.AuthConfig{HeaderName: "x-brain-key"}, keyring, nil, nil, tracker, testLogger())(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusNoContent)
|
||||||
|
}))
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/mcp", nil)
|
||||||
|
req.RemoteAddr = "10.0.0.5:2222"
|
||||||
|
req.Header.Set("x-brain-key", "secret")
|
||||||
|
req.Header.Set("X-Real-IP", "203.0.113.99")
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
if rec.Code != http.StatusNoContent {
|
||||||
|
t.Fatalf("status = %d, want %d", rec.Code, http.StatusNoContent)
|
||||||
|
}
|
||||||
|
|
||||||
|
snap := tracker.Snapshot()
|
||||||
|
if len(snap) != 1 {
|
||||||
|
t.Fatalf("len(snapshot) = %d, want 1", len(snap))
|
||||||
|
}
|
||||||
|
if snap[0].RemoteAddr != "203.0.113.99" {
|
||||||
|
t.Fatalf("snapshot remote_addr = %q, want %q", snap[0].RemoteAddr, "203.0.113.99")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"git.warky.dev/wdevs/amcs/internal/config"
|
"git.warky.dev/wdevs/amcs/internal/config"
|
||||||
|
"git.warky.dev/wdevs/amcs/internal/requestip"
|
||||||
)
|
)
|
||||||
|
|
||||||
type contextKey string
|
type contextKey string
|
||||||
@@ -22,17 +23,18 @@ func Middleware(cfg config.AuthConfig, keyring *Keyring, oauthRegistry *OAuthReg
|
|||||||
}
|
}
|
||||||
recordAccess := func(r *http.Request, keyID string) {
|
recordAccess := func(r *http.Request, keyID string) {
|
||||||
if tracker != nil {
|
if tracker != nil {
|
||||||
tracker.Record(keyID, r.URL.Path, r.RemoteAddr, r.UserAgent(), time.Now())
|
tracker.Record(keyID, r.URL.Path, requestip.FromRequest(r), r.UserAgent(), time.Now())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return func(next http.Handler) http.Handler {
|
return func(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
remoteAddr := requestip.FromRequest(r)
|
||||||
// 1. Custom header → keyring only.
|
// 1. Custom header → keyring only.
|
||||||
if keyring != nil {
|
if keyring != nil {
|
||||||
if token := strings.TrimSpace(r.Header.Get(headerName)); token != "" {
|
if token := strings.TrimSpace(r.Header.Get(headerName)); token != "" {
|
||||||
keyID, ok := keyring.Lookup(token)
|
keyID, ok := keyring.Lookup(token)
|
||||||
if !ok {
|
if !ok {
|
||||||
log.Warn("authentication failed", slog.String("remote_addr", r.RemoteAddr))
|
log.Warn("authentication failed", slog.String("remote_addr", remoteAddr))
|
||||||
http.Error(w, "invalid API key", http.StatusUnauthorized)
|
http.Error(w, "invalid API key", http.StatusUnauthorized)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -58,7 +60,7 @@ func Middleware(cfg config.AuthConfig, keyring *Keyring, oauthRegistry *OAuthReg
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
log.Warn("bearer token rejected", slog.String("remote_addr", r.RemoteAddr))
|
log.Warn("bearer token rejected", slog.String("remote_addr", remoteAddr))
|
||||||
http.Error(w, "invalid token or API key", http.StatusUnauthorized)
|
http.Error(w, "invalid token or API key", http.StatusUnauthorized)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -71,7 +73,7 @@ func Middleware(cfg config.AuthConfig, keyring *Keyring, oauthRegistry *OAuthReg
|
|||||||
}
|
}
|
||||||
keyID, ok := oauthRegistry.Lookup(clientID, clientSecret)
|
keyID, ok := oauthRegistry.Lookup(clientID, clientSecret)
|
||||||
if !ok {
|
if !ok {
|
||||||
log.Warn("oauth client authentication failed", slog.String("remote_addr", r.RemoteAddr))
|
log.Warn("oauth client authentication failed", slog.String("remote_addr", remoteAddr))
|
||||||
http.Error(w, "invalid OAuth client credentials", http.StatusUnauthorized)
|
http.Error(w, "invalid OAuth client credentials", http.StatusUnauthorized)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -85,7 +87,7 @@ func Middleware(cfg config.AuthConfig, keyring *Keyring, oauthRegistry *OAuthReg
|
|||||||
if token := strings.TrimSpace(r.URL.Query().Get(cfg.QueryParam)); token != "" {
|
if token := strings.TrimSpace(r.URL.Query().Get(cfg.QueryParam)); token != "" {
|
||||||
keyID, ok := keyring.Lookup(token)
|
keyID, ok := keyring.Lookup(token)
|
||||||
if !ok {
|
if !ok {
|
||||||
log.Warn("authentication failed", slog.String("remote_addr", r.RemoteAddr))
|
log.Warn("authentication failed", slog.String("remote_addr", remoteAddr))
|
||||||
http.Error(w, "invalid API key", http.StatusUnauthorized)
|
http.Error(w, "invalid API key", http.StatusUnauthorized)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
|
Version int `yaml:"version"`
|
||||||
Server ServerConfig `yaml:"server"`
|
Server ServerConfig `yaml:"server"`
|
||||||
MCP MCPConfig `yaml:"mcp"`
|
MCP MCPConfig `yaml:"mcp"`
|
||||||
Auth AuthConfig `yaml:"auth"`
|
Auth AuthConfig `yaml:"auth"`
|
||||||
@@ -37,10 +38,7 @@ type MCPConfig struct {
|
|||||||
Version string `yaml:"version"`
|
Version string `yaml:"version"`
|
||||||
Transport string `yaml:"transport"`
|
Transport string `yaml:"transport"`
|
||||||
SessionTimeout time.Duration `yaml:"session_timeout"`
|
SessionTimeout time.Duration `yaml:"session_timeout"`
|
||||||
// PublicURL is the externally reachable base URL of this server (e.g. https://amcs.example.com).
|
|
||||||
// When set, it is used to build absolute icon URLs in the MCP server identity.
|
|
||||||
PublicURL string `yaml:"public_url"`
|
PublicURL string `yaml:"public_url"`
|
||||||
// Instructions is set at startup from the embedded memory.md and sent to MCP clients on initialise.
|
|
||||||
Instructions string `yaml:"-"`
|
Instructions string `yaml:"-"`
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -77,52 +75,82 @@ type DatabaseConfig struct {
|
|||||||
MaxConnIdleTime time.Duration `yaml:"max_conn_idle_time"`
|
MaxConnIdleTime time.Duration `yaml:"max_conn_idle_time"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AIConfig (v2): named providers + per-role chains.
|
||||||
type AIConfig struct {
|
type AIConfig struct {
|
||||||
|
Providers map[string]ProviderConfig `yaml:"providers"`
|
||||||
|
Embeddings EmbeddingsRoleConfig `yaml:"embeddings"`
|
||||||
|
Metadata MetadataRoleConfig `yaml:"metadata"`
|
||||||
|
Background *BackgroundRolesConfig `yaml:"background,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ProviderConfig struct {
|
||||||
|
Type string `yaml:"type"`
|
||||||
|
BaseURL string `yaml:"base_url"`
|
||||||
|
APIKey string `yaml:"api_key"`
|
||||||
|
RequestHeaders map[string]string `yaml:"request_headers,omitempty"`
|
||||||
|
AppName string `yaml:"app_name,omitempty"`
|
||||||
|
SiteURL string `yaml:"site_url,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type RoleTarget struct {
|
||||||
Provider string `yaml:"provider"`
|
Provider string `yaml:"provider"`
|
||||||
Embeddings AIEmbeddingConfig `yaml:"embeddings"`
|
Model string `yaml:"model"`
|
||||||
Metadata AIMetadataConfig `yaml:"metadata"`
|
|
||||||
LiteLLM LiteLLMConfig `yaml:"litellm"`
|
|
||||||
Ollama OllamaConfig `yaml:"ollama"`
|
|
||||||
OpenRouter OpenRouterAIConfig `yaml:"openrouter"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type AIEmbeddingConfig struct {
|
type RoleChain struct {
|
||||||
Model string `yaml:"model"`
|
Primary RoleTarget `yaml:"primary"`
|
||||||
|
Fallbacks []RoleTarget `yaml:"fallbacks,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type EmbeddingsRoleConfig struct {
|
||||||
Dimensions int `yaml:"dimensions"`
|
Dimensions int `yaml:"dimensions"`
|
||||||
|
Primary RoleTarget `yaml:"primary"`
|
||||||
|
Fallbacks []RoleTarget `yaml:"fallbacks,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type AIMetadataConfig struct {
|
type MetadataRoleConfig struct {
|
||||||
Model string `yaml:"model"`
|
|
||||||
FallbackModels []string `yaml:"fallback_models"`
|
|
||||||
FallbackModel string `yaml:"fallback_model"` // legacy single fallback
|
|
||||||
Temperature float64 `yaml:"temperature"`
|
Temperature float64 `yaml:"temperature"`
|
||||||
LogConversations bool `yaml:"log_conversations"`
|
LogConversations bool `yaml:"log_conversations"`
|
||||||
Timeout time.Duration `yaml:"timeout"`
|
Timeout time.Duration `yaml:"timeout"`
|
||||||
|
Primary RoleTarget `yaml:"primary"`
|
||||||
|
Fallbacks []RoleTarget `yaml:"fallbacks,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type LiteLLMConfig struct {
|
// BackgroundRolesConfig overrides the foreground chains for background workers
|
||||||
BaseURL string `yaml:"base_url"`
|
// (backfill_embeddings, metadata_retry, reparse_metadata). Either field may be
|
||||||
APIKey string `yaml:"api_key"`
|
// nil to inherit the foreground role unchanged.
|
||||||
UseResponsesAPI bool `yaml:"use_responses_api"`
|
type BackgroundRolesConfig struct {
|
||||||
RequestHeaders map[string]string `yaml:"request_headers"`
|
Embeddings *RoleChain `yaml:"embeddings,omitempty"`
|
||||||
EmbeddingModel string `yaml:"embedding_model"`
|
Metadata *RoleChain `yaml:"metadata,omitempty"`
|
||||||
MetadataModel string `yaml:"metadata_model"`
|
|
||||||
FallbackMetadataModels []string `yaml:"fallback_metadata_models"`
|
|
||||||
FallbackMetadataModel string `yaml:"fallback_metadata_model"` // legacy single fallback
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type OllamaConfig struct {
|
// Chain returns primary followed by fallbacks (deduped, blanks dropped).
|
||||||
BaseURL string `yaml:"base_url"`
|
func (e EmbeddingsRoleConfig) Chain() []RoleTarget {
|
||||||
APIKey string `yaml:"api_key"`
|
return dedupeTargets(append([]RoleTarget{e.Primary}, e.Fallbacks...))
|
||||||
RequestHeaders map[string]string `yaml:"request_headers"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type OpenRouterAIConfig struct {
|
func (m MetadataRoleConfig) Chain() []RoleTarget {
|
||||||
BaseURL string `yaml:"base_url"`
|
return dedupeTargets(append([]RoleTarget{m.Primary}, m.Fallbacks...))
|
||||||
APIKey string `yaml:"api_key"`
|
}
|
||||||
AppName string `yaml:"app_name"`
|
|
||||||
SiteURL string `yaml:"site_url"`
|
func (c RoleChain) AsTargets() []RoleTarget {
|
||||||
ExtraHeaders map[string]string `yaml:"extra_headers"`
|
return dedupeTargets(append([]RoleTarget{c.Primary}, c.Fallbacks...))
|
||||||
|
}
|
||||||
|
|
||||||
|
func dedupeTargets(in []RoleTarget) []RoleTarget {
|
||||||
|
out := make([]RoleTarget, 0, len(in))
|
||||||
|
seen := make(map[RoleTarget]struct{}, len(in))
|
||||||
|
for _, t := range in {
|
||||||
|
if t.Provider == "" || t.Model == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, ok := seen[t]; ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seen[t] = struct{}{}
|
||||||
|
out = append(out, t)
|
||||||
|
}
|
||||||
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
type CaptureConfig struct {
|
type CaptureConfig struct {
|
||||||
@@ -167,45 +195,3 @@ type MetadataRetryConfig struct {
|
|||||||
MaxPerRun int `yaml:"max_per_run"`
|
MaxPerRun int `yaml:"max_per_run"`
|
||||||
IncludeArchived bool `yaml:"include_archived"`
|
IncludeArchived bool `yaml:"include_archived"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c AIMetadataConfig) EffectiveFallbackModels() []string {
|
|
||||||
models := make([]string, 0, len(c.FallbackModels)+1)
|
|
||||||
for _, model := range c.FallbackModels {
|
|
||||||
if model != "" {
|
|
||||||
models = append(models, model)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if c.FallbackModel != "" {
|
|
||||||
models = append(models, c.FallbackModel)
|
|
||||||
}
|
|
||||||
return dedupeNonEmpty(models)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c LiteLLMConfig) EffectiveFallbackMetadataModels() []string {
|
|
||||||
models := make([]string, 0, len(c.FallbackMetadataModels)+1)
|
|
||||||
for _, model := range c.FallbackMetadataModels {
|
|
||||||
if model != "" {
|
|
||||||
models = append(models, model)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if c.FallbackMetadataModel != "" {
|
|
||||||
models = append(models, c.FallbackMetadataModel)
|
|
||||||
}
|
|
||||||
return dedupeNonEmpty(models)
|
|
||||||
}
|
|
||||||
|
|
||||||
func dedupeNonEmpty(values []string) []string {
|
|
||||||
seen := make(map[string]struct{}, len(values))
|
|
||||||
out := make([]string, 0, len(values))
|
|
||||||
for _, value := range values {
|
|
||||||
if value == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if _, ok := seen[value]; ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
seen[value] = struct{}{}
|
|
||||||
out = append(out, value)
|
|
||||||
}
|
|
||||||
return out
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package config
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
"os"
|
"os"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -12,6 +13,12 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func Load(explicitPath string) (*Config, string, error) {
|
func Load(explicitPath string) (*Config, string, error) {
|
||||||
|
return LoadWithLogger(explicitPath, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadWithLogger is Load with a logger surface for migration notices. Passing
|
||||||
|
// nil is fine — migration events will simply not be logged.
|
||||||
|
func LoadWithLogger(explicitPath string, log *slog.Logger) (*Config, string, error) {
|
||||||
path := ResolvePath(explicitPath)
|
path := ResolvePath(explicitPath)
|
||||||
|
|
||||||
data, err := os.ReadFile(path)
|
data, err := os.ReadFile(path)
|
||||||
@@ -19,10 +26,38 @@ func Load(explicitPath string) (*Config, string, error) {
|
|||||||
return nil, path, fmt.Errorf("read config %q: %w", path, err)
|
return nil, path, fmt.Errorf("read config %q: %w", path, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
cfg := defaultConfig()
|
raw := map[string]any{}
|
||||||
if err := yaml.Unmarshal(data, &cfg); err != nil {
|
if err := yaml.Unmarshal(data, &raw); err != nil {
|
||||||
return nil, path, fmt.Errorf("decode config %q: %w", path, err)
|
return nil, path, fmt.Errorf("decode config %q: %w", path, err)
|
||||||
}
|
}
|
||||||
|
if raw == nil {
|
||||||
|
raw = map[string]any{}
|
||||||
|
}
|
||||||
|
|
||||||
|
applied, err := Migrate(raw)
|
||||||
|
if err != nil {
|
||||||
|
return nil, path, fmt.Errorf("migrate config %q: %w", path, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(applied) > 0 {
|
||||||
|
if log != nil {
|
||||||
|
for _, step := range applied {
|
||||||
|
log.Warn("config migrated in memory",
|
||||||
|
slog.String("path", path),
|
||||||
|
slog.Int("from_version", step.From),
|
||||||
|
slog.Int("to_version", step.To),
|
||||||
|
slog.String("describe", step.Describe),
|
||||||
|
slog.String("hint", "persist with amcs-migrate-config"),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg, err := decodeTyped(raw)
|
||||||
|
if err != nil {
|
||||||
|
return nil, path, fmt.Errorf("decode migrated config %q: %w", path, err)
|
||||||
|
}
|
||||||
|
cfg.Version = CurrentConfigVersion
|
||||||
|
|
||||||
applyEnvOverrides(&cfg)
|
applyEnvOverrides(&cfg)
|
||||||
if err := cfg.Validate(); err != nil {
|
if err := cfg.Validate(); err != nil {
|
||||||
@@ -32,6 +67,18 @@ func Load(explicitPath string) (*Config, string, error) {
|
|||||||
return &cfg, path, nil
|
return &cfg, path, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func decodeTyped(raw map[string]any) (Config, error) {
|
||||||
|
out, err := yaml.Marshal(raw)
|
||||||
|
if err != nil {
|
||||||
|
return Config{}, fmt.Errorf("re-marshal migrated config: %w", err)
|
||||||
|
}
|
||||||
|
cfg := defaultConfig()
|
||||||
|
if err := yaml.Unmarshal(out, &cfg); err != nil {
|
||||||
|
return Config{}, err
|
||||||
|
}
|
||||||
|
return cfg, nil
|
||||||
|
}
|
||||||
|
|
||||||
func ResolvePath(explicitPath string) string {
|
func ResolvePath(explicitPath string) string {
|
||||||
if path := strings.TrimSpace(explicitPath); path != "" {
|
if path := strings.TrimSpace(explicitPath); path != "" {
|
||||||
if path != ".yaml" && path != ".yml" {
|
if path != ".yaml" && path != ".yml" {
|
||||||
@@ -49,6 +96,7 @@ func ResolvePath(explicitPath string) string {
|
|||||||
func defaultConfig() Config {
|
func defaultConfig() Config {
|
||||||
info := buildinfo.Current()
|
info := buildinfo.Current()
|
||||||
return Config{
|
return Config{
|
||||||
|
Version: CurrentConfigVersion,
|
||||||
Server: ServerConfig{
|
Server: ServerConfig{
|
||||||
Host: "0.0.0.0",
|
Host: "0.0.0.0",
|
||||||
Port: 8080,
|
Port: 8080,
|
||||||
@@ -69,20 +117,14 @@ func defaultConfig() Config {
|
|||||||
QueryParam: "key",
|
QueryParam: "key",
|
||||||
},
|
},
|
||||||
AI: AIConfig{
|
AI: AIConfig{
|
||||||
Provider: "litellm",
|
Providers: map[string]ProviderConfig{},
|
||||||
Embeddings: AIEmbeddingConfig{
|
Embeddings: EmbeddingsRoleConfig{
|
||||||
Model: "openai/text-embedding-3-small",
|
|
||||||
Dimensions: 1536,
|
Dimensions: 1536,
|
||||||
},
|
},
|
||||||
Metadata: AIMetadataConfig{
|
Metadata: MetadataRoleConfig{
|
||||||
Model: "gpt-4o-mini",
|
|
||||||
Temperature: 0.1,
|
Temperature: 0.1,
|
||||||
Timeout: 10 * time.Second,
|
Timeout: 10 * time.Second,
|
||||||
},
|
},
|
||||||
Ollama: OllamaConfig{
|
|
||||||
BaseURL: "http://localhost:11434/v1",
|
|
||||||
APIKey: "ollama",
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
Capture: CaptureConfig{
|
Capture: CaptureConfig{
|
||||||
Source: DefaultSource,
|
Source: DefaultSource,
|
||||||
@@ -119,11 +161,12 @@ func defaultConfig() Config {
|
|||||||
func applyEnvOverrides(cfg *Config) {
|
func applyEnvOverrides(cfg *Config) {
|
||||||
overrideString(&cfg.Database.URL, "AMCS_DATABASE_URL")
|
overrideString(&cfg.Database.URL, "AMCS_DATABASE_URL")
|
||||||
overrideString(&cfg.MCP.PublicURL, "AMCS_PUBLIC_URL")
|
overrideString(&cfg.MCP.PublicURL, "AMCS_PUBLIC_URL")
|
||||||
overrideString(&cfg.AI.LiteLLM.BaseURL, "AMCS_LITELLM_BASE_URL")
|
|
||||||
overrideString(&cfg.AI.LiteLLM.APIKey, "AMCS_LITELLM_API_KEY")
|
overrideProviderField(cfg, "AMCS_LITELLM_BASE_URL", "litellm", func(p *ProviderConfig, v string) { p.BaseURL = v })
|
||||||
overrideString(&cfg.AI.Ollama.BaseURL, "AMCS_OLLAMA_BASE_URL")
|
overrideProviderField(cfg, "AMCS_LITELLM_API_KEY", "litellm", func(p *ProviderConfig, v string) { p.APIKey = v })
|
||||||
overrideString(&cfg.AI.Ollama.APIKey, "AMCS_OLLAMA_API_KEY")
|
overrideProviderField(cfg, "AMCS_OLLAMA_BASE_URL", "ollama", func(p *ProviderConfig, v string) { p.BaseURL = v })
|
||||||
overrideString(&cfg.AI.OpenRouter.APIKey, "AMCS_OPENROUTER_API_KEY")
|
overrideProviderField(cfg, "AMCS_OLLAMA_API_KEY", "ollama", func(p *ProviderConfig, v string) { p.APIKey = v })
|
||||||
|
overrideProviderField(cfg, "AMCS_OPENROUTER_API_KEY", "openrouter", func(p *ProviderConfig, v string) { p.APIKey = v })
|
||||||
|
|
||||||
if value, ok := os.LookupEnv("AMCS_SERVER_PORT"); ok {
|
if value, ok := os.LookupEnv("AMCS_SERVER_PORT"); ok {
|
||||||
if port, err := strconv.Atoi(strings.TrimSpace(value)); err == nil {
|
if port, err := strconv.Atoi(strings.TrimSpace(value)); err == nil {
|
||||||
@@ -132,6 +175,24 @@ func applyEnvOverrides(cfg *Config) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// overrideProviderField applies an env var to every configured provider of the
|
||||||
|
// given type. This preserves the v1 behaviour where e.g. AMCS_LITELLM_API_KEY
|
||||||
|
// rewrote the single litellm block — in v2 it rewrites every litellm provider.
|
||||||
|
func overrideProviderField(cfg *Config, envKey, providerType string, apply func(*ProviderConfig, string)) {
|
||||||
|
value, ok := os.LookupEnv(envKey)
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
value = strings.TrimSpace(value)
|
||||||
|
for name, p := range cfg.AI.Providers {
|
||||||
|
if p.Type != providerType {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
apply(&p, value)
|
||||||
|
cfg.AI.Providers[name] = p
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func overrideString(target *string, envKey string) {
|
func overrideString(target *string, envKey string) {
|
||||||
if value, ok := os.LookupEnv(envKey); ok {
|
if value, ok := os.LookupEnv(envKey); ok {
|
||||||
*target = strings.TrimSpace(value)
|
*target = strings.TrimSpace(value)
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package config
|
|||||||
import (
|
import (
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
@@ -31,9 +32,8 @@ func TestResolvePathIgnoresBareYAMLExtension(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestLoadAppliesEnvOverrides(t *testing.T) {
|
const v2ConfigYAML = `
|
||||||
configPath := filepath.Join(t.TempDir(), "test.yaml")
|
version: 2
|
||||||
if err := os.WriteFile(configPath, []byte(`
|
|
||||||
server:
|
server:
|
||||||
port: 8080
|
port: 8080
|
||||||
mcp:
|
mcp:
|
||||||
@@ -46,18 +46,30 @@ auth:
|
|||||||
database:
|
database:
|
||||||
url: "postgres://from-file"
|
url: "postgres://from-file"
|
||||||
ai:
|
ai:
|
||||||
provider: "litellm"
|
providers:
|
||||||
embeddings:
|
default:
|
||||||
dimensions: 1536
|
type: "litellm"
|
||||||
litellm:
|
|
||||||
base_url: "http://localhost:4000/v1"
|
base_url: "http://localhost:4000/v1"
|
||||||
api_key: "file-key"
|
api_key: "file-key"
|
||||||
|
embeddings:
|
||||||
|
dimensions: 1536
|
||||||
|
primary:
|
||||||
|
provider: "default"
|
||||||
|
model: "text-embed"
|
||||||
|
metadata:
|
||||||
|
primary:
|
||||||
|
provider: "default"
|
||||||
|
model: "gpt-4"
|
||||||
search:
|
search:
|
||||||
default_limit: 10
|
default_limit: 10
|
||||||
max_limit: 50
|
max_limit: 50
|
||||||
logging:
|
logging:
|
||||||
level: "info"
|
level: "info"
|
||||||
`), 0o600); err != nil {
|
`
|
||||||
|
|
||||||
|
func TestLoadAppliesEnvOverrides(t *testing.T) {
|
||||||
|
configPath := filepath.Join(t.TempDir(), "test.yaml")
|
||||||
|
if err := os.WriteFile(configPath, []byte(v2ConfigYAML), 0o600); err != nil {
|
||||||
t.Fatalf("write config: %v", err)
|
t.Fatalf("write config: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -76,8 +88,8 @@ logging:
|
|||||||
if cfg.Database.URL != "postgres://from-env" {
|
if cfg.Database.URL != "postgres://from-env" {
|
||||||
t.Fatalf("database url = %q, want env override", cfg.Database.URL)
|
t.Fatalf("database url = %q, want env override", cfg.Database.URL)
|
||||||
}
|
}
|
||||||
if cfg.AI.LiteLLM.APIKey != "env-key" {
|
if cfg.AI.Providers["default"].APIKey != "env-key" {
|
||||||
t.Fatalf("litellm api key = %q, want env override", cfg.AI.LiteLLM.APIKey)
|
t.Fatalf("litellm api key = %q, want env override", cfg.AI.Providers["default"].APIKey)
|
||||||
}
|
}
|
||||||
if cfg.Server.Port != 9090 {
|
if cfg.Server.Port != 9090 {
|
||||||
t.Fatalf("server port = %d, want 9090", cfg.Server.Port)
|
t.Fatalf("server port = %d, want 9090", cfg.Server.Port)
|
||||||
@@ -90,10 +102,12 @@ logging:
|
|||||||
func TestLoadAppliesOllamaEnvOverrides(t *testing.T) {
|
func TestLoadAppliesOllamaEnvOverrides(t *testing.T) {
|
||||||
configPath := filepath.Join(t.TempDir(), "test.yaml")
|
configPath := filepath.Join(t.TempDir(), "test.yaml")
|
||||||
if err := os.WriteFile(configPath, []byte(`
|
if err := os.WriteFile(configPath, []byte(`
|
||||||
|
version: 2
|
||||||
server:
|
server:
|
||||||
port: 8080
|
port: 8080
|
||||||
mcp:
|
mcp:
|
||||||
path: "/mcp"
|
path: "/mcp"
|
||||||
|
session_timeout: "10m"
|
||||||
auth:
|
auth:
|
||||||
keys:
|
keys:
|
||||||
- id: "test"
|
- id: "test"
|
||||||
@@ -101,15 +115,20 @@ auth:
|
|||||||
database:
|
database:
|
||||||
url: "postgres://from-file"
|
url: "postgres://from-file"
|
||||||
ai:
|
ai:
|
||||||
provider: "ollama"
|
providers:
|
||||||
embeddings:
|
local:
|
||||||
model: "nomic-embed-text"
|
type: "ollama"
|
||||||
dimensions: 768
|
|
||||||
metadata:
|
|
||||||
model: "llama3.2"
|
|
||||||
ollama:
|
|
||||||
base_url: "http://localhost:11434/v1"
|
base_url: "http://localhost:11434/v1"
|
||||||
api_key: "ollama"
|
api_key: "ollama"
|
||||||
|
embeddings:
|
||||||
|
dimensions: 768
|
||||||
|
primary:
|
||||||
|
provider: "local"
|
||||||
|
model: "nomic-embed-text"
|
||||||
|
metadata:
|
||||||
|
primary:
|
||||||
|
provider: "local"
|
||||||
|
model: "llama3.2"
|
||||||
search:
|
search:
|
||||||
default_limit: 10
|
default_limit: 10
|
||||||
max_limit: 50
|
max_limit: 50
|
||||||
@@ -127,10 +146,85 @@ logging:
|
|||||||
t.Fatalf("Load() error = %v", err)
|
t.Fatalf("Load() error = %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if cfg.AI.Ollama.BaseURL != "https://ollama.example.com/v1" {
|
p := cfg.AI.Providers["local"]
|
||||||
t.Fatalf("ollama base url = %q, want env override", cfg.AI.Ollama.BaseURL)
|
if p.BaseURL != "https://ollama.example.com/v1" {
|
||||||
|
t.Fatalf("ollama base url = %q, want env override", p.BaseURL)
|
||||||
}
|
}
|
||||||
if cfg.AI.Ollama.APIKey != "remote-key" {
|
if p.APIKey != "remote-key" {
|
||||||
t.Fatalf("ollama api key = %q, want env override", cfg.AI.Ollama.APIKey)
|
t.Fatalf("ollama api key = %q, want env override", p.APIKey)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadMigratesV1Config(t *testing.T) {
|
||||||
|
configPath := filepath.Join(t.TempDir(), "v1.yaml")
|
||||||
|
v1 := `
|
||||||
|
server:
|
||||||
|
port: 8080
|
||||||
|
mcp:
|
||||||
|
path: "/mcp"
|
||||||
|
session_timeout: "10m"
|
||||||
|
auth:
|
||||||
|
keys:
|
||||||
|
- id: "test"
|
||||||
|
value: "secret"
|
||||||
|
database:
|
||||||
|
url: "postgres://from-file"
|
||||||
|
ai:
|
||||||
|
provider: "litellm"
|
||||||
|
embeddings:
|
||||||
|
model: "text-embed"
|
||||||
|
dimensions: 1536
|
||||||
|
metadata:
|
||||||
|
model: "gpt-4"
|
||||||
|
temperature: 0.2
|
||||||
|
fallback_models: ["gpt-3.5"]
|
||||||
|
litellm:
|
||||||
|
base_url: "http://localhost:4000/v1"
|
||||||
|
api_key: "file-key"
|
||||||
|
search:
|
||||||
|
default_limit: 10
|
||||||
|
max_limit: 50
|
||||||
|
logging:
|
||||||
|
level: "info"
|
||||||
|
`
|
||||||
|
if err := os.WriteFile(configPath, []byte(v1), 0o600); err != nil {
|
||||||
|
t.Fatalf("write config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg, _, err := Load(configPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Load() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.Version != CurrentConfigVersion {
|
||||||
|
t.Fatalf("version = %d, want %d", cfg.Version, CurrentConfigVersion)
|
||||||
|
}
|
||||||
|
if p, ok := cfg.AI.Providers["default"]; !ok || p.Type != "litellm" || p.APIKey != "file-key" {
|
||||||
|
t.Fatalf("providers[default] = %+v, want litellm/file-key", p)
|
||||||
|
}
|
||||||
|
if cfg.AI.Embeddings.Primary.Model != "text-embed" || cfg.AI.Embeddings.Primary.Provider != "default" {
|
||||||
|
t.Fatalf("embeddings.primary = %+v, want default/text-embed", cfg.AI.Embeddings.Primary)
|
||||||
|
}
|
||||||
|
if cfg.AI.Metadata.Primary.Model != "gpt-4" || cfg.AI.Metadata.Primary.Provider != "default" {
|
||||||
|
t.Fatalf("metadata.primary = %+v, want default/gpt-4", cfg.AI.Metadata.Primary)
|
||||||
|
}
|
||||||
|
if len(cfg.AI.Metadata.Fallbacks) != 1 || cfg.AI.Metadata.Fallbacks[0].Model != "gpt-3.5" {
|
||||||
|
t.Fatalf("metadata.fallbacks = %+v, want [default/gpt-3.5]", cfg.AI.Metadata.Fallbacks)
|
||||||
|
}
|
||||||
|
|
||||||
|
entries, err := filepath.Glob(configPath + ".bak.*")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("glob backups: %v", err)
|
||||||
|
}
|
||||||
|
if len(entries) != 0 {
|
||||||
|
t.Fatalf("backup files = %d, want 0 (load should not rewrite config)", len(entries))
|
||||||
|
}
|
||||||
|
|
||||||
|
originalOnDisk, err := os.ReadFile(configPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("read original config: %v", err)
|
||||||
|
}
|
||||||
|
if !strings.Contains(string(originalOnDisk), "provider: \"litellm\"") {
|
||||||
|
t.Fatalf("expected source config to remain unchanged on disk")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
341
internal/config/migrate.go
Normal file
341
internal/config/migrate.go
Normal file
@@ -0,0 +1,341 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"sort"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CurrentConfigVersion is the schema version this binary expects. Files at a
|
||||||
|
// lower version are migrated automatically when loaded.
|
||||||
|
const CurrentConfigVersion = 2
|
||||||
|
|
||||||
|
// ConfigMigration upgrades a raw YAML map by one version.
|
||||||
|
type ConfigMigration struct {
|
||||||
|
From, To int
|
||||||
|
Describe string
|
||||||
|
Apply func(map[string]any) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// migrations is the ordered ladder of upgrades. Add new entries at the end.
|
||||||
|
var migrations = []ConfigMigration{
|
||||||
|
{From: 1, To: 2, Describe: "named providers + role chains", Apply: migrateV1toV2},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Migrate brings raw up to CurrentConfigVersion in place. Returns the list of
|
||||||
|
// migrations that were applied (may be empty if already current).
|
||||||
|
func Migrate(raw map[string]any) ([]ConfigMigration, error) {
|
||||||
|
if raw == nil {
|
||||||
|
return nil, fmt.Errorf("migrate: raw config is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
version := readVersion(raw)
|
||||||
|
if version > CurrentConfigVersion {
|
||||||
|
return nil, fmt.Errorf("migrate: config version %d is newer than supported version %d", version, CurrentConfigVersion)
|
||||||
|
}
|
||||||
|
|
||||||
|
applied := make([]ConfigMigration, 0)
|
||||||
|
for {
|
||||||
|
if version >= CurrentConfigVersion {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
step, ok := findMigration(version)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("migrate: no migration registered from version %d", version)
|
||||||
|
}
|
||||||
|
if err := step.Apply(raw); err != nil {
|
||||||
|
return nil, fmt.Errorf("migrate v%d->v%d: %w", step.From, step.To, err)
|
||||||
|
}
|
||||||
|
raw["version"] = step.To
|
||||||
|
version = step.To
|
||||||
|
applied = append(applied, step)
|
||||||
|
}
|
||||||
|
return applied, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func findMigration(from int) (ConfigMigration, bool) {
|
||||||
|
for _, m := range migrations {
|
||||||
|
if m.From == from {
|
||||||
|
return m, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ConfigMigration{}, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// readVersion returns the version from raw. Files without a version field are
|
||||||
|
// treated as version 1 (the original schema).
|
||||||
|
func readVersion(raw map[string]any) int {
|
||||||
|
v, ok := raw["version"]
|
||||||
|
if !ok {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
switch n := v.(type) {
|
||||||
|
case int:
|
||||||
|
return n
|
||||||
|
case int64:
|
||||||
|
return int(n)
|
||||||
|
case float64:
|
||||||
|
return int(n)
|
||||||
|
}
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
|
||||||
|
// migrateV1toV2 lifts the single-provider config into the named-providers +
|
||||||
|
// role-chains layout. The pre-v2 config implicitly used one provider for both
|
||||||
|
// embeddings and metadata; we materialise that as a provider named "default".
|
||||||
|
func migrateV1toV2(raw map[string]any) error {
|
||||||
|
aiRaw := mapValue(raw, "ai")
|
||||||
|
if aiRaw == nil {
|
||||||
|
aiRaw = map[string]any{}
|
||||||
|
}
|
||||||
|
|
||||||
|
providerType := stringValue(aiRaw, "provider")
|
||||||
|
if providerType == "" {
|
||||||
|
providerType = "litellm"
|
||||||
|
}
|
||||||
|
|
||||||
|
providers, embeddingModel, metadataModel, fallbackModels := buildV1Provider(aiRaw, providerType)
|
||||||
|
|
||||||
|
embeddingsOld := mapValue(aiRaw, "embeddings")
|
||||||
|
dimensions := intValue(embeddingsOld, "dimensions")
|
||||||
|
if dimensions <= 0 {
|
||||||
|
dimensions = 1536
|
||||||
|
}
|
||||||
|
if embeddingModel == "" {
|
||||||
|
embeddingModel = stringValue(embeddingsOld, "model")
|
||||||
|
}
|
||||||
|
|
||||||
|
metadataOld := mapValue(aiRaw, "metadata")
|
||||||
|
if metadataModel == "" {
|
||||||
|
metadataModel = stringValue(metadataOld, "model")
|
||||||
|
}
|
||||||
|
temperature := floatValue(metadataOld, "temperature")
|
||||||
|
logConversations := boolValue(metadataOld, "log_conversations")
|
||||||
|
timeoutStr := stringValue(metadataOld, "timeout")
|
||||||
|
|
||||||
|
if list := stringListValue(metadataOld, "fallback_models"); len(list) > 0 {
|
||||||
|
fallbackModels = append(fallbackModels, list...)
|
||||||
|
}
|
||||||
|
if v := stringValue(metadataOld, "fallback_model"); v != "" {
|
||||||
|
fallbackModels = append(fallbackModels, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
embeddings := map[string]any{
|
||||||
|
"dimensions": dimensions,
|
||||||
|
"primary": map[string]any{"provider": "default", "model": embeddingModel},
|
||||||
|
}
|
||||||
|
|
||||||
|
metadata := map[string]any{
|
||||||
|
"temperature": temperature,
|
||||||
|
"log_conversations": logConversations,
|
||||||
|
"primary": map[string]any{"provider": "default", "model": metadataModel},
|
||||||
|
}
|
||||||
|
if timeoutStr != "" {
|
||||||
|
metadata["timeout"] = timeoutStr
|
||||||
|
}
|
||||||
|
if fallbacks := chainTargets("default", fallbackModels); len(fallbacks) > 0 {
|
||||||
|
metadata["fallbacks"] = fallbacks
|
||||||
|
}
|
||||||
|
|
||||||
|
raw["ai"] = map[string]any{
|
||||||
|
"providers": providers,
|
||||||
|
"embeddings": embeddings,
|
||||||
|
"metadata": metadata,
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildV1Provider(aiRaw map[string]any, providerType string) (map[string]any, string, string, []string) {
|
||||||
|
providers := map[string]any{}
|
||||||
|
defaultEntry := map[string]any{"type": providerType}
|
||||||
|
embedModel := ""
|
||||||
|
metaModel := ""
|
||||||
|
var fallbacks []string
|
||||||
|
|
||||||
|
switch providerType {
|
||||||
|
case "litellm":
|
||||||
|
block := mapValue(aiRaw, "litellm")
|
||||||
|
copyKeys(defaultEntry, block, "base_url", "api_key")
|
||||||
|
copyHeaders(defaultEntry, block, "request_headers")
|
||||||
|
embedModel = stringValue(block, "embedding_model")
|
||||||
|
metaModel = stringValue(block, "metadata_model")
|
||||||
|
if list := stringListValue(block, "fallback_metadata_models"); len(list) > 0 {
|
||||||
|
fallbacks = append(fallbacks, list...)
|
||||||
|
}
|
||||||
|
if v := stringValue(block, "fallback_metadata_model"); v != "" {
|
||||||
|
fallbacks = append(fallbacks, v)
|
||||||
|
}
|
||||||
|
case "ollama":
|
||||||
|
block := mapValue(aiRaw, "ollama")
|
||||||
|
copyKeys(defaultEntry, block, "base_url", "api_key")
|
||||||
|
copyHeaders(defaultEntry, block, "request_headers")
|
||||||
|
case "openrouter":
|
||||||
|
block := mapValue(aiRaw, "openrouter")
|
||||||
|
copyKeys(defaultEntry, block, "base_url", "api_key", "app_name", "site_url")
|
||||||
|
copyHeaders(defaultEntry, block, "extra_headers")
|
||||||
|
// rename: extra_headers → request_headers
|
||||||
|
if hdr, ok := defaultEntry["extra_headers"]; ok {
|
||||||
|
defaultEntry["request_headers"] = hdr
|
||||||
|
delete(defaultEntry, "extra_headers")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
providers["default"] = defaultEntry
|
||||||
|
return providers, embedModel, metaModel, fallbacks
|
||||||
|
}
|
||||||
|
|
||||||
|
func chainTargets(provider string, models []string) []any {
|
||||||
|
out := make([]any, 0, len(models))
|
||||||
|
seen := map[string]struct{}{}
|
||||||
|
for _, m := range models {
|
||||||
|
if m == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
key := provider + "|" + m
|
||||||
|
if _, ok := seen[key]; ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seen[key] = struct{}{}
|
||||||
|
out = append(out, map[string]any{"provider": provider, "model": m})
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func mapValue(raw map[string]any, key string) map[string]any {
|
||||||
|
if raw == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
v, ok := raw[key]
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
switch m := v.(type) {
|
||||||
|
case map[string]any:
|
||||||
|
return m
|
||||||
|
case map[any]any:
|
||||||
|
return convertAnyMap(m)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func convertAnyMap(in map[any]any) map[string]any {
|
||||||
|
out := make(map[string]any, len(in))
|
||||||
|
keys := make([]string, 0, len(in))
|
||||||
|
for k, v := range in {
|
||||||
|
ks, ok := k.(string)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
keys = append(keys, ks)
|
||||||
|
out[ks] = v
|
||||||
|
}
|
||||||
|
sort.Strings(keys)
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func stringValue(raw map[string]any, key string) string {
|
||||||
|
if raw == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
v, ok := raw[key]
|
||||||
|
if !ok {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if s, ok := v.(string); ok {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func intValue(raw map[string]any, key string) int {
|
||||||
|
if raw == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
switch n := raw[key].(type) {
|
||||||
|
case int:
|
||||||
|
return n
|
||||||
|
case int64:
|
||||||
|
return int(n)
|
||||||
|
case float64:
|
||||||
|
return int(n)
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func floatValue(raw map[string]any, key string) float64 {
|
||||||
|
if raw == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
switch n := raw[key].(type) {
|
||||||
|
case float64:
|
||||||
|
return n
|
||||||
|
case int:
|
||||||
|
return float64(n)
|
||||||
|
case int64:
|
||||||
|
return float64(n)
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func boolValue(raw map[string]any, key string) bool {
|
||||||
|
if raw == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if b, ok := raw[key].(bool); ok {
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func stringListValue(raw map[string]any, key string) []string {
|
||||||
|
if raw == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
v, ok := raw[key]
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
list, ok := v.([]any)
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := make([]string, 0, len(list))
|
||||||
|
for _, item := range list {
|
||||||
|
if s, ok := item.(string); ok && s != "" {
|
||||||
|
out = append(out, s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func copyKeys(dst, src map[string]any, keys ...string) {
|
||||||
|
if src == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for _, k := range keys {
|
||||||
|
if v, ok := src[k]; ok {
|
||||||
|
dst[k] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func copyHeaders(dst, src map[string]any, key string) {
|
||||||
|
if src == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
v, ok := src[key]
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
switch headers := v.(type) {
|
||||||
|
case map[string]any:
|
||||||
|
if len(headers) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
dst[key] = headers
|
||||||
|
case map[any]any:
|
||||||
|
if len(headers) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
dst[key] = convertAnyMap(headers)
|
||||||
|
}
|
||||||
|
}
|
||||||
77
internal/config/migrate_test.go
Normal file
77
internal/config/migrate_test.go
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestMigrateV1ToV2Litellm(t *testing.T) {
|
||||||
|
raw := map[string]any{
|
||||||
|
"ai": map[string]any{
|
||||||
|
"provider": "litellm",
|
||||||
|
"embeddings": map[string]any{
|
||||||
|
"model": "text-embedding-3-small",
|
||||||
|
"dimensions": 1536,
|
||||||
|
},
|
||||||
|
"metadata": map[string]any{
|
||||||
|
"model": "gpt-4o-mini",
|
||||||
|
"temperature": 0.2,
|
||||||
|
"fallback_models": []any{"gpt-4.1-mini"},
|
||||||
|
},
|
||||||
|
"litellm": map[string]any{
|
||||||
|
"base_url": "http://localhost:4000/v1",
|
||||||
|
"api_key": "secret",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
applied, err := Migrate(raw)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Migrate() error = %v", err)
|
||||||
|
}
|
||||||
|
if len(applied) != 1 || applied[0].From != 1 || applied[0].To != 2 {
|
||||||
|
t.Fatalf("applied = %+v, want [v1->v2]", applied)
|
||||||
|
}
|
||||||
|
if got := readVersion(raw); got != CurrentConfigVersion {
|
||||||
|
t.Fatalf("version = %d, want %d", got, CurrentConfigVersion)
|
||||||
|
}
|
||||||
|
|
||||||
|
ai := mapValue(raw, "ai")
|
||||||
|
providers := mapValue(ai, "providers")
|
||||||
|
def := mapValue(providers, "default")
|
||||||
|
if got := stringValue(def, "type"); got != "litellm" {
|
||||||
|
t.Fatalf("providers.default.type = %q, want litellm", got)
|
||||||
|
}
|
||||||
|
if got := stringValue(def, "base_url"); got != "http://localhost:4000/v1" {
|
||||||
|
t.Fatalf("providers.default.base_url = %q", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
emb := mapValue(ai, "embeddings")
|
||||||
|
embPrimary := mapValue(emb, "primary")
|
||||||
|
if stringValue(embPrimary, "provider") != "default" || stringValue(embPrimary, "model") != "text-embedding-3-small" {
|
||||||
|
t.Fatalf("embeddings.primary = %+v, want default/text-embedding-3-small", embPrimary)
|
||||||
|
}
|
||||||
|
|
||||||
|
meta := mapValue(ai, "metadata")
|
||||||
|
metaPrimary := mapValue(meta, "primary")
|
||||||
|
if stringValue(metaPrimary, "provider") != "default" || stringValue(metaPrimary, "model") != "gpt-4o-mini" {
|
||||||
|
t.Fatalf("metadata.primary = %+v, want default/gpt-4o-mini", metaPrimary)
|
||||||
|
}
|
||||||
|
fallbacks, ok := meta["fallbacks"].([]any)
|
||||||
|
if !ok || len(fallbacks) != 1 {
|
||||||
|
t.Fatalf("metadata.fallbacks = %#v, want len=1", meta["fallbacks"])
|
||||||
|
}
|
||||||
|
firstFallback, ok := fallbacks[0].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("metadata.fallbacks[0] type = %T, want map[string]any", fallbacks[0])
|
||||||
|
}
|
||||||
|
if stringValue(firstFallback, "provider") != "default" || stringValue(firstFallback, "model") != "gpt-4.1-mini" {
|
||||||
|
t.Fatalf("metadata fallback = %+v, want default/gpt-4.1-mini", firstFallback)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMigrateRejectsNewerVersion(t *testing.T) {
|
||||||
|
raw := map[string]any{"version": CurrentConfigVersion + 1}
|
||||||
|
|
||||||
|
_, err := Migrate(raw)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("Migrate() error = nil, want error for newer config version")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -45,38 +45,8 @@ func (c Config) Validate() error {
|
|||||||
return fmt.Errorf("invalid config: mcp.session_timeout must be greater than zero")
|
return fmt.Errorf("invalid config: mcp.session_timeout must be greater than zero")
|
||||||
}
|
}
|
||||||
|
|
||||||
switch c.AI.Provider {
|
if err := c.AI.validate(); err != nil {
|
||||||
case "litellm", "ollama", "openrouter":
|
return err
|
||||||
default:
|
|
||||||
return fmt.Errorf("invalid config: unsupported ai.provider %q", c.AI.Provider)
|
|
||||||
}
|
|
||||||
|
|
||||||
if c.AI.Embeddings.Dimensions <= 0 {
|
|
||||||
return fmt.Errorf("invalid config: ai.embeddings.dimensions must be greater than zero")
|
|
||||||
}
|
|
||||||
|
|
||||||
switch c.AI.Provider {
|
|
||||||
case "litellm":
|
|
||||||
if strings.TrimSpace(c.AI.LiteLLM.BaseURL) == "" {
|
|
||||||
return fmt.Errorf("invalid config: ai.litellm.base_url is required when ai.provider=litellm")
|
|
||||||
}
|
|
||||||
if strings.TrimSpace(c.AI.LiteLLM.APIKey) == "" {
|
|
||||||
return fmt.Errorf("invalid config: ai.litellm.api_key is required when ai.provider=litellm")
|
|
||||||
}
|
|
||||||
case "ollama":
|
|
||||||
if strings.TrimSpace(c.AI.Ollama.BaseURL) == "" {
|
|
||||||
return fmt.Errorf("invalid config: ai.ollama.base_url is required when ai.provider=ollama")
|
|
||||||
}
|
|
||||||
if strings.TrimSpace(c.AI.Ollama.APIKey) == "" {
|
|
||||||
return fmt.Errorf("invalid config: ai.ollama.api_key is required when ai.provider=ollama")
|
|
||||||
}
|
|
||||||
case "openrouter":
|
|
||||||
if strings.TrimSpace(c.AI.OpenRouter.BaseURL) == "" {
|
|
||||||
return fmt.Errorf("invalid config: ai.openrouter.base_url is required when ai.provider=openrouter")
|
|
||||||
}
|
|
||||||
if strings.TrimSpace(c.AI.OpenRouter.APIKey) == "" {
|
|
||||||
return fmt.Errorf("invalid config: ai.openrouter.api_key is required when ai.provider=openrouter")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if c.Server.Port <= 0 {
|
if c.Server.Port <= 0 {
|
||||||
@@ -108,3 +78,61 @@ func (c Config) Validate() error {
|
|||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a AIConfig) validate() error {
|
||||||
|
if len(a.Providers) == 0 {
|
||||||
|
return fmt.Errorf("invalid config: ai.providers must contain at least one entry")
|
||||||
|
}
|
||||||
|
for name, p := range a.Providers {
|
||||||
|
if strings.TrimSpace(name) == "" {
|
||||||
|
return fmt.Errorf("invalid config: ai.providers contains an entry with an empty name")
|
||||||
|
}
|
||||||
|
switch p.Type {
|
||||||
|
case "litellm", "ollama", "openrouter":
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("invalid config: ai.providers.%s.type %q is not supported", name, p.Type)
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(p.BaseURL) == "" {
|
||||||
|
return fmt.Errorf("invalid config: ai.providers.%s.base_url is required", name)
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(p.APIKey) == "" {
|
||||||
|
return fmt.Errorf("invalid config: ai.providers.%s.api_key is required", name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if a.Embeddings.Dimensions <= 0 {
|
||||||
|
return fmt.Errorf("invalid config: ai.embeddings.dimensions must be greater than zero")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := a.validateChain("ai.embeddings", a.Embeddings.Chain()); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := a.validateChain("ai.metadata", a.Metadata.Chain()); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if a.Background != nil {
|
||||||
|
if a.Background.Embeddings != nil {
|
||||||
|
if err := a.validateChain("ai.background.embeddings", a.Background.Embeddings.AsTargets()); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if a.Background.Metadata != nil {
|
||||||
|
if err := a.validateChain("ai.background.metadata", a.Background.Metadata.AsTargets()); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a AIConfig) validateChain(prefix string, chain []RoleTarget) error {
|
||||||
|
if len(chain) == 0 {
|
||||||
|
return fmt.Errorf("invalid config: %s.primary must reference a configured provider and model", prefix)
|
||||||
|
}
|
||||||
|
for i, target := range chain {
|
||||||
|
if _, ok := a.Providers[target.Provider]; !ok {
|
||||||
|
return fmt.Errorf("invalid config: %s[%d] references unknown provider %q", prefix, i, target.Provider)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
|
|
||||||
func validConfig() Config {
|
func validConfig() Config {
|
||||||
return Config{
|
return Config{
|
||||||
|
Version: CurrentConfigVersion,
|
||||||
Server: ServerConfig{Port: 8080},
|
Server: ServerConfig{Port: 8080},
|
||||||
MCP: MCPConfig{Path: "/mcp", SessionTimeout: 10 * time.Minute},
|
MCP: MCPConfig{Path: "/mcp", SessionTimeout: 10 * time.Minute},
|
||||||
Auth: AuthConfig{
|
Auth: AuthConfig{
|
||||||
@@ -14,21 +15,15 @@ func validConfig() Config {
|
|||||||
},
|
},
|
||||||
Database: DatabaseConfig{URL: "postgres://example"},
|
Database: DatabaseConfig{URL: "postgres://example"},
|
||||||
AI: AIConfig{
|
AI: AIConfig{
|
||||||
Provider: "litellm",
|
Providers: map[string]ProviderConfig{
|
||||||
Embeddings: AIEmbeddingConfig{
|
"default": {Type: "litellm", BaseURL: "http://localhost:4000/v1", APIKey: "key"},
|
||||||
|
},
|
||||||
|
Embeddings: EmbeddingsRoleConfig{
|
||||||
Dimensions: 1536,
|
Dimensions: 1536,
|
||||||
|
Primary: RoleTarget{Provider: "default", Model: "text-embed"},
|
||||||
},
|
},
|
||||||
LiteLLM: LiteLLMConfig{
|
Metadata: MetadataRoleConfig{
|
||||||
BaseURL: "http://localhost:4000/v1",
|
Primary: RoleTarget{Provider: "default", Model: "gpt-4"},
|
||||||
APIKey: "key",
|
|
||||||
},
|
|
||||||
Ollama: OllamaConfig{
|
|
||||||
BaseURL: "http://localhost:11434/v1",
|
|
||||||
APIKey: "ollama",
|
|
||||||
},
|
|
||||||
OpenRouter: OpenRouterAIConfig{
|
|
||||||
BaseURL: "https://openrouter.ai/api/v1",
|
|
||||||
APIKey: "key",
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
Search: SearchConfig{DefaultLimit: 10, MaxLimit: 50},
|
Search: SearchConfig{DefaultLimit: 10, MaxLimit: 50},
|
||||||
@@ -36,29 +31,44 @@ func validConfig() Config {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestValidateAcceptsSupportedProviders(t *testing.T) {
|
func TestValidateAcceptsSupportedProviderTypes(t *testing.T) {
|
||||||
|
for _, providerType := range []string{"litellm", "ollama", "openrouter"} {
|
||||||
cfg := validConfig()
|
cfg := validConfig()
|
||||||
|
p := cfg.AI.Providers["default"]
|
||||||
|
p.Type = providerType
|
||||||
|
cfg.AI.Providers["default"] = p
|
||||||
if err := cfg.Validate(); err != nil {
|
if err := cfg.Validate(); err != nil {
|
||||||
t.Fatalf("Validate litellm error = %v", err)
|
t.Fatalf("Validate %s error = %v", providerType, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
cfg.AI.Provider = "ollama"
|
|
||||||
if err := cfg.Validate(); err != nil {
|
|
||||||
t.Fatalf("Validate ollama error = %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
cfg.AI.Provider = "openrouter"
|
|
||||||
if err := cfg.Validate(); err != nil {
|
|
||||||
t.Fatalf("Validate openrouter error = %v", err)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestValidateRejectsInvalidProvider(t *testing.T) {
|
func TestValidateRejectsInvalidProviderType(t *testing.T) {
|
||||||
cfg := validConfig()
|
cfg := validConfig()
|
||||||
cfg.AI.Provider = "unknown"
|
p := cfg.AI.Providers["default"]
|
||||||
|
p.Type = "unknown"
|
||||||
|
cfg.AI.Providers["default"] = p
|
||||||
|
|
||||||
if err := cfg.Validate(); err == nil {
|
if err := cfg.Validate(); err == nil {
|
||||||
t.Fatal("Validate() error = nil, want error for unsupported provider")
|
t.Fatal("Validate() error = nil, want error for unsupported provider type")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateRejectsChainWithUnknownProvider(t *testing.T) {
|
||||||
|
cfg := validConfig()
|
||||||
|
cfg.AI.Metadata.Primary = RoleTarget{Provider: "does-not-exist", Model: "x"}
|
||||||
|
|
||||||
|
if err := cfg.Validate(); err == nil {
|
||||||
|
t.Fatal("Validate() error = nil, want error for chain referencing unknown provider")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateRejectsEmptyProviders(t *testing.T) {
|
||||||
|
cfg := validConfig()
|
||||||
|
cfg.AI.Providers = map[string]ProviderConfig{}
|
||||||
|
|
||||||
|
if err := cfg.Validate(); err == nil {
|
||||||
|
t.Fatal("Validate() error = nil, want error for empty providers")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -35,11 +35,12 @@ type ToolSet struct {
|
|||||||
Files *tools.FilesTool
|
Files *tools.FilesTool
|
||||||
Backfill *tools.BackfillTool
|
Backfill *tools.BackfillTool
|
||||||
Reparse *tools.ReparseMetadataTool
|
Reparse *tools.ReparseMetadataTool
|
||||||
RetryMetadata *tools.RetryMetadataTool
|
RetryMetadata *tools.RetryEnrichmentTool
|
||||||
Maintenance *tools.MaintenanceTool
|
Maintenance *tools.MaintenanceTool
|
||||||
Skills *tools.SkillsTool
|
Skills *tools.SkillsTool
|
||||||
ChatHistory *tools.ChatHistoryTool
|
ChatHistory *tools.ChatHistoryTool
|
||||||
Describe *tools.DescribeTool
|
Describe *tools.DescribeTool
|
||||||
|
Learnings *tools.LearningsTool
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handlers groups the HTTP handlers produced for an MCP server instance.
|
// Handlers groups the HTTP handlers produced for an MCP server instance.
|
||||||
@@ -83,6 +84,7 @@ func NewHandlers(cfg config.MCPConfig, logger *slog.Logger, toolSet ToolSet, onS
|
|||||||
registerSystemTools,
|
registerSystemTools,
|
||||||
registerThoughtTools,
|
registerThoughtTools,
|
||||||
registerProjectTools,
|
registerProjectTools,
|
||||||
|
registerLearningTools,
|
||||||
registerFileTools,
|
registerFileTools,
|
||||||
registerMaintenanceTools,
|
registerMaintenanceTools,
|
||||||
registerSkillTools,
|
registerSkillTools,
|
||||||
@@ -249,6 +251,28 @@ func registerProjectTools(server *mcp.Server, logger *slog.Logger, toolSet ToolS
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func registerLearningTools(server *mcp.Server, logger *slog.Logger, toolSet ToolSet) error {
|
||||||
|
if err := addTool(server, logger, &mcp.Tool{
|
||||||
|
Name: "add_learning",
|
||||||
|
Description: "Create a curated learning record distinct from raw thoughts.",
|
||||||
|
}, toolSet.Learnings.Add); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := addTool(server, logger, &mcp.Tool{
|
||||||
|
Name: "get_learning",
|
||||||
|
Description: "Retrieve a structured learning by id.",
|
||||||
|
}, toolSet.Learnings.Get); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := addTool(server, logger, &mcp.Tool{
|
||||||
|
Name: "list_learnings",
|
||||||
|
Description: "List structured learnings with optional project, status, priority, tag, and text filters.",
|
||||||
|
}, toolSet.Learnings.List); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func registerFileTools(server *mcp.Server, logger *slog.Logger, toolSet ToolSet) error {
|
func registerFileTools(server *mcp.Server, logger *slog.Logger, toolSet ToolSet) error {
|
||||||
server.AddResourceTemplate(&mcp.ResourceTemplate{
|
server.AddResourceTemplate(&mcp.ResourceTemplate{
|
||||||
Name: "stored_file",
|
Name: "stored_file",
|
||||||
@@ -477,6 +501,11 @@ func BuildToolCatalog() []tools.ToolEntry {
|
|||||||
{Name: "get_active_project", Description: "Return the active project for the current MCP session. If your client does not preserve MCP sessions, pass project explicitly to project-scoped tools instead of relying on this.", Category: "projects"},
|
{Name: "get_active_project", Description: "Return the active project for the current MCP session. If your client does not preserve MCP sessions, pass project explicitly to project-scoped tools instead of relying on this.", Category: "projects"},
|
||||||
{Name: "get_project_context", Description: "Get recent and semantic context for a project. Uses the explicit project when provided, otherwise the active MCP session project. Falls back to full-text search when no embeddings exist.", Category: "projects"},
|
{Name: "get_project_context", Description: "Get recent and semantic context for a project. Uses the explicit project when provided, otherwise the active MCP session project. Falls back to full-text search when no embeddings exist.", Category: "projects"},
|
||||||
|
|
||||||
|
// learnings
|
||||||
|
{Name: "add_learning", Description: "Create a curated learning record distinct from raw thoughts.", Category: "projects"},
|
||||||
|
{Name: "get_learning", Description: "Retrieve a structured learning by id.", Category: "projects"},
|
||||||
|
{Name: "list_learnings", Description: "List structured learnings with optional project, category, area, status, priority, tag, and text filters.", Category: "projects"},
|
||||||
|
|
||||||
// files
|
// files
|
||||||
{Name: "upload_file", Description: "Stage a file and get an amcs://files/{id} resource URI. Use content_path (absolute server-side path, no size limit) for large or binary files, or content_base64 (≤10 MB) for small files. Pass thought_id/project to link immediately, or omit and pass the URI to save_file later.", Category: "files"},
|
{Name: "upload_file", Description: "Stage a file and get an amcs://files/{id} resource URI. Use content_path (absolute server-side path, no size limit) for large or binary files, or content_base64 (≤10 MB) for small files. Pass thought_id/project to link immediately, or omit and pass the URI to save_file later.", Category: "files"},
|
||||||
{Name: "save_file", Description: "Store a file and optionally link it to a thought. Use content_base64 (≤10 MB) for small files, or content_uri (amcs://files/{id} from a prior upload_file) for previously staged files. For files larger than 10 MB, use upload_file with content_path first. If the goal is to retain the artifact, store the file directly instead of reading or summarising it first.", Category: "files"},
|
{Name: "save_file", Description: "Store a file and optionally link it to a thought. Use content_base64 (≤10 MB) for small files, or content_uri (amcs://files/{id} from a prior upload_file) for previously staged files. For files larger than 10 MB, use upload_file with content_path first. If the goal is to retain the artifact, store the file directly instead of reading or summarising it first.", Category: "files"},
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ func TestNewListsAllRegisteredTools(t *testing.T) {
|
|||||||
|
|
||||||
want := []string{
|
want := []string{
|
||||||
"add_guardrail",
|
"add_guardrail",
|
||||||
|
"add_learning",
|
||||||
"add_maintenance_task",
|
"add_maintenance_task",
|
||||||
"add_project_guardrail",
|
"add_project_guardrail",
|
||||||
"add_project_skill",
|
"add_project_skill",
|
||||||
@@ -43,6 +44,7 @@ func TestNewListsAllRegisteredTools(t *testing.T) {
|
|||||||
"describe_tools",
|
"describe_tools",
|
||||||
"get_active_project",
|
"get_active_project",
|
||||||
"get_chat_history",
|
"get_chat_history",
|
||||||
|
"get_learning",
|
||||||
"get_project_context",
|
"get_project_context",
|
||||||
"get_thought",
|
"get_thought",
|
||||||
"get_upcoming_maintenance",
|
"get_upcoming_maintenance",
|
||||||
@@ -51,6 +53,7 @@ func TestNewListsAllRegisteredTools(t *testing.T) {
|
|||||||
"list_chat_histories",
|
"list_chat_histories",
|
||||||
"list_files",
|
"list_files",
|
||||||
"list_guardrails",
|
"list_guardrails",
|
||||||
|
"list_learnings",
|
||||||
"list_project_guardrails",
|
"list_project_guardrails",
|
||||||
"list_project_skills",
|
"list_project_skills",
|
||||||
"list_projects",
|
"list_projects",
|
||||||
|
|||||||
@@ -126,7 +126,7 @@ func streamableTestToolSet() ToolSet {
|
|||||||
Files: new(tools.FilesTool),
|
Files: new(tools.FilesTool),
|
||||||
Backfill: new(tools.BackfillTool),
|
Backfill: new(tools.BackfillTool),
|
||||||
Reparse: new(tools.ReparseMetadataTool),
|
Reparse: new(tools.ReparseMetadataTool),
|
||||||
RetryMetadata: new(tools.RetryMetadataTool),
|
RetryMetadata: new(tools.RetryEnrichmentTool),
|
||||||
Maintenance: new(tools.MaintenanceTool),
|
Maintenance: new(tools.MaintenanceTool),
|
||||||
Skills: new(tools.SkillsTool),
|
Skills: new(tools.SkillsTool),
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,19 +1,25 @@
|
|||||||
package observability
|
package observability
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"net"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"runtime/debug"
|
"runtime/debug"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
|
|
||||||
|
"git.warky.dev/wdevs/amcs/internal/requestip"
|
||||||
)
|
)
|
||||||
|
|
||||||
type contextKey string
|
type contextKey string
|
||||||
|
|
||||||
const requestIDContextKey contextKey = "request_id"
|
const requestIDContextKey contextKey = "request_id"
|
||||||
|
const mcpToolContextKey contextKey = "mcp_tool"
|
||||||
|
|
||||||
func Chain(h http.Handler, middlewares ...func(http.Handler) http.Handler) http.Handler {
|
func Chain(h http.Handler, middlewares ...func(http.Handler) http.Handler) http.Handler {
|
||||||
for i := len(middlewares) - 1; i >= 0; i-- {
|
for i := len(middlewares) - 1; i >= 0; i-- {
|
||||||
@@ -57,18 +63,27 @@ func Recover(log *slog.Logger) func(http.Handler) http.Handler {
|
|||||||
func AccessLog(log *slog.Logger) func(http.Handler) http.Handler {
|
func AccessLog(log *slog.Logger) func(http.Handler) http.Handler {
|
||||||
return func(next http.Handler) http.Handler {
|
return func(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if tool := mcpToolFromRequest(r); tool != "" {
|
||||||
|
r = r.WithContext(context.WithValue(r.Context(), mcpToolContextKey, tool))
|
||||||
|
}
|
||||||
|
|
||||||
recorder := &statusRecorder{ResponseWriter: w, status: http.StatusOK}
|
recorder := &statusRecorder{ResponseWriter: w, status: http.StatusOK}
|
||||||
started := time.Now()
|
started := time.Now()
|
||||||
next.ServeHTTP(recorder, r)
|
next.ServeHTTP(recorder, r)
|
||||||
|
|
||||||
log.Info("http request",
|
attrs := []any{
|
||||||
slog.String("request_id", RequestIDFromContext(r.Context())),
|
slog.String("request_id", RequestIDFromContext(r.Context())),
|
||||||
slog.String("method", r.Method),
|
slog.String("method", r.Method),
|
||||||
slog.String("path", r.URL.Path),
|
slog.String("path", r.URL.Path),
|
||||||
slog.Int("status", recorder.status),
|
slog.Int("status", recorder.status),
|
||||||
slog.Duration("duration", time.Since(started)),
|
slog.Duration("duration", time.Since(started)),
|
||||||
slog.String("remote_addr", stripPort(r.RemoteAddr)),
|
slog.String("remote_addr", requestip.FromRequest(r)),
|
||||||
)
|
slog.String("mcp_session_id", mcpSessionIDFromRequest(r)),
|
||||||
|
}
|
||||||
|
if tool, _ := r.Context().Value(mcpToolContextKey).(string); strings.TrimSpace(tool) != "" {
|
||||||
|
attrs = append(attrs, slog.String("tool", tool), slog.String("tool_call", tool))
|
||||||
|
}
|
||||||
|
log.Info("http request", attrs...)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -101,10 +116,67 @@ func (s *statusRecorder) WriteHeader(statusCode int) {
|
|||||||
s.ResponseWriter.WriteHeader(statusCode)
|
s.ResponseWriter.WriteHeader(statusCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
func stripPort(remote string) string {
|
func mcpToolFromRequest(r *http.Request) string {
|
||||||
host, _, err := net.SplitHostPort(remote)
|
if r == nil || r.Method != http.MethodPost || !strings.HasPrefix(r.URL.Path, "/mcp") || r.Body == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
raw, err := io.ReadAll(r.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return remote
|
return ""
|
||||||
}
|
}
|
||||||
return host
|
r.Body = io.NopCloser(bytes.NewReader(raw))
|
||||||
|
if len(raw) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// Support both single and batch JSON-RPC payloads.
|
||||||
|
if strings.HasPrefix(strings.TrimSpace(string(raw)), "[") {
|
||||||
|
var batch []rpcEnvelope
|
||||||
|
if err := json.Unmarshal(raw, &batch); err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
for _, msg := range batch {
|
||||||
|
if tool := msg.toolName(); tool != "" {
|
||||||
|
return tool
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
var msg rpcEnvelope
|
||||||
|
if err := json.Unmarshal(raw, &msg); err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return msg.toolName()
|
||||||
|
}
|
||||||
|
|
||||||
|
func mcpSessionIDFromRequest(r *http.Request) string {
|
||||||
|
if r == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if v := strings.TrimSpace(r.Header.Get("MCP-Session-Id")); v != "" {
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
// Some clients/proxies may propagate the session in query params.
|
||||||
|
for _, key := range []string{"session_id", "sessionId", "mcp_session_id"} {
|
||||||
|
if v := strings.TrimSpace(r.URL.Query().Get(key)); v != "" {
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
type rpcEnvelope struct {
|
||||||
|
Method string `json:"method"`
|
||||||
|
Params struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
} `json:"params"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m rpcEnvelope) toolName() string {
|
||||||
|
if m.Method != "tools/call" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return strings.TrimSpace(m.Params.Name)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,10 +1,13 @@
|
|||||||
package observability
|
package observability
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
"io"
|
"io"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
@@ -57,3 +60,99 @@ func TestRecoverHandlesPanic(t *testing.T) {
|
|||||||
t.Fatalf("status = %d, want %d", rec.Code, http.StatusInternalServerError)
|
t.Fatalf("status = %d, want %d", rec.Code, http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAccessLogUsesForwardedClientIP(t *testing.T) {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
logger := slog.New(slog.NewTextHandler(&buf, nil))
|
||||||
|
handler := AccessLog(logger)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusNoContent)
|
||||||
|
}))
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/mcp", nil)
|
||||||
|
req.RemoteAddr = "10.0.0.10:1234"
|
||||||
|
req.Header.Set("X-Real-IP", "203.0.113.7")
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
if rec.Code != http.StatusNoContent {
|
||||||
|
t.Fatalf("status = %d, want %d", rec.Code, http.StatusNoContent)
|
||||||
|
}
|
||||||
|
if !strings.Contains(buf.String(), "remote_addr=203.0.113.7") {
|
||||||
|
t.Fatalf("log output = %q, want remote_addr=203.0.113.7", buf.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAccessLogIncludesMCPToolName(t *testing.T) {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
logger := slog.New(slog.NewTextHandler(&buf, nil))
|
||||||
|
handler := AccessLog(logger)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusNoContent)
|
||||||
|
}))
|
||||||
|
|
||||||
|
payload := map[string]any{
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": "1",
|
||||||
|
"method": "tools/call",
|
||||||
|
"params": map[string]any{
|
||||||
|
"name": "list_projects",
|
||||||
|
"arguments": map[string]any{},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
body, err := json.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("json.Marshal() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/mcp", bytes.NewReader(body))
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
if rec.Code != http.StatusNoContent {
|
||||||
|
t.Fatalf("status = %d, want %d", rec.Code, http.StatusNoContent)
|
||||||
|
}
|
||||||
|
if !strings.Contains(buf.String(), "tool=list_projects") {
|
||||||
|
t.Fatalf("log output = %q, want tool=list_projects", buf.String())
|
||||||
|
}
|
||||||
|
if !strings.Contains(buf.String(), "tool_call=list_projects") {
|
||||||
|
t.Fatalf("log output = %q, want tool_call=list_projects", buf.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAccessLogIncludesMCPSessionIDHeader(t *testing.T) {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
logger := slog.New(slog.NewTextHandler(&buf, nil))
|
||||||
|
handler := AccessLog(logger)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusNoContent)
|
||||||
|
}))
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/sse", nil)
|
||||||
|
req.Header.Set("MCP-Session-Id", "sess-123")
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
if rec.Code != http.StatusNoContent {
|
||||||
|
t.Fatalf("status = %d, want %d", rec.Code, http.StatusNoContent)
|
||||||
|
}
|
||||||
|
if !strings.Contains(buf.String(), "mcp_session_id=sess-123") {
|
||||||
|
t.Fatalf("log output = %q, want mcp_session_id=sess-123", buf.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAccessLogIncludesMCPSessionIDQueryParam(t *testing.T) {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
logger := slog.New(slog.NewTextHandler(&buf, nil))
|
||||||
|
handler := AccessLog(logger)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusNoContent)
|
||||||
|
}))
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/sse?session_id=sess-q-1", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
if rec.Code != http.StatusNoContent {
|
||||||
|
t.Fatalf("status = %d, want %d", rec.Code, http.StatusNoContent)
|
||||||
|
}
|
||||||
|
if !strings.Contains(buf.String(), "mcp_session_id=sess-q-1") {
|
||||||
|
t.Fatalf("log output = %q, want mcp_session_id=sess-q-1", buf.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
72
internal/requestip/requestip.go
Normal file
72
internal/requestip/requestip.go
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
package requestip
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// FromRequest returns the best-effort client IP/host for a request, preferring
|
||||||
|
// proxy headers before falling back to RemoteAddr.
|
||||||
|
//
|
||||||
|
// Header precedence:
|
||||||
|
// 1) X-Real-IP
|
||||||
|
// 2) X-Forwarded-For (first value)
|
||||||
|
// 3) Forwarded (for=...)
|
||||||
|
// 4) RemoteAddr (host part)
|
||||||
|
func FromRequest(r *http.Request) string {
|
||||||
|
if r == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
if v := firstAddressToken(r.Header.Get("X-Real-IP")); v != "" {
|
||||||
|
return stripPort(v)
|
||||||
|
}
|
||||||
|
if v := firstAddressToken(r.Header.Get("X-Forwarded-For")); v != "" {
|
||||||
|
return stripPort(v)
|
||||||
|
}
|
||||||
|
if v := forwardedForValue(r.Header.Get("Forwarded")); v != "" {
|
||||||
|
return stripPort(v)
|
||||||
|
}
|
||||||
|
return stripPort(strings.TrimSpace(r.RemoteAddr))
|
||||||
|
}
|
||||||
|
|
||||||
|
func firstAddressToken(v string) string {
|
||||||
|
if v == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
part := strings.TrimSpace(strings.Split(v, ",")[0])
|
||||||
|
part = strings.Trim(part, `"`)
|
||||||
|
return strings.TrimSpace(part)
|
||||||
|
}
|
||||||
|
|
||||||
|
func forwardedForValue(v string) string {
|
||||||
|
for _, part := range strings.Split(v, ",") {
|
||||||
|
for _, kv := range strings.Split(part, ";") {
|
||||||
|
k, raw, ok := strings.Cut(strings.TrimSpace(kv), "=")
|
||||||
|
if !ok || !strings.EqualFold(strings.TrimSpace(k), "for") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
candidate := strings.Trim(strings.TrimSpace(raw), `"`)
|
||||||
|
if candidate == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return candidate
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func stripPort(addr string) string {
|
||||||
|
addr = strings.TrimSpace(addr)
|
||||||
|
if addr == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
// RFC 7239 quoted values may wrap IPv6 with brackets.
|
||||||
|
addr = strings.Trim(addr, "[]")
|
||||||
|
host, _, err := net.SplitHostPort(addr)
|
||||||
|
if err == nil {
|
||||||
|
return host
|
||||||
|
}
|
||||||
|
return addr
|
||||||
|
}
|
||||||
37
internal/requestip/requestip_test.go
Normal file
37
internal/requestip/requestip_test.go
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
package requestip
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestFromRequestPrefersXRealIP(t *testing.T) {
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
req.RemoteAddr = "10.0.0.10:5555"
|
||||||
|
req.Header.Set("X-Forwarded-For", "198.51.100.1")
|
||||||
|
req.Header.Set("X-Real-IP", "203.0.113.10")
|
||||||
|
|
||||||
|
if got := FromRequest(req); got != "203.0.113.10" {
|
||||||
|
t.Fatalf("FromRequest() = %q, want %q", got, "203.0.113.10")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFromRequestUsesXForwardedForFirstValue(t *testing.T) {
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
req.RemoteAddr = "10.0.0.10:5555"
|
||||||
|
req.Header.Set("X-Forwarded-For", "198.51.100.7, 10.1.1.2")
|
||||||
|
|
||||||
|
if got := FromRequest(req); got != "198.51.100.7" {
|
||||||
|
t.Fatalf("FromRequest() = %q, want %q", got, "198.51.100.7")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFromRequestFallsBackToRemoteAddr(t *testing.T) {
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
req.RemoteAddr = "192.0.2.5:1234"
|
||||||
|
|
||||||
|
if got := FromRequest(req); got != "192.0.2.5" {
|
||||||
|
t.Fatalf("FromRequest() = %q, want %q", got, "192.0.2.5")
|
||||||
|
}
|
||||||
|
}
|
||||||
215
internal/store/learnings.go
Normal file
215
internal/store/learnings.go
Normal file
@@ -0,0 +1,215 @@
|
|||||||
|
package store
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/jackc/pgx/v5"
|
||||||
|
"github.com/jackc/pgx/v5/pgtype"
|
||||||
|
|
||||||
|
thoughttypes "git.warky.dev/wdevs/amcs/internal/types"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (db *DB) CreateLearning(ctx context.Context, learning thoughttypes.Learning) (thoughttypes.Learning, error) {
|
||||||
|
row := db.pool.QueryRow(ctx, `
|
||||||
|
insert into learnings (
|
||||||
|
summary, details, category, area, status, priority, confidence,
|
||||||
|
action_required, source_type, source_ref, project_id, related_thought_id,
|
||||||
|
related_skill_id, reviewed_by, reviewed_at, duplicate_of_learning_id,
|
||||||
|
supersedes_learning_id, tags
|
||||||
|
) values (
|
||||||
|
$1, $2, $3, $4, $5, $6, $7,
|
||||||
|
$8, $9, $10, $11, $12,
|
||||||
|
$13, $14, $15, $16,
|
||||||
|
$17, $18
|
||||||
|
)
|
||||||
|
returning id, created_at, updated_at
|
||||||
|
`,
|
||||||
|
strings.TrimSpace(learning.Summary),
|
||||||
|
strings.TrimSpace(learning.Details),
|
||||||
|
strings.TrimSpace(learning.Category),
|
||||||
|
strings.TrimSpace(learning.Area),
|
||||||
|
string(learning.Status),
|
||||||
|
string(learning.Priority),
|
||||||
|
string(learning.Confidence),
|
||||||
|
learning.ActionRequired,
|
||||||
|
nullableText(learning.SourceType),
|
||||||
|
nullableText(learning.SourceRef),
|
||||||
|
learning.ProjectID,
|
||||||
|
learning.RelatedThoughtID,
|
||||||
|
learning.RelatedSkillID,
|
||||||
|
nullableTextPtr(learning.ReviewedBy),
|
||||||
|
learning.ReviewedAt,
|
||||||
|
learning.DuplicateOfLearningID,
|
||||||
|
learning.SupersedesLearningID,
|
||||||
|
learning.Tags,
|
||||||
|
)
|
||||||
|
|
||||||
|
created := learning
|
||||||
|
if err := row.Scan(&created.ID, &created.CreatedAt, &created.UpdatedAt); err != nil {
|
||||||
|
return thoughttypes.Learning{}, fmt.Errorf("create learning: %w", err)
|
||||||
|
}
|
||||||
|
return created, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *DB) GetLearning(ctx context.Context, id uuid.UUID) (thoughttypes.Learning, error) {
|
||||||
|
row := db.pool.QueryRow(ctx, `
|
||||||
|
select id, summary, details, category, area, status, priority, confidence,
|
||||||
|
action_required, source_type, source_ref, project_id, related_thought_id,
|
||||||
|
related_skill_id, reviewed_by, reviewed_at, duplicate_of_learning_id,
|
||||||
|
supersedes_learning_id, tags, created_at, updated_at
|
||||||
|
from learnings
|
||||||
|
where id = $1
|
||||||
|
`, id)
|
||||||
|
|
||||||
|
learning, err := scanLearning(row)
|
||||||
|
if err != nil {
|
||||||
|
if err == pgx.ErrNoRows {
|
||||||
|
return thoughttypes.Learning{}, fmt.Errorf("learning not found: %s", id)
|
||||||
|
}
|
||||||
|
return thoughttypes.Learning{}, fmt.Errorf("get learning: %w", err)
|
||||||
|
}
|
||||||
|
return learning, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *DB) ListLearnings(ctx context.Context, filter thoughttypes.LearningFilter) ([]thoughttypes.Learning, error) {
|
||||||
|
args := make([]any, 0, 8)
|
||||||
|
conditions := make([]string, 0, 8)
|
||||||
|
|
||||||
|
if filter.ProjectID != nil {
|
||||||
|
args = append(args, *filter.ProjectID)
|
||||||
|
conditions = append(conditions, fmt.Sprintf("project_id = $%d", len(args)))
|
||||||
|
}
|
||||||
|
if value := strings.TrimSpace(filter.Category); value != "" {
|
||||||
|
args = append(args, value)
|
||||||
|
conditions = append(conditions, fmt.Sprintf("category = $%d", len(args)))
|
||||||
|
}
|
||||||
|
if value := strings.TrimSpace(filter.Area); value != "" {
|
||||||
|
args = append(args, value)
|
||||||
|
conditions = append(conditions, fmt.Sprintf("area = $%d", len(args)))
|
||||||
|
}
|
||||||
|
if value := strings.TrimSpace(filter.Status); value != "" {
|
||||||
|
args = append(args, value)
|
||||||
|
conditions = append(conditions, fmt.Sprintf("status = $%d", len(args)))
|
||||||
|
}
|
||||||
|
if value := strings.TrimSpace(filter.Priority); value != "" {
|
||||||
|
args = append(args, value)
|
||||||
|
conditions = append(conditions, fmt.Sprintf("priority = $%d", len(args)))
|
||||||
|
}
|
||||||
|
if value := strings.TrimSpace(filter.Tag); value != "" {
|
||||||
|
args = append(args, value)
|
||||||
|
conditions = append(conditions, fmt.Sprintf("$%d = any(tags)", len(args)))
|
||||||
|
}
|
||||||
|
if value := strings.TrimSpace(filter.Query); value != "" {
|
||||||
|
args = append(args, value)
|
||||||
|
conditions = append(conditions, fmt.Sprintf("to_tsvector('simple', summary || ' ' || coalesce(details, '')) @@ websearch_to_tsquery('simple', $%d)", len(args)))
|
||||||
|
}
|
||||||
|
|
||||||
|
query := `
|
||||||
|
select id, summary, details, category, area, status, priority, confidence,
|
||||||
|
action_required, source_type, source_ref, project_id, related_thought_id,
|
||||||
|
related_skill_id, reviewed_by, reviewed_at, duplicate_of_learning_id,
|
||||||
|
supersedes_learning_id, tags, created_at, updated_at
|
||||||
|
from learnings
|
||||||
|
`
|
||||||
|
if len(conditions) > 0 {
|
||||||
|
query += " where " + strings.Join(conditions, " and ")
|
||||||
|
}
|
||||||
|
query += " order by updated_at desc"
|
||||||
|
if filter.Limit > 0 {
|
||||||
|
args = append(args, filter.Limit)
|
||||||
|
query += fmt.Sprintf(" limit $%d", len(args))
|
||||||
|
}
|
||||||
|
|
||||||
|
rows, err := db.pool.Query(ctx, query, args...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("list learnings: %w", err)
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
items := make([]thoughttypes.Learning, 0)
|
||||||
|
for rows.Next() {
|
||||||
|
item, err := scanLearning(rows)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("scan learning: %w", err)
|
||||||
|
}
|
||||||
|
items = append(items, item)
|
||||||
|
}
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, fmt.Errorf("iterate learnings: %w", err)
|
||||||
|
}
|
||||||
|
return items, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type learningScanner interface {
|
||||||
|
Scan(dest ...any) error
|
||||||
|
}
|
||||||
|
|
||||||
|
func scanLearning(row learningScanner) (thoughttypes.Learning, error) {
|
||||||
|
var learning thoughttypes.Learning
|
||||||
|
var sourceType pgtype.Text
|
||||||
|
var sourceRef pgtype.Text
|
||||||
|
var reviewedBy pgtype.Text
|
||||||
|
var tags []string
|
||||||
|
|
||||||
|
err := row.Scan(
|
||||||
|
&learning.ID,
|
||||||
|
&learning.Summary,
|
||||||
|
&learning.Details,
|
||||||
|
&learning.Category,
|
||||||
|
&learning.Area,
|
||||||
|
&learning.Status,
|
||||||
|
&learning.Priority,
|
||||||
|
&learning.Confidence,
|
||||||
|
&learning.ActionRequired,
|
||||||
|
&sourceType,
|
||||||
|
&sourceRef,
|
||||||
|
&learning.ProjectID,
|
||||||
|
&learning.RelatedThoughtID,
|
||||||
|
&learning.RelatedSkillID,
|
||||||
|
&reviewedBy,
|
||||||
|
&learning.ReviewedAt,
|
||||||
|
&learning.DuplicateOfLearningID,
|
||||||
|
&learning.SupersedesLearningID,
|
||||||
|
&tags,
|
||||||
|
&learning.CreatedAt,
|
||||||
|
&learning.UpdatedAt,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return thoughttypes.Learning{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
learning.SourceType = sourceType.String
|
||||||
|
learning.SourceRef = sourceRef.String
|
||||||
|
if reviewedBy.Valid {
|
||||||
|
value := reviewedBy.String
|
||||||
|
learning.ReviewedBy = &value
|
||||||
|
}
|
||||||
|
if tags == nil {
|
||||||
|
learning.Tags = []string{}
|
||||||
|
} else {
|
||||||
|
learning.Tags = tags
|
||||||
|
}
|
||||||
|
return learning, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func nullableText(value string) *string {
|
||||||
|
trimmed := strings.TrimSpace(value)
|
||||||
|
if trimmed == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &trimmed
|
||||||
|
}
|
||||||
|
|
||||||
|
func nullableTextPtr(value *string) *string {
|
||||||
|
if value == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
trimmed := strings.TrimSpace(*value)
|
||||||
|
if trimmed == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &trimmed
|
||||||
|
}
|
||||||
@@ -26,21 +26,42 @@ func (db *DB) CreateProject(ctx context.Context, name, description string) (thou
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (db *DB) GetProject(ctx context.Context, nameOrID string) (thoughttypes.Project, error) {
|
func (db *DB) GetProject(ctx context.Context, nameOrID string) (thoughttypes.Project, error) {
|
||||||
var row pgx.Row
|
lookup := strings.TrimSpace(nameOrID)
|
||||||
if parsedID, err := uuid.Parse(strings.TrimSpace(nameOrID)); err == nil {
|
|
||||||
row = db.pool.QueryRow(ctx, `
|
// Prefer guid lookup when input parses as UUID, but fall back to name lookup
|
||||||
|
// so UUID-shaped project names can still be resolved by name.
|
||||||
|
if parsedID, err := uuid.Parse(lookup); err == nil {
|
||||||
|
project, queryErr := db.getProjectByGUID(ctx, parsedID)
|
||||||
|
if queryErr == nil {
|
||||||
|
return project, nil
|
||||||
|
}
|
||||||
|
if queryErr != pgx.ErrNoRows {
|
||||||
|
return thoughttypes.Project{}, queryErr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return db.getProjectByName(ctx, lookup)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *DB) getProjectByGUID(ctx context.Context, id uuid.UUID) (thoughttypes.Project, error) {
|
||||||
|
row := db.pool.QueryRow(ctx, `
|
||||||
select guid, name, description, created_at, last_active_at
|
select guid, name, description, created_at, last_active_at
|
||||||
from projects
|
from projects
|
||||||
where guid = $1
|
where guid = $1
|
||||||
`, parsedID)
|
`, id)
|
||||||
} else {
|
return scanProject(row)
|
||||||
row = db.pool.QueryRow(ctx, `
|
}
|
||||||
|
|
||||||
|
func (db *DB) getProjectByName(ctx context.Context, name string) (thoughttypes.Project, error) {
|
||||||
|
row := db.pool.QueryRow(ctx, `
|
||||||
select guid, name, description, created_at, last_active_at
|
select guid, name, description, created_at, last_active_at
|
||||||
from projects
|
from projects
|
||||||
where name = $1
|
where name = $1
|
||||||
`, strings.TrimSpace(nameOrID))
|
`, name)
|
||||||
|
return scanProject(row)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func scanProject(row pgx.Row) (thoughttypes.Project, error) {
|
||||||
var project thoughttypes.Project
|
var project thoughttypes.Project
|
||||||
if err := row.Scan(&project.ID, &project.Name, &project.Description, &project.CreatedAt, &project.LastActiveAt); err != nil {
|
if err := row.Scan(&project.ID, &project.Name, &project.Description, &project.CreatedAt, &project.LastActiveAt); err != nil {
|
||||||
if err == pgx.ErrNoRows {
|
if err == pgx.ErrNoRows {
|
||||||
|
|||||||
@@ -58,6 +58,12 @@ func (db *DB) InsertThought(ctx context.Context, thought thoughttypes.Thought, e
|
|||||||
return thoughttypes.Thought{}, fmt.Errorf("commit thought insert: %w", err)
|
return thoughttypes.Thought{}, fmt.Errorf("commit thought insert: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(thought.Embedding) > 0 {
|
||||||
|
created.EmbeddingStatus = "done"
|
||||||
|
} else {
|
||||||
|
created.EmbeddingStatus = "pending"
|
||||||
|
}
|
||||||
|
|
||||||
return created, nil
|
return created, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -576,7 +582,7 @@ func (db *DB) SearchThoughtsText(ctx context.Context, query string, limit int, p
|
|||||||
args := []any{query}
|
args := []any{query}
|
||||||
conditions := []string{
|
conditions := []string{
|
||||||
"t.archived_at is null",
|
"t.archived_at is null",
|
||||||
"to_tsvector('simple', t.content) @@ websearch_to_tsquery('simple', $1)",
|
"(to_tsvector('simple', t.content) || to_tsvector('simple', coalesce(p.name, ''))) @@ websearch_to_tsquery('simple', $1)",
|
||||||
}
|
}
|
||||||
if projectID != nil {
|
if projectID != nil {
|
||||||
args = append(args, *projectID)
|
args = append(args, *projectID)
|
||||||
@@ -590,9 +596,10 @@ func (db *DB) SearchThoughtsText(ctx context.Context, query string, limit int, p
|
|||||||
|
|
||||||
q := `
|
q := `
|
||||||
select t.guid, t.content, t.metadata,
|
select t.guid, t.content, t.metadata,
|
||||||
ts_rank_cd(to_tsvector('simple', t.content), websearch_to_tsquery('simple', $1)) as similarity,
|
ts_rank_cd(to_tsvector('simple', t.content) || to_tsvector('simple', coalesce(p.name, '')), websearch_to_tsquery('simple', $1)) as similarity,
|
||||||
t.created_at
|
t.created_at
|
||||||
from thoughts t
|
from thoughts t
|
||||||
|
left join projects p on t.project_id = p.guid
|
||||||
where ` + strings.Join(conditions, " and ") + `
|
where ` + strings.Join(conditions, " and ") + `
|
||||||
order by similarity desc
|
order by similarity desc
|
||||||
limit $` + fmt.Sprintf("%d", len(args))
|
limit $` + fmt.Sprintf("%d", len(args))
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ const backfillConcurrency = 4
|
|||||||
|
|
||||||
type BackfillTool struct {
|
type BackfillTool struct {
|
||||||
store *store.DB
|
store *store.DB
|
||||||
provider ai.Provider
|
embeddings *ai.EmbeddingRunner
|
||||||
sessions *session.ActiveProjects
|
sessions *session.ActiveProjects
|
||||||
logger *slog.Logger
|
logger *slog.Logger
|
||||||
}
|
}
|
||||||
@@ -47,8 +47,51 @@ type BackfillOutput struct {
|
|||||||
Failures []BackfillFailure `json:"failures,omitempty"`
|
Failures []BackfillFailure `json:"failures,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewBackfillTool(db *store.DB, provider ai.Provider, sessions *session.ActiveProjects, logger *slog.Logger) *BackfillTool {
|
func NewBackfillTool(db *store.DB, embeddings *ai.EmbeddingRunner, sessions *session.ActiveProjects, logger *slog.Logger) *BackfillTool {
|
||||||
return &BackfillTool{store: db, provider: provider, sessions: sessions, logger: logger}
|
return &BackfillTool{store: db, embeddings: embeddings, sessions: sessions, logger: logger}
|
||||||
|
}
|
||||||
|
|
||||||
|
// QueueThought queues a single thought for background embedding generation.
|
||||||
|
// It is used by capture when the embedding provider is temporarily unavailable.
|
||||||
|
func (t *BackfillTool) QueueThought(ctx context.Context, id uuid.UUID, content string) {
|
||||||
|
go func() {
|
||||||
|
started := time.Now()
|
||||||
|
t.logger.Info("background embedding started",
|
||||||
|
slog.String("thought_id", id.String()),
|
||||||
|
slog.String("provider", t.embeddings.PrimaryProvider()),
|
||||||
|
slog.String("model", t.embeddings.PrimaryModel()),
|
||||||
|
)
|
||||||
|
|
||||||
|
result, err := t.embeddings.Embed(ctx, content)
|
||||||
|
if err != nil {
|
||||||
|
t.logger.Warn("background embedding error",
|
||||||
|
slog.String("thought_id", id.String()),
|
||||||
|
slog.String("provider", t.embeddings.PrimaryProvider()),
|
||||||
|
slog.String("model", t.embeddings.PrimaryModel()),
|
||||||
|
slog.String("stage", "embed"),
|
||||||
|
slog.Duration("duration", time.Since(started)),
|
||||||
|
slog.String("error", err.Error()),
|
||||||
|
)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := t.store.UpsertEmbedding(ctx, id, result.Model, result.Vector); err != nil {
|
||||||
|
t.logger.Warn("background embedding error",
|
||||||
|
slog.String("thought_id", id.String()),
|
||||||
|
slog.String("provider", t.embeddings.PrimaryProvider()),
|
||||||
|
slog.String("model", result.Model),
|
||||||
|
slog.String("stage", "upsert"),
|
||||||
|
slog.Duration("duration", time.Since(started)),
|
||||||
|
slog.String("error", err.Error()),
|
||||||
|
)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
t.logger.Info("background embedding complete",
|
||||||
|
slog.String("thought_id", id.String()),
|
||||||
|
slog.String("provider", t.embeddings.PrimaryProvider()),
|
||||||
|
slog.String("model", result.Model),
|
||||||
|
slog.Duration("duration", time.Since(started)),
|
||||||
|
)
|
||||||
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *BackfillTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in BackfillInput) (*mcp.CallToolResult, BackfillOutput, error) {
|
func (t *BackfillTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in BackfillInput) (*mcp.CallToolResult, BackfillOutput, error) {
|
||||||
@@ -67,15 +110,15 @@ func (t *BackfillTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in
|
|||||||
projectID = &project.ID
|
projectID = &project.ID
|
||||||
}
|
}
|
||||||
|
|
||||||
model := t.provider.EmbeddingModel()
|
primaryModel := t.embeddings.PrimaryModel()
|
||||||
|
|
||||||
thoughts, err := t.store.ListThoughtsMissingEmbedding(ctx, model, limit, projectID, in.IncludeArchived, in.OlderThanDays)
|
thoughts, err := t.store.ListThoughtsMissingEmbedding(ctx, primaryModel, limit, projectID, in.IncludeArchived, in.OlderThanDays)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, BackfillOutput{}, err
|
return nil, BackfillOutput{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
out := BackfillOutput{
|
out := BackfillOutput{
|
||||||
Model: model,
|
Model: primaryModel,
|
||||||
Scanned: len(thoughts),
|
Scanned: len(thoughts),
|
||||||
DryRun: in.DryRun,
|
DryRun: in.DryRun,
|
||||||
}
|
}
|
||||||
@@ -101,7 +144,7 @@ func (t *BackfillTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in
|
|||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
defer sem.Release(1)
|
defer sem.Release(1)
|
||||||
|
|
||||||
vec, embedErr := t.provider.Embed(ctx, content)
|
result, embedErr := t.embeddings.Embed(ctx, content)
|
||||||
if embedErr != nil {
|
if embedErr != nil {
|
||||||
mu.Lock()
|
mu.Lock()
|
||||||
out.Failures = append(out.Failures, BackfillFailure{ID: id.String(), Error: embedErr.Error()})
|
out.Failures = append(out.Failures, BackfillFailure{ID: id.String(), Error: embedErr.Error()})
|
||||||
@@ -110,7 +153,7 @@ func (t *BackfillTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if upsertErr := t.store.UpsertEmbedding(ctx, id, model, vec); upsertErr != nil {
|
if upsertErr := t.store.UpsertEmbedding(ctx, id, result.Model, result.Vector); upsertErr != nil {
|
||||||
mu.Lock()
|
mu.Lock()
|
||||||
out.Failures = append(out.Failures, BackfillFailure{ID: id.String(), Error: upsertErr.Error()})
|
out.Failures = append(out.Failures, BackfillFailure{ID: id.String(), Error: upsertErr.Error()})
|
||||||
mu.Unlock()
|
mu.Unlock()
|
||||||
@@ -130,7 +173,7 @@ func (t *BackfillTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in
|
|||||||
out.Skipped = out.Scanned - out.Embedded - out.Failed
|
out.Skipped = out.Scanned - out.Embedded - out.Failed
|
||||||
|
|
||||||
t.logger.Info("backfill completed",
|
t.logger.Info("backfill completed",
|
||||||
slog.String("model", model),
|
slog.String("model", primaryModel),
|
||||||
slog.Int("scanned", out.Scanned),
|
slog.Int("scanned", out.Scanned),
|
||||||
slog.Int("embedded", out.Embedded),
|
slog.Int("embedded", out.Embedded),
|
||||||
slog.Int("failed", out.Failed),
|
slog.Int("failed", out.Failed),
|
||||||
|
|||||||
@@ -2,12 +2,10 @@ package tools
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"log/slog"
|
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||||
"golang.org/x/sync/errgroup"
|
|
||||||
|
|
||||||
"git.warky.dev/wdevs/amcs/internal/ai"
|
"git.warky.dev/wdevs/amcs/internal/ai"
|
||||||
"git.warky.dev/wdevs/amcs/internal/config"
|
"git.warky.dev/wdevs/amcs/internal/config"
|
||||||
@@ -17,14 +15,24 @@ import (
|
|||||||
thoughttypes "git.warky.dev/wdevs/amcs/internal/types"
|
thoughttypes "git.warky.dev/wdevs/amcs/internal/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// EmbeddingQueuer queues a thought for background embedding generation.
|
||||||
|
type EmbeddingQueuer interface {
|
||||||
|
QueueThought(ctx context.Context, id uuid.UUID, content string)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MetadataQueuer queues a thought for background metadata retry. Both
|
||||||
|
// MetadataRetryer and EnrichmentRetryer satisfy this.
|
||||||
|
type MetadataQueuer interface {
|
||||||
|
QueueThought(id uuid.UUID)
|
||||||
|
}
|
||||||
|
|
||||||
type CaptureTool struct {
|
type CaptureTool struct {
|
||||||
store *store.DB
|
store *store.DB
|
||||||
provider ai.Provider
|
embeddings *ai.EmbeddingRunner
|
||||||
capture config.CaptureConfig
|
capture config.CaptureConfig
|
||||||
sessions *session.ActiveProjects
|
sessions *session.ActiveProjects
|
||||||
metadataTimeout time.Duration
|
retryer MetadataQueuer
|
||||||
retryer *MetadataRetryer
|
embedRetryer EmbeddingQueuer
|
||||||
log *slog.Logger
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type CaptureInput struct {
|
type CaptureInput struct {
|
||||||
@@ -36,8 +44,8 @@ type CaptureOutput struct {
|
|||||||
Thought thoughttypes.Thought `json:"thought"`
|
Thought thoughttypes.Thought `json:"thought"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewCaptureTool(db *store.DB, provider ai.Provider, capture config.CaptureConfig, metadataTimeout time.Duration, sessions *session.ActiveProjects, retryer *MetadataRetryer, log *slog.Logger) *CaptureTool {
|
func NewCaptureTool(db *store.DB, embeddings *ai.EmbeddingRunner, capture config.CaptureConfig, sessions *session.ActiveProjects, retryer MetadataQueuer, embedRetryer EmbeddingQueuer) *CaptureTool {
|
||||||
return &CaptureTool{store: db, provider: provider, capture: capture, sessions: sessions, metadataTimeout: metadataTimeout, retryer: retryer, log: log}
|
return &CaptureTool{store: db, embeddings: embeddings, capture: capture, sessions: sessions, retryer: retryer, embedRetryer: embedRetryer}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *CaptureTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in CaptureInput) (*mcp.CallToolResult, CaptureOutput, error) {
|
func (t *CaptureTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in CaptureInput) (*mcp.CallToolResult, CaptureOutput, error) {
|
||||||
@@ -51,61 +59,30 @@ func (t *CaptureTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in C
|
|||||||
return nil, CaptureOutput{}, err
|
return nil, CaptureOutput{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
var embedding []float32
|
|
||||||
rawMetadata := metadata.Fallback(t.capture)
|
rawMetadata := metadata.Fallback(t.capture)
|
||||||
metadataNeedsRetry := false
|
rawMetadata.MetadataStatus = metadata.MetadataStatusPending
|
||||||
|
|
||||||
group, groupCtx := errgroup.WithContext(ctx)
|
|
||||||
group.Go(func() error {
|
|
||||||
vector, err := t.provider.Embed(groupCtx, content)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
embedding = vector
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
group.Go(func() error {
|
|
||||||
metaCtx := groupCtx
|
|
||||||
attemptedAt := time.Now().UTC()
|
|
||||||
if t.metadataTimeout > 0 {
|
|
||||||
var cancel context.CancelFunc
|
|
||||||
metaCtx, cancel = context.WithTimeout(groupCtx, t.metadataTimeout)
|
|
||||||
defer cancel()
|
|
||||||
}
|
|
||||||
extracted, err := t.provider.ExtractMetadata(metaCtx, content)
|
|
||||||
if err != nil {
|
|
||||||
t.log.Warn("metadata extraction failed, using fallback", slog.String("provider", t.provider.Name()), slog.String("error", err.Error()))
|
|
||||||
rawMetadata = metadata.MarkMetadataPending(rawMetadata, t.capture, attemptedAt, err)
|
|
||||||
metadataNeedsRetry = true
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
rawMetadata = metadata.MarkMetadataComplete(extracted, t.capture, attemptedAt)
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
|
|
||||||
if err := group.Wait(); err != nil {
|
|
||||||
return nil, CaptureOutput{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
thought := thoughttypes.Thought{
|
thought := thoughttypes.Thought{
|
||||||
Content: content,
|
Content: content,
|
||||||
Embedding: embedding,
|
Metadata: rawMetadata,
|
||||||
Metadata: metadata.Normalize(metadata.SanitizeExtracted(rawMetadata), t.capture),
|
|
||||||
}
|
}
|
||||||
if project != nil {
|
if project != nil {
|
||||||
thought.ProjectID = &project.ID
|
thought.ProjectID = &project.ID
|
||||||
}
|
}
|
||||||
|
|
||||||
created, err := t.store.InsertThought(ctx, thought, t.provider.EmbeddingModel())
|
created, err := t.store.InsertThought(ctx, thought, t.embeddings.PrimaryModel())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, CaptureOutput{}, err
|
return nil, CaptureOutput{}, err
|
||||||
}
|
}
|
||||||
if project != nil {
|
if project != nil {
|
||||||
_ = t.store.TouchProject(ctx, project.ID)
|
_ = t.store.TouchProject(ctx, project.ID)
|
||||||
}
|
}
|
||||||
if metadataNeedsRetry && t.retryer != nil {
|
|
||||||
|
if t.retryer != nil {
|
||||||
t.retryer.QueueThought(created.ID)
|
t.retryer.QueueThought(created.ID)
|
||||||
}
|
}
|
||||||
|
if t.embedRetryer != nil {
|
||||||
|
t.embedRetryer.QueueThought(ctx, created.ID, content)
|
||||||
|
}
|
||||||
|
|
||||||
return nil, CaptureOutput{Thought: created}, nil
|
return nil, CaptureOutput{Thought: created}, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ import (
|
|||||||
|
|
||||||
type ContextTool struct {
|
type ContextTool struct {
|
||||||
store *store.DB
|
store *store.DB
|
||||||
provider ai.Provider
|
embeddings *ai.EmbeddingRunner
|
||||||
search config.SearchConfig
|
search config.SearchConfig
|
||||||
sessions *session.ActiveProjects
|
sessions *session.ActiveProjects
|
||||||
}
|
}
|
||||||
@@ -41,8 +41,8 @@ type ProjectContextOutput struct {
|
|||||||
Items []ContextItem `json:"items"`
|
Items []ContextItem `json:"items"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewContextTool(db *store.DB, provider ai.Provider, search config.SearchConfig, sessions *session.ActiveProjects) *ContextTool {
|
func NewContextTool(db *store.DB, embeddings *ai.EmbeddingRunner, search config.SearchConfig, sessions *session.ActiveProjects) *ContextTool {
|
||||||
return &ContextTool{store: db, provider: provider, search: search, sessions: sessions}
|
return &ContextTool{store: db, embeddings: embeddings, search: search, sessions: sessions}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *ContextTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in ProjectContextInput) (*mcp.CallToolResult, ProjectContextOutput, error) {
|
func (t *ContextTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in ProjectContextInput) (*mcp.CallToolResult, ProjectContextOutput, error) {
|
||||||
@@ -72,7 +72,7 @@ func (t *ContextTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in P
|
|||||||
|
|
||||||
query := strings.TrimSpace(in.Query)
|
query := strings.TrimSpace(in.Query)
|
||||||
if query != "" {
|
if query != "" {
|
||||||
semantic, err := semanticSearch(ctx, t.store, t.provider, t.search, query, limit, t.search.DefaultThreshold, &project.ID, nil)
|
semantic, err := semanticSearch(ctx, t.store, t.embeddings, t.search, query, limit, t.search.DefaultThreshold, &project.ID, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, ProjectContextOutput{}, err
|
return nil, ProjectContextOutput{}, err
|
||||||
}
|
}
|
||||||
|
|||||||
227
internal/tools/enrichment_retry.go
Normal file
227
internal/tools/enrichment_retry.go
Normal file
@@ -0,0 +1,227 @@
|
|||||||
|
package tools
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"log/slog"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||||
|
"golang.org/x/sync/semaphore"
|
||||||
|
|
||||||
|
"git.warky.dev/wdevs/amcs/internal/ai"
|
||||||
|
"git.warky.dev/wdevs/amcs/internal/config"
|
||||||
|
"git.warky.dev/wdevs/amcs/internal/metadata"
|
||||||
|
"git.warky.dev/wdevs/amcs/internal/session"
|
||||||
|
"git.warky.dev/wdevs/amcs/internal/store"
|
||||||
|
thoughttypes "git.warky.dev/wdevs/amcs/internal/types"
|
||||||
|
)
|
||||||
|
|
||||||
|
const enrichmentRetryConcurrency = 4
|
||||||
|
const enrichmentRetryMaxAttempts = 5
|
||||||
|
|
||||||
|
var enrichmentRetryBackoff = []time.Duration{
|
||||||
|
30 * time.Second,
|
||||||
|
2 * time.Minute,
|
||||||
|
10 * time.Minute,
|
||||||
|
30 * time.Minute,
|
||||||
|
2 * time.Hour,
|
||||||
|
}
|
||||||
|
|
||||||
|
type EnrichmentRetryer struct {
|
||||||
|
backgroundCtx context.Context
|
||||||
|
store *store.DB
|
||||||
|
metadata *ai.MetadataRunner
|
||||||
|
capture config.CaptureConfig
|
||||||
|
sessions *session.ActiveProjects
|
||||||
|
metadataTimeout time.Duration
|
||||||
|
logger *slog.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
type RetryEnrichmentTool struct {
|
||||||
|
retryer *EnrichmentRetryer
|
||||||
|
}
|
||||||
|
|
||||||
|
type RetryEnrichmentInput struct {
|
||||||
|
Project string `json:"project,omitempty" jsonschema:"optional project name or id to scope the retry"`
|
||||||
|
Limit int `json:"limit,omitempty" jsonschema:"maximum number of thoughts to process in one call; defaults to 100"`
|
||||||
|
IncludeArchived bool `json:"include_archived,omitempty" jsonschema:"whether to include archived thoughts; defaults to false"`
|
||||||
|
OlderThanDays int `json:"older_than_days,omitempty" jsonschema:"only retry thoughts whose last metadata attempt was at least N days ago; 0 means no restriction"`
|
||||||
|
DryRun bool `json:"dry_run,omitempty" jsonschema:"report counts without retrying metadata extraction"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type RetryEnrichmentFailure struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Error string `json:"error"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type RetryEnrichmentOutput struct {
|
||||||
|
Scanned int `json:"scanned"`
|
||||||
|
Retried int `json:"retried"`
|
||||||
|
Updated int `json:"updated"`
|
||||||
|
Skipped int `json:"skipped"`
|
||||||
|
Failed int `json:"failed"`
|
||||||
|
DryRun bool `json:"dry_run"`
|
||||||
|
Failures []RetryEnrichmentFailure `json:"failures,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewEnrichmentRetryer(backgroundCtx context.Context, db *store.DB, metadataRunner *ai.MetadataRunner, capture config.CaptureConfig, metadataTimeout time.Duration, sessions *session.ActiveProjects, logger *slog.Logger) *EnrichmentRetryer {
|
||||||
|
if backgroundCtx == nil {
|
||||||
|
backgroundCtx = context.Background()
|
||||||
|
}
|
||||||
|
return &EnrichmentRetryer{
|
||||||
|
backgroundCtx: backgroundCtx,
|
||||||
|
store: db,
|
||||||
|
metadata: metadataRunner,
|
||||||
|
capture: capture,
|
||||||
|
sessions: sessions,
|
||||||
|
metadataTimeout: metadataTimeout,
|
||||||
|
logger: logger,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewRetryEnrichmentTool(retryer *EnrichmentRetryer) *RetryEnrichmentTool {
|
||||||
|
return &RetryEnrichmentTool{retryer: retryer}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *RetryEnrichmentTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in RetryEnrichmentInput) (*mcp.CallToolResult, RetryEnrichmentOutput, error) {
|
||||||
|
return t.retryer.Handle(ctx, req, in)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *EnrichmentRetryer) QueueThought(id uuid.UUID) {
|
||||||
|
go func() {
|
||||||
|
started := time.Now()
|
||||||
|
r.logger.Info("background metadata started",
|
||||||
|
slog.String("thought_id", id.String()),
|
||||||
|
slog.String("provider", r.metadata.PrimaryProvider()),
|
||||||
|
slog.String("model", r.metadata.PrimaryModel()),
|
||||||
|
)
|
||||||
|
updated, err := r.retryOne(r.backgroundCtx, id)
|
||||||
|
if err != nil {
|
||||||
|
r.logger.Warn("background metadata error",
|
||||||
|
slog.String("thought_id", id.String()),
|
||||||
|
slog.String("provider", r.metadata.PrimaryProvider()),
|
||||||
|
slog.String("model", r.metadata.PrimaryModel()),
|
||||||
|
slog.Duration("duration", time.Since(started)),
|
||||||
|
slog.String("error", err.Error()),
|
||||||
|
)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
r.logger.Info("background metadata complete",
|
||||||
|
slog.String("thought_id", id.String()),
|
||||||
|
slog.String("provider", r.metadata.PrimaryProvider()),
|
||||||
|
slog.String("model", r.metadata.PrimaryModel()),
|
||||||
|
slog.Bool("updated", updated),
|
||||||
|
slog.Duration("duration", time.Since(started)),
|
||||||
|
)
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *EnrichmentRetryer) Handle(ctx context.Context, req *mcp.CallToolRequest, in RetryEnrichmentInput) (*mcp.CallToolResult, RetryEnrichmentOutput, error) {
|
||||||
|
limit := in.Limit
|
||||||
|
if limit <= 0 {
|
||||||
|
limit = 100
|
||||||
|
}
|
||||||
|
|
||||||
|
project, err := resolveProject(ctx, r.store, r.sessions, req, in.Project, false)
|
||||||
|
if err != nil {
|
||||||
|
return nil, RetryEnrichmentOutput{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var projectID *uuid.UUID
|
||||||
|
if project != nil {
|
||||||
|
projectID = &project.ID
|
||||||
|
}
|
||||||
|
|
||||||
|
thoughts, err := r.store.ListThoughtsPendingMetadataRetry(ctx, limit, projectID, in.IncludeArchived, in.OlderThanDays)
|
||||||
|
if err != nil {
|
||||||
|
return nil, RetryEnrichmentOutput{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
out := RetryEnrichmentOutput{Scanned: len(thoughts), DryRun: in.DryRun}
|
||||||
|
if in.DryRun || len(thoughts) == 0 {
|
||||||
|
return nil, out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
sem := semaphore.NewWeighted(enrichmentRetryConcurrency)
|
||||||
|
var mu sync.Mutex
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
|
||||||
|
for _, thought := range thoughts {
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if err := sem.Acquire(ctx, 1); err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Add(1)
|
||||||
|
go func(thought thoughttypes.Thought) {
|
||||||
|
defer wg.Done()
|
||||||
|
defer sem.Release(1)
|
||||||
|
|
||||||
|
mu.Lock()
|
||||||
|
out.Retried++
|
||||||
|
mu.Unlock()
|
||||||
|
|
||||||
|
updated, err := r.retryOne(ctx, thought.ID)
|
||||||
|
if err != nil {
|
||||||
|
mu.Lock()
|
||||||
|
out.Failures = append(out.Failures, RetryEnrichmentFailure{ID: thought.ID.String(), Error: err.Error()})
|
||||||
|
mu.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if updated {
|
||||||
|
mu.Lock()
|
||||||
|
out.Updated++
|
||||||
|
mu.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
mu.Lock()
|
||||||
|
out.Skipped++
|
||||||
|
mu.Unlock()
|
||||||
|
}(thought)
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
out.Failed = len(out.Failures)
|
||||||
|
|
||||||
|
return nil, out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *EnrichmentRetryer) retryOne(ctx context.Context, id uuid.UUID) (bool, error) {
|
||||||
|
thought, err := r.store.GetThought(ctx, id)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
if thought.Metadata.MetadataStatus == metadata.MetadataStatusComplete {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
attemptCtx := ctx
|
||||||
|
if r.metadataTimeout > 0 {
|
||||||
|
var cancel context.CancelFunc
|
||||||
|
attemptCtx, cancel = context.WithTimeout(ctx, r.metadataTimeout)
|
||||||
|
defer cancel()
|
||||||
|
}
|
||||||
|
|
||||||
|
attemptedAt := time.Now().UTC()
|
||||||
|
extracted, extractErr := r.metadata.ExtractMetadata(attemptCtx, thought.Content)
|
||||||
|
if extractErr != nil {
|
||||||
|
failedMetadata := metadata.MarkMetadataFailed(thought.Metadata, r.capture, attemptedAt, extractErr)
|
||||||
|
if _, updateErr := r.store.UpdateThoughtMetadata(ctx, thought.ID, failedMetadata); updateErr != nil {
|
||||||
|
return false, updateErr
|
||||||
|
}
|
||||||
|
return false, extractErr
|
||||||
|
}
|
||||||
|
|
||||||
|
completedMetadata := metadata.MarkMetadataComplete(metadata.SanitizeExtracted(extracted), r.capture, attemptedAt)
|
||||||
|
completedMetadata.Attachments = thought.Metadata.Attachments
|
||||||
|
if _, updateErr := r.store.UpdateThoughtMetadata(ctx, thought.ID, completedMetadata); updateErr != nil {
|
||||||
|
return false, updateErr
|
||||||
|
}
|
||||||
|
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
174
internal/tools/learnings.go
Normal file
174
internal/tools/learnings.go
Normal file
@@ -0,0 +1,174 @@
|
|||||||
|
package tools
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||||
|
|
||||||
|
"git.warky.dev/wdevs/amcs/internal/config"
|
||||||
|
"git.warky.dev/wdevs/amcs/internal/session"
|
||||||
|
"git.warky.dev/wdevs/amcs/internal/store"
|
||||||
|
thoughttypes "git.warky.dev/wdevs/amcs/internal/types"
|
||||||
|
)
|
||||||
|
|
||||||
|
type LearningsTool struct {
|
||||||
|
store *store.DB
|
||||||
|
sessions *session.ActiveProjects
|
||||||
|
cfg config.SearchConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
type AddLearningInput struct {
|
||||||
|
Summary string `json:"summary" jsonschema:"short curated learning summary"`
|
||||||
|
Details string `json:"details,omitempty" jsonschema:"optional detailed learning body"`
|
||||||
|
Category string `json:"category,omitempty"`
|
||||||
|
Area string `json:"area,omitempty"`
|
||||||
|
Status string `json:"status,omitempty"`
|
||||||
|
Priority string `json:"priority,omitempty"`
|
||||||
|
Confidence string `json:"confidence,omitempty"`
|
||||||
|
ActionRequired *bool `json:"action_required,omitempty"`
|
||||||
|
SourceType string `json:"source_type,omitempty"`
|
||||||
|
SourceRef string `json:"source_ref,omitempty"`
|
||||||
|
Project string `json:"project,omitempty" jsonschema:"project name or id; falls back to active session project"`
|
||||||
|
RelatedThoughtID *uuid.UUID `json:"related_thought_id,omitempty"`
|
||||||
|
RelatedSkillID *uuid.UUID `json:"related_skill_id,omitempty"`
|
||||||
|
ReviewedBy *string `json:"reviewed_by,omitempty"`
|
||||||
|
DuplicateOfLearningID *uuid.UUID `json:"duplicate_of_learning_id,omitempty"`
|
||||||
|
SupersedesLearningID *uuid.UUID `json:"supersedes_learning_id,omitempty"`
|
||||||
|
Tags []string `json:"tags,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type AddLearningOutput struct {
|
||||||
|
Learning thoughttypes.Learning `json:"learning"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type GetLearningInput struct {
|
||||||
|
ID uuid.UUID `json:"id" jsonschema:"learning id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type GetLearningOutput struct {
|
||||||
|
Learning thoughttypes.Learning `json:"learning"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ListLearningsInput struct {
|
||||||
|
Limit int `json:"limit,omitempty"`
|
||||||
|
Project string `json:"project,omitempty" jsonschema:"project name or id; falls back to active session project"`
|
||||||
|
Category string `json:"category,omitempty"`
|
||||||
|
Area string `json:"area,omitempty"`
|
||||||
|
Status string `json:"status,omitempty"`
|
||||||
|
Priority string `json:"priority,omitempty"`
|
||||||
|
Tag string `json:"tag,omitempty"`
|
||||||
|
Query string `json:"query,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ListLearningsOutput struct {
|
||||||
|
Learnings []thoughttypes.Learning `json:"learnings"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewLearningsTool(db *store.DB, sessions *session.ActiveProjects, cfg config.SearchConfig) *LearningsTool {
|
||||||
|
return &LearningsTool{store: db, sessions: sessions, cfg: cfg}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *LearningsTool) Add(ctx context.Context, req *mcp.CallToolRequest, in AddLearningInput) (*mcp.CallToolResult, AddLearningOutput, error) {
|
||||||
|
summary := strings.TrimSpace(in.Summary)
|
||||||
|
if summary == "" {
|
||||||
|
return nil, AddLearningOutput{}, errRequiredField("summary")
|
||||||
|
}
|
||||||
|
|
||||||
|
project, err := resolveProject(ctx, t.store, t.sessions, req, in.Project, false)
|
||||||
|
if err != nil {
|
||||||
|
return nil, AddLearningOutput{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
learning := thoughttypes.Learning{
|
||||||
|
Summary: summary,
|
||||||
|
Details: strings.TrimSpace(in.Details),
|
||||||
|
Category: defaultString(strings.TrimSpace(in.Category), "insight"),
|
||||||
|
Area: defaultString(strings.TrimSpace(in.Area), "other"),
|
||||||
|
Status: thoughttypes.LearningStatus(defaultString(strings.TrimSpace(in.Status), string(thoughttypes.LearningStatusPending))),
|
||||||
|
Priority: thoughttypes.LearningPriority(defaultString(strings.TrimSpace(in.Priority), string(thoughttypes.LearningPriorityMedium))),
|
||||||
|
Confidence: thoughttypes.LearningEvidenceLevel(defaultString(strings.TrimSpace(in.Confidence), string(thoughttypes.LearningEvidenceHypothesis))),
|
||||||
|
SourceType: strings.TrimSpace(in.SourceType),
|
||||||
|
SourceRef: strings.TrimSpace(in.SourceRef),
|
||||||
|
RelatedThoughtID: in.RelatedThoughtID,
|
||||||
|
RelatedSkillID: in.RelatedSkillID,
|
||||||
|
ReviewedBy: in.ReviewedBy,
|
||||||
|
DuplicateOfLearningID: in.DuplicateOfLearningID,
|
||||||
|
SupersedesLearningID: in.SupersedesLearningID,
|
||||||
|
Tags: normalizeStringSlice(in.Tags),
|
||||||
|
}
|
||||||
|
if in.ActionRequired != nil {
|
||||||
|
learning.ActionRequired = *in.ActionRequired
|
||||||
|
}
|
||||||
|
if project != nil {
|
||||||
|
learning.ProjectID = &project.ID
|
||||||
|
}
|
||||||
|
|
||||||
|
created, err := t.store.CreateLearning(ctx, learning)
|
||||||
|
if err != nil {
|
||||||
|
return nil, AddLearningOutput{}, err
|
||||||
|
}
|
||||||
|
return nil, AddLearningOutput{Learning: created}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *LearningsTool) Get(ctx context.Context, _ *mcp.CallToolRequest, in GetLearningInput) (*mcp.CallToolResult, GetLearningOutput, error) {
|
||||||
|
learning, err := t.store.GetLearning(ctx, in.ID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, GetLearningOutput{}, err
|
||||||
|
}
|
||||||
|
return nil, GetLearningOutput{Learning: learning}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *LearningsTool) List(ctx context.Context, req *mcp.CallToolRequest, in ListLearningsInput) (*mcp.CallToolResult, ListLearningsOutput, error) {
|
||||||
|
project, err := resolveProject(ctx, t.store, t.sessions, req, in.Project, false)
|
||||||
|
if err != nil {
|
||||||
|
return nil, ListLearningsOutput{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
filter := thoughttypes.LearningFilter{
|
||||||
|
Limit: normalizeLimit(in.Limit, t.cfg),
|
||||||
|
Category: strings.TrimSpace(in.Category),
|
||||||
|
Area: strings.TrimSpace(in.Area),
|
||||||
|
Status: strings.TrimSpace(in.Status),
|
||||||
|
Priority: strings.TrimSpace(in.Priority),
|
||||||
|
Tag: strings.TrimSpace(in.Tag),
|
||||||
|
Query: strings.TrimSpace(in.Query),
|
||||||
|
}
|
||||||
|
if project != nil {
|
||||||
|
filter.ProjectID = &project.ID
|
||||||
|
}
|
||||||
|
|
||||||
|
items, err := t.store.ListLearnings(ctx, filter)
|
||||||
|
if err != nil {
|
||||||
|
return nil, ListLearningsOutput{}, err
|
||||||
|
}
|
||||||
|
return nil, ListLearningsOutput{Learnings: items}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func defaultString(value string, fallback string) string {
|
||||||
|
if value == "" {
|
||||||
|
return fallback
|
||||||
|
}
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeStringSlice(values []string) []string {
|
||||||
|
if len(values) == 0 {
|
||||||
|
return []string{}
|
||||||
|
}
|
||||||
|
out := make([]string, 0, len(values))
|
||||||
|
seen := map[string]struct{}{}
|
||||||
|
for _, value := range values {
|
||||||
|
trimmed := strings.TrimSpace(value)
|
||||||
|
if trimmed == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, ok := seen[trimmed]; ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seen[trimmed] = struct{}{}
|
||||||
|
out = append(out, trimmed)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
@@ -14,7 +14,7 @@ import (
|
|||||||
|
|
||||||
type LinksTool struct {
|
type LinksTool struct {
|
||||||
store *store.DB
|
store *store.DB
|
||||||
provider ai.Provider
|
embeddings *ai.EmbeddingRunner
|
||||||
search config.SearchConfig
|
search config.SearchConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -47,8 +47,8 @@ type RelatedOutput struct {
|
|||||||
Related []RelatedThought `json:"related"`
|
Related []RelatedThought `json:"related"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewLinksTool(db *store.DB, provider ai.Provider, search config.SearchConfig) *LinksTool {
|
func NewLinksTool(db *store.DB, embeddings *ai.EmbeddingRunner, search config.SearchConfig) *LinksTool {
|
||||||
return &LinksTool{store: db, provider: provider, search: search}
|
return &LinksTool{store: db, embeddings: embeddings, search: search}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *LinksTool) Link(ctx context.Context, _ *mcp.CallToolRequest, in LinkInput) (*mcp.CallToolResult, LinkOutput, error) {
|
func (t *LinksTool) Link(ctx context.Context, _ *mcp.CallToolRequest, in LinkInput) (*mcp.CallToolResult, LinkOutput, error) {
|
||||||
@@ -117,7 +117,7 @@ func (t *LinksTool) Related(ctx context.Context, _ *mcp.CallToolRequest, in Rela
|
|||||||
}
|
}
|
||||||
|
|
||||||
if includeSemantic {
|
if includeSemantic {
|
||||||
semantic, err := semanticSearch(ctx, t.store, t.provider, t.search, thought.Content, t.search.DefaultLimit, t.search.DefaultThreshold, thought.ProjectID, &thought.ID)
|
semantic, err := semanticSearch(ctx, t.store, t.embeddings, t.search, thought.Content, t.search.DefaultLimit, t.search.DefaultThreshold, thought.ProjectID, &thought.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, RelatedOutput{}, err
|
return nil, RelatedOutput{}, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -23,17 +23,47 @@ const metadataRetryConcurrency = 4
|
|||||||
type MetadataRetryer struct {
|
type MetadataRetryer struct {
|
||||||
backgroundCtx context.Context
|
backgroundCtx context.Context
|
||||||
store *store.DB
|
store *store.DB
|
||||||
provider ai.Provider
|
metadata *ai.MetadataRunner
|
||||||
capture config.CaptureConfig
|
capture config.CaptureConfig
|
||||||
sessions *session.ActiveProjects
|
sessions *session.ActiveProjects
|
||||||
metadataTimeout time.Duration
|
metadataTimeout time.Duration
|
||||||
logger *slog.Logger
|
logger *slog.Logger
|
||||||
|
lock *RetryLocker
|
||||||
}
|
}
|
||||||
|
|
||||||
type RetryMetadataTool struct {
|
type RetryMetadataTool struct {
|
||||||
retryer *MetadataRetryer
|
retryer *MetadataRetryer
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type RetryLocker struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
locks map[uuid.UUID]time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewRetryLocker() *RetryLocker {
|
||||||
|
return &RetryLocker{locks: map[uuid.UUID]time.Time{}}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *RetryLocker) Acquire(id uuid.UUID, ttl time.Duration) bool {
|
||||||
|
l.mu.Lock()
|
||||||
|
defer l.mu.Unlock()
|
||||||
|
if l.locks == nil {
|
||||||
|
l.locks = map[uuid.UUID]time.Time{}
|
||||||
|
}
|
||||||
|
now := time.Now()
|
||||||
|
if exp, ok := l.locks[id]; ok && exp.After(now) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
l.locks[id] = now.Add(ttl)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *RetryLocker) Release(id uuid.UUID) {
|
||||||
|
l.mu.Lock()
|
||||||
|
defer l.mu.Unlock()
|
||||||
|
delete(l.locks, id)
|
||||||
|
}
|
||||||
|
|
||||||
type RetryMetadataInput struct {
|
type RetryMetadataInput struct {
|
||||||
Project string `json:"project,omitempty" jsonschema:"optional project name or id to scope the retry"`
|
Project string `json:"project,omitempty" jsonschema:"optional project name or id to scope the retry"`
|
||||||
Limit int `json:"limit,omitempty" jsonschema:"maximum number of thoughts to process in one call; defaults to 100"`
|
Limit int `json:"limit,omitempty" jsonschema:"maximum number of thoughts to process in one call; defaults to 100"`
|
||||||
@@ -57,18 +87,19 @@ type RetryMetadataOutput struct {
|
|||||||
Failures []RetryMetadataFailure `json:"failures,omitempty"`
|
Failures []RetryMetadataFailure `json:"failures,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewMetadataRetryer(backgroundCtx context.Context, db *store.DB, provider ai.Provider, capture config.CaptureConfig, metadataTimeout time.Duration, sessions *session.ActiveProjects, logger *slog.Logger) *MetadataRetryer {
|
func NewMetadataRetryer(backgroundCtx context.Context, db *store.DB, metadataRunner *ai.MetadataRunner, capture config.CaptureConfig, metadataTimeout time.Duration, sessions *session.ActiveProjects, logger *slog.Logger) *MetadataRetryer {
|
||||||
if backgroundCtx == nil {
|
if backgroundCtx == nil {
|
||||||
backgroundCtx = context.Background()
|
backgroundCtx = context.Background()
|
||||||
}
|
}
|
||||||
return &MetadataRetryer{
|
return &MetadataRetryer{
|
||||||
backgroundCtx: backgroundCtx,
|
backgroundCtx: backgroundCtx,
|
||||||
store: db,
|
store: db,
|
||||||
provider: provider,
|
metadata: metadataRunner,
|
||||||
capture: capture,
|
capture: capture,
|
||||||
sessions: sessions,
|
sessions: sessions,
|
||||||
metadataTimeout: metadataTimeout,
|
metadataTimeout: metadataTimeout,
|
||||||
logger: logger,
|
logger: logger,
|
||||||
|
lock: NewRetryLocker(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -82,9 +113,35 @@ func (t *RetryMetadataTool) Handle(ctx context.Context, req *mcp.CallToolRequest
|
|||||||
|
|
||||||
func (r *MetadataRetryer) QueueThought(id uuid.UUID) {
|
func (r *MetadataRetryer) QueueThought(id uuid.UUID) {
|
||||||
go func() {
|
go func() {
|
||||||
if _, err := r.retryOne(r.backgroundCtx, id); err != nil {
|
started := time.Now()
|
||||||
r.logger.Warn("background metadata retry failed", slog.String("thought_id", id.String()), slog.String("error", err.Error()))
|
if !r.lock.Acquire(id, 15*time.Minute) {
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
defer r.lock.Release(id)
|
||||||
|
|
||||||
|
r.logger.Info("background metadata started",
|
||||||
|
slog.String("thought_id", id.String()),
|
||||||
|
slog.String("provider", r.metadata.PrimaryProvider()),
|
||||||
|
slog.String("model", r.metadata.PrimaryModel()),
|
||||||
|
)
|
||||||
|
updated, err := r.retryOne(r.backgroundCtx, id)
|
||||||
|
if err != nil {
|
||||||
|
r.logger.Warn("background metadata error",
|
||||||
|
slog.String("thought_id", id.String()),
|
||||||
|
slog.String("provider", r.metadata.PrimaryProvider()),
|
||||||
|
slog.String("model", r.metadata.PrimaryModel()),
|
||||||
|
slog.Duration("duration", time.Since(started)),
|
||||||
|
slog.String("error", err.Error()),
|
||||||
|
)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
r.logger.Info("background metadata complete",
|
||||||
|
slog.String("thought_id", id.String()),
|
||||||
|
slog.String("provider", r.metadata.PrimaryProvider()),
|
||||||
|
slog.String("model", r.metadata.PrimaryModel()),
|
||||||
|
slog.Bool("updated", updated),
|
||||||
|
slog.Duration("duration", time.Since(started)),
|
||||||
|
)
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -138,7 +195,14 @@ func (r *MetadataRetryer) Handle(ctx context.Context, req *mcp.CallToolRequest,
|
|||||||
out.Retried++
|
out.Retried++
|
||||||
mu.Unlock()
|
mu.Unlock()
|
||||||
|
|
||||||
|
if !r.lock.Acquire(thought.ID, 15*time.Minute) {
|
||||||
|
mu.Lock()
|
||||||
|
out.Skipped++
|
||||||
|
mu.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
updated, err := r.retryOne(ctx, thought.ID)
|
updated, err := r.retryOne(ctx, thought.ID)
|
||||||
|
r.lock.Release(thought.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
mu.Lock()
|
mu.Lock()
|
||||||
out.Failures = append(out.Failures, RetryMetadataFailure{ID: thought.ID.String(), Error: err.Error()})
|
out.Failures = append(out.Failures, RetryMetadataFailure{ID: thought.ID.String(), Error: err.Error()})
|
||||||
@@ -181,7 +245,7 @@ func (r *MetadataRetryer) retryOne(ctx context.Context, id uuid.UUID) (bool, err
|
|||||||
}
|
}
|
||||||
|
|
||||||
attemptedAt := time.Now().UTC()
|
attemptedAt := time.Now().UTC()
|
||||||
extracted, extractErr := r.provider.ExtractMetadata(attemptCtx, thought.Content)
|
extracted, extractErr := r.metadata.ExtractMetadata(attemptCtx, thought.Content)
|
||||||
if extractErr != nil {
|
if extractErr != nil {
|
||||||
failedMetadata := metadata.MarkMetadataFailed(thought.Metadata, r.capture, attemptedAt, extractErr)
|
failedMetadata := metadata.MarkMetadataFailed(thought.Metadata, r.capture, attemptedAt, extractErr)
|
||||||
if _, updateErr := r.store.UpdateThoughtMetadata(ctx, thought.ID, failedMetadata); updateErr != nil {
|
if _, updateErr := r.store.UpdateThoughtMetadata(ctx, thought.ID, failedMetadata); updateErr != nil {
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ import (
|
|||||||
|
|
||||||
type RecallTool struct {
|
type RecallTool struct {
|
||||||
store *store.DB
|
store *store.DB
|
||||||
provider ai.Provider
|
embeddings *ai.EmbeddingRunner
|
||||||
search config.SearchConfig
|
search config.SearchConfig
|
||||||
sessions *session.ActiveProjects
|
sessions *session.ActiveProjects
|
||||||
}
|
}
|
||||||
@@ -32,8 +32,8 @@ type RecallOutput struct {
|
|||||||
Items []ContextItem `json:"items"`
|
Items []ContextItem `json:"items"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewRecallTool(db *store.DB, provider ai.Provider, search config.SearchConfig, sessions *session.ActiveProjects) *RecallTool {
|
func NewRecallTool(db *store.DB, embeddings *ai.EmbeddingRunner, search config.SearchConfig, sessions *session.ActiveProjects) *RecallTool {
|
||||||
return &RecallTool{store: db, provider: provider, search: search, sessions: sessions}
|
return &RecallTool{store: db, embeddings: embeddings, search: search, sessions: sessions}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *RecallTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in RecallInput) (*mcp.CallToolResult, RecallOutput, error) {
|
func (t *RecallTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in RecallInput) (*mcp.CallToolResult, RecallOutput, error) {
|
||||||
@@ -54,7 +54,7 @@ func (t *RecallTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in Re
|
|||||||
projectID = &project.ID
|
projectID = &project.ID
|
||||||
}
|
}
|
||||||
|
|
||||||
semantic, err := semanticSearch(ctx, t.store, t.provider, t.search, query, limit, t.search.DefaultThreshold, projectID, nil)
|
semantic, err := semanticSearch(ctx, t.store, t.embeddings, t.search, query, limit, t.search.DefaultThreshold, projectID, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, RecallOutput{}, err
|
return nil, RecallOutput{}, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ const metadataReparseConcurrency = 4
|
|||||||
|
|
||||||
type ReparseMetadataTool struct {
|
type ReparseMetadataTool struct {
|
||||||
store *store.DB
|
store *store.DB
|
||||||
provider ai.Provider
|
metadata *ai.MetadataRunner
|
||||||
capture config.CaptureConfig
|
capture config.CaptureConfig
|
||||||
sessions *session.ActiveProjects
|
sessions *session.ActiveProjects
|
||||||
logger *slog.Logger
|
logger *slog.Logger
|
||||||
@@ -53,8 +53,8 @@ type ReparseMetadataOutput struct {
|
|||||||
Failures []ReparseMetadataFailure `json:"failures,omitempty"`
|
Failures []ReparseMetadataFailure `json:"failures,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewReparseMetadataTool(db *store.DB, provider ai.Provider, capture config.CaptureConfig, sessions *session.ActiveProjects, logger *slog.Logger) *ReparseMetadataTool {
|
func NewReparseMetadataTool(db *store.DB, metadataRunner *ai.MetadataRunner, capture config.CaptureConfig, sessions *session.ActiveProjects, logger *slog.Logger) *ReparseMetadataTool {
|
||||||
return &ReparseMetadataTool{store: db, provider: provider, capture: capture, sessions: sessions, logger: logger}
|
return &ReparseMetadataTool{store: db, metadata: metadataRunner, capture: capture, sessions: sessions, logger: logger}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *ReparseMetadataTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in ReparseMetadataInput) (*mcp.CallToolResult, ReparseMetadataOutput, error) {
|
func (t *ReparseMetadataTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in ReparseMetadataInput) (*mcp.CallToolResult, ReparseMetadataOutput, error) {
|
||||||
@@ -107,7 +107,7 @@ func (t *ReparseMetadataTool) Handle(ctx context.Context, req *mcp.CallToolReque
|
|||||||
normalizedCurrent := metadata.Normalize(thought.Metadata, t.capture)
|
normalizedCurrent := metadata.Normalize(thought.Metadata, t.capture)
|
||||||
|
|
||||||
attemptedAt := time.Now().UTC()
|
attemptedAt := time.Now().UTC()
|
||||||
extracted, extractErr := t.provider.ExtractMetadata(ctx, thought.Content)
|
extracted, extractErr := t.metadata.ExtractMetadata(ctx, thought.Content)
|
||||||
normalizedTarget := normalizedCurrent
|
normalizedTarget := normalizedCurrent
|
||||||
if extractErr != nil {
|
if extractErr != nil {
|
||||||
normalizedTarget = metadata.MarkMetadataFailed(normalizedCurrent, t.capture, attemptedAt, extractErr)
|
normalizedTarget = metadata.MarkMetadataFailed(normalizedCurrent, t.capture, attemptedAt, extractErr)
|
||||||
|
|||||||
@@ -11,12 +11,14 @@ import (
|
|||||||
thoughttypes "git.warky.dev/wdevs/amcs/internal/types"
|
thoughttypes "git.warky.dev/wdevs/amcs/internal/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
// semanticSearch runs vector similarity search if embeddings exist for the active model
|
// semanticSearch runs vector similarity search if embeddings exist for the
|
||||||
// in the given scope, otherwise falls back to Postgres full-text search.
|
// primary embedding model in the given scope, otherwise falls back to Postgres
|
||||||
|
// full-text search. Search always uses the primary model so query vectors
|
||||||
|
// match rows stored under the primary model name.
|
||||||
func semanticSearch(
|
func semanticSearch(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
db *store.DB,
|
db *store.DB,
|
||||||
provider ai.Provider,
|
embeddings *ai.EmbeddingRunner,
|
||||||
search config.SearchConfig,
|
search config.SearchConfig,
|
||||||
query string,
|
query string,
|
||||||
limit int,
|
limit int,
|
||||||
@@ -24,17 +26,18 @@ func semanticSearch(
|
|||||||
projectID *uuid.UUID,
|
projectID *uuid.UUID,
|
||||||
excludeID *uuid.UUID,
|
excludeID *uuid.UUID,
|
||||||
) ([]thoughttypes.SearchResult, error) {
|
) ([]thoughttypes.SearchResult, error) {
|
||||||
hasEmbeddings, err := db.HasEmbeddingsForModel(ctx, provider.EmbeddingModel(), projectID)
|
model := embeddings.PrimaryModel()
|
||||||
|
hasEmbeddings, err := db.HasEmbeddingsForModel(ctx, model, projectID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if hasEmbeddings {
|
if hasEmbeddings {
|
||||||
embedding, err := provider.Embed(ctx, query)
|
embedding, err := embeddings.EmbedPrimary(ctx, query)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return db.SearchSimilarThoughts(ctx, embedding, provider.EmbeddingModel(), threshold, limit, projectID, excludeID)
|
return db.SearchSimilarThoughts(ctx, embedding, model, threshold, limit, projectID, excludeID)
|
||||||
}
|
}
|
||||||
|
|
||||||
return db.SearchThoughtsText(ctx, query, limit, projectID, excludeID)
|
return db.SearchThoughtsText(ctx, query, limit, projectID, excludeID)
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ import (
|
|||||||
|
|
||||||
type SearchTool struct {
|
type SearchTool struct {
|
||||||
store *store.DB
|
store *store.DB
|
||||||
provider ai.Provider
|
embeddings *ai.EmbeddingRunner
|
||||||
search config.SearchConfig
|
search config.SearchConfig
|
||||||
sessions *session.ActiveProjects
|
sessions *session.ActiveProjects
|
||||||
}
|
}
|
||||||
@@ -32,8 +32,8 @@ type SearchOutput struct {
|
|||||||
Results []thoughttypes.SearchResult `json:"results"`
|
Results []thoughttypes.SearchResult `json:"results"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSearchTool(db *store.DB, provider ai.Provider, search config.SearchConfig, sessions *session.ActiveProjects) *SearchTool {
|
func NewSearchTool(db *store.DB, embeddings *ai.EmbeddingRunner, search config.SearchConfig, sessions *session.ActiveProjects) *SearchTool {
|
||||||
return &SearchTool{store: db, provider: provider, search: search, sessions: sessions}
|
return &SearchTool{store: db, embeddings: embeddings, search: search, sessions: sessions}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *SearchTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in SearchInput) (*mcp.CallToolResult, SearchOutput, error) {
|
func (t *SearchTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in SearchInput) (*mcp.CallToolResult, SearchOutput, error) {
|
||||||
@@ -56,7 +56,7 @@ func (t *SearchTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in Se
|
|||||||
_ = t.store.TouchProject(ctx, project.ID)
|
_ = t.store.TouchProject(ctx, project.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
results, err := semanticSearch(ctx, t.store, t.provider, t.search, query, limit, threshold, projectID, nil)
|
results, err := semanticSearch(ctx, t.store, t.embeddings, t.search, query, limit, threshold, projectID, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, SearchOutput{}, err
|
return nil, SearchOutput{}, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,7 +15,8 @@ import (
|
|||||||
|
|
||||||
type SummarizeTool struct {
|
type SummarizeTool struct {
|
||||||
store *store.DB
|
store *store.DB
|
||||||
provider ai.Provider
|
embeddings *ai.EmbeddingRunner
|
||||||
|
metadata *ai.MetadataRunner
|
||||||
search config.SearchConfig
|
search config.SearchConfig
|
||||||
sessions *session.ActiveProjects
|
sessions *session.ActiveProjects
|
||||||
}
|
}
|
||||||
@@ -32,8 +33,8 @@ type SummarizeOutput struct {
|
|||||||
Count int `json:"count"`
|
Count int `json:"count"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSummarizeTool(db *store.DB, provider ai.Provider, search config.SearchConfig, sessions *session.ActiveProjects) *SummarizeTool {
|
func NewSummarizeTool(db *store.DB, embeddings *ai.EmbeddingRunner, metadata *ai.MetadataRunner, search config.SearchConfig, sessions *session.ActiveProjects) *SummarizeTool {
|
||||||
return &SummarizeTool{store: db, provider: provider, search: search, sessions: sessions}
|
return &SummarizeTool{store: db, embeddings: embeddings, metadata: metadata, search: search, sessions: sessions}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *SummarizeTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in SummarizeInput) (*mcp.CallToolResult, SummarizeOutput, error) {
|
func (t *SummarizeTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in SummarizeInput) (*mcp.CallToolResult, SummarizeOutput, error) {
|
||||||
@@ -52,7 +53,7 @@ func (t *SummarizeTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in
|
|||||||
if project != nil {
|
if project != nil {
|
||||||
projectID = &project.ID
|
projectID = &project.ID
|
||||||
}
|
}
|
||||||
results, err := semanticSearch(ctx, t.store, t.provider, t.search, query, limit, t.search.DefaultThreshold, projectID, nil)
|
results, err := semanticSearch(ctx, t.store, t.embeddings, t.search, query, limit, t.search.DefaultThreshold, projectID, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, SummarizeOutput{}, err
|
return nil, SummarizeOutput{}, err
|
||||||
}
|
}
|
||||||
@@ -77,7 +78,7 @@ func (t *SummarizeTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in
|
|||||||
|
|
||||||
userPrompt := formatContextBlock("Summarize the following thoughts into concise prose with themes, action items, and notable people.", lines)
|
userPrompt := formatContextBlock("Summarize the following thoughts into concise prose with themes, action items, and notable people.", lines)
|
||||||
systemPrompt := "You summarize note collections. Be concise, concrete, and structured in plain prose."
|
systemPrompt := "You summarize note collections. Be concise, concrete, and structured in plain prose."
|
||||||
summary, err := t.provider.Summarize(ctx, systemPrompt, userPrompt)
|
summary, err := t.metadata.Summarize(ctx, systemPrompt, userPrompt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, SummarizeOutput{}, err
|
return nil, SummarizeOutput{}, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -17,7 +17,8 @@ import (
|
|||||||
|
|
||||||
type UpdateTool struct {
|
type UpdateTool struct {
|
||||||
store *store.DB
|
store *store.DB
|
||||||
provider ai.Provider
|
embeddings *ai.EmbeddingRunner
|
||||||
|
metadata *ai.MetadataRunner
|
||||||
capture config.CaptureConfig
|
capture config.CaptureConfig
|
||||||
log *slog.Logger
|
log *slog.Logger
|
||||||
}
|
}
|
||||||
@@ -33,8 +34,8 @@ type UpdateOutput struct {
|
|||||||
Thought thoughttypes.Thought `json:"thought"`
|
Thought thoughttypes.Thought `json:"thought"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewUpdateTool(db *store.DB, provider ai.Provider, capture config.CaptureConfig, log *slog.Logger) *UpdateTool {
|
func NewUpdateTool(db *store.DB, embeddings *ai.EmbeddingRunner, metadata *ai.MetadataRunner, capture config.CaptureConfig, log *slog.Logger) *UpdateTool {
|
||||||
return &UpdateTool{store: db, provider: provider, capture: capture, log: log}
|
return &UpdateTool{store: db, embeddings: embeddings, metadata: metadata, capture: capture, log: log}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *UpdateTool) Handle(ctx context.Context, _ *mcp.CallToolRequest, in UpdateInput) (*mcp.CallToolResult, UpdateOutput, error) {
|
func (t *UpdateTool) Handle(ctx context.Context, _ *mcp.CallToolRequest, in UpdateInput) (*mcp.CallToolResult, UpdateOutput, error) {
|
||||||
@@ -50,6 +51,7 @@ func (t *UpdateTool) Handle(ctx context.Context, _ *mcp.CallToolRequest, in Upda
|
|||||||
|
|
||||||
content := current.Content
|
content := current.Content
|
||||||
var embedding []float32
|
var embedding []float32
|
||||||
|
embeddingModel := ""
|
||||||
mergedMetadata := current.Metadata
|
mergedMetadata := current.Metadata
|
||||||
projectID := current.ProjectID
|
projectID := current.ProjectID
|
||||||
|
|
||||||
@@ -58,11 +60,13 @@ func (t *UpdateTool) Handle(ctx context.Context, _ *mcp.CallToolRequest, in Upda
|
|||||||
if content == "" {
|
if content == "" {
|
||||||
return nil, UpdateOutput{}, errInvalidInput("content must not be empty")
|
return nil, UpdateOutput{}, errInvalidInput("content must not be empty")
|
||||||
}
|
}
|
||||||
embedding, err = t.provider.Embed(ctx, content)
|
embedResult, err := t.embeddings.Embed(ctx, content)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, UpdateOutput{}, err
|
return nil, UpdateOutput{}, err
|
||||||
}
|
}
|
||||||
extracted, extractErr := t.provider.ExtractMetadata(ctx, content)
|
embedding = embedResult.Vector
|
||||||
|
embeddingModel = embedResult.Model
|
||||||
|
extracted, extractErr := t.metadata.ExtractMetadata(ctx, content)
|
||||||
if extractErr != nil {
|
if extractErr != nil {
|
||||||
t.log.Warn("metadata extraction failed during update, keeping current metadata", slog.String("error", extractErr.Error()))
|
t.log.Warn("metadata extraction failed during update, keeping current metadata", slog.String("error", extractErr.Error()))
|
||||||
mergedMetadata = metadata.MarkMetadataFailed(mergedMetadata, t.capture, time.Now().UTC(), extractErr)
|
mergedMetadata = metadata.MarkMetadataFailed(mergedMetadata, t.capture, time.Now().UTC(), extractErr)
|
||||||
@@ -82,7 +86,7 @@ func (t *UpdateTool) Handle(ctx context.Context, _ *mcp.CallToolRequest, in Upda
|
|||||||
projectID = &project.ID
|
projectID = &project.ID
|
||||||
}
|
}
|
||||||
|
|
||||||
updated, err := t.store.UpdateThought(ctx, id, content, embedding, t.provider.EmbeddingModel(), mergedMetadata, projectID)
|
updated, err := t.store.UpdateThought(ctx, id, content, embedding, embeddingModel, mergedMetadata, projectID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, UpdateOutput{}, err
|
return nil, UpdateOutput{}, err
|
||||||
}
|
}
|
||||||
|
|||||||
68
internal/types/learning.go
Normal file
68
internal/types/learning.go
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
package types
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
type LearningEvidenceLevel string
|
||||||
|
|
||||||
|
const (
|
||||||
|
LearningEvidenceHypothesis LearningEvidenceLevel = "hypothesis"
|
||||||
|
LearningEvidenceObserved LearningEvidenceLevel = "observed"
|
||||||
|
LearningEvidenceVerified LearningEvidenceLevel = "verified"
|
||||||
|
)
|
||||||
|
|
||||||
|
type LearningStatus string
|
||||||
|
|
||||||
|
const (
|
||||||
|
LearningStatusPending LearningStatus = "pending"
|
||||||
|
LearningStatusInProgress LearningStatus = "in_progress"
|
||||||
|
LearningStatusResolved LearningStatus = "resolved"
|
||||||
|
LearningStatusWontFix LearningStatus = "wont_fix"
|
||||||
|
LearningStatusPromoted LearningStatus = "promoted"
|
||||||
|
)
|
||||||
|
|
||||||
|
type LearningPriority string
|
||||||
|
|
||||||
|
const (
|
||||||
|
LearningPriorityLow LearningPriority = "low"
|
||||||
|
LearningPriorityMedium LearningPriority = "medium"
|
||||||
|
LearningPriorityHigh LearningPriority = "high"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Learning struct {
|
||||||
|
ID uuid.UUID `json:"id"`
|
||||||
|
Summary string `json:"summary"`
|
||||||
|
Details string `json:"details"`
|
||||||
|
Category string `json:"category"`
|
||||||
|
Area string `json:"area"`
|
||||||
|
Status LearningStatus `json:"status"`
|
||||||
|
Priority LearningPriority `json:"priority"`
|
||||||
|
Confidence LearningEvidenceLevel `json:"confidence"`
|
||||||
|
ActionRequired bool `json:"action_required"`
|
||||||
|
SourceType string `json:"source_type,omitempty"`
|
||||||
|
SourceRef string `json:"source_ref,omitempty"`
|
||||||
|
ProjectID *uuid.UUID `json:"project_id,omitempty"`
|
||||||
|
RelatedThoughtID *uuid.UUID `json:"related_thought_id,omitempty"`
|
||||||
|
RelatedSkillID *uuid.UUID `json:"related_skill_id,omitempty"`
|
||||||
|
ReviewedBy *string `json:"reviewed_by,omitempty"`
|
||||||
|
ReviewedAt *time.Time `json:"reviewed_at,omitempty"`
|
||||||
|
DuplicateOfLearningID *uuid.UUID `json:"duplicate_of_learning_id,omitempty"`
|
||||||
|
SupersedesLearningID *uuid.UUID `json:"supersedes_learning_id,omitempty"`
|
||||||
|
Tags []string `json:"tags"`
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
UpdatedAt time.Time `json:"updated_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type LearningFilter struct {
|
||||||
|
Limit int
|
||||||
|
ProjectID *uuid.UUID
|
||||||
|
Category string
|
||||||
|
Area string
|
||||||
|
Status string
|
||||||
|
Priority string
|
||||||
|
Tag string
|
||||||
|
Query string
|
||||||
|
}
|
||||||
@@ -55,6 +55,7 @@ type Thought struct {
|
|||||||
ID uuid.UUID `json:"id"`
|
ID uuid.UUID `json:"id"`
|
||||||
Content string `json:"content"`
|
Content string `json:"content"`
|
||||||
Embedding []float32 `json:"embedding,omitempty"`
|
Embedding []float32 `json:"embedding,omitempty"`
|
||||||
|
EmbeddingStatus string `json:"embedding_status,omitempty"`
|
||||||
Metadata ThoughtMetadata `json:"metadata"`
|
Metadata ThoughtMetadata `json:"metadata"`
|
||||||
ProjectID *uuid.UUID `json:"project_id,omitempty"`
|
ProjectID *uuid.UUID `json:"project_id,omitempty"`
|
||||||
ArchivedAt *time.Time `json:"archived_at,omitempty"`
|
ArchivedAt *time.Time `json:"archived_at,omitempty"`
|
||||||
|
|||||||
@@ -275,6 +275,30 @@ CREATE TABLE IF NOT EXISTS public.tool_annotations (
|
|||||||
updated_at timestamptz NOT NULL DEFAULT now()
|
updated_at timestamptz NOT NULL DEFAULT now()
|
||||||
);
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS public.learnings (
|
||||||
|
action_required boolean NOT NULL DEFAULT false,
|
||||||
|
area text NOT NULL DEFAULT 'other',
|
||||||
|
category text NOT NULL DEFAULT 'insight',
|
||||||
|
confidence text NOT NULL DEFAULT 'hypothesis',
|
||||||
|
created_at timestamptz NOT NULL DEFAULT now(),
|
||||||
|
details text NOT NULL DEFAULT '',
|
||||||
|
duplicate_of_learning_id uuid,
|
||||||
|
id uuid NOT NULL DEFAULT gen_random_uuid(),
|
||||||
|
priority text NOT NULL DEFAULT 'medium',
|
||||||
|
project_id uuid,
|
||||||
|
related_skill_id uuid,
|
||||||
|
related_thought_id uuid,
|
||||||
|
reviewed_at timestamptz,
|
||||||
|
reviewed_by text,
|
||||||
|
source_ref text,
|
||||||
|
source_type text,
|
||||||
|
status text NOT NULL DEFAULT 'pending',
|
||||||
|
summary text NOT NULL,
|
||||||
|
supersedes_learning_id uuid,
|
||||||
|
tags text,
|
||||||
|
updated_at timestamptz NOT NULL DEFAULT now()
|
||||||
|
);
|
||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS public.agent_skills (
|
CREATE TABLE IF NOT EXISTS public.agent_skills (
|
||||||
content text NOT NULL,
|
content text NOT NULL,
|
||||||
created_at timestamptz NOT NULL DEFAULT now(),
|
created_at timestamptz NOT NULL DEFAULT now(),
|
||||||
@@ -2597,6 +2621,279 @@ BEGIN
|
|||||||
END;
|
END;
|
||||||
$$;
|
$$;
|
||||||
|
|
||||||
|
DO $$
|
||||||
|
BEGIN
|
||||||
|
IF NOT EXISTS (
|
||||||
|
SELECT 1 FROM information_schema.columns
|
||||||
|
WHERE table_schema = 'public'
|
||||||
|
AND table_name = 'learnings'
|
||||||
|
AND column_name = 'action_required'
|
||||||
|
) THEN
|
||||||
|
ALTER TABLE public.learnings ADD COLUMN action_required boolean NOT NULL DEFAULT false;
|
||||||
|
END IF;
|
||||||
|
END;
|
||||||
|
$$;
|
||||||
|
|
||||||
|
DO $$
|
||||||
|
BEGIN
|
||||||
|
IF NOT EXISTS (
|
||||||
|
SELECT 1 FROM information_schema.columns
|
||||||
|
WHERE table_schema = 'public'
|
||||||
|
AND table_name = 'learnings'
|
||||||
|
AND column_name = 'area'
|
||||||
|
) THEN
|
||||||
|
ALTER TABLE public.learnings ADD COLUMN area text NOT NULL DEFAULT 'other';
|
||||||
|
END IF;
|
||||||
|
END;
|
||||||
|
$$;
|
||||||
|
|
||||||
|
DO $$
|
||||||
|
BEGIN
|
||||||
|
IF NOT EXISTS (
|
||||||
|
SELECT 1 FROM information_schema.columns
|
||||||
|
WHERE table_schema = 'public'
|
||||||
|
AND table_name = 'learnings'
|
||||||
|
AND column_name = 'category'
|
||||||
|
) THEN
|
||||||
|
ALTER TABLE public.learnings ADD COLUMN category text NOT NULL DEFAULT 'insight';
|
||||||
|
END IF;
|
||||||
|
END;
|
||||||
|
$$;
|
||||||
|
|
||||||
|
DO $$
|
||||||
|
BEGIN
|
||||||
|
IF NOT EXISTS (
|
||||||
|
SELECT 1 FROM information_schema.columns
|
||||||
|
WHERE table_schema = 'public'
|
||||||
|
AND table_name = 'learnings'
|
||||||
|
AND column_name = 'confidence'
|
||||||
|
) THEN
|
||||||
|
ALTER TABLE public.learnings ADD COLUMN confidence text NOT NULL DEFAULT 'hypothesis';
|
||||||
|
END IF;
|
||||||
|
END;
|
||||||
|
$$;
|
||||||
|
|
||||||
|
DO $$
|
||||||
|
BEGIN
|
||||||
|
IF NOT EXISTS (
|
||||||
|
SELECT 1 FROM information_schema.columns
|
||||||
|
WHERE table_schema = 'public'
|
||||||
|
AND table_name = 'learnings'
|
||||||
|
AND column_name = 'created_at'
|
||||||
|
) THEN
|
||||||
|
ALTER TABLE public.learnings ADD COLUMN created_at timestamptz NOT NULL DEFAULT now();
|
||||||
|
END IF;
|
||||||
|
END;
|
||||||
|
$$;
|
||||||
|
|
||||||
|
DO $$
|
||||||
|
BEGIN
|
||||||
|
IF NOT EXISTS (
|
||||||
|
SELECT 1 FROM information_schema.columns
|
||||||
|
WHERE table_schema = 'public'
|
||||||
|
AND table_name = 'learnings'
|
||||||
|
AND column_name = 'details'
|
||||||
|
) THEN
|
||||||
|
ALTER TABLE public.learnings ADD COLUMN details text NOT NULL DEFAULT '';
|
||||||
|
END IF;
|
||||||
|
END;
|
||||||
|
$$;
|
||||||
|
|
||||||
|
DO $$
|
||||||
|
BEGIN
|
||||||
|
IF NOT EXISTS (
|
||||||
|
SELECT 1 FROM information_schema.columns
|
||||||
|
WHERE table_schema = 'public'
|
||||||
|
AND table_name = 'learnings'
|
||||||
|
AND column_name = 'duplicate_of_learning_id'
|
||||||
|
) THEN
|
||||||
|
ALTER TABLE public.learnings ADD COLUMN duplicate_of_learning_id uuid;
|
||||||
|
END IF;
|
||||||
|
END;
|
||||||
|
$$;
|
||||||
|
|
||||||
|
DO $$
|
||||||
|
BEGIN
|
||||||
|
IF NOT EXISTS (
|
||||||
|
SELECT 1 FROM information_schema.columns
|
||||||
|
WHERE table_schema = 'public'
|
||||||
|
AND table_name = 'learnings'
|
||||||
|
AND column_name = 'id'
|
||||||
|
) THEN
|
||||||
|
ALTER TABLE public.learnings ADD COLUMN id uuid NOT NULL DEFAULT gen_random_uuid();
|
||||||
|
END IF;
|
||||||
|
END;
|
||||||
|
$$;
|
||||||
|
|
||||||
|
DO $$
|
||||||
|
BEGIN
|
||||||
|
IF NOT EXISTS (
|
||||||
|
SELECT 1 FROM information_schema.columns
|
||||||
|
WHERE table_schema = 'public'
|
||||||
|
AND table_name = 'learnings'
|
||||||
|
AND column_name = 'priority'
|
||||||
|
) THEN
|
||||||
|
ALTER TABLE public.learnings ADD COLUMN priority text NOT NULL DEFAULT 'medium';
|
||||||
|
END IF;
|
||||||
|
END;
|
||||||
|
$$;
|
||||||
|
|
||||||
|
DO $$
|
||||||
|
BEGIN
|
||||||
|
IF NOT EXISTS (
|
||||||
|
SELECT 1 FROM information_schema.columns
|
||||||
|
WHERE table_schema = 'public'
|
||||||
|
AND table_name = 'learnings'
|
||||||
|
AND column_name = 'project_id'
|
||||||
|
) THEN
|
||||||
|
ALTER TABLE public.learnings ADD COLUMN project_id uuid;
|
||||||
|
END IF;
|
||||||
|
END;
|
||||||
|
$$;
|
||||||
|
|
||||||
|
DO $$
|
||||||
|
BEGIN
|
||||||
|
IF NOT EXISTS (
|
||||||
|
SELECT 1 FROM information_schema.columns
|
||||||
|
WHERE table_schema = 'public'
|
||||||
|
AND table_name = 'learnings'
|
||||||
|
AND column_name = 'related_skill_id'
|
||||||
|
) THEN
|
||||||
|
ALTER TABLE public.learnings ADD COLUMN related_skill_id uuid;
|
||||||
|
END IF;
|
||||||
|
END;
|
||||||
|
$$;
|
||||||
|
|
||||||
|
DO $$
|
||||||
|
BEGIN
|
||||||
|
IF NOT EXISTS (
|
||||||
|
SELECT 1 FROM information_schema.columns
|
||||||
|
WHERE table_schema = 'public'
|
||||||
|
AND table_name = 'learnings'
|
||||||
|
AND column_name = 'related_thought_id'
|
||||||
|
) THEN
|
||||||
|
ALTER TABLE public.learnings ADD COLUMN related_thought_id uuid;
|
||||||
|
END IF;
|
||||||
|
END;
|
||||||
|
$$;
|
||||||
|
|
||||||
|
DO $$
|
||||||
|
BEGIN
|
||||||
|
IF NOT EXISTS (
|
||||||
|
SELECT 1 FROM information_schema.columns
|
||||||
|
WHERE table_schema = 'public'
|
||||||
|
AND table_name = 'learnings'
|
||||||
|
AND column_name = 'reviewed_at'
|
||||||
|
) THEN
|
||||||
|
ALTER TABLE public.learnings ADD COLUMN reviewed_at timestamptz;
|
||||||
|
END IF;
|
||||||
|
END;
|
||||||
|
$$;
|
||||||
|
|
||||||
|
DO $$
|
||||||
|
BEGIN
|
||||||
|
IF NOT EXISTS (
|
||||||
|
SELECT 1 FROM information_schema.columns
|
||||||
|
WHERE table_schema = 'public'
|
||||||
|
AND table_name = 'learnings'
|
||||||
|
AND column_name = 'reviewed_by'
|
||||||
|
) THEN
|
||||||
|
ALTER TABLE public.learnings ADD COLUMN reviewed_by text;
|
||||||
|
END IF;
|
||||||
|
END;
|
||||||
|
$$;
|
||||||
|
|
||||||
|
DO $$
|
||||||
|
BEGIN
|
||||||
|
IF NOT EXISTS (
|
||||||
|
SELECT 1 FROM information_schema.columns
|
||||||
|
WHERE table_schema = 'public'
|
||||||
|
AND table_name = 'learnings'
|
||||||
|
AND column_name = 'source_ref'
|
||||||
|
) THEN
|
||||||
|
ALTER TABLE public.learnings ADD COLUMN source_ref text;
|
||||||
|
END IF;
|
||||||
|
END;
|
||||||
|
$$;
|
||||||
|
|
||||||
|
DO $$
|
||||||
|
BEGIN
|
||||||
|
IF NOT EXISTS (
|
||||||
|
SELECT 1 FROM information_schema.columns
|
||||||
|
WHERE table_schema = 'public'
|
||||||
|
AND table_name = 'learnings'
|
||||||
|
AND column_name = 'source_type'
|
||||||
|
) THEN
|
||||||
|
ALTER TABLE public.learnings ADD COLUMN source_type text;
|
||||||
|
END IF;
|
||||||
|
END;
|
||||||
|
$$;
|
||||||
|
|
||||||
|
DO $$
|
||||||
|
BEGIN
|
||||||
|
IF NOT EXISTS (
|
||||||
|
SELECT 1 FROM information_schema.columns
|
||||||
|
WHERE table_schema = 'public'
|
||||||
|
AND table_name = 'learnings'
|
||||||
|
AND column_name = 'status'
|
||||||
|
) THEN
|
||||||
|
ALTER TABLE public.learnings ADD COLUMN status text NOT NULL DEFAULT 'pending';
|
||||||
|
END IF;
|
||||||
|
END;
|
||||||
|
$$;
|
||||||
|
|
||||||
|
DO $$
|
||||||
|
BEGIN
|
||||||
|
IF NOT EXISTS (
|
||||||
|
SELECT 1 FROM information_schema.columns
|
||||||
|
WHERE table_schema = 'public'
|
||||||
|
AND table_name = 'learnings'
|
||||||
|
AND column_name = 'summary'
|
||||||
|
) THEN
|
||||||
|
ALTER TABLE public.learnings ADD COLUMN summary text NOT NULL;
|
||||||
|
END IF;
|
||||||
|
END;
|
||||||
|
$$;
|
||||||
|
|
||||||
|
DO $$
|
||||||
|
BEGIN
|
||||||
|
IF NOT EXISTS (
|
||||||
|
SELECT 1 FROM information_schema.columns
|
||||||
|
WHERE table_schema = 'public'
|
||||||
|
AND table_name = 'learnings'
|
||||||
|
AND column_name = 'supersedes_learning_id'
|
||||||
|
) THEN
|
||||||
|
ALTER TABLE public.learnings ADD COLUMN supersedes_learning_id uuid;
|
||||||
|
END IF;
|
||||||
|
END;
|
||||||
|
$$;
|
||||||
|
|
||||||
|
DO $$
|
||||||
|
BEGIN
|
||||||
|
IF NOT EXISTS (
|
||||||
|
SELECT 1 FROM information_schema.columns
|
||||||
|
WHERE table_schema = 'public'
|
||||||
|
AND table_name = 'learnings'
|
||||||
|
AND column_name = 'tags'
|
||||||
|
) THEN
|
||||||
|
ALTER TABLE public.learnings ADD COLUMN tags text;
|
||||||
|
END IF;
|
||||||
|
END;
|
||||||
|
$$;
|
||||||
|
|
||||||
|
DO $$
|
||||||
|
BEGIN
|
||||||
|
IF NOT EXISTS (
|
||||||
|
SELECT 1 FROM information_schema.columns
|
||||||
|
WHERE table_schema = 'public'
|
||||||
|
AND table_name = 'learnings'
|
||||||
|
AND column_name = 'updated_at'
|
||||||
|
) THEN
|
||||||
|
ALTER TABLE public.learnings ADD COLUMN updated_at timestamptz NOT NULL DEFAULT now();
|
||||||
|
END IF;
|
||||||
|
END;
|
||||||
|
$$;
|
||||||
|
|
||||||
DO $$
|
DO $$
|
||||||
BEGIN
|
BEGIN
|
||||||
IF NOT EXISTS (
|
IF NOT EXISTS (
|
||||||
@@ -3403,6 +3700,34 @@ BEGIN
|
|||||||
END;
|
END;
|
||||||
$$;
|
$$;
|
||||||
|
|
||||||
|
DO $$
|
||||||
|
DECLARE
|
||||||
|
auto_pk_name text;
|
||||||
|
BEGIN
|
||||||
|
-- Drop auto-generated primary key if it exists
|
||||||
|
SELECT constraint_name INTO auto_pk_name
|
||||||
|
FROM information_schema.table_constraints
|
||||||
|
WHERE table_schema = 'public'
|
||||||
|
AND table_name = 'learnings'
|
||||||
|
AND constraint_type = 'PRIMARY KEY'
|
||||||
|
AND constraint_name IN ('learnings_pkey', 'public_learnings_pkey');
|
||||||
|
|
||||||
|
IF auto_pk_name IS NOT NULL THEN
|
||||||
|
EXECUTE 'ALTER TABLE public.learnings DROP CONSTRAINT ' || quote_ident(auto_pk_name);
|
||||||
|
END IF;
|
||||||
|
|
||||||
|
-- Add named primary key if it doesn't exist
|
||||||
|
IF NOT EXISTS (
|
||||||
|
SELECT 1 FROM information_schema.table_constraints
|
||||||
|
WHERE table_schema = 'public'
|
||||||
|
AND table_name = 'learnings'
|
||||||
|
AND constraint_name = 'pk_public_learnings'
|
||||||
|
) THEN
|
||||||
|
ALTER TABLE public.learnings ADD CONSTRAINT pk_public_learnings PRIMARY KEY (id);
|
||||||
|
END IF;
|
||||||
|
END;
|
||||||
|
$$;
|
||||||
|
|
||||||
DO $$
|
DO $$
|
||||||
DECLARE
|
DECLARE
|
||||||
auto_pk_name text;
|
auto_pk_name text;
|
||||||
@@ -3475,6 +3800,15 @@ CREATE INDEX IF NOT EXISTS idx_contact_interactions_contact_id_occurred_at
|
|||||||
CREATE INDEX IF NOT EXISTS idx_maintenance_logs_task_id_completed_at
|
CREATE INDEX IF NOT EXISTS idx_maintenance_logs_task_id_completed_at
|
||||||
ON public.maintenance_logs USING btree (task_id, completed_at);
|
ON public.maintenance_logs USING btree (task_id, completed_at);
|
||||||
|
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_learnings_details
|
||||||
|
ON public.learnings USING gin (details gin_trgm_ops);
|
||||||
|
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_learnings_summary
|
||||||
|
ON public.learnings USING gin (summary gin_trgm_ops);
|
||||||
|
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_learnings_tags
|
||||||
|
ON public.learnings USING gin (tags gin_trgm_ops);
|
||||||
|
|
||||||
CREATE INDEX IF NOT EXISTS idx_project_skills_project_id_skill_id
|
CREATE INDEX IF NOT EXISTS idx_project_skills_project_id_skill_id
|
||||||
ON public.project_skills USING btree (project_id, skill_id);
|
ON public.project_skills USING btree (project_id, skill_id);
|
||||||
|
|
||||||
@@ -3810,6 +4144,86 @@ BEGIN
|
|||||||
END IF;
|
END IF;
|
||||||
END;
|
END;
|
||||||
$$;DO $$
|
$$;DO $$
|
||||||
|
BEGIN
|
||||||
|
IF NOT EXISTS (
|
||||||
|
SELECT 1 FROM information_schema.table_constraints
|
||||||
|
WHERE table_schema = 'public'
|
||||||
|
AND table_name = 'learnings'
|
||||||
|
AND constraint_name = 'fk_learnings_duplicate_of_learning_id'
|
||||||
|
) THEN
|
||||||
|
ALTER TABLE public.learnings
|
||||||
|
ADD CONSTRAINT fk_learnings_duplicate_of_learning_id
|
||||||
|
FOREIGN KEY (duplicate_of_learning_id)
|
||||||
|
REFERENCES public.learnings (id)
|
||||||
|
ON DELETE NO ACTION
|
||||||
|
ON UPDATE NO ACTION;
|
||||||
|
END IF;
|
||||||
|
END;
|
||||||
|
$$;DO $$
|
||||||
|
BEGIN
|
||||||
|
IF NOT EXISTS (
|
||||||
|
SELECT 1 FROM information_schema.table_constraints
|
||||||
|
WHERE table_schema = 'public'
|
||||||
|
AND table_name = 'learnings'
|
||||||
|
AND constraint_name = 'fk_learnings_project_id'
|
||||||
|
) THEN
|
||||||
|
ALTER TABLE public.learnings
|
||||||
|
ADD CONSTRAINT fk_learnings_project_id
|
||||||
|
FOREIGN KEY (project_id)
|
||||||
|
REFERENCES public.projects (guid)
|
||||||
|
ON DELETE NO ACTION
|
||||||
|
ON UPDATE NO ACTION;
|
||||||
|
END IF;
|
||||||
|
END;
|
||||||
|
$$;DO $$
|
||||||
|
BEGIN
|
||||||
|
IF NOT EXISTS (
|
||||||
|
SELECT 1 FROM information_schema.table_constraints
|
||||||
|
WHERE table_schema = 'public'
|
||||||
|
AND table_name = 'learnings'
|
||||||
|
AND constraint_name = 'fk_learnings_related_skill_id'
|
||||||
|
) THEN
|
||||||
|
ALTER TABLE public.learnings
|
||||||
|
ADD CONSTRAINT fk_learnings_related_skill_id
|
||||||
|
FOREIGN KEY (related_skill_id)
|
||||||
|
REFERENCES public.agent_skills (id)
|
||||||
|
ON DELETE NO ACTION
|
||||||
|
ON UPDATE NO ACTION;
|
||||||
|
END IF;
|
||||||
|
END;
|
||||||
|
$$;DO $$
|
||||||
|
BEGIN
|
||||||
|
IF NOT EXISTS (
|
||||||
|
SELECT 1 FROM information_schema.table_constraints
|
||||||
|
WHERE table_schema = 'public'
|
||||||
|
AND table_name = 'learnings'
|
||||||
|
AND constraint_name = 'fk_learnings_related_thought_id'
|
||||||
|
) THEN
|
||||||
|
ALTER TABLE public.learnings
|
||||||
|
ADD CONSTRAINT fk_learnings_related_thought_id
|
||||||
|
FOREIGN KEY (related_thought_id)
|
||||||
|
REFERENCES public.thoughts (guid)
|
||||||
|
ON DELETE NO ACTION
|
||||||
|
ON UPDATE NO ACTION;
|
||||||
|
END IF;
|
||||||
|
END;
|
||||||
|
$$;DO $$
|
||||||
|
BEGIN
|
||||||
|
IF NOT EXISTS (
|
||||||
|
SELECT 1 FROM information_schema.table_constraints
|
||||||
|
WHERE table_schema = 'public'
|
||||||
|
AND table_name = 'learnings'
|
||||||
|
AND constraint_name = 'fk_learnings_supersedes_learning_id'
|
||||||
|
) THEN
|
||||||
|
ALTER TABLE public.learnings
|
||||||
|
ADD CONSTRAINT fk_learnings_supersedes_learning_id
|
||||||
|
FOREIGN KEY (supersedes_learning_id)
|
||||||
|
REFERENCES public.learnings (id)
|
||||||
|
ON DELETE NO ACTION
|
||||||
|
ON UPDATE NO ACTION;
|
||||||
|
END IF;
|
||||||
|
END;
|
||||||
|
$$;DO $$
|
||||||
BEGIN
|
BEGIN
|
||||||
IF NOT EXISTS (
|
IF NOT EXISTS (
|
||||||
SELECT 1 FROM information_schema.table_constraints
|
SELECT 1 FROM information_schema.table_constraints
|
||||||
@@ -3992,5 +4406,6 @@ $$;
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -30,3 +30,46 @@ Table tool_annotations {
|
|||||||
|
|
||||||
// Cross-file refs (for relspecgo merge)
|
// Cross-file refs (for relspecgo merge)
|
||||||
Ref: chat_histories.project_id > projects.guid [delete: set null]
|
Ref: chat_histories.project_id > projects.guid [delete: set null]
|
||||||
|
|
||||||
|
Table learnings {
|
||||||
|
id uuid [pk, default: `gen_random_uuid()`]
|
||||||
|
summary text [not null]
|
||||||
|
details text [not null, default: '']
|
||||||
|
category text [not null, default: 'insight']
|
||||||
|
area text [not null, default: 'other']
|
||||||
|
status text [not null, default: 'pending']
|
||||||
|
priority text [not null, default: 'medium']
|
||||||
|
confidence text [not null, default: 'hypothesis']
|
||||||
|
action_required boolean [not null, default: false]
|
||||||
|
source_type text
|
||||||
|
source_ref text
|
||||||
|
project_id uuid [ref: > projects.guid]
|
||||||
|
related_thought_id uuid [ref: > thoughts.guid]
|
||||||
|
related_skill_id uuid [ref: > agent_skills.id]
|
||||||
|
reviewed_by text
|
||||||
|
reviewed_at timestamptz
|
||||||
|
duplicate_of_learning_id uuid [ref: > learnings.id]
|
||||||
|
supersedes_learning_id uuid [ref: > learnings.id]
|
||||||
|
tags "text[]" [not null, default: `'{}'`]
|
||||||
|
created_at timestamptz [not null, default: `now()`]
|
||||||
|
updated_at timestamptz [not null, default: `now()`]
|
||||||
|
|
||||||
|
indexes {
|
||||||
|
project_id
|
||||||
|
category
|
||||||
|
area
|
||||||
|
status
|
||||||
|
priority
|
||||||
|
reviewed_at
|
||||||
|
tags [type: gin]
|
||||||
|
summary [type: gin]
|
||||||
|
details [type: gin]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cross-file refs (for relspecgo merge)
|
||||||
|
Ref: learnings.project_id > projects.guid [delete: set null]
|
||||||
|
Ref: learnings.related_thought_id > thoughts.guid [delete: set null]
|
||||||
|
Ref: learnings.related_skill_id > agent_skills.id [delete: set null]
|
||||||
|
Ref: learnings.duplicate_of_learning_id > learnings.id [delete: set null]
|
||||||
|
Ref: learnings.supersedes_learning_id > learnings.id [delete: set null]
|
||||||
|
|||||||
@@ -19,5 +19,14 @@
|
|||||||
"tailwindcss": "^4.1.4",
|
"tailwindcss": "^4.1.4",
|
||||||
"typescript": "^5.8.3",
|
"typescript": "^5.8.3",
|
||||||
"vite": "^6.3.2"
|
"vite": "^6.3.2"
|
||||||
|
},
|
||||||
|
"dependencies": {
|
||||||
|
"@sentry/svelte": "^10.49.0",
|
||||||
|
"@skeletonlabs/skeleton": "^4.15.2",
|
||||||
|
"@skeletonlabs/skeleton-svelte": "^4.15.2",
|
||||||
|
"@tanstack/svelte-virtual": "^3.13.24",
|
||||||
|
"@warkypublic/artemis-kit": "file:../../artemis-kit",
|
||||||
|
"@warkypublic/resolvespec-js": "^1.0.1",
|
||||||
|
"@warkypublic/svelix": "^0.1.31"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
3150
ui/pnpm-lock.yaml
generated
3150
ui/pnpm-lock.yaml
generated
File diff suppressed because it is too large
Load Diff
@@ -1,5 +1,13 @@
|
|||||||
<script lang="ts">
|
<script lang="ts">
|
||||||
import { onMount } from "svelte";
|
import { onMount } from 'svelte';
|
||||||
|
import { getApiURL } from '@warkypublic/svelix';
|
||||||
|
import {
|
||||||
|
buildOAuthAuthorizationURL,
|
||||||
|
ensureApiURL,
|
||||||
|
exchangeOAuthCode,
|
||||||
|
GlobalStateStore,
|
||||||
|
setCurrentPath
|
||||||
|
} from './shellState';
|
||||||
|
|
||||||
type AccessEntry = {
|
type AccessEntry = {
|
||||||
key_id: string;
|
key_id: string;
|
||||||
@@ -22,205 +30,367 @@
|
|||||||
entries: AccessEntry[];
|
entries: AccessEntry[];
|
||||||
};
|
};
|
||||||
|
|
||||||
let data: StatusResponse | null = null;
|
type NavItem = {
|
||||||
let loading = true;
|
id: string;
|
||||||
let error = "";
|
label: string;
|
||||||
|
description: string;
|
||||||
|
disabled?: boolean;
|
||||||
|
};
|
||||||
|
|
||||||
const quickLinks = [
|
const navItems: NavItem[] = [
|
||||||
{ href: "/llm", label: "LLM Instructions" },
|
{
|
||||||
{ href: "/healthz", label: "Health Check" },
|
id: 'dashboard',
|
||||||
{ href: "/readyz", label: "Readiness Check" },
|
label: 'Dashboard',
|
||||||
|
description: 'System overview and status snapshots.'
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: 'projects',
|
||||||
|
label: 'Projects',
|
||||||
|
description: 'First management module for AMCS projects.'
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: 'thoughts',
|
||||||
|
label: 'Thoughts',
|
||||||
|
description: 'Thought management arrives after projects.',
|
||||||
|
disabled: true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: 'files',
|
||||||
|
label: 'Files',
|
||||||
|
description: 'File inventory and attachment views.',
|
||||||
|
disabled: true
|
||||||
|
}
|
||||||
];
|
];
|
||||||
|
|
||||||
async function loadStatus() {
|
let authMessage = $state('');
|
||||||
loading = true;
|
let authError = $state('');
|
||||||
error = "";
|
let authBusy = $state(false);
|
||||||
|
let callbackBusy = $state(false);
|
||||||
|
let data = $state<StatusResponse | null>(null);
|
||||||
|
let loading = $state(false);
|
||||||
|
let error = $state('');
|
||||||
|
let currentPage = $state<'dashboard' | 'projects'>('dashboard');
|
||||||
|
|
||||||
|
ensureApiURL(import.meta.env.VITE_API_URL);
|
||||||
|
|
||||||
|
const isLoggedIn = $derived(GlobalStateStore.isLoggedIn());
|
||||||
|
const currentPath = $derived(typeof window !== 'undefined' ? window.location.pathname : '/');
|
||||||
|
const isOAuthCallback = $derived(currentPath === '/oauth/callback');
|
||||||
|
const oauthAuthorizeURL = $derived(`${getApiURL()}/oauth/authorize`);
|
||||||
|
|
||||||
|
async function startOAuthLogin(): Promise<void> {
|
||||||
|
authBusy = true;
|
||||||
|
authError = '';
|
||||||
|
authMessage = '';
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const response = await fetch("/api/status");
|
const authorizationURL = await buildOAuthAuthorizationURL();
|
||||||
|
window.location.assign(authorizationURL);
|
||||||
|
} catch (err) {
|
||||||
|
authError = err instanceof Error ? err.message : 'Failed to start OAuth login.';
|
||||||
|
} finally {
|
||||||
|
authBusy = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async function finishOAuthLogin(): Promise<void> {
|
||||||
|
callbackBusy = true;
|
||||||
|
authError = '';
|
||||||
|
authMessage = '';
|
||||||
|
|
||||||
|
try {
|
||||||
|
const params = new URLSearchParams(window.location.search);
|
||||||
|
const code = params.get('code');
|
||||||
|
const returnedState = params.get('state');
|
||||||
|
const oauthError = params.get('error');
|
||||||
|
|
||||||
|
if (oauthError) {
|
||||||
|
throw new Error(`OAuth login failed: ${oauthError}`);
|
||||||
|
}
|
||||||
|
if (!code || !returnedState) {
|
||||||
|
throw new Error('OAuth callback is missing code or state.');
|
||||||
|
}
|
||||||
|
|
||||||
|
const token = await exchangeOAuthCode(code, returnedState);
|
||||||
|
await GlobalStateStore.getState().login(token, {
|
||||||
|
username: 'OAuth operator'
|
||||||
|
});
|
||||||
|
|
||||||
|
authMessage = 'OAuth login complete. Welcome back.';
|
||||||
|
window.history.replaceState({}, '', '/');
|
||||||
|
await loadStatus();
|
||||||
|
} catch (err) {
|
||||||
|
authError = err instanceof Error ? err.message : 'OAuth callback failed.';
|
||||||
|
} finally {
|
||||||
|
callbackBusy = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async function logout(): Promise<void> {
|
||||||
|
await GlobalStateStore.getState().logout();
|
||||||
|
authMessage = 'Logged out.';
|
||||||
|
authError = '';
|
||||||
|
}
|
||||||
|
|
||||||
|
async function loadStatus(): Promise<void> {
|
||||||
|
loading = true;
|
||||||
|
error = '';
|
||||||
|
|
||||||
|
try {
|
||||||
|
const response = await fetch('/api/status');
|
||||||
if (!response.ok) {
|
if (!response.ok) {
|
||||||
throw new Error(`Status request failed with ${response.status}`);
|
throw new Error(`Status request failed with ${response.status}`);
|
||||||
}
|
}
|
||||||
data = (await response.json()) as StatusResponse;
|
data = (await response.json()) as StatusResponse;
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
error = err instanceof Error ? err.message : "Failed to load status";
|
error = err instanceof Error ? err.message : 'Failed to load status';
|
||||||
} finally {
|
} finally {
|
||||||
loading = false;
|
loading = false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
function formatDate(value: string) {
|
function formatDate(value: string): string {
|
||||||
return new Date(value).toLocaleString();
|
return new Date(value).toLocaleString();
|
||||||
}
|
}
|
||||||
|
|
||||||
onMount(loadStatus);
|
onMount(async () => {
|
||||||
|
if (typeof window !== 'undefined') {
|
||||||
|
setCurrentPath(window.location.pathname);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (isOAuthCallback) {
|
||||||
|
await finishOAuthLogin();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (isLoggedIn) {
|
||||||
|
await loadStatus();
|
||||||
|
}
|
||||||
|
});
|
||||||
</script>
|
</script>
|
||||||
|
|
||||||
<svelte:head>
|
<svelte:head>
|
||||||
<title>AMCS</title>
|
<title>AMCS Admin</title>
|
||||||
</svelte:head>
|
</svelte:head>
|
||||||
|
|
||||||
<div class="min-h-screen bg-slate-950 text-slate-100">
|
<div class="min-h-screen bg-slate-950 text-slate-100">
|
||||||
<main
|
{#if !isLoggedIn}
|
||||||
class="mx-auto flex min-h-screen max-w-7xl flex-col px-4 py-6 sm:px-6 lg:px-8"
|
<main class="mx-auto flex min-h-screen max-w-6xl items-center px-4 py-10 sm:px-6 lg:px-8">
|
||||||
>
|
<section class="grid w-full gap-8 lg:grid-cols-[1.15fr_0.85fr]">
|
||||||
<section
|
<div class="rounded-3xl border border-cyan-400/20 bg-slate-900/80 p-8 shadow-2xl shadow-slate-950/40">
|
||||||
class="overflow-hidden rounded-3xl border border-white/10 bg-slate-900 shadow-2xl shadow-slate-950/40"
|
<div class="inline-flex items-center gap-2 rounded-full border border-cyan-400/20 bg-cyan-400/10 px-3 py-1 text-sm font-medium text-cyan-200">
|
||||||
>
|
|
||||||
<img
|
|
||||||
src="/images/project.jpg"
|
|
||||||
alt="Avelon Memory Crystal"
|
|
||||||
class="h-64 w-full object-cover object-center sm:h-80"
|
|
||||||
/>
|
|
||||||
|
|
||||||
<div class="grid gap-8 p-6 sm:p-8 lg:grid-cols-[1.6fr_1fr] lg:p-10">
|
|
||||||
<div class="space-y-6">
|
|
||||||
<div class="space-y-4">
|
|
||||||
<div
|
|
||||||
class="inline-flex items-center gap-2 rounded-full border border-cyan-400/20 bg-cyan-400/10 px-3 py-1 text-sm font-medium text-cyan-200"
|
|
||||||
>
|
|
||||||
<span class="h-2 w-2 rounded-full bg-emerald-400"></span>
|
<span class="h-2 w-2 rounded-full bg-emerald-400"></span>
|
||||||
Avalon Memory Crystal Server
|
AMCS Control Interface
|
||||||
</div>
|
</div>
|
||||||
<div>
|
<h1 class="mt-6 text-4xl font-semibold tracking-tight text-white">
|
||||||
<h1
|
{#if isOAuthCallback}
|
||||||
class="text-3xl font-semibold tracking-tight text-white sm:text-4xl"
|
Completing login
|
||||||
>
|
{:else}
|
||||||
Avelon Memory Crystal Server (AMCS)
|
Login
|
||||||
|
{/if}
|
||||||
</h1>
|
</h1>
|
||||||
<p
|
<p class="mt-3 max-w-2xl text-base leading-7 text-slate-300">
|
||||||
class="mt-3 max-w-3xl text-base leading-7 text-slate-300 sm:text-lg"
|
Origin-style operator access for the AMCS admin interface. ResolveSpec OAuth is the front door now,
|
||||||
>
|
not the old login shortcut.
|
||||||
{data?.description ??
|
|
||||||
"AMCS is a memory server that captures, links, and retrieves structured project thoughts for AI assistants using semantic search, summaries, and MCP tools."}
|
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
|
<div class="mt-8 grid gap-4 sm:grid-cols-2">
|
||||||
|
<div class="rounded-2xl border border-white/10 bg-white/5 p-5">
|
||||||
|
<p class="text-sm uppercase tracking-[0.2em] text-slate-400">Primary module</p>
|
||||||
|
<p class="mt-2 text-2xl font-semibold text-white">Projects</p>
|
||||||
|
<p class="mt-2 text-sm text-slate-400">Projects are the first real admin screen in this rollout.</p>
|
||||||
|
</div>
|
||||||
|
<div class="rounded-2xl border border-white/10 bg-white/5 p-5">
|
||||||
|
<p class="text-sm uppercase tracking-[0.2em] text-slate-400">OAuth path</p>
|
||||||
|
<p class="mt-2 text-2xl font-semibold text-white">ResolveSpec</p>
|
||||||
|
<p class="mt-2 text-sm text-slate-400">Client registration, authorize, callback, token exchange.</p>
|
||||||
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div class="flex flex-wrap gap-3">
|
<div class="rounded-3xl border border-white/10 bg-slate-900 p-6 shadow-xl shadow-slate-950/30 sm:p-8">
|
||||||
{#each quickLinks as link}
|
{#if isOAuthCallback}
|
||||||
<a
|
<h2 class="text-xl font-semibold text-white">Authorizing operator session</h2>
|
||||||
class="inline-flex items-center justify-center rounded-xl border border-cyan-300/20 bg-cyan-400/10 px-4 py-2 text-sm font-semibold text-cyan-100 transition hover:border-cyan-300/40 hover:bg-cyan-400/20"
|
<p class="mt-2 text-sm leading-6 text-slate-400">
|
||||||
href={link.href}>{link.label}</a
|
Finishing the ResolveSpec handshake and exchanging the returned code for an AMCS token.
|
||||||
>
|
</p>
|
||||||
{/each}
|
|
||||||
{#if data?.oauth_enabled}
|
<div class="mt-6 rounded-2xl border border-cyan-400/20 bg-cyan-400/5 px-4 py-6 text-sm text-cyan-100">
|
||||||
<a
|
{#if callbackBusy}
|
||||||
class="inline-flex items-center justify-center rounded-xl border border-violet-300/20 bg-violet-400/10 px-4 py-2 text-sm font-semibold text-violet-100 transition hover:border-violet-300/40 hover:bg-violet-400/20"
|
Working the callback doohickey…
|
||||||
href="/oauth-authorization-server">OAuth Authorization Server</a
|
{:else if authError}
|
||||||
>
|
Callback failed. Fix the route or try the login run again.
|
||||||
|
{:else}
|
||||||
|
Callback processed.
|
||||||
{/if}
|
{/if}
|
||||||
</div>
|
</div>
|
||||||
|
{:else}
|
||||||
|
<h2 class="text-xl font-semibold text-white">Operator login</h2>
|
||||||
|
<p class="mt-1 text-sm text-slate-400">Authenticate through AMCS ResolveSpec OAuth endpoints.</p>
|
||||||
|
|
||||||
<div class="grid gap-4 sm:grid-cols-3">
|
<div class="mt-6 space-y-4">
|
||||||
<div class="rounded-2xl border border-white/10 bg-white/5 p-5">
|
<button
|
||||||
<p class="text-sm uppercase tracking-[0.2em] text-slate-400">
|
type="button"
|
||||||
Connected users
|
class="inline-flex w-full items-center justify-center rounded-xl border border-cyan-300/20 bg-cyan-400/10 px-4 py-3 text-sm font-semibold text-cyan-100 transition hover:border-cyan-300/40 hover:bg-cyan-400/20 disabled:cursor-not-allowed disabled:opacity-60"
|
||||||
</p>
|
onclick={startOAuthLogin}
|
||||||
<p class="mt-2 text-3xl font-semibold text-white">
|
disabled={authBusy}
|
||||||
{data?.connected_count ?? "—"}
|
>
|
||||||
</p>
|
{#if authBusy}Starting OAuth login…{:else}Login with ResolveSpec OAuth{/if}
|
||||||
</div>
|
</button>
|
||||||
<div class="rounded-2xl border border-white/10 bg-white/5 p-5">
|
|
||||||
<p class="text-sm uppercase tracking-[0.2em] text-slate-400">
|
<div class="rounded-2xl border border-white/10 bg-white/5 p-4 text-sm text-slate-300">
|
||||||
Known principals
|
<p class="font-semibold text-white">Routes in play</p>
|
||||||
</p>
|
<ul class="mt-3 space-y-2 text-slate-400">
|
||||||
<p class="mt-2 text-3xl font-semibold text-white">
|
<li>• discovery: <code class="text-cyan-100">/api/.well-known/oauth-authorization-server</code></li>
|
||||||
{data?.total_known ?? "—"}
|
<li>• registration: <code class="text-cyan-100">/api/oauth/register</code></li>
|
||||||
</p>
|
<li>• authorize: <code class="text-cyan-100">{oauthAuthorizeURL}</code></li>
|
||||||
</div>
|
<li>• callback: <code class="text-cyan-100">/oauth/callback</code></li>
|
||||||
<div class="rounded-2xl border border-white/10 bg-white/5 p-5">
|
<li>• token: <code class="text-cyan-100">/api/oauth/token</code></li>
|
||||||
<p class="text-sm uppercase tracking-[0.2em] text-slate-400">
|
</ul>
|
||||||
Version
|
|
||||||
</p>
|
|
||||||
<p class="mt-2 break-all text-2xl font-semibold text-white">
|
|
||||||
{data?.version ?? "—"}
|
|
||||||
</p>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<aside
|
{#if authError}
|
||||||
class="space-y-4 rounded-2xl border border-white/10 bg-slate-950/50 p-5"
|
<p class="text-sm text-rose-300">{authError}</p>
|
||||||
>
|
{/if}
|
||||||
<div>
|
{#if authMessage}
|
||||||
<h2 class="text-lg font-semibold text-white">Build details</h2>
|
<p class="text-sm text-emerald-300">{authMessage}</p>
|
||||||
<p class="mt-1 text-sm text-slate-400">The same status info.</p>
|
{/if}
|
||||||
</div>
|
</div>
|
||||||
<dl class="space-y-3 text-sm text-slate-300">
|
{/if}
|
||||||
<div>
|
|
||||||
<dt class="text-slate-500">Build date</dt>
|
|
||||||
<dd class="mt-1 font-medium text-white">
|
|
||||||
{data?.build_date ?? "unknown"}
|
|
||||||
</dd>
|
|
||||||
</div>
|
|
||||||
<div>
|
|
||||||
<dt class="text-slate-500">Commit</dt>
|
|
||||||
<dd
|
|
||||||
class="mt-1 break-all rounded-lg bg-white/5 px-3 py-2 font-mono text-xs text-cyan-100"
|
|
||||||
>
|
|
||||||
{data?.commit ?? "unknown"}
|
|
||||||
</dd>
|
|
||||||
</div>
|
|
||||||
<div>
|
|
||||||
<dt class="text-slate-500">Connected window</dt>
|
|
||||||
<dd class="mt-1 font-medium text-white">
|
|
||||||
{data?.connected_window ?? "last 10 minutes"}
|
|
||||||
</dd>
|
|
||||||
</div>
|
|
||||||
</dl>
|
|
||||||
</aside>
|
|
||||||
</div>
|
</div>
|
||||||
</section>
|
</section>
|
||||||
|
</main>
|
||||||
<section
|
{:else}
|
||||||
class="mt-6 rounded-3xl border border-white/10 bg-slate-900/80 p-6 shadow-xl shadow-slate-950/20 sm:p-8"
|
<div class="grid min-h-screen lg:grid-cols-[17rem_1fr]">
|
||||||
>
|
<aside class="border-r border-white/10 bg-slate-900/90 p-6">
|
||||||
<div
|
|
||||||
class="flex flex-col gap-3 sm:flex-row sm:items-end sm:justify-between"
|
|
||||||
>
|
|
||||||
<div>
|
<div>
|
||||||
<h2 class="text-2xl font-semibold text-white">Recent access</h2>
|
<p class="text-xs uppercase tracking-[0.3em] text-cyan-300">AMCS</p>
|
||||||
<p class="mt-1 text-sm text-slate-400">
|
<h1 class="mt-2 text-2xl font-semibold text-white">Admin</h1>
|
||||||
Authenticated principals AMCS has seen recently.
|
<p class="mt-2 text-sm text-slate-400">Origin-style shell, starting with Projects.</p>
|
||||||
</p>
|
</div>
|
||||||
|
|
||||||
|
<nav class="mt-8 space-y-2">
|
||||||
|
{#each navItems as item}
|
||||||
|
<button
|
||||||
|
class={`w-full rounded-2xl border px-4 py-3 text-left transition ${item.disabled ? 'cursor-not-allowed border-white/5 bg-white/[0.02] text-slate-600' : currentPage === item.id ? 'border-cyan-300/30 bg-cyan-400/10 text-cyan-100' : 'border-white/10 bg-white/5 text-slate-200 hover:bg-white/10'}`}
|
||||||
|
disabled={item.disabled}
|
||||||
|
onclick={() => {
|
||||||
|
if (!item.disabled && (item.id === 'dashboard' || item.id === 'projects')) {
|
||||||
|
currentPage = item.id;
|
||||||
|
}
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<div class="text-sm font-semibold">{item.label}</div>
|
||||||
|
<div class="mt-1 text-xs text-slate-400">{item.description}</div>
|
||||||
|
</button>
|
||||||
|
{/each}
|
||||||
|
</nav>
|
||||||
|
|
||||||
|
<button
|
||||||
|
class="mt-8 inline-flex w-full items-center justify-center rounded-xl border border-white/10 bg-white/5 px-4 py-3 text-sm font-medium text-slate-200 transition hover:bg-white/10"
|
||||||
|
onclick={logout}
|
||||||
|
>
|
||||||
|
Logout
|
||||||
|
</button>
|
||||||
|
</aside>
|
||||||
|
|
||||||
|
<main class="px-4 py-6 sm:px-6 lg:px-8">
|
||||||
|
{#if currentPage === 'dashboard'}
|
||||||
|
<section class="rounded-3xl border border-white/10 bg-slate-900/80 p-6 shadow-xl shadow-slate-950/20 sm:p-8">
|
||||||
|
<div class="flex flex-col gap-3 sm:flex-row sm:items-end sm:justify-between">
|
||||||
|
<div>
|
||||||
|
<h2 class="text-2xl font-semibold text-white">System overview</h2>
|
||||||
|
<p class="mt-1 text-sm text-slate-400">Current AMCS status behind the admin shell.</p>
|
||||||
</div>
|
</div>
|
||||||
<button
|
<button
|
||||||
class="inline-flex items-center justify-center rounded-xl border border-white/10 bg-white/5 px-4 py-2 text-sm font-medium text-slate-200 transition hover:bg-white/10"
|
class="inline-flex items-center justify-center rounded-xl border border-white/10 bg-white/5 px-4 py-2 text-sm font-medium text-slate-200 transition hover:bg-white/10"
|
||||||
on:click={loadStatus}
|
onclick={loadStatus}
|
||||||
>
|
>
|
||||||
Refresh
|
Refresh
|
||||||
</button>
|
</button>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
{#if loading}
|
{#if loading}
|
||||||
<div
|
<div class="mt-6 rounded-2xl border border-dashed border-white/10 bg-slate-950/40 px-4 py-10 text-center text-slate-400">
|
||||||
class="mt-6 rounded-2xl border border-dashed border-white/10 bg-slate-950/40 px-4 py-10 text-center text-slate-400"
|
|
||||||
>
|
|
||||||
Loading status…
|
Loading status…
|
||||||
</div>
|
</div>
|
||||||
{:else if error}
|
{:else if error}
|
||||||
<div
|
<div class="mt-6 rounded-2xl border border-rose-400/30 bg-rose-400/10 px-4 py-6 text-sm text-rose-100">
|
||||||
class="mt-6 rounded-2xl border border-rose-400/30 bg-rose-400/10 px-4 py-6 text-sm text-rose-100"
|
|
||||||
>
|
|
||||||
<p class="font-semibold">Couldn’t load the status snapshot.</p>
|
<p class="font-semibold">Couldn’t load the status snapshot.</p>
|
||||||
<p class="mt-1 text-rose-100/80">{error}</p>
|
<p class="mt-1 text-rose-100/80">{error}</p>
|
||||||
</div>
|
</div>
|
||||||
{:else if data && data.entries.length === 0}
|
|
||||||
<div
|
|
||||||
class="mt-6 rounded-2xl border border-dashed border-white/10 bg-slate-950/40 px-4 py-10 text-center text-slate-400"
|
|
||||||
>
|
|
||||||
No authenticated access recorded yet.
|
|
||||||
</div>
|
|
||||||
{:else if data}
|
{:else if data}
|
||||||
|
<div class="mt-6 grid gap-4 sm:grid-cols-3">
|
||||||
|
<div class="rounded-2xl border border-white/10 bg-white/5 p-5">
|
||||||
|
<p class="text-sm uppercase tracking-[0.2em] text-slate-400">Connected users</p>
|
||||||
|
<p class="mt-2 text-3xl font-semibold text-white">{data.connected_count}</p>
|
||||||
|
</div>
|
||||||
|
<div class="rounded-2xl border border-white/10 bg-white/5 p-5">
|
||||||
|
<p class="text-sm uppercase tracking-[0.2em] text-slate-400">Known principals</p>
|
||||||
|
<p class="mt-2 text-3xl font-semibold text-white">{data.total_known}</p>
|
||||||
|
</div>
|
||||||
|
<div class="rounded-2xl border border-white/10 bg-white/5 p-5">
|
||||||
|
<p class="text-sm uppercase tracking-[0.2em] text-slate-400">Version</p>
|
||||||
|
<p class="mt-2 break-all text-2xl font-semibold text-white">{data.version}</p>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
{/if}
|
||||||
|
</section>
|
||||||
|
{:else}
|
||||||
|
<section class="rounded-3xl border border-white/10 bg-slate-900/80 p-6 shadow-xl shadow-slate-950/20 sm:p-8">
|
||||||
|
<div class="flex flex-col gap-3 sm:flex-row sm:items-end sm:justify-between">
|
||||||
|
<div>
|
||||||
|
<h2 class="text-2xl font-semibold text-white">Projects</h2>
|
||||||
|
<p class="mt-1 text-sm text-slate-400">First module scaffold. Grid/Form wiring comes next.</p>
|
||||||
|
</div>
|
||||||
|
<span class="inline-flex items-center rounded-full border border-amber-300/20 bg-amber-400/10 px-3 py-1 text-xs font-medium text-amber-200">
|
||||||
|
Structure phase
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="mt-6 grid gap-4 lg:grid-cols-[1.35fr_0.65fr]">
|
||||||
|
<div class="rounded-2xl border border-dashed border-cyan-400/20 bg-cyan-400/5 p-6">
|
||||||
|
<h3 class="text-lg font-semibold text-white">Project grid placeholder</h3>
|
||||||
|
<p class="mt-2 text-sm leading-6 text-slate-300">
|
||||||
|
This is the landing zone for the Origin-style projects grid using Svelix and GridlerFull.
|
||||||
|
Next pass: wire ResolveSpec-backed project list, row actions, and editor flow.
|
||||||
|
</p>
|
||||||
|
<ul class="mt-4 space-y-2 text-sm text-slate-400">
|
||||||
|
<li>• Project list and search</li>
|
||||||
|
<li>• Project detail/edit drawer or modal</li>
|
||||||
|
<li>• Create/archive actions</li>
|
||||||
|
<li>• Link-outs to related thoughts and skills</li>
|
||||||
|
</ul>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="rounded-2xl border border-white/10 bg-white/5 p-6">
|
||||||
|
<h3 class="text-lg font-semibold text-white">Build notes</h3>
|
||||||
|
<dl class="mt-4 space-y-3 text-sm text-slate-300">
|
||||||
|
<div>
|
||||||
|
<dt class="text-slate-500">Auth path</dt>
|
||||||
|
<dd class="mt-1">ResolveSpec OAuth packages</dd>
|
||||||
|
</div>
|
||||||
|
<div>
|
||||||
|
<dt class="text-slate-500">Page pattern</dt>
|
||||||
|
<dd class="mt-1">Mapped toward Origin login and shell</dd>
|
||||||
|
</div>
|
||||||
|
<div>
|
||||||
|
<dt class="text-slate-500">First module</dt>
|
||||||
|
<dd class="mt-1">Projects</dd>
|
||||||
|
</div>
|
||||||
|
</dl>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</section>
|
||||||
|
{/if}
|
||||||
|
|
||||||
|
{#if data && currentPage === 'dashboard' && data.entries.length > 0}
|
||||||
|
<section class="mt-6 rounded-3xl border border-white/10 bg-slate-900/80 p-6 shadow-xl shadow-slate-950/20 sm:p-8">
|
||||||
|
<h3 class="text-xl font-semibold text-white">Recent access</h3>
|
||||||
<div class="mt-6 overflow-hidden rounded-2xl border border-white/10">
|
<div class="mt-6 overflow-hidden rounded-2xl border border-white/10">
|
||||||
<div class="overflow-x-auto">
|
<div class="overflow-x-auto">
|
||||||
<table
|
<table class="min-w-full divide-y divide-white/10 text-left text-sm text-slate-300">
|
||||||
class="min-w-full divide-y divide-white/10 text-left text-sm text-slate-300"
|
<thead class="bg-white/5 text-xs uppercase tracking-[0.2em] text-slate-500">
|
||||||
>
|
|
||||||
<thead
|
|
||||||
class="bg-white/5 text-xs uppercase tracking-[0.2em] text-slate-500"
|
|
||||||
>
|
|
||||||
<tr>
|
<tr>
|
||||||
<th class="px-4 py-3 font-medium">Principal</th>
|
<th class="px-4 py-3 font-medium">Principal</th>
|
||||||
<th class="px-4 py-3 font-medium">Last accessed</th>
|
<th class="px-4 py-3 font-medium">Last accessed</th>
|
||||||
@@ -232,31 +402,20 @@
|
|||||||
<tbody class="divide-y divide-white/5 bg-slate-950/30">
|
<tbody class="divide-y divide-white/5 bg-slate-950/30">
|
||||||
{#each data.entries as entry}
|
{#each data.entries as entry}
|
||||||
<tr class="hover:bg-white/[0.03]">
|
<tr class="hover:bg-white/[0.03]">
|
||||||
<td class="px-4 py-3 align-top"
|
<td class="px-4 py-3 align-top"><code class="rounded bg-white/5 px-2 py-1 font-mono text-xs text-cyan-100">{entry.key_id}</code></td>
|
||||||
><code
|
<td class="px-4 py-3 align-top text-slate-200">{formatDate(entry.last_accessed_at)}</td>
|
||||||
class="rounded bg-white/5 px-2 py-1 font-mono text-xs text-cyan-100"
|
<td class="px-4 py-3 align-top"><code class="text-slate-100">{entry.last_path}</code></td>
|
||||||
>{entry.key_id}</code
|
<td class="max-w-[16rem] truncate px-4 py-3 align-top text-xs text-slate-400">{entry.user_agent ?? '—'}</td>
|
||||||
></td
|
<td class="px-4 py-3 align-top font-semibold text-white">{entry.request_count}</td>
|
||||||
>
|
|
||||||
<td class="px-4 py-3 align-top text-slate-200"
|
|
||||||
>{formatDate(entry.last_accessed_at)}</td
|
|
||||||
>
|
|
||||||
<td class="px-4 py-3 align-top"
|
|
||||||
><code class="text-slate-100">{entry.last_path}</code></td
|
|
||||||
>
|
|
||||||
<td class="px-4 py-3 align-top text-slate-400 text-xs max-w-[16rem] truncate"
|
|
||||||
>{entry.user_agent ?? "—"}</td
|
|
||||||
>
|
|
||||||
<td class="px-4 py-3 align-top font-semibold text-white"
|
|
||||||
>{entry.request_count}</td
|
|
||||||
>
|
|
||||||
</tr>
|
</tr>
|
||||||
{/each}
|
{/each}
|
||||||
</tbody>
|
</tbody>
|
||||||
</table>
|
</table>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
{/if}
|
|
||||||
</section>
|
</section>
|
||||||
|
{/if}
|
||||||
</main>
|
</main>
|
||||||
</div>
|
</div>
|
||||||
|
{/if}
|
||||||
|
</div>
|
||||||
|
|||||||
276
ui/src/shellState.ts
Normal file
276
ui/src/shellState.ts
Normal file
@@ -0,0 +1,276 @@
|
|||||||
|
import { GlobalStateStore } from '@warkypublic/svelix';
|
||||||
|
|
||||||
|
const normalizeApiURL = (url: string): string => url.replace(/\/+$/, '');
|
||||||
|
|
||||||
|
const resolveApiURL = (envURL?: string): string => {
|
||||||
|
const viteEnvURL =
|
||||||
|
envURL?.trim() ||
|
||||||
|
import.meta.env.VITE_API_URL?.trim() ||
|
||||||
|
import.meta.env.VITE_API_BASE_URL?.trim() ||
|
||||||
|
import.meta.env.VITE_URL?.trim();
|
||||||
|
|
||||||
|
if (viteEnvURL) return normalizeApiURL(viteEnvURL);
|
||||||
|
|
||||||
|
if (typeof window !== 'undefined') {
|
||||||
|
return `${window.location.protocol}//${window.location.host}/api`;
|
||||||
|
}
|
||||||
|
|
||||||
|
const stateURL = GlobalStateStore.getState().session.apiURL?.trim();
|
||||||
|
if (stateURL) return normalizeApiURL(stateURL);
|
||||||
|
|
||||||
|
return '';
|
||||||
|
};
|
||||||
|
|
||||||
|
export { GlobalStateStore };
|
||||||
|
|
||||||
|
export type OAuthClientRegistration = {
|
||||||
|
client_id: string;
|
||||||
|
client_name?: string;
|
||||||
|
redirect_uris?: string[];
|
||||||
|
grant_types?: string[];
|
||||||
|
response_types?: string[];
|
||||||
|
token_endpoint_auth_method?: string;
|
||||||
|
};
|
||||||
|
|
||||||
|
export type OAuthServerMetadata = {
|
||||||
|
issuer: string;
|
||||||
|
authorization_endpoint: string;
|
||||||
|
token_endpoint: string;
|
||||||
|
registration_endpoint: string;
|
||||||
|
scopes_supported?: string[];
|
||||||
|
response_types_supported?: string[];
|
||||||
|
grant_types_supported?: string[];
|
||||||
|
token_endpoint_auth_methods_supported?: string[];
|
||||||
|
code_challenge_methods_supported?: string[];
|
||||||
|
};
|
||||||
|
|
||||||
|
export type OAuthSession = {
|
||||||
|
clientId: string;
|
||||||
|
redirectURI: string;
|
||||||
|
codeVerifier: string;
|
||||||
|
state: string;
|
||||||
|
createdAt: number;
|
||||||
|
};
|
||||||
|
|
||||||
|
const OAUTH_SESSION_KEY = 'amcs.oauth.session';
|
||||||
|
const OAUTH_CLIENT_KEY = 'amcs.oauth.client';
|
||||||
|
const OAUTH_DEFAULT_SCOPE = 'mcp';
|
||||||
|
|
||||||
|
export function ensureApiURL(envURL?: string): string {
|
||||||
|
const resolved = resolveApiURL(envURL);
|
||||||
|
if (!resolved) return '';
|
||||||
|
|
||||||
|
const state = GlobalStateStore.getState();
|
||||||
|
if (state.session.apiURL !== resolved) {
|
||||||
|
state.setApiURL(resolved);
|
||||||
|
}
|
||||||
|
|
||||||
|
return resolved;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function getPublicBaseURL(): string {
|
||||||
|
if (typeof window === 'undefined') return '';
|
||||||
|
return `${window.location.protocol}//${window.location.host}`;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function getOAuthRedirectURI(): string {
|
||||||
|
const base = getPublicBaseURL();
|
||||||
|
return base ? `${base}/oauth/callback` : '/oauth/callback';
|
||||||
|
}
|
||||||
|
|
||||||
|
function getStorage(storageKey: string): string | null {
|
||||||
|
if (typeof window === 'undefined') return null;
|
||||||
|
return window.localStorage.getItem(storageKey);
|
||||||
|
}
|
||||||
|
|
||||||
|
function setStorage(storageKey: string, value: string): void {
|
||||||
|
if (typeof window === 'undefined') return;
|
||||||
|
window.localStorage.setItem(storageKey, value);
|
||||||
|
}
|
||||||
|
|
||||||
|
function removeStorage(storageKey: string): void {
|
||||||
|
if (typeof window === 'undefined') return;
|
||||||
|
window.localStorage.removeItem(storageKey);
|
||||||
|
}
|
||||||
|
|
||||||
|
export function readOAuthClient(): OAuthClientRegistration | null {
|
||||||
|
const raw = getStorage(OAUTH_CLIENT_KEY);
|
||||||
|
if (!raw) return null;
|
||||||
|
|
||||||
|
try {
|
||||||
|
return JSON.parse(raw) as OAuthClientRegistration;
|
||||||
|
} catch {
|
||||||
|
removeStorage(OAUTH_CLIENT_KEY);
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export function saveOAuthClient(client: OAuthClientRegistration): void {
|
||||||
|
setStorage(OAUTH_CLIENT_KEY, JSON.stringify(client));
|
||||||
|
}
|
||||||
|
|
||||||
|
export function readOAuthSession(): OAuthSession | null {
|
||||||
|
const raw = getStorage(OAUTH_SESSION_KEY);
|
||||||
|
if (!raw) return null;
|
||||||
|
|
||||||
|
try {
|
||||||
|
return JSON.parse(raw) as OAuthSession;
|
||||||
|
} catch {
|
||||||
|
removeStorage(OAUTH_SESSION_KEY);
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export function saveOAuthSession(session: OAuthSession): void {
|
||||||
|
setStorage(OAUTH_SESSION_KEY, JSON.stringify(session));
|
||||||
|
}
|
||||||
|
|
||||||
|
export function clearOAuthSession(): void {
|
||||||
|
removeStorage(OAUTH_SESSION_KEY);
|
||||||
|
}
|
||||||
|
|
||||||
|
export function setCurrentPath(pathname: string): void {
|
||||||
|
const state = GlobalStateStore.getState();
|
||||||
|
const current = state.navigation.currentPage ?? {};
|
||||||
|
|
||||||
|
state.setCurrentPage({
|
||||||
|
...current,
|
||||||
|
path: pathname
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
function createRandomString(length = 48): string {
|
||||||
|
const alphabet = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~';
|
||||||
|
|
||||||
|
if (typeof crypto !== 'undefined' && typeof crypto.getRandomValues === 'function') {
|
||||||
|
const bytes = new Uint8Array(length);
|
||||||
|
crypto.getRandomValues(bytes);
|
||||||
|
return Array.from(bytes, (byte) => alphabet[byte % alphabet.length]).join('');
|
||||||
|
}
|
||||||
|
|
||||||
|
return Array.from({ length }, () => alphabet[Math.floor(Math.random() * alphabet.length)]).join('');
|
||||||
|
}
|
||||||
|
|
||||||
|
function base64UrlEncode(buffer: ArrayBuffer): string {
|
||||||
|
let binary = '';
|
||||||
|
const bytes = new Uint8Array(buffer);
|
||||||
|
const chunkSize = 0x8000;
|
||||||
|
|
||||||
|
for (let index = 0; index < bytes.length; index += chunkSize) {
|
||||||
|
binary += String.fromCharCode(...bytes.subarray(index, index + chunkSize));
|
||||||
|
}
|
||||||
|
|
||||||
|
return btoa(binary).replace(/\+/g, '-').replace(/\//g, '_').replace(/=+$/g, '');
|
||||||
|
}
|
||||||
|
|
||||||
|
async function sha256(input: string): Promise<string> {
|
||||||
|
if (typeof crypto === 'undefined' || !crypto.subtle) {
|
||||||
|
throw new Error('Secure browser crypto is required for OAuth login.');
|
||||||
|
}
|
||||||
|
|
||||||
|
const data = new TextEncoder().encode(input);
|
||||||
|
const digest = await crypto.subtle.digest('SHA-256', data);
|
||||||
|
return base64UrlEncode(digest);
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function fetchOAuthMetadata(): Promise<OAuthServerMetadata> {
|
||||||
|
const apiURL = ensureApiURL();
|
||||||
|
const response = await fetch(`${apiURL}/.well-known/oauth-authorization-server`);
|
||||||
|
if (!response.ok) {
|
||||||
|
throw new Error(`Failed to load OAuth metadata (${response.status})`);
|
||||||
|
}
|
||||||
|
|
||||||
|
return (await response.json()) as OAuthServerMetadata;
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function ensureOAuthClientRegistration(metadata: OAuthServerMetadata): Promise<OAuthClientRegistration> {
|
||||||
|
const redirectURI = getOAuthRedirectURI();
|
||||||
|
const existing = readOAuthClient();
|
||||||
|
if (existing?.client_id && existing.redirect_uris?.includes(redirectURI)) {
|
||||||
|
return existing;
|
||||||
|
}
|
||||||
|
|
||||||
|
const response = await fetch(metadata.registration_endpoint, {
|
||||||
|
method: 'POST',
|
||||||
|
headers: {
|
||||||
|
'Content-Type': 'application/json'
|
||||||
|
},
|
||||||
|
body: JSON.stringify({
|
||||||
|
client_name: 'AMCS Admin UI',
|
||||||
|
redirect_uris: [redirectURI],
|
||||||
|
grant_types: ['authorization_code'],
|
||||||
|
response_types: ['code'],
|
||||||
|
token_endpoint_auth_method: 'none'
|
||||||
|
})
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!response.ok) {
|
||||||
|
throw new Error(`Failed to register OAuth client (${response.status})`);
|
||||||
|
}
|
||||||
|
|
||||||
|
const client = (await response.json()) as OAuthClientRegistration;
|
||||||
|
saveOAuthClient(client);
|
||||||
|
return client;
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function buildOAuthAuthorizationURL(): Promise<string> {
|
||||||
|
const metadata = await fetchOAuthMetadata();
|
||||||
|
const client = await ensureOAuthClientRegistration(metadata);
|
||||||
|
const codeVerifier = createRandomString(96);
|
||||||
|
const codeChallenge = await sha256(codeVerifier);
|
||||||
|
const state = createRandomString(40);
|
||||||
|
const redirectURI = getOAuthRedirectURI();
|
||||||
|
|
||||||
|
saveOAuthSession({
|
||||||
|
clientId: client.client_id,
|
||||||
|
redirectURI,
|
||||||
|
codeVerifier,
|
||||||
|
state,
|
||||||
|
createdAt: Date.now()
|
||||||
|
});
|
||||||
|
|
||||||
|
const url = new URL(metadata.authorization_endpoint);
|
||||||
|
url.searchParams.set('client_id', client.client_id);
|
||||||
|
url.searchParams.set('redirect_uri', redirectURI);
|
||||||
|
url.searchParams.set('response_type', 'code');
|
||||||
|
url.searchParams.set('scope', OAUTH_DEFAULT_SCOPE);
|
||||||
|
url.searchParams.set('state', state);
|
||||||
|
url.searchParams.set('code_challenge', codeChallenge);
|
||||||
|
url.searchParams.set('code_challenge_method', 'S256');
|
||||||
|
|
||||||
|
return url.toString();
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function exchangeOAuthCode(code: string, returnedState: string): Promise<string> {
|
||||||
|
const session = readOAuthSession();
|
||||||
|
if (!session) {
|
||||||
|
throw new Error('OAuth session is missing. Start login again.');
|
||||||
|
}
|
||||||
|
|
||||||
|
if (session.state !== returnedState) {
|
||||||
|
throw new Error('OAuth state mismatch. Start login again.');
|
||||||
|
}
|
||||||
|
|
||||||
|
const metadata = await fetchOAuthMetadata();
|
||||||
|
const response = await fetch(metadata.token_endpoint, {
|
||||||
|
method: 'POST',
|
||||||
|
headers: {
|
||||||
|
'Content-Type': 'application/x-www-form-urlencoded'
|
||||||
|
},
|
||||||
|
body: new URLSearchParams({
|
||||||
|
grant_type: 'authorization_code',
|
||||||
|
code,
|
||||||
|
redirect_uri: session.redirectURI,
|
||||||
|
client_id: session.clientId,
|
||||||
|
code_verifier: session.codeVerifier
|
||||||
|
})
|
||||||
|
});
|
||||||
|
|
||||||
|
const payload = (await response.json()) as { access_token?: string; error?: string };
|
||||||
|
if (!response.ok || !payload.access_token) {
|
||||||
|
throw new Error(payload.error || `Token exchange failed (${response.status})`);
|
||||||
|
}
|
||||||
|
|
||||||
|
clearOAuthSession();
|
||||||
|
return payload.access_token;
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user