diff --git a/pkg/server/manager.go b/pkg/server/manager.go index 3fe739b..11a55c1 100644 --- a/pkg/server/manager.go +++ b/pkg/server/manager.go @@ -369,9 +369,8 @@ func (sm *serverManager) ServeWithGracefulShutdown() error { type serverInstance struct { cfg Config gracefulServer *gracefulServer - certFile string // Path to certificate file (may be temporary for self-signed) - keyFile string // Path to key file (may be temporary for self-signed) - tempCertDir string // Path to temporary certificate directory (for cleanup) + certFile string // Path to certificate file (may be persistent for self-signed) + keyFile string // Path to key file (may be persistent for self-signed) mu sync.RWMutex running bool serverErr chan error @@ -416,7 +415,7 @@ func newInstance(cfg Config) (*serverInstance, error) { handler = middleware.PanicRecovery(handler) // Configure TLS if any TLS option is enabled - tlsConfig, certFile, keyFile, tempCertDir, err := configureTLS(cfg) + tlsConfig, certFile, keyFile, err := configureTLS(cfg) if err != nil { return nil, fmt.Errorf("failed to configure TLS: %w", err) } @@ -441,7 +440,6 @@ func newInstance(cfg Config) (*serverInstance, error) { gracefulServer: gracefulSrv, certFile: certFile, keyFile: keyFile, - tempCertDir: tempCertDir, serverErr: make(chan error, 1), }, nil } @@ -535,20 +533,6 @@ func (s *serverInstance) Stop(ctx context.Context) error { if err == nil { s.running = false } - - // Clean up temporary certificate directory if it exists - if s.tempCertDir != "" { - if cleanupErr := os.RemoveAll(s.tempCertDir); cleanupErr != nil { - logger.Error("Failed to clean up temporary certificate directory '%s': %v", s.tempCertDir, cleanupErr) - // Don't override the shutdown error with cleanup error - if err == nil { - err = fmt.Errorf("failed to clean up temporary certificates: %w", cleanupErr) - } - } else { - logger.Info("Cleaned up temporary certificate directory for server '%s'", s.cfg.Name) - } - } - return err } diff --git a/pkg/server/manager_test.go b/pkg/server/manager_test.go index fe8d7c1..b9785a9 100644 --- a/pkg/server/manager_test.go +++ b/pkg/server/manager_test.go @@ -7,6 +7,7 @@ import ( "net" "net/http" "os" + "path/filepath" "sync" "testing" "time" @@ -328,41 +329,70 @@ func TestShutdownCallbacks(t *testing.T) { assert.True(t, executed, "Shutdown callback should have been executed") } -func TestSelfSignedSSLCleanup(t *testing.T) { +func TestSelfSignedSSLCertificateReuse(t *testing.T) { logger.Init(true) - sm := NewManager() - - testPort := getFreePort(t) - instance, err := sm.Add(Config{ - Name: "SSLTestServer", - Host: "localhost", - Port: testPort, + + // Get cert directory to verify file creation + certDir, err := getCertDirectory() + require.NoError(t, err) + + host := "localhost" + certFile := filepath.Join(certDir, fmt.Sprintf("%s-cert.pem", host)) + keyFile := filepath.Join(certDir, fmt.Sprintf("%s-key.pem", host)) + + // Clean up any existing cert files from previous tests + os.Remove(certFile) + os.Remove(keyFile) + + // First server creation - should generate new certificates + sm1 := NewManager() + testPort1 := getFreePort(t) + _, err = sm1.Add(Config{ + Name: "SSLTestServer1", + Host: host, + Port: testPort1, Handler: http.NewServeMux(), SelfSignedSSL: true, ShutdownTimeout: 5 * time.Second, }) require.NoError(t, err) - - // Get the serverInstance to access the tempCertDir - si, ok := instance.(*serverInstance) - require.True(t, ok, "instance should be of type *serverInstance") - require.NotEmpty(t, si.tempCertDir, "temporary certificate directory should be set") - - // Verify the temp directory exists - _, err = os.Stat(si.tempCertDir) - require.NoError(t, err, "temporary certificate directory should exist") - - // Start the server - err = sm.StartAll() + + // Verify certificates were created + _, err = os.Stat(certFile) + require.NoError(t, err, "certificate file should exist after first creation") + _, err = os.Stat(keyFile) + require.NoError(t, err, "key file should exist after first creation") + + // Get modification time of cert file + info1, err := os.Stat(certFile) require.NoError(t, err) - + modTime1 := info1.ModTime() + + // Wait a bit to ensure different modification times time.Sleep(100 * time.Millisecond) - - // Stop the server - err = sm.StopAll() + + // Second server creation - should reuse existing certificates + sm2 := NewManager() + testPort2 := getFreePort(t) + _, err = sm2.Add(Config{ + Name: "SSLTestServer2", + Host: host, + Port: testPort2, + Handler: http.NewServeMux(), + SelfSignedSSL: true, + ShutdownTimeout: 5 * time.Second, + }) require.NoError(t, err) - - // Verify the temp directory has been cleaned up - _, err = os.Stat(si.tempCertDir) - assert.True(t, os.IsNotExist(err), "temporary certificate directory should be cleaned up after shutdown") + + // Get modification time of cert file after second creation + info2, err := os.Stat(certFile) + require.NoError(t, err) + modTime2 := info2.ModTime() + + // Verify the certificate was reused (same modification time) + assert.Equal(t, modTime1, modTime2, "certificate should be reused, not regenerated") + + // Clean up + sm1.StopAll() + sm2.StopAll() } diff --git a/pkg/server/tls.go b/pkg/server/tls.go index 30b35ae..fdadcc7 100644 --- a/pkg/server/tls.go +++ b/pkg/server/tls.go @@ -75,31 +75,85 @@ func generateSelfSignedCert(host string) (certPEM, keyPEM []byte, err error) { return certPEM, keyPEM, nil } -// saveCertToTempFiles saves certificate and key PEM data to temporary files. -// Returns the file paths for the certificate and key, and the temporary directory path. -func saveCertToTempFiles(certPEM, keyPEM []byte) (certFile, keyFile, tmpDir string, err error) { - // Create temporary directory - tmpDir, err = os.MkdirTemp("", "resolvespec-certs-*") +// getCertDirectory returns the directory path for storing self-signed certificates. +// Creates the directory if it doesn't exist. +func getCertDirectory() (string, error) { + // Use a consistent directory in the user's cache directory + cacheDir, err := os.UserCacheDir() if err != nil { - return "", "", "", fmt.Errorf("failed to create temp directory: %w", err) + // Fallback to current directory if cache dir is not available + cacheDir = "." } + + certDir := filepath.Join(cacheDir, "resolvespec", "certs") + + // Create directory if it doesn't exist + if err := os.MkdirAll(certDir, 0700); err != nil { + return "", fmt.Errorf("failed to create certificate directory: %w", err) + } + + return certDir, nil +} - certFile = filepath.Join(tmpDir, "cert.pem") - keyFile = filepath.Join(tmpDir, "key.pem") +// isCertificateValid checks if a certificate file exists and is not expired. +func isCertificateValid(certFile string) bool { + // Check if file exists + certData, err := os.ReadFile(certFile) + if err != nil { + return false + } + + // Parse certificate + block, _ := pem.Decode(certData) + if block == nil { + return false + } + + cert, err := x509.ParseCertificate(block.Bytes) + if err != nil { + return false + } + + // Check if certificate is expired or will expire in the next 30 days + now := time.Now() + expiryThreshold := now.Add(30 * 24 * time.Hour) + + if now.Before(cert.NotBefore) || now.After(cert.NotAfter) { + return false + } + + // Renew if expiring soon + if expiryThreshold.After(cert.NotAfter) { + return false + } + + return true +} +// saveCertToFiles saves certificate and key PEM data to persistent files. +// Returns the file paths for the certificate and key. +func saveCertToFiles(certPEM, keyPEM []byte, host string) (certFile, keyFile string, err error) { + // Get certificate directory + certDir, err := getCertDirectory() + if err != nil { + return "", "", err + } + + // Use consistent file names based on host + certFile = filepath.Join(certDir, fmt.Sprintf("%s-cert.pem", host)) + keyFile = filepath.Join(certDir, fmt.Sprintf("%s-key.pem", host)) + // Write certificate if err := os.WriteFile(certFile, certPEM, 0600); err != nil { - os.RemoveAll(tmpDir) - return "", "", "", fmt.Errorf("failed to write certificate: %w", err) + return "", "", fmt.Errorf("failed to write certificate: %w", err) } - + // Write key if err := os.WriteFile(keyFile, keyPEM, 0600); err != nil { - os.RemoveAll(tmpDir) - return "", "", "", fmt.Errorf("failed to write private key: %w", err) + return "", "", fmt.Errorf("failed to write private key: %w", err) } - - return certFile, keyFile, tmpDir, nil + + return certFile, keyFile, nil } // setupAutoTLS configures automatic TLS certificate management using Let's Encrypt. @@ -135,32 +189,32 @@ func setupAutoTLS(domains []string, email, cacheDir string) (*tls.Config, error) } // configureTLS configures TLS for the server based on the provided configuration. -// Returns the TLS config, certificate/key file paths (if applicable), and temp directory path (if applicable). -func configureTLS(cfg Config) (*tls.Config, string, string, string, error) { +// Returns the TLS config and certificate/key file paths (if applicable). +func configureTLS(cfg Config) (*tls.Config, string, string, error) { // Option 1: Certificate files provided if cfg.SSLCert != "" && cfg.SSLKey != "" { // Validate that files exist if _, err := os.Stat(cfg.SSLCert); os.IsNotExist(err) { - return nil, "", "", "", fmt.Errorf("SSL certificate file not found: %s", cfg.SSLCert) + return nil, "", "", fmt.Errorf("SSL certificate file not found: %s", cfg.SSLCert) } if _, err := os.Stat(cfg.SSLKey); os.IsNotExist(err) { - return nil, "", "", "", fmt.Errorf("SSL key file not found: %s", cfg.SSLKey) + return nil, "", "", fmt.Errorf("SSL key file not found: %s", cfg.SSLKey) } // Return basic TLS config - cert/key will be loaded by ListenAndServeTLS tlsConfig := &tls.Config{ MinVersion: tls.VersionTLS12, } - return tlsConfig, cfg.SSLCert, cfg.SSLKey, "", nil + return tlsConfig, cfg.SSLCert, cfg.SSLKey, nil } // Option 2: Auto TLS (Let's Encrypt) if cfg.AutoTLS { tlsConfig, err := setupAutoTLS(cfg.AutoTLSDomains, cfg.AutoTLSEmail, cfg.AutoTLSCacheDir) if err != nil { - return nil, "", "", "", fmt.Errorf("failed to setup AutoTLS: %w", err) + return nil, "", "", fmt.Errorf("failed to setup AutoTLS: %w", err) } - return tlsConfig, "", "", "", nil + return tlsConfig, "", "", nil } // Option 3: Self-signed certificate @@ -170,21 +224,43 @@ func configureTLS(cfg Config) (*tls.Config, string, string, string, error) { host = "localhost" } - certPEM, keyPEM, err := generateSelfSignedCert(host) + // Get certificate directory + certDir, err := getCertDirectory() if err != nil { - return nil, "", "", "", fmt.Errorf("failed to generate self-signed certificate: %w", err) + return nil, "", "", fmt.Errorf("failed to get certificate directory: %w", err) } - certFile, keyFile, tmpDir, err := saveCertToTempFiles(certPEM, keyPEM) + // Check for existing valid certificates + certFile := filepath.Join(certDir, fmt.Sprintf("%s-cert.pem", host)) + keyFile := filepath.Join(certDir, fmt.Sprintf("%s-key.pem", host)) + + // If valid certificates exist, reuse them + if isCertificateValid(certFile) { + // Verify key file also exists + if _, err := os.Stat(keyFile); err == nil { + tlsConfig := &tls.Config{ + MinVersion: tls.VersionTLS12, + } + return tlsConfig, certFile, keyFile, nil + } + } + + // Generate new certificates + certPEM, keyPEM, err := generateSelfSignedCert(host) if err != nil { - return nil, "", "", "", fmt.Errorf("failed to save self-signed certificate: %w", err) + return nil, "", "", fmt.Errorf("failed to generate self-signed certificate: %w", err) + } + + certFile, keyFile, err = saveCertToFiles(certPEM, keyPEM, host) + if err != nil { + return nil, "", "", fmt.Errorf("failed to save self-signed certificate: %w", err) } tlsConfig := &tls.Config{ MinVersion: tls.VersionTLS12, } - return tlsConfig, certFile, keyFile, tmpDir, nil + return tlsConfig, certFile, keyFile, nil } - return nil, "", "", "", nil + return nil, "", "", nil }