chore: ⬆️ updated deps
This commit is contained in:
+1
@@ -33,6 +33,7 @@ func (dst *AuthenticationSASL) Decode(src []byte) error {
|
||||
return errors.New("bad auth type")
|
||||
}
|
||||
|
||||
dst.AuthMechanisms = dst.AuthMechanisms[:0]
|
||||
authMechanisms := src[4:]
|
||||
for len(authMechanisms) > 1 {
|
||||
idx := bytes.IndexByte(authMechanisms, 0)
|
||||
|
||||
+5
-5
@@ -46,8 +46,8 @@ type Backend struct {
|
||||
}
|
||||
|
||||
const (
|
||||
minStartupPacketLen = 4 // minStartupPacketLen is a single 32-bit int version or code.
|
||||
maxStartupPacketLen = 10000 // maxStartupPacketLen is MAX_STARTUP_PACKET_LENGTH from PG source.
|
||||
minStartupPacketLen = 4 // minStartupPacketLen is a single 32-bit int version or code.
|
||||
maxStartupPacketLen = 10_000 // maxStartupPacketLen is MAX_STARTUP_PACKET_LENGTH from PG source.
|
||||
)
|
||||
|
||||
// NewBackend creates a new Backend.
|
||||
@@ -123,7 +123,7 @@ func (b *Backend) ReceiveStartupMessage() (FrontendMessage, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
msgSize := int(binary.BigEndian.Uint32(buf) - 4)
|
||||
msgSize := int(int32(binary.BigEndian.Uint32(buf)) - 4)
|
||||
|
||||
if msgSize < minStartupPacketLen || msgSize > maxStartupPacketLen {
|
||||
return nil, fmt.Errorf("invalid length of startup packet: %d", msgSize)
|
||||
@@ -137,7 +137,7 @@ func (b *Backend) ReceiveStartupMessage() (FrontendMessage, error) {
|
||||
code := binary.BigEndian.Uint32(buf)
|
||||
|
||||
switch code {
|
||||
case ProtocolVersionNumber:
|
||||
case ProtocolVersion30, ProtocolVersion32:
|
||||
err = b.startupMessage.Decode(buf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -176,7 +176,7 @@ func (b *Backend) Receive() (FrontendMessage, error) {
|
||||
|
||||
b.msgType = header[0]
|
||||
|
||||
msgLength := int(binary.BigEndian.Uint32(header[1:]))
|
||||
msgLength := int(int32(binary.BigEndian.Uint32(header[1:])))
|
||||
if msgLength < 4 {
|
||||
return nil, fmt.Errorf("invalid message length: %d", msgLength)
|
||||
}
|
||||
|
||||
+27
-6
@@ -2,6 +2,7 @@ package pgproto3
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/jackc/pgx/v5/internal/pgio"
|
||||
@@ -9,7 +10,7 @@ import (
|
||||
|
||||
type BackendKeyData struct {
|
||||
ProcessID uint32
|
||||
SecretKey uint32
|
||||
SecretKey []byte
|
||||
}
|
||||
|
||||
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||
@@ -18,12 +19,13 @@ func (*BackendKeyData) Backend() {}
|
||||
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||
// type identifier and 4 byte message length.
|
||||
func (dst *BackendKeyData) Decode(src []byte) error {
|
||||
if len(src) != 8 {
|
||||
if len(src) < 8 {
|
||||
return &invalidMessageLenErr{messageType: "BackendKeyData", expectedLen: 8, actualLen: len(src)}
|
||||
}
|
||||
|
||||
dst.ProcessID = binary.BigEndian.Uint32(src[:4])
|
||||
dst.SecretKey = binary.BigEndian.Uint32(src[4:])
|
||||
dst.SecretKey = make([]byte, len(src)-4)
|
||||
copy(dst.SecretKey, src[4:])
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -32,7 +34,7 @@ func (dst *BackendKeyData) Decode(src []byte) error {
|
||||
func (src *BackendKeyData) Encode(dst []byte) ([]byte, error) {
|
||||
dst, sp := beginMessage(dst, 'K')
|
||||
dst = pgio.AppendUint32(dst, src.ProcessID)
|
||||
dst = pgio.AppendUint32(dst, src.SecretKey)
|
||||
dst = append(dst, src.SecretKey...)
|
||||
return finishMessage(dst, sp)
|
||||
}
|
||||
|
||||
@@ -41,10 +43,29 @@ func (src BackendKeyData) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
ProcessID uint32
|
||||
SecretKey uint32
|
||||
SecretKey string
|
||||
}{
|
||||
Type: "BackendKeyData",
|
||||
ProcessID: src.ProcessID,
|
||||
SecretKey: src.SecretKey,
|
||||
SecretKey: hex.EncodeToString(src.SecretKey),
|
||||
})
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements encoding/json.Unmarshaler.
|
||||
func (dst *BackendKeyData) UnmarshalJSON(data []byte) error {
|
||||
var msg struct {
|
||||
ProcessID uint32
|
||||
SecretKey string
|
||||
}
|
||||
if err := json.Unmarshal(data, &msg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dst.ProcessID = msg.ProcessID
|
||||
secretKey, err := hex.DecodeString(msg.SecretKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
dst.SecretKey = secretKey
|
||||
return nil
|
||||
}
|
||||
|
||||
+4
-4
@@ -54,7 +54,7 @@ func (dst *Bind) Decode(src []byte) error {
|
||||
if len(src[rp:]) < len(dst.ParameterFormatCodes)*2 {
|
||||
return &invalidMessageFormatErr{messageType: "Bind"}
|
||||
}
|
||||
for i := 0; i < parameterFormatCodeCount; i++ {
|
||||
for i := range parameterFormatCodeCount {
|
||||
dst.ParameterFormatCodes[i] = int16(binary.BigEndian.Uint16(src[rp:]))
|
||||
rp += 2
|
||||
}
|
||||
@@ -69,7 +69,7 @@ func (dst *Bind) Decode(src []byte) error {
|
||||
if parameterCount > 0 {
|
||||
dst.Parameters = make([][]byte, parameterCount)
|
||||
|
||||
for i := 0; i < parameterCount; i++ {
|
||||
for i := range parameterCount {
|
||||
if len(src[rp:]) < 4 {
|
||||
return &invalidMessageFormatErr{messageType: "Bind"}
|
||||
}
|
||||
@@ -82,7 +82,7 @@ func (dst *Bind) Decode(src []byte) error {
|
||||
continue
|
||||
}
|
||||
|
||||
if len(src[rp:]) < msgSize {
|
||||
if msgSize < 0 || len(src[rp:]) < msgSize {
|
||||
return &invalidMessageFormatErr{messageType: "Bind"}
|
||||
}
|
||||
|
||||
@@ -101,7 +101,7 @@ func (dst *Bind) Decode(src []byte) error {
|
||||
if len(src[rp:]) < len(dst.ResultFormatCodes)*2 {
|
||||
return &invalidMessageFormatErr{messageType: "Bind"}
|
||||
}
|
||||
for i := 0; i < resultFormatCodeCount; i++ {
|
||||
for i := range resultFormatCodeCount {
|
||||
dst.ResultFormatCodes[i] = int16(binary.BigEndian.Uint16(src[rp:]))
|
||||
rp += 2
|
||||
}
|
||||
|
||||
+36
-9
@@ -2,6 +2,7 @@ package pgproto3
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
|
||||
@@ -12,35 +13,42 @@ const cancelRequestCode = 80877102
|
||||
|
||||
type CancelRequest struct {
|
||||
ProcessID uint32
|
||||
SecretKey uint32
|
||||
SecretKey []byte
|
||||
}
|
||||
|
||||
// Frontend identifies this message as sendable by a PostgreSQL frontend.
|
||||
func (*CancelRequest) Frontend() {}
|
||||
|
||||
func (dst *CancelRequest) Decode(src []byte) error {
|
||||
if len(src) != 12 {
|
||||
return errors.New("bad cancel request size")
|
||||
if len(src) < 12 {
|
||||
return errors.New("cancel request too short")
|
||||
}
|
||||
if len(src) > 264 {
|
||||
return errors.New("cancel request too long")
|
||||
}
|
||||
|
||||
requestCode := binary.BigEndian.Uint32(src)
|
||||
|
||||
if requestCode != cancelRequestCode {
|
||||
return errors.New("bad cancel request code")
|
||||
}
|
||||
|
||||
dst.ProcessID = binary.BigEndian.Uint32(src[4:])
|
||||
dst.SecretKey = binary.BigEndian.Uint32(src[8:])
|
||||
dst.SecretKey = make([]byte, len(src)-8)
|
||||
copy(dst.SecretKey, src[8:])
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 4 byte message length.
|
||||
func (src *CancelRequest) Encode(dst []byte) ([]byte, error) {
|
||||
dst = pgio.AppendInt32(dst, 16)
|
||||
if len(src.SecretKey) > 256 {
|
||||
return nil, errors.New("secret key too long")
|
||||
}
|
||||
msgLen := int32(12 + len(src.SecretKey))
|
||||
dst = pgio.AppendInt32(dst, msgLen)
|
||||
dst = pgio.AppendInt32(dst, cancelRequestCode)
|
||||
dst = pgio.AppendUint32(dst, src.ProcessID)
|
||||
dst = pgio.AppendUint32(dst, src.SecretKey)
|
||||
dst = append(dst, src.SecretKey...)
|
||||
return dst, nil
|
||||
}
|
||||
|
||||
@@ -49,10 +57,29 @@ func (src CancelRequest) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
ProcessID uint32
|
||||
SecretKey uint32
|
||||
SecretKey string
|
||||
}{
|
||||
Type: "CancelRequest",
|
||||
ProcessID: src.ProcessID,
|
||||
SecretKey: src.SecretKey,
|
||||
SecretKey: hex.EncodeToString(src.SecretKey),
|
||||
})
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements encoding/json.Unmarshaler.
|
||||
func (dst *CancelRequest) UnmarshalJSON(data []byte) error {
|
||||
var msg struct {
|
||||
ProcessID uint32
|
||||
SecretKey string
|
||||
}
|
||||
if err := json.Unmarshal(data, &msg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dst.ProcessID = msg.ProcessID
|
||||
secretKey, err := hex.DecodeString(msg.SecretKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
dst.SecretKey = secretKey
|
||||
return nil
|
||||
}
|
||||
|
||||
+1
-1
@@ -35,7 +35,7 @@ func (dst *CopyBothResponse) Decode(src []byte) error {
|
||||
}
|
||||
|
||||
columnFormatCodes := make([]uint16, columnCount)
|
||||
for i := 0; i < columnCount; i++ {
|
||||
for i := range columnCount {
|
||||
columnFormatCodes[i] = binary.BigEndian.Uint16(buf.Next(2))
|
||||
}
|
||||
|
||||
|
||||
+4
@@ -15,6 +15,10 @@ func (*CopyFail) Frontend() {}
|
||||
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||
// type identifier and 4 byte message length.
|
||||
func (dst *CopyFail) Decode(src []byte) error {
|
||||
if len(src) == 0 {
|
||||
return &invalidMessageFormatErr{messageType: "CopyFail"}
|
||||
}
|
||||
|
||||
idx := bytes.IndexByte(src, 0)
|
||||
if idx != len(src)-1 {
|
||||
return &invalidMessageFormatErr{messageType: "CopyFail"}
|
||||
|
||||
+1
-1
@@ -35,7 +35,7 @@ func (dst *CopyInResponse) Decode(src []byte) error {
|
||||
}
|
||||
|
||||
columnFormatCodes := make([]uint16, columnCount)
|
||||
for i := 0; i < columnCount; i++ {
|
||||
for i := range columnCount {
|
||||
columnFormatCodes[i] = binary.BigEndian.Uint16(buf.Next(2))
|
||||
}
|
||||
|
||||
|
||||
+1
-1
@@ -34,7 +34,7 @@ func (dst *CopyOutResponse) Decode(src []byte) error {
|
||||
}
|
||||
|
||||
columnFormatCodes := make([]uint16, columnCount)
|
||||
for i := 0; i < columnCount; i++ {
|
||||
for i := range columnCount {
|
||||
columnFormatCodes[i] = binary.BigEndian.Uint16(buf.Next(2))
|
||||
}
|
||||
|
||||
|
||||
+2
-5
@@ -31,16 +31,13 @@ func (dst *DataRow) Decode(src []byte) error {
|
||||
// large reallocate. This is too avoid one row with many columns from
|
||||
// permanently allocating memory.
|
||||
if cap(dst.Values) < fieldCount || cap(dst.Values)-fieldCount > 32 {
|
||||
newCap := 32
|
||||
if newCap < fieldCount {
|
||||
newCap = fieldCount
|
||||
}
|
||||
newCap := max(32, fieldCount)
|
||||
dst.Values = make([][]byte, fieldCount, newCap)
|
||||
} else {
|
||||
dst.Values = dst.Values[:fieldCount]
|
||||
}
|
||||
|
||||
for i := 0; i < fieldCount; i++ {
|
||||
for i := range fieldCount {
|
||||
if len(src[rp:]) < 4 {
|
||||
return &invalidMessageFormatErr{messageType: "DataRow"}
|
||||
}
|
||||
|
||||
+5
-2
@@ -52,6 +52,7 @@ type Frontend struct {
|
||||
readyForQuery ReadyForQuery
|
||||
rowDescription RowDescription
|
||||
portalSuspended PortalSuspended
|
||||
negotiateProtocolVersion NegotiateProtocolVersion
|
||||
|
||||
bodyLen int
|
||||
maxBodyLen int // maxBodyLen is the maximum length of a message body in octets. If a message body exceeds this length, Receive will return an error.
|
||||
@@ -230,7 +231,7 @@ func (f *Frontend) SendExecute(msg *Execute) {
|
||||
f.wbuf = newBuf
|
||||
|
||||
if f.tracer != nil {
|
||||
f.tracer.TraceQueryute('F', int32(len(f.wbuf)-prevLen), msg)
|
||||
f.tracer.traceExecute('F', int32(len(f.wbuf)-prevLen), msg)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -312,7 +313,7 @@ func (f *Frontend) Receive() (BackendMessage, error) {
|
||||
|
||||
f.msgType = header[0]
|
||||
|
||||
msgLength := int(binary.BigEndian.Uint32(header[1:]))
|
||||
msgLength := int(int32(binary.BigEndian.Uint32(header[1:])))
|
||||
if msgLength < 4 {
|
||||
return nil, fmt.Errorf("invalid message length: %d", msgLength)
|
||||
}
|
||||
@@ -383,6 +384,8 @@ func (f *Frontend) Receive() (BackendMessage, error) {
|
||||
msg = &f.copyBothResponse
|
||||
case 'Z':
|
||||
msg = &f.readyForQuery
|
||||
case 'v':
|
||||
msg = &f.negotiateProtocolVersion
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown message type: %c", f.msgType)
|
||||
}
|
||||
|
||||
+24
-3
@@ -23,6 +23,11 @@ func (*FunctionCall) Frontend() {}
|
||||
func (dst *FunctionCall) Decode(src []byte) error {
|
||||
*dst = FunctionCall{}
|
||||
rp := 0
|
||||
|
||||
if len(src) < 8 {
|
||||
return &invalidMessageFormatErr{messageType: "FunctionCall"}
|
||||
}
|
||||
|
||||
// Specifies the object ID of the function to call.
|
||||
dst.Function = binary.BigEndian.Uint32(src[rp:])
|
||||
rp += 4
|
||||
@@ -32,8 +37,13 @@ func (dst *FunctionCall) Decode(src []byte) error {
|
||||
// or it can equal the actual number of arguments.
|
||||
nArgumentCodes := int(binary.BigEndian.Uint16(src[rp:]))
|
||||
rp += 2
|
||||
|
||||
if len(src[rp:]) < nArgumentCodes*2+2 {
|
||||
return &invalidMessageFormatErr{messageType: "FunctionCall"}
|
||||
}
|
||||
|
||||
argumentCodes := make([]uint16, nArgumentCodes)
|
||||
for i := 0; i < nArgumentCodes; i++ {
|
||||
for i := range nArgumentCodes {
|
||||
// The argument format codes. Each must presently be zero (text) or one (binary).
|
||||
ac := binary.BigEndian.Uint16(src[rp:])
|
||||
if ac != 0 && ac != 1 {
|
||||
@@ -48,14 +58,22 @@ func (dst *FunctionCall) Decode(src []byte) error {
|
||||
nArguments := int(binary.BigEndian.Uint16(src[rp:]))
|
||||
rp += 2
|
||||
arguments := make([][]byte, nArguments)
|
||||
for i := 0; i < nArguments; i++ {
|
||||
for i := range nArguments {
|
||||
if len(src[rp:]) < 4 {
|
||||
return &invalidMessageFormatErr{messageType: "FunctionCall"}
|
||||
}
|
||||
// The length of the argument value, in bytes (this count does not include itself). Can be zero.
|
||||
// As a special case, -1 indicates a NULL argument value. No value bytes follow in the NULL case.
|
||||
argumentLength := int(binary.BigEndian.Uint32(src[rp:]))
|
||||
argumentLength := int(int32(binary.BigEndian.Uint32(src[rp:])))
|
||||
rp += 4
|
||||
if argumentLength == -1 {
|
||||
arguments[i] = nil
|
||||
} else if argumentLength < 0 {
|
||||
return &invalidMessageFormatErr{messageType: "FunctionCall"}
|
||||
} else {
|
||||
if len(src[rp:]) < argumentLength {
|
||||
return &invalidMessageFormatErr{messageType: "FunctionCall"}
|
||||
}
|
||||
// The value of the argument, in the format indicated by the associated format code. n is the above length.
|
||||
argumentValue := src[rp : rp+argumentLength]
|
||||
rp += argumentLength
|
||||
@@ -64,6 +82,9 @@ func (dst *FunctionCall) Decode(src []byte) error {
|
||||
}
|
||||
dst.Arguments = arguments
|
||||
// The format code for the function result. Must presently be zero (text) or one (binary).
|
||||
if len(src[rp:]) < 2 {
|
||||
return &invalidMessageFormatErr{messageType: "FunctionCall"}
|
||||
}
|
||||
resultFormatCode := binary.BigEndian.Uint16(src[rp:])
|
||||
if resultFormatCode != 0 && resultFormatCode != 1 {
|
||||
return &invalidMessageFormatErr{messageType: "FunctionCall"}
|
||||
|
||||
+2
-2
@@ -22,7 +22,7 @@ func (dst *FunctionCallResponse) Decode(src []byte) error {
|
||||
return &invalidMessageFormatErr{messageType: "FunctionCallResponse"}
|
||||
}
|
||||
rp := 0
|
||||
resultSize := int(binary.BigEndian.Uint32(src[rp:]))
|
||||
resultSize := int(int32(binary.BigEndian.Uint32(src[rp:])))
|
||||
rp += 4
|
||||
|
||||
if resultSize == -1 {
|
||||
@@ -30,7 +30,7 @@ func (dst *FunctionCallResponse) Decode(src []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(src[rp:]) != resultSize {
|
||||
if resultSize < 0 || len(src[rp:]) != resultSize {
|
||||
return &invalidMessageFormatErr{messageType: "FunctionCallResponse"}
|
||||
}
|
||||
|
||||
|
||||
+93
@@ -0,0 +1,93 @@
|
||||
package pgproto3
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/jackc/pgx/v5/internal/pgio"
|
||||
)
|
||||
|
||||
type NegotiateProtocolVersion struct {
|
||||
NewestMinorProtocol uint32
|
||||
UnrecognizedOptions []string
|
||||
}
|
||||
|
||||
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||
func (*NegotiateProtocolVersion) Backend() {}
|
||||
|
||||
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||
// type identifier and 4 byte message length.
|
||||
func (dst *NegotiateProtocolVersion) Decode(src []byte) error {
|
||||
if len(src) < 8 {
|
||||
return &invalidMessageLenErr{messageType: "NegotiateProtocolVersion", expectedLen: 8, actualLen: len(src)}
|
||||
}
|
||||
|
||||
dst.NewestMinorProtocol = binary.BigEndian.Uint32(src[:4])
|
||||
optionCount := int(binary.BigEndian.Uint32(src[4:8]))
|
||||
|
||||
rp := 8
|
||||
|
||||
// Use the remaining message size as an upper bound for capacity to prevent
|
||||
// malicious optionCount values from causing excessive memory allocation.
|
||||
capHint := optionCount
|
||||
if remaining := len(src) - rp; capHint > remaining {
|
||||
capHint = remaining
|
||||
}
|
||||
dst.UnrecognizedOptions = make([]string, 0, capHint)
|
||||
for i := 0; i < optionCount; i++ {
|
||||
if rp >= len(src) {
|
||||
return &invalidMessageFormatErr{messageType: "NegotiateProtocolVersion"}
|
||||
}
|
||||
end := rp
|
||||
for end < len(src) && src[end] != 0 {
|
||||
end++
|
||||
}
|
||||
if end >= len(src) {
|
||||
return &invalidMessageFormatErr{messageType: "NegotiateProtocolVersion"}
|
||||
}
|
||||
dst.UnrecognizedOptions = append(dst.UnrecognizedOptions, string(src[rp:end]))
|
||||
rp = end + 1
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *NegotiateProtocolVersion) Encode(dst []byte) ([]byte, error) {
|
||||
dst, sp := beginMessage(dst, 'v')
|
||||
dst = pgio.AppendUint32(dst, src.NewestMinorProtocol)
|
||||
dst = pgio.AppendUint32(dst, uint32(len(src.UnrecognizedOptions)))
|
||||
for _, option := range src.UnrecognizedOptions {
|
||||
dst = append(dst, option...)
|
||||
dst = append(dst, 0)
|
||||
}
|
||||
return finishMessage(dst, sp)
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
func (src NegotiateProtocolVersion) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
NewestMinorProtocol uint32
|
||||
UnrecognizedOptions []string
|
||||
}{
|
||||
Type: "NegotiateProtocolVersion",
|
||||
NewestMinorProtocol: src.NewestMinorProtocol,
|
||||
UnrecognizedOptions: src.UnrecognizedOptions,
|
||||
})
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements encoding/json.Unmarshaler.
|
||||
func (dst *NegotiateProtocolVersion) UnmarshalJSON(data []byte) error {
|
||||
var msg struct {
|
||||
NewestMinorProtocol uint32
|
||||
UnrecognizedOptions []string
|
||||
}
|
||||
if err := json.Unmarshal(data, &msg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dst.NewestMinorProtocol = msg.NewestMinorProtocol
|
||||
dst.UnrecognizedOptions = msg.UnrecognizedOptions
|
||||
return nil
|
||||
}
|
||||
+1
-1
@@ -33,7 +33,7 @@ func (dst *ParameterDescription) Decode(src []byte) error {
|
||||
|
||||
*dst = ParameterDescription{ParameterOIDs: make([]uint32, parameterCount)}
|
||||
|
||||
for i := 0; i < parameterCount; i++ {
|
||||
for i := range parameterCount {
|
||||
dst.ParameterOIDs[i] = binary.BigEndian.Uint32(buf.Next(4))
|
||||
}
|
||||
|
||||
|
||||
+1
-1
@@ -43,7 +43,7 @@ func (dst *Parse) Decode(src []byte) error {
|
||||
}
|
||||
parameterOIDCount := int(binary.BigEndian.Uint16(buf.Next(2)))
|
||||
|
||||
for i := 0; i < parameterOIDCount; i++ {
|
||||
for range parameterOIDCount {
|
||||
if buf.Len() < 4 {
|
||||
return &invalidMessageFormatErr{messageType: "Parse"}
|
||||
}
|
||||
|
||||
+4
@@ -15,6 +15,10 @@ func (*Query) Frontend() {}
|
||||
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||
// type identifier and 4 byte message length.
|
||||
func (dst *Query) Decode(src []byte) error {
|
||||
if len(src) == 0 {
|
||||
return &invalidMessageFormatErr{messageType: "Query"}
|
||||
}
|
||||
|
||||
i := bytes.IndexByte(src, 0)
|
||||
if i != len(src)-1 {
|
||||
return &invalidMessageFormatErr{messageType: "Query"}
|
||||
|
||||
+1
-1
@@ -64,7 +64,7 @@ func (dst *RowDescription) Decode(src []byte) error {
|
||||
|
||||
dst.Fields = dst.Fields[0:0]
|
||||
|
||||
for i := 0; i < fieldCount; i++ {
|
||||
for range fieldCount {
|
||||
var fd FieldDescription
|
||||
|
||||
idx := bytes.IndexByte(src[rp:], 0)
|
||||
|
||||
+3
@@ -32,6 +32,9 @@ func (dst *SASLInitialResponse) Decode(src []byte) error {
|
||||
dst.AuthMechanism = string(src[rp:idx])
|
||||
rp = idx + 1
|
||||
|
||||
if len(src[rp:]) < 4 {
|
||||
return errors.New("invalid SASLInitialResponse")
|
||||
}
|
||||
rp += 4 // The rest of the message is data so we can just skip the size
|
||||
dst.Data = src[rp:]
|
||||
|
||||
|
||||
+8
-3
@@ -10,7 +10,12 @@ import (
|
||||
"github.com/jackc/pgx/v5/internal/pgio"
|
||||
)
|
||||
|
||||
const ProtocolVersionNumber = 196608 // 3.0
|
||||
const (
|
||||
ProtocolVersion30 = 196608 // 3.0
|
||||
ProtocolVersion32 = 196610 // 3.2
|
||||
ProtocolVersionLatest = ProtocolVersion32 // Latest is 3.2
|
||||
ProtocolVersionNumber = ProtocolVersion30 // Default is still 3.0
|
||||
)
|
||||
|
||||
type StartupMessage struct {
|
||||
ProtocolVersion uint32
|
||||
@@ -30,8 +35,8 @@ func (dst *StartupMessage) Decode(src []byte) error {
|
||||
dst.ProtocolVersion = binary.BigEndian.Uint32(src)
|
||||
rp := 4
|
||||
|
||||
if dst.ProtocolVersion != ProtocolVersionNumber {
|
||||
return fmt.Errorf("Bad startup message version number. Expected %d, got %d", ProtocolVersionNumber, dst.ProtocolVersion)
|
||||
if dst.ProtocolVersion != ProtocolVersion30 && dst.ProtocolVersion != ProtocolVersion32 {
|
||||
return fmt.Errorf("Bad startup message version number. Expected %d or %d, got %d", ProtocolVersion30, ProtocolVersion32, dst.ProtocolVersion)
|
||||
}
|
||||
|
||||
dst.Parameters = make(map[string]string)
|
||||
|
||||
+2
-2
@@ -82,7 +82,7 @@ func (t *tracer) traceMessage(sender byte, encodedLen int32, msg Message) {
|
||||
case *ErrorResponse:
|
||||
t.traceErrorResponse(sender, encodedLen, msg)
|
||||
case *Execute:
|
||||
t.TraceQueryute(sender, encodedLen, msg)
|
||||
t.traceExecute(sender, encodedLen, msg)
|
||||
case *Flush:
|
||||
t.traceFlush(sender, encodedLen, msg)
|
||||
case *FunctionCall:
|
||||
@@ -260,7 +260,7 @@ func (t *tracer) traceErrorResponse(sender byte, encodedLen int32, msg *ErrorRes
|
||||
t.writeTrace(sender, encodedLen, "ErrorResponse", nil)
|
||||
}
|
||||
|
||||
func (t *tracer) TraceQueryute(sender byte, encodedLen int32, msg *Execute) {
|
||||
func (t *tracer) traceExecute(sender byte, encodedLen int32, msg *Execute) {
|
||||
t.writeTrace(sender, encodedLen, "Execute", func() {
|
||||
fmt.Fprintf(t.buf, "\t %s %d", traceDoubleQuotedString([]byte(msg.Portal)), msg.MaxRows)
|
||||
})
|
||||
|
||||
Reference in New Issue
Block a user