diff --git a/src/core/hle/kernel/k_address_arbiter.cpp b/src/core/hle/kernel/k_address_arbiter.cpp index c5c81a880..165475fbf 100644 --- a/src/core/hle/kernel/k_address_arbiter.cpp +++ b/src/core/hle/kernel/k_address_arbiter.cpp @@ -8,6 +8,7 @@ #include "core/hle/kernel/k_scheduler.h" #include "core/hle/kernel/k_scoped_scheduler_lock_and_sleep.h" #include "core/hle/kernel/k_thread.h" +#include "core/hle/kernel/k_thread_queue.h" #include "core/hle/kernel/kernel.h" #include "core/hle/kernel/svc_results.h" #include "core/hle/kernel/time_manager.h" @@ -85,6 +86,27 @@ bool UpdateIfEqual(Core::System& system, s32* out, VAddr address, s32 value, s32 return true; } +class ThreadQueueImplForKAddressArbiter final : public KThreadQueue { +private: + KAddressArbiter::ThreadTree* m_tree; + +public: + explicit ThreadQueueImplForKAddressArbiter(KernelCore& kernel_, KAddressArbiter::ThreadTree* t) + : KThreadQueue(kernel_), m_tree(t) {} + + virtual void CancelWait(KThread* waiting_thread, ResultCode wait_result, + bool cancel_timer_task) override { + // If the thread is waiting on an address arbiter, remove it from the tree. + if (waiting_thread->IsWaitingForAddressArbiter()) { + m_tree->erase(m_tree->iterator_to(*waiting_thread)); + waiting_thread->ClearAddressArbiter(); + } + + // Invoke the base cancel wait handler. + KThreadQueue::CancelWait(waiting_thread, wait_result, cancel_timer_task); + } +}; + } // namespace ResultCode KAddressArbiter::Signal(VAddr addr, s32 count) { @@ -96,14 +118,14 @@ ResultCode KAddressArbiter::Signal(VAddr addr, s32 count) { auto it = thread_tree.nfind_light({addr, -1}); while ((it != thread_tree.end()) && (count <= 0 || num_waiters < count) && (it->GetAddressArbiterKey() == addr)) { + // End the thread's wait. KThread* target_thread = std::addressof(*it); - target_thread->SetWaitResult(ResultSuccess); + target_thread->EndWait(ResultSuccess); ASSERT(target_thread->IsWaitingForAddressArbiter()); - target_thread->Wakeup(); + target_thread->ClearAddressArbiter(); it = thread_tree.erase(it); - target_thread->ClearAddressArbiter(); ++num_waiters; } } @@ -129,14 +151,14 @@ ResultCode KAddressArbiter::SignalAndIncrementIfEqual(VAddr addr, s32 value, s32 auto it = thread_tree.nfind_light({addr, -1}); while ((it != thread_tree.end()) && (count <= 0 || num_waiters < count) && (it->GetAddressArbiterKey() == addr)) { + // End the thread's wait. KThread* target_thread = std::addressof(*it); - target_thread->SetWaitResult(ResultSuccess); + target_thread->EndWait(ResultSuccess); ASSERT(target_thread->IsWaitingForAddressArbiter()); - target_thread->Wakeup(); + target_thread->ClearAddressArbiter(); it = thread_tree.erase(it); - target_thread->ClearAddressArbiter(); ++num_waiters; } } @@ -197,14 +219,14 @@ ResultCode KAddressArbiter::SignalAndModifyByWaitingCountIfEqual(VAddr addr, s32 while ((it != thread_tree.end()) && (count <= 0 || num_waiters < count) && (it->GetAddressArbiterKey() == addr)) { + // End the thread's wait. KThread* target_thread = std::addressof(*it); - target_thread->SetWaitResult(ResultSuccess); + target_thread->EndWait(ResultSuccess); ASSERT(target_thread->IsWaitingForAddressArbiter()); - target_thread->Wakeup(); + target_thread->ClearAddressArbiter(); it = thread_tree.erase(it); - target_thread->ClearAddressArbiter(); ++num_waiters; } } @@ -214,6 +236,7 @@ ResultCode KAddressArbiter::SignalAndModifyByWaitingCountIfEqual(VAddr addr, s32 ResultCode KAddressArbiter::WaitIfLessThan(VAddr addr, s32 value, bool decrement, s64 timeout) { // Prepare to wait. KThread* cur_thread = kernel.CurrentScheduler()->GetCurrentThread(); + ThreadQueueImplForKAddressArbiter wait_queue(kernel, std::addressof(thread_tree)); { KScopedSchedulerLockAndSleep slp{kernel, cur_thread, timeout}; @@ -224,9 +247,6 @@ ResultCode KAddressArbiter::WaitIfLessThan(VAddr addr, s32 value, bool decrement return ResultTerminationRequested; } - // Set the synced object. - cur_thread->SetWaitResult(ResultTimedOut); - // Read the value from userspace. s32 user_value{}; bool succeeded{}; @@ -256,23 +276,12 @@ ResultCode KAddressArbiter::WaitIfLessThan(VAddr addr, s32 value, bool decrement // Set the arbiter. cur_thread->SetAddressArbiter(&thread_tree, addr); thread_tree.insert(*cur_thread); - cur_thread->SetState(ThreadState::Waiting); + + // Wait for the thread to finish. + cur_thread->BeginWait(std::addressof(wait_queue)); cur_thread->SetWaitReasonForDebugging(ThreadWaitReasonForDebugging::Arbitration); } - // Cancel the timer wait. - kernel.TimeManager().UnscheduleTimeEvent(cur_thread); - - // Remove from the address arbiter. - { - KScopedSchedulerLock sl(kernel); - - if (cur_thread->IsWaitingForAddressArbiter()) { - thread_tree.erase(thread_tree.iterator_to(*cur_thread)); - cur_thread->ClearAddressArbiter(); - } - } - // Get the result. return cur_thread->GetWaitResult(); } @@ -280,6 +289,7 @@ ResultCode KAddressArbiter::WaitIfLessThan(VAddr addr, s32 value, bool decrement ResultCode KAddressArbiter::WaitIfEqual(VAddr addr, s32 value, s64 timeout) { // Prepare to wait. KThread* cur_thread = kernel.CurrentScheduler()->GetCurrentThread(); + ThreadQueueImplForKAddressArbiter wait_queue(kernel, std::addressof(thread_tree)); { KScopedSchedulerLockAndSleep slp{kernel, cur_thread, timeout}; @@ -290,9 +300,6 @@ ResultCode KAddressArbiter::WaitIfEqual(VAddr addr, s32 value, s64 timeout) { return ResultTerminationRequested; } - // Set the synced object. - cur_thread->SetWaitResult(ResultTimedOut); - // Read the value from userspace. s32 user_value{}; if (!ReadFromUser(system, &user_value, addr)) { @@ -315,23 +322,12 @@ ResultCode KAddressArbiter::WaitIfEqual(VAddr addr, s32 value, s64 timeout) { // Set the arbiter. cur_thread->SetAddressArbiter(&thread_tree, addr); thread_tree.insert(*cur_thread); - cur_thread->SetState(ThreadState::Waiting); + + // Wait for the thread to finish. + cur_thread->BeginWait(std::addressof(wait_queue)); cur_thread->SetWaitReasonForDebugging(ThreadWaitReasonForDebugging::Arbitration); } - // Cancel the timer wait. - kernel.TimeManager().UnscheduleTimeEvent(cur_thread); - - // Remove from the address arbiter. - { - KScopedSchedulerLock sl(kernel); - - if (cur_thread->IsWaitingForAddressArbiter()) { - thread_tree.erase(thread_tree.iterator_to(*cur_thread)); - cur_thread->ClearAddressArbiter(); - } - } - // Get the result. return cur_thread->GetWaitResult(); }