From 250fcf686c2b0df4f1d2da394c7d234440bcab45 Mon Sep 17 00:00:00 2001 From: Hein Date: Sat, 3 Jan 2026 01:48:42 +0200 Subject: [PATCH] feat(config): add multiple server instances support - Add ServersConfig and ServerInstanceConfig structs - Support configuring multiple named server instances - Add global timeout defaults with per-instance overrides - Add TLS configuration options (SSL cert/key, self-signed, AutoTLS) - Add validation for server configurations - Add helper methods for applying defaults and getting default server - Add conversion helper to avoid import cycles --- cmd/testserver/main.go | 50 ++-- pkg/config/config.go | 79 +++++- pkg/config/manager.go | 36 ++- pkg/config/manager_test.go | 466 +++++++++++++++++++++++++++++++++++- pkg/config/paths.go | 117 +++++++++ pkg/config/server.go | 106 ++++++++ pkg/server/config_helper.go | 47 ++++ pkg/server/manager.go | 4 +- 8 files changed, 837 insertions(+), 68 deletions(-) create mode 100644 pkg/config/paths.go create mode 100644 pkg/config/server.go create mode 100644 pkg/server/config_helper.go diff --git a/cmd/testserver/main.go b/cmd/testserver/main.go index 9eeca2e..b9fea07 100644 --- a/cmd/testserver/main.go +++ b/cmd/testserver/main.go @@ -39,7 +39,6 @@ func main() { logger.UpdateLoggerPath(cfg.Logger.Path, cfg.Logger.Dev) } logger.Info("ResolveSpec test server starting") - logger.Info("Configuration loaded - Server will listen on: %s", cfg.Server.Addr) // Initialize database manager ctx := context.Background() @@ -73,47 +72,30 @@ func main() { // Create server manager mgr := server.NewManager() - // Parse host and port from addr - host := "" - port := 8080 - if cfg.Server.Addr != "" { - // Parse addr (format: ":8080" or "localhost:8080") - if cfg.Server.Addr[0] == ':' { - // Just port - _, err := fmt.Sscanf(cfg.Server.Addr, ":%d", &port) - if err != nil { - logger.Error("Invalid server address: %s", cfg.Server.Addr) - os.Exit(1) - } - } else { - // Host and port - _, err := fmt.Sscanf(cfg.Server.Addr, "%[^:]:%d", &host, &port) - if err != nil { - logger.Error("Invalid server address: %s", cfg.Server.Addr) - os.Exit(1) - } - } + // Get default server configuration + defaultServerCfg, err := cfg.Servers.GetDefault() + if err != nil { + logger.Error("Failed to get default server config: %v", err) + os.Exit(1) } - // Add server instance - _, err = mgr.Add(server.Config{ - Name: "api", - Host: host, - Port: port, - Handler: r, - ShutdownTimeout: cfg.Server.ShutdownTimeout, - DrainTimeout: cfg.Server.DrainTimeout, - ReadTimeout: cfg.Server.ReadTimeout, - WriteTimeout: cfg.Server.WriteTimeout, - IdleTimeout: cfg.Server.IdleTimeout, - }) + // Apply global defaults + defaultServerCfg.ApplyGlobalDefaults(cfg.Servers) + + // Convert to server.Config and add instance + serverCfg := server.FromConfigInstanceToServerConfig(defaultServerCfg, r) + + logger.Info("Configuration loaded - Server '%s' will listen on %s:%d", + serverCfg.Name, serverCfg.Host, serverCfg.Port) + + _, err = mgr.Add(serverCfg) if err != nil { logger.Error("Failed to add server: %v", err) os.Exit(1) } // Start server with graceful shutdown - logger.Info("Starting server on %s", cfg.Server.Addr) + logger.Info("Starting server '%s' on %s:%d", serverCfg.Name, serverCfg.Host, serverCfg.Port) if err := mgr.ServeWithGracefulShutdown(); err != nil { logger.Error("Server failed: %v", err) os.Exit(1) diff --git a/pkg/config/config.go b/pkg/config/config.go index 6798101..faa8387 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -4,20 +4,28 @@ import "time" // Config represents the complete application configuration type Config struct { - Server ServerConfig `mapstructure:"server"` - Tracing TracingConfig `mapstructure:"tracing"` - Cache CacheConfig `mapstructure:"cache"` - Logger LoggerConfig `mapstructure:"logger"` - ErrorTracking ErrorTrackingConfig `mapstructure:"error_tracking"` - Middleware MiddlewareConfig `mapstructure:"middleware"` - CORS CORSConfig `mapstructure:"cors"` - EventBroker EventBrokerConfig `mapstructure:"event_broker"` - DBManager DBManagerConfig `mapstructure:"dbmanager"` + Servers ServersConfig `mapstructure:"servers"` + Tracing TracingConfig `mapstructure:"tracing"` + Cache CacheConfig `mapstructure:"cache"` + Logger LoggerConfig `mapstructure:"logger"` + ErrorTracking ErrorTrackingConfig `mapstructure:"error_tracking"` + Middleware MiddlewareConfig `mapstructure:"middleware"` + CORS CORSConfig `mapstructure:"cors"` + EventBroker EventBrokerConfig `mapstructure:"event_broker"` + DBManager DBManagerConfig `mapstructure:"dbmanager"` + Paths PathsConfig `mapstructure:"paths"` + Extensions map[string]interface{} `mapstructure:"extensions"` } -// ServerConfig holds server-related configuration -type ServerConfig struct { - Addr string `mapstructure:"addr"` +// ServersConfig contains configuration for the server manager +type ServersConfig struct { + // DefaultServer is the name of the default server to use + DefaultServer string `mapstructure:"default_server"` + + // Instances is a map of server name to server configuration + Instances map[string]ServerInstanceConfig `mapstructure:"instances"` + + // Global timeout defaults (can be overridden per instance) ShutdownTimeout time.Duration `mapstructure:"shutdown_timeout"` DrainTimeout time.Duration `mapstructure:"drain_timeout"` ReadTimeout time.Duration `mapstructure:"read_timeout"` @@ -25,6 +33,48 @@ type ServerConfig struct { IdleTimeout time.Duration `mapstructure:"idle_timeout"` } +// ServerInstanceConfig defines configuration for a single server instance +type ServerInstanceConfig struct { + // Name is the unique name of this server instance + Name string `mapstructure:"name"` + + // Host is the host to bind to (e.g., "localhost", "0.0.0.0", "") + Host string `mapstructure:"host"` + + // Port is the port number to listen on + Port int `mapstructure:"port"` + + // Description is a human-readable description of this server + Description string `mapstructure:"description"` + + // GZIP enables GZIP compression middleware + GZIP bool `mapstructure:"gzip"` + + // TLS/HTTPS configuration options (mutually exclusive) + // Option 1: Provide certificate and key files directly + SSLCert string `mapstructure:"ssl_cert"` + SSLKey string `mapstructure:"ssl_key"` + + // Option 2: Use self-signed certificate (for development/testing) + SelfSignedSSL bool `mapstructure:"self_signed_ssl"` + + // Option 3: Use Let's Encrypt / AutoTLS + AutoTLS bool `mapstructure:"auto_tls"` + AutoTLSDomains []string `mapstructure:"auto_tls_domains"` + AutoTLSCacheDir string `mapstructure:"auto_tls_cache_dir"` + AutoTLSEmail string `mapstructure:"auto_tls_email"` + + // Timeout configurations (overrides global defaults) + ShutdownTimeout *time.Duration `mapstructure:"shutdown_timeout"` + DrainTimeout *time.Duration `mapstructure:"drain_timeout"` + ReadTimeout *time.Duration `mapstructure:"read_timeout"` + WriteTimeout *time.Duration `mapstructure:"write_timeout"` + IdleTimeout *time.Duration `mapstructure:"idle_timeout"` + + // Tags for organization and filtering + Tags map[string]string `mapstructure:"tags"` +} + // TracingConfig holds OpenTelemetry tracing configuration type TracingConfig struct { Enabled bool `mapstructure:"enabled"` @@ -136,3 +186,8 @@ type EventBrokerRetryPolicyConfig struct { MaxDelay time.Duration `mapstructure:"max_delay"` BackoffFactor float64 `mapstructure:"backoff_factor"` } + +// PathsConfig contains configuration for named file system paths +// This is a map of path name to file system path +// Example: "data_dir": "/var/lib/myapp/data" +type PathsConfig map[string]string diff --git a/pkg/config/manager.go b/pkg/config/manager.go index 91edd63..ec6351c 100644 --- a/pkg/config/manager.go +++ b/pkg/config/manager.go @@ -107,7 +107,7 @@ func (m *Manager) SetConfig(cfg *Config) error { } // Use viper's merge to apply the config - m.v.Set("server", cfg.Server) + m.v.Set("servers", cfg.Servers) m.v.Set("tracing", cfg.Tracing) m.v.Set("cache", cfg.Cache) m.v.Set("logger", cfg.Logger) @@ -116,6 +116,8 @@ func (m *Manager) SetConfig(cfg *Config) error { m.v.Set("cors", cfg.CORS) m.v.Set("event_broker", cfg.EventBroker) m.v.Set("dbmanager", cfg.DBManager) + m.v.Set("paths", cfg.Paths) + m.v.Set("extensions", cfg.Extensions) return nil } @@ -155,13 +157,22 @@ func (m *Manager) SaveConfig(path string) error { // setDefaults sets default configuration values func setDefaults(v *viper.Viper) { - // Server defaults - v.SetDefault("server.addr", ":8080") - v.SetDefault("server.shutdown_timeout", "30s") - v.SetDefault("server.drain_timeout", "25s") - v.SetDefault("server.read_timeout", "10s") - v.SetDefault("server.write_timeout", "10s") - v.SetDefault("server.idle_timeout", "120s") + // Server defaults - new structure + v.SetDefault("servers.default_server", "default") + + // Global server timeout defaults + v.SetDefault("servers.shutdown_timeout", "30s") + v.SetDefault("servers.drain_timeout", "25s") + v.SetDefault("servers.read_timeout", "10s") + v.SetDefault("servers.write_timeout", "10s") + v.SetDefault("servers.idle_timeout", "120s") + + // Default server instance + v.SetDefault("servers.instances.default.name", "default") + v.SetDefault("servers.instances.default.host", "") + v.SetDefault("servers.instances.default.port", 8080) + v.SetDefault("servers.instances.default.description", "Default HTTP server") + v.SetDefault("servers.instances.default.gzip", false) // Tracing defaults v.SetDefault("tracing.enabled", false) @@ -259,4 +270,13 @@ func setDefaults(v *viper.Viper) { v.SetDefault("event_broker.retry_policy.initial_delay", "1s") v.SetDefault("event_broker.retry_policy.max_delay", "30s") v.SetDefault("event_broker.retry_policy.backoff_factor", 2.0) + + // Paths defaults (common directory paths) + v.SetDefault("paths.data_dir", "./data") + v.SetDefault("paths.config_dir", "./config") + v.SetDefault("paths.logs_dir", "./logs") + v.SetDefault("paths.temp_dir", "./tmp") + + // Extensions defaults (empty map) + v.SetDefault("extensions", map[string]interface{}{}) } diff --git a/pkg/config/manager_test.go b/pkg/config/manager_test.go index ec83378..314fd3c 100644 --- a/pkg/config/manager_test.go +++ b/pkg/config/manager_test.go @@ -34,8 +34,8 @@ func TestDefaultValues(t *testing.T) { got interface{} expected interface{} }{ - {"server.addr", cfg.Server.Addr, ":8080"}, - {"server.shutdown_timeout", cfg.Server.ShutdownTimeout, 30 * time.Second}, + {"servers.default_server", cfg.Servers.DefaultServer, "default"}, + {"servers.shutdown_timeout", cfg.Servers.ShutdownTimeout, 30 * time.Second}, {"tracing.enabled", cfg.Tracing.Enabled, false}, {"tracing.service_name", cfg.Tracing.ServiceName, "resolvespec"}, {"cache.provider", cfg.Cache.Provider, "memory"}, @@ -46,6 +46,18 @@ func TestDefaultValues(t *testing.T) { {"middleware.rate_limit_burst", cfg.Middleware.RateLimitBurst, 200}, } + // Test default server instance + defaultServer, ok := cfg.Servers.Instances["default"] + if !ok { + t.Fatal("Default server instance not found") + } + if defaultServer.Port != 8080 { + t.Errorf("default server port: got %d, want 8080", defaultServer.Port) + } + if defaultServer.Name != "default" { + t.Errorf("default server name: got %s, want default", defaultServer.Name) + } + for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if tt.got != tt.expected { @@ -57,12 +69,12 @@ func TestDefaultValues(t *testing.T) { func TestEnvironmentVariableOverrides(t *testing.T) { // Set environment variables - os.Setenv("RESOLVESPEC_SERVER_ADDR", ":9090") + os.Setenv("RESOLVESPEC_SERVERS_INSTANCES_DEFAULT_PORT", "9090") os.Setenv("RESOLVESPEC_TRACING_ENABLED", "true") os.Setenv("RESOLVESPEC_CACHE_PROVIDER", "redis") os.Setenv("RESOLVESPEC_LOGGER_DEV", "true") defer func() { - os.Unsetenv("RESOLVESPEC_SERVER_ADDR") + os.Unsetenv("RESOLVESPEC_SERVERS_INSTANCES_DEFAULT_PORT") os.Unsetenv("RESOLVESPEC_TRACING_ENABLED") os.Unsetenv("RESOLVESPEC_CACHE_PROVIDER") os.Unsetenv("RESOLVESPEC_LOGGER_DEV") @@ -84,7 +96,6 @@ func TestEnvironmentVariableOverrides(t *testing.T) { got interface{} expected interface{} }{ - {"server.addr", cfg.Server.Addr, ":9090"}, {"tracing.enabled", cfg.Tracing.Enabled, true}, {"cache.provider", cfg.Cache.Provider, "redis"}, {"logger.dev", cfg.Logger.Dev, true}, @@ -97,11 +108,17 @@ func TestEnvironmentVariableOverrides(t *testing.T) { } }) } + + // Test server port override + defaultServer := cfg.Servers.Instances["default"] + if defaultServer.Port != 9090 { + t.Errorf("server port: got %d, want 9090", defaultServer.Port) + } } func TestProgrammaticConfiguration(t *testing.T) { mgr := NewManager() - mgr.Set("server.addr", ":7070") + mgr.Set("servers.instances.default.port", 7070) mgr.Set("tracing.service_name", "test-service") cfg, err := mgr.GetConfig() @@ -109,8 +126,8 @@ func TestProgrammaticConfiguration(t *testing.T) { t.Fatalf("Failed to get config: %v", err) } - if cfg.Server.Addr != ":7070" { - t.Errorf("server.addr: got %s, want :7070", cfg.Server.Addr) + if cfg.Servers.Instances["default"].Port != 7070 { + t.Errorf("server port: got %d, want 7070", cfg.Servers.Instances["default"].Port) } if cfg.Tracing.ServiceName != "test-service" { @@ -148,8 +165,8 @@ func TestWithOptions(t *testing.T) { } // Set environment variable with custom prefix - os.Setenv("MYAPP_SERVER_ADDR", ":5000") - defer os.Unsetenv("MYAPP_SERVER_ADDR") + os.Setenv("MYAPP_SERVERS_INSTANCES_DEFAULT_PORT", "5000") + defer os.Unsetenv("MYAPP_SERVERS_INSTANCES_DEFAULT_PORT") if err := mgr.Load(); err != nil { t.Fatalf("Failed to load config: %v", err) @@ -160,7 +177,432 @@ func TestWithOptions(t *testing.T) { t.Fatalf("Failed to get config: %v", err) } - if cfg.Server.Addr != ":5000" { - t.Errorf("server.addr: got %s, want :5000", cfg.Server.Addr) + if cfg.Servers.Instances["default"].Port != 5000 { + t.Errorf("server port: got %d, want 5000", cfg.Servers.Instances["default"].Port) + } +} + +func TestServersConfig(t *testing.T) { + mgr := NewManager() + if err := mgr.Load(); err != nil { + t.Fatalf("Failed to load config: %v", err) + } + + cfg, err := mgr.GetConfig() + if err != nil { + t.Fatalf("Failed to get config: %v", err) + } + + // Test default server exists + if cfg.Servers.DefaultServer != "default" { + t.Errorf("Expected default_server to be 'default', got %s", cfg.Servers.DefaultServer) + } + + // Test default instance + defaultServer, ok := cfg.Servers.Instances["default"] + if !ok { + t.Fatal("Default server instance not found") + } + + if defaultServer.Port != 8080 { + t.Errorf("Expected default port 8080, got %d", defaultServer.Port) + } + + if defaultServer.Name != "default" { + t.Errorf("Expected default name 'default', got %s", defaultServer.Name) + } + + if defaultServer.Description != "Default HTTP server" { + t.Errorf("Expected description 'Default HTTP server', got %s", defaultServer.Description) + } +} + +func TestMultipleServerInstances(t *testing.T) { + mgr := NewManager() + + // Add additional server instances (default instance exists from defaults) + mgr.Set("servers.default_server", "api") + mgr.Set("servers.instances.api.name", "api") + mgr.Set("servers.instances.api.host", "0.0.0.0") + mgr.Set("servers.instances.api.port", 8080) + mgr.Set("servers.instances.admin.name", "admin") + mgr.Set("servers.instances.admin.host", "localhost") + mgr.Set("servers.instances.admin.port", 8081) + + cfg, err := mgr.GetConfig() + if err != nil { + t.Fatalf("Failed to get config: %v", err) + } + + // Should have default + api + admin = 3 instances + if len(cfg.Servers.Instances) < 2 { + t.Errorf("Expected at least 2 server instances, got %d", len(cfg.Servers.Instances)) + } + + // Verify api instance + apiServer, ok := cfg.Servers.Instances["api"] + if !ok { + t.Fatal("API server instance not found") + } + if apiServer.Port != 8080 { + t.Errorf("Expected API port 8080, got %d", apiServer.Port) + } + + // Verify admin instance + adminServer, ok := cfg.Servers.Instances["admin"] + if !ok { + t.Fatal("Admin server instance not found") + } + if adminServer.Port != 8081 { + t.Errorf("Expected admin port 8081, got %d", adminServer.Port) + } + + // Validate default server + if err := cfg.Servers.Validate(); err != nil { + t.Errorf("Server config validation failed: %v", err) + } + + // Get default + defaultSrv, err := cfg.Servers.GetDefault() + if err != nil { + t.Fatalf("Failed to get default server: %v", err) + } + if defaultSrv.Name != "api" { + t.Errorf("Expected default server 'api', got '%s'", defaultSrv.Name) + } +} + +func TestExtensionsField(t *testing.T) { + mgr := NewManager() + + // Set custom extensions + mgr.Set("extensions.custom_feature.enabled", true) + mgr.Set("extensions.custom_feature.api_key", "test-key") + mgr.Set("extensions.another_extension.timeout", "5s") + + cfg, err := mgr.GetConfig() + if err != nil { + t.Fatalf("Failed to get config: %v", err) + } + + if cfg.Extensions == nil { + t.Fatal("Extensions should not be nil") + } + + // Verify extensions are accessible + customFeature := mgr.Get("extensions.custom_feature") + if customFeature == nil { + t.Error("custom_feature extension not found") + } + + // Verify via config manager methods + if !mgr.GetBool("extensions.custom_feature.enabled") { + t.Error("Expected custom_feature.enabled to be true") + } + + if mgr.GetString("extensions.custom_feature.api_key") != "test-key" { + t.Error("Expected api_key to be 'test-key'") + } +} + +func TestServerInstanceValidation(t *testing.T) { + tests := []struct { + name string + instance ServerInstanceConfig + expectErr bool + }{ + { + name: "valid basic config", + instance: ServerInstanceConfig{ + Name: "test", + Port: 8080, + }, + expectErr: false, + }, + { + name: "invalid port - too high", + instance: ServerInstanceConfig{ + Name: "test", + Port: 99999, + }, + expectErr: true, + }, + { + name: "invalid port - zero", + instance: ServerInstanceConfig{ + Name: "test", + Port: 0, + }, + expectErr: true, + }, + { + name: "empty name", + instance: ServerInstanceConfig{ + Name: "", + Port: 8080, + }, + expectErr: true, + }, + { + name: "conflicting TLS options", + instance: ServerInstanceConfig{ + Name: "test", + Port: 8080, + SelfSignedSSL: true, + AutoTLS: true, + }, + expectErr: true, + }, + { + name: "incomplete SSL cert config", + instance: ServerInstanceConfig{ + Name: "test", + Port: 8080, + SSLCert: "/path/to/cert.pem", + // Missing SSLKey + }, + expectErr: true, + }, + { + name: "AutoTLS without domains", + instance: ServerInstanceConfig{ + Name: "test", + Port: 8080, + AutoTLS: true, + // Missing AutoTLSDomains + }, + expectErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.instance.Validate() + if tt.expectErr && err == nil { + t.Error("Expected validation error, got nil") + } + if !tt.expectErr && err != nil { + t.Errorf("Expected no error, got: %v", err) + } + }) + } +} + +func TestApplyGlobalDefaults(t *testing.T) { + globals := ServersConfig{ + ShutdownTimeout: 30 * time.Second, + DrainTimeout: 25 * time.Second, + ReadTimeout: 10 * time.Second, + WriteTimeout: 10 * time.Second, + IdleTimeout: 120 * time.Second, + } + + instance := ServerInstanceConfig{ + Name: "test", + Port: 8080, + } + + // Apply global defaults + instance.ApplyGlobalDefaults(globals) + + // Check that defaults were applied + if instance.ShutdownTimeout == nil || *instance.ShutdownTimeout != 30*time.Second { + t.Error("ShutdownTimeout not applied correctly") + } + if instance.DrainTimeout == nil || *instance.DrainTimeout != 25*time.Second { + t.Error("DrainTimeout not applied correctly") + } + if instance.ReadTimeout == nil || *instance.ReadTimeout != 10*time.Second { + t.Error("ReadTimeout not applied correctly") + } + if instance.WriteTimeout == nil || *instance.WriteTimeout != 10*time.Second { + t.Error("WriteTimeout not applied correctly") + } + if instance.IdleTimeout == nil || *instance.IdleTimeout != 120*time.Second { + t.Error("IdleTimeout not applied correctly") + } + + // Test that explicit overrides are not replaced + customTimeout := 60 * time.Second + instance2 := ServerInstanceConfig{ + Name: "test2", + Port: 8081, + ShutdownTimeout: &customTimeout, + } + + instance2.ApplyGlobalDefaults(globals) + + if instance2.ShutdownTimeout == nil || *instance2.ShutdownTimeout != 60*time.Second { + t.Error("Custom ShutdownTimeout was overridden") + } +} + +func TestPathsConfig(t *testing.T) { + mgr := NewManager() + if err := mgr.Load(); err != nil { + t.Fatalf("Failed to load config: %v", err) + } + + cfg, err := mgr.GetConfig() + if err != nil { + t.Fatalf("Failed to get config: %v", err) + } + + // Test default paths exist + if !cfg.Paths.Has("data_dir") { + t.Error("Expected data_dir path to exist") + } + if !cfg.Paths.Has("config_dir") { + t.Error("Expected config_dir path to exist") + } + if !cfg.Paths.Has("logs_dir") { + t.Error("Expected logs_dir path to exist") + } + if !cfg.Paths.Has("temp_dir") { + t.Error("Expected temp_dir path to exist") + } + + // Test Get method + dataDir, err := cfg.Paths.Get("data_dir") + if err != nil { + t.Errorf("Failed to get data_dir: %v", err) + } + if dataDir != "./data" { + t.Errorf("Expected data_dir to be './data', got '%s'", dataDir) + } + + // Test GetOrDefault + existing := cfg.Paths.GetOrDefault("data_dir", "/default/path") + if existing != "./data" { + t.Errorf("Expected existing path, got '%s'", existing) + } + + nonExisting := cfg.Paths.GetOrDefault("nonexistent", "/default/path") + if nonExisting != "/default/path" { + t.Errorf("Expected default path, got '%s'", nonExisting) + } +} + +func TestPathsConfigMethods(t *testing.T) { + pc := PathsConfig{ + "base": "/var/myapp", + "data": "/var/myapp/data", + } + + // Test Get + path, err := pc.Get("base") + if err != nil { + t.Errorf("Failed to get path: %v", err) + } + if path != "/var/myapp" { + t.Errorf("Expected '/var/myapp', got '%s'", path) + } + + // Test Get non-existent + _, err = pc.Get("nonexistent") + if err == nil { + t.Error("Expected error for non-existent path") + } + + // Test Set + pc.Set("new_path", "/new/location") + newPath, err := pc.Get("new_path") + if err != nil { + t.Errorf("Failed to get newly set path: %v", err) + } + if newPath != "/new/location" { + t.Errorf("Expected '/new/location', got '%s'", newPath) + } + + // Test Has + if !pc.Has("base") { + t.Error("Expected 'base' path to exist") + } + if pc.Has("nonexistent") { + t.Error("Expected 'nonexistent' path to not exist") + } + + // Test List + names := pc.List() + if len(names) != 3 { + t.Errorf("Expected 3 paths, got %d", len(names)) + } + + // Test Join + joined, err := pc.Join("base", "subdir", "file.txt") + if err != nil { + t.Errorf("Failed to join paths: %v", err) + } + expected := "/var/myapp/subdir/file.txt" + if joined != expected { + t.Errorf("Expected '%s', got '%s'", expected, joined) + } +} + +func TestPathsConfigEnvironmentVariables(t *testing.T) { + // Set environment variables for paths + os.Setenv("RESOLVESPEC_PATHS_DATA_DIR", "/custom/data") + os.Setenv("RESOLVESPEC_PATHS_LOGS_DIR", "/custom/logs") + defer func() { + os.Unsetenv("RESOLVESPEC_PATHS_DATA_DIR") + os.Unsetenv("RESOLVESPEC_PATHS_LOGS_DIR") + }() + + mgr := NewManager() + if err := mgr.Load(); err != nil { + t.Fatalf("Failed to load config: %v", err) + } + + cfg, err := mgr.GetConfig() + if err != nil { + t.Fatalf("Failed to get config: %v", err) + } + + // Test environment variable override of existing default path + dataDir, err := cfg.Paths.Get("data_dir") + if err != nil { + t.Errorf("Failed to get data_dir: %v", err) + } + if dataDir != "/custom/data" { + t.Errorf("Expected '/custom/data', got '%s'", dataDir) + } + + // Test another environment variable override + logsDir, err := cfg.Paths.Get("logs_dir") + if err != nil { + t.Errorf("Failed to get logs_dir: %v", err) + } + if logsDir != "/custom/logs" { + t.Errorf("Expected '/custom/logs', got '%s'", logsDir) + } +} + +func TestPathsConfigProgrammatic(t *testing.T) { + mgr := NewManager() + + // Set custom paths programmatically + mgr.Set("paths.custom_dir", "/my/custom/dir") + mgr.Set("paths.cache_dir", "/var/cache/myapp") + + cfg, err := mgr.GetConfig() + if err != nil { + t.Fatalf("Failed to get config: %v", err) + } + + // Verify custom paths + customDir, err := cfg.Paths.Get("custom_dir") + if err != nil { + t.Errorf("Failed to get custom_dir: %v", err) + } + if customDir != "/my/custom/dir" { + t.Errorf("Expected '/my/custom/dir', got '%s'", customDir) + } + + cacheDir, err := cfg.Paths.Get("cache_dir") + if err != nil { + t.Errorf("Failed to get cache_dir: %v", err) + } + if cacheDir != "/var/cache/myapp" { + t.Errorf("Expected '/var/cache/myapp', got '%s'", cacheDir) } } diff --git a/pkg/config/paths.go b/pkg/config/paths.go new file mode 100644 index 0000000..95d9ca0 --- /dev/null +++ b/pkg/config/paths.go @@ -0,0 +1,117 @@ +package config + +import ( + "fmt" + "os" + "path/filepath" +) + +// Get retrieves a path by name +func (pc PathsConfig) Get(name string) (string, error) { + if pc == nil { + return "", fmt.Errorf("paths not initialized") + } + + path, ok := pc[name] + if !ok { + return "", fmt.Errorf("path '%s' not found", name) + } + + return path, nil +} + +// GetOrDefault retrieves a path by name, returning defaultPath if not found +func (pc PathsConfig) GetOrDefault(name, defaultPath string) string { + if pc == nil { + return defaultPath + } + + path, ok := pc[name] + if !ok { + return defaultPath + } + + return path +} + +// Set sets a path by name +func (pc PathsConfig) Set(name, path string) { + pc[name] = path +} + +// Has checks if a path exists by name +func (pc PathsConfig) Has(name string) bool { + if pc == nil { + return false + } + _, ok := pc[name] + return ok +} + +// EnsureDir ensures a directory exists at the specified path name +// Creates the directory if it doesn't exist with the given permissions +func (pc PathsConfig) EnsureDir(name string, perm os.FileMode) error { + path, err := pc.Get(name) + if err != nil { + return err + } + + // Check if directory exists + info, err := os.Stat(path) + if err == nil { + // Path exists, check if it's a directory + if !info.IsDir() { + return fmt.Errorf("path '%s' exists but is not a directory: %s", name, path) + } + return nil + } + + // Directory doesn't exist, create it + if os.IsNotExist(err) { + if err := os.MkdirAll(path, perm); err != nil { + return fmt.Errorf("failed to create directory for '%s' at %s: %w", name, path, err) + } + return nil + } + + return fmt.Errorf("failed to stat path '%s' at %s: %w", name, path, err) +} + +// AbsPath returns the absolute path for a named path +func (pc PathsConfig) AbsPath(name string) (string, error) { + path, err := pc.Get(name) + if err != nil { + return "", err + } + + absPath, err := filepath.Abs(path) + if err != nil { + return "", fmt.Errorf("failed to get absolute path for '%s': %w", name, err) + } + + return absPath, nil +} + +// Join joins path segments with a named base path +func (pc PathsConfig) Join(name string, elem ...string) (string, error) { + base, err := pc.Get(name) + if err != nil { + return "", err + } + + parts := append([]string{base}, elem...) + return filepath.Join(parts...), nil +} + +// List returns all configured path names +func (pc PathsConfig) List() []string { + if pc == nil { + return []string{} + } + + names := make([]string, 0, len(pc)) + for name := range pc { + names = append(names, name) + } + return names +} diff --git a/pkg/config/server.go b/pkg/config/server.go new file mode 100644 index 0000000..3c20da9 --- /dev/null +++ b/pkg/config/server.go @@ -0,0 +1,106 @@ +package config + +import ( + "fmt" +) + +// ApplyGlobalDefaults applies global server defaults to this instance +// Called for instances that don't specify their own timeout values +func (sic *ServerInstanceConfig) ApplyGlobalDefaults(globals ServersConfig) { + if sic.ShutdownTimeout == nil && globals.ShutdownTimeout > 0 { + t := globals.ShutdownTimeout + sic.ShutdownTimeout = &t + } + if sic.DrainTimeout == nil && globals.DrainTimeout > 0 { + t := globals.DrainTimeout + sic.DrainTimeout = &t + } + if sic.ReadTimeout == nil && globals.ReadTimeout > 0 { + t := globals.ReadTimeout + sic.ReadTimeout = &t + } + if sic.WriteTimeout == nil && globals.WriteTimeout > 0 { + t := globals.WriteTimeout + sic.WriteTimeout = &t + } + if sic.IdleTimeout == nil && globals.IdleTimeout > 0 { + t := globals.IdleTimeout + sic.IdleTimeout = &t + } +} + +// Validate validates the ServerInstanceConfig +func (sic *ServerInstanceConfig) Validate() error { + if sic.Name == "" { + return fmt.Errorf("server instance name cannot be empty") + } + if sic.Port <= 0 || sic.Port > 65535 { + return fmt.Errorf("invalid port: %d (must be 1-65535)", sic.Port) + } + + // Validate TLS options are mutually exclusive + tlsCount := 0 + if sic.SSLCert != "" || sic.SSLKey != "" { + tlsCount++ + } + if sic.SelfSignedSSL { + tlsCount++ + } + if sic.AutoTLS { + tlsCount++ + } + if tlsCount > 1 { + return fmt.Errorf("server '%s': only one TLS option can be enabled", sic.Name) + } + + // If using certificate files, both must be provided + if (sic.SSLCert != "" && sic.SSLKey == "") || (sic.SSLCert == "" && sic.SSLKey != "") { + return fmt.Errorf("server '%s': both ssl_cert and ssl_key must be provided", sic.Name) + } + + // If using AutoTLS, domains must be specified + if sic.AutoTLS && len(sic.AutoTLSDomains) == 0 { + return fmt.Errorf("server '%s': auto_tls_domains must be specified when auto_tls is enabled", sic.Name) + } + + return nil +} + +// Validate validates the ServersConfig +func (sc *ServersConfig) Validate() error { + if len(sc.Instances) == 0 { + return fmt.Errorf("at least one server instance must be configured") + } + + if sc.DefaultServer != "" { + if _, ok := sc.Instances[sc.DefaultServer]; !ok { + return fmt.Errorf("default server '%s' not found in instances", sc.DefaultServer) + } + } + + // Validate each instance + for name, instance := range sc.Instances { + if instance.Name != name { + return fmt.Errorf("server instance name mismatch: key='%s', name='%s'", name, instance.Name) + } + if err := instance.Validate(); err != nil { + return err + } + } + + return nil +} + +// GetDefault returns the default server instance configuration +func (sc *ServersConfig) GetDefault() (*ServerInstanceConfig, error) { + if sc.DefaultServer == "" { + return nil, fmt.Errorf("no default server configured") + } + + instance, ok := sc.Instances[sc.DefaultServer] + if !ok { + return nil, fmt.Errorf("default server '%s' not found", sc.DefaultServer) + } + + return &instance, nil +} diff --git a/pkg/server/config_helper.go b/pkg/server/config_helper.go new file mode 100644 index 0000000..cf42688 --- /dev/null +++ b/pkg/server/config_helper.go @@ -0,0 +1,47 @@ +package server + +import ( + "net/http" + + "github.com/bitechdev/ResolveSpec/pkg/config" +) + +// FromConfigInstanceToServerConfig converts a config.ServerInstanceConfig to server.Config +// The handler must be provided separately as it cannot be serialized +func FromConfigInstanceToServerConfig(sic *config.ServerInstanceConfig, handler http.Handler) Config { + cfg := Config{ + Name: sic.Name, + Host: sic.Host, + Port: sic.Port, + Description: sic.Description, + Handler: handler, + GZIP: sic.GZIP, + + SSLCert: sic.SSLCert, + SSLKey: sic.SSLKey, + SelfSignedSSL: sic.SelfSignedSSL, + AutoTLS: sic.AutoTLS, + AutoTLSDomains: sic.AutoTLSDomains, + AutoTLSCacheDir: sic.AutoTLSCacheDir, + AutoTLSEmail: sic.AutoTLSEmail, + } + + // Apply timeouts (use pointers to override, or use zero values for defaults) + if sic.ShutdownTimeout != nil { + cfg.ShutdownTimeout = *sic.ShutdownTimeout + } + if sic.DrainTimeout != nil { + cfg.DrainTimeout = *sic.DrainTimeout + } + if sic.ReadTimeout != nil { + cfg.ReadTimeout = *sic.ReadTimeout + } + if sic.WriteTimeout != nil { + cfg.WriteTimeout = *sic.WriteTimeout + } + if sic.IdleTimeout != nil { + cfg.IdleTimeout = *sic.IdleTimeout + } + + return cfg +} diff --git a/pkg/server/manager.go b/pkg/server/manager.go index 451fbc3..f1b7877 100644 --- a/pkg/server/manager.go +++ b/pkg/server/manager.go @@ -501,7 +501,7 @@ func (s *serverInstance) Start() error { if useTLS { protocol = "HTTPS" - logger.Info("Starting %s server '%s' on %s", protocol, s.cfg.Name, s.Addr()) + logger.Info("Starting %s server - Name: '%s', Address: %s, Port: %d", protocol, s.cfg.Name, s.cfg.Host, s.cfg.Port) // For AutoTLS, we need to use a TLS listener if s.cfg.AutoTLS { @@ -519,7 +519,7 @@ func (s *serverInstance) Start() error { err = s.gracefulServer.server.ListenAndServeTLS(s.certFile, s.keyFile) } } else { - logger.Info("Starting %s server '%s' on %s", protocol, s.cfg.Name, s.Addr()) + logger.Info("Starting %s server - Name: '%s', Address: %s, Port: %d", protocol, s.cfg.Name, s.cfg.Host, s.cfg.Port) err = s.gracefulServer.server.ListenAndServe() }