Files
vecna/pkg/adapter/adapter.go

89 lines
3.4 KiB
Go

package adapter
import (
"errors"
"fmt"
)
// ErrDimMismatch is returned when the input vector length does not match the adapter's source dimension.
var ErrDimMismatch = errors.New("vector dimension mismatch")
// ErrInvalidDim is returned when source or target dimensions are invalid.
var ErrInvalidDim = errors.New("invalid dimension")
// ErrInvalidMatrix is returned when a projection matrix has the wrong shape.
var ErrInvalidMatrix = errors.New("invalid projection matrix shape")
// TruncateMode controls which end of the vector is dropped when downscaling.
type TruncateMode int
const (
// TruncateFromEnd keeps the first targetDim elements (default; correct for Matryoshka models).
TruncateFromEnd TruncateMode = iota
// TruncateFromStart keeps the last targetDim elements.
TruncateFromStart
)
// PadMode controls which end of the vector receives zero-padding when upscaling.
type PadMode int
const (
// PadAtEnd appends zeros to the end (default).
PadAtEnd PadMode = iota
// PadAtStart prepends zeros to the start.
PadAtStart
)
// Adapter translates vectors between two fixed dimensions.
type Adapter interface {
Adapt(vec []float32) ([]float32, error)
SourceDim() int
TargetDim() int
}
// passthroughAdapter returns the input vector unchanged.
type passthroughAdapter struct{}
func (passthroughAdapter) Adapt(vec []float32) ([]float32, error) { return vec, nil }
func (passthroughAdapter) SourceDim() int { return 0 }
func (passthroughAdapter) TargetDim() int { return 0 }
// NewPassthrough returns an Adapter that returns vectors unchanged.
// Use when no dimension adaptation is needed.
func NewPassthrough() Adapter { return passthroughAdapter{} }
// NewTruncate returns a TruncateAdapter for Matryoshka-style or simple truncation/padding.
// t controls which end is dropped when downscaling; p controls which end is padded when upscaling.
func NewTruncate(sourceDim, targetDim int, t TruncateMode, p PadMode) (Adapter, error) {
if sourceDim <= 0 || targetDim <= 0 {
return nil, fmt.Errorf("NewTruncate: %w: source=%d target=%d", ErrInvalidDim, sourceDim, targetDim)
}
return &truncateAdapter{sourceDim: sourceDim, targetDim: targetDim, truncMode: t, padMode: p}, nil
}
// NewRandom returns a RandomAdapter backed by a seeded Gaussian projection matrix.
// seed=0 uses a time-based seed.
func NewRandom(sourceDim, targetDim int, seed int64) (Adapter, error) {
if sourceDim <= 0 || targetDim <= 0 {
return nil, fmt.Errorf("NewRandom: %w: source=%d target=%d", ErrInvalidDim, sourceDim, targetDim)
}
return newRandomAdapter(sourceDim, targetDim, seed), nil
}
// NewProjection returns a ProjectionAdapter using a caller-supplied matrix.
// matrix must have shape [targetDim][sourceDim].
func NewProjection(sourceDim, targetDim int, matrix [][]float32) (Adapter, error) {
if sourceDim <= 0 || targetDim <= 0 {
return nil, fmt.Errorf("NewProjection: %w: source=%d target=%d", ErrInvalidDim, sourceDim, targetDim)
}
if len(matrix) != targetDim {
return nil, fmt.Errorf("NewProjection: %w: got %d rows, want %d", ErrInvalidMatrix, len(matrix), targetDim)
}
for i, row := range matrix {
if len(row) != sourceDim {
return nil, fmt.Errorf("NewProjection: %w: row %d has %d cols, want %d", ErrInvalidMatrix, i, len(row), sourceDim)
}
}
return &projectionAdapter{sourceDim: sourceDim, targetDim: targetDim, matrix: matrix}, nil
}