[chore]: Bump github.com/jackc/pgx/v5 from 5.5.5 to 5.6.0 (#2929)

This commit is contained in:
dependabot[bot] 2024-05-27 09:35:41 +00:00 committed by GitHub
parent 3d3e99ae52
commit 0a18c0d802
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
36 changed files with 969 additions and 561 deletions

2
go.mod
View file

@ -39,7 +39,7 @@ require (
github.com/gorilla/feeds v1.1.2
github.com/gorilla/websocket v1.5.1
github.com/h2non/filetype v1.1.3
github.com/jackc/pgx/v5 v5.5.5
github.com/jackc/pgx/v5 v5.6.0
github.com/microcosm-cc/bluemonday v1.0.26
github.com/miekg/dns v1.1.59
github.com/minio/minio-go/v7 v7.0.70

4
go.sum
View file

@ -374,8 +374,8 @@ github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsI
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk=
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
github.com/jackc/pgx/v5 v5.5.5 h1:amBjrZVmksIdNjxGW/IiIMzxMKZFelXbUoPNb+8sjQw=
github.com/jackc/pgx/v5 v5.5.5/go.mod h1:ez9gk+OAat140fv9ErkZDYFWmXLfV+++K0uAOiwgm1A=
github.com/jackc/pgx/v5 v5.6.0 h1:SWJzexBzPL5jb0GEsrPMLIsi/3jOo7RHlzTjcAeDrPY=
github.com/jackc/pgx/v5 v5.6.0/go.mod h1:DNZ/vlrUnhWCoFGxHAG8U2ljioxukquj7utPDgtQdTw=
github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk=
github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI=

View file

@ -1,3 +1,22 @@
# 5.6.0 (May 25, 2024)
* Add StrictNamedArgs (Tomas Zahradnicek)
* Add support for macaddr8 type (Carlos Pérez-Aradros Herce)
* Add SeverityUnlocalized field to PgError / Notice
* Performance optimization of RowToStructByPos/Name (Zach Olstein)
* Allow customizing context canceled behavior for pgconn
* Add ScanLocation to pgtype.Timestamp[tz]Codec
* Add custom data to pgconn.PgConn
* Fix ResultReader.Read() to handle nil values
* Do not encode interval microseconds when they are 0 (Carlos Pérez-Aradros Herce)
* pgconn.SafeToRetry checks for wrapped errors (tjasko)
* Failed connection attempts include all errors
* Optimize LargeObject.Read (Mitar)
* Add tracing for connection acquire and release from pool (ngavinsir)
* Fix encode driver.Valuer not called when nil
* Add support for custom JSON marshal and unmarshal (Mitar)
* Use Go default keepalive for TCP connections (Hans-Joachim Kliemeck)
# 5.5.5 (March 9, 2024)
Use spaces instead of parentheses for SQL sanitization.

View file

@ -29,6 +29,7 @@ Create and setup a test database:
export PGDATABASE=pgx_test
createdb
psql -c 'create extension hstore;'
psql -c 'create extension ltree;'
psql -c 'create domain uint64 as numeric(20,0);'
```

View file

@ -92,7 +92,7 @@ See the presentation at Golang Estonia, [PGX Top to Bottom](https://www.youtube.
## Supported Go and PostgreSQL Versions
pgx supports the same versions of Go and PostgreSQL that are supported by their respective teams. For [Go](https://golang.org/doc/devel/release.html#policy) that is the two most recent major releases and for [PostgreSQL](https://www.postgresql.org/support/versioning/) the major releases in the last 5 years. This means pgx supports Go 1.20 and higher and PostgreSQL 12 and higher. pgx also is tested against the latest version of [CockroachDB](https://www.cockroachlabs.com/product/).
pgx supports the same versions of Go and PostgreSQL that are supported by their respective teams. For [Go](https://golang.org/doc/devel/release.html#policy) that is the two most recent major releases and for [PostgreSQL](https://www.postgresql.org/support/versioning/) the major releases in the last 5 years. This means pgx supports Go 1.21 and higher and PostgreSQL 12 and higher. pgx also is tested against the latest version of [CockroachDB](https://www.cockroachlabs.com/product/).
## Version Policy

View file

@ -12,7 +12,7 @@
type QueuedQuery struct {
SQL string
Arguments []any
fn batchItemFunc
Fn batchItemFunc
sd *pgconn.StatementDescription
}
@ -20,7 +20,7 @@ type QueuedQuery struct {
// Query sets fn to be called when the response to qq is received.
func (qq *QueuedQuery) Query(fn func(rows Rows) error) {
qq.fn = func(br BatchResults) error {
qq.Fn = func(br BatchResults) error {
rows, _ := br.Query()
defer rows.Close()
@ -36,7 +36,7 @@ func (qq *QueuedQuery) Query(fn func(rows Rows) error) {
// Query sets fn to be called when the response to qq is received.
func (qq *QueuedQuery) QueryRow(fn func(row Row) error) {
qq.fn = func(br BatchResults) error {
qq.Fn = func(br BatchResults) error {
row := br.QueryRow()
return fn(row)
}
@ -44,7 +44,7 @@ func (qq *QueuedQuery) QueryRow(fn func(row Row) error) {
// Exec sets fn to be called when the response to qq is received.
func (qq *QueuedQuery) Exec(fn func(ct pgconn.CommandTag) error) {
qq.fn = func(br BatchResults) error {
qq.Fn = func(br BatchResults) error {
ct, err := br.Exec()
if err != nil {
return err
@ -228,8 +228,8 @@ func (br *batchResults) Close() error {
// Read and run fn for all remaining items
for br.err == nil && !br.closed && br.b != nil && br.qqIdx < len(br.b.QueuedQueries) {
if br.b.QueuedQueries[br.qqIdx].fn != nil {
err := br.b.QueuedQueries[br.qqIdx].fn(br)
if br.b.QueuedQueries[br.qqIdx].Fn != nil {
err := br.b.QueuedQueries[br.qqIdx].Fn(br)
if err != nil {
br.err = err
}
@ -397,8 +397,8 @@ func (br *pipelineBatchResults) Close() error {
// Read and run fn for all remaining items
for br.err == nil && !br.closed && br.b != nil && br.qqIdx < len(br.b.QueuedQueries) {
if br.b.QueuedQueries[br.qqIdx].fn != nil {
err := br.b.QueuedQueries[br.qqIdx].fn(br)
if br.b.QueuedQueries[br.qqIdx].Fn != nil {
err := br.b.QueuedQueries[br.qqIdx].Fn(br)
if err != nil {
br.err = err
}

View file

@ -10,7 +10,6 @@
"strings"
"time"
"github.com/jackc/pgx/v5/internal/anynil"
"github.com/jackc/pgx/v5/internal/sanitize"
"github.com/jackc/pgx/v5/internal/stmtcache"
"github.com/jackc/pgx/v5/pgconn"
@ -624,7 +623,7 @@ func (c *Conn) getRows(ctx context.Context, sql string, args []any) *baseRows {
// to execute. It does not use named prepared statements. But it does use the unnamed prepared statement to get the
// statement description on the first round trip and then uses it to execute the query on the second round trip. This
// may cause problems with connection poolers that switch the underlying connection between round trips. It is safe
// even when the the database schema is modified concurrently.
// even when the database schema is modified concurrently.
QueryExecModeDescribeExec
// Assume the PostgreSQL query parameter types based on the Go type of the arguments. This uses the extended protocol
@ -755,7 +754,6 @@ func (c *Conn) Query(ctx context.Context, sql string, args ...any) (Rows, error)
}
c.eqb.reset()
anynil.NormalizeSlice(args)
rows := c.getRows(ctx, sql, args)
var err error

View file

@ -11,9 +11,10 @@
conn, err := pgx.Connect(context.Background(), os.Getenv("DATABASE_URL"))
The database connection string can be in URL or DSN format. Both PostgreSQL settings and pgx settings can be specified
here. In addition, a config struct can be created by [ParseConfig] and modified before establishing the connection with
[ConnectConfig] to configure settings such as tracing that cannot be configured with a connection string.
The database connection string can be in URL or key/value format. Both PostgreSQL settings and pgx settings can be
specified here. In addition, a config struct can be created by [ParseConfig] and modified before establishing the
connection with [ConnectConfig] to configure settings such as tracing that cannot be configured with a connection
string.
Connection Pool
@ -23,8 +24,8 @@
Query Interface
pgx implements Query in the familiar database/sql style. However, pgx provides generic functions such as CollectRows and
ForEachRow that are a simpler and safer way of processing rows than manually calling rows.Next(), rows.Scan, and
rows.Err().
ForEachRow that are a simpler and safer way of processing rows than manually calling defer rows.Close(), rows.Next(),
rows.Scan, and rows.Err().
CollectRows can be used collect all returned rows into a slice.

View file

@ -1,10 +1,8 @@
package pgx
import (
"database/sql/driver"
"fmt"
"github.com/jackc/pgx/v5/internal/anynil"
"github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/v5/pgtype"
)
@ -23,10 +21,15 @@ type ExtendedQueryBuilder struct {
func (eqb *ExtendedQueryBuilder) Build(m *pgtype.Map, sd *pgconn.StatementDescription, args []any) error {
eqb.reset()
anynil.NormalizeSlice(args)
if sd == nil {
return eqb.appendParamsForQueryExecModeExec(m, args)
for i := range args {
err := eqb.appendParam(m, 0, pgtype.TextFormatCode, args[i])
if err != nil {
err = fmt.Errorf("failed to encode args[%d]: %w", i, err)
return err
}
}
return nil
}
if len(sd.ParamOIDs) != len(args) {
@ -113,10 +116,6 @@ func (eqb *ExtendedQueryBuilder) reset() {
}
func (eqb *ExtendedQueryBuilder) encodeExtendedParamValue(m *pgtype.Map, oid uint32, formatCode int16, arg any) ([]byte, error) {
if anynil.Is(arg) {
return nil, nil
}
if eqb.paramValueBytes == nil {
eqb.paramValueBytes = make([]byte, 0, 128)
}
@ -145,74 +144,3 @@ func (eqb *ExtendedQueryBuilder) chooseParameterFormatCode(m *pgtype.Map, oid ui
return m.FormatCodeForOID(oid)
}
// appendParamsForQueryExecModeExec appends the args to eqb.
//
// Parameters must be encoded in the text format because of differences in type conversion between timestamps and
// dates. In QueryExecModeExec we don't know what the actual PostgreSQL type is. To determine the type we use the
// Go type to OID type mapping registered by RegisterDefaultPgType. However, the Go time.Time represents both
// PostgreSQL timestamp[tz] and date. To use the binary format we would need to also specify what the PostgreSQL
// type OID is. But that would mean telling PostgreSQL that we have sent a timestamp[tz] when what is needed is a date.
// This means that the value is converted from text to timestamp[tz] to date. This means it does a time zone conversion
// before converting it to date. This means that dates can be shifted by one day. In text format without that double
// type conversion it takes the date directly and ignores time zone (i.e. it works).
//
// Given that the whole point of QueryExecModeExec is to operate without having to know the PostgreSQL types there is
// no way to safely use binary or to specify the parameter OIDs.
func (eqb *ExtendedQueryBuilder) appendParamsForQueryExecModeExec(m *pgtype.Map, args []any) error {
for _, arg := range args {
if arg == nil {
err := eqb.appendParam(m, 0, TextFormatCode, arg)
if err != nil {
return err
}
} else {
dt, ok := m.TypeForValue(arg)
if !ok {
var tv pgtype.TextValuer
if tv, ok = arg.(pgtype.TextValuer); ok {
t, err := tv.TextValue()
if err != nil {
return err
}
dt, ok = m.TypeForOID(pgtype.TextOID)
if ok {
arg = t
}
}
}
if !ok {
var dv driver.Valuer
if dv, ok = arg.(driver.Valuer); ok {
v, err := dv.Value()
if err != nil {
return err
}
dt, ok = m.TypeForValue(v)
if ok {
arg = v
}
}
}
if !ok {
var str fmt.Stringer
if str, ok = arg.(fmt.Stringer); ok {
dt, ok = m.TypeForOID(pgtype.TextOID)
if ok {
arg = str.String()
}
}
}
if !ok {
return &unknownArgumentTypeQueryExecModeExecError{arg: arg}
}
err := eqb.appendParam(m, dt.OID, TextFormatCode, arg)
if err != nil {
return err
}
}
}
return nil
}

View file

@ -1,36 +0,0 @@
package anynil
import "reflect"
// Is returns true if value is any type of nil. e.g. nil or []byte(nil).
func Is(value any) bool {
if value == nil {
return true
}
refVal := reflect.ValueOf(value)
switch refVal.Kind() {
case reflect.Chan, reflect.Func, reflect.Map, reflect.Ptr, reflect.UnsafePointer, reflect.Interface, reflect.Slice:
return refVal.IsNil()
default:
return false
}
}
// Normalize converts typed nils (e.g. []byte(nil)) into untyped nil. Other values are returned unmodified.
func Normalize(v any) any {
if Is(v) {
return nil
}
return v
}
// NormalizeSlice converts all typed nils (e.g. []byte(nil)) in s into untyped nils. Other values are unmodified. s is
// mutated in place.
func NormalizeSlice(s []any) {
for i := range s {
if Is(s[i]) {
s[i] = nil
}
}
}

View file

@ -4,6 +4,8 @@
"context"
"errors"
"io"
"github.com/jackc/pgx/v5/pgtype"
)
// The PostgreSQL wire protocol has a limit of 1 GB - 1 per message. See definition of
@ -115,9 +117,10 @@ func (o *LargeObject) Read(p []byte) (int, error) {
expected = maxLargeObjectMessageLength
}
var res []byte
res := pgtype.PreallocBytes(p[nTotal:])
err := o.tx.QueryRow(o.ctx, "select loread($1, $2)", o.fd, expected).Scan(&res)
copy(p[nTotal:], res)
// We compute expected so that it always fits into p, so it should never happen
// that PreallocBytes's ScanBytes had to allocate a new slice.
nTotal += len(res)
if err != nil {
return nTotal, err

View file

@ -2,6 +2,7 @@
import (
"context"
"fmt"
"strconv"
"strings"
"unicode/utf8"
@ -21,6 +22,34 @@
// RewriteQuery implements the QueryRewriter interface.
func (na NamedArgs) RewriteQuery(ctx context.Context, conn *Conn, sql string, args []any) (newSQL string, newArgs []any, err error) {
return rewriteQuery(na, sql, false)
}
// StrictNamedArgs can be used in the same way as NamedArgs, but provided arguments are also checked to include all
// named arguments that the sql query uses, and no extra arguments.
type StrictNamedArgs map[string]any
// RewriteQuery implements the QueryRewriter interface.
func (sna StrictNamedArgs) RewriteQuery(ctx context.Context, conn *Conn, sql string, args []any) (newSQL string, newArgs []any, err error) {
return rewriteQuery(sna, sql, true)
}
type namedArg string
type sqlLexer struct {
src string
start int
pos int
nested int // multiline comment nesting level.
stateFn stateFn
parts []any
nameToOrdinal map[namedArg]int
}
type stateFn func(*sqlLexer) stateFn
func rewriteQuery(na map[string]any, sql string, isStrict bool) (newSQL string, newArgs []any, err error) {
l := &sqlLexer{
src: sql,
stateFn: rawState,
@ -44,27 +73,24 @@ func (na NamedArgs) RewriteQuery(ctx context.Context, conn *Conn, sql string, ar
newArgs = make([]any, len(l.nameToOrdinal))
for name, ordinal := range l.nameToOrdinal {
newArgs[ordinal-1] = na[string(name)]
var found bool
newArgs[ordinal-1], found = na[string(name)]
if isStrict && !found {
return "", nil, fmt.Errorf("argument %s found in sql query but not present in StrictNamedArgs", name)
}
}
if isStrict {
for name := range na {
if _, found := l.nameToOrdinal[namedArg(name)]; !found {
return "", nil, fmt.Errorf("argument %s of StrictNamedArgs not found in sql query", name)
}
}
}
return sb.String(), newArgs, nil
}
type namedArg string
type sqlLexer struct {
src string
start int
pos int
nested int // multiline comment nesting level.
stateFn stateFn
parts []any
nameToOrdinal map[namedArg]int
}
type stateFn func(*sqlLexer) stateFn
func rawState(l *sqlLexer) stateFn {
for {
r, width := utf8.DecodeRuneInString(l.src[l.pos:])

View file

@ -19,6 +19,7 @@
"github.com/jackc/pgpassfile"
"github.com/jackc/pgservicefile"
"github.com/jackc/pgx/v5/pgconn/ctxwatch"
"github.com/jackc/pgx/v5/pgproto3"
)
@ -39,7 +40,12 @@ type Config struct {
DialFunc DialFunc // e.g. net.Dialer.DialContext
LookupFunc LookupFunc // e.g. net.Resolver.LookupHost
BuildFrontend BuildFrontendFunc
RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name)
// BuildContextWatcherHandler is called to create a ContextWatcherHandler for a connection. The handler is called
// when a context passed to a PgConn method is canceled.
BuildContextWatcherHandler func(*PgConn) ctxwatch.Handler
RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name)
KerberosSrvName string
KerberosSpn string
@ -70,7 +76,7 @@ type Config struct {
// ParseConfigOptions contains options that control how a config is built such as GetSSLPassword.
type ParseConfigOptions struct {
// GetSSLPassword gets the password to decrypt a SSL client certificate. This is analogous to the the libpq function
// GetSSLPassword gets the password to decrypt a SSL client certificate. This is analogous to the libpq function
// PQsetSSLKeyPassHook_OpenSSL.
GetSSLPassword GetSSLPasswordFunc
}
@ -112,6 +118,14 @@ type FallbackConfig struct {
TLSConfig *tls.Config // nil disables TLS
}
// connectOneConfig is the configuration for a single attempt to connect to a single host.
type connectOneConfig struct {
network string
address string
originalHostname string // original hostname before resolving
tlsConfig *tls.Config // nil disables TLS
}
// isAbsolutePath checks if the provided value is an absolute path either
// beginning with a forward slash (as on Linux-based systems) or with a capital
// letter A-Z followed by a colon and a backslash, e.g., "C:\", (as on Windows).
@ -146,11 +160,11 @@ func NetworkAddress(host string, port uint16) (network, address string) {
// ParseConfig builds a *Config from connString with similar behavior to the PostgreSQL standard C library libpq. It
// uses the same defaults as libpq (e.g. port=5432) and understands most PG* environment variables. ParseConfig closely
// matches the parsing behavior of libpq. connString may either be in URL format or keyword = value format (DSN style).
// See https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING for details. connString also may be
// empty to only read from the environment. If a password is not supplied it will attempt to read the .pgpass file.
// matches the parsing behavior of libpq. connString may either be in URL format or keyword = value format. See
// https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING for details. connString also may be empty
// to only read from the environment. If a password is not supplied it will attempt to read the .pgpass file.
//
// # Example DSN
// # Example Keyword/Value
// user=jack password=secret host=pg.example.com port=5432 dbname=mydb sslmode=verify-ca
//
// # Example URL
@ -169,7 +183,7 @@ func NetworkAddress(host string, port uint16) (network, address string) {
// postgres://jack:secret@foo.example.com:5432,bar.example.com:5432/mydb
//
// ParseConfig currently recognizes the following environment variable and their parameter key word equivalents passed
// via database URL or DSN:
// via database URL or keyword/value:
//
// PGHOST
// PGPORT
@ -233,16 +247,16 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con
connStringSettings := make(map[string]string)
if connString != "" {
var err error
// connString may be a database URL or a DSN
// connString may be a database URL or in PostgreSQL keyword/value format
if strings.HasPrefix(connString, "postgres://") || strings.HasPrefix(connString, "postgresql://") {
connStringSettings, err = parseURLSettings(connString)
if err != nil {
return nil, &ParseConfigError{ConnString: connString, msg: "failed to parse as URL", err: err}
}
} else {
connStringSettings, err = parseDSNSettings(connString)
connStringSettings, err = parseKeywordValueSettings(connString)
if err != nil {
return nil, &ParseConfigError{ConnString: connString, msg: "failed to parse as DSN", err: err}
return nil, &ParseConfigError{ConnString: connString, msg: "failed to parse as keyword/value", err: err}
}
}
}
@ -266,6 +280,9 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con
BuildFrontend: func(r io.Reader, w io.Writer) *pgproto3.Frontend {
return pgproto3.NewFrontend(r, w)
},
BuildContextWatcherHandler: func(pgConn *PgConn) ctxwatch.Handler {
return &DeadlineContextWatcherHandler{Conn: pgConn.conn}
},
OnPgError: func(_ *PgConn, pgErr *PgError) bool {
// we want to automatically close any fatal errors
if strings.EqualFold(pgErr.Severity, "FATAL") {
@ -517,7 +534,7 @@ func isIPOnly(host string) bool {
var asciiSpace = [256]uint8{'\t': 1, '\n': 1, '\v': 1, '\f': 1, '\r': 1, ' ': 1}
func parseDSNSettings(s string) (map[string]string, error) {
func parseKeywordValueSettings(s string) (map[string]string, error) {
settings := make(map[string]string)
nameMap := map[string]string{
@ -528,7 +545,7 @@ func parseDSNSettings(s string) (map[string]string, error) {
var key, val string
eqIdx := strings.IndexRune(s, '=')
if eqIdx < 0 {
return nil, errors.New("invalid dsn")
return nil, errors.New("invalid keyword/value")
}
key = strings.Trim(s[:eqIdx], " \t\n\r\v\f")
@ -580,7 +597,7 @@ func parseDSNSettings(s string) (map[string]string, error) {
}
if key == "" {
return nil, errors.New("invalid dsn")
return nil, errors.New("invalid keyword/value")
}
settings[key] = val
@ -800,7 +817,8 @@ func parsePort(s string) (uint16, error) {
}
func makeDefaultDialer() *net.Dialer {
return &net.Dialer{KeepAlive: 5 * time.Minute}
// rely on GOLANG KeepAlive settings
return &net.Dialer{}
}
func makeDefaultResolver() *net.Resolver {

View file

@ -8,9 +8,8 @@
// ContextWatcher watches a context and performs an action when the context is canceled. It can watch one context at a
// time.
type ContextWatcher struct {
onCancel func()
onUnwatchAfterCancel func()
unwatchChan chan struct{}
handler Handler
unwatchChan chan struct{}
lock sync.Mutex
watchInProgress bool
@ -20,11 +19,10 @@ type ContextWatcher struct {
// NewContextWatcher returns a ContextWatcher. onCancel will be called when a watched context is canceled.
// OnUnwatchAfterCancel will be called when Unwatch is called and the watched context had already been canceled and
// onCancel called.
func NewContextWatcher(onCancel func(), onUnwatchAfterCancel func()) *ContextWatcher {
func NewContextWatcher(handler Handler) *ContextWatcher {
cw := &ContextWatcher{
onCancel: onCancel,
onUnwatchAfterCancel: onUnwatchAfterCancel,
unwatchChan: make(chan struct{}),
handler: handler,
unwatchChan: make(chan struct{}),
}
return cw
@ -46,7 +44,7 @@ func (cw *ContextWatcher) Watch(ctx context.Context) {
go func() {
select {
case <-ctx.Done():
cw.onCancel()
cw.handler.HandleCancel(ctx)
cw.onCancelWasCalled = true
<-cw.unwatchChan
case <-cw.unwatchChan:
@ -66,8 +64,17 @@ func (cw *ContextWatcher) Unwatch() {
if cw.watchInProgress {
cw.unwatchChan <- struct{}{}
if cw.onCancelWasCalled {
cw.onUnwatchAfterCancel()
cw.handler.HandleUnwatchAfterCancel()
}
cw.watchInProgress = false
}
}
type Handler interface {
// HandleCancel is called when the context that a ContextWatcher is currently watching is canceled. canceledCtx is the
// context that was canceled.
HandleCancel(canceledCtx context.Context)
// HandleUnwatchAfterCancel is called when a ContextWatcher that called HandleCancel on this Handler is unwatched.
HandleUnwatchAfterCancel()
}

View file

@ -5,8 +5,8 @@
Establishing a Connection
Use Connect to establish a connection. It accepts a connection string in URL or DSN and will read the environment for
libpq style environment variables.
Use Connect to establish a connection. It accepts a connection string in URL or keyword/value format and will read the
environment for libpq style environment variables.
Executing a Query
@ -20,13 +20,17 @@
Pipeline Mode
Pipeline mode allows sending queries without having read the results of previously sent queries. It allows
control of exactly how many and when network round trips occur.
Pipeline mode allows sending queries without having read the results of previously sent queries. It allows control of
exactly how many and when network round trips occur.
Context Support
All potentially blocking operations take a context.Context. If a context is canceled while the method is in progress the
method immediately returns. In most circumstances, this will close the underlying connection.
All potentially blocking operations take a context.Context. The default behavior when a context is canceled is for the
method to immediately return. In most circumstances, this will also close the underlying connection. This behavior can
be customized by using BuildContextWatcherHandler on the Config to create a ctxwatch.Handler with different behavior.
This can be especially useful when queries that are frequently canceled and the overhead of creating new connections is
a problem. DeadlineContextWatcherHandler and CancelRequestContextWatcherHandler can be used to introduce a delay before
interrupting the query in such a way as to close the connection.
The CancelRequest method may be used to request the PostgreSQL server cancel an in-progress query without forcing the
client to abort.

View file

@ -12,13 +12,14 @@
// SafeToRetry checks if the err is guaranteed to have occurred before sending any data to the server.
func SafeToRetry(err error) bool {
if e, ok := err.(interface{ SafeToRetry() bool }); ok {
return e.SafeToRetry()
var retryableErr interface{ SafeToRetry() bool }
if errors.As(err, &retryableErr) {
return retryableErr.SafeToRetry()
}
return false
}
// Timeout checks if err was was caused by a timeout. To be specific, it is true if err was caused within pgconn by a
// Timeout checks if err was caused by a timeout. To be specific, it is true if err was caused within pgconn by a
// context.DeadlineExceeded or an implementer of net.Error where Timeout() is true.
func Timeout(err error) bool {
var timeoutErr *errTimeout
@ -29,23 +30,24 @@ func Timeout(err error) bool {
// http://www.postgresql.org/docs/11/static/protocol-error-fields.html for
// detailed field description.
type PgError struct {
Severity string
Code string
Message string
Detail string
Hint string
Position int32
InternalPosition int32
InternalQuery string
Where string
SchemaName string
TableName string
ColumnName string
DataTypeName string
ConstraintName string
File string
Line int32
Routine string
Severity string
SeverityUnlocalized string
Code string
Message string
Detail string
Hint string
Position int32
InternalPosition int32
InternalQuery string
Where string
SchemaName string
TableName string
ColumnName string
DataTypeName string
ConstraintName string
File string
Line int32
Routine string
}
func (pe *PgError) Error() string {
@ -60,23 +62,37 @@ func (pe *PgError) SQLState() string {
// ConnectError is the error returned when a connection attempt fails.
type ConnectError struct {
Config *Config // The configuration that was used in the connection attempt.
msg string
err error
}
func (e *ConnectError) Error() string {
sb := &strings.Builder{}
fmt.Fprintf(sb, "failed to connect to `host=%s user=%s database=%s`: %s", e.Config.Host, e.Config.User, e.Config.Database, e.msg)
if e.err != nil {
fmt.Fprintf(sb, " (%s)", e.err.Error())
prefix := fmt.Sprintf("failed to connect to `user=%s database=%s`:", e.Config.User, e.Config.Database)
details := e.err.Error()
if strings.Contains(details, "\n") {
return prefix + "\n\t" + strings.ReplaceAll(details, "\n", "\n\t")
} else {
return prefix + " " + details
}
return sb.String()
}
func (e *ConnectError) Unwrap() error {
return e.err
}
type perDialConnectError struct {
address string
originalHostname string
err error
}
func (e *perDialConnectError) Error() string {
return fmt.Sprintf("%s (%s): %s", e.address, e.originalHostname, e.err.Error())
}
func (e *perDialConnectError) Unwrap() error {
return e.err
}
type connLockError struct {
status string
}
@ -195,10 +211,10 @@ func redactPW(connString string) string {
return redactURL(u)
}
}
quotedDSN := regexp.MustCompile(`password='[^']*'`)
connString = quotedDSN.ReplaceAllLiteralString(connString, "password=xxxxx")
plainDSN := regexp.MustCompile(`password=[^ ]*`)
connString = plainDSN.ReplaceAllLiteralString(connString, "password=xxxxx")
quotedKV := regexp.MustCompile(`password='[^']*'`)
connString = quotedKV.ReplaceAllLiteralString(connString, "password=xxxxx")
plainKV := regexp.MustCompile(`password=[^ ]*`)
connString = plainKV.ReplaceAllLiteralString(connString, "password=xxxxx")
brokenURL := regexp.MustCompile(`:[^:@]+?@`)
connString = brokenURL.ReplaceAllLiteralString(connString, ":xxxxxx@")
return connString

View file

@ -18,8 +18,8 @@
"github.com/jackc/pgx/v5/internal/iobufpool"
"github.com/jackc/pgx/v5/internal/pgio"
"github.com/jackc/pgx/v5/pgconn/ctxwatch"
"github.com/jackc/pgx/v5/pgconn/internal/bgreader"
"github.com/jackc/pgx/v5/pgconn/internal/ctxwatch"
"github.com/jackc/pgx/v5/pgproto3"
)
@ -82,6 +82,8 @@ type PgConn struct {
slowWriteTimer *time.Timer
bgReaderStarted chan struct{}
customData map[string]any
config *Config
status byte // One of connStatus* constants
@ -103,8 +105,9 @@ type PgConn struct {
cleanupDone chan struct{}
}
// Connect establishes a connection to a PostgreSQL server using the environment and connString (in URL or DSN format)
// to provide configuration. See documentation for [ParseConfig] for details. ctx can be used to cancel a connect attempt.
// Connect establishes a connection to a PostgreSQL server using the environment and connString (in URL or keyword/value
// format) to provide configuration. See documentation for [ParseConfig] for details. ctx can be used to cancel a
// connect attempt.
func Connect(ctx context.Context, connString string) (*PgConn, error) {
config, err := ParseConfig(connString)
if err != nil {
@ -114,9 +117,9 @@ func Connect(ctx context.Context, connString string) (*PgConn, error) {
return ConnectConfig(ctx, config)
}
// Connect establishes a connection to a PostgreSQL server using the environment and connString (in URL or DSN format)
// and ParseConfigOptions to provide additional configuration. See documentation for [ParseConfig] for details. ctx can be
// used to cancel a connect attempt.
// Connect establishes a connection to a PostgreSQL server using the environment and connString (in URL or keyword/value
// format) and ParseConfigOptions to provide additional configuration. See documentation for [ParseConfig] for details.
// ctx can be used to cancel a connect attempt.
func ConnectWithOptions(ctx context.Context, connString string, parseConfigOptions ParseConfigOptions) (*PgConn, error) {
config, err := ParseConfigWithOptions(connString, parseConfigOptions)
if err != nil {
@ -131,15 +134,46 @@ func ConnectWithOptions(ctx context.Context, connString string, parseConfigOptio
//
// If config.Fallbacks are present they will sequentially be tried in case of error establishing network connection. An
// authentication error will terminate the chain of attempts (like libpq:
// https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS) and be returned as the error. Otherwise,
// if all attempts fail the last error is returned.
func ConnectConfig(octx context.Context, config *Config) (pgConn *PgConn, err error) {
// https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS) and be returned as the error.
func ConnectConfig(ctx context.Context, config *Config) (*PgConn, error) {
// Default values are set in ParseConfig. Enforce initial creation by ParseConfig rather than setting defaults from
// zero values.
if !config.createdByParseConfig {
panic("config must be created by ParseConfig")
}
var allErrors []error
connectConfigs, errs := buildConnectOneConfigs(ctx, config)
if len(errs) > 0 {
allErrors = append(allErrors, errs...)
}
if len(connectConfigs) == 0 {
return nil, &ConnectError{Config: config, err: fmt.Errorf("hostname resolving error: %w", errors.Join(allErrors...))}
}
pgConn, errs := connectPreferred(ctx, config, connectConfigs)
if len(errs) > 0 {
allErrors = append(allErrors, errs...)
return nil, &ConnectError{Config: config, err: errors.Join(allErrors...)}
}
if config.AfterConnect != nil {
err := config.AfterConnect(ctx, pgConn)
if err != nil {
pgConn.conn.Close()
return nil, &ConnectError{Config: config, err: fmt.Errorf("AfterConnect error: %w", err)}
}
}
return pgConn, nil
}
// buildConnectOneConfigs resolves hostnames and builds a list of connectOneConfigs to try connecting to. It returns a
// slice of successfully resolved connectOneConfigs and a slice of errors. It is possible for both slices to contain
// values if some hosts were successfully resolved and others were not.
func buildConnectOneConfigs(ctx context.Context, config *Config) ([]*connectOneConfig, []error) {
// Simplify usage by treating primary config and fallbacks the same.
fallbackConfigs := []*FallbackConfig{
{
@ -149,95 +183,28 @@ func ConnectConfig(octx context.Context, config *Config) (pgConn *PgConn, err er
},
}
fallbackConfigs = append(fallbackConfigs, config.Fallbacks...)
ctx := octx
fallbackConfigs, err = expandWithIPs(ctx, config.LookupFunc, fallbackConfigs)
if err != nil {
return nil, &ConnectError{Config: config, msg: "hostname resolving error", err: err}
}
if len(fallbackConfigs) == 0 {
return nil, &ConnectError{Config: config, msg: "hostname resolving error", err: errors.New("ip addr wasn't found")}
}
var configs []*connectOneConfig
foundBestServer := false
var fallbackConfig *FallbackConfig
for i, fc := range fallbackConfigs {
// ConnectTimeout restricts the whole connection process.
if config.ConnectTimeout != 0 {
// create new context first time or when previous host was different
if i == 0 || (fallbackConfigs[i].Host != fallbackConfigs[i-1].Host) {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(octx, config.ConnectTimeout)
defer cancel()
}
} else {
ctx = octx
}
pgConn, err = connect(ctx, config, fc, false)
if err == nil {
foundBestServer = true
break
} else if pgerr, ok := err.(*PgError); ok {
err = &ConnectError{Config: config, msg: "server error", err: pgerr}
const ERRCODE_INVALID_PASSWORD = "28P01" // wrong password
const ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION = "28000" // wrong password or bad pg_hba.conf settings
const ERRCODE_INVALID_CATALOG_NAME = "3D000" // db does not exist
const ERRCODE_INSUFFICIENT_PRIVILEGE = "42501" // missing connect privilege
if pgerr.Code == ERRCODE_INVALID_PASSWORD ||
pgerr.Code == ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION && fc.TLSConfig != nil ||
pgerr.Code == ERRCODE_INVALID_CATALOG_NAME ||
pgerr.Code == ERRCODE_INSUFFICIENT_PRIVILEGE {
break
}
} else if cerr, ok := err.(*ConnectError); ok {
if _, ok := cerr.err.(*NotPreferredError); ok {
fallbackConfig = fc
}
}
}
var allErrors []error
if !foundBestServer && fallbackConfig != nil {
pgConn, err = connect(ctx, config, fallbackConfig, true)
if pgerr, ok := err.(*PgError); ok {
err = &ConnectError{Config: config, msg: "server error", err: pgerr}
}
}
if err != nil {
return nil, err // no need to wrap in connectError because it will already be wrapped in all cases except PgError
}
if config.AfterConnect != nil {
err := config.AfterConnect(ctx, pgConn)
if err != nil {
pgConn.conn.Close()
return nil, &ConnectError{Config: config, msg: "AfterConnect error", err: err}
}
}
return pgConn, nil
}
func expandWithIPs(ctx context.Context, lookupFn LookupFunc, fallbacks []*FallbackConfig) ([]*FallbackConfig, error) {
var configs []*FallbackConfig
var lookupErrors []error
for _, fb := range fallbacks {
for _, fb := range fallbackConfigs {
// skip resolve for unix sockets
if isAbsolutePath(fb.Host) {
configs = append(configs, &FallbackConfig{
Host: fb.Host,
Port: fb.Port,
TLSConfig: fb.TLSConfig,
network, address := NetworkAddress(fb.Host, fb.Port)
configs = append(configs, &connectOneConfig{
network: network,
address: address,
originalHostname: fb.Host,
tlsConfig: fb.TLSConfig,
})
continue
}
ips, err := lookupFn(ctx, fb.Host)
ips, err := config.LookupFunc(ctx, fb.Host)
if err != nil {
lookupErrors = append(lookupErrors, err)
allErrors = append(allErrors, err)
continue
}
@ -246,63 +213,126 @@ func expandWithIPs(ctx context.Context, lookupFn LookupFunc, fallbacks []*Fallba
if err == nil {
port, err := strconv.ParseUint(splitPort, 10, 16)
if err != nil {
return nil, fmt.Errorf("error parsing port (%s) from lookup: %w", splitPort, err)
return nil, []error{fmt.Errorf("error parsing port (%s) from lookup: %w", splitPort, err)}
}
configs = append(configs, &FallbackConfig{
Host: splitIP,
Port: uint16(port),
TLSConfig: fb.TLSConfig,
network, address := NetworkAddress(splitIP, uint16(port))
configs = append(configs, &connectOneConfig{
network: network,
address: address,
originalHostname: fb.Host,
tlsConfig: fb.TLSConfig,
})
} else {
configs = append(configs, &FallbackConfig{
Host: ip,
Port: fb.Port,
TLSConfig: fb.TLSConfig,
network, address := NetworkAddress(ip, fb.Port)
configs = append(configs, &connectOneConfig{
network: network,
address: address,
originalHostname: fb.Host,
tlsConfig: fb.TLSConfig,
})
}
}
}
// See https://github.com/jackc/pgx/issues/1464. When Go 1.20 can be used in pgx consider using errors.Join so all
// errors are reported.
if len(configs) == 0 && len(lookupErrors) > 0 {
return nil, lookupErrors[0]
}
return configs, nil
return configs, allErrors
}
func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig,
// connectPreferred attempts to connect to the preferred host from connectOneConfigs. The connections are attempted in
// order. If a connection is successful it is returned. If no connection is successful then all errors are returned. If
// a connection attempt returns a [NotPreferredError], then that host will be used if no other hosts are successful.
func connectPreferred(ctx context.Context, config *Config, connectOneConfigs []*connectOneConfig) (*PgConn, []error) {
octx := ctx
var allErrors []error
var fallbackConnectOneConfig *connectOneConfig
for i, c := range connectOneConfigs {
// ConnectTimeout restricts the whole connection process.
if config.ConnectTimeout != 0 {
// create new context first time or when previous host was different
if i == 0 || (connectOneConfigs[i].address != connectOneConfigs[i-1].address) {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(octx, config.ConnectTimeout)
defer cancel()
}
} else {
ctx = octx
}
pgConn, err := connectOne(ctx, config, c, false)
if pgConn != nil {
return pgConn, nil
}
allErrors = append(allErrors, err)
var pgErr *PgError
if errors.As(err, &pgErr) {
const ERRCODE_INVALID_PASSWORD = "28P01" // wrong password
const ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION = "28000" // wrong password or bad pg_hba.conf settings
const ERRCODE_INVALID_CATALOG_NAME = "3D000" // db does not exist
const ERRCODE_INSUFFICIENT_PRIVILEGE = "42501" // missing connect privilege
if pgErr.Code == ERRCODE_INVALID_PASSWORD ||
pgErr.Code == ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION && c.tlsConfig != nil ||
pgErr.Code == ERRCODE_INVALID_CATALOG_NAME ||
pgErr.Code == ERRCODE_INSUFFICIENT_PRIVILEGE {
return nil, allErrors
}
}
var npErr *NotPreferredError
if errors.As(err, &npErr) {
fallbackConnectOneConfig = c
}
}
if fallbackConnectOneConfig != nil {
pgConn, err := connectOne(ctx, config, fallbackConnectOneConfig, true)
if err == nil {
return pgConn, nil
}
allErrors = append(allErrors, err)
}
return nil, allErrors
}
// connectOne makes one connection attempt to a single host.
func connectOne(ctx context.Context, config *Config, connectConfig *connectOneConfig,
ignoreNotPreferredErr bool,
) (*PgConn, error) {
pgConn := new(PgConn)
pgConn.config = config
pgConn.cleanupDone = make(chan struct{})
pgConn.customData = make(map[string]any)
var err error
network, address := NetworkAddress(fallbackConfig.Host, fallbackConfig.Port)
netConn, err := config.DialFunc(ctx, network, address)
if err != nil {
return nil, &ConnectError{Config: config, msg: "dial error", err: normalizeTimeoutError(ctx, err)}
newPerDialConnectError := func(msg string, err error) *perDialConnectError {
err = normalizeTimeoutError(ctx, err)
e := &perDialConnectError{address: connectConfig.address, originalHostname: connectConfig.originalHostname, err: fmt.Errorf("%s: %w", msg, err)}
return e
}
pgConn.conn = netConn
pgConn.contextWatcher = newContextWatcher(netConn)
pgConn.contextWatcher.Watch(ctx)
pgConn.conn, err = config.DialFunc(ctx, connectConfig.network, connectConfig.address)
if err != nil {
return nil, newPerDialConnectError("dial error", err)
}
if fallbackConfig.TLSConfig != nil {
nbTLSConn, err := startTLS(netConn, fallbackConfig.TLSConfig)
if connectConfig.tlsConfig != nil {
pgConn.contextWatcher = ctxwatch.NewContextWatcher(&DeadlineContextWatcherHandler{Conn: pgConn.conn})
pgConn.contextWatcher.Watch(ctx)
tlsConn, err := startTLS(pgConn.conn, connectConfig.tlsConfig)
pgConn.contextWatcher.Unwatch() // Always unwatch `netConn` after TLS.
if err != nil {
netConn.Close()
return nil, &ConnectError{Config: config, msg: "tls error", err: normalizeTimeoutError(ctx, err)}
pgConn.conn.Close()
return nil, newPerDialConnectError("tls error", err)
}
pgConn.conn = nbTLSConn
pgConn.contextWatcher = newContextWatcher(nbTLSConn)
pgConn.contextWatcher.Watch(ctx)
pgConn.conn = tlsConn
}
pgConn.contextWatcher = ctxwatch.NewContextWatcher(config.BuildContextWatcherHandler(pgConn))
pgConn.contextWatcher.Watch(ctx)
defer pgConn.contextWatcher.Unwatch()
pgConn.parameterStatuses = make(map[string]string)
@ -336,7 +366,7 @@ func() {
pgConn.frontend.Send(&startupMsg)
if err := pgConn.flushWithPotentialWriteReadDeadlock(); err != nil {
pgConn.conn.Close()
return nil, &ConnectError{Config: config, msg: "failed to write startup message", err: normalizeTimeoutError(ctx, err)}
return nil, newPerDialConnectError("failed to write startup message", err)
}
for {
@ -344,9 +374,9 @@ func() {
if err != nil {
pgConn.conn.Close()
if err, ok := err.(*PgError); ok {
return nil, err
return nil, newPerDialConnectError("server error", err)
}
return nil, &ConnectError{Config: config, msg: "failed to receive message", err: normalizeTimeoutError(ctx, err)}
return nil, newPerDialConnectError("failed to receive message", err)
}
switch msg := msg.(type) {
@ -359,26 +389,26 @@ func() {
err = pgConn.txPasswordMessage(pgConn.config.Password)
if err != nil {
pgConn.conn.Close()
return nil, &ConnectError{Config: config, msg: "failed to write password message", err: err}
return nil, newPerDialConnectError("failed to write password message", err)
}
case *pgproto3.AuthenticationMD5Password:
digestedPassword := "md5" + hexMD5(hexMD5(pgConn.config.Password+pgConn.config.User)+string(msg.Salt[:]))
err = pgConn.txPasswordMessage(digestedPassword)
if err != nil {
pgConn.conn.Close()
return nil, &ConnectError{Config: config, msg: "failed to write password message", err: err}
return nil, newPerDialConnectError("failed to write password message", err)
}
case *pgproto3.AuthenticationSASL:
err = pgConn.scramAuth(msg.AuthMechanisms)
if err != nil {
pgConn.conn.Close()
return nil, &ConnectError{Config: config, msg: "failed SASL auth", err: err}
return nil, newPerDialConnectError("failed SASL auth", err)
}
case *pgproto3.AuthenticationGSS:
err = pgConn.gssAuth()
if err != nil {
pgConn.conn.Close()
return nil, &ConnectError{Config: config, msg: "failed GSS auth", err: err}
return nil, newPerDialConnectError("failed GSS auth", err)
}
case *pgproto3.ReadyForQuery:
pgConn.status = connStatusIdle
@ -396,7 +426,7 @@ func() {
return pgConn, nil
}
pgConn.conn.Close()
return nil, &ConnectError{Config: config, msg: "ValidateConnect failed", err: err}
return nil, newPerDialConnectError("ValidateConnect failed", err)
}
}
return pgConn, nil
@ -404,21 +434,14 @@ func() {
// handled by ReceiveMessage
case *pgproto3.ErrorResponse:
pgConn.conn.Close()
return nil, ErrorResponseToPgError(msg)
return nil, newPerDialConnectError("server error", ErrorResponseToPgError(msg))
default:
pgConn.conn.Close()
return nil, &ConnectError{Config: config, msg: "received unexpected message", err: err}
return nil, newPerDialConnectError("received unexpected message", err)
}
}
}
func newContextWatcher(conn net.Conn) *ctxwatch.ContextWatcher {
return ctxwatch.NewContextWatcher(
func() { conn.SetDeadline(time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC)) },
func() { conn.SetDeadline(time.Time{}) },
)
}
func startTLS(conn net.Conn, tlsConfig *tls.Config) (net.Conn, error) {
err := binary.Write(conn, binary.BigEndian, []int32{8, 80877103})
if err != nil {
@ -928,23 +951,24 @@ func (pgConn *PgConn) Deallocate(ctx context.Context, name string) error {
// ErrorResponseToPgError converts a wire protocol error message to a *PgError.
func ErrorResponseToPgError(msg *pgproto3.ErrorResponse) *PgError {
return &PgError{
Severity: msg.Severity,
Code: string(msg.Code),
Message: string(msg.Message),
Detail: string(msg.Detail),
Hint: msg.Hint,
Position: msg.Position,
InternalPosition: msg.InternalPosition,
InternalQuery: string(msg.InternalQuery),
Where: string(msg.Where),
SchemaName: string(msg.SchemaName),
TableName: string(msg.TableName),
ColumnName: string(msg.ColumnName),
DataTypeName: string(msg.DataTypeName),
ConstraintName: msg.ConstraintName,
File: string(msg.File),
Line: msg.Line,
Routine: string(msg.Routine),
Severity: msg.Severity,
SeverityUnlocalized: msg.SeverityUnlocalized,
Code: string(msg.Code),
Message: string(msg.Message),
Detail: string(msg.Detail),
Hint: msg.Hint,
Position: msg.Position,
InternalPosition: msg.InternalPosition,
InternalQuery: string(msg.InternalQuery),
Where: string(msg.Where),
SchemaName: string(msg.SchemaName),
TableName: string(msg.TableName),
ColumnName: string(msg.ColumnName),
DataTypeName: string(msg.DataTypeName),
ConstraintName: msg.ConstraintName,
File: string(msg.File),
Line: msg.Line,
Routine: string(msg.Routine),
}
}
@ -987,10 +1011,7 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error {
defer cancelConn.Close()
if ctx != context.Background() {
contextWatcher := ctxwatch.NewContextWatcher(
func() { cancelConn.SetDeadline(time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC)) },
func() { cancelConn.SetDeadline(time.Time{}) },
)
contextWatcher := ctxwatch.NewContextWatcher(&DeadlineContextWatcherHandler{Conn: cancelConn})
contextWatcher.Watch(ctx)
defer contextWatcher.Unwatch()
}
@ -1523,8 +1544,10 @@ func (rr *ResultReader) Read() *Result {
values := rr.Values()
row := make([][]byte, len(values))
for i := range row {
row[i] = make([]byte, len(values[i]))
copy(row[i], values[i])
if values[i] != nil {
row[i] = make([]byte, len(values[i]))
copy(row[i], values[i])
}
}
br.Rows = append(br.Rows, row)
}
@ -1879,6 +1902,11 @@ func (pgConn *PgConn) SyncConn(ctx context.Context) error {
return errors.New("SyncConn: conn never synchronized")
}
// CustomData returns a map that can be used to associate custom data with the connection.
func (pgConn *PgConn) CustomData() map[string]any {
return pgConn.customData
}
// HijackedConn is the result of hijacking a connection.
//
// Due to the necessary exposure of internal implementation details, it is not covered by the semantic versioning
@ -1891,6 +1919,7 @@ type HijackedConn struct {
TxStatus byte
Frontend *pgproto3.Frontend
Config *Config
CustomData map[string]any
}
// Hijack extracts the internal connection data. pgConn must be in an idle state. SyncConn should be called immediately
@ -1913,6 +1942,7 @@ func (pgConn *PgConn) Hijack() (*HijackedConn, error) {
TxStatus: pgConn.txStatus,
Frontend: pgConn.frontend,
Config: pgConn.config,
CustomData: pgConn.customData,
}, nil
}
@ -1932,13 +1962,14 @@ func Construct(hc *HijackedConn) (*PgConn, error) {
txStatus: hc.TxStatus,
frontend: hc.Frontend,
config: hc.Config,
customData: hc.CustomData,
status: connStatusIdle,
cleanupDone: make(chan struct{}),
}
pgConn.contextWatcher = newContextWatcher(pgConn.conn)
pgConn.contextWatcher = ctxwatch.NewContextWatcher(hc.Config.BuildContextWatcherHandler(pgConn))
pgConn.bgReader = bgreader.New(pgConn.conn)
pgConn.slowWriteTimer = time.AfterFunc(time.Duration(math.MaxInt64),
func() {
@ -2245,3 +2276,71 @@ func (p *Pipeline) Close() error {
return p.err
}
// DeadlineContextWatcherHandler handles canceled contexts by setting a deadline on a net.Conn.
type DeadlineContextWatcherHandler struct {
Conn net.Conn
// DeadlineDelay is the delay to set on the deadline set on net.Conn when the context is canceled.
DeadlineDelay time.Duration
}
func (h *DeadlineContextWatcherHandler) HandleCancel(ctx context.Context) {
h.Conn.SetDeadline(time.Now().Add(h.DeadlineDelay))
}
func (h *DeadlineContextWatcherHandler) HandleUnwatchAfterCancel() {
h.Conn.SetDeadline(time.Time{})
}
// CancelRequestContextWatcherHandler handles canceled contexts by sending a cancel request to the server. It also sets
// a deadline on a net.Conn as a fallback.
type CancelRequestContextWatcherHandler struct {
Conn *PgConn
// CancelRequestDelay is the delay before sending the cancel request to the server.
CancelRequestDelay time.Duration
// DeadlineDelay is the delay to set on the deadline set on net.Conn when the context is canceled.
DeadlineDelay time.Duration
cancelFinishedChan chan struct{}
handleUnwatchAfterCancelCalled func()
}
func (h *CancelRequestContextWatcherHandler) HandleCancel(context.Context) {
h.cancelFinishedChan = make(chan struct{})
var handleUnwatchedAfterCancelCalledCtx context.Context
handleUnwatchedAfterCancelCalledCtx, h.handleUnwatchAfterCancelCalled = context.WithCancel(context.Background())
deadline := time.Now().Add(h.DeadlineDelay)
h.Conn.conn.SetDeadline(deadline)
go func() {
defer close(h.cancelFinishedChan)
select {
case <-handleUnwatchedAfterCancelCalledCtx.Done():
return
case <-time.After(h.CancelRequestDelay):
}
cancelRequestCtx, cancel := context.WithDeadline(handleUnwatchedAfterCancelCalledCtx, deadline)
defer cancel()
h.Conn.CancelRequest(cancelRequestCtx)
// CancelRequest is inherently racy. Even though the cancel request has been received by the server at this point,
// it hasn't necessarily been delivered to the other connection. If we immediately return and the connection is
// immediately used then it is possible the CancelRequest will actually cancel our next query. The
// TestCancelRequestContextWatcherHandler Stress test can produce this error without the sleep below. The sleep time
// is arbitrary, but should be sufficient to prevent this error case.
time.Sleep(100 * time.Millisecond)
}()
}
func (h *CancelRequestContextWatcherHandler) HandleUnwatchAfterCancel() {
h.handleUnwatchAfterCancelCalled()
<-h.cancelFinishedChan
h.Conn.conn.SetDeadline(time.Time{})
}

View file

@ -99,7 +99,7 @@ func getValueFromJSON(v map[string]string) ([]byte, error) {
return nil, errors.New("unknown protocol representation")
}
// beginMessage begines a new message of type t. It appends the message type and a placeholder for the message length to
// beginMessage begins a new message of type t. It appends the message type and a placeholder for the message length to
// dst. It returns the new buffer and the position of the message length placeholder.
func beginMessage(dst []byte, t byte) ([]byte, int) {
dst = append(dst, t)

View file

@ -6,7 +6,6 @@
"fmt"
"reflect"
"github.com/jackc/pgx/v5/internal/anynil"
"github.com/jackc/pgx/v5/internal/pgio"
)
@ -230,7 +229,7 @@ func (c *ArrayCodec) PlanScan(m *Map, oid uint32, format int16, target any) Scan
// target / arrayScanner might be a pointer to a nil. If it is create one so we can call ScanIndexType to plan the
// scan of the elements.
if anynil.Is(target) {
if isNil, _ := isNilDriverValuer(target); isNil {
arrayScanner = reflect.New(reflect.TypeOf(target).Elem()).Interface().(ArraySetter)
}

View file

@ -139,6 +139,16 @@ func RegisterDataTypes(ctx context.Context, conn *pgx.Conn) error {
pgtype also includes support for custom types implementing the database/sql.Scanner and database/sql/driver.Valuer
interfaces.
Encoding Typed Nils
pgtype encodes untyped and typed nils (e.g. nil and []byte(nil)) to the SQL NULL value without going through the Codec
system. This means that Codecs and other encoding logic do not have to handle nil or *T(nil).
However, database/sql compatibility requires Value to be called on T(nil) when T implements driver.Valuer. Therefore,
driver.Valuer values are only considered NULL when *T(nil) where driver.Valuer is implemented on T not on *T. See
https://github.com/golang/go/issues/8415 and
https://github.com/golang/go/commit/0ce1d79a6a771f7449ec493b993ed2a720917870.
Child Records
pgtype's support for arrays and composite records can be used to load records and their children in a single query. See

View file

@ -132,22 +132,31 @@ func (encodePlanIntervalCodecText) Encode(value any, buf []byte) (newBuf []byte,
if interval.Days != 0 {
buf = append(buf, strconv.FormatInt(int64(interval.Days), 10)...)
buf = append(buf, " day "...)
buf = append(buf, " day"...)
}
absMicroseconds := interval.Microseconds
if absMicroseconds < 0 {
absMicroseconds = -absMicroseconds
buf = append(buf, '-')
if interval.Microseconds != 0 {
buf = append(buf, " "...)
absMicroseconds := interval.Microseconds
if absMicroseconds < 0 {
absMicroseconds = -absMicroseconds
buf = append(buf, '-')
}
hours := absMicroseconds / microsecondsPerHour
minutes := (absMicroseconds % microsecondsPerHour) / microsecondsPerMinute
seconds := (absMicroseconds % microsecondsPerMinute) / microsecondsPerSecond
timeStr := fmt.Sprintf("%02d:%02d:%02d", hours, minutes, seconds)
buf = append(buf, timeStr...)
microseconds := absMicroseconds % microsecondsPerSecond
if microseconds != 0 {
buf = append(buf, fmt.Sprintf(".%06d", microseconds)...)
}
}
hours := absMicroseconds / microsecondsPerHour
minutes := (absMicroseconds % microsecondsPerHour) / microsecondsPerMinute
seconds := (absMicroseconds % microsecondsPerMinute) / microsecondsPerSecond
microseconds := absMicroseconds % microsecondsPerSecond
timeStr := fmt.Sprintf("%02d:%02d:%02d.%06d", hours, minutes, seconds, microseconds)
buf = append(buf, timeStr...)
return buf, nil
}

View file

@ -8,17 +8,20 @@
"reflect"
)
type JSONCodec struct{}
type JSONCodec struct {
Marshal func(v any) ([]byte, error)
Unmarshal func(data []byte, v any) error
}
func (JSONCodec) FormatSupported(format int16) bool {
func (*JSONCodec) FormatSupported(format int16) bool {
return format == TextFormatCode || format == BinaryFormatCode
}
func (JSONCodec) PreferredFormat() int16 {
func (*JSONCodec) PreferredFormat() int16 {
return TextFormatCode
}
func (c JSONCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan {
func (c *JSONCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan {
switch value.(type) {
case string:
return encodePlanJSONCodecEitherFormatString{}
@ -44,7 +47,9 @@ func (c JSONCodec) PlanEncode(m *Map, oid uint32, format int16, value any) Encod
//
// https://github.com/jackc/pgx/issues/1681
case json.Marshaler:
return encodePlanJSONCodecEitherFormatMarshal{}
return &encodePlanJSONCodecEitherFormatMarshal{
marshal: c.Marshal,
}
}
// Because anything can be marshalled the normal wrapping in Map.PlanScan doesn't get a chance to run. So try the
@ -61,7 +66,9 @@ func (c JSONCodec) PlanEncode(m *Map, oid uint32, format int16, value any) Encod
}
}
return encodePlanJSONCodecEitherFormatMarshal{}
return &encodePlanJSONCodecEitherFormatMarshal{
marshal: c.Marshal,
}
}
type encodePlanJSONCodecEitherFormatString struct{}
@ -96,10 +103,12 @@ func (encodePlanJSONCodecEitherFormatJSONRawMessage) Encode(value any, buf []byt
return buf, nil
}
type encodePlanJSONCodecEitherFormatMarshal struct{}
type encodePlanJSONCodecEitherFormatMarshal struct {
marshal func(v any) ([]byte, error)
}
func (encodePlanJSONCodecEitherFormatMarshal) Encode(value any, buf []byte) (newBuf []byte, err error) {
jsonBytes, err := json.Marshal(value)
func (e *encodePlanJSONCodecEitherFormatMarshal) Encode(value any, buf []byte) (newBuf []byte, err error) {
jsonBytes, err := e.marshal(value)
if err != nil {
return nil, err
}
@ -108,7 +117,7 @@ func (encodePlanJSONCodecEitherFormatMarshal) Encode(value any, buf []byte) (new
return buf, nil
}
func (JSONCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
func (c *JSONCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
switch target.(type) {
case *string:
return scanPlanAnyToString{}
@ -141,7 +150,9 @@ func (JSONCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan
return &scanPlanSQLScanner{formatCode: format}
}
return scanPlanJSONToJSONUnmarshal{}
return &scanPlanJSONToJSONUnmarshal{
unmarshal: c.Unmarshal,
}
}
type scanPlanAnyToString struct{}
@ -173,9 +184,11 @@ func (scanPlanJSONToBytesScanner) Scan(src []byte, dst any) error {
return scanner.ScanBytes(src)
}
type scanPlanJSONToJSONUnmarshal struct{}
type scanPlanJSONToJSONUnmarshal struct {
unmarshal func(data []byte, v any) error
}
func (scanPlanJSONToJSONUnmarshal) Scan(src []byte, dst any) error {
func (s *scanPlanJSONToJSONUnmarshal) Scan(src []byte, dst any) error {
if src == nil {
dstValue := reflect.ValueOf(dst)
if dstValue.Kind() == reflect.Ptr {
@ -193,10 +206,10 @@ func (scanPlanJSONToJSONUnmarshal) Scan(src []byte, dst any) error {
elem := reflect.ValueOf(dst).Elem()
elem.Set(reflect.Zero(elem.Type()))
return json.Unmarshal(src, dst)
return s.unmarshal(src, dst)
}
func (c JSONCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) {
func (c *JSONCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) {
if src == nil {
return nil, nil
}
@ -206,12 +219,12 @@ func (c JSONCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src
return dstBuf, nil
}
func (c JSONCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) {
func (c *JSONCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) {
if src == nil {
return nil, nil
}
var dst any
err := json.Unmarshal(src, &dst)
err := c.Unmarshal(src, &dst)
return dst, err
}

View file

@ -2,29 +2,31 @@
import (
"database/sql/driver"
"encoding/json"
"fmt"
)
type JSONBCodec struct{}
type JSONBCodec struct {
Marshal func(v any) ([]byte, error)
Unmarshal func(data []byte, v any) error
}
func (JSONBCodec) FormatSupported(format int16) bool {
func (*JSONBCodec) FormatSupported(format int16) bool {
return format == TextFormatCode || format == BinaryFormatCode
}
func (JSONBCodec) PreferredFormat() int16 {
func (*JSONBCodec) PreferredFormat() int16 {
return TextFormatCode
}
func (JSONBCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan {
func (c *JSONBCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan {
switch format {
case BinaryFormatCode:
plan := JSONCodec{}.PlanEncode(m, oid, TextFormatCode, value)
plan := (&JSONCodec{Marshal: c.Marshal, Unmarshal: c.Unmarshal}).PlanEncode(m, oid, TextFormatCode, value)
if plan != nil {
return &encodePlanJSONBCodecBinaryWrapper{textPlan: plan}
}
case TextFormatCode:
return JSONCodec{}.PlanEncode(m, oid, format, value)
return (&JSONCodec{Marshal: c.Marshal, Unmarshal: c.Unmarshal}).PlanEncode(m, oid, format, value)
}
return nil
@ -39,15 +41,15 @@ func (plan *encodePlanJSONBCodecBinaryWrapper) Encode(value any, buf []byte) (ne
return plan.textPlan.Encode(value, buf)
}
func (JSONBCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
func (c *JSONBCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
switch format {
case BinaryFormatCode:
plan := JSONCodec{}.PlanScan(m, oid, TextFormatCode, target)
plan := (&JSONCodec{Marshal: c.Marshal, Unmarshal: c.Unmarshal}).PlanScan(m, oid, TextFormatCode, target)
if plan != nil {
return &scanPlanJSONBCodecBinaryUnwrapper{textPlan: plan}
}
case TextFormatCode:
return JSONCodec{}.PlanScan(m, oid, format, target)
return (&JSONCodec{Marshal: c.Marshal, Unmarshal: c.Unmarshal}).PlanScan(m, oid, format, target)
}
return nil
@ -73,7 +75,7 @@ func (plan *scanPlanJSONBCodecBinaryUnwrapper) Scan(src []byte, dst any) error {
return plan.textPlan.Scan(src[1:], dst)
}
func (c JSONBCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) {
func (c *JSONBCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) {
if src == nil {
return nil, nil
}
@ -100,7 +102,7 @@ func (c JSONBCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src
}
}
func (c JSONBCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) {
func (c *JSONBCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) {
if src == nil {
return nil, nil
}
@ -122,6 +124,6 @@ func (c JSONBCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (a
}
var dst any
err := json.Unmarshal(src, &dst)
err := c.Unmarshal(src, &dst)
return dst, err
}

View file

@ -41,6 +41,7 @@
CircleOID = 718
CircleArrayOID = 719
UnknownOID = 705
Macaddr8OID = 774
MacaddrOID = 829
InetOID = 869
BoolArrayOID = 1000
@ -1330,7 +1331,7 @@ func (plan *derefPointerEncodePlan) Encode(value any, buf []byte) (newBuf []byte
}
// TryWrapDerefPointerEncodePlan tries to dereference a pointer. e.g. If value was of type *string then a wrapper plan
// would be returned that derefences the value.
// would be returned that dereferences the value.
func TryWrapDerefPointerEncodePlan(value any) (plan WrappedEncodePlanNextSetter, nextValue any, ok bool) {
if _, ok := value.(driver.Valuer); ok {
return nil, nil, false
@ -1911,8 +1912,17 @@ func newEncodeError(value any, m *Map, oid uint32, formatCode int16, err error)
// (nil, nil). The caller of Encode is responsible for writing the correct NULL value or the length of the data
// written.
func (m *Map) Encode(oid uint32, formatCode int16, value any, buf []byte) (newBuf []byte, err error) {
if value == nil {
return nil, nil
if isNil, callNilDriverValuer := isNilDriverValuer(value); isNil {
if callNilDriverValuer {
newBuf, err = (&encodePlanDriverValuer{m: m, oid: oid, formatCode: formatCode}).Encode(value, buf)
if err != nil {
return nil, newEncodeError(value, m, oid, formatCode, err)
}
return newBuf, nil
} else {
return nil, nil
}
}
plan := m.PlanEncode(oid, formatCode, value)
@ -1967,3 +1977,55 @@ func (w *sqlScannerWrapper) Scan(src any) error {
return w.m.Scan(t.OID, TextFormatCode, bufSrc, w.v)
}
// canBeNil returns true if value can be nil.
func canBeNil(value any) bool {
refVal := reflect.ValueOf(value)
kind := refVal.Kind()
switch kind {
case reflect.Chan, reflect.Func, reflect.Map, reflect.Ptr, reflect.UnsafePointer, reflect.Interface, reflect.Slice:
return true
default:
return false
}
}
// valuerReflectType is a reflect.Type for driver.Valuer. It has confusing syntax because reflect.TypeOf returns nil
// when it's argument is a nil interface value. So we use a pointer to the interface and call Elem to get the actual
// type. Yuck.
//
// This can be simplified in Go 1.22 with reflect.TypeFor.
//
// var valuerReflectType = reflect.TypeFor[driver.Valuer]()
var valuerReflectType = reflect.TypeOf((*driver.Valuer)(nil)).Elem()
// isNilDriverValuer returns true if value is any type of nil unless it implements driver.Valuer. *T is not considered to implement
// driver.Valuer if it is only implemented by T.
func isNilDriverValuer(value any) (isNil bool, callNilDriverValuer bool) {
if value == nil {
return true, false
}
refVal := reflect.ValueOf(value)
kind := refVal.Kind()
switch kind {
case reflect.Chan, reflect.Func, reflect.Map, reflect.Ptr, reflect.UnsafePointer, reflect.Interface, reflect.Slice:
if !refVal.IsNil() {
return false, false
}
if _, ok := value.(driver.Valuer); ok {
if kind == reflect.Ptr {
// The type assertion will succeed if driver.Valuer is implemented on T or *T. Check if it is implemented on *T
// by checking if it is not implemented on *T.
return true, !refVal.Type().Elem().Implements(valuerReflectType)
} else {
return true, true
}
}
return true, false
default:
return false, false
}
}

View file

@ -65,11 +65,12 @@ func initDefaultMap() {
defaultMap.RegisterType(&Type{Name: "int4", OID: Int4OID, Codec: Int4Codec{}})
defaultMap.RegisterType(&Type{Name: "int8", OID: Int8OID, Codec: Int8Codec{}})
defaultMap.RegisterType(&Type{Name: "interval", OID: IntervalOID, Codec: IntervalCodec{}})
defaultMap.RegisterType(&Type{Name: "json", OID: JSONOID, Codec: JSONCodec{}})
defaultMap.RegisterType(&Type{Name: "jsonb", OID: JSONBOID, Codec: JSONBCodec{}})
defaultMap.RegisterType(&Type{Name: "json", OID: JSONOID, Codec: &JSONCodec{Marshal: json.Marshal, Unmarshal: json.Unmarshal}})
defaultMap.RegisterType(&Type{Name: "jsonb", OID: JSONBOID, Codec: &JSONBCodec{Marshal: json.Marshal, Unmarshal: json.Unmarshal}})
defaultMap.RegisterType(&Type{Name: "jsonpath", OID: JSONPathOID, Codec: &TextFormatOnlyCodec{TextCodec{}}})
defaultMap.RegisterType(&Type{Name: "line", OID: LineOID, Codec: LineCodec{}})
defaultMap.RegisterType(&Type{Name: "lseg", OID: LsegOID, Codec: LsegCodec{}})
defaultMap.RegisterType(&Type{Name: "macaddr8", OID: Macaddr8OID, Codec: MacaddrCodec{}})
defaultMap.RegisterType(&Type{Name: "macaddr", OID: MacaddrOID, Codec: MacaddrCodec{}})
defaultMap.RegisterType(&Type{Name: "name", OID: NameOID, Codec: TextCodec{}})
defaultMap.RegisterType(&Type{Name: "numeric", OID: NumericOID, Codec: NumericCodec{}})
@ -81,8 +82,8 @@ func initDefaultMap() {
defaultMap.RegisterType(&Type{Name: "text", OID: TextOID, Codec: TextCodec{}})
defaultMap.RegisterType(&Type{Name: "tid", OID: TIDOID, Codec: TIDCodec{}})
defaultMap.RegisterType(&Type{Name: "time", OID: TimeOID, Codec: TimeCodec{}})
defaultMap.RegisterType(&Type{Name: "timestamp", OID: TimestampOID, Codec: TimestampCodec{}})
defaultMap.RegisterType(&Type{Name: "timestamptz", OID: TimestamptzOID, Codec: TimestamptzCodec{}})
defaultMap.RegisterType(&Type{Name: "timestamp", OID: TimestampOID, Codec: &TimestampCodec{}})
defaultMap.RegisterType(&Type{Name: "timestamptz", OID: TimestamptzOID, Codec: &TimestamptzCodec{}})
defaultMap.RegisterType(&Type{Name: "unknown", OID: UnknownOID, Codec: TextCodec{}})
defaultMap.RegisterType(&Type{Name: "uuid", OID: UUIDOID, Codec: UUIDCodec{}})
defaultMap.RegisterType(&Type{Name: "varbit", OID: VarbitOID, Codec: BitsCodec{}})

View file

@ -45,7 +45,12 @@ func (t *Time) Scan(src any) error {
switch src := src.(type) {
case string:
return scanPlanTextAnyToTimeScanner{}.Scan([]byte(src), t)
err := scanPlanTextAnyToTimeScanner{}.Scan([]byte(src), t)
if err != nil {
t.Microseconds = 0
t.Valid = false
}
return err
}
return fmt.Errorf("cannot scan %T", src)
@ -136,6 +141,8 @@ func (TimeCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan
switch target.(type) {
case TimeScanner:
return scanPlanBinaryTimeToTimeScanner{}
case TextScanner:
return scanPlanBinaryTimeToTextScanner{}
}
case TextFormatCode:
switch target.(type) {
@ -165,6 +172,34 @@ func (scanPlanBinaryTimeToTimeScanner) Scan(src []byte, dst any) error {
return scanner.ScanTime(Time{Microseconds: usec, Valid: true})
}
type scanPlanBinaryTimeToTextScanner struct{}
func (scanPlanBinaryTimeToTextScanner) Scan(src []byte, dst any) error {
ts, ok := (dst).(TextScanner)
if !ok {
return ErrScanTargetTypeChanged
}
if src == nil {
return ts.ScanText(Text{})
}
if len(src) != 8 {
return fmt.Errorf("invalid length for time: %v", len(src))
}
usec := int64(binary.BigEndian.Uint64(src))
tim := Time{Microseconds: usec, Valid: true}
buf, err := TimeCodec{}.PlanEncode(nil, 0, TextFormatCode, tim).Encode(tim, nil)
if err != nil {
return err
}
return ts.ScanText(Text{String: string(buf), Valid: true})
}
type scanPlanTextAnyToTimeScanner struct{}
func (scanPlanTextAnyToTimeScanner) Scan(src []byte, dst any) error {
@ -176,7 +211,7 @@ func (scanPlanTextAnyToTimeScanner) Scan(src []byte, dst any) error {
s := string(src)
if len(s) < 8 {
if len(s) < 8 || s[2] != ':' || s[5] != ':' {
return fmt.Errorf("cannot decode %v into Time", s)
}
@ -199,6 +234,10 @@ func (scanPlanTextAnyToTimeScanner) Scan(src []byte, dst any) error {
usec += seconds * microsecondsPerSecond
if len(s) > 9 {
if s[8] != '.' || len(s) > 15 {
return fmt.Errorf("cannot decode %v into Time", s)
}
fraction := s[9:]
n, err := strconv.ParseInt(fraction, 10, 64)
if err != nil {

View file

@ -46,7 +46,7 @@ func (ts *Timestamp) Scan(src any) error {
switch src := src.(type) {
case string:
return scanPlanTextTimestampToTimestampScanner{}.Scan([]byte(src), ts)
return (&scanPlanTextTimestampToTimestampScanner{}).Scan([]byte(src), ts)
case time.Time:
*ts = Timestamp{Time: src, Valid: true}
return nil
@ -116,17 +116,21 @@ func (ts *Timestamp) UnmarshalJSON(b []byte) error {
return nil
}
type TimestampCodec struct{}
type TimestampCodec struct {
// ScanLocation is the location that the time is assumed to be in for scanning. This is different from
// TimestamptzCodec.ScanLocation in that this setting does change the instant in time that the timestamp represents.
ScanLocation *time.Location
}
func (TimestampCodec) FormatSupported(format int16) bool {
func (*TimestampCodec) FormatSupported(format int16) bool {
return format == TextFormatCode || format == BinaryFormatCode
}
func (TimestampCodec) PreferredFormat() int16 {
func (*TimestampCodec) PreferredFormat() int16 {
return BinaryFormatCode
}
func (TimestampCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan {
func (*TimestampCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan {
if _, ok := value.(TimestampValuer); !ok {
return nil
}
@ -220,27 +224,27 @@ func discardTimeZone(t time.Time) time.Time {
return t
}
func (TimestampCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
func (c *TimestampCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
switch format {
case BinaryFormatCode:
switch target.(type) {
case TimestampScanner:
return scanPlanBinaryTimestampToTimestampScanner{}
return &scanPlanBinaryTimestampToTimestampScanner{location: c.ScanLocation}
}
case TextFormatCode:
switch target.(type) {
case TimestampScanner:
return scanPlanTextTimestampToTimestampScanner{}
return &scanPlanTextTimestampToTimestampScanner{location: c.ScanLocation}
}
}
return nil
}
type scanPlanBinaryTimestampToTimestampScanner struct{}
type scanPlanBinaryTimestampToTimestampScanner struct{ location *time.Location }
func (scanPlanBinaryTimestampToTimestampScanner) Scan(src []byte, dst any) error {
func (plan *scanPlanBinaryTimestampToTimestampScanner) Scan(src []byte, dst any) error {
scanner := (dst).(TimestampScanner)
if src == nil {
@ -264,15 +268,18 @@ func (scanPlanBinaryTimestampToTimestampScanner) Scan(src []byte, dst any) error
microsecFromUnixEpochToY2K/1000000+microsecSinceY2K/1000000,
(microsecFromUnixEpochToY2K%1000000*1000)+(microsecSinceY2K%1000000*1000),
).UTC()
if plan.location != nil {
tim = time.Date(tim.Year(), tim.Month(), tim.Day(), tim.Hour(), tim.Minute(), tim.Second(), tim.Nanosecond(), plan.location)
}
ts = Timestamp{Time: tim, Valid: true}
}
return scanner.ScanTimestamp(ts)
}
type scanPlanTextTimestampToTimestampScanner struct{}
type scanPlanTextTimestampToTimestampScanner struct{ location *time.Location }
func (scanPlanTextTimestampToTimestampScanner) Scan(src []byte, dst any) error {
func (plan *scanPlanTextTimestampToTimestampScanner) Scan(src []byte, dst any) error {
scanner := (dst).(TimestampScanner)
if src == nil {
@ -302,13 +309,17 @@ func (scanPlanTextTimestampToTimestampScanner) Scan(src []byte, dst any) error {
tim = time.Date(year, tim.Month(), tim.Day(), tim.Hour(), tim.Minute(), tim.Second(), tim.Nanosecond(), tim.Location())
}
if plan.location != nil {
tim = time.Date(tim.Year(), tim.Month(), tim.Day(), tim.Hour(), tim.Minute(), tim.Second(), tim.Nanosecond(), plan.location)
}
ts = Timestamp{Time: tim, Valid: true}
}
return scanner.ScanTimestamp(ts)
}
func (c TimestampCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) {
func (c *TimestampCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) {
if src == nil {
return nil, nil
}
@ -326,7 +337,7 @@ func (c TimestampCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16,
return ts.Time, nil
}
func (c TimestampCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) {
func (c *TimestampCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) {
if src == nil {
return nil, nil
}

View file

@ -54,7 +54,7 @@ func (tstz *Timestamptz) Scan(src any) error {
switch src := src.(type) {
case string:
return scanPlanTextTimestamptzToTimestamptzScanner{}.Scan([]byte(src), tstz)
return (&scanPlanTextTimestamptzToTimestamptzScanner{}).Scan([]byte(src), tstz)
case time.Time:
*tstz = Timestamptz{Time: src, Valid: true}
return nil
@ -124,17 +124,21 @@ func (tstz *Timestamptz) UnmarshalJSON(b []byte) error {
return nil
}
type TimestamptzCodec struct{}
type TimestamptzCodec struct {
// ScanLocation is the location to return scanned timestamptz values in. This does not change the instant in time that
// the timestamptz represents.
ScanLocation *time.Location
}
func (TimestamptzCodec) FormatSupported(format int16) bool {
func (*TimestamptzCodec) FormatSupported(format int16) bool {
return format == TextFormatCode || format == BinaryFormatCode
}
func (TimestamptzCodec) PreferredFormat() int16 {
func (*TimestamptzCodec) PreferredFormat() int16 {
return BinaryFormatCode
}
func (TimestamptzCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan {
func (*TimestamptzCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan {
if _, ok := value.(TimestamptzValuer); !ok {
return nil
}
@ -220,27 +224,27 @@ func (encodePlanTimestamptzCodecText) Encode(value any, buf []byte) (newBuf []by
return buf, nil
}
func (TimestamptzCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
func (c *TimestamptzCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
switch format {
case BinaryFormatCode:
switch target.(type) {
case TimestamptzScanner:
return scanPlanBinaryTimestamptzToTimestamptzScanner{}
return &scanPlanBinaryTimestamptzToTimestamptzScanner{location: c.ScanLocation}
}
case TextFormatCode:
switch target.(type) {
case TimestamptzScanner:
return scanPlanTextTimestamptzToTimestamptzScanner{}
return &scanPlanTextTimestamptzToTimestamptzScanner{location: c.ScanLocation}
}
}
return nil
}
type scanPlanBinaryTimestamptzToTimestamptzScanner struct{}
type scanPlanBinaryTimestamptzToTimestamptzScanner struct{ location *time.Location }
func (scanPlanBinaryTimestamptzToTimestamptzScanner) Scan(src []byte, dst any) error {
func (plan *scanPlanBinaryTimestamptzToTimestamptzScanner) Scan(src []byte, dst any) error {
scanner := (dst).(TimestamptzScanner)
if src == nil {
@ -264,15 +268,18 @@ func (scanPlanBinaryTimestamptzToTimestamptzScanner) Scan(src []byte, dst any) e
microsecFromUnixEpochToY2K/1000000+microsecSinceY2K/1000000,
(microsecFromUnixEpochToY2K%1000000*1000)+(microsecSinceY2K%1000000*1000),
)
if plan.location != nil {
tim = tim.In(plan.location)
}
tstz = Timestamptz{Time: tim, Valid: true}
}
return scanner.ScanTimestamptz(tstz)
}
type scanPlanTextTimestamptzToTimestamptzScanner struct{}
type scanPlanTextTimestamptzToTimestamptzScanner struct{ location *time.Location }
func (scanPlanTextTimestamptzToTimestamptzScanner) Scan(src []byte, dst any) error {
func (plan *scanPlanTextTimestamptzToTimestamptzScanner) Scan(src []byte, dst any) error {
scanner := (dst).(TimestamptzScanner)
if src == nil {
@ -312,13 +319,17 @@ func (scanPlanTextTimestamptzToTimestamptzScanner) Scan(src []byte, dst any) err
tim = time.Date(year, tim.Month(), tim.Day(), tim.Hour(), tim.Minute(), tim.Second(), tim.Nanosecond(), tim.Location())
}
if plan.location != nil {
tim = tim.In(plan.location)
}
tstz = Timestamptz{Time: tim, Valid: true}
}
return scanner.ScanTimestamptz(tstz)
}
func (c TimestamptzCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) {
func (c *TimestamptzCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) {
if src == nil {
return nil, nil
}
@ -336,7 +347,7 @@ func (c TimestamptzCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int1
return tstz.Time, nil
}
func (c TimestamptzCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) {
func (c *TimestamptzCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) {
if src == nil {
return nil, nil
}

View file

@ -26,6 +26,10 @@ func (c *Conn) Release() {
res := c.res
c.res = nil
if c.p.releaseTracer != nil {
c.p.releaseTracer.TraceRelease(c.p, TraceReleaseData{Conn: conn})
}
if conn.IsClosed() || conn.PgConn().IsBusy() || conn.PgConn().TxStatus() != 'I' {
res.Destroy()
// Signal to the health check to run since we just destroyed a connections

View file

@ -8,7 +8,7 @@
pool, err := pgxpool.New(context.Background(), os.Getenv("DATABASE_URL"))
The database connection string can be in URL or DSN format. PostgreSQL settings, pgx settings, and pool settings can be
The database connection string can be in URL or keyword/value format. PostgreSQL settings, pgx settings, and pool settings can be
specified here. In addition, a config struct can be created by [ParseConfig].
config, err := pgxpool.ParseConfig(os.Getenv("DATABASE_URL"))

View file

@ -95,6 +95,9 @@ type Pool struct {
healthCheckChan chan struct{}
acquireTracer AcquireTracer
releaseTracer ReleaseTracer
closeOnce sync.Once
closeChan chan struct{}
}
@ -195,6 +198,14 @@ func NewWithConfig(ctx context.Context, config *Config) (*Pool, error) {
closeChan: make(chan struct{}),
}
if t, ok := config.ConnConfig.Tracer.(AcquireTracer); ok {
p.acquireTracer = t
}
if t, ok := config.ConnConfig.Tracer.(ReleaseTracer); ok {
p.releaseTracer = t
}
var err error
p.p, err = puddle.NewPool(
&puddle.Config[*connResource]{
@ -279,7 +290,7 @@ func NewWithConfig(ctx context.Context, config *Config) (*Pool, error) {
//
// See Config for definitions of these arguments.
//
// # Example DSN
// # Example Keyword/Value
// user=jack password=secret host=pg.example.com port=5432 dbname=mydb sslmode=verify-ca pool_max_conns=10
//
// # Example URL
@ -498,7 +509,18 @@ func (p *Pool) createIdleResources(parentCtx context.Context, targetResources in
}
// Acquire returns a connection (*Conn) from the Pool
func (p *Pool) Acquire(ctx context.Context) (*Conn, error) {
func (p *Pool) Acquire(ctx context.Context) (c *Conn, err error) {
if p.acquireTracer != nil {
ctx = p.acquireTracer.TraceAcquireStart(ctx, p, TraceAcquireStartData{})
defer func() {
var conn *pgx.Conn
if c != nil {
conn = c.Conn()
}
p.acquireTracer.TraceAcquireEnd(ctx, p, TraceAcquireEndData{Conn: conn, Err: err})
}()
}
for {
res, err := p.p.Acquire(ctx)
if err != nil {

33
vendor/github.com/jackc/pgx/v5/pgxpool/tracer.go generated vendored Normal file
View file

@ -0,0 +1,33 @@
package pgxpool
import (
"context"
"github.com/jackc/pgx/v5"
)
// AcquireTracer traces Acquire.
type AcquireTracer interface {
// TraceAcquireStart is called at the beginning of Acquire.
// The returned context is used for the rest of the call and will be passed to the TraceAcquireEnd.
TraceAcquireStart(ctx context.Context, pool *Pool, data TraceAcquireStartData) context.Context
// TraceAcquireEnd is called when a connection has been acquired.
TraceAcquireEnd(ctx context.Context, pool *Pool, data TraceAcquireEndData)
}
type TraceAcquireStartData struct{}
type TraceAcquireEndData struct {
Conn *pgx.Conn
Err error
}
// ReleaseTracer traces Release.
type ReleaseTracer interface {
// TraceRelease is called at the beginning of Release.
TraceRelease(pool *Pool, data TraceReleaseData)
}
type TraceReleaseData struct {
Conn *pgx.Conn
}

View file

@ -6,6 +6,7 @@
"fmt"
"reflect"
"strings"
"sync"
"time"
"github.com/jackc/pgx/v5/pgconn"
@ -418,6 +419,8 @@ type CollectableRow interface {
type RowToFunc[T any] func(row CollectableRow) (T, error)
// AppendRows iterates through rows, calling fn for each row, and appending the results into a slice of T.
//
// This function closes the rows automatically on return.
func AppendRows[T any, S ~[]T](slice S, rows Rows, fn RowToFunc[T]) (S, error) {
defer rows.Close()
@ -437,12 +440,16 @@ func AppendRows[T any, S ~[]T](slice S, rows Rows, fn RowToFunc[T]) (S, error) {
}
// CollectRows iterates through rows, calling fn for each row, and collecting the results into a slice of T.
//
// This function closes the rows automatically on return.
func CollectRows[T any](rows Rows, fn RowToFunc[T]) ([]T, error) {
return AppendRows([]T{}, rows, fn)
}
// CollectOneRow calls fn for the first row in rows and returns the result. If no rows are found returns an error where errors.Is(ErrNoRows) is true.
// CollectOneRow is to CollectRows as QueryRow is to Query.
//
// This function closes the rows automatically on return.
func CollectOneRow[T any](rows Rows, fn RowToFunc[T]) (T, error) {
defer rows.Close()
@ -468,6 +475,8 @@ func CollectOneRow[T any](rows Rows, fn RowToFunc[T]) (T, error) {
// CollectExactlyOneRow calls fn for the first row in rows and returns the result.
// - If no rows are found returns an error where errors.Is(ErrNoRows) is true.
// - If more than 1 row is found returns an error where errors.Is(ErrTooManyRows) is true.
//
// This function closes the rows automatically on return.
func CollectExactlyOneRow[T any](rows Rows, fn RowToFunc[T]) (T, error) {
defer rows.Close()
@ -541,7 +550,7 @@ func (rs *mapRowScanner) ScanRow(rows Rows) error {
// ignored.
func RowToStructByPos[T any](row CollectableRow) (T, error) {
var value T
err := row.Scan(&positionalStructRowScanner{ptrToStruct: &value})
err := (&positionalStructRowScanner{ptrToStruct: &value}).ScanRow(row)
return value, err
}
@ -550,7 +559,7 @@ func RowToStructByPos[T any](row CollectableRow) (T, error) {
// the field will be ignored.
func RowToAddrOfStructByPos[T any](row CollectableRow) (*T, error) {
var value T
err := row.Scan(&positionalStructRowScanner{ptrToStruct: &value})
err := (&positionalStructRowScanner{ptrToStruct: &value}).ScanRow(row)
return &value, err
}
@ -558,46 +567,60 @@ type positionalStructRowScanner struct {
ptrToStruct any
}
func (rs *positionalStructRowScanner) ScanRow(rows Rows) error {
dst := rs.ptrToStruct
dstValue := reflect.ValueOf(dst)
if dstValue.Kind() != reflect.Ptr {
return fmt.Errorf("dst not a pointer")
func (rs *positionalStructRowScanner) ScanRow(rows CollectableRow) error {
typ := reflect.TypeOf(rs.ptrToStruct).Elem()
fields := lookupStructFields(typ)
if len(rows.RawValues()) > len(fields) {
return fmt.Errorf(
"got %d values, but dst struct has only %d fields",
len(rows.RawValues()),
len(fields),
)
}
dstElemValue := dstValue.Elem()
scanTargets := rs.appendScanTargets(dstElemValue, nil)
if len(rows.RawValues()) > len(scanTargets) {
return fmt.Errorf("got %d values, but dst struct has only %d fields", len(rows.RawValues()), len(scanTargets))
}
scanTargets := setupStructScanTargets(rs.ptrToStruct, fields)
return rows.Scan(scanTargets...)
}
func (rs *positionalStructRowScanner) appendScanTargets(dstElemValue reflect.Value, scanTargets []any) []any {
dstElemType := dstElemValue.Type()
// Map from reflect.Type -> []structRowField
var positionalStructFieldMap sync.Map
if scanTargets == nil {
scanTargets = make([]any, 0, dstElemType.NumField())
func lookupStructFields(t reflect.Type) []structRowField {
if cached, ok := positionalStructFieldMap.Load(t); ok {
return cached.([]structRowField)
}
for i := 0; i < dstElemType.NumField(); i++ {
sf := dstElemType.Field(i)
fieldStack := make([]int, 0, 1)
fields := computeStructFields(t, make([]structRowField, 0, t.NumField()), &fieldStack)
fieldsIface, _ := positionalStructFieldMap.LoadOrStore(t, fields)
return fieldsIface.([]structRowField)
}
func computeStructFields(
t reflect.Type,
fields []structRowField,
fieldStack *[]int,
) []structRowField {
tail := len(*fieldStack)
*fieldStack = append(*fieldStack, 0)
for i := 0; i < t.NumField(); i++ {
sf := t.Field(i)
(*fieldStack)[tail] = i
// Handle anonymous struct embedding, but do not try to handle embedded pointers.
if sf.Anonymous && sf.Type.Kind() == reflect.Struct {
scanTargets = rs.appendScanTargets(dstElemValue.Field(i), scanTargets)
fields = computeStructFields(sf.Type, fields, fieldStack)
} else if sf.PkgPath == "" {
dbTag, _ := sf.Tag.Lookup(structTagKey)
if dbTag == "-" {
// Field is ignored, skip it.
continue
}
scanTargets = append(scanTargets, dstElemValue.Field(i).Addr().Interface())
fields = append(fields, structRowField{
path: append([]int(nil), *fieldStack...),
})
}
}
return scanTargets
*fieldStack = (*fieldStack)[:tail]
return fields
}
// RowToStructByName returns a T scanned from row. T must be a struct. T must have the same number of named public
@ -605,7 +628,7 @@ func (rs *positionalStructRowScanner) appendScanTargets(dstElemValue reflect.Val
// column name can be overridden with a "db" struct tag. If the "db" struct tag is "-" then the field will be ignored.
func RowToStructByName[T any](row CollectableRow) (T, error) {
var value T
err := row.Scan(&namedStructRowScanner{ptrToStruct: &value})
err := (&namedStructRowScanner{ptrToStruct: &value}).ScanRow(row)
return value, err
}
@ -615,7 +638,7 @@ func RowToStructByName[T any](row CollectableRow) (T, error) {
// then the field will be ignored.
func RowToAddrOfStructByName[T any](row CollectableRow) (*T, error) {
var value T
err := row.Scan(&namedStructRowScanner{ptrToStruct: &value})
err := (&namedStructRowScanner{ptrToStruct: &value}).ScanRow(row)
return &value, err
}
@ -624,7 +647,7 @@ func RowToAddrOfStructByName[T any](row CollectableRow) (*T, error) {
// column name can be overridden with a "db" struct tag. If the "db" struct tag is "-" then the field will be ignored.
func RowToStructByNameLax[T any](row CollectableRow) (T, error) {
var value T
err := row.Scan(&namedStructRowScanner{ptrToStruct: &value, lax: true})
err := (&namedStructRowScanner{ptrToStruct: &value, lax: true}).ScanRow(row)
return value, err
}
@ -634,7 +657,7 @@ func RowToStructByNameLax[T any](row CollectableRow) (T, error) {
// then the field will be ignored.
func RowToAddrOfStructByNameLax[T any](row CollectableRow) (*T, error) {
var value T
err := row.Scan(&namedStructRowScanner{ptrToStruct: &value, lax: true})
err := (&namedStructRowScanner{ptrToStruct: &value, lax: true}).ScanRow(row)
return &value, err
}
@ -643,26 +666,152 @@ type namedStructRowScanner struct {
lax bool
}
func (rs *namedStructRowScanner) ScanRow(rows Rows) error {
dst := rs.ptrToStruct
dstValue := reflect.ValueOf(dst)
if dstValue.Kind() != reflect.Ptr {
return fmt.Errorf("dst not a pointer")
}
dstElemValue := dstValue.Elem()
scanTargets, err := rs.appendScanTargets(dstElemValue, nil, rows.FieldDescriptions())
func (rs *namedStructRowScanner) ScanRow(rows CollectableRow) error {
typ := reflect.TypeOf(rs.ptrToStruct).Elem()
fldDescs := rows.FieldDescriptions()
namedStructFields, err := lookupNamedStructFields(typ, fldDescs)
if err != nil {
return err
}
if !rs.lax && namedStructFields.missingField != "" {
return fmt.Errorf("cannot find field %s in returned row", namedStructFields.missingField)
}
fields := namedStructFields.fields
scanTargets := setupStructScanTargets(rs.ptrToStruct, fields)
return rows.Scan(scanTargets...)
}
for i, t := range scanTargets {
if t == nil {
return fmt.Errorf("struct doesn't have corresponding row field %s", rows.FieldDescriptions()[i].Name)
// Map from namedStructFieldMap -> *namedStructFields
var namedStructFieldMap sync.Map
type namedStructFieldsKey struct {
t reflect.Type
colNames string
}
type namedStructFields struct {
fields []structRowField
// missingField is the first field from the struct without a corresponding row field.
// This is used to construct the correct error message for non-lax queries.
missingField string
}
func lookupNamedStructFields(
t reflect.Type,
fldDescs []pgconn.FieldDescription,
) (*namedStructFields, error) {
key := namedStructFieldsKey{
t: t,
colNames: joinFieldNames(fldDescs),
}
if cached, ok := namedStructFieldMap.Load(key); ok {
return cached.(*namedStructFields), nil
}
// We could probably do two-levels of caching, where we compute the key -> fields mapping
// for a type only once, cache it by type, then use that to compute the column -> fields
// mapping for a given set of columns.
fieldStack := make([]int, 0, 1)
fields, missingField := computeNamedStructFields(
fldDescs,
t,
make([]structRowField, len(fldDescs)),
&fieldStack,
)
for i, f := range fields {
if f.path == nil {
return nil, fmt.Errorf(
"struct doesn't have corresponding row field %s",
fldDescs[i].Name,
)
}
}
return rows.Scan(scanTargets...)
fieldsIface, _ := namedStructFieldMap.LoadOrStore(
key,
&namedStructFields{fields: fields, missingField: missingField},
)
return fieldsIface.(*namedStructFields), nil
}
func joinFieldNames(fldDescs []pgconn.FieldDescription) string {
switch len(fldDescs) {
case 0:
return ""
case 1:
return fldDescs[0].Name
}
totalSize := len(fldDescs) - 1 // Space for separator bytes.
for _, d := range fldDescs {
totalSize += len(d.Name)
}
var b strings.Builder
b.Grow(totalSize)
b.WriteString(fldDescs[0].Name)
for _, d := range fldDescs[1:] {
b.WriteByte(0) // Join with NUL byte as it's (presumably) not a valid column character.
b.WriteString(d.Name)
}
return b.String()
}
func computeNamedStructFields(
fldDescs []pgconn.FieldDescription,
t reflect.Type,
fields []structRowField,
fieldStack *[]int,
) ([]structRowField, string) {
var missingField string
tail := len(*fieldStack)
*fieldStack = append(*fieldStack, 0)
for i := 0; i < t.NumField(); i++ {
sf := t.Field(i)
(*fieldStack)[tail] = i
if sf.PkgPath != "" && !sf.Anonymous {
// Field is unexported, skip it.
continue
}
// Handle anonymous struct embedding, but do not try to handle embedded pointers.
if sf.Anonymous && sf.Type.Kind() == reflect.Struct {
var missingSubField string
fields, missingSubField = computeNamedStructFields(
fldDescs,
sf.Type,
fields,
fieldStack,
)
if missingField == "" {
missingField = missingSubField
}
} else {
dbTag, dbTagPresent := sf.Tag.Lookup(structTagKey)
if dbTagPresent {
dbTag, _, _ = strings.Cut(dbTag, ",")
}
if dbTag == "-" {
// Field is ignored, skip it.
continue
}
colName := dbTag
if !dbTagPresent {
colName = sf.Name
}
fpos := fieldPosByName(fldDescs, colName)
if fpos == -1 {
if missingField == "" {
missingField = colName
}
continue
}
fields[fpos] = structRowField{
path: append([]int(nil), *fieldStack...),
}
}
}
*fieldStack = (*fieldStack)[:tail]
return fields, missingField
}
const structTagKey = "db"
@ -682,52 +831,21 @@ func fieldPosByName(fldDescs []pgconn.FieldDescription, field string) (i int) {
return
}
func (rs *namedStructRowScanner) appendScanTargets(dstElemValue reflect.Value, scanTargets []any, fldDescs []pgconn.FieldDescription) ([]any, error) {
var err error
dstElemType := dstElemValue.Type()
if scanTargets == nil {
scanTargets = make([]any, len(fldDescs))
}
for i := 0; i < dstElemType.NumField(); i++ {
sf := dstElemType.Field(i)
if sf.PkgPath != "" && !sf.Anonymous {
// Field is unexported, skip it.
continue
}
// Handle anonymous struct embedding, but do not try to handle embedded pointers.
if sf.Anonymous && sf.Type.Kind() == reflect.Struct {
scanTargets, err = rs.appendScanTargets(dstElemValue.Field(i), scanTargets, fldDescs)
if err != nil {
return nil, err
}
} else {
dbTag, dbTagPresent := sf.Tag.Lookup(structTagKey)
if dbTagPresent {
dbTag, _, _ = strings.Cut(dbTag, ",")
}
if dbTag == "-" {
// Field is ignored, skip it.
continue
}
colName := dbTag
if !dbTagPresent {
colName = sf.Name
}
fpos := fieldPosByName(fldDescs, colName)
if fpos == -1 {
if rs.lax {
continue
}
return nil, fmt.Errorf("cannot find field %s in returned row", colName)
}
if fpos >= len(scanTargets) && !rs.lax {
return nil, fmt.Errorf("cannot find field %s in returned row", colName)
}
scanTargets[fpos] = dstElemValue.Field(i).Addr().Interface()
}
}
return scanTargets, err
// structRowField describes a field of a struct.
//
// TODO: It would be a bit more efficient to track the path using the pointer
// offset within the (outermost) struct and use unsafe.Pointer arithmetic to
// construct references when scanning rows. However, it's not clear it's worth
// using unsafe for this.
type structRowField struct {
path []int
}
func setupStructScanTargets(receiver any, fields []structRowField) []any {
scanTargets := make([]any, len(fields))
v := reflect.ValueOf(receiver).Elem()
for i, f := range fields {
scanTargets[i] = v.FieldByIndex(f.path).Addr().Interface()
}
return scanTargets
}

View file

@ -7,7 +7,7 @@
// return err
// }
//
// Or from a DSN string.
// Or from a keyword/value string.
//
// db, err := sql.Open("pgx", "user=postgres password=secret host=localhost port=5432 database=pgx_test sslmode=disable")
// if err != nil {

View file

@ -3,7 +3,6 @@
import (
"errors"
"github.com/jackc/pgx/v5/internal/anynil"
"github.com/jackc/pgx/v5/internal/pgio"
"github.com/jackc/pgx/v5/pgtype"
)
@ -15,10 +14,6 @@
)
func convertSimpleArgument(m *pgtype.Map, arg any) (any, error) {
if anynil.Is(arg) {
return nil, nil
}
buf, err := m.Encode(0, TextFormatCode, arg, []byte{})
if err != nil {
return nil, err
@ -30,10 +25,6 @@ func convertSimpleArgument(m *pgtype.Map, arg any) (any, error) {
}
func encodeCopyValue(m *pgtype.Map, buf []byte, oid uint32, arg any) ([]byte, error) {
if anynil.Is(arg) {
return pgio.AppendInt32(buf, -1), nil
}
sp := len(buf)
buf = pgio.AppendInt32(buf, -1)
argBuf, err := m.Encode(oid, BinaryFormatCode, arg, buf)

7
vendor/modules.txt vendored
View file

@ -421,17 +421,16 @@ github.com/jackc/pgpassfile
# github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a
## explicit; go 1.14
github.com/jackc/pgservicefile
# github.com/jackc/pgx/v5 v5.5.5
## explicit; go 1.19
# github.com/jackc/pgx/v5 v5.6.0
## explicit; go 1.20
github.com/jackc/pgx/v5
github.com/jackc/pgx/v5/internal/anynil
github.com/jackc/pgx/v5/internal/iobufpool
github.com/jackc/pgx/v5/internal/pgio
github.com/jackc/pgx/v5/internal/sanitize
github.com/jackc/pgx/v5/internal/stmtcache
github.com/jackc/pgx/v5/pgconn
github.com/jackc/pgx/v5/pgconn/ctxwatch
github.com/jackc/pgx/v5/pgconn/internal/bgreader
github.com/jackc/pgx/v5/pgconn/internal/ctxwatch
github.com/jackc/pgx/v5/pgproto3
github.com/jackc/pgx/v5/pgtype
github.com/jackc/pgx/v5/pgxpool