mirror of
https://github.com/bitechdev/ResolveSpec.git
synced 2025-11-13 09:53:53 +00:00
Reflect safty
This commit is contained in:
parent
e7e5754a47
commit
e88018543e
@ -38,12 +38,28 @@ func (r *DefaultModelRegistry) RegisterModel(name string, model interface{}) err
|
|||||||
return fmt.Errorf("model cannot be nil")
|
return fmt.Errorf("model cannot be nil")
|
||||||
}
|
}
|
||||||
|
|
||||||
if modelType.Kind() == reflect.Ptr {
|
originalType := modelType
|
||||||
return fmt.Errorf("model must be a non-pointer struct, got pointer to %s", modelType.Elem().Kind())
|
|
||||||
|
// Unwrap pointers, slices, and arrays to check the underlying type
|
||||||
|
for modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array {
|
||||||
|
modelType = modelType.Elem()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Validate that the underlying type is a struct
|
||||||
if modelType.Kind() != reflect.Struct {
|
if modelType.Kind() != reflect.Struct {
|
||||||
return fmt.Errorf("model must be a struct, got %s", modelType.Kind())
|
return fmt.Errorf("model must be a struct or pointer to struct, got %s", originalType.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// If a pointer/slice/array was passed, unwrap to the base struct
|
||||||
|
if originalType != modelType {
|
||||||
|
// Create a zero value of the struct type
|
||||||
|
model = reflect.New(modelType).Elem().Interface()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Additional check: ensure model is not a pointer
|
||||||
|
finalType := reflect.TypeOf(model)
|
||||||
|
if finalType.Kind() == reflect.Ptr {
|
||||||
|
return fmt.Errorf("model must be a non-pointer struct, got pointer to %s. Use MyModel{} instead of &MyModel{}", finalType.Elem().Name())
|
||||||
}
|
}
|
||||||
|
|
||||||
r.models[name] = model
|
r.models[name] = model
|
||||||
|
|||||||
@ -73,6 +73,26 @@ func (h *Handler) Handle(w common.ResponseWriter, r common.Request, params map[s
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Validate that the model is a struct type (not a slice or pointer to slice)
|
||||||
|
modelType := reflect.TypeOf(model)
|
||||||
|
originalType := modelType
|
||||||
|
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
||||||
|
modelType = modelType.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
if modelType == nil || modelType.Kind() != reflect.Struct {
|
||||||
|
logger.Error("Model for %s.%s must be a struct type, got %v. Please register models as struct types, not slices or pointers to slices.", schema, entity, originalType)
|
||||||
|
h.sendError(w, http.StatusInternalServerError, "invalid_model_type",
|
||||||
|
fmt.Sprintf("Model must be a struct type, got %v. Ensure you register the struct (e.g., ModelCoreAccount{}) not a slice (e.g., []*ModelCoreAccount)", originalType),
|
||||||
|
fmt.Errorf("invalid model type: %v", originalType))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the registered model was a pointer or slice, use the unwrapped struct type
|
||||||
|
if originalType != modelType {
|
||||||
|
model = reflect.New(modelType).Elem().Interface()
|
||||||
|
}
|
||||||
|
|
||||||
// Create a pointer to the model type for database operations
|
// Create a pointer to the model type for database operations
|
||||||
modelPtr := reflect.New(reflect.TypeOf(model)).Interface()
|
modelPtr := reflect.New(reflect.TypeOf(model)).Interface()
|
||||||
tableName := h.getTableName(schema, entity, model)
|
tableName := h.getTableName(schema, entity, model)
|
||||||
@ -132,7 +152,21 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
|||||||
entity := GetEntity(ctx)
|
entity := GetEntity(ctx)
|
||||||
tableName := GetTableName(ctx)
|
tableName := GetTableName(ctx)
|
||||||
model := GetModel(ctx)
|
model := GetModel(ctx)
|
||||||
modelPtr := GetModelPtr(ctx)
|
|
||||||
|
// Validate and unwrap model type to get base struct
|
||||||
|
modelType := reflect.TypeOf(model)
|
||||||
|
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
||||||
|
modelType = modelType.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
if modelType == nil || modelType.Kind() != reflect.Struct {
|
||||||
|
logger.Error("Model must be a struct type, got %v for %s.%s", modelType, schema, entity)
|
||||||
|
h.sendError(w, http.StatusInternalServerError, "invalid_model", "Model must be a struct type", fmt.Errorf("invalid model type: %v", modelType))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a pointer to the model type for database operations
|
||||||
|
modelPtr := reflect.New(modelType).Interface()
|
||||||
|
|
||||||
logger.Info("Reading records from %s.%s", schema, entity)
|
logger.Info("Reading records from %s.%s", schema, entity)
|
||||||
|
|
||||||
@ -189,8 +223,8 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
|||||||
var result interface{}
|
var result interface{}
|
||||||
if id != "" {
|
if id != "" {
|
||||||
logger.Debug("Querying single record with ID: %s", id)
|
logger.Debug("Querying single record with ID: %s", id)
|
||||||
// Create a pointer to the struct type for scanning
|
// Create a pointer to the struct type for scanning - use modelType which is already unwrapped
|
||||||
singleResult := reflect.New(reflect.TypeOf(model)).Interface()
|
singleResult := reflect.New(modelType).Interface()
|
||||||
query = query.Where("id = ?", id)
|
query = query.Where("id = ?", id)
|
||||||
if err := query.Scan(ctx, singleResult); err != nil {
|
if err := query.Scan(ctx, singleResult); err != nil {
|
||||||
logger.Error("Error querying record: %v", err)
|
logger.Error("Error querying record: %v", err)
|
||||||
@ -200,8 +234,8 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
|||||||
result = singleResult
|
result = singleResult
|
||||||
} else {
|
} else {
|
||||||
logger.Debug("Querying multiple records")
|
logger.Debug("Querying multiple records")
|
||||||
// Create a slice of pointers to the model type
|
// Create a slice of pointers to the model type - use modelType which is already unwrapped
|
||||||
sliceType := reflect.SliceOf(reflect.PointerTo(reflect.TypeOf(model)))
|
sliceType := reflect.SliceOf(reflect.PointerTo(modelType))
|
||||||
results := reflect.New(sliceType).Interface()
|
results := reflect.New(sliceType).Interface()
|
||||||
|
|
||||||
if err := query.Scan(ctx, results); err != nil {
|
if err := query.Scan(ctx, results); err != nil {
|
||||||
@ -444,10 +478,23 @@ func (h *Handler) getTableName(schema, entity string, model interface{}) string
|
|||||||
|
|
||||||
func (h *Handler) generateMetadata(schema, entity string, model interface{}) *common.TableMetadata {
|
func (h *Handler) generateMetadata(schema, entity string, model interface{}) *common.TableMetadata {
|
||||||
modelType := reflect.TypeOf(model)
|
modelType := reflect.TypeOf(model)
|
||||||
if modelType.Kind() == reflect.Ptr {
|
|
||||||
|
// Unwrap pointers, slices, and arrays to get to the base struct type
|
||||||
|
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
||||||
modelType = modelType.Elem()
|
modelType = modelType.Elem()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Validate that we have a struct type
|
||||||
|
if modelType == nil || modelType.Kind() != reflect.Struct {
|
||||||
|
logger.Error("Model type must be a struct, got %v for %s.%s", modelType, schema, entity)
|
||||||
|
return &common.TableMetadata{
|
||||||
|
Schema: schema,
|
||||||
|
Table: entity,
|
||||||
|
Columns: make([]common.Column, 0),
|
||||||
|
Relations: make([]string, 0),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
metadata := &common.TableMetadata{
|
metadata := &common.TableMetadata{
|
||||||
Schema: schema,
|
Schema: schema,
|
||||||
Table: entity,
|
Table: entity,
|
||||||
@ -591,10 +638,18 @@ type relationshipInfo struct {
|
|||||||
|
|
||||||
func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, preloads []common.PreloadOption) common.SelectQuery {
|
func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, preloads []common.PreloadOption) common.SelectQuery {
|
||||||
modelType := reflect.TypeOf(model)
|
modelType := reflect.TypeOf(model)
|
||||||
if modelType.Kind() == reflect.Ptr {
|
|
||||||
|
// Unwrap pointers, slices, and arrays to get to the base struct type
|
||||||
|
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
||||||
modelType = modelType.Elem()
|
modelType = modelType.Elem()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Validate that we have a struct type
|
||||||
|
if modelType == nil || modelType.Kind() != reflect.Struct {
|
||||||
|
logger.Warn("Cannot apply preloads to non-struct type: %v", modelType)
|
||||||
|
return query
|
||||||
|
}
|
||||||
|
|
||||||
for _, preload := range preloads {
|
for _, preload := range preloads {
|
||||||
logger.Debug("Processing preload for relation: %s", preload.Relation)
|
logger.Debug("Processing preload for relation: %s", preload.Relation)
|
||||||
relInfo := h.getRelationshipInfo(modelType, preload.Relation)
|
relInfo := h.getRelationshipInfo(modelType, preload.Relation)
|
||||||
@ -618,6 +673,12 @@ func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, pre
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) getRelationshipInfo(modelType reflect.Type, relationName string) *relationshipInfo {
|
func (h *Handler) getRelationshipInfo(modelType reflect.Type, relationName string) *relationshipInfo {
|
||||||
|
// Ensure we have a struct type
|
||||||
|
if modelType == nil || modelType.Kind() != reflect.Struct {
|
||||||
|
logger.Warn("Cannot get relationship info from non-struct type: %v", modelType)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
for i := 0; i < modelType.NumField(); i++ {
|
for i := 0; i < modelType.NumField(); i++ {
|
||||||
field := modelType.Field(i)
|
field := modelType.Field(i)
|
||||||
jsonTag := field.Tag.Get("json")
|
jsonTag := field.Tag.Get("json")
|
||||||
|
|||||||
@ -67,6 +67,26 @@ func (h *Handler) Handle(w common.ResponseWriter, r common.Request, params map[s
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Validate that the model is a struct type (not a slice or pointer to slice)
|
||||||
|
modelType := reflect.TypeOf(model)
|
||||||
|
originalType := modelType
|
||||||
|
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
||||||
|
modelType = modelType.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
if modelType == nil || modelType.Kind() != reflect.Struct {
|
||||||
|
logger.Error("Model for %s.%s must be a struct type, got %v. Please register models as struct types, not slices or pointers to slices.", schema, entity, originalType)
|
||||||
|
h.sendError(w, http.StatusInternalServerError, "invalid_model_type",
|
||||||
|
fmt.Sprintf("Model must be a struct type, got %v. Ensure you register the struct (e.g., ModelCoreAccount{}) not a slice (e.g., []*ModelCoreAccount)", originalType),
|
||||||
|
fmt.Errorf("invalid model type: %v", originalType))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the registered model was a pointer or slice, use the unwrapped struct type
|
||||||
|
if originalType != modelType {
|
||||||
|
model = reflect.New(modelType).Elem().Interface()
|
||||||
|
}
|
||||||
|
|
||||||
modelPtr := reflect.New(reflect.TypeOf(model)).Interface()
|
modelPtr := reflect.New(reflect.TypeOf(model)).Interface()
|
||||||
tableName := h.getTableName(schema, entity, model)
|
tableName := h.getTableName(schema, entity, model)
|
||||||
|
|
||||||
@ -158,7 +178,22 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
|||||||
schema := GetSchema(ctx)
|
schema := GetSchema(ctx)
|
||||||
entity := GetEntity(ctx)
|
entity := GetEntity(ctx)
|
||||||
tableName := GetTableName(ctx)
|
tableName := GetTableName(ctx)
|
||||||
modelPtr := GetModelPtr(ctx)
|
model := GetModel(ctx)
|
||||||
|
|
||||||
|
// Validate and unwrap model type to get base struct
|
||||||
|
modelType := reflect.TypeOf(model)
|
||||||
|
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
||||||
|
modelType = modelType.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
if modelType == nil || modelType.Kind() != reflect.Struct {
|
||||||
|
logger.Error("Model must be a struct type, got %v for %s.%s", modelType, schema, entity)
|
||||||
|
h.sendError(w, http.StatusInternalServerError, "invalid_model", "Model must be a struct type", fmt.Errorf("invalid model type: %v", modelType))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a pointer to a slice of pointers to the model type for query results
|
||||||
|
modelPtr := reflect.New(reflect.SliceOf(reflect.PointerTo(modelType))).Interface()
|
||||||
|
|
||||||
logger.Info("Reading records from %s.%s", schema, entity)
|
logger.Info("Reading records from %s.%s", schema, entity)
|
||||||
|
|
||||||
@ -252,10 +287,8 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
|||||||
query = query.Offset(*options.Offset)
|
query = query.Offset(*options.Offset)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Execute query - create a slice of pointers to the model type
|
// Execute query - modelPtr was already created earlier
|
||||||
model := GetModel(ctx)
|
if err := query.Scan(ctx, modelPtr); err != nil {
|
||||||
resultSlice := reflect.New(reflect.SliceOf(reflect.PointerTo(reflect.TypeOf(model)))).Interface()
|
|
||||||
if err := query.Scan(ctx, resultSlice); err != nil {
|
|
||||||
logger.Error("Error executing query: %v", err)
|
logger.Error("Error executing query: %v", err)
|
||||||
h.sendError(w, http.StatusInternalServerError, "query_error", "Error executing query", err)
|
h.sendError(w, http.StatusInternalServerError, "query_error", "Error executing query", err)
|
||||||
return
|
return
|
||||||
@ -277,7 +310,7 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
|||||||
Offset: offset,
|
Offset: offset,
|
||||||
}
|
}
|
||||||
|
|
||||||
h.sendFormattedResponse(w, resultSlice, metadata, options)
|
h.sendFormattedResponse(w, modelPtr, metadata, options)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, data interface{}, options ExtendedRequestOptions) {
|
func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, data interface{}, options ExtendedRequestOptions) {
|
||||||
@ -516,10 +549,22 @@ func (h *Handler) getTableName(schema, entity string, model interface{}) string
|
|||||||
|
|
||||||
func (h *Handler) generateMetadata(schema, entity string, model interface{}) *common.TableMetadata {
|
func (h *Handler) generateMetadata(schema, entity string, model interface{}) *common.TableMetadata {
|
||||||
modelType := reflect.TypeOf(model)
|
modelType := reflect.TypeOf(model)
|
||||||
if modelType.Kind() == reflect.Ptr {
|
|
||||||
|
// Unwrap pointers, slices, and arrays to get to the base struct type
|
||||||
|
for modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array {
|
||||||
modelType = modelType.Elem()
|
modelType = modelType.Elem()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Validate that we have a struct type
|
||||||
|
if modelType.Kind() != reflect.Struct {
|
||||||
|
logger.Error("Model type must be a struct, got %s for %s.%s", modelType.Kind(), schema, entity)
|
||||||
|
return &common.TableMetadata{
|
||||||
|
Schema: schema,
|
||||||
|
Table: h.getTableName(schema, entity, model),
|
||||||
|
Columns: []common.Column{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
tableName := h.getTableName(schema, entity, model)
|
tableName := h.getTableName(schema, entity, model)
|
||||||
|
|
||||||
metadata := &common.TableMetadata{
|
metadata := &common.TableMetadata{
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user