Files
vecna/pkg/adapter/adapter.go
Hein 4009a54e39 feat: 🎉 Vectors na Vectors, the begining
Translate 1536 <-> 768 , 3072 <-> 2048
2026-04-11 18:05:05 +02:00

78 lines
2.8 KiB
Go

package adapter
import (
"errors"
"fmt"
)
// ErrDimMismatch is returned when the input vector length does not match the adapter's source dimension.
var ErrDimMismatch = errors.New("vector dimension mismatch")
// ErrInvalidDim is returned when source or target dimensions are invalid.
var ErrInvalidDim = errors.New("invalid dimension")
// ErrInvalidMatrix is returned when a projection matrix has the wrong shape.
var ErrInvalidMatrix = errors.New("invalid projection matrix shape")
// TruncateMode controls which end of the vector is dropped when downscaling.
type TruncateMode int
const (
// TruncateFromEnd keeps the first targetDim elements (default; correct for Matryoshka models).
TruncateFromEnd TruncateMode = iota
// TruncateFromStart keeps the last targetDim elements.
TruncateFromStart
)
// PadMode controls which end of the vector receives zero-padding when upscaling.
type PadMode int
const (
// PadAtEnd appends zeros to the end (default).
PadAtEnd PadMode = iota
// PadAtStart prepends zeros to the start.
PadAtStart
)
// Adapter translates vectors between two fixed dimensions.
type Adapter interface {
Adapt(vec []float32) ([]float32, error)
SourceDim() int
TargetDim() int
}
// NewTruncate returns a TruncateAdapter for Matryoshka-style or simple truncation/padding.
// t controls which end is dropped when downscaling; p controls which end is padded when upscaling.
func NewTruncate(sourceDim, targetDim int, t TruncateMode, p PadMode) (Adapter, error) {
if sourceDim <= 0 || targetDim <= 0 {
return nil, fmt.Errorf("NewTruncate: %w: source=%d target=%d", ErrInvalidDim, sourceDim, targetDim)
}
return &truncateAdapter{sourceDim: sourceDim, targetDim: targetDim, truncMode: t, padMode: p}, nil
}
// NewRandom returns a RandomAdapter backed by a seeded Gaussian projection matrix.
// seed=0 uses a time-based seed.
func NewRandom(sourceDim, targetDim int, seed int64) (Adapter, error) {
if sourceDim <= 0 || targetDim <= 0 {
return nil, fmt.Errorf("NewRandom: %w: source=%d target=%d", ErrInvalidDim, sourceDim, targetDim)
}
return newRandomAdapter(sourceDim, targetDim, seed), nil
}
// NewProjection returns a ProjectionAdapter using a caller-supplied matrix.
// matrix must have shape [targetDim][sourceDim].
func NewProjection(sourceDim, targetDim int, matrix [][]float32) (Adapter, error) {
if sourceDim <= 0 || targetDim <= 0 {
return nil, fmt.Errorf("NewProjection: %w: source=%d target=%d", ErrInvalidDim, sourceDim, targetDim)
}
if len(matrix) != targetDim {
return nil, fmt.Errorf("NewProjection: %w: got %d rows, want %d", ErrInvalidMatrix, len(matrix), targetDim)
}
for i, row := range matrix {
if len(row) != sourceDim {
return nil, fmt.Errorf("NewProjection: %w: row %d has %d cols, want %d", ErrInvalidMatrix, i, len(row), sourceDim)
}
}
return &projectionAdapter{sourceDim: sourceDim, targetDim: targetDim, matrix: matrix}, nil
}