mirror of
https://github.com/Warky-Devs/vecna.git
synced 2026-05-05 01:26:58 +00:00
feat: 🎉 Vectors na Vectors, the begining
Translate 1536 <-> 768 , 3072 <-> 2048
This commit is contained in:
27
pkg/embedclient/client.go
Normal file
27
pkg/embedclient/client.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package embedclient
|
||||
|
||||
import "context"
|
||||
|
||||
// Request is a batch of texts to embed.
|
||||
type Request struct {
|
||||
Texts []string
|
||||
Model string
|
||||
}
|
||||
|
||||
// Usage reports token consumption from the backing model.
|
||||
type Usage struct {
|
||||
PromptTokens int
|
||||
TotalTokens int
|
||||
}
|
||||
|
||||
// Response holds the raw embeddings returned by the backing model.
|
||||
type Response struct {
|
||||
Embeddings [][]float32
|
||||
Model string
|
||||
Usage Usage
|
||||
}
|
||||
|
||||
// Client sends text to a backing embedding model and returns raw vectors.
|
||||
type Client interface {
|
||||
Embed(ctx context.Context, req Request) (Response, error)
|
||||
}
|
||||
98
pkg/embedclient/google.go
Normal file
98
pkg/embedclient/google.go
Normal file
@@ -0,0 +1,98 @@
|
||||
package embedclient
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type googleClient struct {
|
||||
baseURL string
|
||||
apiKey string
|
||||
model string
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// NewGoogle returns a Client that speaks the Google Gemini batchEmbedContents API.
|
||||
func NewGoogle(baseURL, apiKey, model string, httpClient *http.Client) Client {
|
||||
if httpClient == nil {
|
||||
httpClient = http.DefaultClient
|
||||
}
|
||||
return &googleClient{baseURL: baseURL, apiKey: apiKey, model: model, httpClient: httpClient}
|
||||
}
|
||||
|
||||
type googleBatchRequest struct {
|
||||
Requests []googleEmbedRequest `json:"requests"`
|
||||
}
|
||||
|
||||
type googleEmbedRequest struct {
|
||||
Model string `json:"model"`
|
||||
Content googleContent `json:"content"`
|
||||
}
|
||||
|
||||
type googleContent struct {
|
||||
Parts []googlePart `json:"parts"`
|
||||
}
|
||||
|
||||
type googlePart struct {
|
||||
Text string `json:"text"`
|
||||
}
|
||||
|
||||
type googleBatchResponse struct {
|
||||
Embeddings []struct {
|
||||
Values []float32 `json:"values"`
|
||||
} `json:"embeddings"`
|
||||
}
|
||||
|
||||
func (c *googleClient) Embed(ctx context.Context, req Request) (Response, error) {
|
||||
requests := make([]googleEmbedRequest, len(req.Texts))
|
||||
for i, text := range req.Texts {
|
||||
requests[i] = googleEmbedRequest{
|
||||
Model: "models/" + c.model,
|
||||
Content: googleContent{Parts: []googlePart{{Text: text}}},
|
||||
}
|
||||
}
|
||||
|
||||
body, err := json.Marshal(googleBatchRequest{Requests: requests})
|
||||
if err != nil {
|
||||
return Response{}, fmt.Errorf("google embed marshal: %w", err)
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/v1/models/%s:batchEmbedContents", c.baseURL, c.model)
|
||||
if c.apiKey != "" {
|
||||
url += "?key=" + c.apiKey
|
||||
}
|
||||
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return Response{}, fmt.Errorf("google embed request: %w", err)
|
||||
}
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := c.httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
return Response{}, fmt.Errorf("google embed do: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return Response{}, fmt.Errorf("google embed: unexpected status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var gResp googleBatchResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&gResp); err != nil {
|
||||
return Response{}, fmt.Errorf("google embed decode: %w", err)
|
||||
}
|
||||
|
||||
embeddings := make([][]float32, len(gResp.Embeddings))
|
||||
for i, e := range gResp.Embeddings {
|
||||
embeddings[i] = e.Values
|
||||
}
|
||||
|
||||
return Response{
|
||||
Embeddings: embeddings,
|
||||
Model: c.model,
|
||||
}, nil
|
||||
}
|
||||
87
pkg/embedclient/openai.go
Normal file
87
pkg/embedclient/openai.go
Normal file
@@ -0,0 +1,87 @@
|
||||
package embedclient
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type openAIClient struct {
|
||||
baseURL string
|
||||
apiKey string
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// NewOpenAI returns a Client that speaks the OpenAI embeddings API.
|
||||
func NewOpenAI(baseURL, apiKey string, httpClient *http.Client) Client {
|
||||
if httpClient == nil {
|
||||
httpClient = http.DefaultClient
|
||||
}
|
||||
return &openAIClient{baseURL: baseURL, apiKey: apiKey, httpClient: httpClient}
|
||||
}
|
||||
|
||||
type openAIEmbedRequest struct {
|
||||
Input []string `json:"input"`
|
||||
Model string `json:"model"`
|
||||
}
|
||||
|
||||
type openAIEmbedResponse struct {
|
||||
Object string `json:"object"`
|
||||
Data []struct {
|
||||
Object string `json:"object"`
|
||||
Embedding []float32 `json:"embedding"`
|
||||
Index int `json:"index"`
|
||||
} `json:"data"`
|
||||
Model string `json:"model"`
|
||||
Usage struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
} `json:"usage"`
|
||||
}
|
||||
|
||||
func (c *openAIClient) Embed(ctx context.Context, req Request) (Response, error) {
|
||||
body, err := json.Marshal(openAIEmbedRequest{Input: req.Texts, Model: req.Model})
|
||||
if err != nil {
|
||||
return Response{}, fmt.Errorf("openai embed marshal: %w", err)
|
||||
}
|
||||
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+"/v1/embeddings", bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return Response{}, fmt.Errorf("openai embed request: %w", err)
|
||||
}
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
if c.apiKey != "" {
|
||||
httpReq.Header.Set("Authorization", "Bearer "+c.apiKey)
|
||||
}
|
||||
|
||||
resp, err := c.httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
return Response{}, fmt.Errorf("openai embed do: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return Response{}, fmt.Errorf("openai embed: unexpected status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var oaiResp openAIEmbedResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&oaiResp); err != nil {
|
||||
return Response{}, fmt.Errorf("openai embed decode: %w", err)
|
||||
}
|
||||
|
||||
embeddings := make([][]float32, len(oaiResp.Data))
|
||||
for _, d := range oaiResp.Data {
|
||||
embeddings[d.Index] = d.Embedding
|
||||
}
|
||||
|
||||
return Response{
|
||||
Embeddings: embeddings,
|
||||
Model: oaiResp.Model,
|
||||
Usage: Usage{
|
||||
PromptTokens: oaiResp.Usage.PromptTokens,
|
||||
TotalTokens: oaiResp.Usage.TotalTokens,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
180
pkg/embedclient/router.go
Normal file
180
pkg/embedclient/router.go
Normal file
@@ -0,0 +1,180 @@
|
||||
package embedclient
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Warky-Devs/vecna.git/pkg/metrics"
|
||||
)
|
||||
|
||||
// RouterConfig holds tuning parameters for a TargetRouter.
|
||||
type RouterConfig struct {
|
||||
TargetName string
|
||||
TimeoutSecs int
|
||||
CooldownSecs int
|
||||
PriorityDecay int
|
||||
PriorityRecovery int
|
||||
}
|
||||
|
||||
type endpointSlot struct {
|
||||
client Client
|
||||
url string
|
||||
initialPriority int
|
||||
|
||||
mu sync.Mutex
|
||||
priority int
|
||||
inflight int
|
||||
successCount int
|
||||
lastFail time.Time
|
||||
}
|
||||
|
||||
// TargetRouter implements Client by routing requests across multiple endpoint slots
|
||||
// using a busyness-based priority algorithm.
|
||||
type TargetRouter struct {
|
||||
slots []*endpointSlot
|
||||
cfg RouterConfig
|
||||
metrics *metrics.Registry
|
||||
}
|
||||
|
||||
// NewTargetRouter constructs a TargetRouter from a slice of (client, url, initialPriority) tuples.
|
||||
func NewTargetRouter(slots []RouterSlot, cfg RouterConfig, reg *metrics.Registry) (*TargetRouter, error) {
|
||||
if len(slots) == 0 {
|
||||
return nil, fmt.Errorf("NewTargetRouter: at least one slot required")
|
||||
}
|
||||
es := make([]*endpointSlot, len(slots))
|
||||
for i, s := range slots {
|
||||
es[i] = &endpointSlot{
|
||||
client: s.Client,
|
||||
url: s.URL,
|
||||
initialPriority: s.Priority,
|
||||
priority: s.Priority,
|
||||
}
|
||||
if reg != nil {
|
||||
reg.SetEndpointPriority(cfg.TargetName, s.URL, float64(s.Priority))
|
||||
reg.SetEndpointInflight(cfg.TargetName, s.URL, 0)
|
||||
}
|
||||
}
|
||||
return &TargetRouter{slots: es, cfg: cfg, metrics: reg}, nil
|
||||
}
|
||||
|
||||
// RouterSlot is a single endpoint entry for NewTargetRouter.
|
||||
type RouterSlot struct {
|
||||
Client Client
|
||||
URL string
|
||||
Priority int
|
||||
}
|
||||
|
||||
func (r *TargetRouter) Embed(ctx context.Context, req Request) (Response, error) {
|
||||
slot := r.pick()
|
||||
|
||||
slot.mu.Lock()
|
||||
slot.inflight++
|
||||
if r.metrics != nil {
|
||||
r.metrics.SetEndpointInflight(r.cfg.TargetName, slot.url, float64(slot.inflight))
|
||||
}
|
||||
slot.mu.Unlock()
|
||||
|
||||
defer func() {
|
||||
slot.mu.Lock()
|
||||
slot.inflight--
|
||||
if r.metrics != nil {
|
||||
r.metrics.SetEndpointInflight(r.cfg.TargetName, slot.url, float64(slot.inflight))
|
||||
}
|
||||
slot.mu.Unlock()
|
||||
}()
|
||||
|
||||
timeout := time.Duration(r.cfg.TimeoutSecs) * time.Second
|
||||
ctx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
|
||||
resp, err := slot.client.Embed(ctx, req)
|
||||
if err != nil {
|
||||
r.onFailure(slot, err)
|
||||
return Response{}, fmt.Errorf("router embed [%s]: %w", slot.url, err)
|
||||
}
|
||||
|
||||
r.onSuccess(slot)
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// pick selects the best available slot.
|
||||
func (r *TargetRouter) pick() *endpointSlot {
|
||||
cooldown := time.Duration(r.cfg.CooldownSecs) * time.Second
|
||||
now := time.Now()
|
||||
|
||||
var best *endpointSlot
|
||||
bestScore := -1 << 30
|
||||
|
||||
for _, s := range r.slots {
|
||||
s.mu.Lock()
|
||||
inCooldown := !s.lastFail.IsZero() && now.Sub(s.lastFail) < cooldown
|
||||
score := s.priority - s.inflight
|
||||
lastFail := s.lastFail
|
||||
s.mu.Unlock()
|
||||
|
||||
if inCooldown {
|
||||
continue
|
||||
}
|
||||
if best == nil || score > bestScore {
|
||||
best = s
|
||||
bestScore = score
|
||||
_ = lastFail
|
||||
}
|
||||
}
|
||||
|
||||
// All in cooldown — fall back to oldest failure
|
||||
if best == nil {
|
||||
var oldest time.Time
|
||||
for _, s := range r.slots {
|
||||
s.mu.Lock()
|
||||
lf := s.lastFail
|
||||
s.mu.Unlock()
|
||||
if best == nil || lf.Before(oldest) {
|
||||
best = s
|
||||
oldest = lf
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return best
|
||||
}
|
||||
|
||||
func (r *TargetRouter) onSuccess(s *endpointSlot) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.successCount++
|
||||
if r.cfg.PriorityRecovery > 0 && s.successCount%r.cfg.PriorityRecovery == 0 {
|
||||
if s.priority < s.initialPriority {
|
||||
s.priority++
|
||||
}
|
||||
}
|
||||
if r.metrics != nil {
|
||||
r.metrics.SetEndpointPriority(r.cfg.TargetName, s.url, float64(s.priority))
|
||||
}
|
||||
}
|
||||
|
||||
func (r *TargetRouter) onFailure(s *endpointSlot, err error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.lastFail = time.Now()
|
||||
s.priority -= r.cfg.PriorityDecay
|
||||
if s.priority < 1 {
|
||||
s.priority = 1
|
||||
}
|
||||
|
||||
errType := "error"
|
||||
if ctx := context.Background(); ctx.Err() != nil {
|
||||
errType = "timeout"
|
||||
}
|
||||
|
||||
if r.metrics != nil {
|
||||
r.metrics.SetEndpointPriority(r.cfg.TargetName, s.url, float64(s.priority))
|
||||
r.metrics.IncEndpointErrors(r.cfg.TargetName, s.url, errType)
|
||||
}
|
||||
|
||||
_ = err
|
||||
}
|
||||
Reference in New Issue
Block a user