Files
vecna/pkg/adapter/adapter_test.go
Hein 4009a54e39 feat: 🎉 Vectors na Vectors, the begining
Translate 1536 <-> 768 , 3072 <-> 2048
2026-04-11 18:05:05 +02:00

237 lines
6.3 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package adapter
import (
"math"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// unitVec returns a unit vector of length n with value 1/sqrt(n) in each position.
func unitVec(n int) []float32 {
v := make([]float32, n)
val := float32(1.0 / math.Sqrt(float64(n)))
for i := range v {
v[i] = val
}
return v
}
func assertNormOne(t *testing.T, v []float32) {
t.Helper()
var sum float64
for _, x := range v {
sum += float64(x) * float64(x)
}
assert.InDelta(t, 1.0, sum, 1e-5, "expected unit norm")
}
// --- L2Norm ---
func TestL2Norm(t *testing.T) {
tests := []struct {
name string
input []float32
wantLen int
wantNorm float64
}{
{"unit vector unchanged", []float32{1, 0, 0}, 3, 1.0},
{"scale down to unit", []float32{2, 0, 0}, 3, 1.0},
{"multi-dim", unitVec(4), 4, 1.0},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := L2Norm(tc.input)
assert.Len(t, got, tc.wantLen)
assertNormOne(t, got)
})
}
t.Run("zero vector returns copy unchanged", func(t *testing.T) {
in := []float32{0, 0, 0}
got := L2Norm(in)
assert.Equal(t, in, got)
})
}
// --- TruncateAdapter ---
func TestNewTruncate_InvalidDims(t *testing.T) {
_, err := NewTruncate(0, 768, TruncateFromEnd, PadAtEnd)
require.ErrorIs(t, err, ErrInvalidDim)
_, err = NewTruncate(1536, 0, TruncateFromEnd, PadAtEnd)
require.ErrorIs(t, err, ErrInvalidDim)
}
func TestTruncateAdapter(t *testing.T) {
tests := []struct {
name string
src, tgt int
truncMode TruncateMode
padMode PadMode
// builder for input vec; checks are len + norm only unless wantFirst/wantLast set
wantLen int
wantErr bool
}{
{"downscale TruncateFromEnd", 1536, 768, TruncateFromEnd, PadAtEnd, 768, false},
{"downscale TruncateFromStart", 1536, 768, TruncateFromStart, PadAtEnd, 768, false},
{"upscale PadAtEnd", 768, 1536, TruncateFromEnd, PadAtEnd, 1536, false},
{"upscale PadAtStart", 768, 1536, TruncateFromEnd, PadAtStart, 1536, false},
{"same dim", 768, 768, TruncateFromEnd, PadAtEnd, 768, false},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
a, err := NewTruncate(tc.src, tc.tgt, tc.truncMode, tc.padMode)
require.NoError(t, err)
assert.Equal(t, tc.src, a.SourceDim())
assert.Equal(t, tc.tgt, a.TargetDim())
got, err := a.Adapt(unitVec(tc.src))
require.NoError(t, err)
assert.Len(t, got, tc.wantLen)
assertNormOne(t, got)
})
}
t.Run("dim mismatch returns error", func(t *testing.T) {
a, err := NewTruncate(1536, 768, TruncateFromEnd, PadAtEnd)
require.NoError(t, err)
_, err = a.Adapt(make([]float32, 100))
require.ErrorIs(t, err, ErrDimMismatch)
})
t.Run("TruncateFromEnd keeps first elements", func(t *testing.T) {
a, err := NewTruncate(4, 2, TruncateFromEnd, PadAtEnd)
require.NoError(t, err)
// input: [1,2,3,4] — TruncateFromEnd keeps [1,2]
in := []float32{1, 2, 3, 4}
got, err := a.Adapt(in)
require.NoError(t, err)
// after L2Norm the signs are preserved; ratio should match
assert.Greater(t, got[0], float32(0))
assert.Greater(t, got[1], float32(0))
})
t.Run("TruncateFromStart keeps last elements", func(t *testing.T) {
a, err := NewTruncate(4, 2, TruncateFromStart, PadAtEnd)
require.NoError(t, err)
// input: [0,0,3,4] — TruncateFromStart keeps [3,4]
in := []float32{0, 0, 3, 4}
got, err := a.Adapt(in)
require.NoError(t, err)
assert.Greater(t, got[0], float32(0))
assert.Greater(t, got[1], float32(0))
})
t.Run("PadAtStart zero-pads front", func(t *testing.T) {
a, err := NewTruncate(2, 4, TruncateFromEnd, PadAtStart)
require.NoError(t, err)
in := []float32{1, 0}
got, err := a.Adapt(in)
require.NoError(t, err)
// first two positions should be zero (padded), last two carry the signal
assert.Equal(t, float32(0), got[0])
assert.Equal(t, float32(0), got[1])
})
}
// --- RandomAdapter ---
func TestNewRandom_InvalidDims(t *testing.T) {
_, err := NewRandom(0, 768, 42)
require.ErrorIs(t, err, ErrInvalidDim)
}
func TestRandomAdapter(t *testing.T) {
t.Run("output length and unit norm", func(t *testing.T) {
a, err := NewRandom(1536, 768, 42)
require.NoError(t, err)
got, err := a.Adapt(unitVec(1536))
require.NoError(t, err)
assert.Len(t, got, 768)
assertNormOne(t, got)
})
t.Run("deterministic with same seed", func(t *testing.T) {
a1, _ := NewRandom(64, 32, 99)
a2, _ := NewRandom(64, 32, 99)
in := unitVec(64)
out1, _ := a1.Adapt(in)
out2, _ := a2.Adapt(in)
assert.Equal(t, out1, out2)
})
t.Run("different seeds produce different output", func(t *testing.T) {
a1, _ := NewRandom(64, 32, 1)
a2, _ := NewRandom(64, 32, 2)
in := unitVec(64)
out1, _ := a1.Adapt(in)
out2, _ := a2.Adapt(in)
assert.NotEqual(t, out1, out2)
})
t.Run("dim mismatch returns error", func(t *testing.T) {
a, err := NewRandom(1536, 768, 1)
require.NoError(t, err)
_, err = a.Adapt(make([]float32, 10))
require.ErrorIs(t, err, ErrDimMismatch)
})
}
// --- ProjectionAdapter ---
func TestNewProjection_InvalidMatrix(t *testing.T) {
tests := []struct {
name string
src int
tgt int
matrix [][]float32
errIs error
}{
{"wrong row count", 4, 2, [][]float32{{1, 2, 3, 4}}, ErrInvalidMatrix},
{"wrong col count", 4, 2, [][]float32{{1, 2}, {3, 4}}, ErrInvalidMatrix},
{"zero sourceDim", 0, 2, [][]float32{{1}, {2}}, ErrInvalidDim},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
_, err := NewProjection(tc.src, tc.tgt, tc.matrix)
require.ErrorIs(t, err, tc.errIs)
})
}
}
func TestProjectionAdapter(t *testing.T) {
t.Run("identity matrix same dim", func(t *testing.T) {
// 3×3 identity
id := [][]float32{{1, 0, 0}, {0, 1, 0}, {0, 0, 1}}
a, err := NewProjection(3, 3, id)
require.NoError(t, err)
in := []float32{1, 2, 3}
got, err := a.Adapt(in)
require.NoError(t, err)
assert.Len(t, got, 3)
assertNormOne(t, got)
})
t.Run("downscale projection", func(t *testing.T) {
// 2×4 matrix
m := [][]float32{{1, 0, 0, 0}, {0, 1, 0, 0}}
a, err := NewProjection(4, 2, m)
require.NoError(t, err)
got, err := a.Adapt([]float32{3, 4, 0, 0})
require.NoError(t, err)
assert.Len(t, got, 2)
assertNormOne(t, got)
})
t.Run("dim mismatch returns error", func(t *testing.T) {
m := [][]float32{{1, 0}, {0, 1}}
a, err := NewProjection(2, 2, m)
require.NoError(t, err)
_, err = a.Adapt(make([]float32, 5))
require.ErrorIs(t, err, ErrDimMismatch)
})
}