shader/control_flow: Abstract repeated code chunks in BRX tracking

Remove copied and pasted for cycles into a common templated function.
This commit is contained in:
ReinUsesLisp 2019-10-27 02:24:48 -03:00
parent ae7dfa93be
commit 46c3047283
No known key found for this signature in database
GPG key ID: 2DFC508897B39CFE

View file

@ -2,6 +2,7 @@
// Licensed under GPLv2 or any later version // Licensed under GPLv2 or any later version
// Refer to the license.txt file included. // Refer to the license.txt file included.
#include <functional>
#include <list> #include <list>
#include <map> #include <map>
#include <set> #include <set>
@ -16,7 +17,9 @@
#include "video_core/shader/shader_ir.h" #include "video_core/shader/shader_ir.h"
namespace VideoCommon::Shader { namespace VideoCommon::Shader {
namespace { namespace {
using Tegra::Shader::Instruction; using Tegra::Shader::Instruction;
using Tegra::Shader::OpCode; using Tegra::Shader::OpCode;
@ -136,15 +139,13 @@ struct BranchIndirectInfo {
s32 relative_position{}; s32 relative_position{};
}; };
std::optional<BranchIndirectInfo> TrackBranchIndirectInfo(const CFGRebuildState& state, struct BufferInfo {
u32 start_address, u32 current_position) { u32 index;
const u32 shader_start = state.start; u32 offset;
u32 pos = current_position; };
BranchIndirectInfo result{};
u64 track_register = 0;
// Step 0 Get BRX Info std::optional<std::pair<s32, u64>> GetBRXInfo(const CFGRebuildState& state, u32& pos) {
const Instruction instr = {state.program_code[pos]}; const Instruction instr = state.program_code[pos];
const auto opcode = OpCode::Decode(instr); const auto opcode = OpCode::Decode(instr);
if (opcode->get().GetId() != OpCode::Id::BRX) { if (opcode->get().GetId() != OpCode::Id::BRX) {
return std::nullopt; return std::nullopt;
@ -152,86 +153,93 @@ std::optional<BranchIndirectInfo> TrackBranchIndirectInfo(const CFGRebuildState&
if (instr.brx.constant_buffer != 0) { if (instr.brx.constant_buffer != 0) {
return std::nullopt; return std::nullopt;
} }
track_register = instr.gpr8.Value(); --pos;
result.relative_position = instr.brx.GetBranchExtend(); return std::make_pair(instr.brx.GetBranchExtend(), instr.gpr8.Value());
pos--; }
bool found_track = false;
// Step 1 Track LDC template <typename Result>
while (pos >= shader_start) { std::optional<Result> TrackInstruction(
if (IsSchedInstruction(pos, shader_start)) { const CFGRebuildState& state, u32& pos,
pos--; std::function<bool(Instruction, const OpCode::Matcher&)>&& test,
continue; std::function<Result(Instruction, const OpCode::Matcher&)>&& pack) {
} for (; pos >= state.start; --pos) {
const Instruction instr = {state.program_code[pos]}; if (IsSchedInstruction(pos, state.start)) {
const auto opcode = OpCode::Decode(instr);
if (opcode->get().GetId() == OpCode::Id::LD_C) {
if (instr.gpr0.Value() == track_register &&
instr.ld_c.type.Value() == Tegra::Shader::UniformType::Single) {
result.buffer = instr.cbuf36.index.Value();
result.offset = static_cast<u32>(instr.cbuf36.GetOffset());
track_register = instr.gpr8.Value();
pos--;
found_track = true;
break;
}
}
pos--;
}
if (!found_track) {
return std::nullopt;
}
found_track = false;
// Step 2 Track SHL
while (pos >= shader_start) {
if (IsSchedInstruction(pos, shader_start)) {
pos--;
continue; continue;
} }
const Instruction instr = state.program_code[pos]; const Instruction instr = state.program_code[pos];
const auto opcode = OpCode::Decode(instr); const auto opcode = OpCode::Decode(instr);
if (opcode->get().GetId() == OpCode::Id::SHL_IMM) { if (!opcode) {
if (instr.gpr0.Value() == track_register) {
track_register = instr.gpr8.Value();
pos--;
found_track = true;
break;
}
}
pos--;
}
if (!found_track) {
return std::nullopt;
}
found_track = false;
// Step 3 Track IMNMX
while (pos >= shader_start) {
if (IsSchedInstruction(pos, shader_start)) {
pos--;
continue; continue;
} }
const Instruction instr = state.program_code[pos]; if (test(instr, opcode->get())) {
const auto opcode = OpCode::Decode(instr); --pos;
if (opcode->get().GetId() == OpCode::Id::IMNMX_IMM) { return std::make_optional(pack(instr, opcode->get()));
if (instr.gpr0.Value() == track_register) {
track_register = instr.gpr8.Value();
result.entries = instr.alu.GetSignedImm20_20() + 1;
pos--;
found_track = true;
break;
}
} }
pos--;
} }
return std::nullopt;
}
if (!found_track) { std::optional<std::pair<BufferInfo, u64>> TrackLDC(const CFGRebuildState& state, u32& pos,
u64 brx_tracked_register) {
return TrackInstruction<std::pair<BufferInfo, u64>>(
state, pos,
[brx_tracked_register](auto instr, auto& opcode) {
return opcode.GetId() == OpCode::Id::LD_C &&
instr.gpr0.Value() == brx_tracked_register &&
instr.ld_c.type.Value() == Tegra::Shader::UniformType::Single;
},
[](auto instr, auto& opcode) {
const BufferInfo info = {static_cast<u32>(instr.cbuf36.index.Value()),
static_cast<u32>(instr.cbuf36.GetOffset())};
return std::make_pair(info, instr.gpr8.Value());
});
}
std::optional<u64> TrackSHLRegister(const CFGRebuildState& state, u32& pos,
u64 ldc_tracked_register) {
return TrackInstruction<u64>(state, pos,
[ldc_tracked_register](auto instr, auto& opcode) {
return opcode.GetId() == OpCode::Id::SHL_IMM &&
instr.gpr0.Value() == ldc_tracked_register;
},
[](auto instr, auto&) { return instr.gpr8.Value(); });
}
std::optional<u32> TrackIMNMXValue(const CFGRebuildState& state, u32& pos,
u64 shl_tracked_register) {
return TrackInstruction<u32>(
state, pos,
[shl_tracked_register](auto instr, auto& opcode) {
return opcode.GetId() == OpCode::Id::IMNMX_IMM &&
instr.gpr0.Value() == shl_tracked_register;
},
[](auto instr, auto&) { return static_cast<u32>(instr.alu.GetSignedImm20_20() + 1); });
}
std::optional<BranchIndirectInfo> TrackBranchIndirectInfo(const CFGRebuildState& state, u32 pos) {
const auto brx_info = GetBRXInfo(state, pos);
if (!brx_info) {
return std::nullopt; return std::nullopt;
} }
return result; const auto [relative_position, brx_tracked_register] = *brx_info;
const auto ldc_info = TrackLDC(state, pos, brx_tracked_register);
if (!ldc_info) {
return std::nullopt;
}
const auto [buffer_info, ldc_tracked_register] = *ldc_info;
const auto shl_tracked_register = TrackSHLRegister(state, pos, ldc_tracked_register);
if (!shl_tracked_register) {
return std::nullopt;
}
const auto entries = TrackIMNMXValue(state, pos, *shl_tracked_register);
if (!entries) {
return std::nullopt;
}
return BranchIndirectInfo{buffer_info.index, buffer_info.offset, *entries, relative_position};
} }
std::pair<ParseResult, ParseInfo> ParseCode(CFGRebuildState& state, u32 address) { std::pair<ParseResult, ParseInfo> ParseCode(CFGRebuildState& state, u32 address) {
@ -420,30 +428,30 @@ std::pair<ParseResult, ParseInfo> ParseCode(CFGRebuildState& state, u32 address)
break; break;
} }
case OpCode::Id::BRX: { case OpCode::Id::BRX: {
auto tmp = TrackBranchIndirectInfo(state, address, offset); const auto tmp = TrackBranchIndirectInfo(state, offset);
if (tmp) { if (!tmp) {
auto result = *tmp;
std::vector<CaseBranch> branches{};
s32 pc_target = offset + result.relative_position;
for (u32 i = 0; i < result.entries; i++) {
auto k = state.locker.ObtainKey(result.buffer, result.offset + i * 4);
if (!k) {
return {ParseResult::AbnormalFlow, parse_info};
}
u32 value = *k;
u32 target = static_cast<u32>((value >> 3) + pc_target);
insert_label(state, target);
branches.emplace_back(value, target);
}
parse_info.end_address = offset;
parse_info.branch_info = MakeBranchInfo<MultiBranch>(
static_cast<u32>(instr.gpr8.Value()), std::move(branches));
return {ParseResult::ControlCaught, parse_info};
} else {
LOG_WARNING(HW_GPU, "BRX Track Unsuccesful"); LOG_WARNING(HW_GPU, "BRX Track Unsuccesful");
return {ParseResult::AbnormalFlow, parse_info};
} }
return {ParseResult::AbnormalFlow, parse_info};
const auto result = *tmp;
const s32 pc_target = offset + result.relative_position;
std::vector<CaseBranch> branches;
for (u32 i = 0; i < result.entries; i++) {
auto key = state.locker.ObtainKey(result.buffer, result.offset + i * 4);
if (!key) {
return {ParseResult::AbnormalFlow, parse_info};
}
u32 value = *key;
u32 target = static_cast<u32>((value >> 3) + pc_target);
insert_label(state, target);
branches.emplace_back(value, target);
}
parse_info.end_address = offset;
parse_info.branch_info = MakeBranchInfo<MultiBranch>(
static_cast<u32>(instr.gpr8.Value()), std::move(branches));
return {ParseResult::ControlCaught, parse_info};
} }
default: default:
break; break;