mirror of
https://github.com/bitechdev/ResolveSpec.git
synced 2026-02-12 11:36:07 +00:00
* Introduce `logic_operator` field to combine filters with OR logic.
* Implement grouping for consecutive OR filters to ensure proper SQL precedence.
* Add support for custom SQL operators in filter conditions.
* Enhance `fetch_row_number` functionality to return specific record with its position.
* Update tests to cover new filter logic and grouping behavior.
Features Implemented:
1. OR Logic Filter Support (SearchOr)
- Added to resolvespec, restheadspec, and websocketspec
- Consecutive OR filters are automatically grouped with parentheses
- Prevents SQL logic errors: (A OR B OR C) AND D instead of A OR B OR C AND D
2. CustomOperators
- Allows arbitrary SQL conditions in resolvespec
- Properly integrated with filter logic
3. FetchRowNumber
- Uses SQL window functions: ROW_NUMBER() OVER (ORDER BY ...)
- Returns only the specific record (not all records)
- Available in resolvespec and restheadspec
- Perfect for "What's my rank?" queries
4. RowNumber Field Auto-Population
- Now available in all three packages: resolvespec, restheadspec, and websocketspec
- Uses simple offset-based math: offset + index + 1
- Automatically populates RowNumber int64 field if it exists on models
- Perfect for displaying paginated lists with sequential numbering
2004 lines
64 KiB
Go
2004 lines
64 KiB
Go
package resolvespec
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"encoding/json"
|
|
"fmt"
|
|
"net/http"
|
|
"reflect"
|
|
"runtime/debug"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/bitechdev/ResolveSpec/pkg/cache"
|
|
"github.com/bitechdev/ResolveSpec/pkg/common"
|
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
|
"github.com/bitechdev/ResolveSpec/pkg/reflection"
|
|
)
|
|
|
|
// FallbackHandler is a function that handles requests when no model is found
|
|
// It receives the same parameters as the Handle method
|
|
type FallbackHandler func(w common.ResponseWriter, r common.Request, params map[string]string)
|
|
|
|
// Handler handles API requests using database and model abstractions
|
|
type Handler struct {
|
|
db common.Database
|
|
registry common.ModelRegistry
|
|
nestedProcessor *common.NestedCUDProcessor
|
|
hooks *HookRegistry
|
|
fallbackHandler FallbackHandler
|
|
openAPIGenerator func() (string, error)
|
|
}
|
|
|
|
// NewHandler creates a new API handler with database and registry abstractions
|
|
func NewHandler(db common.Database, registry common.ModelRegistry) *Handler {
|
|
handler := &Handler{
|
|
db: db,
|
|
registry: registry,
|
|
hooks: NewHookRegistry(),
|
|
}
|
|
// Initialize nested processor
|
|
handler.nestedProcessor = common.NewNestedCUDProcessor(db, registry, handler)
|
|
return handler
|
|
}
|
|
|
|
// Hooks returns the hook registry for this handler
|
|
// Use this to register custom hooks for operations
|
|
func (h *Handler) Hooks() *HookRegistry {
|
|
return h.hooks
|
|
}
|
|
|
|
// SetFallbackHandler sets a fallback handler to be called when no model is found
|
|
// If not set, the handler will simply return (pass through to next route)
|
|
func (h *Handler) SetFallbackHandler(fallback FallbackHandler) {
|
|
h.fallbackHandler = fallback
|
|
}
|
|
|
|
// GetDatabase returns the underlying database connection
|
|
// Implements common.SpecHandler interface
|
|
func (h *Handler) GetDatabase() common.Database {
|
|
return h.db
|
|
}
|
|
|
|
// handlePanic is a helper function to handle panics with stack traces
|
|
func (h *Handler) handlePanic(w common.ResponseWriter, method string, err interface{}) {
|
|
stack := debug.Stack()
|
|
logger.Error("Panic in %s: %v\nStack trace:\n%s", method, err, string(stack))
|
|
h.sendError(w, http.StatusInternalServerError, "internal_error", fmt.Sprintf("Internal server error in %s", method), fmt.Errorf("%v", err))
|
|
}
|
|
|
|
// Handle processes API requests through router-agnostic interface
|
|
func (h *Handler) Handle(w common.ResponseWriter, r common.Request, params map[string]string) {
|
|
// Capture panics and return error response
|
|
defer func() {
|
|
if err := recover(); err != nil {
|
|
h.handlePanic(w, "Handle", err)
|
|
}
|
|
}()
|
|
|
|
// Check for ?openapi query parameter
|
|
if r.UnderlyingRequest().URL.Query().Get("openapi") != "" {
|
|
h.HandleOpenAPI(w, r)
|
|
return
|
|
}
|
|
|
|
ctx := r.UnderlyingRequest().Context()
|
|
|
|
body, err := r.Body()
|
|
if err != nil {
|
|
logger.Error("Failed to read request body: %v", err)
|
|
h.sendError(w, http.StatusBadRequest, "invalid_request", "Failed to read request body", err)
|
|
return
|
|
}
|
|
|
|
var req common.RequestBody
|
|
if err := json.Unmarshal(body, &req); err != nil {
|
|
logger.Error("Failed to decode request body: %v", err)
|
|
h.sendError(w, http.StatusBadRequest, "invalid_request", "Invalid request body", err)
|
|
return
|
|
}
|
|
|
|
schema := params["schema"]
|
|
entity := params["entity"]
|
|
id := params["id"]
|
|
|
|
logger.Info("Handling %s operation for %s.%s", req.Operation, schema, entity)
|
|
|
|
// Get model and populate context with request-scoped data
|
|
model, err := h.registry.GetModelByEntity(schema, entity)
|
|
if err != nil {
|
|
// Model not found - call fallback handler if set, otherwise pass through
|
|
logger.Debug("Model not found for %s.%s", schema, entity)
|
|
if h.fallbackHandler != nil {
|
|
logger.Debug("Calling fallback handler for %s.%s", schema, entity)
|
|
h.fallbackHandler(w, r, params)
|
|
} else {
|
|
logger.Debug("No fallback handler set, passing through to next route")
|
|
}
|
|
return
|
|
}
|
|
|
|
// Validate and unwrap model using common utility
|
|
result, err := common.ValidateAndUnwrapModel(model)
|
|
if err != nil {
|
|
logger.Error("Model for %s.%s validation failed: %v", schema, entity, err)
|
|
h.sendError(w, http.StatusInternalServerError, "invalid_model_type", err.Error(), err)
|
|
return
|
|
}
|
|
|
|
model = result.Model
|
|
modelPtr := result.ModelPtr
|
|
tableName := h.getTableName(schema, entity, model)
|
|
|
|
// Add request-scoped data to context
|
|
ctx = WithRequestData(ctx, schema, entity, tableName, model, modelPtr)
|
|
|
|
// Validate and filter columns in options (log warnings for invalid columns)
|
|
validator := common.NewColumnValidator(model)
|
|
req.Options = validator.FilterRequestOptions(req.Options)
|
|
|
|
switch req.Operation {
|
|
case "read":
|
|
h.handleRead(ctx, w, id, req.Options)
|
|
case "create":
|
|
h.handleCreate(ctx, w, req.Data, req.Options)
|
|
case "update":
|
|
h.handleUpdate(ctx, w, id, req.ID, req.Data, req.Options)
|
|
case "delete":
|
|
h.handleDelete(ctx, w, id, req.Data)
|
|
case "meta":
|
|
h.handleMeta(ctx, w, schema, entity, model)
|
|
default:
|
|
logger.Error("Invalid operation: %s", req.Operation)
|
|
h.sendError(w, http.StatusBadRequest, "invalid_operation", "Invalid operation", nil)
|
|
}
|
|
}
|
|
|
|
// HandleGet processes GET requests for metadata
|
|
func (h *Handler) HandleGet(w common.ResponseWriter, r common.Request, params map[string]string) {
|
|
// Capture panics and return error response
|
|
defer func() {
|
|
if err := recover(); err != nil {
|
|
h.handlePanic(w, "HandleGet", err)
|
|
}
|
|
}()
|
|
|
|
// Check for ?openapi query parameter
|
|
if r.UnderlyingRequest().URL.Query().Get("openapi") != "" {
|
|
h.HandleOpenAPI(w, r)
|
|
return
|
|
}
|
|
|
|
schema := params["schema"]
|
|
entity := params["entity"]
|
|
|
|
logger.Info("Getting metadata for %s.%s", schema, entity)
|
|
|
|
model, err := h.registry.GetModelByEntity(schema, entity)
|
|
if err != nil {
|
|
// Model not found - call fallback handler if set, otherwise pass through
|
|
logger.Debug("Model not found for %s.%s", schema, entity)
|
|
if h.fallbackHandler != nil {
|
|
logger.Debug("Calling fallback handler for %s.%s", schema, entity)
|
|
h.fallbackHandler(w, r, params)
|
|
} else {
|
|
logger.Debug("No fallback handler set, passing through to next route")
|
|
}
|
|
return
|
|
}
|
|
|
|
metadata := h.generateMetadata(schema, entity, model)
|
|
h.sendResponse(w, metadata, nil)
|
|
}
|
|
|
|
// handleMeta processes meta operation requests
|
|
func (h *Handler) handleMeta(ctx context.Context, w common.ResponseWriter, schema, entity string, model interface{}) {
|
|
// Capture panics and return error response
|
|
defer func() {
|
|
if err := recover(); err != nil {
|
|
h.handlePanic(w, "handleMeta", err)
|
|
}
|
|
}()
|
|
|
|
logger.Info("Getting metadata for %s.%s via meta operation", schema, entity)
|
|
|
|
metadata := h.generateMetadata(schema, entity, model)
|
|
h.sendResponse(w, metadata, nil)
|
|
}
|
|
|
|
func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id string, options common.RequestOptions) {
|
|
// Capture panics and return error response
|
|
defer func() {
|
|
if err := recover(); err != nil {
|
|
h.handlePanic(w, "handleRead", err)
|
|
}
|
|
}()
|
|
|
|
schema := GetSchema(ctx)
|
|
entity := GetEntity(ctx)
|
|
tableName := GetTableName(ctx)
|
|
model := GetModel(ctx)
|
|
|
|
// Validate and unwrap model type to get base struct
|
|
modelType := reflect.TypeOf(model)
|
|
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
|
modelType = modelType.Elem()
|
|
}
|
|
|
|
if modelType == nil || modelType.Kind() != reflect.Struct {
|
|
logger.Error("Model must be a struct type, got %v for %s.%s", modelType, schema, entity)
|
|
h.sendError(w, http.StatusInternalServerError, "invalid_model", "Model must be a struct type", fmt.Errorf("invalid model type: %v", modelType))
|
|
return
|
|
}
|
|
|
|
logger.Info("Reading records from %s.%s", schema, entity)
|
|
|
|
// Create the model pointer for Scan() operations
|
|
sliceType := reflect.SliceOf(reflect.PointerTo(modelType))
|
|
modelPtr := reflect.New(sliceType).Interface()
|
|
|
|
// Start with Model() using the slice pointer to avoid "Model(nil)" errors in Count()
|
|
// Bun's Model() accepts both single pointers and slice pointers
|
|
query := h.db.NewSelect().Model(modelPtr)
|
|
|
|
// Only set Table() if the model doesn't provide a table name via the underlying type
|
|
// Create a temporary instance to check for TableNameProvider
|
|
tempInstance := reflect.New(modelType).Interface()
|
|
if provider, ok := tempInstance.(common.TableNameProvider); !ok || provider.TableName() == "" {
|
|
query = query.Table(tableName)
|
|
}
|
|
|
|
if len(options.Columns) == 0 && (len(options.ComputedColumns) > 0) {
|
|
logger.Debug("Populating options.Columns with all model columns since computed columns are additions")
|
|
options.Columns = reflection.GetSQLModelColumns(model)
|
|
}
|
|
|
|
// Apply column selection
|
|
if len(options.Columns) > 0 {
|
|
logger.Debug("Selecting columns: %v", options.Columns)
|
|
for _, col := range options.Columns {
|
|
query = query.Column(reflection.ExtractSourceColumn(col))
|
|
}
|
|
}
|
|
|
|
if len(options.ComputedColumns) > 0 {
|
|
for _, cu := range options.ComputedColumns {
|
|
logger.Debug("Applying computed column: %s", cu.Name)
|
|
query = query.ColumnExpr(fmt.Sprintf("(%s) AS %s", cu.Expression, cu.Name))
|
|
}
|
|
}
|
|
|
|
// Apply preloading
|
|
if len(options.Preload) > 0 {
|
|
var err error
|
|
query, err = h.applyPreloads(model, query, options.Preload)
|
|
if err != nil {
|
|
logger.Error("Failed to apply preloads: %v", err)
|
|
h.sendError(w, http.StatusBadRequest, "invalid_preload", "Failed to apply preloads", err)
|
|
return
|
|
}
|
|
}
|
|
|
|
// Apply filters with proper grouping for OR logic
|
|
query = h.applyFilters(query, options.Filters)
|
|
|
|
// Apply custom operators
|
|
for _, customOp := range options.CustomOperators {
|
|
logger.Debug("Applying custom operator: %s - %s", customOp.Name, customOp.SQL)
|
|
query = query.Where(customOp.SQL)
|
|
}
|
|
|
|
// Apply sorting
|
|
for _, sort := range options.Sort {
|
|
direction := "ASC"
|
|
if strings.EqualFold(sort.Direction, "desc") {
|
|
direction = "DESC"
|
|
}
|
|
logger.Debug("Applying sort: %s %s", sort.Column, direction)
|
|
query = query.Order(fmt.Sprintf("%s %s", sort.Column, direction))
|
|
}
|
|
|
|
// Apply cursor-based pagination
|
|
if len(options.CursorForward) > 0 || len(options.CursorBackward) > 0 {
|
|
logger.Debug("Applying cursor pagination")
|
|
|
|
// Get primary key name
|
|
pkName := reflection.GetPrimaryKeyName(model)
|
|
|
|
// Extract model columns for validation
|
|
modelColumns := reflection.GetModelColumns(model)
|
|
|
|
// Get cursor filter SQL
|
|
cursorFilter, err := GetCursorFilter(tableName, pkName, modelColumns, options)
|
|
if err != nil {
|
|
logger.Error("Error building cursor filter: %v", err)
|
|
h.sendError(w, http.StatusBadRequest, "cursor_error", "Invalid cursor pagination", err)
|
|
return
|
|
}
|
|
|
|
// Apply cursor filter to query
|
|
if cursorFilter != "" {
|
|
logger.Debug("Applying cursor filter: %s", cursorFilter)
|
|
sanitizedCursor := common.SanitizeWhereClause(cursorFilter, reflection.ExtractTableNameOnly(tableName), &options)
|
|
// Ensure outer parentheses to prevent OR logic from escaping
|
|
sanitizedCursor = common.EnsureOuterParentheses(sanitizedCursor)
|
|
if sanitizedCursor != "" {
|
|
query = query.Where(sanitizedCursor)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Get total count before pagination
|
|
var total int
|
|
|
|
// Try to get from cache first
|
|
// Use extended cache key if cursors are present
|
|
var cacheKeyHash string
|
|
if len(options.CursorForward) > 0 || len(options.CursorBackward) > 0 {
|
|
cacheKeyHash = buildExtendedQueryCacheKey(
|
|
tableName,
|
|
options.Filters,
|
|
options.Sort,
|
|
"", // No custom SQL WHERE in resolvespec
|
|
"", // No custom SQL OR in resolvespec
|
|
options.CursorForward,
|
|
options.CursorBackward,
|
|
)
|
|
} else {
|
|
cacheKeyHash = buildQueryCacheKey(
|
|
tableName,
|
|
options.Filters,
|
|
options.Sort,
|
|
"", // No custom SQL WHERE in resolvespec
|
|
"", // No custom SQL OR in resolvespec
|
|
)
|
|
}
|
|
cacheKey := getQueryTotalCacheKey(cacheKeyHash)
|
|
|
|
// Try to retrieve from cache
|
|
var cachedTotal cachedTotal
|
|
err := cache.GetDefaultCache().Get(ctx, cacheKey, &cachedTotal)
|
|
if err == nil {
|
|
total = cachedTotal.Total
|
|
logger.Debug("Total records (from cache): %d", total)
|
|
} else {
|
|
// Cache miss - execute count query
|
|
logger.Debug("Cache miss for query total")
|
|
count, err := query.Count(ctx)
|
|
if err != nil {
|
|
logger.Error("Error counting records: %v", err)
|
|
h.sendError(w, http.StatusInternalServerError, "query_error", "Error counting records", err)
|
|
return
|
|
}
|
|
total = count
|
|
logger.Debug("Total records (from query): %d", total)
|
|
|
|
// Store in cache with schema and table tags
|
|
cacheTTL := time.Minute * 2 // Default 2 minutes TTL
|
|
if err := setQueryTotalCache(ctx, cacheKey, total, schema, tableName, cacheTTL); err != nil {
|
|
logger.Warn("Failed to cache query total: %v", err)
|
|
// Don't fail the request if caching fails
|
|
} else {
|
|
logger.Debug("Cached query total with key: %s", cacheKey)
|
|
}
|
|
}
|
|
|
|
// Handle FetchRowNumber if requested
|
|
var rowNumber *int64
|
|
if options.FetchRowNumber != nil && *options.FetchRowNumber != "" {
|
|
logger.Debug("Fetching row number for ID: %s", *options.FetchRowNumber)
|
|
pkName := reflection.GetPrimaryKeyName(model)
|
|
|
|
// Build ROW_NUMBER window function SQL
|
|
rowNumberSQL := "ROW_NUMBER() OVER ("
|
|
if len(options.Sort) > 0 {
|
|
rowNumberSQL += "ORDER BY "
|
|
for i, sort := range options.Sort {
|
|
if i > 0 {
|
|
rowNumberSQL += ", "
|
|
}
|
|
direction := "ASC"
|
|
if strings.EqualFold(sort.Direction, "desc") {
|
|
direction = "DESC"
|
|
}
|
|
rowNumberSQL += fmt.Sprintf("%s %s", sort.Column, direction)
|
|
}
|
|
}
|
|
rowNumberSQL += ")"
|
|
|
|
// Create a query to fetch the row number using a subquery approach
|
|
// We'll select the PK and row_number, then filter by the target ID
|
|
type RowNumResult struct {
|
|
RowNum int64 `bun:"row_num"`
|
|
}
|
|
|
|
rowNumQuery := h.db.NewSelect().Table(tableName).
|
|
ColumnExpr(fmt.Sprintf("%s AS row_num", rowNumberSQL)).
|
|
Column(pkName)
|
|
|
|
// Apply the same filters as the main query
|
|
for _, filter := range options.Filters {
|
|
rowNumQuery = h.applyFilter(rowNumQuery, filter)
|
|
}
|
|
|
|
// Apply custom operators
|
|
for _, customOp := range options.CustomOperators {
|
|
rowNumQuery = rowNumQuery.Where(customOp.SQL)
|
|
}
|
|
|
|
// Filter for the specific ID we want the row number for
|
|
rowNumQuery = rowNumQuery.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), *options.FetchRowNumber)
|
|
|
|
// Execute query to get row number
|
|
var result RowNumResult
|
|
if err := rowNumQuery.Scan(ctx, &result); err != nil {
|
|
if err != sql.ErrNoRows {
|
|
logger.Warn("Error fetching row number: %v", err)
|
|
}
|
|
} else {
|
|
rowNumber = &result.RowNum
|
|
logger.Debug("Found row number: %d", *rowNumber)
|
|
}
|
|
}
|
|
|
|
// Apply pagination (skip if FetchRowNumber is set - we want only that record)
|
|
if options.FetchRowNumber == nil || *options.FetchRowNumber == "" {
|
|
if options.Limit != nil && *options.Limit > 0 {
|
|
logger.Debug("Applying limit: %d", *options.Limit)
|
|
query = query.Limit(*options.Limit)
|
|
}
|
|
if options.Offset != nil && *options.Offset > 0 {
|
|
logger.Debug("Applying offset: %d", *options.Offset)
|
|
query = query.Offset(*options.Offset)
|
|
}
|
|
}
|
|
|
|
// Execute query
|
|
var result interface{}
|
|
if id != "" || (options.FetchRowNumber != nil && *options.FetchRowNumber != "") {
|
|
// Single record query - either by URL ID or FetchRowNumber
|
|
var targetID string
|
|
if id != "" {
|
|
targetID = id
|
|
logger.Debug("Querying single record with URL ID: %s", id)
|
|
} else {
|
|
targetID = *options.FetchRowNumber
|
|
logger.Debug("Querying single record with FetchRowNumber ID: %s", targetID)
|
|
}
|
|
|
|
// For single record, create a new pointer to the struct type
|
|
singleResult := reflect.New(modelType).Interface()
|
|
pkName := reflection.GetPrimaryKeyName(singleResult)
|
|
|
|
query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), targetID)
|
|
if err := query.Scan(ctx, singleResult); err != nil {
|
|
logger.Error("Error querying record: %v", err)
|
|
h.sendError(w, http.StatusInternalServerError, "query_error", "Error executing query", err)
|
|
return
|
|
}
|
|
result = singleResult
|
|
} else {
|
|
logger.Debug("Querying multiple records")
|
|
// Use the modelPtr already created and set on the query
|
|
if err := query.Scan(ctx, modelPtr); err != nil {
|
|
logger.Error("Error querying records: %v", err)
|
|
h.sendError(w, http.StatusInternalServerError, "query_error", "Error executing query", err)
|
|
return
|
|
}
|
|
result = reflect.ValueOf(modelPtr).Elem().Interface()
|
|
}
|
|
|
|
logger.Info("Successfully retrieved records")
|
|
|
|
// Build metadata
|
|
limit := 0
|
|
offset := 0
|
|
count := int64(total)
|
|
|
|
// When FetchRowNumber is used, we only return 1 record
|
|
if options.FetchRowNumber != nil && *options.FetchRowNumber != "" {
|
|
count = 1
|
|
// Don't use limit/offset when fetching specific record
|
|
} else {
|
|
if options.Limit != nil {
|
|
limit = *options.Limit
|
|
}
|
|
if options.Offset != nil {
|
|
offset = *options.Offset
|
|
}
|
|
|
|
// Set row numbers on records if RowNumber field exists
|
|
// Only for multiple records (not when fetching single record)
|
|
h.setRowNumbersOnRecords(result, offset)
|
|
}
|
|
|
|
h.sendResponse(w, result, &common.Metadata{
|
|
Total: int64(total),
|
|
Filtered: int64(total),
|
|
Count: count,
|
|
Limit: limit,
|
|
Offset: offset,
|
|
RowNumber: rowNumber,
|
|
})
|
|
}
|
|
|
|
func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, data interface{}, options common.RequestOptions) {
|
|
// Capture panics and return error response
|
|
defer func() {
|
|
if err := recover(); err != nil {
|
|
h.handlePanic(w, "handleCreate", err)
|
|
}
|
|
}()
|
|
|
|
schema := GetSchema(ctx)
|
|
entity := GetEntity(ctx)
|
|
tableName := GetTableName(ctx)
|
|
model := GetModel(ctx)
|
|
|
|
logger.Info("Creating records for %s.%s", schema, entity)
|
|
|
|
// Check if data contains nested relations or _request field
|
|
switch v := data.(type) {
|
|
case map[string]interface{}:
|
|
// Check if we should use nested processing
|
|
if h.shouldUseNestedProcessor(v, model) {
|
|
logger.Info("Using nested CUD processor for create operation")
|
|
result, err := h.nestedProcessor.ProcessNestedCUD(ctx, "insert", v, model, make(map[string]interface{}), tableName)
|
|
if err != nil {
|
|
logger.Error("Error in nested create: %v", err)
|
|
h.sendError(w, http.StatusInternalServerError, "create_error", "Error creating record with nested data", err)
|
|
return
|
|
}
|
|
logger.Info("Successfully created record with nested data, ID: %v", result.ID)
|
|
// Invalidate cache for this table
|
|
cacheTags := buildCacheTags(schema, tableName)
|
|
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
|
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
|
|
}
|
|
h.sendResponse(w, result.Data, nil)
|
|
return
|
|
}
|
|
|
|
// Standard processing without nested relations
|
|
query := h.db.NewInsert().Table(tableName)
|
|
for key, value := range v {
|
|
query = query.Value(key, value)
|
|
}
|
|
result, err := query.Exec(ctx)
|
|
if err != nil {
|
|
logger.Error("Error creating record: %v", err)
|
|
h.sendError(w, http.StatusInternalServerError, "create_error", "Error creating record", err)
|
|
return
|
|
}
|
|
logger.Info("Successfully created record, rows affected: %d", result.RowsAffected())
|
|
// Invalidate cache for this table
|
|
cacheTags := buildCacheTags(schema, tableName)
|
|
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
|
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
|
|
}
|
|
h.sendResponse(w, v, nil)
|
|
|
|
case []map[string]interface{}:
|
|
// Check if any item needs nested processing
|
|
hasNestedData := false
|
|
for _, item := range v {
|
|
if h.shouldUseNestedProcessor(item, model) {
|
|
hasNestedData = true
|
|
break
|
|
}
|
|
}
|
|
|
|
if hasNestedData {
|
|
logger.Info("Using nested CUD processor for batch create with nested data")
|
|
results := make([]map[string]interface{}, 0, len(v))
|
|
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
|
// Temporarily swap the database to use transaction
|
|
originalDB := h.nestedProcessor
|
|
h.nestedProcessor = common.NewNestedCUDProcessor(tx, h.registry, h)
|
|
defer func() {
|
|
h.nestedProcessor = originalDB
|
|
}()
|
|
|
|
for _, item := range v {
|
|
result, err := h.nestedProcessor.ProcessNestedCUD(ctx, "insert", item, model, make(map[string]interface{}), tableName)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to process item: %w", err)
|
|
}
|
|
results = append(results, result.Data)
|
|
}
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
logger.Error("Error creating records with nested data: %v", err)
|
|
h.sendError(w, http.StatusInternalServerError, "create_error", "Error creating records with nested data", err)
|
|
return
|
|
}
|
|
logger.Info("Successfully created %d records with nested data", len(results))
|
|
// Invalidate cache for this table
|
|
cacheTags := buildCacheTags(schema, tableName)
|
|
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
|
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
|
|
}
|
|
h.sendResponse(w, results, nil)
|
|
return
|
|
}
|
|
|
|
// Standard batch insert without nested relations
|
|
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
|
for _, item := range v {
|
|
txQuery := tx.NewInsert().Table(tableName)
|
|
for key, value := range item {
|
|
txQuery = txQuery.Value(key, value)
|
|
}
|
|
if _, err := txQuery.Exec(ctx); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
logger.Error("Error creating records: %v", err)
|
|
h.sendError(w, http.StatusInternalServerError, "create_error", "Error creating records", err)
|
|
return
|
|
}
|
|
logger.Info("Successfully created %d records", len(v))
|
|
// Invalidate cache for this table
|
|
cacheTags := buildCacheTags(schema, tableName)
|
|
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
|
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
|
|
}
|
|
h.sendResponse(w, v, nil)
|
|
|
|
case []interface{}:
|
|
// Handle []interface{} type from JSON unmarshaling
|
|
// Check if any item needs nested processing
|
|
hasNestedData := false
|
|
for _, item := range v {
|
|
if itemMap, ok := item.(map[string]interface{}); ok {
|
|
if h.shouldUseNestedProcessor(itemMap, model) {
|
|
hasNestedData = true
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
if hasNestedData {
|
|
logger.Info("Using nested CUD processor for batch create with nested data ([]interface{})")
|
|
results := make([]interface{}, 0, len(v))
|
|
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
|
// Temporarily swap the database to use transaction
|
|
originalDB := h.nestedProcessor
|
|
h.nestedProcessor = common.NewNestedCUDProcessor(tx, h.registry, h)
|
|
defer func() {
|
|
h.nestedProcessor = originalDB
|
|
}()
|
|
|
|
for _, item := range v {
|
|
if itemMap, ok := item.(map[string]interface{}); ok {
|
|
result, err := h.nestedProcessor.ProcessNestedCUD(ctx, "insert", itemMap, model, make(map[string]interface{}), tableName)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to process item: %w", err)
|
|
}
|
|
results = append(results, result.Data)
|
|
}
|
|
}
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
logger.Error("Error creating records with nested data: %v", err)
|
|
h.sendError(w, http.StatusInternalServerError, "create_error", "Error creating records with nested data", err)
|
|
return
|
|
}
|
|
logger.Info("Successfully created %d records with nested data", len(results))
|
|
// Invalidate cache for this table
|
|
cacheTags := buildCacheTags(schema, tableName)
|
|
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
|
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
|
|
}
|
|
h.sendResponse(w, results, nil)
|
|
return
|
|
}
|
|
|
|
// Standard batch insert without nested relations
|
|
list := make([]interface{}, 0)
|
|
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
|
for _, item := range v {
|
|
if itemMap, ok := item.(map[string]interface{}); ok {
|
|
txQuery := tx.NewInsert().Table(tableName)
|
|
for key, value := range itemMap {
|
|
txQuery = txQuery.Value(key, value)
|
|
}
|
|
if _, err := txQuery.Exec(ctx); err != nil {
|
|
return err
|
|
}
|
|
list = append(list, item)
|
|
}
|
|
}
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
logger.Error("Error creating records: %v", err)
|
|
h.sendError(w, http.StatusInternalServerError, "create_error", "Error creating records", err)
|
|
return
|
|
}
|
|
logger.Info("Successfully created %d records", len(v))
|
|
// Invalidate cache for this table
|
|
cacheTags := buildCacheTags(schema, tableName)
|
|
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
|
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
|
|
}
|
|
h.sendResponse(w, list, nil)
|
|
|
|
default:
|
|
logger.Error("Invalid data type for create operation: %T", data)
|
|
h.sendError(w, http.StatusBadRequest, "invalid_data", "Invalid data type for create operation", nil)
|
|
}
|
|
}
|
|
|
|
func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, urlID string, reqID interface{}, data interface{}, options common.RequestOptions) {
|
|
// Capture panics and return error response
|
|
defer func() {
|
|
if err := recover(); err != nil {
|
|
h.handlePanic(w, "handleUpdate", err)
|
|
}
|
|
}()
|
|
|
|
schema := GetSchema(ctx)
|
|
entity := GetEntity(ctx)
|
|
tableName := GetTableName(ctx)
|
|
model := GetModel(ctx)
|
|
|
|
logger.Info("Updating records for %s.%s", schema, entity)
|
|
|
|
switch updates := data.(type) {
|
|
case map[string]interface{}:
|
|
// Determine the ID to use
|
|
var targetID interface{}
|
|
switch {
|
|
case urlID != "":
|
|
targetID = urlID
|
|
case reqID != nil:
|
|
targetID = reqID
|
|
case updates["id"] != nil:
|
|
targetID = updates["id"]
|
|
}
|
|
|
|
// Check if we should use nested processing
|
|
if h.shouldUseNestedProcessor(updates, model) {
|
|
logger.Info("Using nested CUD processor for update operation")
|
|
// Ensure ID is in the data map
|
|
if targetID != nil {
|
|
updates["id"] = targetID
|
|
}
|
|
result, err := h.nestedProcessor.ProcessNestedCUD(ctx, "update", updates, model, make(map[string]interface{}), tableName)
|
|
if err != nil {
|
|
logger.Error("Error in nested update: %v", err)
|
|
h.sendError(w, http.StatusInternalServerError, "update_error", "Error updating record with nested data", err)
|
|
return
|
|
}
|
|
logger.Info("Successfully updated record with nested data, rows: %d", result.AffectedRows)
|
|
// Invalidate cache for this table
|
|
cacheTags := buildCacheTags(schema, tableName)
|
|
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
|
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
|
|
}
|
|
h.sendResponse(w, result.Data, nil)
|
|
return
|
|
}
|
|
|
|
// Standard processing without nested relations
|
|
// Get the primary key name
|
|
pkName := reflection.GetPrimaryKeyName(model)
|
|
|
|
// Wrap in transaction to ensure BeforeUpdate hook is inside transaction
|
|
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
|
// First, read the existing record from the database
|
|
existingRecord := reflect.New(reflection.GetPointerElement(reflect.TypeOf(model))).Interface()
|
|
selectQuery := tx.NewSelect().Model(existingRecord).Column("*")
|
|
|
|
// Apply conditions to select
|
|
if urlID != "" {
|
|
logger.Debug("Updating by URL ID: %s", urlID)
|
|
selectQuery = selectQuery.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), urlID)
|
|
} else if reqID != nil {
|
|
switch id := reqID.(type) {
|
|
case string:
|
|
logger.Debug("Updating by request ID: %s", id)
|
|
selectQuery = selectQuery.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id)
|
|
case []string:
|
|
if len(id) > 0 {
|
|
logger.Debug("Updating by multiple IDs: %v", id)
|
|
selectQuery = selectQuery.Where(fmt.Sprintf("%s IN (?)", common.QuoteIdent(pkName)), id)
|
|
}
|
|
}
|
|
}
|
|
|
|
if err := selectQuery.ScanModel(ctx); err != nil {
|
|
if err == sql.ErrNoRows {
|
|
return fmt.Errorf("no records found to update")
|
|
}
|
|
return fmt.Errorf("error fetching existing record: %w", err)
|
|
}
|
|
|
|
// Convert existing record to map
|
|
existingMap := make(map[string]interface{})
|
|
jsonData, err := json.Marshal(existingRecord)
|
|
if err != nil {
|
|
return fmt.Errorf("error marshaling existing record: %w", err)
|
|
}
|
|
if err := json.Unmarshal(jsonData, &existingMap); err != nil {
|
|
return fmt.Errorf("error unmarshaling existing record: %w", err)
|
|
}
|
|
|
|
// Execute BeforeUpdate hooks inside transaction
|
|
hookCtx := &HookContext{
|
|
Context: ctx,
|
|
Handler: h,
|
|
Schema: schema,
|
|
Entity: entity,
|
|
Model: model,
|
|
Options: options,
|
|
ID: urlID,
|
|
Data: updates,
|
|
Writer: w,
|
|
Tx: tx,
|
|
}
|
|
|
|
if err := h.hooks.Execute(BeforeUpdate, hookCtx); err != nil {
|
|
return fmt.Errorf("BeforeUpdate hook failed: %w", err)
|
|
}
|
|
|
|
// Use potentially modified data from hook context
|
|
if modifiedData, ok := hookCtx.Data.(map[string]interface{}); ok {
|
|
updates = modifiedData
|
|
}
|
|
|
|
// Merge only non-null and non-empty values from the incoming request into the existing record
|
|
for key, newValue := range updates {
|
|
// Skip if the value is nil
|
|
if newValue == nil {
|
|
continue
|
|
}
|
|
|
|
// Skip if the value is an empty string
|
|
if strVal, ok := newValue.(string); ok && strVal == "" {
|
|
continue
|
|
}
|
|
|
|
// Update the existing map with the new value
|
|
existingMap[key] = newValue
|
|
}
|
|
|
|
// Build update query with merged data
|
|
query := tx.NewUpdate().Table(tableName).SetMap(existingMap)
|
|
|
|
// Apply conditions
|
|
if urlID != "" {
|
|
query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), urlID)
|
|
} else if reqID != nil {
|
|
switch id := reqID.(type) {
|
|
case string:
|
|
query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id)
|
|
case []string:
|
|
query = query.Where(fmt.Sprintf("%s IN (?)", common.QuoteIdent(pkName)), id)
|
|
}
|
|
}
|
|
|
|
result, err := query.Exec(ctx)
|
|
if err != nil {
|
|
return fmt.Errorf("error updating record(s): %w", err)
|
|
}
|
|
|
|
if result.RowsAffected() == 0 {
|
|
return fmt.Errorf("no records found to update")
|
|
}
|
|
|
|
// Execute AfterUpdate hooks inside transaction
|
|
hookCtx.Result = updates
|
|
hookCtx.Error = nil
|
|
if err := h.hooks.Execute(AfterUpdate, hookCtx); err != nil {
|
|
return fmt.Errorf("AfterUpdate hook failed: %w", err)
|
|
}
|
|
|
|
return nil
|
|
})
|
|
|
|
if err != nil {
|
|
logger.Error("Update error: %v", err)
|
|
if err.Error() == "no records found to update" {
|
|
h.sendError(w, http.StatusNotFound, "not_found", "No records found to update", err)
|
|
} else {
|
|
h.sendError(w, http.StatusInternalServerError, "update_error", "Error updating record(s)", err)
|
|
}
|
|
return
|
|
}
|
|
|
|
logger.Info("Successfully updated record(s)")
|
|
// Invalidate cache for this table
|
|
cacheTags := buildCacheTags(schema, tableName)
|
|
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
|
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
|
|
}
|
|
h.sendResponse(w, data, nil)
|
|
|
|
case []map[string]interface{}:
|
|
// Batch update with array of objects
|
|
hasNestedData := false
|
|
for _, item := range updates {
|
|
if h.shouldUseNestedProcessor(item, model) {
|
|
hasNestedData = true
|
|
break
|
|
}
|
|
}
|
|
|
|
if hasNestedData {
|
|
logger.Info("Using nested CUD processor for batch update with nested data")
|
|
results := make([]map[string]interface{}, 0, len(updates))
|
|
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
|
// Temporarily swap the database to use transaction
|
|
originalDB := h.nestedProcessor
|
|
h.nestedProcessor = common.NewNestedCUDProcessor(tx, h.registry, h)
|
|
defer func() {
|
|
h.nestedProcessor = originalDB
|
|
}()
|
|
|
|
for _, item := range updates {
|
|
result, err := h.nestedProcessor.ProcessNestedCUD(ctx, "update", item, model, make(map[string]interface{}), tableName)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to process item: %w", err)
|
|
}
|
|
results = append(results, result.Data)
|
|
}
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
logger.Error("Error updating records with nested data: %v", err)
|
|
h.sendError(w, http.StatusInternalServerError, "update_error", "Error updating records with nested data", err)
|
|
return
|
|
}
|
|
logger.Info("Successfully updated %d records with nested data", len(results))
|
|
// Invalidate cache for this table
|
|
cacheTags := buildCacheTags(schema, tableName)
|
|
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
|
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
|
|
}
|
|
h.sendResponse(w, results, nil)
|
|
return
|
|
}
|
|
|
|
// Standard batch update without nested relations
|
|
pkName := reflection.GetPrimaryKeyName(model)
|
|
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
|
for _, item := range updates {
|
|
if itemID, ok := item["id"]; ok {
|
|
itemIDStr := fmt.Sprintf("%v", itemID)
|
|
|
|
// First, read the existing record
|
|
existingRecord := reflect.New(reflection.GetPointerElement(reflect.TypeOf(model))).Interface()
|
|
selectQuery := tx.NewSelect().Model(existingRecord).Column("*").Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), itemID)
|
|
if err := selectQuery.ScanModel(ctx); err != nil {
|
|
if err == sql.ErrNoRows {
|
|
continue // Skip if record not found
|
|
}
|
|
return fmt.Errorf("failed to fetch existing record: %w", err)
|
|
}
|
|
|
|
// Convert existing record to map
|
|
existingMap := make(map[string]interface{})
|
|
jsonData, err := json.Marshal(existingRecord)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to marshal existing record: %w", err)
|
|
}
|
|
if err := json.Unmarshal(jsonData, &existingMap); err != nil {
|
|
return fmt.Errorf("failed to unmarshal existing record: %w", err)
|
|
}
|
|
|
|
// Execute BeforeUpdate hooks inside transaction
|
|
hookCtx := &HookContext{
|
|
Context: ctx,
|
|
Handler: h,
|
|
Schema: schema,
|
|
Entity: entity,
|
|
Model: model,
|
|
Options: options,
|
|
ID: itemIDStr,
|
|
Data: item,
|
|
Writer: w,
|
|
Tx: tx,
|
|
}
|
|
|
|
if err := h.hooks.Execute(BeforeUpdate, hookCtx); err != nil {
|
|
return fmt.Errorf("BeforeUpdate hook failed for ID %v: %w", itemID, err)
|
|
}
|
|
|
|
// Use potentially modified data from hook context
|
|
if modifiedData, ok := hookCtx.Data.(map[string]interface{}); ok {
|
|
item = modifiedData
|
|
}
|
|
|
|
// Merge only non-null and non-empty values
|
|
for key, newValue := range item {
|
|
if newValue == nil {
|
|
continue
|
|
}
|
|
if strVal, ok := newValue.(string); ok && strVal == "" {
|
|
continue
|
|
}
|
|
existingMap[key] = newValue
|
|
}
|
|
|
|
txQuery := tx.NewUpdate().Table(tableName).SetMap(existingMap).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), itemID)
|
|
if _, err := txQuery.Exec(ctx); err != nil {
|
|
return err
|
|
}
|
|
|
|
// Execute AfterUpdate hooks inside transaction
|
|
hookCtx.Result = item
|
|
hookCtx.Error = nil
|
|
if err := h.hooks.Execute(AfterUpdate, hookCtx); err != nil {
|
|
return fmt.Errorf("AfterUpdate hook failed for ID %v: %w", itemID, err)
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
logger.Error("Error updating records: %v", err)
|
|
h.sendError(w, http.StatusInternalServerError, "update_error", "Error updating records", err)
|
|
return
|
|
}
|
|
logger.Info("Successfully updated %d records", len(updates))
|
|
// Invalidate cache for this table
|
|
cacheTags := buildCacheTags(schema, tableName)
|
|
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
|
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
|
|
}
|
|
h.sendResponse(w, updates, nil)
|
|
|
|
case []interface{}:
|
|
// Batch update with []interface{}
|
|
hasNestedData := false
|
|
for _, item := range updates {
|
|
if itemMap, ok := item.(map[string]interface{}); ok {
|
|
if h.shouldUseNestedProcessor(itemMap, model) {
|
|
hasNestedData = true
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
if hasNestedData {
|
|
logger.Info("Using nested CUD processor for batch update with nested data ([]interface{})")
|
|
results := make([]interface{}, 0, len(updates))
|
|
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
|
// Temporarily swap the database to use transaction
|
|
originalDB := h.nestedProcessor
|
|
h.nestedProcessor = common.NewNestedCUDProcessor(tx, h.registry, h)
|
|
defer func() {
|
|
h.nestedProcessor = originalDB
|
|
}()
|
|
|
|
for _, item := range updates {
|
|
if itemMap, ok := item.(map[string]interface{}); ok {
|
|
result, err := h.nestedProcessor.ProcessNestedCUD(ctx, "update", itemMap, model, make(map[string]interface{}), tableName)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to process item: %w", err)
|
|
}
|
|
results = append(results, result.Data)
|
|
}
|
|
}
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
logger.Error("Error updating records with nested data: %v", err)
|
|
h.sendError(w, http.StatusInternalServerError, "update_error", "Error updating records with nested data", err)
|
|
return
|
|
}
|
|
logger.Info("Successfully updated %d records with nested data", len(results))
|
|
// Invalidate cache for this table
|
|
cacheTags := buildCacheTags(schema, tableName)
|
|
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
|
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
|
|
}
|
|
h.sendResponse(w, results, nil)
|
|
return
|
|
}
|
|
|
|
// Standard batch update without nested relations
|
|
pkName := reflection.GetPrimaryKeyName(model)
|
|
list := make([]interface{}, 0)
|
|
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
|
for _, item := range updates {
|
|
if itemMap, ok := item.(map[string]interface{}); ok {
|
|
if itemID, ok := itemMap["id"]; ok {
|
|
itemIDStr := fmt.Sprintf("%v", itemID)
|
|
|
|
// First, read the existing record
|
|
existingRecord := reflect.New(reflection.GetPointerElement(reflect.TypeOf(model))).Interface()
|
|
selectQuery := tx.NewSelect().Model(existingRecord).Column("*").Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), itemID)
|
|
if err := selectQuery.ScanModel(ctx); err != nil {
|
|
if err == sql.ErrNoRows {
|
|
continue // Skip if record not found
|
|
}
|
|
return fmt.Errorf("failed to fetch existing record: %w", err)
|
|
}
|
|
|
|
// Convert existing record to map
|
|
existingMap := make(map[string]interface{})
|
|
jsonData, err := json.Marshal(existingRecord)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to marshal existing record: %w", err)
|
|
}
|
|
if err := json.Unmarshal(jsonData, &existingMap); err != nil {
|
|
return fmt.Errorf("failed to unmarshal existing record: %w", err)
|
|
}
|
|
|
|
// Execute BeforeUpdate hooks inside transaction
|
|
hookCtx := &HookContext{
|
|
Context: ctx,
|
|
Handler: h,
|
|
Schema: schema,
|
|
Entity: entity,
|
|
Model: model,
|
|
Options: options,
|
|
ID: itemIDStr,
|
|
Data: itemMap,
|
|
Writer: w,
|
|
Tx: tx,
|
|
}
|
|
|
|
if err := h.hooks.Execute(BeforeUpdate, hookCtx); err != nil {
|
|
return fmt.Errorf("BeforeUpdate hook failed for ID %v: %w", itemID, err)
|
|
}
|
|
|
|
// Use potentially modified data from hook context
|
|
if modifiedData, ok := hookCtx.Data.(map[string]interface{}); ok {
|
|
itemMap = modifiedData
|
|
}
|
|
|
|
// Merge only non-null and non-empty values
|
|
for key, newValue := range itemMap {
|
|
if newValue == nil {
|
|
continue
|
|
}
|
|
if strVal, ok := newValue.(string); ok && strVal == "" {
|
|
continue
|
|
}
|
|
existingMap[key] = newValue
|
|
}
|
|
|
|
txQuery := tx.NewUpdate().Table(tableName).SetMap(existingMap).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), itemID)
|
|
if _, err := txQuery.Exec(ctx); err != nil {
|
|
return err
|
|
}
|
|
|
|
// Execute AfterUpdate hooks inside transaction
|
|
hookCtx.Result = itemMap
|
|
hookCtx.Error = nil
|
|
if err := h.hooks.Execute(AfterUpdate, hookCtx); err != nil {
|
|
return fmt.Errorf("AfterUpdate hook failed for ID %v: %w", itemID, err)
|
|
}
|
|
|
|
list = append(list, item)
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
logger.Error("Error updating records: %v", err)
|
|
h.sendError(w, http.StatusInternalServerError, "update_error", "Error updating records", err)
|
|
return
|
|
}
|
|
logger.Info("Successfully updated %d records", len(list))
|
|
// Invalidate cache for this table
|
|
cacheTags := buildCacheTags(schema, tableName)
|
|
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
|
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
|
|
}
|
|
h.sendResponse(w, list, nil)
|
|
|
|
default:
|
|
logger.Error("Invalid data type for update operation: %T", data)
|
|
h.sendError(w, http.StatusBadRequest, "invalid_data", "Invalid data type for update operation", nil)
|
|
return
|
|
}
|
|
}
|
|
|
|
func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id string, data interface{}) {
|
|
// Capture panics and return error response
|
|
defer func() {
|
|
if err := recover(); err != nil {
|
|
h.handlePanic(w, "handleDelete", err)
|
|
}
|
|
}()
|
|
|
|
schema := GetSchema(ctx)
|
|
entity := GetEntity(ctx)
|
|
tableName := GetTableName(ctx)
|
|
model := GetModel(ctx)
|
|
|
|
logger.Info("Deleting records from %s.%s", schema, entity)
|
|
|
|
// Handle batch delete from request data
|
|
if data != nil {
|
|
switch v := data.(type) {
|
|
case []string:
|
|
// Array of IDs as strings
|
|
logger.Info("Batch delete with %d IDs ([]string)", len(v))
|
|
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
|
for _, itemID := range v {
|
|
|
|
query := tx.NewDelete().Table(tableName).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(reflection.GetPrimaryKeyName(model))), itemID)
|
|
if _, err := query.Exec(ctx); err != nil {
|
|
return fmt.Errorf("failed to delete record %s: %w", itemID, err)
|
|
}
|
|
}
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
logger.Error("Error in batch delete: %v", err)
|
|
h.sendError(w, http.StatusInternalServerError, "delete_error", "Error deleting records", err)
|
|
return
|
|
}
|
|
logger.Info("Successfully deleted %d records", len(v))
|
|
// Invalidate cache for this table
|
|
cacheTags := buildCacheTags(schema, tableName)
|
|
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
|
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
|
|
}
|
|
h.sendResponse(w, map[string]interface{}{"deleted": len(v)}, nil)
|
|
return
|
|
|
|
case []interface{}:
|
|
// Array of IDs or objects with ID field
|
|
logger.Info("Batch delete with %d items ([]interface{})", len(v))
|
|
deletedCount := 0
|
|
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
|
for _, item := range v {
|
|
var itemID interface{}
|
|
|
|
// Check if item is a string ID or object with id field
|
|
switch v := item.(type) {
|
|
case string:
|
|
itemID = v
|
|
case map[string]interface{}:
|
|
itemID = v["id"]
|
|
default:
|
|
// Try to use the item directly as ID
|
|
itemID = item
|
|
}
|
|
|
|
if itemID == nil {
|
|
continue // Skip items without ID
|
|
}
|
|
|
|
query := tx.NewDelete().Table(tableName).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(reflection.GetPrimaryKeyName(model))), itemID)
|
|
result, err := query.Exec(ctx)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to delete record %v: %w", itemID, err)
|
|
}
|
|
deletedCount += int(result.RowsAffected())
|
|
}
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
logger.Error("Error in batch delete: %v", err)
|
|
h.sendError(w, http.StatusInternalServerError, "delete_error", "Error deleting records", err)
|
|
return
|
|
}
|
|
logger.Info("Successfully deleted %d records", deletedCount)
|
|
// Invalidate cache for this table
|
|
cacheTags := buildCacheTags(schema, tableName)
|
|
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
|
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
|
|
}
|
|
h.sendResponse(w, map[string]interface{}{"deleted": deletedCount}, nil)
|
|
return
|
|
|
|
case []map[string]interface{}:
|
|
// Array of objects with id field
|
|
logger.Info("Batch delete with %d items ([]map[string]interface{})", len(v))
|
|
deletedCount := 0
|
|
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
|
for _, item := range v {
|
|
if itemID, ok := item["id"]; ok && itemID != nil {
|
|
query := tx.NewDelete().Table(tableName).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(reflection.GetPrimaryKeyName(model))), itemID)
|
|
result, err := query.Exec(ctx)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to delete record %v: %w", itemID, err)
|
|
}
|
|
deletedCount += int(result.RowsAffected())
|
|
}
|
|
}
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
logger.Error("Error in batch delete: %v", err)
|
|
h.sendError(w, http.StatusInternalServerError, "delete_error", "Error deleting records", err)
|
|
return
|
|
}
|
|
logger.Info("Successfully deleted %d records", deletedCount)
|
|
// Invalidate cache for this table
|
|
cacheTags := buildCacheTags(schema, tableName)
|
|
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
|
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
|
|
}
|
|
h.sendResponse(w, map[string]interface{}{"deleted": deletedCount}, nil)
|
|
return
|
|
|
|
case map[string]interface{}:
|
|
// Single object with id field
|
|
if itemID, ok := v["id"]; ok && itemID != nil {
|
|
id = fmt.Sprintf("%v", itemID)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Single delete with URL ID
|
|
if id == "" {
|
|
logger.Error("Delete operation requires an ID")
|
|
h.sendError(w, http.StatusBadRequest, "missing_id", "Delete operation requires an ID", nil)
|
|
return
|
|
}
|
|
|
|
// Get primary key name
|
|
pkName := reflection.GetPrimaryKeyName(model)
|
|
|
|
// First, fetch the record that will be deleted
|
|
modelType := reflect.TypeOf(model)
|
|
if modelType.Kind() == reflect.Ptr {
|
|
modelType = modelType.Elem()
|
|
}
|
|
recordToDelete := reflect.New(modelType).Interface()
|
|
|
|
selectQuery := h.db.NewSelect().Model(recordToDelete).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id)
|
|
if err := selectQuery.ScanModel(ctx); err != nil {
|
|
if err == sql.ErrNoRows {
|
|
logger.Warn("Record not found for delete: %s = %s", pkName, id)
|
|
h.sendError(w, http.StatusNotFound, "not_found", "Record not found", err)
|
|
return
|
|
}
|
|
logger.Error("Error fetching record for delete: %v", err)
|
|
h.sendError(w, http.StatusInternalServerError, "fetch_error", "Error fetching record", err)
|
|
return
|
|
}
|
|
|
|
query := h.db.NewDelete().Table(tableName).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id)
|
|
|
|
result, err := query.Exec(ctx)
|
|
if err != nil {
|
|
logger.Error("Error deleting record: %v", err)
|
|
h.sendError(w, http.StatusInternalServerError, "delete_error", "Error deleting record", err)
|
|
return
|
|
}
|
|
|
|
// Check if the record was actually deleted
|
|
if result.RowsAffected() == 0 {
|
|
logger.Warn("No rows deleted for ID: %s", id)
|
|
h.sendError(w, http.StatusNotFound, "not_found", "Record not found or already deleted", nil)
|
|
return
|
|
}
|
|
|
|
logger.Info("Successfully deleted record with ID: %s", id)
|
|
// Return the deleted record data
|
|
// Invalidate cache for this table
|
|
cacheTags := buildCacheTags(schema, tableName)
|
|
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
|
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
|
|
}
|
|
h.sendResponse(w, recordToDelete, nil)
|
|
}
|
|
|
|
// applyFilters applies all filters with proper grouping for OR logic
|
|
// Groups consecutive OR filters together to ensure proper query precedence
|
|
// Example: [A, B(OR), C(OR), D(AND)] => WHERE (A OR B OR C) AND D
|
|
func (h *Handler) applyFilters(query common.SelectQuery, filters []common.FilterOption) common.SelectQuery {
|
|
if len(filters) == 0 {
|
|
return query
|
|
}
|
|
|
|
i := 0
|
|
for i < len(filters) {
|
|
// Check if this starts an OR group (current or next filter has OR logic)
|
|
startORGroup := i+1 < len(filters) && strings.EqualFold(filters[i+1].LogicOperator, "OR")
|
|
|
|
if startORGroup {
|
|
// Collect all consecutive filters that are OR'd together
|
|
orGroup := []common.FilterOption{filters[i]}
|
|
j := i + 1
|
|
for j < len(filters) && strings.EqualFold(filters[j].LogicOperator, "OR") {
|
|
orGroup = append(orGroup, filters[j])
|
|
j++
|
|
}
|
|
|
|
// Apply the OR group as a single grouped WHERE clause
|
|
query = h.applyFilterGroup(query, orGroup)
|
|
i = j
|
|
} else {
|
|
// Single filter with AND logic (or first filter)
|
|
condition, args := h.buildFilterCondition(filters[i])
|
|
if condition != "" {
|
|
query = query.Where(condition, args...)
|
|
}
|
|
i++
|
|
}
|
|
}
|
|
|
|
return query
|
|
}
|
|
|
|
// applyFilterGroup applies a group of filters that should be OR'd together
|
|
// Always wraps them in parentheses and applies as a single WHERE clause
|
|
func (h *Handler) applyFilterGroup(query common.SelectQuery, filters []common.FilterOption) common.SelectQuery {
|
|
if len(filters) == 0 {
|
|
return query
|
|
}
|
|
|
|
// Build all conditions and collect args
|
|
var conditions []string
|
|
var args []interface{}
|
|
|
|
for _, filter := range filters {
|
|
condition, filterArgs := h.buildFilterCondition(filter)
|
|
if condition != "" {
|
|
conditions = append(conditions, condition)
|
|
args = append(args, filterArgs...)
|
|
}
|
|
}
|
|
|
|
if len(conditions) == 0 {
|
|
return query
|
|
}
|
|
|
|
// Single filter - no need for grouping
|
|
if len(conditions) == 1 {
|
|
return query.Where(conditions[0], args...)
|
|
}
|
|
|
|
// Multiple conditions - group with parentheses and OR
|
|
groupedCondition := "(" + strings.Join(conditions, " OR ") + ")"
|
|
return query.Where(groupedCondition, args...)
|
|
}
|
|
|
|
// buildFilterCondition builds a filter condition and returns it with args
|
|
func (h *Handler) buildFilterCondition(filter common.FilterOption) (conditionString string, conditionArgs []interface{}) {
|
|
var condition string
|
|
var args []interface{}
|
|
|
|
switch filter.Operator {
|
|
case "eq":
|
|
condition = fmt.Sprintf("%s = ?", filter.Column)
|
|
args = []interface{}{filter.Value}
|
|
case "neq":
|
|
condition = fmt.Sprintf("%s != ?", filter.Column)
|
|
args = []interface{}{filter.Value}
|
|
case "gt":
|
|
condition = fmt.Sprintf("%s > ?", filter.Column)
|
|
args = []interface{}{filter.Value}
|
|
case "gte":
|
|
condition = fmt.Sprintf("%s >= ?", filter.Column)
|
|
args = []interface{}{filter.Value}
|
|
case "lt":
|
|
condition = fmt.Sprintf("%s < ?", filter.Column)
|
|
args = []interface{}{filter.Value}
|
|
case "lte":
|
|
condition = fmt.Sprintf("%s <= ?", filter.Column)
|
|
args = []interface{}{filter.Value}
|
|
case "like":
|
|
condition = fmt.Sprintf("%s LIKE ?", filter.Column)
|
|
args = []interface{}{filter.Value}
|
|
case "ilike":
|
|
condition = fmt.Sprintf("%s ILIKE ?", filter.Column)
|
|
args = []interface{}{filter.Value}
|
|
case "in":
|
|
condition = fmt.Sprintf("%s IN (?)", filter.Column)
|
|
args = []interface{}{filter.Value}
|
|
default:
|
|
return "", nil
|
|
}
|
|
|
|
return condition, args
|
|
}
|
|
|
|
func (h *Handler) applyFilter(query common.SelectQuery, filter common.FilterOption) common.SelectQuery {
|
|
// Determine which method to use based on LogicOperator
|
|
useOrLogic := strings.EqualFold(filter.LogicOperator, "OR")
|
|
|
|
var condition string
|
|
var args []interface{}
|
|
|
|
switch filter.Operator {
|
|
case "eq":
|
|
condition = fmt.Sprintf("%s = ?", filter.Column)
|
|
args = []interface{}{filter.Value}
|
|
case "neq":
|
|
condition = fmt.Sprintf("%s != ?", filter.Column)
|
|
args = []interface{}{filter.Value}
|
|
case "gt":
|
|
condition = fmt.Sprintf("%s > ?", filter.Column)
|
|
args = []interface{}{filter.Value}
|
|
case "gte":
|
|
condition = fmt.Sprintf("%s >= ?", filter.Column)
|
|
args = []interface{}{filter.Value}
|
|
case "lt":
|
|
condition = fmt.Sprintf("%s < ?", filter.Column)
|
|
args = []interface{}{filter.Value}
|
|
case "lte":
|
|
condition = fmt.Sprintf("%s <= ?", filter.Column)
|
|
args = []interface{}{filter.Value}
|
|
case "like":
|
|
condition = fmt.Sprintf("%s LIKE ?", filter.Column)
|
|
args = []interface{}{filter.Value}
|
|
case "ilike":
|
|
condition = fmt.Sprintf("%s ILIKE ?", filter.Column)
|
|
args = []interface{}{filter.Value}
|
|
case "in":
|
|
condition = fmt.Sprintf("%s IN (?)", filter.Column)
|
|
args = []interface{}{filter.Value}
|
|
default:
|
|
return query
|
|
}
|
|
|
|
// Apply filter with appropriate logic operator
|
|
if useOrLogic {
|
|
return query.WhereOr(condition, args...)
|
|
}
|
|
return query.Where(condition, args...)
|
|
}
|
|
|
|
// parseTableName splits a table name that may contain schema into separate schema and table
|
|
func (h *Handler) parseTableName(fullTableName string) (schema, table string) {
|
|
if idx := strings.LastIndex(fullTableName, "."); idx != -1 {
|
|
return fullTableName[:idx], fullTableName[idx+1:]
|
|
}
|
|
return "", fullTableName
|
|
}
|
|
|
|
// getSchemaAndTable returns the schema and table name separately
|
|
// It checks SchemaProvider and TableNameProvider interfaces and handles cases where
|
|
// the table name may already include the schema (e.g., "public.users")
|
|
//
|
|
// Priority order:
|
|
// 1. If TableName() contains a schema (e.g., "myschema.mytable"), that schema takes precedence
|
|
// 2. If model implements SchemaProvider, use that schema
|
|
// 3. Otherwise, use the defaultSchema parameter
|
|
func (h *Handler) getSchemaAndTable(defaultSchema, entity string, model interface{}) (schema, table string) {
|
|
// First check if model provides a table name
|
|
// We check this FIRST because the table name might already contain the schema
|
|
if tableProvider, ok := model.(common.TableNameProvider); ok {
|
|
tableName := tableProvider.TableName()
|
|
|
|
// IMPORTANT: Check if the table name already contains a schema (e.g., "schema.table")
|
|
// This is common when models need to specify a different schema than the default
|
|
if tableSchema, tableOnly := h.parseTableName(tableName); tableSchema != "" {
|
|
// Table name includes schema - use it and ignore any other schema providers
|
|
logger.Debug("TableName() includes schema: %s.%s", tableSchema, tableOnly)
|
|
return tableSchema, tableOnly
|
|
}
|
|
|
|
// Table name is just the table name without schema
|
|
// Now determine which schema to use
|
|
if schemaProvider, ok := model.(common.SchemaProvider); ok {
|
|
schema = schemaProvider.SchemaName()
|
|
} else {
|
|
schema = defaultSchema
|
|
}
|
|
|
|
return schema, tableName
|
|
}
|
|
|
|
// No TableNameProvider, so check for schema and use entity as table name
|
|
if schemaProvider, ok := model.(common.SchemaProvider); ok {
|
|
schema = schemaProvider.SchemaName()
|
|
} else {
|
|
schema = defaultSchema
|
|
}
|
|
|
|
// Default to entity name as table
|
|
return schema, entity
|
|
}
|
|
|
|
// getTableName returns the full table name including schema.
|
|
// For most drivers the result is "schema.table". For SQLite, which does not
|
|
// support schema-qualified names, the schema and table are joined with an
|
|
// underscore: "schema_table".
|
|
func (h *Handler) getTableName(schema, entity string, model interface{}) string {
|
|
schemaName, tableName := h.getSchemaAndTable(schema, entity, model)
|
|
if schemaName != "" {
|
|
if h.db.DriverName() == "sqlite" {
|
|
return fmt.Sprintf("%s_%s", schemaName, tableName)
|
|
}
|
|
return fmt.Sprintf("%s.%s", schemaName, tableName)
|
|
}
|
|
return tableName
|
|
}
|
|
|
|
func (h *Handler) generateMetadata(schema, entity string, model interface{}) *common.TableMetadata {
|
|
modelType := reflect.TypeOf(model)
|
|
|
|
// Unwrap pointers, slices, and arrays to get to the base struct type
|
|
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
|
modelType = modelType.Elem()
|
|
}
|
|
|
|
// Validate that we have a struct type
|
|
if modelType == nil || modelType.Kind() != reflect.Struct {
|
|
logger.Error("Model type must be a struct, got %v for %s.%s", modelType, schema, entity)
|
|
return &common.TableMetadata{
|
|
Schema: schema,
|
|
Table: entity,
|
|
Columns: make([]common.Column, 0),
|
|
Relations: make([]string, 0),
|
|
}
|
|
}
|
|
|
|
metadata := &common.TableMetadata{
|
|
Schema: schema,
|
|
Table: entity,
|
|
Columns: make([]common.Column, 0),
|
|
Relations: make([]string, 0),
|
|
}
|
|
|
|
// Generate metadata using reflection (same logic as before)
|
|
for i := 0; i < modelType.NumField(); i++ {
|
|
field := modelType.Field(i)
|
|
|
|
if !field.IsExported() {
|
|
continue
|
|
}
|
|
|
|
gormTag := field.Tag.Get("gorm")
|
|
jsonTag := field.Tag.Get("json")
|
|
|
|
if jsonTag == "-" {
|
|
continue
|
|
}
|
|
|
|
jsonName := strings.Split(jsonTag, ",")[0]
|
|
if jsonName == "" {
|
|
jsonName = field.Name
|
|
}
|
|
|
|
if field.Type.Kind() == reflect.Slice ||
|
|
(field.Type.Kind() == reflect.Struct && field.Type.Name() != "Time") {
|
|
metadata.Relations = append(metadata.Relations, jsonName)
|
|
continue
|
|
}
|
|
|
|
column := common.Column{
|
|
Name: jsonName,
|
|
Type: getColumnType(field),
|
|
IsNullable: isNullable(field),
|
|
IsPrimary: strings.Contains(gormTag, "primaryKey"),
|
|
IsUnique: strings.Contains(gormTag, "unique") || strings.Contains(gormTag, "uniqueIndex"),
|
|
HasIndex: strings.Contains(gormTag, "index") || strings.Contains(gormTag, "uniqueIndex"),
|
|
}
|
|
|
|
metadata.Columns = append(metadata.Columns, column)
|
|
}
|
|
|
|
return metadata
|
|
}
|
|
|
|
func (h *Handler) sendResponse(w common.ResponseWriter, data interface{}, metadata *common.Metadata) {
|
|
w.SetHeader("Content-Type", "application/json")
|
|
err := w.WriteJSON(common.Response{
|
|
Success: true,
|
|
Data: data,
|
|
Metadata: metadata,
|
|
})
|
|
if err != nil {
|
|
logger.Error("Error sending response: %v", err)
|
|
}
|
|
}
|
|
|
|
func (h *Handler) sendError(w common.ResponseWriter, status int, code, message string, details interface{}) {
|
|
w.SetHeader("Content-Type", "application/json")
|
|
w.WriteHeader(status)
|
|
err := w.WriteJSON(common.Response{
|
|
Success: false,
|
|
Error: &common.APIError{
|
|
Code: code,
|
|
Message: message,
|
|
Details: details,
|
|
Detail: fmt.Sprintf("%v", details),
|
|
},
|
|
})
|
|
if err != nil {
|
|
logger.Error("Error sending response: %v", err)
|
|
}
|
|
}
|
|
|
|
// RegisterModel allows registering models at runtime
|
|
func (h *Handler) RegisterModel(schema, name string, model interface{}) error {
|
|
fullname := fmt.Sprintf("%s.%s", schema, name)
|
|
return h.registry.RegisterModel(fullname, model)
|
|
}
|
|
|
|
// shouldUseNestedProcessor determines if we should use nested CUD processing
|
|
// It checks if the data contains nested relations or a _request field
|
|
func (h *Handler) shouldUseNestedProcessor(data map[string]interface{}, model interface{}) bool {
|
|
return common.ShouldUseNestedProcessor(data, model, h)
|
|
}
|
|
|
|
// Helper functions
|
|
|
|
func getColumnType(field reflect.StructField) string {
|
|
// Check GORM type tag first
|
|
gormTag := field.Tag.Get("gorm")
|
|
if strings.Contains(gormTag, "type:") {
|
|
parts := strings.Split(gormTag, "type:")
|
|
if len(parts) > 1 {
|
|
typePart := strings.Split(parts[1], ";")[0]
|
|
return typePart
|
|
}
|
|
}
|
|
|
|
// Map Go types to SQL types
|
|
switch field.Type.Kind() {
|
|
case reflect.String:
|
|
return "string"
|
|
case reflect.Int, reflect.Int32:
|
|
return "integer"
|
|
case reflect.Int64:
|
|
return "bigint"
|
|
case reflect.Float32:
|
|
return "float"
|
|
case reflect.Float64:
|
|
return "double"
|
|
case reflect.Bool:
|
|
return "boolean"
|
|
default:
|
|
if field.Type.Name() == "Time" {
|
|
return "timestamp"
|
|
}
|
|
return "unknown"
|
|
}
|
|
}
|
|
|
|
func isNullable(field reflect.StructField) bool {
|
|
// Check if it's a pointer type
|
|
if field.Type.Kind() == reflect.Ptr {
|
|
return true
|
|
}
|
|
|
|
// Check if it's a null type from sql package
|
|
typeName := field.Type.Name()
|
|
if strings.HasPrefix(typeName, "Null") {
|
|
return true
|
|
}
|
|
|
|
// Check GORM tags
|
|
gormTag := field.Tag.Get("gorm")
|
|
return !strings.Contains(gormTag, "not null")
|
|
}
|
|
|
|
// Preload support functions
|
|
|
|
// GetRelationshipInfo implements common.RelationshipInfoProvider interface
|
|
func (h *Handler) GetRelationshipInfo(modelType reflect.Type, relationName string) *common.RelationshipInfo {
|
|
return common.GetRelationshipInfo(modelType, relationName)
|
|
}
|
|
|
|
func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, preloads []common.PreloadOption) (common.SelectQuery, error) {
|
|
modelType := reflect.TypeOf(model)
|
|
|
|
// Unwrap pointers, slices, and arrays to get to the base struct type
|
|
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
|
modelType = modelType.Elem()
|
|
}
|
|
|
|
// Validate that we have a struct type
|
|
if modelType == nil || modelType.Kind() != reflect.Struct {
|
|
logger.Warn("Cannot apply preloads to non-struct type: %v", modelType)
|
|
return query, nil
|
|
}
|
|
|
|
for idx := range preloads {
|
|
preload := preloads[idx]
|
|
logger.Debug("Processing preload for relation: %s", preload.Relation)
|
|
relInfo := common.GetRelationshipInfo(modelType, preload.Relation)
|
|
if relInfo == nil {
|
|
logger.Warn("Relation %s not found in model", preload.Relation)
|
|
continue
|
|
}
|
|
|
|
// Use the field name (capitalized) for ORM preloading
|
|
// ORMs like GORM and Bun expect the struct field name, not the JSON name
|
|
relationFieldName := relInfo.FieldName
|
|
|
|
// Validate and fix WHERE clause to ensure it contains the relation prefix
|
|
if len(preload.Where) > 0 {
|
|
fixedWhere, err := common.ValidateAndFixPreloadWhere(preload.Where, relationFieldName)
|
|
if err != nil {
|
|
logger.Error("Invalid preload WHERE clause for relation '%s': %v", relationFieldName, err)
|
|
return query, fmt.Errorf("invalid preload WHERE clause for relation '%s': %w", relationFieldName, err)
|
|
}
|
|
preload.Where = fixedWhere
|
|
}
|
|
|
|
logger.Debug("Applying preload: %s", relationFieldName)
|
|
query = query.PreloadRelation(relationFieldName, func(sq common.SelectQuery) common.SelectQuery {
|
|
if len(preload.Columns) == 0 && (len(preload.ComputedQL) > 0 || len(preload.OmitColumns) > 0) {
|
|
preload.Columns = reflection.GetSQLModelColumns(model)
|
|
}
|
|
|
|
// Handle column selection and omission
|
|
if len(preload.OmitColumns) > 0 {
|
|
allCols := reflection.GetSQLModelColumns(model)
|
|
// Remove omitted columns
|
|
preload.Columns = []string{}
|
|
for _, col := range allCols {
|
|
addCols := true
|
|
for _, omitCol := range preload.OmitColumns {
|
|
if col == omitCol {
|
|
addCols = false
|
|
break
|
|
}
|
|
}
|
|
if addCols {
|
|
preload.Columns = append(preload.Columns, col)
|
|
}
|
|
}
|
|
}
|
|
|
|
if len(preload.Columns) > 0 {
|
|
// Ensure foreign key is included in column selection for GORM to establish the relationship
|
|
columns := make([]string, len(preload.Columns))
|
|
copy(columns, preload.Columns)
|
|
|
|
// Add foreign key if not already present
|
|
if relInfo.ForeignKey != "" {
|
|
// Convert struct field name (e.g., DepartmentID) to snake_case (e.g., department_id)
|
|
foreignKeyColumn := toSnakeCase(relInfo.ForeignKey)
|
|
|
|
hasForeignKey := false
|
|
for _, col := range columns {
|
|
if col == foreignKeyColumn || col == relInfo.ForeignKey {
|
|
hasForeignKey = true
|
|
break
|
|
}
|
|
}
|
|
if !hasForeignKey {
|
|
columns = append(columns, foreignKeyColumn)
|
|
}
|
|
}
|
|
|
|
sq = sq.Column(columns...)
|
|
}
|
|
|
|
if len(preload.Filters) > 0 {
|
|
for _, filter := range preload.Filters {
|
|
sq = h.applyFilter(sq, filter)
|
|
}
|
|
}
|
|
if len(preload.Sort) > 0 {
|
|
for _, sort := range preload.Sort {
|
|
sq = sq.Order(fmt.Sprintf("%s %s", sort.Column, sort.Direction))
|
|
}
|
|
}
|
|
|
|
if len(preload.Where) > 0 {
|
|
// Build RequestOptions with all preloads to allow references to sibling relations
|
|
preloadOpts := &common.RequestOptions{Preload: preloads}
|
|
sanitizedWhere := common.SanitizeWhereClause(preload.Where, reflection.ExtractTableNameOnly(preload.Relation), preloadOpts)
|
|
// Ensure outer parentheses to prevent OR logic from escaping
|
|
sanitizedWhere = common.EnsureOuterParentheses(sanitizedWhere)
|
|
if len(sanitizedWhere) > 0 {
|
|
sq = sq.Where(sanitizedWhere)
|
|
}
|
|
}
|
|
|
|
if preload.Limit != nil && *preload.Limit > 0 {
|
|
sq = sq.Limit(*preload.Limit)
|
|
}
|
|
|
|
return sq
|
|
})
|
|
|
|
logger.Debug("Applied Preload for relation: %s (field: %s)", preload.Relation, relationFieldName)
|
|
}
|
|
|
|
return query, nil
|
|
}
|
|
|
|
// toSnakeCase converts a PascalCase or camelCase string to snake_case
|
|
func toSnakeCase(s string) string {
|
|
var result strings.Builder
|
|
runes := []rune(s)
|
|
|
|
for i := 0; i < len(runes); i++ {
|
|
r := runes[i]
|
|
|
|
if i > 0 && r >= 'A' && r <= 'Z' {
|
|
// Check if previous character is lowercase or if next character is lowercase
|
|
prevIsLower := runes[i-1] >= 'a' && runes[i-1] <= 'z'
|
|
nextIsLower := i+1 < len(runes) && runes[i+1] >= 'a' && runes[i+1] <= 'z'
|
|
|
|
// Add underscore if this is the start of a new word
|
|
// (previous was lowercase OR this is followed by lowercase)
|
|
if prevIsLower || nextIsLower {
|
|
result.WriteByte('_')
|
|
}
|
|
}
|
|
|
|
result.WriteRune(r)
|
|
}
|
|
return strings.ToLower(result.String())
|
|
}
|
|
|
|
// setRowNumbersOnRecords sets the RowNumber field on each record if it exists
|
|
// The row number is calculated as offset + index + 1 (1-based)
|
|
func (h *Handler) setRowNumbersOnRecords(records interface{}, offset int) {
|
|
// Get the reflect value of the records
|
|
recordsValue := reflect.ValueOf(records)
|
|
if recordsValue.Kind() == reflect.Ptr {
|
|
recordsValue = recordsValue.Elem()
|
|
}
|
|
|
|
// Ensure it's a slice
|
|
if recordsValue.Kind() != reflect.Slice {
|
|
logger.Debug("setRowNumbersOnRecords: records is not a slice, skipping")
|
|
return
|
|
}
|
|
|
|
// Iterate through each record
|
|
for i := 0; i < recordsValue.Len(); i++ {
|
|
record := recordsValue.Index(i)
|
|
|
|
// Dereference if it's a pointer
|
|
if record.Kind() == reflect.Ptr {
|
|
if record.IsNil() {
|
|
continue
|
|
}
|
|
record = record.Elem()
|
|
}
|
|
|
|
// Ensure it's a struct
|
|
if record.Kind() != reflect.Struct {
|
|
continue
|
|
}
|
|
|
|
// Try to find and set the RowNumber field
|
|
rowNumberField := record.FieldByName("RowNumber")
|
|
if rowNumberField.IsValid() && rowNumberField.CanSet() {
|
|
// Check if the field is of type int64
|
|
if rowNumberField.Kind() == reflect.Int64 {
|
|
rowNum := int64(offset + i + 1)
|
|
rowNumberField.SetInt(rowNum)
|
|
logger.Debug("Set RowNumber=%d for record index %d", rowNum, i)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// HandleOpenAPI generates and returns the OpenAPI specification
|
|
func (h *Handler) HandleOpenAPI(w common.ResponseWriter, r common.Request) {
|
|
if h.openAPIGenerator == nil {
|
|
logger.Error("OpenAPI generator not configured")
|
|
h.sendError(w, http.StatusInternalServerError, "openapi_not_configured", "OpenAPI generation not configured", nil)
|
|
return
|
|
}
|
|
|
|
spec, err := h.openAPIGenerator()
|
|
if err != nil {
|
|
logger.Error("Failed to generate OpenAPI spec: %v", err)
|
|
h.sendError(w, http.StatusInternalServerError, "openapi_generation_error", "Failed to generate OpenAPI specification", err)
|
|
return
|
|
}
|
|
|
|
w.SetHeader("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusOK)
|
|
_, err = w.Write([]byte(spec))
|
|
if err != nil {
|
|
logger.Error("Error sending OpenAPI spec response: %v", err)
|
|
}
|
|
}
|
|
|
|
// SetOpenAPIGenerator sets the OpenAPI generator function
|
|
func (h *Handler) SetOpenAPIGenerator(generator func() (string, error)) {
|
|
h.openAPIGenerator = generator
|
|
}
|