mirror of
https://github.com/Warky-Devs/vecna.git
synced 2026-05-05 01:26:58 +00:00
164 lines
5.0 KiB
Go
164 lines
5.0 KiB
Go
package server
|
|
|
|
import (
|
|
"fmt"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/prometheus/client_golang/prometheus/promhttp"
|
|
"github.com/uptrace/bunrouter"
|
|
"go.uber.org/zap"
|
|
|
|
"github.com/Warky-Devs/vecna.git/pkg/adapter"
|
|
"github.com/Warky-Devs/vecna.git/pkg/config"
|
|
"github.com/Warky-Devs/vecna.git/pkg/embedclient"
|
|
"github.com/Warky-Devs/vecna.git/pkg/metrics"
|
|
"github.com/Warky-Devs/vecna.git/pkg/server/spec"
|
|
)
|
|
|
|
// New builds and returns a configured bunrouter.Router.
|
|
func New(
|
|
cfg *config.Config,
|
|
clients map[string]embedclient.Client,
|
|
adp adapter.Adapter,
|
|
reg *metrics.Registry,
|
|
logger *zap.Logger,
|
|
) *bunrouter.Router {
|
|
router := bunrouter.New(
|
|
bunrouter.WithMiddleware(authMiddleware(cfg.Server.APIKeys)),
|
|
bunrouter.WithMiddleware(traceMiddleware()),
|
|
bunrouter.WithMiddleware(metricsMiddleware(reg, adp)),
|
|
bunrouter.WithMiddleware(loggingMiddleware(logger)),
|
|
)
|
|
|
|
h := &handler{cfg: cfg, clients: clients, adapter: adp, logger: logger}
|
|
|
|
router.POST("/v1/embeddings", h.openAIEmbeddings)
|
|
router.POST("/v1/models/:model:embedContent", h.googleEmbedContent)
|
|
router.POST("/v1/models/:model:batchEmbedContents", h.googleBatchEmbedContents)
|
|
|
|
// OpenAPI spec + docs
|
|
router.GET("/openapi.yaml", spec.SpecHandler())
|
|
router.GET("/docs", spec.DocsHandler())
|
|
|
|
// Metrics — only when enabled
|
|
if cfg.Metrics.Enabled {
|
|
metricsHandler := promhttp.HandlerFor(reg.Prometheus(), promhttp.HandlerOpts{})
|
|
path := cfg.Metrics.Path
|
|
if path == "" {
|
|
path = "/metrics"
|
|
}
|
|
if cfg.Metrics.APIKey != "" {
|
|
router.GET(path, metricsAuthHandler(cfg.Metrics.APIKey, metricsHandler))
|
|
} else {
|
|
router.GET(path, func(w http.ResponseWriter, req bunrouter.Request) error {
|
|
metricsHandler.ServeHTTP(w, req.Request)
|
|
return nil
|
|
})
|
|
}
|
|
}
|
|
|
|
return router
|
|
}
|
|
|
|
// authMiddleware rejects requests without a valid Bearer token when api_keys is configured.
|
|
func authMiddleware(apiKeys []string) bunrouter.MiddlewareFunc {
|
|
if len(apiKeys) == 0 {
|
|
return func(next bunrouter.HandlerFunc) bunrouter.HandlerFunc { return next }
|
|
}
|
|
keySet := make(map[string]struct{}, len(apiKeys))
|
|
for _, k := range apiKeys {
|
|
keySet[k] = struct{}{}
|
|
}
|
|
return func(next bunrouter.HandlerFunc) bunrouter.HandlerFunc {
|
|
return func(w http.ResponseWriter, req bunrouter.Request) error {
|
|
token := strings.TrimPrefix(req.Header.Get("Authorization"), "Bearer ")
|
|
if _, ok := keySet[token]; !ok {
|
|
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
|
return nil
|
|
}
|
|
return next(w, req)
|
|
}
|
|
}
|
|
}
|
|
|
|
// traceMiddleware injects a *RequestTrace into every request context.
|
|
func traceMiddleware() bunrouter.MiddlewareFunc {
|
|
return func(next bunrouter.HandlerFunc) bunrouter.HandlerFunc {
|
|
return func(w http.ResponseWriter, req bunrouter.Request) error {
|
|
ctx := WithTrace(req.Context())
|
|
return next(w, req.WithContext(ctx))
|
|
}
|
|
}
|
|
}
|
|
|
|
// metricsMiddleware records Prometheus observations after the handler returns.
|
|
func metricsMiddleware(reg *metrics.Registry, adp adapter.Adapter) bunrouter.MiddlewareFunc {
|
|
if reg == nil {
|
|
return func(next bunrouter.HandlerFunc) bunrouter.HandlerFunc { return next }
|
|
}
|
|
adpType := fmt.Sprintf("%T", adp)
|
|
return reg.Middleware(func(req bunrouter.Request) metrics.TraceSnapshot {
|
|
t := TraceFromContext(req.Context())
|
|
total := time.Since(t.Start)
|
|
return metrics.TraceSnapshot{
|
|
TotalSeconds: total.Seconds(),
|
|
ForwardSeconds: t.ForwardDuration.Seconds(),
|
|
TranslateSeconds: t.TranslateDuration.Seconds(),
|
|
ForwardTarget: t.ForwardTarget,
|
|
ForwardURL: t.ForwardURL,
|
|
ForwardModel: t.ForwardModel,
|
|
AdapterType: adpType,
|
|
PromptTokens: t.PromptTokens,
|
|
TotalTokens: t.TotalTokens,
|
|
}
|
|
})
|
|
}
|
|
|
|
// loggingMiddleware logs method, path, status, and timing via zap.
|
|
func loggingMiddleware(logger *zap.Logger) bunrouter.MiddlewareFunc {
|
|
return func(next bunrouter.HandlerFunc) bunrouter.HandlerFunc {
|
|
return func(w http.ResponseWriter, req bunrouter.Request) error {
|
|
sw := &statusWriter{ResponseWriter: w, status: http.StatusOK}
|
|
err := next(sw, req)
|
|
t := TraceFromContext(req.Context())
|
|
total := time.Since(t.Start)
|
|
|
|
logger.Info("request",
|
|
zap.String("method", req.Method),
|
|
zap.String("path", req.URL.Path),
|
|
zap.Int("status", sw.status),
|
|
zap.Int64("total_ms", total.Milliseconds()),
|
|
zap.Int64("forward_ms", t.ForwardDuration.Milliseconds()),
|
|
zap.Int64("translate_ms", t.TranslateDuration.Milliseconds()),
|
|
)
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
|
|
// metricsAuthHandler wraps a standard http.Handler with Bearer token auth.
|
|
func metricsAuthHandler(apiKey string, h http.Handler) bunrouter.HandlerFunc {
|
|
return func(w http.ResponseWriter, req bunrouter.Request) error {
|
|
token := strings.TrimPrefix(req.Header.Get("Authorization"), "Bearer ")
|
|
if token != apiKey {
|
|
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
|
return nil
|
|
}
|
|
h.ServeHTTP(w, req.Request)
|
|
return nil
|
|
}
|
|
}
|
|
|
|
// statusWriter captures the HTTP status code written by a handler.
|
|
type statusWriter struct {
|
|
http.ResponseWriter
|
|
status int
|
|
}
|
|
|
|
func (sw *statusWriter) WriteHeader(code int) {
|
|
sw.status = code
|
|
sw.ResponseWriter.WriteHeader(code)
|
|
}
|