mirror of
https://github.com/Warky-Devs/vecna.git
synced 2026-05-05 01:26:58 +00:00
feat: 🎉 Vectors na Vectors, the begining
Translate 1536 <-> 768 , 3072 <-> 2048
This commit is contained in:
77
pkg/adapter/adapter.go
Normal file
77
pkg/adapter/adapter.go
Normal file
@@ -0,0 +1,77 @@
|
||||
package adapter
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// ErrDimMismatch is returned when the input vector length does not match the adapter's source dimension.
|
||||
var ErrDimMismatch = errors.New("vector dimension mismatch")
|
||||
|
||||
// ErrInvalidDim is returned when source or target dimensions are invalid.
|
||||
var ErrInvalidDim = errors.New("invalid dimension")
|
||||
|
||||
// ErrInvalidMatrix is returned when a projection matrix has the wrong shape.
|
||||
var ErrInvalidMatrix = errors.New("invalid projection matrix shape")
|
||||
|
||||
// TruncateMode controls which end of the vector is dropped when downscaling.
|
||||
type TruncateMode int
|
||||
|
||||
const (
|
||||
// TruncateFromEnd keeps the first targetDim elements (default; correct for Matryoshka models).
|
||||
TruncateFromEnd TruncateMode = iota
|
||||
// TruncateFromStart keeps the last targetDim elements.
|
||||
TruncateFromStart
|
||||
)
|
||||
|
||||
// PadMode controls which end of the vector receives zero-padding when upscaling.
|
||||
type PadMode int
|
||||
|
||||
const (
|
||||
// PadAtEnd appends zeros to the end (default).
|
||||
PadAtEnd PadMode = iota
|
||||
// PadAtStart prepends zeros to the start.
|
||||
PadAtStart
|
||||
)
|
||||
|
||||
// Adapter translates vectors between two fixed dimensions.
|
||||
type Adapter interface {
|
||||
Adapt(vec []float32) ([]float32, error)
|
||||
SourceDim() int
|
||||
TargetDim() int
|
||||
}
|
||||
|
||||
// NewTruncate returns a TruncateAdapter for Matryoshka-style or simple truncation/padding.
|
||||
// t controls which end is dropped when downscaling; p controls which end is padded when upscaling.
|
||||
func NewTruncate(sourceDim, targetDim int, t TruncateMode, p PadMode) (Adapter, error) {
|
||||
if sourceDim <= 0 || targetDim <= 0 {
|
||||
return nil, fmt.Errorf("NewTruncate: %w: source=%d target=%d", ErrInvalidDim, sourceDim, targetDim)
|
||||
}
|
||||
return &truncateAdapter{sourceDim: sourceDim, targetDim: targetDim, truncMode: t, padMode: p}, nil
|
||||
}
|
||||
|
||||
// NewRandom returns a RandomAdapter backed by a seeded Gaussian projection matrix.
|
||||
// seed=0 uses a time-based seed.
|
||||
func NewRandom(sourceDim, targetDim int, seed int64) (Adapter, error) {
|
||||
if sourceDim <= 0 || targetDim <= 0 {
|
||||
return nil, fmt.Errorf("NewRandom: %w: source=%d target=%d", ErrInvalidDim, sourceDim, targetDim)
|
||||
}
|
||||
return newRandomAdapter(sourceDim, targetDim, seed), nil
|
||||
}
|
||||
|
||||
// NewProjection returns a ProjectionAdapter using a caller-supplied matrix.
|
||||
// matrix must have shape [targetDim][sourceDim].
|
||||
func NewProjection(sourceDim, targetDim int, matrix [][]float32) (Adapter, error) {
|
||||
if sourceDim <= 0 || targetDim <= 0 {
|
||||
return nil, fmt.Errorf("NewProjection: %w: source=%d target=%d", ErrInvalidDim, sourceDim, targetDim)
|
||||
}
|
||||
if len(matrix) != targetDim {
|
||||
return nil, fmt.Errorf("NewProjection: %w: got %d rows, want %d", ErrInvalidMatrix, len(matrix), targetDim)
|
||||
}
|
||||
for i, row := range matrix {
|
||||
if len(row) != sourceDim {
|
||||
return nil, fmt.Errorf("NewProjection: %w: row %d has %d cols, want %d", ErrInvalidMatrix, i, len(row), sourceDim)
|
||||
}
|
||||
}
|
||||
return &projectionAdapter{sourceDim: sourceDim, targetDim: targetDim, matrix: matrix}, nil
|
||||
}
|
||||
236
pkg/adapter/adapter_test.go
Normal file
236
pkg/adapter/adapter_test.go
Normal file
@@ -0,0 +1,236 @@
|
||||
package adapter
|
||||
|
||||
import (
|
||||
"math"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// unitVec returns a unit vector of length n with value 1/sqrt(n) in each position.
|
||||
func unitVec(n int) []float32 {
|
||||
v := make([]float32, n)
|
||||
val := float32(1.0 / math.Sqrt(float64(n)))
|
||||
for i := range v {
|
||||
v[i] = val
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
func assertNormOne(t *testing.T, v []float32) {
|
||||
t.Helper()
|
||||
var sum float64
|
||||
for _, x := range v {
|
||||
sum += float64(x) * float64(x)
|
||||
}
|
||||
assert.InDelta(t, 1.0, sum, 1e-5, "expected unit norm")
|
||||
}
|
||||
|
||||
// --- L2Norm ---
|
||||
|
||||
func TestL2Norm(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input []float32
|
||||
wantLen int
|
||||
wantNorm float64
|
||||
}{
|
||||
{"unit vector unchanged", []float32{1, 0, 0}, 3, 1.0},
|
||||
{"scale down to unit", []float32{2, 0, 0}, 3, 1.0},
|
||||
{"multi-dim", unitVec(4), 4, 1.0},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := L2Norm(tc.input)
|
||||
assert.Len(t, got, tc.wantLen)
|
||||
assertNormOne(t, got)
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("zero vector returns copy unchanged", func(t *testing.T) {
|
||||
in := []float32{0, 0, 0}
|
||||
got := L2Norm(in)
|
||||
assert.Equal(t, in, got)
|
||||
})
|
||||
}
|
||||
|
||||
// --- TruncateAdapter ---
|
||||
|
||||
func TestNewTruncate_InvalidDims(t *testing.T) {
|
||||
_, err := NewTruncate(0, 768, TruncateFromEnd, PadAtEnd)
|
||||
require.ErrorIs(t, err, ErrInvalidDim)
|
||||
|
||||
_, err = NewTruncate(1536, 0, TruncateFromEnd, PadAtEnd)
|
||||
require.ErrorIs(t, err, ErrInvalidDim)
|
||||
}
|
||||
|
||||
func TestTruncateAdapter(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
src, tgt int
|
||||
truncMode TruncateMode
|
||||
padMode PadMode
|
||||
// builder for input vec; checks are len + norm only unless wantFirst/wantLast set
|
||||
wantLen int
|
||||
wantErr bool
|
||||
}{
|
||||
{"downscale TruncateFromEnd", 1536, 768, TruncateFromEnd, PadAtEnd, 768, false},
|
||||
{"downscale TruncateFromStart", 1536, 768, TruncateFromStart, PadAtEnd, 768, false},
|
||||
{"upscale PadAtEnd", 768, 1536, TruncateFromEnd, PadAtEnd, 1536, false},
|
||||
{"upscale PadAtStart", 768, 1536, TruncateFromEnd, PadAtStart, 1536, false},
|
||||
{"same dim", 768, 768, TruncateFromEnd, PadAtEnd, 768, false},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
a, err := NewTruncate(tc.src, tc.tgt, tc.truncMode, tc.padMode)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tc.src, a.SourceDim())
|
||||
assert.Equal(t, tc.tgt, a.TargetDim())
|
||||
|
||||
got, err := a.Adapt(unitVec(tc.src))
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, got, tc.wantLen)
|
||||
assertNormOne(t, got)
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("dim mismatch returns error", func(t *testing.T) {
|
||||
a, err := NewTruncate(1536, 768, TruncateFromEnd, PadAtEnd)
|
||||
require.NoError(t, err)
|
||||
_, err = a.Adapt(make([]float32, 100))
|
||||
require.ErrorIs(t, err, ErrDimMismatch)
|
||||
})
|
||||
|
||||
t.Run("TruncateFromEnd keeps first elements", func(t *testing.T) {
|
||||
a, err := NewTruncate(4, 2, TruncateFromEnd, PadAtEnd)
|
||||
require.NoError(t, err)
|
||||
// input: [1,2,3,4] — TruncateFromEnd keeps [1,2]
|
||||
in := []float32{1, 2, 3, 4}
|
||||
got, err := a.Adapt(in)
|
||||
require.NoError(t, err)
|
||||
// after L2Norm the signs are preserved; ratio should match
|
||||
assert.Greater(t, got[0], float32(0))
|
||||
assert.Greater(t, got[1], float32(0))
|
||||
})
|
||||
|
||||
t.Run("TruncateFromStart keeps last elements", func(t *testing.T) {
|
||||
a, err := NewTruncate(4, 2, TruncateFromStart, PadAtEnd)
|
||||
require.NoError(t, err)
|
||||
// input: [0,0,3,4] — TruncateFromStart keeps [3,4]
|
||||
in := []float32{0, 0, 3, 4}
|
||||
got, err := a.Adapt(in)
|
||||
require.NoError(t, err)
|
||||
assert.Greater(t, got[0], float32(0))
|
||||
assert.Greater(t, got[1], float32(0))
|
||||
})
|
||||
|
||||
t.Run("PadAtStart zero-pads front", func(t *testing.T) {
|
||||
a, err := NewTruncate(2, 4, TruncateFromEnd, PadAtStart)
|
||||
require.NoError(t, err)
|
||||
in := []float32{1, 0}
|
||||
got, err := a.Adapt(in)
|
||||
require.NoError(t, err)
|
||||
// first two positions should be zero (padded), last two carry the signal
|
||||
assert.Equal(t, float32(0), got[0])
|
||||
assert.Equal(t, float32(0), got[1])
|
||||
})
|
||||
}
|
||||
|
||||
// --- RandomAdapter ---
|
||||
|
||||
func TestNewRandom_InvalidDims(t *testing.T) {
|
||||
_, err := NewRandom(0, 768, 42)
|
||||
require.ErrorIs(t, err, ErrInvalidDim)
|
||||
}
|
||||
|
||||
func TestRandomAdapter(t *testing.T) {
|
||||
t.Run("output length and unit norm", func(t *testing.T) {
|
||||
a, err := NewRandom(1536, 768, 42)
|
||||
require.NoError(t, err)
|
||||
got, err := a.Adapt(unitVec(1536))
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, got, 768)
|
||||
assertNormOne(t, got)
|
||||
})
|
||||
|
||||
t.Run("deterministic with same seed", func(t *testing.T) {
|
||||
a1, _ := NewRandom(64, 32, 99)
|
||||
a2, _ := NewRandom(64, 32, 99)
|
||||
in := unitVec(64)
|
||||
out1, _ := a1.Adapt(in)
|
||||
out2, _ := a2.Adapt(in)
|
||||
assert.Equal(t, out1, out2)
|
||||
})
|
||||
|
||||
t.Run("different seeds produce different output", func(t *testing.T) {
|
||||
a1, _ := NewRandom(64, 32, 1)
|
||||
a2, _ := NewRandom(64, 32, 2)
|
||||
in := unitVec(64)
|
||||
out1, _ := a1.Adapt(in)
|
||||
out2, _ := a2.Adapt(in)
|
||||
assert.NotEqual(t, out1, out2)
|
||||
})
|
||||
|
||||
t.Run("dim mismatch returns error", func(t *testing.T) {
|
||||
a, err := NewRandom(1536, 768, 1)
|
||||
require.NoError(t, err)
|
||||
_, err = a.Adapt(make([]float32, 10))
|
||||
require.ErrorIs(t, err, ErrDimMismatch)
|
||||
})
|
||||
}
|
||||
|
||||
// --- ProjectionAdapter ---
|
||||
|
||||
func TestNewProjection_InvalidMatrix(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
src int
|
||||
tgt int
|
||||
matrix [][]float32
|
||||
errIs error
|
||||
}{
|
||||
{"wrong row count", 4, 2, [][]float32{{1, 2, 3, 4}}, ErrInvalidMatrix},
|
||||
{"wrong col count", 4, 2, [][]float32{{1, 2}, {3, 4}}, ErrInvalidMatrix},
|
||||
{"zero sourceDim", 0, 2, [][]float32{{1}, {2}}, ErrInvalidDim},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
_, err := NewProjection(tc.src, tc.tgt, tc.matrix)
|
||||
require.ErrorIs(t, err, tc.errIs)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProjectionAdapter(t *testing.T) {
|
||||
t.Run("identity matrix same dim", func(t *testing.T) {
|
||||
// 3×3 identity
|
||||
id := [][]float32{{1, 0, 0}, {0, 1, 0}, {0, 0, 1}}
|
||||
a, err := NewProjection(3, 3, id)
|
||||
require.NoError(t, err)
|
||||
in := []float32{1, 2, 3}
|
||||
got, err := a.Adapt(in)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, got, 3)
|
||||
assertNormOne(t, got)
|
||||
})
|
||||
|
||||
t.Run("downscale projection", func(t *testing.T) {
|
||||
// 2×4 matrix
|
||||
m := [][]float32{{1, 0, 0, 0}, {0, 1, 0, 0}}
|
||||
a, err := NewProjection(4, 2, m)
|
||||
require.NoError(t, err)
|
||||
got, err := a.Adapt([]float32{3, 4, 0, 0})
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, got, 2)
|
||||
assertNormOne(t, got)
|
||||
})
|
||||
|
||||
t.Run("dim mismatch returns error", func(t *testing.T) {
|
||||
m := [][]float32{{1, 0}, {0, 1}}
|
||||
a, err := NewProjection(2, 2, m)
|
||||
require.NoError(t, err)
|
||||
_, err = a.Adapt(make([]float32, 5))
|
||||
require.ErrorIs(t, err, ErrDimMismatch)
|
||||
})
|
||||
}
|
||||
23
pkg/adapter/normalize.go
Normal file
23
pkg/adapter/normalize.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package adapter
|
||||
|
||||
import "math"
|
||||
|
||||
// L2Norm returns a new slice with the vector normalized to unit length.
|
||||
// If the vector has zero magnitude it is returned unchanged.
|
||||
func L2Norm(v []float32) []float32 {
|
||||
var sum float64
|
||||
for _, x := range v {
|
||||
sum += float64(x) * float64(x)
|
||||
}
|
||||
if sum == 0 {
|
||||
out := make([]float32, len(v))
|
||||
copy(out, v)
|
||||
return out
|
||||
}
|
||||
norm := float32(math.Sqrt(sum))
|
||||
out := make([]float32, len(v))
|
||||
for i, x := range v {
|
||||
out[i] = x / norm
|
||||
}
|
||||
return out
|
||||
}
|
||||
32
pkg/adapter/projection.go
Normal file
32
pkg/adapter/projection.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package adapter
|
||||
|
||||
import "fmt"
|
||||
|
||||
type projectionAdapter struct {
|
||||
sourceDim int
|
||||
targetDim int
|
||||
matrix [][]float32
|
||||
}
|
||||
|
||||
func (a *projectionAdapter) SourceDim() int { return a.sourceDim }
|
||||
func (a *projectionAdapter) TargetDim() int { return a.targetDim }
|
||||
|
||||
func (a *projectionAdapter) Adapt(vec []float32) ([]float32, error) {
|
||||
if len(vec) != a.sourceDim {
|
||||
return nil, fmt.Errorf("projection adapt: %w: got %d, want %d", ErrDimMismatch, len(vec), a.sourceDim)
|
||||
}
|
||||
return L2Norm(matVecMul(a.matrix, vec)), nil
|
||||
}
|
||||
|
||||
// matVecMul computes m·v where m is [rows][cols] and v has len cols.
|
||||
func matVecMul(m [][]float32, v []float32) []float32 {
|
||||
out := make([]float32, len(m))
|
||||
for i, row := range m {
|
||||
var sum float32
|
||||
for j, val := range row {
|
||||
sum += val * v[j]
|
||||
}
|
||||
out[i] = sum
|
||||
}
|
||||
return out
|
||||
}
|
||||
45
pkg/adapter/random.go
Normal file
45
pkg/adapter/random.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package adapter
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"math/rand"
|
||||
"time"
|
||||
)
|
||||
|
||||
type randomAdapter struct {
|
||||
sourceDim int
|
||||
targetDim int
|
||||
matrix [][]float32
|
||||
}
|
||||
|
||||
func newRandomAdapter(sourceDim, targetDim int, seed int64) *randomAdapter {
|
||||
if seed == 0 {
|
||||
seed = time.Now().UnixNano()
|
||||
}
|
||||
//nolint:gosec // deterministic seeded RNG for projection matrix generation, not security use
|
||||
rng := rand.New(rand.NewSource(seed))
|
||||
|
||||
// Gaussian N(0, 1/targetDim) — preserves expected squared norms (Johnson-Lindenstrauss)
|
||||
stddev := 1.0 / math.Sqrt(float64(targetDim))
|
||||
matrix := make([][]float32, targetDim)
|
||||
for i := range matrix {
|
||||
row := make([]float32, sourceDim)
|
||||
for j := range row {
|
||||
row[j] = float32(rng.NormFloat64() * stddev)
|
||||
}
|
||||
matrix[i] = row
|
||||
}
|
||||
|
||||
return &randomAdapter{sourceDim: sourceDim, targetDim: targetDim, matrix: matrix}
|
||||
}
|
||||
|
||||
func (a *randomAdapter) SourceDim() int { return a.sourceDim }
|
||||
func (a *randomAdapter) TargetDim() int { return a.targetDim }
|
||||
|
||||
func (a *randomAdapter) Adapt(vec []float32) ([]float32, error) {
|
||||
if len(vec) != a.sourceDim {
|
||||
return nil, fmt.Errorf("random adapt: %w: got %d, want %d", ErrDimMismatch, len(vec), a.sourceDim)
|
||||
}
|
||||
return L2Norm(matVecMul(a.matrix, vec)), nil
|
||||
}
|
||||
41
pkg/adapter/truncate.go
Normal file
41
pkg/adapter/truncate.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package adapter
|
||||
|
||||
import "fmt"
|
||||
|
||||
type truncateAdapter struct {
|
||||
sourceDim int
|
||||
targetDim int
|
||||
truncMode TruncateMode
|
||||
padMode PadMode
|
||||
}
|
||||
|
||||
func (a *truncateAdapter) SourceDim() int { return a.sourceDim }
|
||||
func (a *truncateAdapter) TargetDim() int { return a.targetDim }
|
||||
|
||||
func (a *truncateAdapter) Adapt(vec []float32) ([]float32, error) {
|
||||
if len(vec) != a.sourceDim {
|
||||
return nil, fmt.Errorf("truncate adapt: %w: got %d, want %d", ErrDimMismatch, len(vec), a.sourceDim)
|
||||
}
|
||||
|
||||
out := make([]float32, a.targetDim)
|
||||
|
||||
if a.targetDim <= a.sourceDim {
|
||||
// Downscale: truncate
|
||||
switch a.truncMode {
|
||||
case TruncateFromEnd:
|
||||
copy(out, vec[:a.targetDim])
|
||||
case TruncateFromStart:
|
||||
copy(out, vec[a.sourceDim-a.targetDim:])
|
||||
}
|
||||
} else {
|
||||
// Upscale: zero-pad
|
||||
switch a.padMode {
|
||||
case PadAtEnd:
|
||||
copy(out, vec)
|
||||
case PadAtStart:
|
||||
copy(out[a.targetDim-a.sourceDim:], vec)
|
||||
}
|
||||
}
|
||||
|
||||
return L2Norm(out), nil
|
||||
}
|
||||
168
pkg/config/config.go
Normal file
168
pkg/config/config.go
Normal file
@@ -0,0 +1,168 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
// Config is the root configuration for vecna.
|
||||
type Config struct {
|
||||
Server ServerConfig `mapstructure:"server"`
|
||||
Metrics MetricsConfig `mapstructure:"metrics"`
|
||||
Forward ForwardConfig `mapstructure:"forward"`
|
||||
Adapter AdapterConfig `mapstructure:"adapter"`
|
||||
}
|
||||
|
||||
// ServerConfig controls the HTTP listener and inbound auth.
|
||||
type ServerConfig struct {
|
||||
Port int `mapstructure:"port"`
|
||||
Host string `mapstructure:"host"`
|
||||
APIKeys []string `mapstructure:"api_keys"`
|
||||
}
|
||||
|
||||
// MetricsConfig controls Prometheus metrics exposure.
|
||||
type MetricsConfig struct {
|
||||
Enabled bool `mapstructure:"enabled"`
|
||||
Path string `mapstructure:"path"`
|
||||
APIKey string `mapstructure:"api_key"`
|
||||
}
|
||||
|
||||
// ForwardConfig holds all named forwarding targets.
|
||||
type ForwardConfig struct {
|
||||
Default string `mapstructure:"default"`
|
||||
Targets map[string]ForwardTarget `mapstructure:"targets"`
|
||||
}
|
||||
|
||||
// ForwardTarget is a named backing embedding model with one or more endpoints.
|
||||
type ForwardTarget struct {
|
||||
Endpoints []EndpointConfig `mapstructure:"endpoints"`
|
||||
Model string `mapstructure:"model"`
|
||||
APIKey string `mapstructure:"api_key"`
|
||||
APIType string `mapstructure:"api_type"`
|
||||
TimeoutSecs int `mapstructure:"timeout_secs"`
|
||||
CooldownSecs int `mapstructure:"cooldown_secs"`
|
||||
PriorityDecay int `mapstructure:"priority_decay"`
|
||||
PriorityRecovery int `mapstructure:"priority_recovery"`
|
||||
}
|
||||
|
||||
// EndpointConfig is a single URL within a ForwardTarget.
|
||||
type EndpointConfig struct {
|
||||
URL string `mapstructure:"url"`
|
||||
Priority int `mapstructure:"priority"`
|
||||
APIKey string `mapstructure:"api_key"`
|
||||
}
|
||||
|
||||
// AdapterConfig selects and tunes the dimension adapter.
|
||||
type AdapterConfig struct {
|
||||
Type string `mapstructure:"type"`
|
||||
SourceDim int `mapstructure:"source_dim"`
|
||||
TargetDim int `mapstructure:"target_dim"`
|
||||
TruncateMode string `mapstructure:"truncate_mode"`
|
||||
PadMode string `mapstructure:"pad_mode"`
|
||||
Seed int64 `mapstructure:"seed"`
|
||||
MatrixFile string `mapstructure:"matrix_file"`
|
||||
}
|
||||
|
||||
// extensions viper will detect automatically.
|
||||
var extensions = []string{"json", "yaml", "toml"}
|
||||
|
||||
// ResolveFile returns the config file path that would be used by Load.
|
||||
// If cfgFile is non-empty it is returned as-is.
|
||||
// Otherwise the default search paths are checked; if no existing file is found,
|
||||
// the preferred default (~/.vecna.json) is returned so callers can create it.
|
||||
func ResolveFile(cfgFile string) string {
|
||||
if cfgFile != "" {
|
||||
return cfgFile
|
||||
}
|
||||
home, _ := os.UserHomeDir()
|
||||
dirs := []string{".", home, home + "/.config/vecna"}
|
||||
for _, dir := range dirs {
|
||||
for _, ext := range extensions {
|
||||
path := dir + "/vecna." + ext
|
||||
if _, err := os.Stat(path); err == nil {
|
||||
return path
|
||||
}
|
||||
}
|
||||
}
|
||||
return home + "/vecna.json"
|
||||
}
|
||||
|
||||
// Load reads configuration from the given file path (empty = search defaults),
|
||||
// environment variables (prefix VECNA_), and applies built-in defaults.
|
||||
func Load(cfgFile string) (*Config, error) {
|
||||
v := viper.New()
|
||||
|
||||
// Defaults
|
||||
v.SetDefault("server.port", 8080)
|
||||
v.SetDefault("server.host", "0.0.0.0")
|
||||
v.SetDefault("metrics.enabled", false)
|
||||
v.SetDefault("metrics.path", "/metrics")
|
||||
v.SetDefault("adapter.type", "truncate")
|
||||
v.SetDefault("adapter.truncate_mode", "from_end")
|
||||
v.SetDefault("adapter.pad_mode", "at_end")
|
||||
|
||||
// Environment
|
||||
v.SetEnvPrefix("VECNA")
|
||||
v.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
|
||||
v.AutomaticEnv()
|
||||
|
||||
// Config file
|
||||
if cfgFile != "" {
|
||||
v.SetConfigFile(cfgFile)
|
||||
} else {
|
||||
home, _ := os.UserHomeDir()
|
||||
v.SetConfigName("vecna")
|
||||
// No SetConfigType — viper detects format from file extension (json, yaml, toml, etc.)
|
||||
v.AddConfigPath(".")
|
||||
v.AddConfigPath(home)
|
||||
v.AddConfigPath(home + "/.config/vecna")
|
||||
}
|
||||
|
||||
if err := v.ReadInConfig(); err != nil {
|
||||
// Missing config file is acceptable when all required values come from flags/env
|
||||
var notFound viper.ConfigFileNotFoundError
|
||||
if !errors.As(err, ¬Found) {
|
||||
return nil, fmt.Errorf("load config: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
var cfg Config
|
||||
if err := v.Unmarshal(&cfg); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal config: %w", err)
|
||||
}
|
||||
|
||||
applyForwardDefaults(&cfg)
|
||||
|
||||
return &cfg, nil
|
||||
}
|
||||
|
||||
// applyForwardDefaults fills in zero-value fields on ForwardTarget entries.
|
||||
func applyForwardDefaults(cfg *Config) {
|
||||
for name, t := range cfg.Forward.Targets {
|
||||
if t.TimeoutSecs == 0 {
|
||||
t.TimeoutSecs = 30
|
||||
}
|
||||
if t.CooldownSecs == 0 {
|
||||
t.CooldownSecs = 60
|
||||
}
|
||||
if t.PriorityDecay == 0 {
|
||||
t.PriorityDecay = 2
|
||||
}
|
||||
if t.PriorityRecovery == 0 {
|
||||
t.PriorityRecovery = 5
|
||||
}
|
||||
for i, ep := range t.Endpoints {
|
||||
if ep.Priority == 0 {
|
||||
t.Endpoints[i].Priority = 10
|
||||
}
|
||||
if ep.APIKey == "" && t.APIKey != "" {
|
||||
t.Endpoints[i].APIKey = t.APIKey
|
||||
}
|
||||
}
|
||||
cfg.Forward.Targets[name] = t
|
||||
}
|
||||
}
|
||||
137
pkg/config/update.go
Normal file
137
pkg/config/update.go
Normal file
@@ -0,0 +1,137 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// SaveTarget adds or replaces a named ForwardTarget in the config file at path.
|
||||
// Only JSON config files are written in-place. For other formats an error is
|
||||
// returned describing what to add manually.
|
||||
func SaveTarget(path, name string, target ForwardTarget) error {
|
||||
ext := strings.ToLower(strings.TrimPrefix(filepath.Ext(path), "."))
|
||||
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read config %s: %w", path, err)
|
||||
}
|
||||
|
||||
switch ext {
|
||||
case "json":
|
||||
return saveTargetJSON(path, data, name, target)
|
||||
default:
|
||||
snippet, _ := json.MarshalIndent(map[string]ForwardTarget{name: target}, "", " ")
|
||||
return fmt.Errorf(
|
||||
"auto-update not supported for .%s files\n"+
|
||||
"Add the following to the 'forward.targets' section of %s:\n\n%s",
|
||||
ext, path, snippet,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// RemoveBrokenEndpoints removes failing endpoints from the config file.
|
||||
// broken maps target name → set of failing endpoint URLs.
|
||||
// If all endpoints of a target are removed, the target itself is deleted.
|
||||
// Returns the list of removed items as human-readable strings.
|
||||
func RemoveBrokenEndpoints(path string, broken map[string][]string) ([]string, error) {
|
||||
ext := strings.ToLower(strings.TrimPrefix(filepath.Ext(path), "."))
|
||||
if ext != "json" {
|
||||
return nil, fmt.Errorf("auto-update not supported for .%s files; edit %s manually", ext, path)
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read config %s: %w", path, err)
|
||||
}
|
||||
|
||||
var cfg Config
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return nil, fmt.Errorf("parse %s: %w", path, err)
|
||||
}
|
||||
|
||||
var removed []string
|
||||
|
||||
for targetName, failedURLs := range broken {
|
||||
target, ok := cfg.Forward.Targets[targetName]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
failSet := make(map[string]bool, len(failedURLs))
|
||||
for _, u := range failedURLs {
|
||||
failSet[u] = true
|
||||
}
|
||||
|
||||
kept := target.Endpoints[:0]
|
||||
for _, ep := range target.Endpoints {
|
||||
if failSet[ep.URL] {
|
||||
removed = append(removed, fmt.Sprintf("endpoint %s (target %q)", ep.URL, targetName))
|
||||
} else {
|
||||
kept = append(kept, ep)
|
||||
}
|
||||
}
|
||||
|
||||
if len(kept) == 0 {
|
||||
delete(cfg.Forward.Targets, targetName)
|
||||
removed = append(removed, fmt.Sprintf("target %q (all endpoints failed)", targetName))
|
||||
if cfg.Forward.Default == targetName {
|
||||
cfg.Forward.Default = ""
|
||||
}
|
||||
} else {
|
||||
target.Endpoints = kept
|
||||
cfg.Forward.Targets[targetName] = target
|
||||
}
|
||||
}
|
||||
|
||||
if len(removed) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
out, err := json.MarshalIndent(cfg, "", " ")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal config: %w", err)
|
||||
}
|
||||
if err := os.WriteFile(path, append(out, '\n'), 0o600); err != nil {
|
||||
return nil, fmt.Errorf("write %s: %w", path, err)
|
||||
}
|
||||
return removed, nil
|
||||
}
|
||||
|
||||
// WriteConfig serialises cfg as indented JSON and atomically overwrites path.
|
||||
func WriteConfig(path string, cfg Config) error {
|
||||
out, err := json.MarshalIndent(cfg, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal config: %w", err)
|
||||
}
|
||||
if err := os.WriteFile(path, append(out, '\n'), 0o600); err != nil {
|
||||
return fmt.Errorf("write %s: %w", path, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func saveTargetJSON(path string, data []byte, name string, target ForwardTarget) error {
|
||||
var cfg Config
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return fmt.Errorf("parse %s: %w", path, err)
|
||||
}
|
||||
|
||||
if cfg.Forward.Targets == nil {
|
||||
cfg.Forward.Targets = make(map[string]ForwardTarget)
|
||||
}
|
||||
cfg.Forward.Targets[name] = target
|
||||
if cfg.Forward.Default == "" {
|
||||
cfg.Forward.Default = name
|
||||
}
|
||||
|
||||
out, err := json.MarshalIndent(cfg, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal config: %w", err)
|
||||
}
|
||||
if err := os.WriteFile(path, append(out, '\n'), 0o600); err != nil {
|
||||
return fmt.Errorf("write %s: %w", path, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
206
pkg/discovery/discovery.go
Normal file
206
pkg/discovery/discovery.go
Normal file
@@ -0,0 +1,206 @@
|
||||
package discovery
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Kind describes a known server type.
|
||||
type Kind struct {
|
||||
Name string
|
||||
APIType string
|
||||
Port int
|
||||
NeedsKey bool
|
||||
}
|
||||
|
||||
// Found is a server discovered on the network.
|
||||
type Found struct {
|
||||
Kind Kind
|
||||
BaseURL string
|
||||
Models []string
|
||||
}
|
||||
|
||||
// knownServers lists server types by their default port and display name.
|
||||
// Ollama is listed separately because it uses a non-OpenAI probe endpoint.
|
||||
var knownServers = []Kind{
|
||||
{Name: "Ollama", APIType: "openai", Port: 11434},
|
||||
{Name: "LM Studio", APIType: "openai", Port: 1234},
|
||||
{Name: "vLLM", APIType: "openai", Port: 8000, NeedsKey: true},
|
||||
{Name: "LocalAI", APIType: "openai", Port: 8080},
|
||||
{Name: "Jan", APIType: "openai", Port: 1337},
|
||||
{Name: "Kobold", APIType: "openai", Port: 5001},
|
||||
{Name: "Tabby", APIType: "openai", Port: 9090},
|
||||
}
|
||||
|
||||
// Scan concurrently probes localhost and LAN gateway addresses for known LLM servers.
|
||||
// Results are returned in the order they are found (non-deterministic).
|
||||
func Scan(ctx context.Context) []Found {
|
||||
hosts := localHosts()
|
||||
|
||||
var (
|
||||
mu sync.Mutex
|
||||
results []Found
|
||||
wg sync.WaitGroup
|
||||
)
|
||||
|
||||
probeCtx, cancel := context.WithTimeout(ctx, 4*time.Second)
|
||||
defer cancel()
|
||||
|
||||
httpClient := &http.Client{Timeout: 600 * time.Millisecond}
|
||||
|
||||
for _, host := range hosts {
|
||||
for _, kind := range knownServers {
|
||||
host := host
|
||||
wg.Add(1)
|
||||
go func(kind Kind) {
|
||||
defer wg.Done()
|
||||
baseURL := fmt.Sprintf("http://%s:%d", host, kind.Port)
|
||||
models, err := probe(probeCtx, httpClient, baseURL, kind)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
mu.Lock()
|
||||
results = append(results, Found{Kind: kind, BaseURL: baseURL, Models: models})
|
||||
mu.Unlock()
|
||||
}(kind)
|
||||
}
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
return results
|
||||
}
|
||||
|
||||
// Models fetches the model list from a single base URL and kind (for the models command).
|
||||
func Models(ctx context.Context, baseURL string, kind Kind) ([]string, error) {
|
||||
httpClient := &http.Client{Timeout: 5 * time.Second}
|
||||
return probe(ctx, httpClient, baseURL, kind)
|
||||
}
|
||||
|
||||
// localHosts returns localhost plus the .1 gateway of every local IPv4 subnet.
|
||||
func localHosts() []string {
|
||||
seen := map[string]bool{"127.0.0.1": true}
|
||||
hosts := []string{"127.0.0.1"}
|
||||
|
||||
ifaces, err := net.Interfaces()
|
||||
if err != nil {
|
||||
return hosts
|
||||
}
|
||||
|
||||
for _, iface := range ifaces {
|
||||
if iface.Flags&net.FlagUp == 0 {
|
||||
continue
|
||||
}
|
||||
addrs, err := iface.Addrs()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
for _, addr := range addrs {
|
||||
ipnet, ok := addr.(*net.IPNet)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
ip := ipnet.IP.To4()
|
||||
if ip == nil || ip.IsLoopback() {
|
||||
continue
|
||||
}
|
||||
// derive the likely gateway: network address + 1
|
||||
gw := ip.Mask(ipnet.Mask)
|
||||
gw[3] = 1
|
||||
h := gw.String()
|
||||
if !seen[h] {
|
||||
seen[h] = true
|
||||
hosts = append(hosts, h)
|
||||
}
|
||||
}
|
||||
}
|
||||
return hosts
|
||||
}
|
||||
|
||||
// probe attempts to identify the server at baseURL and returns its model list.
|
||||
func probe(ctx context.Context, client *http.Client, baseURL string, kind Kind) ([]string, error) {
|
||||
// Ollama has its own endpoint; everything else is OpenAI-compatible
|
||||
if kind.Name == "Ollama" {
|
||||
models, err := probeOllama(ctx, client, baseURL)
|
||||
if err != nil {
|
||||
// Ollama also exposes /v1/models since v0.1.27 — fall back
|
||||
return probeOpenAI(ctx, client, baseURL)
|
||||
}
|
||||
return models, nil
|
||||
}
|
||||
return probeOpenAI(ctx, client, baseURL)
|
||||
}
|
||||
|
||||
// --- Ollama ---
|
||||
|
||||
type ollamaTagsResponse struct {
|
||||
Models []struct {
|
||||
Name string `json:"name"`
|
||||
} `json:"models"`
|
||||
}
|
||||
|
||||
func probeOllama(ctx context.Context, client *http.Client, baseURL string) ([]string, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, baseURL+"/api/tags", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var body ollamaTagsResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&body); err != nil {
|
||||
return nil, fmt.Errorf("decode ollama response: %w", err)
|
||||
}
|
||||
|
||||
models := make([]string, len(body.Models))
|
||||
for i, m := range body.Models {
|
||||
models[i] = m.Name
|
||||
}
|
||||
return models, nil
|
||||
}
|
||||
|
||||
// --- OpenAI-compatible ---
|
||||
|
||||
type openAIModelsResponse struct {
|
||||
Data []struct {
|
||||
ID string `json:"id"`
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
func probeOpenAI(ctx context.Context, client *http.Client, baseURL string) ([]string, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, baseURL+"/v1/models", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var body openAIModelsResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&body); err != nil {
|
||||
return nil, fmt.Errorf("decode openai response: %w", err)
|
||||
}
|
||||
|
||||
models := make([]string, len(body.Data))
|
||||
for i, m := range body.Data {
|
||||
models[i] = m.ID
|
||||
}
|
||||
return models, nil
|
||||
}
|
||||
27
pkg/embedclient/client.go
Normal file
27
pkg/embedclient/client.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package embedclient
|
||||
|
||||
import "context"
|
||||
|
||||
// Request is a batch of texts to embed.
|
||||
type Request struct {
|
||||
Texts []string
|
||||
Model string
|
||||
}
|
||||
|
||||
// Usage reports token consumption from the backing model.
|
||||
type Usage struct {
|
||||
PromptTokens int
|
||||
TotalTokens int
|
||||
}
|
||||
|
||||
// Response holds the raw embeddings returned by the backing model.
|
||||
type Response struct {
|
||||
Embeddings [][]float32
|
||||
Model string
|
||||
Usage Usage
|
||||
}
|
||||
|
||||
// Client sends text to a backing embedding model and returns raw vectors.
|
||||
type Client interface {
|
||||
Embed(ctx context.Context, req Request) (Response, error)
|
||||
}
|
||||
98
pkg/embedclient/google.go
Normal file
98
pkg/embedclient/google.go
Normal file
@@ -0,0 +1,98 @@
|
||||
package embedclient
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type googleClient struct {
|
||||
baseURL string
|
||||
apiKey string
|
||||
model string
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// NewGoogle returns a Client that speaks the Google Gemini batchEmbedContents API.
|
||||
func NewGoogle(baseURL, apiKey, model string, httpClient *http.Client) Client {
|
||||
if httpClient == nil {
|
||||
httpClient = http.DefaultClient
|
||||
}
|
||||
return &googleClient{baseURL: baseURL, apiKey: apiKey, model: model, httpClient: httpClient}
|
||||
}
|
||||
|
||||
type googleBatchRequest struct {
|
||||
Requests []googleEmbedRequest `json:"requests"`
|
||||
}
|
||||
|
||||
type googleEmbedRequest struct {
|
||||
Model string `json:"model"`
|
||||
Content googleContent `json:"content"`
|
||||
}
|
||||
|
||||
type googleContent struct {
|
||||
Parts []googlePart `json:"parts"`
|
||||
}
|
||||
|
||||
type googlePart struct {
|
||||
Text string `json:"text"`
|
||||
}
|
||||
|
||||
type googleBatchResponse struct {
|
||||
Embeddings []struct {
|
||||
Values []float32 `json:"values"`
|
||||
} `json:"embeddings"`
|
||||
}
|
||||
|
||||
func (c *googleClient) Embed(ctx context.Context, req Request) (Response, error) {
|
||||
requests := make([]googleEmbedRequest, len(req.Texts))
|
||||
for i, text := range req.Texts {
|
||||
requests[i] = googleEmbedRequest{
|
||||
Model: "models/" + c.model,
|
||||
Content: googleContent{Parts: []googlePart{{Text: text}}},
|
||||
}
|
||||
}
|
||||
|
||||
body, err := json.Marshal(googleBatchRequest{Requests: requests})
|
||||
if err != nil {
|
||||
return Response{}, fmt.Errorf("google embed marshal: %w", err)
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/v1/models/%s:batchEmbedContents", c.baseURL, c.model)
|
||||
if c.apiKey != "" {
|
||||
url += "?key=" + c.apiKey
|
||||
}
|
||||
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return Response{}, fmt.Errorf("google embed request: %w", err)
|
||||
}
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := c.httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
return Response{}, fmt.Errorf("google embed do: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return Response{}, fmt.Errorf("google embed: unexpected status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var gResp googleBatchResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&gResp); err != nil {
|
||||
return Response{}, fmt.Errorf("google embed decode: %w", err)
|
||||
}
|
||||
|
||||
embeddings := make([][]float32, len(gResp.Embeddings))
|
||||
for i, e := range gResp.Embeddings {
|
||||
embeddings[i] = e.Values
|
||||
}
|
||||
|
||||
return Response{
|
||||
Embeddings: embeddings,
|
||||
Model: c.model,
|
||||
}, nil
|
||||
}
|
||||
87
pkg/embedclient/openai.go
Normal file
87
pkg/embedclient/openai.go
Normal file
@@ -0,0 +1,87 @@
|
||||
package embedclient
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type openAIClient struct {
|
||||
baseURL string
|
||||
apiKey string
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// NewOpenAI returns a Client that speaks the OpenAI embeddings API.
|
||||
func NewOpenAI(baseURL, apiKey string, httpClient *http.Client) Client {
|
||||
if httpClient == nil {
|
||||
httpClient = http.DefaultClient
|
||||
}
|
||||
return &openAIClient{baseURL: baseURL, apiKey: apiKey, httpClient: httpClient}
|
||||
}
|
||||
|
||||
type openAIEmbedRequest struct {
|
||||
Input []string `json:"input"`
|
||||
Model string `json:"model"`
|
||||
}
|
||||
|
||||
type openAIEmbedResponse struct {
|
||||
Object string `json:"object"`
|
||||
Data []struct {
|
||||
Object string `json:"object"`
|
||||
Embedding []float32 `json:"embedding"`
|
||||
Index int `json:"index"`
|
||||
} `json:"data"`
|
||||
Model string `json:"model"`
|
||||
Usage struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
} `json:"usage"`
|
||||
}
|
||||
|
||||
func (c *openAIClient) Embed(ctx context.Context, req Request) (Response, error) {
|
||||
body, err := json.Marshal(openAIEmbedRequest{Input: req.Texts, Model: req.Model})
|
||||
if err != nil {
|
||||
return Response{}, fmt.Errorf("openai embed marshal: %w", err)
|
||||
}
|
||||
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+"/v1/embeddings", bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return Response{}, fmt.Errorf("openai embed request: %w", err)
|
||||
}
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
if c.apiKey != "" {
|
||||
httpReq.Header.Set("Authorization", "Bearer "+c.apiKey)
|
||||
}
|
||||
|
||||
resp, err := c.httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
return Response{}, fmt.Errorf("openai embed do: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return Response{}, fmt.Errorf("openai embed: unexpected status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var oaiResp openAIEmbedResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&oaiResp); err != nil {
|
||||
return Response{}, fmt.Errorf("openai embed decode: %w", err)
|
||||
}
|
||||
|
||||
embeddings := make([][]float32, len(oaiResp.Data))
|
||||
for _, d := range oaiResp.Data {
|
||||
embeddings[d.Index] = d.Embedding
|
||||
}
|
||||
|
||||
return Response{
|
||||
Embeddings: embeddings,
|
||||
Model: oaiResp.Model,
|
||||
Usage: Usage{
|
||||
PromptTokens: oaiResp.Usage.PromptTokens,
|
||||
TotalTokens: oaiResp.Usage.TotalTokens,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
180
pkg/embedclient/router.go
Normal file
180
pkg/embedclient/router.go
Normal file
@@ -0,0 +1,180 @@
|
||||
package embedclient
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Warky-Devs/vecna.git/pkg/metrics"
|
||||
)
|
||||
|
||||
// RouterConfig holds tuning parameters for a TargetRouter.
|
||||
type RouterConfig struct {
|
||||
TargetName string
|
||||
TimeoutSecs int
|
||||
CooldownSecs int
|
||||
PriorityDecay int
|
||||
PriorityRecovery int
|
||||
}
|
||||
|
||||
type endpointSlot struct {
|
||||
client Client
|
||||
url string
|
||||
initialPriority int
|
||||
|
||||
mu sync.Mutex
|
||||
priority int
|
||||
inflight int
|
||||
successCount int
|
||||
lastFail time.Time
|
||||
}
|
||||
|
||||
// TargetRouter implements Client by routing requests across multiple endpoint slots
|
||||
// using a busyness-based priority algorithm.
|
||||
type TargetRouter struct {
|
||||
slots []*endpointSlot
|
||||
cfg RouterConfig
|
||||
metrics *metrics.Registry
|
||||
}
|
||||
|
||||
// NewTargetRouter constructs a TargetRouter from a slice of (client, url, initialPriority) tuples.
|
||||
func NewTargetRouter(slots []RouterSlot, cfg RouterConfig, reg *metrics.Registry) (*TargetRouter, error) {
|
||||
if len(slots) == 0 {
|
||||
return nil, fmt.Errorf("NewTargetRouter: at least one slot required")
|
||||
}
|
||||
es := make([]*endpointSlot, len(slots))
|
||||
for i, s := range slots {
|
||||
es[i] = &endpointSlot{
|
||||
client: s.Client,
|
||||
url: s.URL,
|
||||
initialPriority: s.Priority,
|
||||
priority: s.Priority,
|
||||
}
|
||||
if reg != nil {
|
||||
reg.SetEndpointPriority(cfg.TargetName, s.URL, float64(s.Priority))
|
||||
reg.SetEndpointInflight(cfg.TargetName, s.URL, 0)
|
||||
}
|
||||
}
|
||||
return &TargetRouter{slots: es, cfg: cfg, metrics: reg}, nil
|
||||
}
|
||||
|
||||
// RouterSlot is a single endpoint entry for NewTargetRouter.
|
||||
type RouterSlot struct {
|
||||
Client Client
|
||||
URL string
|
||||
Priority int
|
||||
}
|
||||
|
||||
func (r *TargetRouter) Embed(ctx context.Context, req Request) (Response, error) {
|
||||
slot := r.pick()
|
||||
|
||||
slot.mu.Lock()
|
||||
slot.inflight++
|
||||
if r.metrics != nil {
|
||||
r.metrics.SetEndpointInflight(r.cfg.TargetName, slot.url, float64(slot.inflight))
|
||||
}
|
||||
slot.mu.Unlock()
|
||||
|
||||
defer func() {
|
||||
slot.mu.Lock()
|
||||
slot.inflight--
|
||||
if r.metrics != nil {
|
||||
r.metrics.SetEndpointInflight(r.cfg.TargetName, slot.url, float64(slot.inflight))
|
||||
}
|
||||
slot.mu.Unlock()
|
||||
}()
|
||||
|
||||
timeout := time.Duration(r.cfg.TimeoutSecs) * time.Second
|
||||
ctx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
|
||||
resp, err := slot.client.Embed(ctx, req)
|
||||
if err != nil {
|
||||
r.onFailure(slot, err)
|
||||
return Response{}, fmt.Errorf("router embed [%s]: %w", slot.url, err)
|
||||
}
|
||||
|
||||
r.onSuccess(slot)
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// pick selects the best available slot.
|
||||
func (r *TargetRouter) pick() *endpointSlot {
|
||||
cooldown := time.Duration(r.cfg.CooldownSecs) * time.Second
|
||||
now := time.Now()
|
||||
|
||||
var best *endpointSlot
|
||||
bestScore := -1 << 30
|
||||
|
||||
for _, s := range r.slots {
|
||||
s.mu.Lock()
|
||||
inCooldown := !s.lastFail.IsZero() && now.Sub(s.lastFail) < cooldown
|
||||
score := s.priority - s.inflight
|
||||
lastFail := s.lastFail
|
||||
s.mu.Unlock()
|
||||
|
||||
if inCooldown {
|
||||
continue
|
||||
}
|
||||
if best == nil || score > bestScore {
|
||||
best = s
|
||||
bestScore = score
|
||||
_ = lastFail
|
||||
}
|
||||
}
|
||||
|
||||
// All in cooldown — fall back to oldest failure
|
||||
if best == nil {
|
||||
var oldest time.Time
|
||||
for _, s := range r.slots {
|
||||
s.mu.Lock()
|
||||
lf := s.lastFail
|
||||
s.mu.Unlock()
|
||||
if best == nil || lf.Before(oldest) {
|
||||
best = s
|
||||
oldest = lf
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return best
|
||||
}
|
||||
|
||||
func (r *TargetRouter) onSuccess(s *endpointSlot) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.successCount++
|
||||
if r.cfg.PriorityRecovery > 0 && s.successCount%r.cfg.PriorityRecovery == 0 {
|
||||
if s.priority < s.initialPriority {
|
||||
s.priority++
|
||||
}
|
||||
}
|
||||
if r.metrics != nil {
|
||||
r.metrics.SetEndpointPriority(r.cfg.TargetName, s.url, float64(s.priority))
|
||||
}
|
||||
}
|
||||
|
||||
func (r *TargetRouter) onFailure(s *endpointSlot, err error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.lastFail = time.Now()
|
||||
s.priority -= r.cfg.PriorityDecay
|
||||
if s.priority < 1 {
|
||||
s.priority = 1
|
||||
}
|
||||
|
||||
errType := "error"
|
||||
if ctx := context.Background(); ctx.Err() != nil {
|
||||
errType = "timeout"
|
||||
}
|
||||
|
||||
if r.metrics != nil {
|
||||
r.metrics.SetEndpointPriority(r.cfg.TargetName, s.url, float64(s.priority))
|
||||
r.metrics.IncEndpointErrors(r.cfg.TargetName, s.url, errType)
|
||||
}
|
||||
|
||||
_ = err
|
||||
}
|
||||
112
pkg/metrics/metrics.go
Normal file
112
pkg/metrics/metrics.go
Normal file
@@ -0,0 +1,112 @@
|
||||
package metrics
|
||||
|
||||
import (
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
)
|
||||
|
||||
// Registry holds all vecna Prometheus metrics on a dedicated (non-global) registry.
|
||||
type Registry struct {
|
||||
reg *prometheus.Registry
|
||||
|
||||
RequestsTotal *prometheus.CounterVec
|
||||
RequestDuration *prometheus.HistogramVec
|
||||
ForwardDuration *prometheus.HistogramVec
|
||||
TranslateDuration *prometheus.HistogramVec
|
||||
EndpointPriority *prometheus.GaugeVec
|
||||
EndpointInflight *prometheus.GaugeVec
|
||||
EndpointErrorsTotal *prometheus.CounterVec
|
||||
TokensTotal *prometheus.CounterVec
|
||||
}
|
||||
|
||||
// New creates and registers all metrics on a fresh Prometheus registry.
|
||||
func New() *Registry {
|
||||
reg := prometheus.NewRegistry()
|
||||
|
||||
r := &Registry{
|
||||
reg: reg,
|
||||
|
||||
RequestsTotal: prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||
Name: "vecna_requests_total",
|
||||
Help: "Total number of requests served by vecna.",
|
||||
}, []string{"endpoint", "status"}),
|
||||
|
||||
RequestDuration: prometheus.NewHistogramVec(prometheus.HistogramOpts{
|
||||
Name: "vecna_request_duration_seconds",
|
||||
Help: "Total request wall-clock time.",
|
||||
Buckets: prometheus.DefBuckets,
|
||||
}, []string{"endpoint"}),
|
||||
|
||||
ForwardDuration: prometheus.NewHistogramVec(prometheus.HistogramOpts{
|
||||
Name: "vecna_forward_duration_seconds",
|
||||
Help: "Time spent waiting on the backing embedding model.",
|
||||
Buckets: prometheus.DefBuckets,
|
||||
}, []string{"target", "url"}),
|
||||
|
||||
TranslateDuration: prometheus.NewHistogramVec(prometheus.HistogramOpts{
|
||||
Name: "vecna_translate_duration_seconds",
|
||||
Help: "Time spent in the dimension adapter.",
|
||||
Buckets: []float64{0.0001, 0.0005, 0.001, 0.005, 0.01, 0.05},
|
||||
}, []string{"adapter_type"}),
|
||||
|
||||
EndpointPriority: prometheus.NewGaugeVec(prometheus.GaugeOpts{
|
||||
Name: "vecna_endpoint_priority",
|
||||
Help: "Current dynamic routing priority for a forwarding endpoint.",
|
||||
}, []string{"target", "url"}),
|
||||
|
||||
EndpointInflight: prometheus.NewGaugeVec(prometheus.GaugeOpts{
|
||||
Name: "vecna_endpoint_inflight",
|
||||
Help: "Number of active in-flight requests per forwarding endpoint.",
|
||||
}, []string{"target", "url"}),
|
||||
|
||||
EndpointErrorsTotal: prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||
Name: "vecna_endpoint_errors_total",
|
||||
Help: "Total forwarding errors per endpoint, labelled by error type.",
|
||||
}, []string{"target", "url", "error"}),
|
||||
|
||||
TokensTotal: prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||
Name: "vecna_tokens_total",
|
||||
Help: "Tokens consumed by the backing embedding model, by target, model, and token type.",
|
||||
}, []string{"target", "model", "token_type"}),
|
||||
}
|
||||
|
||||
reg.MustRegister(
|
||||
r.RequestsTotal,
|
||||
r.RequestDuration,
|
||||
r.ForwardDuration,
|
||||
r.TranslateDuration,
|
||||
r.EndpointPriority,
|
||||
r.EndpointInflight,
|
||||
r.EndpointErrorsTotal,
|
||||
r.TokensTotal,
|
||||
)
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
// Prometheus returns the underlying registry for use with promhttp.HandlerFor.
|
||||
func (r *Registry) Prometheus() *prometheus.Registry {
|
||||
return r.reg
|
||||
}
|
||||
|
||||
// Convenience setters used by the router.
|
||||
|
||||
func (r *Registry) SetEndpointPriority(target, url string, v float64) {
|
||||
r.EndpointPriority.WithLabelValues(target, url).Set(v)
|
||||
}
|
||||
|
||||
func (r *Registry) SetEndpointInflight(target, url string, v float64) {
|
||||
r.EndpointInflight.WithLabelValues(target, url).Set(v)
|
||||
}
|
||||
|
||||
func (r *Registry) IncEndpointErrors(target, url, errType string) {
|
||||
r.EndpointErrorsTotal.WithLabelValues(target, url, errType).Inc()
|
||||
}
|
||||
|
||||
func (r *Registry) AddTokens(target, model string, promptTokens, totalTokens int) {
|
||||
if promptTokens > 0 {
|
||||
r.TokensTotal.WithLabelValues(target, model, "prompt").Add(float64(promptTokens))
|
||||
}
|
||||
if totalTokens > 0 {
|
||||
r.TokensTotal.WithLabelValues(target, model, "total").Add(float64(totalTokens))
|
||||
}
|
||||
}
|
||||
62
pkg/metrics/middleware.go
Normal file
62
pkg/metrics/middleware.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package metrics
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/uptrace/bunrouter"
|
||||
)
|
||||
|
||||
// Middleware returns a bunrouter middleware that records per-request Prometheus metrics.
|
||||
// It reads timing from the RequestTrace stored in the context (set by server/trace.go).
|
||||
// The trace target/url labels are optional; pass empty strings if not applicable.
|
||||
func (r *Registry) Middleware(getTrace func(req bunrouter.Request) TraceSnapshot) bunrouter.MiddlewareFunc {
|
||||
return func(next bunrouter.HandlerFunc) bunrouter.HandlerFunc {
|
||||
return func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
rw := &statusWriter{ResponseWriter: w, status: http.StatusOK}
|
||||
err := next(rw, req)
|
||||
|
||||
snap := getTrace(req)
|
||||
endpoint := req.URL.Path
|
||||
status := fmt.Sprintf("%d", rw.status)
|
||||
|
||||
r.RequestsTotal.WithLabelValues(endpoint, status).Inc()
|
||||
r.RequestDuration.WithLabelValues(endpoint).Observe(snap.TotalSeconds)
|
||||
if snap.ForwardTarget != "" {
|
||||
r.ForwardDuration.WithLabelValues(snap.ForwardTarget, snap.ForwardURL).Observe(snap.ForwardSeconds)
|
||||
}
|
||||
if snap.AdapterType != "" {
|
||||
r.TranslateDuration.WithLabelValues(snap.AdapterType).Observe(snap.TranslateSeconds)
|
||||
}
|
||||
if snap.PromptTokens > 0 || snap.TotalTokens > 0 {
|
||||
r.AddTokens(snap.ForwardTarget, snap.ForwardModel, snap.PromptTokens, snap.TotalTokens)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TraceSnapshot carries the timing and usage values the metrics middleware needs.
|
||||
type TraceSnapshot struct {
|
||||
TotalSeconds float64
|
||||
ForwardSeconds float64
|
||||
TranslateSeconds float64
|
||||
ForwardTarget string
|
||||
ForwardURL string
|
||||
ForwardModel string
|
||||
AdapterType string
|
||||
PromptTokens int
|
||||
TotalTokens int
|
||||
}
|
||||
|
||||
// statusWriter wraps http.ResponseWriter to capture the written status code.
|
||||
type statusWriter struct {
|
||||
http.ResponseWriter
|
||||
status int
|
||||
}
|
||||
|
||||
func (sw *statusWriter) WriteHeader(code int) {
|
||||
sw.status = code
|
||||
sw.ResponseWriter.WriteHeader(code)
|
||||
}
|
||||
135
pkg/server/google.go
Normal file
135
pkg/server/google.go
Normal file
@@ -0,0 +1,135 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/uptrace/bunrouter"
|
||||
|
||||
"github.com/Warky-Devs/vecna.git/pkg/embedclient"
|
||||
)
|
||||
|
||||
// --- single embedContent ---
|
||||
|
||||
type googleEmbedContentRequest struct {
|
||||
Content googleContent `json:"content"`
|
||||
TaskType string `json:"taskType,omitempty"`
|
||||
}
|
||||
|
||||
type googleContent struct {
|
||||
Parts []googlePart `json:"parts"`
|
||||
}
|
||||
|
||||
type googlePart struct {
|
||||
Text string `json:"text"`
|
||||
}
|
||||
|
||||
type googleEmbedContentResponse struct {
|
||||
Embedding googleEmbeddingValues `json:"embedding"`
|
||||
}
|
||||
|
||||
type googleEmbeddingValues struct {
|
||||
Values []float32 `json:"values"`
|
||||
}
|
||||
|
||||
func (h *handler) googleEmbedContent(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
model := req.Param("model")
|
||||
|
||||
var body googleEmbedContentRequest
|
||||
if err := json.NewDecoder(req.Body).Decode(&body); err != nil {
|
||||
return writeJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid request body"})
|
||||
}
|
||||
|
||||
texts := make([]string, len(body.Content.Parts))
|
||||
for i, p := range body.Content.Parts {
|
||||
texts[i] = p.Text
|
||||
}
|
||||
|
||||
client, targetName, targetURL := h.resolveClient(model)
|
||||
trace := TraceFromContext(req.Context())
|
||||
trace.ForwardTarget = targetName
|
||||
trace.ForwardURL = targetURL
|
||||
|
||||
t0 := time.Now()
|
||||
embedResp, err := client.Embed(req.Context(), embedclient.Request{Texts: texts, Model: model})
|
||||
trace.ForwardDuration = time.Since(t0)
|
||||
if err != nil {
|
||||
return writeJSON(w, http.StatusBadGateway, map[string]string{"error": err.Error()})
|
||||
}
|
||||
trace.ForwardModel = embedResp.Model
|
||||
trace.PromptTokens = embedResp.Usage.PromptTokens
|
||||
trace.TotalTokens = embedResp.Usage.TotalTokens
|
||||
|
||||
t1 := time.Now()
|
||||
var adapted []float32
|
||||
if len(embedResp.Embeddings) > 0 {
|
||||
adapted, err = h.adapter.Adapt(embedResp.Embeddings[0])
|
||||
if err != nil {
|
||||
return writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()})
|
||||
}
|
||||
}
|
||||
trace.TranslateDuration = time.Since(t1)
|
||||
|
||||
writeTraceHeaders(w, trace)
|
||||
|
||||
return writeJSON(w, http.StatusOK, googleEmbedContentResponse{
|
||||
Embedding: googleEmbeddingValues{Values: adapted},
|
||||
})
|
||||
}
|
||||
|
||||
// --- batch batchEmbedContents ---
|
||||
|
||||
type googleBatchRequest struct {
|
||||
Requests []googleEmbedContentRequest `json:"requests"`
|
||||
}
|
||||
|
||||
type googleBatchResponse struct {
|
||||
Embeddings []googleEmbeddingValues `json:"embeddings"`
|
||||
}
|
||||
|
||||
func (h *handler) googleBatchEmbedContents(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
model := req.Param("model")
|
||||
|
||||
var body googleBatchRequest
|
||||
if err := json.NewDecoder(req.Body).Decode(&body); err != nil {
|
||||
return writeJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid request body"})
|
||||
}
|
||||
|
||||
var texts []string
|
||||
for _, r := range body.Requests {
|
||||
for _, p := range r.Content.Parts {
|
||||
texts = append(texts, p.Text)
|
||||
}
|
||||
}
|
||||
|
||||
client, targetName, targetURL := h.resolveClient(model)
|
||||
trace := TraceFromContext(req.Context())
|
||||
trace.ForwardTarget = targetName
|
||||
trace.ForwardURL = targetURL
|
||||
|
||||
t0 := time.Now()
|
||||
embedResp, err := client.Embed(req.Context(), embedclient.Request{Texts: texts, Model: model})
|
||||
trace.ForwardDuration = time.Since(t0)
|
||||
if err != nil {
|
||||
return writeJSON(w, http.StatusBadGateway, map[string]string{"error": err.Error()})
|
||||
}
|
||||
trace.ForwardModel = embedResp.Model
|
||||
trace.PromptTokens = embedResp.Usage.PromptTokens
|
||||
trace.TotalTokens = embedResp.Usage.TotalTokens
|
||||
|
||||
t1 := time.Now()
|
||||
result := make([]googleEmbeddingValues, len(embedResp.Embeddings))
|
||||
for i, vec := range embedResp.Embeddings {
|
||||
adapted, adaptErr := h.adapter.Adapt(vec)
|
||||
if adaptErr != nil {
|
||||
return writeJSON(w, http.StatusInternalServerError, map[string]string{"error": adaptErr.Error()})
|
||||
}
|
||||
result[i] = googleEmbeddingValues{Values: adapted}
|
||||
}
|
||||
trace.TranslateDuration = time.Since(t1)
|
||||
|
||||
writeTraceHeaders(w, trace)
|
||||
|
||||
return writeJSON(w, http.StatusOK, googleBatchResponse{Embeddings: result})
|
||||
}
|
||||
74
pkg/server/handler.go
Normal file
74
pkg/server/handler.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/Warky-Devs/vecna.git/pkg/adapter"
|
||||
"github.com/Warky-Devs/vecna.git/pkg/config"
|
||||
"github.com/Warky-Devs/vecna.git/pkg/embedclient"
|
||||
)
|
||||
|
||||
// handler holds shared dependencies for all HTTP handlers.
|
||||
type handler struct {
|
||||
cfg *config.Config
|
||||
clients map[string]embedclient.Client
|
||||
adapter adapter.Adapter
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// resolveClient selects the embed client for the given model name.
|
||||
// Returns the client, target name, and first endpoint URL for tracing.
|
||||
func (h *handler) resolveClient(model string) (embedclient.Client, string, string) {
|
||||
if c, ok := h.clients[model]; ok {
|
||||
url := firstEndpointURL(h.cfg, model)
|
||||
return c, model, url
|
||||
}
|
||||
name := h.cfg.Forward.Default
|
||||
c, ok := h.clients[name]
|
||||
if !ok {
|
||||
// No configured client — return a nil-safe error client
|
||||
return &errClient{err: fmt.Errorf("no client configured for model %q and no default", model)}, name, ""
|
||||
}
|
||||
return c, name, firstEndpointURL(h.cfg, name)
|
||||
}
|
||||
|
||||
func firstEndpointURL(cfg *config.Config, targetName string) string {
|
||||
t, ok := cfg.Forward.Targets[targetName]
|
||||
if !ok || len(t.Endpoints) == 0 {
|
||||
return ""
|
||||
}
|
||||
return t.Endpoints[0].URL
|
||||
}
|
||||
|
||||
// writeJSON encodes v as JSON and writes it with the given status code.
|
||||
func writeJSON(w http.ResponseWriter, status int, v interface{}) error {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
if err := json.NewEncoder(w).Encode(v); err != nil {
|
||||
return fmt.Errorf("writeJSON: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// writeTraceHeaders writes X-Vecna-* timing headers from the RequestTrace.
|
||||
func writeTraceHeaders(w http.ResponseWriter, t *RequestTrace) {
|
||||
total := time.Since(t.Start)
|
||||
w.Header().Set("X-Vecna-Forward-Ms", fmt.Sprintf("%d", t.ForwardDuration.Milliseconds()))
|
||||
w.Header().Set("X-Vecna-Translate-Ms", fmt.Sprintf("%d", t.TranslateDuration.Milliseconds()))
|
||||
w.Header().Set("X-Vecna-Total-Ms", fmt.Sprintf("%d", total.Milliseconds()))
|
||||
}
|
||||
|
||||
// errClient is a Client that always returns a fixed error (used as safe fallback).
|
||||
type errClient struct {
|
||||
err error
|
||||
}
|
||||
|
||||
func (e *errClient) Embed(_ context.Context, _ embedclient.Request) (embedclient.Response, error) {
|
||||
return embedclient.Response{}, e.err
|
||||
}
|
||||
102
pkg/server/openai.go
Normal file
102
pkg/server/openai.go
Normal file
@@ -0,0 +1,102 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/uptrace/bunrouter"
|
||||
|
||||
"github.com/Warky-Devs/vecna.git/pkg/embedclient"
|
||||
)
|
||||
|
||||
type openAIEmbedRequest struct {
|
||||
Input interface{} `json:"input"` // string or []string
|
||||
Model string `json:"model"`
|
||||
}
|
||||
|
||||
type openAIEmbedResponse struct {
|
||||
Object string `json:"object"`
|
||||
Data []openAIEmbedDatum `json:"data"`
|
||||
Model string `json:"model"`
|
||||
Usage openAIUsage `json:"usage"`
|
||||
}
|
||||
|
||||
type openAIEmbedDatum struct {
|
||||
Object string `json:"object"`
|
||||
Embedding []float32 `json:"embedding"`
|
||||
Index int `json:"index"`
|
||||
}
|
||||
|
||||
type openAIUsage struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
|
||||
func (h *handler) openAIEmbeddings(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
var body openAIEmbedRequest
|
||||
if err := json.NewDecoder(req.Body).Decode(&body); err != nil {
|
||||
return writeJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid request body"})
|
||||
}
|
||||
|
||||
texts, err := toStringSlice(body.Input)
|
||||
if err != nil {
|
||||
return writeJSON(w, http.StatusBadRequest, map[string]string{"error": err.Error()})
|
||||
}
|
||||
|
||||
client, targetName, targetURL := h.resolveClient(body.Model)
|
||||
trace := TraceFromContext(req.Context())
|
||||
trace.ForwardTarget = targetName
|
||||
trace.ForwardURL = targetURL
|
||||
|
||||
t0 := time.Now()
|
||||
embedResp, err := client.Embed(req.Context(), embedclient.Request{Texts: texts, Model: body.Model})
|
||||
trace.ForwardDuration = time.Since(t0)
|
||||
if err != nil {
|
||||
return writeJSON(w, http.StatusBadGateway, map[string]string{"error": err.Error()})
|
||||
}
|
||||
trace.ForwardModel = embedResp.Model
|
||||
trace.PromptTokens = embedResp.Usage.PromptTokens
|
||||
trace.TotalTokens = embedResp.Usage.TotalTokens
|
||||
|
||||
t1 := time.Now()
|
||||
data := make([]openAIEmbedDatum, len(embedResp.Embeddings))
|
||||
for i, vec := range embedResp.Embeddings {
|
||||
adapted, adaptErr := h.adapter.Adapt(vec)
|
||||
if adaptErr != nil {
|
||||
return writeJSON(w, http.StatusInternalServerError, map[string]string{"error": adaptErr.Error()})
|
||||
}
|
||||
data[i] = openAIEmbedDatum{Object: "embedding", Embedding: adapted, Index: i}
|
||||
}
|
||||
trace.TranslateDuration = time.Since(t1)
|
||||
|
||||
writeTraceHeaders(w, trace)
|
||||
|
||||
return writeJSON(w, http.StatusOK, openAIEmbedResponse{
|
||||
Object: "list",
|
||||
Data: data,
|
||||
Model: embedResp.Model,
|
||||
Usage: openAIUsage{PromptTokens: embedResp.Usage.PromptTokens, TotalTokens: embedResp.Usage.TotalTokens},
|
||||
})
|
||||
}
|
||||
|
||||
// toStringSlice accepts a JSON string or array of strings.
|
||||
func toStringSlice(v interface{}) ([]string, error) {
|
||||
switch val := v.(type) {
|
||||
case string:
|
||||
return []string{val}, nil
|
||||
case []interface{}:
|
||||
out := make([]string, len(val))
|
||||
for i, item := range val {
|
||||
s, ok := item.(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("input array element %d is not a string", i)
|
||||
}
|
||||
out[i] = s
|
||||
}
|
||||
return out, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("input must be a string or array of strings")
|
||||
}
|
||||
}
|
||||
163
pkg/server/server.go
Normal file
163
pkg/server/server.go
Normal file
@@ -0,0 +1,163 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
"github.com/uptrace/bunrouter"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/Warky-Devs/vecna.git/pkg/adapter"
|
||||
"github.com/Warky-Devs/vecna.git/pkg/config"
|
||||
"github.com/Warky-Devs/vecna.git/pkg/embedclient"
|
||||
"github.com/Warky-Devs/vecna.git/pkg/metrics"
|
||||
"github.com/Warky-Devs/vecna.git/pkg/server/spec"
|
||||
)
|
||||
|
||||
// New builds and returns a configured bunrouter.Router.
|
||||
func New(
|
||||
cfg *config.Config,
|
||||
clients map[string]embedclient.Client,
|
||||
adp adapter.Adapter,
|
||||
reg *metrics.Registry,
|
||||
logger *zap.Logger,
|
||||
) *bunrouter.Router {
|
||||
router := bunrouter.New(
|
||||
bunrouter.WithMiddleware(authMiddleware(cfg.Server.APIKeys)),
|
||||
bunrouter.WithMiddleware(traceMiddleware()),
|
||||
bunrouter.WithMiddleware(metricsMiddleware(reg, adp)),
|
||||
bunrouter.WithMiddleware(loggingMiddleware(logger)),
|
||||
)
|
||||
|
||||
h := &handler{cfg: cfg, clients: clients, adapter: adp, logger: logger}
|
||||
|
||||
router.POST("/v1/embeddings", h.openAIEmbeddings)
|
||||
router.POST("/v1/models/:model:embedContent", h.googleEmbedContent)
|
||||
router.POST("/v1/models/:model:batchEmbedContents", h.googleBatchEmbedContents)
|
||||
|
||||
// OpenAPI spec + docs
|
||||
router.GET("/openapi.yaml", spec.SpecHandler())
|
||||
router.GET("/docs", spec.DocsHandler())
|
||||
|
||||
// Metrics — only when enabled
|
||||
if cfg.Metrics.Enabled {
|
||||
metricsHandler := promhttp.HandlerFor(reg.Prometheus(), promhttp.HandlerOpts{})
|
||||
path := cfg.Metrics.Path
|
||||
if path == "" {
|
||||
path = "/metrics"
|
||||
}
|
||||
if cfg.Metrics.APIKey != "" {
|
||||
router.GET(path, metricsAuthHandler(cfg.Metrics.APIKey, metricsHandler))
|
||||
} else {
|
||||
router.GET(path, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
metricsHandler.ServeHTTP(w, req.Request)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return router
|
||||
}
|
||||
|
||||
// authMiddleware rejects requests without a valid Bearer token when api_keys is configured.
|
||||
func authMiddleware(apiKeys []string) bunrouter.MiddlewareFunc {
|
||||
if len(apiKeys) == 0 {
|
||||
return func(next bunrouter.HandlerFunc) bunrouter.HandlerFunc { return next }
|
||||
}
|
||||
keySet := make(map[string]struct{}, len(apiKeys))
|
||||
for _, k := range apiKeys {
|
||||
keySet[k] = struct{}{}
|
||||
}
|
||||
return func(next bunrouter.HandlerFunc) bunrouter.HandlerFunc {
|
||||
return func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
token := strings.TrimPrefix(req.Header.Get("Authorization"), "Bearer ")
|
||||
if _, ok := keySet[token]; !ok {
|
||||
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
||||
return nil
|
||||
}
|
||||
return next(w, req)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// traceMiddleware injects a *RequestTrace into every request context.
|
||||
func traceMiddleware() bunrouter.MiddlewareFunc {
|
||||
return func(next bunrouter.HandlerFunc) bunrouter.HandlerFunc {
|
||||
return func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
ctx := WithTrace(req.Context())
|
||||
return next(w, req.WithContext(ctx))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// metricsMiddleware records Prometheus observations after the handler returns.
|
||||
func metricsMiddleware(reg *metrics.Registry, adp adapter.Adapter) bunrouter.MiddlewareFunc {
|
||||
if reg == nil {
|
||||
return func(next bunrouter.HandlerFunc) bunrouter.HandlerFunc { return next }
|
||||
}
|
||||
adpType := fmt.Sprintf("%T", adp)
|
||||
return reg.Middleware(func(req bunrouter.Request) metrics.TraceSnapshot {
|
||||
t := TraceFromContext(req.Context())
|
||||
total := time.Since(t.Start)
|
||||
return metrics.TraceSnapshot{
|
||||
TotalSeconds: total.Seconds(),
|
||||
ForwardSeconds: t.ForwardDuration.Seconds(),
|
||||
TranslateSeconds: t.TranslateDuration.Seconds(),
|
||||
ForwardTarget: t.ForwardTarget,
|
||||
ForwardURL: t.ForwardURL,
|
||||
ForwardModel: t.ForwardModel,
|
||||
AdapterType: adpType,
|
||||
PromptTokens: t.PromptTokens,
|
||||
TotalTokens: t.TotalTokens,
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// loggingMiddleware logs method, path, status, and timing via zap.
|
||||
func loggingMiddleware(logger *zap.Logger) bunrouter.MiddlewareFunc {
|
||||
return func(next bunrouter.HandlerFunc) bunrouter.HandlerFunc {
|
||||
return func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
sw := &statusWriter{ResponseWriter: w, status: http.StatusOK}
|
||||
err := next(sw, req)
|
||||
t := TraceFromContext(req.Context())
|
||||
total := time.Since(t.Start)
|
||||
|
||||
logger.Info("request",
|
||||
zap.String("method", req.Method),
|
||||
zap.String("path", req.URL.Path),
|
||||
zap.Int("status", sw.status),
|
||||
zap.Int64("total_ms", total.Milliseconds()),
|
||||
zap.Int64("forward_ms", t.ForwardDuration.Milliseconds()),
|
||||
zap.Int64("translate_ms", t.TranslateDuration.Milliseconds()),
|
||||
)
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// metricsAuthHandler wraps a standard http.Handler with Bearer token auth.
|
||||
func metricsAuthHandler(apiKey string, h http.Handler) bunrouter.HandlerFunc {
|
||||
return func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
token := strings.TrimPrefix(req.Header.Get("Authorization"), "Bearer ")
|
||||
if token != apiKey {
|
||||
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
||||
return nil
|
||||
}
|
||||
h.ServeHTTP(w, req.Request)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// statusWriter captures the HTTP status code written by a handler.
|
||||
type statusWriter struct {
|
||||
http.ResponseWriter
|
||||
status int
|
||||
}
|
||||
|
||||
func (sw *statusWriter) WriteHeader(code int) {
|
||||
sw.status = code
|
||||
sw.ResponseWriter.WriteHeader(code)
|
||||
}
|
||||
36
pkg/server/spec/handler.go
Normal file
36
pkg/server/spec/handler.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package spec
|
||||
|
||||
import (
|
||||
_ "embed"
|
||||
"net/http"
|
||||
|
||||
"github.com/uptrace/bunrouter"
|
||||
)
|
||||
|
||||
//go:embed openapi.yaml
|
||||
var openapiYAML []byte
|
||||
|
||||
// SpecHandler serves the raw OpenAPI YAML spec.
|
||||
func SpecHandler() bunrouter.HandlerFunc {
|
||||
return func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
w.Header().Set("Content-Type", "application/yaml")
|
||||
_, err := w.Write(openapiYAML)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// DocsHandler serves the Scalar API reference UI.
|
||||
func DocsHandler() bunrouter.HandlerFunc {
|
||||
return func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
_, err := w.Write([]byte(`<!doctype html>
|
||||
<html>
|
||||
<head><title>vecna API</title><meta charset="utf-8"/></head>
|
||||
<body>
|
||||
<script id="api-reference" data-url="/openapi.yaml"></script>
|
||||
<script src="https://cdn.jsdelivr.net/npm/@scalar/api-reference"></script>
|
||||
</body>
|
||||
</html>`))
|
||||
return err
|
||||
}
|
||||
}
|
||||
252
pkg/server/spec/openapi.yaml
Normal file
252
pkg/server/spec/openapi.yaml
Normal file
@@ -0,0 +1,252 @@
|
||||
openapi: "3.1.0"
|
||||
info:
|
||||
title: vecna Embedding Adapter
|
||||
description: Proxies text to a backing embedding model and adapts the result vectors between dimensions.
|
||||
version: "1.0.0"
|
||||
|
||||
servers:
|
||||
- url: http://localhost:8080
|
||||
|
||||
security:
|
||||
- BearerAuth: []
|
||||
|
||||
components:
|
||||
securitySchemes:
|
||||
BearerAuth:
|
||||
type: http
|
||||
scheme: bearer
|
||||
|
||||
schemas:
|
||||
Error:
|
||||
type: object
|
||||
properties:
|
||||
error:
|
||||
type: string
|
||||
|
||||
OpenAIEmbedRequest:
|
||||
type: object
|
||||
required: [input, model]
|
||||
properties:
|
||||
input:
|
||||
oneOf:
|
||||
- type: string
|
||||
- type: array
|
||||
items:
|
||||
type: string
|
||||
model:
|
||||
type: string
|
||||
|
||||
OpenAIEmbedResponse:
|
||||
type: object
|
||||
properties:
|
||||
object:
|
||||
type: string
|
||||
example: list
|
||||
model:
|
||||
type: string
|
||||
data:
|
||||
type: array
|
||||
items:
|
||||
type: object
|
||||
properties:
|
||||
object:
|
||||
type: string
|
||||
example: embedding
|
||||
index:
|
||||
type: integer
|
||||
embedding:
|
||||
type: array
|
||||
items:
|
||||
type: number
|
||||
format: float
|
||||
usage:
|
||||
type: object
|
||||
properties:
|
||||
prompt_tokens:
|
||||
type: integer
|
||||
total_tokens:
|
||||
type: integer
|
||||
|
||||
GoogleEmbedContentRequest:
|
||||
type: object
|
||||
required: [content]
|
||||
properties:
|
||||
content:
|
||||
type: object
|
||||
properties:
|
||||
parts:
|
||||
type: array
|
||||
items:
|
||||
type: object
|
||||
properties:
|
||||
text:
|
||||
type: string
|
||||
taskType:
|
||||
type: string
|
||||
|
||||
GoogleEmbedContentResponse:
|
||||
type: object
|
||||
properties:
|
||||
embedding:
|
||||
type: object
|
||||
properties:
|
||||
values:
|
||||
type: array
|
||||
items:
|
||||
type: number
|
||||
format: float
|
||||
|
||||
GoogleBatchRequest:
|
||||
type: object
|
||||
required: [requests]
|
||||
properties:
|
||||
requests:
|
||||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/GoogleEmbedContentRequest'
|
||||
|
||||
GoogleBatchResponse:
|
||||
type: object
|
||||
properties:
|
||||
embeddings:
|
||||
type: array
|
||||
items:
|
||||
type: object
|
||||
properties:
|
||||
values:
|
||||
type: array
|
||||
items:
|
||||
type: number
|
||||
format: float
|
||||
|
||||
headers:
|
||||
X-Vecna-Forward-Ms:
|
||||
description: Time spent forwarding the request to the backing model (milliseconds).
|
||||
schema:
|
||||
type: integer
|
||||
X-Vecna-Translate-Ms:
|
||||
description: Time spent in the dimension adapter (milliseconds).
|
||||
schema:
|
||||
type: integer
|
||||
X-Vecna-Total-Ms:
|
||||
description: Total request wall-clock time (milliseconds).
|
||||
schema:
|
||||
type: integer
|
||||
|
||||
paths:
|
||||
/v1/embeddings:
|
||||
post:
|
||||
summary: OpenAI-compatible embeddings
|
||||
operationId: openaiEmbeddings
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/OpenAIEmbedRequest'
|
||||
responses:
|
||||
"200":
|
||||
description: Adapted embeddings
|
||||
headers:
|
||||
X-Vecna-Forward-Ms:
|
||||
$ref: '#/components/headers/X-Vecna-Forward-Ms'
|
||||
X-Vecna-Translate-Ms:
|
||||
$ref: '#/components/headers/X-Vecna-Translate-Ms'
|
||||
X-Vecna-Total-Ms:
|
||||
$ref: '#/components/headers/X-Vecna-Total-Ms'
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/OpenAIEmbedResponse'
|
||||
"400":
|
||||
description: Bad request
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/Error'
|
||||
"401":
|
||||
description: Unauthorized
|
||||
"502":
|
||||
description: Backing model error
|
||||
|
||||
/v1/models/{model}:embedContent:
|
||||
post:
|
||||
summary: Google-compatible single embedContent
|
||||
operationId: googleEmbedContent
|
||||
parameters:
|
||||
- name: model
|
||||
in: path
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/GoogleEmbedContentRequest'
|
||||
responses:
|
||||
"200":
|
||||
description: Adapted embedding
|
||||
headers:
|
||||
X-Vecna-Forward-Ms:
|
||||
$ref: '#/components/headers/X-Vecna-Forward-Ms'
|
||||
X-Vecna-Translate-Ms:
|
||||
$ref: '#/components/headers/X-Vecna-Translate-Ms'
|
||||
X-Vecna-Total-Ms:
|
||||
$ref: '#/components/headers/X-Vecna-Total-Ms'
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/GoogleEmbedContentResponse'
|
||||
"400":
|
||||
description: Bad request
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/Error'
|
||||
"401":
|
||||
description: Unauthorized
|
||||
"502":
|
||||
description: Backing model error
|
||||
|
||||
/v1/models/{model}:batchEmbedContents:
|
||||
post:
|
||||
summary: Google-compatible batch batchEmbedContents
|
||||
operationId: googleBatchEmbedContents
|
||||
parameters:
|
||||
- name: model
|
||||
in: path
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/GoogleBatchRequest'
|
||||
responses:
|
||||
"200":
|
||||
description: Adapted embeddings
|
||||
headers:
|
||||
X-Vecna-Forward-Ms:
|
||||
$ref: '#/components/headers/X-Vecna-Forward-Ms'
|
||||
X-Vecna-Translate-Ms:
|
||||
$ref: '#/components/headers/X-Vecna-Translate-Ms'
|
||||
X-Vecna-Total-Ms:
|
||||
$ref: '#/components/headers/X-Vecna-Total-Ms'
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/GoogleBatchResponse'
|
||||
"400":
|
||||
description: Bad request
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/Error'
|
||||
"401":
|
||||
description: Unauthorized
|
||||
"502":
|
||||
description: Backing model error
|
||||
37
pkg/server/trace.go
Normal file
37
pkg/server/trace.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
type contextKey int
|
||||
|
||||
const traceKey contextKey = iota
|
||||
|
||||
// RequestTrace holds per-request timing data populated by handlers and middleware.
|
||||
type RequestTrace struct {
|
||||
Start time.Time
|
||||
ForwardDuration time.Duration
|
||||
TranslateDuration time.Duration
|
||||
ForwardTarget string
|
||||
ForwardURL string
|
||||
ForwardModel string
|
||||
AdapterType string
|
||||
PromptTokens int
|
||||
TotalTokens int
|
||||
}
|
||||
|
||||
// WithTrace injects a new *RequestTrace into ctx.
|
||||
func WithTrace(ctx context.Context) context.Context {
|
||||
return context.WithValue(ctx, traceKey, &RequestTrace{Start: time.Now()})
|
||||
}
|
||||
|
||||
// TraceFromContext retrieves the *RequestTrace from ctx.
|
||||
// Returns a zero-value trace (non-nil) if none was set.
|
||||
func TraceFromContext(ctx context.Context) *RequestTrace {
|
||||
if t, ok := ctx.Value(traceKey).(*RequestTrace); ok && t != nil {
|
||||
return t
|
||||
}
|
||||
return &RequestTrace{Start: time.Now()}
|
||||
}
|
||||
Reference in New Issue
Block a user