feat: 🎉 Vectors na Vectors, the begining

Translate 1536 <-> 768 , 3072 <-> 2048
This commit is contained in:
2026-04-11 18:05:05 +02:00
parent d98ea7c222
commit 4009a54e39
58 changed files with 5324 additions and 2 deletions

90
.github/workflows/release.yml vendored Normal file
View File

@@ -0,0 +1,90 @@
name: Release
on:
push:
tags:
- 'v*'
workflow_dispatch:
inputs:
tag:
description: 'Tag to release (e.g. v1.2.3)'
required: true
permissions:
contents: write
jobs:
test:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
with:
go-version-file: go.mod
- name: Test
run: go test ./...
- name: Lint
run: go vet ./...
release:
needs: test
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
- uses: actions/setup-go@v5
with:
go-version-file: go.mod
- name: Build release binaries
run: |
VERSION="${{ github.event.inputs.tag || github.ref_name }}"
mkdir -p dist
for target in "linux/amd64" "linux/arm64" "darwin/amd64" "darwin/arm64" "windows/amd64"; do
GOOS="${target%/*}"
GOARCH="${target#*/}"
EXT=""
[ "$GOOS" = "windows" ] && EXT=".exe"
BINARY="vecna-${GOOS}-${GOARCH}${EXT}"
CGO_ENABLED=0 GOOS="$GOOS" GOARCH="$GOARCH" go build \
-trimpath \
-ldflags "-s -w -X main.version=${VERSION}" \
-o "dist/${BINARY}" \
./cmd/vecna
if [ "$GOOS" = "windows" ]; then
zip -j "dist/vecna-${GOOS}-${GOARCH}.zip" "dist/${BINARY}"
else
tar -czf "dist/vecna-${GOOS}-${GOARCH}.tar.gz" -C dist "${BINARY}"
fi
echo "Built ${BINARY}"
done
- name: Generate checksums
run: |
cd dist
sha256sum *.tar.gz *.zip > checksums.txt
- name: Create release and upload assets
run: |
TAG="${{ github.event.inputs.tag || github.ref_name }}"
PREV_TAG=$(git tag --sort=-version:refname | grep -v "^${TAG}$" | head -1)
if [ -n "$PREV_TAG" ]; then
RANGE="${PREV_TAG}..${TAG}"
else
RANGE="HEAD~20..HEAD"
fi
NOTES=$(git log "$RANGE" --pretty=format:"- %s" --no-merges)
BODY="## What's changed"$'\n'"${NOTES}"
gh release create "${TAG}" \
--title "${TAG}" \
--notes "${BODY}" \
dist/*.tar.gz dist/*.zip dist/checksums.txt
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}

2
.gitignore vendored
View File

@@ -30,3 +30,5 @@ go.work.sum
# Editor/IDE
# .idea/
# .vscode/
bin/

30
.golangci.yml Normal file
View File

@@ -0,0 +1,30 @@
version: "2"
formatters:
enable:
- goimports
settings:
goimports:
local-prefixes:
- github.com/Warky-Devs/vecna.git
linters:
default: standard
enable:
- misspell
- gocritic
- noctx
- bodyclose
- errorlint
- copyloopvar
- durationcheck
linters-settings:
gocritic:
disabled-checks:
- commentedOutCode
- hugeParam
issues:
max-issues-per-linter: 0
max-same-issues: 0

22
Dockerfile Normal file
View File

@@ -0,0 +1,22 @@
# ── build stage ───────────────────────────────────────────────────────────────
FROM golang:1.26-alpine AS builder
WORKDIR /src
COPY go.mod go.sum ./
RUN go mod download
COPY . .
RUN CGO_ENABLED=0 GOOS=linux go build -trimpath -ldflags="-s -w" -o /vecna ./cmd/vecna
# ── runtime stage ──────────────────────────────────────────────────────────────
FROM alpine:3.21
RUN apk add --no-cache ca-certificates
COPY --from=builder /vecna /usr/local/bin/vecna
EXPOSE 8080
ENTRYPOINT ["vecna"]
CMD ["serve"]

63
Makefile Normal file
View File

@@ -0,0 +1,63 @@
BINARY := vecna
CMD := ./cmd/vecna
BUILD_DIR := ./bin
BUMP ?= patch
TEST_URL ?=
TEST_MODEL ?=
.PHONY: all build test lint fmt tidy clean release-version test-integration
all: tidy fmt lint test build
build:
mkdir -p $(BUILD_DIR)
go build -o $(BUILD_DIR)/$(BINARY) $(CMD)
test:
go test ./...
lint:
golangci-lint run ./...
fmt:
golangci-lint run --fix ./...
gofmt -w .
tidy:
go mod tidy
clean:
rm -rf $(BUILD_DIR)
# Run dimension integration tests against a live embedding server.
# Requires VECNA_TEST_URL and VECNA_TEST_MODEL to be set.
#
# Examples:
# make test-integration TEST_URL=http://localhost:11434 TEST_MODEL=nomic-embed-text
# make test-integration TEST_URL=http://localhost:11434 TEST_MODEL=nomic-embed-text VECNA_TEST_TARGET_DIM=256
#
test-integration:
@if [ -z "$(TEST_URL)" ]; then echo "ERROR: TEST_URL is required"; exit 1; fi
@if [ -z "$(TEST_MODEL)" ]; then echo "ERROR: TEST_MODEL is required"; exit 1; fi
VECNA_TEST_URL=$(TEST_URL) \
VECNA_TEST_MODEL=$(TEST_MODEL) \
VECNA_TEST_API_TYPE=$(or $(TEST_API_TYPE),openai) \
VECNA_TEST_API_KEY=$(TEST_API_KEY) \
go test -v -tags integration -timeout 120s ./tests/integration/
release-version:
@current=$$(git describe --tags --abbrev=0 2>/dev/null || echo "v0.0.0"); \
major=$$(echo $$current | sed 's/^v//' | cut -d. -f1); \
minor=$$(echo $$current | sed 's/^v//' | cut -d. -f2); \
patch=$$(echo $$current | sed 's/^v//' | cut -d. -f3); \
case "$(BUMP)" in \
major) major=$$((major+1)); minor=0; patch=0 ;; \
minor) minor=$$((minor+1)); patch=0 ;; \
patch) patch=$$((patch+1)) ;; \
*) echo "Unknown BUMP value '$(BUMP)'. Use: patch (default), minor, major"; exit 1 ;; \
esac; \
next="v$${major}.$${minor}.$${patch}"; \
echo "Tagging $$current → $$next"; \
git tag -a "$$next" -m "Release $$next"; \
git push origin "$$next"

361
README.md
View File

@@ -1,2 +1,359 @@
# vecna
Vecna - Vectors na Vectors, Translate 1536 &lt;-> 768 , 3072 &lt;-> 2048
# vecna — Vectors na Vectors
**vecna** is an OpenAI- and Google-compatible HTTP proxy that sits between your application and a local or remote embedding model. It forwards text to the backing model, receives the raw embedding vector, and re-shapes it to the dimension your vector database or application expects — without any changes to your client code.
New to embeddings? → [docs/what_is_embeddings.md](docs/what_is_embeddings.md)
Every time I install a tool that needs vector embeddings, it assumes OpenAI's dimensions. There are very few local models that match those exact dimensions. Most tools don't give you an option to change them, and some vector databases are hardcoded to fixed dimensions entirely. Many open source models do stick to the same dimensions, which helps — but I run embedding models across several machines and it's always been a pain to use them effectively. This is where vecna helps me.
---
## Why
Most vector databases are initialized with a fixed dimension (e.g. 1536 for `pgvector` defaults, 768 for many HNSW indexes). If you switch embedding models — or run a smaller local model that produces 768-dim vectors — your existing index breaks.
vecna solves this by translating dimensions at the proxy layer:
- **Downscale**: 3072 → 1536, 768 → 256 (truncation or random projection)
- **Upscale**: 384 → 1536, 768 → 1536 (zero-padding)
- **Same-dim pass-through**: no transformation, just proxy and auth
All output vectors are L2-normalized so cosine similarity remains valid.
---
## Install
```sh
go install github.com/Warky-Devs/vecna.git/cmd/vecna@latest
```
Or build from source:
```sh
make build # outputs ./bin/vecna
```
---
## Quick start
```sh
# Interactive setup: discovers local servers, configures adapter, writes config
vecna onboard
# Start the proxy
vecna serve
# Test a request (OpenAI-compatible)
curl -s http://localhost:8080/v1/embeddings \
-H "Content-Type: application/json" \
-d '{"input": "hello world", "model": "nomic-embed-text"}'
```
---
## Commands
| Command | Description |
|---------|-------------|
| `vecna onboard` | Interactive wizard: discover servers → detect dims → configure → test → write config |
| `vecna serve` | Start the proxy server |
| `vecna query <text>` | Embed text and print the resulting vector as JSON |
| `vecna convert` | Convert vectors from file/stdin using the configured adapter |
| `vecna search` | Scan LAN for embedding servers and add one to config |
| `vecna models` | List models available on each configured forwarder |
| `vecna test` | Test each configured endpoint; `--remove-broken` prunes failing ones |
| `vecna editconfig` | Print config path and open it in `$EDITOR` |
---
## Query
Send text directly to a forwarding target and print the adapted vector as JSON.
```sh
# uses forward.default target
vecna query "hello world"
# specific target
vecna query --target ollama "hello world"
# skip the adapter — raw model output
vecna query --raw "hello world"
# compact single-line output (pipe-friendly)
vecna query --compact "hello world"
# read text from stdin
echo "hello world" | vecna query -
# inspect a single dimension
vecna query --compact "hello world" | jq '.[0]'
# save to file
vecna query --compact "hello world" > vector.json
```
Status info (target, model, dims, tokens) is written to stderr; stdout is clean JSON.
---
## Configuration
Default config path: `~/vecna.json` (created by `onboard` or `editconfig`).
Override with `--config path/to/file.yaml` or env vars prefixed `VECNA_`.
```json
{
"server": {
"port": 8080,
"host": "0.0.0.0",
"api_keys": ["sk-vecna-abc123"]
},
"forward": {
"default": "ollama",
"targets": {
"ollama": {
"api_type": "openai",
"model": "nomic-embed-text",
"api_key": "",
"timeout_secs": 30,
"cooldown_secs": 60,
"priority_decay": 2,
"priority_recovery": 5,
"endpoints": [
{"url": "http://localhost:11434", "priority": 10}
]
}
}
},
"adapter": {
"type": "truncate",
"source_dim": 768,
"target_dim": 1536,
"truncate_mode": "from_end",
"pad_mode": "at_end"
},
"metrics": {
"enabled": true,
"path": "/metrics",
"api_key": ""
}
}
```
---
## Important: quality and consistency warnings
### Upscaling reduces quality — use the highest-dimension model you can
Upscaling (e.g. 768 → 1536) **does not add information**. The extra dimensions are zeros (truncate adapter) or linear combinations of existing values (random/projection). The resulting vectors occupy a 1536-dim space but carry no more semantic content than the original 768-dim ones.
**The model's native output dimension is the ceiling of quality.** If your vector database requires 1536 dims, use a model that natively produces 1536 dims. Use vecna's upscaling only as a compatibility shim when you cannot change the index schema — not as a way to improve retrieval quality.
- Downscale (higher → lower): small, controlled quality loss. Acceptable for MRL models.
- Upscale (lower → higher): no quality gain, only compatibility. Replace the model when possible.
### Changing any adapter setting requires regenerating all stored embeddings
vecna's adapter is applied at query time **and** at indexing time. If you change any of the following, every vector already stored in your database is now in a different space and comparisons against new queries will silently return wrong results:
- `type` (truncate / random / projection)
- `source_dim` or `target_dim`
- `truncate_mode` (from_end / from_start)
- `pad_mode` (at_end / at_start)
- `seed` (random adapter)
- the backing model itself
**When you change adapter settings: stop ingestion, re-embed your entire corpus through vecna with the new settings, repopulate the index, then resume.**
There is no partial migration path — a mixed index produces degraded or incorrect search results.
---
## Adapter types
| Type | Description |
|------|-------------|
| `truncate` | Slice or zero-pad the vector. Fast, deterministic. Best for MRL-trained models. |
| `random` | Seeded Gaussian projection matrix. Preserves distances (Johnson-Lindenstrauss). |
| `projection` | Learned linear matrix from a JSON file. Highest quality, requires pre-training. |
---
## Truncation and padding modes
### `truncate_mode` — which part of the vector is kept when downscaling
| Value | Keeps |
|-------|-------|
| `from_end` *(default)* | first N dimensions |
| `from_start` | last N dimensions |
**`from_end`** — use for **Matryoshka Representation Learning (MRL)** models. The most important information is packed into the first dimensions.
Models: `nomic-embed-text`, `mxbai-embed-large`, `text-embedding-3-small`, `text-embedding-3-large`, `snowflake-arctic-embed`, `e5-mistral-7b-instruct`.
**`from_start`** — use when task-specific information is at the end of the vector. Try this if `from_end` gives poor retrieval on a non-MRL model.
Models: some fine-tuned BERT variants, domain-specific models with task heads appended after base dimensions.
### `pad_mode` — where zeros are inserted when upscaling
| Value | Zeros go |
|-------|----------|
| `at_end` *(default)* | after the real values |
| `at_start` | before the real values |
**`at_end`** — almost always correct. Keeps the original vector in the first N positions.
**`at_start`** — use if your index expects meaningful content at the end of the vector.
### Common combinations
| Scenario | `truncate_mode` | `pad_mode` |
|----------|----------------|------------|
| MRL model downscale | `from_end` | `at_end` |
| MRL model upscale (e.g. 768→1536) | `from_end` | `at_end` |
| Non-MRL BERT fine-tune | `from_start` | `at_end` |
| Custom index with leading-zeros convention | `from_end` | `at_start` |
When unsure, run `vecna test` before and after and compare the reported L2 norm.
---
## API endpoints
### OpenAI-compatible
```
POST /v1/embeddings
Authorization: Bearer <api_key>
Content-Type: application/json
{"input": "text or array of texts", "model": "nomic-embed-text"}
```
### Google Gemini-compatible
```
POST /v1/models/{model}:embedContent
POST /v1/models/{model}:batchEmbedContents
```
### OpenAPI spec and docs
```
GET /openapi.yaml
GET /docs
```
---
## Response tracing headers
| Header | Value |
|--------|-------|
| `X-Vecna-Forward-Ms` | Time waiting on the backing model |
| `X-Vecna-Translate-Ms` | Time in the adapter |
| `X-Vecna-Total-Ms` | Total request wall time |
---
## Prometheus metrics
Enable in config: `metrics.enabled: true`. Scrape at `GET /metrics`.
| Metric | Type | Description |
|--------|------|-------------|
| `vecna_requests_total` | counter | Requests served, by endpoint and status |
| `vecna_request_duration_seconds` | histogram | Total request wall time |
| `vecna_forward_duration_seconds` | histogram | Time waiting on the backing model |
| `vecna_translate_duration_seconds` | histogram | Time in the adapter |
| `vecna_endpoint_priority` | gauge | Current dynamic routing priority per endpoint |
| `vecna_endpoint_inflight` | gauge | Active in-flight requests per endpoint |
| `vecna_endpoint_errors_total` | counter | Forwarding failures by error type |
| `vecna_tokens_total` | counter | Tokens consumed, by target, model, and type (`prompt`/`total`) |
---
## Development
```sh
make build # compile
make test # unit tests
make lint # golangci-lint
make fmt # goimports + gofmt
# Integration tests against a live server
make test-integration TEST_URL=http://localhost:11434 TEST_MODEL=nomic-embed-text
# Tag and push a release
make release-version BUMP=patch # patch | minor | major
```
---
## Docker
### Build
```sh
docker build -t vecna .
```
### First-time setup with docker compose
```sh
cp docker-compose.example.yml docker-compose.yml
docker compose up -d
```
Starts vecna and an Ollama instance. The `vecna_config` named volume persists the config across container rebuilds.
### Onboard (interactive setup)
```sh
docker compose run --rm -it vecna onboard --config /config/vecna.json
```
Ollama is reachable by hostname on the Docker network — the scanner will find it automatically. After onboarding, restart the proxy:
```sh
docker compose restart vecna
```
### Query
```sh
docker compose run --rm vecna query --compact "hello world" --config /config/vecna.json
```
### Test endpoints
```sh
# report latency and dims
docker compose run --rm vecna test --config /config/vecna.json
# test and remove failing endpoints
docker compose run --rm vecna test --config /config/vecna.json --remove-broken
```
### Edit config manually
```sh
docker compose run --rm -it vecna sh -c "vi /config/vecna.json"
```
### With Prometheus
```sh
docker compose --profile metrics up -d
```
Scrape config is in `prometheus.example.yml`. Set `bearer_token` if `metrics.api_key` is configured.
---
© Hein Puth — [Warky Devs (Pty) Ltd](https://github.com/Warky-Devs)

160
cmd/vecna/convert.go Normal file
View File

@@ -0,0 +1,160 @@
package main
import (
"encoding/json"
"fmt"
"io"
"os"
"github.com/spf13/cobra"
"github.com/Warky-Devs/vecna.git/pkg/adapter"
"github.com/Warky-Devs/vecna.git/pkg/config"
)
var (
convertInput string
convertOutput string
)
var convertCmd = &cobra.Command{
Use: "convert",
Short: "Convert vectors from one dimension to another",
Long: "Reads a JSON array of float32 vectors, applies the configured adapter, and writes the result.",
RunE: runConvert,
}
func init() {
convertCmd.Flags().StringVarP(&convertInput, "input", "i", "-", "input file path (- for stdin)")
convertCmd.Flags().StringVarP(&convertOutput, "output", "o", "-", "output file path (- for stdout)")
}
func runConvert(cmd *cobra.Command, _ []string) error {
cfg, err := config.Load(cfgFile)
if err != nil {
return fmt.Errorf("load config: %w", err)
}
adp, err := buildAdapter(cfg)
if err != nil {
return fmt.Errorf("build adapter: %w", err)
}
in, err := openReader(convertInput)
if err != nil {
return fmt.Errorf("open input: %w", err)
}
if f, ok := in.(*os.File); ok && f != os.Stdin {
defer func() { _ = f.Close() }()
}
out, err := openWriter(convertOutput)
if err != nil {
return fmt.Errorf("open output: %w", err)
}
if f, ok := out.(*os.File); ok && f != os.Stdout {
defer func() { _ = f.Close() }()
}
var vecs [][]float32
if err := json.NewDecoder(in).Decode(&vecs); err != nil {
return fmt.Errorf("decode input: %w", err)
}
result := make([][]float32, len(vecs))
for i, v := range vecs {
adapted, adaptErr := adp.Adapt(v)
if adaptErr != nil {
return fmt.Errorf("adapt vector %d: %w", i, adaptErr)
}
result[i] = adapted
}
enc := json.NewEncoder(out)
enc.SetIndent("", " ")
if err := enc.Encode(result); err != nil {
return fmt.Errorf("encode output: %w", err)
}
return nil
}
func openReader(path string) (io.Reader, error) {
if path == "-" {
return os.Stdin, nil
}
return os.Open(path)
}
func openWriter(path string) (io.Writer, error) {
if path == "-" {
return os.Stdout, nil
}
return os.Create(path)
}
// buildAdapter constructs the Adapter from the loaded config.
func buildAdapter(cfg *config.Config) (adapter.Adapter, error) {
ac := cfg.Adapter
switch ac.Type {
case "truncate":
tm, pm, err := parseTruncateModes(ac.TruncateMode, ac.PadMode)
if err != nil {
return nil, err
}
return adapter.NewTruncate(ac.SourceDim, ac.TargetDim, tm, pm)
case "random":
return adapter.NewRandom(ac.SourceDim, ac.TargetDim, ac.Seed)
case "projection":
if ac.MatrixFile == "" {
return nil, fmt.Errorf("adapter type 'projection' requires matrix_file")
}
matrix, err := loadMatrix(ac.MatrixFile)
if err != nil {
return nil, fmt.Errorf("load projection matrix: %w", err)
}
return adapter.NewProjection(ac.SourceDim, ac.TargetDim, matrix)
default:
return nil, fmt.Errorf("unknown adapter type %q; valid: truncate, random, projection", ac.Type)
}
}
func parseTruncateModes(truncMode, padMode string) (adapter.TruncateMode, adapter.PadMode, error) {
var tm adapter.TruncateMode
switch truncMode {
case "from_end", "":
tm = adapter.TruncateFromEnd
case "from_start":
tm = adapter.TruncateFromStart
default:
return 0, 0, fmt.Errorf("unknown truncate_mode %q; valid: from_end, from_start", truncMode)
}
var pm adapter.PadMode
switch padMode {
case "at_end", "":
pm = adapter.PadAtEnd
case "at_start":
pm = adapter.PadAtStart
default:
return 0, 0, fmt.Errorf("unknown pad_mode %q; valid: at_end, at_start", padMode)
}
return tm, pm, nil
}
func loadMatrix(path string) ([][]float32, error) {
f, err := os.Open(path)
if err != nil {
return nil, err
}
defer func() { _ = f.Close() }()
var m [][]float32
if err := json.NewDecoder(f).Decode(&m); err != nil {
return nil, fmt.Errorf("decode matrix JSON: %w", err)
}
return m, nil
}

105
cmd/vecna/editconfig.go Normal file
View File

@@ -0,0 +1,105 @@
package main
import (
"context"
"encoding/json"
"fmt"
"os"
"os/exec"
"github.com/spf13/cobra"
"github.com/Warky-Devs/vecna.git/pkg/config"
)
var editConfigCmd = &cobra.Command{
Use: "editconfig",
Short: "Open the vecna config file in your editor",
RunE: runEditConfig,
}
func init() {
rootCmd.AddCommand(editConfigCmd)
}
func runEditConfig(cmd *cobra.Command, _ []string) error {
path := config.ResolveFile(cfgFile)
if _, err := os.Stat(path); os.IsNotExist(err) {
if err := createDefaultConfig(path); err != nil {
return fmt.Errorf("create default config: %w", err)
}
}
fmt.Println(path)
editor := resolveEditor()
c := exec.CommandContext(context.Background(), editor, path)
c.Stdin = os.Stdin
c.Stdout = os.Stdout
c.Stderr = os.Stderr
if err := c.Run(); err != nil {
return fmt.Errorf("editor %q exited with error: %w", editor, err)
}
return nil
}
// resolveEditor returns $EDITOR, falling back to nvim then nano.
func resolveEditor() string {
if e := os.Getenv("EDITOR"); e != "" {
return e
}
for _, e := range []string{"nvim", "nano"} {
if path, err := exec.LookPath(e); err == nil {
return path
}
}
return "nano"
}
// createDefaultConfig writes a minimal JSON config skeleton to path.
func createDefaultConfig(path string) error {
skeleton := config.Config{
Server: config.ServerConfig{
Port: 8080,
Host: "0.0.0.0",
APIKeys: []string{},
},
Metrics: config.MetricsConfig{
Enabled: false,
Path: "/metrics",
},
Forward: config.ForwardConfig{
Default: "default",
Targets: map[string]config.ForwardTarget{
"default": {
APIType: "openai",
Model: "text-embedding-3-small",
Endpoints: []config.EndpointConfig{
{URL: "https://api.openai.com", Priority: 10},
},
},
},
},
Adapter: config.AdapterConfig{
Type: "truncate",
SourceDim: 1536,
TargetDim: 768,
TruncateMode: "from_end",
PadMode: "at_end",
},
}
f, err := os.Create(path)
if err != nil {
return fmt.Errorf("create %s: %w", path, err)
}
defer func() { _ = f.Close() }()
enc := json.NewEncoder(f)
enc.SetIndent("", " ")
if err := enc.Encode(skeleton); err != nil {
return fmt.Errorf("write default config: %w", err)
}
return nil
}

43
cmd/vecna/main.go Normal file
View File

@@ -0,0 +1,43 @@
package main
import (
"os"
"github.com/spf13/cobra"
"go.uber.org/zap"
)
var (
cfgFile string
logLevel string
logger *zap.Logger
version = "dev"
)
var rootCmd = &cobra.Command{
Use: "vecna",
Short: "Embedding dimension adapter — translate vectors between model spaces",
PersistentPreRunE: func(cmd *cobra.Command, args []string) error {
var err error
if logLevel == "debug" {
logger, err = zap.NewDevelopment()
} else {
logger, err = zap.NewProduction()
}
return err
},
}
func init() {
rootCmd.PersistentFlags().StringVar(&cfgFile, "config", "", "config file (default: ./vecna.yaml)")
rootCmd.PersistentFlags().StringVar(&logLevel, "log-level", "info", "log level: info | debug")
rootCmd.AddCommand(convertCmd)
rootCmd.AddCommand(serveCmd)
rootCmd.AddCommand(versionCmd)
}
func main() {
if err := rootCmd.Execute(); err != nil {
os.Exit(1)
}
}

91
cmd/vecna/models.go Normal file
View File

@@ -0,0 +1,91 @@
package main
import (
"context"
"fmt"
"strings"
"github.com/spf13/cobra"
"github.com/Warky-Devs/vecna.git/pkg/config"
"github.com/Warky-Devs/vecna.git/pkg/discovery"
)
var modelsCmd = &cobra.Command{
Use: "models",
Short: "List models available on each configured forwarder",
RunE: runModels,
}
func init() {
rootCmd.AddCommand(modelsCmd)
}
func runModels(_ *cobra.Command, _ []string) error {
cfg, err := config.Load(cfgFile)
if err != nil {
return fmt.Errorf("load config: %w", err)
}
if len(cfg.Forward.Targets) == 0 {
fmt.Println("No forwarder targets configured.")
return nil
}
ctx := context.Background()
for targetName, target := range cfg.Forward.Targets {
fmt.Printf("[ %s ]\n", targetName)
for _, ep := range target.Endpoints {
kind := discovery.Kind{
Name: targetName,
APIType: target.APIType,
Port: 0, // not used by Models()
}
models, err := discovery.Models(ctx, ep.URL, kind)
if err != nil {
fmt.Printf(" %s error: %s\n\n", ep.URL, err)
continue
}
if len(models) == 0 {
fmt.Printf(" %s (no models listed)\n\n", ep.URL)
continue
}
fmt.Printf(" %s\n", ep.URL)
for _, m := range models {
marker := " "
if m == target.Model {
marker = "* "
}
fmt.Printf(" %s%s\n", marker, m)
}
if target.Model != "" && !contains(models, target.Model) {
fmt.Printf(" ! configured model %q not found in list\n", target.Model)
}
fmt.Println()
}
if len(target.Endpoints) == 0 {
fmt.Printf(" (no endpoints configured)\n\n")
}
fmt.Printf(" API type : %s\n", target.APIType)
fmt.Printf(" Model : %s\n\n", strings.TrimSpace(target.Model))
}
return nil
}
func contains(ss []string, s string) bool {
for _, v := range ss {
if v == s {
return true
}
}
return false
}

480
cmd/vecna/onboard.go Normal file
View File

@@ -0,0 +1,480 @@
package main
import (
"bufio"
"context"
"fmt"
"net/http"
"os"
"strings"
"time"
"github.com/spf13/cobra"
"github.com/Warky-Devs/vecna.git/pkg/config"
"github.com/Warky-Devs/vecna.git/pkg/discovery"
"github.com/Warky-Devs/vecna.git/pkg/embedclient"
)
var onboardCmd = &cobra.Command{
Use: "onboard",
Short: "Interactive setup wizard: discover servers, configure, test, and write config",
RunE: runOnboard,
}
func init() {
rootCmd.AddCommand(onboardCmd)
}
func runOnboard(_ *cobra.Command, _ []string) error {
in := bufio.NewReader(os.Stdin)
fmt.Println("=== vecna onboard ===")
fmt.Println()
// ── Step 1: Discover ──────────────────────────────────────────────────────
step(1, 5, "Discover embedding servers")
fmt.Println("Scanning (Ollama, LM Studio, vLLM, LocalAI, Jan, Kobold, Tabby)...")
servers := discovery.Scan(context.Background())
var targets []pendingTarget
if len(servers) == 0 {
fmt.Println("No servers found automatically.")
} else {
fmt.Printf("Found %d server(s):\n\n", len(servers))
for i, s := range servers {
fmt.Printf(" [%d] %-12s %s\n Models: %s\n\n",
i+1, s.Kind.Name, s.BaseURL, joinModels(s.Models))
}
// Let user pick one or more from the list; 0 = manual
for {
choice, err := promptInt(in,
fmt.Sprintf("Select server [1-%d] or 0 to enter URL manually: ", len(servers)), 0, len(servers))
if err != nil {
return err
}
var pt pendingTarget
if choice == 0 {
pt, err = collectManualTarget(in)
} else {
pt, err = collectDiscoveredTarget(in, servers[choice-1])
}
if err != nil {
return err
}
targets = append(targets, pt)
again, err := promptBool(in, "Add another forwarder?", false)
if err != nil {
return err
}
if !again {
break
}
}
}
if len(targets) == 0 {
pt, err := collectManualTarget(in)
if err != nil {
return err
}
targets = append(targets, pt)
}
// ── Step 2: Detect dimensions ─────────────────────────────────────────────
step(2, 5, "Detect model dimensions")
for i := range targets {
fmt.Printf("Probing %s / %s ... ", targets[i].endpoint, targets[i].model)
dim, err := detectDim(targets[i])
if err != nil {
fmt.Printf("failed (%s) — you will need to enter the dimension manually\n", err)
targets[i].detectedDim = 0
} else {
fmt.Printf("%d dims\n", dim)
targets[i].detectedDim = dim
}
}
fmt.Println()
// ── Step 3: Configure adapter ─────────────────────────────────────────────
step(3, 5, "Configure dimension adapter")
// Use the first target's detected dim as the source dimension default
firstDim := 0
for _, t := range targets {
if t.detectedDim > 0 {
firstDim = t.detectedDim
break
}
}
srcDimStr := ""
if firstDim > 0 {
srcDimStr = fmt.Sprintf("%d", firstDim)
}
sourceDimRaw, err := promptString(in,
fmt.Sprintf("Source dimension (native model output dim)%s: ", defaultHint(srcDimStr)), srcDimStr)
if err != nil {
return err
}
sourceDim := mustParseInt(sourceDimRaw, firstDim)
targetDimRaw, err := promptString(in, "Target dimension (output dim vecna will serve) [1536]: ", "1536")
if err != nil {
return err
}
targetDim := mustParseInt(targetDimRaw, 1536)
adapterType, err := promptString(in, "Adapter type (truncate/random/projection) [truncate]: ", "truncate")
if err != nil {
return err
}
truncateMode := "from_end"
padMode := "at_end"
if adapterType == "truncate" {
truncateMode, err = promptString(in, "Truncate mode (from_end/from_start) [from_end]: ", "from_end")
if err != nil {
return err
}
padMode, err = promptString(in, "Pad mode (at_end/at_start) [at_end]: ", "at_end")
if err != nil {
return err
}
}
fmt.Println()
// ── Step 4: Configure vecna server ────────────────────────────────────────
step(4, 5, "Configure vecna server")
portRaw, err := promptString(in, "Bind port [8080]: ", "8080")
if err != nil {
return err
}
port := mustParseInt(portRaw, 8080)
apiKeysRaw, err := promptString(in,
"Inbound API keys for vecna (comma-separated, leave empty to disable auth): ", "")
if err != nil {
return err
}
var apiKeys []string
for _, k := range strings.Split(apiKeysRaw, ",") {
if k := strings.TrimSpace(k); k != "" {
apiKeys = append(apiKeys, k)
}
}
enableMetrics, err := promptBool(in, "Enable Prometheus /metrics endpoint?", false)
if err != nil {
return err
}
metricsAPIKey := ""
if enableMetrics {
metricsAPIKey, err = promptString(in, "Metrics API key (leave empty for open access): ", "")
if err != nil {
return err
}
}
fmt.Println()
// ── Step 5: Test & write ──────────────────────────────────────────────────
step(5, 5, "Test connections and write config")
allPassed := true
for _, t := range targets {
fmt.Printf("Testing %-45s ", t.endpoint+"...")
_, elapsed, dims, testErr := runSingleTest(t)
if testErr != nil {
fmt.Printf("FAIL %s\n", truncate(testErr.Error(), 55))
allPassed = false
} else {
fmt.Printf("OK %dms dims=%d\n", elapsed.Milliseconds(), dims)
}
}
fmt.Println()
if !allPassed {
proceed, err := promptBool(in, "Some tests failed. Write config anyway?", false)
if err != nil {
return err
}
if !proceed {
fmt.Println("Aborted. No config written.")
return nil
}
}
// Build the config struct
defaultTarget := ""
forwardTargets := make(map[string]config.ForwardTarget, len(targets))
for i, t := range targets {
forwardTargets[t.name] = config.ForwardTarget{
APIType: t.apiType,
Model: t.model,
APIKey: t.apiKey,
Endpoints: []config.EndpointConfig{
{URL: t.endpoint, Priority: 10},
},
TimeoutSecs: 30,
CooldownSecs: 60,
PriorityDecay: 2,
PriorityRecovery: 5,
}
if i == 0 {
defaultTarget = t.name
}
}
cfg := config.Config{
Server: config.ServerConfig{
Port: port,
Host: "0.0.0.0",
APIKeys: apiKeys,
},
Metrics: config.MetricsConfig{
Enabled: enableMetrics,
Path: "/metrics",
APIKey: metricsAPIKey,
},
Forward: config.ForwardConfig{
Default: defaultTarget,
Targets: forwardTargets,
},
Adapter: config.AdapterConfig{
Type: adapterType,
SourceDim: sourceDim,
TargetDim: targetDim,
TruncateMode: truncateMode,
PadMode: padMode,
},
}
defaultCfgPath := config.ResolveFile(cfgFile)
fmt.Printf("Config will be written to: %s\n", defaultCfgPath)
cfgPath, err := promptString(in, "Config path (press Enter to accept): ", defaultCfgPath)
if err != nil {
return err
}
if err := writeFullConfig(cfgPath, cfg); err != nil {
return fmt.Errorf("write config: %w", err)
}
fmt.Printf("Config written to %s\n", cfgPath)
fmt.Println()
fmt.Println("Run 'vecna serve' to start the proxy server.")
return nil
}
// ── helpers ──────────────────────────────────────────────────────────────────
// pendingTarget collects configuration for a single forwarding target before
// the config is assembled.
type pendingTarget struct {
name string
endpoint string
model string
apiType string
apiKey string
detectedDim int
}
func step(n, total int, title string) {
fmt.Printf("[%d/%d] %s\n", n, total, title)
fmt.Println(strings.Repeat("-", 40))
}
func defaultHint(s string) string {
if s == "" {
return ""
}
return fmt.Sprintf(" [%s]", s)
}
func joinModels(models []string) string {
if len(models) == 0 {
return "(none)"
}
if len(models) > 5 {
return strings.Join(models[:5], ", ") + fmt.Sprintf(" (+%d more)", len(models)-5)
}
return strings.Join(models, ", ")
}
func collectDiscoveredTarget(in *bufio.Reader, srv discovery.Found) (pendingTarget, error) {
defaultName := strings.ToLower(strings.ReplaceAll(srv.Kind.Name, " ", "_"))
model, err := pickModel(in, srv.Models)
if err != nil {
return pendingTarget{}, err
}
name, err := promptString(in, fmt.Sprintf("Target name in config [%s]: ", defaultName), defaultName)
if err != nil {
return pendingTarget{}, err
}
apiKey, err := promptAPIKey(in, srv.Kind.NeedsKey)
if err != nil {
return pendingTarget{}, err
}
return pendingTarget{
name: name,
endpoint: srv.BaseURL,
model: model,
apiType: srv.Kind.APIType,
apiKey: apiKey,
}, nil
}
func collectManualTarget(in *bufio.Reader) (pendingTarget, error) {
fmt.Println("Enter server details manually:")
endpoint, err := promptString(in, "Server URL (e.g. http://localhost:11434): ", "")
if err != nil || endpoint == "" {
return pendingTarget{}, fmt.Errorf("server URL is required")
}
apiTypeStr, err := promptString(in, "API type (openai/google) [openai]: ", "openai")
if err != nil {
return pendingTarget{}, err
}
model, err := promptString(in, "Model name: ", "")
if err != nil || model == "" {
return pendingTarget{}, fmt.Errorf("model name is required")
}
name, err := promptString(in, "Target name in config [custom]: ", "custom")
if err != nil {
return pendingTarget{}, err
}
apiKey, err := promptAPIKey(in, false)
if err != nil {
return pendingTarget{}, err
}
return pendingTarget{
name: name,
endpoint: endpoint,
model: model,
apiType: apiTypeStr,
apiKey: apiKey,
}, nil
}
func pickModel(in *bufio.Reader, models []string) (string, error) {
switch {
case len(models) == 0:
return promptString(in, "Model name: ", "")
case len(models) == 1:
fmt.Printf("Using model: %s\n", models[0])
return models[0], nil
default:
fmt.Println("Available models:")
for i, m := range models {
fmt.Printf(" [%d] %s\n", i+1, m)
}
idx, err := promptInt(in, fmt.Sprintf("Select model [1-%d]: ", len(models)), 1, len(models))
if err != nil {
return "", err
}
return models[idx-1], nil
}
}
func promptAPIKey(in *bufio.Reader, required bool) (string, error) {
prompt := "API key (leave empty if none): "
if required {
prompt = "API key: "
}
return promptString(in, prompt, "")
}
// detectDim sends a single test embedding and returns the vector length.
func detectDim(t pendingTarget) (int, error) {
httpClient := &http.Client{Timeout: 10 * time.Second}
var client embedclient.Client
if t.apiType == "google" {
client = embedclient.NewGoogle(t.endpoint, t.apiKey, t.model, httpClient)
} else {
client = embedclient.NewOpenAI(t.endpoint, t.apiKey, httpClient)
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
resp, err := client.Embed(ctx, embedclient.Request{
Texts: []string{"dimension probe"},
Model: t.model,
})
if err != nil {
return 0, err
}
if len(resp.Embeddings) == 0 || len(resp.Embeddings[0]) == 0 {
return 0, fmt.Errorf("empty embedding in response")
}
return len(resp.Embeddings[0]), nil
}
// runSingleTest runs one test embed and returns success, elapsed time, dims, and any error.
func runSingleTest(t pendingTarget) (bool, time.Duration, int, error) {
httpClient := &http.Client{Timeout: 30 * time.Second}
var client embedclient.Client
if t.apiType == "google" {
client = embedclient.NewGoogle(t.endpoint, t.apiKey, t.model, httpClient)
} else {
client = embedclient.NewOpenAI(t.endpoint, t.apiKey, httpClient)
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
start := time.Now()
resp, err := client.Embed(ctx, embedclient.Request{
Texts: []string{testPhrase},
Model: t.model,
})
elapsed := time.Since(start)
if err != nil {
return false, elapsed, 0, err
}
dims, _ := embeddingStats(resp.Embeddings)
return true, elapsed, dims, nil
}
func mustParseInt(s string, fallback int) int {
var n int
if _, err := fmt.Sscanf(s, "%d", &n); err != nil || n <= 0 {
return fallback
}
return n
}
func writeFullConfig(path string, cfg config.Config) error {
// If file already exists, preserve any targets not touched by onboard
// by using SaveTarget for each new target; otherwise write the whole file.
if _, err := os.Stat(path); os.IsNotExist(err) {
if err := createDefaultConfig(path); err != nil {
return err
}
}
// Overwrite with the complete onboard config
return config.WriteConfig(path, cfg)
}

59
cmd/vecna/prompt.go Normal file
View File

@@ -0,0 +1,59 @@
package main
import (
"bufio"
"fmt"
"strconv"
"strings"
)
// promptInt reads an integer in [min, max] from the reader, re-prompting on bad input.
func promptInt(in *bufio.Reader, prompt string, min, max int) (int, error) {
for {
fmt.Print(prompt)
line, err := in.ReadString('\n')
if err != nil {
return 0, fmt.Errorf("read input: %w", err)
}
n, err := strconv.Atoi(strings.TrimSpace(line))
if err != nil || n < min || n > max {
fmt.Printf(" Enter a number between %d and %d\n", min, max)
continue
}
return n, nil
}
}
// promptString reads a line, returning defaultVal when the user presses Enter with no input.
func promptString(in *bufio.Reader, prompt, defaultVal string) (string, error) {
fmt.Print(prompt)
line, err := in.ReadString('\n')
if err != nil {
return "", fmt.Errorf("read input: %w", err)
}
if s := strings.TrimSpace(line); s != "" {
return s, nil
}
return defaultVal, nil
}
// promptBool reads a y/N confirmation, returning defaultVal on empty input.
func promptBool(in *bufio.Reader, prompt string, defaultVal bool) (bool, error) {
hint := "y/N"
if defaultVal {
hint = "Y/n"
}
fmt.Printf("%s [%s]: ", prompt, hint)
line, err := in.ReadString('\n')
if err != nil {
return false, fmt.Errorf("read input: %w", err)
}
switch strings.ToLower(strings.TrimSpace(line)) {
case "y", "yes":
return true, nil
case "n", "no":
return false, nil
default:
return defaultVal, nil
}
}

154
cmd/vecna/query.go Normal file
View File

@@ -0,0 +1,154 @@
package main
import (
"context"
"encoding/json"
"fmt"
"net/http"
"os"
"strings"
"time"
"github.com/spf13/cobra"
"github.com/Warky-Devs/vecna.git/pkg/config"
"github.com/Warky-Devs/vecna.git/pkg/embedclient"
)
var (
queryTarget string
queryRaw bool
queryCompact bool
)
var queryCmd = &cobra.Command{
Use: "query <text>",
Short: "Embed text and print the resulting vector",
Long: `Sends text to the configured forwarding target, applies the dimension adapter,
and prints the resulting vector as a JSON array.
Text can be supplied as a positional argument or via stdin (use - as the argument).`,
Args: cobra.MaximumNArgs(1),
RunE: runQuery,
}
func init() {
queryCmd.Flags().StringVar(&queryTarget, "target", "",
"forward target to use (default: forward.default from config)")
queryCmd.Flags().BoolVar(&queryRaw, "raw", false,
"skip the adapter — output the raw vector from the backing model")
queryCmd.Flags().BoolVar(&queryCompact, "compact", false,
"print vector on a single line instead of pretty-printed")
rootCmd.AddCommand(queryCmd)
}
func runQuery(_ *cobra.Command, args []string) error {
cfg, err := config.Load(cfgFile)
if err != nil {
return fmt.Errorf("load config: %w", err)
}
// Resolve text: positional arg, "-" reads stdin, no arg reads stdin.
text, err := queryText(args)
if err != nil {
return err
}
// Resolve target.
targetName := queryTarget
if targetName == "" {
targetName = cfg.Forward.Default
}
target, ok := cfg.Forward.Targets[targetName]
if !ok {
return fmt.Errorf("target %q not found in config", targetName)
}
if len(target.Endpoints) == 0 {
return fmt.Errorf("target %q has no endpoints", targetName)
}
// Build client (use first endpoint directly — no router needed for a one-shot query).
ep := target.Endpoints[0]
apiKey := ep.APIKey
if apiKey == "" {
apiKey = target.APIKey
}
timeout := time.Duration(target.TimeoutSecs) * time.Second
if timeout == 0 {
timeout = 30 * time.Second
}
httpClient := &http.Client{Timeout: timeout}
var client embedclient.Client
switch target.APIType {
case "google":
client = embedclient.NewGoogle(ep.URL, apiKey, target.Model, httpClient)
default:
client = embedclient.NewOpenAI(ep.URL, apiKey, httpClient)
}
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
resp, err := client.Embed(ctx, embedclient.Request{
Texts: []string{text},
Model: target.Model,
})
if err != nil {
return fmt.Errorf("embed: %w", err)
}
if len(resp.Embeddings) == 0 || len(resp.Embeddings[0]) == 0 {
return fmt.Errorf("empty embedding in response")
}
vec := resp.Embeddings[0]
if !queryRaw {
adp, err := buildAdapter(cfg)
if err != nil {
return fmt.Errorf("build adapter: %w", err)
}
vec, err = adp.Adapt(vec)
if err != nil {
return fmt.Errorf("adapt: %w", err)
}
}
fmt.Fprintf(os.Stderr, "target=%s model=%s dims=%d tokens=%d\n",
targetName, resp.Model, len(vec), resp.Usage.TotalTokens)
return printVector(vec)
}
func queryText(args []string) (string, error) {
if len(args) == 0 || args[0] == "-" {
raw, err := os.ReadFile("/dev/stdin")
if err != nil {
return "", fmt.Errorf("read stdin: %w", err)
}
return strings.TrimRight(string(raw), "\n"), nil
}
return args[0], nil
}
func printVector(vec []float32) error {
// Convert to []any so json.Marshal produces clean floats without float32 quirks.
out := make([]float64, len(vec))
for i, v := range vec {
out[i] = float64(v)
}
var b []byte
var err error
if queryCompact {
b, err = json.Marshal(out)
} else {
b, err = json.MarshalIndent(out, "", " ")
}
if err != nil {
return fmt.Errorf("marshal vector: %w", err)
}
fmt.Println(string(b))
return nil
}

121
cmd/vecna/search.go Normal file
View File

@@ -0,0 +1,121 @@
package main
import (
"bufio"
"context"
"fmt"
"os"
"strings"
"github.com/spf13/cobra"
"github.com/Warky-Devs/vecna.git/pkg/config"
"github.com/Warky-Devs/vecna.git/pkg/discovery"
)
var searchCmd = &cobra.Command{
Use: "search",
Short: "Scan the network for LLM servers and add one to the config",
RunE: runSearch,
}
func init() {
rootCmd.AddCommand(searchCmd)
}
func runSearch(_ *cobra.Command, _ []string) error {
in := bufio.NewReader(os.Stdin)
fmt.Println("Scanning for LLM servers (Ollama, LM Studio, vLLM, LocalAI, Jan, Kobold, Tabby)...")
servers := discovery.Scan(context.Background())
if len(servers) == 0 {
fmt.Println("No servers found.")
return nil
}
fmt.Printf("\nFound %d server(s):\n\n", len(servers))
for i, s := range servers {
modelList := strings.Join(s.Models, ", ")
if modelList == "" {
modelList = "(no models listed)"
}
fmt.Printf(" [%d] %s %s\n Models: %s\n\n", i+1, s.Kind.Name, s.BaseURL, modelList)
}
// Select server
chosen, err := promptInt(in, fmt.Sprintf("Select server [1-%d]: ", len(servers)), 1, len(servers))
if err != nil {
return err
}
srv := servers[chosen-1]
// Select model
var model string
switch {
case len(srv.Models) == 1:
model = srv.Models[0]
fmt.Printf("Using model: %s\n", model)
case len(srv.Models) > 1:
fmt.Println("\nAvailable models:")
for i, m := range srv.Models {
fmt.Printf(" [%d] %s\n", i+1, m)
}
idx, err := promptInt(in, fmt.Sprintf("Select model [1-%d]: ", len(srv.Models)), 1, len(srv.Models))
if err != nil {
return err
}
model = srv.Models[idx-1]
default:
model, err = promptString(in, "Model name: ", "")
if err != nil {
return err
}
}
// Target name in config
defaultName := strings.ToLower(strings.ReplaceAll(srv.Kind.Name, " ", "_"))
targetName, err := promptString(in, fmt.Sprintf("Target name in config [%s]: ", defaultName), defaultName)
if err != nil {
return err
}
// API key
keyPrompt := "API key (leave empty if none): "
if srv.Kind.NeedsKey {
keyPrompt = "API key: "
}
apiKey, err := promptString(in, keyPrompt, "")
if err != nil {
return err
}
target := config.ForwardTarget{
APIType: srv.Kind.APIType,
Model: model,
APIKey: apiKey,
Endpoints: []config.EndpointConfig{
{URL: srv.BaseURL, Priority: 10},
},
TimeoutSecs: 30,
CooldownSecs: 60,
PriorityDecay: 2,
PriorityRecovery: 5,
}
cfgPath := config.ResolveFile(cfgFile)
// Create default config if it doesn't exist yet
if _, err := os.Stat(cfgPath); os.IsNotExist(err) {
if err := createDefaultConfig(cfgPath); err != nil {
return fmt.Errorf("create default config: %w", err)
}
}
if err := config.SaveTarget(cfgPath, targetName, target); err != nil {
return err
}
fmt.Printf("\nAdded target %q to %s\n", targetName, cfgPath)
return nil
}

150
cmd/vecna/serve.go Normal file
View File

@@ -0,0 +1,150 @@
package main
import (
"context"
"fmt"
"net/http"
"os"
"os/signal"
"syscall"
"time"
"github.com/spf13/cobra"
"go.uber.org/zap"
"github.com/Warky-Devs/vecna.git/pkg/config"
"github.com/Warky-Devs/vecna.git/pkg/embedclient"
"github.com/Warky-Devs/vecna.git/pkg/metrics"
"github.com/Warky-Devs/vecna.git/pkg/server"
)
var serveCmd = &cobra.Command{
Use: "serve",
Short: "Start the vecna embedding proxy server",
RunE: runServe,
}
func init() {
serveCmd.Flags().String("host", "", "bind host (overrides config)")
serveCmd.Flags().Int("port", 0, "bind port (overrides config)")
}
func runServe(cmd *cobra.Command, _ []string) error {
cfg, err := config.Load(cfgFile)
if err != nil {
return fmt.Errorf("load config: %w", err)
}
// Flag overrides
if h, _ := cmd.Flags().GetString("host"); h != "" {
cfg.Server.Host = h
}
if p, _ := cmd.Flags().GetInt("port"); p != 0 {
cfg.Server.Port = p
}
adp, err := buildAdapter(cfg)
if err != nil {
return fmt.Errorf("build adapter: %w", err)
}
clients, err := buildClients(cfg)
if err != nil {
return fmt.Errorf("build clients: %w", err)
}
var reg *metrics.Registry
if cfg.Metrics.Enabled {
reg = metrics.New()
}
router := server.New(cfg, clients, adp, reg, logger)
addr := fmt.Sprintf("%s:%d", cfg.Server.Host, cfg.Server.Port)
srv := &http.Server{
Addr: addr,
Handler: router,
}
// Graceful shutdown
quit := make(chan os.Signal, 1)
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
go func() {
logger.Info("vecna listening", zap.String("addr", addr))
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
logger.Error("server error", zap.Error(err))
quit <- syscall.SIGTERM
}
}()
<-quit
logger.Info("shutting down")
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := srv.Shutdown(ctx); err != nil {
return fmt.Errorf("graceful shutdown: %w", err)
}
return nil
}
// buildClients constructs one embedclient.Client per named forward target.
func buildClients(cfg *config.Config) (map[string]embedclient.Client, error) {
clients := make(map[string]embedclient.Client, len(cfg.Forward.Targets))
for name, target := range cfg.Forward.Targets {
if len(target.Endpoints) == 0 {
return nil, fmt.Errorf("target %q has no endpoints", name)
}
timeout := time.Duration(target.TimeoutSecs) * time.Second
httpClient := &http.Client{Timeout: timeout}
slots := make([]embedclient.RouterSlot, len(target.Endpoints))
for i, ep := range target.Endpoints {
apiKey := ep.APIKey
if apiKey == "" {
apiKey = target.APIKey
}
var c embedclient.Client
switch target.APIType {
case "google":
c = embedclient.NewGoogle(ep.URL, apiKey, target.Model, httpClient)
default: // "openai" or unset
c = embedclient.NewOpenAI(ep.URL, apiKey, httpClient)
}
slots[i] = embedclient.RouterSlot{
Client: c,
URL: ep.URL,
Priority: ep.Priority,
}
}
routerCfg := embedclient.RouterConfig{
TargetName: name,
TimeoutSecs: target.TimeoutSecs,
CooldownSecs: target.CooldownSecs,
PriorityDecay: target.PriorityDecay,
PriorityRecovery: target.PriorityRecovery,
}
// metrics registry may be nil (disabled)
var reg *metrics.Registry
if cfg.Metrics.Enabled {
// Registry is created in runServe before buildClients; pass nil here,
// caller wires it in after creation. See note below.
_ = reg
}
router, err := embedclient.NewTargetRouter(slots, routerCfg, nil)
if err != nil {
return nil, fmt.Errorf("build router for target %q: %w", name, err)
}
clients[name] = router
}
return clients, nil
}

150
cmd/vecna/test.go Normal file
View File

@@ -0,0 +1,150 @@
package main
import (
"context"
"fmt"
"math"
"net/http"
"time"
"github.com/spf13/cobra"
"github.com/Warky-Devs/vecna.git/pkg/config"
"github.com/Warky-Devs/vecna.git/pkg/embedclient"
)
var removeBroken bool
var testCmd = &cobra.Command{
Use: "test",
Short: "Send a test embedding request to each configured forwarder",
RunE: runTest,
}
func init() {
testCmd.Flags().BoolVar(&removeBroken, "remove-broken", false,
"Remove endpoints (and targets with no endpoints left) that fail the test from the config file")
rootCmd.AddCommand(testCmd)
}
const testPhrase = "The quick brown fox jumps over the lazy dog"
func runTest(_ *cobra.Command, _ []string) error {
cfg, err := config.Load(cfgFile)
if err != nil {
return fmt.Errorf("load config: %w", err)
}
if len(cfg.Forward.Targets) == 0 {
fmt.Println("No forwarder targets configured.")
return nil
}
fmt.Printf("Test phrase: %q\n\n", testPhrase)
passed, failed := 0, 0
// brokenEndpoints maps target name → list of failing endpoint URLs
brokenEndpoints := make(map[string][]string)
for targetName, target := range cfg.Forward.Targets {
fmt.Printf("[ %s ] model: %s type: %s\n", targetName, target.Model, target.APIType)
timeout := time.Duration(target.TimeoutSecs) * time.Second
if timeout == 0 {
timeout = 30 * time.Second
}
httpClient := &http.Client{Timeout: timeout}
for _, ep := range target.Endpoints {
apiKey := ep.APIKey
if apiKey == "" {
apiKey = target.APIKey
}
var client embedclient.Client
switch target.APIType {
case "google":
client = embedclient.NewGoogle(ep.URL, apiKey, target.Model, httpClient)
default:
client = embedclient.NewOpenAI(ep.URL, apiKey, httpClient)
}
ctx, cancel := context.WithTimeout(context.Background(), timeout)
start := time.Now()
resp, embedErr := client.Embed(ctx, embedclient.Request{
Texts: []string{testPhrase},
Model: target.Model,
})
elapsed := time.Since(start)
cancel()
if embedErr != nil {
fmt.Printf(" %-45s FAIL %s\n", ep.URL, truncate(embedErr.Error(), 60))
brokenEndpoints[targetName] = append(brokenEndpoints[targetName], ep.URL)
failed++
continue
}
dims, norm := embeddingStats(resp.Embeddings)
fmt.Printf(" %-45s OK %dms dims=%d norm=%.4f\n",
ep.URL, elapsed.Milliseconds(), dims, norm)
passed++
}
if len(target.Endpoints) == 0 {
fmt.Println(" (no endpoints configured)")
}
fmt.Println()
}
fmt.Printf("Results: %d passed, %d failed\n", passed, failed)
if removeBroken && len(brokenEndpoints) > 0 {
if err := applyRemoveBroken(brokenEndpoints); err != nil {
return err
}
}
if failed > 0 {
return fmt.Errorf("%d forwarder(s) failed", failed)
}
return nil
}
func applyRemoveBroken(broken map[string][]string) error {
cfgPath := config.ResolveFile(cfgFile)
removed, err := config.RemoveBrokenEndpoints(cfgPath, broken)
if err != nil {
return fmt.Errorf("remove broken: %w", err)
}
if len(removed) == 0 {
return nil
}
fmt.Println("\nRemoved from config:")
for _, r := range removed {
fmt.Printf(" - %s\n", r)
}
fmt.Printf("Config updated: %s\n", cfgPath)
return nil
}
// embeddingStats returns the dimension count and L2 norm of the first embedding.
func embeddingStats(embeddings [][]float32) (dims int, norm float32) {
if len(embeddings) == 0 || len(embeddings[0]) == 0 {
return 0, 0
}
vec := embeddings[0]
dims = len(vec)
var sum float64
for _, v := range vec {
sum += float64(v) * float64(v)
}
return dims, float32(math.Sqrt(sum))
}
func truncate(s string, n int) string {
if len(s) <= n {
return s
}
return s[:n-3] + "..."
}

15
cmd/vecna/version.go Normal file
View File

@@ -0,0 +1,15 @@
package main
import (
"fmt"
"github.com/spf13/cobra"
)
var versionCmd = &cobra.Command{
Use: "version",
Short: "Print the version",
Run: func(cmd *cobra.Command, args []string) {
fmt.Println(version)
},
}

View File

@@ -0,0 +1,82 @@
services:
# ── vecna proxy ─────────────────────────────────────────────────────────────
vecna:
build: .
# image: ghcr.io/warky-devs/vecna:latest
ports:
- "8080:8080"
volumes:
- vecna_config:/config
environment:
VECNA_SERVER_PORT: 8080
# VECNA_SERVER_API_KEYS: sk-vecna-abc123,sk-vecna-def456
command: ["serve", "--config", "/config/vecna.json"]
restart: unless-stopped
depends_on:
ollama:
condition: service_healthy
# ── ollama (local embedding model) ──────────────────────────────────────────
ollama:
image: ollama/ollama:latest
ports:
- "11434:11434"
volumes:
- ollama_data:/root/.ollama
healthcheck:
test: ["CMD", "ollama", "list"]
interval: 10s
timeout: 5s
retries: 6
start_period: 20s
restart: unless-stopped
# ── pull the embedding model on first start ──────────────────────────────────
# Remove this service after the model has been pulled once.
ollama-pull:
image: ollama/ollama:latest
depends_on:
ollama:
condition: service_healthy
environment:
OLLAMA_HOST: http://ollama:11434
entrypoint: ["ollama", "pull", "nomic-embed-text"]
restart: "no"
# ── prometheus (optional, for metrics scraping) ──────────────────────────────
# Requires metrics.enabled: true in /config/vecna.json
prometheus:
image: prom/prometheus:latest
ports:
- "9090:9090"
volumes:
- ./prometheus.example.yml:/etc/prometheus/prometheus.yml:ro
restart: unless-stopped
profiles:
- metrics
volumes:
ollama_data:
vecna_config: # persists vecna.json across container rebuilds
# ── one-off commands ──────────────────────────────────────────────────────────
#
# Run the interactive onboard wizard (writes config into the vecna_config volume):
#
# docker compose run --rm -it vecna onboard --config /config/vecna.json
#
# The wizard will discover the ollama service on the Docker network at
# http://ollama:11434 (select it from the list or enter the URL manually).
#
# Test all configured endpoints after onboarding:
#
# docker compose run --rm vecna test --config /config/vecna.json
#
# Remove broken endpoints automatically:
#
# docker compose run --rm vecna test --config /config/vecna.json --remove-broken
#
# Open the config in a shell editor (requires the alpine image):
#
# docker compose run --rm -it vecna sh -c "vi /config/vecna.json"

BIN
docs/images/a.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 9.5 KiB

BIN
docs/images/a1.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 21 KiB

BIN
docs/images/b.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 69 KiB

BIN
docs/images/b1.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 36 KiB

BIN
docs/images/c.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 153 KiB

BIN
docs/images/c1.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 302 KiB

BIN
docs/images/d.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 176 KiB

BIN
docs/images/d1.jpeg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 54 KiB

BIN
docs/images/e.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 17 KiB

BIN
docs/images/e1.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 26 KiB

BIN
docs/images/f.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 88 KiB

BIN
docs/images/f1.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 19 KiB

355
docs/what_is_embeddings.md Normal file
View File

@@ -0,0 +1,355 @@
**Practical breakdown** of embeddings: how they work, how theyre trained, and concrete examples you can actually implement.
---
# 🧠 What embeddings actually are (intuitively + formally)
![Image](images/a.jpg)
![Image](images/b.jpg)
![Image](images/c.jpg)
![Image](images/d.jpg)
![Image](images/e.jpg)
![Image](images/f.jpg)
At the core:
* An embedding is a **function**
[
f(x) \rightarrow \mathbb{R}^d
]
* It maps raw data (text, image, etc.) into a **dense vector**.
Key property:
* **Semantic structure is encoded as geometry**
* Similar things → vectors close together
* Different things → far apart
This is the core idea behind vector search, RAG, clustering, etc. ([ibm.com][1])
---
# ⚙️ Why embeddings work (the underlying theory)
### 1. Distributional hypothesis
> “You shall know a word by the company it keeps”
* Words appearing in similar contexts → similar vectors
* This is the **foundation of Word2Vec, GloVe, BERT, etc.** ([ibm.com][2])
---
### 2. Geometry encodes meaning
Classic example:
```
king - man + woman ≈ queen
```
This works because relationships become **linear directions in vector space**.
---
### 3. Dense vs sparse representations
| Method | Problem |
| ---------- | ------------------ |
| One-hot | huge, no meaning |
| TF-IDF | frequency only |
| Embeddings | compact + semantic |
Embeddings are **low-dimensional but information-rich** representations. ([GeeksforGeeks][3])
---
# 🏗️ How embeddings are trained (core methods)
## 1. Prediction-based models (most important)
### Word2Vec (classic foundation)
![Image](images/a1.jpg)
![Image](images/b1.jpg)
![Image](images/c1.jpg)
![Image](images/d1.jpg)
![Image](images/e1.jpg)
![Image](images/f1.jpg)
Two main training strategies:
### (a) Skip-gram
Predict context from a word:
[
P(context \mid word)
]
### (b) CBOW
Predict word from context:
[
P(word \mid context)
]
Mechanism:
* Input = one-hot vector
* Hidden layer = embedding
* Train via gradient descent
👉 The embedding is literally the **weights of the hidden layer** ([Medium][4])
---
### Example (Skip-gram training loop)
```python
# pseudo-code
for word in corpus:
context = get_context_window(word)
loss = -log P(context | word)
update_weights()
```
---
## 2. Matrix factorization (GloVe)
Instead of prediction:
* Build co-occurrence matrix
* Factorize it into lower dimensions
Captures **global statistics**, not just local context. ([Medium][5])
---
## 3. Neural embedding layers (modern approach)
Used in:
* Transformers (BERT, GPT)
* Recommender systems
Mechanism:
* Embedding = **lookup table**
* Trained jointly with model
```python
embedding = torch.nn.Embedding(vocab_size, dim)
vector = embedding(token_id)
```
---
## 4. Contrastive learning (modern SOTA)
Used in:
* sentence embeddings
* CLIP (image-text)
* OpenAI embeddings
Core idea:
[
\text{similar pairs} \rightarrow \text{closer}
]
[
\text{different pairs} \rightarrow \text{farther}
]
Loss function:
[
\mathcal{L} = -\log \frac{e^{sim(x_i, x_j)}}{\sum_k e^{sim(x_i, x_k)}}
]
---
# 🔬 How modern embeddings (LLMs) differ
Older:
* static embeddings (Word2Vec)
Modern:
* **contextual embeddings**
Example:
* “bank” (river vs finance) → different vectors
This is why models like BERT/GPT outperform Word2Vec.
---
# 🧪 Practical training examples
## Example 1 — Train Word2Vec (Gensim)
```python
from gensim.models import Word2Vec
sentences = [["cat", "sat", "mat"], ["dog", "sat", "floor"]]
model = Word2Vec(sentences, vector_size=100, window=5, min_count=1)
vector = model.wv["cat"]
```
---
## Example 2 — Train embeddings in PyTorch
```python
import torch
import torch.nn as nn
embedding = nn.Embedding(10000, 128) # vocab, dim
input_ids = torch.tensor([1, 5, 23])
vectors = embedding(input_ids)
```
---
## Example 3 — Train contrastive embeddings
```python
# pseudo
anchor = model(text1)
positive = model(text2)
negative = model(text3)
loss = contrastive_loss(anchor, positive, negative)
```
---
## Example 4 — PCA reduction (your earlier question)
```python
from sklearn.decomposition import PCA
pca = PCA(n_components=256)
X_reduced = pca.fit_transform(X)
```
---
# 📊 Types of embeddings
| Type | Example |
| ---------- | --------------- |
| Word | Word2Vec, GloVe |
| Sentence | SBERT |
| Document | Doc2Vec |
| Image | CLIP |
| Graph | Node2Vec |
| Multimodal | CLIP, Gemini |
---
# 🧩 Key properties you should care about (engineering perspective)
### 1. Dimensionality
* Typical: 1281536
* Tradeoff: memory vs accuracy
---
### 2. Distance metric
* cosine similarity (most common)
* dot product
* Euclidean
---
### 3. Normalization
Critical for:
* search quality
* clustering
---
### 4. Training data distribution
Embeddings are only as good as:
* corpus size
* domain relevance
---
# ⚠️ Common pitfalls (important)
### ❌ Mixing embedding spaces
* embeddings from different models are **not compatible**
---
### ❌ Assuming linear compression is harmless
* PCA can distort semantic relationships
---
### ❌ Ignoring normalization
* cosine similarity breaks without it
---
### ❌ Using embeddings without evaluation
Always test:
* retrieval accuracy
* clustering quality
---
# 🧠 Mental model (most useful takeaway)
Think of embeddings as:
> A learned coordinate system where **meaning = position**
Training = learning that coordinate system so that:
* similar things cluster
* relationships become directions
---
[1]: https://www.ibm.com/think/topics/vector-embedding?utm_source=chatgpt.com "What is Vector Embedding? | IBM"
[2]: https://www.ibm.com/think/topics/word-embeddings?utm_source=chatgpt.com "What Are Word Embeddings? | IBM"
[3]: https://www.geeksforgeeks.org/nlp/word-embeddings-in-nlp/?utm_source=chatgpt.com "Word Embeddings in NLP"
[4]: https://medium.com/%40manansuri/a-dummys-guide-to-word2vec-456444f3c673?utm_source=chatgpt.com "A Dummy's Guide to Word2Vec - Medium"
[5]: https://medium.com/%40neri.vvo/word-embedding-a-powerful-tool-word2vec-glove-fasttext-dd6e2171d5?utm_source=chatgpt.com "Word Embedding Explained — Word2Vec GloVe, FastText"

40
go.mod Normal file
View File

@@ -0,0 +1,40 @@
module github.com/Warky-Devs/vecna.git
go 1.26.1
require (
github.com/prometheus/client_golang v1.23.2
github.com/spf13/cobra v1.10.2
github.com/spf13/viper v1.21.0
github.com/stretchr/testify v1.11.1
github.com/uptrace/bunrouter v1.0.23
go.uber.org/zap v1.27.1
)
require (
github.com/beorn7/perks v1.0.1 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/fsnotify/fsnotify v1.9.0 // indirect
github.com/go-viper/mapstructure/v2 v2.4.0 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
github.com/pelletier/go-toml/v2 v2.2.4 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/prometheus/client_model v0.6.2 // indirect
github.com/prometheus/common v0.66.1 // indirect
github.com/prometheus/procfs v0.16.1 // indirect
github.com/sagikazarmark/locafero v0.11.0 // indirect
github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 // indirect
github.com/spf13/afero v1.15.0 // indirect
github.com/spf13/cast v1.10.0 // indirect
github.com/spf13/pflag v1.0.10 // indirect
github.com/subosito/gotenv v1.6.0 // indirect
go.uber.org/multierr v1.10.0 // indirect
go.yaml.in/yaml/v2 v2.4.2 // indirect
go.yaml.in/yaml/v3 v3.0.4 // indirect
golang.org/x/sys v0.35.0 // indirect
golang.org/x/text v0.28.0 // indirect
google.golang.org/protobuf v1.36.8 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

84
go.sum Normal file
View File

@@ -0,0 +1,84 @@
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k=
github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
github.com/go-viper/mapstructure/v2 v2.4.0 h1:EBsztssimR/CONLSZZ04E8qAkxNYq4Qp9LvH92wZUgs=
github.com/go-viper/mapstructure/v2 v2.4.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA=
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4=
github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/prometheus/client_golang v1.23.2 h1:Je96obch5RDVy3FDMndoUsjAhG5Edi49h0RJWRi/o0o=
github.com/prometheus/client_golang v1.23.2/go.mod h1:Tb1a6LWHB3/SPIzCoaDXI4I8UHKeFTEQ1YCr+0Gyqmg=
github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk=
github.com/prometheus/client_model v0.6.2/go.mod h1:y3m2F6Gdpfy6Ut/GBsUqTWZqCUvMVzSfMLjcu6wAwpE=
github.com/prometheus/common v0.66.1 h1:h5E0h5/Y8niHc5DlaLlWLArTQI7tMrsfQjHV+d9ZoGs=
github.com/prometheus/common v0.66.1/go.mod h1:gcaUsgf3KfRSwHY4dIMXLPV0K/Wg1oZ8+SbZk/HH/dA=
github.com/prometheus/procfs v0.16.1 h1:hZ15bTNuirocR6u0JZ6BAHHmwS1p8B4P6MRqxtzMyRg=
github.com/prometheus/procfs v0.16.1/go.mod h1:teAbpZRB1iIAJYREa1LsoWUXykVXA1KlTmWl8x/U+Is=
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog=
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/sagikazarmark/locafero v0.11.0 h1:1iurJgmM9G3PA/I+wWYIOw/5SyBtxapeHDcg+AAIFXc=
github.com/sagikazarmark/locafero v0.11.0/go.mod h1:nVIGvgyzw595SUSUE6tvCp3YYTeHs15MvlmU87WwIik=
github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 h1:+jumHNA0Wrelhe64i8F6HNlS8pkoyMv5sreGx2Ry5Rw=
github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8/go.mod h1:3n1Cwaq1E1/1lhQhtRK2ts/ZwZEhjcQeJQ1RuC6Q/8U=
github.com/spf13/afero v1.15.0 h1:b/YBCLWAJdFWJTN9cLhiXXcD7mzKn9Dm86dNnfyQw1I=
github.com/spf13/afero v1.15.0/go.mod h1:NC2ByUVxtQs4b3sIUphxK0NioZnmxgyCrfzeuq8lxMg=
github.com/spf13/cast v1.10.0 h1:h2x0u2shc1QuLHfxi+cTJvs30+ZAHOGRic8uyGTDWxY=
github.com/spf13/cast v1.10.0/go.mod h1:jNfB8QC9IA6ZuY2ZjDp0KtFO2LZZlg4S/7bzP6qqeHo=
github.com/spf13/cobra v1.10.2 h1:DMTTonx5m65Ic0GOoRY2c16WCbHxOOw6xxezuLaBpcU=
github.com/spf13/cobra v1.10.2/go.mod h1:7C1pvHqHw5A4vrJfjNwvOdzYu0Gml16OCs2GRiTUUS4=
github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/spf13/pflag v1.0.10 h1:4EBh2KAYBwaONj6b2Ye1GiHfwjqyROoF4RwYO+vPwFk=
github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/spf13/viper v1.21.0 h1:x5S+0EU27Lbphp4UKm1C+1oQO+rKx36vfCoaVebLFSU=
github.com/spf13/viper v1.21.0/go.mod h1:P0lhsswPGWD/1lZJ9ny3fYnVqxiegrlNrEmgLjbTCAY=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
github.com/uptrace/bunrouter v1.0.23 h1:Bi7NKw3uCQkcA/GUCtDNPq5LE5UdR9pe+UyWbjHB/wU=
github.com/uptrace/bunrouter v1.0.23/go.mod h1:O3jAcl+5qgnF+ejhgkmbceEk0E/mqaK+ADOocdNpY8M=
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
go.uber.org/multierr v1.10.0 h1:S0h4aNzvfcFsC3dRF1jLoaov7oRaKqRGC/pUEJ2yvPQ=
go.uber.org/multierr v1.10.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
go.uber.org/zap v1.27.1 h1:08RqriUEv8+ArZRYSTXy1LeBScaMpVSTBhCeaZYfMYc=
go.uber.org/zap v1.27.1/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI=
go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU=
go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc=
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI=
golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng=
golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU=
google.golang.org/protobuf v1.36.8 h1:xHScyCOEuuwZEc6UtSOvPbAT4zRh0xcNRYekJwfqyMc=
google.golang.org/protobuf v1.36.8/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

77
pkg/adapter/adapter.go Normal file
View File

@@ -0,0 +1,77 @@
package adapter
import (
"errors"
"fmt"
)
// ErrDimMismatch is returned when the input vector length does not match the adapter's source dimension.
var ErrDimMismatch = errors.New("vector dimension mismatch")
// ErrInvalidDim is returned when source or target dimensions are invalid.
var ErrInvalidDim = errors.New("invalid dimension")
// ErrInvalidMatrix is returned when a projection matrix has the wrong shape.
var ErrInvalidMatrix = errors.New("invalid projection matrix shape")
// TruncateMode controls which end of the vector is dropped when downscaling.
type TruncateMode int
const (
// TruncateFromEnd keeps the first targetDim elements (default; correct for Matryoshka models).
TruncateFromEnd TruncateMode = iota
// TruncateFromStart keeps the last targetDim elements.
TruncateFromStart
)
// PadMode controls which end of the vector receives zero-padding when upscaling.
type PadMode int
const (
// PadAtEnd appends zeros to the end (default).
PadAtEnd PadMode = iota
// PadAtStart prepends zeros to the start.
PadAtStart
)
// Adapter translates vectors between two fixed dimensions.
type Adapter interface {
Adapt(vec []float32) ([]float32, error)
SourceDim() int
TargetDim() int
}
// NewTruncate returns a TruncateAdapter for Matryoshka-style or simple truncation/padding.
// t controls which end is dropped when downscaling; p controls which end is padded when upscaling.
func NewTruncate(sourceDim, targetDim int, t TruncateMode, p PadMode) (Adapter, error) {
if sourceDim <= 0 || targetDim <= 0 {
return nil, fmt.Errorf("NewTruncate: %w: source=%d target=%d", ErrInvalidDim, sourceDim, targetDim)
}
return &truncateAdapter{sourceDim: sourceDim, targetDim: targetDim, truncMode: t, padMode: p}, nil
}
// NewRandom returns a RandomAdapter backed by a seeded Gaussian projection matrix.
// seed=0 uses a time-based seed.
func NewRandom(sourceDim, targetDim int, seed int64) (Adapter, error) {
if sourceDim <= 0 || targetDim <= 0 {
return nil, fmt.Errorf("NewRandom: %w: source=%d target=%d", ErrInvalidDim, sourceDim, targetDim)
}
return newRandomAdapter(sourceDim, targetDim, seed), nil
}
// NewProjection returns a ProjectionAdapter using a caller-supplied matrix.
// matrix must have shape [targetDim][sourceDim].
func NewProjection(sourceDim, targetDim int, matrix [][]float32) (Adapter, error) {
if sourceDim <= 0 || targetDim <= 0 {
return nil, fmt.Errorf("NewProjection: %w: source=%d target=%d", ErrInvalidDim, sourceDim, targetDim)
}
if len(matrix) != targetDim {
return nil, fmt.Errorf("NewProjection: %w: got %d rows, want %d", ErrInvalidMatrix, len(matrix), targetDim)
}
for i, row := range matrix {
if len(row) != sourceDim {
return nil, fmt.Errorf("NewProjection: %w: row %d has %d cols, want %d", ErrInvalidMatrix, i, len(row), sourceDim)
}
}
return &projectionAdapter{sourceDim: sourceDim, targetDim: targetDim, matrix: matrix}, nil
}

236
pkg/adapter/adapter_test.go Normal file
View File

@@ -0,0 +1,236 @@
package adapter
import (
"math"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// unitVec returns a unit vector of length n with value 1/sqrt(n) in each position.
func unitVec(n int) []float32 {
v := make([]float32, n)
val := float32(1.0 / math.Sqrt(float64(n)))
for i := range v {
v[i] = val
}
return v
}
func assertNormOne(t *testing.T, v []float32) {
t.Helper()
var sum float64
for _, x := range v {
sum += float64(x) * float64(x)
}
assert.InDelta(t, 1.0, sum, 1e-5, "expected unit norm")
}
// --- L2Norm ---
func TestL2Norm(t *testing.T) {
tests := []struct {
name string
input []float32
wantLen int
wantNorm float64
}{
{"unit vector unchanged", []float32{1, 0, 0}, 3, 1.0},
{"scale down to unit", []float32{2, 0, 0}, 3, 1.0},
{"multi-dim", unitVec(4), 4, 1.0},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := L2Norm(tc.input)
assert.Len(t, got, tc.wantLen)
assertNormOne(t, got)
})
}
t.Run("zero vector returns copy unchanged", func(t *testing.T) {
in := []float32{0, 0, 0}
got := L2Norm(in)
assert.Equal(t, in, got)
})
}
// --- TruncateAdapter ---
func TestNewTruncate_InvalidDims(t *testing.T) {
_, err := NewTruncate(0, 768, TruncateFromEnd, PadAtEnd)
require.ErrorIs(t, err, ErrInvalidDim)
_, err = NewTruncate(1536, 0, TruncateFromEnd, PadAtEnd)
require.ErrorIs(t, err, ErrInvalidDim)
}
func TestTruncateAdapter(t *testing.T) {
tests := []struct {
name string
src, tgt int
truncMode TruncateMode
padMode PadMode
// builder for input vec; checks are len + norm only unless wantFirst/wantLast set
wantLen int
wantErr bool
}{
{"downscale TruncateFromEnd", 1536, 768, TruncateFromEnd, PadAtEnd, 768, false},
{"downscale TruncateFromStart", 1536, 768, TruncateFromStart, PadAtEnd, 768, false},
{"upscale PadAtEnd", 768, 1536, TruncateFromEnd, PadAtEnd, 1536, false},
{"upscale PadAtStart", 768, 1536, TruncateFromEnd, PadAtStart, 1536, false},
{"same dim", 768, 768, TruncateFromEnd, PadAtEnd, 768, false},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
a, err := NewTruncate(tc.src, tc.tgt, tc.truncMode, tc.padMode)
require.NoError(t, err)
assert.Equal(t, tc.src, a.SourceDim())
assert.Equal(t, tc.tgt, a.TargetDim())
got, err := a.Adapt(unitVec(tc.src))
require.NoError(t, err)
assert.Len(t, got, tc.wantLen)
assertNormOne(t, got)
})
}
t.Run("dim mismatch returns error", func(t *testing.T) {
a, err := NewTruncate(1536, 768, TruncateFromEnd, PadAtEnd)
require.NoError(t, err)
_, err = a.Adapt(make([]float32, 100))
require.ErrorIs(t, err, ErrDimMismatch)
})
t.Run("TruncateFromEnd keeps first elements", func(t *testing.T) {
a, err := NewTruncate(4, 2, TruncateFromEnd, PadAtEnd)
require.NoError(t, err)
// input: [1,2,3,4] — TruncateFromEnd keeps [1,2]
in := []float32{1, 2, 3, 4}
got, err := a.Adapt(in)
require.NoError(t, err)
// after L2Norm the signs are preserved; ratio should match
assert.Greater(t, got[0], float32(0))
assert.Greater(t, got[1], float32(0))
})
t.Run("TruncateFromStart keeps last elements", func(t *testing.T) {
a, err := NewTruncate(4, 2, TruncateFromStart, PadAtEnd)
require.NoError(t, err)
// input: [0,0,3,4] — TruncateFromStart keeps [3,4]
in := []float32{0, 0, 3, 4}
got, err := a.Adapt(in)
require.NoError(t, err)
assert.Greater(t, got[0], float32(0))
assert.Greater(t, got[1], float32(0))
})
t.Run("PadAtStart zero-pads front", func(t *testing.T) {
a, err := NewTruncate(2, 4, TruncateFromEnd, PadAtStart)
require.NoError(t, err)
in := []float32{1, 0}
got, err := a.Adapt(in)
require.NoError(t, err)
// first two positions should be zero (padded), last two carry the signal
assert.Equal(t, float32(0), got[0])
assert.Equal(t, float32(0), got[1])
})
}
// --- RandomAdapter ---
func TestNewRandom_InvalidDims(t *testing.T) {
_, err := NewRandom(0, 768, 42)
require.ErrorIs(t, err, ErrInvalidDim)
}
func TestRandomAdapter(t *testing.T) {
t.Run("output length and unit norm", func(t *testing.T) {
a, err := NewRandom(1536, 768, 42)
require.NoError(t, err)
got, err := a.Adapt(unitVec(1536))
require.NoError(t, err)
assert.Len(t, got, 768)
assertNormOne(t, got)
})
t.Run("deterministic with same seed", func(t *testing.T) {
a1, _ := NewRandom(64, 32, 99)
a2, _ := NewRandom(64, 32, 99)
in := unitVec(64)
out1, _ := a1.Adapt(in)
out2, _ := a2.Adapt(in)
assert.Equal(t, out1, out2)
})
t.Run("different seeds produce different output", func(t *testing.T) {
a1, _ := NewRandom(64, 32, 1)
a2, _ := NewRandom(64, 32, 2)
in := unitVec(64)
out1, _ := a1.Adapt(in)
out2, _ := a2.Adapt(in)
assert.NotEqual(t, out1, out2)
})
t.Run("dim mismatch returns error", func(t *testing.T) {
a, err := NewRandom(1536, 768, 1)
require.NoError(t, err)
_, err = a.Adapt(make([]float32, 10))
require.ErrorIs(t, err, ErrDimMismatch)
})
}
// --- ProjectionAdapter ---
func TestNewProjection_InvalidMatrix(t *testing.T) {
tests := []struct {
name string
src int
tgt int
matrix [][]float32
errIs error
}{
{"wrong row count", 4, 2, [][]float32{{1, 2, 3, 4}}, ErrInvalidMatrix},
{"wrong col count", 4, 2, [][]float32{{1, 2}, {3, 4}}, ErrInvalidMatrix},
{"zero sourceDim", 0, 2, [][]float32{{1}, {2}}, ErrInvalidDim},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
_, err := NewProjection(tc.src, tc.tgt, tc.matrix)
require.ErrorIs(t, err, tc.errIs)
})
}
}
func TestProjectionAdapter(t *testing.T) {
t.Run("identity matrix same dim", func(t *testing.T) {
// 3×3 identity
id := [][]float32{{1, 0, 0}, {0, 1, 0}, {0, 0, 1}}
a, err := NewProjection(3, 3, id)
require.NoError(t, err)
in := []float32{1, 2, 3}
got, err := a.Adapt(in)
require.NoError(t, err)
assert.Len(t, got, 3)
assertNormOne(t, got)
})
t.Run("downscale projection", func(t *testing.T) {
// 2×4 matrix
m := [][]float32{{1, 0, 0, 0}, {0, 1, 0, 0}}
a, err := NewProjection(4, 2, m)
require.NoError(t, err)
got, err := a.Adapt([]float32{3, 4, 0, 0})
require.NoError(t, err)
assert.Len(t, got, 2)
assertNormOne(t, got)
})
t.Run("dim mismatch returns error", func(t *testing.T) {
m := [][]float32{{1, 0}, {0, 1}}
a, err := NewProjection(2, 2, m)
require.NoError(t, err)
_, err = a.Adapt(make([]float32, 5))
require.ErrorIs(t, err, ErrDimMismatch)
})
}

23
pkg/adapter/normalize.go Normal file
View File

@@ -0,0 +1,23 @@
package adapter
import "math"
// L2Norm returns a new slice with the vector normalized to unit length.
// If the vector has zero magnitude it is returned unchanged.
func L2Norm(v []float32) []float32 {
var sum float64
for _, x := range v {
sum += float64(x) * float64(x)
}
if sum == 0 {
out := make([]float32, len(v))
copy(out, v)
return out
}
norm := float32(math.Sqrt(sum))
out := make([]float32, len(v))
for i, x := range v {
out[i] = x / norm
}
return out
}

32
pkg/adapter/projection.go Normal file
View File

@@ -0,0 +1,32 @@
package adapter
import "fmt"
type projectionAdapter struct {
sourceDim int
targetDim int
matrix [][]float32
}
func (a *projectionAdapter) SourceDim() int { return a.sourceDim }
func (a *projectionAdapter) TargetDim() int { return a.targetDim }
func (a *projectionAdapter) Adapt(vec []float32) ([]float32, error) {
if len(vec) != a.sourceDim {
return nil, fmt.Errorf("projection adapt: %w: got %d, want %d", ErrDimMismatch, len(vec), a.sourceDim)
}
return L2Norm(matVecMul(a.matrix, vec)), nil
}
// matVecMul computes m·v where m is [rows][cols] and v has len cols.
func matVecMul(m [][]float32, v []float32) []float32 {
out := make([]float32, len(m))
for i, row := range m {
var sum float32
for j, val := range row {
sum += val * v[j]
}
out[i] = sum
}
return out
}

45
pkg/adapter/random.go Normal file
View File

@@ -0,0 +1,45 @@
package adapter
import (
"fmt"
"math"
"math/rand"
"time"
)
type randomAdapter struct {
sourceDim int
targetDim int
matrix [][]float32
}
func newRandomAdapter(sourceDim, targetDim int, seed int64) *randomAdapter {
if seed == 0 {
seed = time.Now().UnixNano()
}
//nolint:gosec // deterministic seeded RNG for projection matrix generation, not security use
rng := rand.New(rand.NewSource(seed))
// Gaussian N(0, 1/targetDim) — preserves expected squared norms (Johnson-Lindenstrauss)
stddev := 1.0 / math.Sqrt(float64(targetDim))
matrix := make([][]float32, targetDim)
for i := range matrix {
row := make([]float32, sourceDim)
for j := range row {
row[j] = float32(rng.NormFloat64() * stddev)
}
matrix[i] = row
}
return &randomAdapter{sourceDim: sourceDim, targetDim: targetDim, matrix: matrix}
}
func (a *randomAdapter) SourceDim() int { return a.sourceDim }
func (a *randomAdapter) TargetDim() int { return a.targetDim }
func (a *randomAdapter) Adapt(vec []float32) ([]float32, error) {
if len(vec) != a.sourceDim {
return nil, fmt.Errorf("random adapt: %w: got %d, want %d", ErrDimMismatch, len(vec), a.sourceDim)
}
return L2Norm(matVecMul(a.matrix, vec)), nil
}

41
pkg/adapter/truncate.go Normal file
View File

@@ -0,0 +1,41 @@
package adapter
import "fmt"
type truncateAdapter struct {
sourceDim int
targetDim int
truncMode TruncateMode
padMode PadMode
}
func (a *truncateAdapter) SourceDim() int { return a.sourceDim }
func (a *truncateAdapter) TargetDim() int { return a.targetDim }
func (a *truncateAdapter) Adapt(vec []float32) ([]float32, error) {
if len(vec) != a.sourceDim {
return nil, fmt.Errorf("truncate adapt: %w: got %d, want %d", ErrDimMismatch, len(vec), a.sourceDim)
}
out := make([]float32, a.targetDim)
if a.targetDim <= a.sourceDim {
// Downscale: truncate
switch a.truncMode {
case TruncateFromEnd:
copy(out, vec[:a.targetDim])
case TruncateFromStart:
copy(out, vec[a.sourceDim-a.targetDim:])
}
} else {
// Upscale: zero-pad
switch a.padMode {
case PadAtEnd:
copy(out, vec)
case PadAtStart:
copy(out[a.targetDim-a.sourceDim:], vec)
}
}
return L2Norm(out), nil
}

168
pkg/config/config.go Normal file
View File

@@ -0,0 +1,168 @@
package config
import (
"errors"
"fmt"
"os"
"strings"
"github.com/spf13/viper"
)
// Config is the root configuration for vecna.
type Config struct {
Server ServerConfig `mapstructure:"server"`
Metrics MetricsConfig `mapstructure:"metrics"`
Forward ForwardConfig `mapstructure:"forward"`
Adapter AdapterConfig `mapstructure:"adapter"`
}
// ServerConfig controls the HTTP listener and inbound auth.
type ServerConfig struct {
Port int `mapstructure:"port"`
Host string `mapstructure:"host"`
APIKeys []string `mapstructure:"api_keys"`
}
// MetricsConfig controls Prometheus metrics exposure.
type MetricsConfig struct {
Enabled bool `mapstructure:"enabled"`
Path string `mapstructure:"path"`
APIKey string `mapstructure:"api_key"`
}
// ForwardConfig holds all named forwarding targets.
type ForwardConfig struct {
Default string `mapstructure:"default"`
Targets map[string]ForwardTarget `mapstructure:"targets"`
}
// ForwardTarget is a named backing embedding model with one or more endpoints.
type ForwardTarget struct {
Endpoints []EndpointConfig `mapstructure:"endpoints"`
Model string `mapstructure:"model"`
APIKey string `mapstructure:"api_key"`
APIType string `mapstructure:"api_type"`
TimeoutSecs int `mapstructure:"timeout_secs"`
CooldownSecs int `mapstructure:"cooldown_secs"`
PriorityDecay int `mapstructure:"priority_decay"`
PriorityRecovery int `mapstructure:"priority_recovery"`
}
// EndpointConfig is a single URL within a ForwardTarget.
type EndpointConfig struct {
URL string `mapstructure:"url"`
Priority int `mapstructure:"priority"`
APIKey string `mapstructure:"api_key"`
}
// AdapterConfig selects and tunes the dimension adapter.
type AdapterConfig struct {
Type string `mapstructure:"type"`
SourceDim int `mapstructure:"source_dim"`
TargetDim int `mapstructure:"target_dim"`
TruncateMode string `mapstructure:"truncate_mode"`
PadMode string `mapstructure:"pad_mode"`
Seed int64 `mapstructure:"seed"`
MatrixFile string `mapstructure:"matrix_file"`
}
// extensions viper will detect automatically.
var extensions = []string{"json", "yaml", "toml"}
// ResolveFile returns the config file path that would be used by Load.
// If cfgFile is non-empty it is returned as-is.
// Otherwise the default search paths are checked; if no existing file is found,
// the preferred default (~/.vecna.json) is returned so callers can create it.
func ResolveFile(cfgFile string) string {
if cfgFile != "" {
return cfgFile
}
home, _ := os.UserHomeDir()
dirs := []string{".", home, home + "/.config/vecna"}
for _, dir := range dirs {
for _, ext := range extensions {
path := dir + "/vecna." + ext
if _, err := os.Stat(path); err == nil {
return path
}
}
}
return home + "/vecna.json"
}
// Load reads configuration from the given file path (empty = search defaults),
// environment variables (prefix VECNA_), and applies built-in defaults.
func Load(cfgFile string) (*Config, error) {
v := viper.New()
// Defaults
v.SetDefault("server.port", 8080)
v.SetDefault("server.host", "0.0.0.0")
v.SetDefault("metrics.enabled", false)
v.SetDefault("metrics.path", "/metrics")
v.SetDefault("adapter.type", "truncate")
v.SetDefault("adapter.truncate_mode", "from_end")
v.SetDefault("adapter.pad_mode", "at_end")
// Environment
v.SetEnvPrefix("VECNA")
v.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
v.AutomaticEnv()
// Config file
if cfgFile != "" {
v.SetConfigFile(cfgFile)
} else {
home, _ := os.UserHomeDir()
v.SetConfigName("vecna")
// No SetConfigType — viper detects format from file extension (json, yaml, toml, etc.)
v.AddConfigPath(".")
v.AddConfigPath(home)
v.AddConfigPath(home + "/.config/vecna")
}
if err := v.ReadInConfig(); err != nil {
// Missing config file is acceptable when all required values come from flags/env
var notFound viper.ConfigFileNotFoundError
if !errors.As(err, &notFound) {
return nil, fmt.Errorf("load config: %w", err)
}
}
var cfg Config
if err := v.Unmarshal(&cfg); err != nil {
return nil, fmt.Errorf("unmarshal config: %w", err)
}
applyForwardDefaults(&cfg)
return &cfg, nil
}
// applyForwardDefaults fills in zero-value fields on ForwardTarget entries.
func applyForwardDefaults(cfg *Config) {
for name, t := range cfg.Forward.Targets {
if t.TimeoutSecs == 0 {
t.TimeoutSecs = 30
}
if t.CooldownSecs == 0 {
t.CooldownSecs = 60
}
if t.PriorityDecay == 0 {
t.PriorityDecay = 2
}
if t.PriorityRecovery == 0 {
t.PriorityRecovery = 5
}
for i, ep := range t.Endpoints {
if ep.Priority == 0 {
t.Endpoints[i].Priority = 10
}
if ep.APIKey == "" && t.APIKey != "" {
t.Endpoints[i].APIKey = t.APIKey
}
}
cfg.Forward.Targets[name] = t
}
}

137
pkg/config/update.go Normal file
View File

@@ -0,0 +1,137 @@
package config
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"strings"
)
// SaveTarget adds or replaces a named ForwardTarget in the config file at path.
// Only JSON config files are written in-place. For other formats an error is
// returned describing what to add manually.
func SaveTarget(path, name string, target ForwardTarget) error {
ext := strings.ToLower(strings.TrimPrefix(filepath.Ext(path), "."))
data, err := os.ReadFile(path)
if err != nil {
return fmt.Errorf("read config %s: %w", path, err)
}
switch ext {
case "json":
return saveTargetJSON(path, data, name, target)
default:
snippet, _ := json.MarshalIndent(map[string]ForwardTarget{name: target}, "", " ")
return fmt.Errorf(
"auto-update not supported for .%s files\n"+
"Add the following to the 'forward.targets' section of %s:\n\n%s",
ext, path, snippet,
)
}
}
// RemoveBrokenEndpoints removes failing endpoints from the config file.
// broken maps target name → set of failing endpoint URLs.
// If all endpoints of a target are removed, the target itself is deleted.
// Returns the list of removed items as human-readable strings.
func RemoveBrokenEndpoints(path string, broken map[string][]string) ([]string, error) {
ext := strings.ToLower(strings.TrimPrefix(filepath.Ext(path), "."))
if ext != "json" {
return nil, fmt.Errorf("auto-update not supported for .%s files; edit %s manually", ext, path)
}
data, err := os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("read config %s: %w", path, err)
}
var cfg Config
if err := json.Unmarshal(data, &cfg); err != nil {
return nil, fmt.Errorf("parse %s: %w", path, err)
}
var removed []string
for targetName, failedURLs := range broken {
target, ok := cfg.Forward.Targets[targetName]
if !ok {
continue
}
failSet := make(map[string]bool, len(failedURLs))
for _, u := range failedURLs {
failSet[u] = true
}
kept := target.Endpoints[:0]
for _, ep := range target.Endpoints {
if failSet[ep.URL] {
removed = append(removed, fmt.Sprintf("endpoint %s (target %q)", ep.URL, targetName))
} else {
kept = append(kept, ep)
}
}
if len(kept) == 0 {
delete(cfg.Forward.Targets, targetName)
removed = append(removed, fmt.Sprintf("target %q (all endpoints failed)", targetName))
if cfg.Forward.Default == targetName {
cfg.Forward.Default = ""
}
} else {
target.Endpoints = kept
cfg.Forward.Targets[targetName] = target
}
}
if len(removed) == 0 {
return nil, nil
}
out, err := json.MarshalIndent(cfg, "", " ")
if err != nil {
return nil, fmt.Errorf("marshal config: %w", err)
}
if err := os.WriteFile(path, append(out, '\n'), 0o600); err != nil {
return nil, fmt.Errorf("write %s: %w", path, err)
}
return removed, nil
}
// WriteConfig serialises cfg as indented JSON and atomically overwrites path.
func WriteConfig(path string, cfg Config) error {
out, err := json.MarshalIndent(cfg, "", " ")
if err != nil {
return fmt.Errorf("marshal config: %w", err)
}
if err := os.WriteFile(path, append(out, '\n'), 0o600); err != nil {
return fmt.Errorf("write %s: %w", path, err)
}
return nil
}
func saveTargetJSON(path string, data []byte, name string, target ForwardTarget) error {
var cfg Config
if err := json.Unmarshal(data, &cfg); err != nil {
return fmt.Errorf("parse %s: %w", path, err)
}
if cfg.Forward.Targets == nil {
cfg.Forward.Targets = make(map[string]ForwardTarget)
}
cfg.Forward.Targets[name] = target
if cfg.Forward.Default == "" {
cfg.Forward.Default = name
}
out, err := json.MarshalIndent(cfg, "", " ")
if err != nil {
return fmt.Errorf("marshal config: %w", err)
}
if err := os.WriteFile(path, append(out, '\n'), 0o600); err != nil {
return fmt.Errorf("write %s: %w", path, err)
}
return nil
}

206
pkg/discovery/discovery.go Normal file
View File

@@ -0,0 +1,206 @@
package discovery
import (
"context"
"encoding/json"
"fmt"
"net"
"net/http"
"sync"
"time"
)
// Kind describes a known server type.
type Kind struct {
Name string
APIType string
Port int
NeedsKey bool
}
// Found is a server discovered on the network.
type Found struct {
Kind Kind
BaseURL string
Models []string
}
// knownServers lists server types by their default port and display name.
// Ollama is listed separately because it uses a non-OpenAI probe endpoint.
var knownServers = []Kind{
{Name: "Ollama", APIType: "openai", Port: 11434},
{Name: "LM Studio", APIType: "openai", Port: 1234},
{Name: "vLLM", APIType: "openai", Port: 8000, NeedsKey: true},
{Name: "LocalAI", APIType: "openai", Port: 8080},
{Name: "Jan", APIType: "openai", Port: 1337},
{Name: "Kobold", APIType: "openai", Port: 5001},
{Name: "Tabby", APIType: "openai", Port: 9090},
}
// Scan concurrently probes localhost and LAN gateway addresses for known LLM servers.
// Results are returned in the order they are found (non-deterministic).
func Scan(ctx context.Context) []Found {
hosts := localHosts()
var (
mu sync.Mutex
results []Found
wg sync.WaitGroup
)
probeCtx, cancel := context.WithTimeout(ctx, 4*time.Second)
defer cancel()
httpClient := &http.Client{Timeout: 600 * time.Millisecond}
for _, host := range hosts {
for _, kind := range knownServers {
host := host
wg.Add(1)
go func(kind Kind) {
defer wg.Done()
baseURL := fmt.Sprintf("http://%s:%d", host, kind.Port)
models, err := probe(probeCtx, httpClient, baseURL, kind)
if err != nil {
return
}
mu.Lock()
results = append(results, Found{Kind: kind, BaseURL: baseURL, Models: models})
mu.Unlock()
}(kind)
}
}
wg.Wait()
return results
}
// Models fetches the model list from a single base URL and kind (for the models command).
func Models(ctx context.Context, baseURL string, kind Kind) ([]string, error) {
httpClient := &http.Client{Timeout: 5 * time.Second}
return probe(ctx, httpClient, baseURL, kind)
}
// localHosts returns localhost plus the .1 gateway of every local IPv4 subnet.
func localHosts() []string {
seen := map[string]bool{"127.0.0.1": true}
hosts := []string{"127.0.0.1"}
ifaces, err := net.Interfaces()
if err != nil {
return hosts
}
for _, iface := range ifaces {
if iface.Flags&net.FlagUp == 0 {
continue
}
addrs, err := iface.Addrs()
if err != nil {
continue
}
for _, addr := range addrs {
ipnet, ok := addr.(*net.IPNet)
if !ok {
continue
}
ip := ipnet.IP.To4()
if ip == nil || ip.IsLoopback() {
continue
}
// derive the likely gateway: network address + 1
gw := ip.Mask(ipnet.Mask)
gw[3] = 1
h := gw.String()
if !seen[h] {
seen[h] = true
hosts = append(hosts, h)
}
}
}
return hosts
}
// probe attempts to identify the server at baseURL and returns its model list.
func probe(ctx context.Context, client *http.Client, baseURL string, kind Kind) ([]string, error) {
// Ollama has its own endpoint; everything else is OpenAI-compatible
if kind.Name == "Ollama" {
models, err := probeOllama(ctx, client, baseURL)
if err != nil {
// Ollama also exposes /v1/models since v0.1.27 — fall back
return probeOpenAI(ctx, client, baseURL)
}
return models, nil
}
return probeOpenAI(ctx, client, baseURL)
}
// --- Ollama ---
type ollamaTagsResponse struct {
Models []struct {
Name string `json:"name"`
} `json:"models"`
}
func probeOllama(ctx context.Context, client *http.Client, baseURL string) ([]string, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, baseURL+"/api/tags", nil)
if err != nil {
return nil, err
}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("status %d", resp.StatusCode)
}
var body ollamaTagsResponse
if err := json.NewDecoder(resp.Body).Decode(&body); err != nil {
return nil, fmt.Errorf("decode ollama response: %w", err)
}
models := make([]string, len(body.Models))
for i, m := range body.Models {
models[i] = m.Name
}
return models, nil
}
// --- OpenAI-compatible ---
type openAIModelsResponse struct {
Data []struct {
ID string `json:"id"`
} `json:"data"`
}
func probeOpenAI(ctx context.Context, client *http.Client, baseURL string) ([]string, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, baseURL+"/v1/models", nil)
if err != nil {
return nil, err
}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("status %d", resp.StatusCode)
}
var body openAIModelsResponse
if err := json.NewDecoder(resp.Body).Decode(&body); err != nil {
return nil, fmt.Errorf("decode openai response: %w", err)
}
models := make([]string, len(body.Data))
for i, m := range body.Data {
models[i] = m.ID
}
return models, nil
}

27
pkg/embedclient/client.go Normal file
View File

@@ -0,0 +1,27 @@
package embedclient
import "context"
// Request is a batch of texts to embed.
type Request struct {
Texts []string
Model string
}
// Usage reports token consumption from the backing model.
type Usage struct {
PromptTokens int
TotalTokens int
}
// Response holds the raw embeddings returned by the backing model.
type Response struct {
Embeddings [][]float32
Model string
Usage Usage
}
// Client sends text to a backing embedding model and returns raw vectors.
type Client interface {
Embed(ctx context.Context, req Request) (Response, error)
}

98
pkg/embedclient/google.go Normal file
View File

@@ -0,0 +1,98 @@
package embedclient
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
)
type googleClient struct {
baseURL string
apiKey string
model string
httpClient *http.Client
}
// NewGoogle returns a Client that speaks the Google Gemini batchEmbedContents API.
func NewGoogle(baseURL, apiKey, model string, httpClient *http.Client) Client {
if httpClient == nil {
httpClient = http.DefaultClient
}
return &googleClient{baseURL: baseURL, apiKey: apiKey, model: model, httpClient: httpClient}
}
type googleBatchRequest struct {
Requests []googleEmbedRequest `json:"requests"`
}
type googleEmbedRequest struct {
Model string `json:"model"`
Content googleContent `json:"content"`
}
type googleContent struct {
Parts []googlePart `json:"parts"`
}
type googlePart struct {
Text string `json:"text"`
}
type googleBatchResponse struct {
Embeddings []struct {
Values []float32 `json:"values"`
} `json:"embeddings"`
}
func (c *googleClient) Embed(ctx context.Context, req Request) (Response, error) {
requests := make([]googleEmbedRequest, len(req.Texts))
for i, text := range req.Texts {
requests[i] = googleEmbedRequest{
Model: "models/" + c.model,
Content: googleContent{Parts: []googlePart{{Text: text}}},
}
}
body, err := json.Marshal(googleBatchRequest{Requests: requests})
if err != nil {
return Response{}, fmt.Errorf("google embed marshal: %w", err)
}
url := fmt.Sprintf("%s/v1/models/%s:batchEmbedContents", c.baseURL, c.model)
if c.apiKey != "" {
url += "?key=" + c.apiKey
}
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
if err != nil {
return Response{}, fmt.Errorf("google embed request: %w", err)
}
httpReq.Header.Set("Content-Type", "application/json")
resp, err := c.httpClient.Do(httpReq)
if err != nil {
return Response{}, fmt.Errorf("google embed do: %w", err)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
return Response{}, fmt.Errorf("google embed: unexpected status %d", resp.StatusCode)
}
var gResp googleBatchResponse
if err := json.NewDecoder(resp.Body).Decode(&gResp); err != nil {
return Response{}, fmt.Errorf("google embed decode: %w", err)
}
embeddings := make([][]float32, len(gResp.Embeddings))
for i, e := range gResp.Embeddings {
embeddings[i] = e.Values
}
return Response{
Embeddings: embeddings,
Model: c.model,
}, nil
}

87
pkg/embedclient/openai.go Normal file
View File

@@ -0,0 +1,87 @@
package embedclient
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
)
type openAIClient struct {
baseURL string
apiKey string
httpClient *http.Client
}
// NewOpenAI returns a Client that speaks the OpenAI embeddings API.
func NewOpenAI(baseURL, apiKey string, httpClient *http.Client) Client {
if httpClient == nil {
httpClient = http.DefaultClient
}
return &openAIClient{baseURL: baseURL, apiKey: apiKey, httpClient: httpClient}
}
type openAIEmbedRequest struct {
Input []string `json:"input"`
Model string `json:"model"`
}
type openAIEmbedResponse struct {
Object string `json:"object"`
Data []struct {
Object string `json:"object"`
Embedding []float32 `json:"embedding"`
Index int `json:"index"`
} `json:"data"`
Model string `json:"model"`
Usage struct {
PromptTokens int `json:"prompt_tokens"`
TotalTokens int `json:"total_tokens"`
} `json:"usage"`
}
func (c *openAIClient) Embed(ctx context.Context, req Request) (Response, error) {
body, err := json.Marshal(openAIEmbedRequest{Input: req.Texts, Model: req.Model})
if err != nil {
return Response{}, fmt.Errorf("openai embed marshal: %w", err)
}
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+"/v1/embeddings", bytes.NewReader(body))
if err != nil {
return Response{}, fmt.Errorf("openai embed request: %w", err)
}
httpReq.Header.Set("Content-Type", "application/json")
if c.apiKey != "" {
httpReq.Header.Set("Authorization", "Bearer "+c.apiKey)
}
resp, err := c.httpClient.Do(httpReq)
if err != nil {
return Response{}, fmt.Errorf("openai embed do: %w", err)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
return Response{}, fmt.Errorf("openai embed: unexpected status %d", resp.StatusCode)
}
var oaiResp openAIEmbedResponse
if err := json.NewDecoder(resp.Body).Decode(&oaiResp); err != nil {
return Response{}, fmt.Errorf("openai embed decode: %w", err)
}
embeddings := make([][]float32, len(oaiResp.Data))
for _, d := range oaiResp.Data {
embeddings[d.Index] = d.Embedding
}
return Response{
Embeddings: embeddings,
Model: oaiResp.Model,
Usage: Usage{
PromptTokens: oaiResp.Usage.PromptTokens,
TotalTokens: oaiResp.Usage.TotalTokens,
},
}, nil
}

180
pkg/embedclient/router.go Normal file
View File

@@ -0,0 +1,180 @@
package embedclient
import (
"context"
"fmt"
"sync"
"time"
"github.com/Warky-Devs/vecna.git/pkg/metrics"
)
// RouterConfig holds tuning parameters for a TargetRouter.
type RouterConfig struct {
TargetName string
TimeoutSecs int
CooldownSecs int
PriorityDecay int
PriorityRecovery int
}
type endpointSlot struct {
client Client
url string
initialPriority int
mu sync.Mutex
priority int
inflight int
successCount int
lastFail time.Time
}
// TargetRouter implements Client by routing requests across multiple endpoint slots
// using a busyness-based priority algorithm.
type TargetRouter struct {
slots []*endpointSlot
cfg RouterConfig
metrics *metrics.Registry
}
// NewTargetRouter constructs a TargetRouter from a slice of (client, url, initialPriority) tuples.
func NewTargetRouter(slots []RouterSlot, cfg RouterConfig, reg *metrics.Registry) (*TargetRouter, error) {
if len(slots) == 0 {
return nil, fmt.Errorf("NewTargetRouter: at least one slot required")
}
es := make([]*endpointSlot, len(slots))
for i, s := range slots {
es[i] = &endpointSlot{
client: s.Client,
url: s.URL,
initialPriority: s.Priority,
priority: s.Priority,
}
if reg != nil {
reg.SetEndpointPriority(cfg.TargetName, s.URL, float64(s.Priority))
reg.SetEndpointInflight(cfg.TargetName, s.URL, 0)
}
}
return &TargetRouter{slots: es, cfg: cfg, metrics: reg}, nil
}
// RouterSlot is a single endpoint entry for NewTargetRouter.
type RouterSlot struct {
Client Client
URL string
Priority int
}
func (r *TargetRouter) Embed(ctx context.Context, req Request) (Response, error) {
slot := r.pick()
slot.mu.Lock()
slot.inflight++
if r.metrics != nil {
r.metrics.SetEndpointInflight(r.cfg.TargetName, slot.url, float64(slot.inflight))
}
slot.mu.Unlock()
defer func() {
slot.mu.Lock()
slot.inflight--
if r.metrics != nil {
r.metrics.SetEndpointInflight(r.cfg.TargetName, slot.url, float64(slot.inflight))
}
slot.mu.Unlock()
}()
timeout := time.Duration(r.cfg.TimeoutSecs) * time.Second
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
resp, err := slot.client.Embed(ctx, req)
if err != nil {
r.onFailure(slot, err)
return Response{}, fmt.Errorf("router embed [%s]: %w", slot.url, err)
}
r.onSuccess(slot)
return resp, nil
}
// pick selects the best available slot.
func (r *TargetRouter) pick() *endpointSlot {
cooldown := time.Duration(r.cfg.CooldownSecs) * time.Second
now := time.Now()
var best *endpointSlot
bestScore := -1 << 30
for _, s := range r.slots {
s.mu.Lock()
inCooldown := !s.lastFail.IsZero() && now.Sub(s.lastFail) < cooldown
score := s.priority - s.inflight
lastFail := s.lastFail
s.mu.Unlock()
if inCooldown {
continue
}
if best == nil || score > bestScore {
best = s
bestScore = score
_ = lastFail
}
}
// All in cooldown — fall back to oldest failure
if best == nil {
var oldest time.Time
for _, s := range r.slots {
s.mu.Lock()
lf := s.lastFail
s.mu.Unlock()
if best == nil || lf.Before(oldest) {
best = s
oldest = lf
}
}
}
return best
}
func (r *TargetRouter) onSuccess(s *endpointSlot) {
s.mu.Lock()
defer s.mu.Unlock()
s.successCount++
if r.cfg.PriorityRecovery > 0 && s.successCount%r.cfg.PriorityRecovery == 0 {
if s.priority < s.initialPriority {
s.priority++
}
}
if r.metrics != nil {
r.metrics.SetEndpointPriority(r.cfg.TargetName, s.url, float64(s.priority))
}
}
func (r *TargetRouter) onFailure(s *endpointSlot, err error) {
s.mu.Lock()
defer s.mu.Unlock()
s.lastFail = time.Now()
s.priority -= r.cfg.PriorityDecay
if s.priority < 1 {
s.priority = 1
}
errType := "error"
if ctx := context.Background(); ctx.Err() != nil {
errType = "timeout"
}
if r.metrics != nil {
r.metrics.SetEndpointPriority(r.cfg.TargetName, s.url, float64(s.priority))
r.metrics.IncEndpointErrors(r.cfg.TargetName, s.url, errType)
}
_ = err
}

112
pkg/metrics/metrics.go Normal file
View File

@@ -0,0 +1,112 @@
package metrics
import (
"github.com/prometheus/client_golang/prometheus"
)
// Registry holds all vecna Prometheus metrics on a dedicated (non-global) registry.
type Registry struct {
reg *prometheus.Registry
RequestsTotal *prometheus.CounterVec
RequestDuration *prometheus.HistogramVec
ForwardDuration *prometheus.HistogramVec
TranslateDuration *prometheus.HistogramVec
EndpointPriority *prometheus.GaugeVec
EndpointInflight *prometheus.GaugeVec
EndpointErrorsTotal *prometheus.CounterVec
TokensTotal *prometheus.CounterVec
}
// New creates and registers all metrics on a fresh Prometheus registry.
func New() *Registry {
reg := prometheus.NewRegistry()
r := &Registry{
reg: reg,
RequestsTotal: prometheus.NewCounterVec(prometheus.CounterOpts{
Name: "vecna_requests_total",
Help: "Total number of requests served by vecna.",
}, []string{"endpoint", "status"}),
RequestDuration: prometheus.NewHistogramVec(prometheus.HistogramOpts{
Name: "vecna_request_duration_seconds",
Help: "Total request wall-clock time.",
Buckets: prometheus.DefBuckets,
}, []string{"endpoint"}),
ForwardDuration: prometheus.NewHistogramVec(prometheus.HistogramOpts{
Name: "vecna_forward_duration_seconds",
Help: "Time spent waiting on the backing embedding model.",
Buckets: prometheus.DefBuckets,
}, []string{"target", "url"}),
TranslateDuration: prometheus.NewHistogramVec(prometheus.HistogramOpts{
Name: "vecna_translate_duration_seconds",
Help: "Time spent in the dimension adapter.",
Buckets: []float64{0.0001, 0.0005, 0.001, 0.005, 0.01, 0.05},
}, []string{"adapter_type"}),
EndpointPriority: prometheus.NewGaugeVec(prometheus.GaugeOpts{
Name: "vecna_endpoint_priority",
Help: "Current dynamic routing priority for a forwarding endpoint.",
}, []string{"target", "url"}),
EndpointInflight: prometheus.NewGaugeVec(prometheus.GaugeOpts{
Name: "vecna_endpoint_inflight",
Help: "Number of active in-flight requests per forwarding endpoint.",
}, []string{"target", "url"}),
EndpointErrorsTotal: prometheus.NewCounterVec(prometheus.CounterOpts{
Name: "vecna_endpoint_errors_total",
Help: "Total forwarding errors per endpoint, labelled by error type.",
}, []string{"target", "url", "error"}),
TokensTotal: prometheus.NewCounterVec(prometheus.CounterOpts{
Name: "vecna_tokens_total",
Help: "Tokens consumed by the backing embedding model, by target, model, and token type.",
}, []string{"target", "model", "token_type"}),
}
reg.MustRegister(
r.RequestsTotal,
r.RequestDuration,
r.ForwardDuration,
r.TranslateDuration,
r.EndpointPriority,
r.EndpointInflight,
r.EndpointErrorsTotal,
r.TokensTotal,
)
return r
}
// Prometheus returns the underlying registry for use with promhttp.HandlerFor.
func (r *Registry) Prometheus() *prometheus.Registry {
return r.reg
}
// Convenience setters used by the router.
func (r *Registry) SetEndpointPriority(target, url string, v float64) {
r.EndpointPriority.WithLabelValues(target, url).Set(v)
}
func (r *Registry) SetEndpointInflight(target, url string, v float64) {
r.EndpointInflight.WithLabelValues(target, url).Set(v)
}
func (r *Registry) IncEndpointErrors(target, url, errType string) {
r.EndpointErrorsTotal.WithLabelValues(target, url, errType).Inc()
}
func (r *Registry) AddTokens(target, model string, promptTokens, totalTokens int) {
if promptTokens > 0 {
r.TokensTotal.WithLabelValues(target, model, "prompt").Add(float64(promptTokens))
}
if totalTokens > 0 {
r.TokensTotal.WithLabelValues(target, model, "total").Add(float64(totalTokens))
}
}

62
pkg/metrics/middleware.go Normal file
View File

@@ -0,0 +1,62 @@
package metrics
import (
"fmt"
"net/http"
"github.com/uptrace/bunrouter"
)
// Middleware returns a bunrouter middleware that records per-request Prometheus metrics.
// It reads timing from the RequestTrace stored in the context (set by server/trace.go).
// The trace target/url labels are optional; pass empty strings if not applicable.
func (r *Registry) Middleware(getTrace func(req bunrouter.Request) TraceSnapshot) bunrouter.MiddlewareFunc {
return func(next bunrouter.HandlerFunc) bunrouter.HandlerFunc {
return func(w http.ResponseWriter, req bunrouter.Request) error {
rw := &statusWriter{ResponseWriter: w, status: http.StatusOK}
err := next(rw, req)
snap := getTrace(req)
endpoint := req.URL.Path
status := fmt.Sprintf("%d", rw.status)
r.RequestsTotal.WithLabelValues(endpoint, status).Inc()
r.RequestDuration.WithLabelValues(endpoint).Observe(snap.TotalSeconds)
if snap.ForwardTarget != "" {
r.ForwardDuration.WithLabelValues(snap.ForwardTarget, snap.ForwardURL).Observe(snap.ForwardSeconds)
}
if snap.AdapterType != "" {
r.TranslateDuration.WithLabelValues(snap.AdapterType).Observe(snap.TranslateSeconds)
}
if snap.PromptTokens > 0 || snap.TotalTokens > 0 {
r.AddTokens(snap.ForwardTarget, snap.ForwardModel, snap.PromptTokens, snap.TotalTokens)
}
return err
}
}
}
// TraceSnapshot carries the timing and usage values the metrics middleware needs.
type TraceSnapshot struct {
TotalSeconds float64
ForwardSeconds float64
TranslateSeconds float64
ForwardTarget string
ForwardURL string
ForwardModel string
AdapterType string
PromptTokens int
TotalTokens int
}
// statusWriter wraps http.ResponseWriter to capture the written status code.
type statusWriter struct {
http.ResponseWriter
status int
}
func (sw *statusWriter) WriteHeader(code int) {
sw.status = code
sw.ResponseWriter.WriteHeader(code)
}

135
pkg/server/google.go Normal file
View File

@@ -0,0 +1,135 @@
package server
import (
"encoding/json"
"net/http"
"time"
"github.com/uptrace/bunrouter"
"github.com/Warky-Devs/vecna.git/pkg/embedclient"
)
// --- single embedContent ---
type googleEmbedContentRequest struct {
Content googleContent `json:"content"`
TaskType string `json:"taskType,omitempty"`
}
type googleContent struct {
Parts []googlePart `json:"parts"`
}
type googlePart struct {
Text string `json:"text"`
}
type googleEmbedContentResponse struct {
Embedding googleEmbeddingValues `json:"embedding"`
}
type googleEmbeddingValues struct {
Values []float32 `json:"values"`
}
func (h *handler) googleEmbedContent(w http.ResponseWriter, req bunrouter.Request) error {
model := req.Param("model")
var body googleEmbedContentRequest
if err := json.NewDecoder(req.Body).Decode(&body); err != nil {
return writeJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid request body"})
}
texts := make([]string, len(body.Content.Parts))
for i, p := range body.Content.Parts {
texts[i] = p.Text
}
client, targetName, targetURL := h.resolveClient(model)
trace := TraceFromContext(req.Context())
trace.ForwardTarget = targetName
trace.ForwardURL = targetURL
t0 := time.Now()
embedResp, err := client.Embed(req.Context(), embedclient.Request{Texts: texts, Model: model})
trace.ForwardDuration = time.Since(t0)
if err != nil {
return writeJSON(w, http.StatusBadGateway, map[string]string{"error": err.Error()})
}
trace.ForwardModel = embedResp.Model
trace.PromptTokens = embedResp.Usage.PromptTokens
trace.TotalTokens = embedResp.Usage.TotalTokens
t1 := time.Now()
var adapted []float32
if len(embedResp.Embeddings) > 0 {
adapted, err = h.adapter.Adapt(embedResp.Embeddings[0])
if err != nil {
return writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()})
}
}
trace.TranslateDuration = time.Since(t1)
writeTraceHeaders(w, trace)
return writeJSON(w, http.StatusOK, googleEmbedContentResponse{
Embedding: googleEmbeddingValues{Values: adapted},
})
}
// --- batch batchEmbedContents ---
type googleBatchRequest struct {
Requests []googleEmbedContentRequest `json:"requests"`
}
type googleBatchResponse struct {
Embeddings []googleEmbeddingValues `json:"embeddings"`
}
func (h *handler) googleBatchEmbedContents(w http.ResponseWriter, req bunrouter.Request) error {
model := req.Param("model")
var body googleBatchRequest
if err := json.NewDecoder(req.Body).Decode(&body); err != nil {
return writeJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid request body"})
}
var texts []string
for _, r := range body.Requests {
for _, p := range r.Content.Parts {
texts = append(texts, p.Text)
}
}
client, targetName, targetURL := h.resolveClient(model)
trace := TraceFromContext(req.Context())
trace.ForwardTarget = targetName
trace.ForwardURL = targetURL
t0 := time.Now()
embedResp, err := client.Embed(req.Context(), embedclient.Request{Texts: texts, Model: model})
trace.ForwardDuration = time.Since(t0)
if err != nil {
return writeJSON(w, http.StatusBadGateway, map[string]string{"error": err.Error()})
}
trace.ForwardModel = embedResp.Model
trace.PromptTokens = embedResp.Usage.PromptTokens
trace.TotalTokens = embedResp.Usage.TotalTokens
t1 := time.Now()
result := make([]googleEmbeddingValues, len(embedResp.Embeddings))
for i, vec := range embedResp.Embeddings {
adapted, adaptErr := h.adapter.Adapt(vec)
if adaptErr != nil {
return writeJSON(w, http.StatusInternalServerError, map[string]string{"error": adaptErr.Error()})
}
result[i] = googleEmbeddingValues{Values: adapted}
}
trace.TranslateDuration = time.Since(t1)
writeTraceHeaders(w, trace)
return writeJSON(w, http.StatusOK, googleBatchResponse{Embeddings: result})
}

74
pkg/server/handler.go Normal file
View File

@@ -0,0 +1,74 @@
package server
import (
"context"
"encoding/json"
"fmt"
"net/http"
"time"
"go.uber.org/zap"
"github.com/Warky-Devs/vecna.git/pkg/adapter"
"github.com/Warky-Devs/vecna.git/pkg/config"
"github.com/Warky-Devs/vecna.git/pkg/embedclient"
)
// handler holds shared dependencies for all HTTP handlers.
type handler struct {
cfg *config.Config
clients map[string]embedclient.Client
adapter adapter.Adapter
logger *zap.Logger
}
// resolveClient selects the embed client for the given model name.
// Returns the client, target name, and first endpoint URL for tracing.
func (h *handler) resolveClient(model string) (embedclient.Client, string, string) {
if c, ok := h.clients[model]; ok {
url := firstEndpointURL(h.cfg, model)
return c, model, url
}
name := h.cfg.Forward.Default
c, ok := h.clients[name]
if !ok {
// No configured client — return a nil-safe error client
return &errClient{err: fmt.Errorf("no client configured for model %q and no default", model)}, name, ""
}
return c, name, firstEndpointURL(h.cfg, name)
}
func firstEndpointURL(cfg *config.Config, targetName string) string {
t, ok := cfg.Forward.Targets[targetName]
if !ok || len(t.Endpoints) == 0 {
return ""
}
return t.Endpoints[0].URL
}
// writeJSON encodes v as JSON and writes it with the given status code.
func writeJSON(w http.ResponseWriter, status int, v interface{}) error {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
if err := json.NewEncoder(w).Encode(v); err != nil {
return fmt.Errorf("writeJSON: %w", err)
}
return nil
}
// writeTraceHeaders writes X-Vecna-* timing headers from the RequestTrace.
func writeTraceHeaders(w http.ResponseWriter, t *RequestTrace) {
total := time.Since(t.Start)
w.Header().Set("X-Vecna-Forward-Ms", fmt.Sprintf("%d", t.ForwardDuration.Milliseconds()))
w.Header().Set("X-Vecna-Translate-Ms", fmt.Sprintf("%d", t.TranslateDuration.Milliseconds()))
w.Header().Set("X-Vecna-Total-Ms", fmt.Sprintf("%d", total.Milliseconds()))
}
// errClient is a Client that always returns a fixed error (used as safe fallback).
type errClient struct {
err error
}
func (e *errClient) Embed(_ context.Context, _ embedclient.Request) (embedclient.Response, error) {
return embedclient.Response{}, e.err
}

102
pkg/server/openai.go Normal file
View File

@@ -0,0 +1,102 @@
package server
import (
"encoding/json"
"fmt"
"net/http"
"time"
"github.com/uptrace/bunrouter"
"github.com/Warky-Devs/vecna.git/pkg/embedclient"
)
type openAIEmbedRequest struct {
Input interface{} `json:"input"` // string or []string
Model string `json:"model"`
}
type openAIEmbedResponse struct {
Object string `json:"object"`
Data []openAIEmbedDatum `json:"data"`
Model string `json:"model"`
Usage openAIUsage `json:"usage"`
}
type openAIEmbedDatum struct {
Object string `json:"object"`
Embedding []float32 `json:"embedding"`
Index int `json:"index"`
}
type openAIUsage struct {
PromptTokens int `json:"prompt_tokens"`
TotalTokens int `json:"total_tokens"`
}
func (h *handler) openAIEmbeddings(w http.ResponseWriter, req bunrouter.Request) error {
var body openAIEmbedRequest
if err := json.NewDecoder(req.Body).Decode(&body); err != nil {
return writeJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid request body"})
}
texts, err := toStringSlice(body.Input)
if err != nil {
return writeJSON(w, http.StatusBadRequest, map[string]string{"error": err.Error()})
}
client, targetName, targetURL := h.resolveClient(body.Model)
trace := TraceFromContext(req.Context())
trace.ForwardTarget = targetName
trace.ForwardURL = targetURL
t0 := time.Now()
embedResp, err := client.Embed(req.Context(), embedclient.Request{Texts: texts, Model: body.Model})
trace.ForwardDuration = time.Since(t0)
if err != nil {
return writeJSON(w, http.StatusBadGateway, map[string]string{"error": err.Error()})
}
trace.ForwardModel = embedResp.Model
trace.PromptTokens = embedResp.Usage.PromptTokens
trace.TotalTokens = embedResp.Usage.TotalTokens
t1 := time.Now()
data := make([]openAIEmbedDatum, len(embedResp.Embeddings))
for i, vec := range embedResp.Embeddings {
adapted, adaptErr := h.adapter.Adapt(vec)
if adaptErr != nil {
return writeJSON(w, http.StatusInternalServerError, map[string]string{"error": adaptErr.Error()})
}
data[i] = openAIEmbedDatum{Object: "embedding", Embedding: adapted, Index: i}
}
trace.TranslateDuration = time.Since(t1)
writeTraceHeaders(w, trace)
return writeJSON(w, http.StatusOK, openAIEmbedResponse{
Object: "list",
Data: data,
Model: embedResp.Model,
Usage: openAIUsage{PromptTokens: embedResp.Usage.PromptTokens, TotalTokens: embedResp.Usage.TotalTokens},
})
}
// toStringSlice accepts a JSON string or array of strings.
func toStringSlice(v interface{}) ([]string, error) {
switch val := v.(type) {
case string:
return []string{val}, nil
case []interface{}:
out := make([]string, len(val))
for i, item := range val {
s, ok := item.(string)
if !ok {
return nil, fmt.Errorf("input array element %d is not a string", i)
}
out[i] = s
}
return out, nil
default:
return nil, fmt.Errorf("input must be a string or array of strings")
}
}

163
pkg/server/server.go Normal file
View File

@@ -0,0 +1,163 @@
package server
import (
"fmt"
"net/http"
"strings"
"time"
"github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/uptrace/bunrouter"
"go.uber.org/zap"
"github.com/Warky-Devs/vecna.git/pkg/adapter"
"github.com/Warky-Devs/vecna.git/pkg/config"
"github.com/Warky-Devs/vecna.git/pkg/embedclient"
"github.com/Warky-Devs/vecna.git/pkg/metrics"
"github.com/Warky-Devs/vecna.git/pkg/server/spec"
)
// New builds and returns a configured bunrouter.Router.
func New(
cfg *config.Config,
clients map[string]embedclient.Client,
adp adapter.Adapter,
reg *metrics.Registry,
logger *zap.Logger,
) *bunrouter.Router {
router := bunrouter.New(
bunrouter.WithMiddleware(authMiddleware(cfg.Server.APIKeys)),
bunrouter.WithMiddleware(traceMiddleware()),
bunrouter.WithMiddleware(metricsMiddleware(reg, adp)),
bunrouter.WithMiddleware(loggingMiddleware(logger)),
)
h := &handler{cfg: cfg, clients: clients, adapter: adp, logger: logger}
router.POST("/v1/embeddings", h.openAIEmbeddings)
router.POST("/v1/models/:model:embedContent", h.googleEmbedContent)
router.POST("/v1/models/:model:batchEmbedContents", h.googleBatchEmbedContents)
// OpenAPI spec + docs
router.GET("/openapi.yaml", spec.SpecHandler())
router.GET("/docs", spec.DocsHandler())
// Metrics — only when enabled
if cfg.Metrics.Enabled {
metricsHandler := promhttp.HandlerFor(reg.Prometheus(), promhttp.HandlerOpts{})
path := cfg.Metrics.Path
if path == "" {
path = "/metrics"
}
if cfg.Metrics.APIKey != "" {
router.GET(path, metricsAuthHandler(cfg.Metrics.APIKey, metricsHandler))
} else {
router.GET(path, func(w http.ResponseWriter, req bunrouter.Request) error {
metricsHandler.ServeHTTP(w, req.Request)
return nil
})
}
}
return router
}
// authMiddleware rejects requests without a valid Bearer token when api_keys is configured.
func authMiddleware(apiKeys []string) bunrouter.MiddlewareFunc {
if len(apiKeys) == 0 {
return func(next bunrouter.HandlerFunc) bunrouter.HandlerFunc { return next }
}
keySet := make(map[string]struct{}, len(apiKeys))
for _, k := range apiKeys {
keySet[k] = struct{}{}
}
return func(next bunrouter.HandlerFunc) bunrouter.HandlerFunc {
return func(w http.ResponseWriter, req bunrouter.Request) error {
token := strings.TrimPrefix(req.Header.Get("Authorization"), "Bearer ")
if _, ok := keySet[token]; !ok {
http.Error(w, "unauthorized", http.StatusUnauthorized)
return nil
}
return next(w, req)
}
}
}
// traceMiddleware injects a *RequestTrace into every request context.
func traceMiddleware() bunrouter.MiddlewareFunc {
return func(next bunrouter.HandlerFunc) bunrouter.HandlerFunc {
return func(w http.ResponseWriter, req bunrouter.Request) error {
ctx := WithTrace(req.Context())
return next(w, req.WithContext(ctx))
}
}
}
// metricsMiddleware records Prometheus observations after the handler returns.
func metricsMiddleware(reg *metrics.Registry, adp adapter.Adapter) bunrouter.MiddlewareFunc {
if reg == nil {
return func(next bunrouter.HandlerFunc) bunrouter.HandlerFunc { return next }
}
adpType := fmt.Sprintf("%T", adp)
return reg.Middleware(func(req bunrouter.Request) metrics.TraceSnapshot {
t := TraceFromContext(req.Context())
total := time.Since(t.Start)
return metrics.TraceSnapshot{
TotalSeconds: total.Seconds(),
ForwardSeconds: t.ForwardDuration.Seconds(),
TranslateSeconds: t.TranslateDuration.Seconds(),
ForwardTarget: t.ForwardTarget,
ForwardURL: t.ForwardURL,
ForwardModel: t.ForwardModel,
AdapterType: adpType,
PromptTokens: t.PromptTokens,
TotalTokens: t.TotalTokens,
}
})
}
// loggingMiddleware logs method, path, status, and timing via zap.
func loggingMiddleware(logger *zap.Logger) bunrouter.MiddlewareFunc {
return func(next bunrouter.HandlerFunc) bunrouter.HandlerFunc {
return func(w http.ResponseWriter, req bunrouter.Request) error {
sw := &statusWriter{ResponseWriter: w, status: http.StatusOK}
err := next(sw, req)
t := TraceFromContext(req.Context())
total := time.Since(t.Start)
logger.Info("request",
zap.String("method", req.Method),
zap.String("path", req.URL.Path),
zap.Int("status", sw.status),
zap.Int64("total_ms", total.Milliseconds()),
zap.Int64("forward_ms", t.ForwardDuration.Milliseconds()),
zap.Int64("translate_ms", t.TranslateDuration.Milliseconds()),
)
return err
}
}
}
// metricsAuthHandler wraps a standard http.Handler with Bearer token auth.
func metricsAuthHandler(apiKey string, h http.Handler) bunrouter.HandlerFunc {
return func(w http.ResponseWriter, req bunrouter.Request) error {
token := strings.TrimPrefix(req.Header.Get("Authorization"), "Bearer ")
if token != apiKey {
http.Error(w, "unauthorized", http.StatusUnauthorized)
return nil
}
h.ServeHTTP(w, req.Request)
return nil
}
}
// statusWriter captures the HTTP status code written by a handler.
type statusWriter struct {
http.ResponseWriter
status int
}
func (sw *statusWriter) WriteHeader(code int) {
sw.status = code
sw.ResponseWriter.WriteHeader(code)
}

View File

@@ -0,0 +1,36 @@
package spec
import (
_ "embed"
"net/http"
"github.com/uptrace/bunrouter"
)
//go:embed openapi.yaml
var openapiYAML []byte
// SpecHandler serves the raw OpenAPI YAML spec.
func SpecHandler() bunrouter.HandlerFunc {
return func(w http.ResponseWriter, req bunrouter.Request) error {
w.Header().Set("Content-Type", "application/yaml")
_, err := w.Write(openapiYAML)
return err
}
}
// DocsHandler serves the Scalar API reference UI.
func DocsHandler() bunrouter.HandlerFunc {
return func(w http.ResponseWriter, req bunrouter.Request) error {
w.Header().Set("Content-Type", "text/html; charset=utf-8")
_, err := w.Write([]byte(`<!doctype html>
<html>
<head><title>vecna API</title><meta charset="utf-8"/></head>
<body>
<script id="api-reference" data-url="/openapi.yaml"></script>
<script src="https://cdn.jsdelivr.net/npm/@scalar/api-reference"></script>
</body>
</html>`))
return err
}
}

View File

@@ -0,0 +1,252 @@
openapi: "3.1.0"
info:
title: vecna Embedding Adapter
description: Proxies text to a backing embedding model and adapts the result vectors between dimensions.
version: "1.0.0"
servers:
- url: http://localhost:8080
security:
- BearerAuth: []
components:
securitySchemes:
BearerAuth:
type: http
scheme: bearer
schemas:
Error:
type: object
properties:
error:
type: string
OpenAIEmbedRequest:
type: object
required: [input, model]
properties:
input:
oneOf:
- type: string
- type: array
items:
type: string
model:
type: string
OpenAIEmbedResponse:
type: object
properties:
object:
type: string
example: list
model:
type: string
data:
type: array
items:
type: object
properties:
object:
type: string
example: embedding
index:
type: integer
embedding:
type: array
items:
type: number
format: float
usage:
type: object
properties:
prompt_tokens:
type: integer
total_tokens:
type: integer
GoogleEmbedContentRequest:
type: object
required: [content]
properties:
content:
type: object
properties:
parts:
type: array
items:
type: object
properties:
text:
type: string
taskType:
type: string
GoogleEmbedContentResponse:
type: object
properties:
embedding:
type: object
properties:
values:
type: array
items:
type: number
format: float
GoogleBatchRequest:
type: object
required: [requests]
properties:
requests:
type: array
items:
$ref: '#/components/schemas/GoogleEmbedContentRequest'
GoogleBatchResponse:
type: object
properties:
embeddings:
type: array
items:
type: object
properties:
values:
type: array
items:
type: number
format: float
headers:
X-Vecna-Forward-Ms:
description: Time spent forwarding the request to the backing model (milliseconds).
schema:
type: integer
X-Vecna-Translate-Ms:
description: Time spent in the dimension adapter (milliseconds).
schema:
type: integer
X-Vecna-Total-Ms:
description: Total request wall-clock time (milliseconds).
schema:
type: integer
paths:
/v1/embeddings:
post:
summary: OpenAI-compatible embeddings
operationId: openaiEmbeddings
requestBody:
required: true
content:
application/json:
schema:
$ref: '#/components/schemas/OpenAIEmbedRequest'
responses:
"200":
description: Adapted embeddings
headers:
X-Vecna-Forward-Ms:
$ref: '#/components/headers/X-Vecna-Forward-Ms'
X-Vecna-Translate-Ms:
$ref: '#/components/headers/X-Vecna-Translate-Ms'
X-Vecna-Total-Ms:
$ref: '#/components/headers/X-Vecna-Total-Ms'
content:
application/json:
schema:
$ref: '#/components/schemas/OpenAIEmbedResponse'
"400":
description: Bad request
content:
application/json:
schema:
$ref: '#/components/schemas/Error'
"401":
description: Unauthorized
"502":
description: Backing model error
/v1/models/{model}:embedContent:
post:
summary: Google-compatible single embedContent
operationId: googleEmbedContent
parameters:
- name: model
in: path
required: true
schema:
type: string
requestBody:
required: true
content:
application/json:
schema:
$ref: '#/components/schemas/GoogleEmbedContentRequest'
responses:
"200":
description: Adapted embedding
headers:
X-Vecna-Forward-Ms:
$ref: '#/components/headers/X-Vecna-Forward-Ms'
X-Vecna-Translate-Ms:
$ref: '#/components/headers/X-Vecna-Translate-Ms'
X-Vecna-Total-Ms:
$ref: '#/components/headers/X-Vecna-Total-Ms'
content:
application/json:
schema:
$ref: '#/components/schemas/GoogleEmbedContentResponse'
"400":
description: Bad request
content:
application/json:
schema:
$ref: '#/components/schemas/Error'
"401":
description: Unauthorized
"502":
description: Backing model error
/v1/models/{model}:batchEmbedContents:
post:
summary: Google-compatible batch batchEmbedContents
operationId: googleBatchEmbedContents
parameters:
- name: model
in: path
required: true
schema:
type: string
requestBody:
required: true
content:
application/json:
schema:
$ref: '#/components/schemas/GoogleBatchRequest'
responses:
"200":
description: Adapted embeddings
headers:
X-Vecna-Forward-Ms:
$ref: '#/components/headers/X-Vecna-Forward-Ms'
X-Vecna-Translate-Ms:
$ref: '#/components/headers/X-Vecna-Translate-Ms'
X-Vecna-Total-Ms:
$ref: '#/components/headers/X-Vecna-Total-Ms'
content:
application/json:
schema:
$ref: '#/components/schemas/GoogleBatchResponse'
"400":
description: Bad request
content:
application/json:
schema:
$ref: '#/components/schemas/Error'
"401":
description: Unauthorized
"502":
description: Backing model error

37
pkg/server/trace.go Normal file
View File

@@ -0,0 +1,37 @@
package server
import (
"context"
"time"
)
type contextKey int
const traceKey contextKey = iota
// RequestTrace holds per-request timing data populated by handlers and middleware.
type RequestTrace struct {
Start time.Time
ForwardDuration time.Duration
TranslateDuration time.Duration
ForwardTarget string
ForwardURL string
ForwardModel string
AdapterType string
PromptTokens int
TotalTokens int
}
// WithTrace injects a new *RequestTrace into ctx.
func WithTrace(ctx context.Context) context.Context {
return context.WithValue(ctx, traceKey, &RequestTrace{Start: time.Now()})
}
// TraceFromContext retrieves the *RequestTrace from ctx.
// Returns a zero-value trace (non-nil) if none was set.
func TraceFromContext(ctx context.Context) *RequestTrace {
if t, ok := ctx.Value(traceKey).(*RequestTrace); ok && t != nil {
return t
}
return &RequestTrace{Start: time.Now()}
}

9
prometheus.example.yml Normal file
View File

@@ -0,0 +1,9 @@
global:
scrape_interval: 15s
scrape_configs:
- job_name: vecna
static_configs:
- targets: ["vecna:8080"]
# bearer_token: sk-metrics-secret # uncomment if metrics.api_key is set
metrics_path: /metrics

View File

@@ -0,0 +1,330 @@
//go:build integration
package integration
import (
"context"
"math"
"net/http"
"os"
"strconv"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/Warky-Devs/vecna.git/pkg/adapter"
"github.com/Warky-Devs/vecna.git/pkg/embedclient"
)
// Environment variables that configure the integration tests:
//
// VECNA_TEST_URL base URL of the embedding server (required)
// VECNA_TEST_MODEL model name to request (required)
// VECNA_TEST_API_TYPE "openai" (default) or "google"
// VECNA_TEST_API_KEY bearer token, empty if not needed
//
// Example (Ollama):
//
// VECNA_TEST_URL=http://localhost:11434 VECNA_TEST_MODEL=nomic-embed-text \
// go test -tags integration ./tests/integration/
const testText = "The quick brown fox jumps over the lazy dog"
// cfg holds resolved test parameters.
type cfg struct {
url string
model string
apiType string
apiKey string
}
func loadCfg(t *testing.T) cfg {
t.Helper()
url := os.Getenv("VECNA_TEST_URL")
if url == "" {
t.Skip("VECNA_TEST_URL not set — skipping integration tests")
}
model := os.Getenv("VECNA_TEST_MODEL")
if model == "" {
t.Skip("VECNA_TEST_MODEL not set — skipping integration tests")
}
apiType := os.Getenv("VECNA_TEST_API_TYPE")
if apiType == "" {
apiType = "openai"
}
return cfg{
url: url,
model: model,
apiType: apiType,
apiKey: os.Getenv("VECNA_TEST_API_KEY"),
}
}
func newClient(c cfg) embedclient.Client {
httpClient := &http.Client{Timeout: 30 * time.Second}
if c.apiType == "google" {
return embedclient.NewGoogle(c.url, c.apiKey, c.model, httpClient)
}
return embedclient.NewOpenAI(c.url, c.apiKey, httpClient)
}
// embed fetches a single embedding vector for testText.
func embed(t *testing.T, client embedclient.Client, model string) []float32 {
t.Helper()
resp, err := client.Embed(context.Background(), embedclient.Request{
Texts: []string{testText},
Model: model,
})
require.NoError(t, err, "embedding request failed")
require.Len(t, resp.Embeddings, 1, "expected exactly one embedding in response")
require.NotEmpty(t, resp.Embeddings[0], "embedding vector is empty")
return resp.Embeddings[0]
}
func l2Norm(v []float32) float64 {
var sum float64
for _, x := range v {
sum += float64(x) * float64(x)
}
return math.Sqrt(sum)
}
// assertUnitNorm checks the vector is approximately L2-normalised.
func assertUnitNorm(t *testing.T, v []float32) {
t.Helper()
norm := l2Norm(v)
assert.InDelta(t, 1.0, norm, 0.01, "expected unit L2 norm after adaptation")
}
// ---- Tests ----------------------------------------------------------------
// TestNativeDimension verifies the server returns a non-empty vector.
// This is the baseline; the native dimension is logged so it can be used
// as VECNA_TEST_SOURCE_DIM for the dimension tests below.
func TestNativeDimension(t *testing.T) {
c := loadCfg(t)
client := newClient(c)
vec := embed(t, client, c.model)
t.Logf("native dimension: %d", len(vec))
t.Logf("native L2 norm: %.6f", l2Norm(vec))
assert.Greater(t, len(vec), 0)
}
// TestDownscaleTruncate tests truncation to half the native dimension.
func TestDownscaleTruncate(t *testing.T) {
c := loadCfg(t)
client := newClient(c)
vec := embed(t, client, c.model)
srcDim := len(vec)
tgtDim := srcDim / 2
if tgtDim == 0 {
t.Skipf("source dim %d too small to halve", srcDim)
}
adp, err := adapter.NewTruncate(srcDim, tgtDim, adapter.TruncateFromEnd, adapter.PadAtEnd)
require.NoError(t, err)
out, err := adp.Adapt(vec)
require.NoError(t, err)
assert.Len(t, out, tgtDim, "output dimension mismatch")
assertUnitNorm(t, out)
t.Logf("downscale truncate: %d → %d norm=%.6f", srcDim, tgtDim, l2Norm(out))
}
// TestDownscaleTruncateFromStart tests keeping the last N dims.
func TestDownscaleTruncateFromStart(t *testing.T) {
c := loadCfg(t)
client := newClient(c)
vec := embed(t, client, c.model)
srcDim := len(vec)
tgtDim := srcDim / 2
if tgtDim == 0 {
t.Skipf("source dim %d too small to halve", srcDim)
}
adp, err := adapter.NewTruncate(srcDim, tgtDim, adapter.TruncateFromStart, adapter.PadAtEnd)
require.NoError(t, err)
out, err := adp.Adapt(vec)
require.NoError(t, err)
assert.Len(t, out, tgtDim)
assertUnitNorm(t, out)
t.Logf("downscale truncate-from-start: %d → %d norm=%.6f", srcDim, tgtDim, l2Norm(out))
}
// TestDownscaleRandom tests random projection to a lower dimension.
func TestDownscaleRandom(t *testing.T) {
c := loadCfg(t)
client := newClient(c)
vec := embed(t, client, c.model)
srcDim := len(vec)
tgtDim := srcDim / 2
if tgtDim == 0 {
t.Skipf("source dim %d too small to halve", srcDim)
}
adp, err := adapter.NewRandom(srcDim, tgtDim, 42)
require.NoError(t, err)
out, err := adp.Adapt(vec)
require.NoError(t, err)
assert.Len(t, out, tgtDim)
assertUnitNorm(t, out)
t.Logf("downscale random: %d → %d norm=%.6f", srcDim, tgtDim, l2Norm(out))
}
// TestDownscaleToFixed tests truncation to a fixed well-known target (e.g. 768 → 256).
// Skips if the native dimension is not larger than the target.
func TestDownscaleToFixed(t *testing.T) {
c := loadCfg(t)
client := newClient(c)
tgtDim := intEnv("VECNA_TEST_TARGET_DIM", 256)
vec := embed(t, client, c.model)
srcDim := len(vec)
if srcDim <= tgtDim {
t.Skipf("native dim %d is not larger than target dim %d", srcDim, tgtDim)
}
adp, err := adapter.NewTruncate(srcDim, tgtDim, adapter.TruncateFromEnd, adapter.PadAtEnd)
require.NoError(t, err)
out, err := adp.Adapt(vec)
require.NoError(t, err)
assert.Len(t, out, tgtDim)
assertUnitNorm(t, out)
t.Logf("downscale to fixed: %d → %d norm=%.6f", srcDim, tgtDim, l2Norm(out))
}
// TestUpscalePadEnd tests zero-padding to double the native dimension.
func TestUpscalePadEnd(t *testing.T) {
c := loadCfg(t)
client := newClient(c)
vec := embed(t, client, c.model)
srcDim := len(vec)
tgtDim := srcDim * 2
adp, err := adapter.NewTruncate(srcDim, tgtDim, adapter.TruncateFromEnd, adapter.PadAtEnd)
require.NoError(t, err)
out, err := adp.Adapt(vec)
require.NoError(t, err)
assert.Len(t, out, tgtDim)
assertUnitNorm(t, out)
// The second half of the raw output (before normalisation) should have been zero-padded.
// After normalisation all values shrink but the last half should all be equal (zero → 0).
t.Logf("upscale pad-end: %d → %d norm=%.6f", srcDim, tgtDim, l2Norm(out))
}
// TestUpscalePadStart tests zero-padding prepended to the vector.
func TestUpscalePadStart(t *testing.T) {
c := loadCfg(t)
client := newClient(c)
vec := embed(t, client, c.model)
srcDim := len(vec)
tgtDim := srcDim * 2
adp, err := adapter.NewTruncate(srcDim, tgtDim, adapter.TruncateFromEnd, adapter.PadAtStart)
require.NoError(t, err)
out, err := adp.Adapt(vec)
require.NoError(t, err)
assert.Len(t, out, tgtDim)
assertUnitNorm(t, out)
t.Logf("upscale pad-start: %d → %d norm=%.6f", srcDim, tgtDim, l2Norm(out))
}
// TestUpscaleRandom tests random projection to a higher dimension.
func TestUpscaleRandom(t *testing.T) {
c := loadCfg(t)
client := newClient(c)
vec := embed(t, client, c.model)
srcDim := len(vec)
tgtDim := srcDim * 2
adp, err := adapter.NewRandom(srcDim, tgtDim, 42)
require.NoError(t, err)
out, err := adp.Adapt(vec)
require.NoError(t, err)
assert.Len(t, out, tgtDim)
assertUnitNorm(t, out)
t.Logf("upscale random: %d → %d norm=%.6f", srcDim, tgtDim, l2Norm(out))
}
// TestUpscaleToFixed tests upscaling to a fixed well-known target (e.g. 768 → 1536).
// Skips if the native dimension is already larger than or equal to the target.
func TestUpscaleToFixed(t *testing.T) {
c := loadCfg(t)
client := newClient(c)
tgtDim := intEnv("VECNA_TEST_TARGET_DIM", 1536)
vec := embed(t, client, c.model)
srcDim := len(vec)
if srcDim >= tgtDim {
t.Skipf("native dim %d is not smaller than target dim %d", srcDim, tgtDim)
}
adp, err := adapter.NewTruncate(srcDim, tgtDim, adapter.TruncateFromEnd, adapter.PadAtEnd)
require.NoError(t, err)
out, err := adp.Adapt(vec)
require.NoError(t, err)
assert.Len(t, out, tgtDim)
assertUnitNorm(t, out)
t.Logf("upscale to fixed: %d → %d norm=%.6f", srcDim, tgtDim, l2Norm(out))
}
// TestRoundtripConsistency embeds the same text twice and checks the vectors are identical.
func TestRoundtripConsistency(t *testing.T) {
c := loadCfg(t)
client := newClient(c)
v1 := embed(t, client, c.model)
v2 := embed(t, client, c.model)
require.Equal(t, len(v1), len(v2), "dimension mismatch between two identical requests")
var maxDiff float32
for i := range v1 {
d := v1[i] - v2[i]
if d < 0 {
d = -d
}
if d > maxDiff {
maxDiff = d
}
}
t.Logf("max element-wise diff between two identical embeds: %e", maxDiff)
assert.Less(t, maxDiff, float32(1e-5), "embeddings for identical input should be deterministic")
}
// intEnv reads an integer from an env var, returning defaultVal if unset or invalid.
func intEnv(key string, defaultVal int) int {
if s := os.Getenv(key); s != "" {
if n, err := strconv.Atoi(s); err == nil {
return n
}
}
return defaultVal
}

BIN
vecna Executable file

Binary file not shown.