Files
vecna/pkg/server/server.go

199 lines
6.6 KiB
Go

package server
import (
"fmt"
"net/http"
"strings"
"time"
"github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/uptrace/bunrouter"
"go.uber.org/zap"
"github.com/Warky-Devs/vecna.git/pkg/adapter"
"github.com/Warky-Devs/vecna.git/pkg/config"
"github.com/Warky-Devs/vecna.git/pkg/embedclient"
"github.com/Warky-Devs/vecna.git/pkg/metrics"
"github.com/Warky-Devs/vecna.git/pkg/server/spec"
)
// New builds and returns a configured bunrouter.Router.
// Returns an error if route registration panics (e.g. conflicting routes).
func New(
cfg *config.Config,
clients map[string]embedclient.Client,
adp adapter.Adapter,
extraMaps map[string]ExtraMap,
reg *metrics.Registry,
logger *zap.Logger,
) (router *bunrouter.Router, err error) {
defer func() {
if r := recover(); r != nil {
logger.Error("panic during router setup", zap.Any("recover", r))
err = fmt.Errorf("router setup panic: %v", r)
}
}()
router = bunrouter.New(
bunrouter.WithMiddleware(traceMiddleware()),
bunrouter.WithMiddleware(metricsMiddleware(reg, adp)),
bunrouter.WithMiddleware(loggingMiddleware(logger)),
)
h := &handler{cfg: cfg, clients: clients, adapter: adp, extraMaps: extraMaps, logger: logger}
// Public routes — no authentication required.
router.GET("/", spec.DocsHandler())
router.GET("/docs", spec.DocsHandler())
router.GET("/openapi.yaml", spec.SpecHandler())
// All API routes require authentication.
authed := router.NewGroup("", bunrouter.WithMiddleware(authMiddleware(cfg.Server.APIKeys)))
authed.POST("/v1/embeddings", h.openAIEmbeddings)
// Google API uses a literal colon as a method separator (e.g. /v1/models/foo:embedContent).
// bunrouter can't distinguish two routes with the same :param prefix, so a single wildcard
// captures the full "model:action" segment and dispatches internally.
authed.POST("/v1/models/*modelaction", h.googleDispatch)
// Extra-map routes: /map/:mapping/v1/... uses the adapter configured under extra_maps[mapping].
authed.POST("/map/:mapping/v1/embeddings", h.openAIEmbeddingsMapped)
authed.POST("/map/:mapping/v1/models/*modelaction", h.googleDispatchMapped)
// Metrics — only when enabled.
// If metrics.api_key is set, routes are guarded by that key (on the authed group).
// If metrics.api_key is blank, routes are public — no auth headers checked at all.
if cfg.Metrics.Enabled {
metricsHandler := promhttp.HandlerFor(reg.Prometheus(), promhttp.HandlerOpts{})
path := cfg.Metrics.Path
if path == "" {
path = "/metrics"
}
dash := dashboardHandler(reg)
if cfg.Metrics.APIKey != "" {
authed.GET(path, metricsAuthHandler(cfg.Metrics.APIKey, metricsHandler))
authed.GET("/dashboard", metricsKeyMiddleware(cfg.Metrics.APIKey, dash))
} else {
router.GET(path, func(w http.ResponseWriter, req bunrouter.Request) error {
metricsHandler.ServeHTTP(w, req.Request)
return nil
})
router.GET("/dashboard", dash)
}
}
return router, nil
}
// authMiddleware rejects requests without a valid Bearer token when api_keys is configured.
func authMiddleware(apiKeys []string) bunrouter.MiddlewareFunc {
if len(apiKeys) == 0 {
return func(next bunrouter.HandlerFunc) bunrouter.HandlerFunc { return next }
}
keySet := make(map[string]struct{}, len(apiKeys))
for _, k := range apiKeys {
keySet[k] = struct{}{}
}
return func(next bunrouter.HandlerFunc) bunrouter.HandlerFunc {
return func(w http.ResponseWriter, req bunrouter.Request) error {
token := strings.TrimPrefix(req.Header.Get("Authorization"), "Bearer ")
if _, ok := keySet[token]; !ok {
http.Error(w, "unauthorized", http.StatusUnauthorized)
return nil
}
return next(w, req)
}
}
}
// traceMiddleware injects a *RequestTrace into every request context.
func traceMiddleware() bunrouter.MiddlewareFunc {
return func(next bunrouter.HandlerFunc) bunrouter.HandlerFunc {
return func(w http.ResponseWriter, req bunrouter.Request) error {
ctx := WithTrace(req.Context())
return next(w, req.WithContext(ctx))
}
}
}
// metricsMiddleware records Prometheus observations after the handler returns.
func metricsMiddleware(reg *metrics.Registry, adp adapter.Adapter) bunrouter.MiddlewareFunc {
if reg == nil {
return func(next bunrouter.HandlerFunc) bunrouter.HandlerFunc { return next }
}
adpType := fmt.Sprintf("%T", adp)
return reg.Middleware(func(req bunrouter.Request) metrics.TraceSnapshot {
t := TraceFromContext(req.Context())
total := time.Since(t.Start)
return metrics.TraceSnapshot{
TotalSeconds: total.Seconds(),
ForwardSeconds: t.ForwardDuration.Seconds(),
TranslateSeconds: t.TranslateDuration.Seconds(),
ForwardTarget: t.ForwardTarget,
ForwardURL: t.ForwardURL,
ForwardModel: t.ForwardModel,
AdapterType: adpType,
PromptTokens: t.PromptTokens,
TotalTokens: t.TotalTokens,
}
})
}
// loggingMiddleware logs method, path, status, and timing via zap.
func loggingMiddleware(logger *zap.Logger) bunrouter.MiddlewareFunc {
return func(next bunrouter.HandlerFunc) bunrouter.HandlerFunc {
return func(w http.ResponseWriter, req bunrouter.Request) error {
sw := &statusWriter{ResponseWriter: w, status: http.StatusOK}
err := next(sw, req)
t := TraceFromContext(req.Context())
total := time.Since(t.Start)
logger.Info("request",
zap.String("method", req.Method),
zap.String("path", req.URL.Path),
zap.Int("status", sw.status),
zap.Int64("total_ms", total.Milliseconds()),
zap.Int64("forward_ms", t.ForwardDuration.Milliseconds()),
zap.Int64("translate_ms", t.TranslateDuration.Milliseconds()),
)
return err
}
}
}
// metricsKeyMiddleware guards a bunrouter.HandlerFunc with a dedicated Bearer token.
func metricsKeyMiddleware(apiKey string, h bunrouter.HandlerFunc) bunrouter.HandlerFunc {
return func(w http.ResponseWriter, req bunrouter.Request) error {
token := strings.TrimPrefix(req.Header.Get("Authorization"), "Bearer ")
if token != apiKey {
http.Error(w, "unauthorized", http.StatusUnauthorized)
return nil
}
return h(w, req)
}
}
// metricsAuthHandler wraps a standard http.Handler with Bearer token auth.
func metricsAuthHandler(apiKey string, h http.Handler) bunrouter.HandlerFunc {
return func(w http.ResponseWriter, req bunrouter.Request) error {
token := strings.TrimPrefix(req.Header.Get("Authorization"), "Bearer ")
if token != apiKey {
http.Error(w, "unauthorized", http.StatusUnauthorized)
return nil
}
h.ServeHTTP(w, req.Request)
return nil
}
}
// statusWriter captures the HTTP status code written by a handler.
type statusWriter struct {
http.ResponseWriter
status int
}
func (sw *statusWriter) WriteHeader(code int) {
sw.status = code
sw.ResponseWriter.WriteHeader(code)
}