mirror of
https://github.com/Warky-Devs/vecna.git
synced 2026-05-05 01:26:58 +00:00
78 lines
2.8 KiB
Go
78 lines
2.8 KiB
Go
package adapter
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
)
|
|
|
|
// ErrDimMismatch is returned when the input vector length does not match the adapter's source dimension.
|
|
var ErrDimMismatch = errors.New("vector dimension mismatch")
|
|
|
|
// ErrInvalidDim is returned when source or target dimensions are invalid.
|
|
var ErrInvalidDim = errors.New("invalid dimension")
|
|
|
|
// ErrInvalidMatrix is returned when a projection matrix has the wrong shape.
|
|
var ErrInvalidMatrix = errors.New("invalid projection matrix shape")
|
|
|
|
// TruncateMode controls which end of the vector is dropped when downscaling.
|
|
type TruncateMode int
|
|
|
|
const (
|
|
// TruncateFromEnd keeps the first targetDim elements (default; correct for Matryoshka models).
|
|
TruncateFromEnd TruncateMode = iota
|
|
// TruncateFromStart keeps the last targetDim elements.
|
|
TruncateFromStart
|
|
)
|
|
|
|
// PadMode controls which end of the vector receives zero-padding when upscaling.
|
|
type PadMode int
|
|
|
|
const (
|
|
// PadAtEnd appends zeros to the end (default).
|
|
PadAtEnd PadMode = iota
|
|
// PadAtStart prepends zeros to the start.
|
|
PadAtStart
|
|
)
|
|
|
|
// Adapter translates vectors between two fixed dimensions.
|
|
type Adapter interface {
|
|
Adapt(vec []float32) ([]float32, error)
|
|
SourceDim() int
|
|
TargetDim() int
|
|
}
|
|
|
|
// NewTruncate returns a TruncateAdapter for Matryoshka-style or simple truncation/padding.
|
|
// t controls which end is dropped when downscaling; p controls which end is padded when upscaling.
|
|
func NewTruncate(sourceDim, targetDim int, t TruncateMode, p PadMode) (Adapter, error) {
|
|
if sourceDim <= 0 || targetDim <= 0 {
|
|
return nil, fmt.Errorf("NewTruncate: %w: source=%d target=%d", ErrInvalidDim, sourceDim, targetDim)
|
|
}
|
|
return &truncateAdapter{sourceDim: sourceDim, targetDim: targetDim, truncMode: t, padMode: p}, nil
|
|
}
|
|
|
|
// NewRandom returns a RandomAdapter backed by a seeded Gaussian projection matrix.
|
|
// seed=0 uses a time-based seed.
|
|
func NewRandom(sourceDim, targetDim int, seed int64) (Adapter, error) {
|
|
if sourceDim <= 0 || targetDim <= 0 {
|
|
return nil, fmt.Errorf("NewRandom: %w: source=%d target=%d", ErrInvalidDim, sourceDim, targetDim)
|
|
}
|
|
return newRandomAdapter(sourceDim, targetDim, seed), nil
|
|
}
|
|
|
|
// NewProjection returns a ProjectionAdapter using a caller-supplied matrix.
|
|
// matrix must have shape [targetDim][sourceDim].
|
|
func NewProjection(sourceDim, targetDim int, matrix [][]float32) (Adapter, error) {
|
|
if sourceDim <= 0 || targetDim <= 0 {
|
|
return nil, fmt.Errorf("NewProjection: %w: source=%d target=%d", ErrInvalidDim, sourceDim, targetDim)
|
|
}
|
|
if len(matrix) != targetDim {
|
|
return nil, fmt.Errorf("NewProjection: %w: got %d rows, want %d", ErrInvalidMatrix, len(matrix), targetDim)
|
|
}
|
|
for i, row := range matrix {
|
|
if len(row) != sourceDim {
|
|
return nil, fmt.Errorf("NewProjection: %w: row %d has %d cols, want %d", ErrInvalidMatrix, i, len(row), sourceDim)
|
|
}
|
|
}
|
|
return &projectionAdapter{sourceDim: sourceDim, targetDim: targetDim, matrix: matrix}, nil
|
|
}
|