mirror of
https://github.com/Warky-Devs/vecna.git
synced 2026-05-05 01:26:58 +00:00
feat: 🎉 Vectors na Vectors, the begining
Translate 1536 <-> 768 , 3072 <-> 2048
This commit is contained in:
236
pkg/adapter/adapter_test.go
Normal file
236
pkg/adapter/adapter_test.go
Normal file
@@ -0,0 +1,236 @@
|
||||
package adapter
|
||||
|
||||
import (
|
||||
"math"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// unitVec returns a unit vector of length n with value 1/sqrt(n) in each position.
|
||||
func unitVec(n int) []float32 {
|
||||
v := make([]float32, n)
|
||||
val := float32(1.0 / math.Sqrt(float64(n)))
|
||||
for i := range v {
|
||||
v[i] = val
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
func assertNormOne(t *testing.T, v []float32) {
|
||||
t.Helper()
|
||||
var sum float64
|
||||
for _, x := range v {
|
||||
sum += float64(x) * float64(x)
|
||||
}
|
||||
assert.InDelta(t, 1.0, sum, 1e-5, "expected unit norm")
|
||||
}
|
||||
|
||||
// --- L2Norm ---
|
||||
|
||||
func TestL2Norm(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input []float32
|
||||
wantLen int
|
||||
wantNorm float64
|
||||
}{
|
||||
{"unit vector unchanged", []float32{1, 0, 0}, 3, 1.0},
|
||||
{"scale down to unit", []float32{2, 0, 0}, 3, 1.0},
|
||||
{"multi-dim", unitVec(4), 4, 1.0},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := L2Norm(tc.input)
|
||||
assert.Len(t, got, tc.wantLen)
|
||||
assertNormOne(t, got)
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("zero vector returns copy unchanged", func(t *testing.T) {
|
||||
in := []float32{0, 0, 0}
|
||||
got := L2Norm(in)
|
||||
assert.Equal(t, in, got)
|
||||
})
|
||||
}
|
||||
|
||||
// --- TruncateAdapter ---
|
||||
|
||||
func TestNewTruncate_InvalidDims(t *testing.T) {
|
||||
_, err := NewTruncate(0, 768, TruncateFromEnd, PadAtEnd)
|
||||
require.ErrorIs(t, err, ErrInvalidDim)
|
||||
|
||||
_, err = NewTruncate(1536, 0, TruncateFromEnd, PadAtEnd)
|
||||
require.ErrorIs(t, err, ErrInvalidDim)
|
||||
}
|
||||
|
||||
func TestTruncateAdapter(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
src, tgt int
|
||||
truncMode TruncateMode
|
||||
padMode PadMode
|
||||
// builder for input vec; checks are len + norm only unless wantFirst/wantLast set
|
||||
wantLen int
|
||||
wantErr bool
|
||||
}{
|
||||
{"downscale TruncateFromEnd", 1536, 768, TruncateFromEnd, PadAtEnd, 768, false},
|
||||
{"downscale TruncateFromStart", 1536, 768, TruncateFromStart, PadAtEnd, 768, false},
|
||||
{"upscale PadAtEnd", 768, 1536, TruncateFromEnd, PadAtEnd, 1536, false},
|
||||
{"upscale PadAtStart", 768, 1536, TruncateFromEnd, PadAtStart, 1536, false},
|
||||
{"same dim", 768, 768, TruncateFromEnd, PadAtEnd, 768, false},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
a, err := NewTruncate(tc.src, tc.tgt, tc.truncMode, tc.padMode)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tc.src, a.SourceDim())
|
||||
assert.Equal(t, tc.tgt, a.TargetDim())
|
||||
|
||||
got, err := a.Adapt(unitVec(tc.src))
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, got, tc.wantLen)
|
||||
assertNormOne(t, got)
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("dim mismatch returns error", func(t *testing.T) {
|
||||
a, err := NewTruncate(1536, 768, TruncateFromEnd, PadAtEnd)
|
||||
require.NoError(t, err)
|
||||
_, err = a.Adapt(make([]float32, 100))
|
||||
require.ErrorIs(t, err, ErrDimMismatch)
|
||||
})
|
||||
|
||||
t.Run("TruncateFromEnd keeps first elements", func(t *testing.T) {
|
||||
a, err := NewTruncate(4, 2, TruncateFromEnd, PadAtEnd)
|
||||
require.NoError(t, err)
|
||||
// input: [1,2,3,4] — TruncateFromEnd keeps [1,2]
|
||||
in := []float32{1, 2, 3, 4}
|
||||
got, err := a.Adapt(in)
|
||||
require.NoError(t, err)
|
||||
// after L2Norm the signs are preserved; ratio should match
|
||||
assert.Greater(t, got[0], float32(0))
|
||||
assert.Greater(t, got[1], float32(0))
|
||||
})
|
||||
|
||||
t.Run("TruncateFromStart keeps last elements", func(t *testing.T) {
|
||||
a, err := NewTruncate(4, 2, TruncateFromStart, PadAtEnd)
|
||||
require.NoError(t, err)
|
||||
// input: [0,0,3,4] — TruncateFromStart keeps [3,4]
|
||||
in := []float32{0, 0, 3, 4}
|
||||
got, err := a.Adapt(in)
|
||||
require.NoError(t, err)
|
||||
assert.Greater(t, got[0], float32(0))
|
||||
assert.Greater(t, got[1], float32(0))
|
||||
})
|
||||
|
||||
t.Run("PadAtStart zero-pads front", func(t *testing.T) {
|
||||
a, err := NewTruncate(2, 4, TruncateFromEnd, PadAtStart)
|
||||
require.NoError(t, err)
|
||||
in := []float32{1, 0}
|
||||
got, err := a.Adapt(in)
|
||||
require.NoError(t, err)
|
||||
// first two positions should be zero (padded), last two carry the signal
|
||||
assert.Equal(t, float32(0), got[0])
|
||||
assert.Equal(t, float32(0), got[1])
|
||||
})
|
||||
}
|
||||
|
||||
// --- RandomAdapter ---
|
||||
|
||||
func TestNewRandom_InvalidDims(t *testing.T) {
|
||||
_, err := NewRandom(0, 768, 42)
|
||||
require.ErrorIs(t, err, ErrInvalidDim)
|
||||
}
|
||||
|
||||
func TestRandomAdapter(t *testing.T) {
|
||||
t.Run("output length and unit norm", func(t *testing.T) {
|
||||
a, err := NewRandom(1536, 768, 42)
|
||||
require.NoError(t, err)
|
||||
got, err := a.Adapt(unitVec(1536))
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, got, 768)
|
||||
assertNormOne(t, got)
|
||||
})
|
||||
|
||||
t.Run("deterministic with same seed", func(t *testing.T) {
|
||||
a1, _ := NewRandom(64, 32, 99)
|
||||
a2, _ := NewRandom(64, 32, 99)
|
||||
in := unitVec(64)
|
||||
out1, _ := a1.Adapt(in)
|
||||
out2, _ := a2.Adapt(in)
|
||||
assert.Equal(t, out1, out2)
|
||||
})
|
||||
|
||||
t.Run("different seeds produce different output", func(t *testing.T) {
|
||||
a1, _ := NewRandom(64, 32, 1)
|
||||
a2, _ := NewRandom(64, 32, 2)
|
||||
in := unitVec(64)
|
||||
out1, _ := a1.Adapt(in)
|
||||
out2, _ := a2.Adapt(in)
|
||||
assert.NotEqual(t, out1, out2)
|
||||
})
|
||||
|
||||
t.Run("dim mismatch returns error", func(t *testing.T) {
|
||||
a, err := NewRandom(1536, 768, 1)
|
||||
require.NoError(t, err)
|
||||
_, err = a.Adapt(make([]float32, 10))
|
||||
require.ErrorIs(t, err, ErrDimMismatch)
|
||||
})
|
||||
}
|
||||
|
||||
// --- ProjectionAdapter ---
|
||||
|
||||
func TestNewProjection_InvalidMatrix(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
src int
|
||||
tgt int
|
||||
matrix [][]float32
|
||||
errIs error
|
||||
}{
|
||||
{"wrong row count", 4, 2, [][]float32{{1, 2, 3, 4}}, ErrInvalidMatrix},
|
||||
{"wrong col count", 4, 2, [][]float32{{1, 2}, {3, 4}}, ErrInvalidMatrix},
|
||||
{"zero sourceDim", 0, 2, [][]float32{{1}, {2}}, ErrInvalidDim},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
_, err := NewProjection(tc.src, tc.tgt, tc.matrix)
|
||||
require.ErrorIs(t, err, tc.errIs)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProjectionAdapter(t *testing.T) {
|
||||
t.Run("identity matrix same dim", func(t *testing.T) {
|
||||
// 3×3 identity
|
||||
id := [][]float32{{1, 0, 0}, {0, 1, 0}, {0, 0, 1}}
|
||||
a, err := NewProjection(3, 3, id)
|
||||
require.NoError(t, err)
|
||||
in := []float32{1, 2, 3}
|
||||
got, err := a.Adapt(in)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, got, 3)
|
||||
assertNormOne(t, got)
|
||||
})
|
||||
|
||||
t.Run("downscale projection", func(t *testing.T) {
|
||||
// 2×4 matrix
|
||||
m := [][]float32{{1, 0, 0, 0}, {0, 1, 0, 0}}
|
||||
a, err := NewProjection(4, 2, m)
|
||||
require.NoError(t, err)
|
||||
got, err := a.Adapt([]float32{3, 4, 0, 0})
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, got, 2)
|
||||
assertNormOne(t, got)
|
||||
})
|
||||
|
||||
t.Run("dim mismatch returns error", func(t *testing.T) {
|
||||
m := [][]float32{{1, 0}, {0, 1}}
|
||||
a, err := NewProjection(2, 2, m)
|
||||
require.NoError(t, err)
|
||||
_, err = a.Adapt(make([]float32, 5))
|
||||
require.ErrorIs(t, err, ErrDimMismatch)
|
||||
})
|
||||
}
|
||||
Reference in New Issue
Block a user