Compare commits

..

6 Commits

Author SHA1 Message Date
aa095d6bfd fix(tests): replace panic with log.Fatal for better error handling
Some checks failed
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Successful in -29m52s
Build , Vet Test, and Lint / Build (push) Successful in -29m52s
Tests / Integration Tests (push) Failing after -30m46s
Tests / Unit Tests (push) Successful in -28m51s
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Successful in -30m17s
Build , Vet Test, and Lint / Lint Code (push) Failing after -29m23s
2026-04-07 20:38:22 +02:00
ea5bb38ee4 feat(handler): update to use static base path for SSE server 2026-04-07 20:03:43 +02:00
c2e2c9b873 feat(transport): add streamable HTTP transport for MCP 2026-04-07 19:52:38 +02:00
4adf94fe37 feat(go.mod): add mcp-go dependency for enhanced functionality 2026-04-07 19:09:51 +02:00
Hein
405a04a192 feat(security): integrate security hooks for access control
Some checks failed
Build , Vet Test, and Lint / Lint Code (push) Failing after -30m6s
Tests / Unit Tests (push) Successful in -30m22s
Tests / Integration Tests (push) Failing after -30m41s
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Successful in -30m3s
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Successful in -29m36s
Build , Vet Test, and Lint / Build (push) Successful in -29m58s
* Add security hooks for per-entity operation rules and row/column-level security.
* Implement annotation tool for storing and retrieving freeform annotations.
* Enhance handler to support model registration with access rules.
2026-04-07 15:53:12 +02:00
Hein
c1b16d363a feat(db): add DB method to sqlConnection and mongoConnection
Some checks failed
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Successful in -30m22s
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Successful in -29m59s
Build , Vet Test, and Lint / Lint Code (push) Failing after -30m11s
Build , Vet Test, and Lint / Build (push) Successful in -30m12s
Tests / Unit Tests (push) Successful in -30m49s
Tests / Integration Tests (push) Failing after -30m59s
2026-04-01 15:34:09 +02:00
9 changed files with 529 additions and 75 deletions

2
go.mod
View File

@@ -15,6 +15,7 @@ require (
github.com/gorilla/websocket v1.5.3 github.com/gorilla/websocket v1.5.3
github.com/jackc/pgx/v5 v5.8.0 github.com/jackc/pgx/v5 v5.8.0
github.com/klauspost/compress v1.18.2 github.com/klauspost/compress v1.18.2
github.com/mark3labs/mcp-go v0.46.0
github.com/mattn/go-sqlite3 v1.14.33 github.com/mattn/go-sqlite3 v1.14.33
github.com/microsoft/go-mssqldb v1.9.5 github.com/microsoft/go-mssqldb v1.9.5
github.com/mochi-mqtt/server/v2 v2.7.9 github.com/mochi-mqtt/server/v2 v2.7.9
@@ -88,7 +89,6 @@ require (
github.com/jinzhu/now v1.1.5 // indirect github.com/jinzhu/now v1.1.5 // indirect
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect
github.com/magiconair/properties v1.8.10 // indirect github.com/magiconair/properties v1.8.10 // indirect
github.com/mark3labs/mcp-go v0.46.0 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-isatty v0.0.20 // indirect
github.com/moby/docker-image-spec v1.3.1 // indirect github.com/moby/docker-image-spec v1.3.1 // indirect
github.com/moby/go-archive v0.1.0 // indirect github.com/moby/go-archive v0.1.0 // indirect

View File

@@ -26,6 +26,7 @@ type Connection interface {
Bun() (*bun.DB, error) Bun() (*bun.DB, error)
GORM() (*gorm.DB, error) GORM() (*gorm.DB, error)
Native() (*sql.DB, error) Native() (*sql.DB, error)
DB() (*sql.DB, error)
// Common Database interface (for SQL databases) // Common Database interface (for SQL databases)
Database() (common.Database, error) Database() (common.Database, error)
@@ -224,6 +225,11 @@ func (c *sqlConnection) Native() (*sql.DB, error) {
return c.nativeDB, nil return c.nativeDB, nil
} }
// DB returns the underlying *sql.DB connection
func (c *sqlConnection) DB() (*sql.DB, error) {
return c.Native()
}
// Bun returns a Bun ORM instance wrapping the native connection // Bun returns a Bun ORM instance wrapping the native connection
func (c *sqlConnection) Bun() (*bun.DB, error) { func (c *sqlConnection) Bun() (*bun.DB, error) {
if c == nil { if c == nil {
@@ -645,6 +651,11 @@ func (c *mongoConnection) Native() (*sql.DB, error) {
return nil, ErrNotSQLDatabase return nil, ErrNotSQLDatabase
} }
// DB returns an error for MongoDB connections
func (c *mongoConnection) DB() (*sql.DB, error) {
return nil, ErrNotSQLDatabase
}
// Database returns an error for MongoDB connections // Database returns an error for MongoDB connections
func (c *mongoConnection) Database() (common.Database, error) { func (c *mongoConnection) Database() (common.Database, error) {
return nil, ErrNotSQLDatabase return nil, ErrNotSQLDatabase

View File

@@ -3,6 +3,7 @@ package providers_test
import ( import (
"context" "context"
"fmt" "fmt"
"log"
"time" "time"
"github.com/bitechdev/ResolveSpec/pkg/dbmanager" "github.com/bitechdev/ResolveSpec/pkg/dbmanager"
@@ -29,14 +30,14 @@ func ExamplePostgresListener_basic() {
ctx := context.Background() ctx := context.Background()
if err := provider.Connect(ctx, cfg); err != nil { if err := provider.Connect(ctx, cfg); err != nil {
panic(fmt.Sprintf("Failed to connect: %v", err)) log.Fatalf("Failed to connect: %v", err)
} }
defer provider.Close() defer provider.Close()
// Get listener // Get listener
listener, err := provider.GetListener(ctx) listener, err := provider.GetListener(ctx)
if err != nil { if err != nil {
panic(fmt.Sprintf("Failed to get listener: %v", err)) log.Fatalf("Failed to get listener: %v", err)
} }
// Subscribe to a channel with a handler // Subscribe to a channel with a handler
@@ -44,13 +45,13 @@ func ExamplePostgresListener_basic() {
fmt.Printf("Received notification on %s: %s\n", channel, payload) fmt.Printf("Received notification on %s: %s\n", channel, payload)
}) })
if err != nil { if err != nil {
panic(fmt.Sprintf("Failed to listen: %v", err)) log.Fatalf("Failed to listen: %v", err)
} }
// Send a notification // Send a notification
err = listener.Notify(ctx, "user_events", `{"event":"user_created","user_id":123}`) err = listener.Notify(ctx, "user_events", `{"event":"user_created","user_id":123}`)
if err != nil { if err != nil {
panic(fmt.Sprintf("Failed to notify: %v", err)) log.Fatalf("Failed to notify: %v", err)
} }
// Wait for notification to be processed // Wait for notification to be processed
@@ -58,7 +59,7 @@ func ExamplePostgresListener_basic() {
// Unsubscribe from the channel // Unsubscribe from the channel
if err := listener.Unlisten("user_events"); err != nil { if err := listener.Unlisten("user_events"); err != nil {
panic(fmt.Sprintf("Failed to unlisten: %v", err)) log.Fatalf("Failed to unlisten: %v", err)
} }
} }
@@ -80,13 +81,13 @@ func ExamplePostgresListener_multipleChannels() {
ctx := context.Background() ctx := context.Background()
if err := provider.Connect(ctx, cfg); err != nil { if err := provider.Connect(ctx, cfg); err != nil {
panic(fmt.Sprintf("Failed to connect: %v", err)) log.Fatalf("Failed to connect: %v", err)
} }
defer provider.Close() defer provider.Close()
listener, err := provider.GetListener(ctx) listener, err := provider.GetListener(ctx)
if err != nil { if err != nil {
panic(fmt.Sprintf("Failed to get listener: %v", err)) log.Fatalf("Failed to get listener: %v", err)
} }
// Listen to multiple channels // Listen to multiple channels
@@ -97,7 +98,7 @@ func ExamplePostgresListener_multipleChannels() {
fmt.Printf("[%s] %s\n", ch, payload) fmt.Printf("[%s] %s\n", ch, payload)
}) })
if err != nil { if err != nil {
panic(fmt.Sprintf("Failed to listen on %s: %v", channel, err)) log.Fatalf("Failed to listen on %s: %v", channel, err)
} }
} }
@@ -140,14 +141,14 @@ func ExamplePostgresListener_withDBManager() {
provider := providers.NewPostgresProvider() provider := providers.NewPostgresProvider()
if err := provider.Connect(ctx, cfg); err != nil { if err := provider.Connect(ctx, cfg); err != nil {
panic(err) log.Fatal(err)
} }
defer provider.Close() defer provider.Close()
// Get listener // Get listener
listener, err := provider.GetListener(ctx) listener, err := provider.GetListener(ctx)
if err != nil { if err != nil {
panic(err) log.Fatal(err)
} }
// Subscribe to application events // Subscribe to application events
@@ -186,13 +187,13 @@ func ExamplePostgresListener_errorHandling() {
ctx := context.Background() ctx := context.Background()
if err := provider.Connect(ctx, cfg); err != nil { if err := provider.Connect(ctx, cfg); err != nil {
panic(fmt.Sprintf("Failed to connect: %v", err)) log.Fatalf("Failed to connect: %v", err)
} }
defer provider.Close() defer provider.Close()
listener, err := provider.GetListener(ctx) listener, err := provider.GetListener(ctx)
if err != nil { if err != nil {
panic(fmt.Sprintf("Failed to get listener: %v", err)) log.Fatalf("Failed to get listener: %v", err)
} }
// The listener automatically reconnects if the connection is lost // The listener automatically reconnects if the connection is lost

View File

@@ -67,58 +67,164 @@ Each call immediately creates four MCP **tools** and one MCP **resource** for th
--- ---
## HTTP / SSE Transport ## HTTP Transports
The `*server.SSEServer` returned by any of the helpers below implements `http.Handler`, so it works with every Go HTTP framework.
`Config.BasePath` is required and used for all route registration. `Config.BasePath` is required and used for all route registration.
`Config.BaseURL` is optional — when empty it is detected from each request. `Config.BaseURL` is optional — when empty it is detected from each request.
### Gorilla Mux Two transports are supported: **SSE** (legacy, two-endpoint) and **Streamable HTTP** (recommended, single-endpoint).
---
### SSE Transport
Two endpoints: `GET {BasePath}/sse` (subscribe) + `POST {BasePath}/message` (send).
#### Gorilla Mux
```go ```go
resolvemcp.SetupMuxRoutes(r, handler) resolvemcp.SetupMuxRoutes(r, handler)
``` ```
Registers:
| Route | Method | Description | | Route | Method | Description |
|---|---|---| |---|---|---|
| `{BasePath}/sse` | GET | SSE connection — clients subscribe here | | `{BasePath}/sse` | GET | SSE connection — clients subscribe here |
| `{BasePath}/message` | POST | JSON-RPC — clients send requests here | | `{BasePath}/message` | POST | JSON-RPC — clients send requests here |
| `{BasePath}/*` | any | Full SSE server (convenience prefix) |
### bunrouter #### bunrouter
```go ```go
resolvemcp.SetupBunRouterRoutes(router, handler) resolvemcp.SetupBunRouterRoutes(router, handler)
``` ```
Registers `GET {BasePath}/sse` and `POST {BasePath}/message` on the provided `*bunrouter.Router`. #### Gin / net/http / Echo
### Gin (or any `http.Handler`-compatible framework)
Use `handler.SSEServer()` to get an `http.Handler` and wrap it with the framework's adapter:
```go ```go
sse := handler.SSEServer() sse := handler.SSEServer()
// Gin engine.Any("/mcp/*path", gin.WrapH(sse)) // Gin
engine.Any("/mcp/*path", gin.WrapH(sse)) http.Handle("/mcp/", sse) // net/http
e.Any("/mcp/*", echo.WrapHandler(sse)) // Echo
// net/http
http.Handle("/mcp/", sse)
// Echo
e.Any("/mcp/*", echo.WrapHandler(sse))
``` ```
---
### Streamable HTTP Transport
Single endpoint at `{BasePath}`. Handles POST (client→server) and GET (server→client streaming). Preferred for new integrations.
#### Gorilla Mux
```go
resolvemcp.SetupMuxStreamableHTTPRoutes(r, handler)
```
Mounts the handler at `{BasePath}` (all methods).
#### bunrouter
```go
resolvemcp.SetupBunRouterStreamableHTTPRoutes(router, handler)
```
Registers GET, POST, DELETE on `{BasePath}`.
#### Gin / net/http / Echo
```go
h := handler.StreamableHTTPServer()
// or: h := resolvemcp.NewStreamableHTTPHandler(handler)
engine.Any("/mcp", gin.WrapH(h)) // Gin
http.Handle("/mcp", h) // net/http
e.Any("/mcp", echo.WrapHandler(h)) // Echo
```
---
### Authentication ### Authentication
Add middleware before the MCP routes. The handler itself has no auth layer. Add middleware before the MCP routes. The handler itself has no auth layer.
--- ---
## Security
`resolvemcp` integrates with the `security` package to provide per-entity access control, row-level security, and column-level security — the same system used by `resolvespec` and `restheadspec`.
### Wiring security hooks
```go
import "github.com/bitechdev/ResolveSpec/pkg/security"
securityList := security.NewSecurityList(mySecurityProvider)
resolvemcp.RegisterSecurityHooks(handler, securityList)
```
Call `RegisterSecurityHooks` **once**, after creating the handler and before registering models. It installs these controls automatically:
| Hook | Effect |
|---|---|
| `BeforeHandle` | Enforces per-entity operation rules (see below) |
| `BeforeRead` | Loads RLS/CLS rules, then injects a user-scoped WHERE clause |
| `AfterRead` | Masks/hides columns per column-security rules; writes audit log |
| `BeforeUpdate` | Blocks update if `CanUpdate` is false |
| `BeforeDelete` | Blocks delete if `CanDelete` is false |
### Per-entity operation rules
Use `RegisterModelWithRules` instead of `RegisterModel` to set access rules at registration time:
```go
import "github.com/bitechdev/ResolveSpec/pkg/modelregistry"
// Read-only entity
handler.RegisterModelWithRules("public", "audit_logs", &AuditLog{}, modelregistry.ModelRules{
CanRead: true,
CanCreate: false,
CanUpdate: false,
CanDelete: false,
})
// Public read, authenticated write
handler.RegisterModelWithRules("public", "products", &Product{}, modelregistry.ModelRules{
CanPublicRead: true,
CanRead: true,
CanCreate: true,
CanUpdate: true,
CanDelete: false,
})
```
To update rules for an already-registered model:
```go
handler.SetModelRules("public", "users", modelregistry.ModelRules{
CanRead: true,
CanCreate: true,
CanUpdate: true,
CanDelete: false,
})
```
`RegisterModel` (no rules) registers with all-allowed defaults (`CanRead/Create/Update/Delete = true`).
### ModelRules fields
| Field | Default | Description |
|---|---|---|
| `CanPublicRead` | `false` | Allow unauthenticated reads |
| `CanPublicCreate` | `false` | Allow unauthenticated creates |
| `CanPublicUpdate` | `false` | Allow unauthenticated updates |
| `CanPublicDelete` | `false` | Allow unauthenticated deletes |
| `CanRead` | `true` | Allow authenticated reads |
| `CanCreate` | `true` | Allow authenticated creates |
| `CanUpdate` | `true` | Allow authenticated updates |
| `CanDelete` | `true` | Allow authenticated deletes |
| `SecurityDisabled` | `false` | Skip all security checks for this model |
---
## MCP Tools ## MCP Tools
### Tool Naming ### Tool Naming
@@ -204,6 +310,35 @@ Delete a record by primary key. **Irreversible.**
{ "success": true, "data": { ...deleted record... } } { "success": true, "data": { ...deleted record... } }
``` ```
### Annotation Tool — `resolvespec_annotate`
Store or retrieve freeform annotation records for any tool, model, or entity. Registered automatically on every handler.
| Argument | Type | Description |
|---|---|---|
| `tool_name` | string (required) | Key to annotate — an MCP tool name (e.g. `read_public_users`), a model name (e.g. `public.users`), or any other identifier. |
| `annotations` | object | Annotation data to persist. Omit to retrieve existing annotations instead. |
**Set annotations** (calls `resolvespec_set_annotation(tool_name, annotations)`):
```json
{ "tool_name": "read_public_users", "annotations": { "description": "Returns active users", "owner": "platform-team" } }
```
**Response:**
```json
{ "success": true, "tool_name": "read_public_users", "action": "set" }
```
**Get annotations** (calls `resolvespec_get_annotation(tool_name)`):
```json
{ "tool_name": "read_public_users" }
```
**Response:**
```json
{ "success": true, "tool_name": "read_public_users", "action": "get", "annotations": { ... } }
```
---
### Resource — `{schema}.{entity}` ### Resource — `{schema}.{entity}`
Each model is also registered as an MCP resource with URI `schema.entity` (or just `entity` when schema is empty). Reading the resource returns up to 100 records as `application/json`. Each model is also registered as an MCP resource with URI `schema.entity` (or just `entity` when schema is empty). Reading the resource returns up to 100 records as `application/json`.

View File

@@ -0,0 +1,107 @@
package resolvemcp
import (
"context"
"encoding/json"
"fmt"
"github.com/mark3labs/mcp-go/mcp"
)
const annotationToolName = "resolvespec_annotate"
// registerAnnotationTool adds the resolvespec_annotate tool to the MCP server.
// The tool lets models/entities store and retrieve freeform annotation records
// using the resolvespec_set_annotation / resolvespec_get_annotation database procedures.
func registerAnnotationTool(h *Handler) {
tool := mcp.NewTool(annotationToolName,
mcp.WithDescription(
"Store or retrieve annotations for any MCP tool, model, or entity.\n\n"+
"To set annotations: provide both 'tool_name' and 'annotations'. "+
"Calls resolvespec_set_annotation(tool_name, annotations) to persist the data.\n\n"+
"To get annotations: provide only 'tool_name'. "+
"Calls resolvespec_get_annotation(tool_name) and returns the stored annotations.\n\n"+
"'tool_name' may be any identifier: an MCP tool name (e.g. 'read_public_users'), "+
"a model/entity name (e.g. 'public.users'), or any other key.",
),
mcp.WithString("tool_name",
mcp.Description("Name of the tool, model, or entity to annotate (e.g. 'read_public_users', 'public.users')."),
mcp.Required(),
),
mcp.WithObject("annotations",
mcp.Description("Annotation data to store. Omit to retrieve existing annotations instead of setting them."),
),
)
h.mcpServer.AddTool(tool, func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
args := req.GetArguments()
toolName, ok := args["tool_name"].(string)
if !ok || toolName == "" {
return mcp.NewToolResultError("missing required argument: tool_name"), nil
}
annotations, hasAnnotations := args["annotations"]
if hasAnnotations && annotations != nil {
return executeSetAnnotation(ctx, h, toolName, annotations)
}
return executeGetAnnotation(ctx, h, toolName)
})
}
func executeSetAnnotation(ctx context.Context, h *Handler, toolName string, annotations interface{}) (*mcp.CallToolResult, error) {
jsonBytes, err := json.Marshal(annotations)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to marshal annotations: %v", err)), nil
}
_, err = h.db.Exec(ctx, "SELECT resolvespec_set_annotation($1, $2)", toolName, string(jsonBytes))
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to set annotation: %v", err)), nil
}
return marshalResult(map[string]interface{}{
"success": true,
"tool_name": toolName,
"action": "set",
})
}
func executeGetAnnotation(ctx context.Context, h *Handler, toolName string) (*mcp.CallToolResult, error) {
var rows []map[string]interface{}
err := h.db.Query(ctx, &rows, "SELECT resolvespec_get_annotation($1)", toolName)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to get annotation: %v", err)), nil
}
var annotations interface{}
if len(rows) > 0 {
// The procedure returns a single value; extract the first column of the first row.
for _, v := range rows[0] {
annotations = v
break
}
}
// If the value is a []byte or string containing JSON, decode it so it round-trips cleanly.
switch v := annotations.(type) {
case []byte:
var decoded interface{}
if json.Unmarshal(v, &decoded) == nil {
annotations = decoded
}
case string:
var decoded interface{}
if json.Unmarshal([]byte(v), &decoded) == nil {
annotations = decoded
}
}
return marshalResult(map[string]interface{}{
"success": true,
"tool_name": toolName,
"action": "get",
"annotations": annotations,
})
}

View File

@@ -14,6 +14,7 @@ import (
"github.com/bitechdev/ResolveSpec/pkg/common" "github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/logger" "github.com/bitechdev/ResolveSpec/pkg/logger"
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
"github.com/bitechdev/ResolveSpec/pkg/reflection" "github.com/bitechdev/ResolveSpec/pkg/reflection"
) )
@@ -30,7 +31,7 @@ type Handler struct {
// NewHandler creates a Handler with the given database, model registry, and config. // NewHandler creates a Handler with the given database, model registry, and config.
func NewHandler(db common.Database, registry common.ModelRegistry, cfg Config) *Handler { func NewHandler(db common.Database, registry common.ModelRegistry, cfg Config) *Handler {
return &Handler{ h := &Handler{
db: db, db: db,
registry: registry, registry: registry,
hooks: NewHookRegistry(), hooks: NewHookRegistry(),
@@ -39,6 +40,8 @@ func NewHandler(db common.Database, registry common.ModelRegistry, cfg Config) *
name: "resolvemcp", name: "resolvemcp",
version: "1.0.0", version: "1.0.0",
} }
registerAnnotationTool(h)
return h
} }
// Hooks returns the hook registry. // Hooks returns the hook registry.
@@ -66,12 +69,20 @@ func (h *Handler) SSEServer() http.Handler {
return &dynamicSSEHandler{h: h} return &dynamicSSEHandler{h: h}
} }
// StreamableHTTPServer returns an http.Handler that serves MCP over the streamable HTTP transport.
// Unlike SSE (which requires two endpoints), streamable HTTP uses a single endpoint for all
// client-server communication (POST for requests, GET for server-initiated messages).
// Mount the returned handler at the desired path; the path itself becomes the MCP endpoint.
func (h *Handler) StreamableHTTPServer() http.Handler {
return server.NewStreamableHTTPServer(h.mcpServer)
}
// newSSEServer creates a concrete *server.SSEServer for known baseURL and basePath values. // newSSEServer creates a concrete *server.SSEServer for known baseURL and basePath values.
func (h *Handler) newSSEServer(baseURL, basePath string) *server.SSEServer { func (h *Handler) newSSEServer(baseURL, basePath string) *server.SSEServer {
return server.NewSSEServer( return server.NewSSEServer(
h.mcpServer, h.mcpServer,
server.WithBaseURL(baseURL), server.WithBaseURL(baseURL),
server.WithBasePath(basePath), server.WithStaticBasePath(basePath),
) )
} }
@@ -123,6 +134,32 @@ func (h *Handler) RegisterModel(schema, entity string, model interface{}) error
return nil return nil
} }
// RegisterModelWithRules registers a model and sets per-entity operation rules
// (CanRead, CanCreate, CanUpdate, CanDelete, CanPublic*, SecurityDisabled).
// Requires RegisterSecurityHooks to have been called for the rules to be enforced.
func (h *Handler) RegisterModelWithRules(schema, entity string, model interface{}, rules modelregistry.ModelRules) error {
reg, ok := h.registry.(*modelregistry.DefaultModelRegistry)
if !ok {
return fmt.Errorf("resolvemcp: registry does not support model rules (use NewHandlerWithGORM/Bun/DB)")
}
fullName := buildModelName(schema, entity)
if err := reg.RegisterModelWithRules(fullName, model, rules); err != nil {
return err
}
registerModelTools(h, schema, entity, model)
return nil
}
// SetModelRules updates the operation rules for an already-registered model.
// Requires RegisterSecurityHooks to have been called for the rules to be enforced.
func (h *Handler) SetModelRules(schema, entity string, rules modelregistry.ModelRules) error {
reg, ok := h.registry.(*modelregistry.DefaultModelRegistry)
if !ok {
return fmt.Errorf("resolvemcp: registry does not support model rules (use NewHandlerWithGORM/Bun/DB)")
}
return reg.SetModelRules(buildModelName(schema, entity), rules)
}
// buildModelName builds the registry key for a model (same format as resolvespec). // buildModelName builds the registry key for a model (same format as resolvespec).
func buildModelName(schema, entity string) string { func buildModelName(schema, entity string) string {
if schema == "" { if schema == "" {
@@ -160,8 +197,19 @@ func (h *Handler) getSchemaAndTable(defaultSchema, entity string, model interfac
return defaultSchema, entity return defaultSchema, entity
} }
// recoverPanic catches a panic from the current goroutine and returns it as an error.
// Usage: defer recoverPanic(&returnedErr)
func recoverPanic(err *error) {
if r := recover(); r != nil {
msg := fmt.Sprintf("%v", r)
logger.Error("[resolvemcp] panic recovered: %s", msg)
*err = fmt.Errorf("internal error: %s", msg)
}
}
// executeRead reads records from the database and returns raw data + metadata. // executeRead reads records from the database and returns raw data + metadata.
func (h *Handler) executeRead(ctx context.Context, schema, entity, id string, options common.RequestOptions) (interface{}, *common.Metadata, error) { func (h *Handler) executeRead(ctx context.Context, schema, entity, id string, options common.RequestOptions) (_ interface{}, _ *common.Metadata, retErr error) {
defer recoverPanic(&retErr)
model, err := h.registry.GetModelByEntity(schema, entity) model, err := h.registry.GetModelByEntity(schema, entity)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("model not found: %w", err) return nil, nil, fmt.Errorf("model not found: %w", err)
@@ -217,15 +265,6 @@ func (h *Handler) executeRead(ctx context.Context, schema, entity, id string, op
query = query.ColumnExpr(fmt.Sprintf("(%s) AS %s", cu.Expression, cu.Name)) query = query.ColumnExpr(fmt.Sprintf("(%s) AS %s", cu.Expression, cu.Name))
} }
// Preloads
if len(options.Preload) > 0 {
var err error
query, err = h.applyPreloads(model, query, options.Preload)
if err != nil {
return nil, nil, fmt.Errorf("failed to apply preloads: %w", err)
}
}
// Filters // Filters
query = h.applyFilters(query, options.Filters) query = h.applyFilters(query, options.Filters)
@@ -267,7 +306,7 @@ func (h *Handler) executeRead(ctx context.Context, schema, entity, id string, op
} }
} }
// Count // Count — must happen before preloads are applied; Bun panics when counting with relations.
total, err := query.Count(ctx) total, err := query.Count(ctx)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("error counting records: %w", err) return nil, nil, fmt.Errorf("error counting records: %w", err)
@@ -281,6 +320,15 @@ func (h *Handler) executeRead(ctx context.Context, schema, entity, id string, op
query = query.Offset(*options.Offset) query = query.Offset(*options.Offset)
} }
// Preloads — applied after count to avoid Bun panic when counting with relations.
if len(options.Preload) > 0 {
var preloadErr error
query, preloadErr = h.applyPreloads(model, query, options.Preload)
if preloadErr != nil {
return nil, nil, fmt.Errorf("failed to apply preloads: %w", preloadErr)
}
}
// BeforeRead hook // BeforeRead hook
hookCtx.Query = query hookCtx.Query = query
if err := h.hooks.Execute(BeforeRead, hookCtx); err != nil { if err := h.hooks.Execute(BeforeRead, hookCtx); err != nil {
@@ -341,7 +389,8 @@ func (h *Handler) executeRead(ctx context.Context, schema, entity, id string, op
} }
// executeCreate inserts one or more records. // executeCreate inserts one or more records.
func (h *Handler) executeCreate(ctx context.Context, schema, entity string, data interface{}) (interface{}, error) { func (h *Handler) executeCreate(ctx context.Context, schema, entity string, data interface{}) (_ interface{}, retErr error) {
defer recoverPanic(&retErr)
model, err := h.registry.GetModelByEntity(schema, entity) model, err := h.registry.GetModelByEntity(schema, entity)
if err != nil { if err != nil {
return nil, fmt.Errorf("model not found: %w", err) return nil, fmt.Errorf("model not found: %w", err)
@@ -425,7 +474,8 @@ func (h *Handler) executeCreate(ctx context.Context, schema, entity string, data
} }
// executeUpdate updates a record by ID. // executeUpdate updates a record by ID.
func (h *Handler) executeUpdate(ctx context.Context, schema, entity, id string, data interface{}) (interface{}, error) { func (h *Handler) executeUpdate(ctx context.Context, schema, entity, id string, data interface{}) (_ interface{}, retErr error) {
defer recoverPanic(&retErr)
model, err := h.registry.GetModelByEntity(schema, entity) model, err := h.registry.GetModelByEntity(schema, entity)
if err != nil { if err != nil {
return nil, fmt.Errorf("model not found: %w", err) return nil, fmt.Errorf("model not found: %w", err)
@@ -535,7 +585,8 @@ func (h *Handler) executeUpdate(ctx context.Context, schema, entity, id string,
} }
// executeDelete deletes a record by ID. // executeDelete deletes a record by ID.
func (h *Handler) executeDelete(ctx context.Context, schema, entity, id string) (interface{}, error) { func (h *Handler) executeDelete(ctx context.Context, schema, entity, id string) (_ interface{}, retErr error) {
defer recoverPanic(&retErr)
if id == "" { if id == "" {
return nil, fmt.Errorf("delete requires an ID") return nil, fmt.Errorf("delete requires an ID")
} }

View File

@@ -98,3 +98,36 @@ func SetupBunRouterRoutes(router *bunrouter.Router, handler *Handler) {
func NewSSEServer(handler *Handler) http.Handler { func NewSSEServer(handler *Handler) http.Handler {
return handler.SSEServer() return handler.SSEServer()
} }
// SetupMuxStreamableHTTPRoutes mounts the MCP streamable HTTP endpoint on the given Gorilla Mux router.
// The streamable HTTP transport uses a single endpoint (Config.BasePath) for all communication:
// POST for client→server messages, GET for server→client streaming.
//
// Example:
//
// resolvemcp.SetupMuxStreamableHTTPRoutes(r, handler) // mounts at Config.BasePath
func SetupMuxStreamableHTTPRoutes(muxRouter *mux.Router, handler *Handler) {
basePath := handler.config.BasePath
h := handler.StreamableHTTPServer()
muxRouter.PathPrefix(basePath).Handler(http.StripPrefix(basePath, h))
}
// SetupBunRouterStreamableHTTPRoutes mounts the MCP streamable HTTP endpoint on a bunrouter router.
// The streamable HTTP transport uses a single endpoint (Config.BasePath).
func SetupBunRouterStreamableHTTPRoutes(router *bunrouter.Router, handler *Handler) {
basePath := handler.config.BasePath
h := handler.StreamableHTTPServer()
router.GET(basePath, bunrouter.HTTPHandler(h))
router.POST(basePath, bunrouter.HTTPHandler(h))
router.DELETE(basePath, bunrouter.HTTPHandler(h))
}
// NewStreamableHTTPHandler returns an http.Handler that serves MCP over the streamable HTTP transport.
// Mount it at the desired path; that path becomes the MCP endpoint.
//
// h := resolvemcp.NewStreamableHTTPHandler(handler)
// http.Handle("/mcp", h)
// engine.Any("/mcp", gin.WrapH(h))
func NewStreamableHTTPHandler(handler *Handler) http.Handler {
return handler.StreamableHTTPServer()
}

View File

@@ -0,0 +1,115 @@
package resolvemcp
import (
"context"
"net/http"
"github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/logger"
"github.com/bitechdev/ResolveSpec/pkg/security"
)
// RegisterSecurityHooks wires the security package's access-control layer into the
// resolvemcp handler. Call it once after creating the handler, before registering models.
//
// The following controls are applied:
// - Per-entity operation rules (CanRead, CanCreate, CanUpdate, CanDelete, CanPublic*)
// stored via RegisterModelWithRules / SetModelRules.
// - Row-level security: WHERE clause injected per user from the SecurityList provider.
// - Column-level security: sensitive columns masked/hidden in read results.
// - Audit logging after each read.
func RegisterSecurityHooks(handler *Handler, securityList *security.SecurityList) {
// BeforeHandle: enforce model-level operation rules (auth check).
handler.Hooks().Register(BeforeHandle, func(hookCtx *HookContext) error {
if err := security.CheckModelAuthAllowed(newSecurityContext(hookCtx), hookCtx.Operation); err != nil {
hookCtx.Abort = true
hookCtx.AbortMessage = err.Error()
hookCtx.AbortCode = http.StatusUnauthorized
return err
}
return nil
})
// BeforeRead (1st): load RLS + CLS rules from the provider into SecurityList.
handler.Hooks().Register(BeforeRead, func(hookCtx *HookContext) error {
return security.LoadSecurityRules(newSecurityContext(hookCtx), securityList)
})
// BeforeRead (2nd): apply row-level security — injects a WHERE clause into the query.
// resolvemcp has no separate BeforeScan hook; the query is available in BeforeRead.
handler.Hooks().Register(BeforeRead, func(hookCtx *HookContext) error {
return security.ApplyRowSecurity(newSecurityContext(hookCtx), securityList)
})
// AfterRead (1st): apply column-level security — mask/hide columns in the result.
handler.Hooks().Register(AfterRead, func(hookCtx *HookContext) error {
return security.ApplyColumnSecurity(newSecurityContext(hookCtx), securityList)
})
// AfterRead (2nd): audit log.
handler.Hooks().Register(AfterRead, func(hookCtx *HookContext) error {
return security.LogDataAccess(newSecurityContext(hookCtx))
})
// BeforeUpdate: enforce CanUpdate rule.
handler.Hooks().Register(BeforeUpdate, func(hookCtx *HookContext) error {
return security.CheckModelUpdateAllowed(newSecurityContext(hookCtx))
})
// BeforeDelete: enforce CanDelete rule.
handler.Hooks().Register(BeforeDelete, func(hookCtx *HookContext) error {
return security.CheckModelDeleteAllowed(newSecurityContext(hookCtx))
})
logger.Info("Security hooks registered for resolvemcp handler")
}
// --------------------------------------------------------------------------
// securityContext — adapts resolvemcp.HookContext to security.SecurityContext
// --------------------------------------------------------------------------
type securityContext struct {
ctx *HookContext
}
func newSecurityContext(ctx *HookContext) security.SecurityContext {
return &securityContext{ctx: ctx}
}
func (s *securityContext) GetContext() context.Context {
return s.ctx.Context
}
func (s *securityContext) GetUserID() (int, bool) {
return security.GetUserID(s.ctx.Context)
}
func (s *securityContext) GetSchema() string {
return s.ctx.Schema
}
func (s *securityContext) GetEntity() string {
return s.ctx.Entity
}
func (s *securityContext) GetModel() interface{} {
return s.ctx.Model
}
func (s *securityContext) GetQuery() interface{} {
return s.ctx.Query
}
func (s *securityContext) SetQuery(query interface{}) {
if q, ok := query.(common.SelectQuery); ok {
s.ctx.Query = q
}
}
func (s *securityContext) GetResult() interface{} {
return s.ctx.Result
}
func (s *securityContext) SetResult(result interface{}) {
s.ctx.Result = result
}

View File

@@ -3,6 +3,7 @@ package server_test
import ( import (
"context" "context"
"fmt" "fmt"
"log"
"net/http" "net/http"
"time" "time"
@@ -29,18 +30,18 @@ func ExampleManager_basic() {
GZIP: true, // Enable GZIP compression GZIP: true, // Enable GZIP compression
}) })
if err != nil { if err != nil {
panic(err) log.Fatal(err)
} }
// Start all servers // Start all servers
if err := mgr.StartAll(); err != nil { if err := mgr.StartAll(); err != nil {
panic(err) log.Fatal(err)
} }
// Server is now running... // Server is now running...
// When done, stop gracefully // When done, stop gracefully
if err := mgr.StopAll(); err != nil { if err := mgr.StopAll(); err != nil {
panic(err) log.Fatal(err)
} }
} }
@@ -61,7 +62,7 @@ func ExampleManager_https() {
SSLKey: "/path/to/key.pem", SSLKey: "/path/to/key.pem",
}) })
if err != nil { if err != nil {
panic(err) log.Fatal(err)
} }
// Option 2: Self-signed certificate (for development) // Option 2: Self-signed certificate (for development)
@@ -73,27 +74,27 @@ func ExampleManager_https() {
SelfSignedSSL: true, SelfSignedSSL: true,
}) })
if err != nil { if err != nil {
panic(err) log.Fatal(err)
} }
// Option 3: Let's Encrypt / AutoTLS (for production) // Option 3: Let's Encrypt / AutoTLS (for production)
_, err = mgr.Add(server.Config{ _, err = mgr.Add(server.Config{
Name: "https-server-letsencrypt", Name: "https-server-letsencrypt",
Host: "0.0.0.0", Host: "0.0.0.0",
Port: 443, Port: 443,
Handler: handler, Handler: handler,
AutoTLS: true, AutoTLS: true,
AutoTLSDomains: []string{"example.com", "www.example.com"}, AutoTLSDomains: []string{"example.com", "www.example.com"},
AutoTLSEmail: "admin@example.com", AutoTLSEmail: "admin@example.com",
AutoTLSCacheDir: "./certs-cache", AutoTLSCacheDir: "./certs-cache",
}) })
if err != nil { if err != nil {
panic(err) log.Fatal(err)
} }
// Start all servers // Start all servers
if err := mgr.StartAll(); err != nil { if err := mgr.StartAll(); err != nil {
panic(err) log.Fatal(err)
} }
// Cleanup // Cleanup
@@ -136,7 +137,7 @@ func ExampleManager_gracefulShutdown() {
IdleTimeout: 120 * time.Second, IdleTimeout: 120 * time.Second,
}) })
if err != nil { if err != nil {
panic(err) log.Fatal(err)
} }
// Start servers and block until shutdown signal (SIGINT/SIGTERM) // Start servers and block until shutdown signal (SIGINT/SIGTERM)
@@ -164,7 +165,7 @@ func ExampleManager_healthChecks() {
Handler: mux, Handler: mux,
}) })
if err != nil { if err != nil {
panic(err) log.Fatal(err)
} }
// Add health and readiness endpoints // Add health and readiness endpoints
@@ -173,7 +174,7 @@ func ExampleManager_healthChecks() {
// Start the server // Start the server
if err := mgr.StartAll(); err != nil { if err := mgr.StartAll(); err != nil {
panic(err) log.Fatal(err)
} }
// Health check returns: // Health check returns:
@@ -204,7 +205,7 @@ func ExampleManager_multipleServers() {
GZIP: true, GZIP: true,
}) })
if err != nil { if err != nil {
panic(err) log.Fatal(err)
} }
// Admin API server (different port) // Admin API server (different port)
@@ -218,7 +219,7 @@ func ExampleManager_multipleServers() {
Handler: adminHandler, Handler: adminHandler,
}) })
if err != nil { if err != nil {
panic(err) log.Fatal(err)
} }
// Metrics server (internal only) // Metrics server (internal only)
@@ -232,18 +233,18 @@ func ExampleManager_multipleServers() {
Handler: metricsHandler, Handler: metricsHandler,
}) })
if err != nil { if err != nil {
panic(err) log.Fatal(err)
} }
// Start all servers at once // Start all servers at once
if err := mgr.StartAll(); err != nil { if err := mgr.StartAll(); err != nil {
panic(err) log.Fatal(err)
} }
// Get specific server instance // Get specific server instance
publicInstance, err := mgr.Get("public-api") publicInstance, err := mgr.Get("public-api")
if err != nil { if err != nil {
panic(err) log.Fatal(err)
} }
fmt.Printf("Public API running on: %s\n", publicInstance.Addr()) fmt.Printf("Public API running on: %s\n", publicInstance.Addr())
@@ -253,7 +254,7 @@ func ExampleManager_multipleServers() {
// Stop all servers gracefully (in parallel) // Stop all servers gracefully (in parallel)
if err := mgr.StopAll(); err != nil { if err := mgr.StopAll(); err != nil {
panic(err) log.Fatal(err)
} }
} }
@@ -273,11 +274,11 @@ func ExampleManager_monitoring() {
Handler: handler, Handler: handler,
}) })
if err != nil { if err != nil {
panic(err) log.Fatal(err)
} }
if err := mgr.StartAll(); err != nil { if err := mgr.StartAll(); err != nil {
panic(err) log.Fatal(err)
} }
// Check server status // Check server status