mirror of
https://github.com/Warky-Devs/vecna.git
synced 2026-05-05 01:26:58 +00:00
feat: 🎉 Vectors na Vectors, the begining
Translate 1536 <-> 768 , 3072 <-> 2048
This commit is contained in:
77
pkg/adapter/adapter.go
Normal file
77
pkg/adapter/adapter.go
Normal file
@@ -0,0 +1,77 @@
|
||||
package adapter
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// ErrDimMismatch is returned when the input vector length does not match the adapter's source dimension.
|
||||
var ErrDimMismatch = errors.New("vector dimension mismatch")
|
||||
|
||||
// ErrInvalidDim is returned when source or target dimensions are invalid.
|
||||
var ErrInvalidDim = errors.New("invalid dimension")
|
||||
|
||||
// ErrInvalidMatrix is returned when a projection matrix has the wrong shape.
|
||||
var ErrInvalidMatrix = errors.New("invalid projection matrix shape")
|
||||
|
||||
// TruncateMode controls which end of the vector is dropped when downscaling.
|
||||
type TruncateMode int
|
||||
|
||||
const (
|
||||
// TruncateFromEnd keeps the first targetDim elements (default; correct for Matryoshka models).
|
||||
TruncateFromEnd TruncateMode = iota
|
||||
// TruncateFromStart keeps the last targetDim elements.
|
||||
TruncateFromStart
|
||||
)
|
||||
|
||||
// PadMode controls which end of the vector receives zero-padding when upscaling.
|
||||
type PadMode int
|
||||
|
||||
const (
|
||||
// PadAtEnd appends zeros to the end (default).
|
||||
PadAtEnd PadMode = iota
|
||||
// PadAtStart prepends zeros to the start.
|
||||
PadAtStart
|
||||
)
|
||||
|
||||
// Adapter translates vectors between two fixed dimensions.
|
||||
type Adapter interface {
|
||||
Adapt(vec []float32) ([]float32, error)
|
||||
SourceDim() int
|
||||
TargetDim() int
|
||||
}
|
||||
|
||||
// NewTruncate returns a TruncateAdapter for Matryoshka-style or simple truncation/padding.
|
||||
// t controls which end is dropped when downscaling; p controls which end is padded when upscaling.
|
||||
func NewTruncate(sourceDim, targetDim int, t TruncateMode, p PadMode) (Adapter, error) {
|
||||
if sourceDim <= 0 || targetDim <= 0 {
|
||||
return nil, fmt.Errorf("NewTruncate: %w: source=%d target=%d", ErrInvalidDim, sourceDim, targetDim)
|
||||
}
|
||||
return &truncateAdapter{sourceDim: sourceDim, targetDim: targetDim, truncMode: t, padMode: p}, nil
|
||||
}
|
||||
|
||||
// NewRandom returns a RandomAdapter backed by a seeded Gaussian projection matrix.
|
||||
// seed=0 uses a time-based seed.
|
||||
func NewRandom(sourceDim, targetDim int, seed int64) (Adapter, error) {
|
||||
if sourceDim <= 0 || targetDim <= 0 {
|
||||
return nil, fmt.Errorf("NewRandom: %w: source=%d target=%d", ErrInvalidDim, sourceDim, targetDim)
|
||||
}
|
||||
return newRandomAdapter(sourceDim, targetDim, seed), nil
|
||||
}
|
||||
|
||||
// NewProjection returns a ProjectionAdapter using a caller-supplied matrix.
|
||||
// matrix must have shape [targetDim][sourceDim].
|
||||
func NewProjection(sourceDim, targetDim int, matrix [][]float32) (Adapter, error) {
|
||||
if sourceDim <= 0 || targetDim <= 0 {
|
||||
return nil, fmt.Errorf("NewProjection: %w: source=%d target=%d", ErrInvalidDim, sourceDim, targetDim)
|
||||
}
|
||||
if len(matrix) != targetDim {
|
||||
return nil, fmt.Errorf("NewProjection: %w: got %d rows, want %d", ErrInvalidMatrix, len(matrix), targetDim)
|
||||
}
|
||||
for i, row := range matrix {
|
||||
if len(row) != sourceDim {
|
||||
return nil, fmt.Errorf("NewProjection: %w: row %d has %d cols, want %d", ErrInvalidMatrix, i, len(row), sourceDim)
|
||||
}
|
||||
}
|
||||
return &projectionAdapter{sourceDim: sourceDim, targetDim: targetDim, matrix: matrix}, nil
|
||||
}
|
||||
236
pkg/adapter/adapter_test.go
Normal file
236
pkg/adapter/adapter_test.go
Normal file
@@ -0,0 +1,236 @@
|
||||
package adapter
|
||||
|
||||
import (
|
||||
"math"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// unitVec returns a unit vector of length n with value 1/sqrt(n) in each position.
|
||||
func unitVec(n int) []float32 {
|
||||
v := make([]float32, n)
|
||||
val := float32(1.0 / math.Sqrt(float64(n)))
|
||||
for i := range v {
|
||||
v[i] = val
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
func assertNormOne(t *testing.T, v []float32) {
|
||||
t.Helper()
|
||||
var sum float64
|
||||
for _, x := range v {
|
||||
sum += float64(x) * float64(x)
|
||||
}
|
||||
assert.InDelta(t, 1.0, sum, 1e-5, "expected unit norm")
|
||||
}
|
||||
|
||||
// --- L2Norm ---
|
||||
|
||||
func TestL2Norm(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input []float32
|
||||
wantLen int
|
||||
wantNorm float64
|
||||
}{
|
||||
{"unit vector unchanged", []float32{1, 0, 0}, 3, 1.0},
|
||||
{"scale down to unit", []float32{2, 0, 0}, 3, 1.0},
|
||||
{"multi-dim", unitVec(4), 4, 1.0},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := L2Norm(tc.input)
|
||||
assert.Len(t, got, tc.wantLen)
|
||||
assertNormOne(t, got)
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("zero vector returns copy unchanged", func(t *testing.T) {
|
||||
in := []float32{0, 0, 0}
|
||||
got := L2Norm(in)
|
||||
assert.Equal(t, in, got)
|
||||
})
|
||||
}
|
||||
|
||||
// --- TruncateAdapter ---
|
||||
|
||||
func TestNewTruncate_InvalidDims(t *testing.T) {
|
||||
_, err := NewTruncate(0, 768, TruncateFromEnd, PadAtEnd)
|
||||
require.ErrorIs(t, err, ErrInvalidDim)
|
||||
|
||||
_, err = NewTruncate(1536, 0, TruncateFromEnd, PadAtEnd)
|
||||
require.ErrorIs(t, err, ErrInvalidDim)
|
||||
}
|
||||
|
||||
func TestTruncateAdapter(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
src, tgt int
|
||||
truncMode TruncateMode
|
||||
padMode PadMode
|
||||
// builder for input vec; checks are len + norm only unless wantFirst/wantLast set
|
||||
wantLen int
|
||||
wantErr bool
|
||||
}{
|
||||
{"downscale TruncateFromEnd", 1536, 768, TruncateFromEnd, PadAtEnd, 768, false},
|
||||
{"downscale TruncateFromStart", 1536, 768, TruncateFromStart, PadAtEnd, 768, false},
|
||||
{"upscale PadAtEnd", 768, 1536, TruncateFromEnd, PadAtEnd, 1536, false},
|
||||
{"upscale PadAtStart", 768, 1536, TruncateFromEnd, PadAtStart, 1536, false},
|
||||
{"same dim", 768, 768, TruncateFromEnd, PadAtEnd, 768, false},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
a, err := NewTruncate(tc.src, tc.tgt, tc.truncMode, tc.padMode)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tc.src, a.SourceDim())
|
||||
assert.Equal(t, tc.tgt, a.TargetDim())
|
||||
|
||||
got, err := a.Adapt(unitVec(tc.src))
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, got, tc.wantLen)
|
||||
assertNormOne(t, got)
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("dim mismatch returns error", func(t *testing.T) {
|
||||
a, err := NewTruncate(1536, 768, TruncateFromEnd, PadAtEnd)
|
||||
require.NoError(t, err)
|
||||
_, err = a.Adapt(make([]float32, 100))
|
||||
require.ErrorIs(t, err, ErrDimMismatch)
|
||||
})
|
||||
|
||||
t.Run("TruncateFromEnd keeps first elements", func(t *testing.T) {
|
||||
a, err := NewTruncate(4, 2, TruncateFromEnd, PadAtEnd)
|
||||
require.NoError(t, err)
|
||||
// input: [1,2,3,4] — TruncateFromEnd keeps [1,2]
|
||||
in := []float32{1, 2, 3, 4}
|
||||
got, err := a.Adapt(in)
|
||||
require.NoError(t, err)
|
||||
// after L2Norm the signs are preserved; ratio should match
|
||||
assert.Greater(t, got[0], float32(0))
|
||||
assert.Greater(t, got[1], float32(0))
|
||||
})
|
||||
|
||||
t.Run("TruncateFromStart keeps last elements", func(t *testing.T) {
|
||||
a, err := NewTruncate(4, 2, TruncateFromStart, PadAtEnd)
|
||||
require.NoError(t, err)
|
||||
// input: [0,0,3,4] — TruncateFromStart keeps [3,4]
|
||||
in := []float32{0, 0, 3, 4}
|
||||
got, err := a.Adapt(in)
|
||||
require.NoError(t, err)
|
||||
assert.Greater(t, got[0], float32(0))
|
||||
assert.Greater(t, got[1], float32(0))
|
||||
})
|
||||
|
||||
t.Run("PadAtStart zero-pads front", func(t *testing.T) {
|
||||
a, err := NewTruncate(2, 4, TruncateFromEnd, PadAtStart)
|
||||
require.NoError(t, err)
|
||||
in := []float32{1, 0}
|
||||
got, err := a.Adapt(in)
|
||||
require.NoError(t, err)
|
||||
// first two positions should be zero (padded), last two carry the signal
|
||||
assert.Equal(t, float32(0), got[0])
|
||||
assert.Equal(t, float32(0), got[1])
|
||||
})
|
||||
}
|
||||
|
||||
// --- RandomAdapter ---
|
||||
|
||||
func TestNewRandom_InvalidDims(t *testing.T) {
|
||||
_, err := NewRandom(0, 768, 42)
|
||||
require.ErrorIs(t, err, ErrInvalidDim)
|
||||
}
|
||||
|
||||
func TestRandomAdapter(t *testing.T) {
|
||||
t.Run("output length and unit norm", func(t *testing.T) {
|
||||
a, err := NewRandom(1536, 768, 42)
|
||||
require.NoError(t, err)
|
||||
got, err := a.Adapt(unitVec(1536))
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, got, 768)
|
||||
assertNormOne(t, got)
|
||||
})
|
||||
|
||||
t.Run("deterministic with same seed", func(t *testing.T) {
|
||||
a1, _ := NewRandom(64, 32, 99)
|
||||
a2, _ := NewRandom(64, 32, 99)
|
||||
in := unitVec(64)
|
||||
out1, _ := a1.Adapt(in)
|
||||
out2, _ := a2.Adapt(in)
|
||||
assert.Equal(t, out1, out2)
|
||||
})
|
||||
|
||||
t.Run("different seeds produce different output", func(t *testing.T) {
|
||||
a1, _ := NewRandom(64, 32, 1)
|
||||
a2, _ := NewRandom(64, 32, 2)
|
||||
in := unitVec(64)
|
||||
out1, _ := a1.Adapt(in)
|
||||
out2, _ := a2.Adapt(in)
|
||||
assert.NotEqual(t, out1, out2)
|
||||
})
|
||||
|
||||
t.Run("dim mismatch returns error", func(t *testing.T) {
|
||||
a, err := NewRandom(1536, 768, 1)
|
||||
require.NoError(t, err)
|
||||
_, err = a.Adapt(make([]float32, 10))
|
||||
require.ErrorIs(t, err, ErrDimMismatch)
|
||||
})
|
||||
}
|
||||
|
||||
// --- ProjectionAdapter ---
|
||||
|
||||
func TestNewProjection_InvalidMatrix(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
src int
|
||||
tgt int
|
||||
matrix [][]float32
|
||||
errIs error
|
||||
}{
|
||||
{"wrong row count", 4, 2, [][]float32{{1, 2, 3, 4}}, ErrInvalidMatrix},
|
||||
{"wrong col count", 4, 2, [][]float32{{1, 2}, {3, 4}}, ErrInvalidMatrix},
|
||||
{"zero sourceDim", 0, 2, [][]float32{{1}, {2}}, ErrInvalidDim},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
_, err := NewProjection(tc.src, tc.tgt, tc.matrix)
|
||||
require.ErrorIs(t, err, tc.errIs)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProjectionAdapter(t *testing.T) {
|
||||
t.Run("identity matrix same dim", func(t *testing.T) {
|
||||
// 3×3 identity
|
||||
id := [][]float32{{1, 0, 0}, {0, 1, 0}, {0, 0, 1}}
|
||||
a, err := NewProjection(3, 3, id)
|
||||
require.NoError(t, err)
|
||||
in := []float32{1, 2, 3}
|
||||
got, err := a.Adapt(in)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, got, 3)
|
||||
assertNormOne(t, got)
|
||||
})
|
||||
|
||||
t.Run("downscale projection", func(t *testing.T) {
|
||||
// 2×4 matrix
|
||||
m := [][]float32{{1, 0, 0, 0}, {0, 1, 0, 0}}
|
||||
a, err := NewProjection(4, 2, m)
|
||||
require.NoError(t, err)
|
||||
got, err := a.Adapt([]float32{3, 4, 0, 0})
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, got, 2)
|
||||
assertNormOne(t, got)
|
||||
})
|
||||
|
||||
t.Run("dim mismatch returns error", func(t *testing.T) {
|
||||
m := [][]float32{{1, 0}, {0, 1}}
|
||||
a, err := NewProjection(2, 2, m)
|
||||
require.NoError(t, err)
|
||||
_, err = a.Adapt(make([]float32, 5))
|
||||
require.ErrorIs(t, err, ErrDimMismatch)
|
||||
})
|
||||
}
|
||||
23
pkg/adapter/normalize.go
Normal file
23
pkg/adapter/normalize.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package adapter
|
||||
|
||||
import "math"
|
||||
|
||||
// L2Norm returns a new slice with the vector normalized to unit length.
|
||||
// If the vector has zero magnitude it is returned unchanged.
|
||||
func L2Norm(v []float32) []float32 {
|
||||
var sum float64
|
||||
for _, x := range v {
|
||||
sum += float64(x) * float64(x)
|
||||
}
|
||||
if sum == 0 {
|
||||
out := make([]float32, len(v))
|
||||
copy(out, v)
|
||||
return out
|
||||
}
|
||||
norm := float32(math.Sqrt(sum))
|
||||
out := make([]float32, len(v))
|
||||
for i, x := range v {
|
||||
out[i] = x / norm
|
||||
}
|
||||
return out
|
||||
}
|
||||
32
pkg/adapter/projection.go
Normal file
32
pkg/adapter/projection.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package adapter
|
||||
|
||||
import "fmt"
|
||||
|
||||
type projectionAdapter struct {
|
||||
sourceDim int
|
||||
targetDim int
|
||||
matrix [][]float32
|
||||
}
|
||||
|
||||
func (a *projectionAdapter) SourceDim() int { return a.sourceDim }
|
||||
func (a *projectionAdapter) TargetDim() int { return a.targetDim }
|
||||
|
||||
func (a *projectionAdapter) Adapt(vec []float32) ([]float32, error) {
|
||||
if len(vec) != a.sourceDim {
|
||||
return nil, fmt.Errorf("projection adapt: %w: got %d, want %d", ErrDimMismatch, len(vec), a.sourceDim)
|
||||
}
|
||||
return L2Norm(matVecMul(a.matrix, vec)), nil
|
||||
}
|
||||
|
||||
// matVecMul computes m·v where m is [rows][cols] and v has len cols.
|
||||
func matVecMul(m [][]float32, v []float32) []float32 {
|
||||
out := make([]float32, len(m))
|
||||
for i, row := range m {
|
||||
var sum float32
|
||||
for j, val := range row {
|
||||
sum += val * v[j]
|
||||
}
|
||||
out[i] = sum
|
||||
}
|
||||
return out
|
||||
}
|
||||
45
pkg/adapter/random.go
Normal file
45
pkg/adapter/random.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package adapter
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"math/rand"
|
||||
"time"
|
||||
)
|
||||
|
||||
type randomAdapter struct {
|
||||
sourceDim int
|
||||
targetDim int
|
||||
matrix [][]float32
|
||||
}
|
||||
|
||||
func newRandomAdapter(sourceDim, targetDim int, seed int64) *randomAdapter {
|
||||
if seed == 0 {
|
||||
seed = time.Now().UnixNano()
|
||||
}
|
||||
//nolint:gosec // deterministic seeded RNG for projection matrix generation, not security use
|
||||
rng := rand.New(rand.NewSource(seed))
|
||||
|
||||
// Gaussian N(0, 1/targetDim) — preserves expected squared norms (Johnson-Lindenstrauss)
|
||||
stddev := 1.0 / math.Sqrt(float64(targetDim))
|
||||
matrix := make([][]float32, targetDim)
|
||||
for i := range matrix {
|
||||
row := make([]float32, sourceDim)
|
||||
for j := range row {
|
||||
row[j] = float32(rng.NormFloat64() * stddev)
|
||||
}
|
||||
matrix[i] = row
|
||||
}
|
||||
|
||||
return &randomAdapter{sourceDim: sourceDim, targetDim: targetDim, matrix: matrix}
|
||||
}
|
||||
|
||||
func (a *randomAdapter) SourceDim() int { return a.sourceDim }
|
||||
func (a *randomAdapter) TargetDim() int { return a.targetDim }
|
||||
|
||||
func (a *randomAdapter) Adapt(vec []float32) ([]float32, error) {
|
||||
if len(vec) != a.sourceDim {
|
||||
return nil, fmt.Errorf("random adapt: %w: got %d, want %d", ErrDimMismatch, len(vec), a.sourceDim)
|
||||
}
|
||||
return L2Norm(matVecMul(a.matrix, vec)), nil
|
||||
}
|
||||
41
pkg/adapter/truncate.go
Normal file
41
pkg/adapter/truncate.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package adapter
|
||||
|
||||
import "fmt"
|
||||
|
||||
type truncateAdapter struct {
|
||||
sourceDim int
|
||||
targetDim int
|
||||
truncMode TruncateMode
|
||||
padMode PadMode
|
||||
}
|
||||
|
||||
func (a *truncateAdapter) SourceDim() int { return a.sourceDim }
|
||||
func (a *truncateAdapter) TargetDim() int { return a.targetDim }
|
||||
|
||||
func (a *truncateAdapter) Adapt(vec []float32) ([]float32, error) {
|
||||
if len(vec) != a.sourceDim {
|
||||
return nil, fmt.Errorf("truncate adapt: %w: got %d, want %d", ErrDimMismatch, len(vec), a.sourceDim)
|
||||
}
|
||||
|
||||
out := make([]float32, a.targetDim)
|
||||
|
||||
if a.targetDim <= a.sourceDim {
|
||||
// Downscale: truncate
|
||||
switch a.truncMode {
|
||||
case TruncateFromEnd:
|
||||
copy(out, vec[:a.targetDim])
|
||||
case TruncateFromStart:
|
||||
copy(out, vec[a.sourceDim-a.targetDim:])
|
||||
}
|
||||
} else {
|
||||
// Upscale: zero-pad
|
||||
switch a.padMode {
|
||||
case PadAtEnd:
|
||||
copy(out, vec)
|
||||
case PadAtStart:
|
||||
copy(out[a.targetDim-a.sourceDim:], vec)
|
||||
}
|
||||
}
|
||||
|
||||
return L2Norm(out), nil
|
||||
}
|
||||
Reference in New Issue
Block a user