Files
vecna/pkg/embedclient/openai.go
Hein 4009a54e39 feat: 🎉 Vectors na Vectors, the begining
Translate 1536 <-> 768 , 3072 <-> 2048
2026-04-11 18:05:05 +02:00

88 lines
2.3 KiB
Go

package embedclient
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
)
type openAIClient struct {
baseURL string
apiKey string
httpClient *http.Client
}
// NewOpenAI returns a Client that speaks the OpenAI embeddings API.
func NewOpenAI(baseURL, apiKey string, httpClient *http.Client) Client {
if httpClient == nil {
httpClient = http.DefaultClient
}
return &openAIClient{baseURL: baseURL, apiKey: apiKey, httpClient: httpClient}
}
type openAIEmbedRequest struct {
Input []string `json:"input"`
Model string `json:"model"`
}
type openAIEmbedResponse struct {
Object string `json:"object"`
Data []struct {
Object string `json:"object"`
Embedding []float32 `json:"embedding"`
Index int `json:"index"`
} `json:"data"`
Model string `json:"model"`
Usage struct {
PromptTokens int `json:"prompt_tokens"`
TotalTokens int `json:"total_tokens"`
} `json:"usage"`
}
func (c *openAIClient) Embed(ctx context.Context, req Request) (Response, error) {
body, err := json.Marshal(openAIEmbedRequest{Input: req.Texts, Model: req.Model})
if err != nil {
return Response{}, fmt.Errorf("openai embed marshal: %w", err)
}
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+"/v1/embeddings", bytes.NewReader(body))
if err != nil {
return Response{}, fmt.Errorf("openai embed request: %w", err)
}
httpReq.Header.Set("Content-Type", "application/json")
if c.apiKey != "" {
httpReq.Header.Set("Authorization", "Bearer "+c.apiKey)
}
resp, err := c.httpClient.Do(httpReq)
if err != nil {
return Response{}, fmt.Errorf("openai embed do: %w", err)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
return Response{}, fmt.Errorf("openai embed: unexpected status %d", resp.StatusCode)
}
var oaiResp openAIEmbedResponse
if err := json.NewDecoder(resp.Body).Decode(&oaiResp); err != nil {
return Response{}, fmt.Errorf("openai embed decode: %w", err)
}
embeddings := make([][]float32, len(oaiResp.Data))
for _, d := range oaiResp.Data {
embeddings[d.Index] = d.Embedding
}
return Response{
Embeddings: embeddings,
Model: oaiResp.Model,
Usage: Usage{
PromptTokens: oaiResp.Usage.PromptTokens,
TotalTokens: oaiResp.Usage.TotalTokens,
},
}, nil
}