Files
vecna/pkg/server/server.go
Hein 4009a54e39 feat: 🎉 Vectors na Vectors, the begining
Translate 1536 <-> 768 , 3072 <-> 2048
2026-04-11 18:05:05 +02:00

164 lines
5.0 KiB
Go

package server
import (
"fmt"
"net/http"
"strings"
"time"
"github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/uptrace/bunrouter"
"go.uber.org/zap"
"github.com/Warky-Devs/vecna.git/pkg/adapter"
"github.com/Warky-Devs/vecna.git/pkg/config"
"github.com/Warky-Devs/vecna.git/pkg/embedclient"
"github.com/Warky-Devs/vecna.git/pkg/metrics"
"github.com/Warky-Devs/vecna.git/pkg/server/spec"
)
// New builds and returns a configured bunrouter.Router.
func New(
cfg *config.Config,
clients map[string]embedclient.Client,
adp adapter.Adapter,
reg *metrics.Registry,
logger *zap.Logger,
) *bunrouter.Router {
router := bunrouter.New(
bunrouter.WithMiddleware(authMiddleware(cfg.Server.APIKeys)),
bunrouter.WithMiddleware(traceMiddleware()),
bunrouter.WithMiddleware(metricsMiddleware(reg, adp)),
bunrouter.WithMiddleware(loggingMiddleware(logger)),
)
h := &handler{cfg: cfg, clients: clients, adapter: adp, logger: logger}
router.POST("/v1/embeddings", h.openAIEmbeddings)
router.POST("/v1/models/:model:embedContent", h.googleEmbedContent)
router.POST("/v1/models/:model:batchEmbedContents", h.googleBatchEmbedContents)
// OpenAPI spec + docs
router.GET("/openapi.yaml", spec.SpecHandler())
router.GET("/docs", spec.DocsHandler())
// Metrics — only when enabled
if cfg.Metrics.Enabled {
metricsHandler := promhttp.HandlerFor(reg.Prometheus(), promhttp.HandlerOpts{})
path := cfg.Metrics.Path
if path == "" {
path = "/metrics"
}
if cfg.Metrics.APIKey != "" {
router.GET(path, metricsAuthHandler(cfg.Metrics.APIKey, metricsHandler))
} else {
router.GET(path, func(w http.ResponseWriter, req bunrouter.Request) error {
metricsHandler.ServeHTTP(w, req.Request)
return nil
})
}
}
return router
}
// authMiddleware rejects requests without a valid Bearer token when api_keys is configured.
func authMiddleware(apiKeys []string) bunrouter.MiddlewareFunc {
if len(apiKeys) == 0 {
return func(next bunrouter.HandlerFunc) bunrouter.HandlerFunc { return next }
}
keySet := make(map[string]struct{}, len(apiKeys))
for _, k := range apiKeys {
keySet[k] = struct{}{}
}
return func(next bunrouter.HandlerFunc) bunrouter.HandlerFunc {
return func(w http.ResponseWriter, req bunrouter.Request) error {
token := strings.TrimPrefix(req.Header.Get("Authorization"), "Bearer ")
if _, ok := keySet[token]; !ok {
http.Error(w, "unauthorized", http.StatusUnauthorized)
return nil
}
return next(w, req)
}
}
}
// traceMiddleware injects a *RequestTrace into every request context.
func traceMiddleware() bunrouter.MiddlewareFunc {
return func(next bunrouter.HandlerFunc) bunrouter.HandlerFunc {
return func(w http.ResponseWriter, req bunrouter.Request) error {
ctx := WithTrace(req.Context())
return next(w, req.WithContext(ctx))
}
}
}
// metricsMiddleware records Prometheus observations after the handler returns.
func metricsMiddleware(reg *metrics.Registry, adp adapter.Adapter) bunrouter.MiddlewareFunc {
if reg == nil {
return func(next bunrouter.HandlerFunc) bunrouter.HandlerFunc { return next }
}
adpType := fmt.Sprintf("%T", adp)
return reg.Middleware(func(req bunrouter.Request) metrics.TraceSnapshot {
t := TraceFromContext(req.Context())
total := time.Since(t.Start)
return metrics.TraceSnapshot{
TotalSeconds: total.Seconds(),
ForwardSeconds: t.ForwardDuration.Seconds(),
TranslateSeconds: t.TranslateDuration.Seconds(),
ForwardTarget: t.ForwardTarget,
ForwardURL: t.ForwardURL,
ForwardModel: t.ForwardModel,
AdapterType: adpType,
PromptTokens: t.PromptTokens,
TotalTokens: t.TotalTokens,
}
})
}
// loggingMiddleware logs method, path, status, and timing via zap.
func loggingMiddleware(logger *zap.Logger) bunrouter.MiddlewareFunc {
return func(next bunrouter.HandlerFunc) bunrouter.HandlerFunc {
return func(w http.ResponseWriter, req bunrouter.Request) error {
sw := &statusWriter{ResponseWriter: w, status: http.StatusOK}
err := next(sw, req)
t := TraceFromContext(req.Context())
total := time.Since(t.Start)
logger.Info("request",
zap.String("method", req.Method),
zap.String("path", req.URL.Path),
zap.Int("status", sw.status),
zap.Int64("total_ms", total.Milliseconds()),
zap.Int64("forward_ms", t.ForwardDuration.Milliseconds()),
zap.Int64("translate_ms", t.TranslateDuration.Milliseconds()),
)
return err
}
}
}
// metricsAuthHandler wraps a standard http.Handler with Bearer token auth.
func metricsAuthHandler(apiKey string, h http.Handler) bunrouter.HandlerFunc {
return func(w http.ResponseWriter, req bunrouter.Request) error {
token := strings.TrimPrefix(req.Header.Get("Authorization"), "Bearer ")
if token != apiKey {
http.Error(w, "unauthorized", http.StatusUnauthorized)
return nil
}
h.ServeHTTP(w, req.Request)
return nil
}
}
// statusWriter captures the HTTP status code written by a handler.
type statusWriter struct {
http.ResponseWriter
status int
}
func (sw *statusWriter) WriteHeader(code int) {
sw.status = code
sw.ResponseWriter.WriteHeader(code)
}