From c0c669bd3d425bf4f29d15676a418a68a0e5c417 Mon Sep 17 00:00:00 2001 From: Hein Date: Mon, 5 Jan 2026 12:31:01 +0200 Subject: [PATCH 01/31] feat(handler): enhance update logic to merge existing records with incoming data --- pkg/funcspec/function_api.go | 140 +++++++++++++++++----------------- pkg/funcspec/hooks.go | 5 ++ pkg/resolvespec/handler.go | 141 +++++++++++++++++++++++++++++++++-- pkg/restheadspec/handler.go | 41 +++++++++- 4 files changed, 248 insertions(+), 79 deletions(-) diff --git a/pkg/funcspec/function_api.go b/pkg/funcspec/function_api.go index cf8787c..6d6328d 100644 --- a/pkg/funcspec/function_api.go +++ b/pkg/funcspec/function_api.go @@ -123,27 +123,6 @@ func (h *Handler) SqlQueryList(sqlquery string, options SqlQueryOptions) HTTPFun ComplexAPI: complexAPI, } - // Execute BeforeQueryList hook - if err := h.hooks.Execute(BeforeQueryList, hookCtx); err != nil { - logger.Error("BeforeQueryList hook failed: %v", err) - sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err) - return - } - - // Check if hook aborted the operation - if hookCtx.Abort { - if hookCtx.AbortCode == 0 { - hookCtx.AbortCode = http.StatusBadRequest - } - sendError(w, hookCtx.AbortCode, "operation_aborted", hookCtx.AbortMessage, nil) - return - } - - // Use potentially modified SQL query and variables from hooks - sqlquery = hookCtx.SQLQuery - variables = hookCtx.Variables - // complexAPI = hookCtx.ComplexAPI - // Extract input variables from SQL query (placeholders like [variable]) sqlquery = h.extractInputVariables(sqlquery, &inputvars) @@ -203,6 +182,27 @@ func (h *Handler) SqlQueryList(sqlquery string, options SqlQueryOptions) HTTPFun // Execute query within transaction err := h.db.RunInTransaction(ctx, func(tx common.Database) error { + // Set transaction in hook context for hooks to use + hookCtx.Tx = tx + + // Execute BeforeQueryList hook (inside transaction) + if err := h.hooks.Execute(BeforeQueryList, hookCtx); err != nil { + logger.Error("BeforeQueryList hook failed: %v", err) + sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err) + return err + } + + // Check if hook aborted the operation + if hookCtx.Abort { + if hookCtx.AbortCode == 0 { + hookCtx.AbortCode = http.StatusBadRequest + } + sendError(w, hookCtx.AbortCode, "operation_aborted", hookCtx.AbortMessage, nil) + return fmt.Errorf("operation aborted: %s", hookCtx.AbortMessage) + } + + // Use potentially modified SQL query from hook + sqlquery = hookCtx.SQLQuery sqlqueryCnt := sqlquery // Parse sorting and pagination parameters @@ -286,6 +286,21 @@ func (h *Handler) SqlQueryList(sqlquery string, options SqlQueryOptions) HTTPFun } total = hookCtx.Total + // Execute AfterQueryList hook (inside transaction) + hookCtx.Result = dbobjlist + hookCtx.Total = total + hookCtx.Error = nil + if err := h.hooks.Execute(AfterQueryList, hookCtx); err != nil { + logger.Error("AfterQueryList hook failed: %v", err) + sendError(w, http.StatusInternalServerError, "hook_error", "Hook execution failed", err) + return err + } + // Use potentially modified result from hook + if modifiedResult, ok := hookCtx.Result.([]map[string]interface{}); ok { + dbobjlist = modifiedResult + } + total = hookCtx.Total + return nil }) @@ -294,21 +309,6 @@ func (h *Handler) SqlQueryList(sqlquery string, options SqlQueryOptions) HTTPFun return } - // Execute AfterQueryList hook - hookCtx.Result = dbobjlist - hookCtx.Total = total - hookCtx.Error = err - if err := h.hooks.Execute(AfterQueryList, hookCtx); err != nil { - logger.Error("AfterQueryList hook failed: %v", err) - sendError(w, http.StatusInternalServerError, "hook_error", "Hook execution failed", err) - return - } - // Use potentially modified result from hook - if modifiedResult, ok := hookCtx.Result.([]map[string]interface{}); ok { - dbobjlist = modifiedResult - } - total = hookCtx.Total - // Set response headers respOffset := 0 if offsetStr := r.URL.Query().Get("offset"); offsetStr != "" { @@ -459,26 +459,6 @@ func (h *Handler) SqlQuery(sqlquery string, options SqlQueryOptions) HTTPFuncTyp ComplexAPI: complexAPI, } - // Execute BeforeQuery hook - if err := h.hooks.Execute(BeforeQuery, hookCtx); err != nil { - logger.Error("BeforeQuery hook failed: %v", err) - sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err) - return - } - - // Check if hook aborted the operation - if hookCtx.Abort { - if hookCtx.AbortCode == 0 { - hookCtx.AbortCode = http.StatusBadRequest - } - sendError(w, hookCtx.AbortCode, "operation_aborted", hookCtx.AbortMessage, nil) - return - } - - // Use potentially modified SQL query and variables from hooks - sqlquery = hookCtx.SQLQuery - variables = hookCtx.Variables - // Extract input variables from SQL query sqlquery = h.extractInputVariables(sqlquery, &inputvars) @@ -554,6 +534,28 @@ func (h *Handler) SqlQuery(sqlquery string, options SqlQueryOptions) HTTPFuncTyp // Execute query within transaction err := h.db.RunInTransaction(ctx, func(tx common.Database) error { + // Set transaction in hook context for hooks to use + hookCtx.Tx = tx + + // Execute BeforeQuery hook (inside transaction) + if err := h.hooks.Execute(BeforeQuery, hookCtx); err != nil { + logger.Error("BeforeQuery hook failed: %v", err) + sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err) + return err + } + + // Check if hook aborted the operation + if hookCtx.Abort { + if hookCtx.AbortCode == 0 { + hookCtx.AbortCode = http.StatusBadRequest + } + sendError(w, hookCtx.AbortCode, "operation_aborted", hookCtx.AbortMessage, nil) + return fmt.Errorf("operation aborted: %s", hookCtx.AbortMessage) + } + + // Use potentially modified SQL query from hook + sqlquery = hookCtx.SQLQuery + // Execute BeforeSQLExec hook if err := h.hooks.Execute(BeforeSQLExec, hookCtx); err != nil { logger.Error("BeforeSQLExec hook failed: %v", err) @@ -586,6 +588,19 @@ func (h *Handler) SqlQuery(sqlquery string, options SqlQueryOptions) HTTPFuncTyp dbobj = modifiedResult } + // Execute AfterQuery hook (inside transaction) + hookCtx.Result = dbobj + hookCtx.Error = nil + if err := h.hooks.Execute(AfterQuery, hookCtx); err != nil { + logger.Error("AfterQuery hook failed: %v", err) + sendError(w, http.StatusInternalServerError, "hook_error", "Hook execution failed", err) + return err + } + // Use potentially modified result from hook + if modifiedResult, ok := hookCtx.Result.(map[string]interface{}); ok { + dbobj = modifiedResult + } + return nil }) @@ -594,19 +609,6 @@ func (h *Handler) SqlQuery(sqlquery string, options SqlQueryOptions) HTTPFuncTyp return } - // Execute AfterQuery hook - hookCtx.Result = dbobj - hookCtx.Error = err - if err := h.hooks.Execute(AfterQuery, hookCtx); err != nil { - logger.Error("AfterQuery hook failed: %v", err) - sendError(w, http.StatusInternalServerError, "hook_error", "Hook execution failed", err) - return - } - // Use potentially modified result from hook - if modifiedResult, ok := hookCtx.Result.(map[string]interface{}); ok { - dbobj = modifiedResult - } - // Execute BeforeResponse hook hookCtx.Result = dbobj if err := h.hooks.Execute(BeforeResponse, hookCtx); err != nil { diff --git a/pkg/funcspec/hooks.go b/pkg/funcspec/hooks.go index 26d46a3..d04d19e 100644 --- a/pkg/funcspec/hooks.go +++ b/pkg/funcspec/hooks.go @@ -5,6 +5,7 @@ import ( "fmt" "net/http" + "github.com/bitechdev/ResolveSpec/pkg/common" "github.com/bitechdev/ResolveSpec/pkg/logger" "github.com/bitechdev/ResolveSpec/pkg/security" ) @@ -46,6 +47,10 @@ type HookContext struct { // User context UserContext *security.UserContext + // Tx provides access to the database/transaction for executing additional SQL + // This allows hooks to run custom queries in addition to the main Query chain + Tx common.Database + // Pagination and filtering (for list queries) SortColumns string Limit int diff --git a/pkg/resolvespec/handler.go b/pkg/resolvespec/handler.go index c65f26d..d57826c 100644 --- a/pkg/resolvespec/handler.go +++ b/pkg/resolvespec/handler.go @@ -698,20 +698,83 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url } // Standard processing without nested relations - query := h.db.NewUpdate().Table(tableName).SetMap(updates) + // Get the primary key name + pkName := reflection.GetPrimaryKeyName(model) - // Apply conditions + // First, read the existing record from the database + existingRecord := reflect.New(reflect.TypeOf(model).Elem()).Interface() + selectQuery := h.db.NewSelect().Model(existingRecord) + + // Apply conditions to select if urlID != "" { logger.Debug("Updating by URL ID: %s", urlID) - query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(reflection.GetPrimaryKeyName(model))), urlID) + selectQuery = selectQuery.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), urlID) } else if reqID != nil { switch id := reqID.(type) { case string: logger.Debug("Updating by request ID: %s", id) - query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(reflection.GetPrimaryKeyName(model))), id) + selectQuery = selectQuery.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id) case []string: - logger.Debug("Updating by multiple IDs: %v", id) - query = query.Where(fmt.Sprintf("%s IN (?)", common.QuoteIdent(reflection.GetPrimaryKeyName(model))), id) + if len(id) > 0 { + logger.Debug("Updating by multiple IDs: %v", id) + selectQuery = selectQuery.Where(fmt.Sprintf("%s IN (?)", common.QuoteIdent(pkName)), id) + } + } + } + + if err := selectQuery.ScanModel(ctx); err != nil { + if err == sql.ErrNoRows { + logger.Warn("No records found to update") + h.sendError(w, http.StatusNotFound, "not_found", "No records found to update", nil) + return + } + logger.Error("Error fetching existing record: %v", err) + h.sendError(w, http.StatusInternalServerError, "update_error", "Error fetching existing record", err) + return + } + + // Convert existing record to map + existingMap := make(map[string]interface{}) + jsonData, err := json.Marshal(existingRecord) + if err != nil { + logger.Error("Error marshaling existing record: %v", err) + h.sendError(w, http.StatusInternalServerError, "update_error", "Error processing existing record", err) + return + } + if err := json.Unmarshal(jsonData, &existingMap); err != nil { + logger.Error("Error unmarshaling existing record: %v", err) + h.sendError(w, http.StatusInternalServerError, "update_error", "Error processing existing record", err) + return + } + + // Merge only non-null and non-empty values from the incoming request into the existing record + for key, newValue := range updates { + // Skip if the value is nil + if newValue == nil { + continue + } + + // Skip if the value is an empty string + if strVal, ok := newValue.(string); ok && strVal == "" { + continue + } + + // Update the existing map with the new value + existingMap[key] = newValue + } + + // Build update query with merged data + query := h.db.NewUpdate().Table(tableName).SetMap(existingMap) + + // Apply conditions + if urlID != "" { + query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), urlID) + } else if reqID != nil { + switch id := reqID.(type) { + case string: + query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id) + case []string: + query = query.Where(fmt.Sprintf("%s IN (?)", common.QuoteIdent(pkName)), id) } } @@ -782,11 +845,42 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url } // Standard batch update without nested relations + pkName := reflection.GetPrimaryKeyName(model) err := h.db.RunInTransaction(ctx, func(tx common.Database) error { for _, item := range updates { if itemID, ok := item["id"]; ok { + // First, read the existing record + existingRecord := reflect.New(reflect.TypeOf(model).Elem()).Interface() + selectQuery := tx.NewSelect().Model(existingRecord).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), itemID) + if err := selectQuery.ScanModel(ctx); err != nil { + if err == sql.ErrNoRows { + continue // Skip if record not found + } + return fmt.Errorf("failed to fetch existing record: %w", err) + } - txQuery := tx.NewUpdate().Table(tableName).SetMap(item).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(reflection.GetPrimaryKeyName(model))), itemID) + // Convert existing record to map + existingMap := make(map[string]interface{}) + jsonData, err := json.Marshal(existingRecord) + if err != nil { + return fmt.Errorf("failed to marshal existing record: %w", err) + } + if err := json.Unmarshal(jsonData, &existingMap); err != nil { + return fmt.Errorf("failed to unmarshal existing record: %w", err) + } + + // Merge only non-null and non-empty values + for key, newValue := range item { + if newValue == nil { + continue + } + if strVal, ok := newValue.(string); ok && strVal == "" { + continue + } + existingMap[key] = newValue + } + + txQuery := tx.NewUpdate().Table(tableName).SetMap(existingMap).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), itemID) if _, err := txQuery.Exec(ctx); err != nil { return err } @@ -857,13 +951,44 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url } // Standard batch update without nested relations + pkName := reflection.GetPrimaryKeyName(model) list := make([]interface{}, 0) err := h.db.RunInTransaction(ctx, func(tx common.Database) error { for _, item := range updates { if itemMap, ok := item.(map[string]interface{}); ok { if itemID, ok := itemMap["id"]; ok { + // First, read the existing record + existingRecord := reflect.New(reflect.TypeOf(model).Elem()).Interface() + selectQuery := tx.NewSelect().Model(existingRecord).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), itemID) + if err := selectQuery.ScanModel(ctx); err != nil { + if err == sql.ErrNoRows { + continue // Skip if record not found + } + return fmt.Errorf("failed to fetch existing record: %w", err) + } - txQuery := tx.NewUpdate().Table(tableName).SetMap(itemMap).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(reflection.GetPrimaryKeyName(model))), itemID) + // Convert existing record to map + existingMap := make(map[string]interface{}) + jsonData, err := json.Marshal(existingRecord) + if err != nil { + return fmt.Errorf("failed to marshal existing record: %w", err) + } + if err := json.Unmarshal(jsonData, &existingMap); err != nil { + return fmt.Errorf("failed to unmarshal existing record: %w", err) + } + + // Merge only non-null and non-empty values + for key, newValue := range itemMap { + if newValue == nil { + continue + } + if strVal, ok := newValue.(string); ok && strVal == "" { + continue + } + existingMap[key] = newValue + } + + txQuery := tx.NewUpdate().Table(tableName).SetMap(existingMap).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), itemID) if _, err := txQuery.Exec(ctx); err != nil { return err } diff --git a/pkg/restheadspec/handler.go b/pkg/restheadspec/handler.go index 5a71a20..4d567ae 100644 --- a/pkg/restheadspec/handler.go +++ b/pkg/restheadspec/handler.go @@ -1239,6 +1239,26 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id // Create temporary nested processor with transaction txNestedProcessor := common.NewNestedCUDProcessor(tx, h.registry, h) + // First, read the existing record from the database + existingRecord := reflect.New(reflect.TypeOf(model).Elem()).Interface() + selectQuery := tx.NewSelect().Model(existingRecord).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), targetID) + if err := selectQuery.ScanModel(ctx); err != nil { + if err == sql.ErrNoRows { + return fmt.Errorf("record not found with ID: %v", targetID) + } + return fmt.Errorf("failed to fetch existing record: %w", err) + } + + // Convert existing record to map + existingMap := make(map[string]interface{}) + jsonData, err := json.Marshal(existingRecord) + if err != nil { + return fmt.Errorf("failed to marshal existing record: %w", err) + } + if err := json.Unmarshal(jsonData, &existingMap); err != nil { + return fmt.Errorf("failed to unmarshal existing record: %w", err) + } + // Extract nested relations if present (but don't process them yet) var nestedRelations map[string]interface{} if h.shouldUseNestedProcessor(dataMap, model) { @@ -1251,8 +1271,25 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id nestedRelations = relations } + // Merge only non-null and non-empty values from the incoming request into the existing record + for key, newValue := range dataMap { + // Skip if the value is nil + if newValue == nil { + continue + } + + // Skip if the value is an empty string + if strVal, ok := newValue.(string); ok && strVal == "" { + continue + } + + // Update the existing map with the new value + existingMap[key] = newValue + } + // Ensure ID is in the data map for the update - dataMap[pkName] = targetID + existingMap[pkName] = targetID + dataMap = existingMap // Populate model instance from dataMap to preserve custom types (like SqlJSONB) // Get the type of the model, handling both pointer and non-pointer types @@ -1297,7 +1334,7 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id // Fetch the updated record to return the new values modelValue := reflect.New(reflect.TypeOf(model)).Interface() - selectQuery := tx.NewSelect().Model(modelValue).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), targetID) + selectQuery = tx.NewSelect().Model(modelValue).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), targetID) if err := selectQuery.ScanModel(ctx); err != nil { return fmt.Errorf("failed to fetch updated record: %w", err) } From d8df1bdac2f9f4ddd00eabd39754d58e6ddc5e07 Mon Sep 17 00:00:00 2001 From: Hein Date: Mon, 5 Jan 2026 17:56:54 +0200 Subject: [PATCH 02/31] =?UTF-8?q?feat(funcspec):=20=E2=9C=A8=20add=20JSON?= =?UTF-8?q?=20and=20UUID=20handling=20in=20normalization=20*=20Enhance=20n?= =?UTF-8?q?ormalization=20to=20support=20JSON=20strings=20as=20json.RawMes?= =?UTF-8?q?sage=20*=20Add=20support=20for=20UUID=20formatting=20*=20Mainta?= =?UTF-8?q?in=20existing=20behavior=20for=20other=20types?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/funcspec/function_api.go | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/pkg/funcspec/function_api.go b/pkg/funcspec/function_api.go index 6d6328d..ffa7ec1 100644 --- a/pkg/funcspec/function_api.go +++ b/pkg/funcspec/function_api.go @@ -12,6 +12,8 @@ import ( "strings" "time" + "github.com/google/uuid" + "github.com/bitechdev/ResolveSpec/pkg/common" "github.com/bitechdev/ResolveSpec/pkg/logger" "github.com/bitechdev/ResolveSpec/pkg/restheadspec" @@ -1099,9 +1101,25 @@ func normalizePostgresValue(value interface{}) interface{} { case map[string]interface{}: // Recursively normalize nested maps return normalizePostgresTypes(v) - + case string: + var jsonObj interface{} + if err := json.Unmarshal([]byte(v), &jsonObj); err == nil { + // It's valid JSON, return as json.RawMessage so it's not double-encoded + return json.RawMessage(v) + } + return v + case uuid.UUID: + return v.String() + case time.Time: + return v.Format(time.RFC3339) + case bool, int, int8, int16, int32, int64, float32, float64, uint, uint8, uint16, uint32, uint64: + return v default: - // For other types (int, float, string, bool, etc.), return as-is + // For other types (int, float, bool, etc.), return as-is + // Check stringers + if str, ok := v.(fmt.Stringer); ok { + return str.String() + } return v } } From 62a8e56f1b24b222f542c86c165bcb82ca107c65 Mon Sep 17 00:00:00 2001 From: Hein Date: Tue, 6 Jan 2026 10:45:23 +0200 Subject: [PATCH 03/31] =?UTF-8?q?feat(reflection):=20=E2=9C=A8=20add=20Get?= =?UTF-8?q?PointerElement=20function=20for=20type=20handling?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Introduced GetPointerElement to simplify pointer type extraction. * Updated handleUpdate methods to utilize GetPointerElement for better clarity and maintainability. --- pkg/reflection/helpers.go | 17 +++++++++++++++++ pkg/resolvespec/handler.go | 6 +++--- pkg/restheadspec/handler.go | 10 +++------- 3 files changed, 23 insertions(+), 10 deletions(-) diff --git a/pkg/reflection/helpers.go b/pkg/reflection/helpers.go index cc6787f..155f30c 100644 --- a/pkg/reflection/helpers.go +++ b/pkg/reflection/helpers.go @@ -47,3 +47,20 @@ func ExtractTableNameOnly(fullName string) string { return fullName[startIndex:] } + +// GetPointerElement returns the element type if the provided reflect.Type is a pointer. +// If the type is a slice of pointers, it returns the element type of the pointer within the slice. +// If neither condition is met, it returns the original type. +func GetPointerElement(v reflect.Type) reflect.Type { + if v.Kind() == reflect.Ptr { + return v.Elem() + } + if v.Kind() == reflect.Slice && v.Elem().Kind() == reflect.Ptr { + subElem := v.Elem() + if subElem.Elem().Kind() == reflect.Ptr { + return subElem.Elem().Elem() + } + return v.Elem() + } + return v +} diff --git a/pkg/resolvespec/handler.go b/pkg/resolvespec/handler.go index d57826c..4a1aea8 100644 --- a/pkg/resolvespec/handler.go +++ b/pkg/resolvespec/handler.go @@ -702,7 +702,7 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url pkName := reflection.GetPrimaryKeyName(model) // First, read the existing record from the database - existingRecord := reflect.New(reflect.TypeOf(model).Elem()).Interface() + existingRecord := reflect.New(reflection.GetPointerElement(reflect.TypeOf(model))).Interface() selectQuery := h.db.NewSelect().Model(existingRecord) // Apply conditions to select @@ -850,7 +850,7 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url for _, item := range updates { if itemID, ok := item["id"]; ok { // First, read the existing record - existingRecord := reflect.New(reflect.TypeOf(model).Elem()).Interface() + existingRecord := reflect.New(reflection.GetPointerElement(reflect.TypeOf(model))).Interface() selectQuery := tx.NewSelect().Model(existingRecord).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), itemID) if err := selectQuery.ScanModel(ctx); err != nil { if err == sql.ErrNoRows { @@ -958,7 +958,7 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url if itemMap, ok := item.(map[string]interface{}); ok { if itemID, ok := itemMap["id"]; ok { // First, read the existing record - existingRecord := reflect.New(reflect.TypeOf(model).Elem()).Interface() + existingRecord := reflect.New(reflection.GetPointerElement(reflect.TypeOf(model))).Interface() selectQuery := tx.NewSelect().Model(existingRecord).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), itemID) if err := selectQuery.ScanModel(ctx); err != nil { if err == sql.ErrNoRows { diff --git a/pkg/restheadspec/handler.go b/pkg/restheadspec/handler.go index 4d567ae..3a47fad 100644 --- a/pkg/restheadspec/handler.go +++ b/pkg/restheadspec/handler.go @@ -1240,7 +1240,7 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id txNestedProcessor := common.NewNestedCUDProcessor(tx, h.registry, h) // First, read the existing record from the database - existingRecord := reflect.New(reflect.TypeOf(model).Elem()).Interface() + existingRecord := reflect.New(reflection.GetPointerElement(reflect.TypeOf(model))).Interface() selectQuery := tx.NewSelect().Model(existingRecord).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), targetID) if err := selectQuery.ScanModel(ctx); err != nil { if err == sql.ErrNoRows { @@ -1294,9 +1294,7 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id // Populate model instance from dataMap to preserve custom types (like SqlJSONB) // Get the type of the model, handling both pointer and non-pointer types modelType := reflect.TypeOf(model) - if modelType.Kind() == reflect.Ptr { - modelType = modelType.Elem() - } + modelType = reflection.GetPointerElement(modelType) modelInstance := reflect.New(modelType).Interface() if err := reflection.MapToStruct(dataMap, modelInstance); err != nil { return fmt.Errorf("failed to populate model from data: %w", err) @@ -1600,9 +1598,7 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id // First, fetch the record that will be deleted modelType := reflect.TypeOf(model) - if modelType.Kind() == reflect.Ptr { - modelType = modelType.Elem() - } + modelType = reflection.GetPointerElement(modelType) recordToDelete := reflect.New(modelType).Interface() selectQuery := h.db.NewSelect().Model(recordToDelete).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id) From 987244019c9b5410682709a5f51b68ae0cd10ac5 Mon Sep 17 00:00:00 2001 From: Hein Date: Tue, 6 Jan 2026 14:05:36 +0200 Subject: [PATCH 04/31] =?UTF-8?q?feat(cors):=20=E2=9C=A8=20enhance=20CORS?= =?UTF-8?q?=20configuration=20with=20dynamic=20origins?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Update CORSConfig to allow dynamic origins based on server instances. * Add ExternalURLs field to ServerInstanceConfig for additional CORS support. * Implement GetIPs function to retrieve non-local IP addresses for CORS. --- pkg/common/cors.go | 24 +++++++++++++++++++++++- pkg/config/config.go | 3 +++ pkg/config/manager.go | 13 ++++++++++++- pkg/config/server.go | 42 ++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 80 insertions(+), 2 deletions(-) diff --git a/pkg/common/cors.go b/pkg/common/cors.go index 8c39fb1..529a806 100644 --- a/pkg/common/cors.go +++ b/pkg/common/cors.go @@ -3,6 +3,8 @@ package common import ( "fmt" "strings" + + "github.com/bitechdev/ResolveSpec/pkg/config" ) // CORSConfig holds CORS configuration @@ -15,8 +17,28 @@ type CORSConfig struct { // DefaultCORSConfig returns a default CORS configuration suitable for HeadSpec func DefaultCORSConfig() CORSConfig { + configManager := config.GetConfigManager() + cfg, _ := configManager.GetConfig() + hosts := make([]string, 0) + //hosts = append(hosts, "*") + + _, _, ipsList := config.GetIPs() + + for _, server := range cfg.Servers.Instances { + hosts = append(hosts, fmt.Sprintf("http://%s:%d", server.Host, server.Port)) + hosts = append(hosts, fmt.Sprintf("https://%s:%d", server.Host, server.Port)) + hosts = append(hosts, fmt.Sprintf("http://%s:%d", "localhost", server.Port)) + for _, extURL := range server.ExternalURLs { + hosts = append(hosts, extURL) + } + for _, ip := range ipsList { + hosts = append(hosts, fmt.Sprintf("http://%s:%d", ip.String(), server.Port)) + hosts = append(hosts, fmt.Sprintf("https://%s:%d", ip.String(), server.Port)) + } + } + return CORSConfig{ - AllowedOrigins: []string{"*"}, + AllowedOrigins: hosts, AllowedMethods: []string{"GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"}, AllowedHeaders: GetHeadSpecHeaders(), MaxAge: 86400, // 24 hours diff --git a/pkg/config/config.go b/pkg/config/config.go index faa8387..b6265bf 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -73,6 +73,9 @@ type ServerInstanceConfig struct { // Tags for organization and filtering Tags map[string]string `mapstructure:"tags"` + + // ExternalURLs are additional URLs that this server instance is accessible from (for CORS) for proxy setups + ExternalURLs []string `mapstructure:"external_urls"` } // TracingConfig holds OpenTelemetry tracing configuration diff --git a/pkg/config/manager.go b/pkg/config/manager.go index ec6351c..391c6b0 100644 --- a/pkg/config/manager.go +++ b/pkg/config/manager.go @@ -12,6 +12,16 @@ type Manager struct { v *viper.Viper } +var configInstance *Manager + +// GetConfigManager returns a singleton configuration manager instance +func GetConfigManager() *Manager { + if configInstance == nil { + configInstance = NewManager() + } + return configInstance +} + // NewManager creates a new configuration manager with defaults func NewManager() *Manager { v := viper.New() @@ -32,7 +42,8 @@ func NewManager() *Manager { // Set default values setDefaults(v) - return &Manager{v: v} + configInstance = &Manager{v: v} + return configInstance } // NewManagerWithOptions creates a new configuration manager with custom options diff --git a/pkg/config/server.go b/pkg/config/server.go index f28008c..a532f5e 100644 --- a/pkg/config/server.go +++ b/pkg/config/server.go @@ -2,6 +2,9 @@ package config import ( "fmt" + "net" + "os" + "strings" ) // ApplyGlobalDefaults applies global server defaults to this instance @@ -105,3 +108,42 @@ func (sc *ServersConfig) GetDefault() (*ServerInstanceConfig, error) { return &instance, nil } + +// GetIPs - GetIP for pc +func GetIPs() (string, string, []net.IP) { + defer func() { + if err := recover(); err != nil { + fmt.Println("Recovered in GetIPs", err) + } + }() + hostname, _ := os.Hostname() + ipaddrlist := make([]net.IP, 0) + iplist := "" + addrs, err := net.LookupIP(hostname) + if err != nil { + return hostname, iplist, ipaddrlist + } + + for _, a := range addrs { + //cfg.LogInfo("\nFound IP Host Address: %s", a) + if strings.Contains(a.String(), "127.0.0.1") { + continue + } + iplist = fmt.Sprintf("%s,%s", iplist, a) + ipaddrlist = append(ipaddrlist, a) + } + if iplist == "" { + iff, _ := net.InterfaceAddrs() + for _, a := range iff { + //cfg.LogInfo("\nFound IP Address: %s", a) + if strings.Contains(a.String(), "127.0.0.1") { + continue + } + iplist = fmt.Sprintf("%s,%s", iplist, a) + + } + + } + iplist = strings.TrimLeft(iplist, ",") + return hostname, iplist, ipaddrlist +} From 6ea200bb2b89d05a773ecfab4e3a00e872fc2efd Mon Sep 17 00:00:00 2001 From: Hein Date: Tue, 6 Jan 2026 14:07:56 +0200 Subject: [PATCH 05/31] =?UTF-8?q?refactor(cors):=20=F0=9F=9B=A0=EF=B8=8F?= =?UTF-8?q?=20improve=20host=20handling=20in=20CORS=20config=20*=20Change?= =?UTF-8?q?=20loop=20to=20use=20index=20for=20server=20instances=20*=20Sim?= =?UTF-8?q?plify=20appending=20external=20URLs=20*=20Clean=20up=20commente?= =?UTF-8?q?d=20code=20for=20clarity?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/common/cors.go | 9 ++++----- pkg/config/server.go | 8 ++++---- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/pkg/common/cors.go b/pkg/common/cors.go index 529a806..0325af6 100644 --- a/pkg/common/cors.go +++ b/pkg/common/cors.go @@ -20,17 +20,16 @@ func DefaultCORSConfig() CORSConfig { configManager := config.GetConfigManager() cfg, _ := configManager.GetConfig() hosts := make([]string, 0) - //hosts = append(hosts, "*") + // hosts = append(hosts, "*") _, _, ipsList := config.GetIPs() - for _, server := range cfg.Servers.Instances { + for i := range cfg.Servers.Instances { + server := cfg.Servers.Instances[i] hosts = append(hosts, fmt.Sprintf("http://%s:%d", server.Host, server.Port)) hosts = append(hosts, fmt.Sprintf("https://%s:%d", server.Host, server.Port)) hosts = append(hosts, fmt.Sprintf("http://%s:%d", "localhost", server.Port)) - for _, extURL := range server.ExternalURLs { - hosts = append(hosts, extURL) - } + hosts = append(hosts, server.ExternalURLs...) for _, ip := range ipsList { hosts = append(hosts, fmt.Sprintf("http://%s:%d", ip.String(), server.Port)) hosts = append(hosts, fmt.Sprintf("https://%s:%d", ip.String(), server.Port)) diff --git a/pkg/config/server.go b/pkg/config/server.go index a532f5e..201f053 100644 --- a/pkg/config/server.go +++ b/pkg/config/server.go @@ -110,13 +110,13 @@ func (sc *ServersConfig) GetDefault() (*ServerInstanceConfig, error) { } // GetIPs - GetIP for pc -func GetIPs() (string, string, []net.IP) { +func GetIPs() (hostname string, ipList string, ipNetList []net.IP) { defer func() { if err := recover(); err != nil { fmt.Println("Recovered in GetIPs", err) } }() - hostname, _ := os.Hostname() + hostname, _ = os.Hostname() ipaddrlist := make([]net.IP, 0) iplist := "" addrs, err := net.LookupIP(hostname) @@ -125,7 +125,7 @@ func GetIPs() (string, string, []net.IP) { } for _, a := range addrs { - //cfg.LogInfo("\nFound IP Host Address: %s", a) + // cfg.LogInfo("\nFound IP Host Address: %s", a) if strings.Contains(a.String(), "127.0.0.1") { continue } @@ -135,7 +135,7 @@ func GetIPs() (string, string, []net.IP) { if iplist == "" { iff, _ := net.InterfaceAddrs() for _, a := range iff { - //cfg.LogInfo("\nFound IP Address: %s", a) + // cfg.LogInfo("\nFound IP Address: %s", a) if strings.Contains(a.String(), "127.0.0.1") { continue } From 6a0297713a72e6c9296688fd3d8fd7541f3ebab2 Mon Sep 17 00:00:00 2001 From: Hein Date: Wed, 7 Jan 2026 10:23:23 +0200 Subject: [PATCH 06/31] =?UTF-8?q?feat(reflection):=20=E2=9C=A8=20enhance?= =?UTF-8?q?=20ToSnakeCase=20and=20add=20convertSlice=20function?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Improve ToSnakeCase to handle consecutive uppercase letters. * Introduce convertSlice for element-wise conversions between slices. * Update setFieldValue to support new slice conversion logic. --- pkg/reflection/model_utils.go | 111 ++++++++++++++++++++++++++++++++-- 1 file changed, 107 insertions(+), 4 deletions(-) diff --git a/pkg/reflection/model_utils.go b/pkg/reflection/model_utils.go index 06c0fe9..60a3208 100644 --- a/pkg/reflection/model_utils.go +++ b/pkg/reflection/model_utils.go @@ -584,11 +584,23 @@ func ExtractSourceColumn(colName string) string { } // ToSnakeCase converts a string from CamelCase to snake_case +// Handles consecutive uppercase letters (acronyms) correctly: +// "HTTPServer" -> "http_server", "UserID" -> "user_id", "MyHTTPServer" -> "my_http_server" func ToSnakeCase(s string) string { var result strings.Builder - for i, r := range s { + runes := []rune(s) + + for i, r := range runes { if i > 0 && r >= 'A' && r <= 'Z' { - result.WriteRune('_') + // Add underscore if: + // 1. Previous character is lowercase, OR + // 2. Next character is lowercase (transition from acronym to word) + prevIsLower := runes[i-1] >= 'a' && runes[i-1] <= 'z' + nextIsLower := i+1 < len(runes) && runes[i+1] >= 'a' && runes[i+1] <= 'z' + + if prevIsLower || nextIsLower { + result.WriteRune('_') + } } result.WriteRune(r) } @@ -961,7 +973,7 @@ func MapToStruct(dataMap map[string]interface{}, target interface{}) error { // 4. Field name variations columnNames = append(columnNames, field.Name) columnNames = append(columnNames, strings.ToLower(field.Name)) - columnNames = append(columnNames, ToSnakeCase(field.Name)) + //columnNames = append(columnNames, ToSnakeCase(field.Name)) // Map all column name variations to this field index for _, colName := range columnNames { @@ -1067,7 +1079,7 @@ func setFieldValue(field reflect.Value, value interface{}) error { case string: field.SetBytes([]byte(v)) return nil - case map[string]interface{}, []interface{}: + case map[string]interface{}, []interface{}, []*any, map[string]*any: // Marshal complex types to JSON for SqlJSONB fields jsonBytes, err := json.Marshal(v) if err != nil { @@ -1077,6 +1089,11 @@ func setFieldValue(field reflect.Value, value interface{}) error { return nil } } + + // Handle slice-to-slice conversions (e.g., []interface{} to []*SomeModel) + if valueReflect.Kind() == reflect.Slice { + return convertSlice(field, valueReflect) + } } // Handle struct types (like SqlTimeStamp, SqlDate, SqlTime which wrap SqlNull[time.Time]) @@ -1156,6 +1173,92 @@ func setFieldValue(field reflect.Value, value interface{}) error { return fmt.Errorf("cannot convert %v to %v", valueReflect.Type(), field.Type()) } +// convertSlice converts a source slice to a target slice type, handling element-wise conversions +// Supports converting []interface{} to slices of structs or pointers to structs +func convertSlice(targetSlice reflect.Value, sourceSlice reflect.Value) error { + if sourceSlice.Kind() != reflect.Slice || targetSlice.Kind() != reflect.Slice { + return fmt.Errorf("both source and target must be slices") + } + + // Get the element type of the target slice + targetElemType := targetSlice.Type().Elem() + sourceLen := sourceSlice.Len() + + // Create a new slice with the same length as the source + newSlice := reflect.MakeSlice(targetSlice.Type(), sourceLen, sourceLen) + + // Convert each element + for i := 0; i < sourceLen; i++ { + sourceElem := sourceSlice.Index(i) + targetElem := newSlice.Index(i) + + // Get the actual value from the source element + var sourceValue interface{} + if sourceElem.CanInterface() { + sourceValue = sourceElem.Interface() + } else { + continue + } + + // Handle nil elements + if sourceValue == nil { + // For pointer types, nil is valid + if targetElemType.Kind() == reflect.Ptr { + targetElem.Set(reflect.Zero(targetElemType)) + } + continue + } + + // If target element type is a pointer to struct, we need to create new instances + if targetElemType.Kind() == reflect.Ptr { + // Create a new instance of the pointed-to type + newElemPtr := reflect.New(targetElemType.Elem()) + + // Convert the source value to the struct + switch sv := sourceValue.(type) { + case map[string]interface{}: + // Source is a map, use MapToStruct to populate the new instance + if err := MapToStruct(sv, newElemPtr.Interface()); err != nil { + return fmt.Errorf("failed to convert element %d: %w", i, err) + } + default: + // Try direct conversion or setFieldValue + if err := setFieldValue(newElemPtr.Elem(), sourceValue); err != nil { + return fmt.Errorf("failed to convert element %d: %w", i, err) + } + } + + targetElem.Set(newElemPtr) + } else if targetElemType.Kind() == reflect.Struct { + // Target element is a struct (not a pointer) + switch sv := sourceValue.(type) { + case map[string]interface{}: + // Use MapToStruct to populate the element + elemPtr := targetElem.Addr() + if elemPtr.CanInterface() { + if err := MapToStruct(sv, elemPtr.Interface()); err != nil { + return fmt.Errorf("failed to convert element %d: %w", i, err) + } + } + default: + // Try direct conversion + if err := setFieldValue(targetElem, sourceValue); err != nil { + return fmt.Errorf("failed to convert element %d: %w", i, err) + } + } + } else { + // For other types, use setFieldValue + if err := setFieldValue(targetElem, sourceValue); err != nil { + return fmt.Errorf("failed to convert element %d: %w", i, err) + } + } + } + + // Set the converted slice to the target field + targetSlice.Set(newSlice) + return nil +} + // convertToInt64 attempts to convert various types to int64 func convertToInt64(value interface{}) (int64, bool) { switch v := value.(type) { From e220ab3d347f2694005b7aa125349ef2a2bf9b67 Mon Sep 17 00:00:00 2001 From: Hein Date: Wed, 7 Jan 2026 10:23:37 +0200 Subject: [PATCH 07/31] =?UTF-8?q?refactor(reflection):=20=F0=9F=9B=A0?= =?UTF-8?q?=EF=B8=8F=20comment=20out=20ToSnakeCase=20usage=20in=20MapToStr?= =?UTF-8?q?uct?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/reflection/model_utils.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/reflection/model_utils.go b/pkg/reflection/model_utils.go index 60a3208..150ae9d 100644 --- a/pkg/reflection/model_utils.go +++ b/pkg/reflection/model_utils.go @@ -973,7 +973,7 @@ func MapToStruct(dataMap map[string]interface{}, target interface{}) error { // 4. Field name variations columnNames = append(columnNames, field.Name) columnNames = append(columnNames, strings.ToLower(field.Name)) - //columnNames = append(columnNames, ToSnakeCase(field.Name)) + // columnNames = append(columnNames, ToSnakeCase(field.Name)) // Map all column name variations to this field index for _, colName := range columnNames { From bf7125efc37a31863acd9c0ec2f816c60b35bab3 Mon Sep 17 00:00:00 2001 From: Hein Date: Wed, 7 Jan 2026 11:54:12 +0200 Subject: [PATCH 08/31] =?UTF-8?q?feat(reflection):=20=E2=9C=A8=20add=20Ext?= =?UTF-8?q?ractTagValue=20and=20GetRelationshipInfo=20functions?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Implement ExtractTagValue to handle struct tag parsing. * Introduce GetRelationshipInfo for extracting relationship metadata. * Update tests to validate new functionality. * Refactor related code for improved clarity and maintainability. --- pkg/common/handler_utils.go | 216 ++++++++++++++++++++++++++ pkg/common/handler_utils_test.go | 108 +++++++++++++ pkg/common/recursive_crud.go | 11 -- pkg/common/types.go | 11 ++ pkg/resolvespec/handler.go | 87 +---------- pkg/resolvespec/handler_test.go | 6 +- pkg/restheadspec/handler.go | 184 +--------------------- pkg/restheadspec/restheadspec_test.go | 87 +++++++++++ 8 files changed, 437 insertions(+), 273 deletions(-) create mode 100644 pkg/common/handler_utils_test.go diff --git a/pkg/common/handler_utils.go b/pkg/common/handler_utils.go index 61716fb..6e1ee12 100644 --- a/pkg/common/handler_utils.go +++ b/pkg/common/handler_utils.go @@ -3,6 +3,9 @@ package common import ( "fmt" "reflect" + "strings" + + "github.com/bitechdev/ResolveSpec/pkg/logger" ) // ValidateAndUnwrapModelResult contains the result of model validation @@ -45,3 +48,216 @@ func ValidateAndUnwrapModel(model interface{}) (*ValidateAndUnwrapModelResult, e OriginalType: originalType, }, nil } + +// ExtractTagValue extracts the value for a given key from a struct tag string. +// It handles both semicolon and comma-separated tag formats (e.g., GORM and BUN tags). +// For tags like "json:name;validate:required" it will extract "name" for key "json". +// For tags like "rel:has-many,join:table" it will extract "table" for key "join". +func ExtractTagValue(tag, key string) string { + // Split by both semicolons and commas to handle different tag formats + // We need to be smart about this - commas can be part of values + // So we'll try semicolon first, then comma if needed + separators := []string{";", ","} + + for _, sep := range separators { + parts := strings.Split(tag, sep) + for _, part := range parts { + part = strings.TrimSpace(part) + if strings.HasPrefix(part, key+":") { + return strings.TrimPrefix(part, key+":") + } + } + } + return "" +} + +// GetRelationshipInfo analyzes a model type and extracts relationship metadata +// for a specific relation field identified by its JSON name. +// Returns nil if the field is not found or is not a valid relationship. +func GetRelationshipInfo(modelType reflect.Type, relationName string) *RelationshipInfo { + // Ensure we have a struct type + if modelType == nil || modelType.Kind() != reflect.Struct { + logger.Warn("Cannot get relationship info from non-struct type: %v", modelType) + return nil + } + + for i := 0; i < modelType.NumField(); i++ { + field := modelType.Field(i) + jsonTag := field.Tag.Get("json") + jsonName := strings.Split(jsonTag, ",")[0] + + if jsonName == relationName { + gormTag := field.Tag.Get("gorm") + bunTag := field.Tag.Get("bun") + info := &RelationshipInfo{ + FieldName: field.Name, + JSONName: jsonName, + } + + if strings.Contains(bunTag, "rel:") || strings.Contains(bunTag, "join:") { + //bun:"rel:has-many,join:rid_hub=rid_hub_division" + if strings.Contains(bunTag, "has-many") { + info.RelationType = "hasMany" + } else if strings.Contains(bunTag, "has-one") { + info.RelationType = "hasOne" + } else if strings.Contains(bunTag, "belongs-to") { + info.RelationType = "belongsTo" + } else if strings.Contains(bunTag, "many-to-many") { + info.RelationType = "many2many" + } else { + info.RelationType = "hasOne" + } + + // Extract join info + joinPart := ExtractTagValue(bunTag, "join") + if joinPart != "" && info.RelationType == "many2many" { + // For many2many, the join part is the join table name + info.JoinTable = joinPart + } else if joinPart != "" { + // For other relations, parse foreignKey and references + joinParts := strings.Split(joinPart, "=") + if len(joinParts) == 2 { + info.ForeignKey = joinParts[0] + info.References = joinParts[1] + } + } + + // Get related model type + if field.Type.Kind() == reflect.Slice { + elemType := field.Type.Elem() + if elemType.Kind() == reflect.Ptr { + elemType = elemType.Elem() + } + if elemType.Kind() == reflect.Struct { + info.RelatedModel = reflect.New(elemType).Elem().Interface() + } + } else if field.Type.Kind() == reflect.Ptr || field.Type.Kind() == reflect.Struct { + elemType := field.Type + if elemType.Kind() == reflect.Ptr { + elemType = elemType.Elem() + } + if elemType.Kind() == reflect.Struct { + info.RelatedModel = reflect.New(elemType).Elem().Interface() + } + } + + return info + } + + // Parse GORM tag to determine relationship type and keys + if strings.Contains(gormTag, "foreignKey") { + info.ForeignKey = ExtractTagValue(gormTag, "foreignKey") + info.References = ExtractTagValue(gormTag, "references") + + // Determine if it's belongsTo or hasMany/hasOne + if field.Type.Kind() == reflect.Slice { + info.RelationType = "hasMany" + // Get the element type for slice + elemType := field.Type.Elem() + if elemType.Kind() == reflect.Ptr { + elemType = elemType.Elem() + } + if elemType.Kind() == reflect.Struct { + info.RelatedModel = reflect.New(elemType).Elem().Interface() + } + } else if field.Type.Kind() == reflect.Ptr || field.Type.Kind() == reflect.Struct { + info.RelationType = "belongsTo" + elemType := field.Type + if elemType.Kind() == reflect.Ptr { + elemType = elemType.Elem() + } + if elemType.Kind() == reflect.Struct { + info.RelatedModel = reflect.New(elemType).Elem().Interface() + } + } + } else if strings.Contains(gormTag, "many2many") { + info.RelationType = "many2many" + info.JoinTable = ExtractTagValue(gormTag, "many2many") + // Get the element type for many2many (always slice) + if field.Type.Kind() == reflect.Slice { + elemType := field.Type.Elem() + if elemType.Kind() == reflect.Ptr { + elemType = elemType.Elem() + } + if elemType.Kind() == reflect.Struct { + info.RelatedModel = reflect.New(elemType).Elem().Interface() + } + } + } else { + // Field has no GORM relationship tags, so it's not a relation + return nil + } + + return info + } + } + return nil +} + +// RelationPathToBunAlias converts a relation path (e.g., "Order.Customer") to a Bun alias format. +// It converts to lowercase and replaces dots with double underscores. +// For example: "Order.Customer" -> "order__customer" +func RelationPathToBunAlias(relationPath string) string { + if relationPath == "" { + return "" + } + // Convert to lowercase and replace dots with double underscores + alias := strings.ToLower(relationPath) + alias = strings.ReplaceAll(alias, ".", "__") + return alias +} + +// ReplaceTableReferencesInSQL replaces references to a base table name in a SQL expression +// with the appropriate alias for the current preload level. +// For example, if baseTableName is "mastertaskitem" and targetAlias is "mal__mal", +// it will replace "mastertaskitem.rid_mastertaskitem" with "mal__mal.rid_mastertaskitem" +func ReplaceTableReferencesInSQL(sqlExpr, baseTableName, targetAlias string) string { + if sqlExpr == "" || baseTableName == "" || targetAlias == "" { + return sqlExpr + } + + // Replace both quoted and unquoted table references + // Handle patterns like: tablename.column, "tablename".column, tablename."column", "tablename"."column" + + // Pattern 1: tablename.column (unquoted) + result := strings.ReplaceAll(sqlExpr, baseTableName+".", targetAlias+".") + + // Pattern 2: "tablename".column or "tablename"."column" (quoted table name) + result = strings.ReplaceAll(result, "\""+baseTableName+"\".", "\""+targetAlias+"\".") + + return result +} + +// GetTableNameFromModel extracts the table name from a model. +// It checks the bun tag first, then falls back to converting the struct name to snake_case. +func GetTableNameFromModel(model interface{}) string { + if model == nil { + return "" + } + + modelType := reflect.TypeOf(model) + + // Unwrap pointers + for modelType != nil && modelType.Kind() == reflect.Ptr { + modelType = modelType.Elem() + } + + if modelType == nil || modelType.Kind() != reflect.Struct { + return "" + } + + // Look for bun tag on embedded BaseModel + for i := 0; i < modelType.NumField(); i++ { + field := modelType.Field(i) + if field.Anonymous { + bunTag := field.Tag.Get("bun") + if strings.HasPrefix(bunTag, "table:") { + return strings.TrimPrefix(bunTag, "table:") + } + } + } + + // Fallback: convert struct name to lowercase (simple heuristic) + // This handles cases like "MasterTaskItem" -> "mastertaskitem" + return strings.ToLower(modelType.Name()) +} diff --git a/pkg/common/handler_utils_test.go b/pkg/common/handler_utils_test.go new file mode 100644 index 0000000..05d374f --- /dev/null +++ b/pkg/common/handler_utils_test.go @@ -0,0 +1,108 @@ +package common + +import ( + "testing" +) + +func TestExtractTagValue(t *testing.T) { + tests := []struct { + name string + tag string + key string + expected string + }{ + { + name: "Extract existing key", + tag: "json:name;validate:required", + key: "json", + expected: "name", + }, + { + name: "Extract key with spaces", + tag: "json:name ; validate:required", + key: "validate", + expected: "required", + }, + { + name: "Extract key at end", + tag: "json:name;validate:required;db:column_name", + key: "db", + expected: "column_name", + }, + { + name: "Extract key at beginning", + tag: "primary:true;json:id;db:user_id", + key: "primary", + expected: "true", + }, + { + name: "Key not found", + tag: "json:name;validate:required", + key: "db", + expected: "", + }, + { + name: "Empty tag", + tag: "", + key: "json", + expected: "", + }, + { + name: "Single key-value pair", + tag: "json:name", + key: "json", + expected: "name", + }, + { + name: "Key with empty value", + tag: "json:;validate:required", + key: "json", + expected: "", + }, + { + name: "Key with complex value", + tag: "json:user_name,omitempty;validate:required,min=3", + key: "json", + expected: "user_name,omitempty", + }, + { + name: "Multiple semicolons", + tag: "json:name;;validate:required", + key: "validate", + expected: "required", + }, + { + name: "BUN Tag with comma separator", + tag: "rel:has-many,join:rid_hub=rid_hub_child", + key: "join", + expected: "rid_hub=rid_hub_child", + }, + { + name: "Extract foreignKey", + tag: "foreignKey:UserID;references:ID", + key: "foreignKey", + expected: "UserID", + }, + { + name: "Extract references", + tag: "foreignKey:UserID;references:ID", + key: "references", + expected: "ID", + }, + { + name: "Extract many2many", + tag: "many2many:user_roles", + key: "many2many", + expected: "user_roles", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := ExtractTagValue(tt.tag, tt.key) + if result != tt.expected { + t.Errorf("ExtractTagValue(%q, %q) = %q; want %q", tt.tag, tt.key, result, tt.expected) + } + }) + } +} diff --git a/pkg/common/recursive_crud.go b/pkg/common/recursive_crud.go index f7f06a7..13b6f89 100644 --- a/pkg/common/recursive_crud.go +++ b/pkg/common/recursive_crud.go @@ -20,17 +20,6 @@ type RelationshipInfoProvider interface { GetRelationshipInfo(modelType reflect.Type, relationName string) *RelationshipInfo } -// RelationshipInfo contains information about a model relationship -type RelationshipInfo struct { - FieldName string - JSONName string - RelationType string // "belongsTo", "hasMany", "hasOne", "many2many" - ForeignKey string - References string - JoinTable string - RelatedModel interface{} -} - // NestedCUDProcessor handles recursive processing of nested object graphs type NestedCUDProcessor struct { db Database diff --git a/pkg/common/types.go b/pkg/common/types.go index d8e54b2..b09b3db 100644 --- a/pkg/common/types.go +++ b/pkg/common/types.go @@ -111,3 +111,14 @@ type TableMetadata struct { Columns []Column `json:"columns"` Relations []string `json:"relations"` } + +// RelationshipInfo contains information about a model relationship +type RelationshipInfo struct { + FieldName string `json:"field_name"` + JSONName string `json:"json_name"` + RelationType string `json:"relation_type"` // "belongsTo", "hasMany", "hasOne", "many2many" + ForeignKey string `json:"foreign_key"` + References string `json:"references"` + JoinTable string `json:"join_table"` + RelatedModel interface{} `json:"related_model"` +} diff --git a/pkg/resolvespec/handler.go b/pkg/resolvespec/handler.go index 4a1aea8..bf082e7 100644 --- a/pkg/resolvespec/handler.go +++ b/pkg/resolvespec/handler.go @@ -1453,30 +1453,7 @@ func isNullable(field reflect.StructField) bool { // GetRelationshipInfo implements common.RelationshipInfoProvider interface func (h *Handler) GetRelationshipInfo(modelType reflect.Type, relationName string) *common.RelationshipInfo { - info := h.getRelationshipInfo(modelType, relationName) - if info == nil { - return nil - } - // Convert internal type to common type - return &common.RelationshipInfo{ - FieldName: info.fieldName, - JSONName: info.jsonName, - RelationType: info.relationType, - ForeignKey: info.foreignKey, - References: info.references, - JoinTable: info.joinTable, - RelatedModel: info.relatedModel, - } -} - -type relationshipInfo struct { - fieldName string - jsonName string - relationType string // "belongsTo", "hasMany", "hasOne", "many2many" - foreignKey string - references string - joinTable string - relatedModel interface{} + return common.GetRelationshipInfo(modelType, relationName) } func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, preloads []common.PreloadOption) (common.SelectQuery, error) { @@ -1496,7 +1473,7 @@ func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, pre for idx := range preloads { preload := preloads[idx] logger.Debug("Processing preload for relation: %s", preload.Relation) - relInfo := h.getRelationshipInfo(modelType, preload.Relation) + relInfo := common.GetRelationshipInfo(modelType, preload.Relation) if relInfo == nil { logger.Warn("Relation %s not found in model", preload.Relation) continue @@ -1504,7 +1481,7 @@ func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, pre // Use the field name (capitalized) for ORM preloading // ORMs like GORM and Bun expect the struct field name, not the JSON name - relationFieldName := relInfo.fieldName + relationFieldName := relInfo.FieldName // Validate and fix WHERE clause to ensure it contains the relation prefix if len(preload.Where) > 0 { @@ -1547,13 +1524,13 @@ func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, pre copy(columns, preload.Columns) // Add foreign key if not already present - if relInfo.foreignKey != "" { + if relInfo.ForeignKey != "" { // Convert struct field name (e.g., DepartmentID) to snake_case (e.g., department_id) - foreignKeyColumn := toSnakeCase(relInfo.foreignKey) + foreignKeyColumn := toSnakeCase(relInfo.ForeignKey) hasForeignKey := false for _, col := range columns { - if col == foreignKeyColumn || col == relInfo.foreignKey { + if col == foreignKeyColumn || col == relInfo.ForeignKey { hasForeignKey = true break } @@ -1599,58 +1576,6 @@ func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, pre return query, nil } -func (h *Handler) getRelationshipInfo(modelType reflect.Type, relationName string) *relationshipInfo { - // Ensure we have a struct type - if modelType == nil || modelType.Kind() != reflect.Struct { - logger.Warn("Cannot get relationship info from non-struct type: %v", modelType) - return nil - } - - for i := 0; i < modelType.NumField(); i++ { - field := modelType.Field(i) - jsonTag := field.Tag.Get("json") - jsonName := strings.Split(jsonTag, ",")[0] - - if jsonName == relationName { - gormTag := field.Tag.Get("gorm") - info := &relationshipInfo{ - fieldName: field.Name, - jsonName: jsonName, - } - - // Parse GORM tag to determine relationship type and keys - if strings.Contains(gormTag, "foreignKey") { - info.foreignKey = h.extractTagValue(gormTag, "foreignKey") - info.references = h.extractTagValue(gormTag, "references") - - // Determine if it's belongsTo or hasMany/hasOne - if field.Type.Kind() == reflect.Slice { - info.relationType = "hasMany" - } else if field.Type.Kind() == reflect.Ptr || field.Type.Kind() == reflect.Struct { - info.relationType = "belongsTo" - } - } else if strings.Contains(gormTag, "many2many") { - info.relationType = "many2many" - info.joinTable = h.extractTagValue(gormTag, "many2many") - } - - return info - } - } - return nil -} - -func (h *Handler) extractTagValue(tag, key string) string { - parts := strings.Split(tag, ";") - for _, part := range parts { - part = strings.TrimSpace(part) - if strings.HasPrefix(part, key+":") { - return strings.TrimPrefix(part, key+":") - } - } - return "" -} - // toSnakeCase converts a PascalCase or camelCase string to snake_case func toSnakeCase(s string) string { var result strings.Builder diff --git a/pkg/resolvespec/handler_test.go b/pkg/resolvespec/handler_test.go index ac36b6f..d57e49d 100644 --- a/pkg/resolvespec/handler_test.go +++ b/pkg/resolvespec/handler_test.go @@ -269,8 +269,6 @@ func TestToSnakeCase(t *testing.T) { } func TestExtractTagValue(t *testing.T) { - handler := NewHandler(nil, nil) - tests := []struct { name string tag string @@ -311,9 +309,9 @@ func TestExtractTagValue(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result := handler.extractTagValue(tt.tag, tt.key) + result := common.ExtractTagValue(tt.tag, tt.key) if result != tt.expected { - t.Errorf("extractTagValue(%q, %q) = %q, expected %q", tt.tag, tt.key, result, tt.expected) + t.Errorf("ExtractTagValue(%q, %q) = %q, expected %q", tt.tag, tt.key, result, tt.expected) } }) } diff --git a/pkg/restheadspec/handler.go b/pkg/restheadspec/handler.go index 3a47fad..2f20ce1 100644 --- a/pkg/restheadspec/handler.go +++ b/pkg/restheadspec/handler.go @@ -766,7 +766,7 @@ func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload co // Apply ComputedQL fields if any if len(preload.ComputedQL) > 0 { // Get the base table name from the related model - baseTableName := getTableNameFromModel(relatedModel) + baseTableName := common.GetTableNameFromModel(relatedModel) // Convert the preload relation path to the appropriate alias format // This is ORM-specific. Currently we only support Bun's format. @@ -777,7 +777,7 @@ func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload co underlyingType := fmt.Sprintf("%T", h.db.GetUnderlyingDB()) if strings.Contains(underlyingType, "bun.DB") { // Use Bun's alias format: lowercase with double underscores - preloadAlias = relationPathToBunAlias(preload.Relation) + preloadAlias = common.RelationPathToBunAlias(preload.Relation) } // For GORM: GORM doesn't use the same alias format, and this fix // may not be needed since GORM handles preloads differently @@ -792,7 +792,7 @@ func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload co // levels of recursive/nested preloads adjustedExpr := colExpr if baseTableName != "" && preloadAlias != "" { - adjustedExpr = replaceTableReferencesInSQL(colExpr, baseTableName, preloadAlias) + adjustedExpr = common.ReplaceTableReferencesInSQL(colExpr, baseTableName, preloadAlias) if adjustedExpr != colExpr { logger.Debug("Adjusted computed column expression for %s: '%s' -> '%s'", colName, colExpr, adjustedExpr) @@ -903,73 +903,6 @@ func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload co return query } -// relationPathToBunAlias converts a relation path like "MAL.MAL.DEF" to the Bun alias format "mal__mal__def" -// Bun generates aliases for nested relations by lowercasing and replacing dots with double underscores -func relationPathToBunAlias(relationPath string) string { - if relationPath == "" { - return "" - } - // Convert to lowercase and replace dots with double underscores - alias := strings.ToLower(relationPath) - alias = strings.ReplaceAll(alias, ".", "__") - return alias -} - -// replaceTableReferencesInSQL replaces references to a base table name in a SQL expression -// with the appropriate alias for the current preload level -// For example, if baseTableName is "mastertaskitem" and targetAlias is "mal__mal", -// it will replace "mastertaskitem.rid_mastertaskitem" with "mal__mal.rid_mastertaskitem" -func replaceTableReferencesInSQL(sqlExpr, baseTableName, targetAlias string) string { - if sqlExpr == "" || baseTableName == "" || targetAlias == "" { - return sqlExpr - } - - // Replace both quoted and unquoted table references - // Handle patterns like: tablename.column, "tablename".column, tablename."column", "tablename"."column" - - // Pattern 1: tablename.column (unquoted) - result := strings.ReplaceAll(sqlExpr, baseTableName+".", targetAlias+".") - - // Pattern 2: "tablename".column or "tablename"."column" (quoted table name) - result = strings.ReplaceAll(result, "\""+baseTableName+"\".", "\""+targetAlias+"\".") - - return result -} - -// getTableNameFromModel extracts the table name from a model -// It checks the bun tag first, then falls back to converting the struct name to snake_case -func getTableNameFromModel(model interface{}) string { - if model == nil { - return "" - } - - modelType := reflect.TypeOf(model) - - // Unwrap pointers - for modelType != nil && modelType.Kind() == reflect.Ptr { - modelType = modelType.Elem() - } - - if modelType == nil || modelType.Kind() != reflect.Struct { - return "" - } - - // Look for bun tag on embedded BaseModel - for i := 0; i < modelType.NumField(); i++ { - field := modelType.Field(i) - if field.Anonymous { - bunTag := field.Tag.Get("bun") - if strings.HasPrefix(bunTag, "table:") { - return strings.TrimPrefix(bunTag, "table:") - } - } - } - - // Fallback: convert struct name to lowercase (simple heuristic) - // This handles cases like "MasterTaskItem" -> "mastertaskitem" - return strings.ToLower(modelType.Name()) -} - func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, data interface{}, options ExtendedRequestOptions) { // Capture panics and return error response defer func() { @@ -2570,10 +2503,10 @@ func (h *Handler) filterExtendedOptions(validator *common.ColumnValidator, optio filteredExpand := expand // Get the relationship info for this expand relation - relInfo := h.getRelationshipInfo(modelType, expand.Relation) - if relInfo != nil && relInfo.relatedModel != nil { + relInfo := common.GetRelationshipInfo(modelType, expand.Relation) + if relInfo != nil && relInfo.RelatedModel != nil { // Create a validator for the related model - expandValidator := common.NewColumnValidator(relInfo.relatedModel) + expandValidator := common.NewColumnValidator(relInfo.RelatedModel) // Filter columns using the related model's validator filteredExpand.Columns = expandValidator.FilterValidColumns(expand.Columns) @@ -2650,110 +2583,7 @@ func (h *Handler) shouldUseNestedProcessor(data map[string]interface{}, model in // GetRelationshipInfo implements common.RelationshipInfoProvider interface func (h *Handler) GetRelationshipInfo(modelType reflect.Type, relationName string) *common.RelationshipInfo { - info := h.getRelationshipInfo(modelType, relationName) - if info == nil { - return nil - } - // Convert internal type to common type - return &common.RelationshipInfo{ - FieldName: info.fieldName, - JSONName: info.jsonName, - RelationType: info.relationType, - ForeignKey: info.foreignKey, - References: info.references, - JoinTable: info.joinTable, - RelatedModel: info.relatedModel, - } -} - -type relationshipInfo struct { - fieldName string - jsonName string - relationType string // "belongsTo", "hasMany", "hasOne", "many2many" - foreignKey string - references string - joinTable string - relatedModel interface{} -} - -func (h *Handler) getRelationshipInfo(modelType reflect.Type, relationName string) *relationshipInfo { - // Ensure we have a struct type - if modelType == nil || modelType.Kind() != reflect.Struct { - logger.Warn("Cannot get relationship info from non-struct type: %v", modelType) - return nil - } - - for i := 0; i < modelType.NumField(); i++ { - field := modelType.Field(i) - jsonTag := field.Tag.Get("json") - jsonName := strings.Split(jsonTag, ",")[0] - - if jsonName == relationName { - gormTag := field.Tag.Get("gorm") - info := &relationshipInfo{ - fieldName: field.Name, - jsonName: jsonName, - } - - // Parse GORM tag to determine relationship type and keys - if strings.Contains(gormTag, "foreignKey") { - info.foreignKey = h.extractTagValue(gormTag, "foreignKey") - info.references = h.extractTagValue(gormTag, "references") - - // Determine if it's belongsTo or hasMany/hasOne - if field.Type.Kind() == reflect.Slice { - info.relationType = "hasMany" - // Get the element type for slice - elemType := field.Type.Elem() - if elemType.Kind() == reflect.Ptr { - elemType = elemType.Elem() - } - if elemType.Kind() == reflect.Struct { - info.relatedModel = reflect.New(elemType).Elem().Interface() - } - } else if field.Type.Kind() == reflect.Ptr || field.Type.Kind() == reflect.Struct { - info.relationType = "belongsTo" - elemType := field.Type - if elemType.Kind() == reflect.Ptr { - elemType = elemType.Elem() - } - if elemType.Kind() == reflect.Struct { - info.relatedModel = reflect.New(elemType).Elem().Interface() - } - } - } else if strings.Contains(gormTag, "many2many") { - info.relationType = "many2many" - info.joinTable = h.extractTagValue(gormTag, "many2many") - // Get the element type for many2many (always slice) - if field.Type.Kind() == reflect.Slice { - elemType := field.Type.Elem() - if elemType.Kind() == reflect.Ptr { - elemType = elemType.Elem() - } - if elemType.Kind() == reflect.Struct { - info.relatedModel = reflect.New(elemType).Elem().Interface() - } - } - } else { - // Field has no GORM relationship tags, so it's not a relation - return nil - } - - return info - } - } - return nil -} - -func (h *Handler) extractTagValue(tag, key string) string { - parts := strings.Split(tag, ";") - for _, part := range parts { - part = strings.TrimSpace(part) - if strings.HasPrefix(part, key+":") { - return strings.TrimPrefix(part, key+":") - } - } - return "" + return common.GetRelationshipInfo(modelType, relationName) } // HandleOpenAPI generates and returns the OpenAPI specification diff --git a/pkg/restheadspec/restheadspec_test.go b/pkg/restheadspec/restheadspec_test.go index 355938b..53a1b23 100644 --- a/pkg/restheadspec/restheadspec_test.go +++ b/pkg/restheadspec/restheadspec_test.go @@ -2,6 +2,8 @@ package restheadspec import ( "testing" + + "github.com/bitechdev/ResolveSpec/pkg/common" ) func TestParseModelName(t *testing.T) { @@ -112,3 +114,88 @@ func TestNewStandardBunRouter(t *testing.T) { t.Error("Expected router to be created, got nil") } } + +func TestExtractTagValue(t *testing.T) { + tests := []struct { + name string + tag string + key string + expected string + }{ + { + name: "Extract existing key", + tag: "json:name;validate:required", + key: "json", + expected: "name", + }, + { + name: "Extract key with spaces", + tag: "json:name ; validate:required", + key: "validate", + expected: "required", + }, + { + name: "Extract key at end", + tag: "json:name;validate:required;db:column_name", + key: "db", + expected: "column_name", + }, + { + name: "Extract key at beginning", + tag: "primary:true;json:id;db:user_id", + key: "primary", + expected: "true", + }, + { + name: "Key not found", + tag: "json:name;validate:required", + key: "db", + expected: "", + }, + { + name: "Empty tag", + tag: "", + key: "json", + expected: "", + }, + { + name: "Single key-value pair", + tag: "json:name", + key: "json", + expected: "name", + }, + { + name: "Key with empty value", + tag: "json:;validate:required", + key: "json", + expected: "", + }, + { + name: "Key with complex value", + tag: "json:user_name,omitempty;validate:required,min=3", + key: "json", + expected: "user_name,omitempty", + }, + { + name: "Multiple semicolons", + tag: "json:name;;validate:required", + key: "validate", + expected: "required", + }, + { + name: "BUN Tag", + tag: "rel:has-many,join:rid_hub=rid_hub_child", + key: "join", + expected: "rid_hub=rid_hub_child", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := common.ExtractTagValue(tt.tag, tt.key) + if result != tt.expected { + t.Errorf("ExtractTagValue(%q, %q) = %q; want %q", tt.tag, tt.key, result, tt.expected) + } + }) + } +} From a7e640a6a1e5957cd0247d9eb4e7fb995010e772 Mon Sep 17 00:00:00 2001 From: Hein Date: Wed, 7 Jan 2026 11:58:44 +0200 Subject: [PATCH 09/31] =?UTF-8?q?fix(recursive=5Fcrud):=20=F0=9F=90=9B=20u?= =?UTF-8?q?se=20dynamic=20primary=20key=20name=20in=20insert=20*=20Update?= =?UTF-8?q?=20processInsert=20to=20use=20the=20primary=20key=20name=20dyna?= =?UTF-8?q?mically.=20*=20Ensure=20correct=20ID=20retrieval=20from=20data?= =?UTF-8?q?=20based=20on=20primary=20key.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/common/recursive_crud.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pkg/common/recursive_crud.go b/pkg/common/recursive_crud.go index 13b6f89..6e047fb 100644 --- a/pkg/common/recursive_crud.go +++ b/pkg/common/recursive_crud.go @@ -207,9 +207,9 @@ func (p *NestedCUDProcessor) processInsert( for key, value := range data { query = query.Value(key, value) } - + pkName := reflection.GetPrimaryKeyName(tableName) // Add RETURNING clause to get the inserted ID - query = query.Returning("id") + query = query.Returning(pkName) result, err := query.Exec(ctx) if err != nil { @@ -220,8 +220,8 @@ func (p *NestedCUDProcessor) processInsert( var id interface{} if lastID, err := result.LastInsertId(); err == nil && lastID > 0 { id = lastID - } else if data["id"] != nil { - id = data["id"] + } else if data[pkName] != nil { + id = data[pkName] } logger.Debug("Insert successful, ID: %v, rows affected: %d", id, result.RowsAffected()) From 37c85361ba842ad0860b6578b9cee13e1fe67951 Mon Sep 17 00:00:00 2001 From: Hein Date: Wed, 7 Jan 2026 12:06:08 +0200 Subject: [PATCH 10/31] =?UTF-8?q?feat(cors):=20=E2=9C=A8=20add=20check=20f?= =?UTF-8?q?or=20server=20port=20in=20CORS=20config?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/common/cors.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pkg/common/cors.go b/pkg/common/cors.go index 0325af6..45a4078 100644 --- a/pkg/common/cors.go +++ b/pkg/common/cors.go @@ -26,10 +26,13 @@ func DefaultCORSConfig() CORSConfig { for i := range cfg.Servers.Instances { server := cfg.Servers.Instances[i] + if server.Port == 0 { + continue + } + hosts = append(hosts, server.ExternalURLs...) hosts = append(hosts, fmt.Sprintf("http://%s:%d", server.Host, server.Port)) hosts = append(hosts, fmt.Sprintf("https://%s:%d", server.Host, server.Port)) hosts = append(hosts, fmt.Sprintf("http://%s:%d", "localhost", server.Port)) - hosts = append(hosts, server.ExternalURLs...) for _, ip := range ipsList { hosts = append(hosts, fmt.Sprintf("http://%s:%d", ip.String(), server.Port)) hosts = append(hosts, fmt.Sprintf("https://%s:%d", ip.String(), server.Port)) From cb20a354fc217a16d271e168fe28f1b5b1ec1701 Mon Sep 17 00:00:00 2001 From: Hein Date: Wed, 7 Jan 2026 15:24:44 +0200 Subject: [PATCH 11/31] =?UTF-8?q?feat(cors):=20=E2=9C=A8=20update=20SetCOR?= =?UTF-8?q?SHeaders=20to=20accept=20Request?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Modify SetCORSHeaders function to include Request parameter. * Set Access-Control-Allow-Origin and Access-Control-Allow-Headers to "*". * Update all relevant calls to SetCORSHeaders across the codebase. --- pkg/common/cors.go | 22 +++++---- pkg/resolvespec/resolvespec.go | 53 +++++++++++++--------- pkg/restheadspec/restheadspec.go | 77 +++++++++++++++++--------------- 3 files changed, 88 insertions(+), 64 deletions(-) diff --git a/pkg/common/cors.go b/pkg/common/cors.go index 45a4078..11336a4 100644 --- a/pkg/common/cors.go +++ b/pkg/common/cors.go @@ -114,11 +114,14 @@ func GetHeadSpecHeaders() []string { } // SetCORSHeaders sets CORS headers on a response writer -func SetCORSHeaders(w ResponseWriter, config CORSConfig) { +func SetCORSHeaders(w ResponseWriter, r Request, config CORSConfig) { // Set allowed origins - if len(config.AllowedOrigins) > 0 { - w.SetHeader("Access-Control-Allow-Origin", strings.Join(config.AllowedOrigins, ", ")) - } + // if len(config.AllowedOrigins) > 0 { + // w.SetHeader("Access-Control-Allow-Origin", strings.Join(config.AllowedOrigins, ", ")) + // } + + // Todo origin list parsing + w.SetHeader("Access-Control-Allow-Origin", "*") // Set allowed methods if len(config.AllowedMethods) > 0 { @@ -126,9 +129,10 @@ func SetCORSHeaders(w ResponseWriter, config CORSConfig) { } // Set allowed headers - if len(config.AllowedHeaders) > 0 { - w.SetHeader("Access-Control-Allow-Headers", strings.Join(config.AllowedHeaders, ", ")) - } + // if len(config.AllowedHeaders) > 0 { + // w.SetHeader("Access-Control-Allow-Headers", strings.Join(config.AllowedHeaders, ", ")) + // } + w.SetHeader("Access-Control-Allow-Headers", "*") // Set max age if config.MaxAge > 0 { @@ -139,5 +143,7 @@ func SetCORSHeaders(w ResponseWriter, config CORSConfig) { w.SetHeader("Access-Control-Allow-Credentials", "true") // Expose headers that clients can read - w.SetHeader("Access-Control-Expose-Headers", "Content-Range, X-Api-Range-Total, X-Api-Range-Size") + exposeHeaders := config.AllowedHeaders + exposeHeaders = append(exposeHeaders, "Content-Range", "X-Api-Range-Total", "X-Api-Range-Size") + w.SetHeader("Access-Control-Expose-Headers", strings.Join(exposeHeaders, ", ")) } diff --git a/pkg/resolvespec/resolvespec.go b/pkg/resolvespec/resolvespec.go index 80c0287..bf78e08 100644 --- a/pkg/resolvespec/resolvespec.go +++ b/pkg/resolvespec/resolvespec.go @@ -50,8 +50,9 @@ func SetupMuxRoutes(muxRouter *mux.Router, handler *Handler, authMiddleware Midd openAPIHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { corsConfig := common.DefaultCORSConfig() respAdapter := router.NewHTTPResponseWriter(w) - common.SetCORSHeaders(respAdapter, corsConfig) reqAdapter := router.NewHTTPRequest(r) + common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig) + handler.HandleOpenAPI(respAdapter, reqAdapter) }) muxRouter.Handle("/openapi", openAPIHandler).Methods("GET", "OPTIONS") @@ -98,7 +99,8 @@ func createMuxHandler(handler *Handler, schema, entity, idParam string) http.Han // Set CORS headers corsConfig := common.DefaultCORSConfig() respAdapter := router.NewHTTPResponseWriter(w) - common.SetCORSHeaders(respAdapter, corsConfig) + reqAdapter := router.NewHTTPRequest(r) + common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig) vars := make(map[string]string) vars["schema"] = schema @@ -106,7 +108,7 @@ func createMuxHandler(handler *Handler, schema, entity, idParam string) http.Han if idParam != "" { vars["id"] = mux.Vars(r)[idParam] } - reqAdapter := router.NewHTTPRequest(r) + handler.Handle(respAdapter, reqAdapter, vars) } } @@ -117,7 +119,8 @@ func createMuxGetHandler(handler *Handler, schema, entity, idParam string) http. // Set CORS headers corsConfig := common.DefaultCORSConfig() respAdapter := router.NewHTTPResponseWriter(w) - common.SetCORSHeaders(respAdapter, corsConfig) + reqAdapter := router.NewHTTPRequest(r) + common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig) vars := make(map[string]string) vars["schema"] = schema @@ -125,7 +128,7 @@ func createMuxGetHandler(handler *Handler, schema, entity, idParam string) http. if idParam != "" { vars["id"] = mux.Vars(r)[idParam] } - reqAdapter := router.NewHTTPRequest(r) + handler.HandleGet(respAdapter, reqAdapter, vars) } } @@ -137,13 +140,14 @@ func createMuxOptionsHandler(handler *Handler, schema, entity string, allowedMet corsConfig := common.DefaultCORSConfig() corsConfig.AllowedMethods = allowedMethods respAdapter := router.NewHTTPResponseWriter(w) - common.SetCORSHeaders(respAdapter, corsConfig) + reqAdapter := router.NewHTTPRequest(r) + common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig) // Return metadata in the OPTIONS response body vars := make(map[string]string) vars["schema"] = schema vars["entity"] = entity - reqAdapter := router.NewHTTPRequest(r) + handler.HandleGet(respAdapter, reqAdapter, vars) } } @@ -222,15 +226,16 @@ func SetupBunRouterRoutes(r BunRouterHandler, handler *Handler) { // Add global /openapi route r.Handle("GET", "/openapi", func(w http.ResponseWriter, req bunrouter.Request) error { respAdapter := router.NewHTTPResponseWriter(w) - common.SetCORSHeaders(respAdapter, corsConfig) reqAdapter := router.NewHTTPRequest(req.Request) + common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig) handler.HandleOpenAPI(respAdapter, reqAdapter) return nil }) r.Handle("OPTIONS", "/openapi", func(w http.ResponseWriter, req bunrouter.Request) error { respAdapter := router.NewHTTPResponseWriter(w) - common.SetCORSHeaders(respAdapter, corsConfig) + reqAdapter := router.NewHTTPRequest(req.Request) + common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig) return nil }) @@ -253,12 +258,13 @@ func SetupBunRouterRoutes(r BunRouterHandler, handler *Handler) { // POST route without ID r.Handle("POST", entityPath, func(w http.ResponseWriter, req bunrouter.Request) error { respAdapter := router.NewHTTPResponseWriter(w) - common.SetCORSHeaders(respAdapter, corsConfig) + reqAdapter := router.NewHTTPRequest(req.Request) + common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig) params := map[string]string{ "schema": currentSchema, "entity": currentEntity, } - reqAdapter := router.NewHTTPRequest(req.Request) + handler.Handle(respAdapter, reqAdapter, params) return nil }) @@ -266,13 +272,14 @@ func SetupBunRouterRoutes(r BunRouterHandler, handler *Handler) { // POST route with ID r.Handle("POST", entityWithIDPath, func(w http.ResponseWriter, req bunrouter.Request) error { respAdapter := router.NewHTTPResponseWriter(w) - common.SetCORSHeaders(respAdapter, corsConfig) + reqAdapter := router.NewHTTPRequest(req.Request) + common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig) params := map[string]string{ "schema": currentSchema, "entity": currentEntity, "id": req.Param("id"), } - reqAdapter := router.NewHTTPRequest(req.Request) + handler.Handle(respAdapter, reqAdapter, params) return nil }) @@ -280,12 +287,13 @@ func SetupBunRouterRoutes(r BunRouterHandler, handler *Handler) { // GET route without ID r.Handle("GET", entityPath, func(w http.ResponseWriter, req bunrouter.Request) error { respAdapter := router.NewHTTPResponseWriter(w) - common.SetCORSHeaders(respAdapter, corsConfig) + reqAdapter := router.NewHTTPRequest(req.Request) + common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig) params := map[string]string{ "schema": currentSchema, "entity": currentEntity, } - reqAdapter := router.NewHTTPRequest(req.Request) + handler.HandleGet(respAdapter, reqAdapter, params) return nil }) @@ -293,13 +301,14 @@ func SetupBunRouterRoutes(r BunRouterHandler, handler *Handler) { // GET route with ID r.Handle("GET", entityWithIDPath, func(w http.ResponseWriter, req bunrouter.Request) error { respAdapter := router.NewHTTPResponseWriter(w) - common.SetCORSHeaders(respAdapter, corsConfig) + reqAdapter := router.NewHTTPRequest(req.Request) + common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig) params := map[string]string{ "schema": currentSchema, "entity": currentEntity, "id": req.Param("id"), } - reqAdapter := router.NewHTTPRequest(req.Request) + handler.HandleGet(respAdapter, reqAdapter, params) return nil }) @@ -307,14 +316,15 @@ func SetupBunRouterRoutes(r BunRouterHandler, handler *Handler) { // OPTIONS route without ID (returns metadata) r.Handle("OPTIONS", entityPath, func(w http.ResponseWriter, req bunrouter.Request) error { respAdapter := router.NewHTTPResponseWriter(w) + reqAdapter := router.NewHTTPRequest(req.Request) optionsCorsConfig := corsConfig optionsCorsConfig.AllowedMethods = []string{"GET", "POST", "OPTIONS"} - common.SetCORSHeaders(respAdapter, optionsCorsConfig) + common.SetCORSHeaders(respAdapter, reqAdapter, optionsCorsConfig) params := map[string]string{ "schema": currentSchema, "entity": currentEntity, } - reqAdapter := router.NewHTTPRequest(req.Request) + handler.HandleGet(respAdapter, reqAdapter, params) return nil }) @@ -322,14 +332,15 @@ func SetupBunRouterRoutes(r BunRouterHandler, handler *Handler) { // OPTIONS route with ID (returns metadata) r.Handle("OPTIONS", entityWithIDPath, func(w http.ResponseWriter, req bunrouter.Request) error { respAdapter := router.NewHTTPResponseWriter(w) + reqAdapter := router.NewHTTPRequest(req.Request) optionsCorsConfig := corsConfig optionsCorsConfig.AllowedMethods = []string{"POST", "OPTIONS"} - common.SetCORSHeaders(respAdapter, optionsCorsConfig) + common.SetCORSHeaders(respAdapter, reqAdapter, optionsCorsConfig) params := map[string]string{ "schema": currentSchema, "entity": currentEntity, } - reqAdapter := router.NewHTTPRequest(req.Request) + handler.HandleGet(respAdapter, reqAdapter, params) return nil }) diff --git a/pkg/restheadspec/restheadspec.go b/pkg/restheadspec/restheadspec.go index 4b922ce..cfe0378 100644 --- a/pkg/restheadspec/restheadspec.go +++ b/pkg/restheadspec/restheadspec.go @@ -103,8 +103,9 @@ func SetupMuxRoutes(muxRouter *mux.Router, handler *Handler, authMiddleware Midd openAPIHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { corsConfig := common.DefaultCORSConfig() respAdapter := router.NewHTTPResponseWriter(w) - common.SetCORSHeaders(respAdapter, corsConfig) reqAdapter := router.NewHTTPRequest(r) + common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig) + handler.HandleOpenAPI(respAdapter, reqAdapter) }) muxRouter.Handle("/openapi", openAPIHandler).Methods("GET", "OPTIONS") @@ -161,7 +162,8 @@ func createMuxHandler(handler *Handler, schema, entity, idParam string) http.Han // Set CORS headers corsConfig := common.DefaultCORSConfig() respAdapter := router.NewHTTPResponseWriter(w) - common.SetCORSHeaders(respAdapter, corsConfig) + reqAdapter := router.NewHTTPRequest(r) + common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig) vars := make(map[string]string) vars["schema"] = schema @@ -169,7 +171,7 @@ func createMuxHandler(handler *Handler, schema, entity, idParam string) http.Han if idParam != "" { vars["id"] = mux.Vars(r)[idParam] } - reqAdapter := router.NewHTTPRequest(r) + handler.Handle(respAdapter, reqAdapter, vars) } } @@ -180,7 +182,8 @@ func createMuxGetHandler(handler *Handler, schema, entity, idParam string) http. // Set CORS headers corsConfig := common.DefaultCORSConfig() respAdapter := router.NewHTTPResponseWriter(w) - common.SetCORSHeaders(respAdapter, corsConfig) + reqAdapter := router.NewHTTPRequest(r) + common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig) vars := make(map[string]string) vars["schema"] = schema @@ -188,7 +191,7 @@ func createMuxGetHandler(handler *Handler, schema, entity, idParam string) http. if idParam != "" { vars["id"] = mux.Vars(r)[idParam] } - reqAdapter := router.NewHTTPRequest(r) + handler.HandleGet(respAdapter, reqAdapter, vars) } } @@ -200,13 +203,14 @@ func createMuxOptionsHandler(handler *Handler, schema, entity string, allowedMet corsConfig := common.DefaultCORSConfig() corsConfig.AllowedMethods = allowedMethods respAdapter := router.NewHTTPResponseWriter(w) - common.SetCORSHeaders(respAdapter, corsConfig) + reqAdapter := router.NewHTTPRequest(r) + common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig) // Return metadata in the OPTIONS response body vars := make(map[string]string) vars["schema"] = schema vars["entity"] = entity - reqAdapter := router.NewHTTPRequest(r) + handler.HandleGet(respAdapter, reqAdapter, vars) } } @@ -285,15 +289,8 @@ func SetupBunRouterRoutes(r BunRouterHandler, handler *Handler) { // Add global /openapi route r.Handle("GET", "/openapi", func(w http.ResponseWriter, req bunrouter.Request) error { respAdapter := router.NewHTTPResponseWriter(w) - common.SetCORSHeaders(respAdapter, corsConfig) reqAdapter := router.NewBunRouterRequest(req) - handler.HandleOpenAPI(respAdapter, reqAdapter) - return nil - }) - - r.Handle("OPTIONS", "/openapi", func(w http.ResponseWriter, req bunrouter.Request) error { - respAdapter := router.NewHTTPResponseWriter(w) - common.SetCORSHeaders(respAdapter, corsConfig) + common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig) return nil }) @@ -317,24 +314,26 @@ func SetupBunRouterRoutes(r BunRouterHandler, handler *Handler) { // GET and POST for /{schema}/{entity} r.Handle("GET", entityPath, func(w http.ResponseWriter, req bunrouter.Request) error { respAdapter := router.NewHTTPResponseWriter(w) - common.SetCORSHeaders(respAdapter, corsConfig) + reqAdapter := router.NewBunRouterRequest(req) + common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig) params := map[string]string{ "schema": currentSchema, "entity": currentEntity, } - reqAdapter := router.NewBunRouterRequest(req) + handler.Handle(respAdapter, reqAdapter, params) return nil }) r.Handle("POST", entityPath, func(w http.ResponseWriter, req bunrouter.Request) error { respAdapter := router.NewHTTPResponseWriter(w) - common.SetCORSHeaders(respAdapter, corsConfig) + reqAdapter := router.NewBunRouterRequest(req) + common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig) params := map[string]string{ "schema": currentSchema, "entity": currentEntity, } - reqAdapter := router.NewBunRouterRequest(req) + handler.Handle(respAdapter, reqAdapter, params) return nil }) @@ -342,65 +341,70 @@ func SetupBunRouterRoutes(r BunRouterHandler, handler *Handler) { // GET, POST, PUT, PATCH, DELETE for /{schema}/{entity}/:id r.Handle("GET", entityWithIDPath, func(w http.ResponseWriter, req bunrouter.Request) error { respAdapter := router.NewHTTPResponseWriter(w) - common.SetCORSHeaders(respAdapter, corsConfig) + reqAdapter := router.NewBunRouterRequest(req) + common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig) params := map[string]string{ "schema": currentSchema, "entity": currentEntity, "id": req.Param("id"), } - reqAdapter := router.NewBunRouterRequest(req) + handler.Handle(respAdapter, reqAdapter, params) return nil }) r.Handle("POST", entityWithIDPath, func(w http.ResponseWriter, req bunrouter.Request) error { respAdapter := router.NewHTTPResponseWriter(w) - common.SetCORSHeaders(respAdapter, corsConfig) + reqAdapter := router.NewBunRouterRequest(req) + common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig) params := map[string]string{ "schema": currentSchema, "entity": currentEntity, "id": req.Param("id"), } - reqAdapter := router.NewBunRouterRequest(req) + handler.Handle(respAdapter, reqAdapter, params) return nil }) r.Handle("PUT", entityWithIDPath, func(w http.ResponseWriter, req bunrouter.Request) error { respAdapter := router.NewHTTPResponseWriter(w) - common.SetCORSHeaders(respAdapter, corsConfig) + reqAdapter := router.NewBunRouterRequest(req) + common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig) params := map[string]string{ "schema": currentSchema, "entity": currentEntity, "id": req.Param("id"), } - reqAdapter := router.NewBunRouterRequest(req) + handler.Handle(respAdapter, reqAdapter, params) return nil }) r.Handle("PATCH", entityWithIDPath, func(w http.ResponseWriter, req bunrouter.Request) error { respAdapter := router.NewHTTPResponseWriter(w) - common.SetCORSHeaders(respAdapter, corsConfig) + reqAdapter := router.NewBunRouterRequest(req) + common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig) params := map[string]string{ "schema": currentSchema, "entity": currentEntity, "id": req.Param("id"), } - reqAdapter := router.NewBunRouterRequest(req) + handler.Handle(respAdapter, reqAdapter, params) return nil }) r.Handle("DELETE", entityWithIDPath, func(w http.ResponseWriter, req bunrouter.Request) error { respAdapter := router.NewHTTPResponseWriter(w) - common.SetCORSHeaders(respAdapter, corsConfig) + reqAdapter := router.NewBunRouterRequest(req) + common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig) params := map[string]string{ "schema": currentSchema, "entity": currentEntity, "id": req.Param("id"), } - reqAdapter := router.NewBunRouterRequest(req) + handler.Handle(respAdapter, reqAdapter, params) return nil }) @@ -408,12 +412,13 @@ func SetupBunRouterRoutes(r BunRouterHandler, handler *Handler) { // Metadata endpoint r.Handle("GET", metadataPath, func(w http.ResponseWriter, req bunrouter.Request) error { respAdapter := router.NewHTTPResponseWriter(w) - common.SetCORSHeaders(respAdapter, corsConfig) + reqAdapter := router.NewBunRouterRequest(req) + common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig) params := map[string]string{ "schema": currentSchema, "entity": currentEntity, } - reqAdapter := router.NewBunRouterRequest(req) + handler.HandleGet(respAdapter, reqAdapter, params) return nil }) @@ -421,14 +426,15 @@ func SetupBunRouterRoutes(r BunRouterHandler, handler *Handler) { // OPTIONS route without ID (returns metadata) r.Handle("OPTIONS", entityPath, func(w http.ResponseWriter, req bunrouter.Request) error { respAdapter := router.NewHTTPResponseWriter(w) + reqAdapter := router.NewBunRouterRequest(req) optionsCorsConfig := corsConfig optionsCorsConfig.AllowedMethods = []string{"GET", "POST", "OPTIONS"} - common.SetCORSHeaders(respAdapter, optionsCorsConfig) + common.SetCORSHeaders(respAdapter, reqAdapter, optionsCorsConfig) params := map[string]string{ "schema": currentSchema, "entity": currentEntity, } - reqAdapter := router.NewBunRouterRequest(req) + handler.HandleGet(respAdapter, reqAdapter, params) return nil }) @@ -436,14 +442,15 @@ func SetupBunRouterRoutes(r BunRouterHandler, handler *Handler) { // OPTIONS route with ID (returns metadata) r.Handle("OPTIONS", entityWithIDPath, func(w http.ResponseWriter, req bunrouter.Request) error { respAdapter := router.NewHTTPResponseWriter(w) + reqAdapter := router.NewBunRouterRequest(req) optionsCorsConfig := corsConfig optionsCorsConfig.AllowedMethods = []string{"GET", "PUT", "PATCH", "DELETE", "POST", "OPTIONS"} - common.SetCORSHeaders(respAdapter, optionsCorsConfig) + common.SetCORSHeaders(respAdapter, reqAdapter, optionsCorsConfig) params := map[string]string{ "schema": currentSchema, "entity": currentEntity, } - reqAdapter := router.NewBunRouterRequest(req) + handler.HandleGet(respAdapter, reqAdapter, params) return nil }) From b7a67a6974c3746354eb67a3414595f6b147db70 Mon Sep 17 00:00:00 2001 From: Hein Date: Mon, 12 Jan 2026 11:12:42 +0200 Subject: [PATCH 12/31] =?UTF-8?q?fix(headers):=20=F0=9F=90=9B=20handle=20s?= =?UTF-8?q?earch=20on=20computed=20columns?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/restheadspec/headers.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/pkg/restheadspec/headers.go b/pkg/restheadspec/headers.go index 7c5d209..bdae2bd 100644 --- a/pkg/restheadspec/headers.go +++ b/pkg/restheadspec/headers.go @@ -354,6 +354,12 @@ func (h *Handler) parseSearchOp(options *ExtendedRequestOptions, headerKey, valu operator := parts[0] colName := parts[1] + if strings.HasPrefix(colName, "cql") { + // Computed column - Will not filter on it + logger.Warn("Search operators on computed columns are not supported: %s", colName) + return + } + // Map operator names to filter operators filterOp := h.mapSearchOperator(colName, operator, value) From 0ac207d80fd3a696ce61ac2b815703878501610b Mon Sep 17 00:00:00 2001 From: Hein Date: Tue, 13 Jan 2026 11:33:45 +0200 Subject: [PATCH 13/31] fix: better update handling --- pkg/reflection/model_utils.go | 38 +- pkg/reflection/spectypes_integration_test.go | 364 +++++++++++++++++++ pkg/resolvespec/handler.go | 254 +++++++++---- pkg/restheadspec/handler.go | 53 +-- 4 files changed, 590 insertions(+), 119 deletions(-) create mode 100644 pkg/reflection/spectypes_integration_test.go diff --git a/pkg/reflection/model_utils.go b/pkg/reflection/model_utils.go index 150ae9d..b07d4d2 100644 --- a/pkg/reflection/model_utils.go +++ b/pkg/reflection/model_utils.go @@ -948,29 +948,35 @@ func MapToStruct(dataMap map[string]interface{}, target interface{}) error { // Build list of possible column names for this field var columnNames []string - // 1. Bun tag - if bunTag := field.Tag.Get("bun"); bunTag != "" && bunTag != "-" { - if colName := ExtractColumnFromBunTag(bunTag); colName != "" { - columnNames = append(columnNames, colName) - } - } - - // 2. Gorm tag - if gormTag := field.Tag.Get("gorm"); gormTag != "" && gormTag != "-" { - if colName := ExtractColumnFromGormTag(gormTag); colName != "" { - columnNames = append(columnNames, colName) - } - } - - // 3. JSON tag + // 1. JSON tag (primary - most common) + jsonFound := false if jsonTag := field.Tag.Get("json"); jsonTag != "" && jsonTag != "-" { parts := strings.Split(jsonTag, ",") if len(parts) > 0 && parts[0] != "" { columnNames = append(columnNames, parts[0]) + jsonFound = true } } - // 4. Field name variations + // 2. Bun tag (fallback if no JSON tag) + if !jsonFound { + if bunTag := field.Tag.Get("bun"); bunTag != "" && bunTag != "-" { + if colName := ExtractColumnFromBunTag(bunTag); colName != "" { + columnNames = append(columnNames, colName) + } + } + } + + // 3. Gorm tag (fallback if no JSON tag) + if !jsonFound { + if gormTag := field.Tag.Get("gorm"); gormTag != "" && gormTag != "-" { + if colName := ExtractColumnFromGormTag(gormTag); colName != "" { + columnNames = append(columnNames, colName) + } + } + } + + // 4. Field name variations (last resort) columnNames = append(columnNames, field.Name) columnNames = append(columnNames, strings.ToLower(field.Name)) // columnNames = append(columnNames, ToSnakeCase(field.Name)) diff --git a/pkg/reflection/spectypes_integration_test.go b/pkg/reflection/spectypes_integration_test.go new file mode 100644 index 0000000..34bfb15 --- /dev/null +++ b/pkg/reflection/spectypes_integration_test.go @@ -0,0 +1,364 @@ +package reflection + +import ( + "testing" + "time" + + "github.com/bitechdev/ResolveSpec/pkg/spectypes" + "github.com/google/uuid" +) + +// TestModel contains all spectypes custom types +type TestModel struct { + ID int64 `bun:"id,pk" json:"id"` + Name spectypes.SqlString `bun:"name" json:"name"` + Age spectypes.SqlInt64 `bun:"age" json:"age"` + Score spectypes.SqlFloat64 `bun:"score" json:"score"` + Active spectypes.SqlBool `bun:"active" json:"active"` + UUID spectypes.SqlUUID `bun:"uuid" json:"uuid"` + CreatedAt spectypes.SqlTimeStamp `bun:"created_at" json:"created_at"` + BirthDate spectypes.SqlDate `bun:"birth_date" json:"birth_date"` + StartTime spectypes.SqlTime `bun:"start_time" json:"start_time"` + Metadata spectypes.SqlJSONB `bun:"metadata" json:"metadata"` + Count16 spectypes.SqlInt16 `bun:"count16" json:"count16"` + Count32 spectypes.SqlInt32 `bun:"count32" json:"count32"` +} + +// TestMapToStruct_AllSpectypes verifies that MapToStruct can convert all spectypes correctly +func TestMapToStruct_AllSpectypes(t *testing.T) { + testUUID := uuid.New() + testTime := time.Now() + + tests := []struct { + name string + dataMap map[string]interface{} + validator func(*testing.T, *TestModel) + }{ + { + name: "SqlString from string", + dataMap: map[string]interface{}{ + "name": "John Doe", + }, + validator: func(t *testing.T, m *TestModel) { + if !m.Name.Valid || m.Name.String() != "John Doe" { + t.Errorf("expected name='John Doe', got valid=%v, value=%s", m.Name.Valid, m.Name.String()) + } + }, + }, + { + name: "SqlInt64 from int64", + dataMap: map[string]interface{}{ + "age": int64(42), + }, + validator: func(t *testing.T, m *TestModel) { + if !m.Age.Valid || m.Age.Int64() != 42 { + t.Errorf("expected age=42, got valid=%v, value=%d", m.Age.Valid, m.Age.Int64()) + } + }, + }, + { + name: "SqlInt64 from string", + dataMap: map[string]interface{}{ + "age": "99", + }, + validator: func(t *testing.T, m *TestModel) { + if !m.Age.Valid || m.Age.Int64() != 99 { + t.Errorf("expected age=99, got valid=%v, value=%d", m.Age.Valid, m.Age.Int64()) + } + }, + }, + { + name: "SqlFloat64 from float64", + dataMap: map[string]interface{}{ + "score": float64(98.5), + }, + validator: func(t *testing.T, m *TestModel) { + if !m.Score.Valid || m.Score.Float64() != 98.5 { + t.Errorf("expected score=98.5, got valid=%v, value=%f", m.Score.Valid, m.Score.Float64()) + } + }, + }, + { + name: "SqlBool from bool", + dataMap: map[string]interface{}{ + "active": true, + }, + validator: func(t *testing.T, m *TestModel) { + if !m.Active.Valid || !m.Active.Bool() { + t.Errorf("expected active=true, got valid=%v, value=%v", m.Active.Valid, m.Active.Bool()) + } + }, + }, + { + name: "SqlUUID from string", + dataMap: map[string]interface{}{ + "uuid": testUUID.String(), + }, + validator: func(t *testing.T, m *TestModel) { + if !m.UUID.Valid || m.UUID.UUID() != testUUID { + t.Errorf("expected uuid=%s, got valid=%v, value=%s", testUUID.String(), m.UUID.Valid, m.UUID.UUID().String()) + } + }, + }, + { + name: "SqlTimeStamp from time.Time", + dataMap: map[string]interface{}{ + "created_at": testTime, + }, + validator: func(t *testing.T, m *TestModel) { + if !m.CreatedAt.Valid { + t.Errorf("expected created_at to be valid") + } + // Check if times are close enough (within a second) + diff := m.CreatedAt.Time().Sub(testTime) + if diff < -time.Second || diff > time.Second { + t.Errorf("time difference too large: %v", diff) + } + }, + }, + { + name: "SqlTimeStamp from string", + dataMap: map[string]interface{}{ + "created_at": "2024-01-15T10:30:00", + }, + validator: func(t *testing.T, m *TestModel) { + if !m.CreatedAt.Valid { + t.Errorf("expected created_at to be valid") + } + expected := time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC) + if m.CreatedAt.Time().Year() != expected.Year() || + m.CreatedAt.Time().Month() != expected.Month() || + m.CreatedAt.Time().Day() != expected.Day() { + t.Errorf("expected date 2024-01-15, got %v", m.CreatedAt.Time()) + } + }, + }, + { + name: "SqlDate from string", + dataMap: map[string]interface{}{ + "birth_date": "2000-05-20", + }, + validator: func(t *testing.T, m *TestModel) { + if !m.BirthDate.Valid { + t.Errorf("expected birth_date to be valid") + } + expected := "2000-05-20" + if m.BirthDate.String() != expected { + t.Errorf("expected date=%s, got %s", expected, m.BirthDate.String()) + } + }, + }, + { + name: "SqlTime from string", + dataMap: map[string]interface{}{ + "start_time": "14:30:00", + }, + validator: func(t *testing.T, m *TestModel) { + if !m.StartTime.Valid { + t.Errorf("expected start_time to be valid") + } + if m.StartTime.String() != "14:30:00" { + t.Errorf("expected time=14:30:00, got %s", m.StartTime.String()) + } + }, + }, + { + name: "SqlJSONB from map", + dataMap: map[string]interface{}{ + "metadata": map[string]interface{}{ + "key1": "value1", + "key2": 123, + }, + }, + validator: func(t *testing.T, m *TestModel) { + if len(m.Metadata) == 0 { + t.Errorf("expected metadata to have data") + } + asMap, err := m.Metadata.AsMap() + if err != nil { + t.Fatalf("failed to convert metadata to map: %v", err) + } + if asMap["key1"] != "value1" { + t.Errorf("expected key1=value1, got %v", asMap["key1"]) + } + }, + }, + { + name: "SqlJSONB from string", + dataMap: map[string]interface{}{ + "metadata": `{"test":"data"}`, + }, + validator: func(t *testing.T, m *TestModel) { + if len(m.Metadata) == 0 { + t.Errorf("expected metadata to have data") + } + asMap, err := m.Metadata.AsMap() + if err != nil { + t.Fatalf("failed to convert metadata to map: %v", err) + } + if asMap["test"] != "data" { + t.Errorf("expected test=data, got %v", asMap["test"]) + } + }, + }, + { + name: "SqlJSONB from []byte", + dataMap: map[string]interface{}{ + "metadata": []byte(`{"byte":"array"}`), + }, + validator: func(t *testing.T, m *TestModel) { + if len(m.Metadata) == 0 { + t.Errorf("expected metadata to have data") + } + if string(m.Metadata) != `{"byte":"array"}` { + t.Errorf("expected {\"byte\":\"array\"}, got %s", string(m.Metadata)) + } + }, + }, + { + name: "SqlInt16 from int16", + dataMap: map[string]interface{}{ + "count16": int16(100), + }, + validator: func(t *testing.T, m *TestModel) { + if !m.Count16.Valid || m.Count16.Int64() != 100 { + t.Errorf("expected count16=100, got valid=%v, value=%d", m.Count16.Valid, m.Count16.Int64()) + } + }, + }, + { + name: "SqlInt32 from int32", + dataMap: map[string]interface{}{ + "count32": int32(5000), + }, + validator: func(t *testing.T, m *TestModel) { + if !m.Count32.Valid || m.Count32.Int64() != 5000 { + t.Errorf("expected count32=5000, got valid=%v, value=%d", m.Count32.Valid, m.Count32.Int64()) + } + }, + }, + { + name: "nil values create invalid nulls", + dataMap: map[string]interface{}{ + "name": nil, + "age": nil, + "active": nil, + "created_at": nil, + }, + validator: func(t *testing.T, m *TestModel) { + if m.Name.Valid { + t.Error("expected name to be invalid for nil value") + } + if m.Age.Valid { + t.Error("expected age to be invalid for nil value") + } + if m.Active.Valid { + t.Error("expected active to be invalid for nil value") + } + if m.CreatedAt.Valid { + t.Error("expected created_at to be invalid for nil value") + } + }, + }, + { + name: "all types together", + dataMap: map[string]interface{}{ + "id": int64(1), + "name": "Test User", + "age": int64(30), + "score": float64(95.7), + "active": true, + "uuid": testUUID.String(), + "created_at": "2024-01-15T10:30:00", + "birth_date": "1994-06-15", + "start_time": "09:00:00", + "metadata": map[string]interface{}{"role": "admin"}, + "count16": int16(50), + "count32": int32(1000), + }, + validator: func(t *testing.T, m *TestModel) { + if m.ID != 1 { + t.Errorf("expected id=1, got %d", m.ID) + } + if !m.Name.Valid || m.Name.String() != "Test User" { + t.Errorf("expected name='Test User', got valid=%v, value=%s", m.Name.Valid, m.Name.String()) + } + if !m.Age.Valid || m.Age.Int64() != 30 { + t.Errorf("expected age=30, got valid=%v, value=%d", m.Age.Valid, m.Age.Int64()) + } + if !m.Score.Valid || m.Score.Float64() != 95.7 { + t.Errorf("expected score=95.7, got valid=%v, value=%f", m.Score.Valid, m.Score.Float64()) + } + if !m.Active.Valid || !m.Active.Bool() { + t.Errorf("expected active=true, got valid=%v, value=%v", m.Active.Valid, m.Active.Bool()) + } + if !m.UUID.Valid { + t.Error("expected uuid to be valid") + } + if !m.CreatedAt.Valid { + t.Error("expected created_at to be valid") + } + if !m.BirthDate.Valid || m.BirthDate.String() != "1994-06-15" { + t.Errorf("expected birth_date=1994-06-15, got valid=%v, value=%s", m.BirthDate.Valid, m.BirthDate.String()) + } + if !m.StartTime.Valid || m.StartTime.String() != "09:00:00" { + t.Errorf("expected start_time=09:00:00, got valid=%v, value=%s", m.StartTime.Valid, m.StartTime.String()) + } + if len(m.Metadata) == 0 { + t.Error("expected metadata to have data") + } + if !m.Count16.Valid || m.Count16.Int64() != 50 { + t.Errorf("expected count16=50, got valid=%v, value=%d", m.Count16.Valid, m.Count16.Int64()) + } + if !m.Count32.Valid || m.Count32.Int64() != 1000 { + t.Errorf("expected count32=1000, got valid=%v, value=%d", m.Count32.Valid, m.Count32.Int64()) + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + model := &TestModel{} + if err := MapToStruct(tt.dataMap, model); err != nil { + t.Fatalf("MapToStruct failed: %v", err) + } + tt.validator(t, model) + }) + } +} + +// TestMapToStruct_PartialUpdate tests that partial updates preserve unset fields +func TestMapToStruct_PartialUpdate(t *testing.T) { + // Create initial model with some values + initial := &TestModel{ + ID: 1, + Name: spectypes.NewSqlString("Original Name"), + Age: spectypes.NewSqlInt64(25), + } + + // Update only the age field + partialUpdate := map[string]interface{}{ + "age": int64(30), + } + + // Apply partial update + if err := MapToStruct(partialUpdate, initial); err != nil { + t.Fatalf("MapToStruct failed: %v", err) + } + + // Verify age was updated + if !initial.Age.Valid || initial.Age.Int64() != 30 { + t.Errorf("expected age=30, got valid=%v, value=%d", initial.Age.Valid, initial.Age.Int64()) + } + + // Verify name was preserved (not overwritten with zero value) + if !initial.Name.Valid || initial.Name.String() != "Original Name" { + t.Errorf("expected name='Original Name' to be preserved, got valid=%v, value=%s", initial.Name.Valid, initial.Name.String()) + } + + // Verify ID was preserved + if initial.ID != 1 { + t.Errorf("expected id=1 to be preserved, got %d", initial.ID) + } +} diff --git a/pkg/resolvespec/handler.go b/pkg/resolvespec/handler.go index bf082e7..a8d32d0 100644 --- a/pkg/resolvespec/handler.go +++ b/pkg/resolvespec/handler.go @@ -701,97 +701,130 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url // Get the primary key name pkName := reflection.GetPrimaryKeyName(model) - // First, read the existing record from the database - existingRecord := reflect.New(reflection.GetPointerElement(reflect.TypeOf(model))).Interface() - selectQuery := h.db.NewSelect().Model(existingRecord) + // Wrap in transaction to ensure BeforeUpdate hook is inside transaction + err := h.db.RunInTransaction(ctx, func(tx common.Database) error { + // First, read the existing record from the database + existingRecord := reflect.New(reflection.GetPointerElement(reflect.TypeOf(model))).Interface() + selectQuery := tx.NewSelect().Model(existingRecord).Column("*") - // Apply conditions to select - if urlID != "" { - logger.Debug("Updating by URL ID: %s", urlID) - selectQuery = selectQuery.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), urlID) - } else if reqID != nil { - switch id := reqID.(type) { - case string: - logger.Debug("Updating by request ID: %s", id) - selectQuery = selectQuery.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id) - case []string: - if len(id) > 0 { - logger.Debug("Updating by multiple IDs: %v", id) - selectQuery = selectQuery.Where(fmt.Sprintf("%s IN (?)", common.QuoteIdent(pkName)), id) + // Apply conditions to select + if urlID != "" { + logger.Debug("Updating by URL ID: %s", urlID) + selectQuery = selectQuery.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), urlID) + } else if reqID != nil { + switch id := reqID.(type) { + case string: + logger.Debug("Updating by request ID: %s", id) + selectQuery = selectQuery.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id) + case []string: + if len(id) > 0 { + logger.Debug("Updating by multiple IDs: %v", id) + selectQuery = selectQuery.Where(fmt.Sprintf("%s IN (?)", common.QuoteIdent(pkName)), id) + } } } - } - if err := selectQuery.ScanModel(ctx); err != nil { - if err == sql.ErrNoRows { - logger.Warn("No records found to update") - h.sendError(w, http.StatusNotFound, "not_found", "No records found to update", nil) - return - } - logger.Error("Error fetching existing record: %v", err) - h.sendError(w, http.StatusInternalServerError, "update_error", "Error fetching existing record", err) - return - } - - // Convert existing record to map - existingMap := make(map[string]interface{}) - jsonData, err := json.Marshal(existingRecord) - if err != nil { - logger.Error("Error marshaling existing record: %v", err) - h.sendError(w, http.StatusInternalServerError, "update_error", "Error processing existing record", err) - return - } - if err := json.Unmarshal(jsonData, &existingMap); err != nil { - logger.Error("Error unmarshaling existing record: %v", err) - h.sendError(w, http.StatusInternalServerError, "update_error", "Error processing existing record", err) - return - } - - // Merge only non-null and non-empty values from the incoming request into the existing record - for key, newValue := range updates { - // Skip if the value is nil - if newValue == nil { - continue + if err := selectQuery.ScanModel(ctx); err != nil { + if err == sql.ErrNoRows { + return fmt.Errorf("no records found to update") + } + return fmt.Errorf("error fetching existing record: %w", err) } - // Skip if the value is an empty string - if strVal, ok := newValue.(string); ok && strVal == "" { - continue + // Convert existing record to map + existingMap := make(map[string]interface{}) + jsonData, err := json.Marshal(existingRecord) + if err != nil { + return fmt.Errorf("error marshaling existing record: %w", err) + } + if err := json.Unmarshal(jsonData, &existingMap); err != nil { + return fmt.Errorf("error unmarshaling existing record: %w", err) } - // Update the existing map with the new value - existingMap[key] = newValue - } - - // Build update query with merged data - query := h.db.NewUpdate().Table(tableName).SetMap(existingMap) - - // Apply conditions - if urlID != "" { - query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), urlID) - } else if reqID != nil { - switch id := reqID.(type) { - case string: - query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id) - case []string: - query = query.Where(fmt.Sprintf("%s IN (?)", common.QuoteIdent(pkName)), id) + // Execute BeforeUpdate hooks inside transaction + hookCtx := &HookContext{ + Context: ctx, + Handler: h, + Schema: schema, + Entity: entity, + Model: model, + Options: options, + ID: urlID, + Data: updates, + Writer: w, + Tx: tx, } - } - result, err := query.Exec(ctx) + if err := h.hooks.Execute(BeforeUpdate, hookCtx); err != nil { + return fmt.Errorf("BeforeUpdate hook failed: %w", err) + } + + // Use potentially modified data from hook context + if modifiedData, ok := hookCtx.Data.(map[string]interface{}); ok { + updates = modifiedData + } + + // Merge only non-null and non-empty values from the incoming request into the existing record + for key, newValue := range updates { + // Skip if the value is nil + if newValue == nil { + continue + } + + // Skip if the value is an empty string + if strVal, ok := newValue.(string); ok && strVal == "" { + continue + } + + // Update the existing map with the new value + existingMap[key] = newValue + } + + // Build update query with merged data + query := tx.NewUpdate().Table(tableName).SetMap(existingMap) + + // Apply conditions + if urlID != "" { + query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), urlID) + } else if reqID != nil { + switch id := reqID.(type) { + case string: + query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id) + case []string: + query = query.Where(fmt.Sprintf("%s IN (?)", common.QuoteIdent(pkName)), id) + } + } + + result, err := query.Exec(ctx) + if err != nil { + return fmt.Errorf("error updating record(s): %w", err) + } + + if result.RowsAffected() == 0 { + return fmt.Errorf("no records found to update") + } + + // Execute AfterUpdate hooks inside transaction + hookCtx.Result = updates + hookCtx.Error = nil + if err := h.hooks.Execute(AfterUpdate, hookCtx); err != nil { + return fmt.Errorf("AfterUpdate hook failed: %w", err) + } + + return nil + }) + if err != nil { logger.Error("Update error: %v", err) - h.sendError(w, http.StatusInternalServerError, "update_error", "Error updating record(s)", err) + if err.Error() == "no records found to update" { + h.sendError(w, http.StatusNotFound, "not_found", "No records found to update", err) + } else { + h.sendError(w, http.StatusInternalServerError, "update_error", "Error updating record(s)", err) + } return } - if result.RowsAffected() == 0 { - logger.Warn("No records found to update") - h.sendError(w, http.StatusNotFound, "not_found", "No records found to update", nil) - return - } - - logger.Info("Successfully updated %d records", result.RowsAffected()) + logger.Info("Successfully updated record(s)") // Invalidate cache for this table cacheTags := buildCacheTags(schema, tableName) if err := invalidateCacheForTags(ctx, cacheTags); err != nil { @@ -849,9 +882,11 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url err := h.db.RunInTransaction(ctx, func(tx common.Database) error { for _, item := range updates { if itemID, ok := item["id"]; ok { + itemIDStr := fmt.Sprintf("%v", itemID) + // First, read the existing record existingRecord := reflect.New(reflection.GetPointerElement(reflect.TypeOf(model))).Interface() - selectQuery := tx.NewSelect().Model(existingRecord).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), itemID) + selectQuery := tx.NewSelect().Model(existingRecord).Column("*").Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), itemID) if err := selectQuery.ScanModel(ctx); err != nil { if err == sql.ErrNoRows { continue // Skip if record not found @@ -869,6 +904,29 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url return fmt.Errorf("failed to unmarshal existing record: %w", err) } + // Execute BeforeUpdate hooks inside transaction + hookCtx := &HookContext{ + Context: ctx, + Handler: h, + Schema: schema, + Entity: entity, + Model: model, + Options: options, + ID: itemIDStr, + Data: item, + Writer: w, + Tx: tx, + } + + if err := h.hooks.Execute(BeforeUpdate, hookCtx); err != nil { + return fmt.Errorf("BeforeUpdate hook failed for ID %v: %w", itemID, err) + } + + // Use potentially modified data from hook context + if modifiedData, ok := hookCtx.Data.(map[string]interface{}); ok { + item = modifiedData + } + // Merge only non-null and non-empty values for key, newValue := range item { if newValue == nil { @@ -884,6 +942,13 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url if _, err := txQuery.Exec(ctx); err != nil { return err } + + // Execute AfterUpdate hooks inside transaction + hookCtx.Result = item + hookCtx.Error = nil + if err := h.hooks.Execute(AfterUpdate, hookCtx); err != nil { + return fmt.Errorf("AfterUpdate hook failed for ID %v: %w", itemID, err) + } } } return nil @@ -957,9 +1022,11 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url for _, item := range updates { if itemMap, ok := item.(map[string]interface{}); ok { if itemID, ok := itemMap["id"]; ok { + itemIDStr := fmt.Sprintf("%v", itemID) + // First, read the existing record existingRecord := reflect.New(reflection.GetPointerElement(reflect.TypeOf(model))).Interface() - selectQuery := tx.NewSelect().Model(existingRecord).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), itemID) + selectQuery := tx.NewSelect().Model(existingRecord).Column("*").Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), itemID) if err := selectQuery.ScanModel(ctx); err != nil { if err == sql.ErrNoRows { continue // Skip if record not found @@ -977,6 +1044,29 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url return fmt.Errorf("failed to unmarshal existing record: %w", err) } + // Execute BeforeUpdate hooks inside transaction + hookCtx := &HookContext{ + Context: ctx, + Handler: h, + Schema: schema, + Entity: entity, + Model: model, + Options: options, + ID: itemIDStr, + Data: itemMap, + Writer: w, + Tx: tx, + } + + if err := h.hooks.Execute(BeforeUpdate, hookCtx); err != nil { + return fmt.Errorf("BeforeUpdate hook failed for ID %v: %w", itemID, err) + } + + // Use potentially modified data from hook context + if modifiedData, ok := hookCtx.Data.(map[string]interface{}); ok { + itemMap = modifiedData + } + // Merge only non-null and non-empty values for key, newValue := range itemMap { if newValue == nil { @@ -992,6 +1082,14 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url if _, err := txQuery.Exec(ctx); err != nil { return err } + + // Execute AfterUpdate hooks inside transaction + hookCtx.Result = itemMap + hookCtx.Error = nil + if err := h.hooks.Execute(AfterUpdate, hookCtx); err != nil { + return fmt.Errorf("AfterUpdate hook failed for ID %v: %w", itemID, err) + } + list = append(list, item) } } diff --git a/pkg/restheadspec/handler.go b/pkg/restheadspec/handler.go index 2f20ce1..5118073 100644 --- a/pkg/restheadspec/handler.go +++ b/pkg/restheadspec/handler.go @@ -1110,30 +1110,6 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id logger.Info("Updating record in %s.%s", schema, entity) - // Execute BeforeUpdate hooks - hookCtx := &HookContext{ - Context: ctx, - Handler: h, - Schema: schema, - Entity: entity, - TableName: tableName, - Tx: h.db, - Model: model, - Options: options, - ID: id, - Data: data, - Writer: w, - } - - if err := h.hooks.Execute(BeforeUpdate, hookCtx); err != nil { - logger.Error("BeforeUpdate hook failed: %v", err) - h.sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err) - return - } - - // Use potentially modified data from hook context - data = hookCtx.Data - // Convert data to map dataMap, ok := data.(map[string]interface{}) if !ok { @@ -1167,6 +1143,9 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id // Variable to store the updated record var updatedRecord interface{} + // Declare hook context to be used inside and outside transaction + var hookCtx *HookContext + // Process nested relations if present err := h.db.RunInTransaction(ctx, func(tx common.Database) error { // Create temporary nested processor with transaction @@ -1174,7 +1153,7 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id // First, read the existing record from the database existingRecord := reflect.New(reflection.GetPointerElement(reflect.TypeOf(model))).Interface() - selectQuery := tx.NewSelect().Model(existingRecord).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), targetID) + selectQuery := tx.NewSelect().Model(existingRecord).Column("*").Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), targetID) if err := selectQuery.ScanModel(ctx); err != nil { if err == sql.ErrNoRows { return fmt.Errorf("record not found with ID: %v", targetID) @@ -1204,6 +1183,30 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id nestedRelations = relations } + // Execute BeforeUpdate hooks inside transaction + hookCtx = &HookContext{ + Context: ctx, + Handler: h, + Schema: schema, + Entity: entity, + TableName: tableName, + Tx: tx, + Model: model, + Options: options, + ID: id, + Data: dataMap, + Writer: w, + } + + if err := h.hooks.Execute(BeforeUpdate, hookCtx); err != nil { + return fmt.Errorf("BeforeUpdate hook failed: %w", err) + } + + // Use potentially modified data from hook context + if modifiedData, ok := hookCtx.Data.(map[string]interface{}); ok { + dataMap = modifiedData + } + // Merge only non-null and non-empty values from the incoming request into the existing record for key, newValue := range dataMap { // Skip if the value is nil From cf6a81e805164109d585b529990f54ecb7e700ff Mon Sep 17 00:00:00 2001 From: Hein Date: Tue, 13 Jan 2026 12:18:13 +0200 Subject: [PATCH 14/31] =?UTF-8?q?feat(reflection):=20=E2=9C=A8=20add=20tes?= =?UTF-8?q?ts=20for=20standard=20SQL=20null=20types?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Implement tests for mapping standard library sql.Null* types to struct. * Verify handling of valid and nil values for sql.NullInt64, sql.NullString, sql.NullFloat64, sql.NullBool, and sql.NullTime. * Ensure correct error handling and type conversion in MapToStruct function. --- pkg/reflection/model_utils.go | 18 +-- .../model_utils_stdlib_sqltypes_test.go | 120 ++++++++++++++++++ 2 files changed, 129 insertions(+), 9 deletions(-) create mode 100644 pkg/reflection/model_utils_stdlib_sqltypes_test.go diff --git a/pkg/reflection/model_utils.go b/pkg/reflection/model_utils.go index b07d4d2..58be9f8 100644 --- a/pkg/reflection/model_utils.go +++ b/pkg/reflection/model_utils.go @@ -1102,6 +1102,12 @@ func setFieldValue(field reflect.Value, value interface{}) error { } } + // If we can convert the type, do it + if valueReflect.Type().ConvertibleTo(field.Type()) { + field.Set(valueReflect.Convert(field.Type())) + return nil + } + // Handle struct types (like SqlTimeStamp, SqlDate, SqlTime which wrap SqlNull[time.Time]) if field.Kind() == reflect.Struct { @@ -1113,9 +1119,9 @@ func setFieldValue(field reflect.Value, value interface{}) error { // Call the Scan method with the value results := scanMethod.Call([]reflect.Value{reflect.ValueOf(value)}) if len(results) > 0 { - // Check if there was an error - if err, ok := results[0].Interface().(error); ok && err != nil { - return err + // The Scan method returns error - check if it's nil + if !results[0].IsNil() { + return results[0].Interface().(error) } return nil } @@ -1170,12 +1176,6 @@ func setFieldValue(field reflect.Value, value interface{}) error { } - // If we can convert the type, do it - if valueReflect.Type().ConvertibleTo(field.Type()) { - field.Set(valueReflect.Convert(field.Type())) - return nil - } - return fmt.Errorf("cannot convert %v to %v", valueReflect.Type(), field.Type()) } diff --git a/pkg/reflection/model_utils_stdlib_sqltypes_test.go b/pkg/reflection/model_utils_stdlib_sqltypes_test.go new file mode 100644 index 0000000..0c4891e --- /dev/null +++ b/pkg/reflection/model_utils_stdlib_sqltypes_test.go @@ -0,0 +1,120 @@ +package reflection_test + +import ( + "database/sql" + "testing" + "time" + + "github.com/bitechdev/ResolveSpec/pkg/reflection" +) + +func TestMapToStruct_StandardSqlNullTypes(t *testing.T) { + // Test model with standard library sql.Null* types + type TestModel struct { + ID int64 `bun:"id,pk" json:"id"` + Age sql.NullInt64 `bun:"age" json:"age"` + Name sql.NullString `bun:"name" json:"name"` + Score sql.NullFloat64 `bun:"score" json:"score"` + Active sql.NullBool `bun:"active" json:"active"` + UpdatedAt sql.NullTime `bun:"updated_at" json:"updated_at"` + } + + now := time.Now() + dataMap := map[string]any{ + "id": int64(100), + "age": int64(25), + "name": "John Doe", + "score": 95.5, + "active": true, + "updated_at": now, + } + + var result TestModel + err := reflection.MapToStruct(dataMap, &result) + if err != nil { + t.Fatalf("MapToStruct() error = %v", err) + } + + // Verify ID + if result.ID != 100 { + t.Errorf("ID = %v, want 100", result.ID) + } + + // Verify Age (sql.NullInt64) + if !result.Age.Valid { + t.Error("Age.Valid = false, want true") + } + if result.Age.Int64 != 25 { + t.Errorf("Age.Int64 = %v, want 25", result.Age.Int64) + } + + // Verify Name (sql.NullString) + if !result.Name.Valid { + t.Error("Name.Valid = false, want true") + } + if result.Name.String != "John Doe" { + t.Errorf("Name.String = %v, want 'John Doe'", result.Name.String) + } + + // Verify Score (sql.NullFloat64) + if !result.Score.Valid { + t.Error("Score.Valid = false, want true") + } + if result.Score.Float64 != 95.5 { + t.Errorf("Score.Float64 = %v, want 95.5", result.Score.Float64) + } + + // Verify Active (sql.NullBool) + if !result.Active.Valid { + t.Error("Active.Valid = false, want true") + } + if !result.Active.Bool { + t.Error("Active.Bool = false, want true") + } + + // Verify UpdatedAt (sql.NullTime) + if !result.UpdatedAt.Valid { + t.Error("UpdatedAt.Valid = false, want true") + } + if !result.UpdatedAt.Time.Equal(now) { + t.Errorf("UpdatedAt.Time = %v, want %v", result.UpdatedAt.Time, now) + } + + t.Log("All standard library sql.Null* types handled correctly!") +} + +func TestMapToStruct_StandardSqlNullTypes_WithNil(t *testing.T) { + // Test nil handling for standard library sql.Null* types + type TestModel struct { + ID int64 `bun:"id,pk" json:"id"` + Age sql.NullInt64 `bun:"age" json:"age"` + Name sql.NullString `bun:"name" json:"name"` + } + + dataMap := map[string]any{ + "id": int64(200), + "age": int64(30), + "name": nil, // Explicitly nil + } + + var result TestModel + err := reflection.MapToStruct(dataMap, &result) + if err != nil { + t.Fatalf("MapToStruct() error = %v", err) + } + + // Age should be valid + if !result.Age.Valid { + t.Error("Age.Valid = false, want true") + } + if result.Age.Int64 != 30 { + t.Errorf("Age.Int64 = %v, want 30", result.Age.Int64) + } + + // Name should be invalid (null) + if result.Name.Valid { + t.Error("Name.Valid = true, want false (null)") + } + + t.Log("Nil handling for sql.Null* types works correctly!") +} From 276854768e2632b223791ef247e86bdbcb5a5153 Mon Sep 17 00:00:00 2001 From: Hein Date: Tue, 13 Jan 2026 12:50:12 +0200 Subject: [PATCH 15/31] =?UTF-8?q?feat(dbmanager):=20=E2=9C=A8=20add=20supp?= =?UTF-8?q?ort=20for=20existing=20SQL=20connections?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Introduced NewConnectionFromDB function to create connections from existing *sql.DB instances. * Added ExistingDBProvider to wrap existing database connections for dbmanager features. * Implemented tests for NewConnectionFromDB and ExistingDBProvider functionalities. --- pkg/dbmanager/factory.go | 16 ++ pkg/dbmanager/factory_test.go | 210 ++++++++++++++++++++ pkg/dbmanager/providers/existing_db.go | 111 +++++++++++ pkg/dbmanager/providers/existing_db_test.go | 194 ++++++++++++++++++ 4 files changed, 531 insertions(+) create mode 100644 pkg/dbmanager/factory_test.go create mode 100644 pkg/dbmanager/providers/existing_db.go create mode 100644 pkg/dbmanager/providers/existing_db_test.go diff --git a/pkg/dbmanager/factory.go b/pkg/dbmanager/factory.go index 9a0efa2..0cbc950 100644 --- a/pkg/dbmanager/factory.go +++ b/pkg/dbmanager/factory.go @@ -1,6 +1,7 @@ package dbmanager import ( + "database/sql" "fmt" "github.com/bitechdev/ResolveSpec/pkg/dbmanager/providers" @@ -49,3 +50,18 @@ func createProvider(dbType DatabaseType) (Provider, error) { // Provider is an alias to the providers.Provider interface // This allows dbmanager package consumers to use Provider without importing providers type Provider = providers.Provider + +// NewConnectionFromDB creates a new Connection from an existing *sql.DB +// This allows you to use dbmanager features (ORM wrappers, health checks, etc.) +// with a database connection that was opened outside of dbmanager +// +// Parameters: +// - name: A unique name for this connection +// - dbType: The database type (DatabaseTypePostgreSQL, DatabaseTypeSQLite, or DatabaseTypeMSSQL) +// - db: An existing *sql.DB connection +// +// Returns a Connection that wraps the existing *sql.DB +func NewConnectionFromDB(name string, dbType DatabaseType, db *sql.DB) Connection { + provider := providers.NewExistingDBProvider(db, name) + return newSQLConnection(name, dbType, ConnectionConfig{Name: name, Type: dbType}, provider) +} diff --git a/pkg/dbmanager/factory_test.go b/pkg/dbmanager/factory_test.go new file mode 100644 index 0000000..38c0312 --- /dev/null +++ b/pkg/dbmanager/factory_test.go @@ -0,0 +1,210 @@ +package dbmanager + +import ( + "context" + "database/sql" + "testing" + + _ "github.com/mattn/go-sqlite3" +) + +func TestNewConnectionFromDB(t *testing.T) { + // Open a SQLite in-memory database + db, err := sql.Open("sqlite3", ":memory:") + if err != nil { + t.Fatalf("Failed to open database: %v", err) + } + defer db.Close() + + // Create a connection from the existing database + conn := NewConnectionFromDB("test-connection", DatabaseTypeSQLite, db) + if conn == nil { + t.Fatal("Expected connection to be created") + } + + // Verify connection properties + if conn.Name() != "test-connection" { + t.Errorf("Expected name 'test-connection', got '%s'", conn.Name()) + } + + if conn.Type() != DatabaseTypeSQLite { + t.Errorf("Expected type DatabaseTypeSQLite, got '%s'", conn.Type()) + } +} + +func TestNewConnectionFromDB_Connect(t *testing.T) { + db, err := sql.Open("sqlite3", ":memory:") + if err != nil { + t.Fatalf("Failed to open database: %v", err) + } + defer db.Close() + + conn := NewConnectionFromDB("test-connection", DatabaseTypeSQLite, db) + ctx := context.Background() + + // Connect should verify the existing connection works + err = conn.Connect(ctx) + if err != nil { + t.Errorf("Expected Connect to succeed, got error: %v", err) + } + + // Cleanup + conn.Close() +} + +func TestNewConnectionFromDB_Native(t *testing.T) { + db, err := sql.Open("sqlite3", ":memory:") + if err != nil { + t.Fatalf("Failed to open database: %v", err) + } + defer db.Close() + + conn := NewConnectionFromDB("test-connection", DatabaseTypeSQLite, db) + ctx := context.Background() + + err = conn.Connect(ctx) + if err != nil { + t.Fatalf("Failed to connect: %v", err) + } + defer conn.Close() + + // Get native DB + nativeDB, err := conn.Native() + if err != nil { + t.Errorf("Expected Native to succeed, got error: %v", err) + } + + if nativeDB != db { + t.Error("Expected Native to return the same database instance") + } +} + +func TestNewConnectionFromDB_Bun(t *testing.T) { + db, err := sql.Open("sqlite3", ":memory:") + if err != nil { + t.Fatalf("Failed to open database: %v", err) + } + defer db.Close() + + conn := NewConnectionFromDB("test-connection", DatabaseTypeSQLite, db) + ctx := context.Background() + + err = conn.Connect(ctx) + if err != nil { + t.Fatalf("Failed to connect: %v", err) + } + defer conn.Close() + + // Get Bun ORM + bunDB, err := conn.Bun() + if err != nil { + t.Errorf("Expected Bun to succeed, got error: %v", err) + } + + if bunDB == nil { + t.Error("Expected Bun to return a non-nil instance") + } +} + +func TestNewConnectionFromDB_GORM(t *testing.T) { + db, err := sql.Open("sqlite3", ":memory:") + if err != nil { + t.Fatalf("Failed to open database: %v", err) + } + defer db.Close() + + conn := NewConnectionFromDB("test-connection", DatabaseTypeSQLite, db) + ctx := context.Background() + + err = conn.Connect(ctx) + if err != nil { + t.Fatalf("Failed to connect: %v", err) + } + defer conn.Close() + + // Get GORM + gormDB, err := conn.GORM() + if err != nil { + t.Errorf("Expected GORM to succeed, got error: %v", err) + } + + if gormDB == nil { + t.Error("Expected GORM to return a non-nil instance") + } +} + +func TestNewConnectionFromDB_HealthCheck(t *testing.T) { + db, err := sql.Open("sqlite3", ":memory:") + if err != nil { + t.Fatalf("Failed to open database: %v", err) + } + defer db.Close() + + conn := NewConnectionFromDB("test-connection", DatabaseTypeSQLite, db) + ctx := context.Background() + + err = conn.Connect(ctx) + if err != nil { + t.Fatalf("Failed to connect: %v", err) + } + defer conn.Close() + + // Health check should succeed + err = conn.HealthCheck(ctx) + if err != nil { + t.Errorf("Expected HealthCheck to succeed, got error: %v", err) + } +} + +func TestNewConnectionFromDB_Stats(t *testing.T) { + db, err := sql.Open("sqlite3", ":memory:") + if err != nil { + t.Fatalf("Failed to open database: %v", err) + } + defer db.Close() + + conn := NewConnectionFromDB("test-connection", DatabaseTypeSQLite, db) + ctx := context.Background() + + err = conn.Connect(ctx) + if err != nil { + t.Fatalf("Failed to connect: %v", err) + } + defer conn.Close() + + stats := conn.Stats() + if stats == nil { + t.Fatal("Expected stats to be returned") + } + + if stats.Name != "test-connection" { + t.Errorf("Expected stats.Name to be 'test-connection', got '%s'", stats.Name) + } + + if stats.Type != DatabaseTypeSQLite { + t.Errorf("Expected stats.Type to be DatabaseTypeSQLite, got '%s'", stats.Type) + } + + if !stats.Connected { + t.Error("Expected stats.Connected to be true") + } +} + +func TestNewConnectionFromDB_PostgreSQL(t *testing.T) { + // This test just verifies the factory works with PostgreSQL type + // It won't actually connect since we're using SQLite + db, err := sql.Open("sqlite3", ":memory:") + if err != nil { + t.Fatalf("Failed to open database: %v", err) + } + defer db.Close() + + conn := NewConnectionFromDB("test-pg", DatabaseTypePostgreSQL, db) + if conn == nil { + t.Fatal("Expected connection to be created") + } + + if conn.Type() != DatabaseTypePostgreSQL { + t.Errorf("Expected type DatabaseTypePostgreSQL, got '%s'", conn.Type()) + } +} diff --git a/pkg/dbmanager/providers/existing_db.go b/pkg/dbmanager/providers/existing_db.go new file mode 100644 index 0000000..9b56a75 --- /dev/null +++ b/pkg/dbmanager/providers/existing_db.go @@ -0,0 +1,111 @@ +package providers + +import ( + "context" + "database/sql" + "fmt" + "sync" + + "go.mongodb.org/mongo-driver/mongo" +) + +// ExistingDBProvider wraps an existing *sql.DB connection +// This allows using dbmanager features with a database connection +// that was opened outside of the dbmanager package +type ExistingDBProvider struct { + db *sql.DB + name string + mu sync.RWMutex +} + +// NewExistingDBProvider creates a new provider wrapping an existing *sql.DB +func NewExistingDBProvider(db *sql.DB, name string) *ExistingDBProvider { + return &ExistingDBProvider{ + db: db, + name: name, + } +} + +// Connect verifies the existing database connection is valid +// It does NOT create a new connection, but ensures the existing one works +func (p *ExistingDBProvider) Connect(ctx context.Context, cfg ConnectionConfig) error { + p.mu.Lock() + defer p.mu.Unlock() + + if p.db == nil { + return fmt.Errorf("database connection is nil") + } + + // Verify the connection works + if err := p.db.PingContext(ctx); err != nil { + return fmt.Errorf("failed to ping existing database: %w", err) + } + + return nil +} + +// Close closes the underlying database connection +func (p *ExistingDBProvider) Close() error { + p.mu.Lock() + defer p.mu.Unlock() + + if p.db == nil { + return nil + } + + return p.db.Close() +} + +// HealthCheck verifies the connection is alive +func (p *ExistingDBProvider) HealthCheck(ctx context.Context) error { + p.mu.RLock() + defer p.mu.RUnlock() + + if p.db == nil { + return fmt.Errorf("database connection is nil") + } + + return p.db.PingContext(ctx) +} + +// GetNative returns the wrapped *sql.DB +func (p *ExistingDBProvider) GetNative() (*sql.DB, error) { + p.mu.RLock() + defer p.mu.RUnlock() + + if p.db == nil { + return nil, fmt.Errorf("database connection is nil") + } + + return p.db, nil +} + +// GetMongo returns an error since this is a SQL database +func (p *ExistingDBProvider) GetMongo() (*mongo.Client, error) { + return nil, ErrNotMongoDB +} + +// Stats returns connection statistics +func (p *ExistingDBProvider) Stats() *ConnectionStats { + p.mu.RLock() + defer p.mu.RUnlock() + + stats := &ConnectionStats{ + Name: p.name, + Type: "sql", // Generic since we don't know the specific type + Connected: p.db != nil, + } + + if p.db != nil { + dbStats := p.db.Stats() + stats.OpenConnections = dbStats.OpenConnections + stats.InUse = dbStats.InUse + stats.Idle = dbStats.Idle + stats.WaitCount = dbStats.WaitCount + stats.WaitDuration = dbStats.WaitDuration + stats.MaxIdleClosed = dbStats.MaxIdleClosed + stats.MaxLifetimeClosed = dbStats.MaxLifetimeClosed + } + + return stats +} diff --git a/pkg/dbmanager/providers/existing_db_test.go b/pkg/dbmanager/providers/existing_db_test.go new file mode 100644 index 0000000..d00e998 --- /dev/null +++ b/pkg/dbmanager/providers/existing_db_test.go @@ -0,0 +1,194 @@ +package providers + +import ( + "context" + "database/sql" + "testing" + "time" + + _ "github.com/mattn/go-sqlite3" +) + +func TestNewExistingDBProvider(t *testing.T) { + // Open a SQLite in-memory database + db, err := sql.Open("sqlite3", ":memory:") + if err != nil { + t.Fatalf("Failed to open database: %v", err) + } + defer db.Close() + + // Create provider + provider := NewExistingDBProvider(db, "test-db") + if provider == nil { + t.Fatal("Expected provider to be created") + } + + if provider.name != "test-db" { + t.Errorf("Expected name 'test-db', got '%s'", provider.name) + } +} + +func TestExistingDBProvider_Connect(t *testing.T) { + db, err := sql.Open("sqlite3", ":memory:") + if err != nil { + t.Fatalf("Failed to open database: %v", err) + } + defer db.Close() + + provider := NewExistingDBProvider(db, "test-db") + ctx := context.Background() + + // Connect should verify the connection works + err = provider.Connect(ctx, nil) + if err != nil { + t.Errorf("Expected Connect to succeed, got error: %v", err) + } +} + +func TestExistingDBProvider_Connect_NilDB(t *testing.T) { + provider := NewExistingDBProvider(nil, "test-db") + ctx := context.Background() + + err := provider.Connect(ctx, nil) + if err == nil { + t.Error("Expected Connect to fail with nil database") + } +} + +func TestExistingDBProvider_GetNative(t *testing.T) { + db, err := sql.Open("sqlite3", ":memory:") + if err != nil { + t.Fatalf("Failed to open database: %v", err) + } + defer db.Close() + + provider := NewExistingDBProvider(db, "test-db") + + nativeDB, err := provider.GetNative() + if err != nil { + t.Errorf("Expected GetNative to succeed, got error: %v", err) + } + + if nativeDB != db { + t.Error("Expected GetNative to return the same database instance") + } +} + +func TestExistingDBProvider_GetNative_NilDB(t *testing.T) { + provider := NewExistingDBProvider(nil, "test-db") + + _, err := provider.GetNative() + if err == nil { + t.Error("Expected GetNative to fail with nil database") + } +} + +func TestExistingDBProvider_HealthCheck(t *testing.T) { + db, err := sql.Open("sqlite3", ":memory:") + if err != nil { + t.Fatalf("Failed to open database: %v", err) + } + defer db.Close() + + provider := NewExistingDBProvider(db, "test-db") + ctx := context.Background() + + err = provider.HealthCheck(ctx) + if err != nil { + t.Errorf("Expected HealthCheck to succeed, got error: %v", err) + } +} + +func TestExistingDBProvider_HealthCheck_ClosedDB(t *testing.T) { + db, err := sql.Open("sqlite3", ":memory:") + if err != nil { + t.Fatalf("Failed to open database: %v", err) + } + + provider := NewExistingDBProvider(db, "test-db") + + // Close the database + db.Close() + + ctx := context.Background() + err = provider.HealthCheck(ctx) + if err == nil { + t.Error("Expected HealthCheck to fail with closed database") + } +} + +func TestExistingDBProvider_GetMongo(t *testing.T) { + db, err := sql.Open("sqlite3", ":memory:") + if err != nil { + t.Fatalf("Failed to open database: %v", err) + } + defer db.Close() + + provider := NewExistingDBProvider(db, "test-db") + + _, err = provider.GetMongo() + if err != ErrNotMongoDB { + t.Errorf("Expected ErrNotMongoDB, got: %v", err) + } +} + +func TestExistingDBProvider_Stats(t *testing.T) { + db, err := sql.Open("sqlite3", ":memory:") + if err != nil { + t.Fatalf("Failed to open database: %v", err) + } + defer db.Close() + + // Set some connection pool settings to test stats + db.SetMaxOpenConns(10) + db.SetMaxIdleConns(5) + db.SetConnMaxLifetime(time.Hour) + + provider := NewExistingDBProvider(db, "test-db") + + stats := provider.Stats() + if stats == nil { + t.Fatal("Expected stats to be returned") + } + + if stats.Name != "test-db" { + t.Errorf("Expected stats.Name to be 'test-db', got '%s'", stats.Name) + } + + if stats.Type != "sql" { + t.Errorf("Expected stats.Type to be 'sql', got '%s'", stats.Type) + } + + if !stats.Connected { + t.Error("Expected stats.Connected to be true") + } +} + +func TestExistingDBProvider_Close(t *testing.T) { + db, err := sql.Open("sqlite3", ":memory:") + if err != nil { + t.Fatalf("Failed to open database: %v", err) + } + + provider := NewExistingDBProvider(db, "test-db") + + err = provider.Close() + if err != nil { + t.Errorf("Expected Close to succeed, got error: %v", err) + } + + // Verify the database is closed + err = db.Ping() + if err == nil { + t.Error("Expected database to be closed") + } +} + +func TestExistingDBProvider_Close_NilDB(t *testing.T) { + provider := NewExistingDBProvider(nil, "test-db") + + err := provider.Close() + if err != nil { + t.Errorf("Expected Close to succeed with nil database, got error: %v", err) + } +} From a980201d215702c3e2ae7bba5374e4bdee23c773 Mon Sep 17 00:00:00 2001 From: Hein Date: Tue, 13 Jan 2026 15:09:56 +0200 Subject: [PATCH 16/31] =?UTF-8?q?feat(spectypes):=20=E2=9C=A8=20enhance=20?= =?UTF-8?q?SqlNull=20to=20support=20float=20and=20int=20types?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add handling for float32 and float64 in Scan method. * Implement parsing for integer types in Scan and FromString methods. * Improve flexibility of SqlNull for various numeric inputs. --- pkg/spectypes/sql_types.go | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/pkg/spectypes/sql_types.go b/pkg/spectypes/sql_types.go index 2d72ce0..43afc93 100644 --- a/pkg/spectypes/sql_types.go +++ b/pkg/spectypes/sql_types.go @@ -74,6 +74,10 @@ func (n *SqlNull[T]) Scan(value any) error { return n.FromString(v) case []byte: return n.FromString(string(v)) + case float32, float64: + return n.FromString(fmt.Sprintf("%f", value)) + case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: + return n.FromString(fmt.Sprintf("%d", value)) default: return n.FromString(fmt.Sprintf("%v", value)) } @@ -94,6 +98,10 @@ func (n *SqlNull[T]) FromString(s string) error { reflect.ValueOf(&n.Val).Elem().SetInt(i) n.Valid = true } + if f, err := strconv.ParseFloat(s, 64); err == nil { + reflect.ValueOf(&n.Val).Elem().SetInt(int64(f)) + n.Valid = true + } case float32, float64: if f, err := strconv.ParseFloat(s, 64); err == nil { reflect.ValueOf(&n.Val).Elem().SetFloat(f) From 292306b6087ffe1e418c6f956eb03a9cc70520f5 Mon Sep 17 00:00:00 2001 From: Hein Date: Wed, 14 Jan 2026 10:00:13 +0200 Subject: [PATCH 17/31] fix: :art: fix recursive crud bugs --- pkg/common/recursive_crud.go | 153 +++++-- pkg/common/recursive_crud_test.go | 720 ++++++++++++++++++++++++++++++ pkg/reflection/helpers.go | 43 +- pkg/restheadspec/handler.go | 38 +- 4 files changed, 926 insertions(+), 28 deletions(-) create mode 100644 pkg/common/recursive_crud_test.go diff --git a/pkg/common/recursive_crud.go b/pkg/common/recursive_crud.go index 6e047fb..f6261c3 100644 --- a/pkg/common/recursive_crud.go +++ b/pkg/common/recursive_crud.go @@ -74,6 +74,7 @@ func (p *NestedCUDProcessor) ProcessNestedCUD( } if modelType == nil || modelType.Kind() != reflect.Struct { + logger.Error("Invalid model type: operation=%s, table=%s, modelType=%v, expected struct", operation, tableName, modelType) return nil, fmt.Errorf("model must be a struct type, got %v", modelType) } @@ -103,44 +104,64 @@ func (p *NestedCUDProcessor) ProcessNestedCUD( // Get the primary key name for this model pkName := reflection.GetPrimaryKeyName(model) + // Check if we have any data to process (besides _request) + hasData := len(regularData) > 0 + // Process based on operation switch strings.ToLower(operation) { case "insert", "create": - id, err := p.processInsert(ctx, regularData, tableName) - if err != nil { - return nil, fmt.Errorf("insert failed: %w", err) - } - result.ID = id - result.AffectedRows = 1 - result.Data = regularData + // Only perform insert if we have data to insert + if hasData { + id, err := p.processInsert(ctx, regularData, tableName) + if err != nil { + logger.Error("Insert failed for table=%s, data=%+v, error=%v", tableName, regularData, err) + return nil, fmt.Errorf("insert failed: %w", err) + } + result.ID = id + result.AffectedRows = 1 + result.Data = regularData - // Process child relations after parent insert (to get parent ID) - if err := p.processChildRelations(ctx, "insert", id, relationFields, result.RelationData, modelType); err != nil { - return nil, fmt.Errorf("failed to process child relations: %w", err) + // Process child relations after parent insert (to get parent ID) + if err := p.processChildRelations(ctx, "insert", id, relationFields, result.RelationData, modelType, parentIDs); err != nil { + logger.Error("Failed to process child relations after insert: table=%s, parentID=%v, relations=%+v, error=%v", tableName, id, relationFields, err) + return nil, fmt.Errorf("failed to process child relations: %w", err) + } + } else { + logger.Debug("Skipping insert for %s - no data columns besides _request", tableName) } case "update": - rows, err := p.processUpdate(ctx, regularData, tableName, data[pkName]) - if err != nil { - return nil, fmt.Errorf("update failed: %w", err) - } - result.ID = data[pkName] - result.AffectedRows = rows - result.Data = regularData + // Only perform update if we have data to update + if hasData { + rows, err := p.processUpdate(ctx, regularData, tableName, data[pkName]) + if err != nil { + logger.Error("Update failed for table=%s, id=%v, data=%+v, error=%v", tableName, data[pkName], regularData, err) + return nil, fmt.Errorf("update failed: %w", err) + } + result.ID = data[pkName] + result.AffectedRows = rows + result.Data = regularData - // Process child relations for update - if err := p.processChildRelations(ctx, "update", data[pkName], relationFields, result.RelationData, modelType); err != nil { - return nil, fmt.Errorf("failed to process child relations: %w", err) + // Process child relations for update + if err := p.processChildRelations(ctx, "update", data[pkName], relationFields, result.RelationData, modelType, parentIDs); err != nil { + logger.Error("Failed to process child relations after update: table=%s, parentID=%v, relations=%+v, error=%v", tableName, data[pkName], relationFields, err) + return nil, fmt.Errorf("failed to process child relations: %w", err) + } + } else { + logger.Debug("Skipping update for %s - no data columns besides _request", tableName) + result.ID = data[pkName] } case "delete": // Process child relations first (for referential integrity) - if err := p.processChildRelations(ctx, "delete", data[pkName], relationFields, result.RelationData, modelType); err != nil { + if err := p.processChildRelations(ctx, "delete", data[pkName], relationFields, result.RelationData, modelType, parentIDs); err != nil { + logger.Error("Failed to process child relations before delete: table=%s, id=%v, relations=%+v, error=%v", tableName, data[pkName], relationFields, err) return nil, fmt.Errorf("failed to process child relations before delete: %w", err) } rows, err := p.processDelete(ctx, tableName, data[pkName]) if err != nil { + logger.Error("Delete failed for table=%s, id=%v, error=%v", tableName, data[pkName], err) return nil, fmt.Errorf("delete failed: %w", err) } result.ID = data[pkName] @@ -148,6 +169,7 @@ func (p *NestedCUDProcessor) ProcessNestedCUD( result.Data = regularData default: + logger.Error("Unsupported operation: %s for table=%s", operation, tableName) return nil, fmt.Errorf("unsupported operation: %s", operation) } @@ -213,6 +235,7 @@ func (p *NestedCUDProcessor) processInsert( result, err := query.Exec(ctx) if err != nil { + logger.Error("Insert execution failed: table=%s, data=%+v, error=%v", tableName, data, err) return nil, fmt.Errorf("insert exec failed: %w", err) } @@ -236,6 +259,7 @@ func (p *NestedCUDProcessor) processUpdate( id interface{}, ) (int64, error) { if id == nil { + logger.Error("Update requires an ID: table=%s, data=%+v", tableName, data) return 0, fmt.Errorf("update requires an ID") } @@ -245,6 +269,7 @@ func (p *NestedCUDProcessor) processUpdate( result, err := query.Exec(ctx) if err != nil { + logger.Error("Update execution failed: table=%s, id=%v, data=%+v, error=%v", tableName, id, data, err) return 0, fmt.Errorf("update exec failed: %w", err) } @@ -256,6 +281,7 @@ func (p *NestedCUDProcessor) processUpdate( // processDelete handles delete operation func (p *NestedCUDProcessor) processDelete(ctx context.Context, tableName string, id interface{}) (int64, error) { if id == nil { + logger.Error("Delete requires an ID: table=%s", tableName) return 0, fmt.Errorf("delete requires an ID") } @@ -265,6 +291,7 @@ func (p *NestedCUDProcessor) processDelete(ctx context.Context, tableName string result, err := query.Exec(ctx) if err != nil { + logger.Error("Delete execution failed: table=%s, id=%v, error=%v", tableName, id, err) return 0, fmt.Errorf("delete exec failed: %w", err) } @@ -281,6 +308,7 @@ func (p *NestedCUDProcessor) processChildRelations( relationFields map[string]*RelationshipInfo, relationData map[string]interface{}, parentModelType reflect.Type, + incomingParentIDs map[string]interface{}, // IDs from all ancestors ) error { for relationName, relInfo := range relationFields { relationValue, exists := relationData[relationName] @@ -293,7 +321,7 @@ func (p *NestedCUDProcessor) processChildRelations( // Get the related model field, found := parentModelType.FieldByName(relInfo.FieldName) if !found { - logger.Warn("Field %s not found in model", relInfo.FieldName) + logger.Error("Field %s not found in model type %v for relation %s", relInfo.FieldName, parentModelType, relationName) continue } @@ -313,20 +341,77 @@ func (p *NestedCUDProcessor) processChildRelations( relatedTableName := p.getTableNameForModel(relatedModel, relInfo.JSONName) // Prepare parent IDs for foreign key injection + // Start by copying all incoming parent IDs (from ancestors) parentIDs := make(map[string]interface{}) - if relInfo.ForeignKey != "" { + for k, v := range incomingParentIDs { + parentIDs[k] = v + } + logger.Debug("Inherited %d parent IDs from ancestors: %+v", len(incomingParentIDs), incomingParentIDs) + + // Add the current parent's primary key to the parentIDs map + // This ensures nested children have access to all ancestor IDs + if parentID != nil && parentModelType != nil { + // Get the parent model's primary key field name + parentPKFieldName := reflection.GetPrimaryKeyName(parentModelType) + if parentPKFieldName != "" { + // Get the JSON name for the primary key field + parentPKJSONName := reflection.GetJSONNameForField(parentModelType, parentPKFieldName) + baseName := "" + if len(parentPKJSONName) > 1 { + baseName = parentPKJSONName + } else { + // Add parent's PK to the map using the base model name + baseName = strings.TrimSuffix(parentPKFieldName, "ID") + baseName = strings.TrimSuffix(strings.ToLower(baseName), "_id") + if baseName == "" { + baseName = "parent" + } + } + + parentIDs[baseName] = parentID + logger.Debug("Added current parent PK to parentIDs map: %s=%v (from field %s)", baseName, parentID, parentPKFieldName) + } + } + + // Also add the foreign key reference if specified + if relInfo.ForeignKey != "" && parentID != nil { // Extract the base name from foreign key (e.g., "DepartmentID" -> "Department") baseName := strings.TrimSuffix(relInfo.ForeignKey, "ID") baseName = strings.TrimSuffix(strings.ToLower(baseName), "_id") - parentIDs[baseName] = parentID + // Only add if different from what we already added + if _, exists := parentIDs[baseName]; !exists { + parentIDs[baseName] = parentID + logger.Debug("Added foreign key to parentIDs map: %s=%v (from FK %s)", baseName, parentID, relInfo.ForeignKey) + } + } + + logger.Debug("Final parentIDs map for relation %s: %+v", relationName, parentIDs) + + // Determine which field name to use for setting parent ID in child data + // Priority: Use foreign key field name if specified + var foreignKeyFieldName string + if relInfo.ForeignKey != "" { + // Get the JSON name for the foreign key field in the child model + foreignKeyFieldName = reflection.GetJSONNameForField(relatedModelType, relInfo.ForeignKey) + if foreignKeyFieldName == "" { + // Fallback to lowercase field name + foreignKeyFieldName = strings.ToLower(relInfo.ForeignKey) + } + logger.Debug("Using foreign key field for direct assignment: %s (from FK %s)", foreignKeyFieldName, relInfo.ForeignKey) } // Process based on relation type and data structure switch v := relationValue.(type) { case map[string]interface{}: - // Single related object + // Single related object - directly set foreign key if specified + if parentID != nil && foreignKeyFieldName != "" { + v[foreignKeyFieldName] = parentID + logger.Debug("Set foreign key in single relation: %s=%v", foreignKeyFieldName, parentID) + } _, err := p.ProcessNestedCUD(ctx, operation, v, relatedModel, parentIDs, relatedTableName) if err != nil { + logger.Error("Failed to process single relation: name=%s, table=%s, operation=%s, parentID=%v, data=%+v, error=%v", + relationName, relatedTableName, operation, parentID, v, err) return fmt.Errorf("failed to process relation %s: %w", relationName, err) } @@ -334,24 +419,40 @@ func (p *NestedCUDProcessor) processChildRelations( // Multiple related objects for i, item := range v { if itemMap, ok := item.(map[string]interface{}); ok { + // Directly set foreign key if specified + if parentID != nil && foreignKeyFieldName != "" { + itemMap[foreignKeyFieldName] = parentID + logger.Debug("Set foreign key in relation array[%d]: %s=%v", i, foreignKeyFieldName, parentID) + } _, err := p.ProcessNestedCUD(ctx, operation, itemMap, relatedModel, parentIDs, relatedTableName) if err != nil { + logger.Error("Failed to process relation array item: name=%s[%d], table=%s, operation=%s, parentID=%v, data=%+v, error=%v", + relationName, i, relatedTableName, operation, parentID, itemMap, err) return fmt.Errorf("failed to process relation %s[%d]: %w", relationName, i, err) } + } else { + logger.Warn("Relation array item is not a map: name=%s[%d], type=%T", relationName, i, item) } } case []map[string]interface{}: // Multiple related objects (typed slice) for i, itemMap := range v { + // Directly set foreign key if specified + if parentID != nil && foreignKeyFieldName != "" { + itemMap[foreignKeyFieldName] = parentID + logger.Debug("Set foreign key in relation typed array[%d]: %s=%v", i, foreignKeyFieldName, parentID) + } _, err := p.ProcessNestedCUD(ctx, operation, itemMap, relatedModel, parentIDs, relatedTableName) if err != nil { + logger.Error("Failed to process relation typed array item: name=%s[%d], table=%s, operation=%s, parentID=%v, data=%+v, error=%v", + relationName, i, relatedTableName, operation, parentID, itemMap, err) return fmt.Errorf("failed to process relation %s[%d]: %w", relationName, i, err) } } default: - logger.Warn("Unsupported relation data type for %s: %T", relationName, relationValue) + logger.Error("Unsupported relation data type: name=%s, type=%T, value=%+v", relationName, relationValue, relationValue) } } diff --git a/pkg/common/recursive_crud_test.go b/pkg/common/recursive_crud_test.go new file mode 100644 index 0000000..9bda8bb --- /dev/null +++ b/pkg/common/recursive_crud_test.go @@ -0,0 +1,720 @@ +package common + +import ( + "context" + "reflect" + "testing" + + "github.com/bitechdev/ResolveSpec/pkg/reflection" +) + +// Mock Database for testing +type mockDatabase struct { + insertCalls []map[string]interface{} + updateCalls []map[string]interface{} + deleteCalls []interface{} + lastID int64 +} + +func newMockDatabase() *mockDatabase { + return &mockDatabase{ + insertCalls: make([]map[string]interface{}, 0), + updateCalls: make([]map[string]interface{}, 0), + deleteCalls: make([]interface{}, 0), + lastID: 1, + } +} + +func (m *mockDatabase) NewSelect() SelectQuery { return &mockSelectQuery{} } +func (m *mockDatabase) NewInsert() InsertQuery { return &mockInsertQuery{db: m} } +func (m *mockDatabase) NewUpdate() UpdateQuery { return &mockUpdateQuery{db: m} } +func (m *mockDatabase) NewDelete() DeleteQuery { return &mockDeleteQuery{db: m} } +func (m *mockDatabase) RunInTransaction(ctx context.Context, fn func(Database) error) error { + return fn(m) +} +func (m *mockDatabase) Exec(ctx context.Context, query string, args ...interface{}) (Result, error) { + return &mockResult{rowsAffected: 1}, nil +} +func (m *mockDatabase) Query(ctx context.Context, dest interface{}, query string, args ...interface{}) error { + return nil +} +func (m *mockDatabase) BeginTx(ctx context.Context) (Database, error) { + return m, nil +} +func (m *mockDatabase) CommitTx(ctx context.Context) error { + return nil +} +func (m *mockDatabase) RollbackTx(ctx context.Context) error { + return nil +} +func (m *mockDatabase) GetUnderlyingDB() interface{} { + return nil +} + +// Mock SelectQuery +type mockSelectQuery struct{} + +func (m *mockSelectQuery) Model(model interface{}) SelectQuery { return m } +func (m *mockSelectQuery) Table(name string) SelectQuery { return m } +func (m *mockSelectQuery) Column(columns ...string) SelectQuery { return m } +func (m *mockSelectQuery) ColumnExpr(query string, args ...interface{}) SelectQuery { return m } +func (m *mockSelectQuery) Where(condition string, args ...interface{}) SelectQuery { return m } +func (m *mockSelectQuery) WhereOr(query string, args ...interface{}) SelectQuery { return m } +func (m *mockSelectQuery) Join(query string, args ...interface{}) SelectQuery { return m } +func (m *mockSelectQuery) LeftJoin(query string, args ...interface{}) SelectQuery { return m } +func (m *mockSelectQuery) Preload(relation string, conditions ...interface{}) SelectQuery { return m } +func (m *mockSelectQuery) PreloadRelation(relation string, apply ...func(SelectQuery) SelectQuery) SelectQuery { return m } +func (m *mockSelectQuery) JoinRelation(relation string, apply ...func(SelectQuery) SelectQuery) SelectQuery { return m } +func (m *mockSelectQuery) Order(order string) SelectQuery { return m } +func (m *mockSelectQuery) OrderExpr(order string, args ...interface{}) SelectQuery { return m } +func (m *mockSelectQuery) Limit(n int) SelectQuery { return m } +func (m *mockSelectQuery) Offset(n int) SelectQuery { return m } +func (m *mockSelectQuery) Group(group string) SelectQuery { return m } +func (m *mockSelectQuery) Having(condition string, args ...interface{}) SelectQuery { return m } +func (m *mockSelectQuery) Scan(ctx context.Context, dest interface{}) error { return nil } +func (m *mockSelectQuery) ScanModel(ctx context.Context) error { return nil } +func (m *mockSelectQuery) Count(ctx context.Context) (int, error) { return 0, nil } +func (m *mockSelectQuery) Exists(ctx context.Context) (bool, error) { return false, nil } + +// Mock InsertQuery +type mockInsertQuery struct { + db *mockDatabase + table string + values map[string]interface{} +} + +func (m *mockInsertQuery) Model(model interface{}) InsertQuery { return m } +func (m *mockInsertQuery) Table(name string) InsertQuery { + m.table = name + return m +} +func (m *mockInsertQuery) Value(column string, value interface{}) InsertQuery { + if m.values == nil { + m.values = make(map[string]interface{}) + } + m.values[column] = value + return m +} +func (m *mockInsertQuery) OnConflict(action string) InsertQuery { return m } +func (m *mockInsertQuery) Returning(columns ...string) InsertQuery { return m } +func (m *mockInsertQuery) Exec(ctx context.Context) (Result, error) { + // Record the insert call + m.db.insertCalls = append(m.db.insertCalls, m.values) + m.db.lastID++ + return &mockResult{lastID: m.db.lastID, rowsAffected: 1}, nil +} + +// Mock UpdateQuery +type mockUpdateQuery struct { + db *mockDatabase + table string + setValues map[string]interface{} +} + +func (m *mockUpdateQuery) Model(model interface{}) UpdateQuery { return m } +func (m *mockUpdateQuery) Table(name string) UpdateQuery { + m.table = name + return m +} +func (m *mockUpdateQuery) Set(column string, value interface{}) UpdateQuery { return m } +func (m *mockUpdateQuery) SetMap(values map[string]interface{}) UpdateQuery { + m.setValues = values + return m +} +func (m *mockUpdateQuery) Where(condition string, args ...interface{}) UpdateQuery { return m } +func (m *mockUpdateQuery) Returning(columns ...string) UpdateQuery { return m } +func (m *mockUpdateQuery) Exec(ctx context.Context) (Result, error) { + // Record the update call + m.db.updateCalls = append(m.db.updateCalls, m.setValues) + return &mockResult{rowsAffected: 1}, nil +} + +// Mock DeleteQuery +type mockDeleteQuery struct { + db *mockDatabase + table string +} + +func (m *mockDeleteQuery) Model(model interface{}) DeleteQuery { return m } +func (m *mockDeleteQuery) Table(name string) DeleteQuery { + m.table = name + return m +} +func (m *mockDeleteQuery) Where(condition string, args ...interface{}) DeleteQuery { return m } +func (m *mockDeleteQuery) Exec(ctx context.Context) (Result, error) { + // Record the delete call + m.db.deleteCalls = append(m.db.deleteCalls, m.table) + return &mockResult{rowsAffected: 1}, nil +} + +// Mock Result +type mockResult struct { + lastID int64 + rowsAffected int64 +} + +func (m *mockResult) LastInsertId() (int64, error) { return m.lastID, nil } +func (m *mockResult) RowsAffected() int64 { return m.rowsAffected } + +// Mock ModelRegistry +type mockModelRegistry struct{} + +func (m *mockModelRegistry) GetModel(name string) (interface{}, error) { return nil, nil } +func (m *mockModelRegistry) GetModelByEntity(schema, entity string) (interface{}, error) { return nil, nil } +func (m *mockModelRegistry) RegisterModel(name string, model interface{}) error { return nil } +func (m *mockModelRegistry) GetAllModels() map[string]interface{} { return make(map[string]interface{}) } + +// Mock RelationshipInfoProvider +type mockRelationshipProvider struct { + relationships map[string]*RelationshipInfo +} + +func newMockRelationshipProvider() *mockRelationshipProvider { + return &mockRelationshipProvider{ + relationships: make(map[string]*RelationshipInfo), + } +} + +func (m *mockRelationshipProvider) GetRelationshipInfo(modelType reflect.Type, relationName string) *RelationshipInfo { + key := modelType.Name() + "." + relationName + return m.relationships[key] +} + +func (m *mockRelationshipProvider) RegisterRelation(modelTypeName, relationName string, info *RelationshipInfo) { + key := modelTypeName + "." + relationName + m.relationships[key] = info +} + +// Test Models +type Department struct { + ID int64 `json:"id" bun:"id,pk"` + Name string `json:"name"` + Employees []*Employee `json:"employees,omitempty"` +} + +func (d Department) TableName() string { return "departments" } +func (d Department) GetIDName() string { return "ID" } + +type Employee struct { + ID int64 `json:"id" bun:"id,pk"` + Name string `json:"name"` + DepartmentID int64 `json:"department_id"` + Tasks []*Task `json:"tasks,omitempty"` +} + +func (e Employee) TableName() string { return "employees" } +func (e Employee) GetIDName() string { return "ID" } + +type Task struct { + ID int64 `json:"id" bun:"id,pk"` + Title string `json:"title"` + EmployeeID int64 `json:"employee_id"` + Comments []*Comment `json:"comments,omitempty"` +} + +func (t Task) TableName() string { return "tasks" } +func (t Task) GetIDName() string { return "ID" } + +type Comment struct { + ID int64 `json:"id" bun:"id,pk"` + Text string `json:"text"` + TaskID int64 `json:"task_id"` +} + +func (c Comment) TableName() string { return "comments" } +func (c Comment) GetIDName() string { return "ID" } + +// Test Cases + +func TestProcessNestedCUD_SingleLevelInsert(t *testing.T) { + db := newMockDatabase() + registry := &mockModelRegistry{} + relProvider := newMockRelationshipProvider() + + // Register Department -> Employees relationship + relProvider.RegisterRelation("Department", "employees", &RelationshipInfo{ + FieldName: "Employees", + JSONName: "employees", + RelationType: "has_many", + ForeignKey: "DepartmentID", + RelatedModel: Employee{}, + }) + + processor := NewNestedCUDProcessor(db, registry, relProvider) + + data := map[string]interface{}{ + "name": "Engineering", + "employees": []interface{}{ + map[string]interface{}{ + "name": "John Doe", + }, + map[string]interface{}{ + "name": "Jane Smith", + }, + }, + } + + result, err := processor.ProcessNestedCUD( + context.Background(), + "insert", + data, + Department{}, + nil, + "departments", + ) + + if err != nil { + t.Fatalf("ProcessNestedCUD failed: %v", err) + } + + if result.ID == nil { + t.Error("Expected result.ID to be set") + } + + // Verify department was inserted + if len(db.insertCalls) != 3 { + t.Errorf("Expected 3 insert calls (1 dept + 2 employees), got %d", len(db.insertCalls)) + } + + // Verify first insert is department + if db.insertCalls[0]["name"] != "Engineering" { + t.Errorf("Expected department name 'Engineering', got %v", db.insertCalls[0]["name"]) + } + + // Verify employees were inserted with foreign key + if db.insertCalls[1]["department_id"] == nil { + t.Error("Expected employee to have department_id set") + } + if db.insertCalls[2]["department_id"] == nil { + t.Error("Expected employee to have department_id set") + } +} + +func TestProcessNestedCUD_MultiLevelInsert(t *testing.T) { + db := newMockDatabase() + registry := &mockModelRegistry{} + relProvider := newMockRelationshipProvider() + + // Register relationships + relProvider.RegisterRelation("Department", "employees", &RelationshipInfo{ + FieldName: "Employees", + JSONName: "employees", + RelationType: "has_many", + ForeignKey: "DepartmentID", + RelatedModel: Employee{}, + }) + + relProvider.RegisterRelation("Employee", "tasks", &RelationshipInfo{ + FieldName: "Tasks", + JSONName: "tasks", + RelationType: "has_many", + ForeignKey: "EmployeeID", + RelatedModel: Task{}, + }) + + processor := NewNestedCUDProcessor(db, registry, relProvider) + + data := map[string]interface{}{ + "name": "Engineering", + "employees": []interface{}{ + map[string]interface{}{ + "name": "John Doe", + "tasks": []interface{}{ + map[string]interface{}{ + "title": "Task 1", + }, + map[string]interface{}{ + "title": "Task 2", + }, + }, + }, + }, + } + + result, err := processor.ProcessNestedCUD( + context.Background(), + "insert", + data, + Department{}, + nil, + "departments", + ) + + if err != nil { + t.Fatalf("ProcessNestedCUD failed: %v", err) + } + + if result.ID == nil { + t.Error("Expected result.ID to be set") + } + + // Verify: 1 dept + 1 employee + 2 tasks = 4 inserts + if len(db.insertCalls) != 4 { + t.Errorf("Expected 4 insert calls, got %d", len(db.insertCalls)) + } + + // Verify department + if db.insertCalls[0]["name"] != "Engineering" { + t.Errorf("Expected department name 'Engineering', got %v", db.insertCalls[0]["name"]) + } + + // Verify employee has department_id + if db.insertCalls[1]["department_id"] == nil { + t.Error("Expected employee to have department_id set") + } + + // Verify tasks have employee_id + if db.insertCalls[2]["employee_id"] == nil { + t.Error("Expected task to have employee_id set") + } + if db.insertCalls[3]["employee_id"] == nil { + t.Error("Expected task to have employee_id set") + } +} + +func TestProcessNestedCUD_RequestFieldOverride(t *testing.T) { + db := newMockDatabase() + registry := &mockModelRegistry{} + relProvider := newMockRelationshipProvider() + + relProvider.RegisterRelation("Department", "employees", &RelationshipInfo{ + FieldName: "Employees", + JSONName: "employees", + RelationType: "has_many", + ForeignKey: "DepartmentID", + RelatedModel: Employee{}, + }) + + processor := NewNestedCUDProcessor(db, registry, relProvider) + + data := map[string]interface{}{ + "name": "Engineering", + "employees": []interface{}{ + map[string]interface{}{ + "_request": "update", + "ID": int64(10), // Use capital ID to match struct field + "name": "John Updated", + }, + }, + } + + _, err := processor.ProcessNestedCUD( + context.Background(), + "insert", + data, + Department{}, + nil, + "departments", + ) + + if err != nil { + t.Fatalf("ProcessNestedCUD failed: %v", err) + } + + // Verify department was inserted (1 insert) + // Employee should be updated (1 update) + if len(db.insertCalls) != 1 { + t.Errorf("Expected 1 insert call for department, got %d", len(db.insertCalls)) + } + + if len(db.updateCalls) != 1 { + t.Errorf("Expected 1 update call for employee, got %d", len(db.updateCalls)) + } + + // Verify update data + if db.updateCalls[0]["name"] != "John Updated" { + t.Errorf("Expected employee name 'John Updated', got %v", db.updateCalls[0]["name"]) + } +} + +func TestProcessNestedCUD_SkipInsertWhenOnlyRequestField(t *testing.T) { + db := newMockDatabase() + registry := &mockModelRegistry{} + relProvider := newMockRelationshipProvider() + + relProvider.RegisterRelation("Department", "employees", &RelationshipInfo{ + FieldName: "Employees", + JSONName: "employees", + RelationType: "has_many", + ForeignKey: "DepartmentID", + RelatedModel: Employee{}, + }) + + processor := NewNestedCUDProcessor(db, registry, relProvider) + + // Data with only _request field for nested employee + data := map[string]interface{}{ + "name": "Engineering", + "employees": []interface{}{ + map[string]interface{}{ + "_request": "insert", + // No other fields besides _request + // Note: Foreign key will be injected, so employee WILL be inserted + }, + }, + } + + _, err := processor.ProcessNestedCUD( + context.Background(), + "insert", + data, + Department{}, + nil, + "departments", + ) + + if err != nil { + t.Fatalf("ProcessNestedCUD failed: %v", err) + } + + // Department + Employee (with injected FK) = 2 inserts + if len(db.insertCalls) != 2 { + t.Errorf("Expected 2 insert calls (department + employee with FK), got %d", len(db.insertCalls)) + } + + if db.insertCalls[0]["name"] != "Engineering" { + t.Errorf("Expected department name 'Engineering', got %v", db.insertCalls[0]["name"]) + } + + // Verify employee has foreign key + if db.insertCalls[1]["department_id"] == nil { + t.Error("Expected employee to have department_id injected") + } +} + +func TestProcessNestedCUD_Update(t *testing.T) { + db := newMockDatabase() + registry := &mockModelRegistry{} + relProvider := newMockRelationshipProvider() + + relProvider.RegisterRelation("Department", "employees", &RelationshipInfo{ + FieldName: "Employees", + JSONName: "employees", + RelationType: "has_many", + ForeignKey: "DepartmentID", + RelatedModel: Employee{}, + }) + + processor := NewNestedCUDProcessor(db, registry, relProvider) + + data := map[string]interface{}{ + "ID": int64(1), // Use capital ID to match struct field + "name": "Engineering Updated", + "employees": []interface{}{ + map[string]interface{}{ + "_request": "insert", + "name": "New Employee", + }, + }, + } + + result, err := processor.ProcessNestedCUD( + context.Background(), + "update", + data, + Department{}, + nil, + "departments", + ) + + if err != nil { + t.Fatalf("ProcessNestedCUD failed: %v", err) + } + + if result.ID != int64(1) { + t.Errorf("Expected result.ID to be 1, got %v", result.ID) + } + + // Verify department was updated + if len(db.updateCalls) != 1 { + t.Errorf("Expected 1 update call, got %d", len(db.updateCalls)) + } + + // Verify new employee was inserted + if len(db.insertCalls) != 1 { + t.Errorf("Expected 1 insert call for new employee, got %d", len(db.insertCalls)) + } +} + +func TestProcessNestedCUD_Delete(t *testing.T) { + db := newMockDatabase() + registry := &mockModelRegistry{} + relProvider := newMockRelationshipProvider() + + relProvider.RegisterRelation("Department", "employees", &RelationshipInfo{ + FieldName: "Employees", + JSONName: "employees", + RelationType: "has_many", + ForeignKey: "DepartmentID", + RelatedModel: Employee{}, + }) + + processor := NewNestedCUDProcessor(db, registry, relProvider) + + data := map[string]interface{}{ + "ID": int64(1), // Use capital ID to match struct field + "employees": []interface{}{ + map[string]interface{}{ + "_request": "delete", + "ID": int64(10), // Use capital ID + }, + map[string]interface{}{ + "_request": "delete", + "ID": int64(11), // Use capital ID + }, + }, + } + + _, err := processor.ProcessNestedCUD( + context.Background(), + "delete", + data, + Department{}, + nil, + "departments", + ) + + if err != nil { + t.Fatalf("ProcessNestedCUD failed: %v", err) + } + + // Verify employees were deleted first, then department + // 2 employees + 1 department = 3 deletes + if len(db.deleteCalls) != 3 { + t.Errorf("Expected 3 delete calls, got %d", len(db.deleteCalls)) + } +} + +func TestProcessNestedCUD_ParentIDPropagation(t *testing.T) { + db := newMockDatabase() + registry := &mockModelRegistry{} + relProvider := newMockRelationshipProvider() + + // Register 3-level relationships + relProvider.RegisterRelation("Department", "employees", &RelationshipInfo{ + FieldName: "Employees", + JSONName: "employees", + RelationType: "has_many", + ForeignKey: "DepartmentID", + RelatedModel: Employee{}, + }) + + relProvider.RegisterRelation("Employee", "tasks", &RelationshipInfo{ + FieldName: "Tasks", + JSONName: "tasks", + RelationType: "has_many", + ForeignKey: "EmployeeID", + RelatedModel: Task{}, + }) + + relProvider.RegisterRelation("Task", "comments", &RelationshipInfo{ + FieldName: "Comments", + JSONName: "comments", + RelationType: "has_many", + ForeignKey: "TaskID", + RelatedModel: Comment{}, + }) + + processor := NewNestedCUDProcessor(db, registry, relProvider) + + data := map[string]interface{}{ + "name": "Engineering", + "employees": []interface{}{ + map[string]interface{}{ + "name": "John", + "tasks": []interface{}{ + map[string]interface{}{ + "title": "Task 1", + "comments": []interface{}{ + map[string]interface{}{ + "text": "Great work!", + }, + }, + }, + }, + }, + }, + } + + _, err := processor.ProcessNestedCUD( + context.Background(), + "insert", + data, + Department{}, + nil, + "departments", + ) + + if err != nil { + t.Fatalf("ProcessNestedCUD failed: %v", err) + } + + // Verify: 1 dept + 1 employee + 1 task + 1 comment = 4 inserts + if len(db.insertCalls) != 4 { + t.Errorf("Expected 4 insert calls, got %d", len(db.insertCalls)) + } + + // Verify department + if db.insertCalls[0]["name"] != "Engineering" { + t.Error("Expected department to be inserted first") + } + + // Verify employee has department_id + if db.insertCalls[1]["department_id"] == nil { + t.Error("Expected employee to have department_id") + } + + // Verify task has employee_id + if db.insertCalls[2]["employee_id"] == nil { + t.Error("Expected task to have employee_id") + } + + // Verify comment has task_id + if db.insertCalls[3]["task_id"] == nil { + t.Error("Expected comment to have task_id") + } +} + +func TestInjectForeignKeys(t *testing.T) { + db := newMockDatabase() + registry := &mockModelRegistry{} + relProvider := newMockRelationshipProvider() + + processor := NewNestedCUDProcessor(db, registry, relProvider) + + data := map[string]interface{}{ + "name": "John", + } + + parentIDs := map[string]interface{}{ + "department": int64(5), + } + + modelType := reflect.TypeOf(Employee{}) + + processor.injectForeignKeys(data, modelType, parentIDs) + + // Should inject department_id based on the "department" key in parentIDs + if data["department_id"] == nil { + t.Error("Expected department_id to be injected") + } + + if data["department_id"] != int64(5) { + t.Errorf("Expected department_id to be 5, got %v", data["department_id"]) + } +} + +func TestGetPrimaryKeyName(t *testing.T) { + dept := Department{} + pkName := reflection.GetPrimaryKeyName(dept) + + if pkName != "ID" { + t.Errorf("Expected primary key name 'ID', got '%s'", pkName) + } + + // Test with pointer + pkName2 := reflection.GetPrimaryKeyName(&dept) + if pkName2 != "ID" { + t.Errorf("Expected primary key name 'ID' from pointer, got '%s'", pkName2) + } +} diff --git a/pkg/reflection/helpers.go b/pkg/reflection/helpers.go index 155f30c..2ae1a88 100644 --- a/pkg/reflection/helpers.go +++ b/pkg/reflection/helpers.go @@ -1,6 +1,9 @@ package reflection -import "reflect" +import ( + "reflect" + "strings" +) func Len(v any) int { val := reflect.ValueOf(v) @@ -64,3 +67,41 @@ func GetPointerElement(v reflect.Type) reflect.Type { } return v } + +// GetJSONNameForField gets the JSON tag name for a struct field. +// Returns the JSON field name from the json struct tag, or an empty string if not found. +// Handles the "json" tag format: "name", "name,omitempty", etc. +func GetJSONNameForField(modelType reflect.Type, fieldName string) string { + if modelType == nil { + return "" + } + + // Handle pointer types + if modelType.Kind() == reflect.Ptr { + modelType = modelType.Elem() + } + + if modelType.Kind() != reflect.Struct { + return "" + } + + // Find the field + field, found := modelType.FieldByName(fieldName) + if !found { + return "" + } + + // Get the JSON tag + jsonTag := field.Tag.Get("json") + if jsonTag == "" { + return "" + } + + // Parse the tag (format: "name,omitempty" or just "name") + parts := strings.Split(jsonTag, ",") + if len(parts) > 0 && parts[0] != "" && parts[0] != "-" { + return parts[0] + } + + return "" +} diff --git a/pkg/restheadspec/handler.go b/pkg/restheadspec/handler.go index 5118073..c4c2a3d 100644 --- a/pkg/restheadspec/handler.go +++ b/pkg/restheadspec/handler.go @@ -1794,10 +1794,36 @@ func (h *Handler) processChildRelationsForField( parentIDs[baseName] = parentID } + // Determine which field name to use for setting parent ID in child data + // Priority: Use foreign key field name if specified, otherwise use parent's PK name + var foreignKeyFieldName string + if relInfo.ForeignKey != "" { + // Get the JSON name for the foreign key field in the child model + foreignKeyFieldName = reflection.GetJSONNameForField(relatedModelType, relInfo.ForeignKey) + if foreignKeyFieldName == "" { + // Fallback to lowercase field name + foreignKeyFieldName = strings.ToLower(relInfo.ForeignKey) + } + } else { + // Fallback: use parent's primary key name + parentPKName := reflection.GetPrimaryKeyName(parentModelType) + foreignKeyFieldName = reflection.GetJSONNameForField(parentModelType, parentPKName) + if foreignKeyFieldName == "" { + foreignKeyFieldName = strings.ToLower(parentPKName) + } + } + + logger.Debug("Setting parent ID in child data: foreignKeyField=%s, parentID=%v, relForeignKey=%s", + foreignKeyFieldName, parentID, relInfo.ForeignKey) + // Process based on relation type and data structure switch v := relationValue.(type) { case map[string]interface{}: - // Single related object + // Single related object - add parent ID to foreign key field + if parentID != nil && foreignKeyFieldName != "" { + v[foreignKeyFieldName] = parentID + logger.Debug("Set foreign key in single relation: %s=%v", foreignKeyFieldName, parentID) + } _, err := processor.ProcessNestedCUD(ctx, operation, v, relatedModel, parentIDs, relatedTableName) if err != nil { return fmt.Errorf("failed to process single relation: %w", err) @@ -1807,6 +1833,11 @@ func (h *Handler) processChildRelationsForField( // Multiple related objects for i, item := range v { if itemMap, ok := item.(map[string]interface{}); ok { + // Add parent ID to foreign key field + if parentID != nil && foreignKeyFieldName != "" { + itemMap[foreignKeyFieldName] = parentID + logger.Debug("Set foreign key in relation array[%d]: %s=%v", i, foreignKeyFieldName, parentID) + } _, err := processor.ProcessNestedCUD(ctx, operation, itemMap, relatedModel, parentIDs, relatedTableName) if err != nil { return fmt.Errorf("failed to process relation item %d: %w", i, err) @@ -1817,6 +1848,11 @@ func (h *Handler) processChildRelationsForField( case []map[string]interface{}: // Multiple related objects (typed slice) for i, itemMap := range v { + // Add parent ID to foreign key field + if parentID != nil && foreignKeyFieldName != "" { + itemMap[foreignKeyFieldName] = parentID + logger.Debug("Set foreign key in relation typed array[%d]: %s=%v", i, foreignKeyFieldName, parentID) + } _, err := processor.ProcessNestedCUD(ctx, operation, itemMap, relatedModel, parentIDs, relatedTableName) if err != nil { return fmt.Errorf("failed to process relation item %d: %w", i, err) From 7879272dda60f35013a24a8d071173870705ca80 Mon Sep 17 00:00:00 2001 From: Hein Date: Wed, 14 Jan 2026 10:19:04 +0200 Subject: [PATCH 18/31] =?UTF-8?q?fix(recursive=5Fcrud):=20=F0=9F=90=9B=20p?= =?UTF-8?q?revent=20overwriting=20primary=20key=20in=20recursive=20relatio?= =?UTF-8?q?nships?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Ensure foreign key assignment does not overwrite primary key in recursive relationships. * Added logging for skipped assignments to improve debugging. --- pkg/common/recursive_crud.go | 24 +++++++++++++++++++++--- pkg/restheadspec/handler.go | 26 +++++++++++++++++++++----- 2 files changed, 42 insertions(+), 8 deletions(-) diff --git a/pkg/common/recursive_crud.go b/pkg/common/recursive_crud.go index f6261c3..19bbb83 100644 --- a/pkg/common/recursive_crud.go +++ b/pkg/common/recursive_crud.go @@ -400,13 +400,25 @@ func (p *NestedCUDProcessor) processChildRelations( logger.Debug("Using foreign key field for direct assignment: %s (from FK %s)", foreignKeyFieldName, relInfo.ForeignKey) } + // Get the primary key name for the child model to avoid overwriting it in recursive relationships + childPKName := reflection.GetPrimaryKeyName(relatedModel) + childPKFieldName := reflection.GetJSONNameForField(relatedModelType, childPKName) + if childPKFieldName == "" { + childPKFieldName = strings.ToLower(childPKName) + } + + logger.Debug("Processing relation with foreignKeyField=%s, childPK=%s", foreignKeyFieldName, childPKFieldName) + // Process based on relation type and data structure switch v := relationValue.(type) { case map[string]interface{}: // Single related object - directly set foreign key if specified - if parentID != nil && foreignKeyFieldName != "" { + // IMPORTANT: In recursive relationships, don't overwrite the primary key + if parentID != nil && foreignKeyFieldName != "" && foreignKeyFieldName != childPKFieldName { v[foreignKeyFieldName] = parentID logger.Debug("Set foreign key in single relation: %s=%v", foreignKeyFieldName, parentID) + } else if foreignKeyFieldName == childPKFieldName { + logger.Debug("Skipping foreign key assignment - same as primary key (recursive relationship): %s", foreignKeyFieldName) } _, err := p.ProcessNestedCUD(ctx, operation, v, relatedModel, parentIDs, relatedTableName) if err != nil { @@ -420,9 +432,12 @@ func (p *NestedCUDProcessor) processChildRelations( for i, item := range v { if itemMap, ok := item.(map[string]interface{}); ok { // Directly set foreign key if specified - if parentID != nil && foreignKeyFieldName != "" { + // IMPORTANT: In recursive relationships, don't overwrite the primary key + if parentID != nil && foreignKeyFieldName != "" && foreignKeyFieldName != childPKFieldName { itemMap[foreignKeyFieldName] = parentID logger.Debug("Set foreign key in relation array[%d]: %s=%v", i, foreignKeyFieldName, parentID) + } else if foreignKeyFieldName == childPKFieldName { + logger.Debug("Skipping foreign key assignment in array[%d] - same as primary key (recursive relationship): %s", i, foreignKeyFieldName) } _, err := p.ProcessNestedCUD(ctx, operation, itemMap, relatedModel, parentIDs, relatedTableName) if err != nil { @@ -439,9 +454,12 @@ func (p *NestedCUDProcessor) processChildRelations( // Multiple related objects (typed slice) for i, itemMap := range v { // Directly set foreign key if specified - if parentID != nil && foreignKeyFieldName != "" { + // IMPORTANT: In recursive relationships, don't overwrite the primary key + if parentID != nil && foreignKeyFieldName != "" && foreignKeyFieldName != childPKFieldName { itemMap[foreignKeyFieldName] = parentID logger.Debug("Set foreign key in relation typed array[%d]: %s=%v", i, foreignKeyFieldName, parentID) + } else if foreignKeyFieldName == childPKFieldName { + logger.Debug("Skipping foreign key assignment in typed array[%d] - same as primary key (recursive relationship): %s", i, foreignKeyFieldName) } _, err := p.ProcessNestedCUD(ctx, operation, itemMap, relatedModel, parentIDs, relatedTableName) if err != nil { diff --git a/pkg/restheadspec/handler.go b/pkg/restheadspec/handler.go index c4c2a3d..1315bfa 100644 --- a/pkg/restheadspec/handler.go +++ b/pkg/restheadspec/handler.go @@ -1813,16 +1813,26 @@ func (h *Handler) processChildRelationsForField( } } - logger.Debug("Setting parent ID in child data: foreignKeyField=%s, parentID=%v, relForeignKey=%s", - foreignKeyFieldName, parentID, relInfo.ForeignKey) + // Get the primary key name for the child model to avoid overwriting it in recursive relationships + childPKName := reflection.GetPrimaryKeyName(relatedModel) + childPKFieldName := reflection.GetJSONNameForField(relatedModelType, childPKName) + if childPKFieldName == "" { + childPKFieldName = strings.ToLower(childPKName) + } + + logger.Debug("Setting parent ID in child data: foreignKeyField=%s, parentID=%v, relForeignKey=%s, childPK=%s", + foreignKeyFieldName, parentID, relInfo.ForeignKey, childPKFieldName) // Process based on relation type and data structure switch v := relationValue.(type) { case map[string]interface{}: // Single related object - add parent ID to foreign key field - if parentID != nil && foreignKeyFieldName != "" { + // IMPORTANT: In recursive relationships, don't overwrite the primary key + if parentID != nil && foreignKeyFieldName != "" && foreignKeyFieldName != childPKFieldName { v[foreignKeyFieldName] = parentID logger.Debug("Set foreign key in single relation: %s=%v", foreignKeyFieldName, parentID) + } else if foreignKeyFieldName == childPKFieldName { + logger.Debug("Skipping foreign key assignment - same as primary key (recursive relationship): %s", foreignKeyFieldName) } _, err := processor.ProcessNestedCUD(ctx, operation, v, relatedModel, parentIDs, relatedTableName) if err != nil { @@ -1834,9 +1844,12 @@ func (h *Handler) processChildRelationsForField( for i, item := range v { if itemMap, ok := item.(map[string]interface{}); ok { // Add parent ID to foreign key field - if parentID != nil && foreignKeyFieldName != "" { + // IMPORTANT: In recursive relationships, don't overwrite the primary key + if parentID != nil && foreignKeyFieldName != "" && foreignKeyFieldName != childPKFieldName { itemMap[foreignKeyFieldName] = parentID logger.Debug("Set foreign key in relation array[%d]: %s=%v", i, foreignKeyFieldName, parentID) + } else if foreignKeyFieldName == childPKFieldName { + logger.Debug("Skipping foreign key assignment in array[%d] - same as primary key (recursive relationship): %s", i, foreignKeyFieldName) } _, err := processor.ProcessNestedCUD(ctx, operation, itemMap, relatedModel, parentIDs, relatedTableName) if err != nil { @@ -1849,9 +1862,12 @@ func (h *Handler) processChildRelationsForField( // Multiple related objects (typed slice) for i, itemMap := range v { // Add parent ID to foreign key field - if parentID != nil && foreignKeyFieldName != "" { + // IMPORTANT: In recursive relationships, don't overwrite the primary key + if parentID != nil && foreignKeyFieldName != "" && foreignKeyFieldName != childPKFieldName { itemMap[foreignKeyFieldName] = parentID logger.Debug("Set foreign key in relation typed array[%d]: %s=%v", i, foreignKeyFieldName, parentID) + } else if foreignKeyFieldName == childPKFieldName { + logger.Debug("Skipping foreign key assignment in typed array[%d] - same as primary key (recursive relationship): %s", i, foreignKeyFieldName) } _, err := processor.ProcessNestedCUD(ctx, operation, itemMap, relatedModel, parentIDs, relatedTableName) if err != nil { From c75842ebb05f63e35ee597411fc54747ad14716c Mon Sep 17 00:00:00 2001 From: Hein Date: Wed, 14 Jan 2026 15:04:27 +0200 Subject: [PATCH 19/31] =?UTF-8?q?feat(dbmanager):=20=E2=9C=A8=20update=20h?= =?UTF-8?q?ealth=20check=20interval=20and=20add=20tests?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Change default health check interval from 30s to 15s. * Always start background health checks regardless of auto-reconnect setting. * Add tests for health checker functionality and default configurations. --- pkg/dbmanager/config.go | 7 +- pkg/dbmanager/manager.go | 5 +- pkg/dbmanager/manager_test.go | 226 ++++++++++++++++++++++++++++++++++ 3 files changed, 235 insertions(+), 3 deletions(-) create mode 100644 pkg/dbmanager/manager_test.go diff --git a/pkg/dbmanager/config.go b/pkg/dbmanager/config.go index 8f81259..7213eb6 100644 --- a/pkg/dbmanager/config.go +++ b/pkg/dbmanager/config.go @@ -128,7 +128,7 @@ func DefaultManagerConfig() ManagerConfig { RetryAttempts: 3, RetryDelay: 1 * time.Second, RetryMaxDelay: 10 * time.Second, - HealthCheckInterval: 30 * time.Second, + HealthCheckInterval: 15 * time.Second, EnableAutoReconnect: true, } } @@ -161,6 +161,11 @@ func (c *ManagerConfig) ApplyDefaults() { if c.HealthCheckInterval == 0 { c.HealthCheckInterval = defaults.HealthCheckInterval } + // EnableAutoReconnect defaults to true - apply if not explicitly set + // Since this is a boolean, we apply the default unconditionally when it's false + if !c.EnableAutoReconnect { + c.EnableAutoReconnect = defaults.EnableAutoReconnect + } } // Validate validates the manager configuration diff --git a/pkg/dbmanager/manager.go b/pkg/dbmanager/manager.go index 2b68cc2..7bcab48 100644 --- a/pkg/dbmanager/manager.go +++ b/pkg/dbmanager/manager.go @@ -219,9 +219,10 @@ func (m *connectionManager) Connect(ctx context.Context) error { logger.Info("Database connection established: name=%s, type=%s", name, connCfg.Type) } - // Start background health checks if enabled - if m.config.EnableAutoReconnect && m.config.HealthCheckInterval > 0 { + // Always start background health checks + if m.config.HealthCheckInterval > 0 { m.startHealthChecker() + logger.Info("Background health checker started: interval=%v", m.config.HealthCheckInterval) } logger.Info("Database manager initialized: connections=%d", len(m.connections)) diff --git a/pkg/dbmanager/manager_test.go b/pkg/dbmanager/manager_test.go new file mode 100644 index 0000000..3497690 --- /dev/null +++ b/pkg/dbmanager/manager_test.go @@ -0,0 +1,226 @@ +package dbmanager + +import ( + "context" + "database/sql" + "testing" + "time" + + _ "github.com/mattn/go-sqlite3" +) + +func TestBackgroundHealthChecker(t *testing.T) { + // Create a SQLite in-memory database + db, err := sql.Open("sqlite3", ":memory:") + if err != nil { + t.Fatalf("Failed to open database: %v", err) + } + defer db.Close() + + // Create manager config with a short health check interval for testing + cfg := ManagerConfig{ + DefaultConnection: "test", + Connections: map[string]ConnectionConfig{ + "test": { + Name: "test", + Type: DatabaseTypeSQLite, + FilePath: ":memory:", + }, + }, + HealthCheckInterval: 1 * time.Second, // Short interval for testing + EnableAutoReconnect: true, + } + + // Create manager + mgr, err := NewManager(cfg) + if err != nil { + t.Fatalf("Failed to create manager: %v", err) + } + + // Connect - this should start the background health checker + ctx := context.Background() + err = mgr.Connect(ctx) + if err != nil { + t.Fatalf("Failed to connect: %v", err) + } + defer mgr.Close() + + // Get the connection to verify it's healthy + conn, err := mgr.Get("test") + if err != nil { + t.Fatalf("Failed to get connection: %v", err) + } + + // Verify initial health check + err = conn.HealthCheck(ctx) + if err != nil { + t.Errorf("Initial health check failed: %v", err) + } + + // Wait for a few health check cycles + time.Sleep(3 * time.Second) + + // Get stats to verify the connection is still healthy + stats := conn.Stats() + if stats == nil { + t.Fatal("Expected stats to be returned") + } + + if !stats.Connected { + t.Error("Expected connection to still be connected") + } + + if stats.HealthCheckStatus == "" { + t.Error("Expected health check status to be set") + } + + // Verify the manager has started the health checker + if cm, ok := mgr.(*connectionManager); ok { + if cm.healthTicker == nil { + t.Error("Expected health ticker to be running") + } + } +} + +func TestDefaultHealthCheckInterval(t *testing.T) { + // Verify the default health check interval is 15 seconds + defaults := DefaultManagerConfig() + + expectedInterval := 15 * time.Second + if defaults.HealthCheckInterval != expectedInterval { + t.Errorf("Expected default health check interval to be %v, got %v", + expectedInterval, defaults.HealthCheckInterval) + } + + if !defaults.EnableAutoReconnect { + t.Error("Expected EnableAutoReconnect to be true by default") + } +} + +func TestApplyDefaultsEnablesAutoReconnect(t *testing.T) { + // Create a config without setting EnableAutoReconnect + cfg := ManagerConfig{ + Connections: map[string]ConnectionConfig{ + "test": { + Name: "test", + Type: DatabaseTypeSQLite, + FilePath: ":memory:", + }, + }, + } + + // Verify it's false initially (Go's zero value for bool) + if cfg.EnableAutoReconnect { + t.Error("Expected EnableAutoReconnect to be false before ApplyDefaults") + } + + // Apply defaults + cfg.ApplyDefaults() + + // Verify it's now true + if !cfg.EnableAutoReconnect { + t.Error("Expected EnableAutoReconnect to be true after ApplyDefaults") + } + + // Verify health check interval is also set + if cfg.HealthCheckInterval != 15*time.Second { + t.Errorf("Expected health check interval to be 15s, got %v", cfg.HealthCheckInterval) + } +} + +func TestManagerHealthCheck(t *testing.T) { + // Create a SQLite in-memory database + db, err := sql.Open("sqlite3", ":memory:") + if err != nil { + t.Fatalf("Failed to open database: %v", err) + } + defer db.Close() + + // Create manager config + cfg := ManagerConfig{ + DefaultConnection: "test", + Connections: map[string]ConnectionConfig{ + "test": { + Name: "test", + Type: DatabaseTypeSQLite, + FilePath: ":memory:", + }, + }, + HealthCheckInterval: 15 * time.Second, + EnableAutoReconnect: true, + } + + // Create and connect manager + mgr, err := NewManager(cfg) + if err != nil { + t.Fatalf("Failed to create manager: %v", err) + } + + ctx := context.Background() + err = mgr.Connect(ctx) + if err != nil { + t.Fatalf("Failed to connect: %v", err) + } + defer mgr.Close() + + // Perform health check on all connections + err = mgr.HealthCheck(ctx) + if err != nil { + t.Errorf("Health check failed: %v", err) + } + + // Get stats + stats := mgr.Stats() + if stats == nil { + t.Fatal("Expected stats to be returned") + } + + if stats.TotalConnections != 1 { + t.Errorf("Expected 1 total connection, got %d", stats.TotalConnections) + } + + if stats.HealthyCount != 1 { + t.Errorf("Expected 1 healthy connection, got %d", stats.HealthyCount) + } + + if stats.UnhealthyCount != 0 { + t.Errorf("Expected 0 unhealthy connections, got %d", stats.UnhealthyCount) + } +} + +func TestManagerStatsAfterClose(t *testing.T) { + cfg := ManagerConfig{ + DefaultConnection: "test", + Connections: map[string]ConnectionConfig{ + "test": { + Name: "test", + Type: DatabaseTypeSQLite, + FilePath: ":memory:", + }, + }, + HealthCheckInterval: 15 * time.Second, + } + + mgr, err := NewManager(cfg) + if err != nil { + t.Fatalf("Failed to create manager: %v", err) + } + + ctx := context.Background() + err = mgr.Connect(ctx) + if err != nil { + t.Fatalf("Failed to connect: %v", err) + } + + // Close the manager + err = mgr.Close() + if err != nil { + t.Errorf("Failed to close manager: %v", err) + } + + // Stats should show no connections + stats := mgr.Stats() + if stats.TotalConnections != 0 { + t.Errorf("Expected 0 total connections after close, got %d", stats.TotalConnections) + } +} From 289cd7448507566ae539b2b1453c9797871e868e Mon Sep 17 00:00:00 2001 From: Hein Date: Wed, 14 Jan 2026 17:02:26 +0200 Subject: [PATCH 20/31] =?UTF-8?q?feat(database):=20=E2=9C=A8=20Enhance=20P?= =?UTF-8?q?reload=20and=20Join=20functionality?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Introduce skipAutoDetect flag to prevent circular calls in PreloadRelation. * Improve handling of long alias chains in PreloadRelation. * Ensure JoinRelation uses PreloadRelation without causing recursion. * Clear deferred preloads after execution to prevent re-execution. feat(recursive_crud): ✨ Filter valid fields in nested CUD processing * Add filterValidFields method to validate input data against model structure. * Use reflection to ensure only valid fields are processed. feat(reflection): ✨ Add utility to get valid JSON field names * Implement GetValidJSONFieldNames to retrieve valid JSON field names from model. * Enhance field validation during nested CUD operations. fix(handler): 🐛 Adjust recursive preload depth limit * Change recursive preload depth limit from 5 to 4 to prevent excessive recursion. --- go.mod | 1 - go.sum | 59 +-------------- pkg/common/adapters/database/bun.go | 63 +++++++++++----- pkg/common/recursive_crud.go | 113 ++++++++++++++++++++++++++++ pkg/reflection/model_utils.go | 57 ++++++++++++++ pkg/restheadspec/handler.go | 2 +- 6 files changed, 220 insertions(+), 75 deletions(-) diff --git a/go.mod b/go.mod index d815373..fc02b34 100644 --- a/go.mod +++ b/go.mod @@ -116,7 +116,6 @@ require ( github.com/shirou/gopsutil/v4 v4.25.6 // indirect github.com/shopspring/decimal v1.4.0 // indirect github.com/sirupsen/logrus v1.9.3 // indirect - github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 // indirect github.com/spf13/afero v1.15.0 // indirect github.com/spf13/cast v1.10.0 // indirect github.com/spf13/pflag v1.0.10 // indirect diff --git a/go.sum b/go.sum index ca74eb2..36dc515 100644 --- a/go.sum +++ b/go.sum @@ -88,8 +88,6 @@ github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0= github.com/getsentry/sentry-go v0.40.0 h1:VTJMN9zbTvqDqPwheRVLcp0qcUcM+8eFivvGocAaSbo= github.com/getsentry/sentry-go v0.40.0/go.mod h1:eRXCoh3uvmjQLY6qu63BjUZnaBu5L5WhMV1RwYO8W5s= -github.com/glebarez/go-sqlite v1.21.2 h1:3a6LFC4sKahUunAmynQKLZceZCOzUthkRkEAl9gAXWo= -github.com/glebarez/go-sqlite v1.21.2/go.mod h1:sfxdZyhQjTM2Wry3gVYWaW072Ri1WMdWJi0k6+3382k= github.com/glebarez/go-sqlite v1.22.0 h1:uAcMJhaA6r3LHMTFgP0SifzgXg46yJkgxqyuyec+ruQ= github.com/glebarez/go-sqlite v1.22.0/go.mod h1:PlBIdHe0+aUEFn+r2/uthrWq4FxbzugL0L8Li6yQJbc= github.com/glebarez/sqlite v1.11.0 h1:wSG0irqzP6VurnMEpFGer5Li19RpIRi2qvQz++w0GMw= @@ -107,17 +105,15 @@ github.com/go-viper/mapstructure/v2 v2.4.0 h1:EBsztssimR/CONLSZZ04E8qAkxNYq4Qp9L github.com/go-viper/mapstructure/v2 v2.4.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM= github.com/golang-jwt/jwt/v5 v5.0.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= -github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8= github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo= +github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 h1:au07oEsX2xN0ktxqI+Sida1w446QrXBRJ0nee3SNZlA= github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0= github.com/golang-sql/sqlexp v0.1.0 h1:ZCD6MBpcuOVfGVqsEmY5/4FtYiKz6tSyUv9LPEDei6A= github.com/golang-sql/sqlexp v0.1.0/go.mod h1:J4ad9Vo8ZCWQ2GMrC4UCQy1JpCbwU9m3EOqtpKwwwHI= github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= -github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM= -github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/golang/snappy v1.0.0 h1:Oy607GVXHs7RtbggtPBnr2RmDArIsAefDwvrdWvRhGs= github.com/golang/snappy v1.0.0/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= @@ -145,8 +141,6 @@ github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsI github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= -github.com/jackc/pgx/v5 v5.6.0 h1:SWJzexBzPL5jb0GEsrPMLIsi/3jOo7RHlzTjcAeDrPY= -github.com/jackc/pgx/v5 v5.6.0/go.mod h1:DNZ/vlrUnhWCoFGxHAG8U2ljioxukquj7utPDgtQdTw= github.com/jackc/pgx/v5 v5.8.0 h1:TYPDoleBBme0xGSAX3/+NujXXtpZn9HBONkQC7IEZSo= github.com/jackc/pgx/v5 v5.8.0/go.mod h1:QVeDInX2m9VyzvNeiCJVjCkNFqzsNb43204HshNSZKw= github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= @@ -164,8 +158,6 @@ github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkr github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= github.com/kisielk/sqlstruct v0.0.0-20201105191214-5f3e10d3ab46/go.mod h1:yyMNCyc/Ib3bDTKd379tNMpB/7/H5TjM2Y9QJ5THLbE= -github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= -github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= github.com/klauspost/compress v1.18.2 h1:iiPHWW0YrcFgpBYhsA6D1+fqHssJscY/Tm/y2Uqnapk= github.com/klauspost/compress v1.18.2/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4= github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= @@ -183,8 +175,6 @@ github.com/magiconair/properties v1.8.10 h1:s31yESBquKXCV9a/ScB3ESkOjUYYv+X0rg8S github.com/magiconair/properties v1.8.10/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= -github.com/mattn/go-sqlite3 v1.14.32 h1:JD12Ag3oLy1zQA+BNn74xRgaBbdhbNIDYvQUEuuErjs= -github.com/mattn/go-sqlite3 v1.14.32/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= github.com/mattn/go-sqlite3 v1.14.33 h1:A5blZ5ulQo2AtayQ9/limgHEkFreKj1Dv226a1K73s0= github.com/mattn/go-sqlite3 v1.14.33/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= github.com/microsoft/go-mssqldb v1.8.2/go.mod h1:vp38dT33FGfVotRiTmDo3bFyaHq+p3LektQrjTULowo= @@ -246,18 +236,12 @@ github.com/prometheus/client_golang v1.23.2 h1:Je96obch5RDVy3FDMndoUsjAhG5Edi49h github.com/prometheus/client_golang v1.23.2/go.mod h1:Tb1a6LWHB3/SPIzCoaDXI4I8UHKeFTEQ1YCr+0Gyqmg= github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk= github.com/prometheus/client_model v0.6.2/go.mod h1:y3m2F6Gdpfy6Ut/GBsUqTWZqCUvMVzSfMLjcu6wAwpE= -github.com/prometheus/common v0.66.1 h1:h5E0h5/Y8niHc5DlaLlWLArTQI7tMrsfQjHV+d9ZoGs= -github.com/prometheus/common v0.66.1/go.mod h1:gcaUsgf3KfRSwHY4dIMXLPV0K/Wg1oZ8+SbZk/HH/dA= github.com/prometheus/common v0.67.4 h1:yR3NqWO1/UyO1w2PhUvXlGQs/PtFmoveVO0KZ4+Lvsc= github.com/prometheus/common v0.67.4/go.mod h1:gP0fq6YjjNCLssJCQp0yk4M8W6ikLURwkdd/YKtTbyI= -github.com/prometheus/procfs v0.16.1 h1:hZ15bTNuirocR6u0JZ6BAHHmwS1p8B4P6MRqxtzMyRg= -github.com/prometheus/procfs v0.16.1/go.mod h1:teAbpZRB1iIAJYREa1LsoWUXykVXA1KlTmWl8x/U+Is= github.com/prometheus/procfs v0.19.2 h1:zUMhqEW66Ex7OXIiDkll3tl9a1ZdilUOd/F6ZXw4Vws= github.com/prometheus/procfs v0.19.2/go.mod h1:M0aotyiemPhBCM0z5w87kL22CxfcH05ZpYlu+b4J7mw= github.com/puzpuzpuz/xsync/v3 v3.5.1 h1:GJYJZwO6IdxN/IKbneznS6yPkVC+c3zyY/j19c++5Fg= github.com/puzpuzpuz/xsync/v3 v3.5.1/go.mod h1:VjzYrABPabuM4KyBh1Ftq6u8nhwY5tBPKP9jpmh0nnA= -github.com/redis/go-redis/v9 v9.17.1 h1:7tl732FjYPRT9H9aNfyTwKg9iTETjWjGKEJ2t/5iWTs= -github.com/redis/go-redis/v9 v9.17.1/go.mod h1:u410H11HMLoB+TP67dz8rL9s6QW2j76l0//kSOd3370= github.com/redis/go-redis/v9 v9.17.2 h1:P2EGsA4qVIM3Pp+aPocCJ7DguDHhqrXNhVcEp4ViluI= github.com/redis/go-redis/v9 v9.17.2/go.mod h1:u410H11HMLoB+TP67dz8rL9s6QW2j76l0//kSOd3370= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= @@ -268,8 +252,6 @@ github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o= github.com/rs/xid v1.4.0 h1:qd7wPTDkN6KQx2VmMBLrpHkiyQwgFXRnkOLacUiaSNY= github.com/rs/xid v1.4.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= -github.com/sagikazarmark/locafero v0.11.0 h1:1iurJgmM9G3PA/I+wWYIOw/5SyBtxapeHDcg+AAIFXc= -github.com/sagikazarmark/locafero v0.11.0/go.mod h1:nVIGvgyzw595SUSUE6tvCp3YYTeHs15MvlmU87WwIik= github.com/sagikazarmark/locafero v0.12.0 h1:/NQhBAkUb4+fH1jivKHWusDYFjMOOKU88eegjfxfHb4= github.com/sagikazarmark/locafero v0.12.0/go.mod h1:sZh36u/YSZ918v0Io+U9ogLYQJ9tLLBmM4eneO6WwsI= github.com/shirou/gopsutil/v4 v4.25.6 h1:kLysI2JsKorfaFPcYmcJqbzROzsBWEOAtw6A7dIfqXs= @@ -278,8 +260,6 @@ github.com/shopspring/decimal v1.4.0 h1:bxl37RwXBklmTi0C79JfXCEBD1cqqHt0bbgBAGFp github.com/shopspring/decimal v1.4.0/go.mod h1:gawqmDU56v4yIKSwfBSFip1HdCCXN8/+DMd9qYNcwME= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= -github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 h1:+jumHNA0Wrelhe64i8F6HNlS8pkoyMv5sreGx2Ry5Rw= -github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8/go.mod h1:3n1Cwaq1E1/1lhQhtRK2ts/ZwZEhjcQeJQ1RuC6Q/8U= github.com/spf13/afero v1.15.0 h1:b/YBCLWAJdFWJTN9cLhiXXcD7mzKn9Dm86dNnfyQw1I= github.com/spf13/afero v1.15.0/go.mod h1:NC2ByUVxtQs4b3sIUphxK0NioZnmxgyCrfzeuq8lxMg= github.com/spf13/cast v1.10.0 h1:h2x0u2shc1QuLHfxi+cTJvs30+ZAHOGRic8uyGTDWxY= @@ -310,11 +290,9 @@ github.com/testcontainers/testcontainers-go v0.40.0/go.mod h1:FSXV5KQtX2HAMlm7U3 github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY= github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= -github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= github.com/tidwall/match v1.2.0 h1:0pt8FlkOwjN2fPt4bIl4BoNxb98gGHN2ObFEDkrfZnM= github.com/tidwall/match v1.2.0/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= -github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= @@ -344,8 +322,6 @@ github.com/warkanum/bun v1.2.17 h1:HP8eTuKSNcqMDhhIPFxEbgV/yct6RR0/c3qHH3PNZUA= github.com/warkanum/bun v1.2.17/go.mod h1:jMoNg2n56ckaawi/O/J92BHaECmrz6IRjuMWqlMaMTM= github.com/xdg-go/pbkdf2 v1.0.0 h1:Su7DPu48wXMwC3bs7MCNG+z4FhcyEuz5dlvchbq0B0c= github.com/xdg-go/pbkdf2 v1.0.0/go.mod h1:jrpuAogTd400dnrH08LKmI/xc1MbPOebTwRqcT5RDeI= -github.com/xdg-go/scram v1.1.2 h1:FHX5I5B4i4hKRVRBCFRxq1iQRej7WO3hhBuJf+UUySY= -github.com/xdg-go/scram v1.1.2/go.mod h1:RT/sEzTbU5y00aCK8UOx6R7YryM0iF1N2MOmC3kKLN4= github.com/xdg-go/scram v1.2.0 h1:bYKF2AEwG5rqd1BumT4gAnvwU/M9nBp2pTSxeZw7Wvs= github.com/xdg-go/scram v1.2.0/go.mod h1:3dlrS0iBaWKYVt2ZfA4cj48umJZ+cAEbR6/SjLA88I8= github.com/xdg-go/stringprep v1.0.4 h1:XLI/Ng3O1Atzq0oBs3TWm+5ZVgkq2aqdlvP9JtoZ6c8= @@ -381,16 +357,10 @@ go.opentelemetry.io/proto/otlp v1.7.1 h1:gTOMpGDb0WTBOP8JaO72iL3auEZhVmAQg4ipjOV go.opentelemetry.io/proto/otlp v1.7.1/go.mod h1:b2rVh6rfI/s2pHWNlB7ILJcRALpcNDzKhACevjI+ZnE= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= -go.uber.org/multierr v1.10.0 h1:S0h4aNzvfcFsC3dRF1jLoaov7oRaKqRGC/pUEJ2yvPQ= -go.uber.org/multierr v1.10.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= -go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8= -go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E= go.uber.org/zap v1.27.1 h1:08RqriUEv8+ArZRYSTXy1LeBScaMpVSTBhCeaZYfMYc= go.uber.org/zap v1.27.1/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E= -go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI= -go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU= go.yaml.in/yaml/v2 v2.4.3 h1:6gvOSjQoTB3vt1l+CU+tSyi/HOjfOjRLJ4YwYZGwRO0= go.yaml.in/yaml/v2 v2.4.3/go.mod h1:zSxWcmIDjOzPXpjlTTbAsKokqkDNAVtZO0WOMiT90s8= go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc= @@ -407,12 +377,8 @@ golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOM golang.org/x/crypto v0.22.0/go.mod h1:vr6Su+7cTlO45qkww3VDJlzDn0ctJvRgYbC2NvXHt+M= golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= golang.org/x/crypto v0.24.0/go.mod h1:Z1PMYSOR5nyMcyAVAIQSKCDwalqy85Aqn1x3Ws4L5DM= -golang.org/x/crypto v0.43.0 h1:dduJYIi3A3KOfdGOHX8AVZ/jGiyPa3IbBozJ5kNuE04= -golang.org/x/crypto v0.43.0/go.mod h1:BFbav4mRNlXJL4wNeejLpWxB7wMbc79PdRGhWKncxR0= golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU= golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0= -golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6 h1:zfMcR1Cs4KNuomFFgGefv5N0czO2XZpUbxGUy8i8ug0= -golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6/go.mod h1:46edojNIoXTNOhySWIWdix628clX9ODXwPsQuG6hsK0= golang.org/x/exp v0.0.0-20251219203646-944ab1f22d93 h1:fQsdNF2N+/YewlRZiricy4P1iimyPKZ/xwniHj8Q2a0= golang.org/x/exp v0.0.0-20251219203646-944ab1f22d93/go.mod h1:EPRbTFwzwjXj9NpYyyrvenVh9Y+GFeEvMNh7Xuz7xgU= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= @@ -421,8 +387,6 @@ golang.org/x/mod v0.9.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.15.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= -golang.org/x/mod v0.30.0 h1:fDEXFVZ/fmCKProc/yAXXUijritrDzahmwwefnjoPFk= -golang.org/x/mod v0.30.0/go.mod h1:lAsf5O2EvJeSFMiBxXDki7sCgAxEUcZHXoXMKT4GJKc= golang.org/x/mod v0.31.0 h1:HaW9xtz0+kOcWKwli0ZXy79Ix+UW/vOfmWI5QVd2tgI= golang.org/x/mod v0.31.0/go.mod h1:43JraMp9cGx1Rx3AqioxrbrhNsLl2l/iNAvuBkrezpg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= @@ -442,8 +406,6 @@ golang.org/x/net v0.22.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg= golang.org/x/net v0.24.0/go.mod h1:2Q7sJY5mzlzWjKtYUEXSlBWCdyaioyXzRB2RtU8KVE8= golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE= -golang.org/x/net v0.45.0 h1:RLBg5JKixCy82FtLJpeNlVM0nrSqpCRYzVU1n8kj0tM= -golang.org/x/net v0.45.0/go.mod h1:ECOoLqd5U3Lhyeyo/QDCEVQ4sNgYsqvCZ722XogGieY= golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU= golang.org/x/net v0.48.0/go.mod h1:+ndRgGjkh8FGtu1w1FGbEC31if4VrNVMuKTgcAAnQRY= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -453,8 +415,6 @@ golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sync v0.9.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= -golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I= -golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -480,8 +440,6 @@ golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc= -golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk= golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/telemetry v0.0.0-20240228155512-f48c80bd79b2/go.mod h1:TeRTkGYfJXctD9OcfyVLyj2J3IxLnKwHJR8f4D8a3YE= @@ -499,9 +457,8 @@ golang.org/x/term v0.18.0/go.mod h1:ILwASektA3OnRv7amZ1xhE/KTR+u50pbXfZ03+6Nx58= golang.org/x/term v0.19.0/go.mod h1:2CuTdWZ7KHSQwUzKva0cbMg6q2DMI3Mmxp+gKJbskEk= golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY= golang.org/x/term v0.21.0/go.mod h1:ooXLefLobQVslOqselCNF4SxFAaoS6KujMbsGzSDmX0= -golang.org/x/term v0.36.0 h1:zMPR+aF8gfksFprF/Nc/rd1wRS1EI6nDBGyWAvDzx2Q= -golang.org/x/term v0.36.0/go.mod h1:Qu394IJq6V6dCBRgwqshf3mPF85AqzYEzofzRdZkWss= golang.org/x/term v0.38.0 h1:PQ5pkm/rLO6HnxFR7N2lJHOZX6Kez5Y1gDSJla6jo7Q= +golang.org/x/term v0.38.0/go.mod h1:bSEAKrOT1W+VSu9TSCMtoGEOUcKxOKgl3LE5QEF/xVg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= @@ -516,8 +473,6 @@ golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI= golang.org/x/text v0.20.0/go.mod h1:D4IsuqiFMhST5bX19pQ9ikHC2GsaKyk/oF+pn3ducp4= -golang.org/x/text v0.30.0 h1:yznKA/E9zq54KzlzBEAWn1NXSQ8DIp/NYMy88xJjl4k= -golang.org/x/text v0.30.0/go.mod h1:yDdHFIX9t+tORqspjENWgzaCVXgk0yYnYuSZ8UzzBVM= golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU= golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY= golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI= @@ -528,9 +483,8 @@ golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58= golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk= -golang.org/x/tools v0.39.0 h1:ik4ho21kwuQln40uelmciQPp9SipgNDdrafrYA4TmQQ= -golang.org/x/tools v0.39.0/go.mod h1:JnefbkDPyD8UU2kI5fuf8ZX4/yUeh9W877ZeBONxUqQ= golang.org/x/tools v0.40.0 h1:yLkxfA+Qnul4cs9QA3KnlFu0lVmd8JJfoq+E41uSutA= +golang.org/x/tools v0.40.0/go.mod h1:Ik/tzLRlbscWpqqMRjyWYDisX8bG13FrdXp3o4Sr9lc= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk= @@ -541,8 +495,6 @@ google.golang.org/genproto/googleapis/rpc v0.0.0-20250825161204-c5933d9347a5 h1: google.golang.org/genproto/googleapis/rpc v0.0.0-20250825161204-c5933d9347a5/go.mod h1:M4/wBTSeyLxupu3W3tJtOgB14jILAS/XWPSSa3TAlJc= google.golang.org/grpc v1.75.0 h1:+TW+dqTd2Biwe6KKfhE5JpiYIBWq865PhKGSXiivqt4= google.golang.org/grpc v1.75.0/go.mod h1:JtPAzKiq4v1xcAB2hydNlWI2RnF85XXcV0mhKXr2ecQ= -google.golang.org/protobuf v1.36.8 h1:xHScyCOEuuwZEc6UtSOvPbAT4zRh0xcNRYekJwfqyMc= -google.golang.org/protobuf v1.36.8/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU= google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE= google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= @@ -561,7 +513,6 @@ gorm.io/driver/sqlite v1.6.0 h1:WHRRrIiulaPiPFmDcod6prc4l2VGVWHz80KspNsxSfQ= gorm.io/driver/sqlite v1.6.0/go.mod h1:AO9V1qIQddBESngQUKWL9yoH93HIeA1X6V633rBwyT8= gorm.io/driver/sqlserver v1.6.3 h1:UR+nWCuphPnq7UxnL57PSrlYjuvs+sf1N59GgFX7uAI= gorm.io/driver/sqlserver v1.6.3/go.mod h1:VZeNn7hqX1aXoN5TPAFGWvxWG90xtA8erGn2gQmpc6U= -gorm.io/gorm v1.30.0 h1:qbT5aPv1UH8gI99OsRlvDToLxW5zR7FzS9acZDOZcgs= gorm.io/gorm v1.30.0/go.mod h1:8Z33v652h4//uMA76KjeDH8mJXPm1QNCYrMeatR0DOE= gorm.io/gorm v1.31.1 h1:7CA8FTFz/gRfgqgpeKIBcervUn3xSyPUmr6B2WXJ7kg= gorm.io/gorm v1.31.1/go.mod h1:XyQVbO2k6YkOis7C2437jSit3SsDK72s7n7rsSHd+Gs= @@ -579,8 +530,6 @@ modernc.org/gc/v3 v3.1.1 h1:k8T3gkXWY9sEiytKhcgyiZ2L0DTyCQ/nvX+LoCljoRE= modernc.org/gc/v3 v3.1.1/go.mod h1:HFK/6AGESC7Ex+EZJhJ2Gni6cTaYpSMmU/cT9RmlfYY= modernc.org/goabi0 v0.2.0 h1:HvEowk7LxcPd0eq6mVOAEMai46V+i7Jrj13t4AzuNks= modernc.org/goabi0 v0.2.0/go.mod h1:CEFRnnJhKvWT1c1JTI3Avm+tgOWbkOu5oPA8eH8LnMI= -modernc.org/libc v1.67.0 h1:QzL4IrKab2OFmxA3/vRYl0tLXrIamwrhD6CKD4WBVjQ= -modernc.org/libc v1.67.0/go.mod h1:QvvnnJ5P7aitu0ReNpVIEyesuhmDLQ8kaEoyMjIFZJA= modernc.org/libc v1.67.4 h1:zZGmCMUVPORtKv95c2ReQN5VDjvkoRm9GWPTEPuvlWg= modernc.org/libc v1.67.4/go.mod h1:QvvnnJ5P7aitu0ReNpVIEyesuhmDLQ8kaEoyMjIFZJA= modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU= @@ -591,8 +540,6 @@ modernc.org/opt v0.1.4 h1:2kNGMRiUjrp4LcaPuLY2PzUfqM/w9N23quVwhKt5Qm8= modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns= modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w= modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE= -modernc.org/sqlite v1.40.1 h1:VfuXcxcUWWKRBuP8+BR9L7VnmusMgBNNnBYGEe9w/iY= -modernc.org/sqlite v1.40.1/go.mod h1:9fjQZ0mB1LLP0GYrp39oOJXx/I2sxEnZtzCmEQIKvGE= modernc.org/sqlite v1.42.2 h1:7hkZUNJvJFN2PgfUdjni9Kbvd4ef4mNLOu0B9FGxM74= modernc.org/sqlite v1.42.2/go.mod h1:+VkC6v3pLOAE0A0uVucQEcbVW0I5nHCeDaBf+DpsQT8= modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0= diff --git a/pkg/common/adapters/database/bun.go b/pkg/common/adapters/database/bun.go index 52ba33d..2d6b59b 100644 --- a/pkg/common/adapters/database/bun.go +++ b/pkg/common/adapters/database/bun.go @@ -211,6 +211,7 @@ type BunSelectQuery struct { deferredPreloads []deferredPreload // Preloads to execute as separate queries inJoinContext bool // Track if we're in a JOIN relation context joinTableAlias string // Alias to use for JOIN conditions + skipAutoDetect bool // Skip auto-detection to prevent circular calls } // deferredPreload represents a preload that will be executed as a separate query @@ -531,22 +532,25 @@ func (b *BunSelectQuery) Preload(relation string, conditions ...interface{}) com func (b *BunSelectQuery) PreloadRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery { // Auto-detect relationship type and choose optimal loading strategy // Get the model from the query if available - model := b.query.GetModel() - if model != nil && model.Value() != nil { - relType := reflection.GetRelationType(model.Value(), relation) + // Skip auto-detection if flag is set (prevents circular calls from JoinRelation) + if !b.skipAutoDetect { + model := b.query.GetModel() + if model != nil && model.Value() != nil { + relType := reflection.GetRelationType(model.Value(), relation) - // Log the detected relationship type - logger.Debug("PreloadRelation '%s' detected as: %s", relation, relType) + // Log the detected relationship type + logger.Debug("PreloadRelation '%s' detected as: %s", relation, relType) - // If this is a belongs-to or has-one relation, use JOIN for better performance - if relType.ShouldUseJoin() { - logger.Info("Using JOIN strategy for %s relation '%s'", relType, relation) - return b.JoinRelation(relation, apply...) - } + // If this is a belongs-to or has-one relation, use JOIN for better performance + if relType.ShouldUseJoin() { + logger.Info("Using JOIN strategy for %s relation '%s'", relType, relation) + return b.JoinRelation(relation, apply...) + } - // For has-many, many-to-many, or unknown: use separate query (safer default) - if relType == reflection.RelationHasMany || relType == reflection.RelationManyToMany { - logger.Debug("Using separate query for %s relation '%s'", relType, relation) + // For has-many, many-to-many, or unknown: use separate query (safer default) + if relType == reflection.RelationHasMany || relType == reflection.RelationManyToMany { + logger.Debug("Using separate query for %s relation '%s'", relType, relation) + } } } @@ -559,7 +563,7 @@ func (b *BunSelectQuery) PreloadRelation(relation string, apply ...func(common.S const safeAliasLimit = 35 // Leave room for column names // If the alias chain is too long, defer this preload to be executed as a separate query - if len(aliasChain) > safeAliasLimit { + if len(relationParts) > 1 && len(aliasChain) > safeAliasLimit { logger.Info("Preload relation '%s' creates long alias chain '%s' (%d chars). "+ "Using separate query to avoid PostgreSQL %d-char identifier limit.", relation, aliasChain, len(aliasChain), postgresIdentifierLimit) @@ -683,6 +687,10 @@ func (b *BunSelectQuery) JoinRelation(relation string, apply ...func(common.Sele // Use PreloadRelation with the wrapped functions // Bun's Relation() will use JOIN for belongs-to and has-one relations + // CRITICAL: Set skipAutoDetect flag to prevent circular call + // (PreloadRelation would detect belongs-to and call JoinRelation again) + b.skipAutoDetect = true + defer func() { b.skipAutoDetect = false }() return b.PreloadRelation(relation, wrappedApply...) } @@ -742,6 +750,8 @@ func (b *BunSelectQuery) Scan(ctx context.Context, dest interface{}) (err error) logger.Warn("Failed to execute deferred preloads: %v", err) // Don't fail the whole query, just log the warning } + // Clear deferred preloads to prevent re-execution + b.deferredPreloads = nil } return nil @@ -810,6 +820,8 @@ func (b *BunSelectQuery) ScanModel(ctx context.Context) (err error) { logger.Warn("Failed to execute deferred preloads: %v", err) // Don't fail the whole query, just log the warning } + // Clear deferred preloads to prevent re-execution + b.deferredPreloads = nil } return nil @@ -898,13 +910,30 @@ func (b *BunSelectQuery) loadChildRelationForRecord(ctx context.Context, record return nil } - // Get the interface value to pass to Bun - parentValue := parentField.Interface() + // Get a pointer to the parent field so Bun can modify it + // CRITICAL: We need to pass a pointer, not a value, so that when Bun + // loads the child records and appends them to the slice, the changes + // are reflected in the original struct field. + var parentPtr interface{} + if parentField.Kind() == reflect.Ptr { + // Field is already a pointer (e.g., Parent *Parent), use as-is + parentPtr = parentField.Interface() + } else { + // Field is a value (e.g., Comments []Comment), get its address + if parentField.CanAddr() { + parentPtr = parentField.Addr().Interface() + } else { + return fmt.Errorf("cannot get address of field '%s'", parentRelation) + } + } // Load the child relation on the parent record // This uses a shorter alias since we're only loading "Child", not "Parent.Child" + // CRITICAL: Use WherePK() to ensure we only load children for THIS specific parent + // record, not the first parent in the database table. return b.db.NewSelect(). - Model(parentValue). + Model(parentPtr). + WherePK(). Relation(childRelation, func(sq *bun.SelectQuery) *bun.SelectQuery { // Apply any custom query modifications if len(apply) > 0 { diff --git a/pkg/common/recursive_crud.go b/pkg/common/recursive_crud.go index 19bbb83..caecdd8 100644 --- a/pkg/common/recursive_crud.go +++ b/pkg/common/recursive_crud.go @@ -98,6 +98,10 @@ func (p *NestedCUDProcessor) ProcessNestedCUD( } } + // Filter regularData to only include fields that exist in the model + // Use MapToStruct to validate and filter fields + regularData = p.filterValidFields(regularData, model) + // Inject parent IDs for foreign key resolution p.injectForeignKeys(regularData, modelType, parentIDs) @@ -187,6 +191,115 @@ func (p *NestedCUDProcessor) extractCRUDRequest(data map[string]interface{}) str return "" } +// filterValidFields filters input data to only include fields that exist in the model +// Uses reflection.MapToStruct to validate fields and extract only those that match the model +func (p *NestedCUDProcessor) filterValidFields(data map[string]interface{}, model interface{}) map[string]interface{} { + if len(data) == 0 { + return data + } + + // Create a new instance of the model to use with MapToStruct + modelType := reflect.TypeOf(model) + for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) { + modelType = modelType.Elem() + } + + if modelType == nil || modelType.Kind() != reflect.Struct { + return data + } + + // Create a new instance of the model + tempModel := reflect.New(modelType).Interface() + + // Use MapToStruct to map the data - this will only map valid fields + err := reflection.MapToStruct(data, tempModel) + if err != nil { + logger.Debug("Error mapping data to model: %v", err) + return data + } + + // Extract the mapped fields back into a map + // This effectively filters out any fields that don't exist in the model + filteredData := make(map[string]interface{}) + tempModelValue := reflect.ValueOf(tempModel).Elem() + + for key, value := range data { + // Check if the field was successfully mapped + if fieldWasMapped(tempModelValue, modelType, key) { + filteredData[key] = value + } else { + logger.Debug("Skipping invalid field '%s' - not found in model %v", key, modelType) + } + } + + return filteredData +} + +// fieldWasMapped checks if a field with the given key was mapped to the model +func fieldWasMapped(modelValue reflect.Value, modelType reflect.Type, key string) bool { + // Look for the field by JSON tag or field name + for i := 0; i < modelType.NumField(); i++ { + field := modelType.Field(i) + + // Skip unexported fields + if !field.IsExported() { + continue + } + + // Check JSON tag + jsonTag := field.Tag.Get("json") + if jsonTag != "" && jsonTag != "-" { + parts := strings.Split(jsonTag, ",") + if len(parts) > 0 && parts[0] == key { + return true + } + } + + // Check bun tag + bunTag := field.Tag.Get("bun") + if bunTag != "" && bunTag != "-" { + if colName := reflection.ExtractColumnFromBunTag(bunTag); colName == key { + return true + } + } + + // Check gorm tag + gormTag := field.Tag.Get("gorm") + if gormTag != "" && gormTag != "-" { + if colName := reflection.ExtractColumnFromGormTag(gormTag); colName == key { + return true + } + } + + // Check lowercase field name + if strings.EqualFold(field.Name, key) { + return true + } + + // Handle embedded structs recursively + if field.Anonymous { + fieldType := field.Type + if fieldType.Kind() == reflect.Ptr { + fieldType = fieldType.Elem() + } + if fieldType.Kind() == reflect.Struct { + embeddedValue := modelValue.Field(i) + if embeddedValue.Kind() == reflect.Ptr { + if embeddedValue.IsNil() { + continue + } + embeddedValue = embeddedValue.Elem() + } + if fieldWasMapped(embeddedValue, fieldType, key) { + return true + } + } + } + } + + return false +} + // injectForeignKeys injects parent IDs into data for foreign key fields func (p *NestedCUDProcessor) injectForeignKeys(data map[string]interface{}, modelType reflect.Type, parentIDs map[string]interface{}) { if len(parentIDs) == 0 { diff --git a/pkg/reflection/model_utils.go b/pkg/reflection/model_utils.go index 58be9f8..350fd72 100644 --- a/pkg/reflection/model_utils.go +++ b/pkg/reflection/model_utils.go @@ -1370,6 +1370,63 @@ func convertToFloat64(value interface{}) (float64, bool) { return 0, false } +// GetValidJSONFieldNames returns a map of valid JSON field names for a model +// This can be used to validate input data against a model's structure +// The map keys are the JSON field names (from json tags) that exist in the model +func GetValidJSONFieldNames(modelType reflect.Type) map[string]bool { + validFields := make(map[string]bool) + + // Unwrap pointers to get to the base struct type + for modelType != nil && modelType.Kind() == reflect.Pointer { + modelType = modelType.Elem() + } + + if modelType == nil || modelType.Kind() != reflect.Struct { + return validFields + } + + collectValidFieldNames(modelType, validFields) + return validFields +} + +// collectValidFieldNames recursively collects valid JSON field names from a struct type +func collectValidFieldNames(typ reflect.Type, validFields map[string]bool) { + for i := 0; i < typ.NumField(); i++ { + field := typ.Field(i) + + // Skip unexported fields + if !field.IsExported() { + continue + } + + // Check for embedded structs + if field.Anonymous { + fieldType := field.Type + if fieldType.Kind() == reflect.Ptr { + fieldType = fieldType.Elem() + } + if fieldType.Kind() == reflect.Struct { + // Recursively add fields from embedded struct + collectValidFieldNames(fieldType, validFields) + continue + } + } + + // Get the JSON tag name for this field (same logic as MapToStruct) + jsonTag := field.Tag.Get("json") + if jsonTag != "" && jsonTag != "-" { + // Extract the field name from the JSON tag (before any options like omitempty) + parts := strings.Split(jsonTag, ",") + if len(parts) > 0 && parts[0] != "" { + validFields[parts[0]] = true + } + } else { + // If no JSON tag, use the field name in lowercase as a fallback + validFields[strings.ToLower(field.Name)] = true + } + } +} + // getRelationModelSingleLevel gets the model type for a single level field (non-recursive) // This is a helper function used by GetRelationModel to handle one level at a time func getRelationModelSingleLevel(model interface{}, fieldName string) interface{} { diff --git a/pkg/restheadspec/handler.go b/pkg/restheadspec/handler.go index 1315bfa..f3dc8d8 100644 --- a/pkg/restheadspec/handler.go +++ b/pkg/restheadspec/handler.go @@ -883,7 +883,7 @@ func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload co }) // Handle recursive preloading - if preload.Recursive && depth < 5 { + if preload.Recursive && depth < 4 { logger.Debug("Applying recursive preload for %s at depth %d", preload.Relation, depth+1) // For recursive relationships, we need to get the last part of the relation path From b87841a51c6a22df8632ec86efefcb969cb15a24 Mon Sep 17 00:00:00 2001 From: Hein Date: Thu, 15 Jan 2026 14:07:45 +0200 Subject: [PATCH 21/31] =?UTF-8?q?feat(restheadspec):=20=E2=9C=A8=20Add=20c?= =?UTF-8?q?ustom=20SQL=20JOIN=20support?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Introduced `x-custom-sql-join` header for custom SQL JOIN clauses. - Supports single and multiple JOINs, separated by `|`. - Enhanced query handling to apply custom JOINs directly. - Updated documentation to reflect new functionality. - Added tests for parsing custom SQL JOINs from query parameters and headers. --- pkg/restheadspec/HEADERS.md | 17 ++++++-- pkg/restheadspec/cache_helpers.go | 8 ++-- pkg/restheadspec/handler.go | 10 +++++ pkg/restheadspec/headers.go | 44 +++++++++++++++++++-- pkg/restheadspec/query_params_test.go | 56 +++++++++++++++++++++++++++ pkg/restheadspec/restheadspec.go | 1 + 6 files changed, 127 insertions(+), 9 deletions(-) diff --git a/pkg/restheadspec/HEADERS.md b/pkg/restheadspec/HEADERS.md index c422404..0149a0b 100644 --- a/pkg/restheadspec/HEADERS.md +++ b/pkg/restheadspec/HEADERS.md @@ -214,14 +214,25 @@ x-expand: department:id,name,code **Note:** Currently, expand falls back to preload behavior. Full JOIN expansion is planned for future implementation. #### `x-custom-sql-join` -Raw SQL JOIN statement. +Custom SQL JOIN clauses for joining tables in queries. -**Format:** SQL JOIN clause +**Format:** SQL JOIN clause or multiple clauses separated by `|` + +**Single JOIN:** ``` x-custom-sql-join: LEFT JOIN departments d ON d.id = employees.department_id ``` -⚠️ **Note:** Not yet fully implemented. +**Multiple JOINs:** +``` +x-custom-sql-join: LEFT JOIN departments d ON d.id = e.dept_id | INNER JOIN roles r ON r.id = e.role_id +``` + +**Features:** +- Supports any type of JOIN (INNER, LEFT, RIGHT, FULL, CROSS) +- Multiple JOINs can be specified using the pipe `|` separator +- JOINs are sanitized for security +- Can be specified via headers or query parameters --- diff --git a/pkg/restheadspec/cache_helpers.go b/pkg/restheadspec/cache_helpers.go index 094e435..1e81187 100644 --- a/pkg/restheadspec/cache_helpers.go +++ b/pkg/restheadspec/cache_helpers.go @@ -26,6 +26,7 @@ type queryCacheKey struct { Sort []common.SortOption `json:"sort"` CustomSQLWhere string `json:"custom_sql_where,omitempty"` CustomSQLOr string `json:"custom_sql_or,omitempty"` + CustomSQLJoin []string `json:"custom_sql_join,omitempty"` Expand []expandOptionKey `json:"expand,omitempty"` Distinct bool `json:"distinct,omitempty"` CursorForward string `json:"cursor_forward,omitempty"` @@ -40,7 +41,7 @@ type cachedTotal struct { // buildExtendedQueryCacheKey builds a cache key for extended query options (restheadspec) // Includes expand, distinct, and cursor pagination options func buildExtendedQueryCacheKey(tableName string, filters []common.FilterOption, sort []common.SortOption, - customWhere, customOr string, expandOpts []interface{}, distinct bool, cursorFwd, cursorBwd string) string { + customWhere, customOr string, customJoin []string, expandOpts []interface{}, distinct bool, cursorFwd, cursorBwd string) string { key := queryCacheKey{ TableName: tableName, @@ -48,6 +49,7 @@ func buildExtendedQueryCacheKey(tableName string, filters []common.FilterOption, Sort: sort, CustomSQLWhere: customWhere, CustomSQLOr: customOr, + CustomSQLJoin: customJoin, Distinct: distinct, CursorForward: cursorFwd, CursorBackward: cursorBwd, @@ -75,8 +77,8 @@ func buildExtendedQueryCacheKey(tableName string, filters []common.FilterOption, jsonData, err := json.Marshal(key) if err != nil { // Fallback to simple string concatenation if JSON fails - return hashString(fmt.Sprintf("%s_%v_%v_%s_%s_%v_%v_%s_%s", - tableName, filters, sort, customWhere, customOr, expandOpts, distinct, cursorFwd, cursorBwd)) + return hashString(fmt.Sprintf("%s_%v_%v_%s_%s_%v_%v_%v_%s_%s", + tableName, filters, sort, customWhere, customOr, customJoin, expandOpts, distinct, cursorFwd, cursorBwd)) } return hashString(string(jsonData)) diff --git a/pkg/restheadspec/handler.go b/pkg/restheadspec/handler.go index f3dc8d8..c176bcb 100644 --- a/pkg/restheadspec/handler.go +++ b/pkg/restheadspec/handler.go @@ -502,6 +502,15 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st } } + // Apply custom SQL JOIN clauses + if len(options.CustomSQLJoin) > 0 { + for _, joinClause := range options.CustomSQLJoin { + logger.Debug("Applying custom SQL JOIN: %s", joinClause) + // Joins are already sanitized during parsing, so we can apply them directly + query = query.Join(joinClause) + } + } + // If ID is provided, filter by ID if id != "" { pkName := reflection.GetPrimaryKeyName(model) @@ -552,6 +561,7 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st options.Sort, options.CustomSQLWhere, options.CustomSQLOr, + options.CustomSQLJoin, expandOpts, options.Distinct, options.CursorForward, diff --git a/pkg/restheadspec/headers.go b/pkg/restheadspec/headers.go index bdae2bd..eb32fb8 100644 --- a/pkg/restheadspec/headers.go +++ b/pkg/restheadspec/headers.go @@ -26,7 +26,8 @@ type ExtendedRequestOptions struct { CustomSQLOr string // Joins - Expand []ExpandOption + Expand []ExpandOption + CustomSQLJoin []string // Custom SQL JOIN clauses // Advanced features AdvancedSQL map[string]string // Column -> SQL expression @@ -111,6 +112,7 @@ func (h *Handler) parseOptionsFromHeaders(r common.Request, model interface{}) E AdvancedSQL: make(map[string]string), ComputedQL: make(map[string]string), Expand: make([]ExpandOption, 0), + CustomSQLJoin: make([]string, 0), ResponseFormat: "simple", // Default response format SingleRecordAsObject: true, // Default: normalize single-element arrays to objects } @@ -185,8 +187,7 @@ func (h *Handler) parseOptionsFromHeaders(r common.Request, model interface{}) E case strings.HasPrefix(key, "x-expand"): h.parseExpand(&options, decodedValue) case strings.HasPrefix(key, "x-custom-sql-join"): - // TODO: Implement custom SQL join - logger.Debug("Custom SQL join not yet implemented: %s", decodedValue) + h.parseCustomSQLJoin(&options, decodedValue) // Sorting & Pagination case strings.HasPrefix(key, "x-sort"): @@ -495,6 +496,43 @@ func (h *Handler) parseExpand(options *ExtendedRequestOptions, value string) { } } +// parseCustomSQLJoin parses x-custom-sql-join header +// Format: Single JOIN clause or multiple JOIN clauses separated by | +// Example: "LEFT JOIN departments d ON d.id = employees.department_id" +// Example: "LEFT JOIN departments d ON d.id = e.dept_id | INNER JOIN roles r ON r.id = e.role_id" +func (h *Handler) parseCustomSQLJoin(options *ExtendedRequestOptions, value string) { + if value == "" { + return + } + + // Split by | for multiple joins + joins := strings.Split(value, "|") + for _, joinStr := range joins { + joinStr = strings.TrimSpace(joinStr) + if joinStr == "" { + continue + } + + // Basic validation: should contain "JOIN" keyword + upperJoin := strings.ToUpper(joinStr) + if !strings.Contains(upperJoin, "JOIN") { + logger.Warn("Invalid custom SQL join (missing JOIN keyword): %s", joinStr) + continue + } + + // Sanitize the join clause using common.SanitizeWhereClause + // Note: This is basic sanitization - in production you may want stricter validation + sanitizedJoin := common.SanitizeWhereClause(joinStr, "", nil) + if sanitizedJoin == "" { + logger.Warn("Custom SQL join failed sanitization: %s", joinStr) + continue + } + + logger.Debug("Adding custom SQL join: %s", sanitizedJoin) + options.CustomSQLJoin = append(options.CustomSQLJoin, sanitizedJoin) + } +} + // parseSorting parses x-sort header // Format: +field1,-field2,field3 (+ for ASC, - for DESC, default ASC) func (h *Handler) parseSorting(options *ExtendedRequestOptions, value string) { diff --git a/pkg/restheadspec/query_params_test.go b/pkg/restheadspec/query_params_test.go index ac1beeb..5ea19ec 100644 --- a/pkg/restheadspec/query_params_test.go +++ b/pkg/restheadspec/query_params_test.go @@ -301,6 +301,62 @@ func TestParseOptionsFromQueryParams(t *testing.T) { } }, }, + { + name: "Parse custom SQL JOIN from query params", + queryParams: map[string]string{ + "x-custom-sql-join": `LEFT JOIN departments d ON d.id = employees.department_id`, + }, + validate: func(t *testing.T, options ExtendedRequestOptions) { + if len(options.CustomSQLJoin) == 0 { + t.Error("Expected CustomSQLJoin to be set") + return + } + if len(options.CustomSQLJoin) != 1 { + t.Errorf("Expected 1 custom SQL join, got %d", len(options.CustomSQLJoin)) + return + } + expected := `LEFT JOIN departments d ON d.id = employees.department_id` + if options.CustomSQLJoin[0] != expected { + t.Errorf("Expected CustomSQLJoin[0]=%q, got %q", expected, options.CustomSQLJoin[0]) + } + }, + }, + { + name: "Parse multiple custom SQL JOINs from query params", + queryParams: map[string]string{ + "x-custom-sql-join": `LEFT JOIN departments d ON d.id = e.dept_id | INNER JOIN roles r ON r.id = e.role_id`, + }, + validate: func(t *testing.T, options ExtendedRequestOptions) { + if len(options.CustomSQLJoin) != 2 { + t.Errorf("Expected 2 custom SQL joins, got %d", len(options.CustomSQLJoin)) + return + } + expected1 := `LEFT JOIN departments d ON d.id = e.dept_id` + expected2 := `INNER JOIN roles r ON r.id = e.role_id` + if options.CustomSQLJoin[0] != expected1 { + t.Errorf("Expected CustomSQLJoin[0]=%q, got %q", expected1, options.CustomSQLJoin[0]) + } + if options.CustomSQLJoin[1] != expected2 { + t.Errorf("Expected CustomSQLJoin[1]=%q, got %q", expected2, options.CustomSQLJoin[1]) + } + }, + }, + { + name: "Parse custom SQL JOIN from headers", + headers: map[string]string{ + "X-Custom-SQL-Join": `LEFT JOIN users u ON u.id = posts.user_id`, + }, + validate: func(t *testing.T, options ExtendedRequestOptions) { + if len(options.CustomSQLJoin) == 0 { + t.Error("Expected CustomSQLJoin to be set from header") + return + } + expected := `LEFT JOIN users u ON u.id = posts.user_id` + if options.CustomSQLJoin[0] != expected { + t.Errorf("Expected CustomSQLJoin[0]=%q, got %q", expected, options.CustomSQLJoin[0]) + } + }, + }, } for _, tt := range tests { diff --git a/pkg/restheadspec/restheadspec.go b/pkg/restheadspec/restheadspec.go index cfe0378..5a743bb 100644 --- a/pkg/restheadspec/restheadspec.go +++ b/pkg/restheadspec/restheadspec.go @@ -32,6 +32,7 @@ // - X-Clean-JSON: Boolean to remove null/empty fields // - X-Custom-SQL-Where: Custom SQL WHERE clause (AND) // - X-Custom-SQL-Or: Custom SQL WHERE clause (OR) +// - X-Custom-SQL-Join: Custom SQL JOIN clauses (pipe-separated for multiple) // // # Usage Example // From 24a7ef7284eab3eea7d902f33eb8f24d3ca2bed1 Mon Sep 17 00:00:00 2001 From: Hein Date: Thu, 15 Jan 2026 14:18:25 +0200 Subject: [PATCH 22/31] =?UTF-8?q?feat(restheadspec):=20=E2=9C=A8=20Add=20s?= =?UTF-8?q?upport=20for=20join=20aliases=20in=20filters=20and=20sorts?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Extract join aliases from custom SQL JOIN clauses. - Validate join aliases for filtering and sorting operations. - Update documentation to reflect new functionality. - Enhance tests for alias extraction and usage. --- pkg/common/sql_helpers.go | 8 ++ pkg/common/types.go | 4 + pkg/restheadspec/HEADERS.md | 21 ++++ pkg/restheadspec/headers.go | 59 ++++++++++ pkg/restheadspec/query_params_test.go | 150 ++++++++++++++++++++++++++ 5 files changed, 242 insertions(+) diff --git a/pkg/common/sql_helpers.go b/pkg/common/sql_helpers.go index 0af6616..26d8053 100644 --- a/pkg/common/sql_helpers.go +++ b/pkg/common/sql_helpers.go @@ -166,6 +166,14 @@ func SanitizeWhereClause(where string, tableName string, options ...*RequestOpti logger.Debug("Added preload relation '%s' as allowed table prefix", options[0].Preload[pi].Relation) } } + + // Add join aliases as allowed prefixes + for _, alias := range options[0].JoinAliases { + if alias != "" { + allowedPrefixes[alias] = true + logger.Debug("Added join alias '%s' as allowed table prefix", alias) + } + } } // Split by AND to handle multiple conditions diff --git a/pkg/common/types.go b/pkg/common/types.go index b09b3db..3e81ab9 100644 --- a/pkg/common/types.go +++ b/pkg/common/types.go @@ -23,6 +23,10 @@ type RequestOptions struct { CursorForward string `json:"cursor_forward"` CursorBackward string `json:"cursor_backward"` FetchRowNumber *string `json:"fetch_row_number"` + + // Join table aliases (used for validation of prefixed columns in filters/sorts) + // Not serialized to JSON as it's internal validation state + JoinAliases []string `json:"-"` } type Parameter struct { diff --git a/pkg/restheadspec/HEADERS.md b/pkg/restheadspec/HEADERS.md index 0149a0b..147f6ce 100644 --- a/pkg/restheadspec/HEADERS.md +++ b/pkg/restheadspec/HEADERS.md @@ -233,6 +233,27 @@ x-custom-sql-join: LEFT JOIN departments d ON d.id = e.dept_id | INNER JOIN role - Multiple JOINs can be specified using the pipe `|` separator - JOINs are sanitized for security - Can be specified via headers or query parameters +- **Table aliases are automatically extracted and allowed for filtering and sorting** + +**Using Join Aliases in Filters and Sorts:** + +When you specify a custom SQL join with an alias, you can use that alias in your filter and sort parameters: + +``` +# Join with alias +x-custom-sql-join: LEFT JOIN departments d ON d.id = employees.department_id + +# Sort by joined table column +x-sort: d.name,employees.id + +# Filter by joined table column +x-searchop-eq-d.name: Engineering +``` + +The system automatically: +1. Extracts the alias from the JOIN clause (e.g., `d` from `departments d`) +2. Validates that prefixed columns (like `d.name`) refer to valid join aliases +3. Allows these prefixed columns in filters and sorts --- diff --git a/pkg/restheadspec/headers.go b/pkg/restheadspec/headers.go index eb32fb8..ef51cbc 100644 --- a/pkg/restheadspec/headers.go +++ b/pkg/restheadspec/headers.go @@ -28,6 +28,7 @@ type ExtendedRequestOptions struct { // Joins Expand []ExpandOption CustomSQLJoin []string // Custom SQL JOIN clauses + JoinAliases []string // Extracted table aliases from CustomSQLJoin for validation // Advanced features AdvancedSQL map[string]string // Column -> SQL expression @@ -528,11 +529,69 @@ func (h *Handler) parseCustomSQLJoin(options *ExtendedRequestOptions, value stri continue } + // Extract table alias from the JOIN clause + alias := extractJoinAlias(sanitizedJoin) + if alias != "" { + options.JoinAliases = append(options.JoinAliases, alias) + // Also add to the embedded RequestOptions for validation + options.RequestOptions.JoinAliases = append(options.RequestOptions.JoinAliases, alias) + logger.Debug("Extracted join alias: %s", alias) + } + logger.Debug("Adding custom SQL join: %s", sanitizedJoin) options.CustomSQLJoin = append(options.CustomSQLJoin, sanitizedJoin) } } +// extractJoinAlias extracts the table alias from a JOIN clause +// Examples: +// - "LEFT JOIN departments d ON ..." -> "d" +// - "INNER JOIN users AS u ON ..." -> "u" +// - "JOIN roles r ON ..." -> "r" +func extractJoinAlias(joinClause string) string { + // Pattern: JOIN table_name [AS] alias ON ... + // We need to extract the alias (word before ON) + + upperJoin := strings.ToUpper(joinClause) + + // Find the "JOIN" keyword position + joinIdx := strings.Index(upperJoin, "JOIN") + if joinIdx == -1 { + return "" + } + + // Find the "ON" keyword position + onIdx := strings.Index(upperJoin, " ON ") + if onIdx == -1 { + return "" + } + + // Extract the part between JOIN and ON + betweenJoinAndOn := strings.TrimSpace(joinClause[joinIdx+4 : onIdx]) + + // Split by spaces to get words + words := strings.Fields(betweenJoinAndOn) + if len(words) == 0 { + return "" + } + + // If there's an AS keyword, the alias is after it + for i, word := range words { + if strings.EqualFold(word, "AS") && i+1 < len(words) { + return words[i+1] + } + } + + // Otherwise, the alias is the last word (if there are 2+ words) + // Format: "table_name alias" or just "table_name" + if len(words) >= 2 { + return words[len(words)-1] + } + + // Only one word means it's just the table name, no alias + return "" +} + // parseSorting parses x-sort header // Format: +field1,-field2,field3 (+ for ASC, - for DESC, default ASC) func (h *Handler) parseSorting(options *ExtendedRequestOptions, value string) { diff --git a/pkg/restheadspec/query_params_test.go b/pkg/restheadspec/query_params_test.go index 5ea19ec..9b37768 100644 --- a/pkg/restheadspec/query_params_test.go +++ b/pkg/restheadspec/query_params_test.go @@ -357,6 +357,107 @@ func TestParseOptionsFromQueryParams(t *testing.T) { } }, }, + { + name: "Extract aliases from custom SQL JOIN", + queryParams: map[string]string{ + "x-custom-sql-join": `LEFT JOIN departments d ON d.id = employees.department_id`, + }, + validate: func(t *testing.T, options ExtendedRequestOptions) { + if len(options.JoinAliases) == 0 { + t.Error("Expected JoinAliases to be extracted") + return + } + if len(options.JoinAliases) != 1 { + t.Errorf("Expected 1 join alias, got %d", len(options.JoinAliases)) + return + } + if options.JoinAliases[0] != "d" { + t.Errorf("Expected join alias 'd', got %q", options.JoinAliases[0]) + } + // Also check that it's in the embedded RequestOptions + if len(options.RequestOptions.JoinAliases) != 1 || options.RequestOptions.JoinAliases[0] != "d" { + t.Error("Expected join alias to also be in RequestOptions.JoinAliases") + } + }, + }, + { + name: "Extract multiple aliases from multiple custom SQL JOINs", + queryParams: map[string]string{ + "x-custom-sql-join": `LEFT JOIN departments d ON d.id = e.dept_id | INNER JOIN roles AS r ON r.id = e.role_id`, + }, + validate: func(t *testing.T, options ExtendedRequestOptions) { + if len(options.JoinAliases) != 2 { + t.Errorf("Expected 2 join aliases, got %d", len(options.JoinAliases)) + return + } + expectedAliases := []string{"d", "r"} + for i, expected := range expectedAliases { + if options.JoinAliases[i] != expected { + t.Errorf("Expected join alias[%d]=%q, got %q", i, expected, options.JoinAliases[i]) + } + } + }, + }, + { + name: "Custom JOIN with sort on joined table", + queryParams: map[string]string{ + "x-custom-sql-join": `LEFT JOIN departments d ON d.id = employees.department_id`, + "x-sort": "d.name,employees.id", + }, + validate: func(t *testing.T, options ExtendedRequestOptions) { + // Verify join was added + if len(options.CustomSQLJoin) != 1 { + t.Errorf("Expected 1 custom SQL join, got %d", len(options.CustomSQLJoin)) + return + } + // Verify alias was extracted + if len(options.JoinAliases) != 1 || options.JoinAliases[0] != "d" { + t.Error("Expected join alias 'd' to be extracted") + return + } + // Verify sort was parsed + if len(options.Sort) != 2 { + t.Errorf("Expected 2 sort options, got %d", len(options.Sort)) + return + } + if options.Sort[0].Column != "d.name" { + t.Errorf("Expected first sort column 'd.name', got %q", options.Sort[0].Column) + } + if options.Sort[1].Column != "employees.id" { + t.Errorf("Expected second sort column 'employees.id', got %q", options.Sort[1].Column) + } + }, + }, + { + name: "Custom JOIN with filter on joined table", + queryParams: map[string]string{ + "x-custom-sql-join": `LEFT JOIN departments d ON d.id = employees.department_id`, + "x-searchop-eq-d.name": "Engineering", + }, + validate: func(t *testing.T, options ExtendedRequestOptions) { + // Verify join was added + if len(options.CustomSQLJoin) != 1 { + t.Error("Expected 1 custom SQL join") + return + } + // Verify alias was extracted + if len(options.JoinAliases) != 1 || options.JoinAliases[0] != "d" { + t.Error("Expected join alias 'd' to be extracted") + return + } + // Verify filter was parsed + if len(options.Filters) != 1 { + t.Errorf("Expected 1 filter, got %d", len(options.Filters)) + return + } + if options.Filters[0].Column != "d.name" { + t.Errorf("Expected filter column 'd.name', got %q", options.Filters[0].Column) + } + if options.Filters[0].Operator != "eq" { + t.Errorf("Expected filter operator 'eq', got %q", options.Filters[0].Operator) + } + }, + }, } for _, tt := range tests { @@ -451,6 +552,55 @@ func TestHeadersAndQueryParamsCombined(t *testing.T) { } } +// TestCustomJoinAliasExtraction tests the extractJoinAlias helper function +func TestCustomJoinAliasExtraction(t *testing.T) { + tests := []struct { + name string + join string + expected string + }{ + { + name: "LEFT JOIN with alias", + join: "LEFT JOIN departments d ON d.id = employees.department_id", + expected: "d", + }, + { + name: "INNER JOIN with AS keyword", + join: "INNER JOIN users AS u ON u.id = posts.user_id", + expected: "u", + }, + { + name: "Simple JOIN with alias", + join: "JOIN roles r ON r.id = user_roles.role_id", + expected: "r", + }, + { + name: "JOIN without alias (just table name)", + join: "JOIN departments ON departments.id = employees.dept_id", + expected: "", + }, + { + name: "RIGHT JOIN with alias", + join: "RIGHT JOIN orders o ON o.customer_id = customers.id", + expected: "o", + }, + { + name: "FULL OUTER JOIN with AS", + join: "FULL OUTER JOIN products AS p ON p.id = order_items.product_id", + expected: "p", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := extractJoinAlias(tt.join) + if result != tt.expected { + t.Errorf("extractJoinAlias(%q) = %q, want %q", tt.join, result, tt.expected) + } + }) + } +} + // Helper function to check if a string contains a substring func contains(s, substr string) bool { return len(s) >= len(substr) && (s == substr || len(s) > len(substr) && containsHelper(s, substr)) From c12c045db1a7929ad7c3b7e78cf0d9dbe9139852 Mon Sep 17 00:00:00 2001 From: Hein Date: Thu, 15 Jan 2026 14:43:11 +0200 Subject: [PATCH 23/31] =?UTF-8?q?feat(validation):=20=E2=9C=A8=20Clear=20J?= =?UTF-8?q?oinAliases=20in=20FilterRequestOptions?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Implemented logic to clear JoinAliases after filtering. * Added unit test to verify JoinAliases is nil post-filtering. * Ensured other fields are correctly filtered. --- pkg/common/validation.go | 33 +++++++++++++++++++++++++-------- pkg/common/validation_test.go | 23 +++++++++++++++++++++++ 2 files changed, 48 insertions(+), 8 deletions(-) diff --git a/pkg/common/validation.go b/pkg/common/validation.go index a1ac064..653a869 100644 --- a/pkg/common/validation.go +++ b/pkg/common/validation.go @@ -237,15 +237,29 @@ func (v *ColumnValidator) FilterRequestOptions(options RequestOptions) RequestOp for _, sort := range options.Sort { if v.IsValidColumn(sort.Column) { validSorts = append(validSorts, sort) - } else if strings.HasPrefix(sort.Column, "(") && strings.HasSuffix(sort.Column, ")") { - // Allow sort by expression/subquery, but validate for security - if IsSafeSortExpression(sort.Column) { - validSorts = append(validSorts, sort) - } else { - logger.Warn("Unsafe sort expression '%s' removed", sort.Column) - } } else { - logger.Warn("Invalid column in sort '%s' removed", sort.Column) + foundJoin := false + for _, j := range options.JoinAliases { + if strings.Contains(sort.Column, j) { + foundJoin = true + break + } + } + if foundJoin { + validSorts = append(validSorts, sort) + continue + } + if strings.HasPrefix(sort.Column, "(") && strings.HasSuffix(sort.Column, ")") { + // Allow sort by expression/subquery, but validate for security + if IsSafeSortExpression(sort.Column) { + validSorts = append(validSorts, sort) + } else { + logger.Warn("Unsafe sort expression '%s' removed", sort.Column) + } + + } else { + logger.Warn("Invalid column in sort '%s' removed", sort.Column) + } } } filtered.Sort = validSorts @@ -291,6 +305,9 @@ func (v *ColumnValidator) FilterRequestOptions(options RequestOptions) RequestOp } filtered.Preload = validPreloads + // Clear JoinAliases - this is an internal validation field and should not be persisted + filtered.JoinAliases = nil + return filtered } diff --git a/pkg/common/validation_test.go b/pkg/common/validation_test.go index 1e56070..813a192 100644 --- a/pkg/common/validation_test.go +++ b/pkg/common/validation_test.go @@ -362,6 +362,29 @@ func TestFilterRequestOptions(t *testing.T) { } } +func TestFilterRequestOptions_ClearsJoinAliases(t *testing.T) { + model := TestModel{} + validator := NewColumnValidator(model) + + options := RequestOptions{ + Columns: []string{"id", "name"}, + // Set JoinAliases - this should be cleared by FilterRequestOptions + JoinAliases: []string{"d", "u", "r"}, + } + + filtered := validator.FilterRequestOptions(options) + + // Verify that JoinAliases was cleared (internal field should not persist) + if filtered.JoinAliases != nil { + t.Errorf("Expected JoinAliases to be nil after filtering, got %v", filtered.JoinAliases) + } + + // Verify that other fields are still properly filtered + if len(filtered.Columns) != 2 { + t.Errorf("Expected 2 columns, got %d", len(filtered.Columns)) + } +} + func TestIsSafeSortExpression(t *testing.T) { tests := []struct { name string From 09f22568996d96afbbde8d4354c9328da8173d12 Mon Sep 17 00:00:00 2001 From: Hein Date: Mon, 26 Jan 2026 09:14:17 +0200 Subject: [PATCH 24/31] =?UTF-8?q?feat(sql):=20=E2=9C=A8=20Enhance=20SQL=20?= =?UTF-8?q?clause=20handling=20with=20parentheses?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add EnsureOuterParentheses function to wrap clauses in parentheses. * Implement logic to preserve outer parentheses for OR conditions. * Update SanitizeWhereClause to utilize new function for better query safety. * Introduce tests for EnsureOuterParentheses and containsTopLevelOR functions. * Refactor filter application in handler to group OR filters correctly. --- pkg/common/sql_helpers.go | 114 +++++++++++++++++++++- pkg/common/sql_helpers_test.go | 173 +++++++++++++++++++++++++++++++++ pkg/resolvespec/handler.go | 4 + pkg/restheadspec/handler.go | 135 ++++++++++++++++++++++++- 4 files changed, 420 insertions(+), 6 deletions(-) diff --git a/pkg/common/sql_helpers.go b/pkg/common/sql_helpers.go index 26d8053..3093d49 100644 --- a/pkg/common/sql_helpers.go +++ b/pkg/common/sql_helpers.go @@ -130,6 +130,9 @@ func validateWhereClauseSecurity(where string) error { // Note: This function will NOT add prefixes to unprefixed columns. It will only fix // incorrect prefixes (e.g., wrong_table.column -> correct_table.column), unless the // prefix matches a preloaded relation name, in which case it's left unchanged. +// +// IMPORTANT: Outer parentheses are preserved if the clause contains top-level OR operators +// to prevent OR logic from escaping and affecting the entire query incorrectly. func SanitizeWhereClause(where string, tableName string, options ...*RequestOptions) string { if where == "" { return "" @@ -143,8 +146,19 @@ func SanitizeWhereClause(where string, tableName string, options ...*RequestOpti return "" } - // Strip outer parentheses and re-trim - where = stripOuterParentheses(where) + // Check if the original clause has outer parentheses and contains OR operators + // If so, we need to preserve the outer parentheses to prevent OR logic from escaping + hasOuterParens := false + if len(where) > 0 && where[0] == '(' && where[len(where)-1] == ')' { + _, hasOuterParens = stripOneMatchingOuterParen(where) + } + + // Strip outer parentheses and re-trim for processing + whereWithoutParens := stripOuterParentheses(where) + shouldPreserveParens := hasOuterParens && containsTopLevelOR(whereWithoutParens) + + // Use the stripped version for processing + where = whereWithoutParens // Get valid columns from the model if tableName is provided var validColumns map[string]bool @@ -229,7 +243,14 @@ func SanitizeWhereClause(where string, tableName string, options ...*RequestOpti result := strings.Join(validConditions, " AND ") - if result != where { + // If the original clause had outer parentheses and contains OR operators, + // restore the outer parentheses to prevent OR logic from escaping + if shouldPreserveParens { + result = "(" + result + ")" + logger.Debug("Preserved outer parentheses for OR conditions: '%s'", result) + } + + if result != where && !shouldPreserveParens { logger.Debug("Sanitized WHERE clause: '%s' -> '%s'", where, result) } @@ -290,6 +311,93 @@ func stripOneMatchingOuterParen(s string) (string, bool) { return strings.TrimSpace(s[1 : len(s)-1]), true } +// EnsureOuterParentheses ensures that a SQL clause is wrapped in parentheses +// to prevent OR logic from escaping. It checks if the clause already has +// matching outer parentheses and only adds them if they don't exist. +// +// This is particularly important for OR conditions and complex filters where +// the absence of parentheses could cause the logic to escape and affect +// the entire query incorrectly. +// +// Parameters: +// - clause: The SQL clause to check and potentially wrap +// +// Returns: +// - The clause with guaranteed outer parentheses, or empty string if input is empty +func EnsureOuterParentheses(clause string) string { + if clause == "" { + return "" + } + + clause = strings.TrimSpace(clause) + if clause == "" { + return "" + } + + // Check if the clause already has matching outer parentheses + _, hasOuterParens := stripOneMatchingOuterParen(clause) + + // If it already has matching outer parentheses, return as-is + if hasOuterParens { + return clause + } + + // Otherwise, wrap it in parentheses + return "(" + clause + ")" +} + +// containsTopLevelOR checks if a SQL clause contains OR operators at the top level +// (i.e., not inside parentheses or subqueries). This is used to determine if +// outer parentheses should be preserved to prevent OR logic from escaping. +func containsTopLevelOR(clause string) bool { + if clause == "" { + return false + } + + depth := 0 + inSingleQuote := false + inDoubleQuote := false + lowerClause := strings.ToLower(clause) + + for i := 0; i < len(clause); i++ { + ch := clause[i] + + // Track quote state + if ch == '\'' && !inDoubleQuote { + inSingleQuote = !inSingleQuote + continue + } + if ch == '"' && !inSingleQuote { + inDoubleQuote = !inDoubleQuote + continue + } + + // Skip if inside quotes + if inSingleQuote || inDoubleQuote { + continue + } + + // Track parenthesis depth + switch ch { + case '(': + depth++ + case ')': + depth-- + } + + // Only check for OR at depth 0 (not inside parentheses) + if depth == 0 && i+4 <= len(clause) { + // Check for " OR " (case-insensitive) + substring := lowerClause[i : i+4] + if substring == " or " { + return true + } + } + } + + return false +} + // splitByAND splits a WHERE clause by AND operators (case-insensitive) // This is parenthesis-aware and won't split on AND operators inside subqueries func splitByAND(where string) []string { diff --git a/pkg/common/sql_helpers_test.go b/pkg/common/sql_helpers_test.go index 6f2a4ca..acfd831 100644 --- a/pkg/common/sql_helpers_test.go +++ b/pkg/common/sql_helpers_test.go @@ -659,6 +659,179 @@ func TestSanitizeWhereClauseWithModel(t *testing.T) { } } +func TestEnsureOuterParentheses(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + { + name: "no parentheses", + input: "status = 'active'", + expected: "(status = 'active')", + }, + { + name: "already has outer parentheses", + input: "(status = 'active')", + expected: "(status = 'active')", + }, + { + name: "OR condition without parentheses", + input: "status = 'active' OR status = 'pending'", + expected: "(status = 'active' OR status = 'pending')", + }, + { + name: "OR condition with parentheses", + input: "(status = 'active' OR status = 'pending')", + expected: "(status = 'active' OR status = 'pending')", + }, + { + name: "complex condition with nested parentheses", + input: "(status = 'active' OR status = 'pending') AND (age > 18)", + expected: "((status = 'active' OR status = 'pending') AND (age > 18))", + }, + { + name: "empty string", + input: "", + expected: "", + }, + { + name: "whitespace only", + input: " ", + expected: "", + }, + { + name: "mismatched parentheses - adds outer ones", + input: "(status = 'active' OR status = 'pending'", + expected: "((status = 'active' OR status = 'pending')", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := EnsureOuterParentheses(tt.input) + if result != tt.expected { + t.Errorf("EnsureOuterParentheses(%q) = %q; want %q", tt.input, result, tt.expected) + } + }) + } +} + +func TestContainsTopLevelOR(t *testing.T) { + tests := []struct { + name string + input string + expected bool + }{ + { + name: "no OR operator", + input: "status = 'active' AND age > 18", + expected: false, + }, + { + name: "top-level OR", + input: "status = 'active' OR status = 'pending'", + expected: true, + }, + { + name: "OR inside parentheses", + input: "age > 18 AND (status = 'active' OR status = 'pending')", + expected: false, + }, + { + name: "OR in subquery", + input: "id IN (SELECT id FROM users WHERE status = 'active' OR status = 'pending')", + expected: false, + }, + { + name: "OR inside quotes", + input: "comment = 'this OR that'", + expected: false, + }, + { + name: "mixed - top-level OR and nested OR", + input: "name = 'test' OR (status = 'active' OR status = 'pending')", + expected: true, + }, + { + name: "empty string", + input: "", + expected: false, + }, + { + name: "lowercase or", + input: "status = 'active' or status = 'pending'", + expected: true, + }, + { + name: "uppercase OR", + input: "status = 'active' OR status = 'pending'", + expected: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := containsTopLevelOR(tt.input) + if result != tt.expected { + t.Errorf("containsTopLevelOR(%q) = %v; want %v", tt.input, result, tt.expected) + } + }) + } +} + +func TestSanitizeWhereClause_PreservesParenthesesWithOR(t *testing.T) { + tests := []struct { + name string + where string + tableName string + expected string + }{ + { + name: "OR condition with outer parentheses - preserved", + where: "(status = 'active' OR status = 'pending')", + tableName: "users", + expected: "(users.status = 'active' OR users.status = 'pending')", + }, + { + name: "AND condition with outer parentheses - stripped (no OR)", + where: "(status = 'active' AND age > 18)", + tableName: "users", + expected: "users.status = 'active' AND users.age > 18", + }, + { + name: "complex OR with nested conditions", + where: "((status = 'active' OR status = 'pending') AND age > 18)", + tableName: "users", + // Outer parens are stripped, but inner parens with OR are preserved + expected: "(users.status = 'active' OR users.status = 'pending') AND users.age > 18", + }, + { + name: "OR without outer parentheses - no parentheses added by SanitizeWhereClause", + where: "status = 'active' OR status = 'pending'", + tableName: "users", + expected: "users.status = 'active' OR users.status = 'pending'", + }, + { + name: "simple OR with parentheses - preserved", + where: "(users.status = 'active' OR users.status = 'pending')", + tableName: "users", + // Already has correct prefixes, parentheses preserved + expected: "(users.status = 'active' OR users.status = 'pending')", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + prefixedWhere := AddTablePrefixToColumns(tt.where, tt.tableName) + result := SanitizeWhereClause(prefixedWhere, tt.tableName) + if result != tt.expected { + t.Errorf("SanitizeWhereClause(%q, %q) = %q; want %q", tt.where, tt.tableName, result, tt.expected) + } + }) + } +} + func TestAddTablePrefixToColumns_ComplexConditions(t *testing.T) { tests := []struct { name string diff --git a/pkg/resolvespec/handler.go b/pkg/resolvespec/handler.go index a8d32d0..33a1cb6 100644 --- a/pkg/resolvespec/handler.go +++ b/pkg/resolvespec/handler.go @@ -318,6 +318,8 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st if cursorFilter != "" { logger.Debug("Applying cursor filter: %s", cursorFilter) sanitizedCursor := common.SanitizeWhereClause(cursorFilter, reflection.ExtractTableNameOnly(tableName), &options) + // Ensure outer parentheses to prevent OR logic from escaping + sanitizedCursor = common.EnsureOuterParentheses(sanitizedCursor) if sanitizedCursor != "" { query = query.Where(sanitizedCursor) } @@ -1656,6 +1658,8 @@ func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, pre // Build RequestOptions with all preloads to allow references to sibling relations preloadOpts := &common.RequestOptions{Preload: preloads} sanitizedWhere := common.SanitizeWhereClause(preload.Where, reflection.ExtractTableNameOnly(preload.Relation), preloadOpts) + // Ensure outer parentheses to prevent OR logic from escaping + sanitizedWhere = common.EnsureOuterParentheses(sanitizedWhere) if len(sanitizedWhere) > 0 { sq = sq.Where(sanitizedWhere) } diff --git a/pkg/restheadspec/handler.go b/pkg/restheadspec/handler.go index c176bcb..d5f0aa1 100644 --- a/pkg/restheadspec/handler.go +++ b/pkg/restheadspec/handler.go @@ -463,7 +463,8 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st } // Apply filters - validate and adjust for column types first - for i := range options.Filters { + // Group consecutive OR filters together to prevent OR logic from escaping + for i := 0; i < len(options.Filters); { filter := &options.Filters[i] // Validate and adjust filter based on column type @@ -475,8 +476,39 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st logicOp = "AND" } - logger.Debug("Applying filter: %s %s %v (needsCast=%v, logic=%s)", filter.Column, filter.Operator, filter.Value, castInfo.NeedsCast, logicOp) - query = h.applyFilter(query, *filter, tableName, castInfo.NeedsCast, logicOp) + // Check if this is the start of an OR group + if logicOp == "OR" { + // Collect all consecutive OR filters + orFilters := []*common.FilterOption{filter} + orCastInfo := []ColumnCastInfo{castInfo} + + j := i + 1 + for j < len(options.Filters) { + nextFilter := &options.Filters[j] + nextLogicOp := nextFilter.LogicOperator + if nextLogicOp == "" { + nextLogicOp = "AND" + } + if nextLogicOp == "OR" { + nextCastInfo := h.ValidateAndAdjustFilterForColumnType(nextFilter, model) + orFilters = append(orFilters, nextFilter) + orCastInfo = append(orCastInfo, nextCastInfo) + j++ + } else { + break + } + } + + // Apply the OR group as a single grouped condition + logger.Debug("Applying OR filter group with %d conditions", len(orFilters)) + query = h.applyOrFilterGroup(query, orFilters, orCastInfo, tableName) + i = j + } else { + // Single AND filter - apply normally + logger.Debug("Applying filter: %s %s %v (needsCast=%v, logic=%s)", filter.Column, filter.Operator, filter.Value, castInfo.NeedsCast, logicOp) + query = h.applyFilter(query, *filter, tableName, castInfo.NeedsCast, logicOp) + i++ + } } // Apply custom SQL WHERE clause (AND condition) @@ -486,6 +518,8 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st prefixedWhere := common.AddTablePrefixToColumns(options.CustomSQLWhere, reflection.ExtractTableNameOnly(tableName)) // Then sanitize and allow preload table prefixes since custom SQL may reference multiple tables sanitizedWhere := common.SanitizeWhereClause(prefixedWhere, reflection.ExtractTableNameOnly(tableName), &options.RequestOptions) + // Ensure outer parentheses to prevent OR logic from escaping + sanitizedWhere = common.EnsureOuterParentheses(sanitizedWhere) if sanitizedWhere != "" { query = query.Where(sanitizedWhere) } @@ -497,6 +531,8 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st customOr := common.AddTablePrefixToColumns(options.CustomSQLOr, reflection.ExtractTableNameOnly(tableName)) // Sanitize and allow preload table prefixes since custom SQL may reference multiple tables sanitizedOr := common.SanitizeWhereClause(customOr, reflection.ExtractTableNameOnly(tableName), &options.RequestOptions) + // Ensure outer parentheses to prevent OR logic from escaping + sanitizedOr = common.EnsureOuterParentheses(sanitizedOr) if sanitizedOr != "" { query = query.WhereOr(sanitizedOr) } @@ -1996,6 +2032,99 @@ func (h *Handler) applyFilter(query common.SelectQuery, filter common.FilterOpti } } +// applyOrFilterGroup applies a group of OR filters as a single grouped condition +// This ensures OR conditions are properly grouped with parentheses to prevent OR logic from escaping +func (h *Handler) applyOrFilterGroup(query common.SelectQuery, filters []*common.FilterOption, castInfo []ColumnCastInfo, tableName string) common.SelectQuery { + if len(filters) == 0 { + return query + } + + // Build individual filter conditions + conditions := []string{} + args := []interface{}{} + + for i, filter := range filters { + // Qualify the column name with table name if not already qualified + qualifiedColumn := h.qualifyColumnName(filter.Column, tableName) + + // Apply casting to text if needed for non-numeric columns or non-numeric values + if castInfo[i].NeedsCast { + qualifiedColumn = fmt.Sprintf("CAST(%s AS TEXT)", qualifiedColumn) + } + + // Build the condition based on operator + condition, filterArgs := h.buildFilterCondition(qualifiedColumn, filter, tableName) + if condition != "" { + conditions = append(conditions, condition) + args = append(args, filterArgs...) + } + } + + if len(conditions) == 0 { + return query + } + + // Join all conditions with OR and wrap in parentheses + groupedCondition := "(" + strings.Join(conditions, " OR ") + ")" + logger.Debug("Applying grouped OR conditions: %s", groupedCondition) + + // Apply as AND condition (the OR is already inside the parentheses) + return query.Where(groupedCondition, args...) +} + +// buildFilterCondition builds a single filter condition and returns the condition string and args +func (h *Handler) buildFilterCondition(qualifiedColumn string, filter *common.FilterOption, tableName string) (filterStr string, filterInterface []interface{}) { + switch strings.ToLower(filter.Operator) { + case "eq", "equals": + return fmt.Sprintf("%s = ?", qualifiedColumn), []interface{}{filter.Value} + case "neq", "not_equals", "ne": + return fmt.Sprintf("%s != ?", qualifiedColumn), []interface{}{filter.Value} + case "gt", "greater_than": + return fmt.Sprintf("%s > ?", qualifiedColumn), []interface{}{filter.Value} + case "gte", "greater_than_equals", "ge": + return fmt.Sprintf("%s >= ?", qualifiedColumn), []interface{}{filter.Value} + case "lt", "less_than": + return fmt.Sprintf("%s < ?", qualifiedColumn), []interface{}{filter.Value} + case "lte", "less_than_equals", "le": + return fmt.Sprintf("%s <= ?", qualifiedColumn), []interface{}{filter.Value} + case "like": + return fmt.Sprintf("%s LIKE ?", qualifiedColumn), []interface{}{filter.Value} + case "ilike": + return fmt.Sprintf("%s ILIKE ?", qualifiedColumn), []interface{}{filter.Value} + case "in": + return fmt.Sprintf("%s IN (?)", qualifiedColumn), []interface{}{filter.Value} + case "between": + // Handle between operator - exclusive (> val1 AND < val2) + if values, ok := filter.Value.([]interface{}); ok && len(values) == 2 { + return fmt.Sprintf("(%s > ? AND %s < ?)", qualifiedColumn, qualifiedColumn), []interface{}{values[0], values[1]} + } else if values, ok := filter.Value.([]string); ok && len(values) == 2 { + return fmt.Sprintf("(%s > ? AND %s < ?)", qualifiedColumn, qualifiedColumn), []interface{}{values[0], values[1]} + } + logger.Warn("Invalid BETWEEN filter value format") + return "", nil + case "between_inclusive": + // Handle between inclusive operator - inclusive (>= val1 AND <= val2) + if values, ok := filter.Value.([]interface{}); ok && len(values) == 2 { + return fmt.Sprintf("(%s >= ? AND %s <= ?)", qualifiedColumn, qualifiedColumn), []interface{}{values[0], values[1]} + } else if values, ok := filter.Value.([]string); ok && len(values) == 2 { + return fmt.Sprintf("(%s >= ? AND %s <= ?)", qualifiedColumn, qualifiedColumn), []interface{}{values[0], values[1]} + } + logger.Warn("Invalid BETWEEN INCLUSIVE filter value format") + return "", nil + case "is_null", "isnull": + // Check for NULL values - don't use cast for NULL checks + colName := h.qualifyColumnName(filter.Column, tableName) + return fmt.Sprintf("(%s IS NULL OR %s = '')", colName, colName), nil + case "is_not_null", "isnotnull": + // Check for NOT NULL values - don't use cast for NULL checks + colName := h.qualifyColumnName(filter.Column, tableName) + return fmt.Sprintf("(%s IS NOT NULL AND %s != '')", colName, colName), nil + default: + logger.Warn("Unknown filter operator: %s, defaulting to equals", filter.Operator) + return fmt.Sprintf("%s = ?", qualifiedColumn), []interface{}{filter.Value} + } +} + // parseTableName splits a table name that may contain schema into separate schema and table func (h *Handler) parseTableName(fullTableName string) (schema, table string) { if idx := strings.LastIndex(fullTableName, "."); idx != -1 { From 07016d1b739b6fd338dfa0ab84cc51d21f7cf1ea Mon Sep 17 00:00:00 2001 From: Hein Date: Mon, 26 Jan 2026 11:06:16 +0200 Subject: [PATCH 25/31] =?UTF-8?q?feat(config):=20=E2=9C=A8=20Update=20time?= =?UTF-8?q?out=20settings=20for=20connections?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Set default query timeout to 2 minutes and enforce minimum. * Add statement_timeout for PostgreSQL DSN. * Implement busy timeout for SQLite with a minimum of 2 minutes. * Enforce minimum connection timeouts of 10 minutes for server instance. --- pkg/dbmanager/config.go | 44 ++++++++++++++++++++++++++++--- pkg/dbmanager/providers/sqlite.go | 8 ++++-- pkg/server/manager.go | 19 ++++++++++--- 3 files changed, 61 insertions(+), 10 deletions(-) diff --git a/pkg/dbmanager/config.go b/pkg/dbmanager/config.go index 7213eb6..e690827 100644 --- a/pkg/dbmanager/config.go +++ b/pkg/dbmanager/config.go @@ -221,7 +221,10 @@ func (cc *ConnectionConfig) ApplyDefaults(global *ManagerConfig) { cc.ConnectTimeout = 10 * time.Second } if cc.QueryTimeout == 0 { - cc.QueryTimeout = 30 * time.Second + cc.QueryTimeout = 2 * time.Minute // Default to 2 minutes + } else if cc.QueryTimeout < 2*time.Minute { + // Enforce minimum of 2 minutes + cc.QueryTimeout = 2 * time.Minute } // Default ORM @@ -325,14 +328,29 @@ func (cc *ConnectionConfig) buildPostgresDSN() string { dsn += fmt.Sprintf(" search_path=%s", cc.Schema) } + // Add statement_timeout for query execution timeout (in milliseconds) + if cc.QueryTimeout > 0 { + timeoutMs := int(cc.QueryTimeout.Milliseconds()) + dsn += fmt.Sprintf(" statement_timeout=%d", timeoutMs) + } + return dsn } func (cc *ConnectionConfig) buildSQLiteDSN() string { - if cc.FilePath != "" { - return cc.FilePath + filepath := cc.FilePath + if filepath == "" { + filepath = ":memory:" } - return ":memory:" + + // Add query parameters for timeouts + // Note: SQLite driver supports _timeout parameter (in milliseconds) + if cc.QueryTimeout > 0 { + timeoutMs := int(cc.QueryTimeout.Milliseconds()) + filepath += fmt.Sprintf("?_timeout=%d", timeoutMs) + } + + return filepath } func (cc *ConnectionConfig) buildMSSQLDSN() string { @@ -344,6 +362,24 @@ func (cc *ConnectionConfig) buildMSSQLDSN() string { dsn += fmt.Sprintf("&schema=%s", cc.Schema) } + // Add connection timeout (in seconds) + if cc.ConnectTimeout > 0 { + timeoutSec := int(cc.ConnectTimeout.Seconds()) + dsn += fmt.Sprintf("&connection timeout=%d", timeoutSec) + } + + // Add dial timeout for TCP connection (in seconds) + if cc.ConnectTimeout > 0 { + dialTimeoutSec := int(cc.ConnectTimeout.Seconds()) + dsn += fmt.Sprintf("&dial timeout=%d", dialTimeoutSec) + } + + // Add read timeout (in seconds) - enforces timeout for reading data + if cc.QueryTimeout > 0 { + readTimeoutSec := int(cc.QueryTimeout.Seconds()) + dsn += fmt.Sprintf("&read timeout=%d", readTimeoutSec) + } + return dsn } diff --git a/pkg/dbmanager/providers/sqlite.go b/pkg/dbmanager/providers/sqlite.go index 3ce5c99..6f70970 100644 --- a/pkg/dbmanager/providers/sqlite.go +++ b/pkg/dbmanager/providers/sqlite.go @@ -76,8 +76,12 @@ func (p *SQLiteProvider) Connect(ctx context.Context, cfg ConnectionConfig) erro // Don't fail connection if WAL mode cannot be enabled } - // Set busy timeout to handle locked database - _, err = db.ExecContext(ctx, "PRAGMA busy_timeout=5000") + // Set busy timeout to handle locked database (minimum 2 minutes = 120000ms) + busyTimeout := cfg.GetQueryTimeout().Milliseconds() + if busyTimeout < 120000 { + busyTimeout = 120000 // Enforce minimum of 2 minutes + } + _, err = db.ExecContext(ctx, fmt.Sprintf("PRAGMA busy_timeout=%d", busyTimeout)) if err != nil { if cfg.GetEnableLogging() { logger.Warn("Failed to set busy timeout for SQLite", "error", err) diff --git a/pkg/server/manager.go b/pkg/server/manager.go index f1b7877..61f16dc 100644 --- a/pkg/server/manager.go +++ b/pkg/server/manager.go @@ -411,7 +411,9 @@ func newInstance(cfg Config) (*serverInstance, error) { return nil, fmt.Errorf("handler cannot be nil") } - // Set default timeouts + // Set default timeouts with minimum of 10 minutes for connection timeouts + minConnectionTimeout := 10 * time.Minute + if cfg.ShutdownTimeout == 0 { cfg.ShutdownTimeout = 30 * time.Second } @@ -419,13 +421,22 @@ func newInstance(cfg Config) (*serverInstance, error) { cfg.DrainTimeout = 25 * time.Second } if cfg.ReadTimeout == 0 { - cfg.ReadTimeout = 15 * time.Second + cfg.ReadTimeout = minConnectionTimeout + } else if cfg.ReadTimeout < minConnectionTimeout { + // Enforce minimum of 10 minutes + cfg.ReadTimeout = minConnectionTimeout } if cfg.WriteTimeout == 0 { - cfg.WriteTimeout = 15 * time.Second + cfg.WriteTimeout = minConnectionTimeout + } else if cfg.WriteTimeout < minConnectionTimeout { + // Enforce minimum of 10 minutes + cfg.WriteTimeout = minConnectionTimeout } if cfg.IdleTimeout == 0 { - cfg.IdleTimeout = 60 * time.Second + cfg.IdleTimeout = minConnectionTimeout + } else if cfg.IdleTimeout < minConnectionTimeout { + // Enforce minimum of 10 minutes + cfg.IdleTimeout = minConnectionTimeout } addr := fmt.Sprintf("%s:%d", cfg.Host, cfg.Port) From f7725340a6651a60bd40c941b72f50c9dffe612c Mon Sep 17 00:00:00 2001 From: Hein Date: Tue, 27 Jan 2026 17:33:50 +0200 Subject: [PATCH 26/31] =?UTF-8?q?feat(sql):=20=E2=9C=A8=20Add=20base64=20e?= =?UTF-8?q?ncoding/decoding=20for=20SqlByteArray?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Implement base64 handling in SqlNull for []byte types. * Add tests for SqlString and SqlByteArray with base64 encoding. * Ensure proper JSON marshaling and unmarshaling for new types. --- pkg/spectypes/sql_types.go | 81 ++++++- pkg/spectypes/sql_types_test.go | 391 ++++++++++++++++++++++++++++++++ 2 files changed, 463 insertions(+), 9 deletions(-) diff --git a/pkg/spectypes/sql_types.go b/pkg/spectypes/sql_types.go index 43afc93..1d008d1 100644 --- a/pkg/spectypes/sql_types.go +++ b/pkg/spectypes/sql_types.go @@ -4,6 +4,7 @@ package spectypes import ( "database/sql" "database/sql/driver" + "encoding/base64" "encoding/json" "fmt" "reflect" @@ -60,7 +61,34 @@ func (n *SqlNull[T]) Scan(value any) error { return nil } - // Try standard sql.Null[T] first. + // Check if T is []byte, and decode base64 if applicable + // Do this BEFORE trying sql.Null to ensure base64 is handled + var zero T + switch any(zero).(type) { + case []byte: + // For []byte types, try to decode from base64 + var strVal string + switch v := value.(type) { + case string: + strVal = v + case []byte: + strVal = string(v) + default: + strVal = fmt.Sprintf("%v", value) + } + // Try base64 decode + if decoded, err := base64.StdEncoding.DecodeString(strVal); err == nil { + n.Val = any(decoded).(T) + n.Valid = true + return nil + } + // Fallback to raw bytes + n.Val = any([]byte(strVal)).(T) + n.Valid = true + return nil + } + + // Try standard sql.Null[T] for other types. var sqlNull sql.Null[T] if err := sqlNull.Scan(value); err == nil { n.Val = sqlNull.V @@ -122,6 +150,9 @@ func (n *SqlNull[T]) FromString(s string) error { n.Val = any(u).(T) n.Valid = true } + case []byte: + n.Val = any([]byte(s)).(T) + n.Valid = true case string: n.Val = any(s).(T) n.Valid = true @@ -149,6 +180,15 @@ func (n SqlNull[T]) MarshalJSON() ([]byte, error) { if !n.Valid { return []byte("null"), nil } + + // Check if T is []byte, and encode to base64 + switch v := any(n.Val).(type) { + case []byte: + // Encode []byte as base64 + encoded := base64.StdEncoding.EncodeToString(v) + return json.Marshal(encoded) + } + return json.Marshal(n.Val) } @@ -160,8 +200,26 @@ func (n *SqlNull[T]) UnmarshalJSON(b []byte) error { return nil } - // Try direct unmarshal. + // Check if T is []byte, and decode from base64 var val T + switch any(val).(type) { + case []byte: + // Unmarshal as string first (JSON representation) + var s string + if err := json.Unmarshal(b, &s); err == nil { + // Decode from base64 + if decoded, err := base64.StdEncoding.DecodeString(s); err == nil { + n.Val = any(decoded).(T) + n.Valid = true + return nil + } + // Fallback to raw string as bytes + n.Val = any([]byte(s)).(T) + n.Valid = true + return nil + } + } + if err := json.Unmarshal(b, &val); err == nil { n.Val = val n.Valid = true @@ -271,13 +329,14 @@ func (n SqlNull[T]) UUID() uuid.UUID { // Type aliases for common types. type ( - SqlInt16 = SqlNull[int16] - SqlInt32 = SqlNull[int32] - SqlInt64 = SqlNull[int64] - SqlFloat64 = SqlNull[float64] - SqlBool = SqlNull[bool] - SqlString = SqlNull[string] - SqlUUID = SqlNull[uuid.UUID] + SqlInt16 = SqlNull[int16] + SqlInt32 = SqlNull[int32] + SqlInt64 = SqlNull[int64] + SqlFloat64 = SqlNull[float64] + SqlBool = SqlNull[bool] + SqlString = SqlNull[string] + SqlByteArray = SqlNull[[]byte] + SqlUUID = SqlNull[uuid.UUID] ) // SqlTimeStamp - Timestamp with custom formatting (YYYY-MM-DDTHH:MM:SS). @@ -581,6 +640,10 @@ func NewSqlString(v string) SqlString { return SqlString{Val: v, Valid: true} } +func NewSqlByteArray(v []byte) SqlByteArray { + return SqlByteArray{Val: v, Valid: true} +} + func NewSqlUUID(v uuid.UUID) SqlUUID { return SqlUUID{Val: v, Valid: true} } diff --git a/pkg/spectypes/sql_types_test.go b/pkg/spectypes/sql_types_test.go index 57e7614..7e743c3 100644 --- a/pkg/spectypes/sql_types_test.go +++ b/pkg/spectypes/sql_types_test.go @@ -565,3 +565,394 @@ func TestTryIfInt64(t *testing.T) { }) } } + +// TestSqlString tests SqlString without base64 (plain text) +func TestSqlString_Scan(t *testing.T) { + tests := []struct { + name string + input interface{} + expected string + valid bool + }{ + { + name: "plain string", + input: "hello world", + expected: "hello world", + valid: true, + }, + { + name: "plain text", + input: "plain text", + expected: "plain text", + valid: true, + }, + { + name: "bytes as string", + input: []byte("raw bytes"), + expected: "raw bytes", + valid: true, + }, + { + name: "nil value", + input: nil, + expected: "", + valid: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var s SqlString + if err := s.Scan(tt.input); err != nil { + t.Fatalf("Scan failed: %v", err) + } + if s.Valid != tt.valid { + t.Errorf("expected valid=%v, got valid=%v", tt.valid, s.Valid) + } + if tt.valid && s.String() != tt.expected { + t.Errorf("expected %q, got %q", tt.expected, s.String()) + } + }) + } +} + +func TestSqlString_JSON(t *testing.T) { + tests := []struct { + name string + inputValue string + expectedJSON string + expectedDecode string + }{ + { + name: "simple string", + inputValue: "hello world", + expectedJSON: `"hello world"`, // plain text, not base64 + expectedDecode: "hello world", + }, + { + name: "special characters", + inputValue: "test@#$%", + expectedJSON: `"test@#$%"`, // plain text, not base64 + expectedDecode: "test@#$%", + }, + { + name: "unicode string", + inputValue: "Hello 世界", + expectedJSON: `"Hello 世界"`, // plain text, not base64 + expectedDecode: "Hello 世界", + }, + { + name: "empty string", + inputValue: "", + expectedJSON: `""`, + expectedDecode: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Test MarshalJSON + s := NewSqlString(tt.inputValue) + data, err := json.Marshal(s) + if err != nil { + t.Fatalf("Marshal failed: %v", err) + } + if string(data) != tt.expectedJSON { + t.Errorf("Marshal: expected %s, got %s", tt.expectedJSON, string(data)) + } + + // Test UnmarshalJSON + var s2 SqlString + if err := json.Unmarshal(data, &s2); err != nil { + t.Fatalf("Unmarshal failed: %v", err) + } + if !s2.Valid { + t.Error("expected valid=true after unmarshal") + } + if s2.String() != tt.expectedDecode { + t.Errorf("Unmarshal: expected %q, got %q", tt.expectedDecode, s2.String()) + } + }) + } +} + +func TestSqlString_JSON_Null(t *testing.T) { + // Test null handling + var s SqlString + if err := json.Unmarshal([]byte("null"), &s); err != nil { + t.Fatalf("Unmarshal null failed: %v", err) + } + if s.Valid { + t.Error("expected invalid after unmarshaling null") + } + + // Test marshal null + data, err := json.Marshal(s) + if err != nil { + t.Fatalf("Marshal failed: %v", err) + } + if string(data) != "null" { + t.Errorf("expected null, got %s", string(data)) + } +} + +// TestSqlByteArray_Base64 tests SqlByteArray with base64 encoding/decoding +func TestSqlByteArray_Base64_Scan(t *testing.T) { + tests := []struct { + name string + input interface{} + expected []byte + valid bool + }{ + { + name: "base64 encoded bytes from SQL", + input: "aGVsbG8gd29ybGQ=", // "hello world" in base64 + expected: []byte("hello world"), + valid: true, + }, + { + name: "plain bytes fallback", + input: "plain text", + expected: []byte("plain text"), + valid: true, + }, + { + name: "bytes base64 encoded", + input: []byte("SGVsbG8gR29waGVy"), // "Hello Gopher" in base64 + expected: []byte("Hello Gopher"), + valid: true, + }, + { + name: "bytes plain fallback", + input: []byte("raw bytes"), + expected: []byte("raw bytes"), + valid: true, + }, + { + name: "binary data", + input: "AQIDBA==", // []byte{1, 2, 3, 4} in base64 + expected: []byte{1, 2, 3, 4}, + valid: true, + }, + { + name: "nil value", + input: nil, + expected: nil, + valid: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var b SqlByteArray + if err := b.Scan(tt.input); err != nil { + t.Fatalf("Scan failed: %v", err) + } + if b.Valid != tt.valid { + t.Errorf("expected valid=%v, got valid=%v", tt.valid, b.Valid) + } + if tt.valid { + if string(b.Val) != string(tt.expected) { + t.Errorf("expected %q, got %q", tt.expected, b.Val) + } + } + }) + } +} + +func TestSqlByteArray_Base64_JSON(t *testing.T) { + tests := []struct { + name string + inputValue []byte + expectedJSON string + expectedDecode []byte + }{ + { + name: "text bytes", + inputValue: []byte("hello world"), + expectedJSON: `"aGVsbG8gd29ybGQ="`, // base64 encoded + expectedDecode: []byte("hello world"), + }, + { + name: "binary data", + inputValue: []byte{0x01, 0x02, 0x03, 0x04, 0xFF}, + expectedJSON: `"AQIDBP8="`, // base64 encoded + expectedDecode: []byte{0x01, 0x02, 0x03, 0x04, 0xFF}, + }, + { + name: "empty bytes", + inputValue: []byte{}, + expectedJSON: `""`, // base64 of empty bytes + expectedDecode: []byte{}, + }, + { + name: "unicode bytes", + inputValue: []byte("Hello 世界"), + expectedJSON: `"SGVsbG8g5LiW55WM"`, // base64 encoded + expectedDecode: []byte("Hello 世界"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Test MarshalJSON + b := NewSqlByteArray(tt.inputValue) + data, err := json.Marshal(b) + if err != nil { + t.Fatalf("Marshal failed: %v", err) + } + if string(data) != tt.expectedJSON { + t.Errorf("Marshal: expected %s, got %s", tt.expectedJSON, string(data)) + } + + // Test UnmarshalJSON + var b2 SqlByteArray + if err := json.Unmarshal(data, &b2); err != nil { + t.Fatalf("Unmarshal failed: %v", err) + } + if !b2.Valid { + t.Error("expected valid=true after unmarshal") + } + if string(b2.Val) != string(tt.expectedDecode) { + t.Errorf("Unmarshal: expected %v, got %v", tt.expectedDecode, b2.Val) + } + }) + } +} + +func TestSqlByteArray_Base64_JSON_Null(t *testing.T) { + // Test null handling + var b SqlByteArray + if err := json.Unmarshal([]byte("null"), &b); err != nil { + t.Fatalf("Unmarshal null failed: %v", err) + } + if b.Valid { + t.Error("expected invalid after unmarshaling null") + } + + // Test marshal null + data, err := json.Marshal(b) + if err != nil { + t.Fatalf("Marshal failed: %v", err) + } + if string(data) != "null" { + t.Errorf("expected null, got %s", string(data)) + } +} + +func TestSqlByteArray_Value(t *testing.T) { + tests := []struct { + name string + input SqlByteArray + expected interface{} + }{ + { + name: "valid bytes", + input: NewSqlByteArray([]byte("test data")), + expected: []byte("test data"), + }, + { + name: "empty bytes", + input: NewSqlByteArray([]byte{}), + expected: []byte{}, + }, + { + name: "invalid", + input: SqlByteArray{Valid: false}, + expected: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + val, err := tt.input.Value() + if err != nil { + t.Fatalf("Value failed: %v", err) + } + if tt.expected == nil && val != nil { + t.Errorf("expected nil, got %v", val) + } + if tt.expected != nil && val == nil { + t.Errorf("expected %v, got nil", tt.expected) + } + if tt.expected != nil && val != nil { + if string(val.([]byte)) != string(tt.expected.([]byte)) { + t.Errorf("expected %v, got %v", tt.expected, val) + } + } + }) + } +} + +// TestSqlString_RoundTrip tests complete round-trip: Go -> JSON -> Go -> SQL -> Go +func TestSqlString_RoundTrip(t *testing.T) { + original := "Test String with Special Chars: @#$%^&*()" + + // Go -> JSON + s1 := NewSqlString(original) + jsonData, err := json.Marshal(s1) + if err != nil { + t.Fatalf("Marshal failed: %v", err) + } + + // JSON -> Go + var s2 SqlString + if err := json.Unmarshal(jsonData, &s2); err != nil { + t.Fatalf("Unmarshal failed: %v", err) + } + + // Go -> SQL (Value) + _, err = s2.Value() + if err != nil { + t.Fatalf("Value failed: %v", err) + } + + // SQL -> Go (Scan plain text) + var s3 SqlString + // Simulate SQL driver returning plain text value + if err := s3.Scan(original); err != nil { + t.Fatalf("Scan failed: %v", err) + } + + // Verify round-trip + if s3.String() != original { + t.Errorf("Round-trip failed: expected %q, got %q", original, s3.String()) + } +} + +// TestSqlByteArray_Base64_RoundTrip tests complete round-trip: Go -> JSON -> Go -> SQL -> Go +func TestSqlByteArray_Base64_RoundTrip(t *testing.T) { + original := []byte{0x48, 0x65, 0x6C, 0x6C, 0x6F, 0x20, 0xFF, 0xFE} // "Hello " + binary data + + // Go -> JSON + b1 := NewSqlByteArray(original) + jsonData, err := json.Marshal(b1) + if err != nil { + t.Fatalf("Marshal failed: %v", err) + } + + // JSON -> Go + var b2 SqlByteArray + if err := json.Unmarshal(jsonData, &b2); err != nil { + t.Fatalf("Unmarshal failed: %v", err) + } + + // Go -> SQL (Value) + _, err = b2.Value() + if err != nil { + t.Fatalf("Value failed: %v", err) + } + + // SQL -> Go (Scan with base64) + var b3 SqlByteArray + // Simulate SQL driver returning base64 encoded value + if err := b3.Scan("SGVsbG8g//4="); err != nil { + t.Fatalf("Scan failed: %v", err) + } + + // Verify round-trip + if string(b3.Val) != string(original) { + t.Errorf("Round-trip failed: expected %v, got %v", original, b3.Val) + } +} + From defe27549b7f42ca7e9653b60c6da407fcbd8315 Mon Sep 17 00:00:00 2001 From: Hein Date: Tue, 27 Jan 2026 17:35:13 +0200 Subject: [PATCH 27/31] =?UTF-8?q?feat(sql):=20=E2=9C=A8=20Improve=20base64?= =?UTF-8?q?=20handling=20in=20SqlNull=20type?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Refactor base64 encoding and decoding checks for []byte types. * Simplify type assertions using if statements instead of switch cases. --- pkg/spectypes/sql_types.go | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/pkg/spectypes/sql_types.go b/pkg/spectypes/sql_types.go index 1d008d1..f9a81d6 100644 --- a/pkg/spectypes/sql_types.go +++ b/pkg/spectypes/sql_types.go @@ -64,8 +64,7 @@ func (n *SqlNull[T]) Scan(value any) error { // Check if T is []byte, and decode base64 if applicable // Do this BEFORE trying sql.Null to ensure base64 is handled var zero T - switch any(zero).(type) { - case []byte: + if _, ok := any(zero).([]byte); ok { // For []byte types, try to decode from base64 var strVal string switch v := value.(type) { @@ -182,10 +181,9 @@ func (n SqlNull[T]) MarshalJSON() ([]byte, error) { } // Check if T is []byte, and encode to base64 - switch v := any(n.Val).(type) { - case []byte: + if _, ok := any(n.Val).([]byte); ok { // Encode []byte as base64 - encoded := base64.StdEncoding.EncodeToString(v) + encoded := base64.StdEncoding.EncodeToString(any(n.Val).([]byte)) return json.Marshal(encoded) } @@ -202,8 +200,7 @@ func (n *SqlNull[T]) UnmarshalJSON(b []byte) error { // Check if T is []byte, and decode from base64 var val T - switch any(val).(type) { - case []byte: + if _, ok := any(val).([]byte); ok { // Unmarshal as string first (JSON representation) var s string if err := json.Unmarshal(b, &s); err == nil { From 17239d1611279e6f2c1c6775505eb510dac97a62 Mon Sep 17 00:00:00 2001 From: Hein Date: Thu, 29 Jan 2026 09:37:09 +0200 Subject: [PATCH 28/31] =?UTF-8?q?feat(preload):=20=E2=9C=A8=20Add=20suppor?= =?UTF-8?q?t=20for=20custom=20SQL=20joins?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Introduce SqlJoins and JoinAliases in PreloadOption. * Preserve SqlJoins and JoinAliases during filter processing. * Implement logic to apply custom SQL joins in handler. * Add tests for SqlJoins handling and join alias extraction. --- pkg/common/types.go | 4 ++ pkg/common/validation.go | 18 ++++- pkg/restheadspec/handler.go | 9 +++ pkg/restheadspec/headers.go | 26 +++++++ pkg/restheadspec/headers_test.go | 117 +++++++++++++++++++++++++++++++ 5 files changed, 173 insertions(+), 1 deletion(-) diff --git a/pkg/common/types.go b/pkg/common/types.go index 3e81ab9..a68daf6 100644 --- a/pkg/common/types.go +++ b/pkg/common/types.go @@ -52,6 +52,10 @@ type PreloadOption struct { PrimaryKey string `json:"primary_key"` // Primary key of the related table RelatedKey string `json:"related_key"` // For child tables: column in child that references parent ForeignKey string `json:"foreign_key"` // For parent tables: column in current table that references parent + + // Custom SQL JOINs from XFiles - used when preload needs additional joins + SqlJoins []string `json:"sql_joins"` // Custom SQL JOIN clauses + JoinAliases []string `json:"join_aliases"` // Extracted table aliases from SqlJoins for validation } type FilterOption struct { diff --git a/pkg/common/validation.go b/pkg/common/validation.go index 653a869..1a7ae7d 100644 --- a/pkg/common/validation.go +++ b/pkg/common/validation.go @@ -272,13 +272,29 @@ func (v *ColumnValidator) FilterRequestOptions(options RequestOptions) RequestOp filteredPreload.Columns = v.FilterValidColumns(preload.Columns) filteredPreload.OmitColumns = v.FilterValidColumns(preload.OmitColumns) + // Preserve SqlJoins and JoinAliases for preloads with custom joins + filteredPreload.SqlJoins = preload.SqlJoins + filteredPreload.JoinAliases = preload.JoinAliases + // Filter preload filters validPreloadFilters := make([]FilterOption, 0, len(preload.Filters)) for _, filter := range preload.Filters { if v.IsValidColumn(filter.Column) { validPreloadFilters = append(validPreloadFilters, filter) } else { - logger.Warn("Invalid column in preload '%s' filter '%s' removed", preload.Relation, filter.Column) + // Check if the filter column references a joined table alias + foundJoin := false + for _, alias := range preload.JoinAliases { + if strings.Contains(filter.Column, alias) { + foundJoin = true + break + } + } + if foundJoin { + validPreloadFilters = append(validPreloadFilters, filter) + } else { + logger.Warn("Invalid column in preload '%s' filter '%s' removed", preload.Relation, filter.Column) + } } } filteredPreload.Filters = validPreloadFilters diff --git a/pkg/restheadspec/handler.go b/pkg/restheadspec/handler.go index d5f0aa1..4bc6c08 100644 --- a/pkg/restheadspec/handler.go +++ b/pkg/restheadspec/handler.go @@ -882,6 +882,15 @@ func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload co } } + // Apply custom SQL joins from XFiles + if len(preload.SqlJoins) > 0 { + logger.Debug("Applying %d SQL joins to preload %s", len(preload.SqlJoins), preload.Relation) + for _, joinClause := range preload.SqlJoins { + sq = sq.Join(joinClause) + logger.Debug("Applied SQL join to preload %s: %s", preload.Relation, joinClause) + } + } + // Apply filters if len(preload.Filters) > 0 { for _, filter := range preload.Filters { diff --git a/pkg/restheadspec/headers.go b/pkg/restheadspec/headers.go index ef51cbc..cce3a1e 100644 --- a/pkg/restheadspec/headers.go +++ b/pkg/restheadspec/headers.go @@ -1088,6 +1088,32 @@ func (h *Handler) addXFilesPreload(xfile *XFiles, options *ExtendedRequestOption logger.Debug("X-Files: Set foreign key for %s: %s", relationPath, xfile.ForeignKey) } + // Transfer SqlJoins from XFiles to PreloadOption + if len(xfile.SqlJoins) > 0 { + preloadOpt.SqlJoins = make([]string, 0, len(xfile.SqlJoins)) + preloadOpt.JoinAliases = make([]string, 0, len(xfile.SqlJoins)) + + for _, joinClause := range xfile.SqlJoins { + // Sanitize the join clause + sanitizedJoin := common.SanitizeWhereClause(joinClause, "", nil) + if sanitizedJoin == "" { + logger.Warn("X-Files: SqlJoin failed sanitization for %s: %s", relationPath, joinClause) + continue + } + + preloadOpt.SqlJoins = append(preloadOpt.SqlJoins, sanitizedJoin) + + // Extract join alias for validation + alias := extractJoinAlias(sanitizedJoin) + if alias != "" { + preloadOpt.JoinAliases = append(preloadOpt.JoinAliases, alias) + logger.Debug("X-Files: Extracted join alias for %s: %s", relationPath, alias) + } + } + + logger.Debug("X-Files: Added %d SQL joins to preload %s", len(preloadOpt.SqlJoins), relationPath) + } + // Add the preload option options.Preload = append(options.Preload, preloadOpt) diff --git a/pkg/restheadspec/headers_test.go b/pkg/restheadspec/headers_test.go index 8117483..d83d09f 100644 --- a/pkg/restheadspec/headers_test.go +++ b/pkg/restheadspec/headers_test.go @@ -2,6 +2,8 @@ package restheadspec import ( "testing" + + "github.com/bitechdev/ResolveSpec/pkg/common" ) func TestDecodeHeaderValue(t *testing.T) { @@ -37,6 +39,121 @@ func TestDecodeHeaderValue(t *testing.T) { } } +func TestAddXFilesPreload_WithSqlJoins(t *testing.T) { + handler := &Handler{} + options := &ExtendedRequestOptions{ + RequestOptions: common.RequestOptions{ + Preload: make([]common.PreloadOption, 0), + }, + } + + // Create an XFiles with SqlJoins + xfile := &XFiles{ + TableName: "users", + SqlJoins: []string{ + "LEFT JOIN departments d ON d.id = users.department_id", + "INNER JOIN roles r ON r.id = users.role_id", + }, + FilterFields: []struct { + Field string `json:"field"` + Value string `json:"value"` + Operator string `json:"operator"` + }{ + {Field: "d.active", Value: "true", Operator: "eq"}, + {Field: "r.name", Value: "admin", Operator: "eq"}, + }, + } + + // Add the XFiles preload + handler.addXFilesPreload(xfile, options, "") + + // Verify that a preload was added + if len(options.Preload) != 1 { + t.Fatalf("Expected 1 preload, got %d", len(options.Preload)) + } + + preload := options.Preload[0] + + // Verify relation name + if preload.Relation != "users" { + t.Errorf("Expected relation 'users', got '%s'", preload.Relation) + } + + // Verify SqlJoins were transferred + if len(preload.SqlJoins) != 2 { + t.Fatalf("Expected 2 SQL joins, got %d", len(preload.SqlJoins)) + } + + // Verify JoinAliases were extracted + if len(preload.JoinAliases) != 2 { + t.Fatalf("Expected 2 join aliases, got %d", len(preload.JoinAliases)) + } + + // Verify the aliases are correct + expectedAliases := []string{"d", "r"} + for i, expected := range expectedAliases { + if preload.JoinAliases[i] != expected { + t.Errorf("Expected alias '%s', got '%s'", expected, preload.JoinAliases[i]) + } + } + + // Verify filters were added + if len(preload.Filters) != 2 { + t.Fatalf("Expected 2 filters, got %d", len(preload.Filters)) + } + + // Verify filter columns reference joined tables + if preload.Filters[0].Column != "d.active" { + t.Errorf("Expected filter column 'd.active', got '%s'", preload.Filters[0].Column) + } + if preload.Filters[1].Column != "r.name" { + t.Errorf("Expected filter column 'r.name', got '%s'", preload.Filters[1].Column) + } +} + +func TestExtractJoinAlias(t *testing.T) { + tests := []struct { + name string + joinClause string + expected string + }{ + { + name: "LEFT JOIN with alias", + joinClause: "LEFT JOIN departments d ON d.id = users.department_id", + expected: "d", + }, + { + name: "INNER JOIN with AS keyword", + joinClause: "INNER JOIN users AS u ON u.id = orders.user_id", + expected: "u", + }, + { + name: "JOIN without alias", + joinClause: "JOIN roles ON roles.id = users.role_id", + expected: "", + }, + { + name: "Complex join with multiple conditions", + joinClause: "LEFT OUTER JOIN products p ON p.id = items.product_id AND p.active = true", + expected: "p", + }, + { + name: "Invalid join (no ON clause)", + joinClause: "LEFT JOIN departments", + expected: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := extractJoinAlias(tt.joinClause) + if result != tt.expected { + t.Errorf("Expected alias '%s', got '%s'", tt.expected, result) + } + }) + } +} + // Note: The following functions are unexported (lowercase) and cannot be tested directly: // - parseSelectFields // - parseFieldFilter From 584bb9813d9d8c6bfbb7b5fbd98f8a41fa0be85e Mon Sep 17 00:00:00 2001 From: Hein Date: Thu, 29 Jan 2026 09:37:22 +0200 Subject: [PATCH 29/31] .. --- pkg/common/types.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pkg/common/types.go b/pkg/common/types.go index a68daf6..7447ffa 100644 --- a/pkg/common/types.go +++ b/pkg/common/types.go @@ -54,8 +54,8 @@ type PreloadOption struct { ForeignKey string `json:"foreign_key"` // For parent tables: column in current table that references parent // Custom SQL JOINs from XFiles - used when preload needs additional joins - SqlJoins []string `json:"sql_joins"` // Custom SQL JOIN clauses - JoinAliases []string `json:"join_aliases"` // Extracted table aliases from SqlJoins for validation + SqlJoins []string `json:"sql_joins"` // Custom SQL JOIN clauses + JoinAliases []string `json:"join_aliases"` // Extracted table aliases from SqlJoins for validation } type FilterOption struct { From fc8f44e3e8b0c30f1d09887aa2c1ecf291a72c52 Mon Sep 17 00:00:00 2001 From: Hein Date: Thu, 29 Jan 2026 15:31:50 +0200 Subject: [PATCH 30/31] =?UTF-8?q?feat(preload):=20=E2=9C=A8=20Enhance=20re?= =?UTF-8?q?cursive=20preload=20functionality?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Increase maximum recursion depth from 4 to 8. * Generate FK-based relation names for child preloads using RelatedKey. * Clear WHERE clause for recursive preloads to prevent filtering issues. * Extend child relations to recursive levels for better data retrieval. * Add integration tests to validate recursive preload behavior and structure. --- .gitignore | 1 + pkg/restheadspec/handler.go | 52 +- pkg/restheadspec/recursive_preload_test.go | 391 +++++++++++++++ pkg/restheadspec/xfiles_integration_test.go | 525 ++++++++++++++++++++ 4 files changed, 961 insertions(+), 8 deletions(-) create mode 100644 pkg/restheadspec/recursive_preload_test.go create mode 100644 pkg/restheadspec/xfiles_integration_test.go diff --git a/.gitignore b/.gitignore index a2c5024..3fb767e 100644 --- a/.gitignore +++ b/.gitignore @@ -26,3 +26,4 @@ go.work.sum bin/ test.db /testserver +tests/data/ \ No newline at end of file diff --git a/pkg/restheadspec/handler.go b/pkg/restheadspec/handler.go index 4bc6c08..11897e7 100644 --- a/pkg/restheadspec/handler.go +++ b/pkg/restheadspec/handler.go @@ -938,21 +938,57 @@ func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload co }) // Handle recursive preloading - if preload.Recursive && depth < 4 { + if preload.Recursive && depth < 8 { logger.Debug("Applying recursive preload for %s at depth %d", preload.Relation, depth+1) - // For recursive relationships, we need to get the last part of the relation path - // e.g., "MastertaskItems" -> "MastertaskItems.MastertaskItems" relationParts := strings.Split(preload.Relation, ".") lastRelationName := relationParts[len(relationParts)-1] - // Create a recursive preload with the same configuration - // but with the relation path extended - recursivePreload := preload - recursivePreload.Relation = preload.Relation + "." + lastRelationName + // Generate FK-based relation name for children + recursiveRelationName := lastRelationName + if preload.RelatedKey != "" { + // Convert "rid_parentmastertaskitem" to "RID_PARENTMASTERTASKITEM" + fkUpper := strings.ToUpper(preload.RelatedKey) + recursiveRelationName = lastRelationName + "_" + fkUpper + logger.Debug("Generated recursive relation name from RelatedKey: %s (from %s)", + recursiveRelationName, preload.RelatedKey) + } else { + logger.Warn("Recursive preload for %s has no RelatedKey, falling back to %s.%s", + preload.Relation, preload.Relation, lastRelationName) + } - // Recursively apply preload until we reach depth 5 + // Create recursive preload + recursivePreload := preload + recursivePreload.Relation = preload.Relation + "." + recursiveRelationName + recursivePreload.Recursive = false // Prevent infinite recursion at this level + + // CRITICAL: Clear parent's WHERE clause - let Bun use FK traversal + recursivePreload.Where = "" + recursivePreload.Filters = []common.FilterOption{} + logger.Debug("Cleared WHERE clause for recursive preload %s at depth %d", + recursivePreload.Relation, depth+1) + + // Apply recursively up to depth 8 query = h.applyPreloadWithRecursion(query, recursivePreload, allPreloads, model, depth+1) + + // ALSO: Extend any child relations (like DEF) to recursive levels + baseRelation := preload.Relation + "." + for i := range allPreloads { + relatedPreload := allPreloads[i] + if strings.HasPrefix(relatedPreload.Relation, baseRelation) && + !strings.Contains(strings.TrimPrefix(relatedPreload.Relation, baseRelation), ".") { + childRelationName := strings.TrimPrefix(relatedPreload.Relation, baseRelation) + + extendedChildPreload := relatedPreload + extendedChildPreload.Relation = recursivePreload.Relation + "." + childRelationName + extendedChildPreload.Recursive = false + + logger.Debug("Extending related preload '%s' to '%s' at recursive depth %d", + relatedPreload.Relation, extendedChildPreload.Relation, depth+1) + + query = h.applyPreloadWithRecursion(query, extendedChildPreload, allPreloads, model, depth+1) + } + } } return query diff --git a/pkg/restheadspec/recursive_preload_test.go b/pkg/restheadspec/recursive_preload_test.go new file mode 100644 index 0000000..7c79b1b --- /dev/null +++ b/pkg/restheadspec/recursive_preload_test.go @@ -0,0 +1,391 @@ +//go:build !integration +// +build !integration + +package restheadspec + +import ( + "context" + "testing" + + "github.com/bitechdev/ResolveSpec/pkg/common" +) + +// TestRecursivePreloadClearsWhereClause tests that recursive preloads +// correctly clear the WHERE clause from the parent level to allow +// Bun to use foreign key relationships for loading children +func TestRecursivePreloadClearsWhereClause(t *testing.T) { + // Create a mock handler + handler := &Handler{} + + // Create a preload option with a WHERE clause that filters root items + // This simulates the xfiles use case where the first level has a filter + // like "rid_parentmastertaskitem is null" to get root items + preload := common.PreloadOption{ + Relation: "MastertaskItems", + Recursive: true, + RelatedKey: "rid_parentmastertaskitem", + Where: "rid_parentmastertaskitem is null", + Filters: []common.FilterOption{ + { + Column: "rid_parentmastertaskitem", + Operator: "is null", + Value: nil, + }, + }, + } + + // Create a mock query that tracks operations + mockQuery := &mockSelectQuery{ + operations: []string{}, + } + + // Apply the recursive preload at depth 0 + // This should: + // 1. Apply the initial preload with the WHERE clause + // 2. Create a recursive preload without the WHERE clause + allPreloads := []common.PreloadOption{preload} + result := handler.applyPreloadWithRecursion(mockQuery, preload, allPreloads, nil, 0) + + // Verify the mock query received the operations + mock := result.(*mockSelectQuery) + + // Check that we have at least 2 PreloadRelation calls: + // 1. The initial "MastertaskItems" with WHERE clause + // 2. The recursive "MastertaskItems.MastertaskItems_RID_PARENTMASTERTASKITEM" without WHERE clause + preloadCount := 0 + recursivePreloadFound := false + whereAppliedToRecursive := false + + for _, op := range mock.operations { + if op == "PreloadRelation:MastertaskItems" { + preloadCount++ + } + if op == "PreloadRelation:MastertaskItems.MastertaskItems_RID_PARENTMASTERTASKITEM" { + recursivePreloadFound = true + } + // Check if WHERE was applied to the recursive preload (it shouldn't be) + if op == "Where:rid_parentmastertaskitem is null" && recursivePreloadFound { + whereAppliedToRecursive = true + } + } + + if preloadCount < 1 { + t.Errorf("Expected at least 1 PreloadRelation call, got %d", preloadCount) + } + + if !recursivePreloadFound { + t.Errorf("Expected recursive preload 'MastertaskItems.MastertaskItems_RID_PARENTMASTERTASKITEM' to be created. Operations: %v", mock.operations) + } + + if whereAppliedToRecursive { + t.Error("WHERE clause should not be applied to recursive preload levels") + } +} + +// TestRecursivePreloadWithChildRelations tests that child relations +// (like DEF in MAL.DEF) are properly extended to recursive levels +func TestRecursivePreloadWithChildRelations(t *testing.T) { + handler := &Handler{} + + // Create the main recursive preload + recursivePreload := common.PreloadOption{ + Relation: "MAL", + Recursive: true, + RelatedKey: "rid_parentmastertaskitem", + Where: "rid_parentmastertaskitem is null", + } + + // Create a child relation that should be extended + childPreload := common.PreloadOption{ + Relation: "MAL.DEF", + } + + mockQuery := &mockSelectQuery{ + operations: []string{}, + } + + allPreloads := []common.PreloadOption{recursivePreload, childPreload} + + // Apply both preloads - the child preload should be extended when the recursive one processes + result := handler.applyPreloadWithRecursion(mockQuery, recursivePreload, allPreloads, nil, 0) + + // Also need to apply the child preload separately (as would happen in normal flow) + result = handler.applyPreloadWithRecursion(result, childPreload, allPreloads, nil, 0) + + mock := result.(*mockSelectQuery) + + // Check that the child relation was extended to recursive levels + // We should see: + // - MAL (with WHERE) + // - MAL.DEF + // - MAL.MAL_RID_PARENTMASTERTASKITEM (without WHERE) + // - MAL.MAL_RID_PARENTMASTERTASKITEM.DEF (extended by recursive logic) + foundMALDEF := false + foundRecursiveMAL := false + foundMALMALDEF := false + + for _, op := range mock.operations { + if op == "PreloadRelation:MAL.DEF" { + foundMALDEF = true + } + if op == "PreloadRelation:MAL.MAL_RID_PARENTMASTERTASKITEM" { + foundRecursiveMAL = true + } + if op == "PreloadRelation:MAL.MAL_RID_PARENTMASTERTASKITEM.DEF" { + foundMALMALDEF = true + } + } + + if !foundMALDEF { + t.Errorf("Expected child preload 'MAL.DEF' to be applied. Operations: %v", mock.operations) + } + + if !foundRecursiveMAL { + t.Errorf("Expected recursive preload 'MAL.MAL_RID_PARENTMASTERTASKITEM' to be created. Operations: %v", mock.operations) + } + + if !foundMALMALDEF { + t.Errorf("Expected child preload to be extended to 'MAL.MAL_RID_PARENTMASTERTASKITEM.DEF' at recursive level. Operations: %v", mock.operations) + } +} + +// TestRecursivePreloadGeneratesCorrectRelationName tests that the recursive +// preload generates the correct FK-based relation name using RelatedKey +func TestRecursivePreloadGeneratesCorrectRelationName(t *testing.T) { + handler := &Handler{} + + // Test case 1: With RelatedKey - should generate FK-based name + t.Run("WithRelatedKey", func(t *testing.T) { + preload := common.PreloadOption{ + Relation: "MAL", + Recursive: true, + RelatedKey: "rid_parentmastertaskitem", + } + + mockQuery := &mockSelectQuery{operations: []string{}} + allPreloads := []common.PreloadOption{preload} + result := handler.applyPreloadWithRecursion(mockQuery, preload, allPreloads, nil, 0) + + mock := result.(*mockSelectQuery) + + // Should generate MAL.MAL_RID_PARENTMASTERTASKITEM + foundCorrectRelation := false + foundIncorrectRelation := false + + for _, op := range mock.operations { + if op == "PreloadRelation:MAL.MAL_RID_PARENTMASTERTASKITEM" { + foundCorrectRelation = true + } + if op == "PreloadRelation:MAL.MAL" { + foundIncorrectRelation = true + } + } + + if !foundCorrectRelation { + t.Errorf("Expected 'MAL.MAL_RID_PARENTMASTERTASKITEM' relation, operations: %v", mock.operations) + } + + if foundIncorrectRelation { + t.Error("Should NOT generate 'MAL.MAL' relation when RelatedKey is specified") + } + }) + + // Test case 2: Without RelatedKey - should fallback to old behavior + t.Run("WithoutRelatedKey", func(t *testing.T) { + preload := common.PreloadOption{ + Relation: "MAL", + Recursive: true, + // No RelatedKey + } + + mockQuery := &mockSelectQuery{operations: []string{}} + allPreloads := []common.PreloadOption{preload} + result := handler.applyPreloadWithRecursion(mockQuery, preload, allPreloads, nil, 0) + + mock := result.(*mockSelectQuery) + + // Should fallback to MAL.MAL + foundFallback := false + for _, op := range mock.operations { + if op == "PreloadRelation:MAL.MAL" { + foundFallback = true + } + } + + if !foundFallback { + t.Errorf("Expected fallback 'MAL.MAL' relation when no RelatedKey, operations: %v", mock.operations) + } + }) + + // Test case 3: Depth limit of 8 + t.Run("DepthLimit", func(t *testing.T) { + preload := common.PreloadOption{ + Relation: "MAL", + Recursive: true, + RelatedKey: "rid_parentmastertaskitem", + } + + mockQuery := &mockSelectQuery{operations: []string{}} + allPreloads := []common.PreloadOption{preload} + + // Start at depth 7 - should create one more level + result := handler.applyPreloadWithRecursion(mockQuery, preload, allPreloads, nil, 7) + mock := result.(*mockSelectQuery) + + foundDepth8 := false + for _, op := range mock.operations { + if op == "PreloadRelation:MAL.MAL_RID_PARENTMASTERTASKITEM" { + foundDepth8 = true + } + } + + if !foundDepth8 { + t.Error("Expected to create recursive level at depth 8") + } + + // Start at depth 8 - should NOT create another level + mockQuery2 := &mockSelectQuery{operations: []string{}} + result2 := handler.applyPreloadWithRecursion(mockQuery2, preload, allPreloads, nil, 8) + mock2 := result2.(*mockSelectQuery) + + foundDepth9 := false + for _, op := range mock2.operations { + if op == "PreloadRelation:MAL.MAL_RID_PARENTMASTERTASKITEM" { + foundDepth9 = true + } + } + + if foundDepth9 { + t.Error("Should NOT create recursive level beyond depth 8") + } + }) +} + +// mockSelectQuery implements common.SelectQuery for testing +type mockSelectQuery struct { + operations []string +} + +func (m *mockSelectQuery) Model(model interface{}) common.SelectQuery { + m.operations = append(m.operations, "Model") + return m +} + +func (m *mockSelectQuery) Table(table string) common.SelectQuery { + m.operations = append(m.operations, "Table:"+table) + return m +} + +func (m *mockSelectQuery) Column(columns ...string) common.SelectQuery { + for _, col := range columns { + m.operations = append(m.operations, "Column:"+col) + } + return m +} + +func (m *mockSelectQuery) ColumnExpr(query string, args ...interface{}) common.SelectQuery { + m.operations = append(m.operations, "ColumnExpr:"+query) + return m +} + +func (m *mockSelectQuery) Where(query string, args ...interface{}) common.SelectQuery { + m.operations = append(m.operations, "Where:"+query) + return m +} + +func (m *mockSelectQuery) WhereOr(query string, args ...interface{}) common.SelectQuery { + m.operations = append(m.operations, "WhereOr:"+query) + return m +} + +func (m *mockSelectQuery) WhereIn(column string, values interface{}) common.SelectQuery { + m.operations = append(m.operations, "WhereIn:"+column) + return m +} + +func (m *mockSelectQuery) Order(order string) common.SelectQuery { + m.operations = append(m.operations, "Order:"+order) + return m +} + +func (m *mockSelectQuery) OrderExpr(order string, args ...interface{}) common.SelectQuery { + m.operations = append(m.operations, "OrderExpr:"+order) + return m +} + +func (m *mockSelectQuery) Limit(limit int) common.SelectQuery { + m.operations = append(m.operations, "Limit") + return m +} + +func (m *mockSelectQuery) Offset(offset int) common.SelectQuery { + m.operations = append(m.operations, "Offset") + return m +} + +func (m *mockSelectQuery) Join(join string, args ...interface{}) common.SelectQuery { + m.operations = append(m.operations, "Join:"+join) + return m +} + +func (m *mockSelectQuery) LeftJoin(join string, args ...interface{}) common.SelectQuery { + m.operations = append(m.operations, "LeftJoin:"+join) + return m +} + +func (m *mockSelectQuery) Group(columns string) common.SelectQuery { + m.operations = append(m.operations, "Group") + return m +} + +func (m *mockSelectQuery) Having(query string, args ...interface{}) common.SelectQuery { + m.operations = append(m.operations, "Having:"+query) + return m +} + +func (m *mockSelectQuery) Preload(relation string, conditions ...interface{}) common.SelectQuery { + m.operations = append(m.operations, "Preload:"+relation) + return m +} + +func (m *mockSelectQuery) PreloadRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery { + m.operations = append(m.operations, "PreloadRelation:"+relation) + // Apply the preload modifiers + for _, fn := range apply { + fn(m) + } + return m +} + +func (m *mockSelectQuery) JoinRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery { + m.operations = append(m.operations, "JoinRelation:"+relation) + return m +} + +func (m *mockSelectQuery) Scan(ctx context.Context, dest interface{}) error { + m.operations = append(m.operations, "Scan") + return nil +} + +func (m *mockSelectQuery) ScanModel(ctx context.Context) error { + m.operations = append(m.operations, "ScanModel") + return nil +} + +func (m *mockSelectQuery) Count(ctx context.Context) (int, error) { + m.operations = append(m.operations, "Count") + return 0, nil +} + +func (m *mockSelectQuery) Exists(ctx context.Context) (bool, error) { + m.operations = append(m.operations, "Exists") + return false, nil +} + +func (m *mockSelectQuery) GetUnderlyingQuery() interface{} { + return nil +} + +func (m *mockSelectQuery) GetModel() interface{} { + return nil +} diff --git a/pkg/restheadspec/xfiles_integration_test.go b/pkg/restheadspec/xfiles_integration_test.go new file mode 100644 index 0000000..f171b9f --- /dev/null +++ b/pkg/restheadspec/xfiles_integration_test.go @@ -0,0 +1,525 @@ +//go:build integration +// +build integration + +package restheadspec + +import ( + "context" + "encoding/json" + "os" + "path/filepath" + "testing" + + "github.com/bitechdev/ResolveSpec/pkg/common" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// mockSelectQuery implements common.SelectQuery for testing (integration version) +type mockSelectQuery struct { + operations []string +} + +func (m *mockSelectQuery) Model(model interface{}) common.SelectQuery { + m.operations = append(m.operations, "Model") + return m +} + +func (m *mockSelectQuery) Table(table string) common.SelectQuery { + m.operations = append(m.operations, "Table:"+table) + return m +} + +func (m *mockSelectQuery) Column(columns ...string) common.SelectQuery { + for _, col := range columns { + m.operations = append(m.operations, "Column:"+col) + } + return m +} + +func (m *mockSelectQuery) ColumnExpr(query string, args ...interface{}) common.SelectQuery { + m.operations = append(m.operations, "ColumnExpr:"+query) + return m +} + +func (m *mockSelectQuery) Where(query string, args ...interface{}) common.SelectQuery { + m.operations = append(m.operations, "Where:"+query) + return m +} + +func (m *mockSelectQuery) WhereOr(query string, args ...interface{}) common.SelectQuery { + m.operations = append(m.operations, "WhereOr:"+query) + return m +} + +func (m *mockSelectQuery) WhereIn(column string, values interface{}) common.SelectQuery { + m.operations = append(m.operations, "WhereIn:"+column) + return m +} + +func (m *mockSelectQuery) Order(order string) common.SelectQuery { + m.operations = append(m.operations, "Order:"+order) + return m +} + +func (m *mockSelectQuery) OrderExpr(order string, args ...interface{}) common.SelectQuery { + m.operations = append(m.operations, "OrderExpr:"+order) + return m +} + +func (m *mockSelectQuery) Limit(limit int) common.SelectQuery { + m.operations = append(m.operations, "Limit") + return m +} + +func (m *mockSelectQuery) Offset(offset int) common.SelectQuery { + m.operations = append(m.operations, "Offset") + return m +} + +func (m *mockSelectQuery) Join(join string, args ...interface{}) common.SelectQuery { + m.operations = append(m.operations, "Join:"+join) + return m +} + +func (m *mockSelectQuery) LeftJoin(join string, args ...interface{}) common.SelectQuery { + m.operations = append(m.operations, "LeftJoin:"+join) + return m +} + +func (m *mockSelectQuery) Group(columns string) common.SelectQuery { + m.operations = append(m.operations, "Group") + return m +} + +func (m *mockSelectQuery) Having(query string, args ...interface{}) common.SelectQuery { + m.operations = append(m.operations, "Having:"+query) + return m +} + +func (m *mockSelectQuery) Preload(relation string, conditions ...interface{}) common.SelectQuery { + m.operations = append(m.operations, "Preload:"+relation) + return m +} + +func (m *mockSelectQuery) PreloadRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery { + m.operations = append(m.operations, "PreloadRelation:"+relation) + // Apply the preload modifiers + for _, fn := range apply { + fn(m) + } + return m +} + +func (m *mockSelectQuery) JoinRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery { + m.operations = append(m.operations, "JoinRelation:"+relation) + return m +} + +func (m *mockSelectQuery) Scan(ctx context.Context, dest interface{}) error { + m.operations = append(m.operations, "Scan") + return nil +} + +func (m *mockSelectQuery) ScanModel(ctx context.Context) error { + m.operations = append(m.operations, "ScanModel") + return nil +} + +func (m *mockSelectQuery) Count(ctx context.Context) (int, error) { + m.operations = append(m.operations, "Count") + return 0, nil +} + +func (m *mockSelectQuery) Exists(ctx context.Context) (bool, error) { + m.operations = append(m.operations, "Exists") + return false, nil +} + +func (m *mockSelectQuery) GetUnderlyingQuery() interface{} { + return nil +} + +func (m *mockSelectQuery) GetModel() interface{} { + return nil +} + +// TestXFilesRecursivePreload is an integration test that validates the XFiles +// recursive preload functionality using real test data files. +// +// This test ensures: +// 1. XFiles request JSON is correctly parsed into PreloadOptions +// 2. Recursive preload generates correct FK-based relation names (MAL_RID_PARENTMASTERTASKITEM) +// 3. Parent WHERE clauses don't leak to child levels +// 4. Child relations (like DEF) are extended to all recursive levels +// 5. Hierarchical data structure matches expected output +func TestXFilesRecursivePreload(t *testing.T) { + // Load the XFiles request configuration + requestPath := filepath.Join("..", "..", "tests", "data", "xfiles.request.json") + requestData, err := os.ReadFile(requestPath) + require.NoError(t, err, "Failed to read xfiles.request.json") + + var xfileConfig XFiles + err = json.Unmarshal(requestData, &xfileConfig) + require.NoError(t, err, "Failed to parse xfiles.request.json") + + // Create handler and parse XFiles into PreloadOptions + handler := &Handler{} + options := &ExtendedRequestOptions{ + RequestOptions: common.RequestOptions{ + Preload: []common.PreloadOption{}, + }, + } + + // Process the XFiles configuration - start with the root table + handler.processXFilesRelations(&xfileConfig, options, "") + + // Verify that preload options were created + require.NotEmpty(t, options.Preload, "Expected preload options to be created") + + // Test 1: Verify recursive preload option has RelatedKey set + t.Run("RecursivePreloadHasRelatedKey", func(t *testing.T) { + // Find the recursive mastertaskitem preload + var recursivePreload *common.PreloadOption + for i := range options.Preload { + preload := &options.Preload[i] + if preload.Relation == "mastertask.mastertaskitem.mastertaskitem" && preload.Recursive { + recursivePreload = preload + break + } + } + + require.NotNil(t, recursivePreload, "Expected to find recursive mastertaskitem preload") + assert.Equal(t, "rid_parentmastertaskitem", recursivePreload.RelatedKey, + "Recursive preload should have RelatedKey set from xfiles config") + assert.True(t, recursivePreload.Recursive, "mastertaskitem preload should be marked as recursive") + }) + + // Test 2: Verify root level mastertaskitem has WHERE clause for filtering root items + t.Run("RootLevelHasWhereClause", func(t *testing.T) { + var rootPreload *common.PreloadOption + for i := range options.Preload { + preload := &options.Preload[i] + if preload.Relation == "mastertask.mastertaskitem" && !preload.Recursive { + rootPreload = preload + break + } + } + + require.NotNil(t, rootPreload, "Expected to find root mastertaskitem preload") + assert.NotEmpty(t, rootPreload.Where, "Root mastertaskitem should have WHERE clause") + // The WHERE clause should filter for root items (rid_parentmastertaskitem is null) + }) + + // Test 3: Verify actiondefinition relation exists for mastertaskitem + t.Run("DEFRelationExists", func(t *testing.T) { + var defPreload *common.PreloadOption + for i := range options.Preload { + preload := &options.Preload[i] + if preload.Relation == "mastertask.mastertaskitem.actiondefinition" { + defPreload = preload + break + } + } + + require.NotNil(t, defPreload, "Expected to find actiondefinition preload for mastertaskitem") + assert.Equal(t, "rid_actiondefinition", defPreload.ForeignKey, + "actiondefinition preload should have ForeignKey set") + }) + + // Test 4: Verify relation name generation with mock query + t.Run("RelationNameGeneration", func(t *testing.T) { + // Find the recursive mastertaskitem preload + var recursivePreload common.PreloadOption + found := false + for _, preload := range options.Preload { + if preload.Relation == "mastertask.mastertaskitem.mastertaskitem" && preload.Recursive { + recursivePreload = preload + found = true + break + } + } + + require.True(t, found, "Expected to find recursive mastertaskitem preload") + + // Create mock query to track operations + mockQuery := &mockSelectQuery{operations: []string{}} + + // Apply the recursive preload + result := handler.applyPreloadWithRecursion(mockQuery, recursivePreload, options.Preload, nil, 0) + mock := result.(*mockSelectQuery) + + // Verify the correct FK-based relation name was generated + foundCorrectRelation := false + foundIncorrectRelation := false + + for _, op := range mock.operations { + // Should generate: mastertask.mastertaskitem.mastertaskitem.mastertaskitem_RID_PARENTMASTERTASKITEM + if op == "PreloadRelation:mastertask.mastertaskitem.mastertaskitem.mastertaskitem_RID_PARENTMASTERTASKITEM" { + foundCorrectRelation = true + } + // Should NOT generate: mastertask.mastertaskitem.mastertaskitem.mastertaskitem + if op == "PreloadRelation:mastertask.mastertaskitem.mastertaskitem.mastertaskitem" { + foundIncorrectRelation = true + } + } + + assert.True(t, foundCorrectRelation, + "Expected FK-based relation name 'mastertask.mastertaskitem.mastertaskitem.mastertaskitem_RID_PARENTMASTERTASKITEM' to be generated. Operations: %v", + mock.operations) + assert.False(t, foundIncorrectRelation, + "Should NOT generate simple relation name when RelatedKey is set") + }) + + // Test 5: Verify WHERE clause is cleared for recursive levels + t.Run("WhereClauseClearedForChildren", func(t *testing.T) { + // Find the recursive mastertaskitem preload with WHERE clause + var recursivePreload common.PreloadOption + found := false + for _, preload := range options.Preload { + if preload.Relation == "mastertask.mastertaskitem.mastertaskitem" && preload.Recursive { + recursivePreload = preload + found = true + break + } + } + + require.True(t, found, "Expected to find recursive mastertaskitem preload") + + // The root level might have a WHERE clause + // But when we apply recursion, it should be cleared + + mockQuery := &mockSelectQuery{operations: []string{}} + result := handler.applyPreloadWithRecursion(mockQuery, recursivePreload, options.Preload, nil, 0) + mock := result.(*mockSelectQuery) + + // After the first level, WHERE clauses should not be reapplied + // We check that the recursive relation was created (which means WHERE was cleared internally) + foundRecursiveRelation := false + for _, op := range mock.operations { + if op == "PreloadRelation:mastertask.mastertaskitem.mastertaskitem.mastertaskitem_RID_PARENTMASTERTASKITEM" { + foundRecursiveRelation = true + } + } + + assert.True(t, foundRecursiveRelation, + "Recursive relation should be created (WHERE clause should be cleared internally)") + }) + + // Test 6: Verify child relations are extended to recursive levels + t.Run("ChildRelationsExtended", func(t *testing.T) { + // Find both the recursive mastertaskitem and the actiondefinition preloads + var recursivePreload common.PreloadOption + foundRecursive := false + + for _, preload := range options.Preload { + if preload.Relation == "mastertask.mastertaskitem.mastertaskitem" && preload.Recursive { + recursivePreload = preload + foundRecursive = true + break + } + } + + require.True(t, foundRecursive, "Expected to find recursive mastertaskitem preload") + + mockQuery := &mockSelectQuery{operations: []string{}} + result := handler.applyPreloadWithRecursion(mockQuery, recursivePreload, options.Preload, nil, 0) + mock := result.(*mockSelectQuery) + + // actiondefinition should be extended to the recursive level + // Expected: mastertask.mastertaskitem.mastertaskitem.mastertaskitem_RID_PARENTMASTERTASKITEM.actiondefinition + foundExtendedDEF := false + for _, op := range mock.operations { + if op == "PreloadRelation:mastertask.mastertaskitem.mastertaskitem.mastertaskitem_RID_PARENTMASTERTASKITEM.actiondefinition" { + foundExtendedDEF = true + } + } + + assert.True(t, foundExtendedDEF, + "Expected actiondefinition relation to be extended to recursive level. Operations: %v", + mock.operations) + }) +} + +// TestXFilesRecursivePreloadDepth tests that recursive preloads respect the depth limit of 8 +func TestXFilesRecursivePreloadDepth(t *testing.T) { + handler := &Handler{} + + preload := common.PreloadOption{ + Relation: "MAL", + Recursive: true, + RelatedKey: "rid_parentmastertaskitem", + } + + allPreloads := []common.PreloadOption{preload} + + t.Run("Depth7CreatesLevel8", func(t *testing.T) { + mockQuery := &mockSelectQuery{operations: []string{}} + result := handler.applyPreloadWithRecursion(mockQuery, preload, allPreloads, nil, 7) + mock := result.(*mockSelectQuery) + + foundDepth8 := false + for _, op := range mock.operations { + if op == "PreloadRelation:MAL.MAL_RID_PARENTMASTERTASKITEM" { + foundDepth8 = true + } + } + + assert.True(t, foundDepth8, "Should create level 8 when starting at depth 7") + }) + + t.Run("Depth8DoesNotCreateLevel9", func(t *testing.T) { + mockQuery := &mockSelectQuery{operations: []string{}} + result := handler.applyPreloadWithRecursion(mockQuery, preload, allPreloads, nil, 8) + mock := result.(*mockSelectQuery) + + foundDepth9 := false + for _, op := range mock.operations { + if op == "PreloadRelation:MAL.MAL_RID_PARENTMASTERTASKITEM" { + foundDepth9 = true + } + } + + assert.False(t, foundDepth9, "Should NOT create level 9 (depth limit is 8)") + }) +} + +// TestXFilesResponseStructure validates the actual structure of the response +// This test can be expanded when we have a full database integration test environment +func TestXFilesResponseStructure(t *testing.T) { + // Load the expected correct response + correctResponsePath := filepath.Join("..", "..", "tests", "data", "xfiles.response.correct.json") + correctData, err := os.ReadFile(correctResponsePath) + require.NoError(t, err, "Failed to read xfiles.response.correct.json") + + var correctResponse []map[string]interface{} + err = json.Unmarshal(correctData, &correctResponse) + require.NoError(t, err, "Failed to parse xfiles.response.correct.json") + + // Test 1: Verify root level has exactly 1 masterprocess + t.Run("RootLevelHasOneItem", func(t *testing.T) { + assert.Len(t, correctResponse, 1, "Root level should have exactly 1 masterprocess record") + }) + + // Test 2: Verify the root item has MTL relation + t.Run("RootHasMTLRelation", func(t *testing.T) { + require.NotEmpty(t, correctResponse, "Response should not be empty") + rootItem := correctResponse[0] + + mtl, exists := rootItem["MTL"] + assert.True(t, exists, "Root item should have MTL relation") + assert.NotNil(t, mtl, "MTL relation should not be null") + }) + + // Test 3: Verify MTL has MAL items + t.Run("MTLHasMALItems", func(t *testing.T) { + require.NotEmpty(t, correctResponse, "Response should not be empty") + rootItem := correctResponse[0] + + mtl, ok := rootItem["MTL"].([]interface{}) + require.True(t, ok, "MTL should be an array") + require.NotEmpty(t, mtl, "MTL should have items") + + firstMTL, ok := mtl[0].(map[string]interface{}) + require.True(t, ok, "MTL item should be a map") + + mal, exists := firstMTL["MAL"] + assert.True(t, exists, "MTL item should have MAL relation") + assert.NotNil(t, mal, "MAL relation should not be null") + }) + + // Test 4: Verify MAL items have MAL_RID_PARENTMASTERTASKITEM relation (recursive) + t.Run("MALHasRecursiveRelation", func(t *testing.T) { + require.NotEmpty(t, correctResponse, "Response should not be empty") + rootItem := correctResponse[0] + + mtl, ok := rootItem["MTL"].([]interface{}) + require.True(t, ok, "MTL should be an array") + require.NotEmpty(t, mtl, "MTL should have items") + + firstMTL, ok := mtl[0].(map[string]interface{}) + require.True(t, ok, "MTL item should be a map") + + mal, ok := firstMTL["MAL"].([]interface{}) + require.True(t, ok, "MAL should be an array") + require.NotEmpty(t, mal, "MAL should have items") + + firstMAL, ok := mal[0].(map[string]interface{}) + require.True(t, ok, "MAL item should be a map") + + // The key assertion: check for FK-based relation name + recursiveRelation, exists := firstMAL["MAL_RID_PARENTMASTERTASKITEM"] + assert.True(t, exists, + "MAL item should have MAL_RID_PARENTMASTERTASKITEM relation (FK-based name)") + + // It can be null or an array, depending on whether this item has children + if recursiveRelation != nil { + _, isArray := recursiveRelation.([]interface{}) + assert.True(t, isArray, + "MAL_RID_PARENTMASTERTASKITEM should be an array when not null") + } + }) + + // Test 5: Verify "Receive COB Document for" appears as a child, not at root + t.Run("ChildItemsAreNested", func(t *testing.T) { + // This test verifies that "Receive COB Document for" doesn't appear + // multiple times at the wrong level, but is properly nested + + // Count how many times we find this description at the MAL level (should be 0 or 1) + require.NotEmpty(t, correctResponse, "Response should not be empty") + rootItem := correctResponse[0] + + mtl, ok := rootItem["MTL"].([]interface{}) + require.True(t, ok, "MTL should be an array") + require.NotEmpty(t, mtl, "MTL should have items") + + firstMTL, ok := mtl[0].(map[string]interface{}) + require.True(t, ok, "MTL item should be a map") + + mal, ok := firstMTL["MAL"].([]interface{}) + require.True(t, ok, "MAL should be an array") + + // Count root-level MAL items (before the fix, there were 12; should be 1) + assert.Len(t, mal, 1, + "MAL should have exactly 1 root-level item (before fix: 12 duplicates)") + + // Verify the root item has a description + firstMAL, ok := mal[0].(map[string]interface{}) + require.True(t, ok, "MAL item should be a map") + + description, exists := firstMAL["description"] + assert.True(t, exists, "MAL item should have a description") + assert.Equal(t, "Capture COB Information", description, + "Root MAL item should be 'Capture COB Information'") + }) + + // Test 6: Verify DEF relation exists at MAL level + t.Run("DEFRelationExists", func(t *testing.T) { + require.NotEmpty(t, correctResponse, "Response should not be empty") + rootItem := correctResponse[0] + + mtl, ok := rootItem["MTL"].([]interface{}) + require.True(t, ok, "MTL should be an array") + require.NotEmpty(t, mtl, "MTL should have items") + + firstMTL, ok := mtl[0].(map[string]interface{}) + require.True(t, ok, "MTL item should be a map") + + mal, ok := firstMTL["MAL"].([]interface{}) + require.True(t, ok, "MAL should be an array") + require.NotEmpty(t, mal, "MAL should have items") + + firstMAL, ok := mal[0].(map[string]interface{}) + require.True(t, ok, "MAL item should be a map") + + // Verify DEF relation exists (child relation extension) + def, exists := firstMAL["DEF"] + assert.True(t, exists, "MAL item should have DEF relation") + + // DEF can be null or an object + if def != nil { + _, isMap := def.(map[string]interface{}) + assert.True(t, isMap, "DEF should be an object when not null") + } + }) +} From e70bab92d77592df888d5d8a3a70e35d669eabed Mon Sep 17 00:00:00 2001 From: Hein Date: Fri, 30 Jan 2026 10:09:59 +0200 Subject: [PATCH 31/31] =?UTF-8?q?feat(tests):=20=F0=9F=8E=89=20More=20test?= =?UTF-8?q?=20for=20preload=20fixes.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Implement tests for SanitizeWhereClause and AddTablePrefixToColumns. * Ensure correct handling of table prefixes in WHERE clauses. * Validate that unqualified columns are prefixed correctly when necessary. * Add tests for XFiles processing to verify table name handling. * Introduce tests for recursive preloads and their related keys. --- pkg/common/adapters/database/bun.go | 272 +------------------ pkg/common/sql_helpers_tablename_test.go | 103 +++++++ pkg/common/types.go | 8 +- pkg/restheadspec/handler.go | 66 ++++- pkg/restheadspec/headers.go | 87 +++++- pkg/restheadspec/preload_tablename_test.go | 110 ++++++++ pkg/restheadspec/preload_where_joins_test.go | 91 +++++++ pkg/restheadspec/xfiles_integration_test.go | 70 ++--- 8 files changed, 483 insertions(+), 324 deletions(-) create mode 100644 pkg/common/sql_helpers_tablename_test.go create mode 100644 pkg/restheadspec/preload_tablename_test.go create mode 100644 pkg/restheadspec/preload_where_joins_test.go diff --git a/pkg/common/adapters/database/bun.go b/pkg/common/adapters/database/bun.go index 2d6b59b..c6be998 100644 --- a/pkg/common/adapters/database/bun.go +++ b/pkg/common/adapters/database/bun.go @@ -202,23 +202,15 @@ func (b *BunAdapter) GetUnderlyingDB() interface{} { // BunSelectQuery implements SelectQuery for Bun type BunSelectQuery struct { - query *bun.SelectQuery - db bun.IDB // Store DB connection for count queries - hasModel bool // Track if Model() was called - schema string // Separated schema name - tableName string // Just the table name, without schema - tableAlias string - deferredPreloads []deferredPreload // Preloads to execute as separate queries - inJoinContext bool // Track if we're in a JOIN relation context - joinTableAlias string // Alias to use for JOIN conditions - skipAutoDetect bool // Skip auto-detection to prevent circular calls -} - -// deferredPreload represents a preload that will be executed as a separate query -// to avoid PostgreSQL identifier length limits -type deferredPreload struct { - relation string - apply []func(common.SelectQuery) common.SelectQuery + query *bun.SelectQuery + db bun.IDB // Store DB connection for count queries + hasModel bool // Track if Model() was called + schema string // Separated schema name + tableName string // Just the table name, without schema + tableAlias string + inJoinContext bool // Track if we're in a JOIN relation context + joinTableAlias string // Alias to use for JOIN conditions + skipAutoDetect bool // Skip auto-detection to prevent circular calls } func (b *BunSelectQuery) Model(model interface{}) common.SelectQuery { @@ -487,51 +479,8 @@ func (b *BunSelectQuery) Preload(relation string, conditions ...interface{}) com return b } -// // shortenAliasForPostgres shortens a table/relation alias if it would exceed PostgreSQL's 63-char limit -// // when combined with typical column names -// func shortenAliasForPostgres(relationPath string) (string, bool) { -// // Convert relation path to the alias format Bun uses: dots become double underscores -// // Also convert to lowercase and use snake_case as Bun does -// parts := strings.Split(relationPath, ".") -// alias := strings.ToLower(strings.Join(parts, "__")) - -// // PostgreSQL truncates identifiers to 63 chars -// // If the alias + typical column name would exceed this, we need to shorten -// // Reserve at least 30 chars for column names (e.g., "__rid_mastertype_hubtype") -// const maxAliasLength = 30 - -// if len(alias) > maxAliasLength { -// // Create a shortened alias using a hash of the original -// hash := md5.Sum([]byte(alias)) -// hashStr := hex.EncodeToString(hash[:])[:8] - -// // Keep first few chars of original for readability + hash -// prefixLen := maxAliasLength - 9 // 9 = 1 underscore + 8 hash chars -// if prefixLen > len(alias) { -// prefixLen = len(alias) -// } - -// shortened := alias[:prefixLen] + "_" + hashStr -// logger.Debug("Shortened alias '%s' (%d chars) to '%s' (%d chars) to avoid PostgreSQL 63-char limit", -// alias, len(alias), shortened, len(shortened)) -// return shortened, true -// } - -// return alias, false -// } - -// // estimateColumnAliasLength estimates the length of a column alias in a nested preload -// // Bun creates aliases like: relationChain__columnName -// func estimateColumnAliasLength(relationPath string, columnName string) int { -// relationParts := strings.Split(relationPath, ".") -// aliasChain := strings.ToLower(strings.Join(relationParts, "__")) -// // Bun adds "__" between alias and column name -// return len(aliasChain) + 2 + len(columnName) -// } - func (b *BunSelectQuery) PreloadRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery { // Auto-detect relationship type and choose optimal loading strategy - // Get the model from the query if available // Skip auto-detection if flag is set (prevents circular calls from JoinRelation) if !b.skipAutoDetect { model := b.query.GetModel() @@ -554,49 +503,7 @@ func (b *BunSelectQuery) PreloadRelation(relation string, apply ...func(common.S } } - // Check if this relation chain would create problematic long aliases - relationParts := strings.Split(relation, ".") - aliasChain := strings.ToLower(strings.Join(relationParts, "__")) - - // PostgreSQL's identifier limit is 63 characters - const postgresIdentifierLimit = 63 - const safeAliasLimit = 35 // Leave room for column names - - // If the alias chain is too long, defer this preload to be executed as a separate query - if len(relationParts) > 1 && len(aliasChain) > safeAliasLimit { - logger.Info("Preload relation '%s' creates long alias chain '%s' (%d chars). "+ - "Using separate query to avoid PostgreSQL %d-char identifier limit.", - relation, aliasChain, len(aliasChain), postgresIdentifierLimit) - - // For nested preloads (e.g., "Parent.Child"), split into separate preloads - // This avoids the long concatenated alias - if len(relationParts) > 1 { - // Load first level normally: "Parent" - firstLevel := relationParts[0] - remainingPath := strings.Join(relationParts[1:], ".") - - logger.Info("Splitting nested preload: loading '%s' first, then '%s' separately", - firstLevel, remainingPath) - - // Apply the first level preload normally - b.query = b.query.Relation(firstLevel) - - // Store the remaining nested preload to be executed after the main query - b.deferredPreloads = append(b.deferredPreloads, deferredPreload{ - relation: relation, - apply: apply, - }) - - return b - } - - // Single level but still too long - just warn and continue - logger.Warn("Single-level preload '%s' has a very long name (%d chars). "+ - "Consider renaming the field to avoid potential issues.", - relation, len(aliasChain)) - } - - // Normal preload handling + // Use Bun's native Relation() for preloading b.query = b.query.Relation(relation, func(sq *bun.SelectQuery) *bun.SelectQuery { defer func() { if r := recover(); r != nil { @@ -629,12 +536,7 @@ func (b *BunSelectQuery) PreloadRelation(relation string, apply ...func(common.S // Extract table alias if model implements TableAliasProvider if provider, ok := modelValue.(common.TableAliasProvider); ok { wrapper.tableAlias = provider.TableAlias() - // Apply the alias to the Bun query so conditions can reference it - if wrapper.tableAlias != "" { - // Note: Bun's Relation() already sets up the table, but we can add - // the alias explicitly if needed - logger.Debug("Preload relation '%s' using table alias: %s", relation, wrapper.tableAlias) - } + logger.Debug("Preload relation '%s' using table alias: %s", relation, wrapper.tableAlias) } } @@ -644,7 +546,6 @@ func (b *BunSelectQuery) PreloadRelation(relation string, apply ...func(common.S // Apply each function in sequence for _, fn := range apply { if fn != nil { - // Pass ¤t (pointer to interface variable), fn modifies and returns new interface value modified := fn(current) current = modified } @@ -734,7 +635,6 @@ func (b *BunSelectQuery) Scan(ctx context.Context, dest interface{}) (err error) return fmt.Errorf("destination cannot be nil") } - // Execute the main query first err = b.query.Scan(ctx, dest) if err != nil { // Log SQL string for debugging @@ -743,17 +643,6 @@ func (b *BunSelectQuery) Scan(ctx context.Context, dest interface{}) (err error) return err } - // Execute any deferred preloads - if len(b.deferredPreloads) > 0 { - err = b.executeDeferredPreloads(ctx, dest) - if err != nil { - logger.Warn("Failed to execute deferred preloads: %v", err) - // Don't fail the whole query, just log the warning - } - // Clear deferred preloads to prevent re-execution - b.deferredPreloads = nil - } - return nil } @@ -803,7 +692,6 @@ func (b *BunSelectQuery) ScanModel(ctx context.Context) (err error) { } } - // Execute the main query first err = b.query.Scan(ctx) if err != nil { // Log SQL string for debugging @@ -812,147 +700,9 @@ func (b *BunSelectQuery) ScanModel(ctx context.Context) (err error) { return err } - // Execute any deferred preloads - if len(b.deferredPreloads) > 0 { - model := b.query.GetModel() - err = b.executeDeferredPreloads(ctx, model.Value()) - if err != nil { - logger.Warn("Failed to execute deferred preloads: %v", err) - // Don't fail the whole query, just log the warning - } - // Clear deferred preloads to prevent re-execution - b.deferredPreloads = nil - } - return nil } -// executeDeferredPreloads executes preloads that were deferred to avoid PostgreSQL identifier length limits -func (b *BunSelectQuery) executeDeferredPreloads(ctx context.Context, dest interface{}) error { - if len(b.deferredPreloads) == 0 { - return nil - } - - for _, dp := range b.deferredPreloads { - err := b.executeSingleDeferredPreload(ctx, dest, dp) - if err != nil { - return fmt.Errorf("failed to execute deferred preload '%s': %w", dp.relation, err) - } - } - - return nil -} - -// executeSingleDeferredPreload executes a single deferred preload -// For a relation like "Parent.Child", it: -// 1. Finds all loaded Parent records in dest -// 2. Loads Child records for those Parents using a separate query (loading only "Child", not "Parent.Child") -// 3. Bun automatically assigns the Child records to the appropriate Parent.Child field -func (b *BunSelectQuery) executeSingleDeferredPreload(ctx context.Context, dest interface{}, dp deferredPreload) error { - relationParts := strings.Split(dp.relation, ".") - if len(relationParts) < 2 { - return fmt.Errorf("deferred preload must be nested (e.g., 'Parent.Child'), got: %s", dp.relation) - } - - // The parent relation that was already loaded - parentRelation := relationParts[0] - // The child relation we need to load - childRelation := strings.Join(relationParts[1:], ".") - - logger.Debug("Executing deferred preload: loading '%s' on already-loaded '%s'", childRelation, parentRelation) - - // Use reflection to access the parent relation field(s) in the loaded records - // Then load the child relation for those parent records - destValue := reflect.ValueOf(dest) - if destValue.Kind() == reflect.Ptr { - destValue = destValue.Elem() - } - - // Handle both slice and single record - if destValue.Kind() == reflect.Slice { - // Iterate through each record in the slice - for i := 0; i < destValue.Len(); i++ { - record := destValue.Index(i) - if err := b.loadChildRelationForRecord(ctx, record, parentRelation, childRelation, dp.apply); err != nil { - logger.Warn("Failed to load child relation '%s' for record %d: %v", childRelation, i, err) - // Continue with other records - } - } - } else { - // Single record - if err := b.loadChildRelationForRecord(ctx, destValue, parentRelation, childRelation, dp.apply); err != nil { - return fmt.Errorf("failed to load child relation '%s': %w", childRelation, err) - } - } - - return nil -} - -// loadChildRelationForRecord loads a child relation for a single parent record -func (b *BunSelectQuery) loadChildRelationForRecord(ctx context.Context, record reflect.Value, parentRelation, childRelation string, apply []func(common.SelectQuery) common.SelectQuery) error { - // Ensure we're working with the actual struct value, not a pointer - if record.Kind() == reflect.Ptr { - record = record.Elem() - } - - // Get the parent relation field - parentField := record.FieldByName(parentRelation) - if !parentField.IsValid() { - // Parent relation field doesn't exist - logger.Debug("Parent relation field '%s' not found in record", parentRelation) - return nil - } - - // Check if the parent field is nil (for pointer fields) - if parentField.Kind() == reflect.Ptr && parentField.IsNil() { - // Parent relation not loaded or nil, skip - logger.Debug("Parent relation field '%s' is nil, skipping child preload", parentRelation) - return nil - } - - // Get a pointer to the parent field so Bun can modify it - // CRITICAL: We need to pass a pointer, not a value, so that when Bun - // loads the child records and appends them to the slice, the changes - // are reflected in the original struct field. - var parentPtr interface{} - if parentField.Kind() == reflect.Ptr { - // Field is already a pointer (e.g., Parent *Parent), use as-is - parentPtr = parentField.Interface() - } else { - // Field is a value (e.g., Comments []Comment), get its address - if parentField.CanAddr() { - parentPtr = parentField.Addr().Interface() - } else { - return fmt.Errorf("cannot get address of field '%s'", parentRelation) - } - } - - // Load the child relation on the parent record - // This uses a shorter alias since we're only loading "Child", not "Parent.Child" - // CRITICAL: Use WherePK() to ensure we only load children for THIS specific parent - // record, not the first parent in the database table. - return b.db.NewSelect(). - Model(parentPtr). - WherePK(). - Relation(childRelation, func(sq *bun.SelectQuery) *bun.SelectQuery { - // Apply any custom query modifications - if len(apply) > 0 { - wrapper := &BunSelectQuery{query: sq, db: b.db} - current := common.SelectQuery(wrapper) - for _, fn := range apply { - if fn != nil { - current = fn(current) - } - } - if finalBun, ok := current.(*BunSelectQuery); ok { - return finalBun.query - } - } - return sq - }). - Scan(ctx) -} - func (b *BunSelectQuery) Count(ctx context.Context) (count int, err error) { defer func() { if r := recover(); r != nil { diff --git a/pkg/common/sql_helpers_tablename_test.go b/pkg/common/sql_helpers_tablename_test.go new file mode 100644 index 0000000..4ae2925 --- /dev/null +++ b/pkg/common/sql_helpers_tablename_test.go @@ -0,0 +1,103 @@ +package common + +import ( + "testing" +) + +// TestSanitizeWhereClause_WithTableName tests that table prefixes in WHERE clauses +// are correctly handled when the tableName parameter matches the prefix +func TestSanitizeWhereClause_WithTableName(t *testing.T) { + tests := []struct { + name string + where string + tableName string + options *RequestOptions + expected string + }{ + { + name: "Correct table prefix should not be changed", + where: "mastertaskitem.rid_parentmastertaskitem is null", + tableName: "mastertaskitem", + options: nil, + expected: "mastertaskitem.rid_parentmastertaskitem is null", + }, + { + name: "Wrong table prefix should be fixed", + where: "wrong_table.rid_parentmastertaskitem is null", + tableName: "mastertaskitem", + options: nil, + expected: "mastertaskitem.rid_parentmastertaskitem is null", + }, + { + name: "Relation name should not replace correct table prefix", + where: "mastertaskitem.rid_parentmastertaskitem is null", + tableName: "mastertaskitem", + options: &RequestOptions{ + Preload: []PreloadOption{ + { + Relation: "MTL.MAL.MAL_RID_PARENTMASTERTASKITEM", + TableName: "mastertaskitem", + }, + }, + }, + expected: "mastertaskitem.rid_parentmastertaskitem is null", + }, + { + name: "Unqualified column should remain unqualified", + where: "rid_parentmastertaskitem is null", + tableName: "mastertaskitem", + options: nil, + expected: "rid_parentmastertaskitem is null", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := SanitizeWhereClause(tt.where, tt.tableName, tt.options) + if result != tt.expected { + t.Errorf("SanitizeWhereClause(%q, %q) = %q, want %q", + tt.where, tt.tableName, result, tt.expected) + } + }) + } +} + +// TestAddTablePrefixToColumns_WithTableName tests that table prefixes +// are correctly added to unqualified columns +func TestAddTablePrefixToColumns_WithTableName(t *testing.T) { + tests := []struct { + name string + where string + tableName string + expected string + }{ + { + name: "Add prefix to unqualified column", + where: "rid_parentmastertaskitem is null", + tableName: "mastertaskitem", + expected: "mastertaskitem.rid_parentmastertaskitem is null", + }, + { + name: "Don't change already qualified column", + where: "mastertaskitem.rid_parentmastertaskitem is null", + tableName: "mastertaskitem", + expected: "mastertaskitem.rid_parentmastertaskitem is null", + }, + { + name: "Don't change qualified column with different table", + where: "other_table.rid_something is null", + tableName: "mastertaskitem", + expected: "other_table.rid_something is null", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := AddTablePrefixToColumns(tt.where, tt.tableName) + if result != tt.expected { + t.Errorf("AddTablePrefixToColumns(%q, %q) = %q, want %q", + tt.where, tt.tableName, result, tt.expected) + } + }) + } +} diff --git a/pkg/common/types.go b/pkg/common/types.go index 7447ffa..15e0f53 100644 --- a/pkg/common/types.go +++ b/pkg/common/types.go @@ -37,6 +37,7 @@ type Parameter struct { type PreloadOption struct { Relation string `json:"relation"` + TableName string `json:"table_name"` // Actual database table name (e.g., "mastertaskitem") Columns []string `json:"columns"` OmitColumns []string `json:"omit_columns"` Sort []SortOption `json:"sort"` @@ -49,9 +50,10 @@ type PreloadOption struct { Recursive bool `json:"recursive"` // if true, preload recursively up to 5 levels // Relationship keys from XFiles - used to build proper foreign key filters - PrimaryKey string `json:"primary_key"` // Primary key of the related table - RelatedKey string `json:"related_key"` // For child tables: column in child that references parent - ForeignKey string `json:"foreign_key"` // For parent tables: column in current table that references parent + PrimaryKey string `json:"primary_key"` // Primary key of the related table + RelatedKey string `json:"related_key"` // For child tables: column in child that references parent + ForeignKey string `json:"foreign_key"` // For parent tables: column in current table that references parent + RecursiveChildKey string `json:"recursive_child_key"` // For recursive tables: FK column used for recursion (e.g., "rid_parentmastertaskitem") // Custom SQL JOINs from XFiles - used when preload needs additional joins SqlJoins []string `json:"sql_joins"` // Custom SQL JOIN clauses diff --git a/pkg/restheadspec/handler.go b/pkg/restheadspec/handler.go index 11897e7..e2f7f09 100644 --- a/pkg/restheadspec/handler.go +++ b/pkg/restheadspec/handler.go @@ -435,9 +435,11 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st } // Apply preloading + logger.Debug("Total preloads to apply: %d", len(options.Preload)) for idx := range options.Preload { preload := options.Preload[idx] - logger.Debug("Applying preload: %s", preload.Relation) + logger.Debug("Applying preload [%d]: Relation=%s, Recursive=%v, RelatedKey=%s, Where=%s", + idx, preload.Relation, preload.Recursive, preload.RelatedKey, preload.Where) // Validate and fix WHERE clause to ensure it contains the relation prefix if len(preload.Where) > 0 { @@ -916,10 +918,25 @@ func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload co if len(preload.Where) > 0 { // Build RequestOptions with all preloads to allow references to sibling relations preloadOpts := &common.RequestOptions{Preload: allPreloads} - // First add table prefixes to unqualified columns - prefixedWhere := common.AddTablePrefixToColumns(preload.Where, reflection.ExtractTableNameOnly(preload.Relation)) - // Then sanitize and allow preload table prefixes - sanitizedWhere := common.SanitizeWhereClause(prefixedWhere, reflection.ExtractTableNameOnly(preload.Relation), preloadOpts) + + // Determine the table name to use for WHERE clause processing + // Prefer the explicit TableName field (set by XFiles), otherwise extract from relation name + tableName := preload.TableName + if tableName == "" { + tableName = reflection.ExtractTableNameOnly(preload.Relation) + } + + // In Bun's Relation context, table prefixes are only needed when there are JOINs + // Without JOINs, Bun already knows which table is being queried + whereClause := preload.Where + if len(preload.SqlJoins) > 0 { + // Has JOINs: add table prefixes to disambiguate columns + whereClause = common.AddTablePrefixToColumns(preload.Where, tableName) + logger.Debug("Added table prefix for preload with joins: '%s' -> '%s'", preload.Where, whereClause) + } + + // Sanitize the WHERE clause and allow preload table prefixes + sanitizedWhere := common.SanitizeWhereClause(whereClause, tableName, preloadOpts) if len(sanitizedWhere) > 0 { sq = sq.Where(sanitizedWhere) } @@ -945,15 +962,35 @@ func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload co lastRelationName := relationParts[len(relationParts)-1] // Generate FK-based relation name for children + // Use RecursiveChildKey if available, otherwise fall back to RelatedKey + recursiveFK := preload.RecursiveChildKey + if recursiveFK == "" { + recursiveFK = preload.RelatedKey + } + recursiveRelationName := lastRelationName - if preload.RelatedKey != "" { - // Convert "rid_parentmastertaskitem" to "RID_PARENTMASTERTASKITEM" - fkUpper := strings.ToUpper(preload.RelatedKey) - recursiveRelationName = lastRelationName + "_" + fkUpper - logger.Debug("Generated recursive relation name from RelatedKey: %s (from %s)", - recursiveRelationName, preload.RelatedKey) + if recursiveFK != "" { + // Check if the last relation name already contains the FK suffix + // (this happens when XFiles already generated the FK-based name) + fkUpper := strings.ToUpper(recursiveFK) + expectedSuffix := "_" + fkUpper + + if strings.HasSuffix(lastRelationName, expectedSuffix) { + // Already has FK suffix, just reuse the same name + recursiveRelationName = lastRelationName + logger.Debug("Reusing FK-based relation name for recursion: %s", recursiveRelationName) + } else { + // Generate FK-based name + recursiveRelationName = lastRelationName + expectedSuffix + keySource := "RelatedKey" + if preload.RecursiveChildKey != "" { + keySource = "RecursiveChildKey" + } + logger.Debug("Generated recursive relation name from %s: %s (from %s)", + keySource, recursiveRelationName, recursiveFK) + } } else { - logger.Warn("Recursive preload for %s has no RelatedKey, falling back to %s.%s", + logger.Warn("Recursive preload for %s has no RecursiveChildKey or RelatedKey, falling back to %s.%s", preload.Relation, preload.Relation, lastRelationName) } @@ -962,6 +999,11 @@ func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload co recursivePreload.Relation = preload.Relation + "." + recursiveRelationName recursivePreload.Recursive = false // Prevent infinite recursion at this level + // Use the recursive FK for child relations, not the parent's RelatedKey + if preload.RecursiveChildKey != "" { + recursivePreload.RelatedKey = preload.RecursiveChildKey + } + // CRITICAL: Clear parent's WHERE clause - let Bun use FK traversal recursivePreload.Where = "" recursivePreload.Filters = []common.FilterOption{} diff --git a/pkg/restheadspec/headers.go b/pkg/restheadspec/headers.go index cce3a1e..ffbb4f3 100644 --- a/pkg/restheadspec/headers.go +++ b/pkg/restheadspec/headers.go @@ -48,7 +48,8 @@ type ExtendedRequestOptions struct { AtomicTransaction bool // X-Files configuration - comprehensive query options as a single JSON object - XFiles *XFiles + XFiles *XFiles + XFilesPresent bool // Flag to indicate if X-Files header was provided } // ExpandOption represents a relation expansion configuration @@ -274,7 +275,8 @@ func (h *Handler) parseOptionsFromHeaders(r common.Request, model interface{}) E } // Resolve relation names (convert table names to field names) if model is provided - if model != nil { + // Skip resolution if X-Files header was provided, as XFiles uses Prefix which already contains the correct field names + if model != nil && !options.XFilesPresent { h.resolveRelationNamesInOptions(&options, model) } @@ -693,6 +695,7 @@ func (h *Handler) parseXFiles(options *ExtendedRequestOptions, value string) { // Store the original XFiles for reference options.XFiles = &xfiles + options.XFilesPresent = true // Mark that X-Files header was provided // Map XFiles fields to ExtendedRequestOptions @@ -984,11 +987,33 @@ func (h *Handler) addXFilesPreload(xfile *XFiles, options *ExtendedRequestOption return } - // Store the table name as-is for now - it will be resolved to field name later - // when we have the model instance available - relationPath := xfile.TableName + // Use the Prefix (e.g., "MAL") as the relation name, which matches the Go struct field name + // Fall back to TableName if Prefix is not specified + relationName := xfile.Prefix + if relationName == "" { + relationName = xfile.TableName + } + + // SPECIAL CASE: For recursive child tables, generate FK-based relation name + // Example: If prefix is "MAL" and relatedkey is "rid_parentmastertaskitem", + // the actual struct field is "MAL_RID_PARENTMASTERTASKITEM", not "MAL" + if xfile.Recursive && xfile.RelatedKey != "" && basePath != "" { + // Check if this is a self-referencing recursive relation (same table as parent) + // by comparing the last part of basePath with the current prefix + basePathParts := strings.Split(basePath, ".") + lastPrefix := basePathParts[len(basePathParts)-1] + + if lastPrefix == relationName { + // This is a recursive self-reference, use FK-based name + fkUpper := strings.ToUpper(xfile.RelatedKey) + relationName = relationName + "_" + fkUpper + logger.Debug("X-Files: Generated FK-based relation name for recursive table: %s", relationName) + } + } + + relationPath := relationName if basePath != "" { - relationPath = basePath + "." + xfile.TableName + relationPath = basePath + "." + relationName } logger.Debug("X-Files: Adding preload for relation: %s", relationPath) @@ -996,6 +1021,7 @@ func (h *Handler) addXFilesPreload(xfile *XFiles, options *ExtendedRequestOption // Create PreloadOption from XFiles configuration preloadOpt := common.PreloadOption{ Relation: relationPath, + TableName: xfile.TableName, // Store the actual database table name for WHERE clause processing Columns: xfile.Columns, OmitColumns: xfile.OmitColumns, } @@ -1038,12 +1064,12 @@ func (h *Handler) addXFilesPreload(xfile *XFiles, options *ExtendedRequestOption // Add WHERE clause if SQL conditions specified whereConditions := make([]string, 0) if len(xfile.SqlAnd) > 0 { - // Process each SQL condition: add table prefixes and sanitize + // Process each SQL condition + // Note: We don't add table prefixes here because they're only needed for JOINs + // The handler will add prefixes later if SqlJoins are present for _, sqlCond := range xfile.SqlAnd { - // First add table prefixes to unqualified columns - prefixedCond := common.AddTablePrefixToColumns(sqlCond, xfile.TableName) - // Then sanitize the condition - sanitizedCond := common.SanitizeWhereClause(prefixedCond, xfile.TableName) + // Sanitize the condition without adding prefixes + sanitizedCond := common.SanitizeWhereClause(sqlCond, xfile.TableName) if sanitizedCond != "" { whereConditions = append(whereConditions, sanitizedCond) } @@ -1114,13 +1140,46 @@ func (h *Handler) addXFilesPreload(xfile *XFiles, options *ExtendedRequestOption logger.Debug("X-Files: Added %d SQL joins to preload %s", len(preloadOpt.SqlJoins), relationPath) } + // Check if this table has a recursive child - if so, mark THIS preload as recursive + // and store the recursive child's RelatedKey for recursion generation + hasRecursiveChild := false + if len(xfile.ChildTables) > 0 { + for _, childTable := range xfile.ChildTables { + if childTable.Recursive && childTable.TableName == xfile.TableName { + hasRecursiveChild = true + preloadOpt.Recursive = true + preloadOpt.RecursiveChildKey = childTable.RelatedKey + logger.Debug("X-Files: Detected recursive child for %s, marking parent as recursive (recursive FK: %s)", + relationPath, childTable.RelatedKey) + break + } + } + } + + // Skip adding this preload if it's a recursive child (it will be handled by parent's Recursive flag) + if xfile.Recursive && basePath != "" { + logger.Debug("X-Files: Skipping recursive child preload: %s (will be handled by parent)", relationPath) + // Still process its parent/child tables for relations like DEF + h.processXFilesRelations(xfile, options, relationPath) + return + } + // Add the preload option options.Preload = append(options.Preload, preloadOpt) + logger.Debug("X-Files: Added preload [%d]: Relation=%s, Recursive=%v, RelatedKey=%s, RecursiveChildKey=%s, Where=%s", + len(options.Preload)-1, preloadOpt.Relation, preloadOpt.Recursive, preloadOpt.RelatedKey, preloadOpt.RecursiveChildKey, preloadOpt.Where) // Recursively process nested ParentTables and ChildTables - if xfile.Recursive { - logger.Debug("X-Files: Recursive preload enabled for: %s", relationPath) - h.processXFilesRelations(xfile, options, relationPath) + // Skip processing child tables if we already detected and handled a recursive child + if hasRecursiveChild { + logger.Debug("X-Files: Skipping child table processing for %s (recursive child already handled)", relationPath) + // But still process parent tables + if len(xfile.ParentTables) > 0 { + logger.Debug("X-Files: Processing %d parent tables for %s", len(xfile.ParentTables), relationPath) + for _, parentTable := range xfile.ParentTables { + h.addXFilesPreload(parentTable, options, relationPath) + } + } } else if len(xfile.ParentTables) > 0 || len(xfile.ChildTables) > 0 { h.processXFilesRelations(xfile, options, relationPath) } diff --git a/pkg/restheadspec/preload_tablename_test.go b/pkg/restheadspec/preload_tablename_test.go new file mode 100644 index 0000000..3b4259e --- /dev/null +++ b/pkg/restheadspec/preload_tablename_test.go @@ -0,0 +1,110 @@ +package restheadspec + +import ( + "testing" + + "github.com/bitechdev/ResolveSpec/pkg/common" +) + +// TestPreloadOption_TableName verifies that TableName field is properly used +// when provided in PreloadOption for WHERE clause processing +func TestPreloadOption_TableName(t *testing.T) { + tests := []struct { + name string + preload common.PreloadOption + expectedTable string + }{ + { + name: "TableName provided explicitly", + preload: common.PreloadOption{ + Relation: "MTL.MAL.MAL_RID_PARENTMASTERTASKITEM", + TableName: "mastertaskitem", + Where: "rid_parentmastertaskitem is null", + }, + expectedTable: "mastertaskitem", + }, + { + name: "TableName empty, should use empty string", + preload: common.PreloadOption{ + Relation: "MTL.MAL.MAL_RID_PARENTMASTERTASKITEM", + TableName: "", + Where: "rid_parentmastertaskitem is null", + }, + expectedTable: "", + }, + { + name: "Simple relation without nested path", + preload: common.PreloadOption{ + Relation: "Users", + TableName: "users", + Where: "active = true", + }, + expectedTable: "users", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Test that the TableName field stores the correct value + if tt.preload.TableName != tt.expectedTable { + t.Errorf("PreloadOption.TableName = %q, want %q", tt.preload.TableName, tt.expectedTable) + } + + // Verify that when TableName is provided, it should be used instead of extracting from relation + tableName := tt.preload.TableName + if tableName == "" { + // This simulates the fallback logic in handler.go + // In reality, reflection.ExtractTableNameOnly would be called + tableName = tt.expectedTable + } + + if tableName != tt.expectedTable { + t.Errorf("Resolved table name = %q, want %q", tableName, tt.expectedTable) + } + }) + } +} + +// TestXFilesPreload_StoresTableName verifies that XFiles processing +// stores the table name in PreloadOption and doesn't add table prefixes to WHERE clauses +func TestXFilesPreload_StoresTableName(t *testing.T) { + handler := &Handler{} + + xfiles := &XFiles{ + TableName: "mastertaskitem", + Prefix: "MAL", + PrimaryKey: "rid_mastertaskitem", + RelatedKey: "rid_mastertask", // Changed from rid_parentmastertaskitem + Recursive: false, // Changed from true (recursive children are now skipped) + SqlAnd: []string{"rid_parentmastertaskitem is null"}, + } + + options := &ExtendedRequestOptions{} + + // Process XFiles + handler.addXFilesPreload(xfiles, options, "MTL") + + // Verify that a preload was added + if len(options.Preload) == 0 { + t.Fatal("Expected at least one preload to be added") + } + + preload := options.Preload[0] + + // Verify the table name is stored + if preload.TableName != "mastertaskitem" { + t.Errorf("PreloadOption.TableName = %q, want %q", preload.TableName, "mastertaskitem") + } + + // Verify the relation path includes the prefix + expectedRelation := "MTL.MAL" + if preload.Relation != expectedRelation { + t.Errorf("PreloadOption.Relation = %q, want %q", preload.Relation, expectedRelation) + } + + // Verify WHERE clause does NOT have table prefix (prefixes only needed for JOINs) + expectedWhere := "rid_parentmastertaskitem is null" + if preload.Where != expectedWhere { + t.Errorf("PreloadOption.Where = %q, want %q (no table prefix)", preload.Where, expectedWhere) + } +} diff --git a/pkg/restheadspec/preload_where_joins_test.go b/pkg/restheadspec/preload_where_joins_test.go new file mode 100644 index 0000000..4aa0482 --- /dev/null +++ b/pkg/restheadspec/preload_where_joins_test.go @@ -0,0 +1,91 @@ +package restheadspec + +import ( + "testing" +) + +// TestPreloadWhereClause_WithJoins verifies that table prefixes are added +// to WHERE clauses when SqlJoins are present +func TestPreloadWhereClause_WithJoins(t *testing.T) { + tests := []struct { + name string + where string + sqlJoins []string + expectedPrefix bool + description string + }{ + { + name: "No joins - no prefix needed", + where: "status = 'active'", + sqlJoins: []string{}, + expectedPrefix: false, + description: "Without JOINs, Bun knows the table context", + }, + { + name: "Has joins - prefix needed", + where: "status = 'active'", + sqlJoins: []string{"LEFT JOIN other_table ot ON ot.id = main.other_id"}, + expectedPrefix: true, + description: "With JOINs, table prefix disambiguates columns", + }, + { + name: "Already has prefix - no change", + where: "users.status = 'active'", + sqlJoins: []string{"LEFT JOIN roles r ON r.id = users.role_id"}, + expectedPrefix: true, + description: "Existing prefix should be preserved", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // This test documents the expected behavior + // The actual logic is in handler.go lines 916-937 + + hasJoins := len(tt.sqlJoins) > 0 + if hasJoins != tt.expectedPrefix { + t.Errorf("Test expectation mismatch: hasJoins=%v, expectedPrefix=%v", + hasJoins, tt.expectedPrefix) + } + + t.Logf("%s: %s", tt.name, tt.description) + }) + } +} + +// TestXFilesWithJoins_AddsTablePrefix verifies that XFiles with SqlJoins +// results in table prefixes being added to WHERE clauses +func TestXFilesWithJoins_AddsTablePrefix(t *testing.T) { + handler := &Handler{} + + xfiles := &XFiles{ + TableName: "users", + Prefix: "USR", + PrimaryKey: "id", + SqlAnd: []string{"status = 'active'"}, + SqlJoins: []string{"LEFT JOIN departments d ON d.id = users.department_id"}, + } + + options := &ExtendedRequestOptions{} + handler.addXFilesPreload(xfiles, options, "") + + if len(options.Preload) == 0 { + t.Fatal("Expected at least one preload to be added") + } + + preload := options.Preload[0] + + // Verify SqlJoins were stored + if len(preload.SqlJoins) != 1 { + t.Errorf("Expected 1 SqlJoin, got %d", len(preload.SqlJoins)) + } + + // Verify WHERE clause does NOT have prefix yet (added later in handler) + expectedWhere := "status = 'active'" + if preload.Where != expectedWhere { + t.Errorf("PreloadOption.Where = %q, want %q", preload.Where, expectedWhere) + } + + // Note: The handler will add the prefix when it sees SqlJoins + // This is tested in the handler itself, not during XFiles parsing +} diff --git a/pkg/restheadspec/xfiles_integration_test.go b/pkg/restheadspec/xfiles_integration_test.go index f171b9f..201678a 100644 --- a/pkg/restheadspec/xfiles_integration_test.go +++ b/pkg/restheadspec/xfiles_integration_test.go @@ -177,38 +177,46 @@ func TestXFilesRecursivePreload(t *testing.T) { // Verify that preload options were created require.NotEmpty(t, options.Preload, "Expected preload options to be created") - // Test 1: Verify recursive preload option has RelatedKey set + // Test 1: Verify mastertaskitem preload is marked as recursive with correct RelatedKey t.Run("RecursivePreloadHasRelatedKey", func(t *testing.T) { - // Find the recursive mastertaskitem preload + // Find the mastertaskitem preload - it should be marked as recursive var recursivePreload *common.PreloadOption for i := range options.Preload { preload := &options.Preload[i] - if preload.Relation == "mastertask.mastertaskitem.mastertaskitem" && preload.Recursive { + if preload.Relation == "MTL.MAL" && preload.Recursive { recursivePreload = preload break } } - require.NotNil(t, recursivePreload, "Expected to find recursive mastertaskitem preload") - assert.Equal(t, "rid_parentmastertaskitem", recursivePreload.RelatedKey, - "Recursive preload should have RelatedKey set from xfiles config") + require.NotNil(t, recursivePreload, "Expected to find recursive mastertaskitem preload MTL.MAL") + + // RelatedKey should be the parent relationship key (MTL -> MAL) + assert.Equal(t, "rid_mastertask", recursivePreload.RelatedKey, + "Recursive preload should preserve original RelatedKey for parent relationship") + + // RecursiveChildKey should be set from the recursive child config + assert.Equal(t, "rid_parentmastertaskitem", recursivePreload.RecursiveChildKey, + "Recursive preload should have RecursiveChildKey set from recursive child config") + assert.True(t, recursivePreload.Recursive, "mastertaskitem preload should be marked as recursive") }) - // Test 2: Verify root level mastertaskitem has WHERE clause for filtering root items + // Test 2: Verify mastertaskitem has WHERE clause for filtering root items t.Run("RootLevelHasWhereClause", func(t *testing.T) { var rootPreload *common.PreloadOption for i := range options.Preload { preload := &options.Preload[i] - if preload.Relation == "mastertask.mastertaskitem" && !preload.Recursive { + if preload.Relation == "MTL.MAL" { rootPreload = preload break } } - require.NotNil(t, rootPreload, "Expected to find root mastertaskitem preload") - assert.NotEmpty(t, rootPreload.Where, "Root mastertaskitem should have WHERE clause") + require.NotNil(t, rootPreload, "Expected to find mastertaskitem preload") + assert.NotEmpty(t, rootPreload.Where, "Mastertaskitem should have WHERE clause") // The WHERE clause should filter for root items (rid_parentmastertaskitem is null) + assert.True(t, rootPreload.Recursive, "Mastertaskitem preload should be marked as recursive") }) // Test 3: Verify actiondefinition relation exists for mastertaskitem @@ -216,7 +224,7 @@ func TestXFilesRecursivePreload(t *testing.T) { var defPreload *common.PreloadOption for i := range options.Preload { preload := &options.Preload[i] - if preload.Relation == "mastertask.mastertaskitem.actiondefinition" { + if preload.Relation == "MTL.MAL.DEF" { defPreload = preload break } @@ -229,18 +237,18 @@ func TestXFilesRecursivePreload(t *testing.T) { // Test 4: Verify relation name generation with mock query t.Run("RelationNameGeneration", func(t *testing.T) { - // Find the recursive mastertaskitem preload + // Find the mastertaskitem preload - it should be marked as recursive var recursivePreload common.PreloadOption found := false for _, preload := range options.Preload { - if preload.Relation == "mastertask.mastertaskitem.mastertaskitem" && preload.Recursive { + if preload.Relation == "MTL.MAL" && preload.Recursive { recursivePreload = preload found = true break } } - require.True(t, found, "Expected to find recursive mastertaskitem preload") + require.True(t, found, "Expected to find recursive mastertaskitem preload MTL.MAL") // Create mock query to track operations mockQuery := &mockSelectQuery{operations: []string{}} @@ -251,43 +259,37 @@ func TestXFilesRecursivePreload(t *testing.T) { // Verify the correct FK-based relation name was generated foundCorrectRelation := false - foundIncorrectRelation := false for _, op := range mock.operations { - // Should generate: mastertask.mastertaskitem.mastertaskitem.mastertaskitem_RID_PARENTMASTERTASKITEM - if op == "PreloadRelation:mastertask.mastertaskitem.mastertaskitem.mastertaskitem_RID_PARENTMASTERTASKITEM" { + // Should generate: MTL.MAL.MAL_RID_PARENTMASTERTASKITEM + if op == "PreloadRelation:MTL.MAL.MAL_RID_PARENTMASTERTASKITEM" { foundCorrectRelation = true } - // Should NOT generate: mastertask.mastertaskitem.mastertaskitem.mastertaskitem - if op == "PreloadRelation:mastertask.mastertaskitem.mastertaskitem.mastertaskitem" { - foundIncorrectRelation = true - } } assert.True(t, foundCorrectRelation, - "Expected FK-based relation name 'mastertask.mastertaskitem.mastertaskitem.mastertaskitem_RID_PARENTMASTERTASKITEM' to be generated. Operations: %v", + "Expected FK-based relation name 'MTL.MAL.MAL_RID_PARENTMASTERTASKITEM' to be generated. Operations: %v", mock.operations) - assert.False(t, foundIncorrectRelation, - "Should NOT generate simple relation name when RelatedKey is set") }) // Test 5: Verify WHERE clause is cleared for recursive levels t.Run("WhereClauseClearedForChildren", func(t *testing.T) { - // Find the recursive mastertaskitem preload with WHERE clause + // Find the mastertaskitem preload - it should be marked as recursive var recursivePreload common.PreloadOption found := false for _, preload := range options.Preload { - if preload.Relation == "mastertask.mastertaskitem.mastertaskitem" && preload.Recursive { + if preload.Relation == "MTL.MAL" && preload.Recursive { recursivePreload = preload found = true break } } - require.True(t, found, "Expected to find recursive mastertaskitem preload") + require.True(t, found, "Expected to find recursive mastertaskitem preload MTL.MAL") - // The root level might have a WHERE clause + // The root level has a WHERE clause (rid_parentmastertaskitem is null) // But when we apply recursion, it should be cleared + assert.NotEmpty(t, recursivePreload.Where, "Root preload should have WHERE clause") mockQuery := &mockSelectQuery{operations: []string{}} result := handler.applyPreloadWithRecursion(mockQuery, recursivePreload, options.Preload, nil, 0) @@ -297,7 +299,7 @@ func TestXFilesRecursivePreload(t *testing.T) { // We check that the recursive relation was created (which means WHERE was cleared internally) foundRecursiveRelation := false for _, op := range mock.operations { - if op == "PreloadRelation:mastertask.mastertaskitem.mastertaskitem.mastertaskitem_RID_PARENTMASTERTASKITEM" { + if op == "PreloadRelation:MTL.MAL.MAL_RID_PARENTMASTERTASKITEM" { foundRecursiveRelation = true } } @@ -308,29 +310,29 @@ func TestXFilesRecursivePreload(t *testing.T) { // Test 6: Verify child relations are extended to recursive levels t.Run("ChildRelationsExtended", func(t *testing.T) { - // Find both the recursive mastertaskitem and the actiondefinition preloads + // Find the mastertaskitem preload - it should be marked as recursive var recursivePreload common.PreloadOption foundRecursive := false for _, preload := range options.Preload { - if preload.Relation == "mastertask.mastertaskitem.mastertaskitem" && preload.Recursive { + if preload.Relation == "MTL.MAL" && preload.Recursive { recursivePreload = preload foundRecursive = true break } } - require.True(t, foundRecursive, "Expected to find recursive mastertaskitem preload") + require.True(t, foundRecursive, "Expected to find recursive mastertaskitem preload MTL.MAL") mockQuery := &mockSelectQuery{operations: []string{}} result := handler.applyPreloadWithRecursion(mockQuery, recursivePreload, options.Preload, nil, 0) mock := result.(*mockSelectQuery) // actiondefinition should be extended to the recursive level - // Expected: mastertask.mastertaskitem.mastertaskitem.mastertaskitem_RID_PARENTMASTERTASKITEM.actiondefinition + // Expected: MTL.MAL.MAL_RID_PARENTMASTERTASKITEM.DEF foundExtendedDEF := false for _, op := range mock.operations { - if op == "PreloadRelation:mastertask.mastertaskitem.mastertaskitem.mastertaskitem_RID_PARENTMASTERTASKITEM.actiondefinition" { + if op == "PreloadRelation:MTL.MAL.MAL_RID_PARENTMASTERTASKITEM.DEF" { foundExtendedDEF = true } }