package adapter import ( "errors" "fmt" ) // ErrDimMismatch is returned when the input vector length does not match the adapter's source dimension. var ErrDimMismatch = errors.New("vector dimension mismatch") // ErrInvalidDim is returned when source or target dimensions are invalid. var ErrInvalidDim = errors.New("invalid dimension") // ErrInvalidMatrix is returned when a projection matrix has the wrong shape. var ErrInvalidMatrix = errors.New("invalid projection matrix shape") // TruncateMode controls which end of the vector is dropped when downscaling. type TruncateMode int const ( // TruncateFromEnd keeps the first targetDim elements (default; correct for Matryoshka models). TruncateFromEnd TruncateMode = iota // TruncateFromStart keeps the last targetDim elements. TruncateFromStart ) // PadMode controls which end of the vector receives zero-padding when upscaling. type PadMode int const ( // PadAtEnd appends zeros to the end (default). PadAtEnd PadMode = iota // PadAtStart prepends zeros to the start. PadAtStart ) // Adapter translates vectors between two fixed dimensions. type Adapter interface { Adapt(vec []float32) ([]float32, error) SourceDim() int TargetDim() int } // passthroughAdapter returns the input vector unchanged. type passthroughAdapter struct{} func (passthroughAdapter) Adapt(vec []float32) ([]float32, error) { return vec, nil } func (passthroughAdapter) SourceDim() int { return 0 } func (passthroughAdapter) TargetDim() int { return 0 } // NewPassthrough returns an Adapter that returns vectors unchanged. // Use when no dimension adaptation is needed. func NewPassthrough() Adapter { return passthroughAdapter{} } // NewTruncate returns a TruncateAdapter for Matryoshka-style or simple truncation/padding. // t controls which end is dropped when downscaling; p controls which end is padded when upscaling. func NewTruncate(sourceDim, targetDim int, t TruncateMode, p PadMode) (Adapter, error) { if sourceDim <= 0 || targetDim <= 0 { return nil, fmt.Errorf("NewTruncate: %w: source=%d target=%d", ErrInvalidDim, sourceDim, targetDim) } return &truncateAdapter{sourceDim: sourceDim, targetDim: targetDim, truncMode: t, padMode: p}, nil } // NewRandom returns a RandomAdapter backed by a seeded Gaussian projection matrix. // seed=0 uses a time-based seed. func NewRandom(sourceDim, targetDim int, seed int64) (Adapter, error) { if sourceDim <= 0 || targetDim <= 0 { return nil, fmt.Errorf("NewRandom: %w: source=%d target=%d", ErrInvalidDim, sourceDim, targetDim) } return newRandomAdapter(sourceDim, targetDim, seed), nil } // NewProjection returns a ProjectionAdapter using a caller-supplied matrix. // matrix must have shape [targetDim][sourceDim]. func NewProjection(sourceDim, targetDim int, matrix [][]float32) (Adapter, error) { if sourceDim <= 0 || targetDim <= 0 { return nil, fmt.Errorf("NewProjection: %w: source=%d target=%d", ErrInvalidDim, sourceDim, targetDim) } if len(matrix) != targetDim { return nil, fmt.Errorf("NewProjection: %w: got %d rows, want %d", ErrInvalidMatrix, len(matrix), targetDim) } for i, row := range matrix { if len(row) != sourceDim { return nil, fmt.Errorf("NewProjection: %w: row %d has %d cols, want %d", ErrInvalidMatrix, i, len(row), sourceDim) } } return &projectionAdapter{sourceDim: sourceDim, targetDim: targetDim, matrix: matrix}, nil }