diff --git a/pkg/server/manager.go b/pkg/server/manager.go index 02f0c3a..3fe739b 100644 --- a/pkg/server/manager.go +++ b/pkg/server/manager.go @@ -371,6 +371,7 @@ type serverInstance struct { gracefulServer *gracefulServer certFile string // Path to certificate file (may be temporary for self-signed) keyFile string // Path to key file (may be temporary for self-signed) + tempCertDir string // Path to temporary certificate directory (for cleanup) mu sync.RWMutex running bool serverErr chan error @@ -415,7 +416,7 @@ func newInstance(cfg Config) (*serverInstance, error) { handler = middleware.PanicRecovery(handler) // Configure TLS if any TLS option is enabled - tlsConfig, certFile, keyFile, err := configureTLS(cfg) + tlsConfig, certFile, keyFile, tempCertDir, err := configureTLS(cfg) if err != nil { return nil, fmt.Errorf("failed to configure TLS: %w", err) } @@ -440,6 +441,7 @@ func newInstance(cfg Config) (*serverInstance, error) { gracefulServer: gracefulSrv, certFile: certFile, keyFile: keyFile, + tempCertDir: tempCertDir, serverErr: make(chan error, 1), }, nil } @@ -533,6 +535,20 @@ func (s *serverInstance) Stop(ctx context.Context) error { if err == nil { s.running = false } + + // Clean up temporary certificate directory if it exists + if s.tempCertDir != "" { + if cleanupErr := os.RemoveAll(s.tempCertDir); cleanupErr != nil { + logger.Error("Failed to clean up temporary certificate directory '%s': %v", s.tempCertDir, cleanupErr) + // Don't override the shutdown error with cleanup error + if err == nil { + err = fmt.Errorf("failed to clean up temporary certificates: %w", cleanupErr) + } + } else { + logger.Info("Cleaned up temporary certificate directory for server '%s'", s.cfg.Name) + } + } + return err } diff --git a/pkg/server/manager_test.go b/pkg/server/manager_test.go index 8ad123b..fe8d7c1 100644 --- a/pkg/server/manager_test.go +++ b/pkg/server/manager_test.go @@ -6,6 +6,7 @@ import ( "io" "net" "net/http" + "os" "sync" "testing" "time" @@ -326,3 +327,42 @@ func TestShutdownCallbacks(t *testing.T) { assert.True(t, executed, "Shutdown callback should have been executed") } + +func TestSelfSignedSSLCleanup(t *testing.T) { + logger.Init(true) + sm := NewManager() + + testPort := getFreePort(t) + instance, err := sm.Add(Config{ + Name: "SSLTestServer", + Host: "localhost", + Port: testPort, + Handler: http.NewServeMux(), + SelfSignedSSL: true, + ShutdownTimeout: 5 * time.Second, + }) + require.NoError(t, err) + + // Get the serverInstance to access the tempCertDir + si, ok := instance.(*serverInstance) + require.True(t, ok, "instance should be of type *serverInstance") + require.NotEmpty(t, si.tempCertDir, "temporary certificate directory should be set") + + // Verify the temp directory exists + _, err = os.Stat(si.tempCertDir) + require.NoError(t, err, "temporary certificate directory should exist") + + // Start the server + err = sm.StartAll() + require.NoError(t, err) + + time.Sleep(100 * time.Millisecond) + + // Stop the server + err = sm.StopAll() + require.NoError(t, err) + + // Verify the temp directory has been cleaned up + _, err = os.Stat(si.tempCertDir) + assert.True(t, os.IsNotExist(err), "temporary certificate directory should be cleaned up after shutdown") +} diff --git a/pkg/server/tls.go b/pkg/server/tls.go index 5961291..30b35ae 100644 --- a/pkg/server/tls.go +++ b/pkg/server/tls.go @@ -76,12 +76,12 @@ func generateSelfSignedCert(host string) (certPEM, keyPEM []byte, err error) { } // saveCertToTempFiles saves certificate and key PEM data to temporary files. -// Returns the file paths for the certificate and key. -func saveCertToTempFiles(certPEM, keyPEM []byte) (certFile, keyFile string, err error) { +// Returns the file paths for the certificate and key, and the temporary directory path. +func saveCertToTempFiles(certPEM, keyPEM []byte) (certFile, keyFile, tmpDir string, err error) { // Create temporary directory - tmpDir, err := os.MkdirTemp("", "resolvespec-certs-*") + tmpDir, err = os.MkdirTemp("", "resolvespec-certs-*") if err != nil { - return "", "", fmt.Errorf("failed to create temp directory: %w", err) + return "", "", "", fmt.Errorf("failed to create temp directory: %w", err) } certFile = filepath.Join(tmpDir, "cert.pem") @@ -90,16 +90,16 @@ func saveCertToTempFiles(certPEM, keyPEM []byte) (certFile, keyFile string, err // Write certificate if err := os.WriteFile(certFile, certPEM, 0600); err != nil { os.RemoveAll(tmpDir) - return "", "", fmt.Errorf("failed to write certificate: %w", err) + return "", "", "", fmt.Errorf("failed to write certificate: %w", err) } // Write key if err := os.WriteFile(keyFile, keyPEM, 0600); err != nil { os.RemoveAll(tmpDir) - return "", "", fmt.Errorf("failed to write private key: %w", err) + return "", "", "", fmt.Errorf("failed to write private key: %w", err) } - return certFile, keyFile, nil + return certFile, keyFile, tmpDir, nil } // setupAutoTLS configures automatic TLS certificate management using Let's Encrypt. @@ -135,32 +135,32 @@ func setupAutoTLS(domains []string, email, cacheDir string) (*tls.Config, error) } // configureTLS configures TLS for the server based on the provided configuration. -// Returns the TLS config and certificate/key file paths (if applicable). -func configureTLS(cfg Config) (*tls.Config, string, string, error) { +// Returns the TLS config, certificate/key file paths (if applicable), and temp directory path (if applicable). +func configureTLS(cfg Config) (*tls.Config, string, string, string, error) { // Option 1: Certificate files provided if cfg.SSLCert != "" && cfg.SSLKey != "" { // Validate that files exist if _, err := os.Stat(cfg.SSLCert); os.IsNotExist(err) { - return nil, "", "", fmt.Errorf("SSL certificate file not found: %s", cfg.SSLCert) + return nil, "", "", "", fmt.Errorf("SSL certificate file not found: %s", cfg.SSLCert) } if _, err := os.Stat(cfg.SSLKey); os.IsNotExist(err) { - return nil, "", "", fmt.Errorf("SSL key file not found: %s", cfg.SSLKey) + return nil, "", "", "", fmt.Errorf("SSL key file not found: %s", cfg.SSLKey) } // Return basic TLS config - cert/key will be loaded by ListenAndServeTLS tlsConfig := &tls.Config{ MinVersion: tls.VersionTLS12, } - return tlsConfig, cfg.SSLCert, cfg.SSLKey, nil + return tlsConfig, cfg.SSLCert, cfg.SSLKey, "", nil } // Option 2: Auto TLS (Let's Encrypt) if cfg.AutoTLS { tlsConfig, err := setupAutoTLS(cfg.AutoTLSDomains, cfg.AutoTLSEmail, cfg.AutoTLSCacheDir) if err != nil { - return nil, "", "", fmt.Errorf("failed to setup AutoTLS: %w", err) + return nil, "", "", "", fmt.Errorf("failed to setup AutoTLS: %w", err) } - return tlsConfig, "", "", nil + return tlsConfig, "", "", "", nil } // Option 3: Self-signed certificate @@ -172,19 +172,19 @@ func configureTLS(cfg Config) (*tls.Config, string, string, error) { certPEM, keyPEM, err := generateSelfSignedCert(host) if err != nil { - return nil, "", "", fmt.Errorf("failed to generate self-signed certificate: %w", err) + return nil, "", "", "", fmt.Errorf("failed to generate self-signed certificate: %w", err) } - certFile, keyFile, err := saveCertToTempFiles(certPEM, keyPEM) + certFile, keyFile, tmpDir, err := saveCertToTempFiles(certPEM, keyPEM) if err != nil { - return nil, "", "", fmt.Errorf("failed to save self-signed certificate: %w", err) + return nil, "", "", "", fmt.Errorf("failed to save self-signed certificate: %w", err) } tlsConfig := &tls.Config{ MinVersion: tls.VersionTLS12, } - return tlsConfig, certFile, keyFile, nil + return tlsConfig, certFile, keyFile, tmpDir, nil } - return nil, "", "", nil + return nil, "", "", "", nil }