diff --git a/src/video_core/shader/control_flow.cpp b/src/video_core/shader/control_flow.cpp index 1775dfd81..7b424d65d 100644 --- a/src/video_core/shader/control_flow.cpp +++ b/src/video_core/shader/control_flow.cpp @@ -4,6 +4,7 @@ #include #include +#include #include #include #include @@ -20,68 +21,18 @@ using Tegra::Shader::OpCode; constexpr s32 unassigned_branch = -2; -/** - * 'ControlStack' represents a static stack of control jumps such as SSY and PBK - * stacks in Maxwell. - **/ -struct ControlStack { - static constexpr std::size_t stack_fixed_size = 20; - std::array stack{}; - u32 index{}; - - bool Compare(const ControlStack& cs) const { - if (index != cs.index) { - return false; - } - return std::memcmp(stack.data(), cs.stack.data(), index * sizeof(u32)) == 0; - } - - /// This compare just compares the top of the stack against one another - bool SoftCompare(const ControlStack& cs) const { - if (index == 0 || cs.index == 0) { - return index == cs.index; - } - return Top() == cs.Top(); - } - - u32 Size() const { - return index; - } - - u32 Top() const { - return stack[index - 1]; - } - - bool Push(u32 address) { - if (index >= stack.size()) { - return false; - } - stack[index] = address; - index++; - return true; - } - - bool Pop() { - if (index == 0) { - return false; - } - index--; - return true; - } -}; - struct Query { u32 address{}; - ControlStack ssy_stack{}; - ControlStack pbk_stack{}; + std::stack ssy_stack{}; + std::stack pbk_stack{}; }; struct BlockStack { BlockStack() = default; BlockStack(const BlockStack& b) = default; BlockStack(const Query& q) : ssy_stack{q.ssy_stack}, pbk_stack{q.pbk_stack} {} - ControlStack ssy_stack{}; - ControlStack pbk_stack{}; + std::stack ssy_stack{}; + std::stack pbk_stack{}; }; struct BlockBranchInfo { @@ -144,13 +95,13 @@ struct ParseInfo { u32 end_address{}; }; -BlockInfo* CreateBlockInfo(CFGRebuildState& state, u32 start, u32 end) { +BlockInfo& CreateBlockInfo(CFGRebuildState& state, u32 start, u32 end) { auto& it = state.block_info.emplace_back(); it.start = start; it.end = end; const u32 index = static_cast(state.block_info.size() - 1); state.registered.insert({start, index}); - return ⁢ + return it; } Pred GetPredicate(u32 index, bool negated) { @@ -174,16 +125,17 @@ enum class ParseResult : u32 { AbnormalFlow, }; -ParseResult ParseCode(CFGRebuildState& state, u32 address, ParseInfo& parse_info) { +std::pair ParseCode(CFGRebuildState& state, u32 address) { u32 offset = static_cast(address); const u32 end_address = static_cast(state.program_size / sizeof(Instruction)); + ParseInfo parse_info{}; - const auto insert_label = ([](CFGRebuildState& state, u32 address) { - auto pair = state.labels.emplace(address); + const auto insert_label = [](CFGRebuildState& state, u32 address) { + const auto pair = state.labels.emplace(address); if (pair.second) { state.inspect_queries.push_back(address); } - }); + }; while (true) { if (offset >= end_address) { @@ -229,11 +181,11 @@ ParseResult ParseCode(CFGRebuildState& state, u32 address, ParseInfo& parse_info parse_info.branch_info.ignore = false; parse_info.end_address = offset; - return ParseResult::ControlCaught; + return {ParseResult::ControlCaught, parse_info}; } case OpCode::Id::BRA: { if (instr.bra.constant_buffer != 0) { - return ParseResult::AbnormalFlow; + return {ParseResult::AbnormalFlow, parse_info}; } const auto pred_index = static_cast(instr.pred.pred_index); parse_info.branch_info.condition.predicate = @@ -248,7 +200,7 @@ ParseResult ParseCode(CFGRebuildState& state, u32 address, ParseInfo& parse_info offset++; continue; } - u32 branch_offset = offset + instr.bra.GetBranchTarget(); + const u32 branch_offset = offset + instr.bra.GetBranchTarget(); if (branch_offset == 0) { parse_info.branch_info.address = exit_branch; } else { @@ -261,10 +213,9 @@ ParseResult ParseCode(CFGRebuildState& state, u32 address, ParseInfo& parse_info parse_info.branch_info.ignore = false; parse_info.end_address = offset; - return ParseResult::ControlCaught; + return {ParseResult::ControlCaught, parse_info}; } case OpCode::Id::SYNC: { - parse_info.branch_info.condition; const auto pred_index = static_cast(instr.pred.pred_index); parse_info.branch_info.condition.predicate = GetPredicate(pred_index, instr.negate_pred != 0); @@ -285,10 +236,9 @@ ParseResult ParseCode(CFGRebuildState& state, u32 address, ParseInfo& parse_info parse_info.branch_info.ignore = false; parse_info.end_address = offset; - return ParseResult::ControlCaught; + return {ParseResult::ControlCaught, parse_info}; } case OpCode::Id::BRK: { - parse_info.branch_info.condition; const auto pred_index = static_cast(instr.pred.pred_index); parse_info.branch_info.condition.predicate = GetPredicate(pred_index, instr.negate_pred != 0); @@ -309,10 +259,9 @@ ParseResult ParseCode(CFGRebuildState& state, u32 address, ParseInfo& parse_info parse_info.branch_info.ignore = false; parse_info.end_address = offset; - return ParseResult::ControlCaught; + return {ParseResult::ControlCaught, parse_info}; } case OpCode::Id::KIL: { - parse_info.branch_info.condition; const auto pred_index = static_cast(instr.pred.pred_index); parse_info.branch_info.condition.predicate = GetPredicate(pred_index, instr.negate_pred != 0); @@ -333,7 +282,7 @@ ParseResult ParseCode(CFGRebuildState& state, u32 address, ParseInfo& parse_info parse_info.branch_info.ignore = false; parse_info.end_address = offset; - return ParseResult::ControlCaught; + return {ParseResult::ControlCaught, parse_info}; } case OpCode::Id::SSY: { const u32 target = offset + instr.bra.GetBranchTarget(); @@ -348,7 +297,7 @@ ParseResult ParseCode(CFGRebuildState& state, u32 address, ParseInfo& parse_info break; } case OpCode::Id::BRX: { - return ParseResult::AbnormalFlow; + return {ParseResult::AbnormalFlow, parse_info}; } default: break; @@ -360,7 +309,7 @@ ParseResult ParseCode(CFGRebuildState& state, u32 address, ParseInfo& parse_info parse_info.branch_info.is_sync = false; parse_info.branch_info.is_brk = false; parse_info.end_address = offset - 1; - return ParseResult::BlockEnd; + return {ParseResult::BlockEnd, parse_info}; } bool TryInspectAddress(CFGRebuildState& state) { @@ -377,10 +326,10 @@ bool TryInspectAddress(CFGRebuildState& state) { case BlockCollision::Inside: { // This case is the tricky one: // We need to Split the block in 2 sepparate blocks - auto it = search_result.second; - BlockInfo* block_info = CreateBlockInfo(state, address, it->end); + const auto it = search_result.second; + BlockInfo& block_info = CreateBlockInfo(state, address, it->end); it->end = address - 1; - block_info->branch = it->branch; + block_info.branch = it->branch; BlockBranchInfo forward_branch{}; forward_branch.address = address; forward_branch.ignore = true; @@ -390,15 +339,14 @@ bool TryInspectAddress(CFGRebuildState& state) { default: break; } - ParseInfo parse_info; - const ParseResult parse_result = ParseCode(state, address, parse_info); + const auto [parse_result, parse_info] = ParseCode(state, address); if (parse_result == ParseResult::AbnormalFlow) { // if it's AbnormalFlow, we end it as false, ending the CFG reconstruction return false; } - BlockInfo* block_info = CreateBlockInfo(state, address, parse_info.end_address); - block_info->branch = parse_info.branch_info; + BlockInfo& block_info = CreateBlockInfo(state, address, parse_info.end_address); + block_info.branch = parse_info.branch_info; if (parse_info.branch_info.condition.IsUnconditional()) { return true; } @@ -409,14 +357,15 @@ bool TryInspectAddress(CFGRebuildState& state) { } bool TryQuery(CFGRebuildState& state) { - const auto gather_labels = ([](ControlStack& cc, std::map& labels, BlockInfo& block) { + const auto gather_labels = [](std::stack& cc, std::map& labels, + BlockInfo& block) { auto gather_start = labels.lower_bound(block.start); const auto gather_end = labels.upper_bound(block.end); while (gather_start != gather_end) { - cc.Push(gather_start->second); + cc.push(gather_start->second); gather_start++; } - }); + }; if (state.queries.empty()) { return false; } @@ -428,9 +377,8 @@ bool TryQuery(CFGRebuildState& state) { // consumes a label. Schedule new queries accordingly if (block.visited) { BlockStack& stack = state.stacks[q.address]; - const bool all_okay = - (stack.ssy_stack.Size() == 0 || q.ssy_stack.Compare(stack.ssy_stack)) && - (stack.pbk_stack.Size() == 0 || q.pbk_stack.Compare(stack.pbk_stack)); + const bool all_okay = (stack.ssy_stack.size() == 0 || q.ssy_stack == stack.ssy_stack) && + (stack.pbk_stack.size() == 0 || q.pbk_stack == stack.pbk_stack); state.queries.pop_front(); return all_okay; } @@ -447,15 +395,15 @@ bool TryQuery(CFGRebuildState& state) { Query conditional_query{q2}; if (block.branch.is_sync) { if (block.branch.address == unassigned_branch) { - block.branch.address = conditional_query.ssy_stack.Top(); + block.branch.address = conditional_query.ssy_stack.top(); } - conditional_query.ssy_stack.Pop(); + conditional_query.ssy_stack.pop(); } if (block.branch.is_brk) { if (block.branch.address == unassigned_branch) { - block.branch.address = conditional_query.pbk_stack.Top(); + block.branch.address = conditional_query.pbk_stack.top(); } - conditional_query.pbk_stack.Pop(); + conditional_query.pbk_stack.pop(); } conditional_query.address = block.branch.address; state.queries.push_back(conditional_query);