feat: 🎉 Vectors na Vectors, the begining

Translate 1536 <-> 768 , 3072 <-> 2048
This commit is contained in:
2026-04-11 18:05:05 +02:00
parent d98ea7c222
commit 4009a54e39
58 changed files with 5324 additions and 2 deletions

77
pkg/adapter/adapter.go Normal file
View File

@@ -0,0 +1,77 @@
package adapter
import (
"errors"
"fmt"
)
// ErrDimMismatch is returned when the input vector length does not match the adapter's source dimension.
var ErrDimMismatch = errors.New("vector dimension mismatch")
// ErrInvalidDim is returned when source or target dimensions are invalid.
var ErrInvalidDim = errors.New("invalid dimension")
// ErrInvalidMatrix is returned when a projection matrix has the wrong shape.
var ErrInvalidMatrix = errors.New("invalid projection matrix shape")
// TruncateMode controls which end of the vector is dropped when downscaling.
type TruncateMode int
const (
// TruncateFromEnd keeps the first targetDim elements (default; correct for Matryoshka models).
TruncateFromEnd TruncateMode = iota
// TruncateFromStart keeps the last targetDim elements.
TruncateFromStart
)
// PadMode controls which end of the vector receives zero-padding when upscaling.
type PadMode int
const (
// PadAtEnd appends zeros to the end (default).
PadAtEnd PadMode = iota
// PadAtStart prepends zeros to the start.
PadAtStart
)
// Adapter translates vectors between two fixed dimensions.
type Adapter interface {
Adapt(vec []float32) ([]float32, error)
SourceDim() int
TargetDim() int
}
// NewTruncate returns a TruncateAdapter for Matryoshka-style or simple truncation/padding.
// t controls which end is dropped when downscaling; p controls which end is padded when upscaling.
func NewTruncate(sourceDim, targetDim int, t TruncateMode, p PadMode) (Adapter, error) {
if sourceDim <= 0 || targetDim <= 0 {
return nil, fmt.Errorf("NewTruncate: %w: source=%d target=%d", ErrInvalidDim, sourceDim, targetDim)
}
return &truncateAdapter{sourceDim: sourceDim, targetDim: targetDim, truncMode: t, padMode: p}, nil
}
// NewRandom returns a RandomAdapter backed by a seeded Gaussian projection matrix.
// seed=0 uses a time-based seed.
func NewRandom(sourceDim, targetDim int, seed int64) (Adapter, error) {
if sourceDim <= 0 || targetDim <= 0 {
return nil, fmt.Errorf("NewRandom: %w: source=%d target=%d", ErrInvalidDim, sourceDim, targetDim)
}
return newRandomAdapter(sourceDim, targetDim, seed), nil
}
// NewProjection returns a ProjectionAdapter using a caller-supplied matrix.
// matrix must have shape [targetDim][sourceDim].
func NewProjection(sourceDim, targetDim int, matrix [][]float32) (Adapter, error) {
if sourceDim <= 0 || targetDim <= 0 {
return nil, fmt.Errorf("NewProjection: %w: source=%d target=%d", ErrInvalidDim, sourceDim, targetDim)
}
if len(matrix) != targetDim {
return nil, fmt.Errorf("NewProjection: %w: got %d rows, want %d", ErrInvalidMatrix, len(matrix), targetDim)
}
for i, row := range matrix {
if len(row) != sourceDim {
return nil, fmt.Errorf("NewProjection: %w: row %d has %d cols, want %d", ErrInvalidMatrix, i, len(row), sourceDim)
}
}
return &projectionAdapter{sourceDim: sourceDim, targetDim: targetDim, matrix: matrix}, nil
}

236
pkg/adapter/adapter_test.go Normal file
View File

@@ -0,0 +1,236 @@
package adapter
import (
"math"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// unitVec returns a unit vector of length n with value 1/sqrt(n) in each position.
func unitVec(n int) []float32 {
v := make([]float32, n)
val := float32(1.0 / math.Sqrt(float64(n)))
for i := range v {
v[i] = val
}
return v
}
func assertNormOne(t *testing.T, v []float32) {
t.Helper()
var sum float64
for _, x := range v {
sum += float64(x) * float64(x)
}
assert.InDelta(t, 1.0, sum, 1e-5, "expected unit norm")
}
// --- L2Norm ---
func TestL2Norm(t *testing.T) {
tests := []struct {
name string
input []float32
wantLen int
wantNorm float64
}{
{"unit vector unchanged", []float32{1, 0, 0}, 3, 1.0},
{"scale down to unit", []float32{2, 0, 0}, 3, 1.0},
{"multi-dim", unitVec(4), 4, 1.0},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := L2Norm(tc.input)
assert.Len(t, got, tc.wantLen)
assertNormOne(t, got)
})
}
t.Run("zero vector returns copy unchanged", func(t *testing.T) {
in := []float32{0, 0, 0}
got := L2Norm(in)
assert.Equal(t, in, got)
})
}
// --- TruncateAdapter ---
func TestNewTruncate_InvalidDims(t *testing.T) {
_, err := NewTruncate(0, 768, TruncateFromEnd, PadAtEnd)
require.ErrorIs(t, err, ErrInvalidDim)
_, err = NewTruncate(1536, 0, TruncateFromEnd, PadAtEnd)
require.ErrorIs(t, err, ErrInvalidDim)
}
func TestTruncateAdapter(t *testing.T) {
tests := []struct {
name string
src, tgt int
truncMode TruncateMode
padMode PadMode
// builder for input vec; checks are len + norm only unless wantFirst/wantLast set
wantLen int
wantErr bool
}{
{"downscale TruncateFromEnd", 1536, 768, TruncateFromEnd, PadAtEnd, 768, false},
{"downscale TruncateFromStart", 1536, 768, TruncateFromStart, PadAtEnd, 768, false},
{"upscale PadAtEnd", 768, 1536, TruncateFromEnd, PadAtEnd, 1536, false},
{"upscale PadAtStart", 768, 1536, TruncateFromEnd, PadAtStart, 1536, false},
{"same dim", 768, 768, TruncateFromEnd, PadAtEnd, 768, false},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
a, err := NewTruncate(tc.src, tc.tgt, tc.truncMode, tc.padMode)
require.NoError(t, err)
assert.Equal(t, tc.src, a.SourceDim())
assert.Equal(t, tc.tgt, a.TargetDim())
got, err := a.Adapt(unitVec(tc.src))
require.NoError(t, err)
assert.Len(t, got, tc.wantLen)
assertNormOne(t, got)
})
}
t.Run("dim mismatch returns error", func(t *testing.T) {
a, err := NewTruncate(1536, 768, TruncateFromEnd, PadAtEnd)
require.NoError(t, err)
_, err = a.Adapt(make([]float32, 100))
require.ErrorIs(t, err, ErrDimMismatch)
})
t.Run("TruncateFromEnd keeps first elements", func(t *testing.T) {
a, err := NewTruncate(4, 2, TruncateFromEnd, PadAtEnd)
require.NoError(t, err)
// input: [1,2,3,4] — TruncateFromEnd keeps [1,2]
in := []float32{1, 2, 3, 4}
got, err := a.Adapt(in)
require.NoError(t, err)
// after L2Norm the signs are preserved; ratio should match
assert.Greater(t, got[0], float32(0))
assert.Greater(t, got[1], float32(0))
})
t.Run("TruncateFromStart keeps last elements", func(t *testing.T) {
a, err := NewTruncate(4, 2, TruncateFromStart, PadAtEnd)
require.NoError(t, err)
// input: [0,0,3,4] — TruncateFromStart keeps [3,4]
in := []float32{0, 0, 3, 4}
got, err := a.Adapt(in)
require.NoError(t, err)
assert.Greater(t, got[0], float32(0))
assert.Greater(t, got[1], float32(0))
})
t.Run("PadAtStart zero-pads front", func(t *testing.T) {
a, err := NewTruncate(2, 4, TruncateFromEnd, PadAtStart)
require.NoError(t, err)
in := []float32{1, 0}
got, err := a.Adapt(in)
require.NoError(t, err)
// first two positions should be zero (padded), last two carry the signal
assert.Equal(t, float32(0), got[0])
assert.Equal(t, float32(0), got[1])
})
}
// --- RandomAdapter ---
func TestNewRandom_InvalidDims(t *testing.T) {
_, err := NewRandom(0, 768, 42)
require.ErrorIs(t, err, ErrInvalidDim)
}
func TestRandomAdapter(t *testing.T) {
t.Run("output length and unit norm", func(t *testing.T) {
a, err := NewRandom(1536, 768, 42)
require.NoError(t, err)
got, err := a.Adapt(unitVec(1536))
require.NoError(t, err)
assert.Len(t, got, 768)
assertNormOne(t, got)
})
t.Run("deterministic with same seed", func(t *testing.T) {
a1, _ := NewRandom(64, 32, 99)
a2, _ := NewRandom(64, 32, 99)
in := unitVec(64)
out1, _ := a1.Adapt(in)
out2, _ := a2.Adapt(in)
assert.Equal(t, out1, out2)
})
t.Run("different seeds produce different output", func(t *testing.T) {
a1, _ := NewRandom(64, 32, 1)
a2, _ := NewRandom(64, 32, 2)
in := unitVec(64)
out1, _ := a1.Adapt(in)
out2, _ := a2.Adapt(in)
assert.NotEqual(t, out1, out2)
})
t.Run("dim mismatch returns error", func(t *testing.T) {
a, err := NewRandom(1536, 768, 1)
require.NoError(t, err)
_, err = a.Adapt(make([]float32, 10))
require.ErrorIs(t, err, ErrDimMismatch)
})
}
// --- ProjectionAdapter ---
func TestNewProjection_InvalidMatrix(t *testing.T) {
tests := []struct {
name string
src int
tgt int
matrix [][]float32
errIs error
}{
{"wrong row count", 4, 2, [][]float32{{1, 2, 3, 4}}, ErrInvalidMatrix},
{"wrong col count", 4, 2, [][]float32{{1, 2}, {3, 4}}, ErrInvalidMatrix},
{"zero sourceDim", 0, 2, [][]float32{{1}, {2}}, ErrInvalidDim},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
_, err := NewProjection(tc.src, tc.tgt, tc.matrix)
require.ErrorIs(t, err, tc.errIs)
})
}
}
func TestProjectionAdapter(t *testing.T) {
t.Run("identity matrix same dim", func(t *testing.T) {
// 3×3 identity
id := [][]float32{{1, 0, 0}, {0, 1, 0}, {0, 0, 1}}
a, err := NewProjection(3, 3, id)
require.NoError(t, err)
in := []float32{1, 2, 3}
got, err := a.Adapt(in)
require.NoError(t, err)
assert.Len(t, got, 3)
assertNormOne(t, got)
})
t.Run("downscale projection", func(t *testing.T) {
// 2×4 matrix
m := [][]float32{{1, 0, 0, 0}, {0, 1, 0, 0}}
a, err := NewProjection(4, 2, m)
require.NoError(t, err)
got, err := a.Adapt([]float32{3, 4, 0, 0})
require.NoError(t, err)
assert.Len(t, got, 2)
assertNormOne(t, got)
})
t.Run("dim mismatch returns error", func(t *testing.T) {
m := [][]float32{{1, 0}, {0, 1}}
a, err := NewProjection(2, 2, m)
require.NoError(t, err)
_, err = a.Adapt(make([]float32, 5))
require.ErrorIs(t, err, ErrDimMismatch)
})
}

23
pkg/adapter/normalize.go Normal file
View File

@@ -0,0 +1,23 @@
package adapter
import "math"
// L2Norm returns a new slice with the vector normalized to unit length.
// If the vector has zero magnitude it is returned unchanged.
func L2Norm(v []float32) []float32 {
var sum float64
for _, x := range v {
sum += float64(x) * float64(x)
}
if sum == 0 {
out := make([]float32, len(v))
copy(out, v)
return out
}
norm := float32(math.Sqrt(sum))
out := make([]float32, len(v))
for i, x := range v {
out[i] = x / norm
}
return out
}

32
pkg/adapter/projection.go Normal file
View File

@@ -0,0 +1,32 @@
package adapter
import "fmt"
type projectionAdapter struct {
sourceDim int
targetDim int
matrix [][]float32
}
func (a *projectionAdapter) SourceDim() int { return a.sourceDim }
func (a *projectionAdapter) TargetDim() int { return a.targetDim }
func (a *projectionAdapter) Adapt(vec []float32) ([]float32, error) {
if len(vec) != a.sourceDim {
return nil, fmt.Errorf("projection adapt: %w: got %d, want %d", ErrDimMismatch, len(vec), a.sourceDim)
}
return L2Norm(matVecMul(a.matrix, vec)), nil
}
// matVecMul computes m·v where m is [rows][cols] and v has len cols.
func matVecMul(m [][]float32, v []float32) []float32 {
out := make([]float32, len(m))
for i, row := range m {
var sum float32
for j, val := range row {
sum += val * v[j]
}
out[i] = sum
}
return out
}

45
pkg/adapter/random.go Normal file
View File

@@ -0,0 +1,45 @@
package adapter
import (
"fmt"
"math"
"math/rand"
"time"
)
type randomAdapter struct {
sourceDim int
targetDim int
matrix [][]float32
}
func newRandomAdapter(sourceDim, targetDim int, seed int64) *randomAdapter {
if seed == 0 {
seed = time.Now().UnixNano()
}
//nolint:gosec // deterministic seeded RNG for projection matrix generation, not security use
rng := rand.New(rand.NewSource(seed))
// Gaussian N(0, 1/targetDim) — preserves expected squared norms (Johnson-Lindenstrauss)
stddev := 1.0 / math.Sqrt(float64(targetDim))
matrix := make([][]float32, targetDim)
for i := range matrix {
row := make([]float32, sourceDim)
for j := range row {
row[j] = float32(rng.NormFloat64() * stddev)
}
matrix[i] = row
}
return &randomAdapter{sourceDim: sourceDim, targetDim: targetDim, matrix: matrix}
}
func (a *randomAdapter) SourceDim() int { return a.sourceDim }
func (a *randomAdapter) TargetDim() int { return a.targetDim }
func (a *randomAdapter) Adapt(vec []float32) ([]float32, error) {
if len(vec) != a.sourceDim {
return nil, fmt.Errorf("random adapt: %w: got %d, want %d", ErrDimMismatch, len(vec), a.sourceDim)
}
return L2Norm(matVecMul(a.matrix, vec)), nil
}

41
pkg/adapter/truncate.go Normal file
View File

@@ -0,0 +1,41 @@
package adapter
import "fmt"
type truncateAdapter struct {
sourceDim int
targetDim int
truncMode TruncateMode
padMode PadMode
}
func (a *truncateAdapter) SourceDim() int { return a.sourceDim }
func (a *truncateAdapter) TargetDim() int { return a.targetDim }
func (a *truncateAdapter) Adapt(vec []float32) ([]float32, error) {
if len(vec) != a.sourceDim {
return nil, fmt.Errorf("truncate adapt: %w: got %d, want %d", ErrDimMismatch, len(vec), a.sourceDim)
}
out := make([]float32, a.targetDim)
if a.targetDim <= a.sourceDim {
// Downscale: truncate
switch a.truncMode {
case TruncateFromEnd:
copy(out, vec[:a.targetDim])
case TruncateFromStart:
copy(out, vec[a.sourceDim-a.targetDim:])
}
} else {
// Upscale: zero-pad
switch a.padMode {
case PadAtEnd:
copy(out, vec)
case PadAtStart:
copy(out[a.targetDim-a.sourceDim:], vec)
}
}
return L2Norm(out), nil
}

168
pkg/config/config.go Normal file
View File

@@ -0,0 +1,168 @@
package config
import (
"errors"
"fmt"
"os"
"strings"
"github.com/spf13/viper"
)
// Config is the root configuration for vecna.
type Config struct {
Server ServerConfig `mapstructure:"server"`
Metrics MetricsConfig `mapstructure:"metrics"`
Forward ForwardConfig `mapstructure:"forward"`
Adapter AdapterConfig `mapstructure:"adapter"`
}
// ServerConfig controls the HTTP listener and inbound auth.
type ServerConfig struct {
Port int `mapstructure:"port"`
Host string `mapstructure:"host"`
APIKeys []string `mapstructure:"api_keys"`
}
// MetricsConfig controls Prometheus metrics exposure.
type MetricsConfig struct {
Enabled bool `mapstructure:"enabled"`
Path string `mapstructure:"path"`
APIKey string `mapstructure:"api_key"`
}
// ForwardConfig holds all named forwarding targets.
type ForwardConfig struct {
Default string `mapstructure:"default"`
Targets map[string]ForwardTarget `mapstructure:"targets"`
}
// ForwardTarget is a named backing embedding model with one or more endpoints.
type ForwardTarget struct {
Endpoints []EndpointConfig `mapstructure:"endpoints"`
Model string `mapstructure:"model"`
APIKey string `mapstructure:"api_key"`
APIType string `mapstructure:"api_type"`
TimeoutSecs int `mapstructure:"timeout_secs"`
CooldownSecs int `mapstructure:"cooldown_secs"`
PriorityDecay int `mapstructure:"priority_decay"`
PriorityRecovery int `mapstructure:"priority_recovery"`
}
// EndpointConfig is a single URL within a ForwardTarget.
type EndpointConfig struct {
URL string `mapstructure:"url"`
Priority int `mapstructure:"priority"`
APIKey string `mapstructure:"api_key"`
}
// AdapterConfig selects and tunes the dimension adapter.
type AdapterConfig struct {
Type string `mapstructure:"type"`
SourceDim int `mapstructure:"source_dim"`
TargetDim int `mapstructure:"target_dim"`
TruncateMode string `mapstructure:"truncate_mode"`
PadMode string `mapstructure:"pad_mode"`
Seed int64 `mapstructure:"seed"`
MatrixFile string `mapstructure:"matrix_file"`
}
// extensions viper will detect automatically.
var extensions = []string{"json", "yaml", "toml"}
// ResolveFile returns the config file path that would be used by Load.
// If cfgFile is non-empty it is returned as-is.
// Otherwise the default search paths are checked; if no existing file is found,
// the preferred default (~/.vecna.json) is returned so callers can create it.
func ResolveFile(cfgFile string) string {
if cfgFile != "" {
return cfgFile
}
home, _ := os.UserHomeDir()
dirs := []string{".", home, home + "/.config/vecna"}
for _, dir := range dirs {
for _, ext := range extensions {
path := dir + "/vecna." + ext
if _, err := os.Stat(path); err == nil {
return path
}
}
}
return home + "/vecna.json"
}
// Load reads configuration from the given file path (empty = search defaults),
// environment variables (prefix VECNA_), and applies built-in defaults.
func Load(cfgFile string) (*Config, error) {
v := viper.New()
// Defaults
v.SetDefault("server.port", 8080)
v.SetDefault("server.host", "0.0.0.0")
v.SetDefault("metrics.enabled", false)
v.SetDefault("metrics.path", "/metrics")
v.SetDefault("adapter.type", "truncate")
v.SetDefault("adapter.truncate_mode", "from_end")
v.SetDefault("adapter.pad_mode", "at_end")
// Environment
v.SetEnvPrefix("VECNA")
v.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
v.AutomaticEnv()
// Config file
if cfgFile != "" {
v.SetConfigFile(cfgFile)
} else {
home, _ := os.UserHomeDir()
v.SetConfigName("vecna")
// No SetConfigType — viper detects format from file extension (json, yaml, toml, etc.)
v.AddConfigPath(".")
v.AddConfigPath(home)
v.AddConfigPath(home + "/.config/vecna")
}
if err := v.ReadInConfig(); err != nil {
// Missing config file is acceptable when all required values come from flags/env
var notFound viper.ConfigFileNotFoundError
if !errors.As(err, &notFound) {
return nil, fmt.Errorf("load config: %w", err)
}
}
var cfg Config
if err := v.Unmarshal(&cfg); err != nil {
return nil, fmt.Errorf("unmarshal config: %w", err)
}
applyForwardDefaults(&cfg)
return &cfg, nil
}
// applyForwardDefaults fills in zero-value fields on ForwardTarget entries.
func applyForwardDefaults(cfg *Config) {
for name, t := range cfg.Forward.Targets {
if t.TimeoutSecs == 0 {
t.TimeoutSecs = 30
}
if t.CooldownSecs == 0 {
t.CooldownSecs = 60
}
if t.PriorityDecay == 0 {
t.PriorityDecay = 2
}
if t.PriorityRecovery == 0 {
t.PriorityRecovery = 5
}
for i, ep := range t.Endpoints {
if ep.Priority == 0 {
t.Endpoints[i].Priority = 10
}
if ep.APIKey == "" && t.APIKey != "" {
t.Endpoints[i].APIKey = t.APIKey
}
}
cfg.Forward.Targets[name] = t
}
}

137
pkg/config/update.go Normal file
View File

@@ -0,0 +1,137 @@
package config
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"strings"
)
// SaveTarget adds or replaces a named ForwardTarget in the config file at path.
// Only JSON config files are written in-place. For other formats an error is
// returned describing what to add manually.
func SaveTarget(path, name string, target ForwardTarget) error {
ext := strings.ToLower(strings.TrimPrefix(filepath.Ext(path), "."))
data, err := os.ReadFile(path)
if err != nil {
return fmt.Errorf("read config %s: %w", path, err)
}
switch ext {
case "json":
return saveTargetJSON(path, data, name, target)
default:
snippet, _ := json.MarshalIndent(map[string]ForwardTarget{name: target}, "", " ")
return fmt.Errorf(
"auto-update not supported for .%s files\n"+
"Add the following to the 'forward.targets' section of %s:\n\n%s",
ext, path, snippet,
)
}
}
// RemoveBrokenEndpoints removes failing endpoints from the config file.
// broken maps target name → set of failing endpoint URLs.
// If all endpoints of a target are removed, the target itself is deleted.
// Returns the list of removed items as human-readable strings.
func RemoveBrokenEndpoints(path string, broken map[string][]string) ([]string, error) {
ext := strings.ToLower(strings.TrimPrefix(filepath.Ext(path), "."))
if ext != "json" {
return nil, fmt.Errorf("auto-update not supported for .%s files; edit %s manually", ext, path)
}
data, err := os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("read config %s: %w", path, err)
}
var cfg Config
if err := json.Unmarshal(data, &cfg); err != nil {
return nil, fmt.Errorf("parse %s: %w", path, err)
}
var removed []string
for targetName, failedURLs := range broken {
target, ok := cfg.Forward.Targets[targetName]
if !ok {
continue
}
failSet := make(map[string]bool, len(failedURLs))
for _, u := range failedURLs {
failSet[u] = true
}
kept := target.Endpoints[:0]
for _, ep := range target.Endpoints {
if failSet[ep.URL] {
removed = append(removed, fmt.Sprintf("endpoint %s (target %q)", ep.URL, targetName))
} else {
kept = append(kept, ep)
}
}
if len(kept) == 0 {
delete(cfg.Forward.Targets, targetName)
removed = append(removed, fmt.Sprintf("target %q (all endpoints failed)", targetName))
if cfg.Forward.Default == targetName {
cfg.Forward.Default = ""
}
} else {
target.Endpoints = kept
cfg.Forward.Targets[targetName] = target
}
}
if len(removed) == 0 {
return nil, nil
}
out, err := json.MarshalIndent(cfg, "", " ")
if err != nil {
return nil, fmt.Errorf("marshal config: %w", err)
}
if err := os.WriteFile(path, append(out, '\n'), 0o600); err != nil {
return nil, fmt.Errorf("write %s: %w", path, err)
}
return removed, nil
}
// WriteConfig serialises cfg as indented JSON and atomically overwrites path.
func WriteConfig(path string, cfg Config) error {
out, err := json.MarshalIndent(cfg, "", " ")
if err != nil {
return fmt.Errorf("marshal config: %w", err)
}
if err := os.WriteFile(path, append(out, '\n'), 0o600); err != nil {
return fmt.Errorf("write %s: %w", path, err)
}
return nil
}
func saveTargetJSON(path string, data []byte, name string, target ForwardTarget) error {
var cfg Config
if err := json.Unmarshal(data, &cfg); err != nil {
return fmt.Errorf("parse %s: %w", path, err)
}
if cfg.Forward.Targets == nil {
cfg.Forward.Targets = make(map[string]ForwardTarget)
}
cfg.Forward.Targets[name] = target
if cfg.Forward.Default == "" {
cfg.Forward.Default = name
}
out, err := json.MarshalIndent(cfg, "", " ")
if err != nil {
return fmt.Errorf("marshal config: %w", err)
}
if err := os.WriteFile(path, append(out, '\n'), 0o600); err != nil {
return fmt.Errorf("write %s: %w", path, err)
}
return nil
}

206
pkg/discovery/discovery.go Normal file
View File

@@ -0,0 +1,206 @@
package discovery
import (
"context"
"encoding/json"
"fmt"
"net"
"net/http"
"sync"
"time"
)
// Kind describes a known server type.
type Kind struct {
Name string
APIType string
Port int
NeedsKey bool
}
// Found is a server discovered on the network.
type Found struct {
Kind Kind
BaseURL string
Models []string
}
// knownServers lists server types by their default port and display name.
// Ollama is listed separately because it uses a non-OpenAI probe endpoint.
var knownServers = []Kind{
{Name: "Ollama", APIType: "openai", Port: 11434},
{Name: "LM Studio", APIType: "openai", Port: 1234},
{Name: "vLLM", APIType: "openai", Port: 8000, NeedsKey: true},
{Name: "LocalAI", APIType: "openai", Port: 8080},
{Name: "Jan", APIType: "openai", Port: 1337},
{Name: "Kobold", APIType: "openai", Port: 5001},
{Name: "Tabby", APIType: "openai", Port: 9090},
}
// Scan concurrently probes localhost and LAN gateway addresses for known LLM servers.
// Results are returned in the order they are found (non-deterministic).
func Scan(ctx context.Context) []Found {
hosts := localHosts()
var (
mu sync.Mutex
results []Found
wg sync.WaitGroup
)
probeCtx, cancel := context.WithTimeout(ctx, 4*time.Second)
defer cancel()
httpClient := &http.Client{Timeout: 600 * time.Millisecond}
for _, host := range hosts {
for _, kind := range knownServers {
host := host
wg.Add(1)
go func(kind Kind) {
defer wg.Done()
baseURL := fmt.Sprintf("http://%s:%d", host, kind.Port)
models, err := probe(probeCtx, httpClient, baseURL, kind)
if err != nil {
return
}
mu.Lock()
results = append(results, Found{Kind: kind, BaseURL: baseURL, Models: models})
mu.Unlock()
}(kind)
}
}
wg.Wait()
return results
}
// Models fetches the model list from a single base URL and kind (for the models command).
func Models(ctx context.Context, baseURL string, kind Kind) ([]string, error) {
httpClient := &http.Client{Timeout: 5 * time.Second}
return probe(ctx, httpClient, baseURL, kind)
}
// localHosts returns localhost plus the .1 gateway of every local IPv4 subnet.
func localHosts() []string {
seen := map[string]bool{"127.0.0.1": true}
hosts := []string{"127.0.0.1"}
ifaces, err := net.Interfaces()
if err != nil {
return hosts
}
for _, iface := range ifaces {
if iface.Flags&net.FlagUp == 0 {
continue
}
addrs, err := iface.Addrs()
if err != nil {
continue
}
for _, addr := range addrs {
ipnet, ok := addr.(*net.IPNet)
if !ok {
continue
}
ip := ipnet.IP.To4()
if ip == nil || ip.IsLoopback() {
continue
}
// derive the likely gateway: network address + 1
gw := ip.Mask(ipnet.Mask)
gw[3] = 1
h := gw.String()
if !seen[h] {
seen[h] = true
hosts = append(hosts, h)
}
}
}
return hosts
}
// probe attempts to identify the server at baseURL and returns its model list.
func probe(ctx context.Context, client *http.Client, baseURL string, kind Kind) ([]string, error) {
// Ollama has its own endpoint; everything else is OpenAI-compatible
if kind.Name == "Ollama" {
models, err := probeOllama(ctx, client, baseURL)
if err != nil {
// Ollama also exposes /v1/models since v0.1.27 — fall back
return probeOpenAI(ctx, client, baseURL)
}
return models, nil
}
return probeOpenAI(ctx, client, baseURL)
}
// --- Ollama ---
type ollamaTagsResponse struct {
Models []struct {
Name string `json:"name"`
} `json:"models"`
}
func probeOllama(ctx context.Context, client *http.Client, baseURL string) ([]string, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, baseURL+"/api/tags", nil)
if err != nil {
return nil, err
}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("status %d", resp.StatusCode)
}
var body ollamaTagsResponse
if err := json.NewDecoder(resp.Body).Decode(&body); err != nil {
return nil, fmt.Errorf("decode ollama response: %w", err)
}
models := make([]string, len(body.Models))
for i, m := range body.Models {
models[i] = m.Name
}
return models, nil
}
// --- OpenAI-compatible ---
type openAIModelsResponse struct {
Data []struct {
ID string `json:"id"`
} `json:"data"`
}
func probeOpenAI(ctx context.Context, client *http.Client, baseURL string) ([]string, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, baseURL+"/v1/models", nil)
if err != nil {
return nil, err
}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("status %d", resp.StatusCode)
}
var body openAIModelsResponse
if err := json.NewDecoder(resp.Body).Decode(&body); err != nil {
return nil, fmt.Errorf("decode openai response: %w", err)
}
models := make([]string, len(body.Data))
for i, m := range body.Data {
models[i] = m.ID
}
return models, nil
}

27
pkg/embedclient/client.go Normal file
View File

@@ -0,0 +1,27 @@
package embedclient
import "context"
// Request is a batch of texts to embed.
type Request struct {
Texts []string
Model string
}
// Usage reports token consumption from the backing model.
type Usage struct {
PromptTokens int
TotalTokens int
}
// Response holds the raw embeddings returned by the backing model.
type Response struct {
Embeddings [][]float32
Model string
Usage Usage
}
// Client sends text to a backing embedding model and returns raw vectors.
type Client interface {
Embed(ctx context.Context, req Request) (Response, error)
}

98
pkg/embedclient/google.go Normal file
View File

@@ -0,0 +1,98 @@
package embedclient
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
)
type googleClient struct {
baseURL string
apiKey string
model string
httpClient *http.Client
}
// NewGoogle returns a Client that speaks the Google Gemini batchEmbedContents API.
func NewGoogle(baseURL, apiKey, model string, httpClient *http.Client) Client {
if httpClient == nil {
httpClient = http.DefaultClient
}
return &googleClient{baseURL: baseURL, apiKey: apiKey, model: model, httpClient: httpClient}
}
type googleBatchRequest struct {
Requests []googleEmbedRequest `json:"requests"`
}
type googleEmbedRequest struct {
Model string `json:"model"`
Content googleContent `json:"content"`
}
type googleContent struct {
Parts []googlePart `json:"parts"`
}
type googlePart struct {
Text string `json:"text"`
}
type googleBatchResponse struct {
Embeddings []struct {
Values []float32 `json:"values"`
} `json:"embeddings"`
}
func (c *googleClient) Embed(ctx context.Context, req Request) (Response, error) {
requests := make([]googleEmbedRequest, len(req.Texts))
for i, text := range req.Texts {
requests[i] = googleEmbedRequest{
Model: "models/" + c.model,
Content: googleContent{Parts: []googlePart{{Text: text}}},
}
}
body, err := json.Marshal(googleBatchRequest{Requests: requests})
if err != nil {
return Response{}, fmt.Errorf("google embed marshal: %w", err)
}
url := fmt.Sprintf("%s/v1/models/%s:batchEmbedContents", c.baseURL, c.model)
if c.apiKey != "" {
url += "?key=" + c.apiKey
}
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
if err != nil {
return Response{}, fmt.Errorf("google embed request: %w", err)
}
httpReq.Header.Set("Content-Type", "application/json")
resp, err := c.httpClient.Do(httpReq)
if err != nil {
return Response{}, fmt.Errorf("google embed do: %w", err)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
return Response{}, fmt.Errorf("google embed: unexpected status %d", resp.StatusCode)
}
var gResp googleBatchResponse
if err := json.NewDecoder(resp.Body).Decode(&gResp); err != nil {
return Response{}, fmt.Errorf("google embed decode: %w", err)
}
embeddings := make([][]float32, len(gResp.Embeddings))
for i, e := range gResp.Embeddings {
embeddings[i] = e.Values
}
return Response{
Embeddings: embeddings,
Model: c.model,
}, nil
}

87
pkg/embedclient/openai.go Normal file
View File

@@ -0,0 +1,87 @@
package embedclient
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
)
type openAIClient struct {
baseURL string
apiKey string
httpClient *http.Client
}
// NewOpenAI returns a Client that speaks the OpenAI embeddings API.
func NewOpenAI(baseURL, apiKey string, httpClient *http.Client) Client {
if httpClient == nil {
httpClient = http.DefaultClient
}
return &openAIClient{baseURL: baseURL, apiKey: apiKey, httpClient: httpClient}
}
type openAIEmbedRequest struct {
Input []string `json:"input"`
Model string `json:"model"`
}
type openAIEmbedResponse struct {
Object string `json:"object"`
Data []struct {
Object string `json:"object"`
Embedding []float32 `json:"embedding"`
Index int `json:"index"`
} `json:"data"`
Model string `json:"model"`
Usage struct {
PromptTokens int `json:"prompt_tokens"`
TotalTokens int `json:"total_tokens"`
} `json:"usage"`
}
func (c *openAIClient) Embed(ctx context.Context, req Request) (Response, error) {
body, err := json.Marshal(openAIEmbedRequest{Input: req.Texts, Model: req.Model})
if err != nil {
return Response{}, fmt.Errorf("openai embed marshal: %w", err)
}
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+"/v1/embeddings", bytes.NewReader(body))
if err != nil {
return Response{}, fmt.Errorf("openai embed request: %w", err)
}
httpReq.Header.Set("Content-Type", "application/json")
if c.apiKey != "" {
httpReq.Header.Set("Authorization", "Bearer "+c.apiKey)
}
resp, err := c.httpClient.Do(httpReq)
if err != nil {
return Response{}, fmt.Errorf("openai embed do: %w", err)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
return Response{}, fmt.Errorf("openai embed: unexpected status %d", resp.StatusCode)
}
var oaiResp openAIEmbedResponse
if err := json.NewDecoder(resp.Body).Decode(&oaiResp); err != nil {
return Response{}, fmt.Errorf("openai embed decode: %w", err)
}
embeddings := make([][]float32, len(oaiResp.Data))
for _, d := range oaiResp.Data {
embeddings[d.Index] = d.Embedding
}
return Response{
Embeddings: embeddings,
Model: oaiResp.Model,
Usage: Usage{
PromptTokens: oaiResp.Usage.PromptTokens,
TotalTokens: oaiResp.Usage.TotalTokens,
},
}, nil
}

180
pkg/embedclient/router.go Normal file
View File

@@ -0,0 +1,180 @@
package embedclient
import (
"context"
"fmt"
"sync"
"time"
"github.com/Warky-Devs/vecna.git/pkg/metrics"
)
// RouterConfig holds tuning parameters for a TargetRouter.
type RouterConfig struct {
TargetName string
TimeoutSecs int
CooldownSecs int
PriorityDecay int
PriorityRecovery int
}
type endpointSlot struct {
client Client
url string
initialPriority int
mu sync.Mutex
priority int
inflight int
successCount int
lastFail time.Time
}
// TargetRouter implements Client by routing requests across multiple endpoint slots
// using a busyness-based priority algorithm.
type TargetRouter struct {
slots []*endpointSlot
cfg RouterConfig
metrics *metrics.Registry
}
// NewTargetRouter constructs a TargetRouter from a slice of (client, url, initialPriority) tuples.
func NewTargetRouter(slots []RouterSlot, cfg RouterConfig, reg *metrics.Registry) (*TargetRouter, error) {
if len(slots) == 0 {
return nil, fmt.Errorf("NewTargetRouter: at least one slot required")
}
es := make([]*endpointSlot, len(slots))
for i, s := range slots {
es[i] = &endpointSlot{
client: s.Client,
url: s.URL,
initialPriority: s.Priority,
priority: s.Priority,
}
if reg != nil {
reg.SetEndpointPriority(cfg.TargetName, s.URL, float64(s.Priority))
reg.SetEndpointInflight(cfg.TargetName, s.URL, 0)
}
}
return &TargetRouter{slots: es, cfg: cfg, metrics: reg}, nil
}
// RouterSlot is a single endpoint entry for NewTargetRouter.
type RouterSlot struct {
Client Client
URL string
Priority int
}
func (r *TargetRouter) Embed(ctx context.Context, req Request) (Response, error) {
slot := r.pick()
slot.mu.Lock()
slot.inflight++
if r.metrics != nil {
r.metrics.SetEndpointInflight(r.cfg.TargetName, slot.url, float64(slot.inflight))
}
slot.mu.Unlock()
defer func() {
slot.mu.Lock()
slot.inflight--
if r.metrics != nil {
r.metrics.SetEndpointInflight(r.cfg.TargetName, slot.url, float64(slot.inflight))
}
slot.mu.Unlock()
}()
timeout := time.Duration(r.cfg.TimeoutSecs) * time.Second
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
resp, err := slot.client.Embed(ctx, req)
if err != nil {
r.onFailure(slot, err)
return Response{}, fmt.Errorf("router embed [%s]: %w", slot.url, err)
}
r.onSuccess(slot)
return resp, nil
}
// pick selects the best available slot.
func (r *TargetRouter) pick() *endpointSlot {
cooldown := time.Duration(r.cfg.CooldownSecs) * time.Second
now := time.Now()
var best *endpointSlot
bestScore := -1 << 30
for _, s := range r.slots {
s.mu.Lock()
inCooldown := !s.lastFail.IsZero() && now.Sub(s.lastFail) < cooldown
score := s.priority - s.inflight
lastFail := s.lastFail
s.mu.Unlock()
if inCooldown {
continue
}
if best == nil || score > bestScore {
best = s
bestScore = score
_ = lastFail
}
}
// All in cooldown — fall back to oldest failure
if best == nil {
var oldest time.Time
for _, s := range r.slots {
s.mu.Lock()
lf := s.lastFail
s.mu.Unlock()
if best == nil || lf.Before(oldest) {
best = s
oldest = lf
}
}
}
return best
}
func (r *TargetRouter) onSuccess(s *endpointSlot) {
s.mu.Lock()
defer s.mu.Unlock()
s.successCount++
if r.cfg.PriorityRecovery > 0 && s.successCount%r.cfg.PriorityRecovery == 0 {
if s.priority < s.initialPriority {
s.priority++
}
}
if r.metrics != nil {
r.metrics.SetEndpointPriority(r.cfg.TargetName, s.url, float64(s.priority))
}
}
func (r *TargetRouter) onFailure(s *endpointSlot, err error) {
s.mu.Lock()
defer s.mu.Unlock()
s.lastFail = time.Now()
s.priority -= r.cfg.PriorityDecay
if s.priority < 1 {
s.priority = 1
}
errType := "error"
if ctx := context.Background(); ctx.Err() != nil {
errType = "timeout"
}
if r.metrics != nil {
r.metrics.SetEndpointPriority(r.cfg.TargetName, s.url, float64(s.priority))
r.metrics.IncEndpointErrors(r.cfg.TargetName, s.url, errType)
}
_ = err
}

112
pkg/metrics/metrics.go Normal file
View File

@@ -0,0 +1,112 @@
package metrics
import (
"github.com/prometheus/client_golang/prometheus"
)
// Registry holds all vecna Prometheus metrics on a dedicated (non-global) registry.
type Registry struct {
reg *prometheus.Registry
RequestsTotal *prometheus.CounterVec
RequestDuration *prometheus.HistogramVec
ForwardDuration *prometheus.HistogramVec
TranslateDuration *prometheus.HistogramVec
EndpointPriority *prometheus.GaugeVec
EndpointInflight *prometheus.GaugeVec
EndpointErrorsTotal *prometheus.CounterVec
TokensTotal *prometheus.CounterVec
}
// New creates and registers all metrics on a fresh Prometheus registry.
func New() *Registry {
reg := prometheus.NewRegistry()
r := &Registry{
reg: reg,
RequestsTotal: prometheus.NewCounterVec(prometheus.CounterOpts{
Name: "vecna_requests_total",
Help: "Total number of requests served by vecna.",
}, []string{"endpoint", "status"}),
RequestDuration: prometheus.NewHistogramVec(prometheus.HistogramOpts{
Name: "vecna_request_duration_seconds",
Help: "Total request wall-clock time.",
Buckets: prometheus.DefBuckets,
}, []string{"endpoint"}),
ForwardDuration: prometheus.NewHistogramVec(prometheus.HistogramOpts{
Name: "vecna_forward_duration_seconds",
Help: "Time spent waiting on the backing embedding model.",
Buckets: prometheus.DefBuckets,
}, []string{"target", "url"}),
TranslateDuration: prometheus.NewHistogramVec(prometheus.HistogramOpts{
Name: "vecna_translate_duration_seconds",
Help: "Time spent in the dimension adapter.",
Buckets: []float64{0.0001, 0.0005, 0.001, 0.005, 0.01, 0.05},
}, []string{"adapter_type"}),
EndpointPriority: prometheus.NewGaugeVec(prometheus.GaugeOpts{
Name: "vecna_endpoint_priority",
Help: "Current dynamic routing priority for a forwarding endpoint.",
}, []string{"target", "url"}),
EndpointInflight: prometheus.NewGaugeVec(prometheus.GaugeOpts{
Name: "vecna_endpoint_inflight",
Help: "Number of active in-flight requests per forwarding endpoint.",
}, []string{"target", "url"}),
EndpointErrorsTotal: prometheus.NewCounterVec(prometheus.CounterOpts{
Name: "vecna_endpoint_errors_total",
Help: "Total forwarding errors per endpoint, labelled by error type.",
}, []string{"target", "url", "error"}),
TokensTotal: prometheus.NewCounterVec(prometheus.CounterOpts{
Name: "vecna_tokens_total",
Help: "Tokens consumed by the backing embedding model, by target, model, and token type.",
}, []string{"target", "model", "token_type"}),
}
reg.MustRegister(
r.RequestsTotal,
r.RequestDuration,
r.ForwardDuration,
r.TranslateDuration,
r.EndpointPriority,
r.EndpointInflight,
r.EndpointErrorsTotal,
r.TokensTotal,
)
return r
}
// Prometheus returns the underlying registry for use with promhttp.HandlerFor.
func (r *Registry) Prometheus() *prometheus.Registry {
return r.reg
}
// Convenience setters used by the router.
func (r *Registry) SetEndpointPriority(target, url string, v float64) {
r.EndpointPriority.WithLabelValues(target, url).Set(v)
}
func (r *Registry) SetEndpointInflight(target, url string, v float64) {
r.EndpointInflight.WithLabelValues(target, url).Set(v)
}
func (r *Registry) IncEndpointErrors(target, url, errType string) {
r.EndpointErrorsTotal.WithLabelValues(target, url, errType).Inc()
}
func (r *Registry) AddTokens(target, model string, promptTokens, totalTokens int) {
if promptTokens > 0 {
r.TokensTotal.WithLabelValues(target, model, "prompt").Add(float64(promptTokens))
}
if totalTokens > 0 {
r.TokensTotal.WithLabelValues(target, model, "total").Add(float64(totalTokens))
}
}

62
pkg/metrics/middleware.go Normal file
View File

@@ -0,0 +1,62 @@
package metrics
import (
"fmt"
"net/http"
"github.com/uptrace/bunrouter"
)
// Middleware returns a bunrouter middleware that records per-request Prometheus metrics.
// It reads timing from the RequestTrace stored in the context (set by server/trace.go).
// The trace target/url labels are optional; pass empty strings if not applicable.
func (r *Registry) Middleware(getTrace func(req bunrouter.Request) TraceSnapshot) bunrouter.MiddlewareFunc {
return func(next bunrouter.HandlerFunc) bunrouter.HandlerFunc {
return func(w http.ResponseWriter, req bunrouter.Request) error {
rw := &statusWriter{ResponseWriter: w, status: http.StatusOK}
err := next(rw, req)
snap := getTrace(req)
endpoint := req.URL.Path
status := fmt.Sprintf("%d", rw.status)
r.RequestsTotal.WithLabelValues(endpoint, status).Inc()
r.RequestDuration.WithLabelValues(endpoint).Observe(snap.TotalSeconds)
if snap.ForwardTarget != "" {
r.ForwardDuration.WithLabelValues(snap.ForwardTarget, snap.ForwardURL).Observe(snap.ForwardSeconds)
}
if snap.AdapterType != "" {
r.TranslateDuration.WithLabelValues(snap.AdapterType).Observe(snap.TranslateSeconds)
}
if snap.PromptTokens > 0 || snap.TotalTokens > 0 {
r.AddTokens(snap.ForwardTarget, snap.ForwardModel, snap.PromptTokens, snap.TotalTokens)
}
return err
}
}
}
// TraceSnapshot carries the timing and usage values the metrics middleware needs.
type TraceSnapshot struct {
TotalSeconds float64
ForwardSeconds float64
TranslateSeconds float64
ForwardTarget string
ForwardURL string
ForwardModel string
AdapterType string
PromptTokens int
TotalTokens int
}
// statusWriter wraps http.ResponseWriter to capture the written status code.
type statusWriter struct {
http.ResponseWriter
status int
}
func (sw *statusWriter) WriteHeader(code int) {
sw.status = code
sw.ResponseWriter.WriteHeader(code)
}

135
pkg/server/google.go Normal file
View File

@@ -0,0 +1,135 @@
package server
import (
"encoding/json"
"net/http"
"time"
"github.com/uptrace/bunrouter"
"github.com/Warky-Devs/vecna.git/pkg/embedclient"
)
// --- single embedContent ---
type googleEmbedContentRequest struct {
Content googleContent `json:"content"`
TaskType string `json:"taskType,omitempty"`
}
type googleContent struct {
Parts []googlePart `json:"parts"`
}
type googlePart struct {
Text string `json:"text"`
}
type googleEmbedContentResponse struct {
Embedding googleEmbeddingValues `json:"embedding"`
}
type googleEmbeddingValues struct {
Values []float32 `json:"values"`
}
func (h *handler) googleEmbedContent(w http.ResponseWriter, req bunrouter.Request) error {
model := req.Param("model")
var body googleEmbedContentRequest
if err := json.NewDecoder(req.Body).Decode(&body); err != nil {
return writeJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid request body"})
}
texts := make([]string, len(body.Content.Parts))
for i, p := range body.Content.Parts {
texts[i] = p.Text
}
client, targetName, targetURL := h.resolveClient(model)
trace := TraceFromContext(req.Context())
trace.ForwardTarget = targetName
trace.ForwardURL = targetURL
t0 := time.Now()
embedResp, err := client.Embed(req.Context(), embedclient.Request{Texts: texts, Model: model})
trace.ForwardDuration = time.Since(t0)
if err != nil {
return writeJSON(w, http.StatusBadGateway, map[string]string{"error": err.Error()})
}
trace.ForwardModel = embedResp.Model
trace.PromptTokens = embedResp.Usage.PromptTokens
trace.TotalTokens = embedResp.Usage.TotalTokens
t1 := time.Now()
var adapted []float32
if len(embedResp.Embeddings) > 0 {
adapted, err = h.adapter.Adapt(embedResp.Embeddings[0])
if err != nil {
return writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()})
}
}
trace.TranslateDuration = time.Since(t1)
writeTraceHeaders(w, trace)
return writeJSON(w, http.StatusOK, googleEmbedContentResponse{
Embedding: googleEmbeddingValues{Values: adapted},
})
}
// --- batch batchEmbedContents ---
type googleBatchRequest struct {
Requests []googleEmbedContentRequest `json:"requests"`
}
type googleBatchResponse struct {
Embeddings []googleEmbeddingValues `json:"embeddings"`
}
func (h *handler) googleBatchEmbedContents(w http.ResponseWriter, req bunrouter.Request) error {
model := req.Param("model")
var body googleBatchRequest
if err := json.NewDecoder(req.Body).Decode(&body); err != nil {
return writeJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid request body"})
}
var texts []string
for _, r := range body.Requests {
for _, p := range r.Content.Parts {
texts = append(texts, p.Text)
}
}
client, targetName, targetURL := h.resolveClient(model)
trace := TraceFromContext(req.Context())
trace.ForwardTarget = targetName
trace.ForwardURL = targetURL
t0 := time.Now()
embedResp, err := client.Embed(req.Context(), embedclient.Request{Texts: texts, Model: model})
trace.ForwardDuration = time.Since(t0)
if err != nil {
return writeJSON(w, http.StatusBadGateway, map[string]string{"error": err.Error()})
}
trace.ForwardModel = embedResp.Model
trace.PromptTokens = embedResp.Usage.PromptTokens
trace.TotalTokens = embedResp.Usage.TotalTokens
t1 := time.Now()
result := make([]googleEmbeddingValues, len(embedResp.Embeddings))
for i, vec := range embedResp.Embeddings {
adapted, adaptErr := h.adapter.Adapt(vec)
if adaptErr != nil {
return writeJSON(w, http.StatusInternalServerError, map[string]string{"error": adaptErr.Error()})
}
result[i] = googleEmbeddingValues{Values: adapted}
}
trace.TranslateDuration = time.Since(t1)
writeTraceHeaders(w, trace)
return writeJSON(w, http.StatusOK, googleBatchResponse{Embeddings: result})
}

74
pkg/server/handler.go Normal file
View File

@@ -0,0 +1,74 @@
package server
import (
"context"
"encoding/json"
"fmt"
"net/http"
"time"
"go.uber.org/zap"
"github.com/Warky-Devs/vecna.git/pkg/adapter"
"github.com/Warky-Devs/vecna.git/pkg/config"
"github.com/Warky-Devs/vecna.git/pkg/embedclient"
)
// handler holds shared dependencies for all HTTP handlers.
type handler struct {
cfg *config.Config
clients map[string]embedclient.Client
adapter adapter.Adapter
logger *zap.Logger
}
// resolveClient selects the embed client for the given model name.
// Returns the client, target name, and first endpoint URL for tracing.
func (h *handler) resolveClient(model string) (embedclient.Client, string, string) {
if c, ok := h.clients[model]; ok {
url := firstEndpointURL(h.cfg, model)
return c, model, url
}
name := h.cfg.Forward.Default
c, ok := h.clients[name]
if !ok {
// No configured client — return a nil-safe error client
return &errClient{err: fmt.Errorf("no client configured for model %q and no default", model)}, name, ""
}
return c, name, firstEndpointURL(h.cfg, name)
}
func firstEndpointURL(cfg *config.Config, targetName string) string {
t, ok := cfg.Forward.Targets[targetName]
if !ok || len(t.Endpoints) == 0 {
return ""
}
return t.Endpoints[0].URL
}
// writeJSON encodes v as JSON and writes it with the given status code.
func writeJSON(w http.ResponseWriter, status int, v interface{}) error {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
if err := json.NewEncoder(w).Encode(v); err != nil {
return fmt.Errorf("writeJSON: %w", err)
}
return nil
}
// writeTraceHeaders writes X-Vecna-* timing headers from the RequestTrace.
func writeTraceHeaders(w http.ResponseWriter, t *RequestTrace) {
total := time.Since(t.Start)
w.Header().Set("X-Vecna-Forward-Ms", fmt.Sprintf("%d", t.ForwardDuration.Milliseconds()))
w.Header().Set("X-Vecna-Translate-Ms", fmt.Sprintf("%d", t.TranslateDuration.Milliseconds()))
w.Header().Set("X-Vecna-Total-Ms", fmt.Sprintf("%d", total.Milliseconds()))
}
// errClient is a Client that always returns a fixed error (used as safe fallback).
type errClient struct {
err error
}
func (e *errClient) Embed(_ context.Context, _ embedclient.Request) (embedclient.Response, error) {
return embedclient.Response{}, e.err
}

102
pkg/server/openai.go Normal file
View File

@@ -0,0 +1,102 @@
package server
import (
"encoding/json"
"fmt"
"net/http"
"time"
"github.com/uptrace/bunrouter"
"github.com/Warky-Devs/vecna.git/pkg/embedclient"
)
type openAIEmbedRequest struct {
Input interface{} `json:"input"` // string or []string
Model string `json:"model"`
}
type openAIEmbedResponse struct {
Object string `json:"object"`
Data []openAIEmbedDatum `json:"data"`
Model string `json:"model"`
Usage openAIUsage `json:"usage"`
}
type openAIEmbedDatum struct {
Object string `json:"object"`
Embedding []float32 `json:"embedding"`
Index int `json:"index"`
}
type openAIUsage struct {
PromptTokens int `json:"prompt_tokens"`
TotalTokens int `json:"total_tokens"`
}
func (h *handler) openAIEmbeddings(w http.ResponseWriter, req bunrouter.Request) error {
var body openAIEmbedRequest
if err := json.NewDecoder(req.Body).Decode(&body); err != nil {
return writeJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid request body"})
}
texts, err := toStringSlice(body.Input)
if err != nil {
return writeJSON(w, http.StatusBadRequest, map[string]string{"error": err.Error()})
}
client, targetName, targetURL := h.resolveClient(body.Model)
trace := TraceFromContext(req.Context())
trace.ForwardTarget = targetName
trace.ForwardURL = targetURL
t0 := time.Now()
embedResp, err := client.Embed(req.Context(), embedclient.Request{Texts: texts, Model: body.Model})
trace.ForwardDuration = time.Since(t0)
if err != nil {
return writeJSON(w, http.StatusBadGateway, map[string]string{"error": err.Error()})
}
trace.ForwardModel = embedResp.Model
trace.PromptTokens = embedResp.Usage.PromptTokens
trace.TotalTokens = embedResp.Usage.TotalTokens
t1 := time.Now()
data := make([]openAIEmbedDatum, len(embedResp.Embeddings))
for i, vec := range embedResp.Embeddings {
adapted, adaptErr := h.adapter.Adapt(vec)
if adaptErr != nil {
return writeJSON(w, http.StatusInternalServerError, map[string]string{"error": adaptErr.Error()})
}
data[i] = openAIEmbedDatum{Object: "embedding", Embedding: adapted, Index: i}
}
trace.TranslateDuration = time.Since(t1)
writeTraceHeaders(w, trace)
return writeJSON(w, http.StatusOK, openAIEmbedResponse{
Object: "list",
Data: data,
Model: embedResp.Model,
Usage: openAIUsage{PromptTokens: embedResp.Usage.PromptTokens, TotalTokens: embedResp.Usage.TotalTokens},
})
}
// toStringSlice accepts a JSON string or array of strings.
func toStringSlice(v interface{}) ([]string, error) {
switch val := v.(type) {
case string:
return []string{val}, nil
case []interface{}:
out := make([]string, len(val))
for i, item := range val {
s, ok := item.(string)
if !ok {
return nil, fmt.Errorf("input array element %d is not a string", i)
}
out[i] = s
}
return out, nil
default:
return nil, fmt.Errorf("input must be a string or array of strings")
}
}

163
pkg/server/server.go Normal file
View File

@@ -0,0 +1,163 @@
package server
import (
"fmt"
"net/http"
"strings"
"time"
"github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/uptrace/bunrouter"
"go.uber.org/zap"
"github.com/Warky-Devs/vecna.git/pkg/adapter"
"github.com/Warky-Devs/vecna.git/pkg/config"
"github.com/Warky-Devs/vecna.git/pkg/embedclient"
"github.com/Warky-Devs/vecna.git/pkg/metrics"
"github.com/Warky-Devs/vecna.git/pkg/server/spec"
)
// New builds and returns a configured bunrouter.Router.
func New(
cfg *config.Config,
clients map[string]embedclient.Client,
adp adapter.Adapter,
reg *metrics.Registry,
logger *zap.Logger,
) *bunrouter.Router {
router := bunrouter.New(
bunrouter.WithMiddleware(authMiddleware(cfg.Server.APIKeys)),
bunrouter.WithMiddleware(traceMiddleware()),
bunrouter.WithMiddleware(metricsMiddleware(reg, adp)),
bunrouter.WithMiddleware(loggingMiddleware(logger)),
)
h := &handler{cfg: cfg, clients: clients, adapter: adp, logger: logger}
router.POST("/v1/embeddings", h.openAIEmbeddings)
router.POST("/v1/models/:model:embedContent", h.googleEmbedContent)
router.POST("/v1/models/:model:batchEmbedContents", h.googleBatchEmbedContents)
// OpenAPI spec + docs
router.GET("/openapi.yaml", spec.SpecHandler())
router.GET("/docs", spec.DocsHandler())
// Metrics — only when enabled
if cfg.Metrics.Enabled {
metricsHandler := promhttp.HandlerFor(reg.Prometheus(), promhttp.HandlerOpts{})
path := cfg.Metrics.Path
if path == "" {
path = "/metrics"
}
if cfg.Metrics.APIKey != "" {
router.GET(path, metricsAuthHandler(cfg.Metrics.APIKey, metricsHandler))
} else {
router.GET(path, func(w http.ResponseWriter, req bunrouter.Request) error {
metricsHandler.ServeHTTP(w, req.Request)
return nil
})
}
}
return router
}
// authMiddleware rejects requests without a valid Bearer token when api_keys is configured.
func authMiddleware(apiKeys []string) bunrouter.MiddlewareFunc {
if len(apiKeys) == 0 {
return func(next bunrouter.HandlerFunc) bunrouter.HandlerFunc { return next }
}
keySet := make(map[string]struct{}, len(apiKeys))
for _, k := range apiKeys {
keySet[k] = struct{}{}
}
return func(next bunrouter.HandlerFunc) bunrouter.HandlerFunc {
return func(w http.ResponseWriter, req bunrouter.Request) error {
token := strings.TrimPrefix(req.Header.Get("Authorization"), "Bearer ")
if _, ok := keySet[token]; !ok {
http.Error(w, "unauthorized", http.StatusUnauthorized)
return nil
}
return next(w, req)
}
}
}
// traceMiddleware injects a *RequestTrace into every request context.
func traceMiddleware() bunrouter.MiddlewareFunc {
return func(next bunrouter.HandlerFunc) bunrouter.HandlerFunc {
return func(w http.ResponseWriter, req bunrouter.Request) error {
ctx := WithTrace(req.Context())
return next(w, req.WithContext(ctx))
}
}
}
// metricsMiddleware records Prometheus observations after the handler returns.
func metricsMiddleware(reg *metrics.Registry, adp adapter.Adapter) bunrouter.MiddlewareFunc {
if reg == nil {
return func(next bunrouter.HandlerFunc) bunrouter.HandlerFunc { return next }
}
adpType := fmt.Sprintf("%T", adp)
return reg.Middleware(func(req bunrouter.Request) metrics.TraceSnapshot {
t := TraceFromContext(req.Context())
total := time.Since(t.Start)
return metrics.TraceSnapshot{
TotalSeconds: total.Seconds(),
ForwardSeconds: t.ForwardDuration.Seconds(),
TranslateSeconds: t.TranslateDuration.Seconds(),
ForwardTarget: t.ForwardTarget,
ForwardURL: t.ForwardURL,
ForwardModel: t.ForwardModel,
AdapterType: adpType,
PromptTokens: t.PromptTokens,
TotalTokens: t.TotalTokens,
}
})
}
// loggingMiddleware logs method, path, status, and timing via zap.
func loggingMiddleware(logger *zap.Logger) bunrouter.MiddlewareFunc {
return func(next bunrouter.HandlerFunc) bunrouter.HandlerFunc {
return func(w http.ResponseWriter, req bunrouter.Request) error {
sw := &statusWriter{ResponseWriter: w, status: http.StatusOK}
err := next(sw, req)
t := TraceFromContext(req.Context())
total := time.Since(t.Start)
logger.Info("request",
zap.String("method", req.Method),
zap.String("path", req.URL.Path),
zap.Int("status", sw.status),
zap.Int64("total_ms", total.Milliseconds()),
zap.Int64("forward_ms", t.ForwardDuration.Milliseconds()),
zap.Int64("translate_ms", t.TranslateDuration.Milliseconds()),
)
return err
}
}
}
// metricsAuthHandler wraps a standard http.Handler with Bearer token auth.
func metricsAuthHandler(apiKey string, h http.Handler) bunrouter.HandlerFunc {
return func(w http.ResponseWriter, req bunrouter.Request) error {
token := strings.TrimPrefix(req.Header.Get("Authorization"), "Bearer ")
if token != apiKey {
http.Error(w, "unauthorized", http.StatusUnauthorized)
return nil
}
h.ServeHTTP(w, req.Request)
return nil
}
}
// statusWriter captures the HTTP status code written by a handler.
type statusWriter struct {
http.ResponseWriter
status int
}
func (sw *statusWriter) WriteHeader(code int) {
sw.status = code
sw.ResponseWriter.WriteHeader(code)
}

View File

@@ -0,0 +1,36 @@
package spec
import (
_ "embed"
"net/http"
"github.com/uptrace/bunrouter"
)
//go:embed openapi.yaml
var openapiYAML []byte
// SpecHandler serves the raw OpenAPI YAML spec.
func SpecHandler() bunrouter.HandlerFunc {
return func(w http.ResponseWriter, req bunrouter.Request) error {
w.Header().Set("Content-Type", "application/yaml")
_, err := w.Write(openapiYAML)
return err
}
}
// DocsHandler serves the Scalar API reference UI.
func DocsHandler() bunrouter.HandlerFunc {
return func(w http.ResponseWriter, req bunrouter.Request) error {
w.Header().Set("Content-Type", "text/html; charset=utf-8")
_, err := w.Write([]byte(`<!doctype html>
<html>
<head><title>vecna API</title><meta charset="utf-8"/></head>
<body>
<script id="api-reference" data-url="/openapi.yaml"></script>
<script src="https://cdn.jsdelivr.net/npm/@scalar/api-reference"></script>
</body>
</html>`))
return err
}
}

View File

@@ -0,0 +1,252 @@
openapi: "3.1.0"
info:
title: vecna Embedding Adapter
description: Proxies text to a backing embedding model and adapts the result vectors between dimensions.
version: "1.0.0"
servers:
- url: http://localhost:8080
security:
- BearerAuth: []
components:
securitySchemes:
BearerAuth:
type: http
scheme: bearer
schemas:
Error:
type: object
properties:
error:
type: string
OpenAIEmbedRequest:
type: object
required: [input, model]
properties:
input:
oneOf:
- type: string
- type: array
items:
type: string
model:
type: string
OpenAIEmbedResponse:
type: object
properties:
object:
type: string
example: list
model:
type: string
data:
type: array
items:
type: object
properties:
object:
type: string
example: embedding
index:
type: integer
embedding:
type: array
items:
type: number
format: float
usage:
type: object
properties:
prompt_tokens:
type: integer
total_tokens:
type: integer
GoogleEmbedContentRequest:
type: object
required: [content]
properties:
content:
type: object
properties:
parts:
type: array
items:
type: object
properties:
text:
type: string
taskType:
type: string
GoogleEmbedContentResponse:
type: object
properties:
embedding:
type: object
properties:
values:
type: array
items:
type: number
format: float
GoogleBatchRequest:
type: object
required: [requests]
properties:
requests:
type: array
items:
$ref: '#/components/schemas/GoogleEmbedContentRequest'
GoogleBatchResponse:
type: object
properties:
embeddings:
type: array
items:
type: object
properties:
values:
type: array
items:
type: number
format: float
headers:
X-Vecna-Forward-Ms:
description: Time spent forwarding the request to the backing model (milliseconds).
schema:
type: integer
X-Vecna-Translate-Ms:
description: Time spent in the dimension adapter (milliseconds).
schema:
type: integer
X-Vecna-Total-Ms:
description: Total request wall-clock time (milliseconds).
schema:
type: integer
paths:
/v1/embeddings:
post:
summary: OpenAI-compatible embeddings
operationId: openaiEmbeddings
requestBody:
required: true
content:
application/json:
schema:
$ref: '#/components/schemas/OpenAIEmbedRequest'
responses:
"200":
description: Adapted embeddings
headers:
X-Vecna-Forward-Ms:
$ref: '#/components/headers/X-Vecna-Forward-Ms'
X-Vecna-Translate-Ms:
$ref: '#/components/headers/X-Vecna-Translate-Ms'
X-Vecna-Total-Ms:
$ref: '#/components/headers/X-Vecna-Total-Ms'
content:
application/json:
schema:
$ref: '#/components/schemas/OpenAIEmbedResponse'
"400":
description: Bad request
content:
application/json:
schema:
$ref: '#/components/schemas/Error'
"401":
description: Unauthorized
"502":
description: Backing model error
/v1/models/{model}:embedContent:
post:
summary: Google-compatible single embedContent
operationId: googleEmbedContent
parameters:
- name: model
in: path
required: true
schema:
type: string
requestBody:
required: true
content:
application/json:
schema:
$ref: '#/components/schemas/GoogleEmbedContentRequest'
responses:
"200":
description: Adapted embedding
headers:
X-Vecna-Forward-Ms:
$ref: '#/components/headers/X-Vecna-Forward-Ms'
X-Vecna-Translate-Ms:
$ref: '#/components/headers/X-Vecna-Translate-Ms'
X-Vecna-Total-Ms:
$ref: '#/components/headers/X-Vecna-Total-Ms'
content:
application/json:
schema:
$ref: '#/components/schemas/GoogleEmbedContentResponse'
"400":
description: Bad request
content:
application/json:
schema:
$ref: '#/components/schemas/Error'
"401":
description: Unauthorized
"502":
description: Backing model error
/v1/models/{model}:batchEmbedContents:
post:
summary: Google-compatible batch batchEmbedContents
operationId: googleBatchEmbedContents
parameters:
- name: model
in: path
required: true
schema:
type: string
requestBody:
required: true
content:
application/json:
schema:
$ref: '#/components/schemas/GoogleBatchRequest'
responses:
"200":
description: Adapted embeddings
headers:
X-Vecna-Forward-Ms:
$ref: '#/components/headers/X-Vecna-Forward-Ms'
X-Vecna-Translate-Ms:
$ref: '#/components/headers/X-Vecna-Translate-Ms'
X-Vecna-Total-Ms:
$ref: '#/components/headers/X-Vecna-Total-Ms'
content:
application/json:
schema:
$ref: '#/components/schemas/GoogleBatchResponse'
"400":
description: Bad request
content:
application/json:
schema:
$ref: '#/components/schemas/Error'
"401":
description: Unauthorized
"502":
description: Backing model error

37
pkg/server/trace.go Normal file
View File

@@ -0,0 +1,37 @@
package server
import (
"context"
"time"
)
type contextKey int
const traceKey contextKey = iota
// RequestTrace holds per-request timing data populated by handlers and middleware.
type RequestTrace struct {
Start time.Time
ForwardDuration time.Duration
TranslateDuration time.Duration
ForwardTarget string
ForwardURL string
ForwardModel string
AdapterType string
PromptTokens int
TotalTokens int
}
// WithTrace injects a new *RequestTrace into ctx.
func WithTrace(ctx context.Context) context.Context {
return context.WithValue(ctx, traceKey, &RequestTrace{Start: time.Now()})
}
// TraceFromContext retrieves the *RequestTrace from ctx.
// Returns a zero-value trace (non-nil) if none was set.
func TraceFromContext(ctx context.Context) *RequestTrace {
if t, ok := ctx.Value(traceKey).(*RequestTrace); ok && t != nil {
return t
}
return &RequestTrace{Start: time.Now()}
}