package otelsql import ( "context" "database/sql/driver" "go.opentelemetry.io/otel/trace" ) type otelStmt struct { driver.Stmt query string instrum *dbInstrum execCtx stmtExecCtxFunc queryCtx stmtQueryCtxFunc } var _ driver.Stmt = (*otelStmt)(nil) func newStmt(stmt driver.Stmt, query string, instrum *dbInstrum) *otelStmt { s := &otelStmt{ Stmt: stmt, query: query, instrum: instrum, } s.execCtx = s.createExecCtxFunc(stmt) s.queryCtx = s.createQueryCtxFunc(stmt) return s } //------------------------------------------------------------------------------ var _ driver.StmtExecContext = (*otelStmt)(nil) func (stmt *otelStmt) ExecContext( ctx context.Context, args []driver.NamedValue, ) (driver.Result, error) { return stmt.execCtx(ctx, args) } type stmtExecCtxFunc func(ctx context.Context, args []driver.NamedValue) (driver.Result, error) func (s *otelStmt) createExecCtxFunc(stmt driver.Stmt) stmtExecCtxFunc { var fn stmtExecCtxFunc if execer, ok := s.Stmt.(driver.StmtExecContext); ok { fn = execer.ExecContext } else { fn = func(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { vArgs, err := namedValueToValue(args) if err != nil { return nil, err } return stmt.Exec(vArgs) } } return func(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { var res driver.Result err := s.instrum.withSpan(ctx, "stmt.Exec", s.query, func(ctx context.Context, span trace.Span) error { var err error res, err = fn(ctx, args) if err != nil { return err } if span.IsRecording() { rows, err := res.RowsAffected() if err == nil { span.SetAttributes(dbRowsAffected.Int64(rows)) } } return nil }) return res, err } } //------------------------------------------------------------------------------ var _ driver.StmtQueryContext = (*otelStmt)(nil) func (stmt *otelStmt) QueryContext( ctx context.Context, args []driver.NamedValue, ) (driver.Rows, error) { return stmt.queryCtx(ctx, args) } type stmtQueryCtxFunc func(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) func (s *otelStmt) createQueryCtxFunc(stmt driver.Stmt) stmtQueryCtxFunc { var fn stmtQueryCtxFunc if queryer, ok := s.Stmt.(driver.StmtQueryContext); ok { fn = queryer.QueryContext } else { fn = func(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { vArgs, err := namedValueToValue(args) if err != nil { return nil, err } return s.Query(vArgs) } } return func(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { var rows driver.Rows err := s.instrum.withSpan(ctx, "stmt.Query", s.query, func(ctx context.Context, span trace.Span) error { var err error rows, err = fn(ctx, args) return err }) return rows, err } }