feat: 🎉 Vectors na Vectors, the begining
Translate 1536 <-> 768 , 3072 <-> 2048
90
.github/workflows/release.yml
vendored
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
name: Release
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
tags:
|
||||||
|
- 'v*'
|
||||||
|
workflow_dispatch:
|
||||||
|
inputs:
|
||||||
|
tag:
|
||||||
|
description: 'Tag to release (e.g. v1.2.3)'
|
||||||
|
required: true
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: write
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
test:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- uses: actions/setup-go@v5
|
||||||
|
with:
|
||||||
|
go-version-file: go.mod
|
||||||
|
|
||||||
|
- name: Test
|
||||||
|
run: go test ./...
|
||||||
|
|
||||||
|
- name: Lint
|
||||||
|
run: go vet ./...
|
||||||
|
|
||||||
|
release:
|
||||||
|
needs: test
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
fetch-depth: 0
|
||||||
|
|
||||||
|
- uses: actions/setup-go@v5
|
||||||
|
with:
|
||||||
|
go-version-file: go.mod
|
||||||
|
|
||||||
|
- name: Build release binaries
|
||||||
|
run: |
|
||||||
|
VERSION="${{ github.event.inputs.tag || github.ref_name }}"
|
||||||
|
mkdir -p dist
|
||||||
|
for target in "linux/amd64" "linux/arm64" "darwin/amd64" "darwin/arm64" "windows/amd64"; do
|
||||||
|
GOOS="${target%/*}"
|
||||||
|
GOARCH="${target#*/}"
|
||||||
|
EXT=""
|
||||||
|
[ "$GOOS" = "windows" ] && EXT=".exe"
|
||||||
|
BINARY="vecna-${GOOS}-${GOARCH}${EXT}"
|
||||||
|
CGO_ENABLED=0 GOOS="$GOOS" GOARCH="$GOARCH" go build \
|
||||||
|
-trimpath \
|
||||||
|
-ldflags "-s -w -X main.version=${VERSION}" \
|
||||||
|
-o "dist/${BINARY}" \
|
||||||
|
./cmd/vecna
|
||||||
|
if [ "$GOOS" = "windows" ]; then
|
||||||
|
zip -j "dist/vecna-${GOOS}-${GOARCH}.zip" "dist/${BINARY}"
|
||||||
|
else
|
||||||
|
tar -czf "dist/vecna-${GOOS}-${GOARCH}.tar.gz" -C dist "${BINARY}"
|
||||||
|
fi
|
||||||
|
echo "Built ${BINARY}"
|
||||||
|
done
|
||||||
|
|
||||||
|
- name: Generate checksums
|
||||||
|
run: |
|
||||||
|
cd dist
|
||||||
|
sha256sum *.tar.gz *.zip > checksums.txt
|
||||||
|
|
||||||
|
- name: Create release and upload assets
|
||||||
|
run: |
|
||||||
|
TAG="${{ github.event.inputs.tag || github.ref_name }}"
|
||||||
|
|
||||||
|
PREV_TAG=$(git tag --sort=-version:refname | grep -v "^${TAG}$" | head -1)
|
||||||
|
if [ -n "$PREV_TAG" ]; then
|
||||||
|
RANGE="${PREV_TAG}..${TAG}"
|
||||||
|
else
|
||||||
|
RANGE="HEAD~20..HEAD"
|
||||||
|
fi
|
||||||
|
NOTES=$(git log "$RANGE" --pretty=format:"- %s" --no-merges)
|
||||||
|
BODY="## What's changed"$'\n'"${NOTES}"
|
||||||
|
|
||||||
|
gh release create "${TAG}" \
|
||||||
|
--title "${TAG}" \
|
||||||
|
--notes "${BODY}" \
|
||||||
|
dist/*.tar.gz dist/*.zip dist/checksums.txt
|
||||||
|
env:
|
||||||
|
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
2
.gitignore
vendored
@@ -30,3 +30,5 @@ go.work.sum
|
|||||||
# Editor/IDE
|
# Editor/IDE
|
||||||
# .idea/
|
# .idea/
|
||||||
# .vscode/
|
# .vscode/
|
||||||
|
|
||||||
|
bin/
|
||||||
|
|||||||
30
.golangci.yml
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
version: "2"
|
||||||
|
|
||||||
|
formatters:
|
||||||
|
enable:
|
||||||
|
- goimports
|
||||||
|
settings:
|
||||||
|
goimports:
|
||||||
|
local-prefixes:
|
||||||
|
- github.com/Warky-Devs/vecna.git
|
||||||
|
|
||||||
|
linters:
|
||||||
|
default: standard
|
||||||
|
enable:
|
||||||
|
- misspell
|
||||||
|
- gocritic
|
||||||
|
- noctx
|
||||||
|
- bodyclose
|
||||||
|
- errorlint
|
||||||
|
- copyloopvar
|
||||||
|
- durationcheck
|
||||||
|
|
||||||
|
linters-settings:
|
||||||
|
gocritic:
|
||||||
|
disabled-checks:
|
||||||
|
- commentedOutCode
|
||||||
|
- hugeParam
|
||||||
|
|
||||||
|
issues:
|
||||||
|
max-issues-per-linter: 0
|
||||||
|
max-same-issues: 0
|
||||||
22
Dockerfile
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
# ── build stage ───────────────────────────────────────────────────────────────
|
||||||
|
FROM golang:1.26-alpine AS builder
|
||||||
|
|
||||||
|
WORKDIR /src
|
||||||
|
|
||||||
|
COPY go.mod go.sum ./
|
||||||
|
RUN go mod download
|
||||||
|
|
||||||
|
COPY . .
|
||||||
|
RUN CGO_ENABLED=0 GOOS=linux go build -trimpath -ldflags="-s -w" -o /vecna ./cmd/vecna
|
||||||
|
|
||||||
|
# ── runtime stage ──────────────────────────────────────────────────────────────
|
||||||
|
FROM alpine:3.21
|
||||||
|
|
||||||
|
RUN apk add --no-cache ca-certificates
|
||||||
|
|
||||||
|
COPY --from=builder /vecna /usr/local/bin/vecna
|
||||||
|
|
||||||
|
EXPOSE 8080
|
||||||
|
|
||||||
|
ENTRYPOINT ["vecna"]
|
||||||
|
CMD ["serve"]
|
||||||
63
Makefile
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
BINARY := vecna
|
||||||
|
CMD := ./cmd/vecna
|
||||||
|
BUILD_DIR := ./bin
|
||||||
|
|
||||||
|
BUMP ?= patch
|
||||||
|
TEST_URL ?=
|
||||||
|
TEST_MODEL ?=
|
||||||
|
|
||||||
|
.PHONY: all build test lint fmt tidy clean release-version test-integration
|
||||||
|
|
||||||
|
all: tidy fmt lint test build
|
||||||
|
|
||||||
|
build:
|
||||||
|
mkdir -p $(BUILD_DIR)
|
||||||
|
go build -o $(BUILD_DIR)/$(BINARY) $(CMD)
|
||||||
|
|
||||||
|
test:
|
||||||
|
go test ./...
|
||||||
|
|
||||||
|
lint:
|
||||||
|
golangci-lint run ./...
|
||||||
|
|
||||||
|
fmt:
|
||||||
|
golangci-lint run --fix ./...
|
||||||
|
gofmt -w .
|
||||||
|
|
||||||
|
tidy:
|
||||||
|
go mod tidy
|
||||||
|
|
||||||
|
clean:
|
||||||
|
rm -rf $(BUILD_DIR)
|
||||||
|
|
||||||
|
# Run dimension integration tests against a live embedding server.
|
||||||
|
# Requires VECNA_TEST_URL and VECNA_TEST_MODEL to be set.
|
||||||
|
#
|
||||||
|
# Examples:
|
||||||
|
# make test-integration TEST_URL=http://localhost:11434 TEST_MODEL=nomic-embed-text
|
||||||
|
# make test-integration TEST_URL=http://localhost:11434 TEST_MODEL=nomic-embed-text VECNA_TEST_TARGET_DIM=256
|
||||||
|
#
|
||||||
|
test-integration:
|
||||||
|
@if [ -z "$(TEST_URL)" ]; then echo "ERROR: TEST_URL is required"; exit 1; fi
|
||||||
|
@if [ -z "$(TEST_MODEL)" ]; then echo "ERROR: TEST_MODEL is required"; exit 1; fi
|
||||||
|
VECNA_TEST_URL=$(TEST_URL) \
|
||||||
|
VECNA_TEST_MODEL=$(TEST_MODEL) \
|
||||||
|
VECNA_TEST_API_TYPE=$(or $(TEST_API_TYPE),openai) \
|
||||||
|
VECNA_TEST_API_KEY=$(TEST_API_KEY) \
|
||||||
|
go test -v -tags integration -timeout 120s ./tests/integration/
|
||||||
|
|
||||||
|
release-version:
|
||||||
|
@current=$$(git describe --tags --abbrev=0 2>/dev/null || echo "v0.0.0"); \
|
||||||
|
major=$$(echo $$current | sed 's/^v//' | cut -d. -f1); \
|
||||||
|
minor=$$(echo $$current | sed 's/^v//' | cut -d. -f2); \
|
||||||
|
patch=$$(echo $$current | sed 's/^v//' | cut -d. -f3); \
|
||||||
|
case "$(BUMP)" in \
|
||||||
|
major) major=$$((major+1)); minor=0; patch=0 ;; \
|
||||||
|
minor) minor=$$((minor+1)); patch=0 ;; \
|
||||||
|
patch) patch=$$((patch+1)) ;; \
|
||||||
|
*) echo "Unknown BUMP value '$(BUMP)'. Use: patch (default), minor, major"; exit 1 ;; \
|
||||||
|
esac; \
|
||||||
|
next="v$${major}.$${minor}.$${patch}"; \
|
||||||
|
echo "Tagging $$current → $$next"; \
|
||||||
|
git tag -a "$$next" -m "Release $$next"; \
|
||||||
|
git push origin "$$next"
|
||||||
361
README.md
@@ -1,2 +1,359 @@
|
|||||||
# vecna
|
# vecna — Vectors na Vectors
|
||||||
Vecna - Vectors na Vectors, Translate 1536 <-> 768 , 3072 <-> 2048
|
|
||||||
|
**vecna** is an OpenAI- and Google-compatible HTTP proxy that sits between your application and a local or remote embedding model. It forwards text to the backing model, receives the raw embedding vector, and re-shapes it to the dimension your vector database or application expects — without any changes to your client code.
|
||||||
|
|
||||||
|
New to embeddings? → [docs/what_is_embeddings.md](docs/what_is_embeddings.md)
|
||||||
|
|
||||||
|
Every time I install a tool that needs vector embeddings, it assumes OpenAI's dimensions. There are very few local models that match those exact dimensions. Most tools don't give you an option to change them, and some vector databases are hardcoded to fixed dimensions entirely. Many open source models do stick to the same dimensions, which helps — but I run embedding models across several machines and it's always been a pain to use them effectively. This is where vecna helps me.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Why
|
||||||
|
|
||||||
|
Most vector databases are initialized with a fixed dimension (e.g. 1536 for `pgvector` defaults, 768 for many HNSW indexes). If you switch embedding models — or run a smaller local model that produces 768-dim vectors — your existing index breaks.
|
||||||
|
|
||||||
|
vecna solves this by translating dimensions at the proxy layer:
|
||||||
|
|
||||||
|
- **Downscale**: 3072 → 1536, 768 → 256 (truncation or random projection)
|
||||||
|
- **Upscale**: 384 → 1536, 768 → 1536 (zero-padding)
|
||||||
|
- **Same-dim pass-through**: no transformation, just proxy and auth
|
||||||
|
|
||||||
|
All output vectors are L2-normalized so cosine similarity remains valid.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Install
|
||||||
|
|
||||||
|
```sh
|
||||||
|
go install github.com/Warky-Devs/vecna.git/cmd/vecna@latest
|
||||||
|
```
|
||||||
|
|
||||||
|
Or build from source:
|
||||||
|
|
||||||
|
```sh
|
||||||
|
make build # outputs ./bin/vecna
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Quick start
|
||||||
|
|
||||||
|
```sh
|
||||||
|
# Interactive setup: discovers local servers, configures adapter, writes config
|
||||||
|
vecna onboard
|
||||||
|
|
||||||
|
# Start the proxy
|
||||||
|
vecna serve
|
||||||
|
|
||||||
|
# Test a request (OpenAI-compatible)
|
||||||
|
curl -s http://localhost:8080/v1/embeddings \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{"input": "hello world", "model": "nomic-embed-text"}'
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Commands
|
||||||
|
|
||||||
|
| Command | Description |
|
||||||
|
|---------|-------------|
|
||||||
|
| `vecna onboard` | Interactive wizard: discover servers → detect dims → configure → test → write config |
|
||||||
|
| `vecna serve` | Start the proxy server |
|
||||||
|
| `vecna query <text>` | Embed text and print the resulting vector as JSON |
|
||||||
|
| `vecna convert` | Convert vectors from file/stdin using the configured adapter |
|
||||||
|
| `vecna search` | Scan LAN for embedding servers and add one to config |
|
||||||
|
| `vecna models` | List models available on each configured forwarder |
|
||||||
|
| `vecna test` | Test each configured endpoint; `--remove-broken` prunes failing ones |
|
||||||
|
| `vecna editconfig` | Print config path and open it in `$EDITOR` |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Query
|
||||||
|
|
||||||
|
Send text directly to a forwarding target and print the adapted vector as JSON.
|
||||||
|
|
||||||
|
```sh
|
||||||
|
# uses forward.default target
|
||||||
|
vecna query "hello world"
|
||||||
|
|
||||||
|
# specific target
|
||||||
|
vecna query --target ollama "hello world"
|
||||||
|
|
||||||
|
# skip the adapter — raw model output
|
||||||
|
vecna query --raw "hello world"
|
||||||
|
|
||||||
|
# compact single-line output (pipe-friendly)
|
||||||
|
vecna query --compact "hello world"
|
||||||
|
|
||||||
|
# read text from stdin
|
||||||
|
echo "hello world" | vecna query -
|
||||||
|
|
||||||
|
# inspect a single dimension
|
||||||
|
vecna query --compact "hello world" | jq '.[0]'
|
||||||
|
|
||||||
|
# save to file
|
||||||
|
vecna query --compact "hello world" > vector.json
|
||||||
|
```
|
||||||
|
|
||||||
|
Status info (target, model, dims, tokens) is written to stderr; stdout is clean JSON.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
Default config path: `~/vecna.json` (created by `onboard` or `editconfig`).
|
||||||
|
|
||||||
|
Override with `--config path/to/file.yaml` or env vars prefixed `VECNA_`.
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"server": {
|
||||||
|
"port": 8080,
|
||||||
|
"host": "0.0.0.0",
|
||||||
|
"api_keys": ["sk-vecna-abc123"]
|
||||||
|
},
|
||||||
|
"forward": {
|
||||||
|
"default": "ollama",
|
||||||
|
"targets": {
|
||||||
|
"ollama": {
|
||||||
|
"api_type": "openai",
|
||||||
|
"model": "nomic-embed-text",
|
||||||
|
"api_key": "",
|
||||||
|
"timeout_secs": 30,
|
||||||
|
"cooldown_secs": 60,
|
||||||
|
"priority_decay": 2,
|
||||||
|
"priority_recovery": 5,
|
||||||
|
"endpoints": [
|
||||||
|
{"url": "http://localhost:11434", "priority": 10}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"adapter": {
|
||||||
|
"type": "truncate",
|
||||||
|
"source_dim": 768,
|
||||||
|
"target_dim": 1536,
|
||||||
|
"truncate_mode": "from_end",
|
||||||
|
"pad_mode": "at_end"
|
||||||
|
},
|
||||||
|
"metrics": {
|
||||||
|
"enabled": true,
|
||||||
|
"path": "/metrics",
|
||||||
|
"api_key": ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Important: quality and consistency warnings
|
||||||
|
|
||||||
|
### Upscaling reduces quality — use the highest-dimension model you can
|
||||||
|
|
||||||
|
Upscaling (e.g. 768 → 1536) **does not add information**. The extra dimensions are zeros (truncate adapter) or linear combinations of existing values (random/projection). The resulting vectors occupy a 1536-dim space but carry no more semantic content than the original 768-dim ones.
|
||||||
|
|
||||||
|
**The model's native output dimension is the ceiling of quality.** If your vector database requires 1536 dims, use a model that natively produces 1536 dims. Use vecna's upscaling only as a compatibility shim when you cannot change the index schema — not as a way to improve retrieval quality.
|
||||||
|
|
||||||
|
- Downscale (higher → lower): small, controlled quality loss. Acceptable for MRL models.
|
||||||
|
- Upscale (lower → higher): no quality gain, only compatibility. Replace the model when possible.
|
||||||
|
|
||||||
|
### Changing any adapter setting requires regenerating all stored embeddings
|
||||||
|
|
||||||
|
vecna's adapter is applied at query time **and** at indexing time. If you change any of the following, every vector already stored in your database is now in a different space and comparisons against new queries will silently return wrong results:
|
||||||
|
|
||||||
|
- `type` (truncate / random / projection)
|
||||||
|
- `source_dim` or `target_dim`
|
||||||
|
- `truncate_mode` (from_end / from_start)
|
||||||
|
- `pad_mode` (at_end / at_start)
|
||||||
|
- `seed` (random adapter)
|
||||||
|
- the backing model itself
|
||||||
|
|
||||||
|
**When you change adapter settings: stop ingestion, re-embed your entire corpus through vecna with the new settings, repopulate the index, then resume.**
|
||||||
|
|
||||||
|
There is no partial migration path — a mixed index produces degraded or incorrect search results.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Adapter types
|
||||||
|
|
||||||
|
| Type | Description |
|
||||||
|
|------|-------------|
|
||||||
|
| `truncate` | Slice or zero-pad the vector. Fast, deterministic. Best for MRL-trained models. |
|
||||||
|
| `random` | Seeded Gaussian projection matrix. Preserves distances (Johnson-Lindenstrauss). |
|
||||||
|
| `projection` | Learned linear matrix from a JSON file. Highest quality, requires pre-training. |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Truncation and padding modes
|
||||||
|
|
||||||
|
### `truncate_mode` — which part of the vector is kept when downscaling
|
||||||
|
|
||||||
|
| Value | Keeps |
|
||||||
|
|-------|-------|
|
||||||
|
| `from_end` *(default)* | first N dimensions |
|
||||||
|
| `from_start` | last N dimensions |
|
||||||
|
|
||||||
|
**`from_end`** — use for **Matryoshka Representation Learning (MRL)** models. The most important information is packed into the first dimensions.
|
||||||
|
Models: `nomic-embed-text`, `mxbai-embed-large`, `text-embedding-3-small`, `text-embedding-3-large`, `snowflake-arctic-embed`, `e5-mistral-7b-instruct`.
|
||||||
|
|
||||||
|
**`from_start`** — use when task-specific information is at the end of the vector. Try this if `from_end` gives poor retrieval on a non-MRL model.
|
||||||
|
Models: some fine-tuned BERT variants, domain-specific models with task heads appended after base dimensions.
|
||||||
|
|
||||||
|
### `pad_mode` — where zeros are inserted when upscaling
|
||||||
|
|
||||||
|
| Value | Zeros go |
|
||||||
|
|-------|----------|
|
||||||
|
| `at_end` *(default)* | after the real values |
|
||||||
|
| `at_start` | before the real values |
|
||||||
|
|
||||||
|
**`at_end`** — almost always correct. Keeps the original vector in the first N positions.
|
||||||
|
|
||||||
|
**`at_start`** — use if your index expects meaningful content at the end of the vector.
|
||||||
|
|
||||||
|
### Common combinations
|
||||||
|
|
||||||
|
| Scenario | `truncate_mode` | `pad_mode` |
|
||||||
|
|----------|----------------|------------|
|
||||||
|
| MRL model downscale | `from_end` | `at_end` |
|
||||||
|
| MRL model upscale (e.g. 768→1536) | `from_end` | `at_end` |
|
||||||
|
| Non-MRL BERT fine-tune | `from_start` | `at_end` |
|
||||||
|
| Custom index with leading-zeros convention | `from_end` | `at_start` |
|
||||||
|
|
||||||
|
When unsure, run `vecna test` before and after and compare the reported L2 norm.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## API endpoints
|
||||||
|
|
||||||
|
### OpenAI-compatible
|
||||||
|
|
||||||
|
```
|
||||||
|
POST /v1/embeddings
|
||||||
|
Authorization: Bearer <api_key>
|
||||||
|
Content-Type: application/json
|
||||||
|
|
||||||
|
{"input": "text or array of texts", "model": "nomic-embed-text"}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Google Gemini-compatible
|
||||||
|
|
||||||
|
```
|
||||||
|
POST /v1/models/{model}:embedContent
|
||||||
|
POST /v1/models/{model}:batchEmbedContents
|
||||||
|
```
|
||||||
|
|
||||||
|
### OpenAPI spec and docs
|
||||||
|
|
||||||
|
```
|
||||||
|
GET /openapi.yaml
|
||||||
|
GET /docs
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Response tracing headers
|
||||||
|
|
||||||
|
| Header | Value |
|
||||||
|
|--------|-------|
|
||||||
|
| `X-Vecna-Forward-Ms` | Time waiting on the backing model |
|
||||||
|
| `X-Vecna-Translate-Ms` | Time in the adapter |
|
||||||
|
| `X-Vecna-Total-Ms` | Total request wall time |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Prometheus metrics
|
||||||
|
|
||||||
|
Enable in config: `metrics.enabled: true`. Scrape at `GET /metrics`.
|
||||||
|
|
||||||
|
| Metric | Type | Description |
|
||||||
|
|--------|------|-------------|
|
||||||
|
| `vecna_requests_total` | counter | Requests served, by endpoint and status |
|
||||||
|
| `vecna_request_duration_seconds` | histogram | Total request wall time |
|
||||||
|
| `vecna_forward_duration_seconds` | histogram | Time waiting on the backing model |
|
||||||
|
| `vecna_translate_duration_seconds` | histogram | Time in the adapter |
|
||||||
|
| `vecna_endpoint_priority` | gauge | Current dynamic routing priority per endpoint |
|
||||||
|
| `vecna_endpoint_inflight` | gauge | Active in-flight requests per endpoint |
|
||||||
|
| `vecna_endpoint_errors_total` | counter | Forwarding failures by error type |
|
||||||
|
| `vecna_tokens_total` | counter | Tokens consumed, by target, model, and type (`prompt`/`total`) |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Development
|
||||||
|
|
||||||
|
```sh
|
||||||
|
make build # compile
|
||||||
|
make test # unit tests
|
||||||
|
make lint # golangci-lint
|
||||||
|
make fmt # goimports + gofmt
|
||||||
|
|
||||||
|
# Integration tests against a live server
|
||||||
|
make test-integration TEST_URL=http://localhost:11434 TEST_MODEL=nomic-embed-text
|
||||||
|
|
||||||
|
# Tag and push a release
|
||||||
|
make release-version BUMP=patch # patch | minor | major
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Docker
|
||||||
|
|
||||||
|
### Build
|
||||||
|
|
||||||
|
```sh
|
||||||
|
docker build -t vecna .
|
||||||
|
```
|
||||||
|
|
||||||
|
### First-time setup with docker compose
|
||||||
|
|
||||||
|
```sh
|
||||||
|
cp docker-compose.example.yml docker-compose.yml
|
||||||
|
docker compose up -d
|
||||||
|
```
|
||||||
|
|
||||||
|
Starts vecna and an Ollama instance. The `vecna_config` named volume persists the config across container rebuilds.
|
||||||
|
|
||||||
|
### Onboard (interactive setup)
|
||||||
|
|
||||||
|
```sh
|
||||||
|
docker compose run --rm -it vecna onboard --config /config/vecna.json
|
||||||
|
```
|
||||||
|
|
||||||
|
Ollama is reachable by hostname on the Docker network — the scanner will find it automatically. After onboarding, restart the proxy:
|
||||||
|
|
||||||
|
```sh
|
||||||
|
docker compose restart vecna
|
||||||
|
```
|
||||||
|
|
||||||
|
### Query
|
||||||
|
|
||||||
|
```sh
|
||||||
|
docker compose run --rm vecna query --compact "hello world" --config /config/vecna.json
|
||||||
|
```
|
||||||
|
|
||||||
|
### Test endpoints
|
||||||
|
|
||||||
|
```sh
|
||||||
|
# report latency and dims
|
||||||
|
docker compose run --rm vecna test --config /config/vecna.json
|
||||||
|
|
||||||
|
# test and remove failing endpoints
|
||||||
|
docker compose run --rm vecna test --config /config/vecna.json --remove-broken
|
||||||
|
```
|
||||||
|
|
||||||
|
### Edit config manually
|
||||||
|
|
||||||
|
```sh
|
||||||
|
docker compose run --rm -it vecna sh -c "vi /config/vecna.json"
|
||||||
|
```
|
||||||
|
|
||||||
|
### With Prometheus
|
||||||
|
|
||||||
|
```sh
|
||||||
|
docker compose --profile metrics up -d
|
||||||
|
```
|
||||||
|
|
||||||
|
Scrape config is in `prometheus.example.yml`. Set `bearer_token` if `metrics.api_key` is configured.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
© Hein Puth — [Warky Devs (Pty) Ltd](https://github.com/Warky-Devs)
|
||||||
|
|||||||
160
cmd/vecna/convert.go
Normal file
@@ -0,0 +1,160 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
|
||||||
|
"github.com/Warky-Devs/vecna.git/pkg/adapter"
|
||||||
|
"github.com/Warky-Devs/vecna.git/pkg/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
convertInput string
|
||||||
|
convertOutput string
|
||||||
|
)
|
||||||
|
|
||||||
|
var convertCmd = &cobra.Command{
|
||||||
|
Use: "convert",
|
||||||
|
Short: "Convert vectors from one dimension to another",
|
||||||
|
Long: "Reads a JSON array of float32 vectors, applies the configured adapter, and writes the result.",
|
||||||
|
RunE: runConvert,
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
convertCmd.Flags().StringVarP(&convertInput, "input", "i", "-", "input file path (- for stdin)")
|
||||||
|
convertCmd.Flags().StringVarP(&convertOutput, "output", "o", "-", "output file path (- for stdout)")
|
||||||
|
}
|
||||||
|
|
||||||
|
func runConvert(cmd *cobra.Command, _ []string) error {
|
||||||
|
cfg, err := config.Load(cfgFile)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("load config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
adp, err := buildAdapter(cfg)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("build adapter: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
in, err := openReader(convertInput)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("open input: %w", err)
|
||||||
|
}
|
||||||
|
if f, ok := in.(*os.File); ok && f != os.Stdin {
|
||||||
|
defer func() { _ = f.Close() }()
|
||||||
|
}
|
||||||
|
|
||||||
|
out, err := openWriter(convertOutput)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("open output: %w", err)
|
||||||
|
}
|
||||||
|
if f, ok := out.(*os.File); ok && f != os.Stdout {
|
||||||
|
defer func() { _ = f.Close() }()
|
||||||
|
}
|
||||||
|
|
||||||
|
var vecs [][]float32
|
||||||
|
if err := json.NewDecoder(in).Decode(&vecs); err != nil {
|
||||||
|
return fmt.Errorf("decode input: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
result := make([][]float32, len(vecs))
|
||||||
|
for i, v := range vecs {
|
||||||
|
adapted, adaptErr := adp.Adapt(v)
|
||||||
|
if adaptErr != nil {
|
||||||
|
return fmt.Errorf("adapt vector %d: %w", i, adaptErr)
|
||||||
|
}
|
||||||
|
result[i] = adapted
|
||||||
|
}
|
||||||
|
|
||||||
|
enc := json.NewEncoder(out)
|
||||||
|
enc.SetIndent("", " ")
|
||||||
|
if err := enc.Encode(result); err != nil {
|
||||||
|
return fmt.Errorf("encode output: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func openReader(path string) (io.Reader, error) {
|
||||||
|
if path == "-" {
|
||||||
|
return os.Stdin, nil
|
||||||
|
}
|
||||||
|
return os.Open(path)
|
||||||
|
}
|
||||||
|
|
||||||
|
func openWriter(path string) (io.Writer, error) {
|
||||||
|
if path == "-" {
|
||||||
|
return os.Stdout, nil
|
||||||
|
}
|
||||||
|
return os.Create(path)
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildAdapter constructs the Adapter from the loaded config.
|
||||||
|
func buildAdapter(cfg *config.Config) (adapter.Adapter, error) {
|
||||||
|
ac := cfg.Adapter
|
||||||
|
switch ac.Type {
|
||||||
|
case "truncate":
|
||||||
|
tm, pm, err := parseTruncateModes(ac.TruncateMode, ac.PadMode)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return adapter.NewTruncate(ac.SourceDim, ac.TargetDim, tm, pm)
|
||||||
|
|
||||||
|
case "random":
|
||||||
|
return adapter.NewRandom(ac.SourceDim, ac.TargetDim, ac.Seed)
|
||||||
|
|
||||||
|
case "projection":
|
||||||
|
if ac.MatrixFile == "" {
|
||||||
|
return nil, fmt.Errorf("adapter type 'projection' requires matrix_file")
|
||||||
|
}
|
||||||
|
matrix, err := loadMatrix(ac.MatrixFile)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("load projection matrix: %w", err)
|
||||||
|
}
|
||||||
|
return adapter.NewProjection(ac.SourceDim, ac.TargetDim, matrix)
|
||||||
|
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("unknown adapter type %q; valid: truncate, random, projection", ac.Type)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseTruncateModes(truncMode, padMode string) (adapter.TruncateMode, adapter.PadMode, error) {
|
||||||
|
var tm adapter.TruncateMode
|
||||||
|
switch truncMode {
|
||||||
|
case "from_end", "":
|
||||||
|
tm = adapter.TruncateFromEnd
|
||||||
|
case "from_start":
|
||||||
|
tm = adapter.TruncateFromStart
|
||||||
|
default:
|
||||||
|
return 0, 0, fmt.Errorf("unknown truncate_mode %q; valid: from_end, from_start", truncMode)
|
||||||
|
}
|
||||||
|
|
||||||
|
var pm adapter.PadMode
|
||||||
|
switch padMode {
|
||||||
|
case "at_end", "":
|
||||||
|
pm = adapter.PadAtEnd
|
||||||
|
case "at_start":
|
||||||
|
pm = adapter.PadAtStart
|
||||||
|
default:
|
||||||
|
return 0, 0, fmt.Errorf("unknown pad_mode %q; valid: at_end, at_start", padMode)
|
||||||
|
}
|
||||||
|
|
||||||
|
return tm, pm, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func loadMatrix(path string) ([][]float32, error) {
|
||||||
|
f, err := os.Open(path)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer func() { _ = f.Close() }()
|
||||||
|
|
||||||
|
var m [][]float32
|
||||||
|
if err := json.NewDecoder(f).Decode(&m); err != nil {
|
||||||
|
return nil, fmt.Errorf("decode matrix JSON: %w", err)
|
||||||
|
}
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
105
cmd/vecna/editconfig.go
Normal file
@@ -0,0 +1,105 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
|
||||||
|
"github.com/Warky-Devs/vecna.git/pkg/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
var editConfigCmd = &cobra.Command{
|
||||||
|
Use: "editconfig",
|
||||||
|
Short: "Open the vecna config file in your editor",
|
||||||
|
RunE: runEditConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
rootCmd.AddCommand(editConfigCmd)
|
||||||
|
}
|
||||||
|
|
||||||
|
func runEditConfig(cmd *cobra.Command, _ []string) error {
|
||||||
|
path := config.ResolveFile(cfgFile)
|
||||||
|
|
||||||
|
if _, err := os.Stat(path); os.IsNotExist(err) {
|
||||||
|
if err := createDefaultConfig(path); err != nil {
|
||||||
|
return fmt.Errorf("create default config: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println(path)
|
||||||
|
|
||||||
|
editor := resolveEditor()
|
||||||
|
c := exec.CommandContext(context.Background(), editor, path)
|
||||||
|
c.Stdin = os.Stdin
|
||||||
|
c.Stdout = os.Stdout
|
||||||
|
c.Stderr = os.Stderr
|
||||||
|
if err := c.Run(); err != nil {
|
||||||
|
return fmt.Errorf("editor %q exited with error: %w", editor, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// resolveEditor returns $EDITOR, falling back to nvim then nano.
|
||||||
|
func resolveEditor() string {
|
||||||
|
if e := os.Getenv("EDITOR"); e != "" {
|
||||||
|
return e
|
||||||
|
}
|
||||||
|
for _, e := range []string{"nvim", "nano"} {
|
||||||
|
if path, err := exec.LookPath(e); err == nil {
|
||||||
|
return path
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "nano"
|
||||||
|
}
|
||||||
|
|
||||||
|
// createDefaultConfig writes a minimal JSON config skeleton to path.
|
||||||
|
func createDefaultConfig(path string) error {
|
||||||
|
skeleton := config.Config{
|
||||||
|
Server: config.ServerConfig{
|
||||||
|
Port: 8080,
|
||||||
|
Host: "0.0.0.0",
|
||||||
|
APIKeys: []string{},
|
||||||
|
},
|
||||||
|
Metrics: config.MetricsConfig{
|
||||||
|
Enabled: false,
|
||||||
|
Path: "/metrics",
|
||||||
|
},
|
||||||
|
Forward: config.ForwardConfig{
|
||||||
|
Default: "default",
|
||||||
|
Targets: map[string]config.ForwardTarget{
|
||||||
|
"default": {
|
||||||
|
APIType: "openai",
|
||||||
|
Model: "text-embedding-3-small",
|
||||||
|
Endpoints: []config.EndpointConfig{
|
||||||
|
{URL: "https://api.openai.com", Priority: 10},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Adapter: config.AdapterConfig{
|
||||||
|
Type: "truncate",
|
||||||
|
SourceDim: 1536,
|
||||||
|
TargetDim: 768,
|
||||||
|
TruncateMode: "from_end",
|
||||||
|
PadMode: "at_end",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
f, err := os.Create(path)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("create %s: %w", path, err)
|
||||||
|
}
|
||||||
|
defer func() { _ = f.Close() }()
|
||||||
|
|
||||||
|
enc := json.NewEncoder(f)
|
||||||
|
enc.SetIndent("", " ")
|
||||||
|
if err := enc.Encode(skeleton); err != nil {
|
||||||
|
return fmt.Errorf("write default config: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
43
cmd/vecna/main.go
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
cfgFile string
|
||||||
|
logLevel string
|
||||||
|
logger *zap.Logger
|
||||||
|
version = "dev"
|
||||||
|
)
|
||||||
|
|
||||||
|
var rootCmd = &cobra.Command{
|
||||||
|
Use: "vecna",
|
||||||
|
Short: "Embedding dimension adapter — translate vectors between model spaces",
|
||||||
|
PersistentPreRunE: func(cmd *cobra.Command, args []string) error {
|
||||||
|
var err error
|
||||||
|
if logLevel == "debug" {
|
||||||
|
logger, err = zap.NewDevelopment()
|
||||||
|
} else {
|
||||||
|
logger, err = zap.NewProduction()
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
rootCmd.PersistentFlags().StringVar(&cfgFile, "config", "", "config file (default: ./vecna.yaml)")
|
||||||
|
rootCmd.PersistentFlags().StringVar(&logLevel, "log-level", "info", "log level: info | debug")
|
||||||
|
rootCmd.AddCommand(convertCmd)
|
||||||
|
rootCmd.AddCommand(serveCmd)
|
||||||
|
rootCmd.AddCommand(versionCmd)
|
||||||
|
}
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
if err := rootCmd.Execute(); err != nil {
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
}
|
||||||
91
cmd/vecna/models.go
Normal file
@@ -0,0 +1,91 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
|
||||||
|
"github.com/Warky-Devs/vecna.git/pkg/config"
|
||||||
|
"github.com/Warky-Devs/vecna.git/pkg/discovery"
|
||||||
|
)
|
||||||
|
|
||||||
|
var modelsCmd = &cobra.Command{
|
||||||
|
Use: "models",
|
||||||
|
Short: "List models available on each configured forwarder",
|
||||||
|
RunE: runModels,
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
rootCmd.AddCommand(modelsCmd)
|
||||||
|
}
|
||||||
|
|
||||||
|
func runModels(_ *cobra.Command, _ []string) error {
|
||||||
|
cfg, err := config.Load(cfgFile)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("load config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(cfg.Forward.Targets) == 0 {
|
||||||
|
fmt.Println("No forwarder targets configured.")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
for targetName, target := range cfg.Forward.Targets {
|
||||||
|
fmt.Printf("[ %s ]\n", targetName)
|
||||||
|
|
||||||
|
for _, ep := range target.Endpoints {
|
||||||
|
kind := discovery.Kind{
|
||||||
|
Name: targetName,
|
||||||
|
APIType: target.APIType,
|
||||||
|
Port: 0, // not used by Models()
|
||||||
|
}
|
||||||
|
|
||||||
|
models, err := discovery.Models(ctx, ep.URL, kind)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf(" %s error: %s\n\n", ep.URL, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(models) == 0 {
|
||||||
|
fmt.Printf(" %s (no models listed)\n\n", ep.URL)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf(" %s\n", ep.URL)
|
||||||
|
for _, m := range models {
|
||||||
|
marker := " "
|
||||||
|
if m == target.Model {
|
||||||
|
marker = "* "
|
||||||
|
}
|
||||||
|
fmt.Printf(" %s%s\n", marker, m)
|
||||||
|
}
|
||||||
|
|
||||||
|
if target.Model != "" && !contains(models, target.Model) {
|
||||||
|
fmt.Printf(" ! configured model %q not found in list\n", target.Model)
|
||||||
|
}
|
||||||
|
fmt.Println()
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(target.Endpoints) == 0 {
|
||||||
|
fmt.Printf(" (no endpoints configured)\n\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf(" API type : %s\n", target.APIType)
|
||||||
|
fmt.Printf(" Model : %s\n\n", strings.TrimSpace(target.Model))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func contains(ss []string, s string) bool {
|
||||||
|
for _, v := range ss {
|
||||||
|
if v == s {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
480
cmd/vecna/onboard.go
Normal file
@@ -0,0 +1,480 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
|
||||||
|
"github.com/Warky-Devs/vecna.git/pkg/config"
|
||||||
|
"github.com/Warky-Devs/vecna.git/pkg/discovery"
|
||||||
|
"github.com/Warky-Devs/vecna.git/pkg/embedclient"
|
||||||
|
)
|
||||||
|
|
||||||
|
var onboardCmd = &cobra.Command{
|
||||||
|
Use: "onboard",
|
||||||
|
Short: "Interactive setup wizard: discover servers, configure, test, and write config",
|
||||||
|
RunE: runOnboard,
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
rootCmd.AddCommand(onboardCmd)
|
||||||
|
}
|
||||||
|
|
||||||
|
func runOnboard(_ *cobra.Command, _ []string) error {
|
||||||
|
in := bufio.NewReader(os.Stdin)
|
||||||
|
|
||||||
|
fmt.Println("=== vecna onboard ===")
|
||||||
|
fmt.Println()
|
||||||
|
|
||||||
|
// ── Step 1: Discover ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
step(1, 5, "Discover embedding servers")
|
||||||
|
|
||||||
|
fmt.Println("Scanning (Ollama, LM Studio, vLLM, LocalAI, Jan, Kobold, Tabby)...")
|
||||||
|
servers := discovery.Scan(context.Background())
|
||||||
|
|
||||||
|
var targets []pendingTarget
|
||||||
|
|
||||||
|
if len(servers) == 0 {
|
||||||
|
fmt.Println("No servers found automatically.")
|
||||||
|
} else {
|
||||||
|
fmt.Printf("Found %d server(s):\n\n", len(servers))
|
||||||
|
for i, s := range servers {
|
||||||
|
fmt.Printf(" [%d] %-12s %s\n Models: %s\n\n",
|
||||||
|
i+1, s.Kind.Name, s.BaseURL, joinModels(s.Models))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Let user pick one or more from the list; 0 = manual
|
||||||
|
for {
|
||||||
|
choice, err := promptInt(in,
|
||||||
|
fmt.Sprintf("Select server [1-%d] or 0 to enter URL manually: ", len(servers)), 0, len(servers))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
var pt pendingTarget
|
||||||
|
if choice == 0 {
|
||||||
|
pt, err = collectManualTarget(in)
|
||||||
|
} else {
|
||||||
|
pt, err = collectDiscoveredTarget(in, servers[choice-1])
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
targets = append(targets, pt)
|
||||||
|
|
||||||
|
again, err := promptBool(in, "Add another forwarder?", false)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if !again {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(targets) == 0 {
|
||||||
|
pt, err := collectManualTarget(in)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
targets = append(targets, pt)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Step 2: Detect dimensions ─────────────────────────────────────────────
|
||||||
|
|
||||||
|
step(2, 5, "Detect model dimensions")
|
||||||
|
|
||||||
|
for i := range targets {
|
||||||
|
fmt.Printf("Probing %s / %s ... ", targets[i].endpoint, targets[i].model)
|
||||||
|
dim, err := detectDim(targets[i])
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("failed (%s) — you will need to enter the dimension manually\n", err)
|
||||||
|
targets[i].detectedDim = 0
|
||||||
|
} else {
|
||||||
|
fmt.Printf("%d dims\n", dim)
|
||||||
|
targets[i].detectedDim = dim
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fmt.Println()
|
||||||
|
|
||||||
|
// ── Step 3: Configure adapter ─────────────────────────────────────────────
|
||||||
|
|
||||||
|
step(3, 5, "Configure dimension adapter")
|
||||||
|
|
||||||
|
// Use the first target's detected dim as the source dimension default
|
||||||
|
firstDim := 0
|
||||||
|
for _, t := range targets {
|
||||||
|
if t.detectedDim > 0 {
|
||||||
|
firstDim = t.detectedDim
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
srcDimStr := ""
|
||||||
|
if firstDim > 0 {
|
||||||
|
srcDimStr = fmt.Sprintf("%d", firstDim)
|
||||||
|
}
|
||||||
|
|
||||||
|
sourceDimRaw, err := promptString(in,
|
||||||
|
fmt.Sprintf("Source dimension (native model output dim)%s: ", defaultHint(srcDimStr)), srcDimStr)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
sourceDim := mustParseInt(sourceDimRaw, firstDim)
|
||||||
|
|
||||||
|
targetDimRaw, err := promptString(in, "Target dimension (output dim vecna will serve) [1536]: ", "1536")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
targetDim := mustParseInt(targetDimRaw, 1536)
|
||||||
|
|
||||||
|
adapterType, err := promptString(in, "Adapter type (truncate/random/projection) [truncate]: ", "truncate")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
truncateMode := "from_end"
|
||||||
|
padMode := "at_end"
|
||||||
|
if adapterType == "truncate" {
|
||||||
|
truncateMode, err = promptString(in, "Truncate mode (from_end/from_start) [from_end]: ", "from_end")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
padMode, err = promptString(in, "Pad mode (at_end/at_start) [at_end]: ", "at_end")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fmt.Println()
|
||||||
|
|
||||||
|
// ── Step 4: Configure vecna server ────────────────────────────────────────
|
||||||
|
|
||||||
|
step(4, 5, "Configure vecna server")
|
||||||
|
|
||||||
|
portRaw, err := promptString(in, "Bind port [8080]: ", "8080")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
port := mustParseInt(portRaw, 8080)
|
||||||
|
|
||||||
|
apiKeysRaw, err := promptString(in,
|
||||||
|
"Inbound API keys for vecna (comma-separated, leave empty to disable auth): ", "")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
var apiKeys []string
|
||||||
|
for _, k := range strings.Split(apiKeysRaw, ",") {
|
||||||
|
if k := strings.TrimSpace(k); k != "" {
|
||||||
|
apiKeys = append(apiKeys, k)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
enableMetrics, err := promptBool(in, "Enable Prometheus /metrics endpoint?", false)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
metricsAPIKey := ""
|
||||||
|
if enableMetrics {
|
||||||
|
metricsAPIKey, err = promptString(in, "Metrics API key (leave empty for open access): ", "")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fmt.Println()
|
||||||
|
|
||||||
|
// ── Step 5: Test & write ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
step(5, 5, "Test connections and write config")
|
||||||
|
|
||||||
|
allPassed := true
|
||||||
|
for _, t := range targets {
|
||||||
|
fmt.Printf("Testing %-45s ", t.endpoint+"...")
|
||||||
|
_, elapsed, dims, testErr := runSingleTest(t)
|
||||||
|
if testErr != nil {
|
||||||
|
fmt.Printf("FAIL %s\n", truncate(testErr.Error(), 55))
|
||||||
|
allPassed = false
|
||||||
|
} else {
|
||||||
|
fmt.Printf("OK %dms dims=%d\n", elapsed.Milliseconds(), dims)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fmt.Println()
|
||||||
|
|
||||||
|
if !allPassed {
|
||||||
|
proceed, err := promptBool(in, "Some tests failed. Write config anyway?", false)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if !proceed {
|
||||||
|
fmt.Println("Aborted. No config written.")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build the config struct
|
||||||
|
defaultTarget := ""
|
||||||
|
forwardTargets := make(map[string]config.ForwardTarget, len(targets))
|
||||||
|
for i, t := range targets {
|
||||||
|
forwardTargets[t.name] = config.ForwardTarget{
|
||||||
|
APIType: t.apiType,
|
||||||
|
Model: t.model,
|
||||||
|
APIKey: t.apiKey,
|
||||||
|
Endpoints: []config.EndpointConfig{
|
||||||
|
{URL: t.endpoint, Priority: 10},
|
||||||
|
},
|
||||||
|
TimeoutSecs: 30,
|
||||||
|
CooldownSecs: 60,
|
||||||
|
PriorityDecay: 2,
|
||||||
|
PriorityRecovery: 5,
|
||||||
|
}
|
||||||
|
if i == 0 {
|
||||||
|
defaultTarget = t.name
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg := config.Config{
|
||||||
|
Server: config.ServerConfig{
|
||||||
|
Port: port,
|
||||||
|
Host: "0.0.0.0",
|
||||||
|
APIKeys: apiKeys,
|
||||||
|
},
|
||||||
|
Metrics: config.MetricsConfig{
|
||||||
|
Enabled: enableMetrics,
|
||||||
|
Path: "/metrics",
|
||||||
|
APIKey: metricsAPIKey,
|
||||||
|
},
|
||||||
|
Forward: config.ForwardConfig{
|
||||||
|
Default: defaultTarget,
|
||||||
|
Targets: forwardTargets,
|
||||||
|
},
|
||||||
|
Adapter: config.AdapterConfig{
|
||||||
|
Type: adapterType,
|
||||||
|
SourceDim: sourceDim,
|
||||||
|
TargetDim: targetDim,
|
||||||
|
TruncateMode: truncateMode,
|
||||||
|
PadMode: padMode,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
defaultCfgPath := config.ResolveFile(cfgFile)
|
||||||
|
fmt.Printf("Config will be written to: %s\n", defaultCfgPath)
|
||||||
|
cfgPath, err := promptString(in, "Config path (press Enter to accept): ", defaultCfgPath)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := writeFullConfig(cfgPath, cfg); err != nil {
|
||||||
|
return fmt.Errorf("write config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("Config written to %s\n", cfgPath)
|
||||||
|
fmt.Println()
|
||||||
|
fmt.Println("Run 'vecna serve' to start the proxy server.")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── helpers ──────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
// pendingTarget collects configuration for a single forwarding target before
|
||||||
|
// the config is assembled.
|
||||||
|
type pendingTarget struct {
|
||||||
|
name string
|
||||||
|
endpoint string
|
||||||
|
model string
|
||||||
|
apiType string
|
||||||
|
apiKey string
|
||||||
|
detectedDim int
|
||||||
|
}
|
||||||
|
|
||||||
|
func step(n, total int, title string) {
|
||||||
|
fmt.Printf("[%d/%d] %s\n", n, total, title)
|
||||||
|
fmt.Println(strings.Repeat("-", 40))
|
||||||
|
}
|
||||||
|
|
||||||
|
func defaultHint(s string) string {
|
||||||
|
if s == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return fmt.Sprintf(" [%s]", s)
|
||||||
|
}
|
||||||
|
|
||||||
|
func joinModels(models []string) string {
|
||||||
|
if len(models) == 0 {
|
||||||
|
return "(none)"
|
||||||
|
}
|
||||||
|
if len(models) > 5 {
|
||||||
|
return strings.Join(models[:5], ", ") + fmt.Sprintf(" (+%d more)", len(models)-5)
|
||||||
|
}
|
||||||
|
return strings.Join(models, ", ")
|
||||||
|
}
|
||||||
|
|
||||||
|
func collectDiscoveredTarget(in *bufio.Reader, srv discovery.Found) (pendingTarget, error) {
|
||||||
|
defaultName := strings.ToLower(strings.ReplaceAll(srv.Kind.Name, " ", "_"))
|
||||||
|
|
||||||
|
model, err := pickModel(in, srv.Models)
|
||||||
|
if err != nil {
|
||||||
|
return pendingTarget{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
name, err := promptString(in, fmt.Sprintf("Target name in config [%s]: ", defaultName), defaultName)
|
||||||
|
if err != nil {
|
||||||
|
return pendingTarget{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
apiKey, err := promptAPIKey(in, srv.Kind.NeedsKey)
|
||||||
|
if err != nil {
|
||||||
|
return pendingTarget{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return pendingTarget{
|
||||||
|
name: name,
|
||||||
|
endpoint: srv.BaseURL,
|
||||||
|
model: model,
|
||||||
|
apiType: srv.Kind.APIType,
|
||||||
|
apiKey: apiKey,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func collectManualTarget(in *bufio.Reader) (pendingTarget, error) {
|
||||||
|
fmt.Println("Enter server details manually:")
|
||||||
|
|
||||||
|
endpoint, err := promptString(in, "Server URL (e.g. http://localhost:11434): ", "")
|
||||||
|
if err != nil || endpoint == "" {
|
||||||
|
return pendingTarget{}, fmt.Errorf("server URL is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
apiTypeStr, err := promptString(in, "API type (openai/google) [openai]: ", "openai")
|
||||||
|
if err != nil {
|
||||||
|
return pendingTarget{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
model, err := promptString(in, "Model name: ", "")
|
||||||
|
if err != nil || model == "" {
|
||||||
|
return pendingTarget{}, fmt.Errorf("model name is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
name, err := promptString(in, "Target name in config [custom]: ", "custom")
|
||||||
|
if err != nil {
|
||||||
|
return pendingTarget{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
apiKey, err := promptAPIKey(in, false)
|
||||||
|
if err != nil {
|
||||||
|
return pendingTarget{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return pendingTarget{
|
||||||
|
name: name,
|
||||||
|
endpoint: endpoint,
|
||||||
|
model: model,
|
||||||
|
apiType: apiTypeStr,
|
||||||
|
apiKey: apiKey,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func pickModel(in *bufio.Reader, models []string) (string, error) {
|
||||||
|
switch {
|
||||||
|
case len(models) == 0:
|
||||||
|
return promptString(in, "Model name: ", "")
|
||||||
|
case len(models) == 1:
|
||||||
|
fmt.Printf("Using model: %s\n", models[0])
|
||||||
|
return models[0], nil
|
||||||
|
default:
|
||||||
|
fmt.Println("Available models:")
|
||||||
|
for i, m := range models {
|
||||||
|
fmt.Printf(" [%d] %s\n", i+1, m)
|
||||||
|
}
|
||||||
|
idx, err := promptInt(in, fmt.Sprintf("Select model [1-%d]: ", len(models)), 1, len(models))
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return models[idx-1], nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func promptAPIKey(in *bufio.Reader, required bool) (string, error) {
|
||||||
|
prompt := "API key (leave empty if none): "
|
||||||
|
if required {
|
||||||
|
prompt = "API key: "
|
||||||
|
}
|
||||||
|
return promptString(in, prompt, "")
|
||||||
|
}
|
||||||
|
|
||||||
|
// detectDim sends a single test embedding and returns the vector length.
|
||||||
|
func detectDim(t pendingTarget) (int, error) {
|
||||||
|
httpClient := &http.Client{Timeout: 10 * time.Second}
|
||||||
|
var client embedclient.Client
|
||||||
|
if t.apiType == "google" {
|
||||||
|
client = embedclient.NewGoogle(t.endpoint, t.apiKey, t.model, httpClient)
|
||||||
|
} else {
|
||||||
|
client = embedclient.NewOpenAI(t.endpoint, t.apiKey, httpClient)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
resp, err := client.Embed(ctx, embedclient.Request{
|
||||||
|
Texts: []string{"dimension probe"},
|
||||||
|
Model: t.model,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
if len(resp.Embeddings) == 0 || len(resp.Embeddings[0]) == 0 {
|
||||||
|
return 0, fmt.Errorf("empty embedding in response")
|
||||||
|
}
|
||||||
|
return len(resp.Embeddings[0]), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// runSingleTest runs one test embed and returns success, elapsed time, dims, and any error.
|
||||||
|
func runSingleTest(t pendingTarget) (bool, time.Duration, int, error) {
|
||||||
|
httpClient := &http.Client{Timeout: 30 * time.Second}
|
||||||
|
var client embedclient.Client
|
||||||
|
if t.apiType == "google" {
|
||||||
|
client = embedclient.NewGoogle(t.endpoint, t.apiKey, t.model, httpClient)
|
||||||
|
} else {
|
||||||
|
client = embedclient.NewOpenAI(t.endpoint, t.apiKey, httpClient)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
start := time.Now()
|
||||||
|
resp, err := client.Embed(ctx, embedclient.Request{
|
||||||
|
Texts: []string{testPhrase},
|
||||||
|
Model: t.model,
|
||||||
|
})
|
||||||
|
elapsed := time.Since(start)
|
||||||
|
if err != nil {
|
||||||
|
return false, elapsed, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
dims, _ := embeddingStats(resp.Embeddings)
|
||||||
|
return true, elapsed, dims, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func mustParseInt(s string, fallback int) int {
|
||||||
|
var n int
|
||||||
|
if _, err := fmt.Sscanf(s, "%d", &n); err != nil || n <= 0 {
|
||||||
|
return fallback
|
||||||
|
}
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeFullConfig(path string, cfg config.Config) error {
|
||||||
|
// If file already exists, preserve any targets not touched by onboard
|
||||||
|
// by using SaveTarget for each new target; otherwise write the whole file.
|
||||||
|
if _, err := os.Stat(path); os.IsNotExist(err) {
|
||||||
|
if err := createDefaultConfig(path); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Overwrite with the complete onboard config
|
||||||
|
return config.WriteConfig(path, cfg)
|
||||||
|
}
|
||||||
59
cmd/vecna/prompt.go
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"fmt"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// promptInt reads an integer in [min, max] from the reader, re-prompting on bad input.
|
||||||
|
func promptInt(in *bufio.Reader, prompt string, min, max int) (int, error) {
|
||||||
|
for {
|
||||||
|
fmt.Print(prompt)
|
||||||
|
line, err := in.ReadString('\n')
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("read input: %w", err)
|
||||||
|
}
|
||||||
|
n, err := strconv.Atoi(strings.TrimSpace(line))
|
||||||
|
if err != nil || n < min || n > max {
|
||||||
|
fmt.Printf(" Enter a number between %d and %d\n", min, max)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// promptString reads a line, returning defaultVal when the user presses Enter with no input.
|
||||||
|
func promptString(in *bufio.Reader, prompt, defaultVal string) (string, error) {
|
||||||
|
fmt.Print(prompt)
|
||||||
|
line, err := in.ReadString('\n')
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("read input: %w", err)
|
||||||
|
}
|
||||||
|
if s := strings.TrimSpace(line); s != "" {
|
||||||
|
return s, nil
|
||||||
|
}
|
||||||
|
return defaultVal, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// promptBool reads a y/N confirmation, returning defaultVal on empty input.
|
||||||
|
func promptBool(in *bufio.Reader, prompt string, defaultVal bool) (bool, error) {
|
||||||
|
hint := "y/N"
|
||||||
|
if defaultVal {
|
||||||
|
hint = "Y/n"
|
||||||
|
}
|
||||||
|
fmt.Printf("%s [%s]: ", prompt, hint)
|
||||||
|
line, err := in.ReadString('\n')
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("read input: %w", err)
|
||||||
|
}
|
||||||
|
switch strings.ToLower(strings.TrimSpace(line)) {
|
||||||
|
case "y", "yes":
|
||||||
|
return true, nil
|
||||||
|
case "n", "no":
|
||||||
|
return false, nil
|
||||||
|
default:
|
||||||
|
return defaultVal, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
154
cmd/vecna/query.go
Normal file
@@ -0,0 +1,154 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
|
||||||
|
"github.com/Warky-Devs/vecna.git/pkg/config"
|
||||||
|
"github.com/Warky-Devs/vecna.git/pkg/embedclient"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
queryTarget string
|
||||||
|
queryRaw bool
|
||||||
|
queryCompact bool
|
||||||
|
)
|
||||||
|
|
||||||
|
var queryCmd = &cobra.Command{
|
||||||
|
Use: "query <text>",
|
||||||
|
Short: "Embed text and print the resulting vector",
|
||||||
|
Long: `Sends text to the configured forwarding target, applies the dimension adapter,
|
||||||
|
and prints the resulting vector as a JSON array.
|
||||||
|
|
||||||
|
Text can be supplied as a positional argument or via stdin (use - as the argument).`,
|
||||||
|
Args: cobra.MaximumNArgs(1),
|
||||||
|
RunE: runQuery,
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
queryCmd.Flags().StringVar(&queryTarget, "target", "",
|
||||||
|
"forward target to use (default: forward.default from config)")
|
||||||
|
queryCmd.Flags().BoolVar(&queryRaw, "raw", false,
|
||||||
|
"skip the adapter — output the raw vector from the backing model")
|
||||||
|
queryCmd.Flags().BoolVar(&queryCompact, "compact", false,
|
||||||
|
"print vector on a single line instead of pretty-printed")
|
||||||
|
rootCmd.AddCommand(queryCmd)
|
||||||
|
}
|
||||||
|
|
||||||
|
func runQuery(_ *cobra.Command, args []string) error {
|
||||||
|
cfg, err := config.Load(cfgFile)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("load config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Resolve text: positional arg, "-" reads stdin, no arg reads stdin.
|
||||||
|
text, err := queryText(args)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Resolve target.
|
||||||
|
targetName := queryTarget
|
||||||
|
if targetName == "" {
|
||||||
|
targetName = cfg.Forward.Default
|
||||||
|
}
|
||||||
|
target, ok := cfg.Forward.Targets[targetName]
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("target %q not found in config", targetName)
|
||||||
|
}
|
||||||
|
if len(target.Endpoints) == 0 {
|
||||||
|
return fmt.Errorf("target %q has no endpoints", targetName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build client (use first endpoint directly — no router needed for a one-shot query).
|
||||||
|
ep := target.Endpoints[0]
|
||||||
|
apiKey := ep.APIKey
|
||||||
|
if apiKey == "" {
|
||||||
|
apiKey = target.APIKey
|
||||||
|
}
|
||||||
|
timeout := time.Duration(target.TimeoutSecs) * time.Second
|
||||||
|
if timeout == 0 {
|
||||||
|
timeout = 30 * time.Second
|
||||||
|
}
|
||||||
|
httpClient := &http.Client{Timeout: timeout}
|
||||||
|
|
||||||
|
var client embedclient.Client
|
||||||
|
switch target.APIType {
|
||||||
|
case "google":
|
||||||
|
client = embedclient.NewGoogle(ep.URL, apiKey, target.Model, httpClient)
|
||||||
|
default:
|
||||||
|
client = embedclient.NewOpenAI(ep.URL, apiKey, httpClient)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
resp, err := client.Embed(ctx, embedclient.Request{
|
||||||
|
Texts: []string{text},
|
||||||
|
Model: target.Model,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("embed: %w", err)
|
||||||
|
}
|
||||||
|
if len(resp.Embeddings) == 0 || len(resp.Embeddings[0]) == 0 {
|
||||||
|
return fmt.Errorf("empty embedding in response")
|
||||||
|
}
|
||||||
|
|
||||||
|
vec := resp.Embeddings[0]
|
||||||
|
|
||||||
|
if !queryRaw {
|
||||||
|
adp, err := buildAdapter(cfg)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("build adapter: %w", err)
|
||||||
|
}
|
||||||
|
vec, err = adp.Adapt(vec)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("adapt: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Fprintf(os.Stderr, "target=%s model=%s dims=%d tokens=%d\n",
|
||||||
|
targetName, resp.Model, len(vec), resp.Usage.TotalTokens)
|
||||||
|
|
||||||
|
return printVector(vec)
|
||||||
|
}
|
||||||
|
|
||||||
|
func queryText(args []string) (string, error) {
|
||||||
|
if len(args) == 0 || args[0] == "-" {
|
||||||
|
raw, err := os.ReadFile("/dev/stdin")
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("read stdin: %w", err)
|
||||||
|
}
|
||||||
|
return strings.TrimRight(string(raw), "\n"), nil
|
||||||
|
}
|
||||||
|
return args[0], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func printVector(vec []float32) error {
|
||||||
|
// Convert to []any so json.Marshal produces clean floats without float32 quirks.
|
||||||
|
out := make([]float64, len(vec))
|
||||||
|
for i, v := range vec {
|
||||||
|
out[i] = float64(v)
|
||||||
|
}
|
||||||
|
|
||||||
|
var b []byte
|
||||||
|
var err error
|
||||||
|
if queryCompact {
|
||||||
|
b, err = json.Marshal(out)
|
||||||
|
} else {
|
||||||
|
b, err = json.MarshalIndent(out, "", " ")
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("marshal vector: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println(string(b))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
121
cmd/vecna/search.go
Normal file
@@ -0,0 +1,121 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
|
||||||
|
"github.com/Warky-Devs/vecna.git/pkg/config"
|
||||||
|
"github.com/Warky-Devs/vecna.git/pkg/discovery"
|
||||||
|
)
|
||||||
|
|
||||||
|
var searchCmd = &cobra.Command{
|
||||||
|
Use: "search",
|
||||||
|
Short: "Scan the network for LLM servers and add one to the config",
|
||||||
|
RunE: runSearch,
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
rootCmd.AddCommand(searchCmd)
|
||||||
|
}
|
||||||
|
|
||||||
|
func runSearch(_ *cobra.Command, _ []string) error {
|
||||||
|
in := bufio.NewReader(os.Stdin)
|
||||||
|
|
||||||
|
fmt.Println("Scanning for LLM servers (Ollama, LM Studio, vLLM, LocalAI, Jan, Kobold, Tabby)...")
|
||||||
|
servers := discovery.Scan(context.Background())
|
||||||
|
|
||||||
|
if len(servers) == 0 {
|
||||||
|
fmt.Println("No servers found.")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("\nFound %d server(s):\n\n", len(servers))
|
||||||
|
for i, s := range servers {
|
||||||
|
modelList := strings.Join(s.Models, ", ")
|
||||||
|
if modelList == "" {
|
||||||
|
modelList = "(no models listed)"
|
||||||
|
}
|
||||||
|
fmt.Printf(" [%d] %s %s\n Models: %s\n\n", i+1, s.Kind.Name, s.BaseURL, modelList)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Select server
|
||||||
|
chosen, err := promptInt(in, fmt.Sprintf("Select server [1-%d]: ", len(servers)), 1, len(servers))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
srv := servers[chosen-1]
|
||||||
|
|
||||||
|
// Select model
|
||||||
|
var model string
|
||||||
|
switch {
|
||||||
|
case len(srv.Models) == 1:
|
||||||
|
model = srv.Models[0]
|
||||||
|
fmt.Printf("Using model: %s\n", model)
|
||||||
|
case len(srv.Models) > 1:
|
||||||
|
fmt.Println("\nAvailable models:")
|
||||||
|
for i, m := range srv.Models {
|
||||||
|
fmt.Printf(" [%d] %s\n", i+1, m)
|
||||||
|
}
|
||||||
|
idx, err := promptInt(in, fmt.Sprintf("Select model [1-%d]: ", len(srv.Models)), 1, len(srv.Models))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
model = srv.Models[idx-1]
|
||||||
|
default:
|
||||||
|
model, err = promptString(in, "Model name: ", "")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Target name in config
|
||||||
|
defaultName := strings.ToLower(strings.ReplaceAll(srv.Kind.Name, " ", "_"))
|
||||||
|
targetName, err := promptString(in, fmt.Sprintf("Target name in config [%s]: ", defaultName), defaultName)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// API key
|
||||||
|
keyPrompt := "API key (leave empty if none): "
|
||||||
|
if srv.Kind.NeedsKey {
|
||||||
|
keyPrompt = "API key: "
|
||||||
|
}
|
||||||
|
apiKey, err := promptString(in, keyPrompt, "")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
target := config.ForwardTarget{
|
||||||
|
APIType: srv.Kind.APIType,
|
||||||
|
Model: model,
|
||||||
|
APIKey: apiKey,
|
||||||
|
Endpoints: []config.EndpointConfig{
|
||||||
|
{URL: srv.BaseURL, Priority: 10},
|
||||||
|
},
|
||||||
|
TimeoutSecs: 30,
|
||||||
|
CooldownSecs: 60,
|
||||||
|
PriorityDecay: 2,
|
||||||
|
PriorityRecovery: 5,
|
||||||
|
}
|
||||||
|
|
||||||
|
cfgPath := config.ResolveFile(cfgFile)
|
||||||
|
|
||||||
|
// Create default config if it doesn't exist yet
|
||||||
|
if _, err := os.Stat(cfgPath); os.IsNotExist(err) {
|
||||||
|
if err := createDefaultConfig(cfgPath); err != nil {
|
||||||
|
return fmt.Errorf("create default config: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := config.SaveTarget(cfgPath, targetName, target); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("\nAdded target %q to %s\n", targetName, cfgPath)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
150
cmd/vecna/serve.go
Normal file
@@ -0,0 +1,150 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"os/signal"
|
||||||
|
"syscall"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
|
||||||
|
"github.com/Warky-Devs/vecna.git/pkg/config"
|
||||||
|
"github.com/Warky-Devs/vecna.git/pkg/embedclient"
|
||||||
|
"github.com/Warky-Devs/vecna.git/pkg/metrics"
|
||||||
|
"github.com/Warky-Devs/vecna.git/pkg/server"
|
||||||
|
)
|
||||||
|
|
||||||
|
var serveCmd = &cobra.Command{
|
||||||
|
Use: "serve",
|
||||||
|
Short: "Start the vecna embedding proxy server",
|
||||||
|
RunE: runServe,
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
serveCmd.Flags().String("host", "", "bind host (overrides config)")
|
||||||
|
serveCmd.Flags().Int("port", 0, "bind port (overrides config)")
|
||||||
|
}
|
||||||
|
|
||||||
|
func runServe(cmd *cobra.Command, _ []string) error {
|
||||||
|
cfg, err := config.Load(cfgFile)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("load config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Flag overrides
|
||||||
|
if h, _ := cmd.Flags().GetString("host"); h != "" {
|
||||||
|
cfg.Server.Host = h
|
||||||
|
}
|
||||||
|
if p, _ := cmd.Flags().GetInt("port"); p != 0 {
|
||||||
|
cfg.Server.Port = p
|
||||||
|
}
|
||||||
|
|
||||||
|
adp, err := buildAdapter(cfg)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("build adapter: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
clients, err := buildClients(cfg)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("build clients: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var reg *metrics.Registry
|
||||||
|
if cfg.Metrics.Enabled {
|
||||||
|
reg = metrics.New()
|
||||||
|
}
|
||||||
|
|
||||||
|
router := server.New(cfg, clients, adp, reg, logger)
|
||||||
|
|
||||||
|
addr := fmt.Sprintf("%s:%d", cfg.Server.Host, cfg.Server.Port)
|
||||||
|
srv := &http.Server{
|
||||||
|
Addr: addr,
|
||||||
|
Handler: router,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Graceful shutdown
|
||||||
|
quit := make(chan os.Signal, 1)
|
||||||
|
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
logger.Info("vecna listening", zap.String("addr", addr))
|
||||||
|
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||||
|
logger.Error("server error", zap.Error(err))
|
||||||
|
quit <- syscall.SIGTERM
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
<-quit
|
||||||
|
logger.Info("shutting down")
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
if err := srv.Shutdown(ctx); err != nil {
|
||||||
|
return fmt.Errorf("graceful shutdown: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildClients constructs one embedclient.Client per named forward target.
|
||||||
|
func buildClients(cfg *config.Config) (map[string]embedclient.Client, error) {
|
||||||
|
clients := make(map[string]embedclient.Client, len(cfg.Forward.Targets))
|
||||||
|
|
||||||
|
for name, target := range cfg.Forward.Targets {
|
||||||
|
if len(target.Endpoints) == 0 {
|
||||||
|
return nil, fmt.Errorf("target %q has no endpoints", name)
|
||||||
|
}
|
||||||
|
|
||||||
|
timeout := time.Duration(target.TimeoutSecs) * time.Second
|
||||||
|
httpClient := &http.Client{Timeout: timeout}
|
||||||
|
|
||||||
|
slots := make([]embedclient.RouterSlot, len(target.Endpoints))
|
||||||
|
for i, ep := range target.Endpoints {
|
||||||
|
apiKey := ep.APIKey
|
||||||
|
if apiKey == "" {
|
||||||
|
apiKey = target.APIKey
|
||||||
|
}
|
||||||
|
|
||||||
|
var c embedclient.Client
|
||||||
|
switch target.APIType {
|
||||||
|
case "google":
|
||||||
|
c = embedclient.NewGoogle(ep.URL, apiKey, target.Model, httpClient)
|
||||||
|
default: // "openai" or unset
|
||||||
|
c = embedclient.NewOpenAI(ep.URL, apiKey, httpClient)
|
||||||
|
}
|
||||||
|
|
||||||
|
slots[i] = embedclient.RouterSlot{
|
||||||
|
Client: c,
|
||||||
|
URL: ep.URL,
|
||||||
|
Priority: ep.Priority,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
routerCfg := embedclient.RouterConfig{
|
||||||
|
TargetName: name,
|
||||||
|
TimeoutSecs: target.TimeoutSecs,
|
||||||
|
CooldownSecs: target.CooldownSecs,
|
||||||
|
PriorityDecay: target.PriorityDecay,
|
||||||
|
PriorityRecovery: target.PriorityRecovery,
|
||||||
|
}
|
||||||
|
|
||||||
|
// metrics registry may be nil (disabled)
|
||||||
|
var reg *metrics.Registry
|
||||||
|
if cfg.Metrics.Enabled {
|
||||||
|
// Registry is created in runServe before buildClients; pass nil here,
|
||||||
|
// caller wires it in after creation. See note below.
|
||||||
|
_ = reg
|
||||||
|
}
|
||||||
|
|
||||||
|
router, err := embedclient.NewTargetRouter(slots, routerCfg, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("build router for target %q: %w", name, err)
|
||||||
|
}
|
||||||
|
clients[name] = router
|
||||||
|
}
|
||||||
|
|
||||||
|
return clients, nil
|
||||||
|
}
|
||||||
150
cmd/vecna/test.go
Normal file
@@ -0,0 +1,150 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"math"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
|
||||||
|
"github.com/Warky-Devs/vecna.git/pkg/config"
|
||||||
|
"github.com/Warky-Devs/vecna.git/pkg/embedclient"
|
||||||
|
)
|
||||||
|
|
||||||
|
var removeBroken bool
|
||||||
|
|
||||||
|
var testCmd = &cobra.Command{
|
||||||
|
Use: "test",
|
||||||
|
Short: "Send a test embedding request to each configured forwarder",
|
||||||
|
RunE: runTest,
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
testCmd.Flags().BoolVar(&removeBroken, "remove-broken", false,
|
||||||
|
"Remove endpoints (and targets with no endpoints left) that fail the test from the config file")
|
||||||
|
rootCmd.AddCommand(testCmd)
|
||||||
|
}
|
||||||
|
|
||||||
|
const testPhrase = "The quick brown fox jumps over the lazy dog"
|
||||||
|
|
||||||
|
func runTest(_ *cobra.Command, _ []string) error {
|
||||||
|
cfg, err := config.Load(cfgFile)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("load config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(cfg.Forward.Targets) == 0 {
|
||||||
|
fmt.Println("No forwarder targets configured.")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("Test phrase: %q\n\n", testPhrase)
|
||||||
|
|
||||||
|
passed, failed := 0, 0
|
||||||
|
// brokenEndpoints maps target name → list of failing endpoint URLs
|
||||||
|
brokenEndpoints := make(map[string][]string)
|
||||||
|
|
||||||
|
for targetName, target := range cfg.Forward.Targets {
|
||||||
|
fmt.Printf("[ %s ] model: %s type: %s\n", targetName, target.Model, target.APIType)
|
||||||
|
|
||||||
|
timeout := time.Duration(target.TimeoutSecs) * time.Second
|
||||||
|
if timeout == 0 {
|
||||||
|
timeout = 30 * time.Second
|
||||||
|
}
|
||||||
|
httpClient := &http.Client{Timeout: timeout}
|
||||||
|
|
||||||
|
for _, ep := range target.Endpoints {
|
||||||
|
apiKey := ep.APIKey
|
||||||
|
if apiKey == "" {
|
||||||
|
apiKey = target.APIKey
|
||||||
|
}
|
||||||
|
|
||||||
|
var client embedclient.Client
|
||||||
|
switch target.APIType {
|
||||||
|
case "google":
|
||||||
|
client = embedclient.NewGoogle(ep.URL, apiKey, target.Model, httpClient)
|
||||||
|
default:
|
||||||
|
client = embedclient.NewOpenAI(ep.URL, apiKey, httpClient)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||||
|
start := time.Now()
|
||||||
|
resp, embedErr := client.Embed(ctx, embedclient.Request{
|
||||||
|
Texts: []string{testPhrase},
|
||||||
|
Model: target.Model,
|
||||||
|
})
|
||||||
|
elapsed := time.Since(start)
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
if embedErr != nil {
|
||||||
|
fmt.Printf(" %-45s FAIL %s\n", ep.URL, truncate(embedErr.Error(), 60))
|
||||||
|
brokenEndpoints[targetName] = append(brokenEndpoints[targetName], ep.URL)
|
||||||
|
failed++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
dims, norm := embeddingStats(resp.Embeddings)
|
||||||
|
fmt.Printf(" %-45s OK %dms dims=%d norm=%.4f\n",
|
||||||
|
ep.URL, elapsed.Milliseconds(), dims, norm)
|
||||||
|
passed++
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(target.Endpoints) == 0 {
|
||||||
|
fmt.Println(" (no endpoints configured)")
|
||||||
|
}
|
||||||
|
fmt.Println()
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("Results: %d passed, %d failed\n", passed, failed)
|
||||||
|
|
||||||
|
if removeBroken && len(brokenEndpoints) > 0 {
|
||||||
|
if err := applyRemoveBroken(brokenEndpoints); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if failed > 0 {
|
||||||
|
return fmt.Errorf("%d forwarder(s) failed", failed)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func applyRemoveBroken(broken map[string][]string) error {
|
||||||
|
cfgPath := config.ResolveFile(cfgFile)
|
||||||
|
removed, err := config.RemoveBrokenEndpoints(cfgPath, broken)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("remove broken: %w", err)
|
||||||
|
}
|
||||||
|
if len(removed) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
fmt.Println("\nRemoved from config:")
|
||||||
|
for _, r := range removed {
|
||||||
|
fmt.Printf(" - %s\n", r)
|
||||||
|
}
|
||||||
|
fmt.Printf("Config updated: %s\n", cfgPath)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// embeddingStats returns the dimension count and L2 norm of the first embedding.
|
||||||
|
func embeddingStats(embeddings [][]float32) (dims int, norm float32) {
|
||||||
|
if len(embeddings) == 0 || len(embeddings[0]) == 0 {
|
||||||
|
return 0, 0
|
||||||
|
}
|
||||||
|
vec := embeddings[0]
|
||||||
|
dims = len(vec)
|
||||||
|
var sum float64
|
||||||
|
for _, v := range vec {
|
||||||
|
sum += float64(v) * float64(v)
|
||||||
|
}
|
||||||
|
return dims, float32(math.Sqrt(sum))
|
||||||
|
}
|
||||||
|
|
||||||
|
func truncate(s string, n int) string {
|
||||||
|
if len(s) <= n {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
return s[:n-3] + "..."
|
||||||
|
}
|
||||||
15
cmd/vecna/version.go
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
)
|
||||||
|
|
||||||
|
var versionCmd = &cobra.Command{
|
||||||
|
Use: "version",
|
||||||
|
Short: "Print the version",
|
||||||
|
Run: func(cmd *cobra.Command, args []string) {
|
||||||
|
fmt.Println(version)
|
||||||
|
},
|
||||||
|
}
|
||||||
82
docker-compose.example.yml
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
services:
|
||||||
|
|
||||||
|
# ── vecna proxy ─────────────────────────────────────────────────────────────
|
||||||
|
vecna:
|
||||||
|
build: .
|
||||||
|
# image: ghcr.io/warky-devs/vecna:latest
|
||||||
|
ports:
|
||||||
|
- "8080:8080"
|
||||||
|
volumes:
|
||||||
|
- vecna_config:/config
|
||||||
|
environment:
|
||||||
|
VECNA_SERVER_PORT: 8080
|
||||||
|
# VECNA_SERVER_API_KEYS: sk-vecna-abc123,sk-vecna-def456
|
||||||
|
command: ["serve", "--config", "/config/vecna.json"]
|
||||||
|
restart: unless-stopped
|
||||||
|
depends_on:
|
||||||
|
ollama:
|
||||||
|
condition: service_healthy
|
||||||
|
|
||||||
|
# ── ollama (local embedding model) ──────────────────────────────────────────
|
||||||
|
ollama:
|
||||||
|
image: ollama/ollama:latest
|
||||||
|
ports:
|
||||||
|
- "11434:11434"
|
||||||
|
volumes:
|
||||||
|
- ollama_data:/root/.ollama
|
||||||
|
healthcheck:
|
||||||
|
test: ["CMD", "ollama", "list"]
|
||||||
|
interval: 10s
|
||||||
|
timeout: 5s
|
||||||
|
retries: 6
|
||||||
|
start_period: 20s
|
||||||
|
restart: unless-stopped
|
||||||
|
|
||||||
|
# ── pull the embedding model on first start ──────────────────────────────────
|
||||||
|
# Remove this service after the model has been pulled once.
|
||||||
|
ollama-pull:
|
||||||
|
image: ollama/ollama:latest
|
||||||
|
depends_on:
|
||||||
|
ollama:
|
||||||
|
condition: service_healthy
|
||||||
|
environment:
|
||||||
|
OLLAMA_HOST: http://ollama:11434
|
||||||
|
entrypoint: ["ollama", "pull", "nomic-embed-text"]
|
||||||
|
restart: "no"
|
||||||
|
|
||||||
|
# ── prometheus (optional, for metrics scraping) ──────────────────────────────
|
||||||
|
# Requires metrics.enabled: true in /config/vecna.json
|
||||||
|
prometheus:
|
||||||
|
image: prom/prometheus:latest
|
||||||
|
ports:
|
||||||
|
- "9090:9090"
|
||||||
|
volumes:
|
||||||
|
- ./prometheus.example.yml:/etc/prometheus/prometheus.yml:ro
|
||||||
|
restart: unless-stopped
|
||||||
|
profiles:
|
||||||
|
- metrics
|
||||||
|
|
||||||
|
volumes:
|
||||||
|
ollama_data:
|
||||||
|
vecna_config: # persists vecna.json across container rebuilds
|
||||||
|
|
||||||
|
# ── one-off commands ──────────────────────────────────────────────────────────
|
||||||
|
#
|
||||||
|
# Run the interactive onboard wizard (writes config into the vecna_config volume):
|
||||||
|
#
|
||||||
|
# docker compose run --rm -it vecna onboard --config /config/vecna.json
|
||||||
|
#
|
||||||
|
# The wizard will discover the ollama service on the Docker network at
|
||||||
|
# http://ollama:11434 (select it from the list or enter the URL manually).
|
||||||
|
#
|
||||||
|
# Test all configured endpoints after onboarding:
|
||||||
|
#
|
||||||
|
# docker compose run --rm vecna test --config /config/vecna.json
|
||||||
|
#
|
||||||
|
# Remove broken endpoints automatically:
|
||||||
|
#
|
||||||
|
# docker compose run --rm vecna test --config /config/vecna.json --remove-broken
|
||||||
|
#
|
||||||
|
# Open the config in a shell editor (requires the alpine image):
|
||||||
|
#
|
||||||
|
# docker compose run --rm -it vecna sh -c "vi /config/vecna.json"
|
||||||
BIN
docs/images/a.jpg
Normal file
|
After Width: | Height: | Size: 9.5 KiB |
BIN
docs/images/a1.jpg
Normal file
|
After Width: | Height: | Size: 21 KiB |
BIN
docs/images/b.jpg
Normal file
|
After Width: | Height: | Size: 69 KiB |
BIN
docs/images/b1.jpg
Normal file
|
After Width: | Height: | Size: 36 KiB |
BIN
docs/images/c.jpg
Normal file
|
After Width: | Height: | Size: 153 KiB |
BIN
docs/images/c1.jpg
Normal file
|
After Width: | Height: | Size: 302 KiB |
BIN
docs/images/d.jpg
Normal file
|
After Width: | Height: | Size: 176 KiB |
BIN
docs/images/d1.jpeg
Normal file
|
After Width: | Height: | Size: 54 KiB |
BIN
docs/images/e.jpg
Normal file
|
After Width: | Height: | Size: 17 KiB |
BIN
docs/images/e1.jpg
Normal file
|
After Width: | Height: | Size: 26 KiB |
BIN
docs/images/f.jpg
Normal file
|
After Width: | Height: | Size: 88 KiB |
BIN
docs/images/f1.jpg
Normal file
|
After Width: | Height: | Size: 19 KiB |
355
docs/what_is_embeddings.md
Normal file
@@ -0,0 +1,355 @@
|
|||||||
|
**Practical breakdown** of embeddings: how they work, how they’re trained, and concrete examples you can actually implement.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
# 🧠 What embeddings actually are (intuitively + formally)
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
At the core:
|
||||||
|
|
||||||
|
* An embedding is a **function**
|
||||||
|
[
|
||||||
|
f(x) \rightarrow \mathbb{R}^d
|
||||||
|
]
|
||||||
|
|
||||||
|
* It maps raw data (text, image, etc.) into a **dense vector**.
|
||||||
|
|
||||||
|
Key property:
|
||||||
|
|
||||||
|
* **Semantic structure is encoded as geometry**
|
||||||
|
|
||||||
|
* Similar things → vectors close together
|
||||||
|
* Different things → far apart
|
||||||
|
|
||||||
|
This is the core idea behind vector search, RAG, clustering, etc. ([ibm.com][1])
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
# ⚙️ Why embeddings work (the underlying theory)
|
||||||
|
|
||||||
|
### 1. Distributional hypothesis
|
||||||
|
|
||||||
|
> “You shall know a word by the company it keeps”
|
||||||
|
|
||||||
|
* Words appearing in similar contexts → similar vectors
|
||||||
|
* This is the **foundation of Word2Vec, GloVe, BERT, etc.** ([ibm.com][2])
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 2. Geometry encodes meaning
|
||||||
|
|
||||||
|
Classic example:
|
||||||
|
|
||||||
|
```
|
||||||
|
king - man + woman ≈ queen
|
||||||
|
```
|
||||||
|
|
||||||
|
This works because relationships become **linear directions in vector space**.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 3. Dense vs sparse representations
|
||||||
|
|
||||||
|
| Method | Problem |
|
||||||
|
| ---------- | ------------------ |
|
||||||
|
| One-hot | huge, no meaning |
|
||||||
|
| TF-IDF | frequency only |
|
||||||
|
| Embeddings | compact + semantic |
|
||||||
|
|
||||||
|
Embeddings are **low-dimensional but information-rich** representations. ([GeeksforGeeks][3])
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
# 🏗️ How embeddings are trained (core methods)
|
||||||
|
|
||||||
|
## 1. Prediction-based models (most important)
|
||||||
|
|
||||||
|
### Word2Vec (classic foundation)
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
Two main training strategies:
|
||||||
|
|
||||||
|
### (a) Skip-gram
|
||||||
|
|
||||||
|
Predict context from a word:
|
||||||
|
|
||||||
|
[
|
||||||
|
P(context \mid word)
|
||||||
|
]
|
||||||
|
|
||||||
|
### (b) CBOW
|
||||||
|
|
||||||
|
Predict word from context:
|
||||||
|
|
||||||
|
[
|
||||||
|
P(word \mid context)
|
||||||
|
]
|
||||||
|
|
||||||
|
Mechanism:
|
||||||
|
|
||||||
|
* Input = one-hot vector
|
||||||
|
* Hidden layer = embedding
|
||||||
|
* Train via gradient descent
|
||||||
|
|
||||||
|
👉 The embedding is literally the **weights of the hidden layer** ([Medium][4])
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Example (Skip-gram training loop)
|
||||||
|
|
||||||
|
```python
|
||||||
|
# pseudo-code
|
||||||
|
for word in corpus:
|
||||||
|
context = get_context_window(word)
|
||||||
|
loss = -log P(context | word)
|
||||||
|
update_weights()
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 2. Matrix factorization (GloVe)
|
||||||
|
|
||||||
|
Instead of prediction:
|
||||||
|
|
||||||
|
* Build co-occurrence matrix
|
||||||
|
* Factorize it into lower dimensions
|
||||||
|
|
||||||
|
Captures **global statistics**, not just local context. ([Medium][5])
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 3. Neural embedding layers (modern approach)
|
||||||
|
|
||||||
|
Used in:
|
||||||
|
|
||||||
|
* Transformers (BERT, GPT)
|
||||||
|
* Recommender systems
|
||||||
|
|
||||||
|
Mechanism:
|
||||||
|
|
||||||
|
* Embedding = **lookup table**
|
||||||
|
* Trained jointly with model
|
||||||
|
|
||||||
|
```python
|
||||||
|
embedding = torch.nn.Embedding(vocab_size, dim)
|
||||||
|
vector = embedding(token_id)
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 4. Contrastive learning (modern SOTA)
|
||||||
|
|
||||||
|
Used in:
|
||||||
|
|
||||||
|
* sentence embeddings
|
||||||
|
* CLIP (image-text)
|
||||||
|
* OpenAI embeddings
|
||||||
|
|
||||||
|
Core idea:
|
||||||
|
|
||||||
|
[
|
||||||
|
\text{similar pairs} \rightarrow \text{closer}
|
||||||
|
]
|
||||||
|
[
|
||||||
|
\text{different pairs} \rightarrow \text{farther}
|
||||||
|
]
|
||||||
|
|
||||||
|
Loss function:
|
||||||
|
|
||||||
|
[
|
||||||
|
\mathcal{L} = -\log \frac{e^{sim(x_i, x_j)}}{\sum_k e^{sim(x_i, x_k)}}
|
||||||
|
]
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
# 🔬 How modern embeddings (LLMs) differ
|
||||||
|
|
||||||
|
Older:
|
||||||
|
|
||||||
|
* static embeddings (Word2Vec)
|
||||||
|
|
||||||
|
Modern:
|
||||||
|
|
||||||
|
* **contextual embeddings**
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
* “bank” (river vs finance) → different vectors
|
||||||
|
|
||||||
|
This is why models like BERT/GPT outperform Word2Vec.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
# 🧪 Practical training examples
|
||||||
|
|
||||||
|
## Example 1 — Train Word2Vec (Gensim)
|
||||||
|
|
||||||
|
```python
|
||||||
|
from gensim.models import Word2Vec
|
||||||
|
|
||||||
|
sentences = [["cat", "sat", "mat"], ["dog", "sat", "floor"]]
|
||||||
|
|
||||||
|
model = Word2Vec(sentences, vector_size=100, window=5, min_count=1)
|
||||||
|
|
||||||
|
vector = model.wv["cat"]
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Example 2 — Train embeddings in PyTorch
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
embedding = nn.Embedding(10000, 128) # vocab, dim
|
||||||
|
|
||||||
|
input_ids = torch.tensor([1, 5, 23])
|
||||||
|
vectors = embedding(input_ids)
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Example 3 — Train contrastive embeddings
|
||||||
|
|
||||||
|
```python
|
||||||
|
# pseudo
|
||||||
|
anchor = model(text1)
|
||||||
|
positive = model(text2)
|
||||||
|
negative = model(text3)
|
||||||
|
|
||||||
|
loss = contrastive_loss(anchor, positive, negative)
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Example 4 — PCA reduction (your earlier question)
|
||||||
|
|
||||||
|
```python
|
||||||
|
from sklearn.decomposition import PCA
|
||||||
|
|
||||||
|
pca = PCA(n_components=256)
|
||||||
|
X_reduced = pca.fit_transform(X)
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
# 📊 Types of embeddings
|
||||||
|
|
||||||
|
| Type | Example |
|
||||||
|
| ---------- | --------------- |
|
||||||
|
| Word | Word2Vec, GloVe |
|
||||||
|
| Sentence | SBERT |
|
||||||
|
| Document | Doc2Vec |
|
||||||
|
| Image | CLIP |
|
||||||
|
| Graph | Node2Vec |
|
||||||
|
| Multimodal | CLIP, Gemini |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
# 🧩 Key properties you should care about (engineering perspective)
|
||||||
|
|
||||||
|
### 1. Dimensionality
|
||||||
|
|
||||||
|
* Typical: 128–1536
|
||||||
|
* Tradeoff: memory vs accuracy
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 2. Distance metric
|
||||||
|
|
||||||
|
* cosine similarity (most common)
|
||||||
|
* dot product
|
||||||
|
* Euclidean
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 3. Normalization
|
||||||
|
|
||||||
|
Critical for:
|
||||||
|
|
||||||
|
* search quality
|
||||||
|
* clustering
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 4. Training data distribution
|
||||||
|
|
||||||
|
Embeddings are only as good as:
|
||||||
|
|
||||||
|
* corpus size
|
||||||
|
* domain relevance
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
# ⚠️ Common pitfalls (important)
|
||||||
|
|
||||||
|
### ❌ Mixing embedding spaces
|
||||||
|
|
||||||
|
* embeddings from different models are **not compatible**
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### ❌ Assuming linear compression is harmless
|
||||||
|
|
||||||
|
* PCA can distort semantic relationships
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### ❌ Ignoring normalization
|
||||||
|
|
||||||
|
* cosine similarity breaks without it
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### ❌ Using embeddings without evaluation
|
||||||
|
|
||||||
|
Always test:
|
||||||
|
|
||||||
|
* retrieval accuracy
|
||||||
|
* clustering quality
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
# 🧠 Mental model (most useful takeaway)
|
||||||
|
|
||||||
|
Think of embeddings as:
|
||||||
|
|
||||||
|
> A learned coordinate system where **meaning = position**
|
||||||
|
|
||||||
|
Training = learning that coordinate system so that:
|
||||||
|
|
||||||
|
* similar things cluster
|
||||||
|
* relationships become directions
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
|
||||||
|
[1]: https://www.ibm.com/think/topics/vector-embedding?utm_source=chatgpt.com "What is Vector Embedding? | IBM"
|
||||||
|
[2]: https://www.ibm.com/think/topics/word-embeddings?utm_source=chatgpt.com "What Are Word Embeddings? | IBM"
|
||||||
|
[3]: https://www.geeksforgeeks.org/nlp/word-embeddings-in-nlp/?utm_source=chatgpt.com "Word Embeddings in NLP"
|
||||||
|
[4]: https://medium.com/%40manansuri/a-dummys-guide-to-word2vec-456444f3c673?utm_source=chatgpt.com "A Dummy's Guide to Word2Vec - Medium"
|
||||||
|
[5]: https://medium.com/%40neri.vvo/word-embedding-a-powerful-tool-word2vec-glove-fasttext-dd6e2171d5?utm_source=chatgpt.com "Word Embedding Explained — Word2Vec GloVe, FastText"
|
||||||
|
|
||||||
40
go.mod
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
module github.com/Warky-Devs/vecna.git
|
||||||
|
|
||||||
|
go 1.26.1
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/prometheus/client_golang v1.23.2
|
||||||
|
github.com/spf13/cobra v1.10.2
|
||||||
|
github.com/spf13/viper v1.21.0
|
||||||
|
github.com/stretchr/testify v1.11.1
|
||||||
|
github.com/uptrace/bunrouter v1.0.23
|
||||||
|
go.uber.org/zap v1.27.1
|
||||||
|
)
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/beorn7/perks v1.0.1 // indirect
|
||||||
|
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||||
|
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||||
|
github.com/fsnotify/fsnotify v1.9.0 // indirect
|
||||||
|
github.com/go-viper/mapstructure/v2 v2.4.0 // indirect
|
||||||
|
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||||
|
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
|
||||||
|
github.com/pelletier/go-toml/v2 v2.2.4 // indirect
|
||||||
|
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||||
|
github.com/prometheus/client_model v0.6.2 // indirect
|
||||||
|
github.com/prometheus/common v0.66.1 // indirect
|
||||||
|
github.com/prometheus/procfs v0.16.1 // indirect
|
||||||
|
github.com/sagikazarmark/locafero v0.11.0 // indirect
|
||||||
|
github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 // indirect
|
||||||
|
github.com/spf13/afero v1.15.0 // indirect
|
||||||
|
github.com/spf13/cast v1.10.0 // indirect
|
||||||
|
github.com/spf13/pflag v1.0.10 // indirect
|
||||||
|
github.com/subosito/gotenv v1.6.0 // indirect
|
||||||
|
go.uber.org/multierr v1.10.0 // indirect
|
||||||
|
go.yaml.in/yaml/v2 v2.4.2 // indirect
|
||||||
|
go.yaml.in/yaml/v3 v3.0.4 // indirect
|
||||||
|
golang.org/x/sys v0.35.0 // indirect
|
||||||
|
golang.org/x/text v0.28.0 // indirect
|
||||||
|
google.golang.org/protobuf v1.36.8 // indirect
|
||||||
|
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||||
|
)
|
||||||
84
go.sum
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
|
||||||
|
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
|
||||||
|
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
|
||||||
|
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||||
|
github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g=
|
||||||
|
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||||
|
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
|
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
|
||||||
|
github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
|
||||||
|
github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k=
|
||||||
|
github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
|
||||||
|
github.com/go-viper/mapstructure/v2 v2.4.0 h1:EBsztssimR/CONLSZZ04E8qAkxNYq4Qp9LvH92wZUgs=
|
||||||
|
github.com/go-viper/mapstructure/v2 v2.4.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM=
|
||||||
|
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||||
|
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||||
|
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
|
||||||
|
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
|
||||||
|
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
|
||||||
|
github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
|
||||||
|
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||||
|
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
||||||
|
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||||
|
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||||
|
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
|
||||||
|
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
|
||||||
|
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA=
|
||||||
|
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
|
||||||
|
github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4=
|
||||||
|
github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
|
||||||
|
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||||
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
|
github.com/prometheus/client_golang v1.23.2 h1:Je96obch5RDVy3FDMndoUsjAhG5Edi49h0RJWRi/o0o=
|
||||||
|
github.com/prometheus/client_golang v1.23.2/go.mod h1:Tb1a6LWHB3/SPIzCoaDXI4I8UHKeFTEQ1YCr+0Gyqmg=
|
||||||
|
github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk=
|
||||||
|
github.com/prometheus/client_model v0.6.2/go.mod h1:y3m2F6Gdpfy6Ut/GBsUqTWZqCUvMVzSfMLjcu6wAwpE=
|
||||||
|
github.com/prometheus/common v0.66.1 h1:h5E0h5/Y8niHc5DlaLlWLArTQI7tMrsfQjHV+d9ZoGs=
|
||||||
|
github.com/prometheus/common v0.66.1/go.mod h1:gcaUsgf3KfRSwHY4dIMXLPV0K/Wg1oZ8+SbZk/HH/dA=
|
||||||
|
github.com/prometheus/procfs v0.16.1 h1:hZ15bTNuirocR6u0JZ6BAHHmwS1p8B4P6MRqxtzMyRg=
|
||||||
|
github.com/prometheus/procfs v0.16.1/go.mod h1:teAbpZRB1iIAJYREa1LsoWUXykVXA1KlTmWl8x/U+Is=
|
||||||
|
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
|
||||||
|
github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog=
|
||||||
|
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
|
||||||
|
github.com/sagikazarmark/locafero v0.11.0 h1:1iurJgmM9G3PA/I+wWYIOw/5SyBtxapeHDcg+AAIFXc=
|
||||||
|
github.com/sagikazarmark/locafero v0.11.0/go.mod h1:nVIGvgyzw595SUSUE6tvCp3YYTeHs15MvlmU87WwIik=
|
||||||
|
github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 h1:+jumHNA0Wrelhe64i8F6HNlS8pkoyMv5sreGx2Ry5Rw=
|
||||||
|
github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8/go.mod h1:3n1Cwaq1E1/1lhQhtRK2ts/ZwZEhjcQeJQ1RuC6Q/8U=
|
||||||
|
github.com/spf13/afero v1.15.0 h1:b/YBCLWAJdFWJTN9cLhiXXcD7mzKn9Dm86dNnfyQw1I=
|
||||||
|
github.com/spf13/afero v1.15.0/go.mod h1:NC2ByUVxtQs4b3sIUphxK0NioZnmxgyCrfzeuq8lxMg=
|
||||||
|
github.com/spf13/cast v1.10.0 h1:h2x0u2shc1QuLHfxi+cTJvs30+ZAHOGRic8uyGTDWxY=
|
||||||
|
github.com/spf13/cast v1.10.0/go.mod h1:jNfB8QC9IA6ZuY2ZjDp0KtFO2LZZlg4S/7bzP6qqeHo=
|
||||||
|
github.com/spf13/cobra v1.10.2 h1:DMTTonx5m65Ic0GOoRY2c16WCbHxOOw6xxezuLaBpcU=
|
||||||
|
github.com/spf13/cobra v1.10.2/go.mod h1:7C1pvHqHw5A4vrJfjNwvOdzYu0Gml16OCs2GRiTUUS4=
|
||||||
|
github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
||||||
|
github.com/spf13/pflag v1.0.10 h1:4EBh2KAYBwaONj6b2Ye1GiHfwjqyROoF4RwYO+vPwFk=
|
||||||
|
github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
||||||
|
github.com/spf13/viper v1.21.0 h1:x5S+0EU27Lbphp4UKm1C+1oQO+rKx36vfCoaVebLFSU=
|
||||||
|
github.com/spf13/viper v1.21.0/go.mod h1:P0lhsswPGWD/1lZJ9ny3fYnVqxiegrlNrEmgLjbTCAY=
|
||||||
|
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||||
|
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||||
|
github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
|
||||||
|
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
|
||||||
|
github.com/uptrace/bunrouter v1.0.23 h1:Bi7NKw3uCQkcA/GUCtDNPq5LE5UdR9pe+UyWbjHB/wU=
|
||||||
|
github.com/uptrace/bunrouter v1.0.23/go.mod h1:O3jAcl+5qgnF+ejhgkmbceEk0E/mqaK+ADOocdNpY8M=
|
||||||
|
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
|
||||||
|
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
|
||||||
|
go.uber.org/multierr v1.10.0 h1:S0h4aNzvfcFsC3dRF1jLoaov7oRaKqRGC/pUEJ2yvPQ=
|
||||||
|
go.uber.org/multierr v1.10.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
|
||||||
|
go.uber.org/zap v1.27.1 h1:08RqriUEv8+ArZRYSTXy1LeBScaMpVSTBhCeaZYfMYc=
|
||||||
|
go.uber.org/zap v1.27.1/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
|
||||||
|
go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI=
|
||||||
|
go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU=
|
||||||
|
go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc=
|
||||||
|
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
|
||||||
|
golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI=
|
||||||
|
golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||||
|
golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng=
|
||||||
|
golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU=
|
||||||
|
google.golang.org/protobuf v1.36.8 h1:xHScyCOEuuwZEc6UtSOvPbAT4zRh0xcNRYekJwfqyMc=
|
||||||
|
google.golang.org/protobuf v1.36.8/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU=
|
||||||
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
|
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||||
|
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||||
|
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||||
|
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||||
77
pkg/adapter/adapter.go
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
package adapter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ErrDimMismatch is returned when the input vector length does not match the adapter's source dimension.
|
||||||
|
var ErrDimMismatch = errors.New("vector dimension mismatch")
|
||||||
|
|
||||||
|
// ErrInvalidDim is returned when source or target dimensions are invalid.
|
||||||
|
var ErrInvalidDim = errors.New("invalid dimension")
|
||||||
|
|
||||||
|
// ErrInvalidMatrix is returned when a projection matrix has the wrong shape.
|
||||||
|
var ErrInvalidMatrix = errors.New("invalid projection matrix shape")
|
||||||
|
|
||||||
|
// TruncateMode controls which end of the vector is dropped when downscaling.
|
||||||
|
type TruncateMode int
|
||||||
|
|
||||||
|
const (
|
||||||
|
// TruncateFromEnd keeps the first targetDim elements (default; correct for Matryoshka models).
|
||||||
|
TruncateFromEnd TruncateMode = iota
|
||||||
|
// TruncateFromStart keeps the last targetDim elements.
|
||||||
|
TruncateFromStart
|
||||||
|
)
|
||||||
|
|
||||||
|
// PadMode controls which end of the vector receives zero-padding when upscaling.
|
||||||
|
type PadMode int
|
||||||
|
|
||||||
|
const (
|
||||||
|
// PadAtEnd appends zeros to the end (default).
|
||||||
|
PadAtEnd PadMode = iota
|
||||||
|
// PadAtStart prepends zeros to the start.
|
||||||
|
PadAtStart
|
||||||
|
)
|
||||||
|
|
||||||
|
// Adapter translates vectors between two fixed dimensions.
|
||||||
|
type Adapter interface {
|
||||||
|
Adapt(vec []float32) ([]float32, error)
|
||||||
|
SourceDim() int
|
||||||
|
TargetDim() int
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewTruncate returns a TruncateAdapter for Matryoshka-style or simple truncation/padding.
|
||||||
|
// t controls which end is dropped when downscaling; p controls which end is padded when upscaling.
|
||||||
|
func NewTruncate(sourceDim, targetDim int, t TruncateMode, p PadMode) (Adapter, error) {
|
||||||
|
if sourceDim <= 0 || targetDim <= 0 {
|
||||||
|
return nil, fmt.Errorf("NewTruncate: %w: source=%d target=%d", ErrInvalidDim, sourceDim, targetDim)
|
||||||
|
}
|
||||||
|
return &truncateAdapter{sourceDim: sourceDim, targetDim: targetDim, truncMode: t, padMode: p}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewRandom returns a RandomAdapter backed by a seeded Gaussian projection matrix.
|
||||||
|
// seed=0 uses a time-based seed.
|
||||||
|
func NewRandom(sourceDim, targetDim int, seed int64) (Adapter, error) {
|
||||||
|
if sourceDim <= 0 || targetDim <= 0 {
|
||||||
|
return nil, fmt.Errorf("NewRandom: %w: source=%d target=%d", ErrInvalidDim, sourceDim, targetDim)
|
||||||
|
}
|
||||||
|
return newRandomAdapter(sourceDim, targetDim, seed), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewProjection returns a ProjectionAdapter using a caller-supplied matrix.
|
||||||
|
// matrix must have shape [targetDim][sourceDim].
|
||||||
|
func NewProjection(sourceDim, targetDim int, matrix [][]float32) (Adapter, error) {
|
||||||
|
if sourceDim <= 0 || targetDim <= 0 {
|
||||||
|
return nil, fmt.Errorf("NewProjection: %w: source=%d target=%d", ErrInvalidDim, sourceDim, targetDim)
|
||||||
|
}
|
||||||
|
if len(matrix) != targetDim {
|
||||||
|
return nil, fmt.Errorf("NewProjection: %w: got %d rows, want %d", ErrInvalidMatrix, len(matrix), targetDim)
|
||||||
|
}
|
||||||
|
for i, row := range matrix {
|
||||||
|
if len(row) != sourceDim {
|
||||||
|
return nil, fmt.Errorf("NewProjection: %w: row %d has %d cols, want %d", ErrInvalidMatrix, i, len(row), sourceDim)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return &projectionAdapter{sourceDim: sourceDim, targetDim: targetDim, matrix: matrix}, nil
|
||||||
|
}
|
||||||
236
pkg/adapter/adapter_test.go
Normal file
@@ -0,0 +1,236 @@
|
|||||||
|
package adapter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// unitVec returns a unit vector of length n with value 1/sqrt(n) in each position.
|
||||||
|
func unitVec(n int) []float32 {
|
||||||
|
v := make([]float32, n)
|
||||||
|
val := float32(1.0 / math.Sqrt(float64(n)))
|
||||||
|
for i := range v {
|
||||||
|
v[i] = val
|
||||||
|
}
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
|
||||||
|
func assertNormOne(t *testing.T, v []float32) {
|
||||||
|
t.Helper()
|
||||||
|
var sum float64
|
||||||
|
for _, x := range v {
|
||||||
|
sum += float64(x) * float64(x)
|
||||||
|
}
|
||||||
|
assert.InDelta(t, 1.0, sum, 1e-5, "expected unit norm")
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- L2Norm ---
|
||||||
|
|
||||||
|
func TestL2Norm(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input []float32
|
||||||
|
wantLen int
|
||||||
|
wantNorm float64
|
||||||
|
}{
|
||||||
|
{"unit vector unchanged", []float32{1, 0, 0}, 3, 1.0},
|
||||||
|
{"scale down to unit", []float32{2, 0, 0}, 3, 1.0},
|
||||||
|
{"multi-dim", unitVec(4), 4, 1.0},
|
||||||
|
}
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
got := L2Norm(tc.input)
|
||||||
|
assert.Len(t, got, tc.wantLen)
|
||||||
|
assertNormOne(t, got)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("zero vector returns copy unchanged", func(t *testing.T) {
|
||||||
|
in := []float32{0, 0, 0}
|
||||||
|
got := L2Norm(in)
|
||||||
|
assert.Equal(t, in, got)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- TruncateAdapter ---
|
||||||
|
|
||||||
|
func TestNewTruncate_InvalidDims(t *testing.T) {
|
||||||
|
_, err := NewTruncate(0, 768, TruncateFromEnd, PadAtEnd)
|
||||||
|
require.ErrorIs(t, err, ErrInvalidDim)
|
||||||
|
|
||||||
|
_, err = NewTruncate(1536, 0, TruncateFromEnd, PadAtEnd)
|
||||||
|
require.ErrorIs(t, err, ErrInvalidDim)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTruncateAdapter(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
src, tgt int
|
||||||
|
truncMode TruncateMode
|
||||||
|
padMode PadMode
|
||||||
|
// builder for input vec; checks are len + norm only unless wantFirst/wantLast set
|
||||||
|
wantLen int
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"downscale TruncateFromEnd", 1536, 768, TruncateFromEnd, PadAtEnd, 768, false},
|
||||||
|
{"downscale TruncateFromStart", 1536, 768, TruncateFromStart, PadAtEnd, 768, false},
|
||||||
|
{"upscale PadAtEnd", 768, 1536, TruncateFromEnd, PadAtEnd, 1536, false},
|
||||||
|
{"upscale PadAtStart", 768, 1536, TruncateFromEnd, PadAtStart, 1536, false},
|
||||||
|
{"same dim", 768, 768, TruncateFromEnd, PadAtEnd, 768, false},
|
||||||
|
}
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
a, err := NewTruncate(tc.src, tc.tgt, tc.truncMode, tc.padMode)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, tc.src, a.SourceDim())
|
||||||
|
assert.Equal(t, tc.tgt, a.TargetDim())
|
||||||
|
|
||||||
|
got, err := a.Adapt(unitVec(tc.src))
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Len(t, got, tc.wantLen)
|
||||||
|
assertNormOne(t, got)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("dim mismatch returns error", func(t *testing.T) {
|
||||||
|
a, err := NewTruncate(1536, 768, TruncateFromEnd, PadAtEnd)
|
||||||
|
require.NoError(t, err)
|
||||||
|
_, err = a.Adapt(make([]float32, 100))
|
||||||
|
require.ErrorIs(t, err, ErrDimMismatch)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("TruncateFromEnd keeps first elements", func(t *testing.T) {
|
||||||
|
a, err := NewTruncate(4, 2, TruncateFromEnd, PadAtEnd)
|
||||||
|
require.NoError(t, err)
|
||||||
|
// input: [1,2,3,4] — TruncateFromEnd keeps [1,2]
|
||||||
|
in := []float32{1, 2, 3, 4}
|
||||||
|
got, err := a.Adapt(in)
|
||||||
|
require.NoError(t, err)
|
||||||
|
// after L2Norm the signs are preserved; ratio should match
|
||||||
|
assert.Greater(t, got[0], float32(0))
|
||||||
|
assert.Greater(t, got[1], float32(0))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("TruncateFromStart keeps last elements", func(t *testing.T) {
|
||||||
|
a, err := NewTruncate(4, 2, TruncateFromStart, PadAtEnd)
|
||||||
|
require.NoError(t, err)
|
||||||
|
// input: [0,0,3,4] — TruncateFromStart keeps [3,4]
|
||||||
|
in := []float32{0, 0, 3, 4}
|
||||||
|
got, err := a.Adapt(in)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Greater(t, got[0], float32(0))
|
||||||
|
assert.Greater(t, got[1], float32(0))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("PadAtStart zero-pads front", func(t *testing.T) {
|
||||||
|
a, err := NewTruncate(2, 4, TruncateFromEnd, PadAtStart)
|
||||||
|
require.NoError(t, err)
|
||||||
|
in := []float32{1, 0}
|
||||||
|
got, err := a.Adapt(in)
|
||||||
|
require.NoError(t, err)
|
||||||
|
// first two positions should be zero (padded), last two carry the signal
|
||||||
|
assert.Equal(t, float32(0), got[0])
|
||||||
|
assert.Equal(t, float32(0), got[1])
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- RandomAdapter ---
|
||||||
|
|
||||||
|
func TestNewRandom_InvalidDims(t *testing.T) {
|
||||||
|
_, err := NewRandom(0, 768, 42)
|
||||||
|
require.ErrorIs(t, err, ErrInvalidDim)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRandomAdapter(t *testing.T) {
|
||||||
|
t.Run("output length and unit norm", func(t *testing.T) {
|
||||||
|
a, err := NewRandom(1536, 768, 42)
|
||||||
|
require.NoError(t, err)
|
||||||
|
got, err := a.Adapt(unitVec(1536))
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Len(t, got, 768)
|
||||||
|
assertNormOne(t, got)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("deterministic with same seed", func(t *testing.T) {
|
||||||
|
a1, _ := NewRandom(64, 32, 99)
|
||||||
|
a2, _ := NewRandom(64, 32, 99)
|
||||||
|
in := unitVec(64)
|
||||||
|
out1, _ := a1.Adapt(in)
|
||||||
|
out2, _ := a2.Adapt(in)
|
||||||
|
assert.Equal(t, out1, out2)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("different seeds produce different output", func(t *testing.T) {
|
||||||
|
a1, _ := NewRandom(64, 32, 1)
|
||||||
|
a2, _ := NewRandom(64, 32, 2)
|
||||||
|
in := unitVec(64)
|
||||||
|
out1, _ := a1.Adapt(in)
|
||||||
|
out2, _ := a2.Adapt(in)
|
||||||
|
assert.NotEqual(t, out1, out2)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("dim mismatch returns error", func(t *testing.T) {
|
||||||
|
a, err := NewRandom(1536, 768, 1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
_, err = a.Adapt(make([]float32, 10))
|
||||||
|
require.ErrorIs(t, err, ErrDimMismatch)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- ProjectionAdapter ---
|
||||||
|
|
||||||
|
func TestNewProjection_InvalidMatrix(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
src int
|
||||||
|
tgt int
|
||||||
|
matrix [][]float32
|
||||||
|
errIs error
|
||||||
|
}{
|
||||||
|
{"wrong row count", 4, 2, [][]float32{{1, 2, 3, 4}}, ErrInvalidMatrix},
|
||||||
|
{"wrong col count", 4, 2, [][]float32{{1, 2}, {3, 4}}, ErrInvalidMatrix},
|
||||||
|
{"zero sourceDim", 0, 2, [][]float32{{1}, {2}}, ErrInvalidDim},
|
||||||
|
}
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
_, err := NewProjection(tc.src, tc.tgt, tc.matrix)
|
||||||
|
require.ErrorIs(t, err, tc.errIs)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProjectionAdapter(t *testing.T) {
|
||||||
|
t.Run("identity matrix same dim", func(t *testing.T) {
|
||||||
|
// 3×3 identity
|
||||||
|
id := [][]float32{{1, 0, 0}, {0, 1, 0}, {0, 0, 1}}
|
||||||
|
a, err := NewProjection(3, 3, id)
|
||||||
|
require.NoError(t, err)
|
||||||
|
in := []float32{1, 2, 3}
|
||||||
|
got, err := a.Adapt(in)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Len(t, got, 3)
|
||||||
|
assertNormOne(t, got)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("downscale projection", func(t *testing.T) {
|
||||||
|
// 2×4 matrix
|
||||||
|
m := [][]float32{{1, 0, 0, 0}, {0, 1, 0, 0}}
|
||||||
|
a, err := NewProjection(4, 2, m)
|
||||||
|
require.NoError(t, err)
|
||||||
|
got, err := a.Adapt([]float32{3, 4, 0, 0})
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Len(t, got, 2)
|
||||||
|
assertNormOne(t, got)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("dim mismatch returns error", func(t *testing.T) {
|
||||||
|
m := [][]float32{{1, 0}, {0, 1}}
|
||||||
|
a, err := NewProjection(2, 2, m)
|
||||||
|
require.NoError(t, err)
|
||||||
|
_, err = a.Adapt(make([]float32, 5))
|
||||||
|
require.ErrorIs(t, err, ErrDimMismatch)
|
||||||
|
})
|
||||||
|
}
|
||||||
23
pkg/adapter/normalize.go
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
package adapter
|
||||||
|
|
||||||
|
import "math"
|
||||||
|
|
||||||
|
// L2Norm returns a new slice with the vector normalized to unit length.
|
||||||
|
// If the vector has zero magnitude it is returned unchanged.
|
||||||
|
func L2Norm(v []float32) []float32 {
|
||||||
|
var sum float64
|
||||||
|
for _, x := range v {
|
||||||
|
sum += float64(x) * float64(x)
|
||||||
|
}
|
||||||
|
if sum == 0 {
|
||||||
|
out := make([]float32, len(v))
|
||||||
|
copy(out, v)
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
norm := float32(math.Sqrt(sum))
|
||||||
|
out := make([]float32, len(v))
|
||||||
|
for i, x := range v {
|
||||||
|
out[i] = x / norm
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
32
pkg/adapter/projection.go
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
package adapter
|
||||||
|
|
||||||
|
import "fmt"
|
||||||
|
|
||||||
|
type projectionAdapter struct {
|
||||||
|
sourceDim int
|
||||||
|
targetDim int
|
||||||
|
matrix [][]float32
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *projectionAdapter) SourceDim() int { return a.sourceDim }
|
||||||
|
func (a *projectionAdapter) TargetDim() int { return a.targetDim }
|
||||||
|
|
||||||
|
func (a *projectionAdapter) Adapt(vec []float32) ([]float32, error) {
|
||||||
|
if len(vec) != a.sourceDim {
|
||||||
|
return nil, fmt.Errorf("projection adapt: %w: got %d, want %d", ErrDimMismatch, len(vec), a.sourceDim)
|
||||||
|
}
|
||||||
|
return L2Norm(matVecMul(a.matrix, vec)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// matVecMul computes m·v where m is [rows][cols] and v has len cols.
|
||||||
|
func matVecMul(m [][]float32, v []float32) []float32 {
|
||||||
|
out := make([]float32, len(m))
|
||||||
|
for i, row := range m {
|
||||||
|
var sum float32
|
||||||
|
for j, val := range row {
|
||||||
|
sum += val * v[j]
|
||||||
|
}
|
||||||
|
out[i] = sum
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
45
pkg/adapter/random.go
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
package adapter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"math"
|
||||||
|
"math/rand"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type randomAdapter struct {
|
||||||
|
sourceDim int
|
||||||
|
targetDim int
|
||||||
|
matrix [][]float32
|
||||||
|
}
|
||||||
|
|
||||||
|
func newRandomAdapter(sourceDim, targetDim int, seed int64) *randomAdapter {
|
||||||
|
if seed == 0 {
|
||||||
|
seed = time.Now().UnixNano()
|
||||||
|
}
|
||||||
|
//nolint:gosec // deterministic seeded RNG for projection matrix generation, not security use
|
||||||
|
rng := rand.New(rand.NewSource(seed))
|
||||||
|
|
||||||
|
// Gaussian N(0, 1/targetDim) — preserves expected squared norms (Johnson-Lindenstrauss)
|
||||||
|
stddev := 1.0 / math.Sqrt(float64(targetDim))
|
||||||
|
matrix := make([][]float32, targetDim)
|
||||||
|
for i := range matrix {
|
||||||
|
row := make([]float32, sourceDim)
|
||||||
|
for j := range row {
|
||||||
|
row[j] = float32(rng.NormFloat64() * stddev)
|
||||||
|
}
|
||||||
|
matrix[i] = row
|
||||||
|
}
|
||||||
|
|
||||||
|
return &randomAdapter{sourceDim: sourceDim, targetDim: targetDim, matrix: matrix}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *randomAdapter) SourceDim() int { return a.sourceDim }
|
||||||
|
func (a *randomAdapter) TargetDim() int { return a.targetDim }
|
||||||
|
|
||||||
|
func (a *randomAdapter) Adapt(vec []float32) ([]float32, error) {
|
||||||
|
if len(vec) != a.sourceDim {
|
||||||
|
return nil, fmt.Errorf("random adapt: %w: got %d, want %d", ErrDimMismatch, len(vec), a.sourceDim)
|
||||||
|
}
|
||||||
|
return L2Norm(matVecMul(a.matrix, vec)), nil
|
||||||
|
}
|
||||||
41
pkg/adapter/truncate.go
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
package adapter
|
||||||
|
|
||||||
|
import "fmt"
|
||||||
|
|
||||||
|
type truncateAdapter struct {
|
||||||
|
sourceDim int
|
||||||
|
targetDim int
|
||||||
|
truncMode TruncateMode
|
||||||
|
padMode PadMode
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *truncateAdapter) SourceDim() int { return a.sourceDim }
|
||||||
|
func (a *truncateAdapter) TargetDim() int { return a.targetDim }
|
||||||
|
|
||||||
|
func (a *truncateAdapter) Adapt(vec []float32) ([]float32, error) {
|
||||||
|
if len(vec) != a.sourceDim {
|
||||||
|
return nil, fmt.Errorf("truncate adapt: %w: got %d, want %d", ErrDimMismatch, len(vec), a.sourceDim)
|
||||||
|
}
|
||||||
|
|
||||||
|
out := make([]float32, a.targetDim)
|
||||||
|
|
||||||
|
if a.targetDim <= a.sourceDim {
|
||||||
|
// Downscale: truncate
|
||||||
|
switch a.truncMode {
|
||||||
|
case TruncateFromEnd:
|
||||||
|
copy(out, vec[:a.targetDim])
|
||||||
|
case TruncateFromStart:
|
||||||
|
copy(out, vec[a.sourceDim-a.targetDim:])
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Upscale: zero-pad
|
||||||
|
switch a.padMode {
|
||||||
|
case PadAtEnd:
|
||||||
|
copy(out, vec)
|
||||||
|
case PadAtStart:
|
||||||
|
copy(out[a.targetDim-a.sourceDim:], vec)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return L2Norm(out), nil
|
||||||
|
}
|
||||||
168
pkg/config/config.go
Normal file
@@ -0,0 +1,168 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/spf13/viper"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Config is the root configuration for vecna.
|
||||||
|
type Config struct {
|
||||||
|
Server ServerConfig `mapstructure:"server"`
|
||||||
|
Metrics MetricsConfig `mapstructure:"metrics"`
|
||||||
|
Forward ForwardConfig `mapstructure:"forward"`
|
||||||
|
Adapter AdapterConfig `mapstructure:"adapter"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ServerConfig controls the HTTP listener and inbound auth.
|
||||||
|
type ServerConfig struct {
|
||||||
|
Port int `mapstructure:"port"`
|
||||||
|
Host string `mapstructure:"host"`
|
||||||
|
APIKeys []string `mapstructure:"api_keys"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// MetricsConfig controls Prometheus metrics exposure.
|
||||||
|
type MetricsConfig struct {
|
||||||
|
Enabled bool `mapstructure:"enabled"`
|
||||||
|
Path string `mapstructure:"path"`
|
||||||
|
APIKey string `mapstructure:"api_key"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ForwardConfig holds all named forwarding targets.
|
||||||
|
type ForwardConfig struct {
|
||||||
|
Default string `mapstructure:"default"`
|
||||||
|
Targets map[string]ForwardTarget `mapstructure:"targets"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ForwardTarget is a named backing embedding model with one or more endpoints.
|
||||||
|
type ForwardTarget struct {
|
||||||
|
Endpoints []EndpointConfig `mapstructure:"endpoints"`
|
||||||
|
Model string `mapstructure:"model"`
|
||||||
|
APIKey string `mapstructure:"api_key"`
|
||||||
|
APIType string `mapstructure:"api_type"`
|
||||||
|
TimeoutSecs int `mapstructure:"timeout_secs"`
|
||||||
|
CooldownSecs int `mapstructure:"cooldown_secs"`
|
||||||
|
PriorityDecay int `mapstructure:"priority_decay"`
|
||||||
|
PriorityRecovery int `mapstructure:"priority_recovery"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// EndpointConfig is a single URL within a ForwardTarget.
|
||||||
|
type EndpointConfig struct {
|
||||||
|
URL string `mapstructure:"url"`
|
||||||
|
Priority int `mapstructure:"priority"`
|
||||||
|
APIKey string `mapstructure:"api_key"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// AdapterConfig selects and tunes the dimension adapter.
|
||||||
|
type AdapterConfig struct {
|
||||||
|
Type string `mapstructure:"type"`
|
||||||
|
SourceDim int `mapstructure:"source_dim"`
|
||||||
|
TargetDim int `mapstructure:"target_dim"`
|
||||||
|
TruncateMode string `mapstructure:"truncate_mode"`
|
||||||
|
PadMode string `mapstructure:"pad_mode"`
|
||||||
|
Seed int64 `mapstructure:"seed"`
|
||||||
|
MatrixFile string `mapstructure:"matrix_file"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// extensions viper will detect automatically.
|
||||||
|
var extensions = []string{"json", "yaml", "toml"}
|
||||||
|
|
||||||
|
// ResolveFile returns the config file path that would be used by Load.
|
||||||
|
// If cfgFile is non-empty it is returned as-is.
|
||||||
|
// Otherwise the default search paths are checked; if no existing file is found,
|
||||||
|
// the preferred default (~/.vecna.json) is returned so callers can create it.
|
||||||
|
func ResolveFile(cfgFile string) string {
|
||||||
|
if cfgFile != "" {
|
||||||
|
return cfgFile
|
||||||
|
}
|
||||||
|
home, _ := os.UserHomeDir()
|
||||||
|
dirs := []string{".", home, home + "/.config/vecna"}
|
||||||
|
for _, dir := range dirs {
|
||||||
|
for _, ext := range extensions {
|
||||||
|
path := dir + "/vecna." + ext
|
||||||
|
if _, err := os.Stat(path); err == nil {
|
||||||
|
return path
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return home + "/vecna.json"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load reads configuration from the given file path (empty = search defaults),
|
||||||
|
// environment variables (prefix VECNA_), and applies built-in defaults.
|
||||||
|
func Load(cfgFile string) (*Config, error) {
|
||||||
|
v := viper.New()
|
||||||
|
|
||||||
|
// Defaults
|
||||||
|
v.SetDefault("server.port", 8080)
|
||||||
|
v.SetDefault("server.host", "0.0.0.0")
|
||||||
|
v.SetDefault("metrics.enabled", false)
|
||||||
|
v.SetDefault("metrics.path", "/metrics")
|
||||||
|
v.SetDefault("adapter.type", "truncate")
|
||||||
|
v.SetDefault("adapter.truncate_mode", "from_end")
|
||||||
|
v.SetDefault("adapter.pad_mode", "at_end")
|
||||||
|
|
||||||
|
// Environment
|
||||||
|
v.SetEnvPrefix("VECNA")
|
||||||
|
v.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
|
||||||
|
v.AutomaticEnv()
|
||||||
|
|
||||||
|
// Config file
|
||||||
|
if cfgFile != "" {
|
||||||
|
v.SetConfigFile(cfgFile)
|
||||||
|
} else {
|
||||||
|
home, _ := os.UserHomeDir()
|
||||||
|
v.SetConfigName("vecna")
|
||||||
|
// No SetConfigType — viper detects format from file extension (json, yaml, toml, etc.)
|
||||||
|
v.AddConfigPath(".")
|
||||||
|
v.AddConfigPath(home)
|
||||||
|
v.AddConfigPath(home + "/.config/vecna")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := v.ReadInConfig(); err != nil {
|
||||||
|
// Missing config file is acceptable when all required values come from flags/env
|
||||||
|
var notFound viper.ConfigFileNotFoundError
|
||||||
|
if !errors.As(err, ¬Found) {
|
||||||
|
return nil, fmt.Errorf("load config: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var cfg Config
|
||||||
|
if err := v.Unmarshal(&cfg); err != nil {
|
||||||
|
return nil, fmt.Errorf("unmarshal config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
applyForwardDefaults(&cfg)
|
||||||
|
|
||||||
|
return &cfg, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// applyForwardDefaults fills in zero-value fields on ForwardTarget entries.
|
||||||
|
func applyForwardDefaults(cfg *Config) {
|
||||||
|
for name, t := range cfg.Forward.Targets {
|
||||||
|
if t.TimeoutSecs == 0 {
|
||||||
|
t.TimeoutSecs = 30
|
||||||
|
}
|
||||||
|
if t.CooldownSecs == 0 {
|
||||||
|
t.CooldownSecs = 60
|
||||||
|
}
|
||||||
|
if t.PriorityDecay == 0 {
|
||||||
|
t.PriorityDecay = 2
|
||||||
|
}
|
||||||
|
if t.PriorityRecovery == 0 {
|
||||||
|
t.PriorityRecovery = 5
|
||||||
|
}
|
||||||
|
for i, ep := range t.Endpoints {
|
||||||
|
if ep.Priority == 0 {
|
||||||
|
t.Endpoints[i].Priority = 10
|
||||||
|
}
|
||||||
|
if ep.APIKey == "" && t.APIKey != "" {
|
||||||
|
t.Endpoints[i].APIKey = t.APIKey
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cfg.Forward.Targets[name] = t
|
||||||
|
}
|
||||||
|
}
|
||||||
137
pkg/config/update.go
Normal file
@@ -0,0 +1,137 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SaveTarget adds or replaces a named ForwardTarget in the config file at path.
|
||||||
|
// Only JSON config files are written in-place. For other formats an error is
|
||||||
|
// returned describing what to add manually.
|
||||||
|
func SaveTarget(path, name string, target ForwardTarget) error {
|
||||||
|
ext := strings.ToLower(strings.TrimPrefix(filepath.Ext(path), "."))
|
||||||
|
|
||||||
|
data, err := os.ReadFile(path)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("read config %s: %w", path, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch ext {
|
||||||
|
case "json":
|
||||||
|
return saveTargetJSON(path, data, name, target)
|
||||||
|
default:
|
||||||
|
snippet, _ := json.MarshalIndent(map[string]ForwardTarget{name: target}, "", " ")
|
||||||
|
return fmt.Errorf(
|
||||||
|
"auto-update not supported for .%s files\n"+
|
||||||
|
"Add the following to the 'forward.targets' section of %s:\n\n%s",
|
||||||
|
ext, path, snippet,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveBrokenEndpoints removes failing endpoints from the config file.
|
||||||
|
// broken maps target name → set of failing endpoint URLs.
|
||||||
|
// If all endpoints of a target are removed, the target itself is deleted.
|
||||||
|
// Returns the list of removed items as human-readable strings.
|
||||||
|
func RemoveBrokenEndpoints(path string, broken map[string][]string) ([]string, error) {
|
||||||
|
ext := strings.ToLower(strings.TrimPrefix(filepath.Ext(path), "."))
|
||||||
|
if ext != "json" {
|
||||||
|
return nil, fmt.Errorf("auto-update not supported for .%s files; edit %s manually", ext, path)
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := os.ReadFile(path)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("read config %s: %w", path, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var cfg Config
|
||||||
|
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||||
|
return nil, fmt.Errorf("parse %s: %w", path, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var removed []string
|
||||||
|
|
||||||
|
for targetName, failedURLs := range broken {
|
||||||
|
target, ok := cfg.Forward.Targets[targetName]
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
failSet := make(map[string]bool, len(failedURLs))
|
||||||
|
for _, u := range failedURLs {
|
||||||
|
failSet[u] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
kept := target.Endpoints[:0]
|
||||||
|
for _, ep := range target.Endpoints {
|
||||||
|
if failSet[ep.URL] {
|
||||||
|
removed = append(removed, fmt.Sprintf("endpoint %s (target %q)", ep.URL, targetName))
|
||||||
|
} else {
|
||||||
|
kept = append(kept, ep)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(kept) == 0 {
|
||||||
|
delete(cfg.Forward.Targets, targetName)
|
||||||
|
removed = append(removed, fmt.Sprintf("target %q (all endpoints failed)", targetName))
|
||||||
|
if cfg.Forward.Default == targetName {
|
||||||
|
cfg.Forward.Default = ""
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
target.Endpoints = kept
|
||||||
|
cfg.Forward.Targets[targetName] = target
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(removed) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
out, err := json.MarshalIndent(cfg, "", " ")
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("marshal config: %w", err)
|
||||||
|
}
|
||||||
|
if err := os.WriteFile(path, append(out, '\n'), 0o600); err != nil {
|
||||||
|
return nil, fmt.Errorf("write %s: %w", path, err)
|
||||||
|
}
|
||||||
|
return removed, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteConfig serialises cfg as indented JSON and atomically overwrites path.
|
||||||
|
func WriteConfig(path string, cfg Config) error {
|
||||||
|
out, err := json.MarshalIndent(cfg, "", " ")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("marshal config: %w", err)
|
||||||
|
}
|
||||||
|
if err := os.WriteFile(path, append(out, '\n'), 0o600); err != nil {
|
||||||
|
return fmt.Errorf("write %s: %w", path, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func saveTargetJSON(path string, data []byte, name string, target ForwardTarget) error {
|
||||||
|
var cfg Config
|
||||||
|
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||||
|
return fmt.Errorf("parse %s: %w", path, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.Forward.Targets == nil {
|
||||||
|
cfg.Forward.Targets = make(map[string]ForwardTarget)
|
||||||
|
}
|
||||||
|
cfg.Forward.Targets[name] = target
|
||||||
|
if cfg.Forward.Default == "" {
|
||||||
|
cfg.Forward.Default = name
|
||||||
|
}
|
||||||
|
|
||||||
|
out, err := json.MarshalIndent(cfg, "", " ")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("marshal config: %w", err)
|
||||||
|
}
|
||||||
|
if err := os.WriteFile(path, append(out, '\n'), 0o600); err != nil {
|
||||||
|
return fmt.Errorf("write %s: %w", path, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
206
pkg/discovery/discovery.go
Normal file
@@ -0,0 +1,206 @@
|
|||||||
|
package discovery
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Kind describes a known server type.
|
||||||
|
type Kind struct {
|
||||||
|
Name string
|
||||||
|
APIType string
|
||||||
|
Port int
|
||||||
|
NeedsKey bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// Found is a server discovered on the network.
|
||||||
|
type Found struct {
|
||||||
|
Kind Kind
|
||||||
|
BaseURL string
|
||||||
|
Models []string
|
||||||
|
}
|
||||||
|
|
||||||
|
// knownServers lists server types by their default port and display name.
|
||||||
|
// Ollama is listed separately because it uses a non-OpenAI probe endpoint.
|
||||||
|
var knownServers = []Kind{
|
||||||
|
{Name: "Ollama", APIType: "openai", Port: 11434},
|
||||||
|
{Name: "LM Studio", APIType: "openai", Port: 1234},
|
||||||
|
{Name: "vLLM", APIType: "openai", Port: 8000, NeedsKey: true},
|
||||||
|
{Name: "LocalAI", APIType: "openai", Port: 8080},
|
||||||
|
{Name: "Jan", APIType: "openai", Port: 1337},
|
||||||
|
{Name: "Kobold", APIType: "openai", Port: 5001},
|
||||||
|
{Name: "Tabby", APIType: "openai", Port: 9090},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Scan concurrently probes localhost and LAN gateway addresses for known LLM servers.
|
||||||
|
// Results are returned in the order they are found (non-deterministic).
|
||||||
|
func Scan(ctx context.Context) []Found {
|
||||||
|
hosts := localHosts()
|
||||||
|
|
||||||
|
var (
|
||||||
|
mu sync.Mutex
|
||||||
|
results []Found
|
||||||
|
wg sync.WaitGroup
|
||||||
|
)
|
||||||
|
|
||||||
|
probeCtx, cancel := context.WithTimeout(ctx, 4*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
httpClient := &http.Client{Timeout: 600 * time.Millisecond}
|
||||||
|
|
||||||
|
for _, host := range hosts {
|
||||||
|
for _, kind := range knownServers {
|
||||||
|
host := host
|
||||||
|
wg.Add(1)
|
||||||
|
go func(kind Kind) {
|
||||||
|
defer wg.Done()
|
||||||
|
baseURL := fmt.Sprintf("http://%s:%d", host, kind.Port)
|
||||||
|
models, err := probe(probeCtx, httpClient, baseURL, kind)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
mu.Lock()
|
||||||
|
results = append(results, Found{Kind: kind, BaseURL: baseURL, Models: models})
|
||||||
|
mu.Unlock()
|
||||||
|
}(kind)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
return results
|
||||||
|
}
|
||||||
|
|
||||||
|
// Models fetches the model list from a single base URL and kind (for the models command).
|
||||||
|
func Models(ctx context.Context, baseURL string, kind Kind) ([]string, error) {
|
||||||
|
httpClient := &http.Client{Timeout: 5 * time.Second}
|
||||||
|
return probe(ctx, httpClient, baseURL, kind)
|
||||||
|
}
|
||||||
|
|
||||||
|
// localHosts returns localhost plus the .1 gateway of every local IPv4 subnet.
|
||||||
|
func localHosts() []string {
|
||||||
|
seen := map[string]bool{"127.0.0.1": true}
|
||||||
|
hosts := []string{"127.0.0.1"}
|
||||||
|
|
||||||
|
ifaces, err := net.Interfaces()
|
||||||
|
if err != nil {
|
||||||
|
return hosts
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, iface := range ifaces {
|
||||||
|
if iface.Flags&net.FlagUp == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
addrs, err := iface.Addrs()
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for _, addr := range addrs {
|
||||||
|
ipnet, ok := addr.(*net.IPNet)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
ip := ipnet.IP.To4()
|
||||||
|
if ip == nil || ip.IsLoopback() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// derive the likely gateway: network address + 1
|
||||||
|
gw := ip.Mask(ipnet.Mask)
|
||||||
|
gw[3] = 1
|
||||||
|
h := gw.String()
|
||||||
|
if !seen[h] {
|
||||||
|
seen[h] = true
|
||||||
|
hosts = append(hosts, h)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return hosts
|
||||||
|
}
|
||||||
|
|
||||||
|
// probe attempts to identify the server at baseURL and returns its model list.
|
||||||
|
func probe(ctx context.Context, client *http.Client, baseURL string, kind Kind) ([]string, error) {
|
||||||
|
// Ollama has its own endpoint; everything else is OpenAI-compatible
|
||||||
|
if kind.Name == "Ollama" {
|
||||||
|
models, err := probeOllama(ctx, client, baseURL)
|
||||||
|
if err != nil {
|
||||||
|
// Ollama also exposes /v1/models since v0.1.27 — fall back
|
||||||
|
return probeOpenAI(ctx, client, baseURL)
|
||||||
|
}
|
||||||
|
return models, nil
|
||||||
|
}
|
||||||
|
return probeOpenAI(ctx, client, baseURL)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Ollama ---
|
||||||
|
|
||||||
|
type ollamaTagsResponse struct {
|
||||||
|
Models []struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
} `json:"models"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func probeOllama(ctx context.Context, client *http.Client, baseURL string) ([]string, error) {
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, baseURL+"/api/tags", nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return nil, fmt.Errorf("status %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
var body ollamaTagsResponse
|
||||||
|
if err := json.NewDecoder(resp.Body).Decode(&body); err != nil {
|
||||||
|
return nil, fmt.Errorf("decode ollama response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
models := make([]string, len(body.Models))
|
||||||
|
for i, m := range body.Models {
|
||||||
|
models[i] = m.Name
|
||||||
|
}
|
||||||
|
return models, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- OpenAI-compatible ---
|
||||||
|
|
||||||
|
type openAIModelsResponse struct {
|
||||||
|
Data []struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
} `json:"data"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func probeOpenAI(ctx context.Context, client *http.Client, baseURL string) ([]string, error) {
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, baseURL+"/v1/models", nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return nil, fmt.Errorf("status %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
var body openAIModelsResponse
|
||||||
|
if err := json.NewDecoder(resp.Body).Decode(&body); err != nil {
|
||||||
|
return nil, fmt.Errorf("decode openai response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
models := make([]string, len(body.Data))
|
||||||
|
for i, m := range body.Data {
|
||||||
|
models[i] = m.ID
|
||||||
|
}
|
||||||
|
return models, nil
|
||||||
|
}
|
||||||
27
pkg/embedclient/client.go
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
package embedclient
|
||||||
|
|
||||||
|
import "context"
|
||||||
|
|
||||||
|
// Request is a batch of texts to embed.
|
||||||
|
type Request struct {
|
||||||
|
Texts []string
|
||||||
|
Model string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Usage reports token consumption from the backing model.
|
||||||
|
type Usage struct {
|
||||||
|
PromptTokens int
|
||||||
|
TotalTokens int
|
||||||
|
}
|
||||||
|
|
||||||
|
// Response holds the raw embeddings returned by the backing model.
|
||||||
|
type Response struct {
|
||||||
|
Embeddings [][]float32
|
||||||
|
Model string
|
||||||
|
Usage Usage
|
||||||
|
}
|
||||||
|
|
||||||
|
// Client sends text to a backing embedding model and returns raw vectors.
|
||||||
|
type Client interface {
|
||||||
|
Embed(ctx context.Context, req Request) (Response, error)
|
||||||
|
}
|
||||||
98
pkg/embedclient/google.go
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
package embedclient
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
type googleClient struct {
|
||||||
|
baseURL string
|
||||||
|
apiKey string
|
||||||
|
model string
|
||||||
|
httpClient *http.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewGoogle returns a Client that speaks the Google Gemini batchEmbedContents API.
|
||||||
|
func NewGoogle(baseURL, apiKey, model string, httpClient *http.Client) Client {
|
||||||
|
if httpClient == nil {
|
||||||
|
httpClient = http.DefaultClient
|
||||||
|
}
|
||||||
|
return &googleClient{baseURL: baseURL, apiKey: apiKey, model: model, httpClient: httpClient}
|
||||||
|
}
|
||||||
|
|
||||||
|
type googleBatchRequest struct {
|
||||||
|
Requests []googleEmbedRequest `json:"requests"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type googleEmbedRequest struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Content googleContent `json:"content"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type googleContent struct {
|
||||||
|
Parts []googlePart `json:"parts"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type googlePart struct {
|
||||||
|
Text string `json:"text"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type googleBatchResponse struct {
|
||||||
|
Embeddings []struct {
|
||||||
|
Values []float32 `json:"values"`
|
||||||
|
} `json:"embeddings"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *googleClient) Embed(ctx context.Context, req Request) (Response, error) {
|
||||||
|
requests := make([]googleEmbedRequest, len(req.Texts))
|
||||||
|
for i, text := range req.Texts {
|
||||||
|
requests[i] = googleEmbedRequest{
|
||||||
|
Model: "models/" + c.model,
|
||||||
|
Content: googleContent{Parts: []googlePart{{Text: text}}},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := json.Marshal(googleBatchRequest{Requests: requests})
|
||||||
|
if err != nil {
|
||||||
|
return Response{}, fmt.Errorf("google embed marshal: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
url := fmt.Sprintf("%s/v1/models/%s:batchEmbedContents", c.baseURL, c.model)
|
||||||
|
if c.apiKey != "" {
|
||||||
|
url += "?key=" + c.apiKey
|
||||||
|
}
|
||||||
|
|
||||||
|
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
||||||
|
if err != nil {
|
||||||
|
return Response{}, fmt.Errorf("google embed request: %w", err)
|
||||||
|
}
|
||||||
|
httpReq.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
resp, err := c.httpClient.Do(httpReq)
|
||||||
|
if err != nil {
|
||||||
|
return Response{}, fmt.Errorf("google embed do: %w", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return Response{}, fmt.Errorf("google embed: unexpected status %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
var gResp googleBatchResponse
|
||||||
|
if err := json.NewDecoder(resp.Body).Decode(&gResp); err != nil {
|
||||||
|
return Response{}, fmt.Errorf("google embed decode: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
embeddings := make([][]float32, len(gResp.Embeddings))
|
||||||
|
for i, e := range gResp.Embeddings {
|
||||||
|
embeddings[i] = e.Values
|
||||||
|
}
|
||||||
|
|
||||||
|
return Response{
|
||||||
|
Embeddings: embeddings,
|
||||||
|
Model: c.model,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
87
pkg/embedclient/openai.go
Normal file
@@ -0,0 +1,87 @@
|
|||||||
|
package embedclient
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
type openAIClient struct {
|
||||||
|
baseURL string
|
||||||
|
apiKey string
|
||||||
|
httpClient *http.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewOpenAI returns a Client that speaks the OpenAI embeddings API.
|
||||||
|
func NewOpenAI(baseURL, apiKey string, httpClient *http.Client) Client {
|
||||||
|
if httpClient == nil {
|
||||||
|
httpClient = http.DefaultClient
|
||||||
|
}
|
||||||
|
return &openAIClient{baseURL: baseURL, apiKey: apiKey, httpClient: httpClient}
|
||||||
|
}
|
||||||
|
|
||||||
|
type openAIEmbedRequest struct {
|
||||||
|
Input []string `json:"input"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type openAIEmbedResponse struct {
|
||||||
|
Object string `json:"object"`
|
||||||
|
Data []struct {
|
||||||
|
Object string `json:"object"`
|
||||||
|
Embedding []float32 `json:"embedding"`
|
||||||
|
Index int `json:"index"`
|
||||||
|
} `json:"data"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
Usage struct {
|
||||||
|
PromptTokens int `json:"prompt_tokens"`
|
||||||
|
TotalTokens int `json:"total_tokens"`
|
||||||
|
} `json:"usage"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *openAIClient) Embed(ctx context.Context, req Request) (Response, error) {
|
||||||
|
body, err := json.Marshal(openAIEmbedRequest{Input: req.Texts, Model: req.Model})
|
||||||
|
if err != nil {
|
||||||
|
return Response{}, fmt.Errorf("openai embed marshal: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+"/v1/embeddings", bytes.NewReader(body))
|
||||||
|
if err != nil {
|
||||||
|
return Response{}, fmt.Errorf("openai embed request: %w", err)
|
||||||
|
}
|
||||||
|
httpReq.Header.Set("Content-Type", "application/json")
|
||||||
|
if c.apiKey != "" {
|
||||||
|
httpReq.Header.Set("Authorization", "Bearer "+c.apiKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := c.httpClient.Do(httpReq)
|
||||||
|
if err != nil {
|
||||||
|
return Response{}, fmt.Errorf("openai embed do: %w", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return Response{}, fmt.Errorf("openai embed: unexpected status %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
var oaiResp openAIEmbedResponse
|
||||||
|
if err := json.NewDecoder(resp.Body).Decode(&oaiResp); err != nil {
|
||||||
|
return Response{}, fmt.Errorf("openai embed decode: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
embeddings := make([][]float32, len(oaiResp.Data))
|
||||||
|
for _, d := range oaiResp.Data {
|
||||||
|
embeddings[d.Index] = d.Embedding
|
||||||
|
}
|
||||||
|
|
||||||
|
return Response{
|
||||||
|
Embeddings: embeddings,
|
||||||
|
Model: oaiResp.Model,
|
||||||
|
Usage: Usage{
|
||||||
|
PromptTokens: oaiResp.Usage.PromptTokens,
|
||||||
|
TotalTokens: oaiResp.Usage.TotalTokens,
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
180
pkg/embedclient/router.go
Normal file
@@ -0,0 +1,180 @@
|
|||||||
|
package embedclient
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Warky-Devs/vecna.git/pkg/metrics"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RouterConfig holds tuning parameters for a TargetRouter.
|
||||||
|
type RouterConfig struct {
|
||||||
|
TargetName string
|
||||||
|
TimeoutSecs int
|
||||||
|
CooldownSecs int
|
||||||
|
PriorityDecay int
|
||||||
|
PriorityRecovery int
|
||||||
|
}
|
||||||
|
|
||||||
|
type endpointSlot struct {
|
||||||
|
client Client
|
||||||
|
url string
|
||||||
|
initialPriority int
|
||||||
|
|
||||||
|
mu sync.Mutex
|
||||||
|
priority int
|
||||||
|
inflight int
|
||||||
|
successCount int
|
||||||
|
lastFail time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// TargetRouter implements Client by routing requests across multiple endpoint slots
|
||||||
|
// using a busyness-based priority algorithm.
|
||||||
|
type TargetRouter struct {
|
||||||
|
slots []*endpointSlot
|
||||||
|
cfg RouterConfig
|
||||||
|
metrics *metrics.Registry
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewTargetRouter constructs a TargetRouter from a slice of (client, url, initialPriority) tuples.
|
||||||
|
func NewTargetRouter(slots []RouterSlot, cfg RouterConfig, reg *metrics.Registry) (*TargetRouter, error) {
|
||||||
|
if len(slots) == 0 {
|
||||||
|
return nil, fmt.Errorf("NewTargetRouter: at least one slot required")
|
||||||
|
}
|
||||||
|
es := make([]*endpointSlot, len(slots))
|
||||||
|
for i, s := range slots {
|
||||||
|
es[i] = &endpointSlot{
|
||||||
|
client: s.Client,
|
||||||
|
url: s.URL,
|
||||||
|
initialPriority: s.Priority,
|
||||||
|
priority: s.Priority,
|
||||||
|
}
|
||||||
|
if reg != nil {
|
||||||
|
reg.SetEndpointPriority(cfg.TargetName, s.URL, float64(s.Priority))
|
||||||
|
reg.SetEndpointInflight(cfg.TargetName, s.URL, 0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return &TargetRouter{slots: es, cfg: cfg, metrics: reg}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RouterSlot is a single endpoint entry for NewTargetRouter.
|
||||||
|
type RouterSlot struct {
|
||||||
|
Client Client
|
||||||
|
URL string
|
||||||
|
Priority int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *TargetRouter) Embed(ctx context.Context, req Request) (Response, error) {
|
||||||
|
slot := r.pick()
|
||||||
|
|
||||||
|
slot.mu.Lock()
|
||||||
|
slot.inflight++
|
||||||
|
if r.metrics != nil {
|
||||||
|
r.metrics.SetEndpointInflight(r.cfg.TargetName, slot.url, float64(slot.inflight))
|
||||||
|
}
|
||||||
|
slot.mu.Unlock()
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
slot.mu.Lock()
|
||||||
|
slot.inflight--
|
||||||
|
if r.metrics != nil {
|
||||||
|
r.metrics.SetEndpointInflight(r.cfg.TargetName, slot.url, float64(slot.inflight))
|
||||||
|
}
|
||||||
|
slot.mu.Unlock()
|
||||||
|
}()
|
||||||
|
|
||||||
|
timeout := time.Duration(r.cfg.TimeoutSecs) * time.Second
|
||||||
|
ctx, cancel := context.WithTimeout(ctx, timeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
resp, err := slot.client.Embed(ctx, req)
|
||||||
|
if err != nil {
|
||||||
|
r.onFailure(slot, err)
|
||||||
|
return Response{}, fmt.Errorf("router embed [%s]: %w", slot.url, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
r.onSuccess(slot)
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// pick selects the best available slot.
|
||||||
|
func (r *TargetRouter) pick() *endpointSlot {
|
||||||
|
cooldown := time.Duration(r.cfg.CooldownSecs) * time.Second
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
var best *endpointSlot
|
||||||
|
bestScore := -1 << 30
|
||||||
|
|
||||||
|
for _, s := range r.slots {
|
||||||
|
s.mu.Lock()
|
||||||
|
inCooldown := !s.lastFail.IsZero() && now.Sub(s.lastFail) < cooldown
|
||||||
|
score := s.priority - s.inflight
|
||||||
|
lastFail := s.lastFail
|
||||||
|
s.mu.Unlock()
|
||||||
|
|
||||||
|
if inCooldown {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if best == nil || score > bestScore {
|
||||||
|
best = s
|
||||||
|
bestScore = score
|
||||||
|
_ = lastFail
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// All in cooldown — fall back to oldest failure
|
||||||
|
if best == nil {
|
||||||
|
var oldest time.Time
|
||||||
|
for _, s := range r.slots {
|
||||||
|
s.mu.Lock()
|
||||||
|
lf := s.lastFail
|
||||||
|
s.mu.Unlock()
|
||||||
|
if best == nil || lf.Before(oldest) {
|
||||||
|
best = s
|
||||||
|
oldest = lf
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return best
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *TargetRouter) onSuccess(s *endpointSlot) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
s.successCount++
|
||||||
|
if r.cfg.PriorityRecovery > 0 && s.successCount%r.cfg.PriorityRecovery == 0 {
|
||||||
|
if s.priority < s.initialPriority {
|
||||||
|
s.priority++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if r.metrics != nil {
|
||||||
|
r.metrics.SetEndpointPriority(r.cfg.TargetName, s.url, float64(s.priority))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *TargetRouter) onFailure(s *endpointSlot, err error) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
s.lastFail = time.Now()
|
||||||
|
s.priority -= r.cfg.PriorityDecay
|
||||||
|
if s.priority < 1 {
|
||||||
|
s.priority = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
errType := "error"
|
||||||
|
if ctx := context.Background(); ctx.Err() != nil {
|
||||||
|
errType = "timeout"
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.metrics != nil {
|
||||||
|
r.metrics.SetEndpointPriority(r.cfg.TargetName, s.url, float64(s.priority))
|
||||||
|
r.metrics.IncEndpointErrors(r.cfg.TargetName, s.url, errType)
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = err
|
||||||
|
}
|
||||||
112
pkg/metrics/metrics.go
Normal file
@@ -0,0 +1,112 @@
|
|||||||
|
package metrics
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/prometheus/client_golang/prometheus"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Registry holds all vecna Prometheus metrics on a dedicated (non-global) registry.
|
||||||
|
type Registry struct {
|
||||||
|
reg *prometheus.Registry
|
||||||
|
|
||||||
|
RequestsTotal *prometheus.CounterVec
|
||||||
|
RequestDuration *prometheus.HistogramVec
|
||||||
|
ForwardDuration *prometheus.HistogramVec
|
||||||
|
TranslateDuration *prometheus.HistogramVec
|
||||||
|
EndpointPriority *prometheus.GaugeVec
|
||||||
|
EndpointInflight *prometheus.GaugeVec
|
||||||
|
EndpointErrorsTotal *prometheus.CounterVec
|
||||||
|
TokensTotal *prometheus.CounterVec
|
||||||
|
}
|
||||||
|
|
||||||
|
// New creates and registers all metrics on a fresh Prometheus registry.
|
||||||
|
func New() *Registry {
|
||||||
|
reg := prometheus.NewRegistry()
|
||||||
|
|
||||||
|
r := &Registry{
|
||||||
|
reg: reg,
|
||||||
|
|
||||||
|
RequestsTotal: prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||||
|
Name: "vecna_requests_total",
|
||||||
|
Help: "Total number of requests served by vecna.",
|
||||||
|
}, []string{"endpoint", "status"}),
|
||||||
|
|
||||||
|
RequestDuration: prometheus.NewHistogramVec(prometheus.HistogramOpts{
|
||||||
|
Name: "vecna_request_duration_seconds",
|
||||||
|
Help: "Total request wall-clock time.",
|
||||||
|
Buckets: prometheus.DefBuckets,
|
||||||
|
}, []string{"endpoint"}),
|
||||||
|
|
||||||
|
ForwardDuration: prometheus.NewHistogramVec(prometheus.HistogramOpts{
|
||||||
|
Name: "vecna_forward_duration_seconds",
|
||||||
|
Help: "Time spent waiting on the backing embedding model.",
|
||||||
|
Buckets: prometheus.DefBuckets,
|
||||||
|
}, []string{"target", "url"}),
|
||||||
|
|
||||||
|
TranslateDuration: prometheus.NewHistogramVec(prometheus.HistogramOpts{
|
||||||
|
Name: "vecna_translate_duration_seconds",
|
||||||
|
Help: "Time spent in the dimension adapter.",
|
||||||
|
Buckets: []float64{0.0001, 0.0005, 0.001, 0.005, 0.01, 0.05},
|
||||||
|
}, []string{"adapter_type"}),
|
||||||
|
|
||||||
|
EndpointPriority: prometheus.NewGaugeVec(prometheus.GaugeOpts{
|
||||||
|
Name: "vecna_endpoint_priority",
|
||||||
|
Help: "Current dynamic routing priority for a forwarding endpoint.",
|
||||||
|
}, []string{"target", "url"}),
|
||||||
|
|
||||||
|
EndpointInflight: prometheus.NewGaugeVec(prometheus.GaugeOpts{
|
||||||
|
Name: "vecna_endpoint_inflight",
|
||||||
|
Help: "Number of active in-flight requests per forwarding endpoint.",
|
||||||
|
}, []string{"target", "url"}),
|
||||||
|
|
||||||
|
EndpointErrorsTotal: prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||||
|
Name: "vecna_endpoint_errors_total",
|
||||||
|
Help: "Total forwarding errors per endpoint, labelled by error type.",
|
||||||
|
}, []string{"target", "url", "error"}),
|
||||||
|
|
||||||
|
TokensTotal: prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||||
|
Name: "vecna_tokens_total",
|
||||||
|
Help: "Tokens consumed by the backing embedding model, by target, model, and token type.",
|
||||||
|
}, []string{"target", "model", "token_type"}),
|
||||||
|
}
|
||||||
|
|
||||||
|
reg.MustRegister(
|
||||||
|
r.RequestsTotal,
|
||||||
|
r.RequestDuration,
|
||||||
|
r.ForwardDuration,
|
||||||
|
r.TranslateDuration,
|
||||||
|
r.EndpointPriority,
|
||||||
|
r.EndpointInflight,
|
||||||
|
r.EndpointErrorsTotal,
|
||||||
|
r.TokensTotal,
|
||||||
|
)
|
||||||
|
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prometheus returns the underlying registry for use with promhttp.HandlerFor.
|
||||||
|
func (r *Registry) Prometheus() *prometheus.Registry {
|
||||||
|
return r.reg
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convenience setters used by the router.
|
||||||
|
|
||||||
|
func (r *Registry) SetEndpointPriority(target, url string, v float64) {
|
||||||
|
r.EndpointPriority.WithLabelValues(target, url).Set(v)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Registry) SetEndpointInflight(target, url string, v float64) {
|
||||||
|
r.EndpointInflight.WithLabelValues(target, url).Set(v)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Registry) IncEndpointErrors(target, url, errType string) {
|
||||||
|
r.EndpointErrorsTotal.WithLabelValues(target, url, errType).Inc()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Registry) AddTokens(target, model string, promptTokens, totalTokens int) {
|
||||||
|
if promptTokens > 0 {
|
||||||
|
r.TokensTotal.WithLabelValues(target, model, "prompt").Add(float64(promptTokens))
|
||||||
|
}
|
||||||
|
if totalTokens > 0 {
|
||||||
|
r.TokensTotal.WithLabelValues(target, model, "total").Add(float64(totalTokens))
|
||||||
|
}
|
||||||
|
}
|
||||||
62
pkg/metrics/middleware.go
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
package metrics
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/uptrace/bunrouter"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Middleware returns a bunrouter middleware that records per-request Prometheus metrics.
|
||||||
|
// It reads timing from the RequestTrace stored in the context (set by server/trace.go).
|
||||||
|
// The trace target/url labels are optional; pass empty strings if not applicable.
|
||||||
|
func (r *Registry) Middleware(getTrace func(req bunrouter.Request) TraceSnapshot) bunrouter.MiddlewareFunc {
|
||||||
|
return func(next bunrouter.HandlerFunc) bunrouter.HandlerFunc {
|
||||||
|
return func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||||
|
rw := &statusWriter{ResponseWriter: w, status: http.StatusOK}
|
||||||
|
err := next(rw, req)
|
||||||
|
|
||||||
|
snap := getTrace(req)
|
||||||
|
endpoint := req.URL.Path
|
||||||
|
status := fmt.Sprintf("%d", rw.status)
|
||||||
|
|
||||||
|
r.RequestsTotal.WithLabelValues(endpoint, status).Inc()
|
||||||
|
r.RequestDuration.WithLabelValues(endpoint).Observe(snap.TotalSeconds)
|
||||||
|
if snap.ForwardTarget != "" {
|
||||||
|
r.ForwardDuration.WithLabelValues(snap.ForwardTarget, snap.ForwardURL).Observe(snap.ForwardSeconds)
|
||||||
|
}
|
||||||
|
if snap.AdapterType != "" {
|
||||||
|
r.TranslateDuration.WithLabelValues(snap.AdapterType).Observe(snap.TranslateSeconds)
|
||||||
|
}
|
||||||
|
if snap.PromptTokens > 0 || snap.TotalTokens > 0 {
|
||||||
|
r.AddTokens(snap.ForwardTarget, snap.ForwardModel, snap.PromptTokens, snap.TotalTokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TraceSnapshot carries the timing and usage values the metrics middleware needs.
|
||||||
|
type TraceSnapshot struct {
|
||||||
|
TotalSeconds float64
|
||||||
|
ForwardSeconds float64
|
||||||
|
TranslateSeconds float64
|
||||||
|
ForwardTarget string
|
||||||
|
ForwardURL string
|
||||||
|
ForwardModel string
|
||||||
|
AdapterType string
|
||||||
|
PromptTokens int
|
||||||
|
TotalTokens int
|
||||||
|
}
|
||||||
|
|
||||||
|
// statusWriter wraps http.ResponseWriter to capture the written status code.
|
||||||
|
type statusWriter struct {
|
||||||
|
http.ResponseWriter
|
||||||
|
status int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sw *statusWriter) WriteHeader(code int) {
|
||||||
|
sw.status = code
|
||||||
|
sw.ResponseWriter.WriteHeader(code)
|
||||||
|
}
|
||||||
135
pkg/server/google.go
Normal file
@@ -0,0 +1,135 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/uptrace/bunrouter"
|
||||||
|
|
||||||
|
"github.com/Warky-Devs/vecna.git/pkg/embedclient"
|
||||||
|
)
|
||||||
|
|
||||||
|
// --- single embedContent ---
|
||||||
|
|
||||||
|
type googleEmbedContentRequest struct {
|
||||||
|
Content googleContent `json:"content"`
|
||||||
|
TaskType string `json:"taskType,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type googleContent struct {
|
||||||
|
Parts []googlePart `json:"parts"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type googlePart struct {
|
||||||
|
Text string `json:"text"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type googleEmbedContentResponse struct {
|
||||||
|
Embedding googleEmbeddingValues `json:"embedding"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type googleEmbeddingValues struct {
|
||||||
|
Values []float32 `json:"values"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *handler) googleEmbedContent(w http.ResponseWriter, req bunrouter.Request) error {
|
||||||
|
model := req.Param("model")
|
||||||
|
|
||||||
|
var body googleEmbedContentRequest
|
||||||
|
if err := json.NewDecoder(req.Body).Decode(&body); err != nil {
|
||||||
|
return writeJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid request body"})
|
||||||
|
}
|
||||||
|
|
||||||
|
texts := make([]string, len(body.Content.Parts))
|
||||||
|
for i, p := range body.Content.Parts {
|
||||||
|
texts[i] = p.Text
|
||||||
|
}
|
||||||
|
|
||||||
|
client, targetName, targetURL := h.resolveClient(model)
|
||||||
|
trace := TraceFromContext(req.Context())
|
||||||
|
trace.ForwardTarget = targetName
|
||||||
|
trace.ForwardURL = targetURL
|
||||||
|
|
||||||
|
t0 := time.Now()
|
||||||
|
embedResp, err := client.Embed(req.Context(), embedclient.Request{Texts: texts, Model: model})
|
||||||
|
trace.ForwardDuration = time.Since(t0)
|
||||||
|
if err != nil {
|
||||||
|
return writeJSON(w, http.StatusBadGateway, map[string]string{"error": err.Error()})
|
||||||
|
}
|
||||||
|
trace.ForwardModel = embedResp.Model
|
||||||
|
trace.PromptTokens = embedResp.Usage.PromptTokens
|
||||||
|
trace.TotalTokens = embedResp.Usage.TotalTokens
|
||||||
|
|
||||||
|
t1 := time.Now()
|
||||||
|
var adapted []float32
|
||||||
|
if len(embedResp.Embeddings) > 0 {
|
||||||
|
adapted, err = h.adapter.Adapt(embedResp.Embeddings[0])
|
||||||
|
if err != nil {
|
||||||
|
return writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
trace.TranslateDuration = time.Since(t1)
|
||||||
|
|
||||||
|
writeTraceHeaders(w, trace)
|
||||||
|
|
||||||
|
return writeJSON(w, http.StatusOK, googleEmbedContentResponse{
|
||||||
|
Embedding: googleEmbeddingValues{Values: adapted},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- batch batchEmbedContents ---
|
||||||
|
|
||||||
|
type googleBatchRequest struct {
|
||||||
|
Requests []googleEmbedContentRequest `json:"requests"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type googleBatchResponse struct {
|
||||||
|
Embeddings []googleEmbeddingValues `json:"embeddings"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *handler) googleBatchEmbedContents(w http.ResponseWriter, req bunrouter.Request) error {
|
||||||
|
model := req.Param("model")
|
||||||
|
|
||||||
|
var body googleBatchRequest
|
||||||
|
if err := json.NewDecoder(req.Body).Decode(&body); err != nil {
|
||||||
|
return writeJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid request body"})
|
||||||
|
}
|
||||||
|
|
||||||
|
var texts []string
|
||||||
|
for _, r := range body.Requests {
|
||||||
|
for _, p := range r.Content.Parts {
|
||||||
|
texts = append(texts, p.Text)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
client, targetName, targetURL := h.resolveClient(model)
|
||||||
|
trace := TraceFromContext(req.Context())
|
||||||
|
trace.ForwardTarget = targetName
|
||||||
|
trace.ForwardURL = targetURL
|
||||||
|
|
||||||
|
t0 := time.Now()
|
||||||
|
embedResp, err := client.Embed(req.Context(), embedclient.Request{Texts: texts, Model: model})
|
||||||
|
trace.ForwardDuration = time.Since(t0)
|
||||||
|
if err != nil {
|
||||||
|
return writeJSON(w, http.StatusBadGateway, map[string]string{"error": err.Error()})
|
||||||
|
}
|
||||||
|
trace.ForwardModel = embedResp.Model
|
||||||
|
trace.PromptTokens = embedResp.Usage.PromptTokens
|
||||||
|
trace.TotalTokens = embedResp.Usage.TotalTokens
|
||||||
|
|
||||||
|
t1 := time.Now()
|
||||||
|
result := make([]googleEmbeddingValues, len(embedResp.Embeddings))
|
||||||
|
for i, vec := range embedResp.Embeddings {
|
||||||
|
adapted, adaptErr := h.adapter.Adapt(vec)
|
||||||
|
if adaptErr != nil {
|
||||||
|
return writeJSON(w, http.StatusInternalServerError, map[string]string{"error": adaptErr.Error()})
|
||||||
|
}
|
||||||
|
result[i] = googleEmbeddingValues{Values: adapted}
|
||||||
|
}
|
||||||
|
trace.TranslateDuration = time.Since(t1)
|
||||||
|
|
||||||
|
writeTraceHeaders(w, trace)
|
||||||
|
|
||||||
|
return writeJSON(w, http.StatusOK, googleBatchResponse{Embeddings: result})
|
||||||
|
}
|
||||||
74
pkg/server/handler.go
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
|
||||||
|
"github.com/Warky-Devs/vecna.git/pkg/adapter"
|
||||||
|
"github.com/Warky-Devs/vecna.git/pkg/config"
|
||||||
|
"github.com/Warky-Devs/vecna.git/pkg/embedclient"
|
||||||
|
)
|
||||||
|
|
||||||
|
// handler holds shared dependencies for all HTTP handlers.
|
||||||
|
type handler struct {
|
||||||
|
cfg *config.Config
|
||||||
|
clients map[string]embedclient.Client
|
||||||
|
adapter adapter.Adapter
|
||||||
|
logger *zap.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
// resolveClient selects the embed client for the given model name.
|
||||||
|
// Returns the client, target name, and first endpoint URL for tracing.
|
||||||
|
func (h *handler) resolveClient(model string) (embedclient.Client, string, string) {
|
||||||
|
if c, ok := h.clients[model]; ok {
|
||||||
|
url := firstEndpointURL(h.cfg, model)
|
||||||
|
return c, model, url
|
||||||
|
}
|
||||||
|
name := h.cfg.Forward.Default
|
||||||
|
c, ok := h.clients[name]
|
||||||
|
if !ok {
|
||||||
|
// No configured client — return a nil-safe error client
|
||||||
|
return &errClient{err: fmt.Errorf("no client configured for model %q and no default", model)}, name, ""
|
||||||
|
}
|
||||||
|
return c, name, firstEndpointURL(h.cfg, name)
|
||||||
|
}
|
||||||
|
|
||||||
|
func firstEndpointURL(cfg *config.Config, targetName string) string {
|
||||||
|
t, ok := cfg.Forward.Targets[targetName]
|
||||||
|
if !ok || len(t.Endpoints) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return t.Endpoints[0].URL
|
||||||
|
}
|
||||||
|
|
||||||
|
// writeJSON encodes v as JSON and writes it with the given status code.
|
||||||
|
func writeJSON(w http.ResponseWriter, status int, v interface{}) error {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(status)
|
||||||
|
if err := json.NewEncoder(w).Encode(v); err != nil {
|
||||||
|
return fmt.Errorf("writeJSON: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// writeTraceHeaders writes X-Vecna-* timing headers from the RequestTrace.
|
||||||
|
func writeTraceHeaders(w http.ResponseWriter, t *RequestTrace) {
|
||||||
|
total := time.Since(t.Start)
|
||||||
|
w.Header().Set("X-Vecna-Forward-Ms", fmt.Sprintf("%d", t.ForwardDuration.Milliseconds()))
|
||||||
|
w.Header().Set("X-Vecna-Translate-Ms", fmt.Sprintf("%d", t.TranslateDuration.Milliseconds()))
|
||||||
|
w.Header().Set("X-Vecna-Total-Ms", fmt.Sprintf("%d", total.Milliseconds()))
|
||||||
|
}
|
||||||
|
|
||||||
|
// errClient is a Client that always returns a fixed error (used as safe fallback).
|
||||||
|
type errClient struct {
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *errClient) Embed(_ context.Context, _ embedclient.Request) (embedclient.Response, error) {
|
||||||
|
return embedclient.Response{}, e.err
|
||||||
|
}
|
||||||
102
pkg/server/openai.go
Normal file
@@ -0,0 +1,102 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/uptrace/bunrouter"
|
||||||
|
|
||||||
|
"github.com/Warky-Devs/vecna.git/pkg/embedclient"
|
||||||
|
)
|
||||||
|
|
||||||
|
type openAIEmbedRequest struct {
|
||||||
|
Input interface{} `json:"input"` // string or []string
|
||||||
|
Model string `json:"model"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type openAIEmbedResponse struct {
|
||||||
|
Object string `json:"object"`
|
||||||
|
Data []openAIEmbedDatum `json:"data"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
Usage openAIUsage `json:"usage"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type openAIEmbedDatum struct {
|
||||||
|
Object string `json:"object"`
|
||||||
|
Embedding []float32 `json:"embedding"`
|
||||||
|
Index int `json:"index"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type openAIUsage struct {
|
||||||
|
PromptTokens int `json:"prompt_tokens"`
|
||||||
|
TotalTokens int `json:"total_tokens"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *handler) openAIEmbeddings(w http.ResponseWriter, req bunrouter.Request) error {
|
||||||
|
var body openAIEmbedRequest
|
||||||
|
if err := json.NewDecoder(req.Body).Decode(&body); err != nil {
|
||||||
|
return writeJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid request body"})
|
||||||
|
}
|
||||||
|
|
||||||
|
texts, err := toStringSlice(body.Input)
|
||||||
|
if err != nil {
|
||||||
|
return writeJSON(w, http.StatusBadRequest, map[string]string{"error": err.Error()})
|
||||||
|
}
|
||||||
|
|
||||||
|
client, targetName, targetURL := h.resolveClient(body.Model)
|
||||||
|
trace := TraceFromContext(req.Context())
|
||||||
|
trace.ForwardTarget = targetName
|
||||||
|
trace.ForwardURL = targetURL
|
||||||
|
|
||||||
|
t0 := time.Now()
|
||||||
|
embedResp, err := client.Embed(req.Context(), embedclient.Request{Texts: texts, Model: body.Model})
|
||||||
|
trace.ForwardDuration = time.Since(t0)
|
||||||
|
if err != nil {
|
||||||
|
return writeJSON(w, http.StatusBadGateway, map[string]string{"error": err.Error()})
|
||||||
|
}
|
||||||
|
trace.ForwardModel = embedResp.Model
|
||||||
|
trace.PromptTokens = embedResp.Usage.PromptTokens
|
||||||
|
trace.TotalTokens = embedResp.Usage.TotalTokens
|
||||||
|
|
||||||
|
t1 := time.Now()
|
||||||
|
data := make([]openAIEmbedDatum, len(embedResp.Embeddings))
|
||||||
|
for i, vec := range embedResp.Embeddings {
|
||||||
|
adapted, adaptErr := h.adapter.Adapt(vec)
|
||||||
|
if adaptErr != nil {
|
||||||
|
return writeJSON(w, http.StatusInternalServerError, map[string]string{"error": adaptErr.Error()})
|
||||||
|
}
|
||||||
|
data[i] = openAIEmbedDatum{Object: "embedding", Embedding: adapted, Index: i}
|
||||||
|
}
|
||||||
|
trace.TranslateDuration = time.Since(t1)
|
||||||
|
|
||||||
|
writeTraceHeaders(w, trace)
|
||||||
|
|
||||||
|
return writeJSON(w, http.StatusOK, openAIEmbedResponse{
|
||||||
|
Object: "list",
|
||||||
|
Data: data,
|
||||||
|
Model: embedResp.Model,
|
||||||
|
Usage: openAIUsage{PromptTokens: embedResp.Usage.PromptTokens, TotalTokens: embedResp.Usage.TotalTokens},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// toStringSlice accepts a JSON string or array of strings.
|
||||||
|
func toStringSlice(v interface{}) ([]string, error) {
|
||||||
|
switch val := v.(type) {
|
||||||
|
case string:
|
||||||
|
return []string{val}, nil
|
||||||
|
case []interface{}:
|
||||||
|
out := make([]string, len(val))
|
||||||
|
for i, item := range val {
|
||||||
|
s, ok := item.(string)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("input array element %d is not a string", i)
|
||||||
|
}
|
||||||
|
out[i] = s
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("input must be a string or array of strings")
|
||||||
|
}
|
||||||
|
}
|
||||||
163
pkg/server/server.go
Normal file
@@ -0,0 +1,163 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||||
|
"github.com/uptrace/bunrouter"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
|
||||||
|
"github.com/Warky-Devs/vecna.git/pkg/adapter"
|
||||||
|
"github.com/Warky-Devs/vecna.git/pkg/config"
|
||||||
|
"github.com/Warky-Devs/vecna.git/pkg/embedclient"
|
||||||
|
"github.com/Warky-Devs/vecna.git/pkg/metrics"
|
||||||
|
"github.com/Warky-Devs/vecna.git/pkg/server/spec"
|
||||||
|
)
|
||||||
|
|
||||||
|
// New builds and returns a configured bunrouter.Router.
|
||||||
|
func New(
|
||||||
|
cfg *config.Config,
|
||||||
|
clients map[string]embedclient.Client,
|
||||||
|
adp adapter.Adapter,
|
||||||
|
reg *metrics.Registry,
|
||||||
|
logger *zap.Logger,
|
||||||
|
) *bunrouter.Router {
|
||||||
|
router := bunrouter.New(
|
||||||
|
bunrouter.WithMiddleware(authMiddleware(cfg.Server.APIKeys)),
|
||||||
|
bunrouter.WithMiddleware(traceMiddleware()),
|
||||||
|
bunrouter.WithMiddleware(metricsMiddleware(reg, adp)),
|
||||||
|
bunrouter.WithMiddleware(loggingMiddleware(logger)),
|
||||||
|
)
|
||||||
|
|
||||||
|
h := &handler{cfg: cfg, clients: clients, adapter: adp, logger: logger}
|
||||||
|
|
||||||
|
router.POST("/v1/embeddings", h.openAIEmbeddings)
|
||||||
|
router.POST("/v1/models/:model:embedContent", h.googleEmbedContent)
|
||||||
|
router.POST("/v1/models/:model:batchEmbedContents", h.googleBatchEmbedContents)
|
||||||
|
|
||||||
|
// OpenAPI spec + docs
|
||||||
|
router.GET("/openapi.yaml", spec.SpecHandler())
|
||||||
|
router.GET("/docs", spec.DocsHandler())
|
||||||
|
|
||||||
|
// Metrics — only when enabled
|
||||||
|
if cfg.Metrics.Enabled {
|
||||||
|
metricsHandler := promhttp.HandlerFor(reg.Prometheus(), promhttp.HandlerOpts{})
|
||||||
|
path := cfg.Metrics.Path
|
||||||
|
if path == "" {
|
||||||
|
path = "/metrics"
|
||||||
|
}
|
||||||
|
if cfg.Metrics.APIKey != "" {
|
||||||
|
router.GET(path, metricsAuthHandler(cfg.Metrics.APIKey, metricsHandler))
|
||||||
|
} else {
|
||||||
|
router.GET(path, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||||
|
metricsHandler.ServeHTTP(w, req.Request)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return router
|
||||||
|
}
|
||||||
|
|
||||||
|
// authMiddleware rejects requests without a valid Bearer token when api_keys is configured.
|
||||||
|
func authMiddleware(apiKeys []string) bunrouter.MiddlewareFunc {
|
||||||
|
if len(apiKeys) == 0 {
|
||||||
|
return func(next bunrouter.HandlerFunc) bunrouter.HandlerFunc { return next }
|
||||||
|
}
|
||||||
|
keySet := make(map[string]struct{}, len(apiKeys))
|
||||||
|
for _, k := range apiKeys {
|
||||||
|
keySet[k] = struct{}{}
|
||||||
|
}
|
||||||
|
return func(next bunrouter.HandlerFunc) bunrouter.HandlerFunc {
|
||||||
|
return func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||||
|
token := strings.TrimPrefix(req.Header.Get("Authorization"), "Bearer ")
|
||||||
|
if _, ok := keySet[token]; !ok {
|
||||||
|
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return next(w, req)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// traceMiddleware injects a *RequestTrace into every request context.
|
||||||
|
func traceMiddleware() bunrouter.MiddlewareFunc {
|
||||||
|
return func(next bunrouter.HandlerFunc) bunrouter.HandlerFunc {
|
||||||
|
return func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||||
|
ctx := WithTrace(req.Context())
|
||||||
|
return next(w, req.WithContext(ctx))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// metricsMiddleware records Prometheus observations after the handler returns.
|
||||||
|
func metricsMiddleware(reg *metrics.Registry, adp adapter.Adapter) bunrouter.MiddlewareFunc {
|
||||||
|
if reg == nil {
|
||||||
|
return func(next bunrouter.HandlerFunc) bunrouter.HandlerFunc { return next }
|
||||||
|
}
|
||||||
|
adpType := fmt.Sprintf("%T", adp)
|
||||||
|
return reg.Middleware(func(req bunrouter.Request) metrics.TraceSnapshot {
|
||||||
|
t := TraceFromContext(req.Context())
|
||||||
|
total := time.Since(t.Start)
|
||||||
|
return metrics.TraceSnapshot{
|
||||||
|
TotalSeconds: total.Seconds(),
|
||||||
|
ForwardSeconds: t.ForwardDuration.Seconds(),
|
||||||
|
TranslateSeconds: t.TranslateDuration.Seconds(),
|
||||||
|
ForwardTarget: t.ForwardTarget,
|
||||||
|
ForwardURL: t.ForwardURL,
|
||||||
|
ForwardModel: t.ForwardModel,
|
||||||
|
AdapterType: adpType,
|
||||||
|
PromptTokens: t.PromptTokens,
|
||||||
|
TotalTokens: t.TotalTokens,
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// loggingMiddleware logs method, path, status, and timing via zap.
|
||||||
|
func loggingMiddleware(logger *zap.Logger) bunrouter.MiddlewareFunc {
|
||||||
|
return func(next bunrouter.HandlerFunc) bunrouter.HandlerFunc {
|
||||||
|
return func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||||
|
sw := &statusWriter{ResponseWriter: w, status: http.StatusOK}
|
||||||
|
err := next(sw, req)
|
||||||
|
t := TraceFromContext(req.Context())
|
||||||
|
total := time.Since(t.Start)
|
||||||
|
|
||||||
|
logger.Info("request",
|
||||||
|
zap.String("method", req.Method),
|
||||||
|
zap.String("path", req.URL.Path),
|
||||||
|
zap.Int("status", sw.status),
|
||||||
|
zap.Int64("total_ms", total.Milliseconds()),
|
||||||
|
zap.Int64("forward_ms", t.ForwardDuration.Milliseconds()),
|
||||||
|
zap.Int64("translate_ms", t.TranslateDuration.Milliseconds()),
|
||||||
|
)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// metricsAuthHandler wraps a standard http.Handler with Bearer token auth.
|
||||||
|
func metricsAuthHandler(apiKey string, h http.Handler) bunrouter.HandlerFunc {
|
||||||
|
return func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||||
|
token := strings.TrimPrefix(req.Header.Get("Authorization"), "Bearer ")
|
||||||
|
if token != apiKey {
|
||||||
|
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
h.ServeHTTP(w, req.Request)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// statusWriter captures the HTTP status code written by a handler.
|
||||||
|
type statusWriter struct {
|
||||||
|
http.ResponseWriter
|
||||||
|
status int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sw *statusWriter) WriteHeader(code int) {
|
||||||
|
sw.status = code
|
||||||
|
sw.ResponseWriter.WriteHeader(code)
|
||||||
|
}
|
||||||
36
pkg/server/spec/handler.go
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
package spec
|
||||||
|
|
||||||
|
import (
|
||||||
|
_ "embed"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/uptrace/bunrouter"
|
||||||
|
)
|
||||||
|
|
||||||
|
//go:embed openapi.yaml
|
||||||
|
var openapiYAML []byte
|
||||||
|
|
||||||
|
// SpecHandler serves the raw OpenAPI YAML spec.
|
||||||
|
func SpecHandler() bunrouter.HandlerFunc {
|
||||||
|
return func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||||
|
w.Header().Set("Content-Type", "application/yaml")
|
||||||
|
_, err := w.Write(openapiYAML)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// DocsHandler serves the Scalar API reference UI.
|
||||||
|
func DocsHandler() bunrouter.HandlerFunc {
|
||||||
|
return func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||||
|
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||||
|
_, err := w.Write([]byte(`<!doctype html>
|
||||||
|
<html>
|
||||||
|
<head><title>vecna API</title><meta charset="utf-8"/></head>
|
||||||
|
<body>
|
||||||
|
<script id="api-reference" data-url="/openapi.yaml"></script>
|
||||||
|
<script src="https://cdn.jsdelivr.net/npm/@scalar/api-reference"></script>
|
||||||
|
</body>
|
||||||
|
</html>`))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
252
pkg/server/spec/openapi.yaml
Normal file
@@ -0,0 +1,252 @@
|
|||||||
|
openapi: "3.1.0"
|
||||||
|
info:
|
||||||
|
title: vecna Embedding Adapter
|
||||||
|
description: Proxies text to a backing embedding model and adapts the result vectors between dimensions.
|
||||||
|
version: "1.0.0"
|
||||||
|
|
||||||
|
servers:
|
||||||
|
- url: http://localhost:8080
|
||||||
|
|
||||||
|
security:
|
||||||
|
- BearerAuth: []
|
||||||
|
|
||||||
|
components:
|
||||||
|
securitySchemes:
|
||||||
|
BearerAuth:
|
||||||
|
type: http
|
||||||
|
scheme: bearer
|
||||||
|
|
||||||
|
schemas:
|
||||||
|
Error:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
error:
|
||||||
|
type: string
|
||||||
|
|
||||||
|
OpenAIEmbedRequest:
|
||||||
|
type: object
|
||||||
|
required: [input, model]
|
||||||
|
properties:
|
||||||
|
input:
|
||||||
|
oneOf:
|
||||||
|
- type: string
|
||||||
|
- type: array
|
||||||
|
items:
|
||||||
|
type: string
|
||||||
|
model:
|
||||||
|
type: string
|
||||||
|
|
||||||
|
OpenAIEmbedResponse:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
object:
|
||||||
|
type: string
|
||||||
|
example: list
|
||||||
|
model:
|
||||||
|
type: string
|
||||||
|
data:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
object:
|
||||||
|
type: string
|
||||||
|
example: embedding
|
||||||
|
index:
|
||||||
|
type: integer
|
||||||
|
embedding:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
type: number
|
||||||
|
format: float
|
||||||
|
usage:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
prompt_tokens:
|
||||||
|
type: integer
|
||||||
|
total_tokens:
|
||||||
|
type: integer
|
||||||
|
|
||||||
|
GoogleEmbedContentRequest:
|
||||||
|
type: object
|
||||||
|
required: [content]
|
||||||
|
properties:
|
||||||
|
content:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
parts:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
text:
|
||||||
|
type: string
|
||||||
|
taskType:
|
||||||
|
type: string
|
||||||
|
|
||||||
|
GoogleEmbedContentResponse:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
embedding:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
values:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
type: number
|
||||||
|
format: float
|
||||||
|
|
||||||
|
GoogleBatchRequest:
|
||||||
|
type: object
|
||||||
|
required: [requests]
|
||||||
|
properties:
|
||||||
|
requests:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
$ref: '#/components/schemas/GoogleEmbedContentRequest'
|
||||||
|
|
||||||
|
GoogleBatchResponse:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
embeddings:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
values:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
type: number
|
||||||
|
format: float
|
||||||
|
|
||||||
|
headers:
|
||||||
|
X-Vecna-Forward-Ms:
|
||||||
|
description: Time spent forwarding the request to the backing model (milliseconds).
|
||||||
|
schema:
|
||||||
|
type: integer
|
||||||
|
X-Vecna-Translate-Ms:
|
||||||
|
description: Time spent in the dimension adapter (milliseconds).
|
||||||
|
schema:
|
||||||
|
type: integer
|
||||||
|
X-Vecna-Total-Ms:
|
||||||
|
description: Total request wall-clock time (milliseconds).
|
||||||
|
schema:
|
||||||
|
type: integer
|
||||||
|
|
||||||
|
paths:
|
||||||
|
/v1/embeddings:
|
||||||
|
post:
|
||||||
|
summary: OpenAI-compatible embeddings
|
||||||
|
operationId: openaiEmbeddings
|
||||||
|
requestBody:
|
||||||
|
required: true
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: '#/components/schemas/OpenAIEmbedRequest'
|
||||||
|
responses:
|
||||||
|
"200":
|
||||||
|
description: Adapted embeddings
|
||||||
|
headers:
|
||||||
|
X-Vecna-Forward-Ms:
|
||||||
|
$ref: '#/components/headers/X-Vecna-Forward-Ms'
|
||||||
|
X-Vecna-Translate-Ms:
|
||||||
|
$ref: '#/components/headers/X-Vecna-Translate-Ms'
|
||||||
|
X-Vecna-Total-Ms:
|
||||||
|
$ref: '#/components/headers/X-Vecna-Total-Ms'
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: '#/components/schemas/OpenAIEmbedResponse'
|
||||||
|
"400":
|
||||||
|
description: Bad request
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: '#/components/schemas/Error'
|
||||||
|
"401":
|
||||||
|
description: Unauthorized
|
||||||
|
"502":
|
||||||
|
description: Backing model error
|
||||||
|
|
||||||
|
/v1/models/{model}:embedContent:
|
||||||
|
post:
|
||||||
|
summary: Google-compatible single embedContent
|
||||||
|
operationId: googleEmbedContent
|
||||||
|
parameters:
|
||||||
|
- name: model
|
||||||
|
in: path
|
||||||
|
required: true
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
|
requestBody:
|
||||||
|
required: true
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: '#/components/schemas/GoogleEmbedContentRequest'
|
||||||
|
responses:
|
||||||
|
"200":
|
||||||
|
description: Adapted embedding
|
||||||
|
headers:
|
||||||
|
X-Vecna-Forward-Ms:
|
||||||
|
$ref: '#/components/headers/X-Vecna-Forward-Ms'
|
||||||
|
X-Vecna-Translate-Ms:
|
||||||
|
$ref: '#/components/headers/X-Vecna-Translate-Ms'
|
||||||
|
X-Vecna-Total-Ms:
|
||||||
|
$ref: '#/components/headers/X-Vecna-Total-Ms'
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: '#/components/schemas/GoogleEmbedContentResponse'
|
||||||
|
"400":
|
||||||
|
description: Bad request
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: '#/components/schemas/Error'
|
||||||
|
"401":
|
||||||
|
description: Unauthorized
|
||||||
|
"502":
|
||||||
|
description: Backing model error
|
||||||
|
|
||||||
|
/v1/models/{model}:batchEmbedContents:
|
||||||
|
post:
|
||||||
|
summary: Google-compatible batch batchEmbedContents
|
||||||
|
operationId: googleBatchEmbedContents
|
||||||
|
parameters:
|
||||||
|
- name: model
|
||||||
|
in: path
|
||||||
|
required: true
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
|
requestBody:
|
||||||
|
required: true
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: '#/components/schemas/GoogleBatchRequest'
|
||||||
|
responses:
|
||||||
|
"200":
|
||||||
|
description: Adapted embeddings
|
||||||
|
headers:
|
||||||
|
X-Vecna-Forward-Ms:
|
||||||
|
$ref: '#/components/headers/X-Vecna-Forward-Ms'
|
||||||
|
X-Vecna-Translate-Ms:
|
||||||
|
$ref: '#/components/headers/X-Vecna-Translate-Ms'
|
||||||
|
X-Vecna-Total-Ms:
|
||||||
|
$ref: '#/components/headers/X-Vecna-Total-Ms'
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: '#/components/schemas/GoogleBatchResponse'
|
||||||
|
"400":
|
||||||
|
description: Bad request
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: '#/components/schemas/Error'
|
||||||
|
"401":
|
||||||
|
description: Unauthorized
|
||||||
|
"502":
|
||||||
|
description: Backing model error
|
||||||
37
pkg/server/trace.go
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type contextKey int
|
||||||
|
|
||||||
|
const traceKey contextKey = iota
|
||||||
|
|
||||||
|
// RequestTrace holds per-request timing data populated by handlers and middleware.
|
||||||
|
type RequestTrace struct {
|
||||||
|
Start time.Time
|
||||||
|
ForwardDuration time.Duration
|
||||||
|
TranslateDuration time.Duration
|
||||||
|
ForwardTarget string
|
||||||
|
ForwardURL string
|
||||||
|
ForwardModel string
|
||||||
|
AdapterType string
|
||||||
|
PromptTokens int
|
||||||
|
TotalTokens int
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithTrace injects a new *RequestTrace into ctx.
|
||||||
|
func WithTrace(ctx context.Context) context.Context {
|
||||||
|
return context.WithValue(ctx, traceKey, &RequestTrace{Start: time.Now()})
|
||||||
|
}
|
||||||
|
|
||||||
|
// TraceFromContext retrieves the *RequestTrace from ctx.
|
||||||
|
// Returns a zero-value trace (non-nil) if none was set.
|
||||||
|
func TraceFromContext(ctx context.Context) *RequestTrace {
|
||||||
|
if t, ok := ctx.Value(traceKey).(*RequestTrace); ok && t != nil {
|
||||||
|
return t
|
||||||
|
}
|
||||||
|
return &RequestTrace{Start: time.Now()}
|
||||||
|
}
|
||||||
9
prometheus.example.yml
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
global:
|
||||||
|
scrape_interval: 15s
|
||||||
|
|
||||||
|
scrape_configs:
|
||||||
|
- job_name: vecna
|
||||||
|
static_configs:
|
||||||
|
- targets: ["vecna:8080"]
|
||||||
|
# bearer_token: sk-metrics-secret # uncomment if metrics.api_key is set
|
||||||
|
metrics_path: /metrics
|
||||||
330
tests/integration/embed_test.go
Normal file
@@ -0,0 +1,330 @@
|
|||||||
|
//go:build integration
|
||||||
|
|
||||||
|
package integration
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"math"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/Warky-Devs/vecna.git/pkg/adapter"
|
||||||
|
"github.com/Warky-Devs/vecna.git/pkg/embedclient"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Environment variables that configure the integration tests:
|
||||||
|
//
|
||||||
|
// VECNA_TEST_URL base URL of the embedding server (required)
|
||||||
|
// VECNA_TEST_MODEL model name to request (required)
|
||||||
|
// VECNA_TEST_API_TYPE "openai" (default) or "google"
|
||||||
|
// VECNA_TEST_API_KEY bearer token, empty if not needed
|
||||||
|
//
|
||||||
|
// Example (Ollama):
|
||||||
|
//
|
||||||
|
// VECNA_TEST_URL=http://localhost:11434 VECNA_TEST_MODEL=nomic-embed-text \
|
||||||
|
// go test -tags integration ./tests/integration/
|
||||||
|
|
||||||
|
const testText = "The quick brown fox jumps over the lazy dog"
|
||||||
|
|
||||||
|
// cfg holds resolved test parameters.
|
||||||
|
type cfg struct {
|
||||||
|
url string
|
||||||
|
model string
|
||||||
|
apiType string
|
||||||
|
apiKey string
|
||||||
|
}
|
||||||
|
|
||||||
|
func loadCfg(t *testing.T) cfg {
|
||||||
|
t.Helper()
|
||||||
|
url := os.Getenv("VECNA_TEST_URL")
|
||||||
|
if url == "" {
|
||||||
|
t.Skip("VECNA_TEST_URL not set — skipping integration tests")
|
||||||
|
}
|
||||||
|
model := os.Getenv("VECNA_TEST_MODEL")
|
||||||
|
if model == "" {
|
||||||
|
t.Skip("VECNA_TEST_MODEL not set — skipping integration tests")
|
||||||
|
}
|
||||||
|
apiType := os.Getenv("VECNA_TEST_API_TYPE")
|
||||||
|
if apiType == "" {
|
||||||
|
apiType = "openai"
|
||||||
|
}
|
||||||
|
return cfg{
|
||||||
|
url: url,
|
||||||
|
model: model,
|
||||||
|
apiType: apiType,
|
||||||
|
apiKey: os.Getenv("VECNA_TEST_API_KEY"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newClient(c cfg) embedclient.Client {
|
||||||
|
httpClient := &http.Client{Timeout: 30 * time.Second}
|
||||||
|
if c.apiType == "google" {
|
||||||
|
return embedclient.NewGoogle(c.url, c.apiKey, c.model, httpClient)
|
||||||
|
}
|
||||||
|
return embedclient.NewOpenAI(c.url, c.apiKey, httpClient)
|
||||||
|
}
|
||||||
|
|
||||||
|
// embed fetches a single embedding vector for testText.
|
||||||
|
func embed(t *testing.T, client embedclient.Client, model string) []float32 {
|
||||||
|
t.Helper()
|
||||||
|
resp, err := client.Embed(context.Background(), embedclient.Request{
|
||||||
|
Texts: []string{testText},
|
||||||
|
Model: model,
|
||||||
|
})
|
||||||
|
require.NoError(t, err, "embedding request failed")
|
||||||
|
require.Len(t, resp.Embeddings, 1, "expected exactly one embedding in response")
|
||||||
|
require.NotEmpty(t, resp.Embeddings[0], "embedding vector is empty")
|
||||||
|
return resp.Embeddings[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
func l2Norm(v []float32) float64 {
|
||||||
|
var sum float64
|
||||||
|
for _, x := range v {
|
||||||
|
sum += float64(x) * float64(x)
|
||||||
|
}
|
||||||
|
return math.Sqrt(sum)
|
||||||
|
}
|
||||||
|
|
||||||
|
// assertUnitNorm checks the vector is approximately L2-normalised.
|
||||||
|
func assertUnitNorm(t *testing.T, v []float32) {
|
||||||
|
t.Helper()
|
||||||
|
norm := l2Norm(v)
|
||||||
|
assert.InDelta(t, 1.0, norm, 0.01, "expected unit L2 norm after adaptation")
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---- Tests ----------------------------------------------------------------
|
||||||
|
|
||||||
|
// TestNativeDimension verifies the server returns a non-empty vector.
|
||||||
|
// This is the baseline; the native dimension is logged so it can be used
|
||||||
|
// as VECNA_TEST_SOURCE_DIM for the dimension tests below.
|
||||||
|
func TestNativeDimension(t *testing.T) {
|
||||||
|
c := loadCfg(t)
|
||||||
|
client := newClient(c)
|
||||||
|
|
||||||
|
vec := embed(t, client, c.model)
|
||||||
|
t.Logf("native dimension: %d", len(vec))
|
||||||
|
t.Logf("native L2 norm: %.6f", l2Norm(vec))
|
||||||
|
|
||||||
|
assert.Greater(t, len(vec), 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestDownscaleTruncate tests truncation to half the native dimension.
|
||||||
|
func TestDownscaleTruncate(t *testing.T) {
|
||||||
|
c := loadCfg(t)
|
||||||
|
client := newClient(c)
|
||||||
|
|
||||||
|
vec := embed(t, client, c.model)
|
||||||
|
srcDim := len(vec)
|
||||||
|
tgtDim := srcDim / 2
|
||||||
|
if tgtDim == 0 {
|
||||||
|
t.Skipf("source dim %d too small to halve", srcDim)
|
||||||
|
}
|
||||||
|
|
||||||
|
adp, err := adapter.NewTruncate(srcDim, tgtDim, adapter.TruncateFromEnd, adapter.PadAtEnd)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
out, err := adp.Adapt(vec)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Len(t, out, tgtDim, "output dimension mismatch")
|
||||||
|
assertUnitNorm(t, out)
|
||||||
|
t.Logf("downscale truncate: %d → %d norm=%.6f", srcDim, tgtDim, l2Norm(out))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestDownscaleTruncateFromStart tests keeping the last N dims.
|
||||||
|
func TestDownscaleTruncateFromStart(t *testing.T) {
|
||||||
|
c := loadCfg(t)
|
||||||
|
client := newClient(c)
|
||||||
|
|
||||||
|
vec := embed(t, client, c.model)
|
||||||
|
srcDim := len(vec)
|
||||||
|
tgtDim := srcDim / 2
|
||||||
|
if tgtDim == 0 {
|
||||||
|
t.Skipf("source dim %d too small to halve", srcDim)
|
||||||
|
}
|
||||||
|
|
||||||
|
adp, err := adapter.NewTruncate(srcDim, tgtDim, adapter.TruncateFromStart, adapter.PadAtEnd)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
out, err := adp.Adapt(vec)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Len(t, out, tgtDim)
|
||||||
|
assertUnitNorm(t, out)
|
||||||
|
t.Logf("downscale truncate-from-start: %d → %d norm=%.6f", srcDim, tgtDim, l2Norm(out))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestDownscaleRandom tests random projection to a lower dimension.
|
||||||
|
func TestDownscaleRandom(t *testing.T) {
|
||||||
|
c := loadCfg(t)
|
||||||
|
client := newClient(c)
|
||||||
|
|
||||||
|
vec := embed(t, client, c.model)
|
||||||
|
srcDim := len(vec)
|
||||||
|
tgtDim := srcDim / 2
|
||||||
|
if tgtDim == 0 {
|
||||||
|
t.Skipf("source dim %d too small to halve", srcDim)
|
||||||
|
}
|
||||||
|
|
||||||
|
adp, err := adapter.NewRandom(srcDim, tgtDim, 42)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
out, err := adp.Adapt(vec)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Len(t, out, tgtDim)
|
||||||
|
assertUnitNorm(t, out)
|
||||||
|
t.Logf("downscale random: %d → %d norm=%.6f", srcDim, tgtDim, l2Norm(out))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestDownscaleToFixed tests truncation to a fixed well-known target (e.g. 768 → 256).
|
||||||
|
// Skips if the native dimension is not larger than the target.
|
||||||
|
func TestDownscaleToFixed(t *testing.T) {
|
||||||
|
c := loadCfg(t)
|
||||||
|
client := newClient(c)
|
||||||
|
|
||||||
|
tgtDim := intEnv("VECNA_TEST_TARGET_DIM", 256)
|
||||||
|
|
||||||
|
vec := embed(t, client, c.model)
|
||||||
|
srcDim := len(vec)
|
||||||
|
if srcDim <= tgtDim {
|
||||||
|
t.Skipf("native dim %d is not larger than target dim %d", srcDim, tgtDim)
|
||||||
|
}
|
||||||
|
|
||||||
|
adp, err := adapter.NewTruncate(srcDim, tgtDim, adapter.TruncateFromEnd, adapter.PadAtEnd)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
out, err := adp.Adapt(vec)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Len(t, out, tgtDim)
|
||||||
|
assertUnitNorm(t, out)
|
||||||
|
t.Logf("downscale to fixed: %d → %d norm=%.6f", srcDim, tgtDim, l2Norm(out))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestUpscalePadEnd tests zero-padding to double the native dimension.
|
||||||
|
func TestUpscalePadEnd(t *testing.T) {
|
||||||
|
c := loadCfg(t)
|
||||||
|
client := newClient(c)
|
||||||
|
|
||||||
|
vec := embed(t, client, c.model)
|
||||||
|
srcDim := len(vec)
|
||||||
|
tgtDim := srcDim * 2
|
||||||
|
|
||||||
|
adp, err := adapter.NewTruncate(srcDim, tgtDim, adapter.TruncateFromEnd, adapter.PadAtEnd)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
out, err := adp.Adapt(vec)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Len(t, out, tgtDim)
|
||||||
|
assertUnitNorm(t, out)
|
||||||
|
// The second half of the raw output (before normalisation) should have been zero-padded.
|
||||||
|
// After normalisation all values shrink but the last half should all be equal (zero → 0).
|
||||||
|
t.Logf("upscale pad-end: %d → %d norm=%.6f", srcDim, tgtDim, l2Norm(out))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestUpscalePadStart tests zero-padding prepended to the vector.
|
||||||
|
func TestUpscalePadStart(t *testing.T) {
|
||||||
|
c := loadCfg(t)
|
||||||
|
client := newClient(c)
|
||||||
|
|
||||||
|
vec := embed(t, client, c.model)
|
||||||
|
srcDim := len(vec)
|
||||||
|
tgtDim := srcDim * 2
|
||||||
|
|
||||||
|
adp, err := adapter.NewTruncate(srcDim, tgtDim, adapter.TruncateFromEnd, adapter.PadAtStart)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
out, err := adp.Adapt(vec)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Len(t, out, tgtDim)
|
||||||
|
assertUnitNorm(t, out)
|
||||||
|
t.Logf("upscale pad-start: %d → %d norm=%.6f", srcDim, tgtDim, l2Norm(out))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestUpscaleRandom tests random projection to a higher dimension.
|
||||||
|
func TestUpscaleRandom(t *testing.T) {
|
||||||
|
c := loadCfg(t)
|
||||||
|
client := newClient(c)
|
||||||
|
|
||||||
|
vec := embed(t, client, c.model)
|
||||||
|
srcDim := len(vec)
|
||||||
|
tgtDim := srcDim * 2
|
||||||
|
|
||||||
|
adp, err := adapter.NewRandom(srcDim, tgtDim, 42)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
out, err := adp.Adapt(vec)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Len(t, out, tgtDim)
|
||||||
|
assertUnitNorm(t, out)
|
||||||
|
t.Logf("upscale random: %d → %d norm=%.6f", srcDim, tgtDim, l2Norm(out))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestUpscaleToFixed tests upscaling to a fixed well-known target (e.g. 768 → 1536).
|
||||||
|
// Skips if the native dimension is already larger than or equal to the target.
|
||||||
|
func TestUpscaleToFixed(t *testing.T) {
|
||||||
|
c := loadCfg(t)
|
||||||
|
client := newClient(c)
|
||||||
|
|
||||||
|
tgtDim := intEnv("VECNA_TEST_TARGET_DIM", 1536)
|
||||||
|
|
||||||
|
vec := embed(t, client, c.model)
|
||||||
|
srcDim := len(vec)
|
||||||
|
if srcDim >= tgtDim {
|
||||||
|
t.Skipf("native dim %d is not smaller than target dim %d", srcDim, tgtDim)
|
||||||
|
}
|
||||||
|
|
||||||
|
adp, err := adapter.NewTruncate(srcDim, tgtDim, adapter.TruncateFromEnd, adapter.PadAtEnd)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
out, err := adp.Adapt(vec)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Len(t, out, tgtDim)
|
||||||
|
assertUnitNorm(t, out)
|
||||||
|
t.Logf("upscale to fixed: %d → %d norm=%.6f", srcDim, tgtDim, l2Norm(out))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRoundtripConsistency embeds the same text twice and checks the vectors are identical.
|
||||||
|
func TestRoundtripConsistency(t *testing.T) {
|
||||||
|
c := loadCfg(t)
|
||||||
|
client := newClient(c)
|
||||||
|
|
||||||
|
v1 := embed(t, client, c.model)
|
||||||
|
v2 := embed(t, client, c.model)
|
||||||
|
|
||||||
|
require.Equal(t, len(v1), len(v2), "dimension mismatch between two identical requests")
|
||||||
|
|
||||||
|
var maxDiff float32
|
||||||
|
for i := range v1 {
|
||||||
|
d := v1[i] - v2[i]
|
||||||
|
if d < 0 {
|
||||||
|
d = -d
|
||||||
|
}
|
||||||
|
if d > maxDiff {
|
||||||
|
maxDiff = d
|
||||||
|
}
|
||||||
|
}
|
||||||
|
t.Logf("max element-wise diff between two identical embeds: %e", maxDiff)
|
||||||
|
assert.Less(t, maxDiff, float32(1e-5), "embeddings for identical input should be deterministic")
|
||||||
|
}
|
||||||
|
|
||||||
|
// intEnv reads an integer from an env var, returning defaultVal if unset or invalid.
|
||||||
|
func intEnv(key string, defaultVal int) int {
|
||||||
|
if s := os.Getenv(key); s != "" {
|
||||||
|
if n, err := strconv.Atoi(s); err == nil {
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return defaultVal
|
||||||
|
}
|
||||||