Compare commits

...

1 Commits

Author SHA1 Message Date
aa095d6bfd fix(tests): replace panic with log.Fatal for better error handling
Some checks failed
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Successful in -29m52s
Build , Vet Test, and Lint / Build (push) Successful in -29m52s
Tests / Integration Tests (push) Failing after -30m46s
Tests / Unit Tests (push) Successful in -28m51s
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Successful in -30m17s
Build , Vet Test, and Lint / Lint Code (push) Failing after -29m23s
2026-04-07 20:38:22 +02:00
3 changed files with 67 additions and 51 deletions

View File

@@ -3,6 +3,7 @@ package providers_test
import ( import (
"context" "context"
"fmt" "fmt"
"log"
"time" "time"
"github.com/bitechdev/ResolveSpec/pkg/dbmanager" "github.com/bitechdev/ResolveSpec/pkg/dbmanager"
@@ -29,14 +30,14 @@ func ExamplePostgresListener_basic() {
ctx := context.Background() ctx := context.Background()
if err := provider.Connect(ctx, cfg); err != nil { if err := provider.Connect(ctx, cfg); err != nil {
panic(fmt.Sprintf("Failed to connect: %v", err)) log.Fatalf("Failed to connect: %v", err)
} }
defer provider.Close() defer provider.Close()
// Get listener // Get listener
listener, err := provider.GetListener(ctx) listener, err := provider.GetListener(ctx)
if err != nil { if err != nil {
panic(fmt.Sprintf("Failed to get listener: %v", err)) log.Fatalf("Failed to get listener: %v", err)
} }
// Subscribe to a channel with a handler // Subscribe to a channel with a handler
@@ -44,13 +45,13 @@ func ExamplePostgresListener_basic() {
fmt.Printf("Received notification on %s: %s\n", channel, payload) fmt.Printf("Received notification on %s: %s\n", channel, payload)
}) })
if err != nil { if err != nil {
panic(fmt.Sprintf("Failed to listen: %v", err)) log.Fatalf("Failed to listen: %v", err)
} }
// Send a notification // Send a notification
err = listener.Notify(ctx, "user_events", `{"event":"user_created","user_id":123}`) err = listener.Notify(ctx, "user_events", `{"event":"user_created","user_id":123}`)
if err != nil { if err != nil {
panic(fmt.Sprintf("Failed to notify: %v", err)) log.Fatalf("Failed to notify: %v", err)
} }
// Wait for notification to be processed // Wait for notification to be processed
@@ -58,7 +59,7 @@ func ExamplePostgresListener_basic() {
// Unsubscribe from the channel // Unsubscribe from the channel
if err := listener.Unlisten("user_events"); err != nil { if err := listener.Unlisten("user_events"); err != nil {
panic(fmt.Sprintf("Failed to unlisten: %v", err)) log.Fatalf("Failed to unlisten: %v", err)
} }
} }
@@ -80,13 +81,13 @@ func ExamplePostgresListener_multipleChannels() {
ctx := context.Background() ctx := context.Background()
if err := provider.Connect(ctx, cfg); err != nil { if err := provider.Connect(ctx, cfg); err != nil {
panic(fmt.Sprintf("Failed to connect: %v", err)) log.Fatalf("Failed to connect: %v", err)
} }
defer provider.Close() defer provider.Close()
listener, err := provider.GetListener(ctx) listener, err := provider.GetListener(ctx)
if err != nil { if err != nil {
panic(fmt.Sprintf("Failed to get listener: %v", err)) log.Fatalf("Failed to get listener: %v", err)
} }
// Listen to multiple channels // Listen to multiple channels
@@ -97,7 +98,7 @@ func ExamplePostgresListener_multipleChannels() {
fmt.Printf("[%s] %s\n", ch, payload) fmt.Printf("[%s] %s\n", ch, payload)
}) })
if err != nil { if err != nil {
panic(fmt.Sprintf("Failed to listen on %s: %v", channel, err)) log.Fatalf("Failed to listen on %s: %v", channel, err)
} }
} }
@@ -140,14 +141,14 @@ func ExamplePostgresListener_withDBManager() {
provider := providers.NewPostgresProvider() provider := providers.NewPostgresProvider()
if err := provider.Connect(ctx, cfg); err != nil { if err := provider.Connect(ctx, cfg); err != nil {
panic(err) log.Fatal(err)
} }
defer provider.Close() defer provider.Close()
// Get listener // Get listener
listener, err := provider.GetListener(ctx) listener, err := provider.GetListener(ctx)
if err != nil { if err != nil {
panic(err) log.Fatal(err)
} }
// Subscribe to application events // Subscribe to application events
@@ -186,13 +187,13 @@ func ExamplePostgresListener_errorHandling() {
ctx := context.Background() ctx := context.Background()
if err := provider.Connect(ctx, cfg); err != nil { if err := provider.Connect(ctx, cfg); err != nil {
panic(fmt.Sprintf("Failed to connect: %v", err)) log.Fatalf("Failed to connect: %v", err)
} }
defer provider.Close() defer provider.Close()
listener, err := provider.GetListener(ctx) listener, err := provider.GetListener(ctx)
if err != nil { if err != nil {
panic(fmt.Sprintf("Failed to get listener: %v", err)) log.Fatalf("Failed to get listener: %v", err)
} }
// The listener automatically reconnects if the connection is lost // The listener automatically reconnects if the connection is lost

View File

@@ -197,8 +197,19 @@ func (h *Handler) getSchemaAndTable(defaultSchema, entity string, model interfac
return defaultSchema, entity return defaultSchema, entity
} }
// recoverPanic catches a panic from the current goroutine and returns it as an error.
// Usage: defer recoverPanic(&returnedErr)
func recoverPanic(err *error) {
if r := recover(); r != nil {
msg := fmt.Sprintf("%v", r)
logger.Error("[resolvemcp] panic recovered: %s", msg)
*err = fmt.Errorf("internal error: %s", msg)
}
}
// executeRead reads records from the database and returns raw data + metadata. // executeRead reads records from the database and returns raw data + metadata.
func (h *Handler) executeRead(ctx context.Context, schema, entity, id string, options common.RequestOptions) (interface{}, *common.Metadata, error) { func (h *Handler) executeRead(ctx context.Context, schema, entity, id string, options common.RequestOptions) (_ interface{}, _ *common.Metadata, retErr error) {
defer recoverPanic(&retErr)
model, err := h.registry.GetModelByEntity(schema, entity) model, err := h.registry.GetModelByEntity(schema, entity)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("model not found: %w", err) return nil, nil, fmt.Errorf("model not found: %w", err)
@@ -254,15 +265,6 @@ func (h *Handler) executeRead(ctx context.Context, schema, entity, id string, op
query = query.ColumnExpr(fmt.Sprintf("(%s) AS %s", cu.Expression, cu.Name)) query = query.ColumnExpr(fmt.Sprintf("(%s) AS %s", cu.Expression, cu.Name))
} }
// Preloads
if len(options.Preload) > 0 {
var err error
query, err = h.applyPreloads(model, query, options.Preload)
if err != nil {
return nil, nil, fmt.Errorf("failed to apply preloads: %w", err)
}
}
// Filters // Filters
query = h.applyFilters(query, options.Filters) query = h.applyFilters(query, options.Filters)
@@ -304,7 +306,7 @@ func (h *Handler) executeRead(ctx context.Context, schema, entity, id string, op
} }
} }
// Count // Count — must happen before preloads are applied; Bun panics when counting with relations.
total, err := query.Count(ctx) total, err := query.Count(ctx)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("error counting records: %w", err) return nil, nil, fmt.Errorf("error counting records: %w", err)
@@ -318,6 +320,15 @@ func (h *Handler) executeRead(ctx context.Context, schema, entity, id string, op
query = query.Offset(*options.Offset) query = query.Offset(*options.Offset)
} }
// Preloads — applied after count to avoid Bun panic when counting with relations.
if len(options.Preload) > 0 {
var preloadErr error
query, preloadErr = h.applyPreloads(model, query, options.Preload)
if preloadErr != nil {
return nil, nil, fmt.Errorf("failed to apply preloads: %w", preloadErr)
}
}
// BeforeRead hook // BeforeRead hook
hookCtx.Query = query hookCtx.Query = query
if err := h.hooks.Execute(BeforeRead, hookCtx); err != nil { if err := h.hooks.Execute(BeforeRead, hookCtx); err != nil {
@@ -378,7 +389,8 @@ func (h *Handler) executeRead(ctx context.Context, schema, entity, id string, op
} }
// executeCreate inserts one or more records. // executeCreate inserts one or more records.
func (h *Handler) executeCreate(ctx context.Context, schema, entity string, data interface{}) (interface{}, error) { func (h *Handler) executeCreate(ctx context.Context, schema, entity string, data interface{}) (_ interface{}, retErr error) {
defer recoverPanic(&retErr)
model, err := h.registry.GetModelByEntity(schema, entity) model, err := h.registry.GetModelByEntity(schema, entity)
if err != nil { if err != nil {
return nil, fmt.Errorf("model not found: %w", err) return nil, fmt.Errorf("model not found: %w", err)
@@ -462,7 +474,8 @@ func (h *Handler) executeCreate(ctx context.Context, schema, entity string, data
} }
// executeUpdate updates a record by ID. // executeUpdate updates a record by ID.
func (h *Handler) executeUpdate(ctx context.Context, schema, entity, id string, data interface{}) (interface{}, error) { func (h *Handler) executeUpdate(ctx context.Context, schema, entity, id string, data interface{}) (_ interface{}, retErr error) {
defer recoverPanic(&retErr)
model, err := h.registry.GetModelByEntity(schema, entity) model, err := h.registry.GetModelByEntity(schema, entity)
if err != nil { if err != nil {
return nil, fmt.Errorf("model not found: %w", err) return nil, fmt.Errorf("model not found: %w", err)
@@ -572,7 +585,8 @@ func (h *Handler) executeUpdate(ctx context.Context, schema, entity, id string,
} }
// executeDelete deletes a record by ID. // executeDelete deletes a record by ID.
func (h *Handler) executeDelete(ctx context.Context, schema, entity, id string) (interface{}, error) { func (h *Handler) executeDelete(ctx context.Context, schema, entity, id string) (_ interface{}, retErr error) {
defer recoverPanic(&retErr)
if id == "" { if id == "" {
return nil, fmt.Errorf("delete requires an ID") return nil, fmt.Errorf("delete requires an ID")
} }

View File

@@ -3,6 +3,7 @@ package server_test
import ( import (
"context" "context"
"fmt" "fmt"
"log"
"net/http" "net/http"
"time" "time"
@@ -29,18 +30,18 @@ func ExampleManager_basic() {
GZIP: true, // Enable GZIP compression GZIP: true, // Enable GZIP compression
}) })
if err != nil { if err != nil {
panic(err) log.Fatal(err)
} }
// Start all servers // Start all servers
if err := mgr.StartAll(); err != nil { if err := mgr.StartAll(); err != nil {
panic(err) log.Fatal(err)
} }
// Server is now running... // Server is now running...
// When done, stop gracefully // When done, stop gracefully
if err := mgr.StopAll(); err != nil { if err := mgr.StopAll(); err != nil {
panic(err) log.Fatal(err)
} }
} }
@@ -61,7 +62,7 @@ func ExampleManager_https() {
SSLKey: "/path/to/key.pem", SSLKey: "/path/to/key.pem",
}) })
if err != nil { if err != nil {
panic(err) log.Fatal(err)
} }
// Option 2: Self-signed certificate (for development) // Option 2: Self-signed certificate (for development)
@@ -73,7 +74,7 @@ func ExampleManager_https() {
SelfSignedSSL: true, SelfSignedSSL: true,
}) })
if err != nil { if err != nil {
panic(err) log.Fatal(err)
} }
// Option 3: Let's Encrypt / AutoTLS (for production) // Option 3: Let's Encrypt / AutoTLS (for production)
@@ -88,12 +89,12 @@ func ExampleManager_https() {
AutoTLSCacheDir: "./certs-cache", AutoTLSCacheDir: "./certs-cache",
}) })
if err != nil { if err != nil {
panic(err) log.Fatal(err)
} }
// Start all servers // Start all servers
if err := mgr.StartAll(); err != nil { if err := mgr.StartAll(); err != nil {
panic(err) log.Fatal(err)
} }
// Cleanup // Cleanup
@@ -136,7 +137,7 @@ func ExampleManager_gracefulShutdown() {
IdleTimeout: 120 * time.Second, IdleTimeout: 120 * time.Second,
}) })
if err != nil { if err != nil {
panic(err) log.Fatal(err)
} }
// Start servers and block until shutdown signal (SIGINT/SIGTERM) // Start servers and block until shutdown signal (SIGINT/SIGTERM)
@@ -164,7 +165,7 @@ func ExampleManager_healthChecks() {
Handler: mux, Handler: mux,
}) })
if err != nil { if err != nil {
panic(err) log.Fatal(err)
} }
// Add health and readiness endpoints // Add health and readiness endpoints
@@ -173,7 +174,7 @@ func ExampleManager_healthChecks() {
// Start the server // Start the server
if err := mgr.StartAll(); err != nil { if err := mgr.StartAll(); err != nil {
panic(err) log.Fatal(err)
} }
// Health check returns: // Health check returns:
@@ -204,7 +205,7 @@ func ExampleManager_multipleServers() {
GZIP: true, GZIP: true,
}) })
if err != nil { if err != nil {
panic(err) log.Fatal(err)
} }
// Admin API server (different port) // Admin API server (different port)
@@ -218,7 +219,7 @@ func ExampleManager_multipleServers() {
Handler: adminHandler, Handler: adminHandler,
}) })
if err != nil { if err != nil {
panic(err) log.Fatal(err)
} }
// Metrics server (internal only) // Metrics server (internal only)
@@ -232,18 +233,18 @@ func ExampleManager_multipleServers() {
Handler: metricsHandler, Handler: metricsHandler,
}) })
if err != nil { if err != nil {
panic(err) log.Fatal(err)
} }
// Start all servers at once // Start all servers at once
if err := mgr.StartAll(); err != nil { if err := mgr.StartAll(); err != nil {
panic(err) log.Fatal(err)
} }
// Get specific server instance // Get specific server instance
publicInstance, err := mgr.Get("public-api") publicInstance, err := mgr.Get("public-api")
if err != nil { if err != nil {
panic(err) log.Fatal(err)
} }
fmt.Printf("Public API running on: %s\n", publicInstance.Addr()) fmt.Printf("Public API running on: %s\n", publicInstance.Addr())
@@ -253,7 +254,7 @@ func ExampleManager_multipleServers() {
// Stop all servers gracefully (in parallel) // Stop all servers gracefully (in parallel)
if err := mgr.StopAll(); err != nil { if err := mgr.StopAll(); err != nil {
panic(err) log.Fatal(err)
} }
} }
@@ -273,11 +274,11 @@ func ExampleManager_monitoring() {
Handler: handler, Handler: handler,
}) })
if err != nil { if err != nil {
panic(err) log.Fatal(err)
} }
if err := mgr.StartAll(); err != nil { if err := mgr.StartAll(); err != nil {
panic(err) log.Fatal(err)
} }
// Check server status // Check server status