package bun import ( "context" "crypto/rand" "database/sql" "encoding/hex" "fmt" "reflect" "strings" "sync/atomic" "time" "github.com/uptrace/bun/dialect/feature" "github.com/uptrace/bun/internal" "github.com/uptrace/bun/schema" ) const ( discardUnknownColumns internal.Flag = 1 << iota ) type DBStats struct { Queries uint32 Errors uint32 } type DBOption func(db *DB) func WithOptions(opts ...DBOption) DBOption { return func(db *DB) { for _, opt := range opts { opt(db) } } } func WithDiscardUnknownColumns() DBOption { return func(db *DB) { db.flags = db.flags.Set(discardUnknownColumns) } } func WithConnResolver(resolver ConnResolver) DBOption { return func(db *DB) { db.resolver = resolver } } type DB struct { // Must be a pointer so we copy the whole state, not individual fields. *noCopyState queryHooks []QueryHook fmter schema.Formatter stats DBStats } // noCopyState contains DB fields that must not be copied on clone(), // for example, it is forbidden to copy atomic.Pointer. type noCopyState struct { *sql.DB dialect schema.Dialect resolver ConnResolver flags internal.Flag closed atomic.Bool } func NewDB(sqldb *sql.DB, dialect schema.Dialect, opts ...DBOption) *DB { dialect.Init(sqldb) db := &DB{ noCopyState: &noCopyState{ DB: sqldb, dialect: dialect, }, fmter: schema.NewFormatter(dialect), } for _, opt := range opts { opt(db) } return db } func (db *DB) String() string { var b strings.Builder b.WriteString("DB") return b.String() } func (db *DB) Close() error { if db.closed.Swap(true) { return nil } firstErr := db.DB.Close() if db.resolver != nil { if err := db.resolver.Close(); err != nil && firstErr == nil { firstErr = err } } return firstErr } func (db *DB) DBStats() DBStats { return DBStats{ Queries: atomic.LoadUint32(&db.stats.Queries), Errors: atomic.LoadUint32(&db.stats.Errors), } } func (db *DB) NewValues(model interface{}) *ValuesQuery { return NewValuesQuery(db, model) } func (db *DB) NewMerge() *MergeQuery { return NewMergeQuery(db) } func (db *DB) NewSelect() *SelectQuery { return NewSelectQuery(db) } func (db *DB) NewInsert() *InsertQuery { return NewInsertQuery(db) } func (db *DB) NewUpdate() *UpdateQuery { return NewUpdateQuery(db) } func (db *DB) NewDelete() *DeleteQuery { return NewDeleteQuery(db) } func (db *DB) NewRaw(query string, args ...interface{}) *RawQuery { return NewRawQuery(db, query, args...) } func (db *DB) NewCreateTable() *CreateTableQuery { return NewCreateTableQuery(db) } func (db *DB) NewDropTable() *DropTableQuery { return NewDropTableQuery(db) } func (db *DB) NewCreateIndex() *CreateIndexQuery { return NewCreateIndexQuery(db) } func (db *DB) NewDropIndex() *DropIndexQuery { return NewDropIndexQuery(db) } func (db *DB) NewTruncateTable() *TruncateTableQuery { return NewTruncateTableQuery(db) } func (db *DB) NewAddColumn() *AddColumnQuery { return NewAddColumnQuery(db) } func (db *DB) NewDropColumn() *DropColumnQuery { return NewDropColumnQuery(db) } func (db *DB) ResetModel(ctx context.Context, models ...interface{}) error { for _, model := range models { if _, err := db.NewDropTable().Model(model).IfExists().Cascade().Exec(ctx); err != nil { return err } if _, err := db.NewCreateTable().Model(model).Exec(ctx); err != nil { return err } } return nil } func (db *DB) Dialect() schema.Dialect { return db.dialect } func (db *DB) ScanRows(ctx context.Context, rows *sql.Rows, dest ...interface{}) error { defer rows.Close() model, err := newModel(db, dest) if err != nil { return err } _, err = model.ScanRows(ctx, rows) if err != nil { return err } return rows.Err() } func (db *DB) ScanRow(ctx context.Context, rows *sql.Rows, dest ...interface{}) error { model, err := newModel(db, dest) if err != nil { return err } rs, ok := model.(rowScanner) if !ok { return fmt.Errorf("bun: %T does not support ScanRow", model) } return rs.ScanRow(ctx, rows) } type queryHookIniter interface { Init(db *DB) } func (db *DB) AddQueryHook(hook QueryHook) { if initer, ok := hook.(queryHookIniter); ok { initer.Init(db) } db.queryHooks = append(db.queryHooks, hook) } func (db *DB) Table(typ reflect.Type) *schema.Table { return db.dialect.Tables().Get(typ) } // RegisterModel registers models by name so they can be referenced in table relations // and fixtures. func (db *DB) RegisterModel(models ...interface{}) { db.dialect.Tables().Register(models...) } func (db *DB) clone() *DB { clone := *db l := len(clone.queryHooks) clone.queryHooks = clone.queryHooks[:l:l] return &clone } func (db *DB) WithNamedArg(name string, value interface{}) *DB { clone := db.clone() clone.fmter = clone.fmter.WithNamedArg(name, value) return clone } func (db *DB) Formatter() schema.Formatter { return db.fmter } // UpdateFQN returns a fully qualified column name. For MySQL, it returns the column name with // the table alias. For other RDBMS, it returns just the column name. func (db *DB) UpdateFQN(alias, column string) Ident { if db.HasFeature(feature.UpdateMultiTable) { return Ident(alias + "." + column) } return Ident(column) } // HasFeature uses feature package to report whether the underlying DBMS supports this feature. func (db *DB) HasFeature(feat feature.Feature) bool { return db.dialect.Features().Has(feat) } //------------------------------------------------------------------------------ func (db *DB) Exec(query string, args ...interface{}) (sql.Result, error) { return db.ExecContext(context.Background(), query, args...) } func (db *DB) ExecContext( ctx context.Context, query string, args ...interface{}, ) (sql.Result, error) { formattedQuery := db.format(query, args) ctx, event := db.beforeQuery(ctx, nil, query, args, formattedQuery, nil) res, err := db.DB.ExecContext(ctx, formattedQuery) db.afterQuery(ctx, event, res, err) return res, err } func (db *DB) Query(query string, args ...interface{}) (*sql.Rows, error) { return db.QueryContext(context.Background(), query, args...) } func (db *DB) QueryContext( ctx context.Context, query string, args ...interface{}, ) (*sql.Rows, error) { formattedQuery := db.format(query, args) ctx, event := db.beforeQuery(ctx, nil, query, args, formattedQuery, nil) rows, err := db.DB.QueryContext(ctx, formattedQuery) db.afterQuery(ctx, event, nil, err) return rows, err } func (db *DB) QueryRow(query string, args ...interface{}) *sql.Row { return db.QueryRowContext(context.Background(), query, args...) } func (db *DB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { formattedQuery := db.format(query, args) ctx, event := db.beforeQuery(ctx, nil, query, args, formattedQuery, nil) row := db.DB.QueryRowContext(ctx, formattedQuery) db.afterQuery(ctx, event, nil, row.Err()) return row } func (db *DB) format(query string, args []interface{}) string { return db.fmter.FormatQuery(query, args...) } //------------------------------------------------------------------------------ type Conn struct { db *DB *sql.Conn } func (db *DB) Conn(ctx context.Context) (Conn, error) { conn, err := db.DB.Conn(ctx) if err != nil { return Conn{}, err } return Conn{ db: db, Conn: conn, }, nil } func (c Conn) ExecContext( ctx context.Context, query string, args ...interface{}, ) (sql.Result, error) { formattedQuery := c.db.format(query, args) ctx, event := c.db.beforeQuery(ctx, nil, query, args, formattedQuery, nil) res, err := c.Conn.ExecContext(ctx, formattedQuery) c.db.afterQuery(ctx, event, res, err) return res, err } func (c Conn) QueryContext( ctx context.Context, query string, args ...interface{}, ) (*sql.Rows, error) { formattedQuery := c.db.format(query, args) ctx, event := c.db.beforeQuery(ctx, nil, query, args, formattedQuery, nil) rows, err := c.Conn.QueryContext(ctx, formattedQuery) c.db.afterQuery(ctx, event, nil, err) return rows, err } func (c Conn) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { formattedQuery := c.db.format(query, args) ctx, event := c.db.beforeQuery(ctx, nil, query, args, formattedQuery, nil) row := c.Conn.QueryRowContext(ctx, formattedQuery) c.db.afterQuery(ctx, event, nil, row.Err()) return row } func (c Conn) Dialect() schema.Dialect { return c.db.Dialect() } func (c Conn) NewValues(model interface{}) *ValuesQuery { return NewValuesQuery(c.db, model).Conn(c) } func (c Conn) NewMerge() *MergeQuery { return NewMergeQuery(c.db).Conn(c) } func (c Conn) NewSelect() *SelectQuery { return NewSelectQuery(c.db).Conn(c) } func (c Conn) NewInsert() *InsertQuery { return NewInsertQuery(c.db).Conn(c) } func (c Conn) NewUpdate() *UpdateQuery { return NewUpdateQuery(c.db).Conn(c) } func (c Conn) NewDelete() *DeleteQuery { return NewDeleteQuery(c.db).Conn(c) } func (c Conn) NewRaw(query string, args ...interface{}) *RawQuery { return NewRawQuery(c.db, query, args...).Conn(c) } func (c Conn) NewCreateTable() *CreateTableQuery { return NewCreateTableQuery(c.db).Conn(c) } func (c Conn) NewDropTable() *DropTableQuery { return NewDropTableQuery(c.db).Conn(c) } func (c Conn) NewCreateIndex() *CreateIndexQuery { return NewCreateIndexQuery(c.db).Conn(c) } func (c Conn) NewDropIndex() *DropIndexQuery { return NewDropIndexQuery(c.db).Conn(c) } func (c Conn) NewTruncateTable() *TruncateTableQuery { return NewTruncateTableQuery(c.db).Conn(c) } func (c Conn) NewAddColumn() *AddColumnQuery { return NewAddColumnQuery(c.db).Conn(c) } func (c Conn) NewDropColumn() *DropColumnQuery { return NewDropColumnQuery(c.db).Conn(c) } // RunInTx runs the function in a transaction. If the function returns an error, // the transaction is rolled back. Otherwise, the transaction is committed. func (c Conn) RunInTx( ctx context.Context, opts *sql.TxOptions, fn func(ctx context.Context, tx Tx) error, ) error { tx, err := c.BeginTx(ctx, opts) if err != nil { return err } var done bool defer func() { if !done { _ = tx.Rollback() } }() if err := fn(ctx, tx); err != nil { return err } done = true return tx.Commit() } func (c Conn) BeginTx(ctx context.Context, opts *sql.TxOptions) (Tx, error) { ctx, event := c.db.beforeQuery(ctx, nil, "BEGIN", nil, "BEGIN", nil) tx, err := c.Conn.BeginTx(ctx, opts) c.db.afterQuery(ctx, event, nil, err) if err != nil { return Tx{}, err } return Tx{ ctx: ctx, db: c.db, Tx: tx, }, nil } //------------------------------------------------------------------------------ type Stmt struct { *sql.Stmt } func (db *DB) Prepare(query string) (Stmt, error) { return db.PrepareContext(context.Background(), query) } func (db *DB) PrepareContext(ctx context.Context, query string) (Stmt, error) { stmt, err := db.DB.PrepareContext(ctx, query) if err != nil { return Stmt{}, err } return Stmt{Stmt: stmt}, nil } //------------------------------------------------------------------------------ type Tx struct { ctx context.Context db *DB // name is the name of a savepoint name string *sql.Tx } // RunInTx runs the function in a transaction. If the function returns an error, // the transaction is rolled back. Otherwise, the transaction is committed. func (db *DB) RunInTx( ctx context.Context, opts *sql.TxOptions, fn func(ctx context.Context, tx Tx) error, ) error { tx, err := db.BeginTx(ctx, opts) if err != nil { return err } var done bool defer func() { if !done { _ = tx.Rollback() } }() if err := fn(ctx, tx); err != nil { return err } done = true return tx.Commit() } func (db *DB) Begin() (Tx, error) { return db.BeginTx(context.Background(), nil) } func (db *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (Tx, error) { ctx, event := db.beforeQuery(ctx, nil, "BEGIN", nil, "BEGIN", nil) tx, err := db.DB.BeginTx(ctx, opts) db.afterQuery(ctx, event, nil, err) if err != nil { return Tx{}, err } return Tx{ ctx: ctx, db: db, Tx: tx, }, nil } func (tx Tx) Commit() error { if tx.name == "" { return tx.commitTX() } return tx.commitSP() } func (tx Tx) commitTX() error { ctx, event := tx.db.beforeQuery(tx.ctx, nil, "COMMIT", nil, "COMMIT", nil) err := tx.Tx.Commit() tx.db.afterQuery(ctx, event, nil, err) return err } func (tx Tx) commitSP() error { if tx.db.HasFeature(feature.MSSavepoint) { return nil } query := "RELEASE SAVEPOINT " + tx.name _, err := tx.ExecContext(tx.ctx, query) return err } func (tx Tx) Rollback() error { if tx.name == "" { return tx.rollbackTX() } return tx.rollbackSP() } func (tx Tx) rollbackTX() error { ctx, event := tx.db.beforeQuery(tx.ctx, nil, "ROLLBACK", nil, "ROLLBACK", nil) err := tx.Tx.Rollback() tx.db.afterQuery(ctx, event, nil, err) return err } func (tx Tx) rollbackSP() error { query := "ROLLBACK TO SAVEPOINT " + tx.name if tx.db.HasFeature(feature.MSSavepoint) { query = "ROLLBACK TRANSACTION " + tx.name } _, err := tx.ExecContext(tx.ctx, query) return err } func (tx Tx) Exec(query string, args ...interface{}) (sql.Result, error) { return tx.ExecContext(context.TODO(), query, args...) } func (tx Tx) ExecContext( ctx context.Context, query string, args ...interface{}, ) (sql.Result, error) { formattedQuery := tx.db.format(query, args) ctx, event := tx.db.beforeQuery(ctx, nil, query, args, formattedQuery, nil) res, err := tx.Tx.ExecContext(ctx, formattedQuery) tx.db.afterQuery(ctx, event, res, err) return res, err } func (tx Tx) Query(query string, args ...interface{}) (*sql.Rows, error) { return tx.QueryContext(context.TODO(), query, args...) } func (tx Tx) QueryContext( ctx context.Context, query string, args ...interface{}, ) (*sql.Rows, error) { formattedQuery := tx.db.format(query, args) ctx, event := tx.db.beforeQuery(ctx, nil, query, args, formattedQuery, nil) rows, err := tx.Tx.QueryContext(ctx, formattedQuery) tx.db.afterQuery(ctx, event, nil, err) return rows, err } func (tx Tx) QueryRow(query string, args ...interface{}) *sql.Row { return tx.QueryRowContext(context.TODO(), query, args...) } func (tx Tx) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { formattedQuery := tx.db.format(query, args) ctx, event := tx.db.beforeQuery(ctx, nil, query, args, formattedQuery, nil) row := tx.Tx.QueryRowContext(ctx, formattedQuery) tx.db.afterQuery(ctx, event, nil, row.Err()) return row } //------------------------------------------------------------------------------ func (tx Tx) Begin() (Tx, error) { return tx.BeginTx(tx.ctx, nil) } // BeginTx will save a point in the running transaction. func (tx Tx) BeginTx(ctx context.Context, _ *sql.TxOptions) (Tx, error) { // mssql savepoint names are limited to 32 characters sp := make([]byte, 14) _, err := rand.Read(sp) if err != nil { return Tx{}, err } qName := "SP_" + hex.EncodeToString(sp) query := "SAVEPOINT " + qName if tx.db.HasFeature(feature.MSSavepoint) { query = "SAVE TRANSACTION " + qName } _, err = tx.ExecContext(ctx, query) if err != nil { return Tx{}, err } return Tx{ ctx: ctx, db: tx.db, Tx: tx.Tx, name: qName, }, nil } func (tx Tx) RunInTx( ctx context.Context, _ *sql.TxOptions, fn func(ctx context.Context, tx Tx) error, ) error { sp, err := tx.BeginTx(ctx, nil) if err != nil { return err } var done bool defer func() { if !done { _ = sp.Rollback() } }() if err := fn(ctx, sp); err != nil { return err } done = true return sp.Commit() } func (tx Tx) Dialect() schema.Dialect { return tx.db.Dialect() } func (tx Tx) NewValues(model interface{}) *ValuesQuery { return NewValuesQuery(tx.db, model).Conn(tx) } func (tx Tx) NewMerge() *MergeQuery { return NewMergeQuery(tx.db).Conn(tx) } func (tx Tx) NewSelect() *SelectQuery { return NewSelectQuery(tx.db).Conn(tx) } func (tx Tx) NewInsert() *InsertQuery { return NewInsertQuery(tx.db).Conn(tx) } func (tx Tx) NewUpdate() *UpdateQuery { return NewUpdateQuery(tx.db).Conn(tx) } func (tx Tx) NewDelete() *DeleteQuery { return NewDeleteQuery(tx.db).Conn(tx) } func (tx Tx) NewRaw(query string, args ...interface{}) *RawQuery { return NewRawQuery(tx.db, query, args...).Conn(tx) } func (tx Tx) NewCreateTable() *CreateTableQuery { return NewCreateTableQuery(tx.db).Conn(tx) } func (tx Tx) NewDropTable() *DropTableQuery { return NewDropTableQuery(tx.db).Conn(tx) } func (tx Tx) NewCreateIndex() *CreateIndexQuery { return NewCreateIndexQuery(tx.db).Conn(tx) } func (tx Tx) NewDropIndex() *DropIndexQuery { return NewDropIndexQuery(tx.db).Conn(tx) } func (tx Tx) NewTruncateTable() *TruncateTableQuery { return NewTruncateTableQuery(tx.db).Conn(tx) } func (tx Tx) NewAddColumn() *AddColumnQuery { return NewAddColumnQuery(tx.db).Conn(tx) } func (tx Tx) NewDropColumn() *DropColumnQuery { return NewDropColumnQuery(tx.db).Conn(tx) } //------------------------------------------------------------------------------ func (db *DB) makeQueryBytes() []byte { return internal.MakeQueryBytes() } //------------------------------------------------------------------------------ // ConnResolver enables routing queries to multiple databases. type ConnResolver interface { ResolveConn(query Query) IConn Close() error } // TODO: // - make monitoring interval configurable // - make ping timeout configutable // - allow adding read/write replicas for multi-master replication type ReadWriteConnResolver struct { replicas []*sql.DB // read-only replicas healthyReplicas atomic.Pointer[[]*sql.DB] nextReplica atomic.Int64 closed atomic.Bool } func NewReadWriteConnResolver(opts ...ReadWriteConnResolverOption) *ReadWriteConnResolver { r := new(ReadWriteConnResolver) for _, opt := range opts { opt(r) } if len(r.replicas) > 0 { r.healthyReplicas.Store(&r.replicas) go r.monitor() } return r } type ReadWriteConnResolverOption func(r *ReadWriteConnResolver) func WithReadOnlyReplica(dbs ...*sql.DB) ReadWriteConnResolverOption { return func(r *ReadWriteConnResolver) { r.replicas = append(r.replicas, dbs...) } } func (r *ReadWriteConnResolver) Close() error { if r.closed.Swap(true) { return nil } var firstErr error for _, db := range r.replicas { if err := db.Close(); err != nil && firstErr == nil { firstErr = err } } return firstErr } // healthyReplica returns a random healthy replica. func (r *ReadWriteConnResolver) ResolveConn(query Query) IConn { if len(r.replicas) == 0 || !isReadOnlyQuery(query) { return nil } replicas := r.loadHealthyReplicas() if len(replicas) == 0 { return nil } if len(replicas) == 1 { return replicas[0] } i := r.nextReplica.Add(1) return replicas[int(i)%len(replicas)] } func isReadOnlyQuery(query Query) bool { sel, ok := query.(*SelectQuery) if !ok { return false } for _, el := range sel.with { if !isReadOnlyQuery(el.query) { return false } } return true } func (r *ReadWriteConnResolver) loadHealthyReplicas() []*sql.DB { if ptr := r.healthyReplicas.Load(); ptr != nil { return *ptr } return nil } func (r *ReadWriteConnResolver) monitor() { const interval = 5 * time.Second for !r.closed.Load() { healthy := make([]*sql.DB, 0, len(r.replicas)) for _, replica := range r.replicas { ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) err := replica.PingContext(ctx) cancel() if err == nil { healthy = append(healthy, replica) } } r.healthyReplicas.Store(&healthy) time.Sleep(interval) } }