Better handling with context

This commit is contained in:
Hein
2025-11-07 09:13:06 +02:00
parent d122c7af42
commit c88bff1883
6 changed files with 317 additions and 105 deletions

View File

@@ -0,0 +1,85 @@
package resolvespec
import (
"context"
)
// Context keys for request-scoped data
type contextKey string
const (
contextKeySchema contextKey = "schema"
contextKeyEntity contextKey = "entity"
contextKeyTableName contextKey = "tableName"
contextKeyModel contextKey = "model"
contextKeyModelPtr contextKey = "modelPtr"
)
// WithSchema adds schema to context
func WithSchema(ctx context.Context, schema string) context.Context {
return context.WithValue(ctx, contextKeySchema, schema)
}
// GetSchema retrieves schema from context
func GetSchema(ctx context.Context) string {
if v := ctx.Value(contextKeySchema); v != nil {
return v.(string)
}
return ""
}
// WithEntity adds entity to context
func WithEntity(ctx context.Context, entity string) context.Context {
return context.WithValue(ctx, contextKeyEntity, entity)
}
// GetEntity retrieves entity from context
func GetEntity(ctx context.Context) string {
if v := ctx.Value(contextKeyEntity); v != nil {
return v.(string)
}
return ""
}
// WithTableName adds table name to context
func WithTableName(ctx context.Context, tableName string) context.Context {
return context.WithValue(ctx, contextKeyTableName, tableName)
}
// GetTableName retrieves table name from context
func GetTableName(ctx context.Context) string {
if v := ctx.Value(contextKeyTableName); v != nil {
return v.(string)
}
return ""
}
// WithModel adds model to context
func WithModel(ctx context.Context, model interface{}) context.Context {
return context.WithValue(ctx, contextKeyModel, model)
}
// GetModel retrieves model from context
func GetModel(ctx context.Context) interface{} {
return ctx.Value(contextKeyModel)
}
// WithModelPtr adds model pointer to context
func WithModelPtr(ctx context.Context, modelPtr interface{}) context.Context {
return context.WithValue(ctx, contextKeyModelPtr, modelPtr)
}
// GetModelPtr retrieves model pointer from context
func GetModelPtr(ctx context.Context) interface{} {
return ctx.Value(contextKeyModelPtr)
}
// WithRequestData adds all request-scoped data to context at once
func WithRequestData(ctx context.Context, schema, entity, tableName string, model, modelPtr interface{}) context.Context {
ctx = WithSchema(ctx, schema)
ctx = WithEntity(ctx, entity)
ctx = WithTableName(ctx, tableName)
ctx = WithModel(ctx, model)
ctx = WithModelPtr(ctx, modelPtr)
return ctx
}

View File

@@ -50,15 +50,30 @@ func (h *Handler) Handle(w common.ResponseWriter, r common.Request, params map[s
logger.Info("Handling %s operation for %s.%s", req.Operation, schema, entity)
// Get model and populate context with request-scoped data
model, err := h.registry.GetModelByEntity(schema, entity)
if err != nil {
logger.Error("Invalid entity: %v", err)
h.sendError(w, http.StatusBadRequest, "invalid_entity", "Invalid entity", err)
return
}
// Create a pointer to the model type for database operations
modelPtr := reflect.New(reflect.TypeOf(model)).Interface()
tableName := h.getTableName(schema, entity, model)
// Add request-scoped data to context
ctx = WithRequestData(ctx, schema, entity, tableName, model, modelPtr)
switch req.Operation {
case "read":
h.handleRead(ctx, w, schema, entity, id, req.Options)
h.handleRead(ctx, w, id, req.Options)
case "create":
h.handleCreate(ctx, w, schema, entity, req.Data, req.Options)
h.handleCreate(ctx, w, req.Data, req.Options)
case "update":
h.handleUpdate(ctx, w, schema, entity, id, req.ID, req.Data, req.Options)
h.handleUpdate(ctx, w, id, req.ID, req.Data, req.Options)
case "delete":
h.handleDelete(ctx, w, schema, entity, id)
h.handleDelete(ctx, w, id)
default:
logger.Error("Invalid operation: %s", req.Operation)
h.sendError(w, http.StatusBadRequest, "invalid_operation", "Invalid operation", nil)
@@ -83,24 +98,16 @@ func (h *Handler) HandleGet(w common.ResponseWriter, r common.Request, params ma
h.sendResponse(w, metadata, nil)
}
func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, schema, entity, id string, options common.RequestOptions) {
func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id string, options common.RequestOptions) {
schema := GetSchema(ctx)
entity := GetEntity(ctx)
tableName := GetTableName(ctx)
model := GetModel(ctx)
modelPtr := GetModelPtr(ctx)
logger.Info("Reading records from %s.%s", schema, entity)
model, err := h.registry.GetModelByEntity(schema, entity)
if err != nil {
logger.Error("Invalid entity: %v", err)
h.sendError(w, http.StatusBadRequest, "invalid_entity", "Invalid entity", err)
return
}
// Model is now a non-pointer struct, create a pointer instance for ORM
modelType := reflect.TypeOf(model)
modelPtr := reflect.New(modelType).Interface()
query := h.db.NewSelect().Model(modelPtr)
// Get table name
tableName := h.getTableName(schema, entity, model)
query = query.Table(tableName)
// Apply column selection
@@ -154,7 +161,7 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, schem
if id != "" {
logger.Debug("Querying single record with ID: %s", id)
// Create a pointer to the struct type for scanning
singleResult := reflect.New(modelType).Interface()
singleResult := reflect.New(reflect.TypeOf(model)).Interface()
query = query.Where("id = ?", id)
if err := query.Scan(ctx, singleResult); err != nil {
logger.Error("Error querying record: %v", err)
@@ -164,8 +171,8 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, schem
result = singleResult
} else {
logger.Debug("Querying multiple records")
// Create a slice of the struct type (not pointers)
sliceType := reflect.SliceOf(modelType)
// Create a slice of pointers to the model type
sliceType := reflect.SliceOf(reflect.PointerTo(reflect.TypeOf(model)))
results := reflect.New(sliceType).Interface()
if err := query.Scan(ctx, results); err != nil {
@@ -195,17 +202,13 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, schem
})
}
func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, schema, entity string, data interface{}, options common.RequestOptions) {
func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, data interface{}, options common.RequestOptions) {
schema := GetSchema(ctx)
entity := GetEntity(ctx)
tableName := GetTableName(ctx)
logger.Info("Creating records for %s.%s", schema, entity)
// Get the model to determine the actual table name
model, err := h.registry.GetModelByEntity(schema, entity)
if err != nil {
logger.Warn("Model not found, using default table name")
model = nil
}
tableName := h.getTableName(schema, entity, model)
query := h.db.NewInsert().Table(tableName)
switch v := data.(type) {
@@ -275,18 +278,13 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, sch
}
}
func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, schema, entity, urlID string, reqID interface{}, data interface{}, options common.RequestOptions) {
func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, urlID string, reqID interface{}, data interface{}, options common.RequestOptions) {
schema := GetSchema(ctx)
entity := GetEntity(ctx)
tableName := GetTableName(ctx)
logger.Info("Updating records for %s.%s", schema, entity)
// Get the model to determine the actual table name
model, err := h.registry.GetModelByEntity(schema, entity)
if err != nil {
logger.Warn("Model not found, using default table name")
// Fallback to entity name (without schema for SQLite compatibility)
model = nil
}
tableName := h.getTableName(schema, entity, model)
query := h.db.NewUpdate().Table(tableName)
switch updates := data.(type) {
@@ -330,7 +328,11 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, sch
h.sendResponse(w, data, nil)
}
func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, schema, entity, id string) {
func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id string) {
schema := GetSchema(ctx)
entity := GetEntity(ctx)
tableName := GetTableName(ctx)
logger.Info("Deleting records from %s.%s", schema, entity)
if id == "" {
@@ -339,14 +341,6 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, sch
return
}
// Get the model to determine the actual table name
model, err := h.registry.GetModelByEntity(schema, entity)
if err != nil {
logger.Warn("Model not found, using default table name")
model = nil
}
tableName := h.getTableName(schema, entity, model)
query := h.db.NewDelete().Table(tableName).Where("id = ?", id)
result, err := query.Exec(ctx)