Files
vecna/pkg/server/google.go
Hein 4009a54e39 feat: 🎉 Vectors na Vectors, the begining
Translate 1536 <-> 768 , 3072 <-> 2048
2026-04-11 18:05:05 +02:00

136 lines
3.7 KiB
Go

package server
import (
"encoding/json"
"net/http"
"time"
"github.com/uptrace/bunrouter"
"github.com/Warky-Devs/vecna.git/pkg/embedclient"
)
// --- single embedContent ---
type googleEmbedContentRequest struct {
Content googleContent `json:"content"`
TaskType string `json:"taskType,omitempty"`
}
type googleContent struct {
Parts []googlePart `json:"parts"`
}
type googlePart struct {
Text string `json:"text"`
}
type googleEmbedContentResponse struct {
Embedding googleEmbeddingValues `json:"embedding"`
}
type googleEmbeddingValues struct {
Values []float32 `json:"values"`
}
func (h *handler) googleEmbedContent(w http.ResponseWriter, req bunrouter.Request) error {
model := req.Param("model")
var body googleEmbedContentRequest
if err := json.NewDecoder(req.Body).Decode(&body); err != nil {
return writeJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid request body"})
}
texts := make([]string, len(body.Content.Parts))
for i, p := range body.Content.Parts {
texts[i] = p.Text
}
client, targetName, targetURL := h.resolveClient(model)
trace := TraceFromContext(req.Context())
trace.ForwardTarget = targetName
trace.ForwardURL = targetURL
t0 := time.Now()
embedResp, err := client.Embed(req.Context(), embedclient.Request{Texts: texts, Model: model})
trace.ForwardDuration = time.Since(t0)
if err != nil {
return writeJSON(w, http.StatusBadGateway, map[string]string{"error": err.Error()})
}
trace.ForwardModel = embedResp.Model
trace.PromptTokens = embedResp.Usage.PromptTokens
trace.TotalTokens = embedResp.Usage.TotalTokens
t1 := time.Now()
var adapted []float32
if len(embedResp.Embeddings) > 0 {
adapted, err = h.adapter.Adapt(embedResp.Embeddings[0])
if err != nil {
return writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()})
}
}
trace.TranslateDuration = time.Since(t1)
writeTraceHeaders(w, trace)
return writeJSON(w, http.StatusOK, googleEmbedContentResponse{
Embedding: googleEmbeddingValues{Values: adapted},
})
}
// --- batch batchEmbedContents ---
type googleBatchRequest struct {
Requests []googleEmbedContentRequest `json:"requests"`
}
type googleBatchResponse struct {
Embeddings []googleEmbeddingValues `json:"embeddings"`
}
func (h *handler) googleBatchEmbedContents(w http.ResponseWriter, req bunrouter.Request) error {
model := req.Param("model")
var body googleBatchRequest
if err := json.NewDecoder(req.Body).Decode(&body); err != nil {
return writeJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid request body"})
}
var texts []string
for _, r := range body.Requests {
for _, p := range r.Content.Parts {
texts = append(texts, p.Text)
}
}
client, targetName, targetURL := h.resolveClient(model)
trace := TraceFromContext(req.Context())
trace.ForwardTarget = targetName
trace.ForwardURL = targetURL
t0 := time.Now()
embedResp, err := client.Embed(req.Context(), embedclient.Request{Texts: texts, Model: model})
trace.ForwardDuration = time.Since(t0)
if err != nil {
return writeJSON(w, http.StatusBadGateway, map[string]string{"error": err.Error()})
}
trace.ForwardModel = embedResp.Model
trace.PromptTokens = embedResp.Usage.PromptTokens
trace.TotalTokens = embedResp.Usage.TotalTokens
t1 := time.Now()
result := make([]googleEmbeddingValues, len(embedResp.Embeddings))
for i, vec := range embedResp.Embeddings {
adapted, adaptErr := h.adapter.Adapt(vec)
if adaptErr != nil {
return writeJSON(w, http.StatusInternalServerError, map[string]string{"error": adaptErr.Error()})
}
result[i] = googleEmbeddingValues{Values: adapted}
}
trace.TranslateDuration = time.Since(t1)
writeTraceHeaders(w, trace)
return writeJSON(w, http.StatusOK, googleBatchResponse{Embeddings: result})
}