Skip to content

Commit

Permalink
Check for overflow on uint16 sizes in pgproto3
Browse files Browse the repository at this point in the history
  • Loading branch information
jackc committed Mar 4, 2024
1 parent adbb38f commit 20344df
Show file tree
Hide file tree
Showing 9 changed files with 53 additions and 0 deletions.
11 changes: 11 additions & 0 deletions pgproto3/bind.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ import (
"encoding/binary"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"math"

"github.com/jackc/pgx/v5/internal/pgio"
)
Expand Down Expand Up @@ -116,11 +118,17 @@ func (src *Bind) Encode(dst []byte) ([]byte, error) {
dst = append(dst, src.PreparedStatement...)
dst = append(dst, 0)

if len(src.ParameterFormatCodes) > math.MaxUint16 {
return nil, errors.New("too many parameter format codes")
}
dst = pgio.AppendUint16(dst, uint16(len(src.ParameterFormatCodes)))
for _, fc := range src.ParameterFormatCodes {
dst = pgio.AppendInt16(dst, fc)
}

if len(src.Parameters) > math.MaxUint16 {
return nil, errors.New("too many parameters")
}
dst = pgio.AppendUint16(dst, uint16(len(src.Parameters)))
for _, p := range src.Parameters {
if p == nil {
Expand All @@ -132,6 +140,9 @@ func (src *Bind) Encode(dst []byte) ([]byte, error) {
dst = append(dst, p...)
}

if len(src.ResultFormatCodes) > math.MaxUint16 {
return nil, errors.New("too many result format codes")
}
dst = pgio.AppendUint16(dst, uint16(len(src.ResultFormatCodes)))
for _, fc := range src.ResultFormatCodes {
dst = pgio.AppendInt16(dst, fc)
Expand Down
4 changes: 4 additions & 0 deletions pgproto3/copy_both_response.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"encoding/binary"
"encoding/json"
"errors"
"math"

"github.com/jackc/pgx/v5/internal/pgio"
)
Expand Down Expand Up @@ -47,6 +48,9 @@ func (dst *CopyBothResponse) Decode(src []byte) error {
func (src *CopyBothResponse) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'W')
dst = append(dst, src.OverallFormat)
if len(src.ColumnFormatCodes) > math.MaxUint16 {
return nil, errors.New("too many column format codes")
}
dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes)))
for _, fc := range src.ColumnFormatCodes {
dst = pgio.AppendUint16(dst, fc)
Expand Down
4 changes: 4 additions & 0 deletions pgproto3/copy_in_response.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"encoding/binary"
"encoding/json"
"errors"
"math"

"github.com/jackc/pgx/v5/internal/pgio"
)
Expand Down Expand Up @@ -48,6 +49,9 @@ func (src *CopyInResponse) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'G')

dst = append(dst, src.OverallFormat)
if len(src.ColumnFormatCodes) > math.MaxUint16 {
return nil, errors.New("too many column format codes")
}
dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes)))
for _, fc := range src.ColumnFormatCodes {
dst = pgio.AppendUint16(dst, fc)
Expand Down
4 changes: 4 additions & 0 deletions pgproto3/copy_out_response.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"encoding/binary"
"encoding/json"
"errors"
"math"

"github.com/jackc/pgx/v5/internal/pgio"
)
Expand Down Expand Up @@ -48,6 +49,9 @@ func (src *CopyOutResponse) Encode(dst []byte) ([]byte, error) {

dst = append(dst, src.OverallFormat)

if len(src.ColumnFormatCodes) > math.MaxUint16 {
return nil, errors.New("too many column format codes")
}
dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes)))
for _, fc := range src.ColumnFormatCodes {
dst = pgio.AppendUint16(dst, fc)
Expand Down
5 changes: 5 additions & 0 deletions pgproto3/data_row.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"encoding/binary"
"encoding/hex"
"encoding/json"
"errors"
"math"

"github.com/jackc/pgx/v5/internal/pgio"
)
Expand Down Expand Up @@ -66,6 +68,9 @@ func (dst *DataRow) Decode(src []byte) error {
func (src *DataRow) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'D')

if len(src.Values) > math.MaxUint16 {
return nil, errors.New("too many values")
}
dst = pgio.AppendUint16(dst, uint16(len(src.Values)))
for _, v := range src.Values {
if v == nil {
Expand Down
10 changes: 10 additions & 0 deletions pgproto3/function_call.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package pgproto3

import (
"encoding/binary"
"errors"
"math"

"github.com/jackc/pgx/v5/internal/pgio"
)
Expand Down Expand Up @@ -74,10 +76,18 @@ func (dst *FunctionCall) Decode(src []byte) error {
func (src *FunctionCall) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'F')
dst = pgio.AppendUint32(dst, src.Function)

if len(src.ArgFormatCodes) > math.MaxUint16 {
return nil, errors.New("too many arg format codes")
}
dst = pgio.AppendUint16(dst, uint16(len(src.ArgFormatCodes)))
for _, argFormatCode := range src.ArgFormatCodes {
dst = pgio.AppendUint16(dst, argFormatCode)
}

if len(src.Arguments) > math.MaxUint16 {
return nil, errors.New("too many arguments")
}
dst = pgio.AppendUint16(dst, uint16(len(src.Arguments)))
for _, argument := range src.Arguments {
if argument == nil {
Expand Down
5 changes: 5 additions & 0 deletions pgproto3/parameter_description.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"bytes"
"encoding/binary"
"encoding/json"
"errors"
"math"

"github.com/jackc/pgx/v5/internal/pgio"
)
Expand Down Expand Up @@ -42,6 +44,9 @@ func (dst *ParameterDescription) Decode(src []byte) error {
func (src *ParameterDescription) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 't')

if len(src.ParameterOIDs) > math.MaxUint16 {
return nil, errors.New("too many parameter oids")
}
dst = pgio.AppendUint16(dst, uint16(len(src.ParameterOIDs)))
for _, oid := range src.ParameterOIDs {
dst = pgio.AppendUint32(dst, oid)
Expand Down
5 changes: 5 additions & 0 deletions pgproto3/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"bytes"
"encoding/binary"
"encoding/json"
"errors"
"math"

"github.com/jackc/pgx/v5/internal/pgio"
)
Expand Down Expand Up @@ -60,6 +62,9 @@ func (src *Parse) Encode(dst []byte) ([]byte, error) {
dst = append(dst, src.Query...)
dst = append(dst, 0)

if len(src.ParameterOIDs) > math.MaxUint16 {
return nil, errors.New("too many parameter oids")
}
dst = pgio.AppendUint16(dst, uint16(len(src.ParameterOIDs)))
for _, oid := range src.ParameterOIDs {
dst = pgio.AppendUint32(dst, oid)
Expand Down
5 changes: 5 additions & 0 deletions pgproto3/row_description.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"bytes"
"encoding/binary"
"encoding/json"
"errors"
"math"

"github.com/jackc/pgx/v5/internal/pgio"
)
Expand Down Expand Up @@ -102,6 +104,9 @@ func (dst *RowDescription) Decode(src []byte) error {
func (src *RowDescription) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'T')

if len(src.Fields) > math.MaxUint16 {
return nil, errors.New("too many fields")
}
dst = pgio.AppendUint16(dst, uint16(len(src.Fields)))
for _, fd := range src.Fields {
dst = append(dst, fd.Name...)
Expand Down

0 comments on commit 20344df

Please sign in to comment.