mirror of
https://github.com/bitechdev/ResolveSpec.git
synced 2026-06-04 12:53:45 +00:00
Compare commits
8 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 4018af0636 | |||
| c4e79d6950 | |||
| 982a0e62ac | |||
| 5d459c95a7 | |||
| e9f7726e43 | |||
| 3d2251317a | |||
| 1ce0ab1ab4 | |||
| 1f9b230f7f |
@@ -3,6 +3,7 @@ package common
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
@@ -43,7 +44,7 @@ func (v *ColumnValidator) buildValidColumns() {
|
||||
for i := 0; i < modelType.NumField(); i++ {
|
||||
field := modelType.Field(i)
|
||||
|
||||
if !field.IsExported() {
|
||||
if !field.IsExported() || field.Anonymous {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -125,6 +126,16 @@ func (v *ColumnValidator) IsValidColumn(column string) bool {
|
||||
return v.ValidateColumn(column) == nil
|
||||
}
|
||||
|
||||
// Columns returns all valid column names known to this validator
|
||||
func (v *ColumnValidator) Columns() []string {
|
||||
cols := make([]string, 0, len(v.validColumns))
|
||||
for col := range v.validColumns {
|
||||
cols = append(cols, col)
|
||||
}
|
||||
sort.Strings(cols)
|
||||
return cols
|
||||
}
|
||||
|
||||
// FilterValidColumns filters a list of columns, returning only valid ones
|
||||
// Logs warnings for any invalid columns
|
||||
func (v *ColumnValidator) FilterValidColumns(columns []string) []string {
|
||||
@@ -224,7 +235,19 @@ func (v *ColumnValidator) FilterRequestOptions(options RequestOptions) RequestOp
|
||||
// Filter Filter columns
|
||||
validFilters := make([]FilterOption, 0, len(options.Filters))
|
||||
for _, filter := range options.Filters {
|
||||
if v.IsValidColumn(filter.Column) {
|
||||
if strings.EqualFold(filter.Column, "all") {
|
||||
allCols := v.Columns()
|
||||
if len(filtered.Columns) > 0 {
|
||||
allCols = filtered.Columns
|
||||
}
|
||||
for _, col := range allCols {
|
||||
expanded := filter
|
||||
expanded.Column = col
|
||||
expanded.LogicOperator = "OR"
|
||||
|
||||
validFilters = append(validFilters, expanded)
|
||||
}
|
||||
} else if v.IsValidColumn(filter.Column) {
|
||||
validFilters = append(validFilters, filter)
|
||||
} else {
|
||||
logger.Warn("Invalid column in filter '%s' removed", filter.Column)
|
||||
@@ -266,11 +289,24 @@ func (v *ColumnValidator) FilterRequestOptions(options RequestOptions) RequestOp
|
||||
|
||||
// Filter Preload columns
|
||||
validPreloads := make([]PreloadOption, 0, len(options.Preload))
|
||||
modelType := reflect.TypeOf(v.model)
|
||||
if modelType != nil && modelType.Kind() == reflect.Ptr {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
for idx := range options.Preload {
|
||||
preload := options.Preload[idx]
|
||||
filteredPreload := preload
|
||||
filteredPreload.Columns = v.FilterValidColumns(preload.Columns)
|
||||
filteredPreload.OmitColumns = v.FilterValidColumns(preload.OmitColumns)
|
||||
|
||||
// Use the related model's validator for preload columns/filters/sorts
|
||||
preloadValidator := v
|
||||
if modelType != nil {
|
||||
if relInfo := GetRelationshipInfo(modelType, preload.Relation); relInfo != nil && relInfo.RelatedModel != nil {
|
||||
preloadValidator = NewColumnValidator(relInfo.RelatedModel)
|
||||
}
|
||||
}
|
||||
|
||||
filteredPreload.Columns = preloadValidator.FilterValidColumns(preload.Columns)
|
||||
filteredPreload.OmitColumns = preloadValidator.FilterValidColumns(preload.OmitColumns)
|
||||
|
||||
// Preserve SqlJoins and JoinAliases for preloads with custom joins
|
||||
filteredPreload.SqlJoins = preload.SqlJoins
|
||||
@@ -279,7 +315,7 @@ func (v *ColumnValidator) FilterRequestOptions(options RequestOptions) RequestOp
|
||||
// Filter preload filters
|
||||
validPreloadFilters := make([]FilterOption, 0, len(preload.Filters))
|
||||
for _, filter := range preload.Filters {
|
||||
if v.IsValidColumn(filter.Column) {
|
||||
if preloadValidator.IsValidColumn(filter.Column) {
|
||||
validPreloadFilters = append(validPreloadFilters, filter)
|
||||
} else {
|
||||
// Check if the filter column references a joined table alias
|
||||
@@ -302,7 +338,7 @@ func (v *ColumnValidator) FilterRequestOptions(options RequestOptions) RequestOp
|
||||
// Filter preload sort columns
|
||||
validPreloadSorts := make([]SortOption, 0, len(preload.Sort))
|
||||
for _, sort := range preload.Sort {
|
||||
if v.IsValidColumn(sort.Column) {
|
||||
if preloadValidator.IsValidColumn(sort.Column) {
|
||||
validPreloadSorts = append(validPreloadSorts, sort)
|
||||
} else if strings.HasPrefix(sort.Column, "(") && strings.HasSuffix(sort.Column, ")") {
|
||||
// Allow sort by expression/subquery, but validate for security
|
||||
|
||||
@@ -464,3 +464,84 @@ func TestFilterRequestOptions_WithSortExpressions(t *testing.T) {
|
||||
t.Errorf("Expected third sort to be 'name', got '%s'", filtered.Sort[2].Column)
|
||||
}
|
||||
}
|
||||
|
||||
// RelatedModel is used by PreloadParentModel to test preload column validation.
|
||||
type RelatedModel struct {
|
||||
RelatedID int64 `bun:"related_id,pk"`
|
||||
Functionname string `bun:"functionname"`
|
||||
}
|
||||
|
||||
// PreloadParentModel has a has-one relation to RelatedModel. The json tag on
|
||||
// the relation field is the name used in x-preload headers.
|
||||
type PreloadParentModel struct {
|
||||
ID int64 `bun:"id,pk"`
|
||||
Name string `bun:"name"`
|
||||
RELATED *RelatedModel `json:"RELATED" bun:"rel:has-one,join:id=related_id"`
|
||||
}
|
||||
|
||||
// TestFilterRequestOptions_PreloadColumnsValidatedAgainstRelatedModel verifies
|
||||
// that preload columns are validated against the related model's fields, not the
|
||||
// parent model's fields. This is the fix for the bug where specifying a column
|
||||
// that exists only on the relation (e.g. "functionname") was incorrectly filtered
|
||||
// out because it doesn't exist on the parent model.
|
||||
func TestFilterRequestOptions_PreloadColumnsValidatedAgainstRelatedModel(t *testing.T) {
|
||||
validator := NewColumnValidator(PreloadParentModel{})
|
||||
|
||||
options := RequestOptions{
|
||||
Preload: []PreloadOption{
|
||||
{
|
||||
Relation: "RELATED",
|
||||
// "functionname" exists on RelatedModel but NOT on PreloadParentModel.
|
||||
// "name" exists on PreloadParentModel but NOT on RelatedModel.
|
||||
// "nonexistent" exists on neither.
|
||||
Columns: []string{"functionname", "name", "nonexistent"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
filtered := validator.FilterRequestOptions(options)
|
||||
|
||||
if len(filtered.Preload) != 1 {
|
||||
t.Fatalf("Expected 1 preload, got %d", len(filtered.Preload))
|
||||
}
|
||||
|
||||
cols := filtered.Preload[0].Columns
|
||||
// Only "functionname" should survive: it belongs to RelatedModel.
|
||||
if len(cols) != 1 {
|
||||
t.Errorf("Expected 1 preload column, got %d: %v", len(cols), cols)
|
||||
}
|
||||
if len(cols) > 0 && cols[0] != "functionname" {
|
||||
t.Errorf("Expected preload column 'functionname', got '%s'", cols[0])
|
||||
}
|
||||
}
|
||||
|
||||
// TestFilterRequestOptions_PreloadColumnsParentModelFallback verifies that when
|
||||
// a preload relation is not found on the parent model, column validation falls
|
||||
// back to the parent model's validator (no panic, no silent pass-through).
|
||||
func TestFilterRequestOptions_PreloadColumnsParentModelFallback(t *testing.T) {
|
||||
validator := NewColumnValidator(PreloadParentModel{})
|
||||
|
||||
options := RequestOptions{
|
||||
Preload: []PreloadOption{
|
||||
{
|
||||
Relation: "UNKNOWN_RELATION",
|
||||
Columns: []string{"id", "functionname"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
filtered := validator.FilterRequestOptions(options)
|
||||
|
||||
if len(filtered.Preload) != 1 {
|
||||
t.Fatalf("Expected 1 preload, got %d", len(filtered.Preload))
|
||||
}
|
||||
|
||||
cols := filtered.Preload[0].Columns
|
||||
// Falls back to parent model: only "id" is valid on PreloadParentModel.
|
||||
if len(cols) != 1 {
|
||||
t.Errorf("Expected 1 preload column (fallback to parent), got %d: %v", len(cols), cols)
|
||||
}
|
||||
if len(cols) > 0 && cols[0] != "id" {
|
||||
t.Errorf("Expected preload column 'id', got '%s'", cols[0])
|
||||
}
|
||||
}
|
||||
|
||||
@@ -174,7 +174,7 @@ func (h *Handler) SqlQueryList(sqlquery string, options SqlQueryOptions) HTTPFun
|
||||
varName := kw[1 : len(kw)-1] // strip [ and ]
|
||||
if val, ok := variables[varName]; ok {
|
||||
if strVal := fmt.Sprintf("%v", val); strVal != "" {
|
||||
sqlquery = strings.ReplaceAll(sqlquery, kw, ValidSQL(strVal, "colvalue"))
|
||||
sqlquery = strings.ReplaceAll(sqlquery, kw, safeSubstituteVar(sqlquery, kw, strVal))
|
||||
continue
|
||||
}
|
||||
}
|
||||
@@ -533,7 +533,7 @@ func (h *Handler) SqlQuery(sqlquery string, options SqlQueryOptions) HTTPFuncTyp
|
||||
varName := kw[1 : len(kw)-1] // strip [ and ]
|
||||
if val, ok := variables[varName]; ok {
|
||||
if strVal := fmt.Sprintf("%v", val); strVal != "" {
|
||||
sqlquery = strings.ReplaceAll(sqlquery, kw, ValidSQL(strVal, "colvalue"))
|
||||
sqlquery = strings.ReplaceAll(sqlquery, kw, safeSubstituteVar(sqlquery, kw, strVal))
|
||||
continue
|
||||
}
|
||||
}
|
||||
@@ -1006,6 +1006,37 @@ func IsNumeric(s string) bool {
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// isInsideDollarQuote reports whether the first occurrence of placeholder in sqlquery
|
||||
// is immediately surrounded by dollar-sign characters (i.e. inside a $...$-quoted string).
|
||||
// Dollar-quoted strings pass content through literally — no backslash processing — so
|
||||
// values placed there must NOT have their backslashes escaped.
|
||||
func isInsideDollarQuote(sqlquery, placeholder string) bool {
|
||||
idx := strings.Index(sqlquery, placeholder)
|
||||
if idx < 0 {
|
||||
return false
|
||||
}
|
||||
endIdx := idx + len(placeholder)
|
||||
charBefore := byte(0)
|
||||
charAfter := byte(0)
|
||||
if idx > 0 {
|
||||
charBefore = sqlquery[idx-1]
|
||||
}
|
||||
if endIdx < len(sqlquery) {
|
||||
charAfter = sqlquery[endIdx]
|
||||
}
|
||||
return charBefore == '$' || charAfter == '$'
|
||||
}
|
||||
|
||||
// safeSubstituteVar returns value sanitised for the quoting context that surrounds
|
||||
// placeholder in sqlquery: raw (no backslash escaping) for dollar-quoted contexts,
|
||||
// ValidSQL("colvalue") escaping for everything else.
|
||||
func safeSubstituteVar(sqlquery, placeholder, value string) string {
|
||||
if isInsideDollarQuote(sqlquery, placeholder) {
|
||||
return value
|
||||
}
|
||||
return ValidSQL(value, "colvalue")
|
||||
}
|
||||
|
||||
// getReplacementForBlankParam determines the replacement value for an unused parameter
|
||||
// based on whether it appears within quotes in the SQL query.
|
||||
// It checks for PostgreSQL quotes: single quotes (”) and dollar quotes ($...$)
|
||||
|
||||
@@ -6,9 +6,9 @@ import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
@@ -102,11 +102,6 @@ func DecodeParam(pStr string) (string, error) {
|
||||
|
||||
if strings.HasPrefix(code, "ZIP_") || strings.HasPrefix(code, "__") {
|
||||
code, _ = DecodeParam(code)
|
||||
} else {
|
||||
strDat, err := base64.StdEncoding.DecodeString(code)
|
||||
if err == nil && utf8.Valid(strDat) {
|
||||
code = string(strDat)
|
||||
}
|
||||
}
|
||||
|
||||
return code, nil
|
||||
@@ -146,9 +141,21 @@ func (h *Handler) parseOptionsFromHeaders(r common.Request, model interface{}) E
|
||||
combinedParams[strings.ToLower(key)] = value
|
||||
}
|
||||
|
||||
sortedKeys := make([]string, 0, len(combinedParams))
|
||||
for key := range combinedParams {
|
||||
sortedKeys = append(sortedKeys, key)
|
||||
}
|
||||
sort.Slice(sortedKeys, func(i, j int) bool {
|
||||
if sortedKeys[i] != sortedKeys[j] {
|
||||
return sortedKeys[i] < sortedKeys[j]
|
||||
}
|
||||
return combinedParams[sortedKeys[i]] < combinedParams[sortedKeys[j]]
|
||||
})
|
||||
|
||||
// Process each parameter (from both headers and query params)
|
||||
// Note: keys are already normalized to lowercase in combinedParams
|
||||
for key, value := range combinedParams {
|
||||
for _, key := range sortedKeys {
|
||||
value := combinedParams[key]
|
||||
// Decode value if it's base64 encoded
|
||||
decodedValue := decodeHeaderValue(value)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user