mirror of
https://github.com/bitechdev/ResolveSpec.git
synced 2026-01-03 10:24:26 +00:00
Compare commits
7 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| c864aa4d90 | |||
| 250fcf686c | |||
| 47cfc4b3da | |||
| 0e8ae75daf | |||
| ce092d1c62 | |||
| 871dd2e374 | |||
|
|
ebd03d10ad |
@@ -39,7 +39,6 @@ func main() {
|
||||
logger.UpdateLoggerPath(cfg.Logger.Path, cfg.Logger.Dev)
|
||||
}
|
||||
logger.Info("ResolveSpec test server starting")
|
||||
logger.Info("Configuration loaded - Server will listen on: %s", cfg.Server.Addr)
|
||||
|
||||
// Initialize database manager
|
||||
ctx := context.Background()
|
||||
@@ -73,47 +72,30 @@ func main() {
|
||||
// Create server manager
|
||||
mgr := server.NewManager()
|
||||
|
||||
// Parse host and port from addr
|
||||
host := ""
|
||||
port := 8080
|
||||
if cfg.Server.Addr != "" {
|
||||
// Parse addr (format: ":8080" or "localhost:8080")
|
||||
if cfg.Server.Addr[0] == ':' {
|
||||
// Just port
|
||||
_, err := fmt.Sscanf(cfg.Server.Addr, ":%d", &port)
|
||||
if err != nil {
|
||||
logger.Error("Invalid server address: %s", cfg.Server.Addr)
|
||||
os.Exit(1)
|
||||
}
|
||||
} else {
|
||||
// Host and port
|
||||
_, err := fmt.Sscanf(cfg.Server.Addr, "%[^:]:%d", &host, &port)
|
||||
if err != nil {
|
||||
logger.Error("Invalid server address: %s", cfg.Server.Addr)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
// Get default server configuration
|
||||
defaultServerCfg, err := cfg.Servers.GetDefault()
|
||||
if err != nil {
|
||||
logger.Error("Failed to get default server config: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Add server instance
|
||||
_, err = mgr.Add(server.Config{
|
||||
Name: "api",
|
||||
Host: host,
|
||||
Port: port,
|
||||
Handler: r,
|
||||
ShutdownTimeout: cfg.Server.ShutdownTimeout,
|
||||
DrainTimeout: cfg.Server.DrainTimeout,
|
||||
ReadTimeout: cfg.Server.ReadTimeout,
|
||||
WriteTimeout: cfg.Server.WriteTimeout,
|
||||
IdleTimeout: cfg.Server.IdleTimeout,
|
||||
})
|
||||
// Apply global defaults
|
||||
defaultServerCfg.ApplyGlobalDefaults(cfg.Servers)
|
||||
|
||||
// Convert to server.Config and add instance
|
||||
serverCfg := server.FromConfigInstanceToServerConfig(defaultServerCfg, r)
|
||||
|
||||
logger.Info("Configuration loaded - Server '%s' will listen on %s:%d",
|
||||
serverCfg.Name, serverCfg.Host, serverCfg.Port)
|
||||
|
||||
_, err = mgr.Add(serverCfg)
|
||||
if err != nil {
|
||||
logger.Error("Failed to add server: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Start server with graceful shutdown
|
||||
logger.Info("Starting server on %s", cfg.Server.Addr)
|
||||
logger.Info("Starting server '%s' on %s:%d", serverCfg.Name, serverCfg.Host, serverCfg.Port)
|
||||
if err := mgr.ServeWithGracefulShutdown(); err != nil {
|
||||
logger.Error("Server failed: %v", err)
|
||||
os.Exit(1)
|
||||
|
||||
@@ -4,20 +4,28 @@ import "time"
|
||||
|
||||
// Config represents the complete application configuration
|
||||
type Config struct {
|
||||
Server ServerConfig `mapstructure:"server"`
|
||||
Tracing TracingConfig `mapstructure:"tracing"`
|
||||
Cache CacheConfig `mapstructure:"cache"`
|
||||
Logger LoggerConfig `mapstructure:"logger"`
|
||||
ErrorTracking ErrorTrackingConfig `mapstructure:"error_tracking"`
|
||||
Middleware MiddlewareConfig `mapstructure:"middleware"`
|
||||
CORS CORSConfig `mapstructure:"cors"`
|
||||
EventBroker EventBrokerConfig `mapstructure:"event_broker"`
|
||||
DBManager DBManagerConfig `mapstructure:"dbmanager"`
|
||||
Servers ServersConfig `mapstructure:"servers"`
|
||||
Tracing TracingConfig `mapstructure:"tracing"`
|
||||
Cache CacheConfig `mapstructure:"cache"`
|
||||
Logger LoggerConfig `mapstructure:"logger"`
|
||||
ErrorTracking ErrorTrackingConfig `mapstructure:"error_tracking"`
|
||||
Middleware MiddlewareConfig `mapstructure:"middleware"`
|
||||
CORS CORSConfig `mapstructure:"cors"`
|
||||
EventBroker EventBrokerConfig `mapstructure:"event_broker"`
|
||||
DBManager DBManagerConfig `mapstructure:"dbmanager"`
|
||||
Paths PathsConfig `mapstructure:"paths"`
|
||||
Extensions map[string]interface{} `mapstructure:"extensions"`
|
||||
}
|
||||
|
||||
// ServerConfig holds server-related configuration
|
||||
type ServerConfig struct {
|
||||
Addr string `mapstructure:"addr"`
|
||||
// ServersConfig contains configuration for the server manager
|
||||
type ServersConfig struct {
|
||||
// DefaultServer is the name of the default server to use
|
||||
DefaultServer string `mapstructure:"default_server"`
|
||||
|
||||
// Instances is a map of server name to server configuration
|
||||
Instances map[string]ServerInstanceConfig `mapstructure:"instances"`
|
||||
|
||||
// Global timeout defaults (can be overridden per instance)
|
||||
ShutdownTimeout time.Duration `mapstructure:"shutdown_timeout"`
|
||||
DrainTimeout time.Duration `mapstructure:"drain_timeout"`
|
||||
ReadTimeout time.Duration `mapstructure:"read_timeout"`
|
||||
@@ -25,6 +33,48 @@ type ServerConfig struct {
|
||||
IdleTimeout time.Duration `mapstructure:"idle_timeout"`
|
||||
}
|
||||
|
||||
// ServerInstanceConfig defines configuration for a single server instance
|
||||
type ServerInstanceConfig struct {
|
||||
// Name is the unique name of this server instance
|
||||
Name string `mapstructure:"name"`
|
||||
|
||||
// Host is the host to bind to (e.g., "localhost", "0.0.0.0", "")
|
||||
Host string `mapstructure:"host"`
|
||||
|
||||
// Port is the port number to listen on
|
||||
Port int `mapstructure:"port"`
|
||||
|
||||
// Description is a human-readable description of this server
|
||||
Description string `mapstructure:"description"`
|
||||
|
||||
// GZIP enables GZIP compression middleware
|
||||
GZIP bool `mapstructure:"gzip"`
|
||||
|
||||
// TLS/HTTPS configuration options (mutually exclusive)
|
||||
// Option 1: Provide certificate and key files directly
|
||||
SSLCert string `mapstructure:"ssl_cert"`
|
||||
SSLKey string `mapstructure:"ssl_key"`
|
||||
|
||||
// Option 2: Use self-signed certificate (for development/testing)
|
||||
SelfSignedSSL bool `mapstructure:"self_signed_ssl"`
|
||||
|
||||
// Option 3: Use Let's Encrypt / AutoTLS
|
||||
AutoTLS bool `mapstructure:"auto_tls"`
|
||||
AutoTLSDomains []string `mapstructure:"auto_tls_domains"`
|
||||
AutoTLSCacheDir string `mapstructure:"auto_tls_cache_dir"`
|
||||
AutoTLSEmail string `mapstructure:"auto_tls_email"`
|
||||
|
||||
// Timeout configurations (overrides global defaults)
|
||||
ShutdownTimeout *time.Duration `mapstructure:"shutdown_timeout"`
|
||||
DrainTimeout *time.Duration `mapstructure:"drain_timeout"`
|
||||
ReadTimeout *time.Duration `mapstructure:"read_timeout"`
|
||||
WriteTimeout *time.Duration `mapstructure:"write_timeout"`
|
||||
IdleTimeout *time.Duration `mapstructure:"idle_timeout"`
|
||||
|
||||
// Tags for organization and filtering
|
||||
Tags map[string]string `mapstructure:"tags"`
|
||||
}
|
||||
|
||||
// TracingConfig holds OpenTelemetry tracing configuration
|
||||
type TracingConfig struct {
|
||||
Enabled bool `mapstructure:"enabled"`
|
||||
@@ -136,3 +186,8 @@ type EventBrokerRetryPolicyConfig struct {
|
||||
MaxDelay time.Duration `mapstructure:"max_delay"`
|
||||
BackoffFactor float64 `mapstructure:"backoff_factor"`
|
||||
}
|
||||
|
||||
// PathsConfig contains configuration for named file system paths
|
||||
// This is a map of path name to file system path
|
||||
// Example: "data_dir": "/var/lib/myapp/data"
|
||||
type PathsConfig map[string]string
|
||||
|
||||
@@ -2,6 +2,9 @@ package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
@@ -91,6 +94,160 @@ func (c *DBManagerConfig) ToManagerConfig() interface{} {
|
||||
return c
|
||||
}
|
||||
|
||||
// PopulateFromDSN parses a DSN and populates the connection fields
|
||||
func (cc *DBConnectionConfig) PopulateFromDSN() error {
|
||||
if cc.DSN == "" {
|
||||
return nil // Nothing to populate
|
||||
}
|
||||
|
||||
switch cc.Type {
|
||||
case "postgres":
|
||||
return cc.populatePostgresDSN()
|
||||
case "mongodb":
|
||||
return cc.populateMongoDSN()
|
||||
case "mssql":
|
||||
return cc.populateMSSQLDSN()
|
||||
case "sqlite":
|
||||
return cc.populateSQLiteDSN()
|
||||
default:
|
||||
return fmt.Errorf("cannot parse DSN for unsupported database type: %s", cc.Type)
|
||||
}
|
||||
}
|
||||
|
||||
// populatePostgresDSN parses PostgreSQL DSN format
|
||||
// Example: host=localhost port=5432 user=postgres password=secret dbname=mydb sslmode=disable
|
||||
func (cc *DBConnectionConfig) populatePostgresDSN() error {
|
||||
parts := strings.Fields(cc.DSN)
|
||||
for _, part := range parts {
|
||||
kv := strings.SplitN(part, "=", 2)
|
||||
if len(kv) != 2 {
|
||||
continue
|
||||
}
|
||||
key, value := kv[0], kv[1]
|
||||
|
||||
switch key {
|
||||
case "host":
|
||||
cc.Host = value
|
||||
case "port":
|
||||
port, err := strconv.Atoi(value)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid port in DSN: %w", err)
|
||||
}
|
||||
cc.Port = port
|
||||
case "user":
|
||||
cc.User = value
|
||||
case "password":
|
||||
cc.Password = value
|
||||
case "dbname":
|
||||
cc.Database = value
|
||||
case "sslmode":
|
||||
cc.SSLMode = value
|
||||
case "search_path":
|
||||
cc.Schema = value
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// populateMongoDSN parses MongoDB DSN format
|
||||
// Example: mongodb://user:password@host:port/database?authSource=admin&replicaSet=rs0
|
||||
func (cc *DBConnectionConfig) populateMongoDSN() error {
|
||||
u, err := url.Parse(cc.DSN)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid MongoDB DSN: %w", err)
|
||||
}
|
||||
|
||||
// Extract user and password
|
||||
if u.User != nil {
|
||||
cc.User = u.User.Username()
|
||||
if password, ok := u.User.Password(); ok {
|
||||
cc.Password = password
|
||||
}
|
||||
}
|
||||
|
||||
// Extract host and port
|
||||
if u.Host != "" {
|
||||
host := u.Host
|
||||
if strings.Contains(host, ":") {
|
||||
hostPort := strings.SplitN(host, ":", 2)
|
||||
cc.Host = hostPort[0]
|
||||
if port, err := strconv.Atoi(hostPort[1]); err == nil {
|
||||
cc.Port = port
|
||||
}
|
||||
} else {
|
||||
cc.Host = host
|
||||
}
|
||||
}
|
||||
|
||||
// Extract database
|
||||
if u.Path != "" {
|
||||
cc.Database = strings.TrimPrefix(u.Path, "/")
|
||||
}
|
||||
|
||||
// Extract query parameters
|
||||
params := u.Query()
|
||||
if authSource := params.Get("authSource"); authSource != "" {
|
||||
cc.AuthSource = authSource
|
||||
}
|
||||
if replicaSet := params.Get("replicaSet"); replicaSet != "" {
|
||||
cc.ReplicaSet = replicaSet
|
||||
}
|
||||
if readPref := params.Get("readPreference"); readPref != "" {
|
||||
cc.ReadPreference = readPref
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// populateMSSQLDSN parses MSSQL DSN format
|
||||
// Example: sqlserver://username:password@host:port?database=dbname&schema=dbo
|
||||
func (cc *DBConnectionConfig) populateMSSQLDSN() error {
|
||||
u, err := url.Parse(cc.DSN)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid MSSQL DSN: %w", err)
|
||||
}
|
||||
|
||||
// Extract user and password
|
||||
if u.User != nil {
|
||||
cc.User = u.User.Username()
|
||||
if password, ok := u.User.Password(); ok {
|
||||
cc.Password = password
|
||||
}
|
||||
}
|
||||
|
||||
// Extract host and port
|
||||
if u.Host != "" {
|
||||
host := u.Host
|
||||
if strings.Contains(host, ":") {
|
||||
hostPort := strings.SplitN(host, ":", 2)
|
||||
cc.Host = hostPort[0]
|
||||
if port, err := strconv.Atoi(hostPort[1]); err == nil {
|
||||
cc.Port = port
|
||||
}
|
||||
} else {
|
||||
cc.Host = host
|
||||
}
|
||||
}
|
||||
|
||||
// Extract query parameters
|
||||
params := u.Query()
|
||||
if database := params.Get("database"); database != "" {
|
||||
cc.Database = database
|
||||
}
|
||||
if schema := params.Get("schema"); schema != "" {
|
||||
cc.Schema = schema
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// populateSQLiteDSN parses SQLite DSN format
|
||||
// Example: /path/to/database.db or :memory:
|
||||
func (cc *DBConnectionConfig) populateSQLiteDSN() error {
|
||||
cc.FilePath = cc.DSN
|
||||
return nil
|
||||
}
|
||||
|
||||
// Validate validates the DBManager configuration
|
||||
func (c *DBManagerConfig) Validate() error {
|
||||
if len(c.Connections) == 0 {
|
||||
|
||||
@@ -97,6 +97,31 @@ func (m *Manager) GetConfig() (*Config, error) {
|
||||
return &cfg, nil
|
||||
}
|
||||
|
||||
// SetConfig sets the complete configuration
|
||||
func (m *Manager) SetConfig(cfg *Config) error {
|
||||
configMap := make(map[string]interface{})
|
||||
|
||||
// Marshal the config to a map structure that viper can use
|
||||
if err := m.v.Unmarshal(&configMap); err != nil {
|
||||
return fmt.Errorf("failed to prepare config map: %w", err)
|
||||
}
|
||||
|
||||
// Use viper's merge to apply the config
|
||||
m.v.Set("servers", cfg.Servers)
|
||||
m.v.Set("tracing", cfg.Tracing)
|
||||
m.v.Set("cache", cfg.Cache)
|
||||
m.v.Set("logger", cfg.Logger)
|
||||
m.v.Set("error_tracking", cfg.ErrorTracking)
|
||||
m.v.Set("middleware", cfg.Middleware)
|
||||
m.v.Set("cors", cfg.CORS)
|
||||
m.v.Set("event_broker", cfg.EventBroker)
|
||||
m.v.Set("dbmanager", cfg.DBManager)
|
||||
m.v.Set("paths", cfg.Paths)
|
||||
m.v.Set("extensions", cfg.Extensions)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get returns a configuration value by key
|
||||
func (m *Manager) Get(key string) interface{} {
|
||||
return m.v.Get(key)
|
||||
@@ -122,15 +147,32 @@ func (m *Manager) Set(key string, value interface{}) {
|
||||
m.v.Set(key, value)
|
||||
}
|
||||
|
||||
// SaveConfig writes the current configuration to the specified path
|
||||
func (m *Manager) SaveConfig(path string) error {
|
||||
if err := m.v.WriteConfigAs(path); err != nil {
|
||||
return fmt.Errorf("failed to save config to %s: %w", path, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// setDefaults sets default configuration values
|
||||
func setDefaults(v *viper.Viper) {
|
||||
// Server defaults
|
||||
v.SetDefault("server.addr", ":8080")
|
||||
v.SetDefault("server.shutdown_timeout", "30s")
|
||||
v.SetDefault("server.drain_timeout", "25s")
|
||||
v.SetDefault("server.read_timeout", "10s")
|
||||
v.SetDefault("server.write_timeout", "10s")
|
||||
v.SetDefault("server.idle_timeout", "120s")
|
||||
// Server defaults - new structure
|
||||
v.SetDefault("servers.default_server", "default")
|
||||
|
||||
// Global server timeout defaults
|
||||
v.SetDefault("servers.shutdown_timeout", "30s")
|
||||
v.SetDefault("servers.drain_timeout", "25s")
|
||||
v.SetDefault("servers.read_timeout", "10s")
|
||||
v.SetDefault("servers.write_timeout", "10s")
|
||||
v.SetDefault("servers.idle_timeout", "120s")
|
||||
|
||||
// Default server instance
|
||||
v.SetDefault("servers.instances.default.name", "default")
|
||||
v.SetDefault("servers.instances.default.host", "")
|
||||
v.SetDefault("servers.instances.default.port", 8080)
|
||||
v.SetDefault("servers.instances.default.description", "Default HTTP server")
|
||||
v.SetDefault("servers.instances.default.gzip", false)
|
||||
|
||||
// Tracing defaults
|
||||
v.SetDefault("tracing.enabled", false)
|
||||
@@ -166,6 +208,34 @@ func setDefaults(v *viper.Viper) {
|
||||
// Database defaults
|
||||
v.SetDefault("database.url", "")
|
||||
|
||||
// Database Manager defaults
|
||||
v.SetDefault("dbmanager.default_connection", "default")
|
||||
v.SetDefault("dbmanager.max_open_conns", 25)
|
||||
v.SetDefault("dbmanager.max_idle_conns", 5)
|
||||
v.SetDefault("dbmanager.conn_max_lifetime", "30m")
|
||||
v.SetDefault("dbmanager.conn_max_idle_time", "5m")
|
||||
v.SetDefault("dbmanager.retry_attempts", 3)
|
||||
v.SetDefault("dbmanager.retry_delay", "1s")
|
||||
v.SetDefault("dbmanager.retry_max_delay", "10s")
|
||||
v.SetDefault("dbmanager.health_check_interval", "30s")
|
||||
v.SetDefault("dbmanager.enable_auto_reconnect", true)
|
||||
|
||||
// Default PostgreSQL connection
|
||||
v.SetDefault("dbmanager.connections.default.name", "default")
|
||||
v.SetDefault("dbmanager.connections.default.type", "postgres")
|
||||
v.SetDefault("dbmanager.connections.default.host", "localhost")
|
||||
v.SetDefault("dbmanager.connections.default.port", 5432)
|
||||
v.SetDefault("dbmanager.connections.default.user", "postgres")
|
||||
v.SetDefault("dbmanager.connections.default.password", "")
|
||||
v.SetDefault("dbmanager.connections.default.database", "resolvespec")
|
||||
v.SetDefault("dbmanager.connections.default.sslmode", "disable")
|
||||
v.SetDefault("dbmanager.connections.default.connect_timeout", "10s")
|
||||
v.SetDefault("dbmanager.connections.default.query_timeout", "30s")
|
||||
v.SetDefault("dbmanager.connections.default.enable_tracing", false)
|
||||
v.SetDefault("dbmanager.connections.default.enable_metrics", false)
|
||||
v.SetDefault("dbmanager.connections.default.enable_logging", false)
|
||||
v.SetDefault("dbmanager.connections.default.default_orm", "bun")
|
||||
|
||||
// Event Broker defaults
|
||||
v.SetDefault("event_broker.enabled", false)
|
||||
v.SetDefault("event_broker.provider", "memory")
|
||||
@@ -200,4 +270,13 @@ func setDefaults(v *viper.Viper) {
|
||||
v.SetDefault("event_broker.retry_policy.initial_delay", "1s")
|
||||
v.SetDefault("event_broker.retry_policy.max_delay", "30s")
|
||||
v.SetDefault("event_broker.retry_policy.backoff_factor", 2.0)
|
||||
|
||||
// Paths defaults (common directory paths)
|
||||
v.SetDefault("paths.data_dir", "./data")
|
||||
v.SetDefault("paths.config_dir", "./config")
|
||||
v.SetDefault("paths.logs_dir", "./logs")
|
||||
v.SetDefault("paths.temp_dir", "./tmp")
|
||||
|
||||
// Extensions defaults (empty map)
|
||||
v.SetDefault("extensions", map[string]interface{}{})
|
||||
}
|
||||
|
||||
@@ -34,8 +34,8 @@ func TestDefaultValues(t *testing.T) {
|
||||
got interface{}
|
||||
expected interface{}
|
||||
}{
|
||||
{"server.addr", cfg.Server.Addr, ":8080"},
|
||||
{"server.shutdown_timeout", cfg.Server.ShutdownTimeout, 30 * time.Second},
|
||||
{"servers.default_server", cfg.Servers.DefaultServer, "default"},
|
||||
{"servers.shutdown_timeout", cfg.Servers.ShutdownTimeout, 30 * time.Second},
|
||||
{"tracing.enabled", cfg.Tracing.Enabled, false},
|
||||
{"tracing.service_name", cfg.Tracing.ServiceName, "resolvespec"},
|
||||
{"cache.provider", cfg.Cache.Provider, "memory"},
|
||||
@@ -46,6 +46,18 @@ func TestDefaultValues(t *testing.T) {
|
||||
{"middleware.rate_limit_burst", cfg.Middleware.RateLimitBurst, 200},
|
||||
}
|
||||
|
||||
// Test default server instance
|
||||
defaultServer, ok := cfg.Servers.Instances["default"]
|
||||
if !ok {
|
||||
t.Fatal("Default server instance not found")
|
||||
}
|
||||
if defaultServer.Port != 8080 {
|
||||
t.Errorf("default server port: got %d, want 8080", defaultServer.Port)
|
||||
}
|
||||
if defaultServer.Name != "default" {
|
||||
t.Errorf("default server name: got %s, want default", defaultServer.Name)
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.got != tt.expected {
|
||||
@@ -57,12 +69,12 @@ func TestDefaultValues(t *testing.T) {
|
||||
|
||||
func TestEnvironmentVariableOverrides(t *testing.T) {
|
||||
// Set environment variables
|
||||
os.Setenv("RESOLVESPEC_SERVER_ADDR", ":9090")
|
||||
os.Setenv("RESOLVESPEC_SERVERS_INSTANCES_DEFAULT_PORT", "9090")
|
||||
os.Setenv("RESOLVESPEC_TRACING_ENABLED", "true")
|
||||
os.Setenv("RESOLVESPEC_CACHE_PROVIDER", "redis")
|
||||
os.Setenv("RESOLVESPEC_LOGGER_DEV", "true")
|
||||
defer func() {
|
||||
os.Unsetenv("RESOLVESPEC_SERVER_ADDR")
|
||||
os.Unsetenv("RESOLVESPEC_SERVERS_INSTANCES_DEFAULT_PORT")
|
||||
os.Unsetenv("RESOLVESPEC_TRACING_ENABLED")
|
||||
os.Unsetenv("RESOLVESPEC_CACHE_PROVIDER")
|
||||
os.Unsetenv("RESOLVESPEC_LOGGER_DEV")
|
||||
@@ -84,7 +96,6 @@ func TestEnvironmentVariableOverrides(t *testing.T) {
|
||||
got interface{}
|
||||
expected interface{}
|
||||
}{
|
||||
{"server.addr", cfg.Server.Addr, ":9090"},
|
||||
{"tracing.enabled", cfg.Tracing.Enabled, true},
|
||||
{"cache.provider", cfg.Cache.Provider, "redis"},
|
||||
{"logger.dev", cfg.Logger.Dev, true},
|
||||
@@ -97,11 +108,17 @@ func TestEnvironmentVariableOverrides(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Test server port override
|
||||
defaultServer := cfg.Servers.Instances["default"]
|
||||
if defaultServer.Port != 9090 {
|
||||
t.Errorf("server port: got %d, want 9090", defaultServer.Port)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProgrammaticConfiguration(t *testing.T) {
|
||||
mgr := NewManager()
|
||||
mgr.Set("server.addr", ":7070")
|
||||
mgr.Set("servers.instances.default.port", 7070)
|
||||
mgr.Set("tracing.service_name", "test-service")
|
||||
|
||||
cfg, err := mgr.GetConfig()
|
||||
@@ -109,8 +126,8 @@ func TestProgrammaticConfiguration(t *testing.T) {
|
||||
t.Fatalf("Failed to get config: %v", err)
|
||||
}
|
||||
|
||||
if cfg.Server.Addr != ":7070" {
|
||||
t.Errorf("server.addr: got %s, want :7070", cfg.Server.Addr)
|
||||
if cfg.Servers.Instances["default"].Port != 7070 {
|
||||
t.Errorf("server port: got %d, want 7070", cfg.Servers.Instances["default"].Port)
|
||||
}
|
||||
|
||||
if cfg.Tracing.ServiceName != "test-service" {
|
||||
@@ -148,8 +165,8 @@ func TestWithOptions(t *testing.T) {
|
||||
}
|
||||
|
||||
// Set environment variable with custom prefix
|
||||
os.Setenv("MYAPP_SERVER_ADDR", ":5000")
|
||||
defer os.Unsetenv("MYAPP_SERVER_ADDR")
|
||||
os.Setenv("MYAPP_SERVERS_INSTANCES_DEFAULT_PORT", "5000")
|
||||
defer os.Unsetenv("MYAPP_SERVERS_INSTANCES_DEFAULT_PORT")
|
||||
|
||||
if err := mgr.Load(); err != nil {
|
||||
t.Fatalf("Failed to load config: %v", err)
|
||||
@@ -160,7 +177,432 @@ func TestWithOptions(t *testing.T) {
|
||||
t.Fatalf("Failed to get config: %v", err)
|
||||
}
|
||||
|
||||
if cfg.Server.Addr != ":5000" {
|
||||
t.Errorf("server.addr: got %s, want :5000", cfg.Server.Addr)
|
||||
if cfg.Servers.Instances["default"].Port != 5000 {
|
||||
t.Errorf("server port: got %d, want 5000", cfg.Servers.Instances["default"].Port)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServersConfig(t *testing.T) {
|
||||
mgr := NewManager()
|
||||
if err := mgr.Load(); err != nil {
|
||||
t.Fatalf("Failed to load config: %v", err)
|
||||
}
|
||||
|
||||
cfg, err := mgr.GetConfig()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get config: %v", err)
|
||||
}
|
||||
|
||||
// Test default server exists
|
||||
if cfg.Servers.DefaultServer != "default" {
|
||||
t.Errorf("Expected default_server to be 'default', got %s", cfg.Servers.DefaultServer)
|
||||
}
|
||||
|
||||
// Test default instance
|
||||
defaultServer, ok := cfg.Servers.Instances["default"]
|
||||
if !ok {
|
||||
t.Fatal("Default server instance not found")
|
||||
}
|
||||
|
||||
if defaultServer.Port != 8080 {
|
||||
t.Errorf("Expected default port 8080, got %d", defaultServer.Port)
|
||||
}
|
||||
|
||||
if defaultServer.Name != "default" {
|
||||
t.Errorf("Expected default name 'default', got %s", defaultServer.Name)
|
||||
}
|
||||
|
||||
if defaultServer.Description != "Default HTTP server" {
|
||||
t.Errorf("Expected description 'Default HTTP server', got %s", defaultServer.Description)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMultipleServerInstances(t *testing.T) {
|
||||
mgr := NewManager()
|
||||
|
||||
// Add additional server instances (default instance exists from defaults)
|
||||
mgr.Set("servers.default_server", "api")
|
||||
mgr.Set("servers.instances.api.name", "api")
|
||||
mgr.Set("servers.instances.api.host", "0.0.0.0")
|
||||
mgr.Set("servers.instances.api.port", 8080)
|
||||
mgr.Set("servers.instances.admin.name", "admin")
|
||||
mgr.Set("servers.instances.admin.host", "localhost")
|
||||
mgr.Set("servers.instances.admin.port", 8081)
|
||||
|
||||
cfg, err := mgr.GetConfig()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get config: %v", err)
|
||||
}
|
||||
|
||||
// Should have default + api + admin = 3 instances
|
||||
if len(cfg.Servers.Instances) < 2 {
|
||||
t.Errorf("Expected at least 2 server instances, got %d", len(cfg.Servers.Instances))
|
||||
}
|
||||
|
||||
// Verify api instance
|
||||
apiServer, ok := cfg.Servers.Instances["api"]
|
||||
if !ok {
|
||||
t.Fatal("API server instance not found")
|
||||
}
|
||||
if apiServer.Port != 8080 {
|
||||
t.Errorf("Expected API port 8080, got %d", apiServer.Port)
|
||||
}
|
||||
|
||||
// Verify admin instance
|
||||
adminServer, ok := cfg.Servers.Instances["admin"]
|
||||
if !ok {
|
||||
t.Fatal("Admin server instance not found")
|
||||
}
|
||||
if adminServer.Port != 8081 {
|
||||
t.Errorf("Expected admin port 8081, got %d", adminServer.Port)
|
||||
}
|
||||
|
||||
// Validate default server
|
||||
if err := cfg.Servers.Validate(); err != nil {
|
||||
t.Errorf("Server config validation failed: %v", err)
|
||||
}
|
||||
|
||||
// Get default
|
||||
defaultSrv, err := cfg.Servers.GetDefault()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get default server: %v", err)
|
||||
}
|
||||
if defaultSrv.Name != "api" {
|
||||
t.Errorf("Expected default server 'api', got '%s'", defaultSrv.Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtensionsField(t *testing.T) {
|
||||
mgr := NewManager()
|
||||
|
||||
// Set custom extensions
|
||||
mgr.Set("extensions.custom_feature.enabled", true)
|
||||
mgr.Set("extensions.custom_feature.api_key", "test-key")
|
||||
mgr.Set("extensions.another_extension.timeout", "5s")
|
||||
|
||||
cfg, err := mgr.GetConfig()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get config: %v", err)
|
||||
}
|
||||
|
||||
if cfg.Extensions == nil {
|
||||
t.Fatal("Extensions should not be nil")
|
||||
}
|
||||
|
||||
// Verify extensions are accessible
|
||||
customFeature := mgr.Get("extensions.custom_feature")
|
||||
if customFeature == nil {
|
||||
t.Error("custom_feature extension not found")
|
||||
}
|
||||
|
||||
// Verify via config manager methods
|
||||
if !mgr.GetBool("extensions.custom_feature.enabled") {
|
||||
t.Error("Expected custom_feature.enabled to be true")
|
||||
}
|
||||
|
||||
if mgr.GetString("extensions.custom_feature.api_key") != "test-key" {
|
||||
t.Error("Expected api_key to be 'test-key'")
|
||||
}
|
||||
}
|
||||
|
||||
func TestServerInstanceValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
instance ServerInstanceConfig
|
||||
expectErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid basic config",
|
||||
instance: ServerInstanceConfig{
|
||||
Name: "test",
|
||||
Port: 8080,
|
||||
},
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid port - too high",
|
||||
instance: ServerInstanceConfig{
|
||||
Name: "test",
|
||||
Port: 99999,
|
||||
},
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid port - zero",
|
||||
instance: ServerInstanceConfig{
|
||||
Name: "test",
|
||||
Port: 0,
|
||||
},
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "empty name",
|
||||
instance: ServerInstanceConfig{
|
||||
Name: "",
|
||||
Port: 8080,
|
||||
},
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "conflicting TLS options",
|
||||
instance: ServerInstanceConfig{
|
||||
Name: "test",
|
||||
Port: 8080,
|
||||
SelfSignedSSL: true,
|
||||
AutoTLS: true,
|
||||
},
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "incomplete SSL cert config",
|
||||
instance: ServerInstanceConfig{
|
||||
Name: "test",
|
||||
Port: 8080,
|
||||
SSLCert: "/path/to/cert.pem",
|
||||
// Missing SSLKey
|
||||
},
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "AutoTLS without domains",
|
||||
instance: ServerInstanceConfig{
|
||||
Name: "test",
|
||||
Port: 8080,
|
||||
AutoTLS: true,
|
||||
// Missing AutoTLSDomains
|
||||
},
|
||||
expectErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := tt.instance.Validate()
|
||||
if tt.expectErr && err == nil {
|
||||
t.Error("Expected validation error, got nil")
|
||||
}
|
||||
if !tt.expectErr && err != nil {
|
||||
t.Errorf("Expected no error, got: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyGlobalDefaults(t *testing.T) {
|
||||
globals := ServersConfig{
|
||||
ShutdownTimeout: 30 * time.Second,
|
||||
DrainTimeout: 25 * time.Second,
|
||||
ReadTimeout: 10 * time.Second,
|
||||
WriteTimeout: 10 * time.Second,
|
||||
IdleTimeout: 120 * time.Second,
|
||||
}
|
||||
|
||||
instance := ServerInstanceConfig{
|
||||
Name: "test",
|
||||
Port: 8080,
|
||||
}
|
||||
|
||||
// Apply global defaults
|
||||
instance.ApplyGlobalDefaults(globals)
|
||||
|
||||
// Check that defaults were applied
|
||||
if instance.ShutdownTimeout == nil || *instance.ShutdownTimeout != 30*time.Second {
|
||||
t.Error("ShutdownTimeout not applied correctly")
|
||||
}
|
||||
if instance.DrainTimeout == nil || *instance.DrainTimeout != 25*time.Second {
|
||||
t.Error("DrainTimeout not applied correctly")
|
||||
}
|
||||
if instance.ReadTimeout == nil || *instance.ReadTimeout != 10*time.Second {
|
||||
t.Error("ReadTimeout not applied correctly")
|
||||
}
|
||||
if instance.WriteTimeout == nil || *instance.WriteTimeout != 10*time.Second {
|
||||
t.Error("WriteTimeout not applied correctly")
|
||||
}
|
||||
if instance.IdleTimeout == nil || *instance.IdleTimeout != 120*time.Second {
|
||||
t.Error("IdleTimeout not applied correctly")
|
||||
}
|
||||
|
||||
// Test that explicit overrides are not replaced
|
||||
customTimeout := 60 * time.Second
|
||||
instance2 := ServerInstanceConfig{
|
||||
Name: "test2",
|
||||
Port: 8081,
|
||||
ShutdownTimeout: &customTimeout,
|
||||
}
|
||||
|
||||
instance2.ApplyGlobalDefaults(globals)
|
||||
|
||||
if instance2.ShutdownTimeout == nil || *instance2.ShutdownTimeout != 60*time.Second {
|
||||
t.Error("Custom ShutdownTimeout was overridden")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPathsConfig(t *testing.T) {
|
||||
mgr := NewManager()
|
||||
if err := mgr.Load(); err != nil {
|
||||
t.Fatalf("Failed to load config: %v", err)
|
||||
}
|
||||
|
||||
cfg, err := mgr.GetConfig()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get config: %v", err)
|
||||
}
|
||||
|
||||
// Test default paths exist
|
||||
if !cfg.Paths.Has("data_dir") {
|
||||
t.Error("Expected data_dir path to exist")
|
||||
}
|
||||
if !cfg.Paths.Has("config_dir") {
|
||||
t.Error("Expected config_dir path to exist")
|
||||
}
|
||||
if !cfg.Paths.Has("logs_dir") {
|
||||
t.Error("Expected logs_dir path to exist")
|
||||
}
|
||||
if !cfg.Paths.Has("temp_dir") {
|
||||
t.Error("Expected temp_dir path to exist")
|
||||
}
|
||||
|
||||
// Test Get method
|
||||
dataDir, err := cfg.Paths.Get("data_dir")
|
||||
if err != nil {
|
||||
t.Errorf("Failed to get data_dir: %v", err)
|
||||
}
|
||||
if dataDir != "./data" {
|
||||
t.Errorf("Expected data_dir to be './data', got '%s'", dataDir)
|
||||
}
|
||||
|
||||
// Test GetOrDefault
|
||||
existing := cfg.Paths.GetOrDefault("data_dir", "/default/path")
|
||||
if existing != "./data" {
|
||||
t.Errorf("Expected existing path, got '%s'", existing)
|
||||
}
|
||||
|
||||
nonExisting := cfg.Paths.GetOrDefault("nonexistent", "/default/path")
|
||||
if nonExisting != "/default/path" {
|
||||
t.Errorf("Expected default path, got '%s'", nonExisting)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPathsConfigMethods(t *testing.T) {
|
||||
pc := PathsConfig{
|
||||
"base": "/var/myapp",
|
||||
"data": "/var/myapp/data",
|
||||
}
|
||||
|
||||
// Test Get
|
||||
path, err := pc.Get("base")
|
||||
if err != nil {
|
||||
t.Errorf("Failed to get path: %v", err)
|
||||
}
|
||||
if path != "/var/myapp" {
|
||||
t.Errorf("Expected '/var/myapp', got '%s'", path)
|
||||
}
|
||||
|
||||
// Test Get non-existent
|
||||
_, err = pc.Get("nonexistent")
|
||||
if err == nil {
|
||||
t.Error("Expected error for non-existent path")
|
||||
}
|
||||
|
||||
// Test Set
|
||||
pc.Set("new_path", "/new/location")
|
||||
newPath, err := pc.Get("new_path")
|
||||
if err != nil {
|
||||
t.Errorf("Failed to get newly set path: %v", err)
|
||||
}
|
||||
if newPath != "/new/location" {
|
||||
t.Errorf("Expected '/new/location', got '%s'", newPath)
|
||||
}
|
||||
|
||||
// Test Has
|
||||
if !pc.Has("base") {
|
||||
t.Error("Expected 'base' path to exist")
|
||||
}
|
||||
if pc.Has("nonexistent") {
|
||||
t.Error("Expected 'nonexistent' path to not exist")
|
||||
}
|
||||
|
||||
// Test List
|
||||
names := pc.List()
|
||||
if len(names) != 3 {
|
||||
t.Errorf("Expected 3 paths, got %d", len(names))
|
||||
}
|
||||
|
||||
// Test Join
|
||||
joined, err := pc.Join("base", "subdir", "file.txt")
|
||||
if err != nil {
|
||||
t.Errorf("Failed to join paths: %v", err)
|
||||
}
|
||||
expected := "/var/myapp/subdir/file.txt"
|
||||
if joined != expected {
|
||||
t.Errorf("Expected '%s', got '%s'", expected, joined)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPathsConfigEnvironmentVariables(t *testing.T) {
|
||||
// Set environment variables for paths
|
||||
os.Setenv("RESOLVESPEC_PATHS_DATA_DIR", "/custom/data")
|
||||
os.Setenv("RESOLVESPEC_PATHS_LOGS_DIR", "/custom/logs")
|
||||
defer func() {
|
||||
os.Unsetenv("RESOLVESPEC_PATHS_DATA_DIR")
|
||||
os.Unsetenv("RESOLVESPEC_PATHS_LOGS_DIR")
|
||||
}()
|
||||
|
||||
mgr := NewManager()
|
||||
if err := mgr.Load(); err != nil {
|
||||
t.Fatalf("Failed to load config: %v", err)
|
||||
}
|
||||
|
||||
cfg, err := mgr.GetConfig()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get config: %v", err)
|
||||
}
|
||||
|
||||
// Test environment variable override of existing default path
|
||||
dataDir, err := cfg.Paths.Get("data_dir")
|
||||
if err != nil {
|
||||
t.Errorf("Failed to get data_dir: %v", err)
|
||||
}
|
||||
if dataDir != "/custom/data" {
|
||||
t.Errorf("Expected '/custom/data', got '%s'", dataDir)
|
||||
}
|
||||
|
||||
// Test another environment variable override
|
||||
logsDir, err := cfg.Paths.Get("logs_dir")
|
||||
if err != nil {
|
||||
t.Errorf("Failed to get logs_dir: %v", err)
|
||||
}
|
||||
if logsDir != "/custom/logs" {
|
||||
t.Errorf("Expected '/custom/logs', got '%s'", logsDir)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPathsConfigProgrammatic(t *testing.T) {
|
||||
mgr := NewManager()
|
||||
|
||||
// Set custom paths programmatically
|
||||
mgr.Set("paths.custom_dir", "/my/custom/dir")
|
||||
mgr.Set("paths.cache_dir", "/var/cache/myapp")
|
||||
|
||||
cfg, err := mgr.GetConfig()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get config: %v", err)
|
||||
}
|
||||
|
||||
// Verify custom paths
|
||||
customDir, err := cfg.Paths.Get("custom_dir")
|
||||
if err != nil {
|
||||
t.Errorf("Failed to get custom_dir: %v", err)
|
||||
}
|
||||
if customDir != "/my/custom/dir" {
|
||||
t.Errorf("Expected '/my/custom/dir', got '%s'", customDir)
|
||||
}
|
||||
|
||||
cacheDir, err := cfg.Paths.Get("cache_dir")
|
||||
if err != nil {
|
||||
t.Errorf("Failed to get cache_dir: %v", err)
|
||||
}
|
||||
if cacheDir != "/var/cache/myapp" {
|
||||
t.Errorf("Expected '/var/cache/myapp', got '%s'", cacheDir)
|
||||
}
|
||||
}
|
||||
|
||||
117
pkg/config/paths.go
Normal file
117
pkg/config/paths.go
Normal file
@@ -0,0 +1,117 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
// Get retrieves a path by name
|
||||
func (pc PathsConfig) Get(name string) (string, error) {
|
||||
if pc == nil {
|
||||
return "", fmt.Errorf("paths not initialized")
|
||||
}
|
||||
|
||||
path, ok := pc[name]
|
||||
if !ok {
|
||||
return "", fmt.Errorf("path '%s' not found", name)
|
||||
}
|
||||
|
||||
return path, nil
|
||||
}
|
||||
|
||||
// GetOrDefault retrieves a path by name, returning defaultPath if not found
|
||||
func (pc PathsConfig) GetOrDefault(name, defaultPath string) string {
|
||||
if pc == nil {
|
||||
return defaultPath
|
||||
}
|
||||
|
||||
path, ok := pc[name]
|
||||
if !ok {
|
||||
return defaultPath
|
||||
}
|
||||
|
||||
return path
|
||||
}
|
||||
|
||||
// Set sets a path by name
|
||||
func (pc PathsConfig) Set(name, path string) {
|
||||
pc[name] = path
|
||||
}
|
||||
|
||||
// Has checks if a path exists by name
|
||||
func (pc PathsConfig) Has(name string) bool {
|
||||
if pc == nil {
|
||||
return false
|
||||
}
|
||||
_, ok := pc[name]
|
||||
return ok
|
||||
}
|
||||
|
||||
// EnsureDir ensures a directory exists at the specified path name
|
||||
// Creates the directory if it doesn't exist with the given permissions
|
||||
func (pc PathsConfig) EnsureDir(name string, perm os.FileMode) error {
|
||||
path, err := pc.Get(name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Check if directory exists
|
||||
info, err := os.Stat(path)
|
||||
if err == nil {
|
||||
// Path exists, check if it's a directory
|
||||
if !info.IsDir() {
|
||||
return fmt.Errorf("path '%s' exists but is not a directory: %s", name, path)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Directory doesn't exist, create it
|
||||
if os.IsNotExist(err) {
|
||||
if err := os.MkdirAll(path, perm); err != nil {
|
||||
return fmt.Errorf("failed to create directory for '%s' at %s: %w", name, path, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("failed to stat path '%s' at %s: %w", name, path, err)
|
||||
}
|
||||
|
||||
// AbsPath returns the absolute path for a named path
|
||||
func (pc PathsConfig) AbsPath(name string) (string, error) {
|
||||
path, err := pc.Get(name)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
absPath, err := filepath.Abs(path)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get absolute path for '%s': %w", name, err)
|
||||
}
|
||||
|
||||
return absPath, nil
|
||||
}
|
||||
|
||||
// Join joins path segments with a named base path
|
||||
func (pc PathsConfig) Join(name string, elem ...string) (string, error) {
|
||||
base, err := pc.Get(name)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
parts := append([]string{base}, elem...)
|
||||
return filepath.Join(parts...), nil
|
||||
}
|
||||
|
||||
// List returns all configured path names
|
||||
func (pc PathsConfig) List() []string {
|
||||
if pc == nil {
|
||||
return []string{}
|
||||
}
|
||||
|
||||
names := make([]string, 0, len(pc))
|
||||
for name := range pc {
|
||||
names = append(names, name)
|
||||
}
|
||||
return names
|
||||
}
|
||||
107
pkg/config/server.go
Normal file
107
pkg/config/server.go
Normal file
@@ -0,0 +1,107 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// ApplyGlobalDefaults applies global server defaults to this instance
|
||||
// Called for instances that don't specify their own timeout values
|
||||
func (sic *ServerInstanceConfig) ApplyGlobalDefaults(globals ServersConfig) {
|
||||
if sic.ShutdownTimeout == nil && globals.ShutdownTimeout > 0 {
|
||||
t := globals.ShutdownTimeout
|
||||
sic.ShutdownTimeout = &t
|
||||
}
|
||||
if sic.DrainTimeout == nil && globals.DrainTimeout > 0 {
|
||||
t := globals.DrainTimeout
|
||||
sic.DrainTimeout = &t
|
||||
}
|
||||
if sic.ReadTimeout == nil && globals.ReadTimeout > 0 {
|
||||
t := globals.ReadTimeout
|
||||
sic.ReadTimeout = &t
|
||||
}
|
||||
if sic.WriteTimeout == nil && globals.WriteTimeout > 0 {
|
||||
t := globals.WriteTimeout
|
||||
sic.WriteTimeout = &t
|
||||
}
|
||||
if sic.IdleTimeout == nil && globals.IdleTimeout > 0 {
|
||||
t := globals.IdleTimeout
|
||||
sic.IdleTimeout = &t
|
||||
}
|
||||
}
|
||||
|
||||
// Validate validates the ServerInstanceConfig
|
||||
func (sic *ServerInstanceConfig) Validate() error {
|
||||
if sic.Name == "" {
|
||||
return fmt.Errorf("server instance name cannot be empty")
|
||||
}
|
||||
if sic.Port <= 0 || sic.Port > 65535 {
|
||||
return fmt.Errorf("invalid port: %d (must be 1-65535)", sic.Port)
|
||||
}
|
||||
|
||||
// Validate TLS options are mutually exclusive
|
||||
tlsCount := 0
|
||||
if sic.SSLCert != "" || sic.SSLKey != "" {
|
||||
tlsCount++
|
||||
}
|
||||
if sic.SelfSignedSSL {
|
||||
tlsCount++
|
||||
}
|
||||
if sic.AutoTLS {
|
||||
tlsCount++
|
||||
}
|
||||
if tlsCount > 1 {
|
||||
return fmt.Errorf("server '%s': only one TLS option can be enabled", sic.Name)
|
||||
}
|
||||
|
||||
// If using certificate files, both must be provided
|
||||
if (sic.SSLCert != "" && sic.SSLKey == "") || (sic.SSLCert == "" && sic.SSLKey != "") {
|
||||
return fmt.Errorf("server '%s': both ssl_cert and ssl_key must be provided", sic.Name)
|
||||
}
|
||||
|
||||
// If using AutoTLS, domains must be specified
|
||||
if sic.AutoTLS && len(sic.AutoTLSDomains) == 0 {
|
||||
return fmt.Errorf("server '%s': auto_tls_domains must be specified when auto_tls is enabled", sic.Name)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Validate validates the ServersConfig
|
||||
func (sc *ServersConfig) Validate() error {
|
||||
if len(sc.Instances) == 0 {
|
||||
return fmt.Errorf("at least one server instance must be configured")
|
||||
}
|
||||
|
||||
if sc.DefaultServer != "" {
|
||||
if _, ok := sc.Instances[sc.DefaultServer]; !ok {
|
||||
return fmt.Errorf("default server '%s' not found in instances", sc.DefaultServer)
|
||||
}
|
||||
}
|
||||
|
||||
// Validate each instance
|
||||
for name := range sc.Instances {
|
||||
instance := sc.Instances[name]
|
||||
if instance.Name != name {
|
||||
return fmt.Errorf("server instance name mismatch: key='%s', name='%s'", name, instance.Name)
|
||||
}
|
||||
if err := instance.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetDefault returns the default server instance configuration
|
||||
func (sc *ServersConfig) GetDefault() (*ServerInstanceConfig, error) {
|
||||
if sc.DefaultServer == "" {
|
||||
return nil, fmt.Errorf("no default server configured")
|
||||
}
|
||||
|
||||
instance, ok := sc.Instances[sc.DefaultServer]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("default server '%s' not found", sc.DefaultServer)
|
||||
}
|
||||
|
||||
return &instance, nil
|
||||
}
|
||||
@@ -50,6 +50,59 @@ type connectionManager struct {
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
var (
|
||||
// singleton instance of the manager
|
||||
instance Manager
|
||||
// instanceMu protects the singleton instance
|
||||
instanceMu sync.RWMutex
|
||||
)
|
||||
|
||||
// SetupManager initializes the singleton database manager with the provided configuration.
|
||||
// This function must be called before GetInstance().
|
||||
// Returns an error if the manager is already initialized or if configuration is invalid.
|
||||
func SetupManager(cfg ManagerConfig) error {
|
||||
instanceMu.Lock()
|
||||
defer instanceMu.Unlock()
|
||||
|
||||
if instance != nil {
|
||||
return fmt.Errorf("manager already initialized")
|
||||
}
|
||||
|
||||
mgr, err := NewManager(cfg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create manager: %w", err)
|
||||
}
|
||||
|
||||
instance = mgr
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetInstance returns the singleton instance of the database manager.
|
||||
// Returns an error if SetupManager has not been called yet.
|
||||
func GetInstance() (Manager, error) {
|
||||
instanceMu.RLock()
|
||||
defer instanceMu.RUnlock()
|
||||
|
||||
if instance == nil {
|
||||
return nil, fmt.Errorf("manager not initialized: call SetupManager first")
|
||||
}
|
||||
|
||||
return instance, nil
|
||||
}
|
||||
|
||||
// ResetInstance resets the singleton instance (primarily for testing purposes).
|
||||
// WARNING: This should only be used in tests. Calling this in production code
|
||||
// while the manager is in use can lead to undefined behavior.
|
||||
func ResetInstance() {
|
||||
instanceMu.Lock()
|
||||
defer instanceMu.Unlock()
|
||||
|
||||
if instance != nil {
|
||||
_ = instance.Close()
|
||||
}
|
||||
instance = nil
|
||||
}
|
||||
|
||||
// NewManager creates a new database connection manager
|
||||
func NewManager(cfg ManagerConfig) (Manager, error) {
|
||||
// Apply defaults and validate configuration
|
||||
|
||||
47
pkg/server/config_helper.go
Normal file
47
pkg/server/config_helper.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/config"
|
||||
)
|
||||
|
||||
// FromConfigInstanceToServerConfig converts a config.ServerInstanceConfig to server.Config
|
||||
// The handler must be provided separately as it cannot be serialized
|
||||
func FromConfigInstanceToServerConfig(sic *config.ServerInstanceConfig, handler http.Handler) Config {
|
||||
cfg := Config{
|
||||
Name: sic.Name,
|
||||
Host: sic.Host,
|
||||
Port: sic.Port,
|
||||
Description: sic.Description,
|
||||
Handler: handler,
|
||||
GZIP: sic.GZIP,
|
||||
|
||||
SSLCert: sic.SSLCert,
|
||||
SSLKey: sic.SSLKey,
|
||||
SelfSignedSSL: sic.SelfSignedSSL,
|
||||
AutoTLS: sic.AutoTLS,
|
||||
AutoTLSDomains: sic.AutoTLSDomains,
|
||||
AutoTLSCacheDir: sic.AutoTLSCacheDir,
|
||||
AutoTLSEmail: sic.AutoTLSEmail,
|
||||
}
|
||||
|
||||
// Apply timeouts (use pointers to override, or use zero values for defaults)
|
||||
if sic.ShutdownTimeout != nil {
|
||||
cfg.ShutdownTimeout = *sic.ShutdownTimeout
|
||||
}
|
||||
if sic.DrainTimeout != nil {
|
||||
cfg.DrainTimeout = *sic.DrainTimeout
|
||||
}
|
||||
if sic.ReadTimeout != nil {
|
||||
cfg.ReadTimeout = *sic.ReadTimeout
|
||||
}
|
||||
if sic.WriteTimeout != nil {
|
||||
cfg.WriteTimeout = *sic.WriteTimeout
|
||||
}
|
||||
if sic.IdleTimeout != nil {
|
||||
cfg.IdleTimeout = *sic.IdleTimeout
|
||||
}
|
||||
|
||||
return cfg
|
||||
}
|
||||
@@ -501,7 +501,7 @@ func (s *serverInstance) Start() error {
|
||||
|
||||
if useTLS {
|
||||
protocol = "HTTPS"
|
||||
logger.Info("Starting %s server '%s' on %s", protocol, s.cfg.Name, s.Addr())
|
||||
logger.Info("Starting %s server - Name: '%s', Address: %s, Port: %d", protocol, s.cfg.Name, s.cfg.Host, s.cfg.Port)
|
||||
|
||||
// For AutoTLS, we need to use a TLS listener
|
||||
if s.cfg.AutoTLS {
|
||||
@@ -519,7 +519,7 @@ func (s *serverInstance) Start() error {
|
||||
err = s.gracefulServer.server.ListenAndServeTLS(s.certFile, s.keyFile)
|
||||
}
|
||||
} else {
|
||||
logger.Info("Starting %s server '%s' on %s", protocol, s.cfg.Name, s.Addr())
|
||||
logger.Info("Starting %s server - Name: '%s', Address: %s, Port: %d", protocol, s.cfg.Name, s.cfg.Host, s.cfg.Port)
|
||||
err = s.gracefulServer.server.ListenAndServe()
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user