diff --git a/pkg/modelregistry/model_registry.go b/pkg/modelregistry/model_registry.go index 73beec9..db0611b 100644 --- a/pkg/modelregistry/model_registry.go +++ b/pkg/modelregistry/model_registry.go @@ -6,15 +6,37 @@ import ( "sync" ) +// ModelRules defines the permissions and security settings for a model +type ModelRules struct { + CanRead bool // Whether the model can be read (GET operations) + CanUpdate bool // Whether the model can be updated (PUT/PATCH operations) + CanCreate bool // Whether the model can be created (POST operations) + CanDelete bool // Whether the model can be deleted (DELETE operations) + SecurityDisabled bool // Whether security checks are disabled for this model +} + +// DefaultModelRules returns the default rules for a model (all operations allowed, security enabled) +func DefaultModelRules() ModelRules { + return ModelRules{ + CanRead: true, + CanUpdate: true, + CanCreate: true, + CanDelete: true, + SecurityDisabled: false, + } +} + // DefaultModelRegistry implements ModelRegistry interface type DefaultModelRegistry struct { models map[string]interface{} + rules map[string]ModelRules mutex sync.RWMutex } // Global default registry instance var defaultRegistry = &DefaultModelRegistry{ models: make(map[string]interface{}), + rules: make(map[string]ModelRules), } // Global list of registries (searched in order) @@ -25,6 +47,7 @@ var registriesMutex sync.RWMutex func NewModelRegistry() *DefaultModelRegistry { return &DefaultModelRegistry{ models: make(map[string]interface{}), + rules: make(map[string]ModelRules), } } @@ -98,6 +121,10 @@ func (r *DefaultModelRegistry) RegisterModel(name string, model interface{}) err } r.models[name] = model + // Initialize with default rules if not already set + if _, exists := r.rules[name]; !exists { + r.rules[name] = DefaultModelRules() + } return nil } @@ -135,6 +162,54 @@ func (r *DefaultModelRegistry) GetModelByEntity(schema, entity string) (interfac return r.GetModel(entity) } +// SetModelRules sets the rules for a specific model +func (r *DefaultModelRegistry) SetModelRules(name string, rules ModelRules) error { + r.mutex.Lock() + defer r.mutex.Unlock() + + // Check if model exists + if _, exists := r.models[name]; !exists { + return fmt.Errorf("model %s not found", name) + } + + r.rules[name] = rules + return nil +} + +// GetModelRules retrieves the rules for a specific model +// Returns default rules if model exists but rules are not set +func (r *DefaultModelRegistry) GetModelRules(name string) (ModelRules, error) { + r.mutex.RLock() + defer r.mutex.RUnlock() + + // Check if model exists + if _, exists := r.models[name]; !exists { + return ModelRules{}, fmt.Errorf("model %s not found", name) + } + + // Return rules if set, otherwise return default rules + if rules, exists := r.rules[name]; exists { + return rules, nil + } + + return DefaultModelRules(), nil +} + +// RegisterModelWithRules registers a model with specific rules +func (r *DefaultModelRegistry) RegisterModelWithRules(name string, model interface{}, rules ModelRules) error { + // First register the model + if err := r.RegisterModel(name, model); err != nil { + return err + } + + // Then set the rules (we need to lock again for rules) + r.mutex.Lock() + defer r.mutex.Unlock() + r.rules[name] = rules + + return nil +} + // Global convenience functions using the default registry // RegisterModel registers a model with the default global registry @@ -190,3 +265,34 @@ func GetModels() []interface{} { return models } + +// SetModelRules sets the rules for a specific model in the default registry +func SetModelRules(name string, rules ModelRules) error { + return defaultRegistry.SetModelRules(name, rules) +} + +// GetModelRules retrieves the rules for a specific model from the default registry +func GetModelRules(name string) (ModelRules, error) { + return defaultRegistry.GetModelRules(name) +} + +// GetModelRulesByName retrieves the rules for a model by searching through all registries in order +// Returns the first match found +func GetModelRulesByName(name string) (ModelRules, error) { + registriesMutex.RLock() + defer registriesMutex.RUnlock() + + for _, registry := range registries { + if _, err := registry.GetModel(name); err == nil { + // Model found in this registry, get its rules + return registry.GetModelRules(name) + } + } + + return ModelRules{}, fmt.Errorf("model %s not found in any registry", name) +} + +// RegisterModelWithRules registers a model with specific rules in the default registry +func RegisterModelWithRules(model interface{}, name string, rules ModelRules) error { + return defaultRegistry.RegisterModelWithRules(name, model, rules) +}