diff --git a/ktransformers/ktransformers_ext/cpu_backend/backend.cpp b/ktransformers/ktransformers_ext/cpu_backend/backend.cpp index 5980ba3..a254db9 100644 --- a/ktransformers/ktransformers_ext/cpu_backend/backend.cpp +++ b/ktransformers/ktransformers_ext/cpu_backend/backend.cpp @@ -54,7 +54,12 @@ void Backend::do_work_stealing_job(int task_num, init_func_ = init_func; compute_func_ = compute_func; finalize_func_ = finalize_func; +#ifdef USE_NUMA + // numa node location will be calculated based on the number of threads + thread_num_ = max_thread_num_; +#else thread_num_ = std::min(max_thread_num_, task_num); +#endif int base = task_num / thread_num_; int remain = task_num % thread_num_; thread_state_[0].end = base + (0 < remain); @@ -146,4 +151,4 @@ void Backend::worker_thread(int thread_id) { return; } } -} \ No newline at end of file +}