Compare commits

...

2 Commits

Author SHA1 Message Date
Hein
e88018543e Reflect safty 2025-11-07 09:47:12 +02:00
Hein
e7e5754a47 Added panic catches 2025-11-07 09:32:37 +02:00
3 changed files with 239 additions and 17 deletions

View File

@@ -38,12 +38,28 @@ func (r *DefaultModelRegistry) RegisterModel(name string, model interface{}) err
return fmt.Errorf("model cannot be nil") return fmt.Errorf("model cannot be nil")
} }
if modelType.Kind() == reflect.Ptr { originalType := modelType
return fmt.Errorf("model must be a non-pointer struct, got pointer to %s", modelType.Elem().Kind())
// Unwrap pointers, slices, and arrays to check the underlying type
for modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array {
modelType = modelType.Elem()
} }
// Validate that the underlying type is a struct
if modelType.Kind() != reflect.Struct { if modelType.Kind() != reflect.Struct {
return fmt.Errorf("model must be a struct, got %s", modelType.Kind()) return fmt.Errorf("model must be a struct or pointer to struct, got %s", originalType.String())
}
// If a pointer/slice/array was passed, unwrap to the base struct
if originalType != modelType {
// Create a zero value of the struct type
model = reflect.New(modelType).Elem().Interface()
}
// Additional check: ensure model is not a pointer
finalType := reflect.TypeOf(model)
if finalType.Kind() == reflect.Ptr {
return fmt.Errorf("model must be a non-pointer struct, got pointer to %s. Use MyModel{} instead of &MyModel{}", finalType.Elem().Name())
} }
r.models[name] = model r.models[name] = model

View File

@@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"reflect" "reflect"
"runtime/debug"
"strings" "strings"
"github.com/Warky-Devs/ResolveSpec/pkg/common" "github.com/Warky-Devs/ResolveSpec/pkg/common"
@@ -26,8 +27,22 @@ func NewHandler(db common.Database, registry common.ModelRegistry) *Handler {
} }
} }
// handlePanic is a helper function to handle panics with stack traces
func (h *Handler) handlePanic(w common.ResponseWriter, method string, err interface{}) {
stack := debug.Stack()
logger.Error("Panic in %s: %v\nStack trace:\n%s", method, err, string(stack))
h.sendError(w, http.StatusInternalServerError, "internal_error", fmt.Sprintf("Internal server error in %s", method), fmt.Errorf("%v", err))
}
// Handle processes API requests through router-agnostic interface // Handle processes API requests through router-agnostic interface
func (h *Handler) Handle(w common.ResponseWriter, r common.Request, params map[string]string) { func (h *Handler) Handle(w common.ResponseWriter, r common.Request, params map[string]string) {
// Capture panics and return error response
defer func() {
if err := recover(); err != nil {
h.handlePanic(w, "Handle", err)
}
}()
ctx := context.Background() ctx := context.Background()
body, err := r.Body() body, err := r.Body()
@@ -58,6 +73,26 @@ func (h *Handler) Handle(w common.ResponseWriter, r common.Request, params map[s
return return
} }
// Validate that the model is a struct type (not a slice or pointer to slice)
modelType := reflect.TypeOf(model)
originalType := modelType
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
modelType = modelType.Elem()
}
if modelType == nil || modelType.Kind() != reflect.Struct {
logger.Error("Model for %s.%s must be a struct type, got %v. Please register models as struct types, not slices or pointers to slices.", schema, entity, originalType)
h.sendError(w, http.StatusInternalServerError, "invalid_model_type",
fmt.Sprintf("Model must be a struct type, got %v. Ensure you register the struct (e.g., ModelCoreAccount{}) not a slice (e.g., []*ModelCoreAccount)", originalType),
fmt.Errorf("invalid model type: %v", originalType))
return
}
// If the registered model was a pointer or slice, use the unwrapped struct type
if originalType != modelType {
model = reflect.New(modelType).Elem().Interface()
}
// Create a pointer to the model type for database operations // Create a pointer to the model type for database operations
modelPtr := reflect.New(reflect.TypeOf(model)).Interface() modelPtr := reflect.New(reflect.TypeOf(model)).Interface()
tableName := h.getTableName(schema, entity, model) tableName := h.getTableName(schema, entity, model)
@@ -82,6 +117,13 @@ func (h *Handler) Handle(w common.ResponseWriter, r common.Request, params map[s
// HandleGet processes GET requests for metadata // HandleGet processes GET requests for metadata
func (h *Handler) HandleGet(w common.ResponseWriter, r common.Request, params map[string]string) { func (h *Handler) HandleGet(w common.ResponseWriter, r common.Request, params map[string]string) {
// Capture panics and return error response
defer func() {
if err := recover(); err != nil {
h.handlePanic(w, "HandleGet", err)
}
}()
schema := params["schema"] schema := params["schema"]
entity := params["entity"] entity := params["entity"]
@@ -99,11 +141,32 @@ func (h *Handler) HandleGet(w common.ResponseWriter, r common.Request, params ma
} }
func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id string, options common.RequestOptions) { func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id string, options common.RequestOptions) {
// Capture panics and return error response
defer func() {
if err := recover(); err != nil {
h.handlePanic(w, "handleRead", err)
}
}()
schema := GetSchema(ctx) schema := GetSchema(ctx)
entity := GetEntity(ctx) entity := GetEntity(ctx)
tableName := GetTableName(ctx) tableName := GetTableName(ctx)
model := GetModel(ctx) model := GetModel(ctx)
modelPtr := GetModelPtr(ctx)
// Validate and unwrap model type to get base struct
modelType := reflect.TypeOf(model)
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
modelType = modelType.Elem()
}
if modelType == nil || modelType.Kind() != reflect.Struct {
logger.Error("Model must be a struct type, got %v for %s.%s", modelType, schema, entity)
h.sendError(w, http.StatusInternalServerError, "invalid_model", "Model must be a struct type", fmt.Errorf("invalid model type: %v", modelType))
return
}
// Create a pointer to the model type for database operations
modelPtr := reflect.New(modelType).Interface()
logger.Info("Reading records from %s.%s", schema, entity) logger.Info("Reading records from %s.%s", schema, entity)
@@ -160,8 +223,8 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
var result interface{} var result interface{}
if id != "" { if id != "" {
logger.Debug("Querying single record with ID: %s", id) logger.Debug("Querying single record with ID: %s", id)
// Create a pointer to the struct type for scanning // Create a pointer to the struct type for scanning - use modelType which is already unwrapped
singleResult := reflect.New(reflect.TypeOf(model)).Interface() singleResult := reflect.New(modelType).Interface()
query = query.Where("id = ?", id) query = query.Where("id = ?", id)
if err := query.Scan(ctx, singleResult); err != nil { if err := query.Scan(ctx, singleResult); err != nil {
logger.Error("Error querying record: %v", err) logger.Error("Error querying record: %v", err)
@@ -171,8 +234,8 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
result = singleResult result = singleResult
} else { } else {
logger.Debug("Querying multiple records") logger.Debug("Querying multiple records")
// Create a slice of pointers to the model type // Create a slice of pointers to the model type - use modelType which is already unwrapped
sliceType := reflect.SliceOf(reflect.PointerTo(reflect.TypeOf(model))) sliceType := reflect.SliceOf(reflect.PointerTo(modelType))
results := reflect.New(sliceType).Interface() results := reflect.New(sliceType).Interface()
if err := query.Scan(ctx, results); err != nil { if err := query.Scan(ctx, results); err != nil {
@@ -203,6 +266,13 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
} }
func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, data interface{}, options common.RequestOptions) { func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, data interface{}, options common.RequestOptions) {
// Capture panics and return error response
defer func() {
if err := recover(); err != nil {
h.handlePanic(w, "handleCreate", err)
}
}()
schema := GetSchema(ctx) schema := GetSchema(ctx)
entity := GetEntity(ctx) entity := GetEntity(ctx)
tableName := GetTableName(ctx) tableName := GetTableName(ctx)
@@ -279,6 +349,13 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
} }
func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, urlID string, reqID interface{}, data interface{}, options common.RequestOptions) { func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, urlID string, reqID interface{}, data interface{}, options common.RequestOptions) {
// Capture panics and return error response
defer func() {
if err := recover(); err != nil {
h.handlePanic(w, "handleUpdate", err)
}
}()
schema := GetSchema(ctx) schema := GetSchema(ctx)
entity := GetEntity(ctx) entity := GetEntity(ctx)
tableName := GetTableName(ctx) tableName := GetTableName(ctx)
@@ -329,6 +406,13 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url
} }
func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id string) { func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id string) {
// Capture panics and return error response
defer func() {
if err := recover(); err != nil {
h.handlePanic(w, "handleDelete", err)
}
}()
schema := GetSchema(ctx) schema := GetSchema(ctx)
entity := GetEntity(ctx) entity := GetEntity(ctx)
tableName := GetTableName(ctx) tableName := GetTableName(ctx)
@@ -394,10 +478,23 @@ func (h *Handler) getTableName(schema, entity string, model interface{}) string
func (h *Handler) generateMetadata(schema, entity string, model interface{}) *common.TableMetadata { func (h *Handler) generateMetadata(schema, entity string, model interface{}) *common.TableMetadata {
modelType := reflect.TypeOf(model) modelType := reflect.TypeOf(model)
if modelType.Kind() == reflect.Ptr {
// Unwrap pointers, slices, and arrays to get to the base struct type
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
modelType = modelType.Elem() modelType = modelType.Elem()
} }
// Validate that we have a struct type
if modelType == nil || modelType.Kind() != reflect.Struct {
logger.Error("Model type must be a struct, got %v for %s.%s", modelType, schema, entity)
return &common.TableMetadata{
Schema: schema,
Table: entity,
Columns: make([]common.Column, 0),
Relations: make([]string, 0),
}
}
metadata := &common.TableMetadata{ metadata := &common.TableMetadata{
Schema: schema, Schema: schema,
Table: entity, Table: entity,
@@ -541,10 +638,18 @@ type relationshipInfo struct {
func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, preloads []common.PreloadOption) common.SelectQuery { func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, preloads []common.PreloadOption) common.SelectQuery {
modelType := reflect.TypeOf(model) modelType := reflect.TypeOf(model)
if modelType.Kind() == reflect.Ptr {
// Unwrap pointers, slices, and arrays to get to the base struct type
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
modelType = modelType.Elem() modelType = modelType.Elem()
} }
// Validate that we have a struct type
if modelType == nil || modelType.Kind() != reflect.Struct {
logger.Warn("Cannot apply preloads to non-struct type: %v", modelType)
return query
}
for _, preload := range preloads { for _, preload := range preloads {
logger.Debug("Processing preload for relation: %s", preload.Relation) logger.Debug("Processing preload for relation: %s", preload.Relation)
relInfo := h.getRelationshipInfo(modelType, preload.Relation) relInfo := h.getRelationshipInfo(modelType, preload.Relation)
@@ -568,6 +673,12 @@ func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, pre
} }
func (h *Handler) getRelationshipInfo(modelType reflect.Type, relationName string) *relationshipInfo { func (h *Handler) getRelationshipInfo(modelType reflect.Type, relationName string) *relationshipInfo {
// Ensure we have a struct type
if modelType == nil || modelType.Kind() != reflect.Struct {
logger.Warn("Cannot get relationship info from non-struct type: %v", modelType)
return nil
}
for i := 0; i < modelType.NumField(); i++ { for i := 0; i < modelType.NumField(); i++ {
field := modelType.Field(i) field := modelType.Field(i)
jsonTag := field.Tag.Get("json") jsonTag := field.Tag.Get("json")

View File

@@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"reflect" "reflect"
"runtime/debug"
"strings" "strings"
"github.com/Warky-Devs/ResolveSpec/pkg/common" "github.com/Warky-Devs/ResolveSpec/pkg/common"
@@ -27,9 +28,23 @@ func NewHandler(db common.Database, registry common.ModelRegistry) *Handler {
} }
} }
// handlePanic is a helper function to handle panics with stack traces
func (h *Handler) handlePanic(w common.ResponseWriter, method string, err interface{}) {
stack := debug.Stack()
logger.Error("Panic in %s: %v\nStack trace:\n%s", method, err, string(stack))
h.sendError(w, http.StatusInternalServerError, "internal_error", fmt.Sprintf("Internal server error in %s", method), fmt.Errorf("%v", err))
}
// Handle processes API requests through router-agnostic interface // Handle processes API requests through router-agnostic interface
// Options are read from HTTP headers instead of request body // Options are read from HTTP headers instead of request body
func (h *Handler) Handle(w common.ResponseWriter, r common.Request, params map[string]string) { func (h *Handler) Handle(w common.ResponseWriter, r common.Request, params map[string]string) {
// Capture panics and return error response
defer func() {
if err := recover(); err != nil {
h.handlePanic(w, "Handle", err)
}
}()
ctx := context.Background() ctx := context.Background()
schema := params["schema"] schema := params["schema"]
@@ -52,6 +67,26 @@ func (h *Handler) Handle(w common.ResponseWriter, r common.Request, params map[s
return return
} }
// Validate that the model is a struct type (not a slice or pointer to slice)
modelType := reflect.TypeOf(model)
originalType := modelType
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
modelType = modelType.Elem()
}
if modelType == nil || modelType.Kind() != reflect.Struct {
logger.Error("Model for %s.%s must be a struct type, got %v. Please register models as struct types, not slices or pointers to slices.", schema, entity, originalType)
h.sendError(w, http.StatusInternalServerError, "invalid_model_type",
fmt.Sprintf("Model must be a struct type, got %v. Ensure you register the struct (e.g., ModelCoreAccount{}) not a slice (e.g., []*ModelCoreAccount)", originalType),
fmt.Errorf("invalid model type: %v", originalType))
return
}
// If the registered model was a pointer or slice, use the unwrapped struct type
if originalType != modelType {
model = reflect.New(modelType).Elem().Interface()
}
modelPtr := reflect.New(reflect.TypeOf(model)).Interface() modelPtr := reflect.New(reflect.TypeOf(model)).Interface()
tableName := h.getTableName(schema, entity, model) tableName := h.getTableName(schema, entity, model)
@@ -107,6 +142,13 @@ func (h *Handler) Handle(w common.ResponseWriter, r common.Request, params map[s
// HandleGet processes GET requests for metadata // HandleGet processes GET requests for metadata
func (h *Handler) HandleGet(w common.ResponseWriter, r common.Request, params map[string]string) { func (h *Handler) HandleGet(w common.ResponseWriter, r common.Request, params map[string]string) {
// Capture panics and return error response
defer func() {
if err := recover(); err != nil {
h.handlePanic(w, "HandleGet", err)
}
}()
schema := params["schema"] schema := params["schema"]
entity := params["entity"] entity := params["entity"]
@@ -126,10 +168,32 @@ func (h *Handler) HandleGet(w common.ResponseWriter, r common.Request, params ma
// parseOptionsFromHeaders is now implemented in headers.go // parseOptionsFromHeaders is now implemented in headers.go
func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id string, options ExtendedRequestOptions) { func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id string, options ExtendedRequestOptions) {
// Capture panics and return error response
defer func() {
if err := recover(); err != nil {
h.handlePanic(w, "handleRead", err)
}
}()
schema := GetSchema(ctx) schema := GetSchema(ctx)
entity := GetEntity(ctx) entity := GetEntity(ctx)
tableName := GetTableName(ctx) tableName := GetTableName(ctx)
modelPtr := GetModelPtr(ctx) model := GetModel(ctx)
// Validate and unwrap model type to get base struct
modelType := reflect.TypeOf(model)
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
modelType = modelType.Elem()
}
if modelType == nil || modelType.Kind() != reflect.Struct {
logger.Error("Model must be a struct type, got %v for %s.%s", modelType, schema, entity)
h.sendError(w, http.StatusInternalServerError, "invalid_model", "Model must be a struct type", fmt.Errorf("invalid model type: %v", modelType))
return
}
// Create a pointer to a slice of pointers to the model type for query results
modelPtr := reflect.New(reflect.SliceOf(reflect.PointerTo(modelType))).Interface()
logger.Info("Reading records from %s.%s", schema, entity) logger.Info("Reading records from %s.%s", schema, entity)
@@ -223,10 +287,8 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
query = query.Offset(*options.Offset) query = query.Offset(*options.Offset)
} }
// Execute query - create a slice of pointers to the model type // Execute query - modelPtr was already created earlier
model := GetModel(ctx) if err := query.Scan(ctx, modelPtr); err != nil {
resultSlice := reflect.New(reflect.SliceOf(reflect.PointerTo(reflect.TypeOf(model)))).Interface()
if err := query.Scan(ctx, resultSlice); err != nil {
logger.Error("Error executing query: %v", err) logger.Error("Error executing query: %v", err)
h.sendError(w, http.StatusInternalServerError, "query_error", "Error executing query", err) h.sendError(w, http.StatusInternalServerError, "query_error", "Error executing query", err)
return return
@@ -248,10 +310,17 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
Offset: offset, Offset: offset,
} }
h.sendFormattedResponse(w, resultSlice, metadata, options) h.sendFormattedResponse(w, modelPtr, metadata, options)
} }
func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, data interface{}, options ExtendedRequestOptions) { func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, data interface{}, options ExtendedRequestOptions) {
// Capture panics and return error response
defer func() {
if err := recover(); err != nil {
h.handlePanic(w, "handleCreate", err)
}
}()
schema := GetSchema(ctx) schema := GetSchema(ctx)
entity := GetEntity(ctx) entity := GetEntity(ctx)
tableName := GetTableName(ctx) tableName := GetTableName(ctx)
@@ -322,6 +391,13 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
} }
func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id string, idPtr *int64, data interface{}, options ExtendedRequestOptions) { func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id string, idPtr *int64, data interface{}, options ExtendedRequestOptions) {
// Capture panics and return error response
defer func() {
if err := recover(); err != nil {
h.handlePanic(w, "handleUpdate", err)
}
}()
schema := GetSchema(ctx) schema := GetSchema(ctx)
entity := GetEntity(ctx) entity := GetEntity(ctx)
tableName := GetTableName(ctx) tableName := GetTableName(ctx)
@@ -369,6 +445,13 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id
} }
func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id string) { func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id string) {
// Capture panics and return error response
defer func() {
if err := recover(); err != nil {
h.handlePanic(w, "handleDelete", err)
}
}()
schema := GetSchema(ctx) schema := GetSchema(ctx)
entity := GetEntity(ctx) entity := GetEntity(ctx)
tableName := GetTableName(ctx) tableName := GetTableName(ctx)
@@ -466,10 +549,22 @@ func (h *Handler) getTableName(schema, entity string, model interface{}) string
func (h *Handler) generateMetadata(schema, entity string, model interface{}) *common.TableMetadata { func (h *Handler) generateMetadata(schema, entity string, model interface{}) *common.TableMetadata {
modelType := reflect.TypeOf(model) modelType := reflect.TypeOf(model)
if modelType.Kind() == reflect.Ptr {
// Unwrap pointers, slices, and arrays to get to the base struct type
for modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array {
modelType = modelType.Elem() modelType = modelType.Elem()
} }
// Validate that we have a struct type
if modelType.Kind() != reflect.Struct {
logger.Error("Model type must be a struct, got %s for %s.%s", modelType.Kind(), schema, entity)
return &common.TableMetadata{
Schema: schema,
Table: h.getTableName(schema, entity, model),
Columns: []common.Column{},
}
}
tableName := h.getTableName(schema, entity, model) tableName := h.getTableName(schema, entity, model)
metadata := &common.TableMetadata{ metadata := &common.TableMetadata{