mirror of
https://github.com/bitechdev/ResolveSpec.git
synced 2026-06-05 13:23:46 +00:00
Compare commits
9 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 549ccb8468 | |||
| 1af9c76337 | |||
| 938a2ef3d9 | |||
| 69cc3e2839 | |||
| 4018af0636 | |||
| c4e79d6950 | |||
| 982a0e62ac | |||
| 5d459c95a7 | |||
| e9f7726e43 |
@@ -1489,7 +1489,7 @@ func (b *BunInsertQuery) OnConflict(action string) common.InsertQuery {
|
||||
|
||||
func (b *BunInsertQuery) Returning(columns ...string) common.InsertQuery {
|
||||
if len(columns) > 0 {
|
||||
b.query = b.query.Returning(columns[0])
|
||||
b.query = b.query.Returning(strings.Join(columns, ", "))
|
||||
}
|
||||
return b
|
||||
}
|
||||
@@ -1606,7 +1606,7 @@ func (b *BunUpdateQuery) Where(query string, args ...interface{}) common.UpdateQ
|
||||
|
||||
func (b *BunUpdateQuery) Returning(columns ...string) common.UpdateQuery {
|
||||
if len(columns) > 0 {
|
||||
b.query = b.query.Returning(columns[0])
|
||||
b.query = b.query.Returning(strings.Join(columns, ", "))
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package common
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
@@ -43,7 +44,7 @@ func (v *ColumnValidator) buildValidColumns() {
|
||||
for i := 0; i < modelType.NumField(); i++ {
|
||||
field := modelType.Field(i)
|
||||
|
||||
if !field.IsExported() {
|
||||
if !field.IsExported() || field.Anonymous {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -125,6 +126,16 @@ func (v *ColumnValidator) IsValidColumn(column string) bool {
|
||||
return v.ValidateColumn(column) == nil
|
||||
}
|
||||
|
||||
// Columns returns all valid column names known to this validator
|
||||
func (v *ColumnValidator) Columns() []string {
|
||||
cols := make([]string, 0, len(v.validColumns))
|
||||
for col := range v.validColumns {
|
||||
cols = append(cols, col)
|
||||
}
|
||||
sort.Strings(cols)
|
||||
return cols
|
||||
}
|
||||
|
||||
// FilterValidColumns filters a list of columns, returning only valid ones
|
||||
// Logs warnings for any invalid columns
|
||||
func (v *ColumnValidator) FilterValidColumns(columns []string) []string {
|
||||
@@ -224,7 +235,19 @@ func (v *ColumnValidator) FilterRequestOptions(options RequestOptions) RequestOp
|
||||
// Filter Filter columns
|
||||
validFilters := make([]FilterOption, 0, len(options.Filters))
|
||||
for _, filter := range options.Filters {
|
||||
if v.IsValidColumn(filter.Column) {
|
||||
if strings.EqualFold(filter.Column, "all") {
|
||||
allCols := v.Columns()
|
||||
if len(filtered.Columns) > 0 {
|
||||
allCols = filtered.Columns
|
||||
}
|
||||
for _, col := range allCols {
|
||||
expanded := filter
|
||||
expanded.Column = col
|
||||
expanded.LogicOperator = "OR"
|
||||
|
||||
validFilters = append(validFilters, expanded)
|
||||
}
|
||||
} else if v.IsValidColumn(filter.Column) {
|
||||
validFilters = append(validFilters, filter)
|
||||
} else {
|
||||
logger.Warn("Invalid column in filter '%s' removed", filter.Column)
|
||||
|
||||
@@ -174,7 +174,7 @@ func (h *Handler) SqlQueryList(sqlquery string, options SqlQueryOptions) HTTPFun
|
||||
varName := kw[1 : len(kw)-1] // strip [ and ]
|
||||
if val, ok := variables[varName]; ok {
|
||||
if strVal := fmt.Sprintf("%v", val); strVal != "" {
|
||||
sqlquery = strings.ReplaceAll(sqlquery, kw, ValidSQL(strVal, "colvalue"))
|
||||
sqlquery = strings.ReplaceAll(sqlquery, kw, safeSubstituteVar(sqlquery, kw, strVal))
|
||||
continue
|
||||
}
|
||||
}
|
||||
@@ -533,7 +533,7 @@ func (h *Handler) SqlQuery(sqlquery string, options SqlQueryOptions) HTTPFuncTyp
|
||||
varName := kw[1 : len(kw)-1] // strip [ and ]
|
||||
if val, ok := variables[varName]; ok {
|
||||
if strVal := fmt.Sprintf("%v", val); strVal != "" {
|
||||
sqlquery = strings.ReplaceAll(sqlquery, kw, ValidSQL(strVal, "colvalue"))
|
||||
sqlquery = strings.ReplaceAll(sqlquery, kw, safeSubstituteVar(sqlquery, kw, strVal))
|
||||
continue
|
||||
}
|
||||
}
|
||||
@@ -1006,6 +1006,37 @@ func IsNumeric(s string) bool {
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// isInsideDollarQuote reports whether the first occurrence of placeholder in sqlquery
|
||||
// is immediately surrounded by dollar-sign characters (i.e. inside a $...$-quoted string).
|
||||
// Dollar-quoted strings pass content through literally — no backslash processing — so
|
||||
// values placed there must NOT have their backslashes escaped.
|
||||
func isInsideDollarQuote(sqlquery, placeholder string) bool {
|
||||
idx := strings.Index(sqlquery, placeholder)
|
||||
if idx < 0 {
|
||||
return false
|
||||
}
|
||||
endIdx := idx + len(placeholder)
|
||||
charBefore := byte(0)
|
||||
charAfter := byte(0)
|
||||
if idx > 0 {
|
||||
charBefore = sqlquery[idx-1]
|
||||
}
|
||||
if endIdx < len(sqlquery) {
|
||||
charAfter = sqlquery[endIdx]
|
||||
}
|
||||
return charBefore == '$' || charAfter == '$'
|
||||
}
|
||||
|
||||
// safeSubstituteVar returns value sanitised for the quoting context that surrounds
|
||||
// placeholder in sqlquery: raw (no backslash escaping) for dollar-quoted contexts,
|
||||
// ValidSQL("colvalue") escaping for everything else.
|
||||
func safeSubstituteVar(sqlquery, placeholder, value string) string {
|
||||
if isInsideDollarQuote(sqlquery, placeholder) {
|
||||
return value
|
||||
}
|
||||
return ValidSQL(value, "colvalue")
|
||||
}
|
||||
|
||||
// getReplacementForBlankParam determines the replacement value for an unused parameter
|
||||
// based on whether it appears within quotes in the SQL query.
|
||||
// It checks for PostgreSQL quotes: single quotes (”) and dollar quotes ($...$)
|
||||
|
||||
@@ -836,7 +836,7 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url
|
||||
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
||||
// First, read the existing record from the database
|
||||
existingRecord := reflect.New(reflection.GetPointerElement(reflect.TypeOf(model))).Interface()
|
||||
selectQuery := tx.NewSelect().Model(existingRecord).Column("*")
|
||||
selectQuery := tx.NewSelect().Model(existingRecord).Column(reflection.GetSQLModelColumns(model)...)
|
||||
|
||||
// Apply conditions to select
|
||||
if urlID != "" {
|
||||
@@ -955,13 +955,34 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url
|
||||
return
|
||||
}
|
||||
|
||||
// Fetch the updated record after the transaction commits to capture any trigger changes
|
||||
updatedRecord := reflect.New(reflection.GetPointerElement(reflect.TypeOf(model))).Interface()
|
||||
fetchQuery := h.db.NewSelect().Model(updatedRecord).Column(reflection.GetSQLModelColumns(model)...)
|
||||
if urlID != "" {
|
||||
fetchQuery = fetchQuery.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), urlID)
|
||||
} else if reqID != nil {
|
||||
switch id := reqID.(type) {
|
||||
case string:
|
||||
fetchQuery = fetchQuery.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id)
|
||||
case []string:
|
||||
if len(id) > 0 {
|
||||
fetchQuery = fetchQuery.Where(fmt.Sprintf("%s IN (?)", common.QuoteIdent(pkName)), id)
|
||||
}
|
||||
}
|
||||
}
|
||||
if err := fetchQuery.ScanModel(ctx); err != nil {
|
||||
logger.Error("Failed to fetch updated record: %v", err)
|
||||
h.sendError(w, http.StatusInternalServerError, "fetch_error", "Failed to fetch updated record", err)
|
||||
return
|
||||
}
|
||||
|
||||
logger.Info("Successfully updated record(s)")
|
||||
// Invalidate cache for this table
|
||||
cacheTags := buildCacheTags(schema, tableName)
|
||||
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
||||
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
|
||||
}
|
||||
h.sendResponse(w, data, nil)
|
||||
h.sendResponse(w, updatedRecord, nil)
|
||||
|
||||
case []map[string]interface{}:
|
||||
// Batch update with array of objects
|
||||
@@ -1017,7 +1038,7 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url
|
||||
|
||||
// First, read the existing record
|
||||
existingRecord := reflect.New(reflection.GetPointerElement(reflect.TypeOf(model))).Interface()
|
||||
selectQuery := tx.NewSelect().Model(existingRecord).Column("*").Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), itemID)
|
||||
selectQuery := tx.NewSelect().Model(existingRecord).Column(reflection.GetSQLModelColumns(model)...).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), itemID)
|
||||
if err := selectQuery.ScanModel(ctx); err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
continue // Skip if record not found
|
||||
@@ -1089,13 +1110,29 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url
|
||||
h.sendError(w, http.StatusInternalServerError, "update_error", "Error updating records", err)
|
||||
return
|
||||
}
|
||||
logger.Info("Successfully updated %d records", len(updates))
|
||||
|
||||
// Fetch updated records after the transaction commits to capture any trigger changes
|
||||
fetchedUpdates := make([]interface{}, 0, len(updates))
|
||||
for _, item := range updates {
|
||||
if itemID, ok := item["id"]; ok && itemID != nil {
|
||||
fetchedRecord := reflect.New(reflection.GetPointerElement(reflect.TypeOf(model))).Interface()
|
||||
fetchQuery := h.db.NewSelect().Model(fetchedRecord).Column(reflection.GetSQLModelColumns(model)...).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), itemID)
|
||||
if err := fetchQuery.ScanModel(ctx); err != nil {
|
||||
logger.Error("Failed to fetch updated record with ID %v: %v", itemID, err)
|
||||
h.sendError(w, http.StatusInternalServerError, "fetch_error", "Failed to fetch updated record", err)
|
||||
return
|
||||
}
|
||||
fetchedUpdates = append(fetchedUpdates, fetchedRecord)
|
||||
}
|
||||
}
|
||||
|
||||
logger.Info("Successfully updated %d records", len(fetchedUpdates))
|
||||
// Invalidate cache for this table
|
||||
cacheTags := buildCacheTags(schema, tableName)
|
||||
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
||||
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
|
||||
}
|
||||
h.sendResponse(w, updates, nil)
|
||||
h.sendResponse(w, fetchedUpdates, nil)
|
||||
|
||||
case []interface{}:
|
||||
// Batch update with []interface{}
|
||||
@@ -1157,7 +1194,7 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url
|
||||
|
||||
// First, read the existing record
|
||||
existingRecord := reflect.New(reflection.GetPointerElement(reflect.TypeOf(model))).Interface()
|
||||
selectQuery := tx.NewSelect().Model(existingRecord).Column("*").Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), itemID)
|
||||
selectQuery := tx.NewSelect().Model(existingRecord).Column(reflection.GetSQLModelColumns(model)...).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), itemID)
|
||||
if err := selectQuery.ScanModel(ctx); err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
continue // Skip if record not found
|
||||
@@ -1232,13 +1269,31 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url
|
||||
h.sendError(w, http.StatusInternalServerError, "update_error", "Error updating records", err)
|
||||
return
|
||||
}
|
||||
logger.Info("Successfully updated %d records", len(list))
|
||||
|
||||
// Fetch updated records after the transaction commits to capture any trigger changes
|
||||
fetchedList := make([]interface{}, 0, len(list))
|
||||
for _, item := range list {
|
||||
if itemMap, ok := item.(map[string]interface{}); ok {
|
||||
if itemID, ok := itemMap["id"]; ok && itemID != nil {
|
||||
fetchedRecord := reflect.New(reflection.GetPointerElement(reflect.TypeOf(model))).Interface()
|
||||
fetchQuery := h.db.NewSelect().Model(fetchedRecord).Column(reflection.GetSQLModelColumns(model)...).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), itemID)
|
||||
if err := fetchQuery.ScanModel(ctx); err != nil {
|
||||
logger.Error("Failed to fetch updated record with ID %v: %v", itemID, err)
|
||||
h.sendError(w, http.StatusInternalServerError, "fetch_error", "Failed to fetch updated record", err)
|
||||
return
|
||||
}
|
||||
fetchedList = append(fetchedList, fetchedRecord)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
logger.Info("Successfully updated %d records", len(fetchedList))
|
||||
// Invalidate cache for this table
|
||||
cacheTags := buildCacheTags(schema, tableName)
|
||||
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
||||
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
|
||||
}
|
||||
h.sendResponse(w, list, nil)
|
||||
h.sendResponse(w, fetchedList, nil)
|
||||
|
||||
default:
|
||||
logger.Error("Invalid data type for update operation: %T", data)
|
||||
|
||||
+13
-14
@@ -1218,8 +1218,8 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
|
||||
if provider, ok := modelValue.(common.TableNameProvider); !ok || provider.TableName() == "" {
|
||||
query = query.Table(tableName)
|
||||
}
|
||||
|
||||
query = query.Returning("*")
|
||||
fields := reflection.GetSQLModelColumns(model)
|
||||
query = query.Returning(fields...)
|
||||
|
||||
// Execute BeforeScan hooks - pass query chain so hooks can modify it
|
||||
itemHookCtx := &HookContext{
|
||||
@@ -1480,18 +1480,7 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id
|
||||
}
|
||||
}
|
||||
|
||||
// Fetch the updated record to return the new values
|
||||
modelValue := reflect.New(reflect.TypeOf(model)).Interface()
|
||||
selectQuery = tx.NewSelect().Model(modelValue).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), targetID)
|
||||
if err := selectQuery.ScanModel(ctx); err != nil {
|
||||
return fmt.Errorf("failed to fetch updated record: %w", err)
|
||||
}
|
||||
|
||||
updatedRecord = modelValue
|
||||
|
||||
// Store result for hooks
|
||||
hookCtx.Result = updatedRecord
|
||||
_ = result // Keep result variable for potential future use
|
||||
_ = result
|
||||
return nil
|
||||
})
|
||||
|
||||
@@ -1501,6 +1490,16 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id
|
||||
return
|
||||
}
|
||||
|
||||
// Fetch the updated record after the transaction commits to capture any trigger changes
|
||||
fetchedRecord := reflect.New(reflect.TypeOf(model)).Interface()
|
||||
selectQuery := h.db.NewSelect().Model(fetchedRecord).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), targetID)
|
||||
if err := selectQuery.ScanModel(ctx); err != nil {
|
||||
logger.Error("Failed to fetch updated record: %v", err)
|
||||
h.sendError(w, http.StatusInternalServerError, "fetch_error", "Failed to fetch updated record", err)
|
||||
return
|
||||
}
|
||||
updatedRecord = fetchedRecord
|
||||
|
||||
// Merge the updated record with the original request data
|
||||
// This preserves extra keys from the request and updates values from the database
|
||||
mergedData := h.mergeRecordWithRequest(updatedRecord, dataMap)
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
@@ -140,9 +141,21 @@ func (h *Handler) parseOptionsFromHeaders(r common.Request, model interface{}) E
|
||||
combinedParams[strings.ToLower(key)] = value
|
||||
}
|
||||
|
||||
sortedKeys := make([]string, 0, len(combinedParams))
|
||||
for key := range combinedParams {
|
||||
sortedKeys = append(sortedKeys, key)
|
||||
}
|
||||
sort.Slice(sortedKeys, func(i, j int) bool {
|
||||
if sortedKeys[i] != sortedKeys[j] {
|
||||
return sortedKeys[i] < sortedKeys[j]
|
||||
}
|
||||
return combinedParams[sortedKeys[i]] < combinedParams[sortedKeys[j]]
|
||||
})
|
||||
|
||||
// Process each parameter (from both headers and query params)
|
||||
// Note: keys are already normalized to lowercase in combinedParams
|
||||
for key, value := range combinedParams {
|
||||
for _, key := range sortedKeys {
|
||||
value := combinedParams[key]
|
||||
// Decode value if it's base64 encoded
|
||||
decodedValue := decodeHeaderValue(value)
|
||||
|
||||
|
||||
@@ -70,6 +70,25 @@ func (m *mountPoint) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
// Try to open the file
|
||||
file, err := m.provider.Open(strings.TrimPrefix(filePath, "/"))
|
||||
if err != nil {
|
||||
// For extensionless paths, also try path/index.html
|
||||
if path.Ext(filePath) == "" {
|
||||
indexFallback := path.Join(filePath, "index.html")
|
||||
file, err = m.provider.Open(strings.TrimPrefix(indexFallback, "/"))
|
||||
if err == nil {
|
||||
defer file.Close()
|
||||
m.serveFile(w, r, indexFallback, file)
|
||||
return
|
||||
}
|
||||
|
||||
indexFallback = fmt.Sprintf("%s.html", filePath)
|
||||
file, err = m.provider.Open(strings.TrimPrefix(indexFallback, "/"))
|
||||
if err == nil {
|
||||
defer file.Close()
|
||||
m.serveFile(w, r, indexFallback, file)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// File doesn't exist - check if we should use fallback
|
||||
if m.fallbackStrategy != nil && m.fallbackStrategy.ShouldFallback(filePath) {
|
||||
fallbackPath := m.fallbackStrategy.GetFallbackPath(filePath)
|
||||
@@ -80,16 +99,6 @@ func (m *mountPoint) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// For extensionless paths, also try path/index.html
|
||||
if path.Ext(filePath) == "" {
|
||||
indexFallback := path.Join(filePath, "index.html")
|
||||
file, err = m.provider.Open(strings.TrimPrefix(indexFallback, "/"))
|
||||
if err == nil {
|
||||
defer file.Close()
|
||||
m.serveFile(w, r, indexFallback, file)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// No fallback or fallback failed - return 404
|
||||
|
||||
Reference in New Issue
Block a user