diff --git a/CMakeModules/GenerateSCMRev.cmake b/CMakeModules/GenerateSCMRev.cmake index a1ace89cb..09eabe2c7 100644 --- a/CMakeModules/GenerateSCMRev.cmake +++ b/CMakeModules/GenerateSCMRev.cmake @@ -83,9 +83,15 @@ set(HASH_FILES "${VIDEO_CORE}/shader/decode/video.cpp" "${VIDEO_CORE}/shader/decode/warp.cpp" "${VIDEO_CORE}/shader/decode/xmad.cpp" + "${VIDEO_CORE}/shader/ast.cpp" + "${VIDEO_CORE}/shader/ast.h" "${VIDEO_CORE}/shader/control_flow.cpp" "${VIDEO_CORE}/shader/control_flow.h" + "${VIDEO_CORE}/shader/compiler_settings.cpp" + "${VIDEO_CORE}/shader/compiler_settings.h" "${VIDEO_CORE}/shader/decode.cpp" + "${VIDEO_CORE}/shader/expr.cpp" + "${VIDEO_CORE}/shader/expr.h" "${VIDEO_CORE}/shader/node.h" "${VIDEO_CORE}/shader/node_helper.cpp" "${VIDEO_CORE}/shader/node_helper.h" diff --git a/src/common/CMakeLists.txt b/src/common/CMakeLists.txt index dfed8b51d..0ed96c0d4 100644 --- a/src/common/CMakeLists.txt +++ b/src/common/CMakeLists.txt @@ -60,9 +60,15 @@ add_custom_command(OUTPUT scm_rev.cpp "${VIDEO_CORE}/shader/decode/video.cpp" "${VIDEO_CORE}/shader/decode/warp.cpp" "${VIDEO_CORE}/shader/decode/xmad.cpp" + "${VIDEO_CORE}/shader/ast.cpp" + "${VIDEO_CORE}/shader/ast.h" "${VIDEO_CORE}/shader/control_flow.cpp" "${VIDEO_CORE}/shader/control_flow.h" + "${VIDEO_CORE}/shader/compiler_settings.cpp" + "${VIDEO_CORE}/shader/compiler_settings.h" "${VIDEO_CORE}/shader/decode.cpp" + "${VIDEO_CORE}/shader/expr.cpp" + "${VIDEO_CORE}/shader/expr.h" "${VIDEO_CORE}/shader/node.h" "${VIDEO_CORE}/shader/node_helper.cpp" "${VIDEO_CORE}/shader/node_helper.h" diff --git a/src/video_core/CMakeLists.txt b/src/video_core/CMakeLists.txt index e2f85c5f1..eaa694ff8 100644 --- a/src/video_core/CMakeLists.txt +++ b/src/video_core/CMakeLists.txt @@ -105,9 +105,15 @@ add_library(video_core STATIC shader/decode/warp.cpp shader/decode/xmad.cpp shader/decode/other.cpp + shader/ast.cpp + shader/ast.h shader/control_flow.cpp shader/control_flow.h + shader/compiler_settings.cpp + shader/compiler_settings.h shader/decode.cpp + shader/expr.cpp + shader/expr.h shader/node_helper.cpp shader/node_helper.h shader/node.h diff --git a/src/video_core/renderer_opengl/gl_shader_decompiler.cpp b/src/video_core/renderer_opengl/gl_shader_decompiler.cpp index 8fa9e6534..6a610a3bc 100644 --- a/src/video_core/renderer_opengl/gl_shader_decompiler.cpp +++ b/src/video_core/renderer_opengl/gl_shader_decompiler.cpp @@ -19,6 +19,7 @@ #include "video_core/renderer_opengl/gl_device.h" #include "video_core/renderer_opengl/gl_rasterizer.h" #include "video_core/renderer_opengl/gl_shader_decompiler.h" +#include "video_core/shader/ast.h" #include "video_core/shader/node.h" #include "video_core/shader/shader_ir.h" @@ -334,39 +335,24 @@ constexpr bool IsVertexShader(ProgramType stage) { return stage == ProgramType::VertexA || stage == ProgramType::VertexB; } +class ASTDecompiler; +class ExprDecompiler; + class GLSLDecompiler final { public: explicit GLSLDecompiler(const Device& device, const ShaderIR& ir, ProgramType stage, std::string suffix) : device{device}, ir{ir}, stage{stage}, suffix{suffix}, header{ir.GetHeader()} {} - void Decompile() { - DeclareVertex(); - DeclareGeometry(); - DeclareRegisters(); - DeclarePredicates(); - DeclareLocalMemory(); - DeclareSharedMemory(); - DeclareInternalFlags(); - DeclareInputAttributes(); - DeclareOutputAttributes(); - DeclareConstantBuffers(); - DeclareGlobalMemory(); - DeclareSamplers(); - DeclarePhysicalAttributeReader(); - DeclareImages(); - - code.AddLine("void execute_{}() {{", suffix); - ++code.scope; - + void DecompileBranchMode() { // VM's program counter const auto first_address = ir.GetBasicBlocks().begin()->first; code.AddLine("uint jmp_to = {}U;", first_address); // TODO(Subv): Figure out the actual depth of the flow stack, for now it seems // unlikely that shaders will use 20 nested SSYs and PBKs. + constexpr u32 FLOW_STACK_SIZE = 20; if (!ir.IsFlowStackDisabled()) { - constexpr u32 FLOW_STACK_SIZE = 20; for (const auto stack : std::array{MetaStackClass::Ssy, MetaStackClass::Pbk}) { code.AddLine("uint {}[{}];", FlowStackName(stack), FLOW_STACK_SIZE); code.AddLine("uint {} = 0U;", FlowStackTopName(stack)); @@ -392,10 +378,37 @@ public: code.AddLine("default: return;"); code.AddLine("}}"); - for (std::size_t i = 0; i < 2; ++i) { - --code.scope; - code.AddLine("}}"); + --code.scope; + code.AddLine("}}"); + } + + void DecompileAST(); + + void Decompile() { + DeclareVertex(); + DeclareGeometry(); + DeclareRegisters(); + DeclarePredicates(); + DeclareLocalMemory(); + DeclareInternalFlags(); + DeclareInputAttributes(); + DeclareOutputAttributes(); + DeclareConstantBuffers(); + DeclareGlobalMemory(); + DeclareSamplers(); + DeclarePhysicalAttributeReader(); + + code.AddLine("void execute_{}() {{", suffix); + ++code.scope; + + if (ir.IsDecompiled()) { + DecompileAST(); + } else { + DecompileBranchMode(); } + + --code.scope; + code.AddLine("}}"); } std::string GetResult() { @@ -424,6 +437,9 @@ public: } private: + friend class ASTDecompiler; + friend class ExprDecompiler; + void DeclareVertex() { if (!IsVertexShader(stage)) return; @@ -1821,10 +1837,9 @@ private: return {}; } - Expression Exit(Operation operation) { + void PreExit() { if (stage != ProgramType::Fragment) { - code.AddLine("return;"); - return {}; + return; } const auto& used_registers = ir.GetRegisters(); const auto SafeGetRegister = [&](u32 reg) -> Expression { @@ -1856,7 +1871,10 @@ private: // already contains one past the last color register. code.AddLine("gl_FragDepth = {};", SafeGetRegister(current_reg + 1).AsFloat()); } + } + Expression Exit(Operation operation) { + PreExit(); code.AddLine("return;"); return {}; } @@ -2253,6 +2271,208 @@ private: ShaderWriter code; }; +static constexpr std::string_view flow_var = "flow_var_"; + +std::string GetFlowVariable(u32 i) { + return fmt::format("{}{}", flow_var, i); +} + +class ExprDecompiler { +public: + explicit ExprDecompiler(GLSLDecompiler& decomp) : decomp{decomp} {} + + void operator()(VideoCommon::Shader::ExprAnd& expr) { + inner += "( "; + std::visit(*this, *expr.operand1); + inner += " && "; + std::visit(*this, *expr.operand2); + inner += ')'; + } + + void operator()(VideoCommon::Shader::ExprOr& expr) { + inner += "( "; + std::visit(*this, *expr.operand1); + inner += " || "; + std::visit(*this, *expr.operand2); + inner += ')'; + } + + void operator()(VideoCommon::Shader::ExprNot& expr) { + inner += '!'; + std::visit(*this, *expr.operand1); + } + + void operator()(VideoCommon::Shader::ExprPredicate& expr) { + const auto pred = static_cast(expr.predicate); + inner += decomp.GetPredicate(pred); + } + + void operator()(VideoCommon::Shader::ExprCondCode& expr) { + const Node cc = decomp.ir.GetConditionCode(expr.cc); + std::string target; + + if (const auto pred = std::get_if(&*cc)) { + const auto index = pred->GetIndex(); + switch (index) { + case Tegra::Shader::Pred::NeverExecute: + target = "false"; + case Tegra::Shader::Pred::UnusedIndex: + target = "true"; + default: + target = decomp.GetPredicate(index); + } + } else if (const auto flag = std::get_if(&*cc)) { + target = decomp.GetInternalFlag(flag->GetFlag()); + } else { + UNREACHABLE(); + } + inner += target; + } + + void operator()(VideoCommon::Shader::ExprVar& expr) { + inner += GetFlowVariable(expr.var_index); + } + + void operator()(VideoCommon::Shader::ExprBoolean& expr) { + inner += expr.value ? "true" : "false"; + } + + std::string& GetResult() { + return inner; + } + +private: + std::string inner; + GLSLDecompiler& decomp; +}; + +class ASTDecompiler { +public: + explicit ASTDecompiler(GLSLDecompiler& decomp) : decomp{decomp} {} + + void operator()(VideoCommon::Shader::ASTProgram& ast) { + ASTNode current = ast.nodes.GetFirst(); + while (current) { + Visit(current); + current = current->GetNext(); + } + } + + void operator()(VideoCommon::Shader::ASTIfThen& ast) { + ExprDecompiler expr_parser{decomp}; + std::visit(expr_parser, *ast.condition); + decomp.code.AddLine("if ({}) {{", expr_parser.GetResult()); + decomp.code.scope++; + ASTNode current = ast.nodes.GetFirst(); + while (current) { + Visit(current); + current = current->GetNext(); + } + decomp.code.scope--; + decomp.code.AddLine("}}"); + } + + void operator()(VideoCommon::Shader::ASTIfElse& ast) { + decomp.code.AddLine("else {{"); + decomp.code.scope++; + ASTNode current = ast.nodes.GetFirst(); + while (current) { + Visit(current); + current = current->GetNext(); + } + decomp.code.scope--; + decomp.code.AddLine("}}"); + } + + void operator()(VideoCommon::Shader::ASTBlockEncoded& ast) { + UNREACHABLE(); + } + + void operator()(VideoCommon::Shader::ASTBlockDecoded& ast) { + decomp.VisitBlock(ast.nodes); + } + + void operator()(VideoCommon::Shader::ASTVarSet& ast) { + ExprDecompiler expr_parser{decomp}; + std::visit(expr_parser, *ast.condition); + decomp.code.AddLine("{} = {};", GetFlowVariable(ast.index), expr_parser.GetResult()); + } + + void operator()(VideoCommon::Shader::ASTLabel& ast) { + decomp.code.AddLine("// Label_{}:", ast.index); + } + + void operator()(VideoCommon::Shader::ASTGoto& ast) { + UNREACHABLE(); + } + + void operator()(VideoCommon::Shader::ASTDoWhile& ast) { + ExprDecompiler expr_parser{decomp}; + std::visit(expr_parser, *ast.condition); + decomp.code.AddLine("do {{"); + decomp.code.scope++; + ASTNode current = ast.nodes.GetFirst(); + while (current) { + Visit(current); + current = current->GetNext(); + } + decomp.code.scope--; + decomp.code.AddLine("}} while({});", expr_parser.GetResult()); + } + + void operator()(VideoCommon::Shader::ASTReturn& ast) { + const bool is_true = VideoCommon::Shader::ExprIsTrue(ast.condition); + if (!is_true) { + ExprDecompiler expr_parser{decomp}; + std::visit(expr_parser, *ast.condition); + decomp.code.AddLine("if ({}) {{", expr_parser.GetResult()); + decomp.code.scope++; + } + if (ast.kills) { + decomp.code.AddLine("discard;"); + } else { + decomp.PreExit(); + decomp.code.AddLine("return;"); + } + if (!is_true) { + decomp.code.scope--; + decomp.code.AddLine("}}"); + } + } + + void operator()(VideoCommon::Shader::ASTBreak& ast) { + const bool is_true = VideoCommon::Shader::ExprIsTrue(ast.condition); + if (!is_true) { + ExprDecompiler expr_parser{decomp}; + std::visit(expr_parser, *ast.condition); + decomp.code.AddLine("if ({}) {{", expr_parser.GetResult()); + decomp.code.scope++; + } + decomp.code.AddLine("break;"); + if (!is_true) { + decomp.code.scope--; + decomp.code.AddLine("}}"); + } + } + + void Visit(VideoCommon::Shader::ASTNode& node) { + std::visit(*this, *node->GetInnerData()); + } + +private: + GLSLDecompiler& decomp; +}; + +void GLSLDecompiler::DecompileAST() { + const u32 num_flow_variables = ir.GetASTNumVariables(); + for (u32 i = 0; i < num_flow_variables; i++) { + code.AddLine("bool {} = false;", GetFlowVariable(i)); + } + ASTDecompiler decompiler{*this}; + VideoCommon::Shader::ASTNode program = ir.GetASTProgram(); + decompiler.Visit(program); +} + } // Anonymous namespace std::string GetCommonDeclarations() { diff --git a/src/video_core/renderer_opengl/gl_shader_gen.cpp b/src/video_core/renderer_opengl/gl_shader_gen.cpp index 3a8d9e1da..b5a43e79e 100644 --- a/src/video_core/renderer_opengl/gl_shader_gen.cpp +++ b/src/video_core/renderer_opengl/gl_shader_gen.cpp @@ -11,12 +11,16 @@ namespace OpenGL::GLShader { using Tegra::Engines::Maxwell3D; +using VideoCommon::Shader::CompileDepth; +using VideoCommon::Shader::CompilerSettings; using VideoCommon::Shader::ProgramCode; using VideoCommon::Shader::ShaderIR; static constexpr u32 PROGRAM_OFFSET = 10; static constexpr u32 COMPUTE_OFFSET = 0; +static constexpr CompilerSettings settings{CompileDepth::NoFlowStack, true}; + ProgramResult GenerateVertexShader(const Device& device, const ShaderSetup& setup) { const std::string id = fmt::format("{:016x}", setup.program.unique_identifier); @@ -31,13 +35,14 @@ layout (std140, binding = EMULATION_UBO_BINDING) uniform vs_config { )"; - const ShaderIR program_ir(setup.program.code, PROGRAM_OFFSET, setup.program.size_a); + const ShaderIR program_ir(setup.program.code, PROGRAM_OFFSET, setup.program.size_a, settings); const auto stage = setup.IsDualProgram() ? ProgramType::VertexA : ProgramType::VertexB; ProgramResult program = Decompile(device, program_ir, stage, "vertex"); out += program.first; if (setup.IsDualProgram()) { - const ShaderIR program_ir_b(setup.program.code_b, PROGRAM_OFFSET, setup.program.size_b); + const ShaderIR program_ir_b(setup.program.code_b, PROGRAM_OFFSET, setup.program.size_b, + settings); ProgramResult program_b = Decompile(device, program_ir_b, ProgramType::VertexB, "vertex_b"); out += program_b.first; } @@ -80,7 +85,7 @@ layout (std140, binding = EMULATION_UBO_BINDING) uniform gs_config { )"; - const ShaderIR program_ir(setup.program.code, PROGRAM_OFFSET, setup.program.size_a); + const ShaderIR program_ir(setup.program.code, PROGRAM_OFFSET, setup.program.size_a, settings); ProgramResult program = Decompile(device, program_ir, ProgramType::Geometry, "geometry"); out += program.first; @@ -114,7 +119,8 @@ layout (std140, binding = EMULATION_UBO_BINDING) uniform fs_config { }; )"; - const ShaderIR program_ir(setup.program.code, PROGRAM_OFFSET, setup.program.size_a); + + const ShaderIR program_ir(setup.program.code, PROGRAM_OFFSET, setup.program.size_a, settings); ProgramResult program = Decompile(device, program_ir, ProgramType::Fragment, "fragment"); out += program.first; @@ -133,7 +139,7 @@ ProgramResult GenerateComputeShader(const Device& device, const ShaderSetup& set std::string out = "// Shader Unique Id: CS" + id + "\n\n"; out += GetCommonDeclarations(); - const ShaderIR program_ir(setup.program.code, COMPUTE_OFFSET, setup.program.size_a); + const ShaderIR program_ir(setup.program.code, COMPUTE_OFFSET, setup.program.size_a, settings); ProgramResult program = Decompile(device, program_ir, ProgramType::Compute, "compute"); out += program.first; diff --git a/src/video_core/renderer_vulkan/vk_shader_decompiler.cpp b/src/video_core/renderer_vulkan/vk_shader_decompiler.cpp index 77fc58f25..8bcd04221 100644 --- a/src/video_core/renderer_vulkan/vk_shader_decompiler.cpp +++ b/src/video_core/renderer_vulkan/vk_shader_decompiler.cpp @@ -88,6 +88,9 @@ bool IsPrecise(Operation operand) { } // namespace +class ASTDecompiler; +class ExprDecompiler; + class SPIRVDecompiler : public Sirit::Module { public: explicit SPIRVDecompiler(const VKDevice& device, const ShaderIR& ir, ShaderStage stage) @@ -97,27 +100,7 @@ public: AddExtension("SPV_KHR_variable_pointers"); } - void Decompile() { - AllocateBindings(); - AllocateLabels(); - - DeclareVertex(); - DeclareGeometry(); - DeclareFragment(); - DeclareRegisters(); - DeclarePredicates(); - DeclareLocalMemory(); - DeclareInternalFlags(); - DeclareInputAttributes(); - DeclareOutputAttributes(); - DeclareConstantBuffers(); - DeclareGlobalBuffers(); - DeclareSamplers(); - - execute_function = - Emit(OpFunction(t_void, spv::FunctionControlMask::Inline, TypeFunction(t_void))); - Emit(OpLabel()); - + void DecompileBranchMode() { const u32 first_address = ir.GetBasicBlocks().begin()->first; const Id loop_label = OpLabel("loop"); const Id merge_label = OpLabel("merge"); @@ -174,6 +157,43 @@ public: Emit(continue_label); Emit(OpBranch(loop_label)); Emit(merge_label); + } + + void DecompileAST(); + + void Decompile() { + const bool is_fully_decompiled = ir.IsDecompiled(); + AllocateBindings(); + if (!is_fully_decompiled) { + AllocateLabels(); + } + + DeclareVertex(); + DeclareGeometry(); + DeclareFragment(); + DeclareRegisters(); + DeclarePredicates(); + if (is_fully_decompiled) { + DeclareFlowVariables(); + } + DeclareLocalMemory(); + DeclareInternalFlags(); + DeclareInputAttributes(); + DeclareOutputAttributes(); + DeclareConstantBuffers(); + DeclareGlobalBuffers(); + DeclareSamplers(); + + execute_function = + Emit(OpFunction(t_void, spv::FunctionControlMask::Inline, TypeFunction(t_void))); + Emit(OpLabel()); + + if (is_fully_decompiled) { + DecompileAST(); + } else { + DecompileBranchMode(); + } + Emit(OpReturn()); Emit(OpFunctionEnd()); } @@ -206,6 +226,9 @@ public: } private: + friend class ASTDecompiler; + friend class ExprDecompiler; + static constexpr auto INTERNAL_FLAGS_COUNT = static_cast(InternalFlag::Amount); void AllocateBindings() { @@ -294,6 +317,14 @@ private: } } + void DeclareFlowVariables() { + for (u32 i = 0; i < ir.GetASTNumVariables(); i++) { + const Id id = OpVariable(t_prv_bool, spv::StorageClass::Private, v_false); + Name(id, fmt::format("flow_var_{}", static_cast(i))); + flow_variables.emplace(i, AddGlobalVariable(id)); + } + } + void DeclareLocalMemory() { if (const u64 local_memory_size = header.GetLocalMemorySize(); local_memory_size > 0) { const auto element_count = static_cast(Common::AlignUp(local_memory_size, 4) / 4); @@ -615,9 +646,15 @@ private: Emit(OpBranchConditional(condition, true_label, skip_label)); Emit(true_label); + ++conditional_nest_count; VisitBasicBlock(conditional->GetCode()); + --conditional_nest_count; - Emit(OpBranch(skip_label)); + if (inside_branch == 0) { + Emit(OpBranch(skip_label)); + } else { + inside_branch--; + } Emit(skip_label); return {}; @@ -980,7 +1017,11 @@ private: UNIMPLEMENTED_IF(!target); Emit(OpStore(jmp_to, Constant(t_uint, target->GetValue()))); - BranchingOp([&]() { Emit(OpBranch(continue_label)); }); + Emit(OpBranch(continue_label)); + inside_branch = conditional_nest_count; + if (conditional_nest_count == 0) { + Emit(OpLabel()); + } return {}; } @@ -988,7 +1029,11 @@ private: const Id op_a = VisitOperand(operation, 0); Emit(OpStore(jmp_to, op_a)); - BranchingOp([&]() { Emit(OpBranch(continue_label)); }); + Emit(OpBranch(continue_label)); + inside_branch = conditional_nest_count; + if (conditional_nest_count == 0) { + Emit(OpLabel()); + } return {}; } @@ -1015,11 +1060,15 @@ private: Emit(OpStore(flow_stack_top, previous)); Emit(OpStore(jmp_to, target)); - BranchingOp([&]() { Emit(OpBranch(continue_label)); }); + Emit(OpBranch(continue_label)); + inside_branch = conditional_nest_count; + if (conditional_nest_count == 0) { + Emit(OpLabel()); + } return {}; } - Id Exit(Operation operation) { + Id PreExit() { switch (stage) { case ShaderStage::Vertex: { // TODO(Rodrigo): We should use VK_EXT_depth_range_unrestricted instead, but it doesn't @@ -1067,12 +1116,35 @@ private: } } - BranchingOp([&]() { Emit(OpReturn()); }); + return {}; + } + + Id Exit(Operation operation) { + PreExit(); + inside_branch = conditional_nest_count; + if (conditional_nest_count > 0) { + Emit(OpReturn()); + } else { + const Id dummy = OpLabel(); + Emit(OpBranch(dummy)); + Emit(dummy); + Emit(OpReturn()); + Emit(OpLabel()); + } return {}; } Id Discard(Operation operation) { - BranchingOp([&]() { Emit(OpKill()); }); + inside_branch = conditional_nest_count; + if (conditional_nest_count > 0) { + Emit(OpKill()); + } else { + const Id dummy = OpLabel(); + Emit(OpBranch(dummy)); + Emit(dummy); + Emit(OpKill()); + Emit(OpLabel()); + } return {}; } @@ -1267,17 +1339,6 @@ private: return {}; } - void BranchingOp(std::function call) { - const Id true_label = OpLabel(); - const Id skip_label = OpLabel(); - Emit(OpSelectionMerge(skip_label, spv::SelectionControlMask::Flatten)); - Emit(OpBranchConditional(v_true, true_label, skip_label, 1, 0)); - Emit(true_label); - call(); - - Emit(skip_label); - } - std::tuple CreateFlowStack() { // TODO(Rodrigo): Figure out the actual depth of the flow stack, for now it seems unlikely // that shaders will use 20 nested SSYs and PBKs. @@ -1483,6 +1544,8 @@ private: const ShaderIR& ir; const ShaderStage stage; const Tegra::Shader::Header header; + u64 conditional_nest_count{}; + u64 inside_branch{}; const Id t_void = Name(TypeVoid(), "void"); @@ -1545,6 +1608,7 @@ private: Id per_vertex{}; std::map registers; std::map predicates; + std::map flow_variables; Id local_memory{}; std::array internal_flags{}; std::map input_attributes; @@ -1580,6 +1644,223 @@ private: std::map labels; }; +class ExprDecompiler { +public: + explicit ExprDecompiler(SPIRVDecompiler& decomp) : decomp{decomp} {} + + Id operator()(VideoCommon::Shader::ExprAnd& expr) { + const Id type_def = decomp.GetTypeDefinition(Type::Bool); + const Id op1 = Visit(expr.operand1); + const Id op2 = Visit(expr.operand2); + return decomp.Emit(decomp.OpLogicalAnd(type_def, op1, op2)); + } + + Id operator()(VideoCommon::Shader::ExprOr& expr) { + const Id type_def = decomp.GetTypeDefinition(Type::Bool); + const Id op1 = Visit(expr.operand1); + const Id op2 = Visit(expr.operand2); + return decomp.Emit(decomp.OpLogicalOr(type_def, op1, op2)); + } + + Id operator()(VideoCommon::Shader::ExprNot& expr) { + const Id type_def = decomp.GetTypeDefinition(Type::Bool); + const Id op1 = Visit(expr.operand1); + return decomp.Emit(decomp.OpLogicalNot(type_def, op1)); + } + + Id operator()(VideoCommon::Shader::ExprPredicate& expr) { + const auto pred = static_cast(expr.predicate); + return decomp.Emit(decomp.OpLoad(decomp.t_bool, decomp.predicates.at(pred))); + } + + Id operator()(VideoCommon::Shader::ExprCondCode& expr) { + const Node cc = decomp.ir.GetConditionCode(expr.cc); + Id target; + + if (const auto pred = std::get_if(&*cc)) { + const auto index = pred->GetIndex(); + switch (index) { + case Tegra::Shader::Pred::NeverExecute: + target = decomp.v_false; + case Tegra::Shader::Pred::UnusedIndex: + target = decomp.v_true; + default: + target = decomp.predicates.at(index); + } + } else if (const auto flag = std::get_if(&*cc)) { + target = decomp.internal_flags.at(static_cast(flag->GetFlag())); + } + return decomp.Emit(decomp.OpLoad(decomp.t_bool, target)); + } + + Id operator()(VideoCommon::Shader::ExprVar& expr) { + return decomp.Emit(decomp.OpLoad(decomp.t_bool, decomp.flow_variables.at(expr.var_index))); + } + + Id operator()(VideoCommon::Shader::ExprBoolean& expr) { + return expr.value ? decomp.v_true : decomp.v_false; + } + + Id Visit(VideoCommon::Shader::Expr& node) { + return std::visit(*this, *node); + } + +private: + SPIRVDecompiler& decomp; +}; + +class ASTDecompiler { +public: + explicit ASTDecompiler(SPIRVDecompiler& decomp) : decomp{decomp} {} + + void operator()(VideoCommon::Shader::ASTProgram& ast) { + ASTNode current = ast.nodes.GetFirst(); + while (current) { + Visit(current); + current = current->GetNext(); + } + } + + void operator()(VideoCommon::Shader::ASTIfThen& ast) { + ExprDecompiler expr_parser{decomp}; + const Id condition = expr_parser.Visit(ast.condition); + const Id then_label = decomp.OpLabel(); + const Id endif_label = decomp.OpLabel(); + decomp.Emit(decomp.OpSelectionMerge(endif_label, spv::SelectionControlMask::MaskNone)); + decomp.Emit(decomp.OpBranchConditional(condition, then_label, endif_label)); + decomp.Emit(then_label); + ASTNode current = ast.nodes.GetFirst(); + while (current) { + Visit(current); + current = current->GetNext(); + } + decomp.Emit(decomp.OpBranch(endif_label)); + decomp.Emit(endif_label); + } + + void operator()(VideoCommon::Shader::ASTIfElse& ast) { + UNREACHABLE(); + } + + void operator()(VideoCommon::Shader::ASTBlockEncoded& ast) { + UNREACHABLE(); + } + + void operator()(VideoCommon::Shader::ASTBlockDecoded& ast) { + decomp.VisitBasicBlock(ast.nodes); + } + + void operator()(VideoCommon::Shader::ASTVarSet& ast) { + ExprDecompiler expr_parser{decomp}; + const Id condition = expr_parser.Visit(ast.condition); + decomp.Emit(decomp.OpStore(decomp.flow_variables.at(ast.index), condition)); + } + + void operator()(VideoCommon::Shader::ASTLabel& ast) { + // Do nothing + } + + void operator()(VideoCommon::Shader::ASTGoto& ast) { + UNREACHABLE(); + } + + void operator()(VideoCommon::Shader::ASTDoWhile& ast) { + const Id loop_label = decomp.OpLabel(); + const Id endloop_label = decomp.OpLabel(); + const Id loop_start_block = decomp.OpLabel(); + const Id loop_end_block = decomp.OpLabel(); + current_loop_exit = endloop_label; + decomp.Emit(decomp.OpBranch(loop_label)); + decomp.Emit(loop_label); + decomp.Emit( + decomp.OpLoopMerge(endloop_label, loop_end_block, spv::LoopControlMask::MaskNone)); + decomp.Emit(decomp.OpBranch(loop_start_block)); + decomp.Emit(loop_start_block); + ASTNode current = ast.nodes.GetFirst(); + while (current) { + Visit(current); + current = current->GetNext(); + } + ExprDecompiler expr_parser{decomp}; + const Id condition = expr_parser.Visit(ast.condition); + decomp.Emit(decomp.OpBranchConditional(condition, loop_label, endloop_label)); + decomp.Emit(endloop_label); + } + + void operator()(VideoCommon::Shader::ASTReturn& ast) { + if (!VideoCommon::Shader::ExprIsTrue(ast.condition)) { + ExprDecompiler expr_parser{decomp}; + const Id condition = expr_parser.Visit(ast.condition); + const Id then_label = decomp.OpLabel(); + const Id endif_label = decomp.OpLabel(); + decomp.Emit(decomp.OpSelectionMerge(endif_label, spv::SelectionControlMask::MaskNone)); + decomp.Emit(decomp.OpBranchConditional(condition, then_label, endif_label)); + decomp.Emit(then_label); + if (ast.kills) { + decomp.Emit(decomp.OpKill()); + } else { + decomp.PreExit(); + decomp.Emit(decomp.OpReturn()); + } + decomp.Emit(endif_label); + } else { + const Id next_block = decomp.OpLabel(); + decomp.Emit(decomp.OpBranch(next_block)); + decomp.Emit(next_block); + if (ast.kills) { + decomp.Emit(decomp.OpKill()); + } else { + decomp.PreExit(); + decomp.Emit(decomp.OpReturn()); + } + decomp.Emit(decomp.OpLabel()); + } + } + + void operator()(VideoCommon::Shader::ASTBreak& ast) { + if (!VideoCommon::Shader::ExprIsTrue(ast.condition)) { + ExprDecompiler expr_parser{decomp}; + const Id condition = expr_parser.Visit(ast.condition); + const Id then_label = decomp.OpLabel(); + const Id endif_label = decomp.OpLabel(); + decomp.Emit(decomp.OpSelectionMerge(endif_label, spv::SelectionControlMask::MaskNone)); + decomp.Emit(decomp.OpBranchConditional(condition, then_label, endif_label)); + decomp.Emit(then_label); + decomp.Emit(decomp.OpBranch(current_loop_exit)); + decomp.Emit(endif_label); + } else { + const Id next_block = decomp.OpLabel(); + decomp.Emit(decomp.OpBranch(next_block)); + decomp.Emit(next_block); + decomp.Emit(decomp.OpBranch(current_loop_exit)); + decomp.Emit(decomp.OpLabel()); + } + } + + void Visit(VideoCommon::Shader::ASTNode& node) { + std::visit(*this, *node->GetInnerData()); + } + +private: + SPIRVDecompiler& decomp; + Id current_loop_exit{}; +}; + +void SPIRVDecompiler::DecompileAST() { + const u32 num_flow_variables = ir.GetASTNumVariables(); + for (u32 i = 0; i < num_flow_variables; i++) { + const Id id = OpVariable(t_prv_bool, spv::StorageClass::Private, v_false); + Name(id, fmt::format("flow_var_{}", i)); + flow_variables.emplace(i, AddGlobalVariable(id)); + } + ASTDecompiler decompiler{*this}; + VideoCommon::Shader::ASTNode program = ir.GetASTProgram(); + decompiler.Visit(program); + const Id next_block = OpLabel(); + Emit(OpBranch(next_block)); + Emit(next_block); +} + DecompilerResult Decompile(const VKDevice& device, const VideoCommon::Shader::ShaderIR& ir, Maxwell::ShaderStage stage) { auto decompiler = std::make_unique(device, ir, stage); diff --git a/src/video_core/shader/ast.cpp b/src/video_core/shader/ast.cpp new file mode 100644 index 000000000..2eb065c3d --- /dev/null +++ b/src/video_core/shader/ast.cpp @@ -0,0 +1,766 @@ +// Copyright 2019 yuzu Emulator Project +// Licensed under GPLv2 or any later version +// Refer to the license.txt file included. + +#include + +#include "common/assert.h" +#include "common/common_types.h" +#include "video_core/shader/ast.h" +#include "video_core/shader/expr.h" + +namespace VideoCommon::Shader { + +ASTZipper::ASTZipper() = default; + +void ASTZipper::Init(const ASTNode new_first, const ASTNode parent) { + ASSERT(new_first->manager == nullptr); + first = new_first; + last = new_first; + ASTNode current = first; + while (current) { + current->manager = this; + current->parent = parent; + last = current; + current = current->next; + } +} + +void ASTZipper::PushBack(const ASTNode new_node) { + ASSERT(new_node->manager == nullptr); + new_node->previous = last; + if (last) { + last->next = new_node; + } + new_node->next.reset(); + last = new_node; + if (!first) { + first = new_node; + } + new_node->manager = this; +} + +void ASTZipper::PushFront(const ASTNode new_node) { + ASSERT(new_node->manager == nullptr); + new_node->previous.reset(); + new_node->next = first; + if (first) { + first->previous = new_node; + } + if (last == first) { + last = new_node; + } + first = new_node; + new_node->manager = this; +} + +void ASTZipper::InsertAfter(const ASTNode new_node, const ASTNode at_node) { + ASSERT(new_node->manager == nullptr); + if (!at_node) { + PushFront(new_node); + return; + } + const ASTNode next = at_node->next; + if (next) { + next->previous = new_node; + } + new_node->previous = at_node; + if (at_node == last) { + last = new_node; + } + new_node->next = next; + at_node->next = new_node; + new_node->manager = this; +} + +void ASTZipper::InsertBefore(const ASTNode new_node, const ASTNode at_node) { + ASSERT(new_node->manager == nullptr); + if (!at_node) { + PushBack(new_node); + return; + } + const ASTNode previous = at_node->previous; + if (previous) { + previous->next = new_node; + } + new_node->next = at_node; + if (at_node == first) { + first = new_node; + } + new_node->previous = previous; + at_node->previous = new_node; + new_node->manager = this; +} + +void ASTZipper::DetachTail(const ASTNode node) { + ASSERT(node->manager == this); + if (node == first) { + first.reset(); + last.reset(); + return; + } + + last = node->previous; + last->next.reset(); + node->previous.reset(); + ASTNode current = node; + while (current) { + current->manager = nullptr; + current->parent.reset(); + current = current->next; + } +} + +void ASTZipper::DetachSegment(const ASTNode start, const ASTNode end) { + ASSERT(start->manager == this && end->manager == this); + if (start == end) { + DetachSingle(start); + return; + } + const ASTNode prev = start->previous; + const ASTNode post = end->next; + if (!prev) { + first = post; + } else { + prev->next = post; + } + if (!post) { + last = prev; + } else { + post->previous = prev; + } + start->previous.reset(); + end->next.reset(); + ASTNode current = start; + bool found = false; + while (current) { + current->manager = nullptr; + current->parent.reset(); + found |= current == end; + current = current->next; + } + ASSERT(found); +} + +void ASTZipper::DetachSingle(const ASTNode node) { + ASSERT(node->manager == this); + const ASTNode prev = node->previous; + const ASTNode post = node->next; + node->previous.reset(); + node->next.reset(); + if (!prev) { + first = post; + } else { + prev->next = post; + } + if (!post) { + last = prev; + } else { + post->previous = prev; + } + + node->manager = nullptr; + node->parent.reset(); +} + +void ASTZipper::Remove(const ASTNode node) { + ASSERT(node->manager == this); + const ASTNode next = node->next; + const ASTNode previous = node->previous; + if (previous) { + previous->next = next; + } + if (next) { + next->previous = previous; + } + node->parent.reset(); + node->manager = nullptr; + if (node == last) { + last = previous; + } + if (node == first) { + first = next; + } +} + +class ExprPrinter final { +public: + ExprPrinter() = default; + + void operator()(ExprAnd const& expr) { + inner += "( "; + std::visit(*this, *expr.operand1); + inner += " && "; + std::visit(*this, *expr.operand2); + inner += ')'; + } + + void operator()(ExprOr const& expr) { + inner += "( "; + std::visit(*this, *expr.operand1); + inner += " || "; + std::visit(*this, *expr.operand2); + inner += ')'; + } + + void operator()(ExprNot const& expr) { + inner += "!"; + std::visit(*this, *expr.operand1); + } + + void operator()(ExprPredicate const& expr) { + inner += "P" + std::to_string(expr.predicate); + } + + void operator()(ExprCondCode const& expr) { + u32 cc = static_cast(expr.cc); + inner += "CC" + std::to_string(cc); + } + + void operator()(ExprVar const& expr) { + inner += "V" + std::to_string(expr.var_index); + } + + void operator()(ExprBoolean const& expr) { + inner += expr.value ? "true" : "false"; + } + + std::string& GetResult() { + return inner; + } + + std::string inner{}; +}; + +class ASTPrinter { +public: + ASTPrinter() = default; + + void operator()(ASTProgram& ast) { + scope++; + inner += "program {\n"; + ASTNode current = ast.nodes.GetFirst(); + while (current) { + Visit(current); + current = current->GetNext(); + } + inner += "}\n"; + scope--; + } + + void operator()(ASTIfThen& ast) { + ExprPrinter expr_parser{}; + std::visit(expr_parser, *ast.condition); + inner += Ident() + "if (" + expr_parser.GetResult() + ") {\n"; + scope++; + ASTNode current = ast.nodes.GetFirst(); + while (current) { + Visit(current); + current = current->GetNext(); + } + scope--; + inner += Ident() + "}\n"; + } + + void operator()(ASTIfElse& ast) { + inner += Ident() + "else {\n"; + scope++; + ASTNode current = ast.nodes.GetFirst(); + while (current) { + Visit(current); + current = current->GetNext(); + } + scope--; + inner += Ident() + "}\n"; + } + + void operator()(ASTBlockEncoded& ast) { + inner += Ident() + "Block(" + std::to_string(ast.start) + ", " + std::to_string(ast.end) + + ");\n"; + } + + void operator()(ASTBlockDecoded& ast) { + inner += Ident() + "Block;\n"; + } + + void operator()(ASTVarSet& ast) { + ExprPrinter expr_parser{}; + std::visit(expr_parser, *ast.condition); + inner += + Ident() + "V" + std::to_string(ast.index) + " := " + expr_parser.GetResult() + ";\n"; + } + + void operator()(ASTLabel& ast) { + inner += "Label_" + std::to_string(ast.index) + ":\n"; + } + + void operator()(ASTGoto& ast) { + ExprPrinter expr_parser{}; + std::visit(expr_parser, *ast.condition); + inner += Ident() + "(" + expr_parser.GetResult() + ") -> goto Label_" + + std::to_string(ast.label) + ";\n"; + } + + void operator()(ASTDoWhile& ast) { + ExprPrinter expr_parser{}; + std::visit(expr_parser, *ast.condition); + inner += Ident() + "do {\n"; + scope++; + ASTNode current = ast.nodes.GetFirst(); + while (current) { + Visit(current); + current = current->GetNext(); + } + scope--; + inner += Ident() + "} while (" + expr_parser.GetResult() + ");\n"; + } + + void operator()(ASTReturn& ast) { + ExprPrinter expr_parser{}; + std::visit(expr_parser, *ast.condition); + inner += Ident() + "(" + expr_parser.GetResult() + ") -> " + + (ast.kills ? "discard" : "exit") + ";\n"; + } + + void operator()(ASTBreak& ast) { + ExprPrinter expr_parser{}; + std::visit(expr_parser, *ast.condition); + inner += Ident() + "(" + expr_parser.GetResult() + ") -> break;\n"; + } + + std::string& Ident() { + if (memo_scope == scope) { + return tabs_memo; + } + tabs_memo = tabs.substr(0, scope * 2); + memo_scope = scope; + return tabs_memo; + } + + void Visit(ASTNode& node) { + std::visit(*this, *node->GetInnerData()); + } + + std::string& GetResult() { + return inner; + } + +private: + std::string inner{}; + u32 scope{}; + + std::string tabs_memo{}; + u32 memo_scope{}; + + static std::string tabs; +}; + +std::string ASTPrinter::tabs = " "; + +std::string ASTManager::Print() { + ASTPrinter printer{}; + printer.Visit(main_node); + return printer.GetResult(); +} + +ASTManager::ASTManager(bool full_decompile, bool disable_else_derivation) + : full_decompile{full_decompile}, disable_else_derivation{disable_else_derivation} {}; + +ASTManager::~ASTManager() { + Clear(); +} + +void ASTManager::Init() { + main_node = ASTBase::Make(ASTNode{}); + program = std::get_if(main_node->GetInnerData()); + false_condition = MakeExpr(false); +} + +ASTManager::ASTManager(ASTManager&& other) noexcept + : labels_map(std::move(other.labels_map)), labels_count{other.labels_count}, + gotos(std::move(other.gotos)), labels(std::move(other.labels)), variables{other.variables}, + program{other.program}, main_node{other.main_node}, false_condition{other.false_condition}, + disable_else_derivation{other.disable_else_derivation} { + other.main_node.reset(); +} + +ASTManager& ASTManager::operator=(ASTManager&& other) noexcept { + full_decompile = other.full_decompile; + labels_map = std::move(other.labels_map); + labels_count = other.labels_count; + gotos = std::move(other.gotos); + labels = std::move(other.labels); + variables = other.variables; + program = other.program; + main_node = other.main_node; + false_condition = other.false_condition; + disable_else_derivation = other.disable_else_derivation; + + other.main_node.reset(); + return *this; +} + +void ASTManager::DeclareLabel(u32 address) { + const auto pair = labels_map.emplace(address, labels_count); + if (pair.second) { + labels_count++; + labels.resize(labels_count); + } +} + +void ASTManager::InsertLabel(u32 address) { + const u32 index = labels_map[address]; + const ASTNode label = ASTBase::Make(main_node, index); + labels[index] = label; + program->nodes.PushBack(label); +} + +void ASTManager::InsertGoto(Expr condition, u32 address) { + const u32 index = labels_map[address]; + const ASTNode goto_node = ASTBase::Make(main_node, condition, index); + gotos.push_back(goto_node); + program->nodes.PushBack(goto_node); +} + +void ASTManager::InsertBlock(u32 start_address, u32 end_address) { + const ASTNode block = ASTBase::Make(main_node, start_address, end_address); + program->nodes.PushBack(block); +} + +void ASTManager::InsertReturn(Expr condition, bool kills) { + const ASTNode node = ASTBase::Make(main_node, condition, kills); + program->nodes.PushBack(node); +} + +// The decompile algorithm is based on +// "Taming control flow: A structured approach to eliminating goto statements" +// by AM Erosa, LJ Hendren 1994. In general, the idea is to get gotos to be +// on the same structured level as the label which they jump to. This is done, +// through outward/inward movements and lifting. Once they are at the same +// level, you can enclose them in an "if" structure or a "do-while" structure. +void ASTManager::Decompile() { + auto it = gotos.begin(); + while (it != gotos.end()) { + const ASTNode goto_node = *it; + const auto label_index = goto_node->GetGotoLabel(); + if (!label_index) { + return; + } + const ASTNode label = labels[*label_index]; + if (!full_decompile) { + // We only decompile backward jumps + if (!IsBackwardsJump(goto_node, label)) { + it++; + continue; + } + } + if (IndirectlyRelated(goto_node, label)) { + while (!DirectlyRelated(goto_node, label)) { + MoveOutward(goto_node); + } + } + if (DirectlyRelated(goto_node, label)) { + u32 goto_level = goto_node->GetLevel(); + const u32 label_level = label->GetLevel(); + while (label_level < goto_level) { + MoveOutward(goto_node); + goto_level--; + } + // TODO(Blinkhawk): Implement Lifting and Inward Movements + } + if (label->GetParent() == goto_node->GetParent()) { + bool is_loop = false; + ASTNode current = goto_node->GetPrevious(); + while (current) { + if (current == label) { + is_loop = true; + break; + } + current = current->GetPrevious(); + } + + if (is_loop) { + EncloseDoWhile(goto_node, label); + } else { + EncloseIfThen(goto_node, label); + } + it = gotos.erase(it); + continue; + } + it++; + } + if (full_decompile) { + for (const ASTNode& label : labels) { + auto& manager = label->GetManager(); + manager.Remove(label); + } + labels.clear(); + } else { + auto it = labels.begin(); + while (it != labels.end()) { + bool can_remove = true; + ASTNode label = *it; + for (const ASTNode& goto_node : gotos) { + const auto label_index = goto_node->GetGotoLabel(); + if (!label_index) { + return; + } + ASTNode& glabel = labels[*label_index]; + if (glabel == label) { + can_remove = false; + break; + } + } + if (can_remove) { + label->MarkLabelUnused(); + } + } + } +} + +bool ASTManager::IsBackwardsJump(ASTNode goto_node, ASTNode label_node) const { + u32 goto_level = goto_node->GetLevel(); + u32 label_level = label_node->GetLevel(); + while (goto_level > label_level) { + goto_level--; + goto_node = goto_node->GetParent(); + } + while (label_level > goto_level) { + label_level--; + label_node = label_node->GetParent(); + } + while (goto_node->GetParent() != label_node->GetParent()) { + goto_node = goto_node->GetParent(); + label_node = label_node->GetParent(); + } + ASTNode current = goto_node->GetPrevious(); + while (current) { + if (current == label_node) { + return true; + } + current = current->GetPrevious(); + } + return false; +} + +bool ASTManager::IndirectlyRelated(ASTNode first, ASTNode second) { + return !(first->GetParent() == second->GetParent() || DirectlyRelated(first, second)); +} + +bool ASTManager::DirectlyRelated(ASTNode first, ASTNode second) { + if (first->GetParent() == second->GetParent()) { + return false; + } + const u32 first_level = first->GetLevel(); + const u32 second_level = second->GetLevel(); + u32 min_level; + u32 max_level; + ASTNode max; + ASTNode min; + if (first_level > second_level) { + min_level = second_level; + min = second; + max_level = first_level; + max = first; + } else { + min_level = first_level; + min = first; + max_level = second_level; + max = second; + } + + while (max_level > min_level) { + max_level--; + max = max->GetParent(); + } + + return min->GetParent() == max->GetParent(); +} + +void ASTManager::ShowCurrentState(std::string state) { + LOG_CRITICAL(HW_GPU, "\nState {}:\n\n{}\n", state, Print()); + SanityCheck(); +} + +void ASTManager::SanityCheck() { + for (auto& label : labels) { + if (!label->GetParent()) { + LOG_CRITICAL(HW_GPU, "Sanity Check Failed"); + } + } +} + +void ASTManager::EncloseDoWhile(ASTNode goto_node, ASTNode label) { + ASTZipper& zipper = goto_node->GetManager(); + const ASTNode loop_start = label->GetNext(); + if (loop_start == goto_node) { + zipper.Remove(goto_node); + return; + } + const ASTNode parent = label->GetParent(); + const Expr condition = goto_node->GetGotoCondition(); + zipper.DetachSegment(loop_start, goto_node); + const ASTNode do_while_node = ASTBase::Make(parent, condition); + ASTZipper* sub_zipper = do_while_node->GetSubNodes(); + sub_zipper->Init(loop_start, do_while_node); + zipper.InsertAfter(do_while_node, label); + sub_zipper->Remove(goto_node); +} + +void ASTManager::EncloseIfThen(ASTNode goto_node, ASTNode label) { + ASTZipper& zipper = goto_node->GetManager(); + const ASTNode if_end = label->GetPrevious(); + if (if_end == goto_node) { + zipper.Remove(goto_node); + return; + } + const ASTNode prev = goto_node->GetPrevious(); + const Expr condition = goto_node->GetGotoCondition(); + bool do_else = false; + if (!disable_else_derivation && prev->IsIfThen()) { + const Expr if_condition = prev->GetIfCondition(); + do_else = ExprAreEqual(if_condition, condition); + } + const ASTNode parent = label->GetParent(); + zipper.DetachSegment(goto_node, if_end); + ASTNode if_node; + if (do_else) { + if_node = ASTBase::Make(parent); + } else { + Expr neg_condition = MakeExprNot(condition); + if_node = ASTBase::Make(parent, neg_condition); + } + ASTZipper* sub_zipper = if_node->GetSubNodes(); + sub_zipper->Init(goto_node, if_node); + zipper.InsertAfter(if_node, prev); + sub_zipper->Remove(goto_node); +} + +void ASTManager::MoveOutward(ASTNode goto_node) { + ASTZipper& zipper = goto_node->GetManager(); + const ASTNode parent = goto_node->GetParent(); + ASTZipper& zipper2 = parent->GetManager(); + const ASTNode grandpa = parent->GetParent(); + const bool is_loop = parent->IsLoop(); + const bool is_else = parent->IsIfElse(); + const bool is_if = parent->IsIfThen(); + + const ASTNode prev = goto_node->GetPrevious(); + const ASTNode post = goto_node->GetNext(); + + const Expr condition = goto_node->GetGotoCondition(); + zipper.DetachSingle(goto_node); + if (is_loop) { + const u32 var_index = NewVariable(); + const Expr var_condition = MakeExpr(var_index); + const ASTNode var_node = ASTBase::Make(parent, var_index, condition); + const ASTNode var_node_init = ASTBase::Make(parent, var_index, false_condition); + zipper2.InsertBefore(var_node_init, parent); + zipper.InsertAfter(var_node, prev); + goto_node->SetGotoCondition(var_condition); + const ASTNode break_node = ASTBase::Make(parent, var_condition); + zipper.InsertAfter(break_node, var_node); + } else if (is_if || is_else) { + const u32 var_index = NewVariable(); + const Expr var_condition = MakeExpr(var_index); + const ASTNode var_node = ASTBase::Make(parent, var_index, condition); + const ASTNode var_node_init = ASTBase::Make(parent, var_index, false_condition); + if (is_if) { + zipper2.InsertBefore(var_node_init, parent); + } else { + zipper2.InsertBefore(var_node_init, parent->GetPrevious()); + } + zipper.InsertAfter(var_node, prev); + goto_node->SetGotoCondition(var_condition); + if (post) { + zipper.DetachTail(post); + const ASTNode if_node = ASTBase::Make(parent, MakeExprNot(var_condition)); + ASTZipper* sub_zipper = if_node->GetSubNodes(); + sub_zipper->Init(post, if_node); + zipper.InsertAfter(if_node, var_node); + } + } else { + UNREACHABLE(); + } + const ASTNode next = parent->GetNext(); + if (is_if && next && next->IsIfElse()) { + zipper2.InsertAfter(goto_node, next); + goto_node->SetParent(grandpa); + return; + } + zipper2.InsertAfter(goto_node, parent); + goto_node->SetParent(grandpa); +} + +class ASTClearer { +public: + ASTClearer() = default; + + void operator()(ASTProgram& ast) { + ASTNode current = ast.nodes.GetFirst(); + while (current) { + Visit(current); + current = current->GetNext(); + } + } + + void operator()(ASTIfThen& ast) { + ASTNode current = ast.nodes.GetFirst(); + while (current) { + Visit(current); + current = current->GetNext(); + } + } + + void operator()(ASTIfElse& ast) { + ASTNode current = ast.nodes.GetFirst(); + while (current) { + Visit(current); + current = current->GetNext(); + } + } + + void operator()(ASTBlockEncoded& ast) {} + + void operator()(ASTBlockDecoded& ast) { + ast.nodes.clear(); + } + + void operator()(ASTVarSet& ast) {} + + void operator()(ASTLabel& ast) {} + + void operator()(ASTGoto& ast) {} + + void operator()(ASTDoWhile& ast) { + ASTNode current = ast.nodes.GetFirst(); + while (current) { + Visit(current); + current = current->GetNext(); + } + } + + void operator()(ASTReturn& ast) {} + + void operator()(ASTBreak& ast) {} + + void Visit(ASTNode& node) { + std::visit(*this, *node->GetInnerData()); + node->Clear(); + } +}; + +void ASTManager::Clear() { + if (!main_node) { + return; + } + ASTClearer clearer{}; + clearer.Visit(main_node); + main_node.reset(); + program = nullptr; + labels_map.clear(); + labels.clear(); + gotos.clear(); +} + +} // namespace VideoCommon::Shader diff --git a/src/video_core/shader/ast.h b/src/video_core/shader/ast.h new file mode 100644 index 000000000..ba234138e --- /dev/null +++ b/src/video_core/shader/ast.h @@ -0,0 +1,391 @@ +// Copyright 2019 yuzu Emulator Project +// Licensed under GPLv2 or any later version +// Refer to the license.txt file included. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include "video_core/shader/expr.h" +#include "video_core/shader/node.h" + +namespace VideoCommon::Shader { + +class ASTBase; +class ASTProgram; +class ASTIfThen; +class ASTIfElse; +class ASTBlockEncoded; +class ASTBlockDecoded; +class ASTVarSet; +class ASTGoto; +class ASTLabel; +class ASTDoWhile; +class ASTReturn; +class ASTBreak; + +using ASTData = std::variant; + +using ASTNode = std::shared_ptr; + +enum class ASTZipperType : u32 { + Program, + IfThen, + IfElse, + Loop, +}; + +class ASTZipper final { +public: + explicit ASTZipper(); + + void Init(ASTNode first, ASTNode parent); + + ASTNode GetFirst() { + return first; + } + + ASTNode GetLast() { + return last; + } + + void PushBack(ASTNode new_node); + void PushFront(ASTNode new_node); + void InsertAfter(ASTNode new_node, ASTNode at_node); + void InsertBefore(ASTNode new_node, ASTNode at_node); + void DetachTail(ASTNode node); + void DetachSingle(ASTNode node); + void DetachSegment(ASTNode start, ASTNode end); + void Remove(ASTNode node); + + ASTNode first{}; + ASTNode last{}; +}; + +class ASTProgram { +public: + explicit ASTProgram() = default; + ASTZipper nodes{}; +}; + +class ASTIfThen { +public: + explicit ASTIfThen(Expr condition) : condition(condition) {} + Expr condition; + ASTZipper nodes{}; +}; + +class ASTIfElse { +public: + explicit ASTIfElse() = default; + ASTZipper nodes{}; +}; + +class ASTBlockEncoded { +public: + explicit ASTBlockEncoded(u32 start, u32 end) : start{start}, end{end} {} + u32 start; + u32 end; +}; + +class ASTBlockDecoded { +public: + explicit ASTBlockDecoded(NodeBlock&& new_nodes) : nodes(std::move(new_nodes)) {} + NodeBlock nodes; +}; + +class ASTVarSet { +public: + explicit ASTVarSet(u32 index, Expr condition) : index{index}, condition{condition} {} + u32 index; + Expr condition; +}; + +class ASTLabel { +public: + explicit ASTLabel(u32 index) : index{index} {} + u32 index; + bool unused{}; +}; + +class ASTGoto { +public: + explicit ASTGoto(Expr condition, u32 label) : condition{condition}, label{label} {} + Expr condition; + u32 label; +}; + +class ASTDoWhile { +public: + explicit ASTDoWhile(Expr condition) : condition(condition) {} + Expr condition; + ASTZipper nodes{}; +}; + +class ASTReturn { +public: + explicit ASTReturn(Expr condition, bool kills) : condition{condition}, kills{kills} {} + Expr condition; + bool kills; +}; + +class ASTBreak { +public: + explicit ASTBreak(Expr condition) : condition{condition} {} + Expr condition; +}; + +class ASTBase { +public: + explicit ASTBase(ASTNode parent, ASTData data) : parent{parent}, data{data} {} + + template + static ASTNode Make(ASTNode parent, Args&&... args) { + return std::make_shared(parent, ASTData(U(std::forward(args)...))); + } + + void SetParent(ASTNode new_parent) { + parent = new_parent; + } + + ASTNode& GetParent() { + return parent; + } + + const ASTNode& GetParent() const { + return parent; + } + + u32 GetLevel() const { + u32 level = 0; + auto next_parent = parent; + while (next_parent) { + next_parent = next_parent->GetParent(); + level++; + } + return level; + } + + ASTData* GetInnerData() { + return &data; + } + + ASTNode GetNext() const { + return next; + } + + ASTNode GetPrevious() const { + return previous; + } + + ASTZipper& GetManager() { + return *manager; + } + + std::optional GetGotoLabel() const { + auto inner = std::get_if(&data); + if (inner) { + return {inner->label}; + } + return {}; + } + + Expr GetGotoCondition() const { + auto inner = std::get_if(&data); + if (inner) { + return inner->condition; + } + return nullptr; + } + + void MarkLabelUnused() { + auto inner = std::get_if(&data); + if (inner) { + inner->unused = true; + } + } + + bool IsLabelUnused() const { + auto inner = std::get_if(&data); + if (inner) { + return inner->unused; + } + return true; + } + + std::optional GetLabelIndex() const { + auto inner = std::get_if(&data); + if (inner) { + return {inner->index}; + } + return {}; + } + + Expr GetIfCondition() const { + auto inner = std::get_if(&data); + if (inner) { + return inner->condition; + } + return nullptr; + } + + void SetGotoCondition(Expr new_condition) { + auto inner = std::get_if(&data); + if (inner) { + inner->condition = new_condition; + } + } + + bool IsIfThen() const { + return std::holds_alternative(data); + } + + bool IsIfElse() const { + return std::holds_alternative(data); + } + + bool IsBlockEncoded() const { + return std::holds_alternative(data); + } + + void TransformBlockEncoded(NodeBlock&& nodes) { + data = ASTBlockDecoded(std::move(nodes)); + } + + bool IsLoop() const { + return std::holds_alternative(data); + } + + ASTZipper* GetSubNodes() { + if (std::holds_alternative(data)) { + return &std::get_if(&data)->nodes; + } + if (std::holds_alternative(data)) { + return &std::get_if(&data)->nodes; + } + if (std::holds_alternative(data)) { + return &std::get_if(&data)->nodes; + } + if (std::holds_alternative(data)) { + return &std::get_if(&data)->nodes; + } + return nullptr; + } + + void Clear() { + next.reset(); + previous.reset(); + parent.reset(); + manager = nullptr; + } + +private: + friend class ASTZipper; + + ASTData data; + ASTNode parent{}; + ASTNode next{}; + ASTNode previous{}; + ASTZipper* manager{}; +}; + +class ASTManager final { +public: + ASTManager(bool full_decompile, bool disable_else_derivation); + ~ASTManager(); + + ASTManager(const ASTManager& o) = delete; + ASTManager& operator=(const ASTManager& other) = delete; + + ASTManager(ASTManager&& other) noexcept; + ASTManager& operator=(ASTManager&& other) noexcept; + + void Init(); + + void DeclareLabel(u32 address); + + void InsertLabel(u32 address); + + void InsertGoto(Expr condition, u32 address); + + void InsertBlock(u32 start_address, u32 end_address); + + void InsertReturn(Expr condition, bool kills); + + std::string Print(); + + void Decompile(); + + void ShowCurrentState(std::string state); + + void SanityCheck(); + + void Clear(); + + bool IsFullyDecompiled() const { + if (full_decompile) { + return gotos.size() == 0; + } else { + for (ASTNode goto_node : gotos) { + auto label_index = goto_node->GetGotoLabel(); + if (!label_index) { + return false; + } + ASTNode glabel = labels[*label_index]; + if (IsBackwardsJump(goto_node, glabel)) { + return false; + } + } + return true; + } + } + + ASTNode GetProgram() const { + return main_node; + } + + u32 GetVariables() const { + return variables; + } + + const std::vector& GetLabels() const { + return labels; + } + +private: + bool IsBackwardsJump(ASTNode goto_node, ASTNode label_node) const; + + bool IndirectlyRelated(ASTNode first, ASTNode second); + + bool DirectlyRelated(ASTNode first, ASTNode second); + + void EncloseDoWhile(ASTNode goto_node, ASTNode label); + + void EncloseIfThen(ASTNode goto_node, ASTNode label); + + void MoveOutward(ASTNode goto_node); + + u32 NewVariable() { + return variables++; + } + + bool full_decompile{}; + bool disable_else_derivation{}; + std::unordered_map labels_map{}; + u32 labels_count{}; + std::vector labels{}; + std::list gotos{}; + u32 variables{}; + ASTProgram* program{}; + ASTNode main_node{}; + Expr false_condition{}; +}; + +} // namespace VideoCommon::Shader diff --git a/src/video_core/shader/compiler_settings.cpp b/src/video_core/shader/compiler_settings.cpp new file mode 100644 index 000000000..cddcbd4f0 --- /dev/null +++ b/src/video_core/shader/compiler_settings.cpp @@ -0,0 +1,26 @@ +// Copyright 2019 yuzu Emulator Project +// Licensed under GPLv2 or any later version +// Refer to the license.txt file included. + +#include "video_core/shader/compiler_settings.h" + +namespace VideoCommon::Shader { + +std::string CompileDepthAsString(const CompileDepth cd) { + switch (cd) { + case CompileDepth::BruteForce: + return "Brute Force Compile"; + case CompileDepth::FlowStack: + return "Simple Flow Stack Mode"; + case CompileDepth::NoFlowStack: + return "Remove Flow Stack"; + case CompileDepth::DecompileBackwards: + return "Decompile Backward Jumps"; + case CompileDepth::FullDecompile: + return "Full Decompilation"; + default: + return "Unknown Compiler Process"; + } +} + +} // namespace VideoCommon::Shader diff --git a/src/video_core/shader/compiler_settings.h b/src/video_core/shader/compiler_settings.h new file mode 100644 index 000000000..916018c01 --- /dev/null +++ b/src/video_core/shader/compiler_settings.h @@ -0,0 +1,26 @@ +// Copyright 2019 yuzu Emulator Project +// Licensed under GPLv2 or any later version +// Refer to the license.txt file included. + +#pragma once + +#include "video_core/engines/shader_bytecode.h" + +namespace VideoCommon::Shader { + +enum class CompileDepth : u32 { + BruteForce = 0, + FlowStack = 1, + NoFlowStack = 2, + DecompileBackwards = 3, + FullDecompile = 4, +}; + +std::string CompileDepthAsString(CompileDepth cd); + +struct CompilerSettings { + CompileDepth depth{CompileDepth::NoFlowStack}; + bool disable_else_derivation{true}; +}; + +} // namespace VideoCommon::Shader diff --git a/src/video_core/shader/control_flow.cpp b/src/video_core/shader/control_flow.cpp index ec3a76690..3c3a41ba6 100644 --- a/src/video_core/shader/control_flow.cpp +++ b/src/video_core/shader/control_flow.cpp @@ -4,13 +4,14 @@ #include #include +#include #include #include -#include #include #include "common/assert.h" #include "common/common_types.h" +#include "video_core/shader/ast.h" #include "video_core/shader/control_flow.h" #include "video_core/shader/shader_ir.h" @@ -64,12 +65,13 @@ struct CFGRebuildState { std::list inspect_queries{}; std::list queries{}; std::unordered_map registered{}; - std::unordered_set labels{}; + std::set labels{}; std::map ssy_labels{}; std::map pbk_labels{}; std::unordered_map stacks{}; const ProgramCode& program_code; const std::size_t program_size; + ASTManager* manager; }; enum class BlockCollision : u32 { None, Found, Inside }; @@ -415,38 +417,132 @@ bool TryQuery(CFGRebuildState& state) { } } // Anonymous namespace -std::optional ScanFlow(const ProgramCode& program_code, - std::size_t program_size, u32 start_address) { - CFGRebuildState state{program_code, program_size, start_address}; +void InsertBranch(ASTManager& mm, const BlockBranchInfo& branch) { + const auto get_expr = ([&](const Condition& cond) -> Expr { + Expr result{}; + if (cond.cc != ConditionCode::T) { + result = MakeExpr(cond.cc); + } + if (cond.predicate != Pred::UnusedIndex) { + u32 pred = static_cast(cond.predicate); + bool negate = false; + if (pred > 7) { + negate = true; + pred -= 8; + } + Expr extra = MakeExpr(pred); + if (negate) { + extra = MakeExpr(extra); + } + if (result) { + return MakeExpr(extra, result); + } + return extra; + } + if (result) { + return result; + } + return MakeExpr(true); + }); + if (branch.address < 0) { + if (branch.kill) { + mm.InsertReturn(get_expr(branch.condition), true); + return; + } + mm.InsertReturn(get_expr(branch.condition), false); + return; + } + mm.InsertGoto(get_expr(branch.condition), branch.address); +} +void DecompileShader(CFGRebuildState& state) { + state.manager->Init(); + for (auto label : state.labels) { + state.manager->DeclareLabel(label); + } + for (auto& block : state.block_info) { + if (state.labels.count(block.start) != 0) { + state.manager->InsertLabel(block.start); + } + u32 end = block.branch.ignore ? block.end + 1 : block.end; + state.manager->InsertBlock(block.start, end); + if (!block.branch.ignore) { + InsertBranch(*state.manager, block.branch); + } + } + state.manager->Decompile(); +} + +std::unique_ptr ScanFlow(const ProgramCode& program_code, u32 program_size, + u32 start_address, + const CompilerSettings& settings) { + auto result_out = std::make_unique(); + if (settings.depth == CompileDepth::BruteForce) { + result_out->settings.depth = CompileDepth::BruteForce; + return std::move(result_out); + } + + CFGRebuildState state{program_code, program_size, start_address}; // Inspect Code and generate blocks state.labels.clear(); state.labels.emplace(start_address); state.inspect_queries.push_back(state.start); while (!state.inspect_queries.empty()) { if (!TryInspectAddress(state)) { - return {}; + result_out->settings.depth = CompileDepth::BruteForce; + return std::move(result_out); } } - // Decompile Stacks - state.queries.push_back(Query{state.start, {}, {}}); - bool decompiled = true; - while (!state.queries.empty()) { - if (!TryQuery(state)) { - decompiled = false; - break; + bool use_flow_stack = true; + + bool decompiled = false; + + if (settings.depth != CompileDepth::FlowStack) { + // Decompile Stacks + state.queries.push_back(Query{state.start, {}, {}}); + decompiled = true; + while (!state.queries.empty()) { + if (!TryQuery(state)) { + decompiled = false; + break; + } } } + use_flow_stack = !decompiled; + // Sort and organize results std::sort(state.block_info.begin(), state.block_info.end(), - [](const BlockInfo& a, const BlockInfo& b) { return a.start < b.start; }); - ShaderCharacteristics result_out{}; - result_out.decompilable = decompiled; - result_out.start = start_address; - result_out.end = start_address; - for (const auto& block : state.block_info) { + [](const BlockInfo& a, const BlockInfo& b) -> bool { return a.start < b.start; }); + if (decompiled && settings.depth != CompileDepth::NoFlowStack) { + ASTManager manager{settings.depth != CompileDepth::DecompileBackwards, + settings.disable_else_derivation}; + state.manager = &manager; + DecompileShader(state); + decompiled = state.manager->IsFullyDecompiled(); + if (!decompiled) { + if (settings.depth == CompileDepth::FullDecompile) { + LOG_CRITICAL(HW_GPU, "Failed to remove all the gotos!:"); + } else { + LOG_CRITICAL(HW_GPU, "Failed to remove all backward gotos!:"); + } + state.manager->ShowCurrentState("Of Shader"); + state.manager->Clear(); + } else { + auto result_out = std::make_unique(); + result_out->start = start_address; + result_out->settings.depth = settings.depth; + result_out->manager = std::move(manager); + result_out->end = state.block_info.back().end + 1; + return std::move(result_out); + } + } + result_out->start = start_address; + result_out->settings.depth = + use_flow_stack ? CompileDepth::FlowStack : CompileDepth::NoFlowStack; + result_out->blocks.clear(); + for (auto& block : state.block_info) { ShaderBlock new_block{}; new_block.start = block.start; new_block.end = block.end; @@ -456,26 +552,24 @@ std::optional ScanFlow(const ProgramCode& program_code, new_block.branch.kills = block.branch.kill; new_block.branch.address = block.branch.address; } - result_out.end = std::max(result_out.end, block.end); - result_out.blocks.push_back(new_block); + result_out->end = std::max(result_out->end, block.end); + result_out->blocks.push_back(new_block); } - if (result_out.decompilable) { - result_out.labels = std::move(state.labels); - return {std::move(result_out)}; + if (!use_flow_stack) { + result_out->labels = std::move(state.labels); + return std::move(result_out); } - - // If it's not decompilable, merge the unlabelled blocks together - auto back = result_out.blocks.begin(); + auto back = result_out->blocks.begin(); auto next = std::next(back); - while (next != result_out.blocks.end()) { + while (next != result_out->blocks.end()) { if (state.labels.count(next->start) == 0 && next->start == back->end + 1) { back->end = next->end; - next = result_out.blocks.erase(next); + next = result_out->blocks.erase(next); continue; } back = next; ++next; } - return {std::move(result_out)}; + return std::move(result_out); } } // namespace VideoCommon::Shader diff --git a/src/video_core/shader/control_flow.h b/src/video_core/shader/control_flow.h index b0a5e4f8c..74e54a5c7 100644 --- a/src/video_core/shader/control_flow.h +++ b/src/video_core/shader/control_flow.h @@ -6,9 +6,11 @@ #include #include -#include +#include #include "video_core/engines/shader_bytecode.h" +#include "video_core/shader/ast.h" +#include "video_core/shader/compiler_settings.h" #include "video_core/shader/shader_ir.h" namespace VideoCommon::Shader { @@ -67,13 +69,15 @@ struct ShaderBlock { struct ShaderCharacteristics { std::list blocks{}; - bool decompilable{}; + std::set labels{}; u32 start{}; u32 end{}; - std::unordered_set labels{}; + ASTManager manager{true, true}; + CompilerSettings settings{}; }; -std::optional ScanFlow(const ProgramCode& program_code, - std::size_t program_size, u32 start_address); +std::unique_ptr ScanFlow(const ProgramCode& program_code, u32 program_size, + u32 start_address, + const CompilerSettings& settings); } // namespace VideoCommon::Shader diff --git a/src/video_core/shader/decode.cpp b/src/video_core/shader/decode.cpp index 47a9fd961..2626b1616 100644 --- a/src/video_core/shader/decode.cpp +++ b/src/video_core/shader/decode.cpp @@ -35,58 +35,138 @@ constexpr bool IsSchedInstruction(u32 offset, u32 main_offset) { } // namespace +class ASTDecoder { +public: + ASTDecoder(ShaderIR& ir) : ir(ir) {} + + void operator()(ASTProgram& ast) { + ASTNode current = ast.nodes.GetFirst(); + while (current) { + Visit(current); + current = current->GetNext(); + } + } + + void operator()(ASTIfThen& ast) { + ASTNode current = ast.nodes.GetFirst(); + while (current) { + Visit(current); + current = current->GetNext(); + } + } + + void operator()(ASTIfElse& ast) { + ASTNode current = ast.nodes.GetFirst(); + while (current) { + Visit(current); + current = current->GetNext(); + } + } + + void operator()(ASTBlockEncoded& ast) {} + + void operator()(ASTBlockDecoded& ast) {} + + void operator()(ASTVarSet& ast) {} + + void operator()(ASTLabel& ast) {} + + void operator()(ASTGoto& ast) {} + + void operator()(ASTDoWhile& ast) { + ASTNode current = ast.nodes.GetFirst(); + while (current) { + Visit(current); + current = current->GetNext(); + } + } + + void operator()(ASTReturn& ast) {} + + void operator()(ASTBreak& ast) {} + + void Visit(ASTNode& node) { + std::visit(*this, *node->GetInnerData()); + if (node->IsBlockEncoded()) { + auto block = std::get_if(node->GetInnerData()); + NodeBlock bb = ir.DecodeRange(block->start, block->end); + node->TransformBlockEncoded(std::move(bb)); + } + } + +private: + ShaderIR& ir; +}; + void ShaderIR::Decode() { std::memcpy(&header, program_code.data(), sizeof(Tegra::Shader::Header)); - disable_flow_stack = false; - const auto info = ScanFlow(program_code, program_size, main_offset); - if (info) { - const auto& shader_info = *info; - coverage_begin = shader_info.start; - coverage_end = shader_info.end; - if (shader_info.decompilable) { - disable_flow_stack = true; - const auto insert_block = [this](NodeBlock& nodes, u32 label) { - if (label == static_cast(exit_branch)) { - return; - } - basic_blocks.insert({label, nodes}); - }; - const auto& blocks = shader_info.blocks; - NodeBlock current_block; - u32 current_label = static_cast(exit_branch); - for (auto& block : blocks) { - if (shader_info.labels.count(block.start) != 0) { - insert_block(current_block, current_label); - current_block.clear(); - current_label = block.start; - } - if (!block.ignore_branch) { - DecodeRangeInner(current_block, block.start, block.end); - InsertControlFlow(current_block, block); - } else { - DecodeRangeInner(current_block, block.start, block.end + 1); - } - } - insert_block(current_block, current_label); - return; - } - LOG_WARNING(HW_GPU, "Flow Stack Removing Failed! Falling back to old method"); - // we can't decompile it, fallback to standard method + decompiled = false; + auto info = ScanFlow(program_code, program_size, main_offset, settings); + auto& shader_info = *info; + coverage_begin = shader_info.start; + coverage_end = shader_info.end; + switch (shader_info.settings.depth) { + case CompileDepth::FlowStack: { for (const auto& block : shader_info.blocks) { basic_blocks.insert({block.start, DecodeRange(block.start, block.end + 1)}); } - return; + break; } - LOG_WARNING(HW_GPU, "Flow Analysis Failed! Falling back to brute force compiling"); - - // Now we need to deal with an undecompilable shader. We need to brute force - // a shader that captures every position. - coverage_begin = main_offset; - const u32 shader_end = static_cast(program_size / sizeof(u64)); - coverage_end = shader_end; - for (u32 label = main_offset; label < shader_end; label++) { - basic_blocks.insert({label, DecodeRange(label, label + 1)}); + case CompileDepth::NoFlowStack: { + disable_flow_stack = true; + const auto insert_block = [this](NodeBlock& nodes, u32 label) { + if (label == static_cast(exit_branch)) { + return; + } + basic_blocks.insert({label, nodes}); + }; + const auto& blocks = shader_info.blocks; + NodeBlock current_block; + u32 current_label = static_cast(exit_branch); + for (auto& block : blocks) { + if (shader_info.labels.count(block.start) != 0) { + insert_block(current_block, current_label); + current_block.clear(); + current_label = block.start; + } + if (!block.ignore_branch) { + DecodeRangeInner(current_block, block.start, block.end); + InsertControlFlow(current_block, block); + } else { + DecodeRangeInner(current_block, block.start, block.end + 1); + } + } + insert_block(current_block, current_label); + break; + } + case CompileDepth::DecompileBackwards: + case CompileDepth::FullDecompile: { + program_manager = std::move(shader_info.manager); + disable_flow_stack = true; + decompiled = true; + ASTDecoder decoder{*this}; + ASTNode program = GetASTProgram(); + decoder.Visit(program); + break; + } + default: + LOG_CRITICAL(HW_GPU, "Unknown decompilation mode!"); + [[fallthrough]]; + case CompileDepth::BruteForce: { + coverage_begin = main_offset; + const u32 shader_end = static_cast(program_size / sizeof(u64)); + coverage_end = shader_end; + for (u32 label = main_offset; label < shader_end; label++) { + basic_blocks.insert({label, DecodeRange(label, label + 1)}); + } + break; + } + } + if (settings.depth != shader_info.settings.depth) { + LOG_WARNING( + HW_GPU, "Decompiling to this setting \"{}\" failed, downgrading to this setting \"{}\"", + CompileDepthAsString(settings.depth), CompileDepthAsString(shader_info.settings.depth)); } } diff --git a/src/video_core/shader/expr.cpp b/src/video_core/shader/expr.cpp new file mode 100644 index 000000000..ca633ffb1 --- /dev/null +++ b/src/video_core/shader/expr.cpp @@ -0,0 +1,82 @@ +// Copyright 2019 yuzu Emulator Project +// Licensed under GPLv2 or any later version +// Refer to the license.txt file included. + +#pragma once + +#include +#include + +#include "video_core/shader/expr.h" + +namespace VideoCommon::Shader { + +bool ExprAnd::operator==(const ExprAnd& b) const { + return (*operand1 == *b.operand1) && (*operand2 == *b.operand2); +} + +bool ExprOr::operator==(const ExprOr& b) const { + return (*operand1 == *b.operand1) && (*operand2 == *b.operand2); +} + +bool ExprNot::operator==(const ExprNot& b) const { + return (*operand1 == *b.operand1); +} + +bool ExprIsBoolean(Expr expr) { + return std::holds_alternative(*expr); +} + +bool ExprBooleanGet(Expr expr) { + return std::get_if(expr.get())->value; +} + +Expr MakeExprNot(Expr first) { + if (std::holds_alternative(*first)) { + return std::get_if(first.get())->operand1; + } + return MakeExpr(first); +} + +Expr MakeExprAnd(Expr first, Expr second) { + if (ExprIsBoolean(first)) { + return ExprBooleanGet(first) ? second : first; + } + if (ExprIsBoolean(second)) { + return ExprBooleanGet(second) ? first : second; + } + return MakeExpr(first, second); +} + +Expr MakeExprOr(Expr first, Expr second) { + if (ExprIsBoolean(first)) { + return ExprBooleanGet(first) ? first : second; + } + if (ExprIsBoolean(second)) { + return ExprBooleanGet(second) ? second : first; + } + return MakeExpr(first, second); +} + +bool ExprAreEqual(Expr first, Expr second) { + return (*first) == (*second); +} + +bool ExprAreOpposite(Expr first, Expr second) { + if (std::holds_alternative(*first)) { + return ExprAreEqual(std::get_if(first.get())->operand1, second); + } + if (std::holds_alternative(*second)) { + return ExprAreEqual(std::get_if(second.get())->operand1, first); + } + return false; +} + +bool ExprIsTrue(Expr first) { + if (ExprIsBoolean(first)) { + return ExprBooleanGet(first); + } + return false; +} + +} // namespace VideoCommon::Shader diff --git a/src/video_core/shader/expr.h b/src/video_core/shader/expr.h new file mode 100644 index 000000000..4c399cef9 --- /dev/null +++ b/src/video_core/shader/expr.h @@ -0,0 +1,120 @@ +// Copyright 2019 yuzu Emulator Project +// Licensed under GPLv2 or any later version +// Refer to the license.txt file included. + +#pragma once + +#include +#include + +#include "video_core/engines/shader_bytecode.h" + +namespace VideoCommon::Shader { + +using Tegra::Shader::ConditionCode; +using Tegra::Shader::Pred; + +class ExprAnd; +class ExprOr; +class ExprNot; +class ExprPredicate; +class ExprCondCode; +class ExprVar; +class ExprBoolean; + +using ExprData = + std::variant; +using Expr = std::shared_ptr; + +class ExprAnd final { +public: + explicit ExprAnd(Expr a, Expr b) : operand1{a}, operand2{b} {} + + bool operator==(const ExprAnd& b) const; + + Expr operand1; + Expr operand2; +}; + +class ExprOr final { +public: + explicit ExprOr(Expr a, Expr b) : operand1{a}, operand2{b} {} + + bool operator==(const ExprOr& b) const; + + Expr operand1; + Expr operand2; +}; + +class ExprNot final { +public: + explicit ExprNot(Expr a) : operand1{a} {} + + bool operator==(const ExprNot& b) const; + + Expr operand1; +}; + +class ExprVar final { +public: + explicit ExprVar(u32 index) : var_index{index} {} + + bool operator==(const ExprVar& b) const { + return var_index == b.var_index; + } + + u32 var_index; +}; + +class ExprPredicate final { +public: + explicit ExprPredicate(u32 predicate) : predicate{predicate} {} + + bool operator==(const ExprPredicate& b) const { + return predicate == b.predicate; + } + + u32 predicate; +}; + +class ExprCondCode final { +public: + explicit ExprCondCode(ConditionCode cc) : cc{cc} {} + + bool operator==(const ExprCondCode& b) const { + return cc == b.cc; + } + + ConditionCode cc; +}; + +class ExprBoolean final { +public: + explicit ExprBoolean(bool val) : value{val} {} + + bool operator==(const ExprBoolean& b) const { + return value == b.value; + } + + bool value; +}; + +template +Expr MakeExpr(Args&&... args) { + static_assert(std::is_convertible_v); + return std::make_shared(T(std::forward(args)...)); +} + +bool ExprAreEqual(Expr first, Expr second); + +bool ExprAreOpposite(Expr first, Expr second); + +Expr MakeExprNot(Expr first); + +Expr MakeExprAnd(Expr first, Expr second); + +Expr MakeExprOr(Expr first, Expr second); + +bool ExprIsTrue(Expr first); + +} // namespace VideoCommon::Shader diff --git a/src/video_core/shader/shader_ir.cpp b/src/video_core/shader/shader_ir.cpp index 2c357f310..c1f2b88c8 100644 --- a/src/video_core/shader/shader_ir.cpp +++ b/src/video_core/shader/shader_ir.cpp @@ -22,8 +22,10 @@ using Tegra::Shader::PredCondition; using Tegra::Shader::PredOperation; using Tegra::Shader::Register; -ShaderIR::ShaderIR(const ProgramCode& program_code, u32 main_offset, const std::size_t size) - : program_code{program_code}, main_offset{main_offset}, program_size{size} { +ShaderIR::ShaderIR(const ProgramCode& program_code, u32 main_offset, const std::size_t size, + CompilerSettings settings) + : program_code{program_code}, main_offset{main_offset}, program_size{size}, basic_blocks{}, + program_manager{true, true}, settings{settings} { Decode(); } @@ -137,7 +139,7 @@ Node ShaderIR::GetOutputAttribute(Attribute::Index index, u64 element, Node buff return MakeNode(index, static_cast(element), std::move(buffer)); } -Node ShaderIR::GetInternalFlag(InternalFlag flag, bool negated) { +Node ShaderIR::GetInternalFlag(InternalFlag flag, bool negated) const { const Node node = MakeNode(flag); if (negated) { return Operation(OperationCode::LogicalNegate, node); @@ -367,13 +369,13 @@ OperationCode ShaderIR::GetPredicateCombiner(PredOperation operation) { return op->second; } -Node ShaderIR::GetConditionCode(Tegra::Shader::ConditionCode cc) { +Node ShaderIR::GetConditionCode(Tegra::Shader::ConditionCode cc) const { switch (cc) { case Tegra::Shader::ConditionCode::NEU: return GetInternalFlag(InternalFlag::Zero, true); default: UNIMPLEMENTED_MSG("Unimplemented condition code: {}", static_cast(cc)); - return GetPredicate(static_cast(Pred::NeverExecute)); + return MakeNode(Pred::NeverExecute, false); } } diff --git a/src/video_core/shader/shader_ir.h b/src/video_core/shader/shader_ir.h index 6f666ee30..105981d67 100644 --- a/src/video_core/shader/shader_ir.h +++ b/src/video_core/shader/shader_ir.h @@ -15,6 +15,8 @@ #include "video_core/engines/maxwell_3d.h" #include "video_core/engines/shader_bytecode.h" #include "video_core/engines/shader_header.h" +#include "video_core/shader/ast.h" +#include "video_core/shader/compiler_settings.h" #include "video_core/shader/node.h" namespace VideoCommon::Shader { @@ -64,7 +66,8 @@ struct GlobalMemoryUsage { class ShaderIR final { public: - explicit ShaderIR(const ProgramCode& program_code, u32 main_offset, std::size_t size); + explicit ShaderIR(const ProgramCode& program_code, u32 main_offset, std::size_t size, + CompilerSettings settings); ~ShaderIR(); const std::map& GetBasicBlocks() const { @@ -144,11 +147,31 @@ public: return disable_flow_stack; } + bool IsDecompiled() const { + return decompiled; + } + + const ASTManager& GetASTManager() const { + return program_manager; + } + + ASTNode GetASTProgram() const { + return program_manager.GetProgram(); + } + + u32 GetASTNumVariables() const { + return program_manager.GetVariables(); + } + u32 ConvertAddressToNvidiaSpace(const u32 address) const { return (address - main_offset) * sizeof(Tegra::Shader::Instruction); } + /// Returns a condition code evaluated from internal flags + Node GetConditionCode(Tegra::Shader::ConditionCode cc) const; + private: + friend class ASTDecoder; void Decode(); NodeBlock DecodeRange(u32 begin, u32 end); @@ -213,7 +236,7 @@ private: /// Generates a node representing an output attribute. Keeps track of used attributes. Node GetOutputAttribute(Tegra::Shader::Attribute::Index index, u64 element, Node buffer); /// Generates a node representing an internal flag - Node GetInternalFlag(InternalFlag flag, bool negated = false); + Node GetInternalFlag(InternalFlag flag, bool negated = false) const; /// Generates a node representing a local memory address Node GetLocalMemory(Node address); /// Generates a node representing a shared memory address @@ -271,9 +294,6 @@ private: /// Returns a predicate combiner operation OperationCode GetPredicateCombiner(Tegra::Shader::PredOperation operation); - /// Returns a condition code evaluated from internal flags - Node GetConditionCode(Tegra::Shader::ConditionCode cc); - /// Accesses a texture sampler const Sampler& GetSampler(const Tegra::Shader::Sampler& sampler, Tegra::Shader::TextureType type, bool is_array, bool is_shadow); @@ -357,6 +377,7 @@ private: const ProgramCode& program_code; const u32 main_offset; const std::size_t program_size; + bool decompiled{}; bool disable_flow_stack{}; u32 coverage_begin{}; @@ -364,6 +385,8 @@ private: std::map basic_blocks; NodeBlock global_code; + ASTManager program_manager; + CompilerSettings settings{}; std::set used_registers; std::set used_predicates;