From b5d7279d878211654b4abb165d94af763a365f47 Mon Sep 17 00:00:00 2001 From: ReinUsesLisp Date: Tue, 16 Feb 2021 04:10:22 -0300 Subject: [PATCH] spirv: Initial bindings support --- externals/sirit | 2 +- src/shader_recompiler/CMakeLists.txt | 4 + .../backend/spirv/emit_context.cpp | 160 ++++++++++++++ .../backend/spirv/emit_context.h | 67 ++++++ .../backend/spirv/emit_spirv.cpp | 203 ++++++++---------- .../backend/spirv/emit_spirv.h | 84 +------- .../spirv/emit_spirv_bitwise_conversion.cpp | 4 +- .../backend/spirv/emit_spirv_composite.cpp | 2 +- .../spirv/emit_spirv_context_get_set.cpp | 20 +- .../backend/spirv/emit_spirv_control_flow.cpp | 26 +++ .../spirv/emit_spirv_floating_point.cpp | 18 +- .../backend/spirv/emit_spirv_integer.cpp | 16 +- .../backend/spirv/emit_spirv_memory.cpp | 36 +++- .../backend/spirv/emit_spirv_undefined.cpp | 4 +- .../frontend/ir/basic_block.h | 16 ++ src/shader_recompiler/frontend/ir/program.h | 2 + .../frontend/maxwell/program.cpp | 7 +- .../ir_opt/collect_shader_info_pass.cpp | 81 +++++++ .../ir_opt/constant_propagation_pass.cpp | 76 +++++-- .../global_memory_to_storage_buffer_pass.cpp | 110 +++++----- src/shader_recompiler/ir_opt/passes.h | 4 +- src/shader_recompiler/main.cpp | 4 +- src/shader_recompiler/shader_info.h | 33 ++- 23 files changed, 679 insertions(+), 300 deletions(-) create mode 100644 src/shader_recompiler/backend/spirv/emit_context.cpp create mode 100644 src/shader_recompiler/backend/spirv/emit_context.h create mode 100644 src/shader_recompiler/ir_opt/collect_shader_info_pass.cpp diff --git a/externals/sirit b/externals/sirit index f819ade0e..200310e8f 160000 --- a/externals/sirit +++ b/externals/sirit @@ -1 +1 @@ -Subproject commit f819ade0efe925a782090dea9e1bf300fedffb39 +Subproject commit 200310e8faa756b9869dd6dfc902c255246ac74a diff --git a/src/shader_recompiler/CMakeLists.txt b/src/shader_recompiler/CMakeLists.txt index e1f4276a1..84be94a8d 100644 --- a/src/shader_recompiler/CMakeLists.txt +++ b/src/shader_recompiler/CMakeLists.txt @@ -1,4 +1,6 @@ add_executable(shader_recompiler + backend/spirv/emit_context.cpp + backend/spirv/emit_context.h backend/spirv/emit_spirv.cpp backend/spirv/emit_spirv.h backend/spirv/emit_spirv_bitwise_conversion.cpp @@ -75,6 +77,7 @@ add_executable(shader_recompiler frontend/maxwell/translate/impl/move_special_register.cpp frontend/maxwell/translate/translate.cpp frontend/maxwell/translate/translate.h + ir_opt/collect_shader_info_pass.cpp ir_opt/constant_propagation_pass.cpp ir_opt/dead_code_elimination_pass.cpp ir_opt/global_memory_to_storage_buffer_pass.cpp @@ -84,6 +87,7 @@ add_executable(shader_recompiler ir_opt/verification_pass.cpp main.cpp object_pool.h + shader_info.h ) target_include_directories(video_core PRIVATE sirit) diff --git a/src/shader_recompiler/backend/spirv/emit_context.cpp b/src/shader_recompiler/backend/spirv/emit_context.cpp new file mode 100644 index 000000000..1c985aff8 --- /dev/null +++ b/src/shader_recompiler/backend/spirv/emit_context.cpp @@ -0,0 +1,160 @@ +// Copyright 2021 yuzu Emulator Project +// Licensed under GPLv2 or any later version +// Refer to the license.txt file included. + +#include +#include +#include + +#include + +#include "common/common_types.h" +#include "shader_recompiler/backend/spirv/emit_context.h" + +namespace Shader::Backend::SPIRV { + +void VectorTypes::Define(Sirit::Module& sirit_ctx, Id base_type, std::string_view name) { + defs[0] = sirit_ctx.Name(base_type, name); + + std::array def_name; + for (int i = 1; i < 4; ++i) { + const std::string_view def_name_view( + def_name.data(), + fmt::format_to_n(def_name.data(), def_name.size(), "{}x{}", name, i + 1).size); + defs[i] = sirit_ctx.Name(sirit_ctx.TypeVector(base_type, i + 1), def_name_view); + } +} + +EmitContext::EmitContext(IR::Program& program) : Sirit::Module(0x00010000) { + AddCapability(spv::Capability::Shader); + DefineCommonTypes(program.info); + DefineCommonConstants(); + DefineSpecialVariables(program.info); + DefineConstantBuffers(program.info); + DefineStorageBuffers(program.info); + DefineLabels(program); +} + +EmitContext::~EmitContext() = default; + +Id EmitContext::Def(const IR::Value& value) { + if (!value.IsImmediate()) { + return value.Inst()->Definition(); + } + switch (value.Type()) { + case IR::Type::U1: + return value.U1() ? true_value : false_value; + case IR::Type::U32: + return Constant(U32[1], value.U32()); + case IR::Type::F32: + return Constant(F32[1], value.F32()); + default: + throw NotImplementedException("Immediate type {}", value.Type()); + } +} + +void EmitContext::DefineCommonTypes(const Info& info) { + void_id = TypeVoid(); + + U1 = Name(TypeBool(), "u1"); + + F32.Define(*this, TypeFloat(32), "f32"); + U32.Define(*this, TypeInt(32, false), "u32"); + + if (info.uses_fp16) { + AddCapability(spv::Capability::Float16); + F16.Define(*this, TypeFloat(16), "f16"); + } + if (info.uses_fp64) { + AddCapability(spv::Capability::Float64); + F64.Define(*this, TypeFloat(64), "f64"); + } +} + +void EmitContext::DefineCommonConstants() { + true_value = ConstantTrue(U1); + false_value = ConstantFalse(U1); + u32_zero_value = Constant(U32[1], 0U); +} + +void EmitContext::DefineSpecialVariables(const Info& info) { + const auto define{[this](Id type, spv::BuiltIn builtin, spv::StorageClass storage_class) { + const Id pointer_type{TypePointer(storage_class, type)}; + const Id id{AddGlobalVariable(pointer_type, spv::StorageClass::Input)}; + Decorate(id, spv::Decoration::BuiltIn, builtin); + return id; + }}; + using namespace std::placeholders; + const auto define_input{std::bind(define, _1, _2, spv::StorageClass::Input)}; + + if (info.uses_workgroup_id) { + workgroup_id = define_input(U32[3], spv::BuiltIn::WorkgroupId); + } + if (info.uses_local_invocation_id) { + local_invocation_id = define_input(U32[3], spv::BuiltIn::LocalInvocationId); + } +} + +void EmitContext::DefineConstantBuffers(const Info& info) { + if (info.constant_buffer_descriptors.empty()) { + return; + } + const Id array_type{TypeArray(U32[1], Constant(U32[1], 4096))}; + Decorate(array_type, spv::Decoration::ArrayStride, 16U); + + const Id struct_type{TypeStruct(array_type)}; + Name(struct_type, "cbuf_block"); + Decorate(struct_type, spv::Decoration::Block); + MemberName(struct_type, 0, "data"); + MemberDecorate(struct_type, 0, spv::Decoration::Offset, 0U); + + const Id uniform_type{TypePointer(spv::StorageClass::Uniform, struct_type)}; + uniform_u32 = TypePointer(spv::StorageClass::Uniform, U32[1]); + + u32 binding{}; + for (const Info::ConstantBufferDescriptor& desc : info.constant_buffer_descriptors) { + const Id id{AddGlobalVariable(uniform_type, spv::StorageClass::Uniform)}; + Decorate(id, spv::Decoration::Binding, binding); + Name(id, fmt::format("c{}", desc.index)); + std::fill_n(cbufs.data() + desc.index, desc.count, id); + binding += desc.count; + } +} + +void EmitContext::DefineStorageBuffers(const Info& info) { + if (info.storage_buffers_descriptors.empty()) { + return; + } + AddExtension("SPV_KHR_storage_buffer_storage_class"); + + const Id array_type{TypeRuntimeArray(U32[1])}; + Decorate(array_type, spv::Decoration::ArrayStride, 4U); + + const Id struct_type{TypeStruct(array_type)}; + Name(struct_type, "ssbo_block"); + Decorate(struct_type, spv::Decoration::Block); + MemberName(struct_type, 0, "data"); + MemberDecorate(struct_type, 0, spv::Decoration::Offset, 0U); + + const Id storage_type{TypePointer(spv::StorageClass::StorageBuffer, struct_type)}; + storage_u32 = TypePointer(spv::StorageClass::StorageBuffer, U32[1]); + + u32 binding{}; + for (const Info::StorageBufferDescriptor& desc : info.storage_buffers_descriptors) { + const Id id{AddGlobalVariable(storage_type, spv::StorageClass::StorageBuffer)}; + Decorate(id, spv::Decoration::Binding, binding); + Name(id, fmt::format("ssbo{}", binding)); + std::fill_n(ssbos.data() + binding, desc.count, id); + binding += desc.count; + } +} + +void EmitContext::DefineLabels(IR::Program& program) { + for (const IR::Function& function : program.functions) { + for (IR::Block* const block : function.blocks) { + block->SetDefinition(OpLabel()); + } + } +} + +} // namespace Shader::Backend::SPIRV diff --git a/src/shader_recompiler/backend/spirv/emit_context.h b/src/shader_recompiler/backend/spirv/emit_context.h new file mode 100644 index 000000000..c4b84759d --- /dev/null +++ b/src/shader_recompiler/backend/spirv/emit_context.h @@ -0,0 +1,67 @@ +// Copyright 2021 yuzu Emulator Project +// Licensed under GPLv2 or any later version +// Refer to the license.txt file included. + +#pragma once + +#include +#include + +#include + +#include "shader_recompiler/frontend/ir/program.h" +#include "shader_recompiler/shader_info.h" + +namespace Shader::Backend::SPIRV { + +using Sirit::Id; + +class VectorTypes { +public: + void Define(Sirit::Module& sirit_ctx, Id base_type, std::string_view name); + + [[nodiscard]] Id operator[](size_t size) const noexcept { + return defs[size - 1]; + } + +private: + std::array defs{}; +}; + +class EmitContext final : public Sirit::Module { +public: + explicit EmitContext(IR::Program& program); + ~EmitContext(); + + [[nodiscard]] Id Def(const IR::Value& value); + + Id void_id{}; + Id U1{}; + VectorTypes F32; + VectorTypes U32; + VectorTypes F16; + VectorTypes F64; + + Id true_value{}; + Id false_value{}; + Id u32_zero_value{}; + + Id uniform_u32{}; + Id storage_u32{}; + + std::array cbufs{}; + std::array ssbos{}; + + Id workgroup_id{}; + Id local_invocation_id{}; + +private: + void DefineCommonTypes(const Info& info); + void DefineCommonConstants(); + void DefineSpecialVariables(const Info& info); + void DefineConstantBuffers(const Info& info); + void DefineStorageBuffers(const Info& info); + void DefineLabels(IR::Program& program); +}; + +} // namespace Shader::Backend::SPIRV diff --git a/src/shader_recompiler/backend/spirv/emit_spirv.cpp b/src/shader_recompiler/backend/spirv/emit_spirv.cpp index 0895414b4..c79c09774 100644 --- a/src/shader_recompiler/backend/spirv/emit_spirv.cpp +++ b/src/shader_recompiler/backend/spirv/emit_spirv.cpp @@ -12,60 +12,22 @@ #include "shader_recompiler/frontend/ir/program.h" namespace Shader::Backend::SPIRV { +namespace { +template +struct FuncTraits : FuncTraits {}; -EmitContext::EmitContext(IR::Program& program) { - AddCapability(spv::Capability::Shader); - AddCapability(spv::Capability::Float16); - AddCapability(spv::Capability::Float64); - void_id = TypeVoid(); +template +struct FuncTraits { + using ReturnType = ReturnType_; - u1 = Name(TypeBool(), "u1"); - f32.Define(*this, TypeFloat(32), "f32"); - u32.Define(*this, TypeInt(32, false), "u32"); - f16.Define(*this, TypeFloat(16), "f16"); - f64.Define(*this, TypeFloat(64), "f64"); + static constexpr size_t NUM_ARGS = sizeof...(Args); - true_value = ConstantTrue(u1); - false_value = ConstantFalse(u1); - - for (const IR::Function& function : program.functions) { - for (IR::Block* const block : function.blocks) { - block_label_map.emplace_back(block, OpLabel()); - } - } - std::ranges::sort(block_label_map, {}, &std::pair::first); -} - -EmitContext::~EmitContext() = default; - -EmitSPIRV::EmitSPIRV(IR::Program& program) { - EmitContext ctx{program}; - const Id void_function{ctx.TypeFunction(ctx.void_id)}; - // FIXME: Forward declare functions (needs sirit support) - Id func{}; - for (IR::Function& function : program.functions) { - func = ctx.OpFunction(ctx.void_id, spv::FunctionControlMask::MaskNone, void_function); - for (IR::Block* const block : function.blocks) { - ctx.AddLabel(ctx.BlockLabel(block)); - for (IR::Inst& inst : block->Instructions()) { - EmitInst(ctx, &inst); - } - } - ctx.OpFunctionEnd(); - } - ctx.AddEntryPoint(spv::ExecutionModel::GLCompute, func, "main"); - - std::vector result{ctx.Assemble()}; - std::FILE* file{std::fopen("shader.spv", "wb")}; - std::fwrite(result.data(), sizeof(u32), result.size(), file); - std::fclose(file); - std::system("spirv-dis shader.spv"); - std::system("spirv-val shader.spv"); - std::system("spirv-cross shader.spv"); -} + template + using ArgType = std::tuple_element_t>; +}; template -static void SetDefinition(EmitSPIRV& emit, EmitContext& ctx, IR::Inst* inst, Args... args) { +void SetDefinition(EmitSPIRV& emit, EmitContext& ctx, IR::Inst* inst, Args... args) { const Id forward_id{inst->Definition()}; const bool has_forward_id{Sirit::ValidId(forward_id)}; Id current_id{}; @@ -80,42 +42,90 @@ static void SetDefinition(EmitSPIRV& emit, EmitContext& ctx, IR::Inst* inst, Arg } } -template -static void Invoke(EmitSPIRV& emit, EmitContext& ctx, IR::Inst* inst) { - using M = decltype(method); - using std::is_invocable_r_v; - if constexpr (is_invocable_r_v) { - SetDefinition(emit, ctx, inst); - } else if constexpr (is_invocable_r_v) { - SetDefinition(emit, ctx, inst, ctx.Def(inst->Arg(0))); - } else if constexpr (is_invocable_r_v) { - SetDefinition(emit, ctx, inst, ctx.Def(inst->Arg(0)), ctx.Def(inst->Arg(1))); - } else if constexpr (is_invocable_r_v) { - SetDefinition(emit, ctx, inst, ctx.Def(inst->Arg(0)), ctx.Def(inst->Arg(1)), - ctx.Def(inst->Arg(2))); - } else if constexpr (is_invocable_r_v) { - SetDefinition(emit, ctx, inst, inst); - } else if constexpr (is_invocable_r_v) { - SetDefinition(emit, ctx, inst, inst, ctx.Def(inst->Arg(0)), ctx.Def(inst->Arg(1))); - } else if constexpr (is_invocable_r_v) { - SetDefinition(emit, ctx, inst, inst, ctx.Def(inst->Arg(0)), ctx.Def(inst->Arg(1)), - ctx.Def(inst->Arg(2))); - } else if constexpr (is_invocable_r_v) { - SetDefinition(emit, ctx, inst, ctx.Def(inst->Arg(0)), inst->Arg(1).U32()); - } else if constexpr (is_invocable_r_v) { - SetDefinition(emit, ctx, inst, inst->Arg(0)); - } else if constexpr (is_invocable_r_v) { - SetDefinition(emit, ctx, inst, inst->Arg(0), inst->Arg(1)); - } else if constexpr (is_invocable_r_v) { - (emit.*method)(ctx, inst); - } else if constexpr (is_invocable_r_v) { - (emit.*method)(ctx); - } else { - static_assert(false, "Bad format"); +template +ArgType Arg(EmitContext& ctx, const IR::Value& arg) { + if constexpr (std::is_same_v) { + return ctx.Def(arg); + } else if constexpr (std::is_same_v) { + return arg; + } else if constexpr (std::is_same_v) { + return arg.U32(); + } else if constexpr (std::is_same_v) { + return arg.Label(); } } +template +void Invoke(EmitSPIRV& emit, EmitContext& ctx, IR::Inst* inst, std::index_sequence) { + using Traits = FuncTraits; + if constexpr (std::is_same_v) { + if constexpr (is_first_arg_inst) { + SetDefinition(emit, ctx, inst, inst, + Arg>(ctx, inst->Arg(I))...); + } else { + SetDefinition(emit, ctx, inst, + Arg>(ctx, inst->Arg(I))...); + } + } else { + if constexpr (is_first_arg_inst) { + (emit.*method)(ctx, inst, Arg>(ctx, inst->Arg(I))...); + } else { + (emit.*method)(ctx, Arg>(ctx, inst->Arg(I))...); + } + } +} + +template +void Invoke(EmitSPIRV& emit, EmitContext& ctx, IR::Inst* inst) { + using Traits = FuncTraits; + static_assert(Traits::NUM_ARGS >= 1, "Insufficient arguments"); + if constexpr (Traits::NUM_ARGS == 1) { + Invoke(emit, ctx, inst, std::make_index_sequence<0>{}); + } else { + using FirstArgType = typename Traits::template ArgType<1>; + static constexpr bool is_first_arg_inst = std::is_same_v; + using Indices = std::make_index_sequence; + Invoke(emit, ctx, inst, Indices{}); + } +} +} // Anonymous namespace + +EmitSPIRV::EmitSPIRV(IR::Program& program) { + EmitContext ctx{program}; + const Id void_function{ctx.TypeFunction(ctx.void_id)}; + // FIXME: Forward declare functions (needs sirit support) + Id func{}; + for (IR::Function& function : program.functions) { + func = ctx.OpFunction(ctx.void_id, spv::FunctionControlMask::MaskNone, void_function); + for (IR::Block* const block : function.blocks) { + ctx.AddLabel(block->Definition()); + for (IR::Inst& inst : block->Instructions()) { + EmitInst(ctx, &inst); + } + } + ctx.OpFunctionEnd(); + } + boost::container::small_vector interfaces; + if (program.info.uses_workgroup_id) { + interfaces.push_back(ctx.workgroup_id); + } + if (program.info.uses_local_invocation_id) { + interfaces.push_back(ctx.local_invocation_id); + } + + const std::span interfaces_span(interfaces.data(), interfaces.size()); + ctx.AddEntryPoint(spv::ExecutionModel::Fragment, func, "main", interfaces_span); + ctx.AddExecutionMode(func, spv::ExecutionMode::OriginUpperLeft); + + std::vector result{ctx.Assemble()}; + std::FILE* file{std::fopen("D:\\shader.spv", "wb")}; + std::fwrite(result.data(), sizeof(u32), result.size(), file); + std::fclose(file); + std::system("spirv-dis D:\\shader.spv") == 0 && + std::system("spirv-val --uniform-buffer-standard-layout D:\\shader.spv") == 0 && + std::system("spirv-cross -V D:\\shader.spv") == 0; +} + void EmitSPIRV::EmitInst(EmitContext& ctx, IR::Inst* inst) { switch (inst->Opcode()) { #define OPCODE(name, result_type, ...) \ @@ -130,9 +140,9 @@ void EmitSPIRV::EmitInst(EmitContext& ctx, IR::Inst* inst) { static Id TypeId(const EmitContext& ctx, IR::Type type) { switch (type) { case IR::Type::U1: - return ctx.u1; + return ctx.U1; case IR::Type::U32: - return ctx.u32[1]; + return ctx.U32[1]; default: throw NotImplementedException("Phi node type {}", type); } @@ -162,7 +172,7 @@ Id EmitSPIRV::EmitPhi(EmitContext& ctx, IR::Inst* inst) { } IR::Block* const phi_block{inst->PhiBlock(index)}; operands.push_back(def); - operands.push_back(ctx.BlockLabel(phi_block)); + operands.push_back(phi_block->Definition()); } const Id result_type{TypeId(ctx, inst->Arg(0).Type())}; return ctx.OpPhi(result_type, std::span(operands.data(), operands.size())); @@ -174,29 +184,6 @@ void EmitSPIRV::EmitIdentity(EmitContext&) { throw NotImplementedException("SPIR-V Instruction"); } -// FIXME: Move to its own file -void EmitSPIRV::EmitBranch(EmitContext& ctx, IR::Inst* inst) { - ctx.OpBranch(ctx.BlockLabel(inst->Arg(0).Label())); -} - -void EmitSPIRV::EmitBranchConditional(EmitContext& ctx, IR::Inst* inst) { - ctx.OpBranchConditional(ctx.Def(inst->Arg(0)), ctx.BlockLabel(inst->Arg(1).Label()), - ctx.BlockLabel(inst->Arg(2).Label())); -} - -void EmitSPIRV::EmitLoopMerge(EmitContext& ctx, IR::Inst* inst) { - ctx.OpLoopMerge(ctx.BlockLabel(inst->Arg(0).Label()), ctx.BlockLabel(inst->Arg(1).Label()), - spv::LoopControlMask::MaskNone); -} - -void EmitSPIRV::EmitSelectionMerge(EmitContext& ctx, IR::Inst* inst) { - ctx.OpSelectionMerge(ctx.BlockLabel(inst->Arg(0).Label()), spv::SelectionControlMask::MaskNone); -} - -void EmitSPIRV::EmitReturn(EmitContext& ctx) { - ctx.OpReturn(); -} - void EmitSPIRV::EmitGetZeroFromOp(EmitContext&) { throw LogicError("Unreachable instruction"); } diff --git a/src/shader_recompiler/backend/spirv/emit_spirv.h b/src/shader_recompiler/backend/spirv/emit_spirv.h index 7d76377b5..a5d0e1ec0 100644 --- a/src/shader_recompiler/backend/spirv/emit_spirv.h +++ b/src/shader_recompiler/backend/spirv/emit_spirv.h @@ -7,82 +7,12 @@ #include #include "common/common_types.h" +#include "shader_recompiler/backend/spirv/emit_context.h" #include "shader_recompiler/frontend/ir/microinstruction.h" #include "shader_recompiler/frontend/ir/program.h" namespace Shader::Backend::SPIRV { -using Sirit::Id; - -class VectorTypes { -public: - void Define(Sirit::Module& sirit_ctx, Id base_type, std::string_view name) { - defs[0] = sirit_ctx.Name(base_type, name); - - std::array def_name; - for (int i = 1; i < 4; ++i) { - const std::string_view def_name_view( - def_name.data(), - fmt::format_to_n(def_name.data(), def_name.size(), "{}x{}", name, i + 1).size); - defs[i] = sirit_ctx.Name(sirit_ctx.TypeVector(base_type, i + 1), def_name_view); - } - } - - [[nodiscard]] Id operator[](size_t size) const noexcept { - return defs[size - 1]; - } - -private: - std::array defs; -}; - -class EmitContext final : public Sirit::Module { -public: - explicit EmitContext(IR::Program& program); - ~EmitContext(); - - [[nodiscard]] Id Def(const IR::Value& value) { - if (!value.IsImmediate()) { - return value.Inst()->Definition(); - } - switch (value.Type()) { - case IR::Type::U1: - return value.U1() ? true_value : false_value; - case IR::Type::U32: - return Constant(u32[1], value.U32()); - case IR::Type::F32: - return Constant(f32[1], value.F32()); - default: - throw NotImplementedException("Immediate type {}", value.Type()); - } - } - - [[nodiscard]] Id BlockLabel(IR::Block* block) const { - const auto it{std::ranges::lower_bound(block_label_map, block, {}, - &std::pair::first)}; - if (it == block_label_map.end()) { - throw LogicError("Undefined block"); - } - return it->second; - } - - Id void_id{}; - Id u1{}; - VectorTypes f32; - VectorTypes u32; - VectorTypes f16; - VectorTypes f64; - - Id true_value{}; - Id false_value{}; - - Id workgroup_id{}; - Id local_invocation_id{}; - -private: - std::vector> block_label_map; -}; - class EmitSPIRV { public: explicit EmitSPIRV(IR::Program& program); @@ -94,10 +24,11 @@ private: Id EmitPhi(EmitContext& ctx, IR::Inst* inst); void EmitVoid(EmitContext& ctx); void EmitIdentity(EmitContext& ctx); - void EmitBranch(EmitContext& ctx, IR::Inst* inst); - void EmitBranchConditional(EmitContext& ctx, IR::Inst* inst); - void EmitLoopMerge(EmitContext& ctx, IR::Inst* inst); - void EmitSelectionMerge(EmitContext& ctx, IR::Inst* inst); + void EmitBranch(EmitContext& ctx, IR::Block* label); + void EmitBranchConditional(EmitContext& ctx, Id condition, IR::Block* true_label, + IR::Block* false_label); + void EmitLoopMerge(EmitContext& ctx, IR::Block* merge_label, IR::Block* continue_label); + void EmitSelectionMerge(EmitContext& ctx, IR::Block* merge_label); void EmitReturn(EmitContext& ctx); void EmitGetRegister(EmitContext& ctx); void EmitSetRegister(EmitContext& ctx); @@ -150,7 +81,8 @@ private: void EmitWriteStorageS8(EmitContext& ctx); void EmitWriteStorageU16(EmitContext& ctx); void EmitWriteStorageS16(EmitContext& ctx); - void EmitWriteStorage32(EmitContext& ctx); + void EmitWriteStorage32(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset, + Id value); void EmitWriteStorage64(EmitContext& ctx); void EmitWriteStorage128(EmitContext& ctx); void EmitCompositeConstructU32x2(EmitContext& ctx); diff --git a/src/shader_recompiler/backend/spirv/emit_spirv_bitwise_conversion.cpp b/src/shader_recompiler/backend/spirv/emit_spirv_bitwise_conversion.cpp index 447df5b8c..af82df99c 100644 --- a/src/shader_recompiler/backend/spirv/emit_spirv_bitwise_conversion.cpp +++ b/src/shader_recompiler/backend/spirv/emit_spirv_bitwise_conversion.cpp @@ -11,7 +11,7 @@ void EmitSPIRV::EmitBitCastU16F16(EmitContext&) { } Id EmitSPIRV::EmitBitCastU32F32(EmitContext& ctx, Id value) { - return ctx.OpBitcast(ctx.u32[1], value); + return ctx.OpBitcast(ctx.U32[1], value); } void EmitSPIRV::EmitBitCastU64F64(EmitContext&) { @@ -23,7 +23,7 @@ void EmitSPIRV::EmitBitCastF16U16(EmitContext&) { } Id EmitSPIRV::EmitBitCastF32U32(EmitContext& ctx, Id value) { - return ctx.OpBitcast(ctx.f32[1], value); + return ctx.OpBitcast(ctx.F32[1], value); } void EmitSPIRV::EmitBitCastF64U64(EmitContext&) { diff --git a/src/shader_recompiler/backend/spirv/emit_spirv_composite.cpp b/src/shader_recompiler/backend/spirv/emit_spirv_composite.cpp index b190cf876..a7374c89d 100644 --- a/src/shader_recompiler/backend/spirv/emit_spirv_composite.cpp +++ b/src/shader_recompiler/backend/spirv/emit_spirv_composite.cpp @@ -23,7 +23,7 @@ void EmitSPIRV::EmitCompositeExtractU32x2(EmitContext&) { } Id EmitSPIRV::EmitCompositeExtractU32x3(EmitContext& ctx, Id vector, u32 index) { - return ctx.OpCompositeExtract(ctx.u32[1], vector, index); + return ctx.OpCompositeExtract(ctx.U32[1], vector, index); } void EmitSPIRV::EmitCompositeExtractU32x4(EmitContext&) { diff --git a/src/shader_recompiler/backend/spirv/emit_spirv_context_get_set.cpp b/src/shader_recompiler/backend/spirv/emit_spirv_context_get_set.cpp index 1eab739ed..f4c9970eb 100644 --- a/src/shader_recompiler/backend/spirv/emit_spirv_context_get_set.cpp +++ b/src/shader_recompiler/backend/spirv/emit_spirv_context_get_set.cpp @@ -37,7 +37,10 @@ Id EmitSPIRV::EmitGetCbuf(EmitContext& ctx, const IR::Value& binding, const IR:: if (!offset.IsImmediate()) { throw NotImplementedException("Variable constant buffer offset"); } - return ctx.Name(ctx.OpUndef(ctx.u32[1]), "unimplemented_cbuf"); + const Id imm_offset{ctx.Constant(ctx.U32[1], offset.U32() / 4)}; + const Id cbuf{ctx.cbufs[binding.U32()]}; + const Id access_chain{ctx.OpAccessChain(ctx.uniform_u32, cbuf, ctx.u32_zero_value, imm_offset)}; + return ctx.OpLoad(ctx.U32[1], access_chain); } void EmitSPIRV::EmitGetAttribute(EmitContext&) { @@ -89,22 +92,11 @@ void EmitSPIRV::EmitSetOFlag(EmitContext&) { } Id EmitSPIRV::EmitWorkgroupId(EmitContext& ctx) { - if (ctx.workgroup_id.value == 0) { - ctx.workgroup_id = ctx.AddGlobalVariable( - ctx.TypePointer(spv::StorageClass::Input, ctx.u32[3]), spv::StorageClass::Input); - ctx.Decorate(ctx.workgroup_id, spv::Decoration::BuiltIn, spv::BuiltIn::WorkgroupId); - } - return ctx.OpLoad(ctx.u32[3], ctx.workgroup_id); + return ctx.OpLoad(ctx.U32[3], ctx.workgroup_id); } Id EmitSPIRV::EmitLocalInvocationId(EmitContext& ctx) { - if (ctx.local_invocation_id.value == 0) { - ctx.local_invocation_id = ctx.AddGlobalVariable( - ctx.TypePointer(spv::StorageClass::Input, ctx.u32[3]), spv::StorageClass::Input); - ctx.Decorate(ctx.local_invocation_id, spv::Decoration::BuiltIn, - spv::BuiltIn::LocalInvocationId); - } - return ctx.OpLoad(ctx.u32[3], ctx.local_invocation_id); + return ctx.OpLoad(ctx.U32[3], ctx.local_invocation_id); } } // namespace Shader::Backend::SPIRV diff --git a/src/shader_recompiler/backend/spirv/emit_spirv_control_flow.cpp b/src/shader_recompiler/backend/spirv/emit_spirv_control_flow.cpp index 66ce6c8c5..549c1907a 100644 --- a/src/shader_recompiler/backend/spirv/emit_spirv_control_flow.cpp +++ b/src/shader_recompiler/backend/spirv/emit_spirv_control_flow.cpp @@ -3,3 +3,29 @@ // Refer to the license.txt file included. #include "shader_recompiler/backend/spirv/emit_spirv.h" + +namespace Shader::Backend::SPIRV { + +void EmitSPIRV::EmitBranch(EmitContext& ctx, IR::Block* label) { + ctx.OpBranch(label->Definition()); +} + +void EmitSPIRV::EmitBranchConditional(EmitContext& ctx, Id condition, IR::Block* true_label, + IR::Block* false_label) { + ctx.OpBranchConditional(condition, true_label->Definition(), false_label->Definition()); +} + +void EmitSPIRV::EmitLoopMerge(EmitContext& ctx, IR::Block* merge_label, IR::Block* continue_label) { + ctx.OpLoopMerge(merge_label->Definition(), continue_label->Definition(), + spv::LoopControlMask::MaskNone); +} + +void EmitSPIRV::EmitSelectionMerge(EmitContext& ctx, IR::Block* merge_label) { + ctx.OpSelectionMerge(merge_label->Definition(), spv::SelectionControlMask::MaskNone); +} + +void EmitSPIRV::EmitReturn(EmitContext& ctx) { + ctx.OpReturn(); +} + +} // namespace Shader::Backend::SPIRV diff --git a/src/shader_recompiler/backend/spirv/emit_spirv_floating_point.cpp b/src/shader_recompiler/backend/spirv/emit_spirv_floating_point.cpp index 9c39537e2..c9bc121f8 100644 --- a/src/shader_recompiler/backend/spirv/emit_spirv_floating_point.cpp +++ b/src/shader_recompiler/backend/spirv/emit_spirv_floating_point.cpp @@ -46,27 +46,27 @@ void EmitSPIRV::EmitFPAbs64(EmitContext&) { } Id EmitSPIRV::EmitFPAdd16(EmitContext& ctx, IR::Inst* inst, Id a, Id b) { - return Decorate(ctx, inst, ctx.OpFAdd(ctx.f16[1], a, b)); + return Decorate(ctx, inst, ctx.OpFAdd(ctx.F16[1], a, b)); } Id EmitSPIRV::EmitFPAdd32(EmitContext& ctx, IR::Inst* inst, Id a, Id b) { - return Decorate(ctx, inst, ctx.OpFAdd(ctx.f32[1], a, b)); + return Decorate(ctx, inst, ctx.OpFAdd(ctx.F32[1], a, b)); } Id EmitSPIRV::EmitFPAdd64(EmitContext& ctx, IR::Inst* inst, Id a, Id b) { - return Decorate(ctx, inst, ctx.OpFAdd(ctx.f64[1], a, b)); + return Decorate(ctx, inst, ctx.OpFAdd(ctx.F64[1], a, b)); } Id EmitSPIRV::EmitFPFma16(EmitContext& ctx, IR::Inst* inst, Id a, Id b, Id c) { - return Decorate(ctx, inst, ctx.OpFma(ctx.f16[1], a, b, c)); + return Decorate(ctx, inst, ctx.OpFma(ctx.F16[1], a, b, c)); } Id EmitSPIRV::EmitFPFma32(EmitContext& ctx, IR::Inst* inst, Id a, Id b, Id c) { - return Decorate(ctx, inst, ctx.OpFma(ctx.f32[1], a, b, c)); + return Decorate(ctx, inst, ctx.OpFma(ctx.F32[1], a, b, c)); } Id EmitSPIRV::EmitFPFma64(EmitContext& ctx, IR::Inst* inst, Id a, Id b, Id c) { - return Decorate(ctx, inst, ctx.OpFma(ctx.f64[1], a, b, c)); + return Decorate(ctx, inst, ctx.OpFma(ctx.F64[1], a, b, c)); } void EmitSPIRV::EmitFPMax32(EmitContext&) { @@ -86,15 +86,15 @@ void EmitSPIRV::EmitFPMin64(EmitContext&) { } Id EmitSPIRV::EmitFPMul16(EmitContext& ctx, IR::Inst* inst, Id a, Id b) { - return Decorate(ctx, inst, ctx.OpFMul(ctx.f16[1], a, b)); + return Decorate(ctx, inst, ctx.OpFMul(ctx.F16[1], a, b)); } Id EmitSPIRV::EmitFPMul32(EmitContext& ctx, IR::Inst* inst, Id a, Id b) { - return Decorate(ctx, inst, ctx.OpFMul(ctx.f32[1], a, b)); + return Decorate(ctx, inst, ctx.OpFMul(ctx.F32[1], a, b)); } Id EmitSPIRV::EmitFPMul64(EmitContext& ctx, IR::Inst* inst, Id a, Id b) { - return Decorate(ctx, inst, ctx.OpFMul(ctx.f64[1], a, b)); + return Decorate(ctx, inst, ctx.OpFMul(ctx.F64[1], a, b)); } void EmitSPIRV::EmitFPNeg16(EmitContext&) { diff --git a/src/shader_recompiler/backend/spirv/emit_spirv_integer.cpp b/src/shader_recompiler/backend/spirv/emit_spirv_integer.cpp index e811a63ab..32af94a73 100644 --- a/src/shader_recompiler/backend/spirv/emit_spirv_integer.cpp +++ b/src/shader_recompiler/backend/spirv/emit_spirv_integer.cpp @@ -10,7 +10,7 @@ Id EmitSPIRV::EmitIAdd32(EmitContext& ctx, IR::Inst* inst, Id a, Id b) { if (inst->HasAssociatedPseudoOperation()) { throw NotImplementedException("Pseudo-operations on IAdd32"); } - return ctx.OpIAdd(ctx.u32[1], a, b); + return ctx.OpIAdd(ctx.U32[1], a, b); } void EmitSPIRV::EmitIAdd64(EmitContext&) { @@ -18,7 +18,7 @@ void EmitSPIRV::EmitIAdd64(EmitContext&) { } Id EmitSPIRV::EmitISub32(EmitContext& ctx, Id a, Id b) { - return ctx.OpISub(ctx.u32[1], a, b); + return ctx.OpISub(ctx.U32[1], a, b); } void EmitSPIRV::EmitISub64(EmitContext&) { @@ -26,7 +26,7 @@ void EmitSPIRV::EmitISub64(EmitContext&) { } Id EmitSPIRV::EmitIMul32(EmitContext& ctx, Id a, Id b) { - return ctx.OpIMul(ctx.u32[1], a, b); + return ctx.OpIMul(ctx.U32[1], a, b); } void EmitSPIRV::EmitINeg32(EmitContext&) { @@ -38,7 +38,7 @@ void EmitSPIRV::EmitIAbs32(EmitContext&) { } Id EmitSPIRV::EmitShiftLeftLogical32(EmitContext& ctx, Id base, Id shift) { - return ctx.OpShiftLeftLogical(ctx.u32[1], base, shift); + return ctx.OpShiftLeftLogical(ctx.U32[1], base, shift); } void EmitSPIRV::EmitShiftRightLogical32(EmitContext&) { @@ -70,11 +70,11 @@ void EmitSPIRV::EmitBitFieldSExtract(EmitContext&) { } Id EmitSPIRV::EmitBitFieldUExtract(EmitContext& ctx, Id base, Id offset, Id count) { - return ctx.OpBitFieldUExtract(ctx.u32[1], base, offset, count); + return ctx.OpBitFieldUExtract(ctx.U32[1], base, offset, count); } Id EmitSPIRV::EmitSLessThan(EmitContext& ctx, Id lhs, Id rhs) { - return ctx.OpSLessThan(ctx.u1, lhs, rhs); + return ctx.OpSLessThan(ctx.U1, lhs, rhs); } void EmitSPIRV::EmitULessThan(EmitContext&) { @@ -94,7 +94,7 @@ void EmitSPIRV::EmitULessThanEqual(EmitContext&) { } Id EmitSPIRV::EmitSGreaterThan(EmitContext& ctx, Id lhs, Id rhs) { - return ctx.OpSGreaterThan(ctx.u1, lhs, rhs); + return ctx.OpSGreaterThan(ctx.U1, lhs, rhs); } void EmitSPIRV::EmitUGreaterThan(EmitContext&) { @@ -110,7 +110,7 @@ void EmitSPIRV::EmitSGreaterThanEqual(EmitContext&) { } Id EmitSPIRV::EmitUGreaterThanEqual(EmitContext& ctx, Id lhs, Id rhs) { - return ctx.OpUGreaterThanEqual(ctx.u1, lhs, rhs); + return ctx.OpUGreaterThanEqual(ctx.U1, lhs, rhs); } void EmitSPIRV::EmitLogicalOr(EmitContext&) { diff --git a/src/shader_recompiler/backend/spirv/emit_spirv_memory.cpp b/src/shader_recompiler/backend/spirv/emit_spirv_memory.cpp index 21a0d72fa..5769a3c95 100644 --- a/src/shader_recompiler/backend/spirv/emit_spirv_memory.cpp +++ b/src/shader_recompiler/backend/spirv/emit_spirv_memory.cpp @@ -2,10 +2,26 @@ // Licensed under GPLv2 or any later version // Refer to the license.txt file included. +#include + #include "shader_recompiler/backend/spirv/emit_spirv.h" namespace Shader::Backend::SPIRV { +static Id StorageIndex(EmitContext& ctx, const IR::Value& offset, size_t element_size) { + if (offset.IsImmediate()) { + const u32 imm_offset{static_cast(offset.U32() / element_size)}; + return ctx.Constant(ctx.U32[1], imm_offset); + } + const u32 shift{static_cast(std::countr_zero(element_size))}; + const Id index{ctx.Def(offset)}; + if (shift == 0) { + return index; + } + const Id shift_id{ctx.Constant(ctx.U32[1], shift)}; + return ctx.OpShiftRightLogical(ctx.U32[1], index, shift_id); +} + void EmitSPIRV::EmitLoadGlobalU8(EmitContext&) { throw NotImplementedException("SPIR-V Instruction"); } @@ -79,11 +95,14 @@ void EmitSPIRV::EmitLoadStorageS16(EmitContext&) { } Id EmitSPIRV::EmitLoadStorage32(EmitContext& ctx, const IR::Value& binding, - [[maybe_unused]] const IR::Value& offset) { + const IR::Value& offset) { if (!binding.IsImmediate()) { - throw NotImplementedException("Storage buffer indexing"); + throw NotImplementedException("Dynamic storage buffer indexing"); } - return ctx.Name(ctx.OpUndef(ctx.u32[1]), "unimplemented_sbuf"); + const Id ssbo{ctx.ssbos[binding.U32()]}; + const Id index{StorageIndex(ctx, offset, sizeof(u32))}; + const Id pointer{ctx.OpAccessChain(ctx.storage_u32, ssbo, ctx.u32_zero_value, index)}; + return ctx.OpLoad(ctx.U32[1], pointer); } void EmitSPIRV::EmitLoadStorage64(EmitContext&) { @@ -110,8 +129,15 @@ void EmitSPIRV::EmitWriteStorageS16(EmitContext&) { throw NotImplementedException("SPIR-V Instruction"); } -void EmitSPIRV::EmitWriteStorage32(EmitContext& ctx) { - ctx.Name(ctx.OpUndef(ctx.u32[1]), "unimplemented_sbuf_store"); +void EmitSPIRV::EmitWriteStorage32(EmitContext& ctx, const IR::Value& binding, + const IR::Value& offset, Id value) { + if (!binding.IsImmediate()) { + throw NotImplementedException("Dynamic storage buffer indexing"); + } + const Id ssbo{ctx.ssbos[binding.U32()]}; + const Id index{StorageIndex(ctx, offset, sizeof(u32))}; + const Id pointer{ctx.OpAccessChain(ctx.storage_u32, ssbo, ctx.u32_zero_value, index)}; + ctx.OpStore(pointer, value); } void EmitSPIRV::EmitWriteStorage64(EmitContext&) { diff --git a/src/shader_recompiler/backend/spirv/emit_spirv_undefined.cpp b/src/shader_recompiler/backend/spirv/emit_spirv_undefined.cpp index a6f542360..c1ed8f281 100644 --- a/src/shader_recompiler/backend/spirv/emit_spirv_undefined.cpp +++ b/src/shader_recompiler/backend/spirv/emit_spirv_undefined.cpp @@ -7,7 +7,7 @@ namespace Shader::Backend::SPIRV { Id EmitSPIRV::EmitUndefU1(EmitContext& ctx) { - return ctx.OpUndef(ctx.u1); + return ctx.OpUndef(ctx.U1); } Id EmitSPIRV::EmitUndefU8(EmitContext&) { @@ -19,7 +19,7 @@ Id EmitSPIRV::EmitUndefU16(EmitContext&) { } Id EmitSPIRV::EmitUndefU32(EmitContext& ctx) { - return ctx.OpUndef(ctx.u32[1]); + return ctx.OpUndef(ctx.U32[1]); } Id EmitSPIRV::EmitUndefU64(EmitContext&) { diff --git a/src/shader_recompiler/frontend/ir/basic_block.h b/src/shader_recompiler/frontend/ir/basic_block.h index 778b32e43..b14a35ec5 100644 --- a/src/shader_recompiler/frontend/ir/basic_block.h +++ b/src/shader_recompiler/frontend/ir/basic_block.h @@ -11,6 +11,7 @@ #include +#include "common/bit_cast.h" #include "shader_recompiler/frontend/ir/condition.h" #include "shader_recompiler/frontend/ir/microinstruction.h" #include "shader_recompiler/frontend/ir/value.h" @@ -68,6 +69,18 @@ public: /// Gets an immutable span to the immediate predecessors. [[nodiscard]] std::span ImmediatePredecessors() const noexcept; + /// Intrusively store the host definition of this instruction. + template + void SetDefinition(DefinitionType def) { + definition = Common::BitCast(def); + } + + /// Return the intrusively stored host definition of this instruction. + template + [[nodiscard]] DefinitionType Definition() const noexcept { + return Common::BitCast(definition); + } + [[nodiscard]] Condition BranchCondition() const noexcept { return branch_cond; } @@ -161,6 +174,9 @@ private: Block* branch_false{nullptr}; /// Block immediate predecessors std::vector imm_predecessors; + + /// Intrusively stored host definition of this block. + u32 definition{}; }; using BlockList = std::vector; diff --git a/src/shader_recompiler/frontend/ir/program.h b/src/shader_recompiler/frontend/ir/program.h index efaf1aa1e..98aab2dc6 100644 --- a/src/shader_recompiler/frontend/ir/program.h +++ b/src/shader_recompiler/frontend/ir/program.h @@ -9,11 +9,13 @@ #include #include "shader_recompiler/frontend/ir/function.h" +#include "shader_recompiler/shader_info.h" namespace Shader::IR { struct Program { boost::container::small_vector functions; + Info info; }; [[nodiscard]] std::string DumpProgram(const Program& program); diff --git a/src/shader_recompiler/frontend/maxwell/program.cpp b/src/shader_recompiler/frontend/maxwell/program.cpp index dab6d68c0..8331d576c 100644 --- a/src/shader_recompiler/frontend/maxwell/program.cpp +++ b/src/shader_recompiler/frontend/maxwell/program.cpp @@ -53,21 +53,22 @@ IR::Program TranslateProgram(ObjectPool& inst_pool, ObjectPoolInstructions()) { + Visit(info, inst); + } + } + } +} + +} // namespace Shader::Optimization diff --git a/src/shader_recompiler/ir_opt/constant_propagation_pass.cpp b/src/shader_recompiler/ir_opt/constant_propagation_pass.cpp index cbde65b9b..f1ad16d60 100644 --- a/src/shader_recompiler/ir_opt/constant_propagation_pass.cpp +++ b/src/shader_recompiler/ir_opt/constant_propagation_pass.cpp @@ -77,6 +77,16 @@ bool FoldCommutative(IR::Inst& inst, ImmFn&& imm_fn) { return true; } +template +bool FoldWhenAllImmediates(IR::Inst& inst, Func&& func) { + if (!inst.AreAllArgsImmediates() || inst.HasAssociatedPseudoOperation()) { + return false; + } + using Indices = std::make_index_sequence::NUM_ARGS>; + inst.ReplaceUsesWith(EvalImmediates(inst, func, Indices{})); + return true; +} + void FoldGetRegister(IR::Inst& inst) { if (inst.Arg(0).Reg() == IR::Reg::RZ) { inst.ReplaceUsesWith(IR::Value{u32{0}}); @@ -103,6 +113,52 @@ void FoldAdd(IR::Inst& inst) { } } +void FoldISub32(IR::Inst& inst) { + if (FoldWhenAllImmediates(inst, [](u32 a, u32 b) { return a - b; })) { + return; + } + if (inst.Arg(0).IsImmediate() || inst.Arg(1).IsImmediate()) { + return; + } + // ISub32 is generally used to subtract two constant buffers, compare and replace this with + // zero if they equal. + const auto equal_cbuf{[](IR::Inst* a, IR::Inst* b) { + return a->Opcode() == IR::Opcode::GetCbuf && b->Opcode() == IR::Opcode::GetCbuf && + a->Arg(0) == b->Arg(0) && a->Arg(1) == b->Arg(1); + }}; + IR::Inst* op_a{inst.Arg(0).InstRecursive()}; + IR::Inst* op_b{inst.Arg(1).InstRecursive()}; + if (equal_cbuf(op_a, op_b)) { + inst.ReplaceUsesWith(IR::Value{u32{0}}); + return; + } + // It's also possible a value is being added to a cbuf and then subtracted + if (op_b->Opcode() == IR::Opcode::IAdd32) { + // Canonicalize local variables to simplify the following logic + std::swap(op_a, op_b); + } + if (op_b->Opcode() != IR::Opcode::GetCbuf) { + return; + } + IR::Inst* const inst_cbuf{op_b}; + if (op_a->Opcode() != IR::Opcode::IAdd32) { + return; + } + IR::Value add_op_a{op_a->Arg(0)}; + IR::Value add_op_b{op_a->Arg(1)}; + if (add_op_b.IsImmediate()) { + // Canonicalize + std::swap(add_op_a, add_op_b); + } + if (add_op_b.IsImmediate()) { + return; + } + IR::Inst* const add_cbuf{add_op_b.InstRecursive()}; + if (equal_cbuf(add_cbuf, inst_cbuf)) { + inst.ReplaceUsesWith(add_op_a); + } +} + template void FoldSelect(IR::Inst& inst) { const IR::Value cond{inst.Arg(0)}; @@ -170,15 +226,6 @@ IR::Value EvalImmediates(const IR::Inst& inst, Func&& func, std::index_sequence< return IR::Value{func(Arg>(inst.Arg(I))...)}; } -template -void FoldWhenAllImmediates(IR::Inst& inst, Func&& func) { - if (!inst.AreAllArgsImmediates() || inst.HasAssociatedPseudoOperation()) { - return; - } - using Indices = std::make_index_sequence::NUM_ARGS>; - inst.ReplaceUsesWith(EvalImmediates(inst, func, Indices{})); -} - void FoldBranchConditional(IR::Inst& inst) { const IR::U1 cond{inst.Arg(0)}; if (cond.IsImmediate()) { @@ -205,6 +252,8 @@ void ConstantPropagation(IR::Inst& inst) { return FoldGetPred(inst); case IR::Opcode::IAdd32: return FoldAdd(inst); + case IR::Opcode::ISub32: + return FoldISub32(inst); case IR::Opcode::BitCastF32U32: return FoldBitCast(inst, IR::Opcode::BitCastU32F32); case IR::Opcode::BitCastU32F32: @@ -220,17 +269,20 @@ void ConstantPropagation(IR::Inst& inst) { case IR::Opcode::LogicalNot: return FoldLogicalNot(inst); case IR::Opcode::SLessThan: - return FoldWhenAllImmediates(inst, [](s32 a, s32 b) { return a < b; }); + FoldWhenAllImmediates(inst, [](s32 a, s32 b) { return a < b; }); + return; case IR::Opcode::ULessThan: - return FoldWhenAllImmediates(inst, [](u32 a, u32 b) { return a < b; }); + FoldWhenAllImmediates(inst, [](u32 a, u32 b) { return a < b; }); + return; case IR::Opcode::BitFieldUExtract: - return FoldWhenAllImmediates(inst, [](u32 base, u32 shift, u32 count) { + FoldWhenAllImmediates(inst, [](u32 base, u32 shift, u32 count) { if (static_cast(shift) + static_cast(count) > Common::BitSize()) { throw LogicError("Undefined result in {}({}, {}, {})", IR::Opcode::BitFieldUExtract, base, shift, count); } return (base >> shift) & ((1U << count) - 1); }); + return; case IR::Opcode::BranchConditional: return FoldBranchConditional(inst); default: diff --git a/src/shader_recompiler/ir_opt/global_memory_to_storage_buffer_pass.cpp b/src/shader_recompiler/ir_opt/global_memory_to_storage_buffer_pass.cpp index b40c0c57b..bf230a850 100644 --- a/src/shader_recompiler/ir_opt/global_memory_to_storage_buffer_pass.cpp +++ b/src/shader_recompiler/ir_opt/global_memory_to_storage_buffer_pass.cpp @@ -28,7 +28,8 @@ struct StorageBufferAddr { /// Block iterator to a global memory instruction and the storage buffer it uses struct StorageInst { StorageBufferAddr storage_buffer; - IR::Block::iterator inst; + IR::Inst* inst; + IR::Block* block; }; /// Bias towards a certain range of constant buffers when looking for storage buffers @@ -41,7 +42,7 @@ struct Bias { using StorageBufferSet = boost::container::flat_set, boost::container::small_vector>; -using StorageInstVector = boost::container::small_vector; +using StorageInstVector = boost::container::small_vector; /// Returns true when the instruction is a global memory instruction bool IsGlobalMemory(const IR::Inst& inst) { @@ -109,23 +110,22 @@ bool MeetsBias(const StorageBufferAddr& storage_buffer, const Bias& bias) noexce } /// Discards a global memory operation, reads return zero and writes are ignored -void DiscardGlobalMemory(IR::Block& block, IR::Block::iterator inst) { +void DiscardGlobalMemory(IR::Block& block, IR::Inst& inst) { + IR::IREmitter ir{block, IR::Block::InstructionList::s_iterator_to(inst)}; const IR::Value zero{u32{0}}; - switch (inst->Opcode()) { + switch (inst.Opcode()) { case IR::Opcode::LoadGlobalS8: case IR::Opcode::LoadGlobalU8: case IR::Opcode::LoadGlobalS16: case IR::Opcode::LoadGlobalU16: case IR::Opcode::LoadGlobal32: - inst->ReplaceUsesWith(zero); + inst.ReplaceUsesWith(zero); break; case IR::Opcode::LoadGlobal64: - inst->ReplaceUsesWith(IR::Value{ - &*block.PrependNewInst(inst, IR::Opcode::CompositeConstructU32x2, {zero, zero})}); + inst.ReplaceUsesWith(IR::Value{ir.CompositeConstruct(zero, zero)}); break; case IR::Opcode::LoadGlobal128: - inst->ReplaceUsesWith(IR::Value{&*block.PrependNewInst( - inst, IR::Opcode::CompositeConstructU32x4, {zero, zero, zero, zero})}); + inst.ReplaceUsesWith(IR::Value{ir.CompositeConstruct(zero, zero, zero, zero)}); break; case IR::Opcode::WriteGlobalS8: case IR::Opcode::WriteGlobalU8: @@ -134,11 +134,10 @@ void DiscardGlobalMemory(IR::Block& block, IR::Block::iterator inst) { case IR::Opcode::WriteGlobal32: case IR::Opcode::WriteGlobal64: case IR::Opcode::WriteGlobal128: - inst->Invalidate(); + inst.Invalidate(); break; default: - throw LogicError("Invalid opcode to discard its global memory operation {}", - inst->Opcode()); + throw LogicError("Invalid opcode to discard its global memory operation {}", inst.Opcode()); } } @@ -232,8 +231,8 @@ std::optional Track(const IR::Value& value, const Bias* bias) } /// Collects the storage buffer used by a global memory instruction and the instruction itself -void CollectStorageBuffers(IR::Block& block, IR::Block::iterator inst, - StorageBufferSet& storage_buffer_set, StorageInstVector& to_replace) { +void CollectStorageBuffers(IR::Block& block, IR::Inst& inst, StorageBufferSet& storage_buffer_set, + StorageInstVector& to_replace) { // NVN puts storage buffers in a specific range, we have to bias towards these addresses to // avoid getting false positives static constexpr Bias nvn_bias{ @@ -241,19 +240,13 @@ void CollectStorageBuffers(IR::Block& block, IR::Block::iterator inst, .offset_begin{0x110}, .offset_end{0x610}, }; - // First try to find storage buffers in the NVN address - const IR::U64 addr{inst->Arg(0)}; - if (addr.IsImmediate()) { - // Immediate addresses can't be lowered to a storage buffer - DiscardGlobalMemory(block, inst); - return; - } // Track the low address of the instruction - const std::optional low_addr_info{TrackLowAddress(addr.InstRecursive())}; + const std::optional low_addr_info{TrackLowAddress(&inst)}; if (!low_addr_info) { DiscardGlobalMemory(block, inst); return; } + // First try to find storage buffers in the NVN address const IR::U32 low_addr{low_addr_info->value}; std::optional storage_buffer{Track(low_addr, &nvn_bias)}; if (!storage_buffer) { @@ -269,21 +262,22 @@ void CollectStorageBuffers(IR::Block& block, IR::Block::iterator inst, storage_buffer_set.insert(*storage_buffer); to_replace.push_back(StorageInst{ .storage_buffer{*storage_buffer}, - .inst{inst}, + .inst{&inst}, + .block{&block}, }); } /// Returns the offset in indices (not bytes) for an equivalent storage instruction -IR::U32 StorageOffset(IR::Block& block, IR::Block::iterator inst, StorageBufferAddr buffer) { - IR::IREmitter ir{block, inst}; +IR::U32 StorageOffset(IR::Block& block, IR::Inst& inst, StorageBufferAddr buffer) { + IR::IREmitter ir{block, IR::Block::InstructionList::s_iterator_to(inst)}; IR::U32 offset; - if (const std::optional low_addr{TrackLowAddress(&*inst)}) { + if (const std::optional low_addr{TrackLowAddress(&inst)}) { offset = low_addr->value; if (low_addr->imm_offset != 0) { offset = ir.IAdd(offset, ir.Imm32(low_addr->imm_offset)); } } else { - offset = ir.ConvertU(32, IR::U64{inst->Arg(0)}); + offset = ir.ConvertU(32, IR::U64{inst.Arg(0)}); } // Subtract the least significant 32 bits from the guest offset. The result is the storage // buffer offset in bytes. @@ -292,25 +286,27 @@ IR::U32 StorageOffset(IR::Block& block, IR::Block::iterator inst, StorageBufferA } /// Replace a global memory load instruction with its storage buffer equivalent -void ReplaceLoad(IR::Block& block, IR::Block::iterator inst, const IR::U32& storage_index, +void ReplaceLoad(IR::Block& block, IR::Inst& inst, const IR::U32& storage_index, const IR::U32& offset) { - const IR::Opcode new_opcode{GlobalToStorage(inst->Opcode())}; - const IR::Value value{&*block.PrependNewInst(inst, new_opcode, {storage_index, offset})}; - inst->ReplaceUsesWith(value); + const IR::Opcode new_opcode{GlobalToStorage(inst.Opcode())}; + const auto it{IR::Block::InstructionList::s_iterator_to(inst)}; + const IR::Value value{&*block.PrependNewInst(it, new_opcode, {storage_index, offset})}; + inst.ReplaceUsesWith(value); } /// Replace a global memory write instruction with its storage buffer equivalent -void ReplaceWrite(IR::Block& block, IR::Block::iterator inst, const IR::U32& storage_index, +void ReplaceWrite(IR::Block& block, IR::Inst& inst, const IR::U32& storage_index, const IR::U32& offset) { - const IR::Opcode new_opcode{GlobalToStorage(inst->Opcode())}; - block.PrependNewInst(inst, new_opcode, {storage_index, offset, inst->Arg(1)}); - inst->Invalidate(); + const IR::Opcode new_opcode{GlobalToStorage(inst.Opcode())}; + const auto it{IR::Block::InstructionList::s_iterator_to(inst)}; + block.PrependNewInst(it, new_opcode, {storage_index, offset, inst.Arg(1)}); + inst.Invalidate(); } /// Replace a global memory instruction with its storage buffer equivalent -void Replace(IR::Block& block, IR::Block::iterator inst, const IR::U32& storage_index, +void Replace(IR::Block& block, IR::Inst& inst, const IR::U32& storage_index, const IR::U32& offset) { - switch (inst->Opcode()) { + switch (inst.Opcode()) { case IR::Opcode::LoadGlobalS8: case IR::Opcode::LoadGlobalU8: case IR::Opcode::LoadGlobalS16: @@ -328,26 +324,44 @@ void Replace(IR::Block& block, IR::Block::iterator inst, const IR::U32& storage_ case IR::Opcode::WriteGlobal128: return ReplaceWrite(block, inst, storage_index, offset); default: - throw InvalidArgument("Invalid global memory opcode {}", inst->Opcode()); + throw InvalidArgument("Invalid global memory opcode {}", inst.Opcode()); } } } // Anonymous namespace -void GlobalMemoryToStorageBufferPass(IR::Block& block) { +void GlobalMemoryToStorageBufferPass(IR::Program& program) { StorageBufferSet storage_buffers; StorageInstVector to_replace; - for (IR::Block::iterator inst{block.begin()}; inst != block.end(); ++inst) { - if (!IsGlobalMemory(*inst)) { - continue; + for (IR::Function& function : program.functions) { + for (IR::Block* const block : function.post_order_blocks) { + for (IR::Inst& inst : block->Instructions()) { + if (!IsGlobalMemory(inst)) { + continue; + } + CollectStorageBuffers(*block, inst, storage_buffers, to_replace); + } } - CollectStorageBuffers(block, inst, storage_buffers, to_replace); } - for (const auto [storage_buffer, inst] : to_replace) { - const auto it{storage_buffers.find(storage_buffer)}; - const IR::U32 storage_index{IR::Value{static_cast(storage_buffers.index_of(it))}}; - const IR::U32 offset{StorageOffset(block, inst, storage_buffer)}; - Replace(block, inst, storage_index, offset); + Info& info{program.info}; + u32 storage_index{}; + for (const StorageBufferAddr& storage_buffer : storage_buffers) { + info.storage_buffers_descriptors.push_back({ + .cbuf_index{storage_buffer.index}, + .cbuf_offset{storage_buffer.offset}, + .count{1}, + }); + info.storage_buffers[storage_index] = &info.storage_buffers_descriptors.back(); + ++storage_index; + } + for (const StorageInst& storage_inst : to_replace) { + const StorageBufferAddr storage_buffer{storage_inst.storage_buffer}; + const auto it{storage_buffers.find(storage_inst.storage_buffer)}; + const IR::U32 index{IR::Value{static_cast(storage_buffers.index_of(it))}}; + IR::Block* const block{storage_inst.block}; + IR::Inst* const inst{storage_inst.inst}; + const IR::U32 offset{StorageOffset(*block, *inst, storage_buffer)}; + Replace(*block, *inst, index, offset); } } diff --git a/src/shader_recompiler/ir_opt/passes.h b/src/shader_recompiler/ir_opt/passes.h index 30eb31588..89e5811d3 100644 --- a/src/shader_recompiler/ir_opt/passes.h +++ b/src/shader_recompiler/ir_opt/passes.h @@ -8,6 +8,7 @@ #include "shader_recompiler/frontend/ir/basic_block.h" #include "shader_recompiler/frontend/ir/function.h" +#include "shader_recompiler/frontend/ir/program.h" namespace Shader::Optimization { @@ -18,9 +19,10 @@ void PostOrderInvoke(Func&& func, IR::Function& function) { } } +void CollectShaderInfoPass(IR::Program& program); void ConstantPropagationPass(IR::Block& block); void DeadCodeEliminationPass(IR::Block& block); -void GlobalMemoryToStorageBufferPass(IR::Block& block); +void GlobalMemoryToStorageBufferPass(IR::Program& program); void IdentityRemovalPass(IR::Function& function); void SsaRewritePass(std::span post_order_blocks); void VerificationPass(const IR::Function& function); diff --git a/src/shader_recompiler/main.cpp b/src/shader_recompiler/main.cpp index 216345e91..1610bb34e 100644 --- a/src/shader_recompiler/main.cpp +++ b/src/shader_recompiler/main.cpp @@ -67,8 +67,8 @@ int main() { ObjectPool inst_pool; ObjectPool block_pool; - // FileEnvironment env{"D:\\Shaders\\Database\\Oninaki\\CS8F146B41DB6BD826.bin"}; - FileEnvironment env{"D:\\Shaders\\shader.bin"}; + FileEnvironment env{"D:\\Shaders\\Database\\Oninaki\\CS8F146B41DB6BD826.bin"}; + // FileEnvironment env{"D:\\Shaders\\shader.bin"}; block_pool.ReleaseContents(); inst_pool.ReleaseContents(); flow_block_pool.ReleaseContents(); diff --git a/src/shader_recompiler/shader_info.h b/src/shader_recompiler/shader_info.h index 1760bf4a9..f49a79368 100644 --- a/src/shader_recompiler/shader_info.h +++ b/src/shader_recompiler/shader_info.h @@ -6,23 +6,40 @@ #include +#include "common/common_types.h" + #include namespace Shader { struct Info { - struct ConstantBuffer { + static constexpr size_t MAX_CBUFS{18}; + static constexpr size_t MAX_SSBOS{16}; + struct ConstantBufferDescriptor { + u32 index; + u32 count; }; - struct { - bool workgroup_id{}; - bool local_invocation_id{}; - bool fp16{}; - bool fp64{}; - } uses; + struct StorageBufferDescriptor { + u32 cbuf_index; + u32 cbuf_offset; + u32 count; + }; - std::array<18 + bool uses_workgroup_id{}; + bool uses_local_invocation_id{}; + bool uses_fp16{}; + bool uses_fp64{}; + + u32 constant_buffer_mask{}; + + std::array constant_buffers{}; + boost::container::static_vector + constant_buffer_descriptors; + + std::array storage_buffers{}; + boost::container::static_vector storage_buffers_descriptors; }; } // namespace Shader