vulkan_device: Blacklist ampere devices from float16 math

This commit is contained in:
ameerj 2021-07-03 01:49:59 -04:00
parent 57f222c56e
commit 7277d7fe96
2 changed files with 23 additions and 12 deletions

View file

@ -194,12 +194,22 @@ std::unordered_map<VkFormat, VkFormatProperties> GetFormatProperties(vk::Physica
return format_properties; return format_properties;
} }
std::vector<std::string> GetSupportedExtensions(vk::PhysicalDevice physical) {
const std::vector extensions = physical.EnumerateDeviceExtensionProperties();
std::vector<std::string> supported_extensions(std::size(extensions));
for (const auto& extension : extensions) {
supported_extensions.emplace_back(extension.extensionName);
}
return supported_extensions;
}
} // Anonymous namespace } // Anonymous namespace
Device::Device(VkInstance instance_, vk::PhysicalDevice physical_, VkSurfaceKHR surface, Device::Device(VkInstance instance_, vk::PhysicalDevice physical_, VkSurfaceKHR surface,
const vk::InstanceDispatch& dld_) const vk::InstanceDispatch& dld_)
: instance{instance_}, dld{dld_}, physical{physical_}, properties{physical.GetProperties()}, : instance{instance_}, dld{dld_}, physical{physical_}, properties{physical.GetProperties()},
format_properties{GetFormatProperties(physical)} { supported_extensions{GetSupportedExtensions(physical)},
format_properties(GetFormatProperties(physical)) {
CheckSuitability(surface != nullptr); CheckSuitability(surface != nullptr);
SetupFamilies(surface); SetupFamilies(surface);
SetupFeatures(); SetupFeatures();
@ -510,6 +520,13 @@ Device::Device(VkInstance instance_, vk::PhysicalDevice physical_, VkSurfaceKHR
CollectTelemetryParameters(); CollectTelemetryParameters();
CollectToolingInfo(); CollectToolingInfo();
if (driver_id == VK_DRIVER_ID_NVIDIA_PROPRIETARY_KHR && is_float16_supported) {
if (std::ranges::find(supported_extensions, VK_KHR_FRAGMENT_SHADING_RATE_EXTENSION_NAME) !=
supported_extensions.end()) {
LOG_WARNING(Render_Vulkan, "Blacklisting Ampere devices from float16 math");
is_float16_supported = false;
}
}
if (ext_extended_dynamic_state && driver_id == VK_DRIVER_ID_MESA_RADV) { if (ext_extended_dynamic_state && driver_id == VK_DRIVER_ID_MESA_RADV) {
// Mask driver version variant // Mask driver version variant
const u32 version = (properties.driverVersion << 3) >> 3; const u32 version = (properties.driverVersion << 3) >> 3;
@ -778,10 +795,10 @@ std::vector<const char*> Device::LoadExtensions(bool requires_surface) {
bool has_ext_provoking_vertex{}; bool has_ext_provoking_vertex{};
bool has_ext_vertex_input_dynamic_state{}; bool has_ext_vertex_input_dynamic_state{};
bool has_ext_line_rasterization{}; bool has_ext_line_rasterization{};
for (const VkExtensionProperties& extension : physical.EnumerateDeviceExtensionProperties()) { for (const std::string& extension : supported_extensions) {
const auto test = [&](std::optional<std::reference_wrapper<bool>> status, const char* name, const auto test = [&](std::optional<std::reference_wrapper<bool>> status, const char* name,
bool push) { bool push) {
if (extension.extensionName != std::string_view(name)) { if (extension != name) {
return; return;
} }
if (push) { if (push) {
@ -1064,12 +1081,6 @@ void Device::CollectTelemetryParameters() {
driver_id = driver.driverID; driver_id = driver.driverID;
vendor_name = driver.driverName; vendor_name = driver.driverName;
const std::vector extensions = physical.EnumerateDeviceExtensionProperties();
reported_extensions.reserve(std::size(extensions));
for (const auto& extension : extensions) {
reported_extensions.emplace_back(extension.extensionName);
}
} }
void Device::CollectPhysicalMemoryInfo() { void Device::CollectPhysicalMemoryInfo() {

View file

@ -301,7 +301,7 @@ public:
/// Returns the list of available extensions. /// Returns the list of available extensions.
const std::vector<std::string>& GetAvailableExtensions() const { const std::vector<std::string>& GetAvailableExtensions() const {
return reported_extensions; return supported_extensions;
} }
u64 GetDeviceLocalMemory() const { u64 GetDeviceLocalMemory() const {
@ -399,7 +399,7 @@ private:
// Telemetry parameters // Telemetry parameters
std::string vendor_name; ///< Device's driver name. std::string vendor_name; ///< Device's driver name.
std::vector<std::string> reported_extensions; ///< Reported Vulkan extensions. std::vector<std::string> supported_extensions; ///< Reported Vulkan extensions.
/// Format properties dictionary. /// Format properties dictionary.
std::unordered_map<VkFormat, VkFormatProperties> format_properties; std::unordered_map<VkFormat, VkFormatProperties> format_properties;