2021-08-25 14:34:33 +01:00
|
|
|
package bun
|
|
|
|
|
|
|
|
import (
|
|
|
|
"context"
|
2022-08-15 11:35:05 +01:00
|
|
|
"crypto/rand"
|
2021-08-25 14:34:33 +01:00
|
|
|
"database/sql"
|
2022-08-15 11:35:05 +01:00
|
|
|
"encoding/hex"
|
2021-08-25 14:34:33 +01:00
|
|
|
"fmt"
|
|
|
|
"reflect"
|
|
|
|
"strings"
|
|
|
|
"sync/atomic"
|
|
|
|
|
|
|
|
"github.com/uptrace/bun/dialect/feature"
|
|
|
|
"github.com/uptrace/bun/internal"
|
|
|
|
"github.com/uptrace/bun/schema"
|
|
|
|
)
|
|
|
|
|
|
|
|
const (
|
|
|
|
discardUnknownColumns internal.Flag = 1 << iota
|
|
|
|
)
|
|
|
|
|
|
|
|
type DBStats struct {
|
2021-09-23 10:13:28 +01:00
|
|
|
Queries uint32
|
|
|
|
Errors uint32
|
2021-08-25 14:34:33 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
type DBOption func(db *DB)
|
|
|
|
|
|
|
|
func WithDiscardUnknownColumns() DBOption {
|
|
|
|
return func(db *DB) {
|
|
|
|
db.flags = db.flags.Set(discardUnknownColumns)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
type DB struct {
|
|
|
|
*sql.DB
|
2022-03-07 10:08:26 +00:00
|
|
|
|
2021-08-25 14:34:33 +01:00
|
|
|
dialect schema.Dialect
|
|
|
|
features feature.Feature
|
|
|
|
|
|
|
|
queryHooks []QueryHook
|
|
|
|
|
|
|
|
fmter schema.Formatter
|
|
|
|
flags internal.Flag
|
|
|
|
|
|
|
|
stats DBStats
|
|
|
|
}
|
|
|
|
|
|
|
|
func NewDB(sqldb *sql.DB, dialect schema.Dialect, opts ...DBOption) *DB {
|
|
|
|
dialect.Init(sqldb)
|
|
|
|
|
|
|
|
db := &DB{
|
|
|
|
DB: sqldb,
|
|
|
|
dialect: dialect,
|
|
|
|
features: dialect.Features(),
|
|
|
|
fmter: schema.NewFormatter(dialect),
|
|
|
|
}
|
|
|
|
|
|
|
|
for _, opt := range opts {
|
|
|
|
opt(db)
|
|
|
|
}
|
|
|
|
|
|
|
|
return db
|
|
|
|
}
|
|
|
|
|
|
|
|
func (db *DB) String() string {
|
|
|
|
var b strings.Builder
|
|
|
|
b.WriteString("DB<dialect=")
|
|
|
|
b.WriteString(db.dialect.Name().String())
|
|
|
|
b.WriteString(">")
|
|
|
|
return b.String()
|
|
|
|
}
|
|
|
|
|
|
|
|
func (db *DB) DBStats() DBStats {
|
|
|
|
return DBStats{
|
2021-09-23 10:13:28 +01:00
|
|
|
Queries: atomic.LoadUint32(&db.stats.Queries),
|
|
|
|
Errors: atomic.LoadUint32(&db.stats.Errors),
|
2021-08-25 14:34:33 +01:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func (db *DB) NewValues(model interface{}) *ValuesQuery {
|
|
|
|
return NewValuesQuery(db, model)
|
|
|
|
}
|
|
|
|
|
|
|
|
func (db *DB) NewSelect() *SelectQuery {
|
|
|
|
return NewSelectQuery(db)
|
|
|
|
}
|
|
|
|
|
|
|
|
func (db *DB) NewInsert() *InsertQuery {
|
|
|
|
return NewInsertQuery(db)
|
|
|
|
}
|
|
|
|
|
|
|
|
func (db *DB) NewUpdate() *UpdateQuery {
|
|
|
|
return NewUpdateQuery(db)
|
|
|
|
}
|
|
|
|
|
|
|
|
func (db *DB) NewDelete() *DeleteQuery {
|
|
|
|
return NewDeleteQuery(db)
|
|
|
|
}
|
|
|
|
|
2022-09-28 18:30:40 +01:00
|
|
|
func (db *DB) NewRaw(query string, args ...interface{}) *RawQuery {
|
|
|
|
return NewRawQuery(db, query, args...)
|
|
|
|
}
|
|
|
|
|
2021-08-25 14:34:33 +01:00
|
|
|
func (db *DB) NewCreateTable() *CreateTableQuery {
|
|
|
|
return NewCreateTableQuery(db)
|
|
|
|
}
|
|
|
|
|
|
|
|
func (db *DB) NewDropTable() *DropTableQuery {
|
|
|
|
return NewDropTableQuery(db)
|
|
|
|
}
|
|
|
|
|
|
|
|
func (db *DB) NewCreateIndex() *CreateIndexQuery {
|
|
|
|
return NewCreateIndexQuery(db)
|
|
|
|
}
|
|
|
|
|
|
|
|
func (db *DB) NewDropIndex() *DropIndexQuery {
|
|
|
|
return NewDropIndexQuery(db)
|
|
|
|
}
|
|
|
|
|
|
|
|
func (db *DB) NewTruncateTable() *TruncateTableQuery {
|
|
|
|
return NewTruncateTableQuery(db)
|
|
|
|
}
|
|
|
|
|
|
|
|
func (db *DB) NewAddColumn() *AddColumnQuery {
|
|
|
|
return NewAddColumnQuery(db)
|
|
|
|
}
|
|
|
|
|
|
|
|
func (db *DB) NewDropColumn() *DropColumnQuery {
|
|
|
|
return NewDropColumnQuery(db)
|
|
|
|
}
|
|
|
|
|
|
|
|
func (db *DB) ResetModel(ctx context.Context, models ...interface{}) error {
|
|
|
|
for _, model := range models {
|
2022-03-07 10:08:26 +00:00
|
|
|
if _, err := db.NewDropTable().Model(model).IfExists().Cascade().Exec(ctx); err != nil {
|
2021-08-25 14:34:33 +01:00
|
|
|
return err
|
|
|
|
}
|
|
|
|
if _, err := db.NewCreateTable().Model(model).Exec(ctx); err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func (db *DB) Dialect() schema.Dialect {
|
|
|
|
return db.dialect
|
|
|
|
}
|
|
|
|
|
|
|
|
func (db *DB) ScanRows(ctx context.Context, rows *sql.Rows, dest ...interface{}) error {
|
2022-08-15 11:35:05 +01:00
|
|
|
defer rows.Close()
|
|
|
|
|
2021-08-25 14:34:33 +01:00
|
|
|
model, err := newModel(db, dest)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
_, err = model.ScanRows(ctx, rows)
|
2022-08-15 11:35:05 +01:00
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
return rows.Err()
|
2021-08-25 14:34:33 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
func (db *DB) ScanRow(ctx context.Context, rows *sql.Rows, dest ...interface{}) error {
|
|
|
|
model, err := newModel(db, dest)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
rs, ok := model.(rowScanner)
|
|
|
|
if !ok {
|
|
|
|
return fmt.Errorf("bun: %T does not support ScanRow", model)
|
|
|
|
}
|
|
|
|
|
|
|
|
return rs.ScanRow(ctx, rows)
|
|
|
|
}
|
|
|
|
|
2021-10-24 12:14:37 +01:00
|
|
|
type queryHookIniter interface {
|
|
|
|
Init(db *DB)
|
|
|
|
}
|
|
|
|
|
2021-08-25 14:34:33 +01:00
|
|
|
func (db *DB) AddQueryHook(hook QueryHook) {
|
2021-10-24 12:14:37 +01:00
|
|
|
if initer, ok := hook.(queryHookIniter); ok {
|
|
|
|
initer.Init(db)
|
|
|
|
}
|
2021-08-25 14:34:33 +01:00
|
|
|
db.queryHooks = append(db.queryHooks, hook)
|
|
|
|
}
|
|
|
|
|
|
|
|
func (db *DB) Table(typ reflect.Type) *schema.Table {
|
|
|
|
return db.dialect.Tables().Get(typ)
|
|
|
|
}
|
|
|
|
|
2021-11-27 14:26:58 +00:00
|
|
|
// RegisterModel registers models by name so they can be referenced in table relations
|
|
|
|
// and fixtures.
|
2021-08-25 14:34:33 +01:00
|
|
|
func (db *DB) RegisterModel(models ...interface{}) {
|
|
|
|
db.dialect.Tables().Register(models...)
|
|
|
|
}
|
|
|
|
|
|
|
|
func (db *DB) clone() *DB {
|
|
|
|
clone := *db
|
|
|
|
|
|
|
|
l := len(clone.queryHooks)
|
|
|
|
clone.queryHooks = clone.queryHooks[:l:l]
|
|
|
|
|
|
|
|
return &clone
|
|
|
|
}
|
|
|
|
|
|
|
|
func (db *DB) WithNamedArg(name string, value interface{}) *DB {
|
|
|
|
clone := db.clone()
|
|
|
|
clone.fmter = clone.fmter.WithNamedArg(name, value)
|
|
|
|
return clone
|
|
|
|
}
|
|
|
|
|
|
|
|
func (db *DB) Formatter() schema.Formatter {
|
|
|
|
return db.fmter
|
|
|
|
}
|
|
|
|
|
2021-12-12 14:47:51 +00:00
|
|
|
// UpdateFQN returns a fully qualified column name. For MySQL, it returns the column name with
|
|
|
|
// the table alias. For other RDBMS, it returns just the column name.
|
|
|
|
func (db *DB) UpdateFQN(alias, column string) Ident {
|
|
|
|
if db.HasFeature(feature.UpdateMultiTable) {
|
|
|
|
return Ident(alias + "." + column)
|
|
|
|
}
|
|
|
|
return Ident(column)
|
|
|
|
}
|
|
|
|
|
2021-11-27 14:26:58 +00:00
|
|
|
// HasFeature uses feature package to report whether the underlying DBMS supports this feature.
|
|
|
|
func (db *DB) HasFeature(feat feature.Feature) bool {
|
|
|
|
return db.fmter.HasFeature(feat)
|
|
|
|
}
|
|
|
|
|
2021-08-25 14:34:33 +01:00
|
|
|
//------------------------------------------------------------------------------
|
|
|
|
|
|
|
|
func (db *DB) Exec(query string, args ...interface{}) (sql.Result, error) {
|
|
|
|
return db.ExecContext(context.Background(), query, args...)
|
|
|
|
}
|
|
|
|
|
|
|
|
func (db *DB) ExecContext(
|
|
|
|
ctx context.Context, query string, args ...interface{},
|
|
|
|
) (sql.Result, error) {
|
2022-03-07 10:08:26 +00:00
|
|
|
formattedQuery := db.format(query, args)
|
|
|
|
ctx, event := db.beforeQuery(ctx, nil, query, args, formattedQuery, nil)
|
|
|
|
res, err := db.DB.ExecContext(ctx, formattedQuery)
|
2021-08-25 14:34:33 +01:00
|
|
|
db.afterQuery(ctx, event, res, err)
|
|
|
|
return res, err
|
|
|
|
}
|
|
|
|
|
|
|
|
func (db *DB) Query(query string, args ...interface{}) (*sql.Rows, error) {
|
|
|
|
return db.QueryContext(context.Background(), query, args...)
|
|
|
|
}
|
|
|
|
|
|
|
|
func (db *DB) QueryContext(
|
|
|
|
ctx context.Context, query string, args ...interface{},
|
|
|
|
) (*sql.Rows, error) {
|
2022-03-07 10:08:26 +00:00
|
|
|
formattedQuery := db.format(query, args)
|
|
|
|
ctx, event := db.beforeQuery(ctx, nil, query, args, formattedQuery, nil)
|
|
|
|
rows, err := db.DB.QueryContext(ctx, formattedQuery)
|
2021-08-25 14:34:33 +01:00
|
|
|
db.afterQuery(ctx, event, nil, err)
|
|
|
|
return rows, err
|
|
|
|
}
|
|
|
|
|
|
|
|
func (db *DB) QueryRow(query string, args ...interface{}) *sql.Row {
|
|
|
|
return db.QueryRowContext(context.Background(), query, args...)
|
|
|
|
}
|
|
|
|
|
|
|
|
func (db *DB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
|
2022-03-07 10:08:26 +00:00
|
|
|
formattedQuery := db.format(query, args)
|
|
|
|
ctx, event := db.beforeQuery(ctx, nil, query, args, formattedQuery, nil)
|
|
|
|
row := db.DB.QueryRowContext(ctx, formattedQuery)
|
2021-08-25 14:34:33 +01:00
|
|
|
db.afterQuery(ctx, event, nil, row.Err())
|
|
|
|
return row
|
|
|
|
}
|
|
|
|
|
|
|
|
func (db *DB) format(query string, args []interface{}) string {
|
|
|
|
return db.fmter.FormatQuery(query, args...)
|
|
|
|
}
|
|
|
|
|
|
|
|
//------------------------------------------------------------------------------
|
|
|
|
|
|
|
|
type Conn struct {
|
|
|
|
db *DB
|
|
|
|
*sql.Conn
|
|
|
|
}
|
|
|
|
|
|
|
|
func (db *DB) Conn(ctx context.Context) (Conn, error) {
|
|
|
|
conn, err := db.DB.Conn(ctx)
|
|
|
|
if err != nil {
|
|
|
|
return Conn{}, err
|
|
|
|
}
|
|
|
|
return Conn{
|
|
|
|
db: db,
|
|
|
|
Conn: conn,
|
|
|
|
}, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func (c Conn) ExecContext(
|
|
|
|
ctx context.Context, query string, args ...interface{},
|
|
|
|
) (sql.Result, error) {
|
2022-03-07 10:08:26 +00:00
|
|
|
formattedQuery := c.db.format(query, args)
|
|
|
|
ctx, event := c.db.beforeQuery(ctx, nil, query, args, formattedQuery, nil)
|
|
|
|
res, err := c.Conn.ExecContext(ctx, formattedQuery)
|
2021-08-25 14:34:33 +01:00
|
|
|
c.db.afterQuery(ctx, event, res, err)
|
|
|
|
return res, err
|
|
|
|
}
|
|
|
|
|
|
|
|
func (c Conn) QueryContext(
|
|
|
|
ctx context.Context, query string, args ...interface{},
|
|
|
|
) (*sql.Rows, error) {
|
2022-03-07 10:08:26 +00:00
|
|
|
formattedQuery := c.db.format(query, args)
|
|
|
|
ctx, event := c.db.beforeQuery(ctx, nil, query, args, formattedQuery, nil)
|
|
|
|
rows, err := c.Conn.QueryContext(ctx, formattedQuery)
|
2021-08-25 14:34:33 +01:00
|
|
|
c.db.afterQuery(ctx, event, nil, err)
|
|
|
|
return rows, err
|
|
|
|
}
|
|
|
|
|
|
|
|
func (c Conn) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
|
2022-03-07 10:08:26 +00:00
|
|
|
formattedQuery := c.db.format(query, args)
|
|
|
|
ctx, event := c.db.beforeQuery(ctx, nil, query, args, formattedQuery, nil)
|
|
|
|
row := c.Conn.QueryRowContext(ctx, formattedQuery)
|
2021-08-25 14:34:33 +01:00
|
|
|
c.db.afterQuery(ctx, event, nil, row.Err())
|
|
|
|
return row
|
|
|
|
}
|
|
|
|
|
2022-03-07 10:08:26 +00:00
|
|
|
func (c Conn) Dialect() schema.Dialect {
|
|
|
|
return c.db.Dialect()
|
|
|
|
}
|
|
|
|
|
2021-08-25 14:34:33 +01:00
|
|
|
func (c Conn) NewValues(model interface{}) *ValuesQuery {
|
|
|
|
return NewValuesQuery(c.db, model).Conn(c)
|
|
|
|
}
|
|
|
|
|
|
|
|
func (c Conn) NewSelect() *SelectQuery {
|
|
|
|
return NewSelectQuery(c.db).Conn(c)
|
|
|
|
}
|
|
|
|
|
|
|
|
func (c Conn) NewInsert() *InsertQuery {
|
|
|
|
return NewInsertQuery(c.db).Conn(c)
|
|
|
|
}
|
|
|
|
|
|
|
|
func (c Conn) NewUpdate() *UpdateQuery {
|
|
|
|
return NewUpdateQuery(c.db).Conn(c)
|
|
|
|
}
|
|
|
|
|
|
|
|
func (c Conn) NewDelete() *DeleteQuery {
|
|
|
|
return NewDeleteQuery(c.db).Conn(c)
|
|
|
|
}
|
|
|
|
|
2022-09-28 18:30:40 +01:00
|
|
|
func (c Conn) NewRaw(query string, args ...interface{}) *RawQuery {
|
|
|
|
return NewRawQuery(c.db, query, args...).Conn(c)
|
|
|
|
}
|
|
|
|
|
2021-08-25 14:34:33 +01:00
|
|
|
func (c Conn) NewCreateTable() *CreateTableQuery {
|
|
|
|
return NewCreateTableQuery(c.db).Conn(c)
|
|
|
|
}
|
|
|
|
|
|
|
|
func (c Conn) NewDropTable() *DropTableQuery {
|
|
|
|
return NewDropTableQuery(c.db).Conn(c)
|
|
|
|
}
|
|
|
|
|
|
|
|
func (c Conn) NewCreateIndex() *CreateIndexQuery {
|
|
|
|
return NewCreateIndexQuery(c.db).Conn(c)
|
|
|
|
}
|
|
|
|
|
|
|
|
func (c Conn) NewDropIndex() *DropIndexQuery {
|
|
|
|
return NewDropIndexQuery(c.db).Conn(c)
|
|
|
|
}
|
|
|
|
|
|
|
|
func (c Conn) NewTruncateTable() *TruncateTableQuery {
|
|
|
|
return NewTruncateTableQuery(c.db).Conn(c)
|
|
|
|
}
|
|
|
|
|
|
|
|
func (c Conn) NewAddColumn() *AddColumnQuery {
|
|
|
|
return NewAddColumnQuery(c.db).Conn(c)
|
|
|
|
}
|
|
|
|
|
|
|
|
func (c Conn) NewDropColumn() *DropColumnQuery {
|
|
|
|
return NewDropColumnQuery(c.db).Conn(c)
|
|
|
|
}
|
|
|
|
|
2022-08-15 11:35:05 +01:00
|
|
|
// RunInTx runs the function in a transaction. If the function returns an error,
|
|
|
|
// the transaction is rolled back. Otherwise, the transaction is committed.
|
|
|
|
func (c Conn) RunInTx(
|
|
|
|
ctx context.Context, opts *sql.TxOptions, fn func(ctx context.Context, tx Tx) error,
|
|
|
|
) error {
|
|
|
|
tx, err := c.BeginTx(ctx, opts)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
var done bool
|
|
|
|
|
|
|
|
defer func() {
|
|
|
|
if !done {
|
|
|
|
_ = tx.Rollback()
|
|
|
|
}
|
|
|
|
}()
|
|
|
|
|
|
|
|
if err := fn(ctx, tx); err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
done = true
|
|
|
|
return tx.Commit()
|
|
|
|
}
|
|
|
|
|
|
|
|
func (c Conn) BeginTx(ctx context.Context, opts *sql.TxOptions) (Tx, error) {
|
|
|
|
ctx, event := c.db.beforeQuery(ctx, nil, "BEGIN", nil, "BEGIN", nil)
|
|
|
|
tx, err := c.Conn.BeginTx(ctx, opts)
|
|
|
|
c.db.afterQuery(ctx, event, nil, err)
|
|
|
|
if err != nil {
|
|
|
|
return Tx{}, err
|
|
|
|
}
|
|
|
|
return Tx{
|
|
|
|
ctx: ctx,
|
|
|
|
db: c.db,
|
|
|
|
Tx: tx,
|
|
|
|
}, nil
|
|
|
|
}
|
|
|
|
|
2021-08-25 14:34:33 +01:00
|
|
|
//------------------------------------------------------------------------------
|
|
|
|
|
|
|
|
type Stmt struct {
|
|
|
|
*sql.Stmt
|
|
|
|
}
|
|
|
|
|
|
|
|
func (db *DB) Prepare(query string) (Stmt, error) {
|
|
|
|
return db.PrepareContext(context.Background(), query)
|
|
|
|
}
|
|
|
|
|
|
|
|
func (db *DB) PrepareContext(ctx context.Context, query string) (Stmt, error) {
|
|
|
|
stmt, err := db.DB.PrepareContext(ctx, query)
|
|
|
|
if err != nil {
|
|
|
|
return Stmt{}, err
|
|
|
|
}
|
|
|
|
return Stmt{Stmt: stmt}, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
//------------------------------------------------------------------------------
|
|
|
|
|
|
|
|
type Tx struct {
|
2021-11-13 11:29:08 +00:00
|
|
|
ctx context.Context
|
|
|
|
db *DB
|
2022-08-15 11:35:05 +01:00
|
|
|
// name is the name of a savepoint
|
|
|
|
name string
|
2021-08-25 14:34:33 +01:00
|
|
|
*sql.Tx
|
|
|
|
}
|
|
|
|
|
|
|
|
// RunInTx runs the function in a transaction. If the function returns an error,
|
|
|
|
// the transaction is rolled back. Otherwise, the transaction is committed.
|
|
|
|
func (db *DB) RunInTx(
|
|
|
|
ctx context.Context, opts *sql.TxOptions, fn func(ctx context.Context, tx Tx) error,
|
|
|
|
) error {
|
|
|
|
tx, err := db.BeginTx(ctx, opts)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
2021-11-13 11:29:08 +00:00
|
|
|
|
|
|
|
var done bool
|
|
|
|
|
|
|
|
defer func() {
|
|
|
|
if !done {
|
|
|
|
_ = tx.Rollback()
|
|
|
|
}
|
|
|
|
}()
|
2021-08-25 14:34:33 +01:00
|
|
|
|
|
|
|
if err := fn(ctx, tx); err != nil {
|
|
|
|
return err
|
|
|
|
}
|
2021-11-13 11:29:08 +00:00
|
|
|
|
|
|
|
done = true
|
2021-08-25 14:34:33 +01:00
|
|
|
return tx.Commit()
|
|
|
|
}
|
|
|
|
|
|
|
|
func (db *DB) Begin() (Tx, error) {
|
|
|
|
return db.BeginTx(context.Background(), nil)
|
|
|
|
}
|
|
|
|
|
|
|
|
func (db *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (Tx, error) {
|
2022-03-07 10:08:26 +00:00
|
|
|
ctx, event := db.beforeQuery(ctx, nil, "BEGIN", nil, "BEGIN", nil)
|
2021-08-25 14:34:33 +01:00
|
|
|
tx, err := db.DB.BeginTx(ctx, opts)
|
2021-11-13 11:29:08 +00:00
|
|
|
db.afterQuery(ctx, event, nil, err)
|
2021-08-25 14:34:33 +01:00
|
|
|
if err != nil {
|
|
|
|
return Tx{}, err
|
|
|
|
}
|
|
|
|
return Tx{
|
2021-11-13 11:29:08 +00:00
|
|
|
ctx: ctx,
|
|
|
|
db: db,
|
|
|
|
Tx: tx,
|
2021-08-25 14:34:33 +01:00
|
|
|
}, nil
|
|
|
|
}
|
|
|
|
|
2021-11-13 11:29:08 +00:00
|
|
|
func (tx Tx) Commit() error {
|
2022-08-15 11:35:05 +01:00
|
|
|
if tx.name == "" {
|
|
|
|
return tx.commitTX()
|
|
|
|
}
|
|
|
|
return tx.commitSP()
|
|
|
|
}
|
|
|
|
|
|
|
|
func (tx Tx) commitTX() error {
|
2022-03-07 10:08:26 +00:00
|
|
|
ctx, event := tx.db.beforeQuery(tx.ctx, nil, "COMMIT", nil, "COMMIT", nil)
|
2021-11-13 11:29:08 +00:00
|
|
|
err := tx.Tx.Commit()
|
|
|
|
tx.db.afterQuery(ctx, event, nil, err)
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
2022-08-15 11:35:05 +01:00
|
|
|
func (tx Tx) commitSP() error {
|
|
|
|
if tx.Dialect().Features().Has(feature.MSSavepoint) {
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
query := "RELEASE SAVEPOINT " + tx.name
|
|
|
|
_, err := tx.ExecContext(tx.ctx, query)
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
2021-11-13 11:29:08 +00:00
|
|
|
func (tx Tx) Rollback() error {
|
2022-08-15 11:35:05 +01:00
|
|
|
if tx.name == "" {
|
|
|
|
return tx.rollbackTX()
|
|
|
|
}
|
|
|
|
return tx.rollbackSP()
|
|
|
|
}
|
|
|
|
|
|
|
|
func (tx Tx) rollbackTX() error {
|
2022-03-07 10:08:26 +00:00
|
|
|
ctx, event := tx.db.beforeQuery(tx.ctx, nil, "ROLLBACK", nil, "ROLLBACK", nil)
|
2021-11-13 11:29:08 +00:00
|
|
|
err := tx.Tx.Rollback()
|
|
|
|
tx.db.afterQuery(ctx, event, nil, err)
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
2022-08-15 11:35:05 +01:00
|
|
|
func (tx Tx) rollbackSP() error {
|
|
|
|
query := "ROLLBACK TO SAVEPOINT " + tx.name
|
|
|
|
if tx.Dialect().Features().Has(feature.MSSavepoint) {
|
|
|
|
query = "ROLLBACK TRANSACTION " + tx.name
|
|
|
|
}
|
|
|
|
_, err := tx.ExecContext(tx.ctx, query)
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
2021-08-25 14:34:33 +01:00
|
|
|
func (tx Tx) Exec(query string, args ...interface{}) (sql.Result, error) {
|
|
|
|
return tx.ExecContext(context.TODO(), query, args...)
|
|
|
|
}
|
|
|
|
|
|
|
|
func (tx Tx) ExecContext(
|
|
|
|
ctx context.Context, query string, args ...interface{},
|
|
|
|
) (sql.Result, error) {
|
2022-03-07 10:08:26 +00:00
|
|
|
formattedQuery := tx.db.format(query, args)
|
|
|
|
ctx, event := tx.db.beforeQuery(ctx, nil, query, args, formattedQuery, nil)
|
|
|
|
res, err := tx.Tx.ExecContext(ctx, formattedQuery)
|
2021-08-25 14:34:33 +01:00
|
|
|
tx.db.afterQuery(ctx, event, res, err)
|
|
|
|
return res, err
|
|
|
|
}
|
|
|
|
|
|
|
|
func (tx Tx) Query(query string, args ...interface{}) (*sql.Rows, error) {
|
|
|
|
return tx.QueryContext(context.TODO(), query, args...)
|
|
|
|
}
|
|
|
|
|
|
|
|
func (tx Tx) QueryContext(
|
|
|
|
ctx context.Context, query string, args ...interface{},
|
|
|
|
) (*sql.Rows, error) {
|
2022-03-07 10:08:26 +00:00
|
|
|
formattedQuery := tx.db.format(query, args)
|
|
|
|
ctx, event := tx.db.beforeQuery(ctx, nil, query, args, formattedQuery, nil)
|
|
|
|
rows, err := tx.Tx.QueryContext(ctx, formattedQuery)
|
2021-08-25 14:34:33 +01:00
|
|
|
tx.db.afterQuery(ctx, event, nil, err)
|
|
|
|
return rows, err
|
|
|
|
}
|
|
|
|
|
|
|
|
func (tx Tx) QueryRow(query string, args ...interface{}) *sql.Row {
|
|
|
|
return tx.QueryRowContext(context.TODO(), query, args...)
|
|
|
|
}
|
|
|
|
|
|
|
|
func (tx Tx) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
|
2022-03-07 10:08:26 +00:00
|
|
|
formattedQuery := tx.db.format(query, args)
|
|
|
|
ctx, event := tx.db.beforeQuery(ctx, nil, query, args, formattedQuery, nil)
|
|
|
|
row := tx.Tx.QueryRowContext(ctx, formattedQuery)
|
2021-08-25 14:34:33 +01:00
|
|
|
tx.db.afterQuery(ctx, event, nil, row.Err())
|
|
|
|
return row
|
|
|
|
}
|
|
|
|
|
|
|
|
//------------------------------------------------------------------------------
|
|
|
|
|
2022-08-15 11:35:05 +01:00
|
|
|
func (tx Tx) Begin() (Tx, error) {
|
|
|
|
return tx.BeginTx(tx.ctx, nil)
|
|
|
|
}
|
|
|
|
|
|
|
|
// BeginTx will save a point in the running transaction.
|
|
|
|
func (tx Tx) BeginTx(ctx context.Context, _ *sql.TxOptions) (Tx, error) {
|
|
|
|
// mssql savepoint names are limited to 32 characters
|
|
|
|
sp := make([]byte, 14)
|
|
|
|
_, err := rand.Read(sp)
|
|
|
|
if err != nil {
|
|
|
|
return Tx{}, err
|
|
|
|
}
|
|
|
|
|
|
|
|
qName := "SP_" + hex.EncodeToString(sp)
|
|
|
|
query := "SAVEPOINT " + qName
|
|
|
|
if tx.Dialect().Features().Has(feature.MSSavepoint) {
|
|
|
|
query = "SAVE TRANSACTION " + qName
|
|
|
|
}
|
|
|
|
_, err = tx.ExecContext(ctx, query)
|
|
|
|
if err != nil {
|
|
|
|
return Tx{}, err
|
|
|
|
}
|
|
|
|
return Tx{
|
|
|
|
ctx: ctx,
|
|
|
|
db: tx.db,
|
|
|
|
Tx: tx.Tx,
|
|
|
|
name: qName,
|
|
|
|
}, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func (tx Tx) RunInTx(
|
|
|
|
ctx context.Context, _ *sql.TxOptions, fn func(ctx context.Context, tx Tx) error,
|
|
|
|
) error {
|
|
|
|
sp, err := tx.BeginTx(ctx, nil)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
var done bool
|
|
|
|
|
|
|
|
defer func() {
|
|
|
|
if !done {
|
|
|
|
_ = sp.Rollback()
|
|
|
|
}
|
|
|
|
}()
|
|
|
|
|
|
|
|
if err := fn(ctx, sp); err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
done = true
|
|
|
|
return sp.Commit()
|
|
|
|
}
|
|
|
|
|
2022-03-07 10:08:26 +00:00
|
|
|
func (tx Tx) Dialect() schema.Dialect {
|
|
|
|
return tx.db.Dialect()
|
|
|
|
}
|
|
|
|
|
2021-08-25 14:34:33 +01:00
|
|
|
func (tx Tx) NewValues(model interface{}) *ValuesQuery {
|
|
|
|
return NewValuesQuery(tx.db, model).Conn(tx)
|
|
|
|
}
|
|
|
|
|
|
|
|
func (tx Tx) NewSelect() *SelectQuery {
|
|
|
|
return NewSelectQuery(tx.db).Conn(tx)
|
|
|
|
}
|
|
|
|
|
|
|
|
func (tx Tx) NewInsert() *InsertQuery {
|
|
|
|
return NewInsertQuery(tx.db).Conn(tx)
|
|
|
|
}
|
|
|
|
|
|
|
|
func (tx Tx) NewUpdate() *UpdateQuery {
|
|
|
|
return NewUpdateQuery(tx.db).Conn(tx)
|
|
|
|
}
|
|
|
|
|
|
|
|
func (tx Tx) NewDelete() *DeleteQuery {
|
|
|
|
return NewDeleteQuery(tx.db).Conn(tx)
|
|
|
|
}
|
|
|
|
|
2022-09-28 18:30:40 +01:00
|
|
|
func (tx Tx) NewRaw(query string, args ...interface{}) *RawQuery {
|
|
|
|
return NewRawQuery(tx.db, query, args...).Conn(tx)
|
|
|
|
}
|
|
|
|
|
2021-08-25 14:34:33 +01:00
|
|
|
func (tx Tx) NewCreateTable() *CreateTableQuery {
|
|
|
|
return NewCreateTableQuery(tx.db).Conn(tx)
|
|
|
|
}
|
|
|
|
|
|
|
|
func (tx Tx) NewDropTable() *DropTableQuery {
|
|
|
|
return NewDropTableQuery(tx.db).Conn(tx)
|
|
|
|
}
|
|
|
|
|
|
|
|
func (tx Tx) NewCreateIndex() *CreateIndexQuery {
|
|
|
|
return NewCreateIndexQuery(tx.db).Conn(tx)
|
|
|
|
}
|
|
|
|
|
|
|
|
func (tx Tx) NewDropIndex() *DropIndexQuery {
|
|
|
|
return NewDropIndexQuery(tx.db).Conn(tx)
|
|
|
|
}
|
|
|
|
|
|
|
|
func (tx Tx) NewTruncateTable() *TruncateTableQuery {
|
|
|
|
return NewTruncateTableQuery(tx.db).Conn(tx)
|
|
|
|
}
|
|
|
|
|
|
|
|
func (tx Tx) NewAddColumn() *AddColumnQuery {
|
|
|
|
return NewAddColumnQuery(tx.db).Conn(tx)
|
|
|
|
}
|
|
|
|
|
|
|
|
func (tx Tx) NewDropColumn() *DropColumnQuery {
|
|
|
|
return NewDropColumnQuery(tx.db).Conn(tx)
|
|
|
|
}
|
|
|
|
|
2021-09-08 20:05:26 +01:00
|
|
|
//------------------------------------------------------------------------------
|
2021-08-25 14:34:33 +01:00
|
|
|
|
|
|
|
func (db *DB) makeQueryBytes() []byte {
|
|
|
|
// TODO: make this configurable?
|
|
|
|
return make([]byte, 0, 4096)
|
|
|
|
}
|