Reflect safty

This commit is contained in:
Hein
2025-11-07 09:47:12 +02:00
parent e7e5754a47
commit e88018543e
3 changed files with 139 additions and 17 deletions

View File

@@ -67,6 +67,26 @@ func (h *Handler) Handle(w common.ResponseWriter, r common.Request, params map[s
return
}
// Validate that the model is a struct type (not a slice or pointer to slice)
modelType := reflect.TypeOf(model)
originalType := modelType
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
modelType = modelType.Elem()
}
if modelType == nil || modelType.Kind() != reflect.Struct {
logger.Error("Model for %s.%s must be a struct type, got %v. Please register models as struct types, not slices or pointers to slices.", schema, entity, originalType)
h.sendError(w, http.StatusInternalServerError, "invalid_model_type",
fmt.Sprintf("Model must be a struct type, got %v. Ensure you register the struct (e.g., ModelCoreAccount{}) not a slice (e.g., []*ModelCoreAccount)", originalType),
fmt.Errorf("invalid model type: %v", originalType))
return
}
// If the registered model was a pointer or slice, use the unwrapped struct type
if originalType != modelType {
model = reflect.New(modelType).Elem().Interface()
}
modelPtr := reflect.New(reflect.TypeOf(model)).Interface()
tableName := h.getTableName(schema, entity, model)
@@ -158,7 +178,22 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
schema := GetSchema(ctx)
entity := GetEntity(ctx)
tableName := GetTableName(ctx)
modelPtr := GetModelPtr(ctx)
model := GetModel(ctx)
// Validate and unwrap model type to get base struct
modelType := reflect.TypeOf(model)
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
modelType = modelType.Elem()
}
if modelType == nil || modelType.Kind() != reflect.Struct {
logger.Error("Model must be a struct type, got %v for %s.%s", modelType, schema, entity)
h.sendError(w, http.StatusInternalServerError, "invalid_model", "Model must be a struct type", fmt.Errorf("invalid model type: %v", modelType))
return
}
// Create a pointer to a slice of pointers to the model type for query results
modelPtr := reflect.New(reflect.SliceOf(reflect.PointerTo(modelType))).Interface()
logger.Info("Reading records from %s.%s", schema, entity)
@@ -252,10 +287,8 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
query = query.Offset(*options.Offset)
}
// Execute query - create a slice of pointers to the model type
model := GetModel(ctx)
resultSlice := reflect.New(reflect.SliceOf(reflect.PointerTo(reflect.TypeOf(model)))).Interface()
if err := query.Scan(ctx, resultSlice); err != nil {
// Execute query - modelPtr was already created earlier
if err := query.Scan(ctx, modelPtr); err != nil {
logger.Error("Error executing query: %v", err)
h.sendError(w, http.StatusInternalServerError, "query_error", "Error executing query", err)
return
@@ -277,7 +310,7 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
Offset: offset,
}
h.sendFormattedResponse(w, resultSlice, metadata, options)
h.sendFormattedResponse(w, modelPtr, metadata, options)
}
func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, data interface{}, options ExtendedRequestOptions) {
@@ -516,10 +549,22 @@ func (h *Handler) getTableName(schema, entity string, model interface{}) string
func (h *Handler) generateMetadata(schema, entity string, model interface{}) *common.TableMetadata {
modelType := reflect.TypeOf(model)
if modelType.Kind() == reflect.Ptr {
// Unwrap pointers, slices, and arrays to get to the base struct type
for modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array {
modelType = modelType.Elem()
}
// Validate that we have a struct type
if modelType.Kind() != reflect.Struct {
logger.Error("Model type must be a struct, got %s for %s.%s", modelType.Kind(), schema, entity)
return &common.TableMetadata{
Schema: schema,
Table: h.getTableName(schema, entity, model),
Columns: []common.Column{},
}
}
tableName := h.getTableName(schema, entity, model)
metadata := &common.TableMetadata{