Added modelregistry.AddRegistry

This commit is contained in:
Hein 2025-11-19 18:18:18 +02:00
parent a44ef90d7c
commit 850d7b546c

View File

@ -17,6 +17,10 @@ var defaultRegistry = &DefaultModelRegistry{
models: make(map[string]interface{}), models: make(map[string]interface{}),
} }
// Global list of registries (searched in order)
var registries = []*DefaultModelRegistry{defaultRegistry}
var registriesMutex sync.RWMutex
// NewModelRegistry creates a new model registry // NewModelRegistry creates a new model registry
func NewModelRegistry() *DefaultModelRegistry { func NewModelRegistry() *DefaultModelRegistry {
return &DefaultModelRegistry{ return &DefaultModelRegistry{
@ -24,6 +28,14 @@ func NewModelRegistry() *DefaultModelRegistry {
} }
} }
// AddRegistry adds a registry to the global list of registries
// Registries are searched in the order they were added
func AddRegistry(registry *DefaultModelRegistry) {
registriesMutex.Lock()
defer registriesMutex.Unlock()
registries = append(registries, registry)
}
func (r *DefaultModelRegistry) RegisterModel(name string, model interface{}) error { func (r *DefaultModelRegistry) RegisterModel(name string, model interface{}) error {
r.mutex.Lock() r.mutex.Lock()
defer r.mutex.Unlock() defer r.mutex.Unlock()
@ -107,9 +119,19 @@ func RegisterModel(model interface{}, name string) error {
return defaultRegistry.RegisterModel(name, model) return defaultRegistry.RegisterModel(name, model)
} }
// GetModelByName retrieves a model from the default global registry by name // GetModelByName retrieves a model by searching through all registries in order
// Returns the first match found
func GetModelByName(name string) (interface{}, error) { func GetModelByName(name string) (interface{}, error) {
return defaultRegistry.GetModel(name) registriesMutex.RLock()
defer registriesMutex.RUnlock()
for _, registry := range registries {
if model, err := registry.GetModel(name); err == nil {
return model, nil
}
}
return nil, fmt.Errorf("model %s not found in any registry", name)
} }
// IterateModels iterates over all models in the default global registry // IterateModels iterates over all models in the default global registry
@ -122,14 +144,26 @@ func IterateModels(fn func(name string, model interface{})) {
} }
} }
// GetModels returns a list of all models in the default global registry // GetModels returns a list of all models from all registries
// Models are collected in registry order, with duplicates included
func GetModels() []interface{} { func GetModels() []interface{} {
defaultRegistry.mutex.RLock() registriesMutex.RLock()
defer defaultRegistry.mutex.RUnlock() defer registriesMutex.RUnlock()
models := make([]interface{}, 0, len(defaultRegistry.models)) var models []interface{}
for _, model := range defaultRegistry.models { seen := make(map[string]bool)
models = append(models, model)
for _, registry := range registries {
registry.mutex.RLock()
for name, model := range registry.models {
// Only add the first occurrence of each model name
if !seen[name] {
models = append(models, model)
seen[name] = true
}
}
registry.mutex.RUnlock()
} }
return models return models
} }