gotosocial/vendor/github.com/uptrace/bun/query_base.go
2021-09-29 15:09:45 +02:00

899 lines
19 KiB
Go

package bun
import (
"context"
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"github.com/uptrace/bun/dialect/feature"
"github.com/uptrace/bun/internal"
"github.com/uptrace/bun/schema"
)
const (
wherePKFlag internal.Flag = 1 << iota
forceDeleteFlag
deletedFlag
allWithDeletedFlag
)
type withQuery struct {
name string
query schema.QueryAppender
}
// IConn is a common interface for *sql.DB, *sql.Conn, and *sql.Tx.
type IConn interface {
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
}
var (
_ IConn = (*sql.DB)(nil)
_ IConn = (*sql.Conn)(nil)
_ IConn = (*sql.Tx)(nil)
_ IConn = (*DB)(nil)
_ IConn = (*Conn)(nil)
_ IConn = (*Tx)(nil)
)
// IDB is a common interface for *bun.DB, bun.Conn, and bun.Tx.
type IDB interface {
IConn
NewValues(model interface{}) *ValuesQuery
NewSelect() *SelectQuery
NewInsert() *InsertQuery
NewUpdate() *UpdateQuery
NewDelete() *DeleteQuery
NewCreateTable() *CreateTableQuery
NewDropTable() *DropTableQuery
NewCreateIndex() *CreateIndexQuery
NewDropIndex() *DropIndexQuery
NewTruncateTable() *TruncateTableQuery
NewAddColumn() *AddColumnQuery
NewDropColumn() *DropColumnQuery
}
var (
_ IConn = (*DB)(nil)
_ IConn = (*Conn)(nil)
_ IConn = (*Tx)(nil)
)
type baseQuery struct {
db *DB
conn IConn
model Model
err error
tableModel TableModel
table *schema.Table
with []withQuery
modelTable schema.QueryWithArgs
tables []schema.QueryWithArgs
columns []schema.QueryWithArgs
flags internal.Flag
}
func (q *baseQuery) DB() *DB {
return q.db
}
type query interface {
GetModel() Model
GetTableName() string
}
var _ query = (*baseQuery)(nil)
func (q *baseQuery) GetModel() Model {
return q.model
}
func (q *baseQuery) GetTableName() string {
if q.table != nil {
return q.table.Name
}
for _, wq := range q.with {
if v, ok := wq.query.(query); ok {
if model := v.GetModel(); model != nil {
return v.GetTableName()
}
}
}
if q.modelTable.Query != "" {
return q.modelTable.Query
}
if len(q.tables) > 0 {
return q.tables[0].Query
}
return ""
}
func (q *baseQuery) setConn(db IConn) {
// Unwrap Bun wrappers to not call query hooks twice.
switch db := db.(type) {
case *DB:
q.conn = db.DB
case Conn:
q.conn = db.Conn
case Tx:
q.conn = db.Tx
default:
q.conn = db
}
}
// TODO: rename to setModel
func (q *baseQuery) setTableModel(modeli interface{}) {
model, err := newSingleModel(q.db, modeli)
if err != nil {
q.setErr(err)
return
}
q.model = model
if tm, ok := model.(TableModel); ok {
q.tableModel = tm
q.table = tm.Table()
}
}
func (q *baseQuery) setErr(err error) {
if q.err == nil {
q.err = err
}
}
func (q *baseQuery) getModel(dest []interface{}) (Model, error) {
if len(dest) == 0 {
if q.model != nil {
return q.model, nil
}
return nil, errNilModel
}
return newModel(q.db, dest)
}
//------------------------------------------------------------------------------
func (q *baseQuery) checkSoftDelete() error {
if q.table == nil {
return errors.New("bun: can't use soft deletes without a table")
}
if q.table.SoftDeleteField == nil {
return fmt.Errorf("%s does not have a soft delete field", q.table)
}
if q.tableModel == nil {
return errors.New("bun: can't use soft deletes without a table model")
}
return nil
}
// Deleted adds `WHERE deleted_at IS NOT NULL` clause for soft deleted models.
func (q *baseQuery) whereDeleted() {
if err := q.checkSoftDelete(); err != nil {
q.setErr(err)
return
}
q.flags = q.flags.Set(deletedFlag)
q.flags = q.flags.Remove(allWithDeletedFlag)
}
// AllWithDeleted changes query to return all rows including soft deleted ones.
func (q *baseQuery) whereAllWithDeleted() {
if err := q.checkSoftDelete(); err != nil {
q.setErr(err)
return
}
q.flags = q.flags.Set(allWithDeletedFlag)
q.flags = q.flags.Remove(deletedFlag)
}
func (q *baseQuery) isSoftDelete() bool {
if q.table != nil {
return q.table.SoftDeleteField != nil && !q.flags.Has(allWithDeletedFlag)
}
return false
}
//------------------------------------------------------------------------------
func (q *baseQuery) addWith(name string, query schema.QueryAppender) {
q.with = append(q.with, withQuery{
name: name,
query: query,
})
}
func (q *baseQuery) appendWith(fmter schema.Formatter, b []byte) (_ []byte, err error) {
if len(q.with) == 0 {
return b, nil
}
b = append(b, "WITH "...)
for i, with := range q.with {
if i > 0 {
b = append(b, ", "...)
}
b = fmter.AppendIdent(b, with.name)
if q, ok := with.query.(schema.ColumnsAppender); ok {
b = append(b, " ("...)
b, err = q.AppendColumns(fmter, b)
if err != nil {
return nil, err
}
b = append(b, ")"...)
}
b = append(b, " AS ("...)
b, err = with.query.AppendQuery(fmter, b)
if err != nil {
return nil, err
}
b = append(b, ')')
}
b = append(b, ' ')
return b, nil
}
//------------------------------------------------------------------------------
func (q *baseQuery) addTable(table schema.QueryWithArgs) {
q.tables = append(q.tables, table)
}
func (q *baseQuery) addColumn(column schema.QueryWithArgs) {
q.columns = append(q.columns, column)
}
func (q *baseQuery) excludeColumn(columns []string) {
if q.columns == nil {
for _, f := range q.table.Fields {
q.columns = append(q.columns, schema.UnsafeIdent(f.Name))
}
}
if len(columns) == 1 && columns[0] == "*" {
q.columns = make([]schema.QueryWithArgs, 0)
return
}
for _, column := range columns {
if !q._excludeColumn(column) {
q.setErr(fmt.Errorf("bun: can't find column=%q", column))
return
}
}
}
func (q *baseQuery) _excludeColumn(column string) bool {
for i, col := range q.columns {
if col.Args == nil && col.Query == column {
q.columns = append(q.columns[:i], q.columns[i+1:]...)
return true
}
}
return false
}
//------------------------------------------------------------------------------
func (q *baseQuery) modelHasTableName() bool {
if !q.modelTable.IsZero() {
return q.modelTable.Query != ""
}
return q.table != nil
}
func (q *baseQuery) hasTables() bool {
return q.modelHasTableName() || len(q.tables) > 0
}
func (q *baseQuery) appendTables(
fmter schema.Formatter, b []byte,
) (_ []byte, err error) {
return q._appendTables(fmter, b, false)
}
func (q *baseQuery) appendTablesWithAlias(
fmter schema.Formatter, b []byte,
) (_ []byte, err error) {
return q._appendTables(fmter, b, true)
}
func (q *baseQuery) _appendTables(
fmter schema.Formatter, b []byte, withAlias bool,
) (_ []byte, err error) {
startLen := len(b)
if q.modelHasTableName() {
if !q.modelTable.IsZero() {
b, err = q.modelTable.AppendQuery(fmter, b)
if err != nil {
return nil, err
}
} else {
b = fmter.AppendQuery(b, string(q.table.SQLNameForSelects))
if withAlias && q.table.SQLAlias != q.table.SQLNameForSelects {
b = append(b, " AS "...)
b = append(b, q.table.SQLAlias...)
}
}
}
for _, table := range q.tables {
if len(b) > startLen {
b = append(b, ", "...)
}
b, err = table.AppendQuery(fmter, b)
if err != nil {
return nil, err
}
}
return b, nil
}
func (q *baseQuery) appendFirstTable(fmter schema.Formatter, b []byte) ([]byte, error) {
return q._appendFirstTable(fmter, b, false)
}
func (q *baseQuery) appendFirstTableWithAlias(
fmter schema.Formatter, b []byte,
) ([]byte, error) {
return q._appendFirstTable(fmter, b, true)
}
func (q *baseQuery) _appendFirstTable(
fmter schema.Formatter, b []byte, withAlias bool,
) ([]byte, error) {
if !q.modelTable.IsZero() {
return q.modelTable.AppendQuery(fmter, b)
}
if q.table != nil {
b = fmter.AppendQuery(b, string(q.table.SQLName))
if withAlias {
b = append(b, " AS "...)
b = append(b, q.table.SQLAlias...)
}
return b, nil
}
if len(q.tables) > 0 {
return q.tables[0].AppendQuery(fmter, b)
}
return nil, errors.New("bun: query does not have a table")
}
func (q *baseQuery) hasMultiTables() bool {
if q.modelHasTableName() {
return len(q.tables) >= 1
}
return len(q.tables) >= 2
}
func (q *baseQuery) appendOtherTables(fmter schema.Formatter, b []byte) (_ []byte, err error) {
tables := q.tables
if !q.modelHasTableName() {
tables = tables[1:]
}
for i, table := range tables {
if i > 0 {
b = append(b, ", "...)
}
b, err = table.AppendQuery(fmter, b)
if err != nil {
return nil, err
}
}
return b, nil
}
//------------------------------------------------------------------------------
func (q *baseQuery) appendColumns(fmter schema.Formatter, b []byte) (_ []byte, err error) {
for i, f := range q.columns {
if i > 0 {
b = append(b, ", "...)
}
b, err = f.AppendQuery(fmter, b)
if err != nil {
return nil, err
}
}
return b, nil
}
func (q *baseQuery) getFields() ([]*schema.Field, error) {
if len(q.columns) == 0 {
return q.table.Fields, nil
}
return q._getFields(false)
}
func (q *baseQuery) getDataFields() ([]*schema.Field, error) {
if len(q.columns) == 0 {
return q.table.DataFields, nil
}
return q._getFields(true)
}
func (q *baseQuery) _getFields(omitPK bool) ([]*schema.Field, error) {
fields := make([]*schema.Field, 0, len(q.columns))
for _, col := range q.columns {
if col.Args != nil {
continue
}
field, err := q.table.Field(col.Query)
if err != nil {
return nil, err
}
if omitPK && field.IsPK {
continue
}
fields = append(fields, field)
}
return fields, nil
}
func (q *baseQuery) scan(
ctx context.Context,
iquery IQuery,
query string,
model Model,
hasDest bool,
) (sql.Result, error) {
ctx, event := q.db.beforeQuery(ctx, iquery, query, nil, q.model)
rows, err := q.conn.QueryContext(ctx, query)
if err != nil {
q.db.afterQuery(ctx, event, nil, err)
return nil, err
}
defer rows.Close()
numRow, err := model.ScanRows(ctx, rows)
if err != nil {
q.db.afterQuery(ctx, event, nil, err)
return nil, err
}
if numRow == 0 && hasDest && isSingleRowModel(model) {
err = sql.ErrNoRows
}
res := driver.RowsAffected(numRow)
q.db.afterQuery(ctx, event, res, err)
return res, err
}
func (q *baseQuery) exec(
ctx context.Context,
iquery IQuery,
query string,
) (sql.Result, error) {
ctx, event := q.db.beforeQuery(ctx, iquery, query, nil, q.model)
res, err := q.conn.ExecContext(ctx, query)
if err != nil {
q.db.afterQuery(ctx, event, nil, err)
return res, err
}
q.db.afterQuery(ctx, event, res, err)
return res, nil
}
//------------------------------------------------------------------------------
func (q *baseQuery) AppendNamedArg(fmter schema.Formatter, b []byte, name string) ([]byte, bool) {
if q.table == nil {
return b, false
}
if m, ok := q.tableModel.(*structTableModel); ok {
if b, ok := m.AppendNamedArg(fmter, b, name); ok {
return b, ok
}
}
switch name {
case "TableName":
b = fmter.AppendQuery(b, string(q.table.SQLName))
return b, true
case "TableAlias":
b = fmter.AppendQuery(b, string(q.table.SQLAlias))
return b, true
case "PKs":
b = appendColumns(b, "", q.table.PKs)
return b, true
case "TablePKs":
b = appendColumns(b, q.table.SQLAlias, q.table.PKs)
return b, true
case "Columns":
b = appendColumns(b, "", q.table.Fields)
return b, true
case "TableColumns":
b = appendColumns(b, q.table.SQLAlias, q.table.Fields)
return b, true
}
return b, false
}
func appendColumns(b []byte, table schema.Safe, fields []*schema.Field) []byte {
for i, f := range fields {
if i > 0 {
b = append(b, ", "...)
}
if len(table) > 0 {
b = append(b, table...)
b = append(b, '.')
}
b = append(b, f.SQLName...)
}
return b
}
func formatterWithModel(fmter schema.Formatter, model schema.NamedArgAppender) schema.Formatter {
if fmter.IsNop() {
return fmter
}
return fmter.WithArg(model)
}
//------------------------------------------------------------------------------
type whereBaseQuery struct {
baseQuery
where []schema.QueryWithSep
}
func (q *whereBaseQuery) addWhere(where schema.QueryWithSep) {
q.where = append(q.where, where)
}
func (q *whereBaseQuery) addWhereGroup(sep string, where []schema.QueryWithSep) {
if len(where) == 0 {
return
}
q.addWhere(schema.SafeQueryWithSep("", nil, sep))
q.addWhere(schema.SafeQueryWithSep("", nil, "("))
where[0].Sep = ""
q.where = append(q.where, where...)
q.addWhere(schema.SafeQueryWithSep("", nil, ")"))
}
func (q *whereBaseQuery) mustAppendWhere(
fmter schema.Formatter, b []byte, withAlias bool,
) ([]byte, error) {
if len(q.where) == 0 && !q.flags.Has(wherePKFlag) {
err := errors.New("bun: Update and Delete queries require at least one Where")
return nil, err
}
return q.appendWhere(fmter, b, withAlias)
}
func (q *whereBaseQuery) appendWhere(
fmter schema.Formatter, b []byte, withAlias bool,
) (_ []byte, err error) {
if len(q.where) == 0 && !q.isSoftDelete() && !q.flags.Has(wherePKFlag) {
return b, nil
}
b = append(b, " WHERE "...)
startLen := len(b)
if len(q.where) > 0 {
b, err = appendWhere(fmter, b, q.where)
if err != nil {
return nil, err
}
}
if q.isSoftDelete() {
if len(b) > startLen {
b = append(b, " AND "...)
}
if withAlias {
b = append(b, q.tableModel.Table().SQLAlias...)
b = append(b, '.')
}
b = append(b, q.tableModel.Table().SoftDeleteField.SQLName...)
if q.flags.Has(deletedFlag) {
b = append(b, " IS NOT NULL"...)
} else {
b = append(b, " IS NULL"...)
}
}
if q.flags.Has(wherePKFlag) {
if len(b) > startLen {
b = append(b, " AND "...)
}
b, err = q.appendWherePK(fmter, b, withAlias)
if err != nil {
return nil, err
}
}
return b, nil
}
func appendWhere(
fmter schema.Formatter, b []byte, where []schema.QueryWithSep,
) (_ []byte, err error) {
for i, where := range where {
if i > 0 {
b = append(b, where.Sep...)
}
if where.Query == "" {
continue
}
b = append(b, '(')
b, err = where.AppendQuery(fmter, b)
if err != nil {
return nil, err
}
b = append(b, ')')
}
return b, nil
}
func (q *whereBaseQuery) appendWherePK(
fmter schema.Formatter, b []byte, withAlias bool,
) (_ []byte, err error) {
if q.table == nil {
err := fmt.Errorf("bun: got %T, but WherePK requires a struct or slice-based model", q.model)
return nil, err
}
if err := q.table.CheckPKs(); err != nil {
return nil, err
}
switch model := q.tableModel.(type) {
case *structTableModel:
return q.appendWherePKStruct(fmter, b, model, withAlias)
case *sliceTableModel:
return q.appendWherePKSlice(fmter, b, model, withAlias)
}
return nil, fmt.Errorf("bun: WherePK does not support %T", q.tableModel)
}
func (q *whereBaseQuery) appendWherePKStruct(
fmter schema.Formatter, b []byte, model *structTableModel, withAlias bool,
) (_ []byte, err error) {
if !model.strct.IsValid() {
return nil, errNilModel
}
isTemplate := fmter.IsNop()
b = append(b, '(')
for i, f := range q.table.PKs {
if i > 0 {
b = append(b, " AND "...)
}
if withAlias {
b = append(b, q.table.SQLAlias...)
b = append(b, '.')
}
b = append(b, f.SQLName...)
b = append(b, " = "...)
if isTemplate {
b = append(b, '?')
} else {
b = f.AppendValue(fmter, b, model.strct)
}
}
b = append(b, ')')
return b, nil
}
func (q *whereBaseQuery) appendWherePKSlice(
fmter schema.Formatter, b []byte, model *sliceTableModel, withAlias bool,
) (_ []byte, err error) {
if len(q.table.PKs) > 1 {
b = append(b, '(')
}
if withAlias {
b = appendColumns(b, q.table.SQLAlias, q.table.PKs)
} else {
b = appendColumns(b, "", q.table.PKs)
}
if len(q.table.PKs) > 1 {
b = append(b, ')')
}
b = append(b, " IN ("...)
isTemplate := fmter.IsNop()
slice := model.slice
sliceLen := slice.Len()
for i := 0; i < sliceLen; i++ {
if i > 0 {
if isTemplate {
break
}
b = append(b, ", "...)
}
el := indirect(slice.Index(i))
if len(q.table.PKs) > 1 {
b = append(b, '(')
}
for i, f := range q.table.PKs {
if i > 0 {
b = append(b, ", "...)
}
if isTemplate {
b = append(b, '?')
} else {
b = f.AppendValue(fmter, b, el)
}
}
if len(q.table.PKs) > 1 {
b = append(b, ')')
}
}
b = append(b, ')')
return b, nil
}
//------------------------------------------------------------------------------
type returningQuery struct {
returning []schema.QueryWithArgs
returningFields []*schema.Field
}
func (q *returningQuery) addReturning(ret schema.QueryWithArgs) {
q.returning = append(q.returning, ret)
}
func (q *returningQuery) addReturningField(field *schema.Field) {
if len(q.returning) > 0 {
return
}
for _, f := range q.returningFields {
if f == field {
return
}
}
q.returningFields = append(q.returningFields, field)
}
func (q *returningQuery) hasReturning() bool {
if len(q.returning) == 1 {
switch q.returning[0].Query {
case "null", "NULL":
return false
}
}
return len(q.returning) > 0 || len(q.returningFields) > 0
}
func (q *returningQuery) appendReturning(
fmter schema.Formatter, b []byte,
) (_ []byte, err error) {
if !q.hasReturning() {
return b, nil
}
b = append(b, " RETURNING "...)
for i, f := range q.returning {
if i > 0 {
b = append(b, ", "...)
}
b, err = f.AppendQuery(fmter, b)
if err != nil {
return nil, err
}
}
if len(q.returning) > 0 {
return b, nil
}
b = appendColumns(b, "", q.returningFields)
return b, nil
}
//------------------------------------------------------------------------------
type columnValue struct {
column string
value schema.QueryWithArgs
}
type customValueQuery struct {
modelValues map[string]schema.QueryWithArgs
extraValues []columnValue
}
func (q *customValueQuery) addValue(
table *schema.Table, column string, value string, args []interface{},
) {
if _, ok := table.FieldMap[column]; ok {
if q.modelValues == nil {
q.modelValues = make(map[string]schema.QueryWithArgs)
}
q.modelValues[column] = schema.SafeQuery(value, args)
} else {
q.extraValues = append(q.extraValues, columnValue{
column: column,
value: schema.SafeQuery(value, args),
})
}
}
//------------------------------------------------------------------------------
type setQuery struct {
set []schema.QueryWithArgs
}
func (q *setQuery) addSet(set schema.QueryWithArgs) {
q.set = append(q.set, set)
}
func (q setQuery) appendSet(fmter schema.Formatter, b []byte) (_ []byte, err error) {
for i, f := range q.set {
if i > 0 {
b = append(b, ", "...)
}
b, err = f.AppendQuery(fmter, b)
if err != nil {
return nil, err
}
}
return b, nil
}
//------------------------------------------------------------------------------
type cascadeQuery struct {
restrict bool
}
func (q cascadeQuery) appendCascade(fmter schema.Formatter, b []byte) []byte {
if !fmter.HasFeature(feature.TableCascade) {
return b
}
if q.restrict {
b = append(b, " RESTRICT"...)
} else {
b = append(b, " CASCADE"...)
}
return b
}