mirror of
https://github.com/bitechdev/ResolveSpec.git
synced 2025-11-13 09:53:53 +00:00
Updated logging, added getRowNumber and a few more
This commit is contained in:
parent
faafe5abea
commit
ceaa251301
@ -1,138 +0,0 @@
|
|||||||
# Schema and Table Name Handling
|
|
||||||
|
|
||||||
This document explains how the handlers properly separate and handle schema and table names.
|
|
||||||
|
|
||||||
## Implementation
|
|
||||||
|
|
||||||
Both `resolvespec` and `restheadspec` handlers now properly handle schema and table name separation through the following functions:
|
|
||||||
|
|
||||||
- `parseTableName(fullTableName)` - Splits "schema.table" into separate components
|
|
||||||
- `getSchemaAndTable(defaultSchema, entity, model)` - Returns schema and table separately
|
|
||||||
- `getTableName(schema, entity, model)` - Returns the full "schema.table" format
|
|
||||||
|
|
||||||
## Priority Order
|
|
||||||
|
|
||||||
When determining the schema and table name, the following priority is used:
|
|
||||||
|
|
||||||
1. **If `TableName()` contains a schema** (e.g., "myschema.mytable"), that schema takes precedence
|
|
||||||
2. **If model implements `SchemaProvider`**, use that schema
|
|
||||||
3. **Otherwise**, use the `defaultSchema` parameter from the URL/request
|
|
||||||
|
|
||||||
## Scenarios
|
|
||||||
|
|
||||||
### Scenario 1: Simple table name, default schema
|
|
||||||
```go
|
|
||||||
type User struct {
|
|
||||||
ID string
|
|
||||||
Name string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (User) TableName() string {
|
|
||||||
return "users"
|
|
||||||
}
|
|
||||||
```
|
|
||||||
- Request URL: `/api/public/users`
|
|
||||||
- Result: `schema="public"`, `table="users"`, `fullName="public.users"`
|
|
||||||
|
|
||||||
### Scenario 2: Table name includes schema
|
|
||||||
```go
|
|
||||||
type User struct {
|
|
||||||
ID string
|
|
||||||
Name string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (User) TableName() string {
|
|
||||||
return "auth.users" // Schema included!
|
|
||||||
}
|
|
||||||
```
|
|
||||||
- Request URL: `/api/public/users` (public is ignored)
|
|
||||||
- Result: `schema="auth"`, `table="users"`, `fullName="auth.users"`
|
|
||||||
- **Note**: The schema from `TableName()` takes precedence over the URL schema
|
|
||||||
|
|
||||||
### Scenario 3: Using SchemaProvider
|
|
||||||
```go
|
|
||||||
type User struct {
|
|
||||||
ID string
|
|
||||||
Name string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (User) TableName() string {
|
|
||||||
return "users"
|
|
||||||
}
|
|
||||||
|
|
||||||
func (User) SchemaName() string {
|
|
||||||
return "auth"
|
|
||||||
}
|
|
||||||
```
|
|
||||||
- Request URL: `/api/public/users` (public is ignored)
|
|
||||||
- Result: `schema="auth"`, `table="users"`, `fullName="auth.users"`
|
|
||||||
|
|
||||||
### Scenario 4: Table name includes schema AND SchemaProvider
|
|
||||||
```go
|
|
||||||
type User struct {
|
|
||||||
ID string
|
|
||||||
Name string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (User) TableName() string {
|
|
||||||
return "core.users" // This wins!
|
|
||||||
}
|
|
||||||
|
|
||||||
func (User) SchemaName() string {
|
|
||||||
return "auth" // This is ignored
|
|
||||||
}
|
|
||||||
```
|
|
||||||
- Request URL: `/api/public/users`
|
|
||||||
- Result: `schema="core"`, `table="users"`, `fullName="core.users"`
|
|
||||||
- **Note**: Schema from `TableName()` takes highest precedence
|
|
||||||
|
|
||||||
### Scenario 5: No providers at all
|
|
||||||
```go
|
|
||||||
type User struct {
|
|
||||||
ID string
|
|
||||||
Name string
|
|
||||||
}
|
|
||||||
// No TableName() or SchemaName()
|
|
||||||
```
|
|
||||||
- Request URL: `/api/public/users`
|
|
||||||
- Result: `schema="public"`, `table="users"`, `fullName="public.users"`
|
|
||||||
- Uses URL schema and entity name
|
|
||||||
|
|
||||||
## Key Features
|
|
||||||
|
|
||||||
1. **Automatic detection**: The code automatically detects if `TableName()` includes a schema by checking for "."
|
|
||||||
2. **Backward compatible**: Existing code continues to work
|
|
||||||
3. **Flexible**: Supports multiple ways to specify schema and table
|
|
||||||
4. **Debug logging**: Logs when schema is detected in `TableName()` for debugging
|
|
||||||
|
|
||||||
## Code Locations
|
|
||||||
|
|
||||||
### Handlers
|
|
||||||
- `/pkg/resolvespec/handler.go:472-531`
|
|
||||||
- `/pkg/restheadspec/handler.go:534-593`
|
|
||||||
|
|
||||||
### Database Adapters
|
|
||||||
- `/pkg/common/adapters/database/utils.go` - Shared `parseTableName()` function
|
|
||||||
- `/pkg/common/adapters/database/bun.go` - Bun adapter with separated schema/table
|
|
||||||
- `/pkg/common/adapters/database/gorm.go` - GORM adapter with separated schema/table
|
|
||||||
|
|
||||||
## Adapter Implementation
|
|
||||||
|
|
||||||
Both Bun and GORM adapters now properly separate schema and table name:
|
|
||||||
|
|
||||||
```go
|
|
||||||
// BunSelectQuery/GormSelectQuery now have separated fields:
|
|
||||||
type BunSelectQuery struct {
|
|
||||||
query *bun.SelectQuery
|
|
||||||
schema string // Separated schema name
|
|
||||||
tableName string // Just the table name, without schema
|
|
||||||
tableAlias string
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
When `Model()` or `Table()` is called:
|
|
||||||
1. The full table name (which may include schema) is parsed
|
|
||||||
2. Schema and table name are stored separately
|
|
||||||
3. When building joins, the already-separated table name is used directly
|
|
||||||
|
|
||||||
This ensures consistent handling of schema-qualified table names throughout the codebase.
|
|
||||||
@ -1,7 +1,6 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
@ -21,8 +20,8 @@ import (
|
|||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
// Initialize logger
|
// Initialize logger
|
||||||
fmt.Println("ResolveSpec test server starting")
|
|
||||||
logger.Init(true)
|
logger.Init(true)
|
||||||
|
logger.Info("ResolveSpec test server starting")
|
||||||
|
|
||||||
// Initialize database
|
// Initialize database
|
||||||
db, err := initDB()
|
db, err := initDB()
|
||||||
|
|||||||
4
go.mod
4
go.mod
@ -24,6 +24,10 @@ require (
|
|||||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||||
github.com/puzpuzpuz/xsync/v3 v3.5.1 // indirect
|
github.com/puzpuzpuz/xsync/v3 v3.5.1 // indirect
|
||||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
||||||
|
github.com/tidwall/gjson v1.18.0 // indirect
|
||||||
|
github.com/tidwall/match v1.1.1 // indirect
|
||||||
|
github.com/tidwall/pretty v1.2.0 // indirect
|
||||||
|
github.com/tidwall/sjson v1.2.5 // indirect
|
||||||
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc // indirect
|
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc // indirect
|
||||||
github.com/uptrace/bunrouter v1.0.23 // indirect
|
github.com/uptrace/bunrouter v1.0.23 // indirect
|
||||||
github.com/vmihailenco/msgpack/v5 v5.4.1 // indirect
|
github.com/vmihailenco/msgpack/v5 v5.4.1 // indirect
|
||||||
|
|||||||
9
go.sum
9
go.sum
@ -37,6 +37,15 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
|
|||||||
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
||||||
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
|
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
|
||||||
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||||
|
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||||
|
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
|
||||||
|
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||||
|
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
|
||||||
|
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
||||||
|
github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
|
||||||
|
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||||
|
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
||||||
|
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
||||||
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc h1:9lRDQMhESg+zvGYmW5DyG0UqvY96Bu5QYsTLvCHdrgo=
|
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc h1:9lRDQMhESg+zvGYmW5DyG0UqvY96Bu5QYsTLvCHdrgo=
|
||||||
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc/go.mod h1:bciPuU6GHm1iF1pBvUfxfsH0Wmnc2VbpgvbI9ZWuIRs=
|
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc/go.mod h1:bciPuU6GHm1iF1pBvUfxfsH0Wmnc2VbpgvbI9ZWuIRs=
|
||||||
github.com/uptrace/bun v1.2.15 h1:Ut68XRBLDgp9qG9QBMa9ELWaZOmzHNdczHQdrOZbEFE=
|
github.com/uptrace/bun v1.2.15 h1:Ut68XRBLDgp9qG9QBMa9ELWaZOmzHNdczHQdrOZbEFE=
|
||||||
|
|||||||
@ -1,10 +1,7 @@
|
|||||||
package database
|
package database
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"reflect"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// parseTableName splits a table name that may contain schema into separate schema and table
|
// parseTableName splits a table name that may contain schema into separate schema and table
|
||||||
@ -17,157 +14,3 @@ func parseTableName(fullTableName string) (schema, table string) {
|
|||||||
}
|
}
|
||||||
return "", fullTableName
|
return "", fullTableName
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetPrimaryKeyName extracts the primary key column name from a model
|
|
||||||
// It first checks if the model implements PrimaryKeyNameProvider (GetIDName method)
|
|
||||||
// Falls back to reflection to find bun:",pk" tag, then gorm:"primaryKey" tag
|
|
||||||
func GetPrimaryKeyName(model any) string {
|
|
||||||
// Check if model implements PrimaryKeyNameProvider
|
|
||||||
if provider, ok := model.(common.PrimaryKeyNameProvider); ok {
|
|
||||||
return provider.GetIDName()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Try Bun tag first
|
|
||||||
if pkName := getPrimaryKeyFromReflection(model, "bun"); pkName != "" {
|
|
||||||
return pkName
|
|
||||||
}
|
|
||||||
|
|
||||||
// Fall back to GORM tag
|
|
||||||
return getPrimaryKeyFromReflection(model, "gorm")
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetModelColumns extracts all column names from a model using reflection
|
|
||||||
// It checks bun tags first, then gorm tags, then json tags, and finally falls back to lowercase field names
|
|
||||||
func GetModelColumns(model any) []string {
|
|
||||||
var columns []string
|
|
||||||
|
|
||||||
modelType := reflect.TypeOf(model)
|
|
||||||
|
|
||||||
// Unwrap pointers, slices, and arrays to get to the base struct type
|
|
||||||
for modelType != nil && (modelType.Kind() == reflect.Pointer || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
|
||||||
modelType = modelType.Elem()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate that we have a struct type
|
|
||||||
if modelType == nil || modelType.Kind() != reflect.Struct {
|
|
||||||
return columns
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := 0; i < modelType.NumField(); i++ {
|
|
||||||
field := modelType.Field(i)
|
|
||||||
|
|
||||||
// Get column name using the same logic as primary key extraction
|
|
||||||
columnName := getColumnNameFromField(field)
|
|
||||||
|
|
||||||
if columnName != "" {
|
|
||||||
columns = append(columns, columnName)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return columns
|
|
||||||
}
|
|
||||||
|
|
||||||
// getColumnNameFromField extracts the column name from a struct field
|
|
||||||
// Priority: bun tag -> gorm tag -> json tag -> lowercase field name
|
|
||||||
func getColumnNameFromField(field reflect.StructField) string {
|
|
||||||
// Try bun tag first
|
|
||||||
bunTag := field.Tag.Get("bun")
|
|
||||||
if bunTag != "" && bunTag != "-" {
|
|
||||||
if colName := extractColumnFromBunTag(bunTag); colName != "" {
|
|
||||||
return colName
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Try gorm tag
|
|
||||||
gormTag := field.Tag.Get("gorm")
|
|
||||||
if gormTag != "" && gormTag != "-" {
|
|
||||||
if colName := extractColumnFromGormTag(gormTag); colName != "" {
|
|
||||||
return colName
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Fall back to json tag
|
|
||||||
jsonTag := field.Tag.Get("json")
|
|
||||||
if jsonTag != "" && jsonTag != "-" {
|
|
||||||
// Extract just the field name before any options
|
|
||||||
parts := strings.Split(jsonTag, ",")
|
|
||||||
if len(parts) > 0 && parts[0] != "" {
|
|
||||||
return parts[0]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Last resort: use field name in lowercase
|
|
||||||
return strings.ToLower(field.Name)
|
|
||||||
}
|
|
||||||
|
|
||||||
// getPrimaryKeyFromReflection uses reflection to find the primary key field
|
|
||||||
func getPrimaryKeyFromReflection(model any, ormType string) string {
|
|
||||||
val := reflect.ValueOf(model)
|
|
||||||
if val.Kind() == reflect.Pointer {
|
|
||||||
val = val.Elem()
|
|
||||||
}
|
|
||||||
|
|
||||||
if val.Kind() != reflect.Struct {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
typ := val.Type()
|
|
||||||
for i := 0; i < typ.NumField(); i++ {
|
|
||||||
field := typ.Field(i)
|
|
||||||
|
|
||||||
switch ormType {
|
|
||||||
case "gorm":
|
|
||||||
// Check for gorm tag with primaryKey
|
|
||||||
gormTag := field.Tag.Get("gorm")
|
|
||||||
if strings.Contains(gormTag, "primaryKey") {
|
|
||||||
// Try to extract column name from gorm tag
|
|
||||||
if colName := extractColumnFromGormTag(gormTag); colName != "" {
|
|
||||||
return colName
|
|
||||||
}
|
|
||||||
// Fall back to json tag
|
|
||||||
if jsonTag := field.Tag.Get("json"); jsonTag != "" {
|
|
||||||
return strings.Split(jsonTag, ",")[0]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
case "bun":
|
|
||||||
// Check for bun tag with pk flag
|
|
||||||
bunTag := field.Tag.Get("bun")
|
|
||||||
if strings.Contains(bunTag, "pk") {
|
|
||||||
// Extract column name from bun tag
|
|
||||||
if colName := extractColumnFromBunTag(bunTag); colName != "" {
|
|
||||||
return colName
|
|
||||||
}
|
|
||||||
// Fall back to json tag
|
|
||||||
if jsonTag := field.Tag.Get("json"); jsonTag != "" {
|
|
||||||
return strings.Split(jsonTag, ",")[0]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
// extractColumnFromGormTag extracts the column name from a gorm tag
|
|
||||||
// Example: "column:id;primaryKey" -> "id"
|
|
||||||
func extractColumnFromGormTag(tag string) string {
|
|
||||||
parts := strings.Split(tag, ";")
|
|
||||||
for _, part := range parts {
|
|
||||||
part = strings.TrimSpace(part)
|
|
||||||
if colName, found := strings.CutPrefix(part, "column:"); found {
|
|
||||||
return colName
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
// extractColumnFromBunTag extracts the column name from a bun tag
|
|
||||||
// Example: "id,pk" -> "id"
|
|
||||||
// Example: ",pk" -> "" (will fall back to json tag)
|
|
||||||
func extractColumnFromBunTag(tag string) string {
|
|
||||||
parts := strings.Split(tag, ",")
|
|
||||||
if len(parts) > 0 && parts[0] != "" {
|
|
||||||
return parts[0]
|
|
||||||
}
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|||||||
@ -72,11 +72,12 @@ type Response struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type Metadata struct {
|
type Metadata struct {
|
||||||
Total int64 `json:"total"`
|
Total int64 `json:"total"`
|
||||||
Count int64 `json:"count"`
|
Count int64 `json:"count"`
|
||||||
Filtered int64 `json:"filtered"`
|
Filtered int64 `json:"filtered"`
|
||||||
Limit int `json:"limit"`
|
Limit int `json:"limit"`
|
||||||
Offset int `json:"offset"`
|
Offset int `json:"offset"`
|
||||||
|
RowNumber *int64 `json:"row_number,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type APIError struct {
|
type APIError struct {
|
||||||
|
|||||||
@ -4,6 +4,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
"os"
|
"os"
|
||||||
|
"runtime/debug"
|
||||||
|
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
)
|
)
|
||||||
@ -70,3 +71,35 @@ func Debug(template string, args ...interface{}) {
|
|||||||
}
|
}
|
||||||
Logger.Debugw(fmt.Sprintf(template, args...), "process_id", os.Getpid())
|
Logger.Debugw(fmt.Sprintf(template, args...), "process_id", os.Getpid())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CatchPanic - Handle panic
|
||||||
|
func CatchPanicCallback(location string, cb func(err any)) {
|
||||||
|
if err := recover(); err != nil {
|
||||||
|
//callstack := debug.Stack()
|
||||||
|
|
||||||
|
if Logger != nil {
|
||||||
|
Error("Panic in %s : %v", location, err)
|
||||||
|
} else {
|
||||||
|
fmt.Printf("%s:PANIC->%+v", location, err)
|
||||||
|
debug.PrintStack()
|
||||||
|
}
|
||||||
|
|
||||||
|
//push to sentry
|
||||||
|
// hub := sentry.CurrentHub()
|
||||||
|
// if hub != nil {
|
||||||
|
// evtID := hub.Recover(err)
|
||||||
|
// if evtID != nil {
|
||||||
|
// sentry.Flush(time.Second * 2)
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
if cb != nil {
|
||||||
|
cb(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CatchPanic - Handle panic
|
||||||
|
func CatchPanic(location string) {
|
||||||
|
CatchPanicCallback(location, nil)
|
||||||
|
}
|
||||||
|
|||||||
100
pkg/reflection/generic_model.go
Normal file
100
pkg/reflection/generic_model.go
Normal file
@ -0,0 +1,100 @@
|
|||||||
|
package reflection
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ModelFieldDetail struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
DataType string `json:"datatype"`
|
||||||
|
SQLName string `json:"sqlname"`
|
||||||
|
SQLDataType string `json:"sqldatatype"`
|
||||||
|
SQLKey string `json:"sqlkey"`
|
||||||
|
Nullable bool `json:"nullable"`
|
||||||
|
FieldValue reflect.Value `json:"-"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetModelColumnDetail - Get a list of columns in the SQL declaration of the model
|
||||||
|
func GetModelColumnDetail(record reflect.Value) []ModelFieldDetail {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
logger.Error("Panic in GetModelColumnDetail : %v", r)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
var lst []ModelFieldDetail
|
||||||
|
lst = make([]ModelFieldDetail, 0)
|
||||||
|
|
||||||
|
if !record.IsValid() {
|
||||||
|
return lst
|
||||||
|
}
|
||||||
|
if record.Kind() == reflect.Pointer || record.Kind() == reflect.Interface {
|
||||||
|
record = record.Elem()
|
||||||
|
}
|
||||||
|
if record.Kind() != reflect.Struct {
|
||||||
|
return lst
|
||||||
|
}
|
||||||
|
modeltype := record.Type()
|
||||||
|
|
||||||
|
for i := 0; i < modeltype.NumField(); i++ {
|
||||||
|
fieldtype := modeltype.Field(i)
|
||||||
|
gormdetail := fieldtype.Tag.Get("gorm")
|
||||||
|
gormdetail = strings.Trim(gormdetail, " ")
|
||||||
|
fielddetail := ModelFieldDetail{}
|
||||||
|
fielddetail.FieldValue = record.Field(i)
|
||||||
|
fielddetail.Name = fieldtype.Name
|
||||||
|
fielddetail.DataType = fieldtype.Type.Name()
|
||||||
|
fielddetail.SQLName = fnFindKeyVal(gormdetail, "column:")
|
||||||
|
fielddetail.SQLDataType = fnFindKeyVal(gormdetail, "type:")
|
||||||
|
if strings.Index(strings.ToLower(gormdetail), "identity") > 0 ||
|
||||||
|
strings.Index(strings.ToLower(gormdetail), "primary_key") > 0 {
|
||||||
|
fielddetail.SQLKey = "primary_key"
|
||||||
|
} else if strings.Contains(strings.ToLower(gormdetail), "unique") {
|
||||||
|
fielddetail.SQLKey = "unique"
|
||||||
|
} else if strings.Contains(strings.ToLower(gormdetail), "uniqueindex") {
|
||||||
|
fielddetail.SQLKey = "uniqueindex"
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.Contains(strings.ToLower(gormdetail), "nullable") {
|
||||||
|
fielddetail.Nullable = true
|
||||||
|
} else if strings.Contains(strings.ToLower(gormdetail), "null") {
|
||||||
|
fielddetail.Nullable = true
|
||||||
|
}
|
||||||
|
if strings.Contains(strings.ToLower(gormdetail), "not null") {
|
||||||
|
fielddetail.Nullable = false
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.Contains(strings.ToLower(gormdetail), "foreignkey:") {
|
||||||
|
fielddetail.SQLKey = "foreign_key"
|
||||||
|
ik := strings.Index(strings.ToLower(gormdetail), "foreignkey:")
|
||||||
|
ie := strings.Index(gormdetail[ik:], ";")
|
||||||
|
if ie > ik && ik > 0 {
|
||||||
|
fielddetail.SQLName = strings.ToLower(gormdetail)[ik+11 : ik+ie]
|
||||||
|
//fmt.Printf("\r\nforeignkey: %v", fielddetail)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
//";foreignkey:rid_parent;association_foreignkey:id_atevent;save_associations:false;association_autocreate:false;"
|
||||||
|
|
||||||
|
lst = append(lst, fielddetail)
|
||||||
|
|
||||||
|
}
|
||||||
|
return lst
|
||||||
|
}
|
||||||
|
|
||||||
|
func fnFindKeyVal(src, key string) string {
|
||||||
|
icolStart := strings.Index(strings.ToLower(src), strings.ToLower(key))
|
||||||
|
val := ""
|
||||||
|
if icolStart >= 0 {
|
||||||
|
val = src[icolStart+len(key):]
|
||||||
|
icolend := strings.Index(val, ";")
|
||||||
|
if icolend > 0 {
|
||||||
|
val = val[:icolend]
|
||||||
|
}
|
||||||
|
return val
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
162
pkg/reflection/model_utils.go
Normal file
162
pkg/reflection/model_utils.go
Normal file
@ -0,0 +1,162 @@
|
|||||||
|
package reflection
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GetPrimaryKeyName extracts the primary key column name from a model
|
||||||
|
// It first checks if the model implements PrimaryKeyNameProvider (GetIDName method)
|
||||||
|
// Falls back to reflection to find bun:",pk" tag, then gorm:"primaryKey" tag
|
||||||
|
func GetPrimaryKeyName(model any) string {
|
||||||
|
// Check if model implements PrimaryKeyNameProvider
|
||||||
|
if provider, ok := model.(common.PrimaryKeyNameProvider); ok {
|
||||||
|
return provider.GetIDName()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try Bun tag first
|
||||||
|
if pkName := getPrimaryKeyFromReflection(model, "bun"); pkName != "" {
|
||||||
|
return pkName
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fall back to GORM tag
|
||||||
|
return getPrimaryKeyFromReflection(model, "gorm")
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetModelColumns extracts all column names from a model using reflection
|
||||||
|
// It checks bun tags first, then gorm tags, then json tags, and finally falls back to lowercase field names
|
||||||
|
func GetModelColumns(model any) []string {
|
||||||
|
var columns []string
|
||||||
|
|
||||||
|
modelType := reflect.TypeOf(model)
|
||||||
|
|
||||||
|
// Unwrap pointers, slices, and arrays to get to the base struct type
|
||||||
|
for modelType != nil && (modelType.Kind() == reflect.Pointer || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
||||||
|
modelType = modelType.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate that we have a struct type
|
||||||
|
if modelType == nil || modelType.Kind() != reflect.Struct {
|
||||||
|
return columns
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < modelType.NumField(); i++ {
|
||||||
|
field := modelType.Field(i)
|
||||||
|
|
||||||
|
// Get column name using the same logic as primary key extraction
|
||||||
|
columnName := getColumnNameFromField(field)
|
||||||
|
|
||||||
|
if columnName != "" {
|
||||||
|
columns = append(columns, columnName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return columns
|
||||||
|
}
|
||||||
|
|
||||||
|
// getColumnNameFromField extracts the column name from a struct field
|
||||||
|
// Priority: bun tag -> gorm tag -> json tag -> lowercase field name
|
||||||
|
func getColumnNameFromField(field reflect.StructField) string {
|
||||||
|
// Try bun tag first
|
||||||
|
bunTag := field.Tag.Get("bun")
|
||||||
|
if bunTag != "" && bunTag != "-" {
|
||||||
|
if colName := ExtractColumnFromBunTag(bunTag); colName != "" {
|
||||||
|
return colName
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try gorm tag
|
||||||
|
gormTag := field.Tag.Get("gorm")
|
||||||
|
if gormTag != "" && gormTag != "-" {
|
||||||
|
if colName := ExtractColumnFromGormTag(gormTag); colName != "" {
|
||||||
|
return colName
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fall back to json tag
|
||||||
|
jsonTag := field.Tag.Get("json")
|
||||||
|
if jsonTag != "" && jsonTag != "-" {
|
||||||
|
// Extract just the field name before any options
|
||||||
|
parts := strings.Split(jsonTag, ",")
|
||||||
|
if len(parts) > 0 && parts[0] != "" {
|
||||||
|
return parts[0]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Last resort: use field name in lowercase
|
||||||
|
return strings.ToLower(field.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// getPrimaryKeyFromReflection uses reflection to find the primary key field
|
||||||
|
func getPrimaryKeyFromReflection(model any, ormType string) string {
|
||||||
|
val := reflect.ValueOf(model)
|
||||||
|
if val.Kind() == reflect.Pointer {
|
||||||
|
val = val.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
if val.Kind() != reflect.Struct {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
typ := val.Type()
|
||||||
|
for i := 0; i < typ.NumField(); i++ {
|
||||||
|
field := typ.Field(i)
|
||||||
|
|
||||||
|
switch ormType {
|
||||||
|
case "gorm":
|
||||||
|
// Check for gorm tag with primaryKey
|
||||||
|
gormTag := field.Tag.Get("gorm")
|
||||||
|
if strings.Contains(gormTag, "primaryKey") {
|
||||||
|
// Try to extract column name from gorm tag
|
||||||
|
if colName := ExtractColumnFromGormTag(gormTag); colName != "" {
|
||||||
|
return colName
|
||||||
|
}
|
||||||
|
// Fall back to json tag
|
||||||
|
if jsonTag := field.Tag.Get("json"); jsonTag != "" {
|
||||||
|
return strings.Split(jsonTag, ",")[0]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case "bun":
|
||||||
|
// Check for bun tag with pk flag
|
||||||
|
bunTag := field.Tag.Get("bun")
|
||||||
|
if strings.Contains(bunTag, "pk") {
|
||||||
|
// Extract column name from bun tag
|
||||||
|
if colName := ExtractColumnFromBunTag(bunTag); colName != "" {
|
||||||
|
return colName
|
||||||
|
}
|
||||||
|
// Fall back to json tag
|
||||||
|
if jsonTag := field.Tag.Get("json"); jsonTag != "" {
|
||||||
|
return strings.Split(jsonTag, ",")[0]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExtractColumnFromGormTag extracts the column name from a gorm tag
|
||||||
|
// Example: "column:id;primaryKey" -> "id"
|
||||||
|
func ExtractColumnFromGormTag(tag string) string {
|
||||||
|
parts := strings.Split(tag, ";")
|
||||||
|
for _, part := range parts {
|
||||||
|
part = strings.TrimSpace(part)
|
||||||
|
if colName, found := strings.CutPrefix(part, "column:"); found {
|
||||||
|
return colName
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExtractColumnFromBunTag extracts the column name from a bun tag
|
||||||
|
// Example: "id,pk" -> "id"
|
||||||
|
// Example: ",pk" -> "" (will fall back to json tag)
|
||||||
|
func ExtractColumnFromBunTag(tag string) string {
|
||||||
|
parts := strings.Split(tag, ",")
|
||||||
|
if len(parts) > 0 && parts[0] != "" {
|
||||||
|
return parts[0]
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
@ -1,4 +1,4 @@
|
|||||||
package database
|
package reflection
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
@ -137,9 +137,9 @@ func TestExtractColumnFromGormTag(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
result := extractColumnFromGormTag(tt.tag)
|
result := ExtractColumnFromGormTag(tt.tag)
|
||||||
if result != tt.expected {
|
if result != tt.expected {
|
||||||
t.Errorf("extractColumnFromGormTag() = %v, want %v", result, tt.expected)
|
t.Errorf("ExtractColumnFromGormTag() = %v, want %v", result, tt.expected)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -170,9 +170,9 @@ func TestExtractColumnFromBunTag(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
result := extractColumnFromBunTag(tt.tag)
|
result := ExtractColumnFromBunTag(tt.tag)
|
||||||
if result != tt.expected {
|
if result != tt.expected {
|
||||||
t.Errorf("extractColumnFromBunTag() = %v, want %v", result, tt.expected)
|
t.Errorf("ExtractColumnFromBunTag() = %v, want %v", result, tt.expected)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -5,6 +5,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
)
|
)
|
||||||
|
|
||||||
// CursorDirection defines pagination direction
|
// CursorDirection defines pagination direction
|
||||||
@ -85,7 +86,7 @@ func (opts *ExtendedRequestOptions) GetCursorFilter(
|
|||||||
field, prefix, tableName, modelColumns,
|
field, prefix, tableName, modelColumns,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Printf("WARN: Skipping invalid sort column %q: %v\n", col, err)
|
logger.Warn("Skipping invalid sort column %q: %v", col, err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -10,8 +10,8 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/common/adapters/database"
|
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/reflection"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Handler handles API requests using database and model abstractions
|
// Handler handles API requests using database and model abstractions
|
||||||
@ -343,10 +343,10 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
|||||||
logger.Debug("Applying cursor pagination")
|
logger.Debug("Applying cursor pagination")
|
||||||
|
|
||||||
// Get primary key name
|
// Get primary key name
|
||||||
pkName := database.GetPrimaryKeyName(model)
|
pkName := reflection.GetPrimaryKeyName(model)
|
||||||
|
|
||||||
// Extract model columns for validation using the generic database function
|
// Extract model columns for validation using the generic database function
|
||||||
modelColumns := database.GetModelColumns(model)
|
modelColumns := reflection.GetModelColumns(model)
|
||||||
|
|
||||||
// Build expand joins map (if needed in future)
|
// Build expand joins map (if needed in future)
|
||||||
var expandJoins map[string]string
|
var expandJoins map[string]string
|
||||||
@ -371,6 +371,19 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Execute BeforeScan hooks - pass query chain so hooks can modify it
|
||||||
|
hookCtx.Query = query
|
||||||
|
if err := h.hooks.Execute(BeforeScan, hookCtx); err != nil {
|
||||||
|
logger.Error("BeforeScan hook failed: %v", err)
|
||||||
|
h.sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use potentially modified query from hook context
|
||||||
|
if modifiedQuery, ok := hookCtx.Query.(common.SelectQuery); ok {
|
||||||
|
query = modifiedQuery
|
||||||
|
}
|
||||||
|
|
||||||
// Execute query - modelPtr was already created earlier
|
// Execute query - modelPtr was already created earlier
|
||||||
if err := query.Scan(ctx, modelPtr); err != nil {
|
if err := query.Scan(ctx, modelPtr); err != nil {
|
||||||
logger.Error("Error executing query: %v", err)
|
logger.Error("Error executing query: %v", err)
|
||||||
@ -387,6 +400,9 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
|||||||
offset = *options.Offset
|
offset = *options.Offset
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Set row numbers on each record if the model has a RowNumber field
|
||||||
|
h.setRowNumbersOnRecords(modelPtr, offset)
|
||||||
|
|
||||||
metadata := &common.Metadata{
|
metadata := &common.Metadata{
|
||||||
Total: int64(total),
|
Total: int64(total),
|
||||||
Count: int64(common.Len(modelPtr)),
|
Count: int64(common.Len(modelPtr)),
|
||||||
@ -395,6 +411,23 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
|||||||
Offset: offset,
|
Offset: offset,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Fetch row number for a specific record if requested
|
||||||
|
if options.RequestOptions.FetchRowNumber != nil && *options.RequestOptions.FetchRowNumber != "" {
|
||||||
|
pkName := reflection.GetPrimaryKeyName(model)
|
||||||
|
pkValue := *options.RequestOptions.FetchRowNumber
|
||||||
|
|
||||||
|
logger.Debug("Fetching row number for specific PK %s = %s", pkName, pkValue)
|
||||||
|
|
||||||
|
rowNum, err := h.FetchRowNumber(ctx, tableName, pkName, pkValue, options, model)
|
||||||
|
if err != nil {
|
||||||
|
logger.Warn("Failed to fetch row number: %v", err)
|
||||||
|
// Don't fail the entire request, just log the warning
|
||||||
|
} else {
|
||||||
|
metadata.RowNumber = &rowNum
|
||||||
|
logger.Debug("Row number for PK %s: %d", pkValue, rowNum)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Execute AfterRead hooks
|
// Execute AfterRead hooks
|
||||||
hookCtx.Result = modelPtr
|
hookCtx.Result = modelPtr
|
||||||
hookCtx.Error = nil
|
hookCtx.Error = nil
|
||||||
@ -466,6 +499,29 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
|
|||||||
}
|
}
|
||||||
|
|
||||||
query := tx.NewInsert().Model(modelValue).Table(tableName)
|
query := tx.NewInsert().Model(modelValue).Table(tableName)
|
||||||
|
|
||||||
|
// Execute BeforeScan hooks - pass query chain so hooks can modify it
|
||||||
|
batchHookCtx := &HookContext{
|
||||||
|
Context: ctx,
|
||||||
|
Handler: h,
|
||||||
|
Schema: schema,
|
||||||
|
Entity: entity,
|
||||||
|
TableName: tableName,
|
||||||
|
Model: model,
|
||||||
|
Options: options,
|
||||||
|
Data: modelValue,
|
||||||
|
Writer: w,
|
||||||
|
Query: query,
|
||||||
|
}
|
||||||
|
if err := h.hooks.Execute(BeforeScan, batchHookCtx); err != nil {
|
||||||
|
return fmt.Errorf("BeforeScan hook failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use potentially modified query from hook context
|
||||||
|
if modifiedQuery, ok := batchHookCtx.Query.(common.InsertQuery); ok {
|
||||||
|
query = modifiedQuery
|
||||||
|
}
|
||||||
|
|
||||||
if _, err := query.Exec(ctx); err != nil {
|
if _, err := query.Exec(ctx); err != nil {
|
||||||
return fmt.Errorf("failed to insert record: %w", err)
|
return fmt.Errorf("failed to insert record: %w", err)
|
||||||
}
|
}
|
||||||
@ -508,6 +564,21 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
|
|||||||
}
|
}
|
||||||
|
|
||||||
query := h.db.NewInsert().Model(modelValue).Table(tableName)
|
query := h.db.NewInsert().Model(modelValue).Table(tableName)
|
||||||
|
|
||||||
|
// Execute BeforeScan hooks - pass query chain so hooks can modify it
|
||||||
|
hookCtx.Data = modelValue
|
||||||
|
hookCtx.Query = query
|
||||||
|
if err := h.hooks.Execute(BeforeScan, hookCtx); err != nil {
|
||||||
|
logger.Error("BeforeScan hook failed: %v", err)
|
||||||
|
h.sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use potentially modified query from hook context
|
||||||
|
if modifiedQuery, ok := hookCtx.Query.(common.InsertQuery); ok {
|
||||||
|
query = modifiedQuery
|
||||||
|
}
|
||||||
|
|
||||||
if _, err := query.Exec(ctx); err != nil {
|
if _, err := query.Exec(ctx); err != nil {
|
||||||
logger.Error("Error creating record: %v", err)
|
logger.Error("Error creating record: %v", err)
|
||||||
h.sendError(w, http.StatusInternalServerError, "create_error", "Error creating record", err)
|
h.sendError(w, http.StatusInternalServerError, "create_error", "Error creating record", err)
|
||||||
@ -593,6 +664,19 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Execute BeforeScan hooks - pass query chain so hooks can modify it
|
||||||
|
hookCtx.Query = query
|
||||||
|
if err := h.hooks.Execute(BeforeScan, hookCtx); err != nil {
|
||||||
|
logger.Error("BeforeScan hook failed: %v", err)
|
||||||
|
h.sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use potentially modified query from hook context
|
||||||
|
if modifiedQuery, ok := hookCtx.Query.(common.UpdateQuery); ok {
|
||||||
|
query = modifiedQuery
|
||||||
|
}
|
||||||
|
|
||||||
result, err := query.Exec(ctx)
|
result, err := query.Exec(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("Error updating record: %v", err)
|
logger.Error("Error updating record: %v", err)
|
||||||
@ -658,6 +742,19 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
|||||||
|
|
||||||
query = query.Where("id = ?", id)
|
query = query.Where("id = ?", id)
|
||||||
|
|
||||||
|
// Execute BeforeScan hooks - pass query chain so hooks can modify it
|
||||||
|
hookCtx.Query = query
|
||||||
|
if err := h.hooks.Execute(BeforeScan, hookCtx); err != nil {
|
||||||
|
logger.Error("BeforeScan hook failed: %v", err)
|
||||||
|
h.sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use potentially modified query from hook context
|
||||||
|
if modifiedQuery, ok := hookCtx.Query.(common.DeleteQuery); ok {
|
||||||
|
query = modifiedQuery
|
||||||
|
}
|
||||||
|
|
||||||
result, err := query.Exec(ctx)
|
result, err := query.Exec(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("Error deleting record: %v", err)
|
logger.Error("Error deleting record: %v", err)
|
||||||
@ -999,6 +1096,191 @@ func (h *Handler) sendError(w common.ResponseWriter, statusCode int, code, messa
|
|||||||
w.WriteJSON(response)
|
w.WriteJSON(response)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// FetchRowNumber calculates the row number of a specific record based on sorting and filtering
|
||||||
|
// Returns the 1-based row number of the record with the given primary key value
|
||||||
|
func (h *Handler) FetchRowNumber(ctx context.Context, tableName string, pkName string, pkValue string, options ExtendedRequestOptions, model any) (int64, error) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
logger.Error("Panic during FetchRowNumber: %v", r)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Build the sort order SQL
|
||||||
|
sortSQL := ""
|
||||||
|
if len(options.Sort) > 0 {
|
||||||
|
sortParts := make([]string, 0, len(options.Sort))
|
||||||
|
for _, sort := range options.Sort {
|
||||||
|
direction := "ASC"
|
||||||
|
if strings.ToLower(sort.Direction) == "desc" {
|
||||||
|
direction = "DESC"
|
||||||
|
}
|
||||||
|
sortParts = append(sortParts, fmt.Sprintf("%s.%s %s", tableName, sort.Column, direction))
|
||||||
|
}
|
||||||
|
sortSQL = strings.Join(sortParts, ", ")
|
||||||
|
} else {
|
||||||
|
// Default sort by primary key
|
||||||
|
sortSQL = fmt.Sprintf("%s.%s ASC", tableName, pkName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build WHERE clauses from filters
|
||||||
|
whereClauses := make([]string, 0)
|
||||||
|
for i := range options.Filters {
|
||||||
|
filter := &options.Filters[i]
|
||||||
|
whereClause := h.buildFilterSQL(filter, tableName)
|
||||||
|
if whereClause != "" {
|
||||||
|
whereClauses = append(whereClauses, fmt.Sprintf("(%s)", whereClause))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Combine WHERE clauses
|
||||||
|
whereSQL := ""
|
||||||
|
if len(whereClauses) > 0 {
|
||||||
|
whereSQL = "WHERE " + strings.Join(whereClauses, " AND ")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add custom SQL WHERE if provided
|
||||||
|
if options.CustomSQLWhere != "" {
|
||||||
|
if whereSQL == "" {
|
||||||
|
whereSQL = "WHERE " + options.CustomSQLWhere
|
||||||
|
} else {
|
||||||
|
whereSQL += " AND (" + options.CustomSQLWhere + ")"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build JOIN clauses from Expand options
|
||||||
|
joinSQL := ""
|
||||||
|
if len(options.Expand) > 0 {
|
||||||
|
joinParts := make([]string, 0, len(options.Expand))
|
||||||
|
for _, expand := range options.Expand {
|
||||||
|
// Note: This is a simplified join - in production you'd need proper FK mapping
|
||||||
|
joinParts = append(joinParts, fmt.Sprintf("LEFT JOIN %s ON %s.%s_id = %s.id",
|
||||||
|
expand.Relation, tableName, expand.Relation, expand.Relation))
|
||||||
|
}
|
||||||
|
joinSQL = strings.Join(joinParts, "\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build the final query with parameterized PK value
|
||||||
|
queryStr := fmt.Sprintf(`
|
||||||
|
SELECT search.rn
|
||||||
|
FROM (
|
||||||
|
SELECT %[1]s.%[2]s,
|
||||||
|
ROW_NUMBER() OVER(ORDER BY %[3]s) AS rn
|
||||||
|
FROM %[1]s
|
||||||
|
%[5]s
|
||||||
|
%[4]s
|
||||||
|
) search
|
||||||
|
WHERE search.%[2]s = ?
|
||||||
|
`,
|
||||||
|
tableName, // [1] - table name
|
||||||
|
pkName, // [2] - primary key column name
|
||||||
|
sortSQL, // [3] - sort order SQL
|
||||||
|
whereSQL, // [4] - WHERE clause
|
||||||
|
joinSQL, // [5] - JOIN clauses
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.Debug("FetchRowNumber query: %s, pkValue: %s", queryStr, pkValue)
|
||||||
|
|
||||||
|
// Execute the raw query with parameterized PK value
|
||||||
|
var result []struct {
|
||||||
|
RN int64 `bun:"rn"`
|
||||||
|
}
|
||||||
|
err := h.db.Query(ctx, &result, queryStr, pkValue)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("failed to fetch row number: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(result) == 0 {
|
||||||
|
return 0, fmt.Errorf("no row found for primary key %s", pkValue)
|
||||||
|
}
|
||||||
|
|
||||||
|
return result[0].RN, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildFilterSQL converts a filter to SQL WHERE clause string
|
||||||
|
func (h *Handler) buildFilterSQL(filter *common.FilterOption, tableName string) string {
|
||||||
|
qualifiedColumn := h.qualifyColumnName(filter.Column, tableName)
|
||||||
|
|
||||||
|
switch strings.ToLower(filter.Operator) {
|
||||||
|
case "eq", "equals":
|
||||||
|
return fmt.Sprintf("%s = '%v'", qualifiedColumn, filter.Value)
|
||||||
|
case "neq", "not_equals", "ne":
|
||||||
|
return fmt.Sprintf("%s != '%v'", qualifiedColumn, filter.Value)
|
||||||
|
case "gt", "greater_than":
|
||||||
|
return fmt.Sprintf("%s > '%v'", qualifiedColumn, filter.Value)
|
||||||
|
case "gte", "greater_than_equals", "ge":
|
||||||
|
return fmt.Sprintf("%s >= '%v'", qualifiedColumn, filter.Value)
|
||||||
|
case "lt", "less_than":
|
||||||
|
return fmt.Sprintf("%s < '%v'", qualifiedColumn, filter.Value)
|
||||||
|
case "lte", "less_than_equals", "le":
|
||||||
|
return fmt.Sprintf("%s <= '%v'", qualifiedColumn, filter.Value)
|
||||||
|
case "like":
|
||||||
|
return fmt.Sprintf("%s LIKE '%v'", qualifiedColumn, filter.Value)
|
||||||
|
case "ilike":
|
||||||
|
return fmt.Sprintf("%s ILIKE '%v'", qualifiedColumn, filter.Value)
|
||||||
|
case "in":
|
||||||
|
if values, ok := filter.Value.([]any); ok {
|
||||||
|
valueStrs := make([]string, len(values))
|
||||||
|
for i, v := range values {
|
||||||
|
valueStrs[i] = fmt.Sprintf("'%v'", v)
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%s IN (%s)", qualifiedColumn, strings.Join(valueStrs, ", "))
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
case "is_null", "isnull":
|
||||||
|
return fmt.Sprintf("(%s IS NULL OR %s = '')", qualifiedColumn, qualifiedColumn)
|
||||||
|
case "is_not_null", "isnotnull":
|
||||||
|
return fmt.Sprintf("(%s IS NOT NULL AND %s != '')", qualifiedColumn, qualifiedColumn)
|
||||||
|
default:
|
||||||
|
logger.Warn("Unknown filter operator in buildFilterSQL: %s", filter.Operator)
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// setRowNumbersOnRecords sets the RowNumber field on each record if it exists
|
||||||
|
// The row number is calculated as offset + index + 1 (1-based)
|
||||||
|
func (h *Handler) setRowNumbersOnRecords(records any, offset int) {
|
||||||
|
// Get the reflect value of the records
|
||||||
|
recordsValue := reflect.ValueOf(records)
|
||||||
|
if recordsValue.Kind() == reflect.Ptr {
|
||||||
|
recordsValue = recordsValue.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure it's a slice
|
||||||
|
if recordsValue.Kind() != reflect.Slice {
|
||||||
|
logger.Debug("setRowNumbersOnRecords: records is not a slice, skipping")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Iterate through each record
|
||||||
|
for i := 0; i < recordsValue.Len(); i++ {
|
||||||
|
record := recordsValue.Index(i)
|
||||||
|
|
||||||
|
// Dereference if it's a pointer
|
||||||
|
if record.Kind() == reflect.Ptr {
|
||||||
|
if record.IsNil() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
record = record.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure it's a struct
|
||||||
|
if record.Kind() != reflect.Struct {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to find and set the RowNumber field
|
||||||
|
rowNumberField := record.FieldByName("RowNumber")
|
||||||
|
if rowNumberField.IsValid() && rowNumberField.CanSet() {
|
||||||
|
// Check if the field is of type int64
|
||||||
|
if rowNumberField.Kind() == reflect.Int64 {
|
||||||
|
rowNum := int64(offset + i + 1)
|
||||||
|
rowNumberField.SetInt(rowNum)
|
||||||
|
logger.Debug("Set RowNumber=%d on record %d", rowNum, i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// filterExtendedOptions filters all column references, removing invalid ones and logging warnings
|
// filterExtendedOptions filters all column references, removing invalid ones and logging warnings
|
||||||
func filterExtendedOptions(validator *common.ColumnValidator, options ExtendedRequestOptions) ExtendedRequestOptions {
|
func filterExtendedOptions(validator *common.ColumnValidator, options ExtendedRequestOptions) ExtendedRequestOptions {
|
||||||
filtered := options
|
filtered := options
|
||||||
|
|||||||
@ -27,6 +27,9 @@ const (
|
|||||||
// Delete operation hooks
|
// Delete operation hooks
|
||||||
BeforeDelete HookType = "before_delete"
|
BeforeDelete HookType = "before_delete"
|
||||||
AfterDelete HookType = "after_delete"
|
AfterDelete HookType = "after_delete"
|
||||||
|
|
||||||
|
// Scan/Execute operation hooks
|
||||||
|
BeforeScan HookType = "before_scan"
|
||||||
)
|
)
|
||||||
|
|
||||||
// HookContext contains all the data available to a hook
|
// HookContext contains all the data available to a hook
|
||||||
@ -46,6 +49,10 @@ type HookContext struct {
|
|||||||
Error error // For after hooks
|
Error error // For after hooks
|
||||||
QueryFilter string // For read operations
|
QueryFilter string // For read operations
|
||||||
|
|
||||||
|
// Query chain - allows hooks to modify the query before execution
|
||||||
|
// Can be SelectQuery, InsertQuery, UpdateQuery, or DeleteQuery
|
||||||
|
Query interface{}
|
||||||
|
|
||||||
// Response writer - allows hooks to modify response
|
// Response writer - allows hooks to modify response
|
||||||
Writer common.ResponseWriter
|
Writer common.ResponseWriter
|
||||||
}
|
}
|
||||||
|
|||||||
203
pkg/restheadspec/rownumber_test.go
Normal file
203
pkg/restheadspec/rownumber_test.go
Normal file
@ -0,0 +1,203 @@
|
|||||||
|
package restheadspec
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestModel represents a typical model with RowNumber field (like DBAdhocBuffer)
|
||||||
|
type TestModel struct {
|
||||||
|
ID int64 `json:"id" bun:"id,pk"`
|
||||||
|
Name string `json:"name" bun:"name"`
|
||||||
|
RowNumber int64 `json:"_rownumber,omitempty" gorm:"-" bun:"-"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetRowNumbersOnRecords(t *testing.T) {
|
||||||
|
handler := &Handler{}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
records any
|
||||||
|
offset int
|
||||||
|
expected []int64
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Set row numbers on slice of pointers",
|
||||||
|
records: []*TestModel{
|
||||||
|
{ID: 1, Name: "First"},
|
||||||
|
{ID: 2, Name: "Second"},
|
||||||
|
{ID: 3, Name: "Third"},
|
||||||
|
},
|
||||||
|
offset: 0,
|
||||||
|
expected: []int64{1, 2, 3},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Set row numbers with offset",
|
||||||
|
records: []*TestModel{
|
||||||
|
{ID: 11, Name: "Eleventh"},
|
||||||
|
{ID: 12, Name: "Twelfth"},
|
||||||
|
},
|
||||||
|
offset: 10,
|
||||||
|
expected: []int64{11, 12},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Set row numbers on slice of values",
|
||||||
|
records: []TestModel{
|
||||||
|
{ID: 1, Name: "First"},
|
||||||
|
{ID: 2, Name: "Second"},
|
||||||
|
},
|
||||||
|
offset: 5,
|
||||||
|
expected: []int64{6, 7},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
handler.setRowNumbersOnRecords(tt.records, tt.offset)
|
||||||
|
|
||||||
|
// Verify row numbers were set correctly
|
||||||
|
switch records := tt.records.(type) {
|
||||||
|
case []*TestModel:
|
||||||
|
assert.Equal(t, len(tt.expected), len(records))
|
||||||
|
for i, record := range records {
|
||||||
|
assert.Equal(t, tt.expected[i], record.RowNumber,
|
||||||
|
"Record %d should have RowNumber=%d", i, tt.expected[i])
|
||||||
|
}
|
||||||
|
case []TestModel:
|
||||||
|
assert.Equal(t, len(tt.expected), len(records))
|
||||||
|
for i, record := range records {
|
||||||
|
assert.Equal(t, tt.expected[i], record.RowNumber,
|
||||||
|
"Record %d should have RowNumber=%d", i, tt.expected[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetRowNumbersOnRecords_NoRowNumberField(t *testing.T) {
|
||||||
|
handler := &Handler{}
|
||||||
|
|
||||||
|
// Model without RowNumber field
|
||||||
|
type SimpleModel struct {
|
||||||
|
ID int64 `json:"id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
}
|
||||||
|
|
||||||
|
records := []*SimpleModel{
|
||||||
|
{ID: 1, Name: "First"},
|
||||||
|
{ID: 2, Name: "Second"},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should not panic when model doesn't have RowNumber field
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
handler.setRowNumbersOnRecords(records, 0)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetRowNumbersOnRecords_NilRecords(t *testing.T) {
|
||||||
|
handler := &Handler{}
|
||||||
|
|
||||||
|
records := []*TestModel{
|
||||||
|
{ID: 1, Name: "First"},
|
||||||
|
nil, // Nil record
|
||||||
|
{ID: 3, Name: "Third"},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should not panic with nil records
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
handler.setRowNumbersOnRecords(records, 0)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Verify non-nil records were set
|
||||||
|
assert.Equal(t, int64(1), records[0].RowNumber)
|
||||||
|
assert.Equal(t, int64(3), records[2].RowNumber)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DBAdhocBuffer simulates the actual DBAdhocBuffer from db package
|
||||||
|
type DBAdhocBuffer struct {
|
||||||
|
CQL1 string `json:"cql1,omitempty" gorm:"->" bun:"-"`
|
||||||
|
RowNumber int64 `json:"_rownumber,omitempty" gorm:"-" bun:"-"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ModelWithEmbeddedBuffer simulates a real model like ModelPublicConsultant
|
||||||
|
type ModelWithEmbeddedBuffer struct {
|
||||||
|
ID int64 `json:"id" bun:"id,pk"`
|
||||||
|
Name string `json:"name" bun:"name"`
|
||||||
|
|
||||||
|
DBAdhocBuffer `json:",omitempty"` // Embedded struct containing RowNumber
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetRowNumbersOnRecords_EmbeddedBuffer(t *testing.T) {
|
||||||
|
handler := &Handler{}
|
||||||
|
|
||||||
|
// Test with embedded DBAdhocBuffer (like real models)
|
||||||
|
records := []*ModelWithEmbeddedBuffer{
|
||||||
|
{ID: 1, Name: "First"},
|
||||||
|
{ID: 2, Name: "Second"},
|
||||||
|
{ID: 3, Name: "Third"},
|
||||||
|
}
|
||||||
|
|
||||||
|
handler.setRowNumbersOnRecords(records, 10)
|
||||||
|
|
||||||
|
// Verify row numbers were set on embedded field
|
||||||
|
assert.Equal(t, int64(11), records[0].RowNumber, "First record should have RowNumber=11")
|
||||||
|
assert.Equal(t, int64(12), records[1].RowNumber, "Second record should have RowNumber=12")
|
||||||
|
assert.Equal(t, int64(13), records[2].RowNumber, "Third record should have RowNumber=13")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetRowNumbersOnRecords_EmbeddedBuffer_SliceOfValues(t *testing.T) {
|
||||||
|
handler := &Handler{}
|
||||||
|
|
||||||
|
// Test with slice of values (not pointers)
|
||||||
|
records := []ModelWithEmbeddedBuffer{
|
||||||
|
{ID: 1, Name: "First"},
|
||||||
|
{ID: 2, Name: "Second"},
|
||||||
|
}
|
||||||
|
|
||||||
|
handler.setRowNumbersOnRecords(records, 0)
|
||||||
|
|
||||||
|
// Verify row numbers were set on embedded field
|
||||||
|
assert.Equal(t, int64(1), records[0].RowNumber, "First record should have RowNumber=1")
|
||||||
|
assert.Equal(t, int64(2), records[1].RowNumber, "Second record should have RowNumber=2")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Simulate the exact structure from user's code
|
||||||
|
type MockDBAdhocBuffer struct {
|
||||||
|
CQL1 string `json:"cql1,omitempty" gorm:"->" bun:"-"`
|
||||||
|
CQL2 string `json:"cql2,omitempty" gorm:"->" bun:"-"`
|
||||||
|
RowNumber int64 `json:"_rownumber,omitempty" gorm:"-" bun:"-"`
|
||||||
|
Request string `json:"_request,omitempty" gorm:"-" bun:"-"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exact structure like ModelPublicConsultant
|
||||||
|
type ModelPublicConsultant struct {
|
||||||
|
Consultant string `json:"consultant" bun:"consultant,type:citext,pk"`
|
||||||
|
Ridconsultant int32 `json:"rid_consultant" bun:"rid_consultant,type:integer,pk"`
|
||||||
|
Updatecnt int64 `json:"updatecnt" bun:"updatecnt,type:integer,default:0"`
|
||||||
|
|
||||||
|
MockDBAdhocBuffer `json:",omitempty"` // Embedded - RowNumber is here!
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetRowNumbersOnRecords_RealModelStructure(t *testing.T) {
|
||||||
|
handler := &Handler{}
|
||||||
|
|
||||||
|
// Test with exact structure from user's ModelPublicConsultant
|
||||||
|
records := []*ModelPublicConsultant{
|
||||||
|
{Consultant: "John Doe", Ridconsultant: 1, Updatecnt: 0},
|
||||||
|
{Consultant: "Jane Smith", Ridconsultant: 2, Updatecnt: 0},
|
||||||
|
{Consultant: "Bob Johnson", Ridconsultant: 3, Updatecnt: 0},
|
||||||
|
}
|
||||||
|
|
||||||
|
handler.setRowNumbersOnRecords(records, 100)
|
||||||
|
|
||||||
|
// Verify row numbers were set correctly in the embedded DBAdhocBuffer
|
||||||
|
assert.Equal(t, int64(101), records[0].RowNumber, "First consultant should have RowNumber=101")
|
||||||
|
assert.Equal(t, int64(102), records[1].RowNumber, "Second consultant should have RowNumber=102")
|
||||||
|
assert.Equal(t, int64(103), records[2].RowNumber, "Third consultant should have RowNumber=103")
|
||||||
|
|
||||||
|
t.Logf("✓ RowNumber correctly set in embedded MockDBAdhocBuffer")
|
||||||
|
t.Logf(" Record 0: Consultant=%s, RowNumber=%d", records[0].Consultant, records[0].RowNumber)
|
||||||
|
t.Logf(" Record 1: Consultant=%s, RowNumber=%d", records[1].Consultant, records[1].RowNumber)
|
||||||
|
t.Logf(" Record 2: Consultant=%s, RowNumber=%d", records[2].Consultant, records[2].RowNumber)
|
||||||
|
}
|
||||||
662
pkg/security/CALLBACKS_GUIDE.md
Normal file
662
pkg/security/CALLBACKS_GUIDE.md
Normal file
@ -0,0 +1,662 @@
|
|||||||
|
# Security Provider Callbacks Guide
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The ResolveSpec security provider uses a **callback-based architecture** that requires you to implement three functions:
|
||||||
|
|
||||||
|
1. **AuthenticateCallback** - Extract user credentials from HTTP requests
|
||||||
|
2. **LoadColumnSecurityCallback** - Load column security rules for masking/hiding
|
||||||
|
3. **LoadRowSecurityCallback** - Load row security filters (WHERE clauses)
|
||||||
|
|
||||||
|
This design allows you to integrate the security provider with **any** authentication system and database schema.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Why Callbacks?
|
||||||
|
|
||||||
|
The callback-based design provides:
|
||||||
|
|
||||||
|
✅ **Flexibility** - Works with any auth system (JWT, session, OAuth, custom)
|
||||||
|
✅ **Database Agnostic** - No assumptions about your security table schema
|
||||||
|
✅ **Testability** - Easy to mock for unit tests
|
||||||
|
✅ **Extensibility** - Add custom logic without modifying core code
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
### Step 1: Implement the Three Callbacks
|
||||||
|
|
||||||
|
```go
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/security"
|
||||||
|
)
|
||||||
|
|
||||||
|
// 1. Authentication: Extract user from request
|
||||||
|
func myAuthFunction(r *http.Request) (userID int, roles string, err error) {
|
||||||
|
// Your auth logic here (JWT, session, header, etc.)
|
||||||
|
token := r.Header.Get("Authorization")
|
||||||
|
userID, roles, err = validateToken(token)
|
||||||
|
return userID, roles, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Column Security: Load column masking rules
|
||||||
|
func myLoadColumnSecurity(userID int, schema, tablename string) ([]security.ColumnSecurity, error) {
|
||||||
|
// Your database query or config lookup here
|
||||||
|
return loadColumnRulesFromDatabase(userID, schema, tablename)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. Row Security: Load row filtering rules
|
||||||
|
func myLoadRowSecurity(userID int, schema, tablename string) (security.RowSecurity, error) {
|
||||||
|
// Your database query or config lookup here
|
||||||
|
return loadRowRulesFromDatabase(userID, schema, tablename)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Step 2: Configure the Callbacks
|
||||||
|
|
||||||
|
```go
|
||||||
|
func main() {
|
||||||
|
db := setupDatabase()
|
||||||
|
handler := restheadspec.NewHandlerWithGORM(db)
|
||||||
|
|
||||||
|
// Configure callbacks BEFORE SetupSecurityProvider
|
||||||
|
security.GlobalSecurity.AuthenticateCallback = myAuthFunction
|
||||||
|
security.GlobalSecurity.LoadColumnSecurityCallback = myLoadColumnSecurity
|
||||||
|
security.GlobalSecurity.LoadRowSecurityCallback = myLoadRowSecurity
|
||||||
|
|
||||||
|
// Setup security provider (validates callbacks are set)
|
||||||
|
if err := security.SetupSecurityProvider(handler, &security.GlobalSecurity); err != nil {
|
||||||
|
log.Fatal(err) // Fails if callbacks not configured
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply middleware
|
||||||
|
router := mux.NewRouter()
|
||||||
|
restheadspec.SetupMuxRoutes(router, handler)
|
||||||
|
router.Use(mux.MiddlewareFunc(security.AuthMiddleware))
|
||||||
|
router.Use(mux.MiddlewareFunc(security.SetSecurityMiddleware))
|
||||||
|
|
||||||
|
http.ListenAndServe(":8080", router)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Callback 1: AuthenticateCallback
|
||||||
|
|
||||||
|
### Function Signature
|
||||||
|
|
||||||
|
```go
|
||||||
|
func(r *http.Request) (userID int, roles string, err error)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Parameters
|
||||||
|
- `r *http.Request` - The incoming HTTP request
|
||||||
|
|
||||||
|
### Returns
|
||||||
|
- `userID int` - The authenticated user's ID
|
||||||
|
- `roles string` - User's roles (comma-separated, e.g., "admin,manager")
|
||||||
|
- `err error` - Return error to reject the request (HTTP 401)
|
||||||
|
|
||||||
|
### Example Implementations
|
||||||
|
|
||||||
|
#### Simple Header-Based Auth
|
||||||
|
```go
|
||||||
|
func authenticateFromHeader(r *http.Request) (int, string, error) {
|
||||||
|
userIDStr := r.Header.Get("X-User-ID")
|
||||||
|
if userIDStr == "" {
|
||||||
|
return 0, "", fmt.Errorf("X-User-ID header required")
|
||||||
|
}
|
||||||
|
|
||||||
|
userID, err := strconv.Atoi(userIDStr)
|
||||||
|
if err != nil {
|
||||||
|
return 0, "", fmt.Errorf("invalid user ID")
|
||||||
|
}
|
||||||
|
|
||||||
|
roles := r.Header.Get("X-User-Roles") // Optional
|
||||||
|
return userID, roles, nil
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
#### JWT Token Auth
|
||||||
|
```go
|
||||||
|
import "github.com/golang-jwt/jwt/v5"
|
||||||
|
|
||||||
|
func authenticateFromJWT(r *http.Request) (int, string, error) {
|
||||||
|
authHeader := r.Header.Get("Authorization")
|
||||||
|
tokenString := strings.TrimPrefix(authHeader, "Bearer ")
|
||||||
|
|
||||||
|
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
|
||||||
|
return []byte(os.Getenv("JWT_SECRET")), nil
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil || !token.Valid {
|
||||||
|
return 0, "", fmt.Errorf("invalid token")
|
||||||
|
}
|
||||||
|
|
||||||
|
claims := token.Claims.(jwt.MapClaims)
|
||||||
|
userID := int(claims["user_id"].(float64))
|
||||||
|
roles := claims["roles"].(string)
|
||||||
|
|
||||||
|
return userID, roles, nil
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Session Cookie Auth
|
||||||
|
```go
|
||||||
|
func authenticateFromSession(r *http.Request) (int, string, error) {
|
||||||
|
cookie, err := r.Cookie("session_id")
|
||||||
|
if err != nil {
|
||||||
|
return 0, "", fmt.Errorf("no session cookie")
|
||||||
|
}
|
||||||
|
|
||||||
|
session, err := sessionStore.Get(cookie.Value)
|
||||||
|
if err != nil {
|
||||||
|
return 0, "", fmt.Errorf("invalid session")
|
||||||
|
}
|
||||||
|
|
||||||
|
return session.UserID, session.Roles, nil
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Callback 2: LoadColumnSecurityCallback
|
||||||
|
|
||||||
|
### Function Signature
|
||||||
|
|
||||||
|
```go
|
||||||
|
func(pUserID int, pSchema, pTablename string) ([]ColumnSecurity, error)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Parameters
|
||||||
|
- `pUserID int` - The authenticated user's ID
|
||||||
|
- `pSchema string` - Database schema (e.g., "public")
|
||||||
|
- `pTablename string` - Table name (e.g., "employees")
|
||||||
|
|
||||||
|
### Returns
|
||||||
|
- `[]ColumnSecurity` - List of column security rules
|
||||||
|
- `error` - Return error if loading fails
|
||||||
|
|
||||||
|
### ColumnSecurity Structure
|
||||||
|
|
||||||
|
```go
|
||||||
|
type ColumnSecurity struct {
|
||||||
|
Schema string // "public"
|
||||||
|
Tablename string // "employees"
|
||||||
|
Path []string // ["ssn"] or ["address", "street"]
|
||||||
|
Accesstype string // "mask" or "hide"
|
||||||
|
|
||||||
|
// Masking configuration (for Accesstype = "mask")
|
||||||
|
MaskStart int // Mask first N characters
|
||||||
|
MaskEnd int // Mask last N characters
|
||||||
|
MaskInvert bool // true = mask middle, false = mask edges
|
||||||
|
MaskChar string // Character to use for masking (default "*")
|
||||||
|
|
||||||
|
// Optional fields
|
||||||
|
ExtraFilters map[string]string
|
||||||
|
Control string
|
||||||
|
ID int
|
||||||
|
UserID int
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Example Implementations
|
||||||
|
|
||||||
|
#### Load from Database
|
||||||
|
```go
|
||||||
|
func loadColumnSecurityFromDB(userID int, schema, tablename string) ([]security.ColumnSecurity, error) {
|
||||||
|
var rules []security.ColumnSecurity
|
||||||
|
|
||||||
|
query := `
|
||||||
|
SELECT control, accesstype, jsonvalue
|
||||||
|
FROM core.secacces
|
||||||
|
WHERE rid_hub IN (
|
||||||
|
SELECT rid_hub_parent FROM core.hub_link
|
||||||
|
WHERE rid_hub_child = ? AND parent_hubtype = 'secgroup'
|
||||||
|
)
|
||||||
|
AND control ILIKE ?
|
||||||
|
`
|
||||||
|
|
||||||
|
rows, err := db.Query(query, userID, fmt.Sprintf("%s.%s%%", schema, tablename))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
for rows.Next() {
|
||||||
|
var control, accesstype, jsonValue string
|
||||||
|
rows.Scan(&control, &accesstype, &jsonValue)
|
||||||
|
|
||||||
|
// Parse control: "schema.table.column"
|
||||||
|
parts := strings.Split(control, ".")
|
||||||
|
if len(parts) < 3 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
rule := security.ColumnSecurity{
|
||||||
|
Schema: schema,
|
||||||
|
Tablename: tablename,
|
||||||
|
Path: parts[2:],
|
||||||
|
Accesstype: accesstype,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse JSON configuration
|
||||||
|
var config map[string]interface{}
|
||||||
|
json.Unmarshal([]byte(jsonValue), &config)
|
||||||
|
if start, ok := config["start"].(float64); ok {
|
||||||
|
rule.MaskStart = int(start)
|
||||||
|
}
|
||||||
|
if end, ok := config["end"].(float64); ok {
|
||||||
|
rule.MaskEnd = int(end)
|
||||||
|
}
|
||||||
|
if char, ok := config["char"].(string); ok {
|
||||||
|
rule.MaskChar = char
|
||||||
|
}
|
||||||
|
|
||||||
|
rules = append(rules, rule)
|
||||||
|
}
|
||||||
|
|
||||||
|
return rules, nil
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Load from Static Config
|
||||||
|
```go
|
||||||
|
func loadColumnSecurityFromConfig(userID int, schema, tablename string) ([]security.ColumnSecurity, error) {
|
||||||
|
// Define security rules in code
|
||||||
|
allRules := map[string][]security.ColumnSecurity{
|
||||||
|
"public.employees": {
|
||||||
|
{
|
||||||
|
Schema: "public",
|
||||||
|
Tablename: "employees",
|
||||||
|
Path: []string{"ssn"},
|
||||||
|
Accesstype: "mask",
|
||||||
|
MaskStart: 5,
|
||||||
|
MaskChar: "*",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Schema: "public",
|
||||||
|
Tablename: "employees",
|
||||||
|
Path: []string{"salary"},
|
||||||
|
Accesstype: "hide",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
key := fmt.Sprintf("%s.%s", schema, tablename)
|
||||||
|
rules, ok := allRules[key]
|
||||||
|
if !ok {
|
||||||
|
return []security.ColumnSecurity{}, nil // No rules
|
||||||
|
}
|
||||||
|
|
||||||
|
return rules, nil
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Column Security Examples
|
||||||
|
|
||||||
|
**Mask SSN (show last 4 digits):**
|
||||||
|
```go
|
||||||
|
ColumnSecurity{
|
||||||
|
Path: []string{"ssn"},
|
||||||
|
Accesstype: "mask",
|
||||||
|
MaskStart: 5, // Mask first 5 characters
|
||||||
|
MaskEnd: 0, // Keep last 4 visible
|
||||||
|
MaskChar: "*",
|
||||||
|
}
|
||||||
|
// Result: "123-45-6789" → "*****6789"
|
||||||
|
```
|
||||||
|
|
||||||
|
**Hide entire field:**
|
||||||
|
```go
|
||||||
|
ColumnSecurity{
|
||||||
|
Path: []string{"salary"},
|
||||||
|
Accesstype: "hide",
|
||||||
|
}
|
||||||
|
// Result: salary field returns 0 or empty
|
||||||
|
```
|
||||||
|
|
||||||
|
**Mask credit card (show last 4 digits):**
|
||||||
|
```go
|
||||||
|
ColumnSecurity{
|
||||||
|
Path: []string{"credit_card"},
|
||||||
|
Accesstype: "mask",
|
||||||
|
MaskStart: 12,
|
||||||
|
MaskChar: "*",
|
||||||
|
}
|
||||||
|
// Result: "1234-5678-9012-3456" → "************3456"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Callback 3: LoadRowSecurityCallback
|
||||||
|
|
||||||
|
### Function Signature
|
||||||
|
|
||||||
|
```go
|
||||||
|
func(pUserID int, pSchema, pTablename string) (RowSecurity, error)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Parameters
|
||||||
|
- `pUserID int` - The authenticated user's ID
|
||||||
|
- `pSchema string` - Database schema
|
||||||
|
- `pTablename string` - Table name
|
||||||
|
|
||||||
|
### Returns
|
||||||
|
- `RowSecurity` - Row security configuration
|
||||||
|
- `error` - Return error if loading fails
|
||||||
|
|
||||||
|
### RowSecurity Structure
|
||||||
|
|
||||||
|
```go
|
||||||
|
type RowSecurity struct {
|
||||||
|
Schema string // "public"
|
||||||
|
Tablename string // "orders"
|
||||||
|
UserID int // Current user ID
|
||||||
|
Template string // WHERE clause template (e.g., "user_id = {UserID}")
|
||||||
|
HasBlock bool // If true, block ALL access to this table
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Template Variables
|
||||||
|
|
||||||
|
You can use these placeholders in the `Template` string:
|
||||||
|
- `{UserID}` - Current user's ID
|
||||||
|
- `{PrimaryKeyName}` - Primary key column name
|
||||||
|
- `{TableName}` - Table name
|
||||||
|
- `{SchemaName}` - Schema name
|
||||||
|
|
||||||
|
### Example Implementations
|
||||||
|
|
||||||
|
#### Load from Database Function
|
||||||
|
```go
|
||||||
|
func loadRowSecurityFromDB(userID int, schema, tablename string) (security.RowSecurity, error) {
|
||||||
|
var record security.RowSecurity
|
||||||
|
|
||||||
|
query := `
|
||||||
|
SELECT p_template, p_block
|
||||||
|
FROM core.api_sec_rowtemplate(?, ?, ?)
|
||||||
|
`
|
||||||
|
|
||||||
|
row := db.QueryRow(query, schema, tablename, userID)
|
||||||
|
err := row.Scan(&record.Template, &record.HasBlock)
|
||||||
|
if err != nil {
|
||||||
|
return security.RowSecurity{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
record.Schema = schema
|
||||||
|
record.Tablename = tablename
|
||||||
|
record.UserID = userID
|
||||||
|
|
||||||
|
return record, nil
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Load from Static Config
|
||||||
|
```go
|
||||||
|
func loadRowSecurityFromConfig(userID int, schema, tablename string) (security.RowSecurity, error) {
|
||||||
|
key := fmt.Sprintf("%s.%s", schema, tablename)
|
||||||
|
|
||||||
|
// Define templates for each table
|
||||||
|
templates := map[string]string{
|
||||||
|
"public.orders": "user_id = {UserID}",
|
||||||
|
"public.documents": "user_id = {UserID} OR is_public = true",
|
||||||
|
}
|
||||||
|
|
||||||
|
// Define blocked tables
|
||||||
|
blocked := map[string]bool{
|
||||||
|
"public.admin_logs": true,
|
||||||
|
}
|
||||||
|
|
||||||
|
if blocked[key] {
|
||||||
|
return security.RowSecurity{
|
||||||
|
Schema: schema,
|
||||||
|
Tablename: tablename,
|
||||||
|
UserID: userID,
|
||||||
|
HasBlock: true,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
template, ok := templates[key]
|
||||||
|
if !ok {
|
||||||
|
// No row security - allow all rows
|
||||||
|
return security.RowSecurity{
|
||||||
|
Schema: schema,
|
||||||
|
Tablename: tablename,
|
||||||
|
UserID: userID,
|
||||||
|
Template: "",
|
||||||
|
HasBlock: false,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return security.RowSecurity{
|
||||||
|
Schema: schema,
|
||||||
|
Tablename: tablename,
|
||||||
|
UserID: userID,
|
||||||
|
Template: template,
|
||||||
|
HasBlock: false,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Row Security Examples
|
||||||
|
|
||||||
|
**Users see only their own records:**
|
||||||
|
```go
|
||||||
|
RowSecurity{
|
||||||
|
Template: "user_id = {UserID}",
|
||||||
|
}
|
||||||
|
// Query: SELECT * FROM orders WHERE user_id = 123
|
||||||
|
```
|
||||||
|
|
||||||
|
**Users see their records OR public records:**
|
||||||
|
```go
|
||||||
|
RowSecurity{
|
||||||
|
Template: "user_id = {UserID} OR is_public = true",
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Complex filter with subquery:**
|
||||||
|
```go
|
||||||
|
RowSecurity{
|
||||||
|
Template: "department_id IN (SELECT department_id FROM user_departments WHERE user_id = {UserID})",
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Block all access:**
|
||||||
|
```go
|
||||||
|
RowSecurity{
|
||||||
|
HasBlock: true,
|
||||||
|
}
|
||||||
|
// All queries to this table will be rejected
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Complete Integration Example
|
||||||
|
|
||||||
|
```go
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/restheadspec"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/security"
|
||||||
|
"github.com/gorilla/mux"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
db := setupDatabase()
|
||||||
|
handler := restheadspec.NewHandlerWithGORM(db)
|
||||||
|
handler.RegisterModel("public", "orders", Order{})
|
||||||
|
|
||||||
|
// ===== CONFIGURE CALLBACKS =====
|
||||||
|
security.GlobalSecurity.AuthenticateCallback = authenticateUser
|
||||||
|
security.GlobalSecurity.LoadColumnSecurityCallback = loadColumnSec
|
||||||
|
security.GlobalSecurity.LoadRowSecurityCallback = loadRowSec
|
||||||
|
|
||||||
|
// ===== SETUP SECURITY =====
|
||||||
|
if err := security.SetupSecurityProvider(handler, &security.GlobalSecurity); err != nil {
|
||||||
|
log.Fatal("Security setup failed:", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ===== SETUP ROUTES =====
|
||||||
|
router := mux.NewRouter()
|
||||||
|
restheadspec.SetupMuxRoutes(router, handler)
|
||||||
|
router.Use(mux.MiddlewareFunc(security.AuthMiddleware))
|
||||||
|
router.Use(mux.MiddlewareFunc(security.SetSecurityMiddleware))
|
||||||
|
|
||||||
|
log.Println("Server starting on :8080")
|
||||||
|
http.ListenAndServe(":8080", router)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Callback implementations
|
||||||
|
func authenticateUser(r *http.Request) (int, string, error) {
|
||||||
|
userIDStr := r.Header.Get("X-User-ID")
|
||||||
|
if userIDStr == "" {
|
||||||
|
return 0, "", fmt.Errorf("authentication required")
|
||||||
|
}
|
||||||
|
userID, err := strconv.Atoi(userIDStr)
|
||||||
|
return userID, "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
func loadColumnSec(userID int, schema, table string) ([]security.ColumnSecurity, error) {
|
||||||
|
// Your implementation here
|
||||||
|
return []security.ColumnSecurity{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func loadRowSec(userID int, schema, table string) (security.RowSecurity, error) {
|
||||||
|
return security.RowSecurity{
|
||||||
|
Schema: schema,
|
||||||
|
Tablename: table,
|
||||||
|
UserID: userID,
|
||||||
|
Template: "user_id = " + strconv.Itoa(userID),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Testing Your Callbacks
|
||||||
|
|
||||||
|
### Unit Test Example
|
||||||
|
|
||||||
|
```go
|
||||||
|
func TestAuthCallback(t *testing.T) {
|
||||||
|
req := httptest.NewRequest("GET", "/api/orders", nil)
|
||||||
|
req.Header.Set("X-User-ID", "123")
|
||||||
|
|
||||||
|
userID, roles, err := myAuthFunction(req)
|
||||||
|
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, 123, userID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestColumnSecurityCallback(t *testing.T) {
|
||||||
|
rules, err := myLoadColumnSecurity(123, "public", "employees")
|
||||||
|
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Greater(t, len(rules), 0)
|
||||||
|
assert.Equal(t, "mask", rules[0].Accesstype)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Common Patterns
|
||||||
|
|
||||||
|
### Pattern 1: Role-Based Security
|
||||||
|
|
||||||
|
```go
|
||||||
|
func loadColumnSec(userID int, schema, table string) ([]security.ColumnSecurity, error) {
|
||||||
|
roles := getUserRoles(userID)
|
||||||
|
|
||||||
|
if contains(roles, "admin") {
|
||||||
|
// Admins see everything
|
||||||
|
return []security.ColumnSecurity{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Non-admins have restrictions
|
||||||
|
return []security.ColumnSecurity{
|
||||||
|
{Path: []string{"ssn"}, Accesstype: "mask"},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Pattern 2: Tenant Isolation
|
||||||
|
|
||||||
|
```go
|
||||||
|
func loadRowSec(userID int, schema, table string) (security.RowSecurity, error) {
|
||||||
|
tenantID := getUserTenant(userID)
|
||||||
|
|
||||||
|
return security.RowSecurity{
|
||||||
|
Template: fmt.Sprintf("tenant_id = %d", tenantID),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Pattern 3: Caching Security Rules
|
||||||
|
|
||||||
|
```go
|
||||||
|
var securityCache = cache.New(5*time.Minute, 10*time.Minute)
|
||||||
|
|
||||||
|
func loadColumnSec(userID int, schema, table string) ([]security.ColumnSecurity, error) {
|
||||||
|
cacheKey := fmt.Sprintf("%d:%s.%s", userID, schema, table)
|
||||||
|
|
||||||
|
if cached, found := securityCache.Get(cacheKey); found {
|
||||||
|
return cached.([]security.ColumnSecurity), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
rules := loadFromDatabase(userID, schema, table)
|
||||||
|
securityCache.Set(cacheKey, rules, cache.DefaultExpiration)
|
||||||
|
|
||||||
|
return rules, nil
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Troubleshooting
|
||||||
|
|
||||||
|
### Error: "AuthenticateCallback not set"
|
||||||
|
**Solution:** Configure all three callbacks before calling `SetupSecurityProvider`:
|
||||||
|
```go
|
||||||
|
security.GlobalSecurity.AuthenticateCallback = myAuthFunc
|
||||||
|
security.GlobalSecurity.LoadColumnSecurityCallback = myColSecFunc
|
||||||
|
security.GlobalSecurity.LoadRowSecurityCallback = myRowSecFunc
|
||||||
|
```
|
||||||
|
|
||||||
|
### Error: "Authentication failed"
|
||||||
|
**Solution:** Check your `AuthenticateCallback` implementation. Ensure it returns valid user ID or proper error.
|
||||||
|
|
||||||
|
### Security rules not applying
|
||||||
|
**Solution:**
|
||||||
|
1. Check callbacks are returning data
|
||||||
|
2. Enable debug logging
|
||||||
|
3. Verify database queries return results
|
||||||
|
4. Check user has security groups assigned
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Next Steps
|
||||||
|
|
||||||
|
1. ✅ Implement the three callbacks for your system
|
||||||
|
2. ✅ Configure `GlobalSecurity` with your callbacks
|
||||||
|
3. ✅ Call `SetupSecurityProvider`
|
||||||
|
4. ✅ Test with different users and verify isolation
|
||||||
|
5. ✅ Review `callbacks_example.go` for more examples
|
||||||
|
|
||||||
|
For complete working examples, see:
|
||||||
|
- `pkg/security/callbacks_example.go` - 7 example implementations
|
||||||
|
- `examples/secure_server/main.go` - Full server example
|
||||||
|
- `pkg/security/README.md` - Comprehensive documentation
|
||||||
402
pkg/security/QUICK_REFERENCE.md
Normal file
402
pkg/security/QUICK_REFERENCE.md
Normal file
@ -0,0 +1,402 @@
|
|||||||
|
# Security Provider - Quick Reference
|
||||||
|
|
||||||
|
## 3-Step Setup
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Step 1: Implement callbacks
|
||||||
|
func myAuth(r *http.Request) (int, string, error) { /* ... */ }
|
||||||
|
func myColSec(userID int, schema, table string) ([]security.ColumnSecurity, error) { /* ... */ }
|
||||||
|
func myRowSec(userID int, schema, table string) (security.RowSecurity, error) { /* ... */ }
|
||||||
|
|
||||||
|
// Step 2: Configure callbacks
|
||||||
|
security.GlobalSecurity.AuthenticateCallback = myAuth
|
||||||
|
security.GlobalSecurity.LoadColumnSecurityCallback = myColSec
|
||||||
|
security.GlobalSecurity.LoadRowSecurityCallback = myRowSec
|
||||||
|
|
||||||
|
// Step 3: Setup and apply middleware
|
||||||
|
security.SetupSecurityProvider(handler, &security.GlobalSecurity)
|
||||||
|
router.Use(mux.MiddlewareFunc(security.AuthMiddleware))
|
||||||
|
router.Use(mux.MiddlewareFunc(security.SetSecurityMiddleware))
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Callback Signatures
|
||||||
|
|
||||||
|
```go
|
||||||
|
// 1. Authentication
|
||||||
|
func(r *http.Request) (userID int, roles string, err error)
|
||||||
|
|
||||||
|
// 2. Column Security
|
||||||
|
func(userID int, schema, tablename string) ([]ColumnSecurity, error)
|
||||||
|
|
||||||
|
// 3. Row Security
|
||||||
|
func(userID int, schema, tablename string) (RowSecurity, error)
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## ColumnSecurity Structure
|
||||||
|
|
||||||
|
```go
|
||||||
|
security.ColumnSecurity{
|
||||||
|
Path: []string{"column_name"}, // ["ssn"] or ["address", "street"]
|
||||||
|
Accesstype: "mask", // "mask" or "hide"
|
||||||
|
MaskStart: 5, // Mask first N chars
|
||||||
|
MaskEnd: 0, // Mask last N chars
|
||||||
|
MaskChar: "*", // Masking character
|
||||||
|
MaskInvert: false, // true = mask middle
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Common Examples
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Hide entire field
|
||||||
|
{Path: []string{"salary"}, Accesstype: "hide"}
|
||||||
|
|
||||||
|
// Mask SSN (show last 4)
|
||||||
|
{Path: []string{"ssn"}, Accesstype: "mask", MaskStart: 5}
|
||||||
|
|
||||||
|
// Mask credit card (show last 4)
|
||||||
|
{Path: []string{"credit_card"}, Accesstype: "mask", MaskStart: 12}
|
||||||
|
|
||||||
|
// Mask email (j***@example.com)
|
||||||
|
{Path: []string{"email"}, Accesstype: "mask", MaskStart: 1, MaskEnd: 0}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## RowSecurity Structure
|
||||||
|
|
||||||
|
```go
|
||||||
|
security.RowSecurity{
|
||||||
|
Schema: "public",
|
||||||
|
Tablename: "orders",
|
||||||
|
UserID: 123,
|
||||||
|
Template: "user_id = {UserID}", // WHERE clause
|
||||||
|
HasBlock: false, // true = block all access
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Template Variables
|
||||||
|
|
||||||
|
- `{UserID}` - Current user ID
|
||||||
|
- `{PrimaryKeyName}` - Primary key column
|
||||||
|
- `{TableName}` - Table name
|
||||||
|
- `{SchemaName}` - Schema name
|
||||||
|
|
||||||
|
### Common Examples
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Users see only their records
|
||||||
|
Template: "user_id = {UserID}"
|
||||||
|
|
||||||
|
// Users see their records OR public ones
|
||||||
|
Template: "user_id = {UserID} OR is_public = true"
|
||||||
|
|
||||||
|
// Tenant isolation
|
||||||
|
Template: "tenant_id = 5 AND user_id = {UserID}"
|
||||||
|
|
||||||
|
// Complex with subquery
|
||||||
|
Template: "dept_id IN (SELECT dept_id FROM user_depts WHERE user_id = {UserID})"
|
||||||
|
|
||||||
|
// Block all access
|
||||||
|
HasBlock: true
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Example Implementations
|
||||||
|
|
||||||
|
### Simple Header Auth
|
||||||
|
|
||||||
|
```go
|
||||||
|
func authFromHeader(r *http.Request) (int, string, error) {
|
||||||
|
userIDStr := r.Header.Get("X-User-ID")
|
||||||
|
if userIDStr == "" {
|
||||||
|
return 0, "", fmt.Errorf("X-User-ID required")
|
||||||
|
}
|
||||||
|
userID, err := strconv.Atoi(userIDStr)
|
||||||
|
return userID, "", err
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### JWT Auth
|
||||||
|
|
||||||
|
```go
|
||||||
|
func authFromJWT(r *http.Request) (int, string, error) {
|
||||||
|
token := strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer ")
|
||||||
|
claims, err := jwt.Parse(token, secret)
|
||||||
|
if err != nil {
|
||||||
|
return 0, "", err
|
||||||
|
}
|
||||||
|
return claims.UserID, claims.Roles, nil
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Static Column Security
|
||||||
|
|
||||||
|
```go
|
||||||
|
func loadColSec(userID int, schema, table string) ([]security.ColumnSecurity, error) {
|
||||||
|
if table == "employees" {
|
||||||
|
return []security.ColumnSecurity{
|
||||||
|
{Path: []string{"ssn"}, Accesstype: "mask", MaskStart: 5},
|
||||||
|
{Path: []string{"salary"}, Accesstype: "hide"},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
return []security.ColumnSecurity{}, nil
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Database Column Security
|
||||||
|
|
||||||
|
```go
|
||||||
|
func loadColSec(userID int, schema, table string) ([]security.ColumnSecurity, error) {
|
||||||
|
rows, err := db.Query(`
|
||||||
|
SELECT control, accesstype, jsonvalue
|
||||||
|
FROM core.secacces
|
||||||
|
WHERE rid_hub IN (...)
|
||||||
|
AND control ILIKE ?
|
||||||
|
`, fmt.Sprintf("%s.%s%%", schema, table))
|
||||||
|
// ... parse and return
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Static Row Security
|
||||||
|
|
||||||
|
```go
|
||||||
|
func loadRowSec(userID int, schema, table string) (security.RowSecurity, error) {
|
||||||
|
templates := map[string]string{
|
||||||
|
"orders": "user_id = {UserID}",
|
||||||
|
"documents": "user_id = {UserID} OR is_public = true",
|
||||||
|
}
|
||||||
|
return security.RowSecurity{
|
||||||
|
Template: templates[table],
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Testing
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Test auth callback
|
||||||
|
req := httptest.NewRequest("GET", "/", nil)
|
||||||
|
req.Header.Set("X-User-ID", "123")
|
||||||
|
userID, roles, err := myAuth(req)
|
||||||
|
assert.Equal(t, 123, userID)
|
||||||
|
|
||||||
|
// Test column security callback
|
||||||
|
rules, err := myColSec(123, "public", "employees")
|
||||||
|
assert.Equal(t, "mask", rules[0].Accesstype)
|
||||||
|
|
||||||
|
// Test row security callback
|
||||||
|
rowSec, err := myRowSec(123, "public", "orders")
|
||||||
|
assert.Equal(t, "user_id = {UserID}", rowSec.Template)
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Request Flow
|
||||||
|
|
||||||
|
```
|
||||||
|
HTTP Request
|
||||||
|
↓
|
||||||
|
AuthMiddleware → calls AuthenticateCallback
|
||||||
|
↓ (adds userID to context)
|
||||||
|
SetSecurityMiddleware → adds GlobalSecurity to context
|
||||||
|
↓
|
||||||
|
Handler.Handle()
|
||||||
|
↓
|
||||||
|
BeforeRead Hook → calls LoadColumnSecurityCallback + LoadRowSecurityCallback
|
||||||
|
↓
|
||||||
|
BeforeScan Hook → applies row security (WHERE clause)
|
||||||
|
↓
|
||||||
|
Database Query
|
||||||
|
↓
|
||||||
|
AfterRead Hook → applies column security (masking)
|
||||||
|
↓
|
||||||
|
HTTP Response
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Common Patterns
|
||||||
|
|
||||||
|
### Role-Based Security
|
||||||
|
|
||||||
|
```go
|
||||||
|
func loadColSec(userID int, schema, table string) ([]security.ColumnSecurity, error) {
|
||||||
|
if isAdmin(userID) {
|
||||||
|
return []security.ColumnSecurity{}, nil // No restrictions
|
||||||
|
}
|
||||||
|
return loadRestrictions(userID, schema, table), nil
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Tenant Isolation
|
||||||
|
|
||||||
|
```go
|
||||||
|
func loadRowSec(userID int, schema, table string) (security.RowSecurity, error) {
|
||||||
|
tenantID := getUserTenant(userID)
|
||||||
|
return security.RowSecurity{
|
||||||
|
Template: fmt.Sprintf("tenant_id = %d", tenantID),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Caching
|
||||||
|
|
||||||
|
```go
|
||||||
|
var cache = make(map[string][]security.ColumnSecurity)
|
||||||
|
|
||||||
|
func loadColSec(userID int, schema, table string) ([]security.ColumnSecurity, error) {
|
||||||
|
key := fmt.Sprintf("%d:%s.%s", userID, schema, table)
|
||||||
|
if cached, ok := cache[key]; ok {
|
||||||
|
return cached, nil
|
||||||
|
}
|
||||||
|
rules := loadFromDB(userID, schema, table)
|
||||||
|
cache[key] = rules
|
||||||
|
return rules, nil
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Error Handling
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Setup will fail if callbacks not configured
|
||||||
|
if err := security.SetupSecurityProvider(handler, &security.GlobalSecurity); err != nil {
|
||||||
|
log.Fatal("Security setup failed:", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Auth middleware rejects if callback returns error
|
||||||
|
func myAuth(r *http.Request) (int, string, error) {
|
||||||
|
if invalid {
|
||||||
|
return 0, "", fmt.Errorf("invalid credentials") // Returns HTTP 401
|
||||||
|
}
|
||||||
|
return userID, roles, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Security loading can fail gracefully
|
||||||
|
func loadColSec(userID int, schema, table string) ([]security.ColumnSecurity, error) {
|
||||||
|
rules, err := db.Load(...)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Failed to load security: %v", err)
|
||||||
|
return []security.ColumnSecurity{}, nil // No rules = no restrictions
|
||||||
|
}
|
||||||
|
return rules, nil
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Debugging
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Enable debug logging
|
||||||
|
import "github.com/bitechdev/GoCore/pkg/cfg"
|
||||||
|
cfg.SetLogLevel("DEBUG")
|
||||||
|
|
||||||
|
// Log in callbacks
|
||||||
|
func myAuth(r *http.Request) (int, string, error) {
|
||||||
|
token := r.Header.Get("Authorization")
|
||||||
|
log.Printf("Auth: token=%s", token)
|
||||||
|
// ...
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if callbacks are called
|
||||||
|
func loadColSec(userID int, schema, table string) ([]security.ColumnSecurity, error) {
|
||||||
|
log.Printf("Loading column security: user=%d, schema=%s, table=%s", userID, schema, table)
|
||||||
|
// ...
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Complete Minimal Example
|
||||||
|
|
||||||
|
```go
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/restheadspec"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/security"
|
||||||
|
"github.com/gorilla/mux"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
handler := restheadspec.NewHandlerWithGORM(db)
|
||||||
|
|
||||||
|
// Configure callbacks
|
||||||
|
security.GlobalSecurity.AuthenticateCallback = func(r *http.Request) (int, string, error) {
|
||||||
|
id, _ := strconv.Atoi(r.Header.Get("X-User-ID"))
|
||||||
|
return id, "", nil
|
||||||
|
}
|
||||||
|
security.GlobalSecurity.LoadColumnSecurityCallback = func(u int, s, t string) ([]security.ColumnSecurity, error) {
|
||||||
|
return []security.ColumnSecurity{}, nil
|
||||||
|
}
|
||||||
|
security.GlobalSecurity.LoadRowSecurityCallback = func(u int, s, t string) (security.RowSecurity, error) {
|
||||||
|
return security.RowSecurity{Template: fmt.Sprintf("user_id = %d", u)}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Setup
|
||||||
|
security.SetupSecurityProvider(handler, &security.GlobalSecurity)
|
||||||
|
|
||||||
|
// Middleware
|
||||||
|
router := mux.NewRouter()
|
||||||
|
restheadspec.SetupMuxRoutes(router, handler)
|
||||||
|
router.Use(mux.MiddlewareFunc(security.AuthMiddleware))
|
||||||
|
router.Use(mux.MiddlewareFunc(security.SetSecurityMiddleware))
|
||||||
|
|
||||||
|
http.ListenAndServe(":8080", router)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Resources
|
||||||
|
|
||||||
|
| File | Description |
|
||||||
|
|------|-------------|
|
||||||
|
| `CALLBACKS_GUIDE.md` | **Start here** - Complete implementation guide |
|
||||||
|
| `callbacks_example.go` | 7 working examples to copy |
|
||||||
|
| `CALLBACKS_SUMMARY.md` | Architecture overview |
|
||||||
|
| `README.md` | Full documentation |
|
||||||
|
| `setup_example.go` | Integration examples |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Cheat Sheet
|
||||||
|
|
||||||
|
```go
|
||||||
|
// ===== REQUIRED SETUP =====
|
||||||
|
security.GlobalSecurity.AuthenticateCallback = myAuthFunc
|
||||||
|
security.GlobalSecurity.LoadColumnSecurityCallback = myColFunc
|
||||||
|
security.GlobalSecurity.LoadRowSecurityCallback = myRowFunc
|
||||||
|
security.SetupSecurityProvider(handler, &security.GlobalSecurity)
|
||||||
|
|
||||||
|
// ===== CALLBACK SIGNATURES =====
|
||||||
|
func(r *http.Request) (int, string, error) // Auth
|
||||||
|
func(int, string, string) ([]security.ColumnSecurity, error) // Column
|
||||||
|
func(int, string, string) (security.RowSecurity, error) // Row
|
||||||
|
|
||||||
|
// ===== QUICK EXAMPLES =====
|
||||||
|
// Header auth
|
||||||
|
func(r *http.Request) (int, string, error) {
|
||||||
|
id, _ := strconv.Atoi(r.Header.Get("X-User-ID"))
|
||||||
|
return id, "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mask SSN
|
||||||
|
{Path: []string{"ssn"}, Accesstype: "mask", MaskStart: 5}
|
||||||
|
|
||||||
|
// User isolation
|
||||||
|
{Template: "user_id = {UserID}"}
|
||||||
|
```
|
||||||
418
pkg/security/callbacks_example.go
Normal file
418
pkg/security/callbacks_example.go
Normal file
@ -0,0 +1,418 @@
|
|||||||
|
package security
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
DBM "github.com/bitechdev/GoCore/pkg/models"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
// This file provides example implementations of the required security callbacks.
|
||||||
|
// Copy these functions and modify them to match your authentication and database schema.
|
||||||
|
|
||||||
|
// =============================================================================
|
||||||
|
// EXAMPLE 1: Simple Header-Based Authentication
|
||||||
|
// =============================================================================
|
||||||
|
|
||||||
|
// ExampleAuthenticateFromHeader extracts user ID from X-User-ID header
|
||||||
|
func ExampleAuthenticateFromHeader(r *http.Request) (userID int, roles string, err error) {
|
||||||
|
userIDStr := r.Header.Get("X-User-ID")
|
||||||
|
if userIDStr == "" {
|
||||||
|
return 0, "", fmt.Errorf("X-User-ID header not provided")
|
||||||
|
}
|
||||||
|
|
||||||
|
userID, err = strconv.Atoi(userIDStr)
|
||||||
|
if err != nil {
|
||||||
|
return 0, "", fmt.Errorf("invalid user ID format: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Optionally extract roles
|
||||||
|
roles = r.Header.Get("X-User-Roles") // comma-separated: "admin,manager"
|
||||||
|
|
||||||
|
return userID, roles, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// =============================================================================
|
||||||
|
// EXAMPLE 2: JWT Token Authentication
|
||||||
|
// =============================================================================
|
||||||
|
|
||||||
|
// ExampleAuthenticateFromJWT parses a JWT token and extracts user info
|
||||||
|
// You'll need to import a JWT library like github.com/golang-jwt/jwt/v5
|
||||||
|
func ExampleAuthenticateFromJWT(r *http.Request) (userID int, roles string, err error) {
|
||||||
|
authHeader := r.Header.Get("Authorization")
|
||||||
|
if authHeader == "" {
|
||||||
|
return 0, "", fmt.Errorf("authorization header not provided")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract Bearer token
|
||||||
|
tokenString := strings.TrimPrefix(authHeader, "Bearer ")
|
||||||
|
if tokenString == authHeader {
|
||||||
|
return 0, "", fmt.Errorf("invalid authorization header format")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Parse and validate JWT token
|
||||||
|
// Example using github.com/golang-jwt/jwt/v5:
|
||||||
|
//
|
||||||
|
// token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
|
||||||
|
// return []byte(os.Getenv("JWT_SECRET")), nil
|
||||||
|
// })
|
||||||
|
//
|
||||||
|
// if err != nil || !token.Valid {
|
||||||
|
// return 0, "", fmt.Errorf("invalid token: %v", err)
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// claims := token.Claims.(jwt.MapClaims)
|
||||||
|
// userID = int(claims["user_id"].(float64))
|
||||||
|
// roles = claims["roles"].(string)
|
||||||
|
|
||||||
|
return 0, "", fmt.Errorf("JWT parsing not implemented - see example above")
|
||||||
|
}
|
||||||
|
|
||||||
|
// =============================================================================
|
||||||
|
// EXAMPLE 3: Session Cookie Authentication
|
||||||
|
// =============================================================================
|
||||||
|
|
||||||
|
// ExampleAuthenticateFromSession validates a session cookie
|
||||||
|
func ExampleAuthenticateFromSession(r *http.Request) (userID int, roles string, err error) {
|
||||||
|
sessionCookie, err := r.Cookie("session_id")
|
||||||
|
if err != nil {
|
||||||
|
return 0, "", fmt.Errorf("session cookie not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Validate session against your session store (Redis, database, etc.)
|
||||||
|
// Example:
|
||||||
|
//
|
||||||
|
// session, err := sessionStore.Get(sessionCookie.Value)
|
||||||
|
// if err != nil {
|
||||||
|
// return 0, "", fmt.Errorf("invalid session")
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// userID = session.UserID
|
||||||
|
// roles = session.Roles
|
||||||
|
|
||||||
|
_ = sessionCookie // Suppress unused warning until implemented
|
||||||
|
return 0, "", fmt.Errorf("session validation not implemented - see example above")
|
||||||
|
}
|
||||||
|
|
||||||
|
// =============================================================================
|
||||||
|
// EXAMPLE 4: Column Security - Database Implementation
|
||||||
|
// =============================================================================
|
||||||
|
|
||||||
|
// ExampleLoadColumnSecurityFromDatabase loads column security rules from database
|
||||||
|
// This implementation assumes the following database schema:
|
||||||
|
//
|
||||||
|
// CREATE TABLE core.secacces (
|
||||||
|
// rid_secacces SERIAL PRIMARY KEY,
|
||||||
|
// rid_hub INTEGER,
|
||||||
|
// control TEXT, -- Format: "schema.table.column"
|
||||||
|
// accesstype TEXT, -- "mask" or "hide"
|
||||||
|
// jsonvalue JSONB -- Masking configuration
|
||||||
|
// );
|
||||||
|
//
|
||||||
|
// CREATE TABLE core.hub_link (
|
||||||
|
// rid_hub_parent INTEGER, -- Security group ID
|
||||||
|
// rid_hub_child INTEGER, -- User ID
|
||||||
|
// parent_hubtype TEXT -- 'secgroup'
|
||||||
|
// );
|
||||||
|
func ExampleLoadColumnSecurityFromDatabase(pUserID int, pSchema, pTablename string) ([]ColumnSecurity, error) {
|
||||||
|
colSecList := make([]ColumnSecurity, 0)
|
||||||
|
|
||||||
|
getExtraFilters := func(pStr string) map[string]string {
|
||||||
|
mp := make(map[string]string, 0)
|
||||||
|
for i, val := range strings.Split(pStr, ",") {
|
||||||
|
if i <= 1 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
vals := strings.Split(val, ":")
|
||||||
|
if len(vals) > 1 {
|
||||||
|
mp[vals[0]] = vals[1]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return mp
|
||||||
|
}
|
||||||
|
|
||||||
|
rows, err := DBM.DBConn.Raw(fmt.Sprintf(`
|
||||||
|
SELECT a.rid_secacces, a.control, a.accesstype, a.jsonvalue
|
||||||
|
FROM core.secacces a
|
||||||
|
WHERE a.rid_hub IN (
|
||||||
|
SELECT l.rid_hub_parent
|
||||||
|
FROM core.hub_link l
|
||||||
|
WHERE l.parent_hubtype = 'secgroup'
|
||||||
|
AND l.rid_hub_child = ?
|
||||||
|
)
|
||||||
|
AND control ILIKE '%s.%s%%'
|
||||||
|
`, pSchema, pTablename), pUserID).Rows()
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
if rows != nil {
|
||||||
|
rows.Close()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return colSecList, fmt.Errorf("failed to fetch column security from SQL: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for rows.Next() {
|
||||||
|
var rid int
|
||||||
|
var jsondata []byte
|
||||||
|
var control, accesstype string
|
||||||
|
|
||||||
|
err = rows.Scan(&rid, &control, &accesstype, &jsondata)
|
||||||
|
if err != nil {
|
||||||
|
return colSecList, fmt.Errorf("failed to scan column security: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
parts := strings.Split(control, ",")
|
||||||
|
ids := strings.Split(parts[0], ".")
|
||||||
|
if len(ids) < 3 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonvalue := make(map[string]interface{})
|
||||||
|
if len(jsondata) > 1 {
|
||||||
|
err = json.Unmarshal(jsondata, &jsonvalue)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to parse json: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
colsec := ColumnSecurity{
|
||||||
|
Schema: pSchema,
|
||||||
|
Tablename: pTablename,
|
||||||
|
UserID: pUserID,
|
||||||
|
Path: ids[2:],
|
||||||
|
ExtraFilters: getExtraFilters(control),
|
||||||
|
Accesstype: accesstype,
|
||||||
|
Control: control,
|
||||||
|
ID: int(rid),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse masking configuration from JSON
|
||||||
|
if v, ok := jsonvalue["start"]; ok {
|
||||||
|
if value, ok := v.(float64); ok {
|
||||||
|
colsec.MaskStart = int(value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if v, ok := jsonvalue["end"]; ok {
|
||||||
|
if value, ok := v.(float64); ok {
|
||||||
|
colsec.MaskEnd = int(value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if v, ok := jsonvalue["invert"]; ok {
|
||||||
|
if value, ok := v.(bool); ok {
|
||||||
|
colsec.MaskInvert = value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if v, ok := jsonvalue["char"]; ok {
|
||||||
|
if value, ok := v.(string); ok {
|
||||||
|
colsec.MaskChar = value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
colSecList = append(colSecList, colsec)
|
||||||
|
}
|
||||||
|
|
||||||
|
return colSecList, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// =============================================================================
|
||||||
|
// EXAMPLE 5: Column Security - In-Memory/Static Configuration
|
||||||
|
// =============================================================================
|
||||||
|
|
||||||
|
// ExampleLoadColumnSecurityFromConfig loads column security from static config
|
||||||
|
func ExampleLoadColumnSecurityFromConfig(pUserID int, pSchema, pTablename string) ([]ColumnSecurity, error) {
|
||||||
|
// Example: Define security rules in code or load from config file
|
||||||
|
securityRules := map[string][]ColumnSecurity{
|
||||||
|
"public.employees": {
|
||||||
|
{
|
||||||
|
Schema: "public",
|
||||||
|
Tablename: "employees",
|
||||||
|
Path: []string{"ssn"},
|
||||||
|
Accesstype: "mask",
|
||||||
|
MaskStart: 5,
|
||||||
|
MaskEnd: 0,
|
||||||
|
MaskChar: "*",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Schema: "public",
|
||||||
|
Tablename: "employees",
|
||||||
|
Path: []string{"salary"},
|
||||||
|
Accesstype: "hide",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"public.customers": {
|
||||||
|
{
|
||||||
|
Schema: "public",
|
||||||
|
Tablename: "customers",
|
||||||
|
Path: []string{"credit_card"},
|
||||||
|
Accesstype: "mask",
|
||||||
|
MaskStart: 12,
|
||||||
|
MaskEnd: 0,
|
||||||
|
MaskChar: "*",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
key := fmt.Sprintf("%s.%s", pSchema, pTablename)
|
||||||
|
rules, ok := securityRules[key]
|
||||||
|
if !ok {
|
||||||
|
return []ColumnSecurity{}, nil // No rules for this table
|
||||||
|
}
|
||||||
|
|
||||||
|
// Filter by user ID if needed
|
||||||
|
// For this example, all rules apply to all users
|
||||||
|
return rules, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// =============================================================================
|
||||||
|
// EXAMPLE 6: Row Security - Database Implementation
|
||||||
|
// =============================================================================
|
||||||
|
|
||||||
|
// ExampleLoadRowSecurityFromDatabase loads row security rules from database
|
||||||
|
// This implementation assumes a PostgreSQL function:
|
||||||
|
//
|
||||||
|
// CREATE FUNCTION core.api_sec_rowtemplate(
|
||||||
|
// p_schema TEXT,
|
||||||
|
// p_table TEXT,
|
||||||
|
// p_userid INTEGER
|
||||||
|
// ) RETURNS TABLE (
|
||||||
|
// p_retval INTEGER,
|
||||||
|
// p_errmsg TEXT,
|
||||||
|
// p_template TEXT,
|
||||||
|
// p_block BOOLEAN
|
||||||
|
// );
|
||||||
|
func ExampleLoadRowSecurityFromDatabase(pUserID int, pSchema, pTablename string) (RowSecurity, error) {
|
||||||
|
record := RowSecurity{
|
||||||
|
Schema: pSchema,
|
||||||
|
Tablename: pTablename,
|
||||||
|
UserID: pUserID,
|
||||||
|
}
|
||||||
|
|
||||||
|
rows, err := DBM.DBConn.Raw(`
|
||||||
|
SELECT r.p_retval, r.p_errmsg, r.p_template, r.p_block
|
||||||
|
FROM core.api_sec_rowtemplate(?, ?, ?) r
|
||||||
|
`, pSchema, pTablename, pUserID).Rows()
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
if rows != nil {
|
||||||
|
rows.Close()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return record, fmt.Errorf("failed to fetch row security from SQL: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for rows.Next() {
|
||||||
|
var retval int
|
||||||
|
var errmsg string
|
||||||
|
|
||||||
|
err = rows.Scan(&retval, &errmsg, &record.Template, &record.HasBlock)
|
||||||
|
if err != nil {
|
||||||
|
return record, fmt.Errorf("failed to scan row security: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if retval != 0 {
|
||||||
|
return RowSecurity{}, fmt.Errorf("api_sec_rowtemplate error: %s", errmsg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return record, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// =============================================================================
|
||||||
|
// EXAMPLE 7: Row Security - Static Configuration
|
||||||
|
// =============================================================================
|
||||||
|
|
||||||
|
// ExampleLoadRowSecurityFromConfig loads row security from static config
|
||||||
|
func ExampleLoadRowSecurityFromConfig(pUserID int, pSchema, pTablename string) (RowSecurity, error) {
|
||||||
|
// Define row security templates based on entity
|
||||||
|
templates := map[string]string{
|
||||||
|
"public.orders": "user_id = {UserID}", // Users see only their orders
|
||||||
|
"public.documents": "user_id = {UserID} OR is_public = true", // Users see their docs + public docs
|
||||||
|
"public.employees": "department_id IN (SELECT department_id FROM user_departments WHERE user_id = {UserID})", // Complex filter
|
||||||
|
}
|
||||||
|
|
||||||
|
// Define blocked entities (no access at all)
|
||||||
|
blockedEntities := map[string][]int{
|
||||||
|
"public.admin_logs": {}, // All users blocked (empty list = block all)
|
||||||
|
"public.audit_logs": {1, 2, 3}, // Block users 1, 2, 3
|
||||||
|
}
|
||||||
|
|
||||||
|
key := fmt.Sprintf("%s.%s", pSchema, pTablename)
|
||||||
|
|
||||||
|
// Check if entity is blocked for this user
|
||||||
|
if blockedUsers, ok := blockedEntities[key]; ok {
|
||||||
|
if len(blockedUsers) == 0 {
|
||||||
|
// Block all users
|
||||||
|
return RowSecurity{
|
||||||
|
Schema: pSchema,
|
||||||
|
Tablename: pTablename,
|
||||||
|
UserID: pUserID,
|
||||||
|
HasBlock: true,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
// Check if specific user is blocked
|
||||||
|
for _, blockedUserID := range blockedUsers {
|
||||||
|
if blockedUserID == pUserID {
|
||||||
|
return RowSecurity{
|
||||||
|
Schema: pSchema,
|
||||||
|
Tablename: pTablename,
|
||||||
|
UserID: pUserID,
|
||||||
|
HasBlock: true,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get template for this entity
|
||||||
|
template, ok := templates[key]
|
||||||
|
if !ok {
|
||||||
|
// No row security defined - allow all rows
|
||||||
|
return RowSecurity{
|
||||||
|
Schema: pSchema,
|
||||||
|
Tablename: pTablename,
|
||||||
|
UserID: pUserID,
|
||||||
|
Template: "",
|
||||||
|
HasBlock: false,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return RowSecurity{
|
||||||
|
Schema: pSchema,
|
||||||
|
Tablename: pTablename,
|
||||||
|
UserID: pUserID,
|
||||||
|
Template: template,
|
||||||
|
HasBlock: false,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// =============================================================================
|
||||||
|
// SETUP HELPER: Configure All Callbacks
|
||||||
|
// =============================================================================
|
||||||
|
|
||||||
|
// SetupCallbacksExample shows how to configure all callbacks
|
||||||
|
func SetupCallbacksExample() {
|
||||||
|
// Option 1: Use database-backed security (production)
|
||||||
|
GlobalSecurity.AuthenticateCallback = ExampleAuthenticateFromJWT
|
||||||
|
GlobalSecurity.LoadColumnSecurityCallback = ExampleLoadColumnSecurityFromDatabase
|
||||||
|
GlobalSecurity.LoadRowSecurityCallback = ExampleLoadRowSecurityFromDatabase
|
||||||
|
|
||||||
|
// Option 2: Use static configuration (development/testing)
|
||||||
|
// GlobalSecurity.AuthenticateCallback = ExampleAuthenticateFromHeader
|
||||||
|
// GlobalSecurity.LoadColumnSecurityCallback = ExampleLoadColumnSecurityFromConfig
|
||||||
|
// GlobalSecurity.LoadRowSecurityCallback = ExampleLoadRowSecurityFromConfig
|
||||||
|
|
||||||
|
// Option 3: Mix and match
|
||||||
|
// GlobalSecurity.AuthenticateCallback = ExampleAuthenticateFromJWT
|
||||||
|
// GlobalSecurity.LoadColumnSecurityCallback = ExampleLoadColumnSecurityFromConfig
|
||||||
|
// GlobalSecurity.LoadRowSecurityCallback = ExampleLoadRowSecurityFromDatabase
|
||||||
|
}
|
||||||
244
pkg/security/hooks.go
Normal file
244
pkg/security/hooks.go
Normal file
@ -0,0 +1,244 @@
|
|||||||
|
package security
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/restheadspec"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RegisterSecurityHooks registers all security-related hooks with the handler
|
||||||
|
func RegisterSecurityHooks(handler *restheadspec.Handler, securityList *SecurityList) {
|
||||||
|
|
||||||
|
// Hook 1: BeforeRead - Load security rules
|
||||||
|
handler.Hooks().Register(restheadspec.BeforeRead, func(hookCtx *restheadspec.HookContext) error {
|
||||||
|
return loadSecurityRules(hookCtx, securityList)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Hook 2: BeforeScan - Apply row-level security filters
|
||||||
|
handler.Hooks().Register(restheadspec.BeforeScan, func(hookCtx *restheadspec.HookContext) error {
|
||||||
|
return applyRowSecurity(hookCtx, securityList)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Hook 3: AfterRead - Apply column-level security (masking)
|
||||||
|
handler.Hooks().Register(restheadspec.AfterRead, func(hookCtx *restheadspec.HookContext) error {
|
||||||
|
return applyColumnSecurity(hookCtx, securityList)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Hook 4 (Optional): Audit logging
|
||||||
|
handler.Hooks().Register(restheadspec.AfterRead, func(hookCtx *restheadspec.HookContext) error {
|
||||||
|
return logDataAccess(hookCtx)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// loadSecurityRules loads security configuration for the user and entity
|
||||||
|
func loadSecurityRules(hookCtx *restheadspec.HookContext, securityList *SecurityList) error {
|
||||||
|
// Extract user ID from context
|
||||||
|
userID, ok := GetUserID(hookCtx.Context)
|
||||||
|
if !ok {
|
||||||
|
logger.Warn("No user ID in context for security check")
|
||||||
|
return fmt.Errorf("authentication required")
|
||||||
|
}
|
||||||
|
|
||||||
|
schema := hookCtx.Schema
|
||||||
|
tablename := hookCtx.Entity
|
||||||
|
|
||||||
|
logger.Debug("Loading security rules for user=%d, schema=%s, table=%s", userID, schema, tablename)
|
||||||
|
|
||||||
|
// Load column security rules from database
|
||||||
|
err := securityList.LoadColumnSecurity(userID, schema, tablename, false)
|
||||||
|
if err != nil {
|
||||||
|
logger.Warn("Failed to load column security: %v", err)
|
||||||
|
// Don't fail the request if no security rules exist
|
||||||
|
// return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load row security rules from database
|
||||||
|
_, err = securityList.LoadRowSecurity(userID, schema, tablename, false)
|
||||||
|
if err != nil {
|
||||||
|
logger.Warn("Failed to load row security: %v", err)
|
||||||
|
// Don't fail the request if no security rules exist
|
||||||
|
// return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// applyRowSecurity applies row-level security filters to the query
|
||||||
|
func applyRowSecurity(hookCtx *restheadspec.HookContext, securityList *SecurityList) error {
|
||||||
|
userID, ok := GetUserID(hookCtx.Context)
|
||||||
|
if !ok {
|
||||||
|
return nil // No user context, skip
|
||||||
|
}
|
||||||
|
|
||||||
|
schema := hookCtx.Schema
|
||||||
|
tablename := hookCtx.Entity
|
||||||
|
|
||||||
|
// Get row security template
|
||||||
|
rowSec, err := securityList.GetRowSecurityTemplate(userID, schema, tablename)
|
||||||
|
if err != nil {
|
||||||
|
// No row security defined, allow query to proceed
|
||||||
|
logger.Debug("No row security for %s.%s@%d: %v", schema, tablename, userID, err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if user has a blocking rule
|
||||||
|
if rowSec.HasBlock {
|
||||||
|
logger.Warn("User %d blocked from accessing %s.%s", userID, schema, tablename)
|
||||||
|
return fmt.Errorf("access denied to %s", tablename)
|
||||||
|
}
|
||||||
|
|
||||||
|
// If there's a security template, apply it as a WHERE clause
|
||||||
|
if rowSec.Template != "" {
|
||||||
|
// Get primary key name from model
|
||||||
|
modelType := reflect.TypeOf(hookCtx.Model)
|
||||||
|
if modelType.Kind() == reflect.Ptr {
|
||||||
|
modelType = modelType.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find primary key field
|
||||||
|
pkName := "id" // default
|
||||||
|
for i := 0; i < modelType.NumField(); i++ {
|
||||||
|
field := modelType.Field(i)
|
||||||
|
if tag := field.Tag.Get("bun"); tag != "" {
|
||||||
|
// Check for primary key tag
|
||||||
|
if contains(tag, "pk") || contains(tag, "primary_key") {
|
||||||
|
if sqlName := extractSQLName(tag); sqlName != "" {
|
||||||
|
pkName = sqlName
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate the WHERE clause from template
|
||||||
|
whereClause := rowSec.GetTemplate(pkName, modelType)
|
||||||
|
|
||||||
|
logger.Info("Applying row security filter for user %d on %s.%s: %s",
|
||||||
|
userID, schema, tablename, whereClause)
|
||||||
|
|
||||||
|
// Apply the WHERE clause to the query
|
||||||
|
// The query is in hookCtx.Query
|
||||||
|
if selectQuery, ok := hookCtx.Query.(interface {
|
||||||
|
Where(string, ...interface{}) interface{}
|
||||||
|
}); ok {
|
||||||
|
hookCtx.Query = selectQuery.Where(whereClause)
|
||||||
|
} else {
|
||||||
|
logger.Error("Unable to apply WHERE clause - query doesn't support Where method")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// applyColumnSecurity applies column-level security (masking/hiding) to results
|
||||||
|
func applyColumnSecurity(hookCtx *restheadspec.HookContext, securityList *SecurityList) error {
|
||||||
|
userID, ok := GetUserID(hookCtx.Context)
|
||||||
|
if !ok {
|
||||||
|
return nil // No user context, skip
|
||||||
|
}
|
||||||
|
|
||||||
|
schema := hookCtx.Schema
|
||||||
|
tablename := hookCtx.Entity
|
||||||
|
|
||||||
|
// Get result data
|
||||||
|
result := hookCtx.Result
|
||||||
|
if result == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Debug("Applying column security for user=%d, schema=%s, table=%s", userID, schema, tablename)
|
||||||
|
|
||||||
|
// Get model type
|
||||||
|
modelType := reflect.TypeOf(hookCtx.Model)
|
||||||
|
if modelType.Kind() == reflect.Ptr {
|
||||||
|
modelType = modelType.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply column security masking
|
||||||
|
resultValue := reflect.ValueOf(result)
|
||||||
|
if resultValue.Kind() == reflect.Ptr {
|
||||||
|
resultValue = resultValue.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
err, maskedResult := securityList.ApplyColumnSecurity(resultValue, modelType, userID, schema, tablename)
|
||||||
|
if err != nil {
|
||||||
|
logger.Warn("Column security error: %v", err)
|
||||||
|
// Don't fail the request, just log the issue
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update the result with masked data
|
||||||
|
if maskedResult.IsValid() && maskedResult.CanInterface() {
|
||||||
|
hookCtx.Result = maskedResult.Interface()
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// logDataAccess logs all data access for audit purposes
|
||||||
|
func logDataAccess(hookCtx *restheadspec.HookContext) error {
|
||||||
|
userID, _ := GetUserID(hookCtx.Context)
|
||||||
|
|
||||||
|
logger.Info("AUDIT: User %d accessed %s.%s with filters: %+v",
|
||||||
|
userID,
|
||||||
|
hookCtx.Schema,
|
||||||
|
hookCtx.Entity,
|
||||||
|
hookCtx.Options.Filters,
|
||||||
|
)
|
||||||
|
|
||||||
|
// TODO: Write to audit log table or external audit service
|
||||||
|
// auditLog := AuditLog{
|
||||||
|
// UserID: userID,
|
||||||
|
// Schema: hookCtx.Schema,
|
||||||
|
// Entity: hookCtx.Entity,
|
||||||
|
// Action: "READ",
|
||||||
|
// Timestamp: time.Now(),
|
||||||
|
// Filters: hookCtx.Options.Filters,
|
||||||
|
// }
|
||||||
|
// db.Create(&auditLog)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper functions
|
||||||
|
|
||||||
|
func contains(s, substr string) bool {
|
||||||
|
return len(s) >= len(substr) && s[:len(substr)] == substr ||
|
||||||
|
len(s) > len(substr) && s[len(s)-len(substr):] == substr
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractSQLName(tag string) string {
|
||||||
|
// Simple parser for "column:name" or just "name"
|
||||||
|
// This is a simplified version
|
||||||
|
parts := splitTag(tag, ',')
|
||||||
|
for _, part := range parts {
|
||||||
|
if part != "" && !contains(part, ":") {
|
||||||
|
return part
|
||||||
|
}
|
||||||
|
if contains(part, "column:") {
|
||||||
|
return part[7:] // Skip "column:"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func splitTag(tag string, sep rune) []string {
|
||||||
|
var parts []string
|
||||||
|
var current string
|
||||||
|
for _, ch := range tag {
|
||||||
|
if ch == sep {
|
||||||
|
if current != "" {
|
||||||
|
parts = append(parts, current)
|
||||||
|
current = ""
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
current += string(ch)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if current != "" {
|
||||||
|
parts = append(parts, current)
|
||||||
|
}
|
||||||
|
return parts
|
||||||
|
}
|
||||||
54
pkg/security/middleware.go
Normal file
54
pkg/security/middleware.go
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
package security
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// Context keys for user information
|
||||||
|
UserIDKey = "user_id"
|
||||||
|
UserRolesKey = "user_roles"
|
||||||
|
UserTokenKey = "user_token"
|
||||||
|
)
|
||||||
|
|
||||||
|
// AuthMiddleware extracts user authentication from request and adds to context
|
||||||
|
// This should be applied before the ResolveSpec handler
|
||||||
|
// Uses GlobalSecurity.AuthenticateCallback if set, otherwise returns error
|
||||||
|
func AuthMiddleware(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Check if callback is set
|
||||||
|
if GlobalSecurity.AuthenticateCallback == nil {
|
||||||
|
http.Error(w, "AuthenticateCallback not set - you must provide an authentication callback", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Call the user-provided authentication callback
|
||||||
|
userID, roles, err := GlobalSecurity.AuthenticateCallback(r)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, "Authentication failed: "+err.Error(), http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add user information to context
|
||||||
|
ctx := context.WithValue(r.Context(), UserIDKey, userID)
|
||||||
|
if roles != "" {
|
||||||
|
ctx = context.WithValue(ctx, UserRolesKey, roles)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Continue with authenticated context
|
||||||
|
next.ServeHTTP(w, r.WithContext(ctx))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUserID extracts the user ID from context
|
||||||
|
func GetUserID(ctx context.Context) (int, bool) {
|
||||||
|
userID, ok := ctx.Value(UserIDKey).(int)
|
||||||
|
return userID, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUserRoles extracts user roles from context
|
||||||
|
func GetUserRoles(ctx context.Context) (string, bool) {
|
||||||
|
roles, ok := ctx.Value(UserRolesKey).(string)
|
||||||
|
return roles, ok
|
||||||
|
}
|
||||||
460
pkg/security/provider.go
Normal file
460
pkg/security/provider.go
Normal file
@ -0,0 +1,460 @@
|
|||||||
|
package security
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"reflect"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/reflection"
|
||||||
|
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
"github.com/tidwall/sjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ColumnSecurity struct {
|
||||||
|
Schema string
|
||||||
|
Tablename string
|
||||||
|
Path []string
|
||||||
|
ExtraFilters map[string]string
|
||||||
|
UserID int
|
||||||
|
Accesstype string `json:"accesstype"`
|
||||||
|
MaskStart int
|
||||||
|
MaskEnd int
|
||||||
|
MaskInvert bool
|
||||||
|
MaskChar string
|
||||||
|
Control string `json:"control"`
|
||||||
|
ID int `json:"id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type RowSecurity struct {
|
||||||
|
Schema string
|
||||||
|
Tablename string
|
||||||
|
Template string
|
||||||
|
HasBlock bool
|
||||||
|
UserID int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *RowSecurity) GetTemplate(pPrimaryKeyName string, pModelType reflect.Type) string {
|
||||||
|
str := m.Template
|
||||||
|
str = strings.ReplaceAll(str, "{PrimaryKeyName}", pPrimaryKeyName)
|
||||||
|
str = strings.ReplaceAll(str, "{TableName}", m.Tablename)
|
||||||
|
str = strings.ReplaceAll(str, "{SchemaName}", m.Schema)
|
||||||
|
str = strings.ReplaceAll(str, "{UserID}", fmt.Sprintf("%d", m.UserID))
|
||||||
|
return str
|
||||||
|
}
|
||||||
|
|
||||||
|
// Callback function types for customizing security behavior
|
||||||
|
type (
|
||||||
|
// AuthenticateFunc extracts user ID and roles from HTTP request
|
||||||
|
// Return userID, roles, error. If error is not nil, request will be rejected.
|
||||||
|
AuthenticateFunc func(r *http.Request) (userID int, roles string, err error)
|
||||||
|
|
||||||
|
// LoadColumnSecurityFunc loads column security rules for a user and entity
|
||||||
|
// Override this to customize how column security is loaded from your data source
|
||||||
|
LoadColumnSecurityFunc func(pUserID int, pSchema, pTablename string) ([]ColumnSecurity, error)
|
||||||
|
|
||||||
|
// LoadRowSecurityFunc loads row security rules for a user and entity
|
||||||
|
// Override this to customize how row security is loaded from your data source
|
||||||
|
LoadRowSecurityFunc func(pUserID int, pSchema, pTablename string) (RowSecurity, error)
|
||||||
|
)
|
||||||
|
|
||||||
|
type SecurityList struct {
|
||||||
|
ColumnSecurityMutex sync.RWMutex
|
||||||
|
ColumnSecurity map[string][]ColumnSecurity
|
||||||
|
RowSecurityMutex sync.RWMutex
|
||||||
|
RowSecurity map[string]RowSecurity
|
||||||
|
|
||||||
|
// Overridable callbacks
|
||||||
|
AuthenticateCallback AuthenticateFunc
|
||||||
|
LoadColumnSecurityCallback LoadColumnSecurityFunc
|
||||||
|
LoadRowSecurityCallback LoadRowSecurityFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
const SECURITY_CONTEXT_KEY = "SecurityList"
|
||||||
|
|
||||||
|
var GlobalSecurity SecurityList
|
||||||
|
|
||||||
|
// SetSecurityMiddleware adds security context to requests
|
||||||
|
func SetSecurityMiddleware(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ctx := context.WithValue(r.Context(), SECURITY_CONTEXT_KEY, &GlobalSecurity)
|
||||||
|
next.ServeHTTP(w, r.WithContext(ctx))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func maskString(pString string, maskStart, maskEnd int, maskChar string, invert bool) string {
|
||||||
|
strLen := len(pString)
|
||||||
|
middleIndex := (strLen / 2)
|
||||||
|
newStr := ""
|
||||||
|
if maskStart == 0 && maskEnd == 0 {
|
||||||
|
maskStart = strLen
|
||||||
|
maskEnd = strLen
|
||||||
|
}
|
||||||
|
if maskEnd > strLen {
|
||||||
|
maskEnd = strLen
|
||||||
|
}
|
||||||
|
if maskStart > strLen {
|
||||||
|
maskStart = strLen
|
||||||
|
}
|
||||||
|
if maskChar == "" {
|
||||||
|
maskChar = "*"
|
||||||
|
}
|
||||||
|
for index, char := range pString {
|
||||||
|
if invert && index >= middleIndex-maskStart && index <= middleIndex {
|
||||||
|
newStr = newStr + maskChar
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if invert && index <= middleIndex+maskEnd && index >= middleIndex {
|
||||||
|
newStr = newStr + maskChar
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !invert && index <= maskStart {
|
||||||
|
newStr = newStr + maskChar
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !invert && index >= strLen-1-maskEnd {
|
||||||
|
newStr = newStr + maskChar
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
newStr = newStr + string(char)
|
||||||
|
}
|
||||||
|
|
||||||
|
return newStr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *SecurityList) ColumSecurityApplyOnRecord(prevRecord reflect.Value, newRecord reflect.Value, modelType reflect.Type, pUserID int, pSchema, pTablename string) ([]string, error) {
|
||||||
|
cols := make([]string, 0)
|
||||||
|
if m.ColumnSecurity == nil {
|
||||||
|
return cols, fmt.Errorf("security not initialized")
|
||||||
|
}
|
||||||
|
|
||||||
|
if prevRecord.Type() != newRecord.Type() {
|
||||||
|
logger.Error("prev:%s and new:%s record type mismatch", prevRecord.Type(), newRecord.Type())
|
||||||
|
return cols, fmt.Errorf("prev and new record type mismatch")
|
||||||
|
}
|
||||||
|
|
||||||
|
m.ColumnSecurityMutex.RLock()
|
||||||
|
defer m.ColumnSecurityMutex.RUnlock()
|
||||||
|
|
||||||
|
colsecList, ok := m.ColumnSecurity[fmt.Sprintf("%s.%s@%d", pSchema, pTablename, pUserID)]
|
||||||
|
if !ok || colsecList == nil {
|
||||||
|
return cols, fmt.Errorf("no security data")
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, colsec := range colsecList {
|
||||||
|
if !(strings.EqualFold(colsec.Accesstype, "mask") || strings.EqualFold(colsec.Accesstype, "hide")) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
lastRecords := interateStruct(prevRecord)
|
||||||
|
newRecords := interateStruct(newRecord)
|
||||||
|
var lastLoopField, lastLoopNewField reflect.Value
|
||||||
|
pathLen := len(colsec.Path)
|
||||||
|
for i, path := range colsec.Path {
|
||||||
|
var nameType, fieldName string
|
||||||
|
if len(newRecords) == 0 {
|
||||||
|
if lastLoopNewField.IsValid() && lastLoopField.IsValid() && i < pathLen-1 {
|
||||||
|
lastLoopNewField.Set(lastLoopField)
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
for ri := range newRecords {
|
||||||
|
if !newRecords[ri].IsValid() || !lastRecords[ri].IsValid() {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
var field, oldField reflect.Value
|
||||||
|
|
||||||
|
columnData := reflection.GetModelColumnDetail(newRecords[ri])
|
||||||
|
lastColumnData := reflection.GetModelColumnDetail(lastRecords[ri])
|
||||||
|
for i, cols := range columnData {
|
||||||
|
if cols.SQLName != "" && strings.EqualFold(cols.SQLName, path) {
|
||||||
|
nameType = "sql"
|
||||||
|
fieldName = cols.SQLName
|
||||||
|
field = cols.FieldValue
|
||||||
|
oldField = lastColumnData[i].FieldValue
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if cols.Name != "" && strings.EqualFold(cols.Name, path) {
|
||||||
|
nameType = "struct"
|
||||||
|
fieldName = cols.Name
|
||||||
|
field = cols.FieldValue
|
||||||
|
oldField = lastColumnData[i].FieldValue
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !field.IsValid() || !oldField.IsValid() {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
lastLoopField = oldField
|
||||||
|
lastLoopNewField = field
|
||||||
|
|
||||||
|
if i == pathLen-1 {
|
||||||
|
if strings.Contains(strings.ToLower(fieldName), "json") {
|
||||||
|
prevSrc := oldField.Bytes()
|
||||||
|
newSrc := field.Bytes()
|
||||||
|
pathstr := strings.Join(colsec.Path, ".")
|
||||||
|
prevPathValue := gjson.GetBytes(prevSrc, pathstr)
|
||||||
|
newBytes, err := sjson.SetBytes(newSrc, pathstr, prevPathValue.Str)
|
||||||
|
if err == nil {
|
||||||
|
if field.CanSet() {
|
||||||
|
field.SetBytes(newBytes)
|
||||||
|
} else {
|
||||||
|
logger.Warn("Value not settable: %v", field)
|
||||||
|
cols = append(cols, pathstr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
if nameType == "sql" {
|
||||||
|
if strings.EqualFold(colsec.Accesstype, "mask") || strings.EqualFold(colsec.Accesstype, "hide") {
|
||||||
|
field.Set(oldField)
|
||||||
|
cols = append(cols, strings.Join(colsec.Path, "."))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
lastRecords = interateStruct(field)
|
||||||
|
newRecords = interateStruct(oldField)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return cols, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func interateStruct(val reflect.Value) []reflect.Value {
|
||||||
|
list := make([]reflect.Value, 0)
|
||||||
|
|
||||||
|
switch val.Kind() {
|
||||||
|
case reflect.Pointer, reflect.Interface:
|
||||||
|
elem := val.Elem()
|
||||||
|
if elem.IsValid() {
|
||||||
|
list = append(list, interateStruct(elem)...)
|
||||||
|
}
|
||||||
|
return list
|
||||||
|
case reflect.Array, reflect.Slice:
|
||||||
|
for i := 0; i < val.Len(); i++ {
|
||||||
|
elem := val.Index(i)
|
||||||
|
if !elem.IsValid() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
list = append(list, interateStruct(elem)...)
|
||||||
|
}
|
||||||
|
return list
|
||||||
|
case reflect.Struct:
|
||||||
|
list = append(list, val)
|
||||||
|
return list
|
||||||
|
default:
|
||||||
|
return list
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func setColSecValue(fieldsrc reflect.Value, colsec ColumnSecurity, fieldTypeName string) (int, reflect.Value) {
|
||||||
|
fieldval := fieldsrc
|
||||||
|
if fieldsrc.Kind() == reflect.Pointer || fieldsrc.Kind() == reflect.Interface {
|
||||||
|
fieldval = fieldval.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.Contains(strings.ToLower(fieldval.Kind().String()), "int") &&
|
||||||
|
(strings.EqualFold(colsec.Accesstype, "mask") || strings.EqualFold(colsec.Accesstype, "hide")) {
|
||||||
|
if fieldval.CanInt() && fieldval.CanSet() {
|
||||||
|
fieldval.SetInt(0)
|
||||||
|
}
|
||||||
|
} else if (strings.Contains(strings.ToLower(fieldval.Kind().String()), "time") ||
|
||||||
|
strings.Contains(strings.ToLower(fieldval.Kind().String()), "date")) &&
|
||||||
|
(strings.EqualFold(colsec.Accesstype, "mask") || strings.EqualFold(colsec.Accesstype, "hide")) {
|
||||||
|
fieldval.SetZero()
|
||||||
|
} else if strings.Contains(strings.ToLower(fieldval.Kind().String()), "string") {
|
||||||
|
strVal := fieldval.String()
|
||||||
|
if strings.EqualFold(colsec.Accesstype, "mask") {
|
||||||
|
fieldval.SetString(maskString(strVal, colsec.MaskStart, colsec.MaskEnd, colsec.MaskChar, colsec.MaskInvert))
|
||||||
|
} else if strings.EqualFold(colsec.Accesstype, "hide") {
|
||||||
|
fieldval.SetString("")
|
||||||
|
}
|
||||||
|
} else if strings.Contains(fieldTypeName, "json") &&
|
||||||
|
(strings.EqualFold(colsec.Accesstype, "mask") || strings.EqualFold(colsec.Accesstype, "hide")) {
|
||||||
|
if len(colsec.Path) < 2 {
|
||||||
|
return 1, fieldval
|
||||||
|
}
|
||||||
|
pathstr := strings.Join(colsec.Path, ".")
|
||||||
|
src := fieldval.Bytes()
|
||||||
|
pathValue := gjson.GetBytes(src, pathstr)
|
||||||
|
strValue := pathValue.String()
|
||||||
|
if strings.EqualFold(colsec.Accesstype, "mask") {
|
||||||
|
strValue = maskString(strValue, colsec.MaskStart, colsec.MaskEnd, colsec.MaskChar, colsec.MaskInvert)
|
||||||
|
} else if strings.EqualFold(colsec.Accesstype, "hide") {
|
||||||
|
strValue = ""
|
||||||
|
}
|
||||||
|
newBytes, err := sjson.SetBytes(src, pathstr, strValue)
|
||||||
|
if err == nil {
|
||||||
|
fieldval.SetBytes(newBytes)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 0, fieldsrc
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *SecurityList) ApplyColumnSecurity(records reflect.Value, modelType reflect.Type, pUserID int, pSchema, pTablename string) (error, reflect.Value) {
|
||||||
|
defer logger.CatchPanic("ApplyColumnSecurity")
|
||||||
|
|
||||||
|
if m.ColumnSecurity == nil {
|
||||||
|
return fmt.Errorf("security not initialized"), records
|
||||||
|
}
|
||||||
|
|
||||||
|
m.ColumnSecurityMutex.RLock()
|
||||||
|
defer m.ColumnSecurityMutex.RUnlock()
|
||||||
|
|
||||||
|
colsecList, ok := m.ColumnSecurity[fmt.Sprintf("%s.%s@%d", pSchema, pTablename, pUserID)]
|
||||||
|
if !ok || colsecList == nil {
|
||||||
|
return fmt.Errorf("no security data"), records
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, colsec := range colsecList {
|
||||||
|
if !(strings.EqualFold(colsec.Accesstype, "mask") || strings.EqualFold(colsec.Accesstype, "hide")) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if records.Kind() == reflect.Array || records.Kind() == reflect.Slice {
|
||||||
|
for i := 0; i < records.Len(); i++ {
|
||||||
|
record := records.Index(i)
|
||||||
|
if !record.IsValid() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
lastRecord := interateStruct(record)
|
||||||
|
pathLen := len(colsec.Path)
|
||||||
|
for i, path := range colsec.Path {
|
||||||
|
var field reflect.Value
|
||||||
|
var nameType, fieldName string
|
||||||
|
if len(lastRecord) == 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
columnData := reflection.GetModelColumnDetail(lastRecord[0])
|
||||||
|
for _, cols := range columnData {
|
||||||
|
if cols.SQLName != "" && strings.EqualFold(cols.SQLName, path) {
|
||||||
|
nameType = "sql"
|
||||||
|
fieldName = cols.SQLName
|
||||||
|
field = cols.FieldValue
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if cols.Name != "" && strings.EqualFold(cols.Name, path) {
|
||||||
|
nameType = "struct"
|
||||||
|
fieldName = cols.Name
|
||||||
|
field = cols.FieldValue
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if i == pathLen-1 {
|
||||||
|
if nameType == "sql" || nameType == "struct" {
|
||||||
|
setColSecValue(field, colsec, fieldName)
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if field.IsValid() {
|
||||||
|
lastRecord = interateStruct(field)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, records
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *SecurityList) LoadColumnSecurity(pUserID int, pSchema, pTablename string, pOverwrite bool) error {
|
||||||
|
// Use the callback if provided
|
||||||
|
if m.LoadColumnSecurityCallback == nil {
|
||||||
|
return fmt.Errorf("LoadColumnSecurityCallback not set - you must provide a callback function")
|
||||||
|
}
|
||||||
|
|
||||||
|
m.ColumnSecurityMutex.Lock()
|
||||||
|
defer m.ColumnSecurityMutex.Unlock()
|
||||||
|
|
||||||
|
if m.ColumnSecurity == nil {
|
||||||
|
m.ColumnSecurity = make(map[string][]ColumnSecurity, 0)
|
||||||
|
}
|
||||||
|
secKey := fmt.Sprintf("%s.%s@%d", pSchema, pTablename, pUserID)
|
||||||
|
|
||||||
|
if pOverwrite || m.ColumnSecurity[secKey] == nil {
|
||||||
|
m.ColumnSecurity[secKey] = make([]ColumnSecurity, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Call the user-provided callback to load security rules
|
||||||
|
colSecList, err := m.LoadColumnSecurityCallback(pUserID, pSchema, pTablename)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("LoadColumnSecurityCallback failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.ColumnSecurity[secKey] = colSecList
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *SecurityList) ClearSecurity(pUserID int, pSchema, pTablename string) error {
|
||||||
|
var filtered []ColumnSecurity
|
||||||
|
m.ColumnSecurityMutex.Lock()
|
||||||
|
defer m.ColumnSecurityMutex.Unlock()
|
||||||
|
|
||||||
|
secKey := fmt.Sprintf("%s.%s@%d", pSchema, pTablename, pUserID)
|
||||||
|
list, ok := m.ColumnSecurity[secKey]
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, cs := range list {
|
||||||
|
if !(cs.Schema == pSchema && cs.Tablename == pTablename && cs.UserID == pUserID) {
|
||||||
|
filtered = append(filtered, cs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
m.ColumnSecurity[secKey] = filtered
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *SecurityList) LoadRowSecurity(pUserID int, pSchema, pTablename string, pOverwrite bool) (RowSecurity, error) {
|
||||||
|
// Use the callback if provided
|
||||||
|
if m.LoadRowSecurityCallback == nil {
|
||||||
|
return RowSecurity{}, fmt.Errorf("LoadRowSecurityCallback not set - you must provide a callback function")
|
||||||
|
}
|
||||||
|
|
||||||
|
m.RowSecurityMutex.Lock()
|
||||||
|
defer m.RowSecurityMutex.Unlock()
|
||||||
|
|
||||||
|
if m.RowSecurity == nil {
|
||||||
|
m.RowSecurity = make(map[string]RowSecurity, 0)
|
||||||
|
}
|
||||||
|
secKey := fmt.Sprintf("%s.%s@%d", pSchema, pTablename, pUserID)
|
||||||
|
|
||||||
|
// Call the user-provided callback to load security rules
|
||||||
|
record, err := m.LoadRowSecurityCallback(pUserID, pSchema, pTablename)
|
||||||
|
if err != nil {
|
||||||
|
return RowSecurity{}, fmt.Errorf("LoadRowSecurityCallback failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.RowSecurity[secKey] = record
|
||||||
|
return record, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *SecurityList) GetRowSecurityTemplate(pUserID int, pSchema, pTablename string) (RowSecurity, error) {
|
||||||
|
defer logger.CatchPanic("GetRowSecurityTemplate")
|
||||||
|
|
||||||
|
if m.RowSecurity == nil {
|
||||||
|
return RowSecurity{}, fmt.Errorf("security not initialized")
|
||||||
|
}
|
||||||
|
|
||||||
|
m.RowSecurityMutex.RLock()
|
||||||
|
defer m.RowSecurityMutex.RUnlock()
|
||||||
|
|
||||||
|
rowSec, ok := m.RowSecurity[fmt.Sprintf("%s.%s@%d", pSchema, pTablename, pUserID)]
|
||||||
|
if !ok {
|
||||||
|
return RowSecurity{}, fmt.Errorf("no security data")
|
||||||
|
}
|
||||||
|
|
||||||
|
return rowSec, nil
|
||||||
|
}
|
||||||
155
pkg/security/setup_example.go
Normal file
155
pkg/security/setup_example.go
Normal file
@ -0,0 +1,155 @@
|
|||||||
|
package security
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/restheadspec"
|
||||||
|
"github.com/gorilla/mux"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SetupSecurityProvider initializes and configures the security provider
|
||||||
|
// This should be called when setting up your HTTP server
|
||||||
|
//
|
||||||
|
// IMPORTANT: You MUST configure the callbacks before calling this function:
|
||||||
|
// - GlobalSecurity.AuthenticateCallback
|
||||||
|
// - GlobalSecurity.LoadColumnSecurityCallback
|
||||||
|
// - GlobalSecurity.LoadRowSecurityCallback
|
||||||
|
//
|
||||||
|
// Example usage in your main.go or server setup:
|
||||||
|
//
|
||||||
|
// // Step 1: Configure callbacks (REQUIRED)
|
||||||
|
// security.GlobalSecurity.AuthenticateCallback = myAuthFunction
|
||||||
|
// security.GlobalSecurity.LoadColumnSecurityCallback = myLoadColumnSecurityFunction
|
||||||
|
// security.GlobalSecurity.LoadRowSecurityCallback = myLoadRowSecurityFunction
|
||||||
|
//
|
||||||
|
// // Step 2: Setup security provider
|
||||||
|
// handler := restheadspec.NewHandlerWithGORM(db)
|
||||||
|
// security.SetupSecurityProvider(handler, &security.GlobalSecurity)
|
||||||
|
//
|
||||||
|
// // Step 3: Apply middleware
|
||||||
|
// router.Use(mux.MiddlewareFunc(security.AuthMiddleware))
|
||||||
|
// router.Use(mux.MiddlewareFunc(security.SetSecurityMiddleware))
|
||||||
|
//
|
||||||
|
func SetupSecurityProvider(handler *restheadspec.Handler, securityList *SecurityList) error {
|
||||||
|
// Validate that required callbacks are configured
|
||||||
|
if securityList.AuthenticateCallback == nil {
|
||||||
|
return fmt.Errorf("AuthenticateCallback must be set before calling SetupSecurityProvider")
|
||||||
|
}
|
||||||
|
if securityList.LoadColumnSecurityCallback == nil {
|
||||||
|
return fmt.Errorf("LoadColumnSecurityCallback must be set before calling SetupSecurityProvider")
|
||||||
|
}
|
||||||
|
if securityList.LoadRowSecurityCallback == nil {
|
||||||
|
return fmt.Errorf("LoadRowSecurityCallback must be set before calling SetupSecurityProvider")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize security maps if needed
|
||||||
|
if securityList.ColumnSecurity == nil {
|
||||||
|
securityList.ColumnSecurity = make(map[string][]ColumnSecurity)
|
||||||
|
}
|
||||||
|
if securityList.RowSecurity == nil {
|
||||||
|
securityList.RowSecurity = make(map[string]RowSecurity)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register all security hooks
|
||||||
|
RegisterSecurityHooks(handler, securityList)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Chain creates a middleware chain
|
||||||
|
func Chain(middlewares ...func(http.Handler) http.Handler) func(http.Handler) http.Handler {
|
||||||
|
return func(final http.Handler) http.Handler {
|
||||||
|
for i := len(middlewares) - 1; i >= 0; i-- {
|
||||||
|
final = middlewares[i](final)
|
||||||
|
}
|
||||||
|
return final
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CompleteExample shows a full integration example with Gorilla Mux
|
||||||
|
func CompleteExample(db *gorm.DB) (http.Handler, error) {
|
||||||
|
// Step 1: Create the ResolveSpec handler
|
||||||
|
handler := restheadspec.NewHandlerWithGORM(db)
|
||||||
|
|
||||||
|
// Step 2: Register your models
|
||||||
|
// handler.RegisterModel("public", "users", User{})
|
||||||
|
// handler.RegisterModel("public", "orders", Order{})
|
||||||
|
|
||||||
|
// Step 3: Configure security callbacks (REQUIRED!)
|
||||||
|
// See callbacks_example.go for example implementations
|
||||||
|
GlobalSecurity.AuthenticateCallback = ExampleAuthenticateFromHeader
|
||||||
|
GlobalSecurity.LoadColumnSecurityCallback = ExampleLoadColumnSecurityFromDatabase
|
||||||
|
GlobalSecurity.LoadRowSecurityCallback = ExampleLoadRowSecurityFromDatabase
|
||||||
|
|
||||||
|
// Step 4: Setup security provider
|
||||||
|
if err := SetupSecurityProvider(handler, &GlobalSecurity); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to setup security: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 5: Create Mux router and setup routes
|
||||||
|
router := mux.NewRouter()
|
||||||
|
|
||||||
|
// The routes are set up by restheadspec, which handles the conversion
|
||||||
|
// from http.Request to the internal request format
|
||||||
|
restheadspec.SetupMuxRoutes(router, handler)
|
||||||
|
|
||||||
|
// Step 6: Apply middleware to the entire router
|
||||||
|
secureRouter := Chain(
|
||||||
|
AuthMiddleware, // Extract user from token
|
||||||
|
SetSecurityMiddleware, // Add security context
|
||||||
|
)(router)
|
||||||
|
|
||||||
|
return secureRouter, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExampleWithMux shows a simpler integration with Mux
|
||||||
|
func ExampleWithMux(db *gorm.DB) (*mux.Router, error) {
|
||||||
|
handler := restheadspec.NewHandlerWithGORM(db)
|
||||||
|
|
||||||
|
// IMPORTANT: Configure callbacks BEFORE SetupSecurityProvider
|
||||||
|
GlobalSecurity.AuthenticateCallback = ExampleAuthenticateFromHeader
|
||||||
|
GlobalSecurity.LoadColumnSecurityCallback = ExampleLoadColumnSecurityFromConfig
|
||||||
|
GlobalSecurity.LoadRowSecurityCallback = ExampleLoadRowSecurityFromConfig
|
||||||
|
|
||||||
|
if err := SetupSecurityProvider(handler, &GlobalSecurity); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to setup security: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
router := mux.NewRouter()
|
||||||
|
|
||||||
|
// Setup API routes
|
||||||
|
restheadspec.SetupMuxRoutes(router, handler)
|
||||||
|
|
||||||
|
// Apply middleware to router
|
||||||
|
router.Use(mux.MiddlewareFunc(AuthMiddleware))
|
||||||
|
router.Use(mux.MiddlewareFunc(SetSecurityMiddleware))
|
||||||
|
|
||||||
|
return router, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Example with Gin
|
||||||
|
// import "github.com/gin-gonic/gin"
|
||||||
|
//
|
||||||
|
// func ExampleWithGin(db *gorm.DB) *gin.Engine {
|
||||||
|
// handler := restheadspec.NewHandlerWithGORM(db)
|
||||||
|
// SetupSecurityProvider(handler, &GlobalSecurity)
|
||||||
|
//
|
||||||
|
// router := gin.Default()
|
||||||
|
//
|
||||||
|
// // Convert middleware to Gin middleware
|
||||||
|
// router.Use(func(c *gin.Context) {
|
||||||
|
// AuthMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// c.Request = r
|
||||||
|
// c.Next()
|
||||||
|
// })).ServeHTTP(c.Writer, c.Request)
|
||||||
|
// })
|
||||||
|
//
|
||||||
|
// // Setup routes
|
||||||
|
// api := router.Group("/api")
|
||||||
|
// api.Any("/:schema/:entity", gin.WrapH(http.HandlerFunc(handler.Handle)))
|
||||||
|
// api.Any("/:schema/:entity/:id", gin.WrapH(http.HandlerFunc(handler.Handle)))
|
||||||
|
//
|
||||||
|
// return router
|
||||||
|
// }
|
||||||
@ -83,7 +83,7 @@ func TestSetup(m *testing.M) int {
|
|||||||
router := setupTestRouter(testDB)
|
router := setupTestRouter(testDB)
|
||||||
testServer = httptest.NewServer(router)
|
testServer = httptest.NewServer(router)
|
||||||
|
|
||||||
fmt.Printf("ResolveSpec test server starting on %s\n", testServer.URL)
|
logger.Info("ResolveSpec test server starting on %s", testServer.URL)
|
||||||
testServerURL = testServer.URL
|
testServerURL = testServer.URL
|
||||||
|
|
||||||
defer testServer.Close()
|
defer testServer.Close()
|
||||||
|
|||||||
2
todo.md
2
todo.md
@ -120,6 +120,8 @@ func (b *BunSelectQuery) Preload(relation string, conditions ...interface{}) com
|
|||||||
- When making changes, we can have the trigger fire with the correct user.
|
- When making changes, we can have the trigger fire with the correct user.
|
||||||
- Maybe wrap the handleRead,Update,Create,Delete handlers in a transaction with context that can abort when the request is cancelled or a configurable timeout is reached.
|
- Maybe wrap the handleRead,Update,Create,Delete handlers in a transaction with context that can abort when the request is cancelled or a configurable timeout is reached.
|
||||||
|
|
||||||
|
### 7.
|
||||||
|
|
||||||
## Additional Considerations
|
## Additional Considerations
|
||||||
|
|
||||||
### Documentation
|
### Documentation
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user