mirror of
https://github.com/Warky-Devs/vecna.git
synced 2026-05-05 01:26:58 +00:00
237 lines
6.3 KiB
Go
237 lines
6.3 KiB
Go
package adapter
|
||
|
||
import (
|
||
"math"
|
||
"testing"
|
||
|
||
"github.com/stretchr/testify/assert"
|
||
"github.com/stretchr/testify/require"
|
||
)
|
||
|
||
// unitVec returns a unit vector of length n with value 1/sqrt(n) in each position.
|
||
func unitVec(n int) []float32 {
|
||
v := make([]float32, n)
|
||
val := float32(1.0 / math.Sqrt(float64(n)))
|
||
for i := range v {
|
||
v[i] = val
|
||
}
|
||
return v
|
||
}
|
||
|
||
func assertNormOne(t *testing.T, v []float32) {
|
||
t.Helper()
|
||
var sum float64
|
||
for _, x := range v {
|
||
sum += float64(x) * float64(x)
|
||
}
|
||
assert.InDelta(t, 1.0, sum, 1e-5, "expected unit norm")
|
||
}
|
||
|
||
// --- L2Norm ---
|
||
|
||
func TestL2Norm(t *testing.T) {
|
||
tests := []struct {
|
||
name string
|
||
input []float32
|
||
wantLen int
|
||
wantNorm float64
|
||
}{
|
||
{"unit vector unchanged", []float32{1, 0, 0}, 3, 1.0},
|
||
{"scale down to unit", []float32{2, 0, 0}, 3, 1.0},
|
||
{"multi-dim", unitVec(4), 4, 1.0},
|
||
}
|
||
for _, tc := range tests {
|
||
t.Run(tc.name, func(t *testing.T) {
|
||
got := L2Norm(tc.input)
|
||
assert.Len(t, got, tc.wantLen)
|
||
assertNormOne(t, got)
|
||
})
|
||
}
|
||
|
||
t.Run("zero vector returns copy unchanged", func(t *testing.T) {
|
||
in := []float32{0, 0, 0}
|
||
got := L2Norm(in)
|
||
assert.Equal(t, in, got)
|
||
})
|
||
}
|
||
|
||
// --- TruncateAdapter ---
|
||
|
||
func TestNewTruncate_InvalidDims(t *testing.T) {
|
||
_, err := NewTruncate(0, 768, TruncateFromEnd, PadAtEnd)
|
||
require.ErrorIs(t, err, ErrInvalidDim)
|
||
|
||
_, err = NewTruncate(1536, 0, TruncateFromEnd, PadAtEnd)
|
||
require.ErrorIs(t, err, ErrInvalidDim)
|
||
}
|
||
|
||
func TestTruncateAdapter(t *testing.T) {
|
||
tests := []struct {
|
||
name string
|
||
src, tgt int
|
||
truncMode TruncateMode
|
||
padMode PadMode
|
||
// builder for input vec; checks are len + norm only unless wantFirst/wantLast set
|
||
wantLen int
|
||
wantErr bool
|
||
}{
|
||
{"downscale TruncateFromEnd", 1536, 768, TruncateFromEnd, PadAtEnd, 768, false},
|
||
{"downscale TruncateFromStart", 1536, 768, TruncateFromStart, PadAtEnd, 768, false},
|
||
{"upscale PadAtEnd", 768, 1536, TruncateFromEnd, PadAtEnd, 1536, false},
|
||
{"upscale PadAtStart", 768, 1536, TruncateFromEnd, PadAtStart, 1536, false},
|
||
{"same dim", 768, 768, TruncateFromEnd, PadAtEnd, 768, false},
|
||
}
|
||
for _, tc := range tests {
|
||
t.Run(tc.name, func(t *testing.T) {
|
||
a, err := NewTruncate(tc.src, tc.tgt, tc.truncMode, tc.padMode)
|
||
require.NoError(t, err)
|
||
assert.Equal(t, tc.src, a.SourceDim())
|
||
assert.Equal(t, tc.tgt, a.TargetDim())
|
||
|
||
got, err := a.Adapt(unitVec(tc.src))
|
||
require.NoError(t, err)
|
||
assert.Len(t, got, tc.wantLen)
|
||
assertNormOne(t, got)
|
||
})
|
||
}
|
||
|
||
t.Run("dim mismatch returns error", func(t *testing.T) {
|
||
a, err := NewTruncate(1536, 768, TruncateFromEnd, PadAtEnd)
|
||
require.NoError(t, err)
|
||
_, err = a.Adapt(make([]float32, 100))
|
||
require.ErrorIs(t, err, ErrDimMismatch)
|
||
})
|
||
|
||
t.Run("TruncateFromEnd keeps first elements", func(t *testing.T) {
|
||
a, err := NewTruncate(4, 2, TruncateFromEnd, PadAtEnd)
|
||
require.NoError(t, err)
|
||
// input: [1,2,3,4] — TruncateFromEnd keeps [1,2]
|
||
in := []float32{1, 2, 3, 4}
|
||
got, err := a.Adapt(in)
|
||
require.NoError(t, err)
|
||
// after L2Norm the signs are preserved; ratio should match
|
||
assert.Greater(t, got[0], float32(0))
|
||
assert.Greater(t, got[1], float32(0))
|
||
})
|
||
|
||
t.Run("TruncateFromStart keeps last elements", func(t *testing.T) {
|
||
a, err := NewTruncate(4, 2, TruncateFromStart, PadAtEnd)
|
||
require.NoError(t, err)
|
||
// input: [0,0,3,4] — TruncateFromStart keeps [3,4]
|
||
in := []float32{0, 0, 3, 4}
|
||
got, err := a.Adapt(in)
|
||
require.NoError(t, err)
|
||
assert.Greater(t, got[0], float32(0))
|
||
assert.Greater(t, got[1], float32(0))
|
||
})
|
||
|
||
t.Run("PadAtStart zero-pads front", func(t *testing.T) {
|
||
a, err := NewTruncate(2, 4, TruncateFromEnd, PadAtStart)
|
||
require.NoError(t, err)
|
||
in := []float32{1, 0}
|
||
got, err := a.Adapt(in)
|
||
require.NoError(t, err)
|
||
// first two positions should be zero (padded), last two carry the signal
|
||
assert.Equal(t, float32(0), got[0])
|
||
assert.Equal(t, float32(0), got[1])
|
||
})
|
||
}
|
||
|
||
// --- RandomAdapter ---
|
||
|
||
func TestNewRandom_InvalidDims(t *testing.T) {
|
||
_, err := NewRandom(0, 768, 42)
|
||
require.ErrorIs(t, err, ErrInvalidDim)
|
||
}
|
||
|
||
func TestRandomAdapter(t *testing.T) {
|
||
t.Run("output length and unit norm", func(t *testing.T) {
|
||
a, err := NewRandom(1536, 768, 42)
|
||
require.NoError(t, err)
|
||
got, err := a.Adapt(unitVec(1536))
|
||
require.NoError(t, err)
|
||
assert.Len(t, got, 768)
|
||
assertNormOne(t, got)
|
||
})
|
||
|
||
t.Run("deterministic with same seed", func(t *testing.T) {
|
||
a1, _ := NewRandom(64, 32, 99)
|
||
a2, _ := NewRandom(64, 32, 99)
|
||
in := unitVec(64)
|
||
out1, _ := a1.Adapt(in)
|
||
out2, _ := a2.Adapt(in)
|
||
assert.Equal(t, out1, out2)
|
||
})
|
||
|
||
t.Run("different seeds produce different output", func(t *testing.T) {
|
||
a1, _ := NewRandom(64, 32, 1)
|
||
a2, _ := NewRandom(64, 32, 2)
|
||
in := unitVec(64)
|
||
out1, _ := a1.Adapt(in)
|
||
out2, _ := a2.Adapt(in)
|
||
assert.NotEqual(t, out1, out2)
|
||
})
|
||
|
||
t.Run("dim mismatch returns error", func(t *testing.T) {
|
||
a, err := NewRandom(1536, 768, 1)
|
||
require.NoError(t, err)
|
||
_, err = a.Adapt(make([]float32, 10))
|
||
require.ErrorIs(t, err, ErrDimMismatch)
|
||
})
|
||
}
|
||
|
||
// --- ProjectionAdapter ---
|
||
|
||
func TestNewProjection_InvalidMatrix(t *testing.T) {
|
||
tests := []struct {
|
||
name string
|
||
src int
|
||
tgt int
|
||
matrix [][]float32
|
||
errIs error
|
||
}{
|
||
{"wrong row count", 4, 2, [][]float32{{1, 2, 3, 4}}, ErrInvalidMatrix},
|
||
{"wrong col count", 4, 2, [][]float32{{1, 2}, {3, 4}}, ErrInvalidMatrix},
|
||
{"zero sourceDim", 0, 2, [][]float32{{1}, {2}}, ErrInvalidDim},
|
||
}
|
||
for _, tc := range tests {
|
||
t.Run(tc.name, func(t *testing.T) {
|
||
_, err := NewProjection(tc.src, tc.tgt, tc.matrix)
|
||
require.ErrorIs(t, err, tc.errIs)
|
||
})
|
||
}
|
||
}
|
||
|
||
func TestProjectionAdapter(t *testing.T) {
|
||
t.Run("identity matrix same dim", func(t *testing.T) {
|
||
// 3×3 identity
|
||
id := [][]float32{{1, 0, 0}, {0, 1, 0}, {0, 0, 1}}
|
||
a, err := NewProjection(3, 3, id)
|
||
require.NoError(t, err)
|
||
in := []float32{1, 2, 3}
|
||
got, err := a.Adapt(in)
|
||
require.NoError(t, err)
|
||
assert.Len(t, got, 3)
|
||
assertNormOne(t, got)
|
||
})
|
||
|
||
t.Run("downscale projection", func(t *testing.T) {
|
||
// 2×4 matrix
|
||
m := [][]float32{{1, 0, 0, 0}, {0, 1, 0, 0}}
|
||
a, err := NewProjection(4, 2, m)
|
||
require.NoError(t, err)
|
||
got, err := a.Adapt([]float32{3, 4, 0, 0})
|
||
require.NoError(t, err)
|
||
assert.Len(t, got, 2)
|
||
assertNormOne(t, got)
|
||
})
|
||
|
||
t.Run("dim mismatch returns error", func(t *testing.T) {
|
||
m := [][]float32{{1, 0}, {0, 1}}
|
||
a, err := NewProjection(2, 2, m)
|
||
require.NoError(t, err)
|
||
_, err = a.Adapt(make([]float32, 5))
|
||
require.ErrorIs(t, err, ErrDimMismatch)
|
||
})
|
||
}
|