diff --git a/pkg/modelregistry/model_registry.go b/pkg/modelregistry/model_registry.go index 930da2e..8f98084 100644 --- a/pkg/modelregistry/model_registry.go +++ b/pkg/modelregistry/model_registry.go @@ -38,12 +38,28 @@ func (r *DefaultModelRegistry) RegisterModel(name string, model interface{}) err return fmt.Errorf("model cannot be nil") } - if modelType.Kind() == reflect.Ptr { - return fmt.Errorf("model must be a non-pointer struct, got pointer to %s", modelType.Elem().Kind()) + originalType := modelType + + // Unwrap pointers, slices, and arrays to check the underlying type + for modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array { + modelType = modelType.Elem() } + // Validate that the underlying type is a struct if modelType.Kind() != reflect.Struct { - return fmt.Errorf("model must be a struct, got %s", modelType.Kind()) + return fmt.Errorf("model must be a struct or pointer to struct, got %s", originalType.String()) + } + + // If a pointer/slice/array was passed, unwrap to the base struct + if originalType != modelType { + // Create a zero value of the struct type + model = reflect.New(modelType).Elem().Interface() + } + + // Additional check: ensure model is not a pointer + finalType := reflect.TypeOf(model) + if finalType.Kind() == reflect.Ptr { + return fmt.Errorf("model must be a non-pointer struct, got pointer to %s. Use MyModel{} instead of &MyModel{}", finalType.Elem().Name()) } r.models[name] = model diff --git a/pkg/resolvespec/handler.go b/pkg/resolvespec/handler.go index da43748..e1ec586 100644 --- a/pkg/resolvespec/handler.go +++ b/pkg/resolvespec/handler.go @@ -73,6 +73,26 @@ func (h *Handler) Handle(w common.ResponseWriter, r common.Request, params map[s return } + // Validate that the model is a struct type (not a slice or pointer to slice) + modelType := reflect.TypeOf(model) + originalType := modelType + for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) { + modelType = modelType.Elem() + } + + if modelType == nil || modelType.Kind() != reflect.Struct { + logger.Error("Model for %s.%s must be a struct type, got %v. Please register models as struct types, not slices or pointers to slices.", schema, entity, originalType) + h.sendError(w, http.StatusInternalServerError, "invalid_model_type", + fmt.Sprintf("Model must be a struct type, got %v. Ensure you register the struct (e.g., ModelCoreAccount{}) not a slice (e.g., []*ModelCoreAccount)", originalType), + fmt.Errorf("invalid model type: %v", originalType)) + return + } + + // If the registered model was a pointer or slice, use the unwrapped struct type + if originalType != modelType { + model = reflect.New(modelType).Elem().Interface() + } + // Create a pointer to the model type for database operations modelPtr := reflect.New(reflect.TypeOf(model)).Interface() tableName := h.getTableName(schema, entity, model) @@ -132,7 +152,21 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st entity := GetEntity(ctx) tableName := GetTableName(ctx) model := GetModel(ctx) - modelPtr := GetModelPtr(ctx) + + // Validate and unwrap model type to get base struct + modelType := reflect.TypeOf(model) + for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) { + modelType = modelType.Elem() + } + + if modelType == nil || modelType.Kind() != reflect.Struct { + logger.Error("Model must be a struct type, got %v for %s.%s", modelType, schema, entity) + h.sendError(w, http.StatusInternalServerError, "invalid_model", "Model must be a struct type", fmt.Errorf("invalid model type: %v", modelType)) + return + } + + // Create a pointer to the model type for database operations + modelPtr := reflect.New(modelType).Interface() logger.Info("Reading records from %s.%s", schema, entity) @@ -189,8 +223,8 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st var result interface{} if id != "" { logger.Debug("Querying single record with ID: %s", id) - // Create a pointer to the struct type for scanning - singleResult := reflect.New(reflect.TypeOf(model)).Interface() + // Create a pointer to the struct type for scanning - use modelType which is already unwrapped + singleResult := reflect.New(modelType).Interface() query = query.Where("id = ?", id) if err := query.Scan(ctx, singleResult); err != nil { logger.Error("Error querying record: %v", err) @@ -200,8 +234,8 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st result = singleResult } else { logger.Debug("Querying multiple records") - // Create a slice of pointers to the model type - sliceType := reflect.SliceOf(reflect.PointerTo(reflect.TypeOf(model))) + // Create a slice of pointers to the model type - use modelType which is already unwrapped + sliceType := reflect.SliceOf(reflect.PointerTo(modelType)) results := reflect.New(sliceType).Interface() if err := query.Scan(ctx, results); err != nil { @@ -444,10 +478,23 @@ func (h *Handler) getTableName(schema, entity string, model interface{}) string func (h *Handler) generateMetadata(schema, entity string, model interface{}) *common.TableMetadata { modelType := reflect.TypeOf(model) - if modelType.Kind() == reflect.Ptr { + + // Unwrap pointers, slices, and arrays to get to the base struct type + for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) { modelType = modelType.Elem() } + // Validate that we have a struct type + if modelType == nil || modelType.Kind() != reflect.Struct { + logger.Error("Model type must be a struct, got %v for %s.%s", modelType, schema, entity) + return &common.TableMetadata{ + Schema: schema, + Table: entity, + Columns: make([]common.Column, 0), + Relations: make([]string, 0), + } + } + metadata := &common.TableMetadata{ Schema: schema, Table: entity, @@ -591,10 +638,18 @@ type relationshipInfo struct { func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, preloads []common.PreloadOption) common.SelectQuery { modelType := reflect.TypeOf(model) - if modelType.Kind() == reflect.Ptr { + + // Unwrap pointers, slices, and arrays to get to the base struct type + for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) { modelType = modelType.Elem() } + // Validate that we have a struct type + if modelType == nil || modelType.Kind() != reflect.Struct { + logger.Warn("Cannot apply preloads to non-struct type: %v", modelType) + return query + } + for _, preload := range preloads { logger.Debug("Processing preload for relation: %s", preload.Relation) relInfo := h.getRelationshipInfo(modelType, preload.Relation) @@ -618,6 +673,12 @@ func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, pre } func (h *Handler) getRelationshipInfo(modelType reflect.Type, relationName string) *relationshipInfo { + // Ensure we have a struct type + if modelType == nil || modelType.Kind() != reflect.Struct { + logger.Warn("Cannot get relationship info from non-struct type: %v", modelType) + return nil + } + for i := 0; i < modelType.NumField(); i++ { field := modelType.Field(i) jsonTag := field.Tag.Get("json") diff --git a/pkg/restheadspec/handler.go b/pkg/restheadspec/handler.go index 2238bdc..5e7247b 100644 --- a/pkg/restheadspec/handler.go +++ b/pkg/restheadspec/handler.go @@ -67,6 +67,26 @@ func (h *Handler) Handle(w common.ResponseWriter, r common.Request, params map[s return } + // Validate that the model is a struct type (not a slice or pointer to slice) + modelType := reflect.TypeOf(model) + originalType := modelType + for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) { + modelType = modelType.Elem() + } + + if modelType == nil || modelType.Kind() != reflect.Struct { + logger.Error("Model for %s.%s must be a struct type, got %v. Please register models as struct types, not slices or pointers to slices.", schema, entity, originalType) + h.sendError(w, http.StatusInternalServerError, "invalid_model_type", + fmt.Sprintf("Model must be a struct type, got %v. Ensure you register the struct (e.g., ModelCoreAccount{}) not a slice (e.g., []*ModelCoreAccount)", originalType), + fmt.Errorf("invalid model type: %v", originalType)) + return + } + + // If the registered model was a pointer or slice, use the unwrapped struct type + if originalType != modelType { + model = reflect.New(modelType).Elem().Interface() + } + modelPtr := reflect.New(reflect.TypeOf(model)).Interface() tableName := h.getTableName(schema, entity, model) @@ -158,7 +178,22 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st schema := GetSchema(ctx) entity := GetEntity(ctx) tableName := GetTableName(ctx) - modelPtr := GetModelPtr(ctx) + model := GetModel(ctx) + + // Validate and unwrap model type to get base struct + modelType := reflect.TypeOf(model) + for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) { + modelType = modelType.Elem() + } + + if modelType == nil || modelType.Kind() != reflect.Struct { + logger.Error("Model must be a struct type, got %v for %s.%s", modelType, schema, entity) + h.sendError(w, http.StatusInternalServerError, "invalid_model", "Model must be a struct type", fmt.Errorf("invalid model type: %v", modelType)) + return + } + + // Create a pointer to a slice of pointers to the model type for query results + modelPtr := reflect.New(reflect.SliceOf(reflect.PointerTo(modelType))).Interface() logger.Info("Reading records from %s.%s", schema, entity) @@ -252,10 +287,8 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st query = query.Offset(*options.Offset) } - // Execute query - create a slice of pointers to the model type - model := GetModel(ctx) - resultSlice := reflect.New(reflect.SliceOf(reflect.PointerTo(reflect.TypeOf(model)))).Interface() - if err := query.Scan(ctx, resultSlice); err != nil { + // Execute query - modelPtr was already created earlier + if err := query.Scan(ctx, modelPtr); err != nil { logger.Error("Error executing query: %v", err) h.sendError(w, http.StatusInternalServerError, "query_error", "Error executing query", err) return @@ -277,7 +310,7 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st Offset: offset, } - h.sendFormattedResponse(w, resultSlice, metadata, options) + h.sendFormattedResponse(w, modelPtr, metadata, options) } func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, data interface{}, options ExtendedRequestOptions) { @@ -516,10 +549,22 @@ func (h *Handler) getTableName(schema, entity string, model interface{}) string func (h *Handler) generateMetadata(schema, entity string, model interface{}) *common.TableMetadata { modelType := reflect.TypeOf(model) - if modelType.Kind() == reflect.Ptr { + + // Unwrap pointers, slices, and arrays to get to the base struct type + for modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array { modelType = modelType.Elem() } + // Validate that we have a struct type + if modelType.Kind() != reflect.Struct { + logger.Error("Model type must be a struct, got %s for %s.%s", modelType.Kind(), schema, entity) + return &common.TableMetadata{ + Schema: schema, + Table: h.getTableName(schema, entity, model), + Columns: []common.Column{}, + } + } + tableName := h.getTableName(schema, entity, model) metadata := &common.TableMetadata{