diff --git a/pkg/server/manager_test.go b/pkg/server/manager_test.go index b9785a9..e2b35de 100644 --- a/pkg/server/manager_test.go +++ b/pkg/server/manager_test.go @@ -332,9 +332,10 @@ func TestShutdownCallbacks(t *testing.T) { func TestSelfSignedSSLCertificateReuse(t *testing.T) { logger.Init(true) - // Get cert directory to verify file creation - certDir, err := getCertDirectory() + // Get expected cert directory location + cacheDir, err := os.UserCacheDir() require.NoError(t, err) + certDir := filepath.Join(cacheDir, "resolvespec", "certs") host := "localhost" certFile := filepath.Join(certDir, fmt.Sprintf("%s-cert.pem", host)) diff --git a/pkg/server/tls.go b/pkg/server/tls.go index fdadcc7..22ec935 100644 --- a/pkg/server/tls.go +++ b/pkg/server/tls.go @@ -13,11 +13,15 @@ import ( "net" "os" "path/filepath" + "sync" "time" "golang.org/x/crypto/acme/autocert" ) +// certGenerationMutex protects concurrent certificate generation for the same host +var certGenerationMutex sync.Mutex + // generateSelfSignedCert generates a self-signed certificate for the given host. // Returns the certificate and private key in PEM format. func generateSelfSignedCert(host string) (certPEM, keyPEM []byte, err error) { @@ -75,6 +79,20 @@ func generateSelfSignedCert(host string) (certPEM, keyPEM []byte, err error) { return certPEM, keyPEM, nil } +// sanitizeHostname converts a hostname to a safe filename by replacing invalid characters. +func sanitizeHostname(host string) string { + // Replace any character that's not alphanumeric, dot, or dash with underscore + safe := "" + for _, r := range host { + if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9') || r == '.' || r == '-' { + safe += string(r) + } else { + safe += "_" + } + } + return safe +} + // getCertDirectory returns the directory path for storing self-signed certificates. // Creates the directory if it doesn't exist. func getCertDirectory() (string, error) { @@ -139,9 +157,12 @@ func saveCertToFiles(certPEM, keyPEM []byte, host string) (certFile, keyFile str return "", "", err } + // Sanitize hostname for safe file naming + safeHost := sanitizeHostname(host) + // Use consistent file names based on host - certFile = filepath.Join(certDir, fmt.Sprintf("%s-cert.pem", host)) - keyFile = filepath.Join(certDir, fmt.Sprintf("%s-key.pem", host)) + certFile = filepath.Join(certDir, fmt.Sprintf("%s-cert.pem", safeHost)) + keyFile = filepath.Join(certDir, fmt.Sprintf("%s-key.pem", safeHost)) // Write certificate if err := os.WriteFile(certFile, certPEM, 0600); err != nil { @@ -224,6 +245,13 @@ func configureTLS(cfg Config) (*tls.Config, string, string, error) { host = "localhost" } + // Sanitize hostname for safe file naming + safeHost := sanitizeHostname(host) + + // Lock to prevent concurrent certificate generation for the same host + certGenerationMutex.Lock() + defer certGenerationMutex.Unlock() + // Get certificate directory certDir, err := getCertDirectory() if err != nil { @@ -231,8 +259,8 @@ func configureTLS(cfg Config) (*tls.Config, string, string, error) { } // Check for existing valid certificates - certFile := filepath.Join(certDir, fmt.Sprintf("%s-cert.pem", host)) - keyFile := filepath.Join(certDir, fmt.Sprintf("%s-key.pem", host)) + certFile := filepath.Join(certDir, fmt.Sprintf("%s-cert.pem", safeHost)) + keyFile := filepath.Join(certDir, fmt.Sprintf("%s-key.pem", safeHost)) // If valid certificates exist, reuse them if isCertificateValid(certFile) {