From da34d3704405665b68d3d992f37a7eeb541238af Mon Sep 17 00:00:00 2001 From: ReinUsesLisp Date: Tue, 25 May 2021 20:37:06 -0300 Subject: [PATCH] common/thread_worker: Add support for stateful threads --- src/common/CMakeLists.txt | 1 - src/common/thread_worker.cpp | 66 ------------------------ src/common/thread_worker.h | 97 ++++++++++++++++++++++++++++++++---- 3 files changed, 86 insertions(+), 78 deletions(-) delete mode 100644 src/common/thread_worker.cpp diff --git a/src/common/CMakeLists.txt b/src/common/CMakeLists.txt index c05b78cd57..e03fffd8d0 100644 --- a/src/common/CMakeLists.txt +++ b/src/common/CMakeLists.txt @@ -180,7 +180,6 @@ add_library(common STATIC thread.cpp thread.h thread_queue_list.h - thread_worker.cpp thread_worker.h threadsafe_queue.h time_zone.cpp diff --git a/src/common/thread_worker.cpp b/src/common/thread_worker.cpp deleted file mode 100644 index 32be49b154..0000000000 --- a/src/common/thread_worker.cpp +++ /dev/null @@ -1,66 +0,0 @@ -// Copyright 2020 yuzu emulator team -// Licensed under GPLv2 or any later version -// Refer to the license.txt file included. - -#include "common/thread.h" -#include "common/thread_worker.h" - -namespace Common { - -ThreadWorker::ThreadWorker(std::size_t num_workers, const std::string& name) { - workers_queued.store(static_cast(num_workers), std::memory_order_release); - const auto lambda = [this, thread_name{std::string{name}}] { - Common::SetCurrentThreadName(thread_name.c_str()); - - while (!stop) { - UniqueFunction task; - { - std::unique_lock lock{queue_mutex}; - if (requests.empty()) { - wait_condition.notify_all(); - } - condition.wait(lock, [this] { return stop || !requests.empty(); }); - if (stop) { - break; - } - task = std::move(requests.front()); - requests.pop(); - } - task(); - work_done++; - } - workers_stopped++; - wait_condition.notify_all(); - }; - for (size_t i = 0; i < num_workers; ++i) { - threads.emplace_back(lambda); - } -} - -ThreadWorker::~ThreadWorker() { - { - std::unique_lock lock{queue_mutex}; - stop = true; - } - condition.notify_all(); - for (std::thread& thread : threads) { - thread.join(); - } -} - -void ThreadWorker::QueueWork(UniqueFunction work) { - { - std::unique_lock lock{queue_mutex}; - requests.emplace(std::move(work)); - work_scheduled++; - } - condition.notify_one(); -} - -void ThreadWorker::WaitForRequests() { - std::unique_lock lock{queue_mutex}; - wait_condition.wait( - lock, [this] { return workers_stopped >= workers_queued || work_done >= work_scheduled; }); -} - -} // namespace Common diff --git a/src/common/thread_worker.h b/src/common/thread_worker.h index 12bbf5fef4..16aa673bd0 100644 --- a/src/common/thread_worker.h +++ b/src/common/thread_worker.h @@ -8,32 +8,107 @@ #include #include #include +#include #include #include -#include "common/common_types.h" +#include "common/thread.h" #include "common/unique_function.h" namespace Common { -class ThreadWorker final { +template +class StatefulThreadWorker { + static constexpr bool with_state = !std::is_same_v; + + struct DummyCallable { + int operator()() const noexcept { + return 0; + } + }; + + using Task = + std::conditional_t, UniqueFunction>; + using StateMaker = std::conditional_t, DummyCallable>; + public: - explicit ThreadWorker(std::size_t num_workers, const std::string& name); - ~ThreadWorker(); - void QueueWork(UniqueFunction work); - void WaitForRequests(); + explicit StatefulThreadWorker(size_t num_workers, std::string name, StateMaker func = {}) + : workers_queued{num_workers}, thread_name{std::move(name)} { + const auto lambda = [this, func] { + Common::SetCurrentThreadName(thread_name.c_str()); + { + std::conditional_t state{func()}; + while (!stop) { + Task task; + { + std::unique_lock lock{queue_mutex}; + if (requests.empty()) { + wait_condition.notify_all(); + } + condition.wait(lock, [this] { return stop || !requests.empty(); }); + if (stop) { + break; + } + task = std::move(requests.front()); + requests.pop(); + } + if constexpr (with_state) { + task(&state); + } else { + task(); + } + ++work_done; + } + } + ++workers_stopped; + wait_condition.notify_all(); + }; + for (size_t i = 0; i < num_workers; ++i) { + threads.emplace_back(lambda); + } + } + + ~StatefulThreadWorker() { + { + std::unique_lock lock{queue_mutex}; + stop = true; + } + condition.notify_all(); + for (std::thread& thread : threads) { + thread.join(); + } + } + + void QueueWork(Task work) { + { + std::unique_lock lock{queue_mutex}; + requests.emplace(std::move(work)); + ++work_scheduled; + } + condition.notify_one(); + } + + void WaitForRequests() { + std::unique_lock lock{queue_mutex}; + wait_condition.wait(lock, [this] { + return workers_stopped >= workers_queued || work_done >= work_scheduled; + }); + } private: std::vector threads; - std::queue> requests; + std::queue requests; std::mutex queue_mutex; std::condition_variable condition; std::condition_variable wait_condition; std::atomic_bool stop{}; - std::atomic work_scheduled{}; - std::atomic work_done{}; - std::atomic workers_stopped{}; - std::atomic workers_queued{}; + std::atomic work_scheduled{}; + std::atomic work_done{}; + std::atomic workers_stopped{}; + std::atomic workers_queued{}; + std::string thread_name; }; +using ThreadWorker = StatefulThreadWorker<>; + } // namespace Common