mirror of
https://github.com/bitechdev/ResolveSpec.git
synced 2025-11-13 18:03:53 +00:00
Some validation and header decoding
This commit is contained in:
parent
07c239aaa1
commit
f518a3c73c
272
pkg/common/validation.go
Normal file
272
pkg/common/validation.go
Normal file
@ -0,0 +1,272 @@
|
|||||||
|
package common
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/Warky-Devs/ResolveSpec/pkg/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ColumnValidator validates column names against a model's fields
|
||||||
|
type ColumnValidator struct {
|
||||||
|
validColumns map[string]bool
|
||||||
|
model interface{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewColumnValidator creates a new column validator for a given model
|
||||||
|
func NewColumnValidator(model interface{}) *ColumnValidator {
|
||||||
|
validator := &ColumnValidator{
|
||||||
|
validColumns: make(map[string]bool),
|
||||||
|
model: model,
|
||||||
|
}
|
||||||
|
validator.buildValidColumns()
|
||||||
|
return validator
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildValidColumns extracts all valid column names from the model using reflection
|
||||||
|
func (v *ColumnValidator) buildValidColumns() {
|
||||||
|
modelType := reflect.TypeOf(v.model)
|
||||||
|
|
||||||
|
// Unwrap pointers, slices, and arrays to get to the base struct type
|
||||||
|
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
||||||
|
modelType = modelType.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate that we have a struct type
|
||||||
|
if modelType == nil || modelType.Kind() != reflect.Struct {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract column names from struct fields
|
||||||
|
for i := 0; i < modelType.NumField(); i++ {
|
||||||
|
field := modelType.Field(i)
|
||||||
|
|
||||||
|
if !field.IsExported() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get column name from bun, gorm, or json tag
|
||||||
|
columnName := v.getColumnName(field)
|
||||||
|
if columnName != "" && columnName != "-" {
|
||||||
|
v.validColumns[strings.ToLower(columnName)] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// getColumnName extracts the column name from a struct field's tags
|
||||||
|
// Supports both Bun and GORM tags
|
||||||
|
func (v *ColumnValidator) getColumnName(field reflect.StructField) string {
|
||||||
|
// First check Bun tag for column name
|
||||||
|
bunTag := field.Tag.Get("bun")
|
||||||
|
if bunTag != "" && bunTag != "-" {
|
||||||
|
parts := strings.Split(bunTag, ",")
|
||||||
|
// The first part is usually the column name
|
||||||
|
columnName := strings.TrimSpace(parts[0])
|
||||||
|
if columnName != "" && columnName != "-" {
|
||||||
|
return columnName
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check GORM tag for column name
|
||||||
|
gormTag := field.Tag.Get("gorm")
|
||||||
|
if strings.Contains(gormTag, "column:") {
|
||||||
|
parts := strings.Split(gormTag, ";")
|
||||||
|
for _, part := range parts {
|
||||||
|
part = strings.TrimSpace(part)
|
||||||
|
if strings.HasPrefix(part, "column:") {
|
||||||
|
return strings.TrimPrefix(part, "column:")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fall back to JSON tag
|
||||||
|
jsonTag := field.Tag.Get("json")
|
||||||
|
if jsonTag != "" && jsonTag != "-" {
|
||||||
|
// Extract just the name part (before any comma)
|
||||||
|
jsonName := strings.Split(jsonTag, ",")[0]
|
||||||
|
return jsonName
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fall back to field name in lowercase (snake_case conversion would be better)
|
||||||
|
return strings.ToLower(field.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateColumn validates a single column name
|
||||||
|
// Returns nil if valid, error if invalid
|
||||||
|
// Columns prefixed with "cql" (case insensitive) are always valid
|
||||||
|
func (v *ColumnValidator) ValidateColumn(column string) error {
|
||||||
|
// Allow empty columns
|
||||||
|
if column == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Allow columns prefixed with "cql" (case insensitive) for computed columns
|
||||||
|
if strings.HasPrefix(strings.ToLower(column), "cql") {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if column exists in model
|
||||||
|
if _, exists := v.validColumns[strings.ToLower(column)]; !exists {
|
||||||
|
return fmt.Errorf("invalid column '%s': column does not exist in model", column)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsValidColumn checks if a column is valid
|
||||||
|
// Returns true if valid, false if invalid
|
||||||
|
func (v *ColumnValidator) IsValidColumn(column string) bool {
|
||||||
|
return v.ValidateColumn(column) == nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// FilterValidColumns filters a list of columns, returning only valid ones
|
||||||
|
// Logs warnings for any invalid columns
|
||||||
|
func (v *ColumnValidator) FilterValidColumns(columns []string) []string {
|
||||||
|
if len(columns) == 0 {
|
||||||
|
return columns
|
||||||
|
}
|
||||||
|
|
||||||
|
validColumns := make([]string, 0, len(columns))
|
||||||
|
for _, col := range columns {
|
||||||
|
if v.IsValidColumn(col) {
|
||||||
|
validColumns = append(validColumns, col)
|
||||||
|
} else {
|
||||||
|
logger.Warn("Invalid column '%s' filtered out: column does not exist in model", col)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return validColumns
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateColumns validates multiple column names
|
||||||
|
// Returns error with details about all invalid columns
|
||||||
|
func (v *ColumnValidator) ValidateColumns(columns []string) error {
|
||||||
|
var invalidColumns []string
|
||||||
|
|
||||||
|
for _, column := range columns {
|
||||||
|
if err := v.ValidateColumn(column); err != nil {
|
||||||
|
invalidColumns = append(invalidColumns, column)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(invalidColumns) > 0 {
|
||||||
|
return fmt.Errorf("invalid columns: %s", strings.Join(invalidColumns, ", "))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateRequestOptions validates all column references in RequestOptions
|
||||||
|
func (v *ColumnValidator) ValidateRequestOptions(options RequestOptions) error {
|
||||||
|
// Validate Columns
|
||||||
|
if err := v.ValidateColumns(options.Columns); err != nil {
|
||||||
|
return fmt.Errorf("in select columns: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate OmitColumns
|
||||||
|
if err := v.ValidateColumns(options.OmitColumns); err != nil {
|
||||||
|
return fmt.Errorf("in omit columns: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate Filter columns
|
||||||
|
for _, filter := range options.Filters {
|
||||||
|
if err := v.ValidateColumn(filter.Column); err != nil {
|
||||||
|
return fmt.Errorf("in filter: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate Sort columns
|
||||||
|
for _, sort := range options.Sort {
|
||||||
|
if err := v.ValidateColumn(sort.Column); err != nil {
|
||||||
|
return fmt.Errorf("in sort: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate Preload columns (if specified)
|
||||||
|
for _, preload := range options.Preload {
|
||||||
|
// Note: We don't validate the relation name itself, as it's a relationship
|
||||||
|
// Only validate columns if specified for the preload
|
||||||
|
if err := v.ValidateColumns(preload.Columns); err != nil {
|
||||||
|
return fmt.Errorf("in preload '%s' columns: %w", preload.Relation, err)
|
||||||
|
}
|
||||||
|
if err := v.ValidateColumns(preload.OmitColumns); err != nil {
|
||||||
|
return fmt.Errorf("in preload '%s' omit columns: %w", preload.Relation, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate filter columns in preload
|
||||||
|
for _, filter := range preload.Filters {
|
||||||
|
if err := v.ValidateColumn(filter.Column); err != nil {
|
||||||
|
return fmt.Errorf("in preload '%s' filter: %w", preload.Relation, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// FilterRequestOptions filters all column references in RequestOptions
|
||||||
|
// Returns a new RequestOptions with only valid columns, logging warnings for invalid ones
|
||||||
|
func (v *ColumnValidator) FilterRequestOptions(options RequestOptions) RequestOptions {
|
||||||
|
filtered := options
|
||||||
|
|
||||||
|
// Filter Columns
|
||||||
|
filtered.Columns = v.FilterValidColumns(options.Columns)
|
||||||
|
|
||||||
|
// Filter OmitColumns
|
||||||
|
filtered.OmitColumns = v.FilterValidColumns(options.OmitColumns)
|
||||||
|
|
||||||
|
// Filter Filter columns
|
||||||
|
validFilters := make([]FilterOption, 0, len(options.Filters))
|
||||||
|
for _, filter := range options.Filters {
|
||||||
|
if v.IsValidColumn(filter.Column) {
|
||||||
|
validFilters = append(validFilters, filter)
|
||||||
|
} else {
|
||||||
|
logger.Warn("Invalid column in filter '%s' removed", filter.Column)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
filtered.Filters = validFilters
|
||||||
|
|
||||||
|
// Filter Sort columns
|
||||||
|
validSorts := make([]SortOption, 0, len(options.Sort))
|
||||||
|
for _, sort := range options.Sort {
|
||||||
|
if v.IsValidColumn(sort.Column) {
|
||||||
|
validSorts = append(validSorts, sort)
|
||||||
|
} else {
|
||||||
|
logger.Warn("Invalid column in sort '%s' removed", sort.Column)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
filtered.Sort = validSorts
|
||||||
|
|
||||||
|
// Filter Preload columns
|
||||||
|
validPreloads := make([]PreloadOption, 0, len(options.Preload))
|
||||||
|
for _, preload := range options.Preload {
|
||||||
|
filteredPreload := preload
|
||||||
|
filteredPreload.Columns = v.FilterValidColumns(preload.Columns)
|
||||||
|
filteredPreload.OmitColumns = v.FilterValidColumns(preload.OmitColumns)
|
||||||
|
|
||||||
|
// Filter preload filters
|
||||||
|
validPreloadFilters := make([]FilterOption, 0, len(preload.Filters))
|
||||||
|
for _, filter := range preload.Filters {
|
||||||
|
if v.IsValidColumn(filter.Column) {
|
||||||
|
validPreloadFilters = append(validPreloadFilters, filter)
|
||||||
|
} else {
|
||||||
|
logger.Warn("Invalid column in preload '%s' filter '%s' removed", preload.Relation, filter.Column)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
filteredPreload.Filters = validPreloadFilters
|
||||||
|
|
||||||
|
validPreloads = append(validPreloads, filteredPreload)
|
||||||
|
}
|
||||||
|
filtered.Preload = validPreloads
|
||||||
|
|
||||||
|
return filtered
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetValidColumns returns a list of all valid column names for debugging purposes
|
||||||
|
func (v *ColumnValidator) GetValidColumns() []string {
|
||||||
|
columns := make([]string, 0, len(v.validColumns))
|
||||||
|
for col := range v.validColumns {
|
||||||
|
columns = append(columns, col)
|
||||||
|
}
|
||||||
|
return columns
|
||||||
|
}
|
||||||
363
pkg/common/validation_test.go
Normal file
363
pkg/common/validation_test.go
Normal file
@ -0,0 +1,363 @@
|
|||||||
|
package common
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestModel represents a sample model for testing
|
||||||
|
type TestModel struct {
|
||||||
|
ID int64 `json:"id" gorm:"primaryKey"`
|
||||||
|
Name string `json:"name" gorm:"column:name"`
|
||||||
|
Email string `json:"email" bun:"email"`
|
||||||
|
Age int `json:"age"`
|
||||||
|
IsActive bool `json:"is_active"`
|
||||||
|
CreatedAt string `json:"created_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewColumnValidator(t *testing.T) {
|
||||||
|
model := TestModel{}
|
||||||
|
validator := NewColumnValidator(model)
|
||||||
|
|
||||||
|
if validator == nil {
|
||||||
|
t.Fatal("Expected validator to be created")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(validator.validColumns) == 0 {
|
||||||
|
t.Fatal("Expected validator to have valid columns")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that expected columns are present
|
||||||
|
expectedColumns := []string{"id", "name", "email", "age", "is_active", "created_at"}
|
||||||
|
for _, col := range expectedColumns {
|
||||||
|
if !validator.validColumns[col] {
|
||||||
|
t.Errorf("Expected column '%s' to be valid", col)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateColumn(t *testing.T) {
|
||||||
|
model := TestModel{}
|
||||||
|
validator := NewColumnValidator(model)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
column string
|
||||||
|
shouldError bool
|
||||||
|
}{
|
||||||
|
{"Valid column - id", "id", false},
|
||||||
|
{"Valid column - name", "name", false},
|
||||||
|
{"Valid column - email", "email", false},
|
||||||
|
{"Valid column - uppercase", "ID", false}, // Case insensitive
|
||||||
|
{"Invalid column", "invalid_column", true},
|
||||||
|
{"CQL prefixed - should be valid", "cqlComputedField", false},
|
||||||
|
{"CQL prefixed uppercase - should be valid", "CQLComputedField", false},
|
||||||
|
{"Empty column", "", false}, // Empty columns are allowed
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
err := validator.ValidateColumn(tt.column)
|
||||||
|
if tt.shouldError && err == nil {
|
||||||
|
t.Errorf("Expected error for column '%s', got nil", tt.column)
|
||||||
|
}
|
||||||
|
if !tt.shouldError && err != nil {
|
||||||
|
t.Errorf("Expected no error for column '%s', got: %v", tt.column, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateColumns(t *testing.T) {
|
||||||
|
model := TestModel{}
|
||||||
|
validator := NewColumnValidator(model)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
columns []string
|
||||||
|
shouldError bool
|
||||||
|
}{
|
||||||
|
{"All valid columns", []string{"id", "name", "email"}, false},
|
||||||
|
{"One invalid column", []string{"id", "invalid_col", "name"}, true},
|
||||||
|
{"All invalid columns", []string{"bad1", "bad2"}, true},
|
||||||
|
{"With CQL prefix", []string{"id", "cqlComputed", "name"}, false},
|
||||||
|
{"Empty list", []string{}, false},
|
||||||
|
{"Nil list", nil, false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
err := validator.ValidateColumns(tt.columns)
|
||||||
|
if tt.shouldError && err == nil {
|
||||||
|
t.Errorf("Expected error for columns %v, got nil", tt.columns)
|
||||||
|
}
|
||||||
|
if !tt.shouldError && err != nil {
|
||||||
|
t.Errorf("Expected no error for columns %v, got: %v", tt.columns, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateRequestOptions(t *testing.T) {
|
||||||
|
model := TestModel{}
|
||||||
|
validator := NewColumnValidator(model)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
options RequestOptions
|
||||||
|
shouldError bool
|
||||||
|
errorMsg string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Valid options with columns",
|
||||||
|
options: RequestOptions{
|
||||||
|
Columns: []string{"id", "name"},
|
||||||
|
Filters: []FilterOption{
|
||||||
|
{Column: "name", Operator: "eq", Value: "test"},
|
||||||
|
},
|
||||||
|
Sort: []SortOption{
|
||||||
|
{Column: "id", Direction: "ASC"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
shouldError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Invalid column in Columns",
|
||||||
|
options: RequestOptions{
|
||||||
|
Columns: []string{"id", "invalid_column"},
|
||||||
|
},
|
||||||
|
shouldError: true,
|
||||||
|
errorMsg: "select columns",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Invalid column in Filters",
|
||||||
|
options: RequestOptions{
|
||||||
|
Filters: []FilterOption{
|
||||||
|
{Column: "invalid_col", Operator: "eq", Value: "test"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
shouldError: true,
|
||||||
|
errorMsg: "filter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Invalid column in Sort",
|
||||||
|
options: RequestOptions{
|
||||||
|
Sort: []SortOption{
|
||||||
|
{Column: "invalid_col", Direction: "ASC"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
shouldError: true,
|
||||||
|
errorMsg: "sort",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Valid CQL prefixed columns",
|
||||||
|
options: RequestOptions{
|
||||||
|
Columns: []string{"id", "cqlComputedField"},
|
||||||
|
Filters: []FilterOption{
|
||||||
|
{Column: "cqlCustomFilter", Operator: "eq", Value: "test"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
shouldError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Invalid column in Preload",
|
||||||
|
options: RequestOptions{
|
||||||
|
Preload: []PreloadOption{
|
||||||
|
{
|
||||||
|
Relation: "SomeRelation",
|
||||||
|
Columns: []string{"id", "invalid_col"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
shouldError: true,
|
||||||
|
errorMsg: "preload",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Valid preload with valid columns",
|
||||||
|
options: RequestOptions{
|
||||||
|
Preload: []PreloadOption{
|
||||||
|
{
|
||||||
|
Relation: "SomeRelation",
|
||||||
|
Columns: []string{"id", "name"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
shouldError: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
err := validator.ValidateRequestOptions(tt.options)
|
||||||
|
if tt.shouldError {
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("Expected error, got nil")
|
||||||
|
} else if tt.errorMsg != "" && !strings.Contains(err.Error(), tt.errorMsg) {
|
||||||
|
t.Errorf("Expected error to contain '%s', got: %v", tt.errorMsg, err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Expected no error, got: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetValidColumns(t *testing.T) {
|
||||||
|
model := TestModel{}
|
||||||
|
validator := NewColumnValidator(model)
|
||||||
|
|
||||||
|
columns := validator.GetValidColumns()
|
||||||
|
if len(columns) == 0 {
|
||||||
|
t.Error("Expected to get valid columns, got empty list")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should have at least the columns from TestModel
|
||||||
|
if len(columns) < 6 {
|
||||||
|
t.Errorf("Expected at least 6 columns, got %d", len(columns))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test with Bun tags specifically
|
||||||
|
type BunModel struct {
|
||||||
|
ID int64 `bun:"id,pk"`
|
||||||
|
Name string `bun:"name"`
|
||||||
|
Email string `bun:"user_email"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBunTagSupport(t *testing.T) {
|
||||||
|
model := BunModel{}
|
||||||
|
validator := NewColumnValidator(model)
|
||||||
|
|
||||||
|
// Test that bun tags are properly recognized
|
||||||
|
tests := []struct {
|
||||||
|
column string
|
||||||
|
shouldError bool
|
||||||
|
}{
|
||||||
|
{"id", false},
|
||||||
|
{"name", false},
|
||||||
|
{"user_email", false}, // Bun tag specifies this name
|
||||||
|
{"email", true}, // JSON tag would be "email", but bun tag says "user_email"
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.column, func(t *testing.T) {
|
||||||
|
err := validator.ValidateColumn(tt.column)
|
||||||
|
if tt.shouldError && err == nil {
|
||||||
|
t.Errorf("Expected error for column '%s'", tt.column)
|
||||||
|
}
|
||||||
|
if !tt.shouldError && err != nil {
|
||||||
|
t.Errorf("Expected no error for column '%s', got: %v", tt.column, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFilterValidColumns(t *testing.T) {
|
||||||
|
model := TestModel{}
|
||||||
|
validator := NewColumnValidator(model)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input []string
|
||||||
|
expectedOutput []string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "All valid columns",
|
||||||
|
input: []string{"id", "name", "email"},
|
||||||
|
expectedOutput: []string{"id", "name", "email"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Mix of valid and invalid",
|
||||||
|
input: []string{"id", "invalid_col", "name", "bad_col", "email"},
|
||||||
|
expectedOutput: []string{"id", "name", "email"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "All invalid columns",
|
||||||
|
input: []string{"bad1", "bad2"},
|
||||||
|
expectedOutput: []string{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "With CQL prefix (should pass)",
|
||||||
|
input: []string{"id", "cqlComputed", "name"},
|
||||||
|
expectedOutput: []string{"id", "cqlComputed", "name"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Empty input",
|
||||||
|
input: []string{},
|
||||||
|
expectedOutput: []string{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Nil input",
|
||||||
|
input: nil,
|
||||||
|
expectedOutput: nil,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := validator.FilterValidColumns(tt.input)
|
||||||
|
if len(result) != len(tt.expectedOutput) {
|
||||||
|
t.Errorf("Expected %d columns, got %d", len(tt.expectedOutput), len(result))
|
||||||
|
}
|
||||||
|
for i, col := range result {
|
||||||
|
if col != tt.expectedOutput[i] {
|
||||||
|
t.Errorf("At index %d: expected %s, got %s", i, tt.expectedOutput[i], col)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFilterRequestOptions(t *testing.T) {
|
||||||
|
model := TestModel{}
|
||||||
|
validator := NewColumnValidator(model)
|
||||||
|
|
||||||
|
options := RequestOptions{
|
||||||
|
Columns: []string{"id", "name", "invalid_col"},
|
||||||
|
OmitColumns: []string{"email", "bad_col"},
|
||||||
|
Filters: []FilterOption{
|
||||||
|
{Column: "name", Operator: "eq", Value: "test"},
|
||||||
|
{Column: "invalid_col", Operator: "eq", Value: "test"},
|
||||||
|
},
|
||||||
|
Sort: []SortOption{
|
||||||
|
{Column: "id", Direction: "ASC"},
|
||||||
|
{Column: "bad_col", Direction: "DESC"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
filtered := validator.FilterRequestOptions(options)
|
||||||
|
|
||||||
|
// Check Columns
|
||||||
|
if len(filtered.Columns) != 2 {
|
||||||
|
t.Errorf("Expected 2 columns, got %d", len(filtered.Columns))
|
||||||
|
}
|
||||||
|
if filtered.Columns[0] != "id" || filtered.Columns[1] != "name" {
|
||||||
|
t.Errorf("Expected columns [id, name], got %v", filtered.Columns)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check OmitColumns
|
||||||
|
if len(filtered.OmitColumns) != 1 {
|
||||||
|
t.Errorf("Expected 1 omit column, got %d", len(filtered.OmitColumns))
|
||||||
|
}
|
||||||
|
if filtered.OmitColumns[0] != "email" {
|
||||||
|
t.Errorf("Expected omit column [email], got %v", filtered.OmitColumns)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check Filters
|
||||||
|
if len(filtered.Filters) != 1 {
|
||||||
|
t.Errorf("Expected 1 filter, got %d", len(filtered.Filters))
|
||||||
|
}
|
||||||
|
if filtered.Filters[0].Column != "name" {
|
||||||
|
t.Errorf("Expected filter column 'name', got %s", filtered.Filters[0].Column)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check Sort
|
||||||
|
if len(filtered.Sort) != 1 {
|
||||||
|
t.Errorf("Expected 1 sort, got %d", len(filtered.Sort))
|
||||||
|
}
|
||||||
|
if filtered.Sort[0].Column != "id" {
|
||||||
|
t.Errorf("Expected sort column 'id', got %s", filtered.Sort[0].Column)
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -100,6 +100,10 @@ func (h *Handler) Handle(w common.ResponseWriter, r common.Request, params map[s
|
|||||||
// Add request-scoped data to context
|
// Add request-scoped data to context
|
||||||
ctx = WithRequestData(ctx, schema, entity, tableName, model, modelPtr)
|
ctx = WithRequestData(ctx, schema, entity, tableName, model, modelPtr)
|
||||||
|
|
||||||
|
// Validate and filter columns in options (log warnings for invalid columns)
|
||||||
|
validator := common.NewColumnValidator(model)
|
||||||
|
req.Options = validator.FilterRequestOptions(req.Options)
|
||||||
|
|
||||||
switch req.Operation {
|
switch req.Operation {
|
||||||
case "read":
|
case "read":
|
||||||
h.handleRead(ctx, w, id, req.Options)
|
h.handleRead(ctx, w, id, req.Options)
|
||||||
|
|||||||
@ -93,6 +93,10 @@ func (h *Handler) Handle(w common.ResponseWriter, r common.Request, params map[s
|
|||||||
// Add request-scoped data to context
|
// Add request-scoped data to context
|
||||||
ctx = WithRequestData(ctx, schema, entity, tableName, model, modelPtr)
|
ctx = WithRequestData(ctx, schema, entity, tableName, model, modelPtr)
|
||||||
|
|
||||||
|
// Validate and filter columns in options (log warnings for invalid columns)
|
||||||
|
validator := common.NewColumnValidator(model)
|
||||||
|
options = filterExtendedOptions(validator, options)
|
||||||
|
|
||||||
switch method {
|
switch method {
|
||||||
case "GET":
|
case "GET":
|
||||||
if id != "" {
|
if id != "" {
|
||||||
@ -750,3 +754,41 @@ func (h *Handler) sendError(w common.ResponseWriter, statusCode int, code, messa
|
|||||||
w.WriteHeader(statusCode)
|
w.WriteHeader(statusCode)
|
||||||
w.WriteJSON(response)
|
w.WriteJSON(response)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// filterExtendedOptions filters all column references, removing invalid ones and logging warnings
|
||||||
|
func filterExtendedOptions(validator *common.ColumnValidator, options ExtendedRequestOptions) ExtendedRequestOptions {
|
||||||
|
filtered := options
|
||||||
|
|
||||||
|
// Filter base RequestOptions
|
||||||
|
filtered.RequestOptions = validator.FilterRequestOptions(options.RequestOptions)
|
||||||
|
|
||||||
|
// Filter SearchColumns
|
||||||
|
filtered.SearchColumns = validator.FilterValidColumns(options.SearchColumns)
|
||||||
|
|
||||||
|
// Filter AdvancedSQL column keys
|
||||||
|
filteredAdvSQL := make(map[string]string)
|
||||||
|
for colName, sqlExpr := range options.AdvancedSQL {
|
||||||
|
if validator.IsValidColumn(colName) {
|
||||||
|
filteredAdvSQL[colName] = sqlExpr
|
||||||
|
} else {
|
||||||
|
logger.Warn("Invalid column in advanced SQL removed: %s", colName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
filtered.AdvancedSQL = filteredAdvSQL
|
||||||
|
|
||||||
|
// ComputedQL columns are allowed to be any name since they're computed
|
||||||
|
// No filtering needed for ComputedQL keys
|
||||||
|
filtered.ComputedQL = options.ComputedQL
|
||||||
|
|
||||||
|
// Filter Expand columns
|
||||||
|
filteredExpands := make([]ExpandOption, 0, len(options.Expand))
|
||||||
|
for _, expand := range options.Expand {
|
||||||
|
filteredExpand := expand
|
||||||
|
// Don't validate relation name, only columns
|
||||||
|
filteredExpand.Columns = validator.FilterValidColumns(expand.Columns)
|
||||||
|
filteredExpands = append(filteredExpands, filteredExpand)
|
||||||
|
}
|
||||||
|
filtered.Expand = filteredExpands
|
||||||
|
|
||||||
|
return filtered
|
||||||
|
}
|
||||||
|
|||||||
@ -57,27 +57,43 @@ type ExpandOption struct {
|
|||||||
// decodeHeaderValue decodes base64 encoded header values
|
// decodeHeaderValue decodes base64 encoded header values
|
||||||
// Supports ZIP_ and __ prefixes for base64 encoding
|
// Supports ZIP_ and __ prefixes for base64 encoding
|
||||||
func decodeHeaderValue(value string) string {
|
func decodeHeaderValue(value string) string {
|
||||||
// Check for ZIP_ prefix
|
str, _ := DecodeParam(value)
|
||||||
if strings.HasPrefix(value, "ZIP_") {
|
return str
|
||||||
decoded, err := base64.StdEncoding.DecodeString(value[4:])
|
|
||||||
if err == nil {
|
|
||||||
return string(decoded)
|
|
||||||
}
|
|
||||||
logger.Warn("Failed to decode ZIP_ prefixed value: %v", err)
|
|
||||||
return value
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check for __ prefix
|
// DecodeParam - Decodes parameter string and returns unencoded string
|
||||||
if strings.HasPrefix(value, "__") {
|
func DecodeParam(pStr string) (string, error) {
|
||||||
decoded, err := base64.StdEncoding.DecodeString(value[2:])
|
var code string = pStr
|
||||||
if err == nil {
|
if strings.HasPrefix(pStr, "ZIP_") {
|
||||||
return string(decoded)
|
code = strings.ReplaceAll(pStr, "ZIP_", "")
|
||||||
|
code = strings.ReplaceAll(code, "\n", "")
|
||||||
|
code = strings.ReplaceAll(code, "\r", "")
|
||||||
|
code = strings.ReplaceAll(code, " ", "")
|
||||||
|
strDat, err := base64.StdEncoding.DecodeString(code)
|
||||||
|
if err != nil {
|
||||||
|
return code, fmt.Errorf("failed to read parameter base64: %v", err)
|
||||||
|
} else {
|
||||||
|
code = string(strDat)
|
||||||
|
}
|
||||||
|
} else if strings.HasPrefix(pStr, "__") {
|
||||||
|
code = strings.ReplaceAll(pStr, "__", "")
|
||||||
|
code = strings.ReplaceAll(code, "\n", "")
|
||||||
|
code = strings.ReplaceAll(code, "\r", "")
|
||||||
|
code = strings.ReplaceAll(code, " ", "")
|
||||||
|
|
||||||
|
strDat, err := base64.StdEncoding.DecodeString(code)
|
||||||
|
if err != nil {
|
||||||
|
return code, fmt.Errorf("failed to read parameter base64: %v", err)
|
||||||
|
} else {
|
||||||
|
code = string(strDat)
|
||||||
}
|
}
|
||||||
logger.Warn("Failed to decode __ prefixed value: %v", err)
|
|
||||||
return value
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return value
|
if strings.HasPrefix(code, "ZIP_") || strings.HasPrefix(code, "__") {
|
||||||
|
code, _ = DecodeParam(code)
|
||||||
|
}
|
||||||
|
|
||||||
|
return code, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// parseOptionsFromHeaders parses all request options from HTTP headers
|
// parseOptionsFromHeaders parses all request options from HTTP headers
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user