package adapter import ( "math" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) // unitVec returns a unit vector of length n with value 1/sqrt(n) in each position. func unitVec(n int) []float32 { v := make([]float32, n) val := float32(1.0 / math.Sqrt(float64(n))) for i := range v { v[i] = val } return v } func assertNormOne(t *testing.T, v []float32) { t.Helper() var sum float64 for _, x := range v { sum += float64(x) * float64(x) } assert.InDelta(t, 1.0, sum, 1e-5, "expected unit norm") } // --- L2Norm --- func TestL2Norm(t *testing.T) { tests := []struct { name string input []float32 wantLen int wantNorm float64 }{ {"unit vector unchanged", []float32{1, 0, 0}, 3, 1.0}, {"scale down to unit", []float32{2, 0, 0}, 3, 1.0}, {"multi-dim", unitVec(4), 4, 1.0}, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { got := L2Norm(tc.input) assert.Len(t, got, tc.wantLen) assertNormOne(t, got) }) } t.Run("zero vector returns copy unchanged", func(t *testing.T) { in := []float32{0, 0, 0} got := L2Norm(in) assert.Equal(t, in, got) }) } // --- TruncateAdapter --- func TestNewTruncate_InvalidDims(t *testing.T) { _, err := NewTruncate(0, 768, TruncateFromEnd, PadAtEnd) require.ErrorIs(t, err, ErrInvalidDim) _, err = NewTruncate(1536, 0, TruncateFromEnd, PadAtEnd) require.ErrorIs(t, err, ErrInvalidDim) } func TestTruncateAdapter(t *testing.T) { tests := []struct { name string src, tgt int truncMode TruncateMode padMode PadMode // builder for input vec; checks are len + norm only unless wantFirst/wantLast set wantLen int wantErr bool }{ {"downscale TruncateFromEnd", 1536, 768, TruncateFromEnd, PadAtEnd, 768, false}, {"downscale TruncateFromStart", 1536, 768, TruncateFromStart, PadAtEnd, 768, false}, {"upscale PadAtEnd", 768, 1536, TruncateFromEnd, PadAtEnd, 1536, false}, {"upscale PadAtStart", 768, 1536, TruncateFromEnd, PadAtStart, 1536, false}, {"same dim", 768, 768, TruncateFromEnd, PadAtEnd, 768, false}, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { a, err := NewTruncate(tc.src, tc.tgt, tc.truncMode, tc.padMode) require.NoError(t, err) assert.Equal(t, tc.src, a.SourceDim()) assert.Equal(t, tc.tgt, a.TargetDim()) got, err := a.Adapt(unitVec(tc.src)) require.NoError(t, err) assert.Len(t, got, tc.wantLen) assertNormOne(t, got) }) } t.Run("dim mismatch returns error", func(t *testing.T) { a, err := NewTruncate(1536, 768, TruncateFromEnd, PadAtEnd) require.NoError(t, err) _, err = a.Adapt(make([]float32, 100)) require.ErrorIs(t, err, ErrDimMismatch) }) t.Run("TruncateFromEnd keeps first elements", func(t *testing.T) { a, err := NewTruncate(4, 2, TruncateFromEnd, PadAtEnd) require.NoError(t, err) // input: [1,2,3,4] — TruncateFromEnd keeps [1,2] in := []float32{1, 2, 3, 4} got, err := a.Adapt(in) require.NoError(t, err) // after L2Norm the signs are preserved; ratio should match assert.Greater(t, got[0], float32(0)) assert.Greater(t, got[1], float32(0)) }) t.Run("TruncateFromStart keeps last elements", func(t *testing.T) { a, err := NewTruncate(4, 2, TruncateFromStart, PadAtEnd) require.NoError(t, err) // input: [0,0,3,4] — TruncateFromStart keeps [3,4] in := []float32{0, 0, 3, 4} got, err := a.Adapt(in) require.NoError(t, err) assert.Greater(t, got[0], float32(0)) assert.Greater(t, got[1], float32(0)) }) t.Run("PadAtStart zero-pads front", func(t *testing.T) { a, err := NewTruncate(2, 4, TruncateFromEnd, PadAtStart) require.NoError(t, err) in := []float32{1, 0} got, err := a.Adapt(in) require.NoError(t, err) // first two positions should be zero (padded), last two carry the signal assert.Equal(t, float32(0), got[0]) assert.Equal(t, float32(0), got[1]) }) } // --- RandomAdapter --- func TestNewRandom_InvalidDims(t *testing.T) { _, err := NewRandom(0, 768, 42) require.ErrorIs(t, err, ErrInvalidDim) } func TestRandomAdapter(t *testing.T) { t.Run("output length and unit norm", func(t *testing.T) { a, err := NewRandom(1536, 768, 42) require.NoError(t, err) got, err := a.Adapt(unitVec(1536)) require.NoError(t, err) assert.Len(t, got, 768) assertNormOne(t, got) }) t.Run("deterministic with same seed", func(t *testing.T) { a1, _ := NewRandom(64, 32, 99) a2, _ := NewRandom(64, 32, 99) in := unitVec(64) out1, _ := a1.Adapt(in) out2, _ := a2.Adapt(in) assert.Equal(t, out1, out2) }) t.Run("different seeds produce different output", func(t *testing.T) { a1, _ := NewRandom(64, 32, 1) a2, _ := NewRandom(64, 32, 2) in := unitVec(64) out1, _ := a1.Adapt(in) out2, _ := a2.Adapt(in) assert.NotEqual(t, out1, out2) }) t.Run("dim mismatch returns error", func(t *testing.T) { a, err := NewRandom(1536, 768, 1) require.NoError(t, err) _, err = a.Adapt(make([]float32, 10)) require.ErrorIs(t, err, ErrDimMismatch) }) } // --- ProjectionAdapter --- func TestNewProjection_InvalidMatrix(t *testing.T) { tests := []struct { name string src int tgt int matrix [][]float32 errIs error }{ {"wrong row count", 4, 2, [][]float32{{1, 2, 3, 4}}, ErrInvalidMatrix}, {"wrong col count", 4, 2, [][]float32{{1, 2}, {3, 4}}, ErrInvalidMatrix}, {"zero sourceDim", 0, 2, [][]float32{{1}, {2}}, ErrInvalidDim}, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { _, err := NewProjection(tc.src, tc.tgt, tc.matrix) require.ErrorIs(t, err, tc.errIs) }) } } func TestProjectionAdapter(t *testing.T) { t.Run("identity matrix same dim", func(t *testing.T) { // 3×3 identity id := [][]float32{{1, 0, 0}, {0, 1, 0}, {0, 0, 1}} a, err := NewProjection(3, 3, id) require.NoError(t, err) in := []float32{1, 2, 3} got, err := a.Adapt(in) require.NoError(t, err) assert.Len(t, got, 3) assertNormOne(t, got) }) t.Run("downscale projection", func(t *testing.T) { // 2×4 matrix m := [][]float32{{1, 0, 0, 0}, {0, 1, 0, 0}} a, err := NewProjection(4, 2, m) require.NoError(t, err) got, err := a.Adapt([]float32{3, 4, 0, 0}) require.NoError(t, err) assert.Len(t, got, 2) assertNormOne(t, got) }) t.Run("dim mismatch returns error", func(t *testing.T) { m := [][]float32{{1, 0}, {0, 1}} a, err := NewProjection(2, 2, m) require.NoError(t, err) _, err = a.Adapt(make([]float32, 5)) require.ErrorIs(t, err, ErrDimMismatch) }) }