mirror of
https://github.com/Warky-Devs/vecna.git
synced 2026-05-05 01:26:58 +00:00
614 lines
16 KiB
Go
614 lines
16 KiB
Go
package main
|
|
|
|
import (
|
|
"bufio"
|
|
"context"
|
|
"fmt"
|
|
"net/http"
|
|
"os"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/spf13/cobra"
|
|
|
|
"github.com/Warky-Devs/vecna.git/pkg/config"
|
|
"github.com/Warky-Devs/vecna.git/pkg/discovery"
|
|
"github.com/Warky-Devs/vecna.git/pkg/embedclient"
|
|
)
|
|
|
|
var onboardCmd = &cobra.Command{
|
|
Use: "onboard",
|
|
Short: "Interactive setup wizard: discover servers, configure, test, and write config",
|
|
RunE: runOnboard,
|
|
}
|
|
|
|
func init() {
|
|
rootCmd.AddCommand(onboardCmd)
|
|
}
|
|
|
|
func runOnboard(_ *cobra.Command, _ []string) error {
|
|
in := bufio.NewReader(os.Stdin)
|
|
|
|
fmt.Println("=== vecna onboard ===")
|
|
fmt.Println()
|
|
|
|
// ── Step 1: Discover ──────────────────────────────────────────────────────
|
|
|
|
step(1, 6, "Discover embedding servers")
|
|
|
|
fmt.Println("Scanning (Ollama, LM Studio, vLLM, LocalAI, Jan, Kobold, Tabby)...")
|
|
servers := discovery.Scan(context.Background())
|
|
|
|
var targets []pendingTarget
|
|
|
|
if len(servers) == 0 {
|
|
fmt.Println("No servers found automatically.")
|
|
} else {
|
|
fmt.Printf("Found %d server(s):\n\n", len(servers))
|
|
for i, s := range servers {
|
|
fmt.Printf(" [%d] %-12s %s\n Models: %s\n\n",
|
|
i+1, s.Kind.Name, s.BaseURL, joinModels(s.Models))
|
|
}
|
|
|
|
// Let user pick one or more from the list; 0 = manual
|
|
for {
|
|
choice, err := promptInt(in,
|
|
fmt.Sprintf("Select server [1-%d] or 0 to enter URL manually: ", len(servers)), 0, len(servers))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
var pt pendingTarget
|
|
if choice == 0 {
|
|
pt, err = collectManualTarget(in)
|
|
} else {
|
|
pt, err = collectDiscoveredTarget(in, servers[choice-1])
|
|
}
|
|
if err != nil {
|
|
return err
|
|
}
|
|
targets = append(targets, pt)
|
|
|
|
again, err := promptBool(in, "Add another forwarder?", false)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if !again {
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
if len(targets) == 0 {
|
|
pt, err := collectManualTarget(in)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
targets = append(targets, pt)
|
|
}
|
|
|
|
// ── Step 2: Detect dimensions ─────────────────────────────────────────────
|
|
|
|
step(2, 6, "Detect model dimensions")
|
|
|
|
for i := range targets {
|
|
fmt.Printf("Probing %s / %s ... ", targets[i].endpoint, targets[i].model)
|
|
dim, err := detectDim(targets[i])
|
|
if err != nil {
|
|
fmt.Printf("failed (%s) — you will need to enter the dimension manually\n", err)
|
|
targets[i].detectedDim = 0
|
|
} else {
|
|
fmt.Printf("%d dims\n", dim)
|
|
targets[i].detectedDim = dim
|
|
}
|
|
}
|
|
fmt.Println()
|
|
|
|
// ── Step 3: Configure adapter ─────────────────────────────────────────────
|
|
|
|
step(3, 6, "Configure dimension adapter")
|
|
|
|
// Use the first target's detected dim as the source dimension default
|
|
firstDim := 0
|
|
for _, t := range targets {
|
|
if t.detectedDim > 0 {
|
|
firstDim = t.detectedDim
|
|
break
|
|
}
|
|
}
|
|
|
|
srcDimStr := ""
|
|
if firstDim > 0 {
|
|
srcDimStr = fmt.Sprintf("%d", firstDim)
|
|
}
|
|
|
|
sourceDimRaw, err := promptString(in,
|
|
fmt.Sprintf("Source dimension (native model output dim)%s: ", defaultHint(srcDimStr)), srcDimStr)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
sourceDim := mustParseInt(sourceDimRaw, firstDim)
|
|
|
|
targetDimRaw, err := promptString(in, "Target dimension (output dim vecna will serve) [1536]: ", "1536")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
targetDim := mustParseInt(targetDimRaw, 1536)
|
|
|
|
adapterType, err := promptString(in, "Adapter type (truncate/random/projection) [truncate]: ", "truncate")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
truncateMode := "from_end"
|
|
padMode := "at_end"
|
|
if adapterType == "truncate" {
|
|
truncateMode, err = promptString(in, "Truncate mode (from_end/from_start) [from_end]: ", "from_end")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
padMode, err = promptString(in, "Pad mode (at_end/at_start) [at_end]: ", "at_end")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
fmt.Println()
|
|
|
|
// ── Step 4: Configure vecna server ────────────────────────────────────────
|
|
|
|
step(4, 6, "Configure vecna server")
|
|
|
|
portRaw, err := promptString(in, "Bind port [8080]: ", "8080")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
port := mustParseInt(portRaw, 8080)
|
|
|
|
apiKeysRaw, err := promptString(in,
|
|
"Inbound API keys for vecna (comma-separated, leave empty to disable auth): ", "")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
var apiKeys []string
|
|
for _, k := range strings.Split(apiKeysRaw, ",") {
|
|
if k := strings.TrimSpace(k); k != "" {
|
|
apiKeys = append(apiKeys, k)
|
|
}
|
|
}
|
|
|
|
enableMetrics, err := promptBool(in, "Enable Prometheus /metrics endpoint?", false)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
metricsAPIKey := ""
|
|
if enableMetrics {
|
|
metricsAPIKey, err = promptString(in, "Metrics API key (leave empty for open access): ", "")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
fmt.Println()
|
|
|
|
// ── Step 5: Extra maps ────────────────────────────────────────────────────
|
|
|
|
step(5, 6, "Configure extra maps (optional)")
|
|
|
|
extraMaps := map[string]config.ExtraMapConfig{}
|
|
|
|
addMaps, err := promptBool(in, "Add extra dimension maps (/map/{key}/v1/embeddings)?", false)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if addMaps {
|
|
// Build the list of target names from the already-collected targets.
|
|
targetNames := make([]string, 0, len(targets))
|
|
for _, t := range targets {
|
|
targetNames = append(targetNames, t.name)
|
|
}
|
|
|
|
for {
|
|
mc, mapErr := collectExtraMap(in, sourceDim, targetDim, adapterType, truncateMode, padMode, targetNames)
|
|
if mapErr != nil {
|
|
return mapErr
|
|
}
|
|
extraMaps[mc.key] = mc.cfg
|
|
|
|
another, promptErr := promptBool(in, "Add another extra map?", false)
|
|
if promptErr != nil {
|
|
return promptErr
|
|
}
|
|
if !another {
|
|
break
|
|
}
|
|
}
|
|
}
|
|
fmt.Println()
|
|
|
|
// ── Step 6: Test & write ──────────────────────────────────────────────────
|
|
|
|
step(6, 6, "Test connections and write config")
|
|
|
|
allPassed := true
|
|
for _, t := range targets {
|
|
fmt.Printf("Testing %-45s ", t.endpoint+"...")
|
|
_, elapsed, dims, testErr := runSingleTest(t)
|
|
if testErr != nil {
|
|
fmt.Printf("FAIL %s\n", truncate(testErr.Error(), 55))
|
|
allPassed = false
|
|
} else {
|
|
fmt.Printf("OK %dms dims=%d\n", elapsed.Milliseconds(), dims)
|
|
}
|
|
}
|
|
fmt.Println()
|
|
|
|
if !allPassed {
|
|
proceed, err := promptBool(in, "Some tests failed. Write config anyway?", false)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if !proceed {
|
|
fmt.Println("Aborted. No config written.")
|
|
return nil
|
|
}
|
|
}
|
|
|
|
// Build the config struct
|
|
defaultTarget := ""
|
|
forwardTargets := make(map[string]config.ForwardTarget, len(targets))
|
|
for i, t := range targets {
|
|
forwardTargets[t.name] = config.ForwardTarget{
|
|
APIType: t.apiType,
|
|
Model: t.model,
|
|
APIKey: t.apiKey,
|
|
Endpoints: []config.EndpointConfig{
|
|
{URL: t.endpoint, Priority: 10},
|
|
},
|
|
TimeoutSecs: 30,
|
|
CooldownSecs: 60,
|
|
PriorityDecay: 2,
|
|
PriorityRecovery: 5,
|
|
}
|
|
if i == 0 {
|
|
defaultTarget = t.name
|
|
}
|
|
}
|
|
|
|
cfg := config.Config{
|
|
Server: config.ServerConfig{
|
|
Port: port,
|
|
Host: "0.0.0.0",
|
|
APIKeys: apiKeys,
|
|
},
|
|
Metrics: config.MetricsConfig{
|
|
Enabled: enableMetrics,
|
|
Path: "/metrics",
|
|
APIKey: metricsAPIKey,
|
|
},
|
|
Forward: config.ForwardConfig{
|
|
Default: defaultTarget,
|
|
Targets: forwardTargets,
|
|
},
|
|
Adapter: config.AdapterConfig{
|
|
Type: adapterType,
|
|
SourceDim: sourceDim,
|
|
TargetDim: targetDim,
|
|
TruncateMode: truncateMode,
|
|
PadMode: padMode,
|
|
},
|
|
ExtraMaps: extraMaps,
|
|
}
|
|
|
|
defaultCfgPath := config.ResolveFile(cfgFile)
|
|
fmt.Printf("Config will be written to: %s\n", defaultCfgPath)
|
|
cfgPath, err := promptString(in, "Config path (press Enter to accept): ", defaultCfgPath)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if err := writeFullConfig(cfgPath, cfg); err != nil {
|
|
return fmt.Errorf("write config: %w", err)
|
|
}
|
|
|
|
fmt.Printf("Config written to %s\n", cfgPath)
|
|
fmt.Println()
|
|
fmt.Println("Run 'vecna serve' to start the proxy server.")
|
|
return nil
|
|
}
|
|
|
|
// ── helpers ──────────────────────────────────────────────────────────────────
|
|
|
|
// pendingTarget collects configuration for a single forwarding target before
|
|
// the config is assembled.
|
|
type pendingTarget struct {
|
|
name string
|
|
endpoint string
|
|
model string
|
|
apiType string
|
|
apiKey string
|
|
detectedDim int
|
|
}
|
|
|
|
func step(n, total int, title string) {
|
|
fmt.Printf("[%d/%d] %s\n", n, total, title)
|
|
fmt.Println(strings.Repeat("-", 40))
|
|
}
|
|
|
|
func defaultHint(s string) string {
|
|
if s == "" {
|
|
return ""
|
|
}
|
|
return fmt.Sprintf(" [%s]", s)
|
|
}
|
|
|
|
func joinModels(models []string) string {
|
|
if len(models) == 0 {
|
|
return "(none)"
|
|
}
|
|
if len(models) > 5 {
|
|
return strings.Join(models[:5], ", ") + fmt.Sprintf(" (+%d more)", len(models)-5)
|
|
}
|
|
return strings.Join(models, ", ")
|
|
}
|
|
|
|
func collectDiscoveredTarget(in *bufio.Reader, srv discovery.Found) (pendingTarget, error) {
|
|
defaultName := strings.ToLower(strings.ReplaceAll(srv.Kind.Name, " ", "_"))
|
|
|
|
model, err := pickModel(in, srv.Models)
|
|
if err != nil {
|
|
return pendingTarget{}, err
|
|
}
|
|
|
|
name, err := promptString(in, fmt.Sprintf("Target name in config [%s]: ", defaultName), defaultName)
|
|
if err != nil {
|
|
return pendingTarget{}, err
|
|
}
|
|
|
|
apiKey, err := promptAPIKey(in, srv.Kind.NeedsKey)
|
|
if err != nil {
|
|
return pendingTarget{}, err
|
|
}
|
|
|
|
return pendingTarget{
|
|
name: name,
|
|
endpoint: srv.BaseURL,
|
|
model: model,
|
|
apiType: srv.Kind.APIType,
|
|
apiKey: apiKey,
|
|
}, nil
|
|
}
|
|
|
|
func collectManualTarget(in *bufio.Reader) (pendingTarget, error) {
|
|
fmt.Println("Enter server details manually:")
|
|
|
|
endpoint, err := promptString(in, "Server URL (e.g. http://localhost:11434): ", "")
|
|
if err != nil || endpoint == "" {
|
|
return pendingTarget{}, fmt.Errorf("server URL is required")
|
|
}
|
|
|
|
apiTypeStr, err := promptString(in, "API type (openai/google) [openai]: ", "openai")
|
|
if err != nil {
|
|
return pendingTarget{}, err
|
|
}
|
|
|
|
model, err := promptString(in, "Model name: ", "")
|
|
if err != nil || model == "" {
|
|
return pendingTarget{}, fmt.Errorf("model name is required")
|
|
}
|
|
|
|
name, err := promptString(in, "Target name in config [custom]: ", "custom")
|
|
if err != nil {
|
|
return pendingTarget{}, err
|
|
}
|
|
|
|
apiKey, err := promptAPIKey(in, false)
|
|
if err != nil {
|
|
return pendingTarget{}, err
|
|
}
|
|
|
|
return pendingTarget{
|
|
name: name,
|
|
endpoint: endpoint,
|
|
model: model,
|
|
apiType: apiTypeStr,
|
|
apiKey: apiKey,
|
|
}, nil
|
|
}
|
|
|
|
func pickModel(in *bufio.Reader, models []string) (string, error) {
|
|
switch {
|
|
case len(models) == 0:
|
|
return promptString(in, "Model name: ", "")
|
|
case len(models) == 1:
|
|
fmt.Printf("Using model: %s\n", models[0])
|
|
return models[0], nil
|
|
default:
|
|
fmt.Println("Available models:")
|
|
for i, m := range models {
|
|
fmt.Printf(" [%d] %s\n", i+1, m)
|
|
}
|
|
idx, err := promptInt(in, fmt.Sprintf("Select model [1-%d]: ", len(models)), 1, len(models))
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return models[idx-1], nil
|
|
}
|
|
}
|
|
|
|
func promptAPIKey(in *bufio.Reader, required bool) (string, error) {
|
|
prompt := "API key (leave empty if none): "
|
|
if required {
|
|
prompt = "API key: "
|
|
}
|
|
return promptString(in, prompt, "")
|
|
}
|
|
|
|
// detectDim sends a single test embedding and returns the vector length.
|
|
func detectDim(t pendingTarget) (int, error) {
|
|
httpClient := &http.Client{Timeout: 10 * time.Second}
|
|
var client embedclient.Client
|
|
if t.apiType == "google" {
|
|
client = embedclient.NewGoogle(t.endpoint, t.apiKey, t.model, httpClient)
|
|
} else {
|
|
client = embedclient.NewOpenAI(t.endpoint, t.apiKey, httpClient)
|
|
}
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
|
defer cancel()
|
|
|
|
resp, err := client.Embed(ctx, embedclient.Request{
|
|
Texts: []string{"dimension probe"},
|
|
Model: t.model,
|
|
})
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
if len(resp.Embeddings) == 0 || len(resp.Embeddings[0]) == 0 {
|
|
return 0, fmt.Errorf("empty embedding in response")
|
|
}
|
|
return len(resp.Embeddings[0]), nil
|
|
}
|
|
|
|
// runSingleTest runs one test embed and returns success, elapsed time, dims, and any error.
|
|
func runSingleTest(t pendingTarget) (bool, time.Duration, int, error) {
|
|
httpClient := &http.Client{Timeout: 30 * time.Second}
|
|
var client embedclient.Client
|
|
if t.apiType == "google" {
|
|
client = embedclient.NewGoogle(t.endpoint, t.apiKey, t.model, httpClient)
|
|
} else {
|
|
client = embedclient.NewOpenAI(t.endpoint, t.apiKey, httpClient)
|
|
}
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
|
defer cancel()
|
|
|
|
start := time.Now()
|
|
resp, err := client.Embed(ctx, embedclient.Request{
|
|
Texts: []string{testPhrase},
|
|
Model: t.model,
|
|
})
|
|
elapsed := time.Since(start)
|
|
if err != nil {
|
|
return false, elapsed, 0, err
|
|
}
|
|
|
|
dims, _ := embeddingStats(resp.Embeddings)
|
|
return true, elapsed, dims, nil
|
|
}
|
|
|
|
func mustParseInt(s string, fallback int) int {
|
|
var n int
|
|
if _, err := fmt.Sscanf(s, "%d", &n); err != nil || n <= 0 {
|
|
return fallback
|
|
}
|
|
return n
|
|
}
|
|
|
|
// pendingExtraMap pairs the map key with its collected config.
|
|
type pendingExtraMap struct {
|
|
key string
|
|
cfg config.ExtraMapConfig
|
|
}
|
|
|
|
// collectExtraMap interactively collects one extra_map entry.
|
|
// Global adapter values are shown as defaults; the user may override any field.
|
|
func collectExtraMap(
|
|
in *bufio.Reader,
|
|
globalSourceDim, globalTargetDim int,
|
|
globalAdapterType, globalTruncateMode, globalPadMode string,
|
|
targetNames []string,
|
|
) (pendingExtraMap, error) {
|
|
key, err := promptString(in, "Map key (used in URL path, e.g. \"512\"): ", "")
|
|
if err != nil || key == "" {
|
|
return pendingExtraMap{}, fmt.Errorf("map key is required")
|
|
}
|
|
|
|
mc := config.ExtraMapConfig{}
|
|
|
|
// Target dimension (required — always shown)
|
|
targetDimRaw, err := promptString(in,
|
|
fmt.Sprintf("Target dimension [%d]: ", globalTargetDim), fmt.Sprintf("%d", globalTargetDim))
|
|
if err != nil {
|
|
return pendingExtraMap{}, err
|
|
}
|
|
mc.TargetDim = mustParseInt(targetDimRaw, globalTargetDim)
|
|
|
|
// Forward target override
|
|
if len(targetNames) > 0 {
|
|
fmt.Println("Available forward targets: " + strings.Join(targetNames, ", "))
|
|
ftRaw, err := promptString(in, "Forward target override (leave empty to use global default): ", "")
|
|
if err != nil {
|
|
return pendingExtraMap{}, err
|
|
}
|
|
mc.ForwardTarget = strings.TrimSpace(ftRaw)
|
|
}
|
|
|
|
// Adapter type override
|
|
adapterTypeRaw, err := promptString(in,
|
|
fmt.Sprintf("Adapter type override (truncate/random/projection) [%s]: ", globalAdapterType), globalAdapterType)
|
|
if err != nil {
|
|
return pendingExtraMap{}, err
|
|
}
|
|
if adapterTypeRaw != globalAdapterType {
|
|
mc.Type = adapterTypeRaw
|
|
}
|
|
effectiveType := coalesce(mc.Type, globalAdapterType)
|
|
|
|
// Source dim override
|
|
sourceDimRaw, err := promptString(in,
|
|
fmt.Sprintf("Source dimension override (leave empty to use global %d): ", globalSourceDim), "")
|
|
if err != nil {
|
|
return pendingExtraMap{}, err
|
|
}
|
|
if sourceDimRaw != "" {
|
|
mc.SourceDim = mustParseInt(sourceDimRaw, globalSourceDim)
|
|
}
|
|
|
|
// Truncate/pad mode overrides — only for truncate type
|
|
if effectiveType == "truncate" {
|
|
tmRaw, err := promptString(in,
|
|
fmt.Sprintf("Truncate mode override (from_end/from_start) [%s]: ", globalTruncateMode), globalTruncateMode)
|
|
if err != nil {
|
|
return pendingExtraMap{}, err
|
|
}
|
|
if tmRaw != globalTruncateMode {
|
|
mc.TruncateMode = tmRaw
|
|
}
|
|
|
|
pmRaw, err := promptString(in,
|
|
fmt.Sprintf("Pad mode override (at_end/at_start) [%s]: ", globalPadMode), globalPadMode)
|
|
if err != nil {
|
|
return pendingExtraMap{}, err
|
|
}
|
|
if pmRaw != globalPadMode {
|
|
mc.PadMode = pmRaw
|
|
}
|
|
}
|
|
|
|
// Seed — only for random type
|
|
if effectiveType == "random" {
|
|
seedRaw, err := promptString(in, "Seed (0 = time-based): ", "0")
|
|
if err != nil {
|
|
return pendingExtraMap{}, err
|
|
}
|
|
var seed int64
|
|
if _, err := fmt.Sscanf(seedRaw, "%d", &seed); err != nil {
|
|
seed = 0
|
|
}
|
|
mc.Seed = seed
|
|
}
|
|
|
|
return pendingExtraMap{key: key, cfg: mc}, nil
|
|
}
|
|
|
|
func writeFullConfig(path string, cfg config.Config) error {
|
|
// If file already exists, preserve any targets not touched by onboard
|
|
// by using SaveTarget for each new target; otherwise write the whole file.
|
|
if _, err := os.Stat(path); os.IsNotExist(err) {
|
|
if err := createDefaultConfig(path); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
// Overwrite with the complete onboard config
|
|
return config.WriteConfig(path, cfg)
|
|
}
|