shader: Fix memory barriers

This commit is contained in:
ReinUsesLisp 2021-04-17 03:21:03 -03:00 committed by ameerj
parent c9e4609d87
commit 0a0818c025
8 changed files with 30 additions and 62 deletions

View file

@ -29,9 +29,8 @@ void EmitReturn(EmitContext& ctx);
void EmitUnreachable(EmitContext& ctx);
void EmitDemoteToHelperInvocation(EmitContext& ctx, Id continue_label);
void EmitBarrier(EmitContext& ctx);
void EmitMemoryBarrierWorkgroupLevel(EmitContext& ctx);
void EmitMemoryBarrierDeviceLevel(EmitContext& ctx);
void EmitMemoryBarrierSystemLevel(EmitContext& ctx);
void EmitWorkgroupMemoryBarrier(EmitContext& ctx);
void EmitDeviceMemoryBarrier(EmitContext& ctx);
void EmitPrologue(EmitContext& ctx);
void EmitEpilogue(EmitContext& ctx);
void EmitEmitVertex(EmitContext& ctx, const IR::Value& stream);

View file

@ -7,7 +7,7 @@
namespace Shader::Backend::SPIRV {
namespace {
void EmitMemoryBarrierImpl(EmitContext& ctx, spv::Scope scope) {
void MemoryBarrier(EmitContext& ctx, spv::Scope scope) {
const auto semantics{
spv::MemorySemanticsMask::AcquireRelease | spv::MemorySemanticsMask::UniformMemory |
spv::MemorySemanticsMask::WorkgroupMemory | spv::MemorySemanticsMask::AtomicCounterMemory |
@ -27,16 +27,12 @@ void EmitBarrier(EmitContext& ctx) {
ctx.Constant(ctx.U32[1], static_cast<u32>(memory_semantics)));
}
void EmitMemoryBarrierWorkgroupLevel(EmitContext& ctx) {
EmitMemoryBarrierImpl(ctx, spv::Scope::Workgroup);
void EmitWorkgroupMemoryBarrier(EmitContext& ctx) {
MemoryBarrier(ctx, spv::Scope::Workgroup);
}
void EmitMemoryBarrierDeviceLevel(EmitContext& ctx) {
EmitMemoryBarrierImpl(ctx, spv::Scope::Device);
}
void EmitMemoryBarrierSystemLevel(EmitContext& ctx) {
EmitMemoryBarrierImpl(ctx, spv::Scope::CrossDevice);
void EmitDeviceMemoryBarrier(EmitContext& ctx) {
MemoryBarrier(ctx, spv::Scope::Device);
}
} // namespace Shader::Backend::SPIRV

View file

@ -86,20 +86,12 @@ void IREmitter::Barrier() {
Inst(Opcode::Barrier);
}
void IREmitter::MemoryBarrier(MemoryScope scope) {
switch (scope) {
case MemoryScope::Workgroup:
Inst(Opcode::MemoryBarrierWorkgroupLevel);
break;
case MemoryScope::Device:
Inst(Opcode::MemoryBarrierDeviceLevel);
break;
case MemoryScope::System:
Inst(Opcode::MemoryBarrierSystemLevel);
break;
default:
throw InvalidArgument("Invalid memory scope {}", scope);
}
void IREmitter::WorkgroupMemoryBarrier() {
Inst(Opcode::WorkgroupMemoryBarrier);
}
void IREmitter::DeviceMemoryBarrier() {
Inst(Opcode::DeviceMemoryBarrier);
}
void IREmitter::Return() {

View file

@ -144,8 +144,9 @@ public:
[[nodiscard]] Value Select(const U1& condition, const Value& true_value,
const Value& false_value);
[[nodiscard]] void Barrier();
[[nodiscard]] void MemoryBarrier(MemoryScope scope);
void Barrier();
void WorkgroupMemoryBarrier();
void DeviceMemoryBarrier();
template <typename Dest, typename Source>
[[nodiscard]] Dest BitCast(const Source& value);

View file

@ -64,9 +64,8 @@ bool Inst::MayHaveSideEffects() const noexcept {
case Opcode::Unreachable:
case Opcode::DemoteToHelperInvocation:
case Opcode::Barrier:
case Opcode::MemoryBarrierWorkgroupLevel:
case Opcode::MemoryBarrierDeviceLevel:
case Opcode::MemoryBarrierSystemLevel:
case Opcode::WorkgroupMemoryBarrier:
case Opcode::DeviceMemoryBarrier:
case Opcode::Prologue:
case Opcode::Epilogue:
case Opcode::EmitVertex:

View file

@ -25,14 +25,6 @@ enum class FpRounding : u8 {
RZ, // Round towards zero
};
enum class MemoryScope : u32 {
DontCare,
Warp,
Workgroup,
Device,
System,
};
struct FpControl {
bool no_contraction{false};
FpRounding rounding{FpRounding::DontCare};

View file

@ -18,9 +18,8 @@ OPCODE(DemoteToHelperInvocation, Void, Labe
// Barriers
OPCODE(Barrier, Void, )
OPCODE(MemoryBarrierWorkgroupLevel, Void, )
OPCODE(MemoryBarrierDeviceLevel, Void, )
OPCODE(MemoryBarrierSystemLevel, Void, )
OPCODE(WorkgroupMemoryBarrier, Void, )
OPCODE(DeviceMemoryBarrier, Void, )
// Special operations
OPCODE(Prologue, Void, )

View file

@ -12,34 +12,24 @@ namespace Shader::Maxwell {
namespace {
// Seems to be in CUDA terminology.
enum class LocalScope : u64 {
CTG = 0,
GL = 1,
SYS = 2,
VC = 3,
CTA,
GL,
SYS,
VC,
};
IR::MemoryScope LocalScopeToMemoryScope(LocalScope scope) {
switch (scope) {
case LocalScope::CTG:
return IR::MemoryScope::Workgroup;
case LocalScope::GL:
return IR::MemoryScope::Device;
case LocalScope::SYS:
return IR::MemoryScope::System;
default:
throw NotImplementedException("Unimplemented Local Scope {}", scope);
}
}
} // Anonymous namespace
void TranslatorVisitor::MEMBAR(u64 inst) {
union {
u64 raw;
BitField<8, 2, LocalScope> scope;
} membar{inst};
} const membar{inst};
ir.MemoryBarrier(LocalScopeToMemoryScope(membar.scope));
if (membar.scope == LocalScope::CTA) {
ir.WorkgroupMemoryBarrier();
} else {
ir.DeviceMemoryBarrier();
}
}
void TranslatorVisitor::DEPBAR() {