From ceaa251301e7e1ce8864cc0c04f65d1572c514db Mon Sep 17 00:00:00 2001 From: Hein Date: Mon, 10 Nov 2025 17:02:37 +0200 Subject: [PATCH] Updated logging, added getRowNumber and a few more --- SCHEMA_TABLE_HANDLING.md | 138 ---- cmd/testserver/main.go | 3 +- go.mod | 4 + go.sum | 9 + pkg/common/adapters/database/utils.go | 157 ----- pkg/common/types.go | 11 +- pkg/logger/logger.go | 33 + pkg/reflection/generic_model.go | 100 +++ pkg/reflection/model_utils.go | 162 +++++ .../model_utils_test.go} | 10 +- pkg/restheadspec/cursor.go | 3 +- pkg/restheadspec/handler.go | 288 +++++++- pkg/restheadspec/hooks.go | 7 + pkg/restheadspec/rownumber_test.go | 203 ++++++ pkg/security/CALLBACKS_GUIDE.md | 662 ++++++++++++++++++ pkg/security/QUICK_REFERENCE.md | 402 +++++++++++ pkg/security/callbacks_example.go | 418 +++++++++++ pkg/security/hooks.go | 244 +++++++ pkg/security/middleware.go | 54 ++ pkg/security/provider.go | 460 ++++++++++++ pkg/security/setup_example.go | 155 ++++ tests/test_helpers.go | 2 +- todo.md | 2 + 23 files changed, 3215 insertions(+), 312 deletions(-) delete mode 100644 SCHEMA_TABLE_HANDLING.md create mode 100644 pkg/reflection/generic_model.go create mode 100644 pkg/reflection/model_utils.go rename pkg/{common/adapters/database/utils_test.go => reflection/model_utils_test.go} (95%) create mode 100644 pkg/restheadspec/rownumber_test.go create mode 100644 pkg/security/CALLBACKS_GUIDE.md create mode 100644 pkg/security/QUICK_REFERENCE.md create mode 100644 pkg/security/callbacks_example.go create mode 100644 pkg/security/hooks.go create mode 100644 pkg/security/middleware.go create mode 100644 pkg/security/provider.go create mode 100644 pkg/security/setup_example.go diff --git a/SCHEMA_TABLE_HANDLING.md b/SCHEMA_TABLE_HANDLING.md deleted file mode 100644 index 999fa34..0000000 --- a/SCHEMA_TABLE_HANDLING.md +++ /dev/null @@ -1,138 +0,0 @@ -# Schema and Table Name Handling - -This document explains how the handlers properly separate and handle schema and table names. - -## Implementation - -Both `resolvespec` and `restheadspec` handlers now properly handle schema and table name separation through the following functions: - -- `parseTableName(fullTableName)` - Splits "schema.table" into separate components -- `getSchemaAndTable(defaultSchema, entity, model)` - Returns schema and table separately -- `getTableName(schema, entity, model)` - Returns the full "schema.table" format - -## Priority Order - -When determining the schema and table name, the following priority is used: - -1. **If `TableName()` contains a schema** (e.g., "myschema.mytable"), that schema takes precedence -2. **If model implements `SchemaProvider`**, use that schema -3. **Otherwise**, use the `defaultSchema` parameter from the URL/request - -## Scenarios - -### Scenario 1: Simple table name, default schema -```go -type User struct { - ID string - Name string -} - -func (User) TableName() string { - return "users" -} -``` -- Request URL: `/api/public/users` -- Result: `schema="public"`, `table="users"`, `fullName="public.users"` - -### Scenario 2: Table name includes schema -```go -type User struct { - ID string - Name string -} - -func (User) TableName() string { - return "auth.users" // Schema included! -} -``` -- Request URL: `/api/public/users` (public is ignored) -- Result: `schema="auth"`, `table="users"`, `fullName="auth.users"` -- **Note**: The schema from `TableName()` takes precedence over the URL schema - -### Scenario 3: Using SchemaProvider -```go -type User struct { - ID string - Name string -} - -func (User) TableName() string { - return "users" -} - -func (User) SchemaName() string { - return "auth" -} -``` -- Request URL: `/api/public/users` (public is ignored) -- Result: `schema="auth"`, `table="users"`, `fullName="auth.users"` - -### Scenario 4: Table name includes schema AND SchemaProvider -```go -type User struct { - ID string - Name string -} - -func (User) TableName() string { - return "core.users" // This wins! -} - -func (User) SchemaName() string { - return "auth" // This is ignored -} -``` -- Request URL: `/api/public/users` -- Result: `schema="core"`, `table="users"`, `fullName="core.users"` -- **Note**: Schema from `TableName()` takes highest precedence - -### Scenario 5: No providers at all -```go -type User struct { - ID string - Name string -} -// No TableName() or SchemaName() -``` -- Request URL: `/api/public/users` -- Result: `schema="public"`, `table="users"`, `fullName="public.users"` -- Uses URL schema and entity name - -## Key Features - -1. **Automatic detection**: The code automatically detects if `TableName()` includes a schema by checking for "." -2. **Backward compatible**: Existing code continues to work -3. **Flexible**: Supports multiple ways to specify schema and table -4. **Debug logging**: Logs when schema is detected in `TableName()` for debugging - -## Code Locations - -### Handlers -- `/pkg/resolvespec/handler.go:472-531` -- `/pkg/restheadspec/handler.go:534-593` - -### Database Adapters -- `/pkg/common/adapters/database/utils.go` - Shared `parseTableName()` function -- `/pkg/common/adapters/database/bun.go` - Bun adapter with separated schema/table -- `/pkg/common/adapters/database/gorm.go` - GORM adapter with separated schema/table - -## Adapter Implementation - -Both Bun and GORM adapters now properly separate schema and table name: - -```go -// BunSelectQuery/GormSelectQuery now have separated fields: -type BunSelectQuery struct { - query *bun.SelectQuery - schema string // Separated schema name - tableName string // Just the table name, without schema - tableAlias string -} -``` - -When `Model()` or `Table()` is called: -1. The full table name (which may include schema) is parsed -2. Schema and table name are stored separately -3. When building joins, the already-separated table name is used directly - -This ensures consistent handling of schema-qualified table names throughout the codebase. diff --git a/cmd/testserver/main.go b/cmd/testserver/main.go index 7aba9d9..2309fa9 100644 --- a/cmd/testserver/main.go +++ b/cmd/testserver/main.go @@ -1,7 +1,6 @@ package main import ( - "fmt" "log" "net/http" "os" @@ -21,8 +20,8 @@ import ( func main() { // Initialize logger - fmt.Println("ResolveSpec test server starting") logger.Init(true) + logger.Info("ResolveSpec test server starting") // Initialize database db, err := initDB() diff --git a/go.mod b/go.mod index 0288004..fa9d621 100644 --- a/go.mod +++ b/go.mod @@ -24,6 +24,10 @@ require ( github.com/pmezard/go-difflib v1.0.0 // indirect github.com/puzpuzpuz/xsync/v3 v3.5.1 // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect + github.com/tidwall/gjson v1.18.0 // indirect + github.com/tidwall/match v1.1.1 // indirect + github.com/tidwall/pretty v1.2.0 // indirect + github.com/tidwall/sjson v1.2.5 // indirect github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc // indirect github.com/uptrace/bunrouter v1.0.23 // indirect github.com/vmihailenco/msgpack/v5 v5.4.1 // indirect diff --git a/go.sum b/go.sum index 5c98243..a40ec50 100644 --- a/go.sum +++ b/go.sum @@ -37,6 +37,15 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY= +github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= +github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= +github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= +github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= +github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc h1:9lRDQMhESg+zvGYmW5DyG0UqvY96Bu5QYsTLvCHdrgo= github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc/go.mod h1:bciPuU6GHm1iF1pBvUfxfsH0Wmnc2VbpgvbI9ZWuIRs= github.com/uptrace/bun v1.2.15 h1:Ut68XRBLDgp9qG9QBMa9ELWaZOmzHNdczHQdrOZbEFE= diff --git a/pkg/common/adapters/database/utils.go b/pkg/common/adapters/database/utils.go index 9031df5..4caec07 100644 --- a/pkg/common/adapters/database/utils.go +++ b/pkg/common/adapters/database/utils.go @@ -1,10 +1,7 @@ package database import ( - "reflect" "strings" - - "github.com/bitechdev/ResolveSpec/pkg/common" ) // parseTableName splits a table name that may contain schema into separate schema and table @@ -17,157 +14,3 @@ func parseTableName(fullTableName string) (schema, table string) { } return "", fullTableName } - -// GetPrimaryKeyName extracts the primary key column name from a model -// It first checks if the model implements PrimaryKeyNameProvider (GetIDName method) -// Falls back to reflection to find bun:",pk" tag, then gorm:"primaryKey" tag -func GetPrimaryKeyName(model any) string { - // Check if model implements PrimaryKeyNameProvider - if provider, ok := model.(common.PrimaryKeyNameProvider); ok { - return provider.GetIDName() - } - - // Try Bun tag first - if pkName := getPrimaryKeyFromReflection(model, "bun"); pkName != "" { - return pkName - } - - // Fall back to GORM tag - return getPrimaryKeyFromReflection(model, "gorm") -} - -// GetModelColumns extracts all column names from a model using reflection -// It checks bun tags first, then gorm tags, then json tags, and finally falls back to lowercase field names -func GetModelColumns(model any) []string { - var columns []string - - modelType := reflect.TypeOf(model) - - // Unwrap pointers, slices, and arrays to get to the base struct type - for modelType != nil && (modelType.Kind() == reflect.Pointer || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) { - modelType = modelType.Elem() - } - - // Validate that we have a struct type - if modelType == nil || modelType.Kind() != reflect.Struct { - return columns - } - - for i := 0; i < modelType.NumField(); i++ { - field := modelType.Field(i) - - // Get column name using the same logic as primary key extraction - columnName := getColumnNameFromField(field) - - if columnName != "" { - columns = append(columns, columnName) - } - } - - return columns -} - -// getColumnNameFromField extracts the column name from a struct field -// Priority: bun tag -> gorm tag -> json tag -> lowercase field name -func getColumnNameFromField(field reflect.StructField) string { - // Try bun tag first - bunTag := field.Tag.Get("bun") - if bunTag != "" && bunTag != "-" { - if colName := extractColumnFromBunTag(bunTag); colName != "" { - return colName - } - } - - // Try gorm tag - gormTag := field.Tag.Get("gorm") - if gormTag != "" && gormTag != "-" { - if colName := extractColumnFromGormTag(gormTag); colName != "" { - return colName - } - } - - // Fall back to json tag - jsonTag := field.Tag.Get("json") - if jsonTag != "" && jsonTag != "-" { - // Extract just the field name before any options - parts := strings.Split(jsonTag, ",") - if len(parts) > 0 && parts[0] != "" { - return parts[0] - } - } - - // Last resort: use field name in lowercase - return strings.ToLower(field.Name) -} - -// getPrimaryKeyFromReflection uses reflection to find the primary key field -func getPrimaryKeyFromReflection(model any, ormType string) string { - val := reflect.ValueOf(model) - if val.Kind() == reflect.Pointer { - val = val.Elem() - } - - if val.Kind() != reflect.Struct { - return "" - } - - typ := val.Type() - for i := 0; i < typ.NumField(); i++ { - field := typ.Field(i) - - switch ormType { - case "gorm": - // Check for gorm tag with primaryKey - gormTag := field.Tag.Get("gorm") - if strings.Contains(gormTag, "primaryKey") { - // Try to extract column name from gorm tag - if colName := extractColumnFromGormTag(gormTag); colName != "" { - return colName - } - // Fall back to json tag - if jsonTag := field.Tag.Get("json"); jsonTag != "" { - return strings.Split(jsonTag, ",")[0] - } - } - case "bun": - // Check for bun tag with pk flag - bunTag := field.Tag.Get("bun") - if strings.Contains(bunTag, "pk") { - // Extract column name from bun tag - if colName := extractColumnFromBunTag(bunTag); colName != "" { - return colName - } - // Fall back to json tag - if jsonTag := field.Tag.Get("json"); jsonTag != "" { - return strings.Split(jsonTag, ",")[0] - } - } - } - } - - return "" -} - -// extractColumnFromGormTag extracts the column name from a gorm tag -// Example: "column:id;primaryKey" -> "id" -func extractColumnFromGormTag(tag string) string { - parts := strings.Split(tag, ";") - for _, part := range parts { - part = strings.TrimSpace(part) - if colName, found := strings.CutPrefix(part, "column:"); found { - return colName - } - } - return "" -} - -// extractColumnFromBunTag extracts the column name from a bun tag -// Example: "id,pk" -> "id" -// Example: ",pk" -> "" (will fall back to json tag) -func extractColumnFromBunTag(tag string) string { - parts := strings.Split(tag, ",") - if len(parts) > 0 && parts[0] != "" { - return parts[0] - } - return "" -} diff --git a/pkg/common/types.go b/pkg/common/types.go index 2642291..3a9eaaa 100644 --- a/pkg/common/types.go +++ b/pkg/common/types.go @@ -72,11 +72,12 @@ type Response struct { } type Metadata struct { - Total int64 `json:"total"` - Count int64 `json:"count"` - Filtered int64 `json:"filtered"` - Limit int `json:"limit"` - Offset int `json:"offset"` + Total int64 `json:"total"` + Count int64 `json:"count"` + Filtered int64 `json:"filtered"` + Limit int `json:"limit"` + Offset int `json:"offset"` + RowNumber *int64 `json:"row_number,omitempty"` } type APIError struct { diff --git a/pkg/logger/logger.go b/pkg/logger/logger.go index 1f78aa4..ba8e4e1 100644 --- a/pkg/logger/logger.go +++ b/pkg/logger/logger.go @@ -4,6 +4,7 @@ import ( "fmt" "log" "os" + "runtime/debug" "go.uber.org/zap" ) @@ -70,3 +71,35 @@ func Debug(template string, args ...interface{}) { } Logger.Debugw(fmt.Sprintf(template, args...), "process_id", os.Getpid()) } + +// CatchPanic - Handle panic +func CatchPanicCallback(location string, cb func(err any)) { + if err := recover(); err != nil { + //callstack := debug.Stack() + + if Logger != nil { + Error("Panic in %s : %v", location, err) + } else { + fmt.Printf("%s:PANIC->%+v", location, err) + debug.PrintStack() + } + + //push to sentry + // hub := sentry.CurrentHub() + // if hub != nil { + // evtID := hub.Recover(err) + // if evtID != nil { + // sentry.Flush(time.Second * 2) + // } + // } + + if cb != nil { + cb(err) + } + } +} + +// CatchPanic - Handle panic +func CatchPanic(location string) { + CatchPanicCallback(location, nil) +} diff --git a/pkg/reflection/generic_model.go b/pkg/reflection/generic_model.go new file mode 100644 index 0000000..1463958 --- /dev/null +++ b/pkg/reflection/generic_model.go @@ -0,0 +1,100 @@ +package reflection + +import ( + "reflect" + "strings" + + "github.com/bitechdev/ResolveSpec/pkg/logger" +) + +type ModelFieldDetail struct { + Name string `json:"name"` + DataType string `json:"datatype"` + SQLName string `json:"sqlname"` + SQLDataType string `json:"sqldatatype"` + SQLKey string `json:"sqlkey"` + Nullable bool `json:"nullable"` + FieldValue reflect.Value `json:"-"` +} + +// GetModelColumnDetail - Get a list of columns in the SQL declaration of the model +func GetModelColumnDetail(record reflect.Value) []ModelFieldDetail { + defer func() { + if r := recover(); r != nil { + logger.Error("Panic in GetModelColumnDetail : %v", r) + } + }() + + var lst []ModelFieldDetail + lst = make([]ModelFieldDetail, 0) + + if !record.IsValid() { + return lst + } + if record.Kind() == reflect.Pointer || record.Kind() == reflect.Interface { + record = record.Elem() + } + if record.Kind() != reflect.Struct { + return lst + } + modeltype := record.Type() + + for i := 0; i < modeltype.NumField(); i++ { + fieldtype := modeltype.Field(i) + gormdetail := fieldtype.Tag.Get("gorm") + gormdetail = strings.Trim(gormdetail, " ") + fielddetail := ModelFieldDetail{} + fielddetail.FieldValue = record.Field(i) + fielddetail.Name = fieldtype.Name + fielddetail.DataType = fieldtype.Type.Name() + fielddetail.SQLName = fnFindKeyVal(gormdetail, "column:") + fielddetail.SQLDataType = fnFindKeyVal(gormdetail, "type:") + if strings.Index(strings.ToLower(gormdetail), "identity") > 0 || + strings.Index(strings.ToLower(gormdetail), "primary_key") > 0 { + fielddetail.SQLKey = "primary_key" + } else if strings.Contains(strings.ToLower(gormdetail), "unique") { + fielddetail.SQLKey = "unique" + } else if strings.Contains(strings.ToLower(gormdetail), "uniqueindex") { + fielddetail.SQLKey = "uniqueindex" + } + + if strings.Contains(strings.ToLower(gormdetail), "nullable") { + fielddetail.Nullable = true + } else if strings.Contains(strings.ToLower(gormdetail), "null") { + fielddetail.Nullable = true + } + if strings.Contains(strings.ToLower(gormdetail), "not null") { + fielddetail.Nullable = false + } + + if strings.Contains(strings.ToLower(gormdetail), "foreignkey:") { + fielddetail.SQLKey = "foreign_key" + ik := strings.Index(strings.ToLower(gormdetail), "foreignkey:") + ie := strings.Index(gormdetail[ik:], ";") + if ie > ik && ik > 0 { + fielddetail.SQLName = strings.ToLower(gormdetail)[ik+11 : ik+ie] + //fmt.Printf("\r\nforeignkey: %v", fielddetail) + } + + } + //";foreignkey:rid_parent;association_foreignkey:id_atevent;save_associations:false;association_autocreate:false;" + + lst = append(lst, fielddetail) + + } + return lst +} + +func fnFindKeyVal(src, key string) string { + icolStart := strings.Index(strings.ToLower(src), strings.ToLower(key)) + val := "" + if icolStart >= 0 { + val = src[icolStart+len(key):] + icolend := strings.Index(val, ";") + if icolend > 0 { + val = val[:icolend] + } + return val + } + return "" +} diff --git a/pkg/reflection/model_utils.go b/pkg/reflection/model_utils.go new file mode 100644 index 0000000..4a28e60 --- /dev/null +++ b/pkg/reflection/model_utils.go @@ -0,0 +1,162 @@ +package reflection + +import ( + "reflect" + "strings" + + "github.com/bitechdev/ResolveSpec/pkg/common" +) + +// GetPrimaryKeyName extracts the primary key column name from a model +// It first checks if the model implements PrimaryKeyNameProvider (GetIDName method) +// Falls back to reflection to find bun:",pk" tag, then gorm:"primaryKey" tag +func GetPrimaryKeyName(model any) string { + // Check if model implements PrimaryKeyNameProvider + if provider, ok := model.(common.PrimaryKeyNameProvider); ok { + return provider.GetIDName() + } + + // Try Bun tag first + if pkName := getPrimaryKeyFromReflection(model, "bun"); pkName != "" { + return pkName + } + + // Fall back to GORM tag + return getPrimaryKeyFromReflection(model, "gorm") +} + +// GetModelColumns extracts all column names from a model using reflection +// It checks bun tags first, then gorm tags, then json tags, and finally falls back to lowercase field names +func GetModelColumns(model any) []string { + var columns []string + + modelType := reflect.TypeOf(model) + + // Unwrap pointers, slices, and arrays to get to the base struct type + for modelType != nil && (modelType.Kind() == reflect.Pointer || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) { + modelType = modelType.Elem() + } + + // Validate that we have a struct type + if modelType == nil || modelType.Kind() != reflect.Struct { + return columns + } + + for i := 0; i < modelType.NumField(); i++ { + field := modelType.Field(i) + + // Get column name using the same logic as primary key extraction + columnName := getColumnNameFromField(field) + + if columnName != "" { + columns = append(columns, columnName) + } + } + + return columns +} + +// getColumnNameFromField extracts the column name from a struct field +// Priority: bun tag -> gorm tag -> json tag -> lowercase field name +func getColumnNameFromField(field reflect.StructField) string { + // Try bun tag first + bunTag := field.Tag.Get("bun") + if bunTag != "" && bunTag != "-" { + if colName := ExtractColumnFromBunTag(bunTag); colName != "" { + return colName + } + } + + // Try gorm tag + gormTag := field.Tag.Get("gorm") + if gormTag != "" && gormTag != "-" { + if colName := ExtractColumnFromGormTag(gormTag); colName != "" { + return colName + } + } + + // Fall back to json tag + jsonTag := field.Tag.Get("json") + if jsonTag != "" && jsonTag != "-" { + // Extract just the field name before any options + parts := strings.Split(jsonTag, ",") + if len(parts) > 0 && parts[0] != "" { + return parts[0] + } + } + + // Last resort: use field name in lowercase + return strings.ToLower(field.Name) +} + +// getPrimaryKeyFromReflection uses reflection to find the primary key field +func getPrimaryKeyFromReflection(model any, ormType string) string { + val := reflect.ValueOf(model) + if val.Kind() == reflect.Pointer { + val = val.Elem() + } + + if val.Kind() != reflect.Struct { + return "" + } + + typ := val.Type() + for i := 0; i < typ.NumField(); i++ { + field := typ.Field(i) + + switch ormType { + case "gorm": + // Check for gorm tag with primaryKey + gormTag := field.Tag.Get("gorm") + if strings.Contains(gormTag, "primaryKey") { + // Try to extract column name from gorm tag + if colName := ExtractColumnFromGormTag(gormTag); colName != "" { + return colName + } + // Fall back to json tag + if jsonTag := field.Tag.Get("json"); jsonTag != "" { + return strings.Split(jsonTag, ",")[0] + } + } + case "bun": + // Check for bun tag with pk flag + bunTag := field.Tag.Get("bun") + if strings.Contains(bunTag, "pk") { + // Extract column name from bun tag + if colName := ExtractColumnFromBunTag(bunTag); colName != "" { + return colName + } + // Fall back to json tag + if jsonTag := field.Tag.Get("json"); jsonTag != "" { + return strings.Split(jsonTag, ",")[0] + } + } + } + } + + return "" +} + +// ExtractColumnFromGormTag extracts the column name from a gorm tag +// Example: "column:id;primaryKey" -> "id" +func ExtractColumnFromGormTag(tag string) string { + parts := strings.Split(tag, ";") + for _, part := range parts { + part = strings.TrimSpace(part) + if colName, found := strings.CutPrefix(part, "column:"); found { + return colName + } + } + return "" +} + +// ExtractColumnFromBunTag extracts the column name from a bun tag +// Example: "id,pk" -> "id" +// Example: ",pk" -> "" (will fall back to json tag) +func ExtractColumnFromBunTag(tag string) string { + parts := strings.Split(tag, ",") + if len(parts) > 0 && parts[0] != "" { + return parts[0] + } + return "" +} diff --git a/pkg/common/adapters/database/utils_test.go b/pkg/reflection/model_utils_test.go similarity index 95% rename from pkg/common/adapters/database/utils_test.go rename to pkg/reflection/model_utils_test.go index 0be46bf..dd2f020 100644 --- a/pkg/common/adapters/database/utils_test.go +++ b/pkg/reflection/model_utils_test.go @@ -1,4 +1,4 @@ -package database +package reflection import ( "testing" @@ -137,9 +137,9 @@ func TestExtractColumnFromGormTag(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result := extractColumnFromGormTag(tt.tag) + result := ExtractColumnFromGormTag(tt.tag) if result != tt.expected { - t.Errorf("extractColumnFromGormTag() = %v, want %v", result, tt.expected) + t.Errorf("ExtractColumnFromGormTag() = %v, want %v", result, tt.expected) } }) } @@ -170,9 +170,9 @@ func TestExtractColumnFromBunTag(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result := extractColumnFromBunTag(tt.tag) + result := ExtractColumnFromBunTag(tt.tag) if result != tt.expected { - t.Errorf("extractColumnFromBunTag() = %v, want %v", result, tt.expected) + t.Errorf("ExtractColumnFromBunTag() = %v, want %v", result, tt.expected) } }) } diff --git a/pkg/restheadspec/cursor.go b/pkg/restheadspec/cursor.go index 7e4b5c0..ba78436 100644 --- a/pkg/restheadspec/cursor.go +++ b/pkg/restheadspec/cursor.go @@ -5,6 +5,7 @@ import ( "strings" "github.com/bitechdev/ResolveSpec/pkg/common" + "github.com/bitechdev/ResolveSpec/pkg/logger" ) // CursorDirection defines pagination direction @@ -85,7 +86,7 @@ func (opts *ExtendedRequestOptions) GetCursorFilter( field, prefix, tableName, modelColumns, ) if err != nil { - fmt.Printf("WARN: Skipping invalid sort column %q: %v\n", col, err) + logger.Warn("Skipping invalid sort column %q: %v", col, err) continue } diff --git a/pkg/restheadspec/handler.go b/pkg/restheadspec/handler.go index d45f3de..4a73ec1 100644 --- a/pkg/restheadspec/handler.go +++ b/pkg/restheadspec/handler.go @@ -10,8 +10,8 @@ import ( "strings" "github.com/bitechdev/ResolveSpec/pkg/common" - "github.com/bitechdev/ResolveSpec/pkg/common/adapters/database" "github.com/bitechdev/ResolveSpec/pkg/logger" + "github.com/bitechdev/ResolveSpec/pkg/reflection" ) // Handler handles API requests using database and model abstractions @@ -343,10 +343,10 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st logger.Debug("Applying cursor pagination") // Get primary key name - pkName := database.GetPrimaryKeyName(model) + pkName := reflection.GetPrimaryKeyName(model) // Extract model columns for validation using the generic database function - modelColumns := database.GetModelColumns(model) + modelColumns := reflection.GetModelColumns(model) // Build expand joins map (if needed in future) var expandJoins map[string]string @@ -371,6 +371,19 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st } } + // Execute BeforeScan hooks - pass query chain so hooks can modify it + hookCtx.Query = query + if err := h.hooks.Execute(BeforeScan, hookCtx); err != nil { + logger.Error("BeforeScan hook failed: %v", err) + h.sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err) + return + } + + // Use potentially modified query from hook context + if modifiedQuery, ok := hookCtx.Query.(common.SelectQuery); ok { + query = modifiedQuery + } + // Execute query - modelPtr was already created earlier if err := query.Scan(ctx, modelPtr); err != nil { logger.Error("Error executing query: %v", err) @@ -387,6 +400,9 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st offset = *options.Offset } + // Set row numbers on each record if the model has a RowNumber field + h.setRowNumbersOnRecords(modelPtr, offset) + metadata := &common.Metadata{ Total: int64(total), Count: int64(common.Len(modelPtr)), @@ -395,6 +411,23 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st Offset: offset, } + // Fetch row number for a specific record if requested + if options.RequestOptions.FetchRowNumber != nil && *options.RequestOptions.FetchRowNumber != "" { + pkName := reflection.GetPrimaryKeyName(model) + pkValue := *options.RequestOptions.FetchRowNumber + + logger.Debug("Fetching row number for specific PK %s = %s", pkName, pkValue) + + rowNum, err := h.FetchRowNumber(ctx, tableName, pkName, pkValue, options, model) + if err != nil { + logger.Warn("Failed to fetch row number: %v", err) + // Don't fail the entire request, just log the warning + } else { + metadata.RowNumber = &rowNum + logger.Debug("Row number for PK %s: %d", pkValue, rowNum) + } + } + // Execute AfterRead hooks hookCtx.Result = modelPtr hookCtx.Error = nil @@ -466,6 +499,29 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat } query := tx.NewInsert().Model(modelValue).Table(tableName) + + // Execute BeforeScan hooks - pass query chain so hooks can modify it + batchHookCtx := &HookContext{ + Context: ctx, + Handler: h, + Schema: schema, + Entity: entity, + TableName: tableName, + Model: model, + Options: options, + Data: modelValue, + Writer: w, + Query: query, + } + if err := h.hooks.Execute(BeforeScan, batchHookCtx); err != nil { + return fmt.Errorf("BeforeScan hook failed: %w", err) + } + + // Use potentially modified query from hook context + if modifiedQuery, ok := batchHookCtx.Query.(common.InsertQuery); ok { + query = modifiedQuery + } + if _, err := query.Exec(ctx); err != nil { return fmt.Errorf("failed to insert record: %w", err) } @@ -508,6 +564,21 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat } query := h.db.NewInsert().Model(modelValue).Table(tableName) + + // Execute BeforeScan hooks - pass query chain so hooks can modify it + hookCtx.Data = modelValue + hookCtx.Query = query + if err := h.hooks.Execute(BeforeScan, hookCtx); err != nil { + logger.Error("BeforeScan hook failed: %v", err) + h.sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err) + return + } + + // Use potentially modified query from hook context + if modifiedQuery, ok := hookCtx.Query.(common.InsertQuery); ok { + query = modifiedQuery + } + if _, err := query.Exec(ctx); err != nil { logger.Error("Error creating record: %v", err) h.sendError(w, http.StatusInternalServerError, "create_error", "Error creating record", err) @@ -593,6 +664,19 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id return } + // Execute BeforeScan hooks - pass query chain so hooks can modify it + hookCtx.Query = query + if err := h.hooks.Execute(BeforeScan, hookCtx); err != nil { + logger.Error("BeforeScan hook failed: %v", err) + h.sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err) + return + } + + // Use potentially modified query from hook context + if modifiedQuery, ok := hookCtx.Query.(common.UpdateQuery); ok { + query = modifiedQuery + } + result, err := query.Exec(ctx) if err != nil { logger.Error("Error updating record: %v", err) @@ -658,6 +742,19 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id query = query.Where("id = ?", id) + // Execute BeforeScan hooks - pass query chain so hooks can modify it + hookCtx.Query = query + if err := h.hooks.Execute(BeforeScan, hookCtx); err != nil { + logger.Error("BeforeScan hook failed: %v", err) + h.sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err) + return + } + + // Use potentially modified query from hook context + if modifiedQuery, ok := hookCtx.Query.(common.DeleteQuery); ok { + query = modifiedQuery + } + result, err := query.Exec(ctx) if err != nil { logger.Error("Error deleting record: %v", err) @@ -999,6 +1096,191 @@ func (h *Handler) sendError(w common.ResponseWriter, statusCode int, code, messa w.WriteJSON(response) } +// FetchRowNumber calculates the row number of a specific record based on sorting and filtering +// Returns the 1-based row number of the record with the given primary key value +func (h *Handler) FetchRowNumber(ctx context.Context, tableName string, pkName string, pkValue string, options ExtendedRequestOptions, model any) (int64, error) { + defer func() { + if r := recover(); r != nil { + logger.Error("Panic during FetchRowNumber: %v", r) + } + }() + + // Build the sort order SQL + sortSQL := "" + if len(options.Sort) > 0 { + sortParts := make([]string, 0, len(options.Sort)) + for _, sort := range options.Sort { + direction := "ASC" + if strings.ToLower(sort.Direction) == "desc" { + direction = "DESC" + } + sortParts = append(sortParts, fmt.Sprintf("%s.%s %s", tableName, sort.Column, direction)) + } + sortSQL = strings.Join(sortParts, ", ") + } else { + // Default sort by primary key + sortSQL = fmt.Sprintf("%s.%s ASC", tableName, pkName) + } + + // Build WHERE clauses from filters + whereClauses := make([]string, 0) + for i := range options.Filters { + filter := &options.Filters[i] + whereClause := h.buildFilterSQL(filter, tableName) + if whereClause != "" { + whereClauses = append(whereClauses, fmt.Sprintf("(%s)", whereClause)) + } + } + + // Combine WHERE clauses + whereSQL := "" + if len(whereClauses) > 0 { + whereSQL = "WHERE " + strings.Join(whereClauses, " AND ") + } + + // Add custom SQL WHERE if provided + if options.CustomSQLWhere != "" { + if whereSQL == "" { + whereSQL = "WHERE " + options.CustomSQLWhere + } else { + whereSQL += " AND (" + options.CustomSQLWhere + ")" + } + } + + // Build JOIN clauses from Expand options + joinSQL := "" + if len(options.Expand) > 0 { + joinParts := make([]string, 0, len(options.Expand)) + for _, expand := range options.Expand { + // Note: This is a simplified join - in production you'd need proper FK mapping + joinParts = append(joinParts, fmt.Sprintf("LEFT JOIN %s ON %s.%s_id = %s.id", + expand.Relation, tableName, expand.Relation, expand.Relation)) + } + joinSQL = strings.Join(joinParts, "\n") + } + + // Build the final query with parameterized PK value + queryStr := fmt.Sprintf(` + SELECT search.rn + FROM ( + SELECT %[1]s.%[2]s, + ROW_NUMBER() OVER(ORDER BY %[3]s) AS rn + FROM %[1]s + %[5]s + %[4]s + ) search + WHERE search.%[2]s = ? + `, + tableName, // [1] - table name + pkName, // [2] - primary key column name + sortSQL, // [3] - sort order SQL + whereSQL, // [4] - WHERE clause + joinSQL, // [5] - JOIN clauses + ) + + logger.Debug("FetchRowNumber query: %s, pkValue: %s", queryStr, pkValue) + + // Execute the raw query with parameterized PK value + var result []struct { + RN int64 `bun:"rn"` + } + err := h.db.Query(ctx, &result, queryStr, pkValue) + if err != nil { + return 0, fmt.Errorf("failed to fetch row number: %w", err) + } + + if len(result) == 0 { + return 0, fmt.Errorf("no row found for primary key %s", pkValue) + } + + return result[0].RN, nil +} + +// buildFilterSQL converts a filter to SQL WHERE clause string +func (h *Handler) buildFilterSQL(filter *common.FilterOption, tableName string) string { + qualifiedColumn := h.qualifyColumnName(filter.Column, tableName) + + switch strings.ToLower(filter.Operator) { + case "eq", "equals": + return fmt.Sprintf("%s = '%v'", qualifiedColumn, filter.Value) + case "neq", "not_equals", "ne": + return fmt.Sprintf("%s != '%v'", qualifiedColumn, filter.Value) + case "gt", "greater_than": + return fmt.Sprintf("%s > '%v'", qualifiedColumn, filter.Value) + case "gte", "greater_than_equals", "ge": + return fmt.Sprintf("%s >= '%v'", qualifiedColumn, filter.Value) + case "lt", "less_than": + return fmt.Sprintf("%s < '%v'", qualifiedColumn, filter.Value) + case "lte", "less_than_equals", "le": + return fmt.Sprintf("%s <= '%v'", qualifiedColumn, filter.Value) + case "like": + return fmt.Sprintf("%s LIKE '%v'", qualifiedColumn, filter.Value) + case "ilike": + return fmt.Sprintf("%s ILIKE '%v'", qualifiedColumn, filter.Value) + case "in": + if values, ok := filter.Value.([]any); ok { + valueStrs := make([]string, len(values)) + for i, v := range values { + valueStrs[i] = fmt.Sprintf("'%v'", v) + } + return fmt.Sprintf("%s IN (%s)", qualifiedColumn, strings.Join(valueStrs, ", ")) + } + return "" + case "is_null", "isnull": + return fmt.Sprintf("(%s IS NULL OR %s = '')", qualifiedColumn, qualifiedColumn) + case "is_not_null", "isnotnull": + return fmt.Sprintf("(%s IS NOT NULL AND %s != '')", qualifiedColumn, qualifiedColumn) + default: + logger.Warn("Unknown filter operator in buildFilterSQL: %s", filter.Operator) + return "" + } +} + +// setRowNumbersOnRecords sets the RowNumber field on each record if it exists +// The row number is calculated as offset + index + 1 (1-based) +func (h *Handler) setRowNumbersOnRecords(records any, offset int) { + // Get the reflect value of the records + recordsValue := reflect.ValueOf(records) + if recordsValue.Kind() == reflect.Ptr { + recordsValue = recordsValue.Elem() + } + + // Ensure it's a slice + if recordsValue.Kind() != reflect.Slice { + logger.Debug("setRowNumbersOnRecords: records is not a slice, skipping") + return + } + + // Iterate through each record + for i := 0; i < recordsValue.Len(); i++ { + record := recordsValue.Index(i) + + // Dereference if it's a pointer + if record.Kind() == reflect.Ptr { + if record.IsNil() { + continue + } + record = record.Elem() + } + + // Ensure it's a struct + if record.Kind() != reflect.Struct { + continue + } + + // Try to find and set the RowNumber field + rowNumberField := record.FieldByName("RowNumber") + if rowNumberField.IsValid() && rowNumberField.CanSet() { + // Check if the field is of type int64 + if rowNumberField.Kind() == reflect.Int64 { + rowNum := int64(offset + i + 1) + rowNumberField.SetInt(rowNum) + logger.Debug("Set RowNumber=%d on record %d", rowNum, i) + } + } + } +} + // filterExtendedOptions filters all column references, removing invalid ones and logging warnings func filterExtendedOptions(validator *common.ColumnValidator, options ExtendedRequestOptions) ExtendedRequestOptions { filtered := options diff --git a/pkg/restheadspec/hooks.go b/pkg/restheadspec/hooks.go index 99033b4..a3f2b56 100644 --- a/pkg/restheadspec/hooks.go +++ b/pkg/restheadspec/hooks.go @@ -27,6 +27,9 @@ const ( // Delete operation hooks BeforeDelete HookType = "before_delete" AfterDelete HookType = "after_delete" + + // Scan/Execute operation hooks + BeforeScan HookType = "before_scan" ) // HookContext contains all the data available to a hook @@ -46,6 +49,10 @@ type HookContext struct { Error error // For after hooks QueryFilter string // For read operations + // Query chain - allows hooks to modify the query before execution + // Can be SelectQuery, InsertQuery, UpdateQuery, or DeleteQuery + Query interface{} + // Response writer - allows hooks to modify response Writer common.ResponseWriter } diff --git a/pkg/restheadspec/rownumber_test.go b/pkg/restheadspec/rownumber_test.go new file mode 100644 index 0000000..5424eec --- /dev/null +++ b/pkg/restheadspec/rownumber_test.go @@ -0,0 +1,203 @@ +package restheadspec + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +// TestModel represents a typical model with RowNumber field (like DBAdhocBuffer) +type TestModel struct { + ID int64 `json:"id" bun:"id,pk"` + Name string `json:"name" bun:"name"` + RowNumber int64 `json:"_rownumber,omitempty" gorm:"-" bun:"-"` +} + +func TestSetRowNumbersOnRecords(t *testing.T) { + handler := &Handler{} + + tests := []struct { + name string + records any + offset int + expected []int64 + }{ + { + name: "Set row numbers on slice of pointers", + records: []*TestModel{ + {ID: 1, Name: "First"}, + {ID: 2, Name: "Second"}, + {ID: 3, Name: "Third"}, + }, + offset: 0, + expected: []int64{1, 2, 3}, + }, + { + name: "Set row numbers with offset", + records: []*TestModel{ + {ID: 11, Name: "Eleventh"}, + {ID: 12, Name: "Twelfth"}, + }, + offset: 10, + expected: []int64{11, 12}, + }, + { + name: "Set row numbers on slice of values", + records: []TestModel{ + {ID: 1, Name: "First"}, + {ID: 2, Name: "Second"}, + }, + offset: 5, + expected: []int64{6, 7}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + handler.setRowNumbersOnRecords(tt.records, tt.offset) + + // Verify row numbers were set correctly + switch records := tt.records.(type) { + case []*TestModel: + assert.Equal(t, len(tt.expected), len(records)) + for i, record := range records { + assert.Equal(t, tt.expected[i], record.RowNumber, + "Record %d should have RowNumber=%d", i, tt.expected[i]) + } + case []TestModel: + assert.Equal(t, len(tt.expected), len(records)) + for i, record := range records { + assert.Equal(t, tt.expected[i], record.RowNumber, + "Record %d should have RowNumber=%d", i, tt.expected[i]) + } + } + }) + } +} + +func TestSetRowNumbersOnRecords_NoRowNumberField(t *testing.T) { + handler := &Handler{} + + // Model without RowNumber field + type SimpleModel struct { + ID int64 `json:"id"` + Name string `json:"name"` + } + + records := []*SimpleModel{ + {ID: 1, Name: "First"}, + {ID: 2, Name: "Second"}, + } + + // Should not panic when model doesn't have RowNumber field + assert.NotPanics(t, func() { + handler.setRowNumbersOnRecords(records, 0) + }) +} + +func TestSetRowNumbersOnRecords_NilRecords(t *testing.T) { + handler := &Handler{} + + records := []*TestModel{ + {ID: 1, Name: "First"}, + nil, // Nil record + {ID: 3, Name: "Third"}, + } + + // Should not panic with nil records + assert.NotPanics(t, func() { + handler.setRowNumbersOnRecords(records, 0) + }) + + // Verify non-nil records were set + assert.Equal(t, int64(1), records[0].RowNumber) + assert.Equal(t, int64(3), records[2].RowNumber) +} + +// DBAdhocBuffer simulates the actual DBAdhocBuffer from db package +type DBAdhocBuffer struct { + CQL1 string `json:"cql1,omitempty" gorm:"->" bun:"-"` + RowNumber int64 `json:"_rownumber,omitempty" gorm:"-" bun:"-"` +} + +// ModelWithEmbeddedBuffer simulates a real model like ModelPublicConsultant +type ModelWithEmbeddedBuffer struct { + ID int64 `json:"id" bun:"id,pk"` + Name string `json:"name" bun:"name"` + + DBAdhocBuffer `json:",omitempty"` // Embedded struct containing RowNumber +} + +func TestSetRowNumbersOnRecords_EmbeddedBuffer(t *testing.T) { + handler := &Handler{} + + // Test with embedded DBAdhocBuffer (like real models) + records := []*ModelWithEmbeddedBuffer{ + {ID: 1, Name: "First"}, + {ID: 2, Name: "Second"}, + {ID: 3, Name: "Third"}, + } + + handler.setRowNumbersOnRecords(records, 10) + + // Verify row numbers were set on embedded field + assert.Equal(t, int64(11), records[0].RowNumber, "First record should have RowNumber=11") + assert.Equal(t, int64(12), records[1].RowNumber, "Second record should have RowNumber=12") + assert.Equal(t, int64(13), records[2].RowNumber, "Third record should have RowNumber=13") +} + +func TestSetRowNumbersOnRecords_EmbeddedBuffer_SliceOfValues(t *testing.T) { + handler := &Handler{} + + // Test with slice of values (not pointers) + records := []ModelWithEmbeddedBuffer{ + {ID: 1, Name: "First"}, + {ID: 2, Name: "Second"}, + } + + handler.setRowNumbersOnRecords(records, 0) + + // Verify row numbers were set on embedded field + assert.Equal(t, int64(1), records[0].RowNumber, "First record should have RowNumber=1") + assert.Equal(t, int64(2), records[1].RowNumber, "Second record should have RowNumber=2") +} + +// Simulate the exact structure from user's code +type MockDBAdhocBuffer struct { + CQL1 string `json:"cql1,omitempty" gorm:"->" bun:"-"` + CQL2 string `json:"cql2,omitempty" gorm:"->" bun:"-"` + RowNumber int64 `json:"_rownumber,omitempty" gorm:"-" bun:"-"` + Request string `json:"_request,omitempty" gorm:"-" bun:"-"` +} + +// Exact structure like ModelPublicConsultant +type ModelPublicConsultant struct { + Consultant string `json:"consultant" bun:"consultant,type:citext,pk"` + Ridconsultant int32 `json:"rid_consultant" bun:"rid_consultant,type:integer,pk"` + Updatecnt int64 `json:"updatecnt" bun:"updatecnt,type:integer,default:0"` + + MockDBAdhocBuffer `json:",omitempty"` // Embedded - RowNumber is here! +} + +func TestSetRowNumbersOnRecords_RealModelStructure(t *testing.T) { + handler := &Handler{} + + // Test with exact structure from user's ModelPublicConsultant + records := []*ModelPublicConsultant{ + {Consultant: "John Doe", Ridconsultant: 1, Updatecnt: 0}, + {Consultant: "Jane Smith", Ridconsultant: 2, Updatecnt: 0}, + {Consultant: "Bob Johnson", Ridconsultant: 3, Updatecnt: 0}, + } + + handler.setRowNumbersOnRecords(records, 100) + + // Verify row numbers were set correctly in the embedded DBAdhocBuffer + assert.Equal(t, int64(101), records[0].RowNumber, "First consultant should have RowNumber=101") + assert.Equal(t, int64(102), records[1].RowNumber, "Second consultant should have RowNumber=102") + assert.Equal(t, int64(103), records[2].RowNumber, "Third consultant should have RowNumber=103") + + t.Logf("✓ RowNumber correctly set in embedded MockDBAdhocBuffer") + t.Logf(" Record 0: Consultant=%s, RowNumber=%d", records[0].Consultant, records[0].RowNumber) + t.Logf(" Record 1: Consultant=%s, RowNumber=%d", records[1].Consultant, records[1].RowNumber) + t.Logf(" Record 2: Consultant=%s, RowNumber=%d", records[2].Consultant, records[2].RowNumber) +} diff --git a/pkg/security/CALLBACKS_GUIDE.md b/pkg/security/CALLBACKS_GUIDE.md new file mode 100644 index 0000000..2170277 --- /dev/null +++ b/pkg/security/CALLBACKS_GUIDE.md @@ -0,0 +1,662 @@ +# Security Provider Callbacks Guide + +## Overview + +The ResolveSpec security provider uses a **callback-based architecture** that requires you to implement three functions: + +1. **AuthenticateCallback** - Extract user credentials from HTTP requests +2. **LoadColumnSecurityCallback** - Load column security rules for masking/hiding +3. **LoadRowSecurityCallback** - Load row security filters (WHERE clauses) + +This design allows you to integrate the security provider with **any** authentication system and database schema. + +--- + +## Why Callbacks? + +The callback-based design provides: + +✅ **Flexibility** - Works with any auth system (JWT, session, OAuth, custom) +✅ **Database Agnostic** - No assumptions about your security table schema +✅ **Testability** - Easy to mock for unit tests +✅ **Extensibility** - Add custom logic without modifying core code + +--- + +## Quick Start + +### Step 1: Implement the Three Callbacks + +```go +package main + +import ( + "fmt" + "net/http" + "github.com/bitechdev/ResolveSpec/pkg/security" +) + +// 1. Authentication: Extract user from request +func myAuthFunction(r *http.Request) (userID int, roles string, err error) { + // Your auth logic here (JWT, session, header, etc.) + token := r.Header.Get("Authorization") + userID, roles, err = validateToken(token) + return userID, roles, err +} + +// 2. Column Security: Load column masking rules +func myLoadColumnSecurity(userID int, schema, tablename string) ([]security.ColumnSecurity, error) { + // Your database query or config lookup here + return loadColumnRulesFromDatabase(userID, schema, tablename) +} + +// 3. Row Security: Load row filtering rules +func myLoadRowSecurity(userID int, schema, tablename string) (security.RowSecurity, error) { + // Your database query or config lookup here + return loadRowRulesFromDatabase(userID, schema, tablename) +} +``` + +### Step 2: Configure the Callbacks + +```go +func main() { + db := setupDatabase() + handler := restheadspec.NewHandlerWithGORM(db) + + // Configure callbacks BEFORE SetupSecurityProvider + security.GlobalSecurity.AuthenticateCallback = myAuthFunction + security.GlobalSecurity.LoadColumnSecurityCallback = myLoadColumnSecurity + security.GlobalSecurity.LoadRowSecurityCallback = myLoadRowSecurity + + // Setup security provider (validates callbacks are set) + if err := security.SetupSecurityProvider(handler, &security.GlobalSecurity); err != nil { + log.Fatal(err) // Fails if callbacks not configured + } + + // Apply middleware + router := mux.NewRouter() + restheadspec.SetupMuxRoutes(router, handler) + router.Use(mux.MiddlewareFunc(security.AuthMiddleware)) + router.Use(mux.MiddlewareFunc(security.SetSecurityMiddleware)) + + http.ListenAndServe(":8080", router) +} +``` + +--- + +## Callback 1: AuthenticateCallback + +### Function Signature + +```go +func(r *http.Request) (userID int, roles string, err error) +``` + +### Parameters +- `r *http.Request` - The incoming HTTP request + +### Returns +- `userID int` - The authenticated user's ID +- `roles string` - User's roles (comma-separated, e.g., "admin,manager") +- `err error` - Return error to reject the request (HTTP 401) + +### Example Implementations + +#### Simple Header-Based Auth +```go +func authenticateFromHeader(r *http.Request) (int, string, error) { + userIDStr := r.Header.Get("X-User-ID") + if userIDStr == "" { + return 0, "", fmt.Errorf("X-User-ID header required") + } + + userID, err := strconv.Atoi(userIDStr) + if err != nil { + return 0, "", fmt.Errorf("invalid user ID") + } + + roles := r.Header.Get("X-User-Roles") // Optional + return userID, roles, nil +} +``` + +#### JWT Token Auth +```go +import "github.com/golang-jwt/jwt/v5" + +func authenticateFromJWT(r *http.Request) (int, string, error) { + authHeader := r.Header.Get("Authorization") + tokenString := strings.TrimPrefix(authHeader, "Bearer ") + + token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { + return []byte(os.Getenv("JWT_SECRET")), nil + }) + + if err != nil || !token.Valid { + return 0, "", fmt.Errorf("invalid token") + } + + claims := token.Claims.(jwt.MapClaims) + userID := int(claims["user_id"].(float64)) + roles := claims["roles"].(string) + + return userID, roles, nil +} +``` + +#### Session Cookie Auth +```go +func authenticateFromSession(r *http.Request) (int, string, error) { + cookie, err := r.Cookie("session_id") + if err != nil { + return 0, "", fmt.Errorf("no session cookie") + } + + session, err := sessionStore.Get(cookie.Value) + if err != nil { + return 0, "", fmt.Errorf("invalid session") + } + + return session.UserID, session.Roles, nil +} +``` + +--- + +## Callback 2: LoadColumnSecurityCallback + +### Function Signature + +```go +func(pUserID int, pSchema, pTablename string) ([]ColumnSecurity, error) +``` + +### Parameters +- `pUserID int` - The authenticated user's ID +- `pSchema string` - Database schema (e.g., "public") +- `pTablename string` - Table name (e.g., "employees") + +### Returns +- `[]ColumnSecurity` - List of column security rules +- `error` - Return error if loading fails + +### ColumnSecurity Structure + +```go +type ColumnSecurity struct { + Schema string // "public" + Tablename string // "employees" + Path []string // ["ssn"] or ["address", "street"] + Accesstype string // "mask" or "hide" + + // Masking configuration (for Accesstype = "mask") + MaskStart int // Mask first N characters + MaskEnd int // Mask last N characters + MaskInvert bool // true = mask middle, false = mask edges + MaskChar string // Character to use for masking (default "*") + + // Optional fields + ExtraFilters map[string]string + Control string + ID int + UserID int +} +``` + +### Example Implementations + +#### Load from Database +```go +func loadColumnSecurityFromDB(userID int, schema, tablename string) ([]security.ColumnSecurity, error) { + var rules []security.ColumnSecurity + + query := ` + SELECT control, accesstype, jsonvalue + FROM core.secacces + WHERE rid_hub IN ( + SELECT rid_hub_parent FROM core.hub_link + WHERE rid_hub_child = ? AND parent_hubtype = 'secgroup' + ) + AND control ILIKE ? + ` + + rows, err := db.Query(query, userID, fmt.Sprintf("%s.%s%%", schema, tablename)) + if err != nil { + return nil, err + } + defer rows.Close() + + for rows.Next() { + var control, accesstype, jsonValue string + rows.Scan(&control, &accesstype, &jsonValue) + + // Parse control: "schema.table.column" + parts := strings.Split(control, ".") + if len(parts) < 3 { + continue + } + + rule := security.ColumnSecurity{ + Schema: schema, + Tablename: tablename, + Path: parts[2:], + Accesstype: accesstype, + } + + // Parse JSON configuration + var config map[string]interface{} + json.Unmarshal([]byte(jsonValue), &config) + if start, ok := config["start"].(float64); ok { + rule.MaskStart = int(start) + } + if end, ok := config["end"].(float64); ok { + rule.MaskEnd = int(end) + } + if char, ok := config["char"].(string); ok { + rule.MaskChar = char + } + + rules = append(rules, rule) + } + + return rules, nil +} +``` + +#### Load from Static Config +```go +func loadColumnSecurityFromConfig(userID int, schema, tablename string) ([]security.ColumnSecurity, error) { + // Define security rules in code + allRules := map[string][]security.ColumnSecurity{ + "public.employees": { + { + Schema: "public", + Tablename: "employees", + Path: []string{"ssn"}, + Accesstype: "mask", + MaskStart: 5, + MaskChar: "*", + }, + { + Schema: "public", + Tablename: "employees", + Path: []string{"salary"}, + Accesstype: "hide", + }, + }, + } + + key := fmt.Sprintf("%s.%s", schema, tablename) + rules, ok := allRules[key] + if !ok { + return []security.ColumnSecurity{}, nil // No rules + } + + return rules, nil +} +``` + +### Column Security Examples + +**Mask SSN (show last 4 digits):** +```go +ColumnSecurity{ + Path: []string{"ssn"}, + Accesstype: "mask", + MaskStart: 5, // Mask first 5 characters + MaskEnd: 0, // Keep last 4 visible + MaskChar: "*", +} +// Result: "123-45-6789" → "*****6789" +``` + +**Hide entire field:** +```go +ColumnSecurity{ + Path: []string{"salary"}, + Accesstype: "hide", +} +// Result: salary field returns 0 or empty +``` + +**Mask credit card (show last 4 digits):** +```go +ColumnSecurity{ + Path: []string{"credit_card"}, + Accesstype: "mask", + MaskStart: 12, + MaskChar: "*", +} +// Result: "1234-5678-9012-3456" → "************3456" +``` + +--- + +## Callback 3: LoadRowSecurityCallback + +### Function Signature + +```go +func(pUserID int, pSchema, pTablename string) (RowSecurity, error) +``` + +### Parameters +- `pUserID int` - The authenticated user's ID +- `pSchema string` - Database schema +- `pTablename string` - Table name + +### Returns +- `RowSecurity` - Row security configuration +- `error` - Return error if loading fails + +### RowSecurity Structure + +```go +type RowSecurity struct { + Schema string // "public" + Tablename string // "orders" + UserID int // Current user ID + Template string // WHERE clause template (e.g., "user_id = {UserID}") + HasBlock bool // If true, block ALL access to this table +} +``` + +### Template Variables + +You can use these placeholders in the `Template` string: +- `{UserID}` - Current user's ID +- `{PrimaryKeyName}` - Primary key column name +- `{TableName}` - Table name +- `{SchemaName}` - Schema name + +### Example Implementations + +#### Load from Database Function +```go +func loadRowSecurityFromDB(userID int, schema, tablename string) (security.RowSecurity, error) { + var record security.RowSecurity + + query := ` + SELECT p_template, p_block + FROM core.api_sec_rowtemplate(?, ?, ?) + ` + + row := db.QueryRow(query, schema, tablename, userID) + err := row.Scan(&record.Template, &record.HasBlock) + if err != nil { + return security.RowSecurity{}, err + } + + record.Schema = schema + record.Tablename = tablename + record.UserID = userID + + return record, nil +} +``` + +#### Load from Static Config +```go +func loadRowSecurityFromConfig(userID int, schema, tablename string) (security.RowSecurity, error) { + key := fmt.Sprintf("%s.%s", schema, tablename) + + // Define templates for each table + templates := map[string]string{ + "public.orders": "user_id = {UserID}", + "public.documents": "user_id = {UserID} OR is_public = true", + } + + // Define blocked tables + blocked := map[string]bool{ + "public.admin_logs": true, + } + + if blocked[key] { + return security.RowSecurity{ + Schema: schema, + Tablename: tablename, + UserID: userID, + HasBlock: true, + }, nil + } + + template, ok := templates[key] + if !ok { + // No row security - allow all rows + return security.RowSecurity{ + Schema: schema, + Tablename: tablename, + UserID: userID, + Template: "", + HasBlock: false, + }, nil + } + + return security.RowSecurity{ + Schema: schema, + Tablename: tablename, + UserID: userID, + Template: template, + HasBlock: false, + }, nil +} +``` + +### Row Security Examples + +**Users see only their own records:** +```go +RowSecurity{ + Template: "user_id = {UserID}", +} +// Query: SELECT * FROM orders WHERE user_id = 123 +``` + +**Users see their records OR public records:** +```go +RowSecurity{ + Template: "user_id = {UserID} OR is_public = true", +} +``` + +**Complex filter with subquery:** +```go +RowSecurity{ + Template: "department_id IN (SELECT department_id FROM user_departments WHERE user_id = {UserID})", +} +``` + +**Block all access:** +```go +RowSecurity{ + HasBlock: true, +} +// All queries to this table will be rejected +``` + +--- + +## Complete Integration Example + +```go +package main + +import ( + "fmt" + "log" + "net/http" + "strconv" + + "github.com/bitechdev/ResolveSpec/pkg/restheadspec" + "github.com/bitechdev/ResolveSpec/pkg/security" + "github.com/gorilla/mux" + "gorm.io/gorm" +) + +func main() { + db := setupDatabase() + handler := restheadspec.NewHandlerWithGORM(db) + handler.RegisterModel("public", "orders", Order{}) + + // ===== CONFIGURE CALLBACKS ===== + security.GlobalSecurity.AuthenticateCallback = authenticateUser + security.GlobalSecurity.LoadColumnSecurityCallback = loadColumnSec + security.GlobalSecurity.LoadRowSecurityCallback = loadRowSec + + // ===== SETUP SECURITY ===== + if err := security.SetupSecurityProvider(handler, &security.GlobalSecurity); err != nil { + log.Fatal("Security setup failed:", err) + } + + // ===== SETUP ROUTES ===== + router := mux.NewRouter() + restheadspec.SetupMuxRoutes(router, handler) + router.Use(mux.MiddlewareFunc(security.AuthMiddleware)) + router.Use(mux.MiddlewareFunc(security.SetSecurityMiddleware)) + + log.Println("Server starting on :8080") + http.ListenAndServe(":8080", router) +} + +// Callback implementations +func authenticateUser(r *http.Request) (int, string, error) { + userIDStr := r.Header.Get("X-User-ID") + if userIDStr == "" { + return 0, "", fmt.Errorf("authentication required") + } + userID, err := strconv.Atoi(userIDStr) + return userID, "", err +} + +func loadColumnSec(userID int, schema, table string) ([]security.ColumnSecurity, error) { + // Your implementation here + return []security.ColumnSecurity{}, nil +} + +func loadRowSec(userID int, schema, table string) (security.RowSecurity, error) { + return security.RowSecurity{ + Schema: schema, + Tablename: table, + UserID: userID, + Template: "user_id = " + strconv.Itoa(userID), + }, nil +} +``` + +--- + +## Testing Your Callbacks + +### Unit Test Example + +```go +func TestAuthCallback(t *testing.T) { + req := httptest.NewRequest("GET", "/api/orders", nil) + req.Header.Set("X-User-ID", "123") + + userID, roles, err := myAuthFunction(req) + + assert.Nil(t, err) + assert.Equal(t, 123, userID) +} + +func TestColumnSecurityCallback(t *testing.T) { + rules, err := myLoadColumnSecurity(123, "public", "employees") + + assert.Nil(t, err) + assert.Greater(t, len(rules), 0) + assert.Equal(t, "mask", rules[0].Accesstype) +} +``` + +--- + +## Common Patterns + +### Pattern 1: Role-Based Security + +```go +func loadColumnSec(userID int, schema, table string) ([]security.ColumnSecurity, error) { + roles := getUserRoles(userID) + + if contains(roles, "admin") { + // Admins see everything + return []security.ColumnSecurity{}, nil + } + + // Non-admins have restrictions + return []security.ColumnSecurity{ + {Path: []string{"ssn"}, Accesstype: "mask"}, + }, nil +} +``` + +### Pattern 2: Tenant Isolation + +```go +func loadRowSec(userID int, schema, table string) (security.RowSecurity, error) { + tenantID := getUserTenant(userID) + + return security.RowSecurity{ + Template: fmt.Sprintf("tenant_id = %d", tenantID), + }, nil +} +``` + +### Pattern 3: Caching Security Rules + +```go +var securityCache = cache.New(5*time.Minute, 10*time.Minute) + +func loadColumnSec(userID int, schema, table string) ([]security.ColumnSecurity, error) { + cacheKey := fmt.Sprintf("%d:%s.%s", userID, schema, table) + + if cached, found := securityCache.Get(cacheKey); found { + return cached.([]security.ColumnSecurity), nil + } + + rules := loadFromDatabase(userID, schema, table) + securityCache.Set(cacheKey, rules, cache.DefaultExpiration) + + return rules, nil +} +``` + +--- + +## Troubleshooting + +### Error: "AuthenticateCallback not set" +**Solution:** Configure all three callbacks before calling `SetupSecurityProvider`: +```go +security.GlobalSecurity.AuthenticateCallback = myAuthFunc +security.GlobalSecurity.LoadColumnSecurityCallback = myColSecFunc +security.GlobalSecurity.LoadRowSecurityCallback = myRowSecFunc +``` + +### Error: "Authentication failed" +**Solution:** Check your `AuthenticateCallback` implementation. Ensure it returns valid user ID or proper error. + +### Security rules not applying +**Solution:** +1. Check callbacks are returning data +2. Enable debug logging +3. Verify database queries return results +4. Check user has security groups assigned + +--- + +## Next Steps + +1. ✅ Implement the three callbacks for your system +2. ✅ Configure `GlobalSecurity` with your callbacks +3. ✅ Call `SetupSecurityProvider` +4. ✅ Test with different users and verify isolation +5. ✅ Review `callbacks_example.go` for more examples + +For complete working examples, see: +- `pkg/security/callbacks_example.go` - 7 example implementations +- `examples/secure_server/main.go` - Full server example +- `pkg/security/README.md` - Comprehensive documentation diff --git a/pkg/security/QUICK_REFERENCE.md b/pkg/security/QUICK_REFERENCE.md new file mode 100644 index 0000000..9530d85 --- /dev/null +++ b/pkg/security/QUICK_REFERENCE.md @@ -0,0 +1,402 @@ +# Security Provider - Quick Reference + +## 3-Step Setup + +```go +// Step 1: Implement callbacks +func myAuth(r *http.Request) (int, string, error) { /* ... */ } +func myColSec(userID int, schema, table string) ([]security.ColumnSecurity, error) { /* ... */ } +func myRowSec(userID int, schema, table string) (security.RowSecurity, error) { /* ... */ } + +// Step 2: Configure callbacks +security.GlobalSecurity.AuthenticateCallback = myAuth +security.GlobalSecurity.LoadColumnSecurityCallback = myColSec +security.GlobalSecurity.LoadRowSecurityCallback = myRowSec + +// Step 3: Setup and apply middleware +security.SetupSecurityProvider(handler, &security.GlobalSecurity) +router.Use(mux.MiddlewareFunc(security.AuthMiddleware)) +router.Use(mux.MiddlewareFunc(security.SetSecurityMiddleware)) +``` + +--- + +## Callback Signatures + +```go +// 1. Authentication +func(r *http.Request) (userID int, roles string, err error) + +// 2. Column Security +func(userID int, schema, tablename string) ([]ColumnSecurity, error) + +// 3. Row Security +func(userID int, schema, tablename string) (RowSecurity, error) +``` + +--- + +## ColumnSecurity Structure + +```go +security.ColumnSecurity{ + Path: []string{"column_name"}, // ["ssn"] or ["address", "street"] + Accesstype: "mask", // "mask" or "hide" + MaskStart: 5, // Mask first N chars + MaskEnd: 0, // Mask last N chars + MaskChar: "*", // Masking character + MaskInvert: false, // true = mask middle +} +``` + +### Common Examples + +```go +// Hide entire field +{Path: []string{"salary"}, Accesstype: "hide"} + +// Mask SSN (show last 4) +{Path: []string{"ssn"}, Accesstype: "mask", MaskStart: 5} + +// Mask credit card (show last 4) +{Path: []string{"credit_card"}, Accesstype: "mask", MaskStart: 12} + +// Mask email (j***@example.com) +{Path: []string{"email"}, Accesstype: "mask", MaskStart: 1, MaskEnd: 0} +``` + +--- + +## RowSecurity Structure + +```go +security.RowSecurity{ + Schema: "public", + Tablename: "orders", + UserID: 123, + Template: "user_id = {UserID}", // WHERE clause + HasBlock: false, // true = block all access +} +``` + +### Template Variables + +- `{UserID}` - Current user ID +- `{PrimaryKeyName}` - Primary key column +- `{TableName}` - Table name +- `{SchemaName}` - Schema name + +### Common Examples + +```go +// Users see only their records +Template: "user_id = {UserID}" + +// Users see their records OR public ones +Template: "user_id = {UserID} OR is_public = true" + +// Tenant isolation +Template: "tenant_id = 5 AND user_id = {UserID}" + +// Complex with subquery +Template: "dept_id IN (SELECT dept_id FROM user_depts WHERE user_id = {UserID})" + +// Block all access +HasBlock: true +``` + +--- + +## Example Implementations + +### Simple Header Auth + +```go +func authFromHeader(r *http.Request) (int, string, error) { + userIDStr := r.Header.Get("X-User-ID") + if userIDStr == "" { + return 0, "", fmt.Errorf("X-User-ID required") + } + userID, err := strconv.Atoi(userIDStr) + return userID, "", err +} +``` + +### JWT Auth + +```go +func authFromJWT(r *http.Request) (int, string, error) { + token := strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer ") + claims, err := jwt.Parse(token, secret) + if err != nil { + return 0, "", err + } + return claims.UserID, claims.Roles, nil +} +``` + +### Static Column Security + +```go +func loadColSec(userID int, schema, table string) ([]security.ColumnSecurity, error) { + if table == "employees" { + return []security.ColumnSecurity{ + {Path: []string{"ssn"}, Accesstype: "mask", MaskStart: 5}, + {Path: []string{"salary"}, Accesstype: "hide"}, + }, nil + } + return []security.ColumnSecurity{}, nil +} +``` + +### Database Column Security + +```go +func loadColSec(userID int, schema, table string) ([]security.ColumnSecurity, error) { + rows, err := db.Query(` + SELECT control, accesstype, jsonvalue + FROM core.secacces + WHERE rid_hub IN (...) + AND control ILIKE ? + `, fmt.Sprintf("%s.%s%%", schema, table)) + // ... parse and return +} +``` + +### Static Row Security + +```go +func loadRowSec(userID int, schema, table string) (security.RowSecurity, error) { + templates := map[string]string{ + "orders": "user_id = {UserID}", + "documents": "user_id = {UserID} OR is_public = true", + } + return security.RowSecurity{ + Template: templates[table], + }, nil +} +``` + +--- + +## Testing + +```go +// Test auth callback +req := httptest.NewRequest("GET", "/", nil) +req.Header.Set("X-User-ID", "123") +userID, roles, err := myAuth(req) +assert.Equal(t, 123, userID) + +// Test column security callback +rules, err := myColSec(123, "public", "employees") +assert.Equal(t, "mask", rules[0].Accesstype) + +// Test row security callback +rowSec, err := myRowSec(123, "public", "orders") +assert.Equal(t, "user_id = {UserID}", rowSec.Template) +``` + +--- + +## Request Flow + +``` +HTTP Request + ↓ +AuthMiddleware → calls AuthenticateCallback + ↓ (adds userID to context) +SetSecurityMiddleware → adds GlobalSecurity to context + ↓ +Handler.Handle() + ↓ +BeforeRead Hook → calls LoadColumnSecurityCallback + LoadRowSecurityCallback + ↓ +BeforeScan Hook → applies row security (WHERE clause) + ↓ +Database Query + ↓ +AfterRead Hook → applies column security (masking) + ↓ +HTTP Response +``` + +--- + +## Common Patterns + +### Role-Based Security + +```go +func loadColSec(userID int, schema, table string) ([]security.ColumnSecurity, error) { + if isAdmin(userID) { + return []security.ColumnSecurity{}, nil // No restrictions + } + return loadRestrictions(userID, schema, table), nil +} +``` + +### Tenant Isolation + +```go +func loadRowSec(userID int, schema, table string) (security.RowSecurity, error) { + tenantID := getUserTenant(userID) + return security.RowSecurity{ + Template: fmt.Sprintf("tenant_id = %d", tenantID), + }, nil +} +``` + +### Caching + +```go +var cache = make(map[string][]security.ColumnSecurity) + +func loadColSec(userID int, schema, table string) ([]security.ColumnSecurity, error) { + key := fmt.Sprintf("%d:%s.%s", userID, schema, table) + if cached, ok := cache[key]; ok { + return cached, nil + } + rules := loadFromDB(userID, schema, table) + cache[key] = rules + return rules, nil +} +``` + +--- + +## Error Handling + +```go +// Setup will fail if callbacks not configured +if err := security.SetupSecurityProvider(handler, &security.GlobalSecurity); err != nil { + log.Fatal("Security setup failed:", err) +} + +// Auth middleware rejects if callback returns error +func myAuth(r *http.Request) (int, string, error) { + if invalid { + return 0, "", fmt.Errorf("invalid credentials") // Returns HTTP 401 + } + return userID, roles, nil +} + +// Security loading can fail gracefully +func loadColSec(userID int, schema, table string) ([]security.ColumnSecurity, error) { + rules, err := db.Load(...) + if err != nil { + log.Printf("Failed to load security: %v", err) + return []security.ColumnSecurity{}, nil // No rules = no restrictions + } + return rules, nil +} +``` + +--- + +## Debugging + +```go +// Enable debug logging +import "github.com/bitechdev/GoCore/pkg/cfg" +cfg.SetLogLevel("DEBUG") + +// Log in callbacks +func myAuth(r *http.Request) (int, string, error) { + token := r.Header.Get("Authorization") + log.Printf("Auth: token=%s", token) + // ... +} + +// Check if callbacks are called +func loadColSec(userID int, schema, table string) ([]security.ColumnSecurity, error) { + log.Printf("Loading column security: user=%d, schema=%s, table=%s", userID, schema, table) + // ... +} +``` + +--- + +## Complete Minimal Example + +```go +package main + +import ( + "fmt" + "net/http" + "strconv" + "github.com/bitechdev/ResolveSpec/pkg/restheadspec" + "github.com/bitechdev/ResolveSpec/pkg/security" + "github.com/gorilla/mux" +) + +func main() { + handler := restheadspec.NewHandlerWithGORM(db) + + // Configure callbacks + security.GlobalSecurity.AuthenticateCallback = func(r *http.Request) (int, string, error) { + id, _ := strconv.Atoi(r.Header.Get("X-User-ID")) + return id, "", nil + } + security.GlobalSecurity.LoadColumnSecurityCallback = func(u int, s, t string) ([]security.ColumnSecurity, error) { + return []security.ColumnSecurity{}, nil + } + security.GlobalSecurity.LoadRowSecurityCallback = func(u int, s, t string) (security.RowSecurity, error) { + return security.RowSecurity{Template: fmt.Sprintf("user_id = %d", u)}, nil + } + + // Setup + security.SetupSecurityProvider(handler, &security.GlobalSecurity) + + // Middleware + router := mux.NewRouter() + restheadspec.SetupMuxRoutes(router, handler) + router.Use(mux.MiddlewareFunc(security.AuthMiddleware)) + router.Use(mux.MiddlewareFunc(security.SetSecurityMiddleware)) + + http.ListenAndServe(":8080", router) +} +``` + +--- + +## Resources + +| File | Description | +|------|-------------| +| `CALLBACKS_GUIDE.md` | **Start here** - Complete implementation guide | +| `callbacks_example.go` | 7 working examples to copy | +| `CALLBACKS_SUMMARY.md` | Architecture overview | +| `README.md` | Full documentation | +| `setup_example.go` | Integration examples | + +--- + +## Cheat Sheet + +```go +// ===== REQUIRED SETUP ===== +security.GlobalSecurity.AuthenticateCallback = myAuthFunc +security.GlobalSecurity.LoadColumnSecurityCallback = myColFunc +security.GlobalSecurity.LoadRowSecurityCallback = myRowFunc +security.SetupSecurityProvider(handler, &security.GlobalSecurity) + +// ===== CALLBACK SIGNATURES ===== +func(r *http.Request) (int, string, error) // Auth +func(int, string, string) ([]security.ColumnSecurity, error) // Column +func(int, string, string) (security.RowSecurity, error) // Row + +// ===== QUICK EXAMPLES ===== +// Header auth +func(r *http.Request) (int, string, error) { + id, _ := strconv.Atoi(r.Header.Get("X-User-ID")) + return id, "", nil +} + +// Mask SSN +{Path: []string{"ssn"}, Accesstype: "mask", MaskStart: 5} + +// User isolation +{Template: "user_id = {UserID}"} +``` diff --git a/pkg/security/callbacks_example.go b/pkg/security/callbacks_example.go new file mode 100644 index 0000000..5e0d3ba --- /dev/null +++ b/pkg/security/callbacks_example.go @@ -0,0 +1,418 @@ +package security + +import ( + "encoding/json" + "fmt" + "net/http" + "strconv" + "strings" + + DBM "github.com/bitechdev/GoCore/pkg/models" + "github.com/bitechdev/ResolveSpec/pkg/logger" +) + +// This file provides example implementations of the required security callbacks. +// Copy these functions and modify them to match your authentication and database schema. + +// ============================================================================= +// EXAMPLE 1: Simple Header-Based Authentication +// ============================================================================= + +// ExampleAuthenticateFromHeader extracts user ID from X-User-ID header +func ExampleAuthenticateFromHeader(r *http.Request) (userID int, roles string, err error) { + userIDStr := r.Header.Get("X-User-ID") + if userIDStr == "" { + return 0, "", fmt.Errorf("X-User-ID header not provided") + } + + userID, err = strconv.Atoi(userIDStr) + if err != nil { + return 0, "", fmt.Errorf("invalid user ID format: %v", err) + } + + // Optionally extract roles + roles = r.Header.Get("X-User-Roles") // comma-separated: "admin,manager" + + return userID, roles, nil +} + +// ============================================================================= +// EXAMPLE 2: JWT Token Authentication +// ============================================================================= + +// ExampleAuthenticateFromJWT parses a JWT token and extracts user info +// You'll need to import a JWT library like github.com/golang-jwt/jwt/v5 +func ExampleAuthenticateFromJWT(r *http.Request) (userID int, roles string, err error) { + authHeader := r.Header.Get("Authorization") + if authHeader == "" { + return 0, "", fmt.Errorf("authorization header not provided") + } + + // Extract Bearer token + tokenString := strings.TrimPrefix(authHeader, "Bearer ") + if tokenString == authHeader { + return 0, "", fmt.Errorf("invalid authorization header format") + } + + // TODO: Parse and validate JWT token + // Example using github.com/golang-jwt/jwt/v5: + // + // token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { + // return []byte(os.Getenv("JWT_SECRET")), nil + // }) + // + // if err != nil || !token.Valid { + // return 0, "", fmt.Errorf("invalid token: %v", err) + // } + // + // claims := token.Claims.(jwt.MapClaims) + // userID = int(claims["user_id"].(float64)) + // roles = claims["roles"].(string) + + return 0, "", fmt.Errorf("JWT parsing not implemented - see example above") +} + +// ============================================================================= +// EXAMPLE 3: Session Cookie Authentication +// ============================================================================= + +// ExampleAuthenticateFromSession validates a session cookie +func ExampleAuthenticateFromSession(r *http.Request) (userID int, roles string, err error) { + sessionCookie, err := r.Cookie("session_id") + if err != nil { + return 0, "", fmt.Errorf("session cookie not found") + } + + // TODO: Validate session against your session store (Redis, database, etc.) + // Example: + // + // session, err := sessionStore.Get(sessionCookie.Value) + // if err != nil { + // return 0, "", fmt.Errorf("invalid session") + // } + // + // userID = session.UserID + // roles = session.Roles + + _ = sessionCookie // Suppress unused warning until implemented + return 0, "", fmt.Errorf("session validation not implemented - see example above") +} + +// ============================================================================= +// EXAMPLE 4: Column Security - Database Implementation +// ============================================================================= + +// ExampleLoadColumnSecurityFromDatabase loads column security rules from database +// This implementation assumes the following database schema: +// +// CREATE TABLE core.secacces ( +// rid_secacces SERIAL PRIMARY KEY, +// rid_hub INTEGER, +// control TEXT, -- Format: "schema.table.column" +// accesstype TEXT, -- "mask" or "hide" +// jsonvalue JSONB -- Masking configuration +// ); +// +// CREATE TABLE core.hub_link ( +// rid_hub_parent INTEGER, -- Security group ID +// rid_hub_child INTEGER, -- User ID +// parent_hubtype TEXT -- 'secgroup' +// ); +func ExampleLoadColumnSecurityFromDatabase(pUserID int, pSchema, pTablename string) ([]ColumnSecurity, error) { + colSecList := make([]ColumnSecurity, 0) + + getExtraFilters := func(pStr string) map[string]string { + mp := make(map[string]string, 0) + for i, val := range strings.Split(pStr, ",") { + if i <= 1 { + continue + } + vals := strings.Split(val, ":") + if len(vals) > 1 { + mp[vals[0]] = vals[1] + } + } + return mp + } + + rows, err := DBM.DBConn.Raw(fmt.Sprintf(` + SELECT a.rid_secacces, a.control, a.accesstype, a.jsonvalue + FROM core.secacces a + WHERE a.rid_hub IN ( + SELECT l.rid_hub_parent + FROM core.hub_link l + WHERE l.parent_hubtype = 'secgroup' + AND l.rid_hub_child = ? + ) + AND control ILIKE '%s.%s%%' + `, pSchema, pTablename), pUserID).Rows() + + defer func() { + if rows != nil { + rows.Close() + } + }() + + if err != nil { + return colSecList, fmt.Errorf("failed to fetch column security from SQL: %v", err) + } + + for rows.Next() { + var rid int + var jsondata []byte + var control, accesstype string + + err = rows.Scan(&rid, &control, &accesstype, &jsondata) + if err != nil { + return colSecList, fmt.Errorf("failed to scan column security: %v", err) + } + + parts := strings.Split(control, ",") + ids := strings.Split(parts[0], ".") + if len(ids) < 3 { + continue + } + + jsonvalue := make(map[string]interface{}) + if len(jsondata) > 1 { + err = json.Unmarshal(jsondata, &jsonvalue) + if err != nil { + logger.Error("Failed to parse json: %v", err) + } + } + + colsec := ColumnSecurity{ + Schema: pSchema, + Tablename: pTablename, + UserID: pUserID, + Path: ids[2:], + ExtraFilters: getExtraFilters(control), + Accesstype: accesstype, + Control: control, + ID: int(rid), + } + + // Parse masking configuration from JSON + if v, ok := jsonvalue["start"]; ok { + if value, ok := v.(float64); ok { + colsec.MaskStart = int(value) + } + } + + if v, ok := jsonvalue["end"]; ok { + if value, ok := v.(float64); ok { + colsec.MaskEnd = int(value) + } + } + + if v, ok := jsonvalue["invert"]; ok { + if value, ok := v.(bool); ok { + colsec.MaskInvert = value + } + } + + if v, ok := jsonvalue["char"]; ok { + if value, ok := v.(string); ok { + colsec.MaskChar = value + } + } + + colSecList = append(colSecList, colsec) + } + + return colSecList, nil +} + +// ============================================================================= +// EXAMPLE 5: Column Security - In-Memory/Static Configuration +// ============================================================================= + +// ExampleLoadColumnSecurityFromConfig loads column security from static config +func ExampleLoadColumnSecurityFromConfig(pUserID int, pSchema, pTablename string) ([]ColumnSecurity, error) { + // Example: Define security rules in code or load from config file + securityRules := map[string][]ColumnSecurity{ + "public.employees": { + { + Schema: "public", + Tablename: "employees", + Path: []string{"ssn"}, + Accesstype: "mask", + MaskStart: 5, + MaskEnd: 0, + MaskChar: "*", + }, + { + Schema: "public", + Tablename: "employees", + Path: []string{"salary"}, + Accesstype: "hide", + }, + }, + "public.customers": { + { + Schema: "public", + Tablename: "customers", + Path: []string{"credit_card"}, + Accesstype: "mask", + MaskStart: 12, + MaskEnd: 0, + MaskChar: "*", + }, + }, + } + + key := fmt.Sprintf("%s.%s", pSchema, pTablename) + rules, ok := securityRules[key] + if !ok { + return []ColumnSecurity{}, nil // No rules for this table + } + + // Filter by user ID if needed + // For this example, all rules apply to all users + return rules, nil +} + +// ============================================================================= +// EXAMPLE 6: Row Security - Database Implementation +// ============================================================================= + +// ExampleLoadRowSecurityFromDatabase loads row security rules from database +// This implementation assumes a PostgreSQL function: +// +// CREATE FUNCTION core.api_sec_rowtemplate( +// p_schema TEXT, +// p_table TEXT, +// p_userid INTEGER +// ) RETURNS TABLE ( +// p_retval INTEGER, +// p_errmsg TEXT, +// p_template TEXT, +// p_block BOOLEAN +// ); +func ExampleLoadRowSecurityFromDatabase(pUserID int, pSchema, pTablename string) (RowSecurity, error) { + record := RowSecurity{ + Schema: pSchema, + Tablename: pTablename, + UserID: pUserID, + } + + rows, err := DBM.DBConn.Raw(` + SELECT r.p_retval, r.p_errmsg, r.p_template, r.p_block + FROM core.api_sec_rowtemplate(?, ?, ?) r + `, pSchema, pTablename, pUserID).Rows() + + defer func() { + if rows != nil { + rows.Close() + } + }() + + if err != nil { + return record, fmt.Errorf("failed to fetch row security from SQL: %v", err) + } + + for rows.Next() { + var retval int + var errmsg string + + err = rows.Scan(&retval, &errmsg, &record.Template, &record.HasBlock) + if err != nil { + return record, fmt.Errorf("failed to scan row security: %v", err) + } + + if retval != 0 { + return RowSecurity{}, fmt.Errorf("api_sec_rowtemplate error: %s", errmsg) + } + } + + return record, nil +} + +// ============================================================================= +// EXAMPLE 7: Row Security - Static Configuration +// ============================================================================= + +// ExampleLoadRowSecurityFromConfig loads row security from static config +func ExampleLoadRowSecurityFromConfig(pUserID int, pSchema, pTablename string) (RowSecurity, error) { + // Define row security templates based on entity + templates := map[string]string{ + "public.orders": "user_id = {UserID}", // Users see only their orders + "public.documents": "user_id = {UserID} OR is_public = true", // Users see their docs + public docs + "public.employees": "department_id IN (SELECT department_id FROM user_departments WHERE user_id = {UserID})", // Complex filter + } + + // Define blocked entities (no access at all) + blockedEntities := map[string][]int{ + "public.admin_logs": {}, // All users blocked (empty list = block all) + "public.audit_logs": {1, 2, 3}, // Block users 1, 2, 3 + } + + key := fmt.Sprintf("%s.%s", pSchema, pTablename) + + // Check if entity is blocked for this user + if blockedUsers, ok := blockedEntities[key]; ok { + if len(blockedUsers) == 0 { + // Block all users + return RowSecurity{ + Schema: pSchema, + Tablename: pTablename, + UserID: pUserID, + HasBlock: true, + }, nil + } + // Check if specific user is blocked + for _, blockedUserID := range blockedUsers { + if blockedUserID == pUserID { + return RowSecurity{ + Schema: pSchema, + Tablename: pTablename, + UserID: pUserID, + HasBlock: true, + }, nil + } + } + } + + // Get template for this entity + template, ok := templates[key] + if !ok { + // No row security defined - allow all rows + return RowSecurity{ + Schema: pSchema, + Tablename: pTablename, + UserID: pUserID, + Template: "", + HasBlock: false, + }, nil + } + + return RowSecurity{ + Schema: pSchema, + Tablename: pTablename, + UserID: pUserID, + Template: template, + HasBlock: false, + }, nil +} + +// ============================================================================= +// SETUP HELPER: Configure All Callbacks +// ============================================================================= + +// SetupCallbacksExample shows how to configure all callbacks +func SetupCallbacksExample() { + // Option 1: Use database-backed security (production) + GlobalSecurity.AuthenticateCallback = ExampleAuthenticateFromJWT + GlobalSecurity.LoadColumnSecurityCallback = ExampleLoadColumnSecurityFromDatabase + GlobalSecurity.LoadRowSecurityCallback = ExampleLoadRowSecurityFromDatabase + + // Option 2: Use static configuration (development/testing) + // GlobalSecurity.AuthenticateCallback = ExampleAuthenticateFromHeader + // GlobalSecurity.LoadColumnSecurityCallback = ExampleLoadColumnSecurityFromConfig + // GlobalSecurity.LoadRowSecurityCallback = ExampleLoadRowSecurityFromConfig + + // Option 3: Mix and match + // GlobalSecurity.AuthenticateCallback = ExampleAuthenticateFromJWT + // GlobalSecurity.LoadColumnSecurityCallback = ExampleLoadColumnSecurityFromConfig + // GlobalSecurity.LoadRowSecurityCallback = ExampleLoadRowSecurityFromDatabase +} diff --git a/pkg/security/hooks.go b/pkg/security/hooks.go new file mode 100644 index 0000000..be7e0d3 --- /dev/null +++ b/pkg/security/hooks.go @@ -0,0 +1,244 @@ +package security + +import ( + "fmt" + "reflect" + + "github.com/bitechdev/ResolveSpec/pkg/logger" + "github.com/bitechdev/ResolveSpec/pkg/restheadspec" +) + +// RegisterSecurityHooks registers all security-related hooks with the handler +func RegisterSecurityHooks(handler *restheadspec.Handler, securityList *SecurityList) { + + // Hook 1: BeforeRead - Load security rules + handler.Hooks().Register(restheadspec.BeforeRead, func(hookCtx *restheadspec.HookContext) error { + return loadSecurityRules(hookCtx, securityList) + }) + + // Hook 2: BeforeScan - Apply row-level security filters + handler.Hooks().Register(restheadspec.BeforeScan, func(hookCtx *restheadspec.HookContext) error { + return applyRowSecurity(hookCtx, securityList) + }) + + // Hook 3: AfterRead - Apply column-level security (masking) + handler.Hooks().Register(restheadspec.AfterRead, func(hookCtx *restheadspec.HookContext) error { + return applyColumnSecurity(hookCtx, securityList) + }) + + // Hook 4 (Optional): Audit logging + handler.Hooks().Register(restheadspec.AfterRead, func(hookCtx *restheadspec.HookContext) error { + return logDataAccess(hookCtx) + }) +} + +// loadSecurityRules loads security configuration for the user and entity +func loadSecurityRules(hookCtx *restheadspec.HookContext, securityList *SecurityList) error { + // Extract user ID from context + userID, ok := GetUserID(hookCtx.Context) + if !ok { + logger.Warn("No user ID in context for security check") + return fmt.Errorf("authentication required") + } + + schema := hookCtx.Schema + tablename := hookCtx.Entity + + logger.Debug("Loading security rules for user=%d, schema=%s, table=%s", userID, schema, tablename) + + // Load column security rules from database + err := securityList.LoadColumnSecurity(userID, schema, tablename, false) + if err != nil { + logger.Warn("Failed to load column security: %v", err) + // Don't fail the request if no security rules exist + // return err + } + + // Load row security rules from database + _, err = securityList.LoadRowSecurity(userID, schema, tablename, false) + if err != nil { + logger.Warn("Failed to load row security: %v", err) + // Don't fail the request if no security rules exist + // return err + } + + return nil +} + +// applyRowSecurity applies row-level security filters to the query +func applyRowSecurity(hookCtx *restheadspec.HookContext, securityList *SecurityList) error { + userID, ok := GetUserID(hookCtx.Context) + if !ok { + return nil // No user context, skip + } + + schema := hookCtx.Schema + tablename := hookCtx.Entity + + // Get row security template + rowSec, err := securityList.GetRowSecurityTemplate(userID, schema, tablename) + if err != nil { + // No row security defined, allow query to proceed + logger.Debug("No row security for %s.%s@%d: %v", schema, tablename, userID, err) + return nil + } + + // Check if user has a blocking rule + if rowSec.HasBlock { + logger.Warn("User %d blocked from accessing %s.%s", userID, schema, tablename) + return fmt.Errorf("access denied to %s", tablename) + } + + // If there's a security template, apply it as a WHERE clause + if rowSec.Template != "" { + // Get primary key name from model + modelType := reflect.TypeOf(hookCtx.Model) + if modelType.Kind() == reflect.Ptr { + modelType = modelType.Elem() + } + + // Find primary key field + pkName := "id" // default + for i := 0; i < modelType.NumField(); i++ { + field := modelType.Field(i) + if tag := field.Tag.Get("bun"); tag != "" { + // Check for primary key tag + if contains(tag, "pk") || contains(tag, "primary_key") { + if sqlName := extractSQLName(tag); sqlName != "" { + pkName = sqlName + } + break + } + } + } + + // Generate the WHERE clause from template + whereClause := rowSec.GetTemplate(pkName, modelType) + + logger.Info("Applying row security filter for user %d on %s.%s: %s", + userID, schema, tablename, whereClause) + + // Apply the WHERE clause to the query + // The query is in hookCtx.Query + if selectQuery, ok := hookCtx.Query.(interface { + Where(string, ...interface{}) interface{} + }); ok { + hookCtx.Query = selectQuery.Where(whereClause) + } else { + logger.Error("Unable to apply WHERE clause - query doesn't support Where method") + } + } + + return nil +} + +// applyColumnSecurity applies column-level security (masking/hiding) to results +func applyColumnSecurity(hookCtx *restheadspec.HookContext, securityList *SecurityList) error { + userID, ok := GetUserID(hookCtx.Context) + if !ok { + return nil // No user context, skip + } + + schema := hookCtx.Schema + tablename := hookCtx.Entity + + // Get result data + result := hookCtx.Result + if result == nil { + return nil + } + + logger.Debug("Applying column security for user=%d, schema=%s, table=%s", userID, schema, tablename) + + // Get model type + modelType := reflect.TypeOf(hookCtx.Model) + if modelType.Kind() == reflect.Ptr { + modelType = modelType.Elem() + } + + // Apply column security masking + resultValue := reflect.ValueOf(result) + if resultValue.Kind() == reflect.Ptr { + resultValue = resultValue.Elem() + } + + err, maskedResult := securityList.ApplyColumnSecurity(resultValue, modelType, userID, schema, tablename) + if err != nil { + logger.Warn("Column security error: %v", err) + // Don't fail the request, just log the issue + return nil + } + + // Update the result with masked data + if maskedResult.IsValid() && maskedResult.CanInterface() { + hookCtx.Result = maskedResult.Interface() + } + + return nil +} + +// logDataAccess logs all data access for audit purposes +func logDataAccess(hookCtx *restheadspec.HookContext) error { + userID, _ := GetUserID(hookCtx.Context) + + logger.Info("AUDIT: User %d accessed %s.%s with filters: %+v", + userID, + hookCtx.Schema, + hookCtx.Entity, + hookCtx.Options.Filters, + ) + + // TODO: Write to audit log table or external audit service + // auditLog := AuditLog{ + // UserID: userID, + // Schema: hookCtx.Schema, + // Entity: hookCtx.Entity, + // Action: "READ", + // Timestamp: time.Now(), + // Filters: hookCtx.Options.Filters, + // } + // db.Create(&auditLog) + + return nil +} + +// Helper functions + +func contains(s, substr string) bool { + return len(s) >= len(substr) && s[:len(substr)] == substr || + len(s) > len(substr) && s[len(s)-len(substr):] == substr +} + +func extractSQLName(tag string) string { + // Simple parser for "column:name" or just "name" + // This is a simplified version + parts := splitTag(tag, ',') + for _, part := range parts { + if part != "" && !contains(part, ":") { + return part + } + if contains(part, "column:") { + return part[7:] // Skip "column:" + } + } + return "" +} + +func splitTag(tag string, sep rune) []string { + var parts []string + var current string + for _, ch := range tag { + if ch == sep { + if current != "" { + parts = append(parts, current) + current = "" + } + } else { + current += string(ch) + } + } + if current != "" { + parts = append(parts, current) + } + return parts +} diff --git a/pkg/security/middleware.go b/pkg/security/middleware.go new file mode 100644 index 0000000..b9caa59 --- /dev/null +++ b/pkg/security/middleware.go @@ -0,0 +1,54 @@ +package security + +import ( + "context" + "net/http" +) + +const ( + // Context keys for user information + UserIDKey = "user_id" + UserRolesKey = "user_roles" + UserTokenKey = "user_token" +) + +// AuthMiddleware extracts user authentication from request and adds to context +// This should be applied before the ResolveSpec handler +// Uses GlobalSecurity.AuthenticateCallback if set, otherwise returns error +func AuthMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Check if callback is set + if GlobalSecurity.AuthenticateCallback == nil { + http.Error(w, "AuthenticateCallback not set - you must provide an authentication callback", http.StatusInternalServerError) + return + } + + // Call the user-provided authentication callback + userID, roles, err := GlobalSecurity.AuthenticateCallback(r) + if err != nil { + http.Error(w, "Authentication failed: "+err.Error(), http.StatusUnauthorized) + return + } + + // Add user information to context + ctx := context.WithValue(r.Context(), UserIDKey, userID) + if roles != "" { + ctx = context.WithValue(ctx, UserRolesKey, roles) + } + + // Continue with authenticated context + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} + +// GetUserID extracts the user ID from context +func GetUserID(ctx context.Context) (int, bool) { + userID, ok := ctx.Value(UserIDKey).(int) + return userID, ok +} + +// GetUserRoles extracts user roles from context +func GetUserRoles(ctx context.Context) (string, bool) { + roles, ok := ctx.Value(UserRolesKey).(string) + return roles, ok +} diff --git a/pkg/security/provider.go b/pkg/security/provider.go new file mode 100644 index 0000000..d5fabf8 --- /dev/null +++ b/pkg/security/provider.go @@ -0,0 +1,460 @@ +package security + +import ( + "context" + "fmt" + "net/http" + "reflect" + "strings" + "sync" + + "github.com/bitechdev/ResolveSpec/pkg/logger" + "github.com/bitechdev/ResolveSpec/pkg/reflection" + + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" +) + +type ColumnSecurity struct { + Schema string + Tablename string + Path []string + ExtraFilters map[string]string + UserID int + Accesstype string `json:"accesstype"` + MaskStart int + MaskEnd int + MaskInvert bool + MaskChar string + Control string `json:"control"` + ID int `json:"id"` +} + +type RowSecurity struct { + Schema string + Tablename string + Template string + HasBlock bool + UserID int +} + +func (m *RowSecurity) GetTemplate(pPrimaryKeyName string, pModelType reflect.Type) string { + str := m.Template + str = strings.ReplaceAll(str, "{PrimaryKeyName}", pPrimaryKeyName) + str = strings.ReplaceAll(str, "{TableName}", m.Tablename) + str = strings.ReplaceAll(str, "{SchemaName}", m.Schema) + str = strings.ReplaceAll(str, "{UserID}", fmt.Sprintf("%d", m.UserID)) + return str +} + +// Callback function types for customizing security behavior +type ( + // AuthenticateFunc extracts user ID and roles from HTTP request + // Return userID, roles, error. If error is not nil, request will be rejected. + AuthenticateFunc func(r *http.Request) (userID int, roles string, err error) + + // LoadColumnSecurityFunc loads column security rules for a user and entity + // Override this to customize how column security is loaded from your data source + LoadColumnSecurityFunc func(pUserID int, pSchema, pTablename string) ([]ColumnSecurity, error) + + // LoadRowSecurityFunc loads row security rules for a user and entity + // Override this to customize how row security is loaded from your data source + LoadRowSecurityFunc func(pUserID int, pSchema, pTablename string) (RowSecurity, error) +) + +type SecurityList struct { + ColumnSecurityMutex sync.RWMutex + ColumnSecurity map[string][]ColumnSecurity + RowSecurityMutex sync.RWMutex + RowSecurity map[string]RowSecurity + + // Overridable callbacks + AuthenticateCallback AuthenticateFunc + LoadColumnSecurityCallback LoadColumnSecurityFunc + LoadRowSecurityCallback LoadRowSecurityFunc +} + +const SECURITY_CONTEXT_KEY = "SecurityList" + +var GlobalSecurity SecurityList + +// SetSecurityMiddleware adds security context to requests +func SetSecurityMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := context.WithValue(r.Context(), SECURITY_CONTEXT_KEY, &GlobalSecurity) + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} + +func maskString(pString string, maskStart, maskEnd int, maskChar string, invert bool) string { + strLen := len(pString) + middleIndex := (strLen / 2) + newStr := "" + if maskStart == 0 && maskEnd == 0 { + maskStart = strLen + maskEnd = strLen + } + if maskEnd > strLen { + maskEnd = strLen + } + if maskStart > strLen { + maskStart = strLen + } + if maskChar == "" { + maskChar = "*" + } + for index, char := range pString { + if invert && index >= middleIndex-maskStart && index <= middleIndex { + newStr = newStr + maskChar + continue + } + if invert && index <= middleIndex+maskEnd && index >= middleIndex { + newStr = newStr + maskChar + continue + } + if !invert && index <= maskStart { + newStr = newStr + maskChar + continue + } + if !invert && index >= strLen-1-maskEnd { + newStr = newStr + maskChar + continue + } + newStr = newStr + string(char) + } + + return newStr +} + +func (m *SecurityList) ColumSecurityApplyOnRecord(prevRecord reflect.Value, newRecord reflect.Value, modelType reflect.Type, pUserID int, pSchema, pTablename string) ([]string, error) { + cols := make([]string, 0) + if m.ColumnSecurity == nil { + return cols, fmt.Errorf("security not initialized") + } + + if prevRecord.Type() != newRecord.Type() { + logger.Error("prev:%s and new:%s record type mismatch", prevRecord.Type(), newRecord.Type()) + return cols, fmt.Errorf("prev and new record type mismatch") + } + + m.ColumnSecurityMutex.RLock() + defer m.ColumnSecurityMutex.RUnlock() + + colsecList, ok := m.ColumnSecurity[fmt.Sprintf("%s.%s@%d", pSchema, pTablename, pUserID)] + if !ok || colsecList == nil { + return cols, fmt.Errorf("no security data") + } + + for _, colsec := range colsecList { + if !(strings.EqualFold(colsec.Accesstype, "mask") || strings.EqualFold(colsec.Accesstype, "hide")) { + continue + } + lastRecords := interateStruct(prevRecord) + newRecords := interateStruct(newRecord) + var lastLoopField, lastLoopNewField reflect.Value + pathLen := len(colsec.Path) + for i, path := range colsec.Path { + var nameType, fieldName string + if len(newRecords) == 0 { + if lastLoopNewField.IsValid() && lastLoopField.IsValid() && i < pathLen-1 { + lastLoopNewField.Set(lastLoopField) + } + break + } + + for ri := range newRecords { + if !newRecords[ri].IsValid() || !lastRecords[ri].IsValid() { + break + } + var field, oldField reflect.Value + + columnData := reflection.GetModelColumnDetail(newRecords[ri]) + lastColumnData := reflection.GetModelColumnDetail(lastRecords[ri]) + for i, cols := range columnData { + if cols.SQLName != "" && strings.EqualFold(cols.SQLName, path) { + nameType = "sql" + fieldName = cols.SQLName + field = cols.FieldValue + oldField = lastColumnData[i].FieldValue + break + } + if cols.Name != "" && strings.EqualFold(cols.Name, path) { + nameType = "struct" + fieldName = cols.Name + field = cols.FieldValue + oldField = lastColumnData[i].FieldValue + break + } + } + + if !field.IsValid() || !oldField.IsValid() { + break + } + lastLoopField = oldField + lastLoopNewField = field + + if i == pathLen-1 { + if strings.Contains(strings.ToLower(fieldName), "json") { + prevSrc := oldField.Bytes() + newSrc := field.Bytes() + pathstr := strings.Join(colsec.Path, ".") + prevPathValue := gjson.GetBytes(prevSrc, pathstr) + newBytes, err := sjson.SetBytes(newSrc, pathstr, prevPathValue.Str) + if err == nil { + if field.CanSet() { + field.SetBytes(newBytes) + } else { + logger.Warn("Value not settable: %v", field) + cols = append(cols, pathstr) + } + } + break + } + + if nameType == "sql" { + if strings.EqualFold(colsec.Accesstype, "mask") || strings.EqualFold(colsec.Accesstype, "hide") { + field.Set(oldField) + cols = append(cols, strings.Join(colsec.Path, ".")) + } + } + break + } + + lastRecords = interateStruct(field) + newRecords = interateStruct(oldField) + } + } + } + + return cols, nil +} + +func interateStruct(val reflect.Value) []reflect.Value { + list := make([]reflect.Value, 0) + + switch val.Kind() { + case reflect.Pointer, reflect.Interface: + elem := val.Elem() + if elem.IsValid() { + list = append(list, interateStruct(elem)...) + } + return list + case reflect.Array, reflect.Slice: + for i := 0; i < val.Len(); i++ { + elem := val.Index(i) + if !elem.IsValid() { + continue + } + list = append(list, interateStruct(elem)...) + } + return list + case reflect.Struct: + list = append(list, val) + return list + default: + return list + } +} + +func setColSecValue(fieldsrc reflect.Value, colsec ColumnSecurity, fieldTypeName string) (int, reflect.Value) { + fieldval := fieldsrc + if fieldsrc.Kind() == reflect.Pointer || fieldsrc.Kind() == reflect.Interface { + fieldval = fieldval.Elem() + } + + if strings.Contains(strings.ToLower(fieldval.Kind().String()), "int") && + (strings.EqualFold(colsec.Accesstype, "mask") || strings.EqualFold(colsec.Accesstype, "hide")) { + if fieldval.CanInt() && fieldval.CanSet() { + fieldval.SetInt(0) + } + } else if (strings.Contains(strings.ToLower(fieldval.Kind().String()), "time") || + strings.Contains(strings.ToLower(fieldval.Kind().String()), "date")) && + (strings.EqualFold(colsec.Accesstype, "mask") || strings.EqualFold(colsec.Accesstype, "hide")) { + fieldval.SetZero() + } else if strings.Contains(strings.ToLower(fieldval.Kind().String()), "string") { + strVal := fieldval.String() + if strings.EqualFold(colsec.Accesstype, "mask") { + fieldval.SetString(maskString(strVal, colsec.MaskStart, colsec.MaskEnd, colsec.MaskChar, colsec.MaskInvert)) + } else if strings.EqualFold(colsec.Accesstype, "hide") { + fieldval.SetString("") + } + } else if strings.Contains(fieldTypeName, "json") && + (strings.EqualFold(colsec.Accesstype, "mask") || strings.EqualFold(colsec.Accesstype, "hide")) { + if len(colsec.Path) < 2 { + return 1, fieldval + } + pathstr := strings.Join(colsec.Path, ".") + src := fieldval.Bytes() + pathValue := gjson.GetBytes(src, pathstr) + strValue := pathValue.String() + if strings.EqualFold(colsec.Accesstype, "mask") { + strValue = maskString(strValue, colsec.MaskStart, colsec.MaskEnd, colsec.MaskChar, colsec.MaskInvert) + } else if strings.EqualFold(colsec.Accesstype, "hide") { + strValue = "" + } + newBytes, err := sjson.SetBytes(src, pathstr, strValue) + if err == nil { + fieldval.SetBytes(newBytes) + } + } + return 0, fieldsrc +} + +func (m *SecurityList) ApplyColumnSecurity(records reflect.Value, modelType reflect.Type, pUserID int, pSchema, pTablename string) (error, reflect.Value) { + defer logger.CatchPanic("ApplyColumnSecurity") + + if m.ColumnSecurity == nil { + return fmt.Errorf("security not initialized"), records + } + + m.ColumnSecurityMutex.RLock() + defer m.ColumnSecurityMutex.RUnlock() + + colsecList, ok := m.ColumnSecurity[fmt.Sprintf("%s.%s@%d", pSchema, pTablename, pUserID)] + if !ok || colsecList == nil { + return fmt.Errorf("no security data"), records + } + + for _, colsec := range colsecList { + if !(strings.EqualFold(colsec.Accesstype, "mask") || strings.EqualFold(colsec.Accesstype, "hide")) { + continue + } + + if records.Kind() == reflect.Array || records.Kind() == reflect.Slice { + for i := 0; i < records.Len(); i++ { + record := records.Index(i) + if !record.IsValid() { + continue + } + + lastRecord := interateStruct(record) + pathLen := len(colsec.Path) + for i, path := range colsec.Path { + var field reflect.Value + var nameType, fieldName string + if len(lastRecord) == 0 { + break + } + columnData := reflection.GetModelColumnDetail(lastRecord[0]) + for _, cols := range columnData { + if cols.SQLName != "" && strings.EqualFold(cols.SQLName, path) { + nameType = "sql" + fieldName = cols.SQLName + field = cols.FieldValue + break + } + if cols.Name != "" && strings.EqualFold(cols.Name, path) { + nameType = "struct" + fieldName = cols.Name + field = cols.FieldValue + break + } + } + + if i == pathLen-1 { + if nameType == "sql" || nameType == "struct" { + setColSecValue(field, colsec, fieldName) + } + break + } + if field.IsValid() { + lastRecord = interateStruct(field) + } + } + } + } + } + + return nil, records +} + +func (m *SecurityList) LoadColumnSecurity(pUserID int, pSchema, pTablename string, pOverwrite bool) error { + // Use the callback if provided + if m.LoadColumnSecurityCallback == nil { + return fmt.Errorf("LoadColumnSecurityCallback not set - you must provide a callback function") + } + + m.ColumnSecurityMutex.Lock() + defer m.ColumnSecurityMutex.Unlock() + + if m.ColumnSecurity == nil { + m.ColumnSecurity = make(map[string][]ColumnSecurity, 0) + } + secKey := fmt.Sprintf("%s.%s@%d", pSchema, pTablename, pUserID) + + if pOverwrite || m.ColumnSecurity[secKey] == nil { + m.ColumnSecurity[secKey] = make([]ColumnSecurity, 0) + } + + // Call the user-provided callback to load security rules + colSecList, err := m.LoadColumnSecurityCallback(pUserID, pSchema, pTablename) + if err != nil { + return fmt.Errorf("LoadColumnSecurityCallback failed: %v", err) + } + + m.ColumnSecurity[secKey] = colSecList + return nil +} + +func (m *SecurityList) ClearSecurity(pUserID int, pSchema, pTablename string) error { + var filtered []ColumnSecurity + m.ColumnSecurityMutex.Lock() + defer m.ColumnSecurityMutex.Unlock() + + secKey := fmt.Sprintf("%s.%s@%d", pSchema, pTablename, pUserID) + list, ok := m.ColumnSecurity[secKey] + if !ok { + return nil + } + + for _, cs := range list { + if !(cs.Schema == pSchema && cs.Tablename == pTablename && cs.UserID == pUserID) { + filtered = append(filtered, cs) + } + } + + m.ColumnSecurity[secKey] = filtered + return nil +} + +func (m *SecurityList) LoadRowSecurity(pUserID int, pSchema, pTablename string, pOverwrite bool) (RowSecurity, error) { + // Use the callback if provided + if m.LoadRowSecurityCallback == nil { + return RowSecurity{}, fmt.Errorf("LoadRowSecurityCallback not set - you must provide a callback function") + } + + m.RowSecurityMutex.Lock() + defer m.RowSecurityMutex.Unlock() + + if m.RowSecurity == nil { + m.RowSecurity = make(map[string]RowSecurity, 0) + } + secKey := fmt.Sprintf("%s.%s@%d", pSchema, pTablename, pUserID) + + // Call the user-provided callback to load security rules + record, err := m.LoadRowSecurityCallback(pUserID, pSchema, pTablename) + if err != nil { + return RowSecurity{}, fmt.Errorf("LoadRowSecurityCallback failed: %v", err) + } + + m.RowSecurity[secKey] = record + return record, nil +} + +func (m *SecurityList) GetRowSecurityTemplate(pUserID int, pSchema, pTablename string) (RowSecurity, error) { + defer logger.CatchPanic("GetRowSecurityTemplate") + + if m.RowSecurity == nil { + return RowSecurity{}, fmt.Errorf("security not initialized") + } + + m.RowSecurityMutex.RLock() + defer m.RowSecurityMutex.RUnlock() + + rowSec, ok := m.RowSecurity[fmt.Sprintf("%s.%s@%d", pSchema, pTablename, pUserID)] + if !ok { + return RowSecurity{}, fmt.Errorf("no security data") + } + + return rowSec, nil +} diff --git a/pkg/security/setup_example.go b/pkg/security/setup_example.go new file mode 100644 index 0000000..08cc56d --- /dev/null +++ b/pkg/security/setup_example.go @@ -0,0 +1,155 @@ +package security + +import ( + "fmt" + "net/http" + + "github.com/bitechdev/ResolveSpec/pkg/restheadspec" + "github.com/gorilla/mux" + "gorm.io/gorm" +) + +// SetupSecurityProvider initializes and configures the security provider +// This should be called when setting up your HTTP server +// +// IMPORTANT: You MUST configure the callbacks before calling this function: +// - GlobalSecurity.AuthenticateCallback +// - GlobalSecurity.LoadColumnSecurityCallback +// - GlobalSecurity.LoadRowSecurityCallback +// +// Example usage in your main.go or server setup: +// +// // Step 1: Configure callbacks (REQUIRED) +// security.GlobalSecurity.AuthenticateCallback = myAuthFunction +// security.GlobalSecurity.LoadColumnSecurityCallback = myLoadColumnSecurityFunction +// security.GlobalSecurity.LoadRowSecurityCallback = myLoadRowSecurityFunction +// +// // Step 2: Setup security provider +// handler := restheadspec.NewHandlerWithGORM(db) +// security.SetupSecurityProvider(handler, &security.GlobalSecurity) +// +// // Step 3: Apply middleware +// router.Use(mux.MiddlewareFunc(security.AuthMiddleware)) +// router.Use(mux.MiddlewareFunc(security.SetSecurityMiddleware)) +// +func SetupSecurityProvider(handler *restheadspec.Handler, securityList *SecurityList) error { + // Validate that required callbacks are configured + if securityList.AuthenticateCallback == nil { + return fmt.Errorf("AuthenticateCallback must be set before calling SetupSecurityProvider") + } + if securityList.LoadColumnSecurityCallback == nil { + return fmt.Errorf("LoadColumnSecurityCallback must be set before calling SetupSecurityProvider") + } + if securityList.LoadRowSecurityCallback == nil { + return fmt.Errorf("LoadRowSecurityCallback must be set before calling SetupSecurityProvider") + } + + // Initialize security maps if needed + if securityList.ColumnSecurity == nil { + securityList.ColumnSecurity = make(map[string][]ColumnSecurity) + } + if securityList.RowSecurity == nil { + securityList.RowSecurity = make(map[string]RowSecurity) + } + + // Register all security hooks + RegisterSecurityHooks(handler, securityList) + + return nil +} + +// Chain creates a middleware chain +func Chain(middlewares ...func(http.Handler) http.Handler) func(http.Handler) http.Handler { + return func(final http.Handler) http.Handler { + for i := len(middlewares) - 1; i >= 0; i-- { + final = middlewares[i](final) + } + return final + } +} + +// CompleteExample shows a full integration example with Gorilla Mux +func CompleteExample(db *gorm.DB) (http.Handler, error) { + // Step 1: Create the ResolveSpec handler + handler := restheadspec.NewHandlerWithGORM(db) + + // Step 2: Register your models + // handler.RegisterModel("public", "users", User{}) + // handler.RegisterModel("public", "orders", Order{}) + + // Step 3: Configure security callbacks (REQUIRED!) + // See callbacks_example.go for example implementations + GlobalSecurity.AuthenticateCallback = ExampleAuthenticateFromHeader + GlobalSecurity.LoadColumnSecurityCallback = ExampleLoadColumnSecurityFromDatabase + GlobalSecurity.LoadRowSecurityCallback = ExampleLoadRowSecurityFromDatabase + + // Step 4: Setup security provider + if err := SetupSecurityProvider(handler, &GlobalSecurity); err != nil { + return nil, fmt.Errorf("failed to setup security: %v", err) + } + + // Step 5: Create Mux router and setup routes + router := mux.NewRouter() + + // The routes are set up by restheadspec, which handles the conversion + // from http.Request to the internal request format + restheadspec.SetupMuxRoutes(router, handler) + + // Step 6: Apply middleware to the entire router + secureRouter := Chain( + AuthMiddleware, // Extract user from token + SetSecurityMiddleware, // Add security context + )(router) + + return secureRouter, nil +} + +// ExampleWithMux shows a simpler integration with Mux +func ExampleWithMux(db *gorm.DB) (*mux.Router, error) { + handler := restheadspec.NewHandlerWithGORM(db) + + // IMPORTANT: Configure callbacks BEFORE SetupSecurityProvider + GlobalSecurity.AuthenticateCallback = ExampleAuthenticateFromHeader + GlobalSecurity.LoadColumnSecurityCallback = ExampleLoadColumnSecurityFromConfig + GlobalSecurity.LoadRowSecurityCallback = ExampleLoadRowSecurityFromConfig + + if err := SetupSecurityProvider(handler, &GlobalSecurity); err != nil { + return nil, fmt.Errorf("failed to setup security: %v", err) + } + + router := mux.NewRouter() + + // Setup API routes + restheadspec.SetupMuxRoutes(router, handler) + + // Apply middleware to router + router.Use(mux.MiddlewareFunc(AuthMiddleware)) + router.Use(mux.MiddlewareFunc(SetSecurityMiddleware)) + + return router, nil +} + +// Example with Gin +// import "github.com/gin-gonic/gin" +// +// func ExampleWithGin(db *gorm.DB) *gin.Engine { +// handler := restheadspec.NewHandlerWithGORM(db) +// SetupSecurityProvider(handler, &GlobalSecurity) +// +// router := gin.Default() +// +// // Convert middleware to Gin middleware +// router.Use(func(c *gin.Context) { +// AuthMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { +// c.Request = r +// c.Next() +// })).ServeHTTP(c.Writer, c.Request) +// }) +// +// // Setup routes +// api := router.Group("/api") +// api.Any("/:schema/:entity", gin.WrapH(http.HandlerFunc(handler.Handle))) +// api.Any("/:schema/:entity/:id", gin.WrapH(http.HandlerFunc(handler.Handle))) +// +// return router +// } diff --git a/tests/test_helpers.go b/tests/test_helpers.go index 5f4d81a..2b58a99 100644 --- a/tests/test_helpers.go +++ b/tests/test_helpers.go @@ -83,7 +83,7 @@ func TestSetup(m *testing.M) int { router := setupTestRouter(testDB) testServer = httptest.NewServer(router) - fmt.Printf("ResolveSpec test server starting on %s\n", testServer.URL) + logger.Info("ResolveSpec test server starting on %s", testServer.URL) testServerURL = testServer.URL defer testServer.Close() diff --git a/todo.md b/todo.md index 642b08d..df037e6 100644 --- a/todo.md +++ b/todo.md @@ -120,6 +120,8 @@ func (b *BunSelectQuery) Preload(relation string, conditions ...interface{}) com - When making changes, we can have the trigger fire with the correct user. - Maybe wrap the handleRead,Update,Create,Delete handlers in a transaction with context that can abort when the request is cancelled or a configurable timeout is reached. +### 7. + ## Additional Considerations ### Documentation