Compare commits

...

11 Commits

Author SHA1 Message Date
Hein
68dee78a34 Fixed filterExtendedOptions 2025-12-10 12:25:23 +02:00
Hein
efb9e5d9d5 Removed the buggy filter expand columns 2025-12-10 12:15:18 +02:00
Hein
490ae37c6d Fixed bugs in extractTableAndColumn 2025-12-10 11:48:03 +02:00
Hein
99307e31e6 More debugging on bun for scan issues 2025-12-10 11:16:25 +02:00
Hein
e3f7869c6d Bun scan debugging 2025-12-10 11:07:18 +02:00
Hein
c696d502c5 extractTableAndColumn 2025-12-10 10:10:55 +02:00
Hein
4ed1fba6ad Fixed extractTableAndColumn 2025-12-10 10:10:43 +02:00
Hein
1d0407a16d Fixed linting
Some checks are pending
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Waiting to run
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Waiting to run
Build , Vet Test, and Lint / Lint Code (push) Waiting to run
Build , Vet Test, and Lint / Build (push) Waiting to run
Tests / Unit Tests (push) Waiting to run
Tests / Integration Tests (push) Waiting to run
2025-12-10 10:00:01 +02:00
Hein
99001c749d Better sql where validation 2025-12-10 09:52:13 +02:00
Hein
1f7a57f8e3 Tracking provider 2025-12-10 09:31:55 +02:00
Hein
a95c28a0bf Multi Token warning and handling 2025-12-10 08:44:37 +02:00
17 changed files with 1279 additions and 107 deletions

View File

@@ -71,35 +71,18 @@
},
"gocritic": {
"enabled-checks": [
"appendAssign",
"assignOp",
"boolExprSimplify",
"builtinShadow",
"captLocal",
"caseOrder",
"defaultCaseOrder",
"dupArg",
"dupBranchBody",
"dupCase",
"dupSubExpr",
"elseif",
"emptyFallthrough",
"equalFold",
"flagName",
"indexAlloc",
"initClause",
"methodExprCall",
"nilValReturn",
"rangeExprCopy",
"rangeValCopy",
"regexpMust",
"singleCaseSwitch",
"sloppyLen",
"stringXbytes",
"switchTrue",
"typeAssertChain",
"typeSwitchVar",
"underef",
"unlabelStmt",
"unnamedResult",
"unnecessaryBlock",

5
go.mod
View File

@@ -5,12 +5,15 @@ go 1.24.0
toolchain go1.24.6
require (
github.com/DATA-DOG/go-sqlmock v1.5.2
github.com/bradfitz/gomemcache v0.0.0-20250403215159-8d39553ac7cf
github.com/getsentry/sentry-go v0.40.0
github.com/glebarez/sqlite v1.11.0
github.com/google/uuid v1.6.0
github.com/gorilla/mux v1.8.1
github.com/prometheus/client_golang v1.23.2
github.com/redis/go-redis/v9 v9.17.1
github.com/spf13/viper v1.21.0
github.com/stretchr/testify v1.11.1
github.com/tidwall/gjson v1.18.0
github.com/tidwall/sjson v1.2.5
@@ -30,7 +33,6 @@ require (
)
require (
github.com/DATA-DOG/go-sqlmock v1.5.2 // indirect
github.com/beorn7/perks v1.0.1 // indirect
github.com/cenkalti/backoff/v5 v5.0.3 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
@@ -65,7 +67,6 @@ require (
github.com/spf13/afero v1.15.0 // indirect
github.com/spf13/cast v1.10.0 // indirect
github.com/spf13/pflag v1.0.10 // indirect
github.com/spf13/viper v1.21.0 // indirect
github.com/subosito/gotenv v1.6.0 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.0 // indirect

10
go.sum
View File

@@ -19,12 +19,18 @@ github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/r
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k=
github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
github.com/getsentry/sentry-go v0.40.0 h1:VTJMN9zbTvqDqPwheRVLcp0qcUcM+8eFivvGocAaSbo=
github.com/getsentry/sentry-go v0.40.0/go.mod h1:eRXCoh3uvmjQLY6qu63BjUZnaBu5L5WhMV1RwYO8W5s=
github.com/glebarez/go-sqlite v1.21.2 h1:3a6LFC4sKahUunAmynQKLZceZCOzUthkRkEAl9gAXWo=
github.com/glebarez/go-sqlite v1.21.2/go.mod h1:sfxdZyhQjTM2Wry3gVYWaW072Ri1WMdWJi0k6+3382k=
github.com/glebarez/sqlite v1.11.0 h1:wSG0irqzP6VurnMEpFGer5Li19RpIRi2qvQz++w0GMw=
github.com/glebarez/sqlite v1.11.0/go.mod h1:h8/o8j5wiAsqSPoWELDUdJXhjAhsVliSn7bWZjOhrgQ=
github.com/go-errors/errors v1.4.2 h1:J6MZopCL4uSllY1OfXM374weqZFFItUbrImctkmUxIA=
github.com/go-errors/errors v1.4.2/go.mod h1:sIVyrIiJhuEF+Pj9Ebtd6P/rEYROXFi3BopGUQ5a5Og=
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
@@ -75,6 +81,10 @@ github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdh
github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4=
github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
github.com/pingcap/errors v0.11.4 h1:lFuQV/oaUMGcD2tqt+01ROSmJs75VG1ToEOkZIZ4nE4=
github.com/pingcap/errors v0.11.4/go.mod h1:Oi8TUi2kEtXXLMJk9l1cGmz20kV3TaQ0usTwv5KuLY8=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/prometheus/client_golang v1.23.2 h1:Je96obch5RDVy3FDMndoUsjAhG5Edi49h0RJWRi/o0o=

View File

@@ -34,6 +34,63 @@ func (h *QueryDebugHook) AfterQuery(ctx context.Context, event *bun.QueryEvent)
}
}
// debugScanIntoStruct attempts to scan rows into a struct with detailed field-level logging
// This helps identify which specific field is causing scanning issues
func debugScanIntoStruct(rows interface{}, dest interface{}) error {
v := reflect.ValueOf(dest)
if v.Kind() != reflect.Ptr {
return fmt.Errorf("dest must be a pointer")
}
v = v.Elem()
if v.Kind() != reflect.Struct && v.Kind() != reflect.Slice {
return fmt.Errorf("dest must be pointer to struct or slice")
}
// Log the type being scanned into
typeName := v.Type().String()
logger.Debug("Debug scan into type: %s (kind: %s)", typeName, v.Kind())
// Handle slice types - inspect the element type
var structType reflect.Type
if v.Kind() == reflect.Slice {
elemType := v.Type().Elem()
logger.Debug(" Slice element type: %s", elemType)
// If slice of pointers, get the underlying type
if elemType.Kind() == reflect.Ptr {
structType = elemType.Elem()
} else {
structType = elemType
}
} else if v.Kind() == reflect.Struct {
structType = v.Type()
}
// If we have a struct type, log all its fields
if structType != nil && structType.Kind() == reflect.Struct {
logger.Debug(" Struct %s has %d fields:", structType.Name(), structType.NumField())
for i := 0; i < structType.NumField(); i++ {
field := structType.Field(i)
// Log embedded fields specially
if field.Anonymous {
logger.Debug(" [%d] EMBEDDED: %s (type: %s, kind: %s, bun:%q)",
i, field.Name, field.Type, field.Type.Kind(), field.Tag.Get("bun"))
} else {
bunTag := field.Tag.Get("bun")
if bunTag == "" {
bunTag = "(no tag)"
}
logger.Debug(" [%d] %s (type: %s, kind: %s, bun:%q)",
i, field.Name, field.Type, field.Type.Kind(), bunTag)
}
}
}
return nil
}
// BunAdapter adapts Bun to work with our Database interface
// This demonstrates how the abstraction works with different ORMs
type BunAdapter struct {
@@ -52,6 +109,14 @@ func (b *BunAdapter) EnableQueryDebug() {
logger.Info("Bun query debug mode enabled - all SQL queries will be logged")
}
// EnableDetailedScanDebug enables verbose logging of scan operations
// WARNING: This generates a LOT of log output. Use only for debugging specific issues.
func (b *BunAdapter) EnableDetailedScanDebug() {
logger.Info("Detailed scan debugging enabled - will log all field scanning operations")
// This is a flag that can be checked in scan operations
// Implementation would require modifying the scan logic
}
// DisableQueryDebug removes all query hooks
func (b *BunAdapter) DisableQueryDebug() {
// Create a new DB without hooks
@@ -676,6 +741,31 @@ func (b *BunSelectQuery) Scan(ctx context.Context, dest interface{}) (err error)
func (b *BunSelectQuery) ScanModel(ctx context.Context) (err error) {
defer func() {
if r := recover(); r != nil {
// Enhanced panic recovery with model information
model := b.query.GetModel()
var modelInfo string
if model != nil && model.Value() != nil {
modelValue := model.Value()
modelInfo = fmt.Sprintf("Model type: %T", modelValue)
// Try to get the model's underlying struct type
v := reflect.ValueOf(modelValue)
if v.Kind() == reflect.Ptr {
v = v.Elem()
}
if v.Kind() == reflect.Slice {
if v.Type().Elem().Kind() == reflect.Ptr {
modelInfo += fmt.Sprintf(", Slice of: %s", v.Type().Elem().Elem().Name())
} else {
modelInfo += fmt.Sprintf(", Slice of: %s", v.Type().Elem().Name())
}
} else if v.Kind() == reflect.Struct {
modelInfo += fmt.Sprintf(", Struct: %s", v.Type().Name())
}
}
sqlStr := b.query.String()
logger.Error("Panic in BunSelectQuery.ScanModel: %v. %s. SQL: %s", r, modelInfo, sqlStr)
err = logger.HandlePanic("BunSelectQuery.ScanModel", r)
}
}()
@@ -683,6 +773,17 @@ func (b *BunSelectQuery) ScanModel(ctx context.Context) (err error) {
return fmt.Errorf("model is nil")
}
// Optional: Enable detailed field-level debugging (set to true to debug)
const enableDetailedDebug = true
if enableDetailedDebug {
model := b.query.GetModel()
if model != nil && model.Value() != nil {
if err := debugScanIntoStruct(nil, model.Value()); err != nil {
logger.Warn("Debug scan inspection failed: %v", err)
}
}
}
// Execute the main query first
err = b.query.Scan(ctx)
if err != nil {

View File

@@ -1,6 +1,7 @@
package common
import (
"fmt"
"strings"
"github.com/bitechdev/ResolveSpec/pkg/logger"
@@ -78,6 +79,41 @@ func IsTrivialCondition(cond string) bool {
return false
}
// validateWhereClauseSecurity checks for dangerous SQL statements in WHERE clauses
// Returns an error if any dangerous keywords are found
func validateWhereClauseSecurity(where string) error {
if where == "" {
return nil
}
lowerWhere := strings.ToLower(where)
// List of dangerous SQL keywords that should never appear in WHERE clauses
dangerousKeywords := []string{
"delete ", "delete\t", "delete\n", "delete;",
"update ", "update\t", "update\n", "update;",
"truncate ", "truncate\t", "truncate\n", "truncate;",
"drop ", "drop\t", "drop\n", "drop;",
"alter ", "alter\t", "alter\n", "alter;",
"create ", "create\t", "create\n", "create;",
"insert ", "insert\t", "insert\n", "insert;",
"grant ", "grant\t", "grant\n", "grant;",
"revoke ", "revoke\t", "revoke\n", "revoke;",
"exec ", "exec\t", "exec\n", "exec;",
"execute ", "execute\t", "execute\n", "execute;",
";delete", ";update", ";truncate", ";drop", ";alter", ";create", ";insert",
}
for _, keyword := range dangerousKeywords {
if strings.Contains(lowerWhere, keyword) {
logger.Error("Dangerous SQL keyword detected in WHERE clause: %s", strings.TrimSpace(keyword))
return fmt.Errorf("dangerous SQL keyword detected in WHERE clause: %s", strings.TrimSpace(keyword))
}
}
return nil
}
// SanitizeWhereClause removes trivial conditions and fixes incorrect table prefixes
// This function should be used everywhere a WHERE statement is sent to ensure clean, efficient SQL
//
@@ -100,6 +136,12 @@ func SanitizeWhereClause(where string, tableName string, options ...*RequestOpti
where = strings.TrimSpace(where)
// Validate that the WHERE clause doesn't contain dangerous SQL statements
if err := validateWhereClauseSecurity(where); err != nil {
logger.Debug("Security validation failed for WHERE clause: %v", err)
return ""
}
// Strip outer parentheses and re-trim
where = stripOuterParentheses(where)
@@ -221,19 +263,57 @@ func stripOuterParentheses(s string) string {
}
// splitByAND splits a WHERE clause by AND operators (case-insensitive)
// This is a simple split that doesn't handle nested parentheses or complex expressions
// This is parenthesis-aware and won't split on AND operators inside subqueries
func splitByAND(where string) []string {
// First try uppercase AND
conditions := strings.Split(where, " AND ")
conditions := []string{}
currentCondition := strings.Builder{}
depth := 0 // Track parenthesis depth
i := 0
// If we didn't split on uppercase, try lowercase
if len(conditions) == 1 {
conditions = strings.Split(where, " and ")
for i < len(where) {
ch := where[i]
// Track parenthesis depth
if ch == '(' {
depth++
currentCondition.WriteByte(ch)
i++
continue
} else if ch == ')' {
depth--
currentCondition.WriteByte(ch)
i++
continue
}
// Only look for AND operators at depth 0 (not inside parentheses)
if depth == 0 {
// Check if we're at an AND operator (case-insensitive)
// We need at least " AND " (5 chars) or " and " (5 chars)
if i+5 <= len(where) {
substring := where[i : i+5]
lowerSubstring := strings.ToLower(substring)
if lowerSubstring == " and " {
// Found an AND operator at the top level
// Add the current condition to the list
conditions = append(conditions, currentCondition.String())
currentCondition.Reset()
// Skip past the AND operator
i += 5
continue
}
}
}
// Not an AND operator or we're inside parentheses, just add the character
currentCondition.WriteByte(ch)
i++
}
// If we still didn't split, try mixed case
if len(conditions) == 1 {
conditions = strings.Split(where, " And ")
// Add the last condition
if currentCondition.Len() > 0 {
conditions = append(conditions, currentCondition.String())
}
return conditions
@@ -313,6 +393,7 @@ func getValidColumnsForTable(tableName string) map[string]bool {
// extractTableAndColumn extracts the table prefix and column name from a qualified reference
// For example: "users.status = 'active'" returns ("users", "status")
// Returns empty strings if no table prefix is found
// This function is parenthesis-aware and will only look for operators outside of subqueries
func extractTableAndColumn(cond string) (table string, column string) {
// Common SQL operators to find the column reference
operators := []string{" = ", " != ", " <> ", " > ", " >= ", " < ", " <= ", " LIKE ", " like ", " IN ", " in ", " IS ", " is "}
@@ -320,13 +401,20 @@ func extractTableAndColumn(cond string) (table string, column string) {
var columnRef string
// Find the column reference (left side of the operator)
// We need to find the first operator that appears OUTSIDE of parentheses
minIdx := -1
for _, op := range operators {
if idx := strings.Index(cond, op); idx > 0 {
columnRef = strings.TrimSpace(cond[:idx])
break
idx := findOperatorOutsideParentheses(cond, op)
if idx > 0 && (minIdx == -1 || idx < minIdx) {
minIdx = idx
}
}
if minIdx > 0 {
columnRef = strings.TrimSpace(cond[:minIdx])
}
// If no operator found, the whole condition might be the column reference
if columnRef == "" {
parts := strings.Fields(cond)
@@ -342,7 +430,45 @@ func extractTableAndColumn(cond string) (table string, column string) {
// Remove any quotes
columnRef = strings.Trim(columnRef, "`\"'")
// Check if it contains a dot (qualified reference)
// Check if there's a function call (contains opening parenthesis)
openParenIdx := strings.Index(columnRef, "(")
if openParenIdx >= 0 {
// There's a function call - find the FIRST dot after the opening paren
// This handles cases like: ifblnk(users.status, orders.status) - extracts users.status
dotIdx := strings.Index(columnRef[openParenIdx:], ".")
if dotIdx > 0 {
dotIdx += openParenIdx // Adjust to absolute position
// Extract table name (between paren and dot)
// Find the last opening paren before this dot
lastOpenParen := strings.LastIndex(columnRef[:dotIdx], "(")
table = columnRef[lastOpenParen+1 : dotIdx]
// Find the column name - it ends at comma, closing paren, whitespace, or end of string
columnStart := dotIdx + 1
columnEnd := len(columnRef)
for i := columnStart; i < len(columnRef); i++ {
ch := columnRef[i]
if ch == ',' || ch == ')' || ch == ' ' || ch == '\t' {
columnEnd = i
break
}
}
column = columnRef[columnStart:columnEnd]
// Remove quotes from table and column if present
table = strings.Trim(table, "`\"'")
column = strings.Trim(column, "`\"'")
return table, column
}
}
// No function call - check if it contains a dot (qualified reference)
// Use LastIndex to handle schema.table.column properly
if dotIdx := strings.LastIndex(columnRef, "."); dotIdx > 0 {
table = columnRef[:dotIdx]
column = columnRef[dotIdx+1:]
@@ -357,6 +483,53 @@ func extractTableAndColumn(cond string) (table string, column string) {
return "", ""
}
// findOperatorOutsideParentheses finds the first occurrence of an operator outside of parentheses
// Returns the index of the operator, or -1 if not found or only found inside parentheses
func findOperatorOutsideParentheses(s string, operator string) int {
depth := 0
inSingleQuote := false
inDoubleQuote := false
for i := 0; i < len(s); i++ {
ch := s[i]
// Track quote state (operators inside quotes should be ignored)
if ch == '\'' && !inDoubleQuote {
inSingleQuote = !inSingleQuote
continue
}
if ch == '"' && !inSingleQuote {
inDoubleQuote = !inDoubleQuote
continue
}
// Skip if we're inside quotes
if inSingleQuote || inDoubleQuote {
continue
}
// Track parenthesis depth
switch ch {
case '(':
depth++
case ')':
depth--
}
// Only look for the operator when we're outside parentheses (depth == 0)
if depth == 0 {
// Check if the operator starts at this position
if i+len(operator) <= len(s) {
if s[i:i+len(operator)] == operator {
return i
}
}
}
}
return -1
}
// isValidColumn checks if a column name exists in the valid columns map
// Handles case-insensitive comparison
func isValidColumn(columnName string, validColumns map[string]bool) bool {

View File

@@ -1,6 +1,7 @@
package common
import (
"strings"
"testing"
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
@@ -85,6 +86,54 @@ func TestSanitizeWhereClause(t *testing.T) {
tableName: "users",
expected: "users.status = 'active' AND users.age > 18",
},
{
name: "mixed case AND operators",
where: "status = 'active' AND age > 18 and name = 'John'",
tableName: "users",
expected: "status = 'active' AND age > 18 AND name = 'John'",
},
{
name: "subquery with ORDER BY and LIMIT - allowed",
where: "id IN (SELECT id FROM users WHERE status = 'active' ORDER BY created_at DESC LIMIT 10)",
tableName: "users",
expected: "id IN (SELECT id FROM users WHERE status = 'active' ORDER BY created_at DESC LIMIT 10)",
},
{
name: "dangerous DELETE keyword - blocked",
where: "status = 'active'; DELETE FROM users",
tableName: "users",
expected: "",
},
{
name: "dangerous UPDATE keyword - blocked",
where: "1=1; UPDATE users SET admin = true",
tableName: "users",
expected: "",
},
{
name: "dangerous TRUNCATE keyword - blocked",
where: "status = 'active' OR TRUNCATE TABLE users",
tableName: "users",
expected: "",
},
{
name: "dangerous DROP keyword - blocked",
where: "status = 'active'; DROP TABLE users",
tableName: "users",
expected: "",
},
{
name: "subquery with table alias should not be modified",
where: "apiprovider.rid_apiprovider in (select l.rid_apiprovider from core.apiproviderlink l where l.rid_hub = 2576)",
tableName: "apiprovider",
expected: "apiprovider.rid_apiprovider in (select l.rid_apiprovider from core.apiproviderlink l where l.rid_hub = 2576)",
},
{
name: "complex subquery with AND and multiple operators",
where: "apiprovider.type in ('softphone') AND (apiprovider.rid_apiprovider in (select l.rid_apiprovider from core.apiproviderlink l where l.rid_hub = 2576))",
tableName: "apiprovider",
expected: "apiprovider.type in ('softphone') AND (apiprovider.rid_apiprovider in (select l.rid_apiprovider from core.apiproviderlink l where l.rid_hub = 2576))",
},
}
for _, tt := range tests {
@@ -138,6 +187,11 @@ func TestStripOuterParentheses(t *testing.T) {
input: " ( true ) ",
expected: "true",
},
{
name: "complex sub query",
input: "(a = 1 AND b = 2 or c = 3 and (select s from generate_series(1,10) s where s < 10 and s > 0 offset 2 limit 1) = 3)",
expected: "a = 1 AND b = 2 or c = 3 and (select s from generate_series(1,10) s where s < 10 and s > 0 offset 2 limit 1) = 3",
},
}
for _, tt := range tests {
@@ -232,6 +286,48 @@ func TestExtractTableAndColumn(t *testing.T) {
expectedTable: "",
expectedCol: "",
},
{
name: "function call with table.column - ifblnk",
input: "ifblnk(users.status,0) in (1,2,3,4)",
expectedTable: "users",
expectedCol: "status",
},
{
name: "function call with table.column - coalesce",
input: "coalesce(users.age, 0) = 25",
expectedTable: "users",
expectedCol: "age",
},
{
name: "nested function calls",
input: "upper(trim(users.name)) = 'JOHN'",
expectedTable: "users",
expectedCol: "name",
},
{
name: "function with multiple args and table.column",
input: "substring(users.email, 1, 5) = 'admin'",
expectedTable: "users",
expectedCol: "email",
},
{
name: "cast function with table.column",
input: "cast(orders.total as decimal) > 100",
expectedTable: "orders",
expectedCol: "total",
},
{
name: "complex nested functions",
input: "coalesce(nullif(users.status, ''), 'default') = 'active'",
expectedTable: "users",
expectedCol: "status",
},
{
name: "function with multiple table.column refs (extracts first)",
input: "greatest(users.created_at, users.updated_at) > '2024-01-01'",
expectedTable: "users",
expectedCol: "created_at",
},
}
for _, tt := range tests {
@@ -298,6 +394,14 @@ func TestSanitizeWhereClauseWithPreloads(t *testing.T) {
},
expected: "users.status = 'active' AND Department.name = 'Engineering'",
},
{
name: "Function Call with correct table prefix - unchanged",
where: "ifblnk(users.status,0) in (1,2,3,4)",
tableName: "users",
options: nil,
expected: "ifblnk(users.status,0) in (1,2,3,4)",
},
{
name: "no options provided - works as before",
where: "wrong_table.status = 'active'",
@@ -337,6 +441,131 @@ type MasterTask struct {
UserID int `bun:"user_id"`
}
func TestSplitByAND(t *testing.T) {
tests := []struct {
name string
input string
expected []string
}{
{
name: "uppercase AND",
input: "status = 'active' AND age > 18",
expected: []string{"status = 'active'", "age > 18"},
},
{
name: "lowercase and",
input: "status = 'active' and age > 18",
expected: []string{"status = 'active'", "age > 18"},
},
{
name: "mixed case AND",
input: "status = 'active' AND age > 18 and name = 'John'",
expected: []string{"status = 'active'", "age > 18", "name = 'John'"},
},
{
name: "single condition",
input: "status = 'active'",
expected: []string{"status = 'active'"},
},
{
name: "multiple uppercase AND",
input: "a = 1 AND b = 2 AND c = 3",
expected: []string{"a = 1", "b = 2", "c = 3"},
},
{
name: "multiple case subquery",
input: "a = 1 AND b = 2 AND c = 3 and (select s from generate_series(1,10) s where s < 10 and s > 0 offset 2 limit 1) = 3",
expected: []string{"a = 1", "b = 2", "c = 3", "(select s from generate_series(1,10) s where s < 10 and s > 0 offset 2 limit 1) = 3"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := splitByAND(tt.input)
if len(result) != len(tt.expected) {
t.Errorf("splitByAND(%q) returned %d conditions; want %d", tt.input, len(result), len(tt.expected))
return
}
for i := range result {
if strings.TrimSpace(result[i]) != strings.TrimSpace(tt.expected[i]) {
t.Errorf("splitByAND(%q)[%d] = %q; want %q", tt.input, i, result[i], tt.expected[i])
}
}
})
}
}
func TestValidateWhereClauseSecurity(t *testing.T) {
tests := []struct {
name string
input string
expectError bool
}{
{
name: "safe WHERE clause",
input: "status = 'active' AND age > 18",
expectError: false,
},
{
name: "safe subquery",
input: "id IN (SELECT id FROM users WHERE status = 'active' ORDER BY created_at DESC LIMIT 10)",
expectError: false,
},
{
name: "DELETE keyword",
input: "status = 'active'; DELETE FROM users",
expectError: true,
},
{
name: "UPDATE keyword",
input: "1=1; UPDATE users SET admin = true",
expectError: true,
},
{
name: "TRUNCATE keyword",
input: "status = 'active' OR TRUNCATE TABLE users",
expectError: true,
},
{
name: "DROP keyword",
input: "status = 'active'; DROP TABLE users",
expectError: true,
},
{
name: "INSERT keyword",
input: "status = 'active'; INSERT INTO users (name) VALUES ('hacker')",
expectError: true,
},
{
name: "ALTER keyword",
input: "1=1; ALTER TABLE users ADD COLUMN is_admin BOOLEAN",
expectError: true,
},
{
name: "CREATE keyword",
input: "1=1; CREATE TABLE malicious (id INT)",
expectError: true,
},
{
name: "empty clause",
input: "",
expectError: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validateWhereClauseSecurity(tt.input)
if tt.expectError && err == nil {
t.Errorf("validateWhereClauseSecurity(%q) expected error but got none", tt.input)
}
if !tt.expectError && err != nil {
t.Errorf("validateWhereClauseSecurity(%q) unexpected error: %v", tt.input, err)
}
})
}
}
func TestSanitizeWhereClauseWithModel(t *testing.T) {
// Register the test model
err := modelregistry.RegisterModel(MasterTask{}, "mastertask")

View File

@@ -4,13 +4,14 @@ import "time"
// Config represents the complete application configuration
type Config struct {
Server ServerConfig `mapstructure:"server"`
Tracing TracingConfig `mapstructure:"tracing"`
Cache CacheConfig `mapstructure:"cache"`
Logger LoggerConfig `mapstructure:"logger"`
Middleware MiddlewareConfig `mapstructure:"middleware"`
CORS CORSConfig `mapstructure:"cors"`
Database DatabaseConfig `mapstructure:"database"`
Server ServerConfig `mapstructure:"server"`
Tracing TracingConfig `mapstructure:"tracing"`
Cache CacheConfig `mapstructure:"cache"`
Logger LoggerConfig `mapstructure:"logger"`
ErrorTracking ErrorTrackingConfig `mapstructure:"error_tracking"`
Middleware MiddlewareConfig `mapstructure:"middleware"`
CORS CORSConfig `mapstructure:"cors"`
Database DatabaseConfig `mapstructure:"database"`
}
// ServerConfig holds server-related configuration
@@ -78,3 +79,15 @@ type CORSConfig struct {
type DatabaseConfig struct {
URL string `mapstructure:"url"`
}
// ErrorTrackingConfig holds error tracking configuration
type ErrorTrackingConfig struct {
Enabled bool `mapstructure:"enabled"`
Provider string `mapstructure:"provider"` // sentry, noop
DSN string `mapstructure:"dsn"` // Sentry DSN
Environment string `mapstructure:"environment"` // e.g., production, staging, development
Release string `mapstructure:"release"` // Application version/release
Debug bool `mapstructure:"debug"` // Enable debug mode
SampleRate float64 `mapstructure:"sample_rate"` // Error sample rate (0.0-1.0)
TracesSampleRate float64 `mapstructure:"traces_sample_rate"` // Traces sample rate (0.0-1.0)
}

150
pkg/errortracking/README.md Normal file
View File

@@ -0,0 +1,150 @@
# Error Tracking
This package provides error tracking integration for ResolveSpec, with built-in support for Sentry.
## Features
- **Provider Interface**: Flexible design supporting multiple error tracking backends
- **Sentry Integration**: Full-featured Sentry support with automatic error, warning, and panic tracking
- **Automatic Logger Integration**: All `logger.Error()` and `logger.Warn()` calls are automatically sent to the error tracker
- **Panic Tracking**: Automatic panic capture with stack traces
- **NoOp Provider**: Zero-overhead when error tracking is disabled
## Configuration
Add error tracking configuration to your config file:
```yaml
error_tracking:
enabled: true
provider: "sentry" # Currently supports: "sentry" or "noop"
dsn: "https://your-sentry-dsn@sentry.io/project-id"
environment: "production" # e.g., production, staging, development
release: "v1.0.0" # Your application version
debug: false
sample_rate: 1.0 # Error sample rate (0.0-1.0)
traces_sample_rate: 0.1 # Traces sample rate (0.0-1.0)
```
## Usage
### Initialization
Initialize error tracking in your application startup:
```go
package main
import (
"github.com/bitechdev/ResolveSpec/pkg/config"
"github.com/bitechdev/ResolveSpec/pkg/errortracking"
"github.com/bitechdev/ResolveSpec/pkg/logger"
)
func main() {
// Load your configuration
cfg := config.Config{
ErrorTracking: config.ErrorTrackingConfig{
Enabled: true,
Provider: "sentry",
DSN: "https://your-sentry-dsn@sentry.io/project-id",
Environment: "production",
Release: "v1.0.0",
SampleRate: 1.0,
},
}
// Initialize logger
logger.Init(false)
// Initialize error tracking
provider, err := errortracking.NewProviderFromConfig(cfg.ErrorTracking)
if err != nil {
logger.Error("Failed to initialize error tracking: %v", err)
} else {
logger.InitErrorTracking(provider)
}
// Your application code...
// Cleanup on shutdown
defer logger.CloseErrorTracking()
}
```
### Automatic Tracking
Once initialized, all logger errors and warnings are automatically sent to the error tracker:
```go
// This will be logged AND sent to Sentry
logger.Error("Database connection failed: %v", err)
// This will also be logged AND sent to Sentry
logger.Warn("Cache miss for key: %s", key)
```
### Panic Tracking
Panics are automatically captured when using the logger's panic handlers:
```go
// Using CatchPanic
defer logger.CatchPanic("MyFunction")
// Using CatchPanicCallback
defer logger.CatchPanicCallback("MyFunction", func(err any) {
// Custom cleanup
})
// Using HandlePanic
defer func() {
if r := recover(); r != nil {
err = logger.HandlePanic("MyMethod", r)
}
}()
```
### Manual Tracking
You can also use the provider directly for custom error tracking:
```go
import (
"context"
"github.com/bitechdev/ResolveSpec/pkg/errortracking"
"github.com/bitechdev/ResolveSpec/pkg/logger"
)
func someFunction() {
tracker := logger.GetErrorTracker()
if tracker != nil {
// Capture an error
tracker.CaptureError(context.Background(), err, errortracking.SeverityError, map[string]interface{}{
"user_id": userID,
"request_id": requestID,
})
// Capture a message
tracker.CaptureMessage(context.Background(), "Important event occurred", errortracking.SeverityInfo, map[string]interface{}{
"event_type": "user_signup",
})
// Capture a panic
tracker.CapturePanic(context.Background(), recovered, stackTrace, map[string]interface{}{
"context": "background_job",
})
}
}
```
## Severity Levels
The package supports the following severity levels:
- `SeverityError`: For errors that should be tracked and investigated
- `SeverityWarning`: For warnings that may indicate potential issues
- `SeverityInfo`: For informational messages
- `SeverityDebug`: For debug-level information
```

View File

@@ -0,0 +1,67 @@
package errortracking
import (
"context"
"errors"
"testing"
)
func TestNoOpProvider(t *testing.T) {
provider := NewNoOpProvider()
// Test that all methods can be called without panicking
t.Run("CaptureError", func(t *testing.T) {
provider.CaptureError(context.Background(), errors.New("test error"), SeverityError, nil)
})
t.Run("CaptureMessage", func(t *testing.T) {
provider.CaptureMessage(context.Background(), "test message", SeverityWarning, nil)
})
t.Run("CapturePanic", func(t *testing.T) {
provider.CapturePanic(context.Background(), "panic!", []byte("stack trace"), nil)
})
t.Run("Flush", func(t *testing.T) {
result := provider.Flush(5)
if !result {
t.Error("Expected Flush to return true")
}
})
t.Run("Close", func(t *testing.T) {
err := provider.Close()
if err != nil {
t.Errorf("Expected Close to return nil, got %v", err)
}
})
}
func TestSeverityLevels(t *testing.T) {
tests := []struct {
name string
severity Severity
expected string
}{
{"Error", SeverityError, "error"},
{"Warning", SeverityWarning, "warning"},
{"Info", SeverityInfo, "info"},
{"Debug", SeverityDebug, "debug"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if string(tt.severity) != tt.expected {
t.Errorf("Expected %s, got %s", tt.expected, string(tt.severity))
}
})
}
}
func TestProviderInterface(t *testing.T) {
// Test that NoOpProvider implements Provider interface
var _ Provider = (*NoOpProvider)(nil)
// Test that SentryProvider implements Provider interface
var _ Provider = (*SentryProvider)(nil)
}

View File

@@ -0,0 +1,33 @@
package errortracking
import (
"fmt"
"github.com/bitechdev/ResolveSpec/pkg/config"
)
// NewProviderFromConfig creates an error tracking provider based on the configuration
func NewProviderFromConfig(cfg config.ErrorTrackingConfig) (Provider, error) {
if !cfg.Enabled {
return NewNoOpProvider(), nil
}
switch cfg.Provider {
case "sentry":
if cfg.DSN == "" {
return nil, fmt.Errorf("sentry DSN is required when error tracking is enabled")
}
return NewSentryProvider(SentryConfig{
DSN: cfg.DSN,
Environment: cfg.Environment,
Release: cfg.Release,
Debug: cfg.Debug,
SampleRate: cfg.SampleRate,
TracesSampleRate: cfg.TracesSampleRate,
})
case "noop", "":
return NewNoOpProvider(), nil
default:
return nil, fmt.Errorf("unknown error tracking provider: %s", cfg.Provider)
}
}

View File

@@ -0,0 +1,33 @@
package errortracking
import (
"context"
)
// Severity represents the severity level of an error
type Severity string
const (
SeverityError Severity = "error"
SeverityWarning Severity = "warning"
SeverityInfo Severity = "info"
SeverityDebug Severity = "debug"
)
// Provider defines the interface for error tracking providers
type Provider interface {
// CaptureError captures an error with the given severity and additional context
CaptureError(ctx context.Context, err error, severity Severity, extra map[string]interface{})
// CaptureMessage captures a message with the given severity and additional context
CaptureMessage(ctx context.Context, message string, severity Severity, extra map[string]interface{})
// CapturePanic captures a panic with stack trace
CapturePanic(ctx context.Context, recovered interface{}, stackTrace []byte, extra map[string]interface{})
// Flush waits for all events to be sent (useful for graceful shutdown)
Flush(timeout int) bool
// Close closes the provider and releases resources
Close() error
}

37
pkg/errortracking/noop.go Normal file
View File

@@ -0,0 +1,37 @@
package errortracking
import "context"
// NoOpProvider is a no-op implementation of the Provider interface
// Used when error tracking is disabled
type NoOpProvider struct{}
// NewNoOpProvider creates a new NoOp provider
func NewNoOpProvider() *NoOpProvider {
return &NoOpProvider{}
}
// CaptureError does nothing
func (n *NoOpProvider) CaptureError(ctx context.Context, err error, severity Severity, extra map[string]interface{}) {
// No-op
}
// CaptureMessage does nothing
func (n *NoOpProvider) CaptureMessage(ctx context.Context, message string, severity Severity, extra map[string]interface{}) {
// No-op
}
// CapturePanic does nothing
func (n *NoOpProvider) CapturePanic(ctx context.Context, recovered interface{}, stackTrace []byte, extra map[string]interface{}) {
// No-op
}
// Flush does nothing and returns true
func (n *NoOpProvider) Flush(timeout int) bool {
return true
}
// Close does nothing
func (n *NoOpProvider) Close() error {
return nil
}

154
pkg/errortracking/sentry.go Normal file
View File

@@ -0,0 +1,154 @@
package errortracking
import (
"context"
"fmt"
"time"
"github.com/getsentry/sentry-go"
)
// SentryProvider implements the Provider interface using Sentry
type SentryProvider struct {
hub *sentry.Hub
}
// SentryConfig holds the configuration for Sentry
type SentryConfig struct {
DSN string
Environment string
Release string
Debug bool
SampleRate float64
TracesSampleRate float64
}
// NewSentryProvider creates a new Sentry provider
func NewSentryProvider(config SentryConfig) (*SentryProvider, error) {
err := sentry.Init(sentry.ClientOptions{
Dsn: config.DSN,
Environment: config.Environment,
Release: config.Release,
Debug: config.Debug,
AttachStacktrace: true,
SampleRate: config.SampleRate,
TracesSampleRate: config.TracesSampleRate,
})
if err != nil {
return nil, fmt.Errorf("failed to initialize Sentry: %w", err)
}
return &SentryProvider{
hub: sentry.CurrentHub(),
}, nil
}
// CaptureError captures an error with the given severity and additional context
func (s *SentryProvider) CaptureError(ctx context.Context, err error, severity Severity, extra map[string]interface{}) {
if err == nil {
return
}
hub := sentry.GetHubFromContext(ctx)
if hub == nil {
hub = s.hub
}
event := sentry.NewEvent()
event.Level = s.convertSeverity(severity)
event.Message = err.Error()
event.Exception = []sentry.Exception{
{
Value: err.Error(),
Type: fmt.Sprintf("%T", err),
Stacktrace: sentry.ExtractStacktrace(err),
},
}
if extra != nil {
event.Extra = extra
}
hub.CaptureEvent(event)
}
// CaptureMessage captures a message with the given severity and additional context
func (s *SentryProvider) CaptureMessage(ctx context.Context, message string, severity Severity, extra map[string]interface{}) {
if message == "" {
return
}
hub := sentry.GetHubFromContext(ctx)
if hub == nil {
hub = s.hub
}
event := sentry.NewEvent()
event.Level = s.convertSeverity(severity)
event.Message = message
if extra != nil {
event.Extra = extra
}
hub.CaptureEvent(event)
}
// CapturePanic captures a panic with stack trace
func (s *SentryProvider) CapturePanic(ctx context.Context, recovered interface{}, stackTrace []byte, extra map[string]interface{}) {
if recovered == nil {
return
}
hub := sentry.GetHubFromContext(ctx)
if hub == nil {
hub = s.hub
}
event := sentry.NewEvent()
event.Level = sentry.LevelError
event.Message = fmt.Sprintf("Panic: %v", recovered)
event.Exception = []sentry.Exception{
{
Value: fmt.Sprintf("%v", recovered),
Type: "panic",
},
}
if extra != nil {
event.Extra = extra
}
if stackTrace != nil {
event.Extra["stack_trace"] = string(stackTrace)
}
hub.CaptureEvent(event)
}
// Flush waits for all events to be sent (useful for graceful shutdown)
func (s *SentryProvider) Flush(timeout int) bool {
return sentry.Flush(time.Duration(timeout) * time.Second)
}
// Close closes the provider and releases resources
func (s *SentryProvider) Close() error {
sentry.Flush(2 * time.Second)
return nil
}
// convertSeverity converts our Severity to Sentry's Level
func (s *SentryProvider) convertSeverity(severity Severity) sentry.Level {
switch severity {
case SeverityError:
return sentry.LevelError
case SeverityWarning:
return sentry.LevelWarning
case SeverityInfo:
return sentry.LevelInfo
case SeverityDebug:
return sentry.LevelDebug
default:
return sentry.LevelError
}
}

View File

@@ -1,15 +1,19 @@
package logger
import (
"context"
"fmt"
"log"
"os"
"runtime/debug"
"go.uber.org/zap"
errortracking "github.com/bitechdev/ResolveSpec/pkg/errortracking"
)
var Logger *zap.SugaredLogger
var errorTracker errortracking.Provider
func Init(dev bool) {
@@ -49,6 +53,28 @@ func UpdateLogger(config *zap.Config) {
Info("ResolveSpec Logger initialized")
}
// InitErrorTracking initializes the error tracking provider
func InitErrorTracking(provider errortracking.Provider) {
errorTracker = provider
if errorTracker != nil {
Info("Error tracking initialized")
}
}
// GetErrorTracker returns the current error tracking provider
func GetErrorTracker() errortracking.Provider {
return errorTracker
}
// CloseErrorTracking flushes and closes the error tracking provider
func CloseErrorTracking() error {
if errorTracker != nil {
errorTracker.Flush(5)
return errorTracker.Close()
}
return nil
}
func Info(template string, args ...interface{}) {
if Logger == nil {
log.Printf(template, args...)
@@ -58,19 +84,35 @@ func Info(template string, args ...interface{}) {
}
func Warn(template string, args ...interface{}) {
message := fmt.Sprintf(template, args...)
if Logger == nil {
log.Printf(template, args...)
return
log.Printf("%s", message)
} else {
Logger.Warnw(message, "process_id", os.Getpid())
}
// Send to error tracker
if errorTracker != nil {
errorTracker.CaptureMessage(context.Background(), message, errortracking.SeverityWarning, map[string]interface{}{
"process_id": os.Getpid(),
})
}
Logger.Warnw(fmt.Sprintf(template, args...), "process_id", os.Getpid())
}
func Error(template string, args ...interface{}) {
message := fmt.Sprintf(template, args...)
if Logger == nil {
log.Printf(template, args...)
return
log.Printf("%s", message)
} else {
Logger.Errorw(message, "process_id", os.Getpid())
}
// Send to error tracker
if errorTracker != nil {
errorTracker.CaptureMessage(context.Background(), message, errortracking.SeverityError, map[string]interface{}{
"process_id": os.Getpid(),
})
}
Logger.Errorw(fmt.Sprintf(template, args...), "process_id", os.Getpid())
}
func Debug(template string, args ...interface{}) {
@@ -84,7 +126,7 @@ func Debug(template string, args ...interface{}) {
// CatchPanic - Handle panic
func CatchPanicCallback(location string, cb func(err any)) {
if err := recover(); err != nil {
// callstack := debug.Stack()
callstack := debug.Stack()
if Logger != nil {
Error("Panic in %s : %v", location, err)
@@ -93,14 +135,13 @@ func CatchPanicCallback(location string, cb func(err any)) {
debug.PrintStack()
}
// push to sentry
// hub := sentry.CurrentHub()
// if hub != nil {
// evtID := hub.Recover(err)
// if evtID != nil {
// sentry.Flush(time.Second * 2)
// }
// }
// Send to error tracker
if errorTracker != nil {
errorTracker.CapturePanic(context.Background(), err, callstack, map[string]interface{}{
"location": location,
"process_id": os.Getpid(),
})
}
if cb != nil {
cb(err)
@@ -125,5 +166,14 @@ func CatchPanic(location string) {
func HandlePanic(methodName string, r any) error {
stack := debug.Stack()
Error("Panic in %s: %v\nStack trace:\n%s", methodName, r, string(stack))
// Send to error tracker
if errorTracker != nil {
errorTracker.CapturePanic(context.Background(), r, stack, map[string]interface{}{
"method": methodName,
"process_id": os.Getpid(),
})
}
return fmt.Errorf("panic in %s: %v", methodName, r)
}

View File

@@ -127,7 +127,7 @@ func (h *Handler) Handle(w common.ResponseWriter, r common.Request, params map[s
// Validate and filter columns in options (log warnings for invalid columns)
validator := common.NewColumnValidator(model)
options = filterExtendedOptions(validator, options)
options = h.filterExtendedOptions(validator, options, model)
// Add request-scoped data to context (including options)
ctx = WithRequestData(ctx, schema, entity, tableName, model, modelPtr, options)
@@ -2241,7 +2241,7 @@ func (h *Handler) setRowNumbersOnRecords(records any, offset int) {
}
// filterExtendedOptions filters all column references, removing invalid ones and logging warnings
func filterExtendedOptions(validator *common.ColumnValidator, options ExtendedRequestOptions) ExtendedRequestOptions {
func (h *Handler) filterExtendedOptions(validator *common.ColumnValidator, options ExtendedRequestOptions, model interface{}) ExtendedRequestOptions {
filtered := options
// Filter base RequestOptions
@@ -2265,12 +2265,30 @@ func filterExtendedOptions(validator *common.ColumnValidator, options ExtendedRe
// No filtering needed for ComputedQL keys
filtered.ComputedQL = options.ComputedQL
// Filter Expand columns
// Filter Expand columns using the expand relation's model
filteredExpands := make([]ExpandOption, 0, len(options.Expand))
modelType := reflect.TypeOf(model)
if modelType.Kind() == reflect.Ptr {
modelType = modelType.Elem()
}
for _, expand := range options.Expand {
filteredExpand := expand
// Don't validate relation name, only columns
filteredExpand.Columns = validator.FilterValidColumns(expand.Columns)
// Get the relationship info for this expand relation
relInfo := h.getRelationshipInfo(modelType, expand.Relation)
if relInfo != nil && relInfo.relatedModel != nil {
// Create a validator for the related model
expandValidator := common.NewColumnValidator(relInfo.relatedModel)
// Filter columns using the related model's validator
filteredExpand.Columns = expandValidator.FilterValidColumns(expand.Columns)
} else {
// If we can't find the relationship, log a warning and skip column filtering
logger.Warn("Cannot validate columns for unknown relation: %s", expand.Relation)
// Keep the columns as-is if we can't validate them
filteredExpand.Columns = expand.Columns
}
filteredExpands = append(filteredExpands, filteredExpand)
}
filtered.Expand = filteredExpands

View File

@@ -11,6 +11,7 @@ import (
"time"
"github.com/bitechdev/ResolveSpec/pkg/cache"
"github.com/bitechdev/ResolveSpec/pkg/logger"
)
// Production-Ready Authenticators
@@ -169,69 +170,98 @@ func (a *DatabaseAuthenticator) Authenticate(r *http.Request) (*UserContext, err
// Extract session token from header or cookie
sessionToken := r.Header.Get("Authorization")
reference := "authenticate"
var tokens []string
if sessionToken == "" {
// Try cookie
cookie, err := r.Cookie("session_token")
if err == nil {
sessionToken = cookie.Value
tokens = []string{cookie.Value}
reference = "cookie"
}
} else {
// Remove "Bearer " prefix if present
sessionToken = strings.TrimPrefix(sessionToken, "Bearer ")
// Remove "Token " prefix if present
sessionToken = strings.TrimPrefix(sessionToken, "Token ")
// Parse Authorization header which may contain multiple comma-separated tokens
// Format: "Token abc, Token def" or "Bearer abc" or just "abc"
rawTokens := strings.Split(sessionToken, ",")
for _, token := range rawTokens {
token = strings.TrimSpace(token)
// Remove "Bearer " prefix if present
token = strings.TrimPrefix(token, "Bearer ")
// Remove "Token " prefix if present
token = strings.TrimPrefix(token, "Token ")
token = strings.TrimSpace(token)
if token != "" {
tokens = append(tokens, token)
}
}
}
if sessionToken == "" {
if len(tokens) == 0 {
return nil, fmt.Errorf("session token required")
}
// Build cache key
cacheKey := fmt.Sprintf("auth:session:%s", sessionToken)
// Use cache.GetOrSet to get from cache or load from database
var userCtx UserContext
err := a.cache.GetOrSet(r.Context(), cacheKey, &userCtx, a.cacheTTL, func() (interface{}, error) {
// This function is called only if cache miss
var success bool
var errorMsg sql.NullString
var userJSON sql.NullString
query := `SELECT p_success, p_error, p_user::text FROM resolvespec_session($1, $2)`
err := a.db.QueryRowContext(r.Context(), query, sessionToken, reference).Scan(&success, &errorMsg, &userJSON)
if err != nil {
return nil, fmt.Errorf("session query failed: %w", err)
}
if !success {
if errorMsg.Valid {
return nil, fmt.Errorf("%s", errorMsg.String)
}
return nil, fmt.Errorf("invalid or expired session")
}
if !userJSON.Valid {
return nil, fmt.Errorf("no user data in session")
}
// Parse UserContext
var user UserContext
if err := json.Unmarshal([]byte(userJSON.String), &user); err != nil {
return nil, fmt.Errorf("failed to parse user context: %w", err)
}
return &user, nil
})
if err != nil {
return nil, err
// Log warning if multiple tokens are provided
if len(tokens) > 1 {
logger.Warn("Multiple authentication tokens provided in Authorization header (%d tokens). This is unusual and may indicate a misconfigured client. Header: %s", len(tokens), sessionToken)
}
// Update last activity timestamp asynchronously
go a.updateSessionActivity(r.Context(), sessionToken, &userCtx)
// Try each token until one succeeds
var lastErr error
for _, token := range tokens {
// Build cache key
cacheKey := fmt.Sprintf("auth:session:%s", token)
return &userCtx, nil
// Use cache.GetOrSet to get from cache or load from database
var userCtx UserContext
err := a.cache.GetOrSet(r.Context(), cacheKey, &userCtx, a.cacheTTL, func() (any, error) {
// This function is called only if cache miss
var success bool
var errorMsg sql.NullString
var userJSON sql.NullString
query := `SELECT p_success, p_error, p_user::text FROM resolvespec_session($1, $2)`
err := a.db.QueryRowContext(r.Context(), query, token, reference).Scan(&success, &errorMsg, &userJSON)
if err != nil {
return nil, fmt.Errorf("session query failed: %w", err)
}
if !success {
if errorMsg.Valid {
return nil, fmt.Errorf("%s", errorMsg.String)
}
return nil, fmt.Errorf("invalid or expired session")
}
if !userJSON.Valid {
return nil, fmt.Errorf("no user data in session")
}
// Parse UserContext
var user UserContext
if err := json.Unmarshal([]byte(userJSON.String), &user); err != nil {
return nil, fmt.Errorf("failed to parse user context: %w", err)
}
return &user, nil
})
if err != nil {
lastErr = err
continue // Try next token
}
// Authentication succeeded with this token
// Update last activity timestamp asynchronously
go a.updateSessionActivity(r.Context(), token, &userCtx)
return &userCtx, nil
}
// All tokens failed
if lastErr != nil {
return nil, lastErr
}
return nil, fmt.Errorf("authentication failed for all provided tokens")
}
// ClearCache removes a specific token from the cache or clears all cache if token is empty

View File

@@ -545,6 +545,96 @@ func TestDatabaseAuthenticator(t *testing.T) {
t.Fatal("expected error when token is missing")
}
})
t.Run("authenticate with multiple comma-separated tokens", func(t *testing.T) {
req := httptest.NewRequest("GET", "/test", nil)
req.Header.Set("Authorization", "Token invalid-token, Token valid-token-123")
// First token fails
rows1 := sqlmock.NewRows([]string{"p_success", "p_error", "p_user"}).
AddRow(false, "Invalid token", nil)
mock.ExpectQuery(`SELECT p_success, p_error, p_user::text FROM resolvespec_session`).
WithArgs("invalid-token", "authenticate").
WillReturnRows(rows1)
// Second token succeeds
rows2 := sqlmock.NewRows([]string{"p_success", "p_error", "p_user"}).
AddRow(true, nil, `{"user_id":3,"user_name":"multitoken","session_id":"valid-token-123"}`)
mock.ExpectQuery(`SELECT p_success, p_error, p_user::text FROM resolvespec_session`).
WithArgs("valid-token-123", "authenticate").
WillReturnRows(rows2)
userCtx, err := auth.Authenticate(req)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
if userCtx.UserID != 3 {
t.Errorf("expected UserID 3, got %d", userCtx.UserID)
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("unfulfilled expectations: %v", err)
}
})
t.Run("authenticate with duplicate tokens", func(t *testing.T) {
req := httptest.NewRequest("GET", "/test", nil)
req.Header.Set("Authorization", "Token 968CA5AE-4F83-4D55-A3C6-51AE4410E03A, Token 968CA5AE-4F83-4D55-A3C6-51AE4410E03A")
// First token succeeds
rows := sqlmock.NewRows([]string{"p_success", "p_error", "p_user"}).
AddRow(true, nil, `{"user_id":4,"user_name":"duplicateuser","session_id":"968CA5AE-4F83-4D55-A3C6-51AE4410E03A"}`)
mock.ExpectQuery(`SELECT p_success, p_error, p_user::text FROM resolvespec_session`).
WithArgs("968CA5AE-4F83-4D55-A3C6-51AE4410E03A", "authenticate").
WillReturnRows(rows)
userCtx, err := auth.Authenticate(req)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
if userCtx.UserID != 4 {
t.Errorf("expected UserID 4, got %d", userCtx.UserID)
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("unfulfilled expectations: %v", err)
}
})
t.Run("authenticate with all tokens failing", func(t *testing.T) {
req := httptest.NewRequest("GET", "/test", nil)
req.Header.Set("Authorization", "Token bad-token-1, Token bad-token-2")
// First token fails
rows1 := sqlmock.NewRows([]string{"p_success", "p_error", "p_user"}).
AddRow(false, "Invalid token", nil)
mock.ExpectQuery(`SELECT p_success, p_error, p_user::text FROM resolvespec_session`).
WithArgs("bad-token-1", "authenticate").
WillReturnRows(rows1)
// Second token also fails
rows2 := sqlmock.NewRows([]string{"p_success", "p_error", "p_user"}).
AddRow(false, "Invalid token", nil)
mock.ExpectQuery(`SELECT p_success, p_error, p_user::text FROM resolvespec_session`).
WithArgs("bad-token-2", "authenticate").
WillReturnRows(rows2)
_, err := auth.Authenticate(req)
if err == nil {
t.Fatal("expected error when all tokens fail")
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("unfulfilled expectations: %v", err)
}
})
}
// Test DatabaseAuthenticator RefreshToken