diff --git a/pkg/config/dbmanager.go b/pkg/config/dbmanager.go index 681f93c..c56b8af 100644 --- a/pkg/config/dbmanager.go +++ b/pkg/config/dbmanager.go @@ -2,6 +2,9 @@ package config import ( "fmt" + "net/url" + "strconv" + "strings" "time" ) @@ -91,6 +94,160 @@ func (c *DBManagerConfig) ToManagerConfig() interface{} { return c } +// PopulateFromDSN parses a DSN and populates the connection fields +func (cc *DBConnectionConfig) PopulateFromDSN() error { + if cc.DSN == "" { + return nil // Nothing to populate + } + + switch cc.Type { + case "postgres": + return cc.populatePostgresDSN() + case "mongodb": + return cc.populateMongoDSN() + case "mssql": + return cc.populateMSSQLDSN() + case "sqlite": + return cc.populateSQLiteDSN() + default: + return fmt.Errorf("cannot parse DSN for unsupported database type: %s", cc.Type) + } +} + +// populatePostgresDSN parses PostgreSQL DSN format +// Example: host=localhost port=5432 user=postgres password=secret dbname=mydb sslmode=disable +func (cc *DBConnectionConfig) populatePostgresDSN() error { + parts := strings.Fields(cc.DSN) + for _, part := range parts { + kv := strings.SplitN(part, "=", 2) + if len(kv) != 2 { + continue + } + key, value := kv[0], kv[1] + + switch key { + case "host": + cc.Host = value + case "port": + port, err := strconv.Atoi(value) + if err != nil { + return fmt.Errorf("invalid port in DSN: %w", err) + } + cc.Port = port + case "user": + cc.User = value + case "password": + cc.Password = value + case "dbname": + cc.Database = value + case "sslmode": + cc.SSLMode = value + case "search_path": + cc.Schema = value + } + } + return nil +} + +// populateMongoDSN parses MongoDB DSN format +// Example: mongodb://user:password@host:port/database?authSource=admin&replicaSet=rs0 +func (cc *DBConnectionConfig) populateMongoDSN() error { + u, err := url.Parse(cc.DSN) + if err != nil { + return fmt.Errorf("invalid MongoDB DSN: %w", err) + } + + // Extract user and password + if u.User != nil { + cc.User = u.User.Username() + if password, ok := u.User.Password(); ok { + cc.Password = password + } + } + + // Extract host and port + if u.Host != "" { + host := u.Host + if strings.Contains(host, ":") { + hostPort := strings.SplitN(host, ":", 2) + cc.Host = hostPort[0] + if port, err := strconv.Atoi(hostPort[1]); err == nil { + cc.Port = port + } + } else { + cc.Host = host + } + } + + // Extract database + if u.Path != "" { + cc.Database = strings.TrimPrefix(u.Path, "/") + } + + // Extract query parameters + params := u.Query() + if authSource := params.Get("authSource"); authSource != "" { + cc.AuthSource = authSource + } + if replicaSet := params.Get("replicaSet"); replicaSet != "" { + cc.ReplicaSet = replicaSet + } + if readPref := params.Get("readPreference"); readPref != "" { + cc.ReadPreference = readPref + } + + return nil +} + +// populateMSSQLDSN parses MSSQL DSN format +// Example: sqlserver://username:password@host:port?database=dbname&schema=dbo +func (cc *DBConnectionConfig) populateMSSQLDSN() error { + u, err := url.Parse(cc.DSN) + if err != nil { + return fmt.Errorf("invalid MSSQL DSN: %w", err) + } + + // Extract user and password + if u.User != nil { + cc.User = u.User.Username() + if password, ok := u.User.Password(); ok { + cc.Password = password + } + } + + // Extract host and port + if u.Host != "" { + host := u.Host + if strings.Contains(host, ":") { + hostPort := strings.SplitN(host, ":", 2) + cc.Host = hostPort[0] + if port, err := strconv.Atoi(hostPort[1]); err == nil { + cc.Port = port + } + } else { + cc.Host = host + } + } + + // Extract query parameters + params := u.Query() + if database := params.Get("database"); database != "" { + cc.Database = database + } + if schema := params.Get("schema"); schema != "" { + cc.Schema = schema + } + + return nil +} + +// populateSQLiteDSN parses SQLite DSN format +// Example: /path/to/database.db or :memory: +func (cc *DBConnectionConfig) populateSQLiteDSN() error { + cc.FilePath = cc.DSN + return nil +} + // Validate validates the DBManager configuration func (c *DBManagerConfig) Validate() error { if len(c.Connections) == 0 {