From c7a3fed6e162908413bb4c92f57fec2160b6c787 Mon Sep 17 00:00:00 2001 From: Hein Date: Sat, 11 Apr 2026 21:43:14 +0200 Subject: [PATCH] feat(server): add support for extra maps in adapter configuration * Introduced ExtraMapConfig to allow multiple adapter configurations. * Updated server and handler to utilize extra maps for routing. * Added dashboard handler for metrics visualization. --- README.md | 70 ++++++++++++- cmd/vecna/convert.go | 54 +++++++++- cmd/vecna/serve.go | 7 +- go.mod | 2 +- pkg/config/config.go | 23 ++++- pkg/server/dashboard.go | 214 ++++++++++++++++++++++++++++++++++++++++ pkg/server/google.go | 29 ++++-- pkg/server/handler.go | 38 ++++++- pkg/server/openai.go | 17 +++- pkg/server/server.go | 44 +++++++-- 10 files changed, 461 insertions(+), 37 deletions(-) create mode 100644 pkg/server/dashboard.go diff --git a/README.md b/README.md index 0eb1063..338db56 100644 --- a/README.md +++ b/README.md @@ -136,6 +136,11 @@ Override with `--config path/to/file.yaml` or env vars prefixed `VECNA_`. "truncate_mode": "from_end", "pad_mode": "at_end" }, + "extra_maps": { + "512": { "target_dim": 512 }, + "256": { "target_dim": 256, "type": "random", "seed": 42 }, + "fast": { "target_dim": 768, "forward_target": "small-model" } + }, "metrics": { "enabled": true, "path": "/metrics", @@ -184,6 +189,43 @@ There is no partial migration path — a mixed index produces degraded or incorr --- +## Extra maps + +`extra_maps` lets you expose multiple adapter configurations on a single vecna instance. Each entry is a named `AdapterConfig` whose unset fields fall back to the global `adapter` values. + +```json +"adapter": { "type": "truncate", "source_dim": 1024, "target_dim": 1536 }, +"extra_maps": { + "512": { "target_dim": 512 }, + "256": { "target_dim": 256, "type": "random", "seed": 42 }, + "openai-alt": { "target_dim": 1536, "forward_target": "openai" } +} +``` + +| Route | Forwarder | Adapter | +|-------|-----------|---------| +| `POST /v1/embeddings` | global default | global `adapter` | +| `POST /map/512/v1/embeddings` | global default | `extra_maps["512"]` — target 512, rest from global | +| `POST /map/256/v1/embeddings` | global default | `extra_maps["256"]` — random projection to 256 | +| `POST /map/openai-alt/v1/embeddings` | `openai` target | `extra_maps["openai-alt"]` adapter | + +All fields are overridable per map entry: + +| Field | Description | +|-------|-------------| +| `forward_target` | Named target from `forward.targets`; empty = global default | +| `type` | `truncate` / `random` / `projection` | +| `source_dim` | Source dimension; falls back to global `adapter.source_dim` | +| `target_dim` | Target dimension | +| `truncate_mode` | `from_end` / `from_start` | +| `pad_mode` | `at_end` / `at_start` | +| `seed` | Seed for random projection | +| `matrix_file` | Path to projection matrix JSON | + +> The same re-embedding warning applies per map — changing any setting for an `extra_maps` entry requires re-embedding all vectors indexed through that endpoint. + +--- + ## Truncation and padding modes ### `truncate_mode` — which part of the vector is kept when downscaling @@ -242,6 +284,18 @@ POST /v1/models/{model}:embedContent POST /v1/models/{model}:batchEmbedContents ``` +### Extra-map routes + +Serve the same backing model with a different adapter per endpoint. The `{mapping}` segment matches a key in `extra_maps`. + +``` +POST /map/{mapping}/v1/embeddings +POST /map/{mapping}/v1/models/{model}:embedContent +POST /map/{mapping}/v1/models/{model}:batchEmbedContents +``` + +All extra-map routes require the same authentication as the standard API routes. + ### OpenAPI spec and docs ``` @@ -263,7 +317,7 @@ GET /docs ## Prometheus metrics -Enable in config: `metrics.enabled: true`. Scrape at `GET /metrics`. +Enable in config: `metrics.enabled: true`. Scrape at `GET /metrics`. Human-readable dashboard at `GET /dashboard`. | Metric | Type | Description | |--------|------|-------------| @@ -276,6 +330,12 @@ Enable in config: `metrics.enabled: true`. Scrape at `GET /metrics`. | `vecna_endpoint_errors_total` | counter | Forwarding failures by error type | | `vecna_tokens_total` | counter | Tokens consumed, by target, model, and type (`prompt`/`total`) | +### Dashboard + +`GET /dashboard` renders a live HTML view of all metrics. Counters show request counts with status-code badges, histograms show p50/p95/p99 latencies, gauges show current endpoint priority and inflight counts. + +Auth follows the same rules as `/metrics`: server `api_keys` apply, and `metrics.api_key` adds a second layer if set. + --- ## Development @@ -315,7 +375,7 @@ Starts vecna and an Ollama instance. The `vecna_config` named volume persists th ### Onboard (interactive setup) ```sh -docker compose run --rm -it vecna onboard --config /config/vecna.json +docker compose run --rm -it vecna onboard ``` Ollama is reachable by hostname on the Docker network — the scanner will find it automatically. After onboarding, restart the proxy: @@ -327,17 +387,17 @@ docker compose restart vecna ### Query ```sh -docker compose run --rm vecna query --compact "hello world" --config /config/vecna.json +docker compose run --rm vecna query --compact "hello world" ``` ### Test endpoints ```sh # report latency and dims -docker compose run --rm vecna test --config /config/vecna.json +docker compose run --rm vecna test # test and remove failing endpoints -docker compose run --rm vecna test --config /config/vecna.json --remove-broken +docker compose run --rm vecna test --remove-broken ``` ### Edit config manually diff --git a/cmd/vecna/convert.go b/cmd/vecna/convert.go index f3ea4f2..ff7f406 100644 --- a/cmd/vecna/convert.go +++ b/cmd/vecna/convert.go @@ -10,6 +10,7 @@ import ( "github.com/Warky-Devs/vecna.git/pkg/adapter" "github.com/Warky-Devs/vecna.git/pkg/config" + "github.com/Warky-Devs/vecna.git/pkg/server" ) var ( @@ -94,7 +95,37 @@ func openWriter(path string) (io.Writer, error) { // buildAdapter constructs the Adapter from the loaded config. func buildAdapter(cfg *config.Config) (adapter.Adapter, error) { - ac := cfg.Adapter + return buildAdapterFromConfig(cfg.Adapter) +} + +// buildExtraMapAdapters builds a server.ExtraMap for each entry in cfg.ExtraMaps. +// Unset adapter fields fall back to the global Adapter values. +func buildExtraMapAdapters(cfg *config.Config) (map[string]server.ExtraMap, error) { + result := make(map[string]server.ExtraMap, len(cfg.ExtraMaps)) + for name, mc := range cfg.ExtraMaps { + ac := config.AdapterConfig{ + Type: coalesce(mc.Type, cfg.Adapter.Type), + SourceDim: coalescei(mc.SourceDim, cfg.Adapter.SourceDim), + TargetDim: mc.TargetDim, + TruncateMode: coalesce(mc.TruncateMode, cfg.Adapter.TruncateMode), + PadMode: coalesce(mc.PadMode, cfg.Adapter.PadMode), + Seed: coalescei64(mc.Seed, cfg.Adapter.Seed), + MatrixFile: coalesce(mc.MatrixFile, cfg.Adapter.MatrixFile), + } + adp, err := buildAdapterFromConfig(ac) + if err != nil { + return nil, fmt.Errorf("extra_map %q: %w", name, err) + } + result[name] = server.ExtraMap{ + Adapter: adp, + ForwardTarget: mc.ForwardTarget, + } + } + return result, nil +} + +// buildAdapterFromConfig constructs an Adapter from an AdapterConfig. +func buildAdapterFromConfig(ac config.AdapterConfig) (adapter.Adapter, error) { if ac.SourceDim == 0 && ac.TargetDim == 0 { return adapter.NewPassthrough(), nil } @@ -124,6 +155,27 @@ func buildAdapter(cfg *config.Config) (adapter.Adapter, error) { } } +func coalesce(a, b string) string { + if a != "" { + return a + } + return b +} + +func coalescei(a, b int) int { + if a != 0 { + return a + } + return b +} + +func coalescei64(a, b int64) int64 { + if a != 0 { + return a + } + return b +} + func parseTruncateModes(truncMode, padMode string) (adapter.TruncateMode, adapter.PadMode, error) { var tm adapter.TruncateMode switch truncMode { diff --git a/cmd/vecna/serve.go b/cmd/vecna/serve.go index 3181b38..07d4586 100644 --- a/cmd/vecna/serve.go +++ b/cmd/vecna/serve.go @@ -48,6 +48,11 @@ func runServe(cmd *cobra.Command, _ []string) error { return fmt.Errorf("build adapter: %w", err) } + extraMaps, err := buildExtraMapAdapters(cfg) + if err != nil { + return fmt.Errorf("build extra_maps: %w", err) + } + clients, err := buildClients(cfg) if err != nil { return fmt.Errorf("build clients: %w", err) @@ -58,7 +63,7 @@ func runServe(cmd *cobra.Command, _ []string) error { reg = metrics.New() } - router, err := server.New(cfg, clients, adp, reg, logger) + router, err := server.New(cfg, clients, adp, extraMaps, reg, logger) if err != nil { return fmt.Errorf("build router: %w", err) } diff --git a/go.mod b/go.mod index 0807514..d2848d2 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.26.1 require ( github.com/go-viper/mapstructure/v2 v2.4.0 github.com/prometheus/client_golang v1.23.2 + github.com/prometheus/client_model v0.6.2 github.com/spf13/cobra v1.10.2 github.com/spf13/viper v1.21.0 github.com/stretchr/testify v1.11.1 @@ -21,7 +22,6 @@ require ( github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/pelletier/go-toml/v2 v2.2.4 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - github.com/prometheus/client_model v0.6.2 // indirect github.com/prometheus/common v0.66.1 // indirect github.com/prometheus/procfs v0.16.1 // indirect github.com/sagikazarmark/locafero v0.11.0 // indirect diff --git a/pkg/config/config.go b/pkg/config/config.go index 0351898..e47eaa8 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -12,10 +12,11 @@ import ( // Config is the root configuration for vecna. type Config struct { - Server ServerConfig `mapstructure:"server" json:"server" yaml:"server" xml:"server"` - Metrics MetricsConfig `mapstructure:"metrics" json:"metrics" yaml:"metrics" xml:"metrics"` - Forward ForwardConfig `mapstructure:"forward" json:"forward" yaml:"forward" xml:"forward"` - Adapter AdapterConfig `mapstructure:"adapter" json:"adapter" yaml:"adapter" xml:"adapter"` + Server ServerConfig `mapstructure:"server" json:"server" yaml:"server" xml:"server"` + Metrics MetricsConfig `mapstructure:"metrics" json:"metrics" yaml:"metrics" xml:"metrics"` + Forward ForwardConfig `mapstructure:"forward" json:"forward" yaml:"forward" xml:"forward"` + Adapter AdapterConfig `mapstructure:"adapter" json:"adapter" yaml:"adapter" xml:"adapter"` + ExtraMaps map[string]ExtraMapConfig `mapstructure:"extra_maps" json:"extra_maps" yaml:"extra_maps" xml:"extra_maps"` } // ServerConfig controls the HTTP listener and inbound auth. @@ -57,6 +58,20 @@ type EndpointConfig struct { APIKey string `mapstructure:"api_key" json:"api_key" yaml:"api_key" xml:"api_key"` } +// ExtraMapConfig is a per-mapping override: all adapter fields fall back to the +// global adapter when unset; ForwardTarget selects a named target from +// forward.targets (empty = global forward.default). +type ExtraMapConfig struct { + ForwardTarget string `mapstructure:"forward_target" json:"forward_target,omitempty" yaml:"forward_target,omitempty" xml:"forward_target,omitempty"` + Type string `mapstructure:"type" json:"type,omitempty" yaml:"type,omitempty" xml:"type,omitempty"` + SourceDim int `mapstructure:"source_dim" json:"source_dim,omitempty" yaml:"source_dim,omitempty" xml:"source_dim,omitempty"` + TargetDim int `mapstructure:"target_dim" json:"target_dim,omitempty" yaml:"target_dim,omitempty" xml:"target_dim,omitempty"` + TruncateMode string `mapstructure:"truncate_mode" json:"truncate_mode,omitempty" yaml:"truncate_mode,omitempty" xml:"truncate_mode,omitempty"` + PadMode string `mapstructure:"pad_mode" json:"pad_mode,omitempty" yaml:"pad_mode,omitempty" xml:"pad_mode,omitempty"` + Seed int64 `mapstructure:"seed" json:"seed,omitempty" yaml:"seed,omitempty" xml:"seed,omitempty"` + MatrixFile string `mapstructure:"matrix_file" json:"matrix_file,omitempty" yaml:"matrix_file,omitempty" xml:"matrix_file,omitempty"` +} + // AdapterConfig selects and tunes the dimension adapter. type AdapterConfig struct { Type string `mapstructure:"type" json:"type" yaml:"type" xml:"type"` diff --git a/pkg/server/dashboard.go b/pkg/server/dashboard.go new file mode 100644 index 0000000..93f9e18 --- /dev/null +++ b/pkg/server/dashboard.go @@ -0,0 +1,214 @@ +package server + +import ( + "fmt" + "net/http" + "sort" + "strings" + + dto "github.com/prometheus/client_model/go" + "github.com/uptrace/bunrouter" + + "github.com/Warky-Devs/vecna.git/pkg/metrics" +) + +// dashboardHandler returns an HTML metrics dashboard by gathering from the registry. +func dashboardHandler(reg *metrics.Registry) bunrouter.HandlerFunc { + return func(w http.ResponseWriter, req bunrouter.Request) error { + families, err := reg.Prometheus().Gather() + if err != nil { + http.Error(w, "failed to gather metrics", http.StatusInternalServerError) + return nil + } + + // Sort families by name for deterministic output. + sort.Slice(families, func(i, j int) bool { + return families[i].GetName() < families[j].GetName() + }) + + var b strings.Builder + b.WriteString(` + + + + +Vecna Metrics Dashboard + + + +

Vecna Metrics Dashboard

+`) + + for _, fam := range families { + name := fam.GetName() + help := fam.GetHelp() + mtype := fam.GetType() + + b.WriteString(`
`) + fmt.Fprintf(&b, "

%s

", htmlEsc(name)) + if help != "" { + fmt.Fprintf(&b, `
%s
`, htmlEsc(help)) + } + + switch mtype { + case dto.MetricType_COUNTER: + renderCounter(&b, fam.GetMetric()) + case dto.MetricType_GAUGE: + renderGauge(&b, fam.GetMetric()) + case dto.MetricType_HISTOGRAM: + renderHistogram(&b, fam.GetMetric()) + default: + renderGeneric(&b, fam.GetMetric()) + } + + b.WriteString(`
`) + } + + b.WriteString("") + + w.Header().Set("Content-Type", "text/html; charset=utf-8") + _, err = fmt.Fprint(w, b.String()) + return err + } +} + +func renderCounter(b *strings.Builder, ms []*dto.Metric) { + b.WriteString("") + if len(ms) > 0 && len(ms[0].GetLabel()) > 0 { + for _, lp := range ms[0].GetLabel() { + fmt.Fprintf(b, "", htmlEsc(lp.GetName())) + } + } + b.WriteString("") + for _, m := range ms { + b.WriteString("") + for _, lp := range m.GetLabel() { + val := lp.GetValue() + cls := "" + if lp.GetName() == "status" { + if strings.HasPrefix(val, "2") { + cls = ` class="badge b-ok"` + } else if strings.HasPrefix(val, "4") || strings.HasPrefix(val, "5") { + cls = ` class="badge b-err"` + } + } + if cls != "" { + fmt.Fprintf(b, "", cls, htmlEsc(val)) + } else { + fmt.Fprintf(b, "", htmlEsc(val)) + } + } + fmt.Fprintf(b, ``, m.GetCounter().GetValue()) + b.WriteString("") + } + b.WriteString("
%scount
%s%s%.0f
") +} + +func renderGauge(b *strings.Builder, ms []*dto.Metric) { + b.WriteString("") + if len(ms) > 0 && len(ms[0].GetLabel()) > 0 { + for _, lp := range ms[0].GetLabel() { + fmt.Fprintf(b, "", htmlEsc(lp.GetName())) + } + } + b.WriteString("") + for _, m := range ms { + b.WriteString("") + for _, lp := range m.GetLabel() { + fmt.Fprintf(b, "", htmlEsc(lp.GetValue())) + } + fmt.Fprintf(b, ``, m.GetGauge().GetValue()) + b.WriteString("") + } + b.WriteString("
%svalue
%s%.4g
") +} + +func renderHistogram(b *strings.Builder, ms []*dto.Metric) { + b.WriteString("") + if len(ms) > 0 && len(ms[0].GetLabel()) > 0 { + for _, lp := range ms[0].GetLabel() { + fmt.Fprintf(b, "", htmlEsc(lp.GetName())) + } + } + b.WriteString("") + for _, m := range ms { + b.WriteString("") + for _, lp := range m.GetLabel() { + fmt.Fprintf(b, "", htmlEsc(lp.GetValue())) + } + h := m.GetHistogram() + count := h.GetSampleCount() + p50 := histogramQuantile(h, 0.50) + p95 := histogramQuantile(h, 0.95) + p99 := histogramQuantile(h, 0.99) + fmt.Fprintf(b, ``, count) + fmt.Fprintf(b, ``, fmtSeconds(p50)) + fmt.Fprintf(b, ``, fmtSeconds(p95)) + fmt.Fprintf(b, ``, fmtSeconds(p99)) + b.WriteString("") + } + b.WriteString("
%scountp50 (s)p95 (s)p99 (s)
%s%d%s%s%s
") +} + +func renderGeneric(b *strings.Builder, ms []*dto.Metric) { + fmt.Fprintf(b, `
%d series
`, len(ms)) +} + +// histogramQuantile computes a linear-interpolated quantile from a cumulative histogram. +func histogramQuantile(h *dto.Histogram, q float64) float64 { + buckets := h.GetBucket() + total := float64(h.GetSampleCount()) + if total == 0 || len(buckets) == 0 { + return 0 + } + target := q * total + var prevCount float64 + var prevBound float64 + for _, b := range buckets { + count := float64(b.GetCumulativeCount()) + bound := b.GetUpperBound() + if count >= target { + if count == prevCount { + return prevBound + } + // linear interpolation within bucket + return prevBound + (bound-prevBound)*(target-prevCount)/(count-prevCount) + } + prevCount = count + prevBound = bound + } + return prevBound +} + +func fmtSeconds(s float64) string { + if s == 0 { + return "—" + } + if s < 0.001 { + return fmt.Sprintf("%.3fms", s*1000) + } + return fmt.Sprintf("%.3fs", s) +} + +func htmlEsc(s string) string { + s = strings.ReplaceAll(s, "&", "&") + s = strings.ReplaceAll(s, "<", "<") + s = strings.ReplaceAll(s, ">", ">") + return s +} diff --git a/pkg/server/google.go b/pkg/server/google.go index 0c39ca1..5fe0454 100644 --- a/pkg/server/google.go +++ b/pkg/server/google.go @@ -9,6 +9,7 @@ import ( "github.com/uptrace/bunrouter" + "github.com/Warky-Devs/vecna.git/pkg/adapter" "github.com/Warky-Devs/vecna.git/pkg/embedclient" ) @@ -16,6 +17,18 @@ import ( // correct Google handler. The colon is a literal method separator in the // Google embedding API, not a bunrouter parameter prefix. func (h *handler) googleDispatch(w http.ResponseWriter, req bunrouter.Request) error { + return h.googleDispatchWithAdapter(w, req, h.adapter, "") +} + +func (h *handler) googleDispatchMapped(w http.ResponseWriter, req bunrouter.Request) error { + em, err := h.resolveExtraMap(req.Param("mapping")) + if err != nil { + return writeJSON(w, http.StatusNotFound, map[string]string{"error": err.Error()}) + } + return h.googleDispatchWithAdapter(w, req, em.Adapter, em.ForwardTarget) +} + +func (h *handler) googleDispatchWithAdapter(w http.ResponseWriter, req bunrouter.Request, adp adapter.Adapter, targetOverride string) error { modelaction := req.Param("modelaction") // e.g. "text-embedding-foo:embedContent" idx := strings.LastIndex(modelaction, ":") if idx < 0 { @@ -29,9 +42,9 @@ func (h *handler) googleDispatch(w http.ResponseWriter, req bunrouter.Request) e switch action { case "embedContent": - return h.googleEmbedContent(w, req) + return h.googleEmbedContentWithAdapter(w, req, adp, targetOverride) case "batchEmbedContents": - return h.googleBatchEmbedContents(w, req) + return h.googleBatchEmbedContentsWithAdapter(w, req, adp, targetOverride) default: return writeJSON(w, http.StatusNotFound, map[string]string{"error": "unknown Google API method: " + action}) } @@ -60,7 +73,7 @@ type googleEmbeddingValues struct { Values []float32 `json:"values"` } -func (h *handler) googleEmbedContent(w http.ResponseWriter, req bunrouter.Request) error { +func (h *handler) googleEmbedContentWithAdapter(w http.ResponseWriter, req bunrouter.Request, adp adapter.Adapter, targetOverride string) error { model, _ := req.Context().Value(modelKey).(string) var body googleEmbedContentRequest @@ -73,7 +86,7 @@ func (h *handler) googleEmbedContent(w http.ResponseWriter, req bunrouter.Reques texts[i] = p.Text } - client, targetName, targetURL := h.resolveClient(model) + client, targetName, targetURL := h.resolveClientOverride(targetOverride, model) trace := TraceFromContext(req.Context()) trace.ForwardTarget = targetName trace.ForwardURL = targetURL @@ -91,7 +104,7 @@ func (h *handler) googleEmbedContent(w http.ResponseWriter, req bunrouter.Reques t1 := time.Now() var adapted []float32 if len(embedResp.Embeddings) > 0 { - adapted, err = h.adapter.Adapt(embedResp.Embeddings[0]) + adapted, err = adp.Adapt(embedResp.Embeddings[0]) if err != nil { return writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) } @@ -115,7 +128,7 @@ type googleBatchResponse struct { Embeddings []googleEmbeddingValues `json:"embeddings"` } -func (h *handler) googleBatchEmbedContents(w http.ResponseWriter, req bunrouter.Request) error { +func (h *handler) googleBatchEmbedContentsWithAdapter(w http.ResponseWriter, req bunrouter.Request, adp adapter.Adapter, targetOverride string) error { model, _ := req.Context().Value(modelKey).(string) var body googleBatchRequest @@ -130,7 +143,7 @@ func (h *handler) googleBatchEmbedContents(w http.ResponseWriter, req bunrouter. } } - client, targetName, targetURL := h.resolveClient(model) + client, targetName, targetURL := h.resolveClientOverride(targetOverride, model) trace := TraceFromContext(req.Context()) trace.ForwardTarget = targetName trace.ForwardURL = targetURL @@ -148,7 +161,7 @@ func (h *handler) googleBatchEmbedContents(w http.ResponseWriter, req bunrouter. t1 := time.Now() result := make([]googleEmbeddingValues, len(embedResp.Embeddings)) for i, vec := range embedResp.Embeddings { - adapted, adaptErr := h.adapter.Adapt(vec) + adapted, adaptErr := adp.Adapt(vec) if adaptErr != nil { return writeJSON(w, http.StatusInternalServerError, map[string]string{"error": adaptErr.Error()}) } diff --git a/pkg/server/handler.go b/pkg/server/handler.go index bcad66f..27b7570 100644 --- a/pkg/server/handler.go +++ b/pkg/server/handler.go @@ -14,12 +14,28 @@ import ( "github.com/Warky-Devs/vecna.git/pkg/embedclient" ) +// ExtraMap pairs a dimension adapter with an optional forward-target override. +type ExtraMap struct { + Adapter adapter.Adapter + ForwardTarget string // named target in forward.targets; empty = model-based resolution +} + // handler holds shared dependencies for all HTTP handlers. type handler struct { - cfg *config.Config - clients map[string]embedclient.Client - adapter adapter.Adapter - logger *zap.Logger + cfg *config.Config + clients map[string]embedclient.Client + adapter adapter.Adapter + extraMaps map[string]ExtraMap + logger *zap.Logger +} + +// resolveExtraMap returns the ExtraMap for the named extra_map entry. +func (h *handler) resolveExtraMap(name string) (ExtraMap, error) { + em, ok := h.extraMaps[name] + if !ok { + return ExtraMap{}, fmt.Errorf("extra_map %q not configured", name) + } + return em, nil } // resolveClient selects the embed client for the given model name. @@ -32,12 +48,24 @@ func (h *handler) resolveClient(model string) (embedclient.Client, string, strin name := h.cfg.Forward.Default c, ok := h.clients[name] if !ok { - // No configured client — return a nil-safe error client return &errClient{err: fmt.Errorf("no client configured for model %q and no default", model)}, name, "" } return c, name, firstEndpointURL(h.cfg, name) } +// resolveClientOverride selects the client for targetOverride when set, +// otherwise falls back to model-based resolution. +func (h *handler) resolveClientOverride(targetOverride, model string) (embedclient.Client, string, string) { + if targetOverride == "" { + return h.resolveClient(model) + } + c, ok := h.clients[targetOverride] + if !ok { + return &errClient{err: fmt.Errorf("extra_map forward_target %q not configured", targetOverride)}, targetOverride, "" + } + return c, targetOverride, firstEndpointURL(h.cfg, targetOverride) +} + func firstEndpointURL(cfg *config.Config, targetName string) string { t, ok := cfg.Forward.Targets[targetName] if !ok || len(t.Endpoints) == 0 { diff --git a/pkg/server/openai.go b/pkg/server/openai.go index f44da5c..52046b7 100644 --- a/pkg/server/openai.go +++ b/pkg/server/openai.go @@ -8,6 +8,7 @@ import ( "github.com/uptrace/bunrouter" + "github.com/Warky-Devs/vecna.git/pkg/adapter" "github.com/Warky-Devs/vecna.git/pkg/embedclient" ) @@ -35,6 +36,18 @@ type openAIUsage struct { } func (h *handler) openAIEmbeddings(w http.ResponseWriter, req bunrouter.Request) error { + return h.openAIEmbeddingsWithAdapter(w, req, h.adapter, "") +} + +func (h *handler) openAIEmbeddingsMapped(w http.ResponseWriter, req bunrouter.Request) error { + em, err := h.resolveExtraMap(req.Param("mapping")) + if err != nil { + return writeJSON(w, http.StatusNotFound, map[string]string{"error": err.Error()}) + } + return h.openAIEmbeddingsWithAdapter(w, req, em.Adapter, em.ForwardTarget) +} + +func (h *handler) openAIEmbeddingsWithAdapter(w http.ResponseWriter, req bunrouter.Request, adp adapter.Adapter, targetOverride string) error { var body openAIEmbedRequest if err := json.NewDecoder(req.Body).Decode(&body); err != nil { return writeJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid request body"}) @@ -45,7 +58,7 @@ func (h *handler) openAIEmbeddings(w http.ResponseWriter, req bunrouter.Request) return writeJSON(w, http.StatusBadRequest, map[string]string{"error": err.Error()}) } - client, targetName, targetURL := h.resolveClient(body.Model) + client, targetName, targetURL := h.resolveClientOverride(targetOverride, body.Model) trace := TraceFromContext(req.Context()) trace.ForwardTarget = targetName trace.ForwardURL = targetURL @@ -63,7 +76,7 @@ func (h *handler) openAIEmbeddings(w http.ResponseWriter, req bunrouter.Request) t1 := time.Now() data := make([]openAIEmbedDatum, len(embedResp.Embeddings)) for i, vec := range embedResp.Embeddings { - adapted, adaptErr := h.adapter.Adapt(vec) + adapted, adaptErr := adp.Adapt(vec) if adaptErr != nil { return writeJSON(w, http.StatusInternalServerError, map[string]string{"error": adaptErr.Error()}) } diff --git a/pkg/server/server.go b/pkg/server/server.go index e560bab..079f3a4 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -23,6 +23,7 @@ func New( cfg *config.Config, clients map[string]embedclient.Client, adp adapter.Adapter, + extraMaps map[string]ExtraMap, reg *metrics.Registry, logger *zap.Logger, ) (router *bunrouter.Router, err error) { @@ -34,38 +35,49 @@ func New( }() router = bunrouter.New( - bunrouter.WithMiddleware(authMiddleware(cfg.Server.APIKeys)), bunrouter.WithMiddleware(traceMiddleware()), bunrouter.WithMiddleware(metricsMiddleware(reg, adp)), bunrouter.WithMiddleware(loggingMiddleware(logger)), ) - h := &handler{cfg: cfg, clients: clients, adapter: adp, logger: logger} + h := &handler{cfg: cfg, clients: clients, adapter: adp, extraMaps: extraMaps, logger: logger} - router.POST("/v1/embeddings", h.openAIEmbeddings) + // Public routes — no authentication required. + router.GET("/", spec.DocsHandler()) + router.GET("/docs", spec.DocsHandler()) + router.GET("/openapi.yaml", spec.SpecHandler()) + + // All API routes require authentication. + authed := router.NewGroup("", bunrouter.WithMiddleware(authMiddleware(cfg.Server.APIKeys))) + + authed.POST("/v1/embeddings", h.openAIEmbeddings) // Google API uses a literal colon as a method separator (e.g. /v1/models/foo:embedContent). // bunrouter can't distinguish two routes with the same :param prefix, so a single wildcard // captures the full "model:action" segment and dispatches internally. - router.POST("/v1/models/*modelaction", h.googleDispatch) + authed.POST("/v1/models/*modelaction", h.googleDispatch) - // OpenAPI spec + docs - router.GET("/openapi.yaml", spec.SpecHandler()) - router.GET("/docs", spec.DocsHandler()) + // Extra-map routes: /map/:mapping/v1/... uses the adapter configured under extra_maps[mapping]. + authed.POST("/map/:mapping/v1/embeddings", h.openAIEmbeddingsMapped) + authed.POST("/map/:mapping/v1/models/*modelaction", h.googleDispatchMapped) - // Metrics — only when enabled + // Metrics — only when enabled; registered in the authed group so server + // auth (if configured) applies. Metrics may additionally enforce its own key. if cfg.Metrics.Enabled { metricsHandler := promhttp.HandlerFor(reg.Prometheus(), promhttp.HandlerOpts{}) path := cfg.Metrics.Path if path == "" { path = "/metrics" } + dash := dashboardHandler(reg) if cfg.Metrics.APIKey != "" { - router.GET(path, metricsAuthHandler(cfg.Metrics.APIKey, metricsHandler)) + authed.GET(path, metricsAuthHandler(cfg.Metrics.APIKey, metricsHandler)) + authed.GET("/dashboard", metricsKeyMiddleware(cfg.Metrics.APIKey, dash)) } else { - router.GET(path, func(w http.ResponseWriter, req bunrouter.Request) error { + authed.GET(path, func(w http.ResponseWriter, req bunrouter.Request) error { metricsHandler.ServeHTTP(w, req.Request) return nil }) + authed.GET("/dashboard", dash) } } @@ -148,6 +160,18 @@ func loggingMiddleware(logger *zap.Logger) bunrouter.MiddlewareFunc { } } +// metricsKeyMiddleware guards a bunrouter.HandlerFunc with a dedicated Bearer token. +func metricsKeyMiddleware(apiKey string, h bunrouter.HandlerFunc) bunrouter.HandlerFunc { + return func(w http.ResponseWriter, req bunrouter.Request) error { + token := strings.TrimPrefix(req.Header.Get("Authorization"), "Bearer ") + if token != apiKey { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return nil + } + return h(w, req) + } +} + // metricsAuthHandler wraps a standard http.Handler with Bearer token auth. func metricsAuthHandler(apiKey string, h http.Handler) bunrouter.HandlerFunc { return func(w http.ResponseWriter, req bunrouter.Request) error {