mirror of
https://github.com/bitechdev/ResolveSpec.git
synced 2025-12-29 15:54:26 +00:00
Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1d0407a16d | ||
|
|
99001c749d | ||
|
|
1f7a57f8e3 | ||
|
|
a95c28a0bf |
@@ -71,35 +71,18 @@
|
|||||||
},
|
},
|
||||||
"gocritic": {
|
"gocritic": {
|
||||||
"enabled-checks": [
|
"enabled-checks": [
|
||||||
"appendAssign",
|
|
||||||
"assignOp",
|
|
||||||
"boolExprSimplify",
|
"boolExprSimplify",
|
||||||
"builtinShadow",
|
"builtinShadow",
|
||||||
"captLocal",
|
|
||||||
"caseOrder",
|
|
||||||
"defaultCaseOrder",
|
|
||||||
"dupArg",
|
|
||||||
"dupBranchBody",
|
|
||||||
"dupCase",
|
|
||||||
"dupSubExpr",
|
|
||||||
"elseif",
|
|
||||||
"emptyFallthrough",
|
"emptyFallthrough",
|
||||||
"equalFold",
|
"equalFold",
|
||||||
"flagName",
|
|
||||||
"indexAlloc",
|
"indexAlloc",
|
||||||
"initClause",
|
"initClause",
|
||||||
"methodExprCall",
|
"methodExprCall",
|
||||||
"nilValReturn",
|
"nilValReturn",
|
||||||
"rangeExprCopy",
|
"rangeExprCopy",
|
||||||
"rangeValCopy",
|
"rangeValCopy",
|
||||||
"regexpMust",
|
|
||||||
"singleCaseSwitch",
|
|
||||||
"sloppyLen",
|
|
||||||
"stringXbytes",
|
"stringXbytes",
|
||||||
"switchTrue",
|
|
||||||
"typeAssertChain",
|
"typeAssertChain",
|
||||||
"typeSwitchVar",
|
|
||||||
"underef",
|
|
||||||
"unlabelStmt",
|
"unlabelStmt",
|
||||||
"unnamedResult",
|
"unnamedResult",
|
||||||
"unnecessaryBlock",
|
"unnecessaryBlock",
|
||||||
|
|||||||
1
go.mod
1
go.mod
@@ -38,6 +38,7 @@ require (
|
|||||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||||
github.com/fsnotify/fsnotify v1.9.0 // indirect
|
github.com/fsnotify/fsnotify v1.9.0 // indirect
|
||||||
|
github.com/getsentry/sentry-go v0.40.0 // indirect
|
||||||
github.com/glebarez/go-sqlite v1.21.2 // indirect
|
github.com/glebarez/go-sqlite v1.21.2 // indirect
|
||||||
github.com/go-logr/logr v1.4.3 // indirect
|
github.com/go-logr/logr v1.4.3 // indirect
|
||||||
github.com/go-logr/stdr v1.2.2 // indirect
|
github.com/go-logr/stdr v1.2.2 // indirect
|
||||||
|
|||||||
2
go.sum
2
go.sum
@@ -21,6 +21,8 @@ github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkp
|
|||||||
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
||||||
github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k=
|
github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k=
|
||||||
github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
|
github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
|
||||||
|
github.com/getsentry/sentry-go v0.40.0 h1:VTJMN9zbTvqDqPwheRVLcp0qcUcM+8eFivvGocAaSbo=
|
||||||
|
github.com/getsentry/sentry-go v0.40.0/go.mod h1:eRXCoh3uvmjQLY6qu63BjUZnaBu5L5WhMV1RwYO8W5s=
|
||||||
github.com/glebarez/go-sqlite v1.21.2 h1:3a6LFC4sKahUunAmynQKLZceZCOzUthkRkEAl9gAXWo=
|
github.com/glebarez/go-sqlite v1.21.2 h1:3a6LFC4sKahUunAmynQKLZceZCOzUthkRkEAl9gAXWo=
|
||||||
github.com/glebarez/go-sqlite v1.21.2/go.mod h1:sfxdZyhQjTM2Wry3gVYWaW072Ri1WMdWJi0k6+3382k=
|
github.com/glebarez/go-sqlite v1.21.2/go.mod h1:sfxdZyhQjTM2Wry3gVYWaW072Ri1WMdWJi0k6+3382k=
|
||||||
github.com/glebarez/sqlite v1.11.0 h1:wSG0irqzP6VurnMEpFGer5Li19RpIRi2qvQz++w0GMw=
|
github.com/glebarez/sqlite v1.11.0 h1:wSG0irqzP6VurnMEpFGer5Li19RpIRi2qvQz++w0GMw=
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package common
|
package common
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
@@ -78,6 +79,41 @@ func IsTrivialCondition(cond string) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// validateWhereClauseSecurity checks for dangerous SQL statements in WHERE clauses
|
||||||
|
// Returns an error if any dangerous keywords are found
|
||||||
|
func validateWhereClauseSecurity(where string) error {
|
||||||
|
if where == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
lowerWhere := strings.ToLower(where)
|
||||||
|
|
||||||
|
// List of dangerous SQL keywords that should never appear in WHERE clauses
|
||||||
|
dangerousKeywords := []string{
|
||||||
|
"delete ", "delete\t", "delete\n", "delete;",
|
||||||
|
"update ", "update\t", "update\n", "update;",
|
||||||
|
"truncate ", "truncate\t", "truncate\n", "truncate;",
|
||||||
|
"drop ", "drop\t", "drop\n", "drop;",
|
||||||
|
"alter ", "alter\t", "alter\n", "alter;",
|
||||||
|
"create ", "create\t", "create\n", "create;",
|
||||||
|
"insert ", "insert\t", "insert\n", "insert;",
|
||||||
|
"grant ", "grant\t", "grant\n", "grant;",
|
||||||
|
"revoke ", "revoke\t", "revoke\n", "revoke;",
|
||||||
|
"exec ", "exec\t", "exec\n", "exec;",
|
||||||
|
"execute ", "execute\t", "execute\n", "execute;",
|
||||||
|
";delete", ";update", ";truncate", ";drop", ";alter", ";create", ";insert",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, keyword := range dangerousKeywords {
|
||||||
|
if strings.Contains(lowerWhere, keyword) {
|
||||||
|
logger.Error("Dangerous SQL keyword detected in WHERE clause: %s", strings.TrimSpace(keyword))
|
||||||
|
return fmt.Errorf("dangerous SQL keyword detected in WHERE clause: %s", strings.TrimSpace(keyword))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// SanitizeWhereClause removes trivial conditions and fixes incorrect table prefixes
|
// SanitizeWhereClause removes trivial conditions and fixes incorrect table prefixes
|
||||||
// This function should be used everywhere a WHERE statement is sent to ensure clean, efficient SQL
|
// This function should be used everywhere a WHERE statement is sent to ensure clean, efficient SQL
|
||||||
//
|
//
|
||||||
@@ -100,6 +136,12 @@ func SanitizeWhereClause(where string, tableName string, options ...*RequestOpti
|
|||||||
|
|
||||||
where = strings.TrimSpace(where)
|
where = strings.TrimSpace(where)
|
||||||
|
|
||||||
|
// Validate that the WHERE clause doesn't contain dangerous SQL statements
|
||||||
|
if err := validateWhereClauseSecurity(where); err != nil {
|
||||||
|
logger.Debug("Security validation failed for WHERE clause: %v", err)
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
// Strip outer parentheses and re-trim
|
// Strip outer parentheses and re-trim
|
||||||
where = stripOuterParentheses(where)
|
where = stripOuterParentheses(where)
|
||||||
|
|
||||||
@@ -221,19 +263,57 @@ func stripOuterParentheses(s string) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// splitByAND splits a WHERE clause by AND operators (case-insensitive)
|
// splitByAND splits a WHERE clause by AND operators (case-insensitive)
|
||||||
// This is a simple split that doesn't handle nested parentheses or complex expressions
|
// This is parenthesis-aware and won't split on AND operators inside subqueries
|
||||||
func splitByAND(where string) []string {
|
func splitByAND(where string) []string {
|
||||||
// First try uppercase AND
|
conditions := []string{}
|
||||||
conditions := strings.Split(where, " AND ")
|
currentCondition := strings.Builder{}
|
||||||
|
depth := 0 // Track parenthesis depth
|
||||||
|
i := 0
|
||||||
|
|
||||||
// If we didn't split on uppercase, try lowercase
|
for i < len(where) {
|
||||||
if len(conditions) == 1 {
|
ch := where[i]
|
||||||
conditions = strings.Split(where, " and ")
|
|
||||||
|
// Track parenthesis depth
|
||||||
|
if ch == '(' {
|
||||||
|
depth++
|
||||||
|
currentCondition.WriteByte(ch)
|
||||||
|
i++
|
||||||
|
continue
|
||||||
|
} else if ch == ')' {
|
||||||
|
depth--
|
||||||
|
currentCondition.WriteByte(ch)
|
||||||
|
i++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only look for AND operators at depth 0 (not inside parentheses)
|
||||||
|
if depth == 0 {
|
||||||
|
// Check if we're at an AND operator (case-insensitive)
|
||||||
|
// We need at least " AND " (5 chars) or " and " (5 chars)
|
||||||
|
if i+5 <= len(where) {
|
||||||
|
substring := where[i : i+5]
|
||||||
|
lowerSubstring := strings.ToLower(substring)
|
||||||
|
|
||||||
|
if lowerSubstring == " and " {
|
||||||
|
// Found an AND operator at the top level
|
||||||
|
// Add the current condition to the list
|
||||||
|
conditions = append(conditions, currentCondition.String())
|
||||||
|
currentCondition.Reset()
|
||||||
|
// Skip past the AND operator
|
||||||
|
i += 5
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Not an AND operator or we're inside parentheses, just add the character
|
||||||
|
currentCondition.WriteByte(ch)
|
||||||
|
i++
|
||||||
}
|
}
|
||||||
|
|
||||||
// If we still didn't split, try mixed case
|
// Add the last condition
|
||||||
if len(conditions) == 1 {
|
if currentCondition.Len() > 0 {
|
||||||
conditions = strings.Split(where, " And ")
|
conditions = append(conditions, currentCondition.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
return conditions
|
return conditions
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package common
|
package common
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
|
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
|
||||||
@@ -85,6 +86,42 @@ func TestSanitizeWhereClause(t *testing.T) {
|
|||||||
tableName: "users",
|
tableName: "users",
|
||||||
expected: "users.status = 'active' AND users.age > 18",
|
expected: "users.status = 'active' AND users.age > 18",
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "mixed case AND operators",
|
||||||
|
where: "status = 'active' AND age > 18 and name = 'John'",
|
||||||
|
tableName: "users",
|
||||||
|
expected: "status = 'active' AND age > 18 AND name = 'John'",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "subquery with ORDER BY and LIMIT - allowed",
|
||||||
|
where: "id IN (SELECT id FROM users WHERE status = 'active' ORDER BY created_at DESC LIMIT 10)",
|
||||||
|
tableName: "users",
|
||||||
|
expected: "id IN (SELECT id FROM users WHERE status = 'active' ORDER BY created_at DESC LIMIT 10)",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "dangerous DELETE keyword - blocked",
|
||||||
|
where: "status = 'active'; DELETE FROM users",
|
||||||
|
tableName: "users",
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "dangerous UPDATE keyword - blocked",
|
||||||
|
where: "1=1; UPDATE users SET admin = true",
|
||||||
|
tableName: "users",
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "dangerous TRUNCATE keyword - blocked",
|
||||||
|
where: "status = 'active' OR TRUNCATE TABLE users",
|
||||||
|
tableName: "users",
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "dangerous DROP keyword - blocked",
|
||||||
|
where: "status = 'active'; DROP TABLE users",
|
||||||
|
tableName: "users",
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
@@ -138,6 +175,11 @@ func TestStripOuterParentheses(t *testing.T) {
|
|||||||
input: " ( true ) ",
|
input: " ( true ) ",
|
||||||
expected: "true",
|
expected: "true",
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "complex sub query",
|
||||||
|
input: "(a = 1 AND b = 2 or c = 3 and (select s from generate_series(1,10) s where s < 10 and s > 0 offset 2 limit 1) = 3)",
|
||||||
|
expected: "a = 1 AND b = 2 or c = 3 and (select s from generate_series(1,10) s where s < 10 and s > 0 offset 2 limit 1) = 3",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
@@ -337,6 +379,131 @@ type MasterTask struct {
|
|||||||
UserID int `bun:"user_id"`
|
UserID int `bun:"user_id"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSplitByAND(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
expected []string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "uppercase AND",
|
||||||
|
input: "status = 'active' AND age > 18",
|
||||||
|
expected: []string{"status = 'active'", "age > 18"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "lowercase and",
|
||||||
|
input: "status = 'active' and age > 18",
|
||||||
|
expected: []string{"status = 'active'", "age > 18"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "mixed case AND",
|
||||||
|
input: "status = 'active' AND age > 18 and name = 'John'",
|
||||||
|
expected: []string{"status = 'active'", "age > 18", "name = 'John'"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "single condition",
|
||||||
|
input: "status = 'active'",
|
||||||
|
expected: []string{"status = 'active'"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple uppercase AND",
|
||||||
|
input: "a = 1 AND b = 2 AND c = 3",
|
||||||
|
expected: []string{"a = 1", "b = 2", "c = 3"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple case subquery",
|
||||||
|
input: "a = 1 AND b = 2 AND c = 3 and (select s from generate_series(1,10) s where s < 10 and s > 0 offset 2 limit 1) = 3",
|
||||||
|
expected: []string{"a = 1", "b = 2", "c = 3", "(select s from generate_series(1,10) s where s < 10 and s > 0 offset 2 limit 1) = 3"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := splitByAND(tt.input)
|
||||||
|
if len(result) != len(tt.expected) {
|
||||||
|
t.Errorf("splitByAND(%q) returned %d conditions; want %d", tt.input, len(result), len(tt.expected))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for i := range result {
|
||||||
|
if strings.TrimSpace(result[i]) != strings.TrimSpace(tt.expected[i]) {
|
||||||
|
t.Errorf("splitByAND(%q)[%d] = %q; want %q", tt.input, i, result[i], tt.expected[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateWhereClauseSecurity(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
expectError bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "safe WHERE clause",
|
||||||
|
input: "status = 'active' AND age > 18",
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "safe subquery",
|
||||||
|
input: "id IN (SELECT id FROM users WHERE status = 'active' ORDER BY created_at DESC LIMIT 10)",
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "DELETE keyword",
|
||||||
|
input: "status = 'active'; DELETE FROM users",
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "UPDATE keyword",
|
||||||
|
input: "1=1; UPDATE users SET admin = true",
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "TRUNCATE keyword",
|
||||||
|
input: "status = 'active' OR TRUNCATE TABLE users",
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "DROP keyword",
|
||||||
|
input: "status = 'active'; DROP TABLE users",
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "INSERT keyword",
|
||||||
|
input: "status = 'active'; INSERT INTO users (name) VALUES ('hacker')",
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ALTER keyword",
|
||||||
|
input: "1=1; ALTER TABLE users ADD COLUMN is_admin BOOLEAN",
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "CREATE keyword",
|
||||||
|
input: "1=1; CREATE TABLE malicious (id INT)",
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty clause",
|
||||||
|
input: "",
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
err := validateWhereClauseSecurity(tt.input)
|
||||||
|
if tt.expectError && err == nil {
|
||||||
|
t.Errorf("validateWhereClauseSecurity(%q) expected error but got none", tt.input)
|
||||||
|
}
|
||||||
|
if !tt.expectError && err != nil {
|
||||||
|
t.Errorf("validateWhereClauseSecurity(%q) unexpected error: %v", tt.input, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestSanitizeWhereClauseWithModel(t *testing.T) {
|
func TestSanitizeWhereClauseWithModel(t *testing.T) {
|
||||||
// Register the test model
|
// Register the test model
|
||||||
err := modelregistry.RegisterModel(MasterTask{}, "mastertask")
|
err := modelregistry.RegisterModel(MasterTask{}, "mastertask")
|
||||||
|
|||||||
@@ -4,13 +4,14 @@ import "time"
|
|||||||
|
|
||||||
// Config represents the complete application configuration
|
// Config represents the complete application configuration
|
||||||
type Config struct {
|
type Config struct {
|
||||||
Server ServerConfig `mapstructure:"server"`
|
Server ServerConfig `mapstructure:"server"`
|
||||||
Tracing TracingConfig `mapstructure:"tracing"`
|
Tracing TracingConfig `mapstructure:"tracing"`
|
||||||
Cache CacheConfig `mapstructure:"cache"`
|
Cache CacheConfig `mapstructure:"cache"`
|
||||||
Logger LoggerConfig `mapstructure:"logger"`
|
Logger LoggerConfig `mapstructure:"logger"`
|
||||||
Middleware MiddlewareConfig `mapstructure:"middleware"`
|
ErrorTracking ErrorTrackingConfig `mapstructure:"error_tracking"`
|
||||||
CORS CORSConfig `mapstructure:"cors"`
|
Middleware MiddlewareConfig `mapstructure:"middleware"`
|
||||||
Database DatabaseConfig `mapstructure:"database"`
|
CORS CORSConfig `mapstructure:"cors"`
|
||||||
|
Database DatabaseConfig `mapstructure:"database"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// ServerConfig holds server-related configuration
|
// ServerConfig holds server-related configuration
|
||||||
@@ -78,3 +79,15 @@ type CORSConfig struct {
|
|||||||
type DatabaseConfig struct {
|
type DatabaseConfig struct {
|
||||||
URL string `mapstructure:"url"`
|
URL string `mapstructure:"url"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ErrorTrackingConfig holds error tracking configuration
|
||||||
|
type ErrorTrackingConfig struct {
|
||||||
|
Enabled bool `mapstructure:"enabled"`
|
||||||
|
Provider string `mapstructure:"provider"` // sentry, noop
|
||||||
|
DSN string `mapstructure:"dsn"` // Sentry DSN
|
||||||
|
Environment string `mapstructure:"environment"` // e.g., production, staging, development
|
||||||
|
Release string `mapstructure:"release"` // Application version/release
|
||||||
|
Debug bool `mapstructure:"debug"` // Enable debug mode
|
||||||
|
SampleRate float64 `mapstructure:"sample_rate"` // Error sample rate (0.0-1.0)
|
||||||
|
TracesSampleRate float64 `mapstructure:"traces_sample_rate"` // Traces sample rate (0.0-1.0)
|
||||||
|
}
|
||||||
|
|||||||
150
pkg/errortracking/README.md
Normal file
150
pkg/errortracking/README.md
Normal file
@@ -0,0 +1,150 @@
|
|||||||
|
# Error Tracking
|
||||||
|
|
||||||
|
This package provides error tracking integration for ResolveSpec, with built-in support for Sentry.
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
- **Provider Interface**: Flexible design supporting multiple error tracking backends
|
||||||
|
- **Sentry Integration**: Full-featured Sentry support with automatic error, warning, and panic tracking
|
||||||
|
- **Automatic Logger Integration**: All `logger.Error()` and `logger.Warn()` calls are automatically sent to the error tracker
|
||||||
|
- **Panic Tracking**: Automatic panic capture with stack traces
|
||||||
|
- **NoOp Provider**: Zero-overhead when error tracking is disabled
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
Add error tracking configuration to your config file:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
error_tracking:
|
||||||
|
enabled: true
|
||||||
|
provider: "sentry" # Currently supports: "sentry" or "noop"
|
||||||
|
dsn: "https://your-sentry-dsn@sentry.io/project-id"
|
||||||
|
environment: "production" # e.g., production, staging, development
|
||||||
|
release: "v1.0.0" # Your application version
|
||||||
|
debug: false
|
||||||
|
sample_rate: 1.0 # Error sample rate (0.0-1.0)
|
||||||
|
traces_sample_rate: 0.1 # Traces sample rate (0.0-1.0)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
### Initialization
|
||||||
|
|
||||||
|
Initialize error tracking in your application startup:
|
||||||
|
|
||||||
|
```go
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/config"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/errortracking"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
// Load your configuration
|
||||||
|
cfg := config.Config{
|
||||||
|
ErrorTracking: config.ErrorTrackingConfig{
|
||||||
|
Enabled: true,
|
||||||
|
Provider: "sentry",
|
||||||
|
DSN: "https://your-sentry-dsn@sentry.io/project-id",
|
||||||
|
Environment: "production",
|
||||||
|
Release: "v1.0.0",
|
||||||
|
SampleRate: 1.0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize logger
|
||||||
|
logger.Init(false)
|
||||||
|
|
||||||
|
// Initialize error tracking
|
||||||
|
provider, err := errortracking.NewProviderFromConfig(cfg.ErrorTracking)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to initialize error tracking: %v", err)
|
||||||
|
} else {
|
||||||
|
logger.InitErrorTracking(provider)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Your application code...
|
||||||
|
|
||||||
|
// Cleanup on shutdown
|
||||||
|
defer logger.CloseErrorTracking()
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Automatic Tracking
|
||||||
|
|
||||||
|
Once initialized, all logger errors and warnings are automatically sent to the error tracker:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// This will be logged AND sent to Sentry
|
||||||
|
logger.Error("Database connection failed: %v", err)
|
||||||
|
|
||||||
|
// This will also be logged AND sent to Sentry
|
||||||
|
logger.Warn("Cache miss for key: %s", key)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Panic Tracking
|
||||||
|
|
||||||
|
Panics are automatically captured when using the logger's panic handlers:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Using CatchPanic
|
||||||
|
defer logger.CatchPanic("MyFunction")
|
||||||
|
|
||||||
|
// Using CatchPanicCallback
|
||||||
|
defer logger.CatchPanicCallback("MyFunction", func(err any) {
|
||||||
|
// Custom cleanup
|
||||||
|
})
|
||||||
|
|
||||||
|
// Using HandlePanic
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
err = logger.HandlePanic("MyMethod", r)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
```
|
||||||
|
|
||||||
|
### Manual Tracking
|
||||||
|
|
||||||
|
You can also use the provider directly for custom error tracking:
|
||||||
|
|
||||||
|
```go
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/errortracking"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
func someFunction() {
|
||||||
|
tracker := logger.GetErrorTracker()
|
||||||
|
if tracker != nil {
|
||||||
|
// Capture an error
|
||||||
|
tracker.CaptureError(context.Background(), err, errortracking.SeverityError, map[string]interface{}{
|
||||||
|
"user_id": userID,
|
||||||
|
"request_id": requestID,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Capture a message
|
||||||
|
tracker.CaptureMessage(context.Background(), "Important event occurred", errortracking.SeverityInfo, map[string]interface{}{
|
||||||
|
"event_type": "user_signup",
|
||||||
|
})
|
||||||
|
|
||||||
|
// Capture a panic
|
||||||
|
tracker.CapturePanic(context.Background(), recovered, stackTrace, map[string]interface{}{
|
||||||
|
"context": "background_job",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Severity Levels
|
||||||
|
|
||||||
|
The package supports the following severity levels:
|
||||||
|
|
||||||
|
- `SeverityError`: For errors that should be tracked and investigated
|
||||||
|
- `SeverityWarning`: For warnings that may indicate potential issues
|
||||||
|
- `SeverityInfo`: For informational messages
|
||||||
|
- `SeverityDebug`: For debug-level information
|
||||||
|
|
||||||
|
```
|
||||||
67
pkg/errortracking/errortracking_test.go
Normal file
67
pkg/errortracking/errortracking_test.go
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
package errortracking
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNoOpProvider(t *testing.T) {
|
||||||
|
provider := NewNoOpProvider()
|
||||||
|
|
||||||
|
// Test that all methods can be called without panicking
|
||||||
|
t.Run("CaptureError", func(t *testing.T) {
|
||||||
|
provider.CaptureError(context.Background(), errors.New("test error"), SeverityError, nil)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("CaptureMessage", func(t *testing.T) {
|
||||||
|
provider.CaptureMessage(context.Background(), "test message", SeverityWarning, nil)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("CapturePanic", func(t *testing.T) {
|
||||||
|
provider.CapturePanic(context.Background(), "panic!", []byte("stack trace"), nil)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Flush", func(t *testing.T) {
|
||||||
|
result := provider.Flush(5)
|
||||||
|
if !result {
|
||||||
|
t.Error("Expected Flush to return true")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Close", func(t *testing.T) {
|
||||||
|
err := provider.Close()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Expected Close to return nil, got %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSeverityLevels(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
severity Severity
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{"Error", SeverityError, "error"},
|
||||||
|
{"Warning", SeverityWarning, "warning"},
|
||||||
|
{"Info", SeverityInfo, "info"},
|
||||||
|
{"Debug", SeverityDebug, "debug"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if string(tt.severity) != tt.expected {
|
||||||
|
t.Errorf("Expected %s, got %s", tt.expected, string(tt.severity))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProviderInterface(t *testing.T) {
|
||||||
|
// Test that NoOpProvider implements Provider interface
|
||||||
|
var _ Provider = (*NoOpProvider)(nil)
|
||||||
|
|
||||||
|
// Test that SentryProvider implements Provider interface
|
||||||
|
var _ Provider = (*SentryProvider)(nil)
|
||||||
|
}
|
||||||
33
pkg/errortracking/factory.go
Normal file
33
pkg/errortracking/factory.go
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
package errortracking
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
// NewProviderFromConfig creates an error tracking provider based on the configuration
|
||||||
|
func NewProviderFromConfig(cfg config.ErrorTrackingConfig) (Provider, error) {
|
||||||
|
if !cfg.Enabled {
|
||||||
|
return NewNoOpProvider(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
switch cfg.Provider {
|
||||||
|
case "sentry":
|
||||||
|
if cfg.DSN == "" {
|
||||||
|
return nil, fmt.Errorf("sentry DSN is required when error tracking is enabled")
|
||||||
|
}
|
||||||
|
return NewSentryProvider(SentryConfig{
|
||||||
|
DSN: cfg.DSN,
|
||||||
|
Environment: cfg.Environment,
|
||||||
|
Release: cfg.Release,
|
||||||
|
Debug: cfg.Debug,
|
||||||
|
SampleRate: cfg.SampleRate,
|
||||||
|
TracesSampleRate: cfg.TracesSampleRate,
|
||||||
|
})
|
||||||
|
case "noop", "":
|
||||||
|
return NewNoOpProvider(), nil
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("unknown error tracking provider: %s", cfg.Provider)
|
||||||
|
}
|
||||||
|
}
|
||||||
33
pkg/errortracking/interfaces.go
Normal file
33
pkg/errortracking/interfaces.go
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
package errortracking
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Severity represents the severity level of an error
|
||||||
|
type Severity string
|
||||||
|
|
||||||
|
const (
|
||||||
|
SeverityError Severity = "error"
|
||||||
|
SeverityWarning Severity = "warning"
|
||||||
|
SeverityInfo Severity = "info"
|
||||||
|
SeverityDebug Severity = "debug"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Provider defines the interface for error tracking providers
|
||||||
|
type Provider interface {
|
||||||
|
// CaptureError captures an error with the given severity and additional context
|
||||||
|
CaptureError(ctx context.Context, err error, severity Severity, extra map[string]interface{})
|
||||||
|
|
||||||
|
// CaptureMessage captures a message with the given severity and additional context
|
||||||
|
CaptureMessage(ctx context.Context, message string, severity Severity, extra map[string]interface{})
|
||||||
|
|
||||||
|
// CapturePanic captures a panic with stack trace
|
||||||
|
CapturePanic(ctx context.Context, recovered interface{}, stackTrace []byte, extra map[string]interface{})
|
||||||
|
|
||||||
|
// Flush waits for all events to be sent (useful for graceful shutdown)
|
||||||
|
Flush(timeout int) bool
|
||||||
|
|
||||||
|
// Close closes the provider and releases resources
|
||||||
|
Close() error
|
||||||
|
}
|
||||||
37
pkg/errortracking/noop.go
Normal file
37
pkg/errortracking/noop.go
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
package errortracking
|
||||||
|
|
||||||
|
import "context"
|
||||||
|
|
||||||
|
// NoOpProvider is a no-op implementation of the Provider interface
|
||||||
|
// Used when error tracking is disabled
|
||||||
|
type NoOpProvider struct{}
|
||||||
|
|
||||||
|
// NewNoOpProvider creates a new NoOp provider
|
||||||
|
func NewNoOpProvider() *NoOpProvider {
|
||||||
|
return &NoOpProvider{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CaptureError does nothing
|
||||||
|
func (n *NoOpProvider) CaptureError(ctx context.Context, err error, severity Severity, extra map[string]interface{}) {
|
||||||
|
// No-op
|
||||||
|
}
|
||||||
|
|
||||||
|
// CaptureMessage does nothing
|
||||||
|
func (n *NoOpProvider) CaptureMessage(ctx context.Context, message string, severity Severity, extra map[string]interface{}) {
|
||||||
|
// No-op
|
||||||
|
}
|
||||||
|
|
||||||
|
// CapturePanic does nothing
|
||||||
|
func (n *NoOpProvider) CapturePanic(ctx context.Context, recovered interface{}, stackTrace []byte, extra map[string]interface{}) {
|
||||||
|
// No-op
|
||||||
|
}
|
||||||
|
|
||||||
|
// Flush does nothing and returns true
|
||||||
|
func (n *NoOpProvider) Flush(timeout int) bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close does nothing
|
||||||
|
func (n *NoOpProvider) Close() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
154
pkg/errortracking/sentry.go
Normal file
154
pkg/errortracking/sentry.go
Normal file
@@ -0,0 +1,154 @@
|
|||||||
|
package errortracking
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/getsentry/sentry-go"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SentryProvider implements the Provider interface using Sentry
|
||||||
|
type SentryProvider struct {
|
||||||
|
hub *sentry.Hub
|
||||||
|
}
|
||||||
|
|
||||||
|
// SentryConfig holds the configuration for Sentry
|
||||||
|
type SentryConfig struct {
|
||||||
|
DSN string
|
||||||
|
Environment string
|
||||||
|
Release string
|
||||||
|
Debug bool
|
||||||
|
SampleRate float64
|
||||||
|
TracesSampleRate float64
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSentryProvider creates a new Sentry provider
|
||||||
|
func NewSentryProvider(config SentryConfig) (*SentryProvider, error) {
|
||||||
|
err := sentry.Init(sentry.ClientOptions{
|
||||||
|
Dsn: config.DSN,
|
||||||
|
Environment: config.Environment,
|
||||||
|
Release: config.Release,
|
||||||
|
Debug: config.Debug,
|
||||||
|
AttachStacktrace: true,
|
||||||
|
SampleRate: config.SampleRate,
|
||||||
|
TracesSampleRate: config.TracesSampleRate,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to initialize Sentry: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &SentryProvider{
|
||||||
|
hub: sentry.CurrentHub(),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CaptureError captures an error with the given severity and additional context
|
||||||
|
func (s *SentryProvider) CaptureError(ctx context.Context, err error, severity Severity, extra map[string]interface{}) {
|
||||||
|
if err == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
hub := sentry.GetHubFromContext(ctx)
|
||||||
|
if hub == nil {
|
||||||
|
hub = s.hub
|
||||||
|
}
|
||||||
|
|
||||||
|
event := sentry.NewEvent()
|
||||||
|
event.Level = s.convertSeverity(severity)
|
||||||
|
event.Message = err.Error()
|
||||||
|
event.Exception = []sentry.Exception{
|
||||||
|
{
|
||||||
|
Value: err.Error(),
|
||||||
|
Type: fmt.Sprintf("%T", err),
|
||||||
|
Stacktrace: sentry.ExtractStacktrace(err),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if extra != nil {
|
||||||
|
event.Extra = extra
|
||||||
|
}
|
||||||
|
|
||||||
|
hub.CaptureEvent(event)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CaptureMessage captures a message with the given severity and additional context
|
||||||
|
func (s *SentryProvider) CaptureMessage(ctx context.Context, message string, severity Severity, extra map[string]interface{}) {
|
||||||
|
if message == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
hub := sentry.GetHubFromContext(ctx)
|
||||||
|
if hub == nil {
|
||||||
|
hub = s.hub
|
||||||
|
}
|
||||||
|
|
||||||
|
event := sentry.NewEvent()
|
||||||
|
event.Level = s.convertSeverity(severity)
|
||||||
|
event.Message = message
|
||||||
|
|
||||||
|
if extra != nil {
|
||||||
|
event.Extra = extra
|
||||||
|
}
|
||||||
|
|
||||||
|
hub.CaptureEvent(event)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CapturePanic captures a panic with stack trace
|
||||||
|
func (s *SentryProvider) CapturePanic(ctx context.Context, recovered interface{}, stackTrace []byte, extra map[string]interface{}) {
|
||||||
|
if recovered == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
hub := sentry.GetHubFromContext(ctx)
|
||||||
|
if hub == nil {
|
||||||
|
hub = s.hub
|
||||||
|
}
|
||||||
|
|
||||||
|
event := sentry.NewEvent()
|
||||||
|
event.Level = sentry.LevelError
|
||||||
|
event.Message = fmt.Sprintf("Panic: %v", recovered)
|
||||||
|
event.Exception = []sentry.Exception{
|
||||||
|
{
|
||||||
|
Value: fmt.Sprintf("%v", recovered),
|
||||||
|
Type: "panic",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if extra != nil {
|
||||||
|
event.Extra = extra
|
||||||
|
}
|
||||||
|
|
||||||
|
if stackTrace != nil {
|
||||||
|
event.Extra["stack_trace"] = string(stackTrace)
|
||||||
|
}
|
||||||
|
|
||||||
|
hub.CaptureEvent(event)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Flush waits for all events to be sent (useful for graceful shutdown)
|
||||||
|
func (s *SentryProvider) Flush(timeout int) bool {
|
||||||
|
return sentry.Flush(time.Duration(timeout) * time.Second)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close closes the provider and releases resources
|
||||||
|
func (s *SentryProvider) Close() error {
|
||||||
|
sentry.Flush(2 * time.Second)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// convertSeverity converts our Severity to Sentry's Level
|
||||||
|
func (s *SentryProvider) convertSeverity(severity Severity) sentry.Level {
|
||||||
|
switch severity {
|
||||||
|
case SeverityError:
|
||||||
|
return sentry.LevelError
|
||||||
|
case SeverityWarning:
|
||||||
|
return sentry.LevelWarning
|
||||||
|
case SeverityInfo:
|
||||||
|
return sentry.LevelInfo
|
||||||
|
case SeverityDebug:
|
||||||
|
return sentry.LevelDebug
|
||||||
|
default:
|
||||||
|
return sentry.LevelError
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,15 +1,19 @@
|
|||||||
package logger
|
package logger
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
"os"
|
"os"
|
||||||
"runtime/debug"
|
"runtime/debug"
|
||||||
|
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
|
|
||||||
|
errortracking "github.com/bitechdev/ResolveSpec/pkg/errortracking"
|
||||||
)
|
)
|
||||||
|
|
||||||
var Logger *zap.SugaredLogger
|
var Logger *zap.SugaredLogger
|
||||||
|
var errorTracker errortracking.Provider
|
||||||
|
|
||||||
func Init(dev bool) {
|
func Init(dev bool) {
|
||||||
|
|
||||||
@@ -49,6 +53,28 @@ func UpdateLogger(config *zap.Config) {
|
|||||||
Info("ResolveSpec Logger initialized")
|
Info("ResolveSpec Logger initialized")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// InitErrorTracking initializes the error tracking provider
|
||||||
|
func InitErrorTracking(provider errortracking.Provider) {
|
||||||
|
errorTracker = provider
|
||||||
|
if errorTracker != nil {
|
||||||
|
Info("Error tracking initialized")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetErrorTracker returns the current error tracking provider
|
||||||
|
func GetErrorTracker() errortracking.Provider {
|
||||||
|
return errorTracker
|
||||||
|
}
|
||||||
|
|
||||||
|
// CloseErrorTracking flushes and closes the error tracking provider
|
||||||
|
func CloseErrorTracking() error {
|
||||||
|
if errorTracker != nil {
|
||||||
|
errorTracker.Flush(5)
|
||||||
|
return errorTracker.Close()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func Info(template string, args ...interface{}) {
|
func Info(template string, args ...interface{}) {
|
||||||
if Logger == nil {
|
if Logger == nil {
|
||||||
log.Printf(template, args...)
|
log.Printf(template, args...)
|
||||||
@@ -58,19 +84,35 @@ func Info(template string, args ...interface{}) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func Warn(template string, args ...interface{}) {
|
func Warn(template string, args ...interface{}) {
|
||||||
|
message := fmt.Sprintf(template, args...)
|
||||||
if Logger == nil {
|
if Logger == nil {
|
||||||
log.Printf(template, args...)
|
log.Printf("%s", message)
|
||||||
return
|
} else {
|
||||||
|
Logger.Warnw(message, "process_id", os.Getpid())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send to error tracker
|
||||||
|
if errorTracker != nil {
|
||||||
|
errorTracker.CaptureMessage(context.Background(), message, errortracking.SeverityWarning, map[string]interface{}{
|
||||||
|
"process_id": os.Getpid(),
|
||||||
|
})
|
||||||
}
|
}
|
||||||
Logger.Warnw(fmt.Sprintf(template, args...), "process_id", os.Getpid())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func Error(template string, args ...interface{}) {
|
func Error(template string, args ...interface{}) {
|
||||||
|
message := fmt.Sprintf(template, args...)
|
||||||
if Logger == nil {
|
if Logger == nil {
|
||||||
log.Printf(template, args...)
|
log.Printf("%s", message)
|
||||||
return
|
} else {
|
||||||
|
Logger.Errorw(message, "process_id", os.Getpid())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send to error tracker
|
||||||
|
if errorTracker != nil {
|
||||||
|
errorTracker.CaptureMessage(context.Background(), message, errortracking.SeverityError, map[string]interface{}{
|
||||||
|
"process_id": os.Getpid(),
|
||||||
|
})
|
||||||
}
|
}
|
||||||
Logger.Errorw(fmt.Sprintf(template, args...), "process_id", os.Getpid())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func Debug(template string, args ...interface{}) {
|
func Debug(template string, args ...interface{}) {
|
||||||
@@ -84,7 +126,7 @@ func Debug(template string, args ...interface{}) {
|
|||||||
// CatchPanic - Handle panic
|
// CatchPanic - Handle panic
|
||||||
func CatchPanicCallback(location string, cb func(err any)) {
|
func CatchPanicCallback(location string, cb func(err any)) {
|
||||||
if err := recover(); err != nil {
|
if err := recover(); err != nil {
|
||||||
// callstack := debug.Stack()
|
callstack := debug.Stack()
|
||||||
|
|
||||||
if Logger != nil {
|
if Logger != nil {
|
||||||
Error("Panic in %s : %v", location, err)
|
Error("Panic in %s : %v", location, err)
|
||||||
@@ -93,14 +135,13 @@ func CatchPanicCallback(location string, cb func(err any)) {
|
|||||||
debug.PrintStack()
|
debug.PrintStack()
|
||||||
}
|
}
|
||||||
|
|
||||||
// push to sentry
|
// Send to error tracker
|
||||||
// hub := sentry.CurrentHub()
|
if errorTracker != nil {
|
||||||
// if hub != nil {
|
errorTracker.CapturePanic(context.Background(), err, callstack, map[string]interface{}{
|
||||||
// evtID := hub.Recover(err)
|
"location": location,
|
||||||
// if evtID != nil {
|
"process_id": os.Getpid(),
|
||||||
// sentry.Flush(time.Second * 2)
|
})
|
||||||
// }
|
}
|
||||||
// }
|
|
||||||
|
|
||||||
if cb != nil {
|
if cb != nil {
|
||||||
cb(err)
|
cb(err)
|
||||||
@@ -125,5 +166,14 @@ func CatchPanic(location string) {
|
|||||||
func HandlePanic(methodName string, r any) error {
|
func HandlePanic(methodName string, r any) error {
|
||||||
stack := debug.Stack()
|
stack := debug.Stack()
|
||||||
Error("Panic in %s: %v\nStack trace:\n%s", methodName, r, string(stack))
|
Error("Panic in %s: %v\nStack trace:\n%s", methodName, r, string(stack))
|
||||||
|
|
||||||
|
// Send to error tracker
|
||||||
|
if errorTracker != nil {
|
||||||
|
errorTracker.CapturePanic(context.Background(), r, stack, map[string]interface{}{
|
||||||
|
"method": methodName,
|
||||||
|
"process_id": os.Getpid(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
return fmt.Errorf("panic in %s: %v", methodName, r)
|
return fmt.Errorf("panic in %s: %v", methodName, r)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/cache"
|
"github.com/bitechdev/ResolveSpec/pkg/cache"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Production-Ready Authenticators
|
// Production-Ready Authenticators
|
||||||
@@ -169,69 +170,98 @@ func (a *DatabaseAuthenticator) Authenticate(r *http.Request) (*UserContext, err
|
|||||||
// Extract session token from header or cookie
|
// Extract session token from header or cookie
|
||||||
sessionToken := r.Header.Get("Authorization")
|
sessionToken := r.Header.Get("Authorization")
|
||||||
reference := "authenticate"
|
reference := "authenticate"
|
||||||
|
var tokens []string
|
||||||
|
|
||||||
if sessionToken == "" {
|
if sessionToken == "" {
|
||||||
// Try cookie
|
// Try cookie
|
||||||
cookie, err := r.Cookie("session_token")
|
cookie, err := r.Cookie("session_token")
|
||||||
if err == nil {
|
if err == nil {
|
||||||
sessionToken = cookie.Value
|
tokens = []string{cookie.Value}
|
||||||
reference = "cookie"
|
reference = "cookie"
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// Remove "Bearer " prefix if present
|
// Parse Authorization header which may contain multiple comma-separated tokens
|
||||||
sessionToken = strings.TrimPrefix(sessionToken, "Bearer ")
|
// Format: "Token abc, Token def" or "Bearer abc" or just "abc"
|
||||||
// Remove "Token " prefix if present
|
rawTokens := strings.Split(sessionToken, ",")
|
||||||
sessionToken = strings.TrimPrefix(sessionToken, "Token ")
|
for _, token := range rawTokens {
|
||||||
|
token = strings.TrimSpace(token)
|
||||||
|
// Remove "Bearer " prefix if present
|
||||||
|
token = strings.TrimPrefix(token, "Bearer ")
|
||||||
|
// Remove "Token " prefix if present
|
||||||
|
token = strings.TrimPrefix(token, "Token ")
|
||||||
|
token = strings.TrimSpace(token)
|
||||||
|
if token != "" {
|
||||||
|
tokens = append(tokens, token)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if sessionToken == "" {
|
if len(tokens) == 0 {
|
||||||
return nil, fmt.Errorf("session token required")
|
return nil, fmt.Errorf("session token required")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Build cache key
|
// Log warning if multiple tokens are provided
|
||||||
cacheKey := fmt.Sprintf("auth:session:%s", sessionToken)
|
if len(tokens) > 1 {
|
||||||
|
logger.Warn("Multiple authentication tokens provided in Authorization header (%d tokens). This is unusual and may indicate a misconfigured client. Header: %s", len(tokens), sessionToken)
|
||||||
// Use cache.GetOrSet to get from cache or load from database
|
|
||||||
var userCtx UserContext
|
|
||||||
err := a.cache.GetOrSet(r.Context(), cacheKey, &userCtx, a.cacheTTL, func() (interface{}, error) {
|
|
||||||
// This function is called only if cache miss
|
|
||||||
var success bool
|
|
||||||
var errorMsg sql.NullString
|
|
||||||
var userJSON sql.NullString
|
|
||||||
|
|
||||||
query := `SELECT p_success, p_error, p_user::text FROM resolvespec_session($1, $2)`
|
|
||||||
err := a.db.QueryRowContext(r.Context(), query, sessionToken, reference).Scan(&success, &errorMsg, &userJSON)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("session query failed: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !success {
|
|
||||||
if errorMsg.Valid {
|
|
||||||
return nil, fmt.Errorf("%s", errorMsg.String)
|
|
||||||
}
|
|
||||||
return nil, fmt.Errorf("invalid or expired session")
|
|
||||||
}
|
|
||||||
|
|
||||||
if !userJSON.Valid {
|
|
||||||
return nil, fmt.Errorf("no user data in session")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse UserContext
|
|
||||||
var user UserContext
|
|
||||||
if err := json.Unmarshal([]byte(userJSON.String), &user); err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to parse user context: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &user, nil
|
|
||||||
})
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update last activity timestamp asynchronously
|
// Try each token until one succeeds
|
||||||
go a.updateSessionActivity(r.Context(), sessionToken, &userCtx)
|
var lastErr error
|
||||||
|
for _, token := range tokens {
|
||||||
|
// Build cache key
|
||||||
|
cacheKey := fmt.Sprintf("auth:session:%s", token)
|
||||||
|
|
||||||
return &userCtx, nil
|
// Use cache.GetOrSet to get from cache or load from database
|
||||||
|
var userCtx UserContext
|
||||||
|
err := a.cache.GetOrSet(r.Context(), cacheKey, &userCtx, a.cacheTTL, func() (any, error) {
|
||||||
|
// This function is called only if cache miss
|
||||||
|
var success bool
|
||||||
|
var errorMsg sql.NullString
|
||||||
|
var userJSON sql.NullString
|
||||||
|
|
||||||
|
query := `SELECT p_success, p_error, p_user::text FROM resolvespec_session($1, $2)`
|
||||||
|
err := a.db.QueryRowContext(r.Context(), query, token, reference).Scan(&success, &errorMsg, &userJSON)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("session query failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !success {
|
||||||
|
if errorMsg.Valid {
|
||||||
|
return nil, fmt.Errorf("%s", errorMsg.String)
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("invalid or expired session")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !userJSON.Valid {
|
||||||
|
return nil, fmt.Errorf("no user data in session")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse UserContext
|
||||||
|
var user UserContext
|
||||||
|
if err := json.Unmarshal([]byte(userJSON.String), &user); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse user context: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &user, nil
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
lastErr = err
|
||||||
|
continue // Try next token
|
||||||
|
}
|
||||||
|
|
||||||
|
// Authentication succeeded with this token
|
||||||
|
// Update last activity timestamp asynchronously
|
||||||
|
go a.updateSessionActivity(r.Context(), token, &userCtx)
|
||||||
|
|
||||||
|
return &userCtx, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// All tokens failed
|
||||||
|
if lastErr != nil {
|
||||||
|
return nil, lastErr
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("authentication failed for all provided tokens")
|
||||||
}
|
}
|
||||||
|
|
||||||
// ClearCache removes a specific token from the cache or clears all cache if token is empty
|
// ClearCache removes a specific token from the cache or clears all cache if token is empty
|
||||||
|
|||||||
@@ -545,6 +545,96 @@ func TestDatabaseAuthenticator(t *testing.T) {
|
|||||||
t.Fatal("expected error when token is missing")
|
t.Fatal("expected error when token is missing")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("authenticate with multiple comma-separated tokens", func(t *testing.T) {
|
||||||
|
req := httptest.NewRequest("GET", "/test", nil)
|
||||||
|
req.Header.Set("Authorization", "Token invalid-token, Token valid-token-123")
|
||||||
|
|
||||||
|
// First token fails
|
||||||
|
rows1 := sqlmock.NewRows([]string{"p_success", "p_error", "p_user"}).
|
||||||
|
AddRow(false, "Invalid token", nil)
|
||||||
|
|
||||||
|
mock.ExpectQuery(`SELECT p_success, p_error, p_user::text FROM resolvespec_session`).
|
||||||
|
WithArgs("invalid-token", "authenticate").
|
||||||
|
WillReturnRows(rows1)
|
||||||
|
|
||||||
|
// Second token succeeds
|
||||||
|
rows2 := sqlmock.NewRows([]string{"p_success", "p_error", "p_user"}).
|
||||||
|
AddRow(true, nil, `{"user_id":3,"user_name":"multitoken","session_id":"valid-token-123"}`)
|
||||||
|
|
||||||
|
mock.ExpectQuery(`SELECT p_success, p_error, p_user::text FROM resolvespec_session`).
|
||||||
|
WithArgs("valid-token-123", "authenticate").
|
||||||
|
WillReturnRows(rows2)
|
||||||
|
|
||||||
|
userCtx, err := auth.Authenticate(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected no error, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if userCtx.UserID != 3 {
|
||||||
|
t.Errorf("expected UserID 3, got %d", userCtx.UserID)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := mock.ExpectationsWereMet(); err != nil {
|
||||||
|
t.Errorf("unfulfilled expectations: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("authenticate with duplicate tokens", func(t *testing.T) {
|
||||||
|
req := httptest.NewRequest("GET", "/test", nil)
|
||||||
|
req.Header.Set("Authorization", "Token 968CA5AE-4F83-4D55-A3C6-51AE4410E03A, Token 968CA5AE-4F83-4D55-A3C6-51AE4410E03A")
|
||||||
|
|
||||||
|
// First token succeeds
|
||||||
|
rows := sqlmock.NewRows([]string{"p_success", "p_error", "p_user"}).
|
||||||
|
AddRow(true, nil, `{"user_id":4,"user_name":"duplicateuser","session_id":"968CA5AE-4F83-4D55-A3C6-51AE4410E03A"}`)
|
||||||
|
|
||||||
|
mock.ExpectQuery(`SELECT p_success, p_error, p_user::text FROM resolvespec_session`).
|
||||||
|
WithArgs("968CA5AE-4F83-4D55-A3C6-51AE4410E03A", "authenticate").
|
||||||
|
WillReturnRows(rows)
|
||||||
|
|
||||||
|
userCtx, err := auth.Authenticate(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected no error, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if userCtx.UserID != 4 {
|
||||||
|
t.Errorf("expected UserID 4, got %d", userCtx.UserID)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := mock.ExpectationsWereMet(); err != nil {
|
||||||
|
t.Errorf("unfulfilled expectations: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("authenticate with all tokens failing", func(t *testing.T) {
|
||||||
|
req := httptest.NewRequest("GET", "/test", nil)
|
||||||
|
req.Header.Set("Authorization", "Token bad-token-1, Token bad-token-2")
|
||||||
|
|
||||||
|
// First token fails
|
||||||
|
rows1 := sqlmock.NewRows([]string{"p_success", "p_error", "p_user"}).
|
||||||
|
AddRow(false, "Invalid token", nil)
|
||||||
|
|
||||||
|
mock.ExpectQuery(`SELECT p_success, p_error, p_user::text FROM resolvespec_session`).
|
||||||
|
WithArgs("bad-token-1", "authenticate").
|
||||||
|
WillReturnRows(rows1)
|
||||||
|
|
||||||
|
// Second token also fails
|
||||||
|
rows2 := sqlmock.NewRows([]string{"p_success", "p_error", "p_user"}).
|
||||||
|
AddRow(false, "Invalid token", nil)
|
||||||
|
|
||||||
|
mock.ExpectQuery(`SELECT p_success, p_error, p_user::text FROM resolvespec_session`).
|
||||||
|
WithArgs("bad-token-2", "authenticate").
|
||||||
|
WillReturnRows(rows2)
|
||||||
|
|
||||||
|
_, err := auth.Authenticate(req)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error when all tokens fail")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := mock.ExpectationsWereMet(); err != nil {
|
||||||
|
t.Errorf("unfulfilled expectations: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test DatabaseAuthenticator RefreshToken
|
// Test DatabaseAuthenticator RefreshToken
|
||||||
|
|||||||
Reference in New Issue
Block a user