diff --git a/ktransformers/ktransformers_ext/cpu_backend/task_queue.cpp b/ktransformers/ktransformers_ext/cpu_backend/task_queue.cpp index fb7ac4f..5d20d1e 100644 --- a/ktransformers/ktransformers_ext/cpu_backend/task_queue.cpp +++ b/ktransformers/ktransformers_ext/cpu_backend/task_queue.cpp @@ -16,17 +16,23 @@ TaskQueue::TaskQueue() { } TaskQueue::~TaskQueue() { - exit_flag.store(true, std::memory_order_seq_cst); + { + std::unique_lock lock(mutex); + exit_flag.store(true, std::memory_order_seq_cst); + } + cv.notify_all(); if (worker.joinable()) { worker.join(); } } void TaskQueue::enqueue(std::function task) { - mutex.lock(); - tasks.push(task); - sync_flag.store(false, std::memory_order_seq_cst); - mutex.unlock(); + { + std::unique_lock lock(mutex); + tasks.push(task); + sync_flag.store(false, std::memory_order_seq_cst); + } + cv.notify_one(); } void TaskQueue::sync() { @@ -36,22 +42,22 @@ void TaskQueue::sync() { void TaskQueue::processTasks() { while (true) { - mutex.lock(); - if (tasks.empty()) { - if (exit_flag.load(std::memory_order_seq_cst)) { + std::function task; + { + std::unique_lock lock(mutex); + cv.wait(lock, [this]() { return !tasks.empty() || exit_flag.load(std::memory_order_seq_cst); }); + if (exit_flag.load(std::memory_order_seq_cst) && tasks.empty()) { return; } - mutex.unlock(); - continue; + task = tasks.front(); + tasks.pop(); } - std::function task = tasks.front(); - mutex.unlock(); task(); - mutex.lock(); - tasks.pop(); - if (tasks.empty()) { - sync_flag.store(true, std::memory_order_seq_cst); + { + std::lock_guard lock(mutex); + if (tasks.empty()) { + sync_flag.store(true, std::memory_order_seq_cst); + } } - mutex.unlock(); } -} \ No newline at end of file +} diff --git a/ktransformers/ktransformers_ext/cpu_backend/task_queue.h b/ktransformers/ktransformers_ext/cpu_backend/task_queue.h index a633a40..5325dcc 100644 --- a/ktransformers/ktransformers_ext/cpu_backend/task_queue.h +++ b/ktransformers/ktransformers_ext/cpu_backend/task_queue.h @@ -69,8 +69,9 @@ class TaskQueue { void processTasks(); std::queue> tasks; + std::mutex mutex; + std::condition_variable cv; std::thread worker; - custom_mutex mutex; std::atomic sync_flag; std::atomic exit_flag; };