shader_cache: Fix use-after-free and orphan invalidation cache entries

This fixes some cases where entries could have been removed multiple
times reading freed memory. To address this issue this commit removes
duplicates from entries marked for removal and sorts out the removal
process to fix another use-after-free situation.

Another issue fixed in this commit is orphan invalidation cache entries.
Previously only the entries that were invalidated in the current
operations had its entries removed. This led to more use-after-free
situations when these entries were actually invalidated but referenced
an object that didn't exist.
This commit is contained in:
ReinUsesLisp 2020-06-28 04:58:58 -03:00
parent 0b954a3305
commit f6cb128eac

View file

@ -20,6 +20,7 @@ namespace VideoCommon {
template <class T>
class ShaderCache {
static constexpr u64 PAGE_BITS = 14;
static constexpr u64 PAGE_SIZE = u64(1) << PAGE_BITS;
struct Entry {
VAddr addr_start;
@ -87,8 +88,8 @@ protected:
const VAddr addr_end = addr + size;
Entry* const entry = NewEntry(addr, addr_end, data.get());
const u64 page_end = addr_end >> PAGE_BITS;
for (u64 page = addr >> PAGE_BITS; page <= page_end; ++page) {
const u64 page_end = (addr_end + PAGE_SIZE - 1) >> PAGE_BITS;
for (u64 page = addr >> PAGE_BITS; page < page_end; ++page) {
invalidation_cache[page].push_back(entry);
}
@ -108,20 +109,13 @@ private:
/// @pre invalidation_mutex is locked
void InvalidatePagesInRegion(VAddr addr, std::size_t size) {
const VAddr addr_end = addr + size;
const u64 page_end = addr_end >> PAGE_BITS;
for (u64 page = addr >> PAGE_BITS; page <= page_end; ++page) {
const auto it = invalidation_cache.find(page);
const u64 page_end = (addr_end + PAGE_SIZE - 1) >> PAGE_BITS;
for (u64 page = addr >> PAGE_BITS; page < page_end; ++page) {
auto it = invalidation_cache.find(page);
if (it == invalidation_cache.end()) {
continue;
}
std::vector<Entry*>& entries = it->second;
InvalidatePageEntries(entries, addr, addr_end);
// If there's nothing else in this page, remove it to avoid overpopulating the hash map.
if (entries.empty()) {
invalidation_cache.erase(it);
}
InvalidatePageEntries(it->second, addr, addr_end);
}
}
@ -131,15 +125,22 @@ private:
if (marked_for_removal.empty()) {
return;
}
std::scoped_lock lock{lookup_mutex};
// Remove duplicates
std::sort(marked_for_removal.begin(), marked_for_removal.end());
marked_for_removal.erase(std::unique(marked_for_removal.begin(), marked_for_removal.end()),
marked_for_removal.end());
std::vector<T*> removed_shaders;
removed_shaders.reserve(marked_for_removal.size());
std::scoped_lock lock{lookup_mutex};
for (Entry* const entry : marked_for_removal) {
if (lookup_cache.erase(entry->addr_start) > 0) {
removed_shaders.push_back(entry->data);
}
removed_shaders.push_back(entry->data);
const auto it = lookup_cache.find(entry->addr_start);
ASSERT(it != lookup_cache.end());
lookup_cache.erase(it);
}
marked_for_removal.clear();
@ -154,17 +155,33 @@ private:
/// @param addr_end Non-inclusive end address of the invalidation
/// @pre invalidation_mutex is locked
void InvalidatePageEntries(std::vector<Entry*>& entries, VAddr addr, VAddr addr_end) {
auto it = entries.begin();
while (it != entries.end()) {
Entry* const entry = *it;
std::size_t index = 0;
while (index < entries.size()) {
Entry* const entry = entries[index];
if (!entry->Overlaps(addr, addr_end)) {
++it;
++index;
continue;
}
UnmarkMemory(entry);
marked_for_removal.push_back(entry);
it = entries.erase(it);
UnmarkMemory(entry);
RemoveEntryFromInvalidationCache(entry);
marked_for_removal.push_back(entry);
}
}
/// @brief Removes all references to an entry in the invalidation cache
/// @param entry Entry to remove from the invalidation cache
/// @pre invalidation_mutex is locked
void RemoveEntryFromInvalidationCache(const Entry* entry) {
const u64 page_end = (entry->addr_end + PAGE_SIZE - 1) >> PAGE_BITS;
for (u64 page = entry->addr_start >> PAGE_BITS; page < page_end; ++page) {
const auto entries_it = invalidation_cache.find(page);
ASSERT(entries_it != invalidation_cache.end());
std::vector<Entry*>& entries = entries_it->second;
const auto entry_it = std::find(entries.begin(), entries.end(), entry);
ASSERT(entry_it != entries.end());
entries.erase(entry_it);
}
}
@ -182,16 +199,11 @@ private:
}
/// @brief Removes a vector of shaders from a list
/// @param removed_shaders Shaders to be removed from the storage, it can contain duplicates
/// @param removed_shaders Shaders to be removed from the storage
/// @pre invalidation_mutex is locked
/// @pre lookup_mutex is locked
void RemoveShadersFromStorage(std::vector<T*> removed_shaders) {
// Remove duplicates
std::sort(removed_shaders.begin(), removed_shaders.end());
removed_shaders.erase(std::unique(removed_shaders.begin(), removed_shaders.end()),
removed_shaders.end());
// Now that there are no duplicates, we can notify removals
// Notify removals
for (T* const shader : removed_shaders) {
OnShaderRemoval(shader);
}