From 850d7b546ca948545e56337e579e9b33430f90de Mon Sep 17 00:00:00 2001 From: Hein Date: Wed, 19 Nov 2025 18:18:18 +0200 Subject: [PATCH] Added modelregistry.AddRegistry --- pkg/modelregistry/model_registry.go | 50 ++++++++++++++++++++++++----- 1 file changed, 42 insertions(+), 8 deletions(-) diff --git a/pkg/modelregistry/model_registry.go b/pkg/modelregistry/model_registry.go index 01b5e19..e25872f 100644 --- a/pkg/modelregistry/model_registry.go +++ b/pkg/modelregistry/model_registry.go @@ -17,6 +17,10 @@ var defaultRegistry = &DefaultModelRegistry{ models: make(map[string]interface{}), } +// Global list of registries (searched in order) +var registries = []*DefaultModelRegistry{defaultRegistry} +var registriesMutex sync.RWMutex + // NewModelRegistry creates a new model registry func NewModelRegistry() *DefaultModelRegistry { return &DefaultModelRegistry{ @@ -24,6 +28,14 @@ func NewModelRegistry() *DefaultModelRegistry { } } +// AddRegistry adds a registry to the global list of registries +// Registries are searched in the order they were added +func AddRegistry(registry *DefaultModelRegistry) { + registriesMutex.Lock() + defer registriesMutex.Unlock() + registries = append(registries, registry) +} + func (r *DefaultModelRegistry) RegisterModel(name string, model interface{}) error { r.mutex.Lock() defer r.mutex.Unlock() @@ -107,9 +119,19 @@ func RegisterModel(model interface{}, name string) error { return defaultRegistry.RegisterModel(name, model) } -// GetModelByName retrieves a model from the default global registry by name +// GetModelByName retrieves a model by searching through all registries in order +// Returns the first match found func GetModelByName(name string) (interface{}, error) { - return defaultRegistry.GetModel(name) + registriesMutex.RLock() + defer registriesMutex.RUnlock() + + for _, registry := range registries { + if model, err := registry.GetModel(name); err == nil { + return model, nil + } + } + + return nil, fmt.Errorf("model %s not found in any registry", name) } // IterateModels iterates over all models in the default global registry @@ -122,14 +144,26 @@ func IterateModels(fn func(name string, model interface{})) { } } -// GetModels returns a list of all models in the default global registry +// GetModels returns a list of all models from all registries +// Models are collected in registry order, with duplicates included func GetModels() []interface{} { - defaultRegistry.mutex.RLock() - defer defaultRegistry.mutex.RUnlock() + registriesMutex.RLock() + defer registriesMutex.RUnlock() - models := make([]interface{}, 0, len(defaultRegistry.models)) - for _, model := range defaultRegistry.models { - models = append(models, model) + var models []interface{} + seen := make(map[string]bool) + + for _, registry := range registries { + registry.mutex.RLock() + for name, model := range registry.models { + // Only add the first occurrence of each model name + if !seen[name] { + models = append(models, model) + seen[name] = true + } + } + registry.mutex.RUnlock() } + return models }