Files
vecna/cmd/vecna/onboard.go

614 lines
16 KiB
Go

package main
import (
"bufio"
"context"
"fmt"
"net/http"
"os"
"strings"
"time"
"github.com/spf13/cobra"
"github.com/Warky-Devs/vecna.git/pkg/config"
"github.com/Warky-Devs/vecna.git/pkg/discovery"
"github.com/Warky-Devs/vecna.git/pkg/embedclient"
)
var onboardCmd = &cobra.Command{
Use: "onboard",
Short: "Interactive setup wizard: discover servers, configure, test, and write config",
RunE: runOnboard,
}
func init() {
rootCmd.AddCommand(onboardCmd)
}
func runOnboard(_ *cobra.Command, _ []string) error {
in := bufio.NewReader(os.Stdin)
fmt.Println("=== vecna onboard ===")
fmt.Println()
// ── Step 1: Discover ──────────────────────────────────────────────────────
step(1, 6, "Discover embedding servers")
fmt.Println("Scanning (Ollama, LM Studio, vLLM, LocalAI, Jan, Kobold, Tabby)...")
servers := discovery.Scan(context.Background())
var targets []pendingTarget
if len(servers) == 0 {
fmt.Println("No servers found automatically.")
} else {
fmt.Printf("Found %d server(s):\n\n", len(servers))
for i, s := range servers {
fmt.Printf(" [%d] %-12s %s\n Models: %s\n\n",
i+1, s.Kind.Name, s.BaseURL, joinModels(s.Models))
}
// Let user pick one or more from the list; 0 = manual
for {
choice, err := promptInt(in,
fmt.Sprintf("Select server [1-%d] or 0 to enter URL manually: ", len(servers)), 0, len(servers))
if err != nil {
return err
}
var pt pendingTarget
if choice == 0 {
pt, err = collectManualTarget(in)
} else {
pt, err = collectDiscoveredTarget(in, servers[choice-1])
}
if err != nil {
return err
}
targets = append(targets, pt)
again, err := promptBool(in, "Add another forwarder?", false)
if err != nil {
return err
}
if !again {
break
}
}
}
if len(targets) == 0 {
pt, err := collectManualTarget(in)
if err != nil {
return err
}
targets = append(targets, pt)
}
// ── Step 2: Detect dimensions ─────────────────────────────────────────────
step(2, 6, "Detect model dimensions")
for i := range targets {
fmt.Printf("Probing %s / %s ... ", targets[i].endpoint, targets[i].model)
dim, err := detectDim(targets[i])
if err != nil {
fmt.Printf("failed (%s) — you will need to enter the dimension manually\n", err)
targets[i].detectedDim = 0
} else {
fmt.Printf("%d dims\n", dim)
targets[i].detectedDim = dim
}
}
fmt.Println()
// ── Step 3: Configure adapter ─────────────────────────────────────────────
step(3, 6, "Configure dimension adapter")
// Use the first target's detected dim as the source dimension default
firstDim := 0
for _, t := range targets {
if t.detectedDim > 0 {
firstDim = t.detectedDim
break
}
}
srcDimStr := ""
if firstDim > 0 {
srcDimStr = fmt.Sprintf("%d", firstDim)
}
sourceDimRaw, err := promptString(in,
fmt.Sprintf("Source dimension (native model output dim)%s: ", defaultHint(srcDimStr)), srcDimStr)
if err != nil {
return err
}
sourceDim := mustParseInt(sourceDimRaw, firstDim)
targetDimRaw, err := promptString(in, "Target dimension (output dim vecna will serve) [1536]: ", "1536")
if err != nil {
return err
}
targetDim := mustParseInt(targetDimRaw, 1536)
adapterType, err := promptString(in, "Adapter type (truncate/random/projection) [truncate]: ", "truncate")
if err != nil {
return err
}
truncateMode := "from_end"
padMode := "at_end"
if adapterType == "truncate" {
truncateMode, err = promptString(in, "Truncate mode (from_end/from_start) [from_end]: ", "from_end")
if err != nil {
return err
}
padMode, err = promptString(in, "Pad mode (at_end/at_start) [at_end]: ", "at_end")
if err != nil {
return err
}
}
fmt.Println()
// ── Step 4: Configure vecna server ────────────────────────────────────────
step(4, 6, "Configure vecna server")
portRaw, err := promptString(in, "Bind port [8080]: ", "8080")
if err != nil {
return err
}
port := mustParseInt(portRaw, 8080)
apiKeysRaw, err := promptString(in,
"Inbound API keys for vecna (comma-separated, leave empty to disable auth): ", "")
if err != nil {
return err
}
var apiKeys []string
for _, k := range strings.Split(apiKeysRaw, ",") {
if k := strings.TrimSpace(k); k != "" {
apiKeys = append(apiKeys, k)
}
}
enableMetrics, err := promptBool(in, "Enable Prometheus /metrics endpoint?", false)
if err != nil {
return err
}
metricsAPIKey := ""
if enableMetrics {
metricsAPIKey, err = promptString(in, "Metrics API key (leave empty for open access): ", "")
if err != nil {
return err
}
}
fmt.Println()
// ── Step 5: Extra maps ────────────────────────────────────────────────────
step(5, 6, "Configure extra maps (optional)")
extraMaps := map[string]config.ExtraMapConfig{}
addMaps, err := promptBool(in, "Add extra dimension maps (/map/{key}/v1/embeddings)?", false)
if err != nil {
return err
}
if addMaps {
// Build the list of target names from the already-collected targets.
targetNames := make([]string, 0, len(targets))
for _, t := range targets {
targetNames = append(targetNames, t.name)
}
for {
mc, mapErr := collectExtraMap(in, sourceDim, targetDim, adapterType, truncateMode, padMode, targetNames)
if mapErr != nil {
return mapErr
}
extraMaps[mc.key] = mc.cfg
another, promptErr := promptBool(in, "Add another extra map?", false)
if promptErr != nil {
return promptErr
}
if !another {
break
}
}
}
fmt.Println()
// ── Step 6: Test & write ──────────────────────────────────────────────────
step(6, 6, "Test connections and write config")
allPassed := true
for _, t := range targets {
fmt.Printf("Testing %-45s ", t.endpoint+"...")
_, elapsed, dims, testErr := runSingleTest(t)
if testErr != nil {
fmt.Printf("FAIL %s\n", truncate(testErr.Error(), 55))
allPassed = false
} else {
fmt.Printf("OK %dms dims=%d\n", elapsed.Milliseconds(), dims)
}
}
fmt.Println()
if !allPassed {
proceed, err := promptBool(in, "Some tests failed. Write config anyway?", false)
if err != nil {
return err
}
if !proceed {
fmt.Println("Aborted. No config written.")
return nil
}
}
// Build the config struct
defaultTarget := ""
forwardTargets := make(map[string]config.ForwardTarget, len(targets))
for i, t := range targets {
forwardTargets[t.name] = config.ForwardTarget{
APIType: t.apiType,
Model: t.model,
APIKey: t.apiKey,
Endpoints: []config.EndpointConfig{
{URL: t.endpoint, Priority: 10},
},
TimeoutSecs: 30,
CooldownSecs: 60,
PriorityDecay: 2,
PriorityRecovery: 5,
}
if i == 0 {
defaultTarget = t.name
}
}
cfg := config.Config{
Server: config.ServerConfig{
Port: port,
Host: "0.0.0.0",
APIKeys: apiKeys,
},
Metrics: config.MetricsConfig{
Enabled: enableMetrics,
Path: "/metrics",
APIKey: metricsAPIKey,
},
Forward: config.ForwardConfig{
Default: defaultTarget,
Targets: forwardTargets,
},
Adapter: config.AdapterConfig{
Type: adapterType,
SourceDim: sourceDim,
TargetDim: targetDim,
TruncateMode: truncateMode,
PadMode: padMode,
},
ExtraMaps: extraMaps,
}
defaultCfgPath := config.ResolveFile(cfgFile)
fmt.Printf("Config will be written to: %s\n", defaultCfgPath)
cfgPath, err := promptString(in, "Config path (press Enter to accept): ", defaultCfgPath)
if err != nil {
return err
}
if err := writeFullConfig(cfgPath, cfg); err != nil {
return fmt.Errorf("write config: %w", err)
}
fmt.Printf("Config written to %s\n", cfgPath)
fmt.Println()
fmt.Println("Run 'vecna serve' to start the proxy server.")
return nil
}
// ── helpers ──────────────────────────────────────────────────────────────────
// pendingTarget collects configuration for a single forwarding target before
// the config is assembled.
type pendingTarget struct {
name string
endpoint string
model string
apiType string
apiKey string
detectedDim int
}
func step(n, total int, title string) {
fmt.Printf("[%d/%d] %s\n", n, total, title)
fmt.Println(strings.Repeat("-", 40))
}
func defaultHint(s string) string {
if s == "" {
return ""
}
return fmt.Sprintf(" [%s]", s)
}
func joinModels(models []string) string {
if len(models) == 0 {
return "(none)"
}
if len(models) > 5 {
return strings.Join(models[:5], ", ") + fmt.Sprintf(" (+%d more)", len(models)-5)
}
return strings.Join(models, ", ")
}
func collectDiscoveredTarget(in *bufio.Reader, srv discovery.Found) (pendingTarget, error) {
defaultName := strings.ToLower(strings.ReplaceAll(srv.Kind.Name, " ", "_"))
model, err := pickModel(in, srv.Models)
if err != nil {
return pendingTarget{}, err
}
name, err := promptString(in, fmt.Sprintf("Target name in config [%s]: ", defaultName), defaultName)
if err != nil {
return pendingTarget{}, err
}
apiKey, err := promptAPIKey(in, srv.Kind.NeedsKey)
if err != nil {
return pendingTarget{}, err
}
return pendingTarget{
name: name,
endpoint: srv.BaseURL,
model: model,
apiType: srv.Kind.APIType,
apiKey: apiKey,
}, nil
}
func collectManualTarget(in *bufio.Reader) (pendingTarget, error) {
fmt.Println("Enter server details manually:")
endpoint, err := promptString(in, "Server URL (e.g. http://localhost:11434): ", "")
if err != nil || endpoint == "" {
return pendingTarget{}, fmt.Errorf("server URL is required")
}
apiTypeStr, err := promptString(in, "API type (openai/google) [openai]: ", "openai")
if err != nil {
return pendingTarget{}, err
}
model, err := promptString(in, "Model name: ", "")
if err != nil || model == "" {
return pendingTarget{}, fmt.Errorf("model name is required")
}
name, err := promptString(in, "Target name in config [custom]: ", "custom")
if err != nil {
return pendingTarget{}, err
}
apiKey, err := promptAPIKey(in, false)
if err != nil {
return pendingTarget{}, err
}
return pendingTarget{
name: name,
endpoint: endpoint,
model: model,
apiType: apiTypeStr,
apiKey: apiKey,
}, nil
}
func pickModel(in *bufio.Reader, models []string) (string, error) {
switch {
case len(models) == 0:
return promptString(in, "Model name: ", "")
case len(models) == 1:
fmt.Printf("Using model: %s\n", models[0])
return models[0], nil
default:
fmt.Println("Available models:")
for i, m := range models {
fmt.Printf(" [%d] %s\n", i+1, m)
}
idx, err := promptInt(in, fmt.Sprintf("Select model [1-%d]: ", len(models)), 1, len(models))
if err != nil {
return "", err
}
return models[idx-1], nil
}
}
func promptAPIKey(in *bufio.Reader, required bool) (string, error) {
prompt := "API key (leave empty if none): "
if required {
prompt = "API key: "
}
return promptString(in, prompt, "")
}
// detectDim sends a single test embedding and returns the vector length.
func detectDim(t pendingTarget) (int, error) {
httpClient := &http.Client{Timeout: 10 * time.Second}
var client embedclient.Client
if t.apiType == "google" {
client = embedclient.NewGoogle(t.endpoint, t.apiKey, t.model, httpClient)
} else {
client = embedclient.NewOpenAI(t.endpoint, t.apiKey, httpClient)
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
resp, err := client.Embed(ctx, embedclient.Request{
Texts: []string{"dimension probe"},
Model: t.model,
})
if err != nil {
return 0, err
}
if len(resp.Embeddings) == 0 || len(resp.Embeddings[0]) == 0 {
return 0, fmt.Errorf("empty embedding in response")
}
return len(resp.Embeddings[0]), nil
}
// runSingleTest runs one test embed and returns success, elapsed time, dims, and any error.
func runSingleTest(t pendingTarget) (bool, time.Duration, int, error) {
httpClient := &http.Client{Timeout: 30 * time.Second}
var client embedclient.Client
if t.apiType == "google" {
client = embedclient.NewGoogle(t.endpoint, t.apiKey, t.model, httpClient)
} else {
client = embedclient.NewOpenAI(t.endpoint, t.apiKey, httpClient)
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
start := time.Now()
resp, err := client.Embed(ctx, embedclient.Request{
Texts: []string{testPhrase},
Model: t.model,
})
elapsed := time.Since(start)
if err != nil {
return false, elapsed, 0, err
}
dims, _ := embeddingStats(resp.Embeddings)
return true, elapsed, dims, nil
}
func mustParseInt(s string, fallback int) int {
var n int
if _, err := fmt.Sscanf(s, "%d", &n); err != nil || n <= 0 {
return fallback
}
return n
}
// pendingExtraMap pairs the map key with its collected config.
type pendingExtraMap struct {
key string
cfg config.ExtraMapConfig
}
// collectExtraMap interactively collects one extra_map entry.
// Global adapter values are shown as defaults; the user may override any field.
func collectExtraMap(
in *bufio.Reader,
globalSourceDim, globalTargetDim int,
globalAdapterType, globalTruncateMode, globalPadMode string,
targetNames []string,
) (pendingExtraMap, error) {
key, err := promptString(in, "Map key (used in URL path, e.g. \"512\"): ", "")
if err != nil || key == "" {
return pendingExtraMap{}, fmt.Errorf("map key is required")
}
mc := config.ExtraMapConfig{}
// Target dimension (required — always shown)
targetDimRaw, err := promptString(in,
fmt.Sprintf("Target dimension [%d]: ", globalTargetDim), fmt.Sprintf("%d", globalTargetDim))
if err != nil {
return pendingExtraMap{}, err
}
mc.TargetDim = mustParseInt(targetDimRaw, globalTargetDim)
// Forward target override
if len(targetNames) > 0 {
fmt.Println("Available forward targets: " + strings.Join(targetNames, ", "))
ftRaw, err := promptString(in, "Forward target override (leave empty to use global default): ", "")
if err != nil {
return pendingExtraMap{}, err
}
mc.ForwardTarget = strings.TrimSpace(ftRaw)
}
// Adapter type override
adapterTypeRaw, err := promptString(in,
fmt.Sprintf("Adapter type override (truncate/random/projection) [%s]: ", globalAdapterType), globalAdapterType)
if err != nil {
return pendingExtraMap{}, err
}
if adapterTypeRaw != globalAdapterType {
mc.Type = adapterTypeRaw
}
effectiveType := coalesce(mc.Type, globalAdapterType)
// Source dim override
sourceDimRaw, err := promptString(in,
fmt.Sprintf("Source dimension override (leave empty to use global %d): ", globalSourceDim), "")
if err != nil {
return pendingExtraMap{}, err
}
if sourceDimRaw != "" {
mc.SourceDim = mustParseInt(sourceDimRaw, globalSourceDim)
}
// Truncate/pad mode overrides — only for truncate type
if effectiveType == "truncate" {
tmRaw, err := promptString(in,
fmt.Sprintf("Truncate mode override (from_end/from_start) [%s]: ", globalTruncateMode), globalTruncateMode)
if err != nil {
return pendingExtraMap{}, err
}
if tmRaw != globalTruncateMode {
mc.TruncateMode = tmRaw
}
pmRaw, err := promptString(in,
fmt.Sprintf("Pad mode override (at_end/at_start) [%s]: ", globalPadMode), globalPadMode)
if err != nil {
return pendingExtraMap{}, err
}
if pmRaw != globalPadMode {
mc.PadMode = pmRaw
}
}
// Seed — only for random type
if effectiveType == "random" {
seedRaw, err := promptString(in, "Seed (0 = time-based): ", "0")
if err != nil {
return pendingExtraMap{}, err
}
var seed int64
if _, err := fmt.Sscanf(seedRaw, "%d", &seed); err != nil {
seed = 0
}
mc.Seed = seed
}
return pendingExtraMap{key: key, cfg: mc}, nil
}
func writeFullConfig(path string, cfg config.Config) error {
// If file already exists, preserve any targets not touched by onboard
// by using SaveTarget for each new target; otherwise write the whole file.
if _, err := os.Stat(path); os.IsNotExist(err) {
if err := createDefaultConfig(path); err != nil {
return err
}
}
// Overwrite with the complete onboard config
return config.WriteConfig(path, cfg)
}