mirror of
https://github.com/bitechdev/ResolveSpec.git
synced 2025-11-13 09:53:53 +00:00
Updated logging, added getRowNumber and a few more
This commit is contained in:
parent
faafe5abea
commit
ceaa251301
@ -1,138 +0,0 @@
|
||||
# Schema and Table Name Handling
|
||||
|
||||
This document explains how the handlers properly separate and handle schema and table names.
|
||||
|
||||
## Implementation
|
||||
|
||||
Both `resolvespec` and `restheadspec` handlers now properly handle schema and table name separation through the following functions:
|
||||
|
||||
- `parseTableName(fullTableName)` - Splits "schema.table" into separate components
|
||||
- `getSchemaAndTable(defaultSchema, entity, model)` - Returns schema and table separately
|
||||
- `getTableName(schema, entity, model)` - Returns the full "schema.table" format
|
||||
|
||||
## Priority Order
|
||||
|
||||
When determining the schema and table name, the following priority is used:
|
||||
|
||||
1. **If `TableName()` contains a schema** (e.g., "myschema.mytable"), that schema takes precedence
|
||||
2. **If model implements `SchemaProvider`**, use that schema
|
||||
3. **Otherwise**, use the `defaultSchema` parameter from the URL/request
|
||||
|
||||
## Scenarios
|
||||
|
||||
### Scenario 1: Simple table name, default schema
|
||||
```go
|
||||
type User struct {
|
||||
ID string
|
||||
Name string
|
||||
}
|
||||
|
||||
func (User) TableName() string {
|
||||
return "users"
|
||||
}
|
||||
```
|
||||
- Request URL: `/api/public/users`
|
||||
- Result: `schema="public"`, `table="users"`, `fullName="public.users"`
|
||||
|
||||
### Scenario 2: Table name includes schema
|
||||
```go
|
||||
type User struct {
|
||||
ID string
|
||||
Name string
|
||||
}
|
||||
|
||||
func (User) TableName() string {
|
||||
return "auth.users" // Schema included!
|
||||
}
|
||||
```
|
||||
- Request URL: `/api/public/users` (public is ignored)
|
||||
- Result: `schema="auth"`, `table="users"`, `fullName="auth.users"`
|
||||
- **Note**: The schema from `TableName()` takes precedence over the URL schema
|
||||
|
||||
### Scenario 3: Using SchemaProvider
|
||||
```go
|
||||
type User struct {
|
||||
ID string
|
||||
Name string
|
||||
}
|
||||
|
||||
func (User) TableName() string {
|
||||
return "users"
|
||||
}
|
||||
|
||||
func (User) SchemaName() string {
|
||||
return "auth"
|
||||
}
|
||||
```
|
||||
- Request URL: `/api/public/users` (public is ignored)
|
||||
- Result: `schema="auth"`, `table="users"`, `fullName="auth.users"`
|
||||
|
||||
### Scenario 4: Table name includes schema AND SchemaProvider
|
||||
```go
|
||||
type User struct {
|
||||
ID string
|
||||
Name string
|
||||
}
|
||||
|
||||
func (User) TableName() string {
|
||||
return "core.users" // This wins!
|
||||
}
|
||||
|
||||
func (User) SchemaName() string {
|
||||
return "auth" // This is ignored
|
||||
}
|
||||
```
|
||||
- Request URL: `/api/public/users`
|
||||
- Result: `schema="core"`, `table="users"`, `fullName="core.users"`
|
||||
- **Note**: Schema from `TableName()` takes highest precedence
|
||||
|
||||
### Scenario 5: No providers at all
|
||||
```go
|
||||
type User struct {
|
||||
ID string
|
||||
Name string
|
||||
}
|
||||
// No TableName() or SchemaName()
|
||||
```
|
||||
- Request URL: `/api/public/users`
|
||||
- Result: `schema="public"`, `table="users"`, `fullName="public.users"`
|
||||
- Uses URL schema and entity name
|
||||
|
||||
## Key Features
|
||||
|
||||
1. **Automatic detection**: The code automatically detects if `TableName()` includes a schema by checking for "."
|
||||
2. **Backward compatible**: Existing code continues to work
|
||||
3. **Flexible**: Supports multiple ways to specify schema and table
|
||||
4. **Debug logging**: Logs when schema is detected in `TableName()` for debugging
|
||||
|
||||
## Code Locations
|
||||
|
||||
### Handlers
|
||||
- `/pkg/resolvespec/handler.go:472-531`
|
||||
- `/pkg/restheadspec/handler.go:534-593`
|
||||
|
||||
### Database Adapters
|
||||
- `/pkg/common/adapters/database/utils.go` - Shared `parseTableName()` function
|
||||
- `/pkg/common/adapters/database/bun.go` - Bun adapter with separated schema/table
|
||||
- `/pkg/common/adapters/database/gorm.go` - GORM adapter with separated schema/table
|
||||
|
||||
## Adapter Implementation
|
||||
|
||||
Both Bun and GORM adapters now properly separate schema and table name:
|
||||
|
||||
```go
|
||||
// BunSelectQuery/GormSelectQuery now have separated fields:
|
||||
type BunSelectQuery struct {
|
||||
query *bun.SelectQuery
|
||||
schema string // Separated schema name
|
||||
tableName string // Just the table name, without schema
|
||||
tableAlias string
|
||||
}
|
||||
```
|
||||
|
||||
When `Model()` or `Table()` is called:
|
||||
1. The full table name (which may include schema) is parsed
|
||||
2. Schema and table name are stored separately
|
||||
3. When building joins, the already-separated table name is used directly
|
||||
|
||||
This ensures consistent handling of schema-qualified table names throughout the codebase.
|
||||
@ -1,7 +1,6 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
@ -21,8 +20,8 @@ import (
|
||||
|
||||
func main() {
|
||||
// Initialize logger
|
||||
fmt.Println("ResolveSpec test server starting")
|
||||
logger.Init(true)
|
||||
logger.Info("ResolveSpec test server starting")
|
||||
|
||||
// Initialize database
|
||||
db, err := initDB()
|
||||
|
||||
4
go.mod
4
go.mod
@ -24,6 +24,10 @@ require (
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/puzpuzpuz/xsync/v3 v3.5.1 // indirect
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
||||
github.com/tidwall/gjson v1.18.0 // indirect
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
github.com/tidwall/pretty v1.2.0 // indirect
|
||||
github.com/tidwall/sjson v1.2.5 // indirect
|
||||
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc // indirect
|
||||
github.com/uptrace/bunrouter v1.0.23 // indirect
|
||||
github.com/vmihailenco/msgpack/v5 v5.4.1 // indirect
|
||||
|
||||
9
go.sum
9
go.sum
@ -37,6 +37,15 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
|
||||
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
||||
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
|
||||
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
|
||||
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
|
||||
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
||||
github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
|
||||
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
||||
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
||||
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc h1:9lRDQMhESg+zvGYmW5DyG0UqvY96Bu5QYsTLvCHdrgo=
|
||||
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc/go.mod h1:bciPuU6GHm1iF1pBvUfxfsH0Wmnc2VbpgvbI9ZWuIRs=
|
||||
github.com/uptrace/bun v1.2.15 h1:Ut68XRBLDgp9qG9QBMa9ELWaZOmzHNdczHQdrOZbEFE=
|
||||
|
||||
@ -1,10 +1,7 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
)
|
||||
|
||||
// parseTableName splits a table name that may contain schema into separate schema and table
|
||||
@ -17,157 +14,3 @@ func parseTableName(fullTableName string) (schema, table string) {
|
||||
}
|
||||
return "", fullTableName
|
||||
}
|
||||
|
||||
// GetPrimaryKeyName extracts the primary key column name from a model
|
||||
// It first checks if the model implements PrimaryKeyNameProvider (GetIDName method)
|
||||
// Falls back to reflection to find bun:",pk" tag, then gorm:"primaryKey" tag
|
||||
func GetPrimaryKeyName(model any) string {
|
||||
// Check if model implements PrimaryKeyNameProvider
|
||||
if provider, ok := model.(common.PrimaryKeyNameProvider); ok {
|
||||
return provider.GetIDName()
|
||||
}
|
||||
|
||||
// Try Bun tag first
|
||||
if pkName := getPrimaryKeyFromReflection(model, "bun"); pkName != "" {
|
||||
return pkName
|
||||
}
|
||||
|
||||
// Fall back to GORM tag
|
||||
return getPrimaryKeyFromReflection(model, "gorm")
|
||||
}
|
||||
|
||||
// GetModelColumns extracts all column names from a model using reflection
|
||||
// It checks bun tags first, then gorm tags, then json tags, and finally falls back to lowercase field names
|
||||
func GetModelColumns(model any) []string {
|
||||
var columns []string
|
||||
|
||||
modelType := reflect.TypeOf(model)
|
||||
|
||||
// Unwrap pointers, slices, and arrays to get to the base struct type
|
||||
for modelType != nil && (modelType.Kind() == reflect.Pointer || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
|
||||
// Validate that we have a struct type
|
||||
if modelType == nil || modelType.Kind() != reflect.Struct {
|
||||
return columns
|
||||
}
|
||||
|
||||
for i := 0; i < modelType.NumField(); i++ {
|
||||
field := modelType.Field(i)
|
||||
|
||||
// Get column name using the same logic as primary key extraction
|
||||
columnName := getColumnNameFromField(field)
|
||||
|
||||
if columnName != "" {
|
||||
columns = append(columns, columnName)
|
||||
}
|
||||
}
|
||||
|
||||
return columns
|
||||
}
|
||||
|
||||
// getColumnNameFromField extracts the column name from a struct field
|
||||
// Priority: bun tag -> gorm tag -> json tag -> lowercase field name
|
||||
func getColumnNameFromField(field reflect.StructField) string {
|
||||
// Try bun tag first
|
||||
bunTag := field.Tag.Get("bun")
|
||||
if bunTag != "" && bunTag != "-" {
|
||||
if colName := extractColumnFromBunTag(bunTag); colName != "" {
|
||||
return colName
|
||||
}
|
||||
}
|
||||
|
||||
// Try gorm tag
|
||||
gormTag := field.Tag.Get("gorm")
|
||||
if gormTag != "" && gormTag != "-" {
|
||||
if colName := extractColumnFromGormTag(gormTag); colName != "" {
|
||||
return colName
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to json tag
|
||||
jsonTag := field.Tag.Get("json")
|
||||
if jsonTag != "" && jsonTag != "-" {
|
||||
// Extract just the field name before any options
|
||||
parts := strings.Split(jsonTag, ",")
|
||||
if len(parts) > 0 && parts[0] != "" {
|
||||
return parts[0]
|
||||
}
|
||||
}
|
||||
|
||||
// Last resort: use field name in lowercase
|
||||
return strings.ToLower(field.Name)
|
||||
}
|
||||
|
||||
// getPrimaryKeyFromReflection uses reflection to find the primary key field
|
||||
func getPrimaryKeyFromReflection(model any, ormType string) string {
|
||||
val := reflect.ValueOf(model)
|
||||
if val.Kind() == reflect.Pointer {
|
||||
val = val.Elem()
|
||||
}
|
||||
|
||||
if val.Kind() != reflect.Struct {
|
||||
return ""
|
||||
}
|
||||
|
||||
typ := val.Type()
|
||||
for i := 0; i < typ.NumField(); i++ {
|
||||
field := typ.Field(i)
|
||||
|
||||
switch ormType {
|
||||
case "gorm":
|
||||
// Check for gorm tag with primaryKey
|
||||
gormTag := field.Tag.Get("gorm")
|
||||
if strings.Contains(gormTag, "primaryKey") {
|
||||
// Try to extract column name from gorm tag
|
||||
if colName := extractColumnFromGormTag(gormTag); colName != "" {
|
||||
return colName
|
||||
}
|
||||
// Fall back to json tag
|
||||
if jsonTag := field.Tag.Get("json"); jsonTag != "" {
|
||||
return strings.Split(jsonTag, ",")[0]
|
||||
}
|
||||
}
|
||||
case "bun":
|
||||
// Check for bun tag with pk flag
|
||||
bunTag := field.Tag.Get("bun")
|
||||
if strings.Contains(bunTag, "pk") {
|
||||
// Extract column name from bun tag
|
||||
if colName := extractColumnFromBunTag(bunTag); colName != "" {
|
||||
return colName
|
||||
}
|
||||
// Fall back to json tag
|
||||
if jsonTag := field.Tag.Get("json"); jsonTag != "" {
|
||||
return strings.Split(jsonTag, ",")[0]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// extractColumnFromGormTag extracts the column name from a gorm tag
|
||||
// Example: "column:id;primaryKey" -> "id"
|
||||
func extractColumnFromGormTag(tag string) string {
|
||||
parts := strings.Split(tag, ";")
|
||||
for _, part := range parts {
|
||||
part = strings.TrimSpace(part)
|
||||
if colName, found := strings.CutPrefix(part, "column:"); found {
|
||||
return colName
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// extractColumnFromBunTag extracts the column name from a bun tag
|
||||
// Example: "id,pk" -> "id"
|
||||
// Example: ",pk" -> "" (will fall back to json tag)
|
||||
func extractColumnFromBunTag(tag string) string {
|
||||
parts := strings.Split(tag, ",")
|
||||
if len(parts) > 0 && parts[0] != "" {
|
||||
return parts[0]
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
@ -72,11 +72,12 @@ type Response struct {
|
||||
}
|
||||
|
||||
type Metadata struct {
|
||||
Total int64 `json:"total"`
|
||||
Count int64 `json:"count"`
|
||||
Filtered int64 `json:"filtered"`
|
||||
Limit int `json:"limit"`
|
||||
Offset int `json:"offset"`
|
||||
Total int64 `json:"total"`
|
||||
Count int64 `json:"count"`
|
||||
Filtered int64 `json:"filtered"`
|
||||
Limit int `json:"limit"`
|
||||
Offset int `json:"offset"`
|
||||
RowNumber *int64 `json:"row_number,omitempty"`
|
||||
}
|
||||
|
||||
type APIError struct {
|
||||
|
||||
@ -4,6 +4,7 @@ import (
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"runtime/debug"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
@ -70,3 +71,35 @@ func Debug(template string, args ...interface{}) {
|
||||
}
|
||||
Logger.Debugw(fmt.Sprintf(template, args...), "process_id", os.Getpid())
|
||||
}
|
||||
|
||||
// CatchPanic - Handle panic
|
||||
func CatchPanicCallback(location string, cb func(err any)) {
|
||||
if err := recover(); err != nil {
|
||||
//callstack := debug.Stack()
|
||||
|
||||
if Logger != nil {
|
||||
Error("Panic in %s : %v", location, err)
|
||||
} else {
|
||||
fmt.Printf("%s:PANIC->%+v", location, err)
|
||||
debug.PrintStack()
|
||||
}
|
||||
|
||||
//push to sentry
|
||||
// hub := sentry.CurrentHub()
|
||||
// if hub != nil {
|
||||
// evtID := hub.Recover(err)
|
||||
// if evtID != nil {
|
||||
// sentry.Flush(time.Second * 2)
|
||||
// }
|
||||
// }
|
||||
|
||||
if cb != nil {
|
||||
cb(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// CatchPanic - Handle panic
|
||||
func CatchPanic(location string) {
|
||||
CatchPanicCallback(location, nil)
|
||||
}
|
||||
|
||||
100
pkg/reflection/generic_model.go
Normal file
100
pkg/reflection/generic_model.go
Normal file
@ -0,0 +1,100 @@
|
||||
package reflection
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
)
|
||||
|
||||
type ModelFieldDetail struct {
|
||||
Name string `json:"name"`
|
||||
DataType string `json:"datatype"`
|
||||
SQLName string `json:"sqlname"`
|
||||
SQLDataType string `json:"sqldatatype"`
|
||||
SQLKey string `json:"sqlkey"`
|
||||
Nullable bool `json:"nullable"`
|
||||
FieldValue reflect.Value `json:"-"`
|
||||
}
|
||||
|
||||
// GetModelColumnDetail - Get a list of columns in the SQL declaration of the model
|
||||
func GetModelColumnDetail(record reflect.Value) []ModelFieldDetail {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
logger.Error("Panic in GetModelColumnDetail : %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
var lst []ModelFieldDetail
|
||||
lst = make([]ModelFieldDetail, 0)
|
||||
|
||||
if !record.IsValid() {
|
||||
return lst
|
||||
}
|
||||
if record.Kind() == reflect.Pointer || record.Kind() == reflect.Interface {
|
||||
record = record.Elem()
|
||||
}
|
||||
if record.Kind() != reflect.Struct {
|
||||
return lst
|
||||
}
|
||||
modeltype := record.Type()
|
||||
|
||||
for i := 0; i < modeltype.NumField(); i++ {
|
||||
fieldtype := modeltype.Field(i)
|
||||
gormdetail := fieldtype.Tag.Get("gorm")
|
||||
gormdetail = strings.Trim(gormdetail, " ")
|
||||
fielddetail := ModelFieldDetail{}
|
||||
fielddetail.FieldValue = record.Field(i)
|
||||
fielddetail.Name = fieldtype.Name
|
||||
fielddetail.DataType = fieldtype.Type.Name()
|
||||
fielddetail.SQLName = fnFindKeyVal(gormdetail, "column:")
|
||||
fielddetail.SQLDataType = fnFindKeyVal(gormdetail, "type:")
|
||||
if strings.Index(strings.ToLower(gormdetail), "identity") > 0 ||
|
||||
strings.Index(strings.ToLower(gormdetail), "primary_key") > 0 {
|
||||
fielddetail.SQLKey = "primary_key"
|
||||
} else if strings.Contains(strings.ToLower(gormdetail), "unique") {
|
||||
fielddetail.SQLKey = "unique"
|
||||
} else if strings.Contains(strings.ToLower(gormdetail), "uniqueindex") {
|
||||
fielddetail.SQLKey = "uniqueindex"
|
||||
}
|
||||
|
||||
if strings.Contains(strings.ToLower(gormdetail), "nullable") {
|
||||
fielddetail.Nullable = true
|
||||
} else if strings.Contains(strings.ToLower(gormdetail), "null") {
|
||||
fielddetail.Nullable = true
|
||||
}
|
||||
if strings.Contains(strings.ToLower(gormdetail), "not null") {
|
||||
fielddetail.Nullable = false
|
||||
}
|
||||
|
||||
if strings.Contains(strings.ToLower(gormdetail), "foreignkey:") {
|
||||
fielddetail.SQLKey = "foreign_key"
|
||||
ik := strings.Index(strings.ToLower(gormdetail), "foreignkey:")
|
||||
ie := strings.Index(gormdetail[ik:], ";")
|
||||
if ie > ik && ik > 0 {
|
||||
fielddetail.SQLName = strings.ToLower(gormdetail)[ik+11 : ik+ie]
|
||||
//fmt.Printf("\r\nforeignkey: %v", fielddetail)
|
||||
}
|
||||
|
||||
}
|
||||
//";foreignkey:rid_parent;association_foreignkey:id_atevent;save_associations:false;association_autocreate:false;"
|
||||
|
||||
lst = append(lst, fielddetail)
|
||||
|
||||
}
|
||||
return lst
|
||||
}
|
||||
|
||||
func fnFindKeyVal(src, key string) string {
|
||||
icolStart := strings.Index(strings.ToLower(src), strings.ToLower(key))
|
||||
val := ""
|
||||
if icolStart >= 0 {
|
||||
val = src[icolStart+len(key):]
|
||||
icolend := strings.Index(val, ";")
|
||||
if icolend > 0 {
|
||||
val = val[:icolend]
|
||||
}
|
||||
return val
|
||||
}
|
||||
return ""
|
||||
}
|
||||
162
pkg/reflection/model_utils.go
Normal file
162
pkg/reflection/model_utils.go
Normal file
@ -0,0 +1,162 @@
|
||||
package reflection
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
)
|
||||
|
||||
// GetPrimaryKeyName extracts the primary key column name from a model
|
||||
// It first checks if the model implements PrimaryKeyNameProvider (GetIDName method)
|
||||
// Falls back to reflection to find bun:",pk" tag, then gorm:"primaryKey" tag
|
||||
func GetPrimaryKeyName(model any) string {
|
||||
// Check if model implements PrimaryKeyNameProvider
|
||||
if provider, ok := model.(common.PrimaryKeyNameProvider); ok {
|
||||
return provider.GetIDName()
|
||||
}
|
||||
|
||||
// Try Bun tag first
|
||||
if pkName := getPrimaryKeyFromReflection(model, "bun"); pkName != "" {
|
||||
return pkName
|
||||
}
|
||||
|
||||
// Fall back to GORM tag
|
||||
return getPrimaryKeyFromReflection(model, "gorm")
|
||||
}
|
||||
|
||||
// GetModelColumns extracts all column names from a model using reflection
|
||||
// It checks bun tags first, then gorm tags, then json tags, and finally falls back to lowercase field names
|
||||
func GetModelColumns(model any) []string {
|
||||
var columns []string
|
||||
|
||||
modelType := reflect.TypeOf(model)
|
||||
|
||||
// Unwrap pointers, slices, and arrays to get to the base struct type
|
||||
for modelType != nil && (modelType.Kind() == reflect.Pointer || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
|
||||
// Validate that we have a struct type
|
||||
if modelType == nil || modelType.Kind() != reflect.Struct {
|
||||
return columns
|
||||
}
|
||||
|
||||
for i := 0; i < modelType.NumField(); i++ {
|
||||
field := modelType.Field(i)
|
||||
|
||||
// Get column name using the same logic as primary key extraction
|
||||
columnName := getColumnNameFromField(field)
|
||||
|
||||
if columnName != "" {
|
||||
columns = append(columns, columnName)
|
||||
}
|
||||
}
|
||||
|
||||
return columns
|
||||
}
|
||||
|
||||
// getColumnNameFromField extracts the column name from a struct field
|
||||
// Priority: bun tag -> gorm tag -> json tag -> lowercase field name
|
||||
func getColumnNameFromField(field reflect.StructField) string {
|
||||
// Try bun tag first
|
||||
bunTag := field.Tag.Get("bun")
|
||||
if bunTag != "" && bunTag != "-" {
|
||||
if colName := ExtractColumnFromBunTag(bunTag); colName != "" {
|
||||
return colName
|
||||
}
|
||||
}
|
||||
|
||||
// Try gorm tag
|
||||
gormTag := field.Tag.Get("gorm")
|
||||
if gormTag != "" && gormTag != "-" {
|
||||
if colName := ExtractColumnFromGormTag(gormTag); colName != "" {
|
||||
return colName
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to json tag
|
||||
jsonTag := field.Tag.Get("json")
|
||||
if jsonTag != "" && jsonTag != "-" {
|
||||
// Extract just the field name before any options
|
||||
parts := strings.Split(jsonTag, ",")
|
||||
if len(parts) > 0 && parts[0] != "" {
|
||||
return parts[0]
|
||||
}
|
||||
}
|
||||
|
||||
// Last resort: use field name in lowercase
|
||||
return strings.ToLower(field.Name)
|
||||
}
|
||||
|
||||
// getPrimaryKeyFromReflection uses reflection to find the primary key field
|
||||
func getPrimaryKeyFromReflection(model any, ormType string) string {
|
||||
val := reflect.ValueOf(model)
|
||||
if val.Kind() == reflect.Pointer {
|
||||
val = val.Elem()
|
||||
}
|
||||
|
||||
if val.Kind() != reflect.Struct {
|
||||
return ""
|
||||
}
|
||||
|
||||
typ := val.Type()
|
||||
for i := 0; i < typ.NumField(); i++ {
|
||||
field := typ.Field(i)
|
||||
|
||||
switch ormType {
|
||||
case "gorm":
|
||||
// Check for gorm tag with primaryKey
|
||||
gormTag := field.Tag.Get("gorm")
|
||||
if strings.Contains(gormTag, "primaryKey") {
|
||||
// Try to extract column name from gorm tag
|
||||
if colName := ExtractColumnFromGormTag(gormTag); colName != "" {
|
||||
return colName
|
||||
}
|
||||
// Fall back to json tag
|
||||
if jsonTag := field.Tag.Get("json"); jsonTag != "" {
|
||||
return strings.Split(jsonTag, ",")[0]
|
||||
}
|
||||
}
|
||||
case "bun":
|
||||
// Check for bun tag with pk flag
|
||||
bunTag := field.Tag.Get("bun")
|
||||
if strings.Contains(bunTag, "pk") {
|
||||
// Extract column name from bun tag
|
||||
if colName := ExtractColumnFromBunTag(bunTag); colName != "" {
|
||||
return colName
|
||||
}
|
||||
// Fall back to json tag
|
||||
if jsonTag := field.Tag.Get("json"); jsonTag != "" {
|
||||
return strings.Split(jsonTag, ",")[0]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// ExtractColumnFromGormTag extracts the column name from a gorm tag
|
||||
// Example: "column:id;primaryKey" -> "id"
|
||||
func ExtractColumnFromGormTag(tag string) string {
|
||||
parts := strings.Split(tag, ";")
|
||||
for _, part := range parts {
|
||||
part = strings.TrimSpace(part)
|
||||
if colName, found := strings.CutPrefix(part, "column:"); found {
|
||||
return colName
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// ExtractColumnFromBunTag extracts the column name from a bun tag
|
||||
// Example: "id,pk" -> "id"
|
||||
// Example: ",pk" -> "" (will fall back to json tag)
|
||||
func ExtractColumnFromBunTag(tag string) string {
|
||||
parts := strings.Split(tag, ",")
|
||||
if len(parts) > 0 && parts[0] != "" {
|
||||
return parts[0]
|
||||
}
|
||||
return ""
|
||||
}
|
||||
@ -1,4 +1,4 @@
|
||||
package database
|
||||
package reflection
|
||||
|
||||
import (
|
||||
"testing"
|
||||
@ -137,9 +137,9 @@ func TestExtractColumnFromGormTag(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := extractColumnFromGormTag(tt.tag)
|
||||
result := ExtractColumnFromGormTag(tt.tag)
|
||||
if result != tt.expected {
|
||||
t.Errorf("extractColumnFromGormTag() = %v, want %v", result, tt.expected)
|
||||
t.Errorf("ExtractColumnFromGormTag() = %v, want %v", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
@ -170,9 +170,9 @@ func TestExtractColumnFromBunTag(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := extractColumnFromBunTag(tt.tag)
|
||||
result := ExtractColumnFromBunTag(tt.tag)
|
||||
if result != tt.expected {
|
||||
t.Errorf("extractColumnFromBunTag() = %v, want %v", result, tt.expected)
|
||||
t.Errorf("ExtractColumnFromBunTag() = %v, want %v", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
@ -5,6 +5,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
)
|
||||
|
||||
// CursorDirection defines pagination direction
|
||||
@ -85,7 +86,7 @@ func (opts *ExtendedRequestOptions) GetCursorFilter(
|
||||
field, prefix, tableName, modelColumns,
|
||||
)
|
||||
if err != nil {
|
||||
fmt.Printf("WARN: Skipping invalid sort column %q: %v\n", col, err)
|
||||
logger.Warn("Skipping invalid sort column %q: %v", col, err)
|
||||
continue
|
||||
}
|
||||
|
||||
|
||||
@ -10,8 +10,8 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common/adapters/database"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/reflection"
|
||||
)
|
||||
|
||||
// Handler handles API requests using database and model abstractions
|
||||
@ -343,10 +343,10 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
||||
logger.Debug("Applying cursor pagination")
|
||||
|
||||
// Get primary key name
|
||||
pkName := database.GetPrimaryKeyName(model)
|
||||
pkName := reflection.GetPrimaryKeyName(model)
|
||||
|
||||
// Extract model columns for validation using the generic database function
|
||||
modelColumns := database.GetModelColumns(model)
|
||||
modelColumns := reflection.GetModelColumns(model)
|
||||
|
||||
// Build expand joins map (if needed in future)
|
||||
var expandJoins map[string]string
|
||||
@ -371,6 +371,19 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
||||
}
|
||||
}
|
||||
|
||||
// Execute BeforeScan hooks - pass query chain so hooks can modify it
|
||||
hookCtx.Query = query
|
||||
if err := h.hooks.Execute(BeforeScan, hookCtx); err != nil {
|
||||
logger.Error("BeforeScan hook failed: %v", err)
|
||||
h.sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Use potentially modified query from hook context
|
||||
if modifiedQuery, ok := hookCtx.Query.(common.SelectQuery); ok {
|
||||
query = modifiedQuery
|
||||
}
|
||||
|
||||
// Execute query - modelPtr was already created earlier
|
||||
if err := query.Scan(ctx, modelPtr); err != nil {
|
||||
logger.Error("Error executing query: %v", err)
|
||||
@ -387,6 +400,9 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
||||
offset = *options.Offset
|
||||
}
|
||||
|
||||
// Set row numbers on each record if the model has a RowNumber field
|
||||
h.setRowNumbersOnRecords(modelPtr, offset)
|
||||
|
||||
metadata := &common.Metadata{
|
||||
Total: int64(total),
|
||||
Count: int64(common.Len(modelPtr)),
|
||||
@ -395,6 +411,23 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
||||
Offset: offset,
|
||||
}
|
||||
|
||||
// Fetch row number for a specific record if requested
|
||||
if options.RequestOptions.FetchRowNumber != nil && *options.RequestOptions.FetchRowNumber != "" {
|
||||
pkName := reflection.GetPrimaryKeyName(model)
|
||||
pkValue := *options.RequestOptions.FetchRowNumber
|
||||
|
||||
logger.Debug("Fetching row number for specific PK %s = %s", pkName, pkValue)
|
||||
|
||||
rowNum, err := h.FetchRowNumber(ctx, tableName, pkName, pkValue, options, model)
|
||||
if err != nil {
|
||||
logger.Warn("Failed to fetch row number: %v", err)
|
||||
// Don't fail the entire request, just log the warning
|
||||
} else {
|
||||
metadata.RowNumber = &rowNum
|
||||
logger.Debug("Row number for PK %s: %d", pkValue, rowNum)
|
||||
}
|
||||
}
|
||||
|
||||
// Execute AfterRead hooks
|
||||
hookCtx.Result = modelPtr
|
||||
hookCtx.Error = nil
|
||||
@ -466,6 +499,29 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
|
||||
}
|
||||
|
||||
query := tx.NewInsert().Model(modelValue).Table(tableName)
|
||||
|
||||
// Execute BeforeScan hooks - pass query chain so hooks can modify it
|
||||
batchHookCtx := &HookContext{
|
||||
Context: ctx,
|
||||
Handler: h,
|
||||
Schema: schema,
|
||||
Entity: entity,
|
||||
TableName: tableName,
|
||||
Model: model,
|
||||
Options: options,
|
||||
Data: modelValue,
|
||||
Writer: w,
|
||||
Query: query,
|
||||
}
|
||||
if err := h.hooks.Execute(BeforeScan, batchHookCtx); err != nil {
|
||||
return fmt.Errorf("BeforeScan hook failed: %w", err)
|
||||
}
|
||||
|
||||
// Use potentially modified query from hook context
|
||||
if modifiedQuery, ok := batchHookCtx.Query.(common.InsertQuery); ok {
|
||||
query = modifiedQuery
|
||||
}
|
||||
|
||||
if _, err := query.Exec(ctx); err != nil {
|
||||
return fmt.Errorf("failed to insert record: %w", err)
|
||||
}
|
||||
@ -508,6 +564,21 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
|
||||
}
|
||||
|
||||
query := h.db.NewInsert().Model(modelValue).Table(tableName)
|
||||
|
||||
// Execute BeforeScan hooks - pass query chain so hooks can modify it
|
||||
hookCtx.Data = modelValue
|
||||
hookCtx.Query = query
|
||||
if err := h.hooks.Execute(BeforeScan, hookCtx); err != nil {
|
||||
logger.Error("BeforeScan hook failed: %v", err)
|
||||
h.sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Use potentially modified query from hook context
|
||||
if modifiedQuery, ok := hookCtx.Query.(common.InsertQuery); ok {
|
||||
query = modifiedQuery
|
||||
}
|
||||
|
||||
if _, err := query.Exec(ctx); err != nil {
|
||||
logger.Error("Error creating record: %v", err)
|
||||
h.sendError(w, http.StatusInternalServerError, "create_error", "Error creating record", err)
|
||||
@ -593,6 +664,19 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id
|
||||
return
|
||||
}
|
||||
|
||||
// Execute BeforeScan hooks - pass query chain so hooks can modify it
|
||||
hookCtx.Query = query
|
||||
if err := h.hooks.Execute(BeforeScan, hookCtx); err != nil {
|
||||
logger.Error("BeforeScan hook failed: %v", err)
|
||||
h.sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Use potentially modified query from hook context
|
||||
if modifiedQuery, ok := hookCtx.Query.(common.UpdateQuery); ok {
|
||||
query = modifiedQuery
|
||||
}
|
||||
|
||||
result, err := query.Exec(ctx)
|
||||
if err != nil {
|
||||
logger.Error("Error updating record: %v", err)
|
||||
@ -658,6 +742,19 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
||||
|
||||
query = query.Where("id = ?", id)
|
||||
|
||||
// Execute BeforeScan hooks - pass query chain so hooks can modify it
|
||||
hookCtx.Query = query
|
||||
if err := h.hooks.Execute(BeforeScan, hookCtx); err != nil {
|
||||
logger.Error("BeforeScan hook failed: %v", err)
|
||||
h.sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Use potentially modified query from hook context
|
||||
if modifiedQuery, ok := hookCtx.Query.(common.DeleteQuery); ok {
|
||||
query = modifiedQuery
|
||||
}
|
||||
|
||||
result, err := query.Exec(ctx)
|
||||
if err != nil {
|
||||
logger.Error("Error deleting record: %v", err)
|
||||
@ -999,6 +1096,191 @@ func (h *Handler) sendError(w common.ResponseWriter, statusCode int, code, messa
|
||||
w.WriteJSON(response)
|
||||
}
|
||||
|
||||
// FetchRowNumber calculates the row number of a specific record based on sorting and filtering
|
||||
// Returns the 1-based row number of the record with the given primary key value
|
||||
func (h *Handler) FetchRowNumber(ctx context.Context, tableName string, pkName string, pkValue string, options ExtendedRequestOptions, model any) (int64, error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
logger.Error("Panic during FetchRowNumber: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
// Build the sort order SQL
|
||||
sortSQL := ""
|
||||
if len(options.Sort) > 0 {
|
||||
sortParts := make([]string, 0, len(options.Sort))
|
||||
for _, sort := range options.Sort {
|
||||
direction := "ASC"
|
||||
if strings.ToLower(sort.Direction) == "desc" {
|
||||
direction = "DESC"
|
||||
}
|
||||
sortParts = append(sortParts, fmt.Sprintf("%s.%s %s", tableName, sort.Column, direction))
|
||||
}
|
||||
sortSQL = strings.Join(sortParts, ", ")
|
||||
} else {
|
||||
// Default sort by primary key
|
||||
sortSQL = fmt.Sprintf("%s.%s ASC", tableName, pkName)
|
||||
}
|
||||
|
||||
// Build WHERE clauses from filters
|
||||
whereClauses := make([]string, 0)
|
||||
for i := range options.Filters {
|
||||
filter := &options.Filters[i]
|
||||
whereClause := h.buildFilterSQL(filter, tableName)
|
||||
if whereClause != "" {
|
||||
whereClauses = append(whereClauses, fmt.Sprintf("(%s)", whereClause))
|
||||
}
|
||||
}
|
||||
|
||||
// Combine WHERE clauses
|
||||
whereSQL := ""
|
||||
if len(whereClauses) > 0 {
|
||||
whereSQL = "WHERE " + strings.Join(whereClauses, " AND ")
|
||||
}
|
||||
|
||||
// Add custom SQL WHERE if provided
|
||||
if options.CustomSQLWhere != "" {
|
||||
if whereSQL == "" {
|
||||
whereSQL = "WHERE " + options.CustomSQLWhere
|
||||
} else {
|
||||
whereSQL += " AND (" + options.CustomSQLWhere + ")"
|
||||
}
|
||||
}
|
||||
|
||||
// Build JOIN clauses from Expand options
|
||||
joinSQL := ""
|
||||
if len(options.Expand) > 0 {
|
||||
joinParts := make([]string, 0, len(options.Expand))
|
||||
for _, expand := range options.Expand {
|
||||
// Note: This is a simplified join - in production you'd need proper FK mapping
|
||||
joinParts = append(joinParts, fmt.Sprintf("LEFT JOIN %s ON %s.%s_id = %s.id",
|
||||
expand.Relation, tableName, expand.Relation, expand.Relation))
|
||||
}
|
||||
joinSQL = strings.Join(joinParts, "\n")
|
||||
}
|
||||
|
||||
// Build the final query with parameterized PK value
|
||||
queryStr := fmt.Sprintf(`
|
||||
SELECT search.rn
|
||||
FROM (
|
||||
SELECT %[1]s.%[2]s,
|
||||
ROW_NUMBER() OVER(ORDER BY %[3]s) AS rn
|
||||
FROM %[1]s
|
||||
%[5]s
|
||||
%[4]s
|
||||
) search
|
||||
WHERE search.%[2]s = ?
|
||||
`,
|
||||
tableName, // [1] - table name
|
||||
pkName, // [2] - primary key column name
|
||||
sortSQL, // [3] - sort order SQL
|
||||
whereSQL, // [4] - WHERE clause
|
||||
joinSQL, // [5] - JOIN clauses
|
||||
)
|
||||
|
||||
logger.Debug("FetchRowNumber query: %s, pkValue: %s", queryStr, pkValue)
|
||||
|
||||
// Execute the raw query with parameterized PK value
|
||||
var result []struct {
|
||||
RN int64 `bun:"rn"`
|
||||
}
|
||||
err := h.db.Query(ctx, &result, queryStr, pkValue)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to fetch row number: %w", err)
|
||||
}
|
||||
|
||||
if len(result) == 0 {
|
||||
return 0, fmt.Errorf("no row found for primary key %s", pkValue)
|
||||
}
|
||||
|
||||
return result[0].RN, nil
|
||||
}
|
||||
|
||||
// buildFilterSQL converts a filter to SQL WHERE clause string
|
||||
func (h *Handler) buildFilterSQL(filter *common.FilterOption, tableName string) string {
|
||||
qualifiedColumn := h.qualifyColumnName(filter.Column, tableName)
|
||||
|
||||
switch strings.ToLower(filter.Operator) {
|
||||
case "eq", "equals":
|
||||
return fmt.Sprintf("%s = '%v'", qualifiedColumn, filter.Value)
|
||||
case "neq", "not_equals", "ne":
|
||||
return fmt.Sprintf("%s != '%v'", qualifiedColumn, filter.Value)
|
||||
case "gt", "greater_than":
|
||||
return fmt.Sprintf("%s > '%v'", qualifiedColumn, filter.Value)
|
||||
case "gte", "greater_than_equals", "ge":
|
||||
return fmt.Sprintf("%s >= '%v'", qualifiedColumn, filter.Value)
|
||||
case "lt", "less_than":
|
||||
return fmt.Sprintf("%s < '%v'", qualifiedColumn, filter.Value)
|
||||
case "lte", "less_than_equals", "le":
|
||||
return fmt.Sprintf("%s <= '%v'", qualifiedColumn, filter.Value)
|
||||
case "like":
|
||||
return fmt.Sprintf("%s LIKE '%v'", qualifiedColumn, filter.Value)
|
||||
case "ilike":
|
||||
return fmt.Sprintf("%s ILIKE '%v'", qualifiedColumn, filter.Value)
|
||||
case "in":
|
||||
if values, ok := filter.Value.([]any); ok {
|
||||
valueStrs := make([]string, len(values))
|
||||
for i, v := range values {
|
||||
valueStrs[i] = fmt.Sprintf("'%v'", v)
|
||||
}
|
||||
return fmt.Sprintf("%s IN (%s)", qualifiedColumn, strings.Join(valueStrs, ", "))
|
||||
}
|
||||
return ""
|
||||
case "is_null", "isnull":
|
||||
return fmt.Sprintf("(%s IS NULL OR %s = '')", qualifiedColumn, qualifiedColumn)
|
||||
case "is_not_null", "isnotnull":
|
||||
return fmt.Sprintf("(%s IS NOT NULL AND %s != '')", qualifiedColumn, qualifiedColumn)
|
||||
default:
|
||||
logger.Warn("Unknown filter operator in buildFilterSQL: %s", filter.Operator)
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// setRowNumbersOnRecords sets the RowNumber field on each record if it exists
|
||||
// The row number is calculated as offset + index + 1 (1-based)
|
||||
func (h *Handler) setRowNumbersOnRecords(records any, offset int) {
|
||||
// Get the reflect value of the records
|
||||
recordsValue := reflect.ValueOf(records)
|
||||
if recordsValue.Kind() == reflect.Ptr {
|
||||
recordsValue = recordsValue.Elem()
|
||||
}
|
||||
|
||||
// Ensure it's a slice
|
||||
if recordsValue.Kind() != reflect.Slice {
|
||||
logger.Debug("setRowNumbersOnRecords: records is not a slice, skipping")
|
||||
return
|
||||
}
|
||||
|
||||
// Iterate through each record
|
||||
for i := 0; i < recordsValue.Len(); i++ {
|
||||
record := recordsValue.Index(i)
|
||||
|
||||
// Dereference if it's a pointer
|
||||
if record.Kind() == reflect.Ptr {
|
||||
if record.IsNil() {
|
||||
continue
|
||||
}
|
||||
record = record.Elem()
|
||||
}
|
||||
|
||||
// Ensure it's a struct
|
||||
if record.Kind() != reflect.Struct {
|
||||
continue
|
||||
}
|
||||
|
||||
// Try to find and set the RowNumber field
|
||||
rowNumberField := record.FieldByName("RowNumber")
|
||||
if rowNumberField.IsValid() && rowNumberField.CanSet() {
|
||||
// Check if the field is of type int64
|
||||
if rowNumberField.Kind() == reflect.Int64 {
|
||||
rowNum := int64(offset + i + 1)
|
||||
rowNumberField.SetInt(rowNum)
|
||||
logger.Debug("Set RowNumber=%d on record %d", rowNum, i)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// filterExtendedOptions filters all column references, removing invalid ones and logging warnings
|
||||
func filterExtendedOptions(validator *common.ColumnValidator, options ExtendedRequestOptions) ExtendedRequestOptions {
|
||||
filtered := options
|
||||
|
||||
@ -27,6 +27,9 @@ const (
|
||||
// Delete operation hooks
|
||||
BeforeDelete HookType = "before_delete"
|
||||
AfterDelete HookType = "after_delete"
|
||||
|
||||
// Scan/Execute operation hooks
|
||||
BeforeScan HookType = "before_scan"
|
||||
)
|
||||
|
||||
// HookContext contains all the data available to a hook
|
||||
@ -46,6 +49,10 @@ type HookContext struct {
|
||||
Error error // For after hooks
|
||||
QueryFilter string // For read operations
|
||||
|
||||
// Query chain - allows hooks to modify the query before execution
|
||||
// Can be SelectQuery, InsertQuery, UpdateQuery, or DeleteQuery
|
||||
Query interface{}
|
||||
|
||||
// Response writer - allows hooks to modify response
|
||||
Writer common.ResponseWriter
|
||||
}
|
||||
|
||||
203
pkg/restheadspec/rownumber_test.go
Normal file
203
pkg/restheadspec/rownumber_test.go
Normal file
@ -0,0 +1,203 @@
|
||||
package restheadspec
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// TestModel represents a typical model with RowNumber field (like DBAdhocBuffer)
|
||||
type TestModel struct {
|
||||
ID int64 `json:"id" bun:"id,pk"`
|
||||
Name string `json:"name" bun:"name"`
|
||||
RowNumber int64 `json:"_rownumber,omitempty" gorm:"-" bun:"-"`
|
||||
}
|
||||
|
||||
func TestSetRowNumbersOnRecords(t *testing.T) {
|
||||
handler := &Handler{}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
records any
|
||||
offset int
|
||||
expected []int64
|
||||
}{
|
||||
{
|
||||
name: "Set row numbers on slice of pointers",
|
||||
records: []*TestModel{
|
||||
{ID: 1, Name: "First"},
|
||||
{ID: 2, Name: "Second"},
|
||||
{ID: 3, Name: "Third"},
|
||||
},
|
||||
offset: 0,
|
||||
expected: []int64{1, 2, 3},
|
||||
},
|
||||
{
|
||||
name: "Set row numbers with offset",
|
||||
records: []*TestModel{
|
||||
{ID: 11, Name: "Eleventh"},
|
||||
{ID: 12, Name: "Twelfth"},
|
||||
},
|
||||
offset: 10,
|
||||
expected: []int64{11, 12},
|
||||
},
|
||||
{
|
||||
name: "Set row numbers on slice of values",
|
||||
records: []TestModel{
|
||||
{ID: 1, Name: "First"},
|
||||
{ID: 2, Name: "Second"},
|
||||
},
|
||||
offset: 5,
|
||||
expected: []int64{6, 7},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
handler.setRowNumbersOnRecords(tt.records, tt.offset)
|
||||
|
||||
// Verify row numbers were set correctly
|
||||
switch records := tt.records.(type) {
|
||||
case []*TestModel:
|
||||
assert.Equal(t, len(tt.expected), len(records))
|
||||
for i, record := range records {
|
||||
assert.Equal(t, tt.expected[i], record.RowNumber,
|
||||
"Record %d should have RowNumber=%d", i, tt.expected[i])
|
||||
}
|
||||
case []TestModel:
|
||||
assert.Equal(t, len(tt.expected), len(records))
|
||||
for i, record := range records {
|
||||
assert.Equal(t, tt.expected[i], record.RowNumber,
|
||||
"Record %d should have RowNumber=%d", i, tt.expected[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetRowNumbersOnRecords_NoRowNumberField(t *testing.T) {
|
||||
handler := &Handler{}
|
||||
|
||||
// Model without RowNumber field
|
||||
type SimpleModel struct {
|
||||
ID int64 `json:"id"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
records := []*SimpleModel{
|
||||
{ID: 1, Name: "First"},
|
||||
{ID: 2, Name: "Second"},
|
||||
}
|
||||
|
||||
// Should not panic when model doesn't have RowNumber field
|
||||
assert.NotPanics(t, func() {
|
||||
handler.setRowNumbersOnRecords(records, 0)
|
||||
})
|
||||
}
|
||||
|
||||
func TestSetRowNumbersOnRecords_NilRecords(t *testing.T) {
|
||||
handler := &Handler{}
|
||||
|
||||
records := []*TestModel{
|
||||
{ID: 1, Name: "First"},
|
||||
nil, // Nil record
|
||||
{ID: 3, Name: "Third"},
|
||||
}
|
||||
|
||||
// Should not panic with nil records
|
||||
assert.NotPanics(t, func() {
|
||||
handler.setRowNumbersOnRecords(records, 0)
|
||||
})
|
||||
|
||||
// Verify non-nil records were set
|
||||
assert.Equal(t, int64(1), records[0].RowNumber)
|
||||
assert.Equal(t, int64(3), records[2].RowNumber)
|
||||
}
|
||||
|
||||
// DBAdhocBuffer simulates the actual DBAdhocBuffer from db package
|
||||
type DBAdhocBuffer struct {
|
||||
CQL1 string `json:"cql1,omitempty" gorm:"->" bun:"-"`
|
||||
RowNumber int64 `json:"_rownumber,omitempty" gorm:"-" bun:"-"`
|
||||
}
|
||||
|
||||
// ModelWithEmbeddedBuffer simulates a real model like ModelPublicConsultant
|
||||
type ModelWithEmbeddedBuffer struct {
|
||||
ID int64 `json:"id" bun:"id,pk"`
|
||||
Name string `json:"name" bun:"name"`
|
||||
|
||||
DBAdhocBuffer `json:",omitempty"` // Embedded struct containing RowNumber
|
||||
}
|
||||
|
||||
func TestSetRowNumbersOnRecords_EmbeddedBuffer(t *testing.T) {
|
||||
handler := &Handler{}
|
||||
|
||||
// Test with embedded DBAdhocBuffer (like real models)
|
||||
records := []*ModelWithEmbeddedBuffer{
|
||||
{ID: 1, Name: "First"},
|
||||
{ID: 2, Name: "Second"},
|
||||
{ID: 3, Name: "Third"},
|
||||
}
|
||||
|
||||
handler.setRowNumbersOnRecords(records, 10)
|
||||
|
||||
// Verify row numbers were set on embedded field
|
||||
assert.Equal(t, int64(11), records[0].RowNumber, "First record should have RowNumber=11")
|
||||
assert.Equal(t, int64(12), records[1].RowNumber, "Second record should have RowNumber=12")
|
||||
assert.Equal(t, int64(13), records[2].RowNumber, "Third record should have RowNumber=13")
|
||||
}
|
||||
|
||||
func TestSetRowNumbersOnRecords_EmbeddedBuffer_SliceOfValues(t *testing.T) {
|
||||
handler := &Handler{}
|
||||
|
||||
// Test with slice of values (not pointers)
|
||||
records := []ModelWithEmbeddedBuffer{
|
||||
{ID: 1, Name: "First"},
|
||||
{ID: 2, Name: "Second"},
|
||||
}
|
||||
|
||||
handler.setRowNumbersOnRecords(records, 0)
|
||||
|
||||
// Verify row numbers were set on embedded field
|
||||
assert.Equal(t, int64(1), records[0].RowNumber, "First record should have RowNumber=1")
|
||||
assert.Equal(t, int64(2), records[1].RowNumber, "Second record should have RowNumber=2")
|
||||
}
|
||||
|
||||
// Simulate the exact structure from user's code
|
||||
type MockDBAdhocBuffer struct {
|
||||
CQL1 string `json:"cql1,omitempty" gorm:"->" bun:"-"`
|
||||
CQL2 string `json:"cql2,omitempty" gorm:"->" bun:"-"`
|
||||
RowNumber int64 `json:"_rownumber,omitempty" gorm:"-" bun:"-"`
|
||||
Request string `json:"_request,omitempty" gorm:"-" bun:"-"`
|
||||
}
|
||||
|
||||
// Exact structure like ModelPublicConsultant
|
||||
type ModelPublicConsultant struct {
|
||||
Consultant string `json:"consultant" bun:"consultant,type:citext,pk"`
|
||||
Ridconsultant int32 `json:"rid_consultant" bun:"rid_consultant,type:integer,pk"`
|
||||
Updatecnt int64 `json:"updatecnt" bun:"updatecnt,type:integer,default:0"`
|
||||
|
||||
MockDBAdhocBuffer `json:",omitempty"` // Embedded - RowNumber is here!
|
||||
}
|
||||
|
||||
func TestSetRowNumbersOnRecords_RealModelStructure(t *testing.T) {
|
||||
handler := &Handler{}
|
||||
|
||||
// Test with exact structure from user's ModelPublicConsultant
|
||||
records := []*ModelPublicConsultant{
|
||||
{Consultant: "John Doe", Ridconsultant: 1, Updatecnt: 0},
|
||||
{Consultant: "Jane Smith", Ridconsultant: 2, Updatecnt: 0},
|
||||
{Consultant: "Bob Johnson", Ridconsultant: 3, Updatecnt: 0},
|
||||
}
|
||||
|
||||
handler.setRowNumbersOnRecords(records, 100)
|
||||
|
||||
// Verify row numbers were set correctly in the embedded DBAdhocBuffer
|
||||
assert.Equal(t, int64(101), records[0].RowNumber, "First consultant should have RowNumber=101")
|
||||
assert.Equal(t, int64(102), records[1].RowNumber, "Second consultant should have RowNumber=102")
|
||||
assert.Equal(t, int64(103), records[2].RowNumber, "Third consultant should have RowNumber=103")
|
||||
|
||||
t.Logf("✓ RowNumber correctly set in embedded MockDBAdhocBuffer")
|
||||
t.Logf(" Record 0: Consultant=%s, RowNumber=%d", records[0].Consultant, records[0].RowNumber)
|
||||
t.Logf(" Record 1: Consultant=%s, RowNumber=%d", records[1].Consultant, records[1].RowNumber)
|
||||
t.Logf(" Record 2: Consultant=%s, RowNumber=%d", records[2].Consultant, records[2].RowNumber)
|
||||
}
|
||||
662
pkg/security/CALLBACKS_GUIDE.md
Normal file
662
pkg/security/CALLBACKS_GUIDE.md
Normal file
@ -0,0 +1,662 @@
|
||||
# Security Provider Callbacks Guide
|
||||
|
||||
## Overview
|
||||
|
||||
The ResolveSpec security provider uses a **callback-based architecture** that requires you to implement three functions:
|
||||
|
||||
1. **AuthenticateCallback** - Extract user credentials from HTTP requests
|
||||
2. **LoadColumnSecurityCallback** - Load column security rules for masking/hiding
|
||||
3. **LoadRowSecurityCallback** - Load row security filters (WHERE clauses)
|
||||
|
||||
This design allows you to integrate the security provider with **any** authentication system and database schema.
|
||||
|
||||
---
|
||||
|
||||
## Why Callbacks?
|
||||
|
||||
The callback-based design provides:
|
||||
|
||||
✅ **Flexibility** - Works with any auth system (JWT, session, OAuth, custom)
|
||||
✅ **Database Agnostic** - No assumptions about your security table schema
|
||||
✅ **Testability** - Easy to mock for unit tests
|
||||
✅ **Extensibility** - Add custom logic without modifying core code
|
||||
|
||||
---
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Step 1: Implement the Three Callbacks
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/security"
|
||||
)
|
||||
|
||||
// 1. Authentication: Extract user from request
|
||||
func myAuthFunction(r *http.Request) (userID int, roles string, err error) {
|
||||
// Your auth logic here (JWT, session, header, etc.)
|
||||
token := r.Header.Get("Authorization")
|
||||
userID, roles, err = validateToken(token)
|
||||
return userID, roles, err
|
||||
}
|
||||
|
||||
// 2. Column Security: Load column masking rules
|
||||
func myLoadColumnSecurity(userID int, schema, tablename string) ([]security.ColumnSecurity, error) {
|
||||
// Your database query or config lookup here
|
||||
return loadColumnRulesFromDatabase(userID, schema, tablename)
|
||||
}
|
||||
|
||||
// 3. Row Security: Load row filtering rules
|
||||
func myLoadRowSecurity(userID int, schema, tablename string) (security.RowSecurity, error) {
|
||||
// Your database query or config lookup here
|
||||
return loadRowRulesFromDatabase(userID, schema, tablename)
|
||||
}
|
||||
```
|
||||
|
||||
### Step 2: Configure the Callbacks
|
||||
|
||||
```go
|
||||
func main() {
|
||||
db := setupDatabase()
|
||||
handler := restheadspec.NewHandlerWithGORM(db)
|
||||
|
||||
// Configure callbacks BEFORE SetupSecurityProvider
|
||||
security.GlobalSecurity.AuthenticateCallback = myAuthFunction
|
||||
security.GlobalSecurity.LoadColumnSecurityCallback = myLoadColumnSecurity
|
||||
security.GlobalSecurity.LoadRowSecurityCallback = myLoadRowSecurity
|
||||
|
||||
// Setup security provider (validates callbacks are set)
|
||||
if err := security.SetupSecurityProvider(handler, &security.GlobalSecurity); err != nil {
|
||||
log.Fatal(err) // Fails if callbacks not configured
|
||||
}
|
||||
|
||||
// Apply middleware
|
||||
router := mux.NewRouter()
|
||||
restheadspec.SetupMuxRoutes(router, handler)
|
||||
router.Use(mux.MiddlewareFunc(security.AuthMiddleware))
|
||||
router.Use(mux.MiddlewareFunc(security.SetSecurityMiddleware))
|
||||
|
||||
http.ListenAndServe(":8080", router)
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Callback 1: AuthenticateCallback
|
||||
|
||||
### Function Signature
|
||||
|
||||
```go
|
||||
func(r *http.Request) (userID int, roles string, err error)
|
||||
```
|
||||
|
||||
### Parameters
|
||||
- `r *http.Request` - The incoming HTTP request
|
||||
|
||||
### Returns
|
||||
- `userID int` - The authenticated user's ID
|
||||
- `roles string` - User's roles (comma-separated, e.g., "admin,manager")
|
||||
- `err error` - Return error to reject the request (HTTP 401)
|
||||
|
||||
### Example Implementations
|
||||
|
||||
#### Simple Header-Based Auth
|
||||
```go
|
||||
func authenticateFromHeader(r *http.Request) (int, string, error) {
|
||||
userIDStr := r.Header.Get("X-User-ID")
|
||||
if userIDStr == "" {
|
||||
return 0, "", fmt.Errorf("X-User-ID header required")
|
||||
}
|
||||
|
||||
userID, err := strconv.Atoi(userIDStr)
|
||||
if err != nil {
|
||||
return 0, "", fmt.Errorf("invalid user ID")
|
||||
}
|
||||
|
||||
roles := r.Header.Get("X-User-Roles") // Optional
|
||||
return userID, roles, nil
|
||||
}
|
||||
```
|
||||
|
||||
#### JWT Token Auth
|
||||
```go
|
||||
import "github.com/golang-jwt/jwt/v5"
|
||||
|
||||
func authenticateFromJWT(r *http.Request) (int, string, error) {
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
tokenString := strings.TrimPrefix(authHeader, "Bearer ")
|
||||
|
||||
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
|
||||
return []byte(os.Getenv("JWT_SECRET")), nil
|
||||
})
|
||||
|
||||
if err != nil || !token.Valid {
|
||||
return 0, "", fmt.Errorf("invalid token")
|
||||
}
|
||||
|
||||
claims := token.Claims.(jwt.MapClaims)
|
||||
userID := int(claims["user_id"].(float64))
|
||||
roles := claims["roles"].(string)
|
||||
|
||||
return userID, roles, nil
|
||||
}
|
||||
```
|
||||
|
||||
#### Session Cookie Auth
|
||||
```go
|
||||
func authenticateFromSession(r *http.Request) (int, string, error) {
|
||||
cookie, err := r.Cookie("session_id")
|
||||
if err != nil {
|
||||
return 0, "", fmt.Errorf("no session cookie")
|
||||
}
|
||||
|
||||
session, err := sessionStore.Get(cookie.Value)
|
||||
if err != nil {
|
||||
return 0, "", fmt.Errorf("invalid session")
|
||||
}
|
||||
|
||||
return session.UserID, session.Roles, nil
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Callback 2: LoadColumnSecurityCallback
|
||||
|
||||
### Function Signature
|
||||
|
||||
```go
|
||||
func(pUserID int, pSchema, pTablename string) ([]ColumnSecurity, error)
|
||||
```
|
||||
|
||||
### Parameters
|
||||
- `pUserID int` - The authenticated user's ID
|
||||
- `pSchema string` - Database schema (e.g., "public")
|
||||
- `pTablename string` - Table name (e.g., "employees")
|
||||
|
||||
### Returns
|
||||
- `[]ColumnSecurity` - List of column security rules
|
||||
- `error` - Return error if loading fails
|
||||
|
||||
### ColumnSecurity Structure
|
||||
|
||||
```go
|
||||
type ColumnSecurity struct {
|
||||
Schema string // "public"
|
||||
Tablename string // "employees"
|
||||
Path []string // ["ssn"] or ["address", "street"]
|
||||
Accesstype string // "mask" or "hide"
|
||||
|
||||
// Masking configuration (for Accesstype = "mask")
|
||||
MaskStart int // Mask first N characters
|
||||
MaskEnd int // Mask last N characters
|
||||
MaskInvert bool // true = mask middle, false = mask edges
|
||||
MaskChar string // Character to use for masking (default "*")
|
||||
|
||||
// Optional fields
|
||||
ExtraFilters map[string]string
|
||||
Control string
|
||||
ID int
|
||||
UserID int
|
||||
}
|
||||
```
|
||||
|
||||
### Example Implementations
|
||||
|
||||
#### Load from Database
|
||||
```go
|
||||
func loadColumnSecurityFromDB(userID int, schema, tablename string) ([]security.ColumnSecurity, error) {
|
||||
var rules []security.ColumnSecurity
|
||||
|
||||
query := `
|
||||
SELECT control, accesstype, jsonvalue
|
||||
FROM core.secacces
|
||||
WHERE rid_hub IN (
|
||||
SELECT rid_hub_parent FROM core.hub_link
|
||||
WHERE rid_hub_child = ? AND parent_hubtype = 'secgroup'
|
||||
)
|
||||
AND control ILIKE ?
|
||||
`
|
||||
|
||||
rows, err := db.Query(query, userID, fmt.Sprintf("%s.%s%%", schema, tablename))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
var control, accesstype, jsonValue string
|
||||
rows.Scan(&control, &accesstype, &jsonValue)
|
||||
|
||||
// Parse control: "schema.table.column"
|
||||
parts := strings.Split(control, ".")
|
||||
if len(parts) < 3 {
|
||||
continue
|
||||
}
|
||||
|
||||
rule := security.ColumnSecurity{
|
||||
Schema: schema,
|
||||
Tablename: tablename,
|
||||
Path: parts[2:],
|
||||
Accesstype: accesstype,
|
||||
}
|
||||
|
||||
// Parse JSON configuration
|
||||
var config map[string]interface{}
|
||||
json.Unmarshal([]byte(jsonValue), &config)
|
||||
if start, ok := config["start"].(float64); ok {
|
||||
rule.MaskStart = int(start)
|
||||
}
|
||||
if end, ok := config["end"].(float64); ok {
|
||||
rule.MaskEnd = int(end)
|
||||
}
|
||||
if char, ok := config["char"].(string); ok {
|
||||
rule.MaskChar = char
|
||||
}
|
||||
|
||||
rules = append(rules, rule)
|
||||
}
|
||||
|
||||
return rules, nil
|
||||
}
|
||||
```
|
||||
|
||||
#### Load from Static Config
|
||||
```go
|
||||
func loadColumnSecurityFromConfig(userID int, schema, tablename string) ([]security.ColumnSecurity, error) {
|
||||
// Define security rules in code
|
||||
allRules := map[string][]security.ColumnSecurity{
|
||||
"public.employees": {
|
||||
{
|
||||
Schema: "public",
|
||||
Tablename: "employees",
|
||||
Path: []string{"ssn"},
|
||||
Accesstype: "mask",
|
||||
MaskStart: 5,
|
||||
MaskChar: "*",
|
||||
},
|
||||
{
|
||||
Schema: "public",
|
||||
Tablename: "employees",
|
||||
Path: []string{"salary"},
|
||||
Accesstype: "hide",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("%s.%s", schema, tablename)
|
||||
rules, ok := allRules[key]
|
||||
if !ok {
|
||||
return []security.ColumnSecurity{}, nil // No rules
|
||||
}
|
||||
|
||||
return rules, nil
|
||||
}
|
||||
```
|
||||
|
||||
### Column Security Examples
|
||||
|
||||
**Mask SSN (show last 4 digits):**
|
||||
```go
|
||||
ColumnSecurity{
|
||||
Path: []string{"ssn"},
|
||||
Accesstype: "mask",
|
||||
MaskStart: 5, // Mask first 5 characters
|
||||
MaskEnd: 0, // Keep last 4 visible
|
||||
MaskChar: "*",
|
||||
}
|
||||
// Result: "123-45-6789" → "*****6789"
|
||||
```
|
||||
|
||||
**Hide entire field:**
|
||||
```go
|
||||
ColumnSecurity{
|
||||
Path: []string{"salary"},
|
||||
Accesstype: "hide",
|
||||
}
|
||||
// Result: salary field returns 0 or empty
|
||||
```
|
||||
|
||||
**Mask credit card (show last 4 digits):**
|
||||
```go
|
||||
ColumnSecurity{
|
||||
Path: []string{"credit_card"},
|
||||
Accesstype: "mask",
|
||||
MaskStart: 12,
|
||||
MaskChar: "*",
|
||||
}
|
||||
// Result: "1234-5678-9012-3456" → "************3456"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Callback 3: LoadRowSecurityCallback
|
||||
|
||||
### Function Signature
|
||||
|
||||
```go
|
||||
func(pUserID int, pSchema, pTablename string) (RowSecurity, error)
|
||||
```
|
||||
|
||||
### Parameters
|
||||
- `pUserID int` - The authenticated user's ID
|
||||
- `pSchema string` - Database schema
|
||||
- `pTablename string` - Table name
|
||||
|
||||
### Returns
|
||||
- `RowSecurity` - Row security configuration
|
||||
- `error` - Return error if loading fails
|
||||
|
||||
### RowSecurity Structure
|
||||
|
||||
```go
|
||||
type RowSecurity struct {
|
||||
Schema string // "public"
|
||||
Tablename string // "orders"
|
||||
UserID int // Current user ID
|
||||
Template string // WHERE clause template (e.g., "user_id = {UserID}")
|
||||
HasBlock bool // If true, block ALL access to this table
|
||||
}
|
||||
```
|
||||
|
||||
### Template Variables
|
||||
|
||||
You can use these placeholders in the `Template` string:
|
||||
- `{UserID}` - Current user's ID
|
||||
- `{PrimaryKeyName}` - Primary key column name
|
||||
- `{TableName}` - Table name
|
||||
- `{SchemaName}` - Schema name
|
||||
|
||||
### Example Implementations
|
||||
|
||||
#### Load from Database Function
|
||||
```go
|
||||
func loadRowSecurityFromDB(userID int, schema, tablename string) (security.RowSecurity, error) {
|
||||
var record security.RowSecurity
|
||||
|
||||
query := `
|
||||
SELECT p_template, p_block
|
||||
FROM core.api_sec_rowtemplate(?, ?, ?)
|
||||
`
|
||||
|
||||
row := db.QueryRow(query, schema, tablename, userID)
|
||||
err := row.Scan(&record.Template, &record.HasBlock)
|
||||
if err != nil {
|
||||
return security.RowSecurity{}, err
|
||||
}
|
||||
|
||||
record.Schema = schema
|
||||
record.Tablename = tablename
|
||||
record.UserID = userID
|
||||
|
||||
return record, nil
|
||||
}
|
||||
```
|
||||
|
||||
#### Load from Static Config
|
||||
```go
|
||||
func loadRowSecurityFromConfig(userID int, schema, tablename string) (security.RowSecurity, error) {
|
||||
key := fmt.Sprintf("%s.%s", schema, tablename)
|
||||
|
||||
// Define templates for each table
|
||||
templates := map[string]string{
|
||||
"public.orders": "user_id = {UserID}",
|
||||
"public.documents": "user_id = {UserID} OR is_public = true",
|
||||
}
|
||||
|
||||
// Define blocked tables
|
||||
blocked := map[string]bool{
|
||||
"public.admin_logs": true,
|
||||
}
|
||||
|
||||
if blocked[key] {
|
||||
return security.RowSecurity{
|
||||
Schema: schema,
|
||||
Tablename: tablename,
|
||||
UserID: userID,
|
||||
HasBlock: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
template, ok := templates[key]
|
||||
if !ok {
|
||||
// No row security - allow all rows
|
||||
return security.RowSecurity{
|
||||
Schema: schema,
|
||||
Tablename: tablename,
|
||||
UserID: userID,
|
||||
Template: "",
|
||||
HasBlock: false,
|
||||
}, nil
|
||||
}
|
||||
|
||||
return security.RowSecurity{
|
||||
Schema: schema,
|
||||
Tablename: tablename,
|
||||
UserID: userID,
|
||||
Template: template,
|
||||
HasBlock: false,
|
||||
}, nil
|
||||
}
|
||||
```
|
||||
|
||||
### Row Security Examples
|
||||
|
||||
**Users see only their own records:**
|
||||
```go
|
||||
RowSecurity{
|
||||
Template: "user_id = {UserID}",
|
||||
}
|
||||
// Query: SELECT * FROM orders WHERE user_id = 123
|
||||
```
|
||||
|
||||
**Users see their records OR public records:**
|
||||
```go
|
||||
RowSecurity{
|
||||
Template: "user_id = {UserID} OR is_public = true",
|
||||
}
|
||||
```
|
||||
|
||||
**Complex filter with subquery:**
|
||||
```go
|
||||
RowSecurity{
|
||||
Template: "department_id IN (SELECT department_id FROM user_departments WHERE user_id = {UserID})",
|
||||
}
|
||||
```
|
||||
|
||||
**Block all access:**
|
||||
```go
|
||||
RowSecurity{
|
||||
HasBlock: true,
|
||||
}
|
||||
// All queries to this table will be rejected
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Complete Integration Example
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/restheadspec"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/security"
|
||||
"github.com/gorilla/mux"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func main() {
|
||||
db := setupDatabase()
|
||||
handler := restheadspec.NewHandlerWithGORM(db)
|
||||
handler.RegisterModel("public", "orders", Order{})
|
||||
|
||||
// ===== CONFIGURE CALLBACKS =====
|
||||
security.GlobalSecurity.AuthenticateCallback = authenticateUser
|
||||
security.GlobalSecurity.LoadColumnSecurityCallback = loadColumnSec
|
||||
security.GlobalSecurity.LoadRowSecurityCallback = loadRowSec
|
||||
|
||||
// ===== SETUP SECURITY =====
|
||||
if err := security.SetupSecurityProvider(handler, &security.GlobalSecurity); err != nil {
|
||||
log.Fatal("Security setup failed:", err)
|
||||
}
|
||||
|
||||
// ===== SETUP ROUTES =====
|
||||
router := mux.NewRouter()
|
||||
restheadspec.SetupMuxRoutes(router, handler)
|
||||
router.Use(mux.MiddlewareFunc(security.AuthMiddleware))
|
||||
router.Use(mux.MiddlewareFunc(security.SetSecurityMiddleware))
|
||||
|
||||
log.Println("Server starting on :8080")
|
||||
http.ListenAndServe(":8080", router)
|
||||
}
|
||||
|
||||
// Callback implementations
|
||||
func authenticateUser(r *http.Request) (int, string, error) {
|
||||
userIDStr := r.Header.Get("X-User-ID")
|
||||
if userIDStr == "" {
|
||||
return 0, "", fmt.Errorf("authentication required")
|
||||
}
|
||||
userID, err := strconv.Atoi(userIDStr)
|
||||
return userID, "", err
|
||||
}
|
||||
|
||||
func loadColumnSec(userID int, schema, table string) ([]security.ColumnSecurity, error) {
|
||||
// Your implementation here
|
||||
return []security.ColumnSecurity{}, nil
|
||||
}
|
||||
|
||||
func loadRowSec(userID int, schema, table string) (security.RowSecurity, error) {
|
||||
return security.RowSecurity{
|
||||
Schema: schema,
|
||||
Tablename: table,
|
||||
UserID: userID,
|
||||
Template: "user_id = " + strconv.Itoa(userID),
|
||||
}, nil
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Testing Your Callbacks
|
||||
|
||||
### Unit Test Example
|
||||
|
||||
```go
|
||||
func TestAuthCallback(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/api/orders", nil)
|
||||
req.Header.Set("X-User-ID", "123")
|
||||
|
||||
userID, roles, err := myAuthFunction(req)
|
||||
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 123, userID)
|
||||
}
|
||||
|
||||
func TestColumnSecurityCallback(t *testing.T) {
|
||||
rules, err := myLoadColumnSecurity(123, "public", "employees")
|
||||
|
||||
assert.Nil(t, err)
|
||||
assert.Greater(t, len(rules), 0)
|
||||
assert.Equal(t, "mask", rules[0].Accesstype)
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Common Patterns
|
||||
|
||||
### Pattern 1: Role-Based Security
|
||||
|
||||
```go
|
||||
func loadColumnSec(userID int, schema, table string) ([]security.ColumnSecurity, error) {
|
||||
roles := getUserRoles(userID)
|
||||
|
||||
if contains(roles, "admin") {
|
||||
// Admins see everything
|
||||
return []security.ColumnSecurity{}, nil
|
||||
}
|
||||
|
||||
// Non-admins have restrictions
|
||||
return []security.ColumnSecurity{
|
||||
{Path: []string{"ssn"}, Accesstype: "mask"},
|
||||
}, nil
|
||||
}
|
||||
```
|
||||
|
||||
### Pattern 2: Tenant Isolation
|
||||
|
||||
```go
|
||||
func loadRowSec(userID int, schema, table string) (security.RowSecurity, error) {
|
||||
tenantID := getUserTenant(userID)
|
||||
|
||||
return security.RowSecurity{
|
||||
Template: fmt.Sprintf("tenant_id = %d", tenantID),
|
||||
}, nil
|
||||
}
|
||||
```
|
||||
|
||||
### Pattern 3: Caching Security Rules
|
||||
|
||||
```go
|
||||
var securityCache = cache.New(5*time.Minute, 10*time.Minute)
|
||||
|
||||
func loadColumnSec(userID int, schema, table string) ([]security.ColumnSecurity, error) {
|
||||
cacheKey := fmt.Sprintf("%d:%s.%s", userID, schema, table)
|
||||
|
||||
if cached, found := securityCache.Get(cacheKey); found {
|
||||
return cached.([]security.ColumnSecurity), nil
|
||||
}
|
||||
|
||||
rules := loadFromDatabase(userID, schema, table)
|
||||
securityCache.Set(cacheKey, rules, cache.DefaultExpiration)
|
||||
|
||||
return rules, nil
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Error: "AuthenticateCallback not set"
|
||||
**Solution:** Configure all three callbacks before calling `SetupSecurityProvider`:
|
||||
```go
|
||||
security.GlobalSecurity.AuthenticateCallback = myAuthFunc
|
||||
security.GlobalSecurity.LoadColumnSecurityCallback = myColSecFunc
|
||||
security.GlobalSecurity.LoadRowSecurityCallback = myRowSecFunc
|
||||
```
|
||||
|
||||
### Error: "Authentication failed"
|
||||
**Solution:** Check your `AuthenticateCallback` implementation. Ensure it returns valid user ID or proper error.
|
||||
|
||||
### Security rules not applying
|
||||
**Solution:**
|
||||
1. Check callbacks are returning data
|
||||
2. Enable debug logging
|
||||
3. Verify database queries return results
|
||||
4. Check user has security groups assigned
|
||||
|
||||
---
|
||||
|
||||
## Next Steps
|
||||
|
||||
1. ✅ Implement the three callbacks for your system
|
||||
2. ✅ Configure `GlobalSecurity` with your callbacks
|
||||
3. ✅ Call `SetupSecurityProvider`
|
||||
4. ✅ Test with different users and verify isolation
|
||||
5. ✅ Review `callbacks_example.go` for more examples
|
||||
|
||||
For complete working examples, see:
|
||||
- `pkg/security/callbacks_example.go` - 7 example implementations
|
||||
- `examples/secure_server/main.go` - Full server example
|
||||
- `pkg/security/README.md` - Comprehensive documentation
|
||||
402
pkg/security/QUICK_REFERENCE.md
Normal file
402
pkg/security/QUICK_REFERENCE.md
Normal file
@ -0,0 +1,402 @@
|
||||
# Security Provider - Quick Reference
|
||||
|
||||
## 3-Step Setup
|
||||
|
||||
```go
|
||||
// Step 1: Implement callbacks
|
||||
func myAuth(r *http.Request) (int, string, error) { /* ... */ }
|
||||
func myColSec(userID int, schema, table string) ([]security.ColumnSecurity, error) { /* ... */ }
|
||||
func myRowSec(userID int, schema, table string) (security.RowSecurity, error) { /* ... */ }
|
||||
|
||||
// Step 2: Configure callbacks
|
||||
security.GlobalSecurity.AuthenticateCallback = myAuth
|
||||
security.GlobalSecurity.LoadColumnSecurityCallback = myColSec
|
||||
security.GlobalSecurity.LoadRowSecurityCallback = myRowSec
|
||||
|
||||
// Step 3: Setup and apply middleware
|
||||
security.SetupSecurityProvider(handler, &security.GlobalSecurity)
|
||||
router.Use(mux.MiddlewareFunc(security.AuthMiddleware))
|
||||
router.Use(mux.MiddlewareFunc(security.SetSecurityMiddleware))
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Callback Signatures
|
||||
|
||||
```go
|
||||
// 1. Authentication
|
||||
func(r *http.Request) (userID int, roles string, err error)
|
||||
|
||||
// 2. Column Security
|
||||
func(userID int, schema, tablename string) ([]ColumnSecurity, error)
|
||||
|
||||
// 3. Row Security
|
||||
func(userID int, schema, tablename string) (RowSecurity, error)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## ColumnSecurity Structure
|
||||
|
||||
```go
|
||||
security.ColumnSecurity{
|
||||
Path: []string{"column_name"}, // ["ssn"] or ["address", "street"]
|
||||
Accesstype: "mask", // "mask" or "hide"
|
||||
MaskStart: 5, // Mask first N chars
|
||||
MaskEnd: 0, // Mask last N chars
|
||||
MaskChar: "*", // Masking character
|
||||
MaskInvert: false, // true = mask middle
|
||||
}
|
||||
```
|
||||
|
||||
### Common Examples
|
||||
|
||||
```go
|
||||
// Hide entire field
|
||||
{Path: []string{"salary"}, Accesstype: "hide"}
|
||||
|
||||
// Mask SSN (show last 4)
|
||||
{Path: []string{"ssn"}, Accesstype: "mask", MaskStart: 5}
|
||||
|
||||
// Mask credit card (show last 4)
|
||||
{Path: []string{"credit_card"}, Accesstype: "mask", MaskStart: 12}
|
||||
|
||||
// Mask email (j***@example.com)
|
||||
{Path: []string{"email"}, Accesstype: "mask", MaskStart: 1, MaskEnd: 0}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## RowSecurity Structure
|
||||
|
||||
```go
|
||||
security.RowSecurity{
|
||||
Schema: "public",
|
||||
Tablename: "orders",
|
||||
UserID: 123,
|
||||
Template: "user_id = {UserID}", // WHERE clause
|
||||
HasBlock: false, // true = block all access
|
||||
}
|
||||
```
|
||||
|
||||
### Template Variables
|
||||
|
||||
- `{UserID}` - Current user ID
|
||||
- `{PrimaryKeyName}` - Primary key column
|
||||
- `{TableName}` - Table name
|
||||
- `{SchemaName}` - Schema name
|
||||
|
||||
### Common Examples
|
||||
|
||||
```go
|
||||
// Users see only their records
|
||||
Template: "user_id = {UserID}"
|
||||
|
||||
// Users see their records OR public ones
|
||||
Template: "user_id = {UserID} OR is_public = true"
|
||||
|
||||
// Tenant isolation
|
||||
Template: "tenant_id = 5 AND user_id = {UserID}"
|
||||
|
||||
// Complex with subquery
|
||||
Template: "dept_id IN (SELECT dept_id FROM user_depts WHERE user_id = {UserID})"
|
||||
|
||||
// Block all access
|
||||
HasBlock: true
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Example Implementations
|
||||
|
||||
### Simple Header Auth
|
||||
|
||||
```go
|
||||
func authFromHeader(r *http.Request) (int, string, error) {
|
||||
userIDStr := r.Header.Get("X-User-ID")
|
||||
if userIDStr == "" {
|
||||
return 0, "", fmt.Errorf("X-User-ID required")
|
||||
}
|
||||
userID, err := strconv.Atoi(userIDStr)
|
||||
return userID, "", err
|
||||
}
|
||||
```
|
||||
|
||||
### JWT Auth
|
||||
|
||||
```go
|
||||
func authFromJWT(r *http.Request) (int, string, error) {
|
||||
token := strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer ")
|
||||
claims, err := jwt.Parse(token, secret)
|
||||
if err != nil {
|
||||
return 0, "", err
|
||||
}
|
||||
return claims.UserID, claims.Roles, nil
|
||||
}
|
||||
```
|
||||
|
||||
### Static Column Security
|
||||
|
||||
```go
|
||||
func loadColSec(userID int, schema, table string) ([]security.ColumnSecurity, error) {
|
||||
if table == "employees" {
|
||||
return []security.ColumnSecurity{
|
||||
{Path: []string{"ssn"}, Accesstype: "mask", MaskStart: 5},
|
||||
{Path: []string{"salary"}, Accesstype: "hide"},
|
||||
}, nil
|
||||
}
|
||||
return []security.ColumnSecurity{}, nil
|
||||
}
|
||||
```
|
||||
|
||||
### Database Column Security
|
||||
|
||||
```go
|
||||
func loadColSec(userID int, schema, table string) ([]security.ColumnSecurity, error) {
|
||||
rows, err := db.Query(`
|
||||
SELECT control, accesstype, jsonvalue
|
||||
FROM core.secacces
|
||||
WHERE rid_hub IN (...)
|
||||
AND control ILIKE ?
|
||||
`, fmt.Sprintf("%s.%s%%", schema, table))
|
||||
// ... parse and return
|
||||
}
|
||||
```
|
||||
|
||||
### Static Row Security
|
||||
|
||||
```go
|
||||
func loadRowSec(userID int, schema, table string) (security.RowSecurity, error) {
|
||||
templates := map[string]string{
|
||||
"orders": "user_id = {UserID}",
|
||||
"documents": "user_id = {UserID} OR is_public = true",
|
||||
}
|
||||
return security.RowSecurity{
|
||||
Template: templates[table],
|
||||
}, nil
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Testing
|
||||
|
||||
```go
|
||||
// Test auth callback
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req.Header.Set("X-User-ID", "123")
|
||||
userID, roles, err := myAuth(req)
|
||||
assert.Equal(t, 123, userID)
|
||||
|
||||
// Test column security callback
|
||||
rules, err := myColSec(123, "public", "employees")
|
||||
assert.Equal(t, "mask", rules[0].Accesstype)
|
||||
|
||||
// Test row security callback
|
||||
rowSec, err := myRowSec(123, "public", "orders")
|
||||
assert.Equal(t, "user_id = {UserID}", rowSec.Template)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Request Flow
|
||||
|
||||
```
|
||||
HTTP Request
|
||||
↓
|
||||
AuthMiddleware → calls AuthenticateCallback
|
||||
↓ (adds userID to context)
|
||||
SetSecurityMiddleware → adds GlobalSecurity to context
|
||||
↓
|
||||
Handler.Handle()
|
||||
↓
|
||||
BeforeRead Hook → calls LoadColumnSecurityCallback + LoadRowSecurityCallback
|
||||
↓
|
||||
BeforeScan Hook → applies row security (WHERE clause)
|
||||
↓
|
||||
Database Query
|
||||
↓
|
||||
AfterRead Hook → applies column security (masking)
|
||||
↓
|
||||
HTTP Response
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Common Patterns
|
||||
|
||||
### Role-Based Security
|
||||
|
||||
```go
|
||||
func loadColSec(userID int, schema, table string) ([]security.ColumnSecurity, error) {
|
||||
if isAdmin(userID) {
|
||||
return []security.ColumnSecurity{}, nil // No restrictions
|
||||
}
|
||||
return loadRestrictions(userID, schema, table), nil
|
||||
}
|
||||
```
|
||||
|
||||
### Tenant Isolation
|
||||
|
||||
```go
|
||||
func loadRowSec(userID int, schema, table string) (security.RowSecurity, error) {
|
||||
tenantID := getUserTenant(userID)
|
||||
return security.RowSecurity{
|
||||
Template: fmt.Sprintf("tenant_id = %d", tenantID),
|
||||
}, nil
|
||||
}
|
||||
```
|
||||
|
||||
### Caching
|
||||
|
||||
```go
|
||||
var cache = make(map[string][]security.ColumnSecurity)
|
||||
|
||||
func loadColSec(userID int, schema, table string) ([]security.ColumnSecurity, error) {
|
||||
key := fmt.Sprintf("%d:%s.%s", userID, schema, table)
|
||||
if cached, ok := cache[key]; ok {
|
||||
return cached, nil
|
||||
}
|
||||
rules := loadFromDB(userID, schema, table)
|
||||
cache[key] = rules
|
||||
return rules, nil
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Error Handling
|
||||
|
||||
```go
|
||||
// Setup will fail if callbacks not configured
|
||||
if err := security.SetupSecurityProvider(handler, &security.GlobalSecurity); err != nil {
|
||||
log.Fatal("Security setup failed:", err)
|
||||
}
|
||||
|
||||
// Auth middleware rejects if callback returns error
|
||||
func myAuth(r *http.Request) (int, string, error) {
|
||||
if invalid {
|
||||
return 0, "", fmt.Errorf("invalid credentials") // Returns HTTP 401
|
||||
}
|
||||
return userID, roles, nil
|
||||
}
|
||||
|
||||
// Security loading can fail gracefully
|
||||
func loadColSec(userID int, schema, table string) ([]security.ColumnSecurity, error) {
|
||||
rules, err := db.Load(...)
|
||||
if err != nil {
|
||||
log.Printf("Failed to load security: %v", err)
|
||||
return []security.ColumnSecurity{}, nil // No rules = no restrictions
|
||||
}
|
||||
return rules, nil
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Debugging
|
||||
|
||||
```go
|
||||
// Enable debug logging
|
||||
import "github.com/bitechdev/GoCore/pkg/cfg"
|
||||
cfg.SetLogLevel("DEBUG")
|
||||
|
||||
// Log in callbacks
|
||||
func myAuth(r *http.Request) (int, string, error) {
|
||||
token := r.Header.Get("Authorization")
|
||||
log.Printf("Auth: token=%s", token)
|
||||
// ...
|
||||
}
|
||||
|
||||
// Check if callbacks are called
|
||||
func loadColSec(userID int, schema, table string) ([]security.ColumnSecurity, error) {
|
||||
log.Printf("Loading column security: user=%d, schema=%s, table=%s", userID, schema, table)
|
||||
// ...
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Complete Minimal Example
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/restheadspec"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/security"
|
||||
"github.com/gorilla/mux"
|
||||
)
|
||||
|
||||
func main() {
|
||||
handler := restheadspec.NewHandlerWithGORM(db)
|
||||
|
||||
// Configure callbacks
|
||||
security.GlobalSecurity.AuthenticateCallback = func(r *http.Request) (int, string, error) {
|
||||
id, _ := strconv.Atoi(r.Header.Get("X-User-ID"))
|
||||
return id, "", nil
|
||||
}
|
||||
security.GlobalSecurity.LoadColumnSecurityCallback = func(u int, s, t string) ([]security.ColumnSecurity, error) {
|
||||
return []security.ColumnSecurity{}, nil
|
||||
}
|
||||
security.GlobalSecurity.LoadRowSecurityCallback = func(u int, s, t string) (security.RowSecurity, error) {
|
||||
return security.RowSecurity{Template: fmt.Sprintf("user_id = %d", u)}, nil
|
||||
}
|
||||
|
||||
// Setup
|
||||
security.SetupSecurityProvider(handler, &security.GlobalSecurity)
|
||||
|
||||
// Middleware
|
||||
router := mux.NewRouter()
|
||||
restheadspec.SetupMuxRoutes(router, handler)
|
||||
router.Use(mux.MiddlewareFunc(security.AuthMiddleware))
|
||||
router.Use(mux.MiddlewareFunc(security.SetSecurityMiddleware))
|
||||
|
||||
http.ListenAndServe(":8080", router)
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Resources
|
||||
|
||||
| File | Description |
|
||||
|------|-------------|
|
||||
| `CALLBACKS_GUIDE.md` | **Start here** - Complete implementation guide |
|
||||
| `callbacks_example.go` | 7 working examples to copy |
|
||||
| `CALLBACKS_SUMMARY.md` | Architecture overview |
|
||||
| `README.md` | Full documentation |
|
||||
| `setup_example.go` | Integration examples |
|
||||
|
||||
---
|
||||
|
||||
## Cheat Sheet
|
||||
|
||||
```go
|
||||
// ===== REQUIRED SETUP =====
|
||||
security.GlobalSecurity.AuthenticateCallback = myAuthFunc
|
||||
security.GlobalSecurity.LoadColumnSecurityCallback = myColFunc
|
||||
security.GlobalSecurity.LoadRowSecurityCallback = myRowFunc
|
||||
security.SetupSecurityProvider(handler, &security.GlobalSecurity)
|
||||
|
||||
// ===== CALLBACK SIGNATURES =====
|
||||
func(r *http.Request) (int, string, error) // Auth
|
||||
func(int, string, string) ([]security.ColumnSecurity, error) // Column
|
||||
func(int, string, string) (security.RowSecurity, error) // Row
|
||||
|
||||
// ===== QUICK EXAMPLES =====
|
||||
// Header auth
|
||||
func(r *http.Request) (int, string, error) {
|
||||
id, _ := strconv.Atoi(r.Header.Get("X-User-ID"))
|
||||
return id, "", nil
|
||||
}
|
||||
|
||||
// Mask SSN
|
||||
{Path: []string{"ssn"}, Accesstype: "mask", MaskStart: 5}
|
||||
|
||||
// User isolation
|
||||
{Template: "user_id = {UserID}"}
|
||||
```
|
||||
418
pkg/security/callbacks_example.go
Normal file
418
pkg/security/callbacks_example.go
Normal file
@ -0,0 +1,418 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
DBM "github.com/bitechdev/GoCore/pkg/models"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
)
|
||||
|
||||
// This file provides example implementations of the required security callbacks.
|
||||
// Copy these functions and modify them to match your authentication and database schema.
|
||||
|
||||
// =============================================================================
|
||||
// EXAMPLE 1: Simple Header-Based Authentication
|
||||
// =============================================================================
|
||||
|
||||
// ExampleAuthenticateFromHeader extracts user ID from X-User-ID header
|
||||
func ExampleAuthenticateFromHeader(r *http.Request) (userID int, roles string, err error) {
|
||||
userIDStr := r.Header.Get("X-User-ID")
|
||||
if userIDStr == "" {
|
||||
return 0, "", fmt.Errorf("X-User-ID header not provided")
|
||||
}
|
||||
|
||||
userID, err = strconv.Atoi(userIDStr)
|
||||
if err != nil {
|
||||
return 0, "", fmt.Errorf("invalid user ID format: %v", err)
|
||||
}
|
||||
|
||||
// Optionally extract roles
|
||||
roles = r.Header.Get("X-User-Roles") // comma-separated: "admin,manager"
|
||||
|
||||
return userID, roles, nil
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// EXAMPLE 2: JWT Token Authentication
|
||||
// =============================================================================
|
||||
|
||||
// ExampleAuthenticateFromJWT parses a JWT token and extracts user info
|
||||
// You'll need to import a JWT library like github.com/golang-jwt/jwt/v5
|
||||
func ExampleAuthenticateFromJWT(r *http.Request) (userID int, roles string, err error) {
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
if authHeader == "" {
|
||||
return 0, "", fmt.Errorf("authorization header not provided")
|
||||
}
|
||||
|
||||
// Extract Bearer token
|
||||
tokenString := strings.TrimPrefix(authHeader, "Bearer ")
|
||||
if tokenString == authHeader {
|
||||
return 0, "", fmt.Errorf("invalid authorization header format")
|
||||
}
|
||||
|
||||
// TODO: Parse and validate JWT token
|
||||
// Example using github.com/golang-jwt/jwt/v5:
|
||||
//
|
||||
// token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
|
||||
// return []byte(os.Getenv("JWT_SECRET")), nil
|
||||
// })
|
||||
//
|
||||
// if err != nil || !token.Valid {
|
||||
// return 0, "", fmt.Errorf("invalid token: %v", err)
|
||||
// }
|
||||
//
|
||||
// claims := token.Claims.(jwt.MapClaims)
|
||||
// userID = int(claims["user_id"].(float64))
|
||||
// roles = claims["roles"].(string)
|
||||
|
||||
return 0, "", fmt.Errorf("JWT parsing not implemented - see example above")
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// EXAMPLE 3: Session Cookie Authentication
|
||||
// =============================================================================
|
||||
|
||||
// ExampleAuthenticateFromSession validates a session cookie
|
||||
func ExampleAuthenticateFromSession(r *http.Request) (userID int, roles string, err error) {
|
||||
sessionCookie, err := r.Cookie("session_id")
|
||||
if err != nil {
|
||||
return 0, "", fmt.Errorf("session cookie not found")
|
||||
}
|
||||
|
||||
// TODO: Validate session against your session store (Redis, database, etc.)
|
||||
// Example:
|
||||
//
|
||||
// session, err := sessionStore.Get(sessionCookie.Value)
|
||||
// if err != nil {
|
||||
// return 0, "", fmt.Errorf("invalid session")
|
||||
// }
|
||||
//
|
||||
// userID = session.UserID
|
||||
// roles = session.Roles
|
||||
|
||||
_ = sessionCookie // Suppress unused warning until implemented
|
||||
return 0, "", fmt.Errorf("session validation not implemented - see example above")
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// EXAMPLE 4: Column Security - Database Implementation
|
||||
// =============================================================================
|
||||
|
||||
// ExampleLoadColumnSecurityFromDatabase loads column security rules from database
|
||||
// This implementation assumes the following database schema:
|
||||
//
|
||||
// CREATE TABLE core.secacces (
|
||||
// rid_secacces SERIAL PRIMARY KEY,
|
||||
// rid_hub INTEGER,
|
||||
// control TEXT, -- Format: "schema.table.column"
|
||||
// accesstype TEXT, -- "mask" or "hide"
|
||||
// jsonvalue JSONB -- Masking configuration
|
||||
// );
|
||||
//
|
||||
// CREATE TABLE core.hub_link (
|
||||
// rid_hub_parent INTEGER, -- Security group ID
|
||||
// rid_hub_child INTEGER, -- User ID
|
||||
// parent_hubtype TEXT -- 'secgroup'
|
||||
// );
|
||||
func ExampleLoadColumnSecurityFromDatabase(pUserID int, pSchema, pTablename string) ([]ColumnSecurity, error) {
|
||||
colSecList := make([]ColumnSecurity, 0)
|
||||
|
||||
getExtraFilters := func(pStr string) map[string]string {
|
||||
mp := make(map[string]string, 0)
|
||||
for i, val := range strings.Split(pStr, ",") {
|
||||
if i <= 1 {
|
||||
continue
|
||||
}
|
||||
vals := strings.Split(val, ":")
|
||||
if len(vals) > 1 {
|
||||
mp[vals[0]] = vals[1]
|
||||
}
|
||||
}
|
||||
return mp
|
||||
}
|
||||
|
||||
rows, err := DBM.DBConn.Raw(fmt.Sprintf(`
|
||||
SELECT a.rid_secacces, a.control, a.accesstype, a.jsonvalue
|
||||
FROM core.secacces a
|
||||
WHERE a.rid_hub IN (
|
||||
SELECT l.rid_hub_parent
|
||||
FROM core.hub_link l
|
||||
WHERE l.parent_hubtype = 'secgroup'
|
||||
AND l.rid_hub_child = ?
|
||||
)
|
||||
AND control ILIKE '%s.%s%%'
|
||||
`, pSchema, pTablename), pUserID).Rows()
|
||||
|
||||
defer func() {
|
||||
if rows != nil {
|
||||
rows.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
if err != nil {
|
||||
return colSecList, fmt.Errorf("failed to fetch column security from SQL: %v", err)
|
||||
}
|
||||
|
||||
for rows.Next() {
|
||||
var rid int
|
||||
var jsondata []byte
|
||||
var control, accesstype string
|
||||
|
||||
err = rows.Scan(&rid, &control, &accesstype, &jsondata)
|
||||
if err != nil {
|
||||
return colSecList, fmt.Errorf("failed to scan column security: %v", err)
|
||||
}
|
||||
|
||||
parts := strings.Split(control, ",")
|
||||
ids := strings.Split(parts[0], ".")
|
||||
if len(ids) < 3 {
|
||||
continue
|
||||
}
|
||||
|
||||
jsonvalue := make(map[string]interface{})
|
||||
if len(jsondata) > 1 {
|
||||
err = json.Unmarshal(jsondata, &jsonvalue)
|
||||
if err != nil {
|
||||
logger.Error("Failed to parse json: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
colsec := ColumnSecurity{
|
||||
Schema: pSchema,
|
||||
Tablename: pTablename,
|
||||
UserID: pUserID,
|
||||
Path: ids[2:],
|
||||
ExtraFilters: getExtraFilters(control),
|
||||
Accesstype: accesstype,
|
||||
Control: control,
|
||||
ID: int(rid),
|
||||
}
|
||||
|
||||
// Parse masking configuration from JSON
|
||||
if v, ok := jsonvalue["start"]; ok {
|
||||
if value, ok := v.(float64); ok {
|
||||
colsec.MaskStart = int(value)
|
||||
}
|
||||
}
|
||||
|
||||
if v, ok := jsonvalue["end"]; ok {
|
||||
if value, ok := v.(float64); ok {
|
||||
colsec.MaskEnd = int(value)
|
||||
}
|
||||
}
|
||||
|
||||
if v, ok := jsonvalue["invert"]; ok {
|
||||
if value, ok := v.(bool); ok {
|
||||
colsec.MaskInvert = value
|
||||
}
|
||||
}
|
||||
|
||||
if v, ok := jsonvalue["char"]; ok {
|
||||
if value, ok := v.(string); ok {
|
||||
colsec.MaskChar = value
|
||||
}
|
||||
}
|
||||
|
||||
colSecList = append(colSecList, colsec)
|
||||
}
|
||||
|
||||
return colSecList, nil
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// EXAMPLE 5: Column Security - In-Memory/Static Configuration
|
||||
// =============================================================================
|
||||
|
||||
// ExampleLoadColumnSecurityFromConfig loads column security from static config
|
||||
func ExampleLoadColumnSecurityFromConfig(pUserID int, pSchema, pTablename string) ([]ColumnSecurity, error) {
|
||||
// Example: Define security rules in code or load from config file
|
||||
securityRules := map[string][]ColumnSecurity{
|
||||
"public.employees": {
|
||||
{
|
||||
Schema: "public",
|
||||
Tablename: "employees",
|
||||
Path: []string{"ssn"},
|
||||
Accesstype: "mask",
|
||||
MaskStart: 5,
|
||||
MaskEnd: 0,
|
||||
MaskChar: "*",
|
||||
},
|
||||
{
|
||||
Schema: "public",
|
||||
Tablename: "employees",
|
||||
Path: []string{"salary"},
|
||||
Accesstype: "hide",
|
||||
},
|
||||
},
|
||||
"public.customers": {
|
||||
{
|
||||
Schema: "public",
|
||||
Tablename: "customers",
|
||||
Path: []string{"credit_card"},
|
||||
Accesstype: "mask",
|
||||
MaskStart: 12,
|
||||
MaskEnd: 0,
|
||||
MaskChar: "*",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("%s.%s", pSchema, pTablename)
|
||||
rules, ok := securityRules[key]
|
||||
if !ok {
|
||||
return []ColumnSecurity{}, nil // No rules for this table
|
||||
}
|
||||
|
||||
// Filter by user ID if needed
|
||||
// For this example, all rules apply to all users
|
||||
return rules, nil
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// EXAMPLE 6: Row Security - Database Implementation
|
||||
// =============================================================================
|
||||
|
||||
// ExampleLoadRowSecurityFromDatabase loads row security rules from database
|
||||
// This implementation assumes a PostgreSQL function:
|
||||
//
|
||||
// CREATE FUNCTION core.api_sec_rowtemplate(
|
||||
// p_schema TEXT,
|
||||
// p_table TEXT,
|
||||
// p_userid INTEGER
|
||||
// ) RETURNS TABLE (
|
||||
// p_retval INTEGER,
|
||||
// p_errmsg TEXT,
|
||||
// p_template TEXT,
|
||||
// p_block BOOLEAN
|
||||
// );
|
||||
func ExampleLoadRowSecurityFromDatabase(pUserID int, pSchema, pTablename string) (RowSecurity, error) {
|
||||
record := RowSecurity{
|
||||
Schema: pSchema,
|
||||
Tablename: pTablename,
|
||||
UserID: pUserID,
|
||||
}
|
||||
|
||||
rows, err := DBM.DBConn.Raw(`
|
||||
SELECT r.p_retval, r.p_errmsg, r.p_template, r.p_block
|
||||
FROM core.api_sec_rowtemplate(?, ?, ?) r
|
||||
`, pSchema, pTablename, pUserID).Rows()
|
||||
|
||||
defer func() {
|
||||
if rows != nil {
|
||||
rows.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
if err != nil {
|
||||
return record, fmt.Errorf("failed to fetch row security from SQL: %v", err)
|
||||
}
|
||||
|
||||
for rows.Next() {
|
||||
var retval int
|
||||
var errmsg string
|
||||
|
||||
err = rows.Scan(&retval, &errmsg, &record.Template, &record.HasBlock)
|
||||
if err != nil {
|
||||
return record, fmt.Errorf("failed to scan row security: %v", err)
|
||||
}
|
||||
|
||||
if retval != 0 {
|
||||
return RowSecurity{}, fmt.Errorf("api_sec_rowtemplate error: %s", errmsg)
|
||||
}
|
||||
}
|
||||
|
||||
return record, nil
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// EXAMPLE 7: Row Security - Static Configuration
|
||||
// =============================================================================
|
||||
|
||||
// ExampleLoadRowSecurityFromConfig loads row security from static config
|
||||
func ExampleLoadRowSecurityFromConfig(pUserID int, pSchema, pTablename string) (RowSecurity, error) {
|
||||
// Define row security templates based on entity
|
||||
templates := map[string]string{
|
||||
"public.orders": "user_id = {UserID}", // Users see only their orders
|
||||
"public.documents": "user_id = {UserID} OR is_public = true", // Users see their docs + public docs
|
||||
"public.employees": "department_id IN (SELECT department_id FROM user_departments WHERE user_id = {UserID})", // Complex filter
|
||||
}
|
||||
|
||||
// Define blocked entities (no access at all)
|
||||
blockedEntities := map[string][]int{
|
||||
"public.admin_logs": {}, // All users blocked (empty list = block all)
|
||||
"public.audit_logs": {1, 2, 3}, // Block users 1, 2, 3
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("%s.%s", pSchema, pTablename)
|
||||
|
||||
// Check if entity is blocked for this user
|
||||
if blockedUsers, ok := blockedEntities[key]; ok {
|
||||
if len(blockedUsers) == 0 {
|
||||
// Block all users
|
||||
return RowSecurity{
|
||||
Schema: pSchema,
|
||||
Tablename: pTablename,
|
||||
UserID: pUserID,
|
||||
HasBlock: true,
|
||||
}, nil
|
||||
}
|
||||
// Check if specific user is blocked
|
||||
for _, blockedUserID := range blockedUsers {
|
||||
if blockedUserID == pUserID {
|
||||
return RowSecurity{
|
||||
Schema: pSchema,
|
||||
Tablename: pTablename,
|
||||
UserID: pUserID,
|
||||
HasBlock: true,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Get template for this entity
|
||||
template, ok := templates[key]
|
||||
if !ok {
|
||||
// No row security defined - allow all rows
|
||||
return RowSecurity{
|
||||
Schema: pSchema,
|
||||
Tablename: pTablename,
|
||||
UserID: pUserID,
|
||||
Template: "",
|
||||
HasBlock: false,
|
||||
}, nil
|
||||
}
|
||||
|
||||
return RowSecurity{
|
||||
Schema: pSchema,
|
||||
Tablename: pTablename,
|
||||
UserID: pUserID,
|
||||
Template: template,
|
||||
HasBlock: false,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// SETUP HELPER: Configure All Callbacks
|
||||
// =============================================================================
|
||||
|
||||
// SetupCallbacksExample shows how to configure all callbacks
|
||||
func SetupCallbacksExample() {
|
||||
// Option 1: Use database-backed security (production)
|
||||
GlobalSecurity.AuthenticateCallback = ExampleAuthenticateFromJWT
|
||||
GlobalSecurity.LoadColumnSecurityCallback = ExampleLoadColumnSecurityFromDatabase
|
||||
GlobalSecurity.LoadRowSecurityCallback = ExampleLoadRowSecurityFromDatabase
|
||||
|
||||
// Option 2: Use static configuration (development/testing)
|
||||
// GlobalSecurity.AuthenticateCallback = ExampleAuthenticateFromHeader
|
||||
// GlobalSecurity.LoadColumnSecurityCallback = ExampleLoadColumnSecurityFromConfig
|
||||
// GlobalSecurity.LoadRowSecurityCallback = ExampleLoadRowSecurityFromConfig
|
||||
|
||||
// Option 3: Mix and match
|
||||
// GlobalSecurity.AuthenticateCallback = ExampleAuthenticateFromJWT
|
||||
// GlobalSecurity.LoadColumnSecurityCallback = ExampleLoadColumnSecurityFromConfig
|
||||
// GlobalSecurity.LoadRowSecurityCallback = ExampleLoadRowSecurityFromDatabase
|
||||
}
|
||||
244
pkg/security/hooks.go
Normal file
244
pkg/security/hooks.go
Normal file
@ -0,0 +1,244 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/restheadspec"
|
||||
)
|
||||
|
||||
// RegisterSecurityHooks registers all security-related hooks with the handler
|
||||
func RegisterSecurityHooks(handler *restheadspec.Handler, securityList *SecurityList) {
|
||||
|
||||
// Hook 1: BeforeRead - Load security rules
|
||||
handler.Hooks().Register(restheadspec.BeforeRead, func(hookCtx *restheadspec.HookContext) error {
|
||||
return loadSecurityRules(hookCtx, securityList)
|
||||
})
|
||||
|
||||
// Hook 2: BeforeScan - Apply row-level security filters
|
||||
handler.Hooks().Register(restheadspec.BeforeScan, func(hookCtx *restheadspec.HookContext) error {
|
||||
return applyRowSecurity(hookCtx, securityList)
|
||||
})
|
||||
|
||||
// Hook 3: AfterRead - Apply column-level security (masking)
|
||||
handler.Hooks().Register(restheadspec.AfterRead, func(hookCtx *restheadspec.HookContext) error {
|
||||
return applyColumnSecurity(hookCtx, securityList)
|
||||
})
|
||||
|
||||
// Hook 4 (Optional): Audit logging
|
||||
handler.Hooks().Register(restheadspec.AfterRead, func(hookCtx *restheadspec.HookContext) error {
|
||||
return logDataAccess(hookCtx)
|
||||
})
|
||||
}
|
||||
|
||||
// loadSecurityRules loads security configuration for the user and entity
|
||||
func loadSecurityRules(hookCtx *restheadspec.HookContext, securityList *SecurityList) error {
|
||||
// Extract user ID from context
|
||||
userID, ok := GetUserID(hookCtx.Context)
|
||||
if !ok {
|
||||
logger.Warn("No user ID in context for security check")
|
||||
return fmt.Errorf("authentication required")
|
||||
}
|
||||
|
||||
schema := hookCtx.Schema
|
||||
tablename := hookCtx.Entity
|
||||
|
||||
logger.Debug("Loading security rules for user=%d, schema=%s, table=%s", userID, schema, tablename)
|
||||
|
||||
// Load column security rules from database
|
||||
err := securityList.LoadColumnSecurity(userID, schema, tablename, false)
|
||||
if err != nil {
|
||||
logger.Warn("Failed to load column security: %v", err)
|
||||
// Don't fail the request if no security rules exist
|
||||
// return err
|
||||
}
|
||||
|
||||
// Load row security rules from database
|
||||
_, err = securityList.LoadRowSecurity(userID, schema, tablename, false)
|
||||
if err != nil {
|
||||
logger.Warn("Failed to load row security: %v", err)
|
||||
// Don't fail the request if no security rules exist
|
||||
// return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// applyRowSecurity applies row-level security filters to the query
|
||||
func applyRowSecurity(hookCtx *restheadspec.HookContext, securityList *SecurityList) error {
|
||||
userID, ok := GetUserID(hookCtx.Context)
|
||||
if !ok {
|
||||
return nil // No user context, skip
|
||||
}
|
||||
|
||||
schema := hookCtx.Schema
|
||||
tablename := hookCtx.Entity
|
||||
|
||||
// Get row security template
|
||||
rowSec, err := securityList.GetRowSecurityTemplate(userID, schema, tablename)
|
||||
if err != nil {
|
||||
// No row security defined, allow query to proceed
|
||||
logger.Debug("No row security for %s.%s@%d: %v", schema, tablename, userID, err)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if user has a blocking rule
|
||||
if rowSec.HasBlock {
|
||||
logger.Warn("User %d blocked from accessing %s.%s", userID, schema, tablename)
|
||||
return fmt.Errorf("access denied to %s", tablename)
|
||||
}
|
||||
|
||||
// If there's a security template, apply it as a WHERE clause
|
||||
if rowSec.Template != "" {
|
||||
// Get primary key name from model
|
||||
modelType := reflect.TypeOf(hookCtx.Model)
|
||||
if modelType.Kind() == reflect.Ptr {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
|
||||
// Find primary key field
|
||||
pkName := "id" // default
|
||||
for i := 0; i < modelType.NumField(); i++ {
|
||||
field := modelType.Field(i)
|
||||
if tag := field.Tag.Get("bun"); tag != "" {
|
||||
// Check for primary key tag
|
||||
if contains(tag, "pk") || contains(tag, "primary_key") {
|
||||
if sqlName := extractSQLName(tag); sqlName != "" {
|
||||
pkName = sqlName
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Generate the WHERE clause from template
|
||||
whereClause := rowSec.GetTemplate(pkName, modelType)
|
||||
|
||||
logger.Info("Applying row security filter for user %d on %s.%s: %s",
|
||||
userID, schema, tablename, whereClause)
|
||||
|
||||
// Apply the WHERE clause to the query
|
||||
// The query is in hookCtx.Query
|
||||
if selectQuery, ok := hookCtx.Query.(interface {
|
||||
Where(string, ...interface{}) interface{}
|
||||
}); ok {
|
||||
hookCtx.Query = selectQuery.Where(whereClause)
|
||||
} else {
|
||||
logger.Error("Unable to apply WHERE clause - query doesn't support Where method")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// applyColumnSecurity applies column-level security (masking/hiding) to results
|
||||
func applyColumnSecurity(hookCtx *restheadspec.HookContext, securityList *SecurityList) error {
|
||||
userID, ok := GetUserID(hookCtx.Context)
|
||||
if !ok {
|
||||
return nil // No user context, skip
|
||||
}
|
||||
|
||||
schema := hookCtx.Schema
|
||||
tablename := hookCtx.Entity
|
||||
|
||||
// Get result data
|
||||
result := hookCtx.Result
|
||||
if result == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
logger.Debug("Applying column security for user=%d, schema=%s, table=%s", userID, schema, tablename)
|
||||
|
||||
// Get model type
|
||||
modelType := reflect.TypeOf(hookCtx.Model)
|
||||
if modelType.Kind() == reflect.Ptr {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
|
||||
// Apply column security masking
|
||||
resultValue := reflect.ValueOf(result)
|
||||
if resultValue.Kind() == reflect.Ptr {
|
||||
resultValue = resultValue.Elem()
|
||||
}
|
||||
|
||||
err, maskedResult := securityList.ApplyColumnSecurity(resultValue, modelType, userID, schema, tablename)
|
||||
if err != nil {
|
||||
logger.Warn("Column security error: %v", err)
|
||||
// Don't fail the request, just log the issue
|
||||
return nil
|
||||
}
|
||||
|
||||
// Update the result with masked data
|
||||
if maskedResult.IsValid() && maskedResult.CanInterface() {
|
||||
hookCtx.Result = maskedResult.Interface()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// logDataAccess logs all data access for audit purposes
|
||||
func logDataAccess(hookCtx *restheadspec.HookContext) error {
|
||||
userID, _ := GetUserID(hookCtx.Context)
|
||||
|
||||
logger.Info("AUDIT: User %d accessed %s.%s with filters: %+v",
|
||||
userID,
|
||||
hookCtx.Schema,
|
||||
hookCtx.Entity,
|
||||
hookCtx.Options.Filters,
|
||||
)
|
||||
|
||||
// TODO: Write to audit log table or external audit service
|
||||
// auditLog := AuditLog{
|
||||
// UserID: userID,
|
||||
// Schema: hookCtx.Schema,
|
||||
// Entity: hookCtx.Entity,
|
||||
// Action: "READ",
|
||||
// Timestamp: time.Now(),
|
||||
// Filters: hookCtx.Options.Filters,
|
||||
// }
|
||||
// db.Create(&auditLog)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
|
||||
func contains(s, substr string) bool {
|
||||
return len(s) >= len(substr) && s[:len(substr)] == substr ||
|
||||
len(s) > len(substr) && s[len(s)-len(substr):] == substr
|
||||
}
|
||||
|
||||
func extractSQLName(tag string) string {
|
||||
// Simple parser for "column:name" or just "name"
|
||||
// This is a simplified version
|
||||
parts := splitTag(tag, ',')
|
||||
for _, part := range parts {
|
||||
if part != "" && !contains(part, ":") {
|
||||
return part
|
||||
}
|
||||
if contains(part, "column:") {
|
||||
return part[7:] // Skip "column:"
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func splitTag(tag string, sep rune) []string {
|
||||
var parts []string
|
||||
var current string
|
||||
for _, ch := range tag {
|
||||
if ch == sep {
|
||||
if current != "" {
|
||||
parts = append(parts, current)
|
||||
current = ""
|
||||
}
|
||||
} else {
|
||||
current += string(ch)
|
||||
}
|
||||
}
|
||||
if current != "" {
|
||||
parts = append(parts, current)
|
||||
}
|
||||
return parts
|
||||
}
|
||||
54
pkg/security/middleware.go
Normal file
54
pkg/security/middleware.go
Normal file
@ -0,0 +1,54 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
const (
|
||||
// Context keys for user information
|
||||
UserIDKey = "user_id"
|
||||
UserRolesKey = "user_roles"
|
||||
UserTokenKey = "user_token"
|
||||
)
|
||||
|
||||
// AuthMiddleware extracts user authentication from request and adds to context
|
||||
// This should be applied before the ResolveSpec handler
|
||||
// Uses GlobalSecurity.AuthenticateCallback if set, otherwise returns error
|
||||
func AuthMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Check if callback is set
|
||||
if GlobalSecurity.AuthenticateCallback == nil {
|
||||
http.Error(w, "AuthenticateCallback not set - you must provide an authentication callback", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Call the user-provided authentication callback
|
||||
userID, roles, err := GlobalSecurity.AuthenticateCallback(r)
|
||||
if err != nil {
|
||||
http.Error(w, "Authentication failed: "+err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
// Add user information to context
|
||||
ctx := context.WithValue(r.Context(), UserIDKey, userID)
|
||||
if roles != "" {
|
||||
ctx = context.WithValue(ctx, UserRolesKey, roles)
|
||||
}
|
||||
|
||||
// Continue with authenticated context
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
|
||||
// GetUserID extracts the user ID from context
|
||||
func GetUserID(ctx context.Context) (int, bool) {
|
||||
userID, ok := ctx.Value(UserIDKey).(int)
|
||||
return userID, ok
|
||||
}
|
||||
|
||||
// GetUserRoles extracts user roles from context
|
||||
func GetUserRoles(ctx context.Context) (string, bool) {
|
||||
roles, ok := ctx.Value(UserRolesKey).(string)
|
||||
return roles, ok
|
||||
}
|
||||
460
pkg/security/provider.go
Normal file
460
pkg/security/provider.go
Normal file
@ -0,0 +1,460 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/reflection"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
type ColumnSecurity struct {
|
||||
Schema string
|
||||
Tablename string
|
||||
Path []string
|
||||
ExtraFilters map[string]string
|
||||
UserID int
|
||||
Accesstype string `json:"accesstype"`
|
||||
MaskStart int
|
||||
MaskEnd int
|
||||
MaskInvert bool
|
||||
MaskChar string
|
||||
Control string `json:"control"`
|
||||
ID int `json:"id"`
|
||||
}
|
||||
|
||||
type RowSecurity struct {
|
||||
Schema string
|
||||
Tablename string
|
||||
Template string
|
||||
HasBlock bool
|
||||
UserID int
|
||||
}
|
||||
|
||||
func (m *RowSecurity) GetTemplate(pPrimaryKeyName string, pModelType reflect.Type) string {
|
||||
str := m.Template
|
||||
str = strings.ReplaceAll(str, "{PrimaryKeyName}", pPrimaryKeyName)
|
||||
str = strings.ReplaceAll(str, "{TableName}", m.Tablename)
|
||||
str = strings.ReplaceAll(str, "{SchemaName}", m.Schema)
|
||||
str = strings.ReplaceAll(str, "{UserID}", fmt.Sprintf("%d", m.UserID))
|
||||
return str
|
||||
}
|
||||
|
||||
// Callback function types for customizing security behavior
|
||||
type (
|
||||
// AuthenticateFunc extracts user ID and roles from HTTP request
|
||||
// Return userID, roles, error. If error is not nil, request will be rejected.
|
||||
AuthenticateFunc func(r *http.Request) (userID int, roles string, err error)
|
||||
|
||||
// LoadColumnSecurityFunc loads column security rules for a user and entity
|
||||
// Override this to customize how column security is loaded from your data source
|
||||
LoadColumnSecurityFunc func(pUserID int, pSchema, pTablename string) ([]ColumnSecurity, error)
|
||||
|
||||
// LoadRowSecurityFunc loads row security rules for a user and entity
|
||||
// Override this to customize how row security is loaded from your data source
|
||||
LoadRowSecurityFunc func(pUserID int, pSchema, pTablename string) (RowSecurity, error)
|
||||
)
|
||||
|
||||
type SecurityList struct {
|
||||
ColumnSecurityMutex sync.RWMutex
|
||||
ColumnSecurity map[string][]ColumnSecurity
|
||||
RowSecurityMutex sync.RWMutex
|
||||
RowSecurity map[string]RowSecurity
|
||||
|
||||
// Overridable callbacks
|
||||
AuthenticateCallback AuthenticateFunc
|
||||
LoadColumnSecurityCallback LoadColumnSecurityFunc
|
||||
LoadRowSecurityCallback LoadRowSecurityFunc
|
||||
}
|
||||
|
||||
const SECURITY_CONTEXT_KEY = "SecurityList"
|
||||
|
||||
var GlobalSecurity SecurityList
|
||||
|
||||
// SetSecurityMiddleware adds security context to requests
|
||||
func SetSecurityMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := context.WithValue(r.Context(), SECURITY_CONTEXT_KEY, &GlobalSecurity)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
|
||||
func maskString(pString string, maskStart, maskEnd int, maskChar string, invert bool) string {
|
||||
strLen := len(pString)
|
||||
middleIndex := (strLen / 2)
|
||||
newStr := ""
|
||||
if maskStart == 0 && maskEnd == 0 {
|
||||
maskStart = strLen
|
||||
maskEnd = strLen
|
||||
}
|
||||
if maskEnd > strLen {
|
||||
maskEnd = strLen
|
||||
}
|
||||
if maskStart > strLen {
|
||||
maskStart = strLen
|
||||
}
|
||||
if maskChar == "" {
|
||||
maskChar = "*"
|
||||
}
|
||||
for index, char := range pString {
|
||||
if invert && index >= middleIndex-maskStart && index <= middleIndex {
|
||||
newStr = newStr + maskChar
|
||||
continue
|
||||
}
|
||||
if invert && index <= middleIndex+maskEnd && index >= middleIndex {
|
||||
newStr = newStr + maskChar
|
||||
continue
|
||||
}
|
||||
if !invert && index <= maskStart {
|
||||
newStr = newStr + maskChar
|
||||
continue
|
||||
}
|
||||
if !invert && index >= strLen-1-maskEnd {
|
||||
newStr = newStr + maskChar
|
||||
continue
|
||||
}
|
||||
newStr = newStr + string(char)
|
||||
}
|
||||
|
||||
return newStr
|
||||
}
|
||||
|
||||
func (m *SecurityList) ColumSecurityApplyOnRecord(prevRecord reflect.Value, newRecord reflect.Value, modelType reflect.Type, pUserID int, pSchema, pTablename string) ([]string, error) {
|
||||
cols := make([]string, 0)
|
||||
if m.ColumnSecurity == nil {
|
||||
return cols, fmt.Errorf("security not initialized")
|
||||
}
|
||||
|
||||
if prevRecord.Type() != newRecord.Type() {
|
||||
logger.Error("prev:%s and new:%s record type mismatch", prevRecord.Type(), newRecord.Type())
|
||||
return cols, fmt.Errorf("prev and new record type mismatch")
|
||||
}
|
||||
|
||||
m.ColumnSecurityMutex.RLock()
|
||||
defer m.ColumnSecurityMutex.RUnlock()
|
||||
|
||||
colsecList, ok := m.ColumnSecurity[fmt.Sprintf("%s.%s@%d", pSchema, pTablename, pUserID)]
|
||||
if !ok || colsecList == nil {
|
||||
return cols, fmt.Errorf("no security data")
|
||||
}
|
||||
|
||||
for _, colsec := range colsecList {
|
||||
if !(strings.EqualFold(colsec.Accesstype, "mask") || strings.EqualFold(colsec.Accesstype, "hide")) {
|
||||
continue
|
||||
}
|
||||
lastRecords := interateStruct(prevRecord)
|
||||
newRecords := interateStruct(newRecord)
|
||||
var lastLoopField, lastLoopNewField reflect.Value
|
||||
pathLen := len(colsec.Path)
|
||||
for i, path := range colsec.Path {
|
||||
var nameType, fieldName string
|
||||
if len(newRecords) == 0 {
|
||||
if lastLoopNewField.IsValid() && lastLoopField.IsValid() && i < pathLen-1 {
|
||||
lastLoopNewField.Set(lastLoopField)
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
for ri := range newRecords {
|
||||
if !newRecords[ri].IsValid() || !lastRecords[ri].IsValid() {
|
||||
break
|
||||
}
|
||||
var field, oldField reflect.Value
|
||||
|
||||
columnData := reflection.GetModelColumnDetail(newRecords[ri])
|
||||
lastColumnData := reflection.GetModelColumnDetail(lastRecords[ri])
|
||||
for i, cols := range columnData {
|
||||
if cols.SQLName != "" && strings.EqualFold(cols.SQLName, path) {
|
||||
nameType = "sql"
|
||||
fieldName = cols.SQLName
|
||||
field = cols.FieldValue
|
||||
oldField = lastColumnData[i].FieldValue
|
||||
break
|
||||
}
|
||||
if cols.Name != "" && strings.EqualFold(cols.Name, path) {
|
||||
nameType = "struct"
|
||||
fieldName = cols.Name
|
||||
field = cols.FieldValue
|
||||
oldField = lastColumnData[i].FieldValue
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !field.IsValid() || !oldField.IsValid() {
|
||||
break
|
||||
}
|
||||
lastLoopField = oldField
|
||||
lastLoopNewField = field
|
||||
|
||||
if i == pathLen-1 {
|
||||
if strings.Contains(strings.ToLower(fieldName), "json") {
|
||||
prevSrc := oldField.Bytes()
|
||||
newSrc := field.Bytes()
|
||||
pathstr := strings.Join(colsec.Path, ".")
|
||||
prevPathValue := gjson.GetBytes(prevSrc, pathstr)
|
||||
newBytes, err := sjson.SetBytes(newSrc, pathstr, prevPathValue.Str)
|
||||
if err == nil {
|
||||
if field.CanSet() {
|
||||
field.SetBytes(newBytes)
|
||||
} else {
|
||||
logger.Warn("Value not settable: %v", field)
|
||||
cols = append(cols, pathstr)
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
if nameType == "sql" {
|
||||
if strings.EqualFold(colsec.Accesstype, "mask") || strings.EqualFold(colsec.Accesstype, "hide") {
|
||||
field.Set(oldField)
|
||||
cols = append(cols, strings.Join(colsec.Path, "."))
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
lastRecords = interateStruct(field)
|
||||
newRecords = interateStruct(oldField)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return cols, nil
|
||||
}
|
||||
|
||||
func interateStruct(val reflect.Value) []reflect.Value {
|
||||
list := make([]reflect.Value, 0)
|
||||
|
||||
switch val.Kind() {
|
||||
case reflect.Pointer, reflect.Interface:
|
||||
elem := val.Elem()
|
||||
if elem.IsValid() {
|
||||
list = append(list, interateStruct(elem)...)
|
||||
}
|
||||
return list
|
||||
case reflect.Array, reflect.Slice:
|
||||
for i := 0; i < val.Len(); i++ {
|
||||
elem := val.Index(i)
|
||||
if !elem.IsValid() {
|
||||
continue
|
||||
}
|
||||
list = append(list, interateStruct(elem)...)
|
||||
}
|
||||
return list
|
||||
case reflect.Struct:
|
||||
list = append(list, val)
|
||||
return list
|
||||
default:
|
||||
return list
|
||||
}
|
||||
}
|
||||
|
||||
func setColSecValue(fieldsrc reflect.Value, colsec ColumnSecurity, fieldTypeName string) (int, reflect.Value) {
|
||||
fieldval := fieldsrc
|
||||
if fieldsrc.Kind() == reflect.Pointer || fieldsrc.Kind() == reflect.Interface {
|
||||
fieldval = fieldval.Elem()
|
||||
}
|
||||
|
||||
if strings.Contains(strings.ToLower(fieldval.Kind().String()), "int") &&
|
||||
(strings.EqualFold(colsec.Accesstype, "mask") || strings.EqualFold(colsec.Accesstype, "hide")) {
|
||||
if fieldval.CanInt() && fieldval.CanSet() {
|
||||
fieldval.SetInt(0)
|
||||
}
|
||||
} else if (strings.Contains(strings.ToLower(fieldval.Kind().String()), "time") ||
|
||||
strings.Contains(strings.ToLower(fieldval.Kind().String()), "date")) &&
|
||||
(strings.EqualFold(colsec.Accesstype, "mask") || strings.EqualFold(colsec.Accesstype, "hide")) {
|
||||
fieldval.SetZero()
|
||||
} else if strings.Contains(strings.ToLower(fieldval.Kind().String()), "string") {
|
||||
strVal := fieldval.String()
|
||||
if strings.EqualFold(colsec.Accesstype, "mask") {
|
||||
fieldval.SetString(maskString(strVal, colsec.MaskStart, colsec.MaskEnd, colsec.MaskChar, colsec.MaskInvert))
|
||||
} else if strings.EqualFold(colsec.Accesstype, "hide") {
|
||||
fieldval.SetString("")
|
||||
}
|
||||
} else if strings.Contains(fieldTypeName, "json") &&
|
||||
(strings.EqualFold(colsec.Accesstype, "mask") || strings.EqualFold(colsec.Accesstype, "hide")) {
|
||||
if len(colsec.Path) < 2 {
|
||||
return 1, fieldval
|
||||
}
|
||||
pathstr := strings.Join(colsec.Path, ".")
|
||||
src := fieldval.Bytes()
|
||||
pathValue := gjson.GetBytes(src, pathstr)
|
||||
strValue := pathValue.String()
|
||||
if strings.EqualFold(colsec.Accesstype, "mask") {
|
||||
strValue = maskString(strValue, colsec.MaskStart, colsec.MaskEnd, colsec.MaskChar, colsec.MaskInvert)
|
||||
} else if strings.EqualFold(colsec.Accesstype, "hide") {
|
||||
strValue = ""
|
||||
}
|
||||
newBytes, err := sjson.SetBytes(src, pathstr, strValue)
|
||||
if err == nil {
|
||||
fieldval.SetBytes(newBytes)
|
||||
}
|
||||
}
|
||||
return 0, fieldsrc
|
||||
}
|
||||
|
||||
func (m *SecurityList) ApplyColumnSecurity(records reflect.Value, modelType reflect.Type, pUserID int, pSchema, pTablename string) (error, reflect.Value) {
|
||||
defer logger.CatchPanic("ApplyColumnSecurity")
|
||||
|
||||
if m.ColumnSecurity == nil {
|
||||
return fmt.Errorf("security not initialized"), records
|
||||
}
|
||||
|
||||
m.ColumnSecurityMutex.RLock()
|
||||
defer m.ColumnSecurityMutex.RUnlock()
|
||||
|
||||
colsecList, ok := m.ColumnSecurity[fmt.Sprintf("%s.%s@%d", pSchema, pTablename, pUserID)]
|
||||
if !ok || colsecList == nil {
|
||||
return fmt.Errorf("no security data"), records
|
||||
}
|
||||
|
||||
for _, colsec := range colsecList {
|
||||
if !(strings.EqualFold(colsec.Accesstype, "mask") || strings.EqualFold(colsec.Accesstype, "hide")) {
|
||||
continue
|
||||
}
|
||||
|
||||
if records.Kind() == reflect.Array || records.Kind() == reflect.Slice {
|
||||
for i := 0; i < records.Len(); i++ {
|
||||
record := records.Index(i)
|
||||
if !record.IsValid() {
|
||||
continue
|
||||
}
|
||||
|
||||
lastRecord := interateStruct(record)
|
||||
pathLen := len(colsec.Path)
|
||||
for i, path := range colsec.Path {
|
||||
var field reflect.Value
|
||||
var nameType, fieldName string
|
||||
if len(lastRecord) == 0 {
|
||||
break
|
||||
}
|
||||
columnData := reflection.GetModelColumnDetail(lastRecord[0])
|
||||
for _, cols := range columnData {
|
||||
if cols.SQLName != "" && strings.EqualFold(cols.SQLName, path) {
|
||||
nameType = "sql"
|
||||
fieldName = cols.SQLName
|
||||
field = cols.FieldValue
|
||||
break
|
||||
}
|
||||
if cols.Name != "" && strings.EqualFold(cols.Name, path) {
|
||||
nameType = "struct"
|
||||
fieldName = cols.Name
|
||||
field = cols.FieldValue
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if i == pathLen-1 {
|
||||
if nameType == "sql" || nameType == "struct" {
|
||||
setColSecValue(field, colsec, fieldName)
|
||||
}
|
||||
break
|
||||
}
|
||||
if field.IsValid() {
|
||||
lastRecord = interateStruct(field)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil, records
|
||||
}
|
||||
|
||||
func (m *SecurityList) LoadColumnSecurity(pUserID int, pSchema, pTablename string, pOverwrite bool) error {
|
||||
// Use the callback if provided
|
||||
if m.LoadColumnSecurityCallback == nil {
|
||||
return fmt.Errorf("LoadColumnSecurityCallback not set - you must provide a callback function")
|
||||
}
|
||||
|
||||
m.ColumnSecurityMutex.Lock()
|
||||
defer m.ColumnSecurityMutex.Unlock()
|
||||
|
||||
if m.ColumnSecurity == nil {
|
||||
m.ColumnSecurity = make(map[string][]ColumnSecurity, 0)
|
||||
}
|
||||
secKey := fmt.Sprintf("%s.%s@%d", pSchema, pTablename, pUserID)
|
||||
|
||||
if pOverwrite || m.ColumnSecurity[secKey] == nil {
|
||||
m.ColumnSecurity[secKey] = make([]ColumnSecurity, 0)
|
||||
}
|
||||
|
||||
// Call the user-provided callback to load security rules
|
||||
colSecList, err := m.LoadColumnSecurityCallback(pUserID, pSchema, pTablename)
|
||||
if err != nil {
|
||||
return fmt.Errorf("LoadColumnSecurityCallback failed: %v", err)
|
||||
}
|
||||
|
||||
m.ColumnSecurity[secKey] = colSecList
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *SecurityList) ClearSecurity(pUserID int, pSchema, pTablename string) error {
|
||||
var filtered []ColumnSecurity
|
||||
m.ColumnSecurityMutex.Lock()
|
||||
defer m.ColumnSecurityMutex.Unlock()
|
||||
|
||||
secKey := fmt.Sprintf("%s.%s@%d", pSchema, pTablename, pUserID)
|
||||
list, ok := m.ColumnSecurity[secKey]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, cs := range list {
|
||||
if !(cs.Schema == pSchema && cs.Tablename == pTablename && cs.UserID == pUserID) {
|
||||
filtered = append(filtered, cs)
|
||||
}
|
||||
}
|
||||
|
||||
m.ColumnSecurity[secKey] = filtered
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *SecurityList) LoadRowSecurity(pUserID int, pSchema, pTablename string, pOverwrite bool) (RowSecurity, error) {
|
||||
// Use the callback if provided
|
||||
if m.LoadRowSecurityCallback == nil {
|
||||
return RowSecurity{}, fmt.Errorf("LoadRowSecurityCallback not set - you must provide a callback function")
|
||||
}
|
||||
|
||||
m.RowSecurityMutex.Lock()
|
||||
defer m.RowSecurityMutex.Unlock()
|
||||
|
||||
if m.RowSecurity == nil {
|
||||
m.RowSecurity = make(map[string]RowSecurity, 0)
|
||||
}
|
||||
secKey := fmt.Sprintf("%s.%s@%d", pSchema, pTablename, pUserID)
|
||||
|
||||
// Call the user-provided callback to load security rules
|
||||
record, err := m.LoadRowSecurityCallback(pUserID, pSchema, pTablename)
|
||||
if err != nil {
|
||||
return RowSecurity{}, fmt.Errorf("LoadRowSecurityCallback failed: %v", err)
|
||||
}
|
||||
|
||||
m.RowSecurity[secKey] = record
|
||||
return record, nil
|
||||
}
|
||||
|
||||
func (m *SecurityList) GetRowSecurityTemplate(pUserID int, pSchema, pTablename string) (RowSecurity, error) {
|
||||
defer logger.CatchPanic("GetRowSecurityTemplate")
|
||||
|
||||
if m.RowSecurity == nil {
|
||||
return RowSecurity{}, fmt.Errorf("security not initialized")
|
||||
}
|
||||
|
||||
m.RowSecurityMutex.RLock()
|
||||
defer m.RowSecurityMutex.RUnlock()
|
||||
|
||||
rowSec, ok := m.RowSecurity[fmt.Sprintf("%s.%s@%d", pSchema, pTablename, pUserID)]
|
||||
if !ok {
|
||||
return RowSecurity{}, fmt.Errorf("no security data")
|
||||
}
|
||||
|
||||
return rowSec, nil
|
||||
}
|
||||
155
pkg/security/setup_example.go
Normal file
155
pkg/security/setup_example.go
Normal file
@ -0,0 +1,155 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/restheadspec"
|
||||
"github.com/gorilla/mux"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// SetupSecurityProvider initializes and configures the security provider
|
||||
// This should be called when setting up your HTTP server
|
||||
//
|
||||
// IMPORTANT: You MUST configure the callbacks before calling this function:
|
||||
// - GlobalSecurity.AuthenticateCallback
|
||||
// - GlobalSecurity.LoadColumnSecurityCallback
|
||||
// - GlobalSecurity.LoadRowSecurityCallback
|
||||
//
|
||||
// Example usage in your main.go or server setup:
|
||||
//
|
||||
// // Step 1: Configure callbacks (REQUIRED)
|
||||
// security.GlobalSecurity.AuthenticateCallback = myAuthFunction
|
||||
// security.GlobalSecurity.LoadColumnSecurityCallback = myLoadColumnSecurityFunction
|
||||
// security.GlobalSecurity.LoadRowSecurityCallback = myLoadRowSecurityFunction
|
||||
//
|
||||
// // Step 2: Setup security provider
|
||||
// handler := restheadspec.NewHandlerWithGORM(db)
|
||||
// security.SetupSecurityProvider(handler, &security.GlobalSecurity)
|
||||
//
|
||||
// // Step 3: Apply middleware
|
||||
// router.Use(mux.MiddlewareFunc(security.AuthMiddleware))
|
||||
// router.Use(mux.MiddlewareFunc(security.SetSecurityMiddleware))
|
||||
//
|
||||
func SetupSecurityProvider(handler *restheadspec.Handler, securityList *SecurityList) error {
|
||||
// Validate that required callbacks are configured
|
||||
if securityList.AuthenticateCallback == nil {
|
||||
return fmt.Errorf("AuthenticateCallback must be set before calling SetupSecurityProvider")
|
||||
}
|
||||
if securityList.LoadColumnSecurityCallback == nil {
|
||||
return fmt.Errorf("LoadColumnSecurityCallback must be set before calling SetupSecurityProvider")
|
||||
}
|
||||
if securityList.LoadRowSecurityCallback == nil {
|
||||
return fmt.Errorf("LoadRowSecurityCallback must be set before calling SetupSecurityProvider")
|
||||
}
|
||||
|
||||
// Initialize security maps if needed
|
||||
if securityList.ColumnSecurity == nil {
|
||||
securityList.ColumnSecurity = make(map[string][]ColumnSecurity)
|
||||
}
|
||||
if securityList.RowSecurity == nil {
|
||||
securityList.RowSecurity = make(map[string]RowSecurity)
|
||||
}
|
||||
|
||||
// Register all security hooks
|
||||
RegisterSecurityHooks(handler, securityList)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Chain creates a middleware chain
|
||||
func Chain(middlewares ...func(http.Handler) http.Handler) func(http.Handler) http.Handler {
|
||||
return func(final http.Handler) http.Handler {
|
||||
for i := len(middlewares) - 1; i >= 0; i-- {
|
||||
final = middlewares[i](final)
|
||||
}
|
||||
return final
|
||||
}
|
||||
}
|
||||
|
||||
// CompleteExample shows a full integration example with Gorilla Mux
|
||||
func CompleteExample(db *gorm.DB) (http.Handler, error) {
|
||||
// Step 1: Create the ResolveSpec handler
|
||||
handler := restheadspec.NewHandlerWithGORM(db)
|
||||
|
||||
// Step 2: Register your models
|
||||
// handler.RegisterModel("public", "users", User{})
|
||||
// handler.RegisterModel("public", "orders", Order{})
|
||||
|
||||
// Step 3: Configure security callbacks (REQUIRED!)
|
||||
// See callbacks_example.go for example implementations
|
||||
GlobalSecurity.AuthenticateCallback = ExampleAuthenticateFromHeader
|
||||
GlobalSecurity.LoadColumnSecurityCallback = ExampleLoadColumnSecurityFromDatabase
|
||||
GlobalSecurity.LoadRowSecurityCallback = ExampleLoadRowSecurityFromDatabase
|
||||
|
||||
// Step 4: Setup security provider
|
||||
if err := SetupSecurityProvider(handler, &GlobalSecurity); err != nil {
|
||||
return nil, fmt.Errorf("failed to setup security: %v", err)
|
||||
}
|
||||
|
||||
// Step 5: Create Mux router and setup routes
|
||||
router := mux.NewRouter()
|
||||
|
||||
// The routes are set up by restheadspec, which handles the conversion
|
||||
// from http.Request to the internal request format
|
||||
restheadspec.SetupMuxRoutes(router, handler)
|
||||
|
||||
// Step 6: Apply middleware to the entire router
|
||||
secureRouter := Chain(
|
||||
AuthMiddleware, // Extract user from token
|
||||
SetSecurityMiddleware, // Add security context
|
||||
)(router)
|
||||
|
||||
return secureRouter, nil
|
||||
}
|
||||
|
||||
// ExampleWithMux shows a simpler integration with Mux
|
||||
func ExampleWithMux(db *gorm.DB) (*mux.Router, error) {
|
||||
handler := restheadspec.NewHandlerWithGORM(db)
|
||||
|
||||
// IMPORTANT: Configure callbacks BEFORE SetupSecurityProvider
|
||||
GlobalSecurity.AuthenticateCallback = ExampleAuthenticateFromHeader
|
||||
GlobalSecurity.LoadColumnSecurityCallback = ExampleLoadColumnSecurityFromConfig
|
||||
GlobalSecurity.LoadRowSecurityCallback = ExampleLoadRowSecurityFromConfig
|
||||
|
||||
if err := SetupSecurityProvider(handler, &GlobalSecurity); err != nil {
|
||||
return nil, fmt.Errorf("failed to setup security: %v", err)
|
||||
}
|
||||
|
||||
router := mux.NewRouter()
|
||||
|
||||
// Setup API routes
|
||||
restheadspec.SetupMuxRoutes(router, handler)
|
||||
|
||||
// Apply middleware to router
|
||||
router.Use(mux.MiddlewareFunc(AuthMiddleware))
|
||||
router.Use(mux.MiddlewareFunc(SetSecurityMiddleware))
|
||||
|
||||
return router, nil
|
||||
}
|
||||
|
||||
// Example with Gin
|
||||
// import "github.com/gin-gonic/gin"
|
||||
//
|
||||
// func ExampleWithGin(db *gorm.DB) *gin.Engine {
|
||||
// handler := restheadspec.NewHandlerWithGORM(db)
|
||||
// SetupSecurityProvider(handler, &GlobalSecurity)
|
||||
//
|
||||
// router := gin.Default()
|
||||
//
|
||||
// // Convert middleware to Gin middleware
|
||||
// router.Use(func(c *gin.Context) {
|
||||
// AuthMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// c.Request = r
|
||||
// c.Next()
|
||||
// })).ServeHTTP(c.Writer, c.Request)
|
||||
// })
|
||||
//
|
||||
// // Setup routes
|
||||
// api := router.Group("/api")
|
||||
// api.Any("/:schema/:entity", gin.WrapH(http.HandlerFunc(handler.Handle)))
|
||||
// api.Any("/:schema/:entity/:id", gin.WrapH(http.HandlerFunc(handler.Handle)))
|
||||
//
|
||||
// return router
|
||||
// }
|
||||
@ -83,7 +83,7 @@ func TestSetup(m *testing.M) int {
|
||||
router := setupTestRouter(testDB)
|
||||
testServer = httptest.NewServer(router)
|
||||
|
||||
fmt.Printf("ResolveSpec test server starting on %s\n", testServer.URL)
|
||||
logger.Info("ResolveSpec test server starting on %s", testServer.URL)
|
||||
testServerURL = testServer.URL
|
||||
|
||||
defer testServer.Close()
|
||||
|
||||
2
todo.md
2
todo.md
@ -120,6 +120,8 @@ func (b *BunSelectQuery) Preload(relation string, conditions ...interface{}) com
|
||||
- When making changes, we can have the trigger fire with the correct user.
|
||||
- Maybe wrap the handleRead,Update,Create,Delete handlers in a transaction with context that can abort when the request is cancelled or a configurable timeout is reached.
|
||||
|
||||
### 7.
|
||||
|
||||
## Additional Considerations
|
||||
|
||||
### Documentation
|
||||
|
||||
Loading…
Reference in New Issue
Block a user