chore: ⬆️ updated deps
This commit is contained in:
+67
@@ -0,0 +1,67 @@
|
||||
package pgconn
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgproto3"
|
||||
)
|
||||
|
||||
func (c *PgConn) oauthAuth(ctx context.Context) error {
|
||||
if c.config.OAuthTokenProvider == nil {
|
||||
return errors.New("OAuth authentication required but no token provider configured")
|
||||
}
|
||||
|
||||
token, err := c.config.OAuthTokenProvider(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to obtain OAuth token: %w", err)
|
||||
}
|
||||
|
||||
// https://www.rfc-editor.org/rfc/rfc7628.html#section-3.1
|
||||
initialResponse := []byte("n,,\x01auth=Bearer " + token + "\x01\x01")
|
||||
|
||||
saslInitialResponse := &pgproto3.SASLInitialResponse{
|
||||
AuthMechanism: "OAUTHBEARER",
|
||||
Data: initialResponse,
|
||||
}
|
||||
c.frontend.Send(saslInitialResponse)
|
||||
err = c.flushWithPotentialWriteReadDeadlock()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
msg, err := c.receiveMessage()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch m := msg.(type) {
|
||||
case *pgproto3.AuthenticationOk:
|
||||
return nil
|
||||
case *pgproto3.AuthenticationSASLContinue:
|
||||
// Server sent error response in SASL continue
|
||||
// https://www.rfc-editor.org/rfc/rfc7628.html#section-3.2.2
|
||||
// https://www.rfc-editor.org/rfc/rfc7628.html#section-3.2.3
|
||||
errResponse := struct {
|
||||
Status string `json:"status"`
|
||||
Scope string `json:"scope"`
|
||||
OpenIDConfiguration string `json:"openid-configuration"`
|
||||
}{}
|
||||
err := json.Unmarshal(m.Data, &errResponse)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid OAuth error response from server: %w", err)
|
||||
}
|
||||
|
||||
// Per RFC 7628 section 3.2.3, we should send a SASLResponse which only contains \x01.
|
||||
// However, since the connection will be closed anyway, we can skip this
|
||||
return fmt.Errorf("OAuth authentication failed: %s", errResponse.Status)
|
||||
|
||||
case *pgproto3.ErrorResponse:
|
||||
return ErrorResponseToPgError(m)
|
||||
|
||||
default:
|
||||
return fmt.Errorf("unexpected message type during OAuth auth: %T", msg)
|
||||
}
|
||||
}
|
||||
+148
-21
@@ -1,7 +1,8 @@
|
||||
// SCRAM-SHA-256 authentication
|
||||
// SCRAM-SHA-256 and SCRAM-SHA-256-PLUS authentication
|
||||
//
|
||||
// Resources:
|
||||
// https://tools.ietf.org/html/rfc5802
|
||||
// https://tools.ietf.org/html/rfc5929
|
||||
// https://tools.ietf.org/html/rfc8265
|
||||
// https://www.postgresql.org/docs/current/sasl-authentication.html
|
||||
//
|
||||
@@ -15,19 +16,28 @@ package pgconn
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/hmac"
|
||||
"crypto/pbkdf2"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"crypto/sha512"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash"
|
||||
"slices"
|
||||
"strconv"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgproto3"
|
||||
"golang.org/x/crypto/pbkdf2"
|
||||
"golang.org/x/text/secure/precis"
|
||||
)
|
||||
|
||||
const clientNonceLen = 18
|
||||
const (
|
||||
clientNonceLen = 18
|
||||
scramSHA256Name = "SCRAM-SHA-256"
|
||||
scramSHA256PlusName = "SCRAM-SHA-256-PLUS"
|
||||
)
|
||||
|
||||
// Perform SCRAM authentication.
|
||||
func (c *PgConn) scramAuth(serverAuthMechanisms []string) error {
|
||||
@@ -36,9 +46,35 @@ func (c *PgConn) scramAuth(serverAuthMechanisms []string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
serverHasPlus := slices.Contains(sc.serverAuthMechanisms, scramSHA256PlusName)
|
||||
if c.config.ChannelBinding == "require" && !serverHasPlus {
|
||||
return errors.New("channel binding required but server does not support SCRAM-SHA-256-PLUS")
|
||||
}
|
||||
|
||||
// If we have a TLS connection and channel binding is not disabled, attempt to
|
||||
// extract the server certificate hash for tls-server-end-point channel binding.
|
||||
if tlsConn, ok := c.conn.(*tls.Conn); ok && c.config.ChannelBinding != "disable" {
|
||||
certHash, err := getTLSCertificateHash(tlsConn)
|
||||
if err != nil && c.config.ChannelBinding == "require" {
|
||||
return fmt.Errorf("channel binding required but failed to get server certificate hash: %w", err)
|
||||
}
|
||||
|
||||
// Upgrade to SCRAM-SHA-256-PLUS if we have binding data and the server supports it.
|
||||
if certHash != nil && serverHasPlus {
|
||||
sc.authMechanism = scramSHA256PlusName
|
||||
}
|
||||
|
||||
sc.channelBindingData = certHash
|
||||
sc.hasTLS = true
|
||||
}
|
||||
|
||||
if c.config.ChannelBinding == "require" && sc.channelBindingData == nil {
|
||||
return errors.New("channel binding required but channel binding data is not available")
|
||||
}
|
||||
|
||||
// Send client-first-message in a SASLInitialResponse
|
||||
saslInitialResponse := &pgproto3.SASLInitialResponse{
|
||||
AuthMechanism: "SCRAM-SHA-256",
|
||||
AuthMechanism: sc.authMechanism,
|
||||
Data: sc.clientFirstMessage(),
|
||||
}
|
||||
c.frontend.Send(saslInitialResponse)
|
||||
@@ -107,10 +143,31 @@ func (c *PgConn) rxSASLFinal() (*pgproto3.AuthenticationSASLFinal, error) {
|
||||
|
||||
type scramClient struct {
|
||||
serverAuthMechanisms []string
|
||||
password []byte
|
||||
password string
|
||||
clientNonce []byte
|
||||
|
||||
// authMechanism is the selected SASL mechanism for the client. Must be
|
||||
// either SCRAM-SHA-256 (default) or SCRAM-SHA-256-PLUS.
|
||||
//
|
||||
// Upgraded to SCRAM-SHA-256-PLUS during authentication when channel binding
|
||||
// is not disabled, channel binding data is available (TLS connection with
|
||||
// an obtainable server certificate hash) and the server advertises
|
||||
// SCRAM-SHA-256-PLUS.
|
||||
authMechanism string
|
||||
|
||||
// hasTLS indicates whether the connection is using TLS. This is
|
||||
// needed because the GS2 header must distinguish between a client that
|
||||
// supports channel binding but the server does not ("y,,") versus one
|
||||
// that does not support it at all ("n,,").
|
||||
hasTLS bool
|
||||
|
||||
// channelBindingData is the hash of the server's TLS certificate, computed
|
||||
// per the tls-server-end-point channel binding type (RFC 5929). Used as
|
||||
// the binding input in SCRAM-SHA-256-PLUS. nil when not in use.
|
||||
channelBindingData []byte
|
||||
|
||||
clientFirstMessageBare []byte
|
||||
clientGS2Header []byte
|
||||
|
||||
serverFirstMessage []byte
|
||||
clientAndServerNonce []byte
|
||||
@@ -124,26 +181,23 @@ type scramClient struct {
|
||||
func newScramClient(serverAuthMechanisms []string, password string) (*scramClient, error) {
|
||||
sc := &scramClient{
|
||||
serverAuthMechanisms: serverAuthMechanisms,
|
||||
authMechanism: scramSHA256Name,
|
||||
}
|
||||
|
||||
// Ensure server supports SCRAM-SHA-256
|
||||
hasScramSHA256 := false
|
||||
for _, mech := range sc.serverAuthMechanisms {
|
||||
if mech == "SCRAM-SHA-256" {
|
||||
hasScramSHA256 = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasScramSHA256 {
|
||||
// Ensure the server supports SCRAM-SHA-256. SCRAM-SHA-256-PLUS is the
|
||||
// channel binding variant and is only advertised when the server supports
|
||||
// SSL. PostgreSQL always advertises the base SCRAM-SHA-256 mechanism
|
||||
// regardless of SSL.
|
||||
if !slices.Contains(sc.serverAuthMechanisms, scramSHA256Name) {
|
||||
return nil, errors.New("server does not support SCRAM-SHA-256")
|
||||
}
|
||||
|
||||
// precis.OpaqueString is equivalent to SASLprep for password.
|
||||
var err error
|
||||
sc.password, err = precis.OpaqueString.Bytes([]byte(password))
|
||||
sc.password, err = precis.OpaqueString.String(password)
|
||||
if err != nil {
|
||||
// PostgreSQL allows passwords invalid according to SCRAM / SASLprep.
|
||||
sc.password = []byte(password)
|
||||
sc.password = password
|
||||
}
|
||||
|
||||
buf := make([]byte, clientNonceLen)
|
||||
@@ -158,8 +212,32 @@ func newScramClient(serverAuthMechanisms []string, password string) (*scramClien
|
||||
}
|
||||
|
||||
func (sc *scramClient) clientFirstMessage() []byte {
|
||||
sc.clientFirstMessageBare = []byte(fmt.Sprintf("n=,r=%s", sc.clientNonce))
|
||||
return []byte(fmt.Sprintf("n,,%s", sc.clientFirstMessageBare))
|
||||
// The client-first-message is the GS2 header concatenated with the bare
|
||||
// message (username + client nonce). The GS2 header communicates the
|
||||
// client's channel binding capability to the server:
|
||||
//
|
||||
// "n,," - client is not using TLS (channel binding not possible)
|
||||
// "y,," - client is using TLS but channel binding is not
|
||||
// in use (e.g., server did not advertise SCRAM-SHA-256-PLUS
|
||||
// or the server certificate hash was not obtainable)
|
||||
// "p=tls-server-end-point,," - channel binding is active via SCRAM-SHA-256-PLUS
|
||||
//
|
||||
// See:
|
||||
// https://www.rfc-editor.org/rfc/rfc5802#section-6
|
||||
// https://www.rfc-editor.org/rfc/rfc5929#section-4
|
||||
// https://www.postgresql.org/docs/current/sasl-authentication.html#SASL-SCRAM-SHA-256
|
||||
|
||||
sc.clientFirstMessageBare = fmt.Appendf(nil, "n=,r=%s", sc.clientNonce)
|
||||
|
||||
if sc.authMechanism == scramSHA256PlusName {
|
||||
sc.clientGS2Header = []byte("p=tls-server-end-point,,")
|
||||
} else if sc.hasTLS {
|
||||
sc.clientGS2Header = []byte("y,,")
|
||||
} else {
|
||||
sc.clientGS2Header = []byte("n,,")
|
||||
}
|
||||
|
||||
return append(sc.clientGS2Header, sc.clientFirstMessageBare...)
|
||||
}
|
||||
|
||||
func (sc *scramClient) recvServerFirstMessage(serverFirstMessage []byte) error {
|
||||
@@ -218,9 +296,25 @@ func (sc *scramClient) recvServerFirstMessage(serverFirstMessage []byte) error {
|
||||
}
|
||||
|
||||
func (sc *scramClient) clientFinalMessage() string {
|
||||
clientFinalMessageWithoutProof := []byte(fmt.Sprintf("c=biws,r=%s", sc.clientAndServerNonce))
|
||||
// The c= attribute carries the base64-encoded channel binding input.
|
||||
//
|
||||
// Without channel binding this is just the GS2 header alone ("biws" for
|
||||
// "n,," or "eSws" for "y,,").
|
||||
//
|
||||
// With channel binding, this is the GS2 header with the channel binding data
|
||||
// (certificate hash) appended.
|
||||
channelBindInput := sc.clientGS2Header
|
||||
if sc.authMechanism == scramSHA256PlusName {
|
||||
channelBindInput = slices.Concat(sc.clientGS2Header, sc.channelBindingData)
|
||||
}
|
||||
channelBindingEncoded := base64.StdEncoding.EncodeToString(channelBindInput)
|
||||
clientFinalMessageWithoutProof := fmt.Appendf(nil, "c=%s,r=%s", channelBindingEncoded, sc.clientAndServerNonce)
|
||||
|
||||
sc.saltedPassword = pbkdf2.Key([]byte(sc.password), sc.salt, sc.iterations, 32, sha256.New)
|
||||
var err error
|
||||
sc.saltedPassword, err = pbkdf2.Key(sha256.New, sc.password, sc.salt, sc.iterations, 32)
|
||||
if err != nil {
|
||||
panic(err) // This should never happen.
|
||||
}
|
||||
sc.authMessage = bytes.Join([][]byte{sc.clientFirstMessageBare, sc.serverFirstMessage, clientFinalMessageWithoutProof}, []byte(","))
|
||||
|
||||
clientProof := computeClientProof(sc.saltedPassword, sc.authMessage)
|
||||
@@ -254,7 +348,7 @@ func computeClientProof(saltedPassword, authMessage []byte) []byte {
|
||||
clientSignature := computeHMAC(storedKey[:], authMessage)
|
||||
|
||||
clientProof := make([]byte, len(clientSignature))
|
||||
for i := 0; i < len(clientSignature); i++ {
|
||||
for i := range clientSignature {
|
||||
clientProof[i] = clientKey[i] ^ clientSignature[i]
|
||||
}
|
||||
|
||||
@@ -270,3 +364,36 @@ func computeServerSignature(saltedPassword, authMessage []byte) []byte {
|
||||
base64.StdEncoding.Encode(buf, serverSignature)
|
||||
return buf
|
||||
}
|
||||
|
||||
// Get the server certificate hash for SCRAM channel binding type
|
||||
// tls-server-end-point.
|
||||
func getTLSCertificateHash(conn *tls.Conn) ([]byte, error) {
|
||||
state := conn.ConnectionState()
|
||||
if len(state.PeerCertificates) == 0 {
|
||||
return nil, errors.New("no peer certificates for channel binding")
|
||||
}
|
||||
|
||||
cert := state.PeerCertificates[0]
|
||||
|
||||
// Per RFC 5929 section 4.1: If the certificate's signatureAlgorithm uses
|
||||
// MD5 or SHA-1, use SHA-256. Otherwise use the hash from the signature
|
||||
// algorithm.
|
||||
//
|
||||
// See: https://www.rfc-editor.org/rfc/rfc5929.html#section-4.1
|
||||
var h hash.Hash
|
||||
switch cert.SignatureAlgorithm {
|
||||
case x509.MD5WithRSA, x509.SHA1WithRSA, x509.ECDSAWithSHA1:
|
||||
h = sha256.New()
|
||||
case x509.SHA256WithRSA, x509.SHA256WithRSAPSS, x509.ECDSAWithSHA256:
|
||||
h = sha256.New()
|
||||
case x509.SHA384WithRSA, x509.SHA384WithRSAPSS, x509.ECDSAWithSHA384:
|
||||
h = sha512.New384()
|
||||
case x509.SHA512WithRSA, x509.SHA512WithRSAPSS, x509.ECDSAWithSHA512:
|
||||
h = sha512.New()
|
||||
default:
|
||||
return nil, fmt.Errorf("tls-server-end-point channel binding is undefined for certificate signature algorithm %v", cert.SignatureAlgorithm)
|
||||
}
|
||||
|
||||
h.Write(cert.Raw)
|
||||
return h.Sum(nil), nil
|
||||
}
|
||||
|
||||
+100
-10
@@ -8,6 +8,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"maps"
|
||||
"math"
|
||||
"net"
|
||||
"net/url"
|
||||
@@ -55,6 +56,13 @@ type Config struct {
|
||||
|
||||
SSLNegotiation string // sslnegotiation=postgres or sslnegotiation=direct
|
||||
|
||||
// AfterNetConnect is called after the network connection, including TLS if applicable, is established but before any
|
||||
// PostgreSQL protocol communication. It takes the established net.Conn and returns a net.Conn that will be used in
|
||||
// its place. It can be used to wrap the net.Conn (e.g. for logging, diagnostics, or testing). Its functionality has
|
||||
// some overlap with DialFunc. However, DialFunc takes place before TLS is established and cannot be used to control
|
||||
// the final net.Conn used for PostgreSQL protocol communication while AfterNetConnect can.
|
||||
AfterNetConnect func(ctx context.Context, config *Config, conn net.Conn) (net.Conn, error)
|
||||
|
||||
// ValidateConnect is called during a connection attempt after a successful authentication with the PostgreSQL server.
|
||||
// It can be used to validate that the server is acceptable. If this returns an error the connection is closed and the next
|
||||
// fallback config is tried. This allows implementing high availability behavior such as libpq does with target_session_attrs.
|
||||
@@ -75,6 +83,23 @@ type Config struct {
|
||||
// that you close on FATAL errors by returning false.
|
||||
OnPgError PgErrorHandler
|
||||
|
||||
// OAuthTokenProvider is a function that returns an OAuth token for authentication. If set, it will be used for
|
||||
// OAUTHBEARER SASL authentication when the server requests it.
|
||||
OAuthTokenProvider func(context.Context) (string, error)
|
||||
|
||||
// MinProtocolVersion is the minimum acceptable PostgreSQL protocol version.
|
||||
// If the server does not support at least this version, the connection will fail.
|
||||
// Valid values: "3.0", "3.2", "latest". Defaults to "3.0".
|
||||
MinProtocolVersion string
|
||||
|
||||
// MaxProtocolVersion is the maximum PostgreSQL protocol version to request from the server.
|
||||
// Valid values: "3.0", "3.2", "latest". Defaults to "3.0" for compatibility.
|
||||
MaxProtocolVersion string
|
||||
|
||||
// ChannelBinding is the channel_binding parameter for SCRAM-SHA-256-PLUS authentication.
|
||||
// Valid values: "disable", "prefer", "require". Defaults to "prefer".
|
||||
ChannelBinding string
|
||||
|
||||
createdByParseConfig bool // Used to enforce created by ParseConfig rule.
|
||||
}
|
||||
|
||||
@@ -96,9 +121,7 @@ func (c *Config) Copy() *Config {
|
||||
}
|
||||
if newConf.RuntimeParams != nil {
|
||||
newConf.RuntimeParams = make(map[string]string, len(c.RuntimeParams))
|
||||
for k, v := range c.RuntimeParams {
|
||||
newConf.RuntimeParams[k] = v
|
||||
}
|
||||
maps.Copy(newConf.RuntimeParams, c.RuntimeParams)
|
||||
}
|
||||
if newConf.Fallbacks != nil {
|
||||
newConf.Fallbacks = make([]*FallbackConfig, len(c.Fallbacks))
|
||||
@@ -207,6 +230,8 @@ func NetworkAddress(host string, port uint16) (network, address string) {
|
||||
// PGCONNECT_TIMEOUT
|
||||
// PGTARGETSESSIONATTRS
|
||||
// PGTZ
|
||||
// PGMINPROTOCOLVERSION
|
||||
// PGMAXPROTOCOLVERSION
|
||||
//
|
||||
// See http://www.postgresql.org/docs/current/static/libpq-envars.html for details on the meaning of environment variables.
|
||||
//
|
||||
@@ -332,6 +357,9 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con
|
||||
"target_session_attrs": {},
|
||||
"service": {},
|
||||
"servicefile": {},
|
||||
"min_protocol_version": {},
|
||||
"max_protocol_version": {},
|
||||
"channel_binding": {},
|
||||
}
|
||||
|
||||
// Adding kerberos configuration
|
||||
@@ -424,6 +452,52 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con
|
||||
return nil, &ParseConfigError{ConnString: connString, msg: fmt.Sprintf("unknown target_session_attrs value: %v", tsa)}
|
||||
}
|
||||
|
||||
minProto, err := parseProtocolVersion(settings["min_protocol_version"])
|
||||
if err != nil {
|
||||
return nil, &ParseConfigError{ConnString: connString, msg: fmt.Sprintf("invalid min_protocol_version: %q", settings["min_protocol_version"]), err: err}
|
||||
}
|
||||
maxProto, err := parseProtocolVersion(settings["max_protocol_version"])
|
||||
if err != nil {
|
||||
return nil, &ParseConfigError{ConnString: connString, msg: fmt.Sprintf("invalid max_protocol_version: %q", settings["max_protocol_version"]), err: err}
|
||||
}
|
||||
|
||||
config.MinProtocolVersion = settings["min_protocol_version"]
|
||||
config.MaxProtocolVersion = settings["max_protocol_version"]
|
||||
|
||||
if config.MinProtocolVersion == "" {
|
||||
config.MinProtocolVersion = "3.0"
|
||||
}
|
||||
|
||||
// When max_protocol_version is not explicitly set, default based on
|
||||
// min_protocol_version. This matches libpq behavior: if min > 3.0,
|
||||
// default max to latest; otherwise default to 3.0 for compatibility
|
||||
// with older servers/poolers that don't support NegotiateProtocolVersion.
|
||||
if config.MaxProtocolVersion == "" {
|
||||
if minProto > pgproto3.ProtocolVersion30 {
|
||||
config.MaxProtocolVersion = "latest"
|
||||
} else {
|
||||
config.MaxProtocolVersion = "3.0"
|
||||
}
|
||||
}
|
||||
|
||||
// Only error when max_protocol_version was explicitly set and conflicts
|
||||
// with min_protocol_version. When max_protocol_version is not explicitly
|
||||
// set, the auto-raise logic above already ensures a valid default.
|
||||
if minProto > maxProto && settings["max_protocol_version"] != "" {
|
||||
return nil, &ParseConfigError{ConnString: connString, msg: "min_protocol_version cannot be greater than max_protocol_version"}
|
||||
}
|
||||
|
||||
switch channelBinding := settings["channel_binding"]; channelBinding {
|
||||
case "", "prefer":
|
||||
config.ChannelBinding = "prefer"
|
||||
case "disable":
|
||||
config.ChannelBinding = "disable"
|
||||
case "require":
|
||||
config.ChannelBinding = "require"
|
||||
default:
|
||||
return nil, &ParseConfigError{ConnString: connString, msg: fmt.Sprintf("unknown channel_binding value: %v", channelBinding)}
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
@@ -431,9 +505,7 @@ func mergeSettings(settingSets ...map[string]string) map[string]string {
|
||||
settings := make(map[string]string)
|
||||
|
||||
for _, s2 := range settingSets {
|
||||
for k, v := range s2 {
|
||||
settings[k] = v
|
||||
}
|
||||
maps.Copy(settings, s2)
|
||||
}
|
||||
|
||||
return settings
|
||||
@@ -463,6 +535,8 @@ func parseEnvSettings() map[string]string {
|
||||
"PGSERVICEFILE": "servicefile",
|
||||
"PGTZ": "timezone",
|
||||
"PGOPTIONS": "options",
|
||||
"PGMINPROTOCOLVERSION": "min_protocol_version",
|
||||
"PGMAXPROTOCOLVERSION": "max_protocol_version",
|
||||
}
|
||||
|
||||
for envname, realname := range nameMap {
|
||||
@@ -487,7 +561,9 @@ func parseURLSettings(connString string) (map[string]string, error) {
|
||||
}
|
||||
|
||||
if parsedURL.User != nil {
|
||||
settings["user"] = parsedURL.User.Username()
|
||||
if u := parsedURL.User.Username(); u != "" {
|
||||
settings["user"] = u
|
||||
}
|
||||
if password, present := parsedURL.User.Password(); present {
|
||||
settings["password"] = password
|
||||
}
|
||||
@@ -496,7 +572,7 @@ func parseURLSettings(connString string) (map[string]string, error) {
|
||||
// Handle multiple host:port's in url.Host by splitting them into host,host,host and port,port,port.
|
||||
var hosts []string
|
||||
var ports []string
|
||||
for _, host := range strings.Split(parsedURL.Host, ",") {
|
||||
for host := range strings.SplitSeq(parsedURL.Host, ",") {
|
||||
if host == "" {
|
||||
continue
|
||||
}
|
||||
@@ -614,6 +690,9 @@ func parseKeywordValueSettings(s string) (map[string]string, error) {
|
||||
return nil, errors.New("invalid keyword/value")
|
||||
}
|
||||
|
||||
if key == "user" && val == "" {
|
||||
continue
|
||||
}
|
||||
settings[key] = val
|
||||
}
|
||||
|
||||
@@ -784,7 +863,7 @@ func configTLS(settings map[string]string, thisHost string, parseConfigOptions P
|
||||
// Attempt decryption with pass phrase
|
||||
// NOTE: only supports RSA (PKCS#1)
|
||||
if sslpassword != "" {
|
||||
decryptedKey, decryptedError = x509.DecryptPEMBlock(block, []byte(sslpassword))
|
||||
decryptedKey, decryptedError = x509.DecryptPEMBlock(block, []byte(sslpassword)) //nolint:ineffassign
|
||||
}
|
||||
// if sslpassword not provided or has decryption error when use it
|
||||
// try to find sslpassword with callback function
|
||||
@@ -799,7 +878,7 @@ func configTLS(settings map[string]string, thisHost string, parseConfigOptions P
|
||||
decryptedKey, decryptedError = x509.DecryptPEMBlock(block, []byte(sslpassword))
|
||||
// Should we also provide warning for PKCS#1 needed?
|
||||
if decryptedError != nil {
|
||||
return nil, fmt.Errorf("unable to decrypt key: %w", err)
|
||||
return nil, fmt.Errorf("unable to decrypt key: %w", decryptedError)
|
||||
}
|
||||
|
||||
pemBytes := pem.Block{
|
||||
@@ -951,3 +1030,14 @@ func ValidateConnectTargetSessionAttrsPreferStandby(ctx context.Context, pgConn
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseProtocolVersion(s string) (uint32, error) {
|
||||
switch s {
|
||||
case "", "3.0":
|
||||
return pgproto3.ProtocolVersion30, nil
|
||||
case "3.2", "latest":
|
||||
return pgproto3.ProtocolVersion32, nil
|
||||
default:
|
||||
return 0, fmt.Errorf("invalid protocol version: %q", s)
|
||||
}
|
||||
}
|
||||
|
||||
+19
-27
@@ -8,12 +8,13 @@ import (
|
||||
// ContextWatcher watches a context and performs an action when the context is canceled. It can watch one context at a
|
||||
// time.
|
||||
type ContextWatcher struct {
|
||||
handler Handler
|
||||
unwatchChan chan struct{}
|
||||
handler Handler
|
||||
|
||||
lock sync.Mutex
|
||||
watchInProgress bool
|
||||
onCancelWasCalled bool
|
||||
// Lock protects the members below.
|
||||
lock sync.Mutex
|
||||
// Stop is the handle for an "after func". See [context.AfterFunc].
|
||||
stop func() bool
|
||||
done chan struct{}
|
||||
}
|
||||
|
||||
// NewContextWatcher returns a ContextWatcher. onCancel will be called when a watched context is canceled.
|
||||
@@ -21,8 +22,7 @@ type ContextWatcher struct {
|
||||
// onCancel called.
|
||||
func NewContextWatcher(handler Handler) *ContextWatcher {
|
||||
cw := &ContextWatcher{
|
||||
handler: handler,
|
||||
unwatchChan: make(chan struct{}),
|
||||
handler: handler,
|
||||
}
|
||||
|
||||
return cw
|
||||
@@ -33,25 +33,16 @@ func (cw *ContextWatcher) Watch(ctx context.Context) {
|
||||
cw.lock.Lock()
|
||||
defer cw.lock.Unlock()
|
||||
|
||||
if cw.watchInProgress {
|
||||
panic("Watch already in progress")
|
||||
if cw.stop != nil {
|
||||
panic("watch already in progress")
|
||||
}
|
||||
|
||||
cw.onCancelWasCalled = false
|
||||
|
||||
if ctx.Done() != nil {
|
||||
cw.watchInProgress = true
|
||||
go func() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
cw.handler.HandleCancel(ctx)
|
||||
cw.onCancelWasCalled = true
|
||||
<-cw.unwatchChan
|
||||
case <-cw.unwatchChan:
|
||||
}
|
||||
}()
|
||||
} else {
|
||||
cw.watchInProgress = false
|
||||
cw.done = make(chan struct{})
|
||||
cw.stop = context.AfterFunc(ctx, func() {
|
||||
cw.handler.HandleCancel(ctx)
|
||||
close(cw.done)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -61,12 +52,13 @@ func (cw *ContextWatcher) Unwatch() {
|
||||
cw.lock.Lock()
|
||||
defer cw.lock.Unlock()
|
||||
|
||||
if cw.watchInProgress {
|
||||
cw.unwatchChan <- struct{}{}
|
||||
if cw.onCancelWasCalled {
|
||||
if cw.stop != nil {
|
||||
if !cw.stop() {
|
||||
<-cw.done
|
||||
cw.handler.HandleUnwatchAfterCancel()
|
||||
}
|
||||
cw.watchInProgress = false
|
||||
cw.stop = nil
|
||||
cw.done = nil
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
+17
@@ -254,3 +254,20 @@ func (e *NotPreferredError) SafeToRetry() bool {
|
||||
func (e *NotPreferredError) Unwrap() error {
|
||||
return e.err
|
||||
}
|
||||
|
||||
type PrepareError struct {
|
||||
err error
|
||||
|
||||
ParseComplete bool // Indicates whether the error occurred after a ParseComplete message was received.
|
||||
}
|
||||
|
||||
func (e *PrepareError) Error() string {
|
||||
if e.ParseComplete {
|
||||
return fmt.Sprintf("prepare failed after ParseComplete: %s", e.err.Error())
|
||||
}
|
||||
return e.err.Error()
|
||||
}
|
||||
|
||||
func (e *PrepareError) Unwrap() error {
|
||||
return e.err
|
||||
}
|
||||
|
||||
+618
-150
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user