mirror of
https://github.com/bitechdev/ResolveSpec.git
synced 2026-04-09 17:36:23 +00:00
Compare commits
16 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f10bb0827e | ||
|
|
22a4ab345a | ||
|
|
e289c2ed8f | ||
|
|
0d50bcfee6 | ||
| 4df626ea71 | |||
|
|
7dd630dec2 | ||
|
|
613bf22cbd | ||
| d1ae4fe64e | |||
| 254102bfac | |||
| 6c27419dbc | |||
| 377336caf4 | |||
| 79720d5421 | |||
| e7ab0a20d6 | |||
| e4087104a9 | |||
|
|
17e580a9d3 | ||
|
|
337a007d57 |
27
LICENSE
27
LICENSE
@@ -1,3 +1,18 @@
|
|||||||
|
Project Notice
|
||||||
|
|
||||||
|
This project was independently developed.
|
||||||
|
|
||||||
|
The contents of this repository were prepared and published outside any time
|
||||||
|
allocated to Bitech Systems CC and do not contain, incorporate, disclose,
|
||||||
|
or rely upon any proprietary or confidential information, trade secrets,
|
||||||
|
protected designs, or other intellectual property of Bitech Systems CC.
|
||||||
|
|
||||||
|
No portion of this repository reproduces any Bitech Systems CC-specific
|
||||||
|
implementation, design asset, confidential workflow, or non-public technical material.
|
||||||
|
|
||||||
|
This notice is provided for clarification only and does not modify the terms of
|
||||||
|
the Apache License, Version 2.0.
|
||||||
|
|
||||||
Apache License
|
Apache License
|
||||||
Version 2.0, January 2004
|
Version 2.0, January 2004
|
||||||
http://www.apache.org/licenses/
|
http://www.apache.org/licenses/
|
||||||
@@ -32,15 +47,15 @@ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
|||||||
|
|
||||||
4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions:
|
4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions:
|
||||||
|
|
||||||
(a) You must give any other recipients of the Work or Derivative Works a copy of this License; and
|
(a) You must give any other recipients of the Work or Derivative Works a copy of this License; and
|
||||||
|
|
||||||
(b) You must cause any modified files to carry prominent notices stating that You changed the files; and
|
(b) You must cause any modified files to carry prominent notices stating that You changed the files; and
|
||||||
|
|
||||||
(c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and
|
(c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and
|
||||||
|
|
||||||
(d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License.
|
(d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License.
|
||||||
|
|
||||||
You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License.
|
You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License.
|
||||||
|
|
||||||
5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions.
|
5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions.
|
||||||
|
|
||||||
@@ -56,7 +71,7 @@ END OF TERMS AND CONDITIONS
|
|||||||
|
|
||||||
APPENDIX: How to apply the Apache License to your work.
|
APPENDIX: How to apply the Apache License to your work.
|
||||||
|
|
||||||
To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives.
|
To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives.
|
||||||
|
|
||||||
Copyright 2025 wdevs
|
Copyright 2025 wdevs
|
||||||
|
|
||||||
|
|||||||
@@ -394,12 +394,12 @@ func (p *PgSQLSelectQuery) buildSQL() string {
|
|||||||
|
|
||||||
// LIMIT clause
|
// LIMIT clause
|
||||||
if p.limit > 0 {
|
if p.limit > 0 {
|
||||||
sb.WriteString(fmt.Sprintf(" LIMIT %d", p.limit))
|
fmt.Fprintf(&sb, " LIMIT %d", p.limit)
|
||||||
}
|
}
|
||||||
|
|
||||||
// OFFSET clause
|
// OFFSET clause
|
||||||
if p.offset > 0 {
|
if p.offset > 0 {
|
||||||
sb.WriteString(fmt.Sprintf(" OFFSET %d", p.offset))
|
fmt.Fprintf(&sb, " OFFSET %d", p.offset)
|
||||||
}
|
}
|
||||||
|
|
||||||
return sb.String()
|
return sb.String()
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package common
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"reflect"
|
||||||
"regexp"
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
@@ -167,16 +168,17 @@ func SanitizeWhereClause(where string, tableName string, options ...*RequestOpti
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Build a set of allowed table prefixes (main table + preloaded relations)
|
// Build a set of allowed table prefixes (main table + preloaded relations)
|
||||||
|
// Keys are stored lowercase for case-insensitive matching
|
||||||
allowedPrefixes := make(map[string]bool)
|
allowedPrefixes := make(map[string]bool)
|
||||||
if tableName != "" {
|
if tableName != "" {
|
||||||
allowedPrefixes[tableName] = true
|
allowedPrefixes[strings.ToLower(tableName)] = true
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add preload relation names as allowed prefixes
|
// Add preload relation names as allowed prefixes
|
||||||
if len(options) > 0 && options[0] != nil {
|
if len(options) > 0 && options[0] != nil {
|
||||||
for pi := range options[0].Preload {
|
for pi := range options[0].Preload {
|
||||||
if options[0].Preload[pi].Relation != "" {
|
if options[0].Preload[pi].Relation != "" {
|
||||||
allowedPrefixes[options[0].Preload[pi].Relation] = true
|
allowedPrefixes[strings.ToLower(options[0].Preload[pi].Relation)] = true
|
||||||
logger.Debug("Added preload relation '%s' as allowed table prefix", options[0].Preload[pi].Relation)
|
logger.Debug("Added preload relation '%s' as allowed table prefix", options[0].Preload[pi].Relation)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -184,7 +186,7 @@ func SanitizeWhereClause(where string, tableName string, options ...*RequestOpti
|
|||||||
// Add join aliases as allowed prefixes
|
// Add join aliases as allowed prefixes
|
||||||
for _, alias := range options[0].JoinAliases {
|
for _, alias := range options[0].JoinAliases {
|
||||||
if alias != "" {
|
if alias != "" {
|
||||||
allowedPrefixes[alias] = true
|
allowedPrefixes[strings.ToLower(alias)] = true
|
||||||
logger.Debug("Added join alias '%s' as allowed table prefix", alias)
|
logger.Debug("Added join alias '%s' as allowed table prefix", alias)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -216,8 +218,8 @@ func SanitizeWhereClause(where string, tableName string, options ...*RequestOpti
|
|||||||
currentPrefix, columnName := extractTableAndColumn(condToCheck)
|
currentPrefix, columnName := extractTableAndColumn(condToCheck)
|
||||||
|
|
||||||
if currentPrefix != "" && columnName != "" {
|
if currentPrefix != "" && columnName != "" {
|
||||||
// Check if the prefix is allowed (main table or preload relation)
|
// Check if the prefix is allowed (main table or preload relation) - case-insensitive
|
||||||
if !allowedPrefixes[currentPrefix] {
|
if !allowedPrefixes[strings.ToLower(currentPrefix)] {
|
||||||
// Prefix is not in the allowed list - only fix if it's a valid column in the main table
|
// Prefix is not in the allowed list - only fix if it's a valid column in the main table
|
||||||
if validColumns == nil || isValidColumn(columnName, validColumns) {
|
if validColumns == nil || isValidColumn(columnName, validColumns) {
|
||||||
// Replace the incorrect prefix with the correct main table name
|
// Replace the incorrect prefix with the correct main table name
|
||||||
@@ -925,3 +927,36 @@ func extractLeftSideOfComparison(cond string) string {
|
|||||||
|
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// FilterValueToSlice converts a filter value to []interface{} for use with IN operators.
|
||||||
|
// JSON-decoded arrays arrive as []interface{}, but typed slices (e.g. []string) also work.
|
||||||
|
// Returns a single-element slice if the value is not a slice type.
|
||||||
|
func FilterValueToSlice(v interface{}) []interface{} {
|
||||||
|
if v == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
rv := reflect.ValueOf(v)
|
||||||
|
if rv.Kind() == reflect.Slice {
|
||||||
|
result := make([]interface{}, rv.Len())
|
||||||
|
for i := 0; i < rv.Len(); i++ {
|
||||||
|
result[i] = rv.Index(i).Interface()
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
return []interface{}{v}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildInCondition builds a parameterized IN condition from a filter value.
|
||||||
|
// Returns the condition string (e.g. "col IN (?,?)") and the individual values as args.
|
||||||
|
// Returns ("", nil) if the value is empty or not a slice.
|
||||||
|
func BuildInCondition(column string, v interface{}) (query string, args []interface{}) {
|
||||||
|
values := FilterValueToSlice(v)
|
||||||
|
if len(values) == 0 {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
placeholders := make([]string, len(values))
|
||||||
|
for i := range values {
|
||||||
|
placeholders[i] = "?"
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%s IN (%s)", column, strings.Join(placeholders, ",")), values
|
||||||
|
}
|
||||||
|
|||||||
@@ -2,14 +2,38 @@ package funcspec
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/security"
|
"github.com/bitechdev/ResolveSpec/pkg/security"
|
||||||
)
|
)
|
||||||
|
|
||||||
// RegisterSecurityHooks registers security hooks for funcspec handlers
|
// RegisterSecurityHooks registers security hooks for funcspec handlers
|
||||||
// Note: funcspec operates on SQL queries directly, so row-level security is not directly applicable
|
// Note: funcspec operates on SQL queries directly, so row-level security is not directly applicable
|
||||||
// We provide audit logging for data access tracking
|
// We provide auth enforcement and audit logging for data access tracking
|
||||||
func RegisterSecurityHooks(handler *Handler, securityList *security.SecurityList) {
|
func RegisterSecurityHooks(handler *Handler, securityList *security.SecurityList) {
|
||||||
|
// Hook 0: BeforeQueryList - Auth check before list query execution
|
||||||
|
handler.Hooks().Register(BeforeQueryList, func(hookCtx *HookContext) error {
|
||||||
|
if hookCtx.UserContext == nil || hookCtx.UserContext.UserID == 0 {
|
||||||
|
hookCtx.Abort = true
|
||||||
|
hookCtx.AbortMessage = "authentication required"
|
||||||
|
hookCtx.AbortCode = http.StatusUnauthorized
|
||||||
|
return fmt.Errorf("authentication required")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
// Hook 0: BeforeQuery - Auth check before single query execution
|
||||||
|
handler.Hooks().Register(BeforeQuery, func(hookCtx *HookContext) error {
|
||||||
|
if hookCtx.UserContext == nil || hookCtx.UserContext.UserID == 0 {
|
||||||
|
hookCtx.Abort = true
|
||||||
|
hookCtx.AbortMessage = "authentication required"
|
||||||
|
hookCtx.AbortCode = http.StatusUnauthorized
|
||||||
|
return fmt.Errorf("authentication required")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
// Hook 1: BeforeQueryList - Audit logging before query list execution
|
// Hook 1: BeforeQueryList - Audit logging before query list execution
|
||||||
handler.Hooks().Register(BeforeQueryList, func(hookCtx *HookContext) error {
|
handler.Hooks().Register(BeforeQueryList, func(hookCtx *HookContext) error {
|
||||||
secCtx := newFuncSpecSecurityContext(hookCtx)
|
secCtx := newFuncSpecSecurityContext(hookCtx)
|
||||||
|
|||||||
@@ -8,6 +8,10 @@ import (
|
|||||||
|
|
||||||
// ModelRules defines the permissions and security settings for a model
|
// ModelRules defines the permissions and security settings for a model
|
||||||
type ModelRules struct {
|
type ModelRules struct {
|
||||||
|
CanPublicRead bool // Whether the model can be read (GET operations)
|
||||||
|
CanPublicUpdate bool // Whether the model can be updated (PUT/PATCH operations)
|
||||||
|
CanPublicCreate bool // Whether the model can be created (POST operations)
|
||||||
|
CanPublicDelete bool // Whether the model can be deleted (DELETE operations)
|
||||||
CanRead bool // Whether the model can be read (GET operations)
|
CanRead bool // Whether the model can be read (GET operations)
|
||||||
CanUpdate bool // Whether the model can be updated (PUT/PATCH operations)
|
CanUpdate bool // Whether the model can be updated (PUT/PATCH operations)
|
||||||
CanCreate bool // Whether the model can be created (POST operations)
|
CanCreate bool // Whether the model can be created (POST operations)
|
||||||
@@ -22,6 +26,10 @@ func DefaultModelRules() ModelRules {
|
|||||||
CanUpdate: true,
|
CanUpdate: true,
|
||||||
CanCreate: true,
|
CanCreate: true,
|
||||||
CanDelete: true,
|
CanDelete: true,
|
||||||
|
CanPublicRead: false,
|
||||||
|
CanPublicUpdate: false,
|
||||||
|
CanPublicCreate: false,
|
||||||
|
CanPublicDelete: false,
|
||||||
SecurityDisabled: false,
|
SecurityDisabled: false,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ MQTTSpec is an MQTT-based database query framework that enables real-time databa
|
|||||||
- **Full CRUD Operations**: Create, Read, Update, Delete with hooks
|
- **Full CRUD Operations**: Create, Read, Update, Delete with hooks
|
||||||
- **Real-time Subscriptions**: Subscribe to entity changes with filtering
|
- **Real-time Subscriptions**: Subscribe to entity changes with filtering
|
||||||
- **Database Agnostic**: GORM and Bun ORM support
|
- **Database Agnostic**: GORM and Bun ORM support
|
||||||
- **Lifecycle Hooks**: 12 hooks for authentication, authorization, validation, and auditing
|
- **Lifecycle Hooks**: 13 hooks for authentication, authorization, validation, and auditing
|
||||||
- **Multi-tenancy Support**: Built-in tenant isolation via hooks
|
- **Multi-tenancy Support**: Built-in tenant isolation via hooks
|
||||||
- **Thread-safe**: Proper concurrency handling throughout
|
- **Thread-safe**: Proper concurrency handling throughout
|
||||||
|
|
||||||
@@ -326,10 +326,11 @@ When any client creates/updates/deletes a user matching the subscription filters
|
|||||||
|
|
||||||
## Lifecycle Hooks
|
## Lifecycle Hooks
|
||||||
|
|
||||||
MQTTSpec provides 12 lifecycle hooks for implementing cross-cutting concerns:
|
MQTTSpec provides 13 lifecycle hooks for implementing cross-cutting concerns:
|
||||||
|
|
||||||
### Hook Types
|
### Hook Types
|
||||||
|
|
||||||
|
- `BeforeHandle` — fires after model resolution, before operation dispatch (auth checks)
|
||||||
- `BeforeConnect` / `AfterConnect` - Connection lifecycle
|
- `BeforeConnect` / `AfterConnect` - Connection lifecycle
|
||||||
- `BeforeDisconnect` / `AfterDisconnect` - Disconnection lifecycle
|
- `BeforeDisconnect` / `AfterDisconnect` - Disconnection lifecycle
|
||||||
- `BeforeRead` / `AfterRead` - Read operations
|
- `BeforeRead` / `AfterRead` - Read operations
|
||||||
@@ -339,6 +340,20 @@ MQTTSpec provides 12 lifecycle hooks for implementing cross-cutting concerns:
|
|||||||
- `BeforeSubscribe` / `AfterSubscribe` - Subscription creation
|
- `BeforeSubscribe` / `AfterSubscribe` - Subscription creation
|
||||||
- `BeforeUnsubscribe` / `AfterUnsubscribe` - Subscription removal
|
- `BeforeUnsubscribe` / `AfterUnsubscribe` - Subscription removal
|
||||||
|
|
||||||
|
### Security Hooks (Recommended)
|
||||||
|
|
||||||
|
Use `RegisterSecurityHooks` for integrated auth with model-rule support:
|
||||||
|
|
||||||
|
```go
|
||||||
|
import "github.com/bitechdev/ResolveSpec/pkg/security"
|
||||||
|
|
||||||
|
provider := security.NewCompositeSecurityProvider(auth, colSec, rowSec)
|
||||||
|
securityList := security.NewSecurityList(provider)
|
||||||
|
mqttspec.RegisterSecurityHooks(handler, securityList)
|
||||||
|
// Registers BeforeHandle (model auth), BeforeRead (load rules),
|
||||||
|
// AfterRead (column security + audit), BeforeUpdate, BeforeDelete
|
||||||
|
```
|
||||||
|
|
||||||
### Authentication Example (JWT)
|
### Authentication Example (JWT)
|
||||||
|
|
||||||
```go
|
```go
|
||||||
@@ -657,7 +672,7 @@ handler, err := mqttspec.NewHandlerWithGORM(db,
|
|||||||
| **Network Efficiency** | Better for unreliable networks | Better for low-latency |
|
| **Network Efficiency** | Better for unreliable networks | Better for low-latency |
|
||||||
| **Best For** | IoT, mobile apps, distributed systems | Web applications, real-time dashboards |
|
| **Best For** | IoT, mobile apps, distributed systems | Web applications, real-time dashboards |
|
||||||
| **Message Protocol** | Same JSON structure | Same JSON structure |
|
| **Message Protocol** | Same JSON structure | Same JSON structure |
|
||||||
| **Hooks** | Same 12 hooks | Same 12 hooks |
|
| **Hooks** | Same 13 hooks | Same 13 hooks |
|
||||||
| **CRUD Operations** | Identical | Identical |
|
| **CRUD Operations** | Identical | Identical |
|
||||||
| **Subscriptions** | Identical (via MQTT topics) | Identical (via app-level) |
|
| **Subscriptions** | Identical (via MQTT topics) | Identical (via app-level) |
|
||||||
|
|
||||||
|
|||||||
@@ -284,6 +284,15 @@ func (h *Handler) handleRequest(client *Client, msg *Message) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Execute BeforeHandle hook - auth check fires here, after model resolution
|
||||||
|
hookCtx.Operation = string(msg.Operation)
|
||||||
|
if err := h.hooks.Execute(BeforeHandle, hookCtx); err != nil {
|
||||||
|
if hookCtx.Abort {
|
||||||
|
h.sendError(client.ID, msg.ID, "unauthorized", hookCtx.AbortMessage)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// Route to operation handler
|
// Route to operation handler
|
||||||
switch msg.Operation {
|
switch msg.Operation {
|
||||||
case OperationRead:
|
case OperationRead:
|
||||||
|
|||||||
@@ -20,8 +20,11 @@ type (
|
|||||||
HookRegistry = websocketspec.HookRegistry
|
HookRegistry = websocketspec.HookRegistry
|
||||||
)
|
)
|
||||||
|
|
||||||
// Hook type constants - all 12 lifecycle hooks
|
// Hook type constants - all lifecycle hooks
|
||||||
const (
|
const (
|
||||||
|
// BeforeHandle fires after model resolution, before operation dispatch
|
||||||
|
BeforeHandle = websocketspec.BeforeHandle
|
||||||
|
|
||||||
// CRUD operation hooks
|
// CRUD operation hooks
|
||||||
BeforeRead = websocketspec.BeforeRead
|
BeforeRead = websocketspec.BeforeRead
|
||||||
AfterRead = websocketspec.AfterRead
|
AfterRead = websocketspec.AfterRead
|
||||||
|
|||||||
108
pkg/mqttspec/security_hooks.go
Normal file
108
pkg/mqttspec/security_hooks.go
Normal file
@@ -0,0 +1,108 @@
|
|||||||
|
package mqttspec
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/security"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RegisterSecurityHooks registers all security-related hooks with the MQTT handler
|
||||||
|
func RegisterSecurityHooks(handler *Handler, securityList *security.SecurityList) {
|
||||||
|
// Hook 0: BeforeHandle - enforce auth after model resolution
|
||||||
|
handler.Hooks().Register(BeforeHandle, func(hookCtx *HookContext) error {
|
||||||
|
if err := security.CheckModelAuthAllowed(newSecurityContext(hookCtx), hookCtx.Operation); err != nil {
|
||||||
|
hookCtx.Abort = true
|
||||||
|
hookCtx.AbortMessage = err.Error()
|
||||||
|
hookCtx.AbortCode = http.StatusUnauthorized
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
// Hook 1: BeforeRead - Load security rules
|
||||||
|
handler.Hooks().Register(BeforeRead, func(hookCtx *HookContext) error {
|
||||||
|
secCtx := newSecurityContext(hookCtx)
|
||||||
|
return security.LoadSecurityRules(secCtx, securityList)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Hook 2: AfterRead - Apply column-level security (masking)
|
||||||
|
handler.Hooks().Register(AfterRead, func(hookCtx *HookContext) error {
|
||||||
|
secCtx := newSecurityContext(hookCtx)
|
||||||
|
return security.ApplyColumnSecurity(secCtx, securityList)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Hook 3 (Optional): Audit logging
|
||||||
|
handler.Hooks().Register(AfterRead, func(hookCtx *HookContext) error {
|
||||||
|
secCtx := newSecurityContext(hookCtx)
|
||||||
|
return security.LogDataAccess(secCtx)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Hook 4: BeforeUpdate - enforce CanUpdate rule from context/registry
|
||||||
|
handler.Hooks().Register(BeforeUpdate, func(hookCtx *HookContext) error {
|
||||||
|
secCtx := newSecurityContext(hookCtx)
|
||||||
|
return security.CheckModelUpdateAllowed(secCtx)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Hook 5: BeforeDelete - enforce CanDelete rule from context/registry
|
||||||
|
handler.Hooks().Register(BeforeDelete, func(hookCtx *HookContext) error {
|
||||||
|
secCtx := newSecurityContext(hookCtx)
|
||||||
|
return security.CheckModelDeleteAllowed(secCtx)
|
||||||
|
})
|
||||||
|
|
||||||
|
logger.Info("Security hooks registered for mqttspec handler")
|
||||||
|
}
|
||||||
|
|
||||||
|
// securityContext adapts mqttspec.HookContext to security.SecurityContext interface
|
||||||
|
type securityContext struct {
|
||||||
|
ctx *HookContext
|
||||||
|
}
|
||||||
|
|
||||||
|
func newSecurityContext(ctx *HookContext) security.SecurityContext {
|
||||||
|
return &securityContext{ctx: ctx}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *securityContext) GetContext() context.Context {
|
||||||
|
return s.ctx.Context
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *securityContext) GetUserID() (int, bool) {
|
||||||
|
return security.GetUserID(s.ctx.Context)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *securityContext) GetSchema() string {
|
||||||
|
return s.ctx.Schema
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *securityContext) GetEntity() string {
|
||||||
|
return s.ctx.Entity
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *securityContext) GetModel() interface{} {
|
||||||
|
return s.ctx.Model
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetQuery retrieves a stored query from hook metadata
|
||||||
|
func (s *securityContext) GetQuery() interface{} {
|
||||||
|
if s.ctx.Metadata == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return s.ctx.Metadata["query"]
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetQuery stores the query in hook metadata
|
||||||
|
func (s *securityContext) SetQuery(query interface{}) {
|
||||||
|
if s.ctx.Metadata == nil {
|
||||||
|
s.ctx.Metadata = make(map[string]interface{})
|
||||||
|
}
|
||||||
|
s.ctx.Metadata["query"] = query
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *securityContext) GetResult() interface{} {
|
||||||
|
return s.ctx.Result
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *securityContext) SetResult(result interface{}) {
|
||||||
|
s.ctx.Result = result
|
||||||
|
}
|
||||||
@@ -644,6 +644,7 @@ handler.Hooks().Register(resolvespec.BeforeCreate, func(ctx *resolvespec.HookCon
|
|||||||
CreatedAt time.Time `json:"created_at"`
|
CreatedAt time.Time `json:"created_at"`
|
||||||
Tags []Tag `json:"tags,omitempty" gorm:"many2many:post_tags"`
|
Tags []Tag `json:"tags,omitempty" gorm:"many2many:post_tags"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Schema.Table format
|
// Schema.Table format
|
||||||
handler.registry.RegisterModel("core.users", &User{})
|
handler.registry.RegisterModel("core.users", &User{})
|
||||||
handler.registry.RegisterModel("core.posts", &Post{})
|
handler.registry.RegisterModel("core.posts", &Post{})
|
||||||
@@ -654,11 +655,13 @@ handler.Hooks().Register(resolvespec.BeforeCreate, func(ctx *resolvespec.HookCon
|
|||||||
```go
|
```go
|
||||||
package main
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/resolvespec"
|
"github.com/bitechdev/ResolveSpec/pkg/resolvespec"
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
|
"gorm.io/driver/postgres"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -32,7 +32,8 @@ func GetCursorFilter(
|
|||||||
modelColumns []string,
|
modelColumns []string,
|
||||||
options common.RequestOptions,
|
options common.RequestOptions,
|
||||||
) (string, error) {
|
) (string, error) {
|
||||||
// Remove schema prefix if present
|
// Separate schema prefix from bare table name
|
||||||
|
fullTableName := tableName
|
||||||
if strings.Contains(tableName, ".") {
|
if strings.Contains(tableName, ".") {
|
||||||
tableName = strings.SplitN(tableName, ".", 2)[1]
|
tableName = strings.SplitN(tableName, ".", 2)[1]
|
||||||
}
|
}
|
||||||
@@ -115,7 +116,7 @@ func GetCursorFilter(
|
|||||||
WHERE cursor_select.%s = %s
|
WHERE cursor_select.%s = %s
|
||||||
AND (%s)
|
AND (%s)
|
||||||
)`,
|
)`,
|
||||||
tableName,
|
fullTableName,
|
||||||
pkName,
|
pkName,
|
||||||
cursorID,
|
cursorID,
|
||||||
orSQL,
|
orSQL,
|
||||||
|
|||||||
@@ -175,9 +175,9 @@ func TestGetCursorFilter_WithSchemaPrefix(t *testing.T) {
|
|||||||
t.Fatalf("GetCursorFilter failed: %v", err)
|
t.Fatalf("GetCursorFilter failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Should handle schema prefix properly
|
// Should include full schema-qualified name in FROM clause
|
||||||
if !strings.Contains(filter, "users") {
|
if !strings.Contains(filter, "public.users") {
|
||||||
t.Errorf("Filter should reference table name users, got: %s", filter)
|
t.Errorf("Filter FROM clause should use schema-qualified name public.users, got: %s", filter)
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Logf("Generated cursor filter with schema: %s", filter)
|
t.Logf("Generated cursor filter with schema: %s", filter)
|
||||||
|
|||||||
@@ -44,8 +44,8 @@ func TestBuildFilterCondition(t *testing.T) {
|
|||||||
Operator: "in",
|
Operator: "in",
|
||||||
Value: []string{"active", "pending"},
|
Value: []string{"active", "pending"},
|
||||||
},
|
},
|
||||||
expectedCondition: "status IN (?)",
|
expectedCondition: "status IN (?,?)",
|
||||||
expectedArgsCount: 1,
|
expectedArgsCount: 2,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "LIKE operator",
|
name: "LIKE operator",
|
||||||
|
|||||||
@@ -138,6 +138,26 @@ func (h *Handler) Handle(w common.ResponseWriter, r common.Request, params map[s
|
|||||||
validator := common.NewColumnValidator(model)
|
validator := common.NewColumnValidator(model)
|
||||||
req.Options = validator.FilterRequestOptions(req.Options)
|
req.Options = validator.FilterRequestOptions(req.Options)
|
||||||
|
|
||||||
|
// Execute BeforeHandle hook - auth check fires here, after model resolution
|
||||||
|
beforeCtx := &HookContext{
|
||||||
|
Context: ctx,
|
||||||
|
Handler: h,
|
||||||
|
Schema: schema,
|
||||||
|
Entity: entity,
|
||||||
|
Model: model,
|
||||||
|
Writer: w,
|
||||||
|
Request: r,
|
||||||
|
Operation: req.Operation,
|
||||||
|
}
|
||||||
|
if err := h.hooks.Execute(BeforeHandle, beforeCtx); err != nil {
|
||||||
|
code := http.StatusUnauthorized
|
||||||
|
if beforeCtx.AbortCode != 0 {
|
||||||
|
code = beforeCtx.AbortCode
|
||||||
|
}
|
||||||
|
h.sendError(w, code, "unauthorized", beforeCtx.AbortMessage, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
switch req.Operation {
|
switch req.Operation {
|
||||||
case "read":
|
case "read":
|
||||||
h.handleRead(ctx, w, id, req.Options)
|
h.handleRead(ctx, w, id, req.Options)
|
||||||
@@ -309,6 +329,11 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
|||||||
// Extract model columns for validation
|
// Extract model columns for validation
|
||||||
modelColumns := reflection.GetModelColumns(model)
|
modelColumns := reflection.GetModelColumns(model)
|
||||||
|
|
||||||
|
// Default sort to primary key when none provided
|
||||||
|
if len(options.Sort) == 0 {
|
||||||
|
options.Sort = []common.SortOption{{Column: pkName, Direction: "ASC"}}
|
||||||
|
}
|
||||||
|
|
||||||
// Get cursor filter SQL
|
// Get cursor filter SQL
|
||||||
cursorFilter, err := GetCursorFilter(tableName, pkName, modelColumns, options)
|
cursorFilter, err := GetCursorFilter(tableName, pkName, modelColumns, options)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -1236,6 +1261,24 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
|||||||
|
|
||||||
logger.Info("Deleting records from %s.%s", schema, entity)
|
logger.Info("Deleting records from %s.%s", schema, entity)
|
||||||
|
|
||||||
|
// Execute BeforeDelete hooks (covers model-rule checks before any deletion)
|
||||||
|
hookCtx := &HookContext{
|
||||||
|
Context: ctx,
|
||||||
|
Handler: h,
|
||||||
|
Schema: schema,
|
||||||
|
Entity: entity,
|
||||||
|
Model: model,
|
||||||
|
ID: id,
|
||||||
|
Data: data,
|
||||||
|
Writer: w,
|
||||||
|
Tx: h.db,
|
||||||
|
}
|
||||||
|
if err := h.hooks.Execute(BeforeDelete, hookCtx); err != nil {
|
||||||
|
logger.Error("BeforeDelete hook failed: %v", err)
|
||||||
|
h.sendError(w, http.StatusForbidden, "delete_forbidden", "Delete operation not allowed", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// Handle batch delete from request data
|
// Handle batch delete from request data
|
||||||
if data != nil {
|
if data != nil {
|
||||||
switch v := data.(type) {
|
switch v := data.(type) {
|
||||||
@@ -1483,22 +1526,22 @@ func (h *Handler) buildFilterCondition(filter common.FilterOption) (conditionStr
|
|||||||
var args []interface{}
|
var args []interface{}
|
||||||
|
|
||||||
switch filter.Operator {
|
switch filter.Operator {
|
||||||
case "eq":
|
case "eq", "=":
|
||||||
condition = fmt.Sprintf("%s = ?", filter.Column)
|
condition = fmt.Sprintf("%s = ?", filter.Column)
|
||||||
args = []interface{}{filter.Value}
|
args = []interface{}{filter.Value}
|
||||||
case "neq":
|
case "neq", "!=", "<>":
|
||||||
condition = fmt.Sprintf("%s != ?", filter.Column)
|
condition = fmt.Sprintf("%s != ?", filter.Column)
|
||||||
args = []interface{}{filter.Value}
|
args = []interface{}{filter.Value}
|
||||||
case "gt":
|
case "gt", ">":
|
||||||
condition = fmt.Sprintf("%s > ?", filter.Column)
|
condition = fmt.Sprintf("%s > ?", filter.Column)
|
||||||
args = []interface{}{filter.Value}
|
args = []interface{}{filter.Value}
|
||||||
case "gte":
|
case "gte", ">=":
|
||||||
condition = fmt.Sprintf("%s >= ?", filter.Column)
|
condition = fmt.Sprintf("%s >= ?", filter.Column)
|
||||||
args = []interface{}{filter.Value}
|
args = []interface{}{filter.Value}
|
||||||
case "lt":
|
case "lt", "<":
|
||||||
condition = fmt.Sprintf("%s < ?", filter.Column)
|
condition = fmt.Sprintf("%s < ?", filter.Column)
|
||||||
args = []interface{}{filter.Value}
|
args = []interface{}{filter.Value}
|
||||||
case "lte":
|
case "lte", "<=":
|
||||||
condition = fmt.Sprintf("%s <= ?", filter.Column)
|
condition = fmt.Sprintf("%s <= ?", filter.Column)
|
||||||
args = []interface{}{filter.Value}
|
args = []interface{}{filter.Value}
|
||||||
case "like":
|
case "like":
|
||||||
@@ -1508,8 +1551,10 @@ func (h *Handler) buildFilterCondition(filter common.FilterOption) (conditionStr
|
|||||||
condition = fmt.Sprintf("%s ILIKE ?", filter.Column)
|
condition = fmt.Sprintf("%s ILIKE ?", filter.Column)
|
||||||
args = []interface{}{filter.Value}
|
args = []interface{}{filter.Value}
|
||||||
case "in":
|
case "in":
|
||||||
condition = fmt.Sprintf("%s IN (?)", filter.Column)
|
condition, args = common.BuildInCondition(filter.Column, filter.Value)
|
||||||
args = []interface{}{filter.Value}
|
if condition == "" {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
return "", nil
|
return "", nil
|
||||||
}
|
}
|
||||||
@@ -1525,22 +1570,22 @@ func (h *Handler) applyFilter(query common.SelectQuery, filter common.FilterOpti
|
|||||||
var args []interface{}
|
var args []interface{}
|
||||||
|
|
||||||
switch filter.Operator {
|
switch filter.Operator {
|
||||||
case "eq":
|
case "eq", "=":
|
||||||
condition = fmt.Sprintf("%s = ?", filter.Column)
|
condition = fmt.Sprintf("%s = ?", filter.Column)
|
||||||
args = []interface{}{filter.Value}
|
args = []interface{}{filter.Value}
|
||||||
case "neq":
|
case "neq", "!=", "<>":
|
||||||
condition = fmt.Sprintf("%s != ?", filter.Column)
|
condition = fmt.Sprintf("%s != ?", filter.Column)
|
||||||
args = []interface{}{filter.Value}
|
args = []interface{}{filter.Value}
|
||||||
case "gt":
|
case "gt", ">":
|
||||||
condition = fmt.Sprintf("%s > ?", filter.Column)
|
condition = fmt.Sprintf("%s > ?", filter.Column)
|
||||||
args = []interface{}{filter.Value}
|
args = []interface{}{filter.Value}
|
||||||
case "gte":
|
case "gte", ">=":
|
||||||
condition = fmt.Sprintf("%s >= ?", filter.Column)
|
condition = fmt.Sprintf("%s >= ?", filter.Column)
|
||||||
args = []interface{}{filter.Value}
|
args = []interface{}{filter.Value}
|
||||||
case "lt":
|
case "lt", "<":
|
||||||
condition = fmt.Sprintf("%s < ?", filter.Column)
|
condition = fmt.Sprintf("%s < ?", filter.Column)
|
||||||
args = []interface{}{filter.Value}
|
args = []interface{}{filter.Value}
|
||||||
case "lte":
|
case "lte", "<=":
|
||||||
condition = fmt.Sprintf("%s <= ?", filter.Column)
|
condition = fmt.Sprintf("%s <= ?", filter.Column)
|
||||||
args = []interface{}{filter.Value}
|
args = []interface{}{filter.Value}
|
||||||
case "like":
|
case "like":
|
||||||
@@ -1550,8 +1595,10 @@ func (h *Handler) applyFilter(query common.SelectQuery, filter common.FilterOpti
|
|||||||
condition = fmt.Sprintf("%s ILIKE ?", filter.Column)
|
condition = fmt.Sprintf("%s ILIKE ?", filter.Column)
|
||||||
args = []interface{}{filter.Value}
|
args = []interface{}{filter.Value}
|
||||||
case "in":
|
case "in":
|
||||||
condition = fmt.Sprintf("%s IN (?)", filter.Column)
|
condition, args = common.BuildInCondition(filter.Column, filter.Value)
|
||||||
args = []interface{}{filter.Value}
|
if condition == "" {
|
||||||
|
return query
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
return query
|
return query
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,6 +12,10 @@ import (
|
|||||||
type HookType string
|
type HookType string
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
// BeforeHandle fires after model resolution, before operation dispatch.
|
||||||
|
// Use this for auth checks that need model rules and user context simultaneously.
|
||||||
|
BeforeHandle HookType = "before_handle"
|
||||||
|
|
||||||
// Read operation hooks
|
// Read operation hooks
|
||||||
BeforeRead HookType = "before_read"
|
BeforeRead HookType = "before_read"
|
||||||
AfterRead HookType = "after_read"
|
AfterRead HookType = "after_read"
|
||||||
@@ -43,6 +47,9 @@ type HookContext struct {
|
|||||||
Writer common.ResponseWriter
|
Writer common.ResponseWriter
|
||||||
Request common.Request
|
Request common.Request
|
||||||
|
|
||||||
|
// Operation being dispatched (e.g. "read", "create", "update", "delete")
|
||||||
|
Operation string
|
||||||
|
|
||||||
// Operation-specific fields
|
// Operation-specific fields
|
||||||
ID string
|
ID string
|
||||||
Data interface{} // For create/update operations
|
Data interface{} // For create/update operations
|
||||||
|
|||||||
@@ -70,17 +70,17 @@ func SetupMuxRoutes(muxRouter *mux.Router, handler *Handler, authMiddleware Midd
|
|||||||
entityWithIDPath := buildRoutePath(schema, entity) + "/{id}"
|
entityWithIDPath := buildRoutePath(schema, entity) + "/{id}"
|
||||||
|
|
||||||
// Create handler functions for this specific entity
|
// Create handler functions for this specific entity
|
||||||
postEntityHandler := createMuxHandler(handler, schema, entity, "")
|
var postEntityHandler http.Handler = createMuxHandler(handler, schema, entity, "")
|
||||||
postEntityWithIDHandler := createMuxHandler(handler, schema, entity, "id")
|
var postEntityWithIDHandler http.Handler = createMuxHandler(handler, schema, entity, "id")
|
||||||
getEntityHandler := createMuxGetHandler(handler, schema, entity, "")
|
var getEntityHandler http.Handler = createMuxGetHandler(handler, schema, entity, "")
|
||||||
optionsEntityHandler := createMuxOptionsHandler(handler, schema, entity, []string{"GET", "POST", "OPTIONS"})
|
optionsEntityHandler := createMuxOptionsHandler(handler, schema, entity, []string{"GET", "POST", "OPTIONS"})
|
||||||
optionsEntityWithIDHandler := createMuxOptionsHandler(handler, schema, entity, []string{"POST", "OPTIONS"})
|
optionsEntityWithIDHandler := createMuxOptionsHandler(handler, schema, entity, []string{"POST", "OPTIONS"})
|
||||||
|
|
||||||
// Apply authentication middleware if provided
|
// Apply authentication middleware if provided
|
||||||
if authMiddleware != nil {
|
if authMiddleware != nil {
|
||||||
postEntityHandler = authMiddleware(postEntityHandler).(http.HandlerFunc)
|
postEntityHandler = authMiddleware(postEntityHandler)
|
||||||
postEntityWithIDHandler = authMiddleware(postEntityWithIDHandler).(http.HandlerFunc)
|
postEntityWithIDHandler = authMiddleware(postEntityWithIDHandler)
|
||||||
getEntityHandler = authMiddleware(getEntityHandler).(http.HandlerFunc)
|
getEntityHandler = authMiddleware(getEntityHandler)
|
||||||
// Don't apply auth middleware to OPTIONS - CORS preflight must not require auth
|
// Don't apply auth middleware to OPTIONS - CORS preflight must not require auth
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -225,7 +225,11 @@ func wrapBunRouterHandler(handler bunrouter.HandlerFunc, authMiddleware Middlewa
|
|||||||
return func(w http.ResponseWriter, req bunrouter.Request) error {
|
return func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||||
// Create an http.Handler that calls the bunrouter handler
|
// Create an http.Handler that calls the bunrouter handler
|
||||||
httpHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
httpHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
_ = handler(w, req)
|
// Replace the embedded *http.Request with the middleware-enriched one
|
||||||
|
// so that auth context (user ID, etc.) is visible to the handler.
|
||||||
|
enrichedReq := req
|
||||||
|
enrichedReq.Request = r
|
||||||
|
_ = handler(w, enrichedReq)
|
||||||
})
|
})
|
||||||
|
|
||||||
// Wrap with auth middleware and execute
|
// Wrap with auth middleware and execute
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package resolvespec
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
@@ -10,6 +11,17 @@ import (
|
|||||||
|
|
||||||
// RegisterSecurityHooks registers all security-related hooks with the handler
|
// RegisterSecurityHooks registers all security-related hooks with the handler
|
||||||
func RegisterSecurityHooks(handler *Handler, securityList *security.SecurityList) {
|
func RegisterSecurityHooks(handler *Handler, securityList *security.SecurityList) {
|
||||||
|
// Hook 0: BeforeHandle - enforce auth after model resolution
|
||||||
|
handler.Hooks().Register(BeforeHandle, func(hookCtx *HookContext) error {
|
||||||
|
if err := security.CheckModelAuthAllowed(newSecurityContext(hookCtx), hookCtx.Operation); err != nil {
|
||||||
|
hookCtx.Abort = true
|
||||||
|
hookCtx.AbortMessage = err.Error()
|
||||||
|
hookCtx.AbortCode = http.StatusUnauthorized
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
// Hook 1: BeforeRead - Load security rules
|
// Hook 1: BeforeRead - Load security rules
|
||||||
handler.Hooks().Register(BeforeRead, func(hookCtx *HookContext) error {
|
handler.Hooks().Register(BeforeRead, func(hookCtx *HookContext) error {
|
||||||
secCtx := newSecurityContext(hookCtx)
|
secCtx := newSecurityContext(hookCtx)
|
||||||
@@ -34,6 +46,18 @@ func RegisterSecurityHooks(handler *Handler, securityList *security.SecurityList
|
|||||||
return security.LogDataAccess(secCtx)
|
return security.LogDataAccess(secCtx)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// Hook 5: BeforeUpdate - enforce CanUpdate rule from context/registry
|
||||||
|
handler.Hooks().Register(BeforeUpdate, func(hookCtx *HookContext) error {
|
||||||
|
secCtx := newSecurityContext(hookCtx)
|
||||||
|
return security.CheckModelUpdateAllowed(secCtx)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Hook 6: BeforeDelete - enforce CanDelete rule from context/registry
|
||||||
|
handler.Hooks().Register(BeforeDelete, func(hookCtx *HookContext) error {
|
||||||
|
secCtx := newSecurityContext(hookCtx)
|
||||||
|
return security.CheckModelDeleteAllowed(secCtx)
|
||||||
|
})
|
||||||
|
|
||||||
logger.Info("Security hooks registered for resolvespec handler")
|
logger.Info("Security hooks registered for resolvespec handler")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -147,6 +147,7 @@ handler.Hooks.Register(restheadspec.BeforeCreate, func(ctx *restheadspec.HookCon
|
|||||||
```
|
```
|
||||||
|
|
||||||
**Available Hook Types**:
|
**Available Hook Types**:
|
||||||
|
* `BeforeHandle` — fires after model resolution, before operation dispatch (auth checks)
|
||||||
* `BeforeRead`, `AfterRead`
|
* `BeforeRead`, `AfterRead`
|
||||||
* `BeforeCreate`, `AfterCreate`
|
* `BeforeCreate`, `AfterCreate`
|
||||||
* `BeforeUpdate`, `AfterUpdate`
|
* `BeforeUpdate`, `AfterUpdate`
|
||||||
@@ -157,11 +158,13 @@ handler.Hooks.Register(restheadspec.BeforeCreate, func(ctx *restheadspec.HookCon
|
|||||||
* `Handler`: Access to handler, database, and registry
|
* `Handler`: Access to handler, database, and registry
|
||||||
* `Schema`, `Entity`, `TableName`: Request info
|
* `Schema`, `Entity`, `TableName`: Request info
|
||||||
* `Model`: The registered model type
|
* `Model`: The registered model type
|
||||||
|
* `Operation`: Current operation string (`"read"`, `"create"`, `"update"`, `"delete"`)
|
||||||
* `Options`: Parsed request options (filters, sorting, etc.)
|
* `Options`: Parsed request options (filters, sorting, etc.)
|
||||||
* `ID`: Record ID (for single-record operations)
|
* `ID`: Record ID (for single-record operations)
|
||||||
* `Data`: Request data (for create/update)
|
* `Data`: Request data (for create/update)
|
||||||
* `Result`: Operation result (for after hooks)
|
* `Result`: Operation result (for after hooks)
|
||||||
* `Writer`: Response writer (allows hooks to modify response)
|
* `Writer`: Response writer (allows hooks to modify response)
|
||||||
|
* `Abort`, `AbortMessage`, `AbortCode`: Set in hook to abort with an error response
|
||||||
|
|
||||||
## Cursor Pagination
|
## Cursor Pagination
|
||||||
|
|
||||||
|
|||||||
@@ -32,6 +32,8 @@ func (opts *ExtendedRequestOptions) GetCursorFilter(
|
|||||||
modelColumns []string, // optional: for validation
|
modelColumns []string, // optional: for validation
|
||||||
expandJoins map[string]string, // optional: alias → JOIN SQL
|
expandJoins map[string]string, // optional: alias → JOIN SQL
|
||||||
) (string, error) {
|
) (string, error) {
|
||||||
|
// Separate schema prefix from bare table name
|
||||||
|
fullTableName := tableName
|
||||||
if strings.Contains(tableName, ".") {
|
if strings.Contains(tableName, ".") {
|
||||||
tableName = strings.SplitN(tableName, ".", 2)[1]
|
tableName = strings.SplitN(tableName, ".", 2)[1]
|
||||||
}
|
}
|
||||||
@@ -127,7 +129,7 @@ func (opts *ExtendedRequestOptions) GetCursorFilter(
|
|||||||
WHERE cursor_select.%s = %s
|
WHERE cursor_select.%s = %s
|
||||||
AND (%s)
|
AND (%s)
|
||||||
)`,
|
)`,
|
||||||
tableName,
|
fullTableName,
|
||||||
joinSQL,
|
joinSQL,
|
||||||
pkName,
|
pkName,
|
||||||
cursorID,
|
cursorID,
|
||||||
|
|||||||
@@ -187,9 +187,9 @@ func TestGetCursorFilter_WithSchemaPrefix(t *testing.T) {
|
|||||||
t.Fatalf("GetCursorFilter failed: %v", err)
|
t.Fatalf("GetCursorFilter failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Should handle schema prefix properly
|
// Should include full schema-qualified name in FROM clause
|
||||||
if !strings.Contains(filter, "users") {
|
if !strings.Contains(filter, "public.users") {
|
||||||
t.Errorf("Filter should reference table name users, got: %s", filter)
|
t.Errorf("Filter FROM clause should use schema-qualified name public.users, got: %s", filter)
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Logf("Generated cursor filter with schema: %s", filter)
|
t.Logf("Generated cursor filter with schema: %s", filter)
|
||||||
|
|||||||
@@ -133,6 +133,41 @@ func (h *Handler) Handle(w common.ResponseWriter, r common.Request, params map[s
|
|||||||
// Add request-scoped data to context (including options)
|
// Add request-scoped data to context (including options)
|
||||||
ctx = WithRequestData(ctx, schema, entity, tableName, model, modelPtr, options)
|
ctx = WithRequestData(ctx, schema, entity, tableName, model, modelPtr, options)
|
||||||
|
|
||||||
|
// Derive operation for auth check
|
||||||
|
var operation string
|
||||||
|
switch method {
|
||||||
|
case "GET":
|
||||||
|
operation = "read"
|
||||||
|
case "POST":
|
||||||
|
operation = "create"
|
||||||
|
case "PUT", "PATCH":
|
||||||
|
operation = "update"
|
||||||
|
case "DELETE":
|
||||||
|
operation = "delete"
|
||||||
|
default:
|
||||||
|
operation = "read"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute BeforeHandle hook - auth check fires here, after model resolution
|
||||||
|
beforeCtx := &HookContext{
|
||||||
|
Context: ctx,
|
||||||
|
Handler: h,
|
||||||
|
Schema: schema,
|
||||||
|
Entity: entity,
|
||||||
|
Model: model,
|
||||||
|
Writer: w,
|
||||||
|
Request: r,
|
||||||
|
Operation: operation,
|
||||||
|
}
|
||||||
|
if err := h.hooks.Execute(BeforeHandle, beforeCtx); err != nil {
|
||||||
|
code := http.StatusUnauthorized
|
||||||
|
if beforeCtx.AbortCode != 0 {
|
||||||
|
code = beforeCtx.AbortCode
|
||||||
|
}
|
||||||
|
h.sendError(w, code, "unauthorized", beforeCtx.AbortMessage, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
switch method {
|
switch method {
|
||||||
case "GET":
|
case "GET":
|
||||||
if id != "" {
|
if id != "" {
|
||||||
@@ -696,6 +731,11 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
|||||||
// For now, pass empty map as joins are handled via Preload
|
// For now, pass empty map as joins are handled via Preload
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Default sort to primary key when none provided
|
||||||
|
if len(options.Sort) == 0 {
|
||||||
|
options.Sort = []common.SortOption{{Column: pkName, Direction: "ASC"}}
|
||||||
|
}
|
||||||
|
|
||||||
// Get cursor filter SQL
|
// Get cursor filter SQL
|
||||||
cursorFilter, err := options.GetCursorFilter(tableName, pkName, modelColumns, expandJoins)
|
cursorFilter, err := options.GetCursorFilter(tableName, pkName, modelColumns, expandJoins)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -1498,8 +1538,8 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err := h.hooks.Execute(BeforeDelete, hookCtx); err != nil {
|
if err := h.hooks.Execute(BeforeDelete, hookCtx); err != nil {
|
||||||
logger.Warn("BeforeDelete hook failed for ID %s: %v", itemID, err)
|
logger.Error("BeforeDelete hook failed for ID %s: %v", itemID, err)
|
||||||
continue
|
return fmt.Errorf("delete not allowed for ID %s: %w", itemID, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
query := tx.NewDelete().Table(tableName).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(reflection.GetPrimaryKeyName(model))), itemID)
|
query := tx.NewDelete().Table(tableName).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(reflection.GetPrimaryKeyName(model))), itemID)
|
||||||
@@ -1572,8 +1612,8 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err := h.hooks.Execute(BeforeDelete, hookCtx); err != nil {
|
if err := h.hooks.Execute(BeforeDelete, hookCtx); err != nil {
|
||||||
logger.Warn("BeforeDelete hook failed for ID %v: %v", itemID, err)
|
logger.Error("BeforeDelete hook failed for ID %v: %v", itemID, err)
|
||||||
continue
|
return fmt.Errorf("delete not allowed for ID %v: %w", itemID, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
query := tx.NewDelete().Table(tableName).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(reflection.GetPrimaryKeyName(model))), itemID)
|
query := tx.NewDelete().Table(tableName).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(reflection.GetPrimaryKeyName(model))), itemID)
|
||||||
@@ -1630,8 +1670,8 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err := h.hooks.Execute(BeforeDelete, hookCtx); err != nil {
|
if err := h.hooks.Execute(BeforeDelete, hookCtx); err != nil {
|
||||||
logger.Warn("BeforeDelete hook failed for ID %v: %v", itemID, err)
|
logger.Error("BeforeDelete hook failed for ID %v: %v", itemID, err)
|
||||||
continue
|
return fmt.Errorf("delete not allowed for ID %v: %w", itemID, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
query := tx.NewDelete().Table(tableName).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(reflection.GetPrimaryKeyName(model))), itemID)
|
query := tx.NewDelete().Table(tableName).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(reflection.GetPrimaryKeyName(model))), itemID)
|
||||||
@@ -2111,7 +2151,11 @@ func (h *Handler) applyFilter(query common.SelectQuery, filter common.FilterOpti
|
|||||||
// Column is already cast to TEXT if needed
|
// Column is already cast to TEXT if needed
|
||||||
return applyWhere(fmt.Sprintf("%s ILIKE ?", qualifiedColumn), filter.Value)
|
return applyWhere(fmt.Sprintf("%s ILIKE ?", qualifiedColumn), filter.Value)
|
||||||
case "in":
|
case "in":
|
||||||
return applyWhere(fmt.Sprintf("%s IN (?)", qualifiedColumn), filter.Value)
|
cond, inArgs := common.BuildInCondition(qualifiedColumn, filter.Value)
|
||||||
|
if cond == "" {
|
||||||
|
return query
|
||||||
|
}
|
||||||
|
return applyWhere(cond, inArgs...)
|
||||||
case "between":
|
case "between":
|
||||||
// Handle between operator - exclusive (> val1 AND < val2)
|
// Handle between operator - exclusive (> val1 AND < val2)
|
||||||
if values, ok := filter.Value.([]interface{}); ok && len(values) == 2 {
|
if values, ok := filter.Value.([]interface{}); ok && len(values) == 2 {
|
||||||
@@ -2187,24 +2231,25 @@ func (h *Handler) applyOrFilterGroup(query common.SelectQuery, filters []*common
|
|||||||
// buildFilterCondition builds a single filter condition and returns the condition string and args
|
// buildFilterCondition builds a single filter condition and returns the condition string and args
|
||||||
func (h *Handler) buildFilterCondition(qualifiedColumn string, filter *common.FilterOption, tableName string) (filterStr string, filterInterface []interface{}) {
|
func (h *Handler) buildFilterCondition(qualifiedColumn string, filter *common.FilterOption, tableName string) (filterStr string, filterInterface []interface{}) {
|
||||||
switch strings.ToLower(filter.Operator) {
|
switch strings.ToLower(filter.Operator) {
|
||||||
case "eq", "equals":
|
case "eq", "equals", "=":
|
||||||
return fmt.Sprintf("%s = ?", qualifiedColumn), []interface{}{filter.Value}
|
return fmt.Sprintf("%s = ?", qualifiedColumn), []interface{}{filter.Value}
|
||||||
case "neq", "not_equals", "ne":
|
case "neq", "not_equals", "ne", "!=", "<>":
|
||||||
return fmt.Sprintf("%s != ?", qualifiedColumn), []interface{}{filter.Value}
|
return fmt.Sprintf("%s != ?", qualifiedColumn), []interface{}{filter.Value}
|
||||||
case "gt", "greater_than":
|
case "gt", "greater_than", ">":
|
||||||
return fmt.Sprintf("%s > ?", qualifiedColumn), []interface{}{filter.Value}
|
return fmt.Sprintf("%s > ?", qualifiedColumn), []interface{}{filter.Value}
|
||||||
case "gte", "greater_than_equals", "ge":
|
case "gte", "greater_than_equals", "ge", ">=":
|
||||||
return fmt.Sprintf("%s >= ?", qualifiedColumn), []interface{}{filter.Value}
|
return fmt.Sprintf("%s >= ?", qualifiedColumn), []interface{}{filter.Value}
|
||||||
case "lt", "less_than":
|
case "lt", "less_than", "<":
|
||||||
return fmt.Sprintf("%s < ?", qualifiedColumn), []interface{}{filter.Value}
|
return fmt.Sprintf("%s < ?", qualifiedColumn), []interface{}{filter.Value}
|
||||||
case "lte", "less_than_equals", "le":
|
case "lte", "less_than_equals", "le", "<=":
|
||||||
return fmt.Sprintf("%s <= ?", qualifiedColumn), []interface{}{filter.Value}
|
return fmt.Sprintf("%s <= ?", qualifiedColumn), []interface{}{filter.Value}
|
||||||
case "like":
|
case "like":
|
||||||
return fmt.Sprintf("%s LIKE ?", qualifiedColumn), []interface{}{filter.Value}
|
return fmt.Sprintf("%s LIKE ?", qualifiedColumn), []interface{}{filter.Value}
|
||||||
case "ilike":
|
case "ilike":
|
||||||
return fmt.Sprintf("%s ILIKE ?", qualifiedColumn), []interface{}{filter.Value}
|
return fmt.Sprintf("%s ILIKE ?", qualifiedColumn), []interface{}{filter.Value}
|
||||||
case "in":
|
case "in":
|
||||||
return fmt.Sprintf("%s IN (?)", qualifiedColumn), []interface{}{filter.Value}
|
cond, inArgs := common.BuildInCondition(qualifiedColumn, filter.Value)
|
||||||
|
return cond, inArgs
|
||||||
case "between":
|
case "between":
|
||||||
// Handle between operator - exclusive (> val1 AND < val2)
|
// Handle between operator - exclusive (> val1 AND < val2)
|
||||||
if values, ok := filter.Value.([]interface{}); ok && len(values) == 2 {
|
if values, ok := filter.Value.([]interface{}); ok && len(values) == 2 {
|
||||||
@@ -2839,6 +2884,8 @@ func (h *Handler) filterExtendedOptions(validator *common.ColumnValidator, optio
|
|||||||
|
|
||||||
// Filter base RequestOptions
|
// Filter base RequestOptions
|
||||||
filtered.RequestOptions = validator.FilterRequestOptions(options.RequestOptions)
|
filtered.RequestOptions = validator.FilterRequestOptions(options.RequestOptions)
|
||||||
|
// Restore JoinAliases cleared by FilterRequestOptions — still needed for SanitizeWhereClause
|
||||||
|
filtered.RequestOptions.JoinAliases = options.JoinAliases
|
||||||
|
|
||||||
// Filter SearchColumns
|
// Filter SearchColumns
|
||||||
filtered.SearchColumns = validator.FilterValidColumns(options.SearchColumns)
|
filtered.SearchColumns = validator.FilterValidColumns(options.SearchColumns)
|
||||||
|
|||||||
@@ -1061,15 +1061,42 @@ func (h *Handler) addXFilesPreload(xfile *XFiles, options *ExtendedRequestOption
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Transfer SqlJoins from XFiles to PreloadOption first, so aliases are available for WHERE sanitization
|
||||||
|
if len(xfile.SqlJoins) > 0 {
|
||||||
|
preloadOpt.SqlJoins = make([]string, 0, len(xfile.SqlJoins))
|
||||||
|
preloadOpt.JoinAliases = make([]string, 0, len(xfile.SqlJoins))
|
||||||
|
|
||||||
|
for _, joinClause := range xfile.SqlJoins {
|
||||||
|
// Sanitize the join clause
|
||||||
|
sanitizedJoin := common.SanitizeWhereClause(joinClause, "", nil)
|
||||||
|
if sanitizedJoin == "" {
|
||||||
|
logger.Warn("X-Files: SqlJoin failed sanitization for %s: %s", relationPath, joinClause)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
preloadOpt.SqlJoins = append(preloadOpt.SqlJoins, sanitizedJoin)
|
||||||
|
|
||||||
|
// Extract join alias for validation
|
||||||
|
alias := extractJoinAlias(sanitizedJoin)
|
||||||
|
if alias != "" {
|
||||||
|
preloadOpt.JoinAliases = append(preloadOpt.JoinAliases, alias)
|
||||||
|
logger.Debug("X-Files: Extracted join alias for %s: %s", relationPath, alias)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Debug("X-Files: Added %d SQL joins to preload %s", len(preloadOpt.SqlJoins), relationPath)
|
||||||
|
}
|
||||||
|
|
||||||
// Add WHERE clause if SQL conditions specified
|
// Add WHERE clause if SQL conditions specified
|
||||||
|
// SqlJoins must be processed first so join aliases are known and not incorrectly replaced
|
||||||
whereConditions := make([]string, 0)
|
whereConditions := make([]string, 0)
|
||||||
if len(xfile.SqlAnd) > 0 {
|
if len(xfile.SqlAnd) > 0 {
|
||||||
// Process each SQL condition
|
var sqlAndOpts *common.RequestOptions
|
||||||
// Note: We don't add table prefixes here because they're only needed for JOINs
|
if len(preloadOpt.JoinAliases) > 0 {
|
||||||
// The handler will add prefixes later if SqlJoins are present
|
sqlAndOpts = &common.RequestOptions{JoinAliases: preloadOpt.JoinAliases}
|
||||||
|
}
|
||||||
for _, sqlCond := range xfile.SqlAnd {
|
for _, sqlCond := range xfile.SqlAnd {
|
||||||
// Sanitize the condition without adding prefixes
|
sanitizedCond := common.SanitizeWhereClause(sqlCond, xfile.TableName, sqlAndOpts)
|
||||||
sanitizedCond := common.SanitizeWhereClause(sqlCond, xfile.TableName)
|
|
||||||
if sanitizedCond != "" {
|
if sanitizedCond != "" {
|
||||||
whereConditions = append(whereConditions, sanitizedCond)
|
whereConditions = append(whereConditions, sanitizedCond)
|
||||||
}
|
}
|
||||||
@@ -1114,32 +1141,6 @@ func (h *Handler) addXFilesPreload(xfile *XFiles, options *ExtendedRequestOption
|
|||||||
logger.Debug("X-Files: Set foreign key for %s: %s", relationPath, xfile.ForeignKey)
|
logger.Debug("X-Files: Set foreign key for %s: %s", relationPath, xfile.ForeignKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Transfer SqlJoins from XFiles to PreloadOption
|
|
||||||
if len(xfile.SqlJoins) > 0 {
|
|
||||||
preloadOpt.SqlJoins = make([]string, 0, len(xfile.SqlJoins))
|
|
||||||
preloadOpt.JoinAliases = make([]string, 0, len(xfile.SqlJoins))
|
|
||||||
|
|
||||||
for _, joinClause := range xfile.SqlJoins {
|
|
||||||
// Sanitize the join clause
|
|
||||||
sanitizedJoin := common.SanitizeWhereClause(joinClause, "", nil)
|
|
||||||
if sanitizedJoin == "" {
|
|
||||||
logger.Warn("X-Files: SqlJoin failed sanitization for %s: %s", relationPath, joinClause)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
preloadOpt.SqlJoins = append(preloadOpt.SqlJoins, sanitizedJoin)
|
|
||||||
|
|
||||||
// Extract join alias for validation
|
|
||||||
alias := extractJoinAlias(sanitizedJoin)
|
|
||||||
if alias != "" {
|
|
||||||
preloadOpt.JoinAliases = append(preloadOpt.JoinAliases, alias)
|
|
||||||
logger.Debug("X-Files: Extracted join alias for %s: %s", relationPath, alias)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Debug("X-Files: Added %d SQL joins to preload %s", len(preloadOpt.SqlJoins), relationPath)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if this table has a recursive child - if so, mark THIS preload as recursive
|
// Check if this table has a recursive child - if so, mark THIS preload as recursive
|
||||||
// and store the recursive child's RelatedKey for recursion generation
|
// and store the recursive child's RelatedKey for recursion generation
|
||||||
hasRecursiveChild := false
|
hasRecursiveChild := false
|
||||||
|
|||||||
@@ -12,6 +12,10 @@ import (
|
|||||||
type HookType string
|
type HookType string
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
// BeforeHandle fires after model resolution, before operation dispatch.
|
||||||
|
// Use this for auth checks that need model rules and user context simultaneously.
|
||||||
|
BeforeHandle HookType = "before_handle"
|
||||||
|
|
||||||
// Read operation hooks
|
// Read operation hooks
|
||||||
BeforeRead HookType = "before_read"
|
BeforeRead HookType = "before_read"
|
||||||
AfterRead HookType = "after_read"
|
AfterRead HookType = "after_read"
|
||||||
@@ -42,6 +46,9 @@ type HookContext struct {
|
|||||||
Model interface{}
|
Model interface{}
|
||||||
Options ExtendedRequestOptions
|
Options ExtendedRequestOptions
|
||||||
|
|
||||||
|
// Operation being dispatched (e.g. "read", "create", "update", "delete")
|
||||||
|
Operation string
|
||||||
|
|
||||||
// Operation-specific fields
|
// Operation-specific fields
|
||||||
ID string
|
ID string
|
||||||
Data interface{} // For create/update operations
|
Data interface{} // For create/update operations
|
||||||
@@ -56,6 +63,14 @@ type HookContext struct {
|
|||||||
// Response writer - allows hooks to modify response
|
// Response writer - allows hooks to modify response
|
||||||
Writer common.ResponseWriter
|
Writer common.ResponseWriter
|
||||||
|
|
||||||
|
// Request - the original HTTP request
|
||||||
|
Request common.Request
|
||||||
|
|
||||||
|
// Allow hooks to abort the operation
|
||||||
|
Abort bool // If set to true, the operation will be aborted
|
||||||
|
AbortMessage string // Message to return if aborted
|
||||||
|
AbortCode int // HTTP status code if aborted
|
||||||
|
|
||||||
// Tx provides access to the database/transaction for executing additional SQL
|
// Tx provides access to the database/transaction for executing additional SQL
|
||||||
// This allows hooks to run custom queries in addition to the main Query chain
|
// This allows hooks to run custom queries in addition to the main Query chain
|
||||||
Tx common.Database
|
Tx common.Database
|
||||||
@@ -110,6 +125,12 @@ func (r *HookRegistry) Execute(hookType HookType, ctx *HookContext) error {
|
|||||||
logger.Error("Hook %d for %s failed: %v", i+1, hookType, err)
|
logger.Error("Hook %d for %s failed: %v", i+1, hookType, err)
|
||||||
return fmt.Errorf("hook execution failed: %w", err)
|
return fmt.Errorf("hook execution failed: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check if hook requested abort
|
||||||
|
if ctx.Abort {
|
||||||
|
logger.Warn("Hook %d for %s requested abort: %s", i+1, hookType, ctx.AbortMessage)
|
||||||
|
return fmt.Errorf("operation aborted by hook: %s", ctx.AbortMessage)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// logger.Debug("All hooks for %s executed successfully", hookType)
|
// logger.Debug("All hooks for %s executed successfully", hookType)
|
||||||
|
|||||||
@@ -125,17 +125,17 @@ func SetupMuxRoutes(muxRouter *mux.Router, handler *Handler, authMiddleware Midd
|
|||||||
metadataPath := buildRoutePath(schema, entity) + "/metadata"
|
metadataPath := buildRoutePath(schema, entity) + "/metadata"
|
||||||
|
|
||||||
// Create handler functions for this specific entity
|
// Create handler functions for this specific entity
|
||||||
entityHandler := createMuxHandler(handler, schema, entity, "")
|
var entityHandler http.Handler = createMuxHandler(handler, schema, entity, "")
|
||||||
entityWithIDHandler := createMuxHandler(handler, schema, entity, "id")
|
var entityWithIDHandler http.Handler = createMuxHandler(handler, schema, entity, "id")
|
||||||
metadataHandler := createMuxGetHandler(handler, schema, entity, "")
|
var metadataHandler http.Handler = createMuxGetHandler(handler, schema, entity, "")
|
||||||
optionsEntityHandler := createMuxOptionsHandler(handler, schema, entity, []string{"GET", "POST", "OPTIONS"})
|
optionsEntityHandler := createMuxOptionsHandler(handler, schema, entity, []string{"GET", "POST", "OPTIONS"})
|
||||||
optionsEntityWithIDHandler := createMuxOptionsHandler(handler, schema, entity, []string{"GET", "PUT", "PATCH", "DELETE", "POST", "OPTIONS"})
|
optionsEntityWithIDHandler := createMuxOptionsHandler(handler, schema, entity, []string{"GET", "PUT", "PATCH", "DELETE", "POST", "OPTIONS"})
|
||||||
|
|
||||||
// Apply authentication middleware if provided
|
// Apply authentication middleware if provided
|
||||||
if authMiddleware != nil {
|
if authMiddleware != nil {
|
||||||
entityHandler = authMiddleware(entityHandler).(http.HandlerFunc)
|
entityHandler = authMiddleware(entityHandler)
|
||||||
entityWithIDHandler = authMiddleware(entityWithIDHandler).(http.HandlerFunc)
|
entityWithIDHandler = authMiddleware(entityWithIDHandler)
|
||||||
metadataHandler = authMiddleware(metadataHandler).(http.HandlerFunc)
|
metadataHandler = authMiddleware(metadataHandler)
|
||||||
// Don't apply auth middleware to OPTIONS - CORS preflight must not require auth
|
// Don't apply auth middleware to OPTIONS - CORS preflight must not require auth
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -289,7 +289,11 @@ func wrapBunRouterHandler(handler bunrouter.HandlerFunc, authMiddleware Middlewa
|
|||||||
return func(w http.ResponseWriter, req bunrouter.Request) error {
|
return func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||||
// Create an http.Handler that calls the bunrouter handler
|
// Create an http.Handler that calls the bunrouter handler
|
||||||
httpHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
httpHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
_ = handler(w, req)
|
// Replace the embedded *http.Request with the middleware-enriched one
|
||||||
|
// so that auth context (user ID, etc.) is visible to the handler.
|
||||||
|
enrichedReq := req
|
||||||
|
enrichedReq.Request = r
|
||||||
|
_ = handler(w, enrichedReq)
|
||||||
})
|
})
|
||||||
|
|
||||||
// Wrap with auth middleware and execute
|
// Wrap with auth middleware and execute
|
||||||
@@ -313,6 +317,14 @@ func SetupBunRouterRoutes(r BunRouterHandler, handler *Handler, authMiddleware M
|
|||||||
respAdapter := router.NewHTTPResponseWriter(w)
|
respAdapter := router.NewHTTPResponseWriter(w)
|
||||||
reqAdapter := router.NewBunRouterRequest(req)
|
reqAdapter := router.NewBunRouterRequest(req)
|
||||||
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||||
|
handler.HandleOpenAPI(respAdapter, reqAdapter)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
r.Handle("OPTIONS", "/openapi", func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||||
|
respAdapter := router.NewHTTPResponseWriter(w)
|
||||||
|
reqAdapter := router.NewHTTPRequest(req.Request)
|
||||||
|
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package restheadspec
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/security"
|
"github.com/bitechdev/ResolveSpec/pkg/security"
|
||||||
@@ -9,6 +10,17 @@ import (
|
|||||||
|
|
||||||
// RegisterSecurityHooks registers all security-related hooks with the handler
|
// RegisterSecurityHooks registers all security-related hooks with the handler
|
||||||
func RegisterSecurityHooks(handler *Handler, securityList *security.SecurityList) {
|
func RegisterSecurityHooks(handler *Handler, securityList *security.SecurityList) {
|
||||||
|
// Hook 0: BeforeHandle - enforce auth after model resolution
|
||||||
|
handler.Hooks().Register(BeforeHandle, func(hookCtx *HookContext) error {
|
||||||
|
if err := security.CheckModelAuthAllowed(newSecurityContext(hookCtx), hookCtx.Operation); err != nil {
|
||||||
|
hookCtx.Abort = true
|
||||||
|
hookCtx.AbortMessage = err.Error()
|
||||||
|
hookCtx.AbortCode = http.StatusUnauthorized
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
// Hook 1: BeforeRead - Load security rules
|
// Hook 1: BeforeRead - Load security rules
|
||||||
handler.Hooks().Register(BeforeRead, func(hookCtx *HookContext) error {
|
handler.Hooks().Register(BeforeRead, func(hookCtx *HookContext) error {
|
||||||
secCtx := newSecurityContext(hookCtx)
|
secCtx := newSecurityContext(hookCtx)
|
||||||
@@ -33,6 +45,18 @@ func RegisterSecurityHooks(handler *Handler, securityList *security.SecurityList
|
|||||||
return security.LogDataAccess(secCtx)
|
return security.LogDataAccess(secCtx)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// Hook 5: BeforeUpdate - enforce CanUpdate rule from context/registry
|
||||||
|
handler.Hooks().Register(BeforeUpdate, func(hookCtx *HookContext) error {
|
||||||
|
secCtx := newSecurityContext(hookCtx)
|
||||||
|
return security.CheckModelUpdateAllowed(secCtx)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Hook 6: BeforeDelete - enforce CanDelete rule from context/registry
|
||||||
|
handler.Hooks().Register(BeforeDelete, func(hookCtx *HookContext) error {
|
||||||
|
secCtx := newSecurityContext(hookCtx)
|
||||||
|
return security.CheckModelDeleteAllowed(secCtx)
|
||||||
|
})
|
||||||
|
|
||||||
logger.Info("Security hooks registered for restheadspec handler")
|
logger.Info("Security hooks registered for restheadspec handler")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -405,11 +405,16 @@ assert.Equal(t, "user_id = {UserID}", row.Template)
|
|||||||
```
|
```
|
||||||
HTTP Request
|
HTTP Request
|
||||||
↓
|
↓
|
||||||
NewAuthMiddleware → calls provider.Authenticate()
|
NewOptionalAuthMiddleware → calls provider.Authenticate()
|
||||||
↓ (adds UserContext to context)
|
↓ (adds UserContext or guest context; never 401)
|
||||||
SetSecurityMiddleware → adds SecurityList to context
|
SetSecurityMiddleware → adds SecurityList to context
|
||||||
↓
|
↓
|
||||||
Handler.Handle()
|
Handler.Handle() → resolves model
|
||||||
|
↓
|
||||||
|
BeforeHandle Hook → CheckModelAuthAllowed(secCtx, operation)
|
||||||
|
├─ SecurityDisabled → allow
|
||||||
|
├─ CanPublicRead/Create/Update/Delete → allow unauthenticated
|
||||||
|
└─ UserID == 0 → abort 401
|
||||||
↓
|
↓
|
||||||
BeforeRead Hook → calls provider.GetColumnSecurity() + GetRowSecurity()
|
BeforeRead Hook → calls provider.GetColumnSecurity() + GetRowSecurity()
|
||||||
↓
|
↓
|
||||||
@@ -693,15 +698,30 @@ http.Handle("/api/protected", authHandler)
|
|||||||
optionalHandler := security.NewOptionalAuthHandler(securityList, myHandler)
|
optionalHandler := security.NewOptionalAuthHandler(securityList, myHandler)
|
||||||
http.Handle("/home", optionalHandler)
|
http.Handle("/home", optionalHandler)
|
||||||
|
|
||||||
// Example handler
|
// NewOptionalAuthMiddleware - For spec routes; auth enforcement deferred to BeforeHandle
|
||||||
func myHandler(w http.ResponseWriter, r *http.Request) {
|
apiRouter.Use(security.NewOptionalAuthMiddleware(securityList))
|
||||||
userCtx, _ := security.GetUserContext(r.Context())
|
apiRouter.Use(security.SetSecurityMiddleware(securityList))
|
||||||
if userCtx.UserID == 0 {
|
restheadspec.RegisterSecurityHooks(handler, securityList) // includes BeforeHandle
|
||||||
// Guest user
|
```
|
||||||
} else {
|
|
||||||
// Authenticated user
|
---
|
||||||
}
|
|
||||||
}
|
## Model-Level Access Control
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Register model with rules (pkg/modelregistry)
|
||||||
|
modelregistry.RegisterModelWithRules("public.products", &Product{}, modelregistry.ModelRules{
|
||||||
|
SecurityDisabled: false, // skip all auth when true
|
||||||
|
CanPublicRead: true, // unauthenticated reads allowed
|
||||||
|
CanPublicCreate: false, // requires auth
|
||||||
|
CanPublicUpdate: false, // requires auth
|
||||||
|
CanPublicDelete: false, // requires auth
|
||||||
|
CanUpdate: true, // authenticated can update
|
||||||
|
CanDelete: false, // authenticated cannot delete (enforced in BeforeDelete)
|
||||||
|
})
|
||||||
|
|
||||||
|
// CheckModelAuthAllowed used automatically in BeforeHandle hook
|
||||||
|
// No code needed — call RegisterSecurityHooks and it's applied
|
||||||
```
|
```
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|||||||
@@ -751,14 +751,25 @@ resolvespec.RegisterSecurityHooks(resolveHandler, securityList)
|
|||||||
```
|
```
|
||||||
HTTP Request
|
HTTP Request
|
||||||
↓
|
↓
|
||||||
NewAuthMiddleware (security package)
|
NewOptionalAuthMiddleware (security package) ← recommended for spec routes
|
||||||
├─ Calls provider.Authenticate(request)
|
├─ Calls provider.Authenticate(request)
|
||||||
└─ Adds UserContext to context
|
├─ On success: adds authenticated UserContext to context
|
||||||
|
└─ On failure: adds guest UserContext (UserID=0) to context
|
||||||
↓
|
↓
|
||||||
SetSecurityMiddleware (security package)
|
SetSecurityMiddleware (security package)
|
||||||
└─ Adds SecurityList to context
|
└─ Adds SecurityList to context
|
||||||
↓
|
↓
|
||||||
Spec Handler (restheadspec/funcspec/resolvespec)
|
Spec Handler (restheadspec/funcspec/resolvespec/websocketspec/mqttspec)
|
||||||
|
└─ Resolves schema + entity + model from request
|
||||||
|
↓
|
||||||
|
BeforeHandle Hook (registered by spec via RegisterSecurityHooks)
|
||||||
|
├─ Adapts spec's HookContext → SecurityContext
|
||||||
|
├─ Calls security.CheckModelAuthAllowed(secCtx, operation)
|
||||||
|
│ ├─ Loads model rules from context or registry
|
||||||
|
│ ├─ SecurityDisabled → allow
|
||||||
|
│ ├─ CanPublicRead/Create/Update/Delete → allow unauthenticated
|
||||||
|
│ └─ UserID == 0 → 401 unauthorized
|
||||||
|
└─ On error: aborts with 401
|
||||||
↓
|
↓
|
||||||
BeforeRead Hook (registered by spec)
|
BeforeRead Hook (registered by spec)
|
||||||
├─ Adapts spec's HookContext → SecurityContext
|
├─ Adapts spec's HookContext → SecurityContext
|
||||||
@@ -784,7 +795,8 @@ HTTP Response (secured data)
|
|||||||
```
|
```
|
||||||
|
|
||||||
**Key Points:**
|
**Key Points:**
|
||||||
- Security package is spec-agnostic and provides core logic
|
- `NewOptionalAuthMiddleware` never rejects — it sets guest context on auth failure; `BeforeHandle` enforces auth after model resolution
|
||||||
|
- `BeforeHandle` fires after model resolution, giving access to model rules and user context simultaneously
|
||||||
- Each spec registers its own hooks that adapt to SecurityContext
|
- Each spec registers its own hooks that adapt to SecurityContext
|
||||||
- Security rules are loaded once and cached for the request
|
- Security rules are loaded once and cached for the request
|
||||||
- Row security is applied to the query (database level)
|
- Row security is applied to the query (database level)
|
||||||
@@ -1002,15 +1014,49 @@ func (p *MyProvider) GetRowSecurity(ctx context.Context, userID int, schema, tab
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Model-Level Access Control
|
||||||
|
|
||||||
|
Use `ModelRules` (from `pkg/modelregistry`) to control per-entity auth behavior:
|
||||||
|
|
||||||
|
```go
|
||||||
|
modelregistry.RegisterModelWithRules("public.products", &Product{}, modelregistry.ModelRules{
|
||||||
|
SecurityDisabled: false, // true = skip all auth checks
|
||||||
|
CanPublicRead: true, // unauthenticated GET allowed
|
||||||
|
CanPublicCreate: false, // requires auth
|
||||||
|
CanPublicUpdate: false, // requires auth
|
||||||
|
CanPublicDelete: false, // requires auth
|
||||||
|
CanUpdate: true, // authenticated users can update
|
||||||
|
CanDelete: false, // authenticated users cannot delete
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
`CheckModelAuthAllowed(secCtx, operation)` applies these rules in `BeforeHandle`:
|
||||||
|
1. `SecurityDisabled` → allow all
|
||||||
|
2. `CanPublicRead/Create/Update/Delete` → allow unauthenticated for that operation
|
||||||
|
3. Guest (UserID == 0) → return 401
|
||||||
|
4. Authenticated → allow (operation-specific `CanUpdate`/`CanDelete` checked in `BeforeUpdate`/`BeforeDelete`)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
## Middleware and Handler API
|
## Middleware and Handler API
|
||||||
|
|
||||||
### NewAuthMiddleware
|
### NewAuthMiddleware
|
||||||
Standard middleware that authenticates all requests:
|
Standard middleware that authenticates all requests and returns 401 on failure:
|
||||||
|
|
||||||
```go
|
```go
|
||||||
router.Use(security.NewAuthMiddleware(securityList))
|
router.Use(security.NewAuthMiddleware(securityList))
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### NewOptionalAuthMiddleware
|
||||||
|
Middleware for spec routes — always continues; sets guest context on auth failure:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Use with RegisterSecurityHooks — auth enforcement is deferred to BeforeHandle
|
||||||
|
apiRouter.Use(security.NewOptionalAuthMiddleware(securityList))
|
||||||
|
apiRouter.Use(security.SetSecurityMiddleware(securityList))
|
||||||
|
restheadspec.RegisterSecurityHooks(handler, securityList) // registers BeforeHandle
|
||||||
|
```
|
||||||
|
|
||||||
Routes can skip authentication using the `SkipAuth` helper:
|
Routes can skip authentication using the `SkipAuth` helper:
|
||||||
|
|
||||||
```go
|
```go
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"reflect"
|
"reflect"
|
||||||
|
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
|
||||||
)
|
)
|
||||||
|
|
||||||
// SecurityContext is a generic interface that any spec can implement to integrate with security features
|
// SecurityContext is a generic interface that any spec can implement to integrate with security features
|
||||||
@@ -226,6 +227,122 @@ func ApplyColumnSecurity(secCtx SecurityContext, securityList *SecurityList) err
|
|||||||
return applyColumnSecurity(secCtx, securityList)
|
return applyColumnSecurity(secCtx, securityList)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// checkModelUpdateAllowed returns an error if CanUpdate is false for the model.
|
||||||
|
// Rules are read from context (set by NewModelAuthMiddleware) with a fallback to the model registry.
|
||||||
|
func checkModelUpdateAllowed(secCtx SecurityContext) error {
|
||||||
|
rules, ok := GetModelRulesFromContext(secCtx.GetContext())
|
||||||
|
if !ok {
|
||||||
|
schema := secCtx.GetSchema()
|
||||||
|
entity := secCtx.GetEntity()
|
||||||
|
var err error
|
||||||
|
if schema != "" {
|
||||||
|
rules, err = modelregistry.GetModelRulesByName(fmt.Sprintf("%s.%s", schema, entity))
|
||||||
|
}
|
||||||
|
if err != nil || schema == "" {
|
||||||
|
rules, err = modelregistry.GetModelRulesByName(entity)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil // model not registered, allow by default
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !rules.CanUpdate {
|
||||||
|
return fmt.Errorf("update not allowed for %s", secCtx.GetEntity())
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkModelDeleteAllowed returns an error if CanDelete is false for the model.
|
||||||
|
// Rules are read from context (set by NewModelAuthMiddleware) with a fallback to the model registry.
|
||||||
|
func checkModelDeleteAllowed(secCtx SecurityContext) error {
|
||||||
|
rules, ok := GetModelRulesFromContext(secCtx.GetContext())
|
||||||
|
if !ok {
|
||||||
|
schema := secCtx.GetSchema()
|
||||||
|
entity := secCtx.GetEntity()
|
||||||
|
var err error
|
||||||
|
if schema != "" {
|
||||||
|
rules, err = modelregistry.GetModelRulesByName(fmt.Sprintf("%s.%s", schema, entity))
|
||||||
|
}
|
||||||
|
if err != nil || schema == "" {
|
||||||
|
rules, err = modelregistry.GetModelRulesByName(entity)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil // model not registered, allow by default
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !rules.CanDelete {
|
||||||
|
return fmt.Errorf("delete not allowed for %s", secCtx.GetEntity())
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CheckModelAuthAllowed checks whether the requested operation is permitted based on
|
||||||
|
// model rules and the current user's authentication state. It is intended for use in
|
||||||
|
// a BeforeHandle hook, fired after model resolution.
|
||||||
|
//
|
||||||
|
// Logic:
|
||||||
|
// 1. Load model rules from context (set by NewModelAuthMiddleware) or fall back to registry.
|
||||||
|
// 2. SecurityDisabled → allow.
|
||||||
|
// 3. operation == "read" && CanPublicRead → allow.
|
||||||
|
// 4. operation == "create" && CanPublicCreate → allow.
|
||||||
|
// 5. operation == "update" && CanPublicUpdate → allow.
|
||||||
|
// 6. operation == "delete" && CanPublicDelete → allow.
|
||||||
|
// 7. Guest (UserID == 0) → return "authentication required".
|
||||||
|
// 8. Authenticated user → allow (operation-specific checks remain in BeforeUpdate/BeforeDelete).
|
||||||
|
func CheckModelAuthAllowed(secCtx SecurityContext, operation string) error {
|
||||||
|
rules, ok := GetModelRulesFromContext(secCtx.GetContext())
|
||||||
|
if !ok {
|
||||||
|
schema := secCtx.GetSchema()
|
||||||
|
entity := secCtx.GetEntity()
|
||||||
|
var err error
|
||||||
|
if schema != "" {
|
||||||
|
rules, err = modelregistry.GetModelRulesByName(fmt.Sprintf("%s.%s", schema, entity))
|
||||||
|
}
|
||||||
|
if err != nil || schema == "" {
|
||||||
|
rules, err = modelregistry.GetModelRulesByName(entity)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
// Model not registered - fall through to auth check
|
||||||
|
userID, _ := secCtx.GetUserID()
|
||||||
|
if userID == 0 {
|
||||||
|
return fmt.Errorf("authentication required")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if rules.SecurityDisabled {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if operation == "read" && rules.CanPublicRead {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if operation == "create" && rules.CanPublicCreate {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if operation == "update" && rules.CanPublicUpdate {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if operation == "delete" && rules.CanPublicDelete {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
userID, _ := secCtx.GetUserID()
|
||||||
|
if userID == 0 {
|
||||||
|
return fmt.Errorf("authentication required")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CheckModelUpdateAllowed is the public wrapper for checkModelUpdateAllowed.
|
||||||
|
func CheckModelUpdateAllowed(secCtx SecurityContext) error {
|
||||||
|
return checkModelUpdateAllowed(secCtx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CheckModelDeleteAllowed is the public wrapper for checkModelDeleteAllowed.
|
||||||
|
func CheckModelDeleteAllowed(secCtx SecurityContext) error {
|
||||||
|
return checkModelDeleteAllowed(secCtx)
|
||||||
|
}
|
||||||
|
|
||||||
// Helper functions
|
// Helper functions
|
||||||
|
|
||||||
func contains(s, substr string) bool {
|
func contains(s, substr string) bool {
|
||||||
|
|||||||
@@ -4,6 +4,8 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
|
||||||
)
|
)
|
||||||
|
|
||||||
// contextKey is a custom type for context keys to avoid collisions
|
// contextKey is a custom type for context keys to avoid collisions
|
||||||
@@ -23,6 +25,7 @@ const (
|
|||||||
UserMetaKey contextKey = "user_meta"
|
UserMetaKey contextKey = "user_meta"
|
||||||
SkipAuthKey contextKey = "skip_auth"
|
SkipAuthKey contextKey = "skip_auth"
|
||||||
OptionalAuthKey contextKey = "optional_auth"
|
OptionalAuthKey contextKey = "optional_auth"
|
||||||
|
ModelRulesKey contextKey = "model_rules"
|
||||||
)
|
)
|
||||||
|
|
||||||
// SkipAuth returns a context with skip auth flag set to true
|
// SkipAuth returns a context with skip auth flag set to true
|
||||||
@@ -136,6 +139,31 @@ func NewOptionalAuthHandler(securityList *SecurityList, next http.Handler) http.
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NewOptionalAuthMiddleware creates authentication middleware that always continues.
|
||||||
|
// On auth failure, a guest user context is set instead of returning 401.
|
||||||
|
// Intended for spec routes where auth enforcement is deferred to a BeforeHandle hook
|
||||||
|
// after model resolution.
|
||||||
|
func NewOptionalAuthMiddleware(securityList *SecurityList) func(http.Handler) http.Handler {
|
||||||
|
return func(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
provider := securityList.Provider()
|
||||||
|
if provider == nil {
|
||||||
|
http.Error(w, "Security provider not configured", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
userCtx, err := provider.Authenticate(r)
|
||||||
|
if err != nil {
|
||||||
|
guestCtx := createGuestContext(r)
|
||||||
|
next.ServeHTTP(w, setUserContext(r, guestCtx))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
next.ServeHTTP(w, setUserContext(r, userCtx))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// NewAuthMiddleware creates an authentication middleware with the given security list
|
// NewAuthMiddleware creates an authentication middleware with the given security list
|
||||||
// This middleware extracts user authentication from the request and adds it to context
|
// This middleware extracts user authentication from the request and adds it to context
|
||||||
// Routes can skip authentication by setting SkipAuthKey context value (use SkipAuth helper)
|
// Routes can skip authentication by setting SkipAuthKey context value (use SkipAuth helper)
|
||||||
@@ -182,6 +210,68 @@ func NewAuthMiddleware(securityList *SecurityList) func(http.Handler) http.Handl
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NewModelAuthMiddleware creates authentication middleware that respects ModelRules for the given model name.
|
||||||
|
// It first checks if ModelRules are set for the model:
|
||||||
|
// - If SecurityDisabled is true, authentication is skipped and a guest context is set.
|
||||||
|
// - Otherwise, all checks from NewAuthMiddleware apply (SkipAuthKey, provider check, OptionalAuthKey, Authenticate).
|
||||||
|
//
|
||||||
|
// If the model is not found in any registry, the middleware falls back to standard NewAuthMiddleware behaviour.
|
||||||
|
func NewModelAuthMiddleware(securityList *SecurityList, modelName string) func(http.Handler) http.Handler {
|
||||||
|
return func(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Check ModelRules first
|
||||||
|
if rules, err := modelregistry.GetModelRulesByName(modelName); err == nil {
|
||||||
|
// Store rules in context for downstream use (e.g., security hooks)
|
||||||
|
r = r.WithContext(context.WithValue(r.Context(), ModelRulesKey, rules))
|
||||||
|
|
||||||
|
if rules.SecurityDisabled {
|
||||||
|
guestCtx := createGuestContext(r)
|
||||||
|
next.ServeHTTP(w, setUserContext(r, guestCtx))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
isRead := r.Method == http.MethodGet || r.Method == http.MethodHead
|
||||||
|
isUpdate := r.Method == http.MethodPut || r.Method == http.MethodPatch
|
||||||
|
if (isRead && rules.CanPublicRead) || (isUpdate && rules.CanPublicUpdate) {
|
||||||
|
guestCtx := createGuestContext(r)
|
||||||
|
next.ServeHTTP(w, setUserContext(r, guestCtx))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if this route should skip authentication
|
||||||
|
if skip, ok := r.Context().Value(SkipAuthKey).(bool); ok && skip {
|
||||||
|
guestCtx := createGuestContext(r)
|
||||||
|
next.ServeHTTP(w, setUserContext(r, guestCtx))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the security provider
|
||||||
|
provider := securityList.Provider()
|
||||||
|
if provider == nil {
|
||||||
|
http.Error(w, "Security provider not configured", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if this route has optional authentication
|
||||||
|
optional, _ := r.Context().Value(OptionalAuthKey).(bool)
|
||||||
|
|
||||||
|
// Try to authenticate
|
||||||
|
userCtx, err := provider.Authenticate(r)
|
||||||
|
if err != nil {
|
||||||
|
if optional {
|
||||||
|
guestCtx := createGuestContext(r)
|
||||||
|
next.ServeHTTP(w, setUserContext(r, guestCtx))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
http.Error(w, "Authentication failed: "+err.Error(), http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
next.ServeHTTP(w, setUserContext(r, userCtx))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// SetSecurityMiddleware adds security context to requests
|
// SetSecurityMiddleware adds security context to requests
|
||||||
// This middleware should be applied after AuthMiddleware
|
// This middleware should be applied after AuthMiddleware
|
||||||
func SetSecurityMiddleware(securityList *SecurityList) func(http.Handler) http.Handler {
|
func SetSecurityMiddleware(securityList *SecurityList) func(http.Handler) http.Handler {
|
||||||
@@ -366,6 +456,131 @@ func GetUserMeta(ctx context.Context) (map[string]any, bool) {
|
|||||||
return meta, ok
|
return meta, ok
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SessionCookieOptions configures the session cookie set by SetSessionCookie.
|
||||||
|
// All fields are optional; sensible secure defaults are applied when omitted.
|
||||||
|
type SessionCookieOptions struct {
|
||||||
|
// Name is the cookie name. Defaults to "session_token".
|
||||||
|
Name string
|
||||||
|
// Path is the cookie path. Defaults to "/".
|
||||||
|
Path string
|
||||||
|
// Domain restricts the cookie to a specific domain. Empty means current host.
|
||||||
|
Domain string
|
||||||
|
// Secure sets the Secure flag. Defaults to true.
|
||||||
|
// Set to false only in local development over HTTP.
|
||||||
|
Secure *bool
|
||||||
|
// SameSite sets the SameSite policy. Defaults to http.SameSiteLaxMode.
|
||||||
|
SameSite http.SameSite
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o SessionCookieOptions) name() string {
|
||||||
|
if o.Name != "" {
|
||||||
|
return o.Name
|
||||||
|
}
|
||||||
|
return "session_token"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o SessionCookieOptions) path() string {
|
||||||
|
if o.Path != "" {
|
||||||
|
return o.Path
|
||||||
|
}
|
||||||
|
return "/"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o SessionCookieOptions) secure() bool {
|
||||||
|
if o.Secure != nil {
|
||||||
|
return *o.Secure
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o SessionCookieOptions) sameSite() http.SameSite {
|
||||||
|
if o.SameSite != 0 {
|
||||||
|
return o.SameSite
|
||||||
|
}
|
||||||
|
return http.SameSiteLaxMode
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetSessionCookie writes the session_token cookie to the response after a successful login.
|
||||||
|
// Call this immediately after a successful Authenticator.Login() call.
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
//
|
||||||
|
// resp, err := auth.Login(r.Context(), req)
|
||||||
|
// if err != nil { ... }
|
||||||
|
// security.SetSessionCookie(w, resp)
|
||||||
|
// json.NewEncoder(w).Encode(resp)
|
||||||
|
func SetSessionCookie(w http.ResponseWriter, loginResp *LoginResponse, opts ...SessionCookieOptions) {
|
||||||
|
var o SessionCookieOptions
|
||||||
|
if len(opts) > 0 {
|
||||||
|
o = opts[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
maxAge := 0
|
||||||
|
if loginResp.ExpiresIn > 0 {
|
||||||
|
maxAge = int(loginResp.ExpiresIn)
|
||||||
|
}
|
||||||
|
|
||||||
|
http.SetCookie(w, &http.Cookie{
|
||||||
|
Name: o.name(),
|
||||||
|
Value: loginResp.Token,
|
||||||
|
Path: o.path(),
|
||||||
|
Domain: o.Domain,
|
||||||
|
MaxAge: maxAge,
|
||||||
|
HttpOnly: true,
|
||||||
|
Secure: o.secure(),
|
||||||
|
SameSite: o.sameSite(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetSessionCookie returns the session token value from the request cookie, or empty string if not present.
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
//
|
||||||
|
// token := security.GetSessionCookie(r)
|
||||||
|
func GetSessionCookie(r *http.Request, opts ...SessionCookieOptions) string {
|
||||||
|
var o SessionCookieOptions
|
||||||
|
if len(opts) > 0 {
|
||||||
|
o = opts[0]
|
||||||
|
}
|
||||||
|
cookie, err := r.Cookie(o.name())
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return cookie.Value
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearSessionCookie expires the session_token cookie, effectively logging the user out on the browser side.
|
||||||
|
// Call this after a successful Authenticator.Logout() call.
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
//
|
||||||
|
// err := auth.Logout(r.Context(), req)
|
||||||
|
// if err != nil { ... }
|
||||||
|
// security.ClearSessionCookie(w)
|
||||||
|
func ClearSessionCookie(w http.ResponseWriter, opts ...SessionCookieOptions) {
|
||||||
|
var o SessionCookieOptions
|
||||||
|
if len(opts) > 0 {
|
||||||
|
o = opts[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
http.SetCookie(w, &http.Cookie{
|
||||||
|
Name: o.name(),
|
||||||
|
Value: "",
|
||||||
|
Path: o.path(),
|
||||||
|
Domain: o.Domain,
|
||||||
|
MaxAge: -1,
|
||||||
|
HttpOnly: true,
|
||||||
|
Secure: o.secure(),
|
||||||
|
SameSite: o.sameSite(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetModelRulesFromContext extracts ModelRules stored by NewModelAuthMiddleware
|
||||||
|
func GetModelRulesFromContext(ctx context.Context) (modelregistry.ModelRules, bool) {
|
||||||
|
rules, ok := ctx.Value(ModelRulesKey).(modelregistry.ModelRules)
|
||||||
|
return rules, ok
|
||||||
|
}
|
||||||
|
|
||||||
// // Handler adapters for resolvespec/restheadspec compatibility
|
// // Handler adapters for resolvespec/restheadspec compatibility
|
||||||
// // These functions allow using NewAuthHandler and NewOptionalAuthHandler with custom handler abstractions
|
// // These functions allow using NewAuthHandler and NewOptionalAuthHandler with custom handler abstractions
|
||||||
|
|
||||||
|
|||||||
@@ -222,9 +222,8 @@ func (a *DatabaseAuthenticator) Authenticate(r *http.Request) (*UserContext, err
|
|||||||
|
|
||||||
if sessionToken == "" {
|
if sessionToken == "" {
|
||||||
// Try cookie
|
// Try cookie
|
||||||
cookie, err := r.Cookie("session_token")
|
if token := GetSessionCookie(r); token != "" {
|
||||||
if err == nil {
|
tokens = []string{token}
|
||||||
tokens = []string{cookie.Value}
|
|
||||||
reference = "cookie"
|
reference = "cookie"
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@@ -98,6 +98,7 @@ func (p *EmbedFSProvider) Open(name string) (fs.File, error) {
|
|||||||
|
|
||||||
// Apply prefix stripping by prepending the prefix to the requested path
|
// Apply prefix stripping by prepending the prefix to the requested path
|
||||||
actualPath := name
|
actualPath := name
|
||||||
|
alternatePath := ""
|
||||||
if p.stripPrefix != "" {
|
if p.stripPrefix != "" {
|
||||||
// Clean the paths to handle leading/trailing slashes
|
// Clean the paths to handle leading/trailing slashes
|
||||||
prefix := strings.Trim(p.stripPrefix, "/")
|
prefix := strings.Trim(p.stripPrefix, "/")
|
||||||
@@ -105,12 +106,25 @@ func (p *EmbedFSProvider) Open(name string) (fs.File, error) {
|
|||||||
|
|
||||||
if prefix != "" {
|
if prefix != "" {
|
||||||
actualPath = path.Join(prefix, cleanName)
|
actualPath = path.Join(prefix, cleanName)
|
||||||
|
alternatePath = cleanName
|
||||||
} else {
|
} else {
|
||||||
actualPath = cleanName
|
actualPath = cleanName
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// First try the actual path with prefix
|
||||||
|
if file, err := p.fs.Open(actualPath); err == nil {
|
||||||
|
return file, nil
|
||||||
|
}
|
||||||
|
|
||||||
return p.fs.Open(actualPath)
|
// If alternate path is different, try it as well
|
||||||
|
if alternatePath != "" && alternatePath != actualPath {
|
||||||
|
if file, err := p.fs.Open(alternatePath); err == nil {
|
||||||
|
return file, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If both attempts fail, return the error from the first attempt
|
||||||
|
return nil, fmt.Errorf("file not found: %s", name)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close releases any resources held by the provider.
|
// Close releases any resources held by the provider.
|
||||||
|
|||||||
@@ -53,6 +53,7 @@ func (p *LocalFSProvider) Open(name string) (fs.File, error) {
|
|||||||
|
|
||||||
// Apply prefix stripping by prepending the prefix to the requested path
|
// Apply prefix stripping by prepending the prefix to the requested path
|
||||||
actualPath := name
|
actualPath := name
|
||||||
|
alternatePath := ""
|
||||||
if p.stripPrefix != "" {
|
if p.stripPrefix != "" {
|
||||||
// Clean the paths to handle leading/trailing slashes
|
// Clean the paths to handle leading/trailing slashes
|
||||||
prefix := strings.Trim(p.stripPrefix, "/")
|
prefix := strings.Trim(p.stripPrefix, "/")
|
||||||
@@ -60,12 +61,26 @@ func (p *LocalFSProvider) Open(name string) (fs.File, error) {
|
|||||||
|
|
||||||
if prefix != "" {
|
if prefix != "" {
|
||||||
actualPath = path.Join(prefix, cleanName)
|
actualPath = path.Join(prefix, cleanName)
|
||||||
|
alternatePath = cleanName
|
||||||
} else {
|
} else {
|
||||||
actualPath = cleanName
|
actualPath = cleanName
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return p.fs.Open(actualPath)
|
// First try the actual path with prefix
|
||||||
|
if file, err := p.fs.Open(actualPath); err == nil {
|
||||||
|
return file, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// If alternate path is different, try it as well
|
||||||
|
if alternatePath != "" && alternatePath != actualPath {
|
||||||
|
if file, err := p.fs.Open(alternatePath); err == nil {
|
||||||
|
return file, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If both attempts fail, return the error from the first attempt
|
||||||
|
return nil, fmt.Errorf("file not found: %s", name)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close releases any resources held by the provider.
|
// Close releases any resources held by the provider.
|
||||||
|
|||||||
@@ -56,6 +56,7 @@ func (p *ZipFSProvider) Open(name string) (fs.File, error) {
|
|||||||
|
|
||||||
// Apply prefix stripping by prepending the prefix to the requested path
|
// Apply prefix stripping by prepending the prefix to the requested path
|
||||||
actualPath := name
|
actualPath := name
|
||||||
|
alternatePath := ""
|
||||||
if p.stripPrefix != "" {
|
if p.stripPrefix != "" {
|
||||||
// Clean the paths to handle leading/trailing slashes
|
// Clean the paths to handle leading/trailing slashes
|
||||||
prefix := strings.Trim(p.stripPrefix, "/")
|
prefix := strings.Trim(p.stripPrefix, "/")
|
||||||
@@ -63,12 +64,26 @@ func (p *ZipFSProvider) Open(name string) (fs.File, error) {
|
|||||||
|
|
||||||
if prefix != "" {
|
if prefix != "" {
|
||||||
actualPath = path.Join(prefix, cleanName)
|
actualPath = path.Join(prefix, cleanName)
|
||||||
|
alternatePath = cleanName
|
||||||
} else {
|
} else {
|
||||||
actualPath = cleanName
|
actualPath = cleanName
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return p.zipFS.Open(actualPath)
|
// First try the actual path with prefix
|
||||||
|
if file, err := p.zipFS.Open(actualPath); err == nil {
|
||||||
|
return file, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// If alternate path is different, try it as well
|
||||||
|
if alternatePath != "" && alternatePath != actualPath {
|
||||||
|
if file, err := p.zipFS.Open(alternatePath); err == nil {
|
||||||
|
return file, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If both attempts fail, return the error from the first attempt
|
||||||
|
return nil, fmt.Errorf("file not found: %s", name)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close releases resources held by the zip reader.
|
// Close releases resources held by the zip reader.
|
||||||
|
|||||||
@@ -330,6 +330,7 @@ Hooks allow you to intercept and modify operations at various points in the life
|
|||||||
|
|
||||||
### Available Hook Types
|
### Available Hook Types
|
||||||
|
|
||||||
|
- **BeforeHandle** — fires after model resolution, before operation dispatch (auth checks)
|
||||||
- **BeforeRead** / **AfterRead**
|
- **BeforeRead** / **AfterRead**
|
||||||
- **BeforeCreate** / **AfterCreate**
|
- **BeforeCreate** / **AfterCreate**
|
||||||
- **BeforeUpdate** / **AfterUpdate**
|
- **BeforeUpdate** / **AfterUpdate**
|
||||||
@@ -337,6 +338,8 @@ Hooks allow you to intercept and modify operations at various points in the life
|
|||||||
- **BeforeSubscribe** / **AfterSubscribe**
|
- **BeforeSubscribe** / **AfterSubscribe**
|
||||||
- **BeforeConnect** / **AfterConnect**
|
- **BeforeConnect** / **AfterConnect**
|
||||||
|
|
||||||
|
`HookContext` includes `Operation string` (`"read"`, `"create"`, `"update"`, `"delete"`) and `Abort bool`, `AbortMessage string`, `AbortCode int` for abort signaling.
|
||||||
|
|
||||||
### Hook Example
|
### Hook Example
|
||||||
|
|
||||||
```go
|
```go
|
||||||
@@ -599,7 +602,19 @@ asyncio.run(main())
|
|||||||
|
|
||||||
## Authentication
|
## Authentication
|
||||||
|
|
||||||
Implement authentication using hooks:
|
Use `RegisterSecurityHooks` for integrated auth with model-rule support:
|
||||||
|
|
||||||
|
```go
|
||||||
|
import "github.com/bitechdev/ResolveSpec/pkg/security"
|
||||||
|
|
||||||
|
provider := security.NewCompositeSecurityProvider(auth, colSec, rowSec)
|
||||||
|
securityList := security.NewSecurityList(provider)
|
||||||
|
websocketspec.RegisterSecurityHooks(handler, securityList)
|
||||||
|
// Registers BeforeHandle (model auth), BeforeRead (load rules),
|
||||||
|
// AfterRead (column security + audit), BeforeUpdate, BeforeDelete
|
||||||
|
```
|
||||||
|
|
||||||
|
Or implement custom authentication using hooks directly:
|
||||||
|
|
||||||
```go
|
```go
|
||||||
handler := websocketspec.NewHandlerWithGORM(db)
|
handler := websocketspec.NewHandlerWithGORM(db)
|
||||||
|
|||||||
@@ -177,6 +177,16 @@ func (h *Handler) handleRequest(conn *Connection, msg *Message) {
|
|||||||
Metadata: make(map[string]interface{}),
|
Metadata: make(map[string]interface{}),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Execute BeforeHandle hook - auth check fires here, after model resolution
|
||||||
|
hookCtx.Operation = string(msg.Operation)
|
||||||
|
if err := h.hooks.Execute(BeforeHandle, hookCtx); err != nil {
|
||||||
|
if hookCtx.Abort {
|
||||||
|
errResp := NewErrorResponse(msg.ID, "unauthorized", hookCtx.AbortMessage)
|
||||||
|
_ = conn.SendJSON(errResp)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// Route to operation handler
|
// Route to operation handler
|
||||||
switch msg.Operation {
|
switch msg.Operation {
|
||||||
case OperationRead:
|
case OperationRead:
|
||||||
@@ -618,7 +628,10 @@ func (h *Handler) readMultiple(hookCtx *HookContext) (data interface{}, metadata
|
|||||||
countQuery := h.db.NewSelect().Model(hookCtx.ModelPtr).Table(hookCtx.TableName)
|
countQuery := h.db.NewSelect().Model(hookCtx.ModelPtr).Table(hookCtx.TableName)
|
||||||
if hookCtx.Options != nil {
|
if hookCtx.Options != nil {
|
||||||
for _, filter := range hookCtx.Options.Filters {
|
for _, filter := range hookCtx.Options.Filters {
|
||||||
countQuery = countQuery.Where(fmt.Sprintf("%s %s ?", filter.Column, h.getOperatorSQL(filter.Operator)), filter.Value)
|
cond, args := h.buildFilterCondition(filter)
|
||||||
|
if cond != "" {
|
||||||
|
countQuery = countQuery.Where(cond, args...)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
count, _ := countQuery.Count(hookCtx.Context)
|
count, _ := countQuery.Count(hookCtx.Context)
|
||||||
@@ -790,14 +803,12 @@ func (h *Handler) applyFilterGroup(query common.SelectQuery, filters []common.Fi
|
|||||||
|
|
||||||
// buildFilterCondition builds a filter condition and returns it with args
|
// buildFilterCondition builds a filter condition and returns it with args
|
||||||
func (h *Handler) buildFilterCondition(filter common.FilterOption) (conditionString string, conditionArgs []interface{}) {
|
func (h *Handler) buildFilterCondition(filter common.FilterOption) (conditionString string, conditionArgs []interface{}) {
|
||||||
var condition string
|
if strings.EqualFold(filter.Operator, "in") {
|
||||||
var args []interface{}
|
cond, args := common.BuildInCondition(filter.Column, filter.Value)
|
||||||
|
return cond, args
|
||||||
|
}
|
||||||
operatorSQL := h.getOperatorSQL(filter.Operator)
|
operatorSQL := h.getOperatorSQL(filter.Operator)
|
||||||
condition = fmt.Sprintf("%s %s ?", filter.Column, operatorSQL)
|
return fmt.Sprintf("%s %s ?", filter.Column, operatorSQL), []interface{}{filter.Value}
|
||||||
args = []interface{}{filter.Value}
|
|
||||||
|
|
||||||
return condition, args
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// setRowNumbersOnRecords sets the RowNumber field on each record if it exists
|
// setRowNumbersOnRecords sets the RowNumber field on each record if it exists
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package websocketspec
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||||
)
|
)
|
||||||
@@ -10,6 +11,10 @@ import (
|
|||||||
type HookType string
|
type HookType string
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
// BeforeHandle fires after model resolution, before operation dispatch.
|
||||||
|
// Use this for auth checks that need model rules and user context simultaneously.
|
||||||
|
BeforeHandle HookType = "before_handle"
|
||||||
|
|
||||||
// BeforeRead is called before a read operation
|
// BeforeRead is called before a read operation
|
||||||
BeforeRead HookType = "before_read"
|
BeforeRead HookType = "before_read"
|
||||||
// AfterRead is called after a read operation
|
// AfterRead is called after a read operation
|
||||||
@@ -83,6 +88,9 @@ type HookContext struct {
|
|||||||
// Options contains the parsed request options
|
// Options contains the parsed request options
|
||||||
Options *common.RequestOptions
|
Options *common.RequestOptions
|
||||||
|
|
||||||
|
// Operation being dispatched (e.g. "read", "create", "update", "delete")
|
||||||
|
Operation string
|
||||||
|
|
||||||
// ID is the record ID for single-record operations
|
// ID is the record ID for single-record operations
|
||||||
ID string
|
ID string
|
||||||
|
|
||||||
@@ -98,6 +106,11 @@ type HookContext struct {
|
|||||||
// Error is any error that occurred (for after hooks)
|
// Error is any error that occurred (for after hooks)
|
||||||
Error error
|
Error error
|
||||||
|
|
||||||
|
// Allow hooks to abort the operation
|
||||||
|
Abort bool // If set to true, the operation will be aborted
|
||||||
|
AbortMessage string // Message to return if aborted
|
||||||
|
AbortCode int // HTTP status code if aborted
|
||||||
|
|
||||||
// Metadata is additional context data
|
// Metadata is additional context data
|
||||||
Metadata map[string]interface{}
|
Metadata map[string]interface{}
|
||||||
}
|
}
|
||||||
@@ -171,6 +184,11 @@ func (hr *HookRegistry) Execute(hookType HookType, ctx *HookContext) error {
|
|||||||
if err := hook(ctx); err != nil {
|
if err := hook(ctx); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check if hook requested abort
|
||||||
|
if ctx.Abort {
|
||||||
|
return fmt.Errorf("operation aborted by hook: %s", ctx.AbortMessage)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
108
pkg/websocketspec/security_hooks.go
Normal file
108
pkg/websocketspec/security_hooks.go
Normal file
@@ -0,0 +1,108 @@
|
|||||||
|
package websocketspec
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/security"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RegisterSecurityHooks registers all security-related hooks with the handler
|
||||||
|
func RegisterSecurityHooks(handler *Handler, securityList *security.SecurityList) {
|
||||||
|
// Hook 0: BeforeHandle - enforce auth after model resolution
|
||||||
|
handler.Hooks().Register(BeforeHandle, func(hookCtx *HookContext) error {
|
||||||
|
if err := security.CheckModelAuthAllowed(newSecurityContext(hookCtx), hookCtx.Operation); err != nil {
|
||||||
|
hookCtx.Abort = true
|
||||||
|
hookCtx.AbortMessage = err.Error()
|
||||||
|
hookCtx.AbortCode = http.StatusUnauthorized
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
// Hook 1: BeforeRead - Load security rules
|
||||||
|
handler.Hooks().Register(BeforeRead, func(hookCtx *HookContext) error {
|
||||||
|
secCtx := newSecurityContext(hookCtx)
|
||||||
|
return security.LoadSecurityRules(secCtx, securityList)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Hook 2: AfterRead - Apply column-level security (masking)
|
||||||
|
handler.Hooks().Register(AfterRead, func(hookCtx *HookContext) error {
|
||||||
|
secCtx := newSecurityContext(hookCtx)
|
||||||
|
return security.ApplyColumnSecurity(secCtx, securityList)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Hook 3 (Optional): Audit logging
|
||||||
|
handler.Hooks().Register(AfterRead, func(hookCtx *HookContext) error {
|
||||||
|
secCtx := newSecurityContext(hookCtx)
|
||||||
|
return security.LogDataAccess(secCtx)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Hook 4: BeforeUpdate - enforce CanUpdate rule from context/registry
|
||||||
|
handler.Hooks().Register(BeforeUpdate, func(hookCtx *HookContext) error {
|
||||||
|
secCtx := newSecurityContext(hookCtx)
|
||||||
|
return security.CheckModelUpdateAllowed(secCtx)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Hook 5: BeforeDelete - enforce CanDelete rule from context/registry
|
||||||
|
handler.Hooks().Register(BeforeDelete, func(hookCtx *HookContext) error {
|
||||||
|
secCtx := newSecurityContext(hookCtx)
|
||||||
|
return security.CheckModelDeleteAllowed(secCtx)
|
||||||
|
})
|
||||||
|
|
||||||
|
logger.Info("Security hooks registered for websocketspec handler")
|
||||||
|
}
|
||||||
|
|
||||||
|
// securityContext adapts websocketspec.HookContext to security.SecurityContext interface
|
||||||
|
type securityContext struct {
|
||||||
|
ctx *HookContext
|
||||||
|
}
|
||||||
|
|
||||||
|
func newSecurityContext(ctx *HookContext) security.SecurityContext {
|
||||||
|
return &securityContext{ctx: ctx}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *securityContext) GetContext() context.Context {
|
||||||
|
return s.ctx.Context
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *securityContext) GetUserID() (int, bool) {
|
||||||
|
return security.GetUserID(s.ctx.Context)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *securityContext) GetSchema() string {
|
||||||
|
return s.ctx.Schema
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *securityContext) GetEntity() string {
|
||||||
|
return s.ctx.Entity
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *securityContext) GetModel() interface{} {
|
||||||
|
return s.ctx.Model
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetQuery retrieves a stored query from hook metadata (websocketspec has no Query field)
|
||||||
|
func (s *securityContext) GetQuery() interface{} {
|
||||||
|
if s.ctx.Metadata == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return s.ctx.Metadata["query"]
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetQuery stores the query in hook metadata
|
||||||
|
func (s *securityContext) SetQuery(query interface{}) {
|
||||||
|
if s.ctx.Metadata == nil {
|
||||||
|
s.ctx.Metadata = make(map[string]interface{})
|
||||||
|
}
|
||||||
|
s.ctx.Metadata["query"] = query
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *securityContext) GetResult() interface{} {
|
||||||
|
return s.ctx.Result
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *securityContext) SetResult(result interface{}) {
|
||||||
|
s.ctx.Result = result
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user