package server import ( "fmt" "net/http" "strings" "time" "github.com/prometheus/client_golang/prometheus/promhttp" "github.com/uptrace/bunrouter" "go.uber.org/zap" "github.com/Warky-Devs/vecna.git/pkg/adapter" "github.com/Warky-Devs/vecna.git/pkg/config" "github.com/Warky-Devs/vecna.git/pkg/embedclient" "github.com/Warky-Devs/vecna.git/pkg/metrics" "github.com/Warky-Devs/vecna.git/pkg/server/spec" ) // New builds and returns a configured bunrouter.Router. func New( cfg *config.Config, clients map[string]embedclient.Client, adp adapter.Adapter, reg *metrics.Registry, logger *zap.Logger, ) *bunrouter.Router { router := bunrouter.New( bunrouter.WithMiddleware(authMiddleware(cfg.Server.APIKeys)), bunrouter.WithMiddleware(traceMiddleware()), bunrouter.WithMiddleware(metricsMiddleware(reg, adp)), bunrouter.WithMiddleware(loggingMiddleware(logger)), ) h := &handler{cfg: cfg, clients: clients, adapter: adp, logger: logger} router.POST("/v1/embeddings", h.openAIEmbeddings) router.POST("/v1/models/:model:embedContent", h.googleEmbedContent) router.POST("/v1/models/:model:batchEmbedContents", h.googleBatchEmbedContents) // OpenAPI spec + docs router.GET("/openapi.yaml", spec.SpecHandler()) router.GET("/docs", spec.DocsHandler()) // Metrics — only when enabled if cfg.Metrics.Enabled { metricsHandler := promhttp.HandlerFor(reg.Prometheus(), promhttp.HandlerOpts{}) path := cfg.Metrics.Path if path == "" { path = "/metrics" } if cfg.Metrics.APIKey != "" { router.GET(path, metricsAuthHandler(cfg.Metrics.APIKey, metricsHandler)) } else { router.GET(path, func(w http.ResponseWriter, req bunrouter.Request) error { metricsHandler.ServeHTTP(w, req.Request) return nil }) } } return router } // authMiddleware rejects requests without a valid Bearer token when api_keys is configured. func authMiddleware(apiKeys []string) bunrouter.MiddlewareFunc { if len(apiKeys) == 0 { return func(next bunrouter.HandlerFunc) bunrouter.HandlerFunc { return next } } keySet := make(map[string]struct{}, len(apiKeys)) for _, k := range apiKeys { keySet[k] = struct{}{} } return func(next bunrouter.HandlerFunc) bunrouter.HandlerFunc { return func(w http.ResponseWriter, req bunrouter.Request) error { token := strings.TrimPrefix(req.Header.Get("Authorization"), "Bearer ") if _, ok := keySet[token]; !ok { http.Error(w, "unauthorized", http.StatusUnauthorized) return nil } return next(w, req) } } } // traceMiddleware injects a *RequestTrace into every request context. func traceMiddleware() bunrouter.MiddlewareFunc { return func(next bunrouter.HandlerFunc) bunrouter.HandlerFunc { return func(w http.ResponseWriter, req bunrouter.Request) error { ctx := WithTrace(req.Context()) return next(w, req.WithContext(ctx)) } } } // metricsMiddleware records Prometheus observations after the handler returns. func metricsMiddleware(reg *metrics.Registry, adp adapter.Adapter) bunrouter.MiddlewareFunc { if reg == nil { return func(next bunrouter.HandlerFunc) bunrouter.HandlerFunc { return next } } adpType := fmt.Sprintf("%T", adp) return reg.Middleware(func(req bunrouter.Request) metrics.TraceSnapshot { t := TraceFromContext(req.Context()) total := time.Since(t.Start) return metrics.TraceSnapshot{ TotalSeconds: total.Seconds(), ForwardSeconds: t.ForwardDuration.Seconds(), TranslateSeconds: t.TranslateDuration.Seconds(), ForwardTarget: t.ForwardTarget, ForwardURL: t.ForwardURL, ForwardModel: t.ForwardModel, AdapterType: adpType, PromptTokens: t.PromptTokens, TotalTokens: t.TotalTokens, } }) } // loggingMiddleware logs method, path, status, and timing via zap. func loggingMiddleware(logger *zap.Logger) bunrouter.MiddlewareFunc { return func(next bunrouter.HandlerFunc) bunrouter.HandlerFunc { return func(w http.ResponseWriter, req bunrouter.Request) error { sw := &statusWriter{ResponseWriter: w, status: http.StatusOK} err := next(sw, req) t := TraceFromContext(req.Context()) total := time.Since(t.Start) logger.Info("request", zap.String("method", req.Method), zap.String("path", req.URL.Path), zap.Int("status", sw.status), zap.Int64("total_ms", total.Milliseconds()), zap.Int64("forward_ms", t.ForwardDuration.Milliseconds()), zap.Int64("translate_ms", t.TranslateDuration.Milliseconds()), ) return err } } } // metricsAuthHandler wraps a standard http.Handler with Bearer token auth. func metricsAuthHandler(apiKey string, h http.Handler) bunrouter.HandlerFunc { return func(w http.ResponseWriter, req bunrouter.Request) error { token := strings.TrimPrefix(req.Header.Get("Authorization"), "Bearer ") if token != apiKey { http.Error(w, "unauthorized", http.StatusUnauthorized) return nil } h.ServeHTTP(w, req.Request) return nil } } // statusWriter captures the HTTP status code written by a handler. type statusWriter struct { http.ResponseWriter status int } func (sw *statusWriter) WriteHeader(code int) { sw.status = code sw.ResponseWriter.WriteHeader(code) }