package pgx import ( "context" "errors" "fmt" "reflect" "strings" "time" "github.com/jackc/pgx/v5/internal/stmtcache" "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgtype" ) // Rows is the result set returned from *Conn.Query. Rows must be closed before // the *Conn can be used again. Rows are closed by explicitly calling Close(), // calling Next() until it returns false, or when a fatal error occurs. // // Once a Rows is closed the only methods that may be called are Close(), Err(), and CommandTag(). // // Rows is an interface instead of a struct to allow tests to mock Query. However, // adding a method to an interface is technically a breaking change. Because of this // the Rows interface is partially excluded from semantic version requirements. // Methods will not be removed or changed, but new methods may be added. type Rows interface { // Close closes the rows, making the connection ready for use again. It is safe // to call Close after rows is already closed. Close() // Err returns any error that occurred while reading. Err() error // CommandTag returns the command tag from this query. It is only available after Rows is closed. CommandTag() pgconn.CommandTag FieldDescriptions() []pgconn.FieldDescription // Next prepares the next row for reading. It returns true if there is another // row and false if no more rows are available. It automatically closes rows // when all rows are read. Next() bool // Scan reads the values from the current row into dest values positionally. // dest can include pointers to core types, values implementing the Scanner // interface, and nil. nil will skip the value entirely. It is an error to // call Scan without first calling Next() and checking that it returned true. Scan(dest ...any) error // Values returns the decoded row values. As with Scan(), it is an error to // call Values without first calling Next() and checking that it returned // true. Values() ([]any, error) // RawValues returns the unparsed bytes of the row values. The returned data is only valid until the next Next // call or the Rows is closed. RawValues() [][]byte // Conn returns the underlying *Conn on which the query was executed. This may return nil if Rows did not come from a // *Conn (e.g. if it was created by RowsFromResultReader) Conn() *Conn } // Row is a convenience wrapper over Rows that is returned by QueryRow. // // Row is an interface instead of a struct to allow tests to mock QueryRow. However, // adding a method to an interface is technically a breaking change. Because of this // the Row interface is partially excluded from semantic version requirements. // Methods will not be removed or changed, but new methods may be added. type Row interface { // Scan works the same as Rows. with the following exceptions. If no // rows were found it returns ErrNoRows. If multiple rows are returned it // ignores all but the first. Scan(dest ...any) error } // RowScanner scans an entire row at a time into the RowScanner. type RowScanner interface { // ScanRows scans the row. ScanRow(rows Rows) error } // connRow implements the Row interface for Conn.QueryRow. type connRow baseRows func (r *connRow) Scan(dest ...any) (err error) { rows := (*baseRows)(r) if rows.Err() != nil { return rows.Err() } for _, d := range dest { if _, ok := d.(*pgtype.DriverBytes); ok { rows.Close() return fmt.Errorf("cannot scan into *pgtype.DriverBytes from QueryRow") } } if !rows.Next() { if rows.Err() == nil { return ErrNoRows } return rows.Err() } rows.Scan(dest...) rows.Close() return rows.Err() } // baseRows implements the Rows interface for Conn.Query. type baseRows struct { typeMap *pgtype.Map resultReader *pgconn.ResultReader values [][]byte commandTag pgconn.CommandTag err error closed bool scanPlans []pgtype.ScanPlan scanTypes []reflect.Type conn *Conn multiResultReader *pgconn.MultiResultReader queryTracer QueryTracer batchTracer BatchTracer ctx context.Context startTime time.Time sql string args []any rowCount int } func (rows *baseRows) FieldDescriptions() []pgconn.FieldDescription { return rows.resultReader.FieldDescriptions() } func (rows *baseRows) Close() { if rows.closed { return } rows.closed = true if rows.resultReader != nil { var closeErr error rows.commandTag, closeErr = rows.resultReader.Close() if rows.err == nil { rows.err = closeErr } } if rows.multiResultReader != nil { closeErr := rows.multiResultReader.Close() if rows.err == nil { rows.err = closeErr } } if rows.err != nil && rows.conn != nil && rows.sql != "" { if stmtcache.IsStatementInvalid(rows.err) { if sc := rows.conn.statementCache; sc != nil { sc.Invalidate(rows.sql) } if sc := rows.conn.descriptionCache; sc != nil { sc.Invalidate(rows.sql) } } } if rows.batchTracer != nil { rows.batchTracer.TraceBatchQuery(rows.ctx, rows.conn, TraceBatchQueryData{SQL: rows.sql, Args: rows.args, CommandTag: rows.commandTag, Err: rows.err}) } else if rows.queryTracer != nil { rows.queryTracer.TraceQueryEnd(rows.ctx, rows.conn, TraceQueryEndData{rows.commandTag, rows.err}) } } func (rows *baseRows) CommandTag() pgconn.CommandTag { return rows.commandTag } func (rows *baseRows) Err() error { return rows.err } // fatal signals an error occurred after the query was sent to the server. It // closes the rows automatically. func (rows *baseRows) fatal(err error) { if rows.err != nil { return } rows.err = err rows.Close() } func (rows *baseRows) Next() bool { if rows.closed { return false } if rows.resultReader.NextRow() { rows.rowCount++ rows.values = rows.resultReader.Values() return true } else { rows.Close() return false } } func (rows *baseRows) Scan(dest ...any) error { m := rows.typeMap fieldDescriptions := rows.FieldDescriptions() values := rows.values if len(fieldDescriptions) != len(values) { err := fmt.Errorf("number of field descriptions must equal number of values, got %d and %d", len(fieldDescriptions), len(values)) rows.fatal(err) return err } if len(dest) == 1 { if rc, ok := dest[0].(RowScanner); ok { return rc.ScanRow(rows) } } if len(fieldDescriptions) != len(dest) { err := fmt.Errorf("number of field descriptions must equal number of destinations, got %d and %d", len(fieldDescriptions), len(dest)) rows.fatal(err) return err } if rows.scanPlans == nil { rows.scanPlans = make([]pgtype.ScanPlan, len(values)) rows.scanTypes = make([]reflect.Type, len(values)) for i := range dest { rows.scanPlans[i] = m.PlanScan(fieldDescriptions[i].DataTypeOID, fieldDescriptions[i].Format, dest[i]) rows.scanTypes[i] = reflect.TypeOf(dest[i]) } } for i, dst := range dest { if dst == nil { continue } if rows.scanTypes[i] != reflect.TypeOf(dst) { rows.scanPlans[i] = m.PlanScan(fieldDescriptions[i].DataTypeOID, fieldDescriptions[i].Format, dest[i]) rows.scanTypes[i] = reflect.TypeOf(dest[i]) } err := rows.scanPlans[i].Scan(values[i], dst) if err != nil { err = ScanArgError{ColumnIndex: i, Err: err} rows.fatal(err) return err } } return nil } func (rows *baseRows) Values() ([]any, error) { if rows.closed { return nil, errors.New("rows is closed") } values := make([]any, 0, len(rows.FieldDescriptions())) for i := range rows.FieldDescriptions() { buf := rows.values[i] fd := &rows.FieldDescriptions()[i] if buf == nil { values = append(values, nil) continue } if dt, ok := rows.typeMap.TypeForOID(fd.DataTypeOID); ok { value, err := dt.Codec.DecodeValue(rows.typeMap, fd.DataTypeOID, fd.Format, buf) if err != nil { rows.fatal(err) } values = append(values, value) } else { switch fd.Format { case TextFormatCode: values = append(values, string(buf)) case BinaryFormatCode: newBuf := make([]byte, len(buf)) copy(newBuf, buf) values = append(values, newBuf) default: rows.fatal(errors.New("Unknown format code")) } } if rows.Err() != nil { return nil, rows.Err() } } return values, rows.Err() } func (rows *baseRows) RawValues() [][]byte { return rows.values } func (rows *baseRows) Conn() *Conn { return rows.conn } type ScanArgError struct { ColumnIndex int Err error } func (e ScanArgError) Error() string { return fmt.Sprintf("can't scan into dest[%d]: %v", e.ColumnIndex, e.Err) } func (e ScanArgError) Unwrap() error { return e.Err } // ScanRow decodes raw row data into dest. It can be used to scan rows read from the lower level pgconn interface. // // typeMap - OID to Go type mapping. // fieldDescriptions - OID and format of values // values - the raw data as returned from the PostgreSQL server // dest - the destination that values will be decoded into func ScanRow(typeMap *pgtype.Map, fieldDescriptions []pgconn.FieldDescription, values [][]byte, dest ...any) error { if len(fieldDescriptions) != len(values) { return fmt.Errorf("number of field descriptions must equal number of values, got %d and %d", len(fieldDescriptions), len(values)) } if len(fieldDescriptions) != len(dest) { return fmt.Errorf("number of field descriptions must equal number of destinations, got %d and %d", len(fieldDescriptions), len(dest)) } for i, d := range dest { if d == nil { continue } err := typeMap.Scan(fieldDescriptions[i].DataTypeOID, fieldDescriptions[i].Format, values[i], d) if err != nil { return ScanArgError{ColumnIndex: i, Err: err} } } return nil } // RowsFromResultReader returns a Rows that will read from values resultReader and decode with typeMap. It can be used // to read from the lower level pgconn interface. func RowsFromResultReader(typeMap *pgtype.Map, resultReader *pgconn.ResultReader) Rows { return &baseRows{ typeMap: typeMap, resultReader: resultReader, } } // ForEachRow iterates through rows. For each row it scans into the elements of scans and calls fn. If any row // fails to scan or fn returns an error the query will be aborted and the error will be returned. Rows will be closed // when ForEachRow returns. func ForEachRow(rows Rows, scans []any, fn func() error) (pgconn.CommandTag, error) { defer rows.Close() for rows.Next() { err := rows.Scan(scans...) if err != nil { return pgconn.CommandTag{}, err } err = fn() if err != nil { return pgconn.CommandTag{}, err } } if err := rows.Err(); err != nil { return pgconn.CommandTag{}, err } return rows.CommandTag(), nil } // CollectableRow is the subset of Rows methods that a RowToFunc is allowed to call. type CollectableRow interface { FieldDescriptions() []pgconn.FieldDescription Scan(dest ...any) error Values() ([]any, error) RawValues() [][]byte } // RowToFunc is a function that scans or otherwise converts row to a T. type RowToFunc[T any] func(row CollectableRow) (T, error) // CollectRows iterates through rows, calling fn for each row, and collecting the results into a slice of T. func CollectRows[T any](rows Rows, fn RowToFunc[T]) ([]T, error) { defer rows.Close() slice := []T{} for rows.Next() { value, err := fn(rows) if err != nil { return nil, err } slice = append(slice, value) } if err := rows.Err(); err != nil { return nil, err } return slice, nil } // CollectOneRow calls fn for the first row in rows and returns the result. If no rows are found returns an error where errors.Is(ErrNoRows) is true. // CollectOneRow is to CollectRows as QueryRow is to Query. func CollectOneRow[T any](rows Rows, fn RowToFunc[T]) (T, error) { defer rows.Close() var value T var err error if !rows.Next() { if err = rows.Err(); err != nil { return value, err } return value, ErrNoRows } value, err = fn(rows) if err != nil { return value, err } rows.Close() return value, rows.Err() } // RowTo returns a T scanned from row. func RowTo[T any](row CollectableRow) (T, error) { var value T err := row.Scan(&value) return value, err } // RowTo returns a the address of a T scanned from row. func RowToAddrOf[T any](row CollectableRow) (*T, error) { var value T err := row.Scan(&value) return &value, err } // RowToMap returns a map scanned from row. func RowToMap(row CollectableRow) (map[string]any, error) { var value map[string]any err := row.Scan((*mapRowScanner)(&value)) return value, err } type mapRowScanner map[string]any func (rs *mapRowScanner) ScanRow(rows Rows) error { values, err := rows.Values() if err != nil { return err } *rs = make(mapRowScanner, len(values)) for i := range values { (*rs)[string(rows.FieldDescriptions()[i].Name)] = values[i] } return nil } // RowToStructByPos returns a T scanned from row. T must be a struct. T must have the same number a public fields as row // has fields. The row and T fields will by matched by position. func RowToStructByPos[T any](row CollectableRow) (T, error) { var value T err := row.Scan(&positionalStructRowScanner{ptrToStruct: &value}) return value, err } // RowToAddrOfStructByPos returns the address of a T scanned from row. T must be a struct. T must have the same number a // public fields as row has fields. The row and T fields will by matched by position. func RowToAddrOfStructByPos[T any](row CollectableRow) (*T, error) { var value T err := row.Scan(&positionalStructRowScanner{ptrToStruct: &value}) return &value, err } type positionalStructRowScanner struct { ptrToStruct any } func (rs *positionalStructRowScanner) ScanRow(rows Rows) error { dst := rs.ptrToStruct dstValue := reflect.ValueOf(dst) if dstValue.Kind() != reflect.Ptr { return fmt.Errorf("dst not a pointer") } dstElemValue := dstValue.Elem() scanTargets := rs.appendScanTargets(dstElemValue, nil) if len(rows.RawValues()) > len(scanTargets) { return fmt.Errorf("got %d values, but dst struct has only %d fields", len(rows.RawValues()), len(scanTargets)) } return rows.Scan(scanTargets...) } func (rs *positionalStructRowScanner) appendScanTargets(dstElemValue reflect.Value, scanTargets []any) []any { dstElemType := dstElemValue.Type() if scanTargets == nil { scanTargets = make([]any, 0, dstElemType.NumField()) } for i := 0; i < dstElemType.NumField(); i++ { sf := dstElemType.Field(i) if sf.PkgPath == "" { // Handle anonymous struct embedding, but do not try to handle embedded pointers. if sf.Anonymous && sf.Type.Kind() == reflect.Struct { scanTargets = rs.appendScanTargets(dstElemValue.Field(i), scanTargets) } else { scanTargets = append(scanTargets, dstElemValue.Field(i).Addr().Interface()) } } } return scanTargets } // RowToStructByName returns a T scanned from row. T must be a struct. T must have the same number of named public // fields as row has fields. The row and T fields will by matched by name. The match is case-insensitive. The database // column name can be overridden with a "db" struct tag. If the "db" struct tag is "-" then the field will be ignored. func RowToStructByName[T any](row CollectableRow) (T, error) { var value T err := row.Scan(&namedStructRowScanner{ptrToStruct: &value}) return value, err } // RowToAddrOfStructByName returns the address of a T scanned from row. T must be a struct. T must have the same number // of named public fields as row has fields. The row and T fields will by matched by name. The match is // case-insensitive. The database column name can be overridden with a "db" struct tag. If the "db" struct tag is "-" // then the field will be ignored. func RowToAddrOfStructByName[T any](row CollectableRow) (*T, error) { var value T err := row.Scan(&namedStructRowScanner{ptrToStruct: &value}) return &value, err } type namedStructRowScanner struct { ptrToStruct any } func (rs *namedStructRowScanner) ScanRow(rows Rows) error { dst := rs.ptrToStruct dstValue := reflect.ValueOf(dst) if dstValue.Kind() != reflect.Ptr { return fmt.Errorf("dst not a pointer") } dstElemValue := dstValue.Elem() scanTargets, err := rs.appendScanTargets(dstElemValue, nil, rows.FieldDescriptions()) if err != nil { return err } for i, t := range scanTargets { if t == nil { return fmt.Errorf("struct doesn't have corresponding row field %s", rows.FieldDescriptions()[i].Name) } } return rows.Scan(scanTargets...) } const structTagKey = "db" func fieldPosByName(fldDescs []pgconn.FieldDescription, field string) (i int) { i = -1 for i, desc := range fldDescs { if strings.EqualFold(desc.Name, field) { return i } } return } func (rs *namedStructRowScanner) appendScanTargets(dstElemValue reflect.Value, scanTargets []any, fldDescs []pgconn.FieldDescription) ([]any, error) { var err error dstElemType := dstElemValue.Type() if scanTargets == nil { scanTargets = make([]any, len(fldDescs)) } for i := 0; i < dstElemType.NumField(); i++ { sf := dstElemType.Field(i) if sf.PkgPath != "" && !sf.Anonymous { // Field is unexported, skip it. continue } // Handle anoymous struct embedding, but do not try to handle embedded pointers. if sf.Anonymous && sf.Type.Kind() == reflect.Struct { scanTargets, err = rs.appendScanTargets(dstElemValue.Field(i), scanTargets, fldDescs) if err != nil { return nil, err } } else { dbTag, dbTagPresent := sf.Tag.Lookup(structTagKey) if dbTagPresent { dbTag = strings.Split(dbTag, ",")[0] } if dbTag == "-" { // Field is ignored, skip it. continue } colName := dbTag if !dbTagPresent { colName = sf.Name } fpos := fieldPosByName(fldDescs, colName) if fpos == -1 || fpos >= len(scanTargets) { return nil, fmt.Errorf("cannot find field %s in returned row", colName) } scanTargets[fpos] = dstElemValue.Field(i).Addr().Interface() } } return scanTargets, err }