xref: /aosp_15_r20/external/pytorch/caffe2/utils/threadpool/pthreadpool_impl.cc (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include "caffe2/utils/threadpool/pthreadpool.h"
2 #include "caffe2/utils/threadpool/pthreadpool-cpp.h"
3 #include "caffe2/utils/threadpool/ThreadPool.h"
4 
5 #if defined(USE_PTHREADPOOL)
6 namespace caffe2 {
7 namespace {
8 static thread_local bool using_new_threadpool{false};
9 }
WithCastToNewThreadPool(bool use_new_threadpool)10 WithCastToNewThreadPool::WithCastToNewThreadPool(bool use_new_threadpool) {
11   use_new_threadpool_ = using_new_threadpool;
12   using_new_threadpool = use_new_threadpool;
13 }
~WithCastToNewThreadPool()14 WithCastToNewThreadPool::~WithCastToNewThreadPool() {
15   using_new_threadpool = use_new_threadpool_;
16 }
17 }
18 #endif
19 
20 //
21 // External API
22 //
23 
legacy_pthreadpool_compute_1d(legacy_pthreadpool_t threadpool,legacy_pthreadpool_function_1d_t function,void * argument,size_t range)24 void legacy_pthreadpool_compute_1d(
25     legacy_pthreadpool_t threadpool,
26     legacy_pthreadpool_function_1d_t function,
27     void* argument,
28     size_t range) {
29   if (threadpool == nullptr) {
30     /* No thread pool provided: execute function sequentially on the calling
31      * thread */
32     for (size_t i = 0; i < range; i++) {
33       function(argument, i);
34     }
35     return;
36   }
37 #if defined(USE_PTHREADPOOL)
38   if (caffe2::using_new_threadpool) {
39     pthreadpool_parallelize_1d(threadpool, function, argument, range, 0u);
40   } else {
41     reinterpret_cast<caffe2::ThreadPool*>(threadpool)
42         ->run(
43             [function, argument](int threadId, size_t workId) {
44               function(argument, workId);
45             },
46             range);
47   }
48 #else
49   reinterpret_cast<caffe2::ThreadPool*>(threadpool)
50       ->run(
51           [function, argument](int threadId, size_t workId) {
52             function(argument, workId);
53           },
54           range);
55 #endif
56 }
57 
legacy_pthreadpool_parallelize_1d(const legacy_pthreadpool_t threadpool,const legacy_pthreadpool_function_1d_t function,void * const argument,const size_t range,uint32_t)58 void legacy_pthreadpool_parallelize_1d(
59     const legacy_pthreadpool_t threadpool,
60     const legacy_pthreadpool_function_1d_t function,
61     void* const argument,
62     const size_t range,
63     uint32_t) {
64   legacy_pthreadpool_compute_1d(threadpool, function, argument, range);
65 }
66 
legacy_pthreadpool_get_threads_count(legacy_pthreadpool_t threadpool)67 size_t legacy_pthreadpool_get_threads_count(legacy_pthreadpool_t threadpool) {
68   // The current fix only useful when XNNPACK calls legacy_pthreadpool_get_threads_count with nullptr.
69   if (threadpool == nullptr) {
70     return 1;
71   }
72   return reinterpret_cast<caffe2::ThreadPool*>(threadpool)->getNumThreads();
73 }
74 
legacy_pthreadpool_create(size_t threads_count)75 legacy_pthreadpool_t legacy_pthreadpool_create(size_t threads_count) {
76   std::mutex thread_pool_creation_mutex_;
77   std::lock_guard<std::mutex> guard(thread_pool_creation_mutex_);
78 
79   return reinterpret_cast<legacy_pthreadpool_t>(caffe2::ThreadPool::createThreadPool(threads_count));
80 }
81 
legacy_pthreadpool_destroy(legacy_pthreadpool_t pthreadpool)82 void legacy_pthreadpool_destroy(legacy_pthreadpool_t pthreadpool) {
83   if (pthreadpool) {
84     caffe2::ThreadPool* threadpool =
85         reinterpret_cast<caffe2::ThreadPool*>(pthreadpool);
86     delete threadpool;
87   }
88 }
89