From 6c512f4bffde6bd8e4dbc74ed27cc84cd7fffadb Mon Sep 17 00:00:00 2001 From: ameerj <52414509+ameerj@users.noreply.github.com> Date: Wed, 14 Apr 2021 00:32:18 -0400 Subject: [PATCH] spirv: Implement alpha test --- .../backend/spirv/emit_spirv_special.cpp | 45 +++++++++++++++++++ src/shader_recompiler/profile.h | 15 ++++++- .../renderer_vulkan/vk_pipeline_cache.cpp | 36 +++++++++++++++ 3 files changed, 95 insertions(+), 1 deletion(-) diff --git a/src/shader_recompiler/backend/spirv/emit_spirv_special.cpp b/src/shader_recompiler/backend/spirv/emit_spirv_special.cpp index 7af29e4dd..8bb94f546 100644 --- a/src/shader_recompiler/backend/spirv/emit_spirv_special.cpp +++ b/src/shader_recompiler/backend/spirv/emit_spirv_special.cpp @@ -37,6 +37,48 @@ Id DefaultVarying(EmitContext& ctx, u32 num_components, u32 element, Id zero, Id } throw InvalidArgument("Bad element"); } + +Id ComparisonFunction(EmitContext& ctx, CompareFunction comparison, Id operand_1, Id operand_2) { + switch (comparison) { + case CompareFunction::Never: + return ctx.false_value; + case CompareFunction::Less: + return ctx.OpFOrdLessThan(ctx.U1, operand_1, operand_2); + case CompareFunction::Equal: + return ctx.OpFOrdEqual(ctx.U1, operand_1, operand_2); + case CompareFunction::LessThanEqual: + return ctx.OpFOrdLessThanEqual(ctx.U1, operand_1, operand_2); + case CompareFunction::Greater: + return ctx.OpFOrdGreaterThan(ctx.U1, operand_1, operand_2); + case CompareFunction::NotEqual: + return ctx.OpFOrdNotEqual(ctx.U1, operand_1, operand_2); + case CompareFunction::GreaterThanEqual: + return ctx.OpFOrdGreaterThanEqual(ctx.U1, operand_1, operand_2); + case CompareFunction::Always: + return ctx.true_value; + } + throw InvalidArgument("Comparison function {}", comparison); +} + +void AlphaTest(EmitContext& ctx) { + const auto comparison{*ctx.profile.alpha_test_func}; + if (comparison == CompareFunction::Always) { + return; + } + const Id type{ctx.F32[1]}; + const Id rt0_color{ctx.OpLoad(ctx.F32[4], ctx.frag_color[0])}; + const Id alpha{ctx.OpCompositeExtract(type, rt0_color, 3u)}; + + const Id true_label{ctx.OpLabel()}; + const Id discard_label{ctx.OpLabel()}; + const Id alpha_reference{ctx.Constant(ctx.F32[1], ctx.profile.alpha_test_reference)}; + const Id condition{ComparisonFunction(ctx, comparison, alpha, alpha_reference)}; + + ctx.OpBranchConditional(condition, true_label, discard_label); + ctx.AddLabel(discard_label); + ctx.OpKill(); + ctx.AddLabel(true_label); +} } // Anonymous namespace void EmitPrologue(EmitContext& ctx) { @@ -68,6 +110,9 @@ void EmitEpilogue(EmitContext& ctx) { if (ctx.stage == Stage::VertexB && ctx.profile.convert_depth_mode) { ConvertDepthMode(ctx); } + if (ctx.stage == Stage::Fragment) { + AlphaTest(ctx); + } } void EmitEmitVertex(EmitContext& ctx, const IR::Value& stream) { diff --git a/src/shader_recompiler/profile.h b/src/shader_recompiler/profile.h index 5ecae71b9..c26017d75 100644 --- a/src/shader_recompiler/profile.h +++ b/src/shader_recompiler/profile.h @@ -5,8 +5,8 @@ #pragma once #include -#include #include +#include #include "common/common_types.h" @@ -27,6 +27,17 @@ enum class InputTopology { TrianglesAdjacency, }; +enum class CompareFunction { + Never, + Less, + Equal, + LessThanEqual, + Greater, + NotEqual, + GreaterThanEqual, + Always, +}; + struct TransformFeedbackVarying { u32 buffer{}; u32 stride{}; @@ -66,6 +77,8 @@ struct Profile { InputTopology input_topology{}; std::optional fixed_state_point_size; + std::optional alpha_test_func; + float alpha_test_reference{}; std::vector xfb_varyings; }; diff --git a/src/video_core/renderer_vulkan/vk_pipeline_cache.cpp b/src/video_core/renderer_vulkan/vk_pipeline_cache.cpp index de52d0f30..80f196d0e 100644 --- a/src/video_core/renderer_vulkan/vk_pipeline_cache.cpp +++ b/src/video_core/renderer_vulkan/vk_pipeline_cache.cpp @@ -492,6 +492,37 @@ private: u32 read_lowest{}; u32 read_highest{}; }; + +Shader::CompareFunction MaxwellToCompareFunction(Maxwell::ComparisonOp comparison) { + switch (comparison) { + case Maxwell::ComparisonOp::Never: + case Maxwell::ComparisonOp::NeverOld: + return Shader::CompareFunction::Never; + case Maxwell::ComparisonOp::Less: + case Maxwell::ComparisonOp::LessOld: + return Shader::CompareFunction::Less; + case Maxwell::ComparisonOp::Equal: + case Maxwell::ComparisonOp::EqualOld: + return Shader::CompareFunction::Equal; + case Maxwell::ComparisonOp::LessEqual: + case Maxwell::ComparisonOp::LessEqualOld: + return Shader::CompareFunction::LessThanEqual; + case Maxwell::ComparisonOp::Greater: + case Maxwell::ComparisonOp::GreaterOld: + return Shader::CompareFunction::Greater; + case Maxwell::ComparisonOp::NotEqual: + case Maxwell::ComparisonOp::NotEqualOld: + return Shader::CompareFunction::NotEqual; + case Maxwell::ComparisonOp::GreaterEqual: + case Maxwell::ComparisonOp::GreaterEqualOld: + return Shader::CompareFunction::GreaterThanEqual; + case Maxwell::ComparisonOp::Always: + case Maxwell::ComparisonOp::AlwaysOld: + return Shader::CompareFunction::Always; + } + UNIMPLEMENTED_MSG("Unimplemented comparison op={}", comparison); + return {}; +} } // Anonymous namespace void PipelineCache::LoadDiskResources(u64 title_id, std::stop_token stop_loading, @@ -1016,6 +1047,11 @@ Shader::Profile PipelineCache::MakeProfile(const GraphicsPipelineCacheKey& key, } profile.convert_depth_mode = gl_ndc; break; + case Shader::Stage::Fragment: + profile.alpha_test_func = MaxwellToCompareFunction( + key.state.UnpackComparisonOp(key.state.alpha_test_func.Value())); + profile.alpha_test_reference = Common::BitCast(key.state.alpha_test_ref); + break; default: break; }