diff --git a/src/shader_recompiler/frontend/ir/value.h b/src/shader_recompiler/frontend/ir/value.h index 795194d41..334bb47aa 100644 --- a/src/shader_recompiler/frontend/ir/value.h +++ b/src/shader_recompiler/frontend/ir/value.h @@ -57,6 +57,7 @@ public: [[nodiscard]] IR::Inst* Inst() const; [[nodiscard]] IR::Inst* InstRecursive() const; + [[nodiscard]] IR::Inst* TryInstRecursive() const; [[nodiscard]] IR::Value Resolve() const; [[nodiscard]] IR::Reg Reg() const; [[nodiscard]] IR::Pred Pred() const; @@ -308,6 +309,13 @@ inline IR::Inst* Value::InstRecursive() const { return inst; } +inline IR::Inst* Value::TryInstRecursive() const { + if (IsIdentity()) { + return inst->Arg(0).TryInstRecursive(); + } + return type == Type::Opaque ? inst : nullptr; +} + inline IR::Value Value::Resolve() const { if (IsIdentity()) { return inst->Arg(0).Resolve(); diff --git a/src/shader_recompiler/ir_opt/constant_propagation_pass.cpp b/src/shader_recompiler/ir_opt/constant_propagation_pass.cpp index 8dd6d6c2c..c403a5fae 100644 --- a/src/shader_recompiler/ir_opt/constant_propagation_pass.cpp +++ b/src/shader_recompiler/ir_opt/constant_propagation_pass.cpp @@ -3,6 +3,7 @@ // Refer to the license.txt file included. #include +#include #include #include @@ -88,6 +89,26 @@ bool FoldWhenAllImmediates(IR::Inst& inst, Func&& func) { return true; } +/// Return true when all values in a range are equal +template +bool AreEqual(const Range& range) { + auto resolver{[](const auto& value) { return value.Resolve(); }}; + auto equal{[](const IR::Value& lhs, const IR::Value& rhs) { + if (lhs == rhs) { + return true; + } + // Not equal, but try to match if they read the same constant buffer + if (!lhs.IsImmediate() && !rhs.IsImmediate() && + lhs.Inst()->GetOpcode() == IR::Opcode::GetCbufU32 && + rhs.Inst()->GetOpcode() == IR::Opcode::GetCbufU32 && + lhs.Inst()->Arg(0) == rhs.Inst()->Arg(0) && lhs.Inst()->Arg(1) == rhs.Inst()->Arg(1)) { + return true; + } + return false; + }}; + return std::ranges::adjacent_find(range, std::not_fn(equal), resolver) == std::end(range); +} + void FoldGetRegister(IR::Inst& inst) { if (inst.Arg(0).Reg() == IR::Reg::RZ) { inst.ReplaceUsesWith(IR::Value{u32{0}}); @@ -100,6 +121,157 @@ void FoldGetPred(IR::Inst& inst) { } } +/// Replaces the XMAD pattern generated by an integer FMA +bool FoldXmadMultiplyAdd(IR::Block& block, IR::Inst& inst) { + /* + * We are looking for this specific pattern: + * %6 = BitFieldUExtract %op_b, #0, #16 + * %7 = BitFieldUExtract %op_a', #16, #16 + * %8 = IMul32 %6, %7 + * %10 = BitFieldUExtract %op_a', #0, #16 + * %11 = BitFieldInsert %8, %10, #16, #16 + * %15 = BitFieldUExtract %op_b, #0, #16 + * %16 = BitFieldUExtract %op_a, #0, #16 + * %17 = IMul32 %15, %16 + * %18 = IAdd32 %17, %op_c + * %22 = BitFieldUExtract %op_b, #16, #16 + * %23 = BitFieldUExtract %11, #16, #16 + * %24 = IMul32 %22, %23 + * %25 = ShiftLeftLogical32 %24, #16 + * %26 = ShiftLeftLogical32 %11, #16 + * %27 = IAdd32 %26, %18 + * %result = IAdd32 %25, %27 + * + * And replace it with: + * %temp = IMul32 %op_a, %op_b + * %result = IAdd32 %temp, %op_c + * + * This optimization has been proven safe by Nvidia's compiler logic being reversed. + * (If Nvidia generates this code from 'fma(a, b, c)', we can do the same in the reverse order.) + */ + const IR::Value zero{0u}; + const IR::Value sixteen{16u}; + IR::Inst* const _25{inst.Arg(0).TryInstRecursive()}; + IR::Inst* const _27{inst.Arg(1).TryInstRecursive()}; + if (!_25 || !_27) { + return false; + } + if (_27->GetOpcode() != IR::Opcode::IAdd32) { + return false; + } + if (_25->GetOpcode() != IR::Opcode::ShiftLeftLogical32 || _25->Arg(1) != sixteen) { + return false; + } + IR::Inst* const _24{_25->Arg(0).TryInstRecursive()}; + if (!_24 || _24->GetOpcode() != IR::Opcode::IMul32) { + return false; + } + IR::Inst* const _22{_24->Arg(0).TryInstRecursive()}; + IR::Inst* const _23{_24->Arg(1).TryInstRecursive()}; + if (!_22 || !_23) { + return false; + } + if (_22->GetOpcode() != IR::Opcode::BitFieldUExtract) { + return false; + } + if (_23->GetOpcode() != IR::Opcode::BitFieldUExtract) { + return false; + } + if (_22->Arg(1) != sixteen || _22->Arg(2) != sixteen) { + return false; + } + if (_23->Arg(1) != sixteen || _23->Arg(2) != sixteen) { + return false; + } + IR::Inst* const _11{_23->Arg(0).TryInstRecursive()}; + if (!_11 || _11->GetOpcode() != IR::Opcode::BitFieldInsert) { + return false; + } + if (_11->Arg(2) != sixteen || _11->Arg(3) != sixteen) { + return false; + } + IR::Inst* const _8{_11->Arg(0).TryInstRecursive()}; + IR::Inst* const _10{_11->Arg(1).TryInstRecursive()}; + if (!_8 || !_10) { + return false; + } + if (_8->GetOpcode() != IR::Opcode::IMul32) { + return false; + } + if (_10->GetOpcode() != IR::Opcode::BitFieldUExtract) { + return false; + } + IR::Inst* const _6{_8->Arg(0).TryInstRecursive()}; + IR::Inst* const _7{_8->Arg(1).TryInstRecursive()}; + if (!_6 || !_7) { + return false; + } + if (_6->GetOpcode() != IR::Opcode::BitFieldUExtract) { + return false; + } + if (_7->GetOpcode() != IR::Opcode::BitFieldUExtract) { + return false; + } + if (_6->Arg(1) != zero || _6->Arg(2) != sixteen) { + return false; + } + if (_7->Arg(1) != sixteen || _7->Arg(2) != sixteen) { + return false; + } + IR::Inst* const _26{_27->Arg(0).TryInstRecursive()}; + IR::Inst* const _18{_27->Arg(1).TryInstRecursive()}; + if (!_26 || !_18) { + return false; + } + if (_26->GetOpcode() != IR::Opcode::ShiftLeftLogical32 || _26->Arg(1) != sixteen) { + return false; + } + if (_26->Arg(0).InstRecursive() != _11) { + return false; + } + if (_18->GetOpcode() != IR::Opcode::IAdd32) { + return false; + } + IR::Inst* const _17{_18->Arg(0).TryInstRecursive()}; + if (!_17 || _17->GetOpcode() != IR::Opcode::IMul32) { + return false; + } + IR::Inst* const _15{_17->Arg(0).TryInstRecursive()}; + IR::Inst* const _16{_17->Arg(1).TryInstRecursive()}; + if (!_15 || !_16) { + return false; + } + if (_15->GetOpcode() != IR::Opcode::BitFieldUExtract) { + return false; + } + if (_16->GetOpcode() != IR::Opcode::BitFieldUExtract) { + return false; + } + if (_15->Arg(1) != zero || _16->Arg(1) != zero || _10->Arg(1) != zero) { + return false; + } + if (_15->Arg(2) != sixteen || _16->Arg(2) != sixteen || _10->Arg(2) != sixteen) { + return false; + } + const std::array op_as{ + _7->Arg(0).Resolve(), + _16->Arg(0).Resolve(), + _10->Arg(0).Resolve(), + }; + const std::array op_bs{ + _22->Arg(0).Resolve(), + _6->Arg(0).Resolve(), + _15->Arg(0).Resolve(), + }; + const IR::U32 op_c{_18->Arg(1)}; + if (!AreEqual(op_as) || !AreEqual(op_bs)) { + return false; + } + IR::IREmitter ir{block, IR::Block::InstructionList::s_iterator_to(inst)}; + inst.ReplaceUsesWith(ir.IAdd(ir.IMul(IR::U32{op_as[0]}, IR::U32{op_bs[1]}), op_c)); + return true; +} + /// Replaces the pattern generated by two XMAD multiplications bool FoldXmadMultiply(IR::Block& block, IR::Inst& inst) { /* @@ -116,33 +288,31 @@ bool FoldXmadMultiply(IR::Block& block, IR::Inst& inst) { * * This optimization has been proven safe by LLVM and MSVC. */ - const IR::Value lhs_arg{inst.Arg(0)}; - const IR::Value rhs_arg{inst.Arg(1)}; - if (lhs_arg.IsImmediate() || rhs_arg.IsImmediate()) { + IR::Inst* const lhs_shl{inst.Arg(0).TryInstRecursive()}; + IR::Inst* const rhs_mul{inst.Arg(1).TryInstRecursive()}; + if (!lhs_shl || !rhs_mul) { return false; } - IR::Inst* const lhs_shl{lhs_arg.InstRecursive()}; if (lhs_shl->GetOpcode() != IR::Opcode::ShiftLeftLogical32 || lhs_shl->Arg(1) != IR::Value{16U}) { return false; } - if (lhs_shl->Arg(0).IsImmediate()) { + IR::Inst* const lhs_mul{lhs_shl->Arg(0).TryInstRecursive()}; + if (!lhs_mul) { return false; } - IR::Inst* const lhs_mul{lhs_shl->Arg(0).InstRecursive()}; - IR::Inst* const rhs_mul{rhs_arg.InstRecursive()}; if (lhs_mul->GetOpcode() != IR::Opcode::IMul32 || rhs_mul->GetOpcode() != IR::Opcode::IMul32) { return false; } - if (lhs_mul->Arg(1).Resolve() != rhs_mul->Arg(1).Resolve()) { - return false; - } const IR::U32 factor_b{lhs_mul->Arg(1)}; - if (lhs_mul->Arg(0).IsImmediate() || rhs_mul->Arg(0).IsImmediate()) { + if (factor_b.Resolve() != rhs_mul->Arg(1).Resolve()) { + return false; + } + IR::Inst* const lhs_bfe{lhs_mul->Arg(0).TryInstRecursive()}; + IR::Inst* const rhs_bfe{rhs_mul->Arg(0).TryInstRecursive()}; + if (!lhs_bfe || !rhs_bfe) { return false; } - IR::Inst* const lhs_bfe{lhs_mul->Arg(0).InstRecursive()}; - IR::Inst* const rhs_bfe{rhs_mul->Arg(0).InstRecursive()}; if (lhs_bfe->GetOpcode() != IR::Opcode::BitFieldUExtract) { return false; } @@ -155,10 +325,10 @@ bool FoldXmadMultiply(IR::Block& block, IR::Inst& inst) { if (rhs_bfe->Arg(1) != IR::Value{0U} || rhs_bfe->Arg(2) != IR::Value{16U}) { return false; } - if (lhs_bfe->Arg(0).Resolve() != rhs_bfe->Arg(0).Resolve()) { + const IR::U32 factor_a{lhs_bfe->Arg(0)}; + if (factor_a.Resolve() != rhs_bfe->Arg(0).Resolve()) { return false; } - const IR::U32 factor_a{lhs_bfe->Arg(0)}; IR::IREmitter ir{block, IR::Block::InstructionList::s_iterator_to(inst)}; inst.ReplaceUsesWith(ir.IMul(factor_a, factor_b)); return true; @@ -181,6 +351,9 @@ void FoldAdd(IR::Block& block, IR::Inst& inst) { if (FoldXmadMultiply(block, inst)) { return; } + if (FoldXmadMultiplyAdd(block, inst)) { + return; + } } }